From 88a73d35c975fb7bf79071dbcdb84472e7e6669c Mon Sep 17 00:00:00 2001 From: MakMukhi Date: Thu, 11 May 2017 11:07:38 -0700 Subject: [PATCH] Adding dial options for PerRPCCredentials (#1225) * Adding dial options for PerRPCCredentials * Added tests for PerRPCCredentials * Post-review updates * post-review updates --- call.go | 3 + rpc_util.go | 11 ++++ stream.go | 3 + test/end2end_test.go | 121 ++++++++++++++++++++++++++++++++++++++ transport/http2_client.go | 52 +++++++++++++--- transport/transport.go | 3 + 6 files changed, 184 insertions(+), 9 deletions(-) diff --git a/call.go b/call.go index 0eb5f5cf..08438652 100644 --- a/call.go +++ b/call.go @@ -219,6 +219,9 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli if cc.dopts.cp != nil { callHdr.SendCompress = cc.dopts.cp.Type() } + if c.creds != nil { + callHdr.Creds = c.creds + } gopts := BalancerGetOptions{ BlockingWait: !c.failFast, diff --git a/rpc_util.go b/rpc_util.go index 31a87325..6a32afdf 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -46,6 +46,7 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/stats" @@ -141,6 +142,7 @@ type callInfo struct { trailerMD metadata.MD peer *peer.Peer traceInfo traceInfo // in trace.go + creds credentials.PerRPCCredentials } var defaultCallInfo = callInfo{failFast: true} @@ -207,6 +209,15 @@ func FailFast(failFast bool) CallOption { }) } +// PerRPCCredentials returns a CallOption that sets credentials.PerRPCCredentials +// for a call. +func PerRPCCredentials(creds credentials.PerRPCCredentials) CallOption { + return beforeCall(func(c *callInfo) error { + c.creds = creds + return nil + }) +} + // The format of the payload: compressed or not? type payloadFormat uint8 diff --git a/stream.go b/stream.go index 25764330..ec534a01 100644 --- a/stream.go +++ b/stream.go @@ -132,6 +132,9 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth if cc.dopts.cp != nil { callHdr.SendCompress = cc.dopts.cp.Type() } + if c.creds != nil { + callHdr.Creds = c.creds + } var trInfo traceInfo if EnableTracing { trInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method) diff --git a/test/end2end_test.go b/test/end2end_test.go index 01b3e4f7..ced25096 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -449,6 +449,7 @@ type test struct { serverInitialConnWindowSize int32 clientInitialWindowSize int32 clientInitialConnWindowSize int32 + perRPCCreds credentials.PerRPCCredentials // srv and srvAddr are set once startServer is called. srv *grpc.Server @@ -621,6 +622,9 @@ func (te *test) clientConn() *grpc.ClientConn { if te.clientInitialConnWindowSize > 0 { opts = append(opts, grpc.WithInitialConnWindowSize(te.clientInitialConnWindowSize)) } + if te.perRPCCreds != nil { + opts = append(opts, grpc.WithPerRPCCredentials(te.perRPCCreds)) + } var err error te.cc, err = grpc.Dial(te.srvAddr, opts...) if err != nil { @@ -3984,3 +3988,120 @@ func testConfigurableWindowSize(t *testing.T, e env, wc windowSizeConfig) { t.Fatalf("%v.CloseSend() = %v, want ", stream, err) } } + +var ( + // test authdata + authdata = map[string]string{ + "test-key": "test-value", + "test-key2-bin": string([]byte{1, 2, 3}), + } +) + +type testPerRPCCredentials struct{} + +func (cr testPerRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + return authdata, nil +} + +func (cr testPerRPCCredentials) RequireTransportSecurity() bool { + return false +} + +func authHandle(ctx context.Context, info *tap.Info) (context.Context, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return ctx, fmt.Errorf("didn't find metadata in context") + } + for k, vwant := range authdata { + vgot, ok := md[k] + if !ok { + return ctx, fmt.Errorf("didn't find authdata key %v in context", k) + } + if vgot[0] != vwant { + return ctx, fmt.Errorf("for key %v, got value %v, want %v", k, vgot, vwant) + } + } + return ctx, nil +} + +func TestPerRPCCredentialsViaDialOptions(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testPerRPCCredentialsViaDialOptions(t, e) + } +} + +func testPerRPCCredentialsViaDialOptions(t *testing.T, e env) { + te := newTest(t, e) + te.tapHandle = authHandle + te.perRPCCreds = testPerRPCCredentials{} + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + t.Fatalf("Test failed. Reason: %v", err) + } +} + +func TestPerRPCCredentialsViaCallOptions(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testPerRPCCredentialsViaCallOptions(t, e) + } +} + +func testPerRPCCredentialsViaCallOptions(t *testing.T, e env) { + te := newTest(t, e) + te.tapHandle = authHandle + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil { + t.Fatalf("Test failed. Reason: %v", err) + } +} + +func TestPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testPerRPCCredentialsViaDialOptionsAndCallOptions(t, e) + } +} + +func testPerRPCCredentialsViaDialOptionsAndCallOptions(t *testing.T, e env) { + te := newTest(t, e) + te.perRPCCreds = testPerRPCCredentials{} + // When credentials are provided via both dial options and call options, + // we apply both sets. + te.tapHandle = func(ctx context.Context, _ *tap.Info) (context.Context, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return ctx, fmt.Errorf("couldn't find metadata in context") + } + for k, vwant := range authdata { + vgot, ok := md[k] + if !ok { + return ctx, fmt.Errorf("couldn't find metadata for key %v", k) + } + if len(vgot) != 2 { + return ctx, fmt.Errorf("len of value for key %v was %v, want 2", k, len(vgot)) + } + if vgot[0] != vwant || vgot[1] != vwant { + return ctx, fmt.Errorf("value for %v was %v, want [%v, %v]", k, vgot, vwant, vwant) + } + } + return ctx, nil + } + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.PerRPCCredentials(testPerRPCCredentials{})); err != nil { + t.Fatalf("Test failed. Reason: %v", err) + } +} diff --git a/transport/http2_client.go b/transport/http2_client.go index 3e5ff731..7db73d38 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -101,6 +101,8 @@ type http2Client struct { // The scheme used: https if TLS is on, http otherwise. scheme string + isSecure bool + creds []credentials.PerRPCCredentials // Boolean to keep track of reading activity on transport. @@ -181,7 +183,10 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( conn.Close() } }(conn) - var authInfo credentials.AuthInfo + var ( + isSecure bool + authInfo credentials.AuthInfo + ) if creds := opts.TransportCredentials; creds != nil { scheme = "https" conn, authInfo, err = creds.ClientHandshake(ctx, addr.Addr, conn) @@ -191,6 +196,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( temp := isTemporary(err) return nil, connectionErrorf(temp, err, "transport: %v", err) } + isSecure = true } kp := opts.KeepaliveParams // Validate keepalive parameters. @@ -230,6 +236,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( scheme: scheme, state: reachable, activeStreams: make(map[uint32]*Stream), + isSecure: isSecure, creds: opts.PerRPCCredentials, maxStreams: defaultMaxStreamsClient, streamsQuota: newQuotaPool(defaultMaxStreamsClient), @@ -335,8 +342,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea pr.AuthInfo = t.authInfo } ctx = peer.NewContext(ctx, pr) - authData := make(map[string]string) - for _, c := range t.creds { + var ( + authData = make(map[string]string) + audience string + ) + // Create an audience string only if needed. + if len(t.creds) > 0 || callHdr.Creds != nil { // Construct URI required to get auth request metadata. var port string if pos := strings.LastIndex(t.target, ":"); pos != -1 { @@ -347,17 +358,39 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } pos := strings.LastIndex(callHdr.Method, "/") if pos == -1 { - return nil, streamErrorf(codes.InvalidArgument, "transport: malformed method name: %q", callHdr.Method) + pos = len(callHdr.Method) } - audience := "https://" + callHdr.Host + port + callHdr.Method[:pos] + audience = "https://" + callHdr.Host + port + callHdr.Method[:pos] + } + for _, c := range t.creds { data, err := c.GetRequestMetadata(ctx, audience) if err != nil { - return nil, streamErrorf(codes.InvalidArgument, "transport: %v", err) + return nil, streamErrorf(codes.Internal, "transport: %v", err) } for k, v := range data { + // Capital header names are illegal in HTTP/2. + k = strings.ToLower(k) authData[k] = v } } + callAuthData := make(map[string]string) + // Check if credentials.PerRPCCredentials were provided via call options. + // Note: if these credentials are provided both via dial options and call + // options, then both sets of credentials will be applied. + if callCreds := callHdr.Creds; callCreds != nil { + if !t.isSecure && callCreds.RequireTransportSecurity() { + return nil, streamErrorf(codes.Unauthenticated, "transport: cannot send secure credentials on an insecure conneciton") + } + data, err := callCreds.GetRequestMetadata(ctx, audience) + if err != nil { + return nil, streamErrorf(codes.Internal, "transport: %v", err) + } + for k, v := range data { + // Capital header names are illegal in HTTP/2 + k = strings.ToLower(k) + callAuthData[k] = v + } + } t.mu.Lock() if t.activeStreams == nil { t.mu.Unlock() @@ -435,9 +468,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } for k, v := range authData { - // Capital header names are illegal in HTTP/2. - k = strings.ToLower(k) - t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: v}) + t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + } + for k, v := range callAuthData { + t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) } var ( hasMD bool diff --git a/transport/transport.go b/transport/transport.go index 4bd4dc44..2dff7c8b 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -469,6 +469,9 @@ type CallHdr struct { // outbound message. SendCompress string + // Creds specifies credentials.PerRPCCredentials for a call. + Creds credentials.PerRPCCredentials + // Flush indicates whether a new stream command should be sent // to the peer without waiting for the first data. This is // only a hint. The transport may modify the flush decision