Merge remote-tracking branch 'upstream/master' into status_interop_test
This commit is contained in:
19
.travis.yml
19
.travis.yml
@ -1,17 +1,18 @@
|
|||||||
language: go
|
language: go
|
||||||
|
|
||||||
go:
|
go:
|
||||||
- 1.5.3
|
- 1.5.4
|
||||||
- 1.6
|
- 1.6.3
|
||||||
|
- 1.7
|
||||||
|
|
||||||
|
go_import_path: google.golang.org/grpc
|
||||||
|
|
||||||
before_install:
|
before_install:
|
||||||
- go get github.com/axw/gocov/gocov
|
- go get -u golang.org/x/tools/cmd/goimports github.com/golang/lint/golint github.com/axw/gocov/gocov github.com/mattn/goveralls golang.org/x/tools/cmd/cover
|
||||||
- go get github.com/mattn/goveralls
|
|
||||||
- go get golang.org/x/tools/cmd/cover
|
|
||||||
|
|
||||||
install:
|
|
||||||
- mkdir -p "$GOPATH/src/google.golang.org"
|
|
||||||
- mv "$TRAVIS_BUILD_DIR" "$GOPATH/src/google.golang.org/grpc"
|
|
||||||
|
|
||||||
script:
|
script:
|
||||||
|
- '! gofmt -s -d -l . 2>&1 | read'
|
||||||
|
- '! goimports -l . | read'
|
||||||
|
- '! golint ./... | grep -vE "(_string|\.pb)\.go:"'
|
||||||
|
- '! go tool vet -all . 2>&1 | grep -vE "constant [0-9]+ not a string in call to Errorf" | grep -vF .pb.go:' # https://github.com/golang/protobuf/issues/214
|
||||||
- make test testrace
|
- make test testrace
|
||||||
|
@ -28,5 +28,5 @@ See [API documentation](https://godoc.org/google.golang.org/grpc) for package an
|
|||||||
|
|
||||||
Status
|
Status
|
||||||
------
|
------
|
||||||
Beta release
|
GA
|
||||||
|
|
||||||
|
29
balancer.go
29
balancer.go
@ -40,7 +40,6 @@ import (
|
|||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
"google.golang.org/grpc/grpclog"
|
"google.golang.org/grpc/grpclog"
|
||||||
"google.golang.org/grpc/naming"
|
"google.golang.org/grpc/naming"
|
||||||
"google.golang.org/grpc/transport"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Address represents a server the client connects to.
|
// Address represents a server the client connects to.
|
||||||
@ -94,10 +93,10 @@ type Balancer interface {
|
|||||||
// instead of blocking.
|
// instead of blocking.
|
||||||
//
|
//
|
||||||
// The function returns put which is called once the rpc has completed or failed.
|
// The function returns put which is called once the rpc has completed or failed.
|
||||||
// put can collect and report RPC stats to a remote load balancer. gRPC internals
|
// put can collect and report RPC stats to a remote load balancer.
|
||||||
// will try to call this again if err is non-nil (unless err is ErrClientConnClosing).
|
|
||||||
//
|
//
|
||||||
// TODO: Add other non-recoverable errors?
|
// This function should only return the errors Balancer cannot recover by itself.
|
||||||
|
// gRPC internals will fail the RPC if an error is returned.
|
||||||
Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error)
|
Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error)
|
||||||
// Notify returns a channel that is used by gRPC internals to watch the addresses
|
// Notify returns a channel that is used by gRPC internals to watch the addresses
|
||||||
// gRPC needs to connect. The addresses might be from a name resolver or remote
|
// gRPC needs to connect. The addresses might be from a name resolver or remote
|
||||||
@ -158,14 +157,15 @@ type roundRobin struct {
|
|||||||
func (rr *roundRobin) watchAddrUpdates() error {
|
func (rr *roundRobin) watchAddrUpdates() error {
|
||||||
updates, err := rr.w.Next()
|
updates, err := rr.w.Next()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
grpclog.Println("grpc: the naming watcher stops working due to %v.", err)
|
grpclog.Printf("grpc: the naming watcher stops working due to %v.\n", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
rr.mu.Lock()
|
rr.mu.Lock()
|
||||||
defer rr.mu.Unlock()
|
defer rr.mu.Unlock()
|
||||||
for _, update := range updates {
|
for _, update := range updates {
|
||||||
addr := Address{
|
addr := Address{
|
||||||
Addr: update.Addr,
|
Addr: update.Addr,
|
||||||
|
Metadata: update.Metadata,
|
||||||
}
|
}
|
||||||
switch update.Op {
|
switch update.Op {
|
||||||
case naming.Add:
|
case naming.Add:
|
||||||
@ -298,8 +298,19 @@ func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Ad
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// There is no address available. Wait on rr.waitCh.
|
if !opts.BlockingWait {
|
||||||
// TODO(zhaoq): Handle the case when opts.BlockingWait is false.
|
if len(rr.addrs) == 0 {
|
||||||
|
rr.mu.Unlock()
|
||||||
|
err = fmt.Errorf("there is no address available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Returns the next addr on rr.addrs for failfast RPCs.
|
||||||
|
addr = rr.addrs[rr.next].addr
|
||||||
|
rr.next++
|
||||||
|
rr.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Wait on rr.waitCh for non-failfast RPCs.
|
||||||
if rr.waitCh == nil {
|
if rr.waitCh == nil {
|
||||||
ch = make(chan struct{})
|
ch = make(chan struct{})
|
||||||
rr.waitCh = ch
|
rr.waitCh = ch
|
||||||
@ -310,7 +321,7 @@ func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Ad
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
err = transport.ContextErr(ctx.Err())
|
err = ctx.Err()
|
||||||
return
|
return
|
||||||
case <-ch:
|
case <-ch:
|
||||||
rr.mu.Lock()
|
rr.mu.Lock()
|
||||||
|
134
balancer_test.go
134
balancer_test.go
@ -239,11 +239,11 @@ func TestCloseWithPendingRPC(t *testing.T) {
|
|||||||
t.Fatalf("Failed to create ClientConn: %v", err)
|
t.Fatalf("Failed to create ClientConn: %v", err)
|
||||||
}
|
}
|
||||||
var reply string
|
var reply string
|
||||||
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil {
|
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil {
|
||||||
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port)
|
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port)
|
||||||
}
|
}
|
||||||
// Remove the server.
|
// Remove the server.
|
||||||
updates := []*naming.Update{&naming.Update{
|
updates := []*naming.Update{{
|
||||||
Op: naming.Delete,
|
Op: naming.Delete,
|
||||||
Addr: "127.0.0.1:" + servers[0].port,
|
Addr: "127.0.0.1:" + servers[0].port,
|
||||||
}}
|
}}
|
||||||
@ -251,7 +251,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
|
|||||||
// Loop until the above update applies.
|
// Loop until the above update applies.
|
||||||
for {
|
for {
|
||||||
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||||
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); Code(err) == codes.DeadlineExceeded {
|
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); Code(err) == codes.DeadlineExceeded {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
@ -262,7 +262,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
|
|||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
var reply string
|
var reply string
|
||||||
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err == nil {
|
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil {
|
||||||
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
|
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -270,7 +270,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
|
|||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
var reply string
|
var reply string
|
||||||
time.Sleep(5 * time.Millisecond)
|
time.Sleep(5 * time.Millisecond)
|
||||||
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err == nil {
|
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil {
|
||||||
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
|
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -287,7 +287,7 @@ func TestGetOnWaitChannel(t *testing.T) {
|
|||||||
t.Fatalf("Failed to create ClientConn: %v", err)
|
t.Fatalf("Failed to create ClientConn: %v", err)
|
||||||
}
|
}
|
||||||
// Remove all servers so that all upcoming RPCs will block on waitCh.
|
// Remove all servers so that all upcoming RPCs will block on waitCh.
|
||||||
updates := []*naming.Update{&naming.Update{
|
updates := []*naming.Update{{
|
||||||
Op: naming.Delete,
|
Op: naming.Delete,
|
||||||
Addr: "127.0.0.1:" + servers[0].port,
|
Addr: "127.0.0.1:" + servers[0].port,
|
||||||
}}
|
}}
|
||||||
@ -295,7 +295,7 @@ func TestGetOnWaitChannel(t *testing.T) {
|
|||||||
for {
|
for {
|
||||||
var reply string
|
var reply string
|
||||||
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||||
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); Code(err) == codes.DeadlineExceeded {
|
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); Code(err) == codes.DeadlineExceeded {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
@ -305,12 +305,12 @@ func TestGetOnWaitChannel(t *testing.T) {
|
|||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
var reply string
|
var reply string
|
||||||
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil {
|
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil {
|
||||||
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
|
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
// Add a connected server to get the above RPC through.
|
// Add a connected server to get the above RPC through.
|
||||||
updates = []*naming.Update{&naming.Update{
|
updates = []*naming.Update{{
|
||||||
Op: naming.Add,
|
Op: naming.Add,
|
||||||
Addr: "127.0.0.1:" + servers[0].port,
|
Addr: "127.0.0.1:" + servers[0].port,
|
||||||
}}
|
}}
|
||||||
@ -320,3 +320,119 @@ func TestGetOnWaitChannel(t *testing.T) {
|
|||||||
cc.Close()
|
cc.Close()
|
||||||
servers[0].stop()
|
servers[0].stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOneServerDown(t *testing.T) {
|
||||||
|
// Start 2 servers.
|
||||||
|
numServers := 2
|
||||||
|
servers, r := startServers(t, numServers, math.MaxUint32)
|
||||||
|
cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create ClientConn: %v", err)
|
||||||
|
}
|
||||||
|
// Add servers[1] to the service discovery.
|
||||||
|
var updates []*naming.Update
|
||||||
|
updates = append(updates, &naming.Update{
|
||||||
|
Op: naming.Add,
|
||||||
|
Addr: "127.0.0.1:" + servers[1].port,
|
||||||
|
})
|
||||||
|
r.w.inject(updates)
|
||||||
|
req := "port"
|
||||||
|
var reply string
|
||||||
|
// Loop until servers[1] is up
|
||||||
|
for {
|
||||||
|
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
numRPC := 100
|
||||||
|
sleepDuration := 10 * time.Millisecond
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
time.Sleep(sleepDuration)
|
||||||
|
// After sleepDuration, kill server[0].
|
||||||
|
servers[0].stop()
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// All non-failfast RPCs should not block because there's at least one connection available.
|
||||||
|
for i := 0; i < numRPC; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
time.Sleep(sleepDuration)
|
||||||
|
// After sleepDuration, invoke RPC.
|
||||||
|
// server[0] is killed around the same time to make it racy between balancer and gRPC internals.
|
||||||
|
Invoke(context.Background(), "/foo/bar", &req, &reply, cc, FailFast(false))
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
cc.Close()
|
||||||
|
for i := 0; i < numServers; i++ {
|
||||||
|
servers[i].stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOneAddressRemoval(t *testing.T) {
|
||||||
|
// Start 2 servers.
|
||||||
|
numServers := 2
|
||||||
|
servers, r := startServers(t, numServers, math.MaxUint32)
|
||||||
|
cc, err := Dial("foo.bar.com", WithBalancer(RoundRobin(r)), WithBlock(), WithInsecure(), WithCodec(testCodec{}))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create ClientConn: %v", err)
|
||||||
|
}
|
||||||
|
// Add servers[1] to the service discovery.
|
||||||
|
var updates []*naming.Update
|
||||||
|
updates = append(updates, &naming.Update{
|
||||||
|
Op: naming.Add,
|
||||||
|
Addr: "127.0.0.1:" + servers[1].port,
|
||||||
|
})
|
||||||
|
r.w.inject(updates)
|
||||||
|
req := "port"
|
||||||
|
var reply string
|
||||||
|
// Loop until servers[1] is up
|
||||||
|
for {
|
||||||
|
if err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err != nil && ErrorDesc(err) == servers[1].port {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
numRPC := 100
|
||||||
|
sleepDuration := 10 * time.Millisecond
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
time.Sleep(sleepDuration)
|
||||||
|
// After sleepDuration, delete server[0].
|
||||||
|
var updates []*naming.Update
|
||||||
|
updates = append(updates, &naming.Update{
|
||||||
|
Op: naming.Delete,
|
||||||
|
Addr: "127.0.0.1:" + servers[0].port,
|
||||||
|
})
|
||||||
|
r.w.inject(updates)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// All non-failfast RPCs should not fail because there's at least one connection available.
|
||||||
|
for i := 0; i < numRPC; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
var reply string
|
||||||
|
time.Sleep(sleepDuration)
|
||||||
|
// After sleepDuration, invoke RPC.
|
||||||
|
// server[0] is removed around the same time to make it racy between balancer and gRPC internals.
|
||||||
|
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil {
|
||||||
|
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
cc.Close()
|
||||||
|
for i := 0; i < numServers; i++ {
|
||||||
|
servers[i].stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -58,7 +58,7 @@ func closeLoopUnary() {
|
|||||||
|
|
||||||
for i := 0; i < *maxConcurrentRPCs; i++ {
|
for i := 0; i < *maxConcurrentRPCs; i++ {
|
||||||
go func() {
|
go func() {
|
||||||
for _ = range ch {
|
for range ch {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
unaryCaller(tc)
|
unaryCaller(tc)
|
||||||
elapse := time.Since(start)
|
elapse := time.Since(start)
|
||||||
|
@ -133,8 +133,8 @@ func (h *Histogram) Clear() {
|
|||||||
h.SumOfSquares = 0
|
h.SumOfSquares = 0
|
||||||
h.Min = math.MaxInt64
|
h.Min = math.MaxInt64
|
||||||
h.Max = math.MinInt64
|
h.Max = math.MinInt64
|
||||||
for _, v := range h.Buckets {
|
for i := range h.Buckets {
|
||||||
v.Count = 0
|
h.Buckets[i].Count = 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ type byteBufCodec struct {
|
|||||||
func (byteBufCodec) Marshal(v interface{}) ([]byte, error) {
|
func (byteBufCodec) Marshal(v interface{}) ([]byte, error) {
|
||||||
b, ok := v.(*[]byte)
|
b, ok := v.(*[]byte)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("failed to marshal: %v is not type of *[]byte")
|
return nil, fmt.Errorf("failed to marshal: %v is not type of *[]byte", v)
|
||||||
}
|
}
|
||||||
return *b, nil
|
return *b, nil
|
||||||
}
|
}
|
||||||
@ -68,7 +68,7 @@ func (byteBufCodec) Marshal(v interface{}) ([]byte, error) {
|
|||||||
func (byteBufCodec) Unmarshal(data []byte, v interface{}) error {
|
func (byteBufCodec) Unmarshal(data []byte, v interface{}) error {
|
||||||
b, ok := v.(*[]byte)
|
b, ok := v.(*[]byte)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("failed to marshal: %v is not type of *[]byte")
|
return fmt.Errorf("failed to marshal: %v is not type of *[]byte", v)
|
||||||
}
|
}
|
||||||
*b = data
|
*b = data
|
||||||
return nil
|
return nil
|
||||||
@ -138,8 +138,6 @@ func (s *workerServer) RunServer(stream testpb.WorkerService_RunServerServer) er
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *workerServer) RunClient(stream testpb.WorkerService_RunClientServer) error {
|
func (s *workerServer) RunClient(stream testpb.WorkerService_RunClientServer) error {
|
||||||
@ -191,13 +189,11 @@ func (s *workerServer) RunClient(stream testpb.WorkerService_RunClientServer) er
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *workerServer) CoreCount(ctx context.Context, in *testpb.CoreRequest) (*testpb.CoreResponse, error) {
|
func (s *workerServer) CoreCount(ctx context.Context, in *testpb.CoreRequest) (*testpb.CoreResponse, error) {
|
||||||
grpclog.Printf("core count: %v", runtime.NumCPU())
|
grpclog.Printf("core count: %v", runtime.NumCPU())
|
||||||
return &testpb.CoreResponse{int32(runtime.NumCPU())}, nil
|
return &testpb.CoreResponse{Cores: int32(runtime.NumCPU())}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *workerServer) QuitWorker(ctx context.Context, in *testpb.Void) (*testpb.Void, error) {
|
func (s *workerServer) QuitWorker(ctx context.Context, in *testpb.Void) (*testpb.Void, error) {
|
||||||
|
43
call.go
43
call.go
@ -36,6 +36,7 @@ package grpc
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
@ -51,13 +52,20 @@ import (
|
|||||||
func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error {
|
func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error {
|
||||||
// Try to acquire header metadata from the server if there is any.
|
// Try to acquire header metadata from the server if there is any.
|
||||||
var err error
|
var err error
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
if _, ok := err.(transport.ConnectionError); !ok {
|
||||||
|
t.CloseStream(stream, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
c.headerMD, err = stream.Header()
|
c.headerMD, err = stream.Header()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
p := &parser{r: stream}
|
p := &parser{r: stream}
|
||||||
for {
|
for {
|
||||||
if err = recv(p, dopts.codec, stream, dopts.dc, reply); err != nil {
|
if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32); err != nil {
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -76,6 +84,7 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
|
|||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// If err is connection error, t will be closed, no need to close stream here.
|
||||||
if _, ok := err.(transport.ConnectionError); !ok {
|
if _, ok := err.(transport.ConnectionError); !ok {
|
||||||
t.CloseStream(stream, err)
|
t.CloseStream(stream, err)
|
||||||
}
|
}
|
||||||
@ -90,7 +99,10 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
|
|||||||
return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err)
|
return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err)
|
||||||
}
|
}
|
||||||
err = t.Write(stream, outBuf, opts)
|
err = t.Write(stream, outBuf, opts)
|
||||||
if err != nil {
|
// t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method
|
||||||
|
// does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following
|
||||||
|
// recvResponse to get the final status.
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// Sent successfully.
|
// Sent successfully.
|
||||||
@ -101,7 +113,7 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
|
|||||||
// Invoke is called by generated code. Also users can call Invoke directly when it
|
// Invoke is called by generated code. Also users can call Invoke directly when it
|
||||||
// is really needed in their use cases.
|
// is really needed in their use cases.
|
||||||
func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) {
|
func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) {
|
||||||
var c callInfo
|
c := defaultCallInfo
|
||||||
for _, o := range opts {
|
for _, o := range opts {
|
||||||
if err := o.before(&c); err != nil {
|
if err := o.before(&c); err != nil {
|
||||||
return toRPCErr(err)
|
return toRPCErr(err)
|
||||||
@ -155,19 +167,17 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
|
|||||||
t, put, err = cc.getTransport(ctx, gopts)
|
t, put, err = cc.getTransport(ctx, gopts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO(zhaoq): Probably revisit the error handling.
|
// TODO(zhaoq): Probably revisit the error handling.
|
||||||
if err == ErrClientConnClosing {
|
if _, ok := err.(*rpcError); ok {
|
||||||
return Errorf(codes.FailedPrecondition, "%v", err)
|
return err
|
||||||
}
|
}
|
||||||
if _, ok := err.(transport.StreamError); ok {
|
if err == errConnClosing || err == errConnUnavailable {
|
||||||
return toRPCErr(err)
|
|
||||||
}
|
|
||||||
if _, ok := err.(transport.ConnectionError); ok {
|
|
||||||
if c.failFast {
|
if c.failFast {
|
||||||
return toRPCErr(err)
|
return Errorf(codes.Unavailable, "%v", err)
|
||||||
}
|
}
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
// All the remaining cases are treated as retryable.
|
// All the other errors are treated as Internal errors.
|
||||||
continue
|
return Errorf(codes.Internal, "%v", err)
|
||||||
}
|
}
|
||||||
if c.traceInfo.tr != nil {
|
if c.traceInfo.tr != nil {
|
||||||
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
|
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
|
||||||
@ -178,7 +188,10 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
|
|||||||
put()
|
put()
|
||||||
put = nil
|
put = nil
|
||||||
}
|
}
|
||||||
if _, ok := err.(transport.ConnectionError); ok {
|
// Retry a non-failfast RPC when
|
||||||
|
// i) there is a connection error; or
|
||||||
|
// ii) the server started to drain before this RPC was initiated.
|
||||||
|
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
|
||||||
if c.failFast {
|
if c.failFast {
|
||||||
return toRPCErr(err)
|
return toRPCErr(err)
|
||||||
}
|
}
|
||||||
@ -186,20 +199,18 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
|
|||||||
}
|
}
|
||||||
return toRPCErr(err)
|
return toRPCErr(err)
|
||||||
}
|
}
|
||||||
// Receive the response
|
|
||||||
err = recvResponse(cc.dopts, t, &c, stream, reply)
|
err = recvResponse(cc.dopts, t, &c, stream, reply)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if put != nil {
|
if put != nil {
|
||||||
put()
|
put()
|
||||||
put = nil
|
put = nil
|
||||||
}
|
}
|
||||||
if _, ok := err.(transport.ConnectionError); ok {
|
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
|
||||||
if c.failFast {
|
if c.failFast {
|
||||||
return toRPCErr(err)
|
return toRPCErr(err)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
t.CloseStream(stream, err)
|
|
||||||
return toRPCErr(err)
|
return toRPCErr(err)
|
||||||
}
|
}
|
||||||
if c.traceInfo.tr != nil {
|
if c.traceInfo.tr != nil {
|
||||||
|
21
call_test.go
21
call_test.go
@ -81,7 +81,7 @@ type testStreamHandler struct {
|
|||||||
func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
|
func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
|
||||||
p := &parser{r: s}
|
p := &parser{r: s}
|
||||||
for {
|
for {
|
||||||
pf, req, err := p.recvMsg()
|
pf, req, err := p.recvMsg(math.MaxInt32)
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -234,7 +234,7 @@ func TestInvokeLargeErr(t *testing.T) {
|
|||||||
var reply string
|
var reply string
|
||||||
req := "hello"
|
req := "hello"
|
||||||
err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
|
err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
|
||||||
if _, ok := err.(rpcError); !ok {
|
if _, ok := err.(*rpcError); !ok {
|
||||||
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
|
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
|
||||||
}
|
}
|
||||||
if Code(err) != codes.Internal || len(ErrorDesc(err)) != sizeLargeErr {
|
if Code(err) != codes.Internal || len(ErrorDesc(err)) != sizeLargeErr {
|
||||||
@ -250,7 +250,7 @@ func TestInvokeErrorSpecialChars(t *testing.T) {
|
|||||||
var reply string
|
var reply string
|
||||||
req := "weird error"
|
req := "weird error"
|
||||||
err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
|
err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
|
||||||
if _, ok := err.(rpcError); !ok {
|
if _, ok := err.(*rpcError); !ok {
|
||||||
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
|
t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.")
|
||||||
}
|
}
|
||||||
if got, want := ErrorDesc(err), weirdError; got != want {
|
if got, want := ErrorDesc(err), weirdError; got != want {
|
||||||
@ -276,3 +276,18 @@ func TestInvokeCancel(t *testing.T) {
|
|||||||
cc.Close()
|
cc.Close()
|
||||||
server.stop()
|
server.stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestInvokeCancelClosedNonFail checks that a canceled non-failfast RPC
|
||||||
|
// on a closed client will terminate.
|
||||||
|
func TestInvokeCancelClosedNonFailFast(t *testing.T) {
|
||||||
|
server, cc := setUp(t, 0, math.MaxUint32)
|
||||||
|
var reply string
|
||||||
|
cc.Close()
|
||||||
|
req := "hello"
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
if err := Invoke(ctx, "/foo/bar", &req, &reply, cc, FailFast(false)); err == nil {
|
||||||
|
t.Fatalf("canceled invoke on closed connection should fail")
|
||||||
|
}
|
||||||
|
server.stop()
|
||||||
|
}
|
||||||
|
366
clientconn.go
366
clientconn.go
@ -43,7 +43,6 @@ import (
|
|||||||
|
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
"golang.org/x/net/trace"
|
"golang.org/x/net/trace"
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/grpclog"
|
"google.golang.org/grpc/grpclog"
|
||||||
"google.golang.org/grpc/transport"
|
"google.golang.org/grpc/transport"
|
||||||
@ -68,13 +67,15 @@ var (
|
|||||||
// errCredentialsConflict indicates that grpc.WithTransportCredentials()
|
// errCredentialsConflict indicates that grpc.WithTransportCredentials()
|
||||||
// and grpc.WithInsecure() are both called for a connection.
|
// and grpc.WithInsecure() are both called for a connection.
|
||||||
errCredentialsConflict = errors.New("grpc: transport credentials are set for an insecure connection (grpc.WithTransportCredentials() and grpc.WithInsecure() are both called)")
|
errCredentialsConflict = errors.New("grpc: transport credentials are set for an insecure connection (grpc.WithTransportCredentials() and grpc.WithInsecure() are both called)")
|
||||||
// errNetworkIP indicates that the connection is down due to some network I/O error.
|
// errNetworkIO indicates that the connection is down due to some network I/O error.
|
||||||
errNetworkIO = errors.New("grpc: failed with network I/O error")
|
errNetworkIO = errors.New("grpc: failed with network I/O error")
|
||||||
// errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs.
|
// errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs.
|
||||||
errConnDrain = errors.New("grpc: the connection is drained")
|
errConnDrain = errors.New("grpc: the connection is drained")
|
||||||
// errConnClosing indicates that the connection is closing.
|
// errConnClosing indicates that the connection is closing.
|
||||||
errConnClosing = errors.New("grpc: the connection is closing")
|
errConnClosing = errors.New("grpc: the connection is closing")
|
||||||
errNoAddr = errors.New("grpc: there is no address available to dial")
|
// errConnUnavailable indicates that the connection is unavailable.
|
||||||
|
errConnUnavailable = errors.New("grpc: the connection is unavailable")
|
||||||
|
errNoAddr = errors.New("grpc: there is no address available to dial")
|
||||||
// minimum time to give a connection to complete
|
// minimum time to give a connection to complete
|
||||||
minConnectTimeout = 20 * time.Second
|
minConnectTimeout = 20 * time.Second
|
||||||
)
|
)
|
||||||
@ -196,9 +197,14 @@ func WithTimeout(d time.Duration) DialOption {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WithDialer returns a DialOption that specifies a function to use for dialing network addresses.
|
// WithDialer returns a DialOption that specifies a function to use for dialing network addresses.
|
||||||
func WithDialer(f func(addr string, timeout time.Duration) (net.Conn, error)) DialOption {
|
func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption {
|
||||||
return func(o *dialOptions) {
|
return func(o *dialOptions) {
|
||||||
o.copts.Dialer = f
|
o.copts.Dialer = func(ctx context.Context, addr string) (net.Conn, error) {
|
||||||
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
|
return f(addr, deadline.Sub(time.Now()))
|
||||||
|
}
|
||||||
|
return f(addr, 0)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -209,49 +215,72 @@ func WithUserAgent(s string) DialOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dial creates a client connection the given target.
|
// Dial creates a client connection to the given target.
|
||||||
func Dial(target string, opts ...DialOption) (*ClientConn, error) {
|
func Dial(target string, opts ...DialOption) (*ClientConn, error) {
|
||||||
|
return DialContext(context.Background(), target, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialContext creates a client connection to the given target. ctx can be used to
|
||||||
|
// cancel or expire the pending connecting. Once this function returns, the
|
||||||
|
// cancellation and expiration of ctx will be noop. Users should call ClientConn.Close
|
||||||
|
// to terminate all the pending operations after this function returns.
|
||||||
|
// This is the EXPERIMENTAL API.
|
||||||
|
func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) {
|
||||||
cc := &ClientConn{
|
cc := &ClientConn{
|
||||||
target: target,
|
target: target,
|
||||||
conns: make(map[Address]*addrConn),
|
conns: make(map[Address]*addrConn),
|
||||||
}
|
}
|
||||||
|
cc.ctx, cc.cancel = context.WithCancel(context.Background())
|
||||||
|
defer func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
conn, err = nil, ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
cc.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(&cc.dopts)
|
opt(&cc.dopts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set defaults.
|
||||||
if cc.dopts.codec == nil {
|
if cc.dopts.codec == nil {
|
||||||
// Set the default codec.
|
|
||||||
cc.dopts.codec = protoCodec{}
|
cc.dopts.codec = protoCodec{}
|
||||||
}
|
}
|
||||||
|
|
||||||
if cc.dopts.bs == nil {
|
if cc.dopts.bs == nil {
|
||||||
cc.dopts.bs = DefaultBackoffConfig
|
cc.dopts.bs = DefaultBackoffConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
cc.balancer = cc.dopts.balancer
|
|
||||||
if cc.balancer == nil {
|
|
||||||
cc.balancer = RoundRobin(nil)
|
|
||||||
}
|
|
||||||
if err := cc.balancer.Start(target); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var (
|
var (
|
||||||
ok bool
|
ok bool
|
||||||
addrs []Address
|
addrs []Address
|
||||||
)
|
)
|
||||||
ch := cc.balancer.Notify()
|
if cc.dopts.balancer == nil {
|
||||||
if ch == nil {
|
// Connect to target directly if balancer is nil.
|
||||||
// There is no name resolver installed.
|
|
||||||
addrs = append(addrs, Address{Addr: target})
|
addrs = append(addrs, Address{Addr: target})
|
||||||
} else {
|
} else {
|
||||||
addrs, ok = <-ch
|
if err := cc.dopts.balancer.Start(target); err != nil {
|
||||||
if !ok || len(addrs) == 0 {
|
return nil, err
|
||||||
return nil, errNoAddr
|
}
|
||||||
|
ch := cc.dopts.balancer.Notify()
|
||||||
|
if ch == nil {
|
||||||
|
// There is no name resolver installed.
|
||||||
|
addrs = append(addrs, Address{Addr: target})
|
||||||
|
} else {
|
||||||
|
addrs, ok = <-ch
|
||||||
|
if !ok || len(addrs) == 0 {
|
||||||
|
return nil, errNoAddr
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
waitC := make(chan error, 1)
|
waitC := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
for _, a := range addrs {
|
for _, a := range addrs {
|
||||||
if err := cc.newAddrConn(a, false); err != nil {
|
if err := cc.resetAddrConn(a, false, nil); err != nil {
|
||||||
waitC <- err
|
waitC <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -263,15 +292,17 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
|
|||||||
timeoutCh = time.After(cc.dopts.timeout)
|
timeoutCh = time.After(cc.dopts.timeout)
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
case err := <-waitC:
|
case err := <-waitC:
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cc.Close()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
case <-timeoutCh:
|
case <-timeoutCh:
|
||||||
cc.Close()
|
|
||||||
return nil, ErrClientConnTimeout
|
return nil, ErrClientConnTimeout
|
||||||
}
|
}
|
||||||
|
// If balancer is nil or balancer.Notify() is nil, ok will be false here.
|
||||||
|
// The lbWatcher goroutine will not be created.
|
||||||
if ok {
|
if ok {
|
||||||
go cc.lbWatcher()
|
go cc.lbWatcher()
|
||||||
}
|
}
|
||||||
@ -318,8 +349,10 @@ func (s ConnectivityState) String() string {
|
|||||||
|
|
||||||
// ClientConn represents a client connection to an RPC server.
|
// ClientConn represents a client connection to an RPC server.
|
||||||
type ClientConn struct {
|
type ClientConn struct {
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
|
||||||
target string
|
target string
|
||||||
balancer Balancer
|
|
||||||
authority string
|
authority string
|
||||||
dopts dialOptions
|
dopts dialOptions
|
||||||
|
|
||||||
@ -328,7 +361,7 @@ type ClientConn struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (cc *ClientConn) lbWatcher() {
|
func (cc *ClientConn) lbWatcher() {
|
||||||
for addrs := range cc.balancer.Notify() {
|
for addrs := range cc.dopts.balancer.Notify() {
|
||||||
var (
|
var (
|
||||||
add []Address // Addresses need to setup connections.
|
add []Address // Addresses need to setup connections.
|
||||||
del []*addrConn // Connections need to tear down.
|
del []*addrConn // Connections need to tear down.
|
||||||
@ -349,11 +382,12 @@ func (cc *ClientConn) lbWatcher() {
|
|||||||
}
|
}
|
||||||
if !keep {
|
if !keep {
|
||||||
del = append(del, c)
|
del = append(del, c)
|
||||||
|
delete(cc.conns, c.addr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cc.mu.Unlock()
|
cc.mu.Unlock()
|
||||||
for _, a := range add {
|
for _, a := range add {
|
||||||
cc.newAddrConn(a, true)
|
cc.resetAddrConn(a, true, nil)
|
||||||
}
|
}
|
||||||
for _, c := range del {
|
for _, c := range del {
|
||||||
c.tearDown(errConnDrain)
|
c.tearDown(errConnDrain)
|
||||||
@ -361,13 +395,17 @@ func (cc *ClientConn) lbWatcher() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
|
// resetAddrConn creates an addrConn for addr and adds it to cc.conns.
|
||||||
|
// If there is an old addrConn for addr, it will be torn down, using tearDownErr as the reason.
|
||||||
|
// If tearDownErr is nil, errConnDrain will be used instead.
|
||||||
|
func (cc *ClientConn) resetAddrConn(addr Address, skipWait bool, tearDownErr error) error {
|
||||||
ac := &addrConn{
|
ac := &addrConn{
|
||||||
cc: cc,
|
cc: cc,
|
||||||
addr: addr,
|
addr: addr,
|
||||||
dopts: cc.dopts,
|
dopts: cc.dopts,
|
||||||
shutdownChan: make(chan struct{}),
|
|
||||||
}
|
}
|
||||||
|
ac.ctx, ac.cancel = context.WithCancel(cc.ctx)
|
||||||
|
ac.stateCV = sync.NewCond(&ac.mu)
|
||||||
if EnableTracing {
|
if EnableTracing {
|
||||||
ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr)
|
ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr)
|
||||||
}
|
}
|
||||||
@ -385,26 +423,44 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Insert ac into ac.cc.conns. This needs to be done before any getTransport(...) is called.
|
// Track ac in cc. This needs to be done before any getTransport(...) is called.
|
||||||
ac.cc.mu.Lock()
|
cc.mu.Lock()
|
||||||
if ac.cc.conns == nil {
|
if cc.conns == nil {
|
||||||
ac.cc.mu.Unlock()
|
cc.mu.Unlock()
|
||||||
return ErrClientConnClosing
|
return ErrClientConnClosing
|
||||||
}
|
}
|
||||||
stale := ac.cc.conns[ac.addr]
|
stale := cc.conns[ac.addr]
|
||||||
ac.cc.conns[ac.addr] = ac
|
cc.conns[ac.addr] = ac
|
||||||
ac.cc.mu.Unlock()
|
cc.mu.Unlock()
|
||||||
if stale != nil {
|
if stale != nil {
|
||||||
// There is an addrConn alive on ac.addr already. This could be due to
|
// There is an addrConn alive on ac.addr already. This could be due to
|
||||||
// i) stale's Close is undergoing;
|
// 1) a buggy Balancer notifies duplicated Addresses;
|
||||||
// ii) a buggy Balancer notifies duplicated Addresses.
|
// 2) goaway was received, a new ac will replace the old ac.
|
||||||
stale.tearDown(errConnDrain)
|
// The old ac should be deleted from cc.conns, but the
|
||||||
|
// underlying transport should drain rather than close.
|
||||||
|
if tearDownErr == nil {
|
||||||
|
// tearDownErr is nil if resetAddrConn is called by
|
||||||
|
// 1) Dial
|
||||||
|
// 2) lbWatcher
|
||||||
|
// In both cases, the stale ac should drain, not close.
|
||||||
|
stale.tearDown(errConnDrain)
|
||||||
|
} else {
|
||||||
|
stale.tearDown(tearDownErr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
ac.stateCV = sync.NewCond(&ac.mu)
|
|
||||||
// skipWait may overwrite the decision in ac.dopts.block.
|
// skipWait may overwrite the decision in ac.dopts.block.
|
||||||
if ac.dopts.block && !skipWait {
|
if ac.dopts.block && !skipWait {
|
||||||
if err := ac.resetTransport(false); err != nil {
|
if err := ac.resetTransport(false); err != nil {
|
||||||
ac.tearDown(err)
|
if err != errConnClosing {
|
||||||
|
// Tear down ac and delete it from cc.conns.
|
||||||
|
cc.mu.Lock()
|
||||||
|
delete(cc.conns, ac.addr)
|
||||||
|
cc.mu.Unlock()
|
||||||
|
ac.tearDown(err)
|
||||||
|
}
|
||||||
|
if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() {
|
||||||
|
return e.Origin()
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Start to monitor the error status of transport.
|
// Start to monitor the error status of transport.
|
||||||
@ -414,7 +470,10 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
|
|||||||
go func() {
|
go func() {
|
||||||
if err := ac.resetTransport(false); err != nil {
|
if err := ac.resetTransport(false); err != nil {
|
||||||
grpclog.Printf("Failed to dial %s: %v; please retry.", ac.addr.Addr, err)
|
grpclog.Printf("Failed to dial %s: %v; please retry.", ac.addr.Addr, err)
|
||||||
ac.tearDown(err)
|
if err != errConnClosing {
|
||||||
|
// Keep this ac in cc.conns, to get the reason it's torn down.
|
||||||
|
ac.tearDown(err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ac.transportMonitor()
|
ac.transportMonitor()
|
||||||
@ -424,25 +483,48 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) {
|
func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) {
|
||||||
// TODO(zhaoq): Implement fail-fast logic.
|
var (
|
||||||
addr, put, err := cc.balancer.Get(ctx, opts)
|
ac *addrConn
|
||||||
if err != nil {
|
ok bool
|
||||||
return nil, nil, err
|
put func()
|
||||||
}
|
)
|
||||||
cc.mu.RLock()
|
if cc.dopts.balancer == nil {
|
||||||
if cc.conns == nil {
|
// If balancer is nil, there should be only one addrConn available.
|
||||||
|
cc.mu.RLock()
|
||||||
|
if cc.conns == nil {
|
||||||
|
cc.mu.RUnlock()
|
||||||
|
return nil, nil, toRPCErr(ErrClientConnClosing)
|
||||||
|
}
|
||||||
|
for _, ac = range cc.conns {
|
||||||
|
// Break after the first iteration to get the first addrConn.
|
||||||
|
ok = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
cc.mu.RUnlock()
|
||||||
|
} else {
|
||||||
|
var (
|
||||||
|
addr Address
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
addr, put, err = cc.dopts.balancer.Get(ctx, opts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, toRPCErr(err)
|
||||||
|
}
|
||||||
|
cc.mu.RLock()
|
||||||
|
if cc.conns == nil {
|
||||||
|
cc.mu.RUnlock()
|
||||||
|
return nil, nil, toRPCErr(ErrClientConnClosing)
|
||||||
|
}
|
||||||
|
ac, ok = cc.conns[addr]
|
||||||
cc.mu.RUnlock()
|
cc.mu.RUnlock()
|
||||||
return nil, nil, ErrClientConnClosing
|
|
||||||
}
|
}
|
||||||
ac, ok := cc.conns[addr]
|
|
||||||
cc.mu.RUnlock()
|
|
||||||
if !ok {
|
if !ok {
|
||||||
if put != nil {
|
if put != nil {
|
||||||
put()
|
put()
|
||||||
}
|
}
|
||||||
return nil, nil, transport.StreamErrorf(codes.Internal, "grpc: failed to find the transport to send the rpc")
|
return nil, nil, errConnClosing
|
||||||
}
|
}
|
||||||
t, err := ac.wait(ctx)
|
t, err := ac.wait(ctx, cc.dopts.balancer != nil, !opts.BlockingWait)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if put != nil {
|
if put != nil {
|
||||||
put()
|
put()
|
||||||
@ -454,6 +536,8 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions)
|
|||||||
|
|
||||||
// Close tears down the ClientConn and all underlying connections.
|
// Close tears down the ClientConn and all underlying connections.
|
||||||
func (cc *ClientConn) Close() error {
|
func (cc *ClientConn) Close() error {
|
||||||
|
cc.cancel()
|
||||||
|
|
||||||
cc.mu.Lock()
|
cc.mu.Lock()
|
||||||
if cc.conns == nil {
|
if cc.conns == nil {
|
||||||
cc.mu.Unlock()
|
cc.mu.Unlock()
|
||||||
@ -462,7 +546,9 @@ func (cc *ClientConn) Close() error {
|
|||||||
conns := cc.conns
|
conns := cc.conns
|
||||||
cc.conns = nil
|
cc.conns = nil
|
||||||
cc.mu.Unlock()
|
cc.mu.Unlock()
|
||||||
cc.balancer.Close()
|
if cc.dopts.balancer != nil {
|
||||||
|
cc.dopts.balancer.Close()
|
||||||
|
}
|
||||||
for _, ac := range conns {
|
for _, ac := range conns {
|
||||||
ac.tearDown(ErrClientConnClosing)
|
ac.tearDown(ErrClientConnClosing)
|
||||||
}
|
}
|
||||||
@ -471,11 +557,13 @@ func (cc *ClientConn) Close() error {
|
|||||||
|
|
||||||
// addrConn is a network connection to a given address.
|
// addrConn is a network connection to a given address.
|
||||||
type addrConn struct {
|
type addrConn struct {
|
||||||
cc *ClientConn
|
ctx context.Context
|
||||||
addr Address
|
cancel context.CancelFunc
|
||||||
dopts dialOptions
|
|
||||||
shutdownChan chan struct{}
|
cc *ClientConn
|
||||||
events trace.EventLog
|
addr Address
|
||||||
|
dopts dialOptions
|
||||||
|
events trace.EventLog
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
state ConnectivityState
|
state ConnectivityState
|
||||||
@ -485,6 +573,9 @@ type addrConn struct {
|
|||||||
// due to timeout.
|
// due to timeout.
|
||||||
ready chan struct{}
|
ready chan struct{}
|
||||||
transport transport.ClientTransport
|
transport transport.ClientTransport
|
||||||
|
|
||||||
|
// The reason this addrConn is torn down.
|
||||||
|
tearDownErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
// printf records an event in ac's event log, unless ac has been closed.
|
// printf records an event in ac's event log, unless ac has been closed.
|
||||||
@ -540,8 +631,7 @@ func (ac *addrConn) waitForStateChange(ctx context.Context, sourceState Connecti
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ac *addrConn) resetTransport(closeTransport bool) error {
|
func (ac *addrConn) resetTransport(closeTransport bool) error {
|
||||||
var retries int
|
for retries := 0; ; retries++ {
|
||||||
for {
|
|
||||||
ac.mu.Lock()
|
ac.mu.Lock()
|
||||||
ac.printf("connecting")
|
ac.printf("connecting")
|
||||||
if ac.state == Shutdown {
|
if ac.state == Shutdown {
|
||||||
@ -561,13 +651,20 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
|
|||||||
t.Close()
|
t.Close()
|
||||||
}
|
}
|
||||||
sleepTime := ac.dopts.bs.backoff(retries)
|
sleepTime := ac.dopts.bs.backoff(retries)
|
||||||
ac.dopts.copts.Timeout = sleepTime
|
timeout := minConnectTimeout
|
||||||
if sleepTime < minConnectTimeout {
|
if timeout < sleepTime {
|
||||||
ac.dopts.copts.Timeout = minConnectTimeout
|
timeout = sleepTime
|
||||||
}
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(ac.ctx, timeout)
|
||||||
connectTime := time.Now()
|
connectTime := time.Now()
|
||||||
newTransport, err := transport.NewClientTransport(ac.addr.Addr, &ac.dopts.copts)
|
newTransport, err := transport.NewClientTransport(ctx, ac.addr.Addr, ac.dopts.copts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr)
|
||||||
ac.mu.Lock()
|
ac.mu.Lock()
|
||||||
if ac.state == Shutdown {
|
if ac.state == Shutdown {
|
||||||
// ac.tearDown(...) has been invoked.
|
// ac.tearDown(...) has been invoked.
|
||||||
@ -582,17 +679,12 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
|
|||||||
ac.ready = nil
|
ac.ready = nil
|
||||||
}
|
}
|
||||||
ac.mu.Unlock()
|
ac.mu.Unlock()
|
||||||
sleepTime -= time.Since(connectTime)
|
|
||||||
if sleepTime < 0 {
|
|
||||||
sleepTime = 0
|
|
||||||
}
|
|
||||||
closeTransport = false
|
closeTransport = false
|
||||||
select {
|
select {
|
||||||
case <-time.After(sleepTime):
|
case <-time.After(sleepTime - time.Since(connectTime)):
|
||||||
case <-ac.shutdownChan:
|
case <-ac.ctx.Done():
|
||||||
|
return ac.ctx.Err()
|
||||||
}
|
}
|
||||||
retries++
|
|
||||||
grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ac.mu.Lock()
|
ac.mu.Lock()
|
||||||
@ -610,7 +702,9 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
|
|||||||
close(ac.ready)
|
close(ac.ready)
|
||||||
ac.ready = nil
|
ac.ready = nil
|
||||||
}
|
}
|
||||||
ac.down = ac.cc.balancer.Up(ac.addr)
|
if ac.cc.dopts.balancer != nil {
|
||||||
|
ac.down = ac.cc.dopts.balancer.Up(ac.addr)
|
||||||
|
}
|
||||||
ac.mu.Unlock()
|
ac.mu.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -624,14 +718,42 @@ func (ac *addrConn) transportMonitor() {
|
|||||||
t := ac.transport
|
t := ac.transport
|
||||||
ac.mu.Unlock()
|
ac.mu.Unlock()
|
||||||
select {
|
select {
|
||||||
// shutdownChan is needed to detect the teardown when
|
// This is needed to detect the teardown when
|
||||||
// the addrConn is idle (i.e., no RPC in flight).
|
// the addrConn is idle (i.e., no RPC in flight).
|
||||||
case <-ac.shutdownChan:
|
case <-ac.ctx.Done():
|
||||||
|
select {
|
||||||
|
case <-t.Error():
|
||||||
|
t.Close()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case <-t.GoAway():
|
||||||
|
// If GoAway happens without any network I/O error, ac is closed without shutting down the
|
||||||
|
// underlying transport (the transport will be closed when all the pending RPCs finished or
|
||||||
|
// failed.).
|
||||||
|
// If GoAway and some network I/O error happen concurrently, ac and its underlying transport
|
||||||
|
// are closed.
|
||||||
|
// In both cases, a new ac is created.
|
||||||
|
select {
|
||||||
|
case <-t.Error():
|
||||||
|
ac.cc.resetAddrConn(ac.addr, true, errNetworkIO)
|
||||||
|
default:
|
||||||
|
ac.cc.resetAddrConn(ac.addr, true, errConnDrain)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
case <-t.Error():
|
case <-t.Error():
|
||||||
|
select {
|
||||||
|
case <-ac.ctx.Done():
|
||||||
|
t.Close()
|
||||||
|
return
|
||||||
|
case <-t.GoAway():
|
||||||
|
ac.cc.resetAddrConn(ac.addr, true, errNetworkIO)
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
ac.mu.Lock()
|
ac.mu.Lock()
|
||||||
if ac.state == Shutdown {
|
if ac.state == Shutdown {
|
||||||
// ac.tearDown(...) has been invoked.
|
// ac has been shutdown.
|
||||||
ac.mu.Unlock()
|
ac.mu.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -643,38 +765,53 @@ func (ac *addrConn) transportMonitor() {
|
|||||||
ac.printf("transport exiting: %v", err)
|
ac.printf("transport exiting: %v", err)
|
||||||
ac.mu.Unlock()
|
ac.mu.Unlock()
|
||||||
grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err)
|
grpclog.Printf("grpc: addrConn.transportMonitor exits due to: %v", err)
|
||||||
|
if err != errConnClosing {
|
||||||
|
// Keep this ac in cc.conns, to get the reason it's torn down.
|
||||||
|
ac.tearDown(err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed.
|
// wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed or
|
||||||
func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error) {
|
// iv) transport is in TransientFailure and there's no balancer/failfast is true.
|
||||||
|
func (ac *addrConn) wait(ctx context.Context, hasBalancer, failfast bool) (transport.ClientTransport, error) {
|
||||||
for {
|
for {
|
||||||
ac.mu.Lock()
|
ac.mu.Lock()
|
||||||
switch {
|
switch {
|
||||||
case ac.state == Shutdown:
|
case ac.state == Shutdown:
|
||||||
|
if failfast || !hasBalancer {
|
||||||
|
// RPC is failfast or balancer is nil. This RPC should fail with ac.tearDownErr.
|
||||||
|
err := ac.tearDownErr
|
||||||
|
ac.mu.Unlock()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
ac.mu.Unlock()
|
ac.mu.Unlock()
|
||||||
return nil, errConnClosing
|
return nil, errConnClosing
|
||||||
case ac.state == Ready:
|
case ac.state == Ready:
|
||||||
ct := ac.transport
|
ct := ac.transport
|
||||||
ac.mu.Unlock()
|
ac.mu.Unlock()
|
||||||
return ct, nil
|
return ct, nil
|
||||||
default:
|
case ac.state == TransientFailure:
|
||||||
ready := ac.ready
|
if failfast || hasBalancer {
|
||||||
if ready == nil {
|
ac.mu.Unlock()
|
||||||
ready = make(chan struct{})
|
return nil, errConnUnavailable
|
||||||
ac.ready = ready
|
|
||||||
}
|
|
||||||
ac.mu.Unlock()
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, transport.ContextErr(ctx.Err())
|
|
||||||
// Wait until the new transport is ready or failed.
|
|
||||||
case <-ready:
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ready := ac.ready
|
||||||
|
if ready == nil {
|
||||||
|
ready = make(chan struct{})
|
||||||
|
ac.ready = ready
|
||||||
|
}
|
||||||
|
ac.mu.Unlock()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, toRPCErr(ctx.Err())
|
||||||
|
// Wait until the new transport is ready or failed.
|
||||||
|
case <-ready:
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -682,24 +819,28 @@ func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error)
|
|||||||
// TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in
|
// TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in
|
||||||
// some edge cases (e.g., the caller opens and closes many addrConn's in a
|
// some edge cases (e.g., the caller opens and closes many addrConn's in a
|
||||||
// tight loop.
|
// tight loop.
|
||||||
|
// tearDown doesn't remove ac from ac.cc.conns.
|
||||||
func (ac *addrConn) tearDown(err error) {
|
func (ac *addrConn) tearDown(err error) {
|
||||||
|
ac.cancel()
|
||||||
|
|
||||||
ac.mu.Lock()
|
ac.mu.Lock()
|
||||||
defer func() {
|
defer ac.mu.Unlock()
|
||||||
ac.mu.Unlock()
|
|
||||||
ac.cc.mu.Lock()
|
|
||||||
if ac.cc.conns != nil {
|
|
||||||
delete(ac.cc.conns, ac.addr)
|
|
||||||
}
|
|
||||||
ac.cc.mu.Unlock()
|
|
||||||
}()
|
|
||||||
if ac.state == Shutdown {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ac.state = Shutdown
|
|
||||||
if ac.down != nil {
|
if ac.down != nil {
|
||||||
ac.down(downErrorf(false, false, "%v", err))
|
ac.down(downErrorf(false, false, "%v", err))
|
||||||
ac.down = nil
|
ac.down = nil
|
||||||
}
|
}
|
||||||
|
if err == errConnDrain && ac.transport != nil {
|
||||||
|
// GracefulClose(...) may be executed multiple times when
|
||||||
|
// i) receiving multiple GoAway frames from the server; or
|
||||||
|
// ii) there are concurrent name resolver/Balancer triggered
|
||||||
|
// address removal and GoAway.
|
||||||
|
ac.transport.GracefulClose()
|
||||||
|
}
|
||||||
|
if ac.state == Shutdown {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ac.state = Shutdown
|
||||||
|
ac.tearDownErr = err
|
||||||
ac.stateCV.Broadcast()
|
ac.stateCV.Broadcast()
|
||||||
if ac.events != nil {
|
if ac.events != nil {
|
||||||
ac.events.Finish()
|
ac.events.Finish()
|
||||||
@ -709,15 +850,8 @@ func (ac *addrConn) tearDown(err error) {
|
|||||||
close(ac.ready)
|
close(ac.ready)
|
||||||
ac.ready = nil
|
ac.ready = nil
|
||||||
}
|
}
|
||||||
if ac.transport != nil {
|
if ac.transport != nil && err != errConnDrain {
|
||||||
if err == errConnDrain {
|
ac.transport.Close()
|
||||||
ac.transport.GracefulClose()
|
|
||||||
} else {
|
|
||||||
ac.transport.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if ac.shutdownChan != nil {
|
|
||||||
close(ac.shutdownChan)
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -37,6 +37,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/credentials/oauth"
|
"google.golang.org/grpc/credentials/oauth"
|
||||||
)
|
)
|
||||||
@ -67,13 +69,21 @@ func TestTLSDialTimeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDialContextCancel(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
if _, err := DialContext(ctx, "Non-Existent.Server:80", WithBlock(), WithInsecure()); err != context.Canceled {
|
||||||
|
t.Fatalf("grpc.DialContext(%v, _) = _, %v, want _, %v", ctx, err, context.Canceled)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCredentialsMisuse(t *testing.T) {
|
func TestCredentialsMisuse(t *testing.T) {
|
||||||
tlsCreds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
|
tlsCreds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create authenticator %v", err)
|
t.Fatalf("Failed to create authenticator %v", err)
|
||||||
}
|
}
|
||||||
// Two conflicting credential configurations
|
// Two conflicting credential configurations
|
||||||
if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(tlsCreds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsConflict {
|
if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(tlsCreds), WithBlock(), WithInsecure()); err != errCredentialsConflict {
|
||||||
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsConflict)
|
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsConflict)
|
||||||
}
|
}
|
||||||
rpcCreds, err := oauth.NewJWTAccessFromKey(nil)
|
rpcCreds, err := oauth.NewJWTAccessFromKey(nil)
|
||||||
@ -81,7 +91,7 @@ func TestCredentialsMisuse(t *testing.T) {
|
|||||||
t.Fatalf("Failed to create credentials %v", err)
|
t.Fatalf("Failed to create credentials %v", err)
|
||||||
}
|
}
|
||||||
// security info on insecure connection
|
// security info on insecure connection
|
||||||
if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(rpcCreds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errTransportCredentialsMissing {
|
if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(rpcCreds), WithBlock(), WithInsecure()); err != errTransportCredentialsMissing {
|
||||||
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errTransportCredentialsMissing)
|
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errTransportCredentialsMissing)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -123,4 +133,5 @@ func testBackoffConfigSet(t *testing.T, expected *BackoffConfig, opts ...DialOpt
|
|||||||
if actual != *expected {
|
if actual != *expected {
|
||||||
t.Fatalf("unexpected backoff config on connection: %v, want %v", actual, expected)
|
t.Fatalf("unexpected backoff config on connection: %v, want %v", actual, expected)
|
||||||
}
|
}
|
||||||
|
conn.Close()
|
||||||
}
|
}
|
||||||
|
@ -44,7 +44,6 @@ import (
|
|||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
)
|
)
|
||||||
@ -93,11 +92,12 @@ type TransportCredentials interface {
|
|||||||
// ClientHandshake does the authentication handshake specified by the corresponding
|
// ClientHandshake does the authentication handshake specified by the corresponding
|
||||||
// authentication protocol on rawConn for clients. It returns the authenticated
|
// authentication protocol on rawConn for clients. It returns the authenticated
|
||||||
// connection and the corresponding auth information about the connection.
|
// connection and the corresponding auth information about the connection.
|
||||||
ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, AuthInfo, error)
|
// Implementations must use the provided context to implement timely cancellation.
|
||||||
|
ClientHandshake(context.Context, string, net.Conn) (net.Conn, AuthInfo, error)
|
||||||
// ServerHandshake does the authentication handshake for servers. It returns
|
// ServerHandshake does the authentication handshake for servers. It returns
|
||||||
// the authenticated connection and the corresponding auth information about
|
// the authenticated connection and the corresponding auth information about
|
||||||
// the connection.
|
// the connection.
|
||||||
ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
|
ServerHandshake(net.Conn) (net.Conn, AuthInfo, error)
|
||||||
// Info provides the ProtocolInfo of this TransportCredentials.
|
// Info provides the ProtocolInfo of this TransportCredentials.
|
||||||
Info() ProtocolInfo
|
Info() ProtocolInfo
|
||||||
}
|
}
|
||||||
@ -116,7 +116,7 @@ func (t TLSInfo) AuthType() string {
|
|||||||
// tlsCreds is the credentials required for authenticating a connection using TLS.
|
// tlsCreds is the credentials required for authenticating a connection using TLS.
|
||||||
type tlsCreds struct {
|
type tlsCreds struct {
|
||||||
// TLS configuration
|
// TLS configuration
|
||||||
config tls.Config
|
config *tls.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c tlsCreds) Info() ProtocolInfo {
|
func (c tlsCreds) Info() ProtocolInfo {
|
||||||
@ -136,40 +136,28 @@ func (c *tlsCreds) RequireTransportSecurity() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
type timeoutError struct{}
|
func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) {
|
||||||
|
// use local cfg to avoid clobbering ServerName if using multiple endpoints
|
||||||
func (timeoutError) Error() string { return "credentials: Dial timed out" }
|
cfg := cloneTLSConfig(c.config)
|
||||||
func (timeoutError) Timeout() bool { return true }
|
if cfg.ServerName == "" {
|
||||||
func (timeoutError) Temporary() bool { return true }
|
|
||||||
|
|
||||||
func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ AuthInfo, err error) {
|
|
||||||
// borrow some code from tls.DialWithDialer
|
|
||||||
var errChannel chan error
|
|
||||||
if timeout != 0 {
|
|
||||||
errChannel = make(chan error, 2)
|
|
||||||
time.AfterFunc(timeout, func() {
|
|
||||||
errChannel <- timeoutError{}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
if c.config.ServerName == "" {
|
|
||||||
colonPos := strings.LastIndex(addr, ":")
|
colonPos := strings.LastIndex(addr, ":")
|
||||||
if colonPos == -1 {
|
if colonPos == -1 {
|
||||||
colonPos = len(addr)
|
colonPos = len(addr)
|
||||||
}
|
}
|
||||||
c.config.ServerName = addr[:colonPos]
|
cfg.ServerName = addr[:colonPos]
|
||||||
}
|
}
|
||||||
conn := tls.Client(rawConn, &c.config)
|
conn := tls.Client(rawConn, cfg)
|
||||||
if timeout == 0 {
|
errChannel := make(chan error, 1)
|
||||||
err = conn.Handshake()
|
go func() {
|
||||||
} else {
|
errChannel <- conn.Handshake()
|
||||||
go func() {
|
}()
|
||||||
errChannel <- conn.Handshake()
|
select {
|
||||||
}()
|
case err := <-errChannel:
|
||||||
err = <-errChannel
|
if err != nil {
|
||||||
}
|
return nil, nil, err
|
||||||
if err != nil {
|
}
|
||||||
rawConn.Close()
|
case <-ctx.Done():
|
||||||
return nil, nil, err
|
return nil, nil, ctx.Err()
|
||||||
}
|
}
|
||||||
// TODO(zhaoq): Omit the auth info for client now. It is more for
|
// TODO(zhaoq): Omit the auth info for client now. It is more for
|
||||||
// information than anything else.
|
// information than anything else.
|
||||||
@ -177,9 +165,8 @@ func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.D
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
|
func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
|
||||||
conn := tls.Server(rawConn, &c.config)
|
conn := tls.Server(rawConn, c.config)
|
||||||
if err := conn.Handshake(); err != nil {
|
if err := conn.Handshake(); err != nil {
|
||||||
rawConn.Close()
|
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
return conn, TLSInfo{conn.ConnectionState()}, nil
|
return conn, TLSInfo{conn.ConnectionState()}, nil
|
||||||
@ -187,7 +174,7 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
|
|||||||
|
|
||||||
// NewTLS uses c to construct a TransportCredentials based on TLS.
|
// NewTLS uses c to construct a TransportCredentials based on TLS.
|
||||||
func NewTLS(c *tls.Config) TransportCredentials {
|
func NewTLS(c *tls.Config) TransportCredentials {
|
||||||
tc := &tlsCreds{*c}
|
tc := &tlsCreds{cloneTLSConfig(c)}
|
||||||
tc.config.NextProtos = alpnProtoStr
|
tc.config.NextProtos = alpnProtoStr
|
||||||
return tc
|
return tc
|
||||||
}
|
}
|
||||||
|
76
credentials/credentials_util_go17.go
Normal file
76
credentials/credentials_util_go17.go
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
// +build go1.7
|
||||||
|
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* Copyright 2016, Google Inc.
|
||||||
|
* All rights reserved.
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are
|
||||||
|
* met:
|
||||||
|
*
|
||||||
|
* * Redistributions of source code must retain the above copyright
|
||||||
|
* notice, this list of conditions and the following disclaimer.
|
||||||
|
* * Redistributions in binary form must reproduce the above
|
||||||
|
* copyright notice, this list of conditions and the following disclaimer
|
||||||
|
* in the documentation and/or other materials provided with the
|
||||||
|
* distribution.
|
||||||
|
* * Neither the name of Google Inc. nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package credentials
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
)
|
||||||
|
|
||||||
|
// cloneTLSConfig returns a shallow clone of the exported
|
||||||
|
// fields of cfg, ignoring the unexported sync.Once, which
|
||||||
|
// contains a mutex and must not be copied.
|
||||||
|
//
|
||||||
|
// If cfg is nil, a new zero tls.Config is returned.
|
||||||
|
//
|
||||||
|
// TODO replace this function with official clone function.
|
||||||
|
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
|
||||||
|
if cfg == nil {
|
||||||
|
return &tls.Config{}
|
||||||
|
}
|
||||||
|
return &tls.Config{
|
||||||
|
Rand: cfg.Rand,
|
||||||
|
Time: cfg.Time,
|
||||||
|
Certificates: cfg.Certificates,
|
||||||
|
NameToCertificate: cfg.NameToCertificate,
|
||||||
|
GetCertificate: cfg.GetCertificate,
|
||||||
|
RootCAs: cfg.RootCAs,
|
||||||
|
NextProtos: cfg.NextProtos,
|
||||||
|
ServerName: cfg.ServerName,
|
||||||
|
ClientAuth: cfg.ClientAuth,
|
||||||
|
ClientCAs: cfg.ClientCAs,
|
||||||
|
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||||
|
CipherSuites: cfg.CipherSuites,
|
||||||
|
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
|
||||||
|
SessionTicketsDisabled: cfg.SessionTicketsDisabled,
|
||||||
|
SessionTicketKey: cfg.SessionTicketKey,
|
||||||
|
ClientSessionCache: cfg.ClientSessionCache,
|
||||||
|
MinVersion: cfg.MinVersion,
|
||||||
|
MaxVersion: cfg.MaxVersion,
|
||||||
|
CurvePreferences: cfg.CurvePreferences,
|
||||||
|
DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
|
||||||
|
Renegotiation: cfg.Renegotiation,
|
||||||
|
}
|
||||||
|
}
|
74
credentials/credentials_util_pre_go17.go
Normal file
74
credentials/credentials_util_pre_go17.go
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
// +build !go1.7
|
||||||
|
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* Copyright 2016, Google Inc.
|
||||||
|
* All rights reserved.
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are
|
||||||
|
* met:
|
||||||
|
*
|
||||||
|
* * Redistributions of source code must retain the above copyright
|
||||||
|
* notice, this list of conditions and the following disclaimer.
|
||||||
|
* * Redistributions in binary form must reproduce the above
|
||||||
|
* copyright notice, this list of conditions and the following disclaimer
|
||||||
|
* in the documentation and/or other materials provided with the
|
||||||
|
* distribution.
|
||||||
|
* * Neither the name of Google Inc. nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package credentials
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
)
|
||||||
|
|
||||||
|
// cloneTLSConfig returns a shallow clone of the exported
|
||||||
|
// fields of cfg, ignoring the unexported sync.Once, which
|
||||||
|
// contains a mutex and must not be copied.
|
||||||
|
//
|
||||||
|
// If cfg is nil, a new zero tls.Config is returned.
|
||||||
|
//
|
||||||
|
// TODO replace this function with official clone function.
|
||||||
|
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
|
||||||
|
if cfg == nil {
|
||||||
|
return &tls.Config{}
|
||||||
|
}
|
||||||
|
return &tls.Config{
|
||||||
|
Rand: cfg.Rand,
|
||||||
|
Time: cfg.Time,
|
||||||
|
Certificates: cfg.Certificates,
|
||||||
|
NameToCertificate: cfg.NameToCertificate,
|
||||||
|
GetCertificate: cfg.GetCertificate,
|
||||||
|
RootCAs: cfg.RootCAs,
|
||||||
|
NextProtos: cfg.NextProtos,
|
||||||
|
ServerName: cfg.ServerName,
|
||||||
|
ClientAuth: cfg.ClientAuth,
|
||||||
|
ClientCAs: cfg.ClientCAs,
|
||||||
|
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||||
|
CipherSuites: cfg.CipherSuites,
|
||||||
|
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
|
||||||
|
SessionTicketsDisabled: cfg.SessionTicketsDisabled,
|
||||||
|
SessionTicketKey: cfg.SessionTicketKey,
|
||||||
|
ClientSessionCache: cfg.ClientSessionCache,
|
||||||
|
MinVersion: cfg.MinVersion,
|
||||||
|
MaxVersion: cfg.MaxVersion,
|
||||||
|
CurvePreferences: cfg.CurvePreferences,
|
||||||
|
}
|
||||||
|
}
|
@ -28,12 +28,12 @@ Then change your current directory to `grpc-go/examples/route_guide`:
|
|||||||
$ cd $GOPATH/src/google.golang.org/grpc/examples/route_guide
|
$ cd $GOPATH/src/google.golang.org/grpc/examples/route_guide
|
||||||
```
|
```
|
||||||
|
|
||||||
You also should have the relevant tools installed to generate the server and client interface code - if you don't already, follow the setup instructions in [the Go quick start guide](examples/).
|
You also should have the relevant tools installed to generate the server and client interface code - if you don't already, follow the setup instructions in [the Go quick start guide](https://github.com/grpc/grpc-go/tree/master/examples/).
|
||||||
|
|
||||||
|
|
||||||
## Defining the service
|
## Defining the service
|
||||||
|
|
||||||
Our first step (as you'll know from the [quick start](http://www.grpc.io/docs/#quick-start)) is to define the gRPC *service* and the method *request* and *response* types using [protocol buffers] (https://developers.google.com/protocol-buffers/docs/overview). You can see the complete .proto file in [`examples/route_guide/proto/route_guide.proto`](examples/route_guide/proto/route_guide.proto).
|
Our first step (as you'll know from the [quick start](http://www.grpc.io/docs/#quick-start)) is to define the gRPC *service* and the method *request* and *response* types using [protocol buffers] (https://developers.google.com/protocol-buffers/docs/overview). You can see the complete .proto file in [examples/route_guide/routeguide/route_guide.proto](https://github.com/grpc/grpc-go/tree/master/examples/route_guide/routeguide/route_guide.proto).
|
||||||
|
|
||||||
To define a service, you specify a named `service` in your .proto file:
|
To define a service, you specify a named `service` in your .proto file:
|
||||||
|
|
||||||
|
@ -115,12 +115,12 @@ func runRecordRoute(client pb.RouteGuideClient) {
|
|||||||
// runRouteChat receives a sequence of route notes, while sending notes for various locations.
|
// runRouteChat receives a sequence of route notes, while sending notes for various locations.
|
||||||
func runRouteChat(client pb.RouteGuideClient) {
|
func runRouteChat(client pb.RouteGuideClient) {
|
||||||
notes := []*pb.RouteNote{
|
notes := []*pb.RouteNote{
|
||||||
{&pb.Point{0, 1}, "First message"},
|
{&pb.Point{Latitude: 0, Longitude: 1}, "First message"},
|
||||||
{&pb.Point{0, 2}, "Second message"},
|
{&pb.Point{Latitude: 0, Longitude: 2}, "Second message"},
|
||||||
{&pb.Point{0, 3}, "Third message"},
|
{&pb.Point{Latitude: 0, Longitude: 3}, "Third message"},
|
||||||
{&pb.Point{0, 1}, "Fourth message"},
|
{&pb.Point{Latitude: 0, Longitude: 1}, "Fourth message"},
|
||||||
{&pb.Point{0, 2}, "Fifth message"},
|
{&pb.Point{Latitude: 0, Longitude: 2}, "Fifth message"},
|
||||||
{&pb.Point{0, 3}, "Sixth message"},
|
{&pb.Point{Latitude: 0, Longitude: 3}, "Sixth message"},
|
||||||
}
|
}
|
||||||
stream, err := client.RouteChat(context.Background())
|
stream, err := client.RouteChat(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -153,7 +153,7 @@ func runRouteChat(client pb.RouteGuideClient) {
|
|||||||
func randomPoint(r *rand.Rand) *pb.Point {
|
func randomPoint(r *rand.Rand) *pb.Point {
|
||||||
lat := (r.Int31n(180) - 90) * 1e7
|
lat := (r.Int31n(180) - 90) * 1e7
|
||||||
long := (r.Int31n(360) - 180) * 1e7
|
long := (r.Int31n(360) - 180) * 1e7
|
||||||
return &pb.Point{lat, long}
|
return &pb.Point{Latitude: lat, Longitude: long}
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@ -186,13 +186,16 @@ func main() {
|
|||||||
client := pb.NewRouteGuideClient(conn)
|
client := pb.NewRouteGuideClient(conn)
|
||||||
|
|
||||||
// Looking for a valid feature
|
// Looking for a valid feature
|
||||||
printFeature(client, &pb.Point{409146138, -746188906})
|
printFeature(client, &pb.Point{Latitude: 409146138, Longitude: -746188906})
|
||||||
|
|
||||||
// Feature missing.
|
// Feature missing.
|
||||||
printFeature(client, &pb.Point{0, 0})
|
printFeature(client, &pb.Point{Latitude: 0, Longitude: 0})
|
||||||
|
|
||||||
// Looking for features between 40, -75 and 42, -73.
|
// Looking for features between 40, -75 and 42, -73.
|
||||||
printFeatures(client, &pb.Rectangle{&pb.Point{400000000, -750000000}, &pb.Point{420000000, -730000000}})
|
printFeatures(client, &pb.Rectangle{
|
||||||
|
Lo: &pb.Point{Latitude: 400000000, Longitude: -750000000},
|
||||||
|
Hi: &pb.Point{Latitude: 420000000, Longitude: -730000000},
|
||||||
|
})
|
||||||
|
|
||||||
// RecordRoute
|
// RecordRoute
|
||||||
runRecordRoute(client)
|
runRecordRoute(client)
|
||||||
|
@ -79,7 +79,7 @@ func (s *routeGuideServer) GetFeature(ctx context.Context, point *pb.Point) (*pb
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// No feature was found, return an unnamed feature
|
// No feature was found, return an unnamed feature
|
||||||
return &pb.Feature{"", point}, nil
|
return &pb.Feature{Location: point}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListFeatures lists all features contained within the given bounding Rectangle.
|
// ListFeatures lists all features contained within the given bounding Rectangle.
|
||||||
|
@ -11,19 +11,22 @@ import (
|
|||||||
healthpb "google.golang.org/grpc/health/grpc_health_v1"
|
healthpb "google.golang.org/grpc/health/grpc_health_v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
type HealthServer struct {
|
// Server implements `service Health`.
|
||||||
|
type Server struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
// statusMap stores the serving status of the services this HealthServer monitors.
|
// statusMap stores the serving status of the services this Server monitors.
|
||||||
statusMap map[string]healthpb.HealthCheckResponse_ServingStatus
|
statusMap map[string]healthpb.HealthCheckResponse_ServingStatus
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHealthServer() *HealthServer {
|
// NewServer returns a new Server.
|
||||||
return &HealthServer{
|
func NewServer() *Server {
|
||||||
|
return &Server{
|
||||||
statusMap: make(map[string]healthpb.HealthCheckResponse_ServingStatus),
|
statusMap: make(map[string]healthpb.HealthCheckResponse_ServingStatus),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *HealthServer) Check(ctx context.Context, in *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) {
|
// Check implements `service Health`.
|
||||||
|
func (s *Server) Check(ctx context.Context, in *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
if in.Service == "" {
|
if in.Service == "" {
|
||||||
@ -42,7 +45,7 @@ func (s *HealthServer) Check(ctx context.Context, in *healthpb.HealthCheckReques
|
|||||||
|
|
||||||
// SetServingStatus is called when need to reset the serving status of a service
|
// SetServingStatus is called when need to reset the serving status of a service
|
||||||
// or insert a new service entry into the statusMap.
|
// or insert a new service entry into the statusMap.
|
||||||
func (s *HealthServer) SetServingStatus(service string, status healthpb.HealthCheckResponse_ServingStatus) {
|
func (s *Server) SetServingStatus(service string, status healthpb.HealthCheckResponse_ServingStatus) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.statusMap[service] = status
|
s.statusMap[service] = status
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
@ -60,15 +60,21 @@ func encodeKeyValue(k, v string) (string, string) {
|
|||||||
|
|
||||||
// DecodeKeyValue returns the original key and value corresponding to the
|
// DecodeKeyValue returns the original key and value corresponding to the
|
||||||
// encoded data in k, v.
|
// encoded data in k, v.
|
||||||
|
// If k is a binary header and v contains comma, v is split on comma before decoded,
|
||||||
|
// and the decoded v will be joined with comma before returned.
|
||||||
func DecodeKeyValue(k, v string) (string, string, error) {
|
func DecodeKeyValue(k, v string) (string, string, error) {
|
||||||
if !strings.HasSuffix(k, binHdrSuffix) {
|
if !strings.HasSuffix(k, binHdrSuffix) {
|
||||||
return k, v, nil
|
return k, v, nil
|
||||||
}
|
}
|
||||||
val, err := base64.StdEncoding.DecodeString(v)
|
vvs := strings.Split(v, ",")
|
||||||
if err != nil {
|
for i, vv := range vvs {
|
||||||
return "", "", err
|
val, err := base64.StdEncoding.DecodeString(vv)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
vvs[i] = string(val)
|
||||||
}
|
}
|
||||||
return k, string(val), nil
|
return k, strings.Join(vvs, ","), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MD is a mapping from metadata keys to values. Users should use the following
|
// MD is a mapping from metadata keys to values. Users should use the following
|
||||||
|
@ -74,6 +74,8 @@ func TestDecodeKeyValue(t *testing.T) {
|
|||||||
{"a", "abc", "a", "abc", nil},
|
{"a", "abc", "a", "abc", nil},
|
||||||
{"key-bin", "Zm9vAGJhcg==", "key-bin", "foo\x00bar", nil},
|
{"key-bin", "Zm9vAGJhcg==", "key-bin", "foo\x00bar", nil},
|
||||||
{"key-bin", "woA=", "key-bin", binaryValue, nil},
|
{"key-bin", "woA=", "key-bin", binaryValue, nil},
|
||||||
|
{"a", "abc,efg", "a", "abc,efg", nil},
|
||||||
|
{"key-bin", "Zm9vAGJhcg==,Zm9vAGJhcg==", "key-bin", "foo\x00bar,foo\x00bar", nil},
|
||||||
} {
|
} {
|
||||||
k, v, err := DecodeKeyValue(test.kin, test.vin)
|
k, v, err := DecodeKeyValue(test.kin, test.vin)
|
||||||
if k != test.kout || !reflect.DeepEqual(v, test.vout) || !reflect.DeepEqual(err, test.err) {
|
if k != test.kout || !reflect.DeepEqual(v, test.vout) || !reflect.DeepEqual(err, test.err) {
|
||||||
|
@ -70,7 +70,7 @@ import (
|
|||||||
type serverReflectionServer struct {
|
type serverReflectionServer struct {
|
||||||
s *grpc.Server
|
s *grpc.Server
|
||||||
// TODO add more cache if necessary
|
// TODO add more cache if necessary
|
||||||
serviceInfo map[string]*grpc.ServiceInfo // cache for s.GetServiceInfo()
|
serviceInfo map[string]grpc.ServiceInfo // cache for s.GetServiceInfo()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register registers the server reflection service on the given gRPC server.
|
// Register registers the server reflection service on the given gRPC server.
|
||||||
@ -214,19 +214,19 @@ func (s *serverReflectionServer) serviceMetadataForSymbol(name string) (interfac
|
|||||||
return nil, fmt.Errorf("unknown symbol: %v", name)
|
return nil, fmt.Errorf("unknown symbol: %v", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Search for method in info.
|
// Search the method name in info.Methods.
|
||||||
var found bool
|
var found bool
|
||||||
for _, m := range info.Methods {
|
for _, m := range info.Methods {
|
||||||
if m == name[pos+1:] {
|
if m.Name == name[pos+1:] {
|
||||||
found = true
|
found = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !found {
|
if found {
|
||||||
return nil, fmt.Errorf("unknown symbol: %v", name)
|
return info.Metadata, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return info.Metadata, nil
|
return nil, fmt.Errorf("unknown symbol: %v", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol,
|
// fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol,
|
||||||
@ -253,7 +253,7 @@ func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) (
|
|||||||
// Metadata not valid.
|
// Metadata not valid.
|
||||||
enc, ok := meta.([]byte)
|
enc, ok := meta.([]byte)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid file descriptor for symbol: %v")
|
return nil, fmt.Errorf("invalid file descriptor for symbol: %v", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
fd, err = s.decodeFileDesc(enc)
|
fd, err = s.decodeFileDesc(enc)
|
||||||
|
@ -273,6 +273,7 @@ func testFileContainingSymbol(t *testing.T, stream rpb.ServerReflection_ServerRe
|
|||||||
}{
|
}{
|
||||||
{"grpc.testing.SearchService", fdTestByte},
|
{"grpc.testing.SearchService", fdTestByte},
|
||||||
{"grpc.testing.SearchService.Search", fdTestByte},
|
{"grpc.testing.SearchService.Search", fdTestByte},
|
||||||
|
{"grpc.testing.SearchService.StreamingSearch", fdTestByte},
|
||||||
{"grpc.testing.SearchResponse", fdTestByte},
|
{"grpc.testing.SearchResponse", fdTestByte},
|
||||||
{"grpc.testing.ToBeExtened", fdProto2Byte},
|
{"grpc.testing.ToBeExtened", fdProto2Byte},
|
||||||
} {
|
} {
|
||||||
|
66
rpc_util.go
66
rpc_util.go
@ -141,6 +141,8 @@ type callInfo struct {
|
|||||||
traceInfo traceInfo // in trace.go
|
traceInfo traceInfo // in trace.go
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var defaultCallInfo = callInfo{failFast: true}
|
||||||
|
|
||||||
// CallOption configures a Call before it starts or extracts information from
|
// CallOption configures a Call before it starts or extracts information from
|
||||||
// a Call after it completes.
|
// a Call after it completes.
|
||||||
type CallOption interface {
|
type CallOption interface {
|
||||||
@ -179,6 +181,19 @@ func Trailer(md *metadata.MD) CallOption {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FailFast configures the action to take when an RPC is attempted on broken
|
||||||
|
// connections or unreachable servers. If failfast is true, the RPC will fail
|
||||||
|
// immediately. Otherwise, the RPC client will block the call until a
|
||||||
|
// connection is available (or the call is canceled or times out) and will retry
|
||||||
|
// the call if it fails due to a transient error. Please refer to
|
||||||
|
// https://github.com/grpc/grpc/blob/master/doc/fail_fast.md
|
||||||
|
func FailFast(failFast bool) CallOption {
|
||||||
|
return beforeCall(func(c *callInfo) error {
|
||||||
|
c.failFast = failFast
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// The format of the payload: compressed or not?
|
// The format of the payload: compressed or not?
|
||||||
type payloadFormat uint8
|
type payloadFormat uint8
|
||||||
|
|
||||||
@ -212,7 +227,7 @@ type parser struct {
|
|||||||
// No other error values or types must be returned, which also means
|
// No other error values or types must be returned, which also means
|
||||||
// that the underlying io.Reader must not return an incompatible
|
// that the underlying io.Reader must not return an incompatible
|
||||||
// error.
|
// error.
|
||||||
func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) {
|
func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err error) {
|
||||||
if _, err := io.ReadFull(p.r, p.header[:]); err != nil {
|
if _, err := io.ReadFull(p.r, p.header[:]); err != nil {
|
||||||
return 0, nil, err
|
return 0, nil, err
|
||||||
}
|
}
|
||||||
@ -223,6 +238,9 @@ func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) {
|
|||||||
if length == 0 {
|
if length == 0 {
|
||||||
return pf, nil, nil
|
return pf, nil, nil
|
||||||
}
|
}
|
||||||
|
if length > uint32(maxMsgSize) {
|
||||||
|
return 0, nil, Errorf(codes.Internal, "grpc: received message length %d exceeding the max size %d", length, maxMsgSize)
|
||||||
|
}
|
||||||
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
|
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
|
||||||
// of making it for each message:
|
// of making it for each message:
|
||||||
msg = make([]byte, int(length))
|
msg = make([]byte, int(length))
|
||||||
@ -293,8 +311,8 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}) error {
|
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int) error {
|
||||||
pf, d, err := p.recvMsg()
|
pf, d, err := p.recvMsg(maxMsgSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -304,11 +322,16 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{
|
|||||||
if pf == compressionMade {
|
if pf == compressionMade {
|
||||||
d, err = dc.Do(bytes.NewReader(d))
|
d, err = dc.Do(bytes.NewReader(d))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return transport.StreamErrorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
|
return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if len(d) > maxMsgSize {
|
||||||
|
// TODO: Revisit the error code. Currently keep it consistent with java
|
||||||
|
// implementation.
|
||||||
|
return Errorf(codes.Internal, "grpc: received a message of %d bytes exceeding %d limit", len(d), maxMsgSize)
|
||||||
|
}
|
||||||
if err := c.Unmarshal(d, m); err != nil {
|
if err := c.Unmarshal(d, m); err != nil {
|
||||||
return transport.StreamErrorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
|
return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -319,7 +342,7 @@ type rpcError struct {
|
|||||||
desc string
|
desc string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e rpcError) Error() string {
|
func (e *rpcError) Error() string {
|
||||||
return fmt.Sprintf("rpc error: code = %d desc = %s", e.code, e.desc)
|
return fmt.Sprintf("rpc error: code = %d desc = %s", e.code, e.desc)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -329,7 +352,7 @@ func Code(err error) codes.Code {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
return codes.OK
|
return codes.OK
|
||||||
}
|
}
|
||||||
if e, ok := err.(rpcError); ok {
|
if e, ok := err.(*rpcError); ok {
|
||||||
return e.code
|
return e.code
|
||||||
}
|
}
|
||||||
return codes.Unknown
|
return codes.Unknown
|
||||||
@ -341,7 +364,7 @@ func ErrorDesc(err error) string {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
if e, ok := err.(rpcError); ok {
|
if e, ok := err.(*rpcError); ok {
|
||||||
return e.desc
|
return e.desc
|
||||||
}
|
}
|
||||||
return err.Error()
|
return err.Error()
|
||||||
@ -353,7 +376,7 @@ func Errorf(c codes.Code, format string, a ...interface{}) error {
|
|||||||
if c == codes.OK {
|
if c == codes.OK {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return rpcError{
|
return &rpcError{
|
||||||
code: c,
|
code: c,
|
||||||
desc: fmt.Sprintf(format, a...),
|
desc: fmt.Sprintf(format, a...),
|
||||||
}
|
}
|
||||||
@ -362,18 +385,37 @@ func Errorf(c codes.Code, format string, a ...interface{}) error {
|
|||||||
// toRPCErr converts an error into a rpcError.
|
// toRPCErr converts an error into a rpcError.
|
||||||
func toRPCErr(err error) error {
|
func toRPCErr(err error) error {
|
||||||
switch e := err.(type) {
|
switch e := err.(type) {
|
||||||
case rpcError:
|
case *rpcError:
|
||||||
return err
|
return err
|
||||||
case transport.StreamError:
|
case transport.StreamError:
|
||||||
return rpcError{
|
return &rpcError{
|
||||||
code: e.Code,
|
code: e.Code,
|
||||||
desc: e.Desc,
|
desc: e.Desc,
|
||||||
}
|
}
|
||||||
case transport.ConnectionError:
|
case transport.ConnectionError:
|
||||||
return rpcError{
|
return &rpcError{
|
||||||
code: codes.Internal,
|
code: codes.Internal,
|
||||||
desc: e.Desc,
|
desc: e.Desc,
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
switch err {
|
||||||
|
case context.DeadlineExceeded:
|
||||||
|
return &rpcError{
|
||||||
|
code: codes.DeadlineExceeded,
|
||||||
|
desc: err.Error(),
|
||||||
|
}
|
||||||
|
case context.Canceled:
|
||||||
|
return &rpcError{
|
||||||
|
code: codes.Canceled,
|
||||||
|
desc: err.Error(),
|
||||||
|
}
|
||||||
|
case ErrClientConnClosing:
|
||||||
|
return &rpcError{
|
||||||
|
code: codes.FailedPrecondition,
|
||||||
|
desc: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
return Errorf(codes.Unknown, "%v", err)
|
return Errorf(codes.Unknown, "%v", err)
|
||||||
}
|
}
|
||||||
|
@ -36,6 +36,7 @@ package grpc
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -66,9 +67,9 @@ func TestSimpleParsing(t *testing.T) {
|
|||||||
} {
|
} {
|
||||||
buf := bytes.NewReader(test.p)
|
buf := bytes.NewReader(test.p)
|
||||||
parser := &parser{r: buf}
|
parser := &parser{r: buf}
|
||||||
pt, b, err := parser.recvMsg()
|
pt, b, err := parser.recvMsg(math.MaxInt32)
|
||||||
if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt {
|
if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt {
|
||||||
t.Fatalf("parser{%v}.recvMsg() = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err)
|
t.Fatalf("parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -88,16 +89,16 @@ func TestMultipleParsing(t *testing.T) {
|
|||||||
{compressionNone, []byte("d")},
|
{compressionNone, []byte("d")},
|
||||||
}
|
}
|
||||||
for i, want := range wantRecvs {
|
for i, want := range wantRecvs {
|
||||||
pt, data, err := parser.recvMsg()
|
pt, data, err := parser.recvMsg(math.MaxInt32)
|
||||||
if err != nil || pt != want.pt || !reflect.DeepEqual(data, want.data) {
|
if err != nil || pt != want.pt || !reflect.DeepEqual(data, want.data) {
|
||||||
t.Fatalf("after %d calls, parser{%v}.recvMsg() = %v, %v, %v\nwant %v, %v, <nil>",
|
t.Fatalf("after %d calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, <nil>",
|
||||||
i, p, pt, data, err, want.pt, want.data)
|
i, p, pt, data, err, want.pt, want.data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pt, data, err := parser.recvMsg()
|
pt, data, err := parser.recvMsg(math.MaxInt32)
|
||||||
if err != io.EOF {
|
if err != io.EOF {
|
||||||
t.Fatalf("after %d recvMsgs calls, parser{%v}.recvMsg() = %v, %v, %v\nwant _, _, %v",
|
t.Fatalf("after %d recvMsgs calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant _, _, %v",
|
||||||
len(wantRecvs), p, pt, data, err, io.EOF)
|
len(wantRecvs), p, pt, data, err, io.EOF)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -149,13 +150,17 @@ func TestToRPCErr(t *testing.T) {
|
|||||||
// input
|
// input
|
||||||
errIn error
|
errIn error
|
||||||
// outputs
|
// outputs
|
||||||
errOut error
|
errOut *rpcError
|
||||||
}{
|
}{
|
||||||
{transport.StreamErrorf(codes.Unknown, ""), Errorf(codes.Unknown, "")},
|
{transport.StreamErrorf(codes.Unknown, ""), Errorf(codes.Unknown, "").(*rpcError)},
|
||||||
{transport.ErrConnClosing, Errorf(codes.Internal, transport.ErrConnClosing.Desc)},
|
{transport.ErrConnClosing, Errorf(codes.Internal, transport.ErrConnClosing.Desc).(*rpcError)},
|
||||||
} {
|
} {
|
||||||
err := toRPCErr(test.errIn)
|
err := toRPCErr(test.errIn)
|
||||||
if err != test.errOut {
|
rpcErr, ok := err.(*rpcError)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("toRPCErr{%v} returned type %T, want %T", test.errIn, err, rpcError{})
|
||||||
|
}
|
||||||
|
if *rpcErr != *test.errOut {
|
||||||
t.Fatalf("toRPCErr{%v} = %v \nwant %v", test.errIn, err, test.errOut)
|
t.Fatalf("toRPCErr{%v} = %v \nwant %v", test.errIn, err, test.errOut)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -178,6 +183,18 @@ func TestContextErr(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestErrorsWithSameParameters(t *testing.T) {
|
||||||
|
const description = "some description"
|
||||||
|
e1 := Errorf(codes.AlreadyExists, description).(*rpcError)
|
||||||
|
e2 := Errorf(codes.AlreadyExists, description).(*rpcError)
|
||||||
|
if e1 == e2 {
|
||||||
|
t.Fatalf("Error interfaces should not be considered equal - e1: %p - %v e2: %p - %v", e1, e1, e2, e2)
|
||||||
|
}
|
||||||
|
if Code(e1) != Code(e2) || ErrorDesc(e1) != ErrorDesc(e2) {
|
||||||
|
t.Fatalf("Expected errors to have same code and description - e1: %p - %v e2: %p - %v", e1, e1, e2, e2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// bmEncode benchmarks encoding a Protocol Buffer message containing mSize
|
// bmEncode benchmarks encoding a Protocol Buffer message containing mSize
|
||||||
// bytes.
|
// bytes.
|
||||||
func bmEncode(b *testing.B, mSize int) {
|
func bmEncode(b *testing.B, mSize int) {
|
||||||
|
136
server.go
136
server.go
@ -89,9 +89,13 @@ type service struct {
|
|||||||
type Server struct {
|
type Server struct {
|
||||||
opts options
|
opts options
|
||||||
|
|
||||||
mu sync.Mutex // guards following
|
mu sync.Mutex // guards following
|
||||||
lis map[net.Listener]bool
|
lis map[net.Listener]bool
|
||||||
conns map[io.Closer]bool
|
conns map[io.Closer]bool
|
||||||
|
drain bool
|
||||||
|
// A CondVar to let GracefulStop() blocks until all the pending RPCs are finished
|
||||||
|
// and all the transport goes away.
|
||||||
|
cv *sync.Cond
|
||||||
m map[string]*service // service name -> service info
|
m map[string]*service // service name -> service info
|
||||||
events trace.EventLog
|
events trace.EventLog
|
||||||
}
|
}
|
||||||
@ -101,12 +105,15 @@ type options struct {
|
|||||||
codec Codec
|
codec Codec
|
||||||
cp Compressor
|
cp Compressor
|
||||||
dc Decompressor
|
dc Decompressor
|
||||||
|
maxMsgSize int
|
||||||
unaryInt UnaryServerInterceptor
|
unaryInt UnaryServerInterceptor
|
||||||
streamInt StreamServerInterceptor
|
streamInt StreamServerInterceptor
|
||||||
maxConcurrentStreams uint32
|
maxConcurrentStreams uint32
|
||||||
useHandlerImpl bool // use http.Handler-based server
|
useHandlerImpl bool // use http.Handler-based server
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit
|
||||||
|
|
||||||
// A ServerOption sets options.
|
// A ServerOption sets options.
|
||||||
type ServerOption func(*options)
|
type ServerOption func(*options)
|
||||||
|
|
||||||
@ -117,20 +124,28 @@ func CustomCodec(codec Codec) ServerOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RPCCompressor returns a ServerOption that sets a compressor for outbound message.
|
// RPCCompressor returns a ServerOption that sets a compressor for outbound messages.
|
||||||
func RPCCompressor(cp Compressor) ServerOption {
|
func RPCCompressor(cp Compressor) ServerOption {
|
||||||
return func(o *options) {
|
return func(o *options) {
|
||||||
o.cp = cp
|
o.cp = cp
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RPCDecompressor returns a ServerOption that sets a decompressor for inbound message.
|
// RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages.
|
||||||
func RPCDecompressor(dc Decompressor) ServerOption {
|
func RPCDecompressor(dc Decompressor) ServerOption {
|
||||||
return func(o *options) {
|
return func(o *options) {
|
||||||
o.dc = dc
|
o.dc = dc
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MaxMsgSize returns a ServerOption to set the max message size in bytes for inbound mesages.
|
||||||
|
// If this is not set, gRPC uses the default 4MB.
|
||||||
|
func MaxMsgSize(m int) ServerOption {
|
||||||
|
return func(o *options) {
|
||||||
|
o.maxMsgSize = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
|
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
|
||||||
// of concurrent streams to each ServerTransport.
|
// of concurrent streams to each ServerTransport.
|
||||||
func MaxConcurrentStreams(n uint32) ServerOption {
|
func MaxConcurrentStreams(n uint32) ServerOption {
|
||||||
@ -173,6 +188,7 @@ func StreamInterceptor(i StreamServerInterceptor) ServerOption {
|
|||||||
// started to accept requests yet.
|
// started to accept requests yet.
|
||||||
func NewServer(opt ...ServerOption) *Server {
|
func NewServer(opt ...ServerOption) *Server {
|
||||||
var opts options
|
var opts options
|
||||||
|
opts.maxMsgSize = defaultMaxMsgSize
|
||||||
for _, o := range opt {
|
for _, o := range opt {
|
||||||
o(&opts)
|
o(&opts)
|
||||||
}
|
}
|
||||||
@ -186,6 +202,7 @@ func NewServer(opt ...ServerOption) *Server {
|
|||||||
conns: make(map[io.Closer]bool),
|
conns: make(map[io.Closer]bool),
|
||||||
m: make(map[string]*service),
|
m: make(map[string]*service),
|
||||||
}
|
}
|
||||||
|
s.cv = sync.NewCond(&s.mu)
|
||||||
if EnableTracing {
|
if EnableTracing {
|
||||||
_, file, line, _ := runtime.Caller(1)
|
_, file, line, _ := runtime.Caller(1)
|
||||||
s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
|
s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
|
||||||
@ -245,28 +262,45 @@ func (s *Server) register(sd *ServiceDesc, ss interface{}) {
|
|||||||
s.m[sd.ServiceName] = srv
|
s.m[sd.ServiceName] = srv
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServiceInfo contains method names and metadata for a service.
|
// MethodInfo contains the information of an RPC including its method name and type.
|
||||||
|
type MethodInfo struct {
|
||||||
|
// Name is the method name only, without the service name or package name.
|
||||||
|
Name string
|
||||||
|
// IsClientStream indicates whether the RPC is a client streaming RPC.
|
||||||
|
IsClientStream bool
|
||||||
|
// IsServerStream indicates whether the RPC is a server streaming RPC.
|
||||||
|
IsServerStream bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServiceInfo contains unary RPC method info, streaming RPC methid info and metadata for a service.
|
||||||
type ServiceInfo struct {
|
type ServiceInfo struct {
|
||||||
// Methods are method names only, without the service name or package name.
|
Methods []MethodInfo
|
||||||
Methods []string
|
|
||||||
// Metadata is the metadata specified in ServiceDesc when registering service.
|
// Metadata is the metadata specified in ServiceDesc when registering service.
|
||||||
Metadata interface{}
|
Metadata interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetServiceInfo returns a map from service names to ServiceInfo.
|
// GetServiceInfo returns a map from service names to ServiceInfo.
|
||||||
// Service names include the package names, in the form of <package>.<service>.
|
// Service names include the package names, in the form of <package>.<service>.
|
||||||
func (s *Server) GetServiceInfo() map[string]*ServiceInfo {
|
func (s *Server) GetServiceInfo() map[string]ServiceInfo {
|
||||||
ret := make(map[string]*ServiceInfo)
|
ret := make(map[string]ServiceInfo)
|
||||||
for n, srv := range s.m {
|
for n, srv := range s.m {
|
||||||
methods := make([]string, 0, len(srv.md)+len(srv.sd))
|
methods := make([]MethodInfo, 0, len(srv.md)+len(srv.sd))
|
||||||
for m := range srv.md {
|
for m := range srv.md {
|
||||||
methods = append(methods, m)
|
methods = append(methods, MethodInfo{
|
||||||
|
Name: m,
|
||||||
|
IsClientStream: false,
|
||||||
|
IsServerStream: false,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
for m := range srv.sd {
|
for m, d := range srv.sd {
|
||||||
methods = append(methods, m)
|
methods = append(methods, MethodInfo{
|
||||||
|
Name: m,
|
||||||
|
IsClientStream: d.ClientStreams,
|
||||||
|
IsServerStream: d.ServerStreams,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
ret[n] = &ServiceInfo{
|
ret[n] = ServiceInfo{
|
||||||
Methods: methods,
|
Methods: methods,
|
||||||
Metadata: srv.mdata,
|
Metadata: srv.mdata,
|
||||||
}
|
}
|
||||||
@ -303,9 +337,11 @@ func (s *Server) Serve(lis net.Listener) error {
|
|||||||
s.lis[lis] = true
|
s.lis[lis] = true
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
defer func() {
|
defer func() {
|
||||||
lis.Close()
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
delete(s.lis, lis)
|
if s.lis != nil && s.lis[lis] {
|
||||||
|
lis.Close()
|
||||||
|
delete(s.lis, lis)
|
||||||
|
}
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
}()
|
}()
|
||||||
for {
|
for {
|
||||||
@ -449,7 +485,7 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea
|
|||||||
func (s *Server) addConn(c io.Closer) bool {
|
func (s *Server) addConn(c io.Closer) bool {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
if s.conns == nil {
|
if s.conns == nil || s.drain {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
s.conns[c] = true
|
s.conns[c] = true
|
||||||
@ -461,6 +497,7 @@ func (s *Server) removeConn(c io.Closer) {
|
|||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
if s.conns != nil {
|
if s.conns != nil {
|
||||||
delete(s.conns, c)
|
delete(s.conns, c)
|
||||||
|
s.cv.Signal()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -501,7 +538,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||||||
}
|
}
|
||||||
p := &parser{r: stream}
|
p := &parser{r: stream}
|
||||||
for {
|
for {
|
||||||
pf, req, err := p.recvMsg()
|
pf, req, err := p.recvMsg(s.opts.maxMsgSize)
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
// The entire stream is done (for unary RPC only).
|
// The entire stream is done (for unary RPC only).
|
||||||
return err
|
return err
|
||||||
@ -511,6 +548,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch err := err.(type) {
|
switch err := err.(type) {
|
||||||
|
case *rpcError:
|
||||||
|
if err := t.WriteStatus(stream, err.code, err.desc); err != nil {
|
||||||
|
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
|
||||||
|
}
|
||||||
case transport.ConnectionError:
|
case transport.ConnectionError:
|
||||||
// Nothing to do here.
|
// Nothing to do here.
|
||||||
case transport.StreamError:
|
case transport.StreamError:
|
||||||
@ -550,6 +591,12 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if len(req) > s.opts.maxMsgSize {
|
||||||
|
// TODO: Revisit the error code. Currently keep it consistent with
|
||||||
|
// java implementation.
|
||||||
|
statusCode = codes.Internal
|
||||||
|
statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize)
|
||||||
|
}
|
||||||
if err := s.opts.codec.Unmarshal(req, v); err != nil {
|
if err := s.opts.codec.Unmarshal(req, v); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -560,7 +607,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||||||
}
|
}
|
||||||
reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt)
|
reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt)
|
||||||
if appErr != nil {
|
if appErr != nil {
|
||||||
if err, ok := appErr.(rpcError); ok {
|
if err, ok := appErr.(*rpcError); ok {
|
||||||
statusCode = err.code
|
statusCode = err.code
|
||||||
statusDesc = err.desc
|
statusDesc = err.desc
|
||||||
} else {
|
} else {
|
||||||
@ -609,13 +656,14 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
|
|||||||
stream.SetSendCompress(s.opts.cp.Type())
|
stream.SetSendCompress(s.opts.cp.Type())
|
||||||
}
|
}
|
||||||
ss := &serverStream{
|
ss := &serverStream{
|
||||||
t: t,
|
t: t,
|
||||||
s: stream,
|
s: stream,
|
||||||
p: &parser{r: stream},
|
p: &parser{r: stream},
|
||||||
codec: s.opts.codec,
|
codec: s.opts.codec,
|
||||||
cp: s.opts.cp,
|
cp: s.opts.cp,
|
||||||
dc: s.opts.dc,
|
dc: s.opts.dc,
|
||||||
trInfo: trInfo,
|
maxMsgSize: s.opts.maxMsgSize,
|
||||||
|
trInfo: trInfo,
|
||||||
}
|
}
|
||||||
if ss.cp != nil {
|
if ss.cp != nil {
|
||||||
ss.cbuf = new(bytes.Buffer)
|
ss.cbuf = new(bytes.Buffer)
|
||||||
@ -645,7 +693,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
|
|||||||
appErr = s.opts.streamInt(srv.server, ss, info, sd.Handler)
|
appErr = s.opts.streamInt(srv.server, ss, info, sd.Handler)
|
||||||
}
|
}
|
||||||
if appErr != nil {
|
if appErr != nil {
|
||||||
if err, ok := appErr.(rpcError); ok {
|
if err, ok := appErr.(*rpcError); ok {
|
||||||
ss.statusCode = err.code
|
ss.statusCode = err.code
|
||||||
ss.statusDesc = err.desc
|
ss.statusDesc = err.desc
|
||||||
} else if err, ok := appErr.(transport.StreamError); ok {
|
} else if err, ok := appErr.(transport.StreamError); ok {
|
||||||
@ -747,14 +795,16 @@ func (s *Server) Stop() {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
listeners := s.lis
|
listeners := s.lis
|
||||||
s.lis = nil
|
s.lis = nil
|
||||||
cs := s.conns
|
st := s.conns
|
||||||
s.conns = nil
|
s.conns = nil
|
||||||
|
// interrupt GracefulStop if Stop and GracefulStop are called concurrently.
|
||||||
|
s.cv.Signal()
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
for lis := range listeners {
|
for lis := range listeners {
|
||||||
lis.Close()
|
lis.Close()
|
||||||
}
|
}
|
||||||
for c := range cs {
|
for c := range st {
|
||||||
c.Close()
|
c.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -766,6 +816,32 @@ func (s *Server) Stop() {
|
|||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GracefulStop stops the gRPC server gracefully. It stops the server to accept new
|
||||||
|
// connections and RPCs and blocks until all the pending RPCs are finished.
|
||||||
|
func (s *Server) GracefulStop() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if s.drain == true || s.conns == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.drain = true
|
||||||
|
for lis := range s.lis {
|
||||||
|
lis.Close()
|
||||||
|
}
|
||||||
|
s.lis = nil
|
||||||
|
for c := range s.conns {
|
||||||
|
c.(transport.ServerTransport).Drain()
|
||||||
|
}
|
||||||
|
for len(s.conns) != 0 {
|
||||||
|
s.cv.Wait()
|
||||||
|
}
|
||||||
|
s.conns = nil
|
||||||
|
if s.events != nil {
|
||||||
|
s.events.Finish()
|
||||||
|
s.events = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
internal.TestingCloseConns = func(arg interface{}) {
|
internal.TestingCloseConns = func(arg interface{}) {
|
||||||
arg.(*Server).testingCloseConns()
|
arg.(*Server).testingCloseConns()
|
||||||
|
@ -79,7 +79,7 @@ func TestGetServiceInfo(t *testing.T) {
|
|||||||
{
|
{
|
||||||
StreamName: "EmptyStream",
|
StreamName: "EmptyStream",
|
||||||
Handler: nil,
|
Handler: nil,
|
||||||
ServerStreams: true,
|
ServerStreams: false,
|
||||||
ClientStreams: true,
|
ClientStreams: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -90,17 +90,24 @@ func TestGetServiceInfo(t *testing.T) {
|
|||||||
server.RegisterService(&testSd, &testServer{})
|
server.RegisterService(&testSd, &testServer{})
|
||||||
|
|
||||||
info := server.GetServiceInfo()
|
info := server.GetServiceInfo()
|
||||||
want := map[string]*ServiceInfo{
|
want := map[string]ServiceInfo{
|
||||||
"grpc.testing.EmptyService": &ServiceInfo{
|
"grpc.testing.EmptyService": {
|
||||||
Methods: []string{
|
Methods: []MethodInfo{
|
||||||
"EmptyCall",
|
{
|
||||||
"EmptyStream",
|
Name: "EmptyCall",
|
||||||
},
|
IsClientStream: false,
|
||||||
|
IsServerStream: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "EmptyStream",
|
||||||
|
IsClientStream: true,
|
||||||
|
IsServerStream: false,
|
||||||
|
}},
|
||||||
Metadata: []int{0, 2, 1, 3},
|
Metadata: []int{0, 2, 1, 3},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(info, want) {
|
if !reflect.DeepEqual(info, want) {
|
||||||
t.Errorf("GetServiceInfo() = %q, want %q", info, want)
|
t.Errorf("GetServiceInfo() = %+v, want %+v", info, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
165
stream.go
165
stream.go
@ -37,6 +37,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -84,12 +85,9 @@ type ClientStream interface {
|
|||||||
// Header returns the header metadata received from the server if there
|
// Header returns the header metadata received from the server if there
|
||||||
// is any. It blocks if the metadata is not ready to read.
|
// is any. It blocks if the metadata is not ready to read.
|
||||||
Header() (metadata.MD, error)
|
Header() (metadata.MD, error)
|
||||||
// Trailer returns the trailer metadata from the server. It must be called
|
// Trailer returns the trailer metadata from the server, if there is any.
|
||||||
// after stream.Recv() returns non-nil error (including io.EOF) for
|
// It must only be called after stream.CloseAndRecv has returned, or
|
||||||
// bi-directional streaming and server streaming or stream.CloseAndRecv()
|
// stream.Recv has returned a non-nil error (including io.EOF).
|
||||||
// returns for client streaming in order to receive trailer metadata if
|
|
||||||
// present. Otherwise, it could returns an empty MD even though trailer
|
|
||||||
// is present.
|
|
||||||
Trailer() metadata.MD
|
Trailer() metadata.MD
|
||||||
// CloseSend closes the send direction of the stream. It closes the stream
|
// CloseSend closes the send direction of the stream. It closes the stream
|
||||||
// when non-nil error is met.
|
// when non-nil error is met.
|
||||||
@ -99,19 +97,17 @@ type ClientStream interface {
|
|||||||
|
|
||||||
// NewClientStream creates a new Stream for the client side. This is called
|
// NewClientStream creates a new Stream for the client side. This is called
|
||||||
// by generated code.
|
// by generated code.
|
||||||
func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
|
func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) {
|
||||||
var (
|
var (
|
||||||
t transport.ClientTransport
|
t transport.ClientTransport
|
||||||
err error
|
s *transport.Stream
|
||||||
put func()
|
put func()
|
||||||
)
|
)
|
||||||
// TODO(zhaoq): CallOption is omitted. Add support when it is needed.
|
c := defaultCallInfo
|
||||||
gopts := BalancerGetOptions{
|
for _, o := range opts {
|
||||||
BlockingWait: false,
|
if err := o.before(&c); err != nil {
|
||||||
}
|
return nil, toRPCErr(err)
|
||||||
t, put, err = cc.getTransport(ctx, gopts)
|
}
|
||||||
if err != nil {
|
|
||||||
return nil, toRPCErr(err)
|
|
||||||
}
|
}
|
||||||
callHdr := &transport.CallHdr{
|
callHdr := &transport.CallHdr{
|
||||||
Host: cc.authority,
|
Host: cc.authority,
|
||||||
@ -121,41 +117,98 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
|
|||||||
if cc.dopts.cp != nil {
|
if cc.dopts.cp != nil {
|
||||||
callHdr.SendCompress = cc.dopts.cp.Type()
|
callHdr.SendCompress = cc.dopts.cp.Type()
|
||||||
}
|
}
|
||||||
|
var trInfo traceInfo
|
||||||
|
if EnableTracing {
|
||||||
|
trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
|
||||||
|
trInfo.firstLine.client = true
|
||||||
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
|
trInfo.firstLine.deadline = deadline.Sub(time.Now())
|
||||||
|
}
|
||||||
|
trInfo.tr.LazyLog(&trInfo.firstLine, false)
|
||||||
|
ctx = trace.NewContext(ctx, trInfo.tr)
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
// Need to call tr.finish() if error is returned.
|
||||||
|
// Because tr will not be returned to caller.
|
||||||
|
trInfo.tr.LazyPrintf("RPC: [%v]", err)
|
||||||
|
trInfo.tr.SetError()
|
||||||
|
trInfo.tr.Finish()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
gopts := BalancerGetOptions{
|
||||||
|
BlockingWait: !c.failFast,
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
t, put, err = cc.getTransport(ctx, gopts)
|
||||||
|
if err != nil {
|
||||||
|
// TODO(zhaoq): Probably revisit the error handling.
|
||||||
|
if _, ok := err.(*rpcError); ok {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err == errConnClosing || err == errConnUnavailable {
|
||||||
|
if c.failFast {
|
||||||
|
return nil, Errorf(codes.Unavailable, "%v", err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// All the other errors are treated as Internal errors.
|
||||||
|
return nil, Errorf(codes.Internal, "%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err = t.NewStream(ctx, callHdr)
|
||||||
|
if err != nil {
|
||||||
|
if put != nil {
|
||||||
|
put()
|
||||||
|
put = nil
|
||||||
|
}
|
||||||
|
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
|
||||||
|
if c.failFast {
|
||||||
|
return nil, toRPCErr(err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, toRPCErr(err)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
cs := &clientStream{
|
cs := &clientStream{
|
||||||
desc: desc,
|
opts: opts,
|
||||||
put: put,
|
c: c,
|
||||||
codec: cc.dopts.codec,
|
desc: desc,
|
||||||
cp: cc.dopts.cp,
|
codec: cc.dopts.codec,
|
||||||
dc: cc.dopts.dc,
|
cp: cc.dopts.cp,
|
||||||
|
dc: cc.dopts.dc,
|
||||||
|
|
||||||
|
put: put,
|
||||||
|
t: t,
|
||||||
|
s: s,
|
||||||
|
p: &parser{r: s},
|
||||||
|
|
||||||
tracing: EnableTracing,
|
tracing: EnableTracing,
|
||||||
|
trInfo: trInfo,
|
||||||
}
|
}
|
||||||
if cc.dopts.cp != nil {
|
if cc.dopts.cp != nil {
|
||||||
callHdr.SendCompress = cc.dopts.cp.Type()
|
|
||||||
cs.cbuf = new(bytes.Buffer)
|
cs.cbuf = new(bytes.Buffer)
|
||||||
}
|
}
|
||||||
if cs.tracing {
|
// Listen on ctx.Done() to detect cancellation and s.Done() to detect normal termination
|
||||||
cs.trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
|
// when there is no pending I/O operations on this stream.
|
||||||
cs.trInfo.firstLine.client = true
|
|
||||||
if deadline, ok := ctx.Deadline(); ok {
|
|
||||||
cs.trInfo.firstLine.deadline = deadline.Sub(time.Now())
|
|
||||||
}
|
|
||||||
cs.trInfo.tr.LazyLog(&cs.trInfo.firstLine, false)
|
|
||||||
ctx = trace.NewContext(ctx, cs.trInfo.tr)
|
|
||||||
}
|
|
||||||
s, err := t.NewStream(ctx, callHdr)
|
|
||||||
if err != nil {
|
|
||||||
cs.finish(err)
|
|
||||||
return nil, toRPCErr(err)
|
|
||||||
}
|
|
||||||
cs.t = t
|
|
||||||
cs.s = s
|
|
||||||
cs.p = &parser{r: s}
|
|
||||||
// Listen on ctx.Done() to detect cancellation when there is no pending
|
|
||||||
// I/O operations on this stream.
|
|
||||||
go func() {
|
go func() {
|
||||||
select {
|
select {
|
||||||
case <-t.Error():
|
case <-t.Error():
|
||||||
// Incur transport error, simply exit.
|
// Incur transport error, simply exit.
|
||||||
|
case <-s.Done():
|
||||||
|
// TODO: The trace of the RPC is terminated here when there is no pending
|
||||||
|
// I/O, which is probably not the optimal solution.
|
||||||
|
if s.StatusCode() == codes.OK {
|
||||||
|
cs.finish(nil)
|
||||||
|
} else {
|
||||||
|
cs.finish(Errorf(s.StatusCode(), "%s", s.StatusDesc()))
|
||||||
|
}
|
||||||
|
cs.closeTransportStream(nil)
|
||||||
|
case <-s.GoAway():
|
||||||
|
cs.finish(errConnDrain)
|
||||||
|
cs.closeTransportStream(errConnDrain)
|
||||||
case <-s.Context().Done():
|
case <-s.Context().Done():
|
||||||
err := s.Context().Err()
|
err := s.Context().Err()
|
||||||
cs.finish(err)
|
cs.finish(err)
|
||||||
@ -167,6 +220,8 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
|
|||||||
|
|
||||||
// clientStream implements a client side Stream.
|
// clientStream implements a client side Stream.
|
||||||
type clientStream struct {
|
type clientStream struct {
|
||||||
|
opts []CallOption
|
||||||
|
c callInfo
|
||||||
t transport.ClientTransport
|
t transport.ClientTransport
|
||||||
s *transport.Stream
|
s *transport.Stream
|
||||||
p *parser
|
p *parser
|
||||||
@ -216,7 +271,17 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
cs.finish(err)
|
cs.finish(err)
|
||||||
}
|
}
|
||||||
if err == nil || err == io.EOF {
|
if err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err == io.EOF {
|
||||||
|
// Specialize the process for server streaming. SendMesg is only called
|
||||||
|
// once when creating the stream object. io.EOF needs to be skipped when
|
||||||
|
// the rpc is early finished (before the stream object is created.).
|
||||||
|
// TODO: It is probably better to move this into the generated code.
|
||||||
|
if !cs.desc.ClientStreams && cs.desc.ServerStreams {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if _, ok := err.(transport.ConnectionError); !ok {
|
if _, ok := err.(transport.ConnectionError); !ok {
|
||||||
@ -237,7 +302,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
||||||
err = recv(cs.p, cs.codec, cs.s, cs.dc, m)
|
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
|
||||||
defer func() {
|
defer func() {
|
||||||
// err != nil indicates the termination of the stream.
|
// err != nil indicates the termination of the stream.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -256,7 +321,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Special handling for client streaming rpc.
|
// Special handling for client streaming rpc.
|
||||||
err = recv(cs.p, cs.codec, cs.s, cs.dc, m)
|
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
|
||||||
cs.closeTransportStream(err)
|
cs.closeTransportStream(err)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
|
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
|
||||||
@ -291,7 +356,7 @@ func (cs *clientStream) CloseSend() (err error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
if err == nil || err == io.EOF {
|
if err == nil || err == io.EOF {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
if _, ok := err.(transport.ConnectionError); !ok {
|
if _, ok := err.(transport.ConnectionError); !ok {
|
||||||
cs.closeTransportStream(err)
|
cs.closeTransportStream(err)
|
||||||
@ -312,15 +377,18 @@ func (cs *clientStream) closeTransportStream(err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (cs *clientStream) finish(err error) {
|
func (cs *clientStream) finish(err error) {
|
||||||
if !cs.tracing {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
cs.mu.Lock()
|
cs.mu.Lock()
|
||||||
defer cs.mu.Unlock()
|
defer cs.mu.Unlock()
|
||||||
|
for _, o := range cs.opts {
|
||||||
|
o.after(&cs.c)
|
||||||
|
}
|
||||||
if cs.put != nil {
|
if cs.put != nil {
|
||||||
cs.put()
|
cs.put()
|
||||||
cs.put = nil
|
cs.put = nil
|
||||||
}
|
}
|
||||||
|
if !cs.tracing {
|
||||||
|
return
|
||||||
|
}
|
||||||
if cs.trInfo.tr != nil {
|
if cs.trInfo.tr != nil {
|
||||||
if err == nil || err == io.EOF {
|
if err == nil || err == io.EOF {
|
||||||
cs.trInfo.tr.LazyPrintf("RPC: [OK]")
|
cs.trInfo.tr.LazyPrintf("RPC: [OK]")
|
||||||
@ -354,6 +422,7 @@ type serverStream struct {
|
|||||||
cp Compressor
|
cp Compressor
|
||||||
dc Decompressor
|
dc Decompressor
|
||||||
cbuf *bytes.Buffer
|
cbuf *bytes.Buffer
|
||||||
|
maxMsgSize int
|
||||||
statusCode codes.Code
|
statusCode codes.Code
|
||||||
statusDesc string
|
statusDesc string
|
||||||
trInfo *traceInfo
|
trInfo *traceInfo
|
||||||
@ -420,5 +489,5 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
|
|||||||
ss.mu.Unlock()
|
ss.mu.Unlock()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return recv(ss.p, ss.codec, ss.s, ss.dc, m)
|
return recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize)
|
||||||
}
|
}
|
||||||
|
@ -162,7 +162,7 @@ func (s *server) GetAllGauges(in *metricspb.EmptyMessage, stream metricspb.Metri
|
|||||||
defer s.mutex.RUnlock()
|
defer s.mutex.RUnlock()
|
||||||
|
|
||||||
for name, gauge := range s.gauges {
|
for name, gauge := range s.gauges {
|
||||||
if err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{gauge.get()}}); err != nil {
|
if err := stream.Send(&metricspb.GaugeResponse{Name: name, Value: &metricspb.GaugeResponse_LongValue{LongValue: gauge.get()}}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -175,7 +175,7 @@ func (s *server) GetGauge(ctx context.Context, in *metricspb.GaugeRequest) (*met
|
|||||||
defer s.mutex.RUnlock()
|
defer s.mutex.RUnlock()
|
||||||
|
|
||||||
if g, ok := s.gauges[in.Name]; ok {
|
if g, ok := s.gauges[in.Name]; ok {
|
||||||
return &metricspb.GaugeResponse{Name: in.Name, Value: &metricspb.GaugeResponse_LongValue{g.get()}}, nil
|
return &metricspb.GaugeResponse{Name: in.Name, Value: &metricspb.GaugeResponse_LongValue{LongValue: g.get()}}, nil
|
||||||
}
|
}
|
||||||
return nil, grpc.Errorf(codes.InvalidArgument, "gauge with name %s not found", in.Name)
|
return nil, grpc.Errorf(codes.InvalidArgument, "gauge with name %s not found", in.Name)
|
||||||
}
|
}
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -72,6 +72,11 @@ type resetStream struct {
|
|||||||
|
|
||||||
func (*resetStream) item() {}
|
func (*resetStream) item() {}
|
||||||
|
|
||||||
|
type goAway struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*goAway) item() {}
|
||||||
|
|
||||||
type flushIO struct {
|
type flushIO struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
46
transport/go16.go
Normal file
46
transport/go16.go
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
// +build go1.6,!go1.7
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Copyright 2016, Google Inc.
|
||||||
|
* All rights reserved.
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are
|
||||||
|
* met:
|
||||||
|
*
|
||||||
|
* * Redistributions of source code must retain the above copyright
|
||||||
|
* notice, this list of conditions and the following disclaimer.
|
||||||
|
* * Redistributions in binary form must reproduce the above
|
||||||
|
* copyright notice, this list of conditions and the following disclaimer
|
||||||
|
* in the documentation and/or other materials provided with the
|
||||||
|
* distribution.
|
||||||
|
* * Neither the name of Google Inc. nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package transport
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// dialContext connects to the address on the named network.
|
||||||
|
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return (&net.Dialer{Cancel: ctx.Done()}).Dial(network, address)
|
||||||
|
}
|
46
transport/go17.go
Normal file
46
transport/go17.go
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
// +build go1.7
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Copyright 2016, Google Inc.
|
||||||
|
* All rights reserved.
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are
|
||||||
|
* met:
|
||||||
|
*
|
||||||
|
* * Redistributions of source code must retain the above copyright
|
||||||
|
* notice, this list of conditions and the following disclaimer.
|
||||||
|
* * Redistributions in binary form must reproduce the above
|
||||||
|
* copyright notice, this list of conditions and the following disclaimer
|
||||||
|
* in the documentation and/or other materials provided with the
|
||||||
|
* distribution.
|
||||||
|
* * Neither the name of Google Inc. nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package transport
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// dialContext connects to the address on the named network.
|
||||||
|
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return (&net.Dialer{}).DialContext(ctx, network, address)
|
||||||
|
}
|
@ -83,7 +83,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
|
|||||||
}
|
}
|
||||||
|
|
||||||
if v := r.Header.Get("grpc-timeout"); v != "" {
|
if v := r.Header.Get("grpc-timeout"); v != "" {
|
||||||
to, err := timeoutDecode(v)
|
to, err := decodeTimeout(v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, StreamErrorf(codes.Internal, "malformed time-out: %v", err)
|
return nil, StreamErrorf(codes.Internal, "malformed time-out: %v", err)
|
||||||
}
|
}
|
||||||
@ -194,7 +194,7 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code,
|
|||||||
h := ht.rw.Header()
|
h := ht.rw.Header()
|
||||||
h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode))
|
h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode))
|
||||||
if statusDesc != "" {
|
if statusDesc != "" {
|
||||||
h.Set("Grpc-Message", statusDesc)
|
h.Set("Grpc-Message", encodeGrpcMessage(statusDesc))
|
||||||
}
|
}
|
||||||
if md := s.Trailer(); len(md) > 0 {
|
if md := s.Trailer(); len(md) > 0 {
|
||||||
for k, vv := range md {
|
for k, vv := range md {
|
||||||
@ -312,7 +312,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) {
|
|||||||
Addr: ht.RemoteAddr(),
|
Addr: ht.RemoteAddr(),
|
||||||
}
|
}
|
||||||
if req.TLS != nil {
|
if req.TLS != nil {
|
||||||
pr.AuthInfo = credentials.TLSInfo{*req.TLS}
|
pr.AuthInfo = credentials.TLSInfo{State: *req.TLS}
|
||||||
}
|
}
|
||||||
ctx = metadata.NewContext(ctx, ht.headerMD)
|
ctx = metadata.NewContext(ctx, ht.headerMD)
|
||||||
ctx = peer.NewContext(ctx, pr)
|
ctx = peer.NewContext(ctx, pr)
|
||||||
@ -370,6 +370,10 @@ func (ht *serverHandlerTransport) runStream() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ht *serverHandlerTransport) Drain() {
|
||||||
|
panic("Drain() is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
// mapRecvMsgError returns the non-nil err into the appropriate
|
// mapRecvMsgError returns the non-nil err into the appropriate
|
||||||
// error value as expected by callers of *grpc.parser.recvMsg.
|
// error value as expected by callers of *grpc.parser.recvMsg.
|
||||||
// In particular, in can only be:
|
// In particular, in can only be:
|
||||||
|
@ -333,7 +333,7 @@ func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string)
|
|||||||
"Content-Type": {"application/grpc"},
|
"Content-Type": {"application/grpc"},
|
||||||
"Trailer": {"Grpc-Status", "Grpc-Message"},
|
"Trailer": {"Grpc-Status", "Grpc-Message"},
|
||||||
"Grpc-Status": {fmt.Sprint(uint32(statusCode))},
|
"Grpc-Status": {fmt.Sprint(uint32(statusCode))},
|
||||||
"Grpc-Message": {msg},
|
"Grpc-Message": {encodeGrpcMessage(msg)},
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) {
|
if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) {
|
||||||
t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", st.rw.HeaderMap, wantHeader)
|
t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", st.rw.HeaderMap, wantHeader)
|
||||||
@ -381,7 +381,7 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
|
|||||||
"Content-Type": {"application/grpc"},
|
"Content-Type": {"application/grpc"},
|
||||||
"Trailer": {"Grpc-Status", "Grpc-Message"},
|
"Trailer": {"Grpc-Status", "Grpc-Message"},
|
||||||
"Grpc-Status": {"4"},
|
"Grpc-Status": {"4"},
|
||||||
"Grpc-Message": {"too slow"},
|
"Grpc-Message": {encodeGrpcMessage("too slow")},
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(rw.HeaderMap, wantHeader) {
|
if !reflect.DeepEqual(rw.HeaderMap, wantHeader) {
|
||||||
t.Errorf("Header+Trailer Map mismatch.\n got: %#v\nwant: %#v", rw.HeaderMap, wantHeader)
|
t.Errorf("Header+Trailer Map mismatch.\n got: %#v\nwant: %#v", rw.HeaderMap, wantHeader)
|
||||||
|
@ -35,6 +35,7 @@ package transport
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
@ -71,6 +72,9 @@ type http2Client struct {
|
|||||||
shutdownChan chan struct{}
|
shutdownChan chan struct{}
|
||||||
// errorChan is closed to notify the I/O error to the caller.
|
// errorChan is closed to notify the I/O error to the caller.
|
||||||
errorChan chan struct{}
|
errorChan chan struct{}
|
||||||
|
// goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor)
|
||||||
|
// that the server sent GoAway on this transport.
|
||||||
|
goAway chan struct{}
|
||||||
|
|
||||||
framer *framer
|
framer *framer
|
||||||
hBuf *bytes.Buffer // the buffer for HPACK encoding
|
hBuf *bytes.Buffer // the buffer for HPACK encoding
|
||||||
@ -97,41 +101,49 @@ type http2Client struct {
|
|||||||
maxStreams int
|
maxStreams int
|
||||||
// the per-stream outbound flow control window size set by the peer.
|
// the per-stream outbound flow control window size set by the peer.
|
||||||
streamSendQuota uint32
|
streamSendQuota uint32
|
||||||
|
// goAwayID records the Last-Stream-ID in the GoAway frame from the server.
|
||||||
|
goAwayID uint32
|
||||||
|
// prevGoAway ID records the Last-Stream-ID in the previous GOAway frame.
|
||||||
|
prevGoAwayID uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func dial(fn func(context.Context, string) (net.Conn, error), ctx context.Context, addr string) (net.Conn, error) {
|
||||||
|
if fn != nil {
|
||||||
|
return fn(ctx, addr)
|
||||||
|
}
|
||||||
|
return dialContext(ctx, "tcp", addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
|
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
|
||||||
// and starts to receive messages on it. Non-nil error returns if construction
|
// and starts to receive messages on it. Non-nil error returns if construction
|
||||||
// fails.
|
// fails.
|
||||||
func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) {
|
func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ ClientTransport, err error) {
|
||||||
if opts.Dialer == nil {
|
|
||||||
// Set the default Dialer.
|
|
||||||
opts.Dialer = func(addr string, timeout time.Duration) (net.Conn, error) {
|
|
||||||
return net.DialTimeout("tcp", addr, timeout)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
scheme := "http"
|
scheme := "http"
|
||||||
startT := time.Now()
|
conn, connErr := dial(opts.Dialer, ctx, addr)
|
||||||
timeout := opts.Timeout
|
|
||||||
conn, connErr := opts.Dialer(addr, timeout)
|
|
||||||
if connErr != nil {
|
if connErr != nil {
|
||||||
return nil, ConnectionErrorf("transport: %v", connErr)
|
return nil, ConnectionErrorf(true, connErr, "transport: %v", connErr)
|
||||||
}
|
}
|
||||||
var authInfo credentials.AuthInfo
|
// Any further errors will close the underlying connection
|
||||||
if opts.TransportCredentials != nil {
|
defer func(conn net.Conn) {
|
||||||
scheme = "https"
|
|
||||||
if timeout > 0 {
|
|
||||||
timeout -= time.Since(startT)
|
|
||||||
}
|
|
||||||
conn, authInfo, connErr = opts.TransportCredentials.ClientHandshake(addr, conn, timeout)
|
|
||||||
}
|
|
||||||
if connErr != nil {
|
|
||||||
return nil, ConnectionErrorf("transport: %v", connErr)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}
|
}
|
||||||
}()
|
}(conn)
|
||||||
|
var authInfo credentials.AuthInfo
|
||||||
|
if creds := opts.TransportCredentials; creds != nil {
|
||||||
|
scheme = "https"
|
||||||
|
conn, authInfo, connErr = creds.ClientHandshake(ctx, addr, conn)
|
||||||
|
}
|
||||||
|
if connErr != nil {
|
||||||
|
// Credentials handshake error is not a temporary error (unless the error
|
||||||
|
// was the connection closing or deadline exceeded).
|
||||||
|
var temp bool
|
||||||
|
switch connErr {
|
||||||
|
case io.EOF, context.DeadlineExceeded:
|
||||||
|
temp = true
|
||||||
|
}
|
||||||
|
return nil, ConnectionErrorf(temp, connErr, "transport: %v", connErr)
|
||||||
|
}
|
||||||
ua := primaryUA
|
ua := primaryUA
|
||||||
if opts.UserAgent != "" {
|
if opts.UserAgent != "" {
|
||||||
ua = opts.UserAgent + " " + ua
|
ua = opts.UserAgent + " " + ua
|
||||||
@ -147,6 +159,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
|||||||
writableChan: make(chan int, 1),
|
writableChan: make(chan int, 1),
|
||||||
shutdownChan: make(chan struct{}),
|
shutdownChan: make(chan struct{}),
|
||||||
errorChan: make(chan struct{}),
|
errorChan: make(chan struct{}),
|
||||||
|
goAway: make(chan struct{}),
|
||||||
framer: newFramer(conn),
|
framer: newFramer(conn),
|
||||||
hBuf: &buf,
|
hBuf: &buf,
|
||||||
hEnc: hpack.NewEncoder(&buf),
|
hEnc: hpack.NewEncoder(&buf),
|
||||||
@ -168,26 +181,29 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
|||||||
n, err := t.conn.Write(clientPreface)
|
n, err := t.conn.Write(clientPreface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Close()
|
t.Close()
|
||||||
return nil, ConnectionErrorf("transport: %v", err)
|
return nil, ConnectionErrorf(true, err, "transport: %v", err)
|
||||||
}
|
}
|
||||||
if n != len(clientPreface) {
|
if n != len(clientPreface) {
|
||||||
t.Close()
|
t.Close()
|
||||||
return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
|
return nil, ConnectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
|
||||||
}
|
}
|
||||||
if initialWindowSize != defaultWindowSize {
|
if initialWindowSize != defaultWindowSize {
|
||||||
err = t.framer.writeSettings(true, http2.Setting{http2.SettingInitialWindowSize, uint32(initialWindowSize)})
|
err = t.framer.writeSettings(true, http2.Setting{
|
||||||
|
ID: http2.SettingInitialWindowSize,
|
||||||
|
Val: uint32(initialWindowSize),
|
||||||
|
})
|
||||||
} else {
|
} else {
|
||||||
err = t.framer.writeSettings(true)
|
err = t.framer.writeSettings(true)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Close()
|
t.Close()
|
||||||
return nil, ConnectionErrorf("transport: %v", err)
|
return nil, ConnectionErrorf(true, err, "transport: %v", err)
|
||||||
}
|
}
|
||||||
// Adjust the connection flow control window if needed.
|
// Adjust the connection flow control window if needed.
|
||||||
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
|
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
|
||||||
if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil {
|
if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil {
|
||||||
t.Close()
|
t.Close()
|
||||||
return nil, ConnectionErrorf("transport: %v", err)
|
return nil, ConnectionErrorf(true, err, "transport: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
go t.controller()
|
go t.controller()
|
||||||
@ -199,6 +215,8 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
|
|||||||
// TODO(zhaoq): Handle uint32 overflow of Stream.id.
|
// TODO(zhaoq): Handle uint32 overflow of Stream.id.
|
||||||
s := &Stream{
|
s := &Stream{
|
||||||
id: t.nextID,
|
id: t.nextID,
|
||||||
|
done: make(chan struct{}),
|
||||||
|
goAway: make(chan struct{}),
|
||||||
method: callHdr.Method,
|
method: callHdr.Method,
|
||||||
sendCompress: callHdr.SendCompress,
|
sendCompress: callHdr.SendCompress,
|
||||||
buf: newRecvBuffer(),
|
buf: newRecvBuffer(),
|
||||||
@ -213,8 +231,9 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
|
|||||||
// Make a stream be able to cancel the pending operations by itself.
|
// Make a stream be able to cancel the pending operations by itself.
|
||||||
s.ctx, s.cancel = context.WithCancel(ctx)
|
s.ctx, s.cancel = context.WithCancel(ctx)
|
||||||
s.dec = &recvBufferReader{
|
s.dec = &recvBufferReader{
|
||||||
ctx: s.ctx,
|
ctx: s.ctx,
|
||||||
recv: s.buf,
|
goAway: s.goAway,
|
||||||
|
recv: s.buf,
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
@ -268,6 +287,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
return nil, ErrConnClosing
|
return nil, ErrConnClosing
|
||||||
}
|
}
|
||||||
|
if t.state == draining {
|
||||||
|
t.mu.Unlock()
|
||||||
|
return nil, ErrStreamDrain
|
||||||
|
}
|
||||||
if t.state != reachable {
|
if t.state != reachable {
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
return nil, ErrConnClosing
|
return nil, ErrConnClosing
|
||||||
@ -275,7 +298,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||||||
checkStreamsQuota := t.streamsQuota != nil
|
checkStreamsQuota := t.streamsQuota != nil
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
if checkStreamsQuota {
|
if checkStreamsQuota {
|
||||||
sq, err := wait(ctx, t.shutdownChan, t.streamsQuota.acquire())
|
sq, err := wait(ctx, nil, nil, t.shutdownChan, t.streamsQuota.acquire())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -284,7 +307,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||||||
t.streamsQuota.add(sq - 1)
|
t.streamsQuota.add(sq - 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil {
|
if _, err := wait(ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
|
||||||
// Return the quota back now because there is no stream returned to the caller.
|
// Return the quota back now because there is no stream returned to the caller.
|
||||||
if _, ok := err.(StreamError); ok && checkStreamsQuota {
|
if _, ok := err.(StreamError); ok && checkStreamsQuota {
|
||||||
t.streamsQuota.add(1)
|
t.streamsQuota.add(1)
|
||||||
@ -292,6 +315,15 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
|
if t.state == draining {
|
||||||
|
t.mu.Unlock()
|
||||||
|
if checkStreamsQuota {
|
||||||
|
t.streamsQuota.add(1)
|
||||||
|
}
|
||||||
|
// Need to make t writable again so that the rpc in flight can still proceed.
|
||||||
|
t.writableChan <- 0
|
||||||
|
return nil, ErrStreamDrain
|
||||||
|
}
|
||||||
if t.state != reachable {
|
if t.state != reachable {
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
return nil, ErrConnClosing
|
return nil, ErrConnClosing
|
||||||
@ -326,7 +358,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||||||
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
|
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
|
||||||
}
|
}
|
||||||
if timeout > 0 {
|
if timeout > 0 {
|
||||||
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)})
|
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
|
||||||
}
|
}
|
||||||
for k, v := range authData {
|
for k, v := range authData {
|
||||||
// Capital header names are illegal in HTTP/2.
|
// Capital header names are illegal in HTTP/2.
|
||||||
@ -381,7 +413,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.notifyError(err)
|
t.notifyError(err)
|
||||||
return nil, ConnectionErrorf("transport: %v", err)
|
return nil, ConnectionErrorf(true, err, "transport: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
t.writableChan <- 0
|
t.writableChan <- 0
|
||||||
@ -400,22 +432,17 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
|
|||||||
if t.streamsQuota != nil {
|
if t.streamsQuota != nil {
|
||||||
updateStreams = true
|
updateStreams = true
|
||||||
}
|
}
|
||||||
if t.state == draining && len(t.activeStreams) == 1 {
|
delete(t.activeStreams, s.id)
|
||||||
|
if t.state == draining && len(t.activeStreams) == 0 {
|
||||||
// The transport is draining and s is the last live stream on t.
|
// The transport is draining and s is the last live stream on t.
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
t.Close()
|
t.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
delete(t.activeStreams, s.id)
|
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
if updateStreams {
|
if updateStreams {
|
||||||
t.streamsQuota.add(1)
|
t.streamsQuota.add(1)
|
||||||
}
|
}
|
||||||
// In case stream sending and receiving are invoked in separate
|
|
||||||
// goroutines (e.g., bi-directional streaming), the caller needs
|
|
||||||
// to call cancel on the stream to interrupt the blocking on
|
|
||||||
// other goroutines.
|
|
||||||
s.cancel()
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
if q := s.fc.resetPendingData(); q > 0 {
|
if q := s.fc.resetPendingData(); q > 0 {
|
||||||
if n := t.fc.onRead(q); n > 0 {
|
if n := t.fc.onRead(q); n > 0 {
|
||||||
@ -442,13 +469,13 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
|
|||||||
// accessed any more.
|
// accessed any more.
|
||||||
func (t *http2Client) Close() (err error) {
|
func (t *http2Client) Close() (err error) {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
if t.state == reachable {
|
|
||||||
close(t.errorChan)
|
|
||||||
}
|
|
||||||
if t.state == closing {
|
if t.state == closing {
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if t.state == reachable || t.state == draining {
|
||||||
|
close(t.errorChan)
|
||||||
|
}
|
||||||
t.state = closing
|
t.state = closing
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
close(t.shutdownChan)
|
close(t.shutdownChan)
|
||||||
@ -472,10 +499,35 @@ func (t *http2Client) Close() (err error) {
|
|||||||
|
|
||||||
func (t *http2Client) GracefulClose() error {
|
func (t *http2Client) GracefulClose() error {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
if t.state == closing {
|
switch t.state {
|
||||||
|
case unreachable:
|
||||||
|
// The server may close the connection concurrently. t is not available for
|
||||||
|
// any streams. Close it now.
|
||||||
|
t.mu.Unlock()
|
||||||
|
t.Close()
|
||||||
|
return nil
|
||||||
|
case closing:
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
// Notify the streams which were initiated after the server sent GOAWAY.
|
||||||
|
select {
|
||||||
|
case <-t.goAway:
|
||||||
|
n := t.prevGoAwayID
|
||||||
|
if n == 0 && t.nextID > 1 {
|
||||||
|
n = t.nextID - 2
|
||||||
|
}
|
||||||
|
m := t.goAwayID + 2
|
||||||
|
if m == 2 {
|
||||||
|
m = 1
|
||||||
|
}
|
||||||
|
for i := m; i <= n; i += 2 {
|
||||||
|
if s, ok := t.activeStreams[i]; ok {
|
||||||
|
close(s.goAway)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
}
|
||||||
if t.state == draining {
|
if t.state == draining {
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
return nil
|
return nil
|
||||||
@ -501,15 +553,15 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
|
|||||||
size := http2MaxFrameLen
|
size := http2MaxFrameLen
|
||||||
s.sendQuotaPool.add(0)
|
s.sendQuotaPool.add(0)
|
||||||
// Wait until the stream has some quota to send the data.
|
// Wait until the stream has some quota to send the data.
|
||||||
sq, err := wait(s.ctx, t.shutdownChan, s.sendQuotaPool.acquire())
|
sq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, s.sendQuotaPool.acquire())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
t.sendQuotaPool.add(0)
|
t.sendQuotaPool.add(0)
|
||||||
// Wait until the transport has some quota to send the data.
|
// Wait until the transport has some quota to send the data.
|
||||||
tq, err := wait(s.ctx, t.shutdownChan, t.sendQuotaPool.acquire())
|
tq, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.sendQuotaPool.acquire())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(StreamError); ok {
|
if _, ok := err.(StreamError); ok || err == io.EOF {
|
||||||
t.sendQuotaPool.cancel()
|
t.sendQuotaPool.cancel()
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
@ -541,8 +593,8 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
|
|||||||
// Indicate there is a writer who is about to write a data frame.
|
// Indicate there is a writer who is about to write a data frame.
|
||||||
t.framer.adjustNumWriters(1)
|
t.framer.adjustNumWriters(1)
|
||||||
// Got some quota. Try to acquire writing privilege on the transport.
|
// Got some quota. Try to acquire writing privilege on the transport.
|
||||||
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
|
if _, err := wait(s.ctx, s.done, s.goAway, t.shutdownChan, t.writableChan); err != nil {
|
||||||
if _, ok := err.(StreamError); ok {
|
if _, ok := err.(StreamError); ok || err == io.EOF {
|
||||||
// Return the connection quota back.
|
// Return the connection quota back.
|
||||||
t.sendQuotaPool.add(len(p))
|
t.sendQuotaPool.add(len(p))
|
||||||
}
|
}
|
||||||
@ -575,7 +627,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
|
|||||||
// invoked.
|
// invoked.
|
||||||
if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil {
|
if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil {
|
||||||
t.notifyError(err)
|
t.notifyError(err)
|
||||||
return ConnectionErrorf("transport: %v", err)
|
return ConnectionErrorf(true, err, "transport: %v", err)
|
||||||
}
|
}
|
||||||
if t.framer.adjustNumWriters(-1) == 0 {
|
if t.framer.adjustNumWriters(-1) == 0 {
|
||||||
t.framer.flushWrite()
|
t.framer.flushWrite()
|
||||||
@ -590,11 +642,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
|
|||||||
}
|
}
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
if s.state != streamDone {
|
if s.state != streamDone {
|
||||||
if s.state == streamReadDone {
|
s.state = streamWriteDone
|
||||||
s.state = streamDone
|
|
||||||
} else {
|
|
||||||
s.state = streamWriteDone
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
return nil
|
return nil
|
||||||
@ -627,7 +675,7 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) {
|
|||||||
func (t *http2Client) handleData(f *http2.DataFrame) {
|
func (t *http2Client) handleData(f *http2.DataFrame) {
|
||||||
size := len(f.Data())
|
size := len(f.Data())
|
||||||
if err := t.fc.onData(uint32(size)); err != nil {
|
if err := t.fc.onData(uint32(size)); err != nil {
|
||||||
t.notifyError(ConnectionErrorf("%v", err))
|
t.notifyError(ConnectionErrorf(true, err, "%v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Select the right stream to dispatch.
|
// Select the right stream to dispatch.
|
||||||
@ -652,6 +700,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
|
|||||||
s.state = streamDone
|
s.state = streamDone
|
||||||
s.statusCode = codes.Internal
|
s.statusCode = codes.Internal
|
||||||
s.statusDesc = err.Error()
|
s.statusDesc = err.Error()
|
||||||
|
close(s.done)
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
s.write(recvMsg{err: io.EOF})
|
s.write(recvMsg{err: io.EOF})
|
||||||
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl})
|
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl})
|
||||||
@ -669,13 +718,14 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
|
|||||||
// the read direction is closed, and set the status appropriately.
|
// the read direction is closed, and set the status appropriately.
|
||||||
if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) {
|
if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
if s.state == streamWriteDone {
|
if s.state == streamDone {
|
||||||
s.state = streamDone
|
s.mu.Unlock()
|
||||||
} else {
|
return
|
||||||
s.state = streamReadDone
|
|
||||||
}
|
}
|
||||||
|
s.state = streamDone
|
||||||
s.statusCode = codes.Internal
|
s.statusCode = codes.Internal
|
||||||
s.statusDesc = "server closed the stream without sending trailers"
|
s.statusDesc = "server closed the stream without sending trailers"
|
||||||
|
close(s.done)
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
s.write(recvMsg{err: io.EOF})
|
s.write(recvMsg{err: io.EOF})
|
||||||
}
|
}
|
||||||
@ -701,6 +751,8 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
|
|||||||
grpclog.Println("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error ", f.ErrCode)
|
grpclog.Println("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error ", f.ErrCode)
|
||||||
s.statusCode = codes.Unknown
|
s.statusCode = codes.Unknown
|
||||||
}
|
}
|
||||||
|
s.statusDesc = fmt.Sprintf("stream terminated by RST_STREAM with error code: %d", f.ErrCode)
|
||||||
|
close(s.done)
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
s.write(recvMsg{err: io.EOF})
|
s.write(recvMsg{err: io.EOF})
|
||||||
}
|
}
|
||||||
@ -725,7 +777,32 @@ func (t *http2Client) handlePing(f *http2.PingFrame) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
|
func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
|
||||||
// TODO(zhaoq): GoAwayFrame handler to be implemented
|
t.mu.Lock()
|
||||||
|
if t.state == reachable || t.state == draining {
|
||||||
|
if f.LastStreamID > 0 && f.LastStreamID%2 != 1 {
|
||||||
|
t.mu.Unlock()
|
||||||
|
t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: stream ID %d is even", f.LastStreamID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-t.goAway:
|
||||||
|
id := t.goAwayID
|
||||||
|
// t.goAway has been closed (i.e.,multiple GoAways).
|
||||||
|
if id < f.LastStreamID {
|
||||||
|
t.mu.Unlock()
|
||||||
|
t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.prevGoAwayID = id
|
||||||
|
t.goAwayID = f.LastStreamID
|
||||||
|
t.mu.Unlock()
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
t.goAwayID = f.LastStreamID
|
||||||
|
close(t.goAway)
|
||||||
|
}
|
||||||
|
t.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) {
|
func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) {
|
||||||
@ -777,11 +854,11 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
|
|||||||
if len(state.mdata) > 0 {
|
if len(state.mdata) > 0 {
|
||||||
s.trailer = state.mdata
|
s.trailer = state.mdata
|
||||||
}
|
}
|
||||||
s.state = streamDone
|
|
||||||
s.statusCode = state.statusCode
|
s.statusCode = state.statusCode
|
||||||
s.statusDesc = state.statusDesc
|
s.statusDesc = state.statusDesc
|
||||||
|
close(s.done)
|
||||||
|
s.state = streamDone
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
s.write(recvMsg{err: io.EOF})
|
s.write(recvMsg{err: io.EOF})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -934,13 +1011,22 @@ func (t *http2Client) Error() <-chan struct{} {
|
|||||||
return t.errorChan
|
return t.errorChan
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *http2Client) GoAway() <-chan struct{} {
|
||||||
|
return t.goAway
|
||||||
|
}
|
||||||
|
|
||||||
func (t *http2Client) notifyError(err error) {
|
func (t *http2Client) notifyError(err error) {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
defer t.mu.Unlock()
|
|
||||||
// make sure t.errorChan is closed only once.
|
// make sure t.errorChan is closed only once.
|
||||||
|
if t.state == draining {
|
||||||
|
t.mu.Unlock()
|
||||||
|
t.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
if t.state == reachable {
|
if t.state == reachable {
|
||||||
t.state = unreachable
|
t.state = unreachable
|
||||||
close(t.errorChan)
|
close(t.errorChan)
|
||||||
grpclog.Printf("transport: http2Client.notifyError got notified that the client transport was broken %v.", err)
|
grpclog.Printf("transport: http2Client.notifyError got notified that the client transport was broken %v.", err)
|
||||||
}
|
}
|
||||||
|
t.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
@ -100,18 +100,23 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI
|
|||||||
if maxStreams == 0 {
|
if maxStreams == 0 {
|
||||||
maxStreams = math.MaxUint32
|
maxStreams = math.MaxUint32
|
||||||
} else {
|
} else {
|
||||||
settings = append(settings, http2.Setting{http2.SettingMaxConcurrentStreams, maxStreams})
|
settings = append(settings, http2.Setting{
|
||||||
|
ID: http2.SettingMaxConcurrentStreams,
|
||||||
|
Val: maxStreams,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
if initialWindowSize != defaultWindowSize {
|
if initialWindowSize != defaultWindowSize {
|
||||||
settings = append(settings, http2.Setting{http2.SettingInitialWindowSize, uint32(initialWindowSize)})
|
settings = append(settings, http2.Setting{
|
||||||
|
ID: http2.SettingInitialWindowSize,
|
||||||
|
Val: uint32(initialWindowSize)})
|
||||||
}
|
}
|
||||||
if err := framer.writeSettings(true, settings...); err != nil {
|
if err := framer.writeSettings(true, settings...); err != nil {
|
||||||
return nil, ConnectionErrorf("transport: %v", err)
|
return nil, ConnectionErrorf(true, err, "transport: %v", err)
|
||||||
}
|
}
|
||||||
// Adjust the connection flow control window if needed.
|
// Adjust the connection flow control window if needed.
|
||||||
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
|
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
|
||||||
if err := framer.writeWindowUpdate(true, 0, delta); err != nil {
|
if err := framer.writeWindowUpdate(true, 0, delta); err != nil {
|
||||||
return nil, ConnectionErrorf("transport: %v", err)
|
return nil, ConnectionErrorf(true, err, "transport: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
@ -137,7 +142,7 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI
|
|||||||
}
|
}
|
||||||
|
|
||||||
// operateHeader takes action on the decoded headers.
|
// operateHeader takes action on the decoded headers.
|
||||||
func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) {
|
func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) (close bool) {
|
||||||
buf := newRecvBuffer()
|
buf := newRecvBuffer()
|
||||||
s := &Stream{
|
s := &Stream{
|
||||||
id: frame.Header().StreamID,
|
id: frame.Header().StreamID,
|
||||||
@ -200,6 +205,13 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
|||||||
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
|
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.id%2 != 1 || s.id <= t.maxStreamID {
|
||||||
|
t.mu.Unlock()
|
||||||
|
// illegal gRPC stream id.
|
||||||
|
grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", s.id)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
t.maxStreamID = s.id
|
||||||
s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
|
s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota))
|
||||||
t.activeStreams[s.id] = s
|
t.activeStreams[s.id] = s
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
@ -207,6 +219,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
|||||||
t.updateWindow(s, uint32(n))
|
t.updateWindow(s, uint32(n))
|
||||||
}
|
}
|
||||||
handle(s)
|
handle(s)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleStreams receives incoming streams using the given handler. This is
|
// HandleStreams receives incoming streams using the given handler. This is
|
||||||
@ -226,6 +239,10 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
frame, err := t.framer.readFrame()
|
frame, err := t.framer.readFrame()
|
||||||
|
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||||
|
t.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err)
|
grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err)
|
||||||
t.Close()
|
t.Close()
|
||||||
@ -252,20 +269,20 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
|
|||||||
t.controlBuf.put(&resetStream{se.StreamID, se.Code})
|
t.controlBuf.put(&resetStream{se.StreamID, se.Code})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||||
|
t.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err)
|
||||||
t.Close()
|
t.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
switch frame := frame.(type) {
|
switch frame := frame.(type) {
|
||||||
case *http2.MetaHeadersFrame:
|
case *http2.MetaHeadersFrame:
|
||||||
id := frame.Header().StreamID
|
if t.operateHeaders(frame, handle) {
|
||||||
if id%2 != 1 || id <= t.maxStreamID {
|
|
||||||
// illegal gRPC stream id.
|
|
||||||
grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", id)
|
|
||||||
t.Close()
|
t.Close()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
t.maxStreamID = id
|
|
||||||
t.operateHeaders(frame, handle)
|
|
||||||
case *http2.DataFrame:
|
case *http2.DataFrame:
|
||||||
t.handleData(frame)
|
t.handleData(frame)
|
||||||
case *http2.RSTStreamFrame:
|
case *http2.RSTStreamFrame:
|
||||||
@ -277,7 +294,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
|
|||||||
case *http2.WindowUpdateFrame:
|
case *http2.WindowUpdateFrame:
|
||||||
t.handleWindowUpdate(frame)
|
t.handleWindowUpdate(frame)
|
||||||
case *http2.GoAwayFrame:
|
case *http2.GoAwayFrame:
|
||||||
break
|
// TODO: Handle GoAway from the client appropriately.
|
||||||
default:
|
default:
|
||||||
grpclog.Printf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame)
|
grpclog.Printf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame)
|
||||||
}
|
}
|
||||||
@ -359,11 +376,7 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
|
|||||||
// Received the end of stream from the client.
|
// Received the end of stream from the client.
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
if s.state != streamDone {
|
if s.state != streamDone {
|
||||||
if s.state == streamWriteDone {
|
s.state = streamReadDone
|
||||||
s.state = streamDone
|
|
||||||
} else {
|
|
||||||
s.state = streamReadDone
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
s.write(recvMsg{err: io.EOF})
|
s.write(recvMsg{err: io.EOF})
|
||||||
@ -435,7 +448,7 @@ func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) e
|
|||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Close()
|
t.Close()
|
||||||
return ConnectionErrorf("transport: %v", err)
|
return ConnectionErrorf(true, err, "transport: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -450,7 +463,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
|
|||||||
}
|
}
|
||||||
s.headerOk = true
|
s.headerOk = true
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
|
if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
t.hBuf.Reset()
|
t.hBuf.Reset()
|
||||||
@ -490,7 +503,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
|
|||||||
headersSent = true
|
headersSent = true
|
||||||
}
|
}
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
|
if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
t.hBuf.Reset()
|
t.hBuf.Reset()
|
||||||
@ -503,7 +516,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
|
|||||||
Name: "grpc-status",
|
Name: "grpc-status",
|
||||||
Value: strconv.Itoa(int(statusCode)),
|
Value: strconv.Itoa(int(statusCode)),
|
||||||
})
|
})
|
||||||
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: statusDesc})
|
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(statusDesc)})
|
||||||
// Attach the trailer metadata.
|
// Attach the trailer metadata.
|
||||||
for k, v := range s.trailer {
|
for k, v := range s.trailer {
|
||||||
// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
|
// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
|
||||||
@ -539,7 +552,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
|
|||||||
}
|
}
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
if writeHeaderFrame {
|
if writeHeaderFrame {
|
||||||
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
|
if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
t.hBuf.Reset()
|
t.hBuf.Reset()
|
||||||
@ -555,7 +568,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
|
|||||||
}
|
}
|
||||||
if err := t.framer.writeHeaders(false, p); err != nil {
|
if err := t.framer.writeHeaders(false, p); err != nil {
|
||||||
t.Close()
|
t.Close()
|
||||||
return ConnectionErrorf("transport: %v", err)
|
return ConnectionErrorf(true, err, "transport: %v", err)
|
||||||
}
|
}
|
||||||
t.writableChan <- 0
|
t.writableChan <- 0
|
||||||
}
|
}
|
||||||
@ -567,13 +580,13 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
|
|||||||
size := http2MaxFrameLen
|
size := http2MaxFrameLen
|
||||||
s.sendQuotaPool.add(0)
|
s.sendQuotaPool.add(0)
|
||||||
// Wait until the stream has some quota to send the data.
|
// Wait until the stream has some quota to send the data.
|
||||||
sq, err := wait(s.ctx, t.shutdownChan, s.sendQuotaPool.acquire())
|
sq, err := wait(s.ctx, nil, nil, t.shutdownChan, s.sendQuotaPool.acquire())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
t.sendQuotaPool.add(0)
|
t.sendQuotaPool.add(0)
|
||||||
// Wait until the transport has some quota to send the data.
|
// Wait until the transport has some quota to send the data.
|
||||||
tq, err := wait(s.ctx, t.shutdownChan, t.sendQuotaPool.acquire())
|
tq, err := wait(s.ctx, nil, nil, t.shutdownChan, t.sendQuotaPool.acquire())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(StreamError); ok {
|
if _, ok := err.(StreamError); ok {
|
||||||
t.sendQuotaPool.cancel()
|
t.sendQuotaPool.cancel()
|
||||||
@ -599,7 +612,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
|
|||||||
t.framer.adjustNumWriters(1)
|
t.framer.adjustNumWriters(1)
|
||||||
// Got some quota. Try to acquire writing privilege on the
|
// Got some quota. Try to acquire writing privilege on the
|
||||||
// transport.
|
// transport.
|
||||||
if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil {
|
if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
|
||||||
if _, ok := err.(StreamError); ok {
|
if _, ok := err.(StreamError); ok {
|
||||||
// Return the connection quota back.
|
// Return the connection quota back.
|
||||||
t.sendQuotaPool.add(ps)
|
t.sendQuotaPool.add(ps)
|
||||||
@ -629,7 +642,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
|
|||||||
}
|
}
|
||||||
if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil {
|
if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil {
|
||||||
t.Close()
|
t.Close()
|
||||||
return ConnectionErrorf("transport: %v", err)
|
return ConnectionErrorf(true, err, "transport: %v", err)
|
||||||
}
|
}
|
||||||
if t.framer.adjustNumWriters(-1) == 0 {
|
if t.framer.adjustNumWriters(-1) == 0 {
|
||||||
t.framer.flushWrite()
|
t.framer.flushWrite()
|
||||||
@ -674,6 +687,17 @@ func (t *http2Server) controller() {
|
|||||||
}
|
}
|
||||||
case *resetStream:
|
case *resetStream:
|
||||||
t.framer.writeRSTStream(true, i.streamID, i.code)
|
t.framer.writeRSTStream(true, i.streamID, i.code)
|
||||||
|
case *goAway:
|
||||||
|
t.mu.Lock()
|
||||||
|
if t.state == closing {
|
||||||
|
t.mu.Unlock()
|
||||||
|
// The transport is closing.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sid := t.maxStreamID
|
||||||
|
t.state = draining
|
||||||
|
t.mu.Unlock()
|
||||||
|
t.framer.writeGoAway(true, sid, http2.ErrCodeNo, nil)
|
||||||
case *flushIO:
|
case *flushIO:
|
||||||
t.framer.flushWrite()
|
t.framer.flushWrite()
|
||||||
case *ping:
|
case *ping:
|
||||||
@ -719,6 +743,9 @@ func (t *http2Server) Close() (err error) {
|
|||||||
func (t *http2Server) closeStream(s *Stream) {
|
func (t *http2Server) closeStream(s *Stream) {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
delete(t.activeStreams, s.id)
|
delete(t.activeStreams, s.id)
|
||||||
|
if t.state == draining && len(t.activeStreams) == 0 {
|
||||||
|
defer t.Close()
|
||||||
|
}
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
// In case stream sending and receiving are invoked in separate
|
// In case stream sending and receiving are invoked in separate
|
||||||
// goroutines (e.g., bi-directional streaming), cancel needs to be
|
// goroutines (e.g., bi-directional streaming), cancel needs to be
|
||||||
@ -741,3 +768,7 @@ func (t *http2Server) closeStream(s *Stream) {
|
|||||||
func (t *http2Server) RemoteAddr() net.Addr {
|
func (t *http2Server) RemoteAddr() net.Addr {
|
||||||
return t.conn.RemoteAddr()
|
return t.conn.RemoteAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *http2Server) Drain() {
|
||||||
|
t.controlBuf.put(&goAway{})
|
||||||
|
}
|
||||||
|
@ -35,6 +35,7 @@ package transport
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@ -52,7 +53,7 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
// The primary user agent
|
// The primary user agent
|
||||||
primaryUA = "grpc-go/0.11"
|
primaryUA = "grpc-go/1.0"
|
||||||
// http2MaxFrameLen specifies the max length of a HTTP2 frame.
|
// http2MaxFrameLen specifies the max length of a HTTP2 frame.
|
||||||
http2MaxFrameLen = 16384 // 16KB frame
|
http2MaxFrameLen = 16384 // 16KB frame
|
||||||
// http://http2.github.io/http2-spec/#SettingValues
|
// http://http2.github.io/http2-spec/#SettingValues
|
||||||
@ -174,11 +175,11 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) {
|
|||||||
}
|
}
|
||||||
d.statusCode = codes.Code(code)
|
d.statusCode = codes.Code(code)
|
||||||
case "grpc-message":
|
case "grpc-message":
|
||||||
d.statusDesc = f.Value
|
d.statusDesc = decodeGrpcMessage(f.Value)
|
||||||
case "grpc-timeout":
|
case "grpc-timeout":
|
||||||
d.timeoutSet = true
|
d.timeoutSet = true
|
||||||
var err error
|
var err error
|
||||||
d.timeout, err = timeoutDecode(f.Value)
|
d.timeout, err = decodeTimeout(f.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
d.setErr(StreamErrorf(codes.Internal, "transport: malformed time-out: %v", err))
|
d.setErr(StreamErrorf(codes.Internal, "transport: malformed time-out: %v", err))
|
||||||
return
|
return
|
||||||
@ -251,7 +252,7 @@ func div(d, r time.Duration) int64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it.
|
// TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it.
|
||||||
func timeoutEncode(t time.Duration) string {
|
func encodeTimeout(t time.Duration) string {
|
||||||
if d := div(t, time.Nanosecond); d <= maxTimeoutValue {
|
if d := div(t, time.Nanosecond); d <= maxTimeoutValue {
|
||||||
return strconv.FormatInt(d, 10) + "n"
|
return strconv.FormatInt(d, 10) + "n"
|
||||||
}
|
}
|
||||||
@ -271,7 +272,7 @@ func timeoutEncode(t time.Duration) string {
|
|||||||
return strconv.FormatInt(div(t, time.Hour), 10) + "H"
|
return strconv.FormatInt(div(t, time.Hour), 10) + "H"
|
||||||
}
|
}
|
||||||
|
|
||||||
func timeoutDecode(s string) (time.Duration, error) {
|
func decodeTimeout(s string) (time.Duration, error) {
|
||||||
size := len(s)
|
size := len(s)
|
||||||
if size < 2 {
|
if size < 2 {
|
||||||
return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
|
return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
|
||||||
@ -288,6 +289,80 @@ func timeoutDecode(s string) (time.Duration, error) {
|
|||||||
return d * time.Duration(t), nil
|
return d * time.Duration(t), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
spaceByte = ' '
|
||||||
|
tildaByte = '~'
|
||||||
|
percentByte = '%'
|
||||||
|
)
|
||||||
|
|
||||||
|
// encodeGrpcMessage is used to encode status code in header field
|
||||||
|
// "grpc-message".
|
||||||
|
// It checks to see if each individual byte in msg is an
|
||||||
|
// allowable byte, and then either percent encoding or passing it through.
|
||||||
|
// When percent encoding, the byte is converted into hexadecimal notation
|
||||||
|
// with a '%' prepended.
|
||||||
|
func encodeGrpcMessage(msg string) string {
|
||||||
|
if msg == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
lenMsg := len(msg)
|
||||||
|
for i := 0; i < lenMsg; i++ {
|
||||||
|
c := msg[i]
|
||||||
|
if !(c >= spaceByte && c < tildaByte && c != percentByte) {
|
||||||
|
return encodeGrpcMessageUnchecked(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeGrpcMessageUnchecked(msg string) string {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
lenMsg := len(msg)
|
||||||
|
for i := 0; i < lenMsg; i++ {
|
||||||
|
c := msg[i]
|
||||||
|
if c >= spaceByte && c < tildaByte && c != percentByte {
|
||||||
|
buf.WriteByte(c)
|
||||||
|
} else {
|
||||||
|
buf.WriteString(fmt.Sprintf("%%%02X", c))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage.
|
||||||
|
func decodeGrpcMessage(msg string) string {
|
||||||
|
if msg == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
lenMsg := len(msg)
|
||||||
|
for i := 0; i < lenMsg; i++ {
|
||||||
|
if msg[i] == percentByte && i+2 < lenMsg {
|
||||||
|
return decodeGrpcMessageUnchecked(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeGrpcMessageUnchecked(msg string) string {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
lenMsg := len(msg)
|
||||||
|
for i := 0; i < lenMsg; i++ {
|
||||||
|
c := msg[i]
|
||||||
|
if c == percentByte && i+2 < lenMsg {
|
||||||
|
parsed, err := strconv.ParseInt(msg[i+1:i+3], 16, 8)
|
||||||
|
if err != nil {
|
||||||
|
buf.WriteByte(c)
|
||||||
|
} else {
|
||||||
|
buf.WriteByte(byte(parsed))
|
||||||
|
i += 2
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
buf.WriteByte(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
type framer struct {
|
type framer struct {
|
||||||
numWriters int32
|
numWriters int32
|
||||||
reader io.Reader
|
reader io.Reader
|
||||||
|
@ -59,7 +59,7 @@ func TestTimeoutEncode(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to parse duration string %s: %v", test.in, err)
|
t.Fatalf("failed to parse duration string %s: %v", test.in, err)
|
||||||
}
|
}
|
||||||
out := timeoutEncode(d)
|
out := encodeTimeout(d)
|
||||||
if out != test.out {
|
if out != test.out {
|
||||||
t.Fatalf("timeoutEncode(%s) = %s, want %s", test.in, out, test.out)
|
t.Fatalf("timeoutEncode(%s) = %s, want %s", test.in, out, test.out)
|
||||||
}
|
}
|
||||||
@ -79,7 +79,7 @@ func TestTimeoutDecode(t *testing.T) {
|
|||||||
{"1", 0, fmt.Errorf("transport: timeout string is too short: %q", "1")},
|
{"1", 0, fmt.Errorf("transport: timeout string is too short: %q", "1")},
|
||||||
{"", 0, fmt.Errorf("transport: timeout string is too short: %q", "")},
|
{"", 0, fmt.Errorf("transport: timeout string is too short: %q", "")},
|
||||||
} {
|
} {
|
||||||
d, err := timeoutDecode(test.s)
|
d, err := decodeTimeout(test.s)
|
||||||
if d != test.d || fmt.Sprint(err) != fmt.Sprint(test.err) {
|
if d != test.d || fmt.Sprint(err) != fmt.Sprint(test.err) {
|
||||||
t.Fatalf("timeoutDecode(%q) = %d, %v, want %d, %v", test.s, int64(d), err, int64(test.d), test.err)
|
t.Fatalf("timeoutDecode(%q) = %d, %v, want %d, %v", test.s, int64(d), err, int64(test.d), test.err)
|
||||||
}
|
}
|
||||||
@ -107,3 +107,38 @@ func TestValidContentType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEncodeGrpcMessage(t *testing.T) {
|
||||||
|
for _, tt := range []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"", ""},
|
||||||
|
{"Hello", "Hello"},
|
||||||
|
{"my favorite character is \u0000", "my favorite character is %00"},
|
||||||
|
{"my favorite character is %", "my favorite character is %25"},
|
||||||
|
} {
|
||||||
|
actual := encodeGrpcMessage(tt.input)
|
||||||
|
if tt.expected != actual {
|
||||||
|
t.Errorf("encodeGrpcMessage(%v) = %v, want %v", tt.input, actual, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodeGrpcMessage(t *testing.T) {
|
||||||
|
for _, tt := range []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"", ""},
|
||||||
|
{"Hello", "Hello"},
|
||||||
|
{"H%61o", "Hao"},
|
||||||
|
{"H%6", "H%6"},
|
||||||
|
{"%G0", "%G0"},
|
||||||
|
} {
|
||||||
|
actual := decodeGrpcMessage(tt.input)
|
||||||
|
if tt.expected != actual {
|
||||||
|
t.Errorf("dncodeGrpcMessage(%v) = %v, want %v", tt.input, actual, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
51
transport/pre_go16.go
Normal file
51
transport/pre_go16.go
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
// +build !go1.6
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Copyright 2016, Google Inc.
|
||||||
|
* All rights reserved.
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are
|
||||||
|
* met:
|
||||||
|
*
|
||||||
|
* * Redistributions of source code must retain the above copyright
|
||||||
|
* notice, this list of conditions and the following disclaimer.
|
||||||
|
* * Redistributions in binary form must reproduce the above
|
||||||
|
* copyright notice, this list of conditions and the following disclaimer
|
||||||
|
* in the documentation and/or other materials provided with the
|
||||||
|
* distribution.
|
||||||
|
* * Neither the name of Google Inc. nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package transport
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// dialContext connects to the address on the named network.
|
||||||
|
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
var dialer net.Dialer
|
||||||
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
|
dialer.Timeout = deadline.Sub(time.Now())
|
||||||
|
}
|
||||||
|
return dialer.Dial(network, address)
|
||||||
|
}
|
@ -44,7 +44,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
"golang.org/x/net/trace"
|
"golang.org/x/net/trace"
|
||||||
@ -120,10 +119,11 @@ func (b *recvBuffer) get() <-chan item {
|
|||||||
// recvBufferReader implements io.Reader interface to read the data from
|
// recvBufferReader implements io.Reader interface to read the data from
|
||||||
// recvBuffer.
|
// recvBuffer.
|
||||||
type recvBufferReader struct {
|
type recvBufferReader struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
recv *recvBuffer
|
goAway chan struct{}
|
||||||
last *bytes.Reader // Stores the remaining data in the previous calls.
|
recv *recvBuffer
|
||||||
err error
|
last *bytes.Reader // Stores the remaining data in the previous calls.
|
||||||
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read reads the next len(p) bytes from last. If last is drained, it tries to
|
// Read reads the next len(p) bytes from last. If last is drained, it tries to
|
||||||
@ -141,6 +141,8 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) {
|
|||||||
select {
|
select {
|
||||||
case <-r.ctx.Done():
|
case <-r.ctx.Done():
|
||||||
return 0, ContextErr(r.ctx.Err())
|
return 0, ContextErr(r.ctx.Err())
|
||||||
|
case <-r.goAway:
|
||||||
|
return 0, ErrStreamDrain
|
||||||
case i := <-r.recv.get():
|
case i := <-r.recv.get():
|
||||||
r.recv.load()
|
r.recv.load()
|
||||||
m := i.(*recvMsg)
|
m := i.(*recvMsg)
|
||||||
@ -158,7 +160,7 @@ const (
|
|||||||
streamActive streamState = iota
|
streamActive streamState = iota
|
||||||
streamWriteDone // EndStream sent
|
streamWriteDone // EndStream sent
|
||||||
streamReadDone // EndStream received
|
streamReadDone // EndStream received
|
||||||
streamDone // sendDone and recvDone or RSTStreamFrame is sent or received.
|
streamDone // the entire stream is finished.
|
||||||
)
|
)
|
||||||
|
|
||||||
// Stream represents an RPC in the transport layer.
|
// Stream represents an RPC in the transport layer.
|
||||||
@ -169,6 +171,10 @@ type Stream struct {
|
|||||||
// ctx is the associated context of the stream.
|
// ctx is the associated context of the stream.
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
|
// done is closed when the final status arrives.
|
||||||
|
done chan struct{}
|
||||||
|
// goAway is closed when the server sent GoAways signal before this stream was initiated.
|
||||||
|
goAway chan struct{}
|
||||||
// method records the associated RPC method of the stream.
|
// method records the associated RPC method of the stream.
|
||||||
method string
|
method string
|
||||||
recvCompress string
|
recvCompress string
|
||||||
@ -214,6 +220,18 @@ func (s *Stream) SetSendCompress(str string) {
|
|||||||
s.sendCompress = str
|
s.sendCompress = str
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Done returns a chanel which is closed when it receives the final status
|
||||||
|
// from the server.
|
||||||
|
func (s *Stream) Done() <-chan struct{} {
|
||||||
|
return s.done
|
||||||
|
}
|
||||||
|
|
||||||
|
// GoAway returns a channel which is closed when the server sent GoAways signal
|
||||||
|
// before this stream was initiated.
|
||||||
|
func (s *Stream) GoAway() <-chan struct{} {
|
||||||
|
return s.goAway
|
||||||
|
}
|
||||||
|
|
||||||
// Header acquires the key-value pairs of header metadata once it
|
// Header acquires the key-value pairs of header metadata once it
|
||||||
// is available. It blocks until i) the metadata is ready or ii) there is no
|
// is available. It blocks until i) the metadata is ready or ii) there is no
|
||||||
// header metadata or iii) the stream is cancelled/expired.
|
// header metadata or iii) the stream is cancelled/expired.
|
||||||
@ -221,6 +239,8 @@ func (s *Stream) Header() (metadata.MD, error) {
|
|||||||
select {
|
select {
|
||||||
case <-s.ctx.Done():
|
case <-s.ctx.Done():
|
||||||
return nil, ContextErr(s.ctx.Err())
|
return nil, ContextErr(s.ctx.Err())
|
||||||
|
case <-s.goAway:
|
||||||
|
return nil, ErrStreamDrain
|
||||||
case <-s.headerChan:
|
case <-s.headerChan:
|
||||||
return s.header.Copy(), nil
|
return s.header.Copy(), nil
|
||||||
}
|
}
|
||||||
@ -335,19 +355,17 @@ type ConnectOptions struct {
|
|||||||
// UserAgent is the application user agent.
|
// UserAgent is the application user agent.
|
||||||
UserAgent string
|
UserAgent string
|
||||||
// Dialer specifies how to dial a network address.
|
// Dialer specifies how to dial a network address.
|
||||||
Dialer func(string, time.Duration) (net.Conn, error)
|
Dialer func(context.Context, string) (net.Conn, error)
|
||||||
// PerRPCCredentials stores the PerRPCCredentials required to issue RPCs.
|
// PerRPCCredentials stores the PerRPCCredentials required to issue RPCs.
|
||||||
PerRPCCredentials []credentials.PerRPCCredentials
|
PerRPCCredentials []credentials.PerRPCCredentials
|
||||||
// TransportCredentials stores the Authenticator required to setup a client connection.
|
// TransportCredentials stores the Authenticator required to setup a client connection.
|
||||||
TransportCredentials credentials.TransportCredentials
|
TransportCredentials credentials.TransportCredentials
|
||||||
// Timeout specifies the timeout for dialing a ClientTransport.
|
|
||||||
Timeout time.Duration
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClientTransport establishes the transport with the required ConnectOptions
|
// NewClientTransport establishes the transport with the required ConnectOptions
|
||||||
// and returns it to the caller.
|
// and returns it to the caller.
|
||||||
func NewClientTransport(target string, opts *ConnectOptions) (ClientTransport, error) {
|
func NewClientTransport(ctx context.Context, target string, opts ConnectOptions) (ClientTransport, error) {
|
||||||
return newHTTP2Client(target, opts)
|
return newHTTP2Client(ctx, target, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Options provides additional hints and information for message
|
// Options provides additional hints and information for message
|
||||||
@ -417,6 +435,11 @@ type ClientTransport interface {
|
|||||||
// and create a new one) in error case. It should not return nil
|
// and create a new one) in error case. It should not return nil
|
||||||
// once the transport is initiated.
|
// once the transport is initiated.
|
||||||
Error() <-chan struct{}
|
Error() <-chan struct{}
|
||||||
|
|
||||||
|
// GoAway returns a channel that is closed when ClientTranspor
|
||||||
|
// receives the draining signal from the server (e.g., GOAWAY frame in
|
||||||
|
// HTTP/2).
|
||||||
|
GoAway() <-chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerTransport is the common interface for all gRPC server-side transport
|
// ServerTransport is the common interface for all gRPC server-side transport
|
||||||
@ -448,6 +471,9 @@ type ServerTransport interface {
|
|||||||
|
|
||||||
// RemoteAddr returns the remote network address.
|
// RemoteAddr returns the remote network address.
|
||||||
RemoteAddr() net.Addr
|
RemoteAddr() net.Addr
|
||||||
|
|
||||||
|
// Drain notifies the client this ServerTransport stops accepting new RPCs.
|
||||||
|
Drain()
|
||||||
}
|
}
|
||||||
|
|
||||||
// StreamErrorf creates an StreamError with the specified error code and description.
|
// StreamErrorf creates an StreamError with the specified error code and description.
|
||||||
@ -459,9 +485,11 @@ func StreamErrorf(c codes.Code, format string, a ...interface{}) StreamError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ConnectionErrorf creates an ConnectionError with the specified error description.
|
// ConnectionErrorf creates an ConnectionError with the specified error description.
|
||||||
func ConnectionErrorf(format string, a ...interface{}) ConnectionError {
|
func ConnectionErrorf(temp bool, e error, format string, a ...interface{}) ConnectionError {
|
||||||
return ConnectionError{
|
return ConnectionError{
|
||||||
Desc: fmt.Sprintf(format, a...),
|
Desc: fmt.Sprintf(format, a...),
|
||||||
|
temp: temp,
|
||||||
|
err: e,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -469,14 +497,36 @@ func ConnectionErrorf(format string, a ...interface{}) ConnectionError {
|
|||||||
// entire connection and the retry of all the active streams.
|
// entire connection and the retry of all the active streams.
|
||||||
type ConnectionError struct {
|
type ConnectionError struct {
|
||||||
Desc string
|
Desc string
|
||||||
|
temp bool
|
||||||
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e ConnectionError) Error() string {
|
func (e ConnectionError) Error() string {
|
||||||
return fmt.Sprintf("connection error: desc = %q", e.Desc)
|
return fmt.Sprintf("connection error: desc = %q", e.Desc)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrConnClosing indicates that the transport is closing.
|
// Temporary indicates if this connection error is temporary or fatal.
|
||||||
var ErrConnClosing = ConnectionError{Desc: "transport is closing"}
|
func (e ConnectionError) Temporary() bool {
|
||||||
|
return e.temp
|
||||||
|
}
|
||||||
|
|
||||||
|
// Origin returns the original error of this connection error.
|
||||||
|
func (e ConnectionError) Origin() error {
|
||||||
|
// Never return nil error here.
|
||||||
|
// If the original error is nil, return itself.
|
||||||
|
if e.err == nil {
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrConnClosing indicates that the transport is closing.
|
||||||
|
ErrConnClosing = ConnectionError{Desc: "transport is closing", temp: true}
|
||||||
|
// ErrStreamDrain indicates that the stream is rejected by the server because
|
||||||
|
// the server stops accepting new RPCs.
|
||||||
|
ErrStreamDrain = StreamErrorf(codes.Unavailable, "the server stops accepting new RPCs")
|
||||||
|
)
|
||||||
|
|
||||||
// StreamError is an error that only affects one stream within a connection.
|
// StreamError is an error that only affects one stream within a connection.
|
||||||
type StreamError struct {
|
type StreamError struct {
|
||||||
@ -501,12 +551,25 @@ func ContextErr(err error) StreamError {
|
|||||||
|
|
||||||
// wait blocks until it can receive from ctx.Done, closing, or proceed.
|
// wait blocks until it can receive from ctx.Done, closing, or proceed.
|
||||||
// If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err.
|
// If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err.
|
||||||
|
// If it receives from done, it returns 0, io.EOF if ctx is not done; otherwise
|
||||||
|
// it return the StreamError for ctx.Err.
|
||||||
|
// If it receives from goAway, it returns 0, ErrStreamDrain.
|
||||||
// If it receives from closing, it returns 0, ErrConnClosing.
|
// If it receives from closing, it returns 0, ErrConnClosing.
|
||||||
// If it receives from proceed, it returns the received integer, nil.
|
// If it receives from proceed, it returns the received integer, nil.
|
||||||
func wait(ctx context.Context, closing <-chan struct{}, proceed <-chan int) (int, error) {
|
func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-chan int) (int, error) {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return 0, ContextErr(ctx.Err())
|
return 0, ContextErr(ctx.Err())
|
||||||
|
case <-done:
|
||||||
|
// User cancellation has precedence.
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return 0, ContextErr(ctx.Err())
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return 0, io.EOF
|
||||||
|
case <-goAway:
|
||||||
|
return 0, ErrStreamDrain
|
||||||
case <-closing:
|
case <-closing:
|
||||||
return 0, ErrConnClosing
|
return 0, ErrConnClosing
|
||||||
case i := <-proceed:
|
case i := <-proceed:
|
||||||
|
@ -39,7 +39,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
@ -75,7 +74,7 @@ const (
|
|||||||
normal hType = iota
|
normal hType = iota
|
||||||
suspended
|
suspended
|
||||||
misbehaved
|
misbehaved
|
||||||
malformedStatus
|
encodingRequiredStatus
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
|
func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
|
||||||
@ -111,27 +110,34 @@ func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("Failed to convert %v to *http2Server", s.ServerTransport())
|
t.Fatalf("Failed to convert %v to *http2Server", s.ServerTransport())
|
||||||
}
|
}
|
||||||
size := 1
|
|
||||||
if s.Method() == "foo.MaxFrame" {
|
|
||||||
size = http2MaxFrameLen
|
|
||||||
}
|
|
||||||
// Drain the client side stream flow control window.
|
|
||||||
var sent int
|
var sent int
|
||||||
for sent <= initialWindowSize {
|
p := make([]byte, http2MaxFrameLen)
|
||||||
|
for sent < initialWindowSize {
|
||||||
<-conn.writableChan
|
<-conn.writableChan
|
||||||
if err := conn.framer.writeData(true, s.id, false, make([]byte, size)); err != nil {
|
n := initialWindowSize - sent
|
||||||
|
// The last message may be smaller than http2MaxFrameLen
|
||||||
|
if n <= http2MaxFrameLen {
|
||||||
|
if s.Method() == "foo.Connection" {
|
||||||
|
// Violate connection level flow control window of client but do not
|
||||||
|
// violate any stream level windows.
|
||||||
|
p = make([]byte, n)
|
||||||
|
} else {
|
||||||
|
// Violate stream level flow control window of client.
|
||||||
|
p = make([]byte, n+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := conn.framer.writeData(true, s.id, false, p); err != nil {
|
||||||
conn.writableChan <- 0
|
conn.writableChan <- 0
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
conn.writableChan <- 0
|
conn.writableChan <- 0
|
||||||
sent += size
|
sent += len(p)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *testStreamHandler) handleStreamMalformedStatus(t *testing.T, s *Stream) {
|
func (h *testStreamHandler) handleStreamEncodingRequiredStatus(t *testing.T, s *Stream) {
|
||||||
// raw newline is not accepted by http2 framer and a http2.StreamError is
|
// raw newline is not accepted by http2 framer so it must be encoded.
|
||||||
// generated.
|
h.t.WriteStatus(s, encodingTestStatusCode, encodingTestStatusDesc)
|
||||||
h.t.WriteStatus(s, codes.Internal, "\n")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// start starts server. Other goroutines should block on s.readyChan for further operations.
|
// start starts server. Other goroutines should block on s.readyChan for further operations.
|
||||||
@ -179,9 +185,9 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) {
|
|||||||
go transport.HandleStreams(func(s *Stream) {
|
go transport.HandleStreams(func(s *Stream) {
|
||||||
go h.handleStreamMisbehave(t, s)
|
go h.handleStreamMisbehave(t, s)
|
||||||
})
|
})
|
||||||
case malformedStatus:
|
case encodingRequiredStatus:
|
||||||
go transport.HandleStreams(func(s *Stream) {
|
go transport.HandleStreams(func(s *Stream) {
|
||||||
go h.handleStreamMalformedStatus(t, s)
|
go h.handleStreamEncodingRequiredStatus(t, s)
|
||||||
})
|
})
|
||||||
default:
|
default:
|
||||||
go transport.HandleStreams(func(s *Stream) {
|
go transport.HandleStreams(func(s *Stream) {
|
||||||
@ -221,7 +227,7 @@ func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, Client
|
|||||||
ct ClientTransport
|
ct ClientTransport
|
||||||
connErr error
|
connErr error
|
||||||
)
|
)
|
||||||
ct, connErr = NewClientTransport(addr, &ConnectOptions{})
|
ct, connErr = NewClientTransport(context.Background(), addr, ConnectOptions{})
|
||||||
if connErr != nil {
|
if connErr != nil {
|
||||||
t.Fatalf("failed to create transport: %v", connErr)
|
t.Fatalf("failed to create transport: %v", connErr)
|
||||||
}
|
}
|
||||||
@ -252,7 +258,7 @@ func TestClientSendAndReceive(t *testing.T) {
|
|||||||
Last: true,
|
Last: true,
|
||||||
Delay: false,
|
Delay: false,
|
||||||
}
|
}
|
||||||
if err := ct.Write(s1, expectedRequest, &opts); err != nil {
|
if err := ct.Write(s1, expectedRequest, &opts); err != nil && err != io.EOF {
|
||||||
t.Fatalf("failed to send data: %v", err)
|
t.Fatalf("failed to send data: %v", err)
|
||||||
}
|
}
|
||||||
p := make([]byte, len(expectedResponse))
|
p := make([]byte, len(expectedResponse))
|
||||||
@ -289,7 +295,7 @@ func performOneRPC(ct ClientTransport) {
|
|||||||
Last: true,
|
Last: true,
|
||||||
Delay: false,
|
Delay: false,
|
||||||
}
|
}
|
||||||
if err := ct.Write(s, expectedRequest, &opts); err == nil {
|
if err := ct.Write(s, expectedRequest, &opts); err == nil || err == io.EOF {
|
||||||
time.Sleep(5 * time.Millisecond)
|
time.Sleep(5 * time.Millisecond)
|
||||||
// The following s.Recv()'s could error out because the
|
// The following s.Recv()'s could error out because the
|
||||||
// underlying transport is gone.
|
// underlying transport is gone.
|
||||||
@ -333,7 +339,7 @@ func TestLargeMessage(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
|
t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
|
||||||
}
|
}
|
||||||
if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil {
|
if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
|
||||||
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
|
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
|
||||||
}
|
}
|
||||||
p := make([]byte, len(expectedResponseLarge))
|
p := make([]byte, len(expectedResponseLarge))
|
||||||
@ -369,8 +375,8 @@ func TestGracefulClose(t *testing.T) {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
if _, err := ct.NewStream(context.Background(), callHdr); err != ErrConnClosing {
|
if _, err := ct.NewStream(context.Background(), callHdr); err != ErrStreamDrain {
|
||||||
t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", err, ErrConnClosing)
|
t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", ct, err, ErrStreamDrain)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
@ -379,7 +385,7 @@ func TestGracefulClose(t *testing.T) {
|
|||||||
Delay: false,
|
Delay: false,
|
||||||
}
|
}
|
||||||
// The stream which was created before graceful close can still proceed.
|
// The stream which was created before graceful close can still proceed.
|
||||||
if err := ct.Write(s, expectedRequest, &opts); err != nil {
|
if err := ct.Write(s, expectedRequest, &opts); err != nil && err != io.EOF {
|
||||||
t.Fatalf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
|
t.Fatalf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
|
||||||
}
|
}
|
||||||
p := make([]byte, len(expectedResponse))
|
p := make([]byte, len(expectedResponse))
|
||||||
@ -409,7 +415,7 @@ func TestLargeMessageSuspension(t *testing.T) {
|
|||||||
// Write should not be done successfully due to flow control.
|
// Write should not be done successfully due to flow control.
|
||||||
err = ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false})
|
err = ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false})
|
||||||
expectedErr := StreamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded)
|
expectedErr := StreamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded)
|
||||||
if err == nil || err != expectedErr {
|
if err != expectedErr {
|
||||||
t.Fatalf("Write got %v, want %v", err, expectedErr)
|
t.Fatalf("Write got %v, want %v", err, expectedErr)
|
||||||
}
|
}
|
||||||
ct.Close()
|
ct.Close()
|
||||||
@ -433,14 +439,21 @@ func TestMaxStreams(t *testing.T) {
|
|||||||
}
|
}
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
ch := make(chan int)
|
ch := make(chan int)
|
||||||
|
ready := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-time.After(5 * time.Millisecond):
|
case <-time.After(5 * time.Millisecond):
|
||||||
ch <- 0
|
select {
|
||||||
|
case ch <- 0:
|
||||||
|
case <-ready:
|
||||||
|
return
|
||||||
|
}
|
||||||
case <-time.After(5 * time.Second):
|
case <-time.After(5 * time.Second):
|
||||||
close(done)
|
close(done)
|
||||||
return
|
return
|
||||||
|
case <-ready:
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -467,6 +480,7 @@ func TestMaxStreams(t *testing.T) {
|
|||||||
}
|
}
|
||||||
cc.mu.Unlock()
|
cc.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
close(ready)
|
||||||
// Close the pending stream so that the streams quota becomes available for the next new stream.
|
// Close the pending stream so that the streams quota becomes available for the next new stream.
|
||||||
ct.CloseStream(s, nil)
|
ct.CloseStream(s, nil)
|
||||||
select {
|
select {
|
||||||
@ -546,6 +560,7 @@ func TestServerContextCanceledOnClosedConnection(t *testing.T) {
|
|||||||
case <-time.After(5 * time.Second):
|
case <-time.After(5 * time.Second):
|
||||||
t.Fatalf("Failed to cancel the context of the sever side stream.")
|
t.Fatalf("Failed to cancel the context of the sever side stream.")
|
||||||
}
|
}
|
||||||
|
server.stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServerWithMisbehavedClient(t *testing.T) {
|
func TestServerWithMisbehavedClient(t *testing.T) {
|
||||||
@ -652,7 +667,7 @@ func TestClientWithMisbehavedServer(t *testing.T) {
|
|||||||
server, ct := setUp(t, 0, math.MaxUint32, misbehaved)
|
server, ct := setUp(t, 0, math.MaxUint32, misbehaved)
|
||||||
callHdr := &CallHdr{
|
callHdr := &CallHdr{
|
||||||
Host: "localhost",
|
Host: "localhost",
|
||||||
Method: "foo",
|
Method: "foo.Stream",
|
||||||
}
|
}
|
||||||
conn, ok := ct.(*http2Client)
|
conn, ok := ct.(*http2Client)
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -663,7 +678,8 @@ func TestClientWithMisbehavedServer(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to open stream: %v", err)
|
t.Fatalf("Failed to open stream: %v", err)
|
||||||
}
|
}
|
||||||
if err := ct.Write(s, expectedRequest, &Options{Last: true, Delay: false}); err != nil {
|
d := make([]byte, 1)
|
||||||
|
if err := ct.Write(s, d, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
|
||||||
t.Fatalf("Failed to write: %v", err)
|
t.Fatalf("Failed to write: %v", err)
|
||||||
}
|
}
|
||||||
// Read without window update.
|
// Read without window update.
|
||||||
@ -685,17 +701,15 @@ func TestClientWithMisbehavedServer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
// Test the logic for the violation of the connection flow control window size restriction.
|
// Test the logic for the violation of the connection flow control window size restriction.
|
||||||
//
|
//
|
||||||
// Generate enough streams to drain the connection window.
|
// Generate enough streams to drain the connection window. Make the server flood the traffic
|
||||||
callHdr = &CallHdr{
|
// to violate flow control window size of the connection.
|
||||||
Host: "localhost",
|
callHdr.Method = "foo.Connection"
|
||||||
Method: "foo.MaxFrame",
|
|
||||||
}
|
|
||||||
for i := 0; i < int(initialConnWindowSize/initialWindowSize+10); i++ {
|
for i := 0; i < int(initialConnWindowSize/initialWindowSize+10); i++ {
|
||||||
s, err := ct.NewStream(context.Background(), callHdr)
|
s, err := ct.NewStream(context.Background(), callHdr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if err := ct.Write(s, expectedRequest, &Options{Last: true, Delay: false}); err != nil {
|
if err := ct.Write(s, d, &Options{Last: true, Delay: false}); err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -705,8 +719,13 @@ func TestClientWithMisbehavedServer(t *testing.T) {
|
|||||||
server.stop()
|
server.stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMalformedStatus(t *testing.T) {
|
var (
|
||||||
server, ct := setUp(t, 0, math.MaxUint32, malformedStatus)
|
encodingTestStatusCode = codes.Internal
|
||||||
|
encodingTestStatusDesc = "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEncodingRequiredStatus(t *testing.T) {
|
||||||
|
server, ct := setUp(t, 0, math.MaxUint32, encodingRequiredStatus)
|
||||||
callHdr := &CallHdr{
|
callHdr := &CallHdr{
|
||||||
Host: "localhost",
|
Host: "localhost",
|
||||||
Method: "foo",
|
Method: "foo",
|
||||||
@ -719,24 +738,26 @@ func TestMalformedStatus(t *testing.T) {
|
|||||||
Last: true,
|
Last: true,
|
||||||
Delay: false,
|
Delay: false,
|
||||||
}
|
}
|
||||||
if err := ct.Write(s, expectedRequest, &opts); err != nil {
|
if err := ct.Write(s, expectedRequest, &opts); err != nil && err != io.EOF {
|
||||||
t.Fatalf("Failed to write the request: %v", err)
|
t.Fatalf("Failed to write the request: %v", err)
|
||||||
}
|
}
|
||||||
p := make([]byte, http2MaxFrameLen)
|
p := make([]byte, http2MaxFrameLen)
|
||||||
expectedErr := StreamErrorf(codes.Internal, "invalid header field value \"\\n\"")
|
if _, err := s.dec.Read(p); err != io.EOF {
|
||||||
if _, err = s.dec.Read(p); err != expectedErr {
|
t.Fatalf("Read got error %v, want %v", err, io.EOF)
|
||||||
t.Fatalf("Read the err %v, want %v", err, expectedErr)
|
}
|
||||||
|
if s.StatusCode() != encodingTestStatusCode || s.StatusDesc() != encodingTestStatusDesc {
|
||||||
|
t.Fatalf("stream with status code %d, status desc %v, want %d, %v", s.StatusCode(), s.StatusDesc(), encodingTestStatusCode, encodingTestStatusDesc)
|
||||||
}
|
}
|
||||||
ct.Close()
|
ct.Close()
|
||||||
server.stop()
|
server.stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStreamContext(t *testing.T) {
|
func TestStreamContext(t *testing.T) {
|
||||||
expectedStream := Stream{}
|
expectedStream := &Stream{}
|
||||||
ctx := newContextWithStream(context.Background(), &expectedStream)
|
ctx := newContextWithStream(context.Background(), expectedStream)
|
||||||
s, ok := StreamFromContext(ctx)
|
s, ok := StreamFromContext(ctx)
|
||||||
if !ok || !reflect.DeepEqual(expectedStream, *s) {
|
if !ok || expectedStream != s {
|
||||||
t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, *s, ok, expectedStream)
|
t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, s, ok, expectedStream)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user