diff --git a/packages/types/src/tool.ts b/packages/types/src/tool.ts index 4f90b63e9fc..6220ebb7077 100644 --- a/packages/types/src/tool.ts +++ b/packages/types/src/tool.ts @@ -46,6 +46,10 @@ export const toolNames = [ "skill", "generate_image", "custom_tool", + "go_to_definition", + "find_references", + "workspace_symbols", + "document_symbols", ] as const export const toolNamesSchema = z.enum(toolNames) diff --git a/packages/types/src/vscode-extension-host.ts b/packages/types/src/vscode-extension-host.ts index b20539afe49..a60f17a0ad7 100644 --- a/packages/types/src/vscode-extension-host.ts +++ b/packages/types/src/vscode-extension-host.ts @@ -783,6 +783,10 @@ export interface ClineSayTool { | "runSlashCommand" | "updateTodoList" | "skill" + | "goToDefinition" + | "findReferences" + | "workspaceSymbols" + | "documentSymbols" path?: string // For readCommandOutput readStart?: number diff --git a/src/core/assistant-message/NativeToolCallParser.ts b/src/core/assistant-message/NativeToolCallParser.ts index bda7c71eb8d..e8cfc45c3b4 100644 --- a/src/core/assistant-message/NativeToolCallParser.ts +++ b/src/core/assistant-message/NativeToolCallParser.ts @@ -538,6 +538,42 @@ export class NativeToolCallParser { } break + case "go_to_definition": + if (partialArgs.path !== undefined || partialArgs.line !== undefined) { + nativeArgs = { + path: partialArgs.path, + line: partialArgs.line, + character: partialArgs.character, + } + } + break + + case "find_references": + if (partialArgs.path !== undefined || partialArgs.line !== undefined) { + nativeArgs = { + path: partialArgs.path, + line: partialArgs.line, + character: partialArgs.character, + } + } + break + + case "workspace_symbols": + if (partialArgs.query !== undefined) { + nativeArgs = { + query: partialArgs.query, + } + } + break + + case "document_symbols": + if (partialArgs.path !== undefined) { + nativeArgs = { + path: partialArgs.path, + } + } + break + case "switch_mode": if (partialArgs.mode_slug !== undefined || partialArgs.reason !== undefined) { nativeArgs = { @@ -874,6 +910,42 @@ export class NativeToolCallParser { } break + case "go_to_definition": + if (args.path !== undefined && args.line !== undefined && args.character !== undefined) { + nativeArgs = { + path: args.path, + line: args.line, + character: args.character, + } as NativeArgsFor + } + break + + case "find_references": + if (args.path !== undefined && args.line !== undefined && args.character !== undefined) { + nativeArgs = { + path: args.path, + line: args.line, + character: args.character, + } as NativeArgsFor + } + break + + case "workspace_symbols": + if (args.query !== undefined) { + nativeArgs = { + query: args.query, + } as NativeArgsFor + } + break + + case "document_symbols": + if (args.path !== undefined) { + nativeArgs = { + path: args.path, + } as NativeArgsFor + } + break + case "switch_mode": if (args.mode_slug !== undefined && args.reason !== undefined) { nativeArgs = { diff --git a/src/core/assistant-message/presentAssistantMessage.ts b/src/core/assistant-message/presentAssistantMessage.ts index 7f5862be154..46f1259207a 100644 --- a/src/core/assistant-message/presentAssistantMessage.ts +++ b/src/core/assistant-message/presentAssistantMessage.ts @@ -37,6 +37,10 @@ import { generateImageTool } from "../tools/GenerateImageTool" import { applyDiffTool as applyDiffToolClass } from "../tools/ApplyDiffTool" import { isValidToolName, validateToolUse } from "../tools/validateToolUse" import { codebaseSearchTool } from "../tools/CodebaseSearchTool" +import { goToDefinitionTool } from "../tools/GoToDefinitionTool" +import { findReferencesTool } from "../tools/FindReferencesTool" +import { workspaceSymbolsTool } from "../tools/WorkspaceSymbolsTool" +import { documentSymbolsTool } from "../tools/DocumentSymbolsTool" import { formatResponse } from "../prompts/responses" import { sanitizeToolUseId } from "../../utils/tool-id" @@ -383,6 +387,14 @@ export async function presentAssistantMessage(cline: Task) { return `[${block.name} for '${block.params.skill}'${block.params.args ? ` with args: ${block.params.args}` : ""}]` case "generate_image": return `[${block.name} for '${block.params.path}']` + case "go_to_definition": + return `[${block.name} for '${block.params.path}:${block.params.line}:${block.params.character}']` + case "find_references": + return `[${block.name} for '${block.params.path}:${block.params.line}:${block.params.character}']` + case "workspace_symbols": + return `[${block.name} for '${block.params.query}']` + case "document_symbols": + return `[${block.name} for '${block.params.path}']` default: return `[${block.name}]` } @@ -761,6 +773,34 @@ export async function presentAssistantMessage(cline: Task) { pushToolResult, }) break + case "go_to_definition": + await goToDefinitionTool.handle(cline, block as ToolUse<"go_to_definition">, { + askApproval, + handleError, + pushToolResult, + }) + break + case "find_references": + await findReferencesTool.handle(cline, block as ToolUse<"find_references">, { + askApproval, + handleError, + pushToolResult, + }) + break + case "workspace_symbols": + await workspaceSymbolsTool.handle(cline, block as ToolUse<"workspace_symbols">, { + askApproval, + handleError, + pushToolResult, + }) + break + case "document_symbols": + await documentSymbolsTool.handle(cline, block as ToolUse<"document_symbols">, { + askApproval, + handleError, + pushToolResult, + }) + break case "execute_command": await executeCommandTool.handle(cline, block as ToolUse<"execute_command">, { askApproval, diff --git a/src/core/prompts/tools/native-tools/document_symbols.ts b/src/core/prompts/tools/native-tools/document_symbols.ts new file mode 100644 index 00000000000..9c8b99d7f98 --- /dev/null +++ b/src/core/prompts/tools/native-tools/document_symbols.ts @@ -0,0 +1,29 @@ +import type OpenAI from "openai" + +const DOCUMENT_SYMBOLS_DESCRIPTION = `Request to list all symbols defined in a specific file. Uses the VS Code language server to enumerate all functions, classes, methods, variables, types, and other symbols in the document. Returns a flat list of symbols with their names, kinds, and line ranges. + +Parameters: +- path: (required) The file path (relative to the current workspace directory) to analyze. + +Example: Listing all symbols in a TypeScript file +{ "path": "src/services/auth.ts" }` + +export default { + type: "function", + function: { + name: "document_symbols", + description: DOCUMENT_SYMBOLS_DESCRIPTION, + strict: true, + parameters: { + type: "object", + properties: { + path: { + type: "string", + description: "File path relative to the workspace to analyze", + }, + }, + required: ["path"], + additionalProperties: false, + }, + }, +} satisfies OpenAI.Chat.ChatCompletionTool diff --git a/src/core/prompts/tools/native-tools/find_references.ts b/src/core/prompts/tools/native-tools/find_references.ts new file mode 100644 index 00000000000..d5e916a4b04 --- /dev/null +++ b/src/core/prompts/tools/native-tools/find_references.ts @@ -0,0 +1,39 @@ +import type OpenAI from "openai" + +const FIND_REFERENCES_DESCRIPTION = `Request to find all references to a symbol at a given position in a file. Uses the VS Code language server to locate every place a symbol (function, class, variable, type, etc.) is referenced across the workspace. Returns a list of file paths and positions. Results are capped at 50 locations to keep context size reasonable. + +Parameters: +- path: (required) The file path (relative to the current workspace directory) containing the symbol. +- line: (required) The 1-based line number of the symbol. +- character: (required) The 0-based character offset of the symbol on the line. + +Example: Finding all references to a function at line 10, character 15 +{ "path": "src/services/user.ts", "line": 10, "character": 15 }` + +export default { + type: "function", + function: { + name: "find_references", + description: FIND_REFERENCES_DESCRIPTION, + strict: true, + parameters: { + type: "object", + properties: { + path: { + type: "string", + description: "File path relative to the workspace containing the symbol", + }, + line: { + type: "number", + description: "1-based line number of the symbol", + }, + character: { + type: "number", + description: "0-based character offset of the symbol on the line", + }, + }, + required: ["path", "line", "character"], + additionalProperties: false, + }, + }, +} satisfies OpenAI.Chat.ChatCompletionTool diff --git a/src/core/prompts/tools/native-tools/go_to_definition.ts b/src/core/prompts/tools/native-tools/go_to_definition.ts new file mode 100644 index 00000000000..e7c220ff57d --- /dev/null +++ b/src/core/prompts/tools/native-tools/go_to_definition.ts @@ -0,0 +1,39 @@ +import type OpenAI from "openai" + +const GO_TO_DEFINITION_DESCRIPTION = `Request to find the definition of a symbol at a given position in a file. Uses the VS Code language server to locate where a symbol (function, class, variable, type, etc.) is defined. Returns the file path and position of the definition. + +Parameters: +- path: (required) The file path (relative to the current workspace directory) containing the symbol. +- line: (required) The 1-based line number of the symbol. +- character: (required) The 0-based character offset of the symbol on the line. + +Example: Finding the definition of a function call at line 15, character 8 +{ "path": "src/utils/auth.ts", "line": 15, "character": 8 }` + +export default { + type: "function", + function: { + name: "go_to_definition", + description: GO_TO_DEFINITION_DESCRIPTION, + strict: true, + parameters: { + type: "object", + properties: { + path: { + type: "string", + description: "File path relative to the workspace containing the symbol", + }, + line: { + type: "number", + description: "1-based line number of the symbol", + }, + character: { + type: "number", + description: "0-based character offset of the symbol on the line", + }, + }, + required: ["path", "line", "character"], + additionalProperties: false, + }, + }, +} satisfies OpenAI.Chat.ChatCompletionTool diff --git a/src/core/prompts/tools/native-tools/index.ts b/src/core/prompts/tools/native-tools/index.ts index 758914d2d65..e4bee9a4245 100644 --- a/src/core/prompts/tools/native-tools/index.ts +++ b/src/core/prompts/tools/native-tools/index.ts @@ -20,6 +20,10 @@ import searchFiles from "./search_files" import switchMode from "./switch_mode" import updateTodoList from "./update_todo_list" import writeToFile from "./write_to_file" +import goToDefinition from "./go_to_definition" +import findReferences from "./find_references" +import workspaceSymbols from "./workspace_symbols" +import documentSymbols from "./document_symbols" export { getMcpServerTools } from "./mcp_server" export { convertOpenAIToolToAnthropic, convertOpenAIToolsToAnthropic } from "./converters" @@ -68,6 +72,10 @@ export function getNativeTools(options: NativeToolsOptions = {}): OpenAI.Chat.Ch switchMode, updateTodoList, writeToFile, + goToDefinition, + findReferences, + workspaceSymbols, + documentSymbols, ] satisfies OpenAI.Chat.ChatCompletionTool[] } diff --git a/src/core/prompts/tools/native-tools/workspace_symbols.ts b/src/core/prompts/tools/native-tools/workspace_symbols.ts new file mode 100644 index 00000000000..fb2b09a9209 --- /dev/null +++ b/src/core/prompts/tools/native-tools/workspace_symbols.ts @@ -0,0 +1,29 @@ +import type OpenAI from "openai" + +const WORKSPACE_SYMBOLS_DESCRIPTION = `Request to search for symbols across the entire workspace by name. Uses the VS Code language server to find classes, functions, variables, types, and other symbols that match a query string. Returns up to 100 matching symbols with their names, kinds, file paths, and positions. + +Parameters: +- query: (required) The search query to match symbol names against. Can be a partial name (e.g., "User" will match "UserService", "getUser", etc.). + +Example: Searching for all symbols containing "Payment" +{ "query": "Payment" }` + +export default { + type: "function", + function: { + name: "workspace_symbols", + description: WORKSPACE_SYMBOLS_DESCRIPTION, + strict: true, + parameters: { + type: "object", + properties: { + query: { + type: "string", + description: "Search query to match symbol names against", + }, + }, + required: ["query"], + additionalProperties: false, + }, + }, +} satisfies OpenAI.Chat.ChatCompletionTool diff --git a/src/core/tools/DocumentSymbolsTool.ts b/src/core/tools/DocumentSymbolsTool.ts new file mode 100644 index 00000000000..e7ca087d745 --- /dev/null +++ b/src/core/tools/DocumentSymbolsTool.ts @@ -0,0 +1,165 @@ +import path from "path" + +import * as vscode from "vscode" + +import { Task } from "../task/Task" +import { getReadablePath } from "../../utils/path" +import type { ToolUse } from "../../shared/tools" + +import { BaseTool, ToolCallbacks } from "./BaseTool" + +interface DocumentSymbolsParams { + path: string +} + +/** + * Maps VS Code SymbolKind enum values to human-readable strings. + */ +function symbolKindToString(kind: vscode.SymbolKind): string { + const kindMap: Record = { + [vscode.SymbolKind.File]: "File", + [vscode.SymbolKind.Module]: "Module", + [vscode.SymbolKind.Namespace]: "Namespace", + [vscode.SymbolKind.Package]: "Package", + [vscode.SymbolKind.Class]: "Class", + [vscode.SymbolKind.Method]: "Method", + [vscode.SymbolKind.Property]: "Property", + [vscode.SymbolKind.Field]: "Field", + [vscode.SymbolKind.Constructor]: "Constructor", + [vscode.SymbolKind.Enum]: "Enum", + [vscode.SymbolKind.Interface]: "Interface", + [vscode.SymbolKind.Function]: "Function", + [vscode.SymbolKind.Variable]: "Variable", + [vscode.SymbolKind.Constant]: "Constant", + [vscode.SymbolKind.String]: "String", + [vscode.SymbolKind.Number]: "Number", + [vscode.SymbolKind.Boolean]: "Boolean", + [vscode.SymbolKind.Array]: "Array", + [vscode.SymbolKind.Object]: "Object", + [vscode.SymbolKind.Key]: "Key", + [vscode.SymbolKind.Null]: "Null", + [vscode.SymbolKind.EnumMember]: "EnumMember", + [vscode.SymbolKind.Struct]: "Struct", + [vscode.SymbolKind.Event]: "Event", + [vscode.SymbolKind.Operator]: "Operator", + [vscode.SymbolKind.TypeParameter]: "TypeParameter", + } + return kindMap[kind] ?? "Unknown" +} + +interface FlatSymbol { + name: string + kind: string + line: number + endLine: number + children?: FlatSymbol[] +} + +/** + * Flatten a DocumentSymbol tree into a simple list with nesting preserved via `children`. + */ +function flattenSymbols(symbols: vscode.DocumentSymbol[]): FlatSymbol[] { + return symbols.map((sym) => { + const result: FlatSymbol = { + name: sym.name, + kind: symbolKindToString(sym.kind), + line: sym.range.start.line + 1, + endLine: sym.range.end.line + 1, + } + if (sym.children && sym.children.length > 0) { + result.children = flattenSymbols(sym.children) + } + return result + }) +} + +export class DocumentSymbolsTool extends BaseTool<"document_symbols"> { + readonly name = "document_symbols" as const + + async execute(params: DocumentSymbolsParams, task: Task, callbacks: ToolCallbacks): Promise { + const { askApproval, handleError, pushToolResult } = callbacks + + const relPath = params.path + + if (!relPath) { + task.consecutiveMistakeCount++ + task.recordToolError("document_symbols") + task.didToolFailInCurrentTurn = true + pushToolResult(await task.sayAndCreateMissingParamError("document_symbols", "path")) + return + } + + task.consecutiveMistakeCount = 0 + + const absolutePath = path.resolve(task.cwd, relPath) + const uri = vscode.Uri.file(absolutePath) + + try { + const symbols = await vscode.commands.executeCommand<(vscode.DocumentSymbol | vscode.SymbolInformation)[]>( + "vscode.executeDocumentSymbolProvider", + uri, + ) + + if (!symbols || symbols.length === 0) { + const message = `No symbols found in ${getReadablePath(task.cwd, relPath)}` + const didApprove = await askApproval( + "tool", + JSON.stringify({ + tool: "documentSymbols", + path: getReadablePath(task.cwd, relPath), + content: message, + }), + ) + if (!didApprove) { + return + } + pushToolResult(message) + return + } + + // DocumentSymbol has children; SymbolInformation is flat. + let results: FlatSymbol[] + if ("range" in symbols[0] && "children" in symbols[0]) { + // DocumentSymbol[] + results = flattenSymbols(symbols as vscode.DocumentSymbol[]) + } else { + // SymbolInformation[] (fallback) + results = (symbols as vscode.SymbolInformation[]).map((sym) => ({ + name: sym.name, + kind: symbolKindToString(sym.kind), + line: sym.location.range.start.line + 1, + endLine: sym.location.range.end.line + 1, + })) + } + + const content = JSON.stringify(results, null, 2) + const didApprove = await askApproval( + "tool", + JSON.stringify({ tool: "documentSymbols", path: getReadablePath(task.cwd, relPath), content }), + ) + + if (!didApprove) { + return + } + + pushToolResult(content) + } catch (error) { + await handleError("listing document symbols", error as Error) + } + } + + override async handlePartial(task: Task, block: ToolUse<"document_symbols">): Promise { + const relPath = block.params.path + if (!this.hasPathStabilized(relPath)) { + return + } + const partialMessage = JSON.stringify({ + tool: "documentSymbols", + path: getReadablePath(task.cwd, relPath ?? ""), + content: "", + }) + await task.ask("tool", partialMessage, block.partial).catch(() => {}) + } +} + +export const documentSymbolsTool = new DocumentSymbolsTool() diff --git a/src/core/tools/FindReferencesTool.ts b/src/core/tools/FindReferencesTool.ts new file mode 100644 index 00000000000..9826d52327d --- /dev/null +++ b/src/core/tools/FindReferencesTool.ts @@ -0,0 +1,133 @@ +import path from "path" + +import * as vscode from "vscode" + +import { Task } from "../task/Task" +import { getReadablePath } from "../../utils/path" +import type { ToolUse } from "../../shared/tools" + +import { BaseTool, ToolCallbacks } from "./BaseTool" + +interface FindReferencesParams { + path: string + line: number + character: number +} + +const MAX_RESULTS = 50 + +export class FindReferencesTool extends BaseTool<"find_references"> { + readonly name = "find_references" as const + + async execute(params: FindReferencesParams, task: Task, callbacks: ToolCallbacks): Promise { + const { askApproval, handleError, pushToolResult } = callbacks + + const relPath = params.path + const line = params.line + const character = params.character + + if (!relPath) { + task.consecutiveMistakeCount++ + task.recordToolError("find_references") + task.didToolFailInCurrentTurn = true + pushToolResult(await task.sayAndCreateMissingParamError("find_references", "path")) + return + } + + if (line === undefined || line === null) { + task.consecutiveMistakeCount++ + task.recordToolError("find_references") + task.didToolFailInCurrentTurn = true + pushToolResult(await task.sayAndCreateMissingParamError("find_references", "line")) + return + } + + if (character === undefined || character === null) { + task.consecutiveMistakeCount++ + task.recordToolError("find_references") + task.didToolFailInCurrentTurn = true + pushToolResult(await task.sayAndCreateMissingParamError("find_references", "character")) + return + } + + task.consecutiveMistakeCount = 0 + + const absolutePath = path.resolve(task.cwd, relPath) + const uri = vscode.Uri.file(absolutePath) + const position = new vscode.Position(line - 1, character) // Convert 1-based line to 0-based + + try { + const locations = await vscode.commands.executeCommand( + "vscode.executeReferenceProvider", + uri, + position, + ) + + if (!locations || locations.length === 0) { + const message = `No references found for symbol at ${getReadablePath(task.cwd, relPath)}:${line}:${character}` + const didApprove = await askApproval( + "tool", + JSON.stringify({ + tool: "findReferences", + path: getReadablePath(task.cwd, relPath), + content: message, + }), + ) + if (!didApprove) { + return + } + pushToolResult(message) + return + } + + const truncated = locations.length > MAX_RESULTS + const results = locations.slice(0, MAX_RESULTS).map((loc) => { + const targetPath = vscode.workspace.asRelativePath(loc.uri) + return { + path: targetPath, + line: loc.range.start.line + 1, // Convert 0-based to 1-based + character: loc.range.start.character, + endLine: loc.range.end.line + 1, + endCharacter: loc.range.end.character, + } + }) + + const output: { results: typeof results; totalCount: number; truncated?: boolean } = { + results, + totalCount: locations.length, + } + if (truncated) { + output.truncated = true + } + + const content = JSON.stringify(output, null, 2) + const didApprove = await askApproval( + "tool", + JSON.stringify({ tool: "findReferences", path: getReadablePath(task.cwd, relPath), content }), + ) + + if (!didApprove) { + return + } + + pushToolResult(content) + } catch (error) { + await handleError("finding references", error as Error) + } + } + + override async handlePartial(task: Task, block: ToolUse<"find_references">): Promise { + const relPath = block.params.path + if (!this.hasPathStabilized(relPath)) { + return + } + const partialMessage = JSON.stringify({ + tool: "findReferences", + path: getReadablePath(task.cwd, relPath ?? ""), + content: "", + }) + await task.ask("tool", partialMessage, block.partial).catch(() => {}) + } +} + +export const findReferencesTool = new FindReferencesTool() diff --git a/src/core/tools/GoToDefinitionTool.ts b/src/core/tools/GoToDefinitionTool.ts new file mode 100644 index 00000000000..7e2917b8bb2 --- /dev/null +++ b/src/core/tools/GoToDefinitionTool.ts @@ -0,0 +1,137 @@ +import path from "path" + +import * as vscode from "vscode" + +import { Task } from "../task/Task" +import { getReadablePath } from "../../utils/path" +import type { ToolUse } from "../../shared/tools" + +import { BaseTool, ToolCallbacks } from "./BaseTool" + +interface GoToDefinitionParams { + path: string + line: number + character: number +} + +const MAX_RESULTS = 50 + +export class GoToDefinitionTool extends BaseTool<"go_to_definition"> { + readonly name = "go_to_definition" as const + + async execute(params: GoToDefinitionParams, task: Task, callbacks: ToolCallbacks): Promise { + const { askApproval, handleError, pushToolResult } = callbacks + + const relPath = params.path + const line = params.line + const character = params.character + + if (!relPath) { + task.consecutiveMistakeCount++ + task.recordToolError("go_to_definition") + task.didToolFailInCurrentTurn = true + pushToolResult(await task.sayAndCreateMissingParamError("go_to_definition", "path")) + return + } + + if (line === undefined || line === null) { + task.consecutiveMistakeCount++ + task.recordToolError("go_to_definition") + task.didToolFailInCurrentTurn = true + pushToolResult(await task.sayAndCreateMissingParamError("go_to_definition", "line")) + return + } + + if (character === undefined || character === null) { + task.consecutiveMistakeCount++ + task.recordToolError("go_to_definition") + task.didToolFailInCurrentTurn = true + pushToolResult(await task.sayAndCreateMissingParamError("go_to_definition", "character")) + return + } + + task.consecutiveMistakeCount = 0 + + const absolutePath = path.resolve(task.cwd, relPath) + const uri = vscode.Uri.file(absolutePath) + const position = new vscode.Position(line - 1, character) // Convert 1-based line to 0-based + + try { + const locations = await vscode.commands.executeCommand<(vscode.Location | vscode.LocationLink)[]>( + "vscode.executeDefinitionProvider", + uri, + position, + ) + + if (!locations || locations.length === 0) { + const message = `No definition found for symbol at ${getReadablePath(task.cwd, relPath)}:${line}:${character}` + const didApprove = await askApproval( + "tool", + JSON.stringify({ + tool: "goToDefinition", + path: getReadablePath(task.cwd, relPath), + content: message, + }), + ) + if (!didApprove) { + return + } + pushToolResult(message) + return + } + + const results = locations.slice(0, MAX_RESULTS).map((loc) => { + if ("targetUri" in loc) { + // LocationLink + const targetPath = vscode.workspace.asRelativePath(loc.targetUri) + return { + path: targetPath, + line: loc.targetRange.start.line + 1, // Convert 0-based to 1-based + character: loc.targetRange.start.character, + endLine: loc.targetRange.end.line + 1, + endCharacter: loc.targetRange.end.character, + } + } else { + // Location + const targetPath = vscode.workspace.asRelativePath(loc.uri) + return { + path: targetPath, + line: loc.range.start.line + 1, + character: loc.range.start.character, + endLine: loc.range.end.line + 1, + endCharacter: loc.range.end.character, + } + } + }) + + const content = JSON.stringify(results, null, 2) + const didApprove = await askApproval( + "tool", + JSON.stringify({ tool: "goToDefinition", path: getReadablePath(task.cwd, relPath), content }), + ) + + if (!didApprove) { + return + } + + pushToolResult(content) + } catch (error) { + await handleError("finding definition", error as Error) + } + } + + override async handlePartial(task: Task, block: ToolUse<"go_to_definition">): Promise { + const relPath = block.params.path + if (!this.hasPathStabilized(relPath)) { + return + } + const partialMessage = JSON.stringify({ + tool: "goToDefinition", + path: getReadablePath(task.cwd, relPath ?? ""), + content: "", + }) + await task.ask("tool", partialMessage, block.partial).catch(() => {}) + } +} + +export const goToDefinitionTool = new GoToDefinitionTool() diff --git a/src/core/tools/WorkspaceSymbolsTool.ts b/src/core/tools/WorkspaceSymbolsTool.ts new file mode 100644 index 00000000000..b79e0bb739f --- /dev/null +++ b/src/core/tools/WorkspaceSymbolsTool.ts @@ -0,0 +1,131 @@ +import * as vscode from "vscode" + +import { Task } from "../task/Task" +import type { ToolUse } from "../../shared/tools" + +import { BaseTool, ToolCallbacks } from "./BaseTool" + +interface WorkspaceSymbolsParams { + query: string +} + +const MAX_RESULTS = 100 + +/** + * Maps VS Code SymbolKind enum values to human-readable strings. + */ +function symbolKindToString(kind: vscode.SymbolKind): string { + const kindMap: Record = { + [vscode.SymbolKind.File]: "File", + [vscode.SymbolKind.Module]: "Module", + [vscode.SymbolKind.Namespace]: "Namespace", + [vscode.SymbolKind.Package]: "Package", + [vscode.SymbolKind.Class]: "Class", + [vscode.SymbolKind.Method]: "Method", + [vscode.SymbolKind.Property]: "Property", + [vscode.SymbolKind.Field]: "Field", + [vscode.SymbolKind.Constructor]: "Constructor", + [vscode.SymbolKind.Enum]: "Enum", + [vscode.SymbolKind.Interface]: "Interface", + [vscode.SymbolKind.Function]: "Function", + [vscode.SymbolKind.Variable]: "Variable", + [vscode.SymbolKind.Constant]: "Constant", + [vscode.SymbolKind.String]: "String", + [vscode.SymbolKind.Number]: "Number", + [vscode.SymbolKind.Boolean]: "Boolean", + [vscode.SymbolKind.Array]: "Array", + [vscode.SymbolKind.Object]: "Object", + [vscode.SymbolKind.Key]: "Key", + [vscode.SymbolKind.Null]: "Null", + [vscode.SymbolKind.EnumMember]: "EnumMember", + [vscode.SymbolKind.Struct]: "Struct", + [vscode.SymbolKind.Event]: "Event", + [vscode.SymbolKind.Operator]: "Operator", + [vscode.SymbolKind.TypeParameter]: "TypeParameter", + } + return kindMap[kind] ?? "Unknown" +} + +export class WorkspaceSymbolsTool extends BaseTool<"workspace_symbols"> { + readonly name = "workspace_symbols" as const + + async execute(params: WorkspaceSymbolsParams, task: Task, callbacks: ToolCallbacks): Promise { + const { askApproval, handleError, pushToolResult } = callbacks + + const query = params.query + + if (!query) { + task.consecutiveMistakeCount++ + task.recordToolError("workspace_symbols") + task.didToolFailInCurrentTurn = true + pushToolResult(await task.sayAndCreateMissingParamError("workspace_symbols", "query")) + return + } + + task.consecutiveMistakeCount = 0 + + try { + const symbols = await vscode.commands.executeCommand( + "vscode.executeWorkspaceSymbolProvider", + query, + ) + + if (!symbols || symbols.length === 0) { + const message = `No workspace symbols found matching "${query}"` + const didApprove = await askApproval( + "tool", + JSON.stringify({ tool: "workspaceSymbols", query, content: message }), + ) + if (!didApprove) { + return + } + pushToolResult(message) + return + } + + const truncated = symbols.length > MAX_RESULTS + const results = symbols.slice(0, MAX_RESULTS).map((sym) => { + const symPath = vscode.workspace.asRelativePath(sym.location.uri) + return { + name: sym.name, + kind: symbolKindToString(sym.kind), + path: symPath, + line: sym.location.range.start.line + 1, + character: sym.location.range.start.character, + containerName: sym.containerName || undefined, + } + }) + + const output: { results: typeof results; totalCount: number; truncated?: boolean } = { + results, + totalCount: symbols.length, + } + if (truncated) { + output.truncated = true + } + + const content = JSON.stringify(output, null, 2) + const didApprove = await askApproval("tool", JSON.stringify({ tool: "workspaceSymbols", query, content })) + + if (!didApprove) { + return + } + + pushToolResult(content) + } catch (error) { + await handleError("searching workspace symbols", error as Error) + } + } + + override async handlePartial(task: Task, block: ToolUse<"workspace_symbols">): Promise { + const query = block.params.query + const partialMessage = JSON.stringify({ + tool: "workspaceSymbols", + query: query ?? "", + content: "", + }) + await task.ask("tool", partialMessage, block.partial).catch(() => {}) + } +} + +export const workspaceSymbolsTool = new WorkspaceSymbolsTool() diff --git a/src/core/tools/__tests__/documentSymbolsTool.spec.ts b/src/core/tools/__tests__/documentSymbolsTool.spec.ts new file mode 100644 index 00000000000..dc02dc295a4 --- /dev/null +++ b/src/core/tools/__tests__/documentSymbolsTool.spec.ts @@ -0,0 +1,192 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import * as vscode from "vscode" +import { documentSymbolsTool } from "../DocumentSymbolsTool" +import { Task } from "../../task/Task" +import type { ToolUse } from "../../../shared/tools" + +vi.mock("vscode", () => { + const SymbolKind = { + File: 0, + Module: 1, + Namespace: 2, + Package: 3, + Class: 4, + Method: 5, + Property: 6, + Field: 7, + Constructor: 8, + Enum: 9, + Interface: 10, + Function: 11, + Variable: 12, + Constant: 13, + String: 14, + Number: 15, + Boolean: 16, + Array: 17, + Object: 18, + Key: 19, + Null: 20, + EnumMember: 21, + Struct: 22, + Event: 23, + Operator: 24, + TypeParameter: 25, + } + return { + Uri: { + file: vi.fn((path: string) => ({ fsPath: path, path, scheme: "file" })), + }, + Position: vi.fn((line: number, character: number) => ({ line, character })), + commands: { + executeCommand: vi.fn(), + }, + workspace: { + asRelativePath: vi.fn((uri: any) => { + const p = typeof uri === "string" ? uri : uri.path || uri.fsPath + return p.replace(/^\/test\/project\//, "") + }), + }, + SymbolKind, + } +}) + +describe("DocumentSymbolsTool", () => { + let mockTask: any + let mockCallbacks: any + + beforeEach(() => { + vi.clearAllMocks() + + mockTask = { + consecutiveMistakeCount: 0, + didToolFailInCurrentTurn: false, + recordToolError: vi.fn(), + sayAndCreateMissingParamError: vi.fn().mockResolvedValue("Missing parameter error"), + ask: vi.fn().mockResolvedValue({}), + cwd: "/test/project", + } + + mockCallbacks = { + askApproval: vi.fn().mockResolvedValue(true), + handleError: vi.fn(), + pushToolResult: vi.fn(), + } + }) + + it("should handle missing path parameter", async () => { + const block: ToolUse<"document_symbols"> = { + type: "tool_use" as const, + name: "document_symbols" as const, + params: {}, + partial: false, + nativeArgs: { + path: "", + }, + } + + await documentSymbolsTool.handle(mockTask as Task, block, mockCallbacks) + + expect(mockTask.consecutiveMistakeCount).toBe(1) + expect(mockTask.recordToolError).toHaveBeenCalledWith("document_symbols") + }) + + it("should return document symbols (DocumentSymbol format)", async () => { + const mockSymbols = [ + { + name: "MyClass", + kind: vscode.SymbolKind.Class, + range: { start: { line: 0, character: 0 }, end: { line: 20, character: 1 } }, + children: [ + { + name: "myMethod", + kind: vscode.SymbolKind.Method, + range: { start: { line: 5, character: 2 }, end: { line: 15, character: 3 } }, + children: [], + }, + ], + }, + { + name: "helperFunction", + kind: vscode.SymbolKind.Function, + range: { start: { line: 22, character: 0 }, end: { line: 30, character: 1 } }, + children: [], + }, + ] + + vi.mocked(vscode.commands.executeCommand).mockResolvedValue(mockSymbols as any) + + const block: ToolUse<"document_symbols"> = { + type: "tool_use" as const, + name: "document_symbols" as const, + params: {}, + partial: false, + nativeArgs: { + path: "src/test.ts", + }, + } + + await documentSymbolsTool.handle(mockTask as Task, block, mockCallbacks) + + const result = JSON.parse(mockCallbacks.pushToolResult.mock.calls[0][0]) + expect(result).toHaveLength(2) + expect(result[0].name).toBe("MyClass") + expect(result[0].kind).toBe("Class") + expect(result[0].line).toBe(1) // 0-based 0 => 1-based 1 + expect(result[0].endLine).toBe(21) // 0-based 20 => 1-based 21 + expect(result[0].children).toHaveLength(1) + expect(result[0].children[0].name).toBe("myMethod") + expect(result[1].name).toBe("helperFunction") + expect(result[1].kind).toBe("Function") + }) + + it("should return document symbols (SymbolInformation format)", async () => { + const mockSymbols = [ + { + name: "myVar", + kind: vscode.SymbolKind.Variable, + location: { + uri: { fsPath: "/test/project/src/test.ts", path: "/test/project/src/test.ts" }, + range: { start: { line: 0, character: 0 }, end: { line: 0, character: 20 } }, + }, + }, + ] + + vi.mocked(vscode.commands.executeCommand).mockResolvedValue(mockSymbols as any) + + const block: ToolUse<"document_symbols"> = { + type: "tool_use" as const, + name: "document_symbols" as const, + params: {}, + partial: false, + nativeArgs: { + path: "src/test.ts", + }, + } + + await documentSymbolsTool.handle(mockTask as Task, block, mockCallbacks) + + const result = JSON.parse(mockCallbacks.pushToolResult.mock.calls[0][0]) + expect(result).toHaveLength(1) + expect(result[0].name).toBe("myVar") + expect(result[0].kind).toBe("Variable") + }) + + it("should handle no symbols found", async () => { + vi.mocked(vscode.commands.executeCommand).mockResolvedValue([] as any) + + const block: ToolUse<"document_symbols"> = { + type: "tool_use" as const, + name: "document_symbols" as const, + params: {}, + partial: false, + nativeArgs: { + path: "src/test.ts", + }, + } + + await documentSymbolsTool.handle(mockTask as Task, block, mockCallbacks) + + expect(mockCallbacks.pushToolResult).toHaveBeenCalledWith(expect.stringContaining("No symbols found")) + }) +}) diff --git a/src/core/tools/__tests__/findReferencesTool.spec.ts b/src/core/tools/__tests__/findReferencesTool.spec.ts new file mode 100644 index 00000000000..8f051d48050 --- /dev/null +++ b/src/core/tools/__tests__/findReferencesTool.spec.ts @@ -0,0 +1,178 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import * as vscode from "vscode" +import { findReferencesTool } from "../FindReferencesTool" +import { Task } from "../../task/Task" +import type { ToolUse } from "../../../shared/tools" + +vi.mock("vscode", () => { + const SymbolKind = { + File: 0, + Module: 1, + Namespace: 2, + Package: 3, + Class: 4, + Method: 5, + Property: 6, + Field: 7, + Constructor: 8, + Enum: 9, + Interface: 10, + Function: 11, + Variable: 12, + Constant: 13, + String: 14, + Number: 15, + Boolean: 16, + Array: 17, + Object: 18, + Key: 19, + Null: 20, + EnumMember: 21, + Struct: 22, + Event: 23, + Operator: 24, + TypeParameter: 25, + } + return { + Uri: { + file: vi.fn((path: string) => ({ fsPath: path, path, scheme: "file" })), + }, + Position: vi.fn((line: number, character: number) => ({ line, character })), + commands: { + executeCommand: vi.fn(), + }, + workspace: { + asRelativePath: vi.fn((uri: any) => { + const p = typeof uri === "string" ? uri : uri.path || uri.fsPath + return p.replace(/^\/test\/project\//, "") + }), + }, + SymbolKind, + } +}) + +describe("FindReferencesTool", () => { + let mockTask: any + let mockCallbacks: any + + beforeEach(() => { + vi.clearAllMocks() + + mockTask = { + consecutiveMistakeCount: 0, + didToolFailInCurrentTurn: false, + recordToolError: vi.fn(), + sayAndCreateMissingParamError: vi.fn().mockResolvedValue("Missing parameter error"), + ask: vi.fn().mockResolvedValue({}), + cwd: "/test/project", + } + + mockCallbacks = { + askApproval: vi.fn().mockResolvedValue(true), + handleError: vi.fn(), + pushToolResult: vi.fn(), + } + }) + + it("should handle missing path parameter", async () => { + const block: ToolUse<"find_references"> = { + type: "tool_use" as const, + name: "find_references" as const, + params: {}, + partial: false, + nativeArgs: { + path: "", + line: 10, + character: 5, + }, + } + + await findReferencesTool.handle(mockTask as Task, block, mockCallbacks) + + expect(mockTask.consecutiveMistakeCount).toBe(1) + expect(mockTask.recordToolError).toHaveBeenCalledWith("find_references") + }) + + it("should return references when found", async () => { + const mockLocations = [ + { + uri: { fsPath: "/test/project/src/a.ts", path: "/test/project/src/a.ts" }, + range: { start: { line: 9, character: 0 }, end: { line: 9, character: 10 } }, + }, + { + uri: { fsPath: "/test/project/src/b.ts", path: "/test/project/src/b.ts" }, + range: { start: { line: 19, character: 5 }, end: { line: 19, character: 15 } }, + }, + ] + + vi.mocked(vscode.commands.executeCommand).mockResolvedValue(mockLocations as any) + + const block: ToolUse<"find_references"> = { + type: "tool_use" as const, + name: "find_references" as const, + params: {}, + partial: false, + nativeArgs: { + path: "src/test.ts", + line: 10, + character: 5, + }, + } + + await findReferencesTool.handle(mockTask as Task, block, mockCallbacks) + + const result = JSON.parse(mockCallbacks.pushToolResult.mock.calls[0][0]) + expect(result.results).toHaveLength(2) + expect(result.totalCount).toBe(2) + expect(result.truncated).toBeUndefined() + }) + + it("should truncate results when exceeding MAX_RESULTS", async () => { + // Create 55 mock locations + const mockLocations = Array.from({ length: 55 }, (_, i) => ({ + uri: { fsPath: `/test/project/src/file${i}.ts`, path: `/test/project/src/file${i}.ts` }, + range: { start: { line: i, character: 0 }, end: { line: i, character: 10 } }, + })) + + vi.mocked(vscode.commands.executeCommand).mockResolvedValue(mockLocations as any) + + const block: ToolUse<"find_references"> = { + type: "tool_use" as const, + name: "find_references" as const, + params: {}, + partial: false, + nativeArgs: { + path: "src/test.ts", + line: 10, + character: 5, + }, + } + + await findReferencesTool.handle(mockTask as Task, block, mockCallbacks) + + const result = JSON.parse(mockCallbacks.pushToolResult.mock.calls[0][0]) + expect(result.results).toHaveLength(50) + expect(result.totalCount).toBe(55) + expect(result.truncated).toBe(true) + }) + + it("should handle no references found", async () => { + vi.mocked(vscode.commands.executeCommand).mockResolvedValue([] as any) + + const block: ToolUse<"find_references"> = { + type: "tool_use" as const, + name: "find_references" as const, + params: {}, + partial: false, + nativeArgs: { + path: "src/test.ts", + line: 10, + character: 5, + }, + } + + await findReferencesTool.handle(mockTask as Task, block, mockCallbacks) + + expect(mockCallbacks.pushToolResult).toHaveBeenCalledWith(expect.stringContaining("No references found")) + }) +}) diff --git a/src/core/tools/__tests__/goToDefinitionTool.spec.ts b/src/core/tools/__tests__/goToDefinitionTool.spec.ts new file mode 100644 index 00000000000..a47236f43a1 --- /dev/null +++ b/src/core/tools/__tests__/goToDefinitionTool.spec.ts @@ -0,0 +1,239 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import * as vscode from "vscode" +import { goToDefinitionTool } from "../GoToDefinitionTool" +import { Task } from "../../task/Task" +import type { ToolUse } from "../../../shared/tools" + +vi.mock("vscode", () => { + const SymbolKind = { + File: 0, + Module: 1, + Namespace: 2, + Package: 3, + Class: 4, + Method: 5, + Property: 6, + Field: 7, + Constructor: 8, + Enum: 9, + Interface: 10, + Function: 11, + Variable: 12, + Constant: 13, + String: 14, + Number: 15, + Boolean: 16, + Array: 17, + Object: 18, + Key: 19, + Null: 20, + EnumMember: 21, + Struct: 22, + Event: 23, + Operator: 24, + TypeParameter: 25, + } + return { + Uri: { + file: vi.fn((path: string) => ({ fsPath: path, path, scheme: "file" })), + }, + Position: vi.fn((line: number, character: number) => ({ line, character })), + commands: { + executeCommand: vi.fn(), + }, + workspace: { + asRelativePath: vi.fn((uri: any) => { + const p = typeof uri === "string" ? uri : uri.path || uri.fsPath + return p.replace(/^\/test\/project\//, "") + }), + }, + SymbolKind, + } +}) + +describe("GoToDefinitionTool", () => { + let mockTask: any + let mockCallbacks: any + + beforeEach(() => { + vi.clearAllMocks() + + mockTask = { + consecutiveMistakeCount: 0, + didToolFailInCurrentTurn: false, + recordToolError: vi.fn(), + sayAndCreateMissingParamError: vi.fn().mockResolvedValue("Missing parameter error"), + ask: vi.fn().mockResolvedValue({}), + cwd: "/test/project", + } + + mockCallbacks = { + askApproval: vi.fn().mockResolvedValue(true), + handleError: vi.fn(), + pushToolResult: vi.fn(), + } + }) + + it("should handle missing path parameter", async () => { + const block: ToolUse<"go_to_definition"> = { + type: "tool_use" as const, + name: "go_to_definition" as const, + params: {}, + partial: false, + nativeArgs: { + path: "", + line: 10, + character: 5, + }, + } + + await goToDefinitionTool.handle(mockTask as Task, block, mockCallbacks) + + expect(mockTask.consecutiveMistakeCount).toBe(1) + expect(mockTask.recordToolError).toHaveBeenCalledWith("go_to_definition") + expect(mockCallbacks.pushToolResult).toHaveBeenCalledWith("Missing parameter error") + }) + + it("should handle missing line parameter", async () => { + const block: ToolUse<"go_to_definition"> = { + type: "tool_use" as const, + name: "go_to_definition" as const, + params: {}, + partial: false, + nativeArgs: { + path: "src/test.ts", + line: undefined as any, + character: 5, + }, + } + + await goToDefinitionTool.handle(mockTask as Task, block, mockCallbacks) + + expect(mockTask.consecutiveMistakeCount).toBe(1) + expect(mockTask.recordToolError).toHaveBeenCalledWith("go_to_definition") + }) + + it("should return definitions when found", async () => { + const mockLocations = [ + { + uri: { fsPath: "/test/project/src/other.ts", path: "/test/project/src/other.ts" }, + range: { + start: { line: 9, character: 0 }, + end: { line: 9, character: 20 }, + }, + }, + ] + + vi.mocked(vscode.commands.executeCommand).mockResolvedValue(mockLocations as any) + + const block: ToolUse<"go_to_definition"> = { + type: "tool_use" as const, + name: "go_to_definition" as const, + params: {}, + partial: false, + nativeArgs: { + path: "src/test.ts", + line: 10, + character: 5, + }, + } + + await goToDefinitionTool.handle(mockTask as Task, block, mockCallbacks) + + expect(vscode.commands.executeCommand).toHaveBeenCalledWith( + "vscode.executeDefinitionProvider", + expect.anything(), + expect.anything(), + ) + expect(mockCallbacks.askApproval).toHaveBeenCalled() + expect(mockCallbacks.pushToolResult).toHaveBeenCalled() + + const result = JSON.parse(mockCallbacks.pushToolResult.mock.calls[0][0]) + expect(result).toHaveLength(1) + expect(result[0].line).toBe(10) + expect(result[0].character).toBe(0) + }) + + it("should handle no definitions found", async () => { + vi.mocked(vscode.commands.executeCommand).mockResolvedValue([] as any) + + const block: ToolUse<"go_to_definition"> = { + type: "tool_use" as const, + name: "go_to_definition" as const, + params: {}, + partial: false, + nativeArgs: { + path: "src/test.ts", + line: 10, + character: 5, + }, + } + + await goToDefinitionTool.handle(mockTask as Task, block, mockCallbacks) + + expect(mockCallbacks.pushToolResult).toHaveBeenCalledWith(expect.stringContaining("No definition found")) + }) + + it("should handle LocationLink results", async () => { + const mockLocations = [ + { + targetUri: { fsPath: "/test/project/src/other.ts", path: "/test/project/src/other.ts" }, + targetRange: { + start: { line: 4, character: 0 }, + end: { line: 4, character: 15 }, + }, + }, + ] + + vi.mocked(vscode.commands.executeCommand).mockResolvedValue(mockLocations as any) + + const block: ToolUse<"go_to_definition"> = { + type: "tool_use" as const, + name: "go_to_definition" as const, + params: {}, + partial: false, + nativeArgs: { + path: "src/test.ts", + line: 10, + character: 5, + }, + } + + await goToDefinitionTool.handle(mockTask as Task, block, mockCallbacks) + + const result = JSON.parse(mockCallbacks.pushToolResult.mock.calls[0][0]) + expect(result).toHaveLength(1) + expect(result[0].line).toBe(5) // 0-based 4 => 1-based 5 + }) + + it("should not push result when approval is denied", async () => { + const mockLocations = [ + { + uri: { fsPath: "/test/project/src/other.ts", path: "/test/project/src/other.ts" }, + range: { + start: { line: 9, character: 0 }, + end: { line: 9, character: 20 }, + }, + }, + ] + + vi.mocked(vscode.commands.executeCommand).mockResolvedValue(mockLocations as any) + mockCallbacks.askApproval.mockResolvedValue(false) + + const block: ToolUse<"go_to_definition"> = { + type: "tool_use" as const, + name: "go_to_definition" as const, + params: {}, + partial: false, + nativeArgs: { + path: "src/test.ts", + line: 10, + character: 5, + }, + } + + await goToDefinitionTool.handle(mockTask as Task, block, mockCallbacks) + + expect(mockCallbacks.pushToolResult).not.toHaveBeenCalled() + }) +}) diff --git a/src/core/tools/__tests__/workspaceSymbolsTool.spec.ts b/src/core/tools/__tests__/workspaceSymbolsTool.spec.ts new file mode 100644 index 00000000000..3612599b5cc --- /dev/null +++ b/src/core/tools/__tests__/workspaceSymbolsTool.spec.ts @@ -0,0 +1,155 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import * as vscode from "vscode" +import { workspaceSymbolsTool } from "../WorkspaceSymbolsTool" +import { Task } from "../../task/Task" +import type { ToolUse } from "../../../shared/tools" + +vi.mock("vscode", () => { + const SymbolKind = { + File: 0, + Module: 1, + Namespace: 2, + Package: 3, + Class: 4, + Method: 5, + Property: 6, + Field: 7, + Constructor: 8, + Enum: 9, + Interface: 10, + Function: 11, + Variable: 12, + Constant: 13, + String: 14, + Number: 15, + Boolean: 16, + Array: 17, + Object: 18, + Key: 19, + Null: 20, + EnumMember: 21, + Struct: 22, + Event: 23, + Operator: 24, + TypeParameter: 25, + } + return { + Uri: { + file: vi.fn((path: string) => ({ fsPath: path, path, scheme: "file" })), + }, + Position: vi.fn((line: number, character: number) => ({ line, character })), + commands: { + executeCommand: vi.fn(), + }, + workspace: { + asRelativePath: vi.fn((uri: any) => { + const p = typeof uri === "string" ? uri : uri.path || uri.fsPath + return p.replace(/^\/test\/project\//, "") + }), + }, + SymbolKind, + } +}) + +describe("WorkspaceSymbolsTool", () => { + let mockTask: any + let mockCallbacks: any + + beforeEach(() => { + vi.clearAllMocks() + + mockTask = { + consecutiveMistakeCount: 0, + didToolFailInCurrentTurn: false, + recordToolError: vi.fn(), + sayAndCreateMissingParamError: vi.fn().mockResolvedValue("Missing parameter error"), + ask: vi.fn().mockResolvedValue({}), + cwd: "/test/project", + } + + mockCallbacks = { + askApproval: vi.fn().mockResolvedValue(true), + handleError: vi.fn(), + pushToolResult: vi.fn(), + } + }) + + it("should handle missing query parameter", async () => { + const block: ToolUse<"workspace_symbols"> = { + type: "tool_use" as const, + name: "workspace_symbols" as const, + params: {}, + partial: false, + nativeArgs: { + query: "", + }, + } + + await workspaceSymbolsTool.handle(mockTask as Task, block, mockCallbacks) + + expect(mockTask.consecutiveMistakeCount).toBe(1) + expect(mockTask.recordToolError).toHaveBeenCalledWith("workspace_symbols") + }) + + it("should return symbols when found", async () => { + const mockSymbols = [ + { + name: "UserService", + kind: vscode.SymbolKind.Class, + containerName: "", + location: { + uri: { fsPath: "/test/project/src/services/user.ts", path: "/test/project/src/services/user.ts" }, + range: { start: { line: 4, character: 0 }, end: { line: 50, character: 1 } }, + }, + }, + { + name: "getUser", + kind: vscode.SymbolKind.Function, + containerName: "UserService", + location: { + uri: { fsPath: "/test/project/src/services/user.ts", path: "/test/project/src/services/user.ts" }, + range: { start: { line: 10, character: 2 }, end: { line: 20, character: 3 } }, + }, + }, + ] + + vi.mocked(vscode.commands.executeCommand).mockResolvedValue(mockSymbols as any) + + const block: ToolUse<"workspace_symbols"> = { + type: "tool_use" as const, + name: "workspace_symbols" as const, + params: {}, + partial: false, + nativeArgs: { + query: "User", + }, + } + + await workspaceSymbolsTool.handle(mockTask as Task, block, mockCallbacks) + + const result = JSON.parse(mockCallbacks.pushToolResult.mock.calls[0][0]) + expect(result.results).toHaveLength(2) + expect(result.results[0].name).toBe("UserService") + expect(result.results[0].kind).toBe("Class") + expect(result.results[0].line).toBe(5) // 0-based 4 => 1-based 5 + expect(result.results[1].containerName).toBe("UserService") + }) + + it("should handle no symbols found", async () => { + vi.mocked(vscode.commands.executeCommand).mockResolvedValue([] as any) + + const block: ToolUse<"workspace_symbols"> = { + type: "tool_use" as const, + name: "workspace_symbols" as const, + params: {}, + partial: false, + nativeArgs: { + query: "NonExistent", + }, + } + + await workspaceSymbolsTool.handle(mockTask as Task, block, mockCallbacks) + + expect(mockCallbacks.pushToolResult).toHaveBeenCalledWith(expect.stringContaining("No workspace symbols found")) + }) +}) diff --git a/src/shared/tools.ts b/src/shared/tools.ts index d2dd9907b17..85823585b1e 100644 --- a/src/shared/tools.ts +++ b/src/shared/tools.ts @@ -81,6 +81,7 @@ export const toolParamNames = [ // read_file legacy format parameter (backward compatibility) "files", "line_ranges", + "character", // go_to_definition and find_references parameter ] as const export type ToolParamName = (typeof toolParamNames)[number] @@ -116,6 +117,10 @@ export type NativeToolArgs = { update_todo_list: { todos: string } use_mcp_tool: { server_name: string; tool_name: string; arguments?: Record } write_to_file: { path: string; content: string } + go_to_definition: { path: string; line: number; character: number } + find_references: { path: string; line: number; character: number } + workspace_symbols: { query: string } + document_symbols: { path: string } // Add more tools as they are migrated to native protocol } @@ -290,12 +295,25 @@ export const TOOL_DISPLAY_NAMES: Record = { skill: "load skill", generate_image: "generate images", custom_tool: "use custom tools", + go_to_definition: "go to definition", + find_references: "find references", + workspace_symbols: "search workspace symbols", + document_symbols: "list document symbols", } as const // Define available tool groups. export const TOOL_GROUPS: Record = { read: { - tools: ["read_file", "search_files", "list_files", "codebase_search"], + tools: [ + "read_file", + "search_files", + "list_files", + "codebase_search", + "go_to_definition", + "find_references", + "workspace_symbols", + "document_symbols", + ], }, edit: { tools: ["apply_diff", "write_to_file", "generate_image"],