From d73e25c3c86496982eddfe1c446c2e74af6080ca Mon Sep 17 00:00:00 2001 From: simon Date: Fri, 13 Mar 2026 00:22:43 +0100 Subject: [PATCH 1/4] Add flag name suggestions for misspelled flags When users misspell a flag (e.g., --outpu instead of --output), they now get a "Did you mean" suggestion. This extends Cobra's existing command suggestion behavior to flags using Levenshtein distance with a threshold of 2. Hidden and deprecated flags are excluded from suggestions. Both long flags (--flagname) and shorthand flags (-x) are handled. Co-authored-by: Isaac --- cmd/root/flag_suggestions.go | 160 ++++++++++++++++++++++++++++++ cmd/root/flag_suggestions_test.go | 152 ++++++++++++++++++++++++++++ cmd/root/root.go | 4 +- 3 files changed, 315 insertions(+), 1 deletion(-) create mode 100644 cmd/root/flag_suggestions.go create mode 100644 cmd/root/flag_suggestions_test.go diff --git a/cmd/root/flag_suggestions.go b/cmd/root/flag_suggestions.go new file mode 100644 index 0000000000..ef5708d946 --- /dev/null +++ b/cmd/root/flag_suggestions.go @@ -0,0 +1,160 @@ +package root + +import ( + "fmt" + "strings" + + "github.com/spf13/cobra" + "github.com/spf13/pflag" +) + +const ( + unknownFlagPrefix = "unknown flag: " + unknownShorthandFlagPrefix = "unknown shorthand flag: " + maxSuggestionDistance = 2 +) + +// levenshteinDistance computes the edit distance between two strings. +func levenshteinDistance(a, b string) int { + if len(a) == 0 { + return len(b) + } + if len(b) == 0 { + return len(a) + } + + // Use a single row for the DP table. + prev := make([]int, len(b)+1) + for j := range len(b) + 1 { + prev[j] = j + } + + for i := range len(a) { + curr := make([]int, len(b)+1) + curr[0] = i + 1 + for j := range len(b) { + cost := 1 + if a[i] == b[j] { + cost = 0 + } + curr[j+1] = min( + curr[j]+1, // insertion + prev[j+1]+1, // deletion + prev[j]+cost, // substitution + ) + } + prev = curr + } + + return prev[len(b)] +} + +// suggestFlagFromError inspects the error message from Cobra for "unknown flag" patterns. +// If a close match is found among the command's flags, it returns an enhanced error +// with a "Did you mean" suggestion appended. Otherwise it returns the original error. +func suggestFlagFromError(cmd *cobra.Command, err error) error { + msg := err.Error() + + if strings.HasPrefix(msg, unknownShorthandFlagPrefix) { + return suggestShorthandFlag(cmd, err, msg) + } + + if strings.HasPrefix(msg, unknownFlagPrefix) { + return suggestLongFlag(cmd, err, msg) + } + + return err +} + +// suggestLongFlag suggests a matching long flag name for an "unknown flag: --xyz" error. +func suggestLongFlag(cmd *cobra.Command, original error, msg string) error { + // Extract the flag name: "unknown flag: --flagname" -> "flagname" + flagName := strings.TrimPrefix(msg, unknownFlagPrefix) + flagName = strings.TrimLeft(flagName, "-") + if flagName == "" { + return original + } + + best, bestDist := findClosestFlag(cmd, flagName) + if best == "" || bestDist > maxSuggestionDistance { + return original + } + + return fmt.Errorf("%w\n\nDid you mean \"--%s\"?", original, best) +} + +// suggestShorthandFlag suggests a matching shorthand for an +// "unknown shorthand flag: 'x' in -x" error. +func suggestShorthandFlag(cmd *cobra.Command, original error, msg string) error { + // Extract the shorthand character: "unknown shorthand flag: 'x' in -x" + rest := strings.TrimPrefix(msg, unknownShorthandFlagPrefix) + if len(rest) < 3 || rest[0] != '\'' || rest[2] != '\'' { + return original + } + ch := string(rest[1]) + + best := findClosestShorthand(cmd, ch) + if best == "" { + return original + } + + return fmt.Errorf("%w\n\nDid you mean \"-%s\"?", original, best) +} + +// findClosestFlag returns the closest non-hidden, non-deprecated long flag name +// and its edit distance from the given misspelled name. +func findClosestFlag(cmd *cobra.Command, name string) (string, int) { + best := "" + bestDist := maxSuggestionDistance + 1 + + seen := map[string]bool{} + check := func(f *pflag.Flag) { + if f.Hidden || f.Deprecated != "" { + return + } + if seen[f.Name] { + return + } + seen[f.Name] = true + + d := levenshteinDistance(name, f.Name) + if d < bestDist { + bestDist = d + best = f.Name + } + } + + cmd.Flags().VisitAll(check) + cmd.InheritedFlags().VisitAll(check) + + return best, bestDist +} + +// findClosestShorthand returns the closest non-hidden, non-deprecated shorthand +// that differs by at most 1 edit from the given character. +func findClosestShorthand(cmd *cobra.Command, ch string) string { + best := "" + bestDist := maxSuggestionDistance + 1 + + seen := map[string]bool{} + check := func(f *pflag.Flag) { + if f.Hidden || f.Deprecated != "" || f.Shorthand == "" { + return + } + if seen[f.Shorthand] { + return + } + seen[f.Shorthand] = true + + d := levenshteinDistance(ch, f.Shorthand) + if d < bestDist { + bestDist = d + best = f.Shorthand + } + } + + cmd.Flags().VisitAll(check) + cmd.InheritedFlags().VisitAll(check) + + return best +} diff --git a/cmd/root/flag_suggestions_test.go b/cmd/root/flag_suggestions_test.go new file mode 100644 index 0000000000..1728f60d9e --- /dev/null +++ b/cmd/root/flag_suggestions_test.go @@ -0,0 +1,152 @@ +package root + +import ( + "fmt" + "testing" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" +) + +func TestLevenshteinDistance(t *testing.T) { + tests := []struct { + a, b string + want int + }{ + {"", "", 0}, + {"abc", "abc", 0}, + {"", "abc", 3}, + {"abc", "", 3}, + {"kitten", "sitting", 3}, + {"output", "outpu", 1}, // deletion + {"output", "ouptut", 2}, // transposition = 2 edits + {"output", "outpux", 1}, // substitution + {"output", "outputx", 1}, // insertion + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("%s_%s", tt.a, tt.b), func(t *testing.T) { + assert.Equal(t, tt.want, levenshteinDistance(tt.a, tt.b)) + }) + } +} + +func TestSuggestFlagFromError_LongFlagCloseMatch(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("output", "", "output format") + + err := fmt.Errorf("unknown flag: --outpu") + got := suggestFlagFromError(cmd, err) + assert.Contains(t, got.Error(), `Did you mean "--output"?`) + assert.Contains(t, got.Error(), "unknown flag: --outpu") +} + +func TestSuggestFlagFromError_LongFlagNoMatch(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("output", "", "output format") + + err := fmt.Errorf("unknown flag: --zzzzzzz") + got := suggestFlagFromError(cmd, err) + assert.Equal(t, err.Error(), got.Error()) +} + +func TestSuggestFlagFromError_ShorthandFlag(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + cmd.Flags().StringP("output", "o", "", "output format") + + err := fmt.Errorf("unknown shorthand flag: 'O' in -O") + got := suggestFlagFromError(cmd, err) + assert.Contains(t, got.Error(), `Did you mean "-o"?`) +} + +func TestSuggestFlagFromError_HiddenFlagsExcluded(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("secret", "", "secret flag") + _ = cmd.Flags().MarkHidden("secret") + + err := fmt.Errorf("unknown flag: --secre") + got := suggestFlagFromError(cmd, err) + assert.NotContains(t, got.Error(), "Did you mean") +} + +func TestSuggestFlagFromError_DeprecatedFlagsExcluded(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("legacy", "", "old flag") + _ = cmd.Flags().MarkDeprecated("legacy", "use --new instead") + + err := fmt.Errorf("unknown flag: --legac") + got := suggestFlagFromError(cmd, err) + assert.NotContains(t, got.Error(), "Did you mean") +} + +func TestSuggestFlagFromError_InheritedFlags(t *testing.T) { + parent := &cobra.Command{Use: "parent"} + parent.PersistentFlags().String("profile", "", "auth profile") + + child := &cobra.Command{Use: "child"} + parent.AddCommand(child) + + err := fmt.Errorf("unknown flag: --profil") + got := suggestFlagFromError(child, err) + assert.Contains(t, got.Error(), `Did you mean "--profile"?`) +} + +func TestSuggestFlagFromError_NonFlagError(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("output", "", "output format") + + err := fmt.Errorf("flag needs an argument: --output") + got := suggestFlagFromError(cmd, err) + assert.Equal(t, err.Error(), got.Error()) +} + +func TestSuggestFlagFromError_CobraErrorFormats(t *testing.T) { + tests := []struct { + name string + errMsg string + flags map[string]string + contains string + }{ + { + name: "long flag with double dash", + errMsg: "unknown flag: --outpu", + flags: map[string]string{"output": ""}, + contains: `"--output"`, + }, + { + name: "shorthand with quote format", + errMsg: "unknown shorthand flag: 'x' in -x", + flags: map[string]string{}, + contains: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + for name, usage := range tt.flags { + cmd.Flags().String(name, "", usage) + } + err := fmt.Errorf("%s", tt.errMsg) + got := suggestFlagFromError(cmd, err) + if tt.contains != "" { + assert.Contains(t, got.Error(), tt.contains) + } + }) + } +} + +func TestSuggestFlagFromError_DeduplicatesLocalAndInherited(t *testing.T) { + parent := &cobra.Command{Use: "parent"} + parent.PersistentFlags().String("target", "", "deployment target") + + child := &cobra.Command{Use: "child"} + child.Flags().String("target", "", "deployment target") + parent.AddCommand(child) + + err := fmt.Errorf("unknown flag: --targe") + got := suggestFlagFromError(child, err) + + // Should suggest once, not panic or produce duplicate suggestions. + assert.Contains(t, got.Error(), `Did you mean "--target"?`) +} diff --git a/cmd/root/root.go b/cmd/root/root.go index ca47a23d3a..f10768f21c 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -97,8 +97,10 @@ func New(ctx context.Context) *cobra.Command { return cmd } -// Wrap flag errors to include the usage string. +// flagErrorFunc wraps flag errors to include the usage string and, for unknown +// flags, a "Did you mean" suggestion based on Levenshtein distance. func flagErrorFunc(c *cobra.Command, err error) error { + err = suggestFlagFromError(c, err) return fmt.Errorf("%w\n\n%s", err, c.UsageString()) } From 1db66eab48a5d85876c6b7f549867610430cc64c Mon Sep 17 00:00:00 2001 From: simon Date: Fri, 13 Mar 2026 03:57:15 +0100 Subject: [PATCH 2/4] Fix review findings: shorthand suggestions, tests, TrimPrefix Co-authored-by: Isaac --- cmd/root/flag_suggestions.go | 20 +++++++------------- cmd/root/flag_suggestions_test.go | 29 +++++++++++++++++++++++------ 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/cmd/root/flag_suggestions.go b/cmd/root/flag_suggestions.go index ef5708d946..f531c35fd1 100644 --- a/cmd/root/flag_suggestions.go +++ b/cmd/root/flag_suggestions.go @@ -70,7 +70,7 @@ func suggestFlagFromError(cmd *cobra.Command, err error) error { func suggestLongFlag(cmd *cobra.Command, original error, msg string) error { // Extract the flag name: "unknown flag: --flagname" -> "flagname" flagName := strings.TrimPrefix(msg, unknownFlagPrefix) - flagName = strings.TrimLeft(flagName, "-") + flagName = strings.TrimPrefix(flagName, "--") if flagName == "" { return original } @@ -109,7 +109,7 @@ func findClosestFlag(cmd *cobra.Command, name string) (string, int) { seen := map[string]bool{} check := func(f *pflag.Flag) { - if f.Hidden || f.Deprecated != "" { + if f.Hidden || f.Deprecated != "" || f.ShorthandDeprecated != "" { return } if seen[f.Name] { @@ -130,31 +130,25 @@ func findClosestFlag(cmd *cobra.Command, name string) (string, int) { return best, bestDist } -// findClosestShorthand returns the closest non-hidden, non-deprecated shorthand -// that differs by at most 1 edit from the given character. +// findClosestShorthand returns a case-insensitive exact match for the given +// shorthand character. Levenshtein is not useful for single characters because +// any two distinct characters always have distance 1. func findClosestShorthand(cmd *cobra.Command, ch string) string { best := "" - bestDist := maxSuggestionDistance + 1 - seen := map[string]bool{} check := func(f *pflag.Flag) { - if f.Hidden || f.Deprecated != "" || f.Shorthand == "" { + if f.Hidden || f.Deprecated != "" || f.ShorthandDeprecated != "" || f.Shorthand == "" { return } if seen[f.Shorthand] { return } seen[f.Shorthand] = true - - d := levenshteinDistance(ch, f.Shorthand) - if d < bestDist { - bestDist = d + if strings.EqualFold(ch, f.Shorthand) { best = f.Shorthand } } - cmd.Flags().VisitAll(check) cmd.InheritedFlags().VisitAll(check) - return best } diff --git a/cmd/root/flag_suggestions_test.go b/cmd/root/flag_suggestions_test.go index 1728f60d9e..1afacb0bc3 100644 --- a/cmd/root/flag_suggestions_test.go +++ b/cmd/root/flag_suggestions_test.go @@ -1,6 +1,7 @@ package root import ( + "errors" "fmt" "testing" @@ -114,10 +115,10 @@ func TestSuggestFlagFromError_CobraErrorFormats(t *testing.T) { contains: `"--output"`, }, { - name: "shorthand with quote format", + name: "shorthand with no matching flags", errMsg: "unknown shorthand flag: 'x' in -x", flags: map[string]string{}, - contains: "", + contains: "unknown shorthand flag: 'x' in -x", }, } @@ -127,11 +128,9 @@ func TestSuggestFlagFromError_CobraErrorFormats(t *testing.T) { for name, usage := range tt.flags { cmd.Flags().String(name, "", usage) } - err := fmt.Errorf("%s", tt.errMsg) + err := errors.New(tt.errMsg) got := suggestFlagFromError(cmd, err) - if tt.contains != "" { - assert.Contains(t, got.Error(), tt.contains) - } + assert.Contains(t, got.Error(), tt.contains) }) } } @@ -150,3 +149,21 @@ func TestSuggestFlagFromError_DeduplicatesLocalAndInherited(t *testing.T) { // Should suggest once, not panic or produce duplicate suggestions. assert.Contains(t, got.Error(), `Did you mean "--target"?`) } + +func TestSuggestFlagFromError_EmptyFlagName(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("output", "", "output format") + err := fmt.Errorf("unknown flag: --") + got := suggestFlagFromError(cmd, err) + assert.Equal(t, err.Error(), got.Error()) +} + +func TestSuggestFlagFromError_ShorthandUnrelatedNoSuggestion(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + cmd.Flags().StringP("output", "o", "", "output format") + + err := fmt.Errorf("unknown shorthand flag: 'z' in -z") + got := suggestFlagFromError(cmd, err) + assert.NotContains(t, got.Error(), "Did you mean") + assert.Equal(t, err.Error(), got.Error()) +} From 27d25172c1cbc086fe2054668be6d20aedb87b79 Mon Sep 17 00:00:00 2001 From: simon Date: Fri, 13 Mar 2026 06:49:46 +0100 Subject: [PATCH 3/4] Fix gofmt formatting and lint issues Co-authored-by: Isaac --- cmd/root/flag_suggestions.go | 4 ++-- cmd/root/flag_suggestions_test.go | 20 ++++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/cmd/root/flag_suggestions.go b/cmd/root/flag_suggestions.go index f531c35fd1..ed54237b1a 100644 --- a/cmd/root/flag_suggestions.go +++ b/cmd/root/flag_suggestions.go @@ -11,7 +11,7 @@ import ( const ( unknownFlagPrefix = "unknown flag: " unknownShorthandFlagPrefix = "unknown shorthand flag: " - maxSuggestionDistance = 2 + maxSuggestionDistance = 2 ) // levenshteinDistance computes the edit distance between two strings. @@ -38,7 +38,7 @@ func levenshteinDistance(a, b string) int { cost = 0 } curr[j+1] = min( - curr[j]+1, // insertion + curr[j]+1, // insertion prev[j+1]+1, // deletion prev[j]+cost, // substitution ) diff --git a/cmd/root/flag_suggestions_test.go b/cmd/root/flag_suggestions_test.go index 1afacb0bc3..6cf21fa5a1 100644 --- a/cmd/root/flag_suggestions_test.go +++ b/cmd/root/flag_suggestions_test.go @@ -36,7 +36,7 @@ func TestSuggestFlagFromError_LongFlagCloseMatch(t *testing.T) { cmd := &cobra.Command{Use: "test"} cmd.Flags().String("output", "", "output format") - err := fmt.Errorf("unknown flag: --outpu") + err := errors.New("unknown flag: --outpu") got := suggestFlagFromError(cmd, err) assert.Contains(t, got.Error(), `Did you mean "--output"?`) assert.Contains(t, got.Error(), "unknown flag: --outpu") @@ -46,7 +46,7 @@ func TestSuggestFlagFromError_LongFlagNoMatch(t *testing.T) { cmd := &cobra.Command{Use: "test"} cmd.Flags().String("output", "", "output format") - err := fmt.Errorf("unknown flag: --zzzzzzz") + err := errors.New("unknown flag: --zzzzzzz") got := suggestFlagFromError(cmd, err) assert.Equal(t, err.Error(), got.Error()) } @@ -55,7 +55,7 @@ func TestSuggestFlagFromError_ShorthandFlag(t *testing.T) { cmd := &cobra.Command{Use: "test"} cmd.Flags().StringP("output", "o", "", "output format") - err := fmt.Errorf("unknown shorthand flag: 'O' in -O") + err := errors.New("unknown shorthand flag: 'O' in -O") got := suggestFlagFromError(cmd, err) assert.Contains(t, got.Error(), `Did you mean "-o"?`) } @@ -65,7 +65,7 @@ func TestSuggestFlagFromError_HiddenFlagsExcluded(t *testing.T) { cmd.Flags().String("secret", "", "secret flag") _ = cmd.Flags().MarkHidden("secret") - err := fmt.Errorf("unknown flag: --secre") + err := errors.New("unknown flag: --secre") got := suggestFlagFromError(cmd, err) assert.NotContains(t, got.Error(), "Did you mean") } @@ -75,7 +75,7 @@ func TestSuggestFlagFromError_DeprecatedFlagsExcluded(t *testing.T) { cmd.Flags().String("legacy", "", "old flag") _ = cmd.Flags().MarkDeprecated("legacy", "use --new instead") - err := fmt.Errorf("unknown flag: --legac") + err := errors.New("unknown flag: --legac") got := suggestFlagFromError(cmd, err) assert.NotContains(t, got.Error(), "Did you mean") } @@ -87,7 +87,7 @@ func TestSuggestFlagFromError_InheritedFlags(t *testing.T) { child := &cobra.Command{Use: "child"} parent.AddCommand(child) - err := fmt.Errorf("unknown flag: --profil") + err := errors.New("unknown flag: --profil") got := suggestFlagFromError(child, err) assert.Contains(t, got.Error(), `Did you mean "--profile"?`) } @@ -96,7 +96,7 @@ func TestSuggestFlagFromError_NonFlagError(t *testing.T) { cmd := &cobra.Command{Use: "test"} cmd.Flags().String("output", "", "output format") - err := fmt.Errorf("flag needs an argument: --output") + err := errors.New("flag needs an argument: --output") got := suggestFlagFromError(cmd, err) assert.Equal(t, err.Error(), got.Error()) } @@ -143,7 +143,7 @@ func TestSuggestFlagFromError_DeduplicatesLocalAndInherited(t *testing.T) { child.Flags().String("target", "", "deployment target") parent.AddCommand(child) - err := fmt.Errorf("unknown flag: --targe") + err := errors.New("unknown flag: --targe") got := suggestFlagFromError(child, err) // Should suggest once, not panic or produce duplicate suggestions. @@ -153,7 +153,7 @@ func TestSuggestFlagFromError_DeduplicatesLocalAndInherited(t *testing.T) { func TestSuggestFlagFromError_EmptyFlagName(t *testing.T) { cmd := &cobra.Command{Use: "test"} cmd.Flags().String("output", "", "output format") - err := fmt.Errorf("unknown flag: --") + err := errors.New("unknown flag: --") got := suggestFlagFromError(cmd, err) assert.Equal(t, err.Error(), got.Error()) } @@ -162,7 +162,7 @@ func TestSuggestFlagFromError_ShorthandUnrelatedNoSuggestion(t *testing.T) { cmd := &cobra.Command{Use: "test"} cmd.Flags().StringP("output", "o", "", "output format") - err := fmt.Errorf("unknown shorthand flag: 'z' in -z") + err := errors.New("unknown shorthand flag: 'z' in -z") got := suggestFlagFromError(cmd, err) assert.NotContains(t, got.Error(), "Did you mean") assert.Equal(t, err.Error(), got.Error()) From a0ef7d4c37641ddacfde77a37c6af91c2a79fe04 Mon Sep 17 00:00:00 2001 From: simon Date: Fri, 13 Mar 2026 15:40:50 +0100 Subject: [PATCH 4/4] Use pflag.NotExistError for typed error matching, fix ShorthandDeprecated filtering, add tests Address PR review comments: - Replace string parsing (HasPrefix on error messages) with errors.As matching on pflag.NotExistError, using GetSpecifiedName() and GetSpecifiedShortnames() to extract flag info. - Fix findClosestFlag to no longer exclude flags with ShorthandDeprecated from long-flag suggestions. Only exclude deprecated shorthands from shorthand suggestions (in findClosestShorthand). - Add tests: integration through flagErrorFunc, ShorthandDeprecated filtering for both long and short suggestions, tie-breaking for equidistant flags. - Rewrite existing tests to use real Cobra flag parsing instead of hand-crafted error strings. --- cmd/root/flag_suggestions.go | 45 ++++---- cmd/root/flag_suggestions_test.go | 171 +++++++++++++++++++----------- 2 files changed, 128 insertions(+), 88 deletions(-) diff --git a/cmd/root/flag_suggestions.go b/cmd/root/flag_suggestions.go index ed54237b1a..effef1fcca 100644 --- a/cmd/root/flag_suggestions.go +++ b/cmd/root/flag_suggestions.go @@ -1,6 +1,7 @@ package root import ( + "errors" "fmt" "strings" @@ -8,11 +9,7 @@ import ( "github.com/spf13/pflag" ) -const ( - unknownFlagPrefix = "unknown flag: " - unknownShorthandFlagPrefix = "unknown shorthand flag: " - maxSuggestionDistance = 2 -) +const maxSuggestionDistance = 2 // levenshteinDistance computes the edit distance between two strings. func levenshteinDistance(a, b string) int { @@ -49,28 +46,27 @@ func levenshteinDistance(a, b string) int { return prev[len(b)] } -// suggestFlagFromError inspects the error message from Cobra for "unknown flag" patterns. +// suggestFlagFromError inspects the error from Cobra for unknown-flag errors. // If a close match is found among the command's flags, it returns an enhanced error // with a "Did you mean" suggestion appended. Otherwise it returns the original error. func suggestFlagFromError(cmd *cobra.Command, err error) error { - msg := err.Error() - - if strings.HasPrefix(msg, unknownShorthandFlagPrefix) { - return suggestShorthandFlag(cmd, err, msg) + var notExist *pflag.NotExistError + if !errors.As(err, ¬Exist) { + return err } - if strings.HasPrefix(msg, unknownFlagPrefix) { - return suggestLongFlag(cmd, err, msg) + flagName := notExist.GetSpecifiedName() + isShorthand := notExist.GetSpecifiedShortnames() != "" + + if isShorthand { + return suggestShorthandFlag(cmd, err, flagName) } - return err + return suggestLongFlag(cmd, err, flagName) } -// suggestLongFlag suggests a matching long flag name for an "unknown flag: --xyz" error. -func suggestLongFlag(cmd *cobra.Command, original error, msg string) error { - // Extract the flag name: "unknown flag: --flagname" -> "flagname" - flagName := strings.TrimPrefix(msg, unknownFlagPrefix) - flagName = strings.TrimPrefix(flagName, "--") +// suggestLongFlag suggests a matching long flag name for an unknown long flag error. +func suggestLongFlag(cmd *cobra.Command, original error, flagName string) error { if flagName == "" { return original } @@ -83,15 +79,12 @@ func suggestLongFlag(cmd *cobra.Command, original error, msg string) error { return fmt.Errorf("%w\n\nDid you mean \"--%s\"?", original, best) } -// suggestShorthandFlag suggests a matching shorthand for an -// "unknown shorthand flag: 'x' in -x" error. -func suggestShorthandFlag(cmd *cobra.Command, original error, msg string) error { - // Extract the shorthand character: "unknown shorthand flag: 'x' in -x" - rest := strings.TrimPrefix(msg, unknownShorthandFlagPrefix) - if len(rest) < 3 || rest[0] != '\'' || rest[2] != '\'' { +// suggestShorthandFlag suggests a matching shorthand for an unknown shorthand flag error. +func suggestShorthandFlag(cmd *cobra.Command, original error, flagName string) error { + if flagName == "" { return original } - ch := string(rest[1]) + ch := string(flagName[0]) best := findClosestShorthand(cmd, ch) if best == "" { @@ -109,7 +102,7 @@ func findClosestFlag(cmd *cobra.Command, name string) (string, int) { seen := map[string]bool{} check := func(f *pflag.Flag) { - if f.Hidden || f.Deprecated != "" || f.ShorthandDeprecated != "" { + if f.Hidden || f.Deprecated != "" { return } if seen[f.Name] { diff --git a/cmd/root/flag_suggestions_test.go b/cmd/root/flag_suggestions_test.go index 6cf21fa5a1..6403e6dd30 100644 --- a/cmd/root/flag_suggestions_test.go +++ b/cmd/root/flag_suggestions_test.go @@ -6,9 +6,20 @@ import ( "testing" "github.com/spf13/cobra" + "github.com/spf13/pflag" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +// parseUnknownFlag triggers Cobra's flag parsing on args and returns the error. +// The command is set up with DisableFlagParsing=false (default) and a +// RunE that does nothing, so the only errors come from flag parsing. +func parseUnknownFlag(cmd *cobra.Command, args []string) error { + cmd.RunE = func(cmd *cobra.Command, args []string) error { return nil } + cmd.SetArgs(args) + return cmd.Execute() +} + func TestLevenshteinDistance(t *testing.T) { tests := []struct { a, b string @@ -36,8 +47,15 @@ func TestSuggestFlagFromError_LongFlagCloseMatch(t *testing.T) { cmd := &cobra.Command{Use: "test"} cmd.Flags().String("output", "", "output format") - err := errors.New("unknown flag: --outpu") - got := suggestFlagFromError(cmd, err) + err := &pflag.NotExistError{} + // Parse "--outpu" to get a real error + parseErr := parseUnknownFlag(cmd, []string{"--outpu"}) + require.Error(t, parseErr) + + // Extract the pflag error from the cobra wrapping + require.ErrorAs(t, parseErr, &err) + + got := suggestFlagFromError(cmd, parseErr) assert.Contains(t, got.Error(), `Did you mean "--output"?`) assert.Contains(t, got.Error(), "unknown flag: --outpu") } @@ -46,37 +64,45 @@ func TestSuggestFlagFromError_LongFlagNoMatch(t *testing.T) { cmd := &cobra.Command{Use: "test"} cmd.Flags().String("output", "", "output format") - err := errors.New("unknown flag: --zzzzzzz") - got := suggestFlagFromError(cmd, err) - assert.Equal(t, err.Error(), got.Error()) + parseErr := parseUnknownFlag(cmd, []string{"--zzzzzzz"}) + require.Error(t, parseErr) + + got := suggestFlagFromError(cmd, parseErr) + assert.NotContains(t, got.Error(), "Did you mean") } func TestSuggestFlagFromError_ShorthandFlag(t *testing.T) { cmd := &cobra.Command{Use: "test"} cmd.Flags().StringP("output", "o", "", "output format") - err := errors.New("unknown shorthand flag: 'O' in -O") - got := suggestFlagFromError(cmd, err) + parseErr := parseUnknownFlag(cmd, []string{"-O"}) + require.Error(t, parseErr) + + got := suggestFlagFromError(cmd, parseErr) assert.Contains(t, got.Error(), `Did you mean "-o"?`) } func TestSuggestFlagFromError_HiddenFlagsExcluded(t *testing.T) { cmd := &cobra.Command{Use: "test"} cmd.Flags().String("secret", "", "secret flag") - _ = cmd.Flags().MarkHidden("secret") + require.NoError(t, cmd.Flags().MarkHidden("secret")) - err := errors.New("unknown flag: --secre") - got := suggestFlagFromError(cmd, err) + parseErr := parseUnknownFlag(cmd, []string{"--secre"}) + require.Error(t, parseErr) + + got := suggestFlagFromError(cmd, parseErr) assert.NotContains(t, got.Error(), "Did you mean") } func TestSuggestFlagFromError_DeprecatedFlagsExcluded(t *testing.T) { cmd := &cobra.Command{Use: "test"} cmd.Flags().String("legacy", "", "old flag") - _ = cmd.Flags().MarkDeprecated("legacy", "use --new instead") + require.NoError(t, cmd.Flags().MarkDeprecated("legacy", "use --new instead")) - err := errors.New("unknown flag: --legac") - got := suggestFlagFromError(cmd, err) + parseErr := parseUnknownFlag(cmd, []string{"--legac"}) + require.Error(t, parseErr) + + got := suggestFlagFromError(cmd, parseErr) assert.NotContains(t, got.Error(), "Did you mean") } @@ -85,10 +111,14 @@ func TestSuggestFlagFromError_InheritedFlags(t *testing.T) { parent.PersistentFlags().String("profile", "", "auth profile") child := &cobra.Command{Use: "child"} + child.RunE = func(cmd *cobra.Command, args []string) error { return nil } parent.AddCommand(child) - err := errors.New("unknown flag: --profil") - got := suggestFlagFromError(child, err) + parent.SetArgs([]string{"child", "--profil"}) + parseErr := parent.Execute() + require.Error(t, parseErr) + + got := suggestFlagFromError(child, parseErr) assert.Contains(t, got.Error(), `Did you mean "--profile"?`) } @@ -96,74 +126,91 @@ func TestSuggestFlagFromError_NonFlagError(t *testing.T) { cmd := &cobra.Command{Use: "test"} cmd.Flags().String("output", "", "output format") - err := errors.New("flag needs an argument: --output") + err := errors.New("some other error") got := suggestFlagFromError(cmd, err) assert.Equal(t, err.Error(), got.Error()) } -func TestSuggestFlagFromError_CobraErrorFormats(t *testing.T) { - tests := []struct { - name string - errMsg string - flags map[string]string - contains string - }{ - { - name: "long flag with double dash", - errMsg: "unknown flag: --outpu", - flags: map[string]string{"output": ""}, - contains: `"--output"`, - }, - { - name: "shorthand with no matching flags", - errMsg: "unknown shorthand flag: 'x' in -x", - flags: map[string]string{}, - contains: "unknown shorthand flag: 'x' in -x", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cmd := &cobra.Command{Use: "test"} - for name, usage := range tt.flags { - cmd.Flags().String(name, "", usage) - } - err := errors.New(tt.errMsg) - got := suggestFlagFromError(cmd, err) - assert.Contains(t, got.Error(), tt.contains) - }) - } -} - func TestSuggestFlagFromError_DeduplicatesLocalAndInherited(t *testing.T) { parent := &cobra.Command{Use: "parent"} parent.PersistentFlags().String("target", "", "deployment target") child := &cobra.Command{Use: "child"} child.Flags().String("target", "", "deployment target") + child.RunE = func(cmd *cobra.Command, args []string) error { return nil } parent.AddCommand(child) - err := errors.New("unknown flag: --targe") - got := suggestFlagFromError(child, err) + parent.SetArgs([]string{"child", "--targe"}) + parseErr := parent.Execute() + require.Error(t, parseErr) - // Should suggest once, not panic or produce duplicate suggestions. + got := suggestFlagFromError(child, parseErr) assert.Contains(t, got.Error(), `Did you mean "--target"?`) } -func TestSuggestFlagFromError_EmptyFlagName(t *testing.T) { +func TestSuggestFlagFromError_ShorthandUnrelatedNoSuggestion(t *testing.T) { cmd := &cobra.Command{Use: "test"} - cmd.Flags().String("output", "", "output format") - err := errors.New("unknown flag: --") - got := suggestFlagFromError(cmd, err) - assert.Equal(t, err.Error(), got.Error()) + cmd.Flags().StringP("output", "o", "", "output format") + + parseErr := parseUnknownFlag(cmd, []string{"-z"}) + require.Error(t, parseErr) + + got := suggestFlagFromError(cmd, parseErr) + assert.NotContains(t, got.Error(), "Did you mean") } -func TestSuggestFlagFromError_ShorthandUnrelatedNoSuggestion(t *testing.T) { +func TestSuggestFlagFromError_ShorthandDeprecatedStillSuggestsLongFlag(t *testing.T) { cmd := &cobra.Command{Use: "test"} cmd.Flags().StringP("output", "o", "", "output format") + require.NoError(t, cmd.Flags().MarkShorthandDeprecated("output", "use --output instead")) - err := errors.New("unknown shorthand flag: 'z' in -z") - got := suggestFlagFromError(cmd, err) + parseErr := parseUnknownFlag(cmd, []string{"--outpu"}) + require.Error(t, parseErr) + + // The long flag should still be suggested even though the shorthand is deprecated. + got := suggestFlagFromError(cmd, parseErr) + assert.Contains(t, got.Error(), `Did you mean "--output"?`) +} + +func TestSuggestFlagFromError_ShorthandDeprecatedExcludedFromShorthandSuggestions(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + cmd.Flags().StringP("output", "o", "", "output format") + require.NoError(t, cmd.Flags().MarkShorthandDeprecated("output", "use --output instead")) + + parseErr := parseUnknownFlag(cmd, []string{"-O"}) + require.Error(t, parseErr) + + // The deprecated shorthand should NOT be suggested. + got := suggestFlagFromError(cmd, parseErr) assert.NotContains(t, got.Error(), "Did you mean") - assert.Equal(t, err.Error(), got.Error()) +} + +func TestSuggestFlagFromError_TieBreakingEquidistantFlags(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + // "ab" and "ac" are both distance 1 from "aa" + cmd.Flags().String("ab", "", "") + cmd.Flags().String("ac", "", "") + + parseErr := parseUnknownFlag(cmd, []string{"--aa"}) + require.Error(t, parseErr) + + got := suggestFlagFromError(cmd, parseErr) + // Both are equidistant; we accept whichever is returned (order depends on + // flag iteration) but a suggestion must be present. + assert.Contains(t, got.Error(), "Did you mean") +} + +func TestSuggestFlagFromError_IntegrationThroughFlagErrorFunc(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("output", "", "output format") + cmd.SetFlagErrorFunc(flagErrorFunc) + cmd.RunE = func(cmd *cobra.Command, args []string) error { return nil } + cmd.SetArgs([]string{"--outpu"}) + + err := cmd.Execute() + require.Error(t, err) + + assert.Contains(t, err.Error(), `Did you mean "--output"?`) + // flagErrorFunc also appends usage + assert.Contains(t, err.Error(), "Usage:") }