Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions cmd/root/flag_suggestions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package root

import (
"errors"
"fmt"
"strings"

"github.com/spf13/cobra"
"github.com/spf13/pflag"
)

const maxSuggestionDistance = 2

// levenshteinDistance computes the edit distance between two strings.
func levenshteinDistance(a, b string) int {
if len(a) == 0 {
return len(b)
}
if len(b) == 0 {
return len(a)
}

// Use a single row for the DP table.
prev := make([]int, len(b)+1)
for j := range len(b) + 1 {
prev[j] = j
}

for i := range len(a) {
curr := make([]int, len(b)+1)
curr[0] = i + 1
for j := range len(b) {
cost := 1
if a[i] == b[j] {
cost = 0
}
curr[j+1] = min(
curr[j]+1, // insertion
prev[j+1]+1, // deletion
prev[j]+cost, // substitution
)
}
prev = curr
}

return prev[len(b)]
}

// suggestFlagFromError inspects the error from Cobra for unknown-flag errors.
// If a close match is found among the command's flags, it returns an enhanced error
// with a "Did you mean" suggestion appended. Otherwise it returns the original error.
func suggestFlagFromError(cmd *cobra.Command, err error) error {
var notExist *pflag.NotExistError
if !errors.As(err, &notExist) {
return err
}

flagName := notExist.GetSpecifiedName()
isShorthand := notExist.GetSpecifiedShortnames() != ""

if isShorthand {
return suggestShorthandFlag(cmd, err, flagName)
}

return suggestLongFlag(cmd, err, flagName)
}

// suggestLongFlag suggests a matching long flag name for an unknown long flag error.
func suggestLongFlag(cmd *cobra.Command, original error, flagName string) error {
if flagName == "" {
return original
}

best, bestDist := findClosestFlag(cmd, flagName)
if best == "" || bestDist > maxSuggestionDistance {
return original
}

return fmt.Errorf("%w\n\nDid you mean \"--%s\"?", original, best)
}

// suggestShorthandFlag suggests a matching shorthand for an unknown shorthand flag error.
func suggestShorthandFlag(cmd *cobra.Command, original error, flagName string) error {
if flagName == "" {
return original
}
ch := string(flagName[0])

best := findClosestShorthand(cmd, ch)
if best == "" {
return original
}

return fmt.Errorf("%w\n\nDid you mean \"-%s\"?", original, best)
}

// findClosestFlag returns the closest non-hidden, non-deprecated long flag name
// and its edit distance from the given misspelled name.
func findClosestFlag(cmd *cobra.Command, name string) (string, int) {
best := ""
bestDist := maxSuggestionDistance + 1

seen := map[string]bool{}
check := func(f *pflag.Flag) {
if f.Hidden || f.Deprecated != "" {
return
}
if seen[f.Name] {
return
}
seen[f.Name] = true

d := levenshteinDistance(name, f.Name)
if d < bestDist {
bestDist = d
best = f.Name
}
}

cmd.Flags().VisitAll(check)
cmd.InheritedFlags().VisitAll(check)

return best, bestDist
}

// findClosestShorthand returns a case-insensitive exact match for the given
// shorthand character. Levenshtein is not useful for single characters because
// any two distinct characters always have distance 1.
func findClosestShorthand(cmd *cobra.Command, ch string) string {
best := ""
seen := map[string]bool{}
check := func(f *pflag.Flag) {
if f.Hidden || f.Deprecated != "" || f.ShorthandDeprecated != "" || f.Shorthand == "" {
return
}
if seen[f.Shorthand] {
return
}
seen[f.Shorthand] = true
if strings.EqualFold(ch, f.Shorthand) {
best = f.Shorthand
}
}
cmd.Flags().VisitAll(check)
cmd.InheritedFlags().VisitAll(check)
return best
}
216 changes: 216 additions & 0 deletions cmd/root/flag_suggestions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
package root

