diff --git a/pkg/models/user_auth.go b/pkg/models/user_auth.go index 6a96ccb6bad..c7f2b01e878 100644 --- a/pkg/models/user_auth.go +++ b/pkg/models/user_auth.go @@ -19,6 +19,7 @@ type UserAuth struct { Created time.Time OAuthAccessToken string OAuthRefreshToken string + OAuthIdToken string OAuthTokenType string OAuthExpiry time.Time } diff --git a/pkg/services/login/authinfoservice/database.go b/pkg/services/login/authinfoservice/database.go index fede7e0e095..45452878af0 100644 --- a/pkg/services/login/authinfoservice/database.go +++ b/pkg/services/login/authinfoservice/database.go @@ -71,9 +71,14 @@ func (s *Implementation) GetAuthInfo(ctx context.Context, query *models.GetAuthI if err != nil { return err } + secretIdToken, err := s.decodeAndDecrypt(userAuth.OAuthIdToken) + if err != nil { + return err + } userAuth.OAuthAccessToken = secretAccessToken userAuth.OAuthRefreshToken = secretRefreshToken userAuth.OAuthTokenType = secretTokenType + userAuth.OAuthIdToken = secretIdToken query.Result = userAuth return nil @@ -101,9 +106,18 @@ func (s *Implementation) SetAuthInfo(ctx context.Context, cmd *models.SetAuthInf return err } + var secretIdToken string + if idToken, ok := cmd.OAuthToken.Extra("id_token").(string); ok && idToken != "" { + secretIdToken, err = s.encryptAndEncode(idToken) + if err != nil { + return err + } + } + authUser.OAuthAccessToken = secretAccessToken authUser.OAuthRefreshToken = secretRefreshToken authUser.OAuthTokenType = secretTokenType + authUser.OAuthIdToken = secretIdToken authUser.OAuthExpiry = cmd.OAuthToken.Expiry } @@ -135,9 +149,18 @@ func (s *Implementation) UpdateAuthInfo(ctx context.Context, cmd *models.UpdateA return err } + var secretIdToken string + if idToken, ok := cmd.OAuthToken.Extra("id_token").(string); ok && idToken != "" { + secretIdToken, err = s.encryptAndEncode(idToken) + if err != nil { + return err + } + } + authUser.OAuthAccessToken = secretAccessToken authUser.OAuthRefreshToken = secretRefreshToken authUser.OAuthTokenType = secretTokenType + authUser.OAuthIdToken = secretIdToken authUser.OAuthExpiry = cmd.OAuthToken.Expiry } diff --git a/pkg/services/login/authinfoservice/user_auth_test.go b/pkg/services/login/authinfoservice/user_auth_test.go index c6a42cbf275..a5cf659db1b 100644 --- a/pkg/services/login/authinfoservice/user_auth_test.go +++ b/pkg/services/login/authinfoservice/user_auth_test.go @@ -133,6 +133,8 @@ func TestUserAuth(t *testing.T) { Expiry: time.Now(), TokenType: "Bearer", } + idToken := "testidtoken" + token = token.WithExtra(map[string]interface{}{"id_token": idToken}) // Find a user to set tokens on login := "loginuser0" @@ -161,9 +163,10 @@ func TestUserAuth(t *testing.T) { err = srv.GetAuthInfo(context.Background(), getAuthQuery) require.Nil(t, err) - require.Equal(t, getAuthQuery.Result.OAuthAccessToken, token.AccessToken) - require.Equal(t, getAuthQuery.Result.OAuthRefreshToken, token.RefreshToken) - require.Equal(t, getAuthQuery.Result.OAuthTokenType, token.TokenType) + require.Equal(t, token.AccessToken, getAuthQuery.Result.OAuthAccessToken) + require.Equal(t, token.RefreshToken, getAuthQuery.Result.OAuthRefreshToken) + require.Equal(t, token.TokenType, getAuthQuery.Result.OAuthTokenType) + require.Equal(t, idToken, getAuthQuery.Result.OAuthIdToken) }) t.Run("Always return the most recently used auth_module", func(t *testing.T) { diff --git a/pkg/services/oauthtoken/oauth_token.go b/pkg/services/oauthtoken/oauth_token.go index 9276c4ebab3..4b541433b74 100644 --- a/pkg/services/oauthtoken/oauth_token.go +++ b/pkg/services/oauthtoken/oauth_token.go @@ -68,6 +68,11 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, user *models.SignedI RefreshToken: authInfoQuery.Result.OAuthRefreshToken, TokenType: authInfoQuery.Result.OAuthTokenType, } + + if authInfoQuery.Result.OAuthIdToken != "" { + persistedToken = persistedToken.WithExtra(map[string]interface{}{"id_token": authInfoQuery.Result.OAuthIdToken}) + } + // TokenSource handles refreshing the token if it has expired token, err := connect.TokenSource(ctx, persistedToken).Token() if err != nil { diff --git a/pkg/services/sqlstore/migrations/user_auth_mig.go b/pkg/services/sqlstore/migrations/user_auth_mig.go index e84d9b50014..bf3b98955b7 100644 --- a/pkg/services/sqlstore/migrations/user_auth_mig.go +++ b/pkg/services/sqlstore/migrations/user_auth_mig.go @@ -42,4 +42,8 @@ func addUserAuthMigrations(mg *Migrator) { mg.AddMigration("Add index to user_id column in user_auth", NewAddIndexMigration(userAuthV1, &Index{ Cols: []string{"user_id"}, })) + + mg.AddMigration("Add OAuth ID token to user_auth", NewAddColumnMigration(userAuthV1, &Column{ + Name: "o_auth_id_token", Type: DB_Text, Nullable: true, + })) }