Chore: Remove bus from authproxy (#46936)

* Make authproxy injectable

* Fix import

* Provide function was in wrong place

* Fixing tests

* More imports and rollback a change

* Fix lint
This commit is contained in:
Selene
2022-03-30 17:01:24 +02:00
committed by GitHub
parent 118b87ee8f
commit 8e52dbb87b
11 changed files with 189 additions and 260 deletions

View File

@ -27,11 +27,13 @@ import (
"github.com/grafana/grafana/pkg/services/accesscontrol/ossaccesscontrol" "github.com/grafana/grafana/pkg/services/accesscontrol/ossaccesscontrol"
"github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/contexthandler" "github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
"github.com/grafana/grafana/pkg/services/dashboards" "github.com/grafana/grafana/pkg/services/dashboards"
dashboardsstore "github.com/grafana/grafana/pkg/services/dashboards/database" dashboardsstore "github.com/grafana/grafana/pkg/services/dashboards/database"
dashboardservice "github.com/grafana/grafana/pkg/services/dashboards/manager" dashboardservice "github.com/grafana/grafana/pkg/services/dashboards/manager"
"github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/ldap" "github.com/grafana/grafana/pkg/services/ldap"
"github.com/grafana/grafana/pkg/services/login/loginservice"
"github.com/grafana/grafana/pkg/services/quota" "github.com/grafana/grafana/pkg/services/quota"
"github.com/grafana/grafana/pkg/services/rendering" "github.com/grafana/grafana/pkg/services/rendering"
"github.com/grafana/grafana/pkg/services/searchusers" "github.com/grafana/grafana/pkg/services/searchusers"
@ -193,7 +195,8 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHa
authJWTSvc := models.NewFakeJWTService() authJWTSvc := models.NewFakeJWTService()
tracer, err := tracing.InitializeTracerForTest() tracer, err := tracing.InitializeTracerForTest()
require.NoError(t, err) require.NoError(t, err)
ctxHdlr := contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer) authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginservice.LoginServiceMock{}, sqlStore)
ctxHdlr := contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy)
return ctxHdlr return ctxHdlr
} }

View File

