diff --git a/server.go b/server.go index a52d7a3d..f6516216 100644 --- a/server.go +++ b/server.go @@ -479,8 +479,12 @@ func (s *Server) Serve(lis net.Listener) error { s.serveWG.Add(1) defer func() { s.serveWG.Done() - // Block until Stop or GracefulStop is ready for us to return. - <-s.done + select { + // Stop or GracefulStop called; block until done and return nil. + case <-s.quit: + <-s.done + default: + } }() s.lis[lis] = true @@ -526,7 +530,6 @@ func (s *Server) Serve(lis net.Listener) error { s.printf("done serving; Accept = %v", err) s.mu.Unlock() - // If Stop or GracefulStop is called, return nil. select { case <-s.quit: return nil diff --git a/test/end2end_test.go b/test/end2end_test.go index 21c0cee3..7feea666 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -5771,3 +5771,51 @@ func testCompressorRegister(t *testing.T, e env) { t.Fatalf("%v.Recv() = %v, want ", stream, err) } } + +func TestServeExitsWhenListenerClosed(t *testing.T) { + defer leakcheck.Check(t) + + ss := &stubServer{ + emptyCall: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + + s := grpc.NewServer() + testpb.RegisterTestServiceServer(s, ss) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + + done := make(chan struct{}) + go func() { + s.Serve(lis) + close(done) + }() + + cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure(), grpc.WithBlock()) + if err != nil { + t.Fatalf("Failed to dial server: %v", err) + } + defer cc.Close() + c := testpb.NewTestServiceClient(cc) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := c.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("Failed to send test RPC to server: %v", err) + } + + if err := lis.Close(); err != nil { + t.Fatalf("Failed to close listener: %v", err) + } + const timeout = 5 * time.Second + timer := time.NewTimer(timeout) + select { + case <-done: + return + case <-timer.C: + t.Fatalf("Serve did not return after %v", timeout) + } +}