graceful close and test

This commit is contained in:
iamqizhao
2016-05-10 19:29:44 -07:00
parent 64ed38ebed
commit 19ded23951
5 changed files with 108 additions and 97 deletions

View File

@ -44,7 +44,7 @@ type roundRobin struct {
pending int pending int
} }
func (rr *roundRobin) Up(addr Address) func() { func (rr *roundRobin) Up(addr Address) func(error) {
rr.mu.Lock() rr.mu.Lock()
defer rr.mu.Unlock() defer rr.mu.Unlock()
for _, a := range rr.addrs { for _, a := range rr.addrs {
@ -59,12 +59,12 @@ func (rr *roundRobin) Up(addr Address) func() {
rr.waitCh = nil rr.waitCh = nil
} }
} }
return func() { return func(err error) {
rr.down(addr) rr.down(addr, err)
} }
} }
func (rr *roundRobin) down(addr Address) { func (rr *roundRobin) down(addr Address, err error) {
rr.mu.Lock() rr.mu.Lock()
defer rr.mu.Unlock() defer rr.mu.Unlock()
for i, a := range rr.addrs { for i, a := range rr.addrs {

View File

@ -206,7 +206,7 @@ func WithUserAgent(s string) DialOption {
func Dial(target string, opts ...DialOption) (*ClientConn, error) { func Dial(target string, opts ...DialOption) (*ClientConn, error) {
cc := &ClientConn{ cc := &ClientConn{
target: target, target: target,
infos: make(map[Address]*addrInfo), conns: make(map[Address]*addrConn),
} }
for _, opt := range opts { for _, opt := range opts {
opt(&cc.dopts) opt(&cc.dopts)
@ -235,9 +235,7 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
return nil, err return nil, err
} }
cc.mu.Lock() cc.mu.Lock()
cc.infos[addr] = &addrInfo{ cc.conns[addr] = ac
ac: ac,
}
cc.mu.Unlock() cc.mu.Unlock()
} else { } else {
w, err := cc.dopts.resolver.Resolve(cc.target) w, err := cc.dopts.resolver.Resolve(cc.target)
@ -299,10 +297,6 @@ func (s ConnectivityState) String() string {
} }
} }
type addrInfo struct {
ac *addrConn
}
// ClientConn represents a client connection to an RPC service. // ClientConn represents a client connection to an RPC service.
type ClientConn struct { type ClientConn struct {
target string target string
@ -312,7 +306,7 @@ type ClientConn struct {
dopts dialOptions dopts dialOptions
mu sync.RWMutex mu sync.RWMutex
infos map[Address]*addrInfo conns map[Address]*addrConn
} }
func (cc *ClientConn) watchAddrUpdates() error { func (cc *ClientConn) watchAddrUpdates() error {
@ -328,7 +322,7 @@ func (cc *ClientConn) watchAddrUpdates() error {
Addr: update.Addr, Addr: update.Addr,
Metadata: update.Metadata, Metadata: update.Metadata,
} }
if _, ok := cc.infos[addr]; ok { if _, ok := cc.conns[addr]; ok {
cc.mu.Unlock() cc.mu.Unlock()
grpclog.Println("grpc: The name resolver wanted to add an existing address: ", addr) grpclog.Println("grpc: The name resolver wanted to add an existing address: ", addr)
continue continue
@ -340,9 +334,7 @@ func (cc *ClientConn) watchAddrUpdates() error {
return err return err
} }
cc.mu.Lock() cc.mu.Lock()
cc.infos[addr] = &addrInfo{ cc.conns[addr] = ac
ac: ac,
}
cc.mu.Unlock() cc.mu.Unlock()
case naming.Delete: case naming.Delete:
cc.mu.Lock() cc.mu.Lock()
@ -350,15 +342,16 @@ func (cc *ClientConn) watchAddrUpdates() error {
Addr: update.Addr, Addr: update.Addr,
Metadata: update.Metadata, Metadata: update.Metadata,
} }
i, ok := cc.infos[addr] ac, ok := cc.conns[addr]
if !ok { if !ok {
cc.mu.Unlock() cc.mu.Unlock()
grpclog.Println("grpc: The name resolver wanted to delete a non-exist address: ", addr) grpclog.Println("grpc: The name resolver wanted to delete a non-exist address: ", addr)
continue continue
} }
delete(cc.infos, addr) delete(cc.conns, addr)
cc.mu.Unlock() cc.mu.Unlock()
i.ac.startDrain() ac.tearDown(ErrConnDrain)
//ac.startDrain()
default: default:
grpclog.Println("Unknown update.Op ", update.Op) grpclog.Println("Unknown update.Op ", update.Op)
} }
@ -367,16 +360,10 @@ func (cc *ClientConn) watchAddrUpdates() error {
} }
func (cc *ClientConn) newAddrConn(addr Address) (*addrConn, error) { func (cc *ClientConn) newAddrConn(addr Address) (*addrConn, error) {
/*
if cc.target == "" {
return nil, ErrUnspecTarget
}
*/
c := &addrConn{ c := &addrConn{
cc: cc, cc: cc,
addr: addr, addr: addr,
dopts: cc.dopts, dopts: cc.dopts,
//resetChan: make(chan int, 1),
shutdownChan: make(chan struct{}), shutdownChan: make(chan struct{}),
} }
if EnableTracing { if EnableTracing {
@ -415,7 +402,6 @@ func (cc *ClientConn) newAddrConn(addr Address) (*addrConn, error) {
c.tearDown(err) c.tearDown(err)
return return
} }
grpclog.Println("DEBUG ugh here resetTransport")
c.transportMonitor() c.transportMonitor()
}() }()
} }
@ -428,17 +414,17 @@ func (cc *ClientConn) getTransport(ctx context.Context) (transport.ClientTranspo
return nil, nil, err return nil, nil, err
} }
cc.mu.RLock() cc.mu.RLock()
if cc.infos == nil { if cc.conns == nil {
cc.mu.RUnlock() cc.mu.RUnlock()
return nil, nil, ErrClientConnClosing return nil, nil, ErrClientConnClosing
} }
info, ok := cc.infos[addr] ac, ok := cc.conns[addr]
cc.mu.RUnlock() cc.mu.RUnlock()
if !ok { if !ok {
put() put()
return nil, nil, transport.StreamErrorf(codes.Internal, "grpc: failed to find the transport to send the rpc") return nil, nil, transport.StreamErrorf(codes.Internal, "grpc: failed to find the transport to send the rpc")
} }
t, err := info.ac.wait(ctx) t, err := ac.wait(ctx)
if err != nil { if err != nil {
put() put()
return nil, nil, err return nil, nil, err
@ -446,37 +432,22 @@ func (cc *ClientConn) getTransport(ctx context.Context) (transport.ClientTranspo
return t, put, nil return t, put, nil
} }
/*
// State returns the connectivity state of cc.
// This is EXPERIMENTAL API.
func (cc *ClientConn) State() (ConnectivityState, error) {
return cc.dopts.picker.State()
}
// WaitForStateChange blocks until the state changes to something other than the sourceState.
// It returns the new state or error.
// This is EXPERIMENTAL API.
func (cc *ClientConn) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
return cc.dopts.picker.WaitForStateChange(ctx, sourceState)
}
*/
// Close starts to tear down the ClientConn. // Close starts to tear down the ClientConn.
func (cc *ClientConn) Close() error { func (cc *ClientConn) Close() error {
cc.mu.Lock() cc.mu.Lock()
if cc.infos == nil { if cc.conns == nil {
cc.mu.Unlock() cc.mu.Unlock()
return ErrClientConnClosing return ErrClientConnClosing
} }
infos := cc.infos conns := cc.conns
cc.infos = nil cc.conns = nil
cc.mu.Unlock() cc.mu.Unlock()
cc.balancer.Close() cc.balancer.Close()
if cc.watcher != nil { if cc.watcher != nil {
cc.watcher.Close() cc.watcher.Close()
} }
for _, i := range infos { for _, ac := range conns {
i.ac.tearDown(ErrClientClosing) ac.tearDown(ErrClientConnClosing)
} }
return nil return nil
} }
@ -486,7 +457,6 @@ type addrConn struct {
cc *ClientConn cc *ClientConn
addr Address addr Address
dopts dialOptions dopts dialOptions
//resetChan chan int
shutdownChan chan struct{} shutdownChan chan struct{}
events trace.EventLog events trace.EventLog
@ -494,13 +464,13 @@ type addrConn struct {
state ConnectivityState state ConnectivityState
stateCV *sync.Cond stateCV *sync.Cond
down func(error) // the handler called when a connection is down. down func(error) // the handler called when a connection is down.
drain bool
// ready is closed and becomes nil when a new transport is up or failed // ready is closed and becomes nil when a new transport is up or failed
// due to timeout. // due to timeout.
ready chan struct{} ready chan struct{}
transport transport.ClientTransport transport transport.ClientTransport
} }
/*
func (ac *addrConn) startDrain() { func (ac *addrConn) startDrain() {
ac.mu.Lock() ac.mu.Lock()
t := ac.transport t := ac.transport
@ -510,8 +480,9 @@ func (ac *addrConn) startDrain() {
ac.down = nil ac.down = nil
} }
ac.mu.Unlock() ac.mu.Unlock()
t.GracefulClose() ac.tearDown(ErrConnDrain)
} }
*/
// printf records an event in ac's event log, unless ac has been closed. // printf records an event in ac's event log, unless ac has been closed.
// REQUIRES ac.mu is held. // REQUIRES ac.mu is held.
@ -576,10 +547,12 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
ac.mu.Unlock() ac.mu.Unlock()
return errConnClosing return errConnClosing
} }
/*
if ac.drain { if ac.drain {
ac.mu.Unlock() ac.mu.Unlock()
return nil return nil
} }
*/
if ac.down != nil { if ac.down != nil {
ac.down(ErrNetworkIO) ac.down(ErrNetworkIO)
ac.down = nil ac.down = nil
@ -613,7 +586,6 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
copts.Timeout = timeout copts.Timeout = timeout
} }
connectTime := time.Now() connectTime := time.Now()
grpclog.Println("DEBUG reach inside resetTransport 1")
newTransport, err := transport.NewClientTransport(ac.addr.Addr, &copts) newTransport, err := transport.NewClientTransport(ac.addr.Addr, &copts)
if err != nil { if err != nil {
ac.mu.Lock() ac.mu.Lock()
@ -639,7 +611,7 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
ac.mu.Lock() ac.mu.Lock()
ac.errorf("connection timeout") ac.errorf("connection timeout")
ac.mu.Unlock() ac.mu.Unlock()
ac.tearDown(ErrClientTimeout) ac.tearDown(ErrClientConnTimeout)
return ErrClientConnTimeout return ErrClientConnTimeout
} }
closeTransport = false closeTransport = false
@ -649,7 +621,6 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
continue continue
} }
ac.mu.Lock() ac.mu.Lock()
grpclog.Println("DEBUG reach inside resetTransport 2")
ac.printf("ready") ac.printf("ready")
if ac.state == Shutdown { if ac.state == Shutdown {
// ac.tearDown(...) has been invoked. // ac.tearDown(...) has been invoked.
@ -657,7 +628,6 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
newTransport.Close() newTransport.Close()
return errConnClosing return errConnClosing
} }
grpclog.Println("DEBUG reach inside resetTransport 3: ", ac.addr)
ac.state = Ready ac.state = Ready
ac.stateCV.Broadcast() ac.stateCV.Broadcast()
ac.transport = newTransport ac.transport = newTransport
@ -683,12 +653,6 @@ func (ac *addrConn) transportMonitor() {
// the addrConn is idle (i.e., no RPC in flight). // the addrConn is idle (i.e., no RPC in flight).
case <-ac.shutdownChan: case <-ac.shutdownChan:
return return
/*
case <-ac.resetChan:
if !ac.reconnect() {
return
}
*/
case <-t.Error(): case <-t.Error():
ac.mu.Lock() ac.mu.Lock()
if ac.state == Shutdown { if ac.state == Shutdown {
@ -706,18 +670,6 @@ func (ac *addrConn) transportMonitor() {
grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err) grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err)
return return
} }
/*
if !ac.reconnect() {
return
}
*/
/*
// Tries to drain reset signal if there is any since it is out-dated.
select {
case <-ac.resetChan:
default:
}
*/
} }
} }
} }
@ -751,8 +703,7 @@ func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error)
} }
} }
// tearDown starts to tear down the Conn. Returns errConnClosing if // tearDown starts to tear down the Conn.
// it has been closed (mostly due to dial time-out).
// TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in // TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in
// some edge cases (e.g., the caller opens and closes many addrConn's in a // some edge cases (e.g., the caller opens and closes many addrConn's in a
// tight loop. // tight loop.
@ -777,8 +728,12 @@ func (ac *addrConn) tearDown(err error) {
ac.ready = nil ac.ready = nil
} }
if ac.transport != nil { if ac.transport != nil {
if err == ErrConnDrain {
ac.transport.GracefulClose()
} else {
ac.transport.Close() ac.transport.Close()
} }
}
if ac.shutdownChan != nil { if ac.shutdownChan != nil {
close(ac.shutdownChan) close(ac.shutdownChan)
} }

