mirror of
https://github.com/grafana/grafana.git
synced 2025-07-29 12:52:08 +08:00
683 lines
23 KiB
Go
683 lines
23 KiB
Go
package oauthtoken
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
"golang.org/x/oauth2"
|
|
|
|
claims "github.com/grafana/authlib/types"
|
|
"github.com/grafana/grafana/pkg/apimachinery/identity"
|
|
"github.com/grafana/grafana/pkg/infra/db"
|
|
"github.com/grafana/grafana/pkg/infra/serverlock"
|
|
"github.com/grafana/grafana/pkg/infra/tracing"
|
|
"github.com/grafana/grafana/pkg/login/social"
|
|
"github.com/grafana/grafana/pkg/login/social/socialtest"
|
|
"github.com/grafana/grafana/pkg/models/usertoken"
|
|
"github.com/grafana/grafana/pkg/services/auth"
|
|
"github.com/grafana/grafana/pkg/services/auth/authtest"
|
|
"github.com/grafana/grafana/pkg/services/authn"
|
|
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
|
"github.com/grafana/grafana/pkg/services/login"
|
|
"github.com/grafana/grafana/pkg/services/login/authinfotest"
|
|
"github.com/grafana/grafana/pkg/setting"
|
|
"github.com/grafana/grafana/pkg/tests/testsuite"
|
|
)
|
|
|
|
const EXPIRED_ID_TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2V4YW1wbGUuY29tIiwic3ViIjoiMTIzNDU2Nzg5MCIsImF1ZCI6InlvdXItY2xpZW50LWlkIiwiZXhwIjoxNjAwMDAwMDAwLCJpYXQiOjE2MDAwMDAwMDAsIm5hbWUiOiJKb2huIERvZSIsImVtYWlsIjoiam9obkBleGFtcGxlLmNvbSJ9.c2lnbmF0dXJl" // #nosec G101 not a hardcoded credential
|
|
|
|
const UNEXPIRED_ID_TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2V4YW1wbGUuY29tIiwic3ViIjoiMTIzNDU2Nzg5MCIsImF1ZCI6InlvdXItY2xpZW50LWlkIiwiZXhwIjo0ODg1NjA4MDAwLCJpYXQiOjE2ODU2MDgwMDAsIm5hbWUiOiJKb2huIERvZSIsImVtYWlsIjoiam9obkBleGFtcGxlLmNvbSJ9.c2lnbmF0dXJl" // #nosec G101 not a hardcoded credential
|
|
|
|
func TestMain(m *testing.M) {
|
|
testsuite.Run(m)
|
|
}
|
|
|
|
type FakeAuthInfoStore struct {
|
|
login.Store
|
|
ExpectedError error
|
|
ExpectedOAuth *login.UserAuth
|
|
}
|
|
|
|
func (f *FakeAuthInfoStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
|
|
return f.ExpectedOAuth, f.ExpectedError
|
|
}
|
|
|
|
func (f *FakeAuthInfoStore) SetAuthInfo(ctx context.Context, cmd *login.SetAuthInfoCommand) error {
|
|
return f.ExpectedError
|
|
}
|
|
|
|
func (f *FakeAuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error {
|
|
f.ExpectedOAuth.OAuthAccessToken = cmd.OAuthToken.AccessToken
|
|
f.ExpectedOAuth.OAuthExpiry = cmd.OAuthToken.Expiry
|
|
f.ExpectedOAuth.OAuthTokenType = cmd.OAuthToken.TokenType
|
|
f.ExpectedOAuth.OAuthRefreshToken = cmd.OAuthToken.RefreshToken
|
|
return f.ExpectedError
|
|
}
|
|
|
|
func (f *FakeAuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *login.DeleteAuthInfoCommand) error {
|
|
return f.ExpectedError
|
|
}
|
|
|
|
func TestIntegration_TryTokenRefresh(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("skipping integration test")
|
|
}
|
|
|
|
unexpiredToken := &oauth2.Token{
|
|
AccessToken: "testaccess",
|
|
RefreshToken: "testrefresh",
|
|
Expiry: time.Now().Add(time.Hour),
|
|
TokenType: "Bearer",
|
|
}
|
|
unexpiredTokenWithIDToken := unexpiredToken.WithExtra(map[string]interface{}{
|
|
"id_token": UNEXPIRED_ID_TOKEN,
|
|
})
|
|
|
|
expiredToken := &oauth2.Token{
|
|
AccessToken: "testaccess",
|
|
RefreshToken: "testrefresh",
|
|
Expiry: time.Now().Add(-time.Hour),
|
|
TokenType: "Bearer",
|
|
}
|
|
|
|
type environment struct {
|
|
sessionService *authtest.MockUserAuthTokenService
|
|
authInfoService *authinfotest.FakeService
|
|
serverLock *serverlock.ServerLockService
|
|
socialConnector *socialtest.MockSocialConnector
|
|
socialService *socialtest.FakeSocialService
|
|
|
|
store db.DB
|
|
service *Service
|
|
}
|
|
|
|
type testCase struct {
|
|
desc string
|
|
identity identity.Requester
|
|
setup func(env *environment)
|
|
expectedToken *oauth2.Token
|
|
expectedErr error
|
|
}
|
|
|
|
userIdentity := &authn.Identity{
|
|
AuthenticatedBy: login.GenericOAuthModule,
|
|
ID: "1234",
|
|
Type: claims.TypeUser,
|
|
}
|
|
|
|
tests := []testCase{
|
|
{
|
|
desc: "should skip sync when identity is nil",
|
|
},
|
|
{
|
|
desc: "should skip sync when identity is not a user",
|
|
identity: &authn.Identity{ID: "1", Type: claims.TypeServiceAccount},
|
|
},
|
|
{
|
|
desc: "should skip token refresh and return nil if namespace and id cannot be converted to user ID",
|
|
identity: &authn.Identity{ID: "invalid", Type: claims.TypeUser},
|
|
},
|
|
{
|
|
desc: "should skip token refresh if there's an unexpected error while looking up the user oauth entry, additionally, no error should be returned",
|
|
identity: userIdentity,
|
|
setup: func(env *environment) {
|
|
env.authInfoService.ExpectedError = errors.New("some error")
|
|
},
|
|
},
|
|
{
|
|
desc: "should skip token refresh if the user doesn't have an oauth entry",
|
|
identity: userIdentity,
|
|
setup: func(env *environment) {
|
|
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
|
AuthModule: login.SAMLAuthModule,
|
|
}
|
|
},
|
|
},
|
|
{
|
|
desc: "should skip token refresh when no oauth provider was found",
|
|
identity: userIdentity,
|
|
setup: func(env *environment) {
|
|
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
|
AuthModule: login.GenericOAuthModule,
|
|
}
|
|
},
|
|
},
|
|
{
|
|
desc: "should skip token refresh when oauth provider token handling is disabled (UseRefreshToken is false)",
|
|
identity: userIdentity,
|
|
setup: func(env *environment) {
|
|
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
|
AuthModule: login.GenericOAuthModule,
|
|
}
|
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
|
UseRefreshToken: false,
|
|
}
|
|
},
|
|
},
|
|
{
|
|
desc: "should skip token refresh when the token is still valid and no id token is present",
|
|
identity: userIdentity,
|
|
setup: func(env *environment) {
|
|
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
|
AuthModule: login.GenericOAuthModule,
|
|
OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken,
|
|
OAuthRefreshToken: unexpiredTokenWithIDToken.RefreshToken,
|
|
OAuthExpiry: unexpiredTokenWithIDToken.Expiry,
|
|
OAuthTokenType: unexpiredTokenWithIDToken.TokenType,
|
|
}
|
|
|
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
|
UseRefreshToken: true,
|
|
}
|
|
},
|
|
expectedToken: unexpiredToken,
|
|
},
|
|
{
|
|
desc: "should not refresh the tokens if access token or id token have not expired yet",
|
|
identity: userIdentity,
|
|
setup: func(env *environment) {
|
|
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
|
AuthModule: login.GenericOAuthModule,
|
|
OAuthIdToken: UNEXPIRED_ID_TOKEN,
|
|
OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken,
|
|
OAuthRefreshToken: unexpiredTokenWithIDToken.RefreshToken,
|
|
OAuthExpiry: unexpiredTokenWithIDToken.Expiry,
|
|
OAuthTokenType: unexpiredTokenWithIDToken.TokenType,
|
|
}
|
|
|
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
|
UseRefreshToken: true,
|
|
}
|
|
},
|
|
expectedToken: unexpiredTokenWithIDToken,
|
|
},
|
|
{
|
|
desc: "should skip token refresh when there is no refresh token",
|
|
identity: userIdentity,
|
|
setup: func(env *environment) {
|
|
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
|
AuthModule: login.GenericOAuthModule,
|
|
OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken,
|
|
OAuthRefreshToken: "",
|
|
OAuthExpiry: unexpiredTokenWithIDToken.Expiry,
|
|
}
|
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
|
UseRefreshToken: true,
|
|
}
|
|
},
|
|
expectedToken: &oauth2.Token{
|
|
AccessToken: unexpiredTokenWithIDToken.AccessToken,
|
|
RefreshToken: "",
|
|
Expiry: unexpiredTokenWithIDToken.Expiry,
|
|
},
|
|
},
|
|
{
|
|
desc: "should do token refresh when the token is expired",
|
|
identity: userIdentity,
|
|
setup: func(env *environment) {
|
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
|
UseRefreshToken: true,
|
|
}
|
|
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
|
AuthModule: login.GenericOAuthModule,
|
|
AuthId: "subject",
|
|
UserId: 1,
|
|
OAuthAccessToken: expiredToken.AccessToken,
|
|
OAuthRefreshToken: expiredToken.RefreshToken,
|
|
OAuthExpiry: expiredToken.Expiry,
|
|
OAuthTokenType: expiredToken.TokenType,
|
|
OAuthIdToken: EXPIRED_ID_TOKEN,
|
|
}
|
|
env.sessionService.On("UpdateExternalSession", mock.Anything, int64(1), mock.MatchedBy(verifyUpdateExternalSessionCommand(unexpiredTokenWithIDToken))).Return(nil).Once()
|
|
|
|
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once()
|
|
},
|
|
expectedToken: unexpiredTokenWithIDToken,
|
|
},
|
|
{
|
|
desc: "should refresh token when the id token is expired",
|
|
identity: &authn.Identity{ID: "1234", Type: claims.TypeUser, AuthenticatedBy: login.GenericOAuthModule},
|
|
setup: func(env *environment) {
|
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
|
UseRefreshToken: true,
|
|
}
|
|
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
|
AuthModule: login.GenericOAuthModule,
|
|
AuthId: "subject",
|
|
UserId: 1,
|
|
OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken,
|
|
OAuthRefreshToken: unexpiredTokenWithIDToken.RefreshToken,
|
|
OAuthExpiry: unexpiredTokenWithIDToken.Expiry,
|
|
OAuthTokenType: unexpiredTokenWithIDToken.TokenType,
|
|
OAuthIdToken: EXPIRED_ID_TOKEN,
|
|
}
|
|
env.sessionService.On("UpdateExternalSession", mock.Anything, int64(1), mock.MatchedBy(verifyUpdateExternalSessionCommand(unexpiredTokenWithIDToken))).Return(nil).Once()
|
|
|
|
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once()
|
|
},
|
|
expectedToken: unexpiredTokenWithIDToken,
|
|
},
|
|
{
|
|
desc: "should return ErrRetriesExhausted when lock cannot be acquired",
|
|
identity: &authn.Identity{ID: "1234", Type: claims.TypeUser, AuthenticatedBy: login.GenericOAuthModule},
|
|
setup: func(env *environment) {
|
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
|
UseRefreshToken: true,
|
|
}
|
|
env.authInfoService.ExpectedUserAuth = &login.UserAuth{
|
|
AuthModule: login.GenericOAuthModule,
|
|
AuthId: "subject",
|
|
UserId: 1234,
|
|
OAuthAccessToken: unexpiredTokenWithIDToken.AccessToken,
|
|
OAuthRefreshToken: unexpiredTokenWithIDToken.RefreshToken,
|
|
OAuthExpiry: unexpiredTokenWithIDToken.Expiry,
|
|
OAuthTokenType: unexpiredTokenWithIDToken.TokenType,
|
|
OAuthIdToken: EXPIRED_ID_TOKEN,
|
|
}
|
|
_ = env.store.WithDbSession(context.Background(), func(sess *db.Session) error {
|
|
_, err := sess.Exec(`INSERT INTO server_lock (operation_uid, last_execution, version) VALUES (?, ?, ?)`, "oauth-refresh-token-1234", time.Now().Add(2*time.Second).Unix(), 0)
|
|
return err
|
|
})
|
|
},
|
|
expectedErr: ErrRetriesExhausted,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.desc, func(t *testing.T) {
|
|
socialConnector := socialtest.NewMockSocialConnector(t)
|
|
|
|
store := db.InitTestDB(t)
|
|
|
|
env := environment{
|
|
sessionService: authtest.NewMockUserAuthTokenService(t),
|
|
authInfoService: &authinfotest.FakeService{},
|
|
serverLock: serverlock.ProvideService(store, tracing.InitializeTracerForTest()),
|
|
socialConnector: socialConnector,
|
|
socialService: &socialtest.FakeSocialService{
|
|
ExpectedConnector: socialConnector,
|
|
},
|
|
store: store,
|
|
}
|
|
|
|
if tt.setup != nil {
|
|
tt.setup(&env)
|
|
}
|
|
|
|
env.service = ProvideService(
|
|
env.socialService,
|
|
env.authInfoService,
|
|
setting.NewCfg(),
|
|
prometheus.NewRegistry(),
|
|
env.serverLock,
|
|
tracing.InitializeTracerForTest(),
|
|
env.sessionService,
|
|
featuremgmt.WithFeatures(),
|
|
)
|
|
|
|
// token refresh
|
|
actualToken, err := env.service.TryTokenRefresh(context.Background(), tt.identity, &usertoken.UserToken{ExternalSessionId: 1})
|
|
|
|
if tt.expectedErr != nil {
|
|
assert.ErrorIs(t, err, tt.expectedErr)
|
|
return
|
|
}
|
|
|
|
if tt.expectedToken == nil {
|
|
assert.Nil(t, actualToken)
|
|
return
|
|
}
|
|
|
|
assert.Equal(t, tt.expectedToken.AccessToken, actualToken.AccessToken)
|
|
assert.Equal(t, tt.expectedToken.RefreshToken, actualToken.RefreshToken)
|
|
assert.Equal(t, tt.expectedToken.Expiry, actualToken.Expiry)
|
|
assert.Equal(t, tt.expectedToken.TokenType, actualToken.TokenType)
|
|
if tt.expectedToken.Extra("id_token") != nil {
|
|
assert.Equal(t, tt.expectedToken.Extra("id_token").(string), actualToken.Extra("id_token").(string))
|
|
} else {
|
|
assert.Nil(t, actualToken.Extra("id_token"))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIntegration_TryTokenRefresh_WithExternalSessions(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("skipping integration test")
|
|
}
|
|
|
|
unexpiredToken := &oauth2.Token{
|
|
AccessToken: "testaccess",
|
|
RefreshToken: "testrefresh",
|
|
Expiry: time.Now().Add(time.Hour),
|
|
TokenType: "Bearer",
|
|
}
|
|
unexpiredTokenWithIDToken := unexpiredToken.WithExtra(map[string]interface{}{
|
|
"id_token": UNEXPIRED_ID_TOKEN,
|
|
})
|
|
|
|
expiredToken := &oauth2.Token{
|
|
AccessToken: "testaccess",
|
|
RefreshToken: "testrefresh",
|
|
Expiry: time.Now().Add(-time.Hour),
|
|
TokenType: "Bearer",
|
|
}
|
|
|
|
userIdentity := &authn.Identity{
|
|
AuthenticatedBy: login.GenericOAuthModule,
|
|
ID: "1234",
|
|
Type: claims.TypeUser,
|
|
}
|
|
|
|
type environment struct {
|
|
sessionService *authtest.MockUserAuthTokenService
|
|
serverLock *serverlock.ServerLockService
|
|
socialConnector *socialtest.MockSocialConnector
|
|
socialService *socialtest.FakeSocialService
|
|
|
|
store db.DB
|
|
service *Service
|
|
}
|
|
|
|
type testCase struct {
|
|
desc string
|
|
identity identity.Requester
|
|
setup func(env *environment)
|
|
expectedToken *oauth2.Token
|
|
expectedErr error
|
|
}
|
|
|
|
tests := []testCase{
|
|
{
|
|
desc: "should skip sync when identity is nil",
|
|
},
|
|
{
|
|
desc: "should skip sync when identity is not a user",
|
|
identity: &authn.Identity{ID: "1", Type: claims.TypeServiceAccount},
|
|
},
|
|
{
|
|
desc: "should skip token refresh and return nil if namespace and id cannot be converted to user ID",
|
|
identity: &authn.Identity{ID: "invalid", Type: claims.TypeUser},
|
|
},
|
|
{
|
|
desc: "should skip token refresh if there's an unexpected error while looking up the user oauth entry, additionally, no error should be returned",
|
|
identity: userIdentity,
|
|
setup: func(env *environment) {
|
|
env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(nil, assert.AnError).Once()
|
|
|
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
|
UseRefreshToken: true,
|
|
}
|
|
},
|
|
},
|
|
// Kinda impossible to happen, can only happen after the feature is enabled and logged in users don't have their external sessions set
|
|
{
|
|
desc: "should skip token refresh if the user doesn't have an external session",
|
|
identity: userIdentity,
|
|
setup: func(env *environment) {
|
|
env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(nil, auth.ErrExternalSessionNotFound).Once()
|
|
|
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
|
UseRefreshToken: true,
|
|
}
|
|
},
|
|
},
|
|
{
|
|
desc: "should skip token refresh when no oauth provider was found",
|
|
identity: userIdentity,
|
|
setup: func(env *environment) {
|
|
env.socialService.ExpectedAuthInfoProvider = nil
|
|
},
|
|
},
|
|
{
|
|
desc: "should skip token refresh when oauth provider token handling is disabled (UseRefreshToken is false)",
|
|
identity: userIdentity,
|
|
setup: func(env *environment) {
|
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
|
UseRefreshToken: false,
|
|
}
|
|
},
|
|
},
|
|
{
|
|
desc: "should skip token refresh when the token is still valid and no id token is present",
|
|
identity: userIdentity,
|
|
setup: func(env *environment) {
|
|
env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(&auth.ExternalSession{
|
|
ID: 1,
|
|
UserID: 1,
|
|
AccessToken: unexpiredTokenWithIDToken.AccessToken,
|
|
RefreshToken: unexpiredTokenWithIDToken.RefreshToken,
|
|
ExpiresAt: unexpiredTokenWithIDToken.Expiry,
|
|
}, nil).Once()
|
|
|
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
|
UseRefreshToken: true,
|
|
}
|
|
},
|
|
expectedToken: unexpiredToken,
|
|
},
|
|
{
|
|
desc: "should not do token refresh if access token or id token have not expired yet",
|
|
identity: userIdentity,
|
|
setup: func(env *environment) {
|
|
env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(&auth.ExternalSession{
|
|
ID: 1,
|
|
UserID: 1,
|
|
AccessToken: unexpiredTokenWithIDToken.AccessToken,
|
|
RefreshToken: unexpiredTokenWithIDToken.RefreshToken,
|
|
ExpiresAt: unexpiredTokenWithIDToken.Expiry,
|
|
IDToken: UNEXPIRED_ID_TOKEN,
|
|
}, nil).Once()
|
|
|
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
|
UseRefreshToken: true,
|
|
}
|
|
},
|
|
expectedToken: unexpiredTokenWithIDToken,
|
|
},
|
|
{
|
|
desc: "should skip token refresh when there is no refresh token",
|
|
identity: userIdentity,
|
|
setup: func(env *environment) {
|
|
env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(&auth.ExternalSession{
|
|
ID: 1,
|
|
UserID: 1,
|
|
AccessToken: unexpiredTokenWithIDToken.AccessToken,
|
|
RefreshToken: "",
|
|
ExpiresAt: unexpiredTokenWithIDToken.Expiry,
|
|
}, nil).Once()
|
|
|
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
|
UseRefreshToken: true,
|
|
}
|
|
},
|
|
expectedToken: &oauth2.Token{
|
|
AccessToken: unexpiredTokenWithIDToken.AccessToken,
|
|
RefreshToken: "",
|
|
Expiry: unexpiredTokenWithIDToken.Expiry,
|
|
},
|
|
},
|
|
{
|
|
desc: "should refresh token when the access token is expired",
|
|
identity: &authn.Identity{
|
|
AuthenticatedBy: login.GenericOAuthModule,
|
|
ID: "1",
|
|
Type: claims.TypeUser,
|
|
},
|
|
setup: func(env *environment) {
|
|
env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(&auth.ExternalSession{
|
|
ID: 1,
|
|
UserID: 1,
|
|
AccessToken: expiredToken.AccessToken,
|
|
IDToken: UNEXPIRED_ID_TOKEN,
|
|
RefreshToken: expiredToken.RefreshToken,
|
|
ExpiresAt: expiredToken.Expiry,
|
|
}, nil).Once()
|
|
|
|
env.sessionService.On("UpdateExternalSession", mock.Anything, int64(1), mock.MatchedBy(verifyUpdateExternalSessionCommand(unexpiredTokenWithIDToken))).Return(nil).Once()
|
|
|
|
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once()
|
|
|
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
|
UseRefreshToken: true,
|
|
}
|
|
},
|
|
expectedToken: unexpiredTokenWithIDToken,
|
|
},
|
|
{
|
|
desc: "should refresh token when the id token is expired",
|
|
identity: userIdentity,
|
|
setup: func(env *environment) {
|
|
env.sessionService.On("GetExternalSession", mock.Anything, int64(1)).Return(&auth.ExternalSession{
|
|
ID: 1,
|
|
UserID: 1,
|
|
AccessToken: unexpiredTokenWithIDToken.AccessToken,
|
|
RefreshToken: unexpiredTokenWithIDToken.RefreshToken,
|
|
ExpiresAt: unexpiredTokenWithIDToken.Expiry,
|
|
IDToken: EXPIRED_ID_TOKEN,
|
|
}, nil).Once()
|
|
|
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
|
UseRefreshToken: true,
|
|
}
|
|
|
|
env.sessionService.On("UpdateExternalSession", mock.Anything, int64(1), mock.MatchedBy(verifyUpdateExternalSessionCommand(unexpiredTokenWithIDToken))).Return(nil).Once()
|
|
|
|
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once()
|
|
},
|
|
expectedToken: unexpiredTokenWithIDToken,
|
|
},
|
|
{
|
|
desc: "should return ErrRetriesExhausted when lock cannot be acquired",
|
|
identity: &authn.Identity{ID: "1234", Type: claims.TypeUser, AuthenticatedBy: login.GenericOAuthModule},
|
|
setup: func(env *environment) {
|
|
env.socialService.ExpectedAuthInfoProvider = &social.OAuthInfo{
|
|
UseRefreshToken: true,
|
|
}
|
|
|
|
_ = env.store.WithDbSession(context.Background(), func(sess *db.Session) error {
|
|
_, err := sess.Exec(`INSERT INTO server_lock (operation_uid, last_execution, version) VALUES (?, ?, ?)`, "oauth-refresh-token-1234-1", time.Now().Add(2*time.Second).Unix(), 0)
|
|
return err
|
|
})
|
|
},
|
|
expectedErr: ErrRetriesExhausted,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.desc, func(t *testing.T) {
|
|
socialConnector := socialtest.NewMockSocialConnector(t)
|
|
|
|
store := db.InitTestDB(t)
|
|
|
|
env := environment{
|
|
sessionService: authtest.NewMockUserAuthTokenService(t),
|
|
serverLock: serverlock.ProvideService(store, tracing.InitializeTracerForTest()),
|
|
socialConnector: socialConnector,
|
|
socialService: &socialtest.FakeSocialService{
|
|
ExpectedConnector: socialConnector,
|
|
},
|
|
store: store,
|
|
}
|
|
|
|
if tt.setup != nil {
|
|
tt.setup(&env)
|
|
}
|
|
|
|
env.service = ProvideService(
|
|
env.socialService,
|
|
nil,
|
|
setting.NewCfg(),
|
|
prometheus.NewRegistry(),
|
|
env.serverLock,
|
|
tracing.InitializeTracerForTest(),
|
|
env.sessionService,
|
|
featuremgmt.WithFeatures(featuremgmt.FlagImprovedExternalSessionHandling),
|
|
)
|
|
|
|
// token refresh
|
|
actualToken, err := env.service.TryTokenRefresh(context.Background(), tt.identity, &usertoken.UserToken{ExternalSessionId: 1})
|
|
|
|
if tt.expectedErr != nil {
|
|
assert.ErrorIs(t, err, tt.expectedErr)
|
|
return
|
|
}
|
|
|
|
assert.NoError(t, err)
|
|
|
|
if tt.expectedToken == nil {
|
|
assert.Nil(t, actualToken)
|
|
return
|
|
}
|
|
|
|
assert.Equal(t, tt.expectedToken.AccessToken, actualToken.AccessToken)
|
|
assert.Equal(t, tt.expectedToken.RefreshToken, actualToken.RefreshToken)
|
|
assert.Equal(t, tt.expectedToken.Expiry, actualToken.Expiry)
|
|
if tt.expectedToken.Extra("id_token") != nil {
|
|
assert.Equal(t, tt.expectedToken.Extra("id_token").(string), actualToken.Extra("id_token").(string))
|
|
} else {
|
|
assert.Nil(t, actualToken.Extra("id_token"))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func verifyUpdateExternalSessionCommand(token *oauth2.Token) func(*auth.UpdateExternalSessionCommand) bool {
|
|
return func(cmd *auth.UpdateExternalSessionCommand) bool {
|
|
idToken := cmd.Token.Extra("id_token")
|
|
return cmd.Token.AccessToken == token.AccessToken &&
|
|
cmd.Token.RefreshToken == token.RefreshToken &&
|
|
cmd.Token.Expiry.Equal(token.Expiry) &&
|
|
idToken == token.Extra("id_token")
|
|
}
|
|
}
|
|
|
|
func TestOAuthTokenSync_needTokenRefresh(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
usr *login.UserAuth
|
|
expectedTokenRefreshFlag bool
|
|
expectedTokenDuration time.Duration
|
|
}{
|
|
{
|
|
name: "should not need token refresh when token has no expiration date",
|
|
usr: &login.UserAuth{},
|
|
expectedTokenRefreshFlag: false,
|
|
},
|
|
{
|
|
name: "should not need token refresh with an invalid jwt token that might result in an error when parsing",
|
|
usr: &login.UserAuth{
|
|
OAuthIdToken: "invalid_jwt_format",
|
|
},
|
|
expectedTokenRefreshFlag: false,
|
|
},
|
|
{
|
|
name: "should flag token refresh with id token is expired",
|
|
usr: &login.UserAuth{
|
|
OAuthIdToken: EXPIRED_ID_TOKEN,
|
|
},
|
|
expectedTokenRefreshFlag: true,
|
|
expectedTokenDuration: time.Second,
|
|
},
|
|
{
|
|
name: "should flag token refresh when expiry date is zero",
|
|
usr: &login.UserAuth{
|
|
OAuthExpiry: time.Unix(0, 0),
|
|
},
|
|
expectedTokenRefreshFlag: true,
|
|
expectedTokenDuration: time.Second,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
token := buildOAuthTokenFromAuthInfo(tt.usr)
|
|
needsTokenRefresh := needTokenRefresh(context.Background(), token)
|
|
|
|
assert.Equal(t, tt.expectedTokenRefreshFlag, needsTokenRefresh)
|
|
})
|
|
}
|
|
}
|