Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .claude/settings.local.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
"Bash(git commit:*)",
"Bash(gh pr view:*)",
"Bash(grep:*)",
"Bash(earthly +earthly-linux-amd64:*)"
"Bash(earthly +earthly-linux-amd64:*)",
"Bash(go version:*)",
"Bash(go vet:*)"
]
}
}
99 changes: 99 additions & 0 deletions util/flagutil/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,115 @@ package flagutil

import (
"context"
"math"
"os"
"reflect"
"regexp"
"strings"

"github.com/EarthBuild/earthbuild/ast/commandflag"
"github.com/EarthBuild/earthbuild/ast/spec"
"github.com/EarthBuild/earthbuild/util/hint"
"github.com/EarthBuild/earthbuild/util/stringutil"
"github.com/agext/levenshtein"
"github.com/pkg/errors"

"github.com/jessevdk/go-flags"
"github.com/urfave/cli/v2"
)

// extractFlagNames extracts all long flag names from a struct using reflection.
func extractFlagNames(data any) []string {
if data == nil {
return nil
}

v := reflect.ValueOf(data)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() != reflect.Struct {
return nil
}

t := v.Type()
var flagNames []string
for i := range t.NumField() {
if longTag := t.Field(i).Tag.Get("long"); longTag != "" {
flagNames = append(flagNames, longTag)
}
}
return flagNames
}

// findClosestFlag finds the most similar flag name to the given unknown flag.
// Returns the suggested flag and whether a good suggestion was found.
func findClosestFlag(unknownFlag string, validFlags []string) (string, bool) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bool is redundant. The empty string indicates that no closest match was found.

if len(validFlags) == 0 {
return "", false
}

// Remove leading dashes from the unknown flag for comparison
unknownFlag = strings.TrimLeft(unknownFlag, "-")

bestMatch := ""
bestDistance := math.MaxInt

for _, validFlag := range validFlags {
if distance := levenshtein.Distance(unknownFlag, validFlag, nil); distance < bestDistance {
bestDistance = distance
bestMatch = validFlag
}
}

// Only suggest if the distance is reasonable (less than half the length of the unknown flag).
// This prevents suggesting completely unrelated flags.
// Allow at least 2 character difference for short flags.
maxDistance := max(len(unknownFlag)/2, 2)
if bestDistance <= maxDistance {
return bestMatch, true
}
return "", false
}

// suggestFlagIfUnknown checks if the error is about an unknown flag and adds a suggestion if possible.
func suggestFlagIfUnknown(err error, data any) error {
if err == nil {
return nil
}

unknownFlag, ok := extractUnknownFlagFromError(err)
if !ok {
return err
}

suggestion, found := findClosestFlag(unknownFlag, extractFlagNames(data))
if !found {
return err
}

return hint.Wrapf(err, "Did you mean '--%s'?", suggestion)
}

// unknownFlagRegexp matches the flag name in go-flags error messages like "unknown flag `flag-name'".
var unknownFlagRegexp = regexp.MustCompile("`([^']+)'")

// extractUnknownFlagFromError extracts the flag name from an "unknown flag" error.
// Uses type assertion to check for the specific error type from go-flags library.
func extractUnknownFlagFromError(err error) (string, bool) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, bool is redundant

var flagErr *flags.Error
if !errors.As(err, &flagErr) || flagErr.Type != flags.ErrUnknownFlag {
return "", false
}

matches := unknownFlagRegexp.FindStringSubmatch(flagErr.Message)
if len(matches) < 2 {
return "", false
}

return matches[1], true
}

