776 lines
22 KiB
Go
776 lines
22 KiB
Go
// +build go1.13
|
|
|
|
/*
|
|
*
|
|
* Copyright 2020 gRPC authors.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*
|
|
*/
|
|
|
|
package sts
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/x509"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
|
|
"google.golang.org/grpc/credentials"
|
|
"google.golang.org/grpc/internal"
|
|
"google.golang.org/grpc/internal/grpctest"
|
|
"google.golang.org/grpc/internal/testutils"
|
|
)
|
|
|
|
const (
|
|
requestedTokenType = "urn:ietf:params:oauth:token-type:access-token"
|
|
actorTokenPath = "/var/run/secrets/token.jwt"
|
|
actorTokenType = "urn:ietf:params:oauth:token-type:refresh_token"
|
|
actorTokenContents = "actorToken.jwt.contents"
|
|
accessTokenContents = "access_token"
|
|
subjectTokenPath = "/var/run/secrets/token.jwt"
|
|
subjectTokenType = "urn:ietf:params:oauth:token-type:id_token"
|
|
subjectTokenContents = "subjectToken.jwt.contents"
|
|
serviceURI = "http://localhost"
|
|
exampleResource = "https://backend.example.com/api"
|
|
exampleAudience = "example-backend-service"
|
|
testScope = "https://www.googleapis.com/auth/monitoring"
|
|
defaultTestTimeout = 1 * time.Second
|
|
)
|
|
|
|
var (
|
|
goodOptions = Options{
|
|
TokenExchangeServiceURI: serviceURI,
|
|
Audience: exampleAudience,
|
|
RequestedTokenType: requestedTokenType,
|
|
SubjectTokenPath: subjectTokenPath,
|
|
SubjectTokenType: subjectTokenType,
|
|
}
|
|
goodRequestParams = &requestParameters{
|
|
GrantType: tokenExchangeGrantType,
|
|
Audience: exampleAudience,
|
|
Scope: defaultCloudPlatformScope,
|
|
RequestedTokenType: requestedTokenType,
|
|
SubjectToken: subjectTokenContents,
|
|
SubjectTokenType: subjectTokenType,
|
|
}
|
|
goodMetadata = map[string]string{
|
|
"Authorization": fmt.Sprintf("Bearer %s", accessTokenContents),
|
|
}
|
|
)
|
|
|
|
type s struct {
|
|
grpctest.Tester
|
|
}
|
|
|
|
func Test(t *testing.T) {
|
|
grpctest.RunSubTests(t, s{})
|
|
}
|
|
|
|
// A struct that implements AuthInfo interface and added to the context passed
|
|
// to GetRequestMetadata from tests.
|
|
type testAuthInfo struct {
|
|
credentials.CommonAuthInfo
|
|
}
|
|
|
|
func (ta testAuthInfo) AuthType() string {
|
|
return "testAuthInfo"
|
|
}
|
|
|
|
func createTestContext(ctx context.Context, s credentials.SecurityLevel) context.Context {
|
|
auth := &testAuthInfo{CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: s}}
|
|
ri := credentials.RequestInfo{
|
|
Method: "testInfo",
|
|
AuthInfo: auth,
|
|
}
|
|
return internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri)
|
|
}
|
|
|
|
// errReader implements the io.Reader interface and returns an error from the
|
|
// Read method.
|
|
type errReader struct{}
|
|
|
|
func (r errReader) Read(b []byte) (n int, err error) {
|
|
return 0, errors.New("read error")
|
|
}
|
|
|
|
// We need a function to construct the response instead of simply declaring it
|
|
// as a variable since the the response body will be consumed by the
|
|
// credentials, and therefore we will need a new one everytime.
|
|
func makeGoodResponse() *http.Response {
|
|
respJSON, _ := json.Marshal(responseParameters{
|
|
AccessToken: accessTokenContents,
|
|
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
|
|
TokenType: "Bearer",
|
|
ExpiresIn: 3600,
|
|
})
|
|
respBody := ioutil.NopCloser(bytes.NewReader(respJSON))
|
|
return &http.Response{
|
|
Status: "200 OK",
|
|
StatusCode: http.StatusOK,
|
|
Body: respBody,
|
|
}
|
|
}
|
|
|
|
// fakeHTTPDoer helps mock out the http.Client.Do calls made by the credentials
|
|
// code under test. It makes the http.Request made by the credentials available
|
|
// through a channel, and makes it possible to inject various responses.
|
|
type fakeHTTPDoer struct {
|
|
reqCh *testutils.Channel
|
|
respCh *testutils.Channel
|
|
err error
|
|
}
|
|
|
|
func (fc *fakeHTTPDoer) Do(req *http.Request) (*http.Response, error) {
|
|
fc.reqCh.Send(req)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
|
defer cancel()
|
|
|
|
val, err := fc.respCh.Receive(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return val.(*http.Response), fc.err
|
|
}
|
|
|
|
// Overrides the http.Client with a fakeClient which sends a good response.
|
|
func overrideHTTPClientGood() (*fakeHTTPDoer, func()) {
|
|
fc := &fakeHTTPDoer{
|
|
reqCh: testutils.NewChannel(),
|
|
respCh: testutils.NewChannel(),
|
|
}
|
|
fc.respCh.Send(makeGoodResponse())
|
|
|
|
origMakeHTTPDoer := makeHTTPDoer
|
|
makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc }
|
|
return fc, func() { makeHTTPDoer = origMakeHTTPDoer }
|
|
}
|
|
|
|
// Overrides the http.Client with the provided fakeClient.
|
|
func overrideHTTPClient(fc *fakeHTTPDoer) func() {
|
|
origMakeHTTPDoer := makeHTTPDoer
|
|
makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc }
|
|
return func() { makeHTTPDoer = origMakeHTTPDoer }
|
|
}
|
|
|
|
// Overrides the subject token read to return a const which we can compare in
|
|
// our tests.
|
|
func overrideSubjectTokenGood() func() {
|
|
origReadSubjectTokenFrom := readSubjectTokenFrom
|
|
readSubjectTokenFrom = func(path string) ([]byte, error) {
|
|
return []byte(subjectTokenContents), nil
|
|
}
|
|
return func() { readSubjectTokenFrom = origReadSubjectTokenFrom }
|
|
}
|
|
|
|
// Overrides the subject token read to always return an error.
|
|
func overrideSubjectTokenError() func() {
|
|
origReadSubjectTokenFrom := readSubjectTokenFrom
|
|
readSubjectTokenFrom = func(path string) ([]byte, error) {
|
|
return nil, errors.New("error reading subject token")
|
|
}
|
|
return func() { readSubjectTokenFrom = origReadSubjectTokenFrom }
|
|
}
|
|
|
|
// Overrides the actor token read to return a const which we can compare in
|
|
// our tests.
|
|
func overrideActorTokenGood() func() {
|
|
origReadActorTokenFrom := readActorTokenFrom
|
|
readActorTokenFrom = func(path string) ([]byte, error) {
|
|
return []byte(actorTokenContents), nil
|
|
}
|
|
return func() { readActorTokenFrom = origReadActorTokenFrom }
|
|
}
|
|
|
|
// Overrides the actor token read to always return an error.
|
|
func overrideActorTokenError() func() {
|
|
origReadActorTokenFrom := readActorTokenFrom
|
|
readActorTokenFrom = func(path string) ([]byte, error) {
|
|
return nil, errors.New("error reading actor token")
|
|
}
|
|
return func() { readActorTokenFrom = origReadActorTokenFrom }
|
|
}
|
|
|
|
// compareRequest compares the http.Request received in the test with the
|
|
// expected requestParameters specified in wantReqParams.
|
|
func compareRequest(gotRequest *http.Request, wantReqParams *requestParameters) error {
|
|
jsonBody, err := json.Marshal(wantReqParams)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
wantReq, err := http.NewRequest("POST", serviceURI, bytes.NewBuffer(jsonBody))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create http request: %v", err)
|
|
}
|
|
wantReq.Header.Set("Content-Type", "application/json")
|
|
|
|
wantR, err := httputil.DumpRequestOut(wantReq, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
gotR, err := httputil.DumpRequestOut(gotRequest, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if diff := cmp.Diff(string(wantR), string(gotR)); diff != "" {
|
|
return fmt.Errorf("sts request diff (-want +got):\n%s", diff)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// receiveAndCompareRequest waits for a request to be sent out by the
|
|
// credentials implementation using the fakeHTTPClient and compares it to an
|
|
// expected goodRequest. This is expected to be called in a separate goroutine
|
|
// by the tests. So, any errors encountered are pushed to an error channel
|
|
// which is monitored by the test.
|
|
func receiveAndCompareRequest(reqCh *testutils.Channel, errCh chan error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
|
defer cancel()
|
|
|
|
val, err := reqCh.Receive(ctx)
|
|
if err != nil {
|
|
errCh <- err
|
|
return
|
|
}
|
|
req := val.(*http.Request)
|
|
if err := compareRequest(req, goodRequestParams); err != nil {
|
|
errCh <- err
|
|
return
|
|
}
|
|
errCh <- nil
|
|
}
|
|
|
|
// TestGetRequestMetadataSuccess verifies the successful case of sending an
|
|
// token exchange request and processing the response.
|
|
func (s) TestGetRequestMetadataSuccess(t *testing.T) {
|
|
defer overrideSubjectTokenGood()()
|
|
fc, cancel := overrideHTTPClientGood()
|
|
defer cancel()
|
|
|
|
creds, err := NewCredentials(goodOptions)
|
|
if err != nil {
|
|
t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
|
|
}
|
|
|
|
errCh := make(chan error, 1)
|
|
go receiveAndCompareRequest(fc.reqCh, errCh)
|
|
|
|
gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "")
|
|
if err != nil {
|
|
t.Fatalf("creds.GetRequestMetadata() = %v", err)
|
|
}
|
|
if !cmp.Equal(gotMetadata, goodMetadata) {
|
|
t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
|
|
}
|
|
if err := <-errCh; err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Make another call to get request metadata and this should return contents
|
|
// from the cache. This will fail if the credentials tries to send a fresh
|
|
// request here since we have not configured our fakeClient to return any
|
|
// response on retries.
|
|
gotMetadata, err = creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "")
|
|
if err != nil {
|
|
t.Fatalf("creds.GetRequestMetadata() = %v", err)
|
|
}
|
|
if !cmp.Equal(gotMetadata, goodMetadata) {
|
|
t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
|
|
}
|
|
}
|
|
|
|
// TestGetRequestMetadataBadSecurityLevel verifies the case where the
|
|
// securityLevel specified in the context passed to GetRequestMetadata is not
|
|
// sufficient.
|
|
func (s) TestGetRequestMetadataBadSecurityLevel(t *testing.T) {
|
|
defer overrideSubjectTokenGood()()
|
|
|
|
creds, err := NewCredentials(goodOptions)
|
|
if err != nil {
|
|
t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
|
|
}
|
|
|
|
gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.IntegrityOnly), "")
|
|
if err == nil {
|
|
t.Fatalf("creds.GetRequestMetadata() succeeded with metadata %v, expected to fail", gotMetadata)
|
|
}
|
|
}
|
|
|
|
// TestGetRequestMetadataCacheExpiry verifies the case where the cached access
|
|
// token has expired, and the credentials implementation will have to send a
|
|
// fresh token exchange request.
|
|
func (s) TestGetRequestMetadataCacheExpiry(t *testing.T) {
|
|
const expiresInSecs = 1
|
|
defer overrideSubjectTokenGood()()
|
|
fc := &fakeHTTPDoer{
|
|
reqCh: testutils.NewChannel(),
|
|
respCh: testutils.NewChannel(),
|
|
}
|
|
defer overrideHTTPClient(fc)()
|
|
|
|
creds, err := NewCredentials(goodOptions)
|
|
if err != nil {
|
|
t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
|
|
}
|
|
|
|
// The fakeClient is configured to return an access_token with a one second
|
|
// expiry. So, in the second iteration, the credentials will find the cache
|
|
// entry, but that would have expired, and therefore we expect it to send
|
|
// out a fresh request.
|
|
for i := 0; i < 2; i++ {
|
|
errCh := make(chan error, 1)
|
|
go receiveAndCompareRequest(fc.reqCh, errCh)
|
|
|
|
respJSON, _ := json.Marshal(responseParameters{
|
|
AccessToken: accessTokenContents,
|
|
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
|
|
TokenType: "Bearer",
|
|
ExpiresIn: expiresInSecs,
|
|
})
|
|
respBody := ioutil.NopCloser(bytes.NewReader(respJSON))
|
|
resp := &http.Response{
|
|
Status: "200 OK",
|
|
StatusCode: http.StatusOK,
|
|
Body: respBody,
|
|
}
|
|
fc.respCh.Send(resp)
|
|
|
|
gotMetadata, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), "")
|
|
if err != nil {
|
|
t.Fatalf("creds.GetRequestMetadata() = %v", err)
|
|
}
|
|
if !cmp.Equal(gotMetadata, goodMetadata) {
|
|
t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
|
|
}
|
|
if err := <-errCh; err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
time.Sleep(expiresInSecs * time.Second)
|
|
}
|
|
}
|
|
|
|
// TestGetRequestMetadataBadResponses verifies the scenario where the token
|
|
// exchange server returns bad responses.
|
|
func (s) TestGetRequestMetadataBadResponses(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
response *http.Response
|
|
}{
|
|
{
|
|
name: "bad JSON",
|
|
response: &http.Response{
|
|
Status: "200 OK",
|
|
StatusCode: http.StatusOK,
|
|
Body: ioutil.NopCloser(strings.NewReader("not JSON")),
|
|
},
|
|
},
|
|
{
|
|
name: "no access token",
|
|
response: &http.Response{
|
|
Status: "200 OK",
|
|
StatusCode: http.StatusOK,
|
|
Body: ioutil.NopCloser(strings.NewReader("{}")),
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
defer overrideSubjectTokenGood()()
|
|
|
|
fc := &fakeHTTPDoer{
|
|
reqCh: testutils.NewChannel(),
|
|
respCh: testutils.NewChannel(),
|
|
}
|
|
defer overrideHTTPClient(fc)()
|
|
|
|
creds, err := NewCredentials(goodOptions)
|
|
if err != nil {
|
|
t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
|
|
}
|
|
|
|
errCh := make(chan error, 1)
|
|
go receiveAndCompareRequest(fc.reqCh, errCh)
|
|
|
|
fc.respCh.Send(test.response)
|
|
if _, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), ""); err == nil {
|
|
t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail")
|
|
}
|
|
if err := <-errCh; err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestGetRequestMetadataBadSubjectTokenRead verifies the scenario where the
|
|
// attempt to read the subjectToken fails.
|
|
func (s) TestGetRequestMetadataBadSubjectTokenRead(t *testing.T) {
|
|
defer overrideSubjectTokenError()()
|
|
fc, cancel := overrideHTTPClientGood()
|
|
defer cancel()
|
|
|
|
creds, err := NewCredentials(goodOptions)
|
|
if err != nil {
|
|
t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
|
|
}
|
|
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
|
defer cancel()
|
|
|
|
if _, err := fc.reqCh.Receive(ctx); err != context.DeadlineExceeded {
|
|
errCh <- err
|
|
return
|
|
}
|
|
errCh <- nil
|
|
}()
|
|
|
|
if _, err := creds.GetRequestMetadata(createTestContext(context.Background(), credentials.PrivacyAndIntegrity), ""); err == nil {
|
|
t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail")
|
|
}
|
|
if err := <-errCh; err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func (s) TestNewCredentials(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
opts Options
|
|
errSystemRoots bool
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "invalid options - empty subjectTokenPath",
|
|
opts: Options{
|
|
TokenExchangeServiceURI: serviceURI,
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "invalid system root certs",
|
|
opts: goodOptions,
|
|
errSystemRoots: true,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "good case",
|
|
opts: goodOptions,
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
if test.errSystemRoots {
|
|
oldSystemRoots := loadSystemCertPool
|
|
loadSystemCertPool = func() (*x509.CertPool, error) {
|
|
return nil, errors.New("failed to load system cert pool")
|
|
}
|
|
defer func() {
|
|
loadSystemCertPool = oldSystemRoots
|
|
}()
|
|
}
|
|
|
|
creds, err := NewCredentials(test.opts)
|
|
if (err != nil) != test.wantErr {
|
|
t.Fatalf("NewCredentials(%v) = %v, want %v", test.opts, err, test.wantErr)
|
|
}
|
|
if err == nil {
|
|
if !creds.RequireTransportSecurity() {
|
|
t.Errorf("creds.RequireTransportSecurity() returned false")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func (s) TestValidateOptions(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
opts Options
|
|
wantErrPrefix string
|
|
}{
|
|
{
|
|
name: "empty token exchange service URI",
|
|
opts: Options{},
|
|
wantErrPrefix: "empty token_exchange_service_uri in options",
|
|
},
|
|
{
|
|
name: "invalid URI",
|
|
opts: Options{
|
|
TokenExchangeServiceURI: "\tI'm a bad URI\n",
|
|
},
|
|
wantErrPrefix: "invalid control character in URL",
|
|
},
|
|
{
|
|
name: "unsupported scheme",
|
|
opts: Options{
|
|
TokenExchangeServiceURI: "unix:///path/to/socket",
|
|
},
|
|
wantErrPrefix: "scheme is not supported",
|
|
},
|
|
{
|
|
name: "empty subjectTokenPath",
|
|
opts: Options{
|
|
TokenExchangeServiceURI: serviceURI,
|
|
},
|
|
wantErrPrefix: "required field SubjectTokenPath is not specified",
|
|
},
|
|
{
|
|
name: "empty subjectTokenType",
|
|
opts: Options{
|
|
TokenExchangeServiceURI: serviceURI,
|
|
SubjectTokenPath: subjectTokenPath,
|
|
},
|
|
wantErrPrefix: "required field SubjectTokenType is not specified",
|
|
},
|
|
{
|
|
name: "good options",
|
|
opts: goodOptions,
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
err := validateOptions(test.opts)
|
|
if (err != nil) != (test.wantErrPrefix != "") {
|
|
t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix)
|
|
}
|
|
if err != nil && !strings.Contains(err.Error(), test.wantErrPrefix) {
|
|
t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func (s) TestConstructRequest(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
opts Options
|
|
subjectTokenReadErr bool
|
|
actorTokenReadErr bool
|
|
wantReqParams *requestParameters
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "subject token read failure",
|
|
subjectTokenReadErr: true,
|
|
opts: goodOptions,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "actor token read failure",
|
|
actorTokenReadErr: true,
|
|
opts: Options{
|
|
TokenExchangeServiceURI: serviceURI,
|
|
Audience: exampleAudience,
|
|
RequestedTokenType: requestedTokenType,
|
|
SubjectTokenPath: subjectTokenPath,
|
|
SubjectTokenType: subjectTokenType,
|
|
ActorTokenPath: actorTokenPath,
|
|
ActorTokenType: actorTokenType,
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "default cloud platform scope",
|
|
opts: goodOptions,
|
|
wantReqParams: goodRequestParams,
|
|
},
|
|
{
|
|
name: "all good",
|
|
opts: Options{
|
|
TokenExchangeServiceURI: serviceURI,
|
|
Resource: exampleResource,
|
|
Audience: exampleAudience,
|
|
Scope: testScope,
|
|
RequestedTokenType: requestedTokenType,
|
|
SubjectTokenPath: subjectTokenPath,
|
|
SubjectTokenType: subjectTokenType,
|
|
ActorTokenPath: actorTokenPath,
|
|
ActorTokenType: actorTokenType,
|
|
},
|
|
wantReqParams: &requestParameters{
|
|
GrantType: tokenExchangeGrantType,
|
|
Resource: exampleResource,
|
|
Audience: exampleAudience,
|
|
Scope: testScope,
|
|
RequestedTokenType: requestedTokenType,
|
|
SubjectToken: subjectTokenContents,
|
|
SubjectTokenType: subjectTokenType,
|
|
ActorToken: actorTokenContents,
|
|
ActorTokenType: actorTokenType,
|
|
},
|
|
},
|
|
}
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
if test.subjectTokenReadErr {
|
|
defer overrideSubjectTokenError()()
|
|
} else {
|
|
defer overrideSubjectTokenGood()()
|
|
}
|
|
|
|
if test.actorTokenReadErr {
|
|
defer overrideActorTokenError()()
|
|
} else {
|
|
defer overrideActorTokenGood()()
|
|
}
|
|
|
|
gotRequest, err := constructRequest(context.Background(), test.opts)
|
|
if (err != nil) != test.wantErr {
|
|
t.Fatalf("constructRequest(%v) = %v, wantErr: %v", test.opts, err, test.wantErr)
|
|
}
|
|
if test.wantErr {
|
|
return
|
|
}
|
|
if err := compareRequest(gotRequest, test.wantReqParams); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func (s) TestSendRequest(t *testing.T) {
|
|
defer overrideSubjectTokenGood()()
|
|
req, err := constructRequest(context.Background(), goodOptions)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
resp *http.Response
|
|
respErr error
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "client error",
|
|
respErr: errors.New("http.Client.Do failed"),
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "bad response body",
|
|
resp: &http.Response{
|
|
Status: "200 OK",
|
|
StatusCode: http.StatusOK,
|
|
Body: ioutil.NopCloser(errReader{}),
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "nonOK status code",
|
|
resp: &http.Response{
|
|
Status: "400 BadRequest",
|
|
StatusCode: http.StatusBadRequest,
|
|
Body: ioutil.NopCloser(strings.NewReader("")),
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "good case",
|
|
resp: makeGoodResponse(),
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
client := &fakeHTTPDoer{
|
|
reqCh: testutils.NewChannel(),
|
|
respCh: testutils.NewChannel(),
|
|
err: test.respErr,
|
|
}
|
|
client.respCh.Send(test.resp)
|
|
_, err := sendRequest(client, req)
|
|
if (err != nil) != test.wantErr {
|
|
t.Errorf("sendRequest(%v) = %v, wantErr: %v", req, err, test.wantErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func (s) TestTokenInfoFromResponse(t *testing.T) {
|
|
noAccessToken, _ := json.Marshal(responseParameters{
|
|
IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
|
|
TokenType: "Bearer",
|
|
ExpiresIn: 3600,
|
|
})
|
|
goodResponse, _ := json.Marshal(responseParameters{
|
|
IssuedTokenType: requestedTokenType,
|
|
AccessToken: accessTokenContents,
|
|
TokenType: "Bearer",
|
|
ExpiresIn: 3600,
|
|
})
|
|
|
|
tests := []struct {
|
|
name string
|
|
respBody []byte
|
|
wantTokenInfo *tokenInfo
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "bad JSON",
|
|
respBody: []byte("not JSON"),
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "empty response",
|
|
respBody: []byte(""),
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "non-empty response with no access token",
|
|
respBody: noAccessToken,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "good response",
|
|
respBody: goodResponse,
|
|
wantTokenInfo: &tokenInfo{
|
|
tokenType: "Bearer",
|
|
token: accessTokenContents,
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
gotTokenInfo, err := tokenInfoFromResponse(test.respBody)
|
|
if (err != nil) != test.wantErr {
|
|
t.Fatalf("tokenInfoFromResponse(%+v) = %v, wantErr: %v", test.respBody, err, test.wantErr)
|
|
}
|
|
if test.wantErr {
|
|
return
|
|
}
|
|
// Can't do a cmp.Equal on the whole struct since the expiryField
|
|
// is populated based on time.Now().
|
|
if gotTokenInfo.tokenType != test.wantTokenInfo.tokenType || gotTokenInfo.token != test.wantTokenInfo.token {
|
|
t.Errorf("tokenInfoFromResponse(%+v) = %+v, want: %+v", test.respBody, gotTokenInfo, test.wantTokenInfo)
|
|
}
|
|
})
|
|
}
|
|
}
|