diff --git a/pkg/api/admin_users_test.go b/pkg/api/admin_users_test.go index a45afadc18f..e58e6186086 100644 --- a/pkg/api/admin_users_test.go +++ b/pkg/api/admin_users_test.go @@ -22,7 +22,7 @@ const ( func TestAdminAPIEndpoint(t *testing.T) { const role = models.ROLE_ADMIN - t.Run("Given a server admin attempts to remove themself as an admin", func(t *testing.T) { + t.Run("Given a server admin attempts to remove themselves as an admin", func(t *testing.T) { updateCmd := dtos.AdminUpdateUserPermissionsForm{ IsGrafanaAdmin: false, } diff --git a/pkg/api/api.go b/pkg/api/api.go index 85fcac24f02..cb8a85af58b 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -18,7 +18,7 @@ func (hs *HTTPServer) registerRoutes() { reqEditorRole := middleware.ReqEditorRole reqOrgAdmin := middleware.ReqOrgAdmin reqCanAccessTeams := middleware.AdminOrFeatureEnabled(hs.Cfg.EditorsCanAdmin) - reqSnapshotPublicModeOrSignedIn := middleware.SnapshotPublicModeOrSignedIn() + reqSnapshotPublicModeOrSignedIn := middleware.SnapshotPublicModeOrSignedIn(hs.Cfg) redirectFromLegacyDashboardURL := middleware.RedirectFromLegacyDashboardURL() redirectFromLegacyDashboardSoloURL := middleware.RedirectFromLegacyDashboardSoloURL() redirectFromLegacyPanelEditURL := middleware.RedirectFromLegacyPanelEditURL() diff --git a/pkg/api/common.go b/pkg/api/common.go index 72ae41e5643..f3420354ff8 100644 --- a/pkg/api/common.go +++ b/pkg/api/common.go @@ -85,7 +85,7 @@ func Success(message string) *NormalResponse { return JSON(200, resp) } -// Error create a erroneous response +// Error creates an error response. func Error(status int, message string, err error) *NormalResponse { data := make(map[string]interface{}) diff --git a/pkg/api/common_test.go b/pkg/api/common_test.go index c1a4c358f76..71d878368d4 100644 --- a/pkg/api/common_test.go +++ b/pkg/api/common_test.go @@ -8,9 +8,14 @@ import ( "testing" "github.com/grafana/grafana/pkg/bus" - "github.com/grafana/grafana/pkg/middleware" + "github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/registry" "github.com/grafana/grafana/pkg/services/auth" + "github.com/grafana/grafana/pkg/services/contexthandler" + "github.com/grafana/grafana/pkg/services/rendering" + "github.com/grafana/grafana/pkg/services/sqlstore" + "github.com/grafana/grafana/pkg/setting" "github.com/stretchr/testify/require" "gopkg.in/macaron.v1" ) @@ -141,20 +146,68 @@ func (sc *scenarioContext) exec() { type scenarioFunc func(c *scenarioContext) type handlerFunc func(c *models.ReqContext) Response +func getContextHandler(t *testing.T) *contexthandler.ContextHandler { + t.Helper() + + sqlStore := sqlstore.InitTestDB(t) + remoteCacheSvc := &remotecache.RemoteCache{} + cfg := setting.NewCfg() + cfg.RemoteCacheOptions = &setting.RemoteCacheOptions{ + Name: "database", + } + userAuthTokenSvc := auth.NewFakeUserAuthTokenService() + renderSvc := &fakeRenderService{} + ctxHdlr := &contexthandler.ContextHandler{} + + err := registry.BuildServiceGraph([]interface{}{cfg}, []*registry.Descriptor{ + { + Name: sqlstore.ServiceName, + Instance: sqlStore, + }, + { + Name: remotecache.ServiceName, + Instance: remoteCacheSvc, + }, + { + Name: auth.ServiceName, + Instance: userAuthTokenSvc, + }, + { + Name: rendering.ServiceName, + Instance: renderSvc, + }, + { + Name: contexthandler.ServiceName, + Instance: ctxHdlr, + }, + }) + require.NoError(t, err) + + return ctxHdlr +} + func setupScenarioContext(t *testing.T, url string) *scenarioContext { sc := &scenarioContext{ url: url, t: t, } - viewsPath, _ := filepath.Abs("../../public/views") + viewsPath, err := filepath.Abs("../../public/views") + require.NoError(t, err) sc.m = macaron.New() sc.m.Use(macaron.Renderer(macaron.RenderOptions{ Directory: viewsPath, Delims: macaron.Delims{Left: "[[", Right: "]]"}, })) - - sc.m.Use(middleware.GetContextHandler(nil, nil, nil)) + sc.m.Use(getContextHandler(t).Middleware) return sc } + +type fakeRenderService struct { + rendering.Service +} + +func (s *fakeRenderService) Init() error { + return nil +} diff --git a/pkg/api/frontendsettings.go b/pkg/api/frontendsettings.go index b734f4dd78b..efc366351fc 100644 --- a/pkg/api/frontendsettings.go +++ b/pkg/api/frontendsettings.go @@ -193,11 +193,11 @@ func (hs *HTTPServer) getFrontendSettingsMap(c *models.ReqContext) (map[string]i "datasources": dataSources, "minRefreshInterval": setting.MinRefreshInterval, "panels": panels, - "appUrl": setting.AppUrl, - "appSubUrl": setting.AppSubUrl, + "appUrl": hs.Cfg.AppURL, + "appSubUrl": hs.Cfg.AppSubURL, "allowOrgCreate": (setting.AllowUserOrgCreate && c.IsSignedIn) || c.IsGrafanaAdmin, "authProxyEnabled": setting.AuthProxyEnabled, - "ldapEnabled": setting.LDAPEnabled, + "ldapEnabled": hs.Cfg.LDAPEnabled, "alertingEnabled": setting.AlertingEnabled, "alertingErrorOrTimeout": setting.AlertingErrorOrTimeout, "alertingNoDataOrNullValues": setting.AlertingNoDataOrNullValues, diff --git a/pkg/api/frontendsettings_test.go b/pkg/api/frontendsettings_test.go index 34d62bbd149..94f731d66f6 100644 --- a/pkg/api/frontendsettings_test.go +++ b/pkg/api/frontendsettings_test.go @@ -18,7 +18,6 @@ import ( "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/services/sqlstore" - "github.com/grafana/grafana/pkg/middleware" "gopkg.in/macaron.v1" "github.com/grafana/grafana/pkg/setting" @@ -53,7 +52,7 @@ func setupTestEnvironment(t *testing.T, cfg *setting.Cfg) (*macaron.Macaron, *HT } m := macaron.New() - m.Use(middleware.GetContextHandler(nil, nil, nil)) + m.Use(getContextHandler(t).Middleware) m.Use(macaron.Renderer(macaron.RenderOptions{ Directory: filepath.Join(setting.StaticRootPath, "views"), IndentJSON: true, @@ -84,10 +83,12 @@ func TestHTTPServer_GetFrontendSettings_hideVersionAnonyomus(t *testing.T) { setting.Env = "testing" tests := []struct { + desc string hideVersion bool expected settings }{ { + desc: "Not hiding version", hideVersion: false, expected: settings{ BuildInfo: buildInfo{ @@ -98,6 +99,7 @@ func TestHTTPServer_GetFrontendSettings_hideVersionAnonyomus(t *testing.T) { }, }, { + desc: "Hiding version", hideVersion: true, expected: settings{ BuildInfo: buildInfo{ @@ -110,16 +112,18 @@ func TestHTTPServer_GetFrontendSettings_hideVersionAnonyomus(t *testing.T) { } for _, test := range tests { - hs.Cfg.AnonymousHideVersion = test.hideVersion - expected := test.expected + t.Run(test.desc, func(t *testing.T) { + hs.Cfg.AnonymousHideVersion = test.hideVersion + expected := test.expected - recorder := httptest.NewRecorder() - m.ServeHTTP(recorder, req) - got := settings{} - err := json.Unmarshal(recorder.Body.Bytes(), &got) - require.NoError(t, err) - require.GreaterOrEqual(t, 400, recorder.Code, "status codes higher than 400 indicates a failure") + recorder := httptest.NewRecorder() + m.ServeHTTP(recorder, req) + got := settings{} + err := json.Unmarshal(recorder.Body.Bytes(), &got) + require.NoError(t, err) + require.GreaterOrEqual(t, 400, recorder.Code, "status codes higher than 400 indicate a failure") - assert.EqualValues(t, expected, got) + assert.EqualValues(t, expected, got) + }) } } diff --git a/pkg/api/http_server.go b/pkg/api/http_server.go index 20363146da3..c0c02d765d9 100644 --- a/pkg/api/http_server.go +++ b/pkg/api/http_server.go @@ -29,6 +29,7 @@ import ( "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/registry" + "github.com/grafana/grafana/pkg/services/contexthandler" "github.com/grafana/grafana/pkg/services/datasources" "github.com/grafana/grafana/pkg/services/hooks" "github.com/grafana/grafana/pkg/services/login" @@ -75,6 +76,7 @@ type HTTPServer struct { SearchService *search.SearchService `inject:""` ShortURLService *shorturls.ShortURLService `inject:""` Live *live.GrafanaLive `inject:""` + ContextHandler *contexthandler.ContextHandler `inject:""` Listener net.Listener } @@ -100,7 +102,7 @@ func (hs *HTTPServer) Run(ctx context.Context) error { Addr: net.JoinHostPort(setting.HttpAddr, setting.HttpPort), Handler: hs.macaron, } - switch setting.Protocol { + switch hs.Cfg.Protocol { case setting.HTTP2Scheme: if err := hs.configureHttp2(); err != nil { return err @@ -118,7 +120,7 @@ func (hs *HTTPServer) Run(ctx context.Context) error { } hs.log.Info("HTTP Server Listen", "address", listener.Addr().String(), "protocol", - setting.Protocol, "subUrl", setting.AppSubUrl, "socket", setting.SocketPath) + hs.Cfg.Protocol, "subUrl", hs.Cfg.AppSubURL, "socket", hs.Cfg.SocketPath) var wg sync.WaitGroup wg.Add(1) @@ -133,7 +135,7 @@ func (hs *HTTPServer) Run(ctx context.Context) error { } }() - switch setting.Protocol { + switch hs.Cfg.Protocol { case setting.HTTPScheme, setting.SocketScheme: if err := hs.httpSrv.Serve(listener); err != nil { if errors.Is(err, http.ErrServerClosed) { @@ -151,7 +153,7 @@ func (hs *HTTPServer) Run(ctx context.Context) error { return err } default: - panic(fmt.Sprintf("Unhandled protocol %q", setting.Protocol)) + panic(fmt.Sprintf("Unhandled protocol %q", hs.Cfg.Protocol)) } wg.Wait() @@ -164,7 +166,7 @@ func (hs *HTTPServer) getListener() (net.Listener, error) { return hs.Listener, nil } - switch setting.Protocol { + switch hs.Cfg.Protocol { case setting.HTTPScheme, setting.HTTPSScheme, setting.HTTP2Scheme: listener, err := net.Listen("tcp", hs.httpSrv.Addr) if err != nil { @@ -172,21 +174,21 @@ func (hs *HTTPServer) getListener() (net.Listener, error) { } return listener, nil case setting.SocketScheme: - listener, err := net.ListenUnix("unix", &net.UnixAddr{Name: setting.SocketPath, Net: "unix"}) + listener, err := net.ListenUnix("unix", &net.UnixAddr{Name: hs.Cfg.SocketPath, Net: "unix"}) if err != nil { - return nil, errutil.Wrapf(err, "failed to open listener for socket %s", setting.SocketPath) + return nil, errutil.Wrapf(err, "failed to open listener for socket %s", hs.Cfg.SocketPath) } // Make socket writable by group // nolint:gosec - if err := os.Chmod(setting.SocketPath, 0660); err != nil { + if err := os.Chmod(hs.Cfg.SocketPath, 0660); err != nil { return nil, errutil.Wrapf(err, "failed to change socket permissions") } return listener, nil default: - hs.log.Error("Invalid protocol", "protocol", setting.Protocol) - return nil, fmt.Errorf("invalid protocol %q", setting.Protocol) + hs.log.Error("Invalid protocol", "protocol", hs.Cfg.Protocol) + return nil, fmt.Errorf("invalid protocol %q", hs.Cfg.Protocol) } } @@ -271,7 +273,7 @@ func (hs *HTTPServer) configureHttp2() error { } func (hs *HTTPServer) newMacaron() *macaron.Macaron { - macaron.Env = setting.Env + macaron.Env = hs.Cfg.Env m := macaron.New() // automatically set HEAD for every GET @@ -294,13 +296,13 @@ func (hs *HTTPServer) applyRoutes() { func (hs *HTTPServer) addMiddlewaresAndStaticRoutes() { m := hs.macaron - m.Use(middleware.Logger()) + m.Use(middleware.Logger(hs.Cfg)) if setting.EnableGzip { m.Use(middleware.Gziper()) } - m.Use(middleware.Recovery()) + m.Use(middleware.Recovery(hs.Cfg)) for _, route := range plugins.StaticRoutes { pluginRoute := path.Join("/public/plugins/", route.PluginId) @@ -316,7 +318,7 @@ func (hs *HTTPServer) addMiddlewaresAndStaticRoutes() { hs.mapStatic(m, hs.Cfg.ImagesDir, "", "/public/img/attachments") } - m.Use(middleware.AddDefaultResponseHeaders()) + m.Use(middleware.AddDefaultResponseHeaders(hs.Cfg)) if setting.ServeFromSubPath && setting.AppSubUrl != "" { m.SetURLPrefix(setting.AppSubUrl) @@ -334,16 +336,12 @@ func (hs *HTTPServer) addMiddlewaresAndStaticRoutes() { m.Use(hs.apiHealthHandler) m.Use(hs.metricsEndpoint) - m.Use(middleware.GetContextHandler( - hs.AuthTokenService, - hs.RemoteCacheService, - hs.RenderService, - )) + m.Use(hs.ContextHandler.Middleware) m.Use(middleware.OrgRedirect()) // needs to be after context handler if setting.EnforceDomain { - m.Use(middleware.ValidateHostHeader(setting.Domain)) + m.Use(middleware.ValidateHostHeader(hs.Cfg.Domain)) } m.Use(middleware.HandleNoCacheHeader()) @@ -433,7 +431,7 @@ func (hs *HTTPServer) mapStatic(m *macaron.Macaron, rootDir string, dir string, } } - if setting.Env == setting.Dev { + if hs.Cfg.Env == setting.Dev { headers = func(c *macaron.Context) { c.Resp.Header().Set("Cache-Control", "max-age=0, must-revalidate, no-cache") } diff --git a/pkg/api/index.go b/pkg/api/index.go index 79d35c4bf75..aaaaed3a1f6 100644 --- a/pkg/api/index.go +++ b/pkg/api/index.go @@ -300,7 +300,7 @@ func (hs *HTTPServer) getNavTree(c *models.ReqContext, hasEditPerm bool) ([]*dto {Text: "Stats", Id: "server-stats", Url: setting.AppSubUrl + "/admin/stats", Icon: "graph-bar"}, } - if setting.LDAPEnabled { + if hs.Cfg.LDAPEnabled { adminNavLinks = append(adminNavLinks, &dtos.NavLink{ Text: "LDAP", Id: "ldap", Url: setting.AppSubUrl + "/admin/ldap", Icon: "book", }) @@ -371,7 +371,7 @@ func (hs *HTTPServer) setIndexViewData(c *models.ReqContext) (*dtos.IndexViewDat // special case when doing localhost call from image renderer if c.IsRenderCall && !hs.Cfg.ServeFromSubPath { - appURL = fmt.Sprintf("%s://localhost:%s", setting.Protocol, setting.HttpPort) + appURL = fmt.Sprintf("%s://localhost:%s", hs.Cfg.Protocol, setting.HttpPort) appSubURL = "" settings["appSubUrl"] = "" } diff --git a/pkg/api/ldap_debug.go b/pkg/api/ldap_debug.go index c8af442371e..e9d6f5bfcfa 100644 --- a/pkg/api/ldap_debug.go +++ b/pkg/api/ldap_debug.go @@ -116,7 +116,7 @@ func (hs *HTTPServer) GetLDAPStatus(c *models.ReqContext) Response { return Error(http.StatusBadRequest, "LDAP is not enabled", nil) } - ldapConfig, err := getLDAPConfig() + ldapConfig, err := getLDAPConfig(hs.Cfg) if err != nil { return Error(http.StatusBadRequest, "Failed to obtain the LDAP configuration. Please verify the configuration and try again", err) @@ -158,7 +158,7 @@ func (hs *HTTPServer) PostSyncUserWithLDAP(c *models.ReqContext) Response { return Error(http.StatusBadRequest, "LDAP is not enabled", nil) } - ldapConfig, err := getLDAPConfig() + ldapConfig, err := getLDAPConfig(hs.Cfg) if err != nil { return Error(http.StatusBadRequest, "Failed to obtain the LDAP configuration. Please verify the configuration and try again", err) } @@ -217,7 +217,7 @@ func (hs *HTTPServer) PostSyncUserWithLDAP(c *models.ReqContext) Response { upsertCmd := &models.UpsertUserCommand{ ReqContext: c, ExternalUser: user, - SignupAllowed: setting.LDAPAllowSignup, + SignupAllowed: hs.Cfg.LDAPAllowSignup, } err = bus.Dispatch(upsertCmd) @@ -235,7 +235,7 @@ func (hs *HTTPServer) GetUserFromLDAP(c *models.ReqContext) Response { return Error(http.StatusBadRequest, "LDAP is not enabled", nil) } - ldapConfig, err := getLDAPConfig() + ldapConfig, err := getLDAPConfig(hs.Cfg) if err != nil { return Error(http.StatusBadRequest, "Failed to obtain the LDAP configuration", err) diff --git a/pkg/api/ldap_debug_test.go b/pkg/api/ldap_debug_test.go index f21c4289606..a929c06e640 100644 --- a/pkg/api/ldap_debug_test.go +++ b/pkg/api/ldap_debug_test.go @@ -74,7 +74,7 @@ func getUserFromLDAPContext(t *testing.T, requestURL string) *scenarioContext { } func TestGetUserFromLDAPAPIEndpoint_UserNotFound(t *testing.T) { - getLDAPConfig = func() (*ldap.Config, error) { + getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) { return &ldap.Config{}, nil } @@ -131,7 +131,7 @@ func TestGetUserFromLDAPAPIEndpoint_OrgNotfound(t *testing.T) { return nil }) - getLDAPConfig = func() (*ldap.Config, error) { + getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) { return &ldap.Config{}, nil } @@ -193,7 +193,7 @@ func TestGetUserFromLDAPAPIEndpoint(t *testing.T) { return nil }) - getLDAPConfig = func() (*ldap.Config, error) { + getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) { return &ldap.Config{}, nil } @@ -273,7 +273,7 @@ func TestGetUserFromLDAPAPIEndpoint_WithTeamHandler(t *testing.T) { return nil }) - getLDAPConfig = func() (*ldap.Config, error) { + getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) { return &ldap.Config{}, nil } @@ -349,7 +349,7 @@ func TestGetLDAPStatusAPIEndpoint(t *testing.T) { {Host: "10.0.0.5", Port: 361, Available: false, Error: errors.New("something is awfully wrong")}, } - getLDAPConfig = func() (*ldap.Config, error) { + getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) { return &ldap.Config{}, nil } @@ -412,7 +412,7 @@ func postSyncUserWithLDAPContext(t *testing.T, requestURL string, preHook func(t func TestPostSyncUserWithLDAPAPIEndpoint_Success(t *testing.T) { sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T) { - getLDAPConfig = func() (*ldap.Config, error) { + getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) { return &ldap.Config{}, nil } @@ -457,7 +457,7 @@ func TestPostSyncUserWithLDAPAPIEndpoint_Success(t *testing.T) { func TestPostSyncUserWithLDAPAPIEndpoint_WhenUserNotFound(t *testing.T) { sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T) { - getLDAPConfig = func() (*ldap.Config, error) { + getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) { return &ldap.Config{}, nil } @@ -485,7 +485,7 @@ func TestPostSyncUserWithLDAPAPIEndpoint_WhenUserNotFound(t *testing.T) { func TestPostSyncUserWithLDAPAPIEndpoint_WhenGrafanaAdmin(t *testing.T) { sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T) { - getLDAPConfig = func() (*ldap.Config, error) { + getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) { return &ldap.Config{}, nil } @@ -528,7 +528,7 @@ func TestPostSyncUserWithLDAPAPIEndpoint_WhenGrafanaAdmin(t *testing.T) { func TestPostSyncUserWithLDAPAPIEndpoint_WhenUserNotInLDAP(t *testing.T) { sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T) { - getLDAPConfig = func() (*ldap.Config, error) { + getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) { return &ldap.Config{}, nil } diff --git a/pkg/api/login.go b/pkg/api/login.go index 9ebaf1e19ea..517fc231b4a 100644 --- a/pkg/api/login.go +++ b/pkg/api/login.go @@ -13,7 +13,7 @@ import ( "github.com/grafana/grafana/pkg/infra/metrics" "github.com/grafana/grafana/pkg/infra/network" "github.com/grafana/grafana/pkg/login" - "github.com/grafana/grafana/pkg/middleware" + "github.com/grafana/grafana/pkg/middleware/cookies" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util" @@ -61,12 +61,12 @@ func (hs *HTTPServer) ValidateRedirectTo(redirectTo string) error { return nil } -func (hs *HTTPServer) CookieOptionsFromCfg() middleware.CookieOptions { +func (hs *HTTPServer) CookieOptionsFromCfg() cookies.CookieOptions { path := "/" if len(hs.Cfg.AppSubURL) > 0 { path = hs.Cfg.AppSubURL } - return middleware.CookieOptions{ + return cookies.CookieOptions{ Path: path, Secure: hs.Cfg.CookieSecure, SameSiteDisabled: hs.Cfg.CookieSameSiteDisabled, @@ -101,7 +101,7 @@ func (hs *HTTPServer) LoginView(c *models.ReqContext) { // therefore the loginError should be passed to the view data // and the view should return immediately before attempting // to login again via OAuth and enter to a redirect loop - middleware.DeleteCookie(c.Resp, LoginErrorCookieName, hs.CookieOptionsFromCfg) + cookies.DeleteCookie(c.Resp, LoginErrorCookieName, hs.CookieOptionsFromCfg) viewData.Settings["loginError"] = loginError c.HTML(200, getViewIndex(), viewData) return @@ -113,7 +113,7 @@ func (hs *HTTPServer) LoginView(c *models.ReqContext) { if c.IsSignedIn { // Assign login token to auth proxy users if enable_login_token = true - if setting.AuthProxyEnabled && setting.AuthProxyEnableLoginToken { + if hs.Cfg.AuthProxyEnabled && hs.Cfg.AuthProxyEnableLoginToken { user := &models.User{Id: c.SignedInUser.UserId, Email: c.SignedInUser.Email, Login: c.SignedInUser.Login} err := hs.loginUserWithUser(user, c) if err != nil { @@ -129,7 +129,7 @@ func (hs *HTTPServer) LoginView(c *models.ReqContext) { log.Debugf("Ignored invalid redirect_to cookie value: %v", redirectTo) redirectTo = hs.Cfg.AppSubURL + "/" } - middleware.DeleteCookie(c.Resp, "redirect_to", hs.CookieOptionsFromCfg) + cookies.DeleteCookie(c.Resp, "redirect_to", hs.CookieOptionsFromCfg) c.Redirect(redirectTo) return } @@ -196,6 +196,7 @@ func (hs *HTTPServer) LoginPost(c *models.ReqContext, cmd dtos.LoginCommand) Res Username: cmd.User, Password: cmd.Password, IpAddress: c.Req.RemoteAddr, + Cfg: hs.Cfg, } err := bus.Dispatch(authQuery) @@ -236,7 +237,7 @@ func (hs *HTTPServer) LoginPost(c *models.ReqContext, cmd dtos.LoginCommand) Res } else { log.Infof("Ignored invalid redirect_to cookie value: %v", redirectTo) } - middleware.DeleteCookie(c.Resp, "redirect_to", hs.CookieOptionsFromCfg) + cookies.DeleteCookie(c.Resp, "redirect_to", hs.CookieOptionsFromCfg) } metrics.MApiLoginPost.Inc() @@ -263,7 +264,7 @@ func (hs *HTTPServer) loginUserWithUser(user *models.User, c *models.ReqContext) } hs.log.Info("Successful Login", "User", user.Email) - middleware.WriteSessionCookie(c, userToken.UnhashedToken, hs.Cfg.LoginMaxLifetime) + cookies.WriteSessionCookie(c, hs.Cfg, userToken.UnhashedToken, hs.Cfg.LoginMaxLifetime) return nil } @@ -278,7 +279,7 @@ func (hs *HTTPServer) Logout(c *models.ReqContext) { hs.log.Error("failed to revoke auth token", "error", err) } - middleware.WriteSessionCookie(c, "", -1) + cookies.WriteSessionCookie(c, hs.Cfg, "", -1) if setting.SignoutRedirectUrl != "" { c.Redirect(setting.SignoutRedirectUrl) @@ -309,7 +310,7 @@ func (hs *HTTPServer) trySetEncryptedCookie(ctx *models.ReqContext, cookieName s return err } - middleware.WriteCookie(ctx.Resp, cookieName, hex.EncodeToString(encryptedError), 60, hs.CookieOptionsFromCfg) + cookies.WriteCookie(ctx.Resp, cookieName, hex.EncodeToString(encryptedError), 60, hs.CookieOptionsFromCfg) return nil } diff --git a/pkg/api/login_oauth.go b/pkg/api/login_oauth.go index 9f00db14f25..64a16c2c386 100644 --- a/pkg/api/login_oauth.go +++ b/pkg/api/login_oauth.go @@ -18,7 +18,7 @@ import ( "github.com/grafana/grafana/pkg/infra/metrics" "github.com/grafana/grafana/pkg/login" "github.com/grafana/grafana/pkg/login/social" - "github.com/grafana/grafana/pkg/middleware" + "github.com/grafana/grafana/pkg/middleware/cookies" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/setting" ) @@ -81,7 +81,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) { } hashedState := hashStatecode(state, setting.OAuthService.OAuthInfos[name].ClientSecret) - middleware.WriteCookie(ctx.Resp, OauthStateCookieName, hashedState, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg) + cookies.WriteCookie(ctx.Resp, OauthStateCookieName, hashedState, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg) if setting.OAuthService.OAuthInfos[name].HostedDomain == "" { ctx.Redirect(connect.AuthCodeURL(state, oauth2.AccessTypeOnline)) } else { @@ -93,7 +93,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) { cookieState := ctx.GetCookie(OauthStateCookieName) // delete cookie - middleware.DeleteCookie(ctx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg) + cookies.DeleteCookie(ctx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg) if cookieState == "" { hs.handleOAuthLoginError(ctx, loginInfo, LoginError{ @@ -192,7 +192,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) { if redirectTo, err := url.QueryUnescape(ctx.GetCookie("redirect_to")); err == nil && len(redirectTo) > 0 { if err := hs.ValidateRedirectTo(redirectTo); err == nil { - middleware.DeleteCookie(ctx.Resp, "redirect_to", hs.CookieOptionsFromCfg) + cookies.DeleteCookie(ctx.Resp, "redirect_to", hs.CookieOptionsFromCfg) ctx.Redirect(redirectTo) return } diff --git a/pkg/api/login_test.go b/pkg/api/login_test.go index 5abede806d3..1403fd478f2 100644 --- a/pkg/api/login_test.go +++ b/pkg/api/login_test.go @@ -592,8 +592,8 @@ func setupAuthProxyLoginTest(t *testing.T, enableLoginToken bool) *scenarioConte setting.OAuthService = &setting.OAuther{} setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo) - setting.AuthProxyEnabled = true - setting.AuthProxyEnableLoginToken = enableLoginToken + hs.Cfg.AuthProxyEnabled = true + hs.Cfg.AuthProxyEnableLoginToken = enableLoginToken sc.m.Get(sc.url, sc.defaultHandler) sc.fakeReqNoAssertions("GET", sc.url).exec() diff --git a/pkg/infra/remotecache/remotecache.go b/pkg/infra/remotecache/remotecache.go index b1c3f6b485f..931df03acfe 100644 --- a/pkg/infra/remotecache/remotecache.go +++ b/pkg/infra/remotecache/remotecache.go @@ -23,8 +23,17 @@ var ( defaultMaxCacheExpiration = time.Hour * 24 ) +const ( + ServiceName = "RemoteCache" +) + func init() { - registry.RegisterService(&RemoteCache{}) + rc := &RemoteCache{} + registry.Register(®istry.Descriptor{ + Name: ServiceName, + Instance: rc, + InitPriority: registry.Medium, + }) } // CacheStorage allows the caller to set, get and delete items in the cache. diff --git a/pkg/login/auth.go b/pkg/login/auth.go index 2589ce2ee4e..d6fbe157bfa 100644 --- a/pkg/login/auth.go +++ b/pkg/login/auth.go @@ -25,12 +25,12 @@ var ( var loginLogger = log.New("login") func Init() { - bus.AddHandler("auth", AuthenticateUser) + bus.AddHandler("auth", authenticateUser) } -// AuthenticateUser authenticates the user via username & password -func AuthenticateUser(query *models.LoginUserQuery) error { - if err := validateLoginAttempts(query.Username); err != nil { +// authenticateUser authenticates the user via username & password +func authenticateUser(query *models.LoginUserQuery) error { + if err := validateLoginAttempts(query); err != nil { return err } diff --git a/pkg/login/auth_test.go b/pkg/login/auth_test.go index 82e76f46f19..84e54701414 100644 --- a/pkg/login/auth_test.go +++ b/pkg/login/auth_test.go @@ -21,7 +21,7 @@ func TestAuthenticateUser(t *testing.T) { Username: "user", Password: "", } - err := AuthenticateUser(&loginQuery) + err := authenticateUser(&loginQuery) Convey("login should fail", func() { So(sc.grafanaLoginWasCalled, ShouldBeFalse) @@ -37,7 +37,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(true, nil, sc) mockSaveInvalidLoginAttempt(sc) - err := AuthenticateUser(sc.loginUserQuery) + err := authenticateUser(sc.loginUserQuery) Convey("it should result in", func() { So(err, ShouldEqual, ErrTooManyLoginAttempts) @@ -55,7 +55,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(true, ErrInvalidCredentials, sc) mockSaveInvalidLoginAttempt(sc) - err := AuthenticateUser(sc.loginUserQuery) + err := authenticateUser(sc.loginUserQuery) Convey("it should result in", func() { So(err, ShouldEqual, nil) @@ -74,7 +74,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(true, ErrInvalidCredentials, sc) mockSaveInvalidLoginAttempt(sc) - err := AuthenticateUser(sc.loginUserQuery) + err := authenticateUser(sc.loginUserQuery) Convey("it should result in", func() { So(err, ShouldEqual, customErr) @@ -92,7 +92,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(false, nil, sc) mockSaveInvalidLoginAttempt(sc) - err := AuthenticateUser(sc.loginUserQuery) + err := authenticateUser(sc.loginUserQuery) Convey("it should result in", func() { So(err, ShouldEqual, models.ErrUserNotFound) @@ -110,7 +110,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc) mockSaveInvalidLoginAttempt(sc) - err := AuthenticateUser(sc.loginUserQuery) + err := authenticateUser(sc.loginUserQuery) Convey("it should result in", func() { So(err, ShouldEqual, ErrInvalidCredentials) @@ -128,7 +128,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(true, nil, sc) mockSaveInvalidLoginAttempt(sc) - err := AuthenticateUser(sc.loginUserQuery) + err := authenticateUser(sc.loginUserQuery) Convey("it should result in", func() { So(err, ShouldBeNil) @@ -147,7 +147,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(true, customErr, sc) mockSaveInvalidLoginAttempt(sc) - err := AuthenticateUser(sc.loginUserQuery) + err := authenticateUser(sc.loginUserQuery) Convey("it should result in", func() { So(err, ShouldEqual, customErr) @@ -165,7 +165,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc) mockSaveInvalidLoginAttempt(sc) - err := AuthenticateUser(sc.loginUserQuery) + err := authenticateUser(sc.loginUserQuery) Convey("it should result in", func() { So(err, ShouldEqual, ErrInvalidCredentials) @@ -203,7 +203,7 @@ func mockLoginUsingLDAP(enabled bool, err error, sc *authScenarioContext) { } func mockLoginAttemptValidation(err error, sc *authScenarioContext) { - validateLoginAttempts = func(username string) error { + validateLoginAttempts = func(*models.LoginUserQuery) error { sc.loginAttemptValidationWasCalled = true return err } diff --git a/pkg/login/brute_force_login_protection.go b/pkg/login/brute_force_login_protection.go index 2c0cacb99fc..3d914b48e45 100644 --- a/pkg/login/brute_force_login_protection.go +++ b/pkg/login/brute_force_login_protection.go @@ -5,7 +5,6 @@ import ( "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/models" - "github.com/grafana/grafana/pkg/setting" ) var ( @@ -13,13 +12,13 @@ var ( loginAttemptsWindow = time.Minute * 5 ) -var validateLoginAttempts = func(username string) error { - if setting.DisableBruteForceLoginProtection { +var validateLoginAttempts = func(query *models.LoginUserQuery) error { + if query.Cfg.DisableBruteForceLoginProtection { return nil } loginAttemptCountQuery := models.GetUserLoginAttemptCountQuery{ - Username: username, + Username: query.Username, Since: time.Now().Add(-loginAttemptsWindow), } @@ -35,7 +34,7 @@ var validateLoginAttempts = func(username string) error { } var saveInvalidLoginAttempt = func(query *models.LoginUserQuery) error { - if setting.DisableBruteForceLoginProtection { + if query.Cfg.DisableBruteForceLoginProtection { return nil } diff --git a/pkg/login/brute_force_login_protection_test.go b/pkg/login/brute_force_login_protection_test.go index f8a09000543..e9106a3e439 100644 --- a/pkg/login/brute_force_login_protection_test.go +++ b/pkg/login/brute_force_login_protection_test.go @@ -12,11 +12,16 @@ import ( func TestLoginAttemptsValidation(t *testing.T) { Convey("Validate login attempts", t, func() { Convey("Given brute force login protection enabled", func() { - setting.DisableBruteForceLoginProtection = false + cfg := setting.NewCfg() + cfg.DisableBruteForceLoginProtection = false + query := &models.LoginUserQuery{ + Username: "user", + Cfg: cfg, + } Convey("When user login attempt count equals max-1 ", func() { withLoginAttempts(maxInvalidLoginAttempts - 1) - err := validateLoginAttempts("user") + err := validateLoginAttempts(query) Convey("it should not result in error", func() { So(err, ShouldBeNil) @@ -25,7 +30,7 @@ func TestLoginAttemptsValidation(t *testing.T) { Convey("When user login attempt count equals max ", func() { withLoginAttempts(maxInvalidLoginAttempts) - err := validateLoginAttempts("user") + err := validateLoginAttempts(query) Convey("it should result in too many login attempts error", func() { So(err, ShouldEqual, ErrTooManyLoginAttempts) @@ -34,7 +39,7 @@ func TestLoginAttemptsValidation(t *testing.T) { Convey("When user login attempt count is greater than max ", func() { withLoginAttempts(maxInvalidLoginAttempts + 5) - err := validateLoginAttempts("user") + err := validateLoginAttempts(query) Convey("it should result in too many login attempts error", func() { So(err, ShouldEqual, ErrTooManyLoginAttempts) @@ -54,6 +59,7 @@ func TestLoginAttemptsValidation(t *testing.T) { Username: "user", Password: "pwd", IpAddress: "192.168.1.1:56433", + Cfg: setting.NewCfg(), }) So(err, ShouldBeNil) @@ -66,11 +72,16 @@ func TestLoginAttemptsValidation(t *testing.T) { }) Convey("Given brute force login protection disabled", func() { - setting.DisableBruteForceLoginProtection = true + cfg := setting.NewCfg() + cfg.DisableBruteForceLoginProtection = true + query := &models.LoginUserQuery{ + Username: "user", + Cfg: cfg, + } Convey("When user login attempt count equals max-1 ", func() { withLoginAttempts(maxInvalidLoginAttempts - 1) - err := validateLoginAttempts("user") + err := validateLoginAttempts(query) Convey("it should not result in error", func() { So(err, ShouldBeNil) @@ -79,7 +90,7 @@ func TestLoginAttemptsValidation(t *testing.T) { Convey("When user login attempt count equals max ", func() { withLoginAttempts(maxInvalidLoginAttempts) - err := validateLoginAttempts("user") + err := validateLoginAttempts(query) Convey("it should not result in error", func() { So(err, ShouldBeNil) @@ -88,7 +99,7 @@ func TestLoginAttemptsValidation(t *testing.T) { Convey("When user login attempt count is greater than max ", func() { withLoginAttempts(maxInvalidLoginAttempts + 5) - err := validateLoginAttempts("user") + err := validateLoginAttempts(query) Convey("it should not result in error", func() { So(err, ShouldBeNil) @@ -97,7 +108,7 @@ func TestLoginAttemptsValidation(t *testing.T) { Convey("When saving invalid login attempt", func() { defer bus.ClearBusHandlers() - createLoginAttemptCmd := (*models.CreateLoginAttemptCommand)(nil) + var createLoginAttemptCmd *models.CreateLoginAttemptCommand bus.AddHandler("test", func(cmd *models.CreateLoginAttemptCommand) error { createLoginAttemptCmd = cmd @@ -108,6 +119,7 @@ func TestLoginAttemptsValidation(t *testing.T) { Username: "user", Password: "pwd", IpAddress: "192.168.1.1:56433", + Cfg: cfg, }) So(err, ShouldBeNil) diff --git a/pkg/login/ldap_login.go b/pkg/login/ldap_login.go index f394ed53466..cb5d984e736 100644 --- a/pkg/login/ldap_login.go +++ b/pkg/login/ldap_login.go @@ -33,7 +33,7 @@ var loginUsingLDAP = func(query *models.LoginUserQuery) (bool, error) { return false, nil } - config, err := getLDAPConfig() + config, err := getLDAPConfig(query.Cfg) if err != nil { return true, errutil.Wrap("Failed to get LDAP config", err) } diff --git a/pkg/login/ldap_login_test.go b/pkg/login/ldap_login_test.go index 63fdbb68e7a..949ae433982 100644 --- a/pkg/login/ldap_login_test.go +++ b/pkg/login/ldap_login_test.go @@ -20,7 +20,7 @@ func TestLDAPLogin(t *testing.T) { LDAPLoginScenario("When login", func(sc *LDAPLoginScenarioContext) { sc.withLoginResult(false) - getLDAPConfig = func() (*ldap.Config, error) { + getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) { config := &ldap.Config{ Servers: []*ldap.ServerConfig{}, } @@ -150,7 +150,14 @@ func LDAPLoginScenario(desc string, fn LDAPLoginScenarioFunc) { LDAPAuthenticatorMock: mock, } - getLDAPConfig = func() (*ldap.Config, error) { + origNewLDAP := newLDAP + origGetLDAPConfig := getLDAPConfig + defer func() { + newLDAP = origNewLDAP + getLDAPConfig = origGetLDAPConfig + }() + + getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) { config := &ldap.Config{ Servers: []*ldap.ServerConfig{ { @@ -166,11 +173,6 @@ func LDAPLoginScenario(desc string, fn LDAPLoginScenarioFunc) { return mock } - defer func() { - newLDAP = multildap.New - getLDAPConfig = multildap.GetConfig - }() - fn(sc) }) } diff --git a/pkg/middleware/auth.go b/pkg/middleware/auth.go index d24871d31ca..1492fc8f660 100644 --- a/pkg/middleware/auth.go +++ b/pkg/middleware/auth.go @@ -8,9 +8,9 @@ import ( macaron "gopkg.in/macaron.v1" + "github.com/grafana/grafana/pkg/middleware/cookies" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/setting" - "github.com/grafana/grafana/pkg/util" ) type AuthOptions struct { @@ -18,22 +18,6 @@ type AuthOptions struct { ReqSignedIn bool } -func getApiKey(c *models.ReqContext) string { - header := c.Req.Header.Get("Authorization") - parts := strings.SplitN(header, " ", 2) - if len(parts) == 2 && parts[0] == "Bearer" { - key := parts[1] - return key - } - - username, password, err := util.DecodeBasicAuthHeader(header) - if err == nil && username == "api_key" { - return password - } - - return "" -} - func accessForbidden(c *models.ReqContext) { if c.IsApiRequest() { c.JsonApiErr(403, "Permission denied", nil) @@ -57,7 +41,7 @@ func notAuthorized(c *models.ReqContext) { // remove any forceLogin=true params redirectTo = removeForceLoginParams(redirectTo) - WriteCookie(c.Resp, "redirect_to", url.QueryEscape(redirectTo), 0, nil) + cookies.WriteCookie(c.Resp, "redirect_to", url.QueryEscape(redirectTo), 0, nil) c.Redirect(setting.AppSubUrl + "/login") } @@ -135,9 +119,9 @@ func AdminOrFeatureEnabled(enabled bool) macaron.Handler { } } -func SnapshotPublicModeOrSignedIn() macaron.Handler { +func SnapshotPublicModeOrSignedIn(cfg *setting.Cfg) macaron.Handler { return func(c *models.ReqContext) { - if setting.SnapshotPublicMode { + if cfg.SnapshotPublicMode { return } diff --git a/pkg/middleware/auth_proxy.go b/pkg/middleware/auth_proxy.go deleted file mode 100644 index 4c5f9d22c48..00000000000 --- a/pkg/middleware/auth_proxy.go +++ /dev/null @@ -1,133 +0,0 @@ -package middleware - -import ( - "errors" - - "github.com/grafana/grafana/pkg/infra/log" - "github.com/grafana/grafana/pkg/infra/remotecache" - "github.com/grafana/grafana/pkg/middleware/authproxy" - "github.com/grafana/grafana/pkg/models" - "github.com/grafana/grafana/pkg/setting" -) - -var header = setting.AuthProxyHeaderName - -func logUserIn(auth *authproxy.AuthProxy, username string, logger log.Logger, ignoreCache bool) (int64, error) { - logger.Debug("Trying to log user in", "username", username, "ignoreCache", ignoreCache) - // Try to log in user via various providers - id, err := auth.Login(logger, ignoreCache) - if err != nil { - details := err - var e authproxy.Error - if errors.As(err, &e) { - details = e.DetailsError - } - logger.Error("Failed to login", "username", username, "message", err.Error(), "error", details, - "ignoreCache", ignoreCache) - return 0, err - } - return id, nil -} - -// handleError calls ctx.Handle with the error message and the underlying error. -// If the error is of type authproxy.Error, its DetailsError is unwrapped and passed to ctx.Handle. -// If a callback is provided, it's called with either err.DetailsError, if err is of type -// authproxy.Error, otherwise err itself. -func handleError(ctx *models.ReqContext, err error, statusCode int, cb func(err error)) { - details := err - var e authproxy.Error - if errors.As(err, &e) { - details = e.DetailsError - } - - ctx.Handle(statusCode, err.Error(), details) - - if cb != nil { - cb(details) - } -} - -func initContextWithAuthProxy(store *remotecache.RemoteCache, ctx *models.ReqContext, orgID int64) bool { - username := ctx.Req.Header.Get(header) - auth := authproxy.New(&authproxy.Options{ - Store: store, - Ctx: ctx, - OrgID: orgID, - }) - - logger := log.New("auth.proxy") - - // Bail if auth proxy is not enabled - if !auth.IsEnabled() { - return false - } - - // If there is no header - we can't move forward - if !auth.HasHeader() { - return false - } - - // Check if allowed to continue with this IP - if err := auth.IsAllowedIP(); err != nil { - handleError(ctx, err, 407, func(details error) { - logger.Error("Failed to check whitelisted IP addresses", "message", err.Error(), "error", details) - }) - return true - } - - id, err := logUserIn(auth, username, logger, false) - if err != nil { - handleError(ctx, err, 407, nil) - return true - } - - logger.Debug("Got user ID, getting full user info", "userID", id) - - user, e := auth.GetSignedUser(id) - if e != nil { - // The reason we couldn't find the user corresponding to the ID might be that the ID was found from a stale - // cache entry. For example, if a user is deleted via the API, corresponding cache entries aren't invalidated - // because cache keys are computed from request header values and not just the user ID. Meaning that - // we can't easily derive cache keys to invalidate when deleting a user. To work around this, we try to - // log the user in again without the cache. - logger.Debug("Failed to get user info given ID, retrying without cache", "userID", id) - if err := auth.RemoveUserFromCache(logger); err != nil { - if !errors.Is(err, remotecache.ErrCacheItemNotFound) { - logger.Error("Got unexpected error when removing user from auth cache", "error", err) - } - } - id, err = logUserIn(auth, username, logger, true) - if err != nil { - handleError(ctx, err, 407, nil) - return true - } - - user, err = auth.GetSignedUser(id) - if err != nil { - handleError(ctx, err, 407, nil) - - return true - } - } - - logger.Debug("Successfully got user info", "userID", user.UserId, "username", user.Login) - - // Add user info to context - ctx.SignedInUser = user - ctx.IsSignedIn = true - - // Remember user data in cache - if err := auth.Remember(id); err != nil { - handleError(ctx, err, 500, func(details error) { - logger.Error( - "Failed to store user in cache", - "username", username, - "message", e.Error(), - "error", details, - ) - }) - return true - } - - return true -} diff --git a/pkg/middleware/auth_test.go b/pkg/middleware/auth_test.go index 1cf03aa4df2..8b734255312 100644 --- a/pkg/middleware/auth_test.go +++ b/pkg/middleware/auth_test.go @@ -33,16 +33,10 @@ func TestMiddlewareAuth(t *testing.T) { t.Run("Anonymous auth enabled", func(t *testing.T) { const orgID int64 = 1 - origEnabled := setting.AnonymousEnabled - t.Cleanup(func() { - setting.AnonymousEnabled = origEnabled - }) - origName := setting.AnonymousOrgName - t.Cleanup(func() { - setting.AnonymousOrgName = origName - }) - setting.AnonymousEnabled = true - setting.AnonymousOrgName = "test" + configure := func(cfg *setting.Cfg) { + cfg.AnonymousEnabled = true + cfg.AnonymousOrgName = "test" + } middlewareScenario(t, "ReqSignIn true and request with forceLogin in query string", func( t *testing.T, sc *scenarioContext) { @@ -59,7 +53,7 @@ func TestMiddlewareAuth(t *testing.T) { location, ok := sc.resp.Header()["Location"] assert.True(t, ok) assert.Equal(t, "/login", location[0]) - }) + }, configure) middlewareScenario(t, "ReqSignIn true and request with same org provided in query string", func( t *testing.T, sc *scenarioContext) { @@ -73,7 +67,7 @@ func TestMiddlewareAuth(t *testing.T) { sc.fakeReq("GET", fmt.Sprintf("/secure?orgId=%d", orgID)).exec() assert.Equal(t, 200, sc.resp.Code) - }) + }, configure) middlewareScenario(t, "ReqSignIn true and request with different org provided in query string", func( t *testing.T, sc *scenarioContext) { @@ -90,20 +84,20 @@ func TestMiddlewareAuth(t *testing.T) { location, ok := sc.resp.Header()["Location"] assert.True(t, ok) assert.Equal(t, "/login", location[0]) - }) + }, configure) }) middlewareScenario(t, "Snapshot public mode disabled and unauthenticated request should return 401", func( t *testing.T, sc *scenarioContext) { - sc.m.Get("/api/snapshot", SnapshotPublicModeOrSignedIn(), sc.defaultHandler) + sc.m.Get("/api/snapshot", SnapshotPublicModeOrSignedIn(sc.cfg), sc.defaultHandler) sc.fakeReq("GET", "/api/snapshot").exec() assert.Equal(t, 401, sc.resp.Code) }) middlewareScenario(t, "Snapshot public mode enabled and unauthenticated request should return 200", func( t *testing.T, sc *scenarioContext) { - setting.SnapshotPublicMode = true - sc.m.Get("/api/snapshot", SnapshotPublicModeOrSignedIn(), sc.defaultHandler) + sc.cfg.SnapshotPublicMode = true + sc.m.Get("/api/snapshot", SnapshotPublicModeOrSignedIn(sc.cfg), sc.defaultHandler) sc.fakeReq("GET", "/api/snapshot").exec() assert.Equal(t, 200, sc.resp.Code) }) diff --git a/pkg/middleware/cookie.go b/pkg/middleware/cookies/cookies.go similarity index 90% rename from pkg/middleware/cookie.go rename to pkg/middleware/cookies/cookies.go index 78939632dbc..b819320758d 100644 --- a/pkg/middleware/cookie.go +++ b/pkg/middleware/cookies/cookies.go @@ -1,4 +1,4 @@ -package middleware +package cookies import ( "net/http" @@ -55,8 +55,8 @@ func WriteCookie(w http.ResponseWriter, name string, value string, maxAge int, g http.SetCookie(w, &cookie) } -func WriteSessionCookie(ctx *models.ReqContext, value string, maxLifetime time.Duration) { - if setting.Env == setting.Dev { +func WriteSessionCookie(ctx *models.ReqContext, cfg *setting.Cfg, value string, maxLifetime time.Duration) { + if cfg.Env == setting.Dev { ctx.Logger.Info("New token", "unhashed token", value) } diff --git a/pkg/middleware/logger.go b/pkg/middleware/logger.go index 94de372fee4..455231835ca 100644 --- a/pkg/middleware/logger.go +++ b/pkg/middleware/logger.go @@ -26,7 +26,7 @@ import ( "gopkg.in/macaron.v1" ) -func Logger() macaron.Handler { +func Logger(cfg *setting.Cfg) macaron.Handler { return func(res http.ResponseWriter, req *http.Request, c *macaron.Context) { start := time.Now() c.Data["perfmon.start"] = start @@ -43,7 +43,7 @@ func Logger() macaron.Handler { status := rw.Status() if status == 200 || status == 304 { - if !setting.RouterLogging { + if !cfg.RouterLogging { return } } diff --git a/pkg/middleware/middleware.go b/pkg/middleware/middleware.go index 00531f22823..a49155445d6 100644 --- a/pkg/middleware/middleware.go +++ b/pkg/middleware/middleware.go @@ -1,32 +1,13 @@ package middleware import ( - "context" - "errors" "fmt" - "strconv" "strings" - "time" macaron "gopkg.in/macaron.v1" - "github.com/grafana/grafana/pkg/bus" - "github.com/grafana/grafana/pkg/components/apikeygen" - "github.com/grafana/grafana/pkg/infra/log" - "github.com/grafana/grafana/pkg/infra/network" - "github.com/grafana/grafana/pkg/infra/remotecache" - "github.com/grafana/grafana/pkg/login" "github.com/grafana/grafana/pkg/models" - "github.com/grafana/grafana/pkg/services/rendering" "github.com/grafana/grafana/pkg/setting" - "github.com/grafana/grafana/pkg/util" -) - -var getTime = time.Now - -const ( - errStringInvalidUsernamePassword = "Invalid username or password" - errStringInvalidAPIKey = "Invalid API key" ) var ( @@ -39,244 +20,7 @@ var ( ReqOrgAdmin = RoleAuth(models.ROLE_ADMIN) ) -func GetContextHandler( - ats models.UserTokenService, - remoteCache *remotecache.RemoteCache, - renderService rendering.Service, -) macaron.Handler { - return func(c *macaron.Context) { - ctx := &models.ReqContext{ - Context: c, - SignedInUser: &models.SignedInUser{}, - IsSignedIn: false, - AllowAnonymous: false, - SkipCache: false, - Logger: log.New("context"), - } - - orgID := int64(0) - orgIDHeader := ctx.Req.Header.Get("X-Grafana-Org-Id") - if orgIDHeader != "" { - orgIDParsed, err := strconv.ParseInt(orgIDHeader, 10, 64) - if err == nil { - orgID = orgIDParsed - } - } - - // the order in which these are tested are important - // look for api key in Authorization header first - // then init session and look for userId in session - // then look for api key in session (special case for render calls via api) - // then test if anonymous access is enabled - switch { - case initContextWithRenderAuth(ctx, renderService): - case initContextWithApiKey(ctx): - case initContextWithBasicAuth(ctx, orgID): - case initContextWithAuthProxy(remoteCache, ctx, orgID): - case initContextWithToken(ats, ctx, orgID): - case initContextWithAnonymousUser(ctx): - } - - ctx.Logger = log.New("context", "userId", ctx.UserId, "orgId", ctx.OrgId, "uname", ctx.Login) - ctx.Data["ctx"] = ctx - - c.Map(ctx) - - // update last seen every 5min - if ctx.ShouldUpdateLastSeenAt() { - ctx.Logger.Debug("Updating last user_seen_at", "user_id", ctx.UserId) - if err := bus.Dispatch(&models.UpdateUserLastSeenAtCommand{UserId: ctx.UserId}); err != nil { - ctx.Logger.Error("Failed to update last_seen_at", "error", err) - } - } - } -} - -func initContextWithAnonymousUser(ctx *models.ReqContext) bool { - if !setting.AnonymousEnabled { - return false - } - - orgQuery := models.GetOrgByNameQuery{Name: setting.AnonymousOrgName} - if err := bus.Dispatch(&orgQuery); err != nil { - log.Errorf(3, "Anonymous access organization error: '%s': %s", setting.AnonymousOrgName, err) - return false - } - - ctx.IsSignedIn = false - ctx.AllowAnonymous = true - ctx.SignedInUser = &models.SignedInUser{IsAnonymous: true} - ctx.OrgRole = models.RoleType(setting.AnonymousOrgRole) - ctx.OrgId = orgQuery.Result.Id - ctx.OrgName = orgQuery.Result.Name - return true -} - -func initContextWithApiKey(ctx *models.ReqContext) bool { - var keyString string - if keyString = getApiKey(ctx); keyString == "" { - return false - } - - // base64 decode key - decoded, err := apikeygen.Decode(keyString) - if err != nil { - ctx.JsonApiErr(401, errStringInvalidAPIKey, err) - return true - } - - // fetch key - keyQuery := models.GetApiKeyByNameQuery{KeyName: decoded.Name, OrgId: decoded.OrgId} - if err := bus.Dispatch(&keyQuery); err != nil { - ctx.JsonApiErr(401, errStringInvalidAPIKey, err) - return true - } - - apikey := keyQuery.Result - - // validate api key - isValid, err := apikeygen.IsValid(decoded, apikey.Key) - if err != nil { - ctx.JsonApiErr(500, "Validating API key failed", err) - return true - } - if !isValid { - ctx.JsonApiErr(401, errStringInvalidAPIKey, err) - return true - } - - // check for expiration - if apikey.Expires != nil && *apikey.Expires <= getTime().Unix() { - ctx.JsonApiErr(401, "Expired API key", err) - return true - } - - ctx.IsSignedIn = true - ctx.SignedInUser = &models.SignedInUser{} - ctx.OrgRole = apikey.Role - ctx.ApiKeyId = apikey.Id - ctx.OrgId = apikey.OrgId - return true -} - -func initContextWithBasicAuth(ctx *models.ReqContext, orgId int64) bool { - if !setting.BasicAuthEnabled { - return false - } - - header := ctx.Req.Header.Get("Authorization") - if header == "" { - return false - } - - username, password, err := util.DecodeBasicAuthHeader(header) - if err != nil { - ctx.JsonApiErr(401, "Invalid Basic Auth Header", err) - return true - } - - authQuery := models.LoginUserQuery{ - Username: username, - Password: password, - } - if err := bus.Dispatch(&authQuery); err != nil { - ctx.Logger.Debug( - "Failed to authorize the user", - "username", username, - "err", err, - ) - - if errors.Is(err, models.ErrUserNotFound) { - err = login.ErrInvalidCredentials - } - ctx.JsonApiErr(401, errStringInvalidUsernamePassword, err) - return true - } - - user := authQuery.User - - query := models.GetSignedInUserQuery{UserId: user.Id, OrgId: orgId} - if err := bus.Dispatch(&query); err != nil { - ctx.Logger.Error( - "Failed at user signed in", - "id", user.Id, - "org", orgId, - ) - ctx.JsonApiErr(401, errStringInvalidUsernamePassword, err) - return true - } - - ctx.SignedInUser = query.Result - ctx.IsSignedIn = true - return true -} - -func initContextWithToken(authTokenService models.UserTokenService, ctx *models.ReqContext, orgID int64) bool { - if setting.LoginCookieName == "" { - return false - } - - rawToken := ctx.GetCookie(setting.LoginCookieName) - if rawToken == "" { - return false - } - - token, err := authTokenService.LookupToken(ctx.Req.Context(), rawToken) - if err != nil { - ctx.Logger.Error("Failed to look up user based on cookie", "error", err) - WriteSessionCookie(ctx, "", -1) - return false - } - - query := models.GetSignedInUserQuery{UserId: token.UserId, OrgId: orgID} - if err := bus.Dispatch(&query); err != nil { - ctx.Logger.Error("Failed to get user with id", "userId", token.UserId, "error", err) - return false - } - - ctx.SignedInUser = query.Result - ctx.IsSignedIn = true - ctx.UserToken = token - - // Rotate the token just before we write response headers to ensure there is no delay between - // the new token being generated and the client receiving it. - ctx.Resp.Before(rotateEndOfRequestFunc(ctx, authTokenService, token)) - - return true -} - -func rotateEndOfRequestFunc(ctx *models.ReqContext, authTokenService models.UserTokenService, token *models.UserToken) macaron.BeforeFunc { - return func(w macaron.ResponseWriter) { - // if response has already been written, skip. - if w.Written() { - return - } - - // if the request is cancelled by the client we should not try - // to rotate the token since the client would not accept any result. - if errors.Is(ctx.Context.Req.Context().Err(), context.Canceled) { - return - } - - addr := ctx.RemoteAddr() - ip, err := network.GetIPFromAddress(addr) - if err != nil { - ctx.Logger.Debug("Failed to get client IP address", "addr", addr, "err", err) - ip = nil - } - rotated, err := authTokenService.TryRotateToken(ctx.Req.Context(), token, ip, ctx.Req.UserAgent()) - if err != nil { - ctx.Logger.Error("Failed to rotate token", "error", err) - return - } - - if rotated { - WriteSessionCookie(ctx, token.UnhashedToken, setting.LoginMaxLifetime) - } - } -} - -func AddDefaultResponseHeaders() macaron.Handler { +func AddDefaultResponseHeaders(cfg *setting.Cfg) macaron.Handler { return func(ctx *macaron.Context) { ctx.Resp.Before(func(w macaron.ResponseWriter) { // if response has already been written, skip. @@ -285,47 +29,46 @@ func AddDefaultResponseHeaders() macaron.Handler { } if !strings.HasPrefix(ctx.Req.URL.Path, "/api/datasources/proxy/") { - AddNoCacheHeaders(ctx.Resp) + addNoCacheHeaders(ctx.Resp) } - if !setting.AllowEmbedding { - AddXFrameOptionsDenyHeader(w) + if !cfg.AllowEmbedding { + addXFrameOptionsDenyHeader(w) } - AddSecurityHeaders(w) + addSecurityHeaders(w, cfg) }) } } -// AddSecurityHeaders adds various HTTP(S) response headers that enable various security protections behaviors in the client's browser. -func AddSecurityHeaders(w macaron.ResponseWriter) { - if (setting.Protocol == setting.HTTPSScheme || setting.Protocol == setting.HTTP2Scheme) && - setting.StrictTransportSecurity { - strictHeaderValues := []string{fmt.Sprintf("max-age=%v", setting.StrictTransportSecurityMaxAge)} - if setting.StrictTransportSecurityPreload { +// addSecurityHeaders adds HTTP(S) response headers that enable various security protections in the client's browser. +func addSecurityHeaders(w macaron.ResponseWriter, cfg *setting.Cfg) { + if (cfg.Protocol == setting.HTTPSScheme || cfg.Protocol == setting.HTTP2Scheme) && cfg.StrictTransportSecurity { + strictHeaderValues := []string{fmt.Sprintf("max-age=%v", cfg.StrictTransportSecurityMaxAge)} + if cfg.StrictTransportSecurityPreload { strictHeaderValues = append(strictHeaderValues, "preload") } - if setting.StrictTransportSecuritySubDomains { + if cfg.StrictTransportSecuritySubDomains { strictHeaderValues = append(strictHeaderValues, "includeSubDomains") } w.Header().Add("Strict-Transport-Security", strings.Join(strictHeaderValues, "; ")) } - if setting.ContentTypeProtectionHeader { + if cfg.ContentTypeProtectionHeader { w.Header().Add("X-Content-Type-Options", "nosniff") } - if setting.XSSProtectionHeader { + if cfg.XSSProtectionHeader { w.Header().Add("X-XSS-Protection", "1; mode=block") } } -func AddNoCacheHeaders(w macaron.ResponseWriter) { +func addNoCacheHeaders(w macaron.ResponseWriter) { w.Header().Add("Cache-Control", "no-cache") w.Header().Add("Pragma", "no-cache") w.Header().Add("Expires", "-1") } -func AddXFrameOptionsDenyHeader(w macaron.ResponseWriter) { +func addXFrameOptionsDenyHeader(w macaron.ResponseWriter) { w.Header().Add("X-Frame-Options", "deny") } diff --git a/pkg/middleware/middleware_basic_auth_test.go b/pkg/middleware/middleware_basic_auth_test.go index 447d715dd79..f403264e749 100644 --- a/pkg/middleware/middleware_basic_auth_test.go +++ b/pkg/middleware/middleware_basic_auth_test.go @@ -7,6 +7,7 @@ import ( "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/login" "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/services/contexthandler" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util" "github.com/stretchr/testify/assert" @@ -14,19 +15,13 @@ import ( ) func TestMiddlewareBasicAuth(t *testing.T) { - var origBasicAuthEnabled = setting.BasicAuthEnabled - var origDisableBruteForceLoginProtection = setting.DisableBruteForceLoginProtection - t.Cleanup(func() { - setting.BasicAuthEnabled = origBasicAuthEnabled - setting.DisableBruteForceLoginProtection = origDisableBruteForceLoginProtection - }) - setting.BasicAuthEnabled = true - setting.DisableBruteForceLoginProtection = true - - bus.ClearBusHandlers() - const id int64 = 12 + configure := func(cfg *setting.Cfg) { + cfg.BasicAuthEnabled = true + cfg.DisableBruteForceLoginProtection = true + } + middlewareScenario(t, "Valid API key", func(t *testing.T, sc *scenarioContext) { const orgID int64 = 2 keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd") @@ -44,16 +39,15 @@ func TestMiddlewareBasicAuth(t *testing.T) { assert.True(t, sc.context.IsSignedIn) assert.Equal(t, orgID, sc.context.OrgId) assert.Equal(t, models.ROLE_EDITOR, sc.context.OrgRole) - }) + }, configure) middlewareScenario(t, "Handle auth", func(t *testing.T, sc *scenarioContext) { const password = "MyPass" const salt = "Salt" const orgID int64 = 2 - t.Cleanup(bus.ClearBusHandlers) - bus.AddHandler("grafana-auth", func(query *models.LoginUserQuery) error { + t.Log("Handling LoginUserQuery") encoded, err := util.EncodePassword(password, salt) if err != nil { return err @@ -66,6 +60,7 @@ func TestMiddlewareBasicAuth(t *testing.T) { }) bus.AddHandler("get-sign-user", func(query *models.GetSignedInUserQuery) error { + t.Log("Handling GetSignedInUserQuery") query.Result = &models.SignedInUser{OrgId: orgID, UserId: id} return nil }) @@ -76,7 +71,7 @@ func TestMiddlewareBasicAuth(t *testing.T) { assert.True(t, sc.context.IsSignedIn) assert.Equal(t, orgID, sc.context.OrgId) assert.Equal(t, id, sc.context.UserId) - }) + }, configure) middlewareScenario(t, "Auth sequence", func(t *testing.T, sc *scenarioContext) { const password = "MyPass" @@ -104,10 +99,11 @@ func TestMiddlewareBasicAuth(t *testing.T) { authHeader := util.GetBasicAuthHeader("myUser", password) sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec() + require.NotNil(t, sc.context) assert.True(t, sc.context.IsSignedIn) assert.Equal(t, id, sc.context.UserId) - }) + }, configure) middlewareScenario(t, "Should return error if user is not found", func(t *testing.T, sc *scenarioContext) { sc.fakeReq("GET", "/") @@ -118,8 +114,8 @@ func TestMiddlewareBasicAuth(t *testing.T) { require.Error(t, err) assert.Equal(t, 401, sc.resp.Code) - assert.Equal(t, errStringInvalidUsernamePassword, sc.respJson["message"]) - }) + assert.Equal(t, contexthandler.InvalidUsernamePassword, sc.respJson["message"]) + }, configure) middlewareScenario(t, "Should return error if user & password do not match", func(t *testing.T, sc *scenarioContext) { bus.AddHandler("user-query", func(loginUserQuery *models.GetUserByLoginQuery) error { @@ -134,6 +130,6 @@ func TestMiddlewareBasicAuth(t *testing.T) { require.Error(t, err) assert.Equal(t, 401, sc.resp.Code) - assert.Equal(t, errStringInvalidUsernamePassword, sc.respJson["message"]) - }) + assert.Equal(t, contexthandler.InvalidUsernamePassword, sc.respJson["message"]) + }, configure) } diff --git a/pkg/middleware/middleware_test.go b/pkg/middleware/middleware_test.go index b4dfde543b9..ff930813cbc 100644 --- a/pkg/middleware/middleware_test.go +++ b/pkg/middleware/middleware_test.go @@ -6,7 +6,6 @@ import ( "fmt" "net" "net/http" - "net/http/httptest" "path/filepath" "testing" "time" @@ -18,31 +17,30 @@ import ( "github.com/grafana/grafana/pkg/api/dtos" "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/components/gtime" - "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/remotecache" - "github.com/grafana/grafana/pkg/middleware/authproxy" + "github.com/grafana/grafana/pkg/login" "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/registry" "github.com/grafana/grafana/pkg/services/auth" - "github.com/grafana/grafana/pkg/services/login" + "github.com/grafana/grafana/pkg/services/contexthandler" + "github.com/grafana/grafana/pkg/services/contexthandler/authproxy" + "github.com/grafana/grafana/pkg/services/rendering" + "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util" ) const errorTemplate = "error-template" -func mockGetTime() { +func fakeGetTime() func() time.Time { var timeSeed int64 - getTime = func() time.Time { + return func() time.Time { fakeNow := time.Unix(timeSeed, 0) timeSeed++ return fakeNow } } -func resetGetTime() { - getTime = time.Now -} - func TestMiddleWareSecurityHeaders(t *testing.T) { origErrTemplateName := setting.ErrTemplateName t.Cleanup(func() { @@ -51,46 +49,32 @@ func TestMiddleWareSecurityHeaders(t *testing.T) { setting.ErrTemplateName = errorTemplate middlewareScenario(t, "middleware should get correct x-xss-protection header", func(t *testing.T, sc *scenarioContext) { - origXSSProtectionHeader := setting.XSSProtectionHeader - t.Cleanup(func() { - setting.XSSProtectionHeader = origXSSProtectionHeader - }) - setting.XSSProtectionHeader = true sc.fakeReq("GET", "/api/").exec() assert.Equal(t, "1; mode=block", sc.resp.Header().Get("X-XSS-Protection")) + }, func(cfg *setting.Cfg) { + cfg.XSSProtectionHeader = true }) middlewareScenario(t, "middleware should not get x-xss-protection when disabled", func(t *testing.T, sc *scenarioContext) { - origXSSProtectionHeader := setting.XSSProtectionHeader - t.Cleanup(func() { - setting.XSSProtectionHeader = origXSSProtectionHeader - }) - setting.XSSProtectionHeader = false sc.fakeReq("GET", "/api/").exec() assert.Empty(t, sc.resp.Header().Get("X-XSS-Protection")) + }, func(cfg *setting.Cfg) { + cfg.XSSProtectionHeader = false }) middlewareScenario(t, "middleware should add correct Strict-Transport-Security header", func(t *testing.T, sc *scenarioContext) { - origStrictTransportSecurity := setting.StrictTransportSecurity - origProtocol := setting.Protocol - origStrictTransportSecurityMaxAge := setting.StrictTransportSecurityMaxAge - t.Cleanup(func() { - setting.StrictTransportSecurity = origStrictTransportSecurity - setting.Protocol = origProtocol - setting.StrictTransportSecurityMaxAge = origStrictTransportSecurityMaxAge - }) - setting.StrictTransportSecurity = true - setting.Protocol = setting.HTTPSScheme - setting.StrictTransportSecurityMaxAge = 64000 - sc.fakeReq("GET", "/api/").exec() assert.Equal(t, "max-age=64000", sc.resp.Header().Get("Strict-Transport-Security")) - setting.StrictTransportSecurityPreload = true + sc.cfg.StrictTransportSecurityPreload = true sc.fakeReq("GET", "/api/").exec() assert.Equal(t, "max-age=64000; preload", sc.resp.Header().Get("Strict-Transport-Security")) - setting.StrictTransportSecuritySubDomains = true + sc.cfg.StrictTransportSecuritySubDomains = true sc.fakeReq("GET", "/api/").exec() assert.Equal(t, "max-age=64000; preload; includeSubDomains", sc.resp.Header().Get("Strict-Transport-Security")) + }, func(cfg *setting.Cfg) { + cfg.Protocol = setting.HTTPSScheme + cfg.StrictTransportSecurity = true + cfg.StrictTransportSecurityMaxAge = 64000 }) } @@ -151,13 +135,10 @@ func TestMiddlewareContext(t *testing.T) { middlewareScenario(t, "middleware should not add X-Frame-Options header for request when allowing embedding", func( t *testing.T, sc *scenarioContext) { - origAllowEmbedding := setting.AllowEmbedding - t.Cleanup(func() { - setting.AllowEmbedding = origAllowEmbedding - }) - setting.AllowEmbedding = true sc.fakeReq("GET", "/api/search").exec() assert.Empty(t, sc.resp.Header().Get("X-Frame-Options")) + }, func(cfg *setting.Cfg) { + cfg.AllowEmbedding = true }) middlewareScenario(t, "Invalid api key", func(t *testing.T, sc *scenarioContext) { @@ -166,7 +147,7 @@ func TestMiddlewareContext(t *testing.T) { assert.Empty(t, sc.resp.Header().Get("Set-Cookie")) assert.Equal(t, 401, sc.resp.Code) - assert.Equal(t, errStringInvalidAPIKey, sc.respJson["message"]) + assert.Equal(t, contexthandler.InvalidAPIKey, sc.respJson["message"]) }) middlewareScenario(t, "Valid api key", func(t *testing.T, sc *scenarioContext) { @@ -199,19 +180,18 @@ func TestMiddlewareContext(t *testing.T) { sc.fakeReq("GET", "/").withValidApiKey().exec() assert.Equal(t, 401, sc.resp.Code) - assert.Equal(t, errStringInvalidAPIKey, sc.respJson["message"]) + assert.Equal(t, contexthandler.InvalidAPIKey, sc.respJson["message"]) }) - middlewareScenario(t, "Valid api key, but expired", func(t *testing.T, sc *scenarioContext) { - mockGetTime() - defer resetGetTime() + middlewareScenario(t, "Valid API key, but expired", func(t *testing.T, sc *scenarioContext) { + sc.contextHandler.GetTime = fakeGetTime() keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd") require.NoError(t, err) bus.AddHandler("test", func(query *models.GetApiKeyByNameQuery) error { // api key expired one second before - expires := getTime().Add(-1 * time.Second).Unix() + expires := sc.contextHandler.GetTime().Add(-1 * time.Second).Unix() query.Result = &models.ApiKey{OrgId: 12, Role: models.ROLE_EDITOR, Key: keyhash, Expires: &expires} return nil @@ -223,7 +203,7 @@ func TestMiddlewareContext(t *testing.T) { assert.Equal(t, "Expired API key", sc.respJson["message"]) }) - middlewareScenario(t, "Non-expired auth token in cookie which not are being rotated", func( + middlewareScenario(t, "Non-expired auth token in cookie which is not being rotated", func( t *testing.T, sc *scenarioContext) { const userID int64 = 12 @@ -357,18 +337,6 @@ func TestMiddlewareContext(t *testing.T) { middlewareScenario(t, "When anonymous access is enabled", func(t *testing.T, sc *scenarioContext) { const orgID int64 = 2 - origAnonymousEnabled := setting.AnonymousEnabled - origAnonymousOrgName := setting.AnonymousOrgName - origAnonymousOrgRole := setting.AnonymousOrgRole - t.Cleanup(func() { - setting.AnonymousEnabled = origAnonymousEnabled - setting.AnonymousOrgName = origAnonymousOrgName - setting.AnonymousOrgRole = origAnonymousOrgRole - }) - setting.AnonymousEnabled = true - setting.AnonymousOrgName = "test" - setting.AnonymousOrgRole = string(models.ROLE_EDITOR) - bus.AddHandler("test", func(query *models.GetOrgByNameQuery) error { assert.Equal(t, "test", query.Name) @@ -382,35 +350,24 @@ func TestMiddlewareContext(t *testing.T) { assert.Equal(t, orgID, sc.context.OrgId) assert.Equal(t, models.ROLE_EDITOR, sc.context.OrgRole) assert.False(t, sc.context.IsSignedIn) + }, func(cfg *setting.Cfg) { + cfg.AnonymousEnabled = true + cfg.AnonymousOrgName = "test" + cfg.AnonymousOrgRole = string(models.ROLE_EDITOR) }) t.Run("auth_proxy", func(t *testing.T) { const userID int64 = 33 const orgID int64 = 4 - origAuthProxyEnabled := setting.AuthProxyEnabled - origAuthProxyWhitelist := setting.AuthProxyWhitelist - origAuthProxyAutoSignUp := setting.AuthProxyAutoSignUp - origLDAPEnabled := setting.LDAPEnabled - origAuthProxyHeaderName := setting.AuthProxyHeaderName - origAuthProxyHeaderProperty := setting.AuthProxyHeaderProperty - origAuthProxyHeaders := setting.AuthProxyHeaders - t.Cleanup(func() { - setting.AuthProxyEnabled = origAuthProxyEnabled - setting.AuthProxyWhitelist = origAuthProxyWhitelist - setting.AuthProxyAutoSignUp = origAuthProxyAutoSignUp - setting.LDAPEnabled = origLDAPEnabled - setting.AuthProxyHeaderName = origAuthProxyHeaderName - setting.AuthProxyHeaderProperty = origAuthProxyHeaderProperty - setting.AuthProxyHeaders = origAuthProxyHeaders - }) - setting.AuthProxyEnabled = true - setting.AuthProxyWhitelist = "" - setting.AuthProxyAutoSignUp = true - setting.LDAPEnabled = true - setting.AuthProxyHeaderName = "X-WEBAUTH-USER" - setting.AuthProxyHeaderProperty = "username" - setting.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS"} + configure := func(cfg *setting.Cfg) { + cfg.AuthProxyEnabled = true + cfg.AuthProxyAutoSignUp = true + cfg.LDAPEnabled = true + cfg.AuthProxyHeaderName = "X-WEBAUTH-USER" + cfg.AuthProxyHeaderProperty = "username" + cfg.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS"} + } const hdrName = "markelog" const group = "grafana-core-team" @@ -426,25 +383,16 @@ func TestMiddlewareContext(t *testing.T) { require.NoError(t, err) sc.fakeReq("GET", "/") - sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) + sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) sc.req.Header.Set("X-WEBAUTH-GROUPS", group) sc.exec() assert.True(t, sc.context.IsSignedIn) assert.Equal(t, userID, sc.context.UserId) assert.Equal(t, orgID, sc.context.OrgId) - }) + }, configure) middlewareScenario(t, "Should respect auto signup option", func(t *testing.T, sc *scenarioContext) { - origLDAPEnabled = setting.LDAPEnabled - origAuthProxyAutoSignUp = setting.AuthProxyAutoSignUp - t.Cleanup(func() { - setting.LDAPEnabled = origLDAPEnabled - setting.AuthProxyAutoSignUp = origAuthProxyAutoSignUp - }) - setting.LDAPEnabled = false - setting.AuthProxyAutoSignUp = false - var actualAuthProxyAutoSignUp *bool = nil bus.AddHandler("test", func(cmd *models.UpsertUserCommand) error { @@ -453,24 +401,19 @@ func TestMiddlewareContext(t *testing.T) { }) sc.fakeReq("GET", "/") - sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) + sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) sc.exec() assert.False(t, *actualAuthProxyAutoSignUp) - assert.Equal(t, sc.resp.Code, 407) + assert.Equal(t, 407, sc.resp.Code) assert.Nil(t, sc.context) + }, func(cfg *setting.Cfg) { + configure(cfg) + cfg.LDAPEnabled = false + cfg.AuthProxyAutoSignUp = false }) middlewareScenario(t, "Should create an user from a header", func(t *testing.T, sc *scenarioContext) { - origLDAPEnabled = setting.LDAPEnabled - origAuthProxyAutoSignUp = setting.AuthProxyAutoSignUp - t.Cleanup(func() { - setting.LDAPEnabled = origLDAPEnabled - setting.AuthProxyAutoSignUp = origAuthProxyAutoSignUp - }) - setting.LDAPEnabled = false - setting.AuthProxyAutoSignUp = true - bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { if query.UserId > 0 { query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID} @@ -485,24 +428,22 @@ func TestMiddlewareContext(t *testing.T) { }) sc.fakeReq("GET", "/") - sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) + sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) sc.exec() assert.True(t, sc.context.IsSignedIn) assert.Equal(t, userID, sc.context.UserId) assert.Equal(t, orgID, sc.context.OrgId) + }, func(cfg *setting.Cfg) { + configure(cfg) + cfg.LDAPEnabled = false + cfg.AuthProxyAutoSignUp = true }) middlewareScenario(t, "Should get an existing user from header", func(t *testing.T, sc *scenarioContext) { const userID int64 = 12 const orgID int64 = 2 - origLDAPEnabled = setting.LDAPEnabled - t.Cleanup(func() { - setting.LDAPEnabled = origLDAPEnabled - }) - setting.LDAPEnabled = false - bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID} return nil @@ -514,24 +455,18 @@ func TestMiddlewareContext(t *testing.T) { }) sc.fakeReq("GET", "/") - sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) + sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) sc.exec() assert.True(t, sc.context.IsSignedIn) assert.Equal(t, userID, sc.context.UserId) assert.Equal(t, orgID, sc.context.OrgId) + }, func(cfg *setting.Cfg) { + configure(cfg) + cfg.LDAPEnabled = false }) middlewareScenario(t, "Should allow the request from whitelist IP", func(t *testing.T, sc *scenarioContext) { - origAuthProxyWhitelist = setting.AuthProxyWhitelist - origLDAPEnabled = setting.LDAPEnabled - t.Cleanup(func() { - setting.AuthProxyWhitelist = origAuthProxyWhitelist - setting.LDAPEnabled = origLDAPEnabled - }) - setting.AuthProxyWhitelist = "192.168.1.0/24, 2001::0/120" - setting.LDAPEnabled = false - bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID} return nil @@ -543,25 +478,20 @@ func TestMiddlewareContext(t *testing.T) { }) sc.fakeReq("GET", "/") - sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) + sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) sc.req.RemoteAddr = "[2001::23]:12345" sc.exec() assert.True(t, sc.context.IsSignedIn) assert.Equal(t, userID, sc.context.UserId) assert.Equal(t, orgID, sc.context.OrgId) + }, func(cfg *setting.Cfg) { + configure(cfg) + cfg.AuthProxyWhitelist = "192.168.1.0/24, 2001::0/120" + cfg.LDAPEnabled = false }) - middlewareScenario(t, "Should not allow the request from whitelist IP", func(t *testing.T, sc *scenarioContext) { - origAuthProxyWhitelist = setting.AuthProxyWhitelist - origLDAPEnabled = setting.LDAPEnabled - t.Cleanup(func() { - setting.AuthProxyWhitelist = origAuthProxyWhitelist - setting.LDAPEnabled = origLDAPEnabled - }) - setting.AuthProxyWhitelist = "8.8.8.8" - setting.LDAPEnabled = false - + middlewareScenario(t, "Should not allow the request from whitelisted IP", func(t *testing.T, sc *scenarioContext) { bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID} return nil @@ -573,12 +503,16 @@ func TestMiddlewareContext(t *testing.T) { }) sc.fakeReq("GET", "/") - sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) + sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) sc.req.RemoteAddr = "[2001::23]:12345" sc.exec() assert.Equal(t, 407, sc.resp.Code) assert.Nil(t, sc.context) + }, func(cfg *setting.Cfg) { + configure(cfg) + cfg.AuthProxyWhitelist = "8.8.8.8" + cfg.LDAPEnabled = false }) middlewareScenario(t, "Should return 407 status code if LDAP says no", func(t *testing.T, sc *scenarioContext) { @@ -587,12 +521,12 @@ func TestMiddlewareContext(t *testing.T) { }) sc.fakeReq("GET", "/") - sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) + sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) sc.exec() assert.Equal(t, 407, sc.resp.Code) assert.Nil(t, sc.context) - }) + }, configure) middlewareScenario(t, "Should return 407 status code if there is cache mishap", func(t *testing.T, sc *scenarioContext) { bus.AddHandler("Do not have the user", func(query *models.GetSignedInUserQuery) error { @@ -600,52 +534,53 @@ func TestMiddlewareContext(t *testing.T) { }) sc.fakeReq("GET", "/") - sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) + sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName) sc.exec() assert.Equal(t, 407, sc.resp.Code) assert.Nil(t, sc.context) - }) + }, configure) }) } -func middlewareScenario(t *testing.T, desc string, fn scenarioFunc) { +func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(*setting.Cfg)) { t.Helper() t.Run(desc, func(t *testing.T) { t.Cleanup(bus.ClearBusHandlers) - origLoginCookieName := setting.LoginCookieName - origLoginMaxLifetime := setting.LoginMaxLifetime - t.Cleanup(func() { - setting.LoginCookieName = origLoginCookieName - setting.LoginMaxLifetime = origLoginMaxLifetime - }) - setting.LoginCookieName = "grafana_session" - var err error - setting.LoginMaxLifetime, err = gtime.ParseDuration("30d") + loginMaxLifetime, err := gtime.ParseDuration("30d") require.NoError(t, err) + cfg := setting.NewCfg() + cfg.LoginCookieName = "grafana_session" + cfg.LoginMaxLifetime = loginMaxLifetime + for _, cb := range cbs { + cb(cfg) + } - sc := &scenarioContext{t: t} + sc := &scenarioContext{t: t, cfg: cfg} viewsPath, err := filepath.Abs("../../public/views") require.NoError(t, err) sc.m = macaron.New() - sc.m.Use(AddDefaultResponseHeaders()) + sc.m.Use(AddDefaultResponseHeaders(cfg)) sc.m.Use(macaron.Renderer(macaron.RenderOptions{ Directory: viewsPath, Delims: macaron.Delims{Left: "[[", Right: "]]"}, })) - sc.userAuthTokenService = auth.NewFakeUserAuthTokenService() - sc.remoteCacheService = remotecache.NewFakeStore(t) - - sc.m.Use(GetContextHandler(sc.userAuthTokenService, sc.remoteCacheService, nil)) - + ctxHdlr := getContextHandler(t, cfg) + sc.contextHandler = ctxHdlr + sc.m.Use(ctxHdlr.Middleware) sc.m.Use(OrgRedirect()) + sc.userAuthTokenService = ctxHdlr.AuthTokenService.(*auth.FakeUserAuthTokenService) + sc.remoteCacheService = ctxHdlr.RemoteCache + sc.defaultHandler = func(c *models.ReqContext) { + require.NotNil(t, c) + t.Log("Default HTTP handler called") sc.context = c if sc.handlerFunc != nil { sc.handlerFunc(sc.context) @@ -662,106 +597,52 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc) { }) } -func TestDontRotateTokensOnCancelledRequests(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - reqContext, _, err := initTokenRotationTest(ctx, t) - require.NoError(t, err) - - tryRotateCallCount := 0 - uts := &auth.FakeUserAuthTokenService{ - TryRotateTokenProvider: func(ctx context.Context, token *models.UserToken, clientIP net.IP, - userAgent string) (bool, error) { - tryRotateCallCount++ - return false, nil - }, - } - - token := &models.UserToken{AuthToken: "oldtoken"} - - fn := rotateEndOfRequestFunc(reqContext, uts, token) - cancel() - fn(reqContext.Resp) - - assert.Equal(t, 0, tryRotateCallCount, "Token rotation was attempted") -} - -func TestTokenRotationAtEndOfRequest(t *testing.T) { - reqContext, rr, err := initTokenRotationTest(context.Background(), t) - require.NoError(t, err) - - uts := &auth.FakeUserAuthTokenService{ - TryRotateTokenProvider: func(ctx context.Context, token *models.UserToken, clientIP net.IP, - userAgent string) (bool, error) { - newToken, err := util.RandomHex(16) - require.NoError(t, err) - token.AuthToken = newToken - return true, nil - }, - } - - token := &models.UserToken{AuthToken: "oldtoken"} - - rotateEndOfRequestFunc(reqContext, uts, token)(reqContext.Resp) - - foundLoginCookie := false - resp := rr.Result() - defer resp.Body.Close() - for _, c := range resp.Cookies() { - if c.Name == "login_token" { - foundLoginCookie = true - - require.NotEqual(t, token.AuthToken, c.Value, "Auth token is still the same") - } - } - - assert.True(t, foundLoginCookie, "Could not find cookie") -} - -func initTokenRotationTest(ctx context.Context, t *testing.T) (*models.ReqContext, *httptest.ResponseRecorder, error) { +func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHandler { t.Helper() - origLoginCookieName := setting.LoginCookieName - origLoginMaxLifetime := setting.LoginMaxLifetime - t.Cleanup(func() { - setting.LoginCookieName = origLoginCookieName - setting.LoginMaxLifetime = origLoginMaxLifetime - }) - setting.LoginCookieName = "login_token" - var err error - setting.LoginMaxLifetime, err = gtime.ParseDuration("7d") - if err != nil { - return nil, nil, err + sqlStore := sqlstore.InitTestDB(t) + remoteCacheSvc := &remotecache.RemoteCache{} + if cfg == nil { + cfg = setting.NewCfg() } + cfg.RemoteCacheOptions = &setting.RemoteCacheOptions{ + Name: "database", + } + userAuthTokenSvc := auth.NewFakeUserAuthTokenService() + renderSvc := &fakeRenderService{} + ctxHdlr := &contexthandler.ContextHandler{} - rr := httptest.NewRecorder() - req, err := http.NewRequestWithContext(ctx, "", "", nil) - if err != nil { - return nil, nil, err - } - reqContext := &models.ReqContext{ - Context: &macaron.Context{ - Req: macaron.Request{ - Request: req, - }, + err := registry.BuildServiceGraph([]interface{}{cfg}, []*registry.Descriptor{ + { + Name: sqlstore.ServiceName, + Instance: sqlStore, }, - Logger: log.New("testlogger"), - } + { + Name: remotecache.ServiceName, + Instance: remoteCacheSvc, + }, + { + Name: auth.ServiceName, + Instance: userAuthTokenSvc, + }, + { + Name: rendering.ServiceName, + Instance: renderSvc, + }, + { + Name: contexthandler.ServiceName, + Instance: ctxHdlr, + }, + }) + require.NoError(t, err) - mw := mockWriter{rr} - reqContext.Resp = mw - - return reqContext, rr, nil + return ctxHdlr } -type mockWriter struct { - *httptest.ResponseRecorder +type fakeRenderService struct { + rendering.Service } -func (mw mockWriter) Flush() {} -func (mw mockWriter) Status() int { return 0 } -func (mw mockWriter) Size() int { return 0 } -func (mw mockWriter) Written() bool { return false } -func (mw mockWriter) Before(macaron.BeforeFunc) {} -func (mw mockWriter) Push(target string, opts *http.PushOptions) error { +func (s *fakeRenderService) Init() error { return nil } diff --git a/pkg/middleware/rate_limit_test.go b/pkg/middleware/rate_limit_test.go index 53eb02d2bfd..c78bbe9136a 100644 --- a/pkg/middleware/rate_limit_test.go +++ b/pkg/middleware/rate_limit_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/setting" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -18,6 +19,8 @@ type advanceTimeFunc func(deltaTime time.Duration) type rateLimiterScenarioFunc func(c execFunc, t advanceTimeFunc) func rateLimiterScenario(t *testing.T, desc string, rps int, burst int, fn rateLimiterScenarioFunc) { + t.Helper() + t.Run(desc, func(t *testing.T) { defaultHandler := func(c *models.ReqContext) { resp := make(map[string]interface{}) @@ -26,12 +29,14 @@ func rateLimiterScenario(t *testing.T, desc string, rps int, burst int, fn rateL } currentTime := time.Now() + cfg := setting.NewCfg() + m := macaron.New() m.Use(macaron.Renderer(macaron.RenderOptions{ Directory: "", Delims: macaron.Delims{Left: "[[", Right: "]]"}, })) - m.Use(GetContextHandler(nil, nil, nil)) + m.Use(getContextHandler(t, cfg).Middleware) m.Get("/foo", RateLimit(rps, burst, func() time.Time { return currentTime }), defaultHandler) fn(func() *httptest.ResponseRecorder { diff --git a/pkg/middleware/recovery.go b/pkg/middleware/recovery.go index 8048cc6f53b..5f564135826 100644 --- a/pkg/middleware/recovery.go +++ b/pkg/middleware/recovery.go @@ -103,7 +103,7 @@ func function(pc uintptr) []byte { // Recovery returns a middleware that recovers from any panics and writes a 500 if there was one. // While Martini is in development mode, Recovery will also output the panic as HTML. -func Recovery() macaron.Handler { +func Recovery(cfg *setting.Cfg) macaron.Handler { return func(c *macaron.Context) { defer func() { if r := recover(); r != nil { @@ -134,7 +134,7 @@ func Recovery() macaron.Handler { c.Data["Title"] = "Server Error" c.Data["AppSubUrl"] = setting.AppSubUrl - c.Data["Theme"] = setting.DefaultTheme + c.Data["Theme"] = cfg.DefaultTheme if setting.Env == setting.Dev { if err, ok := r.(error); ok { @@ -158,7 +158,7 @@ func Recovery() macaron.Handler { c.JSON(500, resp) } else { - c.HTML(500, setting.ErrTemplateName) + c.HTML(500, cfg.ErrTemplateName) } } }() diff --git a/pkg/middleware/recovery_test.go b/pkg/middleware/recovery_test.go index 5507ba5aac5..9842b4ad659 100644 --- a/pkg/middleware/recovery_test.go +++ b/pkg/middleware/recovery_test.go @@ -16,8 +16,6 @@ import ( ) func TestRecoveryMiddleware(t *testing.T) { - setting.ErrTemplateName = "error-template" - t.Run("Given an API route that panics", func(t *testing.T) { apiURL := "/api/whatever" recoveryScenario(t, "recovery middleware should return json", apiURL, func(t *testing.T, sc *scenarioContext) { @@ -52,18 +50,21 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) { t.Run(desc, func(t *testing.T) { defer bus.ClearBusHandlers() + cfg := setting.NewCfg() + cfg.ErrTemplateName = "error-template" sc := &scenarioContext{ t: t, url: url, + cfg: cfg, } viewsPath, err := filepath.Abs("../../public/views") require.NoError(t, err) sc.m = macaron.New() - sc.m.Use(Recovery()) + sc.m.Use(Recovery(cfg)) - sc.m.Use(AddDefaultResponseHeaders()) + sc.m.Use(AddDefaultResponseHeaders(cfg)) sc.m.Use(macaron.Renderer(macaron.RenderOptions{ Directory: viewsPath, Delims: macaron.Delims{Left: "[[", Right: "]]"}, @@ -72,7 +73,8 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) { sc.userAuthTokenService = auth.NewFakeUserAuthTokenService() sc.remoteCacheService = remotecache.NewFakeStore(t) - sc.m.Use(GetContextHandler(sc.userAuthTokenService, sc.remoteCacheService, nil)) + contextHandler := getContextHandler(t, nil) + sc.m.Use(contextHandler.Middleware) // mock out gc goroutine sc.m.Use(OrgRedirect()) diff --git a/pkg/middleware/render_auth.go b/pkg/middleware/render_auth.go deleted file mode 100644 index ff5e33c842e..00000000000 --- a/pkg/middleware/render_auth.go +++ /dev/null @@ -1,31 +0,0 @@ -package middleware - -import ( - "time" - - "github.com/grafana/grafana/pkg/models" - "github.com/grafana/grafana/pkg/services/rendering" -) - -func initContextWithRenderAuth(ctx *models.ReqContext, renderService rendering.Service) bool { - key := ctx.GetCookie("renderKey") - if key == "" { - return false - } - - renderUser, exists := renderService.GetRenderUser(key) - if !exists { - ctx.JsonApiErr(401, "Invalid Render Key", nil) - return true - } - - ctx.IsSignedIn = true - ctx.SignedInUser = &models.SignedInUser{ - OrgId: renderUser.OrgID, - UserId: renderUser.UserID, - OrgRole: models.RoleType(renderUser.OrgRole), - } - ctx.IsRenderCall = true - ctx.LastSeenAt = time.Now() - return true -} diff --git a/pkg/middleware/testing.go b/pkg/middleware/testing.go index 8229fb36116..72cccac2843 100644 --- a/pkg/middleware/testing.go +++ b/pkg/middleware/testing.go @@ -11,6 +11,7 @@ import ( "github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/auth" + "github.com/grafana/grafana/pkg/services/contexthandler" "github.com/grafana/grafana/pkg/setting" "github.com/stretchr/testify/require" ) @@ -29,6 +30,8 @@ type scenarioContext struct { url string userAuthTokenService *auth.FakeUserAuthTokenService remoteCacheService *remotecache.RemoteCache + cfg *setting.Cfg + contextHandler *contexthandler.ContextHandler req *http.Request } @@ -94,9 +97,9 @@ func (sc *scenarioContext) exec() { } if sc.tokenSessionCookie != "" { - sc.t.Log(`Adding cookie`, "name", setting.LoginCookieName, "value", sc.tokenSessionCookie) + sc.t.Log(`Adding cookie`, "name", sc.cfg.LoginCookieName, "value", sc.tokenSessionCookie) sc.req.AddCookie(&http.Cookie{ - Name: setting.LoginCookieName, + Name: sc.cfg.LoginCookieName, Value: sc.tokenSessionCookie, }) } diff --git a/pkg/models/user_auth.go b/pkg/models/user_auth.go index 3810de46b08..652e0517075 100644 --- a/pkg/models/user_auth.go +++ b/pkg/models/user_auth.go @@ -3,6 +3,7 @@ package models import ( "time" + "github.com/grafana/grafana/pkg/setting" "golang.org/x/oauth2" ) @@ -84,6 +85,7 @@ type LoginUserQuery struct { User *User IpAddress string AuthModule string + Cfg *setting.Cfg } type GetUserByAuthInfoQuery struct { diff --git a/pkg/registry/di.go b/pkg/registry/di.go new file mode 100644 index 00000000000..4d58fe2a6ec --- /dev/null +++ b/pkg/registry/di.go @@ -0,0 +1,45 @@ +package registry + +import ( + "fmt" + + "github.com/facebookgo/inject" +) + +// BuildServiceGraph builds a graph of services and their dependencies. +// The services are initialized after the graph is built. +func BuildServiceGraph(objs []interface{}, services []*Descriptor) error { + if services == nil { + services = GetServices() + } + for _, service := range services { + objs = append(objs, service.Instance) + } + + serviceGraph := inject.Graph{} + + // Provide services and their dependencies to the graph. + for _, obj := range objs { + if err := serviceGraph.Provide(&inject.Object{Value: obj}); err != nil { + return fmt.Errorf("failed to provide object to the graph: %w", err) + } + } + + // Resolve services and their dependencies. + if err := serviceGraph.Populate(); err != nil { + return fmt.Errorf("failed to populate service dependencies: %w", err) + } + + // Initialize services. + for _, service := range services { + if IsDisabled(service.Instance) { + continue + } + + if err := service.Instance.Init(); err != nil { + return fmt.Errorf("service init failed: %w", err) + } + } + + return nil +} diff --git a/pkg/services/auth/auth_token.go b/pkg/services/auth/auth_token.go index 490cacf5888..a4affaded2a 100644 --- a/pkg/services/auth/auth_token.go +++ b/pkg/services/auth/auth_token.go @@ -18,8 +18,14 @@ import ( "github.com/grafana/grafana/pkg/util" ) +const ServiceName = "UserAuthTokenService" + func init() { - registry.RegisterService(&UserAuthTokenService{}) + registry.Register(®istry.Descriptor{ + Name: ServiceName, + Instance: &UserAuthTokenService{}, + InitPriority: registry.Medium, + }) } var getTime = time.Now diff --git a/pkg/services/auth/testing.go b/pkg/services/auth/testing.go index d4395ab647a..2b1718ed21b 100644 --- a/pkg/services/auth/testing.go +++ b/pkg/services/auth/testing.go @@ -57,8 +57,13 @@ func NewFakeUserAuthTokenService() *FakeUserAuthTokenService { } } -func (s *FakeUserAuthTokenService) CreateToken(ctx context.Context, userId int64, clientIP net.IP, - userAgent string) (*models.UserToken, error) { +// Init initializes the service. +// Required for dependency injection. +func (s *FakeUserAuthTokenService) Init() error { + return nil +} + +func (s *FakeUserAuthTokenService) CreateToken(ctx context.Context, userId int64, clientIP net.IP, userAgent string) (*models.UserToken, error) { return s.CreateTokenProvider(context.Background(), userId, clientIP, userAgent) } diff --git a/pkg/middleware/auth_proxy_test.go b/pkg/services/contexthandler/auth_proxy_test.go similarity index 54% rename from pkg/middleware/auth_proxy_test.go rename to pkg/services/contexthandler/auth_proxy_test.go index ee022d7f74c..790f7959a22 100644 --- a/pkg/middleware/auth_proxy_test.go +++ b/pkg/services/contexthandler/auth_proxy_test.go @@ -1,4 +1,4 @@ -package middleware +package contexthandler import ( "fmt" @@ -8,8 +8,12 @@ import ( "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/remotecache" - "github.com/grafana/grafana/pkg/middleware/authproxy" "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/registry" + "github.com/grafana/grafana/pkg/services/auth" + "github.com/grafana/grafana/pkg/services/contexthandler/authproxy" + "github.com/grafana/grafana/pkg/services/rendering" + "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/setting" "github.com/stretchr/testify/require" macaron "gopkg.in/macaron.v1" @@ -41,25 +45,16 @@ func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) { } return nil } - - origHeaderName := setting.AuthProxyHeaderName - origEnabled := setting.AuthProxyEnabled - origHeaderProperty := setting.AuthProxyHeaderProperty bus.AddHandler("", upsertHandler) bus.AddHandler("", getUserHandler) t.Cleanup(func() { - setting.AuthProxyHeaderName = origHeaderName - setting.AuthProxyEnabled = origEnabled - setting.AuthProxyHeaderProperty = origHeaderProperty bus.ClearBusHandlers() }) - setting.AuthProxyHeaderName = "X-Killa" - setting.AuthProxyEnabled = true - setting.AuthProxyHeaderProperty = "username" + svc := getContextHandler(t) + req, err := http.NewRequest("POST", "http://example.com", nil) require.NoError(t, err) - store := remotecache.NewFakeStore(t) ctx := &models.ReqContext{ Context: &macaron.Context{ Req: macaron.Request{ @@ -69,20 +64,72 @@ func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) { }, Logger: log.New("Test"), } - req.Header.Add(setting.AuthProxyHeaderName, name) + req.Header.Set(svc.Cfg.AuthProxyHeaderName, name) key := fmt.Sprintf(authproxy.CachePrefix, authproxy.HashCacheKey(name)) t.Logf("Injecting stale user ID in cache with key %q", key) - err = store.Set(key, int64(33), 0) + err = svc.RemoteCache.Set(key, int64(33), 0) require.NoError(t, err) - authEnabled := initContextWithAuthProxy(store, ctx, orgID) + authEnabled := svc.initContextWithAuthProxy(ctx, orgID) require.True(t, authEnabled) require.Equal(t, userID, ctx.SignedInUser.UserId) require.True(t, ctx.IsSignedIn) - i, err := store.Get(key) + i, err := svc.RemoteCache.Get(key) require.NoError(t, err) require.Equal(t, userID, i.(int64)) } + +type fakeRenderService struct { + rendering.Service +} + +func (s *fakeRenderService) Init() error { + return nil +} + +func getContextHandler(t *testing.T) *ContextHandler { + t.Helper() + + sqlStore := sqlstore.InitTestDB(t) + remoteCacheSvc := &remotecache.RemoteCache{} + + cfg := setting.NewCfg() + cfg.RemoteCacheOptions = &setting.RemoteCacheOptions{ + Name: "database", + } + cfg.AuthProxyHeaderName = "X-Killa" + cfg.AuthProxyEnabled = true + cfg.AuthProxyHeaderProperty = "username" + userAuthTokenSvc := auth.NewFakeUserAuthTokenService() + renderSvc := &fakeRenderService{} + svc := &ContextHandler{} + + err := registry.BuildServiceGraph([]interface{}{cfg}, []*registry.Descriptor{ + { + Name: sqlstore.ServiceName, + Instance: sqlStore, + }, + { + Name: remotecache.ServiceName, + Instance: remoteCacheSvc, + }, + { + Name: auth.ServiceName, + Instance: userAuthTokenSvc, + }, + { + Name: rendering.ServiceName, + Instance: renderSvc, + }, + { + Name: ServiceName, + Instance: svc, + }, + }) + require.NoError(t, err) + + return svc +} diff --git a/pkg/middleware/authproxy/auth_proxy.go b/pkg/services/contexthandler/authproxy/authproxy.go similarity index 80% rename from pkg/middleware/authproxy/auth_proxy.go rename to pkg/services/contexthandler/authproxy/authproxy.go index 2b80e0f8fd4..df3d52f3564 100644 --- a/pkg/middleware/authproxy/auth_proxy.go +++ b/pkg/services/contexthandler/authproxy/authproxy.go @@ -32,7 +32,13 @@ const ( var getLDAPConfig = ldap.GetConfig // isLDAPEnabled checks if LDAP is enabled -var isLDAPEnabled = ldap.IsEnabled +var isLDAPEnabled = func(cfg *setting.Cfg) bool { + if cfg != nil { + return cfg.LDAPEnabled + } + + return setting.LDAPEnabled +} // newLDAP creates multiple LDAP instance var newLDAP = multildap.New @@ -42,18 +48,11 @@ var supportedHeaderFields = []string{"Name", "Email", "Login", "Groups"} // AuthProxy struct type AuthProxy struct { - store *remotecache.RemoteCache - ctx *models.ReqContext - orgID int64 - header string - - enabled bool - LDAPAllowSignup bool - AuthProxyAutoSignUp bool - whitelistIP string - headerType string - headers map[string]string - cacheTTL int + cfg *setting.Cfg + remoteCache *remotecache.RemoteCache + ctx *models.ReqContext + orgID int64 + header string } // Error auth proxy specific error @@ -77,35 +76,27 @@ func (err Error) Error() string { // Options for the AuthProxy type Options struct { - Store *remotecache.RemoteCache - Ctx *models.ReqContext - OrgID int64 + RemoteCache *remotecache.RemoteCache + Ctx *models.ReqContext + OrgID int64 } // New instance of the AuthProxy -func New(options *Options) *AuthProxy { - header := options.Ctx.Req.Header.Get(setting.AuthProxyHeaderName) - +func New(cfg *setting.Cfg, options *Options) *AuthProxy { + header := options.Ctx.Req.Header.Get(cfg.AuthProxyHeaderName) return &AuthProxy{ - store: options.Store, - ctx: options.Ctx, - orgID: options.OrgID, - header: header, - - enabled: setting.AuthProxyEnabled, - headerType: setting.AuthProxyHeaderProperty, - headers: setting.AuthProxyHeaders, - whitelistIP: setting.AuthProxyWhitelist, - cacheTTL: setting.AuthProxySyncTtl, - LDAPAllowSignup: setting.LDAPAllowSignup, - AuthProxyAutoSignUp: setting.AuthProxyAutoSignUp, + remoteCache: options.RemoteCache, + cfg: cfg, + ctx: options.Ctx, + orgID: options.OrgID, + header: header, } } // IsEnabled checks if the proxy auth is enabled func (auth *AuthProxy) IsEnabled() bool { // Bail if the setting is not enabled - return auth.enabled + return auth.cfg.AuthProxyEnabled } // HasHeader checks if the we have specified header @@ -113,15 +104,15 @@ func (auth *AuthProxy) HasHeader() bool { return len(auth.header) != 0 } -// IsAllowedIP compares presented IP with the whitelist one +// IsAllowedIP returns whether provided IP is allowed. func (auth *AuthProxy) IsAllowedIP() error { ip := auth.ctx.Req.RemoteAddr - if len(strings.TrimSpace(auth.whitelistIP)) == 0 { + if len(strings.TrimSpace(auth.cfg.AuthProxyWhitelist)) == 0 { return nil } - proxies := strings.Split(auth.whitelistIP, ",") + proxies := strings.Split(auth.cfg.AuthProxyWhitelist, ",") var proxyObjs []*net.IPNet for _, proxy := range proxies { result, err := coerceProxyAddress(proxy) @@ -181,7 +172,7 @@ func (auth *AuthProxy) Login(logger log.Logger, ignoreCache bool) (int64, error) } } - if isLDAPEnabled() { + if isLDAPEnabled(auth.cfg) { id, err := auth.LoginViaLDAP() if err != nil { if errors.Is(err, ldap.ErrInvalidCredentials) { @@ -205,7 +196,7 @@ func (auth *AuthProxy) Login(logger log.Logger, ignoreCache bool) (int64, error) func (auth *AuthProxy) GetUserViaCache(logger log.Logger) (int64, error) { cacheKey := auth.getKey() logger.Debug("Getting user ID via auth cache", "cacheKey", cacheKey) - userID, err := auth.store.Get(cacheKey) + userID, err := auth.remoteCache.Get(cacheKey) if err != nil { logger.Debug("Failed getting user ID via auth cache", "error", err) return 0, err @@ -219,7 +210,7 @@ func (auth *AuthProxy) GetUserViaCache(logger log.Logger) (int64, error) { func (auth *AuthProxy) RemoveUserFromCache(logger log.Logger) error { cacheKey := auth.getKey() logger.Debug("Removing user from auth cache", "cacheKey", cacheKey) - if err := auth.store.Delete(cacheKey); err != nil { + if err := auth.remoteCache.Delete(cacheKey); err != nil { return err } @@ -229,12 +220,13 @@ func (auth *AuthProxy) RemoveUserFromCache(logger log.Logger) error { // LoginViaLDAP logs in user via LDAP request func (auth *AuthProxy) LoginViaLDAP() (int64, error) { - config, err := getLDAPConfig() + config, err := getLDAPConfig(auth.cfg) if err != nil { return 0, newError("failed to get LDAP config", err) } - extUser, _, err := newLDAP(config.Servers).User(auth.header) + mldap := newLDAP(config.Servers) + extUser, _, err := mldap.User(auth.header) if err != nil { return 0, err } @@ -242,7 +234,7 @@ func (auth *AuthProxy) LoginViaLDAP() (int64, error) { // Have to sync grafana and LDAP user during log in upsert := &models.UpsertUserCommand{ ReqContext: auth.ctx, - SignupAllowed: auth.LDAPAllowSignup, + SignupAllowed: auth.cfg.LDAPAllowSignup, ExternalUser: extUser, } if err := bus.Dispatch(upsert); err != nil { @@ -259,7 +251,7 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) { AuthId: auth.header, } - switch auth.headerType { + switch auth.cfg.AuthProxyHeaderProperty { case "username": extUser.Login = auth.header @@ -284,7 +276,7 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) { upsert := &models.UpsertUserCommand{ ReqContext: auth.ctx, - SignupAllowed: setting.AuthProxyAutoSignUp, + SignupAllowed: auth.cfg.AuthProxyAutoSignUp, ExternalUser: extUser, } @@ -299,8 +291,7 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) { // headersIterator iterates over all non-empty supported additional headers func (auth *AuthProxy) headersIterator(fn func(field string, header string)) { for _, field := range supportedHeaderFields { - h := auth.headers[field] - + h := auth.cfg.AuthProxyHeaders[field] if h == "" { continue } @@ -311,8 +302,8 @@ func (auth *AuthProxy) headersIterator(fn func(field string, header string)) { } } -// GetSignedUser gets full signed user info. -func (auth *AuthProxy) GetSignedUser(userID int64) (*models.SignedInUser, error) { +// GetSignedUser gets full signed in user info. +func (auth *AuthProxy) GetSignedInUser(userID int64) (*models.SignedInUser, error) { query := &models.GetSignedInUserQuery{ OrgId: auth.orgID, UserId: userID, @@ -330,14 +321,14 @@ func (auth *AuthProxy) Remember(id int64) error { key := auth.getKey() // Check if user already in cache - userID, _ := auth.store.Get(key) + userID, _ := auth.remoteCache.Get(key) if userID != nil { return nil } - expiration := time.Duration(auth.cacheTTL) * time.Minute + expiration := time.Duration(auth.cfg.AuthProxySyncTTL) * time.Minute - err := auth.store.Set(key, id, expiration) + err := auth.remoteCache.Set(key, id, expiration) if err != nil { return err } @@ -353,5 +344,8 @@ func coerceProxyAddress(proxyAddr string) (*net.IPNet, error) { } _, network, err := net.ParseCIDR(proxyAddr) - return network, err + if err != nil { + return nil, fmt.Errorf("could not parse the network: %w", err) + } + return network, nil } diff --git a/pkg/middleware/authproxy/auth_proxy_test.go b/pkg/services/contexthandler/authproxy/authproxy_test.go similarity index 62% rename from pkg/middleware/authproxy/auth_proxy_test.go rename to pkg/services/contexthandler/authproxy/authproxy_test.go index 5a3604946e0..7c5130b0b90 100644 --- a/pkg/middleware/authproxy/auth_proxy_test.go +++ b/pkg/services/contexthandler/authproxy/authproxy_test.go @@ -47,9 +47,22 @@ func (m *fakeMultiLDAP) User(login string) ( return result, ldap.ServerConfig{}, nil } -func prepareMiddleware(t *testing.T, req *http.Request, store *remotecache.RemoteCache) *AuthProxy { +const hdrName = "markelog" + +func prepareMiddleware(t *testing.T, remoteCache *remotecache.RemoteCache, cb func(*http.Request, *setting.Cfg)) *AuthProxy { t.Helper() + cfg := setting.NewCfg() + cfg.AuthProxyHeaderName = "X-Killa" + + req, err := http.NewRequest("POST", "http://example.com", nil) + require.NoError(t, err) + req.Header.Set(cfg.AuthProxyHeaderName, hdrName) + + if cb != nil { + cb(req, cfg) + } + ctx := &models.ReqContext{ Context: &macaron.Context{ Req: macaron.Request{ @@ -58,10 +71,10 @@ func prepareMiddleware(t *testing.T, req *http.Request, store *remotecache.Remot }, } - auth := New(&Options{ - Store: store, - Ctx: ctx, - OrgID: 4, + auth := New(cfg, &Options{ + RemoteCache: remoteCache, + Ctx: ctx, + OrgID: 4, }) return auth @@ -69,24 +82,17 @@ func prepareMiddleware(t *testing.T, req *http.Request, store *remotecache.Remot func TestMiddlewareContext(t *testing.T) { logger := log.New("test") - req, err := http.NewRequest("POST", "http://example.com", nil) - require.NoError(t, err) - setting.AuthProxyHeaderName = "X-Killa" - store := remotecache.NewFakeStore(t) - - name := "markelog" - req.Header.Add(setting.AuthProxyHeaderName, name) + cache := remotecache.NewFakeStore(t) t.Run("When the cache only contains the main header with a simple cache key", func(t *testing.T) { const id int64 = 33 // Set cache key - key := fmt.Sprintf(CachePrefix, HashCacheKey(name)) - err := store.Set(key, id, 0) + key := fmt.Sprintf(CachePrefix, HashCacheKey(hdrName)) + err := cache.Set(key, id, 0) require.NoError(t, err) - // Set up the middleware - auth := prepareMiddleware(t, req, store) - assert.Equal(t, "auth-proxy-sync-ttl:0a7f3374e9659b10980fd66247b0cf2f", auth.getKey()) + auth := prepareMiddleware(t, cache, nil) + assert.Equal(t, key, auth.getKey()) gotID, err := auth.Login(logger, false) require.NoError(t, err) @@ -96,15 +102,16 @@ func TestMiddlewareContext(t *testing.T) { t.Run("When the cache key contains additional headers", func(t *testing.T) { const id int64 = 33 - setting.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS"} - group := "grafana-core-team" - req.Header.Add("X-WEBAUTH-GROUPS", group) + const group = "grafana-core-team" - key := fmt.Sprintf(CachePrefix, HashCacheKey(name+"-"+group)) - err := store.Set(key, id, 0) + key := fmt.Sprintf(CachePrefix, HashCacheKey(hdrName+"-"+group)) + err := cache.Set(key, id, 0) require.NoError(t, err) - auth := prepareMiddleware(t, req, store) + auth := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) { + req.Header.Set("X-WEBAUTH-GROUPS", group) + cfg.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS"} + }) assert.Equal(t, "auth-proxy-sync-ttl:14f69b7023baa0ac98c96b31cec07bc0", auth.getKey()) gotID, err := auth.Login(logger, false) @@ -115,12 +122,6 @@ func TestMiddlewareContext(t *testing.T) { func TestMiddlewareContext_ldap(t *testing.T) { logger := log.New("test") - req, err := http.NewRequest("POST", "http://example.com", nil) - require.NoError(t, err) - setting.AuthProxyHeaderName = "X-Killa" - - const headerName = "markelog" - req.Header.Add(setting.AuthProxyHeaderName, headerName) t.Run("Logs in via LDAP", func(t *testing.T) { const id int64 = 42 @@ -133,7 +134,16 @@ func TestMiddlewareContext_ldap(t *testing.T) { return nil }) - isLDAPEnabled = func() bool { + origIsLDAPEnabled := isLDAPEnabled + origGetLDAPConfig := getLDAPConfig + origNewLDAP := newLDAP + t.Cleanup(func() { + newLDAP = origNewLDAP + isLDAPEnabled = origIsLDAPEnabled + getLDAPConfig = origGetLDAPConfig + }) + + isLDAPEnabled = func(*setting.Cfg) bool { return true } @@ -141,7 +151,7 @@ func TestMiddlewareContext_ldap(t *testing.T) { ID: id, } - getLDAPConfig = func() (*ldap.Config, error) { + getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) { config := &ldap.Config{ Servers: []*ldap.ServerConfig{ { @@ -156,15 +166,9 @@ func TestMiddlewareContext_ldap(t *testing.T) { return stub } - defer func() { - newLDAP = multildap.New - isLDAPEnabled = ldap.IsEnabled - getLDAPConfig = ldap.GetConfig - }() + cache := remotecache.NewFakeStore(t) - store := remotecache.NewFakeStore(t) - - auth := prepareMiddleware(t, req, store) + auth := prepareMiddleware(t, cache, nil) gotID, err := auth.Login(logger, false) require.NoError(t, err) @@ -173,25 +177,28 @@ func TestMiddlewareContext_ldap(t *testing.T) { assert.True(t, stub.userCalled) }) - t.Run("Gets nice error if ldap is enabled but not configured", func(t *testing.T) { + t.Run("Gets nice error if LDAP is enabled, but not configured", func(t *testing.T) { const id int64 = 42 - isLDAPEnabled = func() bool { + origIsLDAPEnabled := isLDAPEnabled + origNewLDAP := newLDAP + origGetLDAPConfig := getLDAPConfig + t.Cleanup(func() { + isLDAPEnabled = origIsLDAPEnabled + newLDAP = origNewLDAP + getLDAPConfig = origGetLDAPConfig + }) + + isLDAPEnabled = func(*setting.Cfg) bool { return true } - getLDAPConfig = func() (*ldap.Config, error) { + getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) { return nil, errors.New("something went wrong") } - defer func() { - newLDAP = multildap.New - isLDAPEnabled = ldap.IsEnabled - getLDAPConfig = ldap.GetConfig - }() + cache := remotecache.NewFakeStore(t) - store := remotecache.NewFakeStore(t) - - auth := prepareMiddleware(t, req, store) + auth := prepareMiddleware(t, cache, nil) stub := &fakeMultiLDAP{ ID: id, diff --git a/pkg/services/contexthandler/contexthandler.go b/pkg/services/contexthandler/contexthandler.go new file mode 100644 index 00000000000..1194db9f507 --- /dev/null +++ b/pkg/services/contexthandler/contexthandler.go @@ -0,0 +1,448 @@ +// Package contexthandler contains the ContextHandler service. +package contexthandler + +import ( + "context" + "errors" + "strconv" + "strings" + "time" + + "github.com/grafana/grafana/pkg/bus" + "github.com/grafana/grafana/pkg/components/apikeygen" + "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/infra/network" + "github.com/grafana/grafana/pkg/infra/remotecache" + "github.com/grafana/grafana/pkg/middleware/cookies" + "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/registry" + "github.com/grafana/grafana/pkg/services/contexthandler/authproxy" + "github.com/grafana/grafana/pkg/services/login" + "github.com/grafana/grafana/pkg/services/rendering" + "github.com/grafana/grafana/pkg/services/sqlstore" + "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/util" + "gopkg.in/macaron.v1" +) + +const ( + InvalidUsernamePassword = "invalid username or password" + InvalidAPIKey = "invalid API key" +) + +const ServiceName = "ContextHandler" + +func init() { + registry.Register(®istry.Descriptor{ + Name: ServiceName, + Instance: &ContextHandler{}, + InitPriority: registry.High, + }) +} + +// ContextHandler is a middleware. +type ContextHandler struct { + Cfg *setting.Cfg `inject:""` + AuthTokenService models.UserTokenService `inject:""` + RemoteCache *remotecache.RemoteCache `inject:""` + RenderService rendering.Service `inject:""` + SQLStore *sqlstore.SQLStore `inject:""` + + // GetTime returns the current time. + // Stubbable by tests. + GetTime func() time.Time +} + +// Init initializes the service. +func (h *ContextHandler) Init() error { + return nil +} + +// Middleware provides a middleware to initialize the Macaron context. +func (h *ContextHandler) Middleware(c *macaron.Context) { + ctx := &models.ReqContext{ + Context: c, + SignedInUser: &models.SignedInUser{}, + IsSignedIn: false, + AllowAnonymous: false, + SkipCache: false, + Logger: log.New("context"), + } + + const headerName = "X-Grafana-Org-Id" + orgID := int64(0) + orgIDHeader := ctx.Req.Header.Get(headerName) + if orgIDHeader != "" { + id, err := strconv.ParseInt(orgIDHeader, 10, 64) + if err == nil { + orgID = id + } else { + ctx.Logger.Debug("Received invalid header", "header", headerName, "value", orgIDHeader) + } + } + + // the order in which these are tested are important + // look for api key in Authorization header first + // then init session and look for userId in session + // then look for api key in session (special case for render calls via api) + // then test if anonymous access is enabled + switch { + case h.initContextWithRenderAuth(ctx): + case h.initContextWithAPIKey(ctx): + case h.initContextWithBasicAuth(ctx, orgID): + case h.initContextWithAuthProxy(ctx, orgID): + case h.initContextWithToken(ctx, orgID): + case h.initContextWithAnonymousUser(ctx): + } + + ctx.Logger = log.New("context", "userId", ctx.UserId, "orgId", ctx.OrgId, "uname", ctx.Login) + ctx.Data["ctx"] = ctx + + c.Map(ctx) + + // update last seen every 5min + if ctx.ShouldUpdateLastSeenAt() { + ctx.Logger.Debug("Updating last user_seen_at", "user_id", ctx.UserId) + if err := bus.Dispatch(&models.UpdateUserLastSeenAtCommand{UserId: ctx.UserId}); err != nil { + ctx.Logger.Error("Failed to update last_seen_at", "error", err) + } + } +} + +func (h *ContextHandler) initContextWithAnonymousUser(ctx *models.ReqContext) bool { + if !h.Cfg.AnonymousEnabled { + return false + } + + orgQuery := models.GetOrgByNameQuery{Name: h.Cfg.AnonymousOrgName} + if err := bus.Dispatch(&orgQuery); err != nil { + log.Errorf(3, "Anonymous access organization error: '%s': %s", h.Cfg.AnonymousOrgName, err) + return false + } + + ctx.IsSignedIn = false + ctx.AllowAnonymous = true + ctx.SignedInUser = &models.SignedInUser{IsAnonymous: true} + ctx.OrgRole = models.RoleType(h.Cfg.AnonymousOrgRole) + ctx.OrgId = orgQuery.Result.Id + ctx.OrgName = orgQuery.Result.Name + return true +} + +func (h *ContextHandler) initContextWithAPIKey(ctx *models.ReqContext) bool { + header := ctx.Req.Header.Get("Authorization") + parts := strings.SplitN(header, " ", 2) + var keyString string + if len(parts) == 2 && parts[0] == "Bearer" { + keyString = parts[1] + } else { + username, password, err := util.DecodeBasicAuthHeader(header) + if err == nil && username == "api_key" { + keyString = password + } + } + + if keyString == "" { + return false + } + + // base64 decode key + decoded, err := apikeygen.Decode(keyString) + if err != nil { + ctx.JsonApiErr(401, InvalidAPIKey, err) + return true + } + + // fetch key + keyQuery := models.GetApiKeyByNameQuery{KeyName: decoded.Name, OrgId: decoded.OrgId} + if err := bus.Dispatch(&keyQuery); err != nil { + ctx.JsonApiErr(401, InvalidAPIKey, err) + return true + } + + apikey := keyQuery.Result + + // validate api key + isValid, err := apikeygen.IsValid(decoded, apikey.Key) + if err != nil { + ctx.JsonApiErr(500, "Validating API key failed", err) + return true + } + if !isValid { + ctx.JsonApiErr(401, InvalidAPIKey, err) + return true + } + + // check for expiration + getTime := h.GetTime + if getTime == nil { + getTime = time.Now + } + if apikey.Expires != nil && *apikey.Expires <= getTime().Unix() { + ctx.JsonApiErr(401, "Expired API key", err) + return true + } + + ctx.IsSignedIn = true + ctx.SignedInUser = &models.SignedInUser{} + ctx.OrgRole = apikey.Role + ctx.ApiKeyId = apikey.Id + ctx.OrgId = apikey.OrgId + return true +} + +func (h *ContextHandler) initContextWithBasicAuth(ctx *models.ReqContext, orgID int64) bool { + if !h.Cfg.BasicAuthEnabled { + return false + } + + header := ctx.Req.Header.Get("Authorization") + if header == "" { + return false + } + + username, password, err := util.DecodeBasicAuthHeader(header) + if err != nil { + ctx.JsonApiErr(401, "Invalid Basic Auth Header", err) + return true + } + + authQuery := models.LoginUserQuery{ + Username: username, + Password: password, + Cfg: h.Cfg, + } + if err := bus.Dispatch(&authQuery); err != nil { + ctx.Logger.Debug( + "Failed to authorize the user", + "username", username, + "err", err, + ) + + if errors.Is(err, models.ErrUserNotFound) { + err = login.ErrInvalidCredentials + } + ctx.JsonApiErr(401, InvalidUsernamePassword, err) + return true + } + + user := authQuery.User + + query := models.GetSignedInUserQuery{UserId: user.Id, OrgId: orgID} + if err := bus.Dispatch(&query); err != nil { + ctx.Logger.Error( + "Failed at user signed in", + "id", user.Id, + "org", orgID, + ) + ctx.JsonApiErr(401, InvalidUsernamePassword, err) + return true + } + + ctx.SignedInUser = query.Result + ctx.IsSignedIn = true + return true +} + +func (h *ContextHandler) initContextWithToken(ctx *models.ReqContext, orgID int64) bool { + if h.Cfg.LoginCookieName == "" { + return false + } + + rawToken := ctx.GetCookie(h.Cfg.LoginCookieName) + if rawToken == "" { + return false + } + + token, err := h.AuthTokenService.LookupToken(ctx.Req.Context(), rawToken) + if err != nil { + ctx.Logger.Error("Failed to look up user based on cookie", "error", err) + cookies.WriteSessionCookie(ctx, h.Cfg, "", -1) + return false + } + + query := models.GetSignedInUserQuery{UserId: token.UserId, OrgId: orgID} + if err := bus.Dispatch(&query); err != nil { + ctx.Logger.Error("Failed to get user with id", "userId", token.UserId, "error", err) + return false + } + + ctx.SignedInUser = query.Result + ctx.IsSignedIn = true + ctx.UserToken = token + + // Rotate the token just before we write response headers to ensure there is no delay between + // the new token being generated and the client receiving it. + ctx.Resp.Before(h.rotateEndOfRequestFunc(ctx, h.AuthTokenService, token)) + + return true +} + +func (h *ContextHandler) rotateEndOfRequestFunc(ctx *models.ReqContext, authTokenService models.UserTokenService, + token *models.UserToken) macaron.BeforeFunc { + return func(w macaron.ResponseWriter) { + // if response has already been written, skip. + if w.Written() { + return + } + + // if the request is cancelled by the client we should not try + // to rotate the token since the client would not accept any result. + if errors.Is(ctx.Context.Req.Context().Err(), context.Canceled) { + return + } + + addr := ctx.RemoteAddr() + ip, err := network.GetIPFromAddress(addr) + if err != nil { + ctx.Logger.Debug("Failed to get client IP address", "addr", addr, "err", err) + ip = nil + } + rotated, err := authTokenService.TryRotateToken(ctx.Req.Context(), token, ip, ctx.Req.UserAgent()) + if err != nil { + ctx.Logger.Error("Failed to rotate token", "error", err) + return + } + + if rotated { + cookies.WriteSessionCookie(ctx, h.Cfg, token.UnhashedToken, h.Cfg.LoginMaxLifetime) + } + } +} + +func (h *ContextHandler) initContextWithRenderAuth(ctx *models.ReqContext) bool { + key := ctx.GetCookie("renderKey") + if key == "" { + return false + } + + renderUser, exists := h.RenderService.GetRenderUser(key) + if !exists { + ctx.JsonApiErr(401, "Invalid Render Key", nil) + return true + } + + ctx.IsSignedIn = true + ctx.SignedInUser = &models.SignedInUser{ + OrgId: renderUser.OrgID, + UserId: renderUser.UserID, + OrgRole: models.RoleType(renderUser.OrgRole), + } + ctx.IsRenderCall = true + ctx.LastSeenAt = time.Now() + return true +} + +func logUserIn(auth *authproxy.AuthProxy, username string, logger log.Logger, ignoreCache bool) (int64, error) { + logger.Debug("Trying to log user in", "username", username, "ignoreCache", ignoreCache) + // Try to log in user via various providers + id, err := auth.Login(logger, ignoreCache) + if err != nil { + details := err + var e authproxy.Error + if errors.As(err, &e) { + details = e.DetailsError + } + logger.Error("Failed to login", "username", username, "message", err.Error(), "error", details, + "ignoreCache", ignoreCache) + return 0, err + } + return id, nil +} + +func handleError(ctx *models.ReqContext, err error, statusCode int, cb func(error)) { + details := err + var e authproxy.Error + if errors.As(err, &e) { + details = e.DetailsError + } + ctx.Handle(statusCode, err.Error(), details) + + if cb != nil { + cb(details) + } +} + +func (h *ContextHandler) initContextWithAuthProxy(ctx *models.ReqContext, orgID int64) bool { + username := ctx.Req.Header.Get(h.Cfg.AuthProxyHeaderName) + auth := authproxy.New(h.Cfg, &authproxy.Options{ + RemoteCache: h.RemoteCache, + Ctx: ctx, + OrgID: orgID, + }) + + logger := log.New("auth.proxy") + + // Bail if auth proxy is not enabled + if !auth.IsEnabled() { + return false + } + + // If there is no header - we can't move forward + if !auth.HasHeader() { + return false + } + + // Check if allowed to continue with this IP + if err := auth.IsAllowedIP(); err != nil { + handleError(ctx, err, 407, func(details error) { + logger.Error("Failed to check whitelisted IP addresses", "message", err.Error(), "error", details) + }) + return true + } + + id, err := logUserIn(auth, username, logger, false) + if err != nil { + handleError(ctx, err, 407, nil) + return true + } + + logger.Debug("Got user ID, getting full user info", "userID", id) + + user, err := auth.GetSignedInUser(id) + if err != nil { + // The reason we couldn't find the user corresponding to the ID might be that the ID was found from a stale + // cache entry. For example, if a user is deleted via the API, corresponding cache entries aren't invalidated + // because cache keys are computed from request header values and not just the user ID. Meaning that + // we can't easily derive cache keys to invalidate when deleting a user. To work around this, we try to + // log the user in again without the cache. + logger.Debug("Failed to get user info given ID, retrying without cache", "userID", id) + if err := auth.RemoveUserFromCache(logger); err != nil { + if !errors.Is(err, remotecache.ErrCacheItemNotFound) { + logger.Error("Got unexpected error when removing user from auth cache", "error", err) + } + } + id, err = logUserIn(auth, username, logger, true) + if err != nil { + handleError(ctx, err, 407, nil) + return true + } + + user, err = auth.GetSignedInUser(id) + if err != nil { + handleError(ctx, err, 407, nil) + return true + } + } + + logger.Debug("Successfully got user info", "userID", user.UserId, "username", user.Login) + + // Add user info to context + ctx.SignedInUser = user + ctx.IsSignedIn = true + + // Remember user data in cache + if err := auth.Remember(id); err != nil { + handleError(ctx, err, 500, func(details error) { + logger.Error( + "Failed to store user in cache", + "username", username, + "message", err.Error(), + "error", details, + ) + }) + return true + } + + return true +} diff --git a/pkg/services/contexthandler/contexthandler_test.go b/pkg/services/contexthandler/contexthandler_test.go new file mode 100644 index 00000000000..d01c04b37dc --- /dev/null +++ b/pkg/services/contexthandler/contexthandler_test.go @@ -0,0 +1,126 @@ +package contexthandler + +import ( + "context" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/grafana/grafana/pkg/components/gtime" + "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/services/auth" + "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/util" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + macaron "gopkg.in/macaron.v1" +) + +func TestDontRotateTokensOnCancelledRequests(t *testing.T) { + ctxHdlr := getContextHandler(t) + + ctx, cancel := context.WithCancel(context.Background()) + reqContext, _, err := initTokenRotationScenario(ctx, t) + require.NoError(t, err) + + tryRotateCallCount := 0 + uts := &auth.FakeUserAuthTokenService{ + TryRotateTokenProvider: func(ctx context.Context, token *models.UserToken, clientIP net.IP, + userAgent string) (bool, error) { + tryRotateCallCount++ + return false, nil + }, + } + + token := &models.UserToken{AuthToken: "oldtoken"} + + fn := ctxHdlr.rotateEndOfRequestFunc(reqContext, uts, token) + cancel() + fn(reqContext.Resp) + + assert.Equal(t, 0, tryRotateCallCount, "Token rotation was attempted") +} + +func TestTokenRotationAtEndOfRequest(t *testing.T) { + ctxHdlr := getContextHandler(t) + + reqContext, rr, err := initTokenRotationScenario(context.Background(), t) + require.NoError(t, err) + + uts := &auth.FakeUserAuthTokenService{ + TryRotateTokenProvider: func(ctx context.Context, token *models.UserToken, clientIP net.IP, + userAgent string) (bool, error) { + newToken, err := util.RandomHex(16) + require.NoError(t, err) + token.AuthToken = newToken + return true, nil + }, + } + + token := &models.UserToken{AuthToken: "oldtoken"} + + ctxHdlr.rotateEndOfRequestFunc(reqContext, uts, token)(reqContext.Resp) + + foundLoginCookie := false + resp := rr.Result() + defer resp.Body.Close() + for _, c := range resp.Cookies() { + if c.Name == "login_token" { + foundLoginCookie = true + require.NotEqual(t, token.AuthToken, c.Value, "Auth token is still the same") + } + } + + assert.True(t, foundLoginCookie, "Could not find cookie") +} + +func initTokenRotationScenario(ctx context.Context, t *testing.T) (*models.ReqContext, *httptest.ResponseRecorder, error) { + t.Helper() + + origLoginCookieName := setting.LoginCookieName + origLoginMaxLifetime := setting.LoginMaxLifetime + t.Cleanup(func() { + setting.LoginCookieName = origLoginCookieName + setting.LoginMaxLifetime = origLoginMaxLifetime + }) + setting.LoginCookieName = "login_token" + var err error + setting.LoginMaxLifetime, err = gtime.ParseDuration("7d") + if err != nil { + return nil, nil, err + } + + rr := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, "", "", nil) + if err != nil { + return nil, nil, err + } + reqContext := &models.ReqContext{ + Context: &macaron.Context{ + Req: macaron.Request{ + Request: req, + }, + }, + Logger: log.New("testlogger"), + } + + mw := mockWriter{rr} + reqContext.Resp = mw + + return reqContext, rr, nil +} + +type mockWriter struct { + *httptest.ResponseRecorder +} + +func (mw mockWriter) Flush() {} +func (mw mockWriter) Status() int { return 0 } +func (mw mockWriter) Size() int { return 0 } +func (mw mockWriter) Written() bool { return false } +func (mw mockWriter) Before(macaron.BeforeFunc) {} +func (mw mockWriter) Push(target string, opts *http.PushOptions) error { + return nil +} diff --git a/pkg/services/ldap/settings.go b/pkg/services/ldap/settings.go index 93e94f3e016..ffd5d48c893 100644 --- a/pkg/services/ldap/settings.go +++ b/pkg/services/ldap/settings.go @@ -94,8 +94,12 @@ var config *Config // GetConfig returns the LDAP config if LDAP is enabled otherwise it returns nil. It returns either cached value of // the config or it reads it and caches it first. -func GetConfig() (*Config, error) { - if !IsEnabled() { +func GetConfig(cfg *setting.Cfg) (*Config, error) { + if cfg != nil { + if !cfg.LDAPEnabled { + return nil, nil + } + } else if !IsEnabled() { return nil, nil } diff --git a/pkg/services/rendering/rendering.go b/pkg/services/rendering/rendering.go index 091d2349157..2b06bd42765 100644 --- a/pkg/services/rendering/rendering.go +++ b/pkg/services/rendering/rendering.go @@ -25,12 +25,13 @@ import ( func init() { remotecache.Register(&RenderUser{}) registry.Register(®istry.Descriptor{ - Name: "RenderingService", + Name: ServiceName, Instance: &RenderingService{}, InitPriority: registry.High, }) } +const ServiceName = "RenderingService" const renderKeyPrefix = "render-%s" type RenderUser struct { @@ -226,8 +227,8 @@ func (rs *RenderingService) getURL(path string) string { return fmt.Sprintf("%s%s&render=1", rs.Cfg.RendererCallbackUrl, path) } - protocol := setting.Protocol - switch setting.Protocol { + protocol := rs.Cfg.Protocol + switch protocol { case setting.HTTPScheme: protocol = "http" case setting.HTTP2Scheme, setting.HTTPSScheme: diff --git a/pkg/services/rendering/rendering_test.go b/pkg/services/rendering/rendering_test.go index d8d02133c2b..e0064b134f8 100644 --- a/pkg/services/rendering/rendering_test.go +++ b/pkg/services/rendering/rendering_test.go @@ -28,7 +28,7 @@ func TestGetUrl(t *testing.T) { t.Run("And protocol HTTP configured should return expected path", func(t *testing.T) { rs.Cfg.ServeFromSubPath = false rs.Cfg.AppSubURL = "" - setting.Protocol = setting.HTTPScheme + rs.Cfg.Protocol = setting.HTTPScheme url := rs.getURL(path) require.Equal(t, "http://localhost:3000/"+path+"&render=1", url) @@ -43,7 +43,7 @@ func TestGetUrl(t *testing.T) { t.Run("And protocol HTTPS configured should return expected path", func(t *testing.T) { rs.Cfg.ServeFromSubPath = false rs.Cfg.AppSubURL = "" - setting.Protocol = setting.HTTPSScheme + rs.Cfg.Protocol = setting.HTTPSScheme url := rs.getURL(path) require.Equal(t, "https://localhost:3000/"+path+"&render=1", url) }) @@ -51,7 +51,7 @@ func TestGetUrl(t *testing.T) { t.Run("And protocol HTTP2 configured should return expected path", func(t *testing.T) { rs.Cfg.ServeFromSubPath = false rs.Cfg.AppSubURL = "" - setting.Protocol = setting.HTTP2Scheme + rs.Cfg.Protocol = setting.HTTP2Scheme url := rs.getURL(path) require.Equal(t, "https://localhost:3000/"+path+"&render=1", url) }) diff --git a/pkg/services/sqlstore/preferences.go b/pkg/services/sqlstore/preferences.go index 960f7ffb570..341f67a94ba 100644 --- a/pkg/services/sqlstore/preferences.go +++ b/pkg/services/sqlstore/preferences.go @@ -6,8 +6,6 @@ import ( "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/models" - - "github.com/grafana/grafana/pkg/setting" ) func (ss *SQLStore) addPreferencesQueryAndCommandHandlers() { @@ -42,7 +40,7 @@ func (ss *SQLStore) GetPreferencesWithDefaults(query *models.GetPreferencesWithD } res := &models.Preferences{ - Theme: setting.DefaultTheme, + Theme: ss.Cfg.DefaultTheme, Timezone: ss.Cfg.DateFormats.DefaultTimezone, HomeDashboardId: 0, } diff --git a/pkg/services/sqlstore/preferences_test.go b/pkg/services/sqlstore/preferences_test.go index 4dd6705d224..410d92b6ee2 100644 --- a/pkg/services/sqlstore/preferences_test.go +++ b/pkg/services/sqlstore/preferences_test.go @@ -6,7 +6,6 @@ import ( "testing" "github.com/grafana/grafana/pkg/models" - "github.com/grafana/grafana/pkg/setting" "github.com/stretchr/testify/require" ) @@ -14,7 +13,7 @@ func TestPreferencesDataAccess(t *testing.T) { ss := InitTestDB(t) t.Run("GetPreferencesWithDefaults with no saved preferences should return defaults", func(t *testing.T) { - setting.DefaultTheme = "light" + ss.Cfg.DefaultTheme = "light" ss.Cfg.DateFormats.DefaultTimezone = "UTC" query := &models.GetPreferencesWithDefaultsQuery{User: &models.SignedInUser{}} diff --git a/pkg/services/sqlstore/sqlstore.go b/pkg/services/sqlstore/sqlstore.go index ad67867284d..56997bae68b 100644 --- a/pkg/services/sqlstore/sqlstore.go +++ b/pkg/services/sqlstore/sqlstore.go @@ -39,16 +39,21 @@ var ( // ContextSessionKey is used as key to save values in `context.Context` type ContextSessionKey struct{} +const ServiceName = "SqlStore" +const InitPriority = registry.High + func init() { + ss := &SQLStore{} + // This change will make xorm use an empty default schema for postgres and // by that mimic the functionality of how it was functioning before // xorm's changes above. xorm.DefaultPostgresSchema = "" registry.Register(®istry.Descriptor{ - Name: "SQLStore", - Instance: &SQLStore{}, - InitPriority: registry.High, + Name: ServiceName, + Instance: ss, + InitPriority: InitPriority, }) } @@ -113,13 +118,20 @@ func (ss *SQLStore) Init() error { func (ss *SQLStore) ensureMainOrgAndAdminUser() error { err := ss.InTransaction(context.Background(), func(ctx context.Context) error { - systemUserCountQuery := models.GetSystemUserCountStatsQuery{} - err := bus.DispatchCtx(ctx, &systemUserCountQuery) + var stats models.SystemUserCountStats + err := ss.WithDbSession(ctx, func(sess *DBSession) error { + var rawSql = `SELECT COUNT(id) AS Count FROM ` + dialect.Quote("user") + if _, err := sess.SQL(rawSql).Get(&stats); err != nil { + return fmt.Errorf("could not determine if admin user exists: %w", err) + } + + return nil + }) if err != nil { - return fmt.Errorf("could not determine if admin user exists: %w", err) + return err } - if systemUserCountQuery.Result.Count > 0 { + if stats.Count > 0 { return nil } @@ -351,7 +363,7 @@ func InitTestDB(t ITestDB) *SQLStore { testSQLStore = &SQLStore{} testSQLStore.Bus = bus.New() testSQLStore.CacheService = localcache.New(5*time.Minute, 10*time.Minute) - testSQLStore.skipEnsureDefaultOrgAndUser = true + testSQLStore.skipEnsureDefaultOrgAndUser = false dbType := migrator.SQLite diff --git a/pkg/setting/setting.go b/pkg/setting/setting.go index 20c5691ffed..b6a61db825b 100644 --- a/pkg/setting/setting.go +++ b/pkg/setting/setting.go @@ -47,7 +47,7 @@ var ( // This constant corresponds to the default value for ldap_sync_ttl in .ini files // it is used for comparison and has to be kept in sync const ( - AuthProxySyncTTL = 60 + authProxySyncTTL = 60 ) var ( @@ -75,12 +75,8 @@ var ( CustomInitPath = "conf/custom.ini" // HTTP server options - Protocol Scheme - Domain string HttpAddr, HttpPort string CertFile, KeyFile string - SocketPath string - RouterLogging bool DataProxyLogging bool DataProxyTimeout int DataProxyTLSHandshakeTimeout int @@ -93,28 +89,19 @@ var ( EnforceDomain bool // Security settings. - SecretKey string - DisableGravatar bool - EmailCodeValidMinutes int - DataProxyWhiteList map[string]bool - DisableBruteForceLoginProtection bool - CookieSecure bool - CookieSameSiteDisabled bool - CookieSameSiteMode http.SameSite - AllowEmbedding bool - XSSProtectionHeader bool - ContentTypeProtectionHeader bool - StrictTransportSecurity bool - StrictTransportSecurityMaxAge int - StrictTransportSecurityPreload bool - StrictTransportSecuritySubDomains bool + SecretKey string + DisableGravatar bool + EmailCodeValidMinutes int + DataProxyWhiteList map[string]bool + CookieSecure bool + CookieSameSiteDisabled bool + CookieSameSiteMode http.SameSite // Snapshots ExternalSnapshotUrl string ExternalSnapshotName string ExternalEnabled bool SnapShotRemoveExpired bool - SnapshotPublicMode bool // Dashboard history DashboardVersionsToKeep int @@ -129,7 +116,6 @@ var ( VerifyEmailEnabled bool LoginHint string PasswordHint string - DefaultTheme string DisableLoginForm bool DisableSignoutMenu bool SignoutRedirectUrl string @@ -139,7 +125,7 @@ var ( OAuthAutoLogin bool ViewersCanEdit bool - // Http auth + // HTTP auth AdminUser string AdminPassword string LoginCookieName string @@ -147,18 +133,10 @@ var ( SigV4AuthEnabled bool AnonymousEnabled bool - AnonymousOrgName string - AnonymousOrgRole string // Auth proxy settings - AuthProxyEnabled bool - AuthProxyHeaderName string - AuthProxyHeaderProperty string - AuthProxyAutoSignUp bool - AuthProxyEnableLoginToken bool - AuthProxySyncTtl int - AuthProxyWhitelist string - AuthProxyHeaders map[string]string + AuthProxyEnabled bool + AuthProxyHeaderProperty string // Basic Auth BasicAuthEnabled bool @@ -224,6 +202,9 @@ type Cfg struct { ServeFromSubPath bool StaticRootPath string Protocol Scheme + SocketPath string + RouterLogging bool + Domain string // build BuildVersion string @@ -251,11 +232,18 @@ type Cfg struct { RendererConcurrentRequestLimit int // Security - DisableInitAdminCreation bool - DisableBruteForceLoginProtection bool - CookieSecure bool - CookieSameSiteDisabled bool - CookieSameSiteMode http.SameSite + DisableInitAdminCreation bool + DisableBruteForceLoginProtection bool + CookieSecure bool + CookieSameSiteDisabled bool + CookieSameSiteMode http.SameSite + AllowEmbedding bool + XSSProtectionHeader bool + ContentTypeProtectionHeader bool + StrictTransportSecurity bool + StrictTransportSecurityMaxAge int + StrictTransportSecurityPreload bool + StrictTransportSecuritySubDomains bool TempDataLifetime time.Duration PluginsEnableAlpha bool @@ -282,6 +270,17 @@ type Cfg struct { LoginMaxLifetime time.Duration TokenRotationIntervalMinutes int SigV4AuthEnabled bool + BasicAuthEnabled bool + + // Auth proxy settings + AuthProxyEnabled bool + AuthProxyHeaderName string + AuthProxyHeaderProperty string + AuthProxyAutoSignUp bool + AuthProxyEnableLoginToken bool + AuthProxyWhitelist string + AuthProxyHeaders map[string]string + AuthProxySyncTTL int // OAuth OAuthCookieMaxAge int @@ -302,6 +301,9 @@ type Cfg struct { // Use to enable new features which may still be in alpha/beta stage. FeatureToggles map[string]bool + AnonymousEnabled bool + AnonymousOrgName string + AnonymousOrgRole string AnonymousHideVersion bool DateFormats DateFormats @@ -317,6 +319,21 @@ type Cfg struct { // Sentry config Sentry Sentry + + // Snapshots + SnapshotPublicMode bool + + ErrTemplateName string + + Env string + + // LDAP + LDAPEnabled bool + LDAPAllowSignup bool + + Quota QuotaSettings + + DefaultTheme string } // IsExpressionsEnabled returns whether the expressions feature is enabled. @@ -707,9 +724,12 @@ func (cfg *Cfg) Load(args *CommandLineArgs) error { cfg.IsEnterprise = IsEnterprise cfg.Packaging = Packaging + cfg.ErrTemplateName = ErrTemplateName + ApplicationName = "Grafana" Env = valueAsString(iniFile.Section(""), "app_mode", "development") + cfg.Env = Env InstanceName = valueAsString(iniFile.Section(""), "instance_name", "unknown_instance_name") plugins := valueAsString(iniFile.Section("paths"), "plugins", "") PluginsPath = makeAbsolute(plugins, HomePath) @@ -736,7 +756,7 @@ func (cfg *Cfg) Load(args *CommandLineArgs) error { return err } - if err := readSnapshotsSettings(iniFile); err != nil { + if err := readSnapshotsSettings(cfg, iniFile); err != nil { return err } @@ -789,7 +809,6 @@ func (cfg *Cfg) Load(args *CommandLineArgs) error { cfg.PluginsAllowUnsigned = append(cfg.PluginsAllowUnsigned, plug) } cfg.MarketplaceURL = pluginsSection.Key("marketplace_url").MustString("https://grafana.com/grafana/plugins/") - cfg.Protocol = Protocol // Read and populate feature toggles list featureTogglesSection := iniFile.Section("feature_toggles") @@ -858,8 +877,10 @@ func (cfg *Cfg) readLDAPConfig() { LDAPConfigFile = ldapSec.Key("config_file").String() LDAPSyncCron = ldapSec.Key("sync_cron").String() LDAPEnabled = ldapSec.Key("enabled").MustBool(false) + cfg.LDAPEnabled = LDAPEnabled LDAPActiveSyncEnabled = ldapSec.Key("active_sync_enabled").MustBool(false) LDAPAllowSignup = ldapSec.Key("allow_sign_up").MustBool(true) + cfg.LDAPAllowSignup = LDAPAllowSignup } func (cfg *Cfg) readSessionConfig() { @@ -910,7 +931,7 @@ func (cfg *Cfg) LogConfigSources() { cfg.Logger.Info("Path Logs", "path", cfg.LogsPath) cfg.Logger.Info("Path Plugins", "path", PluginsPath) cfg.Logger.Info("Path Provisioning", "path", cfg.ProvisioningPath) - cfg.Logger.Info("App mode " + Env) + cfg.Logger.Info("App mode " + cfg.Env) } type DynamicSection struct { @@ -949,7 +970,6 @@ func readSecuritySettings(iniFile *ini.File, cfg *Cfg) error { SecretKey = valueAsString(security, "secret_key", "") DisableGravatar = security.Key("disable_gravatar").MustBool(true) cfg.DisableBruteForceLoginProtection = security.Key("disable_brute_force_login_protection").MustBool(false) - DisableBruteForceLoginProtection = cfg.DisableBruteForceLoginProtection CookieSecure = security.Key("cookie_secure").MustBool(false) cfg.CookieSecure = CookieSecure @@ -974,14 +994,14 @@ func readSecuritySettings(iniFile *ini.File, cfg *Cfg) error { cfg.CookieSameSiteMode = CookieSameSiteMode } } - AllowEmbedding = security.Key("allow_embedding").MustBool(false) + cfg.AllowEmbedding = security.Key("allow_embedding").MustBool(false) - ContentTypeProtectionHeader = security.Key("x_content_type_options").MustBool(true) - XSSProtectionHeader = security.Key("x_xss_protection").MustBool(true) - StrictTransportSecurity = security.Key("strict_transport_security").MustBool(false) - StrictTransportSecurityMaxAge = security.Key("strict_transport_security_max_age_seconds").MustInt(86400) - StrictTransportSecurityPreload = security.Key("strict_transport_security_preload").MustBool(false) - StrictTransportSecuritySubDomains = security.Key("strict_transport_security_subdomains").MustBool(false) + cfg.ContentTypeProtectionHeader = security.Key("x_content_type_options").MustBool(true) + cfg.XSSProtectionHeader = security.Key("x_xss_protection").MustBool(true) + cfg.StrictTransportSecurity = security.Key("strict_transport_security").MustBool(false) + cfg.StrictTransportSecurityMaxAge = security.Key("strict_transport_security_max_age_seconds").MustInt(86400) + cfg.StrictTransportSecurityPreload = security.Key("strict_transport_security_preload").MustBool(false) + cfg.StrictTransportSecuritySubDomains = security.Key("strict_transport_security_subdomains").MustBool(false) // read data source proxy whitelist DataProxyWhiteList = make(map[string]bool) @@ -1054,41 +1074,45 @@ func readAuthSettings(iniFile *ini.File, cfg *Cfg) (err error) { // anonymous access AnonymousEnabled = iniFile.Section("auth.anonymous").Key("enabled").MustBool(false) - AnonymousOrgName = valueAsString(iniFile.Section("auth.anonymous"), "org_name", "") - AnonymousOrgRole = valueAsString(iniFile.Section("auth.anonymous"), "org_role", "") + cfg.AnonymousEnabled = AnonymousEnabled + cfg.AnonymousOrgName = valueAsString(iniFile.Section("auth.anonymous"), "org_name", "") + cfg.AnonymousOrgRole = valueAsString(iniFile.Section("auth.anonymous"), "org_role", "") cfg.AnonymousHideVersion = iniFile.Section("auth.anonymous").Key("hide_version").MustBool(false) // basic auth authBasic := iniFile.Section("auth.basic") BasicAuthEnabled = authBasic.Key("enabled").MustBool(true) + cfg.BasicAuthEnabled = BasicAuthEnabled authProxy := iniFile.Section("auth.proxy") AuthProxyEnabled = authProxy.Key("enabled").MustBool(false) + cfg.AuthProxyEnabled = AuthProxyEnabled - AuthProxyHeaderName = valueAsString(authProxy, "header_name", "") + cfg.AuthProxyHeaderName = valueAsString(authProxy, "header_name", "") AuthProxyHeaderProperty = valueAsString(authProxy, "header_property", "") - AuthProxyAutoSignUp = authProxy.Key("auto_sign_up").MustBool(true) - AuthProxyEnableLoginToken = authProxy.Key("enable_login_token").MustBool(false) + cfg.AuthProxyHeaderProperty = AuthProxyHeaderProperty + cfg.AuthProxyAutoSignUp = authProxy.Key("auto_sign_up").MustBool(true) + cfg.AuthProxyEnableLoginToken = authProxy.Key("enable_login_token").MustBool(false) ldapSyncVal := authProxy.Key("ldap_sync_ttl").MustInt() syncVal := authProxy.Key("sync_ttl").MustInt() - if ldapSyncVal != AuthProxySyncTTL { - AuthProxySyncTtl = ldapSyncVal + if ldapSyncVal != authProxySyncTTL { + cfg.AuthProxySyncTTL = ldapSyncVal cfg.Logger.Warn("[Deprecated] the configuration setting 'ldap_sync_ttl' is deprecated, please use 'sync_ttl' instead") } else { - AuthProxySyncTtl = syncVal + cfg.AuthProxySyncTTL = syncVal } - AuthProxyWhitelist = valueAsString(authProxy, "whitelist", "") + cfg.AuthProxyWhitelist = valueAsString(authProxy, "whitelist", "") - AuthProxyHeaders = make(map[string]string) + cfg.AuthProxyHeaders = make(map[string]string) headers := valueAsString(authProxy, "headers", "") for _, propertyAndHeader := range util.SplitString(headers) { split := strings.SplitN(propertyAndHeader, ":", 2) if len(split) == 2 { - AuthProxyHeaders[split[0]] = split[1] + cfg.AuthProxyHeaders[split[0]] = split[1] } } @@ -1106,7 +1130,7 @@ func readUserSettings(iniFile *ini.File, cfg *Cfg) error { LoginHint = valueAsString(users, "login_hint", "") PasswordHint = valueAsString(users, "password_hint", "") - DefaultTheme = valueAsString(users, "default_theme", "") + cfg.DefaultTheme = valueAsString(users, "default_theme", "") ExternalUserMngLinkUrl = valueAsString(users, "external_manage_link_url", "") ExternalUserMngLinkName = valueAsString(users, "external_manage_link_name", "") ExternalUserMngInfo = valueAsString(users, "external_manage_info", "") @@ -1178,7 +1202,7 @@ func readAlertingSettings(iniFile *ini.File) error { return nil } -func readSnapshotsSettings(iniFile *ini.File) error { +func readSnapshotsSettings(cfg *Cfg, iniFile *ini.File) error { snapshots := iniFile.Section("snapshots") ExternalSnapshotUrl = valueAsString(snapshots, "external_snapshot_url", "") @@ -1186,7 +1210,7 @@ func readSnapshotsSettings(iniFile *ini.File) error { ExternalEnabled = snapshots.Key("external_enabled").MustBool(true) SnapShotRemoveExpired = snapshots.Key("snapshot_remove_expired").MustBool(true) - SnapshotPublicMode = snapshots.Key("public_mode").MustBool(false) + cfg.SnapshotPublicMode = snapshots.Key("public_mode").MustBool(false) return nil } @@ -1204,28 +1228,28 @@ func readServerSettings(iniFile *ini.File, cfg *Cfg) error { cfg.AppSubURL = AppSubUrl cfg.ServeFromSubPath = ServeFromSubPath - Protocol = HTTPScheme + cfg.Protocol = HTTPScheme protocolStr := valueAsString(server, "protocol", "http") if protocolStr == "https" { - Protocol = HTTPSScheme + cfg.Protocol = HTTPSScheme CertFile = server.Key("cert_file").String() KeyFile = server.Key("cert_key").String() } if protocolStr == "h2" { - Protocol = HTTP2Scheme + cfg.Protocol = HTTP2Scheme CertFile = server.Key("cert_file").String() KeyFile = server.Key("cert_key").String() } if protocolStr == "socket" { - Protocol = SocketScheme - SocketPath = server.Key("socket").String() + cfg.Protocol = SocketScheme + cfg.SocketPath = server.Key("socket").String() } - Domain = valueAsString(server, "domain", "localhost") + cfg.Domain = valueAsString(server, "domain", "localhost") HttpAddr = valueAsString(server, "http_addr", DefaultHTTPAddr) HttpPort = valueAsString(server, "http_port", "3000") - RouterLogging = server.Key("router_logging").MustBool(false) + cfg.RouterLogging = server.Key("router_logging").MustBool(false) EnableGzip = server.Key("enable_gzip").MustBool(false) EnforceDomain = server.Key("enforce_domain").MustBool(false) diff --git a/pkg/setting/setting_quota.go b/pkg/setting/setting_quota.go index 050d76215cd..be562a4c0e2 100644 --- a/pkg/setting/setting_quota.go +++ b/pkg/setting/setting_quota.go @@ -86,4 +86,6 @@ func (cfg *Cfg) readQuotaSettings() { ApiKey: quota.Key("global_api_key").MustInt64(-1), Session: quota.Key("global_session").MustInt64(-1), } + + cfg.Quota = Quota } diff --git a/pkg/setting/setting_test.go b/pkg/setting/setting_test.go index 0fb696769fd..ea1c23cae81 100644 --- a/pkg/setting/setting_test.go +++ b/pkg/setting/setting_test.go @@ -133,7 +133,7 @@ func TestLoadingSettings(t *testing.T) { }) So(err, ShouldBeNil) - So(Domain, ShouldEqual, "test2") + So(cfg.Domain, ShouldEqual, "test2") }) Convey("Defaults can be overridden in specified config file", func() { @@ -239,7 +239,7 @@ func TestLoadingSettings(t *testing.T) { }) So(err, ShouldBeNil) - So(AuthProxySyncTtl, ShouldEqual, 2) + So(cfg.AuthProxySyncTTL, ShouldEqual, 2) }) Convey("Only ldap_sync_ttl should return the value ldap_sync_ttl", func() { @@ -250,7 +250,7 @@ func TestLoadingSettings(t *testing.T) { }) So(err, ShouldBeNil) - So(AuthProxySyncTtl, ShouldEqual, 5) + So(cfg.AuthProxySyncTTL, ShouldEqual, 5) }) Convey("ldap_sync should override ldap_sync_ttl that is default value", func() { @@ -261,7 +261,7 @@ func TestLoadingSettings(t *testing.T) { }) So(err, ShouldBeNil) - So(AuthProxySyncTtl, ShouldEqual, 5) + So(cfg.AuthProxySyncTTL, ShouldEqual, 5) }) Convey("ldap_sync should not override ldap_sync_ttl that is different from default value", func() { @@ -272,7 +272,7 @@ func TestLoadingSettings(t *testing.T) { }) So(err, ShouldBeNil) - So(AuthProxySyncTtl, ShouldEqual, 12) + So(cfg.AuthProxySyncTTL, ShouldEqual, 12) }) })