mirror of
https://github.com/ipfs/kubo.git
synced 2025-06-29 17:36:38 +08:00
ctxio -- io with a context.
This commit introduces a reader and writer that respect contexts. Warning: careful how you use them. Returning leaves a goroutine reading until the read finishes.
This commit is contained in:
110
util/ctx/ctxio.go
Normal file
110
util/ctx/ctxio.go
Normal file
@ -0,0 +1,110 @@
|
||||
package ctxutil
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
|
||||
)
|
||||
|
||||
type ioret struct {
|
||||
n int
|
||||
err error
|
||||
}
|
||||
|
||||
type Writer interface {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
type ctxWriter struct {
|
||||
w io.Writer
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewWriter wraps a writer to make it respect given Context.
|
||||
// If there is a blocking write, the returned Writer will return
|
||||
// whenever the context is cancelled (the return values are n=0
|
||||
// and err=ctx.Err().)
|
||||
//
|
||||
// Note well: this wrapper DOES NOT ACTUALLY cancel the underlying
|
||||
// write-- there is no way to do that with the standard go io
|
||||
// interface. So the read and write _will_ happen or hang. So, use
|
||||
// this sparingly, make sure to cancel the read or write as necesary
|
||||
// (e.g. closing a connection whose context is up, etc.)
|
||||
//
|
||||
// Furthermore, in order to protect your memory from being read
|
||||
// _after_ you've cancelled the context, this io.Writer will
|
||||
// first make a **copy** of the buffer.
|
||||
func NewWriter(ctx context.Context, w io.Writer) *ctxWriter {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return &ctxWriter{ctx: ctx, w: w}
|
||||
}
|
||||
|
||||
func (w *ctxWriter) Write(buf []byte) (int, error) {
|
||||
buf2 := make([]byte, len(buf))
|
||||
copy(buf2, buf)
|
||||
|
||||
c := make(chan ioret)
|
||||
|
||||
go func() {
|
||||
n, err := w.w.Write(buf2)
|
||||
c <- ioret{n, err}
|
||||
close(c)
|
||||
}()
|
||||
|
||||
select {
|
||||
case r := <-c:
|
||||
return r.n, r.err
|
||||
case <-w.ctx.Done():
|
||||
return 0, w.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
type Reader interface {
|
||||
io.Reader
|
||||
}
|
||||
|
||||
type ctxReader struct {
|
||||
r io.Reader
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewReader wraps a reader to make it respect given Context.
|
||||
// If there is a blocking read, the returned Reader will return
|
||||
// whenever the context is cancelled (the return values are n=0
|
||||
// and err=ctx.Err().)
|
||||
//
|
||||
// Note well: this wrapper DOES NOT ACTUALLY cancel the underlying
|
||||
// write-- there is no way to do that with the standard go io
|
||||
// interface. So the read and write _will_ happen or hang. So, use
|
||||
// this sparingly, make sure to cancel the read or write as necesary
|
||||
// (e.g. closing a connection whose context is up, etc.)
|
||||
//
|
||||
// Furthermore, in order to protect your memory from being read
|
||||
// _before_ you've cancelled the context, this io.Reader will
|
||||
// allocate a buffer of the same size, and **copy** into the client's
|
||||
// if the read succeeds in time.
|
||||
func NewReader(ctx context.Context, r io.Reader) *ctxReader {
|
||||
return &ctxReader{ctx: ctx, r: r}
|
||||
}
|
||||
|
||||
func (r *ctxReader) Read(buf []byte) (int, error) {
|
||||
buf2 := make([]byte, len(buf))
|
||||
|
||||
c := make(chan ioret)
|
||||
|
||||
go func() {
|
||||
n, err := r.r.Read(buf2)
|
||||
c <- ioret{n, err}
|
||||
close(c)
|
||||
}()
|
||||
|
||||
select {
|
||||
case ret := <-c:
|
||||
copy(buf, buf2)
|
||||
return ret.n, ret.err
|
||||
case <-r.ctx.Done():
|
||||
return 0, r.ctx.Err()
|
||||
}
|
||||
}
|
273
util/ctx/ctxio_test.go
Normal file
273
util/ctx/ctxio_test.go
Normal file
@ -0,0 +1,273 @@
|
||||
package ctxutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
|
||||
)
|
||||
|
||||
func TestReader(t *testing.T) {
|
||||
buf := []byte("abcdef")
|
||||
buf2 := make([]byte, 3)
|
||||
r := NewReader(context.Background(), bytes.NewReader(buf))
|
||||
|
||||
// read first half
|
||||
n, err := r.Read(buf2)
|
||||
if n != 3 {
|
||||
t.Error("n should be 3")
|
||||
}
|
||||
if err != nil {
|
||||
t.Error("should have no error")
|
||||
}
|
||||
if string(buf2) != string(buf[:3]) {
|
||||
t.Error("incorrect contents")
|
||||
}
|
||||
|
||||
// read second half
|
||||
n, err = r.Read(buf2)
|
||||
if n != 3 {
|
||||
t.Error("n should be 3")
|
||||
}
|
||||
if err != nil {
|
||||
t.Error("should have no error")
|
||||
}
|
||||
if string(buf2) != string(buf[3:6]) {
|
||||
t.Error("incorrect contents")
|
||||
}
|
||||
|
||||
// read more.
|
||||
n, err = r.Read(buf2)
|
||||
if n != 0 {
|
||||
t.Error("n should be 0", n)
|
||||
}
|
||||
if err != io.EOF {
|
||||
t.Error("should be EOF", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
w := NewWriter(context.Background(), &buf)
|
||||
|
||||
// write three
|
||||
n, err := w.Write([]byte("abc"))
|
||||
if n != 3 {
|
||||
t.Error("n should be 3")
|
||||
}
|
||||
if err != nil {
|
||||
t.Error("should have no error")
|
||||
}
|
||||
if string(buf.Bytes()) != string("abc") {
|
||||
t.Error("incorrect contents")
|
||||
}
|
||||
|
||||
// write three more
|
||||
n, err = w.Write([]byte("def"))
|
||||
if n != 3 {
|
||||
t.Error("n should be 3")
|
||||
}
|
||||
if err != nil {
|
||||
t.Error("should have no error")
|
||||
}
|
||||
if string(buf.Bytes()) != string("abcdef") {
|
||||
t.Error("incorrect contents")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReaderCancel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
piper, pipew := io.Pipe()
|
||||
r := NewReader(ctx, piper)
|
||||
|
||||
buf := make([]byte, 10)
|
||||
done := make(chan ioret)
|
||||
|
||||
go func() {
|
||||
n, err := r.Read(buf)
|
||||
done <- ioret{n, err}
|
||||
}()
|
||||
|
||||
pipew.Write([]byte("abcdefghij"))
|
||||
|
||||
select {
|
||||
case ret := <-done:
|
||||
if ret.n != 10 {
|
||||
t.Error("ret.n should be 10", ret.n)
|
||||
}
|
||||
if ret.err != nil {
|
||||
t.Error("ret.err should be nil", ret.err)
|
||||
}
|
||||
if string(buf) != "abcdefghij" {
|
||||
t.Error("read contents differ")
|
||||
}
|
||||
case <-time.After(20 * time.Millisecond):
|
||||
t.Fatal("failed to read")
|
||||
}
|
||||
|
||||
go func() {
|
||||
n, err := r.Read(buf)
|
||||
done <- ioret{n, err}
|
||||
}()
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case ret := <-done:
|
||||
if ret.n != 0 {
|
||||
t.Error("ret.n should be 0", ret.n)
|
||||
}
|
||||
if ret.err == nil {
|
||||
t.Error("ret.err should be ctx error", ret.err)
|
||||
}
|
||||
case <-time.After(20 * time.Millisecond):
|
||||
t.Fatal("failed to stop reading after cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriterCancel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
piper, pipew := io.Pipe()
|
||||
w := NewWriter(ctx, pipew)
|
||||
|
||||
buf := make([]byte, 10)
|
||||
done := make(chan ioret)
|
||||
|
||||
go func() {
|
||||
n, err := w.Write([]byte("abcdefghij"))
|
||||
done <- ioret{n, err}
|
||||
}()
|
||||
|
||||
piper.Read(buf)
|
||||
|
||||
select {
|
||||
case ret := <-done:
|
||||
if ret.n != 10 {
|
||||
t.Error("ret.n should be 10", ret.n)
|
||||
}
|
||||
if ret.err != nil {
|
||||
t.Error("ret.err should be nil", ret.err)
|
||||
}
|
||||
if string(buf) != "abcdefghij" {
|
||||
t.Error("write contents differ")
|
||||
}
|
||||
case <-time.After(20 * time.Millisecond):
|
||||
t.Fatal("failed to write")
|
||||
}
|
||||
|
||||
go func() {
|
||||
n, err := w.Write([]byte("abcdefghij"))
|
||||
done <- ioret{n, err}
|
||||
}()
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case ret := <-done:
|
||||
if ret.n != 0 {
|
||||
t.Error("ret.n should be 0", ret.n)
|
||||
}
|
||||
if ret.err == nil {
|
||||
t.Error("ret.err should be ctx error", ret.err)
|
||||
}
|
||||
case <-time.After(20 * time.Millisecond):
|
||||
t.Fatal("failed to stop writing after cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadPostCancel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
piper, pipew := io.Pipe()
|
||||
r := NewReader(ctx, piper)
|
||||
|
||||
buf := make([]byte, 10)
|
||||
done := make(chan ioret)
|
||||
|
||||
go func() {
|
||||
n, err := r.Read(buf)
|
||||
done <- ioret{n, err}
|
||||
}()
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case ret := <-done:
|
||||
if ret.n != 0 {
|
||||
t.Error("ret.n should be 0", ret.n)
|
||||
}
|
||||
if ret.err == nil {
|
||||
t.Error("ret.err should be ctx error", ret.err)
|
||||
}
|
||||
case <-time.After(20 * time.Millisecond):
|
||||
t.Fatal("failed to stop reading after cancel")
|
||||
}
|
||||
|
||||
pipew.Write([]byte("abcdefghij"))
|
||||
|
||||
if !bytes.Equal(buf, make([]byte, len(buf))) {
|
||||
t.Fatal("buffer should have not been written to")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWritePostCancel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
piper, pipew := io.Pipe()
|
||||
w := NewWriter(ctx, pipew)
|
||||
|
||||
buf := []byte("abcdefghij")
|
||||
buf2 := make([]byte, 10)
|
||||
done := make(chan ioret)
|
||||
|
||||
go func() {
|
||||
n, err := w.Write(buf)
|
||||
done <- ioret{n, err}
|
||||
}()
|
||||
|
||||
piper.Read(buf2)
|
||||
|
||||
select {
|
||||
case ret := <-done:
|
||||
if ret.n != 10 {
|
||||
t.Error("ret.n should be 10", ret.n)
|
||||
}
|
||||
if ret.err != nil {
|
||||
t.Error("ret.err should be nil", ret.err)
|
||||
}
|
||||
if string(buf2) != "abcdefghij" {
|
||||
t.Error("write contents differ")
|
||||
}
|
||||
case <-time.After(20 * time.Millisecond):
|
||||
t.Fatal("failed to write")
|
||||
}
|
||||
|
||||
go func() {
|
||||
n, err := w.Write(buf)
|
||||
done <- ioret{n, err}
|
||||
}()
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case ret := <-done:
|
||||
if ret.n != 0 {
|
||||
t.Error("ret.n should be 0", ret.n)
|
||||
}
|
||||
if ret.err == nil {
|
||||
t.Error("ret.err should be ctx error", ret.err)
|
||||
}
|
||||
case <-time.After(20 * time.Millisecond):
|
||||
t.Fatal("failed to stop writing after cancel")
|
||||
}
|
||||
|
||||
copy(buf, []byte("aaaaaaaaaa"))
|
||||
|
||||
piper.Read(buf2)
|
||||
|
||||
if string(buf2) == "aaaaaaaaaa" {
|
||||
t.Error("buffer was read from after ctx cancel")
|
||||
} else if string(buf2) != "abcdefghij" {
|
||||
t.Error("write contents differ from expected")
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user