diff --git a/test/gracefulstop_test.go b/test/gracefulstop_test.go index 7ac12b09..7f8b207b 100644 --- a/test/gracefulstop_test.go +++ b/test/gracefulstop_test.go @@ -20,6 +20,7 @@ package test import ( "fmt" + "io" "net" "sync" "testing" @@ -86,39 +87,19 @@ func (d *delayListener) Dial(to time.Duration) (net.Conn, error) { return d.cc, nil } -func (d *delayListener) clientWriteCalledChan() <-chan struct{} { - return d.cc.writeCalledChan() -} - type delayConn struct { net.Conn - blockRead chan struct{} - mu sync.Mutex - writeCalled chan struct{} + blockRead chan struct{} } -func (d *delayConn) writeCalledChan() <-chan struct{} { - d.mu.Lock() - defer d.mu.Unlock() - d.writeCalled = make(chan struct{}) - return d.writeCalled -} func (d *delayConn) allowRead() { close(d.blockRead) } + func (d *delayConn) Read(b []byte) (n int, err error) { <-d.blockRead return d.Conn.Read(b) } -func (d *delayConn) Write(b []byte) (n int, err error) { - d.mu.Lock() - if d.writeCalled != nil { - close(d.writeCalled) - d.writeCalled = nil - } - d.mu.Unlock() - return d.Conn.Write(b) -} func TestGracefulStop(t *testing.T) { defer leakcheck.Check(t) @@ -148,10 +129,16 @@ func TestGracefulStop(t *testing.T) { allowCloseCh: make(chan struct{}), } d := func(_ string, to time.Duration) (net.Conn, error) { return dlis.Dial(to) } + serverGotReq := make(chan struct{}) ss := &stubServer{ - emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { - return &testpb.Empty{}, nil + fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error { + close(serverGotReq) + _, err := stream.Recv() + if err != nil { + return err + } + return stream.Send(&testpb.StreamingOutputCallResponse{}) }, } s := grpc.NewServer() @@ -190,25 +177,31 @@ func TestGracefulStop(t *testing.T) { cancel() client := testpb.NewTestServiceClient(cc) defer cc.Close() - dlis.allowClose() - wcch := dlis.clientWriteCalledChan() - go func() { - // 5. Allow the client to read the GoAway. The RPC should complete - // successfully. - <-wcch - dlis.allowClientRead() - }() - // 4. Send an RPC on the new connection. // The server would send a GOAWAY first, but we are delaying the server's // writes for now until the client writes more than the preface. ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) - if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil { - t.Fatalf("EmptyCall() = %v; want ", err) + stream, err := client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("FullDuplexCall= _, %v; want _, ", err) + } + go func() { + // 5. Allow the client to read the GoAway. The RPC should complete + // successfully. + <-serverGotReq + dlis.allowClientRead() + }() + if err := stream.Send(&testpb.StreamingOutputCallRequest{}); err != nil { + t.Fatalf("stream.Send(_) = %v, want ", err) + } + if _, err := stream.Recv(); err != nil { + t.Fatalf("stream.Recv() = _, %v, want _, ", err) + } + if _, err := stream.Recv(); err != io.EOF { + t.Fatalf("stream.Recv() = _, %v, want _, io.EOF", err) } - // 5. happens above, then we finish the call. cancel() wg.Wait() diff --git a/transport/controlbuf.go b/transport/controlbuf.go index e147cd51..b215b3a4 100644 --- a/transport/controlbuf.go +++ b/transport/controlbuf.go @@ -361,44 +361,37 @@ func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimato const minBatchSize = 1000 // run should be run in a separate goroutine. -func (l *loopyWriter) run() { - var ( - it interface{} - err error - isEmpty bool - ) - defer func() { - errorf("transport: loopyWriter.run returning. Err: %v", err) - }() +func (l *loopyWriter) run() error { for { - it, err = l.cbuf.get(true) + it, err := l.cbuf.get(true) if err != nil { - return + return err } if err = l.handle(it); err != nil { - return + return err } if _, err = l.processData(); err != nil { - return + return err } gosched := true hasdata: for { - it, err = l.cbuf.get(false) + it, err := l.cbuf.get(false) if err != nil { - return + return err } if it != nil { if err = l.handle(it); err != nil { - return + return err } if _, err = l.processData(); err != nil { - return + return err } continue hasdata } - if isEmpty, err = l.processData(); err != nil { - return + isEmpty, err := l.processData() + if err != nil { + return err } if !isEmpty { continue hasdata diff --git a/transport/http2_client.go b/transport/http2_client.go index 1fdabd95..edf4d6cb 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -295,8 +295,14 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne t.framer.writer.Flush() go func() { t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst) - t.loopy.run() - t.conn.Close() + err := t.loopy.run() + errorf("transport: loopyWriter.run returning. Err: %v", err) + // If it's a connection error, let reader goroutine handle it + // since there might be data in the buffers. + + if _, ok := err.(net.Error); !ok { + t.conn.Close() + } close(t.writerDone) }() if t.kp.Time != infinity { diff --git a/transport/http2_server.go b/transport/http2_server.go index 8b93e222..347edbda 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -273,7 +273,8 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err go func() { t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst) t.loopy.ssGoAwayHandler = t.outgoingGoAwayHandler - t.loopy.run() + err := t.loopy.run() + errorf("transport: loopyWriter.run returning. Err: %v", err) t.conn.Close() close(t.writerDone) }()