Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/browser/components/ChatInput/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ const ChatInputInner: React.FC<ChatInputProps> = (props) => {
variant === "workspace" ? props.workspaceId : getProjectScopeId(props.projectPath)
);
// Extract models for convenience (don't create separate state - use hook as single source of truth)
// - preferredModel: gateway-transformed model for API calls
// - preferredModel: canonical model used for backend routing
// - baseModel: canonical format for UI display and policy checks (e.g., ThinkingSlider)
const preferredModel = sendMessageOptions.model;
const baseModel = sendMessageOptions.baseModel;
Expand Down
7 changes: 6 additions & 1 deletion src/browser/components/Messages/AssistantMessage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,12 @@ export const AssistantMessage: React.FC<AssistantMessageProps> = ({

return (
<div className="flex items-center gap-2">
{modelName && <ModelDisplay modelString={modelName} />}
{modelName && (
<ModelDisplay
modelString={modelName}
routedThroughGateway={message.routedThroughGateway}
/>
)}
{isCompacted && (
<span className="text-plan-mode bg-plan-mode/10 inline-flex items-center gap-1 rounded-sm px-1.5 py-0.5 text-[10px] font-medium uppercase">
{isIdleCompacted ? (
Expand Down
28 changes: 24 additions & 4 deletions src/browser/components/Messages/ModelDisplay.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@ interface ModelDisplayProps {
modelString: string;
/** Whether to show the tooltip on hover (default: true, set to false when used within another tooltip) */
showTooltip?: boolean;
/** Explicit signal that the model routed through Mux Gateway (canonical modelString). */
routedThroughGateway?: boolean;
}

/**
* Parse a model string into provider and model name.
* Handles mux-gateway format: "mux-gateway:inner-provider/model-name"
* Returns: { provider, modelName, isMuxGateway, innerProvider }
*/
function parseModelString(modelString: string): {
function parseModelString(
modelString: string,
routedThroughGateway?: boolean
): {
provider: string;
modelName: string;
isMuxGateway: boolean;
Expand All @@ -30,6 +35,10 @@ function parseModelString(modelString: string): {
return { provider, modelName, isMuxGateway: true, innerProvider };
}

if (routedThroughGateway && provider && rest) {
return { provider, modelName: rest, isMuxGateway: true, innerProvider: provider };
}

return { provider, modelName: rest, isMuxGateway: false, innerProvider: "" };
}

Expand All @@ -42,13 +51,24 @@ function parseModelString(modelString: string): {
* Uses standard inline layout for natural text alignment.
* Icon is 1em (matches font size) with vertical-align: middle.
*/
export const ModelDisplay: React.FC<ModelDisplayProps> = ({ modelString, showTooltip = true }) => {
const { provider, modelName, isMuxGateway, innerProvider } = parseModelString(modelString);
export const ModelDisplay: React.FC<ModelDisplayProps> = ({
modelString,
showTooltip = true,
routedThroughGateway,
}) => {
const { provider, modelName, isMuxGateway, innerProvider } = parseModelString(
modelString,
routedThroughGateway
);

// For mux-gateway, show the inner provider's icon (the model's actual provider)
const iconProvider = isMuxGateway ? innerProvider : provider;
const displayName = formatModelDisplayName(modelName);
const suffix = isMuxGateway ? " (mux gateway)" : "";
const tooltipModelString =
isMuxGateway && provider !== "mux-gateway" && provider.length > 0
? `mux-gateway:${provider}/${modelName}`
: modelString;

const iconClass =
"mr-[0.3em] inline-block h-[1.1em] w-[1.1em] align-[-0.19em] [&_svg]:block [&_svg]:h-full [&_svg]:w-full [&_svg_.st0]:fill-current [&_svg_circle]:!fill-current [&_svg_path]:!fill-current [&_svg_rect]:!fill-current";
Expand All @@ -73,7 +93,7 @@ export const ModelDisplay: React.FC<ModelDisplayProps> = ({ modelString, showToo
<span data-model-display-tooltip>{content}</span>
</TooltipTrigger>
<TooltipContent align="center" data-model-tooltip-text>
{modelString}
{tooltipModelString}
</TooltipContent>
</Tooltip>
);
Expand Down
28 changes: 13 additions & 15 deletions src/browser/hooks/useGatewayModels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,11 @@ import {
GATEWAY_ENABLED_KEY,
GATEWAY_MODELS_KEY,
} from "@/common/constants/storage";

/**
* Providers that Mux Gateway supports routing to.
* Based on Vercel AI Gateway supported providers.
*
* Excluded:
* - ollama: Local-only provider, not routable through cloud gateway
* - openrouter: Already a gateway/aggregator, routing through another gateway is redundant
* - bedrock: Complex auth (AWS credentials), not simple API key routing
* - mux-gateway: Already gateway format
*/
const GATEWAY_SUPPORTED_PROVIDERS = new Set(["anthropic", "openai", "google", "xai"]);
import {
MUX_GATEWAY_SUPPORTED_PROVIDERS,
isValidProvider,
type ProviderName,
} from "@/common/constants/providers";

// ============================================================================
// Pure utility functions (no side effects, used for message sending)
Expand All @@ -27,17 +20,22 @@ const GATEWAY_SUPPORTED_PROVIDERS = new Set(["anthropic", "openai", "google", "x
/**
* Extract provider from a model ID.
*/
function getProvider(modelId: string): string | null {
function getProvider(modelId: string): ProviderName | null {
const colonIndex = modelId.indexOf(":");
return colonIndex === -1 ? null : modelId.slice(0, colonIndex);
if (colonIndex === -1) {
return null;
}

const provider = modelId.slice(0, colonIndex);
return isValidProvider(provider) ? provider : null;
}

/**
* Check if a model's provider can route through Mux Gateway.
*/
export function isProviderSupported(modelId: string): boolean {
const provider = getProvider(modelId);
return provider !== null && GATEWAY_SUPPORTED_PROVIDERS.has(provider);
return provider !== null && MUX_GATEWAY_SUPPORTED_PROVIDERS.has(provider);
}

/**
Expand Down
56 changes: 9 additions & 47 deletions src/browser/hooks/useSendMessageOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { useThinkingLevel } from "./useThinkingLevel";
import { useAgent } from "@/browser/contexts/AgentContext";
import { usePersistedState } from "./usePersistedState";
import { getDefaultModel } from "./useModelsFromSettings";
import { migrateGatewayModel, useGateway, isProviderSupported } from "./useGatewayModels";
import { migrateGatewayModel } from "./useGatewayModels";
import {
getModelKey,
PREFERRED_SYSTEM_1_MODEL_KEY,
Expand All @@ -13,28 +13,9 @@ import { coerceThinkingLevel, type ThinkingLevel } from "@/common/types/thinking
import type { MuxProviderOptions } from "@/common/types/providerOptions";
import { getSendOptionsFromStorage } from "@/browser/utils/messages/sendOptions";
import { useProviderOptions } from "./useProviderOptions";
import type { GatewayState } from "./useGatewayModels";
import { useExperimentOverrideValue } from "./useExperiments";
import { EXPERIMENT_IDS } from "@/common/constants/experiments";

/**
* Transform model to gateway format using reactive gateway state.
* This ensures the component re-renders when gateway toggles change.
*/
function applyGatewayTransform(modelId: string, gateway: GatewayState): string {
if (!gateway.isActive || !isProviderSupported(modelId) || !gateway.modelUsesGateway(modelId)) {
return modelId;
}

// Transform provider:model to mux-gateway:provider/model
const colonIndex = modelId.indexOf(":");
if (colonIndex === -1) return modelId;

const provider = modelId.slice(0, colonIndex);
const model = modelId.slice(colonIndex + 1);
return `mux-gateway:${provider}/${model}`;
}

interface ExperimentValues {
programmaticToolCalling: boolean | undefined;
programmaticToolCallingExclusive: boolean | undefined;
Expand All @@ -50,31 +31,17 @@ interface ExperimentValues {
function constructSendMessageOptions(
agentId: string,
thinkingLevel: ThinkingLevel,
preferredModel: string | null | undefined,
baseModel: string,
providerOptions: MuxProviderOptions,
fallbackModel: string,
gateway: GatewayState,
experimentValues: ExperimentValues,
system1Model: string | undefined,
system1ThinkingLevel: ThinkingLevel | undefined
): SendMessageOptions {
// Ensure model is always a valid string (defensive against corrupted localStorage)
const rawModel =
typeof preferredModel === "string" && preferredModel ? preferredModel : fallbackModel;

// Migrate any legacy mux-gateway:provider/model format to canonical form
const baseModel = migrateGatewayModel(rawModel);

// Preserve the user's preferred thinking level; backend enforces per-model policy.
const uiThinking = thinkingLevel;

// Transform to gateway format if gateway is enabled for this model (reactive)
const model = applyGatewayTransform(baseModel, gateway);

const system1ModelForBackend =
system1Model !== undefined
? applyGatewayTransform(migrateGatewayModel(system1Model), gateway)
: undefined;
system1Model !== undefined ? migrateGatewayModel(system1Model) : undefined;

const system1ThinkingLevelForBackend =
system1ThinkingLevel !== undefined && system1ThinkingLevel !== "off"
Expand All @@ -83,7 +50,7 @@ function constructSendMessageOptions(

return {
thinkingLevel: uiThinking,
model,
model: baseModel,
...(system1ModelForBackend ? { system1Model: system1ModelForBackend } : {}),
...(system1ThinkingLevelForBackend
? { system1ThinkingLevel: system1ThinkingLevelForBackend }
Expand All @@ -100,8 +67,8 @@ function constructSendMessageOptions(
}

/**
* Extended send options that includes both the gateway-transformed model
* and the base model (for UI components that need canonical model names).
* Extended send options that includes both the canonical model used for backend routing
* and a base model string for UI components that need a stable display value.
*/
export interface SendMessageOptionsWithBase extends SendMessageOptions {
/** Base model in canonical format (e.g., "openai:gpt-5.1-codex-max") for UI/policy checks */
Expand All @@ -118,8 +85,8 @@ export interface SendMessageOptionsWithBase extends SendMessageOptions {
* Uses usePersistedState which has listener mode, so changes to preferences
* propagate automatically to all components using this hook.
*
* Returns both `model` (possibly gateway-transformed for API calls) and
* `baseModel` (canonical format for UI display and policy checks).
* Returns both `model` (canonical for backend routing) and `baseModel`
* (canonical format for UI display and policy checks).
*/
export function useSendMessageOptions(workspaceId: string): SendMessageOptionsWithBase {
const [thinkingLevel] = useThinkingLevel();
Expand All @@ -132,9 +99,6 @@ export function useSendMessageOptions(workspaceId: string): SendMessageOptionsWi
{ listener: true } // Listen for changes from ModelSelector and other sources
);

// Subscribe to gateway state so we re-render when user toggles gateway
const gateway = useGateway();

// Subscribe to local override state so toggles apply immediately.
// If undefined, the backend will apply the PostHog assignment.
const programmaticToolCalling = useExperimentOverrideValue(
Expand Down Expand Up @@ -170,10 +134,8 @@ export function useSendMessageOptions(workspaceId: string): SendMessageOptionsWi
const options = constructSendMessageOptions(
agentId,
thinkingLevel,
preferredModel,
baseModel,
providerOptions,
defaultModel,
gateway,
{ programmaticToolCalling, programmaticToolCallingExclusive, system1 },
system1Model,
system1ThinkingLevel
Expand Down
8 changes: 4 additions & 4 deletions src/browser/stories/App.chat.stories.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -876,9 +876,9 @@ export const ModeHelpTooltip: AppStory = {
/**
* Model selector pretty display with mux-gateway enabled.
*
* Regression test: when gateway is enabled, `useSendMessageOptions().model` becomes
* `mux-gateway:provider/model`, but the UI should still display the canonical
* provider:model form (e.g. GPT-4o, not \"Openai/gpt 4o\").
* Regression test: when gateway is enabled, routing happens in the backend,
* but the UI should still display the canonical provider:model form
* (e.g. GPT-4o, not \"Openai/gpt 4o\").
*/
export const ModelSelectorPrettyWithGateway: AppStory = {
render: () => (
Expand All @@ -887,7 +887,7 @@ export const ModelSelectorPrettyWithGateway: AppStory = {
const workspaceId = "ws-gateway-model";
const baseModel = "openai:gpt-4o";

// Ensure the gateway transform actually kicks in (so the regression would reproduce).
// Ensure the gateway indicator is active (so the regression would reproduce).
updatePersistedState(getModelKey(workspaceId), baseModel);
updatePersistedState("gateway-enabled", true);
updatePersistedState("gateway-available", true);
Expand Down
5 changes: 5 additions & 0 deletions src/browser/utils/messages/ChatEventProcessor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ export function createChatEventProcessor(): ChatEventProcessor {
historySequence: start.metadata?.historySequence ?? start.historySequence,
timestamp: start.metadata?.timestamp ?? start.timestamp,
model: start.metadata?.model ?? start.model,
routedThroughGateway: start.metadata?.routedThroughGateway ?? start.routedThroughGateway,
muxMetadata: start.metadata?.muxMetadata,
partial: true,
});
Expand Down Expand Up @@ -217,6 +218,10 @@ export function createChatEventProcessor(): ChatEventProcessor {
partial: false,
timestamp: metadata.timestamp ?? message.metadata?.timestamp,
model: metadata.model ?? message.metadata?.model ?? event.metadata.model,
routedThroughGateway:
metadata.routedThroughGateway ??
message.metadata?.routedThroughGateway ??
event.metadata.routedThroughGateway,
usage: metadata.usage ?? message.metadata?.usage,
providerMetadata: metadata.providerMetadata ?? message.metadata?.providerMetadata,
systemMessageTokens: metadata.systemMessageTokens ?? message.metadata?.systemMessageTokens,
Expand Down
5 changes: 5 additions & 0 deletions src/browser/utils/messages/StreamingMessageAggregator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ interface StreamingContext {
isCompacting: boolean;
hasCompactionContinue: boolean;
model: string;
routedThroughGateway?: boolean;

/** Timestamp of first content token (text or reasoning delta) - backend Date.now() */
serverFirstTokenTime: number | null;
Expand Down Expand Up @@ -1196,6 +1197,7 @@ export class StreamingMessageAggregator {
isCompacting,
hasCompactionContinue,
model: data.model,
routedThroughGateway: data.routedThroughGateway,
serverFirstTokenTime: null,
toolExecutionMs: 0,
pendingToolStarts: new Map(),
Expand All @@ -1211,6 +1213,7 @@ export class StreamingMessageAggregator {
historySequence: data.historySequence,
timestamp: Date.now(),
model: data.model,
routedThroughGateway: data.routedThroughGateway,
mode: data.mode,
});

Expand Down Expand Up @@ -2032,6 +2035,7 @@ export class StreamingMessageAggregator {
isCompacted: !!message.metadata?.compacted,
isIdleCompacted: message.metadata?.compacted === "idle",
model: message.metadata?.model,
routedThroughGateway: message.metadata?.routedThroughGateway,
mode: message.metadata?.mode,
agentId: message.metadata?.agentId ?? message.metadata?.mode,
timestamp: part.timestamp ?? baseTimestamp,
Expand Down Expand Up @@ -2119,6 +2123,7 @@ export class StreamingMessageAggregator {
errorType: message.metadata.errorType ?? "unknown",
historySequence,
model: message.metadata.model,
routedThroughGateway: message.metadata?.routedThroughGateway,
timestamp: baseTimestamp,
});
}
Expand Down
18 changes: 18 additions & 0 deletions src/browser/utils/messages/compactionOptions.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,24 @@ describe("applyCompactionOverrides", () => {
expect(result.model).toBe(KNOWN_MODELS.HAIKU.id);
});

it("falls back to workspace model when override is empty", () => {
const compactData: CompactionRequestData = {
model: "",
};
const result = applyCompactionOverrides(baseOptions, compactData);

expect(result.model).toBe(KNOWN_MODELS.SONNET.id);
});

it("falls back to workspace model when override is whitespace", () => {
const compactData: CompactionRequestData = {
model: " ",
};
const result = applyCompactionOverrides(baseOptions, compactData);

expect(result.model).toBe(KNOWN_MODELS.SONNET.id);
});

it("enforces thinking policy for the compaction model", () => {
// Test Anthropic model (supports medium)
const anthropicData: CompactionRequestData = {
Expand Down
8 changes: 5 additions & 3 deletions src/browser/utils/messages/compactionOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
*/

import { readPersistedState } from "@/browser/hooks/usePersistedState";
import { toGatewayModel } from "@/browser/hooks/useGatewayModels";
import { AGENT_AI_DEFAULTS_KEY } from "@/common/constants/storage";
import type { SendMessageOptions } from "@/common/orpc/types";
import type { CompactionRequestData } from "@/common/types/message";
Expand All @@ -29,8 +28,11 @@ export function applyCompactionOverrides(
baseOptions: SendMessageOptions,
compactData: CompactionRequestData
): SendMessageOptions {
// Apply gateway transformation - compactData.model is raw, baseOptions.model is already transformed.
const compactionModel = compactData.model ? toGatewayModel(compactData.model) : baseOptions.model;
const compactionModelOverride = compactData.model?.trim();
const compactionModel =
compactionModelOverride === undefined || compactionModelOverride === ""
? baseOptions.model
: compactionModelOverride;

const agentAiDefaults = readPersistedState<AgentAiDefaults>(AGENT_AI_DEFAULTS_KEY, {});
const preferredThinking = agentAiDefaults.compact?.thinkingLevel;
Expand Down
Loading
Loading