Fix possible data loss; Only let reader goroutine handle connection errors. (#1993)

* First commit.

* Post review updates.
This commit is contained in:
mmukhi
2018-05-11 13:51:50 -07:00
committed by GitHub
parent 091a800143
commit 4ab6e31b84
4 changed files with 51 additions and 58 deletions

View File

@ -20,6 +20,7 @@ package test
import ( import (
"fmt" "fmt"
"io"
"net" "net"
"sync" "sync"
"testing" "testing"
@ -86,39 +87,19 @@ func (d *delayListener) Dial(to time.Duration) (net.Conn, error) {
return d.cc, nil return d.cc, nil
} }
func (d *delayListener) clientWriteCalledChan() <-chan struct{} {
return d.cc.writeCalledChan()
}
type delayConn struct { type delayConn struct {
net.Conn net.Conn
blockRead chan struct{} blockRead chan struct{}
mu sync.Mutex
writeCalled chan struct{}
} }
func (d *delayConn) writeCalledChan() <-chan struct{} {
d.mu.Lock()
defer d.mu.Unlock()
d.writeCalled = make(chan struct{})
return d.writeCalled
}
func (d *delayConn) allowRead() { func (d *delayConn) allowRead() {
close(d.blockRead) close(d.blockRead)
} }
func (d *delayConn) Read(b []byte) (n int, err error) { func (d *delayConn) Read(b []byte) (n int, err error) {
<-d.blockRead <-d.blockRead
return d.Conn.Read(b) return d.Conn.Read(b)
} }
func (d *delayConn) Write(b []byte) (n int, err error) {
d.mu.Lock()
if d.writeCalled != nil {
close(d.writeCalled)
d.writeCalled = nil
}
d.mu.Unlock()
return d.Conn.Write(b)
}
func TestGracefulStop(t *testing.T) { func TestGracefulStop(t *testing.T) {
defer leakcheck.Check(t) defer leakcheck.Check(t)
@ -148,10 +129,16 @@ func TestGracefulStop(t *testing.T) {
allowCloseCh: make(chan struct{}), allowCloseCh: make(chan struct{}),
} }
d := func(_ string, to time.Duration) (net.Conn, error) { return dlis.Dial(to) } d := func(_ string, to time.Duration) (net.Conn, error) { return dlis.Dial(to) }
serverGotReq := make(chan struct{})
ss := &stubServer{ ss := &stubServer{
emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error {
return &testpb.Empty{}, nil close(serverGotReq)
_, err := stream.Recv()
if err != nil {
return err
}
return stream.Send(&testpb.StreamingOutputCallResponse{})
}, },
} }
s := grpc.NewServer() s := grpc.NewServer()
@ -190,25 +177,31 @@ func TestGracefulStop(t *testing.T) {
cancel() cancel()
client := testpb.NewTestServiceClient(cc) client := testpb.NewTestServiceClient(cc)
defer cc.Close() defer cc.Close()
dlis.allowClose() dlis.allowClose()
wcch := dlis.clientWriteCalledChan()
go func() {
// 5. Allow the client to read the GoAway. The RPC should complete
// successfully.
<-wcch
dlis.allowClientRead()
}()
// 4. Send an RPC on the new connection. // 4. Send an RPC on the new connection.
// The server would send a GOAWAY first, but we are delaying the server's // The server would send a GOAWAY first, but we are delaying the server's
// writes for now until the client writes more than the preface. // writes for now until the client writes more than the preface.
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil { stream, err := client.FullDuplexCall(ctx)
t.Fatalf("EmptyCall() = %v; want <nil>", err) if err != nil {
t.Fatalf("FullDuplexCall= _, %v; want _, <nil>", err)
}
go func() {
// 5. Allow the client to read the GoAway. The RPC should complete
// successfully.
<-serverGotReq
dlis.allowClientRead()
}()
if err := stream.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
t.Fatalf("stream.Send(_) = %v, want <nil>", err)
}
if _, err := stream.Recv(); err != nil {
t.Fatalf("stream.Recv() = _, %v, want _, <nil>", err)
}
if _, err := stream.Recv(); err != io.EOF {
t.Fatalf("stream.Recv() = _, %v, want _, io.EOF", err)
} }
// 5. happens above, then we finish the call. // 5. happens above, then we finish the call.
cancel() cancel()
wg.Wait() wg.Wait()

View File

@ -361,44 +361,37 @@ func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimato
const minBatchSize = 1000 const minBatchSize = 1000
// run should be run in a separate goroutine. // run should be run in a separate goroutine.
func (l *loopyWriter) run() { func (l *loopyWriter) run() error {
var (
it interface{}
err error
isEmpty bool
)
defer func() {
errorf("transport: loopyWriter.run returning. Err: %v", err)
}()
for { for {
it, err = l.cbuf.get(true) it, err := l.cbuf.get(true)
if err != nil { if err != nil {
return return err
} }
if err = l.handle(it); err != nil { if err = l.handle(it); err != nil {
return return err
} }
if _, err = l.processData(); err != nil { if _, err = l.processData(); err != nil {
return return err
} }
gosched := true gosched := true
hasdata: hasdata:
for { for {
it, err = l.cbuf.get(false) it, err := l.cbuf.get(false)
if err != nil { if err != nil {
return return err
} }
if it != nil { if it != nil {
if err = l.handle(it); err != nil { if err = l.handle(it); err != nil {
return return err
} }
if _, err = l.processData(); err != nil { if _, err = l.processData(); err != nil {
return return err
} }
continue hasdata continue hasdata
} }
if isEmpty, err = l.processData(); err != nil { isEmpty, err := l.processData()
return if err != nil {
return err
} }
if !isEmpty { if !isEmpty {
continue hasdata continue hasdata

View File

@ -295,8 +295,14 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
t.framer.writer.Flush() t.framer.writer.Flush()
go func() { go func() {
t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst) t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst)
t.loopy.run() err := t.loopy.run()
t.conn.Close() errorf("transport: loopyWriter.run returning. Err: %v", err)
// If it's a connection error, let reader goroutine handle it
// since there might be data in the buffers.
if _, ok := err.(net.Error); !ok {
t.conn.Close()
}
close(t.writerDone) close(t.writerDone)
}() }()
if t.kp.Time != infinity { if t.kp.Time != infinity {

View File

@ -273,7 +273,8 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
go func() { go func() {
t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst) t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst)
t.loopy.ssGoAwayHandler = t.outgoingGoAwayHandler t.loopy.ssGoAwayHandler = t.outgoingGoAwayHandler
t.loopy.run() err := t.loopy.run()
errorf("transport: loopyWriter.run returning. Err: %v", err)
t.conn.Close() t.conn.Close()
close(t.writerDone) close(t.writerDone)
}() }()