Skip to content

Commit fcf5b8c

Browse files
committed
Fix tool call streaming and non-streaming response handling
1 parent b3cca07 commit fcf5b8c

3 files changed

Lines changed: 88 additions & 10 deletions

File tree

internal/api/commandcode.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ type CCRequestBody struct {
4848
type CCStreamEvent struct {
4949
Type string `json:"type"`
5050
Text string `json:"text"`
51+
ToolCallID string `json:"toolCallId"`
52+
ToolName string `json:"toolName"`
5153
FinishReason string `json:"finishReason"`
5254
Error *struct {
5355
Message string `json:"message"`

internal/api/openai.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ type OpenAIChatRequest struct {
4747
Temperature *float64 `json:"temperature,omitempty"`
4848
MaxTokens *int `json:"max_tokens,omitempty"`
4949
Stream bool `json:"stream,omitempty"`
50+
Tools []any `json:"tools,omitempty"`
5051
}
5152

5253
type OpenAIChoice struct {
@@ -57,8 +58,21 @@ type OpenAIChoice struct {
5758
}
5859

5960
type OpenAIDelta struct {
60-
Role string `json:"role,omitempty"`
61-
Content string `json:"content,omitempty"`
61+
Role string `json:"role,omitempty"`
62+
Content string `json:"content,omitempty"`
63+
ToolCalls []OpenAIDeltaToolCall `json:"tool_calls,omitempty"`
64+
}
65+
66+
type OpenAIDeltaToolCall struct {
67+
Index int `json:"index"`
68+
ID string `json:"id,omitempty"`
69+
Type string `json:"type,omitempty"`
70+
Function *OpenAIDeltaFunction `json:"function,omitempty"`
71+
}
72+
73+
type OpenAIDeltaFunction struct {
74+
Name string `json:"name,omitempty"`
75+
Arguments string `json:"arguments,omitempty"`
6276
}
6377

6478
type OpenAIUsage struct {

internal/proxy/proxy.go

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ func (p *Proxy) BuildRequest(openAIReq api.OpenAIChatRequest) (api.CCRequestBody
5151
maxTokens = *openAIReq.MaxTokens
5252
}
5353

54+
tools := openAIReq.Tools
55+
if tools == nil {
56+
tools = []any{}
57+
}
58+
5459
ccBody := api.CCRequestBody{
5560
Config: api.CCConfig{
5661
WorkingDir: ".",
@@ -69,7 +74,7 @@ func (p *Proxy) BuildRequest(openAIReq api.OpenAIChatRequest) (api.CCRequestBody
6974
Params: api.CCChatParams{
7075
Model: model,
7176
Messages: ccMessages,
72-
Tools: []any{},
77+
Tools: tools,
7378
System: system,
7479
MaxTokens: maxTokens,
7580
Temperature: temperature,
@@ -204,6 +209,7 @@ func (p *Proxy) StreamResponse(w http.ResponseWriter, r *http.Request, ccResp *h
204209
scanner := bufio.NewScanner(ccResp.Body)
205210
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
206211
sentRole := false
212+
toolCallIndex := 0
207213

208214
for scanner.Scan() {
209215
select {
@@ -237,6 +243,40 @@ func (p *Proxy) StreamResponse(w http.ResponseWriter, r *http.Request, ccResp *h
237243
Choices: []api.OpenAIChoice{{Index: 0, Delta: &delta}},
238244
})
239245

246+
case "tool-use":
247+
toolCalls := []api.OpenAIDeltaToolCall{{
248+
Index: toolCallIndex,
249+
ID: event.ToolCallID,
250+
Type: "function",
251+
Function: &api.OpenAIDeltaFunction{Name: event.ToolName},
252+
}}
253+
delta := api.OpenAIDelta{ToolCalls: toolCalls}
254+
if !sentRole {
255+
delta.Role = "assistant"
256+
sentRole = true
257+
}
258+
p.WriteSSE(w, flusher, api.OpenAIChatResponse{
259+
ID: requestID,
260+
Object: "chat.completion.chunk",
261+
Created: created,
262+
Model: model,
263+
Choices: []api.OpenAIChoice{{Index: 0, Delta: &delta}},
264+
})
265+
toolCallIndex++
266+
267+
case "tool-delta":
268+
toolCalls := []api.OpenAIDeltaToolCall{{
269+
Index: toolCallIndex - 1,
270+
Function: &api.OpenAIDeltaFunction{Arguments: event.Text},
271+
}}
272+
p.WriteSSE(w, flusher, api.OpenAIChatResponse{
273+
ID: requestID,
274+
Object: "chat.completion.chunk",
275+
Created: created,
276+
Model: model,
277+
Choices: []api.OpenAIChoice{{Index: 0, Delta: &api.OpenAIDelta{ToolCalls: toolCalls}}},
278+
})
279+
240280
case "finish":
241281
reason := "stop"
242282
if event.FinishReason == "tool_calls" {
@@ -280,6 +320,8 @@ func (p *Proxy) NonStreamResponse(w http.ResponseWriter, ccResp *http.Response,
280320

281321
var content strings.Builder
282322
var inputTokens, outputTokens int
323+
var hasToolCalls bool
324+
var toolCalls []api.ToolCall
283325

284326
for scanner.Scan() {
285327
line := strings.TrimSpace(scanner.Text())
@@ -295,6 +337,20 @@ func (p *Proxy) NonStreamResponse(w http.ResponseWriter, ccResp *http.Response,
295337
switch event.Type {
296338
case "text-delta":
297339
content.WriteString(event.Text)
340+
case "tool-use":
341+
hasToolCalls = true
342+
toolCalls = append(toolCalls, api.ToolCall{
343+
ID: event.ToolCallID,
344+
Type: "function",
345+
Function: api.FunctionCall{
346+
Name: event.ToolName,
347+
Arguments: "",
348+
},
349+
})
350+
case "tool-delta":
351+
if len(toolCalls) > 0 {
352+
toolCalls[len(toolCalls)-1].Function.Arguments += event.Text
353+
}
298354
case "finish":
299355
if event.TotalUsage != nil {
300356
inputTokens = event.TotalUsage.InputTokens
@@ -305,26 +361,32 @@ func (p *Proxy) NonStreamResponse(w http.ResponseWriter, ccResp *http.Response,
305361
}
306362
}
307363

364+
msg := &api.OpenAIMessage{
365+
Role: "assistant",
366+
Content: content.String(),
367+
}
368+
finishReason := "stop"
369+
if hasToolCalls {
370+
msg.ToolCalls = toolCalls
371+
finishReason = "tool_calls"
372+
}
373+
308374
response := api.OpenAIChatResponse{
309375
ID: requestID,
310376
Object: "chat.completion",
311377
Created: created,
312378
Model: model,
313379
Choices: []api.OpenAIChoice{{
314-
Index: 0,
315-
Message: &api.OpenAIMessage{
316-
Role: "assistant",
317-
Content: content.String(),
318-
},
319-
FinishReason: new(string),
380+
Index: 0,
381+
Message: msg,
382+
FinishReason: &finishReason,
320383
}},
321384
Usage: &api.OpenAIUsage{
322385
PromptTokens: inputTokens,
323386
CompletionTokens: outputTokens,
324387
TotalTokens: inputTokens + outputTokens,
325388
},
326389
}
327-
*response.Choices[0].FinishReason = "stop"
328390

329391
w.Header().Set("Content-Type", "application/json")
330392
json.NewEncoder(w).Encode(response)

0 commit comments

Comments
 (0)