mirror of
https://github.com/grafana/grafana.git
synced 2025-08-02 03:12:13 +08:00
Auth: Refresh OAuth access_token automatically using the refresh_token (#56076)
* Verify OAuth token expiration for oauth users in the ctx handler middleware * Use refresh token to get a new access token * Refactor oauth_token.go * Add tests for the middleware changes * Align other tests * Add tests, wip * Add more tests * Add InvalidateOAuthTokens method * Fix ExpiryDate update to default * Invalidate OAuth tokens during logout * Improve logout * Add more comments * Cleanup * Fix import order * Add error to HasOAuthEntry return values * add dev debug logs * Fix tests Co-authored-by: jguer <joao.guerreiro@grafana.com>
This commit is contained in:
@ -2,6 +2,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@ -339,6 +340,123 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
assert.Nil(t, sc.context.UserToken)
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Non-expired auth token in cookie and non-expired OAuth access token", func(
|
||||
t *testing.T, sc *scenarioContext) {
|
||||
const userID int64 = 12
|
||||
sc.contextHandler.GetTime = fakeGetTime()
|
||||
|
||||
sc.withTokenSessionCookie("token")
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
|
||||
sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{UserId: userID, OAuthExpiry: fakeGetTime()().Add(11 * time.Second)}
|
||||
|
||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
|
||||
return &models.UserToken{
|
||||
UserId: userID,
|
||||
UnhashedToken: unhashedToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
sc.fakeReq("GET", "/").exec()
|
||||
|
||||
require.NotNil(t, sc.context)
|
||||
require.NotNil(t, sc.context.UserToken)
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, userID, sc.context.UserID)
|
||||
assert.Equal(t, userID, sc.context.UserToken.UserId)
|
||||
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
|
||||
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Non-expired auth token in cookie and expired OAuth access token and refreshing the token fails", func(
|
||||
t *testing.T, sc *scenarioContext) {
|
||||
const userID int64 = 12
|
||||
sc.contextHandler.GetTime = fakeGetTime()
|
||||
|
||||
sc.withTokenSessionCookie("token")
|
||||
signedInUser := &user.SignedInUser{OrgID: 2, UserID: userID}
|
||||
sc.userService.ExpectedSignedInUser = signedInUser
|
||||
sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{
|
||||
UserId: userID,
|
||||
OAuthExpiry: fakeGetTime()().Add(-1 * time.Second),
|
||||
OAuthAccessToken: "access_token",
|
||||
OAuthRefreshToken: "refresh_token"}
|
||||
sc.oauthTokenService.ExpectedErrors = map[string]error{"TryTokenRefresh": errors.New("error")}
|
||||
|
||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
|
||||
return &models.UserToken{
|
||||
UserId: userID,
|
||||
UnhashedToken: unhashedToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
sc.fakeReq("GET", "/").exec()
|
||||
|
||||
token := sc.oauthTokenService.GetCurrentOAuthToken(sc.context.Req.Context(), signedInUser)
|
||||
assert.Equal(t, token.AccessToken, "")
|
||||
assert.Equal(t, token.RefreshToken, "")
|
||||
assert.True(t, token.Expiry.IsZero())
|
||||
|
||||
require.NotNil(t, sc.context)
|
||||
require.Nil(t, sc.context.UserToken)
|
||||
assert.False(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, int64(0), sc.context.UserID)
|
||||
assert.Equal(t, "grafana_session=; Path=/; Max-Age=0; HttpOnly", sc.resp.Header().Get("Set-Cookie"))
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Non-expired auth token in cookie and expired OAuth access token and refreshing the token succeeds", func(
|
||||
t *testing.T, sc *scenarioContext) {
|
||||
const userID int64 = 12
|
||||
sc.contextHandler.GetTime = fakeGetTime()
|
||||
|
||||
sc.withTokenSessionCookie("token")
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
|
||||
sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{UserId: userID, OAuthExpiry: fakeGetTime()().Add(-5 * time.Second), OAuthRefreshToken: "refreshtoken"}
|
||||
|
||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
|
||||
return &models.UserToken{
|
||||
UserId: userID,
|
||||
UnhashedToken: unhashedToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
sc.fakeReq("GET", "/").exec()
|
||||
|
||||
require.NotNil(t, sc.context)
|
||||
require.NotNil(t, sc.context.UserToken)
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, userID, sc.context.UserID)
|
||||
assert.Equal(t, userID, sc.context.UserToken.UserId)
|
||||
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
|
||||
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Non-expired auth token in cookie and OAuth Access Token's Expiry is not set", func(
|
||||
t *testing.T, sc *scenarioContext) {
|
||||
const userID int64 = 12
|
||||
sc.contextHandler.GetTime = fakeGetTime()
|
||||
|
||||
sc.withTokenSessionCookie("token")
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
|
||||
sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{UserId: userID}
|
||||
|
||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
|
||||
return &models.UserToken{
|
||||
UserId: userID,
|
||||
UnhashedToken: unhashedToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
sc.fakeReq("GET", "/").exec()
|
||||
|
||||
require.NotNil(t, sc.context)
|
||||
require.NotNil(t, sc.context.UserToken)
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, userID, sc.context.UserID)
|
||||
assert.Equal(t, userID, sc.context.UserToken.UserId)
|
||||
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
|
||||
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
|
||||
})
|
||||
|
||||
middlewareScenario(t, "When anonymous access is enabled", func(t *testing.T, sc *scenarioContext) {
|
||||
sc.mockSQLStore.ExpectedOrg = &models.Org{Id: 1, Name: sc.cfg.AnonymousOrgName}
|
||||
sc.orgService.ExpectedOrg = &org.Org{ID: 1, Name: sc.cfg.AnonymousOrgName}
|
||||
@ -655,7 +773,8 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
|
||||
sc.userService = usertest.NewUserServiceFake()
|
||||
sc.orgService = orgtest.NewOrgServiceFake()
|
||||
sc.apiKeyService = &apikeytest.Service{}
|
||||
ctxHdlr := getContextHandler(t, cfg, sc.mockSQLStore, sc.loginService, sc.apiKeyService, sc.userService, sc.orgService)
|
||||
sc.oauthTokenService = &auth.FakeOAuthTokenService{}
|
||||
ctxHdlr := getContextHandler(t, cfg, sc.mockSQLStore, sc.loginService, sc.apiKeyService, sc.userService, sc.orgService, sc.oauthTokenService)
|
||||
sc.sqlStore = ctxHdlr.SQLStore
|
||||
sc.contextHandler = ctxHdlr
|
||||
sc.m.Use(ctxHdlr.Middleware)
|
||||
@ -691,6 +810,7 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
|
||||
func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *mockstore.SQLStoreMock,
|
||||
loginService *loginservice.LoginServiceMock, apiKeyService *apikeytest.Service,
|
||||
userService *usertest.FakeUserService, orgService *orgtest.FakeOrgService,
|
||||
oauthTokenService *auth.FakeOAuthTokenService,
|
||||
) *contexthandler.ContextHandler {
|
||||
t.Helper()
|
||||
|
||||
@ -708,7 +828,7 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *mockstore.S
|
||||
tracer := tracing.InitializeTracerForTest()
|
||||
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, userService, mockSQLStore)
|
||||
authenticator := &logintest.AuthenticatorFake{ExpectedUser: &user.User{}}
|
||||
return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, mockSQLStore, tracer, authProxy, loginService, apiKeyService, authenticator, userService, orgService)
|
||||
return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, mockSQLStore, tracer, authProxy, loginService, apiKeyService, authenticator, userService, orgService, oauthTokenService)
|
||||
}
|
||||
|
||||
type fakeRenderService struct {
|
||||
|
Reference in New Issue
Block a user