Merge remote-tracking branch 'upstream/master' into status_interop_test

This commit is contained in:
Mark D. Roth
2016-08-26 11:19:16 -07:00
41 changed files with 2249 additions and 628 deletions

View File

@ -1,17 +1,18 @@
language: go language: go
go: go:
- 1.5.3 - 1.5.4
- 1.6 - 1.6.3
- 1.7
go_import_path: google.golang.org/grpc
before_install: before_install:
- go get github.com/axw/gocov/gocov - go get -u golang.org/x/tools/cmd/goimports github.com/golang/lint/golint github.com/axw/gocov/gocov github.com/mattn/goveralls golang.org/x/tools/cmd/cover
- go get github.com/mattn/goveralls
- go get golang.org/x/tools/cmd/cover
install:
- mkdir -p "$GOPATH/src/google.golang.org"
- mv "$TRAVIS_BUILD_DIR" "$GOPATH/src/google.golang.org/grpc"
script: script:
- '! gofmt -s -d -l . 2>&1 | read'
- '! goimports -l . | read'
- '! golint ./... | grep -vE "(_string|\.pb)\.go:"'
- '! go tool vet -all . 2>&1 | grep -vE "constant [0-9]+ not a string in call to Errorf" | grep -vF .pb.go:' # https://github.com/golang/protobuf/issues/214
- make test testrace - make test testrace

View File

@ -28,5 +28,5 @@ See [API documentation](https://godoc.org/google.golang.org/grpc) for package an
Status Status
------ ------
Beta release GA

View File

@ -40,7 +40,6 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/naming" "google.golang.org/grpc/naming"
"google.golang.org/grpc/transport"
) )
// Address represents a server the client connects to. // Address represents a server the client connects to.
@ -94,10 +93,10 @@ type Balancer interface {
// instead of blocking. // instead of blocking.
// //
// The function returns put which is called once the rpc has completed or failed. // The function returns put which is called once the rpc has completed or failed.
// put can collect and report RPC stats to a remote load balancer. gRPC internals // put can collect and report RPC stats to a remote load balancer.
// will try to call this again if err is non-nil (unless err is ErrClientConnClosing).
// //
// TODO: Add other non-recoverable errors? // This function should only return the errors Balancer cannot recover by itself.
// gRPC internals will fail the RPC if an error is returned.
Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error)
// Notify returns a channel that is used by gRPC internals to watch the addresses // Notify returns a channel that is used by gRPC internals to watch the addresses
// gRPC needs to connect. The addresses might be from a name resolver or remote // gRPC needs to connect. The addresses might be from a name resolver or remote
@ -158,7 +157,7 @@ type roundRobin struct {
func (rr *roundRobin) watchAddrUpdates() error { func (rr *roundRobin) watchAddrUpdates() error {
updates, err := rr.w.Next() updates, err := rr.w.Next()
if err != nil { if err != nil {
grpclog.Println("grpc: the naming watcher stops working due to %v.", err) grpclog.Printf("grpc: the naming watcher stops working due to %v.\n", err)
return err return err
} }
rr.mu.Lock() rr.mu.Lock()
@ -166,6 +165,7 @@ func (rr *roundRobin) watchAddrUpdates() error {
for _, update := range updates { for _, update := range updates {
addr := Address{ addr := Address{
Addr: update.Addr, Addr: update.Addr,
Metadata: update.Metadata,
} }
switch update.Op { switch update.Op {
case naming.Add: case naming.Add:
@ -298,8 +298,19 @@ func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Ad
} }
} }
} }
// There is no address available. Wait on rr.waitCh. if !opts.BlockingWait {
// TODO(zhaoq): Handle the case when opts.BlockingWait is false. if len(rr.addrs) == 0 {
rr.mu.Unlock()
err = fmt.Errorf("there is no address available")
return
}
// Returns the next addr on rr.addrs for failfast RPCs.
addr = rr.addrs[rr.next].addr
rr.next++
rr.mu.Unlock()
return
}
// Wait on rr.waitCh for non-failfast RPCs.
if rr.waitCh == nil { if rr.waitCh == nil {
ch = make(chan struct{}) ch = make(chan struct{})
rr.waitCh = ch rr.waitCh = ch
@ -310,7 +321,7 @@ func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Ad
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
err = transport.ContextErr(ctx.Err()) err = ctx.Err()
return return
case <-ch: case <-ch:
rr.mu.Lock() rr.mu.Lock()

View File

@ -239,11 +239,11 @@ func TestCloseWithPendingRPC(t *testing.T) {
t.Fatalf("Failed to create ClientConn: %v", err) t.Fatalf("Failed to create ClientConn: %v", err)
} }
var reply string var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil { if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port) t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port)
} }
// Remove the server. // Remove the server.
updates := []*naming.Update{&naming.Update{ updates := []*naming.Update{{
Op: naming.Delete, Op: naming.Delete,
Addr: "127.0.0.1:" + servers[0].port, Addr: "127.0.0.1:" + servers[0].port,
}} }}
@ -251,7 +251,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
// Loop until the above update applies. // Loop until the above update applies.
for { for {
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); Code(err) == codes.DeadlineExceeded { if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); Code(err) == codes.DeadlineExceeded {
break break
} }
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
@ -262,7 +262,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
go func() { go func() {
defer wg.Done() defer wg.Done()
var reply string var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err == nil { if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err) t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
} }
}() }()
@ -270,7 +270,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
defer wg.Done() defer wg.Done()
var reply string var reply string
time.Sleep(5 * time.Millisecond) time.Sleep(5 * time.Millisecond)
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err == nil { if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err) t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
} }
}() }()
@ -287,7 +287,7 @@ func TestGetOnWaitChannel(t *testing.T) {
t.Fatalf("Failed to create ClientConn: %v", err) t.Fatalf("Failed to create ClientConn: %v", err)
} }
// Remove all servers so that all upcoming RPCs will block on waitCh. // Remove all servers so that all upcoming RPCs will block on waitCh.
updates := []*naming.Update{&naming.Update{ updates := []*naming.Update{{
Op: naming.Delete, Op: naming.Delete,
Addr: "127.0.0.1:" + servers[0].port, Addr: "127.0.0.1:" + servers[0].port,
}} }}
@ -295,7 +295,7 @@ func TestGetOnWaitChannel(t *testing.T) {
for { for {
var reply string var reply string
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); Code(err) == codes.DeadlineExceeded { if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); Code(err) == codes.DeadlineExceeded {
break break
} }
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
@ -305,12 +305,12 @@ func TestGetOnWaitChannel(t *testing.T) {
go func() { go func() {
defer wg.Done() defer wg.Done()
var reply string var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil { if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err) t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
} }
}() }()
// Add a connected server to get the above RPC through. // Add a connected server to get the above RPC through.
updates = []*naming.Update{&naming.Update{ updates = []*naming.Update{{
Op: naming.Add, Op: naming.Add,
Addr: "127.0.0.1:" + servers[0].port, Addr: "127.0.0.1:" + servers[0].port,
}} }}
@ -320,3 +320,119 @@ func TestGetOnWaitChannel(t *testing.T) {
cc.Close() cc.Close()
servers[0].stop() servers[0].stop()
} }
func TestOneServerDown(t *testing.T) {
// Start 2 servers.
numServers := 2
servers, r := startServers(t, numServers, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
// Add servers[1] to the service discovery.
var updates []*naming.Update
updates = append(updates, &naming.Update{
Op: naming.Add,
Addr: "127.0.0.1:" + servers[1].port,
})
r.w.inject(updates)
req := "port"
var reply string
// Loop until servers[1] is up
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
}
var wg sync.WaitGroup
numRPC := 100
sleepDuration := 10 * time.Millisecond
wg.Add(1)
go func() {
time.Sleep(sleepDuration)
// After sleepDuration, kill server[0].
servers[0].stop()
wg.Done()
}()
// All non-failfast RPCs should not block because there's at least one connection available.
for i := 0; i < numRPC; i++ {
wg.Add(1)
go func() {
time.Sleep(sleepDuration)
// After sleepDuration, invoke RPC.
// server[0] is killed around the same time to make it racy between balancer and gRPC internals.
Invoke(context.Background(), "/foo/bar", &req, &reply, cc, FailFast(false))
wg.Done()
}()
}
wg.Wait()
cc.Close()
for i := 0; i < numServers; i++ {
servers[i].stop()
}
}
func TestOneAddressRemoval(t *testing.T) {
// Start 2 servers.
numServers := 2
servers, r := startServers(t, numServers, math.MaxUint32)
cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
if err != nil {
t.Fatalf("Failed to create ClientConn: %v", err)
}
// Add servers[1] to the service discovery.
var updates []*naming.Update
updates = append(updates, &naming.Update{
Op: naming.Add,
Addr: "127.0.0.1:" + servers[1].port,
})
r.w.inject(updates)
req := "port"
var reply string
// Loop until servers[1] is up
for {
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port {
break
}
time.Sleep(10 * time.Millisecond)
}
var wg sync.WaitGroup
numRPC := 100
sleepDuration := 10 * time.Millisecond
wg.Add(1)
go func() {
time.Sleep(sleepDuration)
// After sleepDuration, delete server[0].
var updates []*naming.Update
updates = append(updates, &naming.Update{
Op: naming.Delete,
Addr: "127.0.0.1:" + servers[0].port,
})
r.w.inject(updates)
wg.Done()
}()
// All non-failfast RPCs should not fail because there's at least one connection available.
for i := 0; i < numRPC; i++ {
wg.Add(1)
go func() {
var reply string
time.Sleep(sleepDuration)
// After sleepDuration, invoke RPC.
// server[0] is removed around the same time to make it racy between balancer and gRPC internals.
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
}
wg.Done()
}()
}
wg.Wait()
cc.Close()
for i := 0; i < numServers; i++ {
servers[i].stop()
}
}

View File

