Skip to content

Commit 56ff5d0

Browse files
committed
fix: send model thoughts with tool usage recording
Signed-off-by: Danny Kopping <danny@coder.com>
1 parent b24f893 commit 56ff5d0

7 files changed

Lines changed: 52 additions & 147 deletions

File tree

bridge_integration_test.go

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -216,20 +216,17 @@ func TestAnthropicMessagesModelThoughts(t *testing.T) {
216216
assert.Contains(t, sp.AllEvents(), "message_stop")
217217
}
218218

219-
// Verify model thoughts were captured and associated with the tool call.
220-
thoughts := recorderClient.RecordedModelThoughts()
221-
require.Len(t, thoughts, 1)
222-
assert.Contains(t, thoughts[0].Content, "The user wants me to read")
223-
assert.Contains(t, thoughts[0].Content, tc.expectedThinkingSubstr)
224-
assert.NotEmpty(t, thoughts[0].InterceptionID)
225-
assert.Equal(t, tc.expectedToolCallID, thoughts[0].ProviderToolCallID)
226-
227-
// Verify tool usage was also recorded.
219+
// Verify tool usage was recorded with associated model thoughts.
228220
toolUsages := recorderClient.RecordedToolUsages()
229221
require.Len(t, toolUsages, 1)
230222
assert.Equal(t, "Read", toolUsages[0].Tool)
231223
assert.Equal(t, tc.expectedToolCallID, toolUsages[0].ToolCallID)
232224

225+
// Model thoughts should be embedded in the tool usage record.
226+
require.Len(t, toolUsages[0].ModelThoughts, 1)
227+
assert.Contains(t, toolUsages[0].ModelThoughts[0].Content, "The user wants me to read")
228+
assert.Contains(t, toolUsages[0].ModelThoughts[0].Content, tc.expectedThinkingSubstr)
229+
233230
recorderClient.VerifyAllInterceptionsEnded(t)
234231
})
235232
}
@@ -271,9 +268,10 @@ func TestAnthropicMessagesModelThoughts(t *testing.T) {
271268
sp := aibridge.NewSSEParser()
272269
require.NoError(t, sp.Parse(resp.Body))
273270

274-
// No thoughts should be recorded when there are no tool calls.
275-
thoughts := recorderClient.RecordedModelThoughts()
276-
assert.Empty(t, thoughts)
271+
// No tool usages (and therefore no thoughts) should be recorded
272+
// when there are no tool calls.
273+
toolUsages := recorderClient.RecordedToolUsages()
274+
assert.Empty(t, toolUsages)
277275

278276
recorderClient.VerifyAllInterceptionsEnded(t)
279277
})

intercept/messages/blocking.go

Lines changed: 19 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -137,18 +137,16 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
137137

138138
// Capture any thinking blocks that were returned.
139139
var thoughtRecords []*recorder.ModelThoughtRecord
140-
if !i.isSmallFastModel() {
141-
for _, block := range resp.Content {
142-
switch variant := block.AsAny().(type) {
143-
case anthropic.ThinkingBlock:
144-
thoughtRecords = append(thoughtRecords, &recorder.ModelThoughtRecord{
145-
InterceptionID: i.ID().String(),
146-
Content: variant.Thinking,
147-
})
148-
case anthropic.RedactedThinkingBlock:
149-
// For redacted thinking, there's nothing useful we can capture.
150-
continue
151-
}
140+
for _, block := range resp.Content {
141+
switch variant := block.AsAny().(type) {
142+
case anthropic.ThinkingBlock:
143+
thoughtRecords = append(thoughtRecords, &recorder.ModelThoughtRecord{
144+
Content: variant.Thinking,
145+
CreatedAt: time.Now(),
146+
})
147+
case anthropic.RedactedThinkingBlock:
148+
// For redacted thinking, there's nothing useful we can capture.
149+
continue
152150
}
153151
}
154152

@@ -173,22 +171,15 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
173171
Tool: toolUse.Name,
174172
Args: toolUse.Input,
175173
Injected: false,
174+
ModelThoughts: thoughtRecords,
176175
})
177-
178-
// Associate the model thoughts with this tool call.
179-
for _, thought := range thoughtRecords {
180-
thought.ProviderToolCallID = toolUse.ID
181-
}
176+
// Clear after first use to avoid duplicating across
177+
// multiple tool calls in the same message.
178+
thoughtRecords = nil
182179
}
183180

184-
// If no injected tool calls, persist thoughts and we're done.
181+
// If no injected tool calls, we're done.
185182
if len(pendingToolCalls) == 0 {
186-
for _, thought := range thoughtRecords {
187-
if thought.ProviderToolCallID == "" {
188-
continue
189-
}
190-
_ = i.recorder.RecordModelThought(ctx, thought)
191-
}
192183
break
193184
}
194185

@@ -223,12 +214,11 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
223214
Args: tc.Input,
224215
Injected: true,
225216
InvocationError: err,
217+
ModelThoughts: thoughtRecords,
226218
})
227-
228-
// Associate the model thoughts with this tool call.
229-
for _, thought := range thoughtRecords {
230-
thought.ProviderToolCallID = tc.ID
231-
}
219+
// Clear after first use to avoid duplicating across
220+
// multiple tool calls in the same message.
221+
thoughtRecords = nil
232222

