diff --git a/command_run.go b/command_run.go index fb40a7c47a..c1d450cf32 100644 --- a/command_run.go +++ b/command_run.go @@ -91,6 +91,11 @@ outer: // arguments are parsed according to the Flag and Command // definitions and the matching Action functions are run. func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { + _, deferErr = cmd.run(ctx, osArgs) + return +} + +func (cmd *Command) run(ctx context.Context, osArgs []string) (_ context.Context, deferErr error) { tracef("running with arguments %[1]q (cmd=%[2]q)", osArgs, cmd.Name) cmd.setupDefaults(osArgs) @@ -102,7 +107,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { if cmd.parent == nil { if cmd.ReadArgsFromStdin { if args, err := cmd.parseArgsFromStdin(); err != nil { - return err + return ctx, err } else { osArgs = append(osArgs, args...) } @@ -132,7 +137,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { var rargs Args = &stringSliceArgs{v: osArgs} for _, f := range cmd.allFlags() { if err := f.PreParse(); err != nil { - return err + return ctx, err } } @@ -149,7 +154,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { tracef("using post-parse arguments %[1]q (cmd=%[2]q)", args, cmd.Name) if checkCompletions(ctx, cmd) { - return nil + return ctx, nil } if err != nil { @@ -160,7 +165,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { if cmd.OnUsageError != nil { err = cmd.OnUsageError(ctx, cmd, err, cmd.parent != nil) err = cmd.handleExitCoder(ctx, err) - return err + return ctx, err } fmt.Fprintf(cmd.Root().ErrWriter, "Incorrect Usage: %s\n\n", err.Error()) if cmd.Suggest { @@ -182,23 +187,23 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { } } - return err + return ctx, err } if cmd.checkHelp() { - return helpCommandAction(ctx, cmd) + return ctx, helpCommandAction(ctx, cmd) } else { tracef("no help is wanted (cmd=%[1]q)", cmd.Name) } if cmd.parent == nil && !cmd.HideVersion && checkVersion(cmd) { ShowVersion(cmd) - return nil + return ctx, nil } for _, flag := range cmd.allFlags() { if err := flag.PostParse(); err != nil { - return err + return ctx, err } } @@ -219,7 +224,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { for _, grp := range cmd.MutuallyExclusiveFlags { if err := grp.check(cmd); err != nil { _ = ShowSubcommandHelp(cmd) - return err + return ctx, err } } @@ -262,7 +267,12 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { // If a subcommand has been resolved, let it handle the remaining execution. if subCmd != nil { tracef("running sub-command %[1]q with arguments %[2]q (cmd=%[3]q)", subCmd.Name, cmd.Args(), cmd.Name) - return subCmd.Run(ctx, cmd.Args().Slice()) + + // It is important that we overwrite the ctx variable in the current + // function so any defer'd functions use the new context returned + // from the sub command. + ctx, err = subCmd.run(ctx, cmd.Args().Slice()) + return ctx, err } // This code path is the innermost command execution. Here we actually @@ -282,7 +292,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { } if bctx, err := cmd.Before(ctx, cmd); err != nil { deferErr = cmd.handleExitCoder(ctx, err) - return deferErr + return ctx, deferErr } else if bctx != nil { ctx = bctx } @@ -294,14 +304,14 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { tracef("running flag actions (cmd=%[1]q)", cmd.Name) if err := cmd.runFlagActions(ctx); err != nil { deferErr = cmd.handleExitCoder(ctx, err) - return deferErr + return ctx, deferErr } } if err := cmd.checkAllRequiredFlags(); err != nil { cmd.isInError = true _ = ShowSubcommandHelp(cmd) - return err + return ctx, err } // Run the command action. @@ -317,7 +327,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { err = cmd.OnUsageError(ctx, cmd, err, cmd.parent != nil) } err = cmd.handleExitCoder(ctx, err) - return err + return ctx, err } } cmd.parsedArgs = &stringSliceArgs{v: rargs} @@ -329,5 +339,5 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { } tracef("returning deferErr (cmd=%[1]q) %[2]q", cmd.Name, deferErr) - return deferErr + return ctx, deferErr } diff --git a/command_test.go b/command_test.go index 3a7978a3d9..b44ed07932 100644 --- a/command_test.go +++ b/command_test.go @@ -351,6 +351,138 @@ func TestCommand_Run_BeforeReturnNewContext(t *testing.T) { require.Equal(t, "bval", receivedValFromAction) } +type ctxKey string + +// ctxCollector is a small helper to collect context values. +type ctxCollector struct { + // keys are the keys to check the context for. + keys []ctxKey + + // m maps from function name to context name to value. + m map[string]map[ctxKey]string +} + +func (cc *ctxCollector) collect(ctx context.Context, fnName string) { + if cc.m == nil { + cc.m = make(map[string]map[ctxKey]string) + } + + if _, ok := cc.m[fnName]; !ok { + cc.m[fnName] = make(map[ctxKey]string) + } + + for _, k := range cc.keys { + if val := ctx.Value(k); val != nil { + cc.m[fnName][k] = val.(string) + } + } +} + +func TestCommand_Run_BeforeReturnNewContextSubcommand(t *testing.T) { + bkey := ctxKey("bkey") + bkey2 := ctxKey("bkey2") + + cc := &ctxCollector{keys: []ctxKey{bkey, bkey2}} + cmd := &Command{ + Name: "bar", + Before: func(ctx context.Context, cmd *Command) (context.Context, error) { + return context.WithValue(ctx, bkey, "bval"), nil + }, + After: func(ctx context.Context, cmd *Command) error { + cc.collect(ctx, "bar.After") + return nil + }, + Commands: []*Command{ + { + Name: "baz", + Before: func(ctx context.Context, cmd *Command) (context.Context, error) { + return context.WithValue(ctx, bkey2, "bval2"), nil + }, + Action: func(ctx context.Context, cmd *Command) error { + cc.collect(ctx, "baz.Action") + return nil + }, + After: func(ctx context.Context, cmd *Command) error { + cc.collect(ctx, "baz.After") + return nil + }, + }, + }, + } + + require.NoError(t, cmd.Run(buildTestContext(t), []string{"bar", "baz"})) + expected := map[string]map[ctxKey]string{ + "bar.After": { + bkey: "bval", + bkey2: "bval2", + }, + "baz.Action": { + bkey: "bval", + bkey2: "bval2", + }, + "baz.After": { + bkey: "bval", + bkey2: "bval2", + }, + } + require.Equal(t, expected, cc.m) +} + +func TestCommand_Run_FlagActionContext(t *testing.T) { + bkey := ctxKey("bkey") + bkey2 := ctxKey("bkey2") + + cc := &ctxCollector{keys: []ctxKey{bkey, bkey2}} + cmd := &Command{ + Name: "bar", + Before: func(ctx context.Context, cmd *Command) (context.Context, error) { + return context.WithValue(ctx, bkey, "bval"), nil + }, + Flags: []Flag{ + &StringFlag{ + Name: "foo", + Action: func(ctx context.Context, cmd *Command, _ string) error { + cc.collect(ctx, "bar.foo.Action") + return nil + }, + }, + }, + Commands: []*Command{ + { + Name: "baz", + Before: func(ctx context.Context, cmd *Command) (context.Context, error) { + return context.WithValue(ctx, bkey2, "bval2"), nil + }, + Flags: []Flag{ + &StringFlag{ + Name: "goo", + Action: func(ctx context.Context, cmd *Command, _ string) error { + cc.collect(ctx, "baz.goo.Action") + return nil + }, + }, + }, + Action: func(ctx context.Context, cmd *Command) error { + return nil + }, + }, + }, + } + + require.NoError(t, cmd.Run(buildTestContext(t), []string{"bar", "--foo", "value", "baz", "--goo", "value"})) + expected := map[string]map[ctxKey]string{ + "bar.foo.Action": { + bkey: "bval", + bkey2: "bval2", + }, + "baz.goo.Action": { + bkey: "bval", + bkey2: "bval2", + }, + } + require.Equal(t, expected, cc.m) +} + func TestCommand_OnUsageError_hasCommandContext(t *testing.T) { cmd := &Command{ Name: "bar",