diff --git a/util/ctx/ctxio.go b/util/ctx/ctxio.go new file mode 100644 index 000000000..56057cdd2 --- /dev/null +++ b/util/ctx/ctxio.go @@ -0,0 +1,110 @@ +package ctxutil + +import ( + "io" + + context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" +) + +type ioret struct { + n int + err error +} + +type Writer interface { + io.Writer +} + +type ctxWriter struct { + w io.Writer + ctx context.Context +} + +// NewWriter wraps a writer to make it respect given Context. +// If there is a blocking write, the returned Writer will return +// whenever the context is cancelled (the return values are n=0 +// and err=ctx.Err().) +// +// Note well: this wrapper DOES NOT ACTUALLY cancel the underlying +// write-- there is no way to do that with the standard go io +// interface. So the read and write _will_ happen or hang. So, use +// this sparingly, make sure to cancel the read or write as necesary +// (e.g. closing a connection whose context is up, etc.) +// +// Furthermore, in order to protect your memory from being read +// _after_ you've cancelled the context, this io.Writer will +// first make a **copy** of the buffer. +func NewWriter(ctx context.Context, w io.Writer) *ctxWriter { + if ctx == nil { + ctx = context.Background() + } + return &ctxWriter{ctx: ctx, w: w} +} + +func (w *ctxWriter) Write(buf []byte) (int, error) { + buf2 := make([]byte, len(buf)) + copy(buf2, buf) + + c := make(chan ioret) + + go func() { + n, err := w.w.Write(buf2) + c <- ioret{n, err} + close(c) + }() + + select { + case r := <-c: + return r.n, r.err + case <-w.ctx.Done(): + return 0, w.ctx.Err() + } +} + +type Reader interface { + io.Reader +} + +type ctxReader struct { + r io.Reader + ctx context.Context +} + +// NewReader wraps a reader to make it respect given Context. +// If there is a blocking read, the returned Reader will return +// whenever the context is cancelled (the return values are n=0 +// and err=ctx.Err().) +// +// Note well: this wrapper DOES NOT ACTUALLY cancel the underlying +// write-- there is no way to do that with the standard go io +// interface. So the read and write _will_ happen or hang. So, use +// this sparingly, make sure to cancel the read or write as necesary +// (e.g. closing a connection whose context is up, etc.) +// +// Furthermore, in order to protect your memory from being read +// _before_ you've cancelled the context, this io.Reader will +// allocate a buffer of the same size, and **copy** into the client's +// if the read succeeds in time. +func NewReader(ctx context.Context, r io.Reader) *ctxReader { + return &ctxReader{ctx: ctx, r: r} +} + +func (r *ctxReader) Read(buf []byte) (int, error) { + buf2 := make([]byte, len(buf)) + + c := make(chan ioret) + + go func() { + n, err := r.r.Read(buf2) + c <- ioret{n, err} + close(c) + }() + + select { + case ret := <-c: + copy(buf, buf2) + return ret.n, ret.err + case <-r.ctx.Done(): + return 0, r.ctx.Err() + } +} diff --git a/util/ctx/ctxio_test.go b/util/ctx/ctxio_test.go new file mode 100644 index 000000000..4104fb4a0 --- /dev/null +++ b/util/ctx/ctxio_test.go @@ -0,0 +1,273 @@ +package ctxutil + +import ( + "bytes" + "io" + "testing" + "time" + + context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" +) + +func TestReader(t *testing.T) { + buf := []byte("abcdef") + buf2 := make([]byte, 3) + r := NewReader(context.Background(), bytes.NewReader(buf)) + + // read first half + n, err := r.Read(buf2) + if n != 3 { + t.Error("n should be 3") + } + if err != nil { + t.Error("should have no error") + } + if string(buf2) != string(buf[:3]) { + t.Error("incorrect contents") + } + + // read second half + n, err = r.Read(buf2) + if n != 3 { + t.Error("n should be 3") + } + if err != nil { + t.Error("should have no error") + } + if string(buf2) != string(buf[3:6]) { + t.Error("incorrect contents") + } + + // read more. + n, err = r.Read(buf2) + if n != 0 { + t.Error("n should be 0", n) + } + if err != io.EOF { + t.Error("should be EOF", err) + } +} + +func TestWriter(t *testing.T) { + var buf bytes.Buffer + w := NewWriter(context.Background(), &buf) + + // write three + n, err := w.Write([]byte("abc")) + if n != 3 { + t.Error("n should be 3") + } + if err != nil { + t.Error("should have no error") + } + if string(buf.Bytes()) != string("abc") { + t.Error("incorrect contents") + } + + // write three more + n, err = w.Write([]byte("def")) + if n != 3 { + t.Error("n should be 3") + } + if err != nil { + t.Error("should have no error") + } + if string(buf.Bytes()) != string("abcdef") { + t.Error("incorrect contents") + } +} + +func TestReaderCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + piper, pipew := io.Pipe() + r := NewReader(ctx, piper) + + buf := make([]byte, 10) + done := make(chan ioret) + + go func() { + n, err := r.Read(buf) + done <- ioret{n, err} + }() + + pipew.Write([]byte("abcdefghij")) + + select { + case ret := <-done: + if ret.n != 10 { + t.Error("ret.n should be 10", ret.n) + } + if ret.err != nil { + t.Error("ret.err should be nil", ret.err) + } + if string(buf) != "abcdefghij" { + t.Error("read contents differ") + } + case <-time.After(20 * time.Millisecond): + t.Fatal("failed to read") + } + + go func() { + n, err := r.Read(buf) + done <- ioret{n, err} + }() + + cancel() + + select { + case ret := <-done: + if ret.n != 0 { + t.Error("ret.n should be 0", ret.n) + } + if ret.err == nil { + t.Error("ret.err should be ctx error", ret.err) + } + case <-time.After(20 * time.Millisecond): + t.Fatal("failed to stop reading after cancel") + } +} + +func TestWriterCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + piper, pipew := io.Pipe() + w := NewWriter(ctx, pipew) + + buf := make([]byte, 10) + done := make(chan ioret) + + go func() { + n, err := w.Write([]byte("abcdefghij")) + done <- ioret{n, err} + }() + + piper.Read(buf) + + select { + case ret := <-done: + if ret.n != 10 { + t.Error("ret.n should be 10", ret.n) + } + if ret.err != nil { + t.Error("ret.err should be nil", ret.err) + } + if string(buf) != "abcdefghij" { + t.Error("write contents differ") + } + case <-time.After(20 * time.Millisecond): + t.Fatal("failed to write") + } + + go func() { + n, err := w.Write([]byte("abcdefghij")) + done <- ioret{n, err} + }() + + cancel() + + select { + case ret := <-done: + if ret.n != 0 { + t.Error("ret.n should be 0", ret.n) + } + if ret.err == nil { + t.Error("ret.err should be ctx error", ret.err) + } + case <-time.After(20 * time.Millisecond): + t.Fatal("failed to stop writing after cancel") + } +} + +func TestReadPostCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + piper, pipew := io.Pipe() + r := NewReader(ctx, piper) + + buf := make([]byte, 10) + done := make(chan ioret) + + go func() { + n, err := r.Read(buf) + done <- ioret{n, err} + }() + + cancel() + + select { + case ret := <-done: + if ret.n != 0 { + t.Error("ret.n should be 0", ret.n) + } + if ret.err == nil { + t.Error("ret.err should be ctx error", ret.err) + } + case <-time.After(20 * time.Millisecond): + t.Fatal("failed to stop reading after cancel") + } + + pipew.Write([]byte("abcdefghij")) + + if !bytes.Equal(buf, make([]byte, len(buf))) { + t.Fatal("buffer should have not been written to") + } +} + +func TestWritePostCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + piper, pipew := io.Pipe() + w := NewWriter(ctx, pipew) + + buf := []byte("abcdefghij") + buf2 := make([]byte, 10) + done := make(chan ioret) + + go func() { + n, err := w.Write(buf) + done <- ioret{n, err} + }() + + piper.Read(buf2) + + select { + case ret := <-done: + if ret.n != 10 { + t.Error("ret.n should be 10", ret.n) + } + if ret.err != nil { + t.Error("ret.err should be nil", ret.err) + } + if string(buf2) != "abcdefghij" { + t.Error("write contents differ") + } + case <-time.After(20 * time.Millisecond): + t.Fatal("failed to write") + } + + go func() { + n, err := w.Write(buf) + done <- ioret{n, err} + }() + + cancel() + + select { + case ret := <-done: + if ret.n != 0 { + t.Error("ret.n should be 0", ret.n) + } + if ret.err == nil { + t.Error("ret.err should be ctx error", ret.err) + } + case <-time.After(20 * time.Millisecond): + t.Fatal("failed to stop writing after cancel") + } + + copy(buf, []byte("aaaaaaaaaa")) + + piper.Read(buf2) + + if string(buf2) == "aaaaaaaaaa" { + t.Error("buffer was read from after ctx cancel") + } else if string(buf2) != "abcdefghij" { + t.Error("write contents differ from expected") + } +}