support goaway

This commit is contained in:
iamqizhao
2016-07-20 18:48:49 -07:00
parent 0e86f69ef3
commit 873cc272c2
10 changed files with 136 additions and 25 deletions

View File

@ -635,6 +635,7 @@ func (ac *addrConn) transportMonitor() {
if t.Err() == transport.ErrConnDrain { if t.Err() == transport.ErrConnDrain {
ac.mu.Unlock() ac.mu.Unlock()
ac.tearDown(errConnDrain) ac.tearDown(errConnDrain)
ac.cc.newAddrConn(ac.addr, true)
return return
} }
ac.state = TransientFailure ac.state = TransientFailure

View File

@ -385,6 +385,12 @@ func toRPCErr(err error) error {
desc: e.Desc, desc: e.Desc,
} }
case transport.ConnectionError: case transport.ConnectionError:
if err == transport.ErrConnDrain {
return &rpcError{
code: codes.Unavailable,
desc: e.Desc,
}
}
return &rpcError{ return &rpcError{
code: codes.Internal, code: codes.Internal,
desc: e.Desc, desc: e.Desc,

View File

@ -92,6 +92,8 @@ type Server struct {
mu sync.Mutex // guards following mu sync.Mutex // guards following
lis map[net.Listener]bool lis map[net.Listener]bool
conns map[io.Closer]bool conns map[io.Closer]bool
drain bool
cv *sync.Cond
m map[string]*service // service name -> service info m map[string]*service // service name -> service info
events trace.EventLog events trace.EventLog
} }
@ -186,6 +188,7 @@ func NewServer(opt ...ServerOption) *Server {
conns: make(map[io.Closer]bool), conns: make(map[io.Closer]bool),
m: make(map[string]*service), m: make(map[string]*service),
} }
s.cv = sync.NewCond(&s.mu)
if EnableTracing { if EnableTracing {
_, file, line, _ := runtime.Caller(1) _, file, line, _ := runtime.Caller(1)
s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line)) s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
@ -468,7 +471,7 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea
func (s *Server) addConn(c io.Closer) bool { func (s *Server) addConn(c io.Closer) bool {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.conns == nil { if s.conns == nil || s.drain {
return false return false
} }
s.conns[c] = true s.conns[c] = true
@ -480,6 +483,7 @@ func (s *Server) removeConn(c io.Closer) {
defer s.mu.Unlock() defer s.mu.Unlock()
if s.conns != nil { if s.conns != nil {
delete(s.conns, c) delete(s.conns, c)
s.cv.Signal()
} }
} }
@ -766,14 +770,14 @@ func (s *Server) Stop() {
s.mu.Lock() s.mu.Lock()
listeners := s.lis listeners := s.lis
s.lis = nil s.lis = nil
cs := s.conns st := s.conns
s.conns = nil s.conns = nil
s.mu.Unlock() s.mu.Unlock()
for lis := range listeners { for lis := range listeners {
lis.Close() lis.Close()
} }
for c := range cs { for c := range st {
c.Close() c.Close()
} }
@ -785,6 +789,28 @@ func (s *Server) Stop() {
s.mu.Unlock() s.mu.Unlock()
} }
func (s *Server) GracefulStop() {
s.mu.Lock()
s.drain = true
for lis := range s.lis {
lis.Close()
}
for c := range s.conns {
c.(transport.ServerTransport).GoAway()
}
for len(s.conns) != 0 {
s.cv.Wait()
}
s.lis = nil
s.conns = nil
if s.events != nil {
s.events.Finish()
s.events = nil
}
s.mu.Unlock()
}
func init() { func init() {
internal.TestingCloseConns = func(arg interface{}) { internal.TestingCloseConns = func(arg interface{}) {
arg.(*Server).testingCloseConns() arg.(*Server).testingCloseConns()

View File

@ -195,6 +195,9 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
cs.finish(Errorf(s.StatusCode(), "%s", s.StatusDesc())) cs.finish(Errorf(s.StatusCode(), "%s", s.StatusDesc()))
} }
cs.closeTransportStream(nil) cs.closeTransportStream(nil)
case <-s.GoAway():
cs.finish(errConnDrain)
cs.closeTransportStream(errConnDrain)
case <-s.Context().Done(): case <-s.Context().Done():
err := s.Context().Err() err := s.Context().Err()
cs.finish(err) cs.finish(err)

View File

@ -572,6 +572,55 @@ func TestFailFast(t *testing.T) {
} }
} }
func TestServerGoAway(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
if e.name == "handler-tls" {
continue
}
//if e.name != "tcp-clear" {
// continue
//}
testServerGoAway(t, e)
}
}
func testServerGoAway(t *testing.T, e env) {
te := newTest(t, e)
te.userAgent = testAppUA
te.declareLogNoise(
"transport: http2Client.notifyError got notified that the client transport was broken EOF",
"grpc: Conn.transportMonitor exits due to: grpc: the client connection is closing",
"grpc: Conn.resetTransport failed to create client transport: connection error",
"grpc: Conn.resetTransport failed to create client transport: connection error: desc = \"transport: dial unix",
)
te.startServer(&testServer{security: e.security})
defer te.tearDown()
cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
}
ch := make(chan struct{})
go func() {
te.srv.GracefulStop()
close(ch)
}()
for {
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil {
continue
}
break
}
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err == nil || grpc.Code(err) != codes.Unavailable {
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, error code: %d", err, codes.Unavailable)
}
<-ch
awaitNewConnLogOutput()
}
func testFailFast(t *testing.T, e env) { func testFailFast(t *testing.T, e env) {
te := newTest(t, e) te := newTest(t, e)
te.userAgent = testAppUA te.userAgent = testAppUA

View File

@ -72,6 +72,11 @@ type resetStream struct {
func (*resetStream) item() {} func (*resetStream) item() {}
type goAway struct {
}
func (*goAway) item() {}
type flushIO struct { type flushIO struct {
} }

View File

@ -370,6 +370,9 @@ func (ht *serverHandlerTransport) runStream() {
} }
} }
func (ht *serverHandlerTransport) GoAway() {
}
// mapRecvMsgError returns the non-nil err into the appropriate // mapRecvMsgError returns the non-nil err into the appropriate
// error value as expected by callers of *grpc.parser.recvMsg. // error value as expected by callers of *grpc.parser.recvMsg.
// In particular, in can only be: // In particular, in can only be:

