Auth: Update external session data regardless if the improved session handling is enabled (#98139)

Update ext session tokens regardless if the ft is enabled
This commit is contained in:
Misi
2024-12-18 11:27:50 +01:00
committed by GitHub
parent a572dca2d6
commit a5635d7e89
2 changed files with 16 additions and 10 deletions

View File

@ -440,20 +440,20 @@ func (o *Service) tryGetOrRefreshOAuthToken(ctx context.Context, persistedToken
) )
} }
if o.features.IsEnabledGlobally(featuremgmt.FlagImprovedExternalSessionHandling) { if !o.features.IsEnabledGlobally(featuremgmt.FlagImprovedExternalSessionHandling) {
if err := o.sessionService.UpdateExternalSession(ctx, sessionToken.ExternalSessionId, &auth.UpdateExternalSessionCommand{
Token: token,
}); err != nil {
ctxLogger.Error("Failed to update external session during token refresh", "error", err)
return token, err
}
} else {
if err := o.AuthInfoService.UpdateAuthInfo(ctx, updateAuthCommand); err != nil { if err := o.AuthInfoService.UpdateAuthInfo(ctx, updateAuthCommand); err != nil {
ctxLogger.Error("Failed to update auth info during token refresh", "authID", usr.GetAuthID(), "error", err) ctxLogger.Error("Failed to update auth info during token refresh", "authID", usr.GetAuthID(), "error", err)
return token, err return token, err
} }
} }
if err := o.sessionService.UpdateExternalSession(ctx, sessionToken.ExternalSessionId, &auth.UpdateExternalSessionCommand{
Token: token,
}); err != nil {
ctxLogger.Error("Failed to update external session during token refresh", "error", err)
return token, err
}
ctxLogger.Debug("Updated oauth info for user") ctxLogger.Debug("Updated oauth info for user")
} }

View File

@ -85,6 +85,7 @@ func TestIntegration_TryTokenRefresh(t *testing.T) {
} }
type environment struct { type environment struct {
sessionService *authtest.MockUserAuthTokenService
authInfoService *authinfotest.FakeService authInfoService *authinfotest.FakeService
serverLock *serverlock.ServerLockService serverLock *serverlock.ServerLockService
socialConnector *socialtest.MockSocialConnector socialConnector *socialtest.MockSocialConnector
@ -231,6 +232,8 @@ func TestIntegration_TryTokenRefresh(t *testing.T) {
OAuthTokenType: expiredToken.TokenType, OAuthTokenType: expiredToken.TokenType,
OAuthIdToken: EXPIRED_ID_TOKEN, OAuthIdToken: EXPIRED_ID_TOKEN,
} }
env.sessionService.On("UpdateExternalSession", mock.Anything, int64(1), mock.MatchedBy(verifyUpdateExternalSessionCommand(unexpiredTokenWithIDToken))).Return(nil).Once()
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once() env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once()
}, },
expectedToken: unexpiredTokenWithIDToken, expectedToken: unexpiredTokenWithIDToken,
@ -252,6 +255,8 @@ func TestIntegration_TryTokenRefresh(t *testing.T) {
OAuthTokenType: unexpiredTokenWithIDToken.TokenType, OAuthTokenType: unexpiredTokenWithIDToken.TokenType,
OAuthIdToken: EXPIRED_ID_TOKEN, OAuthIdToken: EXPIRED_ID_TOKEN,
} }
env.sessionService.On("UpdateExternalSession", mock.Anything, int64(1), mock.MatchedBy(verifyUpdateExternalSessionCommand(unexpiredTokenWithIDToken))).Return(nil).Once()
env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once() env.socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(unexpiredTokenWithIDToken)).Once()
}, },
expectedToken: unexpiredTokenWithIDToken, expectedToken: unexpiredTokenWithIDToken,
@ -288,6 +293,7 @@ func TestIntegration_TryTokenRefresh(t *testing.T) {
store := db.InitTestDB(t) store := db.InitTestDB(t)
env := environment{ env := environment{
sessionService: authtest.NewMockUserAuthTokenService(t),
authInfoService: &authinfotest.FakeService{}, authInfoService: &authinfotest.FakeService{},
serverLock: serverlock.ProvideService(store, tracing.InitializeTracerForTest()), serverLock: serverlock.ProvideService(store, tracing.InitializeTracerForTest()),
socialConnector: socialConnector, socialConnector: socialConnector,
@ -308,12 +314,12 @@ func TestIntegration_TryTokenRefresh(t *testing.T) {
prometheus.NewRegistry(), prometheus.NewRegistry(),
env.serverLock, env.serverLock,
tracing.InitializeTracerForTest(), tracing.InitializeTracerForTest(),
nil, env.sessionService,
featuremgmt.WithFeatures(), featuremgmt.WithFeatures(),
) )
// token refresh // token refresh
actualToken, err := env.service.TryTokenRefresh(context.Background(), tt.identity, nil) actualToken, err := env.service.TryTokenRefresh(context.Background(), tt.identity, &usertoken.UserToken{ExternalSessionId: 1})
if tt.expectedErr != nil { if tt.expectedErr != nil {
assert.ErrorIs(t, err, tt.expectedErr) assert.ErrorIs(t, err, tt.expectedErr)