mirror of
https://github.com/grafana/grafana.git
synced 2025-08-01 02:31:50 +08:00
Auth: Implement the SSO Settings update endpoint (#79676)
* merge with system settings before storing them in the db * add base for validating sso settings * add unit tests for sso settings validation * call Reload() from sso service upsert() * remove actual validation because it was moved in a separate pr * use constant to fix lint error * check if provider is configurable in service Upsert() method * add unit tests for update provider settings api method * fix lint error
This commit is contained in:
@ -133,11 +133,8 @@ func (api *Api) updateProviderSettings(c *contextmodel.ReqContext) response.Resp
|
||||
settings.Provider = key
|
||||
|
||||
err := api.SSOSettingsService.Upsert(c.Req.Context(), settings)
|
||||
// TODO: first check whether the error is referring to validation errors
|
||||
|
||||
// other error
|
||||
if err != nil {
|
||||
return response.Error(http.StatusInternalServerError, "Failed to update provider settings", err)
|
||||
return response.ErrOrFallback(http.StatusInternalServerError, "Failed to update provider settings", err)
|
||||
}
|
||||
|
||||
return response.JSON(http.StatusNoContent, nil)
|
||||
|
@ -1,6 +1,8 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@ -11,16 +13,145 @@ import (
|
||||
|
||||
"github.com/grafana/grafana/pkg/api/routing"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/services/accesscontrol"
|
||||
"github.com/grafana/grafana/pkg/services/accesscontrol/acimpl"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
"github.com/grafana/grafana/pkg/services/ssosettings"
|
||||
"github.com/grafana/grafana/pkg/services/ssosettings/models"
|
||||
"github.com/grafana/grafana/pkg/services/ssosettings/ssosettingstests"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/web/webtest"
|
||||
)
|
||||
|
||||
func TestSSOSettingsAPI_Update(t *testing.T) {
|
||||
type TestCase struct {
|
||||
desc string
|
||||
key string
|
||||
body string
|
||||
action string
|
||||
scope string
|
||||
expectedError error
|
||||
expectedServiceCall bool
|
||||
expectedStatusCode int
|
||||
}
|
||||
|
||||
tests := []TestCase{
|
||||
{
|
||||
desc: "successfully updates SSO settings",
|
||||
key: social.GitHubProviderName,
|
||||
body: `{"settings": {"enabled": true}}`,
|
||||
action: "settings:write",
|
||||
scope: "settings:auth.github:*",
|
||||
expectedError: nil,
|
||||
expectedServiceCall: true,
|
||||
expectedStatusCode: http.StatusNoContent,
|
||||
},
|
||||
{
|
||||
desc: "fails when action doesn't match",
|
||||
key: social.GitHubProviderName,
|
||||
body: `{"settings": {"enabled": true}}`,
|
||||
action: "settings:read",
|
||||
scope: "settings:auth.github:*",
|
||||
expectedError: nil,
|
||||
expectedServiceCall: false,
|
||||
expectedStatusCode: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
desc: "fails when scope doesn't match",
|
||||
key: social.GitHubProviderName,
|
||||
body: `{"settings": {"enabled": true}}`,
|
||||
action: "settings:write",
|
||||
scope: "settings:auth.github:read",
|
||||
expectedError: nil,
|
||||
expectedServiceCall: false,
|
||||
expectedStatusCode: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
desc: "fails when scope contains another provider",
|
||||
key: social.GitHubProviderName,
|
||||
body: `{"settings": {"enabled": true}}`,
|
||||
action: "settings:write",
|
||||
scope: "settings:auth.okta:*",
|
||||
expectedError: nil,
|
||||
expectedServiceCall: false,
|
||||
expectedStatusCode: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
desc: "fails with not found when key is empty",
|
||||
key: "",
|
||||
body: `{"settings": {"enabled": true}}`,
|
||||
action: "settings:write",
|
||||
scope: "settings:auth.github:*",
|
||||
expectedError: nil,
|
||||
expectedServiceCall: false,
|
||||
expectedStatusCode: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
desc: "fails with bad request when body contains invalid json",
|
||||
key: social.GitHubProviderName,
|
||||
body: `{ invalid json }`,
|
||||
action: "settings:write",
|
||||
scope: "settings:auth.github:*",
|
||||
expectedError: nil,
|
||||
expectedServiceCall: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
desc: "fails with bad request when key was not found",
|
||||
key: social.GitHubProviderName,
|
||||
body: `{"settings": {"enabled": true}}`,
|
||||
action: "settings:write",
|
||||
scope: "settings:auth.github:*",
|
||||
expectedError: ssosettings.ErrInvalidProvider.Errorf("invalid provider"),
|
||||
expectedServiceCall: true,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
desc: "fails with internal server error when service returns an error",
|
||||
key: social.GitHubProviderName,
|
||||
body: `{"settings": {"enabled": true}}`,
|
||||
action: "settings:write",
|
||||
scope: "settings:auth.github:*",
|
||||
expectedError: errors.New("something went wrong"),
|
||||
expectedServiceCall: true,
|
||||
expectedStatusCode: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
var input models.SSOSettings
|
||||
_ = json.Unmarshal([]byte(tt.body), &input)
|
||||
|
||||
settings := models.SSOSettings{
|
||||
Provider: tt.key,
|
||||
Settings: input.Settings,
|
||||
}
|
||||
|
||||
service := ssosettingstests.NewMockService(t)
|
||||
if tt.expectedServiceCall {
|
||||
service.On("Upsert", mock.Anything, settings).Return(tt.expectedError).Once()
|
||||
}
|
||||
server := setupTests(t, service)
|
||||
|
||||
path := fmt.Sprintf("/api/v1/sso-settings/%s", tt.key)
|
||||
req := server.NewRequest(http.MethodPut, path, bytes.NewBufferString(tt.body))
|
||||
webtest.RequestWithSignedInUser(req, &user.SignedInUser{
|
||||
OrgRole: org.RoleEditor,
|
||||
OrgID: 1,
|
||||
Permissions: getPermissionsForActionAndScope(tt.action, tt.scope),
|
||||
})
|
||||
res, err := server.SendJSON(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, tt.expectedStatusCode, res.StatusCode)
|
||||
require.NoError(t, res.Body.Close())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOSettingsAPI_Delete(t *testing.T) {
|
||||
type TestCase struct {
|
||||
desc string
|
||||
@ -35,7 +166,7 @@ func TestSSOSettingsAPI_Delete(t *testing.T) {
|
||||
tests := []TestCase{
|
||||
{
|
||||
desc: "successfully deletes SSO settings",
|
||||
key: "azuread",
|
||||
key: social.AzureADProviderName,
|
||||
action: "settings:write",
|
||||
scope: "settings:auth.azuread:*",
|
||||
expectedError: nil,
|
||||
@ -44,7 +175,7 @@ func TestSSOSettingsAPI_Delete(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "fails when action doesn't match",
|
||||
key: "azuread",
|
||||
key: social.AzureADProviderName,
|
||||
action: "settings:read",
|
||||
scope: "settings:auth.azuread:*",
|
||||
expectedError: nil,
|
||||
@ -53,7 +184,7 @@ func TestSSOSettingsAPI_Delete(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "fails when scope doesn't match",
|
||||
key: "azuread",
|
||||
key: social.AzureADProviderName,
|
||||
action: "settings:write",
|
||||
scope: "settings:auth.azuread:read",
|
||||
expectedError: nil,
|
||||
@ -62,7 +193,7 @@ func TestSSOSettingsAPI_Delete(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "fails when scope contains another provider",
|
||||
key: "azuread",
|
||||
key: social.AzureADProviderName,
|
||||
action: "settings:write",
|
||||
scope: "settings:auth.github:*",
|
||||
expectedError: nil,
|
||||
@ -80,7 +211,7 @@ func TestSSOSettingsAPI_Delete(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "fails with not found when key was not found",
|
||||
key: "azuread",
|
||||
key: social.AzureADProviderName,
|
||||
action: "settings:write",
|
||||
scope: "settings:auth.azuread:*",
|
||||
expectedError: ssosettings.ErrNotFound,
|
||||
@ -89,7 +220,7 @@ func TestSSOSettingsAPI_Delete(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "fails with internal server error when service returns an error",
|
||||
key: "azuread",
|
||||
key: social.AzureADProviderName,
|
||||
action: "settings:write",
|
||||
scope: "settings:auth.azuread:*",
|
||||
expectedError: errors.New("something went wrong"),
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
var (
|
||||
ErrNotFound = errors.New("not found")
|
||||
|
||||
ErrInvalidProvider = errutil.ValidationFailed("sso.invalidProvider", errutil.WithPublicMessage("provider is invalid"))
|
||||
ErrInvalidSettings = errutil.ValidationFailed("sso.settings", errutil.WithPublicMessage("settings field is invalid"))
|
||||
ErrEmptyClientId = errutil.ValidationFailed("sso.emptyClientId", errutil.WithPublicMessage("settings.clientId cannot be empty"))
|
||||
)
|
||||
|
@ -36,6 +36,8 @@ type Service interface {
|
||||
}
|
||||
|
||||
// Reloadable is an interface that can be implemented by a provider to allow it to be validated and reloaded
|
||||
//
|
||||
//go:generate mockery --name Reloadable --structname MockReloadable --outpkg ssosettingstests --filename reloadable_mock.go --output ./ssosettingstests/
|
||||
type Reloadable interface {
|
||||
Reload(ctx context.Context, settings models.SSOSettings) error
|
||||
Validate(ctx context.Context, settings models.SSOSettings) error
|
||||
|
@ -110,16 +110,47 @@ func (s *SSOSettingsService) List(ctx context.Context) ([]*models.SSOSettings, e
|
||||
}
|
||||
|
||||
func (s *SSOSettingsService) Upsert(ctx context.Context, settings models.SSOSettings) error {
|
||||
var err error
|
||||
// TODO: also check whether the provider is configurable
|
||||
// Get the connector for the provider (from the reloadables) and call Validate
|
||||
if !isProviderConfigurable(settings.Provider) {
|
||||
return ssosettings.ErrInvalidProvider.Errorf("provider %s is not configurable", settings.Provider)
|
||||
}
|
||||
|
||||
social, ok := s.reloadables[settings.Provider]
|
||||
if !ok {
|
||||
return ssosettings.ErrInvalidProvider.Errorf("provider %s not found in reloadables", settings.Provider)
|
||||
}
|
||||
|
||||
err := social.Validate(ctx, settings)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
systemSettings, err := s.loadSettingsUsingFallbackStrategy(ctx, settings.Provider)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// add the SSO settings from system that are not available in the user input
|
||||
// in order to have a complete set of SSO settings for every provider in the database
|
||||
settings.Settings = mergeSettings(settings.Settings, systemSettings.Settings)
|
||||
|
||||
settings.Settings, err = s.encryptSecrets(ctx, settings.Settings)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.store.Upsert(ctx, settings)
|
||||
err = s.store.Upsert(ctx, settings)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
err = social.Reload(context.Background(), settings)
|
||||
if err != nil {
|
||||
s.log.Error("failed to reload the provider", "provider", settings.Provider, "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SSOSettingsService) Patch(ctx context.Context, provider string, data map[string]any) error {
|
||||
@ -183,31 +214,60 @@ func (s *SSOSettingsService) getFallBackstrategyFor(provider string) (ssosetting
|
||||
}
|
||||
|
||||
func (s *SSOSettingsService) encryptSecrets(ctx context.Context, settings map[string]any) (map[string]any, error) {
|
||||
secretFieldPatterns := []string{"secret"}
|
||||
|
||||
isSecret := func(field string) bool {
|
||||
for _, v := range secretFieldPatterns {
|
||||
if strings.Contains(strings.ToLower(field), strings.ToLower(v)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
result := make(map[string]any)
|
||||
for k, v := range settings {
|
||||
if isSecret(k) {
|
||||
strValue, ok := v.(string)
|
||||
if !ok {
|
||||
return settings, fmt.Errorf("failed to encrypt %s setting because it is not a string: %v", k, v)
|
||||
return result, fmt.Errorf("failed to encrypt %s setting because it is not a string: %v", k, v)
|
||||
}
|
||||
|
||||
encryptedSecret, err := s.secrets.Encrypt(ctx, []byte(strValue), secrets.WithoutScope())
|
||||
if err != nil {
|
||||
return settings, err
|
||||
return result, err
|
||||
}
|
||||
settings[k] = string(encryptedSecret)
|
||||
result[k] = string(encryptedSecret)
|
||||
} else {
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return settings, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func isSecret(fieldName string) bool {
|
||||
secretFieldPatterns := []string{"secret"}
|
||||
|
||||
for _, v := range secretFieldPatterns {
|
||||
if strings.Contains(strings.ToLower(fieldName), strings.ToLower(v)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func mergeSettings(apiSettings, systemSettings map[string]any) map[string]any {
|
||||
settings := make(map[string]any)
|
||||
|
||||
for k, v := range apiSettings {
|
||||
settings[k] = v
|
||||
}
|
||||
|
||||
for k, v := range systemSettings {
|
||||
if _, ok := settings[k]; !ok {
|
||||
settings[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
func isProviderConfigurable(provider string) bool {
|
||||
for _, configurable := range ssosettings.ConfigurableOAuthProviders {
|
||||
if provider == configurable {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/services/accesscontrol"
|
||||
"github.com/grafana/grafana/pkg/services/accesscontrol/acimpl"
|
||||
secretsFakes "github.com/grafana/grafana/pkg/services/secrets/fakes"
|
||||
@ -252,8 +253,9 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
|
||||
t.Run("successfully upsert SSO settings", func(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
|
||||
provider := social.AzureADProviderName
|
||||
settings := models.SSOSettings{
|
||||
Provider: "azuread",
|
||||
Provider: provider,
|
||||
Settings: map[string]any{
|
||||
"client_id": "client-id",
|
||||
"client_secret": "client-secret",
|
||||
@ -262,17 +264,180 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
|
||||
IsDeleted: false,
|
||||
}
|
||||
|
||||
reloadable := ssosettingstests.NewMockReloadable(t)
|
||||
reloadable.On("Validate", mock.Anything, settings).Return(nil)
|
||||
reloadable.On("Reload", mock.Anything, mock.Anything).Return(nil).Maybe()
|
||||
env.reloadables[provider] = reloadable
|
||||
env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once()
|
||||
|
||||
err := env.service.Upsert(context.Background(), settings)
|
||||
require.NoError(t, err)
|
||||
|
||||
settings.Settings["client_secret"] = "encrypted-client-secret"
|
||||
require.EqualValues(t, settings, env.store.ActualSSOSettings)
|
||||
})
|
||||
|
||||
t.Run("successfully upsert SSO settings having system settings", func(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
|
||||
provider := social.GitHubProviderName
|
||||
settings := models.SSOSettings{
|
||||
Provider: provider,
|
||||
Settings: map[string]any{
|
||||
"client_id": "client-id",
|
||||
"client_secret": "client-secret",
|
||||
"enabled": true,
|
||||
},
|
||||
IsDeleted: false,
|
||||
}
|
||||
systemSettings := map[string]any{
|
||||
"api_url": "http://api-url",
|
||||
"use_refresh_token": true,
|
||||
}
|
||||
|
||||
reloadable := ssosettingstests.NewMockReloadable(t)
|
||||
reloadable.On("Validate", mock.Anything, settings).Return(nil)
|
||||
reloadable.On("Reload", mock.Anything, mock.Anything).Return(nil).Maybe()
|
||||
env.reloadables[provider] = reloadable
|
||||
env.fallbackStrategy.ExpectedConfig = systemSettings
|
||||
env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once()
|
||||
|
||||
err := env.service.Upsert(context.Background(), settings)
|
||||
require.NoError(t, err)
|
||||
|
||||
settings.Settings["client_secret"] = "encrypted-client-secret"
|
||||
settings.Settings["api_url"] = systemSettings["api_url"]
|
||||
settings.Settings["use_refresh_token"] = systemSettings["use_refresh_token"]
|
||||
require.EqualValues(t, settings, env.store.ActualSSOSettings)
|
||||
})
|
||||
|
||||
t.Run("successfully upsert SSO settings having system settings without overwriting user settings", func(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
|
||||
provider := social.GitlabProviderName
|
||||
settings := models.SSOSettings{
|
||||
Provider: provider,
|
||||
Settings: map[string]any{
|
||||
"client_id": "client-id",
|
||||
"client_secret": "client-secret",
|
||||
"enabled": true,
|
||||
},
|
||||
IsDeleted: false,
|
||||
}
|
||||
systemSettings := map[string]any{
|
||||
"client_id": "client-id-from-system",
|
||||
"client_secret": "client-secret-from-system",
|
||||
"enabled": false,
|
||||
"api_url": "http://api-url",
|
||||
"use_refresh_token": true,
|
||||
}
|
||||
|
||||
reloadable := ssosettingstests.NewMockReloadable(t)
|
||||
reloadable.On("Validate", mock.Anything, settings).Return(nil)
|
||||
reloadable.On("Reload", mock.Anything, mock.Anything).Return(nil).Maybe()
|
||||
env.reloadables[provider] = reloadable
|
||||
env.fallbackStrategy.ExpectedConfig = systemSettings
|
||||
env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once()
|
||||
|
||||
err := env.service.Upsert(context.Background(), settings)
|
||||
require.NoError(t, err)
|
||||
|
||||
settings.Settings["client_secret"] = "encrypted-client-secret"
|
||||
settings.Settings["api_url"] = systemSettings["api_url"]
|
||||
settings.Settings["use_refresh_token"] = systemSettings["use_refresh_token"]
|
||||
require.EqualValues(t, settings, env.store.ActualSSOSettings)
|
||||
})
|
||||
|
||||
t.Run("returns error if provider is not configurable", func(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
|
||||
provider := social.GrafanaComProviderName
|
||||
settings := models.SSOSettings{
|
||||
Provider: provider,
|
||||
Settings: map[string]any{
|
||||
"client_id": "client-id",
|
||||
"client_secret": "client-secret",
|
||||
"enabled": true,
|
||||
},
|
||||
IsDeleted: false,
|
||||
}
|
||||
|
||||
reloadable := ssosettingstests.NewMockReloadable(t)
|
||||
env.reloadables[provider] = reloadable
|
||||
|
||||
err := env.service.Upsert(context.Background(), settings)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("returns error if provider was not found in reloadables", func(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
|
||||
provider := social.AzureADProviderName
|
||||
settings := models.SSOSettings{
|
||||
Provider: provider,
|
||||
Settings: map[string]any{
|
||||
"client_id": "client-id",
|
||||
"client_secret": "client-secret",
|
||||
"enabled": true,
|
||||
},
|
||||
IsDeleted: false,
|
||||
}
|
||||
|
||||
reloadable := ssosettingstests.NewMockReloadable(t)
|
||||
// the reloadable is available for other provider
|
||||
env.reloadables["github"] = reloadable
|
||||
|
||||
err := env.service.Upsert(context.Background(), settings)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("returns error if validation fails", func(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
|
||||
provider := social.AzureADProviderName
|
||||
settings := models.SSOSettings{
|
||||
Provider: provider,
|
||||
Settings: map[string]any{
|
||||
"client_id": "client-id",
|
||||
"client_secret": "client-secret",
|
||||
"enabled": true,
|
||||
},
|
||||
IsDeleted: false,
|
||||
}
|
||||
|
||||
reloadable := ssosettingstests.NewMockReloadable(t)
|
||||
reloadable.On("Validate", mock.Anything, settings).Return(errors.New("validation failed"))
|
||||
env.reloadables[provider] = reloadable
|
||||
|
||||
err := env.service.Upsert(context.Background(), settings)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("returns error if a fallback strategy is not available for the provider", func(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
|
||||
settings := models.SSOSettings{
|
||||
Provider: social.AzureADProviderName,
|
||||
Settings: map[string]any{
|
||||
"client_id": "client-id",
|
||||
"client_secret": "client-secret",
|
||||
"enabled": true,
|
||||
},
|
||||
IsDeleted: false,
|
||||
}
|
||||
|
||||
env.fallbackStrategy.ExpectedIsMatch = false
|
||||
|
||||
err := env.service.Upsert(context.Background(), settings)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("returns error if secrets encryption failed", func(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
|
||||
provider := social.OktaProviderName
|
||||
settings := models.SSOSettings{
|
||||
Provider: "azuread",
|
||||
Provider: provider,
|
||||
Settings: map[string]any{
|
||||
"client_id": "client-id",
|
||||
"client_secret": "client-secret",
|
||||
@ -281,6 +446,9 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
|
||||
IsDeleted: false,
|
||||
}
|
||||
|
||||
reloadable := ssosettingstests.NewMockReloadable(t)
|
||||
reloadable.On("Validate", mock.Anything, settings).Return(nil)
|
||||
env.reloadables[provider] = reloadable
|
||||
env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return(nil, errors.New("encryption failed")).Once()
|
||||
|
||||
err := env.service.Upsert(context.Background(), settings)
|
||||
@ -290,8 +458,9 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
|
||||
t.Run("returns error if store failed to upsert settings", func(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
|
||||
provider := social.AzureADProviderName
|
||||
settings := models.SSOSettings{
|
||||
Provider: "azuread",
|
||||
Provider: provider,
|
||||
Settings: map[string]any{
|
||||
"client_id": "client-id",
|
||||
"client_secret": "client-secret",
|
||||
@ -300,19 +469,49 @@ func TestSSOSettingsService_Upsert(t *testing.T) {
|
||||
IsDeleted: false,
|
||||
}
|
||||
|
||||
reloadable := ssosettingstests.NewMockReloadable(t)
|
||||
reloadable.On("Validate", mock.Anything, settings).Return(nil)
|
||||
env.reloadables[provider] = reloadable
|
||||
env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once()
|
||||
env.store.ExpectedError = errors.New("upsert failed")
|
||||
|
||||
err := env.service.Upsert(context.Background(), settings)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("successfully upsert SSO settings if reload fails", func(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
|
||||
provider := social.AzureADProviderName
|
||||
settings := models.SSOSettings{
|
||||
Provider: provider,
|
||||
Settings: map[string]any{
|
||||
"client_id": "client-id",
|
||||
"client_secret": "client-secret",
|
||||
"enabled": true,
|
||||
},
|
||||
IsDeleted: false,
|
||||
}
|
||||
|
||||
reloadable := ssosettingstests.NewMockReloadable(t)
|
||||
reloadable.On("Validate", mock.Anything, settings).Return(nil)
|
||||
reloadable.On("Reload", mock.Anything, mock.Anything).Return(errors.New("failed reloading new settings")).Maybe()
|
||||
env.reloadables[provider] = reloadable
|
||||
env.secrets.On("Encrypt", mock.Anything, []byte(settings.Settings["client_secret"].(string)), mock.Anything).Return([]byte("encrypted-client-secret"), nil).Once()
|
||||
|
||||
err := env.service.Upsert(context.Background(), settings)
|
||||
require.NoError(t, err)
|
||||
|
||||
settings.Settings["client_secret"] = "encrypted-client-secret"
|
||||
require.EqualValues(t, settings, env.store.ActualSSOSettings)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSOSettingsService_Delete(t *testing.T) {
|
||||
t.Run("successfully delete SSO settings", func(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
|
||||
provider := "azuread"
|
||||
provider := social.AzureADProviderName
|
||||
env.store.ExpectedError = nil
|
||||
|
||||
err := env.service.Delete(context.Background(), provider)
|
||||
@ -322,7 +521,7 @@ func TestSSOSettingsService_Delete(t *testing.T) {
|
||||
t.Run("SSO settings not found for the specified provider", func(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
|
||||
provider := "azuread"
|
||||
provider := social.AzureADProviderName
|
||||
env.store.ExpectedError = ssosettings.ErrNotFound
|
||||
|
||||
err := env.service.Delete(context.Background(), provider)
|
||||
@ -333,7 +532,7 @@ func TestSSOSettingsService_Delete(t *testing.T) {
|
||||
t.Run("store fails to delete the SSO settings for the specified provider", func(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
|
||||
provider := "azuread"
|
||||
provider := social.AzureADProviderName
|
||||
env.store.ExpectedError = errors.New("delete sso settings failed")
|
||||
|
||||
err := env.service.Delete(context.Background(), provider)
|
||||
@ -347,13 +546,16 @@ func setupTestEnv(t *testing.T) testEnv {
|
||||
fallbackStrategy := ssosettingstests.NewFakeFallbackStrategy()
|
||||
secrets := secretsFakes.NewMockService(t)
|
||||
accessControl := acimpl.ProvideAccessControl(setting.NewCfg())
|
||||
reloadables := make(map[string]ssosettings.Reloadable)
|
||||
|
||||
fallbackStrategy.ExpectedIsMatch = true
|
||||
|
||||
svc := &SSOSettingsService{
|
||||
log: log.NewNopLogger(),
|
||||
store: store,
|
||||
ac: accessControl,
|
||||
fbStrategies: []ssosettings.FallbackStrategy{fallbackStrategy},
|
||||
reloadables: make(map[string]ssosettings.Reloadable),
|
||||
reloadables: reloadables,
|
||||
secrets: secrets,
|
||||
}
|
||||
|
||||
@ -363,6 +565,7 @@ func setupTestEnv(t *testing.T) testEnv {
|
||||
ac: accessControl,
|
||||
fallbackStrategy: fallbackStrategy,
|
||||
secrets: secrets,
|
||||
reloadables: reloadables,
|
||||
}
|
||||
}
|
||||
|
||||
@ -372,4 +575,5 @@ type testEnv struct {
|
||||
ac accesscontrol.AccessControl
|
||||
fallbackStrategy *ssosettingstests.FakeFallbackStrategy
|
||||
secrets *secretsFakes.MockService
|
||||
reloadables map[string]ssosettings.Reloadable
|
||||
}
|
||||
|
57
pkg/services/ssosettings/ssosettingstests/reloadable_mock.go
Normal file
57
pkg/services/ssosettings/ssosettingstests/reloadable_mock.go
Normal file
@ -0,0 +1,57 @@
|
||||
// Code generated by mockery v2.37.1. DO NOT EDIT.
|
||||
|
||||
package ssosettingstests
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
models "github.com/grafana/grafana/pkg/services/ssosettings/models"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockReloadable is an autogenerated mock type for the Reloadable type
|
||||
type MockReloadable struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Reload provides a mock function with given fields: ctx, settings
|
||||
func (_m *MockReloadable) Reload(ctx context.Context, settings models.SSOSettings) error {
|
||||
ret := _m.Called(ctx, settings)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, models.SSOSettings) error); ok {
|
||||
r0 = rf(ctx, settings)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Validate provides a mock function with given fields: ctx, settings
|
||||
func (_m *MockReloadable) Validate(ctx context.Context, settings models.SSOSettings) error {
|
||||
ret := _m.Called(ctx, settings)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, models.SSOSettings) error); ok {
|
||||
r0 = rf(ctx, settings)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// NewMockReloadable creates a new instance of MockReloadable. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewMockReloadable(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *MockReloadable {
|
||||
mock := &MockReloadable{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
@ -13,6 +13,8 @@ type FakeStore struct {
|
||||
ExpectedSSOSetting *models.SSOSettings
|
||||
ExpectedSSOSettings []*models.SSOSettings
|
||||
ExpectedError error
|
||||
|
||||
ActualSSOSettings models.SSOSettings
|
||||
}
|
||||
|
||||
func NewFakeStore() *FakeStore {
|
||||
@ -28,6 +30,8 @@ func (f *FakeStore) List(ctx context.Context) ([]*models.SSOSettings, error) {
|
||||
}
|
||||
|
||||
func (f *FakeStore) Upsert(ctx context.Context, settings models.SSOSettings) error {
|
||||
f.ActualSSOSettings = settings
|
||||
|
||||
return f.ExpectedError
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user