import (
"errors"
"fmt"
"testing"

"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// parseUnknownFlag triggers Cobra's flag parsing on args and returns the error.
// The command is set up with DisableFlagParsing=false (default) and a
// RunE that does nothing, so the only errors come from flag parsing.
func parseUnknownFlag(cmd *cobra.Command, args []string) error {
cmd.RunE = func(cmd *cobra.Command, args []string) error { return nil }
cmd.SetArgs(args)
return cmd.Execute()
}

func TestLevenshteinDistance(t *testing.T) {
tests := []struct {
a, b string
want int
}{
{"", "", 0},
{"abc", "abc", 0},
{"", "abc", 3},
{"abc", "", 3},
{"kitten", "sitting", 3},
{"output", "outpu", 1}, // deletion
{"output", "ouptut", 2}, // transposition = 2 edits
{"output", "outpux", 1}, // substitution
{"output", "outputx", 1}, // insertion
}

for _, tt := range tests {
t.Run(fmt.Sprintf("%s_%s", tt.a, tt.b), func(t *testing.T) {
assert.Equal(t, tt.want, levenshteinDistance(tt.a, tt.b))
})
}
}

func TestSuggestFlagFromError_LongFlagCloseMatch(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
cmd.Flags().String("output", "", "output format")

err := &pflag.NotExistError{}
// Parse "--outpu" to get a real error
parseErr := parseUnknownFlag(cmd, []string{"--outpu"})
require.Error(t, parseErr)

// Extract the pflag error from the cobra wrapping
require.ErrorAs(t, parseErr, &err)

got := suggestFlagFromError(cmd, parseErr)
assert.Contains(t, got.Error(), `Did you mean "--output"?`)
assert.Contains(t, got.Error(), "unknown flag: --outpu")
}

func TestSuggestFlagFromError_LongFlagNoMatch(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
cmd.Flags().String("output", "", "output format")

parseErr := parseUnknownFlag(cmd, []string{"--zzzzzzz"})
require.Error(t, parseErr)

got := suggestFlagFromError(cmd, parseErr)
assert.NotContains(t, got.Error(), "Did you mean")
}

func TestSuggestFlagFromError_ShorthandFlag(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
cmd.Flags().StringP("output", "o", "", "output format")

parseErr := parseUnknownFlag(cmd, []string{"-O"})
require.Error(t, parseErr)

got := suggestFlagFromError(cmd, parseErr)
assert.Contains(t, got.Error(), `Did you mean "-o"?`)
}

func TestSuggestFlagFromError_HiddenFlagsExcluded(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
cmd.Flags().String("secret", "", "secret flag")
require.NoError(t, cmd.Flags().MarkHidden("secret"))

parseErr := parseUnknownFlag(cmd, []string{"--secre"})
require.Error(t, parseErr)

got := suggestFlagFromError(cmd, parseErr)
assert.NotContains(t, got.Error(), "Did you mean")
}

func TestSuggestFlagFromError_DeprecatedFlagsExcluded(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
cmd.Flags().String("legacy", "", "old flag")
require.NoError(t, cmd.Flags().MarkDeprecated("legacy", "use --new instead"))

parseErr := parseUnknownFlag(cmd, []string{"--legac"})
require.Error(t, parseErr)

got := suggestFlagFromError(cmd, parseErr)
assert.NotContains(t, got.Error(), "Did you mean")
}

func TestSuggestFlagFromError_InheritedFlags(t *testing.T) {
parent := &cobra.Command{Use: "parent"}
parent.PersistentFlags().String("profile", "", "auth profile")

child := &cobra.Command{Use: "child"}
child.RunE = func(cmd *cobra.Command, args []string) error { return nil }
parent.AddCommand(child)

parent.SetArgs([]string{"child", "--profil"})
parseErr := parent.Execute()
require.Error(t, parseErr)

got := suggestFlagFromError(child, parseErr)
assert.Contains(t, got.Error(), `Did you mean "--profile"?`)
}

func TestSuggestFlagFromError_NonFlagError(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
cmd.Flags().String("output", "", "output format")

err := errors.New("some other error")
got := suggestFlagFromError(cmd, err)
assert.Equal(t, err.Error(), got.Error())
}

func TestSuggestFlagFromError_DeduplicatesLocalAndInherited(t *testing.T) {
parent := &cobra.Command{Use: "parent"}
parent.PersistentFlags().String("target", "", "deployment target")

child := &cobra.Command{Use: "child"}
child.Flags().String("target", "", "deployment target")
child.RunE = func(cmd *cobra.Command, args []string) error { return nil }
parent.AddCommand(child)

parent.SetArgs([]string{"child", "--targe"})
parseErr := parent.Execute()
require.Error(t, parseErr)

got := suggestFlagFromError(child, parseErr)
assert.Contains(t, got.Error(), `Did you mean "--target"?`)
}

func TestSuggestFlagFromError_ShorthandUnrelatedNoSuggestion(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
cmd.Flags().StringP("output", "o", "", "output format")

parseErr := parseUnknownFlag(cmd, []string{"-z"})
require.Error(t, parseErr)

got := suggestFlagFromError(cmd, parseErr)
assert.NotContains(t, got.Error(), "Did you mean")
}

func TestSuggestFlagFromError_ShorthandDeprecatedStillSuggestsLongFlag(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
cmd.Flags().StringP("output", "o", "", "output format")
require.NoError(t, cmd.Flags().MarkShorthandDeprecated("output", "use --output instead"))

parseErr := parseUnknownFlag(cmd, []string{"--outpu"})
require.Error(t, parseErr)

// The long flag should still be suggested even though the shorthand is deprecated.
got := suggestFlagFromError(cmd, parseErr)
assert.Contains(t, got.Error(), `Did you mean "--output"?`)
}

func TestSuggestFlagFromError_ShorthandDeprecatedExcludedFromShorthandSuggestions(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
cmd.Flags().StringP("output", "o", "", "output format")
require.NoError(t, cmd.Flags().MarkShorthandDeprecated("output", "use --output instead"))

parseErr := parseUnknownFlag(cmd, []string{"-O"})
require.Error(t, parseErr)

// The deprecated shorthand should NOT be suggested.
got := suggestFlagFromError(cmd, parseErr)
assert.NotContains(t, got.Error(), "Did you mean")
}

func TestSuggestFlagFromError_TieBreakingEquidistantFlags(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
// "ab" and "ac" are both distance 1 from "aa"
cmd.Flags().String("ab", "", "")
cmd.Flags().String("ac", "", "")

parseErr := parseUnknownFlag(cmd, []string{"--aa"})
require.Error(t, parseErr)

got := suggestFlagFromError(cmd, parseErr)
// Both are equidistant; we accept whichever is returned (order depends on
// flag iteration) but a suggestion must be present.
assert.Contains(t, got.Error(), "Did you mean")
}

func TestSuggestFlagFromError_IntegrationThroughFlagErrorFunc(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
cmd.Flags().String("output", "", "output format")
cmd.SetFlagErrorFunc(flagErrorFunc)
cmd.RunE = func(cmd *cobra.Command, args []string) error { return nil }
cmd.SetArgs([]string{"--outpu"})

err := cmd.Execute()
require.Error(t, err)

assert.Contains(t, err.Error(), `Did you mean "--output"?`)
// flagErrorFunc also appends usage
assert.Contains(t, err.Error(), "Usage:")
}
4 changes: 3 additions & 1 deletion cmd/root/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,10 @@ func New(ctx context.Context) *cobra.Command {
return cmd
}

// Wrap flag errors to include the usage string.
// flagErrorFunc wraps flag errors to include the usage string and, for unknown
// flags, a "Did you mean" suggestion based on Levenshtein distance.
func flagErrorFunc(c *cobra.Command, err error) error {
err = suggestFlagFromError(c, err)
return fmt.Errorf("%w\n\n%s", err, c.UsageString())
}

Expand Down
Loading