enforce required flags (#502)
This commit is contained in:
@ -438,3 +438,51 @@ func TestTraverseWithBadChildFlag(t *testing.T) {
|
||||
t.Fatalf("wrong command %q expected %q", c.Name(), sub.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequiredFlags(t *testing.T) {
|
||||
c := &Command{Use: "c", Run: func(*Command, []string) {}}
|
||||
output := new(bytes.Buffer)
|
||||
c.SetOutput(output)
|
||||
c.Flags().String("foo1", "", "required foo1")
|
||||
c.MarkFlagRequired("foo1")
|
||||
c.Flags().String("foo2", "", "required foo2")
|
||||
c.MarkFlagRequired("foo2")
|
||||
c.Flags().String("bar", "", "optional bar")
|
||||
|
||||
expected := fmt.Sprintf("Required flag(s) %q, %q have/has not been set", "foo1", "foo2")
|
||||
|
||||
if err := c.Execute(); err != nil {
|
||||
if err.Error() != expected {
|
||||
t.Errorf("expected %v, got %v", expected, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPersistentRequiredFlags(t *testing.T) {
|
||||
parent := &Command{Use: "parent", Run: func(*Command, []string) {}}
|
||||
output := new(bytes.Buffer)
|
||||
parent.SetOutput(output)
|
||||
parent.PersistentFlags().String("foo1", "", "required foo1")
|
||||
parent.MarkPersistentFlagRequired("foo1")
|
||||
parent.PersistentFlags().String("foo2", "", "required foo2")
|
||||
parent.MarkPersistentFlagRequired("foo2")
|
||||
parent.Flags().String("foo3", "", "optional foo3")
|
||||
|
||||
child := &Command{Use: "child", Run: func(*Command, []string) {}}
|
||||
child.Flags().String("bar1", "", "required bar1")
|
||||
child.MarkFlagRequired("bar1")
|
||||
child.Flags().String("bar2", "", "required bar2")
|
||||
child.MarkFlagRequired("bar2")
|
||||
child.Flags().String("bar3", "", "optional bar3")
|
||||
|
||||
parent.AddCommand(child)
|
||||
parent.SetArgs([]string{"child"})
|
||||
|
||||
expected := fmt.Sprintf("Required flag(s) %q, %q, %q, %q have/has not been set", "bar1", "bar2", "foo1", "foo2")
|
||||
|
||||
if err := parent.Execute(); err != nil {
|
||||
if err.Error() != expected {
|
||||
t.Errorf("expected %v, got %v", expected, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user