Create and utilize mergePersistentFlags method

This commit is contained in:
spf13 2013-09-10 18:26:17 -04:00
parent ed6206272d
commit 061ba30a84

View File

@ -19,7 +19,7 @@ package cobra
import (
"bytes"
"fmt"
flag "github.com/ogier/pflag"
flag "github.com/spf13/pflag"
"os"
"strings"
)
@ -252,7 +252,7 @@ func (c *Command) PersistentFlags() *flag.FlagSet {
}
c.pflags.SetOutput(c.flagErrorBuf)
}
return c.flags
return c.pflags
}
// Intended for use in testing
@ -265,22 +265,11 @@ func (c *Command) ResetFlags() {
}
func (c *Command) HasFlags() bool {
return hasFlags(c.flags)
return c.Flags().HasFlags()
}
func (c *Command) HasPersistentFlags() bool {
return hasFlags(c.pflags)
}
// Is this set of flags not empty
func hasFlags(f *flag.FlagSet) bool {
if f == nil {
return false
}
if f.NFlag() != 0 {
return true
}
return false
return c.PersistentFlags().HasFlags()
}
// Climbs up the command tree looking for matching flag
@ -308,10 +297,8 @@ func (c *Command) persistentFlag(name string) (flag *flag.Flag) {
// Parses persistent flag tree & local flags
func (c *Command) ParseFlags(args []string) (err error) {
err = c.ParsePersistentFlags(args)
if err != nil {
return err
}
c.mergePersistentFlags()
err = c.Flags().Parse(args)
if err != nil {
return err
@ -319,6 +306,25 @@ func (c *Command) ParseFlags(args []string) (err error) {
return nil
}
func (c *Command) mergePersistentFlags() {
var rmerge func(x *Command)
rmerge = func(x *Command) {
if x.HasPersistentFlags() {
x.PersistentFlags().VisitAll(func(f *flag.Flag) {
if c.Flags().Lookup(f.Name) == nil {
c.Flags().AddFlag(f)
}
})
}
if x.HasParent() {
rmerge(x.parent)
}
}
rmerge(c)
}
// Climbs up the command tree parsing flags from top to bottom
func (c *Command) ParsePersistentFlags(args []string) (err error) {
if !c.HasParent() || (c.parent.HasPersistentFlags() && c.parent.PersistentFlags().Parsed()) {