Auth: Refresh OAuth access_token automatically using the refresh_token (#56076)

* Verify OAuth token expiration for oauth users in the ctx handler middleware

* Use refresh token to get a new access token

* Refactor oauth_token.go

* Add tests for the middleware changes

* Align other tests

* Add tests, wip

* Add more tests

* Add InvalidateOAuthTokens method

* Fix ExpiryDate update to default

* Invalidate OAuth tokens during logout

* Improve logout

* Add more comments

* Cleanup

* Fix import order

* Add error to HasOAuthEntry return values

* add dev debug logs

* Fix tests

Co-authored-by: jguer <joao.guerreiro@grafana.com>
This commit is contained in:
Misi
2022-10-18 18:17:28 +02:00
committed by GitHub
parent 984ec00aac
commit 9c954d06ab
17 changed files with 828 additions and 89 deletions

View File

@ -2,6 +2,7 @@ package middleware
import (
"context"
"errors"
"fmt"
"io"
"net"
@ -339,6 +340,123 @@ func TestMiddlewareContext(t *testing.T) {
assert.Nil(t, sc.context.UserToken)
})
middlewareScenario(t, "Non-expired auth token in cookie and non-expired OAuth access token", func(
t *testing.T, sc *scenarioContext) {
const userID int64 = 12
sc.contextHandler.GetTime = fakeGetTime()
sc.withTokenSessionCookie("token")
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{UserId: userID, OAuthExpiry: fakeGetTime()().Add(11 * time.Second)}
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
return &models.UserToken{
UserId: userID,
UnhashedToken: unhashedToken,
}, nil
}
sc.fakeReq("GET", "/").exec()
require.NotNil(t, sc.context)
require.NotNil(t, sc.context.UserToken)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserID)
assert.Equal(t, userID, sc.context.UserToken.UserId)
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
})
middlewareScenario(t, "Non-expired auth token in cookie and expired OAuth access token and refreshing the token fails", func(
t *testing.T, sc *scenarioContext) {
const userID int64 = 12
sc.contextHandler.GetTime = fakeGetTime()
sc.withTokenSessionCookie("token")
signedInUser := &user.SignedInUser{OrgID: 2, UserID: userID}
sc.userService.ExpectedSignedInUser = signedInUser
sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{
UserId: userID,
OAuthExpiry: fakeGetTime()().Add(-1 * time.Second),
OAuthAccessToken: "access_token",
OAuthRefreshToken: "refresh_token"}
sc.oauthTokenService.ExpectedErrors = map[string]error{"TryTokenRefresh": errors.New("error")}
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
return &models.UserToken{
UserId: userID,
UnhashedToken: unhashedToken,
}, nil
}
sc.fakeReq("GET", "/").exec()
token := sc.oauthTokenService.GetCurrentOAuthToken(sc.context.Req.Context(), signedInUser)
assert.Equal(t, token.AccessToken, "")
assert.Equal(t, token.RefreshToken, "")
assert.True(t, token.Expiry.IsZero())
require.NotNil(t, sc.context)
require.Nil(t, sc.context.UserToken)
assert.False(t, sc.context.IsSignedIn)
assert.Equal(t, int64(0), sc.context.UserID)
assert.Equal(t, "grafana_session=; Path=/; Max-Age=0; HttpOnly", sc.resp.Header().Get("Set-Cookie"))
})
middlewareScenario(t, "Non-expired auth token in cookie and expired OAuth access token and refreshing the token succeeds", func(
t *testing.T, sc *scenarioContext) {
const userID int64 = 12
sc.contextHandler.GetTime = fakeGetTime()
sc.withTokenSessionCookie("token")
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{UserId: userID, OAuthExpiry: fakeGetTime()().Add(-5 * time.Second), OAuthRefreshToken: "refreshtoken"}
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
return &models.UserToken{
UserId: userID,
UnhashedToken: unhashedToken,
}, nil
}
sc.fakeReq("GET", "/").exec()
require.NotNil(t, sc.context)
require.NotNil(t, sc.context.UserToken)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserID)
assert.Equal(t, userID, sc.context.UserToken.UserId)
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
})
middlewareScenario(t, "Non-expired auth token in cookie and OAuth Access Token's Expiry is not set", func(
t *testing.T, sc *scenarioContext) {
const userID int64 = 12
sc.contextHandler.GetTime = fakeGetTime()
sc.withTokenSessionCookie("token")
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{UserId: userID}
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
return &models.UserToken{
UserId: userID,
UnhashedToken: unhashedToken,
}, nil
}
sc.fakeReq("GET", "/").exec()
require.NotNil(t, sc.context)
require.NotNil(t, sc.context.UserToken)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserID)
assert.Equal(t, userID, sc.context.UserToken.UserId)
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
})
middlewareScenario(t, "When anonymous access is enabled", func(t *testing.T, sc *scenarioContext) {
sc.mockSQLStore.ExpectedOrg = &models.Org{Id: 1, Name: sc.cfg.AnonymousOrgName}
sc.orgService.ExpectedOrg = &org.Org{ID: 1, Name: sc.cfg.AnonymousOrgName}
@ -655,7 +773,8 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
sc.userService = usertest.NewUserServiceFake()
sc.orgService = orgtest.NewOrgServiceFake()
sc.apiKeyService = &apikeytest.Service{}
ctxHdlr := getContextHandler(t, cfg, sc.mockSQLStore, sc.loginService, sc.apiKeyService, sc.userService, sc.orgService)
sc.oauthTokenService = &auth.FakeOAuthTokenService{}
ctxHdlr := getContextHandler(t, cfg, sc.mockSQLStore, sc.loginService, sc.apiKeyService, sc.userService, sc.orgService, sc.oauthTokenService)
sc.sqlStore = ctxHdlr.SQLStore
sc.contextHandler = ctxHdlr
sc.m.Use(ctxHdlr.Middleware)
@ -691,6 +810,7 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *mockstore.SQLStoreMock,
loginService *loginservice.LoginServiceMock, apiKeyService *apikeytest.Service,
userService *usertest.FakeUserService, orgService *orgtest.FakeOrgService,
oauthTokenService *auth.FakeOAuthTokenService,
) *contexthandler.ContextHandler {
t.Helper()
@ -708,7 +828,7 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *mockstore.S
tracer := tracing.InitializeTracerForTest()
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, userService, mockSQLStore)
authenticator := &logintest.AuthenticatorFake{ExpectedUser: &user.User{}}
return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, mockSQLStore, tracer, authProxy, loginService, apiKeyService, authenticator, userService, orgService)
return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, mockSQLStore, tracer, authProxy, loginService, apiKeyService, authenticator, userService, orgService, oauthTokenService)
}
type fakeRenderService struct {