diff --git a/command/command.go b/command/command.go index f7b8be5c..934f6db4 100644 --- a/command/command.go +++ b/command/command.go @@ -23,12 +23,22 @@ func (c *Commands) Register(cmdstr string, cf cmdfunc) { c.cmds[cmdstr] = cf } +// Find will look up the command function for the given command input. +// If it cannot find the command it will defualt to noCmdAvailable(). +// If the command is an empty string it will replay the last command. func (c *Commands) Find(cmdstr string) cmdfunc { cmd, ok := c.cmds[cmdstr] if !ok { + if cmdstr == "" { + return nullCommand + } + return noCmdAvailable } + // Allow to replay last command + c.cmds[""] = cmd + return cmd } @@ -40,3 +50,7 @@ func exitFunc() error { os.Exit(0) return nil } + +func nullCommand() error { + return nil +} diff --git a/command/command_test.go b/command/command_test.go index 4ad609d8..95d1aa51 100644 --- a/command/command_test.go +++ b/command/command_test.go @@ -37,3 +37,32 @@ func TestCommandRegister(t *testing.T) { t.Fatal("wrong command output") } } + +func TestCommandReplay(t *testing.T) { + cmds := Commands{make(map[string]cmdfunc)} + cmds.Register("foo", func() error { return fmt.Errorf("registered command") }) + cmd := cmds.Find("foo") + + err := cmd() + if err.Error() != "registered command" { + t.Fatal("wrong command output") + } + + cmd = cmds.Find("") + err = cmd() + if err.Error() != "registered command" { + t.Fatal("wrong command output") + } +} + +func TestCommandReplayWithoutPreviousCommand(t *testing.T) { + var ( + cmds = Commands{make(map[string]cmdfunc)} + cmd = cmds.Find("") + err = cmd() + ) + + if err != nil { + t.Error("Null command not returned", err) + } +}