diff --git a/.travis.yml b/.travis.yml index 9bc2c127..88511623 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,17 +1,18 @@ language: go go: - - 1.5.3 - - 1.6 + - 1.5.4 + - 1.6.3 + - 1.7 + +go_import_path: google.golang.org/grpc before_install: - - go get github.com/axw/gocov/gocov - - go get github.com/mattn/goveralls - - go get golang.org/x/tools/cmd/cover - -install: - - mkdir -p "$GOPATH/src/google.golang.org" - - mv "$TRAVIS_BUILD_DIR" "$GOPATH/src/google.golang.org/grpc" + - go get -u golang.org/x/tools/cmd/goimports github.com/golang/lint/golint github.com/axw/gocov/gocov github.com/mattn/goveralls golang.org/x/tools/cmd/cover script: + - '! gofmt -s -d -l . 2>&1 | read' + - '! goimports -l . | read' + - '! golint ./... | grep -vE "(_string|\.pb)\.go:"' + - '! go tool vet -all . 2>&1 | grep -vE "constant [0-9]+ not a string in call to Errorf" | grep -vF .pb.go:' # https://github.com/golang/protobuf/issues/214 - make test testrace diff --git a/README.md b/README.md index 90e9453d..660658be 100644 --- a/README.md +++ b/README.md @@ -28,5 +28,5 @@ See [API documentation](https://godoc.org/google.golang.org/grpc) for package an Status ------ -Beta release +GA diff --git a/balancer.go b/balancer.go index c298ae91..419e2146 100644 --- a/balancer.go +++ b/balancer.go @@ -40,7 +40,6 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/naming" - "google.golang.org/grpc/transport" ) // Address represents a server the client connects to. @@ -94,10 +93,10 @@ type Balancer interface { // instead of blocking. // // The function returns put which is called once the rpc has completed or failed. - // put can collect and report RPC stats to a remote load balancer. gRPC internals - // will try to call this again if err is non-nil (unless err is ErrClientConnClosing). + // put can collect and report RPC stats to a remote load balancer. // - // TODO: Add other non-recoverable errors? + // This function should only return the errors Balancer cannot recover by itself. + // gRPC internals will fail the RPC if an error is returned. Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) // Notify returns a channel that is used by gRPC internals to watch the addresses // gRPC needs to connect. The addresses might be from a name resolver or remote @@ -158,14 +157,15 @@ type roundRobin struct { func (rr *roundRobin) watchAddrUpdates() error { updates, err := rr.w.Next() if err != nil { - grpclog.Println("grpc: the naming watcher stops working due to %v.", err) + grpclog.Printf("grpc: the naming watcher stops working due to %v.\n", err) return err } rr.mu.Lock() defer rr.mu.Unlock() for _, update := range updates { addr := Address{ - Addr: update.Addr, + Addr: update.Addr, + Metadata: update.Metadata, } switch update.Op { case naming.Add: @@ -298,8 +298,19 @@ func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Ad } } } - // There is no address available. Wait on rr.waitCh. - // TODO(zhaoq): Handle the case when opts.BlockingWait is false. + if !opts.BlockingWait { + if len(rr.addrs) == 0 { + rr.mu.Unlock() + err = fmt.Errorf("there is no address available") + return + } + // Returns the next addr on rr.addrs for failfast RPCs. + addr = rr.addrs[rr.next].addr + rr.next++ + rr.mu.Unlock() + return + } + // Wait on rr.waitCh for non-failfast RPCs. if rr.waitCh == nil { ch = make(chan struct{}) rr.waitCh = ch @@ -310,7 +321,7 @@ func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Ad for { select { case <-ctx.Done(): - err = transport.ContextErr(ctx.Err()) + err = ctx.Err() return case <-ch: rr.mu.Lock() diff --git a/balancer_test.go b/balancer_test.go index 9d8d2bcd..48f8b27d 100644 --- a/balancer_test.go +++ b/balancer_test.go @@ -239,11 +239,11 @@ func TestCloseWithPendingRPC(t *testing.T) { t.Fatalf("Failed to create ClientConn: %v", err) } var reply string - if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil { + if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil { t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port) } // Remove the server. - updates := []*naming.Update{&naming.Update{ + updates := []*naming.Update{{ Op: naming.Delete, Addr: "127.0.0.1:" + servers[0].port, }} @@ -251,7 +251,7 @@ func TestCloseWithPendingRPC(t *testing.T) { // Loop until the above update applies. for { ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) - if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); Code(err) == codes.DeadlineExceeded { + if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); Code(err) == codes.DeadlineExceeded { break } time.Sleep(10 * time.Millisecond) @@ -262,7 +262,7 @@ func TestCloseWithPendingRPC(t *testing.T) { go func() { defer wg.Done() var reply string - if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err == nil { + if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil { t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err) } }() @@ -270,7 +270,7 @@ func TestCloseWithPendingRPC(t *testing.T) { defer wg.Done() var reply string time.Sleep(5 * time.Millisecond) - if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err == nil { + if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil { t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err) } }() @@ -287,7 +287,7 @@ func TestGetOnWaitChannel(t *testing.T) { t.Fatalf("Failed to create ClientConn: %v", err) } // Remove all servers so that all upcoming RPCs will block on waitCh. - updates := []*naming.Update{&naming.Update{ + updates := []*naming.Update{{ Op: naming.Delete, Addr: "127.0.0.1:" + servers[0].port, }} @@ -295,7 +295,7 @@ func TestGetOnWaitChannel(t *testing.T) { for { var reply string ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) - if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); Code(err) == codes.DeadlineExceeded { + if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); Code(err) == codes.DeadlineExceeded { break } time.Sleep(10 * time.Millisecond) @@ -305,12 +305,12 @@ func TestGetOnWaitChannel(t *testing.T) { go func() { defer wg.Done() var reply string - if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil { + if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil { t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want ", err) } }() // Add a connected server to get the above RPC through. - updates = []*naming.Update{&naming.Update{ + updates = []*naming.Update{{ Op: naming.Add, Addr: "127.0.0.1:" + servers[0].port, }} @@ -320,3 +320,119 @@ func TestGetOnWaitChannel(t *testing.T) { cc.Close() servers[0].stop() } + +func TestOneServerDown(t *testing.T) { + // Start 2 servers. + numServers := 2 + servers, r := startServers(t, numServers, math.MaxUint32) + cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("Failed to create ClientConn: %v", err) + } + // Add servers[1] to the service discovery. + var updates []*naming.Update + updates = append(updates, &naming.Update{ + Op: naming.Add, + Addr: "127.0.0.1:" + servers[1].port, + }) + r.w.inject(updates) + req := "port" + var reply string + // Loop until servers[1] is up + for { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port { + break + } + time.Sleep(10 * time.Millisecond) + } + + var wg sync.WaitGroup + numRPC := 100 + sleepDuration := 10 * time.Millisecond + wg.Add(1) + go func() { + time.Sleep(sleepDuration) + // After sleepDuration, kill server[0]. + servers[0].stop() + wg.Done() + }() + + // All non-failfast RPCs should not block because there's at least one connection available. + for i := 0; i < numRPC; i++ { + wg.Add(1) + go func() { + time.Sleep(sleepDuration) + // After sleepDuration, invoke RPC. + // server[0] is killed around the same time to make it racy between balancer and gRPC internals. + Invoke(context.Background(), "/foo/bar", &req, &reply, cc, FailFast(false)) + wg.Done() + }() + } + wg.Wait() + cc.Close() + for i := 0; i < numServers; i++ { + servers[i].stop() + } +} + +func TestOneAddressRemoval(t *testing.T) { + // Start 2 servers. + numServers := 2 + servers, r := startServers(t, numServers, math.MaxUint32) + cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{})) + if err != nil { + t.Fatalf("Failed to create ClientConn: %v", err) + } + // Add servers[1] to the service discovery. + var updates []*naming.Update + updates = append(updates, &naming.Update{ + Op: naming.Add, + Addr: "127.0.0.1:" + servers[1].port, + }) + r.w.inject(updates) + req := "port" + var reply string + // Loop until servers[1] is up + for { + if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port { + break + } + time.Sleep(10 * time.Millisecond) + } + + var wg sync.WaitGroup + numRPC := 100 + sleepDuration := 10 * time.Millisecond + wg.Add(1) + go func() { + time.Sleep(sleepDuration) + // After sleepDuration, delete server[0]. + var updates []*naming.Update + updates = append(updates, &naming.Update{ + Op: naming.Delete, + Addr: "127.0.0.1:" + servers[0].port, + }) + r.w.inject(updates) + wg.Done() + }() + + // All non-failfast RPCs should not fail because there's at least one connection available. + for i := 0; i < numRPC; i++ { + wg.Add(1) + go func() { + var reply string + time.Sleep(sleepDuration) + // After sleepDuration, invoke RPC. + // server[0] is removed around the same time to make it racy between balancer and gRPC internals. + if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil { + t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err) + } + wg.Done() + }() + } + wg.Wait() + cc.Close() + for i := 0; i < numServers; i++ { + servers[i].stop() + } +} diff --git a/benchmark/client/main.go b/benchmark/client/main.go index 5dfbe6a3..c63d5212 100644 --- a/benchmark/client/main.go +++ b/benchmark/client/main.go @@ -58,7 +58,7 @@ func closeLoopUnary() { for i := 0; i < *maxConcurrentRPCs; i++ { go func() { - for _ = range ch { + for range ch { start := time.Now() unaryCaller(tc) elapse := time.Since(start) diff --git a/benchmark/stats/histogram.go b/benchmark/stats/histogram.go index 099bcd65..918beadc 100644 --- a/benchmark/stats/histogram.go +++ b/benchmark/stats/histogram.go @@ -133,8 +133,8 @@ func (h *Histogram) Clear() { h.SumOfSquares = 0 h.Min = math.MaxInt64 h.Max = math.MinInt64 - for _, v := range h.Buckets { - v.Count = 0 + for i := range h.Buckets { + h.Buckets[i].Count = 0 } } diff --git a/benchmark/worker/main.go b/benchmark/worker/main.go index c8815b0e..17c52519 100644 --- a/benchmark/worker/main.go +++ b/benchmark/worker/main.go @@ -60,7 +60,7 @@ type byteBufCodec struct { func (byteBufCodec) Marshal(v interface{}) ([]byte, error) { b, ok := v.(*[]byte) if !ok { - return nil, fmt.Errorf("failed to marshal: %v is not type of *[]byte") + return nil, fmt.Errorf("failed to marshal: %v is not type of *[]byte", v) } return *b, nil } @@ -68,7 +68,7 @@ func (byteBufCodec) Marshal(v interface{}) ([]byte, error) { func (byteBufCodec) Unmarshal(data []byte, v interface{}) error { b, ok := v.(*[]byte) if !ok { - return fmt.Errorf("failed to marshal: %v is not type of *[]byte") + return fmt.Errorf("failed to marshal: %v is not type of *[]byte", v) } *b = data return nil @@ -138,8 +138,6 @@ func (s *workerServer) RunServer(stream testpb.WorkerService_RunServerServer) er return err } } - - return nil } func (s *workerServer) RunClient(stream testpb.WorkerService_RunClientServer) error { @@ -191,13 +189,11 @@ func (s *workerServer) RunClient(stream testpb.WorkerService_RunClientServer) er return err } } - - return nil } func (s *workerServer) CoreCount(ctx context.Context, in *testpb.CoreRequest) (*testpb.CoreResponse, error) { grpclog.Printf("core count: %v", runtime.NumCPU()) - return &testpb.CoreResponse{int32(runtime.NumCPU())}, nil + return &testpb.CoreResponse{Cores: int32(runtime.NumCPU())}, nil } func (s *workerServer) QuitWorker(ctx context.Context, in *testpb.Void) (*testpb.Void, error) { diff --git a/call.go b/call.go index d6d993b4..fea07998 100644 --- a/call.go +++ b/call.go @@ -36,6 +36,7 @@ package grpc import ( "bytes" "io" + "math" "time" "golang.org/x/net/context" @@ -51,13 +52,20 @@ import ( func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error { // Try to acquire header metadata from the server if there is any. var err error + defer func() { + if err != nil { + if _, ok := err.(transport.ConnectionError); !ok { + t.CloseStream(stream, err) + } + } + }() c.headerMD, err = stream.Header() if err != nil { return err } p := &parser{r: stream} for { - if err = recv(p, dopts.codec, stream, dopts.dc, reply); err != nil { + if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32); err != nil { if err == io.EOF { break } @@ -76,6 +84,7 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd } defer func() { if err != nil { + // If err is connection error, t will be closed, no need to close stream here. if _, ok := err.(transport.ConnectionError); !ok { t.CloseStream(stream, err) } @@ -90,7 +99,10 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err) } err = t.Write(stream, outBuf, opts) - if err != nil { + // t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method + // does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following + // recvResponse to get the final status. + if err != nil && err != io.EOF { return nil, err } // Sent successfully. @@ -101,7 +113,7 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd // Invoke is called by generated code. Also users can call Invoke directly when it // is really needed in their use cases. func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) { - var c callInfo + c := defaultCallInfo for _, o := range opts { if err := o.before(&c); err != nil { return toRPCErr(err) @@ -155,19 +167,17 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli t, put, err = cc.getTransport(ctx, gopts) if err != nil { // TODO(zhaoq): Probably revisit the error handling. - if err == ErrClientConnClosing { - return Errorf(codes.FailedPrecondition, "%v", err) + if _, ok := err.(*rpcError); ok { + return err } - if _, ok := err.(transport.StreamError); ok { - return toRPCErr(err) - } - if _, ok := err.(transport.ConnectionError); ok { + if err == errConnClosing || err == errConnUnavailable { if c.failFast { - return toRPCErr(err) + return Errorf(codes.Unavailable, "%v", err) } + continue } - // All the remaining cases are treated as retryable. - continue + // All the other errors are treated as Internal errors. + return Errorf(codes.Internal, "%v", err) } if c.traceInfo.tr != nil { c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) @@ -178,7 +188,10 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli put() put = nil } - if _, ok := err.(transport.ConnectionError); ok { + // Retry a non-failfast RPC when + // i) there is a connection error; or + // ii) the server started to drain before this RPC was initiated. + if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { if c.failFast { return toRPCErr(err) } @@ -186,20 +199,18 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli } return toRPCErr(err) } - // Receive the response err = recvResponse(cc.dopts, t, &c, stream, reply) if err != nil { if put != nil { put() put = nil } - if _, ok := err.(transport.ConnectionError); ok { + if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { if c.failFast { return toRPCErr(err) } continue } - t.CloseStream(stream, err) return toRPCErr(err) } if c.traceInfo.tr != nil { diff --git a/call_test.go b/call_test.go index 380bf872..64976d7b 100644 --- a/call_test.go +++ b/call_test.go @@ -81,7 +81,7 @@ type testStreamHandler struct { func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) { p := &parser{r: s} for { - pf, req, err := p.recvMsg() + pf, req, err := p.recvMsg(math.MaxInt32) if err == io.EOF { break } @@ -234,7 +234,7 @@ func TestInvokeLargeErr(t *testing.T) { var reply string req := "hello" err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc) - if _, ok := err.(rpcError); !ok { + if _, ok := err.(*rpcError); !ok { t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.") } if Code(err) != codes.Internal || len(ErrorDesc(err)) != sizeLargeErr { @@ -250,7 +250,7 @@ func TestInvokeErrorSpecialChars(t *testing.T) { var reply string req := "weird error" err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc) - if _, ok := err.(rpcError); !ok { + if _, ok := err.(*rpcError); !ok { t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.") } if got, want := ErrorDesc(err), weirdError; got != want { @@ -276,3 +276,18 @@ func TestInvokeCancel(t *testing.T) { cc.Close() server.stop() } + +// TestInvokeCancelClosedNonFail checks that a canceled non-failfast RPC +// on a closed client will terminate. +func TestInvokeCancelClosedNonFailFast(t *testing.T) { + server, cc := setUp(t, 0, math.MaxUint32) + var reply string + cc.Close() + req := "hello" + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := Invoke(ctx, "/foo/bar", &req, &reply, cc, FailFast(false)); err == nil { + t.Fatalf("canceled invoke on closed connection should fail") + } + server.stop() +} diff --git a/clientconn.go b/clientconn.go index 9b9c78d0..1d3b46c6 100644 --- a/clientconn.go +++ b/clientconn.go @@ -43,7 +43,6 @@ import ( "golang.org/x/net/context" "golang.org/x/net/trace" - "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/transport" @@ -68,13 +67,15 @@ var ( // errCredentialsConflict indicates that grpc.WithTransportCredentials() // and grpc.WithInsecure() are both called for a connection. errCredentialsConflict = errors.New("grpc: transport credentials are set for an insecure connection (grpc.WithTransportCredentials() and grpc.WithInsecure() are both called)") - // errNetworkIP indicates that the connection is down due to some network I/O error. + // errNetworkIO indicates that the connection is down due to some network I/O error. errNetworkIO = errors.New("grpc: failed with network I/O error") // errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs. errConnDrain = errors.New("grpc: the connection is drained") // errConnClosing indicates that the connection is closing. errConnClosing = errors.New("grpc: the connection is closing") - errNoAddr = errors.New("grpc: there is no address available to dial") + // errConnUnavailable indicates that the connection is unavailable. + errConnUnavailable = errors.New("grpc: the connection is unavailable") + errNoAddr = errors.New("grpc: there is no address available to dial") // minimum time to give a connection to complete minConnectTimeout = 20 * time.Second ) @@ -196,9 +197,14 @@ func WithTimeout(d time.Duration) DialOption { } // WithDialer returns a DialOption that specifies a function to use for dialing network addresses. -func WithDialer(f func(addr string, timeout time.Duration) (net.Conn, error)) DialOption { +func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption { return func(o *dialOptions) { - o.copts.Dialer = f + o.copts.Dialer = func(ctx context.Context, addr string) (net.Conn, error) { + if deadline, ok := ctx.Deadline(); ok { + return f(addr, deadline.Sub(time.Now())) + } + return f(addr, 0) + } } } @@ -209,49 +215,72 @@ func WithUserAgent(s string) DialOption { } } -// Dial creates a client connection the given target. +// Dial creates a client connection to the given target. func Dial(target string, opts ...DialOption) (*ClientConn, error) { + return DialContext(context.Background(), target, opts...) +} + +// DialContext creates a client connection to the given target. ctx can be used to +// cancel or expire the pending connecting. Once this function returns, the +// cancellation and expiration of ctx will be noop. Users should call ClientConn.Close +// to terminate all the pending operations after this function returns. +// This is the EXPERIMENTAL API. +func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) { cc := &ClientConn{ target: target, conns: make(map[Address]*addrConn), } + cc.ctx, cc.cancel = context.WithCancel(context.Background()) + defer func() { + select { + case <-ctx.Done(): + conn, err = nil, ctx.Err() + default: + } + + if err != nil { + cc.Close() + } + }() + for _, opt := range opts { opt(&cc.dopts) } + + // Set defaults. if cc.dopts.codec == nil { - // Set the default codec. cc.dopts.codec = protoCodec{} } - if cc.dopts.bs == nil { cc.dopts.bs = DefaultBackoffConfig } - cc.balancer = cc.dopts.balancer - if cc.balancer == nil { - cc.balancer = RoundRobin(nil) - } - if err := cc.balancer.Start(target); err != nil { - return nil, err - } var ( ok bool addrs []Address ) - ch := cc.balancer.Notify() - if ch == nil { - // There is no name resolver installed. + if cc.dopts.balancer == nil { + // Connect to target directly if balancer is nil. addrs = append(addrs, Address{Addr: target}) } else { - addrs, ok = <-ch - if !ok || len(addrs) == 0 { - return nil, errNoAddr + if err := cc.dopts.balancer.Start(target); err != nil { + return nil, err + } + ch := cc.dopts.balancer.Notify() + if ch == nil { + // There is no name resolver installed. + addrs = append(addrs, Address{Addr: target}) + } else { + addrs, ok = <-ch + if !ok || len(addrs) == 0 { + return nil, errNoAddr + } } } waitC := make(chan error, 1) go func() { for _, a := range addrs { - if err := cc.newAddrConn(a, false); err != nil { + if err := cc.resetAddrConn(a, false, nil); err != nil { waitC <- err return } @@ -263,15 +292,17 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) { timeoutCh = time.After(cc.dopts.timeout) } select { + case <-ctx.Done(): + return nil, ctx.Err() case err := <-waitC: if err != nil { - cc.Close() return nil, err } case <-timeoutCh: - cc.Close() return nil, ErrClientConnTimeout } + // If balancer is nil or balancer.Notify() is nil, ok will be false here. + // The lbWatcher goroutine will not be created. if ok { go cc.lbWatcher() } @@ -318,8 +349,10 @@ func (s ConnectivityState) String() string { // ClientConn represents a client connection to an RPC server. type ClientConn struct { + ctx context.Context + cancel context.CancelFunc + target string - balancer Balancer authority string dopts dialOptions @@ -328,7 +361,7 @@ type ClientConn struct { } func (cc *ClientConn) lbWatcher() { - for addrs := range cc.balancer.Notify() { + for addrs := range cc.dopts.balancer.Notify() { var ( add []Address // Addresses need to setup connections. del []*addrConn // Connections need to tear down. @@ -349,11 +382,12 @@ func (cc *ClientConn) lbWatcher() { } if !keep { del = append(del, c) + delete(cc.conns, c.addr) } } cc.mu.Unlock() for _, a := range add { - cc.newAddrConn(a, true) + cc.resetAddrConn(a, true, nil) } for _, c := range del { c.tearDown(errConnDrain) @@ -361,13 +395,17 @@ func (cc *ClientConn) lbWatcher() { } } -func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error { +// resetAddrConn creates an addrConn for addr and adds it to cc.conns. +// If there is an old addrConn for addr, it will be torn down, using tearDownErr as the reason. +// If tearDownErr is nil, errConnDrain will be used instead. +func (cc *ClientConn) resetAddrConn(addr Address, skipWait bool, tearDownErr error) error { ac := &addrConn{ - cc: cc, - addr: addr, - dopts: cc.dopts, - shutdownChan: make(chan struct{}), + cc: cc, + addr: addr, + dopts: cc.dopts, } + ac.ctx, ac.cancel = context.WithCancel(cc.ctx) + ac.stateCV = sync.NewCond(&ac.mu) if EnableTracing { ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr) } @@ -385,26 +423,44 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error { } } } - // Insert ac into ac.cc.conns. This needs to be done before any getTransport(...) is called. - ac.cc.mu.Lock() - if ac.cc.conns == nil { - ac.cc.mu.Unlock() + // Track ac in cc. This needs to be done before any getTransport(...) is called. + cc.mu.Lock() + if cc.conns == nil { + cc.mu.Unlock() return ErrClientConnClosing } - stale := ac.cc.conns[ac.addr] - ac.cc.conns[ac.addr] = ac - ac.cc.mu.Unlock() + stale := cc.conns[ac.addr] + cc.conns[ac.addr] = ac + cc.mu.Unlock() if stale != nil { // There is an addrConn alive on ac.addr already. This could be due to - // i) stale's Close is undergoing; - // ii) a buggy Balancer notifies duplicated Addresses. - stale.tearDown(errConnDrain) + // 1) a buggy Balancer notifies duplicated Addresses; + // 2) goaway was received, a new ac will replace the old ac. + // The old ac should be deleted from cc.conns, but the + // underlying transport should drain rather than close. + if tearDownErr == nil { + // tearDownErr is nil if resetAddrConn is called by + // 1) Dial + // 2) lbWatcher + // In both cases, the stale ac should drain, not close. + stale.tearDown(errConnDrain) + } else { + stale.tearDown(tearDownErr) + } } - ac.stateCV = sync.NewCond(&ac.mu) // skipWait may overwrite the decision in ac.dopts.block. if ac.dopts.block && !skipWait { if err := ac.resetTransport(false); err != nil { - ac.tearDown(err) + if err != errConnClosing { + // Tear down ac and delete it from cc.conns. + cc.mu.Lock() + delete(cc.conns, ac.addr) + cc.mu.Unlock() + ac.tearDown(err) + } + if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() { + return e.Origin() + } return err } // Start to monitor the error status of transport. @@ -414,7 +470,10 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error { go func() { if err := ac.resetTransport(false); err != nil { grpclog.Printf("Failed to dial %s: %v; please retry.", ac.addr.Addr, err) - ac.tearDown(err) + if err != errConnClosing { + // Keep this ac in cc.conns, to get the reason it's torn down. + ac.tearDown(err) + } return } ac.transportMonitor() @@ -424,25 +483,48 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error { } func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) { - // TODO(zhaoq): Implement fail-fast logic. - addr, put, err := cc.balancer.Get(ctx, opts) - if err != nil { - return nil, nil, err - } - cc.mu.RLock() - if cc.conns == nil { + var ( + ac *addrConn + ok bool + put func() + ) + if cc.dopts.balancer == nil { + // If balancer is nil, there should be only one addrConn available. + cc.mu.RLock() + if cc.conns == nil { + cc.mu.RUnlock() + return nil, nil, toRPCErr(ErrClientConnClosing) + } + for _, ac = range cc.conns { + // Break after the first iteration to get the first addrConn. + ok = true + break + } + cc.mu.RUnlock() + } else { + var ( + addr Address + err error + ) + addr, put, err = cc.dopts.balancer.Get(ctx, opts) + if err != nil { + return nil, nil, toRPCErr(err) + } + cc.mu.RLock() + if cc.conns == nil { + cc.mu.RUnlock() + return nil, nil, toRPCErr(ErrClientConnClosing) + } + ac, ok = cc.conns[addr] cc.mu.RUnlock() - return nil, nil, ErrClientConnClosing } - ac, ok := cc.conns[addr] - cc.mu.RUnlock() if !ok { if put != nil { put() } - return nil, nil, transport.StreamErrorf(codes.Internal, "grpc: failed to find the transport to send the rpc") + return nil, nil, errConnClosing } - t, err := ac.wait(ctx) + t, err := ac.wait(ctx, cc.dopts.balancer != nil, !opts.BlockingWait) if err != nil { if put != nil { put() @@ -454,6 +536,8 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) // Close tears down the ClientConn and all underlying connections. func (cc *ClientConn) Close() error { + cc.cancel() + cc.mu.Lock() if cc.conns == nil { cc.mu.Unlock() @@ -462,7 +546,9 @@ func (cc *ClientConn) Close() error { conns := cc.conns cc.conns = nil cc.mu.Unlock() - cc.balancer.Close() + if cc.dopts.balancer != nil { + cc.dopts.balancer.Close() + } for _, ac := range conns { ac.tearDown(ErrClientConnClosing) } @@ -471,11 +557,13 @@ func (cc *ClientConn) Close() error { // addrConn is a network connection to a given address. type addrConn struct { - cc *ClientConn - addr Address - dopts dialOptions - shutdownChan chan struct{} - events trace.EventLog + ctx context.Context + cancel context.CancelFunc + + cc *ClientConn + addr Address + dopts dialOptions + events trace.EventLog mu sync.Mutex state ConnectivityState @@ -485,6 +573,9 @@ type addrConn struct { // due to timeout. ready chan struct{} transport transport.ClientTransport + + // The reason this addrConn is torn down. + tearDownErr error } // printf records an event in ac's event log, unless ac has been closed. @@ -540,8 +631,7 @@ func (ac *addrConn) waitForStateChange(ctx context.Context, sourceState Connecti } func (ac *addrConn) resetTransport(closeTransport bool) error { - var retries int - for { + for retries := 0; ; retries++ { ac.mu.Lock() ac.printf("connecting") if ac.state == Shutdown { @@ -561,13 +651,20 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { t.Close() } sleepTime := ac.dopts.bs.backoff(retries) - ac.dopts.copts.Timeout = sleepTime - if sleepTime < minConnectTimeout { - ac.dopts.copts.Timeout = minConnectTimeout + timeout := minConnectTimeout + if timeout < sleepTime { + timeout = sleepTime } + ctx, cancel := context.WithTimeout(ac.ctx, timeout) connectTime := time.Now() - newTransport, err := transport.NewClientTransport(ac.addr.Addr, &ac.dopts.copts) + newTransport, err := transport.NewClientTransport(ctx, ac.addr.Addr, ac.dopts.copts) if err != nil { + cancel() + + if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() { + return err + } + grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr) ac.mu.Lock() if ac.state == Shutdown { // ac.tearDown(...) has been invoked. @@ -582,17 +679,12 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { ac.ready = nil } ac.mu.Unlock() - sleepTime -= time.Since(connectTime) - if sleepTime < 0 { - sleepTime = 0 - } closeTransport = false select { - case <-time.After(sleepTime): - case <-ac.shutdownChan: + case <-time.After(sleepTime - time.Since(connectTime)): + case <-ac.ctx.Done(): + return ac.ctx.Err() } - retries++ - grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr) continue } ac.mu.Lock() @@ -610,7 +702,9 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { close(ac.ready) ac.ready = nil } - ac.down = ac.cc.balancer.Up(ac.addr) + if ac.cc.dopts.balancer != nil { + ac.down = ac.cc.dopts.balancer.Up(ac.addr) + } ac.mu.Unlock() return nil } @@ -624,14 +718,42 @@ func (ac *addrConn) transportMonitor() { t := ac.transport ac.mu.Unlock() select { - // shutdownChan is needed to detect the teardown when + // This is needed to detect the teardown when // the addrConn is idle (i.e., no RPC in flight). - case <-ac.shutdownChan: + case <-ac.ctx.Done(): + select { + case <-t.Error(): + t.Close() + default: + } + return + case <-t.GoAway(): + // If GoAway happens without any network I/O error, ac is closed without shutting down the + // underlying transport (the transport will be closed when all the pending RPCs finished or + // failed.). + // If GoAway and some network I/O error happen concurrently, ac and its underlying transport + // are closed. + // In both cases, a new ac is created. + select { + case <-t.Error(): + ac.cc.resetAddrConn(ac.addr, true, errNetworkIO) + default: + ac.cc.resetAddrConn(ac.addr, true, errConnDrain) + } return case <-t.Error(): + select { + case <-ac.ctx.Done(): + t.Close() + return + case <-t.GoAway(): + ac.cc.resetAddrConn(ac.addr, true, errNetworkIO) + return + default: + } ac.mu.Lock() if ac.state == Shutdown { - // ac.tearDown(...) has been invoked. + // ac has been shutdown. ac.mu.Unlock() return } @@ -643,38 +765,53 @@ func (ac *addrConn) transportMonitor() { ac.printf("transport exiting: %v", err) ac.mu.Unlock() grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err) + if err != errConnClosing { + // Keep this ac in cc.conns, to get the reason it's torn down. + ac.tearDown(err) + } return } } } } -// wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed. -func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error) { +// wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed or +// iv) transport is in TransientFailure and there's no balancer/failfast is true. +func (ac *addrConn) wait(ctx context.Context, hasBalancer, failfast bool) (transport.ClientTransport, error) { for { ac.mu.Lock() switch { case ac.state == Shutdown: + if failfast || !hasBalancer { + // RPC is failfast or balancer is nil. This RPC should fail with ac.tearDownErr. + err := ac.tearDownErr + ac.mu.Unlock() + return nil, err + } ac.mu.Unlock() return nil, errConnClosing case ac.state == Ready: ct := ac.transport ac.mu.Unlock() return ct, nil - default: - ready := ac.ready - if ready == nil { - ready = make(chan struct{}) - ac.ready = ready - } - ac.mu.Unlock() - select { - case <-ctx.Done(): - return nil, transport.ContextErr(ctx.Err()) - // Wait until the new transport is ready or failed. - case <-ready: + case ac.state == TransientFailure: + if failfast || hasBalancer { + ac.mu.Unlock() + return nil, errConnUnavailable } } + ready := ac.ready + if ready == nil { + ready = make(chan struct{}) + ac.ready = ready + } + ac.mu.Unlock() + select { + case <-ctx.Done(): + return nil, toRPCErr(ctx.Err()) + // Wait until the new transport is ready or failed. + case <-ready: + } } } @@ -682,24 +819,28 @@ func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error) // TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in // some edge cases (e.g., the caller opens and closes many addrConn's in a // tight loop. +// tearDown doesn't remove ac from ac.cc.conns. func (ac *addrConn) tearDown(err error) { + ac.cancel() + ac.mu.Lock() - defer func() { - ac.mu.Unlock() - ac.cc.mu.Lock() - if ac.cc.conns != nil { - delete(ac.cc.conns, ac.addr) - } - ac.cc.mu.Unlock() - }() - if ac.state == Shutdown { - return - } - ac.state = Shutdown + defer ac.mu.Unlock() if ac.down != nil { ac.down(downErrorf(false, false, "%v", err)) ac.down = nil } + if err == errConnDrain && ac.transport != nil { + // GracefulClose(...) may be executed multiple times when + // i) receiving multiple GoAway frames from the server; or + // ii) there are concurrent name resolver/Balancer triggered + // address removal and GoAway. + ac.transport.GracefulClose() + } + if ac.state == Shutdown { + return + } + ac.state = Shutdown + ac.tearDownErr = err ac.stateCV.Broadcast() if ac.events != nil { ac.events.Finish() @@ -709,15 +850,8 @@ func (ac *addrConn) tearDown(err error) { close(ac.ready) ac.ready = nil } - if ac.transport != nil { - if err == errConnDrain { - ac.transport.GracefulClose() - } else { - ac.transport.Close() - } - } - if ac.shutdownChan != nil { - close(ac.shutdownChan) + if ac.transport != nil && err != errConnDrain { + ac.transport.Close() } return } diff --git a/clientconn_test.go b/clientconn_test.go index 29db8bfc..c49548dc 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -37,6 +37,8 @@ import ( "testing" "time" + "golang.org/x/net/context" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/oauth" ) @@ -67,13 +69,21 @@ func TestTLSDialTimeout(t *testing.T) { } } +func TestDialContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if _, err := DialContext(ctx, "Non-Existent.Server:80", WithBlock(), WithInsecure()); err != context.Canceled { + t.Fatalf("grpc.DialContext(%v, _) = _, %v, want _, %v", ctx, err, context.Canceled) + } +} + func TestCredentialsMisuse(t *testing.T) { tlsCreds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") if err != nil { t.Fatalf("Failed to create authenticator %v", err) } // Two conflicting credential configurations - if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(tlsCreds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsConflict { + if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(tlsCreds), WithBlock(), WithInsecure()); err != errCredentialsConflict { t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsConflict) } rpcCreds, err := oauth.NewJWTAccessFromKey(nil) @@ -81,7 +91,7 @@ func TestCredentialsMisuse(t *testing.T) { t.Fatalf("Failed to create credentials %v", err) } // security info on insecure connection - if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(rpcCreds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errTransportCredentialsMissing { + if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(rpcCreds), WithBlock(), WithInsecure()); err != errTransportCredentialsMissing { t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errTransportCredentialsMissing) } } @@ -123,4 +133,5 @@ func testBackoffConfigSet(t *testing.T, expected *BackoffConfig, opts ...DialOpt if actual != *expected { t.Fatalf("unexpected backoff config on connection: %v, want %v", actual, expected) } + conn.Close() } diff --git a/credentials/credentials.go b/credentials/credentials.go index 23fe63ea..3f17b706 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -44,7 +44,6 @@ import ( "io/ioutil" "net" "strings" - "time" "golang.org/x/net/context" ) @@ -93,11 +92,12 @@ type TransportCredentials interface { // ClientHandshake does the authentication handshake specified by the corresponding // authentication protocol on rawConn for clients. It returns the authenticated // connection and the corresponding auth information about the connection. - ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, AuthInfo, error) + // Implementations must use the provided context to implement timely cancellation. + ClientHandshake(context.Context, string, net.Conn) (net.Conn, AuthInfo, error) // ServerHandshake does the authentication handshake for servers. It returns // the authenticated connection and the corresponding auth information about // the connection. - ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) + ServerHandshake(net.Conn) (net.Conn, AuthInfo, error) // Info provides the ProtocolInfo of this TransportCredentials. Info() ProtocolInfo } @@ -116,7 +116,7 @@ func (t TLSInfo) AuthType() string { // tlsCreds is the credentials required for authenticating a connection using TLS. type tlsCreds struct { // TLS configuration - config tls.Config + config *tls.Config } func (c tlsCreds) Info() ProtocolInfo { @@ -136,40 +136,28 @@ func (c *tlsCreds) RequireTransportSecurity() bool { return true } -type timeoutError struct{} - -func (timeoutError) Error() string { return "credentials: Dial timed out" } -func (timeoutError) Timeout() bool { return true } -func (timeoutError) Temporary() bool { return true } - -func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ AuthInfo, err error) { - // borrow some code from tls.DialWithDialer - var errChannel chan error - if timeout != 0 { - errChannel = make(chan error, 2) - time.AfterFunc(timeout, func() { - errChannel <- timeoutError{} - }) - } - if c.config.ServerName == "" { +func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) { + // use local cfg to avoid clobbering ServerName if using multiple endpoints + cfg := cloneTLSConfig(c.config) + if cfg.ServerName == "" { colonPos := strings.LastIndex(addr, ":") if colonPos == -1 { colonPos = len(addr) } - c.config.ServerName = addr[:colonPos] + cfg.ServerName = addr[:colonPos] } - conn := tls.Client(rawConn, &c.config) - if timeout == 0 { - err = conn.Handshake() - } else { - go func() { - errChannel <- conn.Handshake() - }() - err = <-errChannel - } - if err != nil { - rawConn.Close() - return nil, nil, err + conn := tls.Client(rawConn, cfg) + errChannel := make(chan error, 1) + go func() { + errChannel <- conn.Handshake() + }() + select { + case err := <-errChannel: + if err != nil { + return nil, nil, err + } + case <-ctx.Done(): + return nil, nil, ctx.Err() } // TODO(zhaoq): Omit the auth info for client now. It is more for // information than anything else. @@ -177,9 +165,8 @@ func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.D } func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) { - conn := tls.Server(rawConn, &c.config) + conn := tls.Server(rawConn, c.config) if err := conn.Handshake(); err != nil { - rawConn.Close() return nil, nil, err } return conn, TLSInfo{conn.ConnectionState()}, nil @@ -187,7 +174,7 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) // NewTLS uses c to construct a TransportCredentials based on TLS. func NewTLS(c *tls.Config) TransportCredentials { - tc := &tlsCreds{*c} + tc := &tlsCreds{cloneTLSConfig(c)} tc.config.NextProtos = alpnProtoStr return tc } diff --git a/credentials/credentials_util_go17.go b/credentials/credentials_util_go17.go new file mode 100644 index 00000000..9647b9ec --- /dev/null +++ b/credentials/credentials_util_go17.go @@ -0,0 +1,76 @@ +// +build go1.7 + +/* + * + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package credentials + +import ( + "crypto/tls" +) + +// cloneTLSConfig returns a shallow clone of the exported +// fields of cfg, ignoring the unexported sync.Once, which +// contains a mutex and must not be copied. +// +// If cfg is nil, a new zero tls.Config is returned. +// +// TODO replace this function with official clone function. +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return &tls.Config{ + Rand: cfg.Rand, + Time: cfg.Time, + Certificates: cfg.Certificates, + NameToCertificate: cfg.NameToCertificate, + GetCertificate: cfg.GetCertificate, + RootCAs: cfg.RootCAs, + NextProtos: cfg.NextProtos, + ServerName: cfg.ServerName, + ClientAuth: cfg.ClientAuth, + ClientCAs: cfg.ClientCAs, + InsecureSkipVerify: cfg.InsecureSkipVerify, + CipherSuites: cfg.CipherSuites, + PreferServerCipherSuites: cfg.PreferServerCipherSuites, + SessionTicketsDisabled: cfg.SessionTicketsDisabled, + SessionTicketKey: cfg.SessionTicketKey, + ClientSessionCache: cfg.ClientSessionCache, + MinVersion: cfg.MinVersion, + MaxVersion: cfg.MaxVersion, + CurvePreferences: cfg.CurvePreferences, + DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled, + Renegotiation: cfg.Renegotiation, + } +} diff --git a/credentials/credentials_util_pre_go17.go b/credentials/credentials_util_pre_go17.go new file mode 100644 index 00000000..09b8d12c --- /dev/null +++ b/credentials/credentials_util_pre_go17.go @@ -0,0 +1,74 @@ +// +build !go1.7 + +/* + * + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package credentials + +import ( + "crypto/tls" +) + +// cloneTLSConfig returns a shallow clone of the exported +// fields of cfg, ignoring the unexported sync.Once, which +// contains a mutex and must not be copied. +// +// If cfg is nil, a new zero tls.Config is returned. +// +// TODO replace this function with official clone function. +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return &tls.Config{ + Rand: cfg.Rand, + Time: cfg.Time, + Certificates: cfg.Certificates, + NameToCertificate: cfg.NameToCertificate, + GetCertificate: cfg.GetCertificate, + RootCAs: cfg.RootCAs, + NextProtos: cfg.NextProtos, + ServerName: cfg.ServerName, + ClientAuth: cfg.ClientAuth, + ClientCAs: cfg.ClientCAs, + InsecureSkipVerify: cfg.InsecureSkipVerify, + CipherSuites: cfg.CipherSuites, + PreferServerCipherSuites: cfg.PreferServerCipherSuites, + SessionTicketsDisabled: cfg.SessionTicketsDisabled, + SessionTicketKey: cfg.SessionTicketKey, + ClientSessionCache: cfg.ClientSessionCache, + MinVersion: cfg.MinVersion, + MaxVersion: cfg.MaxVersion, + CurvePreferences: cfg.CurvePreferences, + } +} diff --git a/examples/gotutorial.md b/examples/gotutorial.md index 39833fdf..25c0a2df 100644 --- a/examples/gotutorial.md +++ b/examples/gotutorial.md @@ -28,12 +28,12 @@ Then change your current directory to `grpc-go/examples/route_guide`: $ cd $GOPATH/src/google.golang.org/grpc/examples/route_guide ``` -You also should have the relevant tools installed to generate the server and client interface code - if you don't already, follow the setup instructions in [the Go quick start guide](examples/). +You also should have the relevant tools installed to generate the server and client interface code - if you don't already, follow the setup instructions in [the Go quick start guide](https://github.com/grpc/grpc-go/tree/master/examples/). ## Defining the service -Our first step (as you'll know from the [quick start](http://www.grpc.io/docs/#quick-start)) is to define the gRPC *service* and the method *request* and *response* types using [protocol buffers] (https://developers.google.com/protocol-buffers/docs/overview). You can see the complete .proto file in [`examples/route_guide/proto/route_guide.proto`](examples/route_guide/proto/route_guide.proto). +Our first step (as you'll know from the [quick start](http://www.grpc.io/docs/#quick-start)) is to define the gRPC *service* and the method *request* and *response* types using [protocol buffers] (https://developers.google.com/protocol-buffers/docs/overview). You can see the complete .proto file in [examples/route_guide/routeguide/route_guide.proto](https://github.com/grpc/grpc-go/tree/master/examples/route_guide/routeguide/route_guide.proto). To define a service, you specify a named `service` in your .proto file: diff --git a/examples/route_guide/client/client.go b/examples/route_guide/client/client.go index f84352c8..fff6398d 100644 --- a/examples/route_guide/client/client.go +++ b/examples/route_guide/client/client.go @@ -115,12 +115,12 @@ func runRecordRoute(client pb.RouteGuideClient) { // runRouteChat receives a sequence of route notes, while sending notes for various locations. func runRouteChat(client pb.RouteGuideClient) { notes := []*pb.RouteNote{ - {&pb.Point{0, 1}, "First message"}, - {&pb.Point{0, 2}, "Second message"}, - {&pb.Point{0, 3}, "Third message"}, - {&pb.Point{0, 1}, "Fourth message"}, - {&pb.Point{0, 2}, "Fifth message"}, - {&pb.Point{0, 3}, "Sixth message"}, + {&pb.Point{Latitude: 0, Longitude: 1}, "First message"}, + {&pb.Point{Latitude: 0, Longitude: 2}, "Second message"}, + {&pb.Point{Latitude: 0, Longitude: 3}, "Third message"}, + {&pb.Point{Latitude: 0, Longitude: 1}, "Fourth message"}, + {&pb.Point{Latitude: 0, Longitude: 2}, "Fifth message"}, + {&pb.Point{Latitude: 0, Longitude: 3}, "Sixth message"}, } stream, err := client.RouteChat(context.Background()) if err != nil { @@ -153,7 +153,7 @@ func runRouteChat(client pb.RouteGuideClient) { func randomPoint(r *rand.Rand) *pb.Point { lat := (r.Int31n(180) - 90) * 1e7 long := (r.Int31n(360) - 180) * 1e7 - return &pb.Point{lat, long} + return &pb.Point{Latitude: lat, Longitude: long} } func main() { @@ -186,13 +186,16 @@ func main() { client := pb.NewRouteGuideClient(conn) // Looking for a valid feature - printFeature(client, &pb.Point{409146138, -746188906}) + printFeature(client, &pb.Point{Latitude: 409146138, Longitude: -746188906}) // Feature missing. - printFeature(client, &pb.Point{0, 0}) + printFeature(client, &pb.Point{Latitude: 0, Longitude: 0}) // Looking for features between 40, -75 and 42, -73. - printFeatures(client, &pb.Rectangle{&pb.Point{400000000, -750000000}, &pb.Point{420000000, -730000000}}) + printFeatures(client, &pb.Rectangle{ + Lo: &pb.Point{Latitude: 400000000, Longitude: -750000000}, + Hi: &pb.Point{Latitude: 420000000, Longitude: -730000000}, + }) // RecordRoute runRecordRoute(client) diff --git a/examples/route_guide/server/server.go b/examples/route_guide/server/server.go index c8be4970..5932722b 100644 --- a/examples/route_guide/server/server.go +++ b/examples/route_guide/server/server.go @@ -79,7 +79,7 @@ func (s *routeGuideServer) GetFeature(ctx context.Context, point *pb.Point) (*pb } } // No feature was found, return an unnamed feature - return &pb.Feature{"", point}, nil + return &pb.Feature{Location: point}, nil } // ListFeatures lists all features contained within the given bounding Rectangle. diff --git a/health/health.go b/health/health.go index f74fd69b..34255298 100644 --- a/health/health.go +++ b/health/health.go @@ -11,19 +11,22 @@ import ( healthpb "google.golang.org/grpc/health/grpc_health_v1" ) -type HealthServer struct { +// Server implements `service Health`. +type Server struct { mu sync.Mutex - // statusMap stores the serving status of the services this HealthServer monitors. + // statusMap stores the serving status of the services this Server monitors. statusMap map[string]healthpb.HealthCheckResponse_ServingStatus } -func NewHealthServer() *HealthServer { - return &HealthServer{ +// NewServer returns a new Server. +func NewServer() *Server { + return &Server{ statusMap: make(map[string]healthpb.HealthCheckResponse_ServingStatus), } } -func (s *HealthServer) Check(ctx context.Context, in *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { +// Check implements `service Health`. +func (s *Server) Check(ctx context.Context, in *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { s.mu.Lock() defer s.mu.Unlock() if in.Service == "" { @@ -42,7 +45,7 @@ func (s *HealthServer) Check(ctx context.Context, in *healthpb.HealthCheckReques // SetServingStatus is called when need to reset the serving status of a service // or insert a new service entry into the statusMap. -func (s *HealthServer) SetServingStatus(service string, status healthpb.HealthCheckResponse_ServingStatus) { +func (s *Server) SetServingStatus(service string, status healthpb.HealthCheckResponse_ServingStatus) { s.mu.Lock() s.statusMap[service] = status s.mu.Unlock() diff --git a/metadata/metadata.go b/metadata/metadata.go index 52070dbe..954c0f77 100644 --- a/metadata/metadata.go +++ b/metadata/metadata.go @@ -60,15 +60,21 @@ func encodeKeyValue(k, v string) (string, string) { // DecodeKeyValue returns the original key and value corresponding to the // encoded data in k, v. +// If k is a binary header and v contains comma, v is split on comma before decoded, +// and the decoded v will be joined with comma before returned. func DecodeKeyValue(k, v string) (string, string, error) { if !strings.HasSuffix(k, binHdrSuffix) { return k, v, nil } - val, err := base64.StdEncoding.DecodeString(v) - if err != nil { - return "", "", err + vvs := strings.Split(v, ",") + for i, vv := range vvs { + val, err := base64.StdEncoding.DecodeString(vv) + if err != nil { + return "", "", err + } + vvs[i] = string(val) } - return k, string(val), nil + return k, strings.Join(vvs, ","), nil } // MD is a mapping from metadata keys to values. Users should use the following diff --git a/metadata/metadata_test.go b/metadata/metadata_test.go index 02e6ba51..99e86820 100644 --- a/metadata/metadata_test.go +++ b/metadata/metadata_test.go @@ -74,6 +74,8 @@ func TestDecodeKeyValue(t *testing.T) { {"a", "abc", "a", "abc", nil}, {"key-bin", "Zm9vAGJhcg==", "key-bin", "foo\x00bar", nil}, {"key-bin", "woA=", "key-bin", binaryValue, nil}, + {"a", "abc,efg", "a", "abc,efg", nil}, + {"key-bin", "Zm9vAGJhcg==,Zm9vAGJhcg==", "key-bin", "foo\x00bar,foo\x00bar", nil}, } { k, v, err := DecodeKeyValue(test.kin, test.vin) if k != test.kout || !reflect.DeepEqual(v, test.vout) || !reflect.DeepEqual(err, test.err) { diff --git a/reflection/serverreflection.go b/reflection/serverreflection.go index efc6bc88..686090aa 100644 --- a/reflection/serverreflection.go +++ b/reflection/serverreflection.go @@ -70,7 +70,7 @@ import ( type serverReflectionServer struct { s *grpc.Server // TODO add more cache if necessary - serviceInfo map[string]*grpc.ServiceInfo // cache for s.GetServiceInfo() + serviceInfo map[string]grpc.ServiceInfo // cache for s.GetServiceInfo() } // Register registers the server reflection service on the given gRPC server. @@ -214,19 +214,19 @@ func (s *serverReflectionServer) serviceMetadataForSymbol(name string) (interfac return nil, fmt.Errorf("unknown symbol: %v", name) } - // Search for method in info. + // Search the method name in info.Methods. var found bool for _, m := range info.Methods { - if m == name[pos+1:] { + if m.Name == name[pos+1:] { found = true break } } - if !found { - return nil, fmt.Errorf("unknown symbol: %v", name) + if found { + return info.Metadata, nil } - return info.Metadata, nil + return nil, fmt.Errorf("unknown symbol: %v", name) } // fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol, @@ -253,7 +253,7 @@ func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) ( // Metadata not valid. enc, ok := meta.([]byte) if !ok { - return nil, fmt.Errorf("invalid file descriptor for symbol: %v") + return nil, fmt.Errorf("invalid file descriptor for symbol: %v", name) } fd, err = s.decodeFileDesc(enc) diff --git a/reflection/serverreflection_test.go b/reflection/serverreflection_test.go index aeb31e14..ca9610e2 100644 --- a/reflection/serverreflection_test.go +++ b/reflection/serverreflection_test.go @@ -273,6 +273,7 @@ func testFileContainingSymbol(t *testing.T, stream rpb.ServerReflection_ServerRe }{ {"grpc.testing.SearchService", fdTestByte}, {"grpc.testing.SearchService.Search", fdTestByte}, + {"grpc.testing.SearchService.StreamingSearch", fdTestByte}, {"grpc.testing.SearchResponse", fdTestByte}, {"grpc.testing.ToBeExtened", fdProto2Byte}, } { diff --git a/rpc_util.go b/rpc_util.go index 080ebb14..35ac9cc7 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -141,6 +141,8 @@ type callInfo struct { traceInfo traceInfo // in trace.go } +var defaultCallInfo = callInfo{failFast: true} + // CallOption configures a Call before it starts or extracts information from // a Call after it completes. type CallOption interface { @@ -179,6 +181,19 @@ func Trailer(md *metadata.MD) CallOption { }) } +// FailFast configures the action to take when an RPC is attempted on broken +// connections or unreachable servers. If failfast is true, the RPC will fail +// immediately. Otherwise, the RPC client will block the call until a +// connection is available (or the call is canceled or times out) and will retry +// the call if it fails due to a transient error. Please refer to +// https://github.com/grpc/grpc/blob/master/doc/fail_fast.md +func FailFast(failFast bool) CallOption { + return beforeCall(func(c *callInfo) error { + c.failFast = failFast + return nil + }) +} + // The format of the payload: compressed or not? type payloadFormat uint8 @@ -212,7 +227,7 @@ type parser struct { // No other error values or types must be returned, which also means // that the underlying io.Reader must not return an incompatible // error. -func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) { +func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err error) { if _, err := io.ReadFull(p.r, p.header[:]); err != nil { return 0, nil, err } @@ -223,6 +238,9 @@ func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) { if length == 0 { return pf, nil, nil } + if length > uint32(maxMsgSize) { + return 0, nil, Errorf(codes.Internal, "grpc: received message length %d exceeding the max size %d", length, maxMsgSize) + } // TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead // of making it for each message: msg = make([]byte, int(length)) @@ -293,8 +311,8 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er return nil } -func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}) error { - pf, d, err := p.recvMsg() +func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int) error { + pf, d, err := p.recvMsg(maxMsgSize) if err != nil { return err } @@ -304,11 +322,16 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{ if pf == compressionMade { d, err = dc.Do(bytes.NewReader(d)) if err != nil { - return transport.StreamErrorf(codes.Internal, "grpc: failed to decompress the received message %v", err) + return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) } } + if len(d) > maxMsgSize { + // TODO: Revisit the error code. Currently keep it consistent with java + // implementation. + return Errorf(codes.Internal, "grpc: received a message of %d bytes exceeding %d limit", len(d), maxMsgSize) + } if err := c.Unmarshal(d, m); err != nil { - return transport.StreamErrorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err) + return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err) } return nil } @@ -319,7 +342,7 @@ type rpcError struct { desc string } -func (e rpcError) Error() string { +func (e *rpcError) Error() string { return fmt.Sprintf("rpc error: code = %d desc = %s", e.code, e.desc) } @@ -329,7 +352,7 @@ func Code(err error) codes.Code { if err == nil { return codes.OK } - if e, ok := err.(rpcError); ok { + if e, ok := err.(*rpcError); ok { return e.code } return codes.Unknown @@ -341,7 +364,7 @@ func ErrorDesc(err error) string { if err == nil { return "" } - if e, ok := err.(rpcError); ok { + if e, ok := err.(*rpcError); ok { return e.desc } return err.Error() @@ -353,7 +376,7 @@ func Errorf(c codes.Code, format string, a ...interface{}) error { if c == codes.OK { return nil } - return rpcError{ + return &rpcError{ code: c, desc: fmt.Sprintf(format, a...), } @@ -362,18 +385,37 @@ func Errorf(c codes.Code, format string, a ...interface{}) error { // toRPCErr converts an error into a rpcError. func toRPCErr(err error) error { switch e := err.(type) { - case rpcError: + case *rpcError: return err case transport.StreamError: - return rpcError{ + return &rpcError{ code: e.Code, desc: e.Desc, } case transport.ConnectionError: - return rpcError{ + return &rpcError{ code: codes.Internal, desc: e.Desc, } + default: + switch err { + case context.DeadlineExceeded: + return &rpcError{ + code: codes.DeadlineExceeded, + desc: err.Error(), + } + case context.Canceled: + return &rpcError{ + code: codes.Canceled, + desc: err.Error(), + } + case ErrClientConnClosing: + return &rpcError{ + code: codes.FailedPrecondition, + desc: err.Error(), + } + } + } return Errorf(codes.Unknown, "%v", err) } diff --git a/rpc_util_test.go b/rpc_util_test.go index f6327f13..8a813c62 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -36,6 +36,7 @@ package grpc import ( "bytes" "io" + "math" "reflect" "testing" @@ -66,9 +67,9 @@ func TestSimpleParsing(t *testing.T) { } { buf := bytes.NewReader(test.p) parser := &parser{r: buf} - pt, b, err := parser.recvMsg() + pt, b, err := parser.recvMsg(math.MaxInt32) if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt { - t.Fatalf("parser{%v}.recvMsg() = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err) + t.Fatalf("parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err) } } } @@ -88,16 +89,16 @@ func TestMultipleParsing(t *testing.T) { {compressionNone, []byte("d")}, } for i, want := range wantRecvs { - pt, data, err := parser.recvMsg() + pt, data, err := parser.recvMsg(math.MaxInt32) if err != nil || pt != want.pt || !reflect.DeepEqual(data, want.data) { - t.Fatalf("after %d calls, parser{%v}.recvMsg() = %v, %v, %v\nwant %v, %v, ", + t.Fatalf("after %d calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, ", i, p, pt, data, err, want.pt, want.data) } } - pt, data, err := parser.recvMsg() + pt, data, err := parser.recvMsg(math.MaxInt32) if err != io.EOF { - t.Fatalf("after %d recvMsgs calls, parser{%v}.recvMsg() = %v, %v, %v\nwant _, _, %v", + t.Fatalf("after %d recvMsgs calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant _, _, %v", len(wantRecvs), p, pt, data, err, io.EOF) } } @@ -149,13 +150,17 @@ func TestToRPCErr(t *testing.T) { // input errIn error // outputs - errOut error + errOut *rpcError }{ - {transport.StreamErrorf(codes.Unknown, ""), Errorf(codes.Unknown, "")}, - {transport.ErrConnClosing, Errorf(codes.Internal, transport.ErrConnClosing.Desc)}, + {transport.StreamErrorf(codes.Unknown, ""), Errorf(codes.Unknown, "").(*rpcError)}, + {transport.ErrConnClosing, Errorf(codes.Internal, transport.ErrConnClosing.Desc).(*rpcError)}, } { err := toRPCErr(test.errIn) - if err != test.errOut { + rpcErr, ok := err.(*rpcError) + if !ok { + t.Fatalf("toRPCErr{%v} returned type %T, want %T", test.errIn, err, rpcError{}) + } + if *rpcErr != *test.errOut { t.Fatalf("toRPCErr{%v} = %v \nwant %v", test.errIn, err, test.errOut) } } @@ -178,6 +183,18 @@ func TestContextErr(t *testing.T) { } } +func TestErrorsWithSameParameters(t *testing.T) { + const description = "some description" + e1 := Errorf(codes.AlreadyExists, description).(*rpcError) + e2 := Errorf(codes.AlreadyExists, description).(*rpcError) + if e1 == e2 { + t.Fatalf("Error interfaces should not be considered equal - e1: %p - %v e2: %p - %v", e1, e1, e2, e2) + } + if Code(e1) != Code(e2) || ErrorDesc(e1) != ErrorDesc(e2) { + t.Fatalf("Expected errors to have same code and description - e1: %p - %v e2: %p - %v", e1, e1, e2, e2) + } +} + // bmEncode benchmarks encoding a Protocol Buffer message containing mSize // bytes. func bmEncode(b *testing.B, mSize int) { diff --git a/server.go b/server.go index 6782009e..1ed8aac9 100644 --- a/server.go +++ b/server.go @@ -89,9 +89,13 @@ type service struct { type Server struct { opts options - mu sync.Mutex // guards following - lis map[net.Listener]bool - conns map[io.Closer]bool + mu sync.Mutex // guards following + lis map[net.Listener]bool + conns map[io.Closer]bool + drain bool + // A CondVar to let GracefulStop() blocks until all the pending RPCs are finished + // and all the transport goes away. + cv *sync.Cond m map[string]*service // service name -> service info events trace.EventLog } @@ -101,12 +105,15 @@ type options struct { codec Codec cp Compressor dc Decompressor + maxMsgSize int unaryInt UnaryServerInterceptor streamInt StreamServerInterceptor maxConcurrentStreams uint32 useHandlerImpl bool // use http.Handler-based server } +var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit + // A ServerOption sets options. type ServerOption func(*options) @@ -117,20 +124,28 @@ func CustomCodec(codec Codec) ServerOption { } } -// RPCCompressor returns a ServerOption that sets a compressor for outbound message. +// RPCCompressor returns a ServerOption that sets a compressor for outbound messages. func RPCCompressor(cp Compressor) ServerOption { return func(o *options) { o.cp = cp } } -// RPCDecompressor returns a ServerOption that sets a decompressor for inbound message. +// RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages. func RPCDecompressor(dc Decompressor) ServerOption { return func(o *options) { o.dc = dc } } +// MaxMsgSize returns a ServerOption to set the max message size in bytes for inbound mesages. +// If this is not set, gRPC uses the default 4MB. +func MaxMsgSize(m int) ServerOption { + return func(o *options) { + o.maxMsgSize = m + } +} + // MaxConcurrentStreams returns a ServerOption that will apply a limit on the number // of concurrent streams to each ServerTransport. func MaxConcurrentStreams(n uint32) ServerOption { @@ -173,6 +188,7 @@ func StreamInterceptor(i StreamServerInterceptor) ServerOption { // started to accept requests yet. func NewServer(opt ...ServerOption) *Server { var opts options + opts.maxMsgSize = defaultMaxMsgSize for _, o := range opt { o(&opts) } @@ -186,6 +202,7 @@ func NewServer(opt ...ServerOption) *Server { conns: make(map[io.Closer]bool), m: make(map[string]*service), } + s.cv = sync.NewCond(&s.mu) if EnableTracing { _, file, line, _ := runtime.Caller(1) s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line)) @@ -245,28 +262,45 @@ func (s *Server) register(sd *ServiceDesc, ss interface{}) { s.m[sd.ServiceName] = srv } -// ServiceInfo contains method names and metadata for a service. +// MethodInfo contains the information of an RPC including its method name and type. +type MethodInfo struct { + // Name is the method name only, without the service name or package name. + Name string + // IsClientStream indicates whether the RPC is a client streaming RPC. + IsClientStream bool + // IsServerStream indicates whether the RPC is a server streaming RPC. + IsServerStream bool +} + +// ServiceInfo contains unary RPC method info, streaming RPC methid info and metadata for a service. type ServiceInfo struct { - // Methods are method names only, without the service name or package name. - Methods []string + Methods []MethodInfo // Metadata is the metadata specified in ServiceDesc when registering service. Metadata interface{} } // GetServiceInfo returns a map from service names to ServiceInfo. // Service names include the package names, in the form of .. -func (s *Server) GetServiceInfo() map[string]*ServiceInfo { - ret := make(map[string]*ServiceInfo) +func (s *Server) GetServiceInfo() map[string]ServiceInfo { + ret := make(map[string]ServiceInfo) for n, srv := range s.m { - methods := make([]string, 0, len(srv.md)+len(srv.sd)) + methods := make([]MethodInfo, 0, len(srv.md)+len(srv.sd)) for m := range srv.md { - methods = append(methods, m) + methods = append(methods, MethodInfo{ + Name: m, + IsClientStream: false, + IsServerStream: false, + }) } - for m := range srv.sd { - methods = append(methods, m) + for m, d := range srv.sd { + methods = append(methods, MethodInfo{ + Name: m, + IsClientStream: d.ClientStreams, + IsServerStream: d.ServerStreams, + }) } - ret[n] = &ServiceInfo{ + ret[n] = ServiceInfo{ Methods: methods, Metadata: srv.mdata, } @@ -303,9 +337,11 @@ func (s *Server) Serve(lis net.Listener) error { s.lis[lis] = true s.mu.Unlock() defer func() { - lis.Close() s.mu.Lock() - delete(s.lis, lis) + if s.lis != nil && s.lis[lis] { + lis.Close() + delete(s.lis, lis) + } s.mu.Unlock() }() for { @@ -449,7 +485,7 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea func (s *Server) addConn(c io.Closer) bool { s.mu.Lock() defer s.mu.Unlock() - if s.conns == nil { + if s.conns == nil || s.drain { return false } s.conns[c] = true @@ -461,6 +497,7 @@ func (s *Server) removeConn(c io.Closer) { defer s.mu.Unlock() if s.conns != nil { delete(s.conns, c) + s.cv.Signal() } } @@ -501,7 +538,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } p := &parser{r: stream} for { - pf, req, err := p.recvMsg() + pf, req, err := p.recvMsg(s.opts.maxMsgSize) if err == io.EOF { // The entire stream is done (for unary RPC only). return err @@ -511,6 +548,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } if err != nil { switch err := err.(type) { + case *rpcError: + if err := t.WriteStatus(stream, err.code, err.desc); err != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err) + } case transport.ConnectionError: // Nothing to do here. case transport.StreamError: @@ -550,6 +591,12 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. return err } } + if len(req) > s.opts.maxMsgSize { + // TODO: Revisit the error code. Currently keep it consistent with + // java implementation. + statusCode = codes.Internal + statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize) + } if err := s.opts.codec.Unmarshal(req, v); err != nil { return err } @@ -560,7 +607,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt) if appErr != nil { - if err, ok := appErr.(rpcError); ok { + if err, ok := appErr.(*rpcError); ok { statusCode = err.code statusDesc = err.desc } else { @@ -609,13 +656,14 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp stream.SetSendCompress(s.opts.cp.Type()) } ss := &serverStream{ - t: t, - s: stream, - p: &parser{r: stream}, - codec: s.opts.codec, - cp: s.opts.cp, - dc: s.opts.dc, - trInfo: trInfo, + t: t, + s: stream, + p: &parser{r: stream}, + codec: s.opts.codec, + cp: s.opts.cp, + dc: s.opts.dc, + maxMsgSize: s.opts.maxMsgSize, + trInfo: trInfo, } if ss.cp != nil { ss.cbuf = new(bytes.Buffer) @@ -645,7 +693,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp appErr = s.opts.streamInt(srv.server, ss, info, sd.Handler) } if appErr != nil { - if err, ok := appErr.(rpcError); ok { + if err, ok := appErr.(*rpcError); ok { ss.statusCode = err.code ss.statusDesc = err.desc } else if err, ok := appErr.(transport.StreamError); ok { @@ -747,14 +795,16 @@ func (s *Server) Stop() { s.mu.Lock() listeners := s.lis s.lis = nil - cs := s.conns + st := s.conns s.conns = nil + // interrupt GracefulStop if Stop and GracefulStop are called concurrently. + s.cv.Signal() s.mu.Unlock() for lis := range listeners { lis.Close() } - for c := range cs { + for c := range st { c.Close() } @@ -766,6 +816,32 @@ func (s *Server) Stop() { s.mu.Unlock() } +// GracefulStop stops the gRPC server gracefully. It stops the server to accept new +// connections and RPCs and blocks until all the pending RPCs are finished. +func (s *Server) GracefulStop() { + s.mu.Lock() + defer s.mu.Unlock() + if s.drain == true || s.conns == nil { + return + } + s.drain = true + for lis := range s.lis { + lis.Close() + } + s.lis = nil + for c := range s.conns { + c.(transport.ServerTransport).Drain() + } + for len(s.conns) != 0 { + s.cv.Wait() + } + s.conns = nil + if s.events != nil { + s.events.Finish() + s.events = nil + } +} + func init() { internal.TestingCloseConns = func(arg interface{}) { arg.(*Server).testingCloseConns() diff --git a/server_test.go b/server_test.go index 7c1e54dd..23838806 100644 --- a/server_test.go +++ b/server_test.go @@ -79,7 +79,7 @@ func TestGetServiceInfo(t *testing.T) { { StreamName: "EmptyStream", Handler: nil, - ServerStreams: true, + ServerStreams: false, ClientStreams: true, }, }, @@ -90,17 +90,24 @@ func TestGetServiceInfo(t *testing.T) { server.RegisterService(&testSd, &testServer{}) info := server.GetServiceInfo() - want := map[string]*ServiceInfo{ - "grpc.testing.EmptyService": &ServiceInfo{ - Methods: []string{ - "EmptyCall", - "EmptyStream", - }, + want := map[string]ServiceInfo{ + "grpc.testing.EmptyService": { + Methods: []MethodInfo{ + { + Name: "EmptyCall", + IsClientStream: false, + IsServerStream: false, + }, + { + Name: "EmptyStream", + IsClientStream: true, + IsServerStream: false, + }}, Metadata: []int{0, 2, 1, 3}, }, } if !reflect.DeepEqual(info, want) { - t.Errorf("GetServiceInfo() = %q, want %q", info, want) + t.Errorf("GetServiceInfo() = %+v, want %+v", info, want) } } diff --git a/stream.go b/stream.go index 25be4b81..51df3f01 100644 --- a/stream.go +++ b/stream.go @@ -37,6 +37,7 @@ import ( "bytes" "errors" "io" + "math" "sync" "time" @@ -84,12 +85,9 @@ type ClientStream interface { // Header returns the header metadata received from the server if there // is any. It blocks if the metadata is not ready to read. Header() (metadata.MD, error) - // Trailer returns the trailer metadata from the server. It must be called - // after stream.Recv() returns non-nil error (including io.EOF) for - // bi-directional streaming and server streaming or stream.CloseAndRecv() - // returns for client streaming in order to receive trailer metadata if - // present. Otherwise, it could returns an empty MD even though trailer - // is present. + // Trailer returns the trailer metadata from the server, if there is any. + // It must only be called after stream.CloseAndRecv has returned, or + // stream.Recv has returned a non-nil error (including io.EOF). Trailer() metadata.MD // CloseSend closes the send direction of the stream. It closes the stream // when non-nil error is met. @@ -99,19 +97,17 @@ type ClientStream interface { // NewClientStream creates a new Stream for the client side. This is called // by generated code. -func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) { +func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) { var ( t transport.ClientTransport - err error + s *transport.Stream put func() ) - // TODO(zhaoq): CallOption is omitted. Add support when it is needed. - gopts := BalancerGetOptions{ - BlockingWait: false, - } - t, put, err = cc.getTransport(ctx, gopts) - if err != nil { - return nil, toRPCErr(err) + c := defaultCallInfo + for _, o := range opts { + if err := o.before(&c); err != nil { + return nil, toRPCErr(err) + } } callHdr := &transport.CallHdr{ Host: cc.authority, @@ -121,41 +117,98 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth if cc.dopts.cp != nil { callHdr.SendCompress = cc.dopts.cp.Type() } + var trInfo traceInfo + if EnableTracing { + trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method) + trInfo.firstLine.client = true + if deadline, ok := ctx.Deadline(); ok { + trInfo.firstLine.deadline = deadline.Sub(time.Now()) + } + trInfo.tr.LazyLog(&trInfo.firstLine, false) + ctx = trace.NewContext(ctx, trInfo.tr) + defer func() { + if err != nil { + // Need to call tr.finish() if error is returned. + // Because tr will not be returned to caller. + trInfo.tr.LazyPrintf("RPC: [%v]", err) + trInfo.tr.SetError() + trInfo.tr.Finish() + } + }() + } + gopts := BalancerGetOptions{ + BlockingWait: !c.failFast, + } + for { + t, put, err = cc.getTransport(ctx, gopts) + if err != nil { + // TODO(zhaoq): Probably revisit the error handling. + if _, ok := err.(*rpcError); ok { + return nil, err + } + if err == errConnClosing || err == errConnUnavailable { + if c.failFast { + return nil, Errorf(codes.Unavailable, "%v", err) + } + continue + } + // All the other errors are treated as Internal errors. + return nil, Errorf(codes.Internal, "%v", err) + } + + s, err = t.NewStream(ctx, callHdr) + if err != nil { + if put != nil { + put() + put = nil + } + if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { + if c.failFast { + return nil, toRPCErr(err) + } + continue + } + return nil, toRPCErr(err) + } + break + } cs := &clientStream{ - desc: desc, - put: put, - codec: cc.dopts.codec, - cp: cc.dopts.cp, - dc: cc.dopts.dc, + opts: opts, + c: c, + desc: desc, + codec: cc.dopts.codec, + cp: cc.dopts.cp, + dc: cc.dopts.dc, + + put: put, + t: t, + s: s, + p: &parser{r: s}, + tracing: EnableTracing, + trInfo: trInfo, } if cc.dopts.cp != nil { - callHdr.SendCompress = cc.dopts.cp.Type() cs.cbuf = new(bytes.Buffer) } - if cs.tracing { - cs.trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method) - cs.trInfo.firstLine.client = true - if deadline, ok := ctx.Deadline(); ok { - cs.trInfo.firstLine.deadline = deadline.Sub(time.Now()) - } - cs.trInfo.tr.LazyLog(&cs.trInfo.firstLine, false) - ctx = trace.NewContext(ctx, cs.trInfo.tr) - } - s, err := t.NewStream(ctx, callHdr) - if err != nil { - cs.finish(err) - return nil, toRPCErr(err) - } - cs.t = t - cs.s = s - cs.p = &parser{r: s} - // Listen on ctx.Done() to detect cancellation when there is no pending - // I/O operations on this stream. + // Listen on ctx.Done() to detect cancellation and s.Done() to detect normal termination + // when there is no pending I/O operations on this stream. go func() { select { case <-t.Error(): // Incur transport error, simply exit. + case <-s.Done(): + // TODO: The trace of the RPC is terminated here when there is no pending + // I/O, which is probably not the optimal solution. + if s.StatusCode() == codes.OK { + cs.finish(nil) + } else { + cs.finish(Errorf(s.StatusCode(), "%s", s.StatusDesc())) + } + cs.closeTransportStream(nil) + case <-s.GoAway(): + cs.finish(errConnDrain) + cs.closeTransportStream(errConnDrain) case <-s.Context().Done(): err := s.Context().Err() cs.finish(err) @@ -167,6 +220,8 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth // clientStream implements a client side Stream. type clientStream struct { + opts []CallOption + c callInfo t transport.ClientTransport s *transport.Stream p *parser @@ -216,7 +271,17 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { if err != nil { cs.finish(err) } - if err == nil || err == io.EOF { + if err == nil { + return + } + if err == io.EOF { + // Specialize the process for server streaming. SendMesg is only called + // once when creating the stream object. io.EOF needs to be skipped when + // the rpc is early finished (before the stream object is created.). + // TODO: It is probably better to move this into the generated code. + if !cs.desc.ClientStreams && cs.desc.ServerStreams { + err = nil + } return } if _, ok := err.(transport.ConnectionError); !ok { @@ -237,7 +302,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { } func (cs *clientStream) RecvMsg(m interface{}) (err error) { - err = recv(cs.p, cs.codec, cs.s, cs.dc, m) + err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32) defer func() { // err != nil indicates the termination of the stream. if err != nil { @@ -256,7 +321,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { return } // Special handling for client streaming rpc. - err = recv(cs.p, cs.codec, cs.s, cs.dc, m) + err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32) cs.closeTransportStream(err) if err == nil { return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) @@ -291,7 +356,7 @@ func (cs *clientStream) CloseSend() (err error) { } }() if err == nil || err == io.EOF { - return + return nil } if _, ok := err.(transport.ConnectionError); !ok { cs.closeTransportStream(err) @@ -312,15 +377,18 @@ func (cs *clientStream) closeTransportStream(err error) { } func (cs *clientStream) finish(err error) { - if !cs.tracing { - return - } cs.mu.Lock() defer cs.mu.Unlock() + for _, o := range cs.opts { + o.after(&cs.c) + } if cs.put != nil { cs.put() cs.put = nil } + if !cs.tracing { + return + } if cs.trInfo.tr != nil { if err == nil || err == io.EOF { cs.trInfo.tr.LazyPrintf("RPC: [OK]") @@ -354,6 +422,7 @@ type serverStream struct { cp Compressor dc Decompressor cbuf *bytes.Buffer + maxMsgSize int statusCode codes.Code statusDesc string trInfo *traceInfo @@ -420,5 +489,5 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { ss.mu.Unlock() } }() - return recv(ss.p, ss.codec, ss.s, ss.dc, m) + return recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize) } diff --git a/stress/client/main.go b/stress/client/main.go index bb665e98..4579aab4 100644 --- a/stress/client/main.go +++ b/stress/client/main.go @@ -162,7 +162,7 @@ func (s *server) GetAllGauges(in *metricspb.EmptyMessage, stream metricspb.Metri defer s.mutex.RUnlock() for name, gauge := range s.gauges { - if err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{gauge.get()}}); err != nil { + if err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{LongValue: gauge.get()}}); err != nil { return err } } @@ -175,7 +175,7 @@ func (s *server) GetGauge(ctx context.Context, in *metricspb.GaugeRequest) (*met defer s.mutex.RUnlock() if g, ok := s.gauges[in.Name]; ok { - return &metricspb.GaugeResponse{Name: in.Name, Value: &metricspb.GaugeResponse_LongValue{g.get()}}, nil + return &metricspb.GaugeResponse{Name: in.Name, Value: &metricspb.GaugeResponse_LongValue{LongValue: g.get()}}, nil } return nil, grpc.Errorf(codes.InvalidArgument, "gauge with name %s not found", in.Name) } diff --git a/test/end2end_test.go b/test/end2end_test.go index b539584b..a97a8a47 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -90,7 +90,8 @@ var ( var raceMode bool // set by race_test.go in race mode type testServer struct { - security string // indicate the authentication protocol used by this server. + security string // indicate the authentication protocol used by this server. + earlyFail bool // whether to error out the execution of a service handler prematurely. } func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { @@ -219,6 +220,9 @@ func (s *testServer) StreamingInputCall(stream testpb.TestService_StreamingInput } p := in.GetPayload().GetBody() sum += len(p) + if s.earlyFail { + return grpc.Errorf(codes.NotFound, "not found") + } } } @@ -296,41 +300,33 @@ func (s *testServer) HalfDuplexCall(stream testpb.TestService_HalfDuplexCallServ const tlsDir = "testdata/" -func unixDialer(addr string, timeout time.Duration) (net.Conn, error) { - return net.DialTimeout("unix", addr, timeout) -} - type env struct { name string network string // The type of network such as tcp, unix, etc. - dialer func(addr string, timeout time.Duration) (net.Conn, error) security string // The security protocol such as TLS, SSH, etc. httpHandler bool // whether to use the http.Handler ServerTransport; requires TLS + balancer bool // whether to use balancer } func (e env) runnable() bool { - if runtime.GOOS == "windows" && strings.HasPrefix(e.name, "unix-") { + if runtime.GOOS == "windows" && e.network == "unix" { return false } return true } -func (e env) getDialer() func(addr string, timeout time.Duration) (net.Conn, error) { - if e.dialer != nil { - return e.dialer - } - return func(addr string, timeout time.Duration) (net.Conn, error) { - return net.DialTimeout("tcp", addr, timeout) - } +func (e env) dialer(addr string, timeout time.Duration) (net.Conn, error) { + return net.DialTimeout(e.network, addr, timeout) } var ( - tcpClearEnv = env{name: "tcp-clear", network: "tcp"} - tcpTLSEnv = env{name: "tcp-tls", network: "tcp", security: "tls"} - unixClearEnv = env{name: "unix-clear", network: "unix", dialer: unixDialer} - unixTLSEnv = env{name: "unix-tls", network: "unix", dialer: unixDialer, security: "tls"} - handlerEnv = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true} - allEnv = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv, handlerEnv} + tcpClearEnv = env{name: "tcp-clear", network: "tcp", balancer: true} + tcpTLSEnv = env{name: "tcp-tls", network: "tcp", security: "tls", balancer: true} + unixClearEnv = env{name: "unix-clear", network: "unix", balancer: true} + unixTLSEnv = env{name: "unix-tls", network: "unix", security: "tls", balancer: true} + handlerEnv = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true, balancer: true} + noBalancerEnv = env{name: "no-balancer", network: "tcp", security: "tls", balancer: false} + allEnv = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv, handlerEnv, noBalancerEnv} ) var onlyEnv = flag.String("only_env", "", "If non-empty, one of 'tcp-clear', 'tcp-tls', 'unix-clear', 'unix-tls', or 'handler-tls' to only run the tests for that environment. Empty means all.") @@ -367,8 +363,9 @@ type test struct { // Configurable knobs, after newTest returns: testServer testpb.TestServiceServer // nil means none - healthServer *health.HealthServer // nil means disabled + healthServer *health.Server // nil means disabled maxStream uint32 + maxMsgSize int userAgent string clientCompression bool serverCompression bool @@ -404,10 +401,9 @@ func (te *test) tearDown() { // modify it before calling its startServer and clientConn methods. func newTest(t *testing.T, e env) *test { te := &test{ - t: t, - e: e, - testServer: &testServer{security: e.security}, - maxStream: math.MaxUint32, + t: t, + e: e, + maxStream: math.MaxUint32, } te.ctx, te.cancel = context.WithCancel(context.Background()) return te @@ -415,10 +411,14 @@ func newTest(t *testing.T, e env) *test { // startServer starts a gRPC server listening. Callers should defer a // call to te.tearDown to clean up. -func (te *test) startServer() { +func (te *test) startServer(ts testpb.TestServiceServer) { + te.testServer = ts e := te.e te.t.Logf("Running test in %s environment...", e.name) sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(te.maxStream)} + if te.maxMsgSize > 0 { + sopts = append(sopts, grpc.MaxMsgSize(te.maxMsgSize)) + } if te.serverCompression { sopts = append(sopts, grpc.RPCCompressor(grpc.NewGZIPCompressor()), @@ -441,12 +441,17 @@ func (te *test) startServer() { if err != nil { te.t.Fatalf("Failed to listen: %v", err) } - if te.e.security == "tls" { + switch te.e.security { + case "tls": creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key") if err != nil { te.t.Fatalf("Failed to generate credentials %v", err) } sopts = append(sopts, grpc.Creds(creds)) + case "clientAlwaysFailCred": + sopts = append(sopts, grpc.Creds(clientAlwaysFailCred{})) + case "clientTimeoutCreds": + sopts = append(sopts, grpc.Creds(&clientTimeoutCreds{})) } s := grpc.NewServer(sopts...) te.srv = s @@ -489,15 +494,23 @@ func (te *test) clientConn() *grpc.ClientConn { grpc.WithDecompressor(grpc.NewGZIPDecompressor()), ) } - if te.e.security == "tls" { + switch te.e.security { + case "tls": creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") if err != nil { te.t.Fatalf("Failed to load credentials: %v", err) } opts = append(opts, grpc.WithTransportCredentials(creds)) - } else { + case "clientAlwaysFailCred": + opts = append(opts, grpc.WithTransportCredentials(clientAlwaysFailCred{})) + case "clientTimeoutCreds": + opts = append(opts, grpc.WithTransportCredentials(&clientTimeoutCreds{})) + default: opts = append(opts, grpc.WithInsecure()) } + if te.e.balancer { + opts = append(opts, grpc.WithBalancer(grpc.RoundRobin(nil))) + } var err error te.cc, err = grpc.Dial(te.srvAddr, opts...) if err != nil { @@ -511,9 +524,7 @@ func (te *test) declareLogNoise(phrases ...string) { } func (te *test) withServerTester(fn func(st *serverTester)) { - var c net.Conn - var err error - c, err = te.e.getDialer()(te.srvAddr, 10*time.Second) + c, err := te.e.dialer(te.srvAddr, 10*time.Second) if err != nil { te.t.Fatal(err) } @@ -541,11 +552,283 @@ func testTimeoutOnDeadServer(t *testing.T, e env) { te.userAgent = testAppUA te.declareLogNoise( "transport: http2Client.notifyError got notified that the client transport was broken EOF", - "grpc: Conn.transportMonitor exits due to: grpc: the client connection is closing", - "grpc: Conn.resetTransport failed to create client transport: connection error", - "grpc: Conn.resetTransport failed to create client transport: connection error: desc = \"transport: dial unix", + "grpc: addrConn.transportMonitor exits due to: grpc: the connection is closing", + "grpc: addrConn.resetTransport failed to create client transport: connection error", ) - te.startServer() + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) + } + te.srv.Stop() + ctx, _ := context.WithTimeout(context.Background(), time.Millisecond) + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf("TestService/EmptyCall(%v, _) = _, %v, want _, error code: %d", ctx, err, codes.DeadlineExceeded) + } + awaitNewConnLogOutput() +} + +func TestServerGracefulStopIdempotent(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + if e.name == "handler-tls" { + continue + } + testServerGracefulStopIdempotent(t, e) + } +} + +func testServerGracefulStopIdempotent(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + for i := 0; i < 3; i++ { + te.srv.GracefulStop() + } +} + +func TestServerGoAway(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + if e.name == "handler-tls" { + continue + } + testServerGoAway(t, e) + } +} + +func testServerGoAway(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + // Finish an RPC to make sure the connection is good. + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) + } + ch := make(chan struct{}) + go func() { + te.srv.GracefulStop() + close(ch) + }() + // Loop until the server side GoAway signal is propagated to the client. + for { + ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil { + continue + } + break + } + // A new RPC should fail. + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable && grpc.Code(err) != codes.Internal { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s or %s", err, codes.Unavailable, codes.Internal) + } + <-ch + awaitNewConnLogOutput() +} + +func TestServerGoAwayPendingRPC(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + if e.name == "handler-tls" { + continue + } + testServerGoAwayPendingRPC(t, e) + } +} + +func testServerGoAwayPendingRPC(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + te.declareLogNoise( + "transport: http2Client.notifyError got notified that the client transport was broken EOF", + "grpc: addrConn.transportMonitor exits due to: grpc: the connection is closing", + "grpc: addrConn.resetTransport failed to create client transport: connection error", + ) + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + ctx, cancel := context.WithCancel(context.Background()) + stream, err := tc.FullDuplexCall(ctx, grpc.FailFast(false)) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + // Finish an RPC to make sure the connection is good. + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { + t.Fatalf("%v.EmptyCall(_, _, _) = _, %v, want _, ", tc, err) + } + ch := make(chan struct{}) + go func() { + te.srv.GracefulStop() + close(ch) + }() + // Loop until the server side GoAway signal is propagated to the client. + for { + ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil { + continue + } + break + } + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(1), + }, + } + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(100)) + if err != nil { + t.Fatal(err) + } + req := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: payload, + } + // The existing RPC should be still good to proceed. + if err := stream.Send(req); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, req, err) + } + if _, err := stream.Recv(); err != nil { + t.Fatalf("%v.Recv() = _, %v, want _, ", stream, err) + } + cancel() + <-ch + awaitNewConnLogOutput() +} + +func TestConcurrentClientConnCloseAndServerGoAway(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + if e.name == "handler-tls" { + continue + } + testConcurrentClientConnCloseAndServerGoAway(t, e) + } +} + +func testConcurrentClientConnCloseAndServerGoAway(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + te.declareLogNoise( + "transport: http2Client.notifyError got notified that the client transport was broken EOF", + "grpc: addrConn.transportMonitor exits due to: grpc: the connection is closing", + "grpc: addrConn.resetTransport failed to create client transport: connection error", + ) + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { + t.Fatalf("%v.EmptyCall(_, _, _) = _, %v, want _, ", tc, err) + } + ch := make(chan struct{}) + // Close ClientConn and Server concurrently. + go func() { + te.srv.GracefulStop() + close(ch) + }() + go func() { + cc.Close() + }() + <-ch +} + +func TestConcurrentServerStopAndGoAway(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + if e.name == "handler-tls" { + continue + } + testConcurrentServerStopAndGoAway(t, e) + } +} + +func testConcurrentServerStopAndGoAway(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + te.declareLogNoise( + "transport: http2Client.notifyError got notified that the client transport was broken EOF", + "grpc: addrConn.transportMonitor exits due to: grpc: the connection is closing", + "grpc: addrConn.resetTransport failed to create client transport: connection error", + ) + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + stream, err := tc.FullDuplexCall(context.Background(), grpc.FailFast(false)) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + // Finish an RPC to make sure the connection is good. + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { + t.Fatalf("%v.EmptyCall(_, _, _) = _, %v, want _, ", tc, err) + } + ch := make(chan struct{}) + go func() { + te.srv.GracefulStop() + close(ch) + }() + // Loop until the server side GoAway signal is propagated to the client. + for { + ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil { + continue + } + break + } + // Stop the server and close all the connections. + te.srv.Stop() + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(1), + }, + } + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(100)) + if err != nil { + t.Fatal(err) + } + req := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: payload, + } + if err := stream.Send(req); err == nil { + if _, err := stream.Recv(); err == nil { + t.Fatalf("%v.Recv() = _, %v, want _, ", stream, err) + } + } + <-ch + awaitNewConnLogOutput() +} + +func TestFailFast(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testFailFast(t, e) + } +} + +func testFailFast(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + te.declareLogNoise( + "transport: http2Client.notifyError got notified that the client transport was broken EOF", + "grpc: addrConn.transportMonitor exits due to: grpc: the connection is closing", + "grpc: addrConn.resetTransport failed to create client transport: connection error", + ) + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() @@ -553,14 +836,25 @@ func testTimeoutOnDeadServer(t *testing.T, e env) { if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) } + // Stop the server and tear down all the exisiting connections. te.srv.Stop() - // Set -1 as the timeout to make sure if transportMonitor gets error - // notification in time the failure path of the 1st invoke of - // ClientConn.wait hits the deadline exceeded error. - ctx, _ := context.WithTimeout(context.Background(), -1) - if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); grpc.Code(err) != codes.DeadlineExceeded { - t.Fatalf("TestService/EmptyCall(%v, _) = _, %v, want _, error code: %d", ctx, err, codes.DeadlineExceeded) + // Loop until the server teardown is propagated to the client. + for { + _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}) + if grpc.Code(err) == codes.Unavailable { + break + } + fmt.Printf("%v.EmptyCall(_, _) = _, %v", tc, err) + time.Sleep(10 * time.Millisecond) } + // The client keeps reconnecting and ongoing fail-fast RPCs should fail with code.Unavailable. + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable { + t.Fatalf("TestService/EmptyCall(_, _, _) = _, %v, want _, error code: %d", err, codes.Unavailable) + } + if _, err := tc.StreamingInputCall(context.Background()); grpc.Code(err) != codes.Unavailable { + t.Fatalf("TestService/StreamingInputCall(_) = _, %v, want _, error code: %d", err, codes.Unavailable) + } + awaitNewConnLogOutput() } @@ -582,10 +876,10 @@ func TestHealthCheckOnSuccess(t *testing.T) { func testHealthCheckOnSuccess(t *testing.T, e env) { te := newTest(t, e) - hs := health.NewHealthServer() + hs := health.NewServer() hs.SetServingStatus("grpc.health.v1.Health", 1) te.healthServer = hs - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() @@ -608,14 +902,15 @@ func testHealthCheckOnFailure(t *testing.T, e env) { "Failed to dial ", "grpc: the client connection is closing; please retry", ) - hs := health.NewHealthServer() + hs := health.NewServer() hs.SetServingStatus("grpc.health.v1.HealthCheck", 1) te.healthServer = hs - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() - if _, err := healthCheck(0*time.Second, cc, "grpc.health.v1.Health"); err != grpc.Errorf(codes.DeadlineExceeded, "context deadline exceeded") { + wantErr := grpc.Errorf(codes.DeadlineExceeded, "context deadline exceeded") + if _, err := healthCheck(0*time.Second, cc, "grpc.health.v1.Health"); !equalErrors(err, wantErr) { t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %d", err, codes.DeadlineExceeded) } awaitNewConnLogOutput() @@ -634,10 +929,10 @@ func TestHealthCheckOff(t *testing.T) { func testHealthCheckOff(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() want := grpc.Errorf(codes.Unimplemented, "unknown service grpc.health.v1.Health") - if _, err := healthCheck(1*time.Second, te.clientConn(), ""); err != want { + if _, err := healthCheck(1*time.Second, te.clientConn(), ""); !equalErrors(err, want) { t.Fatalf("Health/Check(_, _) = _, %v, want _, %v", err, want) } } @@ -651,9 +946,9 @@ func TestHealthCheckServingStatus(t *testing.T) { func testHealthCheckServingStatus(t *testing.T, e env) { te := newTest(t, e) - hs := health.NewHealthServer() + hs := health.NewServer() te.healthServer = hs - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() @@ -664,7 +959,8 @@ func testHealthCheckServingStatus(t *testing.T, e env) { if out.Status != healthpb.HealthCheckResponse_SERVING { t.Fatalf("Got the serving status %v, want SERVING", out.Status) } - if _, err := healthCheck(1*time.Second, cc, "grpc.health.v1.Health"); err != grpc.Errorf(codes.NotFound, "unknown service") { + wantErr := grpc.Errorf(codes.NotFound, "unknown service") + if _, err := healthCheck(1*time.Second, cc, "grpc.health.v1.Health"); !equalErrors(err, wantErr) { t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %d", err, codes.NotFound) } hs.SetServingStatus("grpc.health.v1.Health", healthpb.HealthCheckResponse_SERVING) @@ -695,7 +991,7 @@ func TestErrorChanNoIO(t *testing.T) { func testErrorChanNoIO(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -714,7 +1010,7 @@ func TestEmptyUnaryWithUserAgent(t *testing.T) { func testEmptyUnaryWithUserAgent(t *testing.T, e env) { te := newTest(t, e) te.userAgent = testAppUA - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() @@ -740,13 +1036,13 @@ func TestFailedEmptyUnary(t *testing.T) { func testFailedEmptyUnary(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) ctx := metadata.NewContext(context.Background(), testMetadata) wantErr := grpc.Errorf(codes.DataLoss, "missing expected user-agent") - if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != wantErr { + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !equalErrors(err, wantErr) { t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, wantErr) } } @@ -760,7 +1056,7 @@ func TestLargeUnary(t *testing.T) { func testLargeUnary(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -788,6 +1084,65 @@ func testLargeUnary(t *testing.T, e env) { } } +func TestExceedMsgLimit(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testExceedMsgLimit(t, e) + } +} + +func testExceedMsgLimit(t *testing.T, e env) { + te := newTest(t, e) + te.maxMsgSize = 1024 + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + tc := testpb.NewTestServiceClient(te.clientConn()) + + argSize := int32(te.maxMsgSize + 1) + const respSize = 1 + + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize) + if err != nil { + t.Fatal(err) + } + + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(respSize), + Payload: payload, + } + if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %d", err, codes.Internal) + } + + stream, err := tc.FullDuplexCall(te.ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(1), + }, + } + + spayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(te.maxMsgSize+1)) + if err != nil { + t.Fatal(err) + } + + sreq := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: spayload, + } + if err := stream.Send(sreq); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) + } + if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal { + t.Fatalf("%v.Recv() = _, %v, want _, error code: %d", stream, err, codes.Internal) + } +} + func TestMetadataUnaryRPC(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() { @@ -797,7 +1152,7 @@ func TestMetadataUnaryRPC(t *testing.T) { func testMetadataUnaryRPC(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -843,7 +1198,7 @@ func TestMalformedHTTP2Metadata(t *testing.T) { func testMalformedHTTP2Metadata(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -879,7 +1234,7 @@ func performOneRPC(t *testing.T, tc testpb.TestServiceClient, wg *sync.WaitGroup ResponseSize: proto.Int32(respSize), Payload: payload, } - reply, err := tc.UnaryCall(context.Background(), req) + reply, err := tc.UnaryCall(context.Background(), req, grpc.FailFast(false)) if err != nil { t.Errorf("TestService/UnaryCall(_, _) = _, %v, want _, ", err) return @@ -905,7 +1260,7 @@ func TestRetry(t *testing.T) { func testRetry(t *testing.T, e env) { te := newTest(t, e) te.declareLogNoise("transport: http2Client.notifyError got notified that the client transport was broken") - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() @@ -954,7 +1309,7 @@ func TestRPCTimeout(t *testing.T) { // TODO(zhaoq): Have a better test coverage of timeout and cancellation mechanism. func testRPCTimeout(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() @@ -991,7 +1346,7 @@ func TestCancel(t *testing.T) { func testCancel(t *testing.T, e env) { te := newTest(t, e) te.declareLogNoise("grpc: the client connection is closing; please retry") - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() @@ -1029,7 +1384,7 @@ func testCancelNoIO(t *testing.T, e env) { te := newTest(t, e) te.declareLogNoise("http2Client.notifyError got notified that the client transport was broken") te.maxStream = 1 // Only allows 1 live stream per server transport. - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() @@ -1099,16 +1454,13 @@ func TestNoService(t *testing.T) { func testNoService(t *testing.T, e env) { te := newTest(t, e) - te.testServer = nil // register nothing - te.startServer() + te.startServer(nil) defer te.tearDown() cc := te.clientConn() tc := testpb.NewTestServiceClient(cc) - // Make sure setting ack has been sent. - time.Sleep(20 * time.Millisecond) - stream, err := tc.FullDuplexCall(te.ctx) + stream, err := tc.FullDuplexCall(te.ctx, grpc.FailFast(false)) if err != nil { t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) } @@ -1126,7 +1478,7 @@ func TestPingPong(t *testing.T) { func testPingPong(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -1186,7 +1538,7 @@ func TestMetadataStreamingRPC(t *testing.T) { func testMetadataStreamingRPC(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -1257,7 +1609,7 @@ func TestServerStreaming(t *testing.T) { func testServerStreaming(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -1312,7 +1664,7 @@ func TestFailedServerStreaming(t *testing.T) { func testFailedServerStreaming(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -1331,14 +1683,15 @@ func testFailedServerStreaming(t *testing.T, e env) { if err != nil { t.Fatalf("%v.StreamingOutputCall(_) = _, %v, want ", tc, err) } - if _, err := stream.Recv(); err != grpc.Errorf(codes.DataLoss, "got extra metadata") { - t.Fatalf("%v.Recv() = _, %v, want _, %v", stream, err, grpc.Errorf(codes.DataLoss, "got extra metadata")) + wantErr := grpc.Errorf(codes.DataLoss, "got extra metadata") + if _, err := stream.Recv(); !equalErrors(err, wantErr) { + t.Fatalf("%v.Recv() = _, %v, want _, %v", stream, err, wantErr) } } // concurrentSendServer is a TestServiceServer whose // StreamingOutputCall makes ten serial Send calls, sending payloads -// "0".."9", inclusive. TestServerStreaming_Concurrent verifies they +// "0".."9", inclusive. TestServerStreamingConcurrent verifies they // were received in the correct order, and that there were no races. // // All other TestServiceServer methods crash if called. @@ -1358,17 +1711,16 @@ func (s concurrentSendServer) StreamingOutputCall(args *testpb.StreamingOutputCa } // Tests doing a bunch of concurrent streaming output calls. -func TestServerStreaming_Concurrent(t *testing.T) { +func TestServerStreamingConcurrent(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() { - testServerStreaming_Concurrent(t, e) + testServerStreamingConcurrent(t, e) } } -func testServerStreaming_Concurrent(t *testing.T, e env) { +func testServerStreamingConcurrent(t *testing.T, e env) { te := newTest(t, e) - te.testServer = concurrentSendServer{} - te.startServer() + te.startServer(concurrentSendServer{}) defer te.tearDown() cc := te.clientConn() @@ -1426,7 +1778,7 @@ func TestClientStreaming(t *testing.T) { func testClientStreaming(t *testing.T, e env) { te := newTest(t, e) - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -1459,6 +1811,49 @@ func testClientStreaming(t *testing.T, e env) { } } +func TestClientStreamingError(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + if e.name == "handler-tls" { + continue + } + testClientStreamingError(t, e) + } +} + +func testClientStreamingError(t *testing.T, e env) { + te := newTest(t, e) + te.startServer(&testServer{security: e.security, earlyFail: true}) + defer te.tearDown() + tc := testpb.NewTestServiceClient(te.clientConn()) + + stream, err := tc.StreamingInputCall(te.ctx) + if err != nil { + t.Fatalf("%v.StreamingInputCall(_) = _, %v, want ", tc, err) + } + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 1) + if err != nil { + t.Fatal(err) + } + + req := &testpb.StreamingInputCallRequest{ + Payload: payload, + } + // The 1st request should go through. + if err := stream.Send(req); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, req, err) + } + for { + if err := stream.Send(req); err != io.EOF { + continue + } + if _, err := stream.CloseAndRecv(); grpc.Code(err) != codes.NotFound { + t.Fatalf("%v.CloseAndRecv() = %v, want error %d", stream, err, codes.NotFound) + } + break + } +} + func TestExceedMaxStreamsLimit(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() { @@ -1471,10 +1866,10 @@ func testExceedMaxStreamsLimit(t *testing.T, e env) { te.declareLogNoise( "http2Client.notifyError got notified that the client transport was broken", "Conn.resetTransport failed to create client transport", - "grpc: the client connection is closing", + "grpc: the connection is closing", ) te.maxStream = 1 // Only allows 1 live stream per server transport. - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() @@ -1512,10 +1907,10 @@ func testStreamsQuotaRecovery(t *testing.T, e env) { te.declareLogNoise( "http2Client.notifyError got notified that the client transport was broken", "Conn.resetTransport failed to create client transport", - "grpc: the client connection is closing", + "grpc: the connection is closing", ) te.maxStream = 1 // Allows 1 live stream. - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() cc := te.clientConn() @@ -1566,7 +1961,7 @@ func testCompressServerHasNoSupport(t *testing.T, e env) { te := newTest(t, e) te.serverCompression = false te.clientCompression = true - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -1622,7 +2017,7 @@ func testCompressOK(t *testing.T, e env) { te := newTest(t, e) te.serverCompression = true te.clientCompression = true - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -1685,7 +2080,7 @@ func errInjector(ctx context.Context, req interface{}, info *grpc.UnaryServerInf func testUnaryServerInterceptor(t *testing.T, e env) { te := newTest(t, e) te.unaryInt = errInjector - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -1716,7 +2111,7 @@ func fullDuplexOnly(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServ func testStreamServerInterceptor(t *testing.T, e env) { te := newTest(t, e) te.streamInt = fullDuplexOnly - te.startServer() + te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) @@ -1771,21 +2166,21 @@ func (s *funcServer) StreamingInputCall(stream testpb.TestService_StreamingInput return s.streamingInputCall(stream) } -func TestClientRequestBodyError_UnexpectedEOF(t *testing.T) { +func TestClientRequestBodyErrorUnexpectedEOF(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() { - testClientRequestBodyError_UnexpectedEOF(t, e) + testClientRequestBodyErrorUnexpectedEOF(t, e) } } -func testClientRequestBodyError_UnexpectedEOF(t *testing.T, e env) { +func testClientRequestBodyErrorUnexpectedEOF(t *testing.T, e env) { te := newTest(t, e) - te.testServer = &funcServer{unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + ts := &funcServer{unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { errUnexpectedCall := errors.New("unexpected call func server method") t.Error(errUnexpectedCall) return nil, errUnexpectedCall }} - te.startServer() + te.startServer(ts) defer te.tearDown() te.withServerTester(func(st *serverTester) { st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall") @@ -1795,22 +2190,22 @@ func testClientRequestBodyError_UnexpectedEOF(t *testing.T, e env) { }) } -func TestClientRequestBodyError_CloseAfterLength(t *testing.T) { +func TestClientRequestBodyErrorCloseAfterLength(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() { - testClientRequestBodyError_CloseAfterLength(t, e) + testClientRequestBodyErrorCloseAfterLength(t, e) } } -func testClientRequestBodyError_CloseAfterLength(t *testing.T, e env) { +func testClientRequestBodyErrorCloseAfterLength(t *testing.T, e env) { te := newTest(t, e) te.declareLogNoise("Server.processUnaryRPC failed to write status") - te.testServer = &funcServer{unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + ts := &funcServer{unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { errUnexpectedCall := errors.New("unexpected call func server method") t.Error(errUnexpectedCall) return nil, errUnexpectedCall }} - te.startServer() + te.startServer(ts) defer te.tearDown() te.withServerTester(func(st *serverTester) { st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall") @@ -1820,27 +2215,27 @@ func testClientRequestBodyError_CloseAfterLength(t *testing.T, e env) { }) } -func TestClientRequestBodyError_Cancel(t *testing.T) { +func TestClientRequestBodyErrorCancel(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() { - testClientRequestBodyError_Cancel(t, e) + testClientRequestBodyErrorCancel(t, e) } } -func testClientRequestBodyError_Cancel(t *testing.T, e env) { +func testClientRequestBodyErrorCancel(t *testing.T, e env) { te := newTest(t, e) gotCall := make(chan bool, 1) - te.testServer = &funcServer{unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + ts := &funcServer{unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { gotCall <- true return new(testpb.SimpleResponse), nil }} - te.startServer() + te.startServer(ts) defer te.tearDown() te.withServerTester(func(st *serverTester) { st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall") // Say we have 5 bytes coming, but cancel it instead. - st.writeData(1, false, []byte{0, 0, 0, 0, 5}) st.writeRSTStream(1, http2.ErrCodeCancel) + st.writeData(1, false, []byte{0, 0, 0, 0, 5}) // Verify we didn't a call yet. select { @@ -1857,22 +2252,22 @@ func testClientRequestBodyError_Cancel(t *testing.T, e env) { }) } -func TestClientRequestBodyError_Cancel_StreamingInput(t *testing.T) { +func TestClientRequestBodyErrorCancelStreamingInput(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() { - testClientRequestBodyError_Cancel_StreamingInput(t, e) + testClientRequestBodyErrorCancelStreamingInput(t, e) } } -func testClientRequestBodyError_Cancel_StreamingInput(t *testing.T, e env) { +func testClientRequestBodyErrorCancelStreamingInput(t *testing.T, e env) { te := newTest(t, e) recvErr := make(chan error, 1) - te.testServer = &funcServer{streamingInputCall: func(stream testpb.TestService_StreamingInputCallServer) error { + ts := &funcServer{streamingInputCall: func(stream testpb.TestService_StreamingInputCallServer) error { _, err := stream.Recv() recvErr <- err return nil }} - te.startServer() + te.startServer(ts) defer te.tearDown() te.withServerTester(func(st *serverTester) { st.writeHeadersGRPC(1, "/grpc.testing.TestService/StreamingInputCall") @@ -1887,11 +2282,111 @@ func testClientRequestBodyError_Cancel_StreamingInput(t *testing.T, e env) { t.Fatal("timeout waiting for error") } if se, ok := got.(transport.StreamError); !ok || se.Code != codes.Canceled { - t.Errorf("error = %#v; want transport.StreamError with code Canceled") + t.Errorf("error = %#v; want transport.StreamError with code Canceled", got) } }) } +const clientAlwaysFailCredErrorMsg = "clientAlwaysFailCred always fails" + +var errClientAlwaysFailCred = errors.New(clientAlwaysFailCredErrorMsg) + +type clientAlwaysFailCred struct{} + +func (c clientAlwaysFailCred) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return nil, nil, errClientAlwaysFailCred +} +func (c clientAlwaysFailCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, nil, nil +} +func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{} +} + +func TestDialWithBlockErrorOnBadCertificates(t *testing.T) { + te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: true}) + te.startServer(&testServer{security: "clientAlwaysFailCred"}) + defer te.tearDown() + + var ( + err error + opts []grpc.DialOption + ) + opts = append(opts, grpc.WithTransportCredentials(clientAlwaysFailCred{}), grpc.WithBlock()) + te.cc, err = grpc.Dial(te.srvAddr, opts...) + if err != errClientAlwaysFailCred { + te.t.Fatalf("Dial(%q) = %v, want %v", te.srvAddr, err, errClientAlwaysFailCred) + } +} + +func TestFailFastRPCErrorOnBadCertificates(t *testing.T) { + te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: true}) + te.startServer(&testServer{security: "clientAlwaysFailCred"}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); !strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { + te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg) + } +} + +func TestFailFastRPCWithNoBalancerErrorOnBadCertificates(t *testing.T) { + te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: false}) + te.startServer(&testServer{security: "clientAlwaysFailCred"}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); !strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { + te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg) + } +} + +func TestNonFailFastRPCWithNoBalancerErrorOnBadCertificates(t *testing.T) { + te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: false}) + te.startServer(&testServer{security: "clientAlwaysFailCred"}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); !strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { + te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg) + } +} + +type clientTimeoutCreds struct { + timeoutReturned bool +} + +func (c *clientTimeoutCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + if !c.timeoutReturned { + c.timeoutReturned = true + return nil, nil, context.DeadlineExceeded + } + return rawConn, nil, nil +} +func (c *clientTimeoutCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, nil, nil +} +func (c *clientTimeoutCreds) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{} +} + +func TestNonFailFastRPCSucceedOnTimeoutCreds(t *testing.T) { + te := newTest(t, env{name: "timeout-cred", network: "tcp", security: "clientTimeoutCreds", balancer: false}) + te.userAgent = testAppUA + te.startServer(&testServer{security: "clientTimeoutCreds"}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + // This unary call should succeed, because ClientHandshake will succeed for the second time. + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { + te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want ", err) + } +} + // interestingGoroutines returns all goroutines we care about for the purpose // of leak checking. It excludes testing or runtime ones. func interestingGoroutines() (gs []string) { @@ -1909,6 +2404,7 @@ func interestingGoroutines() (gs []string) { if stack == "" || strings.Contains(stack, "testing.Main(") || + strings.Contains(stack, "testing.tRunner(") || strings.Contains(stack, "runtime.goexit") || strings.Contains(stack, "created by runtime.gc") || strings.Contains(stack, "created by google3/base/go/log.init") || @@ -1936,8 +2432,8 @@ func leakCheck(t testing.TB) func() { } return func() { // Loop, waiting for goroutines to shut down. - // Wait up to 5 seconds, but finish as quickly as possible. - deadline := time.Now().Add(5 * time.Second) + // Wait up to 10 seconds, but finish as quickly as possible. + deadline := time.Now().Add(10 * time.Second) for { var leaked []string for _, g := range interestingGoroutines() { @@ -2082,3 +2578,7 @@ func (fw *filterWriter) Write(p []byte) (n int, err error) { } return fw.dst.Write(p) } + +func equalErrors(l, r error) bool { + return grpc.Code(l) == grpc.Code(r) && grpc.ErrorDesc(l) == grpc.ErrorDesc(r) +} diff --git a/transport/control.go b/transport/control.go index 7e9bdf33..4ef0830b 100644 --- a/transport/control.go +++ b/transport/control.go @@ -72,6 +72,11 @@ type resetStream struct { func (*resetStream) item() {} +type goAway struct { +} + +func (*goAway) item() {} + type flushIO struct { } diff --git a/transport/go16.go b/transport/go16.go new file mode 100644 index 00000000..ee1c46ba --- /dev/null +++ b/transport/go16.go @@ -0,0 +1,46 @@ +// +build go1.6,!go1.7 + +/* + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package transport + +import ( + "net" + + "golang.org/x/net/context" +) + +// dialContext connects to the address on the named network. +func dialContext(ctx context.Context, network, address string) (net.Conn, error) { + return (&net.Dialer{Cancel: ctx.Done()}).Dial(network, address) +} diff --git a/transport/go17.go b/transport/go17.go new file mode 100644 index 00000000..356f13ff --- /dev/null +++ b/transport/go17.go @@ -0,0 +1,46 @@ +// +build go1.7 + +/* + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package transport + +import ( + "net" + + "golang.org/x/net/context" +) + +// dialContext connects to the address on the named network. +func dialContext(ctx context.Context, network, address string) (net.Conn, error) { + return (&net.Dialer{}).DialContext(ctx, network, address) +} diff --git a/transport/handler_server.go b/transport/handler_server.go index 00d3855f..30e21ac0 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -83,7 +83,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr } if v := r.Header.Get("grpc-timeout"); v != "" { - to, err := timeoutDecode(v) + to, err := decodeTimeout(v) if err != nil { return nil, StreamErrorf(codes.Internal, "malformed time-out: %v", err) } @@ -194,7 +194,7 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code, h := ht.rw.Header() h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode)) if statusDesc != "" { - h.Set("Grpc-Message", statusDesc) + h.Set("Grpc-Message", encodeGrpcMessage(statusDesc)) } if md := s.Trailer(); len(md) > 0 { for k, vv := range md { @@ -312,7 +312,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) { Addr: ht.RemoteAddr(), } if req.TLS != nil { - pr.AuthInfo = credentials.TLSInfo{*req.TLS} + pr.AuthInfo = credentials.TLSInfo{State: *req.TLS} } ctx = metadata.NewContext(ctx, ht.headerMD) ctx = peer.NewContext(ctx, pr) @@ -370,6 +370,10 @@ func (ht *serverHandlerTransport) runStream() { } } +func (ht *serverHandlerTransport) Drain() { + panic("Drain() is not implemented") +} + // mapRecvMsgError returns the non-nil err into the appropriate // error value as expected by callers of *grpc.parser.recvMsg. // In particular, in can only be: diff --git a/transport/handler_server_test.go b/transport/handler_server_test.go index 1fee72ff..84fc917f 100644 --- a/transport/handler_server_test.go +++ b/transport/handler_server_test.go @@ -333,7 +333,7 @@ func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) "Content-Type": {"application/grpc"}, "Trailer": {"Grpc-Status", "Grpc-Message"}, "Grpc-Status": {fmt.Sprint(uint32(statusCode))}, - "Grpc-Message": {msg}, + "Grpc-Message": {encodeGrpcMessage(msg)}, } if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) { t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", st.rw.HeaderMap, wantHeader) @@ -381,7 +381,7 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) { "Content-Type": {"application/grpc"}, "Trailer": {"Grpc-Status", "Grpc-Message"}, "Grpc-Status": {"4"}, - "Grpc-Message": {"too slow"}, + "Grpc-Message": {encodeGrpcMessage("too slow")}, } if !reflect.DeepEqual(rw.HeaderMap, wantHeader) { t.Errorf("Header+Trailer Map mismatch.\n got: %#v\nwant: %#v", rw.HeaderMap, wantHeader) diff --git a/transport/http2_client.go b/transport/http2_client.go index 227686d4..5e7c8c25 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -35,6 +35,7 @@ package transport import ( "bytes" + "fmt" "io" "math" "net" @@ -71,6 +72,9 @@ type http2Client struct { shutdownChan chan struct{} // errorChan is closed to notify the I/O error to the caller. errorChan chan struct{} + // goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor) + // that the server sent GoAway on this transport. + goAway chan struct{} framer *framer hBuf *bytes.Buffer // the buffer for HPACK encoding @@ -97,41 +101,49 @@ type http2Client struct { maxStreams int // the per-stream outbound flow control window size set by the peer. streamSendQuota uint32 + // goAwayID records the Last-Stream-ID in the GoAway frame from the server. + goAwayID uint32 + // prevGoAway ID records the Last-Stream-ID in the previous GOAway frame. + prevGoAwayID uint32 +} + +func dial(fn func(context.Context, string) (net.Conn, error), ctx context.Context, addr string) (net.Conn, error) { + if fn != nil { + return fn(ctx, addr) + } + return dialContext(ctx, "tcp", addr) } // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // and starts to receive messages on it. Non-nil error returns if construction // fails. -func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) { - if opts.Dialer == nil { - // Set the default Dialer. - opts.Dialer = func(addr string, timeout time.Duration) (net.Conn, error) { - return net.DialTimeout("tcp", addr, timeout) - } - } +func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ ClientTransport, err error) { scheme := "http" - startT := time.Now() - timeout := opts.Timeout - conn, connErr := opts.Dialer(addr, timeout) + conn, connErr := dial(opts.Dialer, ctx, addr) if connErr != nil { - return nil, ConnectionErrorf("transport: %v", connErr) + return nil, ConnectionErrorf(true, connErr, "transport: %v", connErr) } - var authInfo credentials.AuthInfo - if opts.TransportCredentials != nil { - scheme = "https" - if timeout > 0 { - timeout -= time.Since(startT) - } - conn, authInfo, connErr = opts.TransportCredentials.ClientHandshake(addr, conn, timeout) - } - if connErr != nil { - return nil, ConnectionErrorf("transport: %v", connErr) - } - defer func() { + // Any further errors will close the underlying connection + defer func(conn net.Conn) { if err != nil { conn.Close() } - }() + }(conn) + var authInfo credentials.AuthInfo + if creds := opts.TransportCredentials; creds != nil { + scheme = "https" + conn, authInfo, connErr = creds.ClientHandshake(ctx, addr, conn) + } + if connErr != nil { + // Credentials handshake error is not a temporary error (unless the error + // was the connection closing or deadline exceeded). + var temp bool + switch connErr { + case io.EOF, context.DeadlineExceeded: + temp = true + } + return nil, ConnectionErrorf(temp, connErr, "transport: %v", connErr) + } ua := primaryUA if opts.UserAgent != "" { ua = opts.UserAgent + " " + ua @@ -147,6 +159,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e writableChan: make(chan int, 1), shutdownChan: make(chan struct{}), errorChan: make(chan struct{}), + goAway: make(chan struct{}), framer: newFramer(conn), hBuf: &buf, hEnc: hpack.NewEncoder(&buf), @@ -168,26 +181,29 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e n, err := t.conn.Write(clientPreface) if err != nil { t.Close() - return nil, ConnectionErrorf("transport: %v", err) + return nil, ConnectionErrorf(true, err, "transport: %v", err) } if n != len(clientPreface) { t.Close() - return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) + return nil, ConnectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) } if initialWindowSize != defaultWindowSize { - err = t.framer.writeSettings(true, http2.Setting{http2.SettingInitialWindowSize, uint32(initialWindowSize)}) + err = t.framer.writeSettings(true, http2.Setting{ + ID: http2.SettingInitialWindowSize, + Val: uint32(initialWindowSize), + }) } else { err = t.framer.writeSettings(true) } if err != nil { t.Close() - return nil, ConnectionErrorf("transport: %v", err) + return nil, ConnectionErrorf(true, err, "transport: %v", err) } // Adjust the connection flow control window if needed. if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 { if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil { t.Close() - return nil, ConnectionErrorf("transport: %v", err) + return nil, ConnectionErrorf(true, err, "transport: %v", err) } } go t.controller() @@ -199,6 +215,8 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { // TODO(zhaoq): Handle uint32 overflow of Stream.id. s := &Stream{ id: t.nextID, + done: make(chan struct{}), + goAway: make(chan struct{}), method: callHdr.Method, sendCompress: callHdr.SendCompress, buf: newRecvBuffer(), @@ -213,8 +231,9 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { // Make a stream be able to cancel the pending operations by itself. s.ctx, s.cancel = context.WithCancel(ctx) s.dec = &recvBufferReader{ - ctx: s.ctx, - recv: s.buf, + ctx: s.ctx, + goAway: s.goAway, + recv: s.buf, } return s } @@ -268,6 +287,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea t.mu.Unlock() return nil, ErrConnClosing } + if t.state == draining { + t.mu.Unlock() + return nil, ErrStreamDrain + } if t.state != reachable { t.mu.Unlock() return nil, ErrConnClosing @@ -275,7 +298,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea checkStreamsQuota := t.streamsQuota != nil t.mu.Unlock() if checkStreamsQuota { - sq, err := wait(ctx, t.shutdownChan, t.streamsQuota.acquire()) + sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire()) if err != nil { return nil, err } @@ -284,7 +307,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea t.streamsQuota.add(sq - 1) } } - if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { // Return the quota back now because there is no stream returned to the caller. if _, ok := err.(StreamError); ok && checkStreamsQuota { t.streamsQuota.add(1) @@ -292,6 +315,15 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea return nil, err } t.mu.Lock() + if t.state == draining { + t.mu.Unlock() + if checkStreamsQuota { + t.streamsQuota.add(1) + } + // Need to make t writable again so that the rpc in flight can still proceed. + t.writableChan <- 0 + return nil, ErrStreamDrain + } if t.state != reachable { t.mu.Unlock() return nil, ErrConnClosing @@ -326,7 +358,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress}) } if timeout > 0 { - t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)}) + t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)}) } for k, v := range authData { // Capital header names are illegal in HTTP/2. @@ -381,7 +413,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } if err != nil { t.notifyError(err) - return nil, ConnectionErrorf("transport: %v", err) + return nil, ConnectionErrorf(true, err, "transport: %v", err) } } t.writableChan <- 0 @@ -400,22 +432,17 @@ func (t *http2Client) CloseStream(s *Stream, err error) { if t.streamsQuota != nil { updateStreams = true } - if t.state == draining && len(t.activeStreams) == 1 { + delete(t.activeStreams, s.id) + if t.state == draining && len(t.activeStreams) == 0 { // The transport is draining and s is the last live stream on t. t.mu.Unlock() t.Close() return } - delete(t.activeStreams, s.id) t.mu.Unlock() if updateStreams { t.streamsQuota.add(1) } - // In case stream sending and receiving are invoked in separate - // goroutines (e.g., bi-directional streaming), the caller needs - // to call cancel on the stream to interrupt the blocking on - // other goroutines. - s.cancel() s.mu.Lock() if q := s.fc.resetPendingData(); q > 0 { if n := t.fc.onRead(q); n > 0 { @@ -442,13 +469,13 @@ func (t *http2Client) CloseStream(s *Stream, err error) { // accessed any more. func (t *http2Client) Close() (err error) { t.mu.Lock() - if t.state == reachable { - close(t.errorChan) - } if t.state == closing { t.mu.Unlock() return } + if t.state == reachable || t.state == draining { + close(t.errorChan) + } t.state = closing t.mu.Unlock() close(t.shutdownChan) @@ -472,10 +499,35 @@ func (t *http2Client) Close() (err error) { func (t *http2Client) GracefulClose() error { t.mu.Lock() - if t.state == closing { + switch t.state { + case unreachable: + // The server may close the connection concurrently. t is not available for + // any streams. Close it now. + t.mu.Unlock() + t.Close() + return nil + case closing: t.mu.Unlock() return nil } + // Notify the streams which were initiated after the server sent GOAWAY. + select { + case <-t.goAway: + n := t.prevGoAwayID + if n == 0 && t.nextID > 1 { + n = t.nextID - 2 + } + m := t.goAwayID + 2 + if m == 2 { + m = 1 + } + for i := m; i <= n; i += 2 { + if s, ok := t.activeStreams[i]; ok { + close(s.goAway) + } + } + default: + } if t.state == draining { t.mu.Unlock() return nil @@ -501,15 +553,15 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { size := http2MaxFrameLen s.sendQuotaPool.add(0) // Wait until the stream has some quota to send the data. - sq, err := wait(s.ctx, t.shutdownChan, s.sendQuotaPool.acquire()) + sq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, s.sendQuotaPool.acquire()) if err != nil { return err } t.sendQuotaPool.add(0) // Wait until the transport has some quota to send the data. - tq, err := wait(s.ctx, t.shutdownChan, t.sendQuotaPool.acquire()) + tq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.sendQuotaPool.acquire()) if err != nil { - if _, ok := err.(StreamError); ok { + if _, ok := err.(StreamError); ok || err == io.EOF { t.sendQuotaPool.cancel() } return err @@ -541,8 +593,8 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { // Indicate there is a writer who is about to write a data frame. t.framer.adjustNumWriters(1) // Got some quota. Try to acquire writing privilege on the transport. - if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { - if _, ok := err.(StreamError); ok { + if _, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.writableChan); err != nil { + if _, ok := err.(StreamError); ok || err == io.EOF { // Return the connection quota back. t.sendQuotaPool.add(len(p)) } @@ -575,7 +627,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { // invoked. if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil { t.notifyError(err) - return ConnectionErrorf("transport: %v", err) + return ConnectionErrorf(true, err, "transport: %v", err) } if t.framer.adjustNumWriters(-1) == 0 { t.framer.flushWrite() @@ -590,11 +642,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { } s.mu.Lock() if s.state != streamDone { - if s.state == streamReadDone { - s.state = streamDone - } else { - s.state = streamWriteDone - } + s.state = streamWriteDone } s.mu.Unlock() return nil @@ -627,7 +675,7 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) { func (t *http2Client) handleData(f *http2.DataFrame) { size := len(f.Data()) if err := t.fc.onData(uint32(size)); err != nil { - t.notifyError(ConnectionErrorf("%v", err)) + t.notifyError(ConnectionErrorf(true, err, "%v", err)) return } // Select the right stream to dispatch. @@ -652,6 +700,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) { s.state = streamDone s.statusCode = codes.Internal s.statusDesc = err.Error() + close(s.done) s.mu.Unlock() s.write(recvMsg{err: io.EOF}) t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl}) @@ -669,13 +718,14 @@ func (t *http2Client) handleData(f *http2.DataFrame) { // the read direction is closed, and set the status appropriately. if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) { s.mu.Lock() - if s.state == streamWriteDone { - s.state = streamDone - } else { - s.state = streamReadDone + if s.state == streamDone { + s.mu.Unlock() + return } + s.state = streamDone s.statusCode = codes.Internal s.statusDesc = "server closed the stream without sending trailers" + close(s.done) s.mu.Unlock() s.write(recvMsg{err: io.EOF}) } @@ -701,6 +751,8 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) { grpclog.Println("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error ", f.ErrCode) s.statusCode = codes.Unknown } + s.statusDesc = fmt.Sprintf("stream terminated by RST_STREAM with error code: %d", f.ErrCode) + close(s.done) s.mu.Unlock() s.write(recvMsg{err: io.EOF}) } @@ -725,7 +777,32 @@ func (t *http2Client) handlePing(f *http2.PingFrame) { } func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { - // TODO(zhaoq): GoAwayFrame handler to be implemented + t.mu.Lock() + if t.state == reachable || t.state == draining { + if f.LastStreamID > 0 && f.LastStreamID%2 != 1 { + t.mu.Unlock() + t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: stream ID %d is even", f.LastStreamID)) + return + } + select { + case <-t.goAway: + id := t.goAwayID + // t.goAway has been closed (i.e.,multiple GoAways). + if id < f.LastStreamID { + t.mu.Unlock() + t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID)) + return + } + t.prevGoAwayID = id + t.goAwayID = f.LastStreamID + t.mu.Unlock() + return + default: + } + t.goAwayID = f.LastStreamID + close(t.goAway) + } + t.mu.Unlock() } func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) { @@ -777,11 +854,11 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { if len(state.mdata) > 0 { s.trailer = state.mdata } - s.state = streamDone s.statusCode = state.statusCode s.statusDesc = state.statusDesc + close(s.done) + s.state = streamDone s.mu.Unlock() - s.write(recvMsg{err: io.EOF}) } @@ -934,13 +1011,22 @@ func (t *http2Client) Error() <-chan struct{} { return t.errorChan } +func (t *http2Client) GoAway() <-chan struct{} { + return t.goAway +} + func (t *http2Client) notifyError(err error) { t.mu.Lock() - defer t.mu.Unlock() // make sure t.errorChan is closed only once. + if t.state == draining { + t.mu.Unlock() + t.Close() + return + } if t.state == reachable { t.state = unreachable close(t.errorChan) grpclog.Printf("transport: http2Client.notifyError got notified that the client transport was broken %v.", err) } + t.mu.Unlock() } diff --git a/transport/http2_server.go b/transport/http2_server.go index 1c4d5852..16010d55 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -100,18 +100,23 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI if maxStreams == 0 { maxStreams = math.MaxUint32 } else { - settings = append(settings, http2.Setting{http2.SettingMaxConcurrentStreams, maxStreams}) + settings = append(settings, http2.Setting{ + ID: http2.SettingMaxConcurrentStreams, + Val: maxStreams, + }) } if initialWindowSize != defaultWindowSize { - settings = append(settings, http2.Setting{http2.SettingInitialWindowSize, uint32(initialWindowSize)}) + settings = append(settings, http2.Setting{ + ID: http2.SettingInitialWindowSize, + Val: uint32(initialWindowSize)}) } if err := framer.writeSettings(true, settings...); err != nil { - return nil, ConnectionErrorf("transport: %v", err) + return nil, ConnectionErrorf(true, err, "transport: %v", err) } // Adjust the connection flow control window if needed. if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 { if err := framer.writeWindowUpdate(true, 0, delta); err != nil { - return nil, ConnectionErrorf("transport: %v", err) + return nil, ConnectionErrorf(true, err, "transport: %v", err) } } var buf bytes.Buffer @@ -137,7 +142,7 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI } // operateHeader takes action on the decoded headers. -func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) { +func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) (close bool) { buf := newRecvBuffer() s := &Stream{ id: frame.Header().StreamID, @@ -200,6 +205,13 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream}) return } + if s.id%2 != 1 || s.id <= t.maxStreamID { + t.mu.Unlock() + // illegal gRPC stream id. + grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", s.id) + return true + } + t.maxStreamID = s.id s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota)) t.activeStreams[s.id] = s t.mu.Unlock() @@ -207,6 +219,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( t.updateWindow(s, uint32(n)) } handle(s) + return } // HandleStreams receives incoming streams using the given handler. This is @@ -226,6 +239,10 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) { } frame, err := t.framer.readFrame() + if err == io.EOF || err == io.ErrUnexpectedEOF { + t.Close() + return + } if err != nil { grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err) t.Close() @@ -252,20 +269,20 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) { t.controlBuf.put(&resetStream{se.StreamID, se.Code}) continue } + if err == io.EOF || err == io.ErrUnexpectedEOF { + t.Close() + return + } + grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err) t.Close() return } switch frame := frame.(type) { case *http2.MetaHeadersFrame: - id := frame.Header().StreamID - if id%2 != 1 || id <= t.maxStreamID { - // illegal gRPC stream id. - grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", id) + if t.operateHeaders(frame, handle) { t.Close() break } - t.maxStreamID = id - t.operateHeaders(frame, handle) case *http2.DataFrame: t.handleData(frame) case *http2.RSTStreamFrame: @@ -277,7 +294,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) { case *http2.WindowUpdateFrame: t.handleWindowUpdate(frame) case *http2.GoAwayFrame: - break + // TODO: Handle GoAway from the client appropriately. default: grpclog.Printf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame) } @@ -359,11 +376,7 @@ func (t *http2Server) handleData(f *http2.DataFrame) { // Received the end of stream from the client. s.mu.Lock() if s.state != streamDone { - if s.state == streamWriteDone { - s.state = streamDone - } else { - s.state = streamReadDone - } + s.state = streamReadDone } s.mu.Unlock() s.write(recvMsg{err: io.EOF}) @@ -435,7 +448,7 @@ func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) e } if err != nil { t.Close() - return ConnectionErrorf("transport: %v", err) + return ConnectionErrorf(true, err, "transport: %v", err) } } return nil @@ -450,7 +463,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { } s.headerOk = true s.mu.Unlock() - if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { return err } t.hBuf.Reset() @@ -490,7 +503,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s headersSent = true } s.mu.Unlock() - if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { return err } t.hBuf.Reset() @@ -503,7 +516,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s Name: "grpc-status", Value: strconv.Itoa(int(statusCode)), }) - t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: statusDesc}) + t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(statusDesc)}) // Attach the trailer metadata. for k, v := range s.trailer { // Clients don't tolerate reading restricted headers after some non restricted ones were sent. @@ -539,7 +552,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { } s.mu.Unlock() if writeHeaderFrame { - if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { return err } t.hBuf.Reset() @@ -555,7 +568,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { } if err := t.framer.writeHeaders(false, p); err != nil { t.Close() - return ConnectionErrorf("transport: %v", err) + return ConnectionErrorf(true, err, "transport: %v", err) } t.writableChan <- 0 } @@ -567,13 +580,13 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { size := http2MaxFrameLen s.sendQuotaPool.add(0) // Wait until the stream has some quota to send the data. - sq, err := wait(s.ctx, t.shutdownChan, s.sendQuotaPool.acquire()) + sq, err := wait(s.ctx, nil, nil, t.shutdownChan, s.sendQuotaPool.acquire()) if err != nil { return err } t.sendQuotaPool.add(0) // Wait until the transport has some quota to send the data. - tq, err := wait(s.ctx, t.shutdownChan, t.sendQuotaPool.acquire()) + tq, err := wait(s.ctx, nil, nil, t.shutdownChan, t.sendQuotaPool.acquire()) if err != nil { if _, ok := err.(StreamError); ok { t.sendQuotaPool.cancel() @@ -599,7 +612,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { t.framer.adjustNumWriters(1) // Got some quota. Try to acquire writing privilege on the // transport. - if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { + if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { if _, ok := err.(StreamError); ok { // Return the connection quota back. t.sendQuotaPool.add(ps) @@ -629,7 +642,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { } if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil { t.Close() - return ConnectionErrorf("transport: %v", err) + return ConnectionErrorf(true, err, "transport: %v", err) } if t.framer.adjustNumWriters(-1) == 0 { t.framer.flushWrite() @@ -674,6 +687,17 @@ func (t *http2Server) controller() { } case *resetStream: t.framer.writeRSTStream(true, i.streamID, i.code) + case *goAway: + t.mu.Lock() + if t.state == closing { + t.mu.Unlock() + // The transport is closing. + return + } + sid := t.maxStreamID + t.state = draining + t.mu.Unlock() + t.framer.writeGoAway(true, sid, http2.ErrCodeNo, nil) case *flushIO: t.framer.flushWrite() case *ping: @@ -719,6 +743,9 @@ func (t *http2Server) Close() (err error) { func (t *http2Server) closeStream(s *Stream) { t.mu.Lock() delete(t.activeStreams, s.id) + if t.state == draining && len(t.activeStreams) == 0 { + defer t.Close() + } t.mu.Unlock() // In case stream sending and receiving are invoked in separate // goroutines (e.g., bi-directional streaming), cancel needs to be @@ -741,3 +768,7 @@ func (t *http2Server) closeStream(s *Stream) { func (t *http2Server) RemoteAddr() net.Addr { return t.conn.RemoteAddr() } + +func (t *http2Server) Drain() { + t.controlBuf.put(&goAway{}) +} diff --git a/transport/http_util.go b/transport/http_util.go index f2e23dce..79da5126 100644 --- a/transport/http_util.go +++ b/transport/http_util.go @@ -35,6 +35,7 @@ package transport import ( "bufio" + "bytes" "fmt" "io" "net" @@ -52,7 +53,7 @@ import ( const ( // The primary user agent - primaryUA = "grpc-go/0.11" + primaryUA = "grpc-go/1.0" // http2MaxFrameLen specifies the max length of a HTTP2 frame. http2MaxFrameLen = 16384 // 16KB frame // http://http2.github.io/http2-spec/#SettingValues @@ -174,11 +175,11 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) { } d.statusCode = codes.Code(code) case "grpc-message": - d.statusDesc = f.Value + d.statusDesc = decodeGrpcMessage(f.Value) case "grpc-timeout": d.timeoutSet = true var err error - d.timeout, err = timeoutDecode(f.Value) + d.timeout, err = decodeTimeout(f.Value) if err != nil { d.setErr(StreamErrorf(codes.Internal, "transport: malformed time-out: %v", err)) return @@ -251,7 +252,7 @@ func div(d, r time.Duration) int64 { } // TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it. -func timeoutEncode(t time.Duration) string { +func encodeTimeout(t time.Duration) string { if d := div(t, time.Nanosecond); d <= maxTimeoutValue { return strconv.FormatInt(d, 10) + "n" } @@ -271,7 +272,7 @@ func timeoutEncode(t time.Duration) string { return strconv.FormatInt(div(t, time.Hour), 10) + "H" } -func timeoutDecode(s string) (time.Duration, error) { +func decodeTimeout(s string) (time.Duration, error) { size := len(s) if size < 2 { return 0, fmt.Errorf("transport: timeout string is too short: %q", s) @@ -288,6 +289,80 @@ func timeoutDecode(s string) (time.Duration, error) { return d * time.Duration(t), nil } +const ( + spaceByte = ' ' + tildaByte = '~' + percentByte = '%' +) + +// encodeGrpcMessage is used to encode status code in header field +// "grpc-message". +// It checks to see if each individual byte in msg is an +// allowable byte, and then either percent encoding or passing it through. +// When percent encoding, the byte is converted into hexadecimal notation +// with a '%' prepended. +func encodeGrpcMessage(msg string) string { + if msg == "" { + return "" + } + lenMsg := len(msg) + for i := 0; i < lenMsg; i++ { + c := msg[i] + if !(c >= spaceByte && c < tildaByte && c != percentByte) { + return encodeGrpcMessageUnchecked(msg) + } + } + return msg +} + +func encodeGrpcMessageUnchecked(msg string) string { + var buf bytes.Buffer + lenMsg := len(msg) + for i := 0; i < lenMsg; i++ { + c := msg[i] + if c >= spaceByte && c < tildaByte && c != percentByte { + buf.WriteByte(c) + } else { + buf.WriteString(fmt.Sprintf("%%%02X", c)) + } + } + return buf.String() +} + +// decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage. +func decodeGrpcMessage(msg string) string { + if msg == "" { + return "" + } + lenMsg := len(msg) + for i := 0; i < lenMsg; i++ { + if msg[i] == percentByte && i+2 < lenMsg { + return decodeGrpcMessageUnchecked(msg) + } + } + return msg +} + +func decodeGrpcMessageUnchecked(msg string) string { + var buf bytes.Buffer + lenMsg := len(msg) + for i := 0; i < lenMsg; i++ { + c := msg[i] + if c == percentByte && i+2 < lenMsg { + parsed, err := strconv.ParseInt(msg[i+1:i+3], 16, 8) + if err != nil { + buf.WriteByte(c) + } else { + buf.WriteByte(byte(parsed)) + i += 2 + } + } else { + buf.WriteByte(c) + } + } + return buf.String() +} + type framer struct { numWriters int32 reader io.Reader diff --git a/transport/http_util_test.go b/transport/http_util_test.go index 279acbc5..41bf5477 100644 --- a/transport/http_util_test.go +++ b/transport/http_util_test.go @@ -59,7 +59,7 @@ func TestTimeoutEncode(t *testing.T) { if err != nil { t.Fatalf("failed to parse duration string %s: %v", test.in, err) } - out := timeoutEncode(d) + out := encodeTimeout(d) if out != test.out { t.Fatalf("timeoutEncode(%s) = %s, want %s", test.in, out, test.out) } @@ -79,7 +79,7 @@ func TestTimeoutDecode(t *testing.T) { {"1", 0, fmt.Errorf("transport: timeout string is too short: %q", "1")}, {"", 0, fmt.Errorf("transport: timeout string is too short: %q", "")}, } { - d, err := timeoutDecode(test.s) + d, err := decodeTimeout(test.s) if d != test.d || fmt.Sprint(err) != fmt.Sprint(test.err) { t.Fatalf("timeoutDecode(%q) = %d, %v, want %d, %v", test.s, int64(d), err, int64(test.d), test.err) } @@ -107,3 +107,38 @@ func TestValidContentType(t *testing.T) { } } } + +func TestEncodeGrpcMessage(t *testing.T) { + for _, tt := range []struct { + input string + expected string + }{ + {"", ""}, + {"Hello", "Hello"}, + {"my favorite character is \u0000", "my favorite character is %00"}, + {"my favorite character is %", "my favorite character is %25"}, + } { + actual := encodeGrpcMessage(tt.input) + if tt.expected != actual { + t.Errorf("encodeGrpcMessage(%v) = %v, want %v", tt.input, actual, tt.expected) + } + } +} + +func TestDecodeGrpcMessage(t *testing.T) { + for _, tt := range []struct { + input string + expected string + }{ + {"", ""}, + {"Hello", "Hello"}, + {"H%61o", "Hao"}, + {"H%6", "H%6"}, + {"%G0", "%G0"}, + } { + actual := decodeGrpcMessage(tt.input) + if tt.expected != actual { + t.Errorf("dncodeGrpcMessage(%v) = %v, want %v", tt.input, actual, tt.expected) + } + } +} diff --git a/transport/pre_go16.go b/transport/pre_go16.go new file mode 100644 index 00000000..33d91c17 --- /dev/null +++ b/transport/pre_go16.go @@ -0,0 +1,51 @@ +// +build !go1.6 + +/* + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package transport + +import ( + "net" + "time" + + "golang.org/x/net/context" +) + +// dialContext connects to the address on the named network. +func dialContext(ctx context.Context, network, address string) (net.Conn, error) { + var dialer net.Dialer + if deadline, ok := ctx.Deadline(); ok { + dialer.Timeout = deadline.Sub(time.Now()) + } + return dialer.Dial(network, address) +} diff --git a/transport/transport.go b/transport/transport.go index d4c220a0..d59e5113 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -44,7 +44,6 @@ import ( "io" "net" "sync" - "time" "golang.org/x/net/context" "golang.org/x/net/trace" @@ -120,10 +119,11 @@ func (b *recvBuffer) get() <-chan item { // recvBufferReader implements io.Reader interface to read the data from // recvBuffer. type recvBufferReader struct { - ctx context.Context - recv *recvBuffer - last *bytes.Reader // Stores the remaining data in the previous calls. - err error + ctx context.Context + goAway chan struct{} + recv *recvBuffer + last *bytes.Reader // Stores the remaining data in the previous calls. + err error } // Read reads the next len(p) bytes from last. If last is drained, it tries to @@ -141,6 +141,8 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) { select { case <-r.ctx.Done(): return 0, ContextErr(r.ctx.Err()) + case <-r.goAway: + return 0, ErrStreamDrain case i := <-r.recv.get(): r.recv.load() m := i.(*recvMsg) @@ -158,7 +160,7 @@ const ( streamActive streamState = iota streamWriteDone // EndStream sent streamReadDone // EndStream received - streamDone // sendDone and recvDone or RSTStreamFrame is sent or received. + streamDone // the entire stream is finished. ) // Stream represents an RPC in the transport layer. @@ -169,6 +171,10 @@ type Stream struct { // ctx is the associated context of the stream. ctx context.Context cancel context.CancelFunc + // done is closed when the final status arrives. + done chan struct{} + // goAway is closed when the server sent GoAways signal before this stream was initiated. + goAway chan struct{} // method records the associated RPC method of the stream. method string recvCompress string @@ -214,6 +220,18 @@ func (s *Stream) SetSendCompress(str string) { s.sendCompress = str } +// Done returns a chanel which is closed when it receives the final status +// from the server. +func (s *Stream) Done() <-chan struct{} { + return s.done +} + +// GoAway returns a channel which is closed when the server sent GoAways signal +// before this stream was initiated. +func (s *Stream) GoAway() <-chan struct{} { + return s.goAway +} + // Header acquires the key-value pairs of header metadata once it // is available. It blocks until i) the metadata is ready or ii) there is no // header metadata or iii) the stream is cancelled/expired. @@ -221,6 +239,8 @@ func (s *Stream) Header() (metadata.MD, error) { select { case <-s.ctx.Done(): return nil, ContextErr(s.ctx.Err()) + case <-s.goAway: + return nil, ErrStreamDrain case <-s.headerChan: return s.header.Copy(), nil } @@ -335,19 +355,17 @@ type ConnectOptions struct { // UserAgent is the application user agent. UserAgent string // Dialer specifies how to dial a network address. - Dialer func(string, time.Duration) (net.Conn, error) + Dialer func(context.Context, string) (net.Conn, error) // PerRPCCredentials stores the PerRPCCredentials required to issue RPCs. PerRPCCredentials []credentials.PerRPCCredentials // TransportCredentials stores the Authenticator required to setup a client connection. TransportCredentials credentials.TransportCredentials - // Timeout specifies the timeout for dialing a ClientTransport. - Timeout time.Duration } // NewClientTransport establishes the transport with the required ConnectOptions // and returns it to the caller. -func NewClientTransport(target string, opts *ConnectOptions) (ClientTransport, error) { - return newHTTP2Client(target, opts) +func NewClientTransport(ctx context.Context, target string, opts ConnectOptions) (ClientTransport, error) { + return newHTTP2Client(ctx, target, opts) } // Options provides additional hints and information for message @@ -417,6 +435,11 @@ type ClientTransport interface { // and create a new one) in error case. It should not return nil // once the transport is initiated. Error() <-chan struct{} + + // GoAway returns a channel that is closed when ClientTranspor + // receives the draining signal from the server (e.g., GOAWAY frame in + // HTTP/2). + GoAway() <-chan struct{} } // ServerTransport is the common interface for all gRPC server-side transport @@ -448,6 +471,9 @@ type ServerTransport interface { // RemoteAddr returns the remote network address. RemoteAddr() net.Addr + + // Drain notifies the client this ServerTransport stops accepting new RPCs. + Drain() } // StreamErrorf creates an StreamError with the specified error code and description. @@ -459,9 +485,11 @@ func StreamErrorf(c codes.Code, format string, a ...interface{}) StreamError { } // ConnectionErrorf creates an ConnectionError with the specified error description. -func ConnectionErrorf(format string, a ...interface{}) ConnectionError { +func ConnectionErrorf(temp bool, e error, format string, a ...interface{}) ConnectionError { return ConnectionError{ Desc: fmt.Sprintf(format, a...), + temp: temp, + err: e, } } @@ -469,14 +497,36 @@ func ConnectionErrorf(format string, a ...interface{}) ConnectionError { // entire connection and the retry of all the active streams. type ConnectionError struct { Desc string + temp bool + err error } func (e ConnectionError) Error() string { return fmt.Sprintf("connection error: desc = %q", e.Desc) } -// ErrConnClosing indicates that the transport is closing. -var ErrConnClosing = ConnectionError{Desc: "transport is closing"} +// Temporary indicates if this connection error is temporary or fatal. +func (e ConnectionError) Temporary() bool { + return e.temp +} + +// Origin returns the original error of this connection error. +func (e ConnectionError) Origin() error { + // Never return nil error here. + // If the original error is nil, return itself. + if e.err == nil { + return e + } + return e.err +} + +var ( + // ErrConnClosing indicates that the transport is closing. + ErrConnClosing = ConnectionError{Desc: "transport is closing", temp: true} + // ErrStreamDrain indicates that the stream is rejected by the server because + // the server stops accepting new RPCs. + ErrStreamDrain = StreamErrorf(codes.Unavailable, "the server stops accepting new RPCs") +) // StreamError is an error that only affects one stream within a connection. type StreamError struct { @@ -501,12 +551,25 @@ func ContextErr(err error) StreamError { // wait blocks until it can receive from ctx.Done, closing, or proceed. // If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err. +// If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise +// it return the StreamError for ctx.Err. +// If it receives from goAway, it returns 0, ErrStreamDrain. // If it receives from closing, it returns 0, ErrConnClosing. // If it receives from proceed, it returns the received integer, nil. -func wait(ctx context.Context, closing <-chan struct{}, proceed <-chan int) (int, error) { +func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-chan int) (int, error) { select { case <-ctx.Done(): return 0, ContextErr(ctx.Err()) + case <-done: + // User cancellation has precedence. + select { + case <-ctx.Done(): + return 0, ContextErr(ctx.Err()) + default: + } + return 0, io.EOF + case <-goAway: + return 0, ErrStreamDrain case <-closing: return 0, ErrConnClosing case i := <-proceed: diff --git a/transport/transport_test.go b/transport/transport_test.go index 6ebec452..2c2f1f66 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -39,7 +39,6 @@ import ( "io" "math" "net" - "reflect" "strconv" "sync" "testing" @@ -75,7 +74,7 @@ const ( normal hType = iota suspended misbehaved - malformedStatus + encodingRequiredStatus ) func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) { @@ -111,27 +110,34 @@ func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) { if !ok { t.Fatalf("Failed to convert %v to *http2Server", s.ServerTransport()) } - size := 1 - if s.Method() == "foo.MaxFrame" { - size = http2MaxFrameLen - } - // Drain the client side stream flow control window. var sent int - for sent <= initialWindowSize { + p := make([]byte, http2MaxFrameLen) + for sent < initialWindowSize { <-conn.writableChan - if err := conn.framer.writeData(true, s.id, false, make([]byte, size)); err != nil { + n := initialWindowSize - sent + // The last message may be smaller than http2MaxFrameLen + if n <= http2MaxFrameLen { + if s.Method() == "foo.Connection" { + // Violate connection level flow control window of client but do not + // violate any stream level windows. + p = make([]byte, n) + } else { + // Violate stream level flow control window of client. + p = make([]byte, n+1) + } + } + if err := conn.framer.writeData(true, s.id, false, p); err != nil { conn.writableChan <- 0 break } conn.writableChan <- 0 - sent += size + sent += len(p) } } -func (h *testStreamHandler) handleStreamMalformedStatus(t *testing.T, s *Stream) { - // raw newline is not accepted by http2 framer and a http2.StreamError is - // generated. - h.t.WriteStatus(s, codes.Internal, "\n") +func (h *testStreamHandler) handleStreamEncodingRequiredStatus(t *testing.T, s *Stream) { + // raw newline is not accepted by http2 framer so it must be encoded. + h.t.WriteStatus(s, encodingTestStatusCode, encodingTestStatusDesc) } // start starts server. Other goroutines should block on s.readyChan for further operations. @@ -179,9 +185,9 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) { go transport.HandleStreams(func(s *Stream) { go h.handleStreamMisbehave(t, s) }) - case malformedStatus: + case encodingRequiredStatus: go transport.HandleStreams(func(s *Stream) { - go h.handleStreamMalformedStatus(t, s) + go h.handleStreamEncodingRequiredStatus(t, s) }) default: go transport.HandleStreams(func(s *Stream) { @@ -221,7 +227,7 @@ func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, Client ct ClientTransport connErr error ) - ct, connErr = NewClientTransport(addr, &ConnectOptions{}) + ct, connErr = NewClientTransport(context.Background(), addr, ConnectOptions{}) if connErr != nil { t.Fatalf("failed to create transport: %v", connErr) } @@ -252,7 +258,7 @@ func TestClientSendAndReceive(t *testing.T) { Last: true, Delay: false, } - if err := ct.Write(s1, expectedRequest, &opts); err != nil { + if err := ct.Write(s1, expectedRequest, &opts); err != nil && err != io.EOF { t.Fatalf("failed to send data: %v", err) } p := make([]byte, len(expectedResponse)) @@ -289,7 +295,7 @@ func performOneRPC(ct ClientTransport) { Last: true, Delay: false, } - if err := ct.Write(s, expectedRequest, &opts); err == nil { + if err := ct.Write(s, expectedRequest, &opts); err == nil || err == io.EOF { time.Sleep(5 * time.Millisecond) // The following s.Recv()'s could error out because the // underlying transport is gone. @@ -333,7 +339,7 @@ func TestLargeMessage(t *testing.T) { if err != nil { t.Errorf("%v.NewStream(_, _) = _, %v, want _, ", ct, err) } - if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil { + if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { t.Errorf("%v.Write(_, _, _) = %v, want ", ct, err) } p := make([]byte, len(expectedResponseLarge)) @@ -369,8 +375,8 @@ func TestGracefulClose(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - if _, err := ct.NewStream(context.Background(), callHdr); err != ErrConnClosing { - t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", err, ErrConnClosing) + if _, err := ct.NewStream(context.Background(), callHdr); err != ErrStreamDrain { + t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", ct, err, ErrStreamDrain) } }() } @@ -379,7 +385,7 @@ func TestGracefulClose(t *testing.T) { Delay: false, } // The stream which was created before graceful close can still proceed. - if err := ct.Write(s, expectedRequest, &opts); err != nil { + if err := ct.Write(s, expectedRequest, &opts); err != nil && err != io.EOF { t.Fatalf("%v.Write(_, _, _) = %v, want ", ct, err) } p := make([]byte, len(expectedResponse)) @@ -409,7 +415,7 @@ func TestLargeMessageSuspension(t *testing.T) { // Write should not be done successfully due to flow control. err = ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}) expectedErr := StreamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded) - if err == nil || err != expectedErr { + if err != expectedErr { t.Fatalf("Write got %v, want %v", err, expectedErr) } ct.Close() @@ -433,14 +439,21 @@ func TestMaxStreams(t *testing.T) { } done := make(chan struct{}) ch := make(chan int) + ready := make(chan struct{}) go func() { for { select { case <-time.After(5 * time.Millisecond): - ch <- 0 + select { + case ch <- 0: + case <-ready: + return + } case <-time.After(5 * time.Second): close(done) return + case <-ready: + return } } }() @@ -467,6 +480,7 @@ func TestMaxStreams(t *testing.T) { } cc.mu.Unlock() } + close(ready) // Close the pending stream so that the streams quota becomes available for the next new stream. ct.CloseStream(s, nil) select { @@ -546,6 +560,7 @@ func TestServerContextCanceledOnClosedConnection(t *testing.T) { case <-time.After(5 * time.Second): t.Fatalf("Failed to cancel the context of the sever side stream.") } + server.stop() } func TestServerWithMisbehavedClient(t *testing.T) { @@ -652,7 +667,7 @@ func TestClientWithMisbehavedServer(t *testing.T) { server, ct := setUp(t, 0, math.MaxUint32, misbehaved) callHdr := &CallHdr{ Host: "localhost", - Method: "foo", + Method: "foo.Stream", } conn, ok := ct.(*http2Client) if !ok { @@ -663,7 +678,8 @@ func TestClientWithMisbehavedServer(t *testing.T) { if err != nil { t.Fatalf("Failed to open stream: %v", err) } - if err := ct.Write(s, expectedRequest, &Options{Last: true, Delay: false}); err != nil { + d := make([]byte, 1) + if err := ct.Write(s, d, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { t.Fatalf("Failed to write: %v", err) } // Read without window update. @@ -685,17 +701,15 @@ func TestClientWithMisbehavedServer(t *testing.T) { } // Test the logic for the violation of the connection flow control window size restriction. // - // Generate enough streams to drain the connection window. - callHdr = &CallHdr{ - Host: "localhost", - Method: "foo.MaxFrame", - } + // Generate enough streams to drain the connection window. Make the server flood the traffic + // to violate flow control window size of the connection. + callHdr.Method = "foo.Connection" for i := 0; i < int(initialConnWindowSize/initialWindowSize+10); i++ { s, err := ct.NewStream(context.Background(), callHdr) if err != nil { break } - if err := ct.Write(s, expectedRequest, &Options{Last: true, Delay: false}); err != nil { + if err := ct.Write(s, d, &Options{Last: true, Delay: false}); err != nil { break } } @@ -705,8 +719,13 @@ func TestClientWithMisbehavedServer(t *testing.T) { server.stop() } -func TestMalformedStatus(t *testing.T) { - server, ct := setUp(t, 0, math.MaxUint32, malformedStatus) +var ( + encodingTestStatusCode = codes.Internal + encodingTestStatusDesc = "\n" +) + +func TestEncodingRequiredStatus(t *testing.T) { + server, ct := setUp(t, 0, math.MaxUint32, encodingRequiredStatus) callHdr := &CallHdr{ Host: "localhost", Method: "foo", @@ -719,24 +738,26 @@ func TestMalformedStatus(t *testing.T) { Last: true, Delay: false, } - if err := ct.Write(s, expectedRequest, &opts); err != nil { + if err := ct.Write(s, expectedRequest, &opts); err != nil && err != io.EOF { t.Fatalf("Failed to write the request: %v", err) } p := make([]byte, http2MaxFrameLen) - expectedErr := StreamErrorf(codes.Internal, "invalid header field value \"\\n\"") - if _, err = s.dec.Read(p); err != expectedErr { - t.Fatalf("Read the err %v, want %v", err, expectedErr) + if _, err := s.dec.Read(p); err != io.EOF { + t.Fatalf("Read got error %v, want %v", err, io.EOF) + } + if s.StatusCode() != encodingTestStatusCode || s.StatusDesc() != encodingTestStatusDesc { + t.Fatalf("stream with status code %d, status desc %v, want %d, %v", s.StatusCode(), s.StatusDesc(), encodingTestStatusCode, encodingTestStatusDesc) } ct.Close() server.stop() } func TestStreamContext(t *testing.T) { - expectedStream := Stream{} - ctx := newContextWithStream(context.Background(), &expectedStream) + expectedStream := &Stream{} + ctx := newContextWithStream(context.Background(), expectedStream) s, ok := StreamFromContext(ctx) - if !ok || !reflect.DeepEqual(expectedStream, *s) { - t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, *s, ok, expectedStream) + if !ok || expectedStream != s { + t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, s, ok, expectedStream) } }