Credentials API and jwtAccess implementation tunning
This commit is contained in:
@ -64,7 +64,7 @@ type Credentials interface {
|
|||||||
// be used for timeout and cancellation.
|
// be used for timeout and cancellation.
|
||||||
// TODO(zhaoq): Define the set of the qualified keys instead of leaving
|
// TODO(zhaoq): Define the set of the qualified keys instead of leaving
|
||||||
// it as an arbitrary string.
|
// it as an arbitrary string.
|
||||||
GetRequestMetadata(ctx context.Context) (map[string]string, error)
|
GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error)
|
||||||
// RequireTransportSecurity indicates whether the credentails requires
|
// RequireTransportSecurity indicates whether the credentails requires
|
||||||
// transport security.
|
// transport security.
|
||||||
RequireTransportSecurity() bool
|
RequireTransportSecurity() bool
|
||||||
@ -140,7 +140,7 @@ func (c tlsCreds) Info() ProtocolInfo {
|
|||||||
|
|
||||||
// GetRequestMetadata returns nil, nil since TLS credentials does not have
|
// GetRequestMetadata returns nil, nil since TLS credentials does not have
|
||||||
// metadata.
|
// metadata.
|
||||||
func (c *tlsCreds) GetRequestMetadata(ctx context.Context) (map[string]string, error) {
|
func (c *tlsCreds) GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ type TokenSource struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetRequestMetadata gets the request metadata as a map from a TokenSource.
|
// GetRequestMetadata gets the request metadata as a map from a TokenSource.
|
||||||
func (ts TokenSource) GetRequestMetadata(ctx context.Context) (map[string]string, error) {
|
func (ts TokenSource) GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error) {
|
||||||
token, err := ts.Token()
|
token, err := ts.Token()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -66,27 +66,28 @@ func (ts TokenSource) RequireTransportSecurity() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type jwtAccess struct {
|
type jwtAccess struct {
|
||||||
ts oauth2.TokenSource
|
jsonKey []byte
|
||||||
|
//ts oauth2.TokenSource
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewJWTAccessFromFile(keyFile string, audience string) (credentials.Credentials, error) {
|
func NewJWTAccessFromFile(keyFile string) (credentials.Credentials, error) {
|
||||||
jsonKey, err := ioutil.ReadFile(keyFile)
|
jsonKey, err := ioutil.ReadFile(keyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("credentials: failed to read the service account key file: %v", err)
|
return nil, fmt.Errorf("credentials: failed to read the service account key file: %v", err)
|
||||||
}
|
}
|
||||||
return NewJWTAccessFromKey(jsonKey, audience)
|
return NewJWTAccessFromKey(jsonKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewJWTAccessFromKey(jsonKey []byte, audience string) (credentials.Credentials, error) {
|
func NewJWTAccessFromKey(jsonKey []byte) (credentials.Credentials, error) {
|
||||||
ts, err := google.JWTAccessTokenSourceFromJSON(jsonKey, audience)
|
return jwtAccess{ jsonKey }, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (j jwtAccess) GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error) {
|
||||||
|
ts, err := google.JWTAccessTokenSourceFromJSON(j.jsonKey, audience[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return jwtAccess{ts: ts}, nil
|
token, err := ts.Token()
|
||||||
}
|
|
||||||
|
|
||||||
func (j jwtAccess) GetRequestMetadata(ctx context.Context) (map[string]string, error) {
|
|
||||||
token, err := j.ts.Token()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -109,7 +110,7 @@ func NewOauthAccess(token *oauth2.Token) credentials.Credentials {
|
|||||||
return oauthAccess{token: *token}
|
return oauthAccess{token: *token}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (oa oauthAccess) GetRequestMetadata(ctx context.Context) (map[string]string, error) {
|
func (oa oauthAccess) GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error) {
|
||||||
return map[string]string{
|
return map[string]string{
|
||||||
"authorization": oa.token.TokenType + " " + oa.token.AccessToken,
|
"authorization": oa.token.TokenType + " " + oa.token.AccessToken,
|
||||||
}, nil
|
}, nil
|
||||||
@ -132,7 +133,7 @@ type serviceAccount struct {
|
|||||||
config *jwt.Config
|
config *jwt.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s serviceAccount) GetRequestMetadata(ctx context.Context) (map[string]string, error) {
|
func (s serviceAccount) GetRequestMetadata(ctx context.Context, audience ...string) (map[string]string, error) {
|
||||||
token, err := s.config.TokenSource(ctx).Token()
|
token, err := s.config.TokenSource(ctx).Token()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -514,7 +514,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
opts = append(opts, grpc.WithPerRPCCredentials(jwtCreds))
|
opts = append(opts, grpc.WithPerRPCCredentials(jwtCreds))
|
||||||
} else if *testCase == "jwt_token_creds" {
|
} else if *testCase == "jwt_token_creds" {
|
||||||
jwtCreds, err := oauth.NewJWTAccessFromFile(*serviceAccountKeyFile, "https://"+*serverHost+":"+string(*serverPort)+"/"+"TestService")
|
jwtCreds, err := oauth.NewJWTAccessFromFile(*serviceAccountKeyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
grpclog.Fatalf("Failed to create JWT credentials: %v", err)
|
grpclog.Fatalf("Failed to create JWT credentials: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -39,6 +39,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -243,7 +244,17 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||||||
}
|
}
|
||||||
authData := make(map[string]string)
|
authData := make(map[string]string)
|
||||||
for _, c := range t.authCreds {
|
for _, c := range t.authCreds {
|
||||||
data, err := c.GetRequestMetadata(ctx)
|
// Generate the audience string.
|
||||||
|
var port string
|
||||||
|
if pos := strings.LastIndex(t.target, ":"); pos != -1 {
|
||||||
|
port = ":" + t.target[pos+1:]
|
||||||
|
}
|
||||||
|
pos := strings.LastIndex(callHdr.Method, "/")
|
||||||
|
if pos == -1 {
|
||||||
|
return nil, StreamErrorf(codes.InvalidArgument, "transport: malformed method name: %q", callHdr.Method)
|
||||||
|
}
|
||||||
|
audience := "https://" + callHdr.Host + port + callHdr.Method[:pos]
|
||||||
|
data, err := c.GetRequestMetadata(ctx, audience)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, StreamErrorf(codes.InvalidArgument, "transport: %v", err)
|
return nil, StreamErrorf(codes.InvalidArgument, "transport: %v", err)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user