updated cobra and pflag

This commit is contained in:
Anastasis Andronidis 2015-05-20 21:55:29 +02:00
parent ee82d469c6
commit a1ea3df0f1
10 changed files with 331 additions and 100 deletions

4
Godeps/Godeps.json generated
View File

@ -413,11 +413,11 @@
}, },
{ {
"ImportPath": "github.com/spf13/cobra", "ImportPath": "github.com/spf13/cobra",
"Rev": "bba56042cf767e329430e7c7f68c3f9f640b4b8b" "Rev": "8f5946caaeeff40a98d67f60c25e89c3525038a3"
}, },
{ {
"ImportPath": "github.com/spf13/pflag", "ImportPath": "github.com/spf13/pflag",
"Rev": "d4ebabf889f7b016ffcb5f74c350e1e5424b2094" "Rev": "b91b2a94780f4e6b4d3b0c12fd9b5f4b05b1aa45"
}, },
{ {
"ImportPath": "github.com/stretchr/objx", "ImportPath": "github.com/stretchr/objx",

View File

@ -4,15 +4,18 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"os" "os"
"reflect"
"runtime" "runtime"
"strings" "strings"
"testing" "testing"
"github.com/spf13/pflag"
) )
var _ = fmt.Println var _ = fmt.Println
var _ = os.Stderr var _ = os.Stderr
var tp, te, tt, t1 []string var tp, te, tt, t1, tr []string
var rootPersPre, echoPre, echoPersPre, timesPersPre []string var rootPersPre, echoPre, echoPersPre, timesPersPre []string
var flagb1, flagb2, flagb3, flagbr, flagbp bool var flagb1, flagb2, flagb3, flagbr, flagbp bool
var flags1, flags2a, flags2b, flags3 string var flags1, flags2a, flags2b, flags3 string
@ -99,6 +102,7 @@ var cmdRootWithRun = &Command{
Short: "The root can run it's own function", Short: "The root can run it's own function",
Long: "The root description for help", Long: "The root description for help",
Run: func(cmd *Command, args []string) { Run: func(cmd *Command, args []string) {
tr = args
rootcalled = true rootcalled = true
}, },
} }
@ -181,7 +185,7 @@ func initializeWithSameName() *Command {
func initializeWithRootCmd() *Command { func initializeWithRootCmd() *Command {
cmdRootWithRun.ResetCommands() cmdRootWithRun.ResetCommands()
tt, tp, te, rootcalled = nil, nil, nil, false tt, tp, te, tr, rootcalled = nil, nil, nil, nil, false
flagInit() flagInit()
cmdRootWithRun.Flags().BoolVarP(&flagbr, "boolroot", "b", false, "help message for flag boolroot") cmdRootWithRun.Flags().BoolVarP(&flagbr, "boolroot", "b", false, "help message for flag boolroot")
cmdRootWithRun.Flags().IntVarP(&flagir, "introot", "i", 321, "help message for flag introot") cmdRootWithRun.Flags().IntVarP(&flagir, "introot", "i", 321, "help message for flag introot")
@ -481,7 +485,7 @@ func TestChildCommandFlags(t *testing.T) {
t.Errorf("invalid input should generate error") t.Errorf("invalid input should generate error")
} }
if !strings.Contains(r.Output, "invalid argument \"10E\" for -i10E") { if !strings.Contains(r.Output, "invalid argument \"10E\" for i10E") {
t.Errorf("Wrong error message displayed, \n %s", r.Output) t.Errorf("Wrong error message displayed, \n %s", r.Output)
} }
} }
@ -494,7 +498,7 @@ func TestTrailingCommandFlags(t *testing.T) {
} }
} }
func TestInvalidSubCommandFlags(t *testing.T) { func TestInvalidSubcommandFlags(t *testing.T) {
cmd := initializeWithRootCmd() cmd := initializeWithRootCmd()
cmd.AddCommand(cmdTimes) cmd.AddCommand(cmdTimes)
@ -508,7 +512,7 @@ func TestInvalidSubCommandFlags(t *testing.T) {
} }
func TestSubCommandArgEvaluation(t *testing.T) { func TestSubcommandArgEvaluation(t *testing.T) {
cmd := initializeWithRootCmd() cmd := initializeWithRootCmd()
first := &Command{ first := &Command{
@ -777,7 +781,7 @@ func TestFlagsBeforeCommand(t *testing.T) {
// With parsing error properly reported // With parsing error properly reported
x = fullSetupTest("-i10E echo") x = fullSetupTest("-i10E echo")
if !strings.Contains(x.Output, "invalid argument \"10E\" for -i10E") { if !strings.Contains(x.Output, "invalid argument \"10E\" for i10E") {
t.Errorf("Wrong error message displayed, \n %s", x.Output) t.Errorf("Wrong error message displayed, \n %s", x.Output)
} }
@ -819,6 +823,31 @@ func TestRemoveCommand(t *testing.T) {
} }
} }
func TestCommandWithoutSubcommands(t *testing.T) {
c := initializeWithRootCmd()
x := simpleTester(c, "")
if x.Error != nil {
t.Errorf("Calling command without subcommands should not have error: %v", x.Error)
return
}
}
func TestCommandWithoutSubcommandsWithArg(t *testing.T) {
c := initializeWithRootCmd()
expectedArgs := []string{"arg"}
x := simpleTester(c, "arg")
if x.Error != nil {
t.Errorf("Calling command without subcommands but with arg should not have error: %v", x.Error)
return
}
if !reflect.DeepEqual(expectedArgs, tr) {
t.Errorf("Calling command without subcommands but with arg has wrong args: expected: %v, actual: %v", expectedArgs, tr)
return
}
}
func TestReplaceCommandWithRemove(t *testing.T) { func TestReplaceCommandWithRemove(t *testing.T) {
versionUsed = 0 versionUsed = 0
c := initializeWithRootCmd() c := initializeWithRootCmd()
@ -886,3 +915,28 @@ func TestPeristentPreRunPropagation(t *testing.T) {
t.Error("RootCmd PersistentPreRun not called but should have been") t.Error("RootCmd PersistentPreRun not called but should have been")
} }
} }
func TestGlobalNormFuncPropagation(t *testing.T) {
normFunc := func(f *pflag.FlagSet, name string) pflag.NormalizedName {
return pflag.NormalizedName(name)
}
rootCmd := initialize()
rootCmd.SetGlobalNormalizationFunc(normFunc)
if reflect.ValueOf(normFunc) != reflect.ValueOf(rootCmd.GlobalNormalizationFunc()) {
t.Error("rootCmd seems to have a wrong normalization function")
}
// First add the cmdEchoSub to cmdPrint
cmdPrint.AddCommand(cmdEchoSub)
if cmdPrint.GlobalNormalizationFunc() != nil && cmdEchoSub.GlobalNormalizationFunc() != nil {
t.Error("cmdPrint and cmdEchoSub should had no normalization functions")
}
// Now add cmdPrint to rootCmd
rootCmd.AddCommand(cmdPrint)
if reflect.ValueOf(cmdPrint.GlobalNormalizationFunc()).Pointer() != reflect.ValueOf(rootCmd.GlobalNormalizationFunc()).Pointer() ||
reflect.ValueOf(cmdEchoSub.GlobalNormalizationFunc()).Pointer() != reflect.ValueOf(rootCmd.GlobalNormalizationFunc()).Pointer() {
t.Error("cmdPrint and cmdEchoSub should had the normalization function of rootCmd")
}
}

View File

@ -18,13 +18,14 @@ package cobra
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"github.com/inconshreveable/mousetrap"
flag "github.com/spf13/pflag"
"io" "io"
"os" "os"
"runtime" "runtime"
"strings" "strings"
"time" "time"
"github.com/inconshreveable/mousetrap"
flag "github.com/spf13/pflag"
) )
// Command is just that, a command for your application. // Command is just that, a command for your application.
@ -93,6 +94,8 @@ type Command struct {
helpFunc func(*Command, []string) // Help can be defined by application helpFunc func(*Command, []string) // Help can be defined by application
helpCommand *Command // The help command helpCommand *Command // The help command
helpFlagVal bool helpFlagVal bool
// The global normalization function that we can use on every pFlag set and children commands
globNormFunc func(f *flag.FlagSet, name string) flag.NormalizedName
} }
// os.Args[1:] by default, if desired, can be overridden // os.Args[1:] by default, if desired, can be overridden
@ -151,6 +154,19 @@ func (c *Command) SetHelpTemplate(s string) {
c.helpTemplate = s c.helpTemplate = s
} }
// SetGlobalNormalizationFunc sets a normalization function to all flag sets and also to child commands.
// The user should not have a cyclic dependency on commands.
func (c *Command) SetGlobalNormalizationFunc(n func(f *flag.FlagSet, name string) flag.NormalizedName) {
c.Flags().SetNormalizeFunc(n)
c.PersistentFlags().SetNormalizeFunc(n)
c.LocalFlags().SetNormalizeFunc(n)
c.globNormFunc = n
for _, command := range c.commands {
command.SetGlobalNormalizationFunc(n)
}
}
func (c *Command) UsageFunc() (f func(*Command) error) { func (c *Command) UsageFunc() (f func(*Command) error) {
if c.usageFunc != nil { if c.usageFunc != nil {
return c.usageFunc return c.usageFunc
@ -360,25 +376,28 @@ func argsMinusFirstX(args []string, x string) []string {
// find the target command given the args and command tree // find the target command given the args and command tree
// Meant to be run on the highest node. Only searches down. // Meant to be run on the highest node. Only searches down.
func (c *Command) Find(arrs []string) (*Command, []string, error) { func (c *Command) Find(args []string) (*Command, []string, error) {
if c == nil { if c == nil {
return nil, nil, fmt.Errorf("Called find() on a nil Command") return nil, nil, fmt.Errorf("Called find() on a nil Command")
} }
if len(arrs) == 0 { // If there are no arguments, return the root command. If the root has no
return c.Root(), arrs, nil // subcommands, args reflects arguments that should actually be passed to
// the root command, so also return the root command.
if len(args) == 0 || !c.Root().HasSubCommands() {
return c.Root(), args, nil
} }
var innerfind func(*Command, []string) (*Command, []string) var innerfind func(*Command, []string) (*Command, []string)
innerfind = func(c *Command, args []string) (*Command, []string) { innerfind = func(c *Command, innerArgs []string) (*Command, []string) {
if len(args) > 0 && c.HasSubCommands() { if len(innerArgs) > 0 && c.HasSubCommands() {
argsWOflags := stripFlags(args, c) argsWOflags := stripFlags(innerArgs, c)
if len(argsWOflags) > 0 { if len(argsWOflags) > 0 {
matches := make([]*Command, 0) matches := make([]*Command, 0)
for _, cmd := range c.commands { for _, cmd := range c.commands {
if cmd.Name() == argsWOflags[0] || cmd.HasAlias(argsWOflags[0]) { // exact name or alias match if cmd.Name() == argsWOflags[0] || cmd.HasAlias(argsWOflags[0]) { // exact name or alias match
return innerfind(cmd, argsMinusFirstX(args, argsWOflags[0])) return innerfind(cmd, argsMinusFirstX(innerArgs, argsWOflags[0]))
} else if EnablePrefixMatching { } else if EnablePrefixMatching {
if strings.HasPrefix(cmd.Name(), argsWOflags[0]) { // prefix match if strings.HasPrefix(cmd.Name(), argsWOflags[0]) { // prefix match
matches = append(matches, cmd) matches = append(matches, cmd)
@ -393,18 +412,18 @@ func (c *Command) Find(arrs []string) (*Command, []string, error) {
// only accept a single prefix match - multiple matches would be ambiguous // only accept a single prefix match - multiple matches would be ambiguous
if len(matches) == 1 { if len(matches) == 1 {
return innerfind(matches[0], argsMinusFirstX(args, argsWOflags[0])) return innerfind(matches[0], argsMinusFirstX(innerArgs, argsWOflags[0]))
} }
} }
} }
return c, args return c, innerArgs
} }
commandFound, a := innerfind(c, arrs) commandFound, a := innerfind(c, args)
// If we matched on the root, but we asked for a subcommand, return an error // If we matched on the root, but we asked for a subcommand, return an error
if commandFound.Name() == c.Name() && len(stripFlags(arrs, c)) > 0 && commandFound.Name() != arrs[0] { if commandFound.Name() == c.Name() && len(stripFlags(args, c)) > 0 && commandFound.Name() != args[0] {
return nil, a, fmt.Errorf("unknown command %q", a[0]) return nil, a, fmt.Errorf("unknown command %q", a[0])
} }
@ -606,6 +625,10 @@ func (c *Command) AddCommand(cmds ...*Command) {
if nameLen > c.commandsMaxNameLen { if nameLen > c.commandsMaxNameLen {
c.commandsMaxNameLen = nameLen c.commandsMaxNameLen = nameLen
} }
// If glabal normalization function exists, update all children
if c.globNormFunc != nil {
x.SetGlobalNormalizationFunc(c.globNormFunc)
}
c.commands = append(c.commands, x) c.commands = append(c.commands, x)
} }
} }
@ -830,6 +853,11 @@ func (c *Command) HasParent() bool {
return c.parent != nil return c.parent != nil
} }
// GlobalNormalizationFunc returns the global normalization function or nil if doesn't exists
func (c *Command) GlobalNormalizationFunc() func(f *flag.FlagSet, name string) flag.NormalizedName {
return c.globNormFunc
}
// Get the complete FlagSet that applies to this command (local and persistent declared here and by all parents) // Get the complete FlagSet that applies to this command (local and persistent declared here and by all parents)
func (c *Command) Flags() *flag.FlagSet { func (c *Command) Flags() *flag.FlagSet {
if c.flags == nil { if c.flags == nil {

View File

@ -47,6 +47,10 @@ func (s byName) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s byName) Less(i, j int) bool { return s[i].Name() < s[j].Name() } func (s byName) Less(i, j int) bool { return s[i].Name() < s[j].Name() }
func GenMarkdown(cmd *Command, out *bytes.Buffer) { func GenMarkdown(cmd *Command, out *bytes.Buffer) {
GenMarkdownCustom(cmd, out, func(s string) string { return s })
}
func GenMarkdownCustom(cmd *Command, out *bytes.Buffer, linkHandler func(string) string) {
name := cmd.CommandPath() name := cmd.CommandPath()
short := cmd.Short short := cmd.Short
@ -78,7 +82,7 @@ func GenMarkdown(cmd *Command, out *bytes.Buffer) {
pname := parent.CommandPath() pname := parent.CommandPath()
link := pname + ".md" link := pname + ".md"
link = strings.Replace(link, " ", "_", -1) link = strings.Replace(link, " ", "_", -1)
fmt.Fprintf(out, "* [%s](%s)\t - %s\n", pname, link, parent.Short) fmt.Fprintf(out, "* [%s](%s)\t - %s\n", pname, linkHandler(link), parent.Short)
} }
children := cmd.Commands() children := cmd.Commands()
@ -91,7 +95,7 @@ func GenMarkdown(cmd *Command, out *bytes.Buffer) {
cname := name + " " + child.Name() cname := name + " " + child.Name()
link := cname + ".md" link := cname + ".md"
link = strings.Replace(link, " ", "_", -1) link = strings.Replace(link, " ", "_", -1)
fmt.Fprintf(out, "* [%s](%s)\t - %s\n", cname, link, child.Short) fmt.Fprintf(out, "* [%s](%s)\t - %s\n", cname, linkHandler(link), child.Short)
} }
fmt.Fprintf(out, "\n") fmt.Fprintf(out, "\n")
} }
@ -100,13 +104,18 @@ func GenMarkdown(cmd *Command, out *bytes.Buffer) {
} }
func GenMarkdownTree(cmd *Command, dir string) { func GenMarkdownTree(cmd *Command, dir string) {
for _, c := range cmd.Commands() { identity := func(s string) string { return s }
GenMarkdownTree(c, dir) emptyStr := func(s string) string { return "" }
} GenMarkdownTreeCustom(cmd, dir, emptyStr, identity)
}
func GenMarkdownTreeCustom(cmd *Command, dir string, filePrepender func(string) string, linkHandler func(string) string) {
for _, c := range cmd.Commands() {
GenMarkdownTreeCustom(c, dir, filePrepender, linkHandler)
}
out := new(bytes.Buffer) out := new(bytes.Buffer)
GenMarkdown(cmd, out) GenMarkdownCustom(cmd, out, linkHandler)
filename := cmd.CommandPath() filename := cmd.CommandPath()
filename = dir + strings.Replace(filename, " ", "_", -1) + ".md" filename = dir + strings.Replace(filename, " ", "_", -1) + ".md"
@ -116,6 +125,11 @@ func GenMarkdownTree(cmd *Command, dir string) {
os.Exit(1) os.Exit(1)
} }
defer outFile.Close() defer outFile.Close()
_, err = outFile.WriteString(filePrepender(filename))
if err != nil {
fmt.Println(err)
os.Exit(1)
}
_, err = outFile.Write(out.Bytes()) _, err = outFile.Write(out.Bytes())
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)

View File

@ -25,7 +25,7 @@ This will generate a whole series of files, one for each command in the tree, in
## Generate markdown docs for a single command ## Generate markdown docs for a single command
You may wish to have more control over the output, or only generate for a single command, instead of the entire command tree. If this is the case you may prefer to `GenMarkdown()` instead of `GenMarkdownTree` You may wish to have more control over the output, or only generate for a single command, instead of the entire command tree. If this is the case you may prefer to `GenMarkdown` instead of `GenMarkdownTree`
```go ```go
out := new(bytes.Buffer) out := new(bytes.Buffer)
@ -33,3 +33,49 @@ You may wish to have more control over the output, or only generate for a single
``` ```
This will write the markdown doc for ONLY "cmd" into the out, buffer. This will write the markdown doc for ONLY "cmd" into the out, buffer.
## Customize the output
Both `GenMarkdown` and `GenMarkdownTree` have alternate versions with callbacks to get some control of the output:
```go
func GenMarkdownTreeCustom(cmd *Command, dir string, filePrepender func(string) string, linkHandler func(string) string) {
//...
}
```
```go
func GenMarkdownCustom(cmd *Command, out *bytes.Buffer, linkHandler func(string) string) {
//...
}
```
The `filePrepender` will prepend the return value given the full filepath to the rendered Markdown file. A common use case is to add front matter to use the generated documentation with [Hugo](http://gohugo.io/):
```go
const fmTemplate = `---
date: %s
title: "%s"
slug: %s
url: %s
---
`
filePrepender := func(filename string) string {
now := time.Now().Format(time.RFC3339)
name := filepath.Base(filename)
base := strings.TrimSuffix(name, path.Ext(name))
url := "/commands/" + strings.ToLower(base) + "/"
return fmt.Sprintf(fmTemplate, now, strings.Replace(base, "_", " ", -1), base, url)
}
```
The `linkHandler` can be used to customize the rendered internal links to the commands, given a filename:
```go
linkHandler := func(name string) string {
base := strings.TrimSuffix(name, path.Ext(name))
return "/commands/" + strings.ToLower(base) + "/"
}
```

View File

@ -0,0 +1,8 @@
sudo: false
language: go
go:
- 1.3
- 1.4
- tip

View File

@ -1,3 +1,5 @@
[![Build Status](https://travis-ci.org/spf13/pflag.svg?branch=master)](https://travis-ci.org/spf13/pflag)
## Description ## Description
pflag is a drop-in replacement for Go's flag package, implementing pflag is a drop-in replacement for Go's flag package, implementing
@ -143,6 +145,40 @@ Boolean flags (in their long form) accept 1, 0, t, f, true, false,
TRUE, FALSE, True, False. TRUE, FALSE, True, False.
Duration flags accept any input valid for time.ParseDuration. Duration flags accept any input valid for time.ParseDuration.
## Mutating or "Normalizing" Flag names
It is possible to set a custom flag name 'normalization function.' It allows flag names to be mutated both when created in the code and when used on the command line to some 'normalized' form. The 'normalized' form is used for comparison. Two examples of using the custom normalization func follow.
**Example #1**: You want -, _, and . in flags to compare the same. aka --my-flag == --my_flag == --my.flag
```go
func wordSepNormalizeFunc(f *pflag.FlagSet, name string) pflag.NormalizedName {
from := []string{"-", "_"}
to := "."
for _, sep := range from {
name = strings.Replace(name, sep, to, -1)
}
return pflag.NormalizedName(name)
}
myFlagSet.SetNormalizeFunc(wordSepNormalizeFunc)
```
**Example #2**: You want to alias two flags. aka --old-flag-name == --new-flag-name
```go
func aliasNormalizeFunc(f *pflag.FlagSet, name string) pflag.NormalizedName {
switch name {
case "old-flag-name":
name = "new-flag-name"
break
}
return pflag.NormalizedName(name)
}
myFlagSet.SetNormalizeFunc(aliasNormalizeFunc)
```
## More info ## More info
You can see the full reference documentation of the pflag package You can see the full reference documentation of the pflag package

View File

@ -2,14 +2,13 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package pflag_test package pflag
import ( import (
"bytes"
"fmt" "fmt"
"strconv" "strconv"
"testing" "testing"
. "github.com/spf13/pflag"
) )
// This value can be a boolean ("true", "false") or "maybe" // This value can be a boolean ("true", "false") or "maybe"
@ -156,8 +155,9 @@ func TestImplicitFalse(t *testing.T) {
func TestInvalidValue(t *testing.T) { func TestInvalidValue(t *testing.T) {
var tristate triStateValue var tristate triStateValue
f := setUpFlagSet(&tristate) f := setUpFlagSet(&tristate)
args := []string{"--tristate=invalid"} var buf bytes.Buffer
_, err := parseReturnStderr(t, f, args) f.SetOutput(&buf)
err := f.Parse([]string{"--tristate=invalid"})
if err == nil { if err == nil {
t.Fatal("expected an error but did not get any, tristate has value", tristate) t.Fatal("expected an error but did not get any, tristate has value", tristate)
} }

View File

@ -184,7 +184,9 @@ func (f *FlagSet) SetNormalizeFunc(n func(f *FlagSet, name string) NormalizedNam
f.normalizeNameFunc = n f.normalizeNameFunc = n
for k, v := range f.formal { for k, v := range f.formal {
delete(f.formal, k) delete(f.formal, k)
f.formal[f.normalizeFlagName(string(k))] = v nname := f.normalizeFlagName(string(k))
f.formal[nname] = v
v.Name = string(nname)
} }
} }
@ -421,7 +423,10 @@ func (f *FlagSet) VarP(value Value, name, shorthand, usage string) {
} }
func (f *FlagSet) AddFlag(flag *Flag) { func (f *FlagSet) AddFlag(flag *Flag) {
_, alreadythere := f.formal[f.normalizeFlagName(flag.Name)] // Call normalizeFlagName function only once
var normalizedFlagName NormalizedName = f.normalizeFlagName(flag.Name)
_, alreadythere := f.formal[normalizedFlagName]
if alreadythere { if alreadythere {
msg := fmt.Sprintf("%s flag redefined: %s", f.name, flag.Name) msg := fmt.Sprintf("%s flag redefined: %s", f.name, flag.Name)
fmt.Fprintln(f.out(), msg) fmt.Fprintln(f.out(), msg)
@ -430,7 +435,9 @@ func (f *FlagSet) AddFlag(flag *Flag) {
if f.formal == nil { if f.formal == nil {
f.formal = make(map[NormalizedName]*Flag) f.formal = make(map[NormalizedName]*Flag)
} }
f.formal[f.normalizeFlagName(flag.Name)] = flag
flag.Name = string(normalizedFlagName)
f.formal[normalizedFlagName] = flag
if len(flag.Shorthand) == 0 { if len(flag.Shorthand) == 0 {
return return
@ -505,10 +512,6 @@ func (f *FlagSet) setFlag(flag *Flag, value string, origArg string) error {
func (f *FlagSet) parseLongArg(s string, args []string) (a []string, err error) { func (f *FlagSet) parseLongArg(s string, args []string) (a []string, err error) {
a = args a = args
if len(s) == 2 { // "--" terminates the flags
f.args = append(f.args, args...)
return
}
name := s[2:] name := s[2:]
if len(name) == 0 || name[0] == '-' || name[0] == '=' { if len(name) == 0 || name[0] == '-' || name[0] == '=' {
err = f.failf("bad flag syntax: %s", s) err = f.failf("bad flag syntax: %s", s)
@ -516,75 +519,74 @@ func (f *FlagSet) parseLongArg(s string, args []string) (a []string, err error)
} }
split := strings.SplitN(name, "=", 2) split := strings.SplitN(name, "=", 2)
name = split[0] name = split[0]
m := f.formal flag, alreadythere := f.formal[f.normalizeFlagName(name)]
flag, alreadythere := m[f.normalizeFlagName(name)] // BUG
if !alreadythere { if !alreadythere {
if name == "help" { // special case for nice help message. if name == "help" { // special case for nice help message.
f.usage() f.usage()
return args, ErrHelp return a, ErrHelp
} }
err = f.failf("unknown flag: --%s", name) err = f.failf("unknown flag: --%s", name)
return return
} }
var value string
if len(split) == 1 { if len(split) == 1 {
if bv, ok := flag.Value.(boolFlag); !ok || !bv.IsBoolFlag() { if bv, ok := flag.Value.(boolFlag); !ok || !bv.IsBoolFlag() {
err = f.failf("flag needs an argument: %s", s) err = f.failf("flag needs an argument: %s", s)
return return
} }
f.setFlag(flag, "true", s) value = "true"
} else { } else {
if e := f.setFlag(flag, split[1], s); e != nil { value = split[1]
err = e }
err = f.setFlag(flag, value, s)
return
}
func (f *FlagSet) parseSingleShortArg(shorthands string, args []string) (outShorts string, outArgs []string, err error) {
outArgs = args
outShorts = shorthands[1:]
c := shorthands[0]
flag, alreadythere := f.shorthands[c]
if !alreadythere {
if c == 'h' { // special case for nice help message.
f.usage()
err = ErrHelp
return return
} }
//TODO continue on error
err = f.failf("unknown shorthand flag: %q in -%s", c, shorthands)
return
} }
return args, nil var value string
if len(shorthands) > 2 && shorthands[1] == '=' {
value = shorthands[2:]
outShorts = ""
} else if bv, ok := flag.Value.(boolFlag); ok && bv.IsBoolFlag() {
value = "true"
} else if len(shorthands) > 1 {
value = shorthands[1:]
outShorts = ""
} else if len(args) > 0 {
value = args[0]
outArgs = args[1:]
} else {
err = f.failf("flag needs an argument: %q in -%s", c, shorthands)
return
}
err = f.setFlag(flag, value, shorthands)
return
} }
func (f *FlagSet) parseShortArg(s string, args []string) (a []string, err error) { func (f *FlagSet) parseShortArg(s string, args []string) (a []string, err error) {
a = args a = args
shorthands := s[1:] shorthands := s[1:]
for i := 0; i < len(shorthands); i++ { for len(shorthands) > 0 {
c := shorthands[i] shorthands, a, err = f.parseSingleShortArg(shorthands, args)
flag, alreadythere := f.shorthands[c] if err != nil {
if !alreadythere {
if c == 'h' { // special case for nice help message.
f.usage()
err = ErrHelp
return
}
//TODO continue on error
err = f.failf("unknown shorthand flag: %q in -%s", c, shorthands)
if len(args) == 0 {
return
}
return return
} }
if alreadythere {
if bv, ok := flag.Value.(boolFlag); ok && bv.IsBoolFlag() {
f.setFlag(flag, "true", s)
continue
}
if i < len(shorthands)-1 {
v := strings.TrimPrefix(shorthands[i+1:], "=")
if e := f.setFlag(flag, v, s); e != nil {
err = e
return
}
break
}
if len(args) == 0 {
err = f.failf("flag needs an argument: %q in -%s", c, shorthands)
return
}
if e := f.setFlag(flag, args[0], s); e != nil {
err = e
return
}
}
a = args[1:]
break // should be unnecessary
} }
return return
@ -605,12 +607,11 @@ func (f *FlagSet) parseArgs(args []string) (err error) {
} }
if s[1] == '-' { if s[1] == '-' {
args, err = f.parseLongArg(s, args) if len(s) == 2 { // "--" terminates the flags
f.args = append(f.args, args...)
if len(s) == 2 {
// stop parsing after --
break break
} }
args, err = f.parseLongArg(s, args)
} else { } else {
args, err = f.parseShortArg(s, args) args, err = f.parseShortArg(s, args)
} }

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package pflag_test package pflag
import ( import (
"bytes" "bytes"
@ -14,19 +14,18 @@ import (
"strings" "strings"
"testing" "testing"
"time" "time"
. "github.com/spf13/pflag"
) )
var ( var (
test_bool = Bool("test_bool", false, "bool value") test_bool = Bool("test_bool", false, "bool value")
test_int = Int("test_int", 0, "int value") test_int = Int("test_int", 0, "int value")
test_int64 = Int64("test_int64", 0, "int64 value") test_int64 = Int64("test_int64", 0, "int64 value")
test_uint = Uint("test_uint", 0, "uint value") test_uint = Uint("test_uint", 0, "uint value")
test_uint64 = Uint64("test_uint64", 0, "uint64 value") test_uint64 = Uint64("test_uint64", 0, "uint64 value")
test_string = String("test_string", "0", "string value") test_string = String("test_string", "0", "string value")
test_float64 = Float64("test_float64", 0, "float64 value") test_float64 = Float64("test_float64", 0, "float64 value")
test_duration = Duration("test_duration", 0, "time.Duration value") test_duration = Duration("test_duration", 0, "time.Duration value")
normalizeFlagNameInvocations = 0
) )
func boolString(s string) string { func boolString(s string) string {
@ -186,6 +185,7 @@ func TestShorthand(t *testing.T) {
boolaFlag := f.BoolP("boola", "a", false, "bool value") boolaFlag := f.BoolP("boola", "a", false, "bool value")
boolbFlag := f.BoolP("boolb", "b", false, "bool2 value") boolbFlag := f.BoolP("boolb", "b", false, "bool2 value")
boolcFlag := f.BoolP("boolc", "c", false, "bool3 value") boolcFlag := f.BoolP("boolc", "c", false, "bool3 value")
booldFlag := f.BoolP("boold", "d", false, "bool4 value")
stringaFlag := f.StringP("stringa", "s", "0", "string value") stringaFlag := f.StringP("stringa", "s", "0", "string value")
stringzFlag := f.StringP("stringz", "z", "0", "string value") stringzFlag := f.StringP("stringz", "z", "0", "string value")
extra := "interspersed-argument" extra := "interspersed-argument"
@ -196,6 +196,7 @@ func TestShorthand(t *testing.T) {
"-cs", "-cs",
"hello", "hello",
"-z=something", "-z=something",
"-d=true",
"--", "--",
notaflag, notaflag,
} }
@ -215,6 +216,9 @@ func TestShorthand(t *testing.T) {
if *boolcFlag != true { if *boolcFlag != true {
t.Error("boolc flag should be true, is ", *boolcFlag) t.Error("boolc flag should be true, is ", *boolcFlag)
} }
if *booldFlag != true {
t.Error("boold flag should be true, is ", *booldFlag)
}
if *stringaFlag != "hello" { if *stringaFlag != "hello" {
t.Error("stringa flag should be `hello`, is ", *stringaFlag) t.Error("stringa flag should be `hello`, is ", *stringaFlag)
} }
@ -251,6 +255,8 @@ func replaceSeparators(name string, from []string, to string) string {
func wordSepNormalizeFunc(f *FlagSet, name string) NormalizedName { func wordSepNormalizeFunc(f *FlagSet, name string) NormalizedName {
seps := []string{"-", "_"} seps := []string{"-", "_"}
name = replaceSeparators(name, seps, ".") name = replaceSeparators(name, seps, ".")
normalizeFlagNameInvocations++
return NormalizedName(name) return NormalizedName(name)
} }
@ -343,6 +349,31 @@ func TestCustomNormalizedNames(t *testing.T) {
} }
} }
// Every flag we add, the name (displayed also in usage) should normalized
func TestNormalizationFuncShouldChangeFlagName(t *testing.T) {
// Test normalization after addition
f := NewFlagSet("normalized", ContinueOnError)
f.Bool("valid_flag", false, "bool value")
if f.Lookup("valid_flag").Name != "valid_flag" {
t.Error("The new flag should have the name 'valid_flag' instead of ", f.Lookup("valid_flag").Name)
}
f.SetNormalizeFunc(wordSepNormalizeFunc)
if f.Lookup("valid_flag").Name != "valid.flag" {
t.Error("The new flag should have the name 'valid.flag' instead of ", f.Lookup("valid_flag").Name)
}
// Test normalization before addition
f = NewFlagSet("normalized", ContinueOnError)
f.SetNormalizeFunc(wordSepNormalizeFunc)
f.Bool("valid_flag", false, "bool value")
if f.Lookup("valid_flag").Name != "valid.flag" {
t.Error("The new flag should have the name 'valid.flag' instead of ", f.Lookup("valid_flag").Name)
}
}
// Declare a user-defined flag type. // Declare a user-defined flag type.
type flagVar []string type flagVar []string
@ -571,3 +602,16 @@ func TestDeprecatedFlagUsageNormalized(t *testing.T) {
t.Errorf("usageMsg not printed when using a deprecated flag!") t.Errorf("usageMsg not printed when using a deprecated flag!")
} }
} }
// Name normalization function should be called only once on flag addition
func TestMultipleNormalizeFlagNameInvocations(t *testing.T) {
normalizeFlagNameInvocations = 0
f := NewFlagSet("normalized", ContinueOnError)
f.SetNormalizeFunc(wordSepNormalizeFunc)
f.Bool("with_under_flag", false, "bool value")
if normalizeFlagNameInvocations != 1 {
t.Fatal("Expected normalizeFlagNameInvocations to be 1; got ", normalizeFlagNameInvocations)
}
}