diff --git a/pkg/services/oauthtoken/oauth_token.go b/pkg/services/oauthtoken/oauth_token.go index fc9b64edc6f..62708b0ff98 100644 --- a/pkg/services/oauthtoken/oauth_token.go +++ b/pkg/services/oauthtoken/oauth_token.go @@ -440,20 +440,20 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, persistedToken ) } - if o.features.IsEnabledGlobally(featuremgmt.FlagImprovedExternalSessionHandling) { - if err := o.sessionService.UpdateExternalSession(ctx, sessionToken.ExternalSessionId, &auth.UpdateExternalSessionCommand{ - Token: token, - }); err != nil { - ctxLogger.Error("Failed to update external session during token refresh", "error", err) - return token, err - } - } else { + if !o.features.IsEnabledGlobally(featuremgmt.FlagImprovedExternalSessionHandling) { if err := o.AuthInfoService.UpdateAuthInfo(ctx, updateAuthCommand); err != nil { ctxLogger.Error("Failed to update auth info during token refresh", "authID", usr.GetAuthID(), "error", err) return token, err } } + if err := o.sessionService.UpdateExternalSession(ctx, sessionToken.ExternalSessionId, &auth.UpdateExternalSessionCommand{ + Token: token, + }); err != nil { + ctxLogger.Error("Failed to update external session during token refresh", "error", err) + return token, err + } + ctxLogger.Debug("Updated oauth info for user") } diff --git a/pkg/services/oauthtoken/oauth_token_test.go b/pkg/services/oauthtoken/oauth_token_test.go index 36b5528a5c7..247254afb27 100644 --- a/pkg/services/oauthtoken/oauth_token_test.go +++ b/pkg/services/oauthtoken/oauth_token_test.go @@ -85,6 +85,7 @@ func TestIntegration_TryTokenRefresh(t *testing.T) { } type environment struct { + sessionService *authtest.MockUserAuthTokenService authInfoService *authinfotest.FakeService serverLock *serverlock.ServerLockService socialConnector *socialtest.MockSocialConnector @@ -231,6 +232,8 @@ func TestIntegration_TryTokenRefresh(t *testing.T) { OAuthTokenType: expiredToken.TokenType, OAuthIdToken: EXPIRED_ID_TOKEN, } + env.sessionService.On("UpdateExternalSession", mock.Anything, int64(1), mock.MatchedBy(verifyUpdateExternalSessionCommand(unexpiredTokenWithIDToken))).Return(nil).Once() + env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once() }, expectedToken: unexpiredTokenWithIDToken, @@ -252,6 +255,8 @@ func TestIntegration_TryTokenRefresh(t *testing.T) { OAuthTokenType: unexpiredTokenWithIDToken.TokenType, OAuthIdToken: EXPIRED_ID_TOKEN, } + env.sessionService.On("UpdateExternalSession", mock.Anything, int64(1), mock.MatchedBy(verifyUpdateExternalSessionCommand(unexpiredTokenWithIDToken))).Return(nil).Once() + env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once() }, expectedToken: unexpiredTokenWithIDToken, @@ -288,6 +293,7 @@ func TestIntegration_TryTokenRefresh(t *testing.T) { store := db.InitTestDB(t) env := environment{ + sessionService: authtest.NewMockUserAuthTokenService(t), authInfoService: &authinfotest.FakeService{}, serverLock: serverlock.ProvideService(store, tracing.InitializeTracerForTest()), socialConnector: socialConnector, @@ -308,12 +314,12 @@ func TestIntegration_TryTokenRefresh(t *testing.T) { prometheus.NewRegistry(), env.serverLock, tracing.InitializeTracerForTest(), - nil, + env.sessionService, featuremgmt.WithFeatures(), ) // token refresh - actualToken, err := env.service.TryTokenRefresh(context.Background(), tt.identity, nil) + actualToken, err := env.service.TryTokenRefresh(context.Background(), tt.identity, &usertoken.UserToken{ExternalSessionId: 1}) if tt.expectedErr != nil { assert.ErrorIs(t, err, tt.expectedErr)