support goaway

This commit is contained in:
iamqizhao
2016-07-20 18:48:49 -07:00
parent 0e86f69ef3
commit 873cc272c2
10 changed files with 136 additions and 25 deletions

View File

@ -635,6 +635,7 @@ func (ac *addrConn) transportMonitor() {
if t.Err() == transport.ErrConnDrain {
ac.mu.Unlock()
ac.tearDown(errConnDrain)
ac.cc.newAddrConn(ac.addr, true)
return
}
ac.state = TransientFailure

View File

@ -385,6 +385,12 @@ func toRPCErr(err error) error {
desc: e.Desc,
}
case transport.ConnectionError:
if err == transport.ErrConnDrain {
return &rpcError{
code: codes.Unavailable,
desc: e.Desc,
}
}
return &rpcError{
code: codes.Internal,
desc: e.Desc,

View File

@ -92,6 +92,8 @@ type Server struct {
mu sync.Mutex // guards following
lis map[net.Listener]bool
conns map[io.Closer]bool
drain bool
cv *sync.Cond
m map[string]*service // service name -> service info
events trace.EventLog
}
@ -186,6 +188,7 @@ func NewServer(opt ...ServerOption) *Server {
conns: make(map[io.Closer]bool),
m: make(map[string]*service),
}
s.cv = sync.NewCond(&s.mu)
if EnableTracing {
_, file, line, _ := runtime.Caller(1)
s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
@ -468,7 +471,7 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea
func (s *Server) addConn(c io.Closer) bool {
s.mu.Lock()
defer s.mu.Unlock()
if s.conns == nil {
if s.conns == nil || s.drain {
return false
}
s.conns[c] = true
@ -480,6 +483,7 @@ func (s *Server) removeConn(c io.Closer) {
defer s.mu.Unlock()
if s.conns != nil {
delete(s.conns, c)
s.cv.Signal()
}
}
@ -766,14 +770,14 @@ func (s *Server) Stop() {
s.mu.Lock()
listeners := s.lis
s.lis = nil
cs := s.conns
st := s.conns
s.conns = nil
s.mu.Unlock()
for lis := range listeners {
lis.Close()
}
for c := range cs {
for c := range st {
c.Close()
}
@ -785,6 +789,28 @@ func (s *Server) Stop() {
s.mu.Unlock()
}
func (s *Server) GracefulStop() {
s.mu.Lock()
s.drain = true
for lis := range s.lis {
lis.Close()
}
for c := range s.conns {
c.(transport.ServerTransport).GoAway()
}
for len(s.conns) != 0 {
s.cv.Wait()
}
s.lis = nil
s.conns = nil
if s.events != nil {
s.events.Finish()
s.events = nil
}
s.mu.Unlock()
}
func init() {
internal.TestingCloseConns = func(arg interface{}) {
arg.(*Server).testingCloseConns()

View File

@ -195,6 +195,9 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
cs.finish(Errorf(s.StatusCode(), "%s", s.StatusDesc()))
}
cs.closeTransportStream(nil)
case <-s.GoAway():
cs.finish(errConnDrain)
cs.closeTransportStream(errConnDrain)
case <-s.Context().Done():
err := s.Context().Err()
cs.finish(err)

View File

@ -572,6 +572,55 @@ func TestFailFast(t *testing.T) {
}
}
func TestServerGoAway(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
if e.name == "handler-tls" {
continue
}
//if e.name != "tcp-clear" {
// continue
//}
testServerGoAway(t, e)
}
}
func testServerGoAway(t *testing.T, e env) {
te := newTest(t, e)
te.userAgent = testAppUA
te.declareLogNoise(
"transport: http2Client.notifyError got notified that the client transport was broken EOF",
"grpc: Conn.transportMonitor exits due to: grpc: the client connection is closing",
"grpc: Conn.resetTransport failed to create client transport: connection error",
"grpc: Conn.resetTransport failed to create client transport: connection error: desc = \"transport: dial unix",
)
te.startServer(&testServer{security: e.security})
defer te.tearDown()
cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
}
ch := make(chan struct{})
go func() {
te.srv.GracefulStop()
close(ch)
}()
for {
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil {
continue
}
break
}
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err == nil || grpc.Code(err) != codes.Unavailable {
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, error code: %d", err, codes.Unavailable)
}
<-ch
awaitNewConnLogOutput()
}
func testFailFast(t *testing.T, e env) {
te := newTest(t, e)
te.userAgent = testAppUA

View File

@ -72,6 +72,11 @@ type resetStream struct {
func (*resetStream) item() {}
type goAway struct {
}
func (*goAway) item() {}
type flushIO struct {
}

View File

@ -370,6 +370,9 @@ func (ht *serverHandlerTransport) runStream() {
}
}
func (ht *serverHandlerTransport) GoAway() {
}
// mapRecvMsgError returns the non-nil err into the appropriate
// error value as expected by callers of *grpc.parser.recvMsg.
// In particular, in can only be:

View File

@ -205,6 +205,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
s := &Stream{
id: t.nextID,
done: make(chan struct{}),
goAway: make(chan struct{}),
method: callHdr.Method,
sendCompress: callHdr.SendCompress,
buf: newRecvBuffer(),
@ -220,6 +221,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
s.ctx, s.cancel = context.WithCancel(ctx)
s.dec = &recvBufferReader{
ctx: s.ctx,
goAway: s.goAway,
recv: s.buf,
}
return s
@ -443,13 +445,13 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
// accessed any more.
func (t *http2Client) Close() (err error) {
t.mu.Lock()
if t.state == reachable {
close(t.errorChan)
}
if t.state == closing {
t.mu.Unlock()
return
}
if t.state == reachable {
close(t.errorChan)
}
t.state = closing
t.mu.Unlock()
close(t.shutdownChan)
@ -732,16 +734,11 @@ func (t *http2Client) handlePing(f *http2.PingFrame) {
func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
t.mu.Lock()
if t.state == reachable {
t.goAwayID = f.LastStreamID
t.err = ErrDrain
t.err = ErrConnDrain
close(t.errorChan)
// Notify the streams which were initiated after the server sent GOAWAY.
//for i := f.LastStreamID + 2; i < t.nextID; i += 2 {
// if s, ok := t.activeStreams[i]; ok {
// close(s.goAway)
// }
//}
}
t.mu.Unlock()
}

View File

@ -196,15 +196,22 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
s.recvCompress = state.encoding
s.method = state.method
t.mu.Lock()
if t.state == draining {
t.mu.Unlock()
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
return
}
if t.state != reachable {
t.mu.Unlock()
return
}
if uint32(len(t.activeStreams)) >= t.maxStreams {
t.mu.Unlock()
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
return
}
s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
t.activeStreams[s.id] = s
t.mu.Unlock()
@ -263,13 +270,16 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
switch frame := frame.(type) {
case *http2.MetaHeadersFrame:
id := frame.Header().StreamID
t.mu.Lock()
if id%2 != 1 || id <= t.maxStreamID {
t.mu.Unlock()
// illegal gRPC stream id.
grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", id)
t.Close()
break
}
t.maxStreamID = id
t.mu.Unlock()
t.operateHeaders(frame, handle)
case *http2.DataFrame:
t.handleData(frame)
@ -282,6 +292,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
case *http2.WindowUpdateFrame:
t.handleWindowUpdate(frame)
case *http2.GoAwayFrame:
t.Close()
break
default:
grpclog.Printf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame)
@ -675,6 +686,12 @@ func (t *http2Server) controller() {
}
case *resetStream:
t.framer.writeRSTStream(true, i.streamID, i.code)
case *goAway:
t.mu.Lock()
sid := t.maxStreamID
t.state = draining
t.mu.Unlock()
t.framer.writeGoAway(true, sid, http2.ErrCodeNo, nil)
case *flushIO:
t.framer.flushWrite()
case *ping:
@ -742,3 +759,7 @@ func (t *http2Server) closeStream(s *Stream) {
func (t *http2Server) RemoteAddr() net.Addr {
return t.conn.RemoteAddr()
}
func (t *http2Server) GoAway() {
t.controlBuf.put(&goAway{})
}

View File

@ -53,10 +53,6 @@ import (
"google.golang.org/grpc/metadata"
)
var (
ErrDrain = ConnectionErrorf("transport: Server stopped accepting new RPCs")
)
// recvMsg represents the received msg from the transport. All transport
// protocol specific info has been removed.
type recvMsg struct {
@ -147,7 +143,7 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) {
case <-r.ctx.Done():
return 0, ContextErr(r.ctx.Err())
case <-r.goAway:
return 0, ErrConnDrain
return 0, ErrStreamDrain
case i := <-r.recv.get():
r.recv.load()
m := i.(*recvMsg)
@ -478,6 +474,9 @@ type ServerTransport interface {
// RemoteAddr returns the remote network address.
RemoteAddr() net.Addr
// GoAway ...
GoAway()
}
// StreamErrorf creates an StreamError with the specified error code and description.
@ -509,6 +508,7 @@ func (e ConnectionError) Error() string {
var (
ErrConnClosing = ConnectionError{Desc: "transport is closing"}
ErrConnDrain = ConnectionError{Desc: "transport is being drained"}
ErrStreamDrain = StreamErrorf(codes.Unavailable, "afjlalf")
)
// StreamError is an error that only affects one stream within a connection.
@ -536,7 +536,7 @@ func ContextErr(err error) StreamError {
// If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err.
// If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise
// it return the StreamError for ctx.Err.
// If it receives from goAway, it returns 0, ErrConnDrain.
// If it receives from goAway, it returns 0, ErrStreamDrain.
// If it receives from closing, it returns 0, ErrConnClosing.
// If it receives from proceed, it returns the received integer, nil.
func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-chan int) (int, error) {
@ -552,7 +552,7 @@ func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-
}
return 0, io.EOF
case <-goAway:
return 0, ErrConnDrain
return 0, ErrStreamDrain
case <-closing:
return 0, ErrConnClosing
case i := <-proceed: