diff --git a/.claude/settings.local.json b/.claude/settings.local.json index d18abe0cb3..e05e6a8a57 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -8,7 +8,9 @@ "Bash(git commit:*)", "Bash(gh pr view:*)", "Bash(grep:*)", - "Bash(earthly +earthly-linux-amd64:*)" + "Bash(earthly +earthly-linux-amd64:*)", + "Bash(go version:*)", + "Bash(go vet:*)" ] } } diff --git a/util/flagutil/parse.go b/util/flagutil/parse.go index 24243d5b32..8b6f830639 100644 --- a/util/flagutil/parse.go +++ b/util/flagutil/parse.go @@ -2,18 +2,115 @@ package flagutil import ( "context" + "math" "os" + "reflect" + "regexp" "strings" "github.com/EarthBuild/earthbuild/ast/commandflag" "github.com/EarthBuild/earthbuild/ast/spec" + "github.com/EarthBuild/earthbuild/util/hint" "github.com/EarthBuild/earthbuild/util/stringutil" + "github.com/agext/levenshtein" "github.com/pkg/errors" "github.com/jessevdk/go-flags" "github.com/urfave/cli/v2" ) +// extractFlagNames extracts all long flag names from a struct using reflection. +func extractFlagNames(data any) []string { + if data == nil { + return nil + } + + v := reflect.ValueOf(data) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() != reflect.Struct { + return nil + } + + t := v.Type() + var flagNames []string + for i := range t.NumField() { + if longTag := t.Field(i).Tag.Get("long"); longTag != "" { + flagNames = append(flagNames, longTag) + } + } + return flagNames +} + +// findClosestFlag finds the most similar flag name to the given unknown flag. +// Returns the suggested flag and whether a good suggestion was found. +func findClosestFlag(unknownFlag string, validFlags []string) (string, bool) { + if len(validFlags) == 0 { + return "", false + } + + // Remove leading dashes from the unknown flag for comparison + unknownFlag = strings.TrimLeft(unknownFlag, "-") + + bestMatch := "" + bestDistance := math.MaxInt + + for _, validFlag := range validFlags { + if distance := levenshtein.Distance(unknownFlag, validFlag, nil); distance < bestDistance { + bestDistance = distance + bestMatch = validFlag + } + } + + // Only suggest if the distance is reasonable (less than half the length of the unknown flag). + // This prevents suggesting completely unrelated flags. + // Allow at least 2 character difference for short flags. + maxDistance := max(len(unknownFlag)/2, 2) + if bestDistance <= maxDistance { + return bestMatch, true + } + return "", false +} + +// suggestFlagIfUnknown checks if the error is about an unknown flag and adds a suggestion if possible. +func suggestFlagIfUnknown(err error, data any) error { + if err == nil { + return nil + } + + unknownFlag, ok := extractUnknownFlagFromError(err) + if !ok { + return err + } + + suggestion, found := findClosestFlag(unknownFlag, extractFlagNames(data)) + if !found { + return err + } + + return hint.Wrapf(err, "Did you mean '--%s'?", suggestion) +} + +// unknownFlagRegexp matches the flag name in go-flags error messages like "unknown flag `flag-name'". +var unknownFlagRegexp = regexp.MustCompile("`([^']+)'") + +// extractUnknownFlagFromError extracts the flag name from an "unknown flag" error. +// Uses type assertion to check for the specific error type from go-flags library. +func extractUnknownFlagFromError(err error) (string, bool) { + var flagErr *flags.Error + if !errors.As(err, &flagErr) || flagErr.Type != flags.ErrUnknownFlag { + return "", false + } + + matches := unknownFlagRegexp.FindStringSubmatch(flagErr.Message) + if len(matches) < 2 { + return "", false + } + + return matches[1], true +} + // ArgumentModFunc accepts a flagName which corresponds to the long flag name, and a pointer // to a flag value. The pointer is nil if no flag was given. // the function returns a new pointer set to nil if one wants to pretend as if no value was given, @@ -81,6 +178,8 @@ func ParseArgsWithValueModifierAndOptions( if parserOptions&flags.PrintErrors != flags.None { p.WriteHelp(os.Stderr) } + // Try to provide helpful suggestions for unknown flags + err = suggestFlagIfUnknown(err, data) return nil, err } if modFuncErr != nil { diff --git a/util/flagutil/parse_test.go b/util/flagutil/parse_test.go index c9b29c657f..322f329a13 100644 --- a/util/flagutil/parse_test.go +++ b/util/flagutil/parse_test.go @@ -1,9 +1,12 @@ package flagutil import ( + "errors" "reflect" "testing" + "github.com/EarthBuild/earthbuild/util/hint" + "github.com/jessevdk/go-flags" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/urfave/cli/v2" @@ -127,3 +130,139 @@ func TestNegativeParseParams(t *testing.T) { assert.Error(t, err) } } + +func TestExtractFlagNames(t *testing.T) { + t.Parallel() + + type TestOpts struct { + KeepTs bool `long:"keep-ts"` + KeepOwn bool `long:"keep-own"` + IfExists bool `long:"if-exists"` + Force bool `long:"force"` + NoTag bool // no long tag, should be ignored + } + + opts := &TestOpts{} + flags := extractFlagNames(opts) + + expected := []string{"keep-ts", "keep-own", "if-exists", "force"} + if len(flags) != len(expected) { + t.Errorf("extractFlagNames returned %d flags; want %d", len(flags), len(expected)) + } + + // Check that all expected flags are present + flagMap := make(map[string]bool) + for _, f := range flags { + flagMap[f] = true + } + for _, exp := range expected { + if !flagMap[exp] { + t.Errorf("extractFlagNames missing expected flag: %s", exp) + } + } +} + +func TestFindClosestFlag(t *testing.T) { + t.Parallel() + + validFlags := []string{"keep-ts", "keep-own", "if-exists", "symlink-no-follow", "force"} + + tests := []struct { + unknownFlag string + expectedMatch string + shouldFind bool + description string + }{ + {"if-exist", "if-exists", true, "missing final 's'"}, + {"--if-exist", "if-exists", true, "with leading dashes"}, + {"keep-t", "keep-ts", true, "shortened version"}, + {"forc", "force", true, "missing final 'e'"}, + {"completely-different", "", false, "no close match"}, + {"xyz", "", false, "very short and different"}, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + t.Parallel() + + match, found := findClosestFlag(tt.unknownFlag, validFlags) + if found != tt.shouldFind { + t.Errorf("findClosestFlag(%q) found=%v; want %v (%s)", tt.unknownFlag, found, tt.shouldFind, tt.description) + } + if found && match != tt.expectedMatch { + t.Errorf("findClosestFlag(%q) = %q; want %q (%s)", tt.unknownFlag, match, tt.expectedMatch, tt.description) + } + }) + } +} + +func TestSuggestFlagIfUnknown(t *testing.T) { + t.Parallel() + + type TestOpts struct { + KeepTs bool `long:"keep-ts"` + KeepOwn bool `long:"keep-own"` + IfExists bool `long:"if-exists"` + Force bool `long:"force"` + } + + opts := &TestOpts{} + + tests := []struct { + inputError error + shouldHaveHint bool + expectedHint string + description string + }{ + { + &flags.Error{Type: flags.ErrUnknownFlag, Message: "unknown flag `if-exist'"}, + true, + "Did you mean '--if-exists'?", + "typo in if-exists flag", + }, + { + &flags.Error{Type: flags.ErrUnknownFlag, Message: "unknown flag `keep-t'"}, + true, + "Did you mean '--keep-ts'?", + "shortened keep-ts flag", + }, + { + errors.New("some other error"), + false, + "", + "non-flag error should pass through", + }, + { + &flags.Error{Type: flags.ErrUnknownFlag, Message: "unknown flag `completely-wrong-flag'"}, + false, + "", + "flag too different to suggest", + }, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + t.Parallel() + + result := suggestFlagIfUnknown(tt.inputError, opts) + + // Check if the result is a hint.Error + hintErr, isHintErr := result.(*hint.Error) + + if tt.shouldHaveHint { + if !isHintErr { + t.Errorf("%s: expected hint error, got regular error: %v", tt.description, result) + return + } + hintText := hintErr.Hint() + if hintText != tt.expectedHint+"\n" { + t.Errorf("%s: hint = %q; want %q", tt.description, hintText, tt.expectedHint+"\n") + } + } else { + if isHintErr { + t.Errorf("%s: expected regular error, got hint error: %v", tt.description, result) + } + } + }) + } +}