Adding dial options for PerRPCCredentials (#1225)
* Adding dial options for PerRPCCredentials * Added tests for PerRPCCredentials * Post-review updates * post-review updates
This commit is contained in:
3
call.go
3
call.go
@ -219,6 +219,9 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
|
||||
if cc.dopts.cp != nil {
|
||||
callHdr.SendCompress = cc.dopts.cp.Type()
|
||||
}
|
||||
if c.creds != nil {
|
||||
callHdr.Creds = c.creds
|
||||
}
|
||||
|
||||
gopts := BalancerGetOptions{
|
||||
BlockingWait: !c.failFast,
|
||||
|
11
rpc_util.go
11
rpc_util.go
@ -46,6 +46,7 @@ import (
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/stats"
|
||||
@ -141,6 +142,7 @@ type callInfo struct {
|
||||
trailerMD metadata.MD
|
||||
peer *peer.Peer
|
||||
traceInfo traceInfo // in trace.go
|
||||
creds credentials.PerRPCCredentials
|
||||
}
|
||||
|
||||
var defaultCallInfo = callInfo{failFast: true}
|
||||
@ -207,6 +209,15 @@ func FailFast(failFast bool) CallOption {
|
||||
})
|
||||
}
|
||||
|
||||
// PerRPCCredentials returns a CallOption that sets credentials.PerRPCCredentials
|
||||
// for a call.
|
||||
func PerRPCCredentials(creds credentials.PerRPCCredentials) CallOption {
|
||||
return beforeCall(func(c *callInfo) error {
|
||||
c.creds = creds
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// The format of the payload: compressed or not?
|
||||
type payloadFormat uint8
|
||||
|
||||
|
@ -132,6 +132,9 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
|
||||
if cc.dopts.cp != nil {
|
||||
callHdr.SendCompress = cc.dopts.cp.Type()
|
||||
}
|
||||
if c.creds != nil {
|
||||
callHdr.Creds = c.creds
|
||||
}
|
||||
var trInfo traceInfo
|
||||
if EnableTracing {
|
||||
trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
|
||||
|
@ -449,6 +449,7 @@ type test struct {
|
||||
serverInitialConnWindowSize int32
|
||||
clientInitialWindowSize int32
|
||||
clientInitialConnWindowSize int32
|
||||
perRPCCreds credentials.PerRPCCredentials
|
||||
|
||||
// srv and srvAddr are set once startServer is called.
|
||||
srv *grpc.Server
|
||||
@ -621,6 +622,9 @@ func (te *test) clientConn() *grpc.ClientConn {
|
||||
if te.clientInitialConnWindowSize > 0 {
|
||||
opts = append(opts, grpc.WithInitialConnWindowSize(te.clientInitialConnWindowSize))
|
||||
}
|
||||
if te.perRPCCreds != nil {
|
||||
opts = append(opts, grpc.WithPerRPCCredentials(te.perRPCCreds))
|
||||
}
|
||||
var err error
|
||||
te.cc, err = grpc.Dial(te.srvAddr, opts...)
|
||||
if err != nil {
|
||||
@ -3984,3 +3988,120 @@ func testConfigurableWindowSize(t *testing.T, e env, wc windowSizeConfig) {
|
||||
t.Fatalf("%v.CloseSend() = %v, want <nil>", stream, err)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
// test authdata
|
||||
authdata = map[string]string{
|
||||
"test-key": "test-value",
|
||||
"test-key2-bin": string([]byte{1, 2, 3}),
|
||||
}
|
||||
)
|
||||
|
||||
type testPerRPCCredentials struct{}
|
||||
|
||||
func (cr testPerRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
|
||||
return authdata, nil
|
||||
}
|
||||
|
||||
func (cr testPerRPCCredentials) RequireTransportSecurity() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func authHandle(ctx context.Context, info *tap.Info) (context.Context, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return ctx, fmt.Errorf("didn't find metadata in context")
|
||||
}
|
||||
for k, vwant := range authdata {
|
||||
vgot, ok := md[k]
|
||||
if !ok {
|
||||
return ctx, fmt.Errorf("didn't find authdata key %v in context", k)
|
||||
}
|
||||
if vgot[0] != vwant {
|
||||
return ctx, fmt.Errorf("for key %v, got value %v, want %v", k, vgot, vwant)
|
||||
}
|
||||
}
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func TestPerRPCCredentialsViaDialOptions(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
for _, e := range listTestEnv() {
|
||||
testPerRPCCredentialsViaDialOptions(t, e)
|
||||
}
|
||||
}
|
||||
|
||||
func testPerRPCCredentialsViaDialOptions(t *testing.T, e env) {
|
||||
te := newTest(t, e)
|
||||
te.tapHandle = authHandle
|
||||
te.perRPCCreds = testPerRPCCredentials{}
|
||||
te.startServer(&testServer{security: e.security})
|
||||
defer te.tearDown()
|
||||
|
||||
cc := te.clientConn()
|
||||
tc := testpb.NewTestServiceClient(cc)
|
||||
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
|
||||
t.Fatalf("Test failed. Reason: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerRPCCredentialsViaCallOptions(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
for _, e := range listTestEnv() {
|
||||
testPerRPCCredentialsViaCallOptions(t, e)
|
||||
}
|
||||
}
|
||||
|
||||
func testPerRPCCredentialsViaCallOptions(t *testing.T, e env) {
|
||||
te := newTest(t, e)
|
||||
te.tapHandle = authHandle
|
||||
te.startServer(&testServer{security: e.security})
|
||||
defer te.tearDown()
|
||||
|
||||
cc := te.clientConn()
|
||||
tc := testpb.NewTestServiceClient(cc)
|
||||
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil {
|
||||
t.Fatalf("Test failed. Reason: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T) {
|
||||
defer leakCheck(t)()
|
||||
for _, e := range listTestEnv() {
|
||||
testPerRPCCredentialsViaDialOptionsAndCallOptions(t, e)
|
||||
}
|
||||
}
|
||||
|
||||
func testPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T, e env) {
|
||||
te := newTest(t, e)
|
||||
te.perRPCCreds = testPerRPCCredentials{}
|
||||
// When credentials are provided via both dial options and call options,
|
||||
// we apply both sets.
|
||||
te.tapHandle = func(ctx context.Context, _ *tap.Info) (context.Context, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return ctx, fmt.Errorf("couldn't find metadata in context")
|
||||
}
|
||||
for k, vwant := range authdata {
|
||||
vgot, ok := md[k]
|
||||
if !ok {
|
||||
return ctx, fmt.Errorf("couldn't find metadata for key %v", k)
|
||||
}
|
||||
if len(vgot) != 2 {
|
||||
return ctx, fmt.Errorf("len of value for key %v was %v, want 2", k, len(vgot))
|
||||
}
|
||||
if vgot[0] != vwant || vgot[1] != vwant {
|
||||
return ctx, fmt.Errorf("value for %v was %v, want [%v, %v]", k, vgot, vwant, vwant)
|
||||
}
|
||||
}
|
||||
return ctx, nil
|
||||
}
|
||||
te.startServer(&testServer{security: e.security})
|
||||
defer te.tearDown()
|
||||
|
||||
cc := te.clientConn()
|
||||
tc := testpb.NewTestServiceClient(cc)
|
||||
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil {
|
||||
t.Fatalf("Test failed. Reason: %v", err)
|
||||
}
|
||||
}
|
||||
|
@ -101,6 +101,8 @@ type http2Client struct {
|
||||
// The scheme used: https if TLS is on, http otherwise.
|
||||
scheme string
|
||||
|
||||
isSecure bool
|
||||
|
||||
creds []credentials.PerRPCCredentials
|
||||
|
||||
// Boolean to keep track of reading activity on transport.
|
||||
@ -181,7 +183,10 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
|
||||
conn.Close()
|
||||
}
|
||||
}(conn)
|
||||
var authInfo credentials.AuthInfo
|
||||
var (
|
||||
isSecure bool
|
||||
authInfo credentials.AuthInfo
|
||||
)
|
||||
if creds := opts.TransportCredentials; creds != nil {
|
||||
scheme = "https"
|
||||
conn, authInfo, err = creds.ClientHandshake(ctx, addr.Addr, conn)
|
||||
@ -191,6 +196,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
|
||||
temp := isTemporary(err)
|
||||
return nil, connectionErrorf(temp, err, "transport: %v", err)
|
||||
}
|
||||
isSecure = true
|
||||
}
|
||||
kp := opts.KeepaliveParams
|
||||
// Validate keepalive parameters.
|
||||
@ -230,6 +236,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
|
||||
scheme: scheme,
|
||||
state: reachable,
|
||||
activeStreams: make(map[uint32]*Stream),
|
||||
isSecure: isSecure,
|
||||
creds: opts.PerRPCCredentials,
|
||||
maxStreams: defaultMaxStreamsClient,
|
||||
streamsQuota: newQuotaPool(defaultMaxStreamsClient),
|
||||
@ -335,8 +342,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
||||
pr.AuthInfo = t.authInfo
|
||||
}
|
||||
ctx = peer.NewContext(ctx, pr)
|
||||
authData := make(map[string]string)
|
||||
for _, c := range t.creds {
|
||||
var (
|
||||
authData = make(map[string]string)
|
||||
audience string
|
||||
)
|
||||
// Create an audience string only if needed.
|
||||
if len(t.creds) > 0 || callHdr.Creds != nil {
|
||||
// Construct URI required to get auth request metadata.
|
||||
var port string
|
||||
if pos := strings.LastIndex(t.target, ":"); pos != -1 {
|
||||
@ -347,17 +358,39 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
||||
}
|
||||
pos := strings.LastIndex(callHdr.Method, "/")
|
||||
if pos == -1 {
|
||||
return nil, streamErrorf(codes.InvalidArgument, "transport: malformed method name: %q", callHdr.Method)
|
||||
pos = len(callHdr.Method)
|
||||
}
|
||||
audience := "https://" + callHdr.Host + port + callHdr.Method[:pos]
|
||||
audience = "https://" + callHdr.Host + port + callHdr.Method[:pos]
|
||||
}
|
||||
for _, c := range t.creds {
|
||||
data, err := c.GetRequestMetadata(ctx, audience)
|
||||
if err != nil {
|
||||
return nil, streamErrorf(codes.InvalidArgument, "transport: %v", err)
|
||||
return nil, streamErrorf(codes.Internal, "transport: %v", err)
|
||||
}
|
||||
for k, v := range data {
|
||||
// Capital header names are illegal in HTTP/2.
|
||||
k = strings.ToLower(k)
|
||||
authData[k] = v
|
||||
}
|
||||
}
|
||||
callAuthData := make(map[string]string)
|
||||
// Check if credentials.PerRPCCredentials were provided via call options.
|
||||
// Note: if these credentials are provided both via dial options and call
|
||||
// options, then both sets of credentials will be applied.
|
||||
if callCreds := callHdr.Creds; callCreds != nil {
|
||||
if !t.isSecure && callCreds.RequireTransportSecurity() {
|
||||
return nil, streamErrorf(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure conneciton")
|
||||
}
|
||||
data, err := callCreds.GetRequestMetadata(ctx, audience)
|
||||
if err != nil {
|
||||
return nil, streamErrorf(codes.Internal, "transport: %v", err)
|
||||
}
|
||||
for k, v := range data {
|
||||
// Capital header names are illegal in HTTP/2
|
||||
k = strings.ToLower(k)
|
||||
callAuthData[k] = v
|
||||
}
|
||||
}
|
||||
t.mu.Lock()
|
||||
if t.activeStreams == nil {
|
||||
t.mu.Unlock()
|
||||
@ -435,9 +468,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
||||
}
|
||||
|
||||
for k, v := range authData {
|
||||
// Capital header names are illegal in HTTP/2.
|
||||
k = strings.ToLower(k)
|
||||
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
|
||||
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
|
||||
}
|
||||
for k, v := range callAuthData {
|
||||
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
|
||||
}
|
||||
var (
|
||||
hasMD bool
|
||||
|
@ -469,6 +469,9 @@ type CallHdr struct {
|
||||
// outbound message.
|
||||
SendCompress string
|
||||
|
||||
// Creds specifies credentials.PerRPCCredentials for a call.
|
||||
Creds credentials.PerRPCCredentials
|
||||
|
||||
// Flush indicates whether a new stream command should be sent
|
||||
// to the peer without waiting for the first data. This is
|
||||
// only a hint. The transport may modify the flush decision
|
||||
|
Reference in New Issue
Block a user