Files
grafana/pkg/services/auth/authimpl/external_session_store.go
Misi 84b8296ffb OAuth: Use the attached external session data in OAuthToken and OAuthTokenSync (#96655)
* wip

* wip + tests

* wip

* wip opt2

* Use authn.Identity struct's SessionToken

* Merge fixes

* Handle disabling the feature flag correctly

* Fix test

* Cleanup

* Remove HasOAuthEntry from the OAuthTokenService interface

* Remove unused function
2024-11-27 11:06:39 +01:00

284 lines
6.8 KiB
Go

package authimpl
import (
"context"
"crypto/sha256"
"encoding/base64"
"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/infra/tracing"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/secrets"
)
var _ auth.ExternalSessionStore = (*store)(nil)
type store struct {
sqlStore db.DB
secretsService secrets.Service
tracer tracing.Tracer
}
func provideExternalSessionStore(sqlStore db.DB, secretService secrets.Service, tracer tracing.Tracer) auth.ExternalSessionStore {
return &store{
sqlStore: sqlStore,
secretsService: secretService,
tracer: tracer,
}
}
func (s *store) Get(ctx context.Context, ID int64) (*auth.ExternalSession, error) {
ctx, span := s.tracer.Start(ctx, "externalsession.Get")
defer span.End()
externalSession := &auth.ExternalSession{ID: ID}
err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
found, err := sess.Get(externalSession)
if err != nil {
return err
}
if !found {
return auth.ErrExternalSessionNotFound
}
return nil
})
if err != nil {
return nil, err
}
err = s.decryptSecrets(externalSession)
if err != nil {
return nil, err
}
return externalSession, nil
}
func (s *store) List(ctx context.Context, query *auth.ListExternalSessionQuery) ([]*auth.ExternalSession, error) {
ctx, span := s.tracer.Start(ctx, "externalsession.List")
defer span.End()
externalSession := &auth.ExternalSession{}
if query.ID != 0 {
externalSession.ID = query.ID
}
hash := sha256.New()
if query.SessionID != "" {
hash.Write([]byte(query.SessionID))
externalSession.SessionIDHash = base64.RawStdEncoding.EncodeToString(hash.Sum(nil))
}
if query.NameID != "" {
hash.Reset()
hash.Write([]byte(query.NameID))
externalSession.NameIDHash = base64.RawStdEncoding.EncodeToString(hash.Sum(nil))
}
queryResult := make([]*auth.ExternalSession, 0)
err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
return sess.Find(&queryResult, externalSession)
})
if err != nil {
return nil, err
}
for _, extSession := range queryResult {
err := s.decryptSecrets(extSession)
if err != nil {
return nil, err
}
}
return queryResult, nil
}
func (s *store) Create(ctx context.Context, extSession *auth.ExternalSession) error {
ctx, span := s.tracer.Start(ctx, "externalsession.Create")
defer span.End()
var err error
clone := extSession.Clone()
clone.AccessToken, err = s.encryptAndEncode(extSession.AccessToken)
if err != nil {
return err
}
clone.RefreshToken, err = s.encryptAndEncode(extSession.RefreshToken)
if err != nil {
return err
}
clone.IDToken, err = s.encryptAndEncode(extSession.IDToken)
if err != nil {
return err
}
if extSession.NameID != "" {
hash := sha256.New()
hash.Write([]byte(extSession.NameID))
clone.NameIDHash = base64.RawStdEncoding.EncodeToString(hash.Sum(nil))
}
clone.NameID, err = s.encryptAndEncode(extSession.NameID)
if err != nil {
return err
}
if extSession.SessionID != "" {
hash := sha256.New()
hash.Write([]byte(extSession.SessionID))
clone.SessionIDHash = base64.RawStdEncoding.EncodeToString(hash.Sum(nil))
}
clone.SessionID, err = s.encryptAndEncode(extSession.SessionID)
if err != nil {
return err
}
err = s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
_, err := sess.Insert(clone)
return err
})
if err != nil {
return err
}
extSession.ID = clone.ID
return nil
}
func (s *store) Update(ctx context.Context, ID int64, cmd *auth.UpdateExternalSessionCommand) error {
ctx, span := s.tracer.Start(ctx, "externalsession.Update")
defer span.End()
var err error
externalSession := &auth.ExternalSession{}
externalSession.AccessToken, err = s.encryptAndEncode(cmd.Token.AccessToken)
if err != nil {
return err
}
externalSession.RefreshToken, err = s.encryptAndEncode(cmd.Token.RefreshToken)
if err != nil {
return err
}
var secretIdToken string
if idToken, ok := cmd.Token.Extra("id_token").(string); ok && idToken != "" {
secretIdToken, err = s.encryptAndEncode(idToken)
if err != nil {
return err
}
externalSession.IDToken = secretIdToken
}
externalSession.ExpiresAt = cmd.Token.Expiry
err = s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
_, err := sess.ID(ID).Cols("access_token", "refresh_token", "id_token", "expires_at").Update(externalSession)
return err
})
if err != nil {
return err
}
return nil
}
func (s *store) Delete(ctx context.Context, ID int64) error {
ctx, span := s.tracer.Start(ctx, "externalsession.Delete")
defer span.End()
externalSession := &auth.ExternalSession{ID: ID}
err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
_, err := sess.Delete(externalSession)
return err
})
return err
}
func (s *store) DeleteExternalSessionsByUserID(ctx context.Context, userID int64) error {
ctx, span := s.tracer.Start(ctx, "externalsession.DeleteExternalSessionsByUserID")
defer span.End()
externalSession := &auth.ExternalSession{UserID: userID}
err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
_, err := sess.Delete(externalSession)
return err
})
return err
}
func (s *store) BatchDeleteExternalSessionsByUserIDs(ctx context.Context, userIDs []int64) error {
ctx, span := s.tracer.Start(ctx, "externalsession.BatchDeleteExternalSessionsByUserIDs")
defer span.End()
externalSession := &auth.ExternalSession{}
err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
_, err := sess.In("user_id", userIDs).Delete(externalSession)
return err
})
return err
}
func (s *store) decryptSecrets(extSession *auth.ExternalSession) error {
var err error
extSession.AccessToken, err = s.decodeAndDecrypt(extSession.AccessToken)
if err != nil {
return err
}
extSession.RefreshToken, err = s.decodeAndDecrypt(extSession.RefreshToken)
if err != nil {
return err
}
extSession.IDToken, err = s.decodeAndDecrypt(extSession.IDToken)
if err != nil {
return err
}
extSession.NameID, err = s.decodeAndDecrypt(extSession.NameID)
if err != nil {
return err
}
extSession.SessionID, err = s.decodeAndDecrypt(extSession.SessionID)
if err != nil {
return err
}
return nil
}
func (s *store) encryptAndEncode(str string) (string, error) {
if str == "" {
return "", nil
}
encrypted, err := s.secretsService.Encrypt(context.Background(), []byte(str), secrets.WithoutScope())
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(encrypted), nil
}
func (s *store) decodeAndDecrypt(str string) (string, error) {
// Bail out if empty string since it'll cause a segfault in Decrypt
if str == "" {
return "", nil
}
decoded, err := base64.StdEncoding.DecodeString(str)
if err != nil {
return "", err
}
decrypted, err := s.secretsService.Decrypt(context.Background(), decoded)
if err != nil {
return "", err
}
return string(decrypted), nil
}