diff --git a/cobra.go b/cobra.go index 79a433c..075b49e 100644 --- a/cobra.go +++ b/cobra.go @@ -19,7 +19,7 @@ package cobra import ( "bytes" "fmt" - flag "github.com/ogier/pflag" + flag "github.com/spf13/pflag" "os" "strings" ) @@ -252,7 +252,7 @@ func (c *Command) PersistentFlags() *flag.FlagSet { } c.pflags.SetOutput(c.flagErrorBuf) } - return c.flags + return c.pflags } // Intended for use in testing @@ -265,22 +265,11 @@ func (c *Command) ResetFlags() { } func (c *Command) HasFlags() bool { - return hasFlags(c.flags) + return c.Flags().HasFlags() } func (c *Command) HasPersistentFlags() bool { - return hasFlags(c.pflags) -} - -// Is this set of flags not empty -func hasFlags(f *flag.FlagSet) bool { - if f == nil { - return false - } - if f.NFlag() != 0 { - return true - } - return false + return c.PersistentFlags().HasFlags() } // Climbs up the command tree looking for matching flag @@ -308,10 +297,8 @@ func (c *Command) persistentFlag(name string) (flag *flag.Flag) { // Parses persistent flag tree & local flags func (c *Command) ParseFlags(args []string) (err error) { - err = c.ParsePersistentFlags(args) - if err != nil { - return err - } + c.mergePersistentFlags() + err = c.Flags().Parse(args) if err != nil { return err @@ -319,6 +306,25 @@ func (c *Command) ParseFlags(args []string) (err error) { return nil } +func (c *Command) mergePersistentFlags() { + var rmerge func(x *Command) + + rmerge = func(x *Command) { + if x.HasPersistentFlags() { + x.PersistentFlags().VisitAll(func(f *flag.Flag) { + if c.Flags().Lookup(f.Name) == nil { + c.Flags().AddFlag(f) + } + }) + } + if x.HasParent() { + rmerge(x.parent) + } + } + + rmerge(c) +} + // Climbs up the command tree parsing flags from top to bottom func (c *Command) ParsePersistentFlags(args []string) (err error) { if !c.HasParent() || (c.parent.HasPersistentFlags() && c.parent.PersistentFlags().Parsed()) {