graceful close and test

This commit is contained in:
iamqizhao
2016-05-10 19:29:44 -07:00
parent 64ed38ebed
commit 19ded23951
5 changed files with 108 additions and 97 deletions

View File

@ -44,7 +44,7 @@ type roundRobin struct {
pending int
}
func (rr *roundRobin) Up(addr Address) func() {
func (rr *roundRobin) Up(addr Address) func(error) {
rr.mu.Lock()
defer rr.mu.Unlock()
for _, a := range rr.addrs {
@ -59,12 +59,12 @@ func (rr *roundRobin) Up(addr Address) func() {
rr.waitCh = nil
}
}
return func() {
rr.down(addr)
return func(err error) {
rr.down(addr, err)
}
}
func (rr *roundRobin) down(addr Address) {
func (rr *roundRobin) down(addr Address, err error) {
rr.mu.Lock()
defer rr.mu.Unlock()
for i, a := range rr.addrs {

View File

@ -206,7 +206,7 @@ func WithUserAgent(s string) DialOption {
func Dial(target string, opts ...DialOption) (*ClientConn, error) {
cc := &ClientConn{
target: target,
infos: make(map[Address]*addrInfo),
conns: make(map[Address]*addrConn),
}
for _, opt := range opts {
opt(&cc.dopts)
@ -235,9 +235,7 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
return nil, err
}
cc.mu.Lock()
cc.infos[addr] = &addrInfo{
ac: ac,
}
cc.conns[addr] = ac
cc.mu.Unlock()
} else {
w, err := cc.dopts.resolver.Resolve(cc.target)
@ -299,10 +297,6 @@ func (s ConnectivityState) String() string {
}
}
type addrInfo struct {
ac *addrConn
}
// ClientConn represents a client connection to an RPC service.
type ClientConn struct {
target string
@ -312,7 +306,7 @@ type ClientConn struct {
dopts dialOptions
mu sync.RWMutex
infos map[Address]*addrInfo
conns map[Address]*addrConn
}
func (cc *ClientConn) watchAddrUpdates() error {
@ -328,7 +322,7 @@ func (cc *ClientConn) watchAddrUpdates() error {
Addr: update.Addr,
Metadata: update.Metadata,
}
if _, ok := cc.infos[addr]; ok {
if _, ok := cc.conns[addr]; ok {
cc.mu.Unlock()
grpclog.Println("grpc: The name resolver wanted to add an existing address: ", addr)
continue
@ -340,9 +334,7 @@ func (cc *ClientConn) watchAddrUpdates() error {
return err
}
cc.mu.Lock()
cc.infos[addr] = &addrInfo{
ac: ac,
}
cc.conns[addr] = ac
cc.mu.Unlock()
case naming.Delete:
cc.mu.Lock()
@ -350,15 +342,16 @@ func (cc *ClientConn) watchAddrUpdates() error {
Addr: update.Addr,
Metadata: update.Metadata,
}
i, ok := cc.infos[addr]
ac, ok := cc.conns[addr]
if !ok {
cc.mu.Unlock()
grpclog.Println("grpc: The name resolver wanted to delete a non-exist address: ", addr)
continue
}
delete(cc.infos, addr)
delete(cc.conns, addr)
cc.mu.Unlock()
i.ac.startDrain()
ac.tearDown(ErrConnDrain)
//ac.startDrain()
default:
grpclog.Println("Unknown update.Op ", update.Op)
}
@ -367,16 +360,10 @@ func (cc *ClientConn) watchAddrUpdates() error {
}
func (cc *ClientConn) newAddrConn(addr Address) (*addrConn, error) {
/*
if cc.target == "" {
return nil, ErrUnspecTarget
}
*/
c := &addrConn{
cc: cc,
addr: addr,
dopts: cc.dopts,
//resetChan: make(chan int, 1),
cc: cc,
addr: addr,
dopts: cc.dopts,
shutdownChan: make(chan struct{}),
}
if EnableTracing {
@ -415,7 +402,6 @@ func (cc *ClientConn) newAddrConn(addr Address) (*addrConn, error) {
c.tearDown(err)
return
}
grpclog.Println("DEBUG ugh here resetTransport")
c.transportMonitor()
}()
}
@ -428,17 +414,17 @@ func (cc *ClientConn) getTransport(ctx context.Context) (transport.ClientTranspo
return nil, nil, err
}
cc.mu.RLock()
if cc.infos == nil {
if cc.conns == nil {
cc.mu.RUnlock()
return nil, nil, ErrClientConnClosing
}
info, ok := cc.infos[addr]
ac, ok := cc.conns[addr]
cc.mu.RUnlock()
if !ok {
put()
return nil, nil, transport.StreamErrorf(codes.Internal, "grpc: failed to find the transport to send the rpc")
}
t, err := info.ac.wait(ctx)
t, err := ac.wait(ctx)
if err != nil {
put()
return nil, nil, err
@ -446,47 +432,31 @@ func (cc *ClientConn) getTransport(ctx context.Context) (transport.ClientTranspo
return t, put, nil
}
/*
// State returns the connectivity state of cc.
// This is EXPERIMENTAL API.
func (cc *ClientConn) State() (ConnectivityState, error) {
return cc.dopts.picker.State()
}
// WaitForStateChange blocks until the state changes to something other than the sourceState.
// It returns the new state or error.
// This is EXPERIMENTAL API.
func (cc *ClientConn) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
return cc.dopts.picker.WaitForStateChange(ctx, sourceState)
}
*/
// Close starts to tear down the ClientConn.
func (cc *ClientConn) Close() error {
cc.mu.Lock()
if cc.infos == nil {
if cc.conns == nil {
cc.mu.Unlock()
return ErrClientConnClosing
}
infos := cc.infos
cc.infos = nil
conns := cc.conns
cc.conns = nil
cc.mu.Unlock()
cc.balancer.Close()
if cc.watcher != nil {
cc.watcher.Close()
}
for _, i := range infos {
i.ac.tearDown(ErrClientClosing)
for _, ac := range conns {
ac.tearDown(ErrClientConnClosing)
}
return nil
}
// addrConn is a network connection to a given address.
type addrConn struct {
cc *ClientConn
addr Address
dopts dialOptions
//resetChan chan int
cc *ClientConn
addr Address
dopts dialOptions
shutdownChan chan struct{}
events trace.EventLog
@ -494,13 +464,13 @@ type addrConn struct {
state ConnectivityState
stateCV *sync.Cond
down func(error) // the handler called when a connection is down.
drain bool
// ready is closed and becomes nil when a new transport is up or failed
// due to timeout.
ready chan struct{}
transport transport.ClientTransport
}
/*
func (ac *addrConn) startDrain() {
ac.mu.Lock()
t := ac.transport
@ -510,8 +480,9 @@ func (ac *addrConn) startDrain() {
ac.down = nil
}
ac.mu.Unlock()
t.GracefulClose()
ac.tearDown(ErrConnDrain)
}
*/
// printf records an event in ac's event log, unless ac has been closed.
// REQUIRES ac.mu is held.
@ -576,10 +547,12 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
ac.mu.Unlock()
return errConnClosing
}
if ac.drain {
ac.mu.Unlock()
return nil
}
/*
if ac.drain {
ac.mu.Unlock()
return nil
}
*/
if ac.down != nil {
ac.down(ErrNetworkIO)
ac.down = nil
@ -613,7 +586,6 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
copts.Timeout = timeout
}
connectTime := time.Now()
grpclog.Println("DEBUG reach inside resetTransport 1")
newTransport, err := transport.NewClientTransport(ac.addr.Addr, &copts)
if err != nil {
ac.mu.Lock()
@ -639,7 +611,7 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
ac.mu.Lock()
ac.errorf("connection timeout")
ac.mu.Unlock()
ac.tearDown(ErrClientTimeout)
ac.tearDown(ErrClientConnTimeout)
return ErrClientConnTimeout
}
closeTransport = false
@ -649,7 +621,6 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
continue
}
ac.mu.Lock()
grpclog.Println("DEBUG reach inside resetTransport 2")
ac.printf("ready")
if ac.state == Shutdown {
// ac.tearDown(...) has been invoked.
@ -657,7 +628,6 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
newTransport.Close()
return errConnClosing
}
grpclog.Println("DEBUG reach inside resetTransport 3: ", ac.addr)
ac.state = Ready
ac.stateCV.Broadcast()
ac.transport = newTransport
@ -683,12 +653,6 @@ func (ac *addrConn) transportMonitor() {
// the addrConn is idle (i.e., no RPC in flight).
case <-ac.shutdownChan:
return
/*
case <-ac.resetChan:
if !ac.reconnect() {
return
}
*/
case <-t.Error():
ac.mu.Lock()
if ac.state == Shutdown {
@ -706,18 +670,6 @@ func (ac *addrConn) transportMonitor() {
grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err)
return
}
/*
if !ac.reconnect() {
return
}
*/
/*
// Tries to drain reset signal if there is any since it is out-dated.
select {
case <-ac.resetChan:
default:
}
*/
}
}
}
@ -751,8 +703,7 @@ func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error)
}
}
// tearDown starts to tear down the Conn. Returns errConnClosing if
// it has been closed (mostly due to dial time-out).
// tearDown starts to tear down the Conn.
// TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in
// some edge cases (e.g., the caller opens and closes many addrConn's in a
// tight loop.
@ -777,7 +728,11 @@ func (ac *addrConn) tearDown(err error) {
ac.ready = nil
}
if ac.transport != nil {
ac.transport.Close()
if err == ErrConnDrain {
ac.transport.GracefulClose()
} else {
ac.transport.Close()
}
}
if ac.shutdownChan != nil {
close(ac.shutdownChan)

View File

@ -403,6 +403,11 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
updateStreams = true
}
delete(t.activeStreams, s.id)
if t.state == draining && len(t.activeStreams) == 0 {
t.mu.Unlock()
t.Close()
return
}
t.mu.Unlock()
if updateStreams {
t.streamsQuota.add(1)
@ -468,8 +473,16 @@ func (t *http2Client) Close() (err error) {
func (t *http2Client) GracefulClose() error {
t.mu.Lock()
if t.state == closing {
t.mu.Unlock()
return errors.New("transport: Graceful close on a closed transport")
}
if t.state == draining {
t.mu.Unlock()
return nil
}
t.state = draining
active := len(t.activeStreams)
t.activeStreams = nil
t.mu.Unlock()
if active == 0 {
return t.Close()

View File

@ -321,6 +321,7 @@ const (
reachable transportState = iota
unreachable
closing
draining
)
// NewServerTransport creates a ServerTransport with conn or non-nil error

View File

@ -331,19 +331,17 @@ func TestLargeMessage(t *testing.T) {
defer wg.Done()
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
t.Errorf("failed to open stream: %v", err)
t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
}
if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil {
t.Errorf("failed to send data: %v", err)
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
}
p := make([]byte, len(expectedResponseLarge))
_, recvErr := io.ReadFull(s, p)
if recvErr != nil || !bytes.Equal(p, expectedResponseLarge) {
t.Errorf("Error: %v, want <nil>; Result len: %d, want len %d", recvErr, len(p), len(expectedResponseLarge))
if _, err := io.ReadFull(s, p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
t.Errorf("io.ReadFull(_, %v) = _, %v, want %v, <nil>", err, p, expectedResponse)
}
_, recvErr = io.ReadFull(s, p)
if recvErr != io.EOF {
t.Errorf("Error: %v; want <EOF>", recvErr)
if _, err = io.ReadFull(s, p); err != io.EOF {
t.Errorf("Failed to complete the stream %v; want <EOF>", err)
}
}()
}
@ -352,6 +350,50 @@ func TestLargeMessage(t *testing.T) {
server.stop()
}
func TestGracefulClose(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, normal)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo.Small",
}
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
t.Fatalf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
}
if err = ct.GracefulClose(); err != nil {
t.Fatalf("%v.GracefulClose() = %v, want <nil>", ct, err)
}
var wg sync.WaitGroup
// Expect the failure for all the follow-up streams because ct has been closed gracefully.
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if _, err := ct.NewStream(context.Background(), callHdr); err != ErrConnClosing {
t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", err, ErrConnClosing)
}
}()
}
opts := Options{
Last: true,
Delay: false,
}
// The stream which was created before graceful close can still proceed.
if err := ct.Write(s, expectedRequest, &opts); err != nil {
t.Fatalf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
}
p := make([]byte, len(expectedResponse))
if _, err := io.ReadFull(s, p); err != nil || !bytes.Equal(p, expectedResponse) {
t.Fatalf("io.ReadFull(_, %v) = _, %v, want %v, <nil>", err, p, expectedResponse)
}
if _, err = io.ReadFull(s, p); err != io.EOF {
t.Fatalf("Failed to complete the stream %v; want <EOF>", err)
}
wg.Wait()
ct.Close()
server.stop()
}
func TestLargeMessageSuspension(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, suspended)
callHdr := &CallHdr{