@ -2,7 +2,6 @@ package middleware
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -25,8 +24,10 @@ import (
"github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/contexthandler" "github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy" "github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
"github.com/grafana/grafana/pkg/services/login/loginservice"
"github.com/grafana/grafana/pkg/services/rendering" "github.com/grafana/grafana/pkg/services/rendering"
"github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/sqlstore/mockstore"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util" "github.com/grafana/grafana/pkg/util"
"github.com/grafana/grafana/pkg/web" "github.com/grafana/grafana/pkg/web"
@ -364,10 +365,7 @@ func TestMiddlewareContext(t *testing.T) {
const group = "grafana-core-team" const group = "grafana-core-team"
middlewareScenario(t, "Should not sync the user if it's in the cache", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Should not sync the user if it's in the cache", func(t *testing.T, sc *scenarioContext) {
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error { sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID}
query.Result = &models.SignedInUser{OrgId: orgID, UserId: query.UserId}
return nil
})
h, err := authproxy.HashCacheKey(hdrName + "-" + group) h, err := authproxy.HashCacheKey(hdrName + "-" + group)
require.NoError(t, err) require.NoError(t, err)
@ -387,11 +385,11 @@ func TestMiddlewareContext(t *testing.T) {
middlewareScenario(t, "Should respect auto signup option", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Should respect auto signup option", func(t *testing.T, sc *scenarioContext) {
var actualAuthProxyAutoSignUp *bool = nil var actualAuthProxyAutoSignUp *bool = nil
sc.loginService.ExpectedUserFunc = func(cmd *models.UpsertUserCommand) *models.User {
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
actualAuthProxyAutoSignUp = &cmd.SignupAllowed actualAuthProxyAutoSignUp = &cmd.SignupAllowed
return login.ErrInvalidCredentials return nil
}) }
sc.loginService.ExpectedError = login.ErrInvalidCredentials
sc.fakeReq("GET", "/") sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
@ -407,18 +405,8 @@ func TestMiddlewareContext(t *testing.T) {
}) })
middlewareScenario(t, "Should create an user from a header", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Should create an user from a header", func(t *testing.T, sc *scenarioContext) {
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error { sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID}
if query.UserId > 0 { sc.loginService.ExpectedUser = &models.User{Id: userID}
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
return nil
}
return models.ErrUserNotFound
})
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
cmd.Result = &models.User{Id: userID}
return nil
})
sc.fakeReq("GET", "/") sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
@ -435,19 +423,11 @@ func TestMiddlewareContext(t *testing.T) {
middlewareScenario(t, "Should assign role from header to default org", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Should assign role from header to default org", func(t *testing.T, sc *scenarioContext) {
var storedRoleInfo map[int64]models.RoleType = nil var storedRoleInfo map[int64]models.RoleType = nil
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error { sc.loginService.ExpectedUserFunc = func(cmd *models.UpsertUserCommand) *models.User {
if query.UserId > 0 {
query.Result = &models.SignedInUser{OrgId: defaultOrgId, UserId: userID, OrgRole: storedRoleInfo[defaultOrgId]}
return nil
}
return models.ErrUserNotFound
})
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
cmd.Result = &models.User{Id: userID}
storedRoleInfo = cmd.ExternalUser.OrgRoles storedRoleInfo = cmd.ExternalUser.OrgRoles
return nil sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: defaultOrgId, UserId: userID, OrgRole: storedRoleInfo[defaultOrgId]}
}) return &models.User{Id: userID}
}
sc.fakeReq("GET", "/") sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
@ -466,19 +446,11 @@ func TestMiddlewareContext(t *testing.T) {
middlewareScenario(t, "Should NOT assign role from header to non-default org", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Should NOT assign role from header to non-default org", func(t *testing.T, sc *scenarioContext) {
var storedRoleInfo map[int64]models.RoleType = nil var storedRoleInfo map[int64]models.RoleType = nil
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error { sc.loginService.ExpectedUserFunc = func(cmd *models.UpsertUserCommand) *models.User {
if query.UserId > 0 {
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID, OrgRole: storedRoleInfo[orgID]}
return nil
}
return models.ErrUserNotFound
})
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
cmd.Result = &models.User{Id: userID}
storedRoleInfo = cmd.ExternalUser.OrgRoles storedRoleInfo = cmd.ExternalUser.OrgRoles
return nil sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID, OrgRole: storedRoleInfo[orgID]}
}) return &models.User{Id: userID}
}
sc.fakeReq("GET", "/") sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
@ -499,27 +471,17 @@ func TestMiddlewareContext(t *testing.T) {
}) })
middlewareScenario(t, "Should use organisation specified by targetOrgId parameter", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Should use organisation specified by targetOrgId parameter", func(t *testing.T, sc *scenarioContext) {
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error { var targetOrgID int64 = 123
if query.UserId > 0 { sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: targetOrgID, UserId: userID}
query.Result = &models.SignedInUser{OrgId: query.OrgId, UserId: userID} sc.loginService.ExpectedUser = &models.User{Id: userID}
return nil
}
return models.ErrUserNotFound
})
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
cmd.Result = &models.User{Id: userID}
return nil
})
targetOrgID := 123
sc.fakeReq("GET", fmt.Sprintf("/?targetOrgId=%d", targetOrgID)) sc.fakeReq("GET", fmt.Sprintf("/?targetOrgId=%d", targetOrgID))
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.exec() sc.exec()
assert.True(t, sc.context.IsSignedIn) assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserId) assert.Equal(t, userID, sc.context.UserId)
assert.Equal(t, int64(targetOrgID), sc.context.OrgId) assert.Equal(t, targetOrgID, sc.context.OrgId)
}, func(cfg *setting.Cfg) { }, func(cfg *setting.Cfg) {
configure(cfg) configure(cfg)
cfg.LDAPEnabled = false cfg.LDAPEnabled = false
@ -554,15 +516,8 @@ func TestMiddlewareContext(t *testing.T) {
const userID int64 = 12 const userID int64 = 12
const orgID int64 = 2 const orgID int64 = 2
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error { sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID}
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID} sc.loginService.ExpectedUser = &models.User{Id: userID}
return nil
})
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
cmd.Result = &models.User{Id: userID}
return nil
})
sc.fakeReq("GET", "/") sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
@ -577,15 +532,8 @@ func TestMiddlewareContext(t *testing.T) {
}) })
middlewareScenario(t, "Should allow the request from whitelist IP", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Should allow the request from whitelist IP", func(t *testing.T, sc *scenarioContext) {
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error { sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID}
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID} sc.loginService.ExpectedUser = &models.User{Id: userID}
return nil
})
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
cmd.Result = &models.User{Id: userID}
return nil
})
sc.fakeReq("GET", "/") sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
@ -602,15 +550,7 @@ func TestMiddlewareContext(t *testing.T) {
}) })
middlewareScenario(t, "Should not allow the request from whitelisted IP", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Should not allow the request from whitelisted IP", func(t *testing.T, sc *scenarioContext) {
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error { sc.loginService.ExpectedUser = &models.User{Id: userID}
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
return nil
})
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
cmd.Result = &models.User{Id: userID}
return nil
})
sc.fakeReq("GET", "/") sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
@ -626,10 +566,6 @@ func TestMiddlewareContext(t *testing.T) {
}) })
middlewareScenario(t, "Should return 407 status code if LDAP says no", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Should return 407 status code if LDAP says no", func(t *testing.T, sc *scenarioContext) {
bus.AddHandler("LDAP", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
return errors.New("Do not add user")
})
sc.fakeReq("GET", "/") sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.exec() sc.exec()
@ -639,10 +575,6 @@ func TestMiddlewareContext(t *testing.T) {
}, configure) }, configure)
middlewareScenario(t, "Should return 407 status code if there is cache mishap", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Should return 407 status code if there is cache mishap", func(t *testing.T, sc *scenarioContext) {
bus.AddHandler("Do not have the user", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
return errors.New("Do not add user")
})
sc.fakeReq("GET", "/") sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.exec() sc.exec()
@ -684,7 +616,9 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
sc.m.UseMiddleware(AddCSPHeader(cfg, logger)) sc.m.UseMiddleware(AddCSPHeader(cfg, logger))
sc.m.UseMiddleware(web.Renderer(viewsPath, "[[", "]]")) sc.m.UseMiddleware(web.Renderer(viewsPath, "[[", "]]"))
ctxHdlr := getContextHandler(t, cfg) sc.mockSQLStore = mockstore.NewSQLStoreMock()
sc.loginService = &loginservice.LoginServiceMock{}
ctxHdlr := getContextHandler(t, cfg, sc.mockSQLStore, sc.loginService)
sc.sqlStore = ctxHdlr.SQLStore sc.sqlStore = ctxHdlr.SQLStore
sc.contextHandler = ctxHdlr sc.contextHandler = ctxHdlr
sc.m.Use(ctxHdlr.Middleware) sc.m.Use(ctxHdlr.Middleware)
@ -714,7 +648,7 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
}) })
} }
func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHandler { func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *mockstore.SQLStoreMock, loginService *loginservice.LoginServiceMock) *contexthandler.ContextHandler {
t.Helper() t.Helper()
sqlStore := sqlstore.InitTestDB(t) sqlStore := sqlstore.InitTestDB(t)
@ -730,8 +664,9 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHa
renderSvc := &fakeRenderService{} renderSvc := &fakeRenderService{}
authJWTSvc := models.NewFakeJWTService() authJWTSvc := models.NewFakeJWTService()
tracer, err := tracing.InitializeTracerForTest() tracer, err := tracing.InitializeTracerForTest()
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, mockSQLStore)
require.NoError(t, err) require.NoError(t, err)
return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer) return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy)
} }
type fakeRenderService struct { type fakeRenderService struct {

View File

@ -33,7 +33,7 @@ func rateLimiterScenario(t *testing.T, desc string, rps int, burst int, fn rateL
m := web.New() m := web.New()
m.UseMiddleware(web.Renderer("../../public/views", "[[", "]]")) m.UseMiddleware(web.Renderer("../../public/views", "[[", "]]"))
m.Use(getContextHandler(t, cfg).Middleware) m.Use(getContextHandler(t, cfg, nil, nil).Middleware)
m.Get("/foo", RateLimit(rps, burst, func() time.Time { return currentTime }), defaultHandler) m.Get("/foo", RateLimit(rps, burst, func() time.Time { return currentTime }), defaultHandler)
fn(func() *httptest.ResponseRecorder { fn(func() *httptest.ResponseRecorder {

View File

@ -70,7 +70,7 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) {
sc.userAuthTokenService = auth.NewFakeUserAuthTokenService() sc.userAuthTokenService = auth.NewFakeUserAuthTokenService()
sc.remoteCacheService = remotecache.NewFakeStore(t) sc.remoteCacheService = remotecache.NewFakeStore(t)
contextHandler := getContextHandler(t, nil) contextHandler := getContextHandler(t, nil, nil, nil)
sc.m.Use(contextHandler.Middleware) sc.m.Use(contextHandler.Middleware)
// mock out gc goroutine // mock out gc goroutine
sc.m.Use(OrgRedirect(cfg)) sc.m.Use(OrgRedirect(cfg))

View File

@ -10,7 +10,9 @@ import (
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/contexthandler" "github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/login/loginservice"
"github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/sqlstore/mockstore"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/web" "github.com/grafana/grafana/pkg/web"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -34,7 +36,9 @@ type scenarioContext struct {
remoteCacheService *remotecache.RemoteCache remoteCacheService *remotecache.RemoteCache
cfg *setting.Cfg cfg *setting.Cfg
sqlStore sqlstore.Store sqlStore sqlstore.Store
mockSQLStore *mockstore.SQLStoreMock
contextHandler *contexthandler.ContextHandler contextHandler *contexthandler.ContextHandler
loginService *loginservice.LoginServiceMock
req *http.Request req *http.Request
} }

View File

@ -32,6 +32,7 @@ import (
"github.com/grafana/grafana/pkg/services/cleanup" "github.com/grafana/grafana/pkg/services/cleanup"
"github.com/grafana/grafana/pkg/services/comments" "github.com/grafana/grafana/pkg/services/comments"
"github.com/grafana/grafana/pkg/services/contexthandler" "github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
"github.com/grafana/grafana/pkg/services/dashboardimport" "github.com/grafana/grafana/pkg/services/dashboardimport"
dashboardimportservice "github.com/grafana/grafana/pkg/services/dashboardimport/service" dashboardimportservice "github.com/grafana/grafana/pkg/services/dashboardimport/service"
"github.com/grafana/grafana/pkg/services/dashboards" "github.com/grafana/grafana/pkg/services/dashboards"
@ -228,6 +229,7 @@ var wireBasicSet = wire.NewSet(
wire.Bind(new(alerting.DashAlertExtractor), new(*alerting.DashAlertExtractorService)), wire.Bind(new(alerting.DashAlertExtractor), new(*alerting.DashAlertExtractorService)),
comments.ProvideService, comments.ProvideService,
guardian.ProvideService, guardian.ProvideService,
authproxy.ProvideAuthProxy,
) )
var wireSet = wire.NewSet( var wireSet = wire.NewSet(

View File

@ -6,13 +6,13 @@ import (
"net/http" "net/http"
"testing" "testing"
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/infra/tracing" "github.com/grafana/grafana/pkg/infra/tracing"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy" "github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
"github.com/grafana/grafana/pkg/services/login/loginservice"
"github.com/grafana/grafana/pkg/services/rendering" "github.com/grafana/grafana/pkg/services/rendering"
"github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
@ -20,42 +20,18 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
const userID = int64(1)
const orgID = int64(4)
// Test initContextWithAuthProxy with a cached user ID that is no longer valid. // Test initContextWithAuthProxy with a cached user ID that is no longer valid.
// //
// In this case, the cache entry should be ignored/cleared and another attempt should be done to sign the user // In this case, the cache entry should be ignored/cleared and another attempt should be done to sign the user
// in without cache. // in without cache.
func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) { func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) {
const name = "markelog" const name = "markelog"
const userID = int64(1)
const orgID = int64(4)
svc := getContextHandler(t) svc := getContextHandler(t)
// XXX: These handlers have to be injected AFTER calling getContextHandler, since the latter
// creates a SQLStore which installs its own handlers.
upsertHandler := func(ctx context.Context, cmd *models.UpsertUserCommand) error {
require.Equal(t, name, cmd.ExternalUser.Login)
cmd.Result = &models.User{Id: userID}
return nil
}
getUserHandler := func(ctx context.Context, query *models.GetSignedInUserQuery) error {
// Simulate that the cached user ID is stale
if query.UserId != userID {
return models.ErrUserNotFound
}
query.Result = &models.SignedInUser{
UserId: userID,
OrgId: orgID,
}
return nil
}
bus.AddHandler("", upsertHandler)
bus.AddHandler("", getUserHandler)
t.Cleanup(func() {
bus.ClearBusHandlers()
})
req, err := http.NewRequest("POST", "http://example.com", nil) req, err := http.NewRequest("POST", "http://example.com", nil)
require.NoError(t, err) require.NoError(t, err)
ctx := &models.ReqContext{ ctx := &models.ReqContext{
@ -106,5 +82,24 @@ func getContextHandler(t *testing.T) *ContextHandler {
tracer, err := tracing.InitializeTracerForTest() tracer, err := tracing.InitializeTracerForTest()
require.NoError(t, err) require.NoError(t, err)
return ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer) loginService := loginservice.LoginServiceMock{ExpectedUser: &models.User{Id: userID}}
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, &FakeGetSignUserStore{})
return ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy)
}
type FakeGetSignUserStore struct {
sqlstore.Store
}
func (f *FakeGetSignUserStore) GetSignedInUser(ctx context.Context, query *models.GetSignedInUserQuery) error {
if query.UserId != userID {
return models.ErrUserNotFound
}
query.Result = &models.SignedInUser{
UserId: userID,
OrgId: orgID,
}
return nil
} }

View File

@ -13,12 +13,13 @@ import (
"strings" "strings"
"time" "time"
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/ldap" "github.com/grafana/grafana/pkg/services/ldap"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/multildap" "github.com/grafana/grafana/pkg/services/multildap"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util" "github.com/grafana/grafana/pkg/util"
) )
@ -49,11 +50,22 @@ var supportedHeaderFields = []string{"Name", "Email", "Login", "Groups", "Role"}
// AuthProxy struct // AuthProxy struct
type AuthProxy struct { type AuthProxy struct {
cfg *setting.Cfg cfg *setting.Cfg
remoteCache *remotecache.RemoteCache remoteCache *remotecache.RemoteCache
ctx *models.ReqContext loginService login.Service
orgID int64 sqlStore sqlstore.Store
header string
logger log.Logger
}
func ProvideAuthProxy(cfg *setting.Cfg, remoteCache *remotecache.RemoteCache, loginService login.Service, sqlStore sqlstore.Store) *AuthProxy {
return &AuthProxy{
cfg: cfg,
remoteCache: remoteCache,
loginService: loginService,
sqlStore: sqlStore,
logger: log.New("auth.proxy"),
}
} }
// Error auth proxy specific error // Error auth proxy specific error
@ -75,40 +87,20 @@ func (err Error) Error() string {
return err.Message return err.Message
} }
// Options for the AuthProxy
type Options struct {
RemoteCache *remotecache.RemoteCache
Ctx *models.ReqContext
OrgID int64
}
// New instance of the AuthProxy.
func New(cfg *setting.Cfg, options *Options) *AuthProxy {
auth := &AuthProxy{
remoteCache: options.RemoteCache,
cfg: cfg,
ctx: options.Ctx,
orgID: options.OrgID,
}
auth.header = auth.getDecodedHeader(cfg.AuthProxyHeaderName)
return auth
}
// IsEnabled checks if the auth proxy is enabled. // IsEnabled checks if the auth proxy is enabled.
func (auth *AuthProxy) IsEnabled() bool { func (auth *AuthProxy) IsEnabled() bool {
// Bail if the setting is not enabled // Bail if the setting is not enabled
return auth.cfg.AuthProxyEnabled return auth.cfg.AuthProxyEnabled
} }
// HasHeader checks if the we have specified header // HasHeader checks if we have specified header
func (auth *AuthProxy) HasHeader() bool { func (auth *AuthProxy) HasHeader(reqCtx *models.ReqContext) bool {
return len(auth.header) != 0 header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
return len(header) != 0
} }
// IsAllowedIP returns whether provided IP is allowed. // IsAllowedIP returns whether provided IP is allowed.
func (auth *AuthProxy) IsAllowedIP() error { func (auth *AuthProxy) IsAllowedIP(ip string) error {
ip := auth.ctx.Req.RemoteAddr
if len(strings.TrimSpace(auth.cfg.AuthProxyWhitelist)) == 0 { if len(strings.TrimSpace(auth.cfg.AuthProxyWhitelist)) == 0 {
return nil return nil
} }
@ -137,7 +129,7 @@ func (auth *AuthProxy) IsAllowedIP() error {
} }
return newError("proxy authentication required", fmt.Errorf( return newError("proxy authentication required", fmt.Errorf(
"request for user (%s) from %s is not from the authentication proxy", auth.header, "request for user from %s is not from the authentication proxy",
sourceIP, sourceIP,
)) ))
} }
@ -153,10 +145,11 @@ func HashCacheKey(key string) (string, error) {
// getKey forms a key for the cache based on the headers received as part of the authentication flow. // getKey forms a key for the cache based on the headers received as part of the authentication flow.
// Our configuration supports multiple headers. The main header contains the email or username. // Our configuration supports multiple headers. The main header contains the email or username.
// And the additional ones that allow us to specify extra attributes: Name, Email, Role, or Groups. // And the additional ones that allow us to specify extra attributes: Name, Email, Role, or Groups.
func (auth *AuthProxy) getKey() (string, error) { func (auth *AuthProxy) getKey(reqCtx *models.ReqContext) (string, error) {
key := strings.TrimSpace(auth.header) // start the key with the main header header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
key := strings.TrimSpace(header) // start the key with the main header
auth.headersIterator(func(_, header string) { auth.headersIterator(reqCtx, func(_, header string) {
key = strings.Join([]string{key, header}, "-") // compose the key with any additional headers key = strings.Join([]string{key, header}, "-") // compose the key with any additional headers
}) })
@ -168,17 +161,17 @@ func (auth *AuthProxy) getKey() (string, error) {
} }
// Login logs in user ID by whatever means possible. // Login logs in user ID by whatever means possible.
func (auth *AuthProxy) Login(logger log.Logger, ignoreCache bool) (int64, error) { func (auth *AuthProxy) Login(reqCtx *models.ReqContext, ignoreCache bool) (int64, error) {
if !ignoreCache { if !ignoreCache {
// Error here means absent cache - we don't need to handle that // Error here means absent cache - we don't need to handle that
id, err := auth.GetUserViaCache(logger) id, err := auth.getUserViaCache(reqCtx)
if err == nil && id != 0 { if err == nil && id != 0 {
return id, nil return id, nil
} }
} }
if isLDAPEnabled(auth.cfg) { if isLDAPEnabled(auth.cfg) {
id, err := auth.LoginViaLDAP() id, err := auth.LoginViaLDAP(reqCtx)
if err != nil { if err != nil {
if errors.Is(err, ldap.ErrInvalidCredentials) { if errors.Is(err, ldap.ErrInvalidCredentials) {
return 0, newError("proxy authentication required", ldap.ErrInvalidCredentials) return 0, newError("proxy authentication required", ldap.ErrInvalidCredentials)
@ -189,7 +182,7 @@ func (auth *AuthProxy) Login(logger log.Logger, ignoreCache bool) (int64, error)
return id, nil return id, nil
} }
id, err := auth.LoginViaHeader() id, err := auth.loginViaHeader(reqCtx)
if err != nil { if err != nil {
return 0, newError("failed to log in as user, specified in auth proxy header", err) return 0, newError("failed to log in as user, specified in auth proxy header", err)
} }
@ -197,87 +190,89 @@ func (auth *AuthProxy) Login(logger log.Logger, ignoreCache bool) (int64, error)
return id, nil return id, nil
} }
// GetUserViaCache gets user ID from cache. // getUserViaCache gets user ID from cache.
func (auth *AuthProxy) GetUserViaCache(logger log.Logger) (int64, error) { func (auth *AuthProxy) getUserViaCache(reqCtx *models.ReqContext) (int64, error) {
cacheKey, err := auth.getKey() cacheKey, err := auth.getKey(reqCtx)
if err != nil { if err != nil {
return 0, err return 0, err
} }
logger.Debug("Getting user ID via auth cache", "cacheKey", cacheKey) auth.logger.Debug("Getting user ID via auth cache", "cacheKey", cacheKey)
userID, err := auth.remoteCache.Get(auth.ctx.Req.Context(), cacheKey) userID, err := auth.remoteCache.Get(reqCtx.Req.Context(), cacheKey)
if err != nil { if err != nil {
logger.Debug("Failed getting user ID via auth cache", "error", err) auth.logger.Debug("Failed getting user ID via auth cache", "error", err)
return 0, err return 0, err
} }
logger.Debug("Successfully got user ID via auth cache", "id", userID) auth.logger.Debug("Successfully got user ID via auth cache", "id", userID)
return userID.(int64), nil return userID.(int64), nil
} }
// RemoveUserFromCache removes user from cache. // RemoveUserFromCache removes user from cache.
func (auth *AuthProxy) RemoveUserFromCache(logger log.Logger) error { func (auth *AuthProxy) RemoveUserFromCache(reqCtx *models.ReqContext) error {
cacheKey, err := auth.getKey() cacheKey, err := auth.getKey(reqCtx)
if err != nil { if err != nil {
return err return err
} }
logger.Debug("Removing user from auth cache", "cacheKey", cacheKey) auth.logger.Debug("Removing user from auth cache", "cacheKey", cacheKey)
if err := auth.remoteCache.Delete(auth.ctx.Req.Context(), cacheKey); err != nil { if err := auth.remoteCache.Delete(reqCtx.Req.Context(), cacheKey); err != nil {
return err return err
} }
logger.Debug("Successfully removed user from auth cache", "cacheKey", cacheKey) auth.logger.Debug("Successfully removed user from auth cache", "cacheKey", cacheKey)
return nil return nil
} }
// LoginViaLDAP logs in user via LDAP request // LoginViaLDAP logs in user via LDAP request
func (auth *AuthProxy) LoginViaLDAP() (int64, error) { func (auth *AuthProxy) LoginViaLDAP(reqCtx *models.ReqContext) (int64, error) {
config, err := getLDAPConfig(auth.cfg) config, err := getLDAPConfig(auth.cfg)
if err != nil { if err != nil {
return 0, newError("failed to get LDAP config", err) return 0, newError("failed to get LDAP config", err)
} }
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
mldap := newLDAP(config.Servers) mldap := newLDAP(config.Servers)
extUser, _, err := mldap.User(auth.header) extUser, _, err := mldap.User(header)
if err != nil { if err != nil {
return 0, err return 0, err
} }
// Have to sync grafana and LDAP user during log in // Have to sync grafana and LDAP user during log in
upsert := &models.UpsertUserCommand{ upsert := &models.UpsertUserCommand{
ReqContext: auth.ctx, ReqContext: reqCtx,
SignupAllowed: auth.cfg.LDAPAllowSignup, SignupAllowed: auth.cfg.LDAPAllowSignup,
ExternalUser: extUser, ExternalUser: extUser,
} }
if err := bus.Dispatch(auth.ctx.Req.Context(), upsert); err != nil { if err := auth.loginService.UpsertUser(reqCtx.Req.Context(), upsert); err != nil {
return 0, err return 0, err
} }
return upsert.Result.Id, nil return upsert.Result.Id, nil
} }
// LoginViaHeader logs in user from the header only // loginViaHeader logs in user from the header only
func (auth *AuthProxy) LoginViaHeader() (int64, error) { func (auth *AuthProxy) loginViaHeader(reqCtx *models.ReqContext) (int64, error) {
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
extUser := &models.ExternalUserInfo{ extUser := &models.ExternalUserInfo{
AuthModule: "authproxy", AuthModule: "authproxy",
AuthId: auth.header, AuthId: header,
} }
switch auth.cfg.AuthProxyHeaderProperty { switch auth.cfg.AuthProxyHeaderProperty {
case "username": case "username":
extUser.Login = auth.header extUser.Login = header
emailAddr, emailErr := mail.ParseAddress(auth.header) // only set Email if it can be parsed as an email address emailAddr, emailErr := mail.ParseAddress(header) // only set Email if it can be parsed as an email address
if emailErr == nil { if emailErr == nil {
extUser.Email = emailAddr.Address extUser.Email = emailAddr.Address
} }
case "email": case "email":
extUser.Email = auth.header extUser.Email = header
extUser.Login = auth.header extUser.Login = header
default: default:
return 0, fmt.Errorf("auth proxy header property invalid") return 0, fmt.Errorf("auth proxy header property invalid")
} }
auth.headersIterator(func(field string, header string) { auth.headersIterator(reqCtx, func(field string, header string) {
switch field { switch field {
case "Groups": case "Groups":
extUser.Groups = util.SplitString(header) extUser.Groups = util.SplitString(header)
@ -300,12 +295,12 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) {
}) })
upsert := &models.UpsertUserCommand{ upsert := &models.UpsertUserCommand{
ReqContext: auth.ctx, ReqContext: reqCtx,
SignupAllowed: auth.cfg.AuthProxyAutoSignUp, SignupAllowed: auth.cfg.AuthProxyAutoSignUp,
ExternalUser: extUser, ExternalUser: extUser,
} }
err := bus.Dispatch(auth.ctx.Req.Context(), upsert) err := auth.loginService.UpsertUser(reqCtx.Req.Context(), upsert)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -314,8 +309,8 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) {
} }
// getDecodedHeader gets decoded value of a header with given headerName // getDecodedHeader gets decoded value of a header with given headerName
func (auth *AuthProxy) getDecodedHeader(headerName string) string { func (auth *AuthProxy) getDecodedHeader(reqCtx *models.ReqContext, headerName string) string {
headerValue := auth.ctx.Req.Header.Get(headerName) headerValue := reqCtx.Req.Header.Get(headerName)
if auth.cfg.AuthProxyHeadersEncoded { if auth.cfg.AuthProxyHeadersEncoded {
headerValue = util.DecodeQuotedPrintable(headerValue) headerValue = util.DecodeQuotedPrintable(headerValue)
@ -325,27 +320,27 @@ func (auth *AuthProxy) getDecodedHeader(headerName string) string {
} }
// headersIterator iterates over all non-empty supported additional headers // headersIterator iterates over all non-empty supported additional headers
func (auth *AuthProxy) headersIterator(fn func(field string, header string)) { func (auth *AuthProxy) headersIterator(reqCtx *models.ReqContext, fn func(field string, header string)) {
for _, field := range supportedHeaderFields { for _, field := range supportedHeaderFields {
h := auth.cfg.AuthProxyHeaders[field] h := auth.cfg.AuthProxyHeaders[field]
if h == "" { if h == "" {
continue continue
} }
if value := auth.getDecodedHeader(h); value != "" { if value := auth.getDecodedHeader(reqCtx, h); value != "" {
fn(field, strings.TrimSpace(value)) fn(field, strings.TrimSpace(value))
} }
} }
} }
// GetSignedUser gets full signed in user info. // GetSignedInUser gets full signed in user info.
func (auth *AuthProxy) GetSignedInUser(userID int64) (*models.SignedInUser, error) { func (auth *AuthProxy) GetSignedInUser(userID int64, orgID int64) (*models.SignedInUser, error) {
query := &models.GetSignedInUserQuery{ query := &models.GetSignedInUserQuery{
OrgId: auth.orgID, OrgId: orgID,
UserId: userID, UserId: userID,
} }
if err := bus.Dispatch(context.Background(), query); err != nil { if err := auth.sqlStore.GetSignedInUser(context.Background(), query); err != nil {
return nil, err return nil, err
} }
@ -353,21 +348,21 @@ func (auth *AuthProxy) GetSignedInUser(userID int64) (*models.SignedInUser, erro
} }
// Remember user in cache // Remember user in cache
func (auth *AuthProxy) Remember(id int64) error { func (auth *AuthProxy) Remember(reqCtx *models.ReqContext, id int64) error {
key, err := auth.getKey() key, err := auth.getKey(reqCtx)
if err != nil { if err != nil {
return err return err
} }
// Check if user already in cache // Check if user already in cache
userID, err := auth.remoteCache.Get(auth.ctx.Req.Context(), key) userID, err := auth.remoteCache.Get(reqCtx.Req.Context(), key)
if err == nil && userID != nil { if err == nil && userID != nil {
return nil return nil
} }
expiration := time.Duration(auth.cfg.AuthProxySyncTTL) * time.Minute expiration := time.Duration(auth.cfg.AuthProxySyncTTL) * time.Minute
if err := auth.remoteCache.Set(auth.ctx.Req.Context(), key, id, expiration); err != nil { if err := auth.remoteCache.Set(reqCtx.Req.Context(), key, id, expiration); err != nil {
return err return err
} }

View File

@ -7,8 +7,8 @@ import (
"net/http" "net/http"
"testing" "testing"
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/services/login/loginservice"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/ldap" "github.com/grafana/grafana/pkg/services/ldap"
@ -20,8 +20,9 @@ import (
) )
const hdrName = "markelog" const hdrName = "markelog"
const id int64 = 42
func prepareMiddleware(t *testing.T, remoteCache *remotecache.RemoteCache, configureReq func(*http.Request, *setting.Cfg)) *AuthProxy { func prepareMiddleware(t *testing.T, remoteCache *remotecache.RemoteCache, configureReq func(*http.Request, *setting.Cfg)) (*AuthProxy, *models.ReqContext) {
t.Helper() t.Helper()
req, err := http.NewRequest("POST", "http://example.com", nil) req, err := http.NewRequest("POST", "http://example.com", nil)
@ -40,17 +41,16 @@ func prepareMiddleware(t *testing.T, remoteCache *remotecache.RemoteCache, confi
Context: &web.Context{Req: req}, Context: &web.Context{Req: req},
} }
auth := New(cfg, &Options{ loginService := loginservice.LoginServiceMock{
RemoteCache: remoteCache, ExpectedUser: &models.User{
Ctx: ctx, Id: id,
OrgID: 4, },
}) }
return auth return ProvideAuthProxy(cfg, remoteCache, loginService, nil), ctx
} }
func TestMiddlewareContext(t *testing.T) { func TestMiddlewareContext(t *testing.T) {
logger := log.New("test")
cache := remotecache.NewFakeStore(t) cache := remotecache.NewFakeStore(t)
t.Run("When the cache only contains the main header with a simple cache key", func(t *testing.T) { t.Run("When the cache only contains the main header with a simple cache key", func(t *testing.T) {
@ -62,12 +62,12 @@ func TestMiddlewareContext(t *testing.T) {
err = cache.Set(context.Background(), key, id, 0) err = cache.Set(context.Background(), key, id, 0)
require.NoError(t, err) require.NoError(t, err)
// Set up the middleware // Set up the middleware
auth := prepareMiddleware(t, cache, nil) auth, reqCtx := prepareMiddleware(t, cache, nil)
gotKey, err := auth.getKey() gotKey, err := auth.getKey(reqCtx)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, key, gotKey) assert.Equal(t, key, gotKey)
gotID, err := auth.Login(logger, false) gotID, err := auth.Login(reqCtx, false)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, id, gotID) assert.Equal(t, id, gotID)
@ -84,7 +84,7 @@ func TestMiddlewareContext(t *testing.T) {
err = cache.Set(context.Background(), key, id, 0) err = cache.Set(context.Background(), key, id, 0)
require.NoError(t, err) require.NoError(t, err)
auth := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) { auth, reqCtx := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
cfg.AuthProxyHeaderName = "X-Killa" cfg.AuthProxyHeaderName = "X-Killa"
cfg.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS", "Role": "X-WEBAUTH-ROLE"} cfg.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS", "Role": "X-WEBAUTH-ROLE"}
req.Header.Set(cfg.AuthProxyHeaderName, hdrName) req.Header.Set(cfg.AuthProxyHeaderName, hdrName)
@ -93,26 +93,14 @@ func TestMiddlewareContext(t *testing.T) {
}) })
assert.Equal(t, "auth-proxy-sync-ttl:f5acfffd56daac98d502ef8c8b8c5d56", key) assert.Equal(t, "auth-proxy-sync-ttl:f5acfffd56daac98d502ef8c8b8c5d56", key)
gotID, err := auth.Login(logger, false) gotID, err := auth.Login(reqCtx, false)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, id, gotID) assert.Equal(t, id, gotID)
}) })
} }
func TestMiddlewareContext_ldap(t *testing.T) { func TestMiddlewareContext_ldap(t *testing.T) {
logger := log.New("test")
t.Run("Logs in via LDAP", func(t *testing.T) { t.Run("Logs in via LDAP", func(t *testing.T) {
const id int64 = 42
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
cmd.Result = &models.User{
Id: id,
}
return nil
})
origIsLDAPEnabled := isLDAPEnabled origIsLDAPEnabled := isLDAPEnabled
origGetLDAPConfig := getLDAPConfig origGetLDAPConfig := getLDAPConfig
origNewLDAP := newLDAP origNewLDAP := newLDAP
@ -147,9 +135,9 @@ func TestMiddlewareContext_ldap(t *testing.T) {
cache := remotecache.NewFakeStore(t) cache := remotecache.NewFakeStore(t)
auth := prepareMiddleware(t, cache, nil) auth, reqCtx := prepareMiddleware(t, cache, nil)
gotID, err := auth.Login(logger, false) gotID, err := auth.Login(reqCtx, false)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, id, gotID) assert.Equal(t, id, gotID)
@ -177,7 +165,7 @@ func TestMiddlewareContext_ldap(t *testing.T) {
cache := remotecache.NewFakeStore(t) cache := remotecache.NewFakeStore(t)
auth := prepareMiddleware(t, cache, nil) auth, reqCtx := prepareMiddleware(t, cache, nil)
stub := &multildap.MultiLDAPmock{ stub := &multildap.MultiLDAPmock{
ID: id, ID: id,
@ -187,7 +175,7 @@ func TestMiddlewareContext_ldap(t *testing.T) {
return stub return stub
} }
gotID, err := auth.Login(logger, false) gotID, err := auth.Login(reqCtx, false)
require.EqualError(t, err, "failed to get the user") require.EqualError(t, err, "failed to get the user")
assert.NotEqual(t, id, gotID) assert.NotEqual(t, id, gotID)
@ -198,22 +186,24 @@ func TestMiddlewareContext_ldap(t *testing.T) {
func TestDecodeHeader(t *testing.T) { func TestDecodeHeader(t *testing.T) {
cache := remotecache.NewFakeStore(t) cache := remotecache.NewFakeStore(t)
t.Run("should not decode header if not enabled in settings", func(t *testing.T) { t.Run("should not decode header if not enabled in settings", func(t *testing.T) {
auth := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) { auth, reqCtx := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
cfg.AuthProxyHeaderName = "X-WEBAUTH-USER" cfg.AuthProxyHeaderName = "X-WEBAUTH-USER"
cfg.AuthProxyHeadersEncoded = false cfg.AuthProxyHeadersEncoded = false
req.Header.Set(cfg.AuthProxyHeaderName, "M=C3=BCnchen") req.Header.Set(cfg.AuthProxyHeaderName, "M=C3=BCnchen")
}) })
assert.Equal(t, "M=C3=BCnchen", auth.header) header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
assert.Equal(t, "M=C3=BCnchen", header)
}) })
t.Run("should decode header if enabled in settings", func(t *testing.T) { t.Run("should decode header if enabled in settings", func(t *testing.T) {
auth := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) { auth, reqCtx := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
cfg.AuthProxyHeaderName = "X-WEBAUTH-USER" cfg.AuthProxyHeaderName = "X-WEBAUTH-USER"
cfg.AuthProxyHeadersEncoded = true cfg.AuthProxyHeadersEncoded = true
req.Header.Set(cfg.AuthProxyHeaderName, "M=C3=BCnchen") req.Header.Set(cfg.AuthProxyHeaderName, "M=C3=BCnchen")
}) })
assert.Equal(t, "München", auth.header) header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
assert.Equal(t, "München", header)
}) })
} }

View File

@ -36,7 +36,7 @@ const ServiceName = "ContextHandler"
func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService, jwtService models.JWTService, func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService, jwtService models.JWTService,
remoteCache *remotecache.RemoteCache, renderService rendering.Service, sqlStore *sqlstore.SQLStore, remoteCache *remotecache.RemoteCache, renderService rendering.Service, sqlStore *sqlstore.SQLStore,
tracer tracing.Tracer) *ContextHandler { tracer tracing.Tracer, authProxy *authproxy.AuthProxy) *ContextHandler {
return &ContextHandler{ return &ContextHandler{
Cfg: cfg, Cfg: cfg,
AuthTokenService: tokenService, AuthTokenService: tokenService,
@ -45,6 +45,7 @@ func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService, jwtS
RenderService: renderService, RenderService: renderService,
SQLStore: sqlStore, SQLStore: sqlStore,
tracer: tracer, tracer: tracer,
authProxy: authProxy,
} }
} }
@ -57,6 +58,7 @@ type ContextHandler struct {
RenderService rendering.Service RenderService rendering.Service
SQLStore sqlstore.Store SQLStore sqlstore.Store
tracer tracing.Tracer tracer tracing.Tracer
authProxy *authproxy.AuthProxy
// GetTime returns the current time. // GetTime returns the current time.
// Stubbable by tests. // Stubbable by tests.
GetTime func() time.Time GetTime func() time.Time
@ -419,10 +421,10 @@ func (h *ContextHandler) initContextWithRenderAuth(reqContext *models.ReqContext
return true return true
} }
func logUserIn(auth *authproxy.AuthProxy, username string, logger log.Logger, ignoreCache bool) (int64, error) { func logUserIn(reqContext *models.ReqContext, auth *authproxy.AuthProxy, username string, logger log.Logger, ignoreCache bool) (int64, error) {
logger.Debug("Trying to log user in", "username", username, "ignoreCache", ignoreCache) logger.Debug("Trying to log user in", "username", username, "ignoreCache", ignoreCache)
// Try to log in user via various providers // Try to log in user via various providers
id, err := auth.Login(logger, ignoreCache) id, err := auth.Login(reqContext, ignoreCache)
if err != nil { if err != nil {
details := err details := err
var e authproxy.Error var e authproxy.Error
@ -451,36 +453,31 @@ func (h *ContextHandler) handleError(ctx *models.ReqContext, err error, statusCo
func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext, orgID int64) bool { func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext, orgID int64) bool {
username := reqContext.Req.Header.Get(h.Cfg.AuthProxyHeaderName) username := reqContext.Req.Header.Get(h.Cfg.AuthProxyHeaderName)
auth := authproxy.New(h.Cfg, &authproxy.Options{
RemoteCache: h.RemoteCache,
Ctx: reqContext,
OrgID: orgID,
})
logger := log.New("auth.proxy") logger := log.New("auth.proxy")
// Bail if auth proxy is not enabled // Bail if auth proxy is not enabled
if !auth.IsEnabled() { if !h.authProxy.IsEnabled() {
return false return false
} }
// If there is no header - we can't move forward // If there is no header - we can't move forward
if !auth.HasHeader() { if !h.authProxy.HasHeader(reqContext) {
return false return false
} }
_, span := h.tracer.Start(reqContext.Req.Context(), "initContextWithAuthProxy") _, span := h.tracer.Start(reqContext.Req.Context(), "initContextWithAuthProxy")
defer span.End() defer span.End()
// Check if allowed to continue with this IP // Check if allowed continuing with this IP
if err := auth.IsAllowedIP(); err != nil { if err := h.authProxy.IsAllowedIP(reqContext.Req.RemoteAddr); err != nil {
h.handleError(reqContext, err, 407, func(details error) { h.handleError(reqContext, err, 407, func(details error) {
logger.Error("Failed to check whitelisted IP addresses", "message", err.Error(), "error", details) logger.Error("Failed to check whitelisted IP addresses", "message", err.Error(), "error", details)
}) })
return true return true
} }
id, err := logUserIn(auth, username, logger, false) id, err := logUserIn(reqContext, h.authProxy, username, logger, false)
if err != nil { if err != nil {
h.handleError(reqContext, err, 407, nil) h.handleError(reqContext, err, 407, nil)
return true return true
@ -488,7 +485,7 @@ func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext,
logger.Debug("Got user ID, getting full user info", "userID", id) logger.Debug("Got user ID, getting full user info", "userID", id)
user, err := auth.GetSignedInUser(id) user, err := h.authProxy.GetSignedInUser(id, orgID)
if err != nil { if err != nil {
// The reason we couldn't find the user corresponding to the ID might be that the ID was found from a stale // The reason we couldn't find the user corresponding to the ID might be that the ID was found from a stale
// cache entry. For example, if a user is deleted via the API, corresponding cache entries aren't invalidated // cache entry. For example, if a user is deleted via the API, corresponding cache entries aren't invalidated
@ -496,18 +493,18 @@ func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext,
// we can't easily derive cache keys to invalidate when deleting a user. To work around this, we try to // we can't easily derive cache keys to invalidate when deleting a user. To work around this, we try to
// log the user in again without the cache. // log the user in again without the cache.
logger.Debug("Failed to get user info given ID, retrying without cache", "userID", id) logger.Debug("Failed to get user info given ID, retrying without cache", "userID", id)
if err := auth.RemoveUserFromCache(logger); err != nil { if err := h.authProxy.RemoveUserFromCache(reqContext); err != nil {
if !errors.Is(err, remotecache.ErrCacheItemNotFound) { if !errors.Is(err, remotecache.ErrCacheItemNotFound) {
logger.Error("Got unexpected error when removing user from auth cache", "error", err) logger.Error("Got unexpected error when removing user from auth cache", "error", err)
} }
} }
id, err = logUserIn(auth, username, logger, true) id, err = logUserIn(reqContext, h.authProxy, username, logger, true)
if err != nil { if err != nil {
h.handleError(reqContext, err, 407, nil) h.handleError(reqContext, err, 407, nil)
return true return true
} }
user, err = auth.GetSignedInUser(id) user, err = h.authProxy.GetSignedInUser(id, orgID)
if err != nil { if err != nil {
h.handleError(reqContext, err, 407, nil) h.handleError(reqContext, err, 407, nil)
return true return true
@ -521,7 +518,7 @@ func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext,
reqContext.IsSignedIn = true reqContext.IsSignedIn = true
// Remember user data in cache // Remember user data in cache
if err := auth.Remember(id); err != nil { if err := h.authProxy.Remember(reqContext, id); err != nil {
h.handleError(reqContext, err, 500, func(details error) { h.handleError(reqContext, err, 500, func(details error) {
logger.Error( logger.Error(
"Failed to store user in cache", "Failed to store user in cache",

View File

@ -15,6 +15,9 @@ type LoginServiceMock struct {
NoExistingOrgId int64 NoExistingOrgId int64
AlreadyExitingLogin string AlreadyExitingLogin string
GeneratedUserId int64 GeneratedUserId int64
ExpectedUser *models.User
ExpectedUserFunc func(cmd *models.UpsertUserCommand) *models.User
ExpectedError error
} }
func (s LoginServiceMock) CreateUser(cmd models.CreateUserCommand) (*models.User, error) { func (s LoginServiceMock) CreateUser(cmd models.CreateUserCommand) (*models.User, error) {
@ -35,5 +38,10 @@ func (s LoginServiceMock) CreateUser(cmd models.CreateUserCommand) (*models.User
} }
func (s LoginServiceMock) UpsertUser(ctx context.Context, cmd *models.UpsertUserCommand) error { func (s LoginServiceMock) UpsertUser(ctx context.Context, cmd *models.UpsertUserCommand) error {
return nil if s.ExpectedUserFunc != nil {
cmd.Result = s.ExpectedUserFunc(cmd)
return s.ExpectedError
}
cmd.Result = s.ExpectedUser
return s.ExpectedError
} }