diff --git a/balancer.go b/balancer.go index 419e2146..e217a207 100644 --- a/balancer.go +++ b/balancer.go @@ -38,6 +38,7 @@ import ( "sync" "golang.org/x/net/context" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/naming" ) @@ -52,6 +53,14 @@ type Address struct { Metadata interface{} } +// BalancerConfig specifies the configurations for Balancer. +type BalancerConfig struct { + // DialCreds is the transport credential the Balancer implementation can + // use to dial to a remote load balancer server. The Balancer implementations + // can ignore this if it does not need to talk to another party securely. + DialCreds credentials.TransportCredentials +} + // BalancerGetOptions configures a Get call. // This is the EXPERIMENTAL API and may be changed or extended in the future. type BalancerGetOptions struct { @@ -66,11 +75,11 @@ type Balancer interface { // Start does the initialization work to bootstrap a Balancer. For example, // this function may start the name resolution and watch the updates. It will // be called when dialing. - Start(target string) error + Start(target string, config BalancerConfig) error // Up informs the Balancer that gRPC has a connection to the server at // addr. It returns down which is called once the connection to addr gets // lost or closed. - // TODO: It is not clear how to construct and take advantage the meaningful error + // TODO: It is not clear how to construct and take advantage of the meaningful error // parameter for down. Need realistic demands to guide. Up(addr Address) (down func(error)) // Get gets the address of a server for the RPC corresponding to ctx. @@ -205,7 +214,12 @@ func (rr *roundRobin) watchAddrUpdates() error { return nil } -func (rr *roundRobin) Start(target string) error { +func (rr *roundRobin) Start(target string, config BalancerConfig) error { + rr.mu.Lock() + defer rr.mu.Unlock() + if rr.done { + return ErrClientConnClosing + } if rr.r == nil { // If there is no name resolver installed, it is not needed to // do name resolution. In this case, target is added into rr.addrs diff --git a/call.go b/call.go index fea07998..788b3d92 100644 --- a/call.go +++ b/call.go @@ -96,7 +96,7 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd } outBuf, err := encode(codec, args, compressor, cbuf) if err != nil { - return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err) + return nil, Errorf(codes.Internal, "grpc: %v", err) } err = t.Write(stream, outBuf, opts) // t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method @@ -112,7 +112,14 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd // Invoke sends the RPC request on the wire and returns after response is received. // Invoke is called by generated code. Also users can call Invoke directly when it // is really needed in their use cases. -func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) { +func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) error { + if cc.dopts.unaryInt != nil { + return cc.dopts.unaryInt(ctx, method, args, reply, cc, invoke, opts...) + } + return invoke(ctx, method, args, reply, cc, opts...) +} + +func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) { c := defaultCallInfo for _, o := range opts { if err := o.before(&c); err != nil { diff --git a/clientconn.go b/clientconn.go index 1d3b46c6..e18d9659 100644 --- a/clientconn.go +++ b/clientconn.go @@ -83,15 +83,17 @@ var ( // dialOptions configure a Dial call. dialOptions are set by the DialOption // values passed to Dial. type dialOptions struct { - codec Codec - cp Compressor - dc Decompressor - bs backoffStrategy - balancer Balancer - block bool - insecure bool - timeout time.Duration - copts transport.ConnectOptions + unaryInt UnaryClientInterceptor + streamInt StreamClientInterceptor + codec Codec + cp Compressor + dc Decompressor + bs backoffStrategy + balancer Balancer + block bool + insecure bool + timeout time.Duration + copts transport.ConnectOptions } // DialOption configures how we set up the connection. @@ -215,6 +217,20 @@ func WithUserAgent(s string) DialOption { } } +// WithUnaryInterceptor returns a DialOption that specifies the interceptor for unary RPCs. +func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption { + return func(o *dialOptions) { + o.unaryInt = f + } +} + +// WithStreamInterceptor returns a DialOption that specifies the interceptor for streaming RPCs. +func WithStreamInterceptor(f StreamClientInterceptor) DialOption { + return func(o *dialOptions) { + o.streamInt = f + } +} + // Dial creates a client connection to the given target. func Dial(target string, opts ...DialOption) (*ClientConn, error) { return DialContext(context.Background(), target, opts...) @@ -254,31 +270,43 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * if cc.dopts.bs == nil { cc.dopts.bs = DefaultBackoffConfig } - - var ( - ok bool - addrs []Address - ) - if cc.dopts.balancer == nil { - // Connect to target directly if balancer is nil. - addrs = append(addrs, Address{Addr: target}) + creds := cc.dopts.copts.TransportCredentials + if creds != nil && creds.Info().ServerName != "" { + cc.authority = creds.Info().ServerName } else { - if err := cc.dopts.balancer.Start(target); err != nil { - return nil, err - } - ch := cc.dopts.balancer.Notify() - if ch == nil { - // There is no name resolver installed. - addrs = append(addrs, Address{Addr: target}) - } else { - addrs, ok = <-ch - if !ok || len(addrs) == 0 { - return nil, errNoAddr - } + colonPos := strings.LastIndex(target, ":") + if colonPos == -1 { + colonPos = len(target) } + cc.authority = target[:colonPos] } + var ok bool waitC := make(chan error, 1) go func() { + var addrs []Address + if cc.dopts.balancer == nil { + // Connect to target directly if balancer is nil. + addrs = append(addrs, Address{Addr: target}) + } else { + config := BalancerConfig{ + DialCreds: creds, + } + if err := cc.dopts.balancer.Start(target, config); err != nil { + waitC <- err + return + } + ch := cc.dopts.balancer.Notify() + if ch == nil { + // There is no name resolver installed. + addrs = append(addrs, Address{Addr: target}) + } else { + addrs, ok = <-ch + if !ok || len(addrs) == 0 { + waitC <- errNoAddr + return + } + } + } for _, a := range addrs { if err := cc.resetAddrConn(a, false, nil); err != nil { waitC <- err @@ -306,11 +334,6 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * if ok { go cc.lbWatcher() } - colonPos := strings.LastIndex(target, ":") - if colonPos == -1 { - colonPos = len(target) - } - cc.authority = target[:colonPos] return cc, nil } @@ -664,7 +687,7 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() { return err } - grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr) + grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %v", err, ac.addr) ac.mu.Lock() if ac.state == Shutdown { // ac.tearDown(...) has been invoked. diff --git a/clientconn_test.go b/clientconn_test.go index c49548dc..179be251 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -65,7 +65,23 @@ func TestTLSDialTimeout(t *testing.T) { conn.Close() } if err != ErrClientConnTimeout { - t.Fatalf("grpc.Dial(_, _) = %v, %v, want %v", conn, err, ErrClientConnTimeout) + t.Fatalf("Dial(_, _) = %v, %v, want %v", conn, err, ErrClientConnTimeout) + } +} + +func TestTLSServerNameOverwrite(t *testing.T) { + overwriteServerName := "over.write.server.name" + creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", overwriteServerName) + if err != nil { + t.Fatalf("Failed to create credentials %v", err) + } + conn, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds)) + if err != nil { + t.Fatalf("Dial(_, _) = _, %v, want _, ", err) + } + conn.Close() + if conn.authority != overwriteServerName { + t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, overwriteServerName) } } @@ -73,10 +89,47 @@ func TestDialContextCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() if _, err := DialContext(ctx, "Non-Existent.Server:80", WithBlock(), WithInsecure()); err != context.Canceled { - t.Fatalf("grpc.DialContext(%v, _) = _, %v, want _, %v", ctx, err, context.Canceled) + t.Fatalf("DialContext(%v, _) = _, %v, want _, %v", ctx, err, context.Canceled) } } +// blockingBalancer mimics the behavior of balancers whose initialization takes a long time. +// In this test, reading from blockingBalancer.Notify() blocks forever. +type blockingBalancer struct { + ch chan []Address +} + +func newBlockingBalancer() Balancer { + return &blockingBalancer{ch: make(chan []Address)} +} +func (b *blockingBalancer) Start(target string, config BalancerConfig) error { + return nil +} +func (b *blockingBalancer) Up(addr Address) func(error) { + return nil +} +func (b *blockingBalancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) { + return Address{}, nil, nil +} +func (b *blockingBalancer) Notify() <-chan []Address { + return b.ch +} +func (b *blockingBalancer) Close() error { + close(b.ch) + return nil +} + +func TestDialWithBlockingBalancer(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + dialDone := make(chan struct{}) + go func() { + DialContext(ctx, "Non-Existent.Server:80", WithBlock(), WithInsecure(), WithBalancer(newBlockingBalancer())) + close(dialDone) + }() + cancel() + <-dialDone +} + func TestCredentialsMisuse(t *testing.T) { tlsCreds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") if err != nil { diff --git a/credentials/credentials.go b/credentials/credentials.go index 3f17b706..5555ef02 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -40,6 +40,7 @@ package credentials // import "google.golang.org/grpc/credentials" import ( "crypto/tls" "crypto/x509" + "errors" "fmt" "io/ioutil" "net" @@ -71,7 +72,7 @@ type PerRPCCredentials interface { } // ProtocolInfo provides information regarding the gRPC wire protocol version, -// security protocol, security protocol version in use, etc. +// security protocol, security protocol version in use, server name, etc. type ProtocolInfo struct { // ProtocolVersion is the gRPC wire protocol version. ProtocolVersion string @@ -79,6 +80,8 @@ type ProtocolInfo struct { SecurityProtocol string // SecurityVersion is the security protocol version. SecurityVersion string + // ServerName is the user-configured server name. + ServerName string } // AuthInfo defines the common interface for the auth information the users are interested in. @@ -86,6 +89,12 @@ type AuthInfo interface { AuthType() string } +var ( + // ErrConnDispatched indicates that rawConn has been dispatched out of gRPC + // and the caller should not close rawConn. + ErrConnDispatched = errors.New("credentials: rawConn is dispatched out of gRPC") +) + // TransportCredentials defines the common interface for all the live gRPC wire // protocols and supported transport security protocols (e.g., TLS, SSL). type TransportCredentials interface { @@ -100,6 +109,12 @@ type TransportCredentials interface { ServerHandshake(net.Conn) (net.Conn, AuthInfo, error) // Info provides the ProtocolInfo of this TransportCredentials. Info() ProtocolInfo + // Clone makes a copy of this TransportCredentials. + Clone() TransportCredentials + // OverrideServerName overrides the server name used to verify the hostname on the returned certificates from the server. + // gRPC internals also use it to override the virtual hosting name if it is set. + // It must be called before dialing. Currently, this is only used by grpclb. + OverrideServerName(string) error } // TLSInfo contains the auth information for a TLS authenticated connection. @@ -123,19 +138,10 @@ func (c tlsCreds) Info() ProtocolInfo { return ProtocolInfo{ SecurityProtocol: "tls", SecurityVersion: "1.2", + ServerName: c.config.ServerName, } } -// GetRequestMetadata returns nil, nil since TLS credentials does not have -// metadata. -func (c *tlsCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { - return nil, nil -} - -func (c *tlsCreds) RequireTransportSecurity() bool { - return true -} - func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) { // use local cfg to avoid clobbering ServerName if using multiple endpoints cfg := cloneTLSConfig(c.config) @@ -172,6 +178,15 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) return conn, TLSInfo{conn.ConnectionState()}, nil } +func (c *tlsCreds) Clone() TransportCredentials { + return NewTLS(c.config) +} + +func (c *tlsCreds) OverrideServerName(serverNameOverride string) error { + c.config.ServerName = serverNameOverride + return nil +} + // NewTLS uses c to construct a TransportCredentials based on TLS. func NewTLS(c *tls.Config) TransportCredentials { tc := &tlsCreds{cloneTLSConfig(c)} @@ -180,12 +195,16 @@ func NewTLS(c *tls.Config) TransportCredentials { } // NewClientTLSFromCert constructs a TLS from the input certificate for client. -func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportCredentials { - return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp}) +// serverNameOverride is for testing only. If set to a non empty string, +// it will override the virtual host name of authority (e.g. :authority header field) in requests. +func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverride string) TransportCredentials { + return NewTLS(&tls.Config{ServerName: serverNameOverride, RootCAs: cp}) } // NewClientTLSFromFile constructs a TLS from the input certificate file for client. -func NewClientTLSFromFile(certFile, serverName string) (TransportCredentials, error) { +// serverNameOverride is for testing only. If set to a non empty string, +// it will override the virtual host name of authority (e.g. :authority header field) in requests. +func NewClientTLSFromFile(certFile, serverNameOverride string) (TransportCredentials, error) { b, err := ioutil.ReadFile(certFile) if err != nil { return nil, err @@ -194,7 +213,7 @@ func NewClientTLSFromFile(certFile, serverName string) (TransportCredentials, er if !cp.AppendCertsFromPEM(b) { return nil, fmt.Errorf("credentials: failed to append certificates") } - return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp}), nil + return NewTLS(&tls.Config{ServerName: serverNameOverride, RootCAs: cp}), nil } // NewServerTLSFromCert constructs a TLS from the input certificate for server. diff --git a/credentials/credentials_test.go b/credentials/credentials_test.go new file mode 100644 index 00000000..caf35b2f --- /dev/null +++ b/credentials/credentials_test.go @@ -0,0 +1,61 @@ +/* + * + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package credentials + +import ( + "testing" +) + +func TestTLSOverrideServerName(t *testing.T) { + expectedServerName := "server.name" + c := NewTLS(nil) + c.OverrideServerName(expectedServerName) + if c.Info().ServerName != expectedServerName { + t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) + } +} + +func TestTLSClone(t *testing.T) { + expectedServerName := "server.name" + c := NewTLS(nil) + c.OverrideServerName(expectedServerName) + cc := c.Clone() + if cc.Info().ServerName != expectedServerName { + t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName) + } + cc.OverrideServerName("") + if c.Info().ServerName != expectedServerName { + t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) + } +} diff --git a/grpclb/grpc_lb_v1/grpclb.pb.go b/grpclb/grpc_lb_v1/grpclb.pb.go new file mode 100644 index 00000000..da371e52 --- /dev/null +++ b/grpclb/grpc_lb_v1/grpclb.pb.go @@ -0,0 +1,557 @@ +// Code generated by protoc-gen-go. +// source: grpclb.proto +// DO NOT EDIT! + +/* +Package grpc_lb_v1 is a generated protocol buffer package. + +It is generated from these files: + grpclb.proto + +It has these top-level messages: + Duration + LoadBalanceRequest + InitialLoadBalanceRequest + ClientStats + LoadBalanceResponse + InitialLoadBalanceResponse + ServerList + Server +*/ +package grpc_lb_v1 + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type Duration struct { + // Signed seconds of the span of time. Must be from -315,576,000,000 + // to +315,576,000,000 inclusive. + Seconds int64 `protobuf:"varint,1,opt,name=seconds" json:"seconds,omitempty"` + // Signed fractions of a second at nanosecond resolution of the span + // of time. Durations less than one second are represented with a 0 + // `seconds` field and a positive or negative `nanos` field. For durations + // of one second or more, a non-zero value for the `nanos` field must be + // of the same sign as the `seconds` field. Must be from -999,999,999 + // to +999,999,999 inclusive. + Nanos int32 `protobuf:"varint,2,opt,name=nanos" json:"nanos,omitempty"` +} + +func (m *Duration) Reset() { *m = Duration{} } +func (m *Duration) String() string { return proto.CompactTextString(m) } +func (*Duration) ProtoMessage() {} +func (*Duration) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +type LoadBalanceRequest struct { + // Types that are valid to be assigned to LoadBalanceRequestType: + // *LoadBalanceRequest_InitialRequest + // *LoadBalanceRequest_ClientStats + LoadBalanceRequestType isLoadBalanceRequest_LoadBalanceRequestType `protobuf_oneof:"load_balance_request_type"` +} + +func (m *LoadBalanceRequest) Reset() { *m = LoadBalanceRequest{} } +func (m *LoadBalanceRequest) String() string { return proto.CompactTextString(m) } +func (*LoadBalanceRequest) ProtoMessage() {} +func (*LoadBalanceRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } + +type isLoadBalanceRequest_LoadBalanceRequestType interface { + isLoadBalanceRequest_LoadBalanceRequestType() +} + +type LoadBalanceRequest_InitialRequest struct { + InitialRequest *InitialLoadBalanceRequest `protobuf:"bytes,1,opt,name=initial_request,oneof"` +} +type LoadBalanceRequest_ClientStats struct { + ClientStats *ClientStats `protobuf:"bytes,2,opt,name=client_stats,oneof"` +} + +func (*LoadBalanceRequest_InitialRequest) isLoadBalanceRequest_LoadBalanceRequestType() {} +func (*LoadBalanceRequest_ClientStats) isLoadBalanceRequest_LoadBalanceRequestType() {} + +func (m *LoadBalanceRequest) GetLoadBalanceRequestType() isLoadBalanceRequest_LoadBalanceRequestType { + if m != nil { + return m.LoadBalanceRequestType + } + return nil +} + +func (m *LoadBalanceRequest) GetInitialRequest() *InitialLoadBalanceRequest { + if x, ok := m.GetLoadBalanceRequestType().(*LoadBalanceRequest_InitialRequest); ok { + return x.InitialRequest + } + return nil +} + +func (m *LoadBalanceRequest) GetClientStats() *ClientStats { + if x, ok := m.GetLoadBalanceRequestType().(*LoadBalanceRequest_ClientStats); ok { + return x.ClientStats + } + return nil +} + +// XXX_OneofFuncs is for the internal use of the proto package. +func (*LoadBalanceRequest) XXX_OneofFuncs() (func(msg proto.Message, b *proto.Buffer) error, func(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error), func(msg proto.Message) (n int), []interface{}) { + return _LoadBalanceRequest_OneofMarshaler, _LoadBalanceRequest_OneofUnmarshaler, _LoadBalanceRequest_OneofSizer, []interface{}{ + (*LoadBalanceRequest_InitialRequest)(nil), + (*LoadBalanceRequest_ClientStats)(nil), + } +} + +func _LoadBalanceRequest_OneofMarshaler(msg proto.Message, b *proto.Buffer) error { + m := msg.(*LoadBalanceRequest) + // load_balance_request_type + switch x := m.LoadBalanceRequestType.(type) { + case *LoadBalanceRequest_InitialRequest: + b.EncodeVarint(1<<3 | proto.WireBytes) + if err := b.EncodeMessage(x.InitialRequest); err != nil { + return err + } + case *LoadBalanceRequest_ClientStats: + b.EncodeVarint(2<<3 | proto.WireBytes) + if err := b.EncodeMessage(x.ClientStats); err != nil { + return err + } + case nil: + default: + return fmt.Errorf("LoadBalanceRequest.LoadBalanceRequestType has unexpected type %T", x) + } + return nil +} + +func _LoadBalanceRequest_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error) { + m := msg.(*LoadBalanceRequest) + switch tag { + case 1: // load_balance_request_type.initial_request + if wire != proto.WireBytes { + return true, proto.ErrInternalBadWireType + } + msg := new(InitialLoadBalanceRequest) + err := b.DecodeMessage(msg) + m.LoadBalanceRequestType = &LoadBalanceRequest_InitialRequest{msg} + return true, err + case 2: // load_balance_request_type.client_stats + if wire != proto.WireBytes { + return true, proto.ErrInternalBadWireType + } + msg := new(ClientStats) + err := b.DecodeMessage(msg) + m.LoadBalanceRequestType = &LoadBalanceRequest_ClientStats{msg} + return true, err + default: + return false, nil + } +} + +func _LoadBalanceRequest_OneofSizer(msg proto.Message) (n int) { + m := msg.(*LoadBalanceRequest) + // load_balance_request_type + switch x := m.LoadBalanceRequestType.(type) { + case *LoadBalanceRequest_InitialRequest: + s := proto.Size(x.InitialRequest) + n += proto.SizeVarint(1<<3 | proto.WireBytes) + n += proto.SizeVarint(uint64(s)) + n += s + case *LoadBalanceRequest_ClientStats: + s := proto.Size(x.ClientStats) + n += proto.SizeVarint(2<<3 | proto.WireBytes) + n += proto.SizeVarint(uint64(s)) + n += s + case nil: + default: + panic(fmt.Sprintf("proto: unexpected type %T in oneof", x)) + } + return n +} + +type InitialLoadBalanceRequest struct { + // Name of load balanced service (IE, service.grpc.gslb.google.com) + Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` +} + +func (m *InitialLoadBalanceRequest) Reset() { *m = InitialLoadBalanceRequest{} } +func (m *InitialLoadBalanceRequest) String() string { return proto.CompactTextString(m) } +func (*InitialLoadBalanceRequest) ProtoMessage() {} +func (*InitialLoadBalanceRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} } + +// Contains client level statistics that are useful to load balancing. Each +// count should be reset to zero after reporting the stats. +type ClientStats struct { + // The total number of requests sent by the client since the last report. + TotalRequests int64 `protobuf:"varint,1,opt,name=total_requests" json:"total_requests,omitempty"` + // The number of client rpc errors since the last report. + ClientRpcErrors int64 `protobuf:"varint,2,opt,name=client_rpc_errors" json:"client_rpc_errors,omitempty"` + // The number of dropped requests since the last report. + DroppedRequests int64 `protobuf:"varint,3,opt,name=dropped_requests" json:"dropped_requests,omitempty"` +} + +func (m *ClientStats) Reset() { *m = ClientStats{} } +func (m *ClientStats) String() string { return proto.CompactTextString(m) } +func (*ClientStats) ProtoMessage() {} +func (*ClientStats) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} } + +type LoadBalanceResponse struct { + // Types that are valid to be assigned to LoadBalanceResponseType: + // *LoadBalanceResponse_InitialResponse + // *LoadBalanceResponse_ServerList + LoadBalanceResponseType isLoadBalanceResponse_LoadBalanceResponseType `protobuf_oneof:"load_balance_response_type"` +} + +func (m *LoadBalanceResponse) Reset() { *m = LoadBalanceResponse{} } +func (m *LoadBalanceResponse) String() string { return proto.CompactTextString(m) } +func (*LoadBalanceResponse) ProtoMessage() {} +func (*LoadBalanceResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} } + +type isLoadBalanceResponse_LoadBalanceResponseType interface { + isLoadBalanceResponse_LoadBalanceResponseType() +} + +type LoadBalanceResponse_InitialResponse struct { + InitialResponse *InitialLoadBalanceResponse `protobuf:"bytes,1,opt,name=initial_response,oneof"` +} +type LoadBalanceResponse_ServerList struct { + ServerList *ServerList `protobuf:"bytes,2,opt,name=server_list,oneof"` +} + +func (*LoadBalanceResponse_InitialResponse) isLoadBalanceResponse_LoadBalanceResponseType() {} +func (*LoadBalanceResponse_ServerList) isLoadBalanceResponse_LoadBalanceResponseType() {} + +func (m *LoadBalanceResponse) GetLoadBalanceResponseType() isLoadBalanceResponse_LoadBalanceResponseType { + if m != nil { + return m.LoadBalanceResponseType + } + return nil +} + +func (m *LoadBalanceResponse) GetInitialResponse() *InitialLoadBalanceResponse { + if x, ok := m.GetLoadBalanceResponseType().(*LoadBalanceResponse_InitialResponse); ok { + return x.InitialResponse + } + return nil +} + +func (m *LoadBalanceResponse) GetServerList() *ServerList { + if x, ok := m.GetLoadBalanceResponseType().(*LoadBalanceResponse_ServerList); ok { + return x.ServerList + } + return nil +} + +// XXX_OneofFuncs is for the internal use of the proto package. +func (*LoadBalanceResponse) XXX_OneofFuncs() (func(msg proto.Message, b *proto.Buffer) error, func(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error), func(msg proto.Message) (n int), []interface{}) { + return _LoadBalanceResponse_OneofMarshaler, _LoadBalanceResponse_OneofUnmarshaler, _LoadBalanceResponse_OneofSizer, []interface{}{ + (*LoadBalanceResponse_InitialResponse)(nil), + (*LoadBalanceResponse_ServerList)(nil), + } +} + +func _LoadBalanceResponse_OneofMarshaler(msg proto.Message, b *proto.Buffer) error { + m := msg.(*LoadBalanceResponse) + // load_balance_response_type + switch x := m.LoadBalanceResponseType.(type) { + case *LoadBalanceResponse_InitialResponse: + b.EncodeVarint(1<<3 | proto.WireBytes) + if err := b.EncodeMessage(x.InitialResponse); err != nil { + return err + } + case *LoadBalanceResponse_ServerList: + b.EncodeVarint(2<<3 | proto.WireBytes) + if err := b.EncodeMessage(x.ServerList); err != nil { + return err + } + case nil: + default: + return fmt.Errorf("LoadBalanceResponse.LoadBalanceResponseType has unexpected type %T", x) + } + return nil +} + +func _LoadBalanceResponse_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error) { + m := msg.(*LoadBalanceResponse) + switch tag { + case 1: // load_balance_response_type.initial_response + if wire != proto.WireBytes { + return true, proto.ErrInternalBadWireType + } + msg := new(InitialLoadBalanceResponse) + err := b.DecodeMessage(msg) + m.LoadBalanceResponseType = &LoadBalanceResponse_InitialResponse{msg} + return true, err + case 2: // load_balance_response_type.server_list + if wire != proto.WireBytes { + return true, proto.ErrInternalBadWireType + } + msg := new(ServerList) + err := b.DecodeMessage(msg) + m.LoadBalanceResponseType = &LoadBalanceResponse_ServerList{msg} + return true, err + default: + return false, nil + } +} + +func _LoadBalanceResponse_OneofSizer(msg proto.Message) (n int) { + m := msg.(*LoadBalanceResponse) + // load_balance_response_type + switch x := m.LoadBalanceResponseType.(type) { + case *LoadBalanceResponse_InitialResponse: + s := proto.Size(x.InitialResponse) + n += proto.SizeVarint(1<<3 | proto.WireBytes) + n += proto.SizeVarint(uint64(s)) + n += s + case *LoadBalanceResponse_ServerList: + s := proto.Size(x.ServerList) + n += proto.SizeVarint(2<<3 | proto.WireBytes) + n += proto.SizeVarint(uint64(s)) + n += s + case nil: + default: + panic(fmt.Sprintf("proto: unexpected type %T in oneof", x)) + } + return n +} + +type InitialLoadBalanceResponse struct { + // This is an application layer redirect that indicates the client should use + // the specified server for load balancing. When this field is non-empty in + // the response, the client should open a separate connection to the + // load_balancer_delegate and call the BalanceLoad method. + LoadBalancerDelegate string `protobuf:"bytes,1,opt,name=load_balancer_delegate" json:"load_balancer_delegate,omitempty"` + // This interval defines how often the client should send the client stats + // to the load balancer. Stats should only be reported when the duration is + // positive. + ClientStatsReportInterval *Duration `protobuf:"bytes,3,opt,name=client_stats_report_interval" json:"client_stats_report_interval,omitempty"` +} + +func (m *InitialLoadBalanceResponse) Reset() { *m = InitialLoadBalanceResponse{} } +func (m *InitialLoadBalanceResponse) String() string { return proto.CompactTextString(m) } +func (*InitialLoadBalanceResponse) ProtoMessage() {} +func (*InitialLoadBalanceResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} } + +func (m *InitialLoadBalanceResponse) GetClientStatsReportInterval() *Duration { + if m != nil { + return m.ClientStatsReportInterval + } + return nil +} + +type ServerList struct { + // Contains a list of servers selected by the load balancer. The list will + // be updated when server resolutions change or as needed to balance load + // across more servers. The client should consume the server list in order + // unless instructed otherwise via the client_config. + Servers []*Server `protobuf:"bytes,1,rep,name=servers" json:"servers,omitempty"` + // Indicates the amount of time that the client should consider this server + // list as valid. It may be considered stale after waiting this interval of + // time after receiving the list. If the interval is not positive, the + // client can assume the list is valid until the next list is received. + ExpirationInterval *Duration `protobuf:"bytes,3,opt,name=expiration_interval" json:"expiration_interval,omitempty"` +} + +func (m *ServerList) Reset() { *m = ServerList{} } +func (m *ServerList) String() string { return proto.CompactTextString(m) } +func (*ServerList) ProtoMessage() {} +func (*ServerList) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} } + +func (m *ServerList) GetServers() []*Server { + if m != nil { + return m.Servers + } + return nil +} + +func (m *ServerList) GetExpirationInterval() *Duration { + if m != nil { + return m.ExpirationInterval + } + return nil +} + +type Server struct { + // A resolved address for the server, serialized in network-byte-order. It may + // either be an IPv4 or IPv6 address. + IpAddress []byte `protobuf:"bytes,1,opt,name=ip_address,proto3" json:"ip_address,omitempty"` + // A resolved port number for the server. + Port int32 `protobuf:"varint,2,opt,name=port" json:"port,omitempty"` + // An opaque but printable token given to the frontend for each pick. All + // frontend requests for that pick must include the token in its initial + // metadata. The token is used by the backend to verify the request and to + // allow the backend to report load to the gRPC LB system. + LoadBalanceToken string `protobuf:"bytes,3,opt,name=load_balance_token" json:"load_balance_token,omitempty"` + // Indicates whether this particular request should be dropped by the client + // when this server is chosen from the list. + DropRequest bool `protobuf:"varint,4,opt,name=drop_request" json:"drop_request,omitempty"` +} + +func (m *Server) Reset() { *m = Server{} } +func (m *Server) String() string { return proto.CompactTextString(m) } +func (*Server) ProtoMessage() {} +func (*Server) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} } + +func init() { + proto.RegisterType((*Duration)(nil), "grpc.lb.v1.Duration") + proto.RegisterType((*LoadBalanceRequest)(nil), "grpc.lb.v1.LoadBalanceRequest") + proto.RegisterType((*InitialLoadBalanceRequest)(nil), "grpc.lb.v1.InitialLoadBalanceRequest") + proto.RegisterType((*ClientStats)(nil), "grpc.lb.v1.ClientStats") + proto.RegisterType((*LoadBalanceResponse)(nil), "grpc.lb.v1.LoadBalanceResponse") + proto.RegisterType((*InitialLoadBalanceResponse)(nil), "grpc.lb.v1.InitialLoadBalanceResponse") + proto.RegisterType((*ServerList)(nil), "grpc.lb.v1.ServerList") + proto.RegisterType((*Server)(nil), "grpc.lb.v1.Server") +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion3 + +// Client API for LoadBalancer service + +type LoadBalancerClient interface { + // Bidirectional rpc to get a list of servers. + BalanceLoad(ctx context.Context, opts ...grpc.CallOption) (LoadBalancer_BalanceLoadClient, error) +} + +type loadBalancerClient struct { + cc *grpc.ClientConn +} + +func NewLoadBalancerClient(cc *grpc.ClientConn) LoadBalancerClient { + return &loadBalancerClient{cc} +} + +func (c *loadBalancerClient) BalanceLoad(ctx context.Context, opts ...grpc.CallOption) (LoadBalancer_BalanceLoadClient, error) { + stream, err := grpc.NewClientStream(ctx, &_LoadBalancer_serviceDesc.Streams[0], c.cc, "/grpc.lb.v1.LoadBalancer/BalanceLoad", opts...) + if err != nil { + return nil, err + } + x := &loadBalancerBalanceLoadClient{stream} + return x, nil +} + +type LoadBalancer_BalanceLoadClient interface { + Send(*LoadBalanceRequest) error + Recv() (*LoadBalanceResponse, error) + grpc.ClientStream +} + +type loadBalancerBalanceLoadClient struct { + grpc.ClientStream +} + +func (x *loadBalancerBalanceLoadClient) Send(m *LoadBalanceRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *loadBalancerBalanceLoadClient) Recv() (*LoadBalanceResponse, error) { + m := new(LoadBalanceResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// Server API for LoadBalancer service + +type LoadBalancerServer interface { + // Bidirectional rpc to get a list of servers. + BalanceLoad(LoadBalancer_BalanceLoadServer) error +} + +func RegisterLoadBalancerServer(s *grpc.Server, srv LoadBalancerServer) { + s.RegisterService(&_LoadBalancer_serviceDesc, srv) +} + +func _LoadBalancer_BalanceLoad_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(LoadBalancerServer).BalanceLoad(&loadBalancerBalanceLoadServer{stream}) +} + +type LoadBalancer_BalanceLoadServer interface { + Send(*LoadBalanceResponse) error + Recv() (*LoadBalanceRequest, error) + grpc.ServerStream +} + +type loadBalancerBalanceLoadServer struct { + grpc.ServerStream +} + +func (x *loadBalancerBalanceLoadServer) Send(m *LoadBalanceResponse) error { + return x.ServerStream.SendMsg(m) +} + +func (x *loadBalancerBalanceLoadServer) Recv() (*LoadBalanceRequest, error) { + m := new(LoadBalanceRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +var _LoadBalancer_serviceDesc = grpc.ServiceDesc{ + ServiceName: "grpc.lb.v1.LoadBalancer", + HandlerType: (*LoadBalancerServer)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "BalanceLoad", + Handler: _LoadBalancer_BalanceLoad_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, + Metadata: fileDescriptor0, +} + +func init() { proto.RegisterFile("grpclb.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 471 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0x8c, 0x93, 0x51, 0x6f, 0xd3, 0x3e, + 0x14, 0xc5, 0x9b, 0x7f, 0xb7, 0xfd, 0xb7, 0x9b, 0xc0, 0xc6, 0xdd, 0x54, 0xda, 0x32, 0x8d, 0x2a, + 0x08, 0x54, 0x90, 0x08, 0x2c, 0xbc, 0xf1, 0x84, 0x0a, 0x0f, 0x45, 0xda, 0xd3, 0xf6, 0x86, 0x90, + 0x2c, 0x27, 0xb9, 0x9a, 0x2c, 0x82, 0x6d, 0x6c, 0xaf, 0x1a, 0xdf, 0x07, 0xf1, 0x39, 0x91, 0xe3, + 0x94, 0x64, 0x54, 0x15, 0xbc, 0xc5, 0xbe, 0x3e, 0xf7, 0x1e, 0xff, 0x7c, 0x02, 0xc9, 0xb5, 0xd1, + 0x65, 0x5d, 0x64, 0xda, 0x28, 0xa7, 0x10, 0xfc, 0x2a, 0xab, 0x8b, 0x6c, 0x75, 0x9e, 0xbe, 0x80, + 0xfd, 0x0f, 0x37, 0x86, 0x3b, 0xa1, 0x24, 0x1e, 0xc2, 0xff, 0x96, 0x4a, 0x25, 0x2b, 0x3b, 0x8e, + 0x66, 0xd1, 0x7c, 0x88, 0xf7, 0x60, 0x57, 0x72, 0xa9, 0xec, 0xf8, 0xbf, 0x59, 0x34, 0xdf, 0x4d, + 0x7f, 0x44, 0x80, 0x17, 0x8a, 0x57, 0x0b, 0x5e, 0x73, 0x59, 0xd2, 0x25, 0x7d, 0xbb, 0x21, 0xeb, + 0xf0, 0x1d, 0x1c, 0x0a, 0x29, 0x9c, 0xe0, 0x35, 0x33, 0x61, 0xab, 0x91, 0xc7, 0xf9, 0xd3, 0xac, + 0x1b, 0x94, 0x7d, 0x0c, 0x47, 0x36, 0xf5, 0xcb, 0x01, 0xbe, 0x82, 0xa4, 0xac, 0x05, 0x49, 0xc7, + 0xac, 0xe3, 0x2e, 0x8c, 0x8b, 0xf3, 0x87, 0x7d, 0xf9, 0xfb, 0xa6, 0x7e, 0xe5, 0xcb, 0xcb, 0xc1, + 0xe2, 0x11, 0x4c, 0x6a, 0xc5, 0x2b, 0x56, 0x84, 0x4e, 0xeb, 0xb9, 0xcc, 0x7d, 0xd7, 0x94, 0x3e, + 0x87, 0xc9, 0xd6, 0x61, 0x98, 0xc0, 0x8e, 0xe4, 0x5f, 0xa9, 0x71, 0x78, 0x90, 0x7e, 0x82, 0xb8, + 0xd7, 0x18, 0x47, 0x70, 0xdf, 0x29, 0xd7, 0xdd, 0x63, 0xcd, 0x61, 0x02, 0x0f, 0x5a, 0x7f, 0x46, + 0x97, 0x8c, 0x8c, 0x51, 0x26, 0x98, 0x1c, 0xe2, 0x18, 0x8e, 0x2a, 0xa3, 0xb4, 0xa6, 0xaa, 0x13, + 0x0d, 0x7d, 0x25, 0xfd, 0x19, 0xc1, 0xf1, 0x1d, 0x03, 0x56, 0x2b, 0x69, 0x09, 0x17, 0x70, 0xd4, + 0xe1, 0x0a, 0x7b, 0x2d, 0xaf, 0x67, 0x7f, 0xe3, 0x15, 0x4e, 0x2f, 0x07, 0xf8, 0x12, 0x62, 0x4b, + 0x66, 0x45, 0x86, 0xd5, 0xc2, 0xba, 0x96, 0xd7, 0xa8, 0x2f, 0xbf, 0x6a, 0xca, 0x17, 0xc2, 0xf3, + 0x5d, 0x9c, 0xc2, 0xf4, 0x0f, 0x5c, 0xa1, 0x53, 0xe0, 0x75, 0x0b, 0xd3, 0xed, 0xc3, 0xf0, 0x0c, + 0x46, 0x7d, 0xad, 0x61, 0x15, 0xd5, 0x74, 0xcd, 0x5d, 0x8b, 0x10, 0xdf, 0xc2, 0x69, 0xff, 0xed, + 0x98, 0x21, 0xad, 0x8c, 0x63, 0x42, 0x3a, 0x32, 0x2b, 0x5e, 0x37, 0x30, 0xe2, 0xfc, 0xa4, 0xef, + 0x6d, 0x1d, 0xb8, 0xb4, 0x02, 0xe8, 0x7c, 0xe2, 0x13, 0x1f, 0x3f, 0xbf, 0xf2, 0xd8, 0x87, 0xf3, + 0x38, 0xc7, 0xcd, 0x0b, 0xe1, 0x39, 0x1c, 0xd3, 0xad, 0x16, 0xa1, 0xc1, 0xbf, 0x4d, 0xf9, 0x0c, + 0x7b, 0xad, 0x18, 0x01, 0x84, 0x66, 0xbc, 0xaa, 0x0c, 0xd9, 0xf0, 0xb6, 0x89, 0x0f, 0x84, 0x37, + 0x1c, 0x22, 0x8e, 0x53, 0xc0, 0x3b, 0xa4, 0x9c, 0xfa, 0x42, 0xb2, 0xe9, 0x7e, 0x80, 0x27, 0x90, + 0xf8, 0xa7, 0xfe, 0x1d, 0xf2, 0x9d, 0x59, 0x34, 0xdf, 0xcf, 0x0b, 0x48, 0x7a, 0xd8, 0x0c, 0x5e, + 0x42, 0xdc, 0x7e, 0xfb, 0x6d, 0x3c, 0xeb, 0x5b, 0xda, 0xcc, 0xe3, 0xf4, 0xf1, 0xd6, 0x7a, 0xe0, + 0x3f, 0x8f, 0x5e, 0x47, 0xc5, 0x5e, 0xf3, 0xdf, 0xbe, 0xf9, 0x15, 0x00, 0x00, 0xff, 0xff, 0x01, + 0x8b, 0xc9, 0x26, 0xc7, 0x03, 0x00, 0x00, +} diff --git a/grpclb/grpc_lb_v1/grpclb.proto b/grpclb/grpc_lb_v1/grpclb.proto new file mode 100644 index 00000000..13219c6f --- /dev/null +++ b/grpclb/grpc_lb_v1/grpclb.proto @@ -0,0 +1,109 @@ +syntax = "proto3"; + +package grpc.lb.v1; + +message Duration { + // Signed seconds of the span of time. Must be from -315,576,000,000 + // to +315,576,000,000 inclusive. + int64 seconds = 1; + + // Signed fractions of a second at nanosecond resolution of the span + // of time. Durations less than one second are represented with a 0 + // `seconds` field and a positive or negative `nanos` field. For durations + // of one second or more, a non-zero value for the `nanos` field must be + // of the same sign as the `seconds` field. Must be from -999,999,999 + // to +999,999,999 inclusive. + int32 nanos = 2; +} + +service LoadBalancer { + // Bidirectional rpc to get a list of servers. + rpc BalanceLoad(stream LoadBalanceRequest) + returns (stream LoadBalanceResponse); +} + +message LoadBalanceRequest { + oneof load_balance_request_type { + // This message should be sent on the first request to the load balancer. + InitialLoadBalanceRequest initial_request = 1; + + // The client stats should be periodically reported to the load balancer + // based on the duration defined in the InitialLoadBalanceResponse. + ClientStats client_stats = 2; + } +} + +message InitialLoadBalanceRequest { + // Name of load balanced service (IE, service.grpc.gslb.google.com) + string name = 1; +} + +// Contains client level statistics that are useful to load balancing. Each +// count should be reset to zero after reporting the stats. +message ClientStats { + // The total number of requests sent by the client since the last report. + int64 total_requests = 1; + + // The number of client rpc errors since the last report. + int64 client_rpc_errors = 2; + + // The number of dropped requests since the last report. + int64 dropped_requests = 3; +} + +message LoadBalanceResponse { + oneof load_balance_response_type { + // This message should be sent on the first response to the client. + InitialLoadBalanceResponse initial_response = 1; + + // Contains the list of servers selected by the load balancer. The client + // should send requests to these servers in the specified order. + ServerList server_list = 2; + } +} + +message InitialLoadBalanceResponse { + // This is an application layer redirect that indicates the client should use + // the specified server for load balancing. When this field is non-empty in + // the response, the client should open a separate connection to the + // load_balancer_delegate and call the BalanceLoad method. + string load_balancer_delegate = 1; + + // This interval defines how often the client should send the client stats + // to the load balancer. Stats should only be reported when the duration is + // positive. + Duration client_stats_report_interval = 3; +} + +message ServerList { + // Contains a list of servers selected by the load balancer. The list will + // be updated when server resolutions change or as needed to balance load + // across more servers. The client should consume the server list in order + // unless instructed otherwise via the client_config. + repeated Server servers = 1; + + // Indicates the amount of time that the client should consider this server + // list as valid. It may be considered stale after waiting this interval of + // time after receiving the list. If the interval is not positive, the + // client can assume the list is valid until the next list is received. + Duration expiration_interval = 3; +} + +message Server { + // A resolved address for the server, serialized in network-byte-order. It may + // either be an IPv4 or IPv6 address. + bytes ip_address = 1; + + // A resolved port number for the server. + int32 port = 2; + + // An opaque but printable token given to the frontend for each pick. All + // frontend requests for that pick must include the token in its initial + // metadata. The token is used by the backend to verify the request and to + // allow the backend to report load to the gRPC LB system. + string load_balance_token = 3; + + // Indicates whether this particular request should be dropped by the client + // when this server is chosen from the list. + bool drop_request = 4; +} diff --git a/grpclb/grpclb.go b/grpclb/grpclb.go new file mode 100644 index 00000000..932fdf06 --- /dev/null +++ b/grpclb/grpclb.go @@ -0,0 +1,450 @@ +/* + * + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +// Package grpclb implements the load balancing protocol defined at +// https://github.com/grpc/grpc/blob/master/doc/load-balancing.md. +// The implementation is currently EXPERIMENTAL. +package grpclb + +import ( + "errors" + "fmt" + "sync" + + "golang.org/x/net/context" + "google.golang.org/grpc" + lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/naming" +) + +// Balancer creates a grpclb load balancer. +func Balancer(r naming.Resolver) grpc.Balancer { + return &balancer{ + r: r, + } +} + +type remoteBalancerInfo struct { + addr grpc.Address + name string +} + +// addrInfo consists of the information of a backend server. +type addrInfo struct { + addr grpc.Address + connected bool + dropRequest bool +} + +type balancer struct { + r naming.Resolver + mu sync.Mutex + seq int // a sequence number to make sure addrCh does not get stale addresses. + w naming.Watcher + addrCh chan []grpc.Address + rbs []remoteBalancerInfo + addrs []addrInfo + next int + waitCh chan struct{} + done bool +} + +func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo) error { + updates, err := w.Next() + if err != nil { + return err + } + b.mu.Lock() + defer b.mu.Unlock() + if b.done { + return grpc.ErrClientConnClosing + } + var bAddr remoteBalancerInfo + if len(b.rbs) > 0 { + bAddr = b.rbs[0] + } + for _, update := range updates { + addr := grpc.Address{ + Addr: update.Addr, + Metadata: update.Metadata, + } + switch update.Op { + case naming.Add: + var exist bool + for _, v := range b.rbs { + // TODO: Is the same addr with different server name a different balancer? + if addr == v.addr { + exist = true + break + } + } + if exist { + continue + } + b.rbs = append(b.rbs, remoteBalancerInfo{addr: addr}) + case naming.Delete: + for i, v := range b.rbs { + if addr == v.addr { + copy(b.rbs[i:], b.rbs[i+1:]) + b.rbs = b.rbs[:len(b.rbs)-1] + break + } + } + default: + grpclog.Println("Unknown update.Op ", update.Op) + } + } + // TODO: Fall back to the basic round-robin load balancing if the resulting address is + // not a load balancer. + if len(b.rbs) > 0 { + // For simplicity, always use the first one now. May revisit this decision later. + if b.rbs[0] != bAddr { + select { + case <-ch: + default: + } + ch <- b.rbs[0] + } + } + return nil +} + +func (b *balancer) processServerList(l *lbpb.ServerList, seq int) { + servers := l.GetServers() + var ( + sl []addrInfo + addrs []grpc.Address + ) + for _, s := range servers { + // TODO: Support ExpirationInterval + addr := grpc.Address{ + Addr: fmt.Sprintf("%s:%d", s.IpAddress, s.Port), + // TODO: include LoadBalanceToken in the Metadata + } + sl = append(sl, addrInfo{ + addr: addr, + // TODO: Support dropRequest feature. + }) + addrs = append(addrs, addr) + } + b.mu.Lock() + defer b.mu.Unlock() + if b.done || seq < b.seq { + return + } + if len(sl) > 0 { + // reset b.next to 0 when replacing the server list. + b.next = 0 + b.addrs = sl + b.addrCh <- addrs + } + return +} + +func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient) (retry bool) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + stream, err := lbc.BalanceLoad(ctx, grpc.FailFast(false)) + if err != nil { + grpclog.Printf("Failed to perform RPC to the remote balancer %v", err) + return + } + b.mu.Lock() + if b.done { + b.mu.Unlock() + return + } + b.seq++ + seq := b.seq + b.mu.Unlock() + initReq := &lbpb.LoadBalanceRequest{ + LoadBalanceRequestType: &lbpb.LoadBalanceRequest_InitialRequest{ + InitialRequest: new(lbpb.InitialLoadBalanceRequest), + }, + } + if err := stream.Send(initReq); err != nil { + // TODO: backoff on retry? + return true + } + reply, err := stream.Recv() + if err != nil { + // TODO: backoff on retry? + return true + } + initResp := reply.GetInitialResponse() + if initResp == nil { + grpclog.Println("Failed to receive the initial response from the remote balancer.") + return + } + // TODO: Support delegation. + if initResp.LoadBalancerDelegate != "" { + // delegation + grpclog.Println("TODO: Delegation is not supported yet.") + return + } + // Retrieve the server list. + for { + reply, err := stream.Recv() + if err != nil { + break + } + if serverList := reply.GetServerList(); serverList != nil { + b.processServerList(serverList, seq) + } + } + return true +} + +func (b *balancer) Start(target string, config grpc.BalancerConfig) error { + // TODO: Fall back to the basic direct connection if there is no name resolver. + if b.r == nil { + return errors.New("there is no name resolver installed") + } + b.mu.Lock() + if b.done { + b.mu.Unlock() + return grpc.ErrClientConnClosing + } + b.addrCh = make(chan []grpc.Address) + w, err := b.r.Resolve(target) + if err != nil { + b.mu.Unlock() + return err + } + b.w = w + b.mu.Unlock() + balancerAddrCh := make(chan remoteBalancerInfo, 1) + // Spawn a goroutine to monitor the name resolution of remote load balancer. + go func() { + for { + if err := b.watchAddrUpdates(w, balancerAddrCh); err != nil { + grpclog.Printf("grpc: the naming watcher stops working due to %v.\n", err) + close(balancerAddrCh) + return + } + } + }() + // Spawn a goroutine to talk to the remote load balancer. + go func() { + var cc *grpc.ClientConn + for { + rb, ok := <-balancerAddrCh + if cc != nil { + cc.Close() + } + if !ok { + // b is closing. + return + } + + // Talk to the remote load balancer to get the server list. + // + // TODO: override the server name in creds using Metadata in addr. + var err error + creds := config.DialCreds + if creds == nil { + cc, err = grpc.Dial(rb.addr.Addr, grpc.WithInsecure()) + } else { + cc, err = grpc.Dial(rb.addr.Addr, grpc.WithTransportCredentials(creds)) + } + if err != nil { + grpclog.Printf("Failed to setup a connection to the remote balancer %v: %v", rb.addr, err) + return + } + go func(cc *grpc.ClientConn) { + lbc := lbpb.NewLoadBalancerClient(cc) + for { + if retry := b.callRemoteBalancer(lbc); !retry { + cc.Close() + return + } + } + }(cc) + } + }() + return nil +} + +func (b *balancer) down(addr grpc.Address, err error) { + b.mu.Lock() + defer b.mu.Unlock() + for _, a := range b.addrs { + if addr == a.addr { + a.connected = false + break + } + } +} + +func (b *balancer) Up(addr grpc.Address) func(error) { + b.mu.Lock() + defer b.mu.Unlock() + if b.done { + return nil + } + var cnt int + for _, a := range b.addrs { + if a.addr == addr { + if a.connected { + return nil + } + a.connected = true + } + if a.connected { + cnt++ + } + } + // addr is the only one which is connected. Notify the Get() callers who are blocking. + if cnt == 1 && b.waitCh != nil { + close(b.waitCh) + b.waitCh = nil + } + return func(err error) { + b.down(addr, err) + } +} + +func (b *balancer) Get(ctx context.Context, opts grpc.BalancerGetOptions) (addr grpc.Address, put func(), err error) { + var ch chan struct{} + b.mu.Lock() + if b.done { + b.mu.Unlock() + err = grpc.ErrClientConnClosing + return + } + + if len(b.addrs) > 0 { + if b.next >= len(b.addrs) { + b.next = 0 + } + next := b.next + for { + a := b.addrs[next] + next = (next + 1) % len(b.addrs) + if a.connected { + addr = a.addr + b.next = next + b.mu.Unlock() + return + } + if next == b.next { + // Has iterated all the possible address but none is connected. + break + } + } + } + if !opts.BlockingWait { + if len(b.addrs) == 0 { + b.mu.Unlock() + err = fmt.Errorf("there is no address available") + return + } + // Returns the next addr on b.addrs for a failfast RPC. + addr = b.addrs[b.next].addr + b.next++ + b.mu.Unlock() + return + } + // Wait on b.waitCh for non-failfast RPCs. + if b.waitCh == nil { + ch = make(chan struct{}) + b.waitCh = ch + } else { + ch = b.waitCh + } + b.mu.Unlock() + for { + select { + case <-ctx.Done(): + err = ctx.Err() + return + case <-ch: + b.mu.Lock() + if b.done { + b.mu.Unlock() + err = grpc.ErrClientConnClosing + return + } + + if len(b.addrs) > 0 { + if b.next >= len(b.addrs) { + b.next = 0 + } + next := b.next + for { + a := b.addrs[next] + next = (next + 1) % len(b.addrs) + if a.connected { + addr = a.addr + b.next = next + b.mu.Unlock() + return + } + if next == b.next { + // Has iterated all the possible address but none is connected. + break + } + } + } + // The newly added addr got removed by Down() again. + if b.waitCh == nil { + ch = make(chan struct{}) + b.waitCh = ch + } else { + ch = b.waitCh + } + b.mu.Unlock() + } + } +} + +func (b *balancer) Notify() <-chan []grpc.Address { + return b.addrCh +} + +func (b *balancer) Close() error { + b.mu.Lock() + defer b.mu.Unlock() + b.done = true + if b.waitCh != nil { + close(b.waitCh) + } + if b.addrCh != nil { + close(b.addrCh) + } + if b.w != nil { + b.w.Close() + } + return nil +} diff --git a/grpclb/grpclb_test.go b/grpclb/grpclb_test.go new file mode 100644 index 00000000..fabb9fec --- /dev/null +++ b/grpclb/grpclb_test.go @@ -0,0 +1,215 @@ +/* + * + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package grpclb + +import ( + "fmt" + "net" + "strconv" + "strings" + "testing" + + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1" + "google.golang.org/grpc/naming" +) + +type testWatcher struct { + // the channel to receives name resolution updates + update chan *naming.Update + // the side channel to get to know how many updates in a batch + side chan int + // the channel to notifiy update injector that the update reading is done + readDone chan int +} + +func (w *testWatcher) Next() (updates []*naming.Update, err error) { + n, ok := <-w.side + if !ok { + return nil, fmt.Errorf("w.side is closed") + } + for i := 0; i < n; i++ { + u, ok := <-w.update + if !ok { + break + } + if u != nil { + updates = append(updates, u) + } + } + w.readDone <- 0 + return +} + +func (w *testWatcher) Close() { +} + +// Inject naming resolution updates to the testWatcher. +func (w *testWatcher) inject(updates []*naming.Update) { + w.side <- len(updates) + for _, u := range updates { + w.update <- u + } + <-w.readDone +} + +type testNameResolver struct { + w *testWatcher + addr string +} + +func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) { + r.w = &testWatcher{ + update: make(chan *naming.Update, 1), + side: make(chan int, 1), + readDone: make(chan int), + } + r.w.side <- 1 + r.w.update <- &naming.Update{ + Op: naming.Add, + Addr: r.addr, + } + go func() { + <-r.w.readDone + }() + return r.w, nil +} + +type remoteBalancer struct { + servers *lbpb.ServerList + done chan struct{} +} + +func newRemoteBalancer(servers *lbpb.ServerList) *remoteBalancer { + return &remoteBalancer{ + servers: servers, + done: make(chan struct{}), + } +} + +func (b *remoteBalancer) stop() { + close(b.done) +} +func (b *remoteBalancer) BalanceLoad(stream lbpb.LoadBalancer_BalanceLoadServer) error { + resp := &lbpb.LoadBalanceResponse{ + LoadBalanceResponseType: &lbpb.LoadBalanceResponse_InitialResponse{ + InitialResponse: new(lbpb.InitialLoadBalanceResponse), + }, + } + if err := stream.Send(resp); err != nil { + return err + } + resp = &lbpb.LoadBalanceResponse{ + LoadBalanceResponseType: &lbpb.LoadBalanceResponse_ServerList{ + ServerList: b.servers, + }, + } + if err := stream.Send(resp); err != nil { + return err + } + <-b.done + return nil +} + +func startBackends(lis ...net.Listener) (servers []*grpc.Server) { + for _, l := range lis { + s := grpc.NewServer() + servers = append(servers, s) + go func(s *grpc.Server, l net.Listener) { + s.Serve(l) + }(s, l) + } + return +} + +func stopBackends(servers []*grpc.Server) { + for _, s := range servers { + s.Stop() + } +} + +func TestGRPCLB(t *testing.T) { + // Start a backend. + beLis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen %v", err) + } + backends := startBackends(beLis) + defer stopBackends(backends) + // Start a load balancer. + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to create the listener for the load balancer %v", err) + } + lb := grpc.NewServer() + addr := strings.Split(lis.Addr().String(), ":") + port, err := strconv.Atoi(addr[1]) + if err != nil { + t.Fatalf("Failed to generate the port number %v", err) + } + be := &lbpb.Server{ + IpAddress: []byte(addr[0]), + Port: int32(port), + } + var bes []*lbpb.Server + bes = append(bes, be) + sl := &lbpb.ServerList{ + Servers: bes, + } + ls := newRemoteBalancer(sl) + lbpb.RegisterLoadBalancerServer(lb, ls) + go func() { + lb.Serve(lis) + }() + defer func() { + ls.stop() + lb.Stop() + }() + cc, err := grpc.Dial("foo.bar.com", grpc.WithBalancer(Balancer(&testNameResolver{ + addr: lis.Addr().String(), + })), grpc.WithInsecure(), grpc.WithBlock()) + if err != nil { + t.Fatalf("Failed to dial to the backend %v", err) + } + // Issue an unimplemented RPC and expect codes.Unimplemented. + var ( + req, reply lbpb.Duration + ) + if err := grpc.Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || grpc.Code(err) != codes.Unimplemented { + t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want error code %s", err, codes.Unimplemented) + } + cc.Close() +} diff --git a/interceptor.go b/interceptor.go index 588f59e5..8d932efe 100644 --- a/interceptor.go +++ b/interceptor.go @@ -37,6 +37,22 @@ import ( "golang.org/x/net/context" ) +// UnaryInvoker is called by UnaryClientInterceptor to complete RPCs. +type UnaryInvoker func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error + +// UnaryClientInterceptor intercepts the execution of a unary RPC on the client. inovker is the handler to complete the RPC +// and it is the responsibility of the interceptor to call it. +// This is the EXPERIMENTAL API. +type UnaryClientInterceptor func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error + +// Streamer is called by StreamClientInterceptor to create a ClientStream. +type Streamer func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) + +// StreamClientInterceptor intercepts the creation of ClientStream. It may return a custom ClientStream to intercept all I/O +// operations. streamer is the handlder to create a ClientStream and it is the responsibility of the interceptor to call it. +// This is the EXPERIMENTAL API. +type StreamClientInterceptor func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) + // UnaryServerInfo consists of various information about a unary RPC on // server side. All per-rpc information may be mutated by the interceptor. type UnaryServerInfo struct { diff --git a/interop/client/client.go b/interop/client/client.go index 9eb6f305..7ae864e3 100644 --- a/interop/client/client.go +++ b/interop/client/client.go @@ -54,7 +54,7 @@ var ( defaultServiceAccount = flag.String("default_service_account", "", "Email of GCE default service account") serverHost = flag.String("server_host", "127.0.0.1", "The server host name") serverPort = flag.Int("server_port", 10000, "The server port number") - tlsServerName = flag.String("server_host_override", "x.test.youtube.com", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.") + tlsServerName = flag.String("server_host_override", "", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.") testCase = flag.String("test_case", "large_unary", `Configure different test cases. Valid options are: empty_unary : empty (zero bytes) request and response; diff --git a/metadata/metadata.go b/metadata/metadata.go index 954c0f77..3c0ca7a3 100644 --- a/metadata/metadata.go +++ b/metadata/metadata.go @@ -117,10 +117,17 @@ func (md MD) Len() int { // Copy returns a copy of md. func (md MD) Copy() MD { + return Join(md) +} + +// Join joins any number of MDs into a single MD. +// The order of values for each key is determined by the order in which +// the MDs containing those values are presented to Join. +func Join(mds ...MD) MD { out := MD{} - for k, v := range md { - for _, i := range v { - out[k] = append(out[k], i) + for _, md := range mds { + for k, v := range md { + out[k] = append(out[k], v...) } } return out diff --git a/metadata/metadata_test.go b/metadata/metadata_test.go index 99e86820..fef0a0f8 100644 --- a/metadata/metadata_test.go +++ b/metadata/metadata_test.go @@ -107,3 +107,33 @@ func TestPairsMD(t *testing.T) { } } } + +func TestCopy(t *testing.T) { + const key, val = "key", "val" + orig := Pairs(key, val) + copy := orig.Copy() + if !reflect.DeepEqual(orig, copy) { + t.Errorf("copied value not equal to the original, got %v, want %v", copy, orig) + } + orig[key][0] = "foo" + if v := copy[key][0]; v != val { + t.Errorf("change in original should not affect copy, got %q, want %q", v, val) + } +} + +func TestJoin(t *testing.T) { + for _, test := range []struct { + mds []MD + want MD + }{ + {[]MD{}, MD{}}, + {[]MD{Pairs("foo", "bar")}, Pairs("foo", "bar")}, + {[]MD{Pairs("foo", "bar"), Pairs("foo", "baz")}, Pairs("foo", "bar", "foo", "baz")}, + {[]MD{Pairs("foo", "bar"), Pairs("foo", "baz"), Pairs("zip", "zap")}, Pairs("foo", "bar", "foo", "baz", "zip", "zap")}, + } { + md := Join(test.mds...) + if !reflect.DeepEqual(md, test.want) { + t.Errorf("context's metadata is %v, want %v", md, test.want) + } + } +} diff --git a/rpc_util.go b/rpc_util.go index 35ac9cc7..6b60095d 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -303,10 +303,10 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er case compressionNone: case compressionMade: if dc == nil || recvCompress != dc.Type() { - return transport.StreamErrorf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress) + return Errorf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress) } default: - return transport.StreamErrorf(codes.Internal, "grpc: received unexpected payload format %d", pf) + return Errorf(codes.Internal, "grpc: received unexpected payload format %d", pf) } return nil } diff --git a/rpc_util_test.go b/rpc_util_test.go index 8a813c62..0ba2d44c 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -152,7 +152,7 @@ func TestToRPCErr(t *testing.T) { // outputs errOut *rpcError }{ - {transport.StreamErrorf(codes.Unknown, ""), Errorf(codes.Unknown, "").(*rpcError)}, + {transport.StreamError{codes.Unknown, ""}, Errorf(codes.Unknown, "").(*rpcError)}, {transport.ErrConnClosing, Errorf(codes.Internal, transport.ErrConnClosing.Desc).(*rpcError)}, } { err := toRPCErr(test.errIn) @@ -173,8 +173,8 @@ func TestContextErr(t *testing.T) { // outputs errOut transport.StreamError }{ - {context.DeadlineExceeded, transport.StreamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded)}, - {context.Canceled, transport.StreamErrorf(codes.Canceled, "%v", context.Canceled)}, + {context.DeadlineExceeded, transport.StreamError{codes.DeadlineExceeded, context.DeadlineExceeded.Error()}}, + {context.Canceled, transport.StreamError{codes.Canceled, context.Canceled.Error()}}, } { err := transport.ContextErr(test.errIn) if err != test.errOut { diff --git a/server.go b/server.go index 1ed8aac9..2524dd22 100644 --- a/server.go +++ b/server.go @@ -324,7 +324,7 @@ func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credenti // Serve accepts incoming connections on the listener lis, creating a new // ServerTransport and service goroutine for each. The service goroutines // read gRPC requests and then call the registered handlers to reply to them. -// Service returns when lis.Accept fails. lis will be closed when +// Serve returns when lis.Accept fails. lis will be closed when // this method returns. func (s *Server) Serve(lis net.Listener) error { s.mu.Lock() @@ -367,7 +367,10 @@ func (s *Server) handleRawConn(rawConn net.Conn) { s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err) s.mu.Unlock() grpclog.Printf("grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err) - rawConn.Close() + // If serverHandShake returns ErrConnDispatched, keep rawConn open. + if err != credentials.ErrConnDispatched { + rawConn.Close() + } return } @@ -544,7 +547,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. return err } if err == io.ErrUnexpectedEOF { - err = transport.StreamError{Code: codes.Internal, Desc: "io.ErrUnexpectedEOF"} + err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error()) } if err != nil { switch err := err.(type) { @@ -566,8 +569,8 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil { switch err := err.(type) { - case transport.StreamError: - if err := t.WriteStatus(stream, err.Code, err.Desc); err != nil { + case *rpcError: + if err := t.WriteStatus(stream, err.code, err.desc); err != nil { grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err) } default: diff --git a/stream.go b/stream.go index 51df3f01..e1b4759e 100644 --- a/stream.go +++ b/stream.go @@ -97,7 +97,14 @@ type ClientStream interface { // NewClientStream creates a new Stream for the client side. This is called // by generated code. -func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) { +func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) { + if cc.dopts.streamInt != nil { + return cc.dopts.streamInt(ctx, desc, cc, method, newClientStream, opts...) + } + return newClientStream(ctx, desc, cc, method, opts...) +} + +func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) { var ( t transport.ClientTransport s *transport.Stream @@ -296,7 +303,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { } }() if err != nil { - return transport.StreamErrorf(codes.Internal, "grpc: %v", err) + return Errorf(codes.Internal, "grpc: %v", err) } return cs.t.Write(cs.s, out, &transport.Options{Last: false}) } @@ -468,10 +475,13 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { } }() if err != nil { - err = transport.StreamErrorf(codes.Internal, "grpc: %v", err) + err = Errorf(codes.Internal, "grpc: %v", err) return err } - return ss.t.Write(ss.s, out, &transport.Options{Last: false}) + if err := ss.t.Write(ss.s, out, &transport.Options{Last: false}); err != nil { + return toRPCErr(err) + } + return nil } func (ss *serverStream) RecvMsg(m interface{}) (err error) { @@ -489,5 +499,14 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { ss.mu.Unlock() } }() - return recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize) + if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize); err != nil { + if err == io.EOF { + return err + } + if err == io.ErrUnexpectedEOF { + err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error()) + } + return toRPCErr(err) + } + return nil } diff --git a/test/end2end_test.go b/test/end2end_test.go index a97a8a47..bf4b49db 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -66,7 +66,6 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" testpb "google.golang.org/grpc/test/grpc_testing" - "google.golang.org/grpc/transport" ) var ( @@ -369,8 +368,10 @@ type test struct { userAgent string clientCompression bool serverCompression bool - unaryInt grpc.UnaryServerInterceptor - streamInt grpc.StreamServerInterceptor + unaryClientInt grpc.UnaryClientInterceptor + streamClientInt grpc.StreamClientInterceptor + unaryServerInt grpc.UnaryServerInterceptor + streamServerInt grpc.StreamServerInterceptor // srv and srvAddr are set once startServer is called. srv *grpc.Server @@ -413,8 +414,7 @@ func newTest(t *testing.T, e env) *test { // call to te.tearDown to clean up. func (te *test) startServer(ts testpb.TestServiceServer) { te.testServer = ts - e := te.e - te.t.Logf("Running test in %s environment...", e.name) + te.t.Logf("Running test in %s environment...", te.e.name) sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(te.maxStream)} if te.maxMsgSize > 0 { sopts = append(sopts, grpc.MaxMsgSize(te.maxMsgSize)) @@ -425,19 +425,19 @@ func (te *test) startServer(ts testpb.TestServiceServer) { grpc.RPCDecompressor(grpc.NewGZIPDecompressor()), ) } - if te.unaryInt != nil { - sopts = append(sopts, grpc.UnaryInterceptor(te.unaryInt)) + if te.unaryServerInt != nil { + sopts = append(sopts, grpc.UnaryInterceptor(te.unaryServerInt)) } - if te.streamInt != nil { - sopts = append(sopts, grpc.StreamInterceptor(te.streamInt)) + if te.streamServerInt != nil { + sopts = append(sopts, grpc.StreamInterceptor(te.streamServerInt)) } la := "localhost:0" - switch e.network { + switch te.e.network { case "unix": la = "/tmp/testsock" + fmt.Sprintf("%d", time.Now()) syscall.Unlink(la) } - lis, err := net.Listen(e.network, la) + lis, err := net.Listen(te.e.network, la) if err != nil { te.t.Fatalf("Failed to listen: %v", err) } @@ -455,7 +455,7 @@ func (te *test) startServer(ts testpb.TestServiceServer) { } s := grpc.NewServer(sopts...) te.srv = s - if e.httpHandler { + if te.e.httpHandler { internal.TestingUseHandlerImpl(s) } if te.healthServer != nil { @@ -465,7 +465,7 @@ func (te *test) startServer(ts testpb.TestServiceServer) { testpb.RegisterTestServiceServer(s, te.testServer) } addr := la - switch e.network { + switch te.e.network { case "unix": default: _, port, err := net.SplitHostPort(lis.Addr().String()) @@ -494,6 +494,12 @@ func (te *test) clientConn() *grpc.ClientConn { grpc.WithDecompressor(grpc.NewGZIPDecompressor()), ) } + if te.unaryClientInt != nil { + opts = append(opts, grpc.WithUnaryInterceptor(te.unaryClientInt)) + } + if te.streamClientInt != nil { + opts = append(opts, grpc.WithStreamInterceptor(te.streamClientInt)) + } switch te.e.security { case "tls": creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") @@ -566,7 +572,7 @@ func testTimeoutOnDeadServer(t *testing.T, e env) { te.srv.Stop() ctx, _ := context.WithTimeout(context.Background(), time.Millisecond) if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { - t.Fatalf("TestService/EmptyCall(%v, _) = _, %v, want _, error code: %d", ctx, err, codes.DeadlineExceeded) + t.Fatalf("TestService/EmptyCall(%v, _) = _, %v, want _, error code: %s", ctx, err, codes.DeadlineExceeded) } awaitNewConnLogOutput() } @@ -849,10 +855,10 @@ func testFailFast(t *testing.T, e env) { } // The client keeps reconnecting and ongoing fail-fast RPCs should fail with code.Unavailable. if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable { - t.Fatalf("TestService/EmptyCall(_, _, _) = _, %v, want _, error code: %d", err, codes.Unavailable) + t.Fatalf("TestService/EmptyCall(_, _, _) = _, %v, want _, error code: %s", err, codes.Unavailable) } if _, err := tc.StreamingInputCall(context.Background()); grpc.Code(err) != codes.Unavailable { - t.Fatalf("TestService/StreamingInputCall(_) = _, %v, want _, error code: %d", err, codes.Unavailable) + t.Fatalf("TestService/StreamingInputCall(_) = _, %v, want _, error code: %s", err, codes.Unavailable) } awaitNewConnLogOutput() @@ -911,7 +917,7 @@ func testHealthCheckOnFailure(t *testing.T, e env) { cc := te.clientConn() wantErr := grpc.Errorf(codes.DeadlineExceeded, "context deadline exceeded") if _, err := healthCheck(0*time.Second, cc, "grpc.health.v1.Health"); !equalErrors(err, wantErr) { - t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %d", err, codes.DeadlineExceeded) + t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %s", err, codes.DeadlineExceeded) } awaitNewConnLogOutput() } @@ -961,7 +967,7 @@ func testHealthCheckServingStatus(t *testing.T, e env) { } wantErr := grpc.Errorf(codes.NotFound, "unknown service") if _, err := healthCheck(1*time.Second, cc, "grpc.health.v1.Health"); !equalErrors(err, wantErr) { - t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %d", err, codes.NotFound) + t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %s", err, codes.NotFound) } hs.SetServingStatus("grpc.health.v1.Health", healthpb.HealthCheckResponse_SERVING) out, err = healthCheck(1*time.Second, cc, "grpc.health.v1.Health") @@ -1112,7 +1118,7 @@ func testExceedMsgLimit(t *testing.T, e env) { Payload: payload, } if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal { - t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %d", err, codes.Internal) + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.Internal) } stream, err := tc.FullDuplexCall(te.ctx) @@ -1139,7 +1145,7 @@ func testExceedMsgLimit(t *testing.T, e env) { t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) } if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal { - t.Fatalf("%v.Recv() = _, %v, want _, error code: %d", stream, err, codes.Internal) + t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.Internal) } } @@ -1214,7 +1220,7 @@ func testMalformedHTTP2Metadata(t *testing.T, e env) { } ctx := metadata.NewContext(context.Background(), malformedHTTP2Metadata) if _, err := tc.UnaryCall(ctx, req); grpc.Code(err) != codes.Internal { - t.Fatalf("TestService.UnaryCall(%v, _) = _, %v; want _, %q", ctx, err, codes.Internal) + t.Fatalf("TestService.UnaryCall(%v, _) = _, %v; want _, %s", ctx, err, codes.Internal) } } @@ -1331,7 +1337,7 @@ func testRPCTimeout(t *testing.T, e env) { for i := -1; i <= 10; i++ { ctx, _ := context.WithTimeout(context.Background(), time.Duration(i)*time.Millisecond) if _, err := tc.UnaryCall(ctx, req); grpc.Code(err) != codes.DeadlineExceeded { - t.Fatalf("TestService/UnaryCallv(_, _) = _, %v; want , error code: %d", err, codes.DeadlineExceeded) + t.Fatalf("TestService/UnaryCallv(_, _) = _, %v; want , error code: %s", err, codes.DeadlineExceeded) } } } @@ -1368,7 +1374,7 @@ func testCancel(t *testing.T, e env) { ctx, cancel := context.WithCancel(context.Background()) time.AfterFunc(1*time.Millisecond, cancel) if r, err := tc.UnaryCall(ctx, req); grpc.Code(err) != codes.Canceled { - t.Fatalf("TestService/UnaryCall(_, _) = %v, %v; want _, error code: %d", r, err, codes.Canceled) + t.Fatalf("TestService/UnaryCall(_, _) = %v, %v; want _, error code: %s", r, err, codes.Canceled) } awaitNewConnLogOutput() } @@ -1416,7 +1422,7 @@ func testCancelNoIO(t *testing.T, e env) { if grpc.Code(err) == codes.DeadlineExceeded { break } - t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %d", tc, err, codes.DeadlineExceeded) + t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %s", tc, err, codes.DeadlineExceeded) } // If there are any RPCs in flight before the client receives // the max streams setting, let them be expired. @@ -1465,7 +1471,7 @@ func testNoService(t *testing.T, e env) { t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) } if _, err := stream.Recv(); grpc.Code(err) != codes.Unimplemented { - t.Fatalf("stream.Recv() = _, %v, want _, error code %d", err, codes.Unimplemented) + t.Fatalf("stream.Recv() = _, %v, want _, error code %s", err, codes.Unimplemented) } } @@ -1689,6 +1695,85 @@ func testFailedServerStreaming(t *testing.T, e env) { } } +// checkTimeoutErrorServer is a gRPC server checks context timeout error in FullDuplexCall(). +// It is only used in TestStreamingRPCTimeoutServerError. +type checkTimeoutErrorServer struct { + t *testing.T + done chan struct{} + testpb.TestServiceServer +} + +func (s *checkTimeoutErrorServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error { + defer close(s.done) + for { + _, err := stream.Recv() + if err != nil { + if grpc.Code(err) != codes.DeadlineExceeded { + s.t.Errorf("stream.Recv() = _, %v, want error code %s", err, codes.DeadlineExceeded) + } + return err + } + if err := stream.Send(&testpb.StreamingOutputCallResponse{ + Payload: &testpb.Payload{ + Body: []byte{'0'}, + }, + }); err != nil { + if grpc.Code(err) != codes.DeadlineExceeded { + s.t.Errorf("stream.Send(_) = %v, want error code %s", err, codes.DeadlineExceeded) + } + return err + } + } +} + +func TestStreamingRPCTimeoutServerError(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testStreamingRPCTimeoutServerError(t, e) + } +} + +// testStreamingRPCTimeoutServerError tests the server side behavior. +// When context timeout happens on client side, server should get deadline exceeded error. +func testStreamingRPCTimeoutServerError(t *testing.T, e env) { + te := newTest(t, e) + serverDone := make(chan struct{}) + te.startServer(&checkTimeoutErrorServer{t: t, done: serverDone}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + + req := &testpb.StreamingOutputCallRequest{} + for duration := 50 * time.Millisecond; ; duration *= 2 { + ctx, _ := context.WithTimeout(context.Background(), duration) + stream, err := tc.FullDuplexCall(ctx, grpc.FailFast(false)) + if grpc.Code(err) == codes.DeadlineExceeded { + // Redo test with double timeout. + continue + } + if err != nil { + t.Errorf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + return + } + for { + err := stream.Send(req) + if err != nil { + break + } + _, err = stream.Recv() + if err != nil { + break + } + } + + // Wait for context timeout on server before closing connection + // to make sure the server will get timeout error. + <-serverDone + break + } +} + // concurrentSendServer is a TestServiceServer whose // StreamingOutputCall makes ten serial Send calls, sending payloads // "0".."9", inclusive. TestServerStreamingConcurrent verifies they @@ -1848,7 +1933,7 @@ func testClientStreamingError(t *testing.T, e env) { continue } if _, err := stream.CloseAndRecv(); grpc.Code(err) != codes.NotFound { - t.Fatalf("%v.CloseAndRecv() = %v, want error %d", stream, err, codes.NotFound) + t.Fatalf("%v.CloseAndRecv() = %v, want error %s", stream, err, codes.NotFound) } break } @@ -1891,7 +1976,7 @@ func testExceedMaxStreamsLimit(t *testing.T, e env) { if grpc.Code(err) == codes.DeadlineExceeded { break } - t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %d", tc, err, codes.DeadlineExceeded) + t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %s", tc, err, codes.DeadlineExceeded) } } @@ -1931,7 +2016,7 @@ func testStreamsQuotaRecovery(t *testing.T, e env) { if grpc.Code(err) == codes.DeadlineExceeded { break } - t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %d", tc, err, codes.DeadlineExceeded) + t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %s", tc, err, codes.DeadlineExceeded) } cancel() @@ -1977,7 +2062,7 @@ func testCompressServerHasNoSupport(t *testing.T, e env) { Payload: payload, } if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Unimplemented { - t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code %d", err, codes.Unimplemented) + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code %s", err, codes.Unimplemented) } // Streaming RPC stream, err := tc.FullDuplexCall(context.Background()) @@ -2002,7 +2087,7 @@ func testCompressServerHasNoSupport(t *testing.T, e env) { t.Fatalf("%v.Send(%v) = %v, want ", stream, sreq, err) } if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Unimplemented { - t.Fatalf("%v.Recv() = %v, want error code %d", stream, err, codes.Unimplemented) + t.Fatalf("%v.Recv() = %v, want error code %s", stream, err, codes.Unimplemented) } } @@ -2066,6 +2151,75 @@ func testCompressOK(t *testing.T, e env) { } } +func TestUnaryClientInterceptor(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testUnaryClientInterceptor(t, e) + } +} + +func failOkayRPC(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + err := invoker(ctx, method, req, reply, cc, opts...) + if err == nil { + return grpc.Errorf(codes.NotFound, "") + } + return err +} + +func testUnaryClientInterceptor(t *testing.T, e env) { + te := newTest(t, e) + te.userAgent = testAppUA + te.unaryClientInt = failOkayRPC + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + tc := testpb.NewTestServiceClient(te.clientConn()) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.NotFound { + t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, error code %s", tc, err, codes.NotFound) + } +} + +func TestStreamClientInterceptor(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testStreamClientInterceptor(t, e) + } +} + +func failOkayStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + s, err := streamer(ctx, desc, cc, method, opts...) + if err == nil { + return nil, grpc.Errorf(codes.NotFound, "") + } + return s, nil +} + +func testStreamClientInterceptor(t *testing.T, e env) { + te := newTest(t, e) + te.streamClientInt = failOkayStream + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + tc := testpb.NewTestServiceClient(te.clientConn()) + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(int32(1)), + }, + } + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(1)) + if err != nil { + t.Fatal(err) + } + req := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: payload, + } + if _, err := tc.StreamingOutputCall(context.Background(), req); grpc.Code(err) != codes.NotFound { + t.Fatalf("%v.StreamingOutputCall(_) = _, %v, want _, error code %s", tc, err, codes.NotFound) + } +} + func TestUnaryServerInterceptor(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() { @@ -2079,13 +2233,13 @@ func errInjector(ctx context.Context, req interface{}, info *grpc.UnaryServerInf func testUnaryServerInterceptor(t *testing.T, e env) { te := newTest(t, e) - te.unaryInt = errInjector + te.unaryServerInt = errInjector te.startServer(&testServer{security: e.security}) defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.PermissionDenied { - t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, error code %d", tc, err, codes.PermissionDenied) + t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, error code %s", tc, err, codes.PermissionDenied) } } @@ -2110,7 +2264,7 @@ func fullDuplexOnly(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServ func testStreamServerInterceptor(t *testing.T, e env) { te := newTest(t, e) - te.streamInt = fullDuplexOnly + te.streamServerInt = fullDuplexOnly te.startServer(&testServer{security: e.security}) defer te.tearDown() @@ -2134,7 +2288,7 @@ func testStreamServerInterceptor(t *testing.T, e env) { t.Fatalf("%v.StreamingOutputCall(_) = _, %v, want _, ", tc, err) } if _, err := s1.Recv(); grpc.Code(err) != codes.PermissionDenied { - t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, error code %d", tc, err, codes.PermissionDenied) + t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, error code %s", tc, err, codes.PermissionDenied) } s2, err := tc.FullDuplexCall(context.Background()) if err != nil { @@ -2281,8 +2435,8 @@ func testClientRequestBodyErrorCancelStreamingInput(t *testing.T, e env) { case <-time.After(3 * time.Second): t.Fatal("timeout waiting for error") } - if se, ok := got.(transport.StreamError); !ok || se.Code != codes.Canceled { - t.Errorf("error = %#v; want transport.StreamError with code Canceled", got) + if grpc.Code(got) != codes.Canceled { + t.Errorf("error = %#v; want error code %s", got, codes.Canceled) } }) } @@ -2302,10 +2456,16 @@ func (c clientAlwaysFailCred) ServerHandshake(rawConn net.Conn) (net.Conn, crede func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo { return credentials.ProtocolInfo{} } +func (c clientAlwaysFailCred) Clone() credentials.TransportCredentials { + return nil +} +func (c clientAlwaysFailCred) OverrideServerName(s string) error { + return nil +} func TestDialWithBlockErrorOnBadCertificates(t *testing.T) { te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: true}) - te.startServer(&testServer{security: "clientAlwaysFailCred"}) + te.startServer(&testServer{security: te.e.security}) defer te.tearDown() var ( @@ -2321,7 +2481,7 @@ func TestDialWithBlockErrorOnBadCertificates(t *testing.T) { func TestFailFastRPCErrorOnBadCertificates(t *testing.T) { te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: true}) - te.startServer(&testServer{security: "clientAlwaysFailCred"}) + te.startServer(&testServer{security: te.e.security}) defer te.tearDown() cc := te.clientConn() @@ -2333,7 +2493,7 @@ func TestFailFastRPCErrorOnBadCertificates(t *testing.T) { func TestFailFastRPCWithNoBalancerErrorOnBadCertificates(t *testing.T) { te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: false}) - te.startServer(&testServer{security: "clientAlwaysFailCred"}) + te.startServer(&testServer{security: te.e.security}) defer te.tearDown() cc := te.clientConn() @@ -2345,7 +2505,7 @@ func TestFailFastRPCWithNoBalancerErrorOnBadCertificates(t *testing.T) { func TestNonFailFastRPCWithNoBalancerErrorOnBadCertificates(t *testing.T) { te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: false}) - te.startServer(&testServer{security: "clientAlwaysFailCred"}) + te.startServer(&testServer{security: te.e.security}) defer te.tearDown() cc := te.clientConn() @@ -2372,11 +2532,17 @@ func (c *clientTimeoutCreds) ServerHandshake(rawConn net.Conn) (net.Conn, creden func (c *clientTimeoutCreds) Info() credentials.ProtocolInfo { return credentials.ProtocolInfo{} } +func (c *clientTimeoutCreds) Clone() credentials.TransportCredentials { + return nil +} +func (c *clientTimeoutCreds) OverrideServerName(s string) error { + return nil +} func TestNonFailFastRPCSucceedOnTimeoutCreds(t *testing.T) { te := newTest(t, env{name: "timeout-cred", network: "tcp", security: "clientTimeoutCreds", balancer: false}) te.userAgent = testAppUA - te.startServer(&testServer{security: "clientTimeoutCreds"}) + te.startServer(&testServer{security: te.e.security}) defer te.tearDown() cc := te.clientConn() @@ -2387,6 +2553,60 @@ func TestNonFailFastRPCSucceedOnTimeoutCreds(t *testing.T) { } } +type serverDispatchCred struct { + ready chan struct{} + rawConn net.Conn +} + +func newServerDispatchCred() *serverDispatchCred { + return &serverDispatchCred{ + ready: make(chan struct{}), + } +} +func (c *serverDispatchCred) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, nil, nil +} +func (c *serverDispatchCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + c.rawConn = rawConn + close(c.ready) + return nil, nil, credentials.ErrConnDispatched +} +func (c *serverDispatchCred) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{} +} +func (c *serverDispatchCred) Clone() credentials.TransportCredentials { + return nil +} +func (c *serverDispatchCred) OverrideServerName(s string) error { + return nil +} +func (c *serverDispatchCred) getRawConn() net.Conn { + <-c.ready + return c.rawConn +} + +func TestServerCredsDispatch(t *testing.T) { + lis, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + cred := newServerDispatchCred() + s := grpc.NewServer(grpc.Creds(cred)) + go s.Serve(lis) + defer s.Stop() + + cc, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(cred)) + if err != nil { + t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err) + } + defer cc.Close() + + // Check rawConn is not closed. + if n, err := cred.getRawConn().Write([]byte{0}); n <= 0 || err != nil { + t.Errorf("Read() = %v, %v; want n>0, ", n, err) + } +} + // interestingGoroutines returns all goroutines we care about for the purpose // of leak checking. It excludes testing or runtime ones. func interestingGoroutines() (gs []string) { diff --git a/transport/handler_server.go b/transport/handler_server.go index 30e21ac0..114e3490 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -85,7 +85,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr if v := r.Header.Get("grpc-timeout"); v != "" { to, err := decodeTimeout(v) if err != nil { - return nil, StreamErrorf(codes.Internal, "malformed time-out: %v", err) + return nil, streamErrorf(codes.Internal, "malformed time-out: %v", err) } st.timeoutSet = true st.timeout = to @@ -393,5 +393,5 @@ func mapRecvMsgError(err error) error { } } } - return ConnectionError{Desc: err.Error()} + return connectionErrorf(true, err, err.Error()) } diff --git a/transport/http2_client.go b/transport/http2_client.go index 5e7c8c25..3c185541 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -114,14 +114,42 @@ func dial(fn func(context.Context, string) (net.Conn, error), ctx context.Contex return dialContext(ctx, "tcp", addr) } +func isTemporary(err error) bool { + switch err { + case io.EOF: + // Connection closures may be resolved upon retry, and are thus + // treated as temporary. + return true + case context.DeadlineExceeded: + // In Go 1.7, context.DeadlineExceeded implements Timeout(), and this + // special case is not needed. Until then, we need to keep this + // clause. + return true + } + + switch err := err.(type) { + case interface { + Temporary() bool + }: + return err.Temporary() + case interface { + Timeout() bool + }: + // Timeouts may be resolved upon retry, and are thus treated as + // temporary. + return err.Timeout() + } + return false +} + // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // and starts to receive messages on it. Non-nil error returns if construction // fails. func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ ClientTransport, err error) { scheme := "http" - conn, connErr := dial(opts.Dialer, ctx, addr) - if connErr != nil { - return nil, ConnectionErrorf(true, connErr, "transport: %v", connErr) + conn, err := dial(opts.Dialer, ctx, addr) + if err != nil { + return nil, connectionErrorf(true, err, "transport: %v", err) } // Any further errors will close the underlying connection defer func(conn net.Conn) { @@ -132,17 +160,13 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl var authInfo credentials.AuthInfo if creds := opts.TransportCredentials; creds != nil { scheme = "https" - conn, authInfo, connErr = creds.ClientHandshake(ctx, addr, conn) - } - if connErr != nil { - // Credentials handshake error is not a temporary error (unless the error - // was the connection closing or deadline exceeded). - var temp bool - switch connErr { - case io.EOF, context.DeadlineExceeded: - temp = true + conn, authInfo, err = creds.ClientHandshake(ctx, addr, conn) + if err != nil { + // Credentials handshake errors are typically considered permanent + // to avoid retrying on e.g. bad certificates. + temp := isTemporary(err) + return nil, connectionErrorf(temp, err, "transport: %v", err) } - return nil, ConnectionErrorf(temp, connErr, "transport: %v", connErr) } ua := primaryUA if opts.UserAgent != "" { @@ -181,11 +205,11 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl n, err := t.conn.Write(clientPreface) if err != nil { t.Close() - return nil, ConnectionErrorf(true, err, "transport: %v", err) + return nil, connectionErrorf(true, err, "transport: %v", err) } if n != len(clientPreface) { t.Close() - return nil, ConnectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) + return nil, connectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) } if initialWindowSize != defaultWindowSize { err = t.framer.writeSettings(true, http2.Setting{ @@ -197,13 +221,13 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl } if err != nil { t.Close() - return nil, ConnectionErrorf(true, err, "transport: %v", err) + return nil, connectionErrorf(true, err, "transport: %v", err) } // Adjust the connection flow control window if needed. if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 { if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil { t.Close() - return nil, ConnectionErrorf(true, err, "transport: %v", err) + return nil, connectionErrorf(true, err, "transport: %v", err) } } go t.controller() @@ -228,8 +252,10 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { s.windowHandler = func(n int) { t.updateWindow(s, uint32(n)) } - // Make a stream be able to cancel the pending operations by itself. - s.ctx, s.cancel = context.WithCancel(ctx) + // The client side stream context should have exactly the same life cycle with the user provided context. + // That means, s.ctx should be read-only. And s.ctx is done iff ctx is done. + // So we use the original context here instead of creating a copy. + s.ctx = ctx s.dec = &recvBufferReader{ ctx: s.ctx, goAway: s.goAway, @@ -241,16 +267,6 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { // NewStream creates a stream and register it into the transport as "active" // streams. func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) { - // Record the timeout value on the context. - var timeout time.Duration - if dl, ok := ctx.Deadline(); ok { - timeout = dl.Sub(time.Now()) - } - select { - case <-ctx.Done(): - return nil, ContextErr(ctx.Err()) - default: - } pr := &peer.Peer{ Addr: t.conn.RemoteAddr(), } @@ -271,12 +287,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } pos := strings.LastIndex(callHdr.Method, "/") if pos == -1 { - return nil, StreamErrorf(codes.InvalidArgument, "transport: malformed method name: %q", callHdr.Method) + return nil, streamErrorf(codes.InvalidArgument, "transport: malformed method name: %q", callHdr.Method) } audience := "https://" + callHdr.Host + port + callHdr.Method[:pos] data, err := c.GetRequestMetadata(ctx, audience) if err != nil { - return nil, StreamErrorf(codes.InvalidArgument, "transport: %v", err) + return nil, streamErrorf(codes.InvalidArgument, "transport: %v", err) } for k, v := range data { authData[k] = v @@ -357,9 +373,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea if callHdr.SendCompress != "" { t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress}) } - if timeout > 0 { + if dl, ok := ctx.Deadline(); ok { + // Send out timeout regardless its value. The server can detect timeout context by itself. + timeout := dl.Sub(time.Now()) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)}) } + for k, v := range authData { // Capital header names are illegal in HTTP/2. k = strings.ToLower(k) @@ -413,7 +432,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } if err != nil { t.notifyError(err) - return nil, ConnectionErrorf(true, err, "transport: %v", err) + return nil, connectionErrorf(true, err, "transport: %v", err) } } t.writableChan <- 0 @@ -459,7 +478,7 @@ func (t *http2Client) CloseStream(s *Stream, err error) { } s.state = streamDone s.mu.Unlock() - if _, ok := err.(StreamError); ok { + if se, ok := err.(StreamError); ok && se.Code != codes.DeadlineExceeded { t.controlBuf.put(&resetStream{s.id, http2.ErrCodeCancel}) } } @@ -627,7 +646,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { // invoked. if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil { t.notifyError(err) - return ConnectionErrorf(true, err, "transport: %v", err) + return connectionErrorf(true, err, "transport: %v", err) } if t.framer.adjustNumWriters(-1) == 0 { t.framer.flushWrite() @@ -675,7 +694,7 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) { func (t *http2Client) handleData(f *http2.DataFrame) { size := len(f.Data()) if err := t.fc.onData(uint32(size)); err != nil { - t.notifyError(ConnectionErrorf(true, err, "%v", err)) + t.notifyError(connectionErrorf(true, err, "%v", err)) return } // Select the right stream to dispatch. @@ -781,7 +800,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { if t.state == reachable || t.state == draining { if f.LastStreamID > 0 && f.LastStreamID%2 != 1 { t.mu.Unlock() - t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: stream ID %d is even", f.LastStreamID)) + t.notifyError(connectionErrorf(true, nil, "received illegal http2 GOAWAY frame: stream ID %d is even", f.LastStreamID)) return } select { @@ -790,7 +809,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { // t.goAway has been closed (i.e.,multiple GoAways). if id < f.LastStreamID { t.mu.Unlock() - t.notifyError(ConnectionErrorf(true, nil, "received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID)) + t.notifyError(connectionErrorf(true, nil, "received illegal http2 GOAWAY frame: previously recv GOAWAY frame with LastStramID %d, currently recv %d", id, f.LastStreamID)) return } t.prevGoAwayID = id @@ -828,6 +847,12 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { state.processHeaderField(hf) } if state.err != nil { + s.mu.Lock() + if !s.headerDone { + close(s.headerChan) + s.headerDone = true + } + s.mu.Unlock() s.write(recvMsg{err: state.err}) // Something wrong. Stops reading even when there is remaining. return @@ -905,7 +930,7 @@ func (t *http2Client) reader() { t.mu.Unlock() if s != nil { // use error detail to provide better err message - handleMalformedHTTP2(s, StreamErrorf(http2ErrConvTab[se.Code], "%v", t.framer.errorDetail())) + handleMalformedHTTP2(s, streamErrorf(http2ErrConvTab[se.Code], "%v", t.framer.errorDetail())) } continue } else { diff --git a/transport/http2_server.go b/transport/http2_server.go index 16010d55..f753c4f1 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -111,12 +111,12 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI Val: uint32(initialWindowSize)}) } if err := framer.writeSettings(true, settings...); err != nil { - return nil, ConnectionErrorf(true, err, "transport: %v", err) + return nil, connectionErrorf(true, err, "transport: %v", err) } // Adjust the connection flow control window if needed. if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 { if err := framer.writeWindowUpdate(true, 0, delta); err != nil { - return nil, ConnectionErrorf(true, err, "transport: %v", err) + return nil, connectionErrorf(true, err, "transport: %v", err) } } var buf bytes.Buffer @@ -448,7 +448,7 @@ func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) e } if err != nil { t.Close() - return ConnectionErrorf(true, err, "transport: %v", err) + return connectionErrorf(true, err, "transport: %v", err) } } return nil @@ -544,7 +544,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { s.mu.Lock() if s.state == streamDone { s.mu.Unlock() - return StreamErrorf(codes.Unknown, "the stream has been done") + return streamErrorf(codes.Unknown, "the stream has been done") } if !s.headerOk { writeHeaderFrame = true @@ -568,7 +568,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { } if err := t.framer.writeHeaders(false, p); err != nil { t.Close() - return ConnectionErrorf(true, err, "transport: %v", err) + return connectionErrorf(true, err, "transport: %v", err) } t.writableChan <- 0 } @@ -642,7 +642,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { } if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil { t.Close() - return ConnectionErrorf(true, err, "transport: %v", err) + return connectionErrorf(true, err, "transport: %v", err) } if t.framer.adjustNumWriters(-1) == 0 { t.framer.flushWrite() diff --git a/transport/http_util.go b/transport/http_util.go index 79da5126..a3c68d4c 100644 --- a/transport/http_util.go +++ b/transport/http_util.go @@ -162,7 +162,7 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) { switch f.Name { case "content-type": if !validContentType(f.Value) { - d.setErr(StreamErrorf(codes.FailedPrecondition, "transport: received the unexpected content-type %q", f.Value)) + d.setErr(streamErrorf(codes.FailedPrecondition, "transport: received the unexpected content-type %q", f.Value)) return } case "grpc-encoding": @@ -170,7 +170,7 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) { case "grpc-status": code, err := strconv.Atoi(f.Value) if err != nil { - d.setErr(StreamErrorf(codes.Internal, "transport: malformed grpc-status: %v", err)) + d.setErr(streamErrorf(codes.Internal, "transport: malformed grpc-status: %v", err)) return } d.statusCode = codes.Code(code) @@ -181,7 +181,7 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) { var err error d.timeout, err = decodeTimeout(f.Value) if err != nil { - d.setErr(StreamErrorf(codes.Internal, "transport: malformed time-out: %v", err)) + d.setErr(streamErrorf(codes.Internal, "transport: malformed time-out: %v", err)) return } case ":path": @@ -253,6 +253,9 @@ func div(d, r time.Duration) int64 { // TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it. func encodeTimeout(t time.Duration) string { + if t <= 0 { + return "0n" + } if d := div(t, time.Nanosecond); d <= maxTimeoutValue { return strconv.FormatInt(d, 10) + "n" } @@ -349,7 +352,7 @@ func decodeGrpcMessageUnchecked(msg string) string { for i := 0; i < lenMsg; i++ { c := msg[i] if c == percentByte && i+2 < lenMsg { - parsed, err := strconv.ParseInt(msg[i+1:i+3], 16, 8) + parsed, err := strconv.ParseUint(msg[i+1:i+3], 16, 8) if err != nil { buf.WriteByte(c) } else { diff --git a/transport/http_util_test.go b/transport/http_util_test.go index 41bf5477..a5f8a85f 100644 --- a/transport/http_util_test.go +++ b/transport/http_util_test.go @@ -135,6 +135,7 @@ func TestDecodeGrpcMessage(t *testing.T) { {"H%61o", "Hao"}, {"H%6", "H%6"}, {"%G0", "%G0"}, + {"%E7%B3%BB%E7%BB%9F", "系统"}, } { actual := decodeGrpcMessage(tt.input) if tt.expected != actual { diff --git a/transport/transport.go b/transport/transport.go index d59e5113..e2f691ef 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -169,7 +169,8 @@ type Stream struct { // nil for client side Stream. st ServerTransport // ctx is the associated context of the stream. - ctx context.Context + ctx context.Context + // cancel is always nil for client side Stream. cancel context.CancelFunc // done is closed when the final status arrives. done chan struct{} @@ -476,16 +477,16 @@ type ServerTransport interface { Drain() } -// StreamErrorf creates an StreamError with the specified error code and description. -func StreamErrorf(c codes.Code, format string, a ...interface{}) StreamError { +// streamErrorf creates an StreamError with the specified error code and description. +func streamErrorf(c codes.Code, format string, a ...interface{}) StreamError { return StreamError{ Code: c, Desc: fmt.Sprintf(format, a...), } } -// ConnectionErrorf creates an ConnectionError with the specified error description. -func ConnectionErrorf(temp bool, e error, format string, a ...interface{}) ConnectionError { +// connectionErrorf creates an ConnectionError with the specified error description. +func connectionErrorf(temp bool, e error, format string, a ...interface{}) ConnectionError { return ConnectionError{ Desc: fmt.Sprintf(format, a...), temp: temp, @@ -522,10 +523,10 @@ func (e ConnectionError) Origin() error { var ( // ErrConnClosing indicates that the transport is closing. - ErrConnClosing = ConnectionError{Desc: "transport is closing", temp: true} + ErrConnClosing = connectionErrorf(true, nil, "transport is closing") // ErrStreamDrain indicates that the stream is rejected by the server because // the server stops accepting new RPCs. - ErrStreamDrain = StreamErrorf(codes.Unavailable, "the server stops accepting new RPCs") + ErrStreamDrain = streamErrorf(codes.Unavailable, "the server stops accepting new RPCs") ) // StreamError is an error that only affects one stream within a connection. @@ -542,9 +543,9 @@ func (e StreamError) Error() string { func ContextErr(err error) StreamError { switch err { case context.DeadlineExceeded: - return StreamErrorf(codes.DeadlineExceeded, "%v", err) + return streamErrorf(codes.DeadlineExceeded, "%v", err) case context.Canceled: - return StreamErrorf(codes.Canceled, "%v", err) + return streamErrorf(codes.Canceled, "%v", err) } panic(fmt.Sprintf("Unexpected error from context packet: %v", err)) } diff --git a/transport/transport_test.go b/transport/transport_test.go index 2c2f1f66..01d95e47 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -40,12 +40,14 @@ import ( "math" "net" "strconv" + "strings" "sync" "testing" "time" "golang.org/x/net/context" "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" "google.golang.org/grpc/codes" ) @@ -58,14 +60,15 @@ type server struct { } var ( - expectedRequest = []byte("ping") - expectedResponse = []byte("pong") - expectedRequestLarge = make([]byte, initialWindowSize*2) - expectedResponseLarge = make([]byte, initialWindowSize*2) + expectedRequest = []byte("ping") + expectedResponse = []byte("pong") + expectedRequestLarge = make([]byte, initialWindowSize*2) + expectedResponseLarge = make([]byte, initialWindowSize*2) + expectedInvalidHeaderField = "invalid/content-type" ) type testStreamHandler struct { - t ServerTransport + t *http2Server } type hType int @@ -75,6 +78,7 @@ const ( suspended misbehaved encodingRequiredStatus + invalidHeaderField ) func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) { @@ -140,6 +144,16 @@ func (h *testStreamHandler) handleStreamEncodingRequiredStatus(t *testing.T, s * h.t.WriteStatus(s, encodingTestStatusCode, encodingTestStatusDesc) } +func (h *testStreamHandler) handleStreamInvalidHeaderField(t *testing.T, s *Stream) { + <-h.t.writableChan + h.t.hBuf.Reset() + h.t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField}) + if err := h.t.writeHeaders(s, h.t.hBuf, false); err != nil { + t.Fatalf("Failed to write headers: %v", err) + } + h.t.writableChan <- 0 +} + // start starts server. Other goroutines should block on s.readyChan for further operations. func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) { var err error @@ -177,7 +191,7 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) { } s.conns[transport] = true s.mu.Unlock() - h := &testStreamHandler{transport} + h := &testStreamHandler{transport.(*http2Server)} switch ht { case suspended: go transport.HandleStreams(h.handleStreamSuspension) @@ -189,6 +203,10 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) { go transport.HandleStreams(func(s *Stream) { go h.handleStreamEncodingRequiredStatus(t, s) }) + case invalidHeaderField: + go transport.HandleStreams(func(s *Stream) { + go h.handleStreamInvalidHeaderField(t, s) + }) default: go transport.HandleStreams(func(s *Stream) { go h.handleStream(t, s) @@ -414,7 +432,7 @@ func TestLargeMessageSuspension(t *testing.T) { } // Write should not be done successfully due to flow control. err = ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}) - expectedErr := StreamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded) + expectedErr := streamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded) if err != expectedErr { t.Fatalf("Write got %v, want %v", err, expectedErr) } @@ -752,6 +770,32 @@ func TestEncodingRequiredStatus(t *testing.T) { server.stop() } +func TestInvalidHeaderField(t *testing.T) { + server, ct := setUp(t, 0, math.MaxUint32, invalidHeaderField) + callHdr := &CallHdr{ + Host: "localhost", + Method: "foo", + } + s, err := ct.NewStream(context.Background(), callHdr) + if err != nil { + return + } + opts := Options{ + Last: true, + Delay: false, + } + if err := ct.Write(s, expectedRequest, &opts); err != nil && err != io.EOF { + t.Fatalf("Failed to write the request: %v", err) + } + p := make([]byte, http2MaxFrameLen) + _, err = s.dec.Read(p) + if se, ok := err.(StreamError); !ok || se.Code != codes.FailedPrecondition || !strings.Contains(err.Error(), expectedInvalidHeaderField) { + t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.FailedPrecondition, expectedInvalidHeaderField) + } + ct.Close() + server.stop() +} + func TestStreamContext(t *testing.T) { expectedStream := &Stream{} ctx := newContextWithStream(context.Background(), expectedStream)