diff --git a/src/core/assistant-message/NativeToolCallParser.ts b/src/core/assistant-message/NativeToolCallParser.ts index 250afdc3890..7fd63b804a9 100644 --- a/src/core/assistant-message/NativeToolCallParser.ts +++ b/src/core/assistant-message/NativeToolCallParser.ts @@ -400,7 +400,17 @@ export class NativeToolCallParser { break case "apply_diff": - if (partialArgs.path !== undefined || partialArgs.diff !== undefined) { + // Multi-file format (from multi_apply_diff schema) + if (partialArgs.files && Array.isArray(partialArgs.files)) { + nativeArgs = { + files: partialArgs.files.map((f: any) => ({ + path: f.path, + diff: f.diff, + })), + } + } + // Single-file format (from apply_diff schema) + else if (partialArgs.path !== undefined || partialArgs.diff !== undefined) { nativeArgs = { path: partialArgs.path, diff: partialArgs.diff, @@ -633,7 +643,17 @@ export class NativeToolCallParser { break case "apply_diff": - if (args.path !== undefined && args.diff !== undefined) { + // Multi-file format (from multi_apply_diff schema) + if (args.files && Array.isArray(args.files)) { + nativeArgs = { + files: args.files.map((f: any) => ({ + path: f.path, + diff: f.diff, + })), + } as NativeArgsFor + } + // Single-file format (from apply_diff schema) + else if (args.path !== undefined && args.diff !== undefined) { nativeArgs = { path: args.path, diff: args.diff, diff --git a/src/core/assistant-message/__tests__/NativeToolCallParser.spec.ts b/src/core/assistant-message/__tests__/NativeToolCallParser.spec.ts index 0e81671cc15..23d7c58d76b 100644 --- a/src/core/assistant-message/__tests__/NativeToolCallParser.spec.ts +++ b/src/core/assistant-message/__tests__/NativeToolCallParser.spec.ts @@ -237,5 +237,206 @@ describe("NativeToolCallParser", () => { } }) }) + + describe("parseToolCall", () => { + describe("apply_diff tool", () => { + it("should handle single-file format (path and diff)", () => { + const toolCall = { + id: "toolu_123", + name: "apply_diff" as const, + arguments: JSON.stringify({ + path: "src/test.ts", + diff: "<<<<<<< SEARCH\nold code\n=======\nnew code\n>>>>>>> REPLACE", + }), + } + + const result = NativeToolCallParser.parseToolCall(toolCall) + + expect(result).not.toBeNull() + expect(result?.type).toBe("tool_use") + if (result?.type === "tool_use") { + expect(result.nativeArgs).toBeDefined() + const nativeArgs = result.nativeArgs as { path: string; diff: string } + expect(nativeArgs.path).toBe("src/test.ts") + expect(nativeArgs.diff).toContain("<<<<<<< SEARCH") + expect(nativeArgs.diff).toContain(">>>>>>> REPLACE") + } + }) + + it("should handle multi-file format (files array)", () => { + const toolCall = { + id: "toolu_456", + name: "apply_diff" as const, + arguments: JSON.stringify({ + files: [ + { + path: "src/file1.ts", + diff: "<<<<<<< SEARCH\nold code 1\n=======\nnew code 1\n>>>>>>> REPLACE", + }, + { + path: "src/file2.ts", + diff: "<<<<<<< SEARCH\nold code 2\n=======\nnew code 2\n>>>>>>> REPLACE", + }, + ], + }), + } + + const result = NativeToolCallParser.parseToolCall(toolCall) + + expect(result).not.toBeNull() + expect(result?.type).toBe("tool_use") + if (result?.type === "tool_use") { + expect(result.nativeArgs).toBeDefined() + const nativeArgs = result.nativeArgs as { + files: Array<{ path: string; diff: string }> + } + expect(nativeArgs.files).toHaveLength(2) + expect(nativeArgs.files[0].path).toBe("src/file1.ts") + expect(nativeArgs.files[0].diff).toContain("old code 1") + expect(nativeArgs.files[1].path).toBe("src/file2.ts") + expect(nativeArgs.files[1].diff).toContain("old code 2") + } + }) + + it("should handle multi-file format with single file", () => { + const toolCall = { + id: "toolu_789", + name: "apply_diff" as const, + arguments: JSON.stringify({ + files: [ + { + path: "src/single.ts", + diff: "<<<<<<< SEARCH\nold\n=======\nnew\n>>>>>>> REPLACE", + }, + ], + }), + } + + const result = NativeToolCallParser.parseToolCall(toolCall) + + expect(result).not.toBeNull() + expect(result?.type).toBe("tool_use") + if (result?.type === "tool_use") { + const nativeArgs = result.nativeArgs as { + files: Array<{ path: string; diff: string }> + } + expect(nativeArgs.files).toHaveLength(1) + expect(nativeArgs.files[0].path).toBe("src/single.ts") + } + }) + }) + }) + + describe("processStreamingChunk", () => { + describe("apply_diff tool", () => { + it("should handle single-file format during streaming", () => { + const id = "toolu_streaming_apply_diff_single" + NativeToolCallParser.startStreamingToolCall(id, "apply_diff") + + const fullArgs = JSON.stringify({ + path: "streaming/test.ts", + diff: "<<<<<<< SEARCH\nold\n=======\nnew\n>>>>>>> REPLACE", + }) + + const result = NativeToolCallParser.processStreamingChunk(id, fullArgs) + + expect(result).not.toBeNull() + expect(result?.nativeArgs).toBeDefined() + const nativeArgs = result?.nativeArgs as { path: string; diff: string } + expect(nativeArgs.path).toBe("streaming/test.ts") + expect(nativeArgs.diff).toContain("<<<<<<< SEARCH") + }) + + it("should handle multi-file format during streaming", () => { + const id = "toolu_streaming_apply_diff_multi" + NativeToolCallParser.startStreamingToolCall(id, "apply_diff") + + const fullArgs = JSON.stringify({ + files: [ + { + path: "streaming/file1.ts", + diff: "<<<<<<< SEARCH\nold1\n=======\nnew1\n>>>>>>> REPLACE", + }, + { + path: "streaming/file2.ts", + diff: "<<<<<<< SEARCH\nold2\n=======\nnew2\n>>>>>>> REPLACE", + }, + ], + }) + + const result = NativeToolCallParser.processStreamingChunk(id, fullArgs) + + expect(result).not.toBeNull() + expect(result?.nativeArgs).toBeDefined() + const nativeArgs = result?.nativeArgs as { + files: Array<{ path: string; diff: string }> + } + expect(nativeArgs.files).toHaveLength(2) + expect(nativeArgs.files[0].path).toBe("streaming/file1.ts") + expect(nativeArgs.files[1].path).toBe("streaming/file2.ts") + }) + }) + }) + + describe("finalizeStreamingToolCall", () => { + describe("apply_diff tool", () => { + it("should finalize single-file format correctly", () => { + const id = "toolu_finalize_apply_diff_single" + NativeToolCallParser.startStreamingToolCall(id, "apply_diff") + + NativeToolCallParser.processStreamingChunk( + id, + JSON.stringify({ + path: "finalized/single.ts", + diff: "<<<<<<< SEARCH\nold\n=======\nnew\n>>>>>>> REPLACE", + }), + ) + + const result = NativeToolCallParser.finalizeStreamingToolCall(id) + + expect(result).not.toBeNull() + expect(result?.type).toBe("tool_use") + if (result?.type === "tool_use") { + const nativeArgs = result.nativeArgs as { path: string; diff: string } + expect(nativeArgs.path).toBe("finalized/single.ts") + expect(nativeArgs.diff).toContain("<<<<<<< SEARCH") + } + }) + + it("should finalize multi-file format correctly", () => { + const id = "toolu_finalize_apply_diff_multi" + NativeToolCallParser.startStreamingToolCall(id, "apply_diff") + + NativeToolCallParser.processStreamingChunk( + id, + JSON.stringify({ + files: [ + { + path: "finalized/file1.ts", + diff: "<<<<<<< SEARCH\nold1\n=======\nnew1\n>>>>>>> REPLACE", + }, + { + path: "finalized/file2.ts", + diff: "<<<<<<< SEARCH\nold2\n=======\nnew2\n>>>>>>> REPLACE", + }, + ], + }), + ) + + const result = NativeToolCallParser.finalizeStreamingToolCall(id) + + expect(result).not.toBeNull() + expect(result?.type).toBe("tool_use") + if (result?.type === "tool_use") { + const nativeArgs = result.nativeArgs as { + files: Array<{ path: string; diff: string }> + } + expect(nativeArgs.files).toHaveLength(2) + expect(nativeArgs.files[0].path).toBe("finalized/file1.ts") + expect(nativeArgs.files[1].path).toBe("finalized/file2.ts") + } + }) + }) + }) }) }) diff --git a/src/core/assistant-message/presentAssistantMessage.ts b/src/core/assistant-message/presentAssistantMessage.ts index 2e8b791b349..6618439bc61 100644 --- a/src/core/assistant-message/presentAssistantMessage.ts +++ b/src/core/assistant-message/presentAssistantMessage.ts @@ -377,8 +377,24 @@ export async function presentAssistantMessage(cline: Task) { case "write_to_file": return `[${block.name} for '${block.params.path}']` case "apply_diff": - // Handle both legacy format and new multi-file format - if (block.params.path) { + // Handle native multi-file format (from multi_apply_diff schema) + if (block.nativeArgs?.files && Array.isArray(block.nativeArgs.files)) { + const files = block.nativeArgs.files + const firstPath = files[0]?.path + if (firstPath) { + if (files.length > 1) { + return `[${block.name} for '${firstPath}' and ${files.length - 1} more file${files.length > 2 ? "s" : ""}]` + } else { + return `[${block.name} for '${firstPath}']` + } + } + } + // Handle native single-file format + else if (block.nativeArgs?.path) { + return `[${block.name} for '${block.nativeArgs.path}']` + } + // Handle XML legacy format + else if (block.params.path) { return `[${block.name} for '${block.params.path}']` } else if (block.params.args) { // Try to extract first file path from args for display @@ -722,6 +738,7 @@ export async function presentAssistantMessage(cline: Task) { block.params, stateExperiments, includedTools, + block.nativeArgs, ) } catch (error) { cline.consecutiveMistakeCount++ @@ -817,17 +834,38 @@ export async function presentAssistantMessage(cline: Task) { // Check if this tool call came from native protocol by checking for ID // Native calls always have IDs, XML calls never do if (toolProtocol === TOOL_PROTOCOL.NATIVE) { - await applyDiffToolClass.handle(cline, block as ToolUse<"apply_diff">, { - askApproval, - handleError, - pushToolResult, - removeClosingTag, - toolProtocol, - }) + // For native protocol, route based on nativeArgs format: + // - nativeArgs.files (array) -> multi-file tool (from multi_apply_diff schema) + // - nativeArgs.path (string) -> single-file tool (from apply_diff schema) + const nativeArgs = block.nativeArgs as + | { files: Array<{ path: string; diff: string }> } + | { path: string; diff: string } + | undefined + + if (nativeArgs && "files" in nativeArgs && Array.isArray(nativeArgs.files)) { + // Multi-file format: use MultiApplyDiffTool + await applyDiffTool( + cline, + block, + askApproval, + handleError, + pushToolResult, + removeClosingTag, + ) + } else { + // Single-file format: use ApplyDiffTool + await applyDiffToolClass.handle(cline, block as ToolUse<"apply_diff">, { + askApproval, + handleError, + pushToolResult, + removeClosingTag, + toolProtocol, + }) + } break } - // Get the provider and state to check experiment settings + // For XML protocol, check experiment settings to determine routing const provider = cline.providerRef.deref() let isMultiFileApplyDiffEnabled = false diff --git a/src/core/prompts/tools/native-tools/apply_diff.ts b/src/core/prompts/tools/native-tools/apply_diff.ts index 3938e4886a0..14baee926de 100644 --- a/src/core/prompts/tools/native-tools/apply_diff.ts +++ b/src/core/prompts/tools/native-tools/apply_diff.ts @@ -1,4 +1,5 @@ import type OpenAI from "openai" +import { multi_apply_diff } from "./multi_apply_diff" const APPLY_DIFF_DESCRIPTION = `Apply precise, targeted modifications to an existing file using one or more search/replace blocks. This tool is for surgical edits only; the 'SEARCH' block must exactly match the existing content, including whitespace and indentation. To make multiple targeted changes, provide multiple SEARCH/REPLACE blocks in the 'diff' parameter. Use the 'read_file' tool first if you are not confident in the exact content to search for.` @@ -33,3 +34,14 @@ export const apply_diff = { }, }, } satisfies OpenAI.Chat.ChatCompletionTool + +/** + * Creates the apply_diff tool definition, selecting between single-file and multi-file + * schemas based on whether the multi-file experiment is enabled. + * + * @param multiFileEnabled - Whether to use the multi-file schema (default: false) + * @returns Native tool definition for apply_diff + */ +export function createApplyDiffTool(multiFileEnabled: boolean = false): OpenAI.Chat.ChatCompletionTool { + return multiFileEnabled ? multi_apply_diff : apply_diff +} diff --git a/src/core/prompts/tools/native-tools/index.ts b/src/core/prompts/tools/native-tools/index.ts index 760d987b47b..f4a180f2ac9 100644 --- a/src/core/prompts/tools/native-tools/index.ts +++ b/src/core/prompts/tools/native-tools/index.ts @@ -1,6 +1,6 @@ import type OpenAI from "openai" import accessMcpResource from "./access_mcp_resource" -import { apply_diff } from "./apply_diff" +import { createApplyDiffTool } from "./apply_diff" import applyPatch from "./apply_patch" import askFollowupQuestion from "./ask_followup_question" import attemptCompletion from "./attempt_completion" @@ -27,12 +27,16 @@ export { convertOpenAIToolToAnthropic, convertOpenAIToolsToAnthropic } from "./c * Get native tools array, optionally customizing based on settings. * * @param partialReadsEnabled - Whether to include line_ranges support in read_file tool (default: true) + * @param multiFileApplyDiffEnabled - Whether to use multi-file apply_diff schema (default: false) * @returns Array of native tool definitions */ -export function getNativeTools(partialReadsEnabled: boolean = true): OpenAI.Chat.ChatCompletionTool[] { +export function getNativeTools( + partialReadsEnabled: boolean = true, + multiFileApplyDiffEnabled: boolean = false, +): OpenAI.Chat.ChatCompletionTool[] { return [ accessMcpResource, - apply_diff, + createApplyDiffTool(multiFileApplyDiffEnabled), applyPatch, askFollowupQuestion, attemptCompletion, diff --git a/src/core/prompts/tools/native-tools/multi_apply_diff.ts b/src/core/prompts/tools/native-tools/multi_apply_diff.ts new file mode 100644 index 00000000000..28a1ef3f3ef --- /dev/null +++ b/src/core/prompts/tools/native-tools/multi_apply_diff.ts @@ -0,0 +1,54 @@ +import type OpenAI from "openai" + +const MULTI_APPLY_DIFF_DESCRIPTION = `Apply precise, targeted modifications to one or more files using search/replace blocks. This tool supports batch operations across multiple files in a single request, maximizing efficiency. For each file, the 'SEARCH' block must exactly match the existing content, including whitespace and indentation. Use the 'read_file' tool first if you are not confident in the exact content to search for.` + +const DIFF_PARAMETER_DESCRIPTION = `A string containing one or more search/replace blocks defining the changes. The ':start_line:' is required and indicates the starting line number of the original content. You must not add a start line for the replacement content. Each block must follow this format: +<<<<<<< SEARCH +:start_line:[line_number] +------- +[exact content to find] +======= +[new content to replace with] +>>>>>>> REPLACE` + +/** + * Multi-file apply_diff schema for native tool calling. + * This schema is used when the MULTI_FILE_APPLY_DIFF experiment is enabled. + * It allows batch operations across multiple files in a single tool call. + */ +export const multi_apply_diff = { + type: "function", + function: { + name: "apply_diff", // Same name - model sees "apply_diff" + description: MULTI_APPLY_DIFF_DESCRIPTION, + parameters: { + type: "object", + properties: { + files: { + type: "array", + description: + "List of files to modify with their diffs. Include multiple files to batch related changes efficiently.", + items: { + type: "object", + properties: { + path: { + type: "string", + description: + "The path of the file to modify, relative to the current workspace directory.", + }, + diff: { + type: "string", + description: DIFF_PARAMETER_DESCRIPTION, + }, + }, + required: ["path", "diff"], + additionalProperties: false, + }, + minItems: 1, + }, + }, + required: ["files"], + additionalProperties: false, + }, + }, +} satisfies OpenAI.Chat.ChatCompletionTool diff --git a/src/core/task/build-tools.ts b/src/core/task/build-tools.ts index 575b31580e6..a527cb89477 100644 --- a/src/core/task/build-tools.ts +++ b/src/core/task/build-tools.ts @@ -3,6 +3,7 @@ import type { ProviderSettings, ModeConfig, ModelInfo } from "@roo-code/types" import type { ClineProvider } from "../webview/ClineProvider" import { getNativeTools, getMcpServerTools } from "../prompts/tools/native-tools" import { filterNativeToolsForMode, filterMcpToolsForMode } from "../prompts/tools/filter-tools-for-mode" +import { experiments as experimentsModule, EXPERIMENT_IDS } from "../../shared/experiments" interface BuildToolsOptions { provider: ClineProvider @@ -55,8 +56,14 @@ export async function buildNativeToolsArray(options: BuildToolsOptions): Promise // Determine if partial reads are enabled based on maxReadFileLine setting const partialReadsEnabled = maxReadFileLine !== -1 - // Build native tools with dynamic read_file tool based on partialReadsEnabled - const nativeTools = getNativeTools(partialReadsEnabled) + // Determine if multi-file apply_diff is enabled based on experiment flag + const multiFileApplyDiffEnabled = experimentsModule.isEnabled( + experiments ?? {}, + EXPERIMENT_IDS.MULTI_FILE_APPLY_DIFF, + ) + + // Build native tools with dynamic read_file and apply_diff tools based on settings + const nativeTools = getNativeTools(partialReadsEnabled, multiFileApplyDiffEnabled) // Filter native tools based on mode restrictions const filteredNativeTools = filterNativeToolsForMode( diff --git a/src/core/tools/MultiApplyDiffTool.ts b/src/core/tools/MultiApplyDiffTool.ts index 7e076d27a9a..186aed52655 100644 --- a/src/core/tools/MultiApplyDiffTool.ts +++ b/src/core/tools/MultiApplyDiffTool.ts @@ -61,45 +61,20 @@ export async function applyDiffTool( pushToolResult: PushToolResult, removeClosingTag: RemoveClosingTag, ) { - // Check if native protocol is enabled - if so, always use single-file class-based tool const toolProtocol = resolveToolProtocol(cline.apiConfiguration, cline.api.getModel().info) - if (isNativeProtocol(toolProtocol)) { - return applyDiffToolClass.handle(cline, block as ToolUse<"apply_diff">, { - askApproval, - handleError, - pushToolResult, - removeClosingTag, - toolProtocol, - }) - } - - // Check if MULTI_FILE_APPLY_DIFF experiment is enabled - const provider = cline.providerRef.deref() - const state = await provider?.getState() - if (provider && state) { - const isMultiFileApplyDiffEnabled = experiments.isEnabled( - state.experiments ?? {}, - EXPERIMENT_IDS.MULTI_FILE_APPLY_DIFF, - ) - // If experiment is disabled, use single-file class-based tool - if (!isMultiFileApplyDiffEnabled) { - return applyDiffToolClass.handle(cline, block as ToolUse<"apply_diff">, { - askApproval, - handleError, - pushToolResult, - removeClosingTag, - toolProtocol, - }) - } - } + // Note: Routing between single-file and multi-file tools is now done in presentAssistantMessage.ts + // based on nativeArgs format. This function is only called for multi-file operations. - // Otherwise, continue with new multi-file implementation + // Multi-file implementation const argsXmlTag: string | undefined = block.params.args const legacyPath: string | undefined = block.params.path const legacyDiffContent: string | undefined = block.params.diff const legacyStartLineStr: string | undefined = block.params.start_line + // Native multi-file format from nativeArgs.files + const nativeFiles = (block.nativeArgs as { files?: Array<{ path: string; diff: string }> } | undefined)?.files + let operationsMap: Record = {} let usingLegacyParams = false let filteredOperationErrors: string[] = [] @@ -107,7 +82,12 @@ export async function applyDiffTool( // Handle partial message first if (block.partial) { let filePath = "" - if (argsXmlTag) { + // Native multi-file format + if (nativeFiles && nativeFiles.length > 0) { + filePath = nativeFiles[0].path || "" + } + // XML args format + else if (argsXmlTag) { const match = argsXmlTag.match(/.*?([^<]+)<\/path>/s) if (match) { filePath = match[1] @@ -126,7 +106,33 @@ export async function applyDiffTool( return } - if (argsXmlTag) { + // Handle native multi-file format (from nativeArgs.files via multi_apply_diff schema) + if (nativeFiles && nativeFiles.length > 0) { + for (const file of nativeFiles) { + if (!file.path || !file.diff) continue + + const filePath = file.path + + // Initialize the operation in the map if it doesn't exist + if (!operationsMap[filePath]) { + operationsMap[filePath] = { + path: filePath, + diff: [], + } + } + + // Native format has a single diff content per file entry + // The diff content contains the full SEARCH/REPLACE block(s) + if (file.diff) { + operationsMap[filePath].diff.push({ + content: file.diff, + startLine: undefined, // Native format doesn't include start_line per file + }) + } + } + } + // Handle XML args format (from XML protocol) + else if (argsXmlTag) { // Parse file entries from XML (new way) try { // IMPORTANT: We use parseXmlForDiff here instead of parseXml to prevent HTML entity decoding diff --git a/src/core/tools/__tests__/applyDiffTool.experiment.spec.ts b/src/core/tools/__tests__/applyDiffTool.experiment.spec.ts index 4e0044c5ee2..8c6d89c6d8a 100644 --- a/src/core/tools/__tests__/applyDiffTool.experiment.spec.ts +++ b/src/core/tools/__tests__/applyDiffTool.experiment.spec.ts @@ -8,17 +8,21 @@ vi.mock("vscode", () => ({ }, })) -// Mock the ApplyDiffTool module -vi.mock("../ApplyDiffTool", () => ({ - applyDiffTool: { - handle: vi.fn(), - }, -})) - // Import after mocking to get the mocked version import { applyDiffTool as multiApplyDiffTool } from "../MultiApplyDiffTool" -import { applyDiffTool as applyDiffToolClass } from "../ApplyDiffTool" +/** + * These tests verify that multiApplyDiffTool properly handles multi-file operations. + * + * NOTE: Routing between single-file and multi-file tools is now done in presentAssistantMessage.ts + * based on nativeArgs format. multiApplyDiffTool is responsible for: + * - Handling XML args format (multi-file XML protocol) + * - Handling nativeArgs.files format (multi-file native protocol) + * - Handling legacy path/diff params (single file, XML protocol) + * + * The routing logic tests (when to use applyDiffTool vs multiApplyDiffTool) should be + * tested in presentAssistantMessage.spec.ts or integration tests. + */ describe("applyDiffTool experiment routing", () => { let mockCline: any let mockBlock: any @@ -68,9 +72,22 @@ describe("applyDiffTool experiment routing", () => { }), }, processQueuedMessages: vi.fn(), + consecutiveMistakeCount: 0, + recordToolError: vi.fn(), + sayAndCreateMissingParamError: vi.fn().mockResolvedValue("Missing parameter error"), + // Required for file access validation + rooIgnoreController: { + validateAccess: vi.fn().mockReturnValue(true), + }, + rooProtectedController: { + isWriteProtected: vi.fn().mockReturnValue(false), + }, + say: vi.fn().mockResolvedValue(undefined), } as any mockBlock = { + type: "tool_use", + name: "apply_diff", params: { path: "test.ts", diff: "test diff", @@ -84,16 +101,18 @@ describe("applyDiffTool experiment routing", () => { mockRemoveClosingTag = vi.fn((tag, value) => value) }) - it("should use legacy tool when MULTI_FILE_APPLY_DIFF experiment is disabled", async () => { + it("should handle legacy params directly when experiment is disabled", async () => { + // With new architecture, multiApplyDiffTool handles the request directly + // when called with legacy params. Routing to applyDiffTool (single-file) + // is now done in presentAssistantMessage.ts BEFORE calling this function. mockProvider.getState.mockResolvedValue({ experiments: { [EXPERIMENT_IDS.MULTI_FILE_APPLY_DIFF]: false, }, }) - // Mock the class-based tool to resolve successfully - ;(applyDiffToolClass.handle as any).mockResolvedValue(undefined) - + // This will result in an error because the file doesn't exist, + // but it verifies the function processes the request directly await multiApplyDiffTool( mockCline, mockBlock, @@ -103,21 +122,15 @@ describe("applyDiffTool experiment routing", () => { mockRemoveClosingTag, ) - expect(applyDiffToolClass.handle).toHaveBeenCalledWith(mockCline, mockBlock, { - askApproval: mockAskApproval, - handleError: mockHandleError, - pushToolResult: mockPushToolResult, - removeClosingTag: mockRemoveClosingTag, - toolProtocol: "xml", - }) + // Function should process the request - it will push a result (error about missing file) + expect(mockPushToolResult).toHaveBeenCalled() }) - it("should use legacy tool when experiments are not defined", async () => { + it("should handle legacy params directly when experiments are not defined", async () => { mockProvider.getState.mockResolvedValue({}) - // Mock the class-based tool to resolve successfully - ;(applyDiffToolClass.handle as any).mockResolvedValue(undefined) - + // This will result in an error because the file doesn't exist, + // but it verifies the function processes the request directly await multiApplyDiffTool( mockCline, mockBlock, @@ -127,24 +140,19 @@ describe("applyDiffTool experiment routing", () => { mockRemoveClosingTag, ) - expect(applyDiffToolClass.handle).toHaveBeenCalledWith(mockCline, mockBlock, { - askApproval: mockAskApproval, - handleError: mockHandleError, - pushToolResult: mockPushToolResult, - removeClosingTag: mockRemoveClosingTag, - toolProtocol: "xml", - }) + // Function should process the request - it will push a result (error about missing file) + expect(mockPushToolResult).toHaveBeenCalled() }) - it("should use multi-file tool when MULTI_FILE_APPLY_DIFF experiment is enabled and using XML protocol", async () => { + it("should handle multi-file operations when MULTI_FILE_APPLY_DIFF experiment is enabled", async () => { mockProvider.getState.mockResolvedValue({ experiments: { [EXPERIMENT_IDS.MULTI_FILE_APPLY_DIFF]: true, }, }) - // Mock the new tool behavior - it should continue with the multi-file implementation - // Since we're not mocking the entire function, we'll just verify it doesn't call the class-based tool + // This will result in an error because the file doesn't exist, + // but it verifies the function processes the request directly await multiApplyDiffTool( mockCline, mockBlock, @@ -154,45 +162,74 @@ describe("applyDiffTool experiment routing", () => { mockRemoveClosingTag, ) - expect(applyDiffToolClass.handle).not.toHaveBeenCalled() + // Function should process the request directly + expect(mockPushToolResult).toHaveBeenCalled() }) - it("should use class-based tool when model defaults to native protocol", async () => { - // Update model to support native tools and default to native protocol - mockCline.api.getModel = vi.fn().mockReturnValue({ - id: "test-model", - info: { - maxTokens: 4096, - contextWindow: 128000, - supportsPromptCache: false, - supportsNativeTools: true, // Model supports native tools - defaultToolProtocol: "native", // Model defaults to native protocol + it("should handle native multi-file format (nativeArgs.files)", async () => { + // Test that multiApplyDiffTool properly handles native multi-file format + mockProvider.getState.mockResolvedValue({ + experiments: { + [EXPERIMENT_IDS.MULTI_FILE_APPLY_DIFF]: true, }, }) + const blockWithNativeArgs = { + type: "tool_use" as const, + name: "apply_diff" as const, + params: {}, + partial: false, + nativeArgs: { + files: [ + { path: "test1.ts", diff: "test diff 1" }, + { path: "test2.ts", diff: "test diff 2" }, + ], + }, + } + + await multiApplyDiffTool( + mockCline, + blockWithNativeArgs, + mockAskApproval, + mockHandleError, + mockPushToolResult, + mockRemoveClosingTag, + ) + + // Function should process the multi-file request + expect(mockPushToolResult).toHaveBeenCalled() + }) + + it("should handle partial messages for native multi-file format", async () => { mockProvider.getState.mockResolvedValue({ experiments: { [EXPERIMENT_IDS.MULTI_FILE_APPLY_DIFF]: true, }, }) - ;(applyDiffToolClass.handle as any).mockResolvedValue(undefined) + + const mockAsk = vi.fn().mockResolvedValue({}) + mockCline.ask = mockAsk + + const blockWithNativeArgs = { + type: "tool_use" as const, + name: "apply_diff" as const, + params: {}, + partial: true, + nativeArgs: { + files: [{ path: "test1.ts", diff: "partial diff" }], + }, + } await multiApplyDiffTool( mockCline, - mockBlock, + blockWithNativeArgs, mockAskApproval, mockHandleError, mockPushToolResult, mockRemoveClosingTag, ) - // When native protocol is used, should always use class-based tool - expect(applyDiffToolClass.handle).toHaveBeenCalledWith(mockCline, mockBlock, { - askApproval: mockAskApproval, - handleError: mockHandleError, - pushToolResult: mockPushToolResult, - removeClosingTag: mockRemoveClosingTag, - toolProtocol: "native", - }) + // For partial messages, should call ask with partial=true + expect(mockAsk).toHaveBeenCalledWith("tool", expect.any(String), true) }) }) diff --git a/src/core/tools/validateToolUse.ts b/src/core/tools/validateToolUse.ts index d0570337414..4a784da7905 100644 --- a/src/core/tools/validateToolUse.ts +++ b/src/core/tools/validateToolUse.ts @@ -32,6 +32,7 @@ export function validateToolUse( toolParams?: Record, experiments?: Record, includedTools?: string[], + nativeArgs?: Record, ): void { // First, check if the tool name is actually a valid/known tool // This catches completely invalid tool names like "edit_file" that don't exist @@ -51,6 +52,7 @@ export function validateToolUse( toolParams, experiments, includedTools, + nativeArgs, ) ) { throw new Error(`Tool "${toolName}" is not allowed in ${mode} mode.`) @@ -81,6 +83,7 @@ export function isToolAllowedForMode( toolParams?: Record, // All tool parameters experiments?: Record, includedTools?: string[], // Opt-in tools explicitly included (e.g., from modelInfo) + nativeArgs?: Record, // Native protocol arguments (e.g., nativeArgs.files for multi-file apply_diff) ): boolean { // Always allow these tools if (ALWAYS_AVAILABLE_TOOLS.includes(tool as any)) { @@ -188,6 +191,37 @@ export function isToolAllowedForMode( console.warn(`Failed to parse XML args for file restriction validation: ${error}`) } } + + // Handle native protocol multi-file format (nativeArgs.files from multi_apply_diff schema) + if (nativeArgs?.files && Array.isArray(nativeArgs.files)) { + for (const file of nativeArgs.files) { + const filePath = file?.path + if (filePath && typeof filePath === "string") { + if (!doesFileMatchRegex(filePath, options.fileRegex)) { + throw new FileRestrictionError( + mode.name, + options.fileRegex, + options.description, + filePath, + tool, + ) + } + } + } + } + + // Handle native protocol single-file format (nativeArgs.path from apply_diff schema) + if (nativeArgs?.path && typeof nativeArgs.path === "string" && nativeArgs?.diff) { + if (!doesFileMatchRegex(nativeArgs.path, options.fileRegex)) { + throw new FileRestrictionError( + mode.name, + options.fileRegex, + options.description, + nativeArgs.path, + tool, + ) + } + } } return true diff --git a/src/shared/__tests__/modes.spec.ts b/src/shared/__tests__/modes.spec.ts index a00abde7879..db42cc2b571 100644 --- a/src/shared/__tests__/modes.spec.ts +++ b/src/shared/__tests__/modes.spec.ts @@ -300,6 +300,191 @@ describe("isToolAllowedForMode", () => { }), ).toThrow(/Markdown files only/) }) + + it("applies restrictions to apply_diff with native protocol multi-file format (nativeArgs.files)", () => { + // Test apply_diff with nativeArgs.files (native protocol multi-file format) + // This simulates the multi_apply_diff schema used with native protocol + + // Should allow markdown files in architect mode + expect( + isToolAllowedForMode( + "apply_diff", + "architect", + [], + undefined, + {}, // toolParams is empty for native protocol + undefined, + undefined, + { files: [{ path: "test.md", diff: "- old\n+ new" }] }, + ), + ).toBe(true) + + // Test with non-markdown file - should throw error + expect(() => + isToolAllowedForMode("apply_diff", "architect", [], undefined, {}, undefined, undefined, { + files: [{ path: "test.py", diff: "- old\n+ new" }], + }), + ).toThrow(FileRestrictionError) + expect(() => + isToolAllowedForMode("apply_diff", "architect", [], undefined, {}, undefined, undefined, { + files: [{ path: "test.py", diff: "- old\n+ new" }], + }), + ).toThrow(/Markdown files only/) + + // Test with multiple markdown files - should allow + expect( + isToolAllowedForMode("apply_diff", "architect", [], undefined, {}, undefined, undefined, { + files: [ + { path: "readme.md", diff: "- old\n+ new" }, + { path: "docs.md", diff: "- old\n+ new" }, + ], + }), + ).toBe(true) + + // Test with mixed file types - should throw error for non-markdown + expect(() => + isToolAllowedForMode("apply_diff", "architect", [], undefined, {}, undefined, undefined, { + files: [ + { path: "readme.md", diff: "- old\n+ new" }, + { path: "script.py", diff: "- old\n+ new" }, + ], + }), + ).toThrow(FileRestrictionError) + expect(() => + isToolAllowedForMode("apply_diff", "architect", [], undefined, {}, undefined, undefined, { + files: [ + { path: "readme.md", diff: "- old\n+ new" }, + { path: "script.py", diff: "- old\n+ new" }, + ], + }), + ).toThrow(/Markdown files only/) + }) + + it("applies restrictions to apply_diff with native protocol single-file format (nativeArgs.path)", () => { + // Test apply_diff with nativeArgs.path (native protocol single-file format) + // This simulates the apply_diff schema used with native protocol + + // Should allow markdown files in architect mode + expect( + isToolAllowedForMode("apply_diff", "architect", [], undefined, {}, undefined, undefined, { + path: "test.md", + diff: "- old\n+ new", + }), + ).toBe(true) + + // Test with non-markdown file - should throw error + expect(() => + isToolAllowedForMode("apply_diff", "architect", [], undefined, {}, undefined, undefined, { + path: "test.py", + diff: "- old\n+ new", + }), + ).toThrow(FileRestrictionError) + expect(() => + isToolAllowedForMode("apply_diff", "architect", [], undefined, {}, undefined, undefined, { + path: "test.py", + diff: "- old\n+ new", + }), + ).toThrow(/Markdown files only/) + }) + + it("applies native protocol file restrictions to custom modes with fileRegex", () => { + // Test that custom mode file restrictions work with native protocol formats + const customModesWithRegex: ModeConfig[] = [ + { + slug: "ts-editor", + name: "TypeScript Editor", + roleDefinition: "You are a TypeScript editor", + groups: [ + "read", + ["edit", { fileRegex: "\\.tsx?$", description: "TypeScript files only" }], + "browser", + ], + }, + ] + + // Test native multi-file format with valid TS files + expect( + isToolAllowedForMode( + "apply_diff", + "ts-editor", + customModesWithRegex, + undefined, + {}, + undefined, + undefined, + { + files: [ + { path: "app.ts", diff: "- old\n+ new" }, + { path: "component.tsx", diff: "- old\n+ new" }, + ], + }, + ), + ).toBe(true) + + // Test native multi-file format with invalid file + expect(() => + isToolAllowedForMode( + "apply_diff", + "ts-editor", + customModesWithRegex, + undefined, + {}, + undefined, + undefined, + { + files: [ + { path: "app.ts", diff: "- old\n+ new" }, + { path: "styles.css", diff: "- old\n+ new" }, + ], + }, + ), + ).toThrow(FileRestrictionError) + expect(() => + isToolAllowedForMode( + "apply_diff", + "ts-editor", + customModesWithRegex, + undefined, + {}, + undefined, + undefined, + { + files: [ + { path: "app.ts", diff: "- old\n+ new" }, + { path: "styles.css", diff: "- old\n+ new" }, + ], + }, + ), + ).toThrow(/TypeScript files only/) + + // Test native single-file format with valid TS file + expect( + isToolAllowedForMode( + "apply_diff", + "ts-editor", + customModesWithRegex, + undefined, + {}, + undefined, + undefined, + { path: "app.ts", diff: "- old\n+ new" }, + ), + ).toBe(true) + + // Test native single-file format with invalid file + expect(() => + isToolAllowedForMode( + "apply_diff", + "ts-editor", + customModesWithRegex, + undefined, + {}, + undefined, + undefined, + { path: "styles.css", diff: "- old\n+ new" }, + ), + ).toThrow(FileRestrictionError) + }) }) it("handles non-existent modes", () => { diff --git a/src/shared/tools.ts b/src/shared/tools.ts index de7a65bfb79..b5b3b461d11 100644 --- a/src/shared/tools.ts +++ b/src/shared/tools.ts @@ -90,7 +90,8 @@ export type NativeToolArgs = { read_file: { files: FileEntry[] } attempt_completion: { result: string } execute_command: { command: string; cwd?: string } - apply_diff: { path: string; diff: string } + // Union type to support both single-file and multi-file formats + apply_diff: { path: string; diff: string } | { files: Array<{ path: string; diff: string }> } search_and_replace: { path: string; operations: Array<{ search: string; replace: string }> } search_replace: { file_path: string; old_string: string; new_string: string } apply_patch: { patch: string }