Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 220 additions & 0 deletions src/api/providers/__tests__/openai.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,226 @@ describe("OpenAiHandler", () => {
)
})
})

describe("Mistral/Devstral Family Models", () => {
const systemPrompt = "You are a helpful assistant."
const messagesWithToolResult: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: [{ type: "text", text: "Hello!" }],
},
{
role: "assistant",
content: [
{
type: "tool_use",
id: "call_test_123456789",
name: "read_file",
input: { path: "test.ts" },
},
],
},
{
role: "user",
content: [
{
type: "tool_result",
tool_use_id: "call_test_123456789",
content: "File content here",
},
{
type: "text",
text: "<environment_details>Details here</environment_details>",
},
],
},
]

it("should detect Mistral models and apply mergeToolResultText", async () => {
const mistralHandler = new OpenAiHandler({
...mockOptions,
openAiModelId: "mistral-large-latest",
})

const stream = mistralHandler.createMessage(systemPrompt, messagesWithToolResult)
for await (const _chunk of stream) {
// Consume the stream
}

expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]

// Find the messages - should NOT have a user message after tool message
// because mergeToolResultText should merge text into the tool message
const messages = callArgs.messages
const toolMessageIndex = messages.findIndex((m: any) => m.role === "tool")

// Assert tool message exists - test setup should always produce a tool message
expect(toolMessageIndex).not.toBe(-1)
const toolMessage = messages[toolMessageIndex]

// Verify the tool message contains both the original content AND the merged environment_details
// This is the key verification that mergeToolResultText is working correctly
expect(toolMessage.content).toContain("File content here")
expect(toolMessage.content).toContain("environment_details")
})

it("should not have user message after tool message for Mistral models", async () => {
// Create a message sequence that includes a follow-up after the tool result
// to verify the Mistral constraint is enforced
const messagesWithFollowUp: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: [{ type: "text", text: "Read test.ts and explain it" }],
},
{
role: "assistant",
content: [
{
type: "tool_use",
id: "call_abc123xyz",
name: "read_file",
input: { path: "test.ts" },
},
],
},
{
role: "user",
content: [
{
type: "tool_result",
tool_use_id: "call_abc123xyz",
content: "export const foo = 'bar'",
},
{
type: "text",
text: "<environment_details>Current directory: /project</environment_details>",
},
],
},
{
role: "assistant",
content: [{ type: "text", text: "This file exports a constant named foo with value 'bar'." }],
},
{
role: "user",
content: [{ type: "text", text: "Thanks!" }],
},
]

const mistralHandler = new OpenAiHandler({
...mockOptions,
openAiModelId: "mistral-large-latest",
})

const stream = mistralHandler.createMessage(systemPrompt, messagesWithFollowUp)
for await (const _chunk of stream) {
// Consume the stream
}

expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]
const messages = callArgs.messages

// Find the tool message
const toolMessageIndex = messages.findIndex((m: any) => m.role === "tool")
expect(toolMessageIndex).not.toBe(-1)

// Verify there IS a next message (the assistant response)
const nextMessage = messages[toolMessageIndex + 1]
expect(nextMessage).toBeDefined()

// Verify the next message is NOT a user message
// This is the Mistral constraint: after tool, only assistant or tool is allowed, never user
// Per mistral_common validator: elif previous_role == Roles.tool: expected_roles = {Roles.assistant, Roles.tool}
expect(nextMessage.role).not.toBe("user")
expect(nextMessage.role).toBe("assistant")
})

it("should detect Devstral models and apply mergeToolResultText", async () => {
const devstralHandler = new OpenAiHandler({
...mockOptions,
openAiModelId: "devstral-small-2",
})

const stream = devstralHandler.createMessage(systemPrompt, messagesWithToolResult)
for await (const _chunk of stream) {
// Consume the stream
}

expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]

// Verify the model ID was passed correctly
expect(callArgs.model).toBe("devstral-small-2")
})

it("should normalize tool call IDs to 9-char alphanumeric for Mistral models", async () => {
const mistralHandler = new OpenAiHandler({
...mockOptions,
openAiModelId: "mistral-medium",
})

const stream = mistralHandler.createMessage(systemPrompt, messagesWithToolResult)
for await (const _chunk of stream) {
// Consume the stream
}

expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]

// Find the tool message and verify the tool_call_id is normalized
const toolMessage = callArgs.messages.find((m: any) => m.role === "tool")
// Assert tool message exists - test setup should always produce a tool message
expect(toolMessage).toBeDefined()
// The ID should be normalized to 9 alphanumeric characters
expect(toolMessage.tool_call_id).toMatch(/^[a-zA-Z0-9]{9}$/)
})

it("should NOT apply Mistral-specific handling for non-Mistral models", async () => {
const gpt4Handler = new OpenAiHandler({
...mockOptions,
openAiModelId: "gpt-4-turbo",
})

const stream = gpt4Handler.createMessage(systemPrompt, messagesWithToolResult)
for await (const _chunk of stream) {
// Consume the stream
}

expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]

// For non-Mistral models, tool_call_id should retain original format
const toolMessage = callArgs.messages.find((m: any) => m.role === "tool")
// Assert tool message exists - test setup should always produce a tool message
expect(toolMessage).toBeDefined()
// The original ID format should be preserved (not normalized)
expect(toolMessage.tool_call_id).toBe("call_test_123456789")
})

