diff --git a/transport/http2_client_transport.go b/transport/http2_client.go similarity index 99% rename from transport/http2_client_transport.go rename to transport/http2_client.go index f2147509..9b2840a1 100644 --- a/transport/http2_client_transport.go +++ b/transport/http2_client.go @@ -35,6 +35,7 @@ package transport import ( "bytes" + "errors" "io" "log" "math" @@ -315,6 +316,10 @@ func (t *http2Client) CloseStream(s *Stream, err error) { // accessed any more. func (t *http2Client) Close() (err error) { t.mu.Lock() + if t.state == closing { + t.mu.Unlock() + return errors.New("transport: Close() was already called") + } t.state = closing t.mu.Unlock() close(t.shutdownChan) diff --git a/transport/http2_server_transport.go b/transport/http2_server.go similarity index 99% rename from transport/http2_server_transport.go rename to transport/http2_server.go index 633e3f51..53a09bf3 100644 --- a/transport/http2_server_transport.go +++ b/transport/http2_server.go @@ -79,7 +79,7 @@ type http2Server struct { // sendQuotaPool provides flow control to outbound message. sendQuotaPool *quotaPool - mu sync.Mutex + mu sync.Mutex // guard the following state transportState activeStreams map[uint32]*Stream // Inbound quota for flow control @@ -570,7 +570,7 @@ func (t *http2Server) Close() (err error) { t.mu.Lock() if t.state == closing { t.mu.Unlock() - return + return errors.New("transport: Close() was already called") } t.state = closing streams := t.activeStreams diff --git a/transport/transport.go b/transport/transport.go index 2cd211e6..406d581f 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -186,7 +186,7 @@ type Stream struct { // The key-value map of trailer metadata. trailer metadata.MD - mu sync.RWMutex + mu sync.RWMutex // guard the following // headerOK becomes true from the first header is about to send. headerOk bool state streamState diff --git a/transport/transport_test.go b/transport/transport_test.go index bac71160..e768f7a5 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -158,12 +158,12 @@ func (s *server) Wait(t *testing.T, timeout time.Duration) { } func (s *server) Close() { + // Keep consistent with closeServer(). s.lis.Close() s.mu.Lock() for c := range s.conns { c.Close() } - s.conns = nil s.mu.Unlock() } @@ -227,8 +227,8 @@ func TestClientSendAndReceive(t *testing.T) { if recvErr != io.EOF { t.Fatalf("Error: %v; want ", recvErr) } - ct.Close() - server.Close() + closeClient(ct, t) + closeServer(server, t) } func TestClientErrorNotify(t *testing.T) { @@ -245,10 +245,10 @@ func TestClientErrorNotify(t *testing.T) { t.Fatalf("wrong stream id: %d", s.id) } // Tear down the server. - go server.Close() + go closeServer(server, t) // ct.reader should detect the error and activate ct.Error(). <-ct.Error() - ct.Close() + closeClient(ct, t) } func performOneRPC(ct ClientTransport) { @@ -281,11 +281,11 @@ func TestClientMix(t *testing.T) { s, ct := setUp(t, true, 0, math.MaxUint32, false) go func(s *server) { time.Sleep(5 * time.Second) - s.Close() + closeServer(s, t) }(s) - go func(t ClientTransport) { + go func(ct ClientTransport) { <-ct.Error() - ct.Close() + closeClient(ct, t) }(ct) for i := 0; i < 1000; i++ { time.Sleep(10 * time.Millisecond) @@ -296,8 +296,8 @@ func TestClientMix(t *testing.T) { func TestExceedMaxStreamsLimit(t *testing.T) { server, ct := setUp(t, true, 0, 1, false) defer func() { - ct.Close() - server.Close() + closeClient(ct, t) + closeServer(server, t) }() callHdr := &CallHdr{ Host: "localhost", @@ -371,8 +371,8 @@ func TestLargeMessage(t *testing.T) { }() } wg.Wait() - ct.Close() - server.Close() + closeClient(ct, t) + closeServer(server, t) } func TestLargeMessageSuspension(t *testing.T) { @@ -393,8 +393,8 @@ func TestLargeMessageSuspension(t *testing.T) { if err == nil || err != expectedErr { t.Fatalf("Write got %v, want %v", err, expectedErr) } - ct.Close() - server.Close() + closeClient(ct, t) + closeServer(server, t) } func TestStreamContext(t *testing.T) { @@ -405,3 +405,53 @@ func TestStreamContext(t *testing.T) { t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, *s, ok, expectedStream) } } + +// closeClient shuts down the ClientTransport and reports any errors to the +// test framework and terminates the current test case. +func closeClient(ct ClientTransport, t *testing.T) { + if err := ct.Close(); err != nil { + t.Fatalf("ct.Close() = %v, want ", err) + } +} + +// closeServerWithErr shuts down the testing server, closing the associated +// transports. It returns the first error it encounters, if any. +func closeServerWithErr(s *server) error { + // Keep consistent with s.Close(). + s.lis.Close() + s.mu.Lock() + defer s.mu.Unlock() + for c := range s.conns { + if err := c.Close(); err != nil { + return err + } + } + return nil +} + +// closeServer shuts down the and testing server, closing the associated +// transport. It reports any errors to the test framework and terminates the +// current test case. +func closeServer(s *server, t *testing.T) { + if err := closeServerWithErr(s); err != nil { + t.Fatalf("server.Close() = %v, want ", err) + } +} + +func TestClientServerDuplicatedClose(t *testing.T) { + server, ct := setUp(t, true, 0, math.MaxUint32, false) + if err := ct.Close(); err != nil { + t.Fatalf("ct.Close() = %v, want ", err) + } + if err := ct.Close(); err == nil { + // Duplicated closes should gracefully issue an error. + t.Fatalf("ct.Close() = , want non-nil") + } + if err := closeServerWithErr(server); err != nil { + t.Fatalf("closeServerWithErr(server) = %v, want ", err) + } + if err := closeServerWithErr(server); err == nil { + // Duplicated closes should gracefully issue an error. + t.Fatalf("closeServerWithErr(server) = , want non-nil") + } +}