Merge pull request #52 from matttproud/refactor/constrained-err-scope

Gracefully deny supplemental transport shutdowns.
This commit is contained in:
Qi Zhao
2015-02-25 12:47:55 -08:00
4 changed files with 72 additions and 17 deletions

View File

@ -35,6 +35,7 @@ package transport
import ( import (
"bytes" "bytes"
"errors"
"io" "io"
"log" "log"
"math" "math"
@ -315,6 +316,10 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
// accessed any more. // accessed any more.
func (t *http2Client) Close() (err error) { func (t *http2Client) Close() (err error) {
t.mu.Lock() t.mu.Lock()
if t.state == closing {
t.mu.Unlock()
return errors.New("transport: Close() was already called")
}
t.state = closing t.state = closing
t.mu.Unlock() t.mu.Unlock()
close(t.shutdownChan) close(t.shutdownChan)

View File

@ -79,7 +79,7 @@ type http2Server struct {
// sendQuotaPool provides flow control to outbound message. // sendQuotaPool provides flow control to outbound message.
sendQuotaPool *quotaPool sendQuotaPool *quotaPool
mu sync.Mutex mu sync.Mutex // guard the following
state transportState state transportState
activeStreams map[uint32]*Stream activeStreams map[uint32]*Stream
// Inbound quota for flow control // Inbound quota for flow control
@ -570,7 +570,7 @@ func (t *http2Server) Close() (err error) {
t.mu.Lock() t.mu.Lock()
if t.state == closing { if t.state == closing {
t.mu.Unlock() t.mu.Unlock()
return return errors.New("transport: Close() was already called")
} }
t.state = closing t.state = closing
streams := t.activeStreams streams := t.activeStreams

View File

@ -186,7 +186,7 @@ type Stream struct {
// The key-value map of trailer metadata. // The key-value map of trailer metadata.
trailer metadata.MD trailer metadata.MD
mu sync.RWMutex mu sync.RWMutex // guard the following
// headerOK becomes true from the first header is about to send. // headerOK becomes true from the first header is about to send.
headerOk bool headerOk bool
state streamState state streamState

View File

@ -158,12 +158,12 @@ func (s *server) Wait(t *testing.T, timeout time.Duration) {
} }
func (s *server) Close() { func (s *server) Close() {
// Keep consistent with closeServer().
s.lis.Close() s.lis.Close()
s.mu.Lock() s.mu.Lock()
for c := range s.conns { for c := range s.conns {
c.Close() c.Close()
} }
s.conns = nil
s.mu.Unlock() s.mu.Unlock()
} }
@ -227,8 +227,8 @@ func TestClientSendAndReceive(t *testing.T) {
if recvErr != io.EOF { if recvErr != io.EOF {
t.Fatalf("Error: %v; want <EOF>", recvErr) t.Fatalf("Error: %v; want <EOF>", recvErr)
} }
ct.Close() closeClient(ct, t)
server.Close() closeServer(server, t)
} }
func TestClientErrorNotify(t *testing.T) { func TestClientErrorNotify(t *testing.T) {
@ -245,10 +245,10 @@ func TestClientErrorNotify(t *testing.T) {
t.Fatalf("wrong stream id: %d", s.id) t.Fatalf("wrong stream id: %d", s.id)
} }
// Tear down the server. // Tear down the server.
go server.Close() go closeServer(server, t)
// ct.reader should detect the error and activate ct.Error(). // ct.reader should detect the error and activate ct.Error().
<-ct.Error() <-ct.Error()
ct.Close() closeClient(ct, t)
} }
func performOneRPC(ct ClientTransport) { func performOneRPC(ct ClientTransport) {
@ -281,11 +281,11 @@ func TestClientMix(t *testing.T) {
s, ct := setUp(t, true, 0, math.MaxUint32, false) s, ct := setUp(t, true, 0, math.MaxUint32, false)
go func(s *server) { go func(s *server) {
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
s.Close() closeServer(s, t)
}(s) }(s)
go func(t ClientTransport) { go func(ct ClientTransport) {
<-ct.Error() <-ct.Error()
ct.Close() closeClient(ct, t)
}(ct) }(ct)
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
@ -296,8 +296,8 @@ func TestClientMix(t *testing.T) {
func TestExceedMaxStreamsLimit(t *testing.T) { func TestExceedMaxStreamsLimit(t *testing.T) {
server, ct := setUp(t, true, 0, 1, false) server, ct := setUp(t, true, 0, 1, false)
defer func() { defer func() {
ct.Close() closeClient(ct, t)
server.Close() closeServer(server, t)
}() }()
callHdr := &CallHdr{ callHdr := &CallHdr{
Host: "localhost", Host: "localhost",
@ -371,8 +371,8 @@ func TestLargeMessage(t *testing.T) {
}() }()
} }
wg.Wait() wg.Wait()
ct.Close() closeClient(ct, t)
server.Close() closeServer(server, t)
} }
func TestLargeMessageSuspension(t *testing.T) { func TestLargeMessageSuspension(t *testing.T) {
@ -393,8 +393,8 @@ func TestLargeMessageSuspension(t *testing.T) {
if err == nil || err != expectedErr { if err == nil || err != expectedErr {
t.Fatalf("Write got %v, want %v", err, expectedErr) t.Fatalf("Write got %v, want %v", err, expectedErr)
} }
ct.Close() closeClient(ct, t)
server.Close() closeServer(server, t)
} }
func TestStreamContext(t *testing.T) { func TestStreamContext(t *testing.T) {
@ -405,3 +405,53 @@ func TestStreamContext(t *testing.T) {
t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, *s, ok, expectedStream) t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, *s, ok, expectedStream)
} }
} }
// closeClient shuts down the ClientTransport and reports any errors to the
// test framework and terminates the current test case.
func closeClient(ct ClientTransport, t *testing.T) {
if err := ct.Close(); err != nil {
t.Fatalf("ct.Close() = %v, want <nil>", err)
}
}
// closeServerWithErr shuts down the testing server, closing the associated
// transports. It returns the first error it encounters, if any.
func closeServerWithErr(s *server) error {
// Keep consistent with s.Close().
s.lis.Close()
s.mu.Lock()
defer s.mu.Unlock()
for c := range s.conns {
if err := c.Close(); err != nil {
return err
}
}
return nil
}
// closeServer shuts down the and testing server, closing the associated
// transport. It reports any errors to the test framework and terminates the
// current test case.
func closeServer(s *server, t *testing.T) {
if err := closeServerWithErr(s); err != nil {
t.Fatalf("server.Close() = %v, want <nil>", err)
}
}
func TestClientServerDuplicatedClose(t *testing.T) {
server, ct := setUp(t, true, 0, math.MaxUint32, false)
if err := ct.Close(); err != nil {
t.Fatalf("ct.Close() = %v, want <nil>", err)
}
if err := ct.Close(); err == nil {
// Duplicated closes should gracefully issue an error.
t.Fatalf("ct.Close() = <nil>, want non-nil")
}
if err := closeServerWithErr(server); err != nil {
t.Fatalf("closeServerWithErr(server) = %v, want <nil>", err)
}
if err := closeServerWithErr(server); err == nil {
// Duplicated closes should gracefully issue an error.
t.Fatalf("closeServerWithErr(server) = <nil>, want non-nil")
}
}