Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions command_run.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@
// arguments are parsed according to the Flag and Command
// definitions and the matching Action functions are run.
func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {
_, deferErr = cmd.run(ctx, osArgs)
return
}

func (cmd *Command) run(ctx context.Context, osArgs []string) (_ context.Context, deferErr error) {
tracef("running with arguments %[1]q (cmd=%[2]q)", osArgs, cmd.Name)
cmd.setupDefaults(osArgs)

Expand All @@ -102,7 +107,7 @@
if cmd.parent == nil {
if cmd.ReadArgsFromStdin {
if args, err := cmd.parseArgsFromStdin(); err != nil {
return err
return ctx, err

Check warning on line 110 in command_run.go

View check run for this annotation

Codecov / codecov/patch

command_run.go#L110

Added line #L110 was not covered by tests
} else {
osArgs = append(osArgs, args...)
}
Expand Down Expand Up @@ -132,7 +137,7 @@
var rargs Args = &stringSliceArgs{v: osArgs}
for _, f := range cmd.allFlags() {
if err := f.PreParse(); err != nil {
return err
return ctx, err
}
}

Expand All @@ -149,7 +154,7 @@
tracef("using post-parse arguments %[1]q (cmd=%[2]q)", args, cmd.Name)

if checkCompletions(ctx, cmd) {
return nil
return ctx, nil
}

if err != nil {
Expand All @@ -160,7 +165,7 @@
if cmd.OnUsageError != nil {
err = cmd.OnUsageError(ctx, cmd, err, cmd.parent != nil)
err = cmd.handleExitCoder(ctx, err)
return err
return ctx, err
}
fmt.Fprintf(cmd.Root().ErrWriter, "Incorrect Usage: %s\n\n", err.Error())
if cmd.Suggest {
Expand All @@ -182,23 +187,23 @@
}
}

return err
return ctx, err
}

if cmd.checkHelp() {
return helpCommandAction(ctx, cmd)
return ctx, helpCommandAction(ctx, cmd)
} else {
tracef("no help is wanted (cmd=%[1]q)", cmd.Name)
}

if cmd.parent == nil && !cmd.HideVersion && checkVersion(cmd) {
ShowVersion(cmd)
return nil
return ctx, nil
}

for _, flag := range cmd.allFlags() {
if err := flag.PostParse(); err != nil {
return err
return ctx, err
}
}

Expand All @@ -219,7 +224,7 @@
for _, grp := range cmd.MutuallyExclusiveFlags {
if err := grp.check(cmd); err != nil {
_ = ShowSubcommandHelp(cmd)
return err
return ctx, err
}
}

Expand Down Expand Up @@ -262,7 +267,12 @@
// If a subcommand has been resolved, let it handle the remaining execution.
if subCmd != nil {
tracef("running sub-command %[1]q with arguments %[2]q (cmd=%[3]q)", subCmd.Name, cmd.Args(), cmd.Name)
return subCmd.Run(ctx, cmd.Args().Slice())

// It is important that we overwrite the ctx variable in the current
// function so any defer'd functions use the new context returned
// from the sub command.
ctx, err = subCmd.run(ctx, cmd.Args().Slice())
Comment thread
dearchap marked this conversation as resolved.
return ctx, err
}

// This code path is the innermost command execution. Here we actually
Expand All @@ -282,7 +292,7 @@
}
if bctx, err := cmd.Before(ctx, cmd); err != nil {
deferErr = cmd.handleExitCoder(ctx, err)
return deferErr
return ctx, deferErr
} else if bctx != nil {
ctx = bctx
}
Expand All @@ -294,14 +304,14 @@
tracef("running flag actions (cmd=%[1]q)", cmd.Name)
if err := cmd.runFlagActions(ctx); err != nil {
deferErr = cmd.handleExitCoder(ctx, err)
return deferErr
return ctx, deferErr
}
}

if err := cmd.checkAllRequiredFlags(); err != nil {
cmd.isInError = true
_ = ShowSubcommandHelp(cmd)
return err
return ctx, err
}

// Run the command action.
Expand All @@ -317,7 +327,7 @@
err = cmd.OnUsageError(ctx, cmd, err, cmd.parent != nil)
}
err = cmd.handleExitCoder(ctx, err)
return err
return ctx, err
}
}
cmd.parsedArgs = &stringSliceArgs{v: rargs}
Expand All @@ -329,5 +339,5 @@
}

