diff --git a/commands/cli/parse.go b/commands/cli/parse.go index 2b985a681..b8a105dbc 100644 --- a/commands/cli/parse.go +++ b/commands/cli/parse.go @@ -45,7 +45,14 @@ func Parse(input []string, roots ...*cmds.Command) (cmds.Request, *cmds.Command, return nil, nil, err } - return cmds.NewRequest(path, opts, args, cmd), root, nil + req := cmds.NewRequest(path, opts, args, cmd) + + err = cmd.CheckArguments(req) + if err != nil { + return nil, nil, err + } + + return req, root, nil } // parsePath gets the command path from the command line input @@ -108,8 +115,6 @@ func parseOptions(input []string) (map[string]interface{}, []string, error) { return opts, args, nil } -// Note that the argument handling here is dumb, it does not do any error-checking. -// (Arguments are further processed when the request is passed to the command to run) func parseArgs(stringArgs []string, cmd *cmds.Command) ([]interface{}, error) { var argDef cmds.Argument args := make([]interface{}, len(stringArgs)) diff --git a/commands/command.go b/commands/command.go index ea3dc2e5d..4844f1a5f 100644 --- a/commands/command.go +++ b/commands/command.go @@ -3,6 +3,7 @@ package commands import ( "errors" "fmt" + "io" "strings" u "github.com/jbenet/go-ipfs/util" @@ -58,7 +59,7 @@ func (c *Command) Call(req Request) Response { return res } - err = req.CheckArguments(cmd.Arguments) + err = cmd.CheckArguments(req) if err != nil { res.SetError(err, ErrClient) return res @@ -138,6 +139,48 @@ func (c *Command) GetOptions(path []string) (map[string]Option, error) { return optionsMap, nil } +func (c *Command) CheckArguments(req Request) error { + var argDef Argument + args := req.Arguments() + + var length int + if len(args) > len(c.Arguments) { + length = len(args) + } else { + length = len(c.Arguments) + } + + for i := 0; i < length; i++ { + var arg interface{} + if len(args) > i { + arg = args[i] + } + + if i < len(c.Arguments) { + argDef = c.Arguments[i] + } else if !argDef.Variadic { + return fmt.Errorf("Expected %v arguments, got %v", len(c.Arguments), len(args)) + } + + if argDef.Required && arg == nil { + return fmt.Errorf("Argument '%s' is required", argDef.Name) + } + if argDef.Type == ArgFile { + _, ok := arg.(io.Reader) + if !ok { + return fmt.Errorf("Argument '%s' isn't valid", argDef.Name) + } + } else if argDef.Type == ArgString { + _, ok := arg.(string) + if !ok { + return fmt.Errorf("Argument '%s' must be a string", argDef.Name) + } + } + } + + return nil +} + // Subcommand returns the subcommand with the given id func (c *Command) Subcommand(id string) *Command { return c.Subcommands[id] diff --git a/commands/http/parse.go b/commands/http/parse.go index b30ef0bfd..bb2f8348a 100644 --- a/commands/http/parse.go +++ b/commands/http/parse.go @@ -51,7 +51,14 @@ func Parse(r *http.Request, root *cmds.Command) (cmds.Request, error) { } } - return cmds.NewRequest(path, opts, args, cmd), nil + req := cmds.NewRequest(path, opts, args, cmd) + + err = cmd.CheckArguments(req) + if err != nil { + return nil, err + } + + return req, nil } func parseOptions(r *http.Request) (map[string]interface{}, []string) { diff --git a/commands/request.go b/commands/request.go index ea21d4720..c26f1773b 100644 --- a/commands/request.go +++ b/commands/request.go @@ -2,7 +2,6 @@ package commands import ( "fmt" - "io" "reflect" "strconv" @@ -29,7 +28,6 @@ type Request interface { SetContext(Context) Command() *Command - CheckArguments(args []Argument) error ConvertOptions(options map[string]Option) error } @@ -103,48 +101,6 @@ var converters = map[reflect.Kind]converter{ }, } -// MAYBE_TODO: maybe this should be a Command method? (taking a Request as a param) -func (r *request) CheckArguments(args []Argument) error { - var argDef Argument - - var length int - if len(r.arguments) > len(args) { - length = len(r.arguments) - } else { - length = len(args) - } - - for i := 0; i < length; i++ { - var arg interface{} - if len(r.arguments) > i { - arg = r.arguments[i] - } - - if i < len(args) { - argDef = args[i] - } else if !argDef.Variadic { - return fmt.Errorf("Expected %v arguments, got %v", len(args), len(r.arguments)) - } - - if argDef.Required && arg == nil { - return fmt.Errorf("Argument '%s' is required", argDef.Name) - } - if argDef.Type == ArgFile { - _, ok := arg.(io.Reader) - if !ok { - return fmt.Errorf("Argument '%s' isn't valid", argDef.Name) - } - } else if argDef.Type == ArgString { - _, ok := arg.(string) - if !ok { - return fmt.Errorf("Argument '%s' must be a string", argDef.Name) - } - } - } - - return nil -} - func (r *request) ConvertOptions(options map[string]Option) error { converted := make(map[string]interface{})