From d2e79470cce0a0cf23852cf13eca9f950b5e1064 Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Tue, 19 Jul 2016 16:29:13 -0700 Subject: [PATCH 1/9] client goaway support --- clientconn.go | 7 +++++- stream.go | 2 +- transport/http2_client.go | 37 +++++++++++++++++++++++++------ transport/http2_server.go | 12 +++++----- transport/transport.go | 44 +++++++++++++++++++++++++++++++------ transport/transport_test.go | 8 +++---- 6 files changed, 84 insertions(+), 26 deletions(-) diff --git a/clientconn.go b/clientconn.go index c3c7691d..f6b39e4a 100644 --- a/clientconn.go +++ b/clientconn.go @@ -625,13 +625,18 @@ func (ac *addrConn) transportMonitor() { // the addrConn is idle (i.e., no RPC in flight). case <-ac.shutdownChan: return - case <-t.Error(): + case <-t.Done(): ac.mu.Lock() if ac.state == Shutdown { // ac.tearDown(...) has been invoked. ac.mu.Unlock() return } + if t.Err() == transport.ErrConnDrain { + ac.mu.Unlock() + ac.tearDown(errConnDrain) + return + } ac.state = TransientFailure ac.stateCV.Broadcast() ac.mu.Unlock() diff --git a/stream.go b/stream.go index a182e077..2940cbbd 100644 --- a/stream.go +++ b/stream.go @@ -184,7 +184,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth // when there is no pending I/O operations on this stream. go func() { select { - case <-t.Error(): + case <-t.Done(): // Incur transport error, simply exit. case <-s.Done(): // TODO: The trace of the RPC is terminated here when there is no pending diff --git a/transport/http2_client.go b/transport/http2_client.go index 4f22be09..bde77f11 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -71,6 +71,7 @@ type http2Client struct { shutdownChan chan struct{} // errorChan is closed to notify the I/O error to the caller. errorChan chan struct{} + err error framer *framer hBuf *bytes.Buffer // the buffer for HPACK encoding @@ -97,6 +98,7 @@ type http2Client struct { maxStreams int // the per-stream outbound flow control window size set by the peer. streamSendQuota uint32 + goAwayID uint32 } // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 @@ -279,7 +281,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea checkStreamsQuota := t.streamsQuota != nil t.mu.Unlock() if checkStreamsQuota { - sq, err := wait(ctx, nil, t.shutdownChan, t.streamsQuota.acquire()) + sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire()) if err != nil { return nil, err } @@ -288,7 +290,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea t.streamsQuota.add(sq - 1) } } - if _, err := wait(ctx, nil, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { // Return the quota back now because there is no stream returned to the caller. if _, ok := err.(StreamError); ok && checkStreamsQuota { t.streamsQuota.add(1) @@ -480,6 +482,12 @@ func (t *http2Client) GracefulClose() error { return nil } t.state = draining + // Notify the streams which were initiated after the server sent GOAWAY. + for i := t.goAwayID + 2; i < t.nextID; i += 2 { + if s, ok := t.activeStreams[i]; ok { + close(s.goAway) + } + } active := len(t.activeStreams) t.mu.Unlock() if active == 0 { @@ -500,13 +508,13 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { size := http2MaxFrameLen s.sendQuotaPool.add(0) // Wait until the stream has some quota to send the data. - sq, err := wait(s.ctx, s.done, t.shutdownChan, s.sendQuotaPool.acquire()) + sq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, s.sendQuotaPool.acquire()) if err != nil { return err } t.sendQuotaPool.add(0) // Wait until the transport has some quota to send the data. - tq, err := wait(s.ctx, s.done, t.shutdownChan, t.sendQuotaPool.acquire()) + tq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.sendQuotaPool.acquire()) if err != nil { if _, ok := err.(StreamError); ok || err == io.EOF { t.sendQuotaPool.cancel() @@ -540,7 +548,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { // Indicate there is a writer who is about to write a data frame. t.framer.adjustNumWriters(1) // Got some quota. Try to acquire writing privilege on the transport. - if _, err := wait(s.ctx, s.done, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.writableChan); err != nil { if _, ok := err.(StreamError); ok || err == io.EOF { // Return the connection quota back. t.sendQuotaPool.add(len(p)) @@ -723,7 +731,18 @@ func (t *http2Client) handlePing(f *http2.PingFrame) { } func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { - // TODO(zhaoq): GoAwayFrame handler to be implemented + t.mu.Lock() + t.goAwayID = f.LastStreamID + t.err = ErrDrain + close(t.errorChan) + + // Notify the streams which were initiated after the server sent GOAWAY. + //for i := f.LastStreamID + 2; i < t.nextID; i += 2 { + // if s, ok := t.activeStreams[i]; ok { + // close(s.goAway) + // } + //} + t.mu.Unlock() } func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) { @@ -928,10 +947,14 @@ func (t *http2Client) controller() { } } -func (t *http2Client) Error() <-chan struct{} { +func (t *http2Client) Done() <-chan struct{} { return t.errorChan } +func (t *http2Client) Err() error { + return t.err +} + func (t *http2Client) notifyError(err error) { t.mu.Lock() defer t.mu.Unlock() diff --git a/transport/http2_server.go b/transport/http2_server.go index 9e35fdd8..2467630a 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -451,7 +451,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { } s.headerOk = true s.mu.Unlock() - if _, err := wait(s.ctx, nil, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { return err } t.hBuf.Reset() @@ -491,7 +491,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s headersSent = true } s.mu.Unlock() - if _, err := wait(s.ctx, nil, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { return err } t.hBuf.Reset() @@ -540,7 +540,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { } s.mu.Unlock() if writeHeaderFrame { - if _, err := wait(s.ctx, nil, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { return err } t.hBuf.Reset() @@ -568,13 +568,13 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { size := http2MaxFrameLen s.sendQuotaPool.add(0) // Wait until the stream has some quota to send the data. - sq, err := wait(s.ctx, nil, t.shutdownChan, s.sendQuotaPool.acquire()) + sq, err := wait(s.ctx, nil, nil, t.shutdownChan, s.sendQuotaPool.acquire()) if err != nil { return err } t.sendQuotaPool.add(0) // Wait until the transport has some quota to send the data. - tq, err := wait(s.ctx, nil, t.shutdownChan, t.sendQuotaPool.acquire()) + tq, err := wait(s.ctx, nil, nil, t.shutdownChan, t.sendQuotaPool.acquire()) if err != nil { if _, ok := err.(StreamError); ok { t.sendQuotaPool.cancel() @@ -600,7 +600,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { t.framer.adjustNumWriters(1) // Got some quota. Try to acquire writing privilege on the // transport. - if _, err := wait(s.ctx, nil, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { if _, ok := err.(StreamError); ok { // Return the connection quota back. t.sendQuotaPool.add(ps) diff --git a/transport/transport.go b/transport/transport.go index 4dab5745..4a7b83c5 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -53,6 +53,10 @@ import ( "google.golang.org/grpc/metadata" ) +var ( + ErrDrain = ConnectionErrorf("transport: Server stopped accepting new RPCs") +) + // recvMsg represents the received msg from the transport. All transport // protocol specific info has been removed. type recvMsg struct { @@ -120,10 +124,11 @@ func (b *recvBuffer) get() <-chan item { // recvBufferReader implements io.Reader interface to read the data from // recvBuffer. type recvBufferReader struct { - ctx context.Context - recv *recvBuffer - last *bytes.Reader // Stores the remaining data in the previous calls. - err error + ctx context.Context + goAway chan struct{} + recv *recvBuffer + last *bytes.Reader // Stores the remaining data in the previous calls. + err error } // Read reads the next len(p) bytes from last. If last is drained, it tries to @@ -141,6 +146,8 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) { select { case <-r.ctx.Done(): return 0, ContextErr(r.ctx.Err()) + case <-r.goAway: + return 0, ErrConnDrain case i := <-r.recv.get(): r.recv.load() m := i.(*recvMsg) @@ -171,6 +178,8 @@ type Stream struct { cancel context.CancelFunc // done is closed when the final status arrives. done chan struct{} + // goAway + goAway chan struct{} // method records the associated RPC method of the stream. method string recvCompress string @@ -220,6 +229,10 @@ func (s *Stream) Done() <-chan struct{} { return s.done } +func (s *Stream) GoAway() <-chan struct{} { + return s.goAway +} + // Header acquires the key-value pairs of header metadata once it // is available. It blocks until i) the metadata is ready or ii) there is no // header metadata or iii) the stream is cancelled/expired. @@ -422,7 +435,18 @@ type ClientTransport interface { // this in order to take action (e.g., close the current transport // and create a new one) in error case. It should not return nil // once the transport is initiated. - Error() <-chan struct{} + //Error() <-chan struct{} + + // Done returns a channel that is closed when some I/O error + // happens or ClientTranspor receives the draining signal from the server + // (e.g., GOAWAY frame in HTTP/2). Typically the caller should have + // a goroutine to monitor this in order to take action (e.g., close + // the current transport and create a new one) in error case. It should + // not return nil once the transport is initiated. + Done() <-chan struct{} + + // Err returns ... + Err() error } // ServerTransport is the common interface for all gRPC server-side transport @@ -482,7 +506,10 @@ func (e ConnectionError) Error() string { } // ErrConnClosing indicates that the transport is closing. -var ErrConnClosing = ConnectionError{Desc: "transport is closing"} +var ( + ErrConnClosing = ConnectionError{Desc: "transport is closing"} + ErrConnDrain = ConnectionError{Desc: "transport is being drained"} +) // StreamError is an error that only affects one stream within a connection. type StreamError struct { @@ -509,9 +536,10 @@ func ContextErr(err error) StreamError { // If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err. // If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise // it return the StreamError for ctx.Err. +// If it receives from goAway, it returns 0, ErrConnDrain. // If it receives from closing, it returns 0, ErrConnClosing. // If it receives from proceed, it returns the received integer, nil. -func wait(ctx context.Context, done, closing <-chan struct{}, proceed <-chan int) (int, error) { +func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-chan int) (int, error) { select { case <-ctx.Done(): return 0, ContextErr(ctx.Err()) @@ -523,6 +551,8 @@ func wait(ctx context.Context, done, closing <-chan struct{}, proceed <-chan int default: } return 0, io.EOF + case <-goAway: + return 0, ErrConnDrain case <-closing: return 0, ErrConnClosing case i := <-proceed: diff --git a/transport/transport_test.go b/transport/transport_test.go index ce015da2..a98f27e5 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -271,8 +271,8 @@ func TestClientSendAndReceive(t *testing.T) { func TestClientErrorNotify(t *testing.T) { server, ct := setUp(t, 0, math.MaxUint32, normal) go server.stop() - // ct.reader should detect the error and activate ct.Error(). - <-ct.Error() + // ct.reader should detect the error and activate ct.Done(). + <-ct.Done() ct.Close() } @@ -309,7 +309,7 @@ func TestClientMix(t *testing.T) { s.stop() }(s) go func(ct ClientTransport) { - <-ct.Error() + <-ct.Done() ct.Close() }(ct) for i := 0; i < 1000; i++ { @@ -700,7 +700,7 @@ func TestClientWithMisbehavedServer(t *testing.T) { } } // http2Client.errChan is closed due to connection flow control window size violation. - <-conn.Error() + <-conn.Done() ct.Close() server.stop() } From 873cc272c2e52e8e58d96193376d19760012863b Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Wed, 20 Jul 2016 18:48:49 -0700 Subject: [PATCH 2/9] support goaway --- clientconn.go | 1 + rpc_util.go | 6 +++++ server.go | 32 +++++++++++++++++++++--- stream.go | 3 +++ test/end2end_test.go | 49 +++++++++++++++++++++++++++++++++++++ transport/control.go | 5 ++++ transport/handler_server.go | 3 +++ transport/http2_client.go | 27 +++++++++----------- transport/http2_server.go | 21 ++++++++++++++++ transport/transport.go | 14 +++++------ 10 files changed, 136 insertions(+), 25 deletions(-) diff --git a/clientconn.go b/clientconn.go index f6b39e4a..4933b554 100644 --- a/clientconn.go +++ b/clientconn.go @@ -635,6 +635,7 @@ func (ac *addrConn) transportMonitor() { if t.Err() == transport.ErrConnDrain { ac.mu.Unlock() ac.tearDown(errConnDrain) + ac.cc.newAddrConn(ac.addr, true) return } ac.state = TransientFailure diff --git a/rpc_util.go b/rpc_util.go index d6287175..173018e4 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -385,6 +385,12 @@ func toRPCErr(err error) error { desc: e.Desc, } case transport.ConnectionError: + if err == transport.ErrConnDrain { + return &rpcError{ + code: codes.Unavailable, + desc: e.Desc, + } + } return &rpcError{ code: codes.Internal, desc: e.Desc, diff --git a/server.go b/server.go index a2b2b94d..34e69102 100644 --- a/server.go +++ b/server.go @@ -92,6 +92,8 @@ type Server struct { mu sync.Mutex // guards following lis map[net.Listener]bool conns map[io.Closer]bool + drain bool + cv *sync.Cond m map[string]*service // service name -> service info events trace.EventLog } @@ -186,6 +188,7 @@ func NewServer(opt ...ServerOption) *Server { conns: make(map[io.Closer]bool), m: make(map[string]*service), } + s.cv = sync.NewCond(&s.mu) if EnableTracing { _, file, line, _ := runtime.Caller(1) s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line)) @@ -468,7 +471,7 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea func (s *Server) addConn(c io.Closer) bool { s.mu.Lock() defer s.mu.Unlock() - if s.conns == nil { + if s.conns == nil || s.drain { return false } s.conns[c] = true @@ -480,6 +483,7 @@ func (s *Server) removeConn(c io.Closer) { defer s.mu.Unlock() if s.conns != nil { delete(s.conns, c) + s.cv.Signal() } } @@ -766,14 +770,14 @@ func (s *Server) Stop() { s.mu.Lock() listeners := s.lis s.lis = nil - cs := s.conns + st := s.conns s.conns = nil s.mu.Unlock() for lis := range listeners { lis.Close() } - for c := range cs { + for c := range st { c.Close() } @@ -785,6 +789,28 @@ func (s *Server) Stop() { s.mu.Unlock() } +func (s *Server) GracefulStop() { + s.mu.Lock() + s.drain = true + for lis := range s.lis { + lis.Close() + } + for c := range s.conns { + c.(transport.ServerTransport).GoAway() + } + for len(s.conns) != 0 { + s.cv.Wait() + } + s.lis = nil + s.conns = nil + if s.events != nil { + s.events.Finish() + s.events = nil + } + s.mu.Unlock() + +} + func init() { internal.TestingCloseConns = func(arg interface{}) { arg.(*Server).testingCloseConns() diff --git a/stream.go b/stream.go index dfa224b4..deb8663c 100644 --- a/stream.go +++ b/stream.go @@ -195,6 +195,9 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth cs.finish(Errorf(s.StatusCode(), "%s", s.StatusDesc())) } cs.closeTransportStream(nil) + case <-s.GoAway(): + cs.finish(errConnDrain) + cs.closeTransportStream(errConnDrain) case <-s.Context().Done(): err := s.Context().Err() cs.finish(err) diff --git a/test/end2end_test.go b/test/end2end_test.go index fdac5815..721e6706 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -572,6 +572,55 @@ func TestFailFast(t *testing.T) { } } +func TestServerGoAway(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + if e.name == "handler-tls" { + continue + } + //if e.name != "tcp-clear" { + // continue + //} + testServerGoAway(t, e) + } +} + +func testServerGoAway(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + te.declareLogNoise( + "transport: http2Client.notifyError got notified that the client transport was broken EOF", + "grpc: Conn.transportMonitor exits due to: grpc: the client connection is closing", + "grpc: Conn.resetTransport failed to create client transport: connection error", + "grpc: Conn.resetTransport failed to create client transport: connection error: desc = \"transport: dial unix", + ) + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) + } + ch := make(chan struct{}) + go func() { + te.srv.GracefulStop() + close(ch) + }() + for { + ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil { + continue + } + break + } + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err == nil || grpc.Code(err) != codes.Unavailable { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, error code: %d", err, codes.Unavailable) + } + <-ch + awaitNewConnLogOutput() +} + func testFailFast(t *testing.T, e env) { te := newTest(t, e) te.userAgent = testAppUA diff --git a/transport/control.go b/transport/control.go index 7e9bdf33..4ef0830b 100644 --- a/transport/control.go +++ b/transport/control.go @@ -72,6 +72,11 @@ type resetStream struct { func (*resetStream) item() {} +type goAway struct { +} + +func (*goAway) item() {} + type flushIO struct { } diff --git a/transport/handler_server.go b/transport/handler_server.go index 4b0d5252..723bf5b0 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -370,6 +370,9 @@ func (ht *serverHandlerTransport) runStream() { } } +func (ht *serverHandlerTransport) GoAway() { +} + // mapRecvMsgError returns the non-nil err into the appropriate // error value as expected by callers of *grpc.parser.recvMsg. // In particular, in can only be: diff --git a/transport/http2_client.go b/transport/http2_client.go index bde77f11..2ec703f0 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -205,6 +205,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { s := &Stream{ id: t.nextID, done: make(chan struct{}), + goAway: make(chan struct{}), method: callHdr.Method, sendCompress: callHdr.SendCompress, buf: newRecvBuffer(), @@ -219,8 +220,9 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { // Make a stream be able to cancel the pending operations by itself. s.ctx, s.cancel = context.WithCancel(ctx) s.dec = &recvBufferReader{ - ctx: s.ctx, - recv: s.buf, + ctx: s.ctx, + goAway: s.goAway, + recv: s.buf, } return s } @@ -443,13 +445,13 @@ func (t *http2Client) CloseStream(s *Stream, err error) { // accessed any more. func (t *http2Client) Close() (err error) { t.mu.Lock() - if t.state == reachable { - close(t.errorChan) - } if t.state == closing { t.mu.Unlock() return } + if t.state == reachable { + close(t.errorChan) + } t.state = closing t.mu.Unlock() close(t.shutdownChan) @@ -732,16 +734,11 @@ func (t *http2Client) handlePing(f *http2.PingFrame) { func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { t.mu.Lock() - t.goAwayID = f.LastStreamID - t.err = ErrDrain - close(t.errorChan) - - // Notify the streams which were initiated after the server sent GOAWAY. - //for i := f.LastStreamID + 2; i < t.nextID; i += 2 { - // if s, ok := t.activeStreams[i]; ok { - // close(s.goAway) - // } - //} + if t.state == reachable { + t.goAwayID = f.LastStreamID + t.err = ErrConnDrain + close(t.errorChan) + } t.mu.Unlock() } diff --git a/transport/http2_server.go b/transport/http2_server.go index 2467630a..d7cab4fe 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -196,15 +196,22 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( s.recvCompress = state.encoding s.method = state.method t.mu.Lock() + if t.state == draining { + t.mu.Unlock() + t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream}) + return + } if t.state != reachable { t.mu.Unlock() return } + if uint32(len(t.activeStreams)) >= t.maxStreams { t.mu.Unlock() t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream}) return } + s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota)) t.activeStreams[s.id] = s t.mu.Unlock() @@ -263,13 +270,16 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) { switch frame := frame.(type) { case *http2.MetaHeadersFrame: id := frame.Header().StreamID + t.mu.Lock() if id%2 != 1 || id <= t.maxStreamID { + t.mu.Unlock() // illegal gRPC stream id. grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", id) t.Close() break } t.maxStreamID = id + t.mu.Unlock() t.operateHeaders(frame, handle) case *http2.DataFrame: t.handleData(frame) @@ -282,6 +292,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) { case *http2.WindowUpdateFrame: t.handleWindowUpdate(frame) case *http2.GoAwayFrame: + t.Close() break default: grpclog.Printf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame) @@ -675,6 +686,12 @@ func (t *http2Server) controller() { } case *resetStream: t.framer.writeRSTStream(true, i.streamID, i.code) + case *goAway: + t.mu.Lock() + sid := t.maxStreamID + t.state = draining + t.mu.Unlock() + t.framer.writeGoAway(true, sid, http2.ErrCodeNo, nil) case *flushIO: t.framer.flushWrite() case *ping: @@ -742,3 +759,7 @@ func (t *http2Server) closeStream(s *Stream) { func (t *http2Server) RemoteAddr() net.Addr { return t.conn.RemoteAddr() } + +func (t *http2Server) GoAway() { + t.controlBuf.put(&goAway{}) +} diff --git a/transport/transport.go b/transport/transport.go index 4a7b83c5..2372f322 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -53,10 +53,6 @@ import ( "google.golang.org/grpc/metadata" ) -var ( - ErrDrain = ConnectionErrorf("transport: Server stopped accepting new RPCs") -) - // recvMsg represents the received msg from the transport. All transport // protocol specific info has been removed. type recvMsg struct { @@ -147,7 +143,7 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) { case <-r.ctx.Done(): return 0, ContextErr(r.ctx.Err()) case <-r.goAway: - return 0, ErrConnDrain + return 0, ErrStreamDrain case i := <-r.recv.get(): r.recv.load() m := i.(*recvMsg) @@ -478,6 +474,9 @@ type ServerTransport interface { // RemoteAddr returns the remote network address. RemoteAddr() net.Addr + + // GoAway ... + GoAway() } // StreamErrorf creates an StreamError with the specified error code and description. @@ -509,6 +508,7 @@ func (e ConnectionError) Error() string { var ( ErrConnClosing = ConnectionError{Desc: "transport is closing"} ErrConnDrain = ConnectionError{Desc: "transport is being drained"} + ErrStreamDrain = StreamErrorf(codes.Unavailable, "afjlalf") ) // StreamError is an error that only affects one stream within a connection. @@ -536,7 +536,7 @@ func ContextErr(err error) StreamError { // If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err. // If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise // it return the StreamError for ctx.Err. -// If it receives from goAway, it returns 0, ErrConnDrain. +// If it receives from goAway, it returns 0, ErrStreamDrain. // If it receives from closing, it returns 0, ErrConnClosing. // If it receives from proceed, it returns the received integer, nil. func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-chan int) (int, error) { @@ -552,7 +552,7 @@ func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <- } return 0, io.EOF case <-goAway: - return 0, ErrConnDrain + return 0, ErrStreamDrain case <-closing: return 0, ErrConnClosing case i := <-proceed: From 9ad4c58355b60b2cedcc51d35ff7573192a03309 Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Thu, 21 Jul 2016 16:19:34 -0700 Subject: [PATCH 3/9] Make it work for streaming --- clientconn.go | 18 +++++--- server.go | 2 +- stream.go | 2 +- test/end2end_test.go | 87 ++++++++++++++++++++++++++++++++----- transport/http2_client.go | 31 ++++++++----- transport/http2_server.go | 3 ++ transport/transport.go | 6 +-- transport/transport_test.go | 8 ++-- 8 files changed, 120 insertions(+), 37 deletions(-) diff --git a/clientconn.go b/clientconn.go index 4933b554..de7d0383 100644 --- a/clientconn.go +++ b/clientconn.go @@ -625,19 +625,23 @@ func (ac *addrConn) transportMonitor() { // the addrConn is idle (i.e., no RPC in flight). case <-ac.shutdownChan: return - case <-t.Done(): + case <-t.GoAway(): + ac.tearDown(errConnDrain) + ac.cc.newAddrConn(ac.addr, true) + return + case <-t.Error(): ac.mu.Lock() if ac.state == Shutdown { // ac.tearDown(...) has been invoked. ac.mu.Unlock() return } - if t.Err() == transport.ErrConnDrain { - ac.mu.Unlock() - ac.tearDown(errConnDrain) - ac.cc.newAddrConn(ac.addr, true) - return - } + //if t.Err() == transport.ErrConnDrain { + // ac.mu.Unlock() + // ac.tearDown(errConnDrain) + // ac.cc.newAddrConn(ac.addr, true) + // return + //} ac.state = TransientFailure ac.stateCV.Broadcast() ac.mu.Unlock() diff --git a/server.go b/server.go index 34e69102..2f847100 100644 --- a/server.go +++ b/server.go @@ -391,6 +391,7 @@ func (s *Server) serveNewHTTP2Transport(c net.Conn, authInfo credentials.AuthInf st.Close() return } + grpclog.Println("DEBUG addConn ... ") s.serveStreams(st) } @@ -808,7 +809,6 @@ func (s *Server) GracefulStop() { s.events = nil } s.mu.Unlock() - } func init() { diff --git a/stream.go b/stream.go index deb8663c..fb7e50f9 100644 --- a/stream.go +++ b/stream.go @@ -184,7 +184,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth // when there is no pending I/O operations on this stream. go func() { select { - case <-t.Done(): + case <-t.Error(): // Incur transport error, simply exit. case <-s.Done(): // TODO: The trace of the RPC is terminated here when there is no pending diff --git a/test/end2end_test.go b/test/end2end_test.go index 721e6706..c9b5f539 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -565,22 +565,12 @@ func testTimeoutOnDeadServer(t *testing.T, e env) { awaitNewConnLogOutput() } -func TestFailFast(t *testing.T) { - defer leakCheck(t)() - for _, e := range listTestEnv() { - testFailFast(t, e) - } -} - func TestServerGoAway(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() { if e.name == "handler-tls" { continue } - //if e.name != "tcp-clear" { - // continue - //} testServerGoAway(t, e) } } @@ -621,6 +611,83 @@ func testServerGoAway(t *testing.T, e env) { awaitNewConnLogOutput() } +func TestServerGoAwayPendingRPC(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + if e.name == "handler-tls" { + continue + } + testServerGoAwayPendingRPC(t, e) + } +} + +func testServerGoAwayPendingRPC(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + te.declareLogNoise( + "transport: http2Client.notifyError got notified that the client transport was broken EOF", + "grpc: Conn.transportMonitor exits due to: grpc: the client connection is closing", + "grpc: Conn.resetTransport failed to create client transport: connection error", + "grpc: Conn.resetTransport failed to create client transport: connection error: desc = \"transport: dial unix", + ) + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + ctx, cancel := context.WithCancel(context.Background()) + stream, err := tc.FullDuplexCall(ctx, grpc.FailFast(false)) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { + t.Fatalf("fadjflajdflkaflj") + } + ch := make(chan struct{}) + go func() { + te.srv.GracefulStop() + close(ch) + }() + for { + ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil { + continue + } else { + break + } + } + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(1), + }, + } + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(100)) + if err != nil { + t.Fatal(err) + } + req := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: payload, + } + if err := stream.Send(req); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, req, err) + } + if _, err := stream.Recv(); err != nil { + t.Fatalf("%v.Recv() = %v, want _, ", stream, err) + } + cancel() + <-ch + awaitNewConnLogOutput() +} + +func TestFailFast(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testFailFast(t, e) + } +} + func testFailFast(t *testing.T, e env) { te := newTest(t, e) te.userAgent = testAppUA diff --git a/transport/http2_client.go b/transport/http2_client.go index 2ec703f0..71873ef1 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -71,7 +71,8 @@ type http2Client struct { shutdownChan chan struct{} // errorChan is closed to notify the I/O error to the caller. errorChan chan struct{} - err error + //err error + goAway chan struct{} framer *framer hBuf *bytes.Buffer // the buffer for HPACK encoding @@ -149,6 +150,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e writableChan: make(chan int, 1), shutdownChan: make(chan struct{}), errorChan: make(chan struct{}), + goAway: make(chan struct{}), framer: newFramer(conn), hBuf: &buf, hEnc: hpack.NewEncoder(&buf), @@ -408,13 +410,13 @@ func (t *http2Client) CloseStream(s *Stream, err error) { if t.streamsQuota != nil { updateStreams = true } - if t.state == draining && len(t.activeStreams) == 1 { + delete(t.activeStreams, s.id) + if t.state == draining && len(t.activeStreams) == 0 { // The transport is draining and s is the last live stream on t. t.mu.Unlock() t.Close() return } - delete(t.activeStreams, s.id) t.mu.Unlock() if updateStreams { t.streamsQuota.add(1) @@ -485,10 +487,14 @@ func (t *http2Client) GracefulClose() error { } t.state = draining // Notify the streams which were initiated after the server sent GOAWAY. - for i := t.goAwayID + 2; i < t.nextID; i += 2 { - if s, ok := t.activeStreams[i]; ok { - close(s.goAway) + select { + case <-t.goAway: + for i := t.goAwayID + 2; i < t.nextID; i += 2 { + if s, ok := t.activeStreams[i]; ok { + close(s.goAway) + } } + default: } active := len(t.activeStreams) t.mu.Unlock() @@ -736,8 +742,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { t.mu.Lock() if t.state == reachable { t.goAwayID = f.LastStreamID - t.err = ErrConnDrain - close(t.errorChan) + close(t.goAway) } t.mu.Unlock() } @@ -944,14 +949,18 @@ func (t *http2Client) controller() { } } -func (t *http2Client) Done() <-chan struct{} { +func (t *http2Client) Error() <-chan struct{} { return t.errorChan } -func (t *http2Client) Err() error { - return t.err +func (t *http2Client) GoAway() <-chan struct{} { + return t.goAway } +//func (t *http2Client) Err() error { +// return t.err +//} + func (t *http2Client) notifyError(err error) { t.mu.Lock() defer t.mu.Unlock() diff --git a/transport/http2_server.go b/transport/http2_server.go index d7cab4fe..37c9a9ae 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -737,6 +737,9 @@ func (t *http2Server) Close() (err error) { func (t *http2Server) closeStream(s *Stream) { t.mu.Lock() delete(t.activeStreams, s.id) + if t.state == draining && len(t.activeStreams) == 0 { + defer t.Close() + } t.mu.Unlock() // In case stream sending and receiving are invoked in separate // goroutines (e.g., bi-directional streaming), cancel needs to be diff --git a/transport/transport.go b/transport/transport.go index 2372f322..e592bfe9 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -431,7 +431,7 @@ type ClientTransport interface { // this in order to take action (e.g., close the current transport // and create a new one) in error case. It should not return nil // once the transport is initiated. - //Error() <-chan struct{} + Error() <-chan struct{} // Done returns a channel that is closed when some I/O error // happens or ClientTranspor receives the draining signal from the server @@ -439,10 +439,10 @@ type ClientTransport interface { // a goroutine to monitor this in order to take action (e.g., close // the current transport and create a new one) in error case. It should // not return nil once the transport is initiated. - Done() <-chan struct{} + GoAway() <-chan struct{} // Err returns ... - Err() error + //Err() error } // ServerTransport is the common interface for all gRPC server-side transport diff --git a/transport/transport_test.go b/transport/transport_test.go index a1c1cdd3..5a517e0b 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -271,8 +271,8 @@ func TestClientSendAndReceive(t *testing.T) { func TestClientErrorNotify(t *testing.T) { server, ct := setUp(t, 0, math.MaxUint32, normal) go server.stop() - // ct.reader should detect the error and activate ct.Done(). - <-ct.Done() + // ct.reader should detect the error and activate ct.Error(). + <-ct.Error() ct.Close() } @@ -309,7 +309,7 @@ func TestClientMix(t *testing.T) { s.stop() }(s) go func(ct ClientTransport) { - <-ct.Done() + <-ct.Error() ct.Close() }(ct) for i := 0; i < 1000; i++ { @@ -709,7 +709,7 @@ func TestClientWithMisbehavedServer(t *testing.T) { } } // http2Client.errChan is closed due to connection flow control window size violation. - <-conn.Done() + <-conn.Error() ct.Close() server.stop() } From 046e606dc53d0c174ff48b5aa7ef3571b5fab850 Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Thu, 21 Jul 2016 18:12:01 -0700 Subject: [PATCH 4/9] clean up --- clientconn.go | 6 ------ rpc_util.go | 6 ------ server.go | 13 ++++++++----- test/end2end_test.go | 9 +++++++-- transport/handler_server.go | 1 + transport/http2_client.go | 10 ++++------ transport/http2_server.go | 8 +------- transport/transport.go | 27 +++++++++++++-------------- 8 files changed, 34 insertions(+), 46 deletions(-) diff --git a/clientconn.go b/clientconn.go index de7d0383..33edfa31 100644 --- a/clientconn.go +++ b/clientconn.go @@ -636,12 +636,6 @@ func (ac *addrConn) transportMonitor() { ac.mu.Unlock() return } - //if t.Err() == transport.ErrConnDrain { - // ac.mu.Unlock() - // ac.tearDown(errConnDrain) - // ac.cc.newAddrConn(ac.addr, true) - // return - //} ac.state = TransientFailure ac.stateCV.Broadcast() ac.mu.Unlock() diff --git a/rpc_util.go b/rpc_util.go index 173018e4..d6287175 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -385,12 +385,6 @@ func toRPCErr(err error) error { desc: e.Desc, } case transport.ConnectionError: - if err == transport.ErrConnDrain { - return &rpcError{ - code: codes.Unavailable, - desc: e.Desc, - } - } return &rpcError{ code: codes.Internal, desc: e.Desc, diff --git a/server.go b/server.go index 2f847100..2ad331e1 100644 --- a/server.go +++ b/server.go @@ -89,10 +89,12 @@ type service struct { type Server struct { opts options - mu sync.Mutex // guards following - lis map[net.Listener]bool - conns map[io.Closer]bool - drain bool + mu sync.Mutex // guards following + lis map[net.Listener]bool + conns map[io.Closer]bool + drain bool + // A CondVar to let GracefulStop() blocks until all the pending RPCs are finished + // and all the transport goes away. cv *sync.Cond m map[string]*service // service name -> service info events trace.EventLog @@ -391,7 +393,6 @@ func (s *Server) serveNewHTTP2Transport(c net.Conn, authInfo credentials.AuthInf st.Close() return } - grpclog.Println("DEBUG addConn ... ") s.serveStreams(st) } @@ -790,6 +791,8 @@ func (s *Server) Stop() { s.mu.Unlock() } +// GracefulStop stops the gRPC server gracefully. It stops the server to accept new +// connections and RPCs and blocks until all the pending RPCs are finished. func (s *Server) GracefulStop() { s.mu.Lock() s.drain = true diff --git a/test/end2end_test.go b/test/end2end_test.go index c9b5f539..a3f022a1 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -589,6 +589,7 @@ func testServerGoAway(t *testing.T, e env) { cc := te.clientConn() tc := testpb.NewTestServiceClient(cc) + // Finish an RPC to make sure the connection is good. if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) } @@ -597,6 +598,7 @@ func testServerGoAway(t *testing.T, e env) { te.srv.GracefulStop() close(ch) }() + // Loop until the server side GoAway signal is propagated to the client. for { ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil { @@ -604,6 +606,7 @@ func testServerGoAway(t *testing.T, e env) { } break } + // A new RPC should fail with Unavailable error. if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err == nil || grpc.Code(err) != codes.Unavailable { t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, error code: %d", err, codes.Unavailable) } @@ -640,6 +643,7 @@ func testServerGoAwayPendingRPC(t *testing.T, e env) { if err != nil { t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) } + // Finish an RPC to make sure the connection is good. if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { t.Fatalf("fadjflajdflkaflj") } @@ -648,13 +652,13 @@ func testServerGoAwayPendingRPC(t *testing.T, e env) { te.srv.GracefulStop() close(ch) }() + // Loop until the server side GoAway signal is propagated to the client. for { ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil { continue - } else { - break } + break } respParam := []*testpb.ResponseParameters{ { @@ -670,6 +674,7 @@ func testServerGoAwayPendingRPC(t *testing.T, e env) { ResponseParameters: respParam, Payload: payload, } + // The existing RPC should be still good to proceed. if err := stream.Send(req); err != nil { t.Fatalf("%v.Send(%v) = %v, want ", stream, req, err) } diff --git a/transport/handler_server.go b/transport/handler_server.go index 723bf5b0..35ccf627 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -371,6 +371,7 @@ func (ht *serverHandlerTransport) runStream() { } func (ht *serverHandlerTransport) GoAway() { + panic("not implemented") } // mapRecvMsgError returns the non-nil err into the appropriate diff --git a/transport/http2_client.go b/transport/http2_client.go index 71873ef1..20eb4bab 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -71,7 +71,8 @@ type http2Client struct { shutdownChan chan struct{} // errorChan is closed to notify the I/O error to the caller. errorChan chan struct{} - //err error + // goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor) + // that the server sent GoAway on this transport. goAway chan struct{} framer *framer @@ -99,7 +100,8 @@ type http2Client struct { maxStreams int // the per-stream outbound flow control window size set by the peer. streamSendQuota uint32 - goAwayID uint32 + // goAwayID records the Last-Stream-ID in the GoAway frame from the server. + goAwayID uint32 } // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 @@ -957,10 +959,6 @@ func (t *http2Client) GoAway() <-chan struct{} { return t.goAway } -//func (t *http2Client) Err() error { -// return t.err -//} - func (t *http2Client) notifyError(err error) { t.mu.Lock() defer t.mu.Unlock() diff --git a/transport/http2_server.go b/transport/http2_server.go index 37c9a9ae..e10eedb9 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -196,11 +196,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( s.recvCompress = state.encoding s.method = state.method t.mu.Lock() - if t.state == draining { - t.mu.Unlock() - t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream}) - return - } if t.state != reachable { t.mu.Unlock() return @@ -292,8 +287,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) { case *http2.WindowUpdateFrame: t.handleWindowUpdate(frame) case *http2.GoAwayFrame: - t.Close() - break + // TODO: Handle GoAway from the client appropriately. default: grpclog.Printf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame) } diff --git a/transport/transport.go b/transport/transport.go index e592bfe9..f94731ed 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -174,7 +174,7 @@ type Stream struct { cancel context.CancelFunc // done is closed when the final status arrives. done chan struct{} - // goAway + // goAway is closed when the server sent GoAways signal before this stream was initiated. goAway chan struct{} // method records the associated RPC method of the stream. method string @@ -221,10 +221,14 @@ func (s *Stream) SetSendCompress(str string) { s.sendCompress = str } +// Done returns a chanel which is closed when it receives the final status +// from the server. func (s *Stream) Done() <-chan struct{} { return s.done } +// GoAway returns a channel which is closed when the server sent GoAways signal +// before this stream was initiated. func (s *Stream) GoAway() <-chan struct{} { return s.goAway } @@ -433,16 +437,10 @@ type ClientTransport interface { // once the transport is initiated. Error() <-chan struct{} - // Done returns a channel that is closed when some I/O error - // happens or ClientTranspor receives the draining signal from the server - // (e.g., GOAWAY frame in HTTP/2). Typically the caller should have - // a goroutine to monitor this in order to take action (e.g., close - // the current transport and create a new one) in error case. It should - // not return nil once the transport is initiated. + // GoAway returns a channel that is closed when ClientTranspor + // receives the draining signal from the server (e.g., GOAWAY frame in + // HTTP/2). GoAway() <-chan struct{} - - // Err returns ... - //Err() error } // ServerTransport is the common interface for all gRPC server-side transport @@ -475,7 +473,7 @@ type ServerTransport interface { // RemoteAddr returns the remote network address. RemoteAddr() net.Addr - // GoAway ... + // GoAway notifies the client this ServerTransport stops accepting new RPCs. GoAway() } @@ -504,11 +502,12 @@ func (e ConnectionError) Error() string { return fmt.Sprintf("connection error: desc = %q", e.Desc) } -// ErrConnClosing indicates that the transport is closing. var ( + // ErrConnClosing indicates that the transport is closing. ErrConnClosing = ConnectionError{Desc: "transport is closing"} - ErrConnDrain = ConnectionError{Desc: "transport is being drained"} - ErrStreamDrain = StreamErrorf(codes.Unavailable, "afjlalf") + // ErrStreamDrain indicates that the stream is rejected by the server because + // the server stops accepting new RPCs. + ErrStreamDrain = StreamErrorf(codes.Unavailable, "the server stops accepting new RPCs") ) // StreamError is an error that only affects one stream within a connection. From e40dc9bff9832711c681ac321f185351b5771162 Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Fri, 22 Jul 2016 15:04:47 -0700 Subject: [PATCH 5/9] move the stream id checking into operateHeaders --- transport/http2_server.go | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/transport/http2_server.go b/transport/http2_server.go index e10eedb9..f0d1f7f6 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -142,7 +142,7 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI } // operateHeader takes action on the decoded headers. -func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) { +func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) (close bool) { buf := newRecvBuffer() s := &Stream{ id: frame.Header().StreamID, @@ -200,13 +200,18 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( t.mu.Unlock() return } - if uint32(len(t.activeStreams)) >= t.maxStreams { t.mu.Unlock() t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream}) return } - + if s.id%2 != 1 || s.id <= t.maxStreamID { + t.mu.Unlock() + // illegal gRPC stream id. + grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", s.id) + return true + } + t.maxStreamID = s.id s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota)) t.activeStreams[s.id] = s t.mu.Unlock() @@ -214,6 +219,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( t.updateWindow(s, uint32(n)) } handle(s) + return } // HandleStreams receives incoming streams using the given handler. This is @@ -264,18 +270,10 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) { } switch frame := frame.(type) { case *http2.MetaHeadersFrame: - id := frame.Header().StreamID - t.mu.Lock() - if id%2 != 1 || id <= t.maxStreamID { - t.mu.Unlock() - // illegal gRPC stream id. - grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", id) + if t.operateHeaders(frame, handle) { t.Close() break } - t.maxStreamID = id - t.mu.Unlock() - t.operateHeaders(frame, handle) case *http2.DataFrame: t.handleData(frame) case *http2.RSTStreamFrame: From f1e4d3b18070fce70e367a4faff1b89bd48d638f Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Mon, 25 Jul 2016 16:35:32 -0700 Subject: [PATCH 6/9] allow multiple GoAways and retrying on illegal streams --- call.go | 5 ++++- clientconn.go | 23 +++++++++++++---------- server.go | 6 +++++- test/end2end_test.go | 1 + transport/handler_server.go | 4 ++-- transport/http2_client.go | 34 ++++++++++++++++++++++++++-------- transport/http2_server.go | 2 +- transport/transport.go | 4 ++-- 8 files changed, 54 insertions(+), 25 deletions(-) diff --git a/call.go b/call.go index d6326ea0..27cf6411 100644 --- a/call.go +++ b/call.go @@ -179,7 +179,10 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli put() put = nil } - if _, ok := err.(transport.ConnectionError); ok { + // Retry a non-failfast RPC when + // i) there is a connection error; or + // ii) the server started to drain before this RPC was initiated. + if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { if c.failFast { return toRPCErr(err) } diff --git a/clientconn.go b/clientconn.go index 33edfa31..3206d674 100644 --- a/clientconn.go +++ b/clientconn.go @@ -697,14 +697,21 @@ func (ac *addrConn) tearDown(err error) { } ac.cc.mu.Unlock() }() - if ac.state == Shutdown { - return - } - ac.state = Shutdown if ac.down != nil { ac.down(downErrorf(false, false, "%v", err)) ac.down = nil } + if err == errConnDrain && ac.transport != nil { + // GracefulClose(...) may be executed multiple times when + // i) receiving multiple GoAway frames from the server; or + // ii) there are concurrent name resolver/Balancer triggered + // address removal and GoAway. + ac.transport.GracefulClose() + } + if ac.state == Shutdown { + return + } + ac.state = Shutdown ac.stateCV.Broadcast() if ac.events != nil { ac.events.Finish() @@ -714,12 +721,8 @@ func (ac *addrConn) tearDown(err error) { close(ac.ready) ac.ready = nil } - if ac.transport != nil { - if err == errConnDrain { - ac.transport.GracefulClose() - } else { - ac.transport.Close() - } + if ac.transport != nil && err != errConnDrain { + ac.transport.Close() } if ac.shutdownChan != nil { close(ac.shutdownChan) diff --git a/server.go b/server.go index 2ad331e1..1a250c79 100644 --- a/server.go +++ b/server.go @@ -795,12 +795,16 @@ func (s *Server) Stop() { // connections and RPCs and blocks until all the pending RPCs are finished. func (s *Server) GracefulStop() { s.mu.Lock() + if s.drain == true || s.conns == nil { + s.mu.Lock() + return + } s.drain = true for lis := range s.lis { lis.Close() } for c := range s.conns { - c.(transport.ServerTransport).GoAway() + c.(transport.ServerTransport).Drain() } for len(s.conns) != 0 { s.cv.Wait() diff --git a/test/end2end_test.go b/test/end2end_test.go index a3f022a1..2cc4ccfc 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -2114,6 +2114,7 @@ func interestingGoroutines() (gs []string) { if stack == "" || strings.Contains(stack, "testing.Main(") || + strings.Contains(stack, "testing.tRunner(") || strings.Contains(stack, "runtime.goexit") || strings.Contains(stack, "created by runtime.gc") || strings.Contains(stack, "created by google3/base/go/log.init") || diff --git a/transport/handler_server.go b/transport/handler_server.go index 35ccf627..3d7b15e5 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -370,8 +370,8 @@ func (ht *serverHandlerTransport) runStream() { } } -func (ht *serverHandlerTransport) GoAway() { - panic("not implemented") +func (ht *serverHandlerTransport) Drain() { + panic("Drain() is not implemented") } // mapRecvMsgError returns the non-nil err into the appropriate diff --git a/transport/http2_client.go b/transport/http2_client.go index 20eb4bab..ffde50ae 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -102,6 +102,8 @@ type http2Client struct { streamSendQuota uint32 // goAwayID records the Last-Stream-ID in the GoAway frame from the server. goAwayID uint32 + // prevGoAway ID records the Last-Stream-ID in the previous GOAway frame. + prevGoAwayID uint32 } // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 @@ -483,21 +485,25 @@ func (t *http2Client) GracefulClose() error { t.mu.Unlock() return nil } - if t.state == draining { - t.mu.Unlock() - return nil - } - t.state = draining // Notify the streams which were initiated after the server sent GOAWAY. select { case <-t.goAway: - for i := t.goAwayID + 2; i < t.nextID; i += 2 { + n := t.prevGoAwayID + if n == 0 && t.nextID > 1 { + n = t.nextID - 2 + } + for i := t.goAwayID + 2; i <= n; i += 2 { if s, ok := t.activeStreams[i]; ok { close(s.goAway) } } default: } + if t.state == draining { + t.mu.Unlock() + return nil + } + t.state = draining active := len(t.activeStreams) t.mu.Unlock() if active == 0 { @@ -742,9 +748,21 @@ func (t *http2Client) handlePing(f *http2.PingFrame) { func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { t.mu.Lock() - if t.state == reachable { + if t.state == reachable || t.state == draining { + if t.goAwayID > 0 && t.goAwayID < f.LastStreamID { + id := t.goAwayID + t.mu.Unlock() + t.notifyError(ConnectionErrorf("received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID)) + return + } + t.prevGoAwayID = t.goAwayID t.goAwayID = f.LastStreamID - close(t.goAway) + select { + case <-t.goAway: + // t.goAway has been closed (i.e.,multiple GoAways). + default: + close(t.goAway) + } } t.mu.Unlock() } diff --git a/transport/http2_server.go b/transport/http2_server.go index f0d1f7f6..42f5d9c4 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -755,6 +755,6 @@ func (t *http2Server) RemoteAddr() net.Addr { return t.conn.RemoteAddr() } -func (t *http2Server) GoAway() { +func (t *http2Server) Drain() { t.controlBuf.put(&goAway{}) } diff --git a/transport/transport.go b/transport/transport.go index f94731ed..9dade654 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -473,8 +473,8 @@ type ServerTransport interface { // RemoteAddr returns the remote network address. RemoteAddr() net.Addr - // GoAway notifies the client this ServerTransport stops accepting new RPCs. - GoAway() + // Drain notifies the client this ServerTransport stops accepting new RPCs. + Drain() } // StreamErrorf creates an StreamError with the specified error code and description. From e32c9f5d941ae3c1419baceec5c791e623111109 Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Mon, 25 Jul 2016 17:07:54 -0700 Subject: [PATCH 7/9] disallow illegal goaway stream id --- transport/http2_client.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/transport/http2_client.go b/transport/http2_client.go index ffde50ae..c8894ba5 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -492,7 +492,11 @@ func (t *http2Client) GracefulClose() error { if n == 0 && t.nextID > 1 { n = t.nextID - 2 } - for i := t.goAwayID + 2; i <= n; i += 2 { + m := t.goAwayID + 2 + if m == 2 { + m = 1 + } + for i := m; i <= n; i += 2 { if s, ok := t.activeStreams[i]; ok { close(s.goAway) } @@ -749,7 +753,7 @@ func (t *http2Client) handlePing(f *http2.PingFrame) { func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { t.mu.Lock() if t.state == reachable || t.state == draining { - if t.goAwayID > 0 && t.goAwayID < f.LastStreamID { + if f.LastStreamID > 0 && (f.LastStreamID%2 != 1 || t.goAwayID < f.LastStreamID) { id := t.goAwayID t.mu.Unlock() t.notifyError(ConnectionErrorf("received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID)) From d78e09a48d7a20d47ff0f53a370da591c652add4 Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Mon, 25 Jul 2016 17:24:42 -0700 Subject: [PATCH 8/9] bug fix --- transport/http2_client.go | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/transport/http2_client.go b/transport/http2_client.go index c8894ba5..6a0a86ee 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -752,21 +752,29 @@ func (t *http2Client) handlePing(f *http2.PingFrame) { func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { t.mu.Lock() + id := t.goAwayID if t.state == reachable || t.state == draining { - if f.LastStreamID > 0 && (f.LastStreamID%2 != 1 || t.goAwayID < f.LastStreamID) { - id := t.goAwayID + if f.LastStreamID > 0 && f.LastStreamID%2 != 1 { t.mu.Unlock() - t.notifyError(ConnectionErrorf("received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID)) + t.notifyError(ConnectionErrorf("received illegal http2 GOAWAY frame: stream ID %d is even", id)) return } - t.prevGoAwayID = t.goAwayID - t.goAwayID = f.LastStreamID select { case <-t.goAway: // t.goAway has been closed (i.e.,multiple GoAways). + if id < f.LastStreamID { + t.mu.Unlock() + t.notifyError(ConnectionErrorf("received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID)) + return + } + t.prevGoAwayID = id + t.goAwayID = f.LastStreamID + t.mu.Unlock() + return default: - close(t.goAway) } + t.goAwayID = f.LastStreamID + close(t.goAway) } t.mu.Unlock() } From b12fa98959491cf2e892164b2e70685fc3067413 Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Mon, 25 Jul 2016 17:41:47 -0700 Subject: [PATCH 9/9] small touchup --- test/end2end_test.go | 4 ++-- transport/http2_client.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/end2end_test.go b/test/end2end_test.go index 2cc4ccfc..cdbc4c55 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -645,7 +645,7 @@ func testServerGoAwayPendingRPC(t *testing.T, e env) { } // Finish an RPC to make sure the connection is good. if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { - t.Fatalf("fadjflajdflkaflj") + t.Fatalf("%v.EmptyCall(_, _, _) = _, %v, want _, ", tc, err) } ch := make(chan struct{}) go func() { @@ -679,7 +679,7 @@ func testServerGoAwayPendingRPC(t *testing.T, e env) { t.Fatalf("%v.Send(%v) = %v, want ", stream, req, err) } if _, err := stream.Recv(); err != nil { - t.Fatalf("%v.Recv() = %v, want _, ", stream, err) + t.Fatalf("%v.Recv() = _, %v, want _, ", stream, err) } cancel() <-ch diff --git a/transport/http2_client.go b/transport/http2_client.go index 6a0a86ee..08551c4d 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -752,15 +752,15 @@ func (t *http2Client) handlePing(f *http2.PingFrame) { func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { t.mu.Lock() - id := t.goAwayID if t.state == reachable || t.state == draining { if f.LastStreamID > 0 && f.LastStreamID%2 != 1 { t.mu.Unlock() - t.notifyError(ConnectionErrorf("received illegal http2 GOAWAY frame: stream ID %d is even", id)) + t.notifyError(ConnectionErrorf("received illegal http2 GOAWAY frame: stream ID %d is even", f.LastStreamID)) return } select { case <-t.goAway: + id := t.goAwayID // t.goAway has been closed (i.e.,multiple GoAways). if id < f.LastStreamID { t.mu.Unlock()