tracef("returning deferErr (cmd=%[1]q) %[2]q", cmd.Name, deferErr)
return deferErr
return ctx, deferErr
}
132 changes: 132 additions & 0 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,138 @@ func TestCommand_Run_BeforeReturnNewContext(t *testing.T) {
require.Equal(t, "bval", receivedValFromAction)
}

type ctxKey string

// ctxCollector is a small helper to collect context values.
type ctxCollector struct {
// keys are the keys to check the context for.
keys []ctxKey

// m maps from function name to context name to value.
m map[string]map[ctxKey]string
}

func (cc *ctxCollector) collect(ctx context.Context, fnName string) {
if cc.m == nil {
cc.m = make(map[string]map[ctxKey]string)
}

if _, ok := cc.m[fnName]; !ok {
cc.m[fnName] = make(map[ctxKey]string)
}

for _, k := range cc.keys {
if val := ctx.Value(k); val != nil {
cc.m[fnName][k] = val.(string)
}
}
}

func TestCommand_Run_BeforeReturnNewContextSubcommand(t *testing.T) {
bkey := ctxKey("bkey")
bkey2 := ctxKey("bkey2")

cc := &ctxCollector{keys: []ctxKey{bkey, bkey2}}
cmd := &Command{
Name: "bar",
Before: func(ctx context.Context, cmd *Command) (context.Context, error) {
return context.WithValue(ctx, bkey, "bval"), nil
},
After: func(ctx context.Context, cmd *Command) error {
cc.collect(ctx, "bar.After")
return nil
},
Commands: []*Command{
{
Name: "baz",
Before: func(ctx context.Context, cmd *Command) (context.Context, error) {
return context.WithValue(ctx, bkey2, "bval2"), nil
},
Action: func(ctx context.Context, cmd *Command) error {
cc.collect(ctx, "baz.Action")
return nil
},
After: func(ctx context.Context, cmd *Command) error {
cc.collect(ctx, "baz.After")
return nil
},
},
},
}

require.NoError(t, cmd.Run(buildTestContext(t), []string{"bar", "baz"}))
expected := map[string]map[ctxKey]string{
"bar.After": {
bkey: "bval",
bkey2: "bval2",
},
"baz.Action": {
bkey: "bval",
bkey2: "bval2",
},
"baz.After": {
bkey: "bval",
bkey2: "bval2",
},
}
require.Equal(t, expected, cc.m)
}

func TestCommand_Run_FlagActionContext(t *testing.T) {
bkey := ctxKey("bkey")
bkey2 := ctxKey("bkey2")

cc := &ctxCollector{keys: []ctxKey{bkey, bkey2}}
cmd := &Command{
Name: "bar",
Before: func(ctx context.Context, cmd *Command) (context.Context, error) {
return context.WithValue(ctx, bkey, "bval"), nil
},
Flags: []Flag{
&StringFlag{
Name: "foo",
Action: func(ctx context.Context, cmd *Command, _ string) error {
cc.collect(ctx, "bar.foo.Action")
return nil
},
},
},
Commands: []*Command{
{
Name: "baz",
Before: func(ctx context.Context, cmd *Command) (context.Context, error) {
return context.WithValue(ctx, bkey2, "bval2"), nil
},
Flags: []Flag{
&StringFlag{
Name: "goo",
Action: func(ctx context.Context, cmd *Command, _ string) error {
cc.collect(ctx, "baz.goo.Action")
return nil
},
},
},
Action: func(ctx context.Context, cmd *Command) error {
return nil
},
},
},
}

require.NoError(t, cmd.Run(buildTestContext(t), []string{"bar", "--foo", "value", "baz", "--goo", "value"}))
expected := map[string]map[ctxKey]string{
"bar.foo.Action": {
bkey: "bval",
bkey2: "bval2",
},
"baz.goo.Action": {
bkey: "bval",
bkey2: "bval2",
},
}
require.Equal(t, expected, cc.m)
}

func TestCommand_OnUsageError_hasCommandContext(t *testing.T) {
cmd := &Command{
Name: "bar",
Expand Down