diff --git a/util/do.go b/util/do.go new file mode 100644 index 000000000..248f0eda8 --- /dev/null +++ b/util/do.go @@ -0,0 +1,15 @@ +package util + +import "code.google.com/p/go.net/context" + +func Do(ctx context.Context, f func() error) error { + ch := make(chan error) + go func() { ch <- f() }() + select { + case <-ctx.Done(): + return ctx.Err() + case val := <-ch: + return val + } + return nil +} diff --git a/util/do_test.go b/util/do_test.go new file mode 100644 index 000000000..14861265f --- /dev/null +++ b/util/do_test.go @@ -0,0 +1,42 @@ +package util + +import ( + "errors" + "testing" + + "code.google.com/p/go.net/context" +) + +func TestDoReturnsContextErr(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + ch := make(chan struct{}) + err := Do(ctx, func() error { + cancel() + ch <- struct{}{} // won't return + return nil + }) + if err != ctx.Err() { + t.Fail() + } +} + +func TestDoReturnsFuncError(t *testing.T) { + ctx := context.Background() + expected := errors.New("expected to be returned by Do") + err := Do(ctx, func() error { + return expected + }) + if err != expected { + t.Fail() + } +} + +func TestDoReturnsNil(t *testing.T) { + ctx := context.Background() + err := Do(ctx, func() error { + return nil + }) + if err != nil { + t.Fail() + } +}