From 474679aec49b8320fac9633f194b9b92a196a034 Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Wed, 5 Oct 2016 15:51:45 -0700 Subject: [PATCH] grpclb: override credentials server name using the metadata in name resolution --- grpclb/grpclb.go | 63 ++++++++++++++++++++++++------ grpclb/grpclb_test.go | 91 ++++++++++++++++++++++++++++++++++++------- 2 files changed, 130 insertions(+), 24 deletions(-) diff --git a/grpclb/grpclb.go b/grpclb/grpclb.go index 932fdf06..27fd033a 100644 --- a/grpclb/grpclb.go +++ b/grpclb/grpclb.go @@ -48,6 +48,26 @@ import ( "google.golang.org/grpc/naming" ) +// AddressType indicates the address type returned by name resolution. +type AddressType uint8 + +const ( + // Backend indicates the server is a backend server. + Backend AddressType = iota + // GRPCLB indicates the server is a grpclb load balancer. + GRPCLB +) + +// Metadata contains the information the name resolution for grpclb should provide. The +// name resolver used by grpclb balancer is required to provide this type of metadata in +// its address updates. +type Metadata struct { + // AddrType is the type of server (grpc load balancer or backend). + AddrType AddressType + // ServerName is the name of the grpc load balancer. Used for authentication. + ServerName string +} + // Balancer creates a grpclb load balancer. func Balancer(r naming.Resolver) grpc.Balancer { return &balancer{ @@ -56,7 +76,8 @@ func Balancer(r naming.Resolver) grpc.Balancer { } type remoteBalancerInfo struct { - addr grpc.Address + addr string + // the server name used for authentication with the remote LB server. name string } @@ -95,16 +116,12 @@ func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo bAddr = b.rbs[0] } for _, update := range updates { - addr := grpc.Address{ - Addr: update.Addr, - Metadata: update.Metadata, - } switch update.Op { case naming.Add: var exist bool for _, v := range b.rbs { // TODO: Is the same addr with different server name a different balancer? - if addr == v.addr { + if update.Addr == v.addr { exist = true break } @@ -112,10 +129,29 @@ func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo if exist { continue } - b.rbs = append(b.rbs, remoteBalancerInfo{addr: addr}) + md, ok := update.Metadata.(*Metadata) + if !ok { + // TODO: Revisit the handling here and may introduce some fallback mechanism. + grpclog.Printf("The name resolution contains unexpected metadata %v", update.Metadata) + continue + } + switch md.AddrType { + case Backend: + // TODO: Revisit the handling here and may introduce some fallback mechanism. + grpclog.Printf("The name resolution does not give grpclb addresses") + continue + case GRPCLB: + b.rbs = append(b.rbs, remoteBalancerInfo{ + addr: update.Addr, + name: md.ServerName, + }) + default: + grpclog.Printf("Received unknow address type %d", md.AddrType) + continue + } case naming.Delete: for i, v := range b.rbs { - if addr == v.addr { + if update.Addr == v.addr { copy(b.rbs[i:], b.rbs[i+1:]) b.rbs = b.rbs[:len(b.rbs)-1] break @@ -267,16 +303,21 @@ func (b *balancer) Start(target string, config grpc.BalancerConfig) error { // b is closing. return } - // Talk to the remote load balancer to get the server list. // // TODO: override the server name in creds using Metadata in addr. var err error creds := config.DialCreds if creds == nil { - cc, err = grpc.Dial(rb.addr.Addr, grpc.WithInsecure()) + cc, err = grpc.Dial(rb.addr, grpc.WithInsecure()) } else { - cc, err = grpc.Dial(rb.addr.Addr, grpc.WithTransportCredentials(creds)) + if rb.name != "" { + if err := creds.OverrideServerName(rb.name); err != nil { + grpclog.Printf("Failed to override the server name in the credentials: %v", err) + continue + } + } + cc, err = grpc.Dial(rb.addr, grpc.WithTransportCredentials(creds)) } if err != nil { grpclog.Printf("Failed to setup a connection to the remote balancer %v: %v", rb.addr, err) diff --git a/grpclb/grpclb_test.go b/grpclb/grpclb_test.go index fabb9fec..1e7eee81 100644 --- a/grpclb/grpclb_test.go +++ b/grpclb/grpclb_test.go @@ -34,7 +34,9 @@ package grpclb import ( + "errors" "fmt" + "io" "net" "strconv" "strings" @@ -43,10 +45,16 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1" "google.golang.org/grpc/naming" ) +var ( + lbsn = "bar.com" + besn = "foo.com" +) + type testWatcher struct { // the channel to receives name resolution updates update chan *naming.Update @@ -101,6 +109,10 @@ func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) { r.w.update <- &naming.Update{ Op: naming.Add, Addr: r.addr, + Metadata: &Metadata{ + AddrType: GRPCLB, + ServerName: lbsn, + }, } go func() { <-r.w.readDone @@ -108,6 +120,45 @@ func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) { return r.w, nil } +type serverNameCheckCreds struct { + t *testing.T + expected string + sn string +} + +func (c *serverNameCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + if _, err := io.WriteString(rawConn, c.sn); err != nil { + c.t.Errorf("Failed to write the server name %s to the client %v", c.sn, err) + return nil, nil, err + } + return rawConn, nil, nil +} +func (c *serverNameCheckCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + b := make([]byte, len(c.expected)) + if _, err := rawConn.Read(b); err != nil { + c.t.Errorf("Failed to read the server name from the server %v", err) + return nil, nil, err + } + if c.expected != string(b) { + c.t.Errorf("Read the server name %s want %s", string(b), c.expected) + return nil, nil, errors.New("received unexpected server name") + } + return rawConn, nil, nil +} +func (c *serverNameCheckCreds) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{} +} +func (c *serverNameCheckCreds) Clone() credentials.TransportCredentials { + return &serverNameCheckCreds{ + t: c.t, + expected: c.expected, + } +} +func (c *serverNameCheckCreds) OverrideServerName(s string) error { + c.expected = s + return nil +} + type remoteBalancer struct { servers *lbpb.ServerList done chan struct{} @@ -123,6 +174,7 @@ func newRemoteBalancer(servers *lbpb.ServerList) *remoteBalancer { func (b *remoteBalancer) stop() { close(b.done) } + func (b *remoteBalancer) BalanceLoad(stream lbpb.LoadBalancer_BalanceLoadServer) error { resp := &lbpb.LoadBalanceResponse{ LoadBalanceResponseType: &lbpb.LoadBalanceResponse_InitialResponse{ @@ -144,9 +196,13 @@ func (b *remoteBalancer) BalanceLoad(stream lbpb.LoadBalancer_BalanceLoadServer) return nil } -func startBackends(lis ...net.Listener) (servers []*grpc.Server) { +func startBackends(t *testing.T, sn string, lis ...net.Listener) (servers []*grpc.Server) { for _, l := range lis { - s := grpc.NewServer() + creds := &serverNameCheckCreds{ + t: t, + sn: sn, + } + s := grpc.NewServer(grpc.Creds(creds)) servers = append(servers, s) go func(s *grpc.Server, l net.Listener) { s.Serve(l) @@ -167,22 +223,27 @@ func TestGRPCLB(t *testing.T) { if err != nil { t.Fatalf("Failed to listen %v", err) } - backends := startBackends(beLis) + beAddr := strings.Split(beLis.Addr().String(), ":") + bePort, err := strconv.Atoi(beAddr[1]) + backends := startBackends(t, besn, beLis) defer stopBackends(backends) + // Start a load balancer. - lis, err := net.Listen("tcp", "localhost:0") + lbLis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("Failed to create the listener for the load balancer %v", err) } - lb := grpc.NewServer() - addr := strings.Split(lis.Addr().String(), ":") - port, err := strconv.Atoi(addr[1]) + lbCreds := &serverNameCheckCreds{ + t: t, + sn: lbsn, + } + lb := grpc.NewServer(grpc.Creds(lbCreds)) if err != nil { t.Fatalf("Failed to generate the port number %v", err) } be := &lbpb.Server{ - IpAddress: []byte(addr[0]), - Port: int32(port), + IpAddress: []byte(beAddr[0]), + Port: int32(bePort), } var bes []*lbpb.Server bes = append(bes, be) @@ -192,15 +253,19 @@ func TestGRPCLB(t *testing.T) { ls := newRemoteBalancer(sl) lbpb.RegisterLoadBalancerServer(lb, ls) go func() { - lb.Serve(lis) + lb.Serve(lbLis) }() defer func() { ls.stop() lb.Stop() }() - cc, err := grpc.Dial("foo.bar.com", grpc.WithBalancer(Balancer(&testNameResolver{ - addr: lis.Addr().String(), - })), grpc.WithInsecure(), grpc.WithBlock()) + creds := serverNameCheckCreds{ + t: t, + expected: besn, + } + cc, err := grpc.Dial(besn, grpc.WithBalancer(Balancer(&testNameResolver{ + addr: lbLis.Addr().String(), + })), grpc.WithBlock(), grpc.WithTransportCredentials(&creds)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) }