Merge pull request #930 from iamqizhao/master
grpclb: work with LoadBalanceToken
This commit is contained in:
@ -684,7 +684,11 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ac.ctx, timeout)
|
||||
connectTime := time.Now()
|
||||
newTransport, err := transport.NewClientTransport(ctx, ac.addr.Addr, ac.dopts.copts)
|
||||
sinfo := transport.TargetInfo{
|
||||
Addr: ac.addr.Addr,
|
||||
Metadata: ac.addr.Metadata,
|
||||
}
|
||||
newTransport, err := transport.NewClientTransport(ctx, sinfo, ac.dopts.copts)
|
||||
if err != nil {
|
||||
cancel()
|
||||
|
||||
|
@ -45,6 +45,7 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1"
|
||||
"google.golang.org/grpc/grpclog"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/naming"
|
||||
)
|
||||
|
||||
@ -184,9 +185,10 @@ func (b *balancer) processServerList(l *lbpb.ServerList, seq int) {
|
||||
)
|
||||
for _, s := range servers {
|
||||
// TODO: Support ExpirationInterval
|
||||
md := metadata.Pairs("lb-token", s.LoadBalanceToken)
|
||||
addr := grpc.Address{
|
||||
Addr: fmt.Sprintf("%s:%d", s.IpAddress, s.Port),
|
||||
// TODO: include LoadBalanceToken in the Metadata
|
||||
Addr: fmt.Sprintf("%s:%d", s.IpAddress, s.Port),
|
||||
Metadata: &md,
|
||||
}
|
||||
sl = append(sl, addrInfo{
|
||||
addr: addr,
|
||||
|
@ -47,13 +47,16 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
hwpb "google.golang.org/grpc/examples/helloworld/helloworld"
|
||||
lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/naming"
|
||||
)
|
||||
|
||||
var (
|
||||
lbsn = "bar.com"
|
||||
besn = "foo.com"
|
||||
lbsn = "bar.com"
|
||||
besn = "foo.com"
|
||||
lbToken = "iamatoken"
|
||||
)
|
||||
|
||||
type testWatcher struct {
|
||||
@ -195,12 +198,29 @@ func (b *remoteBalancer) BalanceLoad(stream lbpb.LoadBalancer_BalanceLoadServer)
|
||||
return nil
|
||||
}
|
||||
|
||||
type helloServer struct {
|
||||
}
|
||||
|
||||
func (s *helloServer) SayHello(ctx context.Context, in *hwpb.HelloRequest) (*hwpb.HelloReply, error) {
|
||||
md, ok := metadata.FromContext(ctx)
|
||||
if !ok {
|
||||
return nil, grpc.Errorf(codes.Internal, "failed to receive metadata")
|
||||
}
|
||||
if md == nil || md["lb-token"][0] != lbToken {
|
||||
return nil, grpc.Errorf(codes.Internal, "received unexpected metadata: %v", md)
|
||||
}
|
||||
return &hwpb.HelloReply{
|
||||
Message: "Hello " + in.Name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func startBackends(t *testing.T, sn string, lis ...net.Listener) (servers []*grpc.Server) {
|
||||
for _, l := range lis {
|
||||
creds := &serverNameCheckCreds{
|
||||
sn: sn,
|
||||
}
|
||||
s := grpc.NewServer(grpc.Creds(creds))
|
||||
hwpb.RegisterGreeterServer(s, &helloServer{})
|
||||
servers = append(servers, s)
|
||||
go func(s *grpc.Server, l net.Listener) {
|
||||
s.Serve(l)
|
||||
@ -239,8 +259,9 @@ func TestGRPCLB(t *testing.T) {
|
||||
t.Fatalf("Failed to generate the port number %v", err)
|
||||
}
|
||||
be := &lbpb.Server{
|
||||
IpAddress: []byte(beAddr[0]),
|
||||
Port: int32(bePort),
|
||||
IpAddress: []byte(beAddr[0]),
|
||||
Port: int32(bePort),
|
||||
LoadBalanceToken: lbToken,
|
||||
}
|
||||
var bes []*lbpb.Server
|
||||
bes = append(bes, be)
|
||||
@ -266,12 +287,9 @@ func TestGRPCLB(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial to the backend %v", err)
|
||||
}
|
||||
// Issue an unimplemented RPC and expect codes.Unimplemented.
|
||||
var (
|
||||
req, reply lbpb.Duration
|
||||
)
|
||||
if err := grpc.Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || grpc.Code(err) != codes.Unimplemented {
|
||||
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want error code %s", err, codes.Unimplemented)
|
||||
helloC := hwpb.NewGreeterClient(cc)
|
||||
if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
|
||||
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
|
||||
}
|
||||
cc.Close()
|
||||
}
|
||||
|
@ -57,6 +57,7 @@ import (
|
||||
type http2Client struct {
|
||||
target string // server name/addr
|
||||
userAgent string
|
||||
md interface{}
|
||||
conn net.Conn // underlying communication channel
|
||||
authInfo credentials.AuthInfo // auth info about the connection
|
||||
nextID uint32 // the next stream ID to be used
|
||||
@ -145,9 +146,9 @@ func isTemporary(err error) bool {
|
||||
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
|
||||
// and starts to receive messages on it. Non-nil error returns if construction
|
||||
// fails.
|
||||
func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ ClientTransport, err error) {
|
||||
func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (_ ClientTransport, err error) {
|
||||
scheme := "http"
|
||||
conn, err := dial(ctx, opts.Dialer, addr)
|
||||
conn, err := dial(ctx, opts.Dialer, addr.Addr)
|
||||
if err != nil {
|
||||
return nil, connectionErrorf(true, err, "transport: %v", err)
|
||||
}
|
||||
@ -160,7 +161,7 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl
|
||||
var authInfo credentials.AuthInfo
|
||||
if creds := opts.TransportCredentials; creds != nil {
|
||||
scheme = "https"
|
||||
conn, authInfo, err = creds.ClientHandshake(ctx, addr, conn)
|
||||
conn, authInfo, err = creds.ClientHandshake(ctx, addr.Addr, conn)
|
||||
if err != nil {
|
||||
// Credentials handshake errors are typically considered permanent
|
||||
// to avoid retrying on e.g. bad certificates.
|
||||
@ -174,8 +175,9 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
t := &http2Client{
|
||||
target: addr,
|
||||
target: addr.Addr,
|
||||
userAgent: ua,
|
||||
md: addr.Metadata,
|
||||
conn: conn,
|
||||
authInfo: authInfo,
|
||||
// The client initiated stream id is odd starting from 1.
|
||||
@ -400,6 +402,16 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
||||
}
|
||||
}
|
||||
}
|
||||
if md, ok := t.md.(*metadata.MD); ok {
|
||||
for k, v := range *md {
|
||||
if isReservedHeader(k) {
|
||||
continue
|
||||
}
|
||||
for _, entry := range v {
|
||||
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
|
||||
}
|
||||
}
|
||||
}
|
||||
first := true
|
||||
// Sends the headers in a single batch even when they span multiple frames.
|
||||
for !endHeaders {
|
||||
|
@ -361,7 +361,7 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32, authI
|
||||
return newHTTP2Server(conn, maxStreams, authInfo)
|
||||
}
|
||||
|
||||
// ConnectOptions covers all relevant options for dialing a server.
|
||||
// ConnectOptions covers all relevant options for communicating with the server.
|
||||
type ConnectOptions struct {
|
||||
// UserAgent is the application user agent.
|
||||
UserAgent string
|
||||
@ -373,9 +373,15 @@ type ConnectOptions struct {
|
||||
TransportCredentials credentials.TransportCredentials
|
||||
}
|
||||
|
||||
// TargetInfo contains the information of the target such as network address and metadata.
|
||||
type TargetInfo struct {
|
||||
Addr string
|
||||
Metadata interface{}
|
||||
}
|
||||
|
||||
// NewClientTransport establishes the transport with the required ConnectOptions
|
||||
// and returns it to the caller.
|
||||
func NewClientTransport(ctx context.Context, target string, opts ConnectOptions) (ClientTransport, error) {
|
||||
func NewClientTransport(ctx context.Context, target TargetInfo, opts ConnectOptions) (ClientTransport, error) {
|
||||
return newHTTP2Client(ctx, target, opts)
|
||||
}
|
||||
|
||||
|
@ -245,7 +245,10 @@ func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, Client
|
||||
ct ClientTransport
|
||||
connErr error
|
||||
)
|
||||
ct, connErr = NewClientTransport(context.Background(), addr, ConnectOptions{})
|
||||
target := TargetInfo{
|
||||
Addr: addr,
|
||||
}
|
||||
ct, connErr = NewClientTransport(context.Background(), target, ConnectOptions{})
|
||||
if connErr != nil {
|
||||
t.Fatalf("failed to create transport: %v", connErr)
|
||||
}
|
||||
|
Reference in New Issue
Block a user