From 547c9145a719406787cc08c8ff894ad19c72fb6b Mon Sep 17 00:00:00 2001 From: Felix Stein Date: Mon, 21 Apr 2025 22:58:15 +0200 Subject: [PATCH 1/2] use correct context in After function with subcommand When adding values to a context in a Before function that value is not present in the After function when a subcommand was run. Without the subcommand the value is present. This fixes that be returning the new context from the subcommand and changing the context in the calling function. Defer'd function now pick up the new context. Fixes #2098 --- command_run.go | 40 ++++++++++------- command_test.go | 115 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 140 insertions(+), 15 deletions(-) diff --git a/command_run.go b/command_run.go index fb40a7c47a..c1d450cf32 100644 --- a/command_run.go +++ b/command_run.go @@ -91,6 +91,11 @@ outer: // arguments are parsed according to the Flag and Command // definitions and the matching Action functions are run. func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { + _, deferErr = cmd.run(ctx, osArgs) + return +} + +func (cmd *Command) run(ctx context.Context, osArgs []string) (_ context.Context, deferErr error) { tracef("running with arguments %[1]q (cmd=%[2]q)", osArgs, cmd.Name) cmd.setupDefaults(osArgs) @@ -102,7 +107,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { if cmd.parent == nil { if cmd.ReadArgsFromStdin { if args, err := cmd.parseArgsFromStdin(); err != nil { - return err + return ctx, err } else { osArgs = append(osArgs, args...) } @@ -132,7 +137,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { var rargs Args = &stringSliceArgs{v: osArgs} for _, f := range cmd.allFlags() { if err := f.PreParse(); err != nil { - return err + return ctx, err } } @@ -149,7 +154,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { tracef("using post-parse arguments %[1]q (cmd=%[2]q)", args, cmd.Name) if checkCompletions(ctx, cmd) { - return nil + return ctx, nil } if err != nil { @@ -160,7 +165,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { if cmd.OnUsageError != nil { err = cmd.OnUsageError(ctx, cmd, err, cmd.parent != nil) err = cmd.handleExitCoder(ctx, err) - return err + return ctx, err } fmt.Fprintf(cmd.Root().ErrWriter, "Incorrect Usage: %s\n\n", err.Error()) if cmd.Suggest { @@ -182,23 +187,23 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { } } - return err + return ctx, err } if cmd.checkHelp() { - return helpCommandAction(ctx, cmd) + return ctx, helpCommandAction(ctx, cmd) } else { tracef("no help is wanted (cmd=%[1]q)", cmd.Name) } if cmd.parent == nil && !cmd.HideVersion && checkVersion(cmd) { ShowVersion(cmd) - return nil + return ctx, nil } for _, flag := range cmd.allFlags() { if err := flag.PostParse(); err != nil { - return err + return ctx, err } } @@ -219,7 +224,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { for _, grp := range cmd.MutuallyExclusiveFlags { if err := grp.check(cmd); err != nil { _ = ShowSubcommandHelp(cmd) - return err + return ctx, err } } @@ -262,7 +267,12 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { // If a subcommand has been resolved, let it handle the remaining execution. if subCmd != nil { tracef("running sub-command %[1]q with arguments %[2]q (cmd=%[3]q)", subCmd.Name, cmd.Args(), cmd.Name) - return subCmd.Run(ctx, cmd.Args().Slice()) + + // It is important that we overwrite the ctx variable in the current + // function so any defer'd functions use the new context returned + // from the sub command. + ctx, err = subCmd.run(ctx, cmd.Args().Slice()) + return ctx, err } // This code path is the innermost command execution. Here we actually @@ -282,7 +292,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { } if bctx, err := cmd.Before(ctx, cmd); err != nil { deferErr = cmd.handleExitCoder(ctx, err) - return deferErr + return ctx, deferErr } else if bctx != nil { ctx = bctx } @@ -294,14 +304,14 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { tracef("running flag actions (cmd=%[1]q)", cmd.Name) if err := cmd.runFlagActions(ctx); err != nil { deferErr = cmd.handleExitCoder(ctx, err) - return deferErr + return ctx, deferErr } } if err := cmd.checkAllRequiredFlags(); err != nil { cmd.isInError = true _ = ShowSubcommandHelp(cmd) - return err + return ctx, err } // Run the command action. @@ -317,7 +327,7 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { err = cmd.OnUsageError(ctx, cmd, err, cmd.parent != nil) } err = cmd.handleExitCoder(ctx, err) - return err + return ctx, err } } cmd.parsedArgs = &stringSliceArgs{v: rargs} @@ -329,5 +339,5 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { } tracef("returning deferErr (cmd=%[1]q) %[2]q", cmd.Name, deferErr) - return deferErr + return ctx, deferErr } diff --git a/command_test.go b/command_test.go index 3a7978a3d9..ece2f75ffd 100644 --- a/command_test.go +++ b/command_test.go @@ -351,6 +351,121 @@ func TestCommand_Run_BeforeReturnNewContext(t *testing.T) { require.Equal(t, "bval", receivedValFromAction) } +func TestCommand_Run_BeforeReturnNewContextSubcommand(t *testing.T) { + type key string + + bkey := key("bkey") + bkey2 := key("bkey2") + + contextValues := make(map[string]string) + collectContextValues := func(ctx context.Context, name string) { + if val := ctx.Value(bkey); val != nil { + contextValues["bkey in "+name] = val.(string) + } + if val := ctx.Value(bkey2); val != nil { + contextValues["bkey2 in "+name] = val.(string) + } + } + cmd := &Command{ + Name: "bar", + Before: func(ctx context.Context, cmd *Command) (context.Context, error) { + return context.WithValue(ctx, bkey, "bval"), nil + }, + After: func(ctx context.Context, cmd *Command) error { + collectContextValues(ctx, "bar.After") + return nil + }, + Commands: []*Command{ + { + Name: "baz", + Before: func(ctx context.Context, cmd *Command) (context.Context, error) { + return context.WithValue(ctx, bkey2, "bval2"), nil + }, + Action: func(ctx context.Context, cmd *Command) error { + collectContextValues(ctx, "baz.Action") + return nil + }, + After: func(ctx context.Context, cmd *Command) error { + collectContextValues(ctx, "baz.After") + return nil + }, + }, + }, + } + + require.NoError(t, cmd.Run(buildTestContext(t), []string{"bar", "baz"})) + expected := map[string]string{ + "bkey in bar.After": "bval", + "bkey2 in bar.After": "bval2", + "bkey in baz.Action": "bval", + "bkey2 in baz.Action": "bval2", + "bkey in baz.After": "bval", + "bkey2 in baz.After": "bval2", + } + require.Equal(t, expected, contextValues) +} + +func TestCommand_Run_FlagActionContext(t *testing.T) { + type key string + + bkey := key("bkey") + bkey2 := key("bkey2") + + contextValues := make(map[string]string) + collectContextValues := func(ctx context.Context, name string) { + if val := ctx.Value(bkey); val != nil { + contextValues["bkey in "+name] = val.(string) + } + if val := ctx.Value(bkey2); val != nil { + contextValues["bkey2 in "+name] = val.(string) + } + } + cmd := &Command{ + Name: "bar", + Before: func(ctx context.Context, cmd *Command) (context.Context, error) { + return context.WithValue(ctx, bkey, "bval"), nil + }, + Flags: []Flag{ + &StringFlag{ + Name: "foo", + Action: func(ctx context.Context, cmd *Command, _ string) error { + collectContextValues(ctx, "bar.foo.Action") + return nil + }, + }, + }, + Commands: []*Command{ + { + Name: "baz", + Before: func(ctx context.Context, cmd *Command) (context.Context, error) { + return context.WithValue(ctx, bkey2, "bval2"), nil + }, + Flags: []Flag{ + &StringFlag{ + Name: "goo", + Action: func(ctx context.Context, cmd *Command, _ string) error { + collectContextValues(ctx, "baz.goo.Action") + return nil + }, + }, + }, + Action: func(ctx context.Context, cmd *Command) error { + return nil + }, + }, + }, + } + + require.NoError(t, cmd.Run(buildTestContext(t), []string{"bar", "--foo", "value", "baz", "--goo", "value"})) + expected := map[string]string{ + "bkey in bar.foo.Action": "bval", + "bkey2 in bar.foo.Action": "bval2", + "bkey in baz.goo.Action": "bval", + "bkey2 in baz.goo.Action": "bval2", + } + require.Equal(t, expected, contextValues) +} + func TestCommand_OnUsageError_hasCommandContext(t *testing.T) { cmd := &Command{ Name: "bar", From a5cfa4f474bc15704e71452319e11bb215c6f90c Mon Sep 17 00:00:00 2001 From: Felix Stein Date: Wed, 23 Apr 2025 23:05:46 +0200 Subject: [PATCH 2/2] extract some shared test code into ctxCollector --- command_test.go | 103 ++++++++++++++++++++++++++++-------------------- 1 file changed, 60 insertions(+), 43 deletions(-) diff --git a/command_test.go b/command_test.go index ece2f75ffd..b44ed07932 100644 --- a/command_test.go +++ b/command_test.go @@ -351,28 +351,45 @@ func TestCommand_Run_BeforeReturnNewContext(t *testing.T) { require.Equal(t, "bval", receivedValFromAction) } -func TestCommand_Run_BeforeReturnNewContextSubcommand(t *testing.T) { - type key string +type ctxKey string - bkey := key("bkey") - bkey2 := key("bkey2") +// ctxCollector is a small helper to collect context values. +type ctxCollector struct { + // keys are the keys to check the context for. + keys []ctxKey - contextValues := make(map[string]string) - collectContextValues := func(ctx context.Context, name string) { - if val := ctx.Value(bkey); val != nil { - contextValues["bkey in "+name] = val.(string) - } - if val := ctx.Value(bkey2); val != nil { - contextValues["bkey2 in "+name] = val.(string) + // m maps from function name to context name to value. + m map[string]map[ctxKey]string +} + +func (cc *ctxCollector) collect(ctx context.Context, fnName string) { + if cc.m == nil { + cc.m = make(map[string]map[ctxKey]string) + } + + if _, ok := cc.m[fnName]; !ok { + cc.m[fnName] = make(map[ctxKey]string) + } + + for _, k := range cc.keys { + if val := ctx.Value(k); val != nil { + cc.m[fnName][k] = val.(string) } } +} + +func TestCommand_Run_BeforeReturnNewContextSubcommand(t *testing.T) { + bkey := ctxKey("bkey") + bkey2 := ctxKey("bkey2") + + cc := &ctxCollector{keys: []ctxKey{bkey, bkey2}} cmd := &Command{ Name: "bar", Before: func(ctx context.Context, cmd *Command) (context.Context, error) { return context.WithValue(ctx, bkey, "bval"), nil }, After: func(ctx context.Context, cmd *Command) error { - collectContextValues(ctx, "bar.After") + cc.collect(ctx, "bar.After") return nil }, Commands: []*Command{ @@ -382,11 +399,11 @@ func TestCommand_Run_BeforeReturnNewContextSubcommand(t *testing.T) { return context.WithValue(ctx, bkey2, "bval2"), nil }, Action: func(ctx context.Context, cmd *Command) error { - collectContextValues(ctx, "baz.Action") + cc.collect(ctx, "baz.Action") return nil }, After: func(ctx context.Context, cmd *Command) error { - collectContextValues(ctx, "baz.After") + cc.collect(ctx, "baz.After") return nil }, }, @@ -394,32 +411,28 @@ func TestCommand_Run_BeforeReturnNewContextSubcommand(t *testing.T) { } require.NoError(t, cmd.Run(buildTestContext(t), []string{"bar", "baz"})) - expected := map[string]string{ - "bkey in bar.After": "bval", - "bkey2 in bar.After": "bval2", - "bkey in baz.Action": "bval", - "bkey2 in baz.Action": "bval2", - "bkey in baz.After": "bval", - "bkey2 in baz.After": "bval2", + expected := map[string]map[ctxKey]string{ + "bar.After": { + bkey: "bval", + bkey2: "bval2", + }, + "baz.Action": { + bkey: "bval", + bkey2: "bval2", + }, + "baz.After": { + bkey: "bval", + bkey2: "bval2", + }, } - require.Equal(t, expected, contextValues) + require.Equal(t, expected, cc.m) } func TestCommand_Run_FlagActionContext(t *testing.T) { - type key string - - bkey := key("bkey") - bkey2 := key("bkey2") + bkey := ctxKey("bkey") + bkey2 := ctxKey("bkey2") - contextValues := make(map[string]string) - collectContextValues := func(ctx context.Context, name string) { - if val := ctx.Value(bkey); val != nil { - contextValues["bkey in "+name] = val.(string) - } - if val := ctx.Value(bkey2); val != nil { - contextValues["bkey2 in "+name] = val.(string) - } - } + cc := &ctxCollector{keys: []ctxKey{bkey, bkey2}} cmd := &Command{ Name: "bar", Before: func(ctx context.Context, cmd *Command) (context.Context, error) { @@ -429,7 +442,7 @@ func TestCommand_Run_FlagActionContext(t *testing.T) { &StringFlag{ Name: "foo", Action: func(ctx context.Context, cmd *Command, _ string) error { - collectContextValues(ctx, "bar.foo.Action") + cc.collect(ctx, "bar.foo.Action") return nil }, }, @@ -444,7 +457,7 @@ func TestCommand_Run_FlagActionContext(t *testing.T) { &StringFlag{ Name: "goo", Action: func(ctx context.Context, cmd *Command, _ string) error { - collectContextValues(ctx, "baz.goo.Action") + cc.collect(ctx, "baz.goo.Action") return nil }, }, @@ -457,13 +470,17 @@ func TestCommand_Run_FlagActionContext(t *testing.T) { } require.NoError(t, cmd.Run(buildTestContext(t), []string{"bar", "--foo", "value", "baz", "--goo", "value"})) - expected := map[string]string{ - "bkey in bar.foo.Action": "bval", - "bkey2 in bar.foo.Action": "bval2", - "bkey in baz.goo.Action": "bval", - "bkey2 in baz.goo.Action": "bval2", + expected := map[string]map[ctxKey]string{ + "bar.foo.Action": { + bkey: "bval", + bkey2: "bval2", + }, + "baz.goo.Action": { + bkey: "bval", + bkey2: "bval2", + }, } - require.Equal(t, expected, contextValues) + require.Equal(t, expected, cc.m) } func TestCommand_OnUsageError_hasCommandContext(t *testing.T) {