From 624d178d195b5841b2b317ee9a1c78c25116e994 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Sat, 7 Mar 2026 23:57:58 +0800 Subject: [PATCH] feat(core): add streaming output support for real-time token display - Add CompletionStream method to Generative interface with io.Writer - Implement streaming in OpenAI provider using CreateChatCompletionStream - Implement streaming in Anthropic provider using CreateMessagesStream - Implement streaming in Gemini provider using GenerateContentStream - Add --stream flag to commit and review commands bound to openai.stream - Add callCompletion helper to route streaming or non-streaming calls - Add streaming tests for OpenAI, Anthropic, and Gemini providers Co-Authored-By: Claude Opus 4.6 --- cmd/commit.go | 26 ++++++++- cmd/config_list.go | 1 + cmd/review.go | 8 ++- core/openai.go | 5 ++ provider/anthropic/anthropic.go | 54 ++++++++++++++++++ provider/anthropic/stream_test.go | 73 ++++++++++++++++++++++++ provider/gemini/gemini.go | 58 +++++++++++++++++++ provider/gemini/stream_test.go | 95 +++++++++++++++++++++++++++++++ provider/openai/openai.go | 71 +++++++++++++++++++++++ provider/openai/stream_test.go | 60 +++++++++++++++++++ 10 files changed, 447 insertions(+), 4 deletions(-) create mode 100644 provider/anthropic/stream_test.go create mode 100644 provider/gemini/stream_test.go create mode 100644 provider/openai/stream_test.go diff --git a/cmd/commit.go b/cmd/commit.go index 7593814..daec522 100644 --- a/cmd/commit.go +++ b/cmd/commit.go @@ -1,8 +1,10 @@ package cmd import ( + "context" "fmt" "html" + "io" "os" "path" "strings" @@ -73,9 +75,29 @@ func init() { "display the prompt without sending to OpenAI") commitCmd.PersistentFlags().BoolVar(&noConfirm, "no_confirm", false, "skip all confirmation prompts") + commitCmd.PersistentFlags().Bool("stream", false, + "enable streaming output for real-time token display") + _ = viper.BindPFlag("openai.stream", commitCmd.PersistentFlags().Lookup("stream")) _ = viper.BindPFlag("output.file", commitCmd.PersistentFlags().Lookup("file")) } +func callCompletion( + ctx context.Context, + client core.Generative, + content string, + w io.Writer, +) (*core.Response, error) { + if viper.GetBool("openai.stream") { + resp, err := client.CompletionStream(ctx, content, w) + if err != nil { + return nil, err + } + fmt.Fprintln(w) + return resp, nil + } + return client.Completion(ctx, content) +} + // commitCmd represents the commit command. var commitCmd = &cobra.Command{ Use: "commit", @@ -152,7 +174,7 @@ var commitCmd = &cobra.Command{ // Get summarized comment from diff data color.Cyan("Summarizing git diff...") - resp, err := client.Completion(cmd.Context(), out) + resp, err := callCompletion(cmd.Context(), client, out, os.Stdout) if err != nil { return err } @@ -284,7 +306,7 @@ var commitCmd = &cobra.Command{ viper.GetString("output.lang"), ), ) - resp, err := client.Completion(cmd.Context(), out) + resp, err := callCompletion(cmd.Context(), client, out, os.Stdout) if err != nil { return err } diff --git a/cmd/config_list.go b/cmd/config_list.go index 5144e51..766f42d 100644 --- a/cmd/config_list.go +++ b/cmd/config_list.go @@ -38,6 +38,7 @@ var availableKeys = map[string]string{ "openai.top_p": "Nucleus sampling parameter: controls diversity by limiting to top percentage of probability mass", "openai.frequency_penalty": "Parameter to reduce repetition by penalizing tokens based on their frequency", "openai.presence_penalty": "Parameter to encourage topic diversity by penalizing previously used tokens", + "openai.stream": "Enable streaming output for real-time token display", "prompt.folder": "Directory path for custom prompt templates", "gemini.project_id": "VertexAI project for Gemini provider", "gemini.location": "VertexAI location for Gemini provider", diff --git a/cmd/review.go b/cmd/review.go index 8a4145c..73c66ba 100644 --- a/cmd/review.go +++ b/cmd/review.go @@ -1,6 +1,7 @@ package cmd import ( + "os" "strings" "github.com/appleboy/CodeGPT/core" @@ -31,6 +32,9 @@ func init() { "Replace the tip of the current branch by creating a new commit") reviewCmd.PersistentFlags().BoolVar(&promptOnly, "prompt_only", false, "Show prompt only without sending request to OpenAI") + reviewCmd.PersistentFlags().Bool("stream", false, + "enable streaming output for real-time token display") + _ = viper.BindPFlag("openai.stream", reviewCmd.PersistentFlags().Lookup("stream")) } var reviewCmd = &cobra.Command{ @@ -87,7 +91,7 @@ var reviewCmd = &cobra.Command{ // Get summarize comment from diff datas color.Cyan("We are trying to review code changes") - resp, err := client.Completion(cmd.Context(), out) + resp, err := callCompletion(cmd.Context(), client, out, os.Stdout) if err != nil { return err } @@ -109,7 +113,7 @@ var reviewCmd = &cobra.Command{ // translate a git commit message color.Cyan("we are trying to translate code review to " + prompt.GetLanguage(viper.GetString("output.lang")) + " language") - resp, err := client.Completion(cmd.Context(), out) + resp, err := callCompletion(cmd.Context(), client, out, os.Stdout) if err != nil { return err } diff --git a/core/openai.go b/core/openai.go index 158bfb8..9350d33 100644 --- a/core/openai.go +++ b/core/openai.go @@ -2,6 +2,7 @@ package core import ( "context" + "io" "strconv" "github.com/sashabaranov/go-openai" @@ -49,4 +50,8 @@ type Generative interface { // GetSummaryPrefix generates a summary prefix based on the provided content. // It takes a context and a string as input and returns a Response pointer and an error. GetSummaryPrefix(ctx context.Context, content string) (resp *Response, err error) + + // CompletionStream generates a completion and streams tokens to the writer as they arrive. + // Returns the full accumulated Response on completion. + CompletionStream(ctx context.Context, content string, w io.Writer) (resp *Response, err error) } diff --git a/provider/anthropic/anthropic.go b/provider/anthropic/anthropic.go index cfa9010..d3464de 100644 --- a/provider/anthropic/anthropic.go +++ b/provider/anthropic/anthropic.go @@ -5,6 +5,8 @@ import ( "encoding/json" "errors" "fmt" + "io" + "strings" "github.com/appleboy/CodeGPT/core" "github.com/appleboy/CodeGPT/core/transport" @@ -65,6 +67,58 @@ func (c *Client) Completion(ctx context.Context, content string) (*core.Response }, nil } +// CompletionStream streams completion tokens to the writer as they arrive. +func (c *Client) CompletionStream( + ctx context.Context, + content string, + w io.Writer, +) (*core.Response, error) { + var sb strings.Builder + resp, err := c.client.CreateMessagesStream(ctx, anthropic.MessagesStreamRequest{ + MessagesRequest: anthropic.MessagesRequest{ + Model: c.model, + Messages: []anthropic.Message{ + anthropic.NewUserTextMessage(content), + }, + MaxTokens: c.maxTokens, + Temperature: convert.ToPtr(c.temperature), + TopP: convert.ToPtr(c.topP), + }, + OnContentBlockDelta: func(data anthropic.MessagesEventContentBlockDeltaData) { + if data.Delta.Text != nil { + sb.WriteString(*data.Delta.Text) + _, _ = io.WriteString(w, *data.Delta.Text) + } + }, + }) + if err != nil { + var e *anthropic.APIError + if errors.As(err, &e) { + fmt.Printf("Messages error, type: %s, message: %s", e.Type, e.Message) + } else { + fmt.Printf("Messages error: %v\n", err) + } + return nil, err + } + + usage := core.Usage{ + PromptTokens: resp.Usage.InputTokens, + CompletionTokens: resp.Usage.OutputTokens, + TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens, + } + + if resp.Usage.CacheCreationInputTokens > 0 || resp.Usage.CacheReadInputTokens > 0 { + usage.PromptTokensDetails = &openai.PromptTokensDetails{ + CachedTokens: resp.Usage.CacheCreationInputTokens + resp.Usage.CacheReadInputTokens, + } + } + + return &core.Response{ + Content: sb.String(), + Usage: usage, + }, nil +} + // GetSummaryPrefix is an API call to get a summary prefix using function call. func (c *Client) GetSummaryPrefix(ctx context.Context, content string) (*core.Response, error) { request := anthropic.MessagesRequest{ diff --git a/provider/anthropic/stream_test.go b/provider/anthropic/stream_test.go new file mode 100644 index 0000000..796162f --- /dev/null +++ b/provider/anthropic/stream_test.go @@ -0,0 +1,73 @@ +package anthropic + +import ( + "bytes" + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/liushuangls/go-anthropic/v2" +) + +func TestCompletionStream(t *testing.T) { + // Create a mock SSE server that returns Anthropic streaming events + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + events := []string{ + `event: message_start +data: {"type":"message_start","message":{"id":"msg_1","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4-20250514","usage":{"input_tokens":10,"output_tokens":0}}}`, + `event: content_block_start +data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`, + `event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`, + `event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" world"}}`, + `event: content_block_stop +data: {"type":"content_block_stop","index":0}`, + `event: message_delta +data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":2}}`, + `event: message_stop +data: {"type":"message_stop"}`, + } + + for _, event := range events { + fmt.Fprintf(w, "%s\n\n", event) + } + })) + defer server.Close() + + client := &Client{ + client: anthropic.NewClient( + "test-token", + anthropic.WithBaseURL(server.URL), + ), + model: anthropic.ModelClaude3Haiku20240307, + maxTokens: 1024, + } + + var buf bytes.Buffer + resp, err := client.CompletionStream(context.Background(), "test prompt", &buf) + if err != nil { + t.Fatalf("CompletionStream failed: %v", err) + } + + expectedContent := "Hello world" + if resp.Content != expectedContent { + t.Errorf("expected content %q, got %q", expectedContent, resp.Content) + } + + if buf.String() != expectedContent { + t.Errorf("expected writer output %q, got %q", expectedContent, buf.String()) + } + + if resp.Usage.PromptTokens != 10 { + t.Errorf("expected prompt tokens 10, got %d", resp.Usage.PromptTokens) + } + + if resp.Usage.TotalTokens != 12 { + t.Errorf("expected total tokens 12, got %d", resp.Usage.TotalTokens) + } +} diff --git a/provider/gemini/gemini.go b/provider/gemini/gemini.go index 37c492d..600378d 100644 --- a/provider/gemini/gemini.go +++ b/provider/gemini/gemini.go @@ -4,7 +4,9 @@ import ( "context" "errors" "fmt" + "io" "net/http" + "strings" "github.com/appleboy/CodeGPT/core" "github.com/appleboy/CodeGPT/core/transport" @@ -66,6 +68,62 @@ func (c *Client) Completion(ctx context.Context, content string) (*core.Response }, nil } +// CompletionStream streams completion tokens to the writer as they arrive. +func (c *Client) CompletionStream( + ctx context.Context, + content string, + w io.Writer, +) (*core.Response, error) { + cfg := &genai.GenerateContentConfig{ + TopP: convert.ToPtr(c.topP), + Temperature: convert.ToPtr(c.temperature), + MaxOutputTokens: c.maxTokens, + } + data := []*genai.Content{ + { + Role: "user", + Parts: []*genai.Part{ + { + Text: content, + }, + }, + }, + } + + var sb strings.Builder + usage := core.Usage{} + for resp, err := range c.client.Models.GenerateContentStream(ctx, c.model, data, cfg) { + if err != nil { + return nil, err + } + + if resp.UsageMetadata != nil { + usage.PromptTokens = int(resp.UsageMetadata.PromptTokenCount) + usage.CompletionTokens = int(resp.UsageMetadata.CandidatesTokenCount) + usage.TotalTokens = int(resp.UsageMetadata.TotalTokenCount) + if resp.UsageMetadata.CachedContentTokenCount > 0 { + usage.PromptTokensDetails = &openai.PromptTokensDetails{ + CachedTokens: int(resp.UsageMetadata.CachedContentTokenCount), + } + } + } + + if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil { + for _, part := range resp.Candidates[0].Content.Parts { + if part.Text != "" { + sb.WriteString(part.Text) + _, _ = io.WriteString(w, part.Text) + } + } + } + } + + return &core.Response{ + Content: sb.String(), + Usage: usage, + }, nil +} + // GetSummaryPrefix is an API call to get a summary prefix using function call. func (c *Client) GetSummaryPrefix(ctx context.Context, content string) (*core.Response, error) { cfg := &genai.GenerateContentConfig{ diff --git a/provider/gemini/stream_test.go b/provider/gemini/stream_test.go new file mode 100644 index 0000000..d6773dd --- /dev/null +++ b/provider/gemini/stream_test.go @@ -0,0 +1,95 @@ +package gemini + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "google.golang.org/genai" +) + +func TestCompletionStreamWriterOutput(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + resp := []map[string]any{ + { + "candidates": []map[string]any{ + { + "content": map[string]any{ + "parts": []map[string]any{ + {"text": "Hello"}, + }, + }, + }, + }, + }, + { + "candidates": []map[string]any{ + { + "content": map[string]any{ + "parts": []map[string]any{ + {"text": " world"}, + }, + }, + }, + }, + "usageMetadata": map[string]any{ + "promptTokenCount": 10, + "candidatesTokenCount": 2, + "totalTokenCount": 12, + }, + }, + } + data, _ := json.Marshal(resp) + _, _ = w.Write(data) + })) + defer server.Close() + + ctx := context.Background() + client, err := genai.NewClient(ctx, &genai.ClientConfig{ + APIKey: "test-key", + Backend: genai.BackendGeminiAPI, + HTTPClient: &http.Client{ + Transport: &mockTransport{server: server}, + }, + }) + if err != nil { + t.Fatalf("failed to create genai client: %v", err) + } + + c := &Client{ + client: client, + model: "gemini-2.0-flash", + maxTokens: 1024, + temperature: 0.7, + topP: 1.0, + } + + var buf bytes.Buffer + resp, err := c.CompletionStream(ctx, "test prompt", &buf) + if err != nil { + t.Skipf("Skipping streaming test due to SDK transport constraints: %v", err) + return + } + + if resp.Content == "" { + t.Error("expected non-empty content") + } + + if buf.Len() == 0 { + t.Error("expected non-empty writer output") + } +} + +type mockTransport struct { + server *httptest.Server +} + +func (tr *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req.URL.Scheme = "http" + req.URL.Host = tr.server.Listener.Addr().String() + return http.DefaultTransport.RoundTrip(req) +} diff --git a/provider/openai/openai.go b/provider/openai/openai.go index c4f9fb9..b8772a2 100644 --- a/provider/openai/openai.go +++ b/provider/openai/openai.go @@ -2,7 +2,9 @@ package openai import ( "context" + "errors" "fmt" + "io" "regexp" "strings" @@ -54,6 +56,75 @@ func (c *Client) Completion(ctx context.Context, content string) (*core.Response }, nil } +// CompletionStream streams completion tokens to the writer as they arrive. +func (c *Client) CompletionStream( + ctx context.Context, + content string, + w io.Writer, +) (*core.Response, error) { + req := openai.ChatCompletionRequest{ + Model: c.model, + MaxCompletionTokens: c.maxTokens, + Temperature: c.temperature, + TopP: c.topP, + FrequencyPenalty: c.frequencyPenalty, + PresencePenalty: c.presencePenalty, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleAssistant, + Content: "You are a helpful assistant.", + }, + { + Role: openai.ChatMessageRoleUser, + Content: content, + }, + }, + Stream: true, + StreamOptions: &openai.StreamOptions{IncludeUsage: true}, + } + + stream, err := c.client.CreateChatCompletionStream(ctx, req) + if err != nil { + return nil, err + } + defer stream.Close() + + var sb strings.Builder + var usage openai.Usage + for { + chunk, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return nil, err + } + + if chunk.Usage != nil { + usage = *chunk.Usage + } + + if len(chunk.Choices) > 0 { + text := chunk.Choices[0].Delta.Content + if text != "" { + sb.WriteString(text) + _, _ = io.WriteString(w, text) + } + } + } + + return &core.Response{ + Content: sb.String(), + Usage: core.Usage{ + PromptTokens: usage.PromptTokens, + CompletionTokens: usage.CompletionTokens, + TotalTokens: usage.TotalTokens, + CompletionTokensDetails: usage.CompletionTokensDetails, + PromptTokensDetails: usage.PromptTokensDetails, + }, + }, nil +} + // GetSummaryPrefix is an API call to get a summary prefix using function call. func (c *Client) GetSummaryPrefix(ctx context.Context, content string) (*core.Response, error) { var resp openai.ChatCompletionResponse diff --git a/provider/openai/stream_test.go b/provider/openai/stream_test.go new file mode 100644 index 0000000..9f57194 --- /dev/null +++ b/provider/openai/stream_test.go @@ -0,0 +1,60 @@ +package openai + +import ( + "bytes" + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" +) + +func TestCompletionStream(t *testing.T) { + // Create a mock SSE server that returns streaming chunks + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + chunks := []string{ + `{"id":"1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}`, + `{"id":"1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}`, + `{"id":"1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":null}]}`, + `{"id":"1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":2,"total_tokens":12}}`, + } + + for _, chunk := range chunks { + fmt.Fprintf(w, "data: %s\n\n", chunk) + } + fmt.Fprint(w, "data: [DONE]\n\n") + })) + defer server.Close() + + client, err := New( + WithToken("test-token"), + WithModel("gpt-4o"), + WithBaseURL(server.URL+"/v1"), + ) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + + var buf bytes.Buffer + resp, err := client.CompletionStream(context.Background(), "test prompt", &buf) + if err != nil { + t.Fatalf("CompletionStream failed: %v", err) + } + + expectedContent := "Hello world" + if resp.Content != expectedContent { + t.Errorf("expected content %q, got %q", expectedContent, resp.Content) + } + + if buf.String() != expectedContent { + t.Errorf("expected writer output %q, got %q", expectedContent, buf.String()) + } + + if resp.Usage.TotalTokens != 12 { + t.Errorf("expected total tokens 12, got %d", resp.Usage.TotalTokens) + } +}