diff --git a/credentials/credentials.go b/credentials/credentials.go index 492517fc..10e56f15 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -64,7 +64,7 @@ type Credentials interface { // be used for timeout and cancellation. // TODO(zhaoq): Define the set of the qualified keys instead of leaving // it as an arbitrary string. - GetRequestMetadata(ctx context.Context) (map[string]string, error) + GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error) // RequireTransportSecurity indicates whether the credentails requires // transport security. RequireTransportSecurity() bool @@ -140,7 +140,7 @@ func (c tlsCreds) Info() ProtocolInfo { // GetRequestMetadata returns nil, nil since TLS credentials does not have // metadata. -func (c *tlsCreds) GetRequestMetadata(ctx context.Context) (map[string]string, error) { +func (c *tlsCreds) GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error) { return nil, nil } diff --git a/credentials/oauth/oauth.go b/credentials/oauth/oauth.go index 43ebeeaf..7b402b84 100644 --- a/credentials/oauth/oauth.go +++ b/credentials/oauth/oauth.go @@ -51,7 +51,7 @@ type TokenSource struct { } // GetRequestMetadata gets the request metadata as a map from a TokenSource. -func (ts TokenSource) GetRequestMetadata(ctx context.Context) (map[string]string, error) { +func (ts TokenSource) GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error) { token, err := ts.Token() if err != nil { return nil, err @@ -66,27 +66,28 @@ func (ts TokenSource) RequireTransportSecurity() bool { } type jwtAccess struct { - ts oauth2.TokenSource + jsonKey []byte + //ts oauth2.TokenSource } -func NewJWTAccessFromFile(keyFile string, audience string) (credentials.Credentials, error) { +func NewJWTAccessFromFile(keyFile string) (credentials.Credentials, error) { jsonKey, err := ioutil.ReadFile(keyFile) if err != nil { return nil, fmt.Errorf("credentials: failed to read the service account key file: %v", err) } - return NewJWTAccessFromKey(jsonKey, audience) + return NewJWTAccessFromKey(jsonKey) } -func NewJWTAccessFromKey(jsonKey []byte, audience string) (credentials.Credentials, error) { - ts, err := google.JWTAccessTokenSourceFromJSON(jsonKey, audience) +func NewJWTAccessFromKey(jsonKey []byte) (credentials.Credentials, error) { + return jwtAccess{ jsonKey }, nil +} + +func (j jwtAccess) GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error) { + ts, err := google.JWTAccessTokenSourceFromJSON(j.jsonKey, audience[0]) if err != nil { return nil, err } - return jwtAccess{ts: ts}, nil -} - -func (j jwtAccess) GetRequestMetadata(ctx context.Context) (map[string]string, error) { - token, err := j.ts.Token() + token, err := ts.Token() if err != nil { return nil, err } @@ -109,7 +110,7 @@ func NewOauthAccess(token *oauth2.Token) credentials.Credentials { return oauthAccess{token: *token} } -func (oa oauthAccess) GetRequestMetadata(ctx context.Context) (map[string]string, error) { +func (oa oauthAccess) GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error) { return map[string]string{ "authorization": oa.token.TokenType + " " + oa.token.AccessToken, }, nil @@ -132,7 +133,7 @@ type serviceAccount struct { config *jwt.Config } -func (s serviceAccount) GetRequestMetadata(ctx context.Context) (map[string]string, error) { +func (s serviceAccount) GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error) { token, err := s.config.TokenSource(ctx).Token() if err != nil { return nil, err diff --git a/interop/client/client.go b/interop/client/client.go index c6133780..4f715d35 100644 --- a/interop/client/client.go +++ b/interop/client/client.go @@ -514,7 +514,7 @@ func main() { } opts = append(opts, grpc.WithPerRPCCredentials(jwtCreds)) } else if *testCase == "jwt_token_creds" { - jwtCreds, err := oauth.NewJWTAccessFromFile(*serviceAccountKeyFile, "https://"+*serverHost+":"+string(*serverPort)+"/"+"TestService") + jwtCreds, err := oauth.NewJWTAccessFromFile(*serviceAccountKeyFile) if err != nil { grpclog.Fatalf("Failed to create JWT credentials: %v", err) } diff --git a/transport/http2_client.go b/transport/http2_client.go index 715e2dbc..67eb8334 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -39,6 +39,7 @@ import ( "io" "math" "net" + "strings" "sync" "time" @@ -243,7 +244,17 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } authData := make(map[string]string) for _, c := range t.authCreds { - data, err := c.GetRequestMetadata(ctx) + // Generate the audience string. + var port string + if pos := strings.LastIndex(t.target, ":"); pos != -1 { + port = ":" + t.target[pos+1:] + } + pos := strings.LastIndex(callHdr.Method, "/") + if pos == -1 { + return nil, StreamErrorf(codes.InvalidArgument, "transport: malformed method name: %q", callHdr.Method) + } + audience := "https://" + callHdr.Host + port + callHdr.Method[:pos] + data, err := c.GetRequestMetadata(ctx, audience) if err != nil { return nil, StreamErrorf(codes.InvalidArgument, "transport: %v", err) }