diff --git a/command.go b/command.go index fa9b9e5..96b7f35 100644 --- a/command.go +++ b/command.go @@ -251,7 +251,23 @@ func (c *Command) resetChildrensParents() { } } -func stripFlags(args []string) []string { +// Test if the named flag is a boolean flag. +func isBooleanFlag(name string, f *flag.FlagSet) bool { + flag := f.Lookup(name) + if flag == nil { + return false + } + return flag.Value.Type() == "bool" +} + +// Test if the named flag is a boolean flag. +func isBooleanShortFlag(name string, f *flag.FlagSet) bool { + result := false + f.VisitAll(func (f *flag.Flag) { if f.Shorthand == name && f.Value.Type() == "bool" { result = true } }) + return result +} + +func stripFlags(args []string, c *Command) []string { if len(args) < 1 { return args } @@ -259,6 +275,7 @@ func stripFlags(args []string) []string { commands := []string{} inQuote := false + inFlag := false for _, y := range args { if !inQuote { switch { @@ -266,8 +283,16 @@ func stripFlags(args []string) []string { inQuote = true case strings.Contains(y, "=\""): inQuote = true + case strings.HasPrefix(y, "--") && !strings.Contains(y, "="): + // TODO: this isn't quite right, we should really check ahead for 'true' or 'false' + inFlag = !isBooleanFlag(y[2:], c.Flags()) + case strings.HasPrefix(y, "-") && !strings.Contains(y, "=") && len(y) == 2 && !isBooleanShortFlag(y[1:], c.Flags()): + inFlag = true + case inFlag: + inFlag = false case !strings.HasPrefix(y, "-"): commands = append(commands, y) + inFlag = false } } @@ -305,7 +330,7 @@ func (c *Command) Find(arrs []string) (*Command, []string, error) { innerfind = func(c *Command, args []string) (*Command, []string) { if len(args) > 0 && c.HasSubCommands() { - argsWOflags := stripFlags(args) + argsWOflags := stripFlags(args, c) if len(argsWOflags) > 0 { matches := make([]*Command, 0) for _, cmd := range c.commands { diff --git a/command_test.go b/command_test.go new file mode 100644 index 0000000..ae66522 --- /dev/null +++ b/command_test.go @@ -0,0 +1,81 @@ +package cobra + +import ( + "reflect" + "testing" +) + +func TestStripFlags(t *testing.T) { + tests := []struct { + input []string + output []string + }{ + { + []string{"foo", "bar"}, + []string{"foo", "bar"}, + }, + { + []string{"foo", "--bar", "-b"}, + []string{"foo"}, + }, + { + []string{"-b", "foo", "--bar", "bar"}, + []string{}, + }, + { + []string{"-i10", "echo"}, + []string{"echo"}, + }, + { + []string{"-i=10", "echo"}, + []string{"echo"}, + }, + { + []string{"--int=100", "echo"}, + []string{"echo"}, + }, + { + []string{"-ib", "echo", "-bfoo", "baz"}, + []string{"echo", "baz"}, + }, + { + []string{"-i=baz", "bar", "-i", "foo", "blah"}, + []string{"bar", "blah"}, + }, + { + []string{"--int=baz", "-bbar", "-i", "foo", "blah"}, + []string{"blah"}, + }, + { + []string{"--cat", "bar", "-i", "foo", "blah"}, + []string{"bar", "blah"}, + }, + { + []string{"-c", "bar", "-i", "foo", "blah"}, + []string{"bar", "blah"}, + }, + } + + cmdPrint := &Command{ + Use: "print [string to print]", + Short: "Print anything to the screen", + Long: `an utterly useless command for testing.`, + Run: func(cmd *Command, args []string) { + tp = args + }, + } + + var flagi int + var flagstr string + var flagbool bool + cmdPrint.Flags().IntVarP(&flagi, "int", "i", 345, "help message for flag int") + cmdPrint.Flags().StringVarP(&flagstr, "bar", "b", "bar", "help message for flag string") + cmdPrint.Flags().BoolVarP(&flagbool, "cat", "c", false, "help message for flag bool") + + for _, test := range tests { + output := stripFlags(test.input, cmdPrint) + if !reflect.DeepEqual(test.output, output) { + t.Errorf("expected: %v, got: %v", test.output, output) + } + } +}