@ -58,7 +58,7 @@ func closeLoopUnary() {
for i := 0; i < *maxConcurrentRPCs; i++ { for i := 0; i < *maxConcurrentRPCs; i++ {
go func() { go func() {
for _ = range ch { for range ch {
start := time.Now() start := time.Now()
unaryCaller(tc) unaryCaller(tc)
elapse := time.Since(start) elapse := time.Since(start)

View File

@ -133,8 +133,8 @@ func (h *Histogram) Clear() {
h.SumOfSquares = 0 h.SumOfSquares = 0
h.Min = math.MaxInt64 h.Min = math.MaxInt64
h.Max = math.MinInt64 h.Max = math.MinInt64
for _, v := range h.Buckets { for i := range h.Buckets {
v.Count = 0 h.Buckets[i].Count = 0
} }
} }

View File

@ -60,7 +60,7 @@ type byteBufCodec struct {
func (byteBufCodec) Marshal(v interface{}) ([]byte, error) { func (byteBufCodec) Marshal(v interface{}) ([]byte, error) {
b, ok := v.(*[]byte) b, ok := v.(*[]byte)
if !ok { if !ok {
return nil, fmt.Errorf("failed to marshal: %v is not type of *[]byte") return nil, fmt.Errorf("failed to marshal: %v is not type of *[]byte", v)
} }
return *b, nil return *b, nil
} }
@ -68,7 +68,7 @@ func (byteBufCodec) Marshal(v interface{}) ([]byte, error) {
func (byteBufCodec) Unmarshal(data []byte, v interface{}) error { func (byteBufCodec) Unmarshal(data []byte, v interface{}) error {
b, ok := v.(*[]byte) b, ok := v.(*[]byte)
if !ok { if !ok {
return fmt.Errorf("failed to marshal: %v is not type of *[]byte") return fmt.Errorf("failed to marshal: %v is not type of *[]byte", v)
} }
*b = data *b = data
return nil return nil
@ -138,8 +138,6 @@ func (s *workerServer) RunServer(stream testpb.WorkerService_RunServerServer) er
return err return err
} }
} }
return nil
} }
func (s *workerServer) RunClient(stream testpb.WorkerService_RunClientServer) error { func (s *workerServer) RunClient(stream testpb.WorkerService_RunClientServer) error {
@ -191,13 +189,11 @@ func (s *workerServer) RunClient(stream testpb.WorkerService_RunClientServer) er
return err return err
} }
} }
return nil
} }
func (s *workerServer) CoreCount(ctx context.Context, in *testpb.CoreRequest) (*testpb.CoreResponse, error) { func (s *workerServer) CoreCount(ctx context.Context, in *testpb.CoreRequest) (*testpb.CoreResponse, error) {
grpclog.Printf("core count: %v", runtime.NumCPU()) grpclog.Printf("core count: %v", runtime.NumCPU())
return &testpb.CoreResponse{int32(runtime.NumCPU())}, nil return &testpb.CoreResponse{Cores: int32(runtime.NumCPU())}, nil
} }
func (s *workerServer) QuitWorker(ctx context.Context, in *testpb.Void) (*testpb.Void, error) { func (s *workerServer) QuitWorker(ctx context.Context, in *testpb.Void) (*testpb.Void, error) {

43
call.go
View File

@ -36,6 +36,7 @@ package grpc
import ( import (
"bytes" "bytes"
"io" "io"
"math"
"time" "time"
"golang.org/x/net/context" "golang.org/x/net/context"
@ -51,13 +52,20 @@ import (
func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error { func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error {
// Try to acquire header metadata from the server if there is any. // Try to acquire header metadata from the server if there is any.
var err error var err error
defer func() {
if err != nil {
if _, ok := err.(transport.ConnectionError); !ok {
t.CloseStream(stream, err)
}
}
}()
c.headerMD, err = stream.Header() c.headerMD, err = stream.Header()
if err != nil { if err != nil {
return err return err
} }
p := &parser{r: stream} p := &parser{r: stream}
for { for {
if err = recv(p, dopts.codec, stream, dopts.dc, reply); err != nil { if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32); err != nil {
if err == io.EOF { if err == io.EOF {
break break
} }
@ -76,6 +84,7 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
} }
defer func() { defer func() {
if err != nil { if err != nil {
// If err is connection error, t will be closed, no need to close stream here.
if _, ok := err.(transport.ConnectionError); !ok { if _, ok := err.(transport.ConnectionError); !ok {
t.CloseStream(stream, err) t.CloseStream(stream, err)
} }
@ -90,7 +99,10 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err) return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err)
} }
err = t.Write(stream, outBuf, opts) err = t.Write(stream, outBuf, opts)
if err != nil { // t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method
// does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following
// recvResponse to get the final status.
if err != nil && err != io.EOF {
return nil, err return nil, err
} }
// Sent successfully. // Sent successfully.
@ -101,7 +113,7 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
// Invoke is called by generated code. Also users can call Invoke directly when it // Invoke is called by generated code. Also users can call Invoke directly when it
// is really needed in their use cases. // is really needed in their use cases.
func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) { func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) {
var c callInfo c := defaultCallInfo
for _, o := range opts { for _, o := range opts {
if err := o.before(&c); err != nil { if err := o.before(&c); err != nil {
return toRPCErr(err) return toRPCErr(err)
@ -155,20 +167,18 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
t, put, err = cc.getTransport(ctx, gopts) t, put, err = cc.getTransport(ctx, gopts)
if err != nil { if err != nil {
// TODO(zhaoq): Probably revisit the error handling. // TODO(zhaoq): Probably revisit the error handling.
if err == ErrClientConnClosing { if _, ok := err.(*rpcError); ok {
return Errorf(codes.FailedPrecondition, "%v", err) return err
} }
if _, ok := err.(transport.StreamError); ok { if err == errConnClosing || err == errConnUnavailable {
return toRPCErr(err)
}
if _, ok := err.(transport.ConnectionError); ok {
if c.failFast { if c.failFast {
return toRPCErr(err) return Errorf(codes.Unavailable, "%v", err)
} }
}
// All the remaining cases are treated as retryable.
continue continue
} }
// All the other errors are treated as Internal errors.
return Errorf(codes.Internal, "%v", err)
}
if c.traceInfo.tr != nil { if c.traceInfo.tr != nil {
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
} }
@ -178,7 +188,10 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
put() put()
put = nil put = nil
} }
if _, ok := err.(transport.ConnectionError); ok { // Retry a non-failfast RPC when
// i) there is a connection error; or
// ii) the server started to drain before this RPC was initiated.
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
if c.failFast { if c.failFast {
return toRPCErr(err) return toRPCErr(err)
} }
@ -186,20 +199,18 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
} }
return toRPCErr(err) return toRPCErr(err)
} }
// Receive the response
err = recvResponse(cc.dopts, t, &c, stream, reply) err = recvResponse(cc.dopts, t, &c, stream, reply)
if err != nil { if err != nil {
if put != nil { if put != nil {
put() put()
put = nil put = nil
} }
if _, ok := err.(transport.ConnectionError); ok { if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
if c.failFast { if c.failFast {
return toRPCErr(err) return toRPCErr(err)
} }
continue continue
} }
t.CloseStream(stream, err)
return toRPCErr(err) return toRPCErr(err)
} }
if c.traceInfo.tr != nil { if c.traceInfo.tr != nil {

View File

@ -81,7 +81,7 @@ type testStreamHandler struct {
func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) { func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
p := &parser{r: s} p := &parser{r: s}
for { for {
pf, req, err := p.recvMsg() pf, req, err := p.recvMsg(math.MaxInt32)
if err == io.EOF { if err == io.EOF {
break break
} }
@ -234,7 +234,7 @@ func TestInvokeLargeErr(t *testing.T) {
var reply string var reply string
req := "hello" req := "hello"
err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc) err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
if _, ok := err.(rpcError); !ok { if _, ok := err.(*rpcError); !ok {
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.") t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
} }
if Code(err) != codes.Internal || len(ErrorDesc(err)) != sizeLargeErr { if Code(err) != codes.Internal || len(ErrorDesc(err)) != sizeLargeErr {
@ -250,7 +250,7 @@ func TestInvokeErrorSpecialChars(t *testing.T) {
var reply string var reply string
req := "weird error" req := "weird error"
err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc) err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
if _, ok := err.(rpcError); !ok { if _, ok := err.(*rpcError); !ok {
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.") t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
} }
if got, want := ErrorDesc(err), weirdError; got != want { if got, want := ErrorDesc(err), weirdError; got != want {
@ -276,3 +276,18 @@ func TestInvokeCancel(t *testing.T) {
cc.Close() cc.Close()
server.stop() server.stop()
} }
// TestInvokeCancelClosedNonFail checks that a canceled non-failfast RPC
// on a closed client will terminate.
func TestInvokeCancelClosedNonFailFast(t *testing.T) {
server, cc := setUp(t, 0, math.MaxUint32)
var reply string
cc.Close()
req := "hello"
ctx, cancel := context.WithCancel(context.Background())
cancel()
if err := Invoke(ctx, "/foo/bar", &req, &reply, cc, FailFast(false)); err == nil {
t.Fatalf("canceled invoke on closed connection should fail")
}
server.stop()
}

View File

@ -43,7 +43,6 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/trace" "golang.org/x/net/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/transport" "google.golang.org/grpc/transport"
@ -68,12 +67,14 @@ var (
// errCredentialsConflict indicates that grpc.WithTransportCredentials() // errCredentialsConflict indicates that grpc.WithTransportCredentials()
// and grpc.WithInsecure() are both called for a connection. // and grpc.WithInsecure() are both called for a connection.
errCredentialsConflict = errors.New("grpc: transport credentials are set for an insecure connection (grpc.WithTransportCredentials() and grpc.WithInsecure() are both called)") errCredentialsConflict = errors.New("grpc: transport credentials are set for an insecure connection (grpc.WithTransportCredentials() and grpc.WithInsecure() are both called)")
// errNetworkIP indicates that the connection is down due to some network I/O error. // errNetworkIO indicates that the connection is down due to some network I/O error.
errNetworkIO = errors.New("grpc: failed with network I/O error") errNetworkIO = errors.New("grpc: failed with network I/O error")
// errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs. // errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs.
errConnDrain = errors.New("grpc: the connection is drained") errConnDrain = errors.New("grpc: the connection is drained")
// errConnClosing indicates that the connection is closing. // errConnClosing indicates that the connection is closing.
errConnClosing = errors.New("grpc: the connection is closing") errConnClosing = errors.New("grpc: the connection is closing")
// errConnUnavailable indicates that the connection is unavailable.
errConnUnavailable = errors.New("grpc: the connection is unavailable")
errNoAddr = errors.New("grpc: there is no address available to dial") errNoAddr = errors.New("grpc: there is no address available to dial")
// minimum time to give a connection to complete // minimum time to give a connection to complete
minConnectTimeout = 20 * time.Second minConnectTimeout = 20 * time.Second
@ -196,9 +197,14 @@ func WithTimeout(d time.Duration) DialOption {
} }
// WithDialer returns a DialOption that specifies a function to use for dialing network addresses. // WithDialer returns a DialOption that specifies a function to use for dialing network addresses.
func WithDialer(f func(addr string, timeout time.Duration) (net.Conn, error)) DialOption { func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
o.copts.Dialer = f o.copts.Dialer = func(ctx context.Context, addr string) (net.Conn, error) {
if deadline, ok := ctx.Deadline(); ok {
return f(addr, deadline.Sub(time.Now()))
}
return f(addr, 0)
}
} }
} }
@ -209,36 +215,58 @@ func WithUserAgent(s string) DialOption {
} }
} }
// Dial creates a client connection the given target. // Dial creates a client connection to the given target.
func Dial(target string, opts ...DialOption) (*ClientConn, error) { func Dial(target string, opts ...DialOption) (*ClientConn, error) {
return DialContext(context.Background(), target, opts...)
}
// DialContext creates a client connection to the given target. ctx can be used to
// cancel or expire the pending connecting. Once this function returns, the
// cancellation and expiration of ctx will be noop. Users should call ClientConn.Close
// to terminate all the pending operations after this function returns.
// This is the EXPERIMENTAL API.
func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) {
cc := &ClientConn{ cc := &ClientConn{
target: target, target: target,
conns: make(map[Address]*addrConn), conns: make(map[Address]*addrConn),
} }
cc.ctx, cc.cancel = context.WithCancel(context.Background())
defer func() {
select {
case <-ctx.Done():
conn, err = nil, ctx.Err()
default:
}
if err != nil {
cc.Close()
}
}()
for _, opt := range opts { for _, opt := range opts {
opt(&cc.dopts) opt(&cc.dopts)
} }
// Set defaults.
if cc.dopts.codec == nil { if cc.dopts.codec == nil {
// Set the default codec.
cc.dopts.codec = protoCodec{} cc.dopts.codec = protoCodec{}
} }
if cc.dopts.bs == nil { if cc.dopts.bs == nil {
cc.dopts.bs = DefaultBackoffConfig cc.dopts.bs = DefaultBackoffConfig
} }
cc.balancer = cc.dopts.balancer
if cc.balancer == nil {
cc.balancer = RoundRobin(nil)
}
if err := cc.balancer.Start(target); err != nil {
return nil, err
}
var ( var (
ok bool ok bool
addrs []Address addrs []Address
) )
ch := cc.balancer.Notify() if cc.dopts.balancer == nil {
// Connect to target directly if balancer is nil.
addrs = append(addrs, Address{Addr: target})
} else {
if err := cc.dopts.balancer.Start(target); err != nil {
return nil, err
}
ch := cc.dopts.balancer.Notify()
if ch == nil { if ch == nil {
// There is no name resolver installed. // There is no name resolver installed.
addrs = append(addrs, Address{Addr: target}) addrs = append(addrs, Address{Addr: target})
@ -248,10 +276,11 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
return nil, errNoAddr return nil, errNoAddr
} }
} }
}
waitC := make(chan error, 1) waitC := make(chan error, 1)
go func() { go func() {
for _, a := range addrs { for _, a := range addrs {
if err := cc.newAddrConn(a, false); err != nil { if err := cc.resetAddrConn(a, false, nil); err != nil {
waitC <- err waitC <- err
return return
} }
@ -263,15 +292,17 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
timeoutCh = time.After(cc.dopts.timeout) timeoutCh = time.After(cc.dopts.timeout)
} }
select { select {
case <-ctx.Done():
return nil, ctx.Err()
case err := <-waitC: case err := <-waitC:
if err != nil { if err != nil {
cc.Close()
return nil, err return nil, err
} }
case <-timeoutCh: case <-timeoutCh:
cc.Close()
return nil, ErrClientConnTimeout return nil, ErrClientConnTimeout
} }
// If balancer is nil or balancer.Notify() is nil, ok will be false here.
// The lbWatcher goroutine will not be created.
if ok { if ok {
go cc.lbWatcher() go cc.lbWatcher()
} }
@ -318,8 +349,10 @@ func (s ConnectivityState) String() string {
// ClientConn represents a client connection to an RPC server. // ClientConn represents a client connection to an RPC server.
type ClientConn struct { type ClientConn struct {
ctx context.Context
cancel context.CancelFunc
target string target string
balancer Balancer
authority string authority string
dopts dialOptions dopts dialOptions
@ -328,7 +361,7 @@ type ClientConn struct {
} }
func (cc *ClientConn) lbWatcher() { func (cc *ClientConn) lbWatcher() {
for addrs := range cc.balancer.Notify() { for addrs := range cc.dopts.balancer.Notify() {
var ( var (
add []Address // Addresses need to setup connections. add []Address // Addresses need to setup connections.
del []*addrConn // Connections need to tear down. del []*addrConn // Connections need to tear down.
@ -349,11 +382,12 @@ func (cc *ClientConn) lbWatcher() {
} }
if !keep { if !keep {
del = append(del, c) del = append(del, c)
delete(cc.conns, c.addr)
} }
} }
cc.mu.Unlock() cc.mu.Unlock()
for _, a := range add { for _, a := range add {
cc.newAddrConn(a, true) cc.resetAddrConn(a, true, nil)
} }
for _, c := range del { for _, c := range del {
c.tearDown(errConnDrain) c.tearDown(errConnDrain)
@ -361,13 +395,17 @@ func (cc *ClientConn) lbWatcher() {
} }
} }
func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error { // resetAddrConn creates an addrConn for addr and adds it to cc.conns.
// If there is an old addrConn for addr, it will be torn down, using tearDownErr as the reason.
// If tearDownErr is nil, errConnDrain will be used instead.
func (cc *ClientConn) resetAddrConn(addr Address, skipWait bool, tearDownErr error) error {
ac := &addrConn{ ac := &addrConn{
cc: cc, cc: cc,
addr: addr, addr: addr,
dopts: cc.dopts, dopts: cc.dopts,
shutdownChan: make(chan struct{}),
} }
ac.ctx, ac.cancel = context.WithCancel(cc.ctx)
ac.stateCV = sync.NewCond(&ac.mu)
if EnableTracing { if EnableTracing {
ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr) ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr)
} }
@ -385,26 +423,44 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
} }
} }
} }
// Insert ac into ac.cc.conns. This needs to be done before any getTransport(...) is called. // Track ac in cc. This needs to be done before any getTransport(...) is called.
ac.cc.mu.Lock() cc.mu.Lock()
if ac.cc.conns == nil { if cc.conns == nil {
ac.cc.mu.Unlock() cc.mu.Unlock()
return ErrClientConnClosing return ErrClientConnClosing
} }
stale := ac.cc.conns[ac.addr] stale := cc.conns[ac.addr]
ac.cc.conns[ac.addr] = ac cc.conns[ac.addr] = ac
ac.cc.mu.Unlock() cc.mu.Unlock()
if stale != nil { if stale != nil {
// There is an addrConn alive on ac.addr already. This could be due to // There is an addrConn alive on ac.addr already. This could be due to
// i) stale's Close is undergoing; // 1) a buggy Balancer notifies duplicated Addresses;
// ii) a buggy Balancer notifies duplicated Addresses. // 2) goaway was received, a new ac will replace the old ac.
// The old ac should be deleted from cc.conns, but the
// underlying transport should drain rather than close.
if tearDownErr == nil {
// tearDownErr is nil if resetAddrConn is called by
// 1) Dial
// 2) lbWatcher
// In both cases, the stale ac should drain, not close.
stale.tearDown(errConnDrain) stale.tearDown(errConnDrain)
} else {
stale.tearDown(tearDownErr)
}
} }
ac.stateCV = sync.NewCond(&ac.mu)
// skipWait may overwrite the decision in ac.dopts.block. // skipWait may overwrite the decision in ac.dopts.block.
if ac.dopts.block && !skipWait { if ac.dopts.block && !skipWait {
if err := ac.resetTransport(false); err != nil { if err := ac.resetTransport(false); err != nil {
if err != errConnClosing {
// Tear down ac and delete it from cc.conns.
cc.mu.Lock()
delete(cc.conns, ac.addr)
cc.mu.Unlock()
ac.tearDown(err) ac.tearDown(err)
}
if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() {
return e.Origin()
}
return err return err
} }
// Start to monitor the error status of transport. // Start to monitor the error status of transport.
@ -414,7 +470,10 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
go func() { go func() {
if err := ac.resetTransport(false); err != nil { if err := ac.resetTransport(false); err != nil {
grpclog.Printf("Failed to dial %s: %v; please retry.", ac.addr.Addr, err) grpclog.Printf("Failed to dial %s: %v; please retry.", ac.addr.Addr, err)
if err != errConnClosing {
// Keep this ac in cc.conns, to get the reason it's torn down.
ac.tearDown(err) ac.tearDown(err)
}
return return
} }
ac.transportMonitor() ac.transportMonitor()
@ -424,25 +483,48 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
} }
func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) { func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) {
// TODO(zhaoq): Implement fail-fast logic. var (
addr, put, err := cc.balancer.Get(ctx, opts) ac *addrConn
ok bool
put func()
)
if cc.dopts.balancer == nil {
// If balancer is nil, there should be only one addrConn available.
cc.mu.RLock()
if cc.conns == nil {
cc.mu.RUnlock()
return nil, nil, toRPCErr(ErrClientConnClosing)
}
for _, ac = range cc.conns {
// Break after the first iteration to get the first addrConn.
ok = true
break
}
cc.mu.RUnlock()
} else {
var (
addr Address
err error
)
addr, put, err = cc.dopts.balancer.Get(ctx, opts)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, toRPCErr(err)
} }
cc.mu.RLock() cc.mu.RLock()
if cc.conns == nil { if cc.conns == nil {
cc.mu.RUnlock() cc.mu.RUnlock()
return nil, nil, ErrClientConnClosing return nil, nil, toRPCErr(ErrClientConnClosing)
} }
ac, ok := cc.conns[addr] ac, ok = cc.conns[addr]
cc.mu.RUnlock() cc.mu.RUnlock()
}
if !ok { if !ok {
if put != nil { if put != nil {
put() put()
} }
return nil, nil, transport.StreamErrorf(codes.Internal, "grpc: failed to find the transport to send the rpc") return nil, nil, errConnClosing
} }
t, err := ac.wait(ctx) t, err := ac.wait(ctx, cc.dopts.balancer != nil, !opts.BlockingWait)
if err != nil { if err != nil {
if put != nil { if put != nil {
put() put()
@ -454,6 +536,8 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions)
// Close tears down the ClientConn and all underlying connections. // Close tears down the ClientConn and all underlying connections.
func (cc *ClientConn) Close() error { func (cc *ClientConn) Close() error {
cc.cancel()
cc.mu.Lock() cc.mu.Lock()
if cc.conns == nil { if cc.conns == nil {
cc.mu.Unlock() cc.mu.Unlock()
@ -462,7 +546,9 @@ func (cc *ClientConn) Close() error {
conns := cc.conns conns := cc.conns
cc.conns = nil cc.conns = nil
cc.mu.Unlock() cc.mu.Unlock()
cc.balancer.Close() if cc.dopts.balancer != nil {
cc.dopts.balancer.Close()
}
for _, ac := range conns { for _, ac := range conns {
ac.tearDown(ErrClientConnClosing) ac.tearDown(ErrClientConnClosing)
} }
@ -471,10 +557,12 @@ func (cc *ClientConn) Close() error {
// addrConn is a network connection to a given address. // addrConn is a network connection to a given address.
type addrConn struct { type addrConn struct {
ctx context.Context
cancel context.CancelFunc
cc *ClientConn cc *ClientConn
addr Address addr Address
dopts dialOptions dopts dialOptions
shutdownChan chan struct{}
events trace.EventLog events trace.EventLog
mu sync.Mutex mu sync.Mutex
@ -485,6 +573,9 @@ type addrConn struct {
// due to timeout. // due to timeout.
ready chan struct{} ready chan struct{}
transport transport.ClientTransport transport transport.ClientTransport
// The reason this addrConn is torn down.
tearDownErr error
} }
// printf records an event in ac's event log, unless ac has been closed. // printf records an event in ac's event log, unless ac has been closed.
@ -540,8 +631,7 @@ func (ac *addrConn) waitForStateChange(ctx context.Context, sourceState Connecti
} }
func (ac *addrConn) resetTransport(closeTransport bool) error { func (ac *addrConn) resetTransport(closeTransport bool) error {
var retries int for retries := 0; ; retries++ {
for {
ac.mu.Lock() ac.mu.Lock()
ac.printf("connecting") ac.printf("connecting")
if ac.state == Shutdown { if ac.state == Shutdown {
@ -561,13 +651,20 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
t.Close() t.Close()
} }
sleepTime := ac.dopts.bs.backoff(retries) sleepTime := ac.dopts.bs.backoff(retries)
ac.dopts.copts.Timeout = sleepTime timeout := minConnectTimeout
if sleepTime < minConnectTimeout { if timeout < sleepTime {
ac.dopts.copts.Timeout = minConnectTimeout timeout = sleepTime
} }
ctx, cancel := context.WithTimeout(ac.ctx, timeout)
connectTime := time.Now() connectTime := time.Now()
newTransport, err := transport.NewClientTransport(ac.addr.Addr, &ac.dopts.copts) newTransport, err := transport.NewClientTransport(ctx, ac.addr.Addr, ac.dopts.copts)
if err != nil { if err != nil {
cancel()
if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() {
return err
}
grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr)
ac.mu.Lock() ac.mu.Lock()
if ac.state == Shutdown { if ac.state == Shutdown {
// ac.tearDown(...) has been invoked. // ac.tearDown(...) has been invoked.
@ -582,17 +679,12 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
ac.ready = nil ac.ready = nil
} }
ac.mu.Unlock() ac.mu.Unlock()
sleepTime -= time.Since(connectTime)
if sleepTime < 0 {
sleepTime = 0
}
closeTransport = false closeTransport = false
select { select {
case <-time.After(sleepTime): case <-time.After(sleepTime - time.Since(connectTime)):
case <-ac.shutdownChan: case <-ac.ctx.Done():
return ac.ctx.Err()
} }
retries++
grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr)
continue continue
} }
ac.mu.Lock() ac.mu.Lock()
@ -610,7 +702,9 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
close(ac.ready) close(ac.ready)
ac.ready = nil ac.ready = nil
} }
ac.down = ac.cc.balancer.Up(ac.addr) if ac.cc.dopts.balancer != nil {
ac.down = ac.cc.dopts.balancer.Up(ac.addr)
}
ac.mu.Unlock() ac.mu.Unlock()
return nil return nil
} }
@ -624,14 +718,42 @@ func (ac *addrConn) transportMonitor() {
t := ac.transport t := ac.transport
ac.mu.Unlock() ac.mu.Unlock()
select { select {
// shutdownChan is needed to detect the teardown when // This is needed to detect the teardown when
// the addrConn is idle (i.e., no RPC in flight). // the addrConn is idle (i.e., no RPC in flight).
case <-ac.shutdownChan: case <-ac.ctx.Done():
select {
case <-t.Error():
t.Close()
default:
}
return
case <-t.GoAway():
// If GoAway happens without any network I/O error, ac is closed without shutting down the
// underlying transport (the transport will be closed when all the pending RPCs finished or
// failed.).
// If GoAway and some network I/O error happen concurrently, ac and its underlying transport
// are closed.
// In both cases, a new ac is created.
select {
case <-t.Error():
ac.cc.resetAddrConn(ac.addr, true, errNetworkIO)
default:
ac.cc.resetAddrConn(ac.addr, true, errConnDrain)
}
return return
case <-t.Error(): case <-t.Error():
select {
case <-ac.ctx.Done():
t.Close()
return
case <-t.GoAway():
ac.cc.resetAddrConn(ac.addr, true, errNetworkIO)
return
default:
}
ac.mu.Lock() ac.mu.Lock()
if ac.state == Shutdown { if ac.state == Shutdown {
// ac.tearDown(...) has been invoked. // ac has been shutdown.
ac.mu.Unlock() ac.mu.Unlock()
return return
} }
@ -643,25 +765,41 @@ func (ac *addrConn) transportMonitor() {
ac.printf("transport exiting: %v", err) ac.printf("transport exiting: %v", err)
ac.mu.Unlock() ac.mu.Unlock()
grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err) grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err)
if err != errConnClosing {
// Keep this ac in cc.conns, to get the reason it's torn down.
ac.tearDown(err)
}
return return
} }
} }
} }
} }
// wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed. // wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed or
func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error) { // iv) transport is in TransientFailure and there's no balancer/failfast is true.
func (ac *addrConn) wait(ctx context.Context, hasBalancer, failfast bool) (transport.ClientTransport, error) {
for { for {
ac.mu.Lock() ac.mu.Lock()
switch { switch {
case ac.state == Shutdown: case ac.state == Shutdown:
if failfast || !hasBalancer {
// RPC is failfast or balancer is nil. This RPC should fail with ac.tearDownErr.
err := ac.tearDownErr
ac.mu.Unlock()
return nil, err
}
ac.mu.Unlock() ac.mu.Unlock()
return nil, errConnClosing return nil, errConnClosing
case ac.state == Ready: case ac.state == Ready:
ct := ac.transport ct := ac.transport
ac.mu.Unlock() ac.mu.Unlock()
return ct, nil return ct, nil
default: case ac.state == TransientFailure:
if failfast || hasBalancer {
ac.mu.Unlock()
return nil, errConnUnavailable
}
}
ready := ac.ready ready := ac.ready
if ready == nil { if ready == nil {
ready = make(chan struct{}) ready = make(chan struct{})
@ -670,36 +808,39 @@ func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error)
ac.mu.Unlock() ac.mu.Unlock()
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, transport.ContextErr(ctx.Err()) return nil, toRPCErr(ctx.Err())
// Wait until the new transport is ready or failed. // Wait until the new transport is ready or failed.
case <-ready: case <-ready:
} }
} }
} }
}
// tearDown starts to tear down the addrConn. // tearDown starts to tear down the addrConn.
// TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in // TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in
// some edge cases (e.g., the caller opens and closes many addrConn's in a // some edge cases (e.g., the caller opens and closes many addrConn's in a
// tight loop. // tight loop.
// tearDown doesn't remove ac from ac.cc.conns.
func (ac *addrConn) tearDown(err error) { func (ac *addrConn) tearDown(err error) {
ac.cancel()
ac.mu.Lock() ac.mu.Lock()
defer func() { defer ac.mu.Unlock()
ac.mu.Unlock()
ac.cc.mu.Lock()
if ac.cc.conns != nil {
delete(ac.cc.conns, ac.addr)
}
ac.cc.mu.Unlock()
}()
if ac.state == Shutdown {
return
}
ac.state = Shutdown
if ac.down != nil { if ac.down != nil {
ac.down(downErrorf(false, false, "%v", err)) ac.down(downErrorf(false, false, "%v", err))
ac.down = nil ac.down = nil
} }
if err == errConnDrain && ac.transport != nil {
// GracefulClose(...) may be executed multiple times when
// i) receiving multiple GoAway frames from the server; or
// ii) there are concurrent name resolver/Balancer triggered
// address removal and GoAway.
ac.transport.GracefulClose()
}
if ac.state == Shutdown {
return
}
ac.state = Shutdown
ac.tearDownErr = err
ac.stateCV.Broadcast() ac.stateCV.Broadcast()
if ac.events != nil { if ac.events != nil {
ac.events.Finish() ac.events.Finish()
@ -709,15 +850,8 @@ func (ac *addrConn) tearDown(err error) {
close(ac.ready) close(ac.ready)
ac.ready = nil ac.ready = nil
} }
if ac.transport != nil { if ac.transport != nil && err != errConnDrain {
if err == errConnDrain {
ac.transport.GracefulClose()
} else {
ac.transport.Close() ac.transport.Close()
} }
}
if ac.shutdownChan != nil {
close(ac.shutdownChan)
}
return return
} }

View File

@ -37,6 +37,8 @@ import (
"testing" "testing"
"time" "time"
"golang.org/x/net/context"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/oauth" "google.golang.org/grpc/credentials/oauth"
) )
@ -67,13 +69,21 @@ func TestTLSDialTimeout(t *testing.T) {
} }
} }
func TestDialContextCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
if _, err := DialContext(ctx, "Non-Existent.Server:80", WithBlock(), WithInsecure()); err != context.Canceled {
t.Fatalf("grpc.DialContext(%v, _) = _, %v, want _, %v", ctx, err, context.Canceled)
}
}
func TestCredentialsMisuse(t *testing.T) { func TestCredentialsMisuse(t *testing.T) {
tlsCreds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") tlsCreds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
if err != nil { if err != nil {
t.Fatalf("Failed to create authenticator %v", err) t.Fatalf("Failed to create authenticator %v", err)
} }
// Two conflicting credential configurations // Two conflicting credential configurations
if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(tlsCreds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsConflict { if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(tlsCreds), WithBlock(), WithInsecure()); err != errCredentialsConflict {
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsConflict) t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsConflict)
} }
rpcCreds, err := oauth.NewJWTAccessFromKey(nil) rpcCreds, err := oauth.NewJWTAccessFromKey(nil)
@ -81,7 +91,7 @@ func TestCredentialsMisuse(t *testing.T) {
t.Fatalf("Failed to create credentials %v", err) t.Fatalf("Failed to create credentials %v", err)
} }
// security info on insecure connection // security info on insecure connection
if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(rpcCreds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errTransportCredentialsMissing { if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(rpcCreds), WithBlock(), WithInsecure()); err != errTransportCredentialsMissing {
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errTransportCredentialsMissing) t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errTransportCredentialsMissing)
} }
} }
@ -123,4 +133,5 @@ func testBackoffConfigSet(t *testing.T, expected *BackoffConfig, opts ...DialOpt
if actual != *expected { if actual != *expected {
t.Fatalf("unexpected backoff config on connection: %v, want %v", actual, expected) t.Fatalf("unexpected backoff config on connection: %v, want %v", actual, expected)
} }
conn.Close()
} }