233223
if err != nil {
234224
// Always provide a tool_result even if the tool call failed
@@ -315,14 +305,6 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
315305
}
316306
}
317307

318-
// Only persist thoughts that are associated to a tool call.
319-
for _, thought := range thoughtRecords {
320-
if thought.ProviderToolCallID == "" {
321-
continue
322-
}
323-
_ = i.recorder.RecordModelThought(ctx, thought)
324-
}
325-
326308
// Sync the raw payload with updated messages so that withBody()
327309
// sends the updated payload on the next iteration.
328310
if err := i.syncPayloadMessages(messages.Messages); err != nil {

intercept/messages/streaming.go

Lines changed: 18 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -255,18 +255,16 @@ newStream:
255255

256256
// Capture any thinking blocks that were returned.
257257
var thoughtRecords []*recorder.ModelThoughtRecord
258-
if !i.isSmallFastModel() { // TODO: remove.
259-
for _, block := range message.Content {
260-
switch variant := block.AsAny().(type) {
261-
case anthropic.ThinkingBlock:
262-
thoughtRecords = append(thoughtRecords, &recorder.ModelThoughtRecord{
263-
InterceptionID: i.ID().String(),
264-
Content: variant.Thinking,
265-
})
266-
case anthropic.RedactedThinkingBlock:
267-
// For redacted thinking, there's nothing useful we can capture.
268-
continue
269-
}
258+
for _, block := range message.Content {
259+
switch variant := block.AsAny().(type) {
260+
case anthropic.ThinkingBlock:
261+
thoughtRecords = append(thoughtRecords, &recorder.ModelThoughtRecord{
262+
Content: variant.Thinking,
263+
CreatedAt: time.Now(),
264+
})
265+
case anthropic.RedactedThinkingBlock:
266+
// For redacted thinking, there's nothing useful we can capture.
267+
continue
270268
}
271269
}
272270

@@ -322,12 +320,11 @@ newStream:
322320
Args: input,
323321
Injected: true,
324322
InvocationError: err,
323+
ModelThoughts: thoughtRecords,
325324
})
326-
327-
// Associate the model thoughts with this tool call.
328-
for _, thought := range thoughtRecords {
329-
thought.ProviderToolCallID = id
330-
}
325+
// Clear after first use to avoid duplicating across
326+
// multiple tool calls in the same message.
327+
thoughtRecords = nil
331328

332329
if err != nil {
333330
// Always provide a tool_result even if the tool call failed
@@ -413,15 +410,6 @@ newStream:
413410
}
414411
}
415412

416-
// Only persist thoughts that are associated to a tool call.
417-
for _, thought := range thoughtRecords {
418-
if thought.ProviderToolCallID == "" {
419-
continue
420-
}
421-
422-
_ = i.recorder.RecordModelThought(streamCtx, thought)
423-
}
424-
425413
// Sync the raw payload with updated messages so that withBody()
426414
// sends the updated payload on the next iteration.
427415
if syncErr := i.syncPayloadMessages(messages.Messages); syncErr != nil {
@@ -448,23 +436,13 @@ newStream:
448436
Tool: variant.Name,
449437
Args: variant.Input,
450438
Injected: false,
439+
ModelThoughts: thoughtRecords,
451440
})
452-
453-
// Associate the model thoughts with this tool call.
454-
for _, thought := range thoughtRecords {
455-
thought.ProviderToolCallID = variant.ID
456-
}
441+
// Clear after first use to avoid duplicating across
442+
// multiple tool calls in the same message.
443+
thoughtRecords = nil
457444
}
458445
}
459-
460-
// Only persist thoughts that are associated to a tool call.
461-
for _, thought := range thoughtRecords {
462-
if thought.ProviderToolCallID == "" {
463-
continue
464-
}
465-
466-
_ = i.recorder.RecordModelThought(streamCtx, thought)
467-
}
468446
}
469447
}
470448

internal/testutil/mock_recorder.go

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ type MockRecorder struct {
2020
tokenUsages []*recorder.TokenUsageRecord
2121
userPrompts []*recorder.PromptUsageRecord
2222
toolUsages []*recorder.ToolUsageRecord
23-
modelThoughts []*recorder.ModelThoughtRecord
2423
interceptionsEnd map[string]*recorder.InterceptionRecordEnded
2524
}
2625

@@ -65,13 +64,6 @@ func (m *MockRecorder) RecordToolUsage(ctx context.Context, req *recorder.ToolUs
6564
return nil
6665
}
6766

68-
func (m *MockRecorder) RecordModelThought(ctx context.Context, req *recorder.ModelThoughtRecord) error {
69-
m.mu.Lock()
70-
defer m.mu.Unlock()
71-
m.modelThoughts = append(m.modelThoughts, req)
72-
return nil
73-
}
74-
7567
// RecordedTokenUsages returns a copy of recorded token usages in a thread-safe manner.
7668
// Note: This is a shallow clone - the slice is copied but the pointers reference the
7769
// same underlying records. This is sufficient for our test assertions which only read
@@ -114,13 +106,6 @@ func (m *MockRecorder) ToolUsages() []*recorder.ToolUsageRecord {
114106
return m.toolUsages
115107
}
116108

