diff --git a/balancer.go b/balancer.go new file mode 100644 index 00000000..348bf975 --- /dev/null +++ b/balancer.go @@ -0,0 +1,340 @@ +/* + * + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package grpc + +import ( + "fmt" + "sync" + + "golang.org/x/net/context" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/naming" + "google.golang.org/grpc/transport" +) + +// Address represents a server the client connects to. +// This is the EXPERIMENTAL API and may be changed or extended in the future. +type Address struct { + // Addr is the server address on which a connection will be established. + Addr string + // Metadata is the information associated with Addr, which may be used + // to make load balancing decision. + Metadata interface{} +} + +// BalancerGetOptions configures a Get call. +// This is the EXPERIMENTAL API and may be changed or extended in the future. +type BalancerGetOptions struct { + // BlockingWait specifies whether Get should block when there is no + // connected address. + BlockingWait bool +} + +// Balancer chooses network addresses for RPCs. +// This is the EXPERIMENTAL API and may be changed or extended in the future. +type Balancer interface { + // Start does the initialization work to bootstrap a Balancer. For example, + // this function may start the name resolution and watch the updates. It will + // be called when dialing. + Start(target string) error + // Up informs the Balancer that gRPC has a connection to the server at + // addr. It returns down which is called once the connection to addr gets + // lost or closed. + // TODO: It is not clear how to construct and take advantage the meaningful error + // parameter for down. Need realistic demands to guide. + Up(addr Address) (down func(error)) + // Get gets the address of a server for the RPC corresponding to ctx. + // i) If it returns a connected address, gRPC internals issues the RPC on the + // connection to this address; + // ii) If it returns an address on which the connection is under construction + // (initiated by Notify(...)) but not connected, gRPC internals + // * fails RPC if the RPC is fail-fast and connection is in the TransientFailure or + // Shutdown state; + // or + // * issues RPC on the connection otherwise. + // iii) If it returns an address on which the connection does not exist, gRPC + // internals treats it as an error and will fail the corresponding RPC. + // + // Therefore, the following is the recommended rule when writing a custom Balancer. + // If opts.BlockingWait is true, it should return a connected address or + // block if there is no connected address. It should respect the timeout or + // cancellation of ctx when blocking. If opts.BlockingWait is false (for fail-fast + // RPCs), it should return an address it has notified via Notify(...) immediately + // instead of blocking. + // + // The function returns put which is called once the rpc has completed or failed. + // put can collect and report RPC stats to a remote load balancer. gRPC internals + // will try to call this again if err is non-nil (unless err is ErrClientConnClosing). + // + // TODO: Add other non-recoverable errors? + Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) + // Notify returns a channel that is used by gRPC internals to watch the addresses + // gRPC needs to connect. The addresses might be from a name resolver or remote + // load balancer. gRPC internals will compare it with the existing connected + // addresses. If the address Balancer notified is not in the existing connected + // addresses, gRPC starts to connect the address. If an address in the existing + // connected addresses is not in the notification list, the corresponding connection + // is shutdown gracefully. Otherwise, there are no operations to take. Note that + // the Address slice must be the full list of the Addresses which should be connected. + // It is NOT delta. + Notify() <-chan []Address + // Close shuts down the balancer. + Close() error +} + +// downErr implements net.Error. It is constructed by gRPC internals and passed to the down +// call of Balancer. +type downErr struct { + timeout bool + temporary bool + desc string +} + +func (e downErr) Error() string { return e.desc } +func (e downErr) Timeout() bool { return e.timeout } +func (e downErr) Temporary() bool { return e.temporary } + +func downErrorf(timeout, temporary bool, format string, a ...interface{}) downErr { + return downErr{ + timeout: timeout, + temporary: temporary, + desc: fmt.Sprintf(format, a...), + } +} + +// RoundRobin returns a Balancer that selects addresses round-robin. It uses r to watch +// the name resolution updates and updates the addresses available correspondingly. +func RoundRobin(r naming.Resolver) Balancer { + return &roundRobin{r: r} +} + +type roundRobin struct { + r naming.Resolver + w naming.Watcher + open []Address // all the addresses the client should potentially connect + mu sync.Mutex + addrCh chan []Address // the channel to notify gRPC internals the list of addresses the client should connect to. + connected []Address // all the connected addresses + next int // index of the next address to return for Get() + waitCh chan struct{} // the channel to block when there is no connected address available + done bool // The Balancer is closed. +} + +func (rr *roundRobin) watchAddrUpdates() error { + updates, err := rr.w.Next() + if err != nil { + grpclog.Println("grpc: the naming watcher stops working due to %v.", err) + return err + } + rr.mu.Lock() + defer rr.mu.Unlock() + for _, update := range updates { + addr := Address{ + Addr: update.Addr, + } + switch update.Op { + case naming.Add: + var exist bool + for _, v := range rr.open { + if addr == v { + exist = true + grpclog.Println("grpc: The name resolver wanted to add an existing address: ", addr) + break + } + } + if exist { + continue + } + rr.open = append(rr.open, addr) + case naming.Delete: + for i, v := range rr.open { + if v == addr { + copy(rr.open[i:], rr.open[i+1:]) + rr.open = rr.open[:len(rr.open)-1] + break + } + } + default: + grpclog.Println("Unknown update.Op ", update.Op) + } + } + // Make a copy of rr.open and write it onto rr.addrCh so that gRPC internals gets notified. + open := make([]Address, len(rr.open), len(rr.open)) + copy(open, rr.open) + if rr.done { + return ErrClientConnClosing + } + rr.addrCh <- open + return nil +} + +func (rr *roundRobin) Start(target string) error { + if rr.r == nil { + // If there is no name resolver installed, it is not needed to + // do name resolution. In this case, rr.addrCh stays nil. + return nil + } + w, err := rr.r.Resolve(target) + if err != nil { + return err + } + rr.w = w + rr.addrCh = make(chan []Address) + go func() { + for { + if err := rr.watchAddrUpdates(); err != nil { + return + } + } + }() + return nil +} + +// Up appends addr to the end of rr.connected and sends notification if there +// are pending Get() calls. +func (rr *roundRobin) Up(addr Address) func(error) { + rr.mu.Lock() + defer rr.mu.Unlock() + for _, a := range rr.connected { + if a == addr { + return nil + } + } + rr.connected = append(rr.connected, addr) + if len(rr.connected) == 1 { + // addr is only one available. Notify the Get() callers who are blocking. + if rr.waitCh != nil { + close(rr.waitCh) + rr.waitCh = nil + } + } + return func(err error) { + rr.down(addr, err) + } +} + +// down removes addr from rr.connected and moves the remaining addrs forward. +func (rr *roundRobin) down(addr Address, err error) { + rr.mu.Lock() + defer rr.mu.Unlock() + for i, a := range rr.connected { + if a == addr { + copy(rr.connected[i:], rr.connected[i+1:]) + rr.connected = rr.connected[:len(rr.connected)-1] + return + } + } +} + +// Get returns the next addr in the rotation. +func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) { + var ch chan struct{} + rr.mu.Lock() + if rr.done { + rr.mu.Unlock() + err = ErrClientConnClosing + return + } + if rr.next >= len(rr.connected) { + rr.next = 0 + } + if len(rr.connected) > 0 { + addr = rr.connected[rr.next] + rr.next++ + rr.mu.Unlock() + return + } + // There is no address available. Wait on rr.waitCh. + // TODO(zhaoq): Handle the case when opts.BlockingWait is false. + if rr.waitCh == nil { + ch = make(chan struct{}) + rr.waitCh = ch + } else { + ch = rr.waitCh + } + rr.mu.Unlock() + for { + select { + case <-ctx.Done(): + err = transport.ContextErr(ctx.Err()) + return + case <-ch: + rr.mu.Lock() + if rr.done { + rr.mu.Unlock() + err = ErrClientConnClosing + return + } + if len(rr.connected) == 0 { + // The newly added addr got removed by Down() again. + if rr.waitCh == nil { + ch = make(chan struct{}) + rr.waitCh = ch + } else { + ch = rr.waitCh + } + rr.mu.Unlock() + continue + } + if rr.next >= len(rr.connected) { + rr.next = 0 + } + addr = rr.connected[rr.next] + rr.next++ + rr.mu.Unlock() + return + } + } +} + +func (rr *roundRobin) Notify() <-chan []Address { + return rr.addrCh +} + +func (rr *roundRobin) Close() error { + rr.mu.Lock() + defer rr.mu.Unlock() + rr.done = true + if rr.w != nil { + rr.w.Close() + } + if rr.waitCh != nil { + close(rr.waitCh) + rr.waitCh = nil + } + if rr.addrCh != nil { + close(rr.addrCh) + } + return nil +} diff --git a/balancer_test.go b/balancer_test.go new file mode 100644 index 00000000..9d8d2bcd --- /dev/null +++ b/balancer_test.go @@ -0,0 +1,322 @@ +/* + * + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package grpc + +import ( + "fmt" + "math" + "sync" + "testing" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/naming" +) + +type testWatcher struct { + // the channel to receives name resolution updates + update chan *naming.Update + // the side channel to get to know how many updates in a batch + side chan int + // the channel to notifiy update injector that the update reading is done + readDone chan int +} + +func (w *testWatcher) Next() (updates []*naming.Update, err error) { + n := <-w.side + if n == 0 { + return nil, fmt.Errorf("w.side is closed") + } + for i := 0; i < n; i++ { + u := <-w.update + if u != nil { + updates = append(updates, u) + } + } + w.readDone <- 0 + return +} + +func (w *testWatcher) Close() { +} + +// Inject naming resolution updates to the testWatcher. +func (w *testWatcher) inject(updates []*naming.Update) { + w.side <- len(updates) + for _, u := range updates { + w.update <- u + } + <-w.readDone +} + +type testNameResolver struct { + w *testWatcher + addr string +} + +func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) { + r.w = &testWatcher{ + update: make(chan *naming.Update, 1), + side: make(chan int, 1), + readDone: make(chan int), + } + r.w.side <- 1 + r.w.update <- &naming.Update{ + Op: naming.Add, + Addr: r.addr, + } + go func() { + <-r.w.readDone + }() + return r.w, nil +} + +func startServers(t *testing.T, numServers int, maxStreams uint32) ([]*server, *testNameResolver) { + var servers []*server + for i := 0; i < numServers; i++ { + s := newTestServer() + servers = append(servers, s) + go s.start(t, 0, maxStreams) + s.wait(t, 2*time.Second) + } + // Point to server[0] + addr := "127.0.0.1:" + servers[0].port + return servers, &testNameResolver{ + addr: addr, + } +} + +func TestNameDiscovery(t *testing.T) { + // Start 2 servers on 2 ports. + numServers := 2 + servers, r := startServers(t, numServers, math.MaxUint32) + cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("Failed to create ClientConn: %v", err) + } + req := "port" + var reply string + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[0].port { + t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port) + } + // Inject the name resolution change to remove servers[0] and add servers[1]. + var updates []*naming.Update + updates = append(updates, &naming.Update{ + Op: naming.Delete, + Addr: "127.0.0.1:" + servers[0].port, + }) + updates = append(updates, &naming.Update{ + Op: naming.Add, + Addr: "127.0.0.1:" + servers[1].port, + }) + r.w.inject(updates) + // Loop until the rpcs in flight talks to servers[1]. + for { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port { + break + } + time.Sleep(10 * time.Millisecond) + } + cc.Close() + for i := 0; i < numServers; i++ { + servers[i].stop() + } +} + +func TestEmptyAddrs(t *testing.T) { + servers, r := startServers(t, 1, math.MaxUint32) + cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("Failed to create ClientConn: %v", err) + } + var reply string + if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse { + t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, reply = %q, want %q, ", err, reply, expectedResponse) + } + // Inject name resolution change to remove the server so that there is no address + // available after that. + u := &naming.Update{ + Op: naming.Delete, + Addr: "127.0.0.1:" + servers[0].port, + } + r.w.inject([]*naming.Update{u}) + // Loop until the above updates apply. + for { + time.Sleep(10 * time.Millisecond) + ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) + if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); err != nil { + break + } + } + cc.Close() + servers[0].stop() +} + +func TestRoundRobin(t *testing.T) { + // Start 3 servers on 3 ports. + numServers := 3 + servers, r := startServers(t, numServers, math.MaxUint32) + cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("Failed to create ClientConn: %v", err) + } + // Add servers[1] to the service discovery. + u := &naming.Update{ + Op: naming.Add, + Addr: "127.0.0.1:" + servers[1].port, + } + r.w.inject([]*naming.Update{u}) + req := "port" + var reply string + // Loop until servers[1] is up + for { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port { + break + } + time.Sleep(10 * time.Millisecond) + } + // Add server2[2] to the service discovery. + u = &naming.Update{ + Op: naming.Add, + Addr: "127.0.0.1:" + servers[2].port, + } + r.w.inject([]*naming.Update{u}) + // Loop until both servers[2] are up. + for { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[2].port { + break + } + time.Sleep(10 * time.Millisecond) + } + // Check the incoming RPCs served in a round-robin manner. + for i := 0; i < 10; i++ { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[i%numServers].port { + t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", i, err, servers[i%numServers].port) + } + } + cc.Close() + for i := 0; i < numServers; i++ { + servers[i].stop() + } +} + +func TestCloseWithPendingRPC(t *testing.T) { + servers, r := startServers(t, 1, math.MaxUint32) + cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("Failed to create ClientConn: %v", err) + } + var reply string + if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil { + t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port) + } + // Remove the server. + updates := []*naming.Update{&naming.Update{ + Op: naming.Delete, + Addr: "127.0.0.1:" + servers[0].port, + }} + r.w.inject(updates) + // Loop until the above update applies. + for { + ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) + if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); Code(err) == codes.DeadlineExceeded { + break + } + time.Sleep(10 * time.Millisecond) + } + // Issue 2 RPCs which should be completed with error status once cc is closed. + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + var reply string + if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err == nil { + t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err) + } + }() + go func() { + defer wg.Done() + var reply string + time.Sleep(5 * time.Millisecond) + if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err == nil { + t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err) + } + }() + time.Sleep(5 * time.Millisecond) + cc.Close() + wg.Wait() + servers[0].stop() +} + +func TestGetOnWaitChannel(t *testing.T) { + servers, r := startServers(t, 1, math.MaxUint32) + cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("Failed to create ClientConn: %v", err) + } + // Remove all servers so that all upcoming RPCs will block on waitCh. + updates := []*naming.Update{&naming.Update{ + Op: naming.Delete, + Addr: "127.0.0.1:" + servers[0].port, + }} + r.w.inject(updates) + for { + var reply string + ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) + if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); Code(err) == codes.DeadlineExceeded { + break + } + time.Sleep(10 * time.Millisecond) + } + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + var reply string + if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil { + t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want ", err) + } + }() + // Add a connected server to get the above RPC through. + updates = []*naming.Update{&naming.Update{ + Op: naming.Add, + Addr: "127.0.0.1:" + servers[0].port, + }} + r.w.inject(updates) + // Wait until the above RPC succeeds. + wg.Wait() + cc.Close() + servers[0].stop() +} diff --git a/call.go b/call.go index 9d0fc8ee..d6d993b4 100644 --- a/call.go +++ b/call.go @@ -132,19 +132,16 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli Last: true, Delay: false, } - var ( - lastErr error // record the error that happened - ) for { var ( err error t transport.ClientTransport stream *transport.Stream + // Record the put handler from Balancer.Get(...). It is called once the + // RPC has completed or failed. + put func() ) - // TODO(zhaoq): Need a formal spec of retry strategy for non-failfast rpcs. - if lastErr != nil && c.failFast { - return toRPCErr(lastErr) - } + // TODO(zhaoq): Need a formal spec of fail-fast. callHdr := &transport.CallHdr{ Host: cc.authority, Method: method, @@ -152,39 +149,66 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli if cc.dopts.cp != nil { callHdr.SendCompress = cc.dopts.cp.Type() } - t, err = cc.dopts.picker.Pick(ctx) + gopts := BalancerGetOptions{ + BlockingWait: !c.failFast, + } + t, put, err = cc.getTransport(ctx, gopts) if err != nil { - if lastErr != nil { - // This was a retry; return the error from the last attempt. - return toRPCErr(lastErr) + // TODO(zhaoq): Probably revisit the error handling. + if err == ErrClientConnClosing { + return Errorf(codes.FailedPrecondition, "%v", err) } - return toRPCErr(err) + if _, ok := err.(transport.StreamError); ok { + return toRPCErr(err) + } + if _, ok := err.(transport.ConnectionError); ok { + if c.failFast { + return toRPCErr(err) + } + } + // All the remaining cases are treated as retryable. + continue } if c.traceInfo.tr != nil { c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) } stream, err = sendRequest(ctx, cc.dopts.codec, cc.dopts.cp, callHdr, t, args, topts) if err != nil { - if _, ok := err.(transport.ConnectionError); ok { - lastErr = err - continue + if put != nil { + put() + put = nil } - if lastErr != nil { - return toRPCErr(lastErr) + if _, ok := err.(transport.ConnectionError); ok { + if c.failFast { + return toRPCErr(err) + } + continue } return toRPCErr(err) } // Receive the response - lastErr = recvResponse(cc.dopts, t, &c, stream, reply) - if _, ok := lastErr.(transport.ConnectionError); ok { - continue + err = recvResponse(cc.dopts, t, &c, stream, reply) + if err != nil { + if put != nil { + put() + put = nil + } + if _, ok := err.(transport.ConnectionError); ok { + if c.failFast { + return toRPCErr(err) + } + continue + } + t.CloseStream(stream, err) + return toRPCErr(err) } if c.traceInfo.tr != nil { c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true) } - t.CloseStream(stream, lastErr) - if lastErr != nil { - return toRPCErr(lastErr) + t.CloseStream(stream, nil) + if put != nil { + put() + put = nil } return Errorf(stream.StatusCode(), "%s", stream.StatusDesc()) } diff --git a/call_test.go b/call_test.go index 7d01f457..380bf872 100644 --- a/call_test.go +++ b/call_test.go @@ -74,7 +74,8 @@ func (testCodec) String() string { } type testStreamHandler struct { - t transport.ServerTransport + port string + t transport.ServerTransport } func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) { @@ -106,6 +107,11 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) { h.t.WriteStatus(s, codes.Internal, "") return } + if v == "port" { + h.t.WriteStatus(s, codes.Internal, h.port) + return + } + if v != expectedRequest { h.t.WriteStatus(s, codes.Internal, strings.Repeat("A", sizeLargeErr)) return @@ -160,7 +166,7 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32) { } st, err := transport.NewServerTransport("http2", conn, maxStreams, nil) if err != nil { - return + continue } s.mu.Lock() if s.conns == nil { @@ -170,7 +176,10 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32) { } s.conns[st] = true s.mu.Unlock() - h := &testStreamHandler{st} + h := &testStreamHandler{ + port: s.port, + t: st, + } go st.HandleStreams(func(s *transport.Stream) { go h.handleStream(t, s) }) diff --git a/clientconn.go b/clientconn.go index ebf99023..0d640b47 100644 --- a/clientconn.go +++ b/clientconn.go @@ -43,28 +43,34 @@ import ( "golang.org/x/net/context" "golang.org/x/net/trace" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/transport" ) var ( - // ErrUnspecTarget indicates that the target address is unspecified. - ErrUnspecTarget = errors.New("grpc: target is unspecified") - // ErrNoTransportSecurity indicates that there is no transport security + // ErrClientConnClosing indicates that the operation is illegal because + // the ClientConn is closing. + ErrClientConnClosing = errors.New("grpc: the client connection is closing") + + // errNoTransportSecurity indicates that there is no transport security // being set for ClientConn. Users should either set one or explicitly // call WithInsecure DialOption to disable security. - ErrNoTransportSecurity = errors.New("grpc: no transport security set (use grpc.WithInsecure() explicitly or set credentials)") - // ErrCredentialsMisuse indicates that users want to transmit security information + errNoTransportSecurity = errors.New("grpc: no transport security set (use grpc.WithInsecure() explicitly or set credentials)") + // errCredentialsMisuse indicates that users want to transmit security information // (e.g., oauth2 token) which requires secure connection on an insecure // connection. - ErrCredentialsMisuse = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportAuthenticator() to set)") - // ErrClientConnClosing indicates that the operation is illegal because - // the session is closing. - ErrClientConnClosing = errors.New("grpc: the client connection is closing") - // ErrClientConnTimeout indicates that the connection could not be + errCredentialsMisuse = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportAuthenticator() to set)") + // errClientConnTimeout indicates that the connection could not be // established or re-established within the specified timeout. - ErrClientConnTimeout = errors.New("grpc: timed out trying to connect") + errClientConnTimeout = errors.New("grpc: timed out trying to connect") + // errNetworkIP indicates that the connection is down due to some network I/O error. + errNetworkIO = errors.New("grpc: failed with network I/O error") + // errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs. + errConnDrain = errors.New("grpc: the connection is drained") + // errConnClosing indicates that the connection is closing. + errConnClosing = errors.New("grpc: the connection is closing") // minimum time to give a connection to complete minConnectTimeout = 20 * time.Second ) @@ -76,7 +82,7 @@ type dialOptions struct { cp Compressor dc Decompressor bs backoffStrategy - picker Picker + balancer Balancer block bool insecure bool copts transport.ConnectOptions @@ -108,10 +114,10 @@ func WithDecompressor(dc Decompressor) DialOption { } } -// WithPicker returns a DialOption which sets a picker for connection selection. -func WithPicker(p Picker) DialOption { +// WithBalancer returns a DialOption which sets a load balancer. +func WithBalancer(b Balancer) DialOption { return func(o *dialOptions) { - o.picker = p + o.balancer = b } } @@ -201,6 +207,7 @@ func WithUserAgent(s string) DialOption { func Dial(target string, opts ...DialOption) (*ClientConn, error) { cc := &ClientConn{ target: target, + conns: make(map[Address]*addrConn), } for _, opt := range opts { opt(&cc.dopts) @@ -214,14 +221,33 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) { cc.dopts.bs = DefaultBackoffConfig } - if cc.dopts.picker == nil { - cc.dopts.picker = &unicastPicker{ - target: target, - } + cc.balancer = cc.dopts.balancer + if cc.balancer == nil { + cc.balancer = RoundRobin(nil) } - if err := cc.dopts.picker.Init(cc); err != nil { + if err := cc.balancer.Start(target); err != nil { return nil, err } + ch := cc.balancer.Notify() + if ch == nil { + // There is no name resolver installed. + addr := Address{Addr: target} + if err := cc.newAddrConn(addr, false); err != nil { + return nil, err + } + } else { + addrs, ok := <-ch + if !ok || len(addrs) == 0 { + return nil, fmt.Errorf("grpc: there is no address available to dial") + } + for _, a := range addrs { + if err := cc.newAddrConn(a, false); err != nil { + return nil, err + } + } + go cc.lbWatcher() + } + colonPos := strings.LastIndex(target, ":") if colonPos == -1 { colonPos = len(target) @@ -263,193 +289,268 @@ func (s ConnectivityState) String() string { } } -// ClientConn represents a client connection to an RPC service. +// ClientConn represents a client connection to an RPC server. type ClientConn struct { target string + balancer Balancer authority string dopts dialOptions + + mu sync.RWMutex + conns map[Address]*addrConn } -// State returns the connectivity state of cc. -// This is EXPERIMENTAL API. -func (cc *ClientConn) State() (ConnectivityState, error) { - return cc.dopts.picker.State() +func (cc *ClientConn) lbWatcher() { + for addrs := range cc.balancer.Notify() { + var ( + add []Address // Addresses need to setup connections. + del []*addrConn // Connections need to tear down. + ) + cc.mu.Lock() + for _, a := range addrs { + if _, ok := cc.conns[a]; !ok { + add = append(add, a) + } + } + for k, c := range cc.conns { + var keep bool + for _, a := range addrs { + if k == a { + keep = true + break + } + } + if !keep { + del = append(del, c) + } + } + cc.mu.Unlock() + for _, a := range add { + cc.newAddrConn(a, true) + } + for _, c := range del { + c.tearDown(errConnDrain) + } + } } -// WaitForStateChange blocks until the state changes to something other than the sourceState. -// It returns the new state or error. -// This is EXPERIMENTAL API. -func (cc *ClientConn) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) { - return cc.dopts.picker.WaitForStateChange(ctx, sourceState) +func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error { + ac := &addrConn{ + cc: cc, + addr: addr, + dopts: cc.dopts, + shutdownChan: make(chan struct{}), + } + if EnableTracing { + ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr) + } + if !ac.dopts.insecure { + var ok bool + for _, cd := range ac.dopts.copts.AuthOptions { + if _, ok = cd.(credentials.TransportAuthenticator); ok { + break + } + } + if !ok { + return errNoTransportSecurity + } + } else { + for _, cd := range ac.dopts.copts.AuthOptions { + if cd.RequireTransportSecurity() { + return errCredentialsMisuse + } + } + } + // Insert ac into ac.cc.conns. This needs to be done before any getTransport(...) is called. + ac.cc.mu.Lock() + if ac.cc.conns == nil { + ac.cc.mu.Unlock() + return ErrClientConnClosing + } + stale := ac.cc.conns[ac.addr] + ac.cc.conns[ac.addr] = ac + ac.cc.mu.Unlock() + if stale != nil { + // There is an addrConn alive on ac.addr already. This could be due to + // i) stale's Close is undergoing; + // ii) a buggy Balancer notifies duplicated Addresses. + stale.tearDown(errConnDrain) + } + ac.stateCV = sync.NewCond(&ac.mu) + // skipWait may overwrite the decision in ac.dopts.block. + if ac.dopts.block && !skipWait { + if err := ac.resetTransport(false); err != nil { + ac.tearDown(err) + return err + } + // Start to monitor the error status of transport. + go ac.transportMonitor() + } else { + // Start a goroutine connecting to the server asynchronously. + go func() { + if err := ac.resetTransport(false); err != nil { + grpclog.Printf("Failed to dial %s: %v; please retry.", ac.addr.Addr, err) + ac.tearDown(err) + return + } + ac.transportMonitor() + }() + } + return nil } -// Close starts to tear down the ClientConn. +func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) { + // TODO(zhaoq): Implement fail-fast logic. + addr, put, err := cc.balancer.Get(ctx, opts) + if err != nil { + return nil, nil, err + } + cc.mu.RLock() + if cc.conns == nil { + cc.mu.RUnlock() + return nil, nil, ErrClientConnClosing + } + ac, ok := cc.conns[addr] + cc.mu.RUnlock() + if !ok { + if put != nil { + put() + } + return nil, nil, transport.StreamErrorf(codes.Internal, "grpc: failed to find the transport to send the rpc") + } + t, err := ac.wait(ctx) + if err != nil { + if put != nil { + put() + } + return nil, nil, err + } + return t, put, nil +} + +// Close tears down the ClientConn and all underlying connections. func (cc *ClientConn) Close() error { - return cc.dopts.picker.Close() + cc.mu.Lock() + if cc.conns == nil { + cc.mu.Unlock() + return ErrClientConnClosing + } + conns := cc.conns + cc.conns = nil + cc.mu.Unlock() + cc.balancer.Close() + for _, ac := range conns { + ac.tearDown(ErrClientConnClosing) + } + return nil } -// Conn is a client connection to a single destination. -type Conn struct { - target string +// addrConn is a network connection to a given address. +type addrConn struct { + cc *ClientConn + addr Address dopts dialOptions - resetChan chan int shutdownChan chan struct{} events trace.EventLog mu sync.Mutex state ConnectivityState stateCV *sync.Cond + down func(error) // the handler called when a connection is down. // ready is closed and becomes nil when a new transport is up or failed // due to timeout. ready chan struct{} transport transport.ClientTransport } -// NewConn creates a Conn. -func NewConn(cc *ClientConn) (*Conn, error) { - if cc.target == "" { - return nil, ErrUnspecTarget - } - c := &Conn{ - target: cc.target, - dopts: cc.dopts, - resetChan: make(chan int, 1), - shutdownChan: make(chan struct{}), - } - if EnableTracing { - c.events = trace.NewEventLog("grpc.ClientConn", c.target) - } - if !c.dopts.insecure { - var ok bool - for _, cd := range c.dopts.copts.AuthOptions { - if _, ok = cd.(credentials.TransportAuthenticator); ok { - break - } - } - if !ok { - return nil, ErrNoTransportSecurity - } - } else { - for _, cd := range c.dopts.copts.AuthOptions { - if cd.RequireTransportSecurity() { - return nil, ErrCredentialsMisuse - } - } - } - c.stateCV = sync.NewCond(&c.mu) - if c.dopts.block { - if err := c.resetTransport(false); err != nil { - c.Close() - return nil, err - } - // Start to monitor the error status of transport. - go c.transportMonitor() - } else { - // Start a goroutine connecting to the server asynchronously. - go func() { - if err := c.resetTransport(false); err != nil { - grpclog.Printf("Failed to dial %s: %v; please retry.", c.target, err) - c.Close() - return - } - c.transportMonitor() - }() - } - return c, nil -} - -// printf records an event in cc's event log, unless cc has been closed. -// REQUIRES cc.mu is held. -func (cc *Conn) printf(format string, a ...interface{}) { - if cc.events != nil { - cc.events.Printf(format, a...) +// printf records an event in ac's event log, unless ac has been closed. +// REQUIRES ac.mu is held. +func (ac *addrConn) printf(format string, a ...interface{}) { + if ac.events != nil { + ac.events.Printf(format, a...) } } -// errorf records an error in cc's event log, unless cc has been closed. -// REQUIRES cc.mu is held. -func (cc *Conn) errorf(format string, a ...interface{}) { - if cc.events != nil { - cc.events.Errorf(format, a...) +// errorf records an error in ac's event log, unless ac has been closed. +// REQUIRES ac.mu is held. +func (ac *addrConn) errorf(format string, a ...interface{}) { + if ac.events != nil { + ac.events.Errorf(format, a...) } } -// State returns the connectivity state of the Conn -func (cc *Conn) State() ConnectivityState { - cc.mu.Lock() - defer cc.mu.Unlock() - return cc.state +// getState returns the connectivity state of the Conn +func (ac *addrConn) getState() ConnectivityState { + ac.mu.Lock() + defer ac.mu.Unlock() + return ac.state } -// WaitForStateChange blocks until the state changes to something other than the sourceState. -func (cc *Conn) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) { - cc.mu.Lock() - defer cc.mu.Unlock() - if sourceState != cc.state { - return cc.state, nil +// waitForStateChange blocks until the state changes to something other than the sourceState. +func (ac *addrConn) waitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) { + ac.mu.Lock() + defer ac.mu.Unlock() + if sourceState != ac.state { + return ac.state, nil } done := make(chan struct{}) var err error go func() { select { case <-ctx.Done(): - cc.mu.Lock() + ac.mu.Lock() err = ctx.Err() - cc.stateCV.Broadcast() - cc.mu.Unlock() + ac.stateCV.Broadcast() + ac.mu.Unlock() case <-done: } }() defer close(done) - for sourceState == cc.state { - cc.stateCV.Wait() + for sourceState == ac.state { + ac.stateCV.Wait() if err != nil { - return cc.state, err + return ac.state, err } } - return cc.state, nil + return ac.state, nil } -// NotifyReset tries to signal the underlying transport needs to be reset due to -// for example a name resolution change in flight. -func (cc *Conn) NotifyReset() { - select { - case cc.resetChan <- 0: - default: - } -} - -func (cc *Conn) resetTransport(closeTransport bool) error { +func (ac *addrConn) resetTransport(closeTransport bool) error { var retries int start := time.Now() for { - cc.mu.Lock() - cc.printf("connecting") - if cc.state == Shutdown { - // cc.Close() has been invoked. - cc.mu.Unlock() - return ErrClientConnClosing + ac.mu.Lock() + ac.printf("connecting") + if ac.state == Shutdown { + // ac.tearDown(...) has been invoked. + ac.mu.Unlock() + return errConnClosing } - cc.state = Connecting - cc.stateCV.Broadcast() - cc.mu.Unlock() - if closeTransport { - cc.transport.Close() + if ac.down != nil { + ac.down(downErrorf(false, true, "%v", errNetworkIO)) + ac.down = nil + } + ac.state = Connecting + ac.stateCV.Broadcast() + t := ac.transport + ac.mu.Unlock() + if closeTransport && t != nil { + t.Close() } // Adjust timeout for the current try. - copts := cc.dopts.copts + copts := ac.dopts.copts if copts.Timeout < 0 { - cc.Close() - return ErrClientConnTimeout + ac.tearDown(errClientConnTimeout) + return errClientConnTimeout } if copts.Timeout > 0 { copts.Timeout -= time.Since(start) if copts.Timeout <= 0 { - cc.Close() - return ErrClientConnTimeout + ac.tearDown(errClientConnTimeout) + return errClientConnTimeout } } - sleepTime := cc.dopts.bs.backoff(retries) + sleepTime := ac.dopts.bs.backoff(retries) timeout := sleepTime if timeout < minConnectTimeout { timeout = minConnectTimeout @@ -458,133 +559,116 @@ func (cc *Conn) resetTransport(closeTransport bool) error { copts.Timeout = timeout } connectTime := time.Now() - addr, err := cc.dopts.picker.PickAddr() - var newTransport transport.ClientTransport - if err == nil { - newTransport, err = transport.NewClientTransport(addr, &copts) - } + newTransport, err := transport.NewClientTransport(ac.addr.Addr, &copts) if err != nil { - cc.mu.Lock() - if cc.state == Shutdown { - // cc.Close() has been invoked. - cc.mu.Unlock() - return ErrClientConnClosing + ac.mu.Lock() + if ac.state == Shutdown { + // ac.tearDown(...) has been invoked. + ac.mu.Unlock() + return errConnClosing } - cc.errorf("transient failure: %v", err) - cc.state = TransientFailure - cc.stateCV.Broadcast() - if cc.ready != nil { - close(cc.ready) - cc.ready = nil + ac.errorf("transient failure: %v", err) + ac.state = TransientFailure + ac.stateCV.Broadcast() + if ac.ready != nil { + close(ac.ready) + ac.ready = nil } - cc.mu.Unlock() + ac.mu.Unlock() sleepTime -= time.Since(connectTime) if sleepTime < 0 { sleepTime = 0 } // Fail early before falling into sleep. - if cc.dopts.copts.Timeout > 0 && cc.dopts.copts.Timeout < sleepTime+time.Since(start) { - cc.mu.Lock() - cc.errorf("connection timeout") - cc.mu.Unlock() - cc.Close() - return ErrClientConnTimeout + if ac.dopts.copts.Timeout > 0 && ac.dopts.copts.Timeout < sleepTime+time.Since(start) { + ac.mu.Lock() + ac.errorf("connection timeout") + ac.mu.Unlock() + ac.tearDown(errClientConnTimeout) + return errClientConnTimeout } closeTransport = false select { case <-time.After(sleepTime): - case <-cc.shutdownChan: + case <-ac.shutdownChan: } retries++ - grpclog.Printf("grpc: Conn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, cc.target) + grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr) continue } - cc.mu.Lock() - cc.printf("ready") - if cc.state == Shutdown { - // cc.Close() has been invoked. - cc.mu.Unlock() + ac.mu.Lock() + ac.printf("ready") + if ac.state == Shutdown { + // ac.tearDown(...) has been invoked. + ac.mu.Unlock() newTransport.Close() - return ErrClientConnClosing + return errConnClosing } - cc.state = Ready - cc.stateCV.Broadcast() - cc.transport = newTransport - if cc.ready != nil { - close(cc.ready) - cc.ready = nil + ac.state = Ready + ac.stateCV.Broadcast() + ac.transport = newTransport + if ac.ready != nil { + close(ac.ready) + ac.ready = nil } - cc.mu.Unlock() + ac.down = ac.cc.balancer.Up(ac.addr) + ac.mu.Unlock() return nil } } -func (cc *Conn) reconnect() bool { - cc.mu.Lock() - if cc.state == Shutdown { - // cc.Close() has been invoked. - cc.mu.Unlock() - return false - } - cc.state = TransientFailure - cc.stateCV.Broadcast() - cc.mu.Unlock() - if err := cc.resetTransport(true); err != nil { - // The ClientConn is closing. - cc.mu.Lock() - cc.printf("transport exiting: %v", err) - cc.mu.Unlock() - grpclog.Printf("grpc: Conn.transportMonitor exits due to: %v", err) - return false - } - return true -} - // Run in a goroutine to track the error in transport and create the // new transport if an error happens. It returns when the channel is closing. -func (cc *Conn) transportMonitor() { +func (ac *addrConn) transportMonitor() { for { + ac.mu.Lock() + t := ac.transport + ac.mu.Unlock() select { // shutdownChan is needed to detect the teardown when - // the ClientConn is idle (i.e., no RPC in flight). - case <-cc.shutdownChan: + // the addrConn is idle (i.e., no RPC in flight). + case <-ac.shutdownChan: return - case <-cc.resetChan: - if !cc.reconnect() { + case <-t.Error(): + ac.mu.Lock() + if ac.state == Shutdown { + // ac.tearDown(...) has been invoked. + ac.mu.Unlock() return } - case <-cc.transport.Error(): - if !cc.reconnect() { + ac.state = TransientFailure + ac.stateCV.Broadcast() + ac.mu.Unlock() + if err := ac.resetTransport(true); err != nil { + ac.mu.Lock() + ac.printf("transport exiting: %v", err) + ac.mu.Unlock() + grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err) return } - // Tries to drain reset signal if there is any since it is out-dated. - select { - case <-cc.resetChan: - default: - } } } } -// Wait blocks until i) the new transport is up or ii) ctx is done or iii) cc is closed. -func (cc *Conn) Wait(ctx context.Context) (transport.ClientTransport, error) { +// wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed. +func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error) { for { - cc.mu.Lock() + ac.mu.Lock() switch { - case cc.state == Shutdown: - cc.mu.Unlock() - return nil, ErrClientConnClosing - case cc.state == Ready: - ct := cc.transport - cc.mu.Unlock() + case ac.state == Shutdown: + ac.mu.Unlock() + return nil, errConnClosing + case ac.state == Ready: + ct := ac.transport + ac.mu.Unlock() return ct, nil default: - ready := cc.ready + ready := ac.ready if ready == nil { ready = make(chan struct{}) - cc.ready = ready + ac.ready = ready } - cc.mu.Unlock() + ac.mu.Unlock() select { case <-ctx.Done(): return nil, transport.ContextErr(ctx.Err()) @@ -595,32 +679,46 @@ func (cc *Conn) Wait(ctx context.Context) (transport.ClientTransport, error) { } } -// Close starts to tear down the Conn. Returns ErrClientConnClosing if -// it has been closed (mostly due to dial time-out). +// tearDown starts to tear down the addrConn. // TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in -// some edge cases (e.g., the caller opens and closes many ClientConn's in a +// some edge cases (e.g., the caller opens and closes many addrConn's in a // tight loop. -func (cc *Conn) Close() error { - cc.mu.Lock() - defer cc.mu.Unlock() - if cc.state == Shutdown { - return ErrClientConnClosing +func (ac *addrConn) tearDown(err error) { + ac.mu.Lock() + defer func() { + ac.mu.Unlock() + ac.cc.mu.Lock() + if ac.cc.conns != nil { + delete(ac.cc.conns, ac.addr) + } + ac.cc.mu.Unlock() + }() + if ac.state == Shutdown { + return } - cc.state = Shutdown - cc.stateCV.Broadcast() - if cc.events != nil { - cc.events.Finish() - cc.events = nil + ac.state = Shutdown + if ac.down != nil { + ac.down(downErrorf(false, false, "%v", err)) + ac.down = nil } - if cc.ready != nil { - close(cc.ready) - cc.ready = nil + ac.stateCV.Broadcast() + if ac.events != nil { + ac.events.Finish() + ac.events = nil } - if cc.transport != nil { - cc.transport.Close() + if ac.ready != nil { + close(ac.ready) + ac.ready = nil } - if cc.shutdownChan != nil { - close(cc.shutdownChan) + if ac.transport != nil { + if err == errConnDrain { + ac.transport.GracefulClose() + } else { + ac.transport.Close() + } } - return nil + if ac.shutdownChan != nil { + close(ac.shutdownChan) + } + return } diff --git a/clientconn_test.go b/clientconn_test.go index b44f5b29..09d7f110 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -47,8 +47,8 @@ func TestDialTimeout(t *testing.T) { if err == nil { conn.Close() } - if err != ErrClientConnTimeout { - t.Fatalf("Dial(_, _) = %v, %v, want %v", conn, err, ErrClientConnTimeout) + if err != errClientConnTimeout { + t.Fatalf("Dial(_, _) = %v, %v, want %v", conn, err, errClientConnTimeout) } } @@ -61,8 +61,8 @@ func TestTLSDialTimeout(t *testing.T) { if err == nil { conn.Close() } - if err != ErrClientConnTimeout { - t.Fatalf("grpc.Dial(_, _) = %v, %v, want %v", conn, err, ErrClientConnTimeout) + if err != errClientConnTimeout { + t.Fatalf("grpc.Dial(_, _) = %v, %v, want %v", conn, err, errClientConnTimeout) } } @@ -72,12 +72,12 @@ func TestCredentialsMisuse(t *testing.T) { t.Fatalf("Failed to create credentials %v", err) } // Two conflicting credential configurations - if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != ErrCredentialsMisuse { - t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, ErrCredentialsMisuse) + if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse { + t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse) } // security info on insecure connection - if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != ErrCredentialsMisuse { - t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, ErrCredentialsMisuse) + if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse { + t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse) } } diff --git a/naming/naming.go b/naming/naming.go index 06605607..c2e0871e 100644 --- a/naming/naming.go +++ b/naming/naming.go @@ -66,7 +66,8 @@ type Resolver interface { // Watcher watches for the updates on the specified target. type Watcher interface { // Next blocks until an update or error happens. It may return one or more - // updates. The first call should get the full set of the results. + // updates. The first call should get the full set of the results. It should + // return an error if and only if Watcher cannot recover. Next() ([]*Update, error) // Close closes the Watcher. Close() diff --git a/picker.go b/picker.go deleted file mode 100644 index 50f315b4..00000000 --- a/picker.go +++ /dev/null @@ -1,243 +0,0 @@ -/* - * - * Copyright 2014, Google Inc. - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above - * copyright notice, this list of conditions and the following disclaimer - * in the documentation and/or other materials provided with the - * distribution. - * * Neither the name of Google Inc. nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - */ - -package grpc - -import ( - "container/list" - "fmt" - "sync" - - "golang.org/x/net/context" - "google.golang.org/grpc/grpclog" - "google.golang.org/grpc/naming" - "google.golang.org/grpc/transport" -) - -// Picker picks a Conn for RPC requests. -// This is EXPERIMENTAL and please do not implement your own Picker for now. -type Picker interface { - // Init does initial processing for the Picker, e.g., initiate some connections. - Init(cc *ClientConn) error - // Pick blocks until either a transport.ClientTransport is ready for the upcoming RPC - // or some error happens. - Pick(ctx context.Context) (transport.ClientTransport, error) - // PickAddr picks a peer address for connecting. This will be called repeated for - // connecting/reconnecting. - PickAddr() (string, error) - // State returns the connectivity state of the underlying connections. - State() (ConnectivityState, error) - // WaitForStateChange blocks until the state changes to something other than - // the sourceState. It returns the new state or error. - WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) - // Close closes all the Conn's owned by this Picker. - Close() error -} - -// unicastPicker is the default Picker which is used when there is no custom Picker -// specified by users. It always picks the same Conn. -type unicastPicker struct { - target string - conn *Conn -} - -func (p *unicastPicker) Init(cc *ClientConn) error { - c, err := NewConn(cc) - if err != nil { - return err - } - p.conn = c - return nil -} - -func (p *unicastPicker) Pick(ctx context.Context) (transport.ClientTransport, error) { - return p.conn.Wait(ctx) -} - -func (p *unicastPicker) PickAddr() (string, error) { - return p.target, nil -} - -func (p *unicastPicker) State() (ConnectivityState, error) { - return p.conn.State(), nil -} - -func (p *unicastPicker) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) { - return p.conn.WaitForStateChange(ctx, sourceState) -} - -func (p *unicastPicker) Close() error { - if p.conn != nil { - return p.conn.Close() - } - return nil -} - -// unicastNamingPicker picks an address from a name resolver to set up the connection. -type unicastNamingPicker struct { - cc *ClientConn - resolver naming.Resolver - watcher naming.Watcher - mu sync.Mutex - // The list of the addresses are obtained from watcher. - addrs *list.List - // It tracks the current picked addr by PickAddr(). The next PickAddr may - // push it forward on addrs. - pickedAddr *list.Element - conn *Conn -} - -// NewUnicastNamingPicker creates a Picker to pick addresses from a name resolver -// to connect. -func NewUnicastNamingPicker(r naming.Resolver) Picker { - return &unicastNamingPicker{ - resolver: r, - addrs: list.New(), - } -} - -type addrInfo struct { - addr string - // Set to true if this addrInfo needs to be deleted in the next PickAddrr() call. - deleting bool -} - -// processUpdates calls Watcher.Next() once and processes the obtained updates. -func (p *unicastNamingPicker) processUpdates() error { - updates, err := p.watcher.Next() - if err != nil { - return err - } - for _, update := range updates { - switch update.Op { - case naming.Add: - p.mu.Lock() - p.addrs.PushBack(&addrInfo{ - addr: update.Addr, - }) - p.mu.Unlock() - // Initial connection setup - if p.conn == nil { - conn, err := NewConn(p.cc) - if err != nil { - return err - } - p.conn = conn - } - case naming.Delete: - p.mu.Lock() - for e := p.addrs.Front(); e != nil; e = e.Next() { - if update.Addr == e.Value.(*addrInfo).addr { - if e == p.pickedAddr { - // Do not remove the element now if it is the current picked - // one. We leave the deletion to the next PickAddr() call. - e.Value.(*addrInfo).deleting = true - // Notify Conn to close it. All the live RPCs on this connection - // will be aborted. - p.conn.NotifyReset() - } else { - p.addrs.Remove(e) - } - } - } - p.mu.Unlock() - default: - grpclog.Println("Unknown update.Op ", update.Op) - } - } - return nil -} - -// monitor runs in a standalone goroutine to keep watching name resolution updates until the watcher -// is closed. -func (p *unicastNamingPicker) monitor() { - for { - if err := p.processUpdates(); err != nil { - return - } - } -} - -func (p *unicastNamingPicker) Init(cc *ClientConn) error { - w, err := p.resolver.Resolve(cc.target) - if err != nil { - return err - } - p.watcher = w - p.cc = cc - // Get the initial name resolution. - if err := p.processUpdates(); err != nil { - return err - } - go p.monitor() - return nil -} - -func (p *unicastNamingPicker) Pick(ctx context.Context) (transport.ClientTransport, error) { - return p.conn.Wait(ctx) -} - -func (p *unicastNamingPicker) PickAddr() (string, error) { - p.mu.Lock() - defer p.mu.Unlock() - if p.pickedAddr == nil { - p.pickedAddr = p.addrs.Front() - } else { - pa := p.pickedAddr - p.pickedAddr = pa.Next() - if pa.Value.(*addrInfo).deleting { - p.addrs.Remove(pa) - } - if p.pickedAddr == nil { - p.pickedAddr = p.addrs.Front() - } - } - if p.pickedAddr == nil { - return "", fmt.Errorf("there is no address available to pick") - } - return p.pickedAddr.Value.(*addrInfo).addr, nil -} - -func (p *unicastNamingPicker) State() (ConnectivityState, error) { - return 0, fmt.Errorf("State() is not supported for unicastNamingPicker") -} - -func (p *unicastNamingPicker) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) { - return 0, fmt.Errorf("WaitForStateChange is not supported for unicastNamingPciker") -} - -func (p *unicastNamingPicker) Close() error { - p.watcher.Close() - p.conn.Close() - return nil -} diff --git a/picker_test.go b/picker_test.go deleted file mode 100644 index dd29497b..00000000 --- a/picker_test.go +++ /dev/null @@ -1,188 +0,0 @@ -/* - * - * Copyright 2014, Google Inc. - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above - * copyright notice, this list of conditions and the following disclaimer - * in the documentation and/or other materials provided with the - * distribution. - * * Neither the name of Google Inc. nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - */ - -package grpc - -import ( - "fmt" - "math" - "testing" - "time" - - "golang.org/x/net/context" - "google.golang.org/grpc/naming" -) - -type testWatcher struct { - // the channel to receives name resolution updates - update chan *naming.Update - // the side channel to get to know how many updates in a batch - side chan int - // the channel to notifiy update injector that the update reading is done - readDone chan int -} - -func (w *testWatcher) Next() (updates []*naming.Update, err error) { - n := <-w.side - if n == 0 { - return nil, fmt.Errorf("w.side is closed") - } - for i := 0; i < n; i++ { - u := <-w.update - if u != nil { - updates = append(updates, u) - } - } - w.readDone <- 0 - return -} - -func (w *testWatcher) Close() { -} - -func (w *testWatcher) inject(updates []*naming.Update) { - w.side <- len(updates) - for _, u := range updates { - w.update <- u - } - <-w.readDone -} - -type testNameResolver struct { - w *testWatcher - addr string -} - -func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) { - r.w = &testWatcher{ - update: make(chan *naming.Update, 1), - side: make(chan int, 1), - readDone: make(chan int), - } - r.w.side <- 1 - r.w.update <- &naming.Update{ - Op: naming.Add, - Addr: r.addr, - } - go func() { - <-r.w.readDone - }() - return r.w, nil -} - -func startServers(t *testing.T, numServers, port int, maxStreams uint32) ([]*server, *testNameResolver) { - var servers []*server - for i := 0; i < numServers; i++ { - s := newTestServer() - servers = append(servers, s) - go s.start(t, port, maxStreams) - s.wait(t, 2*time.Second) - } - // Point to server1 - addr := "127.0.0.1:" + servers[0].port - return servers, &testNameResolver{ - addr: addr, - } -} - -func TestNameDiscovery(t *testing.T) { - // Start 3 servers on 3 ports. - servers, r := startServers(t, 3, 0, math.MaxUint32) - cc, err := Dial("foo.bar.com", WithPicker(NewUnicastNamingPicker(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) - if err != nil { - t.Fatalf("Failed to create ClientConn: %v", err) - } - var reply string - if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse { - t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want ", err) - } - // Inject name resolution change to point to the second server now. - var updates []*naming.Update - updates = append(updates, &naming.Update{ - Op: naming.Delete, - Addr: "127.0.0.1:" + servers[0].port, - }) - updates = append(updates, &naming.Update{ - Op: naming.Add, - Addr: "127.0.0.1:" + servers[1].port, - }) - r.w.inject(updates) - servers[0].stop() - if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse { - t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want ", err) - } - // Add another server address (server#3) to name resolution - updates = nil - updates = append(updates, &naming.Update{ - Op: naming.Add, - Addr: "127.0.0.1:" + servers[2].port, - }) - r.w.inject(updates) - // Stop server#2. The library should direct to server#3 automatically. - servers[1].stop() - if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse { - t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want ", err) - } - cc.Close() - servers[2].stop() -} - -func TestEmptyAddrs(t *testing.T) { - servers, r := startServers(t, 1, 0, math.MaxUint32) - cc, err := Dial("foo.bar.com", WithPicker(NewUnicastNamingPicker(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) - if err != nil { - t.Fatalf("Failed to create ClientConn: %v", err) - } - var reply string - if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse { - t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want ", err) - } - // Inject name resolution change to remove the server address so that there is no address - // available after that. - var updates []*naming.Update - updates = append(updates, &naming.Update{ - Op: naming.Delete, - Addr: "127.0.0.1:" + servers[0].port, - }) - r.w.inject(updates) - // Loop until the above updates apply. - for { - time.Sleep(10 * time.Millisecond) - ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) - if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); err != nil { - break - } - } - cc.Close() - servers[0].stop() -} diff --git a/stream.go b/stream.go index 565fc3cd..de125d5b 100644 --- a/stream.go +++ b/stream.go @@ -103,12 +103,16 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth var ( t transport.ClientTransport err error + put func() ) - t, err = cc.dopts.picker.Pick(ctx) + // TODO(zhaoq): CallOption is omitted. Add support when it is needed. + gopts := BalancerGetOptions{ + BlockingWait: false, + } + t, put, err = cc.getTransport(ctx, gopts) if err != nil { return nil, toRPCErr(err) } - // TODO(zhaoq): CallOption is omitted. Add support when it is needed. callHdr := &transport.CallHdr{ Host: cc.authority, Method: method, @@ -119,6 +123,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth } cs := &clientStream{ desc: desc, + put: put, codec: cc.dopts.codec, cp: cc.dopts.cp, dc: cc.dopts.dc, @@ -174,6 +179,7 @@ type clientStream struct { tracing bool // set to EnableTracing when the clientStream is created. mu sync.Mutex + put func() closed bool // trInfo.tr is set when the clientStream is created (if EnableTracing is true), // and is set to nil when the clientStream's finish method is called. @@ -311,6 +317,10 @@ func (cs *clientStream) finish(err error) { } cs.mu.Lock() defer cs.mu.Unlock() + if cs.put != nil { + cs.put() + cs.put = nil + } if cs.trInfo.tr != nil { if err == nil || err == io.EOF { cs.trInfo.tr.LazyPrintf("RPC: [OK]") diff --git a/test/end2end_test.go b/test/end2end_test.go index 09bcc392..30e5b8dc 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -162,7 +162,6 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (* return nil, fmt.Errorf("Unknown server name %q", serverName) } } - // Simulate some service delay. time.Sleep(time.Second) @@ -339,15 +338,16 @@ func TestReconnectTimeout(t *testing.T) { ResponseSize: proto.Int32(respSize), Payload: payload, } - if _, err := tc.UnaryCall(context.Background(), req); err == nil { + ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) + if _, err := tc.UnaryCall(ctx, req); err == nil { t.Errorf("TestService/UnaryCall(_, _) = _, , want _, non-nil") return } }() // Block until reconnect times out. <-waitC - if err := conn.Close(); err != grpc.ErrClientConnClosing { - t.Fatalf("%v.Close() = %v, want %v", conn, err, grpc.ErrClientConnClosing) + if err := conn.Close(); err != nil { + t.Fatalf("%v.Close() = %v, want ", conn, err) } } @@ -441,14 +441,17 @@ type test struct { func (te *test) tearDown() { if te.cancel != nil { te.cancel() + te.cancel = nil } - te.srv.Stop() if te.cc != nil { te.cc.Close() + te.cc = nil } if te.restoreLogs != nil { te.restoreLogs() + te.restoreLogs = nil } + te.srv.Stop() } // newTest returns a new test using the provided testing.T and @@ -590,6 +593,7 @@ func TestTimeoutOnDeadServer(t *testing.T) { func testTimeoutOnDeadServer(t *testing.T, e env) { te := newTest(t, e) + te.userAgent = testAppUA te.declareLogNoise( "transport: http2Client.notifyError got notified that the client transport was broken EOF", "grpc: Conn.transportMonitor exits due to: grpc: the client connection is closing", @@ -601,37 +605,17 @@ func testTimeoutOnDeadServer(t *testing.T, e env) { cc := te.clientConn() tc := testpb.NewTestServiceClient(cc) - ctx, _ := context.WithTimeout(context.Background(), time.Second) - if _, err := cc.WaitForStateChange(ctx, grpc.Idle); err != nil { - t.Fatalf("cc.WaitForStateChange(_, %s) = _, %v, want _, ", grpc.Idle, err) - } - ctx, _ = context.WithTimeout(context.Background(), time.Second) - if _, err := cc.WaitForStateChange(ctx, grpc.Connecting); err != nil { - t.Fatalf("cc.WaitForStateChange(_, %s) = _, %v, want _, ", grpc.Connecting, err) - } - if state, err := cc.State(); err != nil || state != grpc.Ready { - t.Fatalf("cc.State() = %s, %v, want %s, ", state, err, grpc.Ready) - } - ctx, _ = context.WithTimeout(context.Background(), time.Second) - if _, err := cc.WaitForStateChange(ctx, grpc.Ready); err != context.DeadlineExceeded { - t.Fatalf("cc.WaitForStateChange(_, %s) = _, %v, want _, %v", grpc.Ready, err, context.DeadlineExceeded) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) } te.srv.Stop() // Set -1 as the timeout to make sure if transportMonitor gets error // notification in time the failure path of the 1st invoke of // ClientConn.wait hits the deadline exceeded error. - ctx, _ = context.WithTimeout(context.Background(), -1) + ctx, _ := context.WithTimeout(context.Background(), -1) if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); grpc.Code(err) != codes.DeadlineExceeded { - t.Fatalf("TestService/EmptyCall(%v, _) = _, error %v, want _, error code: %d", ctx, err, codes.DeadlineExceeded) + t.Fatalf("TestService/EmptyCall(%v, _) = _, %v, want _, error code: %d", ctx, err, codes.DeadlineExceeded) } - ctx, _ = context.WithTimeout(context.Background(), time.Second) - if _, err := cc.WaitForStateChange(ctx, grpc.Ready); err != nil { - t.Fatalf("cc.WaitForStateChange(_, %s) = _, %v, want _, ", grpc.Ready, err) - } - if state, err := cc.State(); err != nil || (state != grpc.Connecting && state != grpc.TransientFailure) { - t.Fatalf("cc.State() = %s, %v, want %s or %s, ", state, err, grpc.Connecting, grpc.TransientFailure) - } - cc.Close() awaitNewConnLogOutput() } @@ -789,23 +773,6 @@ func testEmptyUnaryWithUserAgent(t *testing.T, e env) { defer te.tearDown() cc := te.clientConn() - - // Wait until cc is connected. - ctx, _ := context.WithTimeout(context.Background(), time.Second) - if _, err := cc.WaitForStateChange(ctx, grpc.Idle); err != nil { - t.Fatalf("cc.WaitForStateChange(_, %s) = _, %v, want _, ", grpc.Idle, err) - } - ctx, _ = context.WithTimeout(context.Background(), time.Second) - if _, err := cc.WaitForStateChange(ctx, grpc.Connecting); err != nil { - t.Fatalf("cc.WaitForStateChange(_, %s) = _, %v, want _, ", grpc.Connecting, err) - } - if state, err := cc.State(); err != nil || state != grpc.Ready { - t.Fatalf("cc.State() = %s, %v, want %s, ", state, err, grpc.Ready) - } - ctx, _ = context.WithTimeout(context.Background(), time.Second) - if _, err := cc.WaitForStateChange(ctx, grpc.Ready); err == nil { - t.Fatalf("cc.WaitForStateChange(_, %s) = _, , want _, %v", grpc.Ready, context.DeadlineExceeded) - } tc := testpb.NewTestServiceClient(cc) var header metadata.MD reply, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Header(&header)) @@ -817,15 +784,6 @@ func testEmptyUnaryWithUserAgent(t *testing.T, e env) { } te.srv.Stop() - cc.Close() - - ctx, _ = context.WithTimeout(context.Background(), 5*time.Second) - if _, err := cc.WaitForStateChange(ctx, grpc.Ready); err != nil { - t.Fatalf("cc.WaitForStateChange(_, %s) = _, %v, want _, ", grpc.Ready, err) - } - if state, err := cc.State(); err != nil || state != grpc.Shutdown { - t.Fatalf("cc.State() = %s, %v, want %s, ", state, err, grpc.Shutdown) - } } func TestFailedEmptyUnary(t *testing.T) { @@ -1007,7 +965,6 @@ func testRetry(t *testing.T, e env) { cc := te.clientConn() tc := testpb.NewTestServiceClient(cc) - var wg sync.WaitGroup numRPC := 1000 @@ -1073,9 +1030,8 @@ func testRPCTimeout(t *testing.T, e env) { } for i := -1; i <= 10; i++ { ctx, _ := context.WithTimeout(context.Background(), time.Duration(i)*time.Millisecond) - reply, err := tc.UnaryCall(ctx, req) - if grpc.Code(err) != codes.DeadlineExceeded { - t.Fatalf(`TestService/UnaryCallv(_, _) = %v, %v; want , error code: %d`, reply, err, codes.DeadlineExceeded) + if _, err := tc.UnaryCall(ctx, req); grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf("TestService/UnaryCallv(_, _) = _, %v; want , error code: %d", err, codes.DeadlineExceeded) } } } @@ -1111,12 +1067,9 @@ func testCancel(t *testing.T, e env) { } ctx, cancel := context.WithCancel(context.Background()) time.AfterFunc(1*time.Millisecond, cancel) - reply, err := tc.UnaryCall(ctx, req) - if grpc.Code(err) != codes.Canceled { - t.Fatalf(`TestService/UnaryCall(_, _) = %v, %v; want , error code: %d`, reply, err, codes.Canceled) + if r, err := tc.UnaryCall(ctx, req); grpc.Code(err) != codes.Canceled { + t.Fatalf("TestService/UnaryCall(_, _) = %v, %v; want _, error code: %d", r, err, codes.Canceled) } - cc.Close() - awaitNewConnLogOutput() } diff --git a/transport/http2_client.go b/transport/http2_client.go index 459d14d6..e624f8da 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -35,7 +35,6 @@ package transport import ( "bytes" - "errors" "io" "math" "net" @@ -272,6 +271,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } } t.mu.Lock() + if t.activeStreams == nil { + t.mu.Unlock() + return nil, ErrConnClosing + } if t.state != reachable { t.mu.Unlock() return nil, ErrConnClosing @@ -397,9 +400,19 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea func (t *http2Client) CloseStream(s *Stream, err error) { var updateStreams bool t.mu.Lock() + if t.activeStreams == nil { + t.mu.Unlock() + return + } if t.streamsQuota != nil { updateStreams = true } + if t.state == draining && len(t.activeStreams) == 1 { + // The transport is draining and s is the last live stream on t. + t.mu.Unlock() + t.Close() + return + } delete(t.activeStreams, s.id) t.mu.Unlock() if updateStreams { @@ -441,7 +454,7 @@ func (t *http2Client) Close() (err error) { } if t.state == closing { t.mu.Unlock() - return errors.New("transport: Close() was already called") + return } t.state = closing t.mu.Unlock() @@ -464,6 +477,25 @@ func (t *http2Client) Close() (err error) { return } +func (t *http2Client) GracefulClose() error { + t.mu.Lock() + if t.state == closing { + t.mu.Unlock() + return nil + } + if t.state == draining { + t.mu.Unlock() + return nil + } + t.state = draining + active := len(t.activeStreams) + t.mu.Unlock() + if active == 0 { + return t.Close() + } + return nil +} + // Write formats the data into HTTP2 data frame(s) and sends it out. The caller // should proceed only if Write returns nil. // TODO(zhaoq): opts.Delay is ignored in this implementation. Support it later diff --git a/transport/transport.go b/transport/transport.go index 87fdf532..1c9af545 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -321,6 +321,7 @@ const ( reachable transportState = iota unreachable closing + draining ) // NewServerTransport creates a ServerTransport with conn or non-nil error @@ -391,6 +392,10 @@ type ClientTransport interface { // is called only once. Close() error + // GracefulClose starts to tear down the transport. It stops accepting + // new RPCs and wait the completion of the pending RPCs. + GracefulClose() error + // Write sends the data for the given stream. A nil stream indicates // the write is to be performed on the transport as a whole. Write(s *Stream, data []byte, opts *Options) error diff --git a/transport/transport_test.go b/transport/transport_test.go index d63dba31..6ebec452 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -331,19 +331,17 @@ func TestLargeMessage(t *testing.T) { defer wg.Done() s, err := ct.NewStream(context.Background(), callHdr) if err != nil { - t.Errorf("failed to open stream: %v", err) + t.Errorf("%v.NewStream(_, _) = _, %v, want _, ", ct, err) } if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil { - t.Errorf("failed to send data: %v", err) + t.Errorf("%v.Write(_, _, _) = %v, want ", ct, err) } p := make([]byte, len(expectedResponseLarge)) - _, recvErr := io.ReadFull(s, p) - if recvErr != nil || !bytes.Equal(p, expectedResponseLarge) { - t.Errorf("Error: %v, want ; Result len: %d, want len %d", recvErr, len(p), len(expectedResponseLarge)) + if _, err := io.ReadFull(s, p); err != nil || !bytes.Equal(p, expectedResponseLarge) { + t.Errorf("io.ReadFull(_, %v) = _, %v, want %v, ", err, p, expectedResponse) } - _, recvErr = io.ReadFull(s, p) - if recvErr != io.EOF { - t.Errorf("Error: %v; want ", recvErr) + if _, err = io.ReadFull(s, p); err != io.EOF { + t.Errorf("Failed to complete the stream %v; want ", err) } }() } @@ -352,6 +350,50 @@ func TestLargeMessage(t *testing.T) { server.stop() } +func TestGracefulClose(t *testing.T) { + server, ct := setUp(t, 0, math.MaxUint32, normal) + callHdr := &CallHdr{ + Host: "localhost", + Method: "foo.Small", + } + s, err := ct.NewStream(context.Background(), callHdr) + if err != nil { + t.Fatalf("%v.NewStream(_, _) = _, %v, want _, ", ct, err) + } + if err = ct.GracefulClose(); err != nil { + t.Fatalf("%v.GracefulClose() = %v, want ", ct, err) + } + var wg sync.WaitGroup + // Expect the failure for all the follow-up streams because ct has been closed gracefully. + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if _, err := ct.NewStream(context.Background(), callHdr); err != ErrConnClosing { + t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", err, ErrConnClosing) + } + }() + } + opts := Options{ + Last: true, + Delay: false, + } + // The stream which was created before graceful close can still proceed. + if err := ct.Write(s, expectedRequest, &opts); err != nil { + t.Fatalf("%v.Write(_, _, _) = %v, want ", ct, err) + } + p := make([]byte, len(expectedResponse)) + if _, err := io.ReadFull(s, p); err != nil || !bytes.Equal(p, expectedResponse) { + t.Fatalf("io.ReadFull(_, %v) = _, %v, want %v, ", err, p, expectedResponse) + } + if _, err = io.ReadFull(s, p); err != io.EOF { + t.Fatalf("Failed to complete the stream %v; want ", err) + } + wg.Wait() + ct.Close() + server.stop() +} + func TestLargeMessageSuspension(t *testing.T) { server, ct := setUp(t, 0, math.MaxUint32, suspended) callHdr := &CallHdr{