// ArgumentModFunc accepts a flagName which corresponds to the long flag name, and a pointer
// to a flag value. The pointer is nil if no flag was given.
// the function returns a new pointer set to nil if one wants to pretend as if no value was given,
Expand Down Expand Up @@ -81,6 +178,8 @@ func ParseArgsWithValueModifierAndOptions(
if parserOptions&flags.PrintErrors != flags.None {
p.WriteHelp(os.Stderr)
}
// Try to provide helpful suggestions for unknown flags
err = suggestFlagIfUnknown(err, data)
return nil, err
}
if modFuncErr != nil {
Expand Down
139 changes: 139 additions & 0 deletions util/flagutil/parse_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package flagutil

import (
"errors"
"reflect"
"testing"

"github.com/EarthBuild/earthbuild/util/hint"
"github.com/jessevdk/go-flags"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/cli/v2"
Expand Down Expand Up @@ -127,3 +130,139 @@ func TestNegativeParseParams(t *testing.T) {
assert.Error(t, err)
}
}

func TestExtractFlagNames(t *testing.T) {
t.Parallel()

type TestOpts struct {
KeepTs bool `long:"keep-ts"`
KeepOwn bool `long:"keep-own"`
IfExists bool `long:"if-exists"`
Force bool `long:"force"`
NoTag bool // no long tag, should be ignored
}

opts := &TestOpts{}
flags := extractFlagNames(opts)

expected := []string{"keep-ts", "keep-own", "if-exists", "force"}
if len(flags) != len(expected) {
t.Errorf("extractFlagNames returned %d flags; want %d", len(flags), len(expected))
}

// Check that all expected flags are present
flagMap := make(map[string]bool)
for _, f := range flags {
flagMap[f] = true
}
for _, exp := range expected {
if !flagMap[exp] {
t.Errorf("extractFlagNames missing expected flag: %s", exp)
}
}
}

func TestFindClosestFlag(t *testing.T) {
t.Parallel()

validFlags := []string{"keep-ts", "keep-own", "if-exists", "symlink-no-follow", "force"}

tests := []struct {
unknownFlag string
expectedMatch string
shouldFind bool
description string
}{
{"if-exist", "if-exists", true, "missing final 's'"},
{"--if-exist", "if-exists", true, "with leading dashes"},
{"keep-t", "keep-ts", true, "shortened version"},
{"forc", "force", true, "missing final 'e'"},
{"completely-different", "", false, "no close match"},
{"xyz", "", false, "very short and different"},
}

for _, tt := range tests {
t.Run(tt.description, func(t *testing.T) {
t.Parallel()

match, found := findClosestFlag(tt.unknownFlag, validFlags)
if found != tt.shouldFind {
t.Errorf("findClosestFlag(%q) found=%v; want %v (%s)", tt.unknownFlag, found, tt.shouldFind, tt.description)
}
if found && match != tt.expectedMatch {
t.Errorf("findClosestFlag(%q) = %q; want %q (%s)", tt.unknownFlag, match, tt.expectedMatch, tt.description)
}
})
}
}

func TestSuggestFlagIfUnknown(t *testing.T) {
t.Parallel()

type TestOpts struct {
KeepTs bool `long:"keep-ts"`
KeepOwn bool `long:"keep-own"`
IfExists bool `long:"if-exists"`
Force bool `long:"force"`
}

opts := &TestOpts{}

tests := []struct {
inputError error
shouldHaveHint bool
expectedHint string
description string
}{
{
&flags.Error{Type: flags.ErrUnknownFlag, Message: "unknown flag `if-exist'"},
true,
"Did you mean '--if-exists'?",
"typo in if-exists flag",
},
{
&flags.Error{Type: flags.ErrUnknownFlag, Message: "unknown flag `keep-t'"},
true,
"Did you mean '--keep-ts'?",
"shortened keep-ts flag",
},
{
errors.New("some other error"),
false,
"",
"non-flag error should pass through",
},
{
&flags.Error{Type: flags.ErrUnknownFlag, Message: "unknown flag `completely-wrong-flag'"},
false,
"",
"flag too different to suggest",
},
}

for _, tt := range tests {
t.Run(tt.description, func(t *testing.T) {
t.Parallel()

result := suggestFlagIfUnknown(tt.inputError, opts)

// Check if the result is a hint.Error
hintErr, isHintErr := result.(*hint.Error)

if tt.shouldHaveHint {
if !isHintErr {
t.Errorf("%s: expected hint error, got regular error: %v", tt.description, result)
return
}
hintText := hintErr.Hint()
if hintText != tt.expectedHint+"\n" {
t.Errorf("%s: hint = %q; want %q", tt.description, hintText, tt.expectedHint+"\n")
}
} else {
if isHintErr {
t.Errorf("%s: expected regular error, got hint error: %v", tt.description, result)
}
}
})
}
}
Loading