mirror of
https://github.com/grafana/grafana.git
synced 2025-07-30 03:32:20 +08:00
Chore: Remove bus from authproxy (#46936)
* Make authproxy injectable * Fix import * Provide function was in wrong place * Fixing tests * More imports and rollback a change * Fix lint
This commit is contained in:
@ -27,11 +27,13 @@ import (
|
|||||||
"github.com/grafana/grafana/pkg/services/accesscontrol/ossaccesscontrol"
|
"github.com/grafana/grafana/pkg/services/accesscontrol/ossaccesscontrol"
|
||||||
"github.com/grafana/grafana/pkg/services/auth"
|
"github.com/grafana/grafana/pkg/services/auth"
|
||||||
"github.com/grafana/grafana/pkg/services/contexthandler"
|
"github.com/grafana/grafana/pkg/services/contexthandler"
|
||||||
|
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
|
||||||
"github.com/grafana/grafana/pkg/services/dashboards"
|
"github.com/grafana/grafana/pkg/services/dashboards"
|
||||||
dashboardsstore "github.com/grafana/grafana/pkg/services/dashboards/database"
|
dashboardsstore "github.com/grafana/grafana/pkg/services/dashboards/database"
|
||||||
dashboardservice "github.com/grafana/grafana/pkg/services/dashboards/manager"
|
dashboardservice "github.com/grafana/grafana/pkg/services/dashboards/manager"
|
||||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||||
"github.com/grafana/grafana/pkg/services/ldap"
|
"github.com/grafana/grafana/pkg/services/ldap"
|
||||||
|
"github.com/grafana/grafana/pkg/services/login/loginservice"
|
||||||
"github.com/grafana/grafana/pkg/services/quota"
|
"github.com/grafana/grafana/pkg/services/quota"
|
||||||
"github.com/grafana/grafana/pkg/services/rendering"
|
"github.com/grafana/grafana/pkg/services/rendering"
|
||||||
"github.com/grafana/grafana/pkg/services/searchusers"
|
"github.com/grafana/grafana/pkg/services/searchusers"
|
||||||
@ -193,7 +195,8 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHa
|
|||||||
authJWTSvc := models.NewFakeJWTService()
|
authJWTSvc := models.NewFakeJWTService()
|
||||||
tracer, err := tracing.InitializeTracerForTest()
|
tracer, err := tracing.InitializeTracerForTest()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
ctxHdlr := contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer)
|
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginservice.LoginServiceMock{}, sqlStore)
|
||||||
|
ctxHdlr := contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy)
|
||||||
|
|
||||||
return ctxHdlr
|
return ctxHdlr
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,6 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@ -25,8 +24,10 @@ import (
|
|||||||
"github.com/grafana/grafana/pkg/services/auth"
|
"github.com/grafana/grafana/pkg/services/auth"
|
||||||
"github.com/grafana/grafana/pkg/services/contexthandler"
|
"github.com/grafana/grafana/pkg/services/contexthandler"
|
||||||
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
|
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
|
||||||
|
"github.com/grafana/grafana/pkg/services/login/loginservice"
|
||||||
"github.com/grafana/grafana/pkg/services/rendering"
|
"github.com/grafana/grafana/pkg/services/rendering"
|
||||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||||
|
"github.com/grafana/grafana/pkg/services/sqlstore/mockstore"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
"github.com/grafana/grafana/pkg/util"
|
"github.com/grafana/grafana/pkg/util"
|
||||||
"github.com/grafana/grafana/pkg/web"
|
"github.com/grafana/grafana/pkg/web"
|
||||||
@ -364,10 +365,7 @@ func TestMiddlewareContext(t *testing.T) {
|
|||||||
const group = "grafana-core-team"
|
const group = "grafana-core-team"
|
||||||
|
|
||||||
middlewareScenario(t, "Should not sync the user if it's in the cache", func(t *testing.T, sc *scenarioContext) {
|
middlewareScenario(t, "Should not sync the user if it's in the cache", func(t *testing.T, sc *scenarioContext) {
|
||||||
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID}
|
||||||
query.Result = &models.SignedInUser{OrgId: orgID, UserId: query.UserId}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
h, err := authproxy.HashCacheKey(hdrName + "-" + group)
|
h, err := authproxy.HashCacheKey(hdrName + "-" + group)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -387,11 +385,11 @@ func TestMiddlewareContext(t *testing.T) {
|
|||||||
|
|
||||||
middlewareScenario(t, "Should respect auto signup option", func(t *testing.T, sc *scenarioContext) {
|
middlewareScenario(t, "Should respect auto signup option", func(t *testing.T, sc *scenarioContext) {
|
||||||
var actualAuthProxyAutoSignUp *bool = nil
|
var actualAuthProxyAutoSignUp *bool = nil
|
||||||
|
sc.loginService.ExpectedUserFunc = func(cmd *models.UpsertUserCommand) *models.User {
|
||||||
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
|
||||||
actualAuthProxyAutoSignUp = &cmd.SignupAllowed
|
actualAuthProxyAutoSignUp = &cmd.SignupAllowed
|
||||||
return login.ErrInvalidCredentials
|
return nil
|
||||||
})
|
}
|
||||||
|
sc.loginService.ExpectedError = login.ErrInvalidCredentials
|
||||||
|
|
||||||
sc.fakeReq("GET", "/")
|
sc.fakeReq("GET", "/")
|
||||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||||
@ -407,18 +405,8 @@ func TestMiddlewareContext(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
middlewareScenario(t, "Should create an user from a header", func(t *testing.T, sc *scenarioContext) {
|
middlewareScenario(t, "Should create an user from a header", func(t *testing.T, sc *scenarioContext) {
|
||||||
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID}
|
||||||
if query.UserId > 0 {
|
sc.loginService.ExpectedUser = &models.User{Id: userID}
|
||||||
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return models.ErrUserNotFound
|
|
||||||
})
|
|
||||||
|
|
||||||
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
|
||||||
cmd.Result = &models.User{Id: userID}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/")
|
sc.fakeReq("GET", "/")
|
||||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||||
@ -435,19 +423,11 @@ func TestMiddlewareContext(t *testing.T) {
|
|||||||
|
|
||||||
middlewareScenario(t, "Should assign role from header to default org", func(t *testing.T, sc *scenarioContext) {
|
middlewareScenario(t, "Should assign role from header to default org", func(t *testing.T, sc *scenarioContext) {
|
||||||
var storedRoleInfo map[int64]models.RoleType = nil
|
var storedRoleInfo map[int64]models.RoleType = nil
|
||||||
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
sc.loginService.ExpectedUserFunc = func(cmd *models.UpsertUserCommand) *models.User {
|
||||||
if query.UserId > 0 {
|
|
||||||
query.Result = &models.SignedInUser{OrgId: defaultOrgId, UserId: userID, OrgRole: storedRoleInfo[defaultOrgId]}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return models.ErrUserNotFound
|
|
||||||
})
|
|
||||||
|
|
||||||
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
|
||||||
cmd.Result = &models.User{Id: userID}
|
|
||||||
storedRoleInfo = cmd.ExternalUser.OrgRoles
|
storedRoleInfo = cmd.ExternalUser.OrgRoles
|
||||||
return nil
|
sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: defaultOrgId, UserId: userID, OrgRole: storedRoleInfo[defaultOrgId]}
|
||||||
})
|
return &models.User{Id: userID}
|
||||||
|
}
|
||||||
|
|
||||||
sc.fakeReq("GET", "/")
|
sc.fakeReq("GET", "/")
|
||||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||||
@ -466,19 +446,11 @@ func TestMiddlewareContext(t *testing.T) {
|
|||||||
|
|
||||||
middlewareScenario(t, "Should NOT assign role from header to non-default org", func(t *testing.T, sc *scenarioContext) {
|
middlewareScenario(t, "Should NOT assign role from header to non-default org", func(t *testing.T, sc *scenarioContext) {
|
||||||
var storedRoleInfo map[int64]models.RoleType = nil
|
var storedRoleInfo map[int64]models.RoleType = nil
|
||||||
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
sc.loginService.ExpectedUserFunc = func(cmd *models.UpsertUserCommand) *models.User {
|
||||||
if query.UserId > 0 {
|
|
||||||
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID, OrgRole: storedRoleInfo[orgID]}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return models.ErrUserNotFound
|
|
||||||
})
|
|
||||||
|
|
||||||
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
|
||||||
cmd.Result = &models.User{Id: userID}
|
|
||||||
storedRoleInfo = cmd.ExternalUser.OrgRoles
|
storedRoleInfo = cmd.ExternalUser.OrgRoles
|
||||||
return nil
|
sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID, OrgRole: storedRoleInfo[orgID]}
|
||||||
})
|
return &models.User{Id: userID}
|
||||||
|
}
|
||||||
|
|
||||||
sc.fakeReq("GET", "/")
|
sc.fakeReq("GET", "/")
|
||||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||||
@ -499,27 +471,17 @@ func TestMiddlewareContext(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
middlewareScenario(t, "Should use organisation specified by targetOrgId parameter", func(t *testing.T, sc *scenarioContext) {
|
middlewareScenario(t, "Should use organisation specified by targetOrgId parameter", func(t *testing.T, sc *scenarioContext) {
|
||||||
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
var targetOrgID int64 = 123
|
||||||
if query.UserId > 0 {
|
sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: targetOrgID, UserId: userID}
|
||||||
query.Result = &models.SignedInUser{OrgId: query.OrgId, UserId: userID}
|
sc.loginService.ExpectedUser = &models.User{Id: userID}
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return models.ErrUserNotFound
|
|
||||||
})
|
|
||||||
|
|
||||||
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
|
||||||
cmd.Result = &models.User{Id: userID}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
targetOrgID := 123
|
|
||||||
sc.fakeReq("GET", fmt.Sprintf("/?targetOrgId=%d", targetOrgID))
|
sc.fakeReq("GET", fmt.Sprintf("/?targetOrgId=%d", targetOrgID))
|
||||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||||
sc.exec()
|
sc.exec()
|
||||||
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
assert.True(t, sc.context.IsSignedIn)
|
||||||
assert.Equal(t, userID, sc.context.UserId)
|
assert.Equal(t, userID, sc.context.UserId)
|
||||||
assert.Equal(t, int64(targetOrgID), sc.context.OrgId)
|
assert.Equal(t, targetOrgID, sc.context.OrgId)
|
||||||
}, func(cfg *setting.Cfg) {
|
}, func(cfg *setting.Cfg) {
|
||||||
configure(cfg)
|
configure(cfg)
|
||||||
cfg.LDAPEnabled = false
|
cfg.LDAPEnabled = false
|
||||||
@ -554,15 +516,8 @@ func TestMiddlewareContext(t *testing.T) {
|
|||||||
const userID int64 = 12
|
const userID int64 = 12
|
||||||
const orgID int64 = 2
|
const orgID int64 = 2
|
||||||
|
|
||||||
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID}
|
||||||
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
|
sc.loginService.ExpectedUser = &models.User{Id: userID}
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
|
||||||
cmd.Result = &models.User{Id: userID}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/")
|
sc.fakeReq("GET", "/")
|
||||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||||
@ -577,15 +532,8 @@ func TestMiddlewareContext(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
middlewareScenario(t, "Should allow the request from whitelist IP", func(t *testing.T, sc *scenarioContext) {
|
middlewareScenario(t, "Should allow the request from whitelist IP", func(t *testing.T, sc *scenarioContext) {
|
||||||
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID}
|
||||||
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
|
sc.loginService.ExpectedUser = &models.User{Id: userID}
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
|
||||||
cmd.Result = &models.User{Id: userID}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/")
|
sc.fakeReq("GET", "/")
|
||||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||||
@ -602,15 +550,7 @@ func TestMiddlewareContext(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
middlewareScenario(t, "Should not allow the request from whitelisted IP", func(t *testing.T, sc *scenarioContext) {
|
middlewareScenario(t, "Should not allow the request from whitelisted IP", func(t *testing.T, sc *scenarioContext) {
|
||||||
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
sc.loginService.ExpectedUser = &models.User{Id: userID}
|
||||||
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
|
||||||
cmd.Result = &models.User{Id: userID}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/")
|
sc.fakeReq("GET", "/")
|
||||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||||
@ -626,10 +566,6 @@ func TestMiddlewareContext(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
middlewareScenario(t, "Should return 407 status code if LDAP says no", func(t *testing.T, sc *scenarioContext) {
|
middlewareScenario(t, "Should return 407 status code if LDAP says no", func(t *testing.T, sc *scenarioContext) {
|
||||||
bus.AddHandler("LDAP", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
|
||||||
return errors.New("Do not add user")
|
|
||||||
})
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/")
|
sc.fakeReq("GET", "/")
|
||||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||||
sc.exec()
|
sc.exec()
|
||||||
@ -639,10 +575,6 @@ func TestMiddlewareContext(t *testing.T) {
|
|||||||
}, configure)
|
}, configure)
|
||||||
|
|
||||||
middlewareScenario(t, "Should return 407 status code if there is cache mishap", func(t *testing.T, sc *scenarioContext) {
|
middlewareScenario(t, "Should return 407 status code if there is cache mishap", func(t *testing.T, sc *scenarioContext) {
|
||||||
bus.AddHandler("Do not have the user", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
|
||||||
return errors.New("Do not add user")
|
|
||||||
})
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/")
|
sc.fakeReq("GET", "/")
|
||||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||||
sc.exec()
|
sc.exec()
|
||||||
@ -684,7 +616,9 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
|
|||||||
sc.m.UseMiddleware(AddCSPHeader(cfg, logger))
|
sc.m.UseMiddleware(AddCSPHeader(cfg, logger))
|
||||||
sc.m.UseMiddleware(web.Renderer(viewsPath, "[[", "]]"))
|
sc.m.UseMiddleware(web.Renderer(viewsPath, "[[", "]]"))
|
||||||
|
|
||||||
ctxHdlr := getContextHandler(t, cfg)
|
sc.mockSQLStore = mockstore.NewSQLStoreMock()
|
||||||
|
sc.loginService = &loginservice.LoginServiceMock{}
|
||||||
|
ctxHdlr := getContextHandler(t, cfg, sc.mockSQLStore, sc.loginService)
|
||||||
sc.sqlStore = ctxHdlr.SQLStore
|
sc.sqlStore = ctxHdlr.SQLStore
|
||||||
sc.contextHandler = ctxHdlr
|
sc.contextHandler = ctxHdlr
|
||||||
sc.m.Use(ctxHdlr.Middleware)
|
sc.m.Use(ctxHdlr.Middleware)
|
||||||
@ -714,7 +648,7 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHandler {
|
func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *mockstore.SQLStoreMock, loginService *loginservice.LoginServiceMock) *contexthandler.ContextHandler {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
sqlStore := sqlstore.InitTestDB(t)
|
sqlStore := sqlstore.InitTestDB(t)
|
||||||
@ -730,8 +664,9 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHa
|
|||||||
renderSvc := &fakeRenderService{}
|
renderSvc := &fakeRenderService{}
|
||||||
authJWTSvc := models.NewFakeJWTService()
|
authJWTSvc := models.NewFakeJWTService()
|
||||||
tracer, err := tracing.InitializeTracerForTest()
|
tracer, err := tracing.InitializeTracerForTest()
|
||||||
|
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, mockSQLStore)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer)
|
return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy)
|
||||||
}
|
}
|
||||||
|
|
||||||
type fakeRenderService struct {
|
type fakeRenderService struct {
|
||||||
|
@ -33,7 +33,7 @@ func rateLimiterScenario(t *testing.T, desc string, rps int, burst int, fn rateL
|
|||||||
|
|
||||||
m := web.New()
|
m := web.New()
|
||||||
m.UseMiddleware(web.Renderer("../../public/views", "[[", "]]"))
|
m.UseMiddleware(web.Renderer("../../public/views", "[[", "]]"))
|
||||||
m.Use(getContextHandler(t, cfg).Middleware)
|
m.Use(getContextHandler(t, cfg, nil, nil).Middleware)
|
||||||
m.Get("/foo", RateLimit(rps, burst, func() time.Time { return currentTime }), defaultHandler)
|
m.Get("/foo", RateLimit(rps, burst, func() time.Time { return currentTime }), defaultHandler)
|
||||||
|
|
||||||
fn(func() *httptest.ResponseRecorder {
|
fn(func() *httptest.ResponseRecorder {
|
||||||
|
@ -70,7 +70,7 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) {
|
|||||||
sc.userAuthTokenService = auth.NewFakeUserAuthTokenService()
|
sc.userAuthTokenService = auth.NewFakeUserAuthTokenService()
|
||||||
sc.remoteCacheService = remotecache.NewFakeStore(t)
|
sc.remoteCacheService = remotecache.NewFakeStore(t)
|
||||||
|
|
||||||
contextHandler := getContextHandler(t, nil)
|
contextHandler := getContextHandler(t, nil, nil, nil)
|
||||||
sc.m.Use(contextHandler.Middleware)
|
sc.m.Use(contextHandler.Middleware)
|
||||||
// mock out gc goroutine
|
// mock out gc goroutine
|
||||||
sc.m.Use(OrgRedirect(cfg))
|
sc.m.Use(OrgRedirect(cfg))
|
||||||
|
@ -10,7 +10,9 @@ import (
|
|||||||
"github.com/grafana/grafana/pkg/models"
|
"github.com/grafana/grafana/pkg/models"
|
||||||
"github.com/grafana/grafana/pkg/services/auth"
|
"github.com/grafana/grafana/pkg/services/auth"
|
||||||
"github.com/grafana/grafana/pkg/services/contexthandler"
|
"github.com/grafana/grafana/pkg/services/contexthandler"
|
||||||
|
"github.com/grafana/grafana/pkg/services/login/loginservice"
|
||||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||||
|
"github.com/grafana/grafana/pkg/services/sqlstore/mockstore"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
"github.com/grafana/grafana/pkg/web"
|
"github.com/grafana/grafana/pkg/web"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@ -34,7 +36,9 @@ type scenarioContext struct {
|
|||||||
remoteCacheService *remotecache.RemoteCache
|
remoteCacheService *remotecache.RemoteCache
|
||||||
cfg *setting.Cfg
|
cfg *setting.Cfg
|
||||||
sqlStore sqlstore.Store
|
sqlStore sqlstore.Store
|
||||||
|
mockSQLStore *mockstore.SQLStoreMock
|
||||||
contextHandler *contexthandler.ContextHandler
|
contextHandler *contexthandler.ContextHandler
|
||||||
|
loginService *loginservice.LoginServiceMock
|
||||||
|
|
||||||
req *http.Request
|
req *http.Request
|
||||||
}
|
}
|
||||||
|
@ -32,6 +32,7 @@ import (
|
|||||||
"github.com/grafana/grafana/pkg/services/cleanup"
|
"github.com/grafana/grafana/pkg/services/cleanup"
|
||||||
"github.com/grafana/grafana/pkg/services/comments"
|
"github.com/grafana/grafana/pkg/services/comments"
|
||||||
"github.com/grafana/grafana/pkg/services/contexthandler"
|
"github.com/grafana/grafana/pkg/services/contexthandler"
|
||||||
|
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
|
||||||
"github.com/grafana/grafana/pkg/services/dashboardimport"
|
"github.com/grafana/grafana/pkg/services/dashboardimport"
|
||||||
dashboardimportservice "github.com/grafana/grafana/pkg/services/dashboardimport/service"
|
dashboardimportservice "github.com/grafana/grafana/pkg/services/dashboardimport/service"
|
||||||
"github.com/grafana/grafana/pkg/services/dashboards"
|
"github.com/grafana/grafana/pkg/services/dashboards"
|
||||||
@ -228,6 +229,7 @@ var wireBasicSet = wire.NewSet(
|
|||||||
wire.Bind(new(alerting.DashAlertExtractor), new(*alerting.DashAlertExtractorService)),
|
wire.Bind(new(alerting.DashAlertExtractor), new(*alerting.DashAlertExtractorService)),
|
||||||
comments.ProvideService,
|
comments.ProvideService,
|
||||||
guardian.ProvideService,
|
guardian.ProvideService,
|
||||||
|
authproxy.ProvideAuthProxy,
|
||||||
)
|
)
|
||||||
|
|
||||||
var wireSet = wire.NewSet(
|
var wireSet = wire.NewSet(
|
||||||
|
@ -6,13 +6,13 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/bus"
|
|
||||||
"github.com/grafana/grafana/pkg/infra/log"
|
"github.com/grafana/grafana/pkg/infra/log"
|
||||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
"github.com/grafana/grafana/pkg/infra/remotecache"
|
||||||
"github.com/grafana/grafana/pkg/infra/tracing"
|
"github.com/grafana/grafana/pkg/infra/tracing"
|
||||||
"github.com/grafana/grafana/pkg/models"
|
"github.com/grafana/grafana/pkg/models"
|
||||||
"github.com/grafana/grafana/pkg/services/auth"
|
"github.com/grafana/grafana/pkg/services/auth"
|
||||||
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
|
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
|
||||||
|
"github.com/grafana/grafana/pkg/services/login/loginservice"
|
||||||
"github.com/grafana/grafana/pkg/services/rendering"
|
"github.com/grafana/grafana/pkg/services/rendering"
|
||||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
@ -20,42 +20,18 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const userID = int64(1)
|
||||||
|
const orgID = int64(4)
|
||||||
|
|
||||||
// Test initContextWithAuthProxy with a cached user ID that is no longer valid.
|
// Test initContextWithAuthProxy with a cached user ID that is no longer valid.
|
||||||
//
|
//
|
||||||
// In this case, the cache entry should be ignored/cleared and another attempt should be done to sign the user
|
// In this case, the cache entry should be ignored/cleared and another attempt should be done to sign the user
|
||||||
// in without cache.
|
// in without cache.
|
||||||
func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) {
|
func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) {
|
||||||
const name = "markelog"
|
const name = "markelog"
|
||||||
const userID = int64(1)
|
|
||||||
const orgID = int64(4)
|
|
||||||
|
|
||||||
svc := getContextHandler(t)
|
svc := getContextHandler(t)
|
||||||
|
|
||||||
// XXX: These handlers have to be injected AFTER calling getContextHandler, since the latter
|
|
||||||
// creates a SQLStore which installs its own handlers.
|
|
||||||
upsertHandler := func(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
|
||||||
require.Equal(t, name, cmd.ExternalUser.Login)
|
|
||||||
cmd.Result = &models.User{Id: userID}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
getUserHandler := func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
|
||||||
// Simulate that the cached user ID is stale
|
|
||||||
if query.UserId != userID {
|
|
||||||
return models.ErrUserNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
query.Result = &models.SignedInUser{
|
|
||||||
UserId: userID,
|
|
||||||
OrgId: orgID,
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
bus.AddHandler("", upsertHandler)
|
|
||||||
bus.AddHandler("", getUserHandler)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
bus.ClearBusHandlers()
|
|
||||||
})
|
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", "http://example.com", nil)
|
req, err := http.NewRequest("POST", "http://example.com", nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
ctx := &models.ReqContext{
|
ctx := &models.ReqContext{
|
||||||
@ -106,5 +82,24 @@ func getContextHandler(t *testing.T) *ContextHandler {
|
|||||||
tracer, err := tracing.InitializeTracerForTest()
|
tracer, err := tracing.InitializeTracerForTest()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer)
|
loginService := loginservice.LoginServiceMock{ExpectedUser: &models.User{Id: userID}}
|
||||||
|
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, &FakeGetSignUserStore{})
|
||||||
|
|
||||||
|
return ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy)
|
||||||
|
}
|
||||||
|
|
||||||
|
type FakeGetSignUserStore struct {
|
||||||
|
sqlstore.Store
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FakeGetSignUserStore) GetSignedInUser(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||||
|
if query.UserId != userID {
|
||||||
|
return models.ErrUserNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
query.Result = &models.SignedInUser{
|
||||||
|
UserId: userID,
|
||||||
|
OrgId: orgID,
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -13,12 +13,13 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/bus"
|
|
||||||
"github.com/grafana/grafana/pkg/infra/log"
|
"github.com/grafana/grafana/pkg/infra/log"
|
||||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
"github.com/grafana/grafana/pkg/infra/remotecache"
|
||||||
"github.com/grafana/grafana/pkg/models"
|
"github.com/grafana/grafana/pkg/models"
|
||||||
"github.com/grafana/grafana/pkg/services/ldap"
|
"github.com/grafana/grafana/pkg/services/ldap"
|
||||||
|
"github.com/grafana/grafana/pkg/services/login"
|
||||||
"github.com/grafana/grafana/pkg/services/multildap"
|
"github.com/grafana/grafana/pkg/services/multildap"
|
||||||
|
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
"github.com/grafana/grafana/pkg/util"
|
"github.com/grafana/grafana/pkg/util"
|
||||||
)
|
)
|
||||||
@ -49,11 +50,22 @@ var supportedHeaderFields = []string{"Name", "Email", "Login", "Groups", "Role"}
|
|||||||
|
|
||||||
// AuthProxy struct
|
// AuthProxy struct
|
||||||
type AuthProxy struct {
|
type AuthProxy struct {
|
||||||
cfg *setting.Cfg
|
cfg *setting.Cfg
|
||||||
remoteCache *remotecache.RemoteCache
|
remoteCache *remotecache.RemoteCache
|
||||||
ctx *models.ReqContext
|
loginService login.Service
|
||||||
orgID int64
|
sqlStore sqlstore.Store
|
||||||
header string
|
|
||||||
|
logger log.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func ProvideAuthProxy(cfg *setting.Cfg, remoteCache *remotecache.RemoteCache, loginService login.Service, sqlStore sqlstore.Store) *AuthProxy {
|
||||||
|
return &AuthProxy{
|
||||||
|
cfg: cfg,
|
||||||
|
remoteCache: remoteCache,
|
||||||
|
loginService: loginService,
|
||||||
|
sqlStore: sqlStore,
|
||||||
|
logger: log.New("auth.proxy"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error auth proxy specific error
|
// Error auth proxy specific error
|
||||||
@ -75,40 +87,20 @@ func (err Error) Error() string {
|
|||||||
return err.Message
|
return err.Message
|
||||||
}
|
}
|
||||||
|
|
||||||
// Options for the AuthProxy
|
|
||||||
type Options struct {
|
|
||||||
RemoteCache *remotecache.RemoteCache
|
|
||||||
Ctx *models.ReqContext
|
|
||||||
OrgID int64
|
|
||||||
}
|
|
||||||
|
|
||||||
// New instance of the AuthProxy.
|
|
||||||
func New(cfg *setting.Cfg, options *Options) *AuthProxy {
|
|
||||||
auth := &AuthProxy{
|
|
||||||
remoteCache: options.RemoteCache,
|
|
||||||
cfg: cfg,
|
|
||||||
ctx: options.Ctx,
|
|
||||||
orgID: options.OrgID,
|
|
||||||
}
|
|
||||||
auth.header = auth.getDecodedHeader(cfg.AuthProxyHeaderName)
|
|
||||||
return auth
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsEnabled checks if the auth proxy is enabled.
|
// IsEnabled checks if the auth proxy is enabled.
|
||||||
func (auth *AuthProxy) IsEnabled() bool {
|
func (auth *AuthProxy) IsEnabled() bool {
|
||||||
// Bail if the setting is not enabled
|
// Bail if the setting is not enabled
|
||||||
return auth.cfg.AuthProxyEnabled
|
return auth.cfg.AuthProxyEnabled
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasHeader checks if the we have specified header
|
// HasHeader checks if we have specified header
|
||||||
func (auth *AuthProxy) HasHeader() bool {
|
func (auth *AuthProxy) HasHeader(reqCtx *models.ReqContext) bool {
|
||||||
return len(auth.header) != 0
|
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
|
||||||
|
return len(header) != 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsAllowedIP returns whether provided IP is allowed.
|
// IsAllowedIP returns whether provided IP is allowed.
|
||||||
func (auth *AuthProxy) IsAllowedIP() error {
|
func (auth *AuthProxy) IsAllowedIP(ip string) error {
|
||||||
ip := auth.ctx.Req.RemoteAddr
|
|
||||||
|
|
||||||
if len(strings.TrimSpace(auth.cfg.AuthProxyWhitelist)) == 0 {
|
if len(strings.TrimSpace(auth.cfg.AuthProxyWhitelist)) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -137,7 +129,7 @@ func (auth *AuthProxy) IsAllowedIP() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return newError("proxy authentication required", fmt.Errorf(
|
return newError("proxy authentication required", fmt.Errorf(
|
||||||
"request for user (%s) from %s is not from the authentication proxy", auth.header,
|
"request for user from %s is not from the authentication proxy",
|
||||||
sourceIP,
|
sourceIP,
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
@ -153,10 +145,11 @@ func HashCacheKey(key string) (string, error) {
|
|||||||
// getKey forms a key for the cache based on the headers received as part of the authentication flow.
|
// getKey forms a key for the cache based on the headers received as part of the authentication flow.
|
||||||
// Our configuration supports multiple headers. The main header contains the email or username.
|
// Our configuration supports multiple headers. The main header contains the email or username.
|
||||||
// And the additional ones that allow us to specify extra attributes: Name, Email, Role, or Groups.
|
// And the additional ones that allow us to specify extra attributes: Name, Email, Role, or Groups.
|
||||||
func (auth *AuthProxy) getKey() (string, error) {
|
func (auth *AuthProxy) getKey(reqCtx *models.ReqContext) (string, error) {
|
||||||
key := strings.TrimSpace(auth.header) // start the key with the main header
|
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
|
||||||
|
key := strings.TrimSpace(header) // start the key with the main header
|
||||||
|
|
||||||
auth.headersIterator(func(_, header string) {
|
auth.headersIterator(reqCtx, func(_, header string) {
|
||||||
key = strings.Join([]string{key, header}, "-") // compose the key with any additional headers
|
key = strings.Join([]string{key, header}, "-") // compose the key with any additional headers
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -168,17 +161,17 @@ func (auth *AuthProxy) getKey() (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Login logs in user ID by whatever means possible.
|
// Login logs in user ID by whatever means possible.
|
||||||
func (auth *AuthProxy) Login(logger log.Logger, ignoreCache bool) (int64, error) {
|
func (auth *AuthProxy) Login(reqCtx *models.ReqContext, ignoreCache bool) (int64, error) {
|
||||||
if !ignoreCache {
|
if !ignoreCache {
|
||||||
// Error here means absent cache - we don't need to handle that
|
// Error here means absent cache - we don't need to handle that
|
||||||
id, err := auth.GetUserViaCache(logger)
|
id, err := auth.getUserViaCache(reqCtx)
|
||||||
if err == nil && id != 0 {
|
if err == nil && id != 0 {
|
||||||
return id, nil
|
return id, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if isLDAPEnabled(auth.cfg) {
|
if isLDAPEnabled(auth.cfg) {
|
||||||
id, err := auth.LoginViaLDAP()
|
id, err := auth.LoginViaLDAP(reqCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ldap.ErrInvalidCredentials) {
|
if errors.Is(err, ldap.ErrInvalidCredentials) {
|
||||||
return 0, newError("proxy authentication required", ldap.ErrInvalidCredentials)
|
return 0, newError("proxy authentication required", ldap.ErrInvalidCredentials)
|
||||||
@ -189,7 +182,7 @@ func (auth *AuthProxy) Login(logger log.Logger, ignoreCache bool) (int64, error)
|
|||||||
return id, nil
|
return id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
id, err := auth.LoginViaHeader()
|
id, err := auth.loginViaHeader(reqCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, newError("failed to log in as user, specified in auth proxy header", err)
|
return 0, newError("failed to log in as user, specified in auth proxy header", err)
|
||||||
}
|
}
|
||||||
@ -197,87 +190,89 @@ func (auth *AuthProxy) Login(logger log.Logger, ignoreCache bool) (int64, error)
|
|||||||
return id, nil
|
return id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserViaCache gets user ID from cache.
|
// getUserViaCache gets user ID from cache.
|
||||||
func (auth *AuthProxy) GetUserViaCache(logger log.Logger) (int64, error) {
|
func (auth *AuthProxy) getUserViaCache(reqCtx *models.ReqContext) (int64, error) {
|
||||||
cacheKey, err := auth.getKey()
|
cacheKey, err := auth.getKey(reqCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
logger.Debug("Getting user ID via auth cache", "cacheKey", cacheKey)
|
auth.logger.Debug("Getting user ID via auth cache", "cacheKey", cacheKey)
|
||||||
userID, err := auth.remoteCache.Get(auth.ctx.Req.Context(), cacheKey)
|
userID, err := auth.remoteCache.Get(reqCtx.Req.Context(), cacheKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Debug("Failed getting user ID via auth cache", "error", err)
|
auth.logger.Debug("Failed getting user ID via auth cache", "error", err)
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug("Successfully got user ID via auth cache", "id", userID)
|
auth.logger.Debug("Successfully got user ID via auth cache", "id", userID)
|
||||||
return userID.(int64), nil
|
return userID.(int64), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveUserFromCache removes user from cache.
|
// RemoveUserFromCache removes user from cache.
|
||||||
func (auth *AuthProxy) RemoveUserFromCache(logger log.Logger) error {
|
func (auth *AuthProxy) RemoveUserFromCache(reqCtx *models.ReqContext) error {
|
||||||
cacheKey, err := auth.getKey()
|
cacheKey, err := auth.getKey(reqCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
logger.Debug("Removing user from auth cache", "cacheKey", cacheKey)
|
auth.logger.Debug("Removing user from auth cache", "cacheKey", cacheKey)
|
||||||
if err := auth.remoteCache.Delete(auth.ctx.Req.Context(), cacheKey); err != nil {
|
if err := auth.remoteCache.Delete(reqCtx.Req.Context(), cacheKey); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug("Successfully removed user from auth cache", "cacheKey", cacheKey)
|
auth.logger.Debug("Successfully removed user from auth cache", "cacheKey", cacheKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginViaLDAP logs in user via LDAP request
|
// LoginViaLDAP logs in user via LDAP request
|
||||||
func (auth *AuthProxy) LoginViaLDAP() (int64, error) {
|
func (auth *AuthProxy) LoginViaLDAP(reqCtx *models.ReqContext) (int64, error) {
|
||||||
config, err := getLDAPConfig(auth.cfg)
|
config, err := getLDAPConfig(auth.cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, newError("failed to get LDAP config", err)
|
return 0, newError("failed to get LDAP config", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
|
||||||
mldap := newLDAP(config.Servers)
|
mldap := newLDAP(config.Servers)
|
||||||
extUser, _, err := mldap.User(auth.header)
|
extUser, _, err := mldap.User(header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Have to sync grafana and LDAP user during log in
|
// Have to sync grafana and LDAP user during log in
|
||||||
upsert := &models.UpsertUserCommand{
|
upsert := &models.UpsertUserCommand{
|
||||||
ReqContext: auth.ctx,
|
ReqContext: reqCtx,
|
||||||
SignupAllowed: auth.cfg.LDAPAllowSignup,
|
SignupAllowed: auth.cfg.LDAPAllowSignup,
|
||||||
ExternalUser: extUser,
|
ExternalUser: extUser,
|
||||||
}
|
}
|
||||||
if err := bus.Dispatch(auth.ctx.Req.Context(), upsert); err != nil {
|
if err := auth.loginService.UpsertUser(reqCtx.Req.Context(), upsert); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return upsert.Result.Id, nil
|
return upsert.Result.Id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginViaHeader logs in user from the header only
|
// loginViaHeader logs in user from the header only
|
||||||
func (auth *AuthProxy) LoginViaHeader() (int64, error) {
|
func (auth *AuthProxy) loginViaHeader(reqCtx *models.ReqContext) (int64, error) {
|
||||||
|
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
|
||||||
extUser := &models.ExternalUserInfo{
|
extUser := &models.ExternalUserInfo{
|
||||||
AuthModule: "authproxy",
|
AuthModule: "authproxy",
|
||||||
AuthId: auth.header,
|
AuthId: header,
|
||||||
}
|
}
|
||||||
|
|
||||||
switch auth.cfg.AuthProxyHeaderProperty {
|
switch auth.cfg.AuthProxyHeaderProperty {
|
||||||
case "username":
|
case "username":
|
||||||
extUser.Login = auth.header
|
extUser.Login = header
|
||||||
|
|
||||||
emailAddr, emailErr := mail.ParseAddress(auth.header) // only set Email if it can be parsed as an email address
|
emailAddr, emailErr := mail.ParseAddress(header) // only set Email if it can be parsed as an email address
|
||||||
if emailErr == nil {
|
if emailErr == nil {
|
||||||
extUser.Email = emailAddr.Address
|
extUser.Email = emailAddr.Address
|
||||||
}
|
}
|
||||||
case "email":
|
case "email":
|
||||||
extUser.Email = auth.header
|
extUser.Email = header
|
||||||
extUser.Login = auth.header
|
extUser.Login = header
|
||||||
default:
|
default:
|
||||||
return 0, fmt.Errorf("auth proxy header property invalid")
|
return 0, fmt.Errorf("auth proxy header property invalid")
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.headersIterator(func(field string, header string) {
|
auth.headersIterator(reqCtx, func(field string, header string) {
|
||||||
switch field {
|
switch field {
|
||||||
case "Groups":
|
case "Groups":
|
||||||
extUser.Groups = util.SplitString(header)
|
extUser.Groups = util.SplitString(header)
|
||||||
@ -300,12 +295,12 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
upsert := &models.UpsertUserCommand{
|
upsert := &models.UpsertUserCommand{
|
||||||
ReqContext: auth.ctx,
|
ReqContext: reqCtx,
|
||||||
SignupAllowed: auth.cfg.AuthProxyAutoSignUp,
|
SignupAllowed: auth.cfg.AuthProxyAutoSignUp,
|
||||||
ExternalUser: extUser,
|
ExternalUser: extUser,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := bus.Dispatch(auth.ctx.Req.Context(), upsert)
|
err := auth.loginService.UpsertUser(reqCtx.Req.Context(), upsert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -314,8 +309,8 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getDecodedHeader gets decoded value of a header with given headerName
|
// getDecodedHeader gets decoded value of a header with given headerName
|
||||||
func (auth *AuthProxy) getDecodedHeader(headerName string) string {
|
func (auth *AuthProxy) getDecodedHeader(reqCtx *models.ReqContext, headerName string) string {
|
||||||
headerValue := auth.ctx.Req.Header.Get(headerName)
|
headerValue := reqCtx.Req.Header.Get(headerName)
|
||||||
|
|
||||||
if auth.cfg.AuthProxyHeadersEncoded {
|
if auth.cfg.AuthProxyHeadersEncoded {
|
||||||
headerValue = util.DecodeQuotedPrintable(headerValue)
|
headerValue = util.DecodeQuotedPrintable(headerValue)
|
||||||
@ -325,27 +320,27 @@ func (auth *AuthProxy) getDecodedHeader(headerName string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// headersIterator iterates over all non-empty supported additional headers
|
// headersIterator iterates over all non-empty supported additional headers
|
||||||
func (auth *AuthProxy) headersIterator(fn func(field string, header string)) {
|
func (auth *AuthProxy) headersIterator(reqCtx *models.ReqContext, fn func(field string, header string)) {
|
||||||
for _, field := range supportedHeaderFields {
|
for _, field := range supportedHeaderFields {
|
||||||
h := auth.cfg.AuthProxyHeaders[field]
|
h := auth.cfg.AuthProxyHeaders[field]
|
||||||
if h == "" {
|
if h == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if value := auth.getDecodedHeader(h); value != "" {
|
if value := auth.getDecodedHeader(reqCtx, h); value != "" {
|
||||||
fn(field, strings.TrimSpace(value))
|
fn(field, strings.TrimSpace(value))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSignedUser gets full signed in user info.
|
// GetSignedInUser gets full signed in user info.
|
||||||
func (auth *AuthProxy) GetSignedInUser(userID int64) (*models.SignedInUser, error) {
|
func (auth *AuthProxy) GetSignedInUser(userID int64, orgID int64) (*models.SignedInUser, error) {
|
||||||
query := &models.GetSignedInUserQuery{
|
query := &models.GetSignedInUserQuery{
|
||||||
OrgId: auth.orgID,
|
OrgId: orgID,
|
||||||
UserId: userID,
|
UserId: userID,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := bus.Dispatch(context.Background(), query); err != nil {
|
if err := auth.sqlStore.GetSignedInUser(context.Background(), query); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -353,21 +348,21 @@ func (auth *AuthProxy) GetSignedInUser(userID int64) (*models.SignedInUser, erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Remember user in cache
|
// Remember user in cache
|
||||||
func (auth *AuthProxy) Remember(id int64) error {
|
func (auth *AuthProxy) Remember(reqCtx *models.ReqContext, id int64) error {
|
||||||
key, err := auth.getKey()
|
key, err := auth.getKey(reqCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if user already in cache
|
// Check if user already in cache
|
||||||
userID, err := auth.remoteCache.Get(auth.ctx.Req.Context(), key)
|
userID, err := auth.remoteCache.Get(reqCtx.Req.Context(), key)
|
||||||
if err == nil && userID != nil {
|
if err == nil && userID != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
expiration := time.Duration(auth.cfg.AuthProxySyncTTL) * time.Minute
|
expiration := time.Duration(auth.cfg.AuthProxySyncTTL) * time.Minute
|
||||||
|
|
||||||
if err := auth.remoteCache.Set(auth.ctx.Req.Context(), key, id, expiration); err != nil {
|
if err := auth.remoteCache.Set(reqCtx.Req.Context(), key, id, expiration); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,8 +7,8 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/bus"
|
"github.com/grafana/grafana/pkg/services/login/loginservice"
|
||||||
"github.com/grafana/grafana/pkg/infra/log"
|
|
||||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
"github.com/grafana/grafana/pkg/infra/remotecache"
|
||||||
"github.com/grafana/grafana/pkg/models"
|
"github.com/grafana/grafana/pkg/models"
|
||||||
"github.com/grafana/grafana/pkg/services/ldap"
|
"github.com/grafana/grafana/pkg/services/ldap"
|
||||||
@ -20,8 +20,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const hdrName = "markelog"
|
const hdrName = "markelog"
|
||||||
|
const id int64 = 42
|
||||||
|
|
||||||
func prepareMiddleware(t *testing.T, remoteCache *remotecache.RemoteCache, configureReq func(*http.Request, *setting.Cfg)) *AuthProxy {
|
func prepareMiddleware(t *testing.T, remoteCache *remotecache.RemoteCache, configureReq func(*http.Request, *setting.Cfg)) (*AuthProxy, *models.ReqContext) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", "http://example.com", nil)
|
req, err := http.NewRequest("POST", "http://example.com", nil)
|
||||||
@ -40,17 +41,16 @@ func prepareMiddleware(t *testing.T, remoteCache *remotecache.RemoteCache, confi
|
|||||||
Context: &web.Context{Req: req},
|
Context: &web.Context{Req: req},
|
||||||
}
|
}
|
||||||
|
|
||||||
auth := New(cfg, &Options{
|
loginService := loginservice.LoginServiceMock{
|
||||||
RemoteCache: remoteCache,
|
ExpectedUser: &models.User{
|
||||||
Ctx: ctx,
|
Id: id,
|
||||||
OrgID: 4,
|
},
|
||||||
})
|
}
|
||||||
|
|
||||||
return auth
|
return ProvideAuthProxy(cfg, remoteCache, loginService, nil), ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMiddlewareContext(t *testing.T) {
|
func TestMiddlewareContext(t *testing.T) {
|
||||||
logger := log.New("test")
|
|
||||||
cache := remotecache.NewFakeStore(t)
|
cache := remotecache.NewFakeStore(t)
|
||||||
|
|
||||||
t.Run("When the cache only contains the main header with a simple cache key", func(t *testing.T) {
|
t.Run("When the cache only contains the main header with a simple cache key", func(t *testing.T) {
|
||||||
@ -62,12 +62,12 @@ func TestMiddlewareContext(t *testing.T) {
|
|||||||
err = cache.Set(context.Background(), key, id, 0)
|
err = cache.Set(context.Background(), key, id, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// Set up the middleware
|
// Set up the middleware
|
||||||
auth := prepareMiddleware(t, cache, nil)
|
auth, reqCtx := prepareMiddleware(t, cache, nil)
|
||||||
gotKey, err := auth.getKey()
|
gotKey, err := auth.getKey(reqCtx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, key, gotKey)
|
assert.Equal(t, key, gotKey)
|
||||||
|
|
||||||
gotID, err := auth.Login(logger, false)
|
gotID, err := auth.Login(reqCtx, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, id, gotID)
|
assert.Equal(t, id, gotID)
|
||||||
@ -84,7 +84,7 @@ func TestMiddlewareContext(t *testing.T) {
|
|||||||
err = cache.Set(context.Background(), key, id, 0)
|
err = cache.Set(context.Background(), key, id, 0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
auth := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
|
auth, reqCtx := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
|
||||||
cfg.AuthProxyHeaderName = "X-Killa"
|
cfg.AuthProxyHeaderName = "X-Killa"
|
||||||
cfg.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS", "Role": "X-WEBAUTH-ROLE"}
|
cfg.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS", "Role": "X-WEBAUTH-ROLE"}
|
||||||
req.Header.Set(cfg.AuthProxyHeaderName, hdrName)
|
req.Header.Set(cfg.AuthProxyHeaderName, hdrName)
|
||||||
@ -93,26 +93,14 @@ func TestMiddlewareContext(t *testing.T) {
|
|||||||
})
|
})
|
||||||
assert.Equal(t, "auth-proxy-sync-ttl:f5acfffd56daac98d502ef8c8b8c5d56", key)
|
assert.Equal(t, "auth-proxy-sync-ttl:f5acfffd56daac98d502ef8c8b8c5d56", key)
|
||||||
|
|
||||||
gotID, err := auth.Login(logger, false)
|
gotID, err := auth.Login(reqCtx, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, id, gotID)
|
assert.Equal(t, id, gotID)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMiddlewareContext_ldap(t *testing.T) {
|
func TestMiddlewareContext_ldap(t *testing.T) {
|
||||||
logger := log.New("test")
|
|
||||||
|
|
||||||
t.Run("Logs in via LDAP", func(t *testing.T) {
|
t.Run("Logs in via LDAP", func(t *testing.T) {
|
||||||
const id int64 = 42
|
|
||||||
|
|
||||||
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
|
||||||
cmd.Result = &models.User{
|
|
||||||
Id: id,
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
origIsLDAPEnabled := isLDAPEnabled
|
origIsLDAPEnabled := isLDAPEnabled
|
||||||
origGetLDAPConfig := getLDAPConfig
|
origGetLDAPConfig := getLDAPConfig
|
||||||
origNewLDAP := newLDAP
|
origNewLDAP := newLDAP
|
||||||
@ -147,9 +135,9 @@ func TestMiddlewareContext_ldap(t *testing.T) {
|
|||||||
|
|
||||||
cache := remotecache.NewFakeStore(t)
|
cache := remotecache.NewFakeStore(t)
|
||||||
|
|
||||||
auth := prepareMiddleware(t, cache, nil)
|
auth, reqCtx := prepareMiddleware(t, cache, nil)
|
||||||
|
|
||||||
gotID, err := auth.Login(logger, false)
|
gotID, err := auth.Login(reqCtx, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, id, gotID)
|
assert.Equal(t, id, gotID)
|
||||||
@ -177,7 +165,7 @@ func TestMiddlewareContext_ldap(t *testing.T) {
|
|||||||
|
|
||||||
cache := remotecache.NewFakeStore(t)
|
cache := remotecache.NewFakeStore(t)
|
||||||
|
|
||||||
auth := prepareMiddleware(t, cache, nil)
|
auth, reqCtx := prepareMiddleware(t, cache, nil)
|
||||||
|
|
||||||
stub := &multildap.MultiLDAPmock{
|
stub := &multildap.MultiLDAPmock{
|
||||||
ID: id,
|
ID: id,
|
||||||
@ -187,7 +175,7 @@ func TestMiddlewareContext_ldap(t *testing.T) {
|
|||||||
return stub
|
return stub
|
||||||
}
|
}
|
||||||
|
|
||||||
gotID, err := auth.Login(logger, false)
|
gotID, err := auth.Login(reqCtx, false)
|
||||||
require.EqualError(t, err, "failed to get the user")
|
require.EqualError(t, err, "failed to get the user")
|
||||||
|
|
||||||
assert.NotEqual(t, id, gotID)
|
assert.NotEqual(t, id, gotID)
|
||||||
@ -198,22 +186,24 @@ func TestMiddlewareContext_ldap(t *testing.T) {
|
|||||||
func TestDecodeHeader(t *testing.T) {
|
func TestDecodeHeader(t *testing.T) {
|
||||||
cache := remotecache.NewFakeStore(t)
|
cache := remotecache.NewFakeStore(t)
|
||||||
t.Run("should not decode header if not enabled in settings", func(t *testing.T) {
|
t.Run("should not decode header if not enabled in settings", func(t *testing.T) {
|
||||||
auth := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
|
auth, reqCtx := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
|
||||||
cfg.AuthProxyHeaderName = "X-WEBAUTH-USER"
|
cfg.AuthProxyHeaderName = "X-WEBAUTH-USER"
|
||||||
cfg.AuthProxyHeadersEncoded = false
|
cfg.AuthProxyHeadersEncoded = false
|
||||||
req.Header.Set(cfg.AuthProxyHeaderName, "M=C3=BCnchen")
|
req.Header.Set(cfg.AuthProxyHeaderName, "M=C3=BCnchen")
|
||||||
})
|
})
|
||||||
|
|
||||||
assert.Equal(t, "M=C3=BCnchen", auth.header)
|
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
|
||||||
|
assert.Equal(t, "M=C3=BCnchen", header)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("should decode header if enabled in settings", func(t *testing.T) {
|
t.Run("should decode header if enabled in settings", func(t *testing.T) {
|
||||||
auth := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
|
auth, reqCtx := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
|
||||||
cfg.AuthProxyHeaderName = "X-WEBAUTH-USER"
|
cfg.AuthProxyHeaderName = "X-WEBAUTH-USER"
|
||||||
cfg.AuthProxyHeadersEncoded = true
|
cfg.AuthProxyHeadersEncoded = true
|
||||||
req.Header.Set(cfg.AuthProxyHeaderName, "M=C3=BCnchen")
|
req.Header.Set(cfg.AuthProxyHeaderName, "M=C3=BCnchen")
|
||||||
})
|
})
|
||||||
|
|
||||||
assert.Equal(t, "München", auth.header)
|
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
|
||||||
|
assert.Equal(t, "München", header)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -36,7 +36,7 @@ const ServiceName = "ContextHandler"
|
|||||||
|
|
||||||
func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService, jwtService models.JWTService,
|
func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService, jwtService models.JWTService,
|
||||||
remoteCache *remotecache.RemoteCache, renderService rendering.Service, sqlStore *sqlstore.SQLStore,
|
remoteCache *remotecache.RemoteCache, renderService rendering.Service, sqlStore *sqlstore.SQLStore,
|
||||||
tracer tracing.Tracer) *ContextHandler {
|
tracer tracing.Tracer, authProxy *authproxy.AuthProxy) *ContextHandler {
|
||||||
return &ContextHandler{
|
return &ContextHandler{
|
||||||
Cfg: cfg,
|
Cfg: cfg,
|
||||||
AuthTokenService: tokenService,
|
AuthTokenService: tokenService,
|
||||||
@ -45,6 +45,7 @@ func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService, jwtS
|
|||||||
RenderService: renderService,
|
RenderService: renderService,
|
||||||
SQLStore: sqlStore,
|
SQLStore: sqlStore,
|
||||||
tracer: tracer,
|
tracer: tracer,
|
||||||
|
authProxy: authProxy,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -57,6 +58,7 @@ type ContextHandler struct {
|
|||||||
RenderService rendering.Service
|
RenderService rendering.Service
|
||||||
SQLStore sqlstore.Store
|
SQLStore sqlstore.Store
|
||||||
tracer tracing.Tracer
|
tracer tracing.Tracer
|
||||||
|
authProxy *authproxy.AuthProxy
|
||||||
// GetTime returns the current time.
|
// GetTime returns the current time.
|
||||||
// Stubbable by tests.
|
// Stubbable by tests.
|
||||||
GetTime func() time.Time
|
GetTime func() time.Time
|
||||||
@ -419,10 +421,10 @@ func (h *ContextHandler) initContextWithRenderAuth(reqContext *models.ReqContext
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func logUserIn(auth *authproxy.AuthProxy, username string, logger log.Logger, ignoreCache bool) (int64, error) {
|
func logUserIn(reqContext *models.ReqContext, auth *authproxy.AuthProxy, username string, logger log.Logger, ignoreCache bool) (int64, error) {
|
||||||
logger.Debug("Trying to log user in", "username", username, "ignoreCache", ignoreCache)
|
logger.Debug("Trying to log user in", "username", username, "ignoreCache", ignoreCache)
|
||||||
// Try to log in user via various providers
|
// Try to log in user via various providers
|
||||||
id, err := auth.Login(logger, ignoreCache)
|
id, err := auth.Login(reqContext, ignoreCache)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
details := err
|
details := err
|
||||||
var e authproxy.Error
|
var e authproxy.Error
|
||||||
@ -451,36 +453,31 @@ func (h *ContextHandler) handleError(ctx *models.ReqContext, err error, statusCo
|
|||||||
|
|
||||||
func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext, orgID int64) bool {
|
func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext, orgID int64) bool {
|
||||||
username := reqContext.Req.Header.Get(h.Cfg.AuthProxyHeaderName)
|
username := reqContext.Req.Header.Get(h.Cfg.AuthProxyHeaderName)
|
||||||
auth := authproxy.New(h.Cfg, &authproxy.Options{
|
|
||||||
RemoteCache: h.RemoteCache,
|
|
||||||
Ctx: reqContext,
|
|
||||||
OrgID: orgID,
|
|
||||||
})
|
|
||||||
|
|
||||||
logger := log.New("auth.proxy")
|
logger := log.New("auth.proxy")
|
||||||
|
|
||||||
// Bail if auth proxy is not enabled
|
// Bail if auth proxy is not enabled
|
||||||
if !auth.IsEnabled() {
|
if !h.authProxy.IsEnabled() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there is no header - we can't move forward
|
// If there is no header - we can't move forward
|
||||||
if !auth.HasHeader() {
|
if !h.authProxy.HasHeader(reqContext) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
_, span := h.tracer.Start(reqContext.Req.Context(), "initContextWithAuthProxy")
|
_, span := h.tracer.Start(reqContext.Req.Context(), "initContextWithAuthProxy")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
// Check if allowed to continue with this IP
|
// Check if allowed continuing with this IP
|
||||||
if err := auth.IsAllowedIP(); err != nil {
|
if err := h.authProxy.IsAllowedIP(reqContext.Req.RemoteAddr); err != nil {
|
||||||
h.handleError(reqContext, err, 407, func(details error) {
|
h.handleError(reqContext, err, 407, func(details error) {
|
||||||
logger.Error("Failed to check whitelisted IP addresses", "message", err.Error(), "error", details)
|
logger.Error("Failed to check whitelisted IP addresses", "message", err.Error(), "error", details)
|
||||||
})
|
})
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
id, err := logUserIn(auth, username, logger, false)
|
id, err := logUserIn(reqContext, h.authProxy, username, logger, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.handleError(reqContext, err, 407, nil)
|
h.handleError(reqContext, err, 407, nil)
|
||||||
return true
|
return true
|
||||||
@ -488,7 +485,7 @@ func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext,
|
|||||||
|
|
||||||
logger.Debug("Got user ID, getting full user info", "userID", id)
|
logger.Debug("Got user ID, getting full user info", "userID", id)
|
||||||
|
|
||||||
user, err := auth.GetSignedInUser(id)
|
user, err := h.authProxy.GetSignedInUser(id, orgID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// The reason we couldn't find the user corresponding to the ID might be that the ID was found from a stale
|
// The reason we couldn't find the user corresponding to the ID might be that the ID was found from a stale
|
||||||
// cache entry. For example, if a user is deleted via the API, corresponding cache entries aren't invalidated
|
// cache entry. For example, if a user is deleted via the API, corresponding cache entries aren't invalidated
|
||||||
@ -496,18 +493,18 @@ func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext,
|
|||||||
// we can't easily derive cache keys to invalidate when deleting a user. To work around this, we try to
|
// we can't easily derive cache keys to invalidate when deleting a user. To work around this, we try to
|
||||||
// log the user in again without the cache.
|
// log the user in again without the cache.
|
||||||
logger.Debug("Failed to get user info given ID, retrying without cache", "userID", id)
|
logger.Debug("Failed to get user info given ID, retrying without cache", "userID", id)
|
||||||
if err := auth.RemoveUserFromCache(logger); err != nil {
|
if err := h.authProxy.RemoveUserFromCache(reqContext); err != nil {
|
||||||
if !errors.Is(err, remotecache.ErrCacheItemNotFound) {
|
if !errors.Is(err, remotecache.ErrCacheItemNotFound) {
|
||||||
logger.Error("Got unexpected error when removing user from auth cache", "error", err)
|
logger.Error("Got unexpected error when removing user from auth cache", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
id, err = logUserIn(auth, username, logger, true)
|
id, err = logUserIn(reqContext, h.authProxy, username, logger, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.handleError(reqContext, err, 407, nil)
|
h.handleError(reqContext, err, 407, nil)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err = auth.GetSignedInUser(id)
|
user, err = h.authProxy.GetSignedInUser(id, orgID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.handleError(reqContext, err, 407, nil)
|
h.handleError(reqContext, err, 407, nil)
|
||||||
return true
|
return true
|
||||||
@ -521,7 +518,7 @@ func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext,
|
|||||||
reqContext.IsSignedIn = true
|
reqContext.IsSignedIn = true
|
||||||
|
|
||||||
// Remember user data in cache
|
// Remember user data in cache
|
||||||
if err := auth.Remember(id); err != nil {
|
if err := h.authProxy.Remember(reqContext, id); err != nil {
|
||||||
h.handleError(reqContext, err, 500, func(details error) {
|
h.handleError(reqContext, err, 500, func(details error) {
|
||||||
logger.Error(
|
logger.Error(
|
||||||
"Failed to store user in cache",
|
"Failed to store user in cache",
|
||||||
|
@ -15,6 +15,9 @@ type LoginServiceMock struct {
|
|||||||
NoExistingOrgId int64
|
NoExistingOrgId int64
|
||||||
AlreadyExitingLogin string
|
AlreadyExitingLogin string
|
||||||
GeneratedUserId int64
|
GeneratedUserId int64
|
||||||
|
ExpectedUser *models.User
|
||||||
|
ExpectedUserFunc func(cmd *models.UpsertUserCommand) *models.User
|
||||||
|
ExpectedError error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s LoginServiceMock) CreateUser(cmd models.CreateUserCommand) (*models.User, error) {
|
func (s LoginServiceMock) CreateUser(cmd models.CreateUserCommand) (*models.User, error) {
|
||||||
@ -35,5 +38,10 @@ func (s LoginServiceMock) CreateUser(cmd models.CreateUserCommand) (*models.User
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s LoginServiceMock) UpsertUser(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
func (s LoginServiceMock) UpsertUser(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
||||||
return nil
|
if s.ExpectedUserFunc != nil {
|
||||||
|
cmd.Result = s.ExpectedUserFunc(cmd)
|
||||||
|
return s.ExpectedError
|
||||||
|
}
|
||||||
|
cmd.Result = s.ExpectedUser
|
||||||
|
return s.ExpectedError
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user