diff --git a/call.go b/call.go index 0df314d8..a8b6dcfd 100644 --- a/call.go +++ b/call.go @@ -52,6 +52,13 @@ import ( func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error { // Try to acquire header metadata from the server if there is any. var err error + defer func() { + if err != nil { + if _, ok := err.(transport.ConnectionError); !ok { + t.CloseStream(stream, err) + } + } + }() c.headerMD, err = stream.Header() if err != nil { return err @@ -191,20 +198,18 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli } return toRPCErr(err) } - // Receive the response err = recvResponse(cc.dopts, t, &c, stream, reply) if err != nil { if put != nil { put() put = nil } - if _, ok := err.(transport.ConnectionError); ok { + if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { if c.failFast { return toRPCErr(err) } continue } - t.CloseStream(stream, err) return toRPCErr(err) } if c.traceInfo.tr != nil { diff --git a/clientconn.go b/clientconn.go index 214fb900..6e018133 100644 --- a/clientconn.go +++ b/clientconn.go @@ -68,7 +68,7 @@ var ( // errCredentialsConflict indicates that grpc.WithTransportCredentials() // and grpc.WithInsecure() are both called for a connection. errCredentialsConflict = errors.New("grpc: transport credentials are set for an insecure connection (grpc.WithTransportCredentials() and grpc.WithInsecure() are both called)") - // errNetworkIP indicates that the connection is down due to some network I/O error. + // errNetworkIO indicates that the connection is down due to some network I/O error. errNetworkIO = errors.New("grpc: failed with network I/O error") // errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs. errConnDrain = errors.New("grpc: the connection is drained") @@ -196,9 +196,11 @@ func WithTimeout(d time.Duration) DialOption { } // WithDialer returns a DialOption that specifies a function to use for dialing network addresses. -func WithDialer(f func(string, time.Duration, <-chan struct{}) (net.Conn, error)) DialOption { +func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption { return func(o *dialOptions) { - o.copts.Dialer = f + o.copts.Dialer = func(addr string, timeout time.Duration, _ <-chan struct{}) (net.Conn, error) { + return f(addr, timeout) + } } } @@ -365,6 +367,7 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error { addr: addr, dopts: cc.dopts, } + ac.stateCV = sync.NewCond(&ac.mu) ac.dopts.copts.Cancel = make(chan struct{}) if EnableTracing { ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr) @@ -398,7 +401,6 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error { // ii) a buggy Balancer notifies duplicated Addresses. stale.tearDown(errConnDrain) } - ac.stateCV = sync.NewCond(&ac.mu) // skipWait may overwrite the decision in ac.dopts.block. if ac.dopts.block && !skipWait { if err := ac.resetTransport(false); err != nil { @@ -624,15 +626,41 @@ func (ac *addrConn) transportMonitor() { // Cancel is needed to detect the teardown when // the addrConn is idle (i.e., no RPC in flight). case <-ac.dopts.copts.Cancel: + select { + case <-t.Error(): + t.Close() + default: + } return case <-t.GoAway(): - ac.tearDown(errConnDrain) + // If GoAway happens without any network I/O error, ac is closed without shutting down the + // underlying transport (the transport will be closed when all the pending RPCs finished or + // failed.). + // If GoAway and some network I/O error happen concurrently, ac and its underlying transport + // are closed. + // In both cases, a new ac is created. + select { + case <-t.Error(): + ac.tearDown(errNetworkIO) + default: + ac.tearDown(errConnDrain) + } ac.cc.newAddrConn(ac.addr, true) return case <-t.Error(): + select { + case <-ac.dopts.copts.Cancel: + t.Close() + return + case <-t.GoAway(): + ac.tearDown(errNetworkIO) + ac.cc.newAddrConn(ac.addr, true) + return + default: + } ac.mu.Lock() if ac.state == Shutdown { - // ac.tearDown(...) has been invoked. + // ac has been shutdown. ac.mu.Unlock() return } diff --git a/examples/gotutorial.md b/examples/gotutorial.md index 39833fdf..25c0a2df 100644 --- a/examples/gotutorial.md +++ b/examples/gotutorial.md @@ -28,12 +28,12 @@ Then change your current directory to `grpc-go/examples/route_guide`: $ cd $GOPATH/src/google.golang.org/grpc/examples/route_guide ``` -You also should have the relevant tools installed to generate the server and client interface code - if you don't already, follow the setup instructions in [the Go quick start guide](examples/). +You also should have the relevant tools installed to generate the server and client interface code - if you don't already, follow the setup instructions in [the Go quick start guide](https://github.com/grpc/grpc-go/tree/master/examples/). ## Defining the service -Our first step (as you'll know from the [quick start](http://www.grpc.io/docs/#quick-start)) is to define the gRPC *service* and the method *request* and *response* types using [protocol buffers] (https://developers.google.com/protocol-buffers/docs/overview). You can see the complete .proto file in [`examples/route_guide/proto/route_guide.proto`](examples/route_guide/proto/route_guide.proto). +Our first step (as you'll know from the [quick start](http://www.grpc.io/docs/#quick-start)) is to define the gRPC *service* and the method *request* and *response* types using [protocol buffers] (https://developers.google.com/protocol-buffers/docs/overview). You can see the complete .proto file in [examples/route_guide/routeguide/route_guide.proto](https://github.com/grpc/grpc-go/tree/master/examples/route_guide/routeguide/route_guide.proto). To define a service, you specify a named `service` in your .proto file: diff --git a/server.go b/server.go index 6a5c1c14..90c265fd 100644 --- a/server.go +++ b/server.go @@ -793,6 +793,8 @@ func (s *Server) Stop() { s.lis = nil st := s.conns s.conns = nil + // interrupt GracefulStop if Stop and GracefulStop are called concurrently. + s.cv.Signal() s.mu.Unlock() for lis := range listeners { @@ -815,20 +817,20 @@ func (s *Server) Stop() { func (s *Server) GracefulStop() { s.mu.Lock() if s.drain == true || s.conns == nil { - s.mu.Lock() + s.mu.Unlock() return } s.drain = true for lis := range s.lis { lis.Close() } + s.lis = nil for c := range s.conns { c.(transport.ServerTransport).Drain() } for len(s.conns) != 0 { s.cv.Wait() } - s.lis = nil s.conns = nil if s.events != nil { s.events.Finish() diff --git a/stream.go b/stream.go index f06f137d..66bfad81 100644 --- a/stream.go +++ b/stream.go @@ -166,7 +166,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth put() put = nil } - if _, ok := err.(transport.ConnectionError); ok { + if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { if c.failFast { cs.finish(err) return nil, toRPCErr(err) diff --git a/test/end2end_test.go b/test/end2end_test.go index 04a6c950..c1953e56 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -314,13 +314,7 @@ func (e env) runnable() bool { return true } -func (e env) dialer(addr string, timeout time.Duration, cancel <-chan struct{}) (net.Conn, error) { - // NB: Go 1.6 added a Cancel field on net.Dialer, which would allow this - // to be written as - // - // `(&net.Dialer{Cancel: cancel, Timeout: timeout}).Dial(e.network, addr)` - // - // but that would break compatibility with earlier Go versions. +func (e env) dialer(addr string, timeout time.Duration) (net.Conn, error) { return net.DialTimeout(e.network, addr, timeout) } @@ -515,7 +509,7 @@ func (te *test) declareLogNoise(phrases ...string) { } func (te *test) withServerTester(fn func(st *serverTester)) { - c, err := te.e.dialer(te.srvAddr, 10*time.Second, nil) + c, err := te.e.dialer(te.srvAddr, 10*time.Second) if err != nil { te.t.Fatal(err) } @@ -563,6 +557,27 @@ func testTimeoutOnDeadServer(t *testing.T, e env) { awaitNewConnLogOutput() } +func TestServerGracefulStopIdempotent(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + if e.name == "handler-tls" { + continue + } + testServerGracefulStopIdempotent(t, e) + } +} + +func testServerGracefulStopIdempotent(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + for i := 0; i < 3; i++ { + te.srv.GracefulStop() + } +} + func TestServerGoAway(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() { @@ -576,12 +591,6 @@ func TestServerGoAway(t *testing.T) { func testServerGoAway(t *testing.T, e env) { te := newTest(t, e) te.userAgent = testAppUA - te.declareLogNoise( - "transport: http2Client.notifyError got notified that the client transport was broken EOF", - "grpc: Conn.transportMonitor exits due to: grpc: the client connection is closing", - "grpc: Conn.resetTransport failed to create client transport: connection error", - "grpc: Conn.resetTransport failed to create client transport: connection error: desc = \"transport: dial unix", - ) te.startServer(&testServer{security: e.security}) defer te.tearDown() @@ -684,6 +693,115 @@ func testServerGoAwayPendingRPC(t *testing.T, e env) { awaitNewConnLogOutput() } +func TestConcurrentClientConnCloseAndServerGoAway(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + if e.name == "handler-tls" { + continue + } + testConcurrentClientConnCloseAndServerGoAway(t, e) + } +} + +func testConcurrentClientConnCloseAndServerGoAway(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + te.declareLogNoise( + "transport: http2Client.notifyError got notified that the client transport was broken EOF", + "grpc: Conn.transportMonitor exits due to: grpc: the client connection is closing", + "grpc: Conn.resetTransport failed to create client transport: connection error", + "grpc: Conn.resetTransport failed to create client transport: connection error: desc = \"transport: dial unix", + ) + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { + t.Fatalf("%v.EmptyCall(_, _, _) = _, %v, want _, ", tc, err) + } + ch := make(chan struct{}) + // Close ClientConn and Server concurrently. + go func() { + te.srv.GracefulStop() + close(ch) + }() + go func() { + cc.Close() + }() + <-ch +} + +func TestConcurrentServerStopAndGoAway(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + if e.name == "handler-tls" { + continue + } + testConcurrentServerStopAndGoAway(t, e) + } +} + +func testConcurrentServerStopAndGoAway(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + te.declareLogNoise( + "transport: http2Client.notifyError got notified that the client transport was broken EOF", + "grpc: Conn.transportMonitor exits due to: grpc: the client connection is closing", + "grpc: Conn.resetTransport failed to create client transport: connection error", + "grpc: Conn.resetTransport failed to create client transport: connection error: desc = \"transport: dial unix", + ) + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + stream, err := tc.FullDuplexCall(context.Background(), grpc.FailFast(false)) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + // Finish an RPC to make sure the connection is good. + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { + t.Fatalf("%v.EmptyCall(_, _, _) = _, %v, want _, ", tc, err) + } + ch := make(chan struct{}) + go func() { + te.srv.GracefulStop() + close(ch) + }() + // Loop until the server side GoAway signal is propagated to the client. + for { + ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil { + continue + } + break + } + // Stop the server and close all the connections. + te.srv.Stop() + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(1), + }, + } + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(100)) + if err != nil { + t.Fatal(err) + } + req := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: payload, + } + if err := stream.Send(req); err == nil { + if _, err := stream.Recv(); err == nil { + t.Fatalf("%v.Recv() = _, %v, want _, ", stream, err) + } + } + <-ch + awaitNewConnLogOutput() +} + func TestFailFast(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() { @@ -2199,8 +2317,8 @@ func leakCheck(t testing.TB) func() { } return func() { // Loop, waiting for goroutines to shut down. - // Wait up to 5 seconds, but finish as quickly as possible. - deadline := time.Now().Add(5 * time.Second) + // Wait up to 10 seconds, but finish as quickly as possible. + deadline := time.Now().Add(10 * time.Second) for { var leaked []string for _, g := range interestingGoroutines() { diff --git a/transport/http2_client.go b/transport/http2_client.go index bcfcdf0a..6a34c7e5 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -284,6 +284,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea t.mu.Unlock() return nil, ErrConnClosing } + if t.state == draining { + t.mu.Unlock() + return nil, ErrStreamDrain + } if t.state != reachable { t.mu.Unlock() return nil, ErrConnClosing @@ -308,6 +312,15 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea return nil, err } t.mu.Lock() + if t.state == draining { + t.mu.Unlock() + if checkStreamsQuota { + t.streamsQuota.add(1) + } + // Need to make t writable again so that the rpc in flight can still proceed. + t.writableChan <- 0 + return nil, ErrStreamDrain + } if t.state != reachable { t.mu.Unlock() return nil, ErrConnClosing @@ -457,7 +470,7 @@ func (t *http2Client) Close() (err error) { t.mu.Unlock() return } - if t.state == reachable { + if t.state == reachable || t.state == draining { close(t.errorChan) } t.state = closing @@ -483,7 +496,7 @@ func (t *http2Client) Close() (err error) { func (t *http2Client) GracefulClose() error { t.mu.Lock() - if t.state == closing { + if t.state == closing || t.state == unreachable { t.mu.Unlock() return nil } @@ -994,11 +1007,16 @@ func (t *http2Client) GoAway() <-chan struct{} { func (t *http2Client) notifyError(err error) { t.mu.Lock() - defer t.mu.Unlock() // make sure t.errorChan is closed only once. + if t.state == draining { + t.mu.Unlock() + t.Close() + return + } if t.state == reachable { t.state = unreachable close(t.errorChan) grpclog.Printf("transport: http2Client.notifyError got notified that the client transport was broken %v.", err) } + t.mu.Unlock() } diff --git a/transport/http2_server.go b/transport/http2_server.go index 38715c59..b44a9d53 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -681,6 +681,11 @@ func (t *http2Server) controller() { t.framer.writeRSTStream(true, i.streamID, i.code) case *goAway: t.mu.Lock() + if t.state == closing { + t.mu.Unlock() + // The transport is closing. + return + } sid := t.maxStreamID t.state = draining t.mu.Unlock() diff --git a/transport/transport.go b/transport/transport.go index 86c8fcd6..10b4a2e2 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -240,6 +240,8 @@ func (s *Stream) Header() (metadata.MD, error) { select { case <-s.ctx.Done(): return nil, ContextErr(s.ctx.Err()) + case <-s.goAway: + return nil, ErrStreamDrain case <-s.headerChan: return s.header.Copy(), nil } diff --git a/transport/transport_test.go b/transport/transport_test.go index 6f9cc50c..f4af68b7 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -368,8 +368,8 @@ func TestGracefulClose(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - if _, err := ct.NewStream(context.Background(), callHdr); err != ErrConnClosing { - t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", ct, err, ErrConnClosing) + if _, err := ct.NewStream(context.Background(), callHdr); err != ErrStreamDrain { + t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", ct, err, ErrStreamDrain) } }() }