117-
// RecordedModelThoughts returns a copy of recorded model thoughts in a thread-safe manner.
118-
func (m *MockRecorder) RecordedModelThoughts() []*recorder.ModelThoughtRecord {
119-
m.mu.Lock()
120-
defer m.mu.Unlock()
121-
return slices.Clone(m.modelThoughts)
122-
}
123-
124109
// RecordedInterceptionEnd returns the stored InterceptionRecordEnded for the
125110
// given interception ID, or nil if not found.
126111
func (m *MockRecorder) RecordedInterceptionEnd(id string) *recorder.InterceptionRecordEnded {

recorder/recorder.go

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -116,24 +116,6 @@ func (r *RecorderWrapper) RecordToolUsage(ctx context.Context, req *ToolUsageRec
116116
return err
117117
}
118118

119-
func (r *RecorderWrapper) RecordModelThought(ctx context.Context, req *ModelThoughtRecord) (outErr error) {
120-
ctx, span := r.tracer.Start(ctx, "Intercept.RecordModelThought", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
121-
defer tracing.EndSpanErr(span, &outErr)
122-
123-
client, err := r.clientFn()
124-
if err != nil {
125-
return fmt.Errorf("acquire client: %w", err)
126-
}
127-
128-
req.CreatedAt = time.Now()
129-
if err = client.RecordModelThought(ctx, req); err == nil {
130-
return nil
131-
}
132-
133-
r.logger.Warn(ctx, "failed to record model thought", slog.Error(err), slog.F("interception_id", req.InterceptionID))
134-
return err
135-
}
136-
137119
func NewRecorder(logger slog.Logger, tracer trace.Tracer, clientFn func() (Recorder, error)) *RecorderWrapper {
138120
return &RecorderWrapper{
139121
logger: logger,
@@ -270,22 +252,6 @@ func (a *AsyncRecorder) RecordToolUsage(ctx context.Context, req *ToolUsageRecor
270252
return nil // Caller is not interested in error.
271253
}
272254

273-
func (a *AsyncRecorder) RecordModelThought(ctx context.Context, req *ModelThoughtRecord) error {
274-
a.wg.Add(1)
275-
go func() {
276-
defer a.wg.Done()
277-
timedCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), a.timeout)
278-
defer cancel()
279-
280-
err := a.wrapped.RecordModelThought(timedCtx, req)
281-
if err != nil {
282-
a.logger.Warn(timedCtx, "failed to record model thought", slog.Error(err), slog.F("payload", req))
283-
}
284-
}()
285-
286-
return nil // Caller is not interested in error.
287-
}
288-
289255
func (a *AsyncRecorder) Wait() {
290256
a.wg.Wait()
291257
}

recorder/types.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@ type Recorder interface {
1818
// RecordPromptUsage records the prompts used in an interception with an upstream AI provider.
1919
RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) error
2020
// RecordToolUsage records the tools used in an interception with an upstream AI provider.
21+
// Any associated model thoughts should be included in the ToolUsageRecord.
2122
RecordToolUsage(ctx context.Context, req *ToolUsageRecord) error
22-
// RecordModelThought records the reasoning/thinking produced in an interception with an upstream AI provider.
23-
RecordModelThought(ctx context.Context, req *ModelThoughtRecord) error
2423
}
2524

2625
type ToolArgs any
@@ -74,12 +73,11 @@ type ToolUsageRecord struct {
7473
InvocationError error
7574
Metadata Metadata
7675
CreatedAt time.Time
76+
ModelThoughts []*ModelThoughtRecord
7777
}
7878

7979
type ModelThoughtRecord struct {
80-
InterceptionID string
81-
ProviderToolCallID string
82-
Content string
83-
Metadata Metadata
84-
CreatedAt time.Time
80+
Content string
81+
Metadata Metadata
82+
CreatedAt time.Time
8583
}

trace_integration_test.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ func TestTraceAnthropic(t *testing.T) {
4646
{"Intercept.RecordPromptUsage", 1, codes.Unset},
4747
{"Intercept.RecordTokenUsage", 1, codes.Unset},
4848
{"Intercept.RecordToolUsage", 1, codes.Unset},
49-
{"Intercept.RecordModelThought", 1, codes.Unset},
5049
{"Intercept.ProcessRequest.Upstream", 1, codes.Unset},
5150
}
5251

@@ -59,7 +58,6 @@ func TestTraceAnthropic(t *testing.T) {
5958
{"Intercept.RecordPromptUsage", 1, codes.Unset},
6059
{"Intercept.RecordTokenUsage", 2, codes.Unset},
6160
{"Intercept.RecordToolUsage", 1, codes.Unset},
62-
{"Intercept.RecordModelThought", 1, codes.Unset},
6361
{"Intercept.ProcessRequest.Upstream", 1, codes.Unset},
6462
}
6563

0 commit comments

Comments
 (0)