View File

@ -205,6 +205,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
s := &Stream{ s := &Stream{
id: t.nextID, id: t.nextID,
done: make(chan struct{}), done: make(chan struct{}),
goAway: make(chan struct{}),
method: callHdr.Method, method: callHdr.Method,
sendCompress: callHdr.SendCompress, sendCompress: callHdr.SendCompress,
buf: newRecvBuffer(), buf: newRecvBuffer(),
@ -219,8 +220,9 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
// Make a stream be able to cancel the pending operations by itself. // Make a stream be able to cancel the pending operations by itself.
s.ctx, s.cancel = context.WithCancel(ctx) s.ctx, s.cancel = context.WithCancel(ctx)
s.dec = &recvBufferReader{ s.dec = &recvBufferReader{
ctx: s.ctx, ctx: s.ctx,
recv: s.buf, goAway: s.goAway,
recv: s.buf,
} }
return s return s
} }
@ -443,13 +445,13 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
// accessed any more. // accessed any more.
func (t *http2Client) Close() (err error) { func (t *http2Client) Close() (err error) {
t.mu.Lock() t.mu.Lock()
if t.state == reachable {
close(t.errorChan)
}
if t.state == closing { if t.state == closing {
t.mu.Unlock() t.mu.Unlock()
return return
} }
if t.state == reachable {
close(t.errorChan)
}
t.state = closing t.state = closing
t.mu.Unlock() t.mu.Unlock()
close(t.shutdownChan) close(t.shutdownChan)
@ -732,16 +734,11 @@ func (t *http2Client) handlePing(f *http2.PingFrame) {
func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
t.mu.Lock() t.mu.Lock()
t.goAwayID = f.LastStreamID if t.state == reachable {
t.err = ErrDrain t.goAwayID = f.LastStreamID
close(t.errorChan) t.err = ErrConnDrain
close(t.errorChan)
// Notify the streams which were initiated after the server sent GOAWAY. }
//for i := f.LastStreamID + 2; i < t.nextID; i += 2 {
// if s, ok := t.activeStreams[i]; ok {
// close(s.goAway)
// }
//}
t.mu.Unlock() t.mu.Unlock()
} }

View File

@ -196,15 +196,22 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
s.recvCompress = state.encoding s.recvCompress = state.encoding
s.method = state.method s.method = state.method
t.mu.Lock() t.mu.Lock()
if t.state == draining {
t.mu.Unlock()
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
return
}
if t.state != reachable { if t.state != reachable {
t.mu.Unlock() t.mu.Unlock()
return return
} }
if uint32(len(t.activeStreams)) >= t.maxStreams { if uint32(len(t.activeStreams)) >= t.maxStreams {
t.mu.Unlock() t.mu.Unlock()
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream}) t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
return return
} }
s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota)) s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
t.activeStreams[s.id] = s t.activeStreams[s.id] = s
t.mu.Unlock() t.mu.Unlock()
@ -263,13 +270,16 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
switch frame := frame.(type) { switch frame := frame.(type) {
case *http2.MetaHeadersFrame: case *http2.MetaHeadersFrame:
id := frame.Header().StreamID id := frame.Header().StreamID
t.mu.Lock()
if id%2 != 1 || id <= t.maxStreamID { if id%2 != 1 || id <= t.maxStreamID {
t.mu.Unlock()
// illegal gRPC stream id. // illegal gRPC stream id.
grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", id) grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", id)
t.Close() t.Close()
break break
} }
t.maxStreamID = id t.maxStreamID = id
t.mu.Unlock()
t.operateHeaders(frame, handle) t.operateHeaders(frame, handle)
case *http2.DataFrame: case *http2.DataFrame:
t.handleData(frame) t.handleData(frame)
@ -282,6 +292,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
case *http2.WindowUpdateFrame: case *http2.WindowUpdateFrame:
t.handleWindowUpdate(frame) t.handleWindowUpdate(frame)
case *http2.GoAwayFrame: case *http2.GoAwayFrame:
t.Close()
break break
default: default:
grpclog.Printf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame) grpclog.Printf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame)
@ -675,6 +686,12 @@ func (t *http2Server) controller() {
} }
case *resetStream: case *resetStream:
t.framer.writeRSTStream(true, i.streamID, i.code) t.framer.writeRSTStream(true, i.streamID, i.code)
case *goAway:
t.mu.Lock()
sid := t.maxStreamID
t.state = draining
t.mu.Unlock()
t.framer.writeGoAway(true, sid, http2.ErrCodeNo, nil)
case *flushIO: case *flushIO:
t.framer.flushWrite() t.framer.flushWrite()
case *ping: case *ping:
@ -742,3 +759,7 @@ func (t *http2Server) closeStream(s *Stream) {
func (t *http2Server) RemoteAddr() net.Addr { func (t *http2Server) RemoteAddr() net.Addr {
return t.conn.RemoteAddr() return t.conn.RemoteAddr()
} }
func (t *http2Server) GoAway() {
t.controlBuf.put(&goAway{})
}

View File

@ -53,10 +53,6 @@ import (
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
) )
var (
ErrDrain = ConnectionErrorf("transport: Server stopped accepting new RPCs")
)
// recvMsg represents the received msg from the transport. All transport // recvMsg represents the received msg from the transport. All transport
// protocol specific info has been removed. // protocol specific info has been removed.
type recvMsg struct { type recvMsg struct {
@ -147,7 +143,7 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) {
case <-r.ctx.Done(): case <-r.ctx.Done():
return 0, ContextErr(r.ctx.Err()) return 0, ContextErr(r.ctx.Err())
case <-r.goAway: case <-r.goAway:
return 0, ErrConnDrain return 0, ErrStreamDrain
case i := <-r.recv.get(): case i := <-r.recv.get():
r.recv.load() r.recv.load()
m := i.(*recvMsg) m := i.(*recvMsg)
@ -478,6 +474,9 @@ type ServerTransport interface {
// RemoteAddr returns the remote network address. // RemoteAddr returns the remote network address.
RemoteAddr() net.Addr RemoteAddr() net.Addr
// GoAway ...
GoAway()
} }
// StreamErrorf creates an StreamError with the specified error code and description. // StreamErrorf creates an StreamError with the specified error code and description.
@ -509,6 +508,7 @@ func (e ConnectionError) Error() string {
var ( var (
ErrConnClosing = ConnectionError{Desc: "transport is closing"} ErrConnClosing = ConnectionError{Desc: "transport is closing"}
ErrConnDrain = ConnectionError{Desc: "transport is being drained"} ErrConnDrain = ConnectionError{Desc: "transport is being drained"}
ErrStreamDrain = StreamErrorf(codes.Unavailable, "afjlalf")
) )
// StreamError is an error that only affects one stream within a connection. // StreamError is an error that only affects one stream within a connection.
@ -536,7 +536,7 @@ func ContextErr(err error) StreamError {
// If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err. // If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err.
// If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise // If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise
// it return the StreamError for ctx.Err. // it return the StreamError for ctx.Err.
// If it receives from goAway, it returns 0, ErrConnDrain. // If it receives from goAway, it returns 0, ErrStreamDrain.
// If it receives from closing, it returns 0, ErrConnClosing. // If it receives from closing, it returns 0, ErrConnClosing.
// If it receives from proceed, it returns the received integer, nil. // If it receives from proceed, it returns the received integer, nil.
func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-chan int) (int, error) { func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-chan int) (int, error) {
@ -552,7 +552,7 @@ func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-
} }
return 0, io.EOF return 0, io.EOF
case <-goAway: case <-goAway:
return 0, ErrConnDrain return 0, ErrStreamDrain
case <-closing: case <-closing:
return 0, ErrConnClosing return 0, ErrConnClosing
case i := <-proceed: case i := <-proceed: