mirror of
https://github.com/grafana/grafana.git
synced 2025-09-23 08:23:02 +08:00
AuthN: Add post auth hook for oauth token refresh (#61608)
* AuthN: rename package to sync * AuthN: rename sync files * Ouath: Add mock for OauthTokenService * AuthN: Implement access token refresh hook * AuthN: remove feature check from hook * AuthN: register post auth hook for oauth token refresh
This commit is contained in:
@ -14,10 +14,12 @@ import (
|
|||||||
"github.com/grafana/grafana/pkg/services/apikey"
|
"github.com/grafana/grafana/pkg/services/apikey"
|
||||||
"github.com/grafana/grafana/pkg/services/auth"
|
"github.com/grafana/grafana/pkg/services/auth"
|
||||||
"github.com/grafana/grafana/pkg/services/authn"
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
sync "github.com/grafana/grafana/pkg/services/authn/authnimpl/usersync"
|
"github.com/grafana/grafana/pkg/services/authn/authnimpl/sync"
|
||||||
"github.com/grafana/grafana/pkg/services/authn/clients"
|
"github.com/grafana/grafana/pkg/services/authn/clients"
|
||||||
|
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||||
"github.com/grafana/grafana/pkg/services/login"
|
"github.com/grafana/grafana/pkg/services/login"
|
||||||
"github.com/grafana/grafana/pkg/services/loginattempt"
|
"github.com/grafana/grafana/pkg/services/loginattempt"
|
||||||
|
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||||
"github.com/grafana/grafana/pkg/services/org"
|
"github.com/grafana/grafana/pkg/services/org"
|
||||||
"github.com/grafana/grafana/pkg/services/quota"
|
"github.com/grafana/grafana/pkg/services/quota"
|
||||||
"github.com/grafana/grafana/pkg/services/rendering"
|
"github.com/grafana/grafana/pkg/services/rendering"
|
||||||
@ -43,6 +45,7 @@ func ProvideService(
|
|||||||
userProtectionService login.UserProtectionService,
|
userProtectionService login.UserProtectionService,
|
||||||
loginAttempts loginattempt.Service, quotaService quota.Service,
|
loginAttempts loginattempt.Service, quotaService quota.Service,
|
||||||
authInfoService login.AuthInfoService, renderService rendering.Service,
|
authInfoService login.AuthInfoService, renderService rendering.Service,
|
||||||
|
features *featuremgmt.FeatureManager, oauthTokenService oauthtoken.OAuthTokenService,
|
||||||
) *Service {
|
) *Service {
|
||||||
s := &Service{
|
s := &Service{
|
||||||
log: log.New("authn.service"),
|
log: log.New("authn.service"),
|
||||||
@ -111,6 +114,10 @@ func ProvideService(
|
|||||||
s.RegisterPostAuthHook(sync.ProvideUserLastSeenSync(userService).SyncLastSeen)
|
s.RegisterPostAuthHook(sync.ProvideUserLastSeenSync(userService).SyncLastSeen)
|
||||||
s.RegisterPostAuthHook(sync.ProvideAPIKeyLastSeenSync(apikeyService).SyncLastSeen)
|
s.RegisterPostAuthHook(sync.ProvideAPIKeyLastSeenSync(apikeyService).SyncLastSeen)
|
||||||
|
|
||||||
|
if features.IsEnabled(featuremgmt.FlagAccessTokenExpirationCheck) {
|
||||||
|
s.RegisterPostAuthHook(sync.ProvideOauthTokenSync(oauthTokenService, sessionService).SyncOauthToken)
|
||||||
|
}
|
||||||
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
79
pkg/services/authn/authnimpl/sync/oauth_token_sync.go
Normal file
79
pkg/services/authn/authnimpl/sync/oauth_token_sync.go
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
package sync
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/grafana/grafana/pkg/infra/log"
|
||||||
|
"github.com/grafana/grafana/pkg/services/auth"
|
||||||
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
|
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||||
|
"github.com/grafana/grafana/pkg/services/user"
|
||||||
|
"github.com/grafana/grafana/pkg/util/errutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errExpiredAccessToken = errutil.NewBase(errutil.StatusUnauthorized, "oauth.expired-token")
|
||||||
|
)
|
||||||
|
|
||||||
|
func ProvideOauthTokenSync(service oauthtoken.OAuthTokenService, sessionService auth.UserTokenService) *OauthTokenSync {
|
||||||
|
return &OauthTokenSync{
|
||||||
|
log.New("oauth_token.sync"),
|
||||||
|
service,
|
||||||
|
sessionService,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type OauthTokenSync struct {
|
||||||
|
log log.Logger
|
||||||
|
service oauthtoken.OAuthTokenService
|
||||||
|
sessionService auth.UserTokenService
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OauthTokenSync) SyncOauthToken(ctx context.Context, identity *authn.Identity, _ *authn.Request) error {
|
||||||
|
namespace, id := identity.NamespacedID()
|
||||||
|
// only perform oauth token check if identity is a user
|
||||||
|
if namespace != authn.NamespaceUser {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// not authenticated through session tokens, so we can skip this hook
|
||||||
|
if identity.SessionToken == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
token, exists, _ := s.service.HasOAuthEntry(ctx, &user.SignedInUser{UserID: id})
|
||||||
|
// user is not authenticated through oauth so skip further checks
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// token has no expire time configured, so we don't have to refresh it
|
||||||
|
if token.OAuthExpiry.IsZero() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// token has not expired, so we don't have to refresh it
|
||||||
|
if !token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta).Before(time.Now()) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.service.TryTokenRefresh(ctx, token); err != nil {
|
||||||
|
if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) {
|
||||||
|
s.log.FromContext(ctx).Error("could not refresh oauth access token for user", "userId", id, "err", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.service.InvalidateOAuthTokens(ctx, token); err != nil {
|
||||||
|
s.log.FromContext(ctx).Error("could not invalidate OAuth tokens", "userId", id, "err", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.sessionService.RevokeToken(ctx, identity.SessionToken, false); err != nil {
|
||||||
|
s.log.FromContext(ctx).Error("could not revoke token", "userId", id, "tokenId", identity.SessionToken.Id, "err", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return errExpiredAccessToken.Errorf("oauth access token could not be refreshed: %w", auth.ErrInvalidSessionToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
134
pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go
Normal file
134
pkg/services/authn/authnimpl/sync/oauth_token_sync_test.go
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
package sync
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/grafana/grafana/pkg/infra/log"
|
||||||
|
"github.com/grafana/grafana/pkg/models"
|
||||||
|
"github.com/grafana/grafana/pkg/services/auth"
|
||||||
|
"github.com/grafana/grafana/pkg/services/auth/authtest"
|
||||||
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
|
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
|
||||||
|
"github.com/grafana/grafana/pkg/services/user"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOauthTokenSync_SyncOauthToken(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
desc string
|
||||||
|
identity *authn.Identity
|
||||||
|
|
||||||
|
expectedHasEntryToken *models.UserAuth
|
||||||
|
expectHasEntryCalled bool
|
||||||
|
|
||||||
|
expectedTryRefreshErr error
|
||||||
|
expectTryRefreshTokenCalled bool
|
||||||
|
|
||||||
|
expectRevokeTokenCalled bool
|
||||||
|
expectInvalidateOauthTokensCalled bool
|
||||||
|
|
||||||
|
expectedErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
desc: "should skip sync when identity is not a user",
|
||||||
|
identity: &authn.Identity{ID: "service-account:1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should skip sync when identity is a user but is not authenticated with session token",
|
||||||
|
identity: &authn.Identity{ID: "user:1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should skip sync when user has session but is not authenticated with oauth",
|
||||||
|
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||||
|
expectHasEntryCalled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should skip sync for when access token don't have expire time",
|
||||||
|
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||||
|
expectHasEntryCalled: true,
|
||||||
|
expectedHasEntryToken: &models.UserAuth{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should skip sync when access token has no expired yet",
|
||||||
|
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||||
|
expectHasEntryCalled: true,
|
||||||
|
expectedHasEntryToken: &models.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should skip sync when access token has no expired yet",
|
||||||
|
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||||
|
expectHasEntryCalled: true,
|
||||||
|
expectedHasEntryToken: &models.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should refresh access token when is has expired",
|
||||||
|
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||||
|
expectHasEntryCalled: true,
|
||||||
|
expectTryRefreshTokenCalled: true,
|
||||||
|
expectedHasEntryToken: &models.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should invalidate access token and session token if access token can't be refreshed",
|
||||||
|
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||||
|
expectHasEntryCalled: true,
|
||||||
|
expectedTryRefreshErr: errors.New("some err"),
|
||||||
|
expectTryRefreshTokenCalled: true,
|
||||||
|
expectInvalidateOauthTokensCalled: true,
|
||||||
|
expectRevokeTokenCalled: true,
|
||||||
|
expectedHasEntryToken: &models.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
||||||
|
expectedErr: errExpiredAccessToken,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
hasEntryCalled bool
|
||||||
|
tryRefreshCalled bool
|
||||||
|
invalidateTokensCalled bool
|
||||||
|
revokeTokenCalled bool
|
||||||
|
)
|
||||||
|
|
||||||
|
service := &oauthtokentest.MockOauthTokenService{
|
||||||
|
HasOAuthEntryFunc: func(ctx context.Context, usr *user.SignedInUser) (*models.UserAuth, bool, error) {
|
||||||
|
hasEntryCalled = true
|
||||||
|
return tt.expectedHasEntryToken, tt.expectedHasEntryToken != nil, nil
|
||||||
|
},
|
||||||
|
InvalidateOAuthTokensFunc: func(ctx context.Context, usr *models.UserAuth) error {
|
||||||
|
invalidateTokensCalled = true
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
TryTokenRefreshFunc: func(ctx context.Context, usr *models.UserAuth) error {
|
||||||
|
tryRefreshCalled = true
|
||||||
|
return tt.expectedTryRefreshErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionService := &authtest.FakeUserAuthTokenService{
|
||||||
|
RevokeTokenProvider: func(ctx context.Context, token *auth.UserToken, soft bool) error {
|
||||||
|
revokeTokenCalled = true
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
sync := &OauthTokenSync{
|
||||||
|
log: log.NewNopLogger(),
|
||||||
|
service: service,
|
||||||
|
sessionService: sessionService,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := sync.SyncOauthToken(context.Background(), tt.identity, nil)
|
||||||
|
assert.ErrorIs(t, err, tt.expectedErr)
|
||||||
|
assert.Equal(t, tt.expectHasEntryCalled, hasEntryCalled)
|
||||||
|
assert.Equal(t, tt.expectTryRefreshTokenCalled, tryRefreshCalled)
|
||||||
|
assert.Equal(t, tt.expectInvalidateOauthTokensCalled, invalidateTokensCalled)
|
||||||
|
assert.Equal(t, tt.expectRevokeTokenCalled, revokeTokenCalled)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package usersync
|
package sync
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
@ -1,4 +1,4 @@
|
|||||||
package usersync
|
package sync
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
@ -1,4 +1,4 @@
|
|||||||
package usersync
|
package sync
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
@ -1,4 +1,4 @@
|
|||||||
package usersync
|
package sync
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
53
pkg/services/oauthtoken/oauthtokentest/mock.go
Normal file
53
pkg/services/oauthtoken/oauthtokentest/mock.go
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
package oauthtokentest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/grafana/grafana/pkg/models"
|
||||||
|
"github.com/grafana/grafana/pkg/services/datasources"
|
||||||
|
"github.com/grafana/grafana/pkg/services/user"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MockOauthTokenService struct {
|
||||||
|
GetCurrentOauthTokenFunc func(ctx context.Context, usr *user.SignedInUser) *oauth2.Token
|
||||||
|
IsOAuthPassThruEnabledFunc func(ds *datasources.DataSource) bool
|
||||||
|
HasOAuthEntryFunc func(ctx context.Context, usr *user.SignedInUser) (*models.UserAuth, bool, error)
|
||||||
|
InvalidateOAuthTokensFunc func(ctx context.Context, usr *models.UserAuth) error
|
||||||
|
TryTokenRefreshFunc func(ctx context.Context, usr *models.UserAuth) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockOauthTokenService) GetCurrentOAuthToken(ctx context.Context, usr *user.SignedInUser) *oauth2.Token {
|
||||||
|
if m.GetCurrentOauthTokenFunc != nil {
|
||||||
|
return m.GetCurrentOauthTokenFunc(ctx, usr)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockOauthTokenService) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
|
||||||
|
if m.IsOAuthPassThruEnabledFunc != nil {
|
||||||
|
return m.IsOAuthPassThruEnabledFunc(ds)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockOauthTokenService) HasOAuthEntry(ctx context.Context, usr *user.SignedInUser) (*models.UserAuth, bool, error) {
|
||||||
|
if m.HasOAuthEntryFunc != nil {
|
||||||
|
return m.HasOAuthEntryFunc(ctx, usr)
|
||||||
|
}
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockOauthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error {
|
||||||
|
if m.InvalidateOAuthTokensFunc != nil {
|
||||||
|
return m.InvalidateOAuthTokensFunc(ctx, usr)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockOauthTokenService) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) error {
|
||||||
|
if m.TryTokenRefreshFunc != nil {
|
||||||
|
return m.TryTokenRefreshFunc(ctx, usr)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
Reference in New Issue
Block a user