it("should handle case-insensitive model detection", async () => {
const mixedCaseHandler = new OpenAiHandler({
...mockOptions,
openAiModelId: "Mistral-Large-LATEST",
})

const stream = mixedCaseHandler.createMessage(systemPrompt, messagesWithToolResult)
for await (const _chunk of stream) {
// Consume the stream
}

expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]

// Verify model detection worked despite mixed case
const toolMessage = callArgs.messages.find((m: any) => m.role === "tool")
// Assert tool message exists - test setup should always produce a tool message
expect(toolMessage).toBeDefined()
// The ID should be normalized (indicating Mistral detection worked)
expect(toolMessage.tool_call_id).toMatch(/^[a-zA-Z0-9]{9}$/)
})
})
})

describe("getOpenAiModels", () => {
Expand Down
47 changes: 42 additions & 5 deletions src/api/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ import type { ApiHandlerOptions } from "../../shared/api"

import { TagMatcher } from "../../utils/tag-matcher"

import { convertToOpenAiMessages } from "../transform/openai-format"
import { convertToOpenAiMessages, ConvertToOpenAiMessagesOptions } from "../transform/openai-format"
import { normalizeMistralToolCallId } from "../transform/mistral-format"
import { convertToR1Format } from "../transform/r1-format"
import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
import { getModelParams } from "../transform/model-params"
Expand Down Expand Up @@ -91,6 +92,9 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
const isAzureAiInference = this._isAzureAiInference(modelUrl)
const deepseekReasoner = modelId.includes("deepseek-reasoner") || enabledR1Format

// Mistral/Devstral models require strict tool message ordering and normalized tool call IDs
const mistralConversionOptions = this._getMistralConversionOptions(modelId)

if (modelId.includes("o1") || modelId.includes("o3") || modelId.includes("o4")) {
yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages, metadata)
return
Expand Down Expand Up @@ -121,7 +125,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
}
}

convertedMessages = [systemMessage, ...convertToOpenAiMessages(messages)]
convertedMessages = [systemMessage, ...convertToOpenAiMessages(messages, mistralConversionOptions)]

if (modelInfo.supportsPromptCache) {
// Note: the following logic is copied from openrouter:
Expand Down Expand Up @@ -225,7 +229,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
model: modelId,
messages: deepseekReasoner
? convertToR1Format([{ role: "user", content: systemPrompt }, ...messages])
: [systemMessage, ...convertToOpenAiMessages(messages)],
: [systemMessage, ...convertToOpenAiMessages(messages, mistralConversionOptions)],
// Tools are always present (minimum ALWAYS_AVAILABLE_TOOLS)
tools: this.convertToolsForOpenAI(metadata?.tools),
tool_choice: metadata?.tool_choice,
Expand Down Expand Up @@ -329,6 +333,9 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
const modelInfo = this.getModel().info
const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)

// Mistral/Devstral models require strict tool message ordering and normalized tool call IDs
const mistralConversionOptions = this._getMistralConversionOptions(modelId)

if (this.options.openAiStreamingEnabled ?? true) {
const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl)

Expand All @@ -339,7 +346,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
role: "developer",
content: `Formatting re-enabled\n${systemPrompt}`,
},
...convertToOpenAiMessages(messages),
...convertToOpenAiMessages(messages, mistralConversionOptions),
],
stream: true,
...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
Expand Down Expand Up @@ -375,7 +382,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
role: "developer",
content: `Formatting re-enabled\n${systemPrompt}`,
},
...convertToOpenAiMessages(messages),
...convertToOpenAiMessages(messages, mistralConversionOptions),
],
reasoning_effort: modelInfo.reasoningEffort as "low" | "medium" | "high" | undefined,
temperature: undefined,
Expand Down Expand Up @@ -508,6 +515,36 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
return urlHost.endsWith(".services.ai.azure.com")
}

/**
* Checks if the model is part of the Mistral/Devstral family.
* Mistral models require strict message ordering (no user message after tool message)
* and have specific tool call ID format requirements (9-char alphanumeric).
* @param modelId - The model identifier to check
* @returns true if the model is a Mistral/Devstral family model
*/
private _isMistralFamily(modelId: string): boolean {
const modelIdLower = modelId.toLowerCase()
return modelIdLower.includes("mistral") || modelIdLower.includes("devstral")
}

/**
* Gets the conversion options for Mistral/Devstral models.
* When the model is in the Mistral family, returns options to:
* 1. Merge text content after tool results into the last tool message (prevents user-after-tool error)
* 2. Normalize tool call IDs to 9-char alphanumeric format (Mistral's strict requirement)
* @param modelId - The model identifier
* @returns Conversion options for convertToOpenAiMessages, or undefined for non-Mistral models
*/
private _getMistralConversionOptions(modelId: string): ConvertToOpenAiMessagesOptions | undefined {
if (this._isMistralFamily(modelId)) {
return {
mergeToolResultText: true,
normalizeToolCallId: normalizeMistralToolCallId,
}
}
return undefined
}

/**
* Adds max_completion_tokens to the request body if needed based on provider configuration
* Note: max_tokens is deprecated in favor of max_completion_tokens as per OpenAI documentation
Expand Down
Loading