Merge pull request #690 from iamqizhao/master

Introduce Balancer as the client load balancing solution
This commit is contained in:
Menghan Li
2016-06-02 10:16:10 -07:00
14 changed files with 1195 additions and 790 deletions

340
balancer.go Normal file
View File

@ -0,0 +1,340 @@
/*
*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package grpc
import (
"fmt"
"sync"
"golang.org/x/net/context"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/naming"
"google.golang.org/grpc/transport"
)
// Address represents a server the client connects to.
// This is the EXPERIMENTAL API and may be changed or extended in the future.
type Address struct {
// Addr is the server address on which a connection will be established.
Addr string
// Metadata is the information associated with Addr, which may be used
// to make load balancing decision.
Metadata interface{}
}
// BalancerGetOptions configures a Get call.
// This is the EXPERIMENTAL API and may be changed or extended in the future.
type BalancerGetOptions struct {
// BlockingWait specifies whether Get should block when there is no
// connected address.
BlockingWait bool
}
// Balancer chooses network addresses for RPCs.
// This is the EXPERIMENTAL API and may be changed or extended in the future.
type Balancer interface {
// Start does the initialization work to bootstrap a Balancer. For example,
// this function may start the name resolution and watch the updates. It will
// be called when dialing.
Start(target string) error
// Up informs the Balancer that gRPC has a connection to the server at
// addr. It returns down which is called once the connection to addr gets
// lost or closed.
// TODO: It is not clear how to construct and take advantage the meaningful error
// parameter for down. Need realistic demands to guide.
Up(addr Address) (down func(error))
// Get gets the address of a server for the RPC corresponding to ctx.
// i) If it returns a connected address, gRPC internals issues the RPC on the
// connection to this address;
// ii) If it returns an address on which the connection is under construction
// (initiated by Notify(...)) but not connected, gRPC internals
// * fails RPC if the RPC is fail-fast and connection is in the TransientFailure or
// Shutdown state;
// or
// * issues RPC on the connection otherwise.
// iii) If it returns an address on which the connection does not exist, gRPC
// internals treats it as an error and will fail the corresponding RPC.
//
// Therefore, the following is the recommended rule when writing a custom Balancer.
// If opts.BlockingWait is true, it should return a connected address or
// block if there is no connected address. It should respect the timeout or
// cancellation of ctx when blocking. If opts.BlockingWait is false (for fail-fast
// RPCs), it should return an address it has notified via Notify(...) immediately
// instead of blocking.
//
// The function returns put which is called once the rpc has completed or failed.
// put can collect and report RPC stats to a remote load balancer. gRPC internals
// will try to call this again if err is non-nil (unless err is ErrClientConnClosing).
//
// TODO: Add other non-recoverable errors?
Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error)
// Notify returns a channel that is used by gRPC internals to watch the addresses
// gRPC needs to connect. The addresses might be from a name resolver or remote
// load balancer. gRPC internals will compare it with the existing connected
// addresses. If the address Balancer notified is not in the existing connected
// addresses, gRPC starts to connect the address. If an address in the existing
// connected addresses is not in the notification list, the corresponding connection
// is shutdown gracefully. Otherwise, there are no operations to take. Note that
// the Address slice must be the full list of the Addresses which should be connected.
// It is NOT delta.
Notify() <-chan []Address
// Close shuts down the balancer.
Close() error
}
// downErr implements net.Error. It is constructed by gRPC internals and passed to the down
// call of Balancer.
type downErr struct {
timeout bool
temporary bool
desc string
}
func (e downErr) Error() string { return e.desc }
func (e downErr) Timeout() bool { return e.timeout }
func (e downErr) Temporary() bool { return e.temporary }
func downErrorf(timeout, temporary bool, format string, a ...interface{}) downErr {
return downErr{
timeout: timeout,
temporary: temporary,
desc: fmt.Sprintf(format, a...),
}
}
// RoundRobin returns a Balancer that selects addresses round-robin. It uses r to watch
// the name resolution updates and updates the addresses available correspondingly.
func RoundRobin(r naming.Resolver) Balancer {
return &roundRobin{r: r}
}
type roundRobin struct {
r naming.Resolver
w naming.Watcher
open []Address // all the addresses the client should potentially connect
mu sync.Mutex
addrCh chan []Address // the channel to notify gRPC internals the list of addresses the client should connect to.
connected []Address // all the connected addresses
next int // index of the next address to return for Get()
waitCh chan struct{} // the channel to block when there is no connected address available
done bool // The Balancer is closed.
}
func (rr *roundRobin) watchAddrUpdates() error {
updates, err := rr.w.Next()
if err != nil {
grpclog.Println("grpc: the naming watcher stops working due to %v.", err)
return err
}
rr.mu.Lock()
defer rr.mu.Unlock()
for _, update := range updates {
addr := Address{
Addr: update.Addr,
}
switch update.Op {
case naming.Add:
var exist bool
for _, v := range rr.open {
if addr == v {
exist = true
grpclog.Println("grpc: The name resolver wanted to add an existing address: ", addr)
break
}
}
if exist {
continue
}
rr.open = append(rr.open, addr)
case naming.Delete:
for i, v := range rr.open {
if v == addr {
copy(rr.open[i:], rr.open[i+1:])
rr.open = rr.open[:len(rr.open)-1]
break
}
}
default:
grpclog.Println("Unknown update.Op ", update.Op)
}
}
// Make a copy of rr.open and write it onto rr.addrCh so that gRPC internals gets notified.
open := make([]Address, len(rr.open), len(rr.open))
copy(open, rr.open)
if rr.done {
return ErrClientConnClosing
}
rr.addrCh <- open
return nil
}
func (rr *roundRobin) Start(target string) error {
if rr.r == nil {
// If there is no name resolver installed, it is not needed to
// do name resolution. In this case, rr.addrCh stays nil.
return nil
}
w, err := rr.r.Resolve(target)
if err != nil {
return err
}
rr.w = w
rr.addrCh = make(chan []Address)
go func() {
for {
if err := rr.watchAddrUpdates(); err != nil {
return
}
}
}()
return nil
}
// Up appends addr to the end of rr.connected and sends notification if there
// are pending Get() calls.
func (rr *roundRobin) Up(addr Address) func(error) {
rr.mu.Lock()
defer rr.mu.Unlock()
for _, a := range rr.connected {
if a == addr {
return nil
}
}
rr.connected = append(rr.connected, addr)
if len(rr.connected) == 1 {
// addr is only one available. Notify the Get() callers who are blocking.
if rr.waitCh != nil {
close(rr.waitCh)
rr.waitCh = nil
}
}
return func(err error) {
rr.down(addr, err)
}
}
// down removes addr from rr.connected and moves the remaining addrs forward.
func (rr *roundRobin) down(addr Address, err error) {
rr.mu.Lock()
defer rr.mu.Unlock()
for i, a := range rr.connected {
if a == addr {
copy(rr.connected[i:], rr.connected[i+1:])
rr.connected = rr.connected[:len(rr.connected)-1]
return
}
}
}
// Get returns the next addr in the rotation.
func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) {
var ch chan struct{}
rr.mu.Lock()
if rr.done {
rr.mu.Unlock()
err = ErrClientConnClosing
return
}
if rr.next >= len(rr.connected) {
rr.next = 0
}
if len(rr.connected) > 0 {
addr = rr.connected[rr.next]
rr.next++
rr.mu.Unlock()
return
}
// There is no address available. Wait on rr.waitCh.
// TODO(zhaoq): Handle the case when opts.BlockingWait is false.
if rr.waitCh == nil {
ch = make(chan struct{})
rr.waitCh = ch
} else {
ch = rr.waitCh
}
rr.mu.Unlock()
for {
select {
case <-ctx.Done():
err = transport.ContextErr(ctx.Err())
return
case <-ch:
rr.mu.Lock()
if rr.done {
rr.mu.Unlock()
err = ErrClientConnClosing
return
}
if len(rr.connected) == 0 {
// The newly added addr got removed by Down() again.
if rr.waitCh == nil {
ch = make(chan struct{})
rr.waitCh = ch
} else {
ch = rr.waitCh
}
rr.mu.Unlock()
continue
}
if rr.next >= len(rr.connected) {
rr.next = 0
}
addr = rr.connected[rr.next]
rr.next++
rr.mu.Unlock()
return
}
}
}
func (rr *roundRobin) Notify() <-chan []Address {
return rr.addrCh
}
func (rr *roundRobin) Close() error {
rr.mu.Lock()
defer rr.mu.Unlock()
rr.done = true
if rr.w != nil {
rr.w.Close()
}
if rr.waitCh != nil {
close(rr.waitCh)
rr.waitCh = nil
}
if rr.addrCh != nil {
close(rr.addrCh)
}
return nil
}

322
balancer_test.go Normal file
View File

@ -0,0 +1,322 @@
/*
*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package grpc
import (
"fmt"
"math"
"sync"
"testing"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/naming"
)
type testWatcher struct {
// the channel to receives name resolution updates
update chan *naming.Update
// the side channel to get to know how many updates in a batch
side chan int
// the channel to notifiy update injector that the update reading is done
readDone chan int
}
func (w *testWatcher) Next() (updates []*naming.Update, err error) {
n := <-w.side
if n == 0 {
return nil, fmt.Errorf("w.side is closed")
}
for i := 0; i < n; i++ {
u := <-w.update
if u != nil {
updates = append(updates, u)
}
}
w.readDone <- 0
return
}
func (w *testWatcher) Close() {
}
// Inject naming resolution updates to the testWatcher.
func (w *testWatcher) inject(updates []*naming.Update) {
w.side <- len(updates)
for _, u := range updates {
w.update <- u
}
<-w.readDone
}
type testNameResolver struct {
w *testWatcher
addr string
}
func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) {
r.w = &testWatcher{
update: make(chan *naming.Update, 1),
side: make(chan int, 1),
readDone: make(chan int),
}
r.w.side <- 1
r.w.update <- &naming.Update{
Op: naming.Add,
Addr: r.addr,
}
go func() {
<-r.w.readDone
}()
return r.w, nil
}
func startServers(t *testing.T, numServers int, maxStreams uint32) ([]*server, *testNameResolver) {
var servers []*server
for i := 0; i < numServers; i++ {
s := newTestServer()
servers = append(servers, s)
go s.start(t, 0, maxStreams)
s.wait(t, 2*time.Second)
}
// Point to server[0]
addr := "127.0.0.1:" + servers[0].port
return servers, &testNameResolver{
addr: addr,
}
}
func TestNameDiscovery(t *testing.T) {
// Start 2 servers on 2 ports.
numServers := 2
servers, r := startServers(t, numServers, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
req := "port"
var reply string
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[0].port {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port)
}
// Inject the name resolution change to remove servers[0] and add servers[1].
var updates []*naming.Update
updates = append(updates, &naming.Update{
Op: naming.Delete,
Addr: "127.0.0.1:" + servers[0].port,
})
updates = append(updates, &naming.Update{
Op: naming.Add,
Addr: "127.0.0.1:" + servers[1].port,
})
r.w.inject(updates)
// Loop until the rpcs in flight talks to servers[1].
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
}
cc.Close()
for i := 0; i < numServers; i++ {
servers[i].stop()
}
}
func TestEmptyAddrs(t *testing.T) {
servers, r := startServers(t, 1, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, reply = %q, want %q, <nil>", err, reply, expectedResponse)
}
// Inject name resolution change to remove the server so that there is no address
// available after that.
u := &naming.Update{
Op: naming.Delete,
Addr: "127.0.0.1:" + servers[0].port,
}
r.w.inject([]*naming.Update{u})
// Loop until the above updates apply.
for {
time.Sleep(10 * time.Millisecond)
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); err != nil {
break
}
}
cc.Close()
servers[0].stop()
}
func TestRoundRobin(t *testing.T) {
// Start 3 servers on 3 ports.
numServers := 3
servers, r := startServers(t, numServers, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
// Add servers[1] to the service discovery.
u := &naming.Update{
Op: naming.Add,
Addr: "127.0.0.1:" + servers[1].port,
}
r.w.inject([]*naming.Update{u})
req := "port"
var reply string
// Loop until servers[1] is up
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
}
// Add server2[2] to the service discovery.
u = &naming.Update{
Op: naming.Add,
Addr: "127.0.0.1:" + servers[2].port,
}
r.w.inject([]*naming.Update{u})
// Loop until both servers[2] are up.
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[2].port {
break
}
time.Sleep(10 * time.Millisecond)
}
// Check the incoming RPCs served in a round-robin manner.
for i := 0; i < 10; i++ {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || ErrorDesc(err) != servers[i%numServers].port {
t.Fatalf("Index %d: Invoke(_, _, _, _, _) = %v, want %s", i, err, servers[i%numServers].port)
}
}
cc.Close()
for i := 0; i < numServers; i++ {
servers[i].stop()
}
}
func TestCloseWithPendingRPC(t *testing.T) {
servers, r := startServers(t, 1, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port)
}
// Remove the server.
updates := []*naming.Update{&naming.Update{
Op: naming.Delete,
Addr: "127.0.0.1:" + servers[0].port,
}}
r.w.inject(updates)
// Loop until the above update applies.
for {
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); Code(err) == codes.DeadlineExceeded {
break
}
time.Sleep(10 * time.Millisecond)
}
// Issue 2 RPCs which should be completed with error status once cc is closed.
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
}
}()
go func() {
defer wg.Done()
var reply string
time.Sleep(5 * time.Millisecond)
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
}
}()
time.Sleep(5 * time.Millisecond)
cc.Close()
wg.Wait()
servers[0].stop()
}
func TestGetOnWaitChannel(t *testing.T) {
servers, r := startServers(t, 1, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
// Remove all servers so that all upcoming RPCs will block on waitCh.
updates := []*naming.Update{&naming.Update{
Op: naming.Delete,
Addr: "127.0.0.1:" + servers[0].port,
}}
r.w.inject(updates)
for {
var reply string
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); Code(err) == codes.DeadlineExceeded {
break
}
time.Sleep(10 * time.Millisecond)
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
}
}()
// Add a connected server to get the above RPC through.
updates = []*naming.Update{&naming.Update{
Op: naming.Add,
Addr: "127.0.0.1:" + servers[0].port,
}}
r.w.inject(updates)
// Wait until the above RPC succeeds.
wg.Wait()
cc.Close()
servers[0].stop()
}

70
call.go
View File

@ -132,19 +132,16 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
Last: true,
Delay: false,
}
var (
lastErr error // record the error that happened
)
for {
var (
err error
t transport.ClientTransport
stream *transport.Stream
// Record the put handler from Balancer.Get(...). It is called once the
// RPC has completed or failed.
put func()
)
// TODO(zhaoq): Need a formal spec of retry strategy for non-failfast rpcs.
if lastErr != nil && c.failFast {
return toRPCErr(lastErr)
}
// TODO(zhaoq): Need a formal spec of fail-fast.
callHdr := &transport.CallHdr{
Host: cc.authority,
Method: method,
@ -152,39 +149,66 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type()
}
t, err = cc.dopts.picker.Pick(ctx)
gopts := BalancerGetOptions{
BlockingWait: !c.failFast,
}
t, put, err = cc.getTransport(ctx, gopts)
if err != nil {
if lastErr != nil {
// This was a retry; return the error from the last attempt.
return toRPCErr(lastErr)
// TODO(zhaoq): Probably revisit the error handling.
if err == ErrClientConnClosing {
return Errorf(codes.FailedPrecondition, "%v", err)
}
return toRPCErr(err)
if _, ok := err.(transport.StreamError); ok {
return toRPCErr(err)
}
if _, ok := err.(transport.ConnectionError); ok {
if c.failFast {
return toRPCErr(err)
}
}
// All the remaining cases are treated as retryable.
continue
}
if c.traceInfo.tr != nil {
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
}
stream, err = sendRequest(ctx, cc.dopts.codec, cc.dopts.cp, callHdr, t, args, topts)
if err != nil {
if _, ok := err.(transport.ConnectionError); ok {
lastErr = err
continue
if put != nil {
put()
put = nil
}
if lastErr != nil {
return toRPCErr(lastErr)
if _, ok := err.(transport.ConnectionError); ok {
if c.failFast {
return toRPCErr(err)
}
continue
}
return toRPCErr(err)
}
// Receive the response
lastErr = recvResponse(cc.dopts, t, &c, stream, reply)
if _, ok := lastErr.(transport.ConnectionError); ok {
continue
err = recvResponse(cc.dopts, t, &c, stream, reply)
if err != nil {
if put != nil {
put()
put = nil
}
if _, ok := err.(transport.ConnectionError); ok {
if c.failFast {
return toRPCErr(err)
}
continue
}
t.CloseStream(stream, err)
return toRPCErr(err)
}
if c.traceInfo.tr != nil {
c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true)
}
t.CloseStream(stream, lastErr)
if lastErr != nil {
return toRPCErr(lastErr)
t.CloseStream(stream, nil)
if put != nil {
put()
put = nil
}
return Errorf(stream.StatusCode(), "%s", stream.StatusDesc())
}

View File

@ -74,7 +74,8 @@ func (testCodec) String() string {
}
type testStreamHandler struct {
t transport.ServerTransport
port string
t transport.ServerTransport
}
func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
@ -106,6 +107,11 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
h.t.WriteStatus(s, codes.Internal, "")
return
}
if v == "port" {
h.t.WriteStatus(s, codes.Internal, h.port)
return
}
if v != expectedRequest {
h.t.WriteStatus(s, codes.Internal, strings.Repeat("A", sizeLargeErr))
return
@ -160,7 +166,7 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32) {
}
st, err := transport.NewServerTransport("http2", conn, maxStreams, nil)
if err != nil {
return
continue
}
s.mu.Lock()
if s.conns == nil {
@ -170,7 +176,10 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32) {
}
s.conns[st] = true
s.mu.Unlock()
h := &testStreamHandler{st}
h := &testStreamHandler{
port: s.port,
t: st,
}
go st.HandleStreams(func(s *transport.Stream) {
go h.handleStream(t, s)
})

View File

@ -43,28 +43,34 @@ import (
"golang.org/x/net/context"
"golang.org/x/net/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/transport"
)
var (
// ErrUnspecTarget indicates that the target address is unspecified.
ErrUnspecTarget = errors.New("grpc: target is unspecified")
// ErrNoTransportSecurity indicates that there is no transport security
// ErrClientConnClosing indicates that the operation is illegal because
// the ClientConn is closing.
ErrClientConnClosing = errors.New("grpc: the client connection is closing")
// errNoTransportSecurity indicates that there is no transport security
// being set for ClientConn. Users should either set one or explicitly
// call WithInsecure DialOption to disable security.
ErrNoTransportSecurity = errors.New("grpc: no transport security set (use grpc.WithInsecure() explicitly or set credentials)")
// ErrCredentialsMisuse indicates that users want to transmit security information
errNoTransportSecurity = errors.New("grpc: no transport security set (use grpc.WithInsecure() explicitly or set credentials)")
// errCredentialsMisuse indicates that users want to transmit security information
// (e.g., oauth2 token) which requires secure connection on an insecure
// connection.
ErrCredentialsMisuse = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportAuthenticator() to set)")
// ErrClientConnClosing indicates that the operation is illegal because
// the session is closing.
ErrClientConnClosing = errors.New("grpc: the client connection is closing")
// ErrClientConnTimeout indicates that the connection could not be
errCredentialsMisuse = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportAuthenticator() to set)")
// errClientConnTimeout indicates that the connection could not be
// established or re-established within the specified timeout.
ErrClientConnTimeout = errors.New("grpc: timed out trying to connect")
errClientConnTimeout = errors.New("grpc: timed out trying to connect")
// errNetworkIP indicates that the connection is down due to some network I/O error.
errNetworkIO = errors.New("grpc: failed with network I/O error")
// errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs.
errConnDrain = errors.New("grpc: the connection is drained")
// errConnClosing indicates that the connection is closing.
errConnClosing = errors.New("grpc: the connection is closing")
// minimum time to give a connection to complete
minConnectTimeout = 20 * time.Second
)
@ -76,7 +82,7 @@ type dialOptions struct {
cp Compressor
dc Decompressor
bs backoffStrategy
picker Picker
balancer Balancer
block bool
insecure bool
copts transport.ConnectOptions
@ -108,10 +114,10 @@ func WithDecompressor(dc Decompressor) DialOption {
}
}
// WithPicker returns a DialOption which sets a picker for connection selection.
func WithPicker(p Picker) DialOption {
// WithBalancer returns a DialOption which sets a load balancer.
func WithBalancer(b Balancer) DialOption {
return func(o *dialOptions) {
o.picker = p
o.balancer = b
}
}
@ -201,6 +207,7 @@ func WithUserAgent(s string) DialOption {
func Dial(target string, opts ...DialOption) (*ClientConn, error) {
cc := &ClientConn{
target: target,
conns: make(map[Address]*addrConn),
}
for _, opt := range opts {
opt(&cc.dopts)
@ -214,14 +221,33 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
cc.dopts.bs = DefaultBackoffConfig
}
if cc.dopts.picker == nil {
cc.dopts.picker = &unicastPicker{
target: target,
}
cc.balancer = cc.dopts.balancer
if cc.balancer == nil {
cc.balancer = RoundRobin(nil)
}
if err := cc.dopts.picker.Init(cc); err != nil {
if err := cc.balancer.Start(target); err != nil {
return nil, err
}
ch := cc.balancer.Notify()
if ch == nil {
// There is no name resolver installed.
addr := Address{Addr: target}
if err := cc.newAddrConn(addr, false); err != nil {
return nil, err
}
} else {
addrs, ok := <-ch
if !ok || len(addrs) == 0 {
return nil, fmt.Errorf("grpc: there is no address available to dial")
}
for _, a := range addrs {
if err := cc.newAddrConn(a, false); err != nil {
return nil, err
}
}
go cc.lbWatcher()
}
colonPos := strings.LastIndex(target, ":")
if colonPos == -1 {
colonPos = len(target)
@ -263,193 +289,268 @@ func (s ConnectivityState) String() string {
}
}
// ClientConn represents a client connection to an RPC service.
// ClientConn represents a client connection to an RPC server.
type ClientConn struct {
target string
balancer Balancer
authority string
dopts dialOptions
mu sync.RWMutex
conns map[Address]*addrConn
}
// State returns the connectivity state of cc.
// This is EXPERIMENTAL API.
func (cc *ClientConn) State() (ConnectivityState, error) {
return cc.dopts.picker.State()
func (cc *ClientConn) lbWatcher() {
for addrs := range cc.balancer.Notify() {
var (
add []Address // Addresses need to setup connections.
del []*addrConn // Connections need to tear down.
)
cc.mu.Lock()
for _, a := range addrs {
if _, ok := cc.conns[a]; !ok {
add = append(add, a)
}
}
for k, c := range cc.conns {
var keep bool
for _, a := range addrs {
if k == a {
keep = true
break
}
}
if !keep {
del = append(del, c)
}
}
cc.mu.Unlock()
for _, a := range add {
cc.newAddrConn(a, true)
}
for _, c := range del {
c.tearDown(errConnDrain)
}
}
}
// WaitForStateChange blocks until the state changes to something other than the sourceState.
// It returns the new state or error.
// This is EXPERIMENTAL API.
func (cc *ClientConn) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
return cc.dopts.picker.WaitForStateChange(ctx, sourceState)
func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
ac := &addrConn{
cc: cc,
addr: addr,
dopts: cc.dopts,
shutdownChan: make(chan struct{}),
}
if EnableTracing {
ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr)
}
if !ac.dopts.insecure {
var ok bool
for _, cd := range ac.dopts.copts.AuthOptions {
if _, ok = cd.(credentials.TransportAuthenticator); ok {
break
}
}
if !ok {
return errNoTransportSecurity
}
} else {
for _, cd := range ac.dopts.copts.AuthOptions {
if cd.RequireTransportSecurity() {
return errCredentialsMisuse
}
}
}
// Insert ac into ac.cc.conns. This needs to be done before any getTransport(...) is called.
ac.cc.mu.Lock()
if ac.cc.conns == nil {
ac.cc.mu.Unlock()
return ErrClientConnClosing
}
stale := ac.cc.conns[ac.addr]
ac.cc.conns[ac.addr] = ac
ac.cc.mu.Unlock()
if stale != nil {
// There is an addrConn alive on ac.addr already. This could be due to
// i) stale's Close is undergoing;
// ii) a buggy Balancer notifies duplicated Addresses.
stale.tearDown(errConnDrain)
}
ac.stateCV = sync.NewCond(&ac.mu)
// skipWait may overwrite the decision in ac.dopts.block.
if ac.dopts.block && !skipWait {
if err := ac.resetTransport(false); err != nil {
ac.tearDown(err)
return err
}
// Start to monitor the error status of transport.
go ac.transportMonitor()
} else {
// Start a goroutine connecting to the server asynchronously.
go func() {
if err := ac.resetTransport(false); err != nil {
grpclog.Printf("Failed to dial %s: %v; please retry.", ac.addr.Addr, err)
ac.tearDown(err)
return
}
ac.transportMonitor()
}()
}
return nil
}
// Close starts to tear down the ClientConn.
func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) {
// TODO(zhaoq): Implement fail-fast logic.
addr, put, err := cc.balancer.Get(ctx, opts)
if err != nil {
return nil, nil, err
}
cc.mu.RLock()
if cc.conns == nil {
cc.mu.RUnlock()
return nil, nil, ErrClientConnClosing
}
ac, ok := cc.conns[addr]
cc.mu.RUnlock()
if !ok {
if put != nil {
put()
}
return nil, nil, transport.StreamErrorf(codes.Internal, "grpc: failed to find the transport to send the rpc")
}
t, err := ac.wait(ctx)
if err != nil {
if put != nil {
put()
}
return nil, nil, err
}
return t, put, nil
}
// Close tears down the ClientConn and all underlying connections.
func (cc *ClientConn) Close() error {
return cc.dopts.picker.Close()
cc.mu.Lock()
if cc.conns == nil {
cc.mu.Unlock()
return ErrClientConnClosing
}
conns := cc.conns
cc.conns = nil
cc.mu.Unlock()
cc.balancer.Close()
for _, ac := range conns {
ac.tearDown(ErrClientConnClosing)
}
return nil
}
// Conn is a client connection to a single destination.
type Conn struct {
target string
// addrConn is a network connection to a given address.
type addrConn struct {
cc *ClientConn
addr Address
dopts dialOptions
resetChan chan int
shutdownChan chan struct{}
events trace.EventLog
mu sync.Mutex
state ConnectivityState
stateCV *sync.Cond
down func(error) // the handler called when a connection is down.
// ready is closed and becomes nil when a new transport is up or failed
// due to timeout.
ready chan struct{}
transport transport.ClientTransport
}
// NewConn creates a Conn.
func NewConn(cc *ClientConn) (*Conn, error) {
if cc.target == "" {
return nil, ErrUnspecTarget
}
c := &Conn{
target: cc.target,
dopts: cc.dopts,
resetChan: make(chan int, 1),
shutdownChan: make(chan struct{}),
}
if EnableTracing {
c.events = trace.NewEventLog("grpc.ClientConn", c.target)
}
if !c.dopts.insecure {
var ok bool
for _, cd := range c.dopts.copts.AuthOptions {
if _, ok = cd.(credentials.TransportAuthenticator); ok {
break
}
}
if !ok {
return nil, ErrNoTransportSecurity
}
} else {
for _, cd := range c.dopts.copts.AuthOptions {
if cd.RequireTransportSecurity() {
return nil, ErrCredentialsMisuse
}
}
}
c.stateCV = sync.NewCond(&c.mu)
if c.dopts.block {
if err := c.resetTransport(false); err != nil {
c.Close()
return nil, err
}
// Start to monitor the error status of transport.
go c.transportMonitor()
} else {
// Start a goroutine connecting to the server asynchronously.
go func() {
if err := c.resetTransport(false); err != nil {
grpclog.Printf("Failed to dial %s: %v; please retry.", c.target, err)
c.Close()
return
}
c.transportMonitor()
}()
}
return c, nil
}
// printf records an event in cc's event log, unless cc has been closed.
// REQUIRES cc.mu is held.
func (cc *Conn) printf(format string, a ...interface{}) {
if cc.events != nil {
cc.events.Printf(format, a...)
// printf records an event in ac's event log, unless ac has been closed.
// REQUIRES ac.mu is held.
func (ac *addrConn) printf(format string, a ...interface{}) {
if ac.events != nil {
ac.events.Printf(format, a...)
}
}
// errorf records an error in cc's event log, unless cc has been closed.
// REQUIRES cc.mu is held.
func (cc *Conn) errorf(format string, a ...interface{}) {
if cc.events != nil {
cc.events.Errorf(format, a...)
// errorf records an error in ac's event log, unless ac has been closed.
// REQUIRES ac.mu is held.
func (ac *addrConn) errorf(format string, a ...interface{}) {
if ac.events != nil {
ac.events.Errorf(format, a...)
}
}
// State returns the connectivity state of the Conn
func (cc *Conn) State() ConnectivityState {
cc.mu.Lock()
defer cc.mu.Unlock()
return cc.state
// getState returns the connectivity state of the Conn
func (ac *addrConn) getState() ConnectivityState {
ac.mu.Lock()
defer ac.mu.Unlock()
return ac.state
}
// WaitForStateChange blocks until the state changes to something other than the sourceState.
func (cc *Conn) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
cc.mu.Lock()
defer cc.mu.Unlock()
if sourceState != cc.state {
return cc.state, nil
// waitForStateChange blocks until the state changes to something other than the sourceState.
func (ac *addrConn) waitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
ac.mu.Lock()
defer ac.mu.Unlock()
if sourceState != ac.state {
return ac.state, nil
}
done := make(chan struct{})
var err error
go func() {
select {
case <-ctx.Done():
cc.mu.Lock()
ac.mu.Lock()
err = ctx.Err()
cc.stateCV.Broadcast()
cc.mu.Unlock()
ac.stateCV.Broadcast()
ac.mu.Unlock()
case <-done:
}
}()
defer close(done)
for sourceState == cc.state {
cc.stateCV.Wait()
for sourceState == ac.state {
ac.stateCV.Wait()
if err != nil {
return cc.state, err
return ac.state, err
}
}
return cc.state, nil
return ac.state, nil
}
// NotifyReset tries to signal the underlying transport needs to be reset due to
// for example a name resolution change in flight.
func (cc *Conn) NotifyReset() {
select {
case cc.resetChan <- 0:
default:
}
}
func (cc *Conn) resetTransport(closeTransport bool) error {
func (ac *addrConn) resetTransport(closeTransport bool) error {
var retries int
start := time.Now()
for {
cc.mu.Lock()
cc.printf("connecting")
if cc.state == Shutdown {
// cc.Close() has been invoked.
cc.mu.Unlock()
return ErrClientConnClosing
ac.mu.Lock()
ac.printf("connecting")
if ac.state == Shutdown {
// ac.tearDown(...) has been invoked.
ac.mu.Unlock()
return errConnClosing
}
cc.state = Connecting
cc.stateCV.Broadcast()
cc.mu.Unlock()
if closeTransport {
cc.transport.Close()
if ac.down != nil {
ac.down(downErrorf(false, true, "%v", errNetworkIO))
ac.down = nil
}
ac.state = Connecting
ac.stateCV.Broadcast()
t := ac.transport
ac.mu.Unlock()
if closeTransport && t != nil {
t.Close()
}
// Adjust timeout for the current try.
copts := cc.dopts.copts
copts := ac.dopts.copts
if copts.Timeout < 0 {
cc.Close()
return ErrClientConnTimeout
ac.tearDown(errClientConnTimeout)
return errClientConnTimeout
}
if copts.Timeout > 0 {
copts.Timeout -= time.Since(start)
if copts.Timeout <= 0 {
cc.Close()
return ErrClientConnTimeout
ac.tearDown(errClientConnTimeout)
return errClientConnTimeout
}
}
sleepTime := cc.dopts.bs.backoff(retries)
sleepTime := ac.dopts.bs.backoff(retries)
timeout := sleepTime
if timeout < minConnectTimeout {
timeout = minConnectTimeout
@ -458,133 +559,116 @@ func (cc *Conn) resetTransport(closeTransport bool) error {
copts.Timeout = timeout
}
connectTime := time.Now()
addr, err := cc.dopts.picker.PickAddr()
var newTransport transport.ClientTransport
if err == nil {
newTransport, err = transport.NewClientTransport(addr, &copts)
}
newTransport, err := transport.NewClientTransport(ac.addr.Addr, &copts)
if err != nil {
cc.mu.Lock()
if cc.state == Shutdown {
// cc.Close() has been invoked.
cc.mu.Unlock()
return ErrClientConnClosing
ac.mu.Lock()
if ac.state == Shutdown {
// ac.tearDown(...) has been invoked.
ac.mu.Unlock()
return errConnClosing
}
cc.errorf("transient failure: %v", err)
cc.state = TransientFailure
cc.stateCV.Broadcast()
if cc.ready != nil {
close(cc.ready)
cc.ready = nil
ac.errorf("transient failure: %v", err)
ac.state = TransientFailure
ac.stateCV.Broadcast()
if ac.ready != nil {
close(ac.ready)
ac.ready = nil
}
cc.mu.Unlock()
ac.mu.Unlock()
sleepTime -= time.Since(connectTime)
if sleepTime < 0 {
sleepTime = 0
}
// Fail early before falling into sleep.
if cc.dopts.copts.Timeout > 0 && cc.dopts.copts.Timeout < sleepTime+time.Since(start) {
cc.mu.Lock()
cc.errorf("connection timeout")
cc.mu.Unlock()
cc.Close()
return ErrClientConnTimeout
if ac.dopts.copts.Timeout > 0 && ac.dopts.copts.Timeout < sleepTime+time.Since(start) {
ac.mu.Lock()
ac.errorf("connection timeout")
ac.mu.Unlock()
ac.tearDown(errClientConnTimeout)
return errClientConnTimeout
}
closeTransport = false
select {
case <-time.After(sleepTime):
case <-cc.shutdownChan:
case <-ac.shutdownChan:
}
retries++
grpclog.Printf("grpc: Conn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, cc.target)
grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr)
continue
}
cc.mu.Lock()
cc.printf("ready")
if cc.state == Shutdown {
// cc.Close() has been invoked.
cc.mu.Unlock()
ac.mu.Lock()
ac.printf("ready")
if ac.state == Shutdown {
// ac.tearDown(...) has been invoked.
ac.mu.Unlock()
newTransport.Close()
return ErrClientConnClosing
return errConnClosing
}
cc.state = Ready
cc.stateCV.Broadcast()
cc.transport = newTransport
if cc.ready != nil {
close(cc.ready)
cc.ready = nil
ac.state = Ready
ac.stateCV.Broadcast()
ac.transport = newTransport
if ac.ready != nil {
close(ac.ready)
ac.ready = nil
}
cc.mu.Unlock()
ac.down = ac.cc.balancer.Up(ac.addr)
ac.mu.Unlock()
return nil
}
}
func (cc *Conn) reconnect() bool {
cc.mu.Lock()
if cc.state == Shutdown {
// cc.Close() has been invoked.
cc.mu.Unlock()
return false
}
cc.state = TransientFailure
cc.stateCV.Broadcast()
cc.mu.Unlock()
if err := cc.resetTransport(true); err != nil {
// The ClientConn is closing.
cc.mu.Lock()
cc.printf("transport exiting: %v", err)
cc.mu.Unlock()
grpclog.Printf("grpc: Conn.transportMonitor exits due to: %v", err)
return false
}
return true
}
// Run in a goroutine to track the error in transport and create the
// new transport if an error happens. It returns when the channel is closing.
func (cc *Conn) transportMonitor() {
func (ac *addrConn) transportMonitor() {
for {
ac.mu.Lock()
t := ac.transport
ac.mu.Unlock()
select {
// shutdownChan is needed to detect the teardown when
// the ClientConn is idle (i.e., no RPC in flight).
case <-cc.shutdownChan:
// the addrConn is idle (i.e., no RPC in flight).
case <-ac.shutdownChan:
return
case <-cc.resetChan:
if !cc.reconnect() {
case <-t.Error():
ac.mu.Lock()
if ac.state == Shutdown {
// ac.tearDown(...) has been invoked.
ac.mu.Unlock()
return
}
case <-cc.transport.Error():
if !cc.reconnect() {
ac.state = TransientFailure
ac.stateCV.Broadcast()
ac.mu.Unlock()
if err := ac.resetTransport(true); err != nil {
ac.mu.Lock()
ac.printf("transport exiting: %v", err)
ac.mu.Unlock()
grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err)
return
}
// Tries to drain reset signal if there is any since it is out-dated.
select {
case <-cc.resetChan:
default:
}
}
}
}
// Wait blocks until i) the new transport is up or ii) ctx is done or iii) cc is closed.
func (cc *Conn) Wait(ctx context.Context) (transport.ClientTransport, error) {
// wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed.
func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error) {
for {
cc.mu.Lock()
ac.mu.Lock()
switch {
case cc.state == Shutdown:
cc.mu.Unlock()
return nil, ErrClientConnClosing
case cc.state == Ready:
ct := cc.transport
cc.mu.Unlock()
case ac.state == Shutdown:
ac.mu.Unlock()
return nil, errConnClosing
case ac.state == Ready:
ct := ac.transport
ac.mu.Unlock()
return ct, nil
default:
ready := cc.ready
ready := ac.ready
if ready == nil {
ready = make(chan struct{})
cc.ready = ready
ac.ready = ready
}
cc.mu.Unlock()
ac.mu.Unlock()
select {
case <-ctx.Done():
return nil, transport.ContextErr(ctx.Err())
@ -595,32 +679,46 @@ func (cc *Conn) Wait(ctx context.Context) (transport.ClientTransport, error) {
}
}
// Close starts to tear down the Conn. Returns ErrClientConnClosing if
// it has been closed (mostly due to dial time-out).
// tearDown starts to tear down the addrConn.
// TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in
// some edge cases (e.g., the caller opens and closes many ClientConn's in a
// some edge cases (e.g., the caller opens and closes many addrConn's in a
// tight loop.
func (cc *Conn) Close() error {
cc.mu.Lock()
defer cc.mu.Unlock()
if cc.state == Shutdown {
return ErrClientConnClosing
func (ac *addrConn) tearDown(err error) {
ac.mu.Lock()
defer func() {
ac.mu.Unlock()
ac.cc.mu.Lock()
if ac.cc.conns != nil {
delete(ac.cc.conns, ac.addr)
}
ac.cc.mu.Unlock()
}()
if ac.state == Shutdown {
return
}
cc.state = Shutdown
cc.stateCV.Broadcast()
if cc.events != nil {
cc.events.Finish()
cc.events = nil
ac.state = Shutdown
if ac.down != nil {
ac.down(downErrorf(false, false, "%v", err))
ac.down = nil
}
if cc.ready != nil {
close(cc.ready)
cc.ready = nil
ac.stateCV.Broadcast()
if ac.events != nil {
ac.events.Finish()
ac.events = nil
}
if cc.transport != nil {
cc.transport.Close()
if ac.ready != nil {
close(ac.ready)
ac.ready = nil
}
if cc.shutdownChan != nil {
close(cc.shutdownChan)
if ac.transport != nil {
if err == errConnDrain {
ac.transport.GracefulClose()
} else {
ac.transport.Close()
}
}
return nil
if ac.shutdownChan != nil {
close(ac.shutdownChan)
}
return
}

View File

@ -47,8 +47,8 @@ func TestDialTimeout(t *testing.T) {
if err == nil {
conn.Close()
}
if err != ErrClientConnTimeout {
t.Fatalf("Dial(_, _) = %v, %v, want %v", conn, err, ErrClientConnTimeout)
if err != errClientConnTimeout {
t.Fatalf("Dial(_, _) = %v, %v, want %v", conn, err, errClientConnTimeout)
}
}
@ -61,8 +61,8 @@ func TestTLSDialTimeout(t *testing.T) {
if err == nil {
conn.Close()
}
if err != ErrClientConnTimeout {
t.Fatalf("grpc.Dial(_, _) = %v, %v, want %v", conn, err, ErrClientConnTimeout)
if err != errClientConnTimeout {
t.Fatalf("grpc.Dial(_, _) = %v, %v, want %v", conn, err, errClientConnTimeout)
}
}
@ -72,12 +72,12 @@ func TestCredentialsMisuse(t *testing.T) {
t.Fatalf("Failed to create credentials %v", err)
}
// Two conflicting credential configurations
if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != ErrCredentialsMisuse {
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, ErrCredentialsMisuse)
if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse {
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse)
}
// security info on insecure connection
if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != ErrCredentialsMisuse {
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, ErrCredentialsMisuse)
if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse {
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse)
}
}

View File

@ -66,7 +66,8 @@ type Resolver interface {
// Watcher watches for the updates on the specified target.
type Watcher interface {
// Next blocks until an update or error happens. It may return one or more
// updates. The first call should get the full set of the results.
// updates. The first call should get the full set of the results. It should
// return an error if and only if Watcher cannot recover.
Next() ([]*Update, error)
// Close closes the Watcher.
Close()

243
picker.go
View File

@ -1,243 +0,0 @@
/*
*
* Copyright 2014, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package grpc
import (
"container/list"
"fmt"
"sync"
"golang.org/x/net/context"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/naming"
"google.golang.org/grpc/transport"
)
// Picker picks a Conn for RPC requests.
// This is EXPERIMENTAL and please do not implement your own Picker for now.
type Picker interface {
// Init does initial processing for the Picker, e.g., initiate some connections.
Init(cc *ClientConn) error
// Pick blocks until either a transport.ClientTransport is ready for the upcoming RPC
// or some error happens.
Pick(ctx context.Context) (transport.ClientTransport, error)
// PickAddr picks a peer address for connecting. This will be called repeated for
// connecting/reconnecting.
PickAddr() (string, error)
// State returns the connectivity state of the underlying connections.
State() (ConnectivityState, error)
// WaitForStateChange blocks until the state changes to something other than
// the sourceState. It returns the new state or error.
WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error)
// Close closes all the Conn's owned by this Picker.
Close() error
}
// unicastPicker is the default Picker which is used when there is no custom Picker
// specified by users. It always picks the same Conn.
type unicastPicker struct {
target string
conn *Conn
}
func (p *unicastPicker) Init(cc *ClientConn) error {
c, err := NewConn(cc)
if err != nil {
return err
}
p.conn = c
return nil
}
func (p *unicastPicker) Pick(ctx context.Context) (transport.ClientTransport, error) {
return p.conn.Wait(ctx)
}
func (p *unicastPicker) PickAddr() (string, error) {
return p.target, nil
}
func (p *unicastPicker) State() (ConnectivityState, error) {
return p.conn.State(), nil
}
func (p *unicastPicker) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
return p.conn.WaitForStateChange(ctx, sourceState)
}
func (p *unicastPicker) Close() error {
if p.conn != nil {
return p.conn.Close()
}
return nil
}
// unicastNamingPicker picks an address from a name resolver to set up the connection.
type unicastNamingPicker struct {
cc *ClientConn
resolver naming.Resolver
watcher naming.Watcher
mu sync.Mutex
// The list of the addresses are obtained from watcher.
addrs *list.List
// It tracks the current picked addr by PickAddr(). The next PickAddr may
// push it forward on addrs.
pickedAddr *list.Element
conn *Conn
}
// NewUnicastNamingPicker creates a Picker to pick addresses from a name resolver
// to connect.
func NewUnicastNamingPicker(r naming.Resolver) Picker {
return &unicastNamingPicker{
resolver: r,
addrs: list.New(),
}
}
type addrInfo struct {
addr string
// Set to true if this addrInfo needs to be deleted in the next PickAddrr() call.
deleting bool
}
// processUpdates calls Watcher.Next() once and processes the obtained updates.
func (p *unicastNamingPicker) processUpdates() error {
updates, err := p.watcher.Next()
if err != nil {
return err
}
for _, update := range updates {
switch update.Op {
case naming.Add:
p.mu.Lock()
p.addrs.PushBack(&addrInfo{
addr: update.Addr,
})
p.mu.Unlock()
// Initial connection setup
if p.conn == nil {
conn, err := NewConn(p.cc)
if err != nil {
return err
}
p.conn = conn
}
case naming.Delete:
p.mu.Lock()
for e := p.addrs.Front(); e != nil; e = e.Next() {
if update.Addr == e.Value.(*addrInfo).addr {
if e == p.pickedAddr {
// Do not remove the element now if it is the current picked
// one. We leave the deletion to the next PickAddr() call.
e.Value.(*addrInfo).deleting = true
// Notify Conn to close it. All the live RPCs on this connection
// will be aborted.
p.conn.NotifyReset()
} else {
p.addrs.Remove(e)
}
}
}
p.mu.Unlock()
default:
grpclog.Println("Unknown update.Op ", update.Op)
}
}
return nil
}
// monitor runs in a standalone goroutine to keep watching name resolution updates until the watcher
// is closed.
func (p *unicastNamingPicker) monitor() {
for {
if err := p.processUpdates(); err != nil {
return
}
}
}
func (p *unicastNamingPicker) Init(cc *ClientConn) error {
w, err := p.resolver.Resolve(cc.target)
if err != nil {
return err
}
p.watcher = w
p.cc = cc
// Get the initial name resolution.
if err := p.processUpdates(); err != nil {
return err
}
go p.monitor()
return nil
}
func (p *unicastNamingPicker) Pick(ctx context.Context) (transport.ClientTransport, error) {
return p.conn.Wait(ctx)
}
func (p *unicastNamingPicker) PickAddr() (string, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.pickedAddr == nil {
p.pickedAddr = p.addrs.Front()
} else {
pa := p.pickedAddr
p.pickedAddr = pa.Next()
if pa.Value.(*addrInfo).deleting {
p.addrs.Remove(pa)
}
if p.pickedAddr == nil {
p.pickedAddr = p.addrs.Front()
}
}
if p.pickedAddr == nil {
return "", fmt.Errorf("there is no address available to pick")
}
return p.pickedAddr.Value.(*addrInfo).addr, nil
}
func (p *unicastNamingPicker) State() (ConnectivityState, error) {
return 0, fmt.Errorf("State() is not supported for unicastNamingPicker")
}
func (p *unicastNamingPicker) WaitForStateChange(ctx context.Context, sourceState ConnectivityState) (ConnectivityState, error) {
return 0, fmt.Errorf("WaitForStateChange is not supported for unicastNamingPciker")
}
func (p *unicastNamingPicker) Close() error {
p.watcher.Close()
p.conn.Close()
return nil
}

View File

@ -1,188 +0,0 @@
/*
*
* Copyright 2014, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package grpc
import (
"fmt"
"math"
"testing"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc/naming"
)
type testWatcher struct {
// the channel to receives name resolution updates
update chan *naming.Update
// the side channel to get to know how many updates in a batch
side chan int
// the channel to notifiy update injector that the update reading is done
readDone chan int
}
func (w *testWatcher) Next() (updates []*naming.Update, err error) {
n := <-w.side
if n == 0 {
return nil, fmt.Errorf("w.side is closed")
}
for i := 0; i < n; i++ {
u := <-w.update
if u != nil {
updates = append(updates, u)
}
}
w.readDone <- 0
return
}
func (w *testWatcher) Close() {
}
func (w *testWatcher) inject(updates []*naming.Update) {
w.side <- len(updates)
for _, u := range updates {
w.update <- u
}
<-w.readDone
}
type testNameResolver struct {
w *testWatcher
addr string
}
func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) {
r.w = &testWatcher{
update: make(chan *naming.Update, 1),
side: make(chan int, 1),
readDone: make(chan int),
}
r.w.side <- 1
r.w.update <- &naming.Update{
Op: naming.Add,
Addr: r.addr,
}
go func() {
<-r.w.readDone
}()
return r.w, nil
}
func startServers(t *testing.T, numServers, port int, maxStreams uint32) ([]*server, *testNameResolver) {
var servers []*server
for i := 0; i < numServers; i++ {
s := newTestServer()
servers = append(servers, s)
go s.start(t, port, maxStreams)
s.wait(t, 2*time.Second)
}
// Point to server1
addr := "127.0.0.1:" + servers[0].port
return servers, &testNameResolver{
addr: addr,
}
}
func TestNameDiscovery(t *testing.T) {
// Start 3 servers on 3 ports.
servers, r := startServers(t, 3, 0, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithPicker(NewUnicastNamingPicker(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
}
// Inject name resolution change to point to the second server now.
var updates []*naming.Update
updates = append(updates, &naming.Update{
Op: naming.Delete,
Addr: "127.0.0.1:" + servers[0].port,
})
updates = append(updates, &naming.Update{
Op: naming.Add,
Addr: "127.0.0.1:" + servers[1].port,
})
r.w.inject(updates)
servers[0].stop()
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
}
// Add another server address (server#3) to name resolution
updates = nil
updates = append(updates, &naming.Update{
Op: naming.Add,
Addr: "127.0.0.1:" + servers[2].port,
})
r.w.inject(updates)
// Stop server#2. The library should direct to server#3 automatically.
servers[1].stop()
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
}
cc.Close()
servers[2].stop()
}
func TestEmptyAddrs(t *testing.T) {
servers, r := startServers(t, 1, 0, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithPicker(NewUnicastNamingPicker(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil || reply != expectedResponse {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
}
// Inject name resolution change to remove the server address so that there is no address
// available after that.
var updates []*naming.Update
updates = append(updates, &naming.Update{
Op: naming.Delete,
Addr: "127.0.0.1:" + servers[0].port,
})
r.w.inject(updates)
// Loop until the above updates apply.
for {
time.Sleep(10 * time.Millisecond)
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); err != nil {
break
}
}
cc.Close()
servers[0].stop()
}

View File

@ -103,12 +103,16 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
var (
t transport.ClientTransport
err error
put func()
)
t, err = cc.dopts.picker.Pick(ctx)
// TODO(zhaoq): CallOption is omitted. Add support when it is needed.
gopts := BalancerGetOptions{
BlockingWait: false,
}
t, put, err = cc.getTransport(ctx, gopts)
if err != nil {
return nil, toRPCErr(err)
}
// TODO(zhaoq): CallOption is omitted. Add support when it is needed.
callHdr := &transport.CallHdr{
Host: cc.authority,
Method: method,
@ -119,6 +123,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
}
cs := &clientStream{
desc: desc,
put: put,
codec: cc.dopts.codec,
cp: cc.dopts.cp,
dc: cc.dopts.dc,
@ -174,6 +179,7 @@ type clientStream struct {
tracing bool // set to EnableTracing when the clientStream is created.
mu sync.Mutex
put func()
closed bool
// trInfo.tr is set when the clientStream is created (if EnableTracing is true),
// and is set to nil when the clientStream's finish method is called.
@ -311,6 +317,10 @@ func (cs *clientStream) finish(err error) {
}
cs.mu.Lock()
defer cs.mu.Unlock()
if cs.put != nil {
cs.put()
cs.put = nil
}
if cs.trInfo.tr != nil {
if err == nil || err == io.EOF {
cs.trInfo.tr.LazyPrintf("RPC: [OK]")

View File

@ -162,7 +162,6 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*
return nil, fmt.Errorf("Unknown server name %q", serverName)
}
}
// Simulate some service delay.
time.Sleep(time.Second)
@ -339,15 +338,16 @@ func TestReconnectTimeout(t *testing.T) {
ResponseSize: proto.Int32(respSize),
Payload: payload,
}
if _, err := tc.UnaryCall(context.Background(), req); err == nil {
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
if _, err := tc.UnaryCall(ctx, req); err == nil {
t.Errorf("TestService/UnaryCall(_, _) = _, <nil>, want _, non-nil")
return
}
}()
// Block until reconnect times out.
<-waitC
if err := conn.Close(); err != grpc.ErrClientConnClosing {
t.Fatalf("%v.Close() = %v, want %v", conn, err, grpc.ErrClientConnClosing)
if err := conn.Close(); err != nil {
t.Fatalf("%v.Close() = %v, want <nil>", conn, err)
}
}
@ -441,14 +441,17 @@ type test struct {
func (te *test) tearDown() {
if te.cancel != nil {
te.cancel()
te.cancel = nil
}
te.srv.Stop()
if te.cc != nil {
te.cc.Close()
te.cc = nil
}
if te.restoreLogs != nil {
te.restoreLogs()
te.restoreLogs = nil
}
te.srv.Stop()
}
// newTest returns a new test using the provided testing.T and
@ -590,6 +593,7 @@ func TestTimeoutOnDeadServer(t *testing.T) {
func testTimeoutOnDeadServer(t *testing.T, e env) {
te := newTest(t, e)
te.userAgent = testAppUA
te.declareLogNoise(
"transport: http2Client.notifyError got notified that the client transport was broken EOF",
"grpc: Conn.transportMonitor exits due to: grpc: the client connection is closing",
@ -601,37 +605,17 @@ func testTimeoutOnDeadServer(t *testing.T, e env) {
cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
ctx, _ := context.WithTimeout(context.Background(), time.Second)
if _, err := cc.WaitForStateChange(ctx, grpc.Idle); err != nil {
t.Fatalf("cc.WaitForStateChange(_, %s) = _, %v, want _, <nil>", grpc.Idle, err)
}
ctx, _ = context.WithTimeout(context.Background(), time.Second)
if _, err := cc.WaitForStateChange(ctx, grpc.Connecting); err != nil {
t.Fatalf("cc.WaitForStateChange(_, %s) = _, %v, want _, <nil>", grpc.Connecting, err)
}
if state, err := cc.State(); err != nil || state != grpc.Ready {
t.Fatalf("cc.State() = %s, %v, want %s, <nil>", state, err, grpc.Ready)
}
ctx, _ = context.WithTimeout(context.Background(), time.Second)
if _, err := cc.WaitForStateChange(ctx, grpc.Ready); err != context.DeadlineExceeded {
t.Fatalf("cc.WaitForStateChange(_, %s) = _, %v, want _, %v", grpc.Ready, err, context.DeadlineExceeded)
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
}
te.srv.Stop()
// Set -1 as the timeout to make sure if transportMonitor gets error
// notification in time the failure path of the 1st invoke of
// ClientConn.wait hits the deadline exceeded error.
ctx, _ = context.WithTimeout(context.Background(), -1)
ctx, _ := context.WithTimeout(context.Background(), -1)
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); grpc.Code(err) != codes.DeadlineExceeded {
t.Fatalf("TestService/EmptyCall(%v, _) = _, error %v, want _, error code: %d", ctx, err, codes.DeadlineExceeded)
t.Fatalf("TestService/EmptyCall(%v, _) = _, %v, want _, error code: %d", ctx, err, codes.DeadlineExceeded)
}
ctx, _ = context.WithTimeout(context.Background(), time.Second)
if _, err := cc.WaitForStateChange(ctx, grpc.Ready); err != nil {
t.Fatalf("cc.WaitForStateChange(_, %s) = _, %v, want _, <nil>", grpc.Ready, err)
}
if state, err := cc.State(); err != nil || (state != grpc.Connecting && state != grpc.TransientFailure) {
t.Fatalf("cc.State() = %s, %v, want %s or %s, <nil>", state, err, grpc.Connecting, grpc.TransientFailure)
}
cc.Close()
awaitNewConnLogOutput()
}
@ -789,23 +773,6 @@ func testEmptyUnaryWithUserAgent(t *testing.T, e env) {
defer te.tearDown()
cc := te.clientConn()
// Wait until cc is connected.
ctx, _ := context.WithTimeout(context.Background(), time.Second)
if _, err := cc.WaitForStateChange(ctx, grpc.Idle); err != nil {
t.Fatalf("cc.WaitForStateChange(_, %s) = _, %v, want _, <nil>", grpc.Idle, err)
}
ctx, _ = context.WithTimeout(context.Background(), time.Second)
if _, err := cc.WaitForStateChange(ctx, grpc.Connecting); err != nil {
t.Fatalf("cc.WaitForStateChange(_, %s) = _, %v, want _, <nil>", grpc.Connecting, err)
}
if state, err := cc.State(); err != nil || state != grpc.Ready {
t.Fatalf("cc.State() = %s, %v, want %s, <nil>", state, err, grpc.Ready)
}
ctx, _ = context.WithTimeout(context.Background(), time.Second)
if _, err := cc.WaitForStateChange(ctx, grpc.Ready); err == nil {
t.Fatalf("cc.WaitForStateChange(_, %s) = _, <nil>, want _, %v", grpc.Ready, context.DeadlineExceeded)
}
tc := testpb.NewTestServiceClient(cc)
var header metadata.MD
reply, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Header(&header))
@ -817,15 +784,6 @@ func testEmptyUnaryWithUserAgent(t *testing.T, e env) {
}
te.srv.Stop()
cc.Close()
ctx, _ = context.WithTimeout(context.Background(), 5*time.Second)
if _, err := cc.WaitForStateChange(ctx, grpc.Ready); err != nil {
t.Fatalf("cc.WaitForStateChange(_, %s) = _, %v, want _, <nil>", grpc.Ready, err)
}
if state, err := cc.State(); err != nil || state != grpc.Shutdown {
t.Fatalf("cc.State() = %s, %v, want %s, <nil>", state, err, grpc.Shutdown)
}
}
func TestFailedEmptyUnary(t *testing.T) {
@ -1007,7 +965,6 @@ func testRetry(t *testing.T, e env) {
cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
var wg sync.WaitGroup
numRPC := 1000
@ -1073,9 +1030,8 @@ func testRPCTimeout(t *testing.T, e env) {
}
for i := -1; i <= 10; i++ {
ctx, _ := context.WithTimeout(context.Background(), time.Duration(i)*time.Millisecond)
reply, err := tc.UnaryCall(ctx, req)
if grpc.Code(err) != codes.DeadlineExceeded {
t.Fatalf(`TestService/UnaryCallv(_, _) = %v, %v; want <nil>, error code: %d`, reply, err, codes.DeadlineExceeded)
if _, err := tc.UnaryCall(ctx, req); grpc.Code(err) != codes.DeadlineExceeded {
t.Fatalf("TestService/UnaryCallv(_, _) = _, %v; want <nil>, error code: %d", err, codes.DeadlineExceeded)
}
}
}
@ -1111,12 +1067,9 @@ func testCancel(t *testing.T, e env) {
}
ctx, cancel := context.WithCancel(context.Background())
time.AfterFunc(1*time.Millisecond, cancel)
reply, err := tc.UnaryCall(ctx, req)
if grpc.Code(err) != codes.Canceled {
t.Fatalf(`TestService/UnaryCall(_, _) = %v, %v; want <nil>, error code: %d`, reply, err, codes.Canceled)
if r, err := tc.UnaryCall(ctx, req); grpc.Code(err) != codes.Canceled {
t.Fatalf("TestService/UnaryCall(_, _) = %v, %v; want _, error code: %d", r, err, codes.Canceled)
}
cc.Close()
awaitNewConnLogOutput()
}

View File

@ -35,7 +35,6 @@ package transport
import (
"bytes"
"errors"
"io"
"math"
"net"
@ -272,6 +271,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
}
}
t.mu.Lock()
if t.activeStreams == nil {
t.mu.Unlock()
return nil, ErrConnClosing
}
if t.state != reachable {
t.mu.Unlock()
return nil, ErrConnClosing
@ -397,9 +400,19 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
func (t *http2Client) CloseStream(s *Stream, err error) {
var updateStreams bool
t.mu.Lock()
if t.activeStreams == nil {
t.mu.Unlock()
return
}
if t.streamsQuota != nil {
updateStreams = true
}
if t.state == draining && len(t.activeStreams) == 1 {
// The transport is draining and s is the last live stream on t.
t.mu.Unlock()
t.Close()
return
}
delete(t.activeStreams, s.id)
t.mu.Unlock()
if updateStreams {
@ -441,7 +454,7 @@ func (t *http2Client) Close() (err error) {
}
if t.state == closing {
t.mu.Unlock()
return errors.New("transport: Close() was already called")
return
}
t.state = closing
t.mu.Unlock()
@ -464,6 +477,25 @@ func (t *http2Client) Close() (err error) {
return
}
func (t *http2Client) GracefulClose() error {
t.mu.Lock()
if t.state == closing {
t.mu.Unlock()
return nil
}
if t.state == draining {
t.mu.Unlock()
return nil
}
t.state = draining
active := len(t.activeStreams)
t.mu.Unlock()
if active == 0 {
return t.Close()
}
return nil
}
// Write formats the data into HTTP2 data frame(s) and sends it out. The caller
// should proceed only if Write returns nil.
// TODO(zhaoq): opts.Delay is ignored in this implementation. Support it later

View File

@ -321,6 +321,7 @@ const (
reachable transportState = iota
unreachable
closing
draining
)
// NewServerTransport creates a ServerTransport with conn or non-nil error
@ -391,6 +392,10 @@ type ClientTransport interface {
// is called only once.
Close() error
// GracefulClose starts to tear down the transport. It stops accepting
// new RPCs and wait the completion of the pending RPCs.
GracefulClose() error
// Write sends the data for the given stream. A nil stream indicates
// the write is to be performed on the transport as a whole.
Write(s *Stream, data []byte, opts *Options) error

View File

@ -331,19 +331,17 @@ func TestLargeMessage(t *testing.T) {
defer wg.Done()
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
t.Errorf("failed to open stream: %v", err)
t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
}
if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil {
t.Errorf("failed to send data: %v", err)
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
}
p := make([]byte, len(expectedResponseLarge))
_, recvErr := io.ReadFull(s, p)
if recvErr != nil || !bytes.Equal(p, expectedResponseLarge) {
t.Errorf("Error: %v, want <nil>; Result len: %d, want len %d", recvErr, len(p), len(expectedResponseLarge))
if _, err := io.ReadFull(s, p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
t.Errorf("io.ReadFull(_, %v) = _, %v, want %v, <nil>", err, p, expectedResponse)
}
_, recvErr = io.ReadFull(s, p)
if recvErr != io.EOF {
t.Errorf("Error: %v; want <EOF>", recvErr)
if _, err = io.ReadFull(s, p); err != io.EOF {
t.Errorf("Failed to complete the stream %v; want <EOF>", err)
}
}()
}
@ -352,6 +350,50 @@ func TestLargeMessage(t *testing.T) {
server.stop()
}
func TestGracefulClose(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, normal)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo.Small",
}
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
t.Fatalf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
}
if err = ct.GracefulClose(); err != nil {
t.Fatalf("%v.GracefulClose() = %v, want <nil>", ct, err)
}
var wg sync.WaitGroup
// Expect the failure for all the follow-up streams because ct has been closed gracefully.
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if _, err := ct.NewStream(context.Background(), callHdr); err != ErrConnClosing {
t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", err, ErrConnClosing)
}
}()
}
opts := Options{
Last: true,
Delay: false,
}
// The stream which was created before graceful close can still proceed.
if err := ct.Write(s, expectedRequest, &opts); err != nil {
t.Fatalf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
}
p := make([]byte, len(expectedResponse))
if _, err := io.ReadFull(s, p); err != nil || !bytes.Equal(p, expectedResponse) {
t.Fatalf("io.ReadFull(_, %v) = _, %v, want %v, <nil>", err, p, expectedResponse)
}
if _, err = io.ReadFull(s, p); err != io.EOF {
t.Fatalf("Failed to complete the stream %v; want <EOF>", err)
}
wg.Wait()
ct.Close()
server.stop()
}
func TestLargeMessageSuspension(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, suspended)
callHdr := &CallHdr{