Credentials API and jwtAccess implementation tunning
This commit is contained in:
@ -64,7 +64,7 @@ type Credentials interface {
|
||||
// be used for timeout and cancellation.
|
||||
// TODO(zhaoq): Define the set of the qualified keys instead of leaving
|
||||
// it as an arbitrary string.
|
||||
GetRequestMetadata(ctx context.Context) (map[string]string, error)
|
||||
GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error)
|
||||
// RequireTransportSecurity indicates whether the credentails requires
|
||||
// transport security.
|
||||
RequireTransportSecurity() bool
|
||||
@ -140,7 +140,7 @@ func (c tlsCreds) Info() ProtocolInfo {
|
||||
|
||||
// GetRequestMetadata returns nil, nil since TLS credentials does not have
|
||||
// metadata.
|
||||
func (c *tlsCreds) GetRequestMetadata(ctx context.Context) (map[string]string, error) {
|
||||
func (c *tlsCreds) GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
@ -51,7 +51,7 @@ type TokenSource struct {
|
||||
}
|
||||
|
||||
// GetRequestMetadata gets the request metadata as a map from a TokenSource.
|
||||
func (ts TokenSource) GetRequestMetadata(ctx context.Context) (map[string]string, error) {
|
||||
func (ts TokenSource) GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error) {
|
||||
token, err := ts.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -66,27 +66,28 @@ func (ts TokenSource) RequireTransportSecurity() bool {
|
||||
}
|
||||
|
||||
type jwtAccess struct {
|
||||
ts oauth2.TokenSource
|
||||
jsonKey []byte
|
||||
//ts oauth2.TokenSource
|
||||
}
|
||||
|
||||
func NewJWTAccessFromFile(keyFile string, audience string) (credentials.Credentials, error) {
|
||||
func NewJWTAccessFromFile(keyFile string) (credentials.Credentials, error) {
|
||||
jsonKey, err := ioutil.ReadFile(keyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("credentials: failed to read the service account key file: %v", err)
|
||||
}
|
||||
return NewJWTAccessFromKey(jsonKey, audience)
|
||||
return NewJWTAccessFromKey(jsonKey)
|
||||
}
|
||||
|
||||
func NewJWTAccessFromKey(jsonKey []byte, audience string) (credentials.Credentials, error) {
|
||||
ts, err := google.JWTAccessTokenSourceFromJSON(jsonKey, audience)
|
||||
func NewJWTAccessFromKey(jsonKey []byte) (credentials.Credentials, error) {
|
||||
return jwtAccess{ jsonKey }, nil
|
||||
}
|
||||
|
||||
func (j jwtAccess) GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error) {
|
||||
ts, err := google.JWTAccessTokenSourceFromJSON(j.jsonKey, audience[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jwtAccess{ts: ts}, nil
|
||||
}
|
||||
|
||||
func (j jwtAccess) GetRequestMetadata(ctx context.Context) (map[string]string, error) {
|
||||
token, err := j.ts.Token()
|
||||
token, err := ts.Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -109,7 +110,7 @@ func NewOauthAccess(token *oauth2.Token) credentials.Credentials {
|
||||
return oauthAccess{token: *token}
|
||||
}
|
||||
|
||||
func (oa oauthAccess) GetRequestMetadata(ctx context.Context) (map[string]string, error) {
|
||||
func (oa oauthAccess) GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error) {
|
||||
return map[string]string{
|
||||
"authorization": oa.token.TokenType + " " + oa.token.AccessToken,
|
||||
}, nil
|
||||
@ -132,7 +133,7 @@ type serviceAccount struct {
|
||||
config *jwt.Config
|
||||
}
|
||||
|
||||
func (s serviceAccount) GetRequestMetadata(ctx context.Context) (map[string]string, error) {
|
||||
func (s serviceAccount) GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error) {
|
||||
token, err := s.config.TokenSource(ctx).Token()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -514,7 +514,7 @@ func main() {
|
||||
}
|
||||
opts = append(opts, grpc.WithPerRPCCredentials(jwtCreds))
|
||||
} else if *testCase == "jwt_token_creds" {
|
||||
jwtCreds, err := oauth.NewJWTAccessFromFile(*serviceAccountKeyFile, "https://"+*serverHost+":"+string(*serverPort)+"/"+"TestService")
|
||||
jwtCreds, err := oauth.NewJWTAccessFromFile(*serviceAccountKeyFile)
|
||||
if err != nil {
|
||||
grpclog.Fatalf("Failed to create JWT credentials: %v", err)
|
||||
}
|
||||
|
@ -39,6 +39,7 @@ import (
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -243,7 +244,17 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
||||
}
|
||||
authData := make(map[string]string)
|
||||
for _, c := range t.authCreds {
|
||||
data, err := c.GetRequestMetadata(ctx)
|
||||
// Generate the audience string.
|
||||
var port string
|
||||
if pos := strings.LastIndex(t.target, ":"); pos != -1 {
|
||||
port = ":" + t.target[pos+1:]
|
||||
}
|
||||
pos := strings.LastIndex(callHdr.Method, "/")
|
||||
if pos == -1 {
|
||||
return nil, StreamErrorf(codes.InvalidArgument, "transport: malformed method name: %q", callHdr.Method)
|
||||
}
|
||||
audience := "https://" + callHdr.Host + port + callHdr.Method[:pos]
|
||||
data, err := c.GetRequestMetadata(ctx, audience)
|
||||
if err != nil {
|
||||
return nil, StreamErrorf(codes.InvalidArgument, "transport: %v", err)
|
||||
}
|
||||
|
Reference in New Issue
Block a user