From 6d00909120c77b54b0c9974a4e20ffc540901b98 Mon Sep 17 00:00:00 2001 From: Lukas Malkmus Date: Mon, 3 May 2021 18:33:57 +0200 Subject: [PATCH] Pass context to completion (#1265) --- command.go | 11 ++++++++++- command_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ completions.go | 1 + completions_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 94 insertions(+), 1 deletion(-) diff --git a/command.go b/command.go index ce94d40..5c85c89 100644 --- a/command.go +++ b/command.go @@ -887,7 +887,8 @@ func (c *Command) preRun() { } // ExecuteContext is the same as Execute(), but sets the ctx on the command. -// Retrieve ctx by calling cmd.Context() inside your *Run lifecycle functions. +// Retrieve ctx by calling cmd.Context() inside your *Run lifecycle or ValidArgs +// functions. func (c *Command) ExecuteContext(ctx context.Context) error { c.ctx = ctx return c.Execute() @@ -901,6 +902,14 @@ func (c *Command) Execute() error { return err } +// ExecuteContextC is the same as ExecuteC(), but sets the ctx on the command. +// Retrieve ctx by calling cmd.Context() inside your *Run lifecycle or ValidArgs +// functions. +func (c *Command) ExecuteContextC(ctx context.Context) (*Command, error) { + c.ctx = ctx + return c.ExecuteC() +} + // ExecuteC executes the command. func (c *Command) ExecuteC() (cmd *Command, err error) { if c.ctx == nil { diff --git a/command_test.go b/command_test.go index 9640fc5..583cb02 100644 --- a/command_test.go +++ b/command_test.go @@ -42,6 +42,17 @@ func executeCommandC(root *Command, args ...string) (c *Command, output string, return c, buf.String(), err } +func executeCommandWithContextC(ctx context.Context, root *Command, args ...string) (c *Command, output string, err error) { + buf := new(bytes.Buffer) + root.SetOut(buf) + root.SetErr(buf) + root.SetArgs(args) + + c, err = root.ExecuteContextC(ctx) + + return c, buf.String(), err +} + func resetCommandLineFlagSet() { pflag.CommandLine = pflag.NewFlagSet(os.Args[0], pflag.ExitOnError) } @@ -178,6 +189,35 @@ func TestExecuteContext(t *testing.T) { } } +func TestExecuteContextC(t *testing.T) { + ctx := context.TODO() + + ctxRun := func(cmd *Command, args []string) { + if cmd.Context() != ctx { + t.Errorf("Command %q must have context when called with ExecuteContext", cmd.Use) + } + } + + rootCmd := &Command{Use: "root", Run: ctxRun, PreRun: ctxRun} + childCmd := &Command{Use: "child", Run: ctxRun, PreRun: ctxRun} + granchildCmd := &Command{Use: "grandchild", Run: ctxRun, PreRun: ctxRun} + + childCmd.AddCommand(granchildCmd) + rootCmd.AddCommand(childCmd) + + if _, _, err := executeCommandWithContextC(ctx, rootCmd, ""); err != nil { + t.Errorf("Root command must not fail: %+v", err) + } + + if _, _, err := executeCommandWithContextC(ctx, rootCmd, "child"); err != nil { + t.Errorf("Subcommand must not fail: %+v", err) + } + + if _, _, err := executeCommandWithContextC(ctx, rootCmd, "child", "grandchild"); err != nil { + t.Errorf("Command child must not fail: %+v", err) + } +} + func TestExecute_NoContext(t *testing.T) { run := func(cmd *Command, args []string) { if cmd.Context() != context.Background() { diff --git a/completions.go b/completions.go index fea2c6f..28d7dd0 100644 --- a/completions.go +++ b/completions.go @@ -221,6 +221,7 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi // Unable to find the real command. E.g., someInvalidCmd return c, []string{}, ShellCompDirectiveDefault, fmt.Errorf("Unable to find a command for arguments: %v", trimmedArgs) } + finalCmd.ctx = c.ctx // Check if we are doing flag value completion before parsing the flags. // This is important because if we are completing a flag value, we need to also diff --git a/completions_test.go b/completions_test.go index 603c409..3e16bd0 100644 --- a/completions_test.go +++ b/completions_test.go @@ -2,6 +2,7 @@ package cobra import ( "bytes" + "context" "strings" "testing" ) @@ -1203,6 +1204,48 @@ func TestFlagDirFilterCompletionInGo(t *testing.T) { } } +func TestValidArgsFuncCmdContext(t *testing.T) { + validArgsFunc := func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) { + ctx := cmd.Context() + + if ctx == nil { + t.Error("Received nil context in completion func") + } else if ctx.Value("testKey") != "123" { + t.Error("Received invalid context") + } + + return nil, ShellCompDirectiveDefault + } + + rootCmd := &Command{ + Use: "root", + Run: emptyRun, + } + childCmd := &Command{ + Use: "childCmd", + ValidArgsFunction: validArgsFunc, + Run: emptyRun, + } + rootCmd.AddCommand(childCmd) + + //nolint:golint,staticcheck // We can safely use a basic type as key in tests. + ctx := context.WithValue(context.Background(), "testKey", "123") + + // Test completing an empty string on the childCmd + _, output, err := executeCommandWithContextC(ctx, rootCmd, ShellCompNoDescRequestCmd, "childCmd", "") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + expected := strings.Join([]string{ + ":0", + "Completion ended with directive: ShellCompDirectiveDefault", ""}, "\n") + + if output != expected { + t.Errorf("expected: %q, got: %q", expected, output) + } +} + func TestValidArgsFuncSingleCmd(t *testing.T) { rootCmd := &Command{ Use: "root",