support goaway
This commit is contained in:
@ -635,6 +635,7 @@ func (ac *addrConn) transportMonitor() {
|
|||||||
if t.Err() == transport.ErrConnDrain {
|
if t.Err() == transport.ErrConnDrain {
|
||||||
ac.mu.Unlock()
|
ac.mu.Unlock()
|
||||||
ac.tearDown(errConnDrain)
|
ac.tearDown(errConnDrain)
|
||||||
|
ac.cc.newAddrConn(ac.addr, true)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ac.state = TransientFailure
|
ac.state = TransientFailure
|
||||||
|
@ -385,6 +385,12 @@ func toRPCErr(err error) error {
|
|||||||
desc: e.Desc,
|
desc: e.Desc,
|
||||||
}
|
}
|
||||||
case transport.ConnectionError:
|
case transport.ConnectionError:
|
||||||
|
if err == transport.ErrConnDrain {
|
||||||
|
return &rpcError{
|
||||||
|
code: codes.Unavailable,
|
||||||
|
desc: e.Desc,
|
||||||
|
}
|
||||||
|
}
|
||||||
return &rpcError{
|
return &rpcError{
|
||||||
code: codes.Internal,
|
code: codes.Internal,
|
||||||
desc: e.Desc,
|
desc: e.Desc,
|
||||||
|
32
server.go
32
server.go
@ -92,6 +92,8 @@ type Server struct {
|
|||||||
mu sync.Mutex // guards following
|
mu sync.Mutex // guards following
|
||||||
lis map[net.Listener]bool
|
lis map[net.Listener]bool
|
||||||
conns map[io.Closer]bool
|
conns map[io.Closer]bool
|
||||||
|
drain bool
|
||||||
|
cv *sync.Cond
|
||||||
m map[string]*service // service name -> service info
|
m map[string]*service // service name -> service info
|
||||||
events trace.EventLog
|
events trace.EventLog
|
||||||
}
|
}
|
||||||
@ -186,6 +188,7 @@ func NewServer(opt ...ServerOption) *Server {
|
|||||||
conns: make(map[io.Closer]bool),
|
conns: make(map[io.Closer]bool),
|
||||||
m: make(map[string]*service),
|
m: make(map[string]*service),
|
||||||
}
|
}
|
||||||
|
s.cv = sync.NewCond(&s.mu)
|
||||||
if EnableTracing {
|
if EnableTracing {
|
||||||
_, file, line, _ := runtime.Caller(1)
|
_, file, line, _ := runtime.Caller(1)
|
||||||
s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
|
s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
|
||||||
@ -468,7 +471,7 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea
|
|||||||
func (s *Server) addConn(c io.Closer) bool {
|
func (s *Server) addConn(c io.Closer) bool {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
if s.conns == nil {
|
if s.conns == nil || s.drain {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
s.conns[c] = true
|
s.conns[c] = true
|
||||||
@ -480,6 +483,7 @@ func (s *Server) removeConn(c io.Closer) {
|
|||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
if s.conns != nil {
|
if s.conns != nil {
|
||||||
delete(s.conns, c)
|
delete(s.conns, c)
|
||||||
|
s.cv.Signal()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -766,14 +770,14 @@ func (s *Server) Stop() {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
listeners := s.lis
|
listeners := s.lis
|
||||||
s.lis = nil
|
s.lis = nil
|
||||||
cs := s.conns
|
st := s.conns
|
||||||
s.conns = nil
|
s.conns = nil
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
for lis := range listeners {
|
for lis := range listeners {
|
||||||
lis.Close()
|
lis.Close()
|
||||||
}
|
}
|
||||||
for c := range cs {
|
for c := range st {
|
||||||
c.Close()
|
c.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -785,6 +789,28 @@ func (s *Server) Stop() {
|
|||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) GracefulStop() {
|
||||||
|
s.mu.Lock()
|
||||||
|
s.drain = true
|
||||||
|
for lis := range s.lis {
|
||||||
|
lis.Close()
|
||||||
|
}
|
||||||
|
for c := range s.conns {
|
||||||
|
c.(transport.ServerTransport).GoAway()
|
||||||
|
}
|
||||||
|
for len(s.conns) != 0 {
|
||||||
|
s.cv.Wait()
|
||||||
|
}
|
||||||
|
s.lis = nil
|
||||||
|
s.conns = nil
|
||||||
|
if s.events != nil {
|
||||||
|
s.events.Finish()
|
||||||
|
s.events = nil
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
internal.TestingCloseConns = func(arg interface{}) {
|
internal.TestingCloseConns = func(arg interface{}) {
|
||||||
arg.(*Server).testingCloseConns()
|
arg.(*Server).testingCloseConns()
|
||||||
|
@ -195,6 +195,9 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
|
|||||||
cs.finish(Errorf(s.StatusCode(), "%s", s.StatusDesc()))
|
cs.finish(Errorf(s.StatusCode(), "%s", s.StatusDesc()))
|
||||||
}
|
}
|
||||||
cs.closeTransportStream(nil)
|
cs.closeTransportStream(nil)
|
||||||
|
case <-s.GoAway():
|
||||||
|
cs.finish(errConnDrain)
|
||||||
|
cs.closeTransportStream(errConnDrain)
|
||||||
case <-s.Context().Done():
|
case <-s.Context().Done():
|
||||||
err := s.Context().Err()
|
err := s.Context().Err()
|
||||||
cs.finish(err)
|
cs.finish(err)
|
||||||
|
@ -572,6 +572,55 @@ func TestFailFast(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServerGoAway(t *testing.T) {
|
||||||
|
defer leakCheck(t)()
|
||||||
|
for _, e := range listTestEnv() {
|
||||||
|
if e.name == "handler-tls" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
//if e.name != "tcp-clear" {
|
||||||
|
// continue
|
||||||
|
//}
|
||||||
|
testServerGoAway(t, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testServerGoAway(t *testing.T, e env) {
|
||||||
|
te := newTest(t, e)
|
||||||
|
te.userAgent = testAppUA
|
||||||
|
te.declareLogNoise(
|
||||||
|
"transport: http2Client.notifyError got notified that the client transport was broken EOF",
|
||||||
|
"grpc: Conn.transportMonitor exits due to: grpc: the client connection is closing",
|
||||||
|
"grpc: Conn.resetTransport failed to create client transport: connection error",
|
||||||
|
"grpc: Conn.resetTransport failed to create client transport: connection error: desc = \"transport: dial unix",
|
||||||
|
)
|
||||||
|
te.startServer(&testServer{security: e.security})
|
||||||
|
defer te.tearDown()
|
||||||
|
|
||||||
|
cc := te.clientConn()
|
||||||
|
tc := testpb.NewTestServiceClient(cc)
|
||||||
|
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
|
||||||
|
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
|
||||||
|
}
|
||||||
|
ch := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
te.srv.GracefulStop()
|
||||||
|
close(ch)
|
||||||
|
}()
|
||||||
|
for {
|
||||||
|
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||||
|
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err == nil || grpc.Code(err) != codes.Unavailable {
|
||||||
|
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, error code: %d", err, codes.Unavailable)
|
||||||
|
}
|
||||||
|
<-ch
|
||||||
|
awaitNewConnLogOutput()
|
||||||
|
}
|
||||||
|
|
||||||
func testFailFast(t *testing.T, e env) {
|
func testFailFast(t *testing.T, e env) {
|
||||||
te := newTest(t, e)
|
te := newTest(t, e)
|
||||||
te.userAgent = testAppUA
|
te.userAgent = testAppUA
|
||||||
|
@ -72,6 +72,11 @@ type resetStream struct {
|
|||||||
|
|
||||||
func (*resetStream) item() {}
|
func (*resetStream) item() {}
|
||||||
|
|
||||||
|
type goAway struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*goAway) item() {}
|
||||||
|
|
||||||
type flushIO struct {
|
type flushIO struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -370,6 +370,9 @@ func (ht *serverHandlerTransport) runStream() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ht *serverHandlerTransport) GoAway() {
|
||||||
|
}
|
||||||
|
|
||||||
// mapRecvMsgError returns the non-nil err into the appropriate
|
// mapRecvMsgError returns the non-nil err into the appropriate
|
||||||
// error value as expected by callers of *grpc.parser.recvMsg.
|
// error value as expected by callers of *grpc.parser.recvMsg.
|
||||||
// In particular, in can only be:
|
// In particular, in can only be:
|
||||||
|
@ -205,6 +205,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
|
|||||||
s := &Stream{
|
s := &Stream{
|
||||||
id: t.nextID,
|
id: t.nextID,
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
|
goAway: make(chan struct{}),
|
||||||
method: callHdr.Method,
|
method: callHdr.Method,
|
||||||
sendCompress: callHdr.SendCompress,
|
sendCompress: callHdr.SendCompress,
|
||||||
buf: newRecvBuffer(),
|
buf: newRecvBuffer(),
|
||||||
@ -219,8 +220,9 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
|
|||||||
// Make a stream be able to cancel the pending operations by itself.
|
// Make a stream be able to cancel the pending operations by itself.
|
||||||
s.ctx, s.cancel = context.WithCancel(ctx)
|
s.ctx, s.cancel = context.WithCancel(ctx)
|
||||||
s.dec = &recvBufferReader{
|
s.dec = &recvBufferReader{
|
||||||
ctx: s.ctx,
|
ctx: s.ctx,
|
||||||
recv: s.buf,
|
goAway: s.goAway,
|
||||||
|
recv: s.buf,
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
@ -443,13 +445,13 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
|
|||||||
// accessed any more.
|
// accessed any more.
|
||||||
func (t *http2Client) Close() (err error) {
|
func (t *http2Client) Close() (err error) {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
if t.state == reachable {
|
|
||||||
close(t.errorChan)
|
|
||||||
}
|
|
||||||
if t.state == closing {
|
if t.state == closing {
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if t.state == reachable {
|
||||||
|
close(t.errorChan)
|
||||||
|
}
|
||||||
t.state = closing
|
t.state = closing
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
close(t.shutdownChan)
|
close(t.shutdownChan)
|
||||||
@ -732,16 +734,11 @@ func (t *http2Client) handlePing(f *http2.PingFrame) {
|
|||||||
|
|
||||||
func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
|
func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
t.goAwayID = f.LastStreamID
|
if t.state == reachable {
|
||||||
t.err = ErrDrain
|
t.goAwayID = f.LastStreamID
|
||||||
close(t.errorChan)
|
t.err = ErrConnDrain
|
||||||
|
close(t.errorChan)
|
||||||
// Notify the streams which were initiated after the server sent GOAWAY.
|
}
|
||||||
//for i := f.LastStreamID + 2; i < t.nextID; i += 2 {
|
|
||||||
// if s, ok := t.activeStreams[i]; ok {
|
|
||||||
// close(s.goAway)
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -196,15 +196,22 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
|||||||
s.recvCompress = state.encoding
|
s.recvCompress = state.encoding
|
||||||
s.method = state.method
|
s.method = state.method
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
|
if t.state == draining {
|
||||||
|
t.mu.Unlock()
|
||||||
|
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
|
||||||
|
return
|
||||||
|
}
|
||||||
if t.state != reachable {
|
if t.state != reachable {
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if uint32(len(t.activeStreams)) >= t.maxStreams {
|
if uint32(len(t.activeStreams)) >= t.maxStreams {
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
|
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
|
s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
|
||||||
t.activeStreams[s.id] = s
|
t.activeStreams[s.id] = s
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
@ -263,13 +270,16 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
|
|||||||
switch frame := frame.(type) {
|
switch frame := frame.(type) {
|
||||||
case *http2.MetaHeadersFrame:
|
case *http2.MetaHeadersFrame:
|
||||||
id := frame.Header().StreamID
|
id := frame.Header().StreamID
|
||||||
|
t.mu.Lock()
|
||||||
if id%2 != 1 || id <= t.maxStreamID {
|
if id%2 != 1 || id <= t.maxStreamID {
|
||||||
|
t.mu.Unlock()
|
||||||
// illegal gRPC stream id.
|
// illegal gRPC stream id.
|
||||||
grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", id)
|
grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", id)
|
||||||
t.Close()
|
t.Close()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
t.maxStreamID = id
|
t.maxStreamID = id
|
||||||
|
t.mu.Unlock()
|
||||||
t.operateHeaders(frame, handle)
|
t.operateHeaders(frame, handle)
|
||||||
case *http2.DataFrame:
|
case *http2.DataFrame:
|
||||||
t.handleData(frame)
|
t.handleData(frame)
|
||||||
@ -282,6 +292,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
|
|||||||
case *http2.WindowUpdateFrame:
|
case *http2.WindowUpdateFrame:
|
||||||
t.handleWindowUpdate(frame)
|
t.handleWindowUpdate(frame)
|
||||||
case *http2.GoAwayFrame:
|
case *http2.GoAwayFrame:
|
||||||
|
t.Close()
|
||||||
break
|
break
|
||||||
default:
|
default:
|
||||||
grpclog.Printf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame)
|
grpclog.Printf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame)
|
||||||
@ -675,6 +686,12 @@ func (t *http2Server) controller() {
|
|||||||
}
|
}
|
||||||
case *resetStream:
|
case *resetStream:
|
||||||
t.framer.writeRSTStream(true, i.streamID, i.code)
|
t.framer.writeRSTStream(true, i.streamID, i.code)
|
||||||
|
case *goAway:
|
||||||
|
t.mu.Lock()
|
||||||
|
sid := t.maxStreamID
|
||||||
|
t.state = draining
|
||||||
|
t.mu.Unlock()
|
||||||
|
t.framer.writeGoAway(true, sid, http2.ErrCodeNo, nil)
|
||||||
case *flushIO:
|
case *flushIO:
|
||||||
t.framer.flushWrite()
|
t.framer.flushWrite()
|
||||||
case *ping:
|
case *ping:
|
||||||
@ -742,3 +759,7 @@ func (t *http2Server) closeStream(s *Stream) {
|
|||||||
func (t *http2Server) RemoteAddr() net.Addr {
|
func (t *http2Server) RemoteAddr() net.Addr {
|
||||||
return t.conn.RemoteAddr()
|
return t.conn.RemoteAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *http2Server) GoAway() {
|
||||||
|
t.controlBuf.put(&goAway{})
|
||||||
|
}
|
||||||
|
@ -53,10 +53,6 @@ import (
|
|||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
ErrDrain = ConnectionErrorf("transport: Server stopped accepting new RPCs")
|
|
||||||
)
|
|
||||||
|
|
||||||
// recvMsg represents the received msg from the transport. All transport
|
// recvMsg represents the received msg from the transport. All transport
|
||||||
// protocol specific info has been removed.
|
// protocol specific info has been removed.
|
||||||
type recvMsg struct {
|
type recvMsg struct {
|
||||||
@ -147,7 +143,7 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) {
|
|||||||
case <-r.ctx.Done():
|
case <-r.ctx.Done():
|
||||||
return 0, ContextErr(r.ctx.Err())
|
return 0, ContextErr(r.ctx.Err())
|
||||||
case <-r.goAway:
|
case <-r.goAway:
|
||||||
return 0, ErrConnDrain
|
return 0, ErrStreamDrain
|
||||||
case i := <-r.recv.get():
|
case i := <-r.recv.get():
|
||||||
r.recv.load()
|
r.recv.load()
|
||||||
m := i.(*recvMsg)
|
m := i.(*recvMsg)
|
||||||
@ -478,6 +474,9 @@ type ServerTransport interface {
|
|||||||
|
|
||||||
// RemoteAddr returns the remote network address.
|
// RemoteAddr returns the remote network address.
|
||||||
RemoteAddr() net.Addr
|
RemoteAddr() net.Addr
|
||||||
|
|
||||||
|
// GoAway ...
|
||||||
|
GoAway()
|
||||||
}
|
}
|
||||||
|
|
||||||
// StreamErrorf creates an StreamError with the specified error code and description.
|
// StreamErrorf creates an StreamError with the specified error code and description.
|
||||||
@ -509,6 +508,7 @@ func (e ConnectionError) Error() string {
|
|||||||
var (
|
var (
|
||||||
ErrConnClosing = ConnectionError{Desc: "transport is closing"}
|
ErrConnClosing = ConnectionError{Desc: "transport is closing"}
|
||||||
ErrConnDrain = ConnectionError{Desc: "transport is being drained"}
|
ErrConnDrain = ConnectionError{Desc: "transport is being drained"}
|
||||||
|
ErrStreamDrain = StreamErrorf(codes.Unavailable, "afjlalf")
|
||||||
)
|
)
|
||||||
|
|
||||||
// StreamError is an error that only affects one stream within a connection.
|
// StreamError is an error that only affects one stream within a connection.
|
||||||
@ -536,7 +536,7 @@ func ContextErr(err error) StreamError {
|
|||||||
// If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err.
|
// If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err.
|
||||||
// If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise
|
// If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise
|
||||||
// it return the StreamError for ctx.Err.
|
// it return the StreamError for ctx.Err.
|
||||||
// If it receives from goAway, it returns 0, ErrConnDrain.
|
// If it receives from goAway, it returns 0, ErrStreamDrain.
|
||||||
// If it receives from closing, it returns 0, ErrConnClosing.
|
// If it receives from closing, it returns 0, ErrConnClosing.
|
||||||
// If it receives from proceed, it returns the received integer, nil.
|
// If it receives from proceed, it returns the received integer, nil.
|
||||||
func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-chan int) (int, error) {
|
func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-chan int) (int, error) {
|
||||||
@ -552,7 +552,7 @@ func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-
|
|||||||
}
|
}
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
case <-goAway:
|
case <-goAway:
|
||||||
return 0, ErrConnDrain
|
return 0, ErrStreamDrain
|
||||||
case <-closing:
|
case <-closing:
|
||||||
return 0, ErrConnClosing
|
return 0, ErrConnClosing
|
||||||
case i := <-proceed:
|
case i := <-proceed:
|
||||||
|
Reference in New Issue
Block a user