From 873cc272c2e52e8e58d96193376d19760012863b Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Wed, 20 Jul 2016 18:48:49 -0700 Subject: [PATCH] support goaway --- clientconn.go | 1 + rpc_util.go | 6 +++++ server.go | 32 +++++++++++++++++++++--- stream.go | 3 +++ test/end2end_test.go | 49 +++++++++++++++++++++++++++++++++++++ transport/control.go | 5 ++++ transport/handler_server.go | 3 +++ transport/http2_client.go | 27 +++++++++----------- transport/http2_server.go | 21 ++++++++++++++++ transport/transport.go | 14 +++++------ 10 files changed, 136 insertions(+), 25 deletions(-) diff --git a/clientconn.go b/clientconn.go index f6b39e4a..4933b554 100644 --- a/clientconn.go +++ b/clientconn.go @@ -635,6 +635,7 @@ func (ac *addrConn) transportMonitor() { if t.Err() == transport.ErrConnDrain { ac.mu.Unlock() ac.tearDown(errConnDrain) + ac.cc.newAddrConn(ac.addr, true) return } ac.state = TransientFailure diff --git a/rpc_util.go b/rpc_util.go index d6287175..173018e4 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -385,6 +385,12 @@ func toRPCErr(err error) error { desc: e.Desc, } case transport.ConnectionError: + if err == transport.ErrConnDrain { + return &rpcError{ + code: codes.Unavailable, + desc: e.Desc, + } + } return &rpcError{ code: codes.Internal, desc: e.Desc, diff --git a/server.go b/server.go index a2b2b94d..34e69102 100644 --- a/server.go +++ b/server.go @@ -92,6 +92,8 @@ type Server struct { mu sync.Mutex // guards following lis map[net.Listener]bool conns map[io.Closer]bool + drain bool + cv *sync.Cond m map[string]*service // service name -> service info events trace.EventLog } @@ -186,6 +188,7 @@ func NewServer(opt ...ServerOption) *Server { conns: make(map[io.Closer]bool), m: make(map[string]*service), } + s.cv = sync.NewCond(&s.mu) if EnableTracing { _, file, line, _ := runtime.Caller(1) s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line)) @@ -468,7 +471,7 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea func (s *Server) addConn(c io.Closer) bool { s.mu.Lock() defer s.mu.Unlock() - if s.conns == nil { + if s.conns == nil || s.drain { return false } s.conns[c] = true @@ -480,6 +483,7 @@ func (s *Server) removeConn(c io.Closer) { defer s.mu.Unlock() if s.conns != nil { delete(s.conns, c) + s.cv.Signal() } } @@ -766,14 +770,14 @@ func (s *Server) Stop() { s.mu.Lock() listeners := s.lis s.lis = nil - cs := s.conns + st := s.conns s.conns = nil s.mu.Unlock() for lis := range listeners { lis.Close() } - for c := range cs { + for c := range st { c.Close() } @@ -785,6 +789,28 @@ func (s *Server) Stop() { s.mu.Unlock() } +func (s *Server) GracefulStop() { + s.mu.Lock() + s.drain = true + for lis := range s.lis { + lis.Close() + } + for c := range s.conns { + c.(transport.ServerTransport).GoAway() + } + for len(s.conns) != 0 { + s.cv.Wait() + } + s.lis = nil + s.conns = nil + if s.events != nil { + s.events.Finish() + s.events = nil + } + s.mu.Unlock() + +} + func init() { internal.TestingCloseConns = func(arg interface{}) { arg.(*Server).testingCloseConns() diff --git a/stream.go b/stream.go index dfa224b4..deb8663c 100644 --- a/stream.go +++ b/stream.go @@ -195,6 +195,9 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth cs.finish(Errorf(s.StatusCode(), "%s", s.StatusDesc())) } cs.closeTransportStream(nil) + case <-s.GoAway(): + cs.finish(errConnDrain) + cs.closeTransportStream(errConnDrain) case <-s.Context().Done(): err := s.Context().Err() cs.finish(err) diff --git a/test/end2end_test.go b/test/end2end_test.go index fdac5815..721e6706 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -572,6 +572,55 @@ func TestFailFast(t *testing.T) { } } +func TestServerGoAway(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + if e.name == "handler-tls" { + continue + } + //if e.name != "tcp-clear" { + // continue + //} + testServerGoAway(t, e) + } +} + +func testServerGoAway(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + te.declareLogNoise( + "transport: http2Client.notifyError got notified that the client transport was broken EOF", + "grpc: Conn.transportMonitor exits due to: grpc: the client connection is closing", + "grpc: Conn.resetTransport failed to create client transport: connection error", + "grpc: Conn.resetTransport failed to create client transport: connection error: desc = \"transport: dial unix", + ) + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) + } + ch := make(chan struct{}) + go func() { + te.srv.GracefulStop() + close(ch) + }() + for { + ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil { + continue + } + break + } + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err == nil || grpc.Code(err) != codes.Unavailable { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, error code: %d", err, codes.Unavailable) + } + <-ch + awaitNewConnLogOutput() +} + func testFailFast(t *testing.T, e env) { te := newTest(t, e) te.userAgent = testAppUA diff --git a/transport/control.go b/transport/control.go index 7e9bdf33..4ef0830b 100644 --- a/transport/control.go +++ b/transport/control.go @@ -72,6 +72,11 @@ type resetStream struct { func (*resetStream) item() {} +type goAway struct { +} + +func (*goAway) item() {} + type flushIO struct { } diff --git a/transport/handler_server.go b/transport/handler_server.go index 4b0d5252..723bf5b0 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -370,6 +370,9 @@ func (ht *serverHandlerTransport) runStream() { } } +func (ht *serverHandlerTransport) GoAway() { +} + // mapRecvMsgError returns the non-nil err into the appropriate // error value as expected by callers of *grpc.parser.recvMsg. // In particular, in can only be: diff --git a/transport/http2_client.go b/transport/http2_client.go index bde77f11..2ec703f0 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -205,6 +205,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { s := &Stream{ id: t.nextID, done: make(chan struct{}), + goAway: make(chan struct{}), method: callHdr.Method, sendCompress: callHdr.SendCompress, buf: newRecvBuffer(), @@ -219,8 +220,9 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { // Make a stream be able to cancel the pending operations by itself. s.ctx, s.cancel = context.WithCancel(ctx) s.dec = &recvBufferReader{ - ctx: s.ctx, - recv: s.buf, + ctx: s.ctx, + goAway: s.goAway, + recv: s.buf, } return s } @@ -443,13 +445,13 @@ func (t *http2Client) CloseStream(s *Stream, err error) { // accessed any more. func (t *http2Client) Close() (err error) { t.mu.Lock() - if t.state == reachable { - close(t.errorChan) - } if t.state == closing { t.mu.Unlock() return } + if t.state == reachable { + close(t.errorChan) + } t.state = closing t.mu.Unlock() close(t.shutdownChan) @@ -732,16 +734,11 @@ func (t *http2Client) handlePing(f *http2.PingFrame) { func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { t.mu.Lock() - t.goAwayID = f.LastStreamID - t.err = ErrDrain - close(t.errorChan) - - // Notify the streams which were initiated after the server sent GOAWAY. - //for i := f.LastStreamID + 2; i < t.nextID; i += 2 { - // if s, ok := t.activeStreams[i]; ok { - // close(s.goAway) - // } - //} + if t.state == reachable { + t.goAwayID = f.LastStreamID + t.err = ErrConnDrain + close(t.errorChan) + } t.mu.Unlock() } diff --git a/transport/http2_server.go b/transport/http2_server.go index 2467630a..d7cab4fe 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -196,15 +196,22 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( s.recvCompress = state.encoding s.method = state.method t.mu.Lock() + if t.state == draining { + t.mu.Unlock() + t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream}) + return + } if t.state != reachable { t.mu.Unlock() return } + if uint32(len(t.activeStreams)) >= t.maxStreams { t.mu.Unlock() t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream}) return } + s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota)) t.activeStreams[s.id] = s t.mu.Unlock() @@ -263,13 +270,16 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) { switch frame := frame.(type) { case *http2.MetaHeadersFrame: id := frame.Header().StreamID + t.mu.Lock() if id%2 != 1 || id <= t.maxStreamID { + t.mu.Unlock() // illegal gRPC stream id. grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", id) t.Close() break } t.maxStreamID = id + t.mu.Unlock() t.operateHeaders(frame, handle) case *http2.DataFrame: t.handleData(frame) @@ -282,6 +292,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) { case *http2.WindowUpdateFrame: t.handleWindowUpdate(frame) case *http2.GoAwayFrame: + t.Close() break default: grpclog.Printf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame) @@ -675,6 +686,12 @@ func (t *http2Server) controller() { } case *resetStream: t.framer.writeRSTStream(true, i.streamID, i.code) + case *goAway: + t.mu.Lock() + sid := t.maxStreamID + t.state = draining + t.mu.Unlock() + t.framer.writeGoAway(true, sid, http2.ErrCodeNo, nil) case *flushIO: t.framer.flushWrite() case *ping: @@ -742,3 +759,7 @@ func (t *http2Server) closeStream(s *Stream) { func (t *http2Server) RemoteAddr() net.Addr { return t.conn.RemoteAddr() } + +func (t *http2Server) GoAway() { + t.controlBuf.put(&goAway{}) +} diff --git a/transport/transport.go b/transport/transport.go index 4a7b83c5..2372f322 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -53,10 +53,6 @@ import ( "google.golang.org/grpc/metadata" ) -var ( - ErrDrain = ConnectionErrorf("transport: Server stopped accepting new RPCs") -) - // recvMsg represents the received msg from the transport. All transport // protocol specific info has been removed. type recvMsg struct { @@ -147,7 +143,7 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) { case <-r.ctx.Done(): return 0, ContextErr(r.ctx.Err()) case <-r.goAway: - return 0, ErrConnDrain + return 0, ErrStreamDrain case i := <-r.recv.get(): r.recv.load() m := i.(*recvMsg) @@ -478,6 +474,9 @@ type ServerTransport interface { // RemoteAddr returns the remote network address. RemoteAddr() net.Addr + + // GoAway ... + GoAway() } // StreamErrorf creates an StreamError with the specified error code and description. @@ -509,6 +508,7 @@ func (e ConnectionError) Error() string { var ( ErrConnClosing = ConnectionError{Desc: "transport is closing"} ErrConnDrain = ConnectionError{Desc: "transport is being drained"} + ErrStreamDrain = StreamErrorf(codes.Unavailable, "afjlalf") ) // StreamError is an error that only affects one stream within a connection. @@ -536,7 +536,7 @@ func ContextErr(err error) StreamError { // If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err. // If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise // it return the StreamError for ctx.Err. -// If it receives from goAway, it returns 0, ErrConnDrain. +// If it receives from goAway, it returns 0, ErrStreamDrain. // If it receives from closing, it returns 0, ErrConnClosing. // If it receives from proceed, it returns the received integer, nil. func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-chan int) (int, error) { @@ -552,7 +552,7 @@ func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <- } return 0, io.EOF case <-goAway: - return 0, ErrConnDrain + return 0, ErrStreamDrain case <-closing: return 0, ErrConnClosing case i := <-proceed: