diff --git a/src/api/providers/__tests__/openai.spec.ts b/src/api/providers/__tests__/openai.spec.ts index d95860d5739..104117affc3 100644 --- a/src/api/providers/__tests__/openai.spec.ts +++ b/src/api/providers/__tests__/openai.spec.ts @@ -1139,6 +1139,226 @@ describe("OpenAiHandler", () => { ) }) }) + + describe("Mistral/Devstral Family Models", () => { + const systemPrompt = "You are a helpful assistant." + const messagesWithToolResult: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text", text: "Hello!" }], + }, + { + role: "assistant", + content: [ + { + type: "tool_use", + id: "call_test_123456789", + name: "read_file", + input: { path: "test.ts" }, + }, + ], + }, + { + role: "user", + content: [ + { + type: "tool_result", + tool_use_id: "call_test_123456789", + content: "File content here", + }, + { + type: "text", + text: "Details here", + }, + ], + }, + ] + + it("should detect Mistral models and apply mergeToolResultText", async () => { + const mistralHandler = new OpenAiHandler({ + ...mockOptions, + openAiModelId: "mistral-large-latest", + }) + + const stream = mistralHandler.createMessage(systemPrompt, messagesWithToolResult) + for await (const _chunk of stream) { + // Consume the stream + } + + expect(mockCreate).toHaveBeenCalled() + const callArgs = mockCreate.mock.calls[0][0] + + // Find the messages - should NOT have a user message after tool message + // because mergeToolResultText should merge text into the tool message + const messages = callArgs.messages + const toolMessageIndex = messages.findIndex((m: any) => m.role === "tool") + + // Assert tool message exists - test setup should always produce a tool message + expect(toolMessageIndex).not.toBe(-1) + const toolMessage = messages[toolMessageIndex] + + // Verify the tool message contains both the original content AND the merged environment_details + // This is the key verification that mergeToolResultText is working correctly + expect(toolMessage.content).toContain("File content here") + expect(toolMessage.content).toContain("environment_details") + }) + + it("should not have user message after tool message for Mistral models", async () => { + // Create a message sequence that includes a follow-up after the tool result + // to verify the Mistral constraint is enforced + const messagesWithFollowUp: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text", text: "Read test.ts and explain it" }], + }, + { + role: "assistant", + content: [ + { + type: "tool_use", + id: "call_abc123xyz", + name: "read_file", + input: { path: "test.ts" }, + }, + ], + }, + { + role: "user", + content: [ + { + type: "tool_result", + tool_use_id: "call_abc123xyz", + content: "export const foo = 'bar'", + }, + { + type: "text", + text: "Current directory: /project", + }, + ], + }, + { + role: "assistant", + content: [{ type: "text", text: "This file exports a constant named foo with value 'bar'." }], + }, + { + role: "user", + content: [{ type: "text", text: "Thanks!" }], + }, + ] + + const mistralHandler = new OpenAiHandler({ + ...mockOptions, + openAiModelId: "mistral-large-latest", + }) + + const stream = mistralHandler.createMessage(systemPrompt, messagesWithFollowUp) + for await (const _chunk of stream) { + // Consume the stream + } + + expect(mockCreate).toHaveBeenCalled() + const callArgs = mockCreate.mock.calls[0][0] + const messages = callArgs.messages + + // Find the tool message + const toolMessageIndex = messages.findIndex((m: any) => m.role === "tool") + expect(toolMessageIndex).not.toBe(-1) + + // Verify there IS a next message (the assistant response) + const nextMessage = messages[toolMessageIndex + 1] + expect(nextMessage).toBeDefined() + + // Verify the next message is NOT a user message + // This is the Mistral constraint: after tool, only assistant or tool is allowed, never user + // Per mistral_common validator: elif previous_role == Roles.tool: expected_roles = {Roles.assistant, Roles.tool} + expect(nextMessage.role).not.toBe("user") + expect(nextMessage.role).toBe("assistant") + }) + + it("should detect Devstral models and apply mergeToolResultText", async () => { + const devstralHandler = new OpenAiHandler({ + ...mockOptions, + openAiModelId: "devstral-small-2", + }) + + const stream = devstralHandler.createMessage(systemPrompt, messagesWithToolResult) + for await (const _chunk of stream) { + // Consume the stream + } + + expect(mockCreate).toHaveBeenCalled() + const callArgs = mockCreate.mock.calls[0][0] + + // Verify the model ID was passed correctly + expect(callArgs.model).toBe("devstral-small-2") + }) + + it("should normalize tool call IDs to 9-char alphanumeric for Mistral models", async () => { + const mistralHandler = new OpenAiHandler({ + ...mockOptions, + openAiModelId: "mistral-medium", + }) + + const stream = mistralHandler.createMessage(systemPrompt, messagesWithToolResult) + for await (const _chunk of stream) { + // Consume the stream + } + + expect(mockCreate).toHaveBeenCalled() + const callArgs = mockCreate.mock.calls[0][0] + + // Find the tool message and verify the tool_call_id is normalized + const toolMessage = callArgs.messages.find((m: any) => m.role === "tool") + // Assert tool message exists - test setup should always produce a tool message + expect(toolMessage).toBeDefined() + // The ID should be normalized to 9 alphanumeric characters + expect(toolMessage.tool_call_id).toMatch(/^[a-zA-Z0-9]{9}$/) + }) + + it("should NOT apply Mistral-specific handling for non-Mistral models", async () => { + const gpt4Handler = new OpenAiHandler({ + ...mockOptions, + openAiModelId: "gpt-4-turbo", + }) + + const stream = gpt4Handler.createMessage(systemPrompt, messagesWithToolResult) + for await (const _chunk of stream) { + // Consume the stream + } + + expect(mockCreate).toHaveBeenCalled() + const callArgs = mockCreate.mock.calls[0][0] + + // For non-Mistral models, tool_call_id should retain original format + const toolMessage = callArgs.messages.find((m: any) => m.role === "tool") + // Assert tool message exists - test setup should always produce a tool message + expect(toolMessage).toBeDefined() + // The original ID format should be preserved (not normalized) + expect(toolMessage.tool_call_id).toBe("call_test_123456789") + }) + + it("should handle case-insensitive model detection", async () => { + const mixedCaseHandler = new OpenAiHandler({ + ...mockOptions, + openAiModelId: "Mistral-Large-LATEST", + }) + + const stream = mixedCaseHandler.createMessage(systemPrompt, messagesWithToolResult) + for await (const _chunk of stream) { + // Consume the stream + } + + expect(mockCreate).toHaveBeenCalled() + const callArgs = mockCreate.mock.calls[0][0] + + // Verify model detection worked despite mixed case + const toolMessage = callArgs.messages.find((m: any) => m.role === "tool") + // Assert tool message exists - test setup should always produce a tool message + expect(toolMessage).toBeDefined() + // The ID should be normalized (indicating Mistral detection worked) + expect(toolMessage.tool_call_id).toMatch(/^[a-zA-Z0-9]{9}$/) + }) + }) }) describe("getOpenAiModels", () => { diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index 74cbb511138..ca6a8a78f64 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -14,7 +14,8 @@ import type { ApiHandlerOptions } from "../../shared/api" import { TagMatcher } from "../../utils/tag-matcher" -import { convertToOpenAiMessages } from "../transform/openai-format" +import { convertToOpenAiMessages, ConvertToOpenAiMessagesOptions } from "../transform/openai-format" +import { normalizeMistralToolCallId } from "../transform/mistral-format" import { convertToR1Format } from "../transform/r1-format" import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { getModelParams } from "../transform/model-params" @@ -91,6 +92,9 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl const isAzureAiInference = this._isAzureAiInference(modelUrl) const deepseekReasoner = modelId.includes("deepseek-reasoner") || enabledR1Format + // Mistral/Devstral models require strict tool message ordering and normalized tool call IDs + const mistralConversionOptions = this._getMistralConversionOptions(modelId) + if (modelId.includes("o1") || modelId.includes("o3") || modelId.includes("o4")) { yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages, metadata) return @@ -121,7 +125,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } } - convertedMessages = [systemMessage, ...convertToOpenAiMessages(messages)] + convertedMessages = [systemMessage, ...convertToOpenAiMessages(messages, mistralConversionOptions)] if (modelInfo.supportsPromptCache) { // Note: the following logic is copied from openrouter: @@ -225,7 +229,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl model: modelId, messages: deepseekReasoner ? convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]) - : [systemMessage, ...convertToOpenAiMessages(messages)], + : [systemMessage, ...convertToOpenAiMessages(messages, mistralConversionOptions)], // Tools are always present (minimum ALWAYS_AVAILABLE_TOOLS) tools: this.convertToolsForOpenAI(metadata?.tools), tool_choice: metadata?.tool_choice, @@ -329,6 +333,9 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl const modelInfo = this.getModel().info const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl) + // Mistral/Devstral models require strict tool message ordering and normalized tool call IDs + const mistralConversionOptions = this._getMistralConversionOptions(modelId) + if (this.options.openAiStreamingEnabled ?? true) { const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl) @@ -339,7 +346,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl role: "developer", content: `Formatting re-enabled\n${systemPrompt}`, }, - ...convertToOpenAiMessages(messages), + ...convertToOpenAiMessages(messages, mistralConversionOptions), ], stream: true, ...(isGrokXAI ? {} : { stream_options: { include_usage: true } }), @@ -375,7 +382,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl role: "developer", content: `Formatting re-enabled\n${systemPrompt}`, }, - ...convertToOpenAiMessages(messages), + ...convertToOpenAiMessages(messages, mistralConversionOptions), ], reasoning_effort: modelInfo.reasoningEffort as "low" | "medium" | "high" | undefined, temperature: undefined, @@ -508,6 +515,36 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl return urlHost.endsWith(".services.ai.azure.com") } + /** + * Checks if the model is part of the Mistral/Devstral family. + * Mistral models require strict message ordering (no user message after tool message) + * and have specific tool call ID format requirements (9-char alphanumeric). + * @param modelId - The model identifier to check + * @returns true if the model is a Mistral/Devstral family model + */ + private _isMistralFamily(modelId: string): boolean { + const modelIdLower = modelId.toLowerCase() + return modelIdLower.includes("mistral") || modelIdLower.includes("devstral") + } + + /** + * Gets the conversion options for Mistral/Devstral models. + * When the model is in the Mistral family, returns options to: + * 1. Merge text content after tool results into the last tool message (prevents user-after-tool error) + * 2. Normalize tool call IDs to 9-char alphanumeric format (Mistral's strict requirement) + * @param modelId - The model identifier + * @returns Conversion options for convertToOpenAiMessages, or undefined for non-Mistral models + */ + private _getMistralConversionOptions(modelId: string): ConvertToOpenAiMessagesOptions | undefined { + if (this._isMistralFamily(modelId)) { + return { + mergeToolResultText: true, + normalizeToolCallId: normalizeMistralToolCallId, + } + } + return undefined + } + /** * Adds max_completion_tokens to the request body if needed based on provider configuration * Note: max_tokens is deprecated in favor of max_completion_tokens as per OpenAI documentation