From 9c8e01da5ec381c760a8d6fde85d4d5a5a4e3bd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 5 Apr 2026 15:54:38 +0200 Subject: [PATCH 1/7] Add user stop handling and stop metadata Refactor and extend abort/stop behavior to support targeted user stops. Introduces userStopRequest/plan/result types and handleUserStop/executeUserStopPlan to resolve and execute room-wide, active-turn, or queued-turn stops. Replaces direct abortRoom calls with handleUserStop in command and message handlers. Adds assistantStopMetadata and propagates it into streamingState and UI message metadata (including response status mapping). Tracks room run targets (source/initial events) and binds streaming state to room runs. Implements queue operations to drain or remove pending items by source event and finalizes stopped queue items, preserving ACK reaction removal and session notifications. Adjusts streaming finish logic to treat cancelled vs stop reasons appropriately. Includes unit tests for plan resolution, queued removal, and metadata emission. --- bridges/ai/abort_helpers.go | 207 +++++++++++++++++++++++-- bridges/ai/abort_helpers_test.go | 148 ++++++++++++++++++ bridges/ai/commands_parity.go | 10 +- bridges/ai/handlematrix.go | 11 +- bridges/ai/pending_queue.go | 70 +++++++++ bridges/ai/room_runs.go | 60 +++++++ bridges/ai/streaming_error_handling.go | 12 +- bridges/ai/streaming_init.go | 1 + bridges/ai/streaming_state.go | 2 + bridges/ai/turn_data.go | 3 + bridges/ai/ui_message_metadata.go | 11 ++ 11 files changed, 512 insertions(+), 23 deletions(-) create mode 100644 bridges/ai/abort_helpers_test.go diff --git a/bridges/ai/abort_helpers.go b/bridges/ai/abort_helpers.go index 45607429..e5890f70 100644 --- a/bridges/ai/abort_helpers.go +++ b/bridges/ai/abort_helpers.go @@ -3,31 +3,204 @@ package ai import ( "context" "fmt" + "strings" "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/id" ) -func formatAbortNotice(stopped int) string { - if stopped <= 0 { - return "Agent was aborted." +type stopPlanKind string + +const ( + stopPlanKindNoMatch stopPlanKind = "no-match" + stopPlanKindRoomWide stopPlanKind = "room-wide" + stopPlanKindActive stopPlanKind = "active-turn" + stopPlanKindQueued stopPlanKind = "queued-turn" +) + +type userStopRequest struct { + Portal *bridgev2.Portal + Meta *PortalMetadata + ReplyTo id.EventID + RequestedByEventID id.EventID + RequestedVia string +} + +type userStopPlan struct { + Kind stopPlanKind + Scope string + TargetKind string + TargetEventID id.EventID +} + +type userStopResult struct { + Plan userStopPlan + ActiveStopped bool + QueuedStopped int + SubagentsStopped int +} + +func stopLabel(count int, singular string) string { + if count == 1 { + return singular + } + return singular + "s" +} + +func formatAbortNotice(result userStopResult) string { + switch result.Plan.Kind { + case stopPlanKindNoMatch: + return "No matching active or queued turn found for that reply." + case stopPlanKindActive: + if result.SubagentsStopped > 0 { + return fmt.Sprintf("Stopped that turn. Stopped %d %s.", result.SubagentsStopped, stopLabel(result.SubagentsStopped, "sub-agent")) + } + return "Stopped that turn." + case stopPlanKindQueued: + if result.QueuedStopped <= 1 { + return "Stopped that queued turn." + } + return fmt.Sprintf("Stopped %d queued %s.", result.QueuedStopped, stopLabel(result.QueuedStopped, "turn")) + case stopPlanKindRoomWide: + parts := make([]string, 0, 3) + if result.ActiveStopped { + parts = append(parts, "stopped the active turn") + } + if result.QueuedStopped > 0 { + parts = append(parts, fmt.Sprintf("removed %d queued %s", result.QueuedStopped, stopLabel(result.QueuedStopped, "turn"))) + } + if result.SubagentsStopped > 0 { + parts = append(parts, fmt.Sprintf("stopped %d %s", result.SubagentsStopped, stopLabel(result.SubagentsStopped, "sub-agent"))) + } + if len(parts) == 0 { + return "No active or queued turns to stop." + } + suffix := "" + if len(parts) > 1 { + suffix = " " + strings.Join(parts[1:], ". ") + "." + } + return strings.ToUpper(parts[0][:1]) + parts[0][1:] + "." + suffix + default: + return "No active or queued turns to stop." + } +} + +func (oc *AIClient) pendingQueueHasSourceEvent(roomID id.RoomID, sourceEventID id.EventID) bool { + if oc == nil || roomID == "" || sourceEventID == "" { + return false + } + queue := oc.getQueueSnapshot(roomID) + if queue == nil { + return false + } + for _, item := range queue.items { + if item.pending.sourceEventID() == sourceEventID { + return true + } + } + return false +} + +func buildStopMetadata(plan userStopPlan, req userStopRequest) *assistantStopMetadata { + return &assistantStopMetadata{ + Reason: "user_stop", + Scope: plan.Scope, + TargetKind: plan.TargetKind, + TargetEventID: plan.TargetEventID.String(), + RequestedByEventID: req.RequestedByEventID.String(), + RequestedVia: strings.TrimSpace(req.RequestedVia), + } +} + +func (oc *AIClient) resolveUserStopPlan(req userStopRequest) userStopPlan { + if req.Portal == nil || req.Portal.MXID == "" { + return userStopPlan{Kind: stopPlanKindNoMatch} + } + if req.ReplyTo == "" { + return userStopPlan{ + Kind: stopPlanKindRoomWide, + Scope: "room", + TargetKind: "all", + } + } + + _, sourceEventID, initialEventID, _ := oc.roomRunTarget(req.Portal.MXID) + if initialEventID != "" && req.ReplyTo == initialEventID { + return userStopPlan{ + Kind: stopPlanKindActive, + Scope: "turn", + TargetKind: "placeholder_event", + TargetEventID: req.ReplyTo, + } + } + if sourceEventID != "" && req.ReplyTo == sourceEventID { + return userStopPlan{ + Kind: stopPlanKindActive, + Scope: "turn", + TargetKind: "source_event", + TargetEventID: req.ReplyTo, + } } - label := "sub-agents" - if stopped == 1 { - label = "sub-agent" + if oc.pendingQueueHasSourceEvent(req.Portal.MXID, req.ReplyTo) { + return userStopPlan{ + Kind: stopPlanKindQueued, + Scope: "turn", + TargetKind: "source_event", + TargetEventID: req.ReplyTo, + } + } + return userStopPlan{ + Kind: stopPlanKindNoMatch, + Scope: "turn", + TargetEventID: req.ReplyTo, } - return fmt.Sprintf("Agent was aborted. Stopped %d %s.", stopped, label) } -func (oc *AIClient) abortRoom(ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata) int { - if portal == nil { - return 0 +func (oc *AIClient) finalizeStoppedQueueItems(ctx context.Context, items []pendingQueueItem) int { + for _, item := range items { + if item.pending.Meta != nil && item.pending.Meta.AckReactionRemoveAfter { + oc.removePendingAckReactions(oc.backgroundContext(ctx), item.pending.Portal, item.pending) + } + oc.sendQueueRejectedStatus(ctx, item.pending.Portal, item.pending.Event, item.pending.StatusEvents, "Stopped.") } - oc.cancelRoomRun(portal.MXID) - oc.clearPendingQueue(portal.MXID) - stopped := oc.stopSubagentRuns(portal.MXID) - if meta != nil { - meta.AbortedLastRun = true - oc.savePortalQuiet(ctx, portal, "abort") + return len(items) +} + +func (oc *AIClient) executeUserStopPlan(ctx context.Context, req userStopRequest, plan userStopPlan) userStopResult { + result := userStopResult{Plan: plan} + if req.Portal == nil || req.Portal.MXID == "" { + return result } - return stopped + roomID := req.Portal.MXID + switch plan.Kind { + case stopPlanKindRoomWide: + if oc.markRoomRunStopped(roomID, buildStopMetadata(plan, req)) { + result.ActiveStopped = oc.cancelRoomRun(roomID) + } + result.QueuedStopped = oc.finalizeStoppedQueueItems(ctx, oc.drainPendingQueue(roomID)) + result.SubagentsStopped = oc.stopSubagentRuns(roomID) + case stopPlanKindActive: + if oc.markRoomRunStopped(roomID, buildStopMetadata(plan, req)) { + result.ActiveStopped = oc.cancelRoomRun(roomID) + if result.ActiveStopped { + result.SubagentsStopped = oc.stopSubagentRuns(roomID) + } + } + case stopPlanKindQueued: + result.QueuedStopped = oc.finalizeStoppedQueueItems(ctx, oc.removePendingQueueBySourceEvent(roomID, plan.TargetEventID)) + } + + if req.Meta != nil && (result.ActiveStopped || result.QueuedStopped > 0 || result.SubagentsStopped > 0) { + req.Meta.AbortedLastRun = true + oc.savePortalQuiet(ctx, req.Portal, "stop") + } + if req.Meta != nil && result.QueuedStopped > 0 { + oc.notifySessionMutation(ctx, req.Portal, req.Meta, false) + } + return result +} + +func (oc *AIClient) handleUserStop(ctx context.Context, req userStopRequest) userStopResult { + plan := oc.resolveUserStopPlan(req) + return oc.executeUserStopPlan(ctx, req, plan) } diff --git a/bridges/ai/abort_helpers_test.go b/bridges/ai/abort_helpers_test.go new file mode 100644 index 00000000..ce350392 --- /dev/null +++ b/bridges/ai/abort_helpers_test.go @@ -0,0 +1,148 @@ +package ai + +import ( + "context" + "testing" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/id" + + bridgesdk "github.com/beeper/agentremote/sdk" +) + +func TestResolveUserStopPlanRoomWideWithoutReply(t *testing.T) { + oc := &AIClient{} + portal := &bridgev2.Portal{Portal: &database.Portal{MXID: "!room:test"}} + req := userStopRequest{Portal: portal, RequestedVia: "command"} + + plan := oc.resolveUserStopPlan(req) + if plan.Kind != stopPlanKindRoomWide { + t.Fatalf("expected room-wide stop, got %#v", plan) + } + if plan.TargetKind != "all" || plan.Scope != "room" { + t.Fatalf("unexpected room-wide stop plan: %#v", plan) + } +} + +func TestResolveUserStopPlanMatchesActiveReplyTargets(t *testing.T) { + roomID := id.RoomID("!room:test") + oc := &AIClient{ + activeRoomRuns: map[id.RoomID]*roomRunState{ + roomID: { + sourceEvent: id.EventID("$user"), + initialEvent: id.EventID("$assistant"), + }, + }, + } + portal := &bridgev2.Portal{Portal: &database.Portal{MXID: roomID}} + + placeholderPlan := oc.resolveUserStopPlan(userStopRequest{ + Portal: portal, + ReplyTo: id.EventID("$assistant"), + }) + if placeholderPlan.Kind != stopPlanKindActive || placeholderPlan.TargetKind != "placeholder_event" { + t.Fatalf("expected placeholder-targeted active stop, got %#v", placeholderPlan) + } + + sourcePlan := oc.resolveUserStopPlan(userStopRequest{ + Portal: portal, + ReplyTo: id.EventID("$user"), + }) + if sourcePlan.Kind != stopPlanKindActive || sourcePlan.TargetKind != "source_event" { + t.Fatalf("expected source-targeted active stop, got %#v", sourcePlan) + } +} + +func TestResolveUserStopPlanMatchesQueuedReplyTarget(t *testing.T) { + roomID := id.RoomID("!room:test") + oc := &AIClient{ + pendingQueues: map[id.RoomID]*pendingQueue{ + roomID: { + items: []pendingQueueItem{{ + pending: pendingMessage{SourceEventID: id.EventID("$queued")}, + }}, + }, + }, + } + portal := &bridgev2.Portal{Portal: &database.Portal{MXID: roomID}} + + plan := oc.resolveUserStopPlan(userStopRequest{ + Portal: portal, + ReplyTo: id.EventID("$queued"), + }) + if plan.Kind != stopPlanKindQueued || plan.TargetKind != "source_event" { + t.Fatalf("expected queued stop plan, got %#v", plan) + } +} + +func TestExecuteUserStopPlanRemovesOnlyTargetedQueuedTurn(t *testing.T) { + roomID := id.RoomID("!room:test") + oc := &AIClient{ + pendingQueues: map[id.RoomID]*pendingQueue{ + roomID: { + items: []pendingQueueItem{ + {pending: pendingMessage{SourceEventID: id.EventID("$one")}}, + {pending: pendingMessage{SourceEventID: id.EventID("$two")}}, + }, + }, + }, + } + portal := &bridgev2.Portal{Portal: &database.Portal{MXID: roomID}} + + result := oc.executeUserStopPlan(context.Background(), userStopRequest{ + Portal: portal, + }, userStopPlan{ + Kind: stopPlanKindQueued, + Scope: "turn", + TargetKind: "source_event", + TargetEventID: id.EventID("$one"), + }) + if result.QueuedStopped != 1 { + t.Fatalf("expected one queued turn to stop, got %#v", result) + } + snapshot := oc.getQueueSnapshot(roomID) + if snapshot == nil || len(snapshot.items) != 1 { + t.Fatalf("expected one queued item to remain, got %#v", snapshot) + } + if got := snapshot.items[0].pending.sourceEventID(); got != id.EventID("$two") { + t.Fatalf("expected remaining queued event $two, got %q", got) + } +} + +func TestBuildStreamUIMessageIncludesStopMetadata(t *testing.T) { + oc := &AIClient{} + conv := bridgesdk.NewConversation[*AIClient, *Config](context.Background(), nil, nil, bridgev2.EventSender{}, nil, nil) + turn := conv.StartTurn(context.Background(), nil, &bridgesdk.SourceRef{EventID: "$user", SenderID: "@user:test"}) + turn.SetID("turn-stop") + state := &streamingState{ + turn: turn, + finishReason: "stop", + stop: &assistantStopMetadata{ + Reason: "user_stop", + Scope: "turn", + TargetKind: "source_event", + TargetEventID: "$user", + RequestedByEventID: "$stop", + RequestedVia: "command", + }, + responseID: "resp_123", + completedAtMs: 1, + } + + ui := oc.buildStreamUIMessage(state, nil, nil) + metadata, ok := ui["metadata"].(map[string]any) + if !ok { + t.Fatalf("expected metadata map, got %T", ui["metadata"]) + } + stop, ok := metadata["stop"].(map[string]any) + if !ok { + t.Fatalf("expected nested stop metadata, got %#v", metadata["stop"]) + } + if stop["reason"] != "user_stop" || stop["requested_via"] != "command" { + t.Fatalf("unexpected stop metadata: %#v", stop) + } + if metadata["response_status"] != "cancelled" { + t.Fatalf("expected cancelled response status for stopped turn, got %#v", metadata["response_status"]) + } +} diff --git a/bridges/ai/commands_parity.go b/bridges/ai/commands_parity.go index bafb22de..cdbaf9de 100644 --- a/bridges/ai/commands_parity.go +++ b/bridges/ai/commands_parity.go @@ -65,6 +65,12 @@ func fnStop(ce *commands.Event) { if !ok { return } - stopped := client.abortRoom(ce.Ctx, ce.Portal, meta) - ce.Reply("%s", formatAbortNotice(stopped)) + result := client.handleUserStop(ce.Ctx, userStopRequest{ + Portal: ce.Portal, + Meta: meta, + ReplyTo: ce.ReplyTo, + RequestedByEventID: ce.EventID, + RequestedVia: "command", + }) + ce.Reply("%s", formatAbortNotice(result)) } diff --git a/bridges/ai/handlematrix.go b/bridges/ai/handlematrix.go index 4604d084..d1864069 100644 --- a/bridges/ai/handlematrix.go +++ b/bridges/ai/handlematrix.go @@ -135,8 +135,15 @@ func (oc *AIClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Matri return &bridgev2.MatrixMessageResponse{Pending: false}, nil } if commandAuthorized && airuntime.IsAbortTriggerText(commandBody) { - stopped := oc.abortRoom(ctx, portal, meta) - oc.sendSystemNotice(ctx, portal, formatAbortNotice(stopped)) + replyCtx := extractInboundReplyContext(msg.Event) + result := oc.handleUserStop(ctx, userStopRequest{ + Portal: portal, + Meta: meta, + ReplyTo: replyCtx.ReplyTo, + RequestedByEventID: msg.Event.ID, + RequestedVia: "text-trigger", + }) + oc.sendSystemNotice(ctx, portal, formatAbortNotice(result)) logCtx.Debug().Msg("Abort trigger handled") return &bridgev2.MatrixMessageResponse{Pending: false}, nil } diff --git a/bridges/ai/pending_queue.go b/bridges/ai/pending_queue.go index aee69c52..ac02e3f4 100644 --- a/bridges/ai/pending_queue.go +++ b/bridges/ai/pending_queue.go @@ -38,6 +38,16 @@ type pendingQueue struct { lastItem *pendingQueueItem } +func (pm pendingMessage) sourceEventID() id.EventID { + if pm.SourceEventID != "" { + return pm.SourceEventID + } + if pm.Event != nil { + return pm.Event.ID + } + return "" +} + type pendingQueueDispatchCandidate struct { items []pendingQueueItem summaryPrompt string @@ -85,6 +95,66 @@ func (oc *AIClient) clearPendingQueue(roomID id.RoomID) { } } +func (oc *AIClient) drainPendingQueue(roomID id.RoomID) []pendingQueueItem { + if oc == nil || roomID == "" { + return nil + } + oc.pendingQueuesMu.Lock() + queue := oc.pendingQueues[roomID] + if queue == nil { + oc.pendingQueuesMu.Unlock() + return nil + } + delete(oc.pendingQueues, roomID) + oc.pendingQueuesMu.Unlock() + + queue.mu.Lock() + items := slices.Clone(queue.items) + queue.items = nil + queue.summaryLines = nil + queue.droppedCount = 0 + queue.lastItem = nil + queue.mu.Unlock() + + oc.stopQueueTyping(roomID) + return items +} + +func (oc *AIClient) removePendingQueueBySourceEvent(roomID id.RoomID, sourceEventID id.EventID) []pendingQueueItem { + if oc == nil || roomID == "" || sourceEventID == "" { + return nil + } + oc.pendingQueuesMu.Lock() + queue := oc.pendingQueues[roomID] + if queue == nil { + oc.pendingQueuesMu.Unlock() + return nil + } + queue.mu.Lock() + removed := make([]pendingQueueItem, 0, 1) + kept := queue.items[:0] + for _, item := range queue.items { + if item.pending.sourceEventID() == sourceEventID { + removed = append(removed, item) + continue + } + kept = append(kept, item) + } + clear(queue.items[len(kept):]) + queue.items = kept + empty := len(queue.items) == 0 && queue.droppedCount == 0 + if empty { + delete(oc.pendingQueues, roomID) + } + queue.mu.Unlock() + oc.pendingQueuesMu.Unlock() + + if empty { + oc.stopQueueTyping(roomID) + } + return removed +} + func (oc *AIClient) enqueuePendingItem(roomID id.RoomID, item pendingQueueItem, settings airuntime.QueueSettings) bool { queue := oc.getPendingQueue(roomID, settings) if queue == nil { diff --git a/bridges/ai/room_runs.go b/bridges/ai/room_runs.go index 64071164..35083e76 100644 --- a/bridges/ai/room_runs.go +++ b/bridges/ai/room_runs.go @@ -13,6 +13,11 @@ type roomRunState struct { cancel context.CancelFunc mu sync.Mutex + state *streamingState + stop *assistantStopMetadata + turnID string + sourceEvent id.EventID + initialEvent id.EventID streaming bool steerQueue []pendingQueueItem statusEvents []*event.Event @@ -97,6 +102,61 @@ func (oc *AIClient) markRoomRunStreaming(roomID id.RoomID, streaming bool) { run.mu.Unlock() } +func (oc *AIClient) bindRoomRunState(roomID id.RoomID, state *streamingState) { + run := oc.getRoomRun(roomID) + if run == nil { + return + } + run.mu.Lock() + run.state = state + if run.stop != nil && state != nil { + state.stop = run.stop + } + if state != nil && state.turn != nil { + run.turnID = state.turn.ID() + run.sourceEvent = state.sourceEventID() + run.initialEvent = state.turn.InitialEventID() + } + run.mu.Unlock() +} + +func (oc *AIClient) roomRunTarget(roomID id.RoomID) (turnID string, sourceEventID, initialEventID id.EventID, state *streamingState) { + run := oc.getRoomRun(roomID) + if run == nil { + return "", "", "", nil + } + run.mu.Lock() + defer run.mu.Unlock() + state = run.state + turnID = run.turnID + sourceEventID = run.sourceEvent + initialEventID = run.initialEvent + if state == nil || state.turn == nil { + return turnID, sourceEventID, initialEventID, state + } + turnID = state.turn.ID() + sourceEventID = state.sourceEventID() + initialEventID = state.turn.InitialEventID() + run.turnID = turnID + run.sourceEvent = sourceEventID + run.initialEvent = initialEventID + return turnID, sourceEventID, initialEventID, state +} + +func (oc *AIClient) markRoomRunStopped(roomID id.RoomID, stop *assistantStopMetadata) bool { + run := oc.getRoomRun(roomID) + if run == nil || stop == nil { + return false + } + run.mu.Lock() + run.stop = stop + if run.state != nil { + run.state.stop = stop + } + run.mu.Unlock() + return true +} + func (oc *AIClient) enqueueSteerQueue(roomID id.RoomID, item pendingQueueItem) bool { run := oc.getRoomRun(roomID) if run == nil { diff --git a/bridges/ai/streaming_error_handling.go b/bridges/ai/streaming_error_handling.go index 14a848fa..0ea6a9e5 100644 --- a/bridges/ai/streaming_error_handling.go +++ b/bridges/ai/streaming_error_handling.go @@ -40,6 +40,9 @@ func (oc *AIClient) finishStreamingWithFailure( reason string, err error, ) error { + if state != nil && state.stop != nil && reason == "cancelled" { + reason = "stop" + } state.finishReason = reason state.completedAtMs = time.Now().UnixMilli() _ = log @@ -47,12 +50,17 @@ func (oc *AIClient) finishStreamingWithFailure( if writer := state.writer(); writer != nil { writer.MessageMetadata(ctx, oc.buildUIMessageMetadata(state, meta, true)) } - if reason == "cancelled" { + switch reason { + case "cancelled": state.writer().Abort(ctx, "cancelled") if state != nil && state.turn != nil { state.turn.End(msgconv.MapFinishReason(reason)) } - } else { + case "stop": + if state != nil && state.turn != nil { + state.turn.End(msgconv.MapFinishReason(reason)) + } + default: if state != nil && state.turn != nil { state.turn.EndWithError(err.Error()) } diff --git a/bridges/ai/streaming_init.go b/bridges/ai/streaming_init.go index 1833ec07..7f929cfd 100644 --- a/bridges/ai/streaming_init.go +++ b/bridges/ai/streaming_init.go @@ -128,6 +128,7 @@ func (oc *AIClient) prepareStreamingRun( // Create SDK Turn for writer/emitter/session management. turn := oc.createStreamingTurn(ctx, portal, meta, state, sourceEventID, senderID) state.turn = turn + oc.bindRoomRunState(roomID, state) state.replyTarget = oc.resolveInitialReplyTarget(evt) if state.replyTarget.ThreadRoot != "" { diff --git a/bridges/ai/streaming_state.go b/bridges/ai/streaming_state.go index 18990d95..f4ba5951 100644 --- a/bridges/ai/streaming_state.go +++ b/bridges/ai/streaming_state.go @@ -69,6 +69,8 @@ type streamingState struct { // Pending MCP approvals to resolve before the turn can continue. pendingMcpApprovals []mcpApprovalRequest pendingMcpApprovalsSeen map[string]bool + + stop *assistantStopMetadata } // sourceEventID returns the triggering user message event ID from the turn's source ref. diff --git a/bridges/ai/turn_data.go b/bridges/ai/turn_data.go index 6448b6da..bd5e0020 100644 --- a/bridges/ai/turn_data.go +++ b/bridges/ai/turn_data.go @@ -72,6 +72,9 @@ func canonicalResponseStatus(state *streamingState) string { if strings.TrimSpace(state.responseID) == "" { return status } + if state.stop != nil { + return "cancelled" + } switch strings.TrimSpace(state.finishReason) { case "", "stop": diff --git a/bridges/ai/ui_message_metadata.go b/bridges/ai/ui_message_metadata.go index 5c888330..e96b7fe5 100644 --- a/bridges/ai/ui_message_metadata.go +++ b/bridges/ai/ui_message_metadata.go @@ -13,6 +13,15 @@ type assistantUsageMetadata struct { TotalTokens int64 `json:"total_tokens,omitempty"` } +type assistantStopMetadata struct { + Reason string `json:"reason,omitempty"` + Scope string `json:"scope,omitempty"` + TargetKind string `json:"target_kind,omitempty"` + TargetEventID string `json:"target_event_id,omitempty"` + RequestedByEventID string `json:"requested_by_event_id,omitempty"` + RequestedVia string `json:"requested_via,omitempty"` +} + type assistantTurnMetadata struct { TurnID string `json:"turn_id,omitempty"` AgentID string `json:"agent_id,omitempty"` @@ -28,6 +37,7 @@ type assistantTurnMetadata struct { SourceEventID string `json:"source_event_id,omitempty"` GeneratedFileRefs []GeneratedFileRef `json:"generated_file_refs,omitempty"` Usage *assistantUsageMetadata `json:"usage,omitempty"` + Stop *assistantStopMetadata `json:"stop,omitempty"` } func buildAssistantUsageMetadata(state *streamingState) *assistantUsageMetadata { @@ -70,5 +80,6 @@ func buildAssistantTurnMetadata(state *streamingState, turnID, networkMessageID, SourceEventID: state.sourceEventID().String(), GeneratedFileRefs: agentremote.GeneratedFileRefsFromParts(state.generatedFiles), Usage: buildAssistantUsageMetadata(state), + Stop: state.stop, }) } From a691f3cd3485e65dfa7e620776622b20ec28893d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 5 Apr 2026 16:04:30 +0200 Subject: [PATCH 2/7] Refactor queue checks, abort text, and streaming errors Move pendingQueueHasSourceEvent into pending_queue.go and implement proper locking (pendingQueuesMu and queue.mu) to safely inspect queue items. Simplify drainPendingQueue to delete the queue map entry and return its items directly. Remove the duplicate helper from abort_helpers.go. Improve formatAbortNotice by capitalizing each sentence part and joining them with ". " for clearer messages. Remove redundant run field assignments in roomRunTarget. Adjust finishStreamingWithFailure to fall through from "cancelled" to "stop" so cancelled streams call End like stop cases and remove some redundant nil checks. These changes tidy concurrency handling, clarify abort messaging, and simplify streaming error handling. --- bridges/ai/abort_helpers.go | 22 +++--------------- bridges/ai/pending_queue.go | 32 +++++++++++++++++++------- bridges/ai/room_runs.go | 3 --- bridges/ai/streaming_error_handling.go | 8 +++---- 4 files changed, 30 insertions(+), 35 deletions(-) diff --git a/bridges/ai/abort_helpers.go b/bridges/ai/abort_helpers.go index e5890f70..545209f1 100644 --- a/bridges/ai/abort_helpers.go +++ b/bridges/ai/abort_helpers.go @@ -75,31 +75,15 @@ func formatAbortNotice(result userStopResult) string { if len(parts) == 0 { return "No active or queued turns to stop." } - suffix := "" - if len(parts) > 1 { - suffix = " " + strings.Join(parts[1:], ". ") + "." + for i := range parts { + parts[i] = strings.ToUpper(parts[i][:1]) + parts[i][1:] } - return strings.ToUpper(parts[0][:1]) + parts[0][1:] + "." + suffix + return strings.Join(parts, ". ") + "." default: return "No active or queued turns to stop." } } -func (oc *AIClient) pendingQueueHasSourceEvent(roomID id.RoomID, sourceEventID id.EventID) bool { - if oc == nil || roomID == "" || sourceEventID == "" { - return false - } - queue := oc.getQueueSnapshot(roomID) - if queue == nil { - return false - } - for _, item := range queue.items { - if item.pending.sourceEventID() == sourceEventID { - return true - } - } - return false -} func buildStopMetadata(plan userStopPlan, req userStopRequest) *assistantStopMetadata { return &assistantStopMetadata{ diff --git a/bridges/ai/pending_queue.go b/bridges/ai/pending_queue.go index ac02e3f4..b252ff48 100644 --- a/bridges/ai/pending_queue.go +++ b/bridges/ai/pending_queue.go @@ -106,20 +106,36 @@ func (oc *AIClient) drainPendingQueue(roomID id.RoomID) []pendingQueueItem { return nil } delete(oc.pendingQueues, roomID) + items := queue.items oc.pendingQueuesMu.Unlock() - queue.mu.Lock() - items := slices.Clone(queue.items) - queue.items = nil - queue.summaryLines = nil - queue.droppedCount = 0 - queue.lastItem = nil - queue.mu.Unlock() - oc.stopQueueTyping(roomID) return items } +func (oc *AIClient) pendingQueueHasSourceEvent(roomID id.RoomID, sourceEventID id.EventID) bool { + if oc == nil || roomID == "" || sourceEventID == "" { + return false + } + oc.pendingQueuesMu.Lock() + queue := oc.pendingQueues[roomID] + if queue == nil { + oc.pendingQueuesMu.Unlock() + return false + } + queue.mu.Lock() + found := false + for _, item := range queue.items { + if item.pending.sourceEventID() == sourceEventID { + found = true + break + } + } + queue.mu.Unlock() + oc.pendingQueuesMu.Unlock() + return found +} + func (oc *AIClient) removePendingQueueBySourceEvent(roomID id.RoomID, sourceEventID id.EventID) []pendingQueueItem { if oc == nil || roomID == "" || sourceEventID == "" { return nil diff --git a/bridges/ai/room_runs.go b/bridges/ai/room_runs.go index 35083e76..a12f5af6 100644 --- a/bridges/ai/room_runs.go +++ b/bridges/ai/room_runs.go @@ -137,9 +137,6 @@ func (oc *AIClient) roomRunTarget(roomID id.RoomID) (turnID string, sourceEventI turnID = state.turn.ID() sourceEventID = state.sourceEventID() initialEventID = state.turn.InitialEventID() - run.turnID = turnID - run.sourceEvent = sourceEventID - run.initialEvent = initialEventID return turnID, sourceEventID, initialEventID, state } diff --git a/bridges/ai/streaming_error_handling.go b/bridges/ai/streaming_error_handling.go index 0ea6a9e5..6f949bdb 100644 --- a/bridges/ai/streaming_error_handling.go +++ b/bridges/ai/streaming_error_handling.go @@ -53,15 +53,13 @@ func (oc *AIClient) finishStreamingWithFailure( switch reason { case "cancelled": state.writer().Abort(ctx, "cancelled") - if state != nil && state.turn != nil { - state.turn.End(msgconv.MapFinishReason(reason)) - } + fallthrough case "stop": - if state != nil && state.turn != nil { + if state.turn != nil { state.turn.End(msgconv.MapFinishReason(reason)) } default: - if state != nil && state.turn != nil { + if state.turn != nil { state.turn.EndWithError(err.Error()) } } From 5d7e0b2773454a95d55964a61aca455ac7e01aab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 5 Apr 2026 16:16:40 +0200 Subject: [PATCH 3/7] Make stop metadata atomic and thread-safe Change streamingState.stop to an atomic.Pointer[assistantStopMetadata] and update all callsites to use .Load()/.Store() to avoid races. Fix pending queue drain to lock the queue when accessing items to prevent concurrent access. Improve room run logic to prefer current state.turn when present and store stop metadata atomically when marking a run stopped. Use utf8 + unicode for correct Unicode-aware capitalization in abort notices and update tests to store stop metadata via the new atomic API. --- bridges/ai/abort_helpers.go | 6 ++++-- bridges/ai/abort_helpers_test.go | 20 ++++++++++---------- bridges/ai/pending_queue.go | 2 ++ bridges/ai/room_runs.go | 18 ++++++------------ bridges/ai/streaming_error_handling.go | 2 +- bridges/ai/streaming_state.go | 3 ++- bridges/ai/turn_data.go | 2 +- bridges/ai/ui_message_metadata.go | 2 +- 8 files changed, 27 insertions(+), 28 deletions(-) diff --git a/bridges/ai/abort_helpers.go b/bridges/ai/abort_helpers.go index 545209f1..d88261ac 100644 --- a/bridges/ai/abort_helpers.go +++ b/bridges/ai/abort_helpers.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "strings" + "unicode" + "unicode/utf8" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/id" @@ -76,7 +78,8 @@ func formatAbortNotice(result userStopResult) string { return "No active or queued turns to stop." } for i := range parts { - parts[i] = strings.ToUpper(parts[i][:1]) + parts[i][1:] + r, size := utf8.DecodeRuneInString(parts[i]) + parts[i] = string(unicode.ToUpper(r)) + parts[i][size:] } return strings.Join(parts, ". ") + "." default: @@ -84,7 +87,6 @@ func formatAbortNotice(result userStopResult) string { } } - func buildStopMetadata(plan userStopPlan, req userStopRequest) *assistantStopMetadata { return &assistantStopMetadata{ Reason: "user_stop", diff --git a/bridges/ai/abort_helpers_test.go b/bridges/ai/abort_helpers_test.go index ce350392..ff2994de 100644 --- a/bridges/ai/abort_helpers_test.go +++ b/bridges/ai/abort_helpers_test.go @@ -116,19 +116,19 @@ func TestBuildStreamUIMessageIncludesStopMetadata(t *testing.T) { turn := conv.StartTurn(context.Background(), nil, &bridgesdk.SourceRef{EventID: "$user", SenderID: "@user:test"}) turn.SetID("turn-stop") state := &streamingState{ - turn: turn, - finishReason: "stop", - stop: &assistantStopMetadata{ - Reason: "user_stop", - Scope: "turn", - TargetKind: "source_event", - TargetEventID: "$user", - RequestedByEventID: "$stop", - RequestedVia: "command", - }, + turn: turn, + finishReason: "stop", responseID: "resp_123", completedAtMs: 1, } + state.stop.Store(&assistantStopMetadata{ + Reason: "user_stop", + Scope: "turn", + TargetKind: "source_event", + TargetEventID: "$user", + RequestedByEventID: "$stop", + RequestedVia: "command", + }) ui := oc.buildStreamUIMessage(state, nil, nil) metadata, ok := ui["metadata"].(map[string]any) diff --git a/bridges/ai/pending_queue.go b/bridges/ai/pending_queue.go index b252ff48..2b4f9f35 100644 --- a/bridges/ai/pending_queue.go +++ b/bridges/ai/pending_queue.go @@ -106,7 +106,9 @@ func (oc *AIClient) drainPendingQueue(roomID id.RoomID) []pendingQueueItem { return nil } delete(oc.pendingQueues, roomID) + queue.mu.Lock() items := queue.items + queue.mu.Unlock() oc.pendingQueuesMu.Unlock() oc.stopQueueTyping(roomID) diff --git a/bridges/ai/room_runs.go b/bridges/ai/room_runs.go index a12f5af6..274463f0 100644 --- a/bridges/ai/room_runs.go +++ b/bridges/ai/room_runs.go @@ -110,7 +110,7 @@ func (oc *AIClient) bindRoomRunState(roomID id.RoomID, state *streamingState) { run.mu.Lock() run.state = state if run.stop != nil && state != nil { - state.stop = run.stop + state.stop.Store(run.stop) } if state != nil && state.turn != nil { run.turnID = state.turn.ID() @@ -128,16 +128,10 @@ func (oc *AIClient) roomRunTarget(roomID id.RoomID) (turnID string, sourceEventI run.mu.Lock() defer run.mu.Unlock() state = run.state - turnID = run.turnID - sourceEventID = run.sourceEvent - initialEventID = run.initialEvent - if state == nil || state.turn == nil { - return turnID, sourceEventID, initialEventID, state - } - turnID = state.turn.ID() - sourceEventID = state.sourceEventID() - initialEventID = state.turn.InitialEventID() - return turnID, sourceEventID, initialEventID, state + if state != nil && state.turn != nil { + return state.turn.ID(), state.sourceEventID(), state.turn.InitialEventID(), state + } + return run.turnID, run.sourceEvent, run.initialEvent, state } func (oc *AIClient) markRoomRunStopped(roomID id.RoomID, stop *assistantStopMetadata) bool { @@ -148,7 +142,7 @@ func (oc *AIClient) markRoomRunStopped(roomID id.RoomID, stop *assistantStopMeta run.mu.Lock() run.stop = stop if run.state != nil { - run.state.stop = stop + run.state.stop.Store(stop) } run.mu.Unlock() return true diff --git a/bridges/ai/streaming_error_handling.go b/bridges/ai/streaming_error_handling.go index 6f949bdb..92bb9769 100644 --- a/bridges/ai/streaming_error_handling.go +++ b/bridges/ai/streaming_error_handling.go @@ -40,7 +40,7 @@ func (oc *AIClient) finishStreamingWithFailure( reason string, err error, ) error { - if state != nil && state.stop != nil && reason == "cancelled" { + if state != nil && state.stop.Load() != nil && reason == "cancelled" { reason = "stop" } state.finishReason = reason diff --git a/bridges/ai/streaming_state.go b/bridges/ai/streaming_state.go index f4ba5951..cb205045 100644 --- a/bridges/ai/streaming_state.go +++ b/bridges/ai/streaming_state.go @@ -3,6 +3,7 @@ package ai import ( "context" "strings" + "sync/atomic" "time" "github.com/openai/openai-go/v3/packages/param" @@ -70,7 +71,7 @@ type streamingState struct { pendingMcpApprovals []mcpApprovalRequest pendingMcpApprovalsSeen map[string]bool - stop *assistantStopMetadata + stop atomic.Pointer[assistantStopMetadata] } // sourceEventID returns the triggering user message event ID from the turn's source ref. diff --git a/bridges/ai/turn_data.go b/bridges/ai/turn_data.go index bd5e0020..24c138aa 100644 --- a/bridges/ai/turn_data.go +++ b/bridges/ai/turn_data.go @@ -72,7 +72,7 @@ func canonicalResponseStatus(state *streamingState) string { if strings.TrimSpace(state.responseID) == "" { return status } - if state.stop != nil { + if state.stop.Load() != nil { return "cancelled" } diff --git a/bridges/ai/ui_message_metadata.go b/bridges/ai/ui_message_metadata.go index e96b7fe5..b55abf32 100644 --- a/bridges/ai/ui_message_metadata.go +++ b/bridges/ai/ui_message_metadata.go @@ -80,6 +80,6 @@ func buildAssistantTurnMetadata(state *streamingState, turnID, networkMessageID, SourceEventID: state.sourceEventID().String(), GeneratedFileRefs: agentremote.GeneratedFileRefsFromParts(state.generatedFiles), Usage: buildAssistantUsageMetadata(state), - Stop: state.stop, + Stop: state.stop.Load(), }) } From ed72a84b34bd1a29e0a4606458c28578f6f89861 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 5 Apr 2026 16:26:27 +0200 Subject: [PATCH 4/7] Centralize ack removal and simplify queue logic Move AckReactionRemoveAfter checks into removePendingAckReactions and remove duplicate guarded calls throughout the codebase, so callers simply invoke the removal and the function decides whether to act. Simplify pending queue management by replacing clearPendingQueue usage with drainPendingQueue and delete the pendingQueueHasSourceEvent helper. Adjust stop-plan handling to speculatively return queued stops and add a fallback in executeUserStopPlan to convert a queued plan to no-match if nothing was drained. Update tests to reflect the new speculative behavior and the fallback. --- bridges/ai/abort_helpers.go | 18 +++++---------- bridges/ai/abort_helpers_test.go | 39 ++++++++++++++++++++------------ bridges/ai/client.go | 6 ++--- bridges/ai/pending_queue.go | 37 +++--------------------------- bridges/ai/room_runs.go | 4 +--- bridges/ai/subagent_registry.go | 6 ++--- 6 files changed, 39 insertions(+), 71 deletions(-) diff --git a/bridges/ai/abort_helpers.go b/bridges/ai/abort_helpers.go index d88261ac..985cbde4 100644 --- a/bridges/ai/abort_helpers.go +++ b/bridges/ai/abort_helpers.go @@ -127,26 +127,17 @@ func (oc *AIClient) resolveUserStopPlan(req userStopRequest) userStopPlan { TargetEventID: req.ReplyTo, } } - if oc.pendingQueueHasSourceEvent(req.Portal.MXID, req.ReplyTo) { - return userStopPlan{ - Kind: stopPlanKindQueued, - Scope: "turn", - TargetKind: "source_event", - TargetEventID: req.ReplyTo, - } - } return userStopPlan{ - Kind: stopPlanKindNoMatch, + Kind: stopPlanKindQueued, Scope: "turn", + TargetKind: "source_event", TargetEventID: req.ReplyTo, } } func (oc *AIClient) finalizeStoppedQueueItems(ctx context.Context, items []pendingQueueItem) int { for _, item := range items { - if item.pending.Meta != nil && item.pending.Meta.AckReactionRemoveAfter { - oc.removePendingAckReactions(oc.backgroundContext(ctx), item.pending.Portal, item.pending) - } + oc.removePendingAckReactions(oc.backgroundContext(ctx), item.pending.Portal, item.pending) oc.sendQueueRejectedStatus(ctx, item.pending.Portal, item.pending.Event, item.pending.StatusEvents, "Stopped.") } return len(items) @@ -174,6 +165,9 @@ func (oc *AIClient) executeUserStopPlan(ctx context.Context, req userStopRequest } case stopPlanKindQueued: result.QueuedStopped = oc.finalizeStoppedQueueItems(ctx, oc.removePendingQueueBySourceEvent(roomID, plan.TargetEventID)) + if result.QueuedStopped == 0 { + result.Plan.Kind = stopPlanKindNoMatch + } } if req.Meta != nil && (result.ActiveStopped || result.QueuedStopped > 0 || result.SubagentsStopped > 0) { diff --git a/bridges/ai/abort_helpers_test.go b/bridges/ai/abort_helpers_test.go index ff2994de..a1c7c108 100644 --- a/bridges/ai/abort_helpers_test.go +++ b/bridges/ai/abort_helpers_test.go @@ -54,25 +54,36 @@ func TestResolveUserStopPlanMatchesActiveReplyTargets(t *testing.T) { } } -func TestResolveUserStopPlanMatchesQueuedReplyTarget(t *testing.T) { - roomID := id.RoomID("!room:test") - oc := &AIClient{ - pendingQueues: map[id.RoomID]*pendingQueue{ - roomID: { - items: []pendingQueueItem{{ - pending: pendingMessage{SourceEventID: id.EventID("$queued")}, - }}, - }, - }, - } - portal := &bridgev2.Portal{Portal: &database.Portal{MXID: roomID}} +func TestResolveUserStopPlanSpeculativelyReturnsQueued(t *testing.T) { + oc := &AIClient{} + portal := &bridgev2.Portal{Portal: &database.Portal{MXID: "!room:test"}} plan := oc.resolveUserStopPlan(userStopRequest{ Portal: portal, - ReplyTo: id.EventID("$queued"), + ReplyTo: id.EventID("$unknown"), }) if plan.Kind != stopPlanKindQueued || plan.TargetKind != "source_event" { - t.Fatalf("expected queued stop plan, got %#v", plan) + t.Fatalf("expected speculative queued stop plan, got %#v", plan) + } +} + +func TestExecuteUserStopPlanFallsBackToNoMatch(t *testing.T) { + oc := &AIClient{} + portal := &bridgev2.Portal{Portal: &database.Portal{MXID: "!room:test"}} + + result := oc.executeUserStopPlan(context.Background(), userStopRequest{ + Portal: portal, + }, userStopPlan{ + Kind: stopPlanKindQueued, + Scope: "turn", + TargetKind: "source_event", + TargetEventID: id.EventID("$nonexistent"), + }) + if result.Plan.Kind != stopPlanKindNoMatch { + t.Fatalf("expected no-match fallback, got %#v", result.Plan) + } + if result.QueuedStopped != 0 { + t.Fatalf("expected zero queued stopped, got %d", result.QueuedStopped) } } diff --git a/bridges/ai/client.go b/bridges/ai/client.go index b181d712..ed4a8666 100644 --- a/bridges/ai/client.go +++ b/bridges/ai/client.go @@ -667,7 +667,7 @@ func (oc *AIClient) dispatchOrQueueCore( metaSnapshot := clonePortalMetadata(meta) go func(metaSnapshot *PortalMetadata) { defer func() { - if hasDBMessage && metaSnapshot != nil && metaSnapshot.AckReactionRemoveAfter { + if hasDBMessage { oc.removePendingAckReactions(oc.backgroundContext(ctx), portal, queueItem.pending) } oc.releaseRoom(roomID) @@ -815,9 +815,7 @@ func (oc *AIClient) processPendingQueue(ctx context.Context, roomID id.RoomID) { if err != nil { oc.loggerForContext(ctx).Err(err).Msg("Failed to build prompt for pending queue item") oc.notifyMatrixSendFailure(ctx, item.pending.Portal, item.pending.Event, err) - if item.pending.Meta != nil && item.pending.Meta.AckReactionRemoveAfter { - oc.removePendingAckReactions(oc.backgroundContext(ctx), item.pending.Portal, item.pending) - } + oc.removePendingAckReactions(oc.backgroundContext(ctx), item.pending.Portal, item.pending) oc.releaseRoom(roomID) oc.processPendingQueue(oc.backgroundContext(ctx), roomID) return diff --git a/bridges/ai/pending_queue.go b/bridges/ai/pending_queue.go index 2b4f9f35..6b8b684d 100644 --- a/bridges/ai/pending_queue.go +++ b/bridges/ai/pending_queue.go @@ -86,13 +86,7 @@ func (oc *AIClient) getPendingQueue(roomID id.RoomID, settings airuntime.QueueSe } func (oc *AIClient) clearPendingQueue(roomID id.RoomID) { - oc.pendingQueuesMu.Lock() - _, existed := oc.pendingQueues[roomID] - delete(oc.pendingQueues, roomID) - oc.pendingQueuesMu.Unlock() - if existed { - oc.stopQueueTyping(roomID) - } + oc.drainPendingQueue(roomID) } func (oc *AIClient) drainPendingQueue(roomID id.RoomID) []pendingQueueItem { @@ -115,29 +109,6 @@ func (oc *AIClient) drainPendingQueue(roomID id.RoomID) []pendingQueueItem { return items } -func (oc *AIClient) pendingQueueHasSourceEvent(roomID id.RoomID, sourceEventID id.EventID) bool { - if oc == nil || roomID == "" || sourceEventID == "" { - return false - } - oc.pendingQueuesMu.Lock() - queue := oc.pendingQueues[roomID] - if queue == nil { - oc.pendingQueuesMu.Unlock() - return false - } - queue.mu.Lock() - found := false - for _, item := range queue.items { - if item.pending.sourceEventID() == sourceEventID { - found = true - break - } - } - queue.mu.Unlock() - oc.pendingQueuesMu.Unlock() - return found -} - func (oc *AIClient) removePendingQueueBySourceEvent(roomID id.RoomID, sourceEventID id.EventID) []pendingQueueItem { if oc == nil || roomID == "" || sourceEventID == "" { return nil @@ -509,9 +480,7 @@ func (oc *AIClient) dispatchQueuedPrompt( metaSnapshot := clonePortalMetadata(item.pending.Meta) go func() { defer func() { - if metaSnapshot != nil && metaSnapshot.AckReactionRemoveAfter { - oc.removePendingAckReactions(oc.backgroundContext(ctx), item.pending.Portal, item.pending) - } + oc.removePendingAckReactions(oc.backgroundContext(ctx), item.pending.Portal, item.pending) if item.backlogAfter { followup := item followup.backlogAfter = false @@ -527,7 +496,7 @@ func (oc *AIClient) dispatchQueuedPrompt( } func (oc *AIClient) removePendingAckReactions(ctx context.Context, portal *bridgev2.Portal, pending pendingMessage) { - if portal == nil { + if portal == nil || pending.Meta == nil || !pending.Meta.AckReactionRemoveAfter { return } ids := pending.AckEventIDs diff --git a/bridges/ai/room_runs.go b/bridges/ai/room_runs.go index 274463f0..c83c7e7b 100644 --- a/bridges/ai/room_runs.go +++ b/bridges/ai/room_runs.go @@ -76,9 +76,7 @@ func (oc *AIClient) clearRoomRun(roomID id.RoomID) { } ctx := oc.backgroundContext(context.Background()) for _, pending := range ackPending { - if pending.Meta != nil && pending.Meta.AckReactionRemoveAfter { - oc.removePendingAckReactions(ctx, pending.Portal, pending) - } + oc.removePendingAckReactions(ctx, pending.Portal, pending) } } diff --git a/bridges/ai/subagent_registry.go b/bridges/ai/subagent_registry.go index 2a7bf963..4772b5a5 100644 --- a/bridges/ai/subagent_registry.go +++ b/bridges/ai/subagent_registry.go @@ -43,10 +43,8 @@ func (oc *AIClient) stopSubagentRuns(parent id.RoomID) int { continue } canceled := oc.cancelRoomRun(run.ChildRoomID) - queueSnapshot := oc.getQueueSnapshot(run.ChildRoomID) - hasQueued := queueSnapshot != nil && (len(queueSnapshot.items) > 0 || queueSnapshot.droppedCount > 0) - oc.clearPendingQueue(run.ChildRoomID) - if canceled || hasQueued { + drained := oc.drainPendingQueue(run.ChildRoomID) + if canceled || len(drained) > 0 { stopped++ } } From 02c907a13849285ae692a6321e891e5ae029e9d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 5 Apr 2026 16:48:56 +0200 Subject: [PATCH 5/7] Update pending_queue.go --- bridges/ai/pending_queue.go | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/bridges/ai/pending_queue.go b/bridges/ai/pending_queue.go index 6b8b684d..efdaa03b 100644 --- a/bridges/ai/pending_queue.go +++ b/bridges/ai/pending_queue.go @@ -239,14 +239,22 @@ func (oc *AIClient) getQueueSnapshot(roomID id.RoomID) *pendingQueue { } queue.mu.Lock() defer queue.mu.Unlock() - clone := *queue - clone.items = slices.Clone(queue.items) - clone.summaryLines = slices.Clone(queue.summaryLines) + clone := &pendingQueue{ + items: slices.Clone(queue.items), + draining: queue.draining, + lastEnqueuedAt: queue.lastEnqueuedAt, + mode: queue.mode, + debounceMs: queue.debounceMs, + cap: queue.cap, + dropPolicy: queue.dropPolicy, + droppedCount: queue.droppedCount, + summaryLines: slices.Clone(queue.summaryLines), + } if queue.lastItem != nil { lastItem := *queue.lastItem clone.lastItem = &lastItem } - return &clone + return clone } func (oc *AIClient) roomHasPendingQueueWork(roomID id.RoomID) bool { From a242d262b5d7d0c4b4b8ebc667b02aa5aa92afee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 5 Apr 2026 17:30:09 +0200 Subject: [PATCH 6/7] Propagate context to queue/stop ops and refine stop handling Pass context through pending-queue and subagent stop helpers (clearPendingQueue, stopSubagentRuns, finalizeStoppedQueueItems) and always finalize/drain pending items when clearing queues. Fix executeUserStopPlan logic to mark active stops before cancelling and fall back to a no-match when an active stop is a no-op. Ensure removePendingAckReactions is always invoked in goroutine cleanup. Adjust finishStreamingWithFailure to properly end turns on cancelled streams without falling through, and prefer explicit stop flag in canonicalResponseStatus. Add tests covering the no-op active stop fallback, cancelled finish behavior, and canonicalResponseStatus preference. --- bridges/ai/abort_helpers.go | 13 ++++++---- bridges/ai/abort_helpers_test.go | 28 +++++++++++++++++++++ bridges/ai/client.go | 6 ++--- bridges/ai/commands_parity.go | 2 +- bridges/ai/internal_dispatch.go | 2 +- bridges/ai/pending_queue.go | 4 +-- bridges/ai/streaming_error_handling.go | 4 ++- bridges/ai/streaming_error_handling_test.go | 27 ++++++++++++++++++++ bridges/ai/subagent_registry.go | 6 +++-- bridges/ai/turn_data.go | 6 ++--- bridges/ai/turn_data_test.go | 9 +++++++ 11 files changed, 88 insertions(+), 19 deletions(-) diff --git a/bridges/ai/abort_helpers.go b/bridges/ai/abort_helpers.go index 985cbde4..54e66822 100644 --- a/bridges/ai/abort_helpers.go +++ b/bridges/ai/abort_helpers.go @@ -155,13 +155,16 @@ func (oc *AIClient) executeUserStopPlan(ctx context.Context, req userStopRequest result.ActiveStopped = oc.cancelRoomRun(roomID) } result.QueuedStopped = oc.finalizeStoppedQueueItems(ctx, oc.drainPendingQueue(roomID)) - result.SubagentsStopped = oc.stopSubagentRuns(roomID) + result.SubagentsStopped = oc.stopSubagentRuns(ctx, roomID) case stopPlanKindActive: - if oc.markRoomRunStopped(roomID, buildStopMetadata(plan, req)) { + markedStopped := oc.markRoomRunStopped(roomID, buildStopMetadata(plan, req)) + if markedStopped { result.ActiveStopped = oc.cancelRoomRun(roomID) - if result.ActiveStopped { - result.SubagentsStopped = oc.stopSubagentRuns(roomID) - } + } + if result.ActiveStopped { + result.SubagentsStopped = oc.stopSubagentRuns(ctx, roomID) + } else { + result.Plan.Kind = stopPlanKindNoMatch } case stopPlanKindQueued: result.QueuedStopped = oc.finalizeStoppedQueueItems(ctx, oc.removePendingQueueBySourceEvent(roomID, plan.TargetEventID)) diff --git a/bridges/ai/abort_helpers_test.go b/bridges/ai/abort_helpers_test.go index a1c7c108..ca9597ee 100644 --- a/bridges/ai/abort_helpers_test.go +++ b/bridges/ai/abort_helpers_test.go @@ -121,6 +121,34 @@ func TestExecuteUserStopPlanRemovesOnlyTargetedQueuedTurn(t *testing.T) { } } +func TestExecuteUserStopPlanActiveNoOpFallsBackToNoMatch(t *testing.T) { + roomID := id.RoomID("!room:test") + oc := &AIClient{ + activeRoomRuns: map[id.RoomID]*roomRunState{ + roomID: { + sourceEvent: id.EventID("$user"), + }, + }, + } + portal := &bridgev2.Portal{Portal: &database.Portal{MXID: roomID}} + + result := oc.executeUserStopPlan(context.Background(), userStopRequest{ + Portal: portal, + ReplyTo: id.EventID("$user"), + }, userStopPlan{ + Kind: stopPlanKindActive, + Scope: "turn", + TargetKind: "source_event", + TargetEventID: id.EventID("$user"), + }) + if result.Plan.Kind != stopPlanKindNoMatch { + t.Fatalf("expected no-match fallback for no-op active stop, got %#v", result.Plan) + } + if result.ActiveStopped { + t.Fatalf("expected active stop to report false, got %#v", result) + } +} + func TestBuildStreamUIMessageIncludesStopMetadata(t *testing.T) { oc := &AIClient{} conv := bridgesdk.NewConversation[*AIClient, *Config](context.Background(), nil, nil, bridgev2.EventSender{}, nil, nil) diff --git a/bridges/ai/client.go b/bridges/ai/client.go index ed4a8666..01916cfc 100644 --- a/bridges/ai/client.go +++ b/bridges/ai/client.go @@ -641,7 +641,7 @@ func (oc *AIClient) dispatchOrQueueCore( queueDecision := airuntime.DecideQueueAction(queueSettings.Mode, roomBusy, false) if queueDecision.Action == airuntime.QueueActionInterruptAndRun { oc.cancelRoomRun(roomID) - oc.clearPendingQueue(roomID) + oc.clearPendingQueue(ctx, roomID) roomBusy = false } if !roomBusy && oc.acquireRoom(roomID) { @@ -667,9 +667,7 @@ func (oc *AIClient) dispatchOrQueueCore( metaSnapshot := clonePortalMetadata(meta) go func(metaSnapshot *PortalMetadata) { defer func() { - if hasDBMessage { - oc.removePendingAckReactions(oc.backgroundContext(ctx), portal, queueItem.pending) - } + oc.removePendingAckReactions(oc.backgroundContext(ctx), portal, queueItem.pending) oc.releaseRoom(roomID) oc.processPendingQueue(oc.backgroundContext(ctx), roomID) }() diff --git a/bridges/ai/commands_parity.go b/bridges/ai/commands_parity.go index cdbaf9de..e3ee4b63 100644 --- a/bridges/ai/commands_parity.go +++ b/bridges/ai/commands_parity.go @@ -45,7 +45,7 @@ func fnReset(ce *commands.Event) { meta.SessionResetAt = time.Now().UnixMilli() client.savePortalQuiet(ce.Ctx, ce.Portal, "session reset") - client.clearPendingQueue(ce.Portal.MXID) + client.clearPendingQueue(ce.Ctx, ce.Portal.MXID) client.cancelRoomRun(ce.Portal.MXID) ce.Reply("%s", formatSystemAck("Session reset.")) diff --git a/bridges/ai/internal_dispatch.go b/bridges/ai/internal_dispatch.go index d4cfd657..aed0669e 100644 --- a/bridges/ai/internal_dispatch.go +++ b/bridges/ai/internal_dispatch.go @@ -107,7 +107,7 @@ func (oc *AIClient) dispatchInternalMessage( queueDecision := airuntime.DecideQueueAction(queueSettings.Mode, oc.roomHasActiveRun(portal.MXID), false) if queueDecision.Action == airuntime.QueueActionInterruptAndRun { oc.cancelRoomRun(portal.MXID) - oc.clearPendingQueue(portal.MXID) + oc.clearPendingQueue(ctx, portal.MXID) } if shouldSteer && pending.Type == pendingTypeText { queueItem.prompt = pending.MessageBody diff --git a/bridges/ai/pending_queue.go b/bridges/ai/pending_queue.go index efdaa03b..35650828 100644 --- a/bridges/ai/pending_queue.go +++ b/bridges/ai/pending_queue.go @@ -85,8 +85,8 @@ func (oc *AIClient) getPendingQueue(roomID id.RoomID, settings airuntime.QueueSe return queue } -func (oc *AIClient) clearPendingQueue(roomID id.RoomID) { - oc.drainPendingQueue(roomID) +func (oc *AIClient) clearPendingQueue(ctx context.Context, roomID id.RoomID) { + oc.finalizeStoppedQueueItems(ctx, oc.drainPendingQueue(roomID)) } func (oc *AIClient) drainPendingQueue(roomID id.RoomID) []pendingQueueItem { diff --git a/bridges/ai/streaming_error_handling.go b/bridges/ai/streaming_error_handling.go index 92bb9769..63932df0 100644 --- a/bridges/ai/streaming_error_handling.go +++ b/bridges/ai/streaming_error_handling.go @@ -53,7 +53,9 @@ func (oc *AIClient) finishStreamingWithFailure( switch reason { case "cancelled": state.writer().Abort(ctx, "cancelled") - fallthrough + if state.turn != nil { + state.turn.End("cancelled") + } case "stop": if state.turn != nil { state.turn.End(msgconv.MapFinishReason(reason)) diff --git a/bridges/ai/streaming_error_handling_test.go b/bridges/ai/streaming_error_handling_test.go index 6ae8fbfa..770e6aae 100644 --- a/bridges/ai/streaming_error_handling_test.go +++ b/bridges/ai/streaming_error_handling_test.go @@ -5,10 +5,12 @@ import ( "errors" "testing" + "github.com/rs/zerolog" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/id" + "github.com/beeper/agentremote/pkg/shared/streamui" bridgesdk "github.com/beeper/agentremote/sdk" ) @@ -84,3 +86,28 @@ func TestStreamFailureErrorUsesAnyMessageTarget(t *testing.T) { } }) } + +func TestFinishStreamingWithFailureCancelledEndsTurnAsCancelled(t *testing.T) { + state := newTestStreamingStateWithTurn() + state.turn.SetSuppressSend(true) + state.writer().TextDelta(context.Background(), "hello") + + err := (&AIClient{}).finishStreamingWithFailure( + context.Background(), + zerolog.Nop(), + nil, + state, + nil, + "cancelled", + context.Canceled, + ) + if err == nil { + t.Fatal("expected wrapped cancellation error") + } + + message := streamui.SnapshotUIMessage(state.turn.UIState()) + metadata, _ := message["metadata"].(map[string]any) + if metadata["finish_reason"] != "cancelled" { + t.Fatalf("expected cancelled finish_reason, got %#v", metadata["finish_reason"]) + } +} diff --git a/bridges/ai/subagent_registry.go b/bridges/ai/subagent_registry.go index 4772b5a5..a6ecf9bf 100644 --- a/bridges/ai/subagent_registry.go +++ b/bridges/ai/subagent_registry.go @@ -1,6 +1,7 @@ package ai import ( + "context" "time" "maunium.net/go/mautrix/id" @@ -32,7 +33,7 @@ func (oc *AIClient) listSubagentRunsForParent(parent id.RoomID) []*subagentRun { return runs } -func (oc *AIClient) stopSubagentRuns(parent id.RoomID) int { +func (oc *AIClient) stopSubagentRuns(ctx context.Context, parent id.RoomID) int { if oc == nil || parent == "" { return 0 } @@ -44,7 +45,8 @@ func (oc *AIClient) stopSubagentRuns(parent id.RoomID) int { } canceled := oc.cancelRoomRun(run.ChildRoomID) drained := oc.drainPendingQueue(run.ChildRoomID) - if canceled || len(drained) > 0 { + finalized := oc.finalizeStoppedQueueItems(ctx, drained) + if canceled || finalized > 0 { stopped++ } } diff --git a/bridges/ai/turn_data.go b/bridges/ai/turn_data.go index 24c138aa..2dbb9759 100644 --- a/bridges/ai/turn_data.go +++ b/bridges/ai/turn_data.go @@ -59,6 +59,9 @@ func canonicalResponseStatus(state *streamingState) string { if state == nil { return "" } + if state.stop.Load() != nil { + return "cancelled" + } status := strings.TrimSpace(state.responseStatus) if state.completedAtMs == 0 { return status @@ -72,9 +75,6 @@ func canonicalResponseStatus(state *streamingState) string { if strings.TrimSpace(state.responseID) == "" { return status } - if state.stop.Load() != nil { - return "cancelled" - } switch strings.TrimSpace(state.finishReason) { case "", "stop": diff --git a/bridges/ai/turn_data_test.go b/bridges/ai/turn_data_test.go index fbd2b7c2..36919e63 100644 --- a/bridges/ai/turn_data_test.go +++ b/bridges/ai/turn_data_test.go @@ -100,3 +100,12 @@ func TestBuildTurnDataMetadataUsesResponderSnapshot(t *testing.T) { t.Fatalf("did not expect flat prompt_tokens field, got %#v", meta["prompt_tokens"]) } } + +func TestCanonicalResponseStatusPrefersExplicitStopWithoutResponseID(t *testing.T) { + state := testStreamingState("turn-cancelled") + state.stop.Store(&assistantStopMetadata{Reason: "user_stop"}) + + if got := canonicalResponseStatus(state); got != "cancelled" { + t.Fatalf("expected cancelled status from explicit stop, got %q", got) + } +} From bf332afb695b1804f2a43142864941548cd416b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Sun, 5 Apr 2026 18:11:45 +0200 Subject: [PATCH 7/7] Fix pending queue locking and lastItem updates Rework pending queue locking and item housekeeping to avoid races and stale pointers. getPendingQueue now locks the queue.mu before releasing pendingQueuesMu and consistently applies settings for both new and existing queues. drainPendingQueue clears queue.items and lastItem when removing a queue. removePendingQueueBySourceEvent reassigns lastItem to the new tail if the removed item was the last. enqueuePendingItem removed a now-redundant explicit lock (the returned queue is already locked). Added a unit test to verify lastItem is cleared/reassigned, and tightened an error assertion in the streaming test to use errors.Is for wrapped cancellations. --- bridges/ai/pending_queue.go | 38 ++++++++++++--------- bridges/ai/queue_status_test.go | 30 ++++++++++++++++ bridges/ai/streaming_error_handling_test.go | 4 +-- 3 files changed, 54 insertions(+), 18 deletions(-) diff --git a/bridges/ai/pending_queue.go b/bridges/ai/pending_queue.go index 35650828..9604c8a9 100644 --- a/bridges/ai/pending_queue.go +++ b/bridges/ai/pending_queue.go @@ -57,7 +57,6 @@ type pendingQueueDispatchCandidate struct { func (oc *AIClient) getPendingQueue(roomID id.RoomID, settings airuntime.QueueSettings) *pendingQueue { oc.pendingQueuesMu.Lock() - defer oc.pendingQueuesMu.Unlock() queue := oc.pendingQueues[roomID] if queue == nil { queue = &pendingQueue{ @@ -68,20 +67,19 @@ func (oc *AIClient) getPendingQueue(roomID id.RoomID, settings airuntime.QueueSe dropPolicy: settings.DropPolicy, } oc.pendingQueues[roomID] = queue - } else { - queue.mu.Lock() - queue.mode = settings.Mode - if settings.DebounceMs >= 0 { - queue.debounceMs = settings.DebounceMs - } - if settings.Cap > 0 { - queue.cap = settings.Cap - } - if settings.DropPolicy != "" { - queue.dropPolicy = settings.DropPolicy - } - queue.mu.Unlock() } + queue.mu.Lock() + queue.mode = settings.Mode + if settings.DebounceMs >= 0 { + queue.debounceMs = settings.DebounceMs + } + if settings.Cap > 0 { + queue.cap = settings.Cap + } + if settings.DropPolicy != "" { + queue.dropPolicy = settings.DropPolicy + } + oc.pendingQueuesMu.Unlock() return queue } @@ -99,9 +97,11 @@ func (oc *AIClient) drainPendingQueue(roomID id.RoomID) []pendingQueueItem { oc.pendingQueuesMu.Unlock() return nil } - delete(oc.pendingQueues, roomID) queue.mu.Lock() + delete(oc.pendingQueues, roomID) items := queue.items + queue.items = nil + queue.lastItem = nil queue.mu.Unlock() oc.pendingQueuesMu.Unlock() @@ -131,6 +131,13 @@ func (oc *AIClient) removePendingQueueBySourceEvent(roomID id.RoomID, sourceEven } clear(queue.items[len(kept):]) queue.items = kept + if queue.lastItem != nil && queue.lastItem.pending.sourceEventID() == sourceEventID { + queue.lastItem = nil + if len(kept) > 0 { + lastItem := kept[len(kept)-1] + queue.lastItem = &lastItem + } + } empty := len(queue.items) == 0 && queue.droppedCount == 0 if empty { delete(oc.pendingQueues, roomID) @@ -149,7 +156,6 @@ func (oc *AIClient) enqueuePendingItem(roomID id.RoomID, item pendingQueueItem, if queue == nil { return false } - queue.mu.Lock() defer queue.mu.Unlock() for _, existing := range queue.items { diff --git a/bridges/ai/queue_status_test.go b/bridges/ai/queue_status_test.go index b30b6092..2785af96 100644 --- a/bridges/ai/queue_status_test.go +++ b/bridges/ai/queue_status_test.go @@ -190,3 +190,33 @@ func TestDispatchOrQueueQueuesBehindExistingPendingWork(t *testing.T) { t.Fatalf("expected room to remain unacquired while backlog exists") } } + +func TestRemovePendingQueueBySourceEventClearsRemovedLastItem(t *testing.T) { + roomID := id.RoomID("!room:example.com") + first := pendingQueueItem{pending: pendingMessage{SourceEventID: id.EventID("$one")}} + last := pendingQueueItem{pending: pendingMessage{SourceEventID: id.EventID("$two")}} + oc := &AIClient{ + pendingQueues: map[id.RoomID]*pendingQueue{ + roomID: { + items: []pendingQueueItem{first, last}, + lastItem: &last, + }, + }, + } + + removed := oc.removePendingQueueBySourceEvent(roomID, id.EventID("$two")) + if len(removed) != 1 { + t.Fatalf("expected one removed item, got %d", len(removed)) + } + + snapshot := oc.getQueueSnapshot(roomID) + if snapshot == nil { + t.Fatal("expected queue snapshot to remain") + } + if snapshot.lastItem == nil { + t.Fatal("expected lastItem to be reassigned to the new tail") + } + if got := snapshot.lastItem.pending.sourceEventID(); got != id.EventID("$one") { + t.Fatalf("expected lastItem to point at remaining item, got %q", got) + } +} diff --git a/bridges/ai/streaming_error_handling_test.go b/bridges/ai/streaming_error_handling_test.go index 770e6aae..36435cb0 100644 --- a/bridges/ai/streaming_error_handling_test.go +++ b/bridges/ai/streaming_error_handling_test.go @@ -101,8 +101,8 @@ func TestFinishStreamingWithFailureCancelledEndsTurnAsCancelled(t *testing.T) { "cancelled", context.Canceled, ) - if err == nil { - t.Fatal("expected wrapped cancellation error") + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected wrapped cancellation error, got %#v", err) } message := streamui.SnapshotUIMessage(state.turn.UIState())