graceful close and test
This commit is contained in:
@ -44,7 +44,7 @@ type roundRobin struct {
|
||||
pending int
|
||||
}
|
||||
|
||||
func (rr *roundRobin) Up(addr Address) func() {
|
||||
func (rr *roundRobin) Up(addr Address) func(error) {
|
||||
rr.mu.Lock()
|
||||
defer rr.mu.Unlock()
|
||||
for _, a := range rr.addrs {
|
||||
@ -59,12 +59,12 @@ func (rr *roundRobin) Up(addr Address) func() {
|
||||
rr.waitCh = nil
|
||||
}
|
||||
}
|
||||
return func() {
|
||||
rr.down(addr)
|
||||
return func(err error) {
|
||||
rr.down(addr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (rr *roundRobin) down(addr Address) {
|
||||
func (rr *roundRobin) down(addr Address, err error) {
|
||||
rr.mu.Lock()
|
||||
defer rr.mu.Unlock()
|
||||
for i, a := range rr.addrs {
|
||||
|
123
clientconn.go
123
clientconn.go
@ -206,7 +206,7 @@ func WithUserAgent(s string) DialOption {
|
||||
func Dial(target string, opts ...DialOption) (*ClientConn, error) {
|
||||
cc := &ClientConn{
|
||||
target: target,
|
||||
infos: make(map[Address]*addrInfo),
|
||||
conns: make(map[Address]*addrConn),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(&cc.dopts)
|
||||
@ -235,9 +235,7 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
|
||||
return nil, err
|
||||
}
|
||||
cc.mu.Lock()
|
||||
cc.infos[addr] = &addrInfo{
|
||||
ac: ac,
|
||||
}
|
||||
cc.conns[addr] = ac
|
||||
cc.mu.Unlock()
|
||||
} else {
|
||||
w, err := cc.dopts.resolver.Resolve(cc.target)
|
||||
@ -299,10 +297,6 @@ func (s ConnectivityState) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
type addrInfo struct {
|
||||
ac *addrConn
|
||||
}
|
||||
|
||||
// ClientConn represents a client connection to an RPC service.
|
||||
type ClientConn struct {
|
||||
target string
|
||||
@ -312,7 +306,7 @@ type ClientConn struct {
|
||||
dopts dialOptions
|
||||
|
||||
mu sync.RWMutex
|
||||
infos map[Address]*addrInfo
|
||||
conns map[Address]*addrConn
|
||||
}
|
||||
|
||||
func (cc *ClientConn) watchAddrUpdates() error {
|
||||
@ -328,7 +322,7 @@ func (cc *ClientConn) watchAddrUpdates() error {
|
||||
Addr: update.Addr,
|
||||
Metadata: update.Metadata,
|
||||
}
|
||||
if _, ok := cc.infos[addr]; ok {
|
||||
if _, ok := cc.conns[addr]; ok {
|
||||
cc.mu.Unlock()
|
||||
grpclog.Println("grpc: The name resolver wanted to add an existing address: ", addr)
|
||||
continue
|
||||
@ -340,9 +334,7 @@ func (cc *ClientConn) watchAddrUpdates() error {
|
||||
return err
|
||||
}
|
||||
cc.mu.Lock()
|
||||
cc.infos[addr] = &addrInfo{
|
||||
ac: ac,
|
||||
}
|
||||
cc.conns[addr] = ac
|
||||
cc.mu.Unlock()
|
||||
case naming.Delete:
|
||||
cc.mu.Lock()
|
||||
@ -350,15 +342,16 @@ func (cc *ClientConn) watchAddrUpdates() error {
|
||||
Addr: update.Addr,
|
||||
Metadata: update.Metadata,
|
||||
}
|
||||
i, ok := cc.infos[addr]
|
||||
ac, ok := cc.conns[addr]
|
||||
if !ok {
|
||||
cc.mu.Unlock()
|
||||
grpclog.Println("grpc: The name resolver wanted to delete a non-exist address: ", addr)
|
||||
continue
|
||||
}
|
||||
delete(cc.infos, addr)
|
||||
delete(cc.conns, addr)
|
||||
cc.mu.Unlock()
|
||||
i.ac.startDrain()
|
||||
ac.tearDown(ErrConnDrain)
|
||||
//ac.startDrain()
|
||||
default:
|
||||
grpclog.Println("Unknown update.Op ", update.Op)
|
||||
}
|
||||
@ -367,16 +360,10 @@ func (cc *ClientConn) watchAddrUpdates() error {
|
||||
}
|
||||
|
||||
func (cc *ClientConn) newAddrConn(addr Address) (*addrConn, error) {
|
||||
/*
|
||||
if cc.target == "" {
|
||||
return nil, ErrUnspecTarget
|
||||
}
|
||||
*/
|
||||
c := &addrConn{
|
||||
cc: cc,
|
||||
addr: addr,
|
||||
dopts: cc.dopts,
|
||||
//resetChan: make(chan int, 1),
|
||||
cc: cc,
|
||||
addr: addr,
|
||||
dopts: cc.dopts,
|
||||
shutdownChan: make(chan struct{}),
|
||||
}
|
||||
if EnableTracing {
|
||||
@ -415,7 +402,6 @@ func (cc *ClientConn) newAddrConn(addr Address) (*addrConn, error) {
|
||||
c.tearDown(err)
|
||||
return
|
||||
}
|
||||
grpclog.Println("DEBUG ugh here resetTransport")
|
||||
c.transportMonitor()
|
||||
}()
|
||||
}
|
||||
@ -428,17 +414,17 @@ func (cc *ClientConn) getTransport(ctx context.Context) (transport.ClientTranspo
|
||||
return nil, nil, err
|
||||
}
|
||||
cc.mu.RLock()
|
||||
if cc.infos == nil {
|
||||
if cc.conns == nil {
|
||||
cc.mu.RUnlock()
|
||||
return nil, nil, ErrClientConnClosing
|
||||
}
|
||||
info, ok := cc.infos[addr]
|
||||
ac, ok := cc.conns[addr]
|
||||
cc.mu.RUnlock()
|
||||
if !ok {
|
||||
put()
|
||||
return nil, nil, transport.StreamErrorf(codes.Internal, "grpc: failed to find the transport to send the rpc")
|
||||
}
|
||||
t, err := info.ac.wait(ctx)
|
||||
t, err := ac.wait(ctx)
|
||||
if err != nil {
|
||||
put()
|
||||
return nil, nil, err
|
||||
@ -446,47 +432,31 @@ func (cc *ClientConn) getTransport(ctx context.Context) (transport.ClientTranspo
|
||||
return t, put, nil
|
||||
}
|
||||
|
||||
/*
|
||||
// State returns the connectivity state of cc.
|
||||
// This is EXPERIMENTAL API.
|
||||
func (cc *ClientConn) State() (ConnectivityState, error) {
|
||||
return cc.dopts.picker.State()
|
||||
}
|
||||
|
||||
// WaitForStateChange blocks until the state changes to something other than the sourceState.
|
||||
// It returns the new state or error.
|
||||
// This is EXPERIMENTAL API.
|
||||
func (cc *ClientConn) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
|
||||
return cc.dopts.picker.WaitForStateChange(ctx, sourceState)
|
||||
}
|
||||
*/
|
||||
|
||||
// Close starts to tear down the ClientConn.
|
||||
func (cc *ClientConn) Close() error {
|
||||
cc.mu.Lock()
|
||||
if cc.infos == nil {
|
||||
if cc.conns == nil {
|
||||
cc.mu.Unlock()
|
||||
return ErrClientConnClosing
|
||||
}
|
||||
infos := cc.infos
|
||||
cc.infos = nil
|
||||
conns := cc.conns
|
||||
cc.conns = nil
|
||||
cc.mu.Unlock()
|
||||
cc.balancer.Close()
|
||||
if cc.watcher != nil {
|
||||
cc.watcher.Close()
|
||||
}
|
||||
for _, i := range infos {
|
||||
i.ac.tearDown(ErrClientClosing)
|
||||
for _, ac := range conns {
|
||||
ac.tearDown(ErrClientConnClosing)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// addrConn is a network connection to a given address.
|
||||
type addrConn struct {
|
||||
cc *ClientConn
|
||||
addr Address
|
||||
dopts dialOptions
|
||||
//resetChan chan int
|
||||
cc *ClientConn
|
||||
addr Address
|
||||
dopts dialOptions
|
||||
shutdownChan chan struct{}
|
||||
events trace.EventLog
|
||||
|
||||
@ -494,13 +464,13 @@ type addrConn struct {
|
||||
state ConnectivityState
|
||||
stateCV *sync.Cond
|
||||
down func(error) // the handler called when a connection is down.
|
||||
drain bool
|
||||
// ready is closed and becomes nil when a new transport is up or failed
|
||||
// due to timeout.
|
||||
ready chan struct{}
|
||||
transport transport.ClientTransport
|
||||
}
|
||||
|
||||
/*
|
||||
func (ac *addrConn) startDrain() {
|
||||
ac.mu.Lock()
|
||||
t := ac.transport
|
||||
@ -510,8 +480,9 @@ func (ac *addrConn) startDrain() {
|
||||
ac.down = nil
|
||||
}
|
||||
ac.mu.Unlock()
|
||||
t.GracefulClose()
|
||||
ac.tearDown(ErrConnDrain)
|
||||
}
|
||||
*/
|
||||
|
||||
// printf records an event in ac's event log, unless ac has been closed.
|
||||
// REQUIRES ac.mu is held.
|
||||
@ -576,10 +547,12 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
|
||||
ac.mu.Unlock()
|
||||
return errConnClosing
|
||||
}
|
||||
if ac.drain {
|
||||
ac.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
/*
|
||||
if ac.drain {
|
||||
ac.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
*/
|
||||
if ac.down != nil {
|
||||
ac.down(ErrNetworkIO)
|
||||
ac.down = nil
|
||||
@ -613,7 +586,6 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
|
||||
copts.Timeout = timeout
|
||||
}
|
||||
connectTime := time.Now()
|
||||
grpclog.Println("DEBUG reach inside resetTransport 1")
|
||||
newTransport, err := transport.NewClientTransport(ac.addr.Addr, &copts)
|
||||
if err != nil {
|
||||
ac.mu.Lock()
|
||||
@ -639,7 +611,7 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
|
||||
ac.mu.Lock()
|
||||
ac.errorf("connection timeout")
|
||||
ac.mu.Unlock()
|
||||
ac.tearDown(ErrClientTimeout)
|
||||
ac.tearDown(ErrClientConnTimeout)
|
||||
return ErrClientConnTimeout
|
||||
}
|
||||
closeTransport = false
|
||||
@ -649,7 +621,6 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
|
||||
continue
|
||||
}
|
||||
ac.mu.Lock()
|
||||
grpclog.Println("DEBUG reach inside resetTransport 2")
|
||||
ac.printf("ready")
|
||||
if ac.state == Shutdown {
|
||||
// ac.tearDown(...) has been invoked.
|
||||
@ -657,7 +628,6 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
|
||||
newTransport.Close()
|
||||
return errConnClosing
|
||||
}
|
||||
grpclog.Println("DEBUG reach inside resetTransport 3: ", ac.addr)
|
||||
ac.state = Ready
|
||||
ac.stateCV.Broadcast()
|
||||
ac.transport = newTransport
|
||||
@ -683,12 +653,6 @@ func (ac *addrConn) transportMonitor() {
|
||||
// the addrConn is idle (i.e., no RPC in flight).
|
||||
case <-ac.shutdownChan:
|
||||
return
|
||||
/*
|
||||
case <-ac.resetChan:
|
||||
if !ac.reconnect() {
|
||||
return
|
||||
}
|
||||
*/
|
||||
case <-t.Error():
|
||||
ac.mu.Lock()
|
||||
if ac.state == Shutdown {
|
||||
@ -706,18 +670,6 @@ func (ac *addrConn) transportMonitor() {
|
||||
grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err)
|
||||
return
|
||||
}
|
||||
/*
|
||||
if !ac.reconnect() {
|
||||
return
|
||||
}
|
||||
*/
|
||||
/*
|
||||
// Tries to drain reset signal if there is any since it is out-dated.
|
||||
select {
|
||||
case <-ac.resetChan:
|
||||
default:
|
||||
}
|
||||
*/
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -751,8 +703,7 @@ func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error)
|
||||
}
|
||||
}
|
||||
|
||||
// tearDown starts to tear down the Conn. Returns errConnClosing if
|
||||
// it has been closed (mostly due to dial time-out).
|
||||
// tearDown starts to tear down the Conn.
|
||||
// TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in
|
||||
// some edge cases (e.g., the caller opens and closes many addrConn's in a
|
||||
// tight loop.
|
||||
@ -777,7 +728,11 @@ func (ac *addrConn) tearDown(err error) {
|
||||
ac.ready = nil
|
||||
}
|
||||
if ac.transport != nil {
|
||||
ac.transport.Close()
|
||||
if err == ErrConnDrain {
|
||||
ac.transport.GracefulClose()
|
||||
} else {
|
||||
ac.transport.Close()
|
||||
}
|
||||
}
|
||||
if ac.shutdownChan != nil {
|
||||
close(ac.shutdownChan)
|
||||
|
@ -403,6 +403,11 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
|
||||
updateStreams = true
|
||||
}
|
||||
delete(t.activeStreams, s.id)
|
||||
if t.state == draining && len(t.activeStreams) == 0 {
|
||||
t.mu.Unlock()
|
||||
t.Close()
|
||||
return
|
||||
}
|
||||
t.mu.Unlock()
|
||||
if updateStreams {
|
||||
t.streamsQuota.add(1)
|
||||
@ -468,8 +473,16 @@ func (t *http2Client) Close() (err error) {
|
||||
|
||||
func (t *http2Client) GracefulClose() error {
|
||||
t.mu.Lock()
|
||||
if t.state == closing {
|
||||
t.mu.Unlock()
|
||||
return errors.New("transport: Graceful close on a closed transport")
|
||||
}
|
||||
if t.state == draining {
|
||||
t.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
t.state = draining
|
||||
active := len(t.activeStreams)
|
||||
t.activeStreams = nil
|
||||
t.mu.Unlock()
|
||||
if active == 0 {
|
||||
return t.Close()
|
||||
|
@ -321,6 +321,7 @@ const (
|
||||
reachable transportState = iota
|
||||
unreachable
|
||||
closing
|
||||
draining
|
||||
)
|
||||
|
||||
// NewServerTransport creates a ServerTransport with conn or non-nil error
|
||||
|
@ -331,19 +331,17 @@ func TestLargeMessage(t *testing.T) {
|
||||
defer wg.Done()
|
||||
s, err := ct.NewStream(context.Background(), callHdr)
|
||||
if err != nil {
|
||||
t.Errorf("failed to open stream: %v", err)
|
||||
t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
|
||||
}
|
||||
if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil {
|
||||
t.Errorf("failed to send data: %v", err)
|
||||
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
|
||||
}
|
||||
p := make([]byte, len(expectedResponseLarge))
|
||||
_, recvErr := io.ReadFull(s, p)
|
||||
if recvErr != nil || !bytes.Equal(p, expectedResponseLarge) {
|
||||
t.Errorf("Error: %v, want <nil>; Result len: %d, want len %d", recvErr, len(p), len(expectedResponseLarge))
|
||||
if _, err := io.ReadFull(s, p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
|
||||
t.Errorf("io.ReadFull(_, %v) = _, %v, want %v, <nil>", err, p, expectedResponse)
|
||||
}
|
||||
_, recvErr = io.ReadFull(s, p)
|
||||
if recvErr != io.EOF {
|
||||
t.Errorf("Error: %v; want <EOF>", recvErr)
|
||||
if _, err = io.ReadFull(s, p); err != io.EOF {
|
||||
t.Errorf("Failed to complete the stream %v; want <EOF>", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@ -352,6 +350,50 @@ func TestLargeMessage(t *testing.T) {
|
||||
server.stop()
|
||||
}
|
||||
|
||||
func TestGracefulClose(t *testing.T) {
|
||||
server, ct := setUp(t, 0, math.MaxUint32, normal)
|
||||
callHdr := &CallHdr{
|
||||
Host: "localhost",
|
||||
Method: "foo.Small",
|
||||
}
|
||||
s, err := ct.NewStream(context.Background(), callHdr)
|
||||
if err != nil {
|
||||
t.Fatalf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
|
||||
}
|
||||
if err = ct.GracefulClose(); err != nil {
|
||||
t.Fatalf("%v.GracefulClose() = %v, want <nil>", ct, err)
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
// Expect the failure for all the follow-up streams because ct has been closed gracefully.
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if _, err := ct.NewStream(context.Background(), callHdr); err != ErrConnClosing {
|
||||
t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", err, ErrConnClosing)
|
||||
}
|
||||
}()
|
||||
}
|
||||
opts := Options{
|
||||
Last: true,
|
||||
Delay: false,
|
||||
}
|
||||
// The stream which was created before graceful close can still proceed.
|
||||
if err := ct.Write(s, expectedRequest, &opts); err != nil {
|
||||
t.Fatalf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
|
||||
}
|
||||
p := make([]byte, len(expectedResponse))
|
||||
if _, err := io.ReadFull(s, p); err != nil || !bytes.Equal(p, expectedResponse) {
|
||||
t.Fatalf("io.ReadFull(_, %v) = _, %v, want %v, <nil>", err, p, expectedResponse)
|
||||
}
|
||||
if _, err = io.ReadFull(s, p); err != io.EOF {
|
||||
t.Fatalf("Failed to complete the stream %v; want <EOF>", err)
|
||||
}
|
||||
wg.Wait()
|
||||
ct.Close()
|
||||
server.stop()
|
||||
}
|
||||
|
||||
func TestLargeMessageSuspension(t *testing.T) {
|
||||
server, ct := setUp(t, 0, math.MaxUint32, suspended)
|
||||
callHdr := &CallHdr{
|
||||
|
Reference in New Issue
Block a user