From 58dbf96a12d75f58b46d6956a030909ff1134e53 Mon Sep 17 00:00:00 2001 From: Arve Knudsen Date: Thu, 3 Dec 2020 08:28:54 +0100 Subject: [PATCH] Middleware: Rewrite tests to use standard library (#29535) * middleware: Rewrite tests to use standard library Signed-off-by: Arve Knudsen --- pkg/middleware/auth_test.go | 149 ++- pkg/middleware/dashboard_redirect_test.go | 77 +- pkg/middleware/middleware.go | 17 +- pkg/middleware/middleware_basic_auth_test.go | 211 ++--- pkg/middleware/middleware_test.go | 903 ++++++++++--------- pkg/middleware/org_redirect_test.go | 94 +- pkg/middleware/quota_test.go | 259 +++--- pkg/middleware/recovery_test.go | 32 +- pkg/middleware/testing.go | 19 +- pkg/services/alerting/notifier_test.go | 145 +-- 10 files changed, 1008 insertions(+), 898 deletions(-) diff --git a/pkg/middleware/auth_test.go b/pkg/middleware/auth_test.go index 6f6b243d8a9..ea96dcb9d1a 100644 --- a/pkg/middleware/auth_test.go +++ b/pkg/middleware/auth_test.go @@ -7,103 +7,100 @@ import ( "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/setting" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - . "github.com/smartystreets/goconvey/convey" ) func TestMiddlewareAuth(t *testing.T) { - Convey("Given the grafana middleware", t, func() { - reqSignIn := Auth(&AuthOptions{ReqSignedIn: true}) + reqSignIn := Auth(&AuthOptions{ReqSignedIn: true}) - middlewareScenario(t, "ReqSignIn true and unauthenticated request", func(sc *scenarioContext) { - sc.m.Get("/secure", reqSignIn, sc.defaultHandler) + middlewareScenario(t, "ReqSignIn true and unauthenticated request", func(sc *scenarioContext) { + sc.m.Get("/secure", reqSignIn, sc.defaultHandler) - sc.fakeReq("GET", "/secure").exec() + sc.fakeReq("GET", "/secure").exec() - Convey("Should redirect to login", func() { - So(sc.resp.Code, ShouldEqual, 302) - }) + assert.Equal(t, 302, sc.resp.Code) + }) + + middlewareScenario(t, "ReqSignIn true and unauthenticated API request", func(sc *scenarioContext) { + sc.m.Get("/api/secure", reqSignIn, sc.defaultHandler) + + sc.fakeReq("GET", "/api/secure").exec() + + assert.Equal(t, 401, sc.resp.Code) + }) + + t.Run("Anonymous auth enabled", func(t *testing.T) { + const orgID int64 = 1 + + origEnabled := setting.AnonymousEnabled + t.Cleanup(func() { + setting.AnonymousEnabled = origEnabled }) - - middlewareScenario(t, "ReqSignIn true and unauthenticated API request", func(sc *scenarioContext) { - sc.m.Get("/api/secure", reqSignIn, sc.defaultHandler) - - sc.fakeReq("GET", "/api/secure").exec() - - Convey("Should return 401", func() { - So(sc.resp.Code, ShouldEqual, 401) - }) + origName := setting.AnonymousOrgName + t.Cleanup(func() { + setting.AnonymousOrgName = origName }) + setting.AnonymousEnabled = true + setting.AnonymousOrgName = "test" - Convey("Anonymous auth enabled", func() { - origEnabled := setting.AnonymousEnabled - t.Cleanup(func() { - setting.AnonymousEnabled = origEnabled - }) - origName := setting.AnonymousOrgName - t.Cleanup(func() { - setting.AnonymousOrgName = origName - }) - setting.AnonymousEnabled = true - setting.AnonymousOrgName = "test" - + middlewareScenario(t, "ReqSignIn true and request with forceLogin in query string", func(sc *scenarioContext) { bus.AddHandler("test", func(query *models.GetOrgByNameQuery) error { - query.Result = &models.Org{Id: 1, Name: "test"} + query.Result = &models.Org{Id: orgID, Name: "test"} return nil }) - middlewareScenario(t, "ReqSignIn true and request with forceLogin in query string", func(sc *scenarioContext) { - sc.m.Get("/secure", reqSignIn, sc.defaultHandler) + sc.m.Get("/secure", reqSignIn, sc.defaultHandler) - sc.fakeReq("GET", "/secure?forceLogin=true").exec() + sc.fakeReq("GET", "/secure?forceLogin=true").exec() - Convey("Should redirect to login", func() { - So(sc.resp.Code, ShouldEqual, 302) - location, ok := sc.resp.Header()["Location"] - So(ok, ShouldBeTrue) - So(location[0], ShouldEqual, "/login") - }) - }) - - middlewareScenario(t, "ReqSignIn true and request with same org provided in query string", func(sc *scenarioContext) { - sc.m.Get("/secure", reqSignIn, sc.defaultHandler) - - sc.fakeReq("GET", "/secure?orgId=1").exec() - - Convey("Should not redirect to login", func() { - So(sc.resp.Code, ShouldEqual, 200) - }) - }) - - middlewareScenario(t, "ReqSignIn true and request with different org provided in query string", func(sc *scenarioContext) { - sc.m.Get("/secure", reqSignIn, sc.defaultHandler) - - sc.fakeReq("GET", "/secure?orgId=2").exec() - - Convey("Should redirect to login", func() { - So(sc.resp.Code, ShouldEqual, 302) - location, ok := sc.resp.Header()["Location"] - So(ok, ShouldBeTrue) - So(location[0], ShouldEqual, "/login") - }) - }) + assert.Equal(sc.t, 302, sc.resp.Code) + location, ok := sc.resp.Header()["Location"] + assert.True(t, ok) + assert.Equal(t, "/login", location[0]) }) - Convey("snapshot public mode or signed in", func() { - middlewareScenario(t, "Snapshot public mode disabled and unauthenticated request should return 401", func(sc *scenarioContext) { - sc.m.Get("/api/snapshot", SnapshotPublicModeOrSignedIn(), sc.defaultHandler) - sc.fakeReq("GET", "/api/snapshot").exec() - So(sc.resp.Code, ShouldEqual, 401) + middlewareScenario(t, "ReqSignIn true and request with same org provided in query string", func(sc *scenarioContext) { + bus.AddHandler("test", func(query *models.GetOrgByNameQuery) error { + query.Result = &models.Org{Id: orgID, Name: "test"} + return nil }) - middlewareScenario(t, "Snapshot public mode enabled and unauthenticated request should return 200", func(sc *scenarioContext) { - setting.SnapshotPublicMode = true - sc.m.Get("/api/snapshot", SnapshotPublicModeOrSignedIn(), sc.defaultHandler) - sc.fakeReq("GET", "/api/snapshot").exec() - So(sc.resp.Code, ShouldEqual, 200) - }) + sc.m.Get("/secure", reqSignIn, sc.defaultHandler) + + sc.fakeReq("GET", fmt.Sprintf("/secure?orgId=%d", orgID)).exec() + + assert.Equal(sc.t, 200, sc.resp.Code) }) + + middlewareScenario(t, "ReqSignIn true and request with different org provided in query string", func(sc *scenarioContext) { + bus.AddHandler("test", func(query *models.GetOrgByNameQuery) error { + query.Result = &models.Org{Id: orgID, Name: "test"} + return nil + }) + + sc.m.Get("/secure", reqSignIn, sc.defaultHandler) + + sc.fakeReq("GET", "/secure?orgId=2").exec() + + assert.Equal(sc.t, 302, sc.resp.Code) + location, ok := sc.resp.Header()["Location"] + assert.True(sc.t, ok) + assert.Equal(sc.t, "/login", location[0]) + }) + }) + + middlewareScenario(t, "Snapshot public mode disabled and unauthenticated request should return 401", func(sc *scenarioContext) { + sc.m.Get("/api/snapshot", SnapshotPublicModeOrSignedIn(), sc.defaultHandler) + sc.fakeReq("GET", "/api/snapshot").exec() + assert.Equal(sc.t, 401, sc.resp.Code) + }) + + middlewareScenario(t, "Snapshot public mode enabled and unauthenticated request should return 200", func(sc *scenarioContext) { + setting.SnapshotPublicMode = true + sc.m.Get("/api/snapshot", SnapshotPublicModeOrSignedIn(), sc.defaultHandler) + sc.fakeReq("GET", "/api/snapshot").exec() + assert.Equal(sc.t, 200, sc.resp.Code) }) } diff --git a/pkg/middleware/dashboard_redirect_test.go b/pkg/middleware/dashboard_redirect_test.go index 5730bf15d14..7efcbd801c1 100644 --- a/pkg/middleware/dashboard_redirect_test.go +++ b/pkg/middleware/dashboard_redirect_test.go @@ -8,6 +8,8 @@ import ( "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/util" . "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestMiddlewareDashboardRedirect(t *testing.T) { @@ -22,62 +24,57 @@ func TestMiddlewareDashboardRedirect(t *testing.T) { fakeDash.HasAcl = false fakeDash.Uid = util.GenerateShortUID() - bus.AddHandler("test", func(query *models.GetDashboardQuery) error { - query.Result = fakeDash - return nil - }) - middlewareScenario(t, "GET dashboard by legacy url", func(sc *scenarioContext) { + bus.AddHandler("test", func(query *models.GetDashboardQuery) error { + query.Result = fakeDash + return nil + }) + sc.m.Get("/dashboard/db/:slug", redirectFromLegacyDashboardUrl, sc.defaultHandler) sc.fakeReqWithParams("GET", "/dashboard/db/dash?orgId=1&panelId=2", map[string]string{}).exec() - Convey("Should redirect to new dashboard url with a 301 Moved Permanently", func() { - So(sc.resp.Code, ShouldEqual, 301) - resp := sc.resp.Result() - defer resp.Body.Close() - redirectURL, err := resp.Location() - So(err, ShouldBeNil) - So(redirectURL.Path, ShouldEqual, models.GetDashboardUrl(fakeDash.Uid, fakeDash.Slug)) - So(len(redirectURL.Query()), ShouldEqual, 2) - }) + assert.Equal(t, 301, sc.resp.Code) + resp := sc.resp.Result() + resp.Body.Close() + redirectURL, err := resp.Location() + require.NoError(t, err) + assert.Equal(t, models.GetDashboardUrl(fakeDash.Uid, fakeDash.Slug), redirectURL.Path) + assert.Equal(t, 2, len(redirectURL.Query())) }) middlewareScenario(t, "GET dashboard solo by legacy url", func(sc *scenarioContext) { + bus.AddHandler("test", func(query *models.GetDashboardQuery) error { + query.Result = fakeDash + return nil + }) + sc.m.Get("/dashboard-solo/db/:slug", redirectFromLegacyDashboardSoloUrl, sc.defaultHandler) sc.fakeReqWithParams("GET", "/dashboard-solo/db/dash?orgId=1&panelId=2", map[string]string{}).exec() - Convey("Should redirect to new dashboard url with a 301 Moved Permanently", func() { - So(sc.resp.Code, ShouldEqual, 301) - resp := sc.resp.Result() - defer resp.Body.Close() - redirectURL, err := resp.Location() - So(err, ShouldBeNil) - expectedURL := models.GetDashboardUrl(fakeDash.Uid, fakeDash.Slug) - expectedURL = strings.Replace(expectedURL, "/d/", "/d-solo/", 1) - So(redirectURL.Path, ShouldEqual, expectedURL) - So(len(redirectURL.Query()), ShouldEqual, 2) - }) + assert.Equal(t, 301, sc.resp.Code) + resp := sc.resp.Result() + resp.Body.Close() + redirectURL, err := resp.Location() + require.NoError(t, err) + expectedURL := models.GetDashboardUrl(fakeDash.Uid, fakeDash.Slug) + expectedURL = strings.Replace(expectedURL, "/d/", "/d-solo/", 1) + assert.Equal(t, expectedURL, redirectURL.Path) + assert.Equal(t, 2, len(redirectURL.Query())) }) }) - Convey("Given the dashboard legacy edit panel middleware", t, func() { - bus.ClearBusHandlers() + middlewareScenario(t, "GET dashboard by legacy edit url", func(sc *scenarioContext) { + sc.m.Get("/d/:uid/:slug", RedirectFromLegacyPanelEditURL(), sc.defaultHandler) - middlewareScenario(t, "GET dashboard by legacy edit url", func(sc *scenarioContext) { - sc.m.Get("/d/:uid/:slug", RedirectFromLegacyPanelEditURL(), sc.defaultHandler) + sc.fakeReqWithParams("GET", "/d/asd/dash?orgId=1&panelId=12&fullscreen&edit", map[string]string{}).exec() - sc.fakeReqWithParams("GET", "/d/asd/dash?orgId=1&panelId=12&fullscreen&edit", map[string]string{}).exec() - - Convey("Should redirect to new dashboard edit url with a 301 Moved Permanently", func() { - So(sc.resp.Code, ShouldEqual, 301) - resp := sc.resp.Result() - defer resp.Body.Close() - redirectURL, err := resp.Location() - So(err, ShouldBeNil) - So(redirectURL.String(), ShouldEqual, "/d/asd/d/asd/dash?editPanel=12&orgId=1") - }) - }) + assert.Equal(t, 301, sc.resp.Code) + resp := sc.resp.Result() + resp.Body.Close() + redirectURL, err := resp.Location() + require.NoError(t, err) + assert.Equal(t, "/d/asd/d/asd/dash?editPanel=12&orgId=1", redirectURL.String()) }) } diff --git a/pkg/middleware/middleware.go b/pkg/middleware/middleware.go index 6d5dd32fc26..00531f22823 100644 --- a/pkg/middleware/middleware.go +++ b/pkg/middleware/middleware.go @@ -54,10 +54,13 @@ func GetContextHandler( Logger: log.New("context"), } - orgId := int64(0) - orgIdHeader := ctx.Req.Header.Get("X-Grafana-Org-Id") - if orgIdHeader != "" { - orgId, _ = strconv.ParseInt(orgIdHeader, 10, 64) + orgID := int64(0) + orgIDHeader := ctx.Req.Header.Get("X-Grafana-Org-Id") + if orgIDHeader != "" { + orgIDParsed, err := strconv.ParseInt(orgIDHeader, 10, 64) + if err == nil { + orgID = orgIDParsed + } } // the order in which these are tested are important @@ -68,9 +71,9 @@ func GetContextHandler( switch { case initContextWithRenderAuth(ctx, renderService): case initContextWithApiKey(ctx): - case initContextWithBasicAuth(ctx, orgId): - case initContextWithAuthProxy(remoteCache, ctx, orgId): - case initContextWithToken(ats, ctx, orgId): + case initContextWithBasicAuth(ctx, orgID): + case initContextWithAuthProxy(remoteCache, ctx, orgID): + case initContextWithToken(ats, ctx, orgID): case initContextWithAnonymousUser(ctx): } diff --git a/pkg/middleware/middleware_basic_auth_test.go b/pkg/middleware/middleware_basic_auth_test.go index 1535af0f5ea..17386afcc56 100644 --- a/pkg/middleware/middleware_basic_auth_test.go +++ b/pkg/middleware/middleware_basic_auth_test.go @@ -4,149 +4,136 @@ import ( "encoding/json" "testing" - . "github.com/smartystreets/goconvey/convey" - "github.com/grafana/grafana/pkg/bus" authLogin "github.com/grafana/grafana/pkg/login" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestMiddlewareBasicAuth(t *testing.T) { - Convey("Given the basic auth", t, func() { - var oldBasicAuthEnabled = setting.BasicAuthEnabled - var oldDisableBruteForceLoginProtection = setting.DisableBruteForceLoginProtection - var id int64 = 12 + var origBasicAuthEnabled = setting.BasicAuthEnabled + var origDisableBruteForceLoginProtection = setting.DisableBruteForceLoginProtection + t.Cleanup(func() { + setting.BasicAuthEnabled = origBasicAuthEnabled + setting.DisableBruteForceLoginProtection = origDisableBruteForceLoginProtection + }) + setting.BasicAuthEnabled = true + setting.DisableBruteForceLoginProtection = true - Convey("Setup", func() { - setting.BasicAuthEnabled = true - setting.DisableBruteForceLoginProtection = true - bus.ClearBusHandlers() + bus.ClearBusHandlers() + + const id int64 = 12 + + middlewareScenario(t, "Valid API key", func(sc *scenarioContext) { + const orgID int64 = 2 + keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd") + require.NoError(t, err) + + bus.AddHandler("test", func(query *models.GetApiKeyByNameQuery) error { + query.Result = &models.ApiKey{OrgId: orgID, Role: models.ROLE_EDITOR, Key: keyhash} + return nil }) - middlewareScenario(t, "Valid API key", func(sc *scenarioContext) { - var orgID int64 = 2 - keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd") - So(err, ShouldBeNil) + authHeader := util.GetBasicAuthHeader("api_key", "eyJrIjoidjVuQXdwTWFmRlA2em5hUzR1cmhkV0RMUzU1MTFNNDIiLCJuIjoiYXNkIiwiaWQiOjF9") + sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec() - bus.AddHandler("test", func(query *models.GetApiKeyByNameQuery) error { - query.Result = &models.ApiKey{OrgId: orgID, Role: models.ROLE_EDITOR, Key: keyhash} - return nil - }) + assert.Equal(t, 200, sc.resp.Code) + assert.True(t, sc.context.IsSignedIn) + assert.Equal(t, orgID, sc.context.OrgId) + assert.Equal(t, models.ROLE_EDITOR, sc.context.OrgRole) + }) - authHeader := util.GetBasicAuthHeader("api_key", "eyJrIjoidjVuQXdwTWFmRlA2em5hUzR1cmhkV0RMUzU1MTFNNDIiLCJuIjoiYXNkIiwiaWQiOjF9") - sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec() + middlewareScenario(t, "Handle auth", func(sc *scenarioContext) { + const password = "MyPass" + const salt = "Salt" + const orgID int64 = 2 - Convey("Should return 200", func() { - So(sc.resp.Code, ShouldEqual, 200) - }) + t.Cleanup(bus.ClearBusHandlers) - Convey("Should init middleware context", func() { - So(sc.context.IsSignedIn, ShouldEqual, true) - So(sc.context.OrgId, ShouldEqual, orgID) - So(sc.context.OrgRole, ShouldEqual, models.ROLE_EDITOR) - }) + bus.AddHandler("grafana-auth", func(query *models.LoginUserQuery) error { + encoded, err := util.EncodePassword(password, salt) + if err != nil { + return err + } + query.User = &models.User{ + Password: encoded, + Salt: salt, + } + return nil }) - middlewareScenario(t, "Handle auth", func(sc *scenarioContext) { - var password = "MyPass" - var salt = "Salt" - var orgID int64 = 2 - - bus.AddHandler("grafana-auth", func(query *models.LoginUserQuery) error { - encoded, err := util.EncodePassword(password, salt) - if err != nil { - return err - } - query.User = &models.User{ - Password: encoded, - Salt: salt, - } - return nil - }) - - bus.AddHandler("get-sign-user", func(query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{OrgId: orgID, UserId: id} - return nil - }) - - authHeader := util.GetBasicAuthHeader("myUser", password) - sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec() - - Convey("Should init middleware context with users", func() { - So(sc.context.IsSignedIn, ShouldEqual, true) - So(sc.context.OrgId, ShouldEqual, orgID) - So(sc.context.UserId, ShouldEqual, id) - }) - - bus.ClearBusHandlers() + bus.AddHandler("get-sign-user", func(query *models.GetSignedInUserQuery) error { + query.Result = &models.SignedInUser{OrgId: orgID, UserId: id} + return nil }) - middlewareScenario(t, "Auth sequence", func(sc *scenarioContext) { - var password = "MyPass" - var salt = "Salt" + authHeader := util.GetBasicAuthHeader("myUser", password) + sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec() - authLogin.Init() + assert.True(t, sc.context.IsSignedIn) + assert.Equal(t, orgID, sc.context.OrgId) + assert.Equal(t, id, sc.context.UserId) + }) - bus.AddHandler("user-query", func(query *models.GetUserByLoginQuery) error { - encoded, err := util.EncodePassword(password, salt) - if err != nil { - return err - } - query.Result = &models.User{ - Password: encoded, - Id: id, - Salt: salt, - } - return nil - }) + middlewareScenario(t, "Auth sequence", func(sc *scenarioContext) { + const password = "MyPass" + const salt = "Salt" - bus.AddHandler("get-sign-user", func(query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{UserId: query.UserId} - return nil - }) + authLogin.Init() - authHeader := util.GetBasicAuthHeader("myUser", password) - sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec() - - Convey("Should init middleware context with user", func() { - So(sc.context.IsSignedIn, ShouldEqual, true) - So(sc.context.UserId, ShouldEqual, id) - }) + bus.AddHandler("user-query", func(query *models.GetUserByLoginQuery) error { + encoded, err := util.EncodePassword(password, salt) + if err != nil { + return err + } + query.Result = &models.User{ + Password: encoded, + Id: id, + Salt: salt, + } + return nil }) - middlewareScenario(t, "Should return error if user is not found", func(sc *scenarioContext) { - sc.fakeReq("GET", "/") - sc.req.SetBasicAuth("user", "password") - sc.exec() - - err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson) - So(err, ShouldNotBeNil) - - So(sc.resp.Code, ShouldEqual, 401) - So(sc.respJson["message"], ShouldEqual, errStringInvalidUsernamePassword) + bus.AddHandler("get-sign-user", func(query *models.GetSignedInUserQuery) error { + query.Result = &models.SignedInUser{UserId: query.UserId} + return nil }) - middlewareScenario(t, "Should return error if user & password do not match", func(sc *scenarioContext) { - bus.AddHandler("user-query", func(loginUserQuery *models.GetUserByLoginQuery) error { - return nil - }) + authHeader := util.GetBasicAuthHeader("myUser", password) + sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec() - sc.fakeReq("GET", "/") - sc.req.SetBasicAuth("killa", "gorilla") - sc.exec() + assert.True(t, sc.context.IsSignedIn) + assert.Equal(t, id, sc.context.UserId) + }) - err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson) - So(err, ShouldNotBeNil) + middlewareScenario(t, "Should return error if user is not found", func(sc *scenarioContext) { + sc.fakeReq("GET", "/") + sc.req.SetBasicAuth("user", "password") + sc.exec() - So(sc.resp.Code, ShouldEqual, 401) - So(sc.respJson["message"], ShouldEqual, errStringInvalidUsernamePassword) + err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson) + require.Error(t, err) + + assert.Equal(t, 401, sc.resp.Code) + assert.Equal(t, errStringInvalidUsernamePassword, sc.respJson["message"]) + }) + + middlewareScenario(t, "Should return error if user & password do not match", func(sc *scenarioContext) { + bus.AddHandler("user-query", func(loginUserQuery *models.GetUserByLoginQuery) error { + return nil }) - Convey("Destroy", func() { - setting.BasicAuthEnabled = oldBasicAuthEnabled - setting.DisableBruteForceLoginProtection = oldDisableBruteForceLoginProtection - }) + sc.fakeReq("GET", "/") + sc.req.SetBasicAuth("killa", "gorilla") + sc.exec() + + err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson) + require.Error(t, err) + + assert.Equal(t, 401, sc.resp.Code) + assert.Equal(t, errStringInvalidUsernamePassword, sc.respJson["message"]) }) } diff --git a/pkg/middleware/middleware_test.go b/pkg/middleware/middleware_test.go index f78906485c8..69c44527890 100644 --- a/pkg/middleware/middleware_test.go +++ b/pkg/middleware/middleware_test.go @@ -11,7 +11,6 @@ import ( "testing" "time" - . "github.com/smartystreets/goconvey/convey" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gopkg.in/macaron.v1" @@ -45,226 +44,245 @@ func resetGetTime() { } func TestMiddleWareSecurityHeaders(t *testing.T) { + origErrTemplateName := setting.ErrTemplateName + t.Cleanup(func() { + setting.ErrTemplateName = origErrTemplateName + }) setting.ErrTemplateName = errorTemplate - Convey("Given the grafana middleware", t, func() { - middlewareScenario(t, "middleware should get correct x-xss-protection header", func(sc *scenarioContext) { - setting.XSSProtectionHeader = true - sc.fakeReq("GET", "/api/").exec() - So(sc.resp.Header().Get("X-XSS-Protection"), ShouldEqual, "1; mode=block") + middlewareScenario(t, "middleware should get correct x-xss-protection header", func(sc *scenarioContext) { + origXSSProtectionHeader := setting.XSSProtectionHeader + t.Cleanup(func() { + setting.XSSProtectionHeader = origXSSProtectionHeader }) + setting.XSSProtectionHeader = true + sc.fakeReq("GET", "/api/").exec() + assert.Equal(t, "1; mode=block", sc.resp.Header().Get("X-XSS-Protection")) + }) - middlewareScenario(t, "middleware should not get x-xss-protection when disabled", func(sc *scenarioContext) { - setting.XSSProtectionHeader = false - sc.fakeReq("GET", "/api/").exec() - So(sc.resp.Header().Get("X-XSS-Protection"), ShouldBeEmpty) + middlewareScenario(t, "middleware should not get x-xss-protection when disabled", func(sc *scenarioContext) { + origXSSProtectionHeader := setting.XSSProtectionHeader + t.Cleanup(func() { + setting.XSSProtectionHeader = origXSSProtectionHeader }) + setting.XSSProtectionHeader = false + sc.fakeReq("GET", "/api/").exec() + assert.Empty(t, sc.resp.Header().Get("X-XSS-Protection")) + }) - middlewareScenario(t, "middleware should add correct Strict-Transport-Security header", func(sc *scenarioContext) { - setting.StrictTransportSecurity = true - setting.Protocol = setting.HTTPSScheme - setting.StrictTransportSecurityMaxAge = 64000 - sc.fakeReq("GET", "/api/").exec() - So(sc.resp.Header().Get("Strict-Transport-Security"), ShouldEqual, "max-age=64000") - setting.StrictTransportSecurityPreload = true - sc.fakeReq("GET", "/api/").exec() - So(sc.resp.Header().Get("Strict-Transport-Security"), ShouldEqual, "max-age=64000; preload") - setting.StrictTransportSecuritySubDomains = true - sc.fakeReq("GET", "/api/").exec() - So(sc.resp.Header().Get("Strict-Transport-Security"), ShouldEqual, "max-age=64000; preload; includeSubDomains") + middlewareScenario(t, "middleware should add correct Strict-Transport-Security header", func(sc *scenarioContext) { + origStrictTransportSecurity := setting.StrictTransportSecurity + origProtocol := setting.Protocol + origStrictTransportSecurityMaxAge := setting.StrictTransportSecurityMaxAge + t.Cleanup(func() { + setting.StrictTransportSecurity = origStrictTransportSecurity + setting.Protocol = origProtocol + setting.StrictTransportSecurityMaxAge = origStrictTransportSecurityMaxAge }) + setting.StrictTransportSecurity = true + setting.Protocol = setting.HTTPSScheme + setting.StrictTransportSecurityMaxAge = 64000 + + sc.fakeReq("GET", "/api/").exec() + assert.Equal(t, "max-age=64000", sc.resp.Header().Get("Strict-Transport-Security")) + setting.StrictTransportSecurityPreload = true + sc.fakeReq("GET", "/api/").exec() + assert.Equal(t, "max-age=64000; preload", sc.resp.Header().Get("Strict-Transport-Security")) + setting.StrictTransportSecuritySubDomains = true + sc.fakeReq("GET", "/api/").exec() + assert.Equal(t, "max-age=64000; preload; includeSubDomains", sc.resp.Header().Get("Strict-Transport-Security")) }) } func TestMiddlewareContext(t *testing.T) { + origErrTemplateName := setting.ErrTemplateName + t.Cleanup(func() { + setting.ErrTemplateName = origErrTemplateName + }) setting.ErrTemplateName = errorTemplate - Convey("Given the grafana middleware", t, func() { - middlewareScenario(t, "middleware should add context to injector", func(sc *scenarioContext) { - sc.fakeReq("GET", "/").exec() - So(sc.context, ShouldNotBeNil) - }) + middlewareScenario(t, "middleware should add context to injector", func(sc *scenarioContext) { + sc.fakeReq("GET", "/").exec() + assert.NotNil(t, sc.context) + }) - middlewareScenario(t, "Default middleware should allow get request", func(sc *scenarioContext) { - sc.fakeReq("GET", "/").exec() - So(sc.resp.Code, ShouldEqual, 200) - }) + middlewareScenario(t, "Default middleware should allow get request", func(sc *scenarioContext) { + sc.fakeReq("GET", "/").exec() + assert.Equal(t, 200, sc.resp.Code) + }) - middlewareScenario(t, "middleware should add Cache-Control header for requests to API", func(sc *scenarioContext) { - sc.fakeReq("GET", "/api/search").exec() - So(sc.resp.Header().Get("Cache-Control"), ShouldEqual, "no-cache") - So(sc.resp.Header().Get("Pragma"), ShouldEqual, "no-cache") - So(sc.resp.Header().Get("Expires"), ShouldEqual, "-1") - }) + middlewareScenario(t, "middleware should add Cache-Control header for requests to API", func(sc *scenarioContext) { + sc.fakeReq("GET", "/api/search").exec() + assert.Equal(t, "no-cache", sc.resp.Header().Get("Cache-Control")) + assert.Equal(t, "no-cache", sc.resp.Header().Get("Pragma")) + assert.Equal(t, "-1", sc.resp.Header().Get("Expires")) + }) - middlewareScenario(t, "middleware should not add Cache-Control header for requests to datasource proxy API", func(sc *scenarioContext) { - sc.fakeReq("GET", "/api/datasources/proxy/1/test").exec() - So(sc.resp.Header().Get("Cache-Control"), ShouldBeEmpty) - So(sc.resp.Header().Get("Pragma"), ShouldBeEmpty) - So(sc.resp.Header().Get("Expires"), ShouldBeEmpty) - }) + middlewareScenario(t, "middleware should not add Cache-Control header for requests to datasource proxy API", func(sc *scenarioContext) { + sc.fakeReq("GET", "/api/datasources/proxy/1/test").exec() + assert.Empty(t, sc.resp.Header().Get("Cache-Control")) + assert.Empty(t, sc.resp.Header().Get("Pragma")) + assert.Empty(t, sc.resp.Header().Get("Expires")) + }) - middlewareScenario(t, "middleware should add Cache-Control header for requests with html response", func(sc *scenarioContext) { - sc.handler(func(c *models.ReqContext) { - data := &dtos.IndexViewData{ - User: &dtos.CurrentUser{}, - Settings: map[string]interface{}{}, - NavTree: []*dtos.NavLink{}, - } - c.HTML(200, "index-template", data) - }) - sc.fakeReq("GET", "/").exec() - So(sc.resp.Code, ShouldEqual, 200) - So(sc.resp.Header().Get("Cache-Control"), ShouldEqual, "no-cache") - So(sc.resp.Header().Get("Pragma"), ShouldEqual, "no-cache") - So(sc.resp.Header().Get("Expires"), ShouldEqual, "-1") - }) - - middlewareScenario(t, "middleware should add X-Frame-Options header with deny for request when not allowing embedding", func(sc *scenarioContext) { - sc.fakeReq("GET", "/api/search").exec() - So(sc.resp.Header().Get("X-Frame-Options"), ShouldEqual, "deny") - }) - - middlewareScenario(t, "middleware should not add X-Frame-Options header for request when allowing embedding", func(sc *scenarioContext) { - setting.AllowEmbedding = true - sc.fakeReq("GET", "/api/search").exec() - So(sc.resp.Header().Get("X-Frame-Options"), ShouldBeEmpty) - }) - - middlewareScenario(t, "Invalid api key", func(sc *scenarioContext) { - sc.apiKey = "invalid_key_test" - sc.fakeReq("GET", "/").exec() - - Convey("Should not init session", func() { - So(sc.resp.Header().Get("Set-Cookie"), ShouldBeEmpty) - }) - - Convey("Should return 401", func() { - So(sc.resp.Code, ShouldEqual, 401) - So(sc.respJson["message"], ShouldEqual, errStringInvalidAPIKey) - }) - }) - - middlewareScenario(t, "Valid api key", func(sc *scenarioContext) { - keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd") - So(err, ShouldBeNil) - - bus.AddHandler("test", func(query *models.GetApiKeyByNameQuery) error { - query.Result = &models.ApiKey{OrgId: 12, Role: models.ROLE_EDITOR, Key: keyhash} - return nil - }) - - sc.fakeReq("GET", "/").withValidApiKey().exec() - - Convey("Should return 200", func() { - So(sc.resp.Code, ShouldEqual, 200) - }) - - Convey("Should init middleware context", func() { - So(sc.context.IsSignedIn, ShouldEqual, true) - So(sc.context.OrgId, ShouldEqual, 12) - So(sc.context.OrgRole, ShouldEqual, models.ROLE_EDITOR) - }) - }) - - middlewareScenario(t, "Valid api key, but does not match db hash", func(sc *scenarioContext) { - keyhash := "Something_not_matching" - - bus.AddHandler("test", func(query *models.GetApiKeyByNameQuery) error { - query.Result = &models.ApiKey{OrgId: 12, Role: models.ROLE_EDITOR, Key: keyhash} - return nil - }) - - sc.fakeReq("GET", "/").withValidApiKey().exec() - - Convey("Should return api key invalid", func() { - So(sc.resp.Code, ShouldEqual, 401) - So(sc.respJson["message"], ShouldEqual, errStringInvalidAPIKey) - }) - }) - - middlewareScenario(t, "Valid api key, but expired", func(sc *scenarioContext) { - mockGetTime() - defer resetGetTime() - - keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd") - So(err, ShouldBeNil) - - bus.AddHandler("test", func(query *models.GetApiKeyByNameQuery) error { - // api key expired one second before - expires := getTime().Add(-1 * time.Second).Unix() - query.Result = &models.ApiKey{OrgId: 12, Role: models.ROLE_EDITOR, Key: keyhash, - Expires: &expires} - return nil - }) - - sc.fakeReq("GET", "/").withValidApiKey().exec() - - Convey("Should return 401", func() { - So(sc.resp.Code, ShouldEqual, 401) - So(sc.respJson["message"], ShouldEqual, "Expired API key") - }) - }) - - middlewareScenario(t, "Non-expired auth token in cookie which not are being rotated", func(sc *scenarioContext) { - sc.withTokenSessionCookie("token") - - bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{OrgId: 2, UserId: 12} - return nil - }) - - sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { - return &models.UserToken{ - UserId: 12, - UnhashedToken: unhashedToken, - }, nil + middlewareScenario(t, "middleware should add Cache-Control header for requests with html response", func(sc *scenarioContext) { + sc.handler(func(c *models.ReqContext) { + data := &dtos.IndexViewData{ + User: &dtos.CurrentUser{}, + Settings: map[string]interface{}{}, + NavTree: []*dtos.NavLink{}, } + c.HTML(200, "index-template", data) + }) + sc.fakeReq("GET", "/").exec() + assert.Equal(t, 200, sc.resp.Code) + assert.Equal(t, "no-cache", sc.resp.Header().Get("Cache-Control")) + assert.Equal(t, "no-cache", sc.resp.Header().Get("Pragma")) + assert.Equal(t, "-1", sc.resp.Header().Get("Expires")) + }) - sc.fakeReq("GET", "/").exec() + middlewareScenario(t, "middleware should add X-Frame-Options header with deny for request when not allowing embedding", func(sc *scenarioContext) { + sc.fakeReq("GET", "/api/search").exec() + assert.Equal(t, "deny", sc.resp.Header().Get("X-Frame-Options")) + }) - Convey("Should init context with user info", func() { - So(sc.context.IsSignedIn, ShouldBeTrue) - So(sc.context.UserId, ShouldEqual, 12) - So(sc.context.UserToken.UserId, ShouldEqual, 12) - So(sc.context.UserToken.UnhashedToken, ShouldEqual, "token") - }) + middlewareScenario(t, "middleware should not add X-Frame-Options header for request when allowing embedding", func(sc *scenarioContext) { + origAllowEmbedding := setting.AllowEmbedding + t.Cleanup(func() { + setting.AllowEmbedding = origAllowEmbedding + }) + setting.AllowEmbedding = true + sc.fakeReq("GET", "/api/search").exec() + assert.Empty(t, sc.resp.Header().Get("X-Frame-Options")) + }) - Convey("Should not set cookie", func() { - So(sc.resp.Header().Get("Set-Cookie"), ShouldEqual, "") - }) + middlewareScenario(t, "Invalid api key", func(sc *scenarioContext) { + sc.apiKey = "invalid_key_test" + sc.fakeReq("GET", "/").exec() + + assert.Empty(t, sc.resp.Header().Get("Set-Cookie")) + assert.Equal(t, 401, sc.resp.Code) + assert.Equal(t, errStringInvalidAPIKey, sc.respJson["message"]) + }) + + middlewareScenario(t, "Valid api key", func(sc *scenarioContext) { + const orgID int64 = 12 + keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd") + require.NoError(t, err) + + bus.AddHandler("test", func(query *models.GetApiKeyByNameQuery) error { + query.Result = &models.ApiKey{OrgId: orgID, Role: models.ROLE_EDITOR, Key: keyhash} + return nil }) - middlewareScenario(t, "Non-expired auth token in cookie which are being rotated", func(sc *scenarioContext) { - sc.withTokenSessionCookie("token") + sc.fakeReq("GET", "/").withValidApiKey().exec() - bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{OrgId: 2, UserId: 12} - return nil - }) + assert.Equal(t, 200, sc.resp.Code) - sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { - return &models.UserToken{ - UserId: 12, - UnhashedToken: "", - }, nil - } + assert.True(t, sc.context.IsSignedIn) + assert.Equal(t, orgID, sc.context.OrgId) + assert.Equal(t, models.ROLE_EDITOR, sc.context.OrgRole) + }) - sc.userAuthTokenService.TryRotateTokenProvider = func(ctx context.Context, userToken *models.UserToken, - clientIP net.IP, userAgent string) (bool, error) { - userToken.UnhashedToken = "rotated" - return true, nil - } + middlewareScenario(t, "Valid api key, but does not match db hash", func(sc *scenarioContext) { + keyhash := "Something_not_matching" - maxAge := int(setting.LoginMaxLifetime.Seconds()) + bus.AddHandler("test", func(query *models.GetApiKeyByNameQuery) error { + query.Result = &models.ApiKey{OrgId: 12, Role: models.ROLE_EDITOR, Key: keyhash} + return nil + }) + + sc.fakeReq("GET", "/").withValidApiKey().exec() + + assert.Equal(t, 401, sc.resp.Code) + assert.Equal(t, errStringInvalidAPIKey, sc.respJson["message"]) + }) + + middlewareScenario(t, "Valid api key, but expired", func(sc *scenarioContext) { + mockGetTime() + defer resetGetTime() + + keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd") + require.NoError(t, err) + + bus.AddHandler("test", func(query *models.GetApiKeyByNameQuery) error { + // api key expired one second before + expires := getTime().Add(-1 * time.Second).Unix() + query.Result = &models.ApiKey{OrgId: 12, Role: models.ROLE_EDITOR, Key: keyhash, + Expires: &expires} + return nil + }) + + sc.fakeReq("GET", "/").withValidApiKey().exec() + + assert.Equal(t, 401, sc.resp.Code) + assert.Equal(t, "Expired API key", sc.respJson["message"]) + }) + + middlewareScenario(t, "Non-expired auth token in cookie which not are being rotated", func(sc *scenarioContext) { + const userID int64 = 12 + + sc.withTokenSessionCookie("token") + + bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { + query.Result = &models.SignedInUser{OrgId: 2, UserId: userID} + return nil + }) + + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { + return &models.UserToken{ + UserId: userID, + UnhashedToken: unhashedToken, + }, nil + } + + sc.fakeReq("GET", "/").exec() + + assert.True(t, sc.context.IsSignedIn) + assert.Equal(t, userID, sc.context.UserId) + assert.Equal(t, userID, sc.context.UserToken.UserId) + assert.Equal(t, "token", sc.context.UserToken.UnhashedToken) + assert.Equal(t, "", sc.resp.Header().Get("Set-Cookie")) + }) + + middlewareScenario(t, "Non-expired auth token in cookie which are being rotated", func(sc *scenarioContext) { + const userID int64 = 12 + + sc.withTokenSessionCookie("token") + + bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { + query.Result = &models.SignedInUser{OrgId: 2, UserId: userID} + return nil + }) + + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { + return &models.UserToken{ + UserId: userID, + UnhashedToken: "", + }, nil + } + + sc.userAuthTokenService.TryRotateTokenProvider = func(ctx context.Context, userToken *models.UserToken, + clientIP net.IP, userAgent string) (bool, error) { + userToken.UnhashedToken = "rotated" + return true, nil + } + + maxAge := int(setting.LoginMaxLifetime.Seconds()) + + sameSiteModes := []http.SameSite{ + http.SameSiteNoneMode, + http.SameSiteLaxMode, + http.SameSiteStrictMode, + } + for _, sameSiteMode := range sameSiteModes { + t.Run(fmt.Sprintf("Same site mode %d", sameSiteMode), func(t *testing.T) { + origCookieSameSiteMode := setting.CookieSameSiteMode + t.Cleanup(func() { + setting.CookieSameSiteMode = origCookieSameSiteMode + }) + setting.CookieSameSiteMode = sameSiteMode - sameSitePolicies := []http.SameSite{ - http.SameSiteNoneMode, - http.SameSiteLaxMode, - http.SameSiteStrictMode, - } - for _, sameSitePolicy := range sameSitePolicies { - setting.CookieSameSiteMode = sameSitePolicy expectedCookiePath := "/" if len(setting.AppSubUrl) > 0 { expectedCookiePath = setting.AppSubUrl @@ -276,283 +294,334 @@ func TestMiddlewareContext(t *testing.T) { HttpOnly: true, MaxAge: maxAge, Secure: setting.CookieSecure, - SameSite: sameSitePolicy, + SameSite: sameSiteMode, } sc.fakeReq("GET", "/").exec() - Convey(fmt.Sprintf("Should init context with user info and setting.SameSite=%v", sameSitePolicy), func() { - So(sc.context.IsSignedIn, ShouldBeTrue) - So(sc.context.UserId, ShouldEqual, 12) - So(sc.context.UserToken.UserId, ShouldEqual, 12) - So(sc.context.UserToken.UnhashedToken, ShouldEqual, "rotated") - }) - - Convey(fmt.Sprintf("Should set cookie with setting.SameSite=%v", sameSitePolicy), func() { - So(sc.resp.Header().Get("Set-Cookie"), ShouldEqual, expectedCookie.String()) - }) - } - - Convey("Should not set cookie with SameSite attribute when setting.CookieSameSiteDisabled is true", func() { - setting.CookieSameSiteDisabled = true - setting.CookieSameSiteMode = http.SameSiteLaxMode - expectedCookiePath := "/" - if len(setting.AppSubUrl) > 0 { - expectedCookiePath = setting.AppSubUrl - } - expectedCookie := &http.Cookie{ - Name: setting.LoginCookieName, - Value: "rotated", - Path: expectedCookiePath, - HttpOnly: true, - MaxAge: maxAge, - Secure: setting.CookieSecure, - } - - sc.fakeReq("GET", "/").exec() - So(sc.resp.Header().Get("Set-Cookie"), ShouldEqual, expectedCookie.String()) + assert.True(t, sc.context.IsSignedIn) + assert.Equal(t, userID, sc.context.UserId) + assert.Equal(t, userID, sc.context.UserToken.UserId) + assert.Equal(t, "rotated", sc.context.UserToken.UnhashedToken) + assert.Equal(t, expectedCookie.String(), sc.resp.Header().Get("Set-Cookie")) }) - }) + } - middlewareScenario(t, "Invalid/expired auth token in cookie", func(sc *scenarioContext) { - sc.withTokenSessionCookie("token") + t.Run("Should not set cookie with SameSite attribute when setting.CookieSameSiteDisabled is true", func(t *testing.T) { + origCookieSameSiteDisabled := setting.CookieSameSiteDisabled + origCookieSameSiteMode := setting.CookieSameSiteMode + t.Cleanup(func() { + setting.CookieSameSiteDisabled = origCookieSameSiteDisabled + setting.CookieSameSiteMode = origCookieSameSiteMode + }) + setting.CookieSameSiteDisabled = true + setting.CookieSameSiteMode = http.SameSiteLaxMode - sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { - return nil, models.ErrUserTokenNotFound + expectedCookiePath := "/" + if len(setting.AppSubUrl) > 0 { + expectedCookiePath = setting.AppSubUrl + } + expectedCookie := &http.Cookie{ + Name: setting.LoginCookieName, + Value: "rotated", + Path: expectedCookiePath, + HttpOnly: true, + MaxAge: maxAge, + Secure: setting.CookieSecure, } sc.fakeReq("GET", "/").exec() + assert.Equal(t, expectedCookie.String(), sc.resp.Header().Get("Set-Cookie")) + }) + }) - Convey("Should not init context with user info", func() { - So(sc.context.IsSignedIn, ShouldBeFalse) - So(sc.context.UserId, ShouldEqual, 0) - So(sc.context.UserToken, ShouldBeNil) - }) + middlewareScenario(t, "Invalid/expired auth token in cookie", func(sc *scenarioContext) { + sc.withTokenSessionCookie("token") + + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { + return nil, models.ErrUserTokenNotFound + } + + sc.fakeReq("GET", "/").exec() + + assert.False(t, sc.context.IsSignedIn) + assert.Equal(t, int64(0), sc.context.UserId) + assert.Nil(t, sc.context.UserToken) + }) + + middlewareScenario(t, "When anonymous access is enabled", func(sc *scenarioContext) { + const orgID int64 = 2 + + origAnonymousEnabled := setting.AnonymousEnabled + origAnonymousOrgName := setting.AnonymousOrgName + origAnonymousOrgRole := setting.AnonymousOrgRole + t.Cleanup(func() { + setting.AnonymousEnabled = origAnonymousEnabled + setting.AnonymousOrgName = origAnonymousOrgName + setting.AnonymousOrgRole = origAnonymousOrgRole + }) + setting.AnonymousEnabled = true + setting.AnonymousOrgName = "test" + setting.AnonymousOrgRole = string(models.ROLE_EDITOR) + + bus.AddHandler("test", func(query *models.GetOrgByNameQuery) error { + assert.Equal(t, "test", query.Name) + + query.Result = &models.Org{Id: orgID, Name: "test"} + return nil }) - middlewareScenario(t, "When anonymous access is enabled", func(sc *scenarioContext) { - setting.AnonymousEnabled = true - setting.AnonymousOrgName = "test" - setting.AnonymousOrgRole = string(models.ROLE_EDITOR) + sc.fakeReq("GET", "/").exec() - bus.AddHandler("test", func(query *models.GetOrgByNameQuery) error { - So(query.Name, ShouldEqual, "test") + assert.Equal(t, int64(0), sc.context.UserId) + assert.Equal(t, orgID, sc.context.OrgId) + assert.Equal(t, models.ROLE_EDITOR, sc.context.OrgRole) + assert.False(t, sc.context.IsSignedIn) + }) - query.Result = &models.Org{Id: 2, Name: "test"} + t.Run("auth_proxy", func(t *testing.T) { + const userID int64 = 33 + const orgID int64 = 4 + + origAuthProxyEnabled := setting.AuthProxyEnabled + origAuthProxyWhitelist := setting.AuthProxyWhitelist + origAuthProxyAutoSignUp := setting.AuthProxyAutoSignUp + origLDAPEnabled := setting.LDAPEnabled + origAuthProxyHeaderName := setting.AuthProxyHeaderName + origAuthProxyHeaderProperty := setting.AuthProxyHeaderProperty + origAuthProxyHeaders := setting.AuthProxyHeaders + t.Cleanup(func() { + setting.AuthProxyEnabled = origAuthProxyEnabled + setting.AuthProxyWhitelist = origAuthProxyWhitelist + setting.AuthProxyAutoSignUp = origAuthProxyAutoSignUp + setting.LDAPEnabled = origLDAPEnabled + setting.AuthProxyHeaderName = origAuthProxyHeaderName + setting.AuthProxyHeaderProperty = origAuthProxyHeaderProperty + setting.AuthProxyHeaders = origAuthProxyHeaders + }) + setting.AuthProxyEnabled = true + setting.AuthProxyWhitelist = "" + setting.AuthProxyAutoSignUp = true + setting.LDAPEnabled = true + setting.AuthProxyHeaderName = "X-WEBAUTH-USER" + setting.AuthProxyHeaderProperty = "username" + setting.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS"} + + const hdrName = "markelog" + const group = "grafana-core-team" + + middlewareScenario(t, "Should not sync the user if it's in the cache", func(sc *scenarioContext) { + bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { + query.Result = &models.SignedInUser{OrgId: orgID, UserId: query.UserId} return nil }) - sc.fakeReq("GET", "/").exec() + key := fmt.Sprintf(authproxy.CachePrefix, authproxy.HashCacheKey(hdrName+"-"+group)) + err := sc.remoteCacheService.Set(key, userID, 0) + require.NoError(t, err) + sc.fakeReq("GET", "/") - Convey("Should init context with org info", func() { - So(sc.context.UserId, ShouldEqual, 0) - So(sc.context.OrgId, ShouldEqual, 2) - So(sc.context.OrgRole, ShouldEqual, models.ROLE_EDITOR) - }) + sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) + sc.req.Header.Set("X-WEBAUTH-GROUPS", group) + sc.exec() - Convey("context signed in should be false", func() { - So(sc.context.IsSignedIn, ShouldBeFalse) - }) + assert.True(t, sc.context.IsSignedIn) + assert.Equal(t, userID, sc.context.UserId) + assert.Equal(t, orgID, sc.context.OrgId) }) - Convey("auth_proxy", func() { - setting.AuthProxyEnabled = true - setting.AuthProxyWhitelist = "" + middlewareScenario(t, "Should respect auto signup option", func(sc *scenarioContext) { + origLDAPEnabled = setting.LDAPEnabled + origAuthProxyAutoSignUp = setting.AuthProxyAutoSignUp + t.Cleanup(func() { + setting.LDAPEnabled = origLDAPEnabled + setting.AuthProxyAutoSignUp = origAuthProxyAutoSignUp + }) + setting.LDAPEnabled = false + setting.AuthProxyAutoSignUp = false + + var actualAuthProxyAutoSignUp *bool = nil + + bus.AddHandler("test", func(cmd *models.UpsertUserCommand) error { + actualAuthProxyAutoSignUp = &cmd.SignupAllowed + return login.ErrInvalidCredentials + }) + + sc.fakeReq("GET", "/") + sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) + sc.exec() + + assert.False(t, *actualAuthProxyAutoSignUp) + assert.Equal(t, sc.resp.Code, 407) + assert.Nil(t, sc.context) + }) + + middlewareScenario(t, "Should create an user from a header", func(sc *scenarioContext) { + origLDAPEnabled = setting.LDAPEnabled + origAuthProxyAutoSignUp = setting.AuthProxyAutoSignUp + t.Cleanup(func() { + setting.LDAPEnabled = origLDAPEnabled + setting.AuthProxyAutoSignUp = origAuthProxyAutoSignUp + }) + setting.LDAPEnabled = false setting.AuthProxyAutoSignUp = true - setting.LDAPEnabled = true - setting.AuthProxyHeaderName = "X-WEBAUTH-USER" - setting.AuthProxyHeaderProperty = "username" - setting.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS"} - name := "markelog" - group := "grafana-core-team" - middlewareScenario(t, "Should not sync the user if it's in the cache", func(sc *scenarioContext) { - bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{OrgId: 4, UserId: query.UserId} + bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { + if query.UserId > 0 { + query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID} return nil - }) - - key := fmt.Sprintf(authproxy.CachePrefix, authproxy.HashCacheKey(name+"-"+group)) - err := sc.remoteCacheService.Set(key, int64(33), 0) - So(err, ShouldBeNil) - sc.fakeReq("GET", "/") - - sc.req.Header.Add(setting.AuthProxyHeaderName, name) - sc.req.Header.Add("X-WEBAUTH-GROUPS", group) - sc.exec() - - Convey("Should init user via cache", func() { - So(sc.context.IsSignedIn, ShouldBeTrue) - So(sc.context.UserId, ShouldEqual, 33) - So(sc.context.OrgId, ShouldEqual, 4) - }) + } + return models.ErrUserNotFound }) - middlewareScenario(t, "Should respect auto signup option", func(sc *scenarioContext) { - setting.LDAPEnabled = false - setting.AuthProxyAutoSignUp = false - var actualAuthProxyAutoSignUp *bool = nil - - bus.AddHandler("test", func(cmd *models.UpsertUserCommand) error { - actualAuthProxyAutoSignUp = &cmd.SignupAllowed - return login.ErrInvalidCredentials - }) - - sc.fakeReq("GET", "/") - sc.req.Header.Add(setting.AuthProxyHeaderName, name) - sc.exec() - - assert.False(t, *actualAuthProxyAutoSignUp) - assert.Equal(t, sc.resp.Code, 407) - assert.Nil(t, sc.context) + bus.AddHandler("test", func(cmd *models.UpsertUserCommand) error { + cmd.Result = &models.User{Id: userID} + return nil }) - middlewareScenario(t, "Should create an user from a header", func(sc *scenarioContext) { - setting.LDAPEnabled = false - setting.AuthProxyAutoSignUp = true + sc.fakeReq("GET", "/") + sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) + sc.exec() - bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { - if query.UserId > 0 { - query.Result = &models.SignedInUser{OrgId: 4, UserId: 33} - return nil - } - return models.ErrUserNotFound - }) + assert.True(t, sc.context.IsSignedIn) + assert.Equal(t, userID, sc.context.UserId) + assert.Equal(t, orgID, sc.context.OrgId) + }) - bus.AddHandler("test", func(cmd *models.UpsertUserCommand) error { - cmd.Result = &models.User{Id: 33} - return nil - }) + middlewareScenario(t, "Should get an existing user from header", func(sc *scenarioContext) { + const userID int64 = 12 + const orgID int64 = 2 - sc.fakeReq("GET", "/") - sc.req.Header.Add(setting.AuthProxyHeaderName, name) - sc.exec() + origLDAPEnabled = setting.LDAPEnabled + t.Cleanup(func() { + setting.LDAPEnabled = origLDAPEnabled + }) + setting.LDAPEnabled = false - Convey("Should create user from header info", func() { - So(sc.context.IsSignedIn, ShouldBeTrue) - So(sc.context.UserId, ShouldEqual, 33) - So(sc.context.OrgId, ShouldEqual, 4) - }) + bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { + query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID} + return nil }) - middlewareScenario(t, "Should get an existing user from header", func(sc *scenarioContext) { - setting.LDAPEnabled = false - - bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{OrgId: 2, UserId: 12} - return nil - }) - - bus.AddHandler("test", func(cmd *models.UpsertUserCommand) error { - cmd.Result = &models.User{Id: 12} - return nil - }) - - sc.fakeReq("GET", "/") - sc.req.Header.Add(setting.AuthProxyHeaderName, name) - sc.exec() - - Convey("Should init context with user info", func() { - So(sc.context.IsSignedIn, ShouldBeTrue) - So(sc.context.UserId, ShouldEqual, 12) - So(sc.context.OrgId, ShouldEqual, 2) - }) + bus.AddHandler("test", func(cmd *models.UpsertUserCommand) error { + cmd.Result = &models.User{Id: userID} + return nil }) - middlewareScenario(t, "Should allow the request from whitelist IP", func(sc *scenarioContext) { - setting.AuthProxyWhitelist = "192.168.1.0/24, 2001::0/120" - setting.LDAPEnabled = false + sc.fakeReq("GET", "/") + sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) + sc.exec() - bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{OrgId: 4, UserId: 33} - return nil - }) + assert.True(t, sc.context.IsSignedIn) + assert.Equal(t, userID, sc.context.UserId) + assert.Equal(t, orgID, sc.context.OrgId) + }) - bus.AddHandler("test", func(cmd *models.UpsertUserCommand) error { - cmd.Result = &models.User{Id: 33} - return nil - }) + middlewareScenario(t, "Should allow the request from whitelist IP", func(sc *scenarioContext) { + origAuthProxyWhitelist = setting.AuthProxyWhitelist + origLDAPEnabled = setting.LDAPEnabled + t.Cleanup(func() { + setting.AuthProxyWhitelist = origAuthProxyWhitelist + setting.LDAPEnabled = origLDAPEnabled + }) + setting.AuthProxyWhitelist = "192.168.1.0/24, 2001::0/120" + setting.LDAPEnabled = false - sc.fakeReq("GET", "/") - sc.req.Header.Add(setting.AuthProxyHeaderName, name) - sc.req.RemoteAddr = "[2001::23]:12345" - sc.exec() - - Convey("Should init context with user info", func() { - So(sc.context.IsSignedIn, ShouldBeTrue) - So(sc.context.UserId, ShouldEqual, 33) - So(sc.context.OrgId, ShouldEqual, 4) - }) + bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { + query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID} + return nil }) - middlewareScenario(t, "Should not allow the request from whitelist IP", func(sc *scenarioContext) { - setting.AuthProxyWhitelist = "8.8.8.8" - setting.LDAPEnabled = false - - bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{OrgId: 4, UserId: 33} - return nil - }) - - bus.AddHandler("test", func(cmd *models.UpsertUserCommand) error { - cmd.Result = &models.User{Id: 33} - return nil - }) - - sc.fakeReq("GET", "/") - sc.req.Header.Add(setting.AuthProxyHeaderName, name) - sc.req.RemoteAddr = "[2001::23]:12345" - sc.exec() - - Convey("Should return 407 status code", func() { - So(sc.resp.Code, ShouldEqual, 407) - So(sc.context, ShouldBeNil) - }) + bus.AddHandler("test", func(cmd *models.UpsertUserCommand) error { + cmd.Result = &models.User{Id: userID} + return nil }) - middlewareScenario(t, "Should return 407 status code if LDAP says no", func(sc *scenarioContext) { - bus.AddHandler("LDAP", func(cmd *models.UpsertUserCommand) error { - return errors.New("Do not add user") - }) + sc.fakeReq("GET", "/") + sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) + sc.req.RemoteAddr = "[2001::23]:12345" + sc.exec() - sc.fakeReq("GET", "/") - sc.req.Header.Add(setting.AuthProxyHeaderName, name) - sc.exec() + assert.True(t, sc.context.IsSignedIn) + assert.Equal(t, userID, sc.context.UserId) + assert.Equal(t, orgID, sc.context.OrgId) + }) - Convey("Should return 407 status code", func() { - So(sc.resp.Code, ShouldEqual, 407) - So(sc.context, ShouldBeNil) - }) + middlewareScenario(t, "Should not allow the request from whitelist IP", func(sc *scenarioContext) { + origAuthProxyWhitelist = setting.AuthProxyWhitelist + origLDAPEnabled = setting.LDAPEnabled + t.Cleanup(func() { + setting.AuthProxyWhitelist = origAuthProxyWhitelist + setting.LDAPEnabled = origLDAPEnabled + }) + setting.AuthProxyWhitelist = "8.8.8.8" + setting.LDAPEnabled = false + + bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { + query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID} + return nil }) - middlewareScenario(t, "Should return 407 status code if there is cache mishap", func(sc *scenarioContext) { - bus.AddHandler("Do not have the user", func(query *models.GetSignedInUserQuery) error { - return errors.New("Do not add user") - }) - - sc.fakeReq("GET", "/") - sc.req.Header.Add(setting.AuthProxyHeaderName, name) - sc.exec() - - Convey("Should return 407 status code", func() { - So(sc.resp.Code, ShouldEqual, 407) - So(sc.context, ShouldBeNil) - }) + bus.AddHandler("test", func(cmd *models.UpsertUserCommand) error { + cmd.Result = &models.User{Id: userID} + return nil }) + + sc.fakeReq("GET", "/") + sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) + sc.req.RemoteAddr = "[2001::23]:12345" + sc.exec() + + assert.Equal(t, 407, sc.resp.Code) + assert.Nil(t, sc.context) + }) + + middlewareScenario(t, "Should return 407 status code if LDAP says no", func(sc *scenarioContext) { + bus.AddHandler("LDAP", func(cmd *models.UpsertUserCommand) error { + return errors.New("Do not add user") + }) + + sc.fakeReq("GET", "/") + sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) + sc.exec() + + assert.Equal(t, 407, sc.resp.Code) + assert.Nil(t, sc.context) + }) + + middlewareScenario(t, "Should return 407 status code if there is cache mishap", func(sc *scenarioContext) { + bus.AddHandler("Do not have the user", func(query *models.GetSignedInUserQuery) error { + return errors.New("Do not add user") + }) + + sc.fakeReq("GET", "/") + sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) + sc.exec() + + assert.Equal(t, 407, sc.resp.Code) + assert.Nil(t, sc.context) }) }) } func middlewareScenario(t *testing.T, desc string, fn scenarioFunc) { - Convey(desc, func() { - defer bus.ClearBusHandlers() + t.Helper() + t.Run(desc, func(t *testing.T) { + t.Cleanup(bus.ClearBusHandlers) + + origLoginCookieName := setting.LoginCookieName + origLoginMaxLifetime := setting.LoginMaxLifetime + t.Cleanup(func() { + setting.LoginCookieName = origLoginCookieName + setting.LoginMaxLifetime = origLoginMaxLifetime + }) setting.LoginCookieName = "grafana_session" var err error setting.LoginMaxLifetime, err = gtime.ParseDuration("30d") require.NoError(t, err) - sc := &scenarioContext{} + sc := &scenarioContext{t: t} viewsPath, err := filepath.Abs("../../public/views") require.NoError(t, err) @@ -590,7 +659,7 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc) { func TestDontRotateTokensOnCancelledRequests(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - reqContext, _, err := initTokenRotationTest(ctx) + reqContext, _, err := initTokenRotationTest(ctx, t) require.NoError(t, err) tryRotateCallCount := 0 @@ -612,7 +681,7 @@ func TestDontRotateTokensOnCancelledRequests(t *testing.T) { } func TestTokenRotationAtEndOfRequest(t *testing.T) { - reqContext, rr, err := initTokenRotationTest(context.Background()) + reqContext, rr, err := initTokenRotationTest(context.Background(), t) require.NoError(t, err) uts := &auth.FakeUserAuthTokenService{ @@ -643,7 +712,15 @@ func TestTokenRotationAtEndOfRequest(t *testing.T) { assert.True(t, foundLoginCookie, "Could not find cookie") } -func initTokenRotationTest(ctx context.Context) (*models.ReqContext, *httptest.ResponseRecorder, error) { +func initTokenRotationTest(ctx context.Context, t *testing.T) (*models.ReqContext, *httptest.ResponseRecorder, error) { + t.Helper() + + origLoginCookieName := setting.LoginCookieName + origLoginMaxLifetime := setting.LoginMaxLifetime + t.Cleanup(func() { + setting.LoginCookieName = origLoginCookieName + setting.LoginMaxLifetime = origLoginMaxLifetime + }) setting.LoginCookieName = "login_token" var err error setting.LoginMaxLifetime, err = gtime.ParseDuration("7d") diff --git a/pkg/middleware/org_redirect_test.go b/pkg/middleware/org_redirect_test.go index f0698ba30e4..8225b9970ed 100644 --- a/pkg/middleware/org_redirect_test.go +++ b/pkg/middleware/org_redirect_test.go @@ -7,61 +7,55 @@ import ( "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/models" - . "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" ) func TestOrgRedirectMiddleware(t *testing.T) { - Convey("Can redirect to correct org", t, func() { - middlewareScenario(t, "when setting a correct org for the user", func(sc *scenarioContext) { - sc.withTokenSessionCookie("token") - bus.AddHandler("test", func(query *models.SetUsingOrgCommand) error { - return nil - }) - - bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{OrgId: 1, UserId: 12} - return nil - }) - - sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { - return &models.UserToken{ - UserId: 0, - UnhashedToken: "", - }, nil - } - - sc.m.Get("/", sc.defaultHandler) - sc.fakeReq("GET", "/?orgId=3").exec() - - Convey("change org and redirect", func() { - So(sc.resp.Code, ShouldEqual, 302) - }) + middlewareScenario(t, "when setting a correct org for the user", func(sc *scenarioContext) { + sc.withTokenSessionCookie("token") + bus.AddHandler("test", func(query *models.SetUsingOrgCommand) error { + return nil }) - middlewareScenario(t, "when setting an invalid org for user", func(sc *scenarioContext) { - sc.withTokenSessionCookie("token") - bus.AddHandler("test", func(query *models.SetUsingOrgCommand) error { - return fmt.Errorf("") - }) - - bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{OrgId: 1, UserId: 12} - return nil - }) - - sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { - return &models.UserToken{ - UserId: 12, - UnhashedToken: "", - }, nil - } - - sc.m.Get("/", sc.defaultHandler) - sc.fakeReq("GET", "/?orgId=3").exec() - - Convey("not allowed to change org", func() { - So(sc.resp.Code, ShouldEqual, 404) - }) + bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { + query.Result = &models.SignedInUser{OrgId: 1, UserId: 12} + return nil }) + + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { + return &models.UserToken{ + UserId: 0, + UnhashedToken: "", + }, nil + } + + sc.m.Get("/", sc.defaultHandler) + sc.fakeReq("GET", "/?orgId=3").exec() + + assert.Equal(t, 302, sc.resp.Code) + }) + + middlewareScenario(t, "when setting an invalid org for user", func(sc *scenarioContext) { + sc.withTokenSessionCookie("token") + bus.AddHandler("test", func(query *models.SetUsingOrgCommand) error { + return fmt.Errorf("") + }) + + bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { + query.Result = &models.SignedInUser{OrgId: 1, UserId: 12} + return nil + }) + + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { + return &models.UserToken{ + UserId: 12, + UnhashedToken: "", + }, nil + } + + sc.m.Get("/", sc.defaultHandler) + sc.fakeReq("GET", "/?orgId=3").exec() + + assert.Equal(t, 404, sc.resp.Code) }) } diff --git a/pkg/middleware/quota_test.go b/pkg/middleware/quota_test.go index 75135e20612..fc01ca62a3b 100644 --- a/pkg/middleware/quota_test.go +++ b/pkg/middleware/quota_test.go @@ -9,40 +9,40 @@ import ( "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/quota" "github.com/grafana/grafana/pkg/setting" - . "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" ) func TestMiddlewareQuota(t *testing.T) { - Convey("Given the grafana quota middleware", t, func() { - setting.AnonymousEnabled = false - setting.Quota = setting.QuotaSettings{ - Enabled: true, - Org: &setting.OrgQuota{ - User: 5, - Dashboard: 5, - DataSource: 5, - ApiKey: 5, - }, - User: &setting.UserQuota{ - Org: 5, - }, - Global: &setting.GlobalQuota{ - Org: 5, - User: 5, - Dashboard: 5, - DataSource: 5, - ApiKey: 5, - Session: 5, - }, - } + setting.AnonymousEnabled = false + setting.Quota = setting.QuotaSettings{ + Enabled: true, + Org: &setting.OrgQuota{ + User: 5, + Dashboard: 5, + DataSource: 5, + ApiKey: 5, + }, + User: &setting.UserQuota{ + Org: 5, + }, + Global: &setting.GlobalQuota{ + Org: 5, + User: 5, + Dashboard: 5, + DataSource: 5, + ApiKey: 5, + Session: 5, + }, + } - fakeAuthTokenService := auth.NewFakeUserAuthTokenService() - qs := "a.QuotaService{ - AuthTokenService: fakeAuthTokenService, - } - QuotaFn := Quota(qs) + fakeAuthTokenService := auth.NewFakeUserAuthTokenService() + qs := "a.QuotaService{ + AuthTokenService: fakeAuthTokenService, + } + quotaFn := Quota(qs) - middlewareScenario(t, "with user not logged in", func(sc *scenarioContext) { + t.Run("With user not logged in", func(t *testing.T) { + middlewareScenario(t, "and global quota not reached", func(sc *scenarioContext) { bus.AddHandler("globalQuota", func(query *models.GetGlobalQuotaByTargetQuery) error { query.Result = &models.GlobalQuotaDTO{ Target: query.Target, @@ -52,48 +52,12 @@ func TestMiddlewareQuota(t *testing.T) { return nil }) - Convey("global quota not reached", func() { - sc.m.Get("/user", QuotaFn("user"), sc.defaultHandler) - sc.fakeReq("GET", "/user").exec() - So(sc.resp.Code, ShouldEqual, 200) - }) - - Convey("global quota reached", func() { - setting.Quota.Global.User = 4 - sc.m.Get("/user", QuotaFn("user"), sc.defaultHandler) - sc.fakeReq("GET", "/user").exec() - So(sc.resp.Code, ShouldEqual, 403) - }) - - Convey("global session quota not reached", func() { - setting.Quota.Global.Session = 10 - sc.m.Get("/user", QuotaFn("session"), sc.defaultHandler) - sc.fakeReq("GET", "/user").exec() - So(sc.resp.Code, ShouldEqual, 200) - }) - - Convey("global session quota reached", func() { - setting.Quota.Global.Session = 1 - sc.m.Get("/user", QuotaFn("session"), sc.defaultHandler) - sc.fakeReq("GET", "/user").exec() - So(sc.resp.Code, ShouldEqual, 403) - }) + sc.m.Get("/user", quotaFn("user"), sc.defaultHandler) + sc.fakeReq("GET", "/user").exec() + assert.Equal(sc.t, 200, sc.resp.Code) }) - middlewareScenario(t, "with user logged in", func(sc *scenarioContext) { - sc.withTokenSessionCookie("token") - bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{OrgId: 2, UserId: 12} - return nil - }) - - sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { - return &models.UserToken{ - UserId: 12, - UnhashedToken: "", - }, nil - } - + middlewareScenario(t, "and global quota reached", func(sc *scenarioContext) { bus.AddHandler("globalQuota", func(query *models.GetGlobalQuotaByTargetQuery) error { query.Result = &models.GlobalQuotaDTO{ Target: query.Target, @@ -103,8 +67,20 @@ func TestMiddlewareQuota(t *testing.T) { return nil }) - bus.AddHandler("userQuota", func(query *models.GetUserQuotaByTargetQuery) error { - query.Result = &models.UserQuotaDTO{ + origUser := setting.Quota.Global.User + t.Cleanup(func() { + setting.Quota.Global.User = origUser + }) + setting.Quota.Global.User = 4 + + sc.m.Get("/user", quotaFn("user"), sc.defaultHandler) + sc.fakeReq("GET", "/user").exec() + assert.Equal(t, 403, sc.resp.Code) + }) + + middlewareScenario(t, "and global session quota not reached", func(sc *scenarioContext) { + bus.AddHandler("globalQuota", func(query *models.GetGlobalQuotaByTargetQuery) error { + query.Result = &models.GlobalQuotaDTO{ Target: query.Target, Limit: query.Default, Used: 4, @@ -112,57 +88,112 @@ func TestMiddlewareQuota(t *testing.T) { return nil }) - bus.AddHandler("orgQuota", func(query *models.GetOrgQuotaByTargetQuery) error { - query.Result = &models.OrgQuotaDTO{ - Target: query.Target, - Limit: query.Default, - Used: 4, - } - return nil + origSession := setting.Quota.Global.Session + t.Cleanup(func() { + setting.Quota.Global.Session = origSession }) + setting.Quota.Global.Session = 10 - Convey("global datasource quota reached", func() { - setting.Quota.Global.DataSource = 4 - sc.m.Get("/ds", QuotaFn("data_source"), sc.defaultHandler) - sc.fakeReq("GET", "/ds").exec() - So(sc.resp.Code, ShouldEqual, 403) - }) + sc.m.Get("/user", quotaFn("session"), sc.defaultHandler) + sc.fakeReq("GET", "/user").exec() + assert.Equal(t, 200, sc.resp.Code) + }) - Convey("user Org quota not reached", func() { - setting.Quota.User.Org = 5 - sc.m.Get("/org", QuotaFn("org"), sc.defaultHandler) - sc.fakeReq("GET", "/org").exec() - So(sc.resp.Code, ShouldEqual, 200) + middlewareScenario(t, "and global session quota reached", func(sc *scenarioContext) { + origSession := setting.Quota.Global.Session + t.Cleanup(func() { + setting.Quota.Global.Session = origSession }) + setting.Quota.Global.Session = 1 - Convey("user Org quota reached", func() { - setting.Quota.User.Org = 4 - sc.m.Get("/org", QuotaFn("org"), sc.defaultHandler) - sc.fakeReq("GET", "/org").exec() - So(sc.resp.Code, ShouldEqual, 403) - }) + sc.m.Get("/user", quotaFn("session"), sc.defaultHandler) + sc.fakeReq("GET", "/user").exec() + assert.Equal(sc.t, 403, sc.resp.Code) + }) + }) - Convey("org dashboard quota not reached", func() { - setting.Quota.Org.Dashboard = 10 - sc.m.Get("/dashboard", QuotaFn("dashboard"), sc.defaultHandler) - sc.fakeReq("GET", "/dashboard").exec() - So(sc.resp.Code, ShouldEqual, 200) - }) + middlewareScenario(t, "with user logged in", func(sc *scenarioContext) { + sc.withTokenSessionCookie("token") + bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { + query.Result = &models.SignedInUser{OrgId: 2, UserId: 12} + return nil + }) - Convey("org dashboard quota reached", func() { - setting.Quota.Org.Dashboard = 4 - sc.m.Get("/dashboard", QuotaFn("dashboard"), sc.defaultHandler) - sc.fakeReq("GET", "/dashboard").exec() - So(sc.resp.Code, ShouldEqual, 403) - }) + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { + return &models.UserToken{ + UserId: 12, + UnhashedToken: "", + }, nil + } - Convey("org dashboard quota reached but quotas disabled", func() { - setting.Quota.Org.Dashboard = 4 - setting.Quota.Enabled = false - sc.m.Get("/dashboard", QuotaFn("dashboard"), sc.defaultHandler) - sc.fakeReq("GET", "/dashboard").exec() - So(sc.resp.Code, ShouldEqual, 200) - }) + bus.AddHandler("globalQuota", func(query *models.GetGlobalQuotaByTargetQuery) error { + query.Result = &models.GlobalQuotaDTO{ + Target: query.Target, + Limit: query.Default, + Used: 4, + } + return nil + }) + + bus.AddHandler("userQuota", func(query *models.GetUserQuotaByTargetQuery) error { + query.Result = &models.UserQuotaDTO{ + Target: query.Target, + Limit: query.Default, + Used: 4, + } + return nil + }) + + bus.AddHandler("orgQuota", func(query *models.GetOrgQuotaByTargetQuery) error { + query.Result = &models.OrgQuotaDTO{ + Target: query.Target, + Limit: query.Default, + Used: 4, + } + return nil + }) + + t.Run("global datasource quota reached", func(t *testing.T) { + setting.Quota.Global.DataSource = 4 + sc.m.Get("/ds", quotaFn("data_source"), sc.defaultHandler) + sc.fakeReq("GET", "/ds").exec() + assert.Equal(t, 403, sc.resp.Code) + }) + + t.Run("user Org quota not reached", func(t *testing.T) { + setting.Quota.User.Org = 5 + sc.m.Get("/org", quotaFn("org"), sc.defaultHandler) + sc.fakeReq("GET", "/org").exec() + assert.Equal(t, 200, sc.resp.Code) + }) + + t.Run("user Org quota reached", func(t *testing.T) { + setting.Quota.User.Org = 4 + sc.m.Get("/org", quotaFn("org"), sc.defaultHandler) + sc.fakeReq("GET", "/org").exec() + assert.Equal(t, 403, sc.resp.Code) + }) + + t.Run("org dashboard quota not reached", func(t *testing.T) { + setting.Quota.Org.Dashboard = 10 + sc.m.Get("/dashboard", quotaFn("dashboard"), sc.defaultHandler) + sc.fakeReq("GET", "/dashboard").exec() + assert.Equal(t, 200, sc.resp.Code) + }) + + t.Run("org dashboard quota reached", func(t *testing.T) { + setting.Quota.Org.Dashboard = 4 + sc.m.Get("/dashboard", quotaFn("dashboard"), sc.defaultHandler) + sc.fakeReq("GET", "/dashboard").exec() + assert.Equal(t, 403, sc.resp.Code) + }) + + t.Run("org dashboard quota reached but quotas disabled", func(t *testing.T) { + setting.Quota.Org.Dashboard = 4 + setting.Quota.Enabled = false + sc.m.Get("/dashboard", quotaFn("dashboard"), sc.defaultHandler) + sc.fakeReq("GET", "/dashboard").exec() + assert.Equal(t, 200, sc.resp.Code) }) }) } diff --git a/pkg/middleware/recovery_test.go b/pkg/middleware/recovery_test.go index 03746daa2d4..150167dc4bc 100644 --- a/pkg/middleware/recovery_test.go +++ b/pkg/middleware/recovery_test.go @@ -2,6 +2,7 @@ package middleware import ( "path/filepath" + "strings" "testing" "github.com/grafana/grafana/pkg/bus" @@ -9,52 +10,55 @@ import ( "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/setting" - . "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" macaron "gopkg.in/macaron.v1" ) func TestRecoveryMiddleware(t *testing.T) { setting.ErrTemplateName = "error-template" - Convey("Given an api route that panics", t, func() { + t.Run("Given an API route that panics", func(t *testing.T) { apiURL := "/api/whatever" recoveryScenario(t, "recovery middleware should return json", apiURL, func(sc *scenarioContext) { - sc.handlerFunc = PanicHandler + sc.handlerFunc = panicHandler sc.fakeReq("GET", apiURL).exec() sc.req.Header.Add("content-type", "application/json") - So(sc.resp.Code, ShouldEqual, 500) - So(sc.respJson["message"], ShouldStartWith, "Internal Server Error - Check the Grafana server logs for the detailed error message.") - So(sc.respJson["error"], ShouldStartWith, "Server Error") + assert.Equal(t, 500, sc.resp.Code) + assert.Equal(t, "Internal Server Error - Check the Grafana server logs for the detailed error message.", sc.respJson["message"]) + assert.True(t, strings.HasPrefix(sc.respJson["error"].(string), "Server Error")) }) }) - Convey("Given a non-api route that panics", t, func() { + t.Run("Given a non-API route that panics", func(t *testing.T) { apiURL := "/whatever" recoveryScenario(t, "recovery middleware should return html", apiURL, func(sc *scenarioContext) { - sc.handlerFunc = PanicHandler + sc.handlerFunc = panicHandler sc.fakeReq("GET", apiURL).exec() - So(sc.resp.Code, ShouldEqual, 500) - So(sc.resp.Header().Get("content-type"), ShouldEqual, "text/html; charset=UTF-8") - So(sc.resp.Body.String(), ShouldContainSubstring, "Grafana - Error") + assert.Equal(t, 500, sc.resp.Code) + assert.Equal(t, "text/html; charset=UTF-8", sc.resp.Header().Get("content-type")) + assert.True(t, strings.Contains(sc.resp.Body.String(), "Grafana - Error")) }) }) } -func PanicHandler(c *models.ReqContext) { +func panicHandler(c *models.ReqContext) { panic("Handler has panicked") } func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) { - Convey(desc, func() { + t.Run(desc, func(t *testing.T) { defer bus.ClearBusHandlers() sc := &scenarioContext{ + t: t, url: url, } - viewsPath, _ := filepath.Abs("../../public/views") + viewsPath, err := filepath.Abs("../../public/views") + require.NoError(t, err) sc.m = macaron.New() sc.m.Use(Recovery()) diff --git a/pkg/middleware/testing.go b/pkg/middleware/testing.go index 0e24cc3e7cc..21a9852790d 100644 --- a/pkg/middleware/testing.go +++ b/pkg/middleware/testing.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "testing" "gopkg.in/macaron.v1" @@ -11,10 +12,11 @@ import ( "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/setting" - "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/require" ) type scenarioContext struct { + t *testing.T m *macaron.Macaron context *models.ReqContext resp *httptest.ResponseRecorder @@ -47,15 +49,19 @@ func (sc *scenarioContext) withAuthorizationHeader(authHeader string) *scenarioC } func (sc *scenarioContext) fakeReq(method, url string) *scenarioContext { + sc.t.Helper() + sc.resp = httptest.NewRecorder() req, err := http.NewRequest(method, url, nil) - convey.So(err, convey.ShouldBeNil) + require.NoError(sc.t, err) sc.req = req return sc } func (sc *scenarioContext) fakeReqWithParams(method, url string, queryParams map[string]string) *scenarioContext { + sc.t.Helper() + sc.resp = httptest.NewRecorder() req, err := http.NewRequest(method, url, nil) q := req.URL.Query() @@ -63,7 +69,7 @@ func (sc *scenarioContext) fakeReqWithParams(method, url string, queryParams map q.Add(k, v) } req.URL.RawQuery = q.Encode() - convey.So(err, convey.ShouldBeNil) + require.NoError(sc.t, err) sc.req = req return sc @@ -75,15 +81,20 @@ func (sc *scenarioContext) handler(fn handlerFunc) *scenarioContext { } func (sc *scenarioContext) exec() { + sc.t.Helper() + if sc.apiKey != "" { + sc.t.Logf(`Adding header "Authorization: Bearer %s"`, sc.apiKey) sc.req.Header.Add("Authorization", "Bearer "+sc.apiKey) } if sc.authHeader != "" { + sc.t.Logf(`Adding header "Authorization: %s"`, sc.authHeader) sc.req.Header.Add("Authorization", sc.authHeader) } if sc.tokenSessionCookie != "" { + sc.t.Log(`Adding cookie`, "name", setting.LoginCookieName, "value", sc.tokenSessionCookie) sc.req.AddCookie(&http.Cookie{ Name: setting.LoginCookieName, Value: sc.tokenSessionCookie, @@ -94,7 +105,7 @@ func (sc *scenarioContext) exec() { if sc.resp.Header().Get("Content-Type") == "application/json; charset=UTF-8" { err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson) - convey.So(err, convey.ShouldBeNil) + require.NoError(sc.t, err) } } diff --git a/pkg/services/alerting/notifier_test.go b/pkg/services/alerting/notifier_test.go index 0b2f2b779bc..fcb49e3abc0 100644 --- a/pkg/services/alerting/notifier_test.go +++ b/pkg/services/alerting/notifier_test.go @@ -30,94 +30,102 @@ func TestNotificationService(t *testing.T) { } evalCtx := NewEvalContext(context.Background(), testRule) - notificationServiceScenario(t, "Given alert rule with upload image enabled should render and upload image and send notification", evalCtx, true, func(scenarioCtx *scenarioContext) { - err := scenarioCtx.notificationService.SendIfNeeded(evalCtx) - require.NoError(t, err) + notificationServiceScenario(t, "Given alert rule with upload image enabled should render and upload image and send notification", + evalCtx, true, func(sc *scenarioContext) { + err := sc.notificationService.SendIfNeeded(evalCtx) + require.NoError(sc.t, err) - require.Equalf(t, 1, scenarioCtx.renderCount, "expected render to be called, but wasn't") - require.Equalf(t, 1, scenarioCtx.imageUploadCount, "expected image to be uploaded, but wasn't") - require.Truef(t, evalCtx.Ctx.Value(notificationSent{}).(bool), "expected notification to be sent, but wasn't") - }) + require.Equalf(sc.t, 1, sc.renderCount, "expected render to be called, but wasn't") + require.Equalf(sc.t, 1, sc.imageUploadCount, "expected image to be uploaded, but wasn't") + require.Truef(sc.t, evalCtx.Ctx.Value(notificationSent{}).(bool), "expected notification to be sent, but wasn't") + }) - notificationServiceScenario(t, "Given alert rule with upload image enabled but no renderer available should render and upload unavailable image and send notification", evalCtx, true, func(scenarioCtx *scenarioContext) { - scenarioCtx.rendererAvailable = false - err := scenarioCtx.notificationService.SendIfNeeded(evalCtx) - require.NoError(t, err) + notificationServiceScenario(t, + "Given alert rule with upload image enabled but no renderer available should render and upload unavailable image and send notification", + evalCtx, true, func(sc *scenarioContext) { + sc.rendererAvailable = false + err := sc.notificationService.SendIfNeeded(evalCtx) + require.NoError(sc.t, err) - require.Equalf(t, 1, scenarioCtx.renderCount, "expected render to be called, but it wasn't") - require.Equalf(t, 1, scenarioCtx.imageUploadCount, "expected image to be uploaded, but it wasn't") - require.Truef(t, evalCtx.Ctx.Value(notificationSent{}).(bool), "expected notification to be sent, but wasn't") - }) + require.Equalf(sc.t, 1, sc.renderCount, "expected render to be called, but it wasn't") + require.Equalf(sc.t, 1, sc.imageUploadCount, "expected image to be uploaded, but it wasn't") + require.Truef(sc.t, evalCtx.Ctx.Value(notificationSent{}).(bool), "expected notification to be sent, but wasn't") + }) - notificationServiceScenario(t, "Given alert rule with upload image disabled should not render and upload image, but send notification", evalCtx, false, func(scenarioCtx *scenarioContext) { - err := scenarioCtx.notificationService.SendIfNeeded(evalCtx) - require.NoError(t, err) + notificationServiceScenario( + t, "Given alert rule with upload image disabled should not render and upload image, but send notification", + evalCtx, false, func(sc *scenarioContext) { + err := sc.notificationService.SendIfNeeded(evalCtx) + require.NoError(t, err) - require.Equalf(t, 0, scenarioCtx.renderCount, "expected render not to be called, but it was") - require.Equalf(t, 0, scenarioCtx.imageUploadCount, "expected image not to be uploaded, but it was") - require.Truef(t, evalCtx.Ctx.Value(notificationSent{}).(bool), "expected notification to be sent, but wasn't") - }) + require.Equalf(sc.t, 0, sc.renderCount, "expected render not to be called, but it was") + require.Equalf(sc.t, 0, sc.imageUploadCount, "expected image not to be uploaded, but it was") + require.Truef(sc.t, evalCtx.Ctx.Value(notificationSent{}).(bool), "expected notification to be sent, but wasn't") + }) - notificationServiceScenario(t, "Given alert rule with upload image enabled and render times out should send notification", evalCtx, true, func(scenarioCtx *scenarioContext) { - setting.AlertingNotificationTimeout = 200 * time.Millisecond - scenarioCtx.renderProvider = func(ctx context.Context, opts rendering.Opts) (*rendering.RenderResult, error) { - wait := make(chan bool) + notificationServiceScenario(t, "Given alert rule with upload image enabled and render times out should send notification", + evalCtx, true, func(sc *scenarioContext) { + setting.AlertingNotificationTimeout = 200 * time.Millisecond + sc.renderProvider = func(ctx context.Context, opts rendering.Opts) (*rendering.RenderResult, error) { + wait := make(chan bool) - go func() { - time.Sleep(1 * time.Second) - wait <- true - }() + go func() { + time.Sleep(1 * time.Second) + wait <- true + }() - select { - case <-ctx.Done(): - if err := ctx.Err(); err != nil { - return nil, err + select { + case <-ctx.Done(): + if err := ctx.Err(); err != nil { + return nil, err + } + break + case <-wait: } - break - case <-wait: + + return nil, nil } + err := sc.notificationService.SendIfNeeded(evalCtx) + require.NoError(sc.t, err) - return nil, nil - } - err := scenarioCtx.notificationService.SendIfNeeded(evalCtx) - require.NoError(t, err) + require.Equalf(sc.t, 0, sc.renderCount, "expected render not to be called, but it was") + require.Equalf(sc.t, 0, sc.imageUploadCount, "expected image not to be uploaded, but it was") + require.Truef(sc.t, evalCtx.Ctx.Value(notificationSent{}).(bool), "expected notification to be sent, but wasn't") + }) - require.Equalf(t, 0, scenarioCtx.renderCount, "expected render not to be called, but it was") - require.Equalf(t, 0, scenarioCtx.imageUploadCount, "expected image not to be uploaded, but it was") - require.Truef(t, evalCtx.Ctx.Value(notificationSent{}).(bool), "expected notification to be sent, but wasn't") - }) + notificationServiceScenario(t, "Given alert rule with upload image enabled and upload times out should send notification", + evalCtx, true, func(sc *scenarioContext) { + setting.AlertingNotificationTimeout = 200 * time.Millisecond + sc.uploadProvider = func(ctx context.Context, path string) (string, error) { + wait := make(chan bool) - notificationServiceScenario(t, "Given alert rule with upload image enabled and upload times out should send notification", evalCtx, true, func(scenarioCtx *scenarioContext) { - setting.AlertingNotificationTimeout = 200 * time.Millisecond - scenarioCtx.uploadProvider = func(ctx context.Context, path string) (string, error) { - wait := make(chan bool) + go func() { + time.Sleep(1 * time.Second) + wait <- true + }() - go func() { - time.Sleep(1 * time.Second) - wait <- true - }() - - select { - case <-ctx.Done(): - if err := ctx.Err(); err != nil { - return "", err + select { + case <-ctx.Done(): + if err := ctx.Err(); err != nil { + return "", err + } + break + case <-wait: } - break - case <-wait: + + return "", nil } + err := sc.notificationService.SendIfNeeded(evalCtx) + require.NoError(sc.t, err) - return "", nil - } - err := scenarioCtx.notificationService.SendIfNeeded(evalCtx) - require.NoError(t, err) - - require.Equalf(t, 1, scenarioCtx.renderCount, "expected render to be called, but wasn't") - require.Equalf(t, 0, scenarioCtx.imageUploadCount, "expected image not to be uploaded, but it was") - require.Truef(t, evalCtx.Ctx.Value(notificationSent{}).(bool), "expected notification to be sent, but wasn't") - }) + require.Equalf(sc.t, 1, sc.renderCount, "expected render to be called, but wasn't") + require.Equalf(sc.t, 0, sc.imageUploadCount, "expected image not to be uploaded, but it was") + require.Truef(sc.t, evalCtx.Ctx.Value(notificationSent{}).(bool), "expected notification to be sent, but wasn't") + }) } type scenarioContext struct { + t *testing.T evalCtx *EvalContext notificationService *notificationService imageUploadCount int @@ -175,6 +183,7 @@ func notificationServiceScenario(t *testing.T, name string, evalCtx *EvalContext setting.AlertingNotificationTimeout = 30 * time.Second scenarioCtx := &scenarioContext{ + t: t, evalCtx: evalCtx, }