Skip to content

Commit 809aae2

Browse files
committed
chore: replace ResponsesNewParamsWrapper with ResponsesRequestPayload in responses interceptor
1 parent 53eb065 commit 809aae2

11 files changed

Lines changed: 634 additions & 234 deletions

File tree

intercept/responses/base.go

Lines changed: 56 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,8 @@ const (
3838

3939
type responsesInterceptionBase struct {
4040
id uuid.UUID
41-
req *ResponsesNewParamsWrapper
42-
reqPayload []byte
41+
reqPayload ResponsesRequestPayload
4342
cfg config.OpenAI
44-
model string
4543
recorder recorder.Recorder
4644
mcpProxy mcp.ServerProxier
4745
logger slog.Logger
@@ -71,26 +69,37 @@ func (i *responsesInterceptionBase) ID() uuid.UUID {
7169
}
7270

7371
func (i *responsesInterceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
74-
i.logger = logger.With(slog.F("model", i.model))
72+
i.logger = logger.With(slog.F("model", i.Model()))
7573
i.recorder = recorder
7674
i.mcpProxy = mcpProxy
7775
}
7876

7977
func (i *responsesInterceptionBase) Model() string {
80-
return i.model
78+
return i.reqPayload.model()
8179
}
8280

8381
func (i *responsesInterceptionBase) CorrelatingToolCallID() *string {
84-
if len(i.req.Input.OfInputItemList) == 0 {
82+
items := gjson.GetBytes(i.reqPayload, "input")
83+
if !items.IsArray() {
8584
return nil
8685
}
8786

88-
// The tool result should be the last input message.
89-
item := i.req.Input.OfInputItemList[len(i.req.Input.OfInputItemList)-1]
90-
if item.OfFunctionCallOutput == nil {
87+
arr := items.Array()
88+
if len(arr) == 0 {
9189
return nil
9290
}
93-
return &item.OfFunctionCallOutput.CallID
91+
92+
last := arr[len(arr)-1]
93+
if last.Get(string(constant.ValueOf[constant.Type]())).String() != string(constant.ValueOf[constant.FunctionCallOutput]()) {
94+
return nil
95+
}
96+
97+
callID := last.Get("call_id").String()
98+
if callID == "" {
99+
return nil
100+
}
101+
102+
return &callID
94103
}
95104

96105
func (i *responsesInterceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
@@ -105,13 +114,7 @@ func (i *responsesInterceptionBase) baseTraceAttributes(r *http.Request, streami
105114
}
106115

107116
func (i *responsesInterceptionBase) validateRequest(ctx context.Context, w http.ResponseWriter) error {
108-
if i.req == nil {
109-
err := errors.New("developer error: req is nil")
110-
i.sendCustomErr(ctx, w, http.StatusInternalServerError, err)
111-
return err
112-
}
113-
114-
if i.req.Background.Value {
117+
if i.reqPayload.background() {
115118
err := fmt.Errorf("background requests are currently not supported by AI Bridge")
116119
i.sendCustomErr(ctx, w, http.StatusNotImplemented, err)
117120
return err
@@ -144,15 +147,15 @@ func (i *responsesInterceptionBase) requestOptions(respCopy *responseCopier) []o
144147
// eg. Codex CLI produces requests without ID set in reasoning items: https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-item-reasoning-id
145148
// when re-encoded, ID field is set to empty string which results
146149
// in bad request while not sending ID field at all somehow works.
147-
option.WithRequestBody("application/json", i.reqPayload),
150+
option.WithRequestBody("application/json", []byte(i.reqPayload)),
148151

149152
// copyMiddleware copies body of original response body to the buffer in responseCopier,
150153
// also reference to headers and status code is kept responseCopier.
151154
// responseCopier is used by interceptors to forward response as it was received,
152155
// eliminating any possibility of JSON re-encoding issues.
153156
option.WithMiddleware(respCopy.copyMiddleware),
154157
}
155-
if !i.req.Stream {
158+
if !i.reqPayload.Stream() {
156159
opts = append(opts, option.WithRequestTimeout(requestTimeout))
157160
}
158161
return opts
@@ -161,81 +164,83 @@ func (i *responsesInterceptionBase) requestOptions(respCopy *responseCopier) []o
161164
// lastUserPrompt returns input text with "user" role from last input item
162165
// or string input value if it is present + bool indicating if input was found or not.
163166
// If no such input was found empty string + false is returned.
164-
func (i *responsesInterceptionBase) lastUserPrompt(ctx context.Context) (string, bool, error) {
167+
func (i *responsesInterceptionBase) lastUserPrompt() (string, bool, error) {
165168
if i == nil {
166169
return "", false, errors.New("cannot get last user prompt: nil struct")
167170
}
168-
if i.req == nil {
171+
if i.reqPayload == nil {
169172
return "", false, errors.New("cannot get last user prompt: nil request struct")
170173
}
171174

172-
// 'input' field can be a string or array of objects:
175+
// 'input' can be either a string or an array of input items:
173176
// https://platform.openai.com/docs/api-reference/responses/create#responses_create-input
174-
175-
// Check string variant
176-
if i.req.Input.OfString.Valid() {
177-
return i.req.Input.OfString.Value, true, nil
177+
inputItems := gjson.GetBytes(i.reqPayload, "input")
178+
if !inputItems.Exists() || inputItems.Type == gjson.Null {
179+
return "", false, nil
178180
}
179181

180-
// Fallback to parsing original bytes since golang SDK doesn't properly decode 'Input' field.
181-
// If 'type' field of input item is not set it will be omitted from 'Input.OfInputItemList'
182-
// It is an optional field according to API: https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message
183-
// example: fixtures/openai/responses/blocking/builtin_tool.txtar
184-
inputItems := gjson.GetBytes(i.reqPayload, "input")
182+
// String variant: treat the whole input as the user prompt.
183+
if inputItems.Type == gjson.String {
184+
return inputItems.String(), true, nil
185+
}
185186

187+
// Array variant: checking only the last input item
186188
if !inputItems.IsArray() {
187-
if inputItems.Type == gjson.Null {
188-
return "", false, nil
189-
}
190-
return "", false, fmt.Errorf("unexpected input type: %v", inputItems.Type.String())
189+
return "", false, fmt.Errorf("unexpected input type: %s", inputItems.Type)
191190
}
192191

193192
inputItemsArr := inputItems.Array()
194193
if len(inputItemsArr) == 0 {
195194
return "", false, nil
196195
}
197-
lastItem := inputItemsArr[len(inputItemsArr)-1]
198196

199-
// Request was likely not human-initiated.
197+
lastItem := inputItemsArr[len(inputItemsArr)-1]
200198
if lastItem.Get("role").Str != string(constant.ValueOf[constant.User]()) {
199+
// Request was likely not initiated by a prompt but is an iteration of agentic loop.
201200
return "", false, nil
202201
}
203202

204-
// content can be a string or array of objects:
203+
// Message content can be either a string or an array of typed content items:
205204
// https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-input_message-content
206205
content := lastItem.Get(string(constant.ValueOf[constant.Content]()))
206+
if !content.Exists() || content.Type == gjson.Null {
207+
return "", false, nil
208+
}
209+
210+
// String variant: use it directly as the prompt.
211+
if content.Type == gjson.String {
212+
return content.Str, true, nil
213+
}
207214

208-
// non array case, should be string
209215
if !content.IsArray() {
210-
if content.Type == gjson.String {
211-
return content.Str, true, nil
212-
}
213-
return "", false, fmt.Errorf("unexpected input content type: %v", content.Type.String())
216+
return "", false, fmt.Errorf("unexpected input content type: %s", content.Type)
214217
}
215218

216219
var sb strings.Builder
217220
promptExists := false
218221
for _, c := range content.Array() {
219-
// ignore inputs of not `input_text` type
222+
// Ignore non-text content blocks such as images or files.
220223
if c.Get(string(constant.ValueOf[constant.Type]())).Str != string(constant.ValueOf[constant.InputText]()) {
221224
continue
222225
}
223226

224227
text := c.Get(string(constant.ValueOf[constant.Text]()))
225-
if text.Type == gjson.String {
226-
promptExists = true
227-
sb.WriteString(text.Str + "\n")
228-
} else {
229-
i.logger.Warn(ctx, fmt.Sprintf("unexpected input content array element text type: %v", text.Type))
228+
if text.Type != gjson.String {
229+
continue
230+
}
231+
232+
if promptExists {
233+
sb.WriteByte('\n')
230234
}
235+
promptExists = true
236+
sb.WriteString(text.Str)
231237
}
232238

233239
if !promptExists {
234240
return "", false, nil
235241
}
236242

237-
prompt := strings.TrimSuffix(sb.String(), "\n")
238-
return prompt, true, nil
243+
return sb.String(), true, nil
239244
}
240245

241246
func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, responseID string, prompt string) {

intercept/responses/base_test.go

Lines changed: 36 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"github.com/coder/aibridge/utils"
1313
"github.com/google/uuid"
1414
oairesponses "github.com/openai/openai-go/v3/responses"
15+
"github.com/stretchr/testify/assert"
1516
"github.com/stretchr/testify/require"
1617
)
1718

@@ -20,95 +21,53 @@ func TestScanForCorrelatingToolCallID(t *testing.T) {
2021

2122
tests := []struct {
2223
name string
23-
input []oairesponses.ResponseInputItemUnionParam
24-
expected *string
24+
payload []byte
25+
wantCall *string
2526
}{
2627
{
27-
name: "no input items",
28-
input: nil,
29-
expected: nil,
28+
name: "no input",
29+
payload: []byte(`{"model":"gpt-4o"}`),
3030
},
3131
{
32-
name: "no function_call_output items",
33-
input: []oairesponses.ResponseInputItemUnionParam{
34-
{
35-
OfMessage: &oairesponses.EasyInputMessageParam{
36-
Role: "user",
37-
},
38-
},
39-
},
40-
expected: nil,
32+
name: "empty input array",
33+
payload: []byte(`{"model":"gpt-4o","input":[]}`),
4134
},
4235
{
43-
name: "single function_call_output",
44-
input: []oairesponses.ResponseInputItemUnionParam{
45-
{
46-
OfMessage: &oairesponses.EasyInputMessageParam{
47-
Role: "user",
48-
},
49-
},
50-
{
51-
OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{
52-
CallID: "call_abc",
53-
},
54-
},
55-
},
56-
expected: utils.PtrTo("call_abc"),
36+
name: "no function_call_output items",
37+
payload: []byte(`{"model":"gpt-4o","input":[{"role":"user","content":"hi"}]}`),
5738
},
5839
{
59-
name: "multiple function_call_outputs returns last",
60-
input: []oairesponses.ResponseInputItemUnionParam{
61-
{
62-
OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{
63-
CallID: "call_first",
64-
},
65-
},
66-
{
67-
OfMessage: &oairesponses.EasyInputMessageParam{
68-
Role: "user",
69-
},
70-
},
71-
{
72-
OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{
73-
CallID: "call_second",
74-
},
75-
},
76-
},
77-
expected: utils.PtrTo("call_second"),
40+
name: "single function_call_output",
41+
payload: []byte(`{"model":"gpt-4o","input":[{"role":"user","content":"hi"},{"type":"function_call_output","call_id":"call_abc","output":"result"}]}`),
42+
wantCall: utils.PtrTo("call_abc"),
7843
},
7944
{
80-
name: "last input is not a tool result",
81-
input: []oairesponses.ResponseInputItemUnionParam{
82-
{
83-
OfFunctionCallOutput: &oairesponses.ResponseInputItemFunctionCallOutputParam{
84-
CallID: "call_first",
85-
},
86-
},
87-
{
88-
OfMessage: &oairesponses.EasyInputMessageParam{
89-
Role: "user",
90-
},
91-
},
92-
},
93-
expected: nil,
45+
name: "multiple function_call_outputs returns last",
46+
payload: []byte(`{"model":"gpt-4o","input":[{"type":"function_call_output","call_id":"call_first","output":"r1"},{"role":"user","content":"hi"},{"type":"function_call_output","call_id":"call_second","output":"r2"}]}`),
47+
wantCall: utils.PtrTo("call_second"),
48+
},
49+
{
50+
name: "last input is not a tool result",
51+
payload: []byte(`{"model":"gpt-4o","input":[{"type":"function_call_output","call_id":"call_first","output":"r1"},{"role":"user","content":"hi"}]}`),
52+
},
53+
{
54+
name: "missing call id",
55+
payload: []byte(`{"input":[{"type":"function_call_output","output":"ok"}]}`),
9456
},
9557
}
9658

9759
for _, tc := range tests {
9860
t.Run(tc.name, func(t *testing.T) {
9961
t.Parallel()
10062

63+
rp, err := NewResponsesRequestPayload(tc.payload)
64+
require.NoError(t, err)
10165
base := &responsesInterceptionBase{
102-
req: &ResponsesNewParamsWrapper{
103-
ResponseNewParams: oairesponses.ResponseNewParams{
104-
Input: oairesponses.ResponseNewParamsInputUnion{
105-
OfInputItemList: tc.input,
106-
},
107-
},
108-
},
66+
reqPayload: rp,
10967
}
11068

111-
require.Equal(t, tc.expected, base.CorrelatingToolCallID())
69+
callID := base.CorrelatingToolCallID()
70+
assert.Equal(t, tc.wantCall, callID)
11271
})
11372
}
11473
}
@@ -161,16 +120,13 @@ func TestLastUserPrompt(t *testing.T) {
161120
t.Run(tc.name, func(t *testing.T) {
162121
t.Parallel()
163122

164-
req := &ResponsesNewParamsWrapper{}
165-
err := req.UnmarshalJSON(tc.reqPayload)
123+
rp, err := NewResponsesRequestPayload(tc.reqPayload)
166124
require.NoError(t, err)
167-
168125
base := &responsesInterceptionBase{
169-
req: req,
170-
reqPayload: tc.reqPayload,
126+
reqPayload: rp,
171127
}
172128

173-
prompt, promptFound, err := base.lastUserPrompt(t.Context())
129+
prompt, promptFound, err := base.lastUserPrompt()
174130
require.NoError(t, err)
175131
require.Equal(t, tc.expect, prompt)
176132
require.True(t, promptFound)
@@ -185,7 +141,7 @@ func TestLastUserPromptNotFound(t *testing.T) {
185141
t.Parallel()
186142

187143
var base *responsesInterceptionBase
188-
prompt, promptFound, err := base.lastUserPrompt(t.Context())
144+
prompt, promptFound, err := base.lastUserPrompt()
189145
require.Error(t, err)
190146
require.Empty(t, prompt)
191147
require.False(t, promptFound)
@@ -196,7 +152,7 @@ func TestLastUserPromptNotFound(t *testing.T) {
196152
t.Parallel()
197153

198154
base := responsesInterceptionBase{}
199-
prompt, promptFound, err := base.lastUserPrompt(t.Context())
155+
prompt, promptFound, err := base.lastUserPrompt()
200156
require.Error(t, err)
201157
require.Empty(t, prompt)
202158
require.False(t, promptFound)
@@ -253,16 +209,14 @@ func TestLastUserPromptNotFound(t *testing.T) {
253209
t.Run(tc.name, func(t *testing.T) {
254210
t.Parallel()
255211

256-
req := &ResponsesNewParamsWrapper{}
257-
err := req.UnmarshalJSON(tc.reqPayload)
212+
rp, err := NewResponsesRequestPayload(tc.reqPayload)
258213
require.NoError(t, err)
259214

260215
base := &responsesInterceptionBase{
261-
req: req,
262-
reqPayload: tc.reqPayload,
216+
reqPayload: rp,
263217
}
264218

265-
prompt, promptFound, err := base.lastUserPrompt(t.Context())
219+
prompt, promptFound, err := base.lastUserPrompt()
266220
if tc.expectErr != "" {
267221
require.Error(t, err)
268222
require.Contains(t, err.Error(), tc.expectErr)

0 commit comments

Comments
 (0)