Update spf13/cobra and spf13/pflag

This commit is contained in:
Eric Paris 2015-04-29 11:11:43 -04:00
parent 8fa21ebd62
commit b7217a33ab
9 changed files with 237 additions and 62 deletions

4
Godeps/Godeps.json generated
View File

@ -484,11 +484,11 @@
}, },
{ {
"ImportPath": "github.com/spf13/cobra", "ImportPath": "github.com/spf13/cobra",
"Rev": "9cb5e8502924a8ff1cce18a9348b61995d7b4fde" "Rev": "69e5f196b5d30673deb07a2221d89cf62e4b74ae"
}, },
{ {
"ImportPath": "github.com/spf13/pflag", "ImportPath": "github.com/spf13/pflag",
"Rev": "60d4c375939ff7ba397a84117d5281256abb298f" "Rev": "d4ebabf889f7b016ffcb5f74c350e1e5424b2094"
}, },
{ {
"ImportPath": "github.com/stretchr/objx", "ImportPath": "github.com/stretchr/objx",

View File

@ -145,9 +145,9 @@ A flag can also be assigned locally which will only apply to that specific comma
### Remove a command from its parent ### Remove a command from its parent
Removing a command is not a common action is simple program but it allows 3rd parties to customize an existing command tree. Removing a command is not a common action in simple programs but it allows 3rd parties to customize an existing command tree.
In this exemple, we remove the existing `VersionCmd` command of an existing root command, and we replace it by our own version. In this example, we remove the existing `VersionCmd` command of an existing root command, and we replace it by our own version.
mainlib.RootCmd.RemoveCommand(mainlib.VersionCmd) mainlib.RootCmd.RemoveCommand(mainlib.VersionCmd)
mainlib.RootCmd.AddCommand(versionCmd) mainlib.RootCmd.AddCommand(versionCmd)

View File

@ -483,6 +483,32 @@ func TestInvalidSubCommandFlags(t *testing.T) {
} }
func TestSubCommandArgEvaluation(t *testing.T) {
cmd := initializeWithRootCmd()
first := &Command{
Use: "first",
Run: func(cmd *Command, args []string) {
},
}
cmd.AddCommand(first)
second := &Command{
Use: "second",
Run: func(cmd *Command, args []string) {
fmt.Fprintf(cmd.Out(), "%v", args)
},
}
first.AddCommand(second)
result := simpleTester(cmd, "first second first third")
expectedOutput := fmt.Sprintf("%v", []string{"first third"})
if result.Output != expectedOutput {
t.Errorf("exptected %v, got %v", expectedOutput, result.Output)
}
}
func TestPersistentFlags(t *testing.T) { func TestPersistentFlags(t *testing.T) {
fullSetupTest("echo -s something -p more here") fullSetupTest("echo -s something -p more here")

View File

@ -328,15 +328,18 @@ func stripFlags(args []string, c *Command) []string {
return commands return commands
} }
func argsMinusX(args []string, x string) []string { // argsMinusFirstX removes only the first x from args. Otherwise, commands that look like
newargs := []string{} // openshift admin policy add-role-to-user admin my-user, lose the admin argument (arg[4]).
func argsMinusFirstX(args []string, x string) []string {
for _, y := range args { for i, y := range args {
if x != y { if x == y {
newargs = append(newargs, y) ret := []string{}
ret = append(ret, args[:i]...)
ret = append(ret, args[i+1:]...)
return ret
} }
} }
return newargs return args
} }
// find the target command given the args and command tree // find the target command given the args and command tree
@ -359,7 +362,7 @@ func (c *Command) Find(arrs []string) (*Command, []string, error) {
matches := make([]*Command, 0) matches := make([]*Command, 0)
for _, cmd := range c.commands { for _, cmd := range c.commands {
if cmd.Name() == argsWOflags[0] || cmd.HasAlias(argsWOflags[0]) { // exact name or alias match if cmd.Name() == argsWOflags[0] || cmd.HasAlias(argsWOflags[0]) { // exact name or alias match
return innerfind(cmd, argsMinusX(args, argsWOflags[0])) return innerfind(cmd, argsMinusFirstX(args, argsWOflags[0]))
} else if EnablePrefixMatching { } else if EnablePrefixMatching {
if strings.HasPrefix(cmd.Name(), argsWOflags[0]) { // prefix match if strings.HasPrefix(cmd.Name(), argsWOflags[0]) { // prefix match
matches = append(matches, cmd) matches = append(matches, cmd)
@ -374,7 +377,7 @@ func (c *Command) Find(arrs []string) (*Command, []string, error) {
// only accept a single prefix match - multiple matches would be ambiguous // only accept a single prefix match - multiple matches would be ambiguous
if len(matches) == 1 { if len(matches) == 1 {
return innerfind(matches[0], argsMinusX(args, argsWOflags[0])) return innerfind(matches[0], argsMinusFirstX(args, argsWOflags[0]))
} }
} }
} }

View File

@ -34,7 +34,7 @@ func printOptions(out *bytes.Buffer, cmd *Command, name string) {
parentFlags := cmd.InheritedFlags() parentFlags := cmd.InheritedFlags()
parentFlags.SetOutput(out) parentFlags.SetOutput(out)
if parentFlags.HasFlags() { if parentFlags.HasFlags() {
fmt.Fprintf(out, "### Options inherrited from parent commands\n\n```\n") fmt.Fprintf(out, "### Options inherited from parent commands\n\n```\n")
parentFlags.PrintDefaults() parentFlags.PrintDefaults()
fmt.Fprintf(out, "```\n\n") fmt.Fprintf(out, "```\n\n")
} }

View File

@ -59,6 +59,4 @@ func TestGenMdDoc(t *testing.T) {
if !strings.Contains(found, expected) { if !strings.Contains(found, expected) {
t.Errorf("Unexpected response.\nExpecting to contain: \n %q\nGot:\n %q\n", expected, found) t.Errorf("Unexpected response.\nExpecting to contain: \n %q\nGot:\n %q\n", expected, found)
} }
fmt.Fprintf(os.Stdout, "%s\n", found)
} }

View File

@ -156,7 +156,8 @@ func TestImplicitFalse(t *testing.T) {
func TestInvalidValue(t *testing.T) { func TestInvalidValue(t *testing.T) {
var tristate triStateValue var tristate triStateValue
f := setUpFlagSet(&tristate) f := setUpFlagSet(&tristate)
err := f.Parse([]string{"--tristate=invalid"}) args := []string{"--tristate=invalid"}
_, err := parseReturnStderr(t, f, args)
if err == nil { if err == nil {
t.Fatal("expected an error but did not get any, tristate has value", tristate) t.Fatal("expected an error but did not get any, tristate has value", tristate)
} }

View File

@ -120,9 +120,9 @@ const (
PanicOnError PanicOnError
) )
// normalizedName is a flag name that has been normalized according to rules // NormalizedName is a flag name that has been normalized according to rules
// for the FlagSet (e.g. making '-' and '_' equivalent). // for the FlagSet (e.g. making '-' and '_' equivalent).
type normalizedName string type NormalizedName string
// A FlagSet represents a set of defined flags. // A FlagSet represents a set of defined flags.
type FlagSet struct { type FlagSet struct {
@ -131,17 +131,17 @@ type FlagSet struct {
// a custom error handler. // a custom error handler.
Usage func() Usage func()
name string name string
parsed bool parsed bool
actual map[normalizedName]*Flag actual map[NormalizedName]*Flag
formal map[normalizedName]*Flag formal map[NormalizedName]*Flag
shorthands map[byte]*Flag shorthands map[byte]*Flag
args []string // arguments after flags args []string // arguments after flags
exitOnError bool // does the program exit if there's an error? exitOnError bool // does the program exit if there's an error?
errorHandling ErrorHandling errorHandling ErrorHandling
output io.Writer // nil means stderr; use out() accessor output io.Writer // nil means stderr; use out() accessor
interspersed bool // allow interspersed option/non-option args interspersed bool // allow interspersed option/non-option args
wordSeparators []string normalizeNameFunc func(f *FlagSet, name string) NormalizedName
} }
// A Flag represents the state of a flag. // A Flag represents the state of a flag.
@ -152,6 +152,7 @@ type Flag struct {
Value Value // value as set Value Value // value as set
DefValue string // default value (as text); for usage message DefValue string // default value (as text); for usage message
Changed bool // If the user set the value (or if left to default) Changed bool // If the user set the value (or if left to default)
Deprecated string // If this flag is deprecated, this string is the new or now thing to use
Annotations map[string][]string // used by cobra.Command bash autocomple code Annotations map[string][]string // used by cobra.Command bash autocomple code
} }
@ -164,7 +165,7 @@ type Value interface {
} }
// sortFlags returns the flags as a slice in lexicographical sorted order. // sortFlags returns the flags as a slice in lexicographical sorted order.
func sortFlags(flags map[normalizedName]*Flag) []*Flag { func sortFlags(flags map[NormalizedName]*Flag) []*Flag {
list := make(sort.StringSlice, len(flags)) list := make(sort.StringSlice, len(flags))
i := 0 i := 0
for k := range flags { for k := range flags {
@ -174,18 +175,29 @@ func sortFlags(flags map[normalizedName]*Flag) []*Flag {
list.Sort() list.Sort()
result := make([]*Flag, len(list)) result := make([]*Flag, len(list))
for i, name := range list { for i, name := range list {
result[i] = flags[normalizedName(name)] result[i] = flags[NormalizedName(name)]
} }
return result return result
} }
func (f *FlagSet) normalizeFlagName(name string) normalizedName { func (f *FlagSet) SetNormalizeFunc(n func(f *FlagSet, name string) NormalizedName) {
result := name f.normalizeNameFunc = n
for _, sep := range f.wordSeparators { for k, v := range f.formal {
result = strings.Replace(result, sep, "-", -1) delete(f.formal, k)
f.formal[f.normalizeFlagName(string(k))] = v
} }
// Type convert to indicate normalization has been done. }
return normalizedName(result)
func (f *FlagSet) GetNormalizeFunc() func(f *FlagSet, name string) NormalizedName {
if f.normalizeNameFunc != nil {
return f.normalizeNameFunc
}
return func(f *FlagSet, name string) NormalizedName { return NormalizedName(name) }
}
func (f *FlagSet) normalizeFlagName(name string) NormalizedName {
n := f.GetNormalizeFunc()
return n(f, name)
} }
func (f *FlagSet) out() io.Writer { func (f *FlagSet) out() io.Writer {
@ -239,10 +251,20 @@ func (f *FlagSet) Lookup(name string) *Flag {
} }
// lookup returns the Flag structure of the named flag, returning nil if none exists. // lookup returns the Flag structure of the named flag, returning nil if none exists.
func (f *FlagSet) lookup(name normalizedName) *Flag { func (f *FlagSet) lookup(name NormalizedName) *Flag {
return f.formal[name] return f.formal[name]
} }
// Mark a flag deprecated in your program
func (f *FlagSet) MarkDeprecated(name string, usageMessage string) error {
flag := f.Lookup(name)
if flag == nil {
return fmt.Errorf("flag %q does not exist", name)
}
flag.Deprecated = usageMessage
return nil
}
// Lookup returns the Flag structure of the named command-line flag, // Lookup returns the Flag structure of the named command-line flag,
// returning nil if none exists. // returning nil if none exists.
func Lookup(name string) *Flag { func Lookup(name string) *Flag {
@ -261,10 +283,13 @@ func (f *FlagSet) Set(name, value string) error {
return err return err
} }
if f.actual == nil { if f.actual == nil {
f.actual = make(map[normalizedName]*Flag) f.actual = make(map[NormalizedName]*Flag)
} }
f.actual[normalName] = flag f.actual[normalName] = flag
f.lookup(normalName).Changed = true flag.Changed = true
if len(flag.Deprecated) > 0 {
fmt.Fprintf(os.Stderr, "Flag --%s has been deprecated, %s\n", flag.Name, flag.Deprecated)
}
return nil return nil
} }
@ -277,6 +302,9 @@ func Set(name, value string) error {
// otherwise, the default values of all defined flags in the set. // otherwise, the default values of all defined flags in the set.
func (f *FlagSet) PrintDefaults() { func (f *FlagSet) PrintDefaults() {
f.VisitAll(func(flag *Flag) { f.VisitAll(func(flag *Flag) {
if len(flag.Deprecated) > 0 {
return
}
format := "--%s=%s: %s\n" format := "--%s=%s: %s\n"
if _, ok := flag.Value.(*stringValue); ok { if _, ok := flag.Value.(*stringValue); ok {
// put quotes on the value // put quotes on the value
@ -295,6 +323,9 @@ func (f *FlagSet) FlagUsages() string {
x := new(bytes.Buffer) x := new(bytes.Buffer)
f.VisitAll(func(flag *Flag) { f.VisitAll(func(flag *Flag) {
if len(flag.Deprecated) > 0 {
return
}
format := "--%s=%s: %s\n" format := "--%s=%s: %s\n"
if _, ok := flag.Value.(*stringValue); ok { if _, ok := flag.Value.(*stringValue); ok {
// put quotes on the value // put quotes on the value
@ -397,7 +428,7 @@ func (f *FlagSet) AddFlag(flag *Flag) {
panic(msg) // Happens only if flags are declared with identical names panic(msg) // Happens only if flags are declared with identical names
} }
if f.formal == nil { if f.formal == nil {
f.formal = make(map[normalizedName]*Flag) f.formal = make(map[NormalizedName]*Flag)
} }
f.formal[f.normalizeFlagName(flag.Name)] = flag f.formal[f.normalizeFlagName(flag.Name)] = flag
@ -462,10 +493,13 @@ func (f *FlagSet) setFlag(flag *Flag, value string, origArg string) error {
} }
// mark as visited for Visit() // mark as visited for Visit()
if f.actual == nil { if f.actual == nil {
f.actual = make(map[normalizedName]*Flag) f.actual = make(map[NormalizedName]*Flag)
} }
f.actual[f.normalizeFlagName(flag.Name)] = flag f.actual[f.normalizeFlagName(flag.Name)] = flag
flag.Changed = true flag.Changed = true
if len(flag.Deprecated) > 0 {
fmt.Fprintf(os.Stderr, "Flag --%s has been deprecated, %s\n", flag.Name, flag.Deprecated)
}
return nil return nil
} }
@ -625,19 +659,6 @@ func SetInterspersed(interspersed bool) {
CommandLine.SetInterspersed(interspersed) CommandLine.SetInterspersed(interspersed)
} }
// SetWordSeparators sets a list of strings to be considerered as word
// separators and normalized for the pruposes of lookups. For example, if this
// is set to {"-", "_", "."} then --foo_bar, --foo-bar, and --foo.bar are
// considered equivalent flags. This must be called before flags are parsed,
// and may only be called once.
func (f *FlagSet) SetWordSeparators(separators []string) {
f.wordSeparators = separators
for k, v := range f.formal {
delete(f.formal, k)
f.formal[f.normalizeFlagName(string(k))] = v
}
}
// Parsed returns true if the command-line flags have been parsed. // Parsed returns true if the command-line flags have been parsed.
func Parsed() bool { func Parsed() bool {
return CommandLine.Parsed() return CommandLine.Parsed()

View File

@ -7,6 +7,7 @@ package pflag_test
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"os" "os"
"sort" "sort"
@ -238,14 +239,29 @@ func TestFlagSetParse(t *testing.T) {
testParse(NewFlagSet("test", ContinueOnError), t) testParse(NewFlagSet("test", ContinueOnError), t)
} }
func testNormalizedNames(args []string, t *testing.T) { func replaceSeparators(name string, from []string, to string) string {
result := name
for _, sep := range from {
result = strings.Replace(result, sep, to, -1)
}
// Type convert to indicate normalization has been done.
return result
}
func wordSepNormalizeFunc(f *FlagSet, name string) NormalizedName {
seps := []string{"-", "_"}
name = replaceSeparators(name, seps, ".")
return NormalizedName(name)
}
func testWordSepNormalizedNames(args []string, t *testing.T) {
f := NewFlagSet("normalized", ContinueOnError) f := NewFlagSet("normalized", ContinueOnError)
if f.Parsed() { if f.Parsed() {
t.Error("f.Parse() = true before Parse") t.Error("f.Parse() = true before Parse")
} }
withDashFlag := f.Bool("with-dash-flag", false, "bool value") withDashFlag := f.Bool("with-dash-flag", false, "bool value")
// Set this after some flags have been added and before others. // Set this after some flags have been added and before others.
f.SetWordSeparators([]string{"-", "_"}) f.SetNormalizeFunc(wordSepNormalizeFunc)
withUnderFlag := f.Bool("with_under_flag", false, "bool value") withUnderFlag := f.Bool("with_under_flag", false, "bool value")
withBothFlag := f.Bool("with-both_flag", false, "bool value") withBothFlag := f.Bool("with-both_flag", false, "bool value")
if err := f.Parse(args); err != nil { if err := f.Parse(args); err != nil {
@ -265,27 +281,66 @@ func testNormalizedNames(args []string, t *testing.T) {
} }
} }
func TestNormalizedNames(t *testing.T) { func TestWordSepNormalizedNames(t *testing.T) {
args := []string{ args := []string{
"--with-dash-flag", "--with-dash-flag",
"--with-under-flag", "--with-under-flag",
"--with-both-flag", "--with-both-flag",
} }
testNormalizedNames(args, t) testWordSepNormalizedNames(args, t)
args = []string{ args = []string{
"--with_dash_flag", "--with_dash_flag",
"--with_under_flag", "--with_under_flag",
"--with_both_flag", "--with_both_flag",
} }
testNormalizedNames(args, t) testWordSepNormalizedNames(args, t)
args = []string{ args = []string{
"--with-dash_flag", "--with-dash_flag",
"--with-under_flag", "--with-under_flag",
"--with-both_flag", "--with-both_flag",
} }
testNormalizedNames(args, t) testWordSepNormalizedNames(args, t)
}
func aliasAndWordSepFlagNames(f *FlagSet, name string) NormalizedName {
seps := []string{"-", "_"}
oldName := replaceSeparators("old-valid_flag", seps, ".")
newName := replaceSeparators("valid-flag", seps, ".")
name = replaceSeparators(name, seps, ".")
switch name {
case oldName:
name = newName
break
}
return NormalizedName(name)
}
func TestCustomNormalizedNames(t *testing.T) {
f := NewFlagSet("normalized", ContinueOnError)
if f.Parsed() {
t.Error("f.Parse() = true before Parse")
}
validFlag := f.Bool("valid-flag", false, "bool value")
f.SetNormalizeFunc(aliasAndWordSepFlagNames)
someOtherFlag := f.Bool("some-other-flag", false, "bool value")
args := []string{"--old_valid_flag", "--some-other_flag"}
if err := f.Parse(args); err != nil {
t.Fatal(err)
}
if *validFlag != true {
t.Errorf("validFlag is %v even though we set the alias --old_valid_falg", *validFlag)
}
if *someOtherFlag != true {
t.Error("someOtherFlag should be true, is ", *someOtherFlag)
}
} }
// Declare a user-defined flag type. // Declare a user-defined flag type.
@ -445,3 +500,74 @@ func TestTermination(t *testing.T) {
t.Errorf("expected argument %q got %q", arg2, f.Args()[1]) t.Errorf("expected argument %q got %q", arg2, f.Args()[1])
} }
} }
func TestDeprecatedFlagInDocs(t *testing.T) {
f := NewFlagSet("bob", ContinueOnError)
f.Bool("badflag", true, "always true")
f.MarkDeprecated("badflag", "use --good-flag instead")
out := new(bytes.Buffer)
f.SetOutput(out)
f.PrintDefaults()
if strings.Contains(out.String(), "badflag") {
t.Errorf("found deprecated flag in usage!")
}
}
func parseReturnStderr(t *testing.T, f *FlagSet, args []string) (string, error) {
oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
err := f.Parse(args)
outC := make(chan string)
// copy the output in a separate goroutine so printing can't block indefinitely
go func() {
var buf bytes.Buffer
io.Copy(&buf, r)
outC <- buf.String()
}()
w.Close()
os.Stderr = oldStderr
out := <-outC
return out, err
}
func TestDeprecatedFlagUsage(t *testing.T) {
f := NewFlagSet("bob", ContinueOnError)
f.Bool("badflag", true, "always true")
usageMsg := "use --good-flag instead"
f.MarkDeprecated("badflag", usageMsg)
args := []string{"--badflag"}
out, err := parseReturnStderr(t, f, args)
if err != nil {
t.Fatal("expected no error; got ", err)
}
if !strings.Contains(out, usageMsg) {
t.Errorf("usageMsg not printed when using a deprecated flag!")
}
}
func TestDeprecatedFlagUsageNormalized(t *testing.T) {
f := NewFlagSet("bob", ContinueOnError)
f.Bool("bad-double_flag", true, "always true")
f.SetNormalizeFunc(wordSepNormalizeFunc)
usageMsg := "use --good-flag instead"
f.MarkDeprecated("bad_double-flag", usageMsg)
args := []string{"--bad_double_flag"}
out, err := parseReturnStderr(t, f, args)
if err != nil {
t.Fatal("expected no error; got ", err)
}
if !strings.Contains(out, usageMsg) {
t.Errorf("usageMsg not printed when using a deprecated flag!")
}
}