diff --git a/cmd/root/flag_suggestions.go b/cmd/root/flag_suggestions.go new file mode 100644 index 0000000000..effef1fcca --- /dev/null +++ b/cmd/root/flag_suggestions.go @@ -0,0 +1,147 @@ +package root + +import ( + "errors" + "fmt" + "strings" + + "github.com/spf13/cobra" + "github.com/spf13/pflag" +) + +const maxSuggestionDistance = 2 + +// levenshteinDistance computes the edit distance between two strings. +func levenshteinDistance(a, b string) int { + if len(a) == 0 { + return len(b) + } + if len(b) == 0 { + return len(a) + } + + // Use a single row for the DP table. + prev := make([]int, len(b)+1) + for j := range len(b) + 1 { + prev[j] = j + } + + for i := range len(a) { + curr := make([]int, len(b)+1) + curr[0] = i + 1 + for j := range len(b) { + cost := 1 + if a[i] == b[j] { + cost = 0 + } + curr[j+1] = min( + curr[j]+1, // insertion + prev[j+1]+1, // deletion + prev[j]+cost, // substitution + ) + } + prev = curr + } + + return prev[len(b)] +} + +// suggestFlagFromError inspects the error from Cobra for unknown-flag errors. +// If a close match is found among the command's flags, it returns an enhanced error +// with a "Did you mean" suggestion appended. Otherwise it returns the original error. +func suggestFlagFromError(cmd *cobra.Command, err error) error { + var notExist *pflag.NotExistError + if !errors.As(err, ¬Exist) { + return err + } + + flagName := notExist.GetSpecifiedName() + isShorthand := notExist.GetSpecifiedShortnames() != "" + + if isShorthand { + return suggestShorthandFlag(cmd, err, flagName) + } + + return suggestLongFlag(cmd, err, flagName) +} + +// suggestLongFlag suggests a matching long flag name for an unknown long flag error. +func suggestLongFlag(cmd *cobra.Command, original error, flagName string) error { + if flagName == "" { + return original + } + + best, bestDist := findClosestFlag(cmd, flagName) + if best == "" || bestDist > maxSuggestionDistance { + return original + } + + return fmt.Errorf("%w\n\nDid you mean \"--%s\"?", original, best) +} + +// suggestShorthandFlag suggests a matching shorthand for an unknown shorthand flag error. +func suggestShorthandFlag(cmd *cobra.Command, original error, flagName string) error { + if flagName == "" { + return original + } + ch := string(flagName[0]) + + best := findClosestShorthand(cmd, ch) + if best == "" { + return original + } + + return fmt.Errorf("%w\n\nDid you mean \"-%s\"?", original, best) +} + +// findClosestFlag returns the closest non-hidden, non-deprecated long flag name +// and its edit distance from the given misspelled name. +func findClosestFlag(cmd *cobra.Command, name string) (string, int) { + best := "" + bestDist := maxSuggestionDistance + 1 + + seen := map[string]bool{} + check := func(f *pflag.Flag) { + if f.Hidden || f.Deprecated != "" { + return + } + if seen[f.Name] { + return + } + seen[f.Name] = true + + d := levenshteinDistance(name, f.Name) + if d < bestDist { + bestDist = d + best = f.Name + } + } + + cmd.Flags().VisitAll(check) + cmd.InheritedFlags().VisitAll(check) + + return best, bestDist +} + +// findClosestShorthand returns a case-insensitive exact match for the given +// shorthand character. Levenshtein is not useful for single characters because +// any two distinct characters always have distance 1. +func findClosestShorthand(cmd *cobra.Command, ch string) string { + best := "" + seen := map[string]bool{} + check := func(f *pflag.Flag) { + if f.Hidden || f.Deprecated != "" || f.ShorthandDeprecated != "" || f.Shorthand == "" { + return + } + if seen[f.Shorthand] { + return + } + seen[f.Shorthand] = true + if strings.EqualFold(ch, f.Shorthand) { + best = f.Shorthand + } + } + cmd.Flags().VisitAll(check) + cmd.InheritedFlags().VisitAll(check) + return best +} diff --git a/cmd/root/flag_suggestions_test.go b/cmd/root/flag_suggestions_test.go new file mode 100644 index 0000000000..6403e6dd30 --- /dev/null +++ b/cmd/root/flag_suggestions_test.go @@ -0,0 +1,216 @@ +package root + +import ( + "errors" + "fmt" + "testing" + + "github.com/spf13/cobra" + "github.com/spf13/pflag" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// parseUnknownFlag triggers Cobra's flag parsing on args and returns the error. +// The command is set up with DisableFlagParsing=false (default) and a +// RunE that does nothing, so the only errors come from flag parsing. +func parseUnknownFlag(cmd *cobra.Command, args []string) error { + cmd.RunE = func(cmd *cobra.Command, args []string) error { return nil } + cmd.SetArgs(args) + return cmd.Execute() +} + +func TestLevenshteinDistance(t *testing.T) { + tests := []struct { + a, b string + want int + }{ + {"", "", 0}, + {"abc", "abc", 0}, + {"", "abc", 3}, + {"abc", "", 3}, + {"kitten", "sitting", 3}, + {"output", "outpu", 1}, // deletion + {"output", "ouptut", 2}, // transposition = 2 edits + {"output", "outpux", 1}, // substitution + {"output", "outputx", 1}, // insertion + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("%s_%s", tt.a, tt.b), func(t *testing.T) { + assert.Equal(t, tt.want, levenshteinDistance(tt.a, tt.b)) + }) + } +} + +func TestSuggestFlagFromError_LongFlagCloseMatch(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("output", "", "output format") + + err := &pflag.NotExistError{} + // Parse "--outpu" to get a real error + parseErr := parseUnknownFlag(cmd, []string{"--outpu"}) + require.Error(t, parseErr) + + // Extract the pflag error from the cobra wrapping + require.ErrorAs(t, parseErr, &err) + + got := suggestFlagFromError(cmd, parseErr) + assert.Contains(t, got.Error(), `Did you mean "--output"?`) + assert.Contains(t, got.Error(), "unknown flag: --outpu") +} + +func TestSuggestFlagFromError_LongFlagNoMatch(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("output", "", "output format") + + parseErr := parseUnknownFlag(cmd, []string{"--zzzzzzz"}) + require.Error(t, parseErr) + + got := suggestFlagFromError(cmd, parseErr) + assert.NotContains(t, got.Error(), "Did you mean") +} + +func TestSuggestFlagFromError_ShorthandFlag(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + cmd.Flags().StringP("output", "o", "", "output format") + + parseErr := parseUnknownFlag(cmd, []string{"-O"}) + require.Error(t, parseErr) + + got := suggestFlagFromError(cmd, parseErr) + assert.Contains(t, got.Error(), `Did you mean "-o"?`) +} + +func TestSuggestFlagFromError_HiddenFlagsExcluded(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("secret", "", "secret flag") + require.NoError(t, cmd.Flags().MarkHidden("secret")) + + parseErr := parseUnknownFlag(cmd, []string{"--secre"}) + require.Error(t, parseErr) + + got := suggestFlagFromError(cmd, parseErr) + assert.NotContains(t, got.Error(), "Did you mean") +} + +func TestSuggestFlagFromError_DeprecatedFlagsExcluded(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("legacy", "", "old flag") + require.NoError(t, cmd.Flags().MarkDeprecated("legacy", "use --new instead")) + + parseErr := parseUnknownFlag(cmd, []string{"--legac"}) + require.Error(t, parseErr) + + got := suggestFlagFromError(cmd, parseErr) + assert.NotContains(t, got.Error(), "Did you mean") +} + +func TestSuggestFlagFromError_InheritedFlags(t *testing.T) { + parent := &cobra.Command{Use: "parent"} + parent.PersistentFlags().String("profile", "", "auth profile") + + child := &cobra.Command{Use: "child"} + child.RunE = func(cmd *cobra.Command, args []string) error { return nil } + parent.AddCommand(child) + + parent.SetArgs([]string{"child", "--profil"}) + parseErr := parent.Execute() + require.Error(t, parseErr) + + got := suggestFlagFromError(child, parseErr) + assert.Contains(t, got.Error(), `Did you mean "--profile"?`) +} + +func TestSuggestFlagFromError_NonFlagError(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("output", "", "output format") + + err := errors.New("some other error") + got := suggestFlagFromError(cmd, err) + assert.Equal(t, err.Error(), got.Error()) +} + +func TestSuggestFlagFromError_DeduplicatesLocalAndInherited(t *testing.T) { + parent := &cobra.Command{Use: "parent"} + parent.PersistentFlags().String("target", "", "deployment target") + + child := &cobra.Command{Use: "child"} + child.Flags().String("target", "", "deployment target") + child.RunE = func(cmd *cobra.Command, args []string) error { return nil } + parent.AddCommand(child) + + parent.SetArgs([]string{"child", "--targe"}) + parseErr := parent.Execute() + require.Error(t, parseErr) + + got := suggestFlagFromError(child, parseErr) + assert.Contains(t, got.Error(), `Did you mean "--target"?`) +} + +func TestSuggestFlagFromError_ShorthandUnrelatedNoSuggestion(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + cmd.Flags().StringP("output", "o", "", "output format") + + parseErr := parseUnknownFlag(cmd, []string{"-z"}) + require.Error(t, parseErr) + + got := suggestFlagFromError(cmd, parseErr) + assert.NotContains(t, got.Error(), "Did you mean") +} + +func TestSuggestFlagFromError_ShorthandDeprecatedStillSuggestsLongFlag(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + cmd.Flags().StringP("output", "o", "", "output format") + require.NoError(t, cmd.Flags().MarkShorthandDeprecated("output", "use --output instead")) + + parseErr := parseUnknownFlag(cmd, []string{"--outpu"}) + require.Error(t, parseErr) + + // The long flag should still be suggested even though the shorthand is deprecated. + got := suggestFlagFromError(cmd, parseErr) + assert.Contains(t, got.Error(), `Did you mean "--output"?`) +} + +func TestSuggestFlagFromError_ShorthandDeprecatedExcludedFromShorthandSuggestions(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + cmd.Flags().StringP("output", "o", "", "output format") + require.NoError(t, cmd.Flags().MarkShorthandDeprecated("output", "use --output instead")) + + parseErr := parseUnknownFlag(cmd, []string{"-O"}) + require.Error(t, parseErr) + + // The deprecated shorthand should NOT be suggested. + got := suggestFlagFromError(cmd, parseErr) + assert.NotContains(t, got.Error(), "Did you mean") +} + +func TestSuggestFlagFromError_TieBreakingEquidistantFlags(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + // "ab" and "ac" are both distance 1 from "aa" + cmd.Flags().String("ab", "", "") + cmd.Flags().String("ac", "", "") + + parseErr := parseUnknownFlag(cmd, []string{"--aa"}) + require.Error(t, parseErr) + + got := suggestFlagFromError(cmd, parseErr) + // Both are equidistant; we accept whichever is returned (order depends on + // flag iteration) but a suggestion must be present. + assert.Contains(t, got.Error(), "Did you mean") +} + +func TestSuggestFlagFromError_IntegrationThroughFlagErrorFunc(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + cmd.Flags().String("output", "", "output format") + cmd.SetFlagErrorFunc(flagErrorFunc) + cmd.RunE = func(cmd *cobra.Command, args []string) error { return nil } + cmd.SetArgs([]string{"--outpu"}) + + err := cmd.Execute() + require.Error(t, err) + + assert.Contains(t, err.Error(), `Did you mean "--output"?`) + // flagErrorFunc also appends usage + assert.Contains(t, err.Error(), "Usage:") +} diff --git a/cmd/root/root.go b/cmd/root/root.go index ca47a23d3a..f10768f21c 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -97,8 +97,10 @@ func New(ctx context.Context) *cobra.Command { return cmd } -// Wrap flag errors to include the usage string. +// flagErrorFunc wraps flag errors to include the usage string and, for unknown +// flags, a "Did you mean" suggestion based on Levenshtein distance. func flagErrorFunc(c *cobra.Command, err error) error { + err = suggestFlagFromError(c, err) return fmt.Errorf("%w\n\n%s", err, c.UsageString()) }