diff --git a/clientconn.go b/clientconn.go index 11dce44f..c81e3892 100644 --- a/clientconn.go +++ b/clientconn.go @@ -684,7 +684,11 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { } ctx, cancel := context.WithTimeout(ac.ctx, timeout) connectTime := time.Now() - newTransport, err := transport.NewClientTransport(ctx, ac.addr.Addr, ac.dopts.copts) + sinfo := transport.TargetInfo{ + Addr: ac.addr.Addr, + Metadata: ac.addr.Metadata, + } + newTransport, err := transport.NewClientTransport(ctx, sinfo, ac.dopts.copts) if err != nil { cancel() diff --git a/grpclb/grpclb.go b/grpclb/grpclb.go index 27fd033a..78137453 100644 --- a/grpclb/grpclb.go +++ b/grpclb/grpclb.go @@ -45,6 +45,7 @@ import ( "google.golang.org/grpc" lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1" "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/naming" ) @@ -184,9 +185,10 @@ func (b *balancer) processServerList(l *lbpb.ServerList, seq int) { ) for _, s := range servers { // TODO: Support ExpirationInterval + md := metadata.Pairs("lb-token", s.LoadBalanceToken) addr := grpc.Address{ - Addr: fmt.Sprintf("%s:%d", s.IpAddress, s.Port), - // TODO: include LoadBalanceToken in the Metadata + Addr: fmt.Sprintf("%s:%d", s.IpAddress, s.Port), + Metadata: &md, } sl = append(sl, addrInfo{ addr: addr, diff --git a/grpclb/grpclb_test.go b/grpclb/grpclb_test.go index 8d12ebb4..658e7225 100644 --- a/grpclb/grpclb_test.go +++ b/grpclb/grpclb_test.go @@ -47,13 +47,16 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" + hwpb "google.golang.org/grpc/examples/helloworld/helloworld" lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/naming" ) var ( - lbsn = "bar.com" - besn = "foo.com" + lbsn = "bar.com" + besn = "foo.com" + lbToken = "iamatoken" ) type testWatcher struct { @@ -195,12 +198,29 @@ func (b *remoteBalancer) BalanceLoad(stream lbpb.LoadBalancer_BalanceLoadServer) return nil } +type helloServer struct { +} + +func (s *helloServer) SayHello(ctx context.Context, in *hwpb.HelloRequest) (*hwpb.HelloReply, error) { + md, ok := metadata.FromContext(ctx) + if !ok { + return nil, grpc.Errorf(codes.Internal, "failed to receive metadata") + } + if md == nil || md["lb-token"][0] != lbToken { + return nil, grpc.Errorf(codes.Internal, "received unexpected metadata: %v", md) + } + return &hwpb.HelloReply{ + Message: "Hello " + in.Name, + }, nil +} + func startBackends(t *testing.T, sn string, lis ...net.Listener) (servers []*grpc.Server) { for _, l := range lis { creds := &serverNameCheckCreds{ sn: sn, } s := grpc.NewServer(grpc.Creds(creds)) + hwpb.RegisterGreeterServer(s, &helloServer{}) servers = append(servers, s) go func(s *grpc.Server, l net.Listener) { s.Serve(l) @@ -239,8 +259,9 @@ func TestGRPCLB(t *testing.T) { t.Fatalf("Failed to generate the port number %v", err) } be := &lbpb.Server{ - IpAddress: []byte(beAddr[0]), - Port: int32(bePort), + IpAddress: []byte(beAddr[0]), + Port: int32(bePort), + LoadBalanceToken: lbToken, } var bes []*lbpb.Server bes = append(bes, be) @@ -266,12 +287,9 @@ func TestGRPCLB(t *testing.T) { if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } - // Issue an unimplemented RPC and expect codes.Unimplemented. - var ( - req, reply lbpb.Duration - ) - if err := grpc.Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || grpc.Code(err) != codes.Unimplemented { - t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want error code %s", err, codes.Unimplemented) + helloC := hwpb.NewGreeterClient(cc) + if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil { + t.Fatalf("%v.SayHello(_, _) = _, %v, want _, ", helloC, err) } cc.Close() } diff --git a/transport/http2_client.go b/transport/http2_client.go index 3c185541..8e7701ef 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -57,6 +57,7 @@ import ( type http2Client struct { target string // server name/addr userAgent string + md interface{} conn net.Conn // underlying communication channel authInfo credentials.AuthInfo // auth info about the connection nextID uint32 // the next stream ID to be used @@ -145,9 +146,9 @@ func isTemporary(err error) bool { // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // and starts to receive messages on it. Non-nil error returns if construction // fails. -func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ ClientTransport, err error) { +func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (_ ClientTransport, err error) { scheme := "http" - conn, err := dial(opts.Dialer, ctx, addr) + conn, err := dial(opts.Dialer, ctx, addr.Addr) if err != nil { return nil, connectionErrorf(true, err, "transport: %v", err) } @@ -160,7 +161,7 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl var authInfo credentials.AuthInfo if creds := opts.TransportCredentials; creds != nil { scheme = "https" - conn, authInfo, err = creds.ClientHandshake(ctx, addr, conn) + conn, authInfo, err = creds.ClientHandshake(ctx, addr.Addr, conn) if err != nil { // Credentials handshake errors are typically considered permanent // to avoid retrying on e.g. bad certificates. @@ -174,8 +175,9 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl } var buf bytes.Buffer t := &http2Client{ - target: addr, + target: addr.Addr, userAgent: ua, + md: addr.Metadata, conn: conn, authInfo: authInfo, // The client initiated stream id is odd starting from 1. @@ -400,6 +402,16 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } } } + if md, ok := t.md.(*metadata.MD); ok { + for k, v := range *md { + if isReservedHeader(k) { + continue + } + for _, entry := range v { + t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry}) + } + } + } first := true // Sends the headers in a single batch even when they span multiple frames. for !endHeaders { diff --git a/transport/transport.go b/transport/transport.go index 3d6b6a6d..c82b5f37 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -343,7 +343,7 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32, authI return newHTTP2Server(conn, maxStreams, authInfo) } -// ConnectOptions covers all relevant options for dialing a server. +// ConnectOptions covers all relevant options for communicating with the server. type ConnectOptions struct { // UserAgent is the application user agent. UserAgent string @@ -355,9 +355,15 @@ type ConnectOptions struct { TransportCredentials credentials.TransportCredentials } +// TargetInfo contains the information of the target such as network address and metadata. +type TargetInfo struct { + Addr string + Metadata interface{} +} + // NewClientTransport establishes the transport with the required ConnectOptions // and returns it to the caller. -func NewClientTransport(ctx context.Context, target string, opts ConnectOptions) (ClientTransport, error) { +func NewClientTransport(ctx context.Context, target TargetInfo, opts ConnectOptions) (ClientTransport, error) { return newHTTP2Client(ctx, target, opts) }