fix: don't remove flag value that matches subcommand name (#1781)

When the command searches args to find the arg matching a
particular subcommand name, it needs to ignore flag values,
as it is possible that the value for a flag might match
the name of the sub command.

This change improves argsMinusFirstX() to ignore flag values
when it searches for the X to exclude from the result.
This commit is contained in:
Brian Pursley 2022-11-07 23:12:02 -05:00 committed by GitHub
parent cc7e235fc2
commit 6b0bd3076c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 121 additions and 8 deletions

View File

@ -655,15 +655,39 @@ Loop:
// argsMinusFirstX removes only the first x from args. Otherwise, commands that look like // argsMinusFirstX removes only the first x from args. Otherwise, commands that look like
// openshift admin policy add-role-to-user admin my-user, lose the admin argument (arg[4]). // openshift admin policy add-role-to-user admin my-user, lose the admin argument (arg[4]).
func argsMinusFirstX(args []string, x string) []string { // Special care needs to be taken not to remove a flag value.
for i, y := range args { func (c *Command) argsMinusFirstX(args []string, x string) []string {
if x == y { if len(args) == 0 {
return args
}
c.mergePersistentFlags()
flags := c.Flags()
Loop:
for pos := 0; pos < len(args); pos++ {
s := args[pos]
switch {
case s == "--":
// -- means we have reached the end of the parseable args. Break out of the loop now.
break Loop
case strings.HasPrefix(s, "--") && !strings.Contains(s, "=") && !hasNoOptDefVal(s[2:], flags):
fallthrough
case strings.HasPrefix(s, "-") && !strings.Contains(s, "=") && len(s) == 2 && !shortHasNoOptDefVal(s[1:], flags):
// This is a flag without a default value, and an equal sign is not used. Increment pos in order to skip
// over the next arg, because that is the value of this flag.
pos++
continue
case !strings.HasPrefix(s, "-"):
// This is not a flag or a flag value. Check to see if it matches what we're looking for, and if so,
// return the args, excluding the one at this position.
if s == x {
ret := []string{} ret := []string{}
ret = append(ret, args[:i]...) ret = append(ret, args[:pos]...)
ret = append(ret, args[i+1:]...) ret = append(ret, args[pos+1:]...)
return ret return ret
} }
} }
}
return args return args
} }
@ -686,7 +710,7 @@ func (c *Command) Find(args []string) (*Command, []string, error) {
cmd := c.findNext(nextSubCmd) cmd := c.findNext(nextSubCmd)
if cmd != nil { if cmd != nil {
return innerfind(cmd, argsMinusFirstX(innerArgs, nextSubCmd)) return innerfind(cmd, c.argsMinusFirstX(innerArgs, nextSubCmd))
} }
return c, innerArgs return c, innerArgs
} }

View File

@ -2603,3 +2603,92 @@ func TestHelpflagCommandExecutedWithoutVersionSet(t *testing.T) {
checkStringContains(t, output, HelpFlag) checkStringContains(t, output, HelpFlag)
checkStringOmits(t, output, VersionFlag) checkStringOmits(t, output, VersionFlag)
} }
func TestFind(t *testing.T) {
var foo, bar string
root := &Command{
Use: "root",
}
root.PersistentFlags().StringVarP(&foo, "foo", "f", "", "")
root.PersistentFlags().StringVarP(&bar, "bar", "b", "something", "")
child := &Command{
Use: "child",
}
root.AddCommand(child)
testCases := []struct {
args []string
expectedFoundArgs []string
}{
{
[]string{"child"},
[]string{},
},
{
[]string{"child", "child"},
[]string{"child"},
},
{
[]string{"child", "foo", "child", "bar", "child", "baz", "child"},
[]string{"foo", "child", "bar", "child", "baz", "child"},
},
{
[]string{"-f", "child", "child"},
[]string{"-f", "child"},
},
{
[]string{"child", "-f", "child"},
[]string{"-f", "child"},
},
{
[]string{"-b", "child", "child"},
[]string{"-b", "child"},
},
{
[]string{"child", "-b", "child"},
[]string{"-b", "child"},
},
{
[]string{"child", "-b"},
[]string{"-b"},
},
{
[]string{"-b", "-f", "child", "child"},
[]string{"-b", "-f", "child"},
},
{
[]string{"-f", "child", "-b", "something", "child"},
[]string{"-f", "child", "-b", "something"},
},
{
[]string{"-f", "child", "child", "-b"},
[]string{"-f", "child", "-b"},
},
{
[]string{"-f=child", "-b=something", "child"},
[]string{"-f=child", "-b=something"},
},
{
[]string{"--foo", "child", "--bar", "something", "child"},
[]string{"--foo", "child", "--bar", "something"},
},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("%v", tc.args), func(t *testing.T) {
cmd, foundArgs, err := root.Find(tc.args)
if err != nil {
t.Fatal(err)
}
if cmd != child {
t.Fatal("Expected cmd to be child, but it was not")
}
if !reflect.DeepEqual(tc.expectedFoundArgs, foundArgs) {
t.Fatalf("Wrong args\nExpected: %v\nGot: %v", tc.expectedFoundArgs, foundArgs)
}
})
}
}