Update the cobra and pflag libraries with upstream fixes.

This commit is contained in:
Brendan Burns 2015-03-13 14:21:42 -07:00
parent 232796db04
commit f98cd70c9c
6 changed files with 173 additions and 89 deletions

4
Godeps/Godeps.json generated
View File

@ -357,11 +357,11 @@
}, },
{ {
"ImportPath": "github.com/spf13/cobra", "ImportPath": "github.com/spf13/cobra",
"Rev": "f8e1ec56bdd7494d309c69681267859a6bfb7549" "Rev": "9e7273d5469dd5e04a35fd8823ba510117448c0b"
}, },
{ {
"ImportPath": "github.com/spf13/pflag", "ImportPath": "github.com/spf13/pflag",
"Rev": "370c3171201099fa6b466db45c8a032cbce33d8d" "Rev": "11b7cf8387a31f278486eaad758162830eca8c73"
}, },
{ {
"ImportPath": "github.com/stretchr/objx", "ImportPath": "github.com/stretchr/objx",

View File

@ -11,11 +11,14 @@ var _ = fmt.Println
var tp, te, tt, t1 []string var tp, te, tt, t1 []string
var flagb1, flagb2, flagb3, flagbr bool var flagb1, flagb2, flagb3, flagbr bool
var flags1, flags2, flags3 string var flags1, flags2a, flags2b, flags3 string
var flagi1, flagi2, flagi3, flagir int var flagi1, flagi2, flagi3, flagir int
var globalFlag1 bool var globalFlag1 bool
var flagEcho, rootcalled bool var flagEcho, rootcalled bool
const strtwoParentHelp = "help message for parent flag strtwo"
const strtwoChildHelp = "help message for child flag strtwo"
var cmdPrint = &Command{ var cmdPrint = &Command{
Use: "print [string to print]", Use: "print [string to print]",
Short: "Print anything to the screen", Short: "Print anything to the screen",
@ -72,11 +75,12 @@ func flagInit() {
cmdRootNoRun.ResetFlags() cmdRootNoRun.ResetFlags()
cmdRootSameName.ResetFlags() cmdRootSameName.ResetFlags()
cmdRootWithRun.ResetFlags() cmdRootWithRun.ResetFlags()
cmdRootNoRun.PersistentFlags().StringVarP(&flags2a, "strtwo", "t", "two", strtwoParentHelp)
cmdEcho.Flags().IntVarP(&flagi1, "intone", "i", 123, "help message for flag intone") cmdEcho.Flags().IntVarP(&flagi1, "intone", "i", 123, "help message for flag intone")
cmdTimes.Flags().IntVarP(&flagi2, "inttwo", "j", 234, "help message for flag inttwo") cmdTimes.Flags().IntVarP(&flagi2, "inttwo", "j", 234, "help message for flag inttwo")
cmdPrint.Flags().IntVarP(&flagi3, "intthree", "i", 345, "help message for flag intthree") cmdPrint.Flags().IntVarP(&flagi3, "intthree", "i", 345, "help message for flag intthree")
cmdEcho.PersistentFlags().StringVarP(&flags1, "strone", "s", "one", "help message for flag strone") cmdEcho.PersistentFlags().StringVarP(&flags1, "strone", "s", "one", "help message for flag strone")
cmdTimes.PersistentFlags().StringVarP(&flags2, "strtwo", "t", "two", "help message for flag strtwo") cmdTimes.PersistentFlags().StringVarP(&flags2b, "strtwo", "t", "2", strtwoChildHelp)
cmdPrint.PersistentFlags().StringVarP(&flags3, "strthree", "s", "three", "help message for flag strthree") cmdPrint.PersistentFlags().StringVarP(&flags3, "strthree", "s", "three", "help message for flag strthree")
cmdEcho.Flags().BoolVarP(&flagb1, "boolone", "b", true, "help message for flag boolone") cmdEcho.Flags().BoolVarP(&flagb1, "boolone", "b", true, "help message for flag boolone")
cmdTimes.Flags().BoolVarP(&flagb2, "booltwo", "c", false, "help message for flag booltwo") cmdTimes.Flags().BoolVarP(&flagb2, "booltwo", "c", false, "help message for flag booltwo")
@ -377,10 +381,21 @@ func TestChildCommandFlags(t *testing.T) {
t.Errorf("invalid flag should generate error") t.Errorf("invalid flag should generate error")
} }
if !strings.Contains(r.Output, "intone=123") { if !strings.Contains(r.Output, "unknown shorthand flag") {
t.Errorf("Wrong error message displayed, \n %s", r.Output) t.Errorf("Wrong error message displayed, \n %s", r.Output)
} }
// Testing with persistent flag overwritten by child
noRRSetupTest("echo times --strtwo=child one two")
if flags2b != "child" {
t.Errorf("flag value should be child, %s given", flags2b)
}
if flags2a != "two" {
t.Errorf("unset flag should have default value, expecting two, given %s", flags2a)
}
// Testing flag with invalid input // Testing flag with invalid input
r = noRRSetupTest("echo -i10E") r = noRRSetupTest("echo -i10E")
@ -437,6 +452,13 @@ func TestHelpCommand(t *testing.T) {
checkResultContains(t, r, cmdTimes.Long) checkResultContains(t, r, cmdTimes.Long)
} }
func TestChildCommandHelp(t *testing.T) {
c := noRRSetupTest("print --help")
checkResultContains(t, c, strtwoParentHelp)
r := noRRSetupTest("echo times --help")
checkResultContains(t, r, strtwoChildHelp)
}
func TestRunnableRootCommand(t *testing.T) { func TestRunnableRootCommand(t *testing.T) {
fullSetupTest("") fullSetupTest("")
@ -486,6 +508,26 @@ func TestRootHelp(t *testing.T) {
} }
func TestFlagAccess(t *testing.T) {
initialize()
local := cmdTimes.LocalFlags()
inherited := cmdTimes.InheritedFlags()
for _, f := range []string{"inttwo", "strtwo", "booltwo"} {
if local.Lookup(f) == nil {
t.Errorf("LocalFlags expected to contain %s, Got: nil", f)
}
}
if inherited.Lookup("strone") == nil {
t.Errorf("InheritedFlags expected to contain strone, Got: nil")
}
if inherited.Lookup("strtwo") != nil {
t.Errorf("InheritedFlags shouldn not contain overwritten flag strtwo")
}
}
func TestRootNoCommandHelp(t *testing.T) { func TestRootNoCommandHelp(t *testing.T) {
x := rootOnlySetupTest("--help") x := rootOnlySetupTest("--help")

View File

@ -46,6 +46,8 @@ type Command struct {
flags *flag.FlagSet flags *flag.FlagSet
// Set of flags childrens of this command will inherit // Set of flags childrens of this command will inherit
pflags *flag.FlagSet pflags *flag.FlagSet
// Flags that are declared specifically by this command (not inherited).
lflags *flag.FlagSet
// Run runs the command. // Run runs the command.
// The args are the arguments after the command name. // The args are the arguments after the command name.
Run func(cmd *Command, args []string) Run func(cmd *Command, args []string)
@ -218,8 +220,8 @@ Available Commands: {{range .Commands}}{{if .Runnable}}
{{end}} {{end}}
{{ if .HasLocalFlags}}Flags: {{ if .HasLocalFlags}}Flags:
{{.LocalFlags.FlagUsages}}{{end}} {{.LocalFlags.FlagUsages}}{{end}}
{{ if .HasAnyPersistentFlags}}Global Flags: {{ if .HasInheritedFlags}}Global Flags:
{{.AllPersistentFlags.FlagUsages}}{{end}}{{if .HasParent}}{{if and (gt .Commands 0) (gt .Parent.Commands 1) }} {{.InheritedFlags.FlagUsages}}{{end}}{{if .HasParent}}{{if and (gt .Commands 0) (gt .Parent.Commands 1) }}
Additional help topics: {{if gt .Commands 0 }}{{range .Commands}}{{if not .Runnable}} {{rpad .CommandPath .CommandPathPadding}} {{.Short}}{{end}}{{end}}{{end}}{{if gt .Parent.Commands 1 }}{{range .Parent.Commands}}{{if .Runnable}}{{if not (eq .Name $cmd.Name) }}{{end}} Additional help topics: {{if gt .Commands 0 }}{{range .Commands}}{{if not .Runnable}} {{rpad .CommandPath .CommandPathPadding}} {{.Short}}{{end}}{{end}}{{end}}{{if gt .Parent.Commands 1 }}{{range .Parent.Commands}}{{if .Runnable}}{{if not (eq .Name $cmd.Name) }}{{end}}
{{rpad .CommandPath .CommandPathPadding}} {{.Short}}{{end}}{{end}}{{end}}{{end}} {{rpad .CommandPath .CommandPathPadding}} {{.Short}}{{end}}{{end}}{{end}}{{end}}
{{end}}{{ if .HasSubCommands }} {{end}}{{ if .HasSubCommands }}
@ -249,7 +251,27 @@ func (c *Command) resetChildrensParents() {
} }
} }
func stripFlags(args []string) []string { // Test if the named flag is a boolean flag.
func isBooleanFlag(name string, f *flag.FlagSet) bool {
flag := f.Lookup(name)
if flag == nil {
return false
}
return flag.Value.Type() == "bool"
}
// Test if the named flag is a boolean flag.
func isBooleanShortFlag(name string, f *flag.FlagSet) bool {
result := false
f.VisitAll(func(f *flag.Flag) {
if f.Shorthand == name && f.Value.Type() == "bool" {
result = true
}
})
return result
}
func stripFlags(args []string, c *Command) []string {
if len(args) < 1 { if len(args) < 1 {
return args return args
} }
@ -257,6 +279,7 @@ func stripFlags(args []string) []string {
commands := []string{} commands := []string{}
inQuote := false inQuote := false
inFlag := false
for _, y := range args { for _, y := range args {
if !inQuote { if !inQuote {
switch { switch {
@ -264,8 +287,16 @@ func stripFlags(args []string) []string {
inQuote = true inQuote = true
case strings.Contains(y, "=\""): case strings.Contains(y, "=\""):
inQuote = true inQuote = true
case strings.HasPrefix(y, "--") && !strings.Contains(y, "="):
// TODO: this isn't quite right, we should really check ahead for 'true' or 'false'
inFlag = !isBooleanFlag(y[2:], c.Flags())
case strings.HasPrefix(y, "-") && !strings.Contains(y, "=") && len(y) == 2 && !isBooleanShortFlag(y[1:], c.Flags()):
inFlag = true
case inFlag:
inFlag = false
case !strings.HasPrefix(y, "-"): case !strings.HasPrefix(y, "-"):
commands = append(commands, y) commands = append(commands, y)
inFlag = false
} }
} }
@ -303,7 +334,7 @@ func (c *Command) Find(arrs []string) (*Command, []string, error) {
innerfind = func(c *Command, args []string) (*Command, []string) { innerfind = func(c *Command, args []string) (*Command, []string) {
if len(args) > 0 && c.HasSubCommands() { if len(args) > 0 && c.HasSubCommands() {
argsWOflags := stripFlags(args) argsWOflags := stripFlags(args, c)
if len(argsWOflags) > 0 { if len(argsWOflags) > 0 {
matches := make([]*Command, 0) matches := make([]*Command, 0)
for _, cmd := range c.commands { for _, cmd := range c.commands {
@ -372,7 +403,10 @@ func (c *Command) execute(a []string) (err error) {
} }
err = c.ParseFlags(a) err = c.ParseFlags(a)
if err == flag.ErrHelp {
c.Help()
return nil
}
if err != nil { if err != nil {
// We're writing subcommand usage to root command's error buffer to have it displayed to the user // We're writing subcommand usage to root command's error buffer to have it displayed to the user
r := c.Root() r := c.Root()
@ -460,14 +494,15 @@ func (c *Command) Execute() (err error) {
if e != nil { if e != nil {
// Flags parsing had an error. // Flags parsing had an error.
// If an error happens here, we have to report it to the user // If an error happens here, we have to report it to the user
c.Println(c.errorMsgFromParse()) c.Println(e.Error())
// If an error happens search also for subcommand info about that // If an error happens search also for subcommand info about that
if c.cmdErrorBuf != nil && c.cmdErrorBuf.Len() > 0 { if c.cmdErrorBuf != nil && c.cmdErrorBuf.Len() > 0 {
c.Println(c.cmdErrorBuf.String()) c.Println(c.cmdErrorBuf.String())
} else { } else {
c.Usage() c.Usage()
} }
return e err = e
return
} else { } else {
// If help is called, regardless of other flags, we print that // If help is called, regardless of other flags, we print that
if c.helpFlagVal { if c.helpFlagVal {
@ -491,10 +526,14 @@ func (c *Command) Execute() (err error) {
} }
if err != nil { if err != nil {
if err == flag.ErrHelp {
c.Help()
} else {
c.Println("Error:", err.Error()) c.Println("Error:", err.Error())
c.Printf("%v: invalid command %#q\n", c.Root().Name(), os.Args[1:]) c.Printf("%v: invalid command %#q\n", c.Root().Name(), os.Args[1:])
c.Printf("Run '%v help' for usage\n", c.Root().Name()) c.Printf("Run '%v help' for usage\n", c.Root().Name())
} }
}
return return
} }
@ -553,6 +592,40 @@ func (c *Command) AddCommand(cmds ...*Command) {
} }
} }
// AddCommand removes one or more commands from a parent command.
func (c *Command) RemoveCommand(cmds ...*Command) {
commands := []*Command{}
main:
for _, command := range c.commands {
for _, cmd := range cmds {
if command == cmd {
command.parent = nil
continue main
}
}
commands = append(commands, command)
}
c.commands = commands
// recompute all lengths
c.commandsMaxUseLen = 0
c.commandsMaxCommandPathLen = 0
c.commandsMaxNameLen = 0
for _, command := range c.commands {
usageLen := len(command.Use)
if usageLen > c.commandsMaxUseLen {
c.commandsMaxUseLen = usageLen
}
commandPathLen := len(command.CommandPath())
if commandPathLen > c.commandsMaxCommandPathLen {
c.commandsMaxCommandPathLen = commandPathLen
}
nameLen := len(command.Name())
if nameLen > c.commandsMaxNameLen {
c.commandsMaxNameLen = nameLen
}
}
}
// Convenience method to Print to the defined output // Convenience method to Print to the defined output
func (c *Command) Print(i ...interface{}) { func (c *Command) Print(i ...interface{}) {
fmt.Fprint(c.Out(), i...) fmt.Fprint(c.Out(), i...)
@ -726,14 +799,9 @@ func (c *Command) LocalFlags() *flag.FlagSet {
c.mergePersistentFlags() c.mergePersistentFlags()
local := flag.NewFlagSet(c.Name(), flag.ContinueOnError) local := flag.NewFlagSet(c.Name(), flag.ContinueOnError)
allPersistent := c.AllPersistentFlags() c.lflags.VisitAll(func(f *flag.Flag) {
c.Flags().VisitAll(func(f *flag.Flag) {
if allPersistent.Lookup(f.Name) == nil {
local.AddFlag(f) local.AddFlag(f)
}
}) })
return local return local
} }
@ -741,15 +809,16 @@ func (c *Command) LocalFlags() *flag.FlagSet {
func (c *Command) InheritedFlags() *flag.FlagSet { func (c *Command) InheritedFlags() *flag.FlagSet {
c.mergePersistentFlags() c.mergePersistentFlags()
local := flag.NewFlagSet(c.Name(), flag.ContinueOnError) inherited := flag.NewFlagSet(c.Name(), flag.ContinueOnError)
local := c.LocalFlags()
var rmerge func(x *Command) var rmerge func(x *Command)
rmerge = func(x *Command) { rmerge = func(x *Command) {
if x.HasPersistentFlags() { if x.HasPersistentFlags() {
x.PersistentFlags().VisitAll(func(f *flag.Flag) { x.PersistentFlags().VisitAll(func(f *flag.Flag) {
if local.Lookup(f.Name) == nil { if inherited.Lookup(f.Name) == nil && local.Lookup(f.Name) == nil {
local.AddFlag(f) inherited.AddFlag(f)
} }
}) })
} }
@ -762,23 +831,12 @@ func (c *Command) InheritedFlags() *flag.FlagSet {
rmerge(c.parent) rmerge(c.parent)
} }
return local return inherited
} }
// All Flags which were not inherited from parent commands // All Flags which were not inherited from parent commands
func (c *Command) NonInheritedFlags() *flag.FlagSet { func (c *Command) NonInheritedFlags() *flag.FlagSet {
c.mergePersistentFlags() return c.LocalFlags()
local := flag.NewFlagSet(c.Name(), flag.ContinueOnError)
inheritedFlags := c.InheritedFlags()
c.Flags().VisitAll(func(f *flag.Flag) {
if inheritedFlags.Lookup(f.Name) == nil {
local.AddFlag(f)
}
})
return local
} }
// Get the Persistent FlagSet specifically set in the current command // Get the Persistent FlagSet specifically set in the current command
@ -793,29 +851,6 @@ func (c *Command) PersistentFlags() *flag.FlagSet {
return c.pflags return c.pflags
} }
// Get the Persistent FlagSet traversing the Command hierarchy
func (c *Command) AllPersistentFlags() *flag.FlagSet {
allPersistent := flag.NewFlagSet(c.Name(), flag.ContinueOnError)
var visit func(x *Command)
visit = func(x *Command) {
if x.HasPersistentFlags() {
x.PersistentFlags().VisitAll(func(f *flag.Flag) {
if allPersistent.Lookup(f.Name) == nil {
allPersistent.AddFlag(f)
}
})
}
if x.HasParent() {
visit(x.parent)
}
}
visit(c)
return allPersistent
}
// For use in testing // For use in testing
func (c *Command) ResetFlags() { func (c *Command) ResetFlags() {
c.flagErrorBuf = new(bytes.Buffer) c.flagErrorBuf = new(bytes.Buffer)
@ -836,16 +871,15 @@ func (c *Command) HasPersistentFlags() bool {
return c.PersistentFlags().HasFlags() return c.PersistentFlags().HasFlags()
} }
// Does the command hierarchy contain persistent flags
func (c *Command) HasAnyPersistentFlags() bool {
return c.AllPersistentFlags().HasFlags()
}
// Does the command has flags specifically declared locally // Does the command has flags specifically declared locally
func (c *Command) HasLocalFlags() bool { func (c *Command) HasLocalFlags() bool {
return c.LocalFlags().HasFlags() return c.LocalFlags().HasFlags()
} }
func (c *Command) HasInheritedFlags() bool {
return c.InheritedFlags().HasFlags()
}
// Climbs up the command tree looking for matching flag // Climbs up the command tree looking for matching flag
func (c *Command) Flag(name string) (flag *flag.Flag) { func (c *Command) Flag(name string) (flag *flag.Flag) {
flag = c.Flags().Lookup(name) flag = c.Flags().Lookup(name)
@ -873,16 +907,7 @@ func (c *Command) persistentFlag(name string) (flag *flag.Flag) {
func (c *Command) ParseFlags(args []string) (err error) { func (c *Command) ParseFlags(args []string) (err error) {
c.mergePersistentFlags() c.mergePersistentFlags()
err = c.Flags().Parse(args) err = c.Flags().Parse(args)
return
// The upstream library adds spaces to the error
// response regardless of success.
// Handling it here until fixing upstream
if len(strings.TrimSpace(c.flagErrorBuf.String())) > 1 {
return fmt.Errorf("%s", c.flagErrorBuf.String())
}
//always return nil because upstream library is inconsistent & we always check the error buffer anyway
return nil
} }
func (c *Command) Parent() *Command { func (c *Command) Parent() *Command {
@ -892,6 +917,19 @@ func (c *Command) Parent() *Command {
func (c *Command) mergePersistentFlags() { func (c *Command) mergePersistentFlags() {
var rmerge func(x *Command) var rmerge func(x *Command)
// Save the set of local flags
if c.lflags == nil {
c.lflags = flag.NewFlagSet(c.Name(), flag.ContinueOnError)
if c.flagErrorBuf == nil {
c.flagErrorBuf = new(bytes.Buffer)
}
c.lflags.SetOutput(c.flagErrorBuf)
addtolocal := func(f *flag.Flag) {
c.lflags.AddFlag(f)
}
c.Flags().VisitAll(addtolocal)
c.PersistentFlags().VisitAll(addtolocal)
}
rmerge = func(x *Command) { rmerge = func(x *Command) {
if x.HasPersistentFlags() { if x.HasPersistentFlags() {
x.PersistentFlags().VisitAll(func(f *flag.Flag) { x.PersistentFlags().VisitAll(func(f *flag.Flag) {

View File

@ -9,7 +9,7 @@ import (
"strconv" "strconv"
"testing" "testing"
. "github.com/ogier/pflag" . "github.com/spf13/pflag"
) )
// This value can be a boolean ("true", "false") or "maybe" // This value can be a boolean ("true", "false") or "maybe"

View File

@ -11,7 +11,7 @@ import (
"strings" "strings"
"time" "time"
flag "github.com/ogier/pflag" flag "github.com/spf13/pflag"
) )
// Example 1: A single string flag called "species" with default value "gopher". // Example 1: A single string flag called "species" with default value "gopher".

View File

@ -498,6 +498,7 @@ func (f *FlagSet) parseShortArg(s string, args []string) (a []string, err error)
if len(args) == 0 { if len(args) == 0 {
return return
} }
return
} }
if alreadythere { if alreadythere {
if bv, ok := flag.Value.(boolFlag); ok && bv.IsBoolFlag() { if bv, ok := flag.Value.(boolFlag); ok && bv.IsBoolFlag() {
@ -551,6 +552,9 @@ func (f *FlagSet) parseArgs(args []string) (err error) {
} else { } else {
args, err = f.parseShortArg(s, args) args, err = f.parseShortArg(s, args)
} }
if err != nil {
return
}
} }
return return
} }