diff --git a/commands/command.go b/commands/command.go index afbe650ec..1f15fdab4 100644 --- a/commands/command.go +++ b/commands/command.go @@ -3,7 +3,6 @@ package commands import ( "errors" "fmt" - "io" "reflect" "strings" @@ -195,9 +194,6 @@ func (c *Command) GetOptions(path []string) (map[string]Option, error) { } func (c *Command) CheckArguments(req Request) error { - - // TODO: check file arguments - args := req.Arguments() // count required argument definitions @@ -217,13 +213,14 @@ func (c *Command) CheckArguments(req Request) error { } // the value for this argument definition. can be nil if it wasn't provided by the caller - var v interface{} + v, found := "", false if valueIndex < len(args) { v = args[valueIndex] + found = true valueIndex++ } - err := checkArgValue(v, argDef) + err := checkArgValue(v, found, argDef) if err != nil { return err } @@ -231,7 +228,7 @@ func (c *Command) CheckArguments(req Request) error { // any additional values are for the variadic arg definition if argDef.Variadic && valueIndex < len(args)-1 { for _, val := range args[valueIndex:] { - err := checkArgValue(val, argDef) + err := checkArgValue(val, true, argDef) if err != nil { return err } @@ -248,26 +245,9 @@ func (c *Command) Subcommand(id string) *Command { } // checkArgValue returns an error if a given arg value is not valid for the given Argument -func checkArgValue(v interface{}, def Argument) error { - if v == nil { - if def.Required { - return fmt.Errorf("Argument '%s' is required", def.Name) - } - - return nil - } - - if def.Type == ArgFile { - _, ok := v.(io.Reader) - if !ok { - return fmt.Errorf("Argument '%s' isn't valid", def.Name) - } - - } else if def.Type == ArgString { - _, ok := v.(string) - if !ok { - return fmt.Errorf("Argument '%s' must be a string", def.Name) - } +func checkArgValue(v string, found bool, def Argument) error { + if !found && def.Required { + return fmt.Errorf("Argument '%s' is required", def.Name) } return nil