From 19ded2395157c11fc46624a5038f1551f7a2bc44 Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Tue, 10 May 2016 19:29:44 -0700 Subject: [PATCH] graceful close and test --- balancer.go | 8 +-- clientconn.go | 123 ++++++++++++------------------------ transport/http2_client.go | 15 ++++- transport/transport.go | 1 + transport/transport_test.go | 58 ++++++++++++++--- 5 files changed, 108 insertions(+), 97 deletions(-) diff --git a/balancer.go b/balancer.go index 998c7373..78e4d331 100644 --- a/balancer.go +++ b/balancer.go @@ -44,7 +44,7 @@ type roundRobin struct { pending int } -func (rr *roundRobin) Up(addr Address) func() { +func (rr *roundRobin) Up(addr Address) func(error) { rr.mu.Lock() defer rr.mu.Unlock() for _, a := range rr.addrs { @@ -59,12 +59,12 @@ func (rr *roundRobin) Up(addr Address) func() { rr.waitCh = nil } } - return func() { - rr.down(addr) + return func(err error) { + rr.down(addr, err) } } -func (rr *roundRobin) down(addr Address) { +func (rr *roundRobin) down(addr Address, err error) { rr.mu.Lock() defer rr.mu.Unlock() for i, a := range rr.addrs { diff --git a/clientconn.go b/clientconn.go index 312e07a9..75e59d07 100644 --- a/clientconn.go +++ b/clientconn.go @@ -206,7 +206,7 @@ func WithUserAgent(s string) DialOption { func Dial(target string, opts ...DialOption) (*ClientConn, error) { cc := &ClientConn{ target: target, - infos: make(map[Address]*addrInfo), + conns: make(map[Address]*addrConn), } for _, opt := range opts { opt(&cc.dopts) @@ -235,9 +235,7 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) { return nil, err } cc.mu.Lock() - cc.infos[addr] = &addrInfo{ - ac: ac, - } + cc.conns[addr] = ac cc.mu.Unlock() } else { w, err := cc.dopts.resolver.Resolve(cc.target) @@ -299,10 +297,6 @@ func (s ConnectivityState) String() string { } } -type addrInfo struct { - ac *addrConn -} - // ClientConn represents a client connection to an RPC service. type ClientConn struct { target string @@ -312,7 +306,7 @@ type ClientConn struct { dopts dialOptions mu sync.RWMutex - infos map[Address]*addrInfo + conns map[Address]*addrConn } func (cc *ClientConn) watchAddrUpdates() error { @@ -328,7 +322,7 @@ func (cc *ClientConn) watchAddrUpdates() error { Addr: update.Addr, Metadata: update.Metadata, } - if _, ok := cc.infos[addr]; ok { + if _, ok := cc.conns[addr]; ok { cc.mu.Unlock() grpclog.Println("grpc: The name resolver wanted to add an existing address: ", addr) continue @@ -340,9 +334,7 @@ func (cc *ClientConn) watchAddrUpdates() error { return err } cc.mu.Lock() - cc.infos[addr] = &addrInfo{ - ac: ac, - } + cc.conns[addr] = ac cc.mu.Unlock() case naming.Delete: cc.mu.Lock() @@ -350,15 +342,16 @@ func (cc *ClientConn) watchAddrUpdates() error { Addr: update.Addr, Metadata: update.Metadata, } - i, ok := cc.infos[addr] + ac, ok := cc.conns[addr] if !ok { cc.mu.Unlock() grpclog.Println("grpc: The name resolver wanted to delete a non-exist address: ", addr) continue } - delete(cc.infos, addr) + delete(cc.conns, addr) cc.mu.Unlock() - i.ac.startDrain() + ac.tearDown(ErrConnDrain) + //ac.startDrain() default: grpclog.Println("Unknown update.Op ", update.Op) } @@ -367,16 +360,10 @@ func (cc *ClientConn) watchAddrUpdates() error { } func (cc *ClientConn) newAddrConn(addr Address) (*addrConn, error) { - /* - if cc.target == "" { - return nil, ErrUnspecTarget - } - */ c := &addrConn{ - cc: cc, - addr: addr, - dopts: cc.dopts, - //resetChan: make(chan int, 1), + cc: cc, + addr: addr, + dopts: cc.dopts, shutdownChan: make(chan struct{}), } if EnableTracing { @@ -415,7 +402,6 @@ func (cc *ClientConn) newAddrConn(addr Address) (*addrConn, error) { c.tearDown(err) return } - grpclog.Println("DEBUG ugh here resetTransport") c.transportMonitor() }() } @@ -428,17 +414,17 @@ func (cc *ClientConn) getTransport(ctx context.Context) (transport.ClientTranspo return nil, nil, err } cc.mu.RLock() - if cc.infos == nil { + if cc.conns == nil { cc.mu.RUnlock() return nil, nil, ErrClientConnClosing } - info, ok := cc.infos[addr] + ac, ok := cc.conns[addr] cc.mu.RUnlock() if !ok { put() return nil, nil, transport.StreamErrorf(codes.Internal, "grpc: failed to find the transport to send the rpc") } - t, err := info.ac.wait(ctx) + t, err := ac.wait(ctx) if err != nil { put() return nil, nil, err @@ -446,47 +432,31 @@ func (cc *ClientConn) getTransport(ctx context.Context) (transport.ClientTranspo return t, put, nil } -/* -// State returns the connectivity state of cc. -// This is EXPERIMENTAL API. -func (cc *ClientConn) State() (ConnectivityState, error) { - return cc.dopts.picker.State() -} - -// WaitForStateChange blocks until the state changes to something other than the sourceState. -// It returns the new state or error. -// This is EXPERIMENTAL API. -func (cc *ClientConn) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) { - return cc.dopts.picker.WaitForStateChange(ctx, sourceState) -} -*/ - // Close starts to tear down the ClientConn. func (cc *ClientConn) Close() error { cc.mu.Lock() - if cc.infos == nil { + if cc.conns == nil { cc.mu.Unlock() return ErrClientConnClosing } - infos := cc.infos - cc.infos = nil + conns := cc.conns + cc.conns = nil cc.mu.Unlock() cc.balancer.Close() if cc.watcher != nil { cc.watcher.Close() } - for _, i := range infos { - i.ac.tearDown(ErrClientClosing) + for _, ac := range conns { + ac.tearDown(ErrClientConnClosing) } return nil } // addrConn is a network connection to a given address. type addrConn struct { - cc *ClientConn - addr Address - dopts dialOptions - //resetChan chan int + cc *ClientConn + addr Address + dopts dialOptions shutdownChan chan struct{} events trace.EventLog @@ -494,13 +464,13 @@ type addrConn struct { state ConnectivityState stateCV *sync.Cond down func(error) // the handler called when a connection is down. - drain bool // ready is closed and becomes nil when a new transport is up or failed // due to timeout. ready chan struct{} transport transport.ClientTransport } +/* func (ac *addrConn) startDrain() { ac.mu.Lock() t := ac.transport @@ -510,8 +480,9 @@ func (ac *addrConn) startDrain() { ac.down = nil } ac.mu.Unlock() - t.GracefulClose() + ac.tearDown(ErrConnDrain) } +*/ // printf records an event in ac's event log, unless ac has been closed. // REQUIRES ac.mu is held. @@ -576,10 +547,12 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { ac.mu.Unlock() return errConnClosing } - if ac.drain { - ac.mu.Unlock() - return nil - } + /* + if ac.drain { + ac.mu.Unlock() + return nil + } + */ if ac.down != nil { ac.down(ErrNetworkIO) ac.down = nil @@ -613,7 +586,6 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { copts.Timeout = timeout } connectTime := time.Now() - grpclog.Println("DEBUG reach inside resetTransport 1") newTransport, err := transport.NewClientTransport(ac.addr.Addr, &copts) if err != nil { ac.mu.Lock() @@ -639,7 +611,7 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { ac.mu.Lock() ac.errorf("connection timeout") ac.mu.Unlock() - ac.tearDown(ErrClientTimeout) + ac.tearDown(ErrClientConnTimeout) return ErrClientConnTimeout } closeTransport = false @@ -649,7 +621,6 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { continue } ac.mu.Lock() - grpclog.Println("DEBUG reach inside resetTransport 2") ac.printf("ready") if ac.state == Shutdown { // ac.tearDown(...) has been invoked. @@ -657,7 +628,6 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { newTransport.Close() return errConnClosing } - grpclog.Println("DEBUG reach inside resetTransport 3: ", ac.addr) ac.state = Ready ac.stateCV.Broadcast() ac.transport = newTransport @@ -683,12 +653,6 @@ func (ac *addrConn) transportMonitor() { // the addrConn is idle (i.e., no RPC in flight). case <-ac.shutdownChan: return - /* - case <-ac.resetChan: - if !ac.reconnect() { - return - } - */ case <-t.Error(): ac.mu.Lock() if ac.state == Shutdown { @@ -706,18 +670,6 @@ func (ac *addrConn) transportMonitor() { grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err) return } - /* - if !ac.reconnect() { - return - } - */ - /* - // Tries to drain reset signal if there is any since it is out-dated. - select { - case <-ac.resetChan: - default: - } - */ } } } @@ -751,8 +703,7 @@ func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error) } } -// tearDown starts to tear down the Conn. Returns errConnClosing if -// it has been closed (mostly due to dial time-out). +// tearDown starts to tear down the Conn. // TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in // some edge cases (e.g., the caller opens and closes many addrConn's in a // tight loop. @@ -777,7 +728,11 @@ func (ac *addrConn) tearDown(err error) { ac.ready = nil } if ac.transport != nil { - ac.transport.Close() + if err == ErrConnDrain { + ac.transport.GracefulClose() + } else { + ac.transport.Close() + } } if ac.shutdownChan != nil { close(ac.shutdownChan) diff --git a/transport/http2_client.go b/transport/http2_client.go index 0fc6a668..4027614b 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -403,6 +403,11 @@ func (t *http2Client) CloseStream(s *Stream, err error) { updateStreams = true } delete(t.activeStreams, s.id) + if t.state == draining && len(t.activeStreams) == 0 { + t.mu.Unlock() + t.Close() + return + } t.mu.Unlock() if updateStreams { t.streamsQuota.add(1) @@ -468,8 +473,16 @@ func (t *http2Client) Close() (err error) { func (t *http2Client) GracefulClose() error { t.mu.Lock() + if t.state == closing { + t.mu.Unlock() + return errors.New("transport: Graceful close on a closed transport") + } + if t.state == draining { + t.mu.Unlock() + return nil + } + t.state = draining active := len(t.activeStreams) - t.activeStreams = nil t.mu.Unlock() if active == 0 { return t.Close() diff --git a/transport/transport.go b/transport/transport.go index b85c0ac4..230e215d 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -321,6 +321,7 @@ const ( reachable transportState = iota unreachable closing + draining ) // NewServerTransport creates a ServerTransport with conn or non-nil error diff --git a/transport/transport_test.go b/transport/transport_test.go index d63dba31..6ebec452 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -331,19 +331,17 @@ func TestLargeMessage(t *testing.T) { defer wg.Done() s, err := ct.NewStream(context.Background(), callHdr) if err != nil { - t.Errorf("failed to open stream: %v", err) + t.Errorf("%v.NewStream(_, _) = _, %v, want _, ", ct, err) } if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil { - t.Errorf("failed to send data: %v", err) + t.Errorf("%v.Write(_, _, _) = %v, want ", ct, err) } p := make([]byte, len(expectedResponseLarge)) - _, recvErr := io.ReadFull(s, p) - if recvErr != nil || !bytes.Equal(p, expectedResponseLarge) { - t.Errorf("Error: %v, want ; Result len: %d, want len %d", recvErr, len(p), len(expectedResponseLarge)) + if _, err := io.ReadFull(s, p); err != nil || !bytes.Equal(p, expectedResponseLarge) { + t.Errorf("io.ReadFull(_, %v) = _, %v, want %v, ", err, p, expectedResponse) } - _, recvErr = io.ReadFull(s, p) - if recvErr != io.EOF { - t.Errorf("Error: %v; want ", recvErr) + if _, err = io.ReadFull(s, p); err != io.EOF { + t.Errorf("Failed to complete the stream %v; want ", err) } }() } @@ -352,6 +350,50 @@ func TestLargeMessage(t *testing.T) { server.stop() } +func TestGracefulClose(t *testing.T) { + server, ct := setUp(t, 0, math.MaxUint32, normal) + callHdr := &CallHdr{ + Host: "localhost", + Method: "foo.Small", + } + s, err := ct.NewStream(context.Background(), callHdr) + if err != nil { + t.Fatalf("%v.NewStream(_, _) = _, %v, want _, ", ct, err) + } + if err = ct.GracefulClose(); err != nil { + t.Fatalf("%v.GracefulClose() = %v, want ", ct, err) + } + var wg sync.WaitGroup + // Expect the failure for all the follow-up streams because ct has been closed gracefully. + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if _, err := ct.NewStream(context.Background(), callHdr); err != ErrConnClosing { + t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", err, ErrConnClosing) + } + }() + } + opts := Options{ + Last: true, + Delay: false, + } + // The stream which was created before graceful close can still proceed. + if err := ct.Write(s, expectedRequest, &opts); err != nil { + t.Fatalf("%v.Write(_, _, _) = %v, want ", ct, err) + } + p := make([]byte, len(expectedResponse)) + if _, err := io.ReadFull(s, p); err != nil || !bytes.Equal(p, expectedResponse) { + t.Fatalf("io.ReadFull(_, %v) = _, %v, want %v, ", err, p, expectedResponse) + } + if _, err = io.ReadFull(s, p); err != io.EOF { + t.Fatalf("Failed to complete the stream %v; want ", err) + } + wg.Wait() + ct.Close() + server.stop() +} + func TestLargeMessageSuspension(t *testing.T) { server, ct := setUp(t, 0, math.MaxUint32, suspended) callHdr := &CallHdr{