diff --git a/bridges/ai/abort_helpers.go b/bridges/ai/abort_helpers.go index 45607429..54e66822 100644 --- a/bridges/ai/abort_helpers.go +++ b/bridges/ai/abort_helpers.go @@ -3,31 +3,187 @@ package ai import ( "context" "fmt" + "strings" + "unicode" + "unicode/utf8" "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/id" ) -func formatAbortNotice(stopped int) string { - if stopped <= 0 { - return "Agent was aborted." +type stopPlanKind string + +const ( + stopPlanKindNoMatch stopPlanKind = "no-match" + stopPlanKindRoomWide stopPlanKind = "room-wide" + stopPlanKindActive stopPlanKind = "active-turn" + stopPlanKindQueued stopPlanKind = "queued-turn" +) + +type userStopRequest struct { + Portal *bridgev2.Portal + Meta *PortalMetadata + ReplyTo id.EventID + RequestedByEventID id.EventID + RequestedVia string +} + +type userStopPlan struct { + Kind stopPlanKind + Scope string + TargetKind string + TargetEventID id.EventID +} + +type userStopResult struct { + Plan userStopPlan + ActiveStopped bool + QueuedStopped int + SubagentsStopped int +} + +func stopLabel(count int, singular string) string { + if count == 1 { + return singular + } + return singular + "s" +} + +func formatAbortNotice(result userStopResult) string { + switch result.Plan.Kind { + case stopPlanKindNoMatch: + return "No matching active or queued turn found for that reply." + case stopPlanKindActive: + if result.SubagentsStopped > 0 { + return fmt.Sprintf("Stopped that turn. Stopped %d %s.", result.SubagentsStopped, stopLabel(result.SubagentsStopped, "sub-agent")) + } + return "Stopped that turn." + case stopPlanKindQueued: + if result.QueuedStopped <= 1 { + return "Stopped that queued turn." + } + return fmt.Sprintf("Stopped %d queued %s.", result.QueuedStopped, stopLabel(result.QueuedStopped, "turn")) + case stopPlanKindRoomWide: + parts := make([]string, 0, 3) + if result.ActiveStopped { + parts = append(parts, "stopped the active turn") + } + if result.QueuedStopped > 0 { + parts = append(parts, fmt.Sprintf("removed %d queued %s", result.QueuedStopped, stopLabel(result.QueuedStopped, "turn"))) + } + if result.SubagentsStopped > 0 { + parts = append(parts, fmt.Sprintf("stopped %d %s", result.SubagentsStopped, stopLabel(result.SubagentsStopped, "sub-agent"))) + } + if len(parts) == 0 { + return "No active or queued turns to stop." + } + for i := range parts { + r, size := utf8.DecodeRuneInString(parts[i]) + parts[i] = string(unicode.ToUpper(r)) + parts[i][size:] + } + return strings.Join(parts, ". ") + "." + default: + return "No active or queued turns to stop." + } +} + +func buildStopMetadata(plan userStopPlan, req userStopRequest) *assistantStopMetadata { + return &assistantStopMetadata{ + Reason: "user_stop", + Scope: plan.Scope, + TargetKind: plan.TargetKind, + TargetEventID: plan.TargetEventID.String(), + RequestedByEventID: req.RequestedByEventID.String(), + RequestedVia: strings.TrimSpace(req.RequestedVia), } - label := "sub-agents" - if stopped == 1 { - label = "sub-agent" +} + +func (oc *AIClient) resolveUserStopPlan(req userStopRequest) userStopPlan { + if req.Portal == nil || req.Portal.MXID == "" { + return userStopPlan{Kind: stopPlanKindNoMatch} + } + if req.ReplyTo == "" { + return userStopPlan{ + Kind: stopPlanKindRoomWide, + Scope: "room", + TargetKind: "all", + } + } + + _, sourceEventID, initialEventID, _ := oc.roomRunTarget(req.Portal.MXID) + if initialEventID != "" && req.ReplyTo == initialEventID { + return userStopPlan{ + Kind: stopPlanKindActive, + Scope: "turn", + TargetKind: "placeholder_event", + TargetEventID: req.ReplyTo, + } + } + if sourceEventID != "" && req.ReplyTo == sourceEventID { + return userStopPlan{ + Kind: stopPlanKindActive, + Scope: "turn", + TargetKind: "source_event", + TargetEventID: req.ReplyTo, + } + } + return userStopPlan{ + Kind: stopPlanKindQueued, + Scope: "turn", + TargetKind: "source_event", + TargetEventID: req.ReplyTo, } - return fmt.Sprintf("Agent was aborted. Stopped %d %s.", stopped, label) } -func (oc *AIClient) abortRoom(ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata) int { - if portal == nil { - return 0 +func (oc *AIClient) finalizeStoppedQueueItems(ctx context.Context, items []pendingQueueItem) int { + for _, item := range items { + oc.removePendingAckReactions(oc.backgroundContext(ctx), item.pending.Portal, item.pending) + oc.sendQueueRejectedStatus(ctx, item.pending.Portal, item.pending.Event, item.pending.StatusEvents, "Stopped.") } - oc.cancelRoomRun(portal.MXID) - oc.clearPendingQueue(portal.MXID) - stopped := oc.stopSubagentRuns(portal.MXID) - if meta != nil { - meta.AbortedLastRun = true - oc.savePortalQuiet(ctx, portal, "abort") + return len(items) +} + +func (oc *AIClient) executeUserStopPlan(ctx context.Context, req userStopRequest, plan userStopPlan) userStopResult { + result := userStopResult{Plan: plan} + if req.Portal == nil || req.Portal.MXID == "" { + return result + } + roomID := req.Portal.MXID + switch plan.Kind { + case stopPlanKindRoomWide: + if oc.markRoomRunStopped(roomID, buildStopMetadata(plan, req)) { + result.ActiveStopped = oc.cancelRoomRun(roomID) + } + result.QueuedStopped = oc.finalizeStoppedQueueItems(ctx, oc.drainPendingQueue(roomID)) + result.SubagentsStopped = oc.stopSubagentRuns(ctx, roomID) + case stopPlanKindActive: + markedStopped := oc.markRoomRunStopped(roomID, buildStopMetadata(plan, req)) + if markedStopped { + result.ActiveStopped = oc.cancelRoomRun(roomID) + } + if result.ActiveStopped { + result.SubagentsStopped = oc.stopSubagentRuns(ctx, roomID) + } else { + result.Plan.Kind = stopPlanKindNoMatch + } + case stopPlanKindQueued: + result.QueuedStopped = oc.finalizeStoppedQueueItems(ctx, oc.removePendingQueueBySourceEvent(roomID, plan.TargetEventID)) + if result.QueuedStopped == 0 { + result.Plan.Kind = stopPlanKindNoMatch + } + } + + if req.Meta != nil && (result.ActiveStopped || result.QueuedStopped > 0 || result.SubagentsStopped > 0) { + req.Meta.AbortedLastRun = true + oc.savePortalQuiet(ctx, req.Portal, "stop") } - return stopped + if req.Meta != nil && result.QueuedStopped > 0 { + oc.notifySessionMutation(ctx, req.Portal, req.Meta, false) + } + return result +} + +func (oc *AIClient) handleUserStop(ctx context.Context, req userStopRequest) userStopResult { + plan := oc.resolveUserStopPlan(req) + return oc.executeUserStopPlan(ctx, req, plan) } diff --git a/bridges/ai/abort_helpers_test.go b/bridges/ai/abort_helpers_test.go new file mode 100644 index 00000000..ca9597ee --- /dev/null +++ b/bridges/ai/abort_helpers_test.go @@ -0,0 +1,187 @@ +package ai + +import ( + "context" + "testing" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/id" + + bridgesdk "github.com/beeper/agentremote/sdk" +) + +func TestResolveUserStopPlanRoomWideWithoutReply(t *testing.T) { + oc := &AIClient{} + portal := &bridgev2.Portal{Portal: &database.Portal{MXID: "!room:test"}} + req := userStopRequest{Portal: portal, RequestedVia: "command"} + + plan := oc.resolveUserStopPlan(req) + if plan.Kind != stopPlanKindRoomWide { + t.Fatalf("expected room-wide stop, got %#v", plan) + } + if plan.TargetKind != "all" || plan.Scope != "room" { + t.Fatalf("unexpected room-wide stop plan: %#v", plan) + } +} + +func TestResolveUserStopPlanMatchesActiveReplyTargets(t *testing.T) { + roomID := id.RoomID("!room:test") + oc := &AIClient{ + activeRoomRuns: map[id.RoomID]*roomRunState{ + roomID: { + sourceEvent: id.EventID("$user"), + initialEvent: id.EventID("$assistant"), + }, + }, + } + portal := &bridgev2.Portal{Portal: &database.Portal{MXID: roomID}} + + placeholderPlan := oc.resolveUserStopPlan(userStopRequest{ + Portal: portal, + ReplyTo: id.EventID("$assistant"), + }) + if placeholderPlan.Kind != stopPlanKindActive || placeholderPlan.TargetKind != "placeholder_event" { + t.Fatalf("expected placeholder-targeted active stop, got %#v", placeholderPlan) + } + + sourcePlan := oc.resolveUserStopPlan(userStopRequest{ + Portal: portal, + ReplyTo: id.EventID("$user"), + }) + if sourcePlan.Kind != stopPlanKindActive || sourcePlan.TargetKind != "source_event" { + t.Fatalf("expected source-targeted active stop, got %#v", sourcePlan) + } +} + +func TestResolveUserStopPlanSpeculativelyReturnsQueued(t *testing.T) { + oc := &AIClient{} + portal := &bridgev2.Portal{Portal: &database.Portal{MXID: "!room:test"}} + + plan := oc.resolveUserStopPlan(userStopRequest{ + Portal: portal, + ReplyTo: id.EventID("$unknown"), + }) + if plan.Kind != stopPlanKindQueued || plan.TargetKind != "source_event" { + t.Fatalf("expected speculative queued stop plan, got %#v", plan) + } +} + +func TestExecuteUserStopPlanFallsBackToNoMatch(t *testing.T) { + oc := &AIClient{} + portal := &bridgev2.Portal{Portal: &database.Portal{MXID: "!room:test"}} + + result := oc.executeUserStopPlan(context.Background(), userStopRequest{ + Portal: portal, + }, userStopPlan{ + Kind: stopPlanKindQueued, + Scope: "turn", + TargetKind: "source_event", + TargetEventID: id.EventID("$nonexistent"), + }) + if result.Plan.Kind != stopPlanKindNoMatch { + t.Fatalf("expected no-match fallback, got %#v", result.Plan) + } + if result.QueuedStopped != 0 { + t.Fatalf("expected zero queued stopped, got %d", result.QueuedStopped) + } +} + +func TestExecuteUserStopPlanRemovesOnlyTargetedQueuedTurn(t *testing.T) { + roomID := id.RoomID("!room:test") + oc := &AIClient{ + pendingQueues: map[id.RoomID]*pendingQueue{ + roomID: { + items: []pendingQueueItem{ + {pending: pendingMessage{SourceEventID: id.EventID("$one")}}, + {pending: pendingMessage{SourceEventID: id.EventID("$two")}}, + }, + }, + }, + } + portal := &bridgev2.Portal{Portal: &database.Portal{MXID: roomID}} + + result := oc.executeUserStopPlan(context.Background(), userStopRequest{ + Portal: portal, + }, userStopPlan{ + Kind: stopPlanKindQueued, + Scope: "turn", + TargetKind: "source_event", + TargetEventID: id.EventID("$one"), + }) + if result.QueuedStopped != 1 { + t.Fatalf("expected one queued turn to stop, got %#v", result) + } + snapshot := oc.getQueueSnapshot(roomID) + if snapshot == nil || len(snapshot.items) != 1 { + t.Fatalf("expected one queued item to remain, got %#v", snapshot) + } + if got := snapshot.items[0].pending.sourceEventID(); got != id.EventID("$two") { + t.Fatalf("expected remaining queued event $two, got %q", got) + } +} + +func TestExecuteUserStopPlanActiveNoOpFallsBackToNoMatch(t *testing.T) { + roomID := id.RoomID("!room:test") + oc := &AIClient{ + activeRoomRuns: map[id.RoomID]*roomRunState{ + roomID: { + sourceEvent: id.EventID("$user"), + }, + }, + } + portal := &bridgev2.Portal{Portal: &database.Portal{MXID: roomID}} + + result := oc.executeUserStopPlan(context.Background(), userStopRequest{ + Portal: portal, + ReplyTo: id.EventID("$user"), + }, userStopPlan{ + Kind: stopPlanKindActive, + Scope: "turn", + TargetKind: "source_event", + TargetEventID: id.EventID("$user"), + }) + if result.Plan.Kind != stopPlanKindNoMatch { + t.Fatalf("expected no-match fallback for no-op active stop, got %#v", result.Plan) + } + if result.ActiveStopped { + t.Fatalf("expected active stop to report false, got %#v", result) + } +} + +func TestBuildStreamUIMessageIncludesStopMetadata(t *testing.T) { + oc := &AIClient{} + conv := bridgesdk.NewConversation[*AIClient, *Config](context.Background(), nil, nil, bridgev2.EventSender{}, nil, nil) + turn := conv.StartTurn(context.Background(), nil, &bridgesdk.SourceRef{EventID: "$user", SenderID: "@user:test"}) + turn.SetID("turn-stop") + state := &streamingState{ + turn: turn, + finishReason: "stop", + responseID: "resp_123", + completedAtMs: 1, + } + state.stop.Store(&assistantStopMetadata{ + Reason: "user_stop", + Scope: "turn", + TargetKind: "source_event", + TargetEventID: "$user", + RequestedByEventID: "$stop", + RequestedVia: "command", + }) + + ui := oc.buildStreamUIMessage(state, nil, nil) + metadata, ok := ui["metadata"].(map[string]any) + if !ok { + t.Fatalf("expected metadata map, got %T", ui["metadata"]) + } + stop, ok := metadata["stop"].(map[string]any) + if !ok { + t.Fatalf("expected nested stop metadata, got %#v", metadata["stop"]) + } + if stop["reason"] != "user_stop" || stop["requested_via"] != "command" { + t.Fatalf("unexpected stop metadata: %#v", stop) + } + if metadata["response_status"] != "cancelled" { + t.Fatalf("expected cancelled response status for stopped turn, got %#v", metadata["response_status"]) + } +} diff --git a/bridges/ai/client.go b/bridges/ai/client.go index b181d712..01916cfc 100644 --- a/bridges/ai/client.go +++ b/bridges/ai/client.go @@ -641,7 +641,7 @@ func (oc *AIClient) dispatchOrQueueCore( queueDecision := airuntime.DecideQueueAction(queueSettings.Mode, roomBusy, false) if queueDecision.Action == airuntime.QueueActionInterruptAndRun { oc.cancelRoomRun(roomID) - oc.clearPendingQueue(roomID) + oc.clearPendingQueue(ctx, roomID) roomBusy = false } if !roomBusy && oc.acquireRoom(roomID) { @@ -667,9 +667,7 @@ func (oc *AIClient) dispatchOrQueueCore( metaSnapshot := clonePortalMetadata(meta) go func(metaSnapshot *PortalMetadata) { defer func() { - if hasDBMessage && metaSnapshot != nil && metaSnapshot.AckReactionRemoveAfter { - oc.removePendingAckReactions(oc.backgroundContext(ctx), portal, queueItem.pending) - } + oc.removePendingAckReactions(oc.backgroundContext(ctx), portal, queueItem.pending) oc.releaseRoom(roomID) oc.processPendingQueue(oc.backgroundContext(ctx), roomID) }() @@ -815,9 +813,7 @@ func (oc *AIClient) processPendingQueue(ctx context.Context, roomID id.RoomID) { if err != nil { oc.loggerForContext(ctx).Err(err).Msg("Failed to build prompt for pending queue item") oc.notifyMatrixSendFailure(ctx, item.pending.Portal, item.pending.Event, err) - if item.pending.Meta != nil && item.pending.Meta.AckReactionRemoveAfter { - oc.removePendingAckReactions(oc.backgroundContext(ctx), item.pending.Portal, item.pending) - } + oc.removePendingAckReactions(oc.backgroundContext(ctx), item.pending.Portal, item.pending) oc.releaseRoom(roomID) oc.processPendingQueue(oc.backgroundContext(ctx), roomID) return diff --git a/bridges/ai/commands_parity.go b/bridges/ai/commands_parity.go index bafb22de..e3ee4b63 100644 --- a/bridges/ai/commands_parity.go +++ b/bridges/ai/commands_parity.go @@ -45,7 +45,7 @@ func fnReset(ce *commands.Event) { meta.SessionResetAt = time.Now().UnixMilli() client.savePortalQuiet(ce.Ctx, ce.Portal, "session reset") - client.clearPendingQueue(ce.Portal.MXID) + client.clearPendingQueue(ce.Ctx, ce.Portal.MXID) client.cancelRoomRun(ce.Portal.MXID) ce.Reply("%s", formatSystemAck("Session reset.")) @@ -65,6 +65,12 @@ func fnStop(ce *commands.Event) { if !ok { return } - stopped := client.abortRoom(ce.Ctx, ce.Portal, meta) - ce.Reply("%s", formatAbortNotice(stopped)) + result := client.handleUserStop(ce.Ctx, userStopRequest{ + Portal: ce.Portal, + Meta: meta, + ReplyTo: ce.ReplyTo, + RequestedByEventID: ce.EventID, + RequestedVia: "command", + }) + ce.Reply("%s", formatAbortNotice(result)) } diff --git a/bridges/ai/handlematrix.go b/bridges/ai/handlematrix.go index 4604d084..d1864069 100644 --- a/bridges/ai/handlematrix.go +++ b/bridges/ai/handlematrix.go @@ -135,8 +135,15 @@ func (oc *AIClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Matri return &bridgev2.MatrixMessageResponse{Pending: false}, nil } if commandAuthorized && airuntime.IsAbortTriggerText(commandBody) { - stopped := oc.abortRoom(ctx, portal, meta) - oc.sendSystemNotice(ctx, portal, formatAbortNotice(stopped)) + replyCtx := extractInboundReplyContext(msg.Event) + result := oc.handleUserStop(ctx, userStopRequest{ + Portal: portal, + Meta: meta, + ReplyTo: replyCtx.ReplyTo, + RequestedByEventID: msg.Event.ID, + RequestedVia: "text-trigger", + }) + oc.sendSystemNotice(ctx, portal, formatAbortNotice(result)) logCtx.Debug().Msg("Abort trigger handled") return &bridgev2.MatrixMessageResponse{Pending: false}, nil } diff --git a/bridges/ai/internal_dispatch.go b/bridges/ai/internal_dispatch.go index d4cfd657..aed0669e 100644 --- a/bridges/ai/internal_dispatch.go +++ b/bridges/ai/internal_dispatch.go @@ -107,7 +107,7 @@ func (oc *AIClient) dispatchInternalMessage( queueDecision := airuntime.DecideQueueAction(queueSettings.Mode, oc.roomHasActiveRun(portal.MXID), false) if queueDecision.Action == airuntime.QueueActionInterruptAndRun { oc.cancelRoomRun(portal.MXID) - oc.clearPendingQueue(portal.MXID) + oc.clearPendingQueue(ctx, portal.MXID) } if shouldSteer && pending.Type == pendingTypeText { queueItem.prompt = pending.MessageBody diff --git a/bridges/ai/pending_queue.go b/bridges/ai/pending_queue.go index aee69c52..9604c8a9 100644 --- a/bridges/ai/pending_queue.go +++ b/bridges/ai/pending_queue.go @@ -38,6 +38,16 @@ type pendingQueue struct { lastItem *pendingQueueItem } +func (pm pendingMessage) sourceEventID() id.EventID { + if pm.SourceEventID != "" { + return pm.SourceEventID + } + if pm.Event != nil { + return pm.Event.ID + } + return "" +} + type pendingQueueDispatchCandidate struct { items []pendingQueueItem summaryPrompt string @@ -47,7 +57,6 @@ type pendingQueueDispatchCandidate struct { func (oc *AIClient) getPendingQueue(roomID id.RoomID, settings airuntime.QueueSettings) *pendingQueue { oc.pendingQueuesMu.Lock() - defer oc.pendingQueuesMu.Unlock() queue := oc.pendingQueues[roomID] if queue == nil { queue = &pendingQueue{ @@ -58,31 +67,88 @@ func (oc *AIClient) getPendingQueue(roomID id.RoomID, settings airuntime.QueueSe dropPolicy: settings.DropPolicy, } oc.pendingQueues[roomID] = queue - } else { - queue.mu.Lock() - queue.mode = settings.Mode - if settings.DebounceMs >= 0 { - queue.debounceMs = settings.DebounceMs - } - if settings.Cap > 0 { - queue.cap = settings.Cap - } - if settings.DropPolicy != "" { - queue.dropPolicy = settings.DropPolicy - } - queue.mu.Unlock() } + queue.mu.Lock() + queue.mode = settings.Mode + if settings.DebounceMs >= 0 { + queue.debounceMs = settings.DebounceMs + } + if settings.Cap > 0 { + queue.cap = settings.Cap + } + if settings.DropPolicy != "" { + queue.dropPolicy = settings.DropPolicy + } + oc.pendingQueuesMu.Unlock() return queue } -func (oc *AIClient) clearPendingQueue(roomID id.RoomID) { +func (oc *AIClient) clearPendingQueue(ctx context.Context, roomID id.RoomID) { + oc.finalizeStoppedQueueItems(ctx, oc.drainPendingQueue(roomID)) +} + +func (oc *AIClient) drainPendingQueue(roomID id.RoomID) []pendingQueueItem { + if oc == nil || roomID == "" { + return nil + } oc.pendingQueuesMu.Lock() - _, existed := oc.pendingQueues[roomID] + queue := oc.pendingQueues[roomID] + if queue == nil { + oc.pendingQueuesMu.Unlock() + return nil + } + queue.mu.Lock() delete(oc.pendingQueues, roomID) + items := queue.items + queue.items = nil + queue.lastItem = nil + queue.mu.Unlock() oc.pendingQueuesMu.Unlock() - if existed { + + oc.stopQueueTyping(roomID) + return items +} + +func (oc *AIClient) removePendingQueueBySourceEvent(roomID id.RoomID, sourceEventID id.EventID) []pendingQueueItem { + if oc == nil || roomID == "" || sourceEventID == "" { + return nil + } + oc.pendingQueuesMu.Lock() + queue := oc.pendingQueues[roomID] + if queue == nil { + oc.pendingQueuesMu.Unlock() + return nil + } + queue.mu.Lock() + removed := make([]pendingQueueItem, 0, 1) + kept := queue.items[:0] + for _, item := range queue.items { + if item.pending.sourceEventID() == sourceEventID { + removed = append(removed, item) + continue + } + kept = append(kept, item) + } + clear(queue.items[len(kept):]) + queue.items = kept + if queue.lastItem != nil && queue.lastItem.pending.sourceEventID() == sourceEventID { + queue.lastItem = nil + if len(kept) > 0 { + lastItem := kept[len(kept)-1] + queue.lastItem = &lastItem + } + } + empty := len(queue.items) == 0 && queue.droppedCount == 0 + if empty { + delete(oc.pendingQueues, roomID) + } + queue.mu.Unlock() + oc.pendingQueuesMu.Unlock() + + if empty { oc.stopQueueTyping(roomID) } + return removed } func (oc *AIClient) enqueuePendingItem(roomID id.RoomID, item pendingQueueItem, settings airuntime.QueueSettings) bool { @@ -90,7 +156,6 @@ func (oc *AIClient) enqueuePendingItem(roomID id.RoomID, item pendingQueueItem, if queue == nil { return false } - queue.mu.Lock() defer queue.mu.Unlock() for _, existing := range queue.items { @@ -180,14 +245,22 @@ func (oc *AIClient) getQueueSnapshot(roomID id.RoomID) *pendingQueue { } queue.mu.Lock() defer queue.mu.Unlock() - clone := *queue - clone.items = slices.Clone(queue.items) - clone.summaryLines = slices.Clone(queue.summaryLines) + clone := &pendingQueue{ + items: slices.Clone(queue.items), + draining: queue.draining, + lastEnqueuedAt: queue.lastEnqueuedAt, + mode: queue.mode, + debounceMs: queue.debounceMs, + cap: queue.cap, + dropPolicy: queue.dropPolicy, + droppedCount: queue.droppedCount, + summaryLines: slices.Clone(queue.summaryLines), + } if queue.lastItem != nil { lastItem := *queue.lastItem clone.lastItem = &lastItem } - return &clone + return clone } func (oc *AIClient) roomHasPendingQueueWork(roomID id.RoomID) bool { @@ -421,9 +494,7 @@ func (oc *AIClient) dispatchQueuedPrompt( metaSnapshot := clonePortalMetadata(item.pending.Meta) go func() { defer func() { - if metaSnapshot != nil && metaSnapshot.AckReactionRemoveAfter { - oc.removePendingAckReactions(oc.backgroundContext(ctx), item.pending.Portal, item.pending) - } + oc.removePendingAckReactions(oc.backgroundContext(ctx), item.pending.Portal, item.pending) if item.backlogAfter { followup := item followup.backlogAfter = false @@ -439,7 +510,7 @@ func (oc *AIClient) dispatchQueuedPrompt( } func (oc *AIClient) removePendingAckReactions(ctx context.Context, portal *bridgev2.Portal, pending pendingMessage) { - if portal == nil { + if portal == nil || pending.Meta == nil || !pending.Meta.AckReactionRemoveAfter { return } ids := pending.AckEventIDs diff --git a/bridges/ai/queue_status_test.go b/bridges/ai/queue_status_test.go index b30b6092..2785af96 100644 --- a/bridges/ai/queue_status_test.go +++ b/bridges/ai/queue_status_test.go @@ -190,3 +190,33 @@ func TestDispatchOrQueueQueuesBehindExistingPendingWork(t *testing.T) { t.Fatalf("expected room to remain unacquired while backlog exists") } } + +func TestRemovePendingQueueBySourceEventClearsRemovedLastItem(t *testing.T) { + roomID := id.RoomID("!room:example.com") + first := pendingQueueItem{pending: pendingMessage{SourceEventID: id.EventID("$one")}} + last := pendingQueueItem{pending: pendingMessage{SourceEventID: id.EventID("$two")}} + oc := &AIClient{ + pendingQueues: map[id.RoomID]*pendingQueue{ + roomID: { + items: []pendingQueueItem{first, last}, + lastItem: &last, + }, + }, + } + + removed := oc.removePendingQueueBySourceEvent(roomID, id.EventID("$two")) + if len(removed) != 1 { + t.Fatalf("expected one removed item, got %d", len(removed)) + } + + snapshot := oc.getQueueSnapshot(roomID) + if snapshot == nil { + t.Fatal("expected queue snapshot to remain") + } + if snapshot.lastItem == nil { + t.Fatal("expected lastItem to be reassigned to the new tail") + } + if got := snapshot.lastItem.pending.sourceEventID(); got != id.EventID("$one") { + t.Fatalf("expected lastItem to point at remaining item, got %q", got) + } +} diff --git a/bridges/ai/room_runs.go b/bridges/ai/room_runs.go index 64071164..c83c7e7b 100644 --- a/bridges/ai/room_runs.go +++ b/bridges/ai/room_runs.go @@ -13,6 +13,11 @@ type roomRunState struct { cancel context.CancelFunc mu sync.Mutex + state *streamingState + stop *assistantStopMetadata + turnID string + sourceEvent id.EventID + initialEvent id.EventID streaming bool steerQueue []pendingQueueItem statusEvents []*event.Event @@ -71,9 +76,7 @@ func (oc *AIClient) clearRoomRun(roomID id.RoomID) { } ctx := oc.backgroundContext(context.Background()) for _, pending := range ackPending { - if pending.Meta != nil && pending.Meta.AckReactionRemoveAfter { - oc.removePendingAckReactions(ctx, pending.Portal, pending) - } + oc.removePendingAckReactions(ctx, pending.Portal, pending) } } @@ -97,6 +100,52 @@ func (oc *AIClient) markRoomRunStreaming(roomID id.RoomID, streaming bool) { run.mu.Unlock() } +func (oc *AIClient) bindRoomRunState(roomID id.RoomID, state *streamingState) { + run := oc.getRoomRun(roomID) + if run == nil { + return + } + run.mu.Lock() + run.state = state + if run.stop != nil && state != nil { + state.stop.Store(run.stop) + } + if state != nil && state.turn != nil { + run.turnID = state.turn.ID() + run.sourceEvent = state.sourceEventID() + run.initialEvent = state.turn.InitialEventID() + } + run.mu.Unlock() +} + +func (oc *AIClient) roomRunTarget(roomID id.RoomID) (turnID string, sourceEventID, initialEventID id.EventID, state *streamingState) { + run := oc.getRoomRun(roomID) + if run == nil { + return "", "", "", nil + } + run.mu.Lock() + defer run.mu.Unlock() + state = run.state + if state != nil && state.turn != nil { + return state.turn.ID(), state.sourceEventID(), state.turn.InitialEventID(), state + } + return run.turnID, run.sourceEvent, run.initialEvent, state +} + +func (oc *AIClient) markRoomRunStopped(roomID id.RoomID, stop *assistantStopMetadata) bool { + run := oc.getRoomRun(roomID) + if run == nil || stop == nil { + return false + } + run.mu.Lock() + run.stop = stop + if run.state != nil { + run.state.stop.Store(stop) + } + run.mu.Unlock() + return true +} + func (oc *AIClient) enqueueSteerQueue(roomID id.RoomID, item pendingQueueItem) bool { run := oc.getRoomRun(roomID) if run == nil { diff --git a/bridges/ai/streaming_error_handling.go b/bridges/ai/streaming_error_handling.go index 14a848fa..63932df0 100644 --- a/bridges/ai/streaming_error_handling.go +++ b/bridges/ai/streaming_error_handling.go @@ -40,6 +40,9 @@ func (oc *AIClient) finishStreamingWithFailure( reason string, err error, ) error { + if state != nil && state.stop.Load() != nil && reason == "cancelled" { + reason = "stop" + } state.finishReason = reason state.completedAtMs = time.Now().UnixMilli() _ = log @@ -47,13 +50,18 @@ func (oc *AIClient) finishStreamingWithFailure( if writer := state.writer(); writer != nil { writer.MessageMetadata(ctx, oc.buildUIMessageMetadata(state, meta, true)) } - if reason == "cancelled" { + switch reason { + case "cancelled": state.writer().Abort(ctx, "cancelled") - if state != nil && state.turn != nil { + if state.turn != nil { + state.turn.End("cancelled") + } + case "stop": + if state.turn != nil { state.turn.End(msgconv.MapFinishReason(reason)) } - } else { - if state != nil && state.turn != nil { + default: + if state.turn != nil { state.turn.EndWithError(err.Error()) } } diff --git a/bridges/ai/streaming_error_handling_test.go b/bridges/ai/streaming_error_handling_test.go index 6ae8fbfa..36435cb0 100644 --- a/bridges/ai/streaming_error_handling_test.go +++ b/bridges/ai/streaming_error_handling_test.go @@ -5,10 +5,12 @@ import ( "errors" "testing" + "github.com/rs/zerolog" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/id" + "github.com/beeper/agentremote/pkg/shared/streamui" bridgesdk "github.com/beeper/agentremote/sdk" ) @@ -84,3 +86,28 @@ func TestStreamFailureErrorUsesAnyMessageTarget(t *testing.T) { } }) } + +func TestFinishStreamingWithFailureCancelledEndsTurnAsCancelled(t *testing.T) { + state := newTestStreamingStateWithTurn() + state.turn.SetSuppressSend(true) + state.writer().TextDelta(context.Background(), "hello") + + err := (&AIClient{}).finishStreamingWithFailure( + context.Background(), + zerolog.Nop(), + nil, + state, + nil, + "cancelled", + context.Canceled, + ) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected wrapped cancellation error, got %#v", err) + } + + message := streamui.SnapshotUIMessage(state.turn.UIState()) + metadata, _ := message["metadata"].(map[string]any) + if metadata["finish_reason"] != "cancelled" { + t.Fatalf("expected cancelled finish_reason, got %#v", metadata["finish_reason"]) + } +} diff --git a/bridges/ai/streaming_init.go b/bridges/ai/streaming_init.go index 1833ec07..7f929cfd 100644 --- a/bridges/ai/streaming_init.go +++ b/bridges/ai/streaming_init.go @@ -128,6 +128,7 @@ func (oc *AIClient) prepareStreamingRun( // Create SDK Turn for writer/emitter/session management. turn := oc.createStreamingTurn(ctx, portal, meta, state, sourceEventID, senderID) state.turn = turn + oc.bindRoomRunState(roomID, state) state.replyTarget = oc.resolveInitialReplyTarget(evt) if state.replyTarget.ThreadRoot != "" { diff --git a/bridges/ai/streaming_state.go b/bridges/ai/streaming_state.go index 18990d95..cb205045 100644 --- a/bridges/ai/streaming_state.go +++ b/bridges/ai/streaming_state.go @@ -3,6 +3,7 @@ package ai import ( "context" "strings" + "sync/atomic" "time" "github.com/openai/openai-go/v3/packages/param" @@ -69,6 +70,8 @@ type streamingState struct { // Pending MCP approvals to resolve before the turn can continue. pendingMcpApprovals []mcpApprovalRequest pendingMcpApprovalsSeen map[string]bool + + stop atomic.Pointer[assistantStopMetadata] } // sourceEventID returns the triggering user message event ID from the turn's source ref. diff --git a/bridges/ai/subagent_registry.go b/bridges/ai/subagent_registry.go index 2a7bf963..a6ecf9bf 100644 --- a/bridges/ai/subagent_registry.go +++ b/bridges/ai/subagent_registry.go @@ -1,6 +1,7 @@ package ai import ( + "context" "time" "maunium.net/go/mautrix/id" @@ -32,7 +33,7 @@ func (oc *AIClient) listSubagentRunsForParent(parent id.RoomID) []*subagentRun { return runs } -func (oc *AIClient) stopSubagentRuns(parent id.RoomID) int { +func (oc *AIClient) stopSubagentRuns(ctx context.Context, parent id.RoomID) int { if oc == nil || parent == "" { return 0 } @@ -43,10 +44,9 @@ func (oc *AIClient) stopSubagentRuns(parent id.RoomID) int { continue } canceled := oc.cancelRoomRun(run.ChildRoomID) - queueSnapshot := oc.getQueueSnapshot(run.ChildRoomID) - hasQueued := queueSnapshot != nil && (len(queueSnapshot.items) > 0 || queueSnapshot.droppedCount > 0) - oc.clearPendingQueue(run.ChildRoomID) - if canceled || hasQueued { + drained := oc.drainPendingQueue(run.ChildRoomID) + finalized := oc.finalizeStoppedQueueItems(ctx, drained) + if canceled || finalized > 0 { stopped++ } } diff --git a/bridges/ai/turn_data.go b/bridges/ai/turn_data.go index 6448b6da..2dbb9759 100644 --- a/bridges/ai/turn_data.go +++ b/bridges/ai/turn_data.go @@ -59,6 +59,9 @@ func canonicalResponseStatus(state *streamingState) string { if state == nil { return "" } + if state.stop.Load() != nil { + return "cancelled" + } status := strings.TrimSpace(state.responseStatus) if state.completedAtMs == 0 { return status diff --git a/bridges/ai/turn_data_test.go b/bridges/ai/turn_data_test.go index fbd2b7c2..36919e63 100644 --- a/bridges/ai/turn_data_test.go +++ b/bridges/ai/turn_data_test.go @@ -100,3 +100,12 @@ func TestBuildTurnDataMetadataUsesResponderSnapshot(t *testing.T) { t.Fatalf("did not expect flat prompt_tokens field, got %#v", meta["prompt_tokens"]) } } + +func TestCanonicalResponseStatusPrefersExplicitStopWithoutResponseID(t *testing.T) { + state := testStreamingState("turn-cancelled") + state.stop.Store(&assistantStopMetadata{Reason: "user_stop"}) + + if got := canonicalResponseStatus(state); got != "cancelled" { + t.Fatalf("expected cancelled status from explicit stop, got %q", got) + } +} diff --git a/bridges/ai/ui_message_metadata.go b/bridges/ai/ui_message_metadata.go index 5c888330..b55abf32 100644 --- a/bridges/ai/ui_message_metadata.go +++ b/bridges/ai/ui_message_metadata.go @@ -13,6 +13,15 @@ type assistantUsageMetadata struct { TotalTokens int64 `json:"total_tokens,omitempty"` } +type assistantStopMetadata struct { + Reason string `json:"reason,omitempty"` + Scope string `json:"scope,omitempty"` + TargetKind string `json:"target_kind,omitempty"` + TargetEventID string `json:"target_event_id,omitempty"` + RequestedByEventID string `json:"requested_by_event_id,omitempty"` + RequestedVia string `json:"requested_via,omitempty"` +} + type assistantTurnMetadata struct { TurnID string `json:"turn_id,omitempty"` AgentID string `json:"agent_id,omitempty"` @@ -28,6 +37,7 @@ type assistantTurnMetadata struct { SourceEventID string `json:"source_event_id,omitempty"` GeneratedFileRefs []GeneratedFileRef `json:"generated_file_refs,omitempty"` Usage *assistantUsageMetadata `json:"usage,omitempty"` + Stop *assistantStopMetadata `json:"stop,omitempty"` } func buildAssistantUsageMetadata(state *streamingState) *assistantUsageMetadata { @@ -70,5 +80,6 @@ func buildAssistantTurnMetadata(state *streamingState, turnID, networkMessageID, SourceEventID: state.sourceEventID().String(), GeneratedFileRefs: agentremote.GeneratedFileRefsFromParts(state.generatedFiles), Usage: buildAssistantUsageMetadata(state), + Stop: state.stop.Load(), }) }