View File

@ -403,6 +403,11 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
updateStreams = true updateStreams = true
} }
delete(t.activeStreams, s.id) delete(t.activeStreams, s.id)
if t.state == draining && len(t.activeStreams) == 0 {
t.mu.Unlock()
t.Close()
return
}
t.mu.Unlock() t.mu.Unlock()
if updateStreams { if updateStreams {
t.streamsQuota.add(1) t.streamsQuota.add(1)
@ -468,8 +473,16 @@ func (t *http2Client) Close() (err error) {
func (t *http2Client) GracefulClose() error { func (t *http2Client) GracefulClose() error {
t.mu.Lock() t.mu.Lock()
if t.state == closing {
t.mu.Unlock()
return errors.New("transport: Graceful close on a closed transport")
}
if t.state == draining {
t.mu.Unlock()
return nil
}
t.state = draining
active := len(t.activeStreams) active := len(t.activeStreams)
t.activeStreams = nil
t.mu.Unlock() t.mu.Unlock()
if active == 0 { if active == 0 {
return t.Close() return t.Close()

View File

@ -321,6 +321,7 @@ const (
reachable transportState = iota reachable transportState = iota
unreachable unreachable
closing closing
draining
) )
// NewServerTransport creates a ServerTransport with conn or non-nil error // NewServerTransport creates a ServerTransport with conn or non-nil error

View File

@ -331,19 +331,17 @@ func TestLargeMessage(t *testing.T) {
defer wg.Done() defer wg.Done()
s, err := ct.NewStream(context.Background(), callHdr) s, err := ct.NewStream(context.Background(), callHdr)
if err != nil { if err != nil {
t.Errorf("failed to open stream: %v", err) t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
} }
if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil { if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil {
t.Errorf("failed to send data: %v", err) t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
} }
p := make([]byte, len(expectedResponseLarge)) p := make([]byte, len(expectedResponseLarge))
_, recvErr := io.ReadFull(s, p) if _, err := io.ReadFull(s, p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
if recvErr != nil || !bytes.Equal(p, expectedResponseLarge) { t.Errorf("io.ReadFull(_, %v) = _, %v, want %v, <nil>", err, p, expectedResponse)
t.Errorf("Error: %v, want <nil>; Result len: %d, want len %d", recvErr, len(p), len(expectedResponseLarge))
} }
_, recvErr = io.ReadFull(s, p) if _, err = io.ReadFull(s, p); err != io.EOF {
if recvErr != io.EOF { t.Errorf("Failed to complete the stream %v; want <EOF>", err)
t.Errorf("Error: %v; want <EOF>", recvErr)
} }
}() }()
} }
@ -352,6 +350,50 @@ func TestLargeMessage(t *testing.T) {
server.stop() server.stop()
} }
func TestGracefulClose(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, normal)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo.Small",
}
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
t.Fatalf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
}
if err = ct.GracefulClose(); err != nil {
t.Fatalf("%v.GracefulClose() = %v, want <nil>", ct, err)
}
var wg sync.WaitGroup
// Expect the failure for all the follow-up streams because ct has been closed gracefully.
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if _, err := ct.NewStream(context.Background(), callHdr); err != ErrConnClosing {
t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", err, ErrConnClosing)
}
}()
}
opts := Options{
Last: true,
Delay: false,
}
// The stream which was created before graceful close can still proceed.
if err := ct.Write(s, expectedRequest, &opts); err != nil {
t.Fatalf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
}
p := make([]byte, len(expectedResponse))
if _, err := io.ReadFull(s, p); err != nil || !bytes.Equal(p, expectedResponse) {
t.Fatalf("io.ReadFull(_, %v) = _, %v, want %v, <nil>", err, p, expectedResponse)
}
if _, err = io.ReadFull(s, p); err != io.EOF {
t.Fatalf("Failed to complete the stream %v; want <EOF>", err)
}
wg.Wait()
ct.Close()
server.stop()
}
func TestLargeMessageSuspension(t *testing.T) { func TestLargeMessageSuspension(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, suspended) server, ct := setUp(t, 0, math.MaxUint32, suspended)
callHdr := &CallHdr{ callHdr := &CallHdr{