From 0de9eda88a3bda28ff0c2e06e5b0f8d0c7ef690c Mon Sep 17 00:00:00 2001 From: Lorris Saint-Genez Date: Thu, 19 Mar 2026 17:13:55 -0700 Subject: [PATCH 01/10] feat(agent): agent command --- .github/workflows/releases.yml | 3 + .goreleaser.yml | 3 + Makefile | 13 +- Taskfile.yml | 5 +- go.mod | 10 +- go.sum | 21 ++++ pkg/cmd/agent/agent.go | 216 +++++++++++++++++++++++++++++++++ pkg/cmd/agent/render.go | 144 ++++++++++++++++++++++ pkg/cmd/root/root.go | 2 + 9 files changed, 410 insertions(+), 7 deletions(-) create mode 100644 pkg/cmd/agent/agent.go create mode 100644 pkg/cmd/agent/render.go diff --git a/.github/workflows/releases.yml b/.github/workflows/releases.yml index 20d39933..f240e711 100644 --- a/.github/workflows/releases.yml +++ b/.github/workflows/releases.yml @@ -42,6 +42,9 @@ jobs: ALGOLIA_API_URL: ${{ secrets.ALGOLIA_API_URL }} ALGOLIA_OAUTH_CLIENT_ID: ${{ secrets.ALGOLIA_OAUTH_CLIENT_ID }} ALGOLIA_OAUTH_SCOPE: ${{ secrets.ALGOLIA_OAUTH_SCOPE }} + ALGOLIA_AGENT_ID: ${{ secrets.ALGOLIA_AGENT_ID }} + ALGOLIA_AGENT_APP_ID: ${{ secrets.ALGOLIA_AGENT_APP_ID }} + ALGOLIA_AGENT_API_KEY: ${{ secrets.ALGOLIA_AGENT_API_KEY }} - name: Docs checkout uses: actions/checkout@v4 with: diff --git a/.goreleaser.yml b/.goreleaser.yml index ff8debf0..10457c0c 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -22,6 +22,9 @@ builds: -X github.com/algolia/cli/api/dashboard.DefaultAPIURL={{ .Env.ALGOLIA_API_URL }} -X github.com/algolia/cli/pkg/auth.DefaultOAuthClientID={{ .Env.ALGOLIA_OAUTH_CLIENT_ID }} -X github.com/algolia/cli/api/dashboard.DefaultOAuthScope={{ .Env.ALGOLIA_OAUTH_SCOPE }} + -X github.com/algolia/cli/pkg/cmd/agent.DefaultAgentID={{ .Env.ALGOLIA_AGENT_ID }} + -X github.com/algolia/cli/pkg/cmd/agent.DefaultAgentAppID={{ .Env.ALGOLIA_AGENT_APP_ID }} + -X github.com/algolia/cli/pkg/cmd/agent.DefaultAgentAPIKey={{ .Env.ALGOLIA_AGENT_API_KEY }} id: macos goos: [darwin] goarch: [amd64, arm64] diff --git a/Makefile b/Makefile index ea17875c..9e891638 100644 --- a/Makefile +++ b/Makefile @@ -63,11 +63,14 @@ build: go generate ./... go build -ldflags "\ -s -w \ - -X=github.com/algolia/cli/pkg/version.Version=$(VERSION) \ - -X=github.com/algolia/cli/api/dashboard.DefaultDashboardURL=$(ALGOLIA_DASHBOARD_URL) \ - -X=github.com/algolia/cli/api/dashboard.DefaultAPIURL=$(ALGOLIA_API_URL) \ - -X=github.com/algolia/cli/pkg/auth.DefaultOAuthClientID=$(ALGOLIA_OAUTH_CLIENT_ID) \ - -X 'github.com/algolia/cli/api/dashboard.DefaultOAuthScope=$(ALGOLIA_OAUTH_SCOPE)'" \ + -X github.com/algolia/cli/pkg/version.Version=$(VERSION) \ + -X github.com/algolia/cli/api/dashboard.DefaultDashboardURL=$(ALGOLIA_DASHBOARD_URL) \ + -X github.com/algolia/cli/api/dashboard.DefaultAPIURL=$(ALGOLIA_API_URL) \ + -X github.com/algolia/cli/pkg/auth.DefaultOAuthClientID=$(ALGOLIA_OAUTH_CLIENT_ID) \ + -X 'github.com/algolia/cli/api/dashboard.DefaultOAuthScope=$(ALGOLIA_OAUTH_SCOPE)' \ + -X github.com/algolia/cli/pkg/cmd/agent.DefaultAgentID=$(ALGOLIA_AGENT_ID) \ + -X github.com/algolia/cli/pkg/cmd/agent.DefaultAgentAppID=$(ALGOLIA_AGENT_APP_ID) \ + -X 'github.com/algolia/cli/pkg/cmd/agent.DefaultAgentAPIKey=$(ALGOLIA_AGENT_API_KEY)'" \ -o algolia cmd/algolia/main.go .PHONY: build diff --git a/Taskfile.yml b/Taskfile.yml index ef474d86..a767bfec 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -19,7 +19,10 @@ tasks: -X github.com/algolia/cli/api/dashboard.DefaultDashboardURL=$ALGOLIA_DASHBOARD_URL -X github.com/algolia/cli/api/dashboard.DefaultAPIURL=$ALGOLIA_API_URL -X github.com/algolia/cli/pkg/auth.DefaultOAuthClientID=$ALGOLIA_OAUTH_CLIENT_ID - -X 'github.com/algolia/cli/api/dashboard.DefaultOAuthScope=$ALGOLIA_OAUTH_SCOPE'" + -X 'github.com/algolia/cli/api/dashboard.DefaultOAuthScope=$ALGOLIA_OAUTH_SCOPE' + -X github.com/algolia/cli/pkg/cmd/agent.DefaultAgentID=$ALGOLIA_AGENT_ID + -X github.com/algolia/cli/pkg/cmd/agent.DefaultAgentAppID=$ALGOLIA_AGENT_APP_ID + -X 'github.com/algolia/cli/pkg/cmd/agent.DefaultAgentAPIKey=$ALGOLIA_AGENT_API_KEY'" -o algolia cmd/algolia/main.go vars: VERSION: '{{ .VERSION | default "main" }}' diff --git a/go.mod b/go.mod index 1a67f678..a63a3fef 100644 --- a/go.mod +++ b/go.mod @@ -39,6 +39,10 @@ require ( require ( al.essio.dev/pkg/shellescape v1.5.1 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/chzyer/readline v1.5.1 // indirect + github.com/clipperhouse/displaywidth v0.10.0 // indirect + github.com/clipperhouse/uax29/v2 v2.6.0 // indirect github.com/danieljoos/wincred v1.2.2 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fatih/color v1.18.0 // indirect @@ -57,8 +61,12 @@ require ( github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/magiconair/properties v1.8.9 // indirect github.com/mailru/easyjson v0.7.7 // indirect - github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/mattn/go-runewidth v0.0.19 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6 // indirect + github.com/olekukonko/errors v1.2.0 // indirect + github.com/olekukonko/ll v0.1.6 // indirect + github.com/olekukonko/tablewriter v1.1.4 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rivo/uniseg v0.4.7 // indirect diff --git a/go.sum b/go.sum index c747ca61..3e86c384 100644 --- a/go.sum +++ b/go.sum @@ -16,10 +16,20 @@ github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4Yn github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= github.com/briandowns/spinner v1.23.2 h1:Zc6ecUnI+YzLmJniCfDNaMbW0Wid1d5+qcTq4L2FW8w= github.com/briandowns/spinner v1.23.2/go.mod h1:LaZeM4wm2Ywy6vO571mvhQNRcWfRUnXOs0RcKV0wYKM= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= +github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI= +github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= +github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/cli/go-internal v0.0.0-20241025142207-6c48bcd5ce24 h1:QDrhR4JA2n3ij9YQN0u5ZeuvRIIvsUGmf5yPlTS0w8E= github.com/cli/go-internal v0.0.0-20241025142207-6c48bcd5ce24/go.mod h1:rr9GNING0onuVw8MnracQHn7PcchnFlP882Y0II2KZk= github.com/cli/safeexec v1.0.1 h1:e/C79PbXF4yYTN/wauC4tviMxEV13BwljGj0N9j+N00= github.com/cli/safeexec v1.0.1/go.mod h1:Z/D4tTN8Vs5gXYHDCbaM1S/anmEDnJb1iW0+EJ5zx3Q= +github.com/clipperhouse/displaywidth v0.10.0 h1:GhBG8WuerxjFQQYeuZAeVTuyxuX+UraiZGD4HJQ3Y8g= +github.com/clipperhouse/displaywidth v0.10.0/go.mod h1:XqJajYsaiEwkxOj4bowCTMcT1SgvHo9flfF3jQasdbs= +github.com/clipperhouse/uax29/v2 v2.6.0 h1:z0cDbUV+aPASdFb2/ndFnS9ts/WNXgTNNGFoKXuhpos= +github.com/clipperhouse/uax29/v2 v2.6.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.17 h1:QeVUsEDNrLBW4tMgZHvxy18sKtr6VI492kBhUfhDJNI= @@ -93,6 +103,8 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= +github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQE9x6ikvDFZS2mDVS3drnohI= github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= @@ -104,6 +116,14 @@ github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= +github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6 h1:zrbMGy9YXpIeTnGj4EljqMiZsIcE09mmF8XsD5AYOJc= +github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6/go.mod h1:rEKTHC9roVVicUIfZK7DYrdIoM0EOr8mK1Hj5s3JjH0= +github.com/olekukonko/errors v1.2.0 h1:10Zcn4GeV59t/EGqJc8fUjtFT/FuUh5bTMzZ1XwmCRo= +github.com/olekukonko/errors v1.2.0/go.mod h1:ppzxA5jBKcO1vIpCXQ9ZqgDh8iwODz6OXIGKU8r5m4Y= +github.com/olekukonko/ll v0.1.6 h1:lGVTHO+Qc4Qm+fce/2h2m5y9LvqaW+DCN7xW9hsU3uA= +github.com/olekukonko/ll v0.1.6/go.mod h1:NVUmjBb/aCtUpjKk75BhWrOlARz3dqsM+OtszpY4o88= +github.com/olekukonko/tablewriter v1.1.4 h1:ORUMI3dXbMnRlRggJX3+q7OzQFDdvgbN9nVWj1drm6I= +github.com/olekukonko/tablewriter v1.1.4/go.mod h1:+kedxuyTtgoZLwif3P1Em4hARJs+mVnzKxmsCL/C5RY= github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= @@ -171,6 +191,7 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/pkg/cmd/agent/agent.go b/pkg/cmd/agent/agent.go new file mode 100644 index 00000000..564ae8fb --- /dev/null +++ b/pkg/cmd/agent/agent.go @@ -0,0 +1,216 @@ +package agent + +import ( + "bufio" + "encoding/json" + "fmt" + "net/http" + "os" + "strings" + + "github.com/MakeNowJust/heredoc" + "github.com/spf13/cobra" + + "github.com/algolia/cli/pkg/auth" + "github.com/algolia/cli/pkg/cmdutil" + "github.com/algolia/cli/pkg/iostreams" +) + +// Build-time variables injected via ldflags in .goreleaser.yml. +var ( + DefaultAgentID string + DefaultAgentAppID string + DefaultAgentAPIKey string +) + +// AgentOptions holds the configuration for the agent command. +type AgentOptions struct { + IO *iostreams.IOStreams + + AgentID string + AppID string + APIKey string +} + +// message represents a single message in the conversation. +type message struct { + ID string `json:"id,omitempty"` + Role string `json:"role"` + Parts []part `json:"parts"` +} + +// part represents a content part within a message. +type part struct { + Type string `json:"type,omitempty"` + Text string `json:"text"` +} + +// completionRequest is the request body sent to Agent Studio. +type completionRequest struct { + Messages []message `json:"messages"` +} + +// sseEvent represents a parsed SSE data payload. +type sseEvent struct { + Type string `json:"type"` + ID string `json:"id,omitempty"` + MessageID string `json:"messageId,omitempty"` + Delta string `json:"delta,omitempty"` +} + +func NewAgentCmd(f *cmdutil.Factory) *cobra.Command { + opts := &AgentOptions{ + IO: f.IOStreams, + AgentID: envOrDefault("ALGOLIA_AGENT_ID", DefaultAgentID), + AppID: envOrDefault("ALGOLIA_AGENT_APP_ID", DefaultAgentAppID), + APIKey: envOrDefault("ALGOLIA_AGENT_API_KEY", DefaultAgentAPIKey), + } + + cmd := &cobra.Command{ + Use: "agent", + Short: "Chat with an AI agent that suggests Algolia CLI commands", + Long: "Interactive chat with an AI agent that advises CLI commands for your use case. The agent only prints suggestions — it does not execute commands.", + Example: heredoc.Doc(` + $ algolia agent + `), + RunE: func(cmd *cobra.Command, args []string) error { + return runAgent(opts) + }, + } + + auth.DisableAuthCheck(cmd) + + return cmd +} + +func runAgent(opts *AgentOptions) error { + if opts.AgentID == "" || opts.AppID == "" || opts.APIKey == "" { + return fmt.Errorf("agent credentials are not configured") + } + + out := opts.IO.Out + scanner := bufio.NewScanner(os.Stdin) + + fmt.Fprintln(out, "Algolia CLI Agent (type \"exit\" to quit)") + fmt.Fprintln(out) + + var history []message + msgCounter := 0 + + for { + fmt.Fprint(out, "> ") + + if !scanner.Scan() { + break + } + input := strings.TrimSpace(scanner.Text()) + if input == "" { + continue + } + if input == "exit" { + break + } + + msgCounter++ + userMsg := message{ + ID: fmt.Sprintf("alg_msg_%d", msgCounter), + Role: "user", + Parts: []part{ + {Text: input}, + }, + } + history = append(history, userMsg) + + assistantText, assistantID, err := sendCompletion(opts, history) + if err != nil { + fmt.Fprintf(opts.IO.ErrOut, "Error: %s\n", err) + // Remove the failed user message from history. + history = history[:len(history)-1] + continue + } + + fmt.Fprintln(out) + fmt.Fprintln(out, renderMarkdown(opts.IO.ColorScheme(), assistantText)) + fmt.Fprintln(out) + + history = append(history, message{ + ID: assistantID, + Role: "assistant", + Parts: []part{ + {Type: "text", Text: assistantText}, + }, + }) + } + + return nil +} + +// sendCompletion sends the conversation to Agent Studio and streams the response. +// Returns the assistant text and the server-generated message ID. +func sendCompletion(opts *AgentOptions, messages []message) (string, string, error) { + url := fmt.Sprintf( + "https://%s.algolia.net/agent-studio/1/agents/%s/completions?stream=true&compatibilityMode=ai-sdk-5", + opts.AppID, opts.AgentID, + ) + + reqBody := completionRequest{ + Messages: messages, + } + body, err := json.Marshal(reqBody) + if err != nil { + return "", "", fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, strings.NewReader(string(body))) + if err != nil { + return "", "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-algolia-application-id", opts.AppID) + req.Header.Set("X-Algolia-API-Key", opts.APIKey) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", "", fmt.Errorf("unexpected status: %s", resp.Status) + } + + // Parse SSE stream and collect text deltas. + var result strings.Builder + var messageID string + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") { + continue + } + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + break + } + + var event sseEvent + if err := json.Unmarshal([]byte(data), &event); err != nil { + continue + } + switch event.Type { + case "start": + messageID = event.MessageID + case "text-delta": + result.WriteString(event.Delta) + } + } + + return result.String(), messageID, nil +} + +func envOrDefault(key, defaultVal string) string { + if v := os.Getenv(key); v != "" { + return v + } + return defaultVal +} diff --git a/pkg/cmd/agent/render.go b/pkg/cmd/agent/render.go new file mode 100644 index 00000000..a92e8f90 --- /dev/null +++ b/pkg/cmd/agent/render.go @@ -0,0 +1,144 @@ +package agent + +import ( + "regexp" + "strings" + + "github.com/olekukonko/tablewriter" + + "github.com/algolia/cli/pkg/iostreams" +) + +const algoliaBlue = "3369e7" + +// renderMarkdown converts a markdown string into ANSI-styled terminal output. +// It handles: headers, bold, inline code, fenced code blocks, and tables. +func renderMarkdown(cs *iostreams.ColorScheme, text string) string { + lines := strings.Split(text, "\n") + var out []string + var tableBuffer []string + inCodeBlock := false + + for _, line := range lines { + if strings.HasPrefix(line, "```") { + inCodeBlock = !inCodeBlock + continue + } + + if inCodeBlock { + out = append(out, cs.HexToRGB(algoliaBlue, line)) + continue + } + + if strings.HasPrefix(strings.TrimSpace(line), "|") { + tableBuffer = append(tableBuffer, line) + continue + } + + if len(tableBuffer) > 0 { + out = append(out, renderTable(cs, tableBuffer)) + tableBuffer = nil + } + + if strings.HasPrefix(line, "#") { + stripped := strings.TrimLeft(line, "# ") + out = append(out, cs.Bold(stripped)) + continue + } + + line = renderInline(cs, line) + out = append(out, line) + } + + if len(tableBuffer) > 0 { + out = append(out, renderTable(cs, tableBuffer)) + } + + return strings.Join(out, "\n") +} + +// Bold regex: **text** +var boldRe = regexp.MustCompile(`\*\*(.+?)\*\*`) + +// Inline code regex: `text` +var codeRe = regexp.MustCompile("`([^`]+)`") + +// renderInline applies bold and inline code styling to a single line. +func renderInline(cs *iostreams.ColorScheme, line string) string { + line = boldRe.ReplaceAllStringFunc(line, func(match string) string { + inner := boldRe.FindStringSubmatch(match)[1] + return cs.Bold(inner) + }) + line = codeRe.ReplaceAllStringFunc(line, func(match string) string { + inner := codeRe.FindStringSubmatch(match)[1] + return cs.HexToRGB(algoliaBlue, inner) + }) + return line +} + +// isSeparatorRow checks if a table row is a markdown separator (e.g. |---|---|). +func isSeparatorRow(line string) bool { + for _, cell := range parseTableRow(line) { + stripped := strings.Trim(cell, " :-") + if stripped != "" { + return false + } + } + return true +} + +// parseTableRow splits a markdown table row into cell values. +func parseTableRow(line string) []string { + line = strings.TrimSpace(line) + line = strings.Trim(line, "|") + parts := strings.Split(line, "|") + for i := range parts { + parts[i] = strings.TrimSpace(parts[i]) + } + return parts +} + +// renderTable renders markdown table lines using tablewriter. +func renderTable(cs *iostreams.ColorScheme, lines []string) string { + var header []string + var dataRows [][]string + foundSep := false + + for _, line := range lines { + if isSeparatorRow(line) { + foundSep = true + continue + } + cells := parseTableRow(line) + if !foundSep && header == nil { + header = cells + } else { + dataRows = append(dataRows, cells) + } + } + + var buf strings.Builder + table := tablewriter.NewWriter(&buf) + + // Tablewriter auto-formats headers (uppercase + bold). + if header != nil { + plain := make([]any, len(header)) + for i, h := range header { + plain[i] = h + } + table.Header(plain...) + } + + // Apply inline styling to data cells and append rows. + for _, row := range dataRows { + styled := make([]any, len(row)) + for i, cell := range row { + styled[i] = renderInline(cs, cell) + } + table.Append(styled...) + } + + table.Render() + + return strings.TrimRight(buf.String(), "\n") +} diff --git a/pkg/cmd/root/root.go b/pkg/cmd/root/root.go index b5c33fa7..99e0c95b 100644 --- a/pkg/cmd/root/root.go +++ b/pkg/cmd/root/root.go @@ -25,6 +25,7 @@ import ( "github.com/algolia/cli/pkg/auth" "github.com/algolia/cli/pkg/cmd/apikeys" authcmd "github.com/algolia/cli/pkg/cmd/auth" + "github.com/algolia/cli/pkg/cmd/agent" "github.com/algolia/cli/pkg/cmd/crawler" "github.com/algolia/cli/pkg/cmd/describe" "github.com/algolia/cli/pkg/cmd/dictionary" @@ -103,6 +104,7 @@ func NewRootCmd(f *cmdutil.Factory) *cobra.Command { // Convenience commands cmd.AddCommand(open.NewOpenCmd(f)) + cmd.AddCommand(agent.NewAgentCmd(f)) // API related commands cmd.AddCommand(search.NewSearchCmd(f)) From 6d90da56b77e97bb427d1d2d4056251ae590c480 Mon Sep 17 00:00:00 2001 From: Lorris Saint-Genez Date: Thu, 19 Mar 2026 17:23:17 -0700 Subject: [PATCH 02/10] feat(agent): add loading animation --- pkg/cmd/agent/agent.go | 12 +++- pkg/cmd/agent/agent_test.go | 119 ++++++++++++++++++++++++++++++++++++ 2 files changed, 129 insertions(+), 2 deletions(-) create mode 100644 pkg/cmd/agent/agent_test.go diff --git a/pkg/cmd/agent/agent.go b/pkg/cmd/agent/agent.go index 564ae8fb..645a1145 100644 --- a/pkg/cmd/agent/agent.go +++ b/pkg/cmd/agent/agent.go @@ -4,6 +4,7 @@ import ( "bufio" "encoding/json" "fmt" + "io" "net/http" "os" "strings" @@ -121,7 +122,9 @@ func runAgent(opts *AgentOptions) error { } history = append(history, userMsg) + opts.IO.StartProgressIndicator() assistantText, assistantID, err := sendCompletion(opts, history) + opts.IO.StopProgressIndicator() if err != nil { fmt.Fprintf(opts.IO.ErrOut, "Error: %s\n", err) // Remove the failed user message from history. @@ -179,10 +182,15 @@ func sendCompletion(opts *AgentOptions, messages []message) (string, string, err return "", "", fmt.Errorf("unexpected status: %s", resp.Status) } - // Parse SSE stream and collect text deltas. + return parseSSEStream(resp.Body) +} + +// parseSSEStream reads an SSE stream and collects text deltas. +// Returns the assembled text and the server-generated message ID. +func parseSSEStream(r io.Reader) (string, string, error) { var result strings.Builder var messageID string - scanner := bufio.NewScanner(resp.Body) + scanner := bufio.NewScanner(r) for scanner.Scan() { line := scanner.Text() if !strings.HasPrefix(line, "data: ") { diff --git a/pkg/cmd/agent/agent_test.go b/pkg/cmd/agent/agent_test.go new file mode 100644 index 00000000..a88497f5 --- /dev/null +++ b/pkg/cmd/agent/agent_test.go @@ -0,0 +1,119 @@ +package agent + +import ( + "os" + "strings" + "testing" +) + +func TestParseSSEStream(t *testing.T) { + tests := []struct { + name string + input string + wantText string + wantMsgID string + }{ + { + name: "parses a complete stream with start and text-delta events", + input: "data: {\"type\":\"start\",\"messageId\":\"msg_123\"}\n" + + "data: {\"type\":\"text-delta\",\"delta\":\"Hello \"}\n" + + "data: {\"type\":\"text-delta\",\"delta\":\"world\"}\n" + + "data: [DONE]\n", + wantText: "Hello world", + wantMsgID: "msg_123", + }, + { + name: "returns empty on empty stream", + input: "", + wantText: "", + wantMsgID: "", + }, + { + name: "ignores non-data lines", + input: "event: message\n" + + "id: 1\n" + + "data: {\"type\":\"text-delta\",\"delta\":\"hi\"}\n" + + "data: [DONE]\n", + wantText: "hi", + wantMsgID: "", + }, + { + name: "skips malformed JSON", + input: "data: not-json\n" + + "data: {\"type\":\"text-delta\",\"delta\":\"ok\"}\n" + + "data: [DONE]\n", + wantText: "ok", + wantMsgID: "", + }, + { + name: "stops at DONE marker", + input: "data: {\"type\":\"text-delta\",\"delta\":\"before\"}\n" + + "data: [DONE]\n" + + "data: {\"type\":\"text-delta\",\"delta\":\"after\"}\n", + wantText: "before", + wantMsgID: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := strings.NewReader(tt.input) + gotText, gotMsgID, err := parseSSEStream(r) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if gotText != tt.wantText { + t.Errorf("text = %q, want %q", gotText, tt.wantText) + } + if gotMsgID != tt.wantMsgID { + t.Errorf("messageID = %q, want %q", gotMsgID, tt.wantMsgID) + } + }) + } +} + +func TestEnvOrDefault(t *testing.T) { + tests := []struct { + name string + key string + envValue string + defaultVal string + want string + }{ + { + name: "returns default when env is not set", + key: "TEST_ENV_OR_DEFAULT_UNSET", + defaultVal: "fallback", + want: "fallback", + }, + { + name: "returns env value when set", + key: "TEST_ENV_OR_DEFAULT_SET", + envValue: "from-env", + defaultVal: "fallback", + want: "from-env", + }, + { + name: "returns default when env is empty string", + key: "TEST_ENV_OR_DEFAULT_EMPTY", + envValue: "", + defaultVal: "fallback", + want: "fallback", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + os.Unsetenv(tt.key) + if tt.envValue != "" { + os.Setenv(tt.key, tt.envValue) + defer os.Unsetenv(tt.key) + } + + got := envOrDefault(tt.key, tt.defaultVal) + if got != tt.want { + t.Errorf("envOrDefault(%q, %q) = %q, want %q", tt.key, tt.defaultVal, got, tt.want) + } + }) + } +} From 4bd2b218224b2e12d7d704ec41c800da820d17fe Mon Sep 17 00:00:00 2001 From: Lorris Saint-Genez Date: Thu, 19 Mar 2026 17:35:43 -0700 Subject: [PATCH 03/10] chore: handle italic mardown --- go.mod | 5 ++--- go.sum | 11 ++--------- pkg/cmd/agent/render.go | 22 +++++++++++++++++++--- 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/go.mod b/go.mod index a63a3fef..c0b7f15e 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/mitchellh/go-homedir v1.1.0 github.com/muesli/reflow v0.3.0 github.com/muesli/termenv v0.16.0 + github.com/olekukonko/tablewriter v1.1.4 github.com/segmentio/analytics-go/v3 v3.3.0 github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.9.1 @@ -40,7 +41,6 @@ require ( al.essio.dev/pkg/shellescape v1.5.1 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/chzyer/readline v1.5.1 // indirect github.com/clipperhouse/displaywidth v0.10.0 // indirect github.com/clipperhouse/uax29/v2 v2.6.0 // indirect github.com/danieljoos/wincred v1.2.2 // indirect @@ -58,7 +58,7 @@ require ( github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect - github.com/lucasb-eyer/go-colorful v1.2.0 // indirect + github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/magiconair/properties v1.8.9 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-runewidth v0.0.19 // indirect @@ -66,7 +66,6 @@ require ( github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6 // indirect github.com/olekukonko/errors v1.2.0 // indirect github.com/olekukonko/ll v0.1.6 // indirect - github.com/olekukonko/tablewriter v1.1.4 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rivo/uniseg v0.4.7 // indirect diff --git a/go.sum b/go.sum index 3e86c384..6cc2b50f 100644 --- a/go.sum +++ b/go.sum @@ -18,10 +18,6 @@ github.com/briandowns/spinner v1.23.2 h1:Zc6ecUnI+YzLmJniCfDNaMbW0Wid1d5+qcTq4L2 github.com/briandowns/spinner v1.23.2/go.mod h1:LaZeM4wm2Ywy6vO571mvhQNRcWfRUnXOs0RcKV0wYKM= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= -github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI= -github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= -github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/cli/go-internal v0.0.0-20241025142207-6c48bcd5ce24 h1:QDrhR4JA2n3ij9YQN0u5ZeuvRIIvsUGmf5yPlTS0w8E= github.com/cli/go-internal v0.0.0-20241025142207-6c48bcd5ce24/go.mod h1:rr9GNING0onuVw8MnracQHn7PcchnFlP882Y0II2KZk= github.com/cli/safeexec v1.0.1 h1:e/C79PbXF4yYTN/wauC4tviMxEV13BwljGj0N9j+N00= @@ -86,8 +82,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= -github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= +github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag= +github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/magiconair/properties v1.8.9 h1:nWcCbLq1N2v/cpNsy5WvQ37Fb+YElfq20WJ/a8RkpQM= github.com/magiconair/properties v1.8.9/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= @@ -101,8 +97,6 @@ github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= -github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= -github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= @@ -191,7 +185,6 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/pkg/cmd/agent/render.go b/pkg/cmd/agent/render.go index a92e8f90..d8d6593a 100644 --- a/pkg/cmd/agent/render.go +++ b/pkg/cmd/agent/render.go @@ -9,10 +9,11 @@ import ( "github.com/algolia/cli/pkg/iostreams" ) -const algoliaBlue = "3369e7" +// algoliaBlue is the Algolia Blue brand color. +const algoliaBlue = "3970ff" // renderMarkdown converts a markdown string into ANSI-styled terminal output. -// It handles: headers, bold, inline code, fenced code blocks, and tables. +// It handles: headers, bold, italic, inline code, fenced code blocks, and tables. func renderMarkdown(cs *iostreams.ColorScheme, text string) string { lines := strings.Split(text, "\n") var out []string @@ -60,10 +61,13 @@ func renderMarkdown(cs *iostreams.ColorScheme, text string) string { // Bold regex: **text** var boldRe = regexp.MustCompile(`\*\*(.+?)\*\*`) +// Italic regex: *text* (but not **text**) +var italicRe = regexp.MustCompile(`(?:^|[^*])\*([^*]+?)\*(?:[^*]|$)`) + // Inline code regex: `text` var codeRe = regexp.MustCompile("`([^`]+)`") -// renderInline applies bold and inline code styling to a single line. +// renderInline applies bold, italic, and inline code styling to a single line. func renderInline(cs *iostreams.ColorScheme, line string) string { line = boldRe.ReplaceAllStringFunc(line, func(match string) string { inner := boldRe.FindStringSubmatch(match)[1] @@ -73,6 +77,18 @@ func renderInline(cs *iostreams.ColorScheme, line string) string { inner := codeRe.FindStringSubmatch(match)[1] return cs.HexToRGB(algoliaBlue, inner) }) + line = italicRe.ReplaceAllStringFunc(line, func(match string) string { + inner := italicRe.FindStringSubmatch(match)[1] + prefix := "" + suffix := "" + if len(match) > 0 && match[0] != '*' { + prefix = string(match[0]) + } + if len(match) > 0 && match[len(match)-1] != '*' { + suffix = string(match[len(match)-1]) + } + return prefix + cs.Gray(inner) + suffix + }) return line } From d476344bf218892ac912c1f619f52fa9ff5c124e Mon Sep 17 00:00:00 2001 From: Lorris Saint-Genez Date: Thu, 19 Mar 2026 20:28:23 -0700 Subject: [PATCH 04/10] chore: more colors --- go.mod | 1 + go.sum | 2 + pkg/cmd/agent/agent.go | 7 +- pkg/cmd/agent/render.go | 260 +++++++++++++++++++++++++--------------- pkg/cmd/root/root.go | 2 +- 5 files changed, 173 insertions(+), 99 deletions(-) diff --git a/go.mod b/go.mod index c0b7f15e..7f362b4c 100644 --- a/go.mod +++ b/go.mod @@ -77,6 +77,7 @@ require ( github.com/spf13/afero v1.12.0 // indirect github.com/spf13/cast v1.7.1 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/yuin/goldmark v1.7.17 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect golang.org/x/sys v0.31.0 // indirect diff --git a/go.sum b/go.sum index 6cc2b50f..d9794ea2 100644 --- a/go.sum +++ b/go.sum @@ -167,6 +167,8 @@ github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSW github.com/xtgo/uuid v0.0.0-20140804021211-a0b114877d4c h1:3lbZUMbMiGUW/LMkfsEABsc5zNT9+b1CvsJx47JzJ8g= github.com/xtgo/uuid v0.0.0-20140804021211-a0b114877d4c/go.mod h1:UrdRz5enIKZ63MEE3IF9l2/ebyx59GyGgPi+tICQdmM= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/yuin/goldmark v1.7.17 h1:p36OVWwRb246iHxA/U4p8OPEpOTESm4n+g+8t0EE5uA= +github.com/yuin/goldmark v1.7.17/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= github.com/zalando/go-keyring v0.2.6 h1:r7Yc3+H+Ux0+M72zacZoItR3UDxeWfKTcabvkI8ua9s= github.com/zalando/go-keyring v0.2.6/go.mod h1:2TCrxYrbUNYfNS/Kgy/LSrkSQzZ5UPVH85RwfczwvcI= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= diff --git a/pkg/cmd/agent/agent.go b/pkg/cmd/agent/agent.go index 645a1145..0524b45a 100644 --- a/pkg/cmd/agent/agent.go +++ b/pkg/cmd/agent/agent.go @@ -92,8 +92,11 @@ func runAgent(opts *AgentOptions) error { out := opts.IO.Out scanner := bufio.NewScanner(os.Stdin) + cs := opts.IO.ColorScheme() + separator := cs.Gray(strings.Repeat("─", opts.IO.TerminalWidth())) + fmt.Fprintln(out, "Algolia CLI Agent (type \"exit\" to quit)") - fmt.Fprintln(out) + fmt.Fprintln(out, separator) var history []message msgCounter := 0 @@ -132,9 +135,11 @@ func runAgent(opts *AgentOptions) error { continue } + fmt.Fprintln(out, separator) fmt.Fprintln(out) fmt.Fprintln(out, renderMarkdown(opts.IO.ColorScheme(), assistantText)) fmt.Fprintln(out) + fmt.Fprintln(out, separator) history = append(history, message{ ID: assistantID, diff --git a/pkg/cmd/agent/render.go b/pkg/cmd/agent/render.go index d8d6593a..5000ccff 100644 --- a/pkg/cmd/agent/render.go +++ b/pkg/cmd/agent/render.go @@ -5,6 +5,11 @@ import ( "strings" "github.com/olekukonko/tablewriter" + "github.com/yuin/goldmark" + "github.com/yuin/goldmark/ast" + "github.com/yuin/goldmark/extension" + east "github.com/yuin/goldmark/extension/ast" + "github.com/yuin/goldmark/text" "github.com/algolia/cli/pkg/iostreams" ) @@ -12,121 +17,184 @@ import ( // algoliaBlue is the Algolia Blue brand color. const algoliaBlue = "3970ff" -// renderMarkdown converts a markdown string into ANSI-styled terminal output. -// It handles: headers, bold, italic, inline code, fenced code blocks, and tables. -func renderMarkdown(cs *iostreams.ColorScheme, text string) string { - lines := strings.Split(text, "\n") - var out []string - var tableBuffer []string - inCodeBlock := false - - for _, line := range lines { - if strings.HasPrefix(line, "```") { - inCodeBlock = !inCodeBlock - continue - } +// algoliaLightBlue is a lighter blue for placeholders. +const algoliaLightBlue = "00aeff" - if inCodeBlock { - out = append(out, cs.HexToRGB(algoliaBlue, line)) - continue - } +// algoliaTeal is the Java teal for flags. +const algoliaTeal = "1cc7d0" + +// renderMarkdown converts a markdown string into ANSI-styled terminal output +// by parsing it with goldmark and walking the AST. +func renderMarkdown(cs *iostreams.ColorScheme, input string) string { + source := []byte(input) + + md := goldmark.New( + goldmark.WithExtensions(extension.Table), + ) + reader := text.NewReader(source) + doc := md.Parser().Parse(reader) + + var out strings.Builder + renderNode(&out, cs, doc, source) + return strings.TrimRight(out.String(), "\n") +} + +// renderNode recursively walks the AST and writes ANSI-styled text to out. +func renderNode(out *strings.Builder, cs *iostreams.ColorScheme, n ast.Node, source []byte) { + switch n.Kind() { - if strings.HasPrefix(strings.TrimSpace(line), "|") { - tableBuffer = append(tableBuffer, line) - continue + case ast.KindDocument: + renderChildren(out, cs, n, source) + + case ast.KindHeading: + var headingText strings.Builder + renderChildrenTo(&headingText, cs, n, source) + out.WriteString(cs.Bold(headingText.String())) + out.WriteString("\n\n") + + case ast.KindParagraph: + renderChildren(out, cs, n, source) + out.WriteString("\n\n") + + case ast.KindTextBlock: + renderChildren(out, cs, n, source) + out.WriteString("\n") + + case ast.KindText: + t := n.(*ast.Text) + out.Write(t.Segment.Value(source)) + if t.SoftLineBreak() { + out.WriteString("\n") + } + if t.HardLineBreak() { + out.WriteString("\n") } - if len(tableBuffer) > 0 { - out = append(out, renderTable(cs, tableBuffer)) - tableBuffer = nil + case ast.KindEmphasis: + e := n.(*ast.Emphasis) + var content strings.Builder + renderChildrenTo(&content, cs, n, source) + if e.Level == 2 { + out.WriteString(cs.Bold(content.String())) + } else { + out.WriteString(cs.Gray(content.String())) } - if strings.HasPrefix(line, "#") { - stripped := strings.TrimLeft(line, "# ") - out = append(out, cs.Bold(stripped)) - continue + case ast.KindCodeSpan: + var code strings.Builder + for child := n.FirstChild(); child != nil; child = child.NextSibling() { + if child.Kind() == ast.KindText { + t := child.(*ast.Text) + code.Write(t.Segment.Value(source)) + } + } + out.WriteString(colorCodeSpan(cs, code.String())) + + case ast.KindFencedCodeBlock, ast.KindCodeBlock: + lines := n.Lines() + for i := 0; i < lines.Len(); i++ { + seg := lines.At(i) + line := strings.TrimRight(string(seg.Value(source)), "\n") + if idx := strings.Index(line, "#"); idx >= 0 { + cmd := line[:idx] + comment := line[idx:] + out.WriteString(cs.HexToRGB(algoliaBlue, cmd)) + out.WriteString(cs.Green(comment)) + } else { + out.WriteString(cs.HexToRGB(algoliaBlue, line)) + } + out.WriteString("\n") } + out.WriteString("\n") + + case ast.KindList: + renderChildren(out, cs, n, source) + out.WriteString("\n") + + case ast.KindListItem: + out.WriteString("- ") + // Render list item children inline (skip the paragraph newlines). + for child := n.FirstChild(); child != nil; child = child.NextSibling() { + if child.Kind() == ast.KindParagraph { + renderChildren(out, cs, child, source) + } else { + renderNode(out, cs, child, source) + } + } + out.WriteString("\n") - line = renderInline(cs, line) - out = append(out, line) - } + case ast.KindThematicBreak: + out.WriteString("───\n\n") - if len(tableBuffer) > 0 { - out = append(out, renderTable(cs, tableBuffer)) - } + case east.KindTable: + renderTable(out, cs, n, source) + + case ast.KindString: + s := n.(*ast.String) + out.Write(s.Value) - return strings.Join(out, "\n") + default: + // For any unhandled node, just render its children. + renderChildren(out, cs, n, source) + } } -// Bold regex: **text** -var boldRe = regexp.MustCompile(`\*\*(.+?)\*\*`) - -// Italic regex: *text* (but not **text**) -var italicRe = regexp.MustCompile(`(?:^|[^*])\*([^*]+?)\*(?:[^*]|$)`) - -// Inline code regex: `text` -var codeRe = regexp.MustCompile("`([^`]+)`") - -// renderInline applies bold, italic, and inline code styling to a single line. -func renderInline(cs *iostreams.ColorScheme, line string) string { - line = boldRe.ReplaceAllStringFunc(line, func(match string) string { - inner := boldRe.FindStringSubmatch(match)[1] - return cs.Bold(inner) - }) - line = codeRe.ReplaceAllStringFunc(line, func(match string) string { - inner := codeRe.FindStringSubmatch(match)[1] - return cs.HexToRGB(algoliaBlue, inner) - }) - line = italicRe.ReplaceAllStringFunc(line, func(match string) string { - inner := italicRe.FindStringSubmatch(match)[1] - prefix := "" - suffix := "" - if len(match) > 0 && match[0] != '*' { - prefix = string(match[0]) - } - if len(match) > 0 && match[len(match)-1] != '*' { - suffix = string(match[len(match)-1]) - } - return prefix + cs.Gray(inner) + suffix - }) - return line +// renderChildren renders all children of a node to the shared output. +func renderChildren(out *strings.Builder, cs *iostreams.ColorScheme, n ast.Node, source []byte) { + for child := n.FirstChild(); child != nil; child = child.NextSibling() { + renderNode(out, cs, child, source) + } } -// isSeparatorRow checks if a table row is a markdown separator (e.g. |---|---|). -func isSeparatorRow(line string) bool { - for _, cell := range parseTableRow(line) { - stripped := strings.Trim(cell, " :-") - if stripped != "" { - return false +// codeTokenRe matches placeholders (<...>) and flags (--word). +var codeTokenRe = regexp.MustCompile(`<[^>]+>|\[[^\]]+\]|--\w[\w-]*|-\w\b`) + +// colorCodeSpan colors inline code with different colors per token type: +// - placeholders () in light blue +// - flags (--query) in cyan +// - everything else in Algolia Blue +func colorCodeSpan(cs *iostreams.ColorScheme, code string) string { + var out strings.Builder + last := 0 + for _, match := range codeTokenRe.FindAllStringIndex(code, -1) { + if match[0] > last { + out.WriteString(cs.HexToRGB(algoliaBlue, code[last:match[0]])) + } + token := code[match[0]:match[1]] + if strings.HasPrefix(token, "<") || strings.HasPrefix(token, "[") { + out.WriteString(cs.HexToRGB(algoliaLightBlue, token)) + } else { + out.WriteString(cs.HexToRGB(algoliaTeal, token)) } + last = match[1] } - return true + if last < len(code) { + out.WriteString(cs.HexToRGB(algoliaBlue, code[last:])) + } + return out.String() } -// parseTableRow splits a markdown table row into cell values. -func parseTableRow(line string) []string { - line = strings.TrimSpace(line) - line = strings.Trim(line, "|") - parts := strings.Split(line, "|") - for i := range parts { - parts[i] = strings.TrimSpace(parts[i]) +// renderChildrenTo renders all children into a separate builder (for wrapping in styles). +func renderChildrenTo(buf *strings.Builder, cs *iostreams.ColorScheme, n ast.Node, source []byte) { + for child := n.FirstChild(); child != nil; child = child.NextSibling() { + renderNode(buf, cs, child, source) } - return parts } -// renderTable renders markdown table lines using tablewriter. -func renderTable(cs *iostreams.ColorScheme, lines []string) string { +// renderTable collects rows from a goldmark Table node and renders via tablewriter. +func renderTable(out *strings.Builder, cs *iostreams.ColorScheme, n ast.Node, source []byte) { var header []string var dataRows [][]string - foundSep := false - for _, line := range lines { - if isSeparatorRow(line) { - foundSep = true - continue + for child := n.FirstChild(); child != nil; child = child.NextSibling() { + var cells []string + for cell := child.FirstChild(); cell != nil; cell = cell.NextSibling() { + var cellBuf strings.Builder + renderChildrenTo(&cellBuf, cs, cell, source) + cells = append(cells, cellBuf.String()) } - cells := parseTableRow(line) - if !foundSep && header == nil { + + if child.Kind() == east.KindTableHeader { header = cells } else { dataRows = append(dataRows, cells) @@ -136,7 +204,6 @@ func renderTable(cs *iostreams.ColorScheme, lines []string) string { var buf strings.Builder table := tablewriter.NewWriter(&buf) - // Tablewriter auto-formats headers (uppercase + bold). if header != nil { plain := make([]any, len(header)) for i, h := range header { @@ -145,16 +212,15 @@ func renderTable(cs *iostreams.ColorScheme, lines []string) string { table.Header(plain...) } - // Apply inline styling to data cells and append rows. for _, row := range dataRows { styled := make([]any, len(row)) for i, cell := range row { - styled[i] = renderInline(cs, cell) + styled[i] = cell } - table.Append(styled...) + _ = table.Append(styled...) } - table.Render() - - return strings.TrimRight(buf.String(), "\n") + _ = table.Render() + out.WriteString(strings.TrimRight(buf.String(), "\n")) + out.WriteString("\n\n") } diff --git a/pkg/cmd/root/root.go b/pkg/cmd/root/root.go index 49246e3c..c263910c 100644 --- a/pkg/cmd/root/root.go +++ b/pkg/cmd/root/root.go @@ -21,9 +21,9 @@ import ( "github.com/algolia/cli/internal/update" "github.com/algolia/cli/pkg/auth" + "github.com/algolia/cli/pkg/cmd/agent" "github.com/algolia/cli/pkg/cmd/apikeys" authcmd "github.com/algolia/cli/pkg/cmd/auth" - "github.com/algolia/cli/pkg/cmd/agent" "github.com/algolia/cli/pkg/cmd/crawler" "github.com/algolia/cli/pkg/cmd/describe" "github.com/algolia/cli/pkg/cmd/dictionary" From 16e9427a25c418bc23437b69c048e63e5c729689 Mon Sep 17 00:00:00 2001 From: Lorris Saint-Genez Date: Fri, 20 Mar 2026 09:20:04 -0700 Subject: [PATCH 05/10] chore: display suggested command to run from suggestCommand tool --- pkg/cmd/agent/agent.go | 70 +++++++++++++++++++++++++------------ pkg/cmd/agent/agent_test.go | 32 ++++++++++++----- 2 files changed, 71 insertions(+), 31 deletions(-) diff --git a/pkg/cmd/agent/agent.go b/pkg/cmd/agent/agent.go index 0524b45a..51a13ff6 100644 --- a/pkg/cmd/agent/agent.go +++ b/pkg/cmd/agent/agent.go @@ -53,10 +53,17 @@ type completionRequest struct { // sseEvent represents a parsed SSE data payload. type sseEvent struct { - Type string `json:"type"` - ID string `json:"id,omitempty"` - MessageID string `json:"messageId,omitempty"` - Delta string `json:"delta,omitempty"` + Type string `json:"type"` + ID string `json:"id,omitempty"` + MessageID string `json:"messageId,omitempty"` + Delta string `json:"delta,omitempty"` + ToolName string `json:"toolName,omitempty"` + Input json.RawMessage `json:"input,omitempty"` +} + +// suggestCommandInput represents the input for the suggestCommand tool. +type suggestCommandInput struct { + Command string `json:"command"` } func NewAgentCmd(f *cmdutil.Factory) *cobra.Command { @@ -126,7 +133,7 @@ func runAgent(opts *AgentOptions) error { history = append(history, userMsg) opts.IO.StartProgressIndicator() - assistantText, assistantID, err := sendCompletion(opts, history) + result, err := sendCompletion(opts, history) opts.IO.StopProgressIndicator() if err != nil { fmt.Fprintf(opts.IO.ErrOut, "Error: %s\n", err) @@ -137,15 +144,21 @@ func runAgent(opts *AgentOptions) error { fmt.Fprintln(out, separator) fmt.Fprintln(out) - fmt.Fprintln(out, renderMarkdown(opts.IO.ColorScheme(), assistantText)) - fmt.Fprintln(out) + if result.Text != "" { + fmt.Fprintln(out, renderMarkdown(opts.IO.ColorScheme(), result.Text)) + fmt.Fprintln(out) + } + if result.Command != "" { + fmt.Fprintf(out, "%s %s\n", cs.Bold("Suggested command:"), cs.Cyan(result.Command)) + fmt.Fprintln(out) + } fmt.Fprintln(out, separator) history = append(history, message{ - ID: assistantID, + ID: result.MessageID, Role: "assistant", Parts: []part{ - {Type: "text", Text: assistantText}, + {Type: "text", Text: result.Text}, }, }) } @@ -154,8 +167,7 @@ func runAgent(opts *AgentOptions) error { } // sendCompletion sends the conversation to Agent Studio and streams the response. -// Returns the assistant text and the server-generated message ID. -func sendCompletion(opts *AgentOptions, messages []message) (string, string, error) { +func sendCompletion(opts *AgentOptions, messages []message) (completionResult, error) { url := fmt.Sprintf( "https://%s.algolia.net/agent-studio/1/agents/%s/completions?stream=true&compatibilityMode=ai-sdk-5", opts.AppID, opts.AgentID, @@ -166,12 +178,12 @@ func sendCompletion(opts *AgentOptions, messages []message) (string, string, err } body, err := json.Marshal(reqBody) if err != nil { - return "", "", fmt.Errorf("failed to marshal request: %w", err) + return completionResult{}, fmt.Errorf("failed to marshal request: %w", err) } req, err := http.NewRequest("POST", url, strings.NewReader(string(body))) if err != nil { - return "", "", fmt.Errorf("failed to create request: %w", err) + return completionResult{}, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("x-algolia-application-id", opts.AppID) @@ -179,22 +191,28 @@ func sendCompletion(opts *AgentOptions, messages []message) (string, string, err resp, err := http.DefaultClient.Do(req) if err != nil { - return "", "", fmt.Errorf("request failed: %w", err) + return completionResult{}, fmt.Errorf("request failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return "", "", fmt.Errorf("unexpected status: %s", resp.Status) + return completionResult{}, fmt.Errorf("unexpected status: %s", resp.Status) } return parseSSEStream(resp.Body) } +// completionResult holds the parsed response from the SSE stream. +type completionResult struct { + Text string + MessageID string + Command string // optional, from suggestCommand tool +} + // parseSSEStream reads an SSE stream and collects text deltas. -// Returns the assembled text and the server-generated message ID. -func parseSSEStream(r io.Reader) (string, string, error) { - var result strings.Builder - var messageID string +func parseSSEStream(r io.Reader) (completionResult, error) { + var res completionResult + var textBuf strings.Builder scanner := bufio.NewScanner(r) for scanner.Scan() { line := scanner.Text() @@ -212,13 +230,21 @@ func parseSSEStream(r io.Reader) (string, string, error) { } switch event.Type { case "start": - messageID = event.MessageID + res.MessageID = event.MessageID case "text-delta": - result.WriteString(event.Delta) + textBuf.WriteString(event.Delta) + case "tool-input-available": + if event.ToolName == "suggestCommand" { + var input suggestCommandInput + if err := json.Unmarshal(event.Input, &input); err == nil { + res.Command = input.Command + } + } } } - return result.String(), messageID, nil + res.Text = textBuf.String() + return res, nil } func envOrDefault(key, defaultVal string) string { diff --git a/pkg/cmd/agent/agent_test.go b/pkg/cmd/agent/agent_test.go index a88497f5..fbd86a0b 100644 --- a/pkg/cmd/agent/agent_test.go +++ b/pkg/cmd/agent/agent_test.go @@ -8,10 +8,11 @@ import ( func TestParseSSEStream(t *testing.T) { tests := []struct { - name string - input string - wantText string - wantMsgID string + name string + input string + wantText string + wantMsgID string + wantCommand string }{ { name: "parses a complete stream with start and text-delta events", @@ -53,20 +54,33 @@ func TestParseSSEStream(t *testing.T) { wantText: "before", wantMsgID: "", }, + { + name: "parses suggestCommand tool call", + input: "data: {\"type\":\"start\",\"messageId\":\"msg_456\"}\n" + + "data: {\"type\":\"text-delta\",\"delta\":\"Try this:\"}\n" + + "data: {\"type\":\"tool-input-available\",\"toolName\":\"suggestCommand\",\"input\":{\"command\":\"algolia search MOVIES\"}}\n" + + "data: [DONE]\n", + wantText: "Try this:", + wantMsgID: "msg_456", + wantCommand: "algolia search MOVIES", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { r := strings.NewReader(tt.input) - gotText, gotMsgID, err := parseSSEStream(r) + result, err := parseSSEStream(r) if err != nil { t.Fatalf("unexpected error: %v", err) } - if gotText != tt.wantText { - t.Errorf("text = %q, want %q", gotText, tt.wantText) + if result.Text != tt.wantText { + t.Errorf("text = %q, want %q", result.Text, tt.wantText) + } + if result.MessageID != tt.wantMsgID { + t.Errorf("messageID = %q, want %q", result.MessageID, tt.wantMsgID) } - if gotMsgID != tt.wantMsgID { - t.Errorf("messageID = %q, want %q", gotMsgID, tt.wantMsgID) + if result.Command != tt.wantCommand { + t.Errorf("command = %q, want %q", result.Command, tt.wantCommand) } }) } From fa9ebed71fd60822d8702f475b60070b633992ec Mon Sep 17 00:00:00 2001 From: Lorris Saint-Genez Date: Fri, 20 Mar 2026 12:49:37 -0700 Subject: [PATCH 06/10] feat(agent): execute suggested commands from agent conversation --- go.mod | 2 + go.sum | 7 + pkg/cmd/agent/agent.go | 285 +++++++++++++++++++++++++++++++++++++---- 3 files changed, 272 insertions(+), 22 deletions(-) diff --git a/go.mod b/go.mod index 7f362b4c..4bc67cd4 100644 --- a/go.mod +++ b/go.mod @@ -41,8 +41,10 @@ require ( al.essio.dev/pkg/shellescape v1.5.1 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/chzyer/readline v1.5.1 // indirect github.com/clipperhouse/displaywidth v0.10.0 // indirect github.com/clipperhouse/uax29/v2 v2.6.0 // indirect + github.com/creack/pty v1.1.24 // indirect github.com/danieljoos/wincred v1.2.2 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fatih/color v1.18.0 // indirect diff --git a/go.sum b/go.sum index d9794ea2..78e05fb3 100644 --- a/go.sum +++ b/go.sum @@ -18,6 +18,10 @@ github.com/briandowns/spinner v1.23.2 h1:Zc6ecUnI+YzLmJniCfDNaMbW0Wid1d5+qcTq4L2 github.com/briandowns/spinner v1.23.2/go.mod h1:LaZeM4wm2Ywy6vO571mvhQNRcWfRUnXOs0RcKV0wYKM= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= +github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI= +github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= +github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/cli/go-internal v0.0.0-20241025142207-6c48bcd5ce24 h1:QDrhR4JA2n3ij9YQN0u5ZeuvRIIvsUGmf5yPlTS0w8E= github.com/cli/go-internal v0.0.0-20241025142207-6c48bcd5ce24/go.mod h1:rr9GNING0onuVw8MnracQHn7PcchnFlP882Y0II2KZk= github.com/cli/safeexec v1.0.1 h1:e/C79PbXF4yYTN/wauC4tviMxEV13BwljGj0N9j+N00= @@ -30,6 +34,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6N github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.17 h1:QeVUsEDNrLBW4tMgZHvxy18sKtr6VI492kBhUfhDJNI= github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= +github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= +github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/danieljoos/wincred v1.2.2 h1:774zMFJrqaeYCK2W57BgAem/MLi6mtSE47MB6BOJ0i0= github.com/danieljoos/wincred v1.2.2/go.mod h1:w7w4Utbrz8lqeMbDAK0lkNJUv5sAOkFi7nd/ogr0Uh8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -187,6 +193,7 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/pkg/cmd/agent/agent.go b/pkg/cmd/agent/agent.go index 51a13ff6..750f2d70 100644 --- a/pkg/cmd/agent/agent.go +++ b/pkg/cmd/agent/agent.go @@ -2,16 +2,23 @@ package agent import ( "bufio" + "bytes" "encoding/json" "fmt" "io" "net/http" "os" + "os/exec" "strings" + "github.com/chzyer/readline" + "github.com/creack/pty" + "github.com/MakeNowJust/heredoc" "github.com/spf13/cobra" + "github.com/google/uuid" + "github.com/algolia/cli/pkg/auth" "github.com/algolia/cli/pkg/cmdutil" "github.com/algolia/cli/pkg/iostreams" @@ -48,6 +55,7 @@ type part struct { // completionRequest is the request body sent to Agent Studio. type completionRequest struct { + ID string `json:"id"` Messages []message `json:"messages"` } @@ -97,7 +105,15 @@ func runAgent(opts *AgentOptions) error { } out := opts.IO.Out - scanner := bufio.NewScanner(os.Stdin) + + rl, err := readline.NewEx(&readline.Config{ + Prompt: "> ", + HistoryFile: "", + }) + if err != nil { + return fmt.Errorf("failed to initialize readline: %w", err) + } + defer rl.Close() cs := opts.IO.ColorScheme() separator := cs.Gray(strings.Repeat("─", opts.IO.TerminalWidth())) @@ -105,22 +121,40 @@ func runAgent(opts *AgentOptions) error { fmt.Fprintln(out, "Algolia CLI Agent (type \"exit\" to quit)") fmt.Fprintln(out, separator) + conversationID, err := newConversationID() + if err != nil { + return err + } + var history []message msgCounter := 0 for { - fmt.Fprint(out, "> ") - - if !scanner.Scan() { + line, err := rl.Readline() + if err != nil { // io.EOF or interrupt break } - input := strings.TrimSpace(scanner.Text()) + input := strings.TrimSpace(line) if input == "" { continue } if input == "exit" { break } + if input == "/clear" { + history = nil + msgCounter = 0 + newID, idErr := newConversationID() + if idErr != nil { + fmt.Fprintf(opts.IO.ErrOut, "Error: %s\n", idErr) + continue + } + conversationID = newID + fmt.Fprintln(out, separator) + fmt.Fprintln(out, cs.Gray("Conversation cleared.")) + fmt.Fprintln(out, separator) + continue + } msgCounter++ userMsg := message{ @@ -132,8 +166,8 @@ func runAgent(opts *AgentOptions) error { } history = append(history, userMsg) - opts.IO.StartProgressIndicator() - result, err := sendCompletion(opts, history) + opts.IO.StartProgressIndicatorWithLabel("\nThinking...") + result, err := sendCompletion(opts, conversationID, history) opts.IO.StopProgressIndicator() if err != nil { fmt.Fprintf(opts.IO.ErrOut, "Error: %s\n", err) @@ -143,37 +177,87 @@ func runAgent(opts *AgentOptions) error { } fmt.Fprintln(out, separator) - fmt.Fprintln(out) - if result.Text != "" { - fmt.Fprintln(out, renderMarkdown(opts.IO.ColorScheme(), result.Text)) + + for { fmt.Fprintln(out) - } - if result.Command != "" { - fmt.Fprintf(out, "%s %s\n", cs.Bold("Suggested command:"), cs.Cyan(result.Command)) + if result.Text != "" { + fmt.Fprintln(out, renderMarkdown(cs, result.Text)) + } + history = append(history, message{ + ID: result.MessageID, + Role: "assistant", + Parts: []part{ + {Type: "text", Text: result.Text}, + }, + }) + + if result.Command == "" { + fmt.Fprintln(out) + break + } + + if isSafeCommand(result.Command) { + fmt.Fprintf(out, "%s\n", cs.Gray(fmt.Sprintf("\033[3mRunning %s\033[0m", result.Command))) + } else { + fmt.Fprintf(out, "%s %s\n", cs.Bold("Suggested command:"), cs.Cyan(result.Command)) + fmt.Fprintln(out) + rl.SetPrompt("Run this command? [Y/n] ") + confirmLine, confirmErr := rl.Readline() + rl.SetPrompt("> ") + if confirmErr != nil { + break + } + answer := strings.TrimSpace(strings.ToLower(confirmLine)) + // Rewrite the prompt line with the actual answer. + fmt.Fprintf(out, "\033[1A\033[2K") + if answer == "" || answer == "y" || answer == "yes" { + fmt.Fprintf(out, "Run this command? %s\n", cs.Green("Y")) + } else { + fmt.Fprintf(out, "Run this command? %s\n", cs.Red("n\n")) + break + } + } + fmt.Fprintln(out) + cmdOutput, cmdErr := executeCommand(result.Command) + msgCounter++ + outputText := fmt.Sprintf("Command `%s` was executed.\nOutput:\n\n%s\n", result.Command, cmdOutput) + if cmdErr != nil { + outputText += fmt.Sprintf("\nError: %s", cmdErr) + } + history = append(history, message{ + ID: fmt.Sprintf("alg_msg_%d", msgCounter), + Role: "user", + Parts: []part{ + {Text: outputText}, + }, + }) + + opts.IO.StartProgressIndicatorWithLabel("\nThinking...") + followUp, err := sendCompletion(opts, conversationID, history) + opts.IO.StopProgressIndicator() + if err != nil { + fmt.Fprintf(opts.IO.ErrOut, "Error: %s\n", err) + break + } + result = followUp } - fmt.Fprintln(out, separator) - history = append(history, message{ - ID: result.MessageID, - Role: "assistant", - Parts: []part{ - {Type: "text", Text: result.Text}, - }, - }) + fmt.Fprintln(out, separator) } return nil } // sendCompletion sends the conversation to Agent Studio and streams the response. -func sendCompletion(opts *AgentOptions, messages []message) (completionResult, error) { +func sendCompletion(opts *AgentOptions, conversationID string, messages []message) (completionResult, error) { url := fmt.Sprintf( "https://%s.algolia.net/agent-studio/1/agents/%s/completions?stream=true&compatibilityMode=ai-sdk-5", opts.AppID, opts.AgentID, ) reqBody := completionRequest{ + ID: conversationID, Messages: messages, } body, err := json.Marshal(reqBody) @@ -247,6 +331,163 @@ func parseSSEStream(r io.Reader) (completionResult, error) { return res, nil } +// validateCommand checks that a command string does not contain dangerous shell metacharacters. +func validateCommand(command string) error { + for _, pattern := range []string{"&&", "||", ";", "$(", "`"} { + if strings.Contains(command, pattern) { + return fmt.Errorf("command contains disallowed shell operator: %s", pattern) + } + } + return nil +} + +// safeCommands lists read-only command prefixes that can be auto-run without confirmation. +var safeCommands = []string{ + "profile list", + "application list", + "indices list", + "apikeys list", + "search ", + "objects browse", + "settings get", + "rules browse", + "synonyms browse", + "dictionary settings get", + "dictionary entries browse", + "describe", + "open", + "events tail", + "crawler list", + "crawler get", + "crawler stats", + "indices config export", + "indices analyze", +} + +// isSafeCommand checks if a command is read-only and can be auto-run. +func isSafeCommand(command string) bool { + // Strip the leading "algolia " to get the subcommand. + sub := strings.TrimPrefix(command, "algolia ") + if sub == command { + return false + } + for _, safe := range safeCommands { + if strings.HasPrefix(sub, safe) { + return true + } + } + return false +} + +// executeCommand runs a command string inside a PTY so that the child process +// sees a real terminal (IsTerminal returns true). Output is tee'd to the user's +// terminal and captured for the agent context. +func executeCommand(command string) (string, error) { + if command == "" { + return "", fmt.Errorf("empty command") + } + if err := validateCommand(command); err != nil { + return "", err + } + command = replaceAlgoliaBinary(command) + cmd := exec.Command("sh", "-c", command) + cmd.Env = append(os.Environ(), "PAGER=cat") + + ptmx, err := pty.Start(cmd) + if err != nil { + return "", fmt.Errorf("failed to start command: %w", err) + } + defer ptmx.Close() + + // Tee PTY output to both the real terminal and a buffer. + var buf bytes.Buffer + _, _ = io.Copy(io.MultiWriter(os.Stdout, &buf), ptmx) + + _ = cmd.Wait() + return strings.TrimSpace(stripANSI(buf.String())), nil +} + +// stripANSI removes ANSI escape sequences from a string and simulates +// carriage return behavior (overwrites the current line). +func stripANSI(s string) string { + var lines []string + var cur strings.Builder + i := 0 + for i < len(s) { + if s[i] == '\033' { + // Skip CSI sequences: ESC [ ... final byte + if i+1 < len(s) && s[i+1] == '[' { + j := i + 2 + for j < len(s) && s[j] >= 0x20 && s[j] <= 0x3F { + j++ + } + if j < len(s) { + j++ // skip final byte + } + i = j + continue + } + // Skip other ESC sequences (ESC + one byte) + i += 2 + continue + } + if s[i] == '\r' { + // \r\n is a normal newline, not a spinner overwrite. + if i+1 < len(s) && s[i+1] == '\n' { + lines = append(lines, cur.String()) + cur.Reset() + i += 2 + continue + } + // Standalone \r: discard current line content (spinner overwrite) + cur.Reset() + i++ + continue + } + if s[i] == '\n' { + lines = append(lines, cur.String()) + cur.Reset() + i++ + continue + } + cur.WriteByte(s[i]) + i++ + } + if cur.Len() > 0 { + lines = append(lines, cur.String()) + } + // Filter out empty lines from spinner artifacts. + var result []string + for _, l := range lines { + if strings.TrimSpace(l) != "" { + result = append(result, l) + } + } + return strings.Join(result, "\n") +} + +// replaceAlgoliaBinary replaces "algolia" at command positions with the actual binary path +// (e.g. "./algolia" in dev). Only replaces at the start and after pipes. +func replaceAlgoliaBinary(command string) string { + bin := os.Args[0] + if bin == "algolia" { + return command + } + if strings.HasPrefix(command, "algolia ") { + command = bin + command[len("algolia"):] + } + command = strings.ReplaceAll(command, "| algolia ", "| "+bin+" ") + return command +} + +func newConversationID() (string, error) { + id, err := uuid.NewRandom() + if err != nil { + return "", fmt.Errorf("failed to generate conversation ID: %w", err) + } + return "alg_cnv_" + id.String(), nil +} + func envOrDefault(key, defaultVal string) string { if v := os.Getenv(key); v != "" { return v From 0b374da465217c237cc7cc9c0a6756f0323a369a Mon Sep 17 00:00:00 2001 From: Lorris Saint-Genez Date: Fri, 20 Mar 2026 14:15:01 -0700 Subject: [PATCH 07/10] feat(agent): execute suggested commands with safety controls --- pkg/cmd/agent/agent.go | 229 ++---------------------------------- pkg/cmd/agent/agent_test.go | 143 ++++++++++++++++++++++ pkg/cmd/agent/command.go | 128 ++++++++++++++++++++ pkg/cmd/agent/output.go | 120 +++++++++++++++++++ pkg/cmd/agent/sse.go | 68 +++++++++++ 5 files changed, 472 insertions(+), 216 deletions(-) create mode 100644 pkg/cmd/agent/command.go create mode 100644 pkg/cmd/agent/output.go create mode 100644 pkg/cmd/agent/sse.go diff --git a/pkg/cmd/agent/agent.go b/pkg/cmd/agent/agent.go index 750f2d70..dd703dc5 100644 --- a/pkg/cmd/agent/agent.go +++ b/pkg/cmd/agent/agent.go @@ -1,18 +1,14 @@ package agent import ( - "bufio" - "bytes" "encoding/json" "fmt" - "io" "net/http" "os" - "os/exec" "strings" + "time" "github.com/chzyer/readline" - "github.com/creack/pty" "github.com/MakeNowJust/heredoc" "github.com/spf13/cobra" @@ -59,21 +55,6 @@ type completionRequest struct { Messages []message `json:"messages"` } -// sseEvent represents a parsed SSE data payload. -type sseEvent struct { - Type string `json:"type"` - ID string `json:"id,omitempty"` - MessageID string `json:"messageId,omitempty"` - Delta string `json:"delta,omitempty"` - ToolName string `json:"toolName,omitempty"` - Input json.RawMessage `json:"input,omitempty"` -} - -// suggestCommandInput represents the input for the suggestCommand tool. -type suggestCommandInput struct { - Command string `json:"command"` -} - func NewAgentCmd(f *cmdutil.Factory) *cobra.Command { opts := &AgentOptions{ IO: f.IOStreams, @@ -85,7 +66,7 @@ func NewAgentCmd(f *cmdutil.Factory) *cobra.Command { cmd := &cobra.Command{ Use: "agent", Short: "Chat with an AI agent that suggests Algolia CLI commands", - Long: "Interactive chat with an AI agent that advises CLI commands for your use case. The agent only prints suggestions — it does not execute commands.", + Long: "Interactive chat with an AI agent that can suggest and execute Algolia CLI commands for your use case.", Example: heredoc.Doc(` $ algolia agent `), @@ -196,6 +177,12 @@ func runAgent(opts *AgentOptions) error { break } + if isBlockedCommand(result.Command) { + fmt.Fprintf(out, "%s %s\n", cs.Bold("Suggested command:"), cs.Cyan(result.Command)) + fmt.Fprintln(out) + break + } + if isSafeCommand(result.Command) { fmt.Fprintf(out, "%s\n", cs.Gray(fmt.Sprintf("\033[3mRunning %s\033[0m", result.Command))) } else { @@ -221,6 +208,8 @@ func runAgent(opts *AgentOptions) error { fmt.Fprintln(out) cmdOutput, cmdErr := executeCommand(result.Command) msgCounter++ + cmdOutput = compactJSON(cmdOutput) + cmdOutput = truncateOutput(cmdOutput, 10) outputText := fmt.Sprintf("Command `%s` was executed.\nOutput:\n\n%s\n", result.Command, cmdOutput) if cmdErr != nil { outputText += fmt.Sprintf("\nError: %s", cmdErr) @@ -273,7 +262,8 @@ func sendCompletion(opts *AgentOptions, conversationID string, messages []messag req.Header.Set("x-algolia-application-id", opts.AppID) req.Header.Set("X-Algolia-API-Key", opts.APIKey) - resp, err := http.DefaultClient.Do(req) + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) if err != nil { return completionResult{}, fmt.Errorf("request failed: %w", err) } @@ -286,199 +276,6 @@ func sendCompletion(opts *AgentOptions, conversationID string, messages []messag return parseSSEStream(resp.Body) } -// completionResult holds the parsed response from the SSE stream. -type completionResult struct { - Text string - MessageID string - Command string // optional, from suggestCommand tool -} - -// parseSSEStream reads an SSE stream and collects text deltas. -func parseSSEStream(r io.Reader) (completionResult, error) { - var res completionResult - var textBuf strings.Builder - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := scanner.Text() - if !strings.HasPrefix(line, "data: ") { - continue - } - data := strings.TrimPrefix(line, "data: ") - if data == "[DONE]" { - break - } - - var event sseEvent - if err := json.Unmarshal([]byte(data), &event); err != nil { - continue - } - switch event.Type { - case "start": - res.MessageID = event.MessageID - case "text-delta": - textBuf.WriteString(event.Delta) - case "tool-input-available": - if event.ToolName == "suggestCommand" { - var input suggestCommandInput - if err := json.Unmarshal(event.Input, &input); err == nil { - res.Command = input.Command - } - } - } - } - - res.Text = textBuf.String() - return res, nil -} - -// validateCommand checks that a command string does not contain dangerous shell metacharacters. -func validateCommand(command string) error { - for _, pattern := range []string{"&&", "||", ";", "$(", "`"} { - if strings.Contains(command, pattern) { - return fmt.Errorf("command contains disallowed shell operator: %s", pattern) - } - } - return nil -} - -// safeCommands lists read-only command prefixes that can be auto-run without confirmation. -var safeCommands = []string{ - "profile list", - "application list", - "indices list", - "apikeys list", - "search ", - "objects browse", - "settings get", - "rules browse", - "synonyms browse", - "dictionary settings get", - "dictionary entries browse", - "describe", - "open", - "events tail", - "crawler list", - "crawler get", - "crawler stats", - "indices config export", - "indices analyze", -} - -// isSafeCommand checks if a command is read-only and can be auto-run. -func isSafeCommand(command string) bool { - // Strip the leading "algolia " to get the subcommand. - sub := strings.TrimPrefix(command, "algolia ") - if sub == command { - return false - } - for _, safe := range safeCommands { - if strings.HasPrefix(sub, safe) { - return true - } - } - return false -} - -// executeCommand runs a command string inside a PTY so that the child process -// sees a real terminal (IsTerminal returns true). Output is tee'd to the user's -// terminal and captured for the agent context. -func executeCommand(command string) (string, error) { - if command == "" { - return "", fmt.Errorf("empty command") - } - if err := validateCommand(command); err != nil { - return "", err - } - command = replaceAlgoliaBinary(command) - cmd := exec.Command("sh", "-c", command) - cmd.Env = append(os.Environ(), "PAGER=cat") - - ptmx, err := pty.Start(cmd) - if err != nil { - return "", fmt.Errorf("failed to start command: %w", err) - } - defer ptmx.Close() - - // Tee PTY output to both the real terminal and a buffer. - var buf bytes.Buffer - _, _ = io.Copy(io.MultiWriter(os.Stdout, &buf), ptmx) - - _ = cmd.Wait() - return strings.TrimSpace(stripANSI(buf.String())), nil -} - -// stripANSI removes ANSI escape sequences from a string and simulates -// carriage return behavior (overwrites the current line). -func stripANSI(s string) string { - var lines []string - var cur strings.Builder - i := 0 - for i < len(s) { - if s[i] == '\033' { - // Skip CSI sequences: ESC [ ... final byte - if i+1 < len(s) && s[i+1] == '[' { - j := i + 2 - for j < len(s) && s[j] >= 0x20 && s[j] <= 0x3F { - j++ - } - if j < len(s) { - j++ // skip final byte - } - i = j - continue - } - // Skip other ESC sequences (ESC + one byte) - i += 2 - continue - } - if s[i] == '\r' { - // \r\n is a normal newline, not a spinner overwrite. - if i+1 < len(s) && s[i+1] == '\n' { - lines = append(lines, cur.String()) - cur.Reset() - i += 2 - continue - } - // Standalone \r: discard current line content (spinner overwrite) - cur.Reset() - i++ - continue - } - if s[i] == '\n' { - lines = append(lines, cur.String()) - cur.Reset() - i++ - continue - } - cur.WriteByte(s[i]) - i++ - } - if cur.Len() > 0 { - lines = append(lines, cur.String()) - } - // Filter out empty lines from spinner artifacts. - var result []string - for _, l := range lines { - if strings.TrimSpace(l) != "" { - result = append(result, l) - } - } - return strings.Join(result, "\n") -} - -// replaceAlgoliaBinary replaces "algolia" at command positions with the actual binary path -// (e.g. "./algolia" in dev). Only replaces at the start and after pipes. -func replaceAlgoliaBinary(command string) string { - bin := os.Args[0] - if bin == "algolia" { - return command - } - if strings.HasPrefix(command, "algolia ") { - command = bin + command[len("algolia"):] - } - command = strings.ReplaceAll(command, "| algolia ", "| "+bin+" ") - return command -} func newConversationID() (string, error) { id, err := uuid.NewRandom() @@ -493,4 +290,4 @@ func envOrDefault(key, defaultVal string) string { return v } return defaultVal -} +} \ No newline at end of file diff --git a/pkg/cmd/agent/agent_test.go b/pkg/cmd/agent/agent_test.go index fbd86a0b..a61fd99d 100644 --- a/pkg/cmd/agent/agent_test.go +++ b/pkg/cmd/agent/agent_test.go @@ -86,6 +86,149 @@ func TestParseSSEStream(t *testing.T) { } } +func TestValidateCommand(t *testing.T) { + tests := []struct { + name string + command string + wantErr bool + }{ + {name: "simple command", command: "algolia indices list", wantErr: false}, + {name: "command with pipe", command: "algolia profile list | head -5", wantErr: false}, + {name: "command with quotes", command: `algolia profile list`, wantErr: false}, + {name: "blocks objects browse", command: "algolia objects browse MOVIES", wantErr: true}, + {name: "blocks search", command: "algolia search MOVIES --query test", wantErr: true}, + {name: "blocks rules browse", command: "algolia rules browse MOVIES", wantErr: true}, + {name: "blocks synonyms browse", command: "algolia synonyms browse MOVIES", wantErr: true}, + {name: "blocks dictionary entries browse", command: "algolia dictionary entries browse stopwords", wantErr: true}, + {name: "blocks semicolon", command: "algolia indices list; rm -rf /", wantErr: true}, + {name: "blocks double ampersand", command: "algolia indices list && echo pwned", wantErr: true}, + {name: "blocks double pipe", command: "algolia indices list || echo fallback", wantErr: true}, + {name: "blocks dollar paren", command: "algolia indices list $(whoami)", wantErr: true}, + {name: "blocks backtick", command: "algolia indices list `whoami`", wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateCommand(tt.command) + if (err != nil) != tt.wantErr { + t.Errorf("validateCommand(%q) error = %v, wantErr %v", tt.command, err, tt.wantErr) + } + }) + } +} + +func TestIsSafeCommand(t *testing.T) { + tests := []struct { + name string + command string + want bool + }{ + {name: "profile list is safe", command: "algolia profile list", want: true}, + {name: "search is not safe", command: "algolia search MOVIES --query test", want: false}, + {name: "objects browse is not safe", command: "algolia objects browse MOVIES", want: false}, + {name: "indices list is safe", command: "algolia indices list", want: true}, + {name: "describe is safe", command: "algolia describe search", want: true}, + {name: "delete is not safe", command: "algolia indices delete MOVIES -y", want: false}, + {name: "objects import is not safe", command: "algolia objects import MOVIES -F data.ndjson", want: false}, + {name: "non-algolia command is not safe", command: "rm -rf /", want: false}, + {name: "empty string is not safe", command: "", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isSafeCommand(tt.command) + if got != tt.want { + t.Errorf("isSafeCommand(%q) = %v, want %v", tt.command, got, tt.want) + } + }) + } +} + +func TestStripANSI(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {name: "plain text unchanged", input: "hello world", want: "hello world"}, + {name: "strips color codes", input: "\033[31mred\033[0m", want: "red"}, + {name: "strips bold", input: "\033[1mbold\033[0m", want: "bold"}, + {name: "handles spinner overwrite", input: "Loading ⣾\rLoading ⣽\rDone", want: "Done"}, + {name: "preserves newlines", input: "line1\nline2", want: "line1\nline2"}, + {name: "handles CRLF", input: "line1\r\nline2", want: "line1\nline2"}, + {name: "filters empty lines from spinner", input: "Fetching\rFetching\r\n\nresult", want: "Fetching\nresult"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := stripANSI(tt.input) + if got != tt.want { + t.Errorf("stripANSI(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestReplaceAlgoliaBinary(t *testing.T) { + // Save and restore os.Args[0] + origArg := os.Args[0] + defer func() { os.Args[0] = origArg }() + + t.Run("no replacement when binary is algolia", func(t *testing.T) { + os.Args[0] = "algolia" + got := replaceAlgoliaBinary("algolia search MOVIES") + if got != "algolia search MOVIES" { + t.Errorf("got %q, want %q", got, "algolia search MOVIES") + } + }) + + t.Run("replaces at start", func(t *testing.T) { + os.Args[0] = "./algolia" + got := replaceAlgoliaBinary("algolia search MOVIES") + if got != "./algolia search MOVIES" { + t.Errorf("got %q, want %q", got, "./algolia search MOVIES") + } + }) + + t.Run("replaces after pipe", func(t *testing.T) { + os.Args[0] = "./algolia" + got := replaceAlgoliaBinary("algolia objects browse SRC | algolia objects import DST -F -") + if got != "./algolia objects browse SRC | ./algolia objects import DST -F -" { + t.Errorf("got %q, want %q", got, "./algolia objects browse SRC | ./algolia objects import DST -F -") + } + }) + + t.Run("does not replace in arguments", func(t *testing.T) { + os.Args[0] = "./algolia" + got := replaceAlgoliaBinary("algolia search hello-algolia") + if got != "./algolia search hello-algolia" { + t.Errorf("got %q, want %q", got, "./algolia search hello-algolia") + } + }) +} + +func TestTruncateOutput(t *testing.T) { + tests := []struct { + name string + input string + maxLines int + want string + }{ + {name: "short output unchanged", input: "line1\nline2", maxLines: 10, want: "line1\nline2"}, + {name: "truncates at limit", input: "1\n2\n3\n4\n5", maxLines: 3, want: "1\n2\n3\n[... 2 more lines truncated]"}, + {name: "single line", input: "hello", maxLines: 10, want: "hello"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := truncateOutput(tt.input, tt.maxLines) + if got != tt.want { + t.Errorf("truncateOutput() = %q, want %q", got, tt.want) + } + }) + } +} + func TestEnvOrDefault(t *testing.T) { tests := []struct { name string diff --git a/pkg/cmd/agent/command.go b/pkg/cmd/agent/command.go new file mode 100644 index 00000000..cb0f8cf2 --- /dev/null +++ b/pkg/cmd/agent/command.go @@ -0,0 +1,128 @@ +package agent + +import ( + "bytes" + "context" + "fmt" + "io" + "os" + "os/exec" + "strings" + "time" + + "github.com/creack/pty" +) + +// blockedCommands lists command prefixes that cannot be executed from the agent. +// These commands will be shown as suggestions only. +var blockedCommands = []string{ + "objects browse", + "search ", + "rules browse", + "synonyms browse", + "dictionary entries browse", +} + +// safeCommands lists read-only command prefixes that can be auto-run without confirmation. +var safeCommands = []string{ + "profile list", + "application list", + "indices list", + "apikeys list", + "settings get", + "dictionary settings get", + "describe", + "open", + "events tail", + "crawler list", + "crawler get", + "crawler stats", + "indices config export", + "indices analyze", +} + +// isSafeCommand checks if a command is read-only and can be auto-run. +func isSafeCommand(command string) bool { + // Strip the leading "algolia " to get the subcommand. + sub := strings.TrimPrefix(command, "algolia ") + if sub == command { + return false + } + for _, safe := range safeCommands { + if strings.HasPrefix(sub, safe) { + return true + } + } + return false +} + +// isBlockedCommand checks if a command is not allowed to run from the agent. +func isBlockedCommand(command string) bool { + sub := strings.TrimPrefix(command, "algolia ") + if sub == command { + return false + } + for _, blocked := range blockedCommands { + if strings.HasPrefix(sub, blocked) { + return true + } + } + return false +} + +// validateCommand checks that a command string does not contain dangerous shell metacharacters. +func validateCommand(command string) error { + if isBlockedCommand(command) { + return fmt.Errorf("command is not allowed from the agent") + } + for _, pattern := range []string{"&&", "||", ";", "$(", "`"} { + if strings.Contains(command, pattern) { + return fmt.Errorf("command contains disallowed shell operator: %s", pattern) + } + } + return nil +} + +// executeCommand runs a command string inside a PTY so that the child process +// sees a real terminal (IsTerminal returns true). Output is tee'd to the user's +// terminal and captured for the agent context. +func executeCommand(command string) (string, error) { + if command == "" { + return "", fmt.Errorf("empty command") + } + if err := validateCommand(command); err != nil { + return "", err + } + command = replaceAlgoliaBinary(command) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + cmd := exec.CommandContext(ctx, "sh", "-c", command) + cmd.Env = append(os.Environ(), "PAGER=cat") + + ptmx, err := pty.Start(cmd) + if err != nil { + return "", fmt.Errorf("failed to start command: %w", err) + } + defer ptmx.Close() + + // Tee PTY output to both the real terminal and a buffer. + var buf bytes.Buffer + _, _ = io.Copy(io.MultiWriter(os.Stdout, &buf), ptmx) + + _ = cmd.Wait() + return strings.TrimSpace(stripANSI(buf.String())), nil +} + +// replaceAlgoliaBinary replaces "algolia" at command positions with the actual binary path +// (e.g. "./algolia" in dev). Only replaces at the start and after pipes. +func replaceAlgoliaBinary(command string) string { + bin := os.Args[0] + if bin == "algolia" { + return command + } + if strings.HasPrefix(command, "algolia ") { + command = bin + command[len("algolia"):] + } + command = strings.ReplaceAll(command, "| algolia ", "| "+bin+" ") + return command +} diff --git a/pkg/cmd/agent/output.go b/pkg/cmd/agent/output.go new file mode 100644 index 00000000..922edcd8 --- /dev/null +++ b/pkg/cmd/agent/output.go @@ -0,0 +1,120 @@ +package agent + +import ( + "fmt" + "strings" +) + +// stripANSI removes ANSI escape sequences from a string and simulates +// carriage return behavior (overwrites the current line). +func stripANSI(s string) string { + var lines []string + var cur strings.Builder + i := 0 + for i < len(s) { + if s[i] == '\033' { + // Skip CSI sequences: ESC [ ... final byte + if i+1 < len(s) && s[i+1] == '[' { + j := i + 2 + for j < len(s) && s[j] >= 0x20 && s[j] <= 0x3F { + j++ + } + if j < len(s) { + j++ // skip final byte + } + i = j + continue + } + // Skip other ESC sequences (ESC + one byte) + i += 2 + continue + } + if s[i] == '\r' { + // \r\n is a normal newline, not a spinner overwrite. + if i+1 < len(s) && s[i+1] == '\n' { + lines = append(lines, cur.String()) + cur.Reset() + i += 2 + continue + } + // Standalone \r: discard current line content (spinner overwrite) + cur.Reset() + i++ + continue + } + if s[i] == '\n' { + lines = append(lines, cur.String()) + cur.Reset() + i++ + continue + } + cur.WriteByte(s[i]) + i++ + } + if cur.Len() > 0 { + lines = append(lines, cur.String()) + } + // Filter out empty lines from spinner artifacts. + var result []string + for _, l := range lines { + if strings.TrimSpace(l) != "" { + result = append(result, l) + } + } + return strings.Join(result, "\n") +} + +// compactJSON attempts to compact each JSON object in the output to a single line. +// Non-JSON lines are left unchanged. +func compactJSON(s string) string { + lines := strings.Split(s, "\n") + var result []string + var jsonBuf strings.Builder + depth := 0 + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed == "" && depth == 0 { + continue + } + + for _, ch := range trimmed { + switch ch { + case '{', '[': + depth++ + case '}', ']': + depth-- + } + } + + if depth > 0 { + jsonBuf.WriteString(trimmed) + continue + } + + if jsonBuf.Len() > 0 { + jsonBuf.WriteString(trimmed) + result = append(result, jsonBuf.String()) + jsonBuf.Reset() + } else { + result = append(result, trimmed) + } + } + + if jsonBuf.Len() > 0 { + result = append(result, jsonBuf.String()) + } + + return strings.Join(result, "\n") +} + +// truncateOutput limits the output to maxLines non-empty lines. +// If truncated, appends a note indicating how many lines were omitted. +func truncateOutput(s string, maxLines int) string { + lines := strings.Split(s, "\n") + if len(lines) <= maxLines { + return s + } + truncated := strings.Join(lines[:maxLines], "\n") + return truncated + fmt.Sprintf("\n[... %d more lines truncated]", len(lines)-maxLines) +} diff --git a/pkg/cmd/agent/sse.go b/pkg/cmd/agent/sse.go new file mode 100644 index 00000000..4063541c --- /dev/null +++ b/pkg/cmd/agent/sse.go @@ -0,0 +1,68 @@ +package agent + +import ( + "bufio" + "encoding/json" + "io" + "strings" +) + +// sseEvent represents a parsed SSE data payload. +type sseEvent struct { + Type string `json:"type"` + ID string `json:"id,omitempty"` + MessageID string `json:"messageId,omitempty"` + Delta string `json:"delta,omitempty"` + ToolName string `json:"toolName,omitempty"` + Input json.RawMessage `json:"input,omitempty"` +} + +// suggestCommandInput represents the input for the suggestCommand tool. +type suggestCommandInput struct { + Command string `json:"command"` +} + +// completionResult holds the parsed response from the SSE stream. +type completionResult struct { + Text string + MessageID string + Command string // optional, from suggestCommand tool +} + +// parseSSEStream reads an SSE stream and collects text deltas. +func parseSSEStream(r io.Reader) (completionResult, error) { + var res completionResult + var textBuf strings.Builder + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") { + continue + } + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + break + } + + var event sseEvent + if err := json.Unmarshal([]byte(data), &event); err != nil { + continue + } + switch event.Type { + case "start": + res.MessageID = event.MessageID + case "text-delta": + textBuf.WriteString(event.Delta) + case "tool-input-available": + if event.ToolName == "suggestCommand" { + var input suggestCommandInput + if err := json.Unmarshal(event.Input, &input); err == nil { + res.Command = input.Command + } + } + } + } + + res.Text = textBuf.String() + return res, nil +} From fe96481fe9e49b1a19af449f16f58707c2e17674 Mon Sep 17 00:00:00 2001 From: Lorris Saint-Genez Date: Fri, 20 Mar 2026 14:47:36 -0700 Subject: [PATCH 08/10] feat(agent): add safety controls and output formatting for command execution --- pkg/cmd/agent/agent.go | 4 +-- pkg/cmd/agent/agent_test.go | 22 ------------- pkg/cmd/agent/command.go | 61 ++++++++++++++++++++++++++++++++++++- pkg/cmd/agent/output.go | 12 -------- pkg/iostreams/iostreams.go | 1 + 5 files changed, 62 insertions(+), 38 deletions(-) diff --git a/pkg/cmd/agent/agent.go b/pkg/cmd/agent/agent.go index dd703dc5..b1a9b7b8 100644 --- a/pkg/cmd/agent/agent.go +++ b/pkg/cmd/agent/agent.go @@ -209,7 +209,6 @@ func runAgent(opts *AgentOptions) error { cmdOutput, cmdErr := executeCommand(result.Command) msgCounter++ cmdOutput = compactJSON(cmdOutput) - cmdOutput = truncateOutput(cmdOutput, 10) outputText := fmt.Sprintf("Command `%s` was executed.\nOutput:\n\n%s\n", result.Command, cmdOutput) if cmdErr != nil { outputText += fmt.Sprintf("\nError: %s", cmdErr) @@ -276,7 +275,6 @@ func sendCompletion(opts *AgentOptions, conversationID string, messages []messag return parseSSEStream(resp.Body) } - func newConversationID() (string, error) { id, err := uuid.NewRandom() if err != nil { @@ -290,4 +288,4 @@ func envOrDefault(key, defaultVal string) string { return v } return defaultVal -} \ No newline at end of file +} diff --git a/pkg/cmd/agent/agent_test.go b/pkg/cmd/agent/agent_test.go index a61fd99d..d28661f0 100644 --- a/pkg/cmd/agent/agent_test.go +++ b/pkg/cmd/agent/agent_test.go @@ -207,28 +207,6 @@ func TestReplaceAlgoliaBinary(t *testing.T) { }) } -func TestTruncateOutput(t *testing.T) { - tests := []struct { - name string - input string - maxLines int - want string - }{ - {name: "short output unchanged", input: "line1\nline2", maxLines: 10, want: "line1\nline2"}, - {name: "truncates at limit", input: "1\n2\n3\n4\n5", maxLines: 3, want: "1\n2\n3\n[... 2 more lines truncated]"}, - {name: "single line", input: "hello", maxLines: 10, want: "hello"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := truncateOutput(tt.input, tt.maxLines) - if got != tt.want { - t.Errorf("truncateOutput() = %q, want %q", got, tt.want) - } - }) - } -} - func TestEnvOrDefault(t *testing.T) { tests := []struct { name string diff --git a/pkg/cmd/agent/command.go b/pkg/cmd/agent/command.go index cb0f8cf2..53ec4c32 100644 --- a/pkg/cmd/agent/command.go +++ b/pkg/cmd/agent/command.go @@ -41,6 +41,36 @@ var safeCommands = []string{ "indices analyze", } +// jsonOutputCommands lists command prefixes that support the -o json flag. +var jsonOutputCommands = []string{ + "profile list", + "application list", + "indices list", + "apikeys list", + "settings get", + "crawler list", + "crawler get", + "crawler stats", + "indices analyze", +} + +// forceJSONOutput appends -o json to the command if it supports it and doesn't already have it. +func forceJSONOutput(command string) string { + if strings.Contains(command, " -o ") || strings.Contains(command, " --output ") { + return command + } + sub := strings.TrimPrefix(command, "algolia ") + if sub == command { + return command + } + for _, prefix := range jsonOutputCommands { + if strings.HasPrefix(sub, prefix) { + return command + " -o json" + } + } + return command +} + // isSafeCommand checks if a command is read-only and can be auto-run. func isSafeCommand(command string) bool { // Strip the leading "algolia " to get the subcommand. @@ -70,11 +100,39 @@ func isBlockedCommand(command string) bool { return false } +// requiredFlags maps command prefixes to flags that must be present when run from the agent. +var requiredFlags = map[string][]string{ + "auth login": {"--no-browser", "--app-name"}, + "auth signup": {"--no-browser", "--app-name"}, + "application create": {"--region"}, +} + +// validateRequiredFlags checks that commands include required flags when run from the agent. +func validateRequiredFlags(command string) error { + sub := strings.TrimPrefix(command, "algolia ") + if sub == command { + return nil + } + for prefix, flags := range requiredFlags { + if strings.HasPrefix(sub, prefix) { + for _, flag := range flags { + if !strings.Contains(command, flag) { + return fmt.Errorf("%s requires %s when run from the agent", prefix, flag) + } + } + } + } + return nil +} + // validateCommand checks that a command string does not contain dangerous shell metacharacters. func validateCommand(command string) error { if isBlockedCommand(command) { return fmt.Errorf("command is not allowed from the agent") } + if err := validateRequiredFlags(command); err != nil { + return err + } for _, pattern := range []string{"&&", "||", ";", "$(", "`"} { if strings.Contains(command, pattern) { return fmt.Errorf("command contains disallowed shell operator: %s", pattern) @@ -93,11 +151,12 @@ func executeCommand(command string) (string, error) { if err := validateCommand(command); err != nil { return "", err } + command = forceJSONOutput(command) command = replaceAlgoliaBinary(command) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() cmd := exec.CommandContext(ctx, "sh", "-c", command) - cmd.Env = append(os.Environ(), "PAGER=cat") + cmd.Env = append(os.Environ(), "PAGER=cat", "ALGOLIA_NO_PROMPT=1") ptmx, err := pty.Start(cmd) if err != nil { diff --git a/pkg/cmd/agent/output.go b/pkg/cmd/agent/output.go index 922edcd8..f649690e 100644 --- a/pkg/cmd/agent/output.go +++ b/pkg/cmd/agent/output.go @@ -1,7 +1,6 @@ package agent import ( - "fmt" "strings" ) @@ -107,14 +106,3 @@ func compactJSON(s string) string { return strings.Join(result, "\n") } - -// truncateOutput limits the output to maxLines non-empty lines. -// If truncated, appends a note indicating how many lines were omitted. -func truncateOutput(s string, maxLines int) string { - lines := strings.Split(s, "\n") - if len(lines) <= maxLines { - return s - } - truncated := strings.Join(lines[:maxLines], "\n") - return truncated + fmt.Sprintf("\n[... %d more lines truncated]", len(lines)-maxLines) -} diff --git a/pkg/iostreams/iostreams.go b/pkg/iostreams/iostreams.go index 49f3a169..6de89f9c 100644 --- a/pkg/iostreams/iostreams.go +++ b/pkg/iostreams/iostreams.go @@ -397,6 +397,7 @@ func System() *IOStreams { is256enabled: assumeTrueColor || Is256ColorSupported(), hasTrueColor: assumeTrueColor || IsTrueColorSupported(), pagerCommand: os.Getenv("PAGER"), + neverPrompt: os.Getenv("ALGOLIA_NO_PROMPT") == "1", ttySize: ttySize, } From 4082e3f9265a91745f2dd89ae0e1bd15b0a989d7 Mon Sep 17 00:00:00 2001 From: Lorris Saint-Genez Date: Fri, 20 Mar 2026 15:05:00 -0700 Subject: [PATCH 09/10] fix(agent): remove commands without -o from jsonOutputCommands --- pkg/cmd/agent/command.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/pkg/cmd/agent/command.go b/pkg/cmd/agent/command.go index 53ec4c32..227d3b82 100644 --- a/pkg/cmd/agent/command.go +++ b/pkg/cmd/agent/command.go @@ -43,11 +43,9 @@ var safeCommands = []string{ // jsonOutputCommands lists command prefixes that support the -o json flag. var jsonOutputCommands = []string{ - "profile list", "application list", "indices list", "apikeys list", - "settings get", "crawler list", "crawler get", "crawler stats", From f16e47fa7df9d7addc86acb65be6ea5f765e07cc Mon Sep 17 00:00:00 2001 From: Lorris Saint-Genez Date: Mon, 23 Mar 2026 08:11:38 -0700 Subject: [PATCH 10/10] fix: use ANSI colors --- pkg/cmd/agent/render.go | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/pkg/cmd/agent/render.go b/pkg/cmd/agent/render.go index 5000ccff..4b995d52 100644 --- a/pkg/cmd/agent/render.go +++ b/pkg/cmd/agent/render.go @@ -14,14 +14,6 @@ import ( "github.com/algolia/cli/pkg/iostreams" ) -// algoliaBlue is the Algolia Blue brand color. -const algoliaBlue = "3970ff" - -// algoliaLightBlue is a lighter blue for placeholders. -const algoliaLightBlue = "00aeff" - -// algoliaTeal is the Java teal for flags. -const algoliaTeal = "1cc7d0" // renderMarkdown converts a markdown string into ANSI-styled terminal output // by parsing it with goldmark and walking the AST. @@ -98,10 +90,10 @@ func renderNode(out *strings.Builder, cs *iostreams.ColorScheme, n ast.Node, sou if idx := strings.Index(line, "#"); idx >= 0 { cmd := line[:idx] comment := line[idx:] - out.WriteString(cs.HexToRGB(algoliaBlue, cmd)) + out.WriteString(cs.Blue(cmd)) out.WriteString(cs.Green(comment)) } else { - out.WriteString(cs.HexToRGB(algoliaBlue, line)) + out.WriteString(cs.Blue(line)) } out.WriteString("\n") } @@ -158,18 +150,14 @@ func colorCodeSpan(cs *iostreams.ColorScheme, code string) string { last := 0 for _, match := range codeTokenRe.FindAllStringIndex(code, -1) { if match[0] > last { - out.WriteString(cs.HexToRGB(algoliaBlue, code[last:match[0]])) + out.WriteString(cs.Blue(code[last:match[0]])) } token := code[match[0]:match[1]] - if strings.HasPrefix(token, "<") || strings.HasPrefix(token, "[") { - out.WriteString(cs.HexToRGB(algoliaLightBlue, token)) - } else { - out.WriteString(cs.HexToRGB(algoliaTeal, token)) - } + out.WriteString(cs.Cyan(token)) last = match[1] } if last < len(code) { - out.WriteString(cs.HexToRGB(algoliaBlue, code[last:])) + out.WriteString(cs.Blue(code[last:])) } return out.String() }