View File

@ -44,7 +44,6 @@ import (
"io/ioutil" "io/ioutil"
"net" "net"
"strings" "strings"
"time"
"golang.org/x/net/context" "golang.org/x/net/context"
) )
@ -93,11 +92,12 @@ type TransportCredentials interface {
// ClientHandshake does the authentication handshake specified by the corresponding // ClientHandshake does the authentication handshake specified by the corresponding
// authentication protocol on rawConn for clients. It returns the authenticated // authentication protocol on rawConn for clients. It returns the authenticated
// connection and the corresponding auth information about the connection. // connection and the corresponding auth information about the connection.
ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, AuthInfo, error) // Implementations must use the provided context to implement timely cancellation.
ClientHandshake(context.Context, string, net.Conn) (net.Conn, AuthInfo, error)
// ServerHandshake does the authentication handshake for servers. It returns // ServerHandshake does the authentication handshake for servers. It returns
// the authenticated connection and the corresponding auth information about // the authenticated connection and the corresponding auth information about
// the connection. // the connection.
ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) ServerHandshake(net.Conn) (net.Conn, AuthInfo, error)
// Info provides the ProtocolInfo of this TransportCredentials. // Info provides the ProtocolInfo of this TransportCredentials.
Info() ProtocolInfo Info() ProtocolInfo
} }
@ -116,7 +116,7 @@ func (t TLSInfo) AuthType() string {
// tlsCreds is the credentials required for authenticating a connection using TLS. // tlsCreds is the credentials required for authenticating a connection using TLS.
type tlsCreds struct { type tlsCreds struct {
// TLS configuration // TLS configuration
config tls.Config config *tls.Config
} }
func (c tlsCreds) Info() ProtocolInfo { func (c tlsCreds) Info() ProtocolInfo {
@ -136,50 +136,37 @@ func (c *tlsCreds) RequireTransportSecurity() bool {
return true return true
} }
type timeoutError struct{} func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) {
// use local cfg to avoid clobbering ServerName if using multiple endpoints
func (timeoutError) Error() string { return "credentials: Dial timed out" } cfg := cloneTLSConfig(c.config)
func (timeoutError) Timeout() bool { return true } if cfg.ServerName == "" {
func (timeoutError) Temporary() bool { return true }
func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ AuthInfo, err error) {
// borrow some code from tls.DialWithDialer
var errChannel chan error
if timeout != 0 {
errChannel = make(chan error, 2)
time.AfterFunc(timeout, func() {
errChannel <- timeoutError{}
})
}
if c.config.ServerName == "" {
colonPos := strings.LastIndex(addr, ":") colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 { if colonPos == -1 {
colonPos = len(addr) colonPos = len(addr)
} }
c.config.ServerName = addr[:colonPos] cfg.ServerName = addr[:colonPos]
} }
conn := tls.Client(rawConn, &c.config) conn := tls.Client(rawConn, cfg)
if timeout == 0 { errChannel := make(chan error, 1)
err = conn.Handshake()
} else {
go func() { go func() {
errChannel <- conn.Handshake() errChannel <- conn.Handshake()
}() }()
err = <-errChannel select {
} case err := <-errChannel:
if err != nil { if err != nil {
rawConn.Close()
return nil, nil, err return nil, nil, err
} }
case <-ctx.Done():
return nil, nil, ctx.Err()
}
// TODO(zhaoq): Omit the auth info for client now. It is more for // TODO(zhaoq): Omit the auth info for client now. It is more for
// information than anything else. // information than anything else.
return conn, nil, nil return conn, nil, nil
} }
func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) { func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
conn := tls.Server(rawConn, &c.config) conn := tls.Server(rawConn, c.config)
if err := conn.Handshake(); err != nil { if err := conn.Handshake(); err != nil {
rawConn.Close()
return nil, nil, err return nil, nil, err
} }
return conn, TLSInfo{conn.ConnectionState()}, nil return conn, TLSInfo{conn.ConnectionState()}, nil
@ -187,7 +174,7 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
// NewTLS uses c to construct a TransportCredentials based on TLS. // NewTLS uses c to construct a TransportCredentials based on TLS.
func NewTLS(c *tls.Config) TransportCredentials { func NewTLS(c *tls.Config) TransportCredentials {
tc := &tlsCreds{*c} tc := &tlsCreds{cloneTLSConfig(c)}
tc.config.NextProtos = alpnProtoStr tc.config.NextProtos = alpnProtoStr
return tc return tc
} }

View File

@ -0,0 +1,76 @@
// +build go1.7
/*
*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package credentials
import (
"crypto/tls"
)
// cloneTLSConfig returns a shallow clone of the exported
// fields of cfg, ignoring the unexported sync.Once, which
// contains a mutex and must not be copied.
//
// If cfg is nil, a new zero tls.Config is returned.
//
// TODO replace this function with official clone function.
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
SessionTicketsDisabled: cfg.SessionTicketsDisabled,
SessionTicketKey: cfg.SessionTicketKey,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
Renegotiation: cfg.Renegotiation,
}
}

View File

@ -0,0 +1,74 @@
// +build !go1.7
/*
*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package credentials
import (
"crypto/tls"
)
// cloneTLSConfig returns a shallow clone of the exported
// fields of cfg, ignoring the unexported sync.Once, which
// contains a mutex and must not be copied.
//
// If cfg is nil, a new zero tls.Config is returned.
//
// TODO replace this function with official clone function.
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
SessionTicketsDisabled: cfg.SessionTicketsDisabled,
SessionTicketKey: cfg.SessionTicketKey,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
}
}

View File

@ -28,12 +28,12 @@ Then change your current directory to `grpc-go/examples/route_guide`:
$ cd $GOPATH/src/google.golang.org/grpc/examples/route_guide $ cd $GOPATH/src/google.golang.org/grpc/examples/route_guide
``` ```
You also should have the relevant tools installed to generate the server and client interface code - if you don't already, follow the setup instructions in [the Go quick start guide](examples/). You also should have the relevant tools installed to generate the server and client interface code - if you don't already, follow the setup instructions in [the Go quick start guide](https://github.com/grpc/grpc-go/tree/master/examples/).
## Defining the service ## Defining the service
Our first step (as you'll know from the [quick start](http://www.grpc.io/docs/#quick-start)) is to define the gRPC *service* and the method *request* and *response* types using [protocol buffers] (https://developers.google.com/protocol-buffers/docs/overview). You can see the complete .proto file in [`examples/route_guide/proto/route_guide.proto`](examples/route_guide/proto/route_guide.proto). Our first step (as you'll know from the [quick start](http://www.grpc.io/docs/#quick-start)) is to define the gRPC *service* and the method *request* and *response* types using [protocol buffers] (https://developers.google.com/protocol-buffers/docs/overview). You can see the complete .proto file in [examples/route_guide/routeguide/route_guide.proto](https://github.com/grpc/grpc-go/tree/master/examples/route_guide/routeguide/route_guide.proto).
To define a service, you specify a named `service` in your .proto file: To define a service, you specify a named `service` in your .proto file:

View File

@ -115,12 +115,12 @@ func runRecordRoute(client pb.RouteGuideClient) {
// runRouteChat receives a sequence of route notes, while sending notes for various locations. // runRouteChat receives a sequence of route notes, while sending notes for various locations.
func runRouteChat(client pb.RouteGuideClient) { func runRouteChat(client pb.RouteGuideClient) {
notes := []*pb.RouteNote{ notes := []*pb.RouteNote{
{&pb.Point{0, 1}, "First message"}, {&pb.Point{Latitude: 0, Longitude: 1}, "First message"},
{&pb.Point{0, 2}, "Second message"}, {&pb.Point{Latitude: 0, Longitude: 2}, "Second message"},
{&pb.Point{0, 3}, "Third message"}, {&pb.Point{Latitude: 0, Longitude: 3}, "Third message"},
{&pb.Point{0, 1}, "Fourth message"}, {&pb.Point{Latitude: 0, Longitude: 1}, "Fourth message"},
{&pb.Point{0, 2}, "Fifth message"}, {&pb.Point{Latitude: 0, Longitude: 2}, "Fifth message"},
{&pb.Point{0, 3}, "Sixth message"}, {&pb.Point{Latitude: 0, Longitude: 3}, "Sixth message"},
} }
stream, err := client.RouteChat(context.Background()) stream, err := client.RouteChat(context.Background())
if err != nil { if err != nil {
@ -153,7 +153,7 @@ func runRouteChat(client pb.RouteGuideClient) {
func randomPoint(r *rand.Rand) *pb.Point { func randomPoint(r *rand.Rand) *pb.Point {
lat := (r.Int31n(180) - 90) * 1e7 lat := (r.Int31n(180) - 90) * 1e7
long := (r.Int31n(360) - 180) * 1e7 long := (r.Int31n(360) - 180) * 1e7
return &pb.Point{lat, long} return &pb.Point{Latitude: lat, Longitude: long}
} }
func main() { func main() {
@ -186,13 +186,16 @@ func main() {
client := pb.NewRouteGuideClient(conn) client := pb.NewRouteGuideClient(conn)
// Looking for a valid feature // Looking for a valid feature
printFeature(client, &pb.Point{409146138, -746188906}) printFeature(client, &pb.Point{Latitude: 409146138, Longitude: -746188906})
// Feature missing. // Feature missing.
printFeature(client, &pb.Point{0, 0}) printFeature(client, &pb.Point{Latitude: 0, Longitude: 0})
// Looking for features between 40, -75 and 42, -73. // Looking for features between 40, -75 and 42, -73.
printFeatures(client, &pb.Rectangle{&pb.Point{400000000, -750000000}, &pb.Point{420000000, -730000000}}) printFeatures(client, &pb.Rectangle{
Lo: &pb.Point{Latitude: 400000000, Longitude: -750000000},
Hi: &pb.Point{Latitude: 420000000, Longitude: -730000000},
})
// RecordRoute // RecordRoute
runRecordRoute(client) runRecordRoute(client)

View File

@ -79,7 +79,7 @@ func (s *routeGuideServer) GetFeature(ctx context.Context, point *pb.Point) (*pb
} }
} }
// No feature was found, return an unnamed feature // No feature was found, return an unnamed feature
return &pb.Feature{"", point}, nil return &pb.Feature{Location: point}, nil
} }
// ListFeatures lists all features contained within the given bounding Rectangle. // ListFeatures lists all features contained within the given bounding Rectangle.

View File

@ -11,19 +11,22 @@ import (
healthpb "google.golang.org/grpc/health/grpc_health_v1" healthpb "google.golang.org/grpc/health/grpc_health_v1"
) )
type HealthServer struct { // Server implements `service Health`.
type Server struct {
mu sync.Mutex mu sync.Mutex
// statusMap stores the serving status of the services this HealthServer monitors. // statusMap stores the serving status of the services this Server monitors.
statusMap map[string]healthpb.HealthCheckResponse_ServingStatus statusMap map[string]healthpb.HealthCheckResponse_ServingStatus
} }
func NewHealthServer() *HealthServer { // NewServer returns a new Server.
return &HealthServer{ func NewServer() *Server {
return &Server{
statusMap: make(map[string]healthpb.HealthCheckResponse_ServingStatus), statusMap: make(map[string]healthpb.HealthCheckResponse_ServingStatus),
} }
} }
func (s *HealthServer) Check(ctx context.Context, in *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { // Check implements `service Health`.
func (s *Server) Check(ctx context.Context, in *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if in.Service == "" { if in.Service == "" {
@ -42,7 +45,7 @@ func (s *HealthServer) Check(ctx context.Context, in *healthpb.HealthCheckReques
// SetServingStatus is called when need to reset the serving status of a service // SetServingStatus is called when need to reset the serving status of a service
// or insert a new service entry into the statusMap. // or insert a new service entry into the statusMap.
func (s *HealthServer) SetServingStatus(service string, status healthpb.HealthCheckResponse_ServingStatus) { func (s *Server) SetServingStatus(service string, status healthpb.HealthCheckResponse_ServingStatus) {
s.mu.Lock() s.mu.Lock()
s.statusMap[service] = status s.statusMap[service] = status
s.mu.Unlock() s.mu.Unlock()

View File

@ -60,15 +60,21 @@ func encodeKeyValue(k, v string) (string, string) {
// DecodeKeyValue returns the original key and value corresponding to the // DecodeKeyValue returns the original key and value corresponding to the
// encoded data in k, v. // encoded data in k, v.
// If k is a binary header and v contains comma, v is split on comma before decoded,
// and the decoded v will be joined with comma before returned.
func DecodeKeyValue(k, v string) (string, string, error) { func DecodeKeyValue(k, v string) (string, string, error) {
if !strings.HasSuffix(k, binHdrSuffix) { if !strings.HasSuffix(k, binHdrSuffix) {
return k, v, nil return k, v, nil
} }
val, err := base64.StdEncoding.DecodeString(v) vvs := strings.Split(v, ",")
for i, vv := range vvs {
val, err := base64.StdEncoding.DecodeString(vv)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
return k, string(val), nil vvs[i] = string(val)
}
return k, strings.Join(vvs, ","), nil
} }
// MD is a mapping from metadata keys to values. Users should use the following // MD is a mapping from metadata keys to values. Users should use the following

View File

@ -74,6 +74,8 @@ func TestDecodeKeyValue(t *testing.T) {
{"a", "abc", "a", "abc", nil}, {"a", "abc", "a", "abc", nil},
{"key-bin", "Zm9vAGJhcg==", "key-bin", "foo\x00bar", nil}, {"key-bin", "Zm9vAGJhcg==", "key-bin", "foo\x00bar", nil},
{"key-bin", "woA=", "key-bin", binaryValue, nil}, {"key-bin", "woA=", "key-bin", binaryValue, nil},
{"a", "abc,efg", "a", "abc,efg", nil},
{"key-bin", "Zm9vAGJhcg==,Zm9vAGJhcg==", "key-bin", "foo\x00bar,foo\x00bar", nil},
} { } {
k, v, err := DecodeKeyValue(test.kin, test.vin) k, v, err := DecodeKeyValue(test.kin, test.vin)
if k != test.kout || !reflect.DeepEqual(v, test.vout) || !reflect.DeepEqual(err, test.err) { if k != test.kout || !reflect.DeepEqual(v, test.vout) || !reflect.DeepEqual(err, test.err) {

View File

@ -70,7 +70,7 @@ import (
type serverReflectionServer struct { type serverReflectionServer struct {
s *grpc.Server s *grpc.Server
// TODO add more cache if necessary // TODO add more cache if necessary
serviceInfo map[string]*grpc.ServiceInfo // cache for s.GetServiceInfo() serviceInfo map[string]grpc.ServiceInfo // cache for s.GetServiceInfo()
} }
// Register registers the server reflection service on the given gRPC server. // Register registers the server reflection service on the given gRPC server.
@ -214,19 +214,19 @@ func (s *serverReflectionServer) serviceMetadataForSymbol(name string) (interfac
return nil, fmt.Errorf("unknown symbol: %v", name) return nil, fmt.Errorf("unknown symbol: %v", name)
} }
// Search for method in info. // Search the method name in info.Methods.
var found bool var found bool
for _, m := range info.Methods { for _, m := range info.Methods {
if m == name[pos+1:] { if m.Name == name[pos+1:] {
found = true found = true
break break
} }
} }
if !found { if found {
return nil, fmt.Errorf("unknown symbol: %v", name) return info.Metadata, nil
} }
return info.Metadata, nil return nil, fmt.Errorf("unknown symbol: %v", name)
} }
// fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol, // fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol,
@ -253,7 +253,7 @@ func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) (
// Metadata not valid. // Metadata not valid.
enc, ok := meta.([]byte) enc, ok := meta.([]byte)
if !ok { if !ok {
return nil, fmt.Errorf("invalid file descriptor for symbol: %v") return nil, fmt.Errorf("invalid file descriptor for symbol: %v", name)
} }
fd, err = s.decodeFileDesc(enc) fd, err = s.decodeFileDesc(enc)

View File

@ -273,6 +273,7 @@ func testFileContainingSymbol(t *testing.T, stream rpb.ServerReflection_ServerRe
}{ }{
{"grpc.testing.SearchService", fdTestByte}, {"grpc.testing.SearchService", fdTestByte},
{"grpc.testing.SearchService.Search", fdTestByte}, {"grpc.testing.SearchService.Search", fdTestByte},
{"grpc.testing.SearchService.StreamingSearch", fdTestByte},
{"grpc.testing.SearchResponse", fdTestByte}, {"grpc.testing.SearchResponse", fdTestByte},
{"grpc.testing.ToBeExtened", fdProto2Byte}, {"grpc.testing.ToBeExtened", fdProto2Byte},
} { } {

View File

@ -141,6 +141,8 @@ type callInfo struct {
traceInfo traceInfo // in trace.go traceInfo traceInfo // in trace.go
} }
var defaultCallInfo = callInfo{failFast: true}
// CallOption configures a Call before it starts or extracts information from // CallOption configures a Call before it starts or extracts information from
// a Call after it completes. // a Call after it completes.
type CallOption interface { type CallOption interface {
@ -179,6 +181,19 @@ func Trailer(md *metadata.MD) CallOption {
}) })
} }
// FailFast configures the action to take when an RPC is attempted on broken
// connections or unreachable servers. If failfast is true, the RPC will fail
// immediately. Otherwise, the RPC client will block the call until a
// connection is available (or the call is canceled or times out) and will retry
// the call if it fails due to a transient error. Please refer to
// https://github.com/grpc/grpc/blob/master/doc/fail_fast.md
func FailFast(failFast bool) CallOption {
return beforeCall(func(c *callInfo) error {
c.failFast = failFast
return nil
})
}
// The format of the payload: compressed or not? // The format of the payload: compressed or not?
type payloadFormat uint8 type payloadFormat uint8
@ -212,7 +227,7 @@ type parser struct {
// No other error values or types must be returned, which also means // No other error values or types must be returned, which also means
// that the underlying io.Reader must not return an incompatible // that the underlying io.Reader must not return an incompatible
// error. // error.
func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) { func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err error) {
if _, err := io.ReadFull(p.r, p.header[:]); err != nil { if _, err := io.ReadFull(p.r, p.header[:]); err != nil {
return 0, nil, err return 0, nil, err
} }
@ -223,6 +238,9 @@ func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) {
if length == 0 { if length == 0 {
return pf, nil, nil return pf, nil, nil
} }
if length > uint32(maxMsgSize) {
return 0, nil, Errorf(codes.Internal, "grpc: received message length %d exceeding the max size %d", length, maxMsgSize)
}
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead // TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
// of making it for each message: // of making it for each message:
msg = make([]byte, int(length)) msg = make([]byte, int(length))
@ -293,8 +311,8 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
return nil return nil
} }
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}) error { func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int) error {
pf, d, err := p.recvMsg() pf, d, err := p.recvMsg(maxMsgSize)
if err != nil { if err != nil {
return err return err
} }
@ -304,11 +322,16 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{
if pf == compressionMade { if pf == compressionMade {
d, err = dc.Do(bytes.NewReader(d)) d, err = dc.Do(bytes.NewReader(d))
if err != nil { if err != nil {
return transport.StreamErrorf(codes.Internal, "grpc: failed to decompress the received message %v", err) return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
} }
} }
if len(d) > maxMsgSize {
// TODO: Revisit the error code. Currently keep it consistent with java
// implementation.
return Errorf(codes.Internal, "grpc: received a message of %d bytes exceeding %d limit", len(d), maxMsgSize)
}
if err := c.Unmarshal(d, m); err != nil { if err := c.Unmarshal(d, m); err != nil {
return transport.StreamErrorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err) return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
} }
return nil return nil
} }
@ -319,7 +342,7 @@ type rpcError struct {
desc string desc string
} }
func (e rpcError) Error() string { func (e *rpcError) Error() string {
return fmt.Sprintf("rpc error: code = %d desc = %s", e.code, e.desc) return fmt.Sprintf("rpc error: code = %d desc = %s", e.code, e.desc)
} }
@ -329,7 +352,7 @@ func Code(err error) codes.Code {
if err == nil { if err == nil {
return codes.OK return codes.OK
} }
if e, ok := err.(rpcError); ok { if e, ok := err.(*rpcError); ok {
return e.code return e.code
} }
return codes.Unknown return codes.Unknown
@ -341,7 +364,7 @@ func ErrorDesc(err error) string {
if err == nil { if err == nil {
return "" return ""
} }
if e, ok := err.(rpcError); ok { if e, ok := err.(*rpcError); ok {
return e.desc return e.desc
} }
return err.Error() return err.Error()
@ -353,7 +376,7 @@ func Errorf(c codes.Code, format string, a ...interface{}) error {
if c == codes.OK { if c == codes.OK {
return nil return nil
} }
return rpcError{ return &rpcError{
code: c, code: c,
desc: fmt.Sprintf(format, a...), desc: fmt.Sprintf(format, a...),
} }
@ -362,18 +385,37 @@ func Errorf(c codes.Code, format string, a ...interface{}) error {
// toRPCErr converts an error into a rpcError. // toRPCErr converts an error into a rpcError.
func toRPCErr(err error) error { func toRPCErr(err error) error {
switch e := err.(type) { switch e := err.(type) {
case rpcError: case *rpcError:
return err return err
case transport.StreamError: case transport.StreamError:
return rpcError{ return &rpcError{
code: e.Code, code: e.Code,
desc: e.Desc, desc: e.Desc,
} }
case transport.ConnectionError: case transport.ConnectionError:
return rpcError{ return &rpcError{
code: codes.Internal, code: codes.Internal,
desc: e.Desc, desc: e.Desc,
} }
default:
switch err {
case context.DeadlineExceeded:
return &rpcError{
code: codes.DeadlineExceeded,
desc: err.Error(),
}
case context.Canceled:
return &rpcError{
code: codes.Canceled,
desc: err.Error(),
}
case ErrClientConnClosing:
return &rpcError{
code: codes.FailedPrecondition,
desc: err.Error(),
}
}
} }
return Errorf(codes.Unknown, "%v", err) return Errorf(codes.Unknown, "%v", err)
} }

View File

@ -36,6 +36,7 @@ package grpc
import ( import (
"bytes" "bytes"
"io" "io"
"math"
"reflect" "reflect"
"testing" "testing"
@ -66,9 +67,9 @@ func TestSimpleParsing(t *testing.T) {
} { } {
buf := bytes.NewReader(test.p) buf := bytes.NewReader(test.p)
parser := &parser{r: buf} parser := &parser{r: buf}
pt, b, err := parser.recvMsg() pt, b, err := parser.recvMsg(math.MaxInt32)
if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt { if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt {
t.Fatalf("parser{%v}.recvMsg() = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err) t.Fatalf("parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err)
} }
} }
} }
@ -88,16 +89,16 @@ func TestMultipleParsing(t *testing.T) {
{compressionNone, []byte("d")}, {compressionNone, []byte("d")},
} }
for i, want := range wantRecvs { for i, want := range wantRecvs {
pt, data, err := parser.recvMsg() pt, data, err := parser.recvMsg(math.MaxInt32)
if err != nil || pt != want.pt || !reflect.DeepEqual(data, want.data) { if err != nil || pt != want.pt || !reflect.DeepEqual(data, want.data) {
t.Fatalf("after %d calls, parser{%v}.recvMsg() = %v, %v, %v\nwant %v, %v, <nil>", t.Fatalf("after %d calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, <nil>",
i, p, pt, data, err, want.pt, want.data) i, p, pt, data, err, want.pt, want.data)
} }
} }
pt, data, err := parser.recvMsg() pt, data, err := parser.recvMsg(math.MaxInt32)
if err != io.EOF { if err != io.EOF {
t.Fatalf("after %d recvMsgs calls, parser{%v}.recvMsg() = %v, %v, %v\nwant _, _, %v", t.Fatalf("after %d recvMsgs calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant _, _, %v",
len(wantRecvs), p, pt, data, err, io.EOF) len(wantRecvs), p, pt, data, err, io.EOF)
} }
} }
@ -149,13 +150,17 @@ func TestToRPCErr(t *testing.T) {
// input // input
errIn error errIn error
// outputs // outputs
errOut error errOut *rpcError
}{ }{
{transport.StreamErrorf(codes.Unknown, ""), Errorf(codes.Unknown, "")}, {transport.StreamErrorf(codes.Unknown, ""), Errorf(codes.Unknown, "").(*rpcError)},
{transport.ErrConnClosing, Errorf(codes.Internal, transport.ErrConnClosing.Desc)}, {transport.ErrConnClosing, Errorf(codes.Internal, transport.ErrConnClosing.Desc).(*rpcError)},
} { } {
err := toRPCErr(test.errIn) err := toRPCErr(test.errIn)
if err != test.errOut { rpcErr, ok := err.(*rpcError)
if !ok {
t.Fatalf("toRPCErr{%v} returned type %T, want %T", test.errIn, err, rpcError{})
}
if *rpcErr != *test.errOut {
t.Fatalf("toRPCErr{%v} = %v \nwant %v", test.errIn, err, test.errOut) t.Fatalf("toRPCErr{%v} = %v \nwant %v", test.errIn, err, test.errOut)
} }
} }
@ -178,6 +183,18 @@ func TestContextErr(t *testing.T) {
} }
} }
func TestErrorsWithSameParameters(t *testing.T) {
const description = "some description"
e1 := Errorf(codes.AlreadyExists, description).(*rpcError)
e2 := Errorf(codes.AlreadyExists, description).(*rpcError)
if e1 == e2 {
t.Fatalf("Error interfaces should not be considered equal - e1: %p - %v e2: %p - %v", e1, e1, e2, e2)
}
if Code(e1) != Code(e2) || ErrorDesc(e1) != ErrorDesc(e2) {
t.Fatalf("Expected errors to have same code and description - e1: %p - %v e2: %p - %v", e1, e1, e2, e2)
}
}
// bmEncode benchmarks encoding a Protocol Buffer message containing mSize // bmEncode benchmarks encoding a Protocol Buffer message containing mSize
// bytes. // bytes.
func bmEncode(b *testing.B, mSize int) { func bmEncode(b *testing.B, mSize int) {

114
server.go
View File

@ -92,6 +92,10 @@ type Server struct {
mu sync.Mutex // guards following mu sync.Mutex // guards following
lis map[net.Listener]bool lis map[net.Listener]bool
conns map[io.Closer]bool conns map[io.Closer]bool
drain bool
// A CondVar to let GracefulStop() blocks until all the pending RPCs are finished
// and all the transport goes away.
cv *sync.Cond
m map[string]*service // service name -> service info m map[string]*service // service name -> service info
events trace.EventLog events trace.EventLog
} }
@ -101,12 +105,15 @@ type options struct {
codec Codec codec Codec
cp Compressor cp Compressor
dc Decompressor dc Decompressor
maxMsgSize int
unaryInt UnaryServerInterceptor unaryInt UnaryServerInterceptor
streamInt StreamServerInterceptor streamInt StreamServerInterceptor
maxConcurrentStreams uint32 maxConcurrentStreams uint32
useHandlerImpl bool // use http.Handler-based server useHandlerImpl bool // use http.Handler-based server
} }
var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit
// A ServerOption sets options. // A ServerOption sets options.
type ServerOption func(*options) type ServerOption func(*options)
@ -117,20 +124,28 @@ func CustomCodec(codec Codec) ServerOption {
} }
} }
// RPCCompressor returns a ServerOption that sets a compressor for outbound message. // RPCCompressor returns a ServerOption that sets a compressor for outbound messages.
func RPCCompressor(cp Compressor) ServerOption { func RPCCompressor(cp Compressor) ServerOption {
return func(o *options) { return func(o *options) {
o.cp = cp o.cp = cp
} }
} }
// RPCDecompressor returns a ServerOption that sets a decompressor for inbound message. // RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages.
func RPCDecompressor(dc Decompressor) ServerOption { func RPCDecompressor(dc Decompressor) ServerOption {
return func(o *options) { return func(o *options) {
o.dc = dc o.dc = dc
} }
} }
// MaxMsgSize returns a ServerOption to set the max message size in bytes for inbound mesages.
// If this is not set, gRPC uses the default 4MB.
func MaxMsgSize(m int) ServerOption {
return func(o *options) {
o.maxMsgSize = m
}
}
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number // MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
// of concurrent streams to each ServerTransport. // of concurrent streams to each ServerTransport.
func MaxConcurrentStreams(n uint32) ServerOption { func MaxConcurrentStreams(n uint32) ServerOption {
@ -173,6 +188,7 @@ func StreamInterceptor(i StreamServerInterceptor) ServerOption {
// started to accept requests yet. // started to accept requests yet.
func NewServer(opt ...ServerOption) *Server { func NewServer(opt ...ServerOption) *Server {
var opts options var opts options
opts.maxMsgSize = defaultMaxMsgSize
for _, o := range opt { for _, o := range opt {
o(&opts) o(&opts)
} }
@ -186,6 +202,7 @@ func NewServer(opt ...ServerOption) *Server {
conns: make(map[io.Closer]bool), conns: make(map[io.Closer]bool),
m: make(map[string]*service), m: make(map[string]*service),
} }
s.cv = sync.NewCond(&s.mu)
if EnableTracing { if EnableTracing {
_, file, line, _ := runtime.Caller(1) _, file, line, _ := runtime.Caller(1)
s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line)) s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
@ -245,28 +262,45 @@ func (s *Server) register(sd *ServiceDesc, ss interface{}) {
s.m[sd.ServiceName] = srv s.m[sd.ServiceName] = srv
} }
// ServiceInfo contains method names and metadata for a service. // MethodInfo contains the information of an RPC including its method name and type.
type MethodInfo struct {
// Name is the method name only, without the service name or package name.
Name string
// IsClientStream indicates whether the RPC is a client streaming RPC.
IsClientStream bool
// IsServerStream indicates whether the RPC is a server streaming RPC.
IsServerStream bool
}
// ServiceInfo contains unary RPC method info, streaming RPC methid info and metadata for a service.
type ServiceInfo struct { type ServiceInfo struct {
// Methods are method names only, without the service name or package name. Methods []MethodInfo
Methods []string
// Metadata is the metadata specified in ServiceDesc when registering service. // Metadata is the metadata specified in ServiceDesc when registering service.
Metadata interface{} Metadata interface{}
} }
// GetServiceInfo returns a map from service names to ServiceInfo. // GetServiceInfo returns a map from service names to ServiceInfo.
// Service names include the package names, in the form of <package>.<service>. // Service names include the package names, in the form of <package>.<service>.
func (s *Server) GetServiceInfo() map[string]*ServiceInfo { func (s *Server) GetServiceInfo() map[string]ServiceInfo {
ret := make(map[string]*ServiceInfo) ret := make(map[string]ServiceInfo)
for n, srv := range s.m { for n, srv := range s.m {
methods := make([]string, 0, len(srv.md)+len(srv.sd)) methods := make([]MethodInfo, 0, len(srv.md)+len(srv.sd))
for m := range srv.md { for m := range srv.md {
methods = append(methods, m) methods = append(methods, MethodInfo{
Name: m,
IsClientStream: false,
IsServerStream: false,
})
} }
for m := range srv.sd { for m, d := range srv.sd {
methods = append(methods, m) methods = append(methods, MethodInfo{
Name: m,
IsClientStream: d.ClientStreams,
IsServerStream: d.ServerStreams,
})
} }
ret[n] = &ServiceInfo{ ret[n] = ServiceInfo{
Methods: methods, Methods: methods,
Metadata: srv.mdata, Metadata: srv.mdata,
} }
@ -303,9 +337,11 @@ func (s *Server) Serve(lis net.Listener) error {
s.lis[lis] = true s.lis[lis] = true
s.mu.Unlock() s.mu.Unlock()
defer func() { defer func() {
lis.Close()
s.mu.Lock() s.mu.Lock()
if s.lis != nil && s.lis[lis] {
lis.Close()
delete(s.lis, lis) delete(s.lis, lis)
}
s.mu.Unlock() s.mu.Unlock()
}() }()
for { for {
@ -449,7 +485,7 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea
func (s *Server) addConn(c io.Closer) bool { func (s *Server) addConn(c io.Closer) bool {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.conns == nil { if s.conns == nil || s.drain {
return false return false
} }
s.conns[c] = true s.conns[c] = true
@ -461,6 +497,7 @@ func (s *Server) removeConn(c io.Closer) {
defer s.mu.Unlock() defer s.mu.Unlock()
if s.conns != nil { if s.conns != nil {
delete(s.conns, c) delete(s.conns, c)
s.cv.Signal()
} }
} }
@ -501,7 +538,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
} }
p := &parser{r: stream} p := &parser{r: stream}
for { for {
pf, req, err := p.recvMsg() pf, req, err := p.recvMsg(s.opts.maxMsgSize)
if err == io.EOF { if err == io.EOF {
// The entire stream is done (for unary RPC only). // The entire stream is done (for unary RPC only).
return err return err
@ -511,6 +548,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
} }
if err != nil { if err != nil {
switch err := err.(type) { switch err := err.(type) {
case *rpcError:
if err := t.WriteStatus(stream, err.code, err.desc); err != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
}
case transport.ConnectionError: case transport.ConnectionError:
// Nothing to do here. // Nothing to do here.
case transport.StreamError: case transport.StreamError:
@ -550,6 +591,12 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
return err return err
} }
} }
if len(req) > s.opts.maxMsgSize {
// TODO: Revisit the error code. Currently keep it consistent with
// java implementation.
statusCode = codes.Internal
statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize)
}
if err := s.opts.codec.Unmarshal(req, v); err != nil { if err := s.opts.codec.Unmarshal(req, v); err != nil {
return err return err
} }
@ -560,7 +607,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
} }
reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt) reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt)
if appErr != nil { if appErr != nil {
if err, ok := appErr.(rpcError); ok { if err, ok := appErr.(*rpcError); ok {
statusCode = err.code statusCode = err.code
statusDesc = err.desc statusDesc = err.desc
} else { } else {
@ -615,6 +662,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
codec: s.opts.codec, codec: s.opts.codec,
cp: s.opts.cp, cp: s.opts.cp,
dc: s.opts.dc, dc: s.opts.dc,
maxMsgSize: s.opts.maxMsgSize,
trInfo: trInfo, trInfo: trInfo,
} }
if ss.cp != nil { if ss.cp != nil {
@ -645,7 +693,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
appErr = s.opts.streamInt(srv.server, ss, info, sd.Handler) appErr = s.opts.streamInt(srv.server, ss, info, sd.Handler)
} }
if appErr != nil { if appErr != nil {
if err, ok := appErr.(rpcError); ok { if err, ok := appErr.(*rpcError); ok {
ss.statusCode = err.code ss.statusCode = err.code
ss.statusDesc = err.desc ss.statusDesc = err.desc
} else if err, ok := appErr.(transport.StreamError); ok { } else if err, ok := appErr.(transport.StreamError); ok {
@ -747,14 +795,16 @@ func (s *Server) Stop() {
s.mu.Lock() s.mu.Lock()
listeners := s.lis listeners := s.lis
s.lis = nil s.lis = nil
cs := s.conns st := s.conns
s.conns = nil s.conns = nil
// interrupt GracefulStop if Stop and GracefulStop are called concurrently.
s.cv.Signal()
s.mu.Unlock() s.mu.Unlock()
for lis := range listeners { for lis := range listeners {
lis.Close() lis.Close()
} }
for c := range cs { for c := range st {
c.Close() c.Close()
} }
@ -766,6 +816,32 @@ func (s *Server) Stop() {
s.mu.Unlock() s.mu.Unlock()
} }
// GracefulStop stops the gRPC server gracefully. It stops the server to accept new
// connections and RPCs and blocks until all the pending RPCs are finished.
func (s *Server) GracefulStop() {
s.mu.Lock()
defer s.mu.Unlock()
if s.drain == true || s.conns == nil {
return
}
s.drain = true
for lis := range s.lis {
lis.Close()
}
s.lis = nil
for c := range s.conns {
c.(transport.ServerTransport).Drain()
}
for len(s.conns) != 0 {
s.cv.Wait()
}
s.conns = nil
if s.events != nil {
s.events.Finish()
s.events = nil
}
}
func init() { func init() {
internal.TestingCloseConns = func(arg interface{}) { internal.TestingCloseConns = func(arg interface{}) {
arg.(*Server).testingCloseConns() arg.(*Server).testingCloseConns()

View File

@ -79,7 +79,7 @@ func TestGetServiceInfo(t *testing.T) {
{ {
StreamName: "EmptyStream", StreamName: "EmptyStream",
Handler: nil, Handler: nil,
ServerStreams: true, ServerStreams: false,
ClientStreams: true, ClientStreams: true,
}, },
}, },
@ -90,17 +90,24 @@ func TestGetServiceInfo(t *testing.T) {
server.RegisterService(&testSd, &testServer{}) server.RegisterService(&testSd, &testServer{})
info := server.GetServiceInfo() info := server.GetServiceInfo()
want := map[string]*ServiceInfo{ want := map[string]ServiceInfo{
"grpc.testing.EmptyService": &ServiceInfo{ "grpc.testing.EmptyService": {
Methods: []string{ Methods: []MethodInfo{
"EmptyCall", {
"EmptyStream", Name: "EmptyCall",
IsClientStream: false,
IsServerStream: false,
}, },
{
Name: "EmptyStream",
IsClientStream: true,
IsServerStream: false,
}},
Metadata: []int{0, 2, 1, 3}, Metadata: []int{0, 2, 1, 3},
}, },
} }
if !reflect.DeepEqual(info, want) { if !reflect.DeepEqual(info, want) {
t.Errorf("GetServiceInfo() = %q, want %q", info, want) t.Errorf("GetServiceInfo() = %+v, want %+v", info, want)
} }
} }

155
stream.go
View File

@ -37,6 +37,7 @@ import (
"bytes" "bytes"
"errors" "errors"
"io" "io"
"math"
"sync" "sync"
"time" "time"
@ -84,12 +85,9 @@ type ClientStream interface {
// Header returns the header metadata received from the server if there // Header returns the header metadata received from the server if there
// is any. It blocks if the metadata is not ready to read. // is any. It blocks if the metadata is not ready to read.
Header() (metadata.MD, error) Header() (metadata.MD, error)
// Trailer returns the trailer metadata from the server. It must be called // Trailer returns the trailer metadata from the server, if there is any.
// after stream.Recv() returns non-nil error (including io.EOF) for // It must only be called after stream.CloseAndRecv has returned, or
// bi-directional streaming and server streaming or stream.CloseAndRecv() // stream.Recv has returned a non-nil error (including io.EOF).
// returns for client streaming in order to receive trailer metadata if
// present. Otherwise, it could returns an empty MD even though trailer
// is present.
Trailer() metadata.MD Trailer() metadata.MD
// CloseSend closes the send direction of the stream. It closes the stream // CloseSend closes the send direction of the stream. It closes the stream
// when non-nil error is met. // when non-nil error is met.
@ -99,20 +97,18 @@ type ClientStream interface {
// NewClientStream creates a new Stream for the client side. This is called // NewClientStream creates a new Stream for the client side. This is called
// by generated code. // by generated code.
func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) { func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
var ( var (
t transport.ClientTransport t transport.ClientTransport
err error s *transport.Stream
put func() put func()
) )
// TODO(zhaoq): CallOption is omitted. Add support when it is needed. c := defaultCallInfo
gopts := BalancerGetOptions{ for _, o := range opts {
BlockingWait: false, if err := o.before(&c); err != nil {
}
t, put, err = cc.getTransport(ctx, gopts)
if err != nil {
return nil, toRPCErr(err) return nil, toRPCErr(err)
} }
}
callHdr := &transport.CallHdr{ callHdr := &transport.CallHdr{
Host: cc.authority, Host: cc.authority,
Method: method, Method: method,
@ -121,41 +117,98 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
if cc.dopts.cp != nil { if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type() callHdr.SendCompress = cc.dopts.cp.Type()
} }
var trInfo traceInfo
if EnableTracing {
trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
trInfo.firstLine.client = true
if deadline, ok := ctx.Deadline(); ok {
trInfo.firstLine.deadline = deadline.Sub(time.Now())
}
trInfo.tr.LazyLog(&trInfo.firstLine, false)
ctx = trace.NewContext(ctx, trInfo.tr)
defer func() {
if err != nil {
// Need to call tr.finish() if error is returned.
// Because tr will not be returned to caller.
trInfo.tr.LazyPrintf("RPC: [%v]", err)
trInfo.tr.SetError()
trInfo.tr.Finish()
}
}()
}
gopts := BalancerGetOptions{
BlockingWait: !c.failFast,
}
for {
t, put, err = cc.getTransport(ctx, gopts)
if err != nil {
// TODO(zhaoq): Probably revisit the error handling.
if _, ok := err.(*rpcError); ok {
return nil, err
}
if err == errConnClosing || err == errConnUnavailable {
if c.failFast {
return nil, Errorf(codes.Unavailable, "%v", err)
}
continue
}
// All the other errors are treated as Internal errors.
return nil, Errorf(codes.Internal, "%v", err)
}
s, err = t.NewStream(ctx, callHdr)
if err != nil {
if put != nil {
put()
put = nil
}
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
if c.failFast {
return nil, toRPCErr(err)
}
continue
}
return nil, toRPCErr(err)
}
break
}
cs := &clientStream{ cs := &clientStream{
opts: opts,
c: c,
desc: desc, desc: desc,
put: put,
codec: cc.dopts.codec, codec: cc.dopts.codec,
cp: cc.dopts.cp, cp: cc.dopts.cp,
dc: cc.dopts.dc, dc: cc.dopts.dc,
put: put,
t: t,
s: s,
p: &parser{r: s},
tracing: EnableTracing, tracing: EnableTracing,
trInfo: trInfo,
} }
if cc.dopts.cp != nil { if cc.dopts.cp != nil {
callHdr.SendCompress = cc.dopts.cp.Type()
cs.cbuf = new(bytes.Buffer) cs.cbuf = new(bytes.Buffer)
} }
if cs.tracing { // Listen on ctx.Done() to detect cancellation and s.Done() to detect normal termination
cs.trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method) // when there is no pending I/O operations on this stream.
cs.trInfo.firstLine.client = true
if deadline, ok := ctx.Deadline(); ok {
cs.trInfo.firstLine.deadline = deadline.Sub(time.Now())
}
cs.trInfo.tr.LazyLog(&cs.trInfo.firstLine, false)
ctx = trace.NewContext(ctx, cs.trInfo.tr)
}
s, err := t.NewStream(ctx, callHdr)
if err != nil {
cs.finish(err)
return nil, toRPCErr(err)
}
cs.t = t
cs.s = s
cs.p = &parser{r: s}
// Listen on ctx.Done() to detect cancellation when there is no pending
// I/O operations on this stream.
go func() { go func() {
select { select {
case <-t.Error(): case <-t.Error():
// Incur transport error, simply exit. // Incur transport error, simply exit.
case <-s.Done():
// TODO: The trace of the RPC is terminated here when there is no pending
// I/O, which is probably not the optimal solution.
if s.StatusCode() == codes.OK {
cs.finish(nil)
} else {
cs.finish(Errorf(s.StatusCode(), "%s", s.StatusDesc()))
}
cs.closeTransportStream(nil)
case <-s.GoAway():
cs.finish(errConnDrain)
cs.closeTransportStream(errConnDrain)
case <-s.Context().Done(): case <-s.Context().Done():
err := s.Context().Err() err := s.Context().Err()
cs.finish(err) cs.finish(err)
@ -167,6 +220,8 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
// clientStream implements a client side Stream. // clientStream implements a client side Stream.
type clientStream struct { type clientStream struct {
opts []CallOption
c callInfo
t transport.ClientTransport t transport.ClientTransport
s *transport.Stream s *transport.Stream
p *parser p *parser
@ -216,7 +271,17 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
if err != nil { if err != nil {
cs.finish(err) cs.finish(err)
} }
if err == nil || err == io.EOF { if err == nil {
return
}
if err == io.EOF {
// Specialize the process for server streaming. SendMesg is only called
// once when creating the stream object. io.EOF needs to be skipped when
// the rpc is early finished (before the stream object is created.).
// TODO: It is probably better to move this into the generated code.
if !cs.desc.ClientStreams && cs.desc.ServerStreams {
err = nil
}
return return
} }
if _, ok := err.(transport.ConnectionError); !ok { if _, ok := err.(transport.ConnectionError); !ok {
@ -237,7 +302,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
} }
func (cs *clientStream) RecvMsg(m interface{}) (err error) { func (cs *clientStream) RecvMsg(m interface{}) (err error) {
err = recv(cs.p, cs.codec, cs.s, cs.dc, m) err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
defer func() { defer func() {
// err != nil indicates the termination of the stream. // err != nil indicates the termination of the stream.
if err != nil { if err != nil {
@ -256,7 +321,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
return return
} }
// Special handling for client streaming rpc. // Special handling for client streaming rpc.
err = recv(cs.p, cs.codec, cs.s, cs.dc, m) err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
cs.closeTransportStream(err) cs.closeTransportStream(err)
if err == nil { if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>")) return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
@ -291,7 +356,7 @@ func (cs *clientStream) CloseSend() (err error) {
} }
}() }()
if err == nil || err == io.EOF { if err == nil || err == io.EOF {
return return nil
} }
if _, ok := err.(transport.ConnectionError); !ok { if _, ok := err.(transport.ConnectionError); !ok {
cs.closeTransportStream(err) cs.closeTransportStream(err)
@ -312,15 +377,18 @@ func (cs *clientStream) closeTransportStream(err error) {
} }
func (cs *clientStream) finish(err error) { func (cs *clientStream) finish(err error) {
if !cs.tracing {
return
}
cs.mu.Lock() cs.mu.Lock()
defer cs.mu.Unlock() defer cs.mu.Unlock()
for _, o := range cs.opts {
o.after(&cs.c)
}
if cs.put != nil { if cs.put != nil {
cs.put() cs.put()
cs.put = nil cs.put = nil
} }
if !cs.tracing {
return
}
if cs.trInfo.tr != nil { if cs.trInfo.tr != nil {
if err == nil || err == io.EOF { if err == nil || err == io.EOF {
cs.trInfo.tr.LazyPrintf("RPC: [OK]") cs.trInfo.tr.LazyPrintf("RPC: [OK]")
@ -354,6 +422,7 @@ type serverStream struct {
cp Compressor cp Compressor
dc Decompressor dc Decompressor
cbuf *bytes.Buffer cbuf *bytes.Buffer
maxMsgSize int
statusCode codes.Code statusCode codes.Code
statusDesc string statusDesc string
trInfo *traceInfo trInfo *traceInfo
@ -420,5 +489,5 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
ss.mu.Unlock() ss.mu.Unlock()
} }
}() }()
return recv(ss.p, ss.codec, ss.s, ss.dc, m) return recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize)
} }

View File

@ -162,7 +162,7 @@ func (s *server) GetAllGauges(in *metricspb.EmptyMessage, stream metricspb.Metri
defer s.mutex.RUnlock() defer s.mutex.RUnlock()
for name, gauge := range s.gauges { for name, gauge := range s.gauges {
if err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{gauge.get()}}); err != nil { if err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{LongValue: gauge.get()}}); err != nil {
return err return err
} }
} }
@ -175,7 +175,7 @@ func (s *server) GetGauge(ctx context.Context, in *metricspb.GaugeRequest) (*met
defer s.mutex.RUnlock() defer s.mutex.RUnlock()
if g, ok := s.gauges[in.Name]; ok { if g, ok := s.gauges[in.Name]; ok {
return &metricspb.GaugeResponse{Name: in.Name, Value: &metricspb.GaugeResponse_LongValue{g.get()}}, nil return &metricspb.GaugeResponse{Name: in.Name, Value: &metricspb.GaugeResponse_LongValue{LongValue: g.get()}}, nil
} }
return nil, grpc.Errorf(codes.InvalidArgument, "gauge with name %s not found", in.Name) return nil, grpc.Errorf(codes.InvalidArgument, "gauge with name %s not found", in.Name)
} }

File diff suppressed because it is too large Load Diff

View File

@ -72,6 +72,11 @@ type resetStream struct {
func (*resetStream) item() {} func (*resetStream) item() {}
type goAway struct {
}
func (*goAway) item() {}
type flushIO struct { type flushIO struct {
} }

46
transport/go16.go Normal file
View File

@ -0,0 +1,46 @@
// +build go1.6,!go1.7
/*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package transport
import (
"net"
"golang.org/x/net/context"
)
// dialContext connects to the address on the named network.
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
return (&net.Dialer{Cancel: ctx.Done()}).Dial(network, address)
}

46
transport/go17.go Normal file
View File

@ -0,0 +1,46 @@
// +build go1.7
/*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package transport
import (
"net"
"golang.org/x/net/context"
)
// dialContext connects to the address on the named network.
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
return (&net.Dialer{}).DialContext(ctx, network, address)
}

View File

@ -83,7 +83,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
} }
if v := r.Header.Get("grpc-timeout"); v != "" { if v := r.Header.Get("grpc-timeout"); v != "" {
to, err := timeoutDecode(v) to, err := decodeTimeout(v)
if err != nil { if err != nil {
return nil, StreamErrorf(codes.Internal, "malformed time-out: %v", err) return nil, StreamErrorf(codes.Internal, "malformed time-out: %v", err)
} }
@ -194,7 +194,7 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code,
h := ht.rw.Header() h := ht.rw.Header()
h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode)) h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode))
if statusDesc != "" { if statusDesc != "" {
h.Set("Grpc-Message", statusDesc) h.Set("Grpc-Message", encodeGrpcMessage(statusDesc))
} }
if md := s.Trailer(); len(md) > 0 { if md := s.Trailer(); len(md) > 0 {
for k, vv := range md { for k, vv := range md {
@ -312,7 +312,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) {
Addr: ht.RemoteAddr(), Addr: ht.RemoteAddr(),
} }
if req.TLS != nil { if req.TLS != nil {
pr.AuthInfo = credentials.TLSInfo{*req.TLS} pr.AuthInfo = credentials.TLSInfo{State: *req.TLS}
} }
ctx = metadata.NewContext(ctx, ht.headerMD) ctx = metadata.NewContext(ctx, ht.headerMD)
ctx = peer.NewContext(ctx, pr) ctx = peer.NewContext(ctx, pr)
@ -370,6 +370,10 @@ func (ht *serverHandlerTransport) runStream() {
} }
} }
func (ht *serverHandlerTransport) Drain() {
panic("Drain() is not implemented")
}
// mapRecvMsgError returns the non-nil err into the appropriate // mapRecvMsgError returns the non-nil err into the appropriate
// error value as expected by callers of *grpc.parser.recvMsg. // error value as expected by callers of *grpc.parser.recvMsg.
// In particular, in can only be: // In particular, in can only be:

View File

@ -333,7 +333,7 @@ func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string)
"Content-Type": {"application/grpc"}, "Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message"}, "Trailer": {"Grpc-Status", "Grpc-Message"},
"Grpc-Status": {fmt.Sprint(uint32(statusCode))}, "Grpc-Status": {fmt.Sprint(uint32(statusCode))},
"Grpc-Message": {msg}, "Grpc-Message": {encodeGrpcMessage(msg)},
} }
if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) { if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) {
t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", st.rw.HeaderMap, wantHeader) t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", st.rw.HeaderMap, wantHeader)
@ -381,7 +381,7 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
"Content-Type": {"application/grpc"}, "Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message"}, "Trailer": {"Grpc-Status", "Grpc-Message"},
"Grpc-Status": {"4"}, "Grpc-Status": {"4"},
"Grpc-Message": {"too slow"}, "Grpc-Message": {encodeGrpcMessage("too slow")},
} }
if !reflect.DeepEqual(rw.HeaderMap, wantHeader) { if !reflect.DeepEqual(rw.HeaderMap, wantHeader) {
t.Errorf("Header+Trailer Map mismatch.\n got: %#v\nwant: %#v", rw.HeaderMap, wantHeader) t.Errorf("Header+Trailer Map mismatch.\n got: %#v\nwant: %#v", rw.HeaderMap, wantHeader)

View File

@ -35,6 +35,7 @@ package transport
import ( import (
"bytes" "bytes"
"fmt"
"io" "io"
"math" "math"
"net" "net"
@ -71,6 +72,9 @@ type http2Client struct {
shutdownChan chan struct{} shutdownChan chan struct{}
// errorChan is closed to notify the I/O error to the caller. // errorChan is closed to notify the I/O error to the caller.
errorChan chan struct{} errorChan chan struct{}
// goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor)
// that the server sent GoAway on this transport.
goAway chan struct{}
framer *framer framer *framer
hBuf *bytes.Buffer // the buffer for HPACK encoding hBuf *bytes.Buffer // the buffer for HPACK encoding
@ -97,41 +101,49 @@ type http2Client struct {
maxStreams int maxStreams int
// the per-stream outbound flow control window size set by the peer. // the per-stream outbound flow control window size set by the peer.
streamSendQuota uint32 streamSendQuota uint32
// goAwayID records the Last-Stream-ID in the GoAway frame from the server.
goAwayID uint32
// prevGoAway ID records the Last-Stream-ID in the previous GOAway frame.
prevGoAwayID uint32
}
func dial(fn func(context.Context, string) (net.Conn, error), ctx context.Context, addr string) (net.Conn, error) {
if fn != nil {
return fn(ctx, addr)
}
return dialContext(ctx, "tcp", addr)
} }
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
// and starts to receive messages on it. Non-nil error returns if construction // and starts to receive messages on it. Non-nil error returns if construction
// fails. // fails.
func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) { func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ ClientTransport, err error) {
if opts.Dialer == nil {
// Set the default Dialer.
opts.Dialer = func(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("tcp", addr, timeout)
}
}
scheme := "http" scheme := "http"
startT := time.Now() conn, connErr := dial(opts.Dialer, ctx, addr)
timeout := opts.Timeout
conn, connErr := opts.Dialer(addr, timeout)
if connErr != nil { if connErr != nil {
return nil, ConnectionErrorf("transport: %v", connErr) return nil, ConnectionErrorf(true, connErr, "transport: %v", connErr)
} }
var authInfo credentials.AuthInfo // Any further errors will close the underlying connection
if opts.TransportCredentials != nil { defer func(conn net.Conn) {
scheme = "https"
if timeout > 0 {
timeout -= time.Since(startT)
}
conn, authInfo, connErr = opts.TransportCredentials.ClientHandshake(addr, conn, timeout)
}
if connErr != nil {
return nil, ConnectionErrorf("transport: %v", connErr)
}
defer func() {
if err != nil { if err != nil {
conn.Close() conn.Close()
} }
}() }(conn)
var authInfo credentials.AuthInfo
if creds := opts.TransportCredentials; creds != nil {
scheme = "https"
conn, authInfo, connErr = creds.ClientHandshake(ctx, addr, conn)
}
if connErr != nil {
// Credentials handshake error is not a temporary error (unless the error
// was the connection closing or deadline exceeded).
var temp bool
switch connErr {
case io.EOF, context.DeadlineExceeded:
temp = true
}
return nil, ConnectionErrorf(temp, connErr, "transport: %v", connErr)
}
ua := primaryUA ua := primaryUA
if opts.UserAgent != "" { if opts.UserAgent != "" {
ua = opts.UserAgent + " " + ua ua = opts.UserAgent + " " + ua
@ -147,6 +159,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
writableChan: make(chan int, 1), writableChan: make(chan int, 1),
shutdownChan: make(chan struct{}), shutdownChan: make(chan struct{}),
errorChan: make(chan struct{}), errorChan: make(chan struct{}),
goAway: make(chan struct{}),
framer: newFramer(conn), framer: newFramer(conn),
hBuf: &buf, hBuf: &buf,
hEnc: hpack.NewEncoder(&buf), hEnc: hpack.NewEncoder(&buf),
@ -168,26 +181,29 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
n, err := t.conn.Write(clientPreface) n, err := t.conn.Write(clientPreface)
if err != nil { if err != nil {
t.Close() t.Close()
return nil, ConnectionErrorf("transport: %v", err) return nil, ConnectionErrorf(true, err, "transport: %v", err)
} }
if n != len(clientPreface) { if n != len(clientPreface) {
t.Close() t.Close()
return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) return nil, ConnectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
} }
if initialWindowSize != defaultWindowSize { if initialWindowSize != defaultWindowSize {
err = t.framer.writeSettings(true, http2.Setting{http2.SettingInitialWindowSize, uint32(initialWindowSize)}) err = t.framer.writeSettings(true, http2.Setting{
ID: http2.SettingInitialWindowSize,
Val: uint32(initialWindowSize),
})
} else { } else {
err = t.framer.writeSettings(true) err = t.framer.writeSettings(true)
} }
if err != nil { if err != nil {
t.Close() t.Close()
return nil, ConnectionErrorf("transport: %v", err) return nil, ConnectionErrorf(true, err, "transport: %v", err)
} }
// Adjust the connection flow control window if needed. // Adjust the connection flow control window if needed.
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 { if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil { if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil {
t.Close() t.Close()
return nil, ConnectionErrorf("transport: %v", err) return nil, ConnectionErrorf(true, err, "transport: %v", err)
} }
} }
go t.controller() go t.controller()
@ -199,6 +215,8 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
// TODO(zhaoq): Handle uint32 overflow of Stream.id. // TODO(zhaoq): Handle uint32 overflow of Stream.id.
s := &Stream{ s := &Stream{
id: t.nextID, id: t.nextID,
done: make(chan struct{}),
goAway: make(chan struct{}),
method: callHdr.Method, method: callHdr.Method,
sendCompress: callHdr.SendCompress, sendCompress: callHdr.SendCompress,
buf: newRecvBuffer(), buf: newRecvBuffer(),
@ -214,6 +232,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
s.ctx, s.cancel = context.WithCancel(ctx) s.ctx, s.cancel = context.WithCancel(ctx)
s.dec = &recvBufferReader{ s.dec = &recvBufferReader{
ctx: s.ctx, ctx: s.ctx,
goAway: s.goAway,
recv: s.buf, recv: s.buf,
} }
return s return s
@ -268,6 +287,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.mu.Unlock() t.mu.Unlock()
return nil, ErrConnClosing return nil, ErrConnClosing
} }
if t.state == draining {
t.mu.Unlock()
return nil, ErrStreamDrain
}
if t.state != reachable { if t.state != reachable {
t.mu.Unlock() t.mu.Unlock()
return nil, ErrConnClosing return nil, ErrConnClosing
@ -275,7 +298,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
checkStreamsQuota := t.streamsQuota != nil checkStreamsQuota := t.streamsQuota != nil
t.mu.Unlock() t.mu.Unlock()
if checkStreamsQuota { if checkStreamsQuota {
sq, err := wait(ctx, t.shutdownChan, t.streamsQuota.acquire()) sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -284,7 +307,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.streamsQuota.add(sq - 1) t.streamsQuota.add(sq - 1)
} }
} }
if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
// Return the quota back now because there is no stream returned to the caller. // Return the quota back now because there is no stream returned to the caller.
if _, ok := err.(StreamError); ok && checkStreamsQuota { if _, ok := err.(StreamError); ok && checkStreamsQuota {
t.streamsQuota.add(1) t.streamsQuota.add(1)
@ -292,6 +315,15 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
return nil, err return nil, err
} }
t.mu.Lock() t.mu.Lock()
if t.state == draining {
t.mu.Unlock()
if checkStreamsQuota {
t.streamsQuota.add(1)
}
// Need to make t writable again so that the rpc in flight can still proceed.
t.writableChan <- 0
return nil, ErrStreamDrain
}
if t.state != reachable { if t.state != reachable {
t.mu.Unlock() t.mu.Unlock()
return nil, ErrConnClosing return nil, ErrConnClosing
@ -326,7 +358,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress}) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
} }
if timeout > 0 { if timeout > 0 {
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)}) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
} }
for k, v := range authData { for k, v := range authData {
// Capital header names are illegal in HTTP/2. // Capital header names are illegal in HTTP/2.
@ -381,7 +413,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
} }
if err != nil { if err != nil {
t.notifyError(err) t.notifyError(err)
return nil, ConnectionErrorf("transport: %v", err) return nil, ConnectionErrorf(true, err, "transport: %v", err)
} }
} }
t.writableChan <- 0 t.writableChan <- 0
@ -400,22 +432,17 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
if t.streamsQuota != nil { if t.streamsQuota != nil {
updateStreams = true updateStreams = true
} }
if t.state == draining && len(t.activeStreams) == 1 { delete(t.activeStreams, s.id)
if t.state == draining && len(t.activeStreams) == 0 {
// The transport is draining and s is the last live stream on t. // The transport is draining and s is the last live stream on t.
t.mu.Unlock() t.mu.Unlock()
t.Close() t.Close()
return return
} }
delete(t.activeStreams, s.id)
t.mu.Unlock() t.mu.Unlock()
if updateStreams { if updateStreams {
t.streamsQuota.add(1) t.streamsQuota.add(1)
} }
// In case stream sending and receiving are invoked in separate
// goroutines (e.g., bi-directional streaming), the caller needs
// to call cancel on the stream to interrupt the blocking on
// other goroutines.
s.cancel()
s.mu.Lock() s.mu.Lock()
if q := s.fc.resetPendingData(); q > 0 { if q := s.fc.resetPendingData(); q > 0 {
if n := t.fc.onRead(q); n > 0 { if n := t.fc.onRead(q); n > 0 {
@ -442,13 +469,13 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
// accessed any more. // accessed any more.
func (t *http2Client) Close() (err error) { func (t *http2Client) Close() (err error) {
t.mu.Lock() t.mu.Lock()
if t.state == reachable {
close(t.errorChan)
}
if t.state == closing { if t.state == closing {
t.mu.Unlock() t.mu.Unlock()
return return
} }
if t.state == reachable || t.state == draining {
close(t.errorChan)
}
t.state = closing t.state = closing
t.mu.Unlock() t.mu.Unlock()
close(t.shutdownChan) close(t.shutdownChan)
@ -472,10 +499,35 @@ func (t *http2Client) Close() (err error) {
func (t *http2Client) GracefulClose() error { func (t *http2Client) GracefulClose() error {
t.mu.Lock() t.mu.Lock()
if t.state == closing { switch t.state {
case unreachable:
// The server may close the connection concurrently. t is not available for
// any streams. Close it now.
t.mu.Unlock()
t.Close()
return nil
case closing:
t.mu.Unlock() t.mu.Unlock()
return nil return nil
} }
// Notify the streams which were initiated after the server sent GOAWAY.
select {
case <-t.goAway:
n := t.prevGoAwayID
if n == 0 && t.nextID > 1 {
n = t.nextID - 2
}
m := t.goAwayID + 2
if m == 2 {
m = 1
}
for i := m; i <= n; i += 2 {
if s, ok := t.activeStreams[i]; ok {
close(s.goAway)
}
}
default:
}
if t.state == draining { if t.state == draining {
t.mu.Unlock() t.mu.Unlock()
return nil return nil
@ -501,15 +553,15 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
size := http2MaxFrameLen size := http2MaxFrameLen
s.sendQuotaPool.add(0) s.sendQuotaPool.add(0)
// Wait until the stream has some quota to send the data. // Wait until the stream has some quota to send the data.
sq, err := wait(s.ctx, t.shutdownChan, s.sendQuotaPool.acquire()) sq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, s.sendQuotaPool.acquire())
if err != nil { if err != nil {
return err return err
} }
t.sendQuotaPool.add(0) t.sendQuotaPool.add(0)
// Wait until the transport has some quota to send the data. // Wait until the transport has some quota to send the data.
tq, err := wait(s.ctx, t.shutdownChan, t.sendQuotaPool.acquire()) tq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.sendQuotaPool.acquire())
if err != nil { if err != nil {
if _, ok := err.(StreamError); ok { if _, ok := err.(StreamError); ok || err == io.EOF {
t.sendQuotaPool.cancel() t.sendQuotaPool.cancel()
} }
return err return err
@ -541,8 +593,8 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
// Indicate there is a writer who is about to write a data frame. // Indicate there is a writer who is about to write a data frame.
t.framer.adjustNumWriters(1) t.framer.adjustNumWriters(1)
// Got some quota. Try to acquire writing privilege on the transport. // Got some quota. Try to acquire writing privilege on the transport.
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.writableChan); err != nil {
if _, ok := err.(StreamError); ok { if _, ok := err.(StreamError); ok || err == io.EOF {
// Return the connection quota back. // Return the connection quota back.
t.sendQuotaPool.add(len(p)) t.sendQuotaPool.add(len(p))
} }
@ -575,7 +627,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
// invoked. // invoked.
if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil { if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil {
t.notifyError(err) t.notifyError(err)
return ConnectionErrorf("transport: %v", err) return ConnectionErrorf(true, err, "transport: %v", err)
} }
if t.framer.adjustNumWriters(-1) == 0 { if t.framer.adjustNumWriters(-1) == 0 {
t.framer.flushWrite() t.framer.flushWrite()
@ -590,12 +642,8 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
} }
s.mu.Lock() s.mu.Lock()
if s.state != streamDone { if s.state != streamDone {
if s.state == streamReadDone {
s.state = streamDone
} else {
s.state = streamWriteDone s.state = streamWriteDone
} }
}
s.mu.Unlock() s.mu.Unlock()
return nil return nil
} }
@ -627,7 +675,7 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) {
func (t *http2Client) handleData(f *http2.DataFrame) { func (t *http2Client) handleData(f *http2.DataFrame) {
size := len(f.Data()) size := len(f.Data())
if err := t.fc.onData(uint32(size)); err != nil { if err := t.fc.onData(uint32(size)); err != nil {
t.notifyError(ConnectionErrorf("%v", err)) t.notifyError(ConnectionErrorf(true, err, "%v", err))
return return
} }
// Select the right stream to dispatch. // Select the right stream to dispatch.
@ -652,6 +700,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
s.state = streamDone s.state = streamDone
s.statusCode = codes.Internal s.statusCode = codes.Internal
s.statusDesc = err.Error() s.statusDesc = err.Error()
close(s.done)
s.mu.Unlock() s.mu.Unlock()
s.write(recvMsg{err: io.EOF}) s.write(recvMsg{err: io.EOF})
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl}) t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl})
@ -669,13 +718,14 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
// the read direction is closed, and set the status appropriately. // the read direction is closed, and set the status appropriately.
if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) { if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) {
s.mu.Lock() s.mu.Lock()
if s.state == streamWriteDone { if s.state == streamDone {
s.state = streamDone s.mu.Unlock()
} else { return
s.state = streamReadDone
} }
s.state = streamDone
s.statusCode = codes.Internal s.statusCode = codes.Internal
s.statusDesc = "server closed the stream without sending trailers" s.statusDesc = "server closed the stream without sending trailers"
close(s.done)
s.mu.Unlock() s.mu.Unlock()
s.write(recvMsg{err: io.EOF}) s.write(recvMsg{err: io.EOF})
} }
@ -701,6 +751,8 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
grpclog.Println("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error ", f.ErrCode) grpclog.Println("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error ", f.ErrCode)
s.statusCode = codes.Unknown s.statusCode = codes.Unknown
} }
s.statusDesc = fmt.Sprintf("stream terminated by RST_STREAM with error code: %d", f.ErrCode)
close(s.done)
s.mu.Unlock() s.mu.Unlock()
s.write(recvMsg{err: io.EOF}) s.write(recvMsg{err: io.EOF})
} }
@ -725,7 +777,32 @@ func (t *http2Client) handlePing(f *http2.PingFrame) {
} }
func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
// TODO(zhaoq): GoAwayFrame handler to be implemented t.mu.Lock()
if t.state == reachable || t.state == draining {
if f.LastStreamID > 0 && f.LastStreamID%2 != 1 {
t.mu.Unlock()
t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: stream ID %d is even", f.LastStreamID))
return
}
select {
case <-t.goAway:
id := t.goAwayID
// t.goAway has been closed (i.e.,multiple GoAways).
if id < f.LastStreamID {
t.mu.Unlock()
t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID))
return
}
t.prevGoAwayID = id
t.goAwayID = f.LastStreamID
t.mu.Unlock()
return
default:
}
t.goAwayID = f.LastStreamID
close(t.goAway)
}
t.mu.Unlock()
} }
func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) { func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) {
@ -777,11 +854,11 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
if len(state.mdata) > 0 { if len(state.mdata) > 0 {
s.trailer = state.mdata s.trailer = state.mdata
} }
s.state = streamDone
s.statusCode = state.statusCode s.statusCode = state.statusCode
s.statusDesc = state.statusDesc s.statusDesc = state.statusDesc
close(s.done)
s.state = streamDone
s.mu.Unlock() s.mu.Unlock()
s.write(recvMsg{err: io.EOF}) s.write(recvMsg{err: io.EOF})
} }
@ -934,13 +1011,22 @@ func (t *http2Client) Error() <-chan struct{} {
return t.errorChan return t.errorChan
} }
func (t *http2Client) GoAway() <-chan struct{} {
return t.goAway
}
func (t *http2Client) notifyError(err error) { func (t *http2Client) notifyError(err error) {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock()
// make sure t.errorChan is closed only once. // make sure t.errorChan is closed only once.
if t.state == draining {
t.mu.Unlock()
t.Close()
return
}
if t.state == reachable { if t.state == reachable {
t.state = unreachable t.state = unreachable
close(t.errorChan) close(t.errorChan)
grpclog.Printf("transport: http2Client.notifyError got notified that the client transport was broken %v.", err) grpclog.Printf("transport: http2Client.notifyError got notified that the client transport was broken %v.", err)
} }
t.mu.Unlock()
} }

View File

@ -100,18 +100,23 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI
if maxStreams == 0 { if maxStreams == 0 {
maxStreams = math.MaxUint32 maxStreams = math.MaxUint32
} else { } else {
settings = append(settings, http2.Setting{http2.SettingMaxConcurrentStreams, maxStreams}) settings = append(settings, http2.Setting{
ID: http2.SettingMaxConcurrentStreams,
Val: maxStreams,
})
} }
if initialWindowSize != defaultWindowSize { if initialWindowSize != defaultWindowSize {
settings = append(settings, http2.Setting{http2.SettingInitialWindowSize, uint32(initialWindowSize)}) settings = append(settings, http2.Setting{
ID: http2.SettingInitialWindowSize,
Val: uint32(initialWindowSize)})
} }
if err := framer.writeSettings(true, settings...); err != nil { if err := framer.writeSettings(true, settings...); err != nil {
return nil, ConnectionErrorf("transport: %v", err) return nil, ConnectionErrorf(true, err, "transport: %v", err)
} }
// Adjust the connection flow control window if needed. // Adjust the connection flow control window if needed.
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 { if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
if err := framer.writeWindowUpdate(true, 0, delta); err != nil { if err := framer.writeWindowUpdate(true, 0, delta); err != nil {
return nil, ConnectionErrorf("transport: %v", err) return nil, ConnectionErrorf(true, err, "transport: %v", err)
} }
} }
var buf bytes.Buffer var buf bytes.Buffer
@ -137,7 +142,7 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI
} }
// operateHeader takes action on the decoded headers. // operateHeader takes action on the decoded headers.
func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) { func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) (close bool) {
buf := newRecvBuffer() buf := newRecvBuffer()
s := &Stream{ s := &Stream{
id: frame.Header().StreamID, id: frame.Header().StreamID,
@ -200,6 +205,13 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream}) t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
return return
} }
if s.id%2 != 1 || s.id <= t.maxStreamID {
t.mu.Unlock()
// illegal gRPC stream id.
grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", s.id)
return true
}
t.maxStreamID = s.id
s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota)) s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
t.activeStreams[s.id] = s t.activeStreams[s.id] = s
t.mu.Unlock() t.mu.Unlock()
@ -207,6 +219,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
t.updateWindow(s, uint32(n)) t.updateWindow(s, uint32(n))
} }
handle(s) handle(s)
return
} }
// HandleStreams receives incoming streams using the given handler. This is // HandleStreams receives incoming streams using the given handler. This is
@ -226,6 +239,10 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
} }
frame, err := t.framer.readFrame() frame, err := t.framer.readFrame()
if err == io.EOF || err == io.ErrUnexpectedEOF {
t.Close()
return
}
if err != nil { if err != nil {
grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err) grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err)
t.Close() t.Close()
@ -252,20 +269,20 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
t.controlBuf.put(&resetStream{se.StreamID, se.Code}) t.controlBuf.put(&resetStream{se.StreamID, se.Code})
continue continue
} }
if err == io.EOF || err == io.ErrUnexpectedEOF {
t.Close()
return
}
grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err)
t.Close() t.Close()
return return
} }
switch frame := frame.(type) { switch frame := frame.(type) {
case *http2.MetaHeadersFrame: case *http2.MetaHeadersFrame:
id := frame.Header().StreamID if t.operateHeaders(frame, handle) {
if id%2 != 1 || id <= t.maxStreamID {
// illegal gRPC stream id.
grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", id)
t.Close() t.Close()
break break
} }
t.maxStreamID = id
t.operateHeaders(frame, handle)
case *http2.DataFrame: case *http2.DataFrame:
t.handleData(frame) t.handleData(frame)
case *http2.RSTStreamFrame: case *http2.RSTStreamFrame:
@ -277,7 +294,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
case *http2.WindowUpdateFrame: case *http2.WindowUpdateFrame:
t.handleWindowUpdate(frame) t.handleWindowUpdate(frame)
case *http2.GoAwayFrame: case *http2.GoAwayFrame:
break // TODO: Handle GoAway from the client appropriately.
default: default:
grpclog.Printf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame) grpclog.Printf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame)
} }
@ -359,12 +376,8 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
// Received the end of stream from the client. // Received the end of stream from the client.
s.mu.Lock() s.mu.Lock()
if s.state != streamDone { if s.state != streamDone {
if s.state == streamWriteDone {
s.state = streamDone
} else {
s.state = streamReadDone s.state = streamReadDone
} }
}
s.mu.Unlock() s.mu.Unlock()
s.write(recvMsg{err: io.EOF}) s.write(recvMsg{err: io.EOF})
} }
@ -435,7 +448,7 @@ func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) e
} }
if err != nil { if err != nil {
t.Close() t.Close()
return ConnectionErrorf("transport: %v", err) return ConnectionErrorf(true, err, "transport: %v", err)
} }
} }
return nil return nil
@ -450,7 +463,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
} }
s.headerOk = true s.headerOk = true
s.mu.Unlock() s.mu.Unlock()
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
return err return err
} }
t.hBuf.Reset() t.hBuf.Reset()
@ -490,7 +503,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
headersSent = true headersSent = true
} }
s.mu.Unlock() s.mu.Unlock()
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
return err return err
} }
t.hBuf.Reset() t.hBuf.Reset()
@ -503,7 +516,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
Name: "grpc-status", Name: "grpc-status",
Value: strconv.Itoa(int(statusCode)), Value: strconv.Itoa(int(statusCode)),
}) })
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: statusDesc}) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(statusDesc)})
// Attach the trailer metadata. // Attach the trailer metadata.
for k, v := range s.trailer { for k, v := range s.trailer {
// Clients don't tolerate reading restricted headers after some non restricted ones were sent. // Clients don't tolerate reading restricted headers after some non restricted ones were sent.
@ -539,7 +552,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
} }
s.mu.Unlock() s.mu.Unlock()
if writeHeaderFrame { if writeHeaderFrame {
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
return err return err
} }
t.hBuf.Reset() t.hBuf.Reset()
@ -555,7 +568,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
} }
if err := t.framer.writeHeaders(false, p); err != nil { if err := t.framer.writeHeaders(false, p); err != nil {
t.Close() t.Close()
return ConnectionErrorf("transport: %v", err) return ConnectionErrorf(true, err, "transport: %v", err)
} }
t.writableChan <- 0 t.writableChan <- 0
} }
@ -567,13 +580,13 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
size := http2MaxFrameLen size := http2MaxFrameLen
s.sendQuotaPool.add(0) s.sendQuotaPool.add(0)
// Wait until the stream has some quota to send the data. // Wait until the stream has some quota to send the data.
sq, err := wait(s.ctx, t.shutdownChan, s.sendQuotaPool.acquire()) sq, err := wait(s.ctx, nil, nil, t.shutdownChan, s.sendQuotaPool.acquire())
if err != nil { if err != nil {
return err return err
} }
t.sendQuotaPool.add(0) t.sendQuotaPool.add(0)
// Wait until the transport has some quota to send the data. // Wait until the transport has some quota to send the data.
tq, err := wait(s.ctx, t.shutdownChan, t.sendQuotaPool.acquire()) tq, err := wait(s.ctx, nil, nil, t.shutdownChan, t.sendQuotaPool.acquire())
if err != nil { if err != nil {
if _, ok := err.(StreamError); ok { if _, ok := err.(StreamError); ok {
t.sendQuotaPool.cancel() t.sendQuotaPool.cancel()
@ -599,7 +612,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
t.framer.adjustNumWriters(1) t.framer.adjustNumWriters(1)
// Got some quota. Try to acquire writing privilege on the // Got some quota. Try to acquire writing privilege on the
// transport. // transport.
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
if _, ok := err.(StreamError); ok { if _, ok := err.(StreamError); ok {
// Return the connection quota back. // Return the connection quota back.
t.sendQuotaPool.add(ps) t.sendQuotaPool.add(ps)
@ -629,7 +642,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
} }
if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil { if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil {
t.Close() t.Close()
return ConnectionErrorf("transport: %v", err) return ConnectionErrorf(true, err, "transport: %v", err)
} }
if t.framer.adjustNumWriters(-1) == 0 { if t.framer.adjustNumWriters(-1) == 0 {
t.framer.flushWrite() t.framer.flushWrite()
@ -674,6 +687,17 @@ func (t *http2Server) controller() {
} }
case *resetStream: case *resetStream:
t.framer.writeRSTStream(true, i.streamID, i.code) t.framer.writeRSTStream(true, i.streamID, i.code)
case *goAway:
t.mu.Lock()
if t.state == closing {
t.mu.Unlock()
// The transport is closing.
return
}
sid := t.maxStreamID
t.state = draining
t.mu.Unlock()
t.framer.writeGoAway(true, sid, http2.ErrCodeNo, nil)
case *flushIO: case *flushIO:
t.framer.flushWrite() t.framer.flushWrite()
case *ping: case *ping:
@ -719,6 +743,9 @@ func (t *http2Server) Close() (err error) {
func (t *http2Server) closeStream(s *Stream) { func (t *http2Server) closeStream(s *Stream) {
t.mu.Lock() t.mu.Lock()
delete(t.activeStreams, s.id) delete(t.activeStreams, s.id)
if t.state == draining && len(t.activeStreams) == 0 {
defer t.Close()
}
t.mu.Unlock() t.mu.Unlock()
// In case stream sending and receiving are invoked in separate // In case stream sending and receiving are invoked in separate
// goroutines (e.g., bi-directional streaming), cancel needs to be // goroutines (e.g., bi-directional streaming), cancel needs to be
@ -741,3 +768,7 @@ func (t *http2Server) closeStream(s *Stream) {
func (t *http2Server) RemoteAddr() net.Addr { func (t *http2Server) RemoteAddr() net.Addr {
return t.conn.RemoteAddr() return t.conn.RemoteAddr()
} }
func (t *http2Server) Drain() {
t.controlBuf.put(&goAway{})
}

View File

@ -35,6 +35,7 @@ package transport
import ( import (
"bufio" "bufio"
"bytes"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -52,7 +53,7 @@ import (
const ( const (
// The primary user agent // The primary user agent
primaryUA = "grpc-go/0.11" primaryUA = "grpc-go/1.0"
// http2MaxFrameLen specifies the max length of a HTTP2 frame. // http2MaxFrameLen specifies the max length of a HTTP2 frame.
http2MaxFrameLen = 16384 // 16KB frame http2MaxFrameLen = 16384 // 16KB frame
// http://http2.github.io/http2-spec/#SettingValues // http://http2.github.io/http2-spec/#SettingValues
@ -174,11 +175,11 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) {
} }
d.statusCode = codes.Code(code) d.statusCode = codes.Code(code)
case "grpc-message": case "grpc-message":
d.statusDesc = f.Value d.statusDesc = decodeGrpcMessage(f.Value)
case "grpc-timeout": case "grpc-timeout":
d.timeoutSet = true d.timeoutSet = true
var err error var err error
d.timeout, err = timeoutDecode(f.Value) d.timeout, err = decodeTimeout(f.Value)
if err != nil { if err != nil {
d.setErr(StreamErrorf(codes.Internal, "transport: malformed time-out: %v", err)) d.setErr(StreamErrorf(codes.Internal, "transport: malformed time-out: %v", err))
return return
@ -251,7 +252,7 @@ func div(d, r time.Duration) int64 {
} }
// TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it. // TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it.
func timeoutEncode(t time.Duration) string { func encodeTimeout(t time.Duration) string {
if d := div(t, time.Nanosecond); d <= maxTimeoutValue { if d := div(t, time.Nanosecond); d <= maxTimeoutValue {
return strconv.FormatInt(d, 10) + "n" return strconv.FormatInt(d, 10) + "n"
} }
@ -271,7 +272,7 @@ func timeoutEncode(t time.Duration) string {
return strconv.FormatInt(div(t, time.Hour), 10) + "H" return strconv.FormatInt(div(t, time.Hour), 10) + "H"
} }
func timeoutDecode(s string) (time.Duration, error) { func decodeTimeout(s string) (time.Duration, error) {
size := len(s) size := len(s)
if size < 2 { if size < 2 {
return 0, fmt.Errorf("transport: timeout string is too short: %q", s) return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
@ -288,6 +289,80 @@ func timeoutDecode(s string) (time.Duration, error) {
return d * time.Duration(t), nil return d * time.Duration(t), nil
} }
const (
spaceByte = ' '
tildaByte = '~'
percentByte = '%'
)
// encodeGrpcMessage is used to encode status code in header field
// "grpc-message".
// It checks to see if each individual byte in msg is an
// allowable byte, and then either percent encoding or passing it through.
// When percent encoding, the byte is converted into hexadecimal notation
// with a '%' prepended.
func encodeGrpcMessage(msg string) string {
if msg == "" {
return ""
}
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if !(c >= spaceByte && c < tildaByte && c != percentByte) {
return encodeGrpcMessageUnchecked(msg)
}
}
return msg
}
func encodeGrpcMessageUnchecked(msg string) string {
var buf bytes.Buffer
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if c >= spaceByte && c < tildaByte && c != percentByte {
buf.WriteByte(c)
} else {
buf.WriteString(fmt.Sprintf("%%%02X", c))
}
}
return buf.String()
}
// decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage.
func decodeGrpcMessage(msg string) string {
if msg == "" {
return ""
}
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
if msg[i] == percentByte && i+2 < lenMsg {
return decodeGrpcMessageUnchecked(msg)
}
}
return msg
}
func decodeGrpcMessageUnchecked(msg string) string {
var buf bytes.Buffer
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if c == percentByte && i+2 < lenMsg {
parsed, err := strconv.ParseInt(msg[i+1:i+3], 16, 8)
if err != nil {
buf.WriteByte(c)
} else {
buf.WriteByte(byte(parsed))
i += 2
}
} else {
buf.WriteByte(c)
}
}
return buf.String()
}
type framer struct { type framer struct {
numWriters int32 numWriters int32
reader io.Reader reader io.Reader

View File

@ -59,7 +59,7 @@ func TestTimeoutEncode(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to parse duration string %s: %v", test.in, err) t.Fatalf("failed to parse duration string %s: %v", test.in, err)
} }
out := timeoutEncode(d) out := encodeTimeout(d)
if out != test.out { if out != test.out {
t.Fatalf("timeoutEncode(%s) = %s, want %s", test.in, out, test.out) t.Fatalf("timeoutEncode(%s) = %s, want %s", test.in, out, test.out)
} }
@ -79,7 +79,7 @@ func TestTimeoutDecode(t *testing.T) {
{"1", 0, fmt.Errorf("transport: timeout string is too short: %q", "1")}, {"1", 0, fmt.Errorf("transport: timeout string is too short: %q", "1")},
{"", 0, fmt.Errorf("transport: timeout string is too short: %q", "")}, {"", 0, fmt.Errorf("transport: timeout string is too short: %q", "")},
} { } {
d, err := timeoutDecode(test.s) d, err := decodeTimeout(test.s)
if d != test.d || fmt.Sprint(err) != fmt.Sprint(test.err) { if d != test.d || fmt.Sprint(err) != fmt.Sprint(test.err) {
t.Fatalf("timeoutDecode(%q) = %d, %v, want %d, %v", test.s, int64(d), err, int64(test.d), test.err) t.Fatalf("timeoutDecode(%q) = %d, %v, want %d, %v", test.s, int64(d), err, int64(test.d), test.err)
} }
@ -107,3 +107,38 @@ func TestValidContentType(t *testing.T) {
} }
} }
} }
func TestEncodeGrpcMessage(t *testing.T) {
for _, tt := range []struct {
input string
expected string
}{
{"", ""},
{"Hello", "Hello"},
{"my favorite character is \u0000", "my favorite character is %00"},
{"my favorite character is %", "my favorite character is %25"},
} {
actual := encodeGrpcMessage(tt.input)
if tt.expected != actual {
t.Errorf("encodeGrpcMessage(%v) = %v, want %v", tt.input, actual, tt.expected)
}
}
}
func TestDecodeGrpcMessage(t *testing.T) {
for _, tt := range []struct {
input string
expected string
}{
{"", ""},
{"Hello", "Hello"},
{"H%61o", "Hao"},
{"H%6", "H%6"},
{"%G0", "%G0"},
} {
actual := decodeGrpcMessage(tt.input)
if tt.expected != actual {
t.Errorf("dncodeGrpcMessage(%v) = %v, want %v", tt.input, actual, tt.expected)
}
}
}

51
transport/pre_go16.go Normal file
View File

@ -0,0 +1,51 @@
// +build !go1.6
/*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package transport
import (
"net"
"time"
"golang.org/x/net/context"
)
// dialContext connects to the address on the named network.
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
var dialer net.Dialer
if deadline, ok := ctx.Deadline(); ok {
dialer.Timeout = deadline.Sub(time.Now())
}
return dialer.Dial(network, address)
}

View File

@ -44,7 +44,6 @@ import (
"io" "io"
"net" "net"
"sync" "sync"
"time"
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/trace" "golang.org/x/net/trace"
@ -121,6 +120,7 @@ func (b *recvBuffer) get() <-chan item {
// recvBuffer. // recvBuffer.
type recvBufferReader struct { type recvBufferReader struct {
ctx context.Context ctx context.Context
goAway chan struct{}
recv *recvBuffer recv *recvBuffer
last *bytes.Reader // Stores the remaining data in the previous calls. last *bytes.Reader // Stores the remaining data in the previous calls.
err error err error
@ -141,6 +141,8 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) {
select { select {
case <-r.ctx.Done(): case <-r.ctx.Done():
return 0, ContextErr(r.ctx.Err()) return 0, ContextErr(r.ctx.Err())
case <-r.goAway:
return 0, ErrStreamDrain
case i := <-r.recv.get(): case i := <-r.recv.get():
r.recv.load() r.recv.load()
m := i.(*recvMsg) m := i.(*recvMsg)
@ -158,7 +160,7 @@ const (
streamActive streamState = iota streamActive streamState = iota
streamWriteDone // EndStream sent streamWriteDone // EndStream sent
streamReadDone // EndStream received streamReadDone // EndStream received
streamDone // sendDone and recvDone or RSTStreamFrame is sent or received. streamDone // the entire stream is finished.
) )
// Stream represents an RPC in the transport layer. // Stream represents an RPC in the transport layer.
@ -169,6 +171,10 @@ type Stream struct {
// ctx is the associated context of the stream. // ctx is the associated context of the stream.
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
// done is closed when the final status arrives.
done chan struct{}
// goAway is closed when the server sent GoAways signal before this stream was initiated.
goAway chan struct{}
// method records the associated RPC method of the stream. // method records the associated RPC method of the stream.
method string method string
recvCompress string recvCompress string
@ -214,6 +220,18 @@ func (s *Stream) SetSendCompress(str string) {
s.sendCompress = str s.sendCompress = str
} }
// Done returns a chanel which is closed when it receives the final status
// from the server.
func (s *Stream) Done() <-chan struct{} {
return s.done
}
// GoAway returns a channel which is closed when the server sent GoAways signal
// before this stream was initiated.
func (s *Stream) GoAway() <-chan struct{} {
return s.goAway
}
// Header acquires the key-value pairs of header metadata once it // Header acquires the key-value pairs of header metadata once it
// is available. It blocks until i) the metadata is ready or ii) there is no // is available. It blocks until i) the metadata is ready or ii) there is no
// header metadata or iii) the stream is cancelled/expired. // header metadata or iii) the stream is cancelled/expired.
@ -221,6 +239,8 @@ func (s *Stream) Header() (metadata.MD, error) {
select { select {
case <-s.ctx.Done(): case <-s.ctx.Done():
return nil, ContextErr(s.ctx.Err()) return nil, ContextErr(s.ctx.Err())
case <-s.goAway:
return nil, ErrStreamDrain
case <-s.headerChan: case <-s.headerChan:
return s.header.Copy(), nil return s.header.Copy(), nil
} }
@ -335,19 +355,17 @@ type ConnectOptions struct {
// UserAgent is the application user agent. // UserAgent is the application user agent.
UserAgent string UserAgent string
// Dialer specifies how to dial a network address. // Dialer specifies how to dial a network address.
Dialer func(string, time.Duration) (net.Conn, error) Dialer func(context.Context, string) (net.Conn, error)
// PerRPCCredentials stores the PerRPCCredentials required to issue RPCs. // PerRPCCredentials stores the PerRPCCredentials required to issue RPCs.
PerRPCCredentials []credentials.PerRPCCredentials PerRPCCredentials []credentials.PerRPCCredentials
// TransportCredentials stores the Authenticator required to setup a client connection. // TransportCredentials stores the Authenticator required to setup a client connection.
TransportCredentials credentials.TransportCredentials TransportCredentials credentials.TransportCredentials
// Timeout specifies the timeout for dialing a ClientTransport.
Timeout time.Duration
} }
// NewClientTransport establishes the transport with the required ConnectOptions // NewClientTransport establishes the transport with the required ConnectOptions
// and returns it to the caller. // and returns it to the caller.
func NewClientTransport(target string, opts *ConnectOptions) (ClientTransport, error) { func NewClientTransport(ctx context.Context, target string, opts ConnectOptions) (ClientTransport, error) {
return newHTTP2Client(target, opts) return newHTTP2Client(ctx, target, opts)
} }
// Options provides additional hints and information for message // Options provides additional hints and information for message
@ -417,6 +435,11 @@ type ClientTransport interface {
// and create a new one) in error case. It should not return nil // and create a new one) in error case. It should not return nil
// once the transport is initiated. // once the transport is initiated.
Error() <-chan struct{} Error() <-chan struct{}
// GoAway returns a channel that is closed when ClientTranspor
// receives the draining signal from the server (e.g., GOAWAY frame in
// HTTP/2).
GoAway() <-chan struct{}
} }
// ServerTransport is the common interface for all gRPC server-side transport // ServerTransport is the common interface for all gRPC server-side transport
@ -448,6 +471,9 @@ type ServerTransport interface {
// RemoteAddr returns the remote network address. // RemoteAddr returns the remote network address.
RemoteAddr() net.Addr RemoteAddr() net.Addr
// Drain notifies the client this ServerTransport stops accepting new RPCs.
Drain()
} }
// StreamErrorf creates an StreamError with the specified error code and description. // StreamErrorf creates an StreamError with the specified error code and description.
@ -459,9 +485,11 @@ func StreamErrorf(c codes.Code, format string, a ...interface{}) StreamError {
} }
// ConnectionErrorf creates an ConnectionError with the specified error description. // ConnectionErrorf creates an ConnectionError with the specified error description.
func ConnectionErrorf(format string, a ...interface{}) ConnectionError { func ConnectionErrorf(temp bool, e error, format string, a ...interface{}) ConnectionError {
return ConnectionError{ return ConnectionError{
Desc: fmt.Sprintf(format, a...), Desc: fmt.Sprintf(format, a...),
temp: temp,
err: e,
} }
} }
@ -469,14 +497,36 @@ func ConnectionErrorf(format string, a ...interface{}) ConnectionError {
// entire connection and the retry of all the active streams. // entire connection and the retry of all the active streams.
type ConnectionError struct { type ConnectionError struct {
Desc string Desc string
temp bool
err error
} }
func (e ConnectionError) Error() string { func (e ConnectionError) Error() string {
return fmt.Sprintf("connection error: desc = %q", e.Desc) return fmt.Sprintf("connection error: desc = %q", e.Desc)
} }
// Temporary indicates if this connection error is temporary or fatal.
func (e ConnectionError) Temporary() bool {
return e.temp
}
// Origin returns the original error of this connection error.
func (e ConnectionError) Origin() error {
// Never return nil error here.
// If the original error is nil, return itself.
if e.err == nil {
return e
}
return e.err
}
var (
// ErrConnClosing indicates that the transport is closing. // ErrConnClosing indicates that the transport is closing.
var ErrConnClosing = ConnectionError{Desc: "transport is closing"} ErrConnClosing = ConnectionError{Desc: "transport is closing", temp: true}
// ErrStreamDrain indicates that the stream is rejected by the server because
// the server stops accepting new RPCs.
ErrStreamDrain = StreamErrorf(codes.Unavailable, "the server stops accepting new RPCs")
)
// StreamError is an error that only affects one stream within a connection. // StreamError is an error that only affects one stream within a connection.
type StreamError struct { type StreamError struct {
@ -501,12 +551,25 @@ func ContextErr(err error) StreamError {
// wait blocks until it can receive from ctx.Done, closing, or proceed. // wait blocks until it can receive from ctx.Done, closing, or proceed.
// If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err. // If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err.
// If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise
// it return the StreamError for ctx.Err.
// If it receives from goAway, it returns 0, ErrStreamDrain.
// If it receives from closing, it returns 0, ErrConnClosing. // If it receives from closing, it returns 0, ErrConnClosing.
// If it receives from proceed, it returns the received integer, nil. // If it receives from proceed, it returns the received integer, nil.
func wait(ctx context.Context, closing <-chan struct{}, proceed <-chan int) (int, error) { func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-chan int) (int, error) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return 0, ContextErr(ctx.Err()) return 0, ContextErr(ctx.Err())
case <-done:
// User cancellation has precedence.
select {
case <-ctx.Done():
return 0, ContextErr(ctx.Err())
default:
}
return 0, io.EOF
case <-goAway:
return 0, ErrStreamDrain
case <-closing: case <-closing:
return 0, ErrConnClosing return 0, ErrConnClosing
case i := <-proceed: case i := <-proceed:

View File

@ -39,7 +39,6 @@ import (
"io" "io"
"math" "math"
"net" "net"
"reflect"
"strconv" "strconv"
"sync" "sync"
"testing" "testing"
@ -75,7 +74,7 @@ const (
normal hType = iota normal hType = iota
suspended suspended
misbehaved misbehaved
malformedStatus encodingRequiredStatus
) )
func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) { func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
@ -111,27 +110,34 @@ func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) {
if !ok { if !ok {
t.Fatalf("Failed to convert %v to *http2Server", s.ServerTransport()) t.Fatalf("Failed to convert %v to *http2Server", s.ServerTransport())
} }
size := 1
if s.Method() == "foo.MaxFrame" {
size = http2MaxFrameLen
}
// Drain the client side stream flow control window.
var sent int var sent int
for sent <= initialWindowSize { p := make([]byte, http2MaxFrameLen)
for sent < initialWindowSize {
<-conn.writableChan <-conn.writableChan
if err := conn.framer.writeData(true, s.id, false, make([]byte, size)); err != nil { n := initialWindowSize - sent
// The last message may be smaller than http2MaxFrameLen
if n <= http2MaxFrameLen {
if s.Method() == "foo.Connection" {
// Violate connection level flow control window of client but do not
// violate any stream level windows.
p = make([]byte, n)
} else {
// Violate stream level flow control window of client.
p = make([]byte, n+1)
}
}
if err := conn.framer.writeData(true, s.id, false, p); err != nil {
conn.writableChan <- 0 conn.writableChan <- 0
break break
} }
conn.writableChan <- 0 conn.writableChan <- 0
sent += size sent += len(p)
} }
} }
func (h *testStreamHandler) handleStreamMalformedStatus(t *testing.T, s *Stream) { func (h *testStreamHandler) handleStreamEncodingRequiredStatus(t *testing.T, s *Stream) {
// raw newline is not accepted by http2 framer and a http2.StreamError is // raw newline is not accepted by http2 framer so it must be encoded.
// generated. h.t.WriteStatus(s, encodingTestStatusCode, encodingTestStatusDesc)
h.t.WriteStatus(s, codes.Internal, "\n")
} }
// start starts server. Other goroutines should block on s.readyChan for further operations. // start starts server. Other goroutines should block on s.readyChan for further operations.
@ -179,9 +185,9 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) {
go transport.HandleStreams(func(s *Stream) { go transport.HandleStreams(func(s *Stream) {
go h.handleStreamMisbehave(t, s) go h.handleStreamMisbehave(t, s)
}) })
case malformedStatus: case encodingRequiredStatus:
go transport.HandleStreams(func(s *Stream) { go transport.HandleStreams(func(s *Stream) {
go h.handleStreamMalformedStatus(t, s) go h.handleStreamEncodingRequiredStatus(t, s)
}) })
default: default:
go transport.HandleStreams(func(s *Stream) { go transport.HandleStreams(func(s *Stream) {
@ -221,7 +227,7 @@ func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, Client
ct ClientTransport ct ClientTransport
connErr error connErr error
) )
ct, connErr = NewClientTransport(addr, &ConnectOptions{}) ct, connErr = NewClientTransport(context.Background(), addr, ConnectOptions{})
if connErr != nil { if connErr != nil {
t.Fatalf("failed to create transport: %v", connErr) t.Fatalf("failed to create transport: %v", connErr)
} }
@ -252,7 +258,7 @@ func TestClientSendAndReceive(t *testing.T) {
Last: true, Last: true,
Delay: false, Delay: false,
} }
if err := ct.Write(s1, expectedRequest, &opts); err != nil { if err := ct.Write(s1, expectedRequest, &opts); err != nil && err != io.EOF {
t.Fatalf("failed to send data: %v", err) t.Fatalf("failed to send data: %v", err)
} }
p := make([]byte, len(expectedResponse)) p := make([]byte, len(expectedResponse))
@ -289,7 +295,7 @@ func performOneRPC(ct ClientTransport) {
Last: true, Last: true,
Delay: false, Delay: false,
} }
if err := ct.Write(s, expectedRequest, &opts); err == nil { if err := ct.Write(s, expectedRequest, &opts); err == nil || err == io.EOF {
time.Sleep(5 * time.Millisecond) time.Sleep(5 * time.Millisecond)
// The following s.Recv()'s could error out because the // The following s.Recv()'s could error out because the
// underlying transport is gone. // underlying transport is gone.
@ -333,7 +339,7 @@ func TestLargeMessage(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err) t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
} }
if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil { if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err) t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
} }
p := make([]byte, len(expectedResponseLarge)) p := make([]byte, len(expectedResponseLarge))
@ -369,8 +375,8 @@ func TestGracefulClose(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
if _, err := ct.NewStream(context.Background(), callHdr); err != ErrConnClosing { if _, err := ct.NewStream(context.Background(), callHdr); err != ErrStreamDrain {
t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", err, ErrConnClosing) t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", ct, err, ErrStreamDrain)
} }
}() }()
} }
@ -379,7 +385,7 @@ func TestGracefulClose(t *testing.T) {
Delay: false, Delay: false,
} }
// The stream which was created before graceful close can still proceed. // The stream which was created before graceful close can still proceed.
if err := ct.Write(s, expectedRequest, &opts); err != nil { if err := ct.Write(s, expectedRequest, &opts); err != nil && err != io.EOF {
t.Fatalf("%v.Write(_, _, _) = %v, want <nil>", ct, err) t.Fatalf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
} }
p := make([]byte, len(expectedResponse)) p := make([]byte, len(expectedResponse))
@ -409,7 +415,7 @@ func TestLargeMessageSuspension(t *testing.T) {
// Write should not be done successfully due to flow control. // Write should not be done successfully due to flow control.
err = ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}) err = ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false})
expectedErr := StreamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded) expectedErr := StreamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded)
if err == nil || err != expectedErr { if err != expectedErr {
t.Fatalf("Write got %v, want %v", err, expectedErr) t.Fatalf("Write got %v, want %v", err, expectedErr)
} }
ct.Close() ct.Close()
@ -433,14 +439,21 @@ func TestMaxStreams(t *testing.T) {
} }
done := make(chan struct{}) done := make(chan struct{})
ch := make(chan int) ch := make(chan int)
ready := make(chan struct{})
go func() { go func() {
for { for {
select { select {
case <-time.After(5 * time.Millisecond): case <-time.After(5 * time.Millisecond):
ch <- 0 select {
case ch <- 0:
case <-ready:
return
}
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
close(done) close(done)
return return
case <-ready:
return
} }
} }
}() }()
@ -467,6 +480,7 @@ func TestMaxStreams(t *testing.T) {
} }
cc.mu.Unlock() cc.mu.Unlock()
} }
close(ready)
// Close the pending stream so that the streams quota becomes available for the next new stream. // Close the pending stream so that the streams quota becomes available for the next new stream.
ct.CloseStream(s, nil) ct.CloseStream(s, nil)
select { select {
@ -546,6 +560,7 @@ func TestServerContextCanceledOnClosedConnection(t *testing.T) {
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Fatalf("Failed to cancel the context of the sever side stream.") t.Fatalf("Failed to cancel the context of the sever side stream.")
} }
server.stop()
} }
func TestServerWithMisbehavedClient(t *testing.T) { func TestServerWithMisbehavedClient(t *testing.T) {
@ -652,7 +667,7 @@ func TestClientWithMisbehavedServer(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, misbehaved) server, ct := setUp(t, 0, math.MaxUint32, misbehaved)
callHdr := &CallHdr{ callHdr := &CallHdr{
Host: "localhost", Host: "localhost",
Method: "foo", Method: "foo.Stream",
} }
conn, ok := ct.(*http2Client) conn, ok := ct.(*http2Client)
if !ok { if !ok {
@ -663,7 +678,8 @@ func TestClientWithMisbehavedServer(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to open stream: %v", err) t.Fatalf("Failed to open stream: %v", err)
} }
if err := ct.Write(s, expectedRequest, &Options{Last: true, Delay: false}); err != nil { d := make([]byte, 1)
if err := ct.Write(s, d, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
t.Fatalf("Failed to write: %v", err) t.Fatalf("Failed to write: %v", err)
} }
// Read without window update. // Read without window update.
@ -685,17 +701,15 @@ func TestClientWithMisbehavedServer(t *testing.T) {
} }
// Test the logic for the violation of the connection flow control window size restriction. // Test the logic for the violation of the connection flow control window size restriction.
// //
// Generate enough streams to drain the connection window. // Generate enough streams to drain the connection window. Make the server flood the traffic
callHdr = &CallHdr{ // to violate flow control window size of the connection.
Host: "localhost", callHdr.Method = "foo.Connection"
Method: "foo.MaxFrame",
}
for i := 0; i < int(initialConnWindowSize/initialWindowSize+10); i++ { for i := 0; i < int(initialConnWindowSize/initialWindowSize+10); i++ {
s, err := ct.NewStream(context.Background(), callHdr) s, err := ct.NewStream(context.Background(), callHdr)
if err != nil { if err != nil {
break break
} }
if err := ct.Write(s, expectedRequest, &Options{Last: true, Delay: false}); err != nil { if err := ct.Write(s, d, &Options{Last: true, Delay: false}); err != nil {
break break
} }
} }
@ -705,8 +719,13 @@ func TestClientWithMisbehavedServer(t *testing.T) {
server.stop() server.stop()
} }
func TestMalformedStatus(t *testing.T) { var (
server, ct := setUp(t, 0, math.MaxUint32, malformedStatus) encodingTestStatusCode = codes.Internal
encodingTestStatusDesc = "\n"
)
func TestEncodingRequiredStatus(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, encodingRequiredStatus)
callHdr := &CallHdr{ callHdr := &CallHdr{
Host: "localhost", Host: "localhost",
Method: "foo", Method: "foo",
@ -719,24 +738,26 @@ func TestMalformedStatus(t *testing.T) {
Last: true, Last: true,
Delay: false, Delay: false,
} }
if err := ct.Write(s, expectedRequest, &opts); err != nil { if err := ct.Write(s, expectedRequest, &opts); err != nil && err != io.EOF {
t.Fatalf("Failed to write the request: %v", err) t.Fatalf("Failed to write the request: %v", err)
} }
p := make([]byte, http2MaxFrameLen) p := make([]byte, http2MaxFrameLen)
expectedErr := StreamErrorf(codes.Internal, "invalid header field value \"\\n\"") if _, err := s.dec.Read(p); err != io.EOF {
if _, err = s.dec.Read(p); err != expectedErr { t.Fatalf("Read got error %v, want %v", err, io.EOF)
t.Fatalf("Read the err %v, want %v", err, expectedErr) }
if s.StatusCode() != encodingTestStatusCode || s.StatusDesc() != encodingTestStatusDesc {
t.Fatalf("stream with status code %d, status desc %v, want %d, %v", s.StatusCode(), s.StatusDesc(), encodingTestStatusCode, encodingTestStatusDesc)
} }
ct.Close() ct.Close()
server.stop() server.stop()
} }
func TestStreamContext(t *testing.T) { func TestStreamContext(t *testing.T) {
expectedStream := Stream{} expectedStream := &Stream{}
ctx := newContextWithStream(context.Background(), &expectedStream) ctx := newContextWithStream(context.Background(), expectedStream)
s, ok := StreamFromContext(ctx) s, ok := StreamFromContext(ctx)
if !ok || !reflect.DeepEqual(expectedStream, *s) { if !ok || expectedStream != s {
t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, *s, ok, expectedStream) t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, s, ok, expectedStream)
} }
} }