diff --git a/src/browser/stores/WorkspaceStore.ts b/src/browser/stores/WorkspaceStore.ts index 8b217422d6..55c6df41bd 100644 --- a/src/browser/stores/WorkspaceStore.ts +++ b/src/browser/stores/WorkspaceStore.ts @@ -161,6 +161,15 @@ export class WorkspaceStore { data: WorkspaceChatMessage ) => void > = { + "stream-pending": (workspaceId, aggregator, data) => { + aggregator.handleStreamPending(data as never); + if (this.onModelUsed) { + this.onModelUsed((data as { model: string }).model); + } + this.states.bump(workspaceId); + // Bump usage store so liveUsage can show the current model even before streaming starts + this.usageStore.bump(workspaceId); + }, "stream-start": (workspaceId, aggregator, data) => { aggregator.handleStreamStart(data as never); if (this.onModelUsed) { @@ -484,7 +493,7 @@ export class WorkspaceStore { name: metadata?.name ?? workspaceId, // Fall back to ID if metadata missing messages: aggregator.getDisplayedMessages(), queuedMessage: this.queuedMessages.get(workspaceId) ?? null, - canInterrupt: activeStreams.length > 0, + canInterrupt: activeStreams.length > 0 || aggregator.hasInFlightStreams(), isCompacting: aggregator.isCompacting(), awaitingUserQuestion: aggregator.hasAwaitingUserQuestion(), loading: !hasMessages && !isCaughtUp, @@ -969,7 +978,8 @@ export class WorkspaceStore { // Check if there's an active stream in buffered events (reconnection scenario) const pendingEvents = this.pendingStreamEvents.get(workspaceId) ?? []; const hasActiveStream = pendingEvents.some( - (event) => "type" in event && event.type === "stream-start" + (event) => + "type" in event && (event.type === "stream-start" || event.type === "stream-pending") ); // Load historical messages first diff --git a/src/browser/utils/messages/StreamingMessageAggregator.ts b/src/browser/utils/messages/StreamingMessageAggregator.ts index 4d70a00ebb..85c121b614 100644 --- a/src/browser/utils/messages/StreamingMessageAggregator.ts +++ b/src/browser/utils/messages/StreamingMessageAggregator.ts @@ -7,6 +7,7 @@ import type { } from "@/common/types/message"; import { createMuxMessage } from "@/common/types/message"; import type { + StreamPendingEvent, StreamStartEvent, StreamDeltaEvent, UsageDeltaEvent, @@ -52,6 +53,17 @@ interface StreamingContext { model: string; } +type InFlightStreamState = + | { + phase: "pending"; + pendingAt: number; + model: string; + } + | { + phase: "active"; + context: StreamingContext; + }; + /** * Check if a tool result indicates success (for tools that return { success: boolean }) */ @@ -136,7 +148,9 @@ function mergeAdjacentParts(parts: MuxMessage["parts"]): MuxMessage["parts"] { export class StreamingMessageAggregator { private messages = new Map(); - private activeStreams = new Map(); + + // Streams that are in-flight (pending: `stream-pending` received; active: `stream-start` received). + private inFlightStreams = new Map(); // Simple cache for derived values (invalidated on every mutation) private cachedAllMessages: MuxMessage[] | null = null; @@ -336,14 +350,14 @@ export class StreamingMessageAggregator { * Called by handleStreamEnd, handleStreamAbort, and handleStreamError. * * Clears: - * - Active stream tracking (this.activeStreams) + * - In-flight stream tracking (this.inFlightStreams) * - Current TODOs (this.currentTodos) - reconstructed from history on reload * * Does NOT clear: * - agentStatus - persists after stream completion to show last activity */ private cleanupStreamState(messageId: string): void { - this.activeStreams.delete(messageId); + this.inFlightStreams.delete(messageId); // Clear todos when stream ends - they're stream-scoped state // On reload, todos will be reconstructed from completed tool_write calls in history this.currentTodos = []; @@ -461,8 +475,15 @@ export class StreamingMessageAggregator { this.pendingStreamStartTime = time; } + hasInFlightStreams(): boolean { + return this.inFlightStreams.size > 0; + } getActiveStreams(): StreamingContext[] { - return Array.from(this.activeStreams.values()); + const active: StreamingContext[] = []; + for (const stream of this.inFlightStreams.values()) { + if (stream.phase === "active") active.push(stream.context); + } + return active; } /** @@ -470,12 +491,15 @@ export class StreamingMessageAggregator { * Returns undefined if no streams are active */ getActiveStreamMessageId(): string | undefined { - return this.activeStreams.keys().next().value; + for (const [messageId, stream] of this.inFlightStreams.entries()) { + if (stream.phase === "active") return messageId; + } + return undefined; } isCompacting(): boolean { - for (const context of this.activeStreams.values()) { - if (context.isCompacting) { + for (const stream of this.inFlightStreams.values()) { + if (stream.phase === "active" && stream.context.isCompacting) { return true; } } @@ -484,8 +508,13 @@ export class StreamingMessageAggregator { getCurrentModel(): string | undefined { // If there's an active stream, return its model - for (const context of this.activeStreams.values()) { - return context.model; + for (const stream of this.inFlightStreams.values()) { + if (stream.phase === "active") return stream.context.model; + } + + // If we're pending (stream-pending), return that model + for (const stream of this.inFlightStreams.values()) { + if (stream.phase === "pending") return stream.model; } // Otherwise, return the model from the most recent assistant message @@ -501,12 +530,14 @@ export class StreamingMessageAggregator { } clearActiveStreams(): void { - this.activeStreams.clear(); + this.setPendingStreamStartTime(null); + this.inFlightStreams.clear(); + this.invalidateCache(); } clear(): void { this.messages.clear(); - this.activeStreams.clear(); + this.inFlightStreams.clear(); this.invalidateCache(); } @@ -529,8 +560,24 @@ export class StreamingMessageAggregator { } // Unified event handlers that encapsulate all complex logic + handleStreamPending(data: StreamPendingEvent): void { + // Clear pending stream start timestamp - backend has accepted the request. + this.setPendingStreamStartTime(null); + + const existing = this.inFlightStreams.get(data.messageId); + if (existing?.phase === "active") return; + + this.inFlightStreams.set(data.messageId, { + phase: "pending", + pendingAt: Date.now(), + model: data.model, + }); + + this.invalidateCache(); + } + handleStreamStart(data: StreamStartEvent): void { - // Clear pending stream start timestamp - stream has started + // Clear pending stream start timestamp - stream has started. this.setPendingStreamStartTime(null); // NOTE: We do NOT clear agentStatus or currentTodos here. @@ -551,7 +598,7 @@ export class StreamingMessageAggregator { // Use messageId as key - ensures only ONE stream per message // If called twice (e.g., during replay), second call safely overwrites first - this.activeStreams.set(data.messageId, context); + this.inFlightStreams.set(data.messageId, { phase: "active", context }); // Create initial streaming message with empty parts (deltas will append) const streamingMessage = createMuxMessage(data.messageId, "assistant", "", { @@ -583,7 +630,8 @@ export class StreamingMessageAggregator { handleStreamEnd(data: StreamEndEvent): void { // Direct lookup by messageId - O(1) instead of O(n) find - const activeStream = this.activeStreams.get(data.messageId); + const stream = this.inFlightStreams.get(data.messageId); + const activeStream = stream?.phase === "active" ? stream.context : undefined; if (activeStream) { // Normal streaming case: we've been tracking this stream from the start @@ -650,7 +698,8 @@ export class StreamingMessageAggregator { handleStreamAbort(data: StreamAbortEvent): void { // Direct lookup by messageId - const activeStream = this.activeStreams.get(data.messageId); + const stream = this.inFlightStreams.get(data.messageId); + const activeStream = stream?.phase === "active" ? stream.context : undefined; if (activeStream) { // Mark the message as interrupted and merge metadata (consistent with handleStreamEnd) @@ -673,10 +722,9 @@ export class StreamingMessageAggregator { } handleStreamError(data: StreamErrorMessage): void { - // Direct lookup by messageId - const activeStream = this.activeStreams.get(data.messageId); + const isTrackedStream = this.inFlightStreams.has(data.messageId); - if (activeStream) { + if (isTrackedStream) { // Mark the message with error metadata const message = this.messages.get(data.messageId); if (message?.metadata) { @@ -688,32 +736,33 @@ export class StreamingMessageAggregator { this.compactMessageParts(message); } - // Clean up stream-scoped state (active stream tracking, TODOs) + // Clean up stream-scoped state (active/connecting tracking, TODOs) this.cleanupStreamState(data.messageId); this.invalidateCache(); - } else { - // Pre-stream error (e.g., API key not configured before streaming starts) - // Create a synthetic error message since there's no active stream to attach to - // Get the highest historySequence from existing messages so this appears at the end - const maxSequence = Math.max( - 0, - ...Array.from(this.messages.values()).map((m) => m.metadata?.historySequence ?? 0) - ); - const errorMessage: MuxMessage = { - id: data.messageId, - role: "assistant", - parts: [], - metadata: { - partial: true, - error: data.error, - errorType: data.errorType, - timestamp: Date.now(), - historySequence: maxSequence + 1, - }, - }; - this.messages.set(data.messageId, errorMessage); - this.invalidateCache(); + return; } + + // Pre-stream error (e.g., API key not configured before streaming starts) + // Create a synthetic error message since there's no tracked stream to attach to. + // Get the highest historySequence from existing messages so this appears at the end. + const maxSequence = Math.max( + 0, + ...Array.from(this.messages.values()).map((m) => m.metadata?.historySequence ?? 0) + ); + const errorMessage: MuxMessage = { + id: data.messageId, + role: "assistant", + parts: [], + metadata: { + partial: true, + error: data.error, + errorType: data.errorType, + timestamp: Date.now(), + historySequence: maxSequence + 1, + }, + }; + this.messages.set(data.messageId, errorMessage); + this.invalidateCache(); } handleToolCallStart(data: ToolCallStartEvent): void { @@ -844,7 +893,7 @@ export class StreamingMessageAggregator { handleReasoningEnd(_data: ReasoningEndEvent): void { // Reasoning-end is just a signal - no state to update - // Streaming status is inferred from activeStreams in getDisplayedMessages + // Streaming status is inferred from inFlightStreams in getDisplayedMessages this.invalidateCache(); } @@ -1035,7 +1084,7 @@ export class StreamingMessageAggregator { // Check if this message has an active stream (for inferring streaming status) // Direct Map.has() check - O(1) instead of O(n) iteration - const hasActiveStream = this.activeStreams.has(message.id); + const hasActiveStream = this.inFlightStreams.get(message.id)?.phase === "active"; // Merge adjacent text/reasoning parts for display const mergedParts = mergeAdjacentParts(message.parts); diff --git a/src/browser/utils/messages/retryEligibility.test.ts b/src/browser/utils/messages/retryEligibility.test.ts index 6924d182e4..bebe979a29 100644 --- a/src/browser/utils/messages/retryEligibility.test.ts +++ b/src/browser/utils/messages/retryEligibility.test.ts @@ -239,6 +239,35 @@ describe("hasInterruptedStream", () => { expect(hasInterruptedStream(messages, null)).toBe(true); }); + it("returns false when pendingStreamStartTime is null but last user message timestamp is recent (replay/reload)", () => { + const justSentTimestamp = Date.now() - (PENDING_STREAM_START_GRACE_PERIOD_MS - 500); + const messages: DisplayedMessage[] = [ + { + type: "user", + id: "user-1", + historyId: "user-1", + content: "Hello", + historySequence: 1, + timestamp: justSentTimestamp, + }, + ]; + expect(hasInterruptedStream(messages, null)).toBe(false); + }); + + it("returns true when pendingStreamStartTime is null and last user message timestamp is old (replay/reload)", () => { + const longAgoTimestamp = Date.now() - (PENDING_STREAM_START_GRACE_PERIOD_MS + 1000); + const messages: DisplayedMessage[] = [ + { + type: "user", + id: "user-1", + historyId: "user-1", + content: "Hello", + historySequence: 1, + timestamp: longAgoTimestamp, + }, + ]; + expect(hasInterruptedStream(messages, null)).toBe(true); + }); it("returns false when user message just sent (within grace period)", () => { const messages: DisplayedMessage[] = [ { diff --git a/src/browser/utils/messages/retryEligibility.ts b/src/browser/utils/messages/retryEligibility.ts index c4084f6c83..3be1df4e6c 100644 --- a/src/browser/utils/messages/retryEligibility.ts +++ b/src/browser/utils/messages/retryEligibility.ts @@ -80,16 +80,22 @@ export function hasInterruptedStream( ): boolean { if (messages.length === 0) return false; - // Don't show retry barrier if user message was sent very recently (within the grace period) - // This prevents flash during normal send flow while stream-start event arrives - // After the grace period, assume something is wrong and show the barrier - if (pendingStreamStartTime !== null) { - const elapsed = Date.now() - pendingStreamStartTime; + const lastMessage = messages[messages.length - 1]; + + // Don't show retry barrier if the last user message was sent very recently (within the grace period). + // + // We prefer the explicit pendingStreamStartTime (set during the live send flow). + // But during history replay / app reload, pendingStreamStartTime can be null even when the last + // message is a fresh user message. In that case, fall back to the user message timestamp. + const graceStartTime = + pendingStreamStartTime ?? + (lastMessage.type === "user" ? (lastMessage.timestamp ?? null) : null); + + if (graceStartTime !== null) { + const elapsed = Date.now() - graceStartTime; if (elapsed < PENDING_STREAM_START_GRACE_PERIOD_MS) return false; } - const lastMessage = messages[messages.length - 1]; - // ask_user_question is a special case: an unfinished tool call represents an // intentional "waiting for user input" state, not a stream interruption. // diff --git a/src/common/orpc/schemas.ts b/src/common/orpc/schemas.ts index 53a58f1f9f..bdc99c7aef 100644 --- a/src/common/orpc/schemas.ts +++ b/src/common/orpc/schemas.ts @@ -88,6 +88,7 @@ export { StreamDeltaEventSchema, StreamEndEventSchema, StreamErrorMessageSchema, + StreamPendingEventSchema, StreamStartEventSchema, ToolCallDeltaEventSchema, ToolCallEndEventSchema, diff --git a/src/common/orpc/schemas/stream.ts b/src/common/orpc/schemas/stream.ts index 8a39a5a4a3..334ac87e64 100644 --- a/src/common/orpc/schemas/stream.ts +++ b/src/common/orpc/schemas/stream.ts @@ -27,6 +27,17 @@ export const DeleteMessageSchema = z.object({ historySequences: z.array(z.number()), }); +// Emitted when a stream has been registered and is abortable, but before streaming begins. +// This prevents RetryBarrier flash during slow provider connection/setup. +export const StreamPendingEventSchema = z.object({ + type: z.literal("stream-pending"), + workspaceId: z.string(), + messageId: z.string(), + model: z.string(), + historySequence: z.number().meta({ + description: "Backend assigns global message ordering", + }), +}); export const StreamStartEventSchema = z.object({ type: z.literal("stream-start"), workspaceId: z.string(), @@ -273,6 +284,7 @@ export const WorkspaceChatMessageSchema = z.discriminatedUnion("type", [ CaughtUpMessageSchema, StreamErrorMessageSchema, DeleteMessageSchema, + StreamPendingEventSchema, StreamStartEventSchema, StreamDeltaEventSchema, StreamEndEventSchema, diff --git a/src/common/orpc/types.ts b/src/common/orpc/types.ts index 51a7759141..05a131f7fd 100644 --- a/src/common/orpc/types.ts +++ b/src/common/orpc/types.ts @@ -2,6 +2,7 @@ import type { z } from "zod"; import type * as schemas from "./schemas"; import type { + StreamPendingEvent, StreamStartEvent, StreamDeltaEvent, StreamEndEvent, @@ -43,6 +44,10 @@ export function isStreamError(msg: WorkspaceChatMessage): msg is StreamErrorMess return (msg as { type?: string }).type === "stream-error"; } +export function isStreamPending(msg: WorkspaceChatMessage): msg is StreamPendingEvent { + return (msg as { type?: string }).type === "stream-pending"; +} + export function isDeleteMessage(msg: WorkspaceChatMessage): msg is DeleteMessage { return (msg as { type?: string }).type === "delete"; } diff --git a/src/common/types/stream.ts b/src/common/types/stream.ts index 2d11a8034f..b053231207 100644 --- a/src/common/types/stream.ts +++ b/src/common/types/stream.ts @@ -11,6 +11,7 @@ import type { StreamAbortEventSchema, StreamDeltaEventSchema, StreamEndEventSchema, + StreamPendingEventSchema, StreamStartEventSchema, ToolCallDeltaEventSchema, ToolCallEndEventSchema, @@ -22,6 +23,7 @@ import type { * Completed message part (reasoning, text, or tool) suitable for serialization * Used in StreamEndEvent and partial message storage */ +export type StreamPendingEvent = z.infer; export type CompletedMessagePart = MuxReasoningPart | MuxTextPart | MuxToolPart; export type StreamStartEvent = z.infer; @@ -45,6 +47,7 @@ export type ReasoningEndEvent = z.infer; export type UsageDeltaEvent = z.infer; export type AIServiceEvent = + | StreamPendingEvent | StreamStartEvent | StreamDeltaEvent | StreamEndEvent diff --git a/src/common/utils/streamLifecycle.ts b/src/common/utils/streamLifecycle.ts new file mode 100644 index 0000000000..8bf4717eb3 --- /dev/null +++ b/src/common/utils/streamLifecycle.ts @@ -0,0 +1,42 @@ +// Stream lifecycle events are emitted during an in-flight assistant response. +// +// Keeping the event list centralized makes it harder to accidentally forget to forward/buffer a +// newly introduced lifecycle event. + +export const STREAM_LIFECYCLE_EVENTS = [ + "stream-pending", + "stream-start", + "stream-delta", + "stream-abort", + "stream-end", +] as const; + +export type StreamLifecycleEventName = (typeof STREAM_LIFECYCLE_EVENTS)[number]; + +// Events that can be forwarded 1:1 from StreamManager -> AIService. +// (`stream-abort` needs additional bookkeeping in AIService.) +export const STREAM_LIFECYCLE_EVENTS_DIRECT_FORWARD = [ + "stream-pending", + "stream-start", + "stream-delta", + "stream-end", +] as const satisfies readonly StreamLifecycleEventName[]; + +// Events that can be forwarded 1:1 from AIService -> AgentSession -> renderer. +// (`stream-end` has additional session-side behavior.) +export const STREAM_LIFECYCLE_EVENTS_SIMPLE_FORWARD = [ + "stream-pending", + "stream-start", + "stream-delta", + "stream-abort", +] as const satisfies readonly StreamLifecycleEventName[]; + +export function forwardStreamLifecycleEvents(params: { + events: readonly StreamLifecycleEventName[]; + listen: (event: StreamLifecycleEventName, handler: (payload: unknown) => void) => void; + emit: (event: StreamLifecycleEventName, payload: unknown) => void; +}): void { + for (const event of params.events) { + params.listen(event, (payload) => params.emit(event, payload)); + } +} diff --git a/src/node/services/agentSession.ts b/src/node/services/agentSession.ts index 360148de1c..f89c9a6f9e 100644 --- a/src/node/services/agentSession.ts +++ b/src/node/services/agentSession.ts @@ -35,6 +35,10 @@ import { AttachmentService } from "./attachmentService"; import type { PostCompactionAttachment, PostCompactionExclusions } from "@/common/types/attachment"; import { TURNS_BETWEEN_ATTACHMENTS } from "@/common/constants/attachments"; import { extractEditedFileDiffs } from "@/common/utils/messages/extractEditedFiles"; +import { + forwardStreamLifecycleEvents, + STREAM_LIFECYCLE_EVENTS_SIMPLE_FORWARD, +} from "@/common/utils/streamLifecycle"; import { isValidModelFormat } from "@/common/utils/ai/models"; /** @@ -619,8 +623,15 @@ export class AgentSession { this.aiService.on(event, wrapped as never); }; - forward("stream-start", (payload) => this.emitChatEvent(payload)); - forward("stream-delta", (payload) => this.emitChatEvent(payload)); + forwardStreamLifecycleEvents({ + events: STREAM_LIFECYCLE_EVENTS_SIMPLE_FORWARD, + listen: (event, handler) => { + forward(event, handler); + }, + emit: (_event, payload) => { + this.emitChatEvent(payload as WorkspaceChatMessage); + }, + }); forward("tool-call-start", (payload) => this.emitChatEvent(payload)); forward("tool-call-delta", (payload) => this.emitChatEvent(payload)); forward("tool-call-end", (payload) => { @@ -643,7 +654,6 @@ export class AgentSession { forward("reasoning-delta", (payload) => this.emitChatEvent(payload)); forward("reasoning-end", (payload) => this.emitChatEvent(payload)); forward("usage-delta", (payload) => this.emitChatEvent(payload)); - forward("stream-abort", (payload) => this.emitChatEvent(payload)); forward("stream-end", async (payload) => { const handled = await this.compactionHandler.handleCompletion(payload as StreamEndEvent); diff --git a/src/node/services/aiService.ts b/src/node/services/aiService.ts index 5c230b9dc8..c001d0a575 100644 --- a/src/node/services/aiService.ts +++ b/src/node/services/aiService.ts @@ -58,6 +58,10 @@ import { applyToolPolicy, type ToolPolicy } from "@/common/utils/tools/toolPolic import { MockScenarioPlayer } from "./mock/mockScenarioPlayer"; import { EnvHttpProxyAgent, type Dispatcher } from "undici"; import { getPlanFilePath } from "@/common/utils/planStorage"; +import { + forwardStreamLifecycleEvents, + STREAM_LIFECYCLE_EVENTS_DIRECT_FORWARD, +} from "@/common/utils/streamLifecycle"; import { getPlanModeInstruction } from "@/common/utils/ui/modeUtils"; import type { UIMode } from "@/common/types/mode"; import { MUX_APP_ATTRIBUTION_TITLE, MUX_APP_ATTRIBUTION_URL } from "@/constants/appAttribution"; @@ -348,9 +352,15 @@ export class AIService extends EventEmitter { * Forward all stream events from StreamManager to AIService consumers */ private setupStreamEventForwarding(): void { - this.streamManager.on("stream-start", (data) => this.emit("stream-start", data)); - this.streamManager.on("stream-delta", (data) => this.emit("stream-delta", data)); - this.streamManager.on("stream-end", (data) => this.emit("stream-end", data)); + forwardStreamLifecycleEvents({ + events: STREAM_LIFECYCLE_EVENTS_DIRECT_FORWARD, + listen: (event, handler) => { + this.streamManager.on(event, handler); + }, + emit: (event, payload) => { + this.emit(event, payload); + }, + }); // Handle stream-abort: dispose of partial based on abandonPartial flag this.streamManager.on("stream-abort", (data: StreamAbortEvent) => { @@ -1404,6 +1414,7 @@ export class AIService extends EventEmitter { // Delegate to StreamManager with model instance, system message, tools, historySequence, and initial metadata const streamResult = await this.streamManager.startStream( workspaceId, + assistantMessageId, finalMessages, modelResult.data, modelString, diff --git a/src/node/services/mock/mockScenarioPlayer.ts b/src/node/services/mock/mockScenarioPlayer.ts index 606e38b278..ee62621e68 100644 --- a/src/node/services/mock/mockScenarioPlayer.ts +++ b/src/node/services/mock/mockScenarioPlayer.ts @@ -266,6 +266,15 @@ export class MockScenarioPlayer { ): Promise { switch (event.kind) { case "stream-start": { + // Mirror real runtime: emit stream-pending before stream-start + this.deps.aiService.emit("stream-pending", { + type: "stream-pending", + workspaceId, + messageId, + model: event.model, + historySequence, + }); + const payload: StreamStartEvent = { type: "stream-start", workspaceId, diff --git a/src/node/services/streamManager.test.ts b/src/node/services/streamManager.test.ts index 756932f0d2..9a42b9eb67 100644 --- a/src/node/services/streamManager.test.ts +++ b/src/node/services/streamManager.test.ts @@ -84,6 +84,7 @@ describe("StreamManager - Concurrent Stream Prevention", () => { // Start first stream const result1 = await streamManager.startStream( workspaceId, + "assistant-1", [{ role: "user", content: "Say hello and nothing else" }], model, KNOWN_MODELS.SONNET.id, @@ -102,6 +103,7 @@ describe("StreamManager - Concurrent Stream Prevention", () => { // Start second stream - should cancel first const result2 = await streamManager.startStream( workspaceId, + "assistant-2", [{ role: "user", content: "Say goodbye and nothing else" }], model, KNOWN_MODELS.SONNET.id, @@ -164,6 +166,11 @@ describe("StreamManager - Concurrent Stream Prevention", () => { } const ensureStreamSafetyValue = Reflect.get(streamManager, "ensureStreamSafety") as unknown; + const pendingMessageIds: string[] = []; + streamManager.on("stream-pending", (event) => { + pendingMessageIds.push((event as { messageId: string }).messageId); + }); + if (typeof ensureStreamSafetyValue !== "function") { throw new Error("StreamManager.ensureStreamSafety is unavailable for testing"); } @@ -199,7 +206,10 @@ describe("StreamManager - Concurrent Stream Prevention", () => { "createStreamAtomically", ( wsId: string, + assistantMessageId: string, streamToken: string, + _runtimeTempDir: string, + _runtime: unknown, messages: unknown, modelArg: unknown, modelString: string, @@ -228,7 +238,7 @@ describe("StreamManager - Concurrent Stream Prevention", () => { providerMetadata: Promise.resolve(undefined), }, abortController, - messageId: `test-${Math.random().toString(36).slice(2)}`, + messageId: assistantMessageId, token: streamToken, startTime: Date.now(), model: modelString, @@ -274,6 +284,7 @@ describe("StreamManager - Concurrent Stream Prevention", () => { const promises = [ streamManager.startStream( workspaceId, + "assistant-mutex-1", [{ role: "user", content: "test 1" }], model, KNOWN_MODELS.SONNET.id, @@ -285,6 +296,7 @@ describe("StreamManager - Concurrent Stream Prevention", () => { ), streamManager.startStream( workspaceId, + "assistant-mutex-2", [{ role: "user", content: "test 2" }], model, KNOWN_MODELS.SONNET.id, @@ -296,6 +308,7 @@ describe("StreamManager - Concurrent Stream Prevention", () => { ), streamManager.startStream( workspaceId, + "assistant-mutex-3", [{ role: "user", content: "test 3" }], model, KNOWN_MODELS.SONNET.id, @@ -317,6 +330,12 @@ describe("StreamManager - Concurrent Stream Prevention", () => { expect(ensureOperations[i]).toBe("ensure-start"); expect(ensureOperations[i + 1]).toBe("ensure-end"); } + + expect(pendingMessageIds).toEqual([ + "assistant-mutex-1", + "assistant-mutex-2", + "assistant-mutex-3", + ]); }); }); diff --git a/src/node/services/streamManager.ts b/src/node/services/streamManager.ts index 44f40dc918..55e4d5db07 100644 --- a/src/node/services/streamManager.ts +++ b/src/node/services/streamManager.ts @@ -16,6 +16,7 @@ import type { Result } from "@/common/types/result"; import { Ok, Err } from "@/common/types/result"; import { log } from "./log"; import type { + StreamPendingEvent, StreamStartEvent, StreamEndEvent, UsageDeltaEvent, @@ -599,6 +600,7 @@ export class StreamManager extends EventEmitter { */ private createStreamAtomically( workspaceId: WorkspaceId, + assistantMessageId: string, streamToken: StreamToken, runtimeTempDir: string, runtime: Runtime, @@ -683,7 +685,7 @@ export class StreamManager extends EventEmitter { throw error; } - const messageId = `assistant-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + const messageId = assistantMessageId; const streamInfo: WorkspaceStreamInfo = { state: StreamState.STARTING, @@ -1379,6 +1381,7 @@ export class StreamManager extends EventEmitter { */ async startStream( workspaceId: string, + assistantMessageId: string, messages: ModelMessage[], model: LanguageModel, modelString: string, @@ -1425,6 +1428,7 @@ export class StreamManager extends EventEmitter { // Step 4: Atomic stream creation and registration const streamInfo = this.createStreamAtomically( typedWorkspaceId, + assistantMessageId, streamToken, runtimeTempDir, runtime, @@ -1441,6 +1445,16 @@ export class StreamManager extends EventEmitter { toolPolicy ); + // Emit stream-pending as soon as the stream is registered and abortable. + // This lets the frontend suppress RetryBarrier while providers are slow to produce stream-start. + this.emit("stream-pending", { + type: "stream-pending", + workspaceId, + messageId: streamInfo.messageId, + model: streamInfo.model, + historySequence, + } satisfies StreamPendingEvent); + // Step 5: Track the processing promise for guaranteed cleanup // This allows cancelStreamSafely to wait for full exit streamInfo.processingPromise = this.processStreamWithCleanup( diff --git a/tests/e2e/scenarios/basicChat.spec.ts b/tests/e2e/scenarios/basicChat.spec.ts index 595754284e..049d5a1c6d 100644 --- a/tests/e2e/scenarios/basicChat.spec.ts +++ b/tests/e2e/scenarios/basicChat.spec.ts @@ -17,7 +17,8 @@ test("basic chat streaming flow", async ({ ui }) => { expect(timeline.events.length).toBeGreaterThan(0); const eventTypes = timeline.events.map((event) => event.type); - expect(eventTypes[0]).toBe("stream-start"); + expect(eventTypes[0]).toBe("stream-pending"); + expect(eventTypes.indexOf("stream-start")).toBeGreaterThan(0); const deltaCount = eventTypes.filter((type) => type === "stream-delta").length; expect(deltaCount).toBeGreaterThan(1); expect(eventTypes[eventTypes.length - 1]).toBe("stream-end"); diff --git a/tests/e2e/scenarios/permissionModes.spec.ts b/tests/e2e/scenarios/permissionModes.spec.ts index 39f1951422..41c8d05ca6 100644 --- a/tests/e2e/scenarios/permissionModes.spec.ts +++ b/tests/e2e/scenarios/permissionModes.spec.ts @@ -20,7 +20,8 @@ test.describe("permission mode behavior", () => { } const eventTypes = timeline.events.map((event) => event.type); - expect(eventTypes[0]).toBe("stream-start"); + expect(eventTypes[0]).toBe("stream-pending"); + expect(eventTypes.indexOf("stream-start")).toBeGreaterThan(0); expect(eventTypes[eventTypes.length - 1]).toBe("stream-end"); expect(eventTypes.includes("tool-call-start")).toBe(false); expect(eventTypes.includes("tool-call-end")).toBe(false); diff --git a/tests/e2e/scenarios/toolFlows.spec.ts b/tests/e2e/scenarios/toolFlows.spec.ts index 5e1ed07bca..8f2306b5f3 100644 --- a/tests/e2e/scenarios/toolFlows.spec.ts +++ b/tests/e2e/scenarios/toolFlows.spec.ts @@ -121,7 +121,10 @@ test.describe("tool and reasoning flows", () => { if (finalTimeline.events.length === 0) { throw new Error("Recall turn produced no events"); } - expect(finalTimeline.events[0]?.type).toBe("stream-start"); + expect(finalTimeline.events[0]?.type).toBe("stream-pending"); + expect( + finalTimeline.events.findIndex((event) => event.type === "stream-start") + ).toBeGreaterThan(0); expect(finalTimeline.events.some((event) => event.type === "tool-call-start")).toBeFalsy(); await ui.chat.expectTranscriptContains("contains the line 'hello'"); diff --git a/tests/ipc/mcpConfig.test.ts b/tests/ipc/mcpConfig.test.ts index 03ac00bb49..2b293631fb 100644 --- a/tests/ipc/mcpConfig.test.ts +++ b/tests/ipc/mcpConfig.test.ts @@ -17,6 +17,10 @@ import { extractTextFromEvents, HAIKU_MODEL, } from "./helpers"; +import { configureTestRetries } from "./sendMessageTestHelpers"; + +configureTestRetries(3); + import type { StreamCollector } from "./streamCollector"; const describeIntegration = shouldRunIntegrationTests() ? describe : describe.skip;