Fix some issues and remove garbbage files

This commit is contained in:
iamqizhao
2016-05-16 15:31:00 -07:00
parent 2161303fcd
commit aa532d5baf
8 changed files with 152 additions and 1050 deletions

View File

@ -43,6 +43,7 @@ type roundRobin struct {
addrs []Address addrs []Address
next int // index of the next address to return for Get() next int // index of the next address to return for Get()
waitCh chan struct{} // channel to block when there is no address available waitCh chan struct{} // channel to block when there is no address available
done bool // The Balancer is closed.
} }
// Up appends addr to the end of rr.addrs and sends notification if there // Up appends addr to the end of rr.addrs and sends notification if there
@ -84,6 +85,11 @@ func (rr *roundRobin) down(addr Address, err error) {
func (rr *roundRobin) Get(ctx context.Context) (addr Address, put func(), err error) { func (rr *roundRobin) Get(ctx context.Context) (addr Address, put func(), err error) {
var ch chan struct{} var ch chan struct{}
rr.mu.Lock() rr.mu.Lock()
if rr.done {
rr.mu.Unlock()
err = ErrClientConnClosing
return
}
if rr.next >= len(rr.addrs) { if rr.next >= len(rr.addrs) {
rr.next = 0 rr.next = 0
} }
@ -111,6 +117,11 @@ func (rr *roundRobin) Get(ctx context.Context) (addr Address, put func(), err er
return return
case <-ch: case <-ch:
rr.mu.Lock() rr.mu.Lock()
if rr.done {
rr.mu.Unlock()
err = ErrClientConnClosing
return
}
if len(rr.addrs) == 0 { if len(rr.addrs) == 0 {
// The newly added addr got removed by Down() again. // The newly added addr got removed by Down() again.
rr.mu.Unlock() rr.mu.Unlock()
@ -134,5 +145,12 @@ func (rr *roundRobin) put(ctx context.Context, addr Address) {
} }
func (rr *roundRobin) Close() error { func (rr *roundRobin) Close() error {
rr.mu.Lock()
defer rr.mu.Unlock()
rr.done = true
if rr.waitCh != nil {
close(rr.waitCh)
rr.waitCh = nil
}
return nil return nil
} }

View File

@ -1,6 +1,6 @@
/* /*
* *
* Copyright 2014, Google Inc. * Copyright 2016, Google Inc.
* All rights reserved. * All rights reserved.
* *
* Redistribution and use in source and binary forms, with or without * Redistribution and use in source and binary forms, with or without
@ -36,10 +36,12 @@ package grpc
import ( import (
"fmt" "fmt"
"math" "math"
"sync"
"testing" "testing"
"time" "time"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/naming" "google.golang.org/grpc/naming"
) )
@ -100,12 +102,12 @@ func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) {
return r.w, nil return r.w, nil
} }
func startServers(t *testing.T, numServers, port int, maxStreams uint32) ([]*server, *testNameResolver) { func startServers(t *testing.T, numServers int, maxStreams uint32) ([]*server, *testNameResolver) {
var servers []*server var servers []*server
for i := 0; i < numServers; i++ { for i := 0; i < numServers; i++ {
s := newTestServer() s := newTestServer()
servers = append(servers, s) servers = append(servers, s)
go s.start(t, port, maxStreams) go s.start(t, 0, maxStreams)
s.wait(t, 2*time.Second) s.wait(t, 2*time.Second)
} }
// Point to server[0] // Point to server[0]
@ -118,7 +120,7 @@ func startServers(t *testing.T, numServers, port int, maxStreams uint32) ([]*ser
func TestNameDiscovery(t *testing.T) { func TestNameDiscovery(t *testing.T) {
// Start 2 servers on 2 ports. // Start 2 servers on 2 ports.
numServers := 2 numServers := 2
servers, r := startServers(t, numServers, 0, math.MaxUint32) servers, r := startServers(t, numServers, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithNameResolver(r), WithBlock(), WithInsecure(), WithCodec(testCodec{})) cc, err := Dial("foo.bar.com", WithNameResolver(r), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil { if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err) t.Fatalf("Failed to create ClientConn: %v", err)
@ -152,7 +154,7 @@ func TestNameDiscovery(t *testing.T) {
} }
func TestEmptyAddrs(t *testing.T) { func TestEmptyAddrs(t *testing.T) {
servers, r := startServers(t, 1, 0, math.MaxUint32) servers, r := startServers(t, 1, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithNameResolver(r), WithBlock(), WithInsecure(), WithCodec(testCodec{})) cc, err := Dial("foo.bar.com", WithNameResolver(r), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil { if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err) t.Fatalf("Failed to create ClientConn: %v", err)
@ -184,7 +186,7 @@ func TestEmptyAddrs(t *testing.T) {
func TestRoundRobin(t *testing.T) { func TestRoundRobin(t *testing.T) {
// Start 3 servers on 3 ports. // Start 3 servers on 3 ports.
numServers := 3 numServers := 3
servers, r := startServers(t, numServers, 0, math.MaxUint32) servers, r := startServers(t, numServers, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithNameResolver(r), WithBlock(), WithInsecure(), WithCodec(testCodec{})) cc, err := Dial("foo.bar.com", WithNameResolver(r), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil { if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err) t.Fatalf("Failed to create ClientConn: %v", err)
@ -220,3 +222,89 @@ func TestRoundRobin(t *testing.T) {
servers[i].stop() servers[i].stop()
} }
} }
func TestCloseWithPendingRPC(t *testing.T) {
servers, r := startServers(t, 1, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithNameResolver(r), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port)
}
// Remove the server.
updates := []*naming.Update{&naming.Update{
Op: naming.Delete,
Addr: "127.0.0.1:" + servers[0].port,
}}
r.w.inject(updates)
for {
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); Code(err) == codes.DeadlineExceeded {
break
}
time.Sleep(10 * time.Millisecond)
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not", err)
}
}()
go func() {
defer wg.Done()
var reply string
time.Sleep(5 * time.Millisecond)
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
}
}()
time.Sleep(5 * time.Millisecond)
cc.Close()
wg.Wait()
servers[0].stop()
}
func TestGetOnWaitChannel(t *testing.T) {
servers, r := startServers(t, 1, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithNameResolver(r), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
// Remove all servers so that all upcoming RPCs will block on waitCh.
updates := []*naming.Update{&naming.Update{
Op: naming.Delete,
Addr: "127.0.0.1:" + servers[0].port,
}}
r.w.inject(updates)
for {
var reply string
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); Code(err) == codes.DeadlineExceeded {
break
}
time.Sleep(10 * time.Millisecond)
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
}
}()
// Add a connected server.
updates = []*naming.Update{&naming.Update{
Op: naming.Add,
Addr: "127.0.0.1:" + servers[0].port,
}}
r.w.inject(updates)
wg.Wait()
cc.Close()
servers[0].stop()
}

View File

@ -166,7 +166,7 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32) {
} }
st, err := transport.NewServerTransport("http2", conn, maxStreams, nil) st, err := transport.NewServerTransport("http2", conn, maxStreams, nil)
if err != nil { if err != nil {
return continue
} }
s.mu.Lock() s.mu.Lock()
if s.conns == nil { if s.conns == nil {

View File

@ -237,13 +237,9 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
addr := Address{ addr := Address{
Addr: cc.target, Addr: cc.target,
} }
ac, err := cc.newAddrConn(addr) if err := cc.newAddrConn(addr); err != nil {
if err != nil {
return nil, err return nil, err
} }
cc.mu.Lock()
cc.conns[addr] = ac
cc.mu.Unlock()
} else { } else {
w, err := cc.dopts.resolver.Resolve(cc.target) w, err := cc.dopts.resolver.Resolve(cc.target)
if err != nil { if err != nil {
@ -335,14 +331,15 @@ func (cc *ClientConn) watchAddrUpdates() error {
grpclog.Println("grpc: The name resolver wanted to add an existing address: ", addr) grpclog.Println("grpc: The name resolver wanted to add an existing address: ", addr)
continue continue
} }
cc.mu.Unlock() cc.mu.RUnlock()
ac, err := cc.newAddrConn(addr) if err := cc.newAddrConn(addr); err != nil {
if err != nil {
return err return err
} }
cc.mu.Lock() /*
cc.conns[addr] = ac cc.mu.Lock()
cc.mu.Unlock() cc.conns[addr] = ac
cc.mu.Unlock()
*/
case naming.Delete: case naming.Delete:
cc.mu.Lock() cc.mu.Lock()
addr := Address{ addr := Address{
@ -355,7 +352,6 @@ func (cc *ClientConn) watchAddrUpdates() error {
grpclog.Println("grpc: The name resolver wanted to delete a non-exist address: ", addr) grpclog.Println("grpc: The name resolver wanted to delete a non-exist address: ", addr)
continue continue
} }
delete(cc.conns, addr)
cc.mu.Unlock() cc.mu.Unlock()
ac.tearDown(ErrConnDrain) ac.tearDown(ErrConnDrain)
default: default:
@ -365,7 +361,7 @@ func (cc *ClientConn) watchAddrUpdates() error {
return nil return nil
} }
func (cc *ClientConn) newAddrConn(addr Address) (*addrConn, error) { func (cc *ClientConn) newAddrConn(addr Address) error {
c := &addrConn{ c := &addrConn{
cc: cc, cc: cc,
addr: addr, addr: addr,
@ -383,12 +379,12 @@ func (cc *ClientConn) newAddrConn(addr Address) (*addrConn, error) {
} }
} }
if !ok { if !ok {
return nil, ErrNoTransportSecurity return ErrNoTransportSecurity
} }
} else { } else {
for _, cd := range c.dopts.copts.AuthOptions { for _, cd := range c.dopts.copts.AuthOptions {
if cd.RequireTransportSecurity() { if cd.RequireTransportSecurity() {
return nil, ErrCredentialsMisuse return ErrCredentialsMisuse
} }
} }
} }
@ -396,7 +392,7 @@ func (cc *ClientConn) newAddrConn(addr Address) (*addrConn, error) {
if c.dopts.block { if c.dopts.block {
if err := c.resetTransport(false); err != nil { if err := c.resetTransport(false); err != nil {
c.tearDown(err) c.tearDown(err)
return nil, err return err
} }
// Start to monitor the error status of transport. // Start to monitor the error status of transport.
go c.transportMonitor() go c.transportMonitor()
@ -411,7 +407,7 @@ func (cc *ClientConn) newAddrConn(addr Address) (*addrConn, error) {
c.transportMonitor() c.transportMonitor()
}() }()
} }
return c, nil return nil
} }
func (cc *ClientConn) getTransport(ctx context.Context) (transport.ClientTransport, func(), error) { func (cc *ClientConn) getTransport(ctx context.Context) (transport.ClientTransport, func(), error) {
@ -529,6 +525,13 @@ func (ac *addrConn) waitForStateChange(ctx context.Context, sourceState Connecti
} }
func (ac *addrConn) resetTransport(closeTransport bool) error { func (ac *addrConn) resetTransport(closeTransport bool) error {
ac.cc.mu.Lock()
if ac.cc.conns == nil {
ac.cc.mu.Unlock()
return ErrClientConnClosing
}
ac.cc.conns[ac.addr] = ac
ac.cc.mu.Unlock()
var retries int var retries int
start := time.Now() start := time.Now()
for { for {
@ -692,13 +695,20 @@ func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error)
} }
} }
// tearDown starts to tear down the Conn. // tearDown starts to tear down the addrConn.
// TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in // TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in
// some edge cases (e.g., the caller opens and closes many addrConn's in a // some edge cases (e.g., the caller opens and closes many addrConn's in a
// tight loop. // tight loop.
func (ac *addrConn) tearDown(err error) { func (ac *addrConn) tearDown(err error) {
ac.mu.Lock() ac.mu.Lock()
defer ac.mu.Unlock() defer func() {
ac.mu.Unlock()
ac.cc.mu.Lock()
if ac.cc.conns != nil {
delete(ac.cc.conns, ac.addr)
}
ac.cc.mu.Unlock()
}()
if ac.down != nil { if ac.down != nil {
ac.down(err) ac.down(err)
ac.down = nil ac.down = nil

View File

@ -105,7 +105,6 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
err error err error
put func() put func()
) )
//t, err = cc.dopts.picker.Pick(ctx)
t, put, err = cc.getTransport(ctx) t, put, err = cc.getTransport(ctx)
if err != nil { if err != nil {
return nil, toRPCErr(err) return nil, toRPCErr(err)

View File

@ -435,14 +435,17 @@ type test struct {
func (te *test) tearDown() { func (te *test) tearDown() {
if te.cancel != nil { if te.cancel != nil {
te.cancel() te.cancel()
te.cancel = nil
} }
te.srv.Stop()
if te.cc != nil { if te.cc != nil {
te.cc.Close() te.cc.Close()
te.cc = nil
} }
if te.restoreLogs != nil { if te.restoreLogs != nil {
te.restoreLogs() te.restoreLogs()
te.restoreLogs = nil
} }
te.srv.Stop()
} }
// newTest returns a new test using the provided testing.T and // newTest returns a new test using the provided testing.T and
@ -625,8 +628,8 @@ func testTimeoutOnDeadServer(t *testing.T, e env) {
//if state, err := cc.State(); err != nil || (state != grpc.Connecting && state != grpc.TransientFailure) { //if state, err := cc.State(); err != nil || (state != grpc.Connecting && state != grpc.TransientFailure) {
// t.Fatalf("cc.State() = %s, %v, want %s or %s, <nil>", state, err, grpc.Connecting, grpc.TransientFailure) // t.Fatalf("cc.State() = %s, %v, want %s or %s, <nil>", state, err, grpc.Connecting, grpc.TransientFailure)
//} //}
cc.Close() //cc.Close()
awaitNewConnLogOutput() //awaitNewConnLogOutput()
} }
func healthCheck(d time.Duration, cc *grpc.ClientConn, serviceName string) (*healthpb.HealthCheckResponse, error) { func healthCheck(d time.Duration, cc *grpc.ClientConn, serviceName string) (*healthpb.HealthCheckResponse, error) {
@ -1076,7 +1079,7 @@ func testRPCTimeout(t *testing.T, e env) {
} }
} }
func TestCancel(t *testing.T) { func TestCancelX(t *testing.T) {
defer leakCheck(t)() defer leakCheck(t)()
for _, e := range listTestEnv() { for _, e := range listTestEnv() {
testCancel(t, e) testCancel(t, e)
@ -1111,8 +1114,6 @@ func testCancel(t *testing.T, e env) {
if grpc.Code(err) != codes.Canceled { if grpc.Code(err) != codes.Canceled {
t.Fatalf(`TestService/UnaryCall(_, _) = %v, %v; want <nil>, error code: %d`, reply, err, codes.Canceled) t.Fatalf(`TestService/UnaryCall(_, _) = %v, %v; want <nil>, error code: %d`, reply, err, codes.Canceled)
} }
cc.Close()
awaitNewConnLogOutput() awaitNewConnLogOutput()
} }

File diff suppressed because it is too large Load Diff

View File

@ -405,12 +405,13 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
if t.streamsQuota != nil { if t.streamsQuota != nil {
updateStreams = true updateStreams = true
} }
delete(t.activeStreams, s.id) if t.state == draining && len(t.activeStreams) == 1 {
if t.state == draining && len(t.activeStreams) == 0 { // The transport is draining and s is the last live stream on t.
t.mu.Unlock() t.mu.Unlock()
t.Close() t.Close()
return return
} }
delete(t.activeStreams, s.id)
t.mu.Unlock() t.mu.Unlock()
if updateStreams { if updateStreams {
t.streamsQuota.add(1) t.streamsQuota.add(1)