From 6340c9ec1d6d34fcee24dcb01f5df665f5839539 Mon Sep 17 00:00:00 2001 From: "Matt T. Proud" Date: Mon, 23 Feb 2015 11:32:45 +0100 Subject: [PATCH] Gracefully deny supplemental transport shutdowns. This commit ensures that transport shutdowns do not panic on supplemental shutdowns, even if users should not attempt multiple shutdowns. This is done to make the surface for users a little more forgiving. The _transport suffix in these implementation filenames are dropped since they are already part of the transport package, which makes the specification both redundant and adds stutter. TEST=``go test ./...`` --- ...p2_client_transport.go => http2_client.go} | 5 ++ ...p2_server_transport.go => http2_server.go} | 4 +- transport/transport.go | 2 +- transport/transport_test.go | 78 +++++++++++++++---- 4 files changed, 72 insertions(+), 17 deletions(-) rename transport/{http2_client_transport.go => http2_client.go} (99%) rename transport/{http2_server_transport.go => http2_server.go} (99%) diff --git a/transport/http2_client_transport.go b/transport/http2_client.go similarity index 99% rename from transport/http2_client_transport.go rename to transport/http2_client.go index f2147509..9b2840a1 100644 --- a/transport/http2_client_transport.go +++ b/transport/http2_client.go @@ -35,6 +35,7 @@ package transport import ( "bytes" + "errors" "io" "log" "math" @@ -315,6 +316,10 @@ func (t *http2Client) CloseStream(s *Stream, err error) { // accessed any more. func (t *http2Client) Close() (err error) { t.mu.Lock() + if t.state == closing { + t.mu.Unlock() + return errors.New("transport: Close() was already called") + } t.state = closing t.mu.Unlock() close(t.shutdownChan) diff --git a/transport/http2_server_transport.go b/transport/http2_server.go similarity index 99% rename from transport/http2_server_transport.go rename to transport/http2_server.go index 633e3f51..53a09bf3 100644 --- a/transport/http2_server_transport.go +++ b/transport/http2_server.go @@ -79,7 +79,7 @@ type http2Server struct { // sendQuotaPool provides flow control to outbound message. sendQuotaPool *quotaPool - mu sync.Mutex + mu sync.Mutex // guard the following state transportState activeStreams map[uint32]*Stream // Inbound quota for flow control @@ -570,7 +570,7 @@ func (t *http2Server) Close() (err error) { t.mu.Lock() if t.state == closing { t.mu.Unlock() - return + return errors.New("transport: Close() was already called") } t.state = closing streams := t.activeStreams diff --git a/transport/transport.go b/transport/transport.go index 2cd211e6..406d581f 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -186,7 +186,7 @@ type Stream struct { // The key-value map of trailer metadata. trailer metadata.MD - mu sync.RWMutex + mu sync.RWMutex // guard the following // headerOK becomes true from the first header is about to send. headerOk bool state streamState diff --git a/transport/transport_test.go b/transport/transport_test.go index bac71160..e768f7a5 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -158,12 +158,12 @@ func (s *server) Wait(t *testing.T, timeout time.Duration) { } func (s *server) Close() { + // Keep consistent with closeServer(). s.lis.Close() s.mu.Lock() for c := range s.conns { c.Close() } - s.conns = nil s.mu.Unlock() } @@ -227,8 +227,8 @@ func TestClientSendAndReceive(t *testing.T) { if recvErr != io.EOF { t.Fatalf("Error: %v; want ", recvErr) } - ct.Close() - server.Close() + closeClient(ct, t) + closeServer(server, t) } func TestClientErrorNotify(t *testing.T) { @@ -245,10 +245,10 @@ func TestClientErrorNotify(t *testing.T) { t.Fatalf("wrong stream id: %d", s.id) } // Tear down the server. - go server.Close() + go closeServer(server, t) // ct.reader should detect the error and activate ct.Error(). <-ct.Error() - ct.Close() + closeClient(ct, t) } func performOneRPC(ct ClientTransport) { @@ -281,11 +281,11 @@ func TestClientMix(t *testing.T) { s, ct := setUp(t, true, 0, math.MaxUint32, false) go func(s *server) { time.Sleep(5 * time.Second) - s.Close() + closeServer(s, t) }(s) - go func(t ClientTransport) { + go func(ct ClientTransport) { <-ct.Error() - ct.Close() + closeClient(ct, t) }(ct) for i := 0; i < 1000; i++ { time.Sleep(10 * time.Millisecond) @@ -296,8 +296,8 @@ func TestClientMix(t *testing.T) { func TestExceedMaxStreamsLimit(t *testing.T) { server, ct := setUp(t, true, 0, 1, false) defer func() { - ct.Close() - server.Close() + closeClient(ct, t) + closeServer(server, t) }() callHdr := &CallHdr{ Host: "localhost", @@ -371,8 +371,8 @@ func TestLargeMessage(t *testing.T) { }() } wg.Wait() - ct.Close() - server.Close() + closeClient(ct, t) + closeServer(server, t) } func TestLargeMessageSuspension(t *testing.T) { @@ -393,8 +393,8 @@ func TestLargeMessageSuspension(t *testing.T) { if err == nil || err != expectedErr { t.Fatalf("Write got %v, want %v", err, expectedErr) } - ct.Close() - server.Close() + closeClient(ct, t) + closeServer(server, t) } func TestStreamContext(t *testing.T) { @@ -405,3 +405,53 @@ func TestStreamContext(t *testing.T) { t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, *s, ok, expectedStream) } } + +// closeClient shuts down the ClientTransport and reports any errors to the +// test framework and terminates the current test case. +func closeClient(ct ClientTransport, t *testing.T) { + if err := ct.Close(); err != nil { + t.Fatalf("ct.Close() = %v, want ", err) + } +} + +// closeServerWithErr shuts down the testing server, closing the associated +// transports. It returns the first error it encounters, if any. +func closeServerWithErr(s *server) error { + // Keep consistent with s.Close(). + s.lis.Close() + s.mu.Lock() + defer s.mu.Unlock() + for c := range s.conns { + if err := c.Close(); err != nil { + return err + } + } + return nil +} + +// closeServer shuts down the and testing server, closing the associated +// transport. It reports any errors to the test framework and terminates the +// current test case. +func closeServer(s *server, t *testing.T) { + if err := closeServerWithErr(s); err != nil { + t.Fatalf("server.Close() = %v, want ", err) + } +} + +func TestClientServerDuplicatedClose(t *testing.T) { + server, ct := setUp(t, true, 0, math.MaxUint32, false) + if err := ct.Close(); err != nil { + t.Fatalf("ct.Close() = %v, want ", err) + } + if err := ct.Close(); err == nil { + // Duplicated closes should gracefully issue an error. + t.Fatalf("ct.Close() = , want non-nil") + } + if err := closeServerWithErr(server); err != nil { + t.Fatalf("closeServerWithErr(server) = %v, want ", err) + } + if err := closeServerWithErr(server); err == nil { + // Duplicated closes should gracefully issue an error. + t.Fatalf("closeServerWithErr(server) = , want non-nil") + } +}