grpclb: override credentials server name using the metadata in name resolution
This commit is contained in:
@ -48,6 +48,26 @@ import (
|
|||||||
"google.golang.org/grpc/naming"
|
"google.golang.org/grpc/naming"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// AddressType indicates the address type returned by name resolution.
|
||||||
|
type AddressType uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Backend indicates the server is a backend server.
|
||||||
|
Backend AddressType = iota
|
||||||
|
// GRPCLB indicates the server is a grpclb load balancer.
|
||||||
|
GRPCLB
|
||||||
|
)
|
||||||
|
|
||||||
|
// Metadata contains the information the name resolution for grpclb should provide. The
|
||||||
|
// name resolver used by grpclb balancer is required to provide this type of metadata in
|
||||||
|
// its address updates.
|
||||||
|
type Metadata struct {
|
||||||
|
// AddrType is the type of server (grpc load balancer or backend).
|
||||||
|
AddrType AddressType
|
||||||
|
// ServerName is the name of the grpc load balancer. Used for authentication.
|
||||||
|
ServerName string
|
||||||
|
}
|
||||||
|
|
||||||
// Balancer creates a grpclb load balancer.
|
// Balancer creates a grpclb load balancer.
|
||||||
func Balancer(r naming.Resolver) grpc.Balancer {
|
func Balancer(r naming.Resolver) grpc.Balancer {
|
||||||
return &balancer{
|
return &balancer{
|
||||||
@ -56,7 +76,8 @@ func Balancer(r naming.Resolver) grpc.Balancer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type remoteBalancerInfo struct {
|
type remoteBalancerInfo struct {
|
||||||
addr grpc.Address
|
addr string
|
||||||
|
// the server name used for authentication with the remote LB server.
|
||||||
name string
|
name string
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -95,16 +116,12 @@ func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo
|
|||||||
bAddr = b.rbs[0]
|
bAddr = b.rbs[0]
|
||||||
}
|
}
|
||||||
for _, update := range updates {
|
for _, update := range updates {
|
||||||
addr := grpc.Address{
|
|
||||||
Addr: update.Addr,
|
|
||||||
Metadata: update.Metadata,
|
|
||||||
}
|
|
||||||
switch update.Op {
|
switch update.Op {
|
||||||
case naming.Add:
|
case naming.Add:
|
||||||
var exist bool
|
var exist bool
|
||||||
for _, v := range b.rbs {
|
for _, v := range b.rbs {
|
||||||
// TODO: Is the same addr with different server name a different balancer?
|
// TODO: Is the same addr with different server name a different balancer?
|
||||||
if addr == v.addr {
|
if update.Addr == v.addr {
|
||||||
exist = true
|
exist = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -112,10 +129,29 @@ func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo
|
|||||||
if exist {
|
if exist {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
b.rbs = append(b.rbs, remoteBalancerInfo{addr: addr})
|
md, ok := update.Metadata.(*Metadata)
|
||||||
|
if !ok {
|
||||||
|
// TODO: Revisit the handling here and may introduce some fallback mechanism.
|
||||||
|
grpclog.Printf("The name resolution contains unexpected metadata %v", update.Metadata)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch md.AddrType {
|
||||||
|
case Backend:
|
||||||
|
// TODO: Revisit the handling here and may introduce some fallback mechanism.
|
||||||
|
grpclog.Printf("The name resolution does not give grpclb addresses")
|
||||||
|
continue
|
||||||
|
case GRPCLB:
|
||||||
|
b.rbs = append(b.rbs, remoteBalancerInfo{
|
||||||
|
addr: update.Addr,
|
||||||
|
name: md.ServerName,
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
grpclog.Printf("Received unknow address type %d", md.AddrType)
|
||||||
|
continue
|
||||||
|
}
|
||||||
case naming.Delete:
|
case naming.Delete:
|
||||||
for i, v := range b.rbs {
|
for i, v := range b.rbs {
|
||||||
if addr == v.addr {
|
if update.Addr == v.addr {
|
||||||
copy(b.rbs[i:], b.rbs[i+1:])
|
copy(b.rbs[i:], b.rbs[i+1:])
|
||||||
b.rbs = b.rbs[:len(b.rbs)-1]
|
b.rbs = b.rbs[:len(b.rbs)-1]
|
||||||
break
|
break
|
||||||
@ -267,16 +303,21 @@ func (b *balancer) Start(target string, config grpc.BalancerConfig) error {
|
|||||||
// b is closing.
|
// b is closing.
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Talk to the remote load balancer to get the server list.
|
// Talk to the remote load balancer to get the server list.
|
||||||
//
|
//
|
||||||
// TODO: override the server name in creds using Metadata in addr.
|
// TODO: override the server name in creds using Metadata in addr.
|
||||||
var err error
|
var err error
|
||||||
creds := config.DialCreds
|
creds := config.DialCreds
|
||||||
if creds == nil {
|
if creds == nil {
|
||||||
cc, err = grpc.Dial(rb.addr.Addr, grpc.WithInsecure())
|
cc, err = grpc.Dial(rb.addr, grpc.WithInsecure())
|
||||||
} else {
|
} else {
|
||||||
cc, err = grpc.Dial(rb.addr.Addr, grpc.WithTransportCredentials(creds))
|
if rb.name != "" {
|
||||||
|
if err := creds.OverrideServerName(rb.name); err != nil {
|
||||||
|
grpclog.Printf("Failed to override the server name in the credentials: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cc, err = grpc.Dial(rb.addr, grpc.WithTransportCredentials(creds))
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
grpclog.Printf("Failed to setup a connection to the remote balancer %v: %v", rb.addr, err)
|
grpclog.Printf("Failed to setup a connection to the remote balancer %v: %v", rb.addr, err)
|
||||||
|
@ -34,7 +34,9 @@
|
|||||||
package grpclb
|
package grpclb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -43,10 +45,16 @@ import (
|
|||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1"
|
lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1"
|
||||||
"google.golang.org/grpc/naming"
|
"google.golang.org/grpc/naming"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
lbsn = "bar.com"
|
||||||
|
besn = "foo.com"
|
||||||
|
)
|
||||||
|
|
||||||
type testWatcher struct {
|
type testWatcher struct {
|
||||||
// the channel to receives name resolution updates
|
// the channel to receives name resolution updates
|
||||||
update chan *naming.Update
|
update chan *naming.Update
|
||||||
@ -101,6 +109,10 @@ func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) {
|
|||||||
r.w.update <- &naming.Update{
|
r.w.update <- &naming.Update{
|
||||||
Op: naming.Add,
|
Op: naming.Add,
|
||||||
Addr: r.addr,
|
Addr: r.addr,
|
||||||
|
Metadata: &Metadata{
|
||||||
|
AddrType: GRPCLB,
|
||||||
|
ServerName: lbsn,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
<-r.w.readDone
|
<-r.w.readDone
|
||||||
@ -108,6 +120,45 @@ func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) {
|
|||||||
return r.w, nil
|
return r.w, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type serverNameCheckCreds struct {
|
||||||
|
t *testing.T
|
||||||
|
expected string
|
||||||
|
sn string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverNameCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||||
|
if _, err := io.WriteString(rawConn, c.sn); err != nil {
|
||||||
|
c.t.Errorf("Failed to write the server name %s to the client %v", c.sn, err)
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return rawConn, nil, nil
|
||||||
|
}
|
||||||
|
func (c *serverNameCheckCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||||
|
b := make([]byte, len(c.expected))
|
||||||
|
if _, err := rawConn.Read(b); err != nil {
|
||||||
|
c.t.Errorf("Failed to read the server name from the server %v", err)
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if c.expected != string(b) {
|
||||||
|
c.t.Errorf("Read the server name %s want %s", string(b), c.expected)
|
||||||
|
return nil, nil, errors.New("received unexpected server name")
|
||||||
|
}
|
||||||
|
return rawConn, nil, nil
|
||||||
|
}
|
||||||
|
func (c *serverNameCheckCreds) Info() credentials.ProtocolInfo {
|
||||||
|
return credentials.ProtocolInfo{}
|
||||||
|
}
|
||||||
|
func (c *serverNameCheckCreds) Clone() credentials.TransportCredentials {
|
||||||
|
return &serverNameCheckCreds{
|
||||||
|
t: c.t,
|
||||||
|
expected: c.expected,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func (c *serverNameCheckCreds) OverrideServerName(s string) error {
|
||||||
|
c.expected = s
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type remoteBalancer struct {
|
type remoteBalancer struct {
|
||||||
servers *lbpb.ServerList
|
servers *lbpb.ServerList
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
@ -123,6 +174,7 @@ func newRemoteBalancer(servers *lbpb.ServerList) *remoteBalancer {
|
|||||||
func (b *remoteBalancer) stop() {
|
func (b *remoteBalancer) stop() {
|
||||||
close(b.done)
|
close(b.done)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *remoteBalancer) BalanceLoad(stream lbpb.LoadBalancer_BalanceLoadServer) error {
|
func (b *remoteBalancer) BalanceLoad(stream lbpb.LoadBalancer_BalanceLoadServer) error {
|
||||||
resp := &lbpb.LoadBalanceResponse{
|
resp := &lbpb.LoadBalanceResponse{
|
||||||
LoadBalanceResponseType: &lbpb.LoadBalanceResponse_InitialResponse{
|
LoadBalanceResponseType: &lbpb.LoadBalanceResponse_InitialResponse{
|
||||||
@ -144,9 +196,13 @@ func (b *remoteBalancer) BalanceLoad(stream lbpb.LoadBalancer_BalanceLoadServer)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func startBackends(lis ...net.Listener) (servers []*grpc.Server) {
|
func startBackends(t *testing.T, sn string, lis ...net.Listener) (servers []*grpc.Server) {
|
||||||
for _, l := range lis {
|
for _, l := range lis {
|
||||||
s := grpc.NewServer()
|
creds := &serverNameCheckCreds{
|
||||||
|
t: t,
|
||||||
|
sn: sn,
|
||||||
|
}
|
||||||
|
s := grpc.NewServer(grpc.Creds(creds))
|
||||||
servers = append(servers, s)
|
servers = append(servers, s)
|
||||||
go func(s *grpc.Server, l net.Listener) {
|
go func(s *grpc.Server, l net.Listener) {
|
||||||
s.Serve(l)
|
s.Serve(l)
|
||||||
@ -167,22 +223,27 @@ func TestGRPCLB(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to listen %v", err)
|
t.Fatalf("Failed to listen %v", err)
|
||||||
}
|
}
|
||||||
backends := startBackends(beLis)
|
beAddr := strings.Split(beLis.Addr().String(), ":")
|
||||||
|
bePort, err := strconv.Atoi(beAddr[1])
|
||||||
|
backends := startBackends(t, besn, beLis)
|
||||||
defer stopBackends(backends)
|
defer stopBackends(backends)
|
||||||
|
|
||||||
// Start a load balancer.
|
// Start a load balancer.
|
||||||
lis, err := net.Listen("tcp", "localhost:0")
|
lbLis, err := net.Listen("tcp", "localhost:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create the listener for the load balancer %v", err)
|
t.Fatalf("Failed to create the listener for the load balancer %v", err)
|
||||||
}
|
}
|
||||||
lb := grpc.NewServer()
|
lbCreds := &serverNameCheckCreds{
|
||||||
addr := strings.Split(lis.Addr().String(), ":")
|
t: t,
|
||||||
port, err := strconv.Atoi(addr[1])
|
sn: lbsn,
|
||||||
|
}
|
||||||
|
lb := grpc.NewServer(grpc.Creds(lbCreds))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to generate the port number %v", err)
|
t.Fatalf("Failed to generate the port number %v", err)
|
||||||
}
|
}
|
||||||
be := &lbpb.Server{
|
be := &lbpb.Server{
|
||||||
IpAddress: []byte(addr[0]),
|
IpAddress: []byte(beAddr[0]),
|
||||||
Port: int32(port),
|
Port: int32(bePort),
|
||||||
}
|
}
|
||||||
var bes []*lbpb.Server
|
var bes []*lbpb.Server
|
||||||
bes = append(bes, be)
|
bes = append(bes, be)
|
||||||
@ -192,15 +253,19 @@ func TestGRPCLB(t *testing.T) {
|
|||||||
ls := newRemoteBalancer(sl)
|
ls := newRemoteBalancer(sl)
|
||||||
lbpb.RegisterLoadBalancerServer(lb, ls)
|
lbpb.RegisterLoadBalancerServer(lb, ls)
|
||||||
go func() {
|
go func() {
|
||||||
lb.Serve(lis)
|
lb.Serve(lbLis)
|
||||||
}()
|
}()
|
||||||
defer func() {
|
defer func() {
|
||||||
ls.stop()
|
ls.stop()
|
||||||
lb.Stop()
|
lb.Stop()
|
||||||
}()
|
}()
|
||||||
cc, err := grpc.Dial("foo.bar.com", grpc.WithBalancer(Balancer(&testNameResolver{
|
creds := serverNameCheckCreds{
|
||||||
addr: lis.Addr().String(),
|
t: t,
|
||||||
})), grpc.WithInsecure(), grpc.WithBlock())
|
expected: besn,
|
||||||
|
}
|
||||||
|
cc, err := grpc.Dial(besn, grpc.WithBalancer(Balancer(&testNameResolver{
|
||||||
|
addr: lbLis.Addr().String(),
|
||||||
|
})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to dial to the backend %v", err)
|
t.Fatalf("Failed to dial to the backend %v", err)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user