From bc6938dc0899c4f37b4afa1f7b32c8bc7cdfb618 Mon Sep 17 00:00:00 2001 From: Matt Bell Date: Mon, 3 Nov 2014 18:34:13 -0800 Subject: [PATCH] commands: Cleaned up argument validation --- commands/command.go | 75 +++++++++++++++++++++++++++++---------------- 1 file changed, 49 insertions(+), 26 deletions(-) diff --git a/commands/command.go b/commands/command.go index 4844f1a5f..2ecb6ee29 100644 --- a/commands/command.go +++ b/commands/command.go @@ -140,40 +140,37 @@ func (c *Command) GetOptions(path []string) (map[string]Option, error) { } func (c *Command) CheckArguments(req Request) error { - var argDef Argument args := req.Arguments() + argDefs := c.Arguments - var length int - if len(args) > len(c.Arguments) { - length = len(args) - } else { - length = len(c.Arguments) + // if we have more arg values provided than argument definitions, + // and the last arg definition is not variadic (or there are no definitions), return an error + notVariadic := len(argDefs) == 0 || !argDefs[len(argDefs)-1].Variadic + if notVariadic && len(args) > len(argDefs) { + return fmt.Errorf("Expected %v arguments, got %v", len(argDefs), len(args)) } - for i := 0; i < length; i++ { - var arg interface{} - if len(args) > i { - arg = args[i] + // iterate over the arg definitions + for i, argDef := range c.Arguments { + + // the value for this argument definition. can be nil if it wasn't provided by the caller + var v interface{} + if i < len(args) { + v = args[i] } - if i < len(c.Arguments) { - argDef = c.Arguments[i] - } else if !argDef.Variadic { - return fmt.Errorf("Expected %v arguments, got %v", len(c.Arguments), len(args)) + err := checkArgValue(v, argDef) + if err != nil { + return err } - if argDef.Required && arg == nil { - return fmt.Errorf("Argument '%s' is required", argDef.Name) - } - if argDef.Type == ArgFile { - _, ok := arg.(io.Reader) - if !ok { - return fmt.Errorf("Argument '%s' isn't valid", argDef.Name) - } - } else if argDef.Type == ArgString { - _, ok := arg.(string) - if !ok { - return fmt.Errorf("Argument '%s' must be a string", argDef.Name) + // any additional values are for the variadic arg definition + if argDef.Variadic && i < len(args)-1 { + for _, val := range args[i+1:] { + err := checkArgValue(val, argDef) + if err != nil { + return err + } } } } @@ -185,3 +182,29 @@ func (c *Command) CheckArguments(req Request) error { func (c *Command) Subcommand(id string) *Command { return c.Subcommands[id] } + +// checkArgValue returns an error if a given arg value is not valid for the given Argument +func checkArgValue(v interface{}, def Argument) error { + if v == nil { + if def.Required { + return fmt.Errorf("Argument '%s' is required", def.Name) + } + + return nil + } + + if def.Type == ArgFile { + _, ok := v.(io.Reader) + if !ok { + return fmt.Errorf("Argument '%s' isn't valid", def.Name) + } + + } else if def.Type == ArgString { + _, ok := v.(string) + if !ok { + return fmt.Errorf("Argument '%s' must be a string", def.Name) + } + } + + return nil +}