diff --git a/pkg/api/common_test.go b/pkg/api/common_test.go index e732249e27f..915fe604200 100644 --- a/pkg/api/common_test.go +++ b/pkg/api/common_test.go @@ -27,11 +27,13 @@ import ( "github.com/grafana/grafana/pkg/services/accesscontrol/ossaccesscontrol" "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/contexthandler" + "github.com/grafana/grafana/pkg/services/contexthandler/authproxy" "github.com/grafana/grafana/pkg/services/dashboards" dashboardsstore "github.com/grafana/grafana/pkg/services/dashboards/database" dashboardservice "github.com/grafana/grafana/pkg/services/dashboards/manager" "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/ldap" + "github.com/grafana/grafana/pkg/services/login/loginservice" "github.com/grafana/grafana/pkg/services/quota" "github.com/grafana/grafana/pkg/services/rendering" "github.com/grafana/grafana/pkg/services/searchusers" @@ -193,7 +195,8 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHa authJWTSvc := models.NewFakeJWTService() tracer, err := tracing.InitializeTracerForTest() require.NoError(t, err) - ctxHdlr := contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer) + authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginservice.LoginServiceMock{}, sqlStore) + ctxHdlr := contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy) return ctxHdlr } diff --git a/pkg/middleware/middleware_test.go b/pkg/middleware/middleware_test.go index 3a30084454f..637e02c0a10 100644 --- a/pkg/middleware/middleware_test.go +++ b/pkg/middleware/middleware_test.go @@ -2,7 +2,6 @@ package middleware import ( "context" - "errors" "fmt" "io" "net" @@ -25,8 +24,10 @@ import ( "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/contexthandler" "github.com/grafana/grafana/pkg/services/contexthandler/authproxy" + "github.com/grafana/grafana/pkg/services/login/loginservice" "github.com/grafana/grafana/pkg/services/rendering" "github.com/grafana/grafana/pkg/services/sqlstore" + "github.com/grafana/grafana/pkg/services/sqlstore/mockstore" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util" "github.com/grafana/grafana/pkg/web" @@ -364,10 +365,7 @@ func TestMiddlewareContext(t *testing.T) { const group = "grafana-core-team" middlewareScenario(t, "Should not sync the user if it's in the cache", func(t *testing.T, sc *scenarioContext) { - bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{OrgId: orgID, UserId: query.UserId} - return nil - }) + sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID} h, err := authproxy.HashCacheKey(hdrName + "-" + group) require.NoError(t, err) @@ -387,11 +385,11 @@ func TestMiddlewareContext(t *testing.T) { middlewareScenario(t, "Should respect auto signup option", func(t *testing.T, sc *scenarioContext) { var actualAuthProxyAutoSignUp *bool = nil - - bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error { + sc.loginService.ExpectedUserFunc = func(cmd *models.UpsertUserCommand) *models.User { actualAuthProxyAutoSignUp = &cmd.SignupAllowed - return login.ErrInvalidCredentials - }) + return nil + } + sc.loginService.ExpectedError = login.ErrInvalidCredentials sc.fakeReq("GET", "/") sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) @@ -407,18 +405,8 @@ func TestMiddlewareContext(t *testing.T) { }) middlewareScenario(t, "Should create an user from a header", func(t *testing.T, sc *scenarioContext) { - bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error { - if query.UserId > 0 { - query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID} - return nil - } - return models.ErrUserNotFound - }) - - bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error { - cmd.Result = &models.User{Id: userID} - return nil - }) + sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID} + sc.loginService.ExpectedUser = &models.User{Id: userID} sc.fakeReq("GET", "/") sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) @@ -435,19 +423,11 @@ func TestMiddlewareContext(t *testing.T) { middlewareScenario(t, "Should assign role from header to default org", func(t *testing.T, sc *scenarioContext) { var storedRoleInfo map[int64]models.RoleType = nil - bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error { - if query.UserId > 0 { - query.Result = &models.SignedInUser{OrgId: defaultOrgId, UserId: userID, OrgRole: storedRoleInfo[defaultOrgId]} - return nil - } - return models.ErrUserNotFound - }) - - bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error { - cmd.Result = &models.User{Id: userID} + sc.loginService.ExpectedUserFunc = func(cmd *models.UpsertUserCommand) *models.User { storedRoleInfo = cmd.ExternalUser.OrgRoles - return nil - }) + sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: defaultOrgId, UserId: userID, OrgRole: storedRoleInfo[defaultOrgId]} + return &models.User{Id: userID} + } sc.fakeReq("GET", "/") sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) @@ -466,19 +446,11 @@ func TestMiddlewareContext(t *testing.T) { middlewareScenario(t, "Should NOT assign role from header to non-default org", func(t *testing.T, sc *scenarioContext) { var storedRoleInfo map[int64]models.RoleType = nil - bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error { - if query.UserId > 0 { - query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID, OrgRole: storedRoleInfo[orgID]} - return nil - } - return models.ErrUserNotFound - }) - - bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error { - cmd.Result = &models.User{Id: userID} + sc.loginService.ExpectedUserFunc = func(cmd *models.UpsertUserCommand) *models.User { storedRoleInfo = cmd.ExternalUser.OrgRoles - return nil - }) + sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID, OrgRole: storedRoleInfo[orgID]} + return &models.User{Id: userID} + } sc.fakeReq("GET", "/") sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) @@ -499,27 +471,17 @@ func TestMiddlewareContext(t *testing.T) { }) middlewareScenario(t, "Should use organisation specified by targetOrgId parameter", func(t *testing.T, sc *scenarioContext) { - bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error { - if query.UserId > 0 { - query.Result = &models.SignedInUser{OrgId: query.OrgId, UserId: userID} - return nil - } - return models.ErrUserNotFound - }) + var targetOrgID int64 = 123 + sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: targetOrgID, UserId: userID} + sc.loginService.ExpectedUser = &models.User{Id: userID} - bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error { - cmd.Result = &models.User{Id: userID} - return nil - }) - - targetOrgID := 123 sc.fakeReq("GET", fmt.Sprintf("/?targetOrgId=%d", targetOrgID)) sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) sc.exec() assert.True(t, sc.context.IsSignedIn) assert.Equal(t, userID, sc.context.UserId) - assert.Equal(t, int64(targetOrgID), sc.context.OrgId) + assert.Equal(t, targetOrgID, sc.context.OrgId) }, func(cfg *setting.Cfg) { configure(cfg) cfg.LDAPEnabled = false @@ -554,15 +516,8 @@ func TestMiddlewareContext(t *testing.T) { const userID int64 = 12 const orgID int64 = 2 - bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID} - return nil - }) - - bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error { - cmd.Result = &models.User{Id: userID} - return nil - }) + sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID} + sc.loginService.ExpectedUser = &models.User{Id: userID} sc.fakeReq("GET", "/") sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) @@ -577,15 +532,8 @@ func TestMiddlewareContext(t *testing.T) { }) middlewareScenario(t, "Should allow the request from whitelist IP", func(t *testing.T, sc *scenarioContext) { - bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID} - return nil - }) - - bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error { - cmd.Result = &models.User{Id: userID} - return nil - }) + sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID} + sc.loginService.ExpectedUser = &models.User{Id: userID} sc.fakeReq("GET", "/") sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) @@ -602,15 +550,7 @@ func TestMiddlewareContext(t *testing.T) { }) middlewareScenario(t, "Should not allow the request from whitelisted IP", func(t *testing.T, sc *scenarioContext) { - bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID} - return nil - }) - - bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error { - cmd.Result = &models.User{Id: userID} - return nil - }) + sc.loginService.ExpectedUser = &models.User{Id: userID} sc.fakeReq("GET", "/") sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) @@ -626,10 +566,6 @@ func TestMiddlewareContext(t *testing.T) { }) middlewareScenario(t, "Should return 407 status code if LDAP says no", func(t *testing.T, sc *scenarioContext) { - bus.AddHandler("LDAP", func(ctx context.Context, cmd *models.UpsertUserCommand) error { - return errors.New("Do not add user") - }) - sc.fakeReq("GET", "/") sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) sc.exec() @@ -639,10 +575,6 @@ func TestMiddlewareContext(t *testing.T) { }, configure) middlewareScenario(t, "Should return 407 status code if there is cache mishap", func(t *testing.T, sc *scenarioContext) { - bus.AddHandler("Do not have the user", func(ctx context.Context, query *models.GetSignedInUserQuery) error { - return errors.New("Do not add user") - }) - sc.fakeReq("GET", "/") sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) sc.exec() @@ -684,7 +616,9 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func( sc.m.UseMiddleware(AddCSPHeader(cfg, logger)) sc.m.UseMiddleware(web.Renderer(viewsPath, "[[", "]]")) - ctxHdlr := getContextHandler(t, cfg) + sc.mockSQLStore = mockstore.NewSQLStoreMock() + sc.loginService = &loginservice.LoginServiceMock{} + ctxHdlr := getContextHandler(t, cfg, sc.mockSQLStore, sc.loginService) sc.sqlStore = ctxHdlr.SQLStore sc.contextHandler = ctxHdlr sc.m.Use(ctxHdlr.Middleware) @@ -714,7 +648,7 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func( }) } -func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHandler { +func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *mockstore.SQLStoreMock, loginService *loginservice.LoginServiceMock) *contexthandler.ContextHandler { t.Helper() sqlStore := sqlstore.InitTestDB(t) @@ -730,8 +664,9 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHa renderSvc := &fakeRenderService{} authJWTSvc := models.NewFakeJWTService() tracer, err := tracing.InitializeTracerForTest() + authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, mockSQLStore) require.NoError(t, err) - return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer) + return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy) } type fakeRenderService struct { diff --git a/pkg/middleware/rate_limit_test.go b/pkg/middleware/rate_limit_test.go index 65b1c365774..641e385018a 100644 --- a/pkg/middleware/rate_limit_test.go +++ b/pkg/middleware/rate_limit_test.go @@ -33,7 +33,7 @@ func rateLimiterScenario(t *testing.T, desc string, rps int, burst int, fn rateL m := web.New() m.UseMiddleware(web.Renderer("../../public/views", "[[", "]]")) - m.Use(getContextHandler(t, cfg).Middleware) + m.Use(getContextHandler(t, cfg, nil, nil).Middleware) m.Get("/foo", RateLimit(rps, burst, func() time.Time { return currentTime }), defaultHandler) fn(func() *httptest.ResponseRecorder { diff --git a/pkg/middleware/recovery_test.go b/pkg/middleware/recovery_test.go index f654628c5e9..51b1c6c1c4d 100644 --- a/pkg/middleware/recovery_test.go +++ b/pkg/middleware/recovery_test.go @@ -70,7 +70,7 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) { sc.userAuthTokenService = auth.NewFakeUserAuthTokenService() sc.remoteCacheService = remotecache.NewFakeStore(t) - contextHandler := getContextHandler(t, nil) + contextHandler := getContextHandler(t, nil, nil, nil) sc.m.Use(contextHandler.Middleware) // mock out gc goroutine sc.m.Use(OrgRedirect(cfg)) diff --git a/pkg/middleware/testing.go b/pkg/middleware/testing.go index 12d4cc08e17..aa677b53770 100644 --- a/pkg/middleware/testing.go +++ b/pkg/middleware/testing.go @@ -10,7 +10,9 @@ import ( "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/contexthandler" + "github.com/grafana/grafana/pkg/services/login/loginservice" "github.com/grafana/grafana/pkg/services/sqlstore" + "github.com/grafana/grafana/pkg/services/sqlstore/mockstore" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/web" "github.com/stretchr/testify/require" @@ -34,7 +36,9 @@ type scenarioContext struct { remoteCacheService *remotecache.RemoteCache cfg *setting.Cfg sqlStore sqlstore.Store + mockSQLStore *mockstore.SQLStoreMock contextHandler *contexthandler.ContextHandler + loginService *loginservice.LoginServiceMock req *http.Request } diff --git a/pkg/server/wire.go b/pkg/server/wire.go index 4d018369bd4..dc59444d16d 100644 --- a/pkg/server/wire.go +++ b/pkg/server/wire.go @@ -32,6 +32,7 @@ import ( "github.com/grafana/grafana/pkg/services/cleanup" "github.com/grafana/grafana/pkg/services/comments" "github.com/grafana/grafana/pkg/services/contexthandler" + "github.com/grafana/grafana/pkg/services/contexthandler/authproxy" "github.com/grafana/grafana/pkg/services/dashboardimport" dashboardimportservice "github.com/grafana/grafana/pkg/services/dashboardimport/service" "github.com/grafana/grafana/pkg/services/dashboards" @@ -228,6 +229,7 @@ var wireBasicSet = wire.NewSet( wire.Bind(new(alerting.DashAlertExtractor), new(*alerting.DashAlertExtractorService)), comments.ProvideService, guardian.ProvideService, + authproxy.ProvideAuthProxy, ) var wireSet = wire.NewSet( diff --git a/pkg/services/contexthandler/auth_proxy_test.go b/pkg/services/contexthandler/auth_proxy_test.go index 5eac50fce84..82e6245c428 100644 --- a/pkg/services/contexthandler/auth_proxy_test.go +++ b/pkg/services/contexthandler/auth_proxy_test.go @@ -6,13 +6,13 @@ import ( "net/http" "testing" - "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/infra/tracing" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/contexthandler/authproxy" + "github.com/grafana/grafana/pkg/services/login/loginservice" "github.com/grafana/grafana/pkg/services/rendering" "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/setting" @@ -20,42 +20,18 @@ import ( "github.com/stretchr/testify/require" ) +const userID = int64(1) +const orgID = int64(4) + // Test initContextWithAuthProxy with a cached user ID that is no longer valid. // // In this case, the cache entry should be ignored/cleared and another attempt should be done to sign the user // in without cache. func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) { const name = "markelog" - const userID = int64(1) - const orgID = int64(4) svc := getContextHandler(t) - // XXX: These handlers have to be injected AFTER calling getContextHandler, since the latter - // creates a SQLStore which installs its own handlers. - upsertHandler := func(ctx context.Context, cmd *models.UpsertUserCommand) error { - require.Equal(t, name, cmd.ExternalUser.Login) - cmd.Result = &models.User{Id: userID} - return nil - } - getUserHandler := func(ctx context.Context, query *models.GetSignedInUserQuery) error { - // Simulate that the cached user ID is stale - if query.UserId != userID { - return models.ErrUserNotFound - } - - query.Result = &models.SignedInUser{ - UserId: userID, - OrgId: orgID, - } - return nil - } - bus.AddHandler("", upsertHandler) - bus.AddHandler("", getUserHandler) - t.Cleanup(func() { - bus.ClearBusHandlers() - }) - req, err := http.NewRequest("POST", "http://example.com", nil) require.NoError(t, err) ctx := &models.ReqContext{ @@ -106,5 +82,24 @@ func getContextHandler(t *testing.T) *ContextHandler { tracer, err := tracing.InitializeTracerForTest() require.NoError(t, err) - return ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer) + loginService := loginservice.LoginServiceMock{ExpectedUser: &models.User{Id: userID}} + authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, &FakeGetSignUserStore{}) + + return ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy) +} + +type FakeGetSignUserStore struct { + sqlstore.Store +} + +func (f *FakeGetSignUserStore) GetSignedInUser(ctx context.Context, query *models.GetSignedInUserQuery) error { + if query.UserId != userID { + return models.ErrUserNotFound + } + + query.Result = &models.SignedInUser{ + UserId: userID, + OrgId: orgID, + } + return nil } diff --git a/pkg/services/contexthandler/authproxy/authproxy.go b/pkg/services/contexthandler/authproxy/authproxy.go index d3efa3073e1..ba4a68f3ad6 100644 --- a/pkg/services/contexthandler/authproxy/authproxy.go +++ b/pkg/services/contexthandler/authproxy/authproxy.go @@ -13,12 +13,13 @@ import ( "strings" "time" - "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/ldap" + "github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/multildap" + "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util" ) @@ -49,11 +50,22 @@ var supportedHeaderFields = []string{"Name", "Email", "Login", "Groups", "Role"} // AuthProxy struct type AuthProxy struct { - cfg *setting.Cfg - remoteCache *remotecache.RemoteCache - ctx *models.ReqContext - orgID int64 - header string + cfg *setting.Cfg + remoteCache *remotecache.RemoteCache + loginService login.Service + sqlStore sqlstore.Store + + logger log.Logger +} + +func ProvideAuthProxy(cfg *setting.Cfg, remoteCache *remotecache.RemoteCache, loginService login.Service, sqlStore sqlstore.Store) *AuthProxy { + return &AuthProxy{ + cfg: cfg, + remoteCache: remoteCache, + loginService: loginService, + sqlStore: sqlStore, + logger: log.New("auth.proxy"), + } } // Error auth proxy specific error @@ -75,40 +87,20 @@ func (err Error) Error() string { return err.Message } -// Options for the AuthProxy -type Options struct { - RemoteCache *remotecache.RemoteCache - Ctx *models.ReqContext - OrgID int64 -} - -// New instance of the AuthProxy. -func New(cfg *setting.Cfg, options *Options) *AuthProxy { - auth := &AuthProxy{ - remoteCache: options.RemoteCache, - cfg: cfg, - ctx: options.Ctx, - orgID: options.OrgID, - } - auth.header = auth.getDecodedHeader(cfg.AuthProxyHeaderName) - return auth -} - // IsEnabled checks if the auth proxy is enabled. func (auth *AuthProxy) IsEnabled() bool { // Bail if the setting is not enabled return auth.cfg.AuthProxyEnabled } -// HasHeader checks if the we have specified header -func (auth *AuthProxy) HasHeader() bool { - return len(auth.header) != 0 +// HasHeader checks if we have specified header +func (auth *AuthProxy) HasHeader(reqCtx *models.ReqContext) bool { + header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName) + return len(header) != 0 } // IsAllowedIP returns whether provided IP is allowed. -func (auth *AuthProxy) IsAllowedIP() error { - ip := auth.ctx.Req.RemoteAddr - +func (auth *AuthProxy) IsAllowedIP(ip string) error { if len(strings.TrimSpace(auth.cfg.AuthProxyWhitelist)) == 0 { return nil } @@ -137,7 +129,7 @@ func (auth *AuthProxy) IsAllowedIP() error { } return newError("proxy authentication required", fmt.Errorf( - "request for user (%s) from %s is not from the authentication proxy", auth.header, + "request for user from %s is not from the authentication proxy", sourceIP, )) } @@ -153,10 +145,11 @@ func HashCacheKey(key string) (string, error) { // getKey forms a key for the cache based on the headers received as part of the authentication flow. // Our configuration supports multiple headers. The main header contains the email or username. // And the additional ones that allow us to specify extra attributes: Name, Email, Role, or Groups. -func (auth *AuthProxy) getKey() (string, error) { - key := strings.TrimSpace(auth.header) // start the key with the main header +func (auth *AuthProxy) getKey(reqCtx *models.ReqContext) (string, error) { + header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName) + key := strings.TrimSpace(header) // start the key with the main header - auth.headersIterator(func(_, header string) { + auth.headersIterator(reqCtx, func(_, header string) { key = strings.Join([]string{key, header}, "-") // compose the key with any additional headers }) @@ -168,17 +161,17 @@ func (auth *AuthProxy) getKey() (string, error) { } // Login logs in user ID by whatever means possible. -func (auth *AuthProxy) Login(logger log.Logger, ignoreCache bool) (int64, error) { +func (auth *AuthProxy) Login(reqCtx *models.ReqContext, ignoreCache bool) (int64, error) { if !ignoreCache { // Error here means absent cache - we don't need to handle that - id, err := auth.GetUserViaCache(logger) + id, err := auth.getUserViaCache(reqCtx) if err == nil && id != 0 { return id, nil } } if isLDAPEnabled(auth.cfg) { - id, err := auth.LoginViaLDAP() + id, err := auth.LoginViaLDAP(reqCtx) if err != nil { if errors.Is(err, ldap.ErrInvalidCredentials) { return 0, newError("proxy authentication required", ldap.ErrInvalidCredentials) @@ -189,7 +182,7 @@ func (auth *AuthProxy) Login(logger log.Logger, ignoreCache bool) (int64, error) return id, nil } - id, err := auth.LoginViaHeader() + id, err := auth.loginViaHeader(reqCtx) if err != nil { return 0, newError("failed to log in as user, specified in auth proxy header", err) } @@ -197,87 +190,89 @@ func (auth *AuthProxy) Login(logger log.Logger, ignoreCache bool) (int64, error) return id, nil } -// GetUserViaCache gets user ID from cache. -func (auth *AuthProxy) GetUserViaCache(logger log.Logger) (int64, error) { - cacheKey, err := auth.getKey() +// getUserViaCache gets user ID from cache. +func (auth *AuthProxy) getUserViaCache(reqCtx *models.ReqContext) (int64, error) { + cacheKey, err := auth.getKey(reqCtx) if err != nil { return 0, err } - logger.Debug("Getting user ID via auth cache", "cacheKey", cacheKey) - userID, err := auth.remoteCache.Get(auth.ctx.Req.Context(), cacheKey) + auth.logger.Debug("Getting user ID via auth cache", "cacheKey", cacheKey) + userID, err := auth.remoteCache.Get(reqCtx.Req.Context(), cacheKey) if err != nil { - logger.Debug("Failed getting user ID via auth cache", "error", err) + auth.logger.Debug("Failed getting user ID via auth cache", "error", err) return 0, err } - logger.Debug("Successfully got user ID via auth cache", "id", userID) + auth.logger.Debug("Successfully got user ID via auth cache", "id", userID) return userID.(int64), nil } // RemoveUserFromCache removes user from cache. -func (auth *AuthProxy) RemoveUserFromCache(logger log.Logger) error { - cacheKey, err := auth.getKey() +func (auth *AuthProxy) RemoveUserFromCache(reqCtx *models.ReqContext) error { + cacheKey, err := auth.getKey(reqCtx) if err != nil { return err } - logger.Debug("Removing user from auth cache", "cacheKey", cacheKey) - if err := auth.remoteCache.Delete(auth.ctx.Req.Context(), cacheKey); err != nil { + auth.logger.Debug("Removing user from auth cache", "cacheKey", cacheKey) + if err := auth.remoteCache.Delete(reqCtx.Req.Context(), cacheKey); err != nil { return err } - logger.Debug("Successfully removed user from auth cache", "cacheKey", cacheKey) + auth.logger.Debug("Successfully removed user from auth cache", "cacheKey", cacheKey) return nil } // LoginViaLDAP logs in user via LDAP request -func (auth *AuthProxy) LoginViaLDAP() (int64, error) { +func (auth *AuthProxy) LoginViaLDAP(reqCtx *models.ReqContext) (int64, error) { config, err := getLDAPConfig(auth.cfg) if err != nil { return 0, newError("failed to get LDAP config", err) } + header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName) mldap := newLDAP(config.Servers) - extUser, _, err := mldap.User(auth.header) + extUser, _, err := mldap.User(header) if err != nil { return 0, err } // Have to sync grafana and LDAP user during log in upsert := &models.UpsertUserCommand{ - ReqContext: auth.ctx, + ReqContext: reqCtx, SignupAllowed: auth.cfg.LDAPAllowSignup, ExternalUser: extUser, } - if err := bus.Dispatch(auth.ctx.Req.Context(), upsert); err != nil { + if err := auth.loginService.UpsertUser(reqCtx.Req.Context(), upsert); err != nil { return 0, err } return upsert.Result.Id, nil } -// LoginViaHeader logs in user from the header only -func (auth *AuthProxy) LoginViaHeader() (int64, error) { +// loginViaHeader logs in user from the header only +func (auth *AuthProxy) loginViaHeader(reqCtx *models.ReqContext) (int64, error) { + header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName) extUser := &models.ExternalUserInfo{ AuthModule: "authproxy", - AuthId: auth.header, + AuthId: header, } switch auth.cfg.AuthProxyHeaderProperty { case "username": - extUser.Login = auth.header + extUser.Login = header - emailAddr, emailErr := mail.ParseAddress(auth.header) // only set Email if it can be parsed as an email address + emailAddr, emailErr := mail.ParseAddress(header) // only set Email if it can be parsed as an email address if emailErr == nil { extUser.Email = emailAddr.Address } case "email": - extUser.Email = auth.header - extUser.Login = auth.header + extUser.Email = header + extUser.Login = header default: return 0, fmt.Errorf("auth proxy header property invalid") } - auth.headersIterator(func(field string, header string) { + auth.headersIterator(reqCtx, func(field string, header string) { switch field { case "Groups": extUser.Groups = util.SplitString(header) @@ -300,12 +295,12 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) { }) upsert := &models.UpsertUserCommand{ - ReqContext: auth.ctx, + ReqContext: reqCtx, SignupAllowed: auth.cfg.AuthProxyAutoSignUp, ExternalUser: extUser, } - err := bus.Dispatch(auth.ctx.Req.Context(), upsert) + err := auth.loginService.UpsertUser(reqCtx.Req.Context(), upsert) if err != nil { return 0, err } @@ -314,8 +309,8 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) { } // getDecodedHeader gets decoded value of a header with given headerName -func (auth *AuthProxy) getDecodedHeader(headerName string) string { - headerValue := auth.ctx.Req.Header.Get(headerName) +func (auth *AuthProxy) getDecodedHeader(reqCtx *models.ReqContext, headerName string) string { + headerValue := reqCtx.Req.Header.Get(headerName) if auth.cfg.AuthProxyHeadersEncoded { headerValue = util.DecodeQuotedPrintable(headerValue) @@ -325,27 +320,27 @@ func (auth *AuthProxy) getDecodedHeader(headerName string) string { } // headersIterator iterates over all non-empty supported additional headers -func (auth *AuthProxy) headersIterator(fn func(field string, header string)) { +func (auth *AuthProxy) headersIterator(reqCtx *models.ReqContext, fn func(field string, header string)) { for _, field := range supportedHeaderFields { h := auth.cfg.AuthProxyHeaders[field] if h == "" { continue } - if value := auth.getDecodedHeader(h); value != "" { + if value := auth.getDecodedHeader(reqCtx, h); value != "" { fn(field, strings.TrimSpace(value)) } } } -// GetSignedUser gets full signed in user info. -func (auth *AuthProxy) GetSignedInUser(userID int64) (*models.SignedInUser, error) { +// GetSignedInUser gets full signed in user info. +func (auth *AuthProxy) GetSignedInUser(userID int64, orgID int64) (*models.SignedInUser, error) { query := &models.GetSignedInUserQuery{ - OrgId: auth.orgID, + OrgId: orgID, UserId: userID, } - if err := bus.Dispatch(context.Background(), query); err != nil { + if err := auth.sqlStore.GetSignedInUser(context.Background(), query); err != nil { return nil, err } @@ -353,21 +348,21 @@ func (auth *AuthProxy) GetSignedInUser(userID int64) (*models.SignedInUser, erro } // Remember user in cache -func (auth *AuthProxy) Remember(id int64) error { - key, err := auth.getKey() +func (auth *AuthProxy) Remember(reqCtx *models.ReqContext, id int64) error { + key, err := auth.getKey(reqCtx) if err != nil { return err } // Check if user already in cache - userID, err := auth.remoteCache.Get(auth.ctx.Req.Context(), key) + userID, err := auth.remoteCache.Get(reqCtx.Req.Context(), key) if err == nil && userID != nil { return nil } expiration := time.Duration(auth.cfg.AuthProxySyncTTL) * time.Minute - if err := auth.remoteCache.Set(auth.ctx.Req.Context(), key, id, expiration); err != nil { + if err := auth.remoteCache.Set(reqCtx.Req.Context(), key, id, expiration); err != nil { return err } diff --git a/pkg/services/contexthandler/authproxy/authproxy_test.go b/pkg/services/contexthandler/authproxy/authproxy_test.go index 565a6305cb9..caa4dc46f71 100644 --- a/pkg/services/contexthandler/authproxy/authproxy_test.go +++ b/pkg/services/contexthandler/authproxy/authproxy_test.go @@ -7,8 +7,8 @@ import ( "net/http" "testing" - "github.com/grafana/grafana/pkg/bus" - "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/services/login/loginservice" + "github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/ldap" @@ -20,8 +20,9 @@ import ( ) const hdrName = "markelog" +const id int64 = 42 -func prepareMiddleware(t *testing.T, remoteCache *remotecache.RemoteCache, configureReq func(*http.Request, *setting.Cfg)) *AuthProxy { +func prepareMiddleware(t *testing.T, remoteCache *remotecache.RemoteCache, configureReq func(*http.Request, *setting.Cfg)) (*AuthProxy, *models.ReqContext) { t.Helper() req, err := http.NewRequest("POST", "http://example.com", nil) @@ -40,17 +41,16 @@ func prepareMiddleware(t *testing.T, remoteCache *remotecache.RemoteCache, confi Context: &web.Context{Req: req}, } - auth := New(cfg, &Options{ - RemoteCache: remoteCache, - Ctx: ctx, - OrgID: 4, - }) + loginService := loginservice.LoginServiceMock{ + ExpectedUser: &models.User{ + Id: id, + }, + } - return auth + return ProvideAuthProxy(cfg, remoteCache, loginService, nil), ctx } func TestMiddlewareContext(t *testing.T) { - logger := log.New("test") cache := remotecache.NewFakeStore(t) t.Run("When the cache only contains the main header with a simple cache key", func(t *testing.T) { @@ -62,12 +62,12 @@ func TestMiddlewareContext(t *testing.T) { err = cache.Set(context.Background(), key, id, 0) require.NoError(t, err) // Set up the middleware - auth := prepareMiddleware(t, cache, nil) - gotKey, err := auth.getKey() + auth, reqCtx := prepareMiddleware(t, cache, nil) + gotKey, err := auth.getKey(reqCtx) require.NoError(t, err) assert.Equal(t, key, gotKey) - gotID, err := auth.Login(logger, false) + gotID, err := auth.Login(reqCtx, false) require.NoError(t, err) assert.Equal(t, id, gotID) @@ -84,7 +84,7 @@ func TestMiddlewareContext(t *testing.T) { err = cache.Set(context.Background(), key, id, 0) require.NoError(t, err) - auth := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) { + auth, reqCtx := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) { cfg.AuthProxyHeaderName = "X-Killa" cfg.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS", "Role": "X-WEBAUTH-ROLE"} req.Header.Set(cfg.AuthProxyHeaderName, hdrName) @@ -93,26 +93,14 @@ func TestMiddlewareContext(t *testing.T) { }) assert.Equal(t, "auth-proxy-sync-ttl:f5acfffd56daac98d502ef8c8b8c5d56", key) - gotID, err := auth.Login(logger, false) + gotID, err := auth.Login(reqCtx, false) require.NoError(t, err) assert.Equal(t, id, gotID) }) } func TestMiddlewareContext_ldap(t *testing.T) { - logger := log.New("test") - t.Run("Logs in via LDAP", func(t *testing.T) { - const id int64 = 42 - - bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error { - cmd.Result = &models.User{ - Id: id, - } - - return nil - }) - origIsLDAPEnabled := isLDAPEnabled origGetLDAPConfig := getLDAPConfig origNewLDAP := newLDAP @@ -147,9 +135,9 @@ func TestMiddlewareContext_ldap(t *testing.T) { cache := remotecache.NewFakeStore(t) - auth := prepareMiddleware(t, cache, nil) + auth, reqCtx := prepareMiddleware(t, cache, nil) - gotID, err := auth.Login(logger, false) + gotID, err := auth.Login(reqCtx, false) require.NoError(t, err) assert.Equal(t, id, gotID) @@ -177,7 +165,7 @@ func TestMiddlewareContext_ldap(t *testing.T) { cache := remotecache.NewFakeStore(t) - auth := prepareMiddleware(t, cache, nil) + auth, reqCtx := prepareMiddleware(t, cache, nil) stub := &multildap.MultiLDAPmock{ ID: id, @@ -187,7 +175,7 @@ func TestMiddlewareContext_ldap(t *testing.T) { return stub } - gotID, err := auth.Login(logger, false) + gotID, err := auth.Login(reqCtx, false) require.EqualError(t, err, "failed to get the user") assert.NotEqual(t, id, gotID) @@ -198,22 +186,24 @@ func TestMiddlewareContext_ldap(t *testing.T) { func TestDecodeHeader(t *testing.T) { cache := remotecache.NewFakeStore(t) t.Run("should not decode header if not enabled in settings", func(t *testing.T) { - auth := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) { + auth, reqCtx := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) { cfg.AuthProxyHeaderName = "X-WEBAUTH-USER" cfg.AuthProxyHeadersEncoded = false req.Header.Set(cfg.AuthProxyHeaderName, "M=C3=BCnchen") }) - assert.Equal(t, "M=C3=BCnchen", auth.header) + header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName) + assert.Equal(t, "M=C3=BCnchen", header) }) t.Run("should decode header if enabled in settings", func(t *testing.T) { - auth := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) { + auth, reqCtx := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) { cfg.AuthProxyHeaderName = "X-WEBAUTH-USER" cfg.AuthProxyHeadersEncoded = true req.Header.Set(cfg.AuthProxyHeaderName, "M=C3=BCnchen") }) - assert.Equal(t, "München", auth.header) + header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName) + assert.Equal(t, "München", header) }) } diff --git a/pkg/services/contexthandler/contexthandler.go b/pkg/services/contexthandler/contexthandler.go index 3d66a81a72b..3cf9c391810 100644 --- a/pkg/services/contexthandler/contexthandler.go +++ b/pkg/services/contexthandler/contexthandler.go @@ -36,7 +36,7 @@ const ServiceName = "ContextHandler" func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService, jwtService models.JWTService, remoteCache *remotecache.RemoteCache, renderService rendering.Service, sqlStore *sqlstore.SQLStore, - tracer tracing.Tracer) *ContextHandler { + tracer tracing.Tracer, authProxy *authproxy.AuthProxy) *ContextHandler { return &ContextHandler{ Cfg: cfg, AuthTokenService: tokenService, @@ -45,6 +45,7 @@ func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService, jwtS RenderService: renderService, SQLStore: sqlStore, tracer: tracer, + authProxy: authProxy, } } @@ -57,6 +58,7 @@ type ContextHandler struct { RenderService rendering.Service SQLStore sqlstore.Store tracer tracing.Tracer + authProxy *authproxy.AuthProxy // GetTime returns the current time. // Stubbable by tests. GetTime func() time.Time @@ -419,10 +421,10 @@ func (h *ContextHandler) initContextWithRenderAuth(reqContext *models.ReqContext return true } -func logUserIn(auth *authproxy.AuthProxy, username string, logger log.Logger, ignoreCache bool) (int64, error) { +func logUserIn(reqContext *models.ReqContext, auth *authproxy.AuthProxy, username string, logger log.Logger, ignoreCache bool) (int64, error) { logger.Debug("Trying to log user in", "username", username, "ignoreCache", ignoreCache) // Try to log in user via various providers - id, err := auth.Login(logger, ignoreCache) + id, err := auth.Login(reqContext, ignoreCache) if err != nil { details := err var e authproxy.Error @@ -451,36 +453,31 @@ func (h *ContextHandler) handleError(ctx *models.ReqContext, err error, statusCo func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext, orgID int64) bool { username := reqContext.Req.Header.Get(h.Cfg.AuthProxyHeaderName) - auth := authproxy.New(h.Cfg, &authproxy.Options{ - RemoteCache: h.RemoteCache, - Ctx: reqContext, - OrgID: orgID, - }) logger := log.New("auth.proxy") // Bail if auth proxy is not enabled - if !auth.IsEnabled() { + if !h.authProxy.IsEnabled() { return false } // If there is no header - we can't move forward - if !auth.HasHeader() { + if !h.authProxy.HasHeader(reqContext) { return false } _, span := h.tracer.Start(reqContext.Req.Context(), "initContextWithAuthProxy") defer span.End() - // Check if allowed to continue with this IP - if err := auth.IsAllowedIP(); err != nil { + // Check if allowed continuing with this IP + if err := h.authProxy.IsAllowedIP(reqContext.Req.RemoteAddr); err != nil { h.handleError(reqContext, err, 407, func(details error) { logger.Error("Failed to check whitelisted IP addresses", "message", err.Error(), "error", details) }) return true } - id, err := logUserIn(auth, username, logger, false) + id, err := logUserIn(reqContext, h.authProxy, username, logger, false) if err != nil { h.handleError(reqContext, err, 407, nil) return true @@ -488,7 +485,7 @@ func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext, logger.Debug("Got user ID, getting full user info", "userID", id) - user, err := auth.GetSignedInUser(id) + user, err := h.authProxy.GetSignedInUser(id, orgID) if err != nil { // The reason we couldn't find the user corresponding to the ID might be that the ID was found from a stale // cache entry. For example, if a user is deleted via the API, corresponding cache entries aren't invalidated @@ -496,18 +493,18 @@ func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext, // we can't easily derive cache keys to invalidate when deleting a user. To work around this, we try to // log the user in again without the cache. logger.Debug("Failed to get user info given ID, retrying without cache", "userID", id) - if err := auth.RemoveUserFromCache(logger); err != nil { + if err := h.authProxy.RemoveUserFromCache(reqContext); err != nil { if !errors.Is(err, remotecache.ErrCacheItemNotFound) { logger.Error("Got unexpected error when removing user from auth cache", "error", err) } } - id, err = logUserIn(auth, username, logger, true) + id, err = logUserIn(reqContext, h.authProxy, username, logger, true) if err != nil { h.handleError(reqContext, err, 407, nil) return true } - user, err = auth.GetSignedInUser(id) + user, err = h.authProxy.GetSignedInUser(id, orgID) if err != nil { h.handleError(reqContext, err, 407, nil) return true @@ -521,7 +518,7 @@ func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext, reqContext.IsSignedIn = true // Remember user data in cache - if err := auth.Remember(id); err != nil { + if err := h.authProxy.Remember(reqContext, id); err != nil { h.handleError(reqContext, err, 500, func(details error) { logger.Error( "Failed to store user in cache", diff --git a/pkg/services/login/loginservice/loginservice_mock.go b/pkg/services/login/loginservice/loginservice_mock.go index 46752034e02..abd77f9f2bd 100644 --- a/pkg/services/login/loginservice/loginservice_mock.go +++ b/pkg/services/login/loginservice/loginservice_mock.go @@ -15,6 +15,9 @@ type LoginServiceMock struct { NoExistingOrgId int64 AlreadyExitingLogin string GeneratedUserId int64 + ExpectedUser *models.User + ExpectedUserFunc func(cmd *models.UpsertUserCommand) *models.User + ExpectedError error } func (s LoginServiceMock) CreateUser(cmd models.CreateUserCommand) (*models.User, error) { @@ -35,5 +38,10 @@ func (s LoginServiceMock) CreateUser(cmd models.CreateUserCommand) (*models.User } func (s LoginServiceMock) UpsertUser(ctx context.Context, cmd *models.UpsertUserCommand) error { - return nil + if s.ExpectedUserFunc != nil { + cmd.Result = s.ExpectedUserFunc(cmd) + return s.ExpectedError + } + cmd.Result = s.ExpectedUser + return s.ExpectedError }