Migrated to the latest google.org/x/oauth2 package and added support for JWT.

This commit is contained in:
iamqizhao
2015-02-18 12:02:43 -08:00
parent 78d3bc72bf
commit 6148d0a55d
3 changed files with 79 additions and 91 deletions

View File

@ -60,11 +60,9 @@ func WithClientTLS(creds credentials.TransportAuthenticator) DialOption {
}
}
// WithComputeEngine returns a DialOption which sets
// credentials which use application default credentials as provided to
// Google Compute Engine. Note that TLS credentials is typically also
// needed. If it is the case, users need to pass WithTLS option too.
func WithComputeEngine(creds credentials.Credentials) DialOption {
// WithPerRPCCredentials returns a DialOption which sets
// credentials which will place auth state on each outbound RPC.
func WithPerRPCCredentials(creds credentials.Credentials) DialOption {
return func(o *dialOptions) {
o.authOptions = append(o.authOptions, creds)
}

View File

@ -31,25 +31,23 @@
*
*/
// Package credentials implements various credentials supported by gRPC library.
// Package credentials implements various credentials supported by gRPC library,
// which encapsulate all the state needed by a client to authenticate with a
// server and make various assertions, e.g., about the client's identity, role,
// or whether it is authorized to make a particular call.
package credentials // import "google.golang.org/grpc/credentials"
import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"sync"
"time"
)
const (
metadataServer = "metadata"
serviceAccountPath = "/computeMetadata/v1/instance/service-accounts/default/token"
"golang.org/x/net/context"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"golang.org/x/oauth2/jwt"
)
var (
@ -63,8 +61,11 @@ type Credentials interface {
// GetRequestMetadata gets the current request metadata, refreshing
// tokens if required. This should be called by the transport layer on
// each request, and the data should be populated in headers or other
// context. The operation may do things like refresh tokens.
GetRequestMetadata() (map[string]string, error)
// context. When supported by the underlying implementation, ctx can
// be used for timeout and cancellation.
// TODO(zhaoq): Define the set of the qualified keys instead of leaving
// it as an arbitrary string.
GetRequestMetadata(ctx context.Context) (map[string]string, error)
}
// TransportAuthenticator defines the common interface all supported transport
@ -98,7 +99,9 @@ type tlsCreds struct {
// GetRequestMetadata returns nil, nil since TLS credentials does not have
// metadata.
func (c *tlsCreds) GetRequestMetadata() (map[string]string, error) {
// TODO(zhaoq): Define the set of the qualified keys instead of leaving it as an
// arbitrary string.
func (c *tlsCreds) GetRequestMetadata(ctx context.Context) (map[string]string, error) {
return nil, nil
}
@ -108,7 +111,7 @@ func (c *tlsCreds) Dial(addr string) (_ net.Conn, err error) {
if name == "" {
name, _, err = net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("failed to parse server address %v", err)
return nil, fmt.Errorf("credentials: failed to parse server address %v", err)
}
}
return tls.Dial("tcp", addr, &tls.Config{
@ -143,7 +146,7 @@ func NewClientTLSFromFile(certFile, serverName string) (TransportAuthenticator,
}
cp := x509.NewCertPool()
if !cp.AppendCertsFromPEM(b) {
return nil, fmt.Errorf("failed to append certificates")
return nil, fmt.Errorf("credentials: failed to append certificates")
}
return &tlsCreds{
serverName: serverName,
@ -170,86 +173,68 @@ func NewServerTLSFromFile(certFile, keyFile string) (TransportAuthenticator, err
}, nil
}
type tokenData struct {
accessToken string
expiresIn float64
tokeType string
}
type token struct {
accessToken string
expiry time.Time
}
// expired returns true if there is no access token or the
// access token is expired.
func (t token) expired() bool {
if t.accessToken == "" {
return true
}
if t.expiry.IsZero() {
return false
}
return t.expiry.Before(time.Now())
}
// computeEngine uses the Application Default Credentials as provided to Google Compute Engine instances.
// computeEngine represents credentials for the built-in service account for
// the currently running Google Compute Engine (GCE) instance. It uses the
// metadata server to get access tokens.
type computeEngine struct {
mu sync.Mutex
t token
ts oauth2.TokenSource
}
// GetRequestMetadata returns a refreshed access token.
func (c *computeEngine) GetRequestMetadata() (map[string]string, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.t.expired() {
if err := c.refresh(); err != nil {
return nil, err
}
func (c computeEngine) GetRequestMetadata(ctx context.Context) (map[string]string, error) {
token, err := c.ts.Token()
if err != nil {
return nil, err
}
return map[string]string{
"authorization": "Bearer " + c.t.accessToken,
"authorization": token.TokenType + " " + token.AccessToken,
}, nil
}
func (c *computeEngine) refresh() error {
// https://developers.google.com/compute/docs/metadata
// v1 requires "Metadata-Flavor: Google" header.
tokenURL := &url.URL{
Scheme: "http",
Host: metadataServer,
Path: serviceAccountPath,
// NewComputeEngine constructs the credentials that fetches access tokens from
// Google Compute Engine (GCE)'s metadata server. It is only valid to use this
// if your program is running on a GCE instance.
func NewComputeEngine() Credentials {
return computeEngine{
ts: google.ComputeTokenSource(""),
}
req, err := http.NewRequest("GET", tokenURL.String(), nil)
if err != nil {
return err
}
req.Header.Add("Metadata-Flavor", "Google")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
var td tokenData
err = json.NewDecoder(resp.Body).Decode(&td)
if err != nil {
return err
}
// No need to check td.tokenType.
c.t = token{
accessToken: td.accessToken,
expiry: time.Now().Add(time.Duration(td.expiresIn) * time.Second),
}
return nil
}
// NewComputeEngine constructs a credentials for GCE.
func NewComputeEngine() (Credentials, error) {
creds := &computeEngine{}
// TODO(zhaoq): This is not optimal if refresh() is persistently failed.
if err := creds.refresh(); err != nil {
// serviceAccount represents credentials via JWT signing key.
type serviceAccount struct {
config *jwt.Config
}
func (s serviceAccount) GetRequestMetadata(ctx context.Context) (map[string]string, error) {
c, ok := ctx.(oauth2.Context)
if !ok {
return nil, fmt.Errorf("credentials: the context %v is invalid", ctx)
}
token, err := s.config.TokenSource(c).Token()
if err != nil {
return nil, err
}
return creds, nil
return map[string]string{
"authorization": token.TokenType + " " + token.AccessToken,
}, nil
}
// NewServiceAccountFromKey constructs the credentials using the JSON key slice
// from a Google Developers service account.
func NewServiceAccountFromKey(jsonKey []byte, scope ...string) (Credentials, error) {
config, err := google.JWTConfigFromJSON(jsonKey, scope...)
if err != nil {
return nil, err
}
return serviceAccount{config: config}, nil
}
// NewServiceAccountFromFile constructs the credentials using the JSON key file
// of a Google Developers service account.
func NewServiceAccountFromFile(keyFile string, scope ...string) (Credentials, error) {
jsonKey, err := ioutil.ReadFile(keyFile)
if err != nil {
return nil, fmt.Errorf("credentials: failed to read the service account key file: %v", err)
}
return NewServiceAccountFromKey(jsonKey, scope...)
}

View File

@ -44,10 +44,10 @@ import (
"github.com/bradfitz/http2"
"github.com/bradfitz/http2/hpack"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"golang.org/x/net/context"
)
// http2Client implements the ClientTransport interface with HTTP2.
@ -218,7 +218,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
t.hEnc.WriteField(hpack.HeaderField{Name: "te", Value: "trailers"})
for _, c := range t.authCreds {
m, err := c.GetRequestMetadata()
m, err := c.GetRequestMetadata(ctx)
select {
case <-ctx.Done():
return nil, ContextErr(ctx.Err())
default:
}
if err != nil {
return nil, StreamErrorf(codes.InvalidArgument, "%v", err)
}