Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/browser/stores/WorkspaceStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,15 @@ export class WorkspaceStore {
data: WorkspaceChatMessage
) => void
> = {
"stream-pending": (workspaceId, aggregator, data) => {
aggregator.handleStreamPending(data as never);
if (this.onModelUsed) {
this.onModelUsed((data as { model: string }).model);
}
this.states.bump(workspaceId);
// Bump usage store so liveUsage can show the current model even before streaming starts
this.usageStore.bump(workspaceId);
},
"stream-start": (workspaceId, aggregator, data) => {
aggregator.handleStreamStart(data as never);
if (this.onModelUsed) {
Expand Down Expand Up @@ -484,7 +493,7 @@ export class WorkspaceStore {
name: metadata?.name ?? workspaceId, // Fall back to ID if metadata missing
messages: aggregator.getDisplayedMessages(),
queuedMessage: this.queuedMessages.get(workspaceId) ?? null,
canInterrupt: activeStreams.length > 0,
canInterrupt: activeStreams.length > 0 || aggregator.hasInFlightStreams(),
isCompacting: aggregator.isCompacting(),
awaitingUserQuestion: aggregator.hasAwaitingUserQuestion(),
loading: !hasMessages && !isCaughtUp,
Expand Down Expand Up @@ -969,7 +978,8 @@ export class WorkspaceStore {
// Check if there's an active stream in buffered events (reconnection scenario)
const pendingEvents = this.pendingStreamEvents.get(workspaceId) ?? [];
const hasActiveStream = pendingEvents.some(
(event) => "type" in event && event.type === "stream-start"
(event) =>
"type" in event && (event.type === "stream-start" || event.type === "stream-pending")
);

// Load historical messages first
Expand Down
135 changes: 92 additions & 43 deletions src/browser/utils/messages/StreamingMessageAggregator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import type {
} from "@/common/types/message";
import { createMuxMessage } from "@/common/types/message";
import type {
StreamPendingEvent,
StreamStartEvent,
StreamDeltaEvent,
UsageDeltaEvent,
Expand Down Expand Up @@ -52,6 +53,17 @@ interface StreamingContext {
model: string;
}

type InFlightStreamState =
| {
phase: "pending";
pendingAt: number;
model: string;
}
| {
phase: "active";
context: StreamingContext;
};

/**
* Check if a tool result indicates success (for tools that return { success: boolean })
*/
Expand Down Expand Up @@ -136,7 +148,9 @@ function mergeAdjacentParts(parts: MuxMessage["parts"]): MuxMessage["parts"] {

export class StreamingMessageAggregator {
private messages = new Map<string, MuxMessage>();
private activeStreams = new Map<string, StreamingContext>();

// Streams that are in-flight (pending: `stream-pending` received; active: `stream-start` received).
private inFlightStreams = new Map<string, InFlightStreamState>();

// Simple cache for derived values (invalidated on every mutation)
private cachedAllMessages: MuxMessage[] | null = null;
Expand Down Expand Up @@ -336,14 +350,14 @@ export class StreamingMessageAggregator {
* Called by handleStreamEnd, handleStreamAbort, and handleStreamError.
*
* Clears:
* - Active stream tracking (this.activeStreams)
* - In-flight stream tracking (this.inFlightStreams)
* - Current TODOs (this.currentTodos) - reconstructed from history on reload
*
* Does NOT clear:
* - agentStatus - persists after stream completion to show last activity
*/
private cleanupStreamState(messageId: string): void {
this.activeStreams.delete(messageId);
this.inFlightStreams.delete(messageId);
// Clear todos when stream ends - they're stream-scoped state
// On reload, todos will be reconstructed from completed tool_write calls in history
this.currentTodos = [];
Expand Down Expand Up @@ -461,21 +475,31 @@ export class StreamingMessageAggregator {
this.pendingStreamStartTime = time;
}

hasInFlightStreams(): boolean {
return this.inFlightStreams.size > 0;
}
getActiveStreams(): StreamingContext[] {
return Array.from(this.activeStreams.values());
const active: StreamingContext[] = [];
for (const stream of this.inFlightStreams.values()) {
if (stream.phase === "active") active.push(stream.context);
}
return active;
}

/**
* Get the messageId of the first active stream (for token tracking)
* Returns undefined if no streams are active
*/
getActiveStreamMessageId(): string | undefined {
return this.activeStreams.keys().next().value;
for (const [messageId, stream] of this.inFlightStreams.entries()) {
if (stream.phase === "active") return messageId;
}
return undefined;
}

isCompacting(): boolean {
for (const context of this.activeStreams.values()) {
if (context.isCompacting) {
for (const stream of this.inFlightStreams.values()) {
if (stream.phase === "active" && stream.context.isCompacting) {
return true;
}
}
Expand All @@ -484,8 +508,13 @@ export class StreamingMessageAggregator {

getCurrentModel(): string | undefined {
// If there's an active stream, return its model
for (const context of this.activeStreams.values()) {
return context.model;
for (const stream of this.inFlightStreams.values()) {
if (stream.phase === "active") return stream.context.model;
}

// If we're pending (stream-pending), return that model
for (const stream of this.inFlightStreams.values()) {
if (stream.phase === "pending") return stream.model;
}

// Otherwise, return the model from the most recent assistant message
Expand All @@ -501,12 +530,14 @@ export class StreamingMessageAggregator {
}

clearActiveStreams(): void {
this.activeStreams.clear();
this.setPendingStreamStartTime(null);
this.inFlightStreams.clear();
this.invalidateCache();
}

clear(): void {
this.messages.clear();
this.activeStreams.clear();
this.inFlightStreams.clear();
this.invalidateCache();
}

Expand All @@ -529,8 +560,24 @@ export class StreamingMessageAggregator {
}

// Unified event handlers that encapsulate all complex logic
handleStreamPending(data: StreamPendingEvent): void {
// Clear pending stream start timestamp - backend has accepted the request.
this.setPendingStreamStartTime(null);

const existing = this.inFlightStreams.get(data.messageId);
if (existing?.phase === "active") return;

this.inFlightStreams.set(data.messageId, {
phase: "pending",
pendingAt: Date.now(),
model: data.model,
});

this.invalidateCache();
}

handleStreamStart(data: StreamStartEvent): void {
// Clear pending stream start timestamp - stream has started
// Clear pending stream start timestamp - stream has started.
this.setPendingStreamStartTime(null);

// NOTE: We do NOT clear agentStatus or currentTodos here.
Expand All @@ -551,7 +598,7 @@ export class StreamingMessageAggregator {

// Use messageId as key - ensures only ONE stream per message
// If called twice (e.g., during replay), second call safely overwrites first
this.activeStreams.set(data.messageId, context);
this.inFlightStreams.set(data.messageId, { phase: "active", context });

// Create initial streaming message with empty parts (deltas will append)
const streamingMessage = createMuxMessage(data.messageId, "assistant", "", {
Expand Down Expand Up @@ -583,7 +630,8 @@ export class StreamingMessageAggregator {

handleStreamEnd(data: StreamEndEvent): void {
// Direct lookup by messageId - O(1) instead of O(n) find
const activeStream = this.activeStreams.get(data.messageId);
const stream = this.inFlightStreams.get(data.messageId);
const activeStream = stream?.phase === "active" ? stream.context : undefined;

if (activeStream) {
// Normal streaming case: we've been tracking this stream from the start
Expand Down Expand Up @@ -650,7 +698,8 @@ export class StreamingMessageAggregator {

handleStreamAbort(data: StreamAbortEvent): void {
// Direct lookup by messageId
const activeStream = this.activeStreams.get(data.messageId);
const stream = this.inFlightStreams.get(data.messageId);
const activeStream = stream?.phase === "active" ? stream.context : undefined;

if (activeStream) {
// Mark the message as interrupted and merge metadata (consistent with handleStreamEnd)
Expand All @@ -673,10 +722,9 @@ export class StreamingMessageAggregator {
}

handleStreamError(data: StreamErrorMessage): void {
// Direct lookup by messageId
const activeStream = this.activeStreams.get(data.messageId);
const isTrackedStream = this.inFlightStreams.has(data.messageId);

if (activeStream) {
if (isTrackedStream) {
// Mark the message with error metadata
const message = this.messages.get(data.messageId);
if (message?.metadata) {
Expand All @@ -688,32 +736,33 @@ export class StreamingMessageAggregator {
this.compactMessageParts(message);
}

// Clean up stream-scoped state (active stream tracking, TODOs)
// Clean up stream-scoped state (active/connecting tracking, TODOs)
this.cleanupStreamState(data.messageId);
this.invalidateCache();
} else {
// Pre-stream error (e.g., API key not configured before streaming starts)
// Create a synthetic error message since there's no active stream to attach to
// Get the highest historySequence from existing messages so this appears at the end
const maxSequence = Math.max(
0,
...Array.from(this.messages.values()).map((m) => m.metadata?.historySequence ?? 0)
);
const errorMessage: MuxMessage = {
id: data.messageId,
role: "assistant",
parts: [],
metadata: {
partial: true,
error: data.error,
errorType: data.errorType,
timestamp: Date.now(),
historySequence: maxSequence + 1,
},
};
this.messages.set(data.messageId, errorMessage);
this.invalidateCache();
return;
}

// Pre-stream error (e.g., API key not configured before streaming starts)
// Create a synthetic error message since there's no tracked stream to attach to.
// Get the highest historySequence from existing messages so this appears at the end.
const maxSequence = Math.max(
0,
...Array.from(this.messages.values()).map((m) => m.metadata?.historySequence ?? 0)
);
const errorMessage: MuxMessage = {
id: data.messageId,
role: "assistant",
parts: [],
metadata: {
partial: true,
error: data.error,
errorType: data.errorType,
timestamp: Date.now(),
historySequence: maxSequence + 1,
},
};
this.messages.set(data.messageId, errorMessage);
this.invalidateCache();
}

handleToolCallStart(data: ToolCallStartEvent): void {
Expand Down Expand Up @@ -844,7 +893,7 @@ export class StreamingMessageAggregator {

handleReasoningEnd(_data: ReasoningEndEvent): void {
// Reasoning-end is just a signal - no state to update
// Streaming status is inferred from activeStreams in getDisplayedMessages
// Streaming status is inferred from inFlightStreams in getDisplayedMessages
this.invalidateCache();
}

Expand Down Expand Up @@ -1035,7 +1084,7 @@ export class StreamingMessageAggregator {

// Check if this message has an active stream (for inferring streaming status)
// Direct Map.has() check - O(1) instead of O(n) iteration
const hasActiveStream = this.activeStreams.has(message.id);
const hasActiveStream = this.inFlightStreams.get(message.id)?.phase === "active";

// Merge adjacent text/reasoning parts for display
const mergedParts = mergeAdjacentParts(message.parts);
Expand Down
29 changes: 29 additions & 0 deletions src/browser/utils/messages/retryEligibility.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,35 @@ describe("hasInterruptedStream", () => {
expect(hasInterruptedStream(messages, null)).toBe(true);
});

it("returns false when pendingStreamStartTime is null but last user message timestamp is recent (replay/reload)", () => {
const justSentTimestamp = Date.now() - (PENDING_STREAM_START_GRACE_PERIOD_MS - 500);
const messages: DisplayedMessage[] = [
{
type: "user",
id: "user-1",
historyId: "user-1",
content: "Hello",
historySequence: 1,
timestamp: justSentTimestamp,
},
];
expect(hasInterruptedStream(messages, null)).toBe(false);
});

it("returns true when pendingStreamStartTime is null and last user message timestamp is old (replay/reload)", () => {
const longAgoTimestamp = Date.now() - (PENDING_STREAM_START_GRACE_PERIOD_MS + 1000);
const messages: DisplayedMessage[] = [
{
type: "user",
id: "user-1",
historyId: "user-1",
content: "Hello",
historySequence: 1,
timestamp: longAgoTimestamp,
},
];
expect(hasInterruptedStream(messages, null)).toBe(true);
});
it("returns false when user message just sent (within grace period)", () => {
const messages: DisplayedMessage[] = [
{
Expand Down
20 changes: 13 additions & 7 deletions src/browser/utils/messages/retryEligibility.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,22 @@ export function hasInterruptedStream(
): boolean {
if (messages.length === 0) return false;

// Don't show retry barrier if user message was sent very recently (within the grace period)
// This prevents flash during normal send flow while stream-start event arrives
// After the grace period, assume something is wrong and show the barrier
if (pendingStreamStartTime !== null) {
const elapsed = Date.now() - pendingStreamStartTime;
const lastMessage = messages[messages.length - 1];

// Don't show retry barrier if the last user message was sent very recently (within the grace period).
//
// We prefer the explicit pendingStreamStartTime (set during the live send flow).
// But during history replay / app reload, pendingStreamStartTime can be null even when the last
// message is a fresh user message. In that case, fall back to the user message timestamp.
const graceStartTime =
pendingStreamStartTime ??
(lastMessage.type === "user" ? (lastMessage.timestamp ?? null) : null);

if (graceStartTime !== null) {
const elapsed = Date.now() - graceStartTime;
if (elapsed < PENDING_STREAM_START_GRACE_PERIOD_MS) return false;
}

const lastMessage = messages[messages.length - 1];

// ask_user_question is a special case: an unfinished tool call represents an
// intentional "waiting for user input" state, not a stream interruption.
//
Expand Down
1 change: 1 addition & 0 deletions src/common/orpc/schemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ export {
StreamDeltaEventSchema,
StreamEndEventSchema,
StreamErrorMessageSchema,
StreamPendingEventSchema,
StreamStartEventSchema,
ToolCallDeltaEventSchema,
ToolCallEndEventSchema,
Expand Down
Loading
Loading