Credentials API and jwtAccess implementation tunning

This commit is contained in:
iamqizhao
2015-08-28 16:51:45 -07:00
parent 3af5617830
commit 6be470f058
4 changed files with 29 additions and 17 deletions

View File

@ -64,7 +64,7 @@ type Credentials interface {
// be used for timeout and cancellation.
// TODO(zhaoq): Define the set of the qualified keys instead of leaving
// it as an arbitrary string.
GetRequestMetadata(ctx context.Context) (map[string]string, error)
GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error)
// RequireTransportSecurity indicates whether the credentails requires
// transport security.
RequireTransportSecurity() bool
@ -140,7 +140,7 @@ func (c tlsCreds) Info() ProtocolInfo {
// GetRequestMetadata returns nil, nil since TLS credentials does not have
// metadata.
func (c *tlsCreds) GetRequestMetadata(ctx context.Context) (map[string]string, error) {
func (c *tlsCreds) GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error) {
return nil, nil
}

View File

@ -51,7 +51,7 @@ type TokenSource struct {
}
// GetRequestMetadata gets the request metadata as a map from a TokenSource.
func (ts TokenSource) GetRequestMetadata(ctx context.Context) (map[string]string, error) {
func (ts TokenSource) GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error) {
token, err := ts.Token()
if err != nil {
return nil, err
@ -66,27 +66,28 @@ func (ts TokenSource) RequireTransportSecurity() bool {
}
type jwtAccess struct {
ts oauth2.TokenSource
jsonKey []byte
//ts oauth2.TokenSource
}
func NewJWTAccessFromFile(keyFile string, audience string) (credentials.Credentials, error) {
func NewJWTAccessFromFile(keyFile string) (credentials.Credentials, error) {
jsonKey, err := ioutil.ReadFile(keyFile)
if err != nil {
return nil, fmt.Errorf("credentials: failed to read the service account key file: %v", err)
}
return NewJWTAccessFromKey(jsonKey, audience)
return NewJWTAccessFromKey(jsonKey)
}
func NewJWTAccessFromKey(jsonKey []byte, audience string) (credentials.Credentials, error) {
ts, err := google.JWTAccessTokenSourceFromJSON(jsonKey, audience)
func NewJWTAccessFromKey(jsonKey []byte) (credentials.Credentials, error) {
return jwtAccess{ jsonKey }, nil
}
func (j jwtAccess) GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error) {
ts, err := google.JWTAccessTokenSourceFromJSON(j.jsonKey, audience[0])
if err != nil {
return nil, err
}
return jwtAccess{ts: ts}, nil
}
func (j jwtAccess) GetRequestMetadata(ctx context.Context) (map[string]string, error) {
token, err := j.ts.Token()
token, err := ts.Token()
if err != nil {
return nil, err
}
@ -109,7 +110,7 @@ func NewOauthAccess(token *oauth2.Token) credentials.Credentials {
return oauthAccess{token: *token}
}
func (oa oauthAccess) GetRequestMetadata(ctx context.Context) (map[string]string, error) {
func (oa oauthAccess) GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error) {
return map[string]string{
"authorization": oa.token.TokenType + " " + oa.token.AccessToken,
}, nil
@ -132,7 +133,7 @@ type serviceAccount struct {
config *jwt.Config
}
func (s serviceAccount) GetRequestMetadata(ctx context.Context) (map[string]string, error) {
func (s serviceAccount) GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error) {
token, err := s.config.TokenSource(ctx).Token()
if err != nil {
return nil, err

View File

@ -514,7 +514,7 @@ func main() {
}
opts = append(opts, grpc.WithPerRPCCredentials(jwtCreds))
} else if *testCase == "jwt_token_creds" {
jwtCreds, err := oauth.NewJWTAccessFromFile(*serviceAccountKeyFile, "https://"+*serverHost+":"+string(*serverPort)+"/"+"TestService")
jwtCreds, err := oauth.NewJWTAccessFromFile(*serviceAccountKeyFile)
if err != nil {
grpclog.Fatalf("Failed to create JWT credentials: %v", err)
}

View File

@ -39,6 +39,7 @@ import (
"io"
"math"
"net"
"strings"
"sync"
"time"
@ -243,7 +244,17 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
}
authData := make(map[string]string)
for _, c := range t.authCreds {
data, err := c.GetRequestMetadata(ctx)
// Generate the audience string.
var port string
if pos := strings.LastIndex(t.target, ":"); pos != -1 {
port = ":" + t.target[pos+1:]
}
pos := strings.LastIndex(callHdr.Method, "/")
if pos == -1 {
return nil, StreamErrorf(codes.InvalidArgument, "transport: malformed method name: %q", callHdr.Method)
}
audience := "https://" + callHdr.Host + port + callHdr.Method[:pos]
data, err := c.GetRequestMetadata(ctx, audience)
if err != nil {
return nil, StreamErrorf(codes.InvalidArgument, "transport: %v", err)
}