Middleware: Rewrite tests to use standard library (#29535)

* middleware: Rewrite tests to use standard library

Signed-off-by: Arve Knudsen <arve.knudsen@gmail.com>
This commit is contained in:
Arve Knudsen
2020-12-03 08:28:54 +01:00
committed by GitHub
parent 0b6434d0e8
commit 58dbf96a12
10 changed files with 1008 additions and 898 deletions

View File

@ -7,13 +7,11 @@ import (
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
. "github.com/smartystreets/goconvey/convey"
)
func TestMiddlewareAuth(t *testing.T) {
Convey("Given the grafana middleware", t, func() {
reqSignIn := Auth(&AuthOptions{ReqSignedIn: true})
middlewareScenario(t, "ReqSignIn true and unauthenticated request", func(sc *scenarioContext) {
@ -21,9 +19,7 @@ func TestMiddlewareAuth(t *testing.T) {
sc.fakeReq("GET", "/secure").exec()
Convey("Should redirect to login", func() {
So(sc.resp.Code, ShouldEqual, 302)
})
assert.Equal(t, 302, sc.resp.Code)
})
middlewareScenario(t, "ReqSignIn true and unauthenticated API request", func(sc *scenarioContext) {
@ -31,12 +27,12 @@ func TestMiddlewareAuth(t *testing.T) {
sc.fakeReq("GET", "/api/secure").exec()
Convey("Should return 401", func() {
So(sc.resp.Code, ShouldEqual, 401)
})
assert.Equal(t, 401, sc.resp.Code)
})
Convey("Anonymous auth enabled", func() {
t.Run("Anonymous auth enabled", func(t *testing.T) {
const orgID int64 = 1
origEnabled := setting.AnonymousEnabled
t.Cleanup(func() {
setting.AnonymousEnabled = origEnabled
@ -48,62 +44,63 @@ func TestMiddlewareAuth(t *testing.T) {
setting.AnonymousEnabled = true
setting.AnonymousOrgName = "test"
middlewareScenario(t, "ReqSignIn true and request with forceLogin in query string", func(sc *scenarioContext) {
bus.AddHandler("test", func(query *models.GetOrgByNameQuery) error {
query.Result = &models.Org{Id: 1, Name: "test"}
query.Result = &models.Org{Id: orgID, Name: "test"}
return nil
})
middlewareScenario(t, "ReqSignIn true and request with forceLogin in query string", func(sc *scenarioContext) {
sc.m.Get("/secure", reqSignIn, sc.defaultHandler)
sc.fakeReq("GET", "/secure?forceLogin=true").exec()
Convey("Should redirect to login", func() {
So(sc.resp.Code, ShouldEqual, 302)
assert.Equal(sc.t, 302, sc.resp.Code)
location, ok := sc.resp.Header()["Location"]
So(ok, ShouldBeTrue)
So(location[0], ShouldEqual, "/login")
})
assert.True(t, ok)
assert.Equal(t, "/login", location[0])
})
middlewareScenario(t, "ReqSignIn true and request with same org provided in query string", func(sc *scenarioContext) {
bus.AddHandler("test", func(query *models.GetOrgByNameQuery) error {
query.Result = &models.Org{Id: orgID, Name: "test"}
return nil
})
sc.m.Get("/secure", reqSignIn, sc.defaultHandler)
sc.fakeReq("GET", "/secure?orgId=1").exec()
sc.fakeReq("GET", fmt.Sprintf("/secure?orgId=%d", orgID)).exec()
Convey("Should not redirect to login", func() {
So(sc.resp.Code, ShouldEqual, 200)
})
assert.Equal(sc.t, 200, sc.resp.Code)
})
middlewareScenario(t, "ReqSignIn true and request with different org provided in query string", func(sc *scenarioContext) {
bus.AddHandler("test", func(query *models.GetOrgByNameQuery) error {
query.Result = &models.Org{Id: orgID, Name: "test"}
return nil
})
sc.m.Get("/secure", reqSignIn, sc.defaultHandler)
sc.fakeReq("GET", "/secure?orgId=2").exec()
Convey("Should redirect to login", func() {
So(sc.resp.Code, ShouldEqual, 302)
assert.Equal(sc.t, 302, sc.resp.Code)
location, ok := sc.resp.Header()["Location"]
So(ok, ShouldBeTrue)
So(location[0], ShouldEqual, "/login")
})
assert.True(sc.t, ok)
assert.Equal(sc.t, "/login", location[0])
})
})
Convey("snapshot public mode or signed in", func() {
middlewareScenario(t, "Snapshot public mode disabled and unauthenticated request should return 401", func(sc *scenarioContext) {
sc.m.Get("/api/snapshot", SnapshotPublicModeOrSignedIn(), sc.defaultHandler)
sc.fakeReq("GET", "/api/snapshot").exec()
So(sc.resp.Code, ShouldEqual, 401)
assert.Equal(sc.t, 401, sc.resp.Code)
})
middlewareScenario(t, "Snapshot public mode enabled and unauthenticated request should return 200", func(sc *scenarioContext) {
setting.SnapshotPublicMode = true
sc.m.Get("/api/snapshot", SnapshotPublicModeOrSignedIn(), sc.defaultHandler)
sc.fakeReq("GET", "/api/snapshot").exec()
So(sc.resp.Code, ShouldEqual, 200)
})
})
assert.Equal(sc.t, 200, sc.resp.Code)
})
}

View File

@ -8,6 +8,8 @@ import (
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/util"
. "github.com/smartystreets/goconvey/convey"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMiddlewareDashboardRedirect(t *testing.T) {
@ -22,62 +24,57 @@ func TestMiddlewareDashboardRedirect(t *testing.T) {
fakeDash.HasAcl = false
fakeDash.Uid = util.GenerateShortUID()
middlewareScenario(t, "GET dashboard by legacy url", func(sc *scenarioContext) {
bus.AddHandler("test", func(query *models.GetDashboardQuery) error {
query.Result = fakeDash
return nil
})
middlewareScenario(t, "GET dashboard by legacy url", func(sc *scenarioContext) {
sc.m.Get("/dashboard/db/:slug", redirectFromLegacyDashboardUrl, sc.defaultHandler)
sc.fakeReqWithParams("GET", "/dashboard/db/dash?orgId=1&panelId=2", map[string]string{}).exec()
Convey("Should redirect to new dashboard url with a 301 Moved Permanently", func() {
So(sc.resp.Code, ShouldEqual, 301)
assert.Equal(t, 301, sc.resp.Code)
resp := sc.resp.Result()
defer resp.Body.Close()
resp.Body.Close()
redirectURL, err := resp.Location()
So(err, ShouldBeNil)
So(redirectURL.Path, ShouldEqual, models.GetDashboardUrl(fakeDash.Uid, fakeDash.Slug))
So(len(redirectURL.Query()), ShouldEqual, 2)
})
require.NoError(t, err)
assert.Equal(t, models.GetDashboardUrl(fakeDash.Uid, fakeDash.Slug), redirectURL.Path)
assert.Equal(t, 2, len(redirectURL.Query()))
})
middlewareScenario(t, "GET dashboard solo by legacy url", func(sc *scenarioContext) {
bus.AddHandler("test", func(query *models.GetDashboardQuery) error {
query.Result = fakeDash
return nil
})
sc.m.Get("/dashboard-solo/db/:slug", redirectFromLegacyDashboardSoloUrl, sc.defaultHandler)
sc.fakeReqWithParams("GET", "/dashboard-solo/db/dash?orgId=1&panelId=2", map[string]string{}).exec()
Convey("Should redirect to new dashboard url with a 301 Moved Permanently", func() {
So(sc.resp.Code, ShouldEqual, 301)
assert.Equal(t, 301, sc.resp.Code)
resp := sc.resp.Result()
defer resp.Body.Close()
resp.Body.Close()
redirectURL, err := resp.Location()
So(err, ShouldBeNil)
require.NoError(t, err)
expectedURL := models.GetDashboardUrl(fakeDash.Uid, fakeDash.Slug)
expectedURL = strings.Replace(expectedURL, "/d/", "/d-solo/", 1)
So(redirectURL.Path, ShouldEqual, expectedURL)
So(len(redirectURL.Query()), ShouldEqual, 2)
assert.Equal(t, expectedURL, redirectURL.Path)
assert.Equal(t, 2, len(redirectURL.Query()))
})
})
})
Convey("Given the dashboard legacy edit panel middleware", t, func() {
bus.ClearBusHandlers()
middlewareScenario(t, "GET dashboard by legacy edit url", func(sc *scenarioContext) {
sc.m.Get("/d/:uid/:slug", RedirectFromLegacyPanelEditURL(), sc.defaultHandler)
sc.fakeReqWithParams("GET", "/d/asd/dash?orgId=1&panelId=12&fullscreen&edit", map[string]string{}).exec()
Convey("Should redirect to new dashboard edit url with a 301 Moved Permanently", func() {
So(sc.resp.Code, ShouldEqual, 301)
assert.Equal(t, 301, sc.resp.Code)
resp := sc.resp.Result()
defer resp.Body.Close()
resp.Body.Close()
redirectURL, err := resp.Location()
So(err, ShouldBeNil)
So(redirectURL.String(), ShouldEqual, "/d/asd/d/asd/dash?editPanel=12&orgId=1")
})
})
require.NoError(t, err)
assert.Equal(t, "/d/asd/d/asd/dash?editPanel=12&orgId=1", redirectURL.String())
})
}

View File

@ -54,10 +54,13 @@ func GetContextHandler(
Logger: log.New("context"),
}
orgId := int64(0)
orgIdHeader := ctx.Req.Header.Get("X-Grafana-Org-Id")
if orgIdHeader != "" {
orgId, _ = strconv.ParseInt(orgIdHeader, 10, 64)
orgID := int64(0)
orgIDHeader := ctx.Req.Header.Get("X-Grafana-Org-Id")
if orgIDHeader != "" {
orgIDParsed, err := strconv.ParseInt(orgIDHeader, 10, 64)
if err == nil {
orgID = orgIDParsed
}
}
// the order in which these are tested are important
@ -68,9 +71,9 @@ func GetContextHandler(
switch {
case initContextWithRenderAuth(ctx, renderService):
case initContextWithApiKey(ctx):
case initContextWithBasicAuth(ctx, orgId):
case initContextWithAuthProxy(remoteCache, ctx, orgId):
case initContextWithToken(ats, ctx, orgId):
case initContextWithBasicAuth(ctx, orgID):
case initContextWithAuthProxy(remoteCache, ctx, orgID):
case initContextWithToken(ats, ctx, orgID):
case initContextWithAnonymousUser(ctx):
}

View File

@ -4,31 +4,33 @@ import (
"encoding/json"
"testing"
. "github.com/smartystreets/goconvey/convey"
"github.com/grafana/grafana/pkg/bus"
authLogin "github.com/grafana/grafana/pkg/login"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMiddlewareBasicAuth(t *testing.T) {
Convey("Given the basic auth", t, func() {
var oldBasicAuthEnabled = setting.BasicAuthEnabled
var oldDisableBruteForceLoginProtection = setting.DisableBruteForceLoginProtection
var id int64 = 12
Convey("Setup", func() {
var origBasicAuthEnabled = setting.BasicAuthEnabled
var origDisableBruteForceLoginProtection = setting.DisableBruteForceLoginProtection
t.Cleanup(func() {
setting.BasicAuthEnabled = origBasicAuthEnabled
setting.DisableBruteForceLoginProtection = origDisableBruteForceLoginProtection
})
setting.BasicAuthEnabled = true
setting.DisableBruteForceLoginProtection = true
bus.ClearBusHandlers()
})
const id int64 = 12
middlewareScenario(t, "Valid API key", func(sc *scenarioContext) {
var orgID int64 = 2
const orgID int64 = 2
keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
So(err, ShouldBeNil)
require.NoError(t, err)
bus.AddHandler("test", func(query *models.GetApiKeyByNameQuery) error {
query.Result = &models.ApiKey{OrgId: orgID, Role: models.ROLE_EDITOR, Key: keyhash}
@ -38,21 +40,18 @@ func TestMiddlewareBasicAuth(t *testing.T) {
authHeader := util.GetBasicAuthHeader("api_key", "eyJrIjoidjVuQXdwTWFmRlA2em5hUzR1cmhkV0RMUzU1MTFNNDIiLCJuIjoiYXNkIiwiaWQiOjF9")
sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec()
Convey("Should return 200", func() {
So(sc.resp.Code, ShouldEqual, 200)
})
Convey("Should init middleware context", func() {
So(sc.context.IsSignedIn, ShouldEqual, true)
So(sc.context.OrgId, ShouldEqual, orgID)
So(sc.context.OrgRole, ShouldEqual, models.ROLE_EDITOR)
})
assert.Equal(t, 200, sc.resp.Code)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, orgID, sc.context.OrgId)
assert.Equal(t, models.ROLE_EDITOR, sc.context.OrgRole)
})
middlewareScenario(t, "Handle auth", func(sc *scenarioContext) {
var password = "MyPass"
var salt = "Salt"
var orgID int64 = 2
const password = "MyPass"
const salt = "Salt"
const orgID int64 = 2
t.Cleanup(bus.ClearBusHandlers)
bus.AddHandler("grafana-auth", func(query *models.LoginUserQuery) error {
encoded, err := util.EncodePassword(password, salt)
@ -74,18 +73,14 @@ func TestMiddlewareBasicAuth(t *testing.T) {
authHeader := util.GetBasicAuthHeader("myUser", password)
sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec()
Convey("Should init middleware context with users", func() {
So(sc.context.IsSignedIn, ShouldEqual, true)
So(sc.context.OrgId, ShouldEqual, orgID)
So(sc.context.UserId, ShouldEqual, id)
})
bus.ClearBusHandlers()
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, orgID, sc.context.OrgId)
assert.Equal(t, id, sc.context.UserId)
})
middlewareScenario(t, "Auth sequence", func(sc *scenarioContext) {
var password = "MyPass"
var salt = "Salt"
const password = "MyPass"
const salt = "Salt"
authLogin.Init()
@ -110,10 +105,8 @@ func TestMiddlewareBasicAuth(t *testing.T) {
authHeader := util.GetBasicAuthHeader("myUser", password)
sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec()
Convey("Should init middleware context with user", func() {
So(sc.context.IsSignedIn, ShouldEqual, true)
So(sc.context.UserId, ShouldEqual, id)
})
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, id, sc.context.UserId)
})
middlewareScenario(t, "Should return error if user is not found", func(sc *scenarioContext) {
@ -122,10 +115,10 @@ func TestMiddlewareBasicAuth(t *testing.T) {
sc.exec()
err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson)
So(err, ShouldNotBeNil)
require.Error(t, err)
So(sc.resp.Code, ShouldEqual, 401)
So(sc.respJson["message"], ShouldEqual, errStringInvalidUsernamePassword)
assert.Equal(t, 401, sc.resp.Code)
assert.Equal(t, errStringInvalidUsernamePassword, sc.respJson["message"])
})
middlewareScenario(t, "Should return error if user & password do not match", func(sc *scenarioContext) {
@ -138,15 +131,9 @@ func TestMiddlewareBasicAuth(t *testing.T) {
sc.exec()
err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson)
So(err, ShouldNotBeNil)
require.Error(t, err)
So(sc.resp.Code, ShouldEqual, 401)
So(sc.respJson["message"], ShouldEqual, errStringInvalidUsernamePassword)
})
Convey("Destroy", func() {
setting.BasicAuthEnabled = oldBasicAuthEnabled
setting.DisableBruteForceLoginProtection = oldDisableBruteForceLoginProtection
})
assert.Equal(t, 401, sc.resp.Code)
assert.Equal(t, errStringInvalidUsernamePassword, sc.respJson["message"])
})
}

View File

@ -11,7 +11,6 @@ import (
"testing"
"time"
. "github.com/smartystreets/goconvey/convey"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/macaron.v1"
@ -45,63 +44,85 @@ func resetGetTime() {
}
func TestMiddleWareSecurityHeaders(t *testing.T) {
origErrTemplateName := setting.ErrTemplateName
t.Cleanup(func() {
setting.ErrTemplateName = origErrTemplateName
})
setting.ErrTemplateName = errorTemplate
Convey("Given the grafana middleware", t, func() {
middlewareScenario(t, "middleware should get correct x-xss-protection header", func(sc *scenarioContext) {
origXSSProtectionHeader := setting.XSSProtectionHeader
t.Cleanup(func() {
setting.XSSProtectionHeader = origXSSProtectionHeader
})
setting.XSSProtectionHeader = true
sc.fakeReq("GET", "/api/").exec()
So(sc.resp.Header().Get("X-XSS-Protection"), ShouldEqual, "1; mode=block")
assert.Equal(t, "1; mode=block", sc.resp.Header().Get("X-XSS-Protection"))
})
middlewareScenario(t, "middleware should not get x-xss-protection when disabled", func(sc *scenarioContext) {
origXSSProtectionHeader := setting.XSSProtectionHeader
t.Cleanup(func() {
setting.XSSProtectionHeader = origXSSProtectionHeader
})
setting.XSSProtectionHeader = false
sc.fakeReq("GET", "/api/").exec()
So(sc.resp.Header().Get("X-XSS-Protection"), ShouldBeEmpty)
assert.Empty(t, sc.resp.Header().Get("X-XSS-Protection"))
})
middlewareScenario(t, "middleware should add correct Strict-Transport-Security header", func(sc *scenarioContext) {
origStrictTransportSecurity := setting.StrictTransportSecurity
origProtocol := setting.Protocol
origStrictTransportSecurityMaxAge := setting.StrictTransportSecurityMaxAge
t.Cleanup(func() {
setting.StrictTransportSecurity = origStrictTransportSecurity
setting.Protocol = origProtocol
setting.StrictTransportSecurityMaxAge = origStrictTransportSecurityMaxAge
})
setting.StrictTransportSecurity = true
setting.Protocol = setting.HTTPSScheme
setting.StrictTransportSecurityMaxAge = 64000
sc.fakeReq("GET", "/api/").exec()
So(sc.resp.Header().Get("Strict-Transport-Security"), ShouldEqual, "max-age=64000")
assert.Equal(t, "max-age=64000", sc.resp.Header().Get("Strict-Transport-Security"))
setting.StrictTransportSecurityPreload = true
sc.fakeReq("GET", "/api/").exec()
So(sc.resp.Header().Get("Strict-Transport-Security"), ShouldEqual, "max-age=64000; preload")
assert.Equal(t, "max-age=64000; preload", sc.resp.Header().Get("Strict-Transport-Security"))
setting.StrictTransportSecuritySubDomains = true
sc.fakeReq("GET", "/api/").exec()
So(sc.resp.Header().Get("Strict-Transport-Security"), ShouldEqual, "max-age=64000; preload; includeSubDomains")
})
assert.Equal(t, "max-age=64000; preload; includeSubDomains", sc.resp.Header().Get("Strict-Transport-Security"))
})
}
func TestMiddlewareContext(t *testing.T) {
origErrTemplateName := setting.ErrTemplateName
t.Cleanup(func() {
setting.ErrTemplateName = origErrTemplateName
})
setting.ErrTemplateName = errorTemplate
Convey("Given the grafana middleware", t, func() {
middlewareScenario(t, "middleware should add context to injector", func(sc *scenarioContext) {
sc.fakeReq("GET", "/").exec()
So(sc.context, ShouldNotBeNil)
assert.NotNil(t, sc.context)
})
middlewareScenario(t, "Default middleware should allow get request", func(sc *scenarioContext) {
sc.fakeReq("GET", "/").exec()
So(sc.resp.Code, ShouldEqual, 200)
assert.Equal(t, 200, sc.resp.Code)
})
middlewareScenario(t, "middleware should add Cache-Control header for requests to API", func(sc *scenarioContext) {
sc.fakeReq("GET", "/api/search").exec()
So(sc.resp.Header().Get("Cache-Control"), ShouldEqual, "no-cache")
So(sc.resp.Header().Get("Pragma"), ShouldEqual, "no-cache")
So(sc.resp.Header().Get("Expires"), ShouldEqual, "-1")
assert.Equal(t, "no-cache", sc.resp.Header().Get("Cache-Control"))
assert.Equal(t, "no-cache", sc.resp.Header().Get("Pragma"))
assert.Equal(t, "-1", sc.resp.Header().Get("Expires"))
})
middlewareScenario(t, "middleware should not add Cache-Control header for requests to datasource proxy API", func(sc *scenarioContext) {
sc.fakeReq("GET", "/api/datasources/proxy/1/test").exec()
So(sc.resp.Header().Get("Cache-Control"), ShouldBeEmpty)
So(sc.resp.Header().Get("Pragma"), ShouldBeEmpty)
So(sc.resp.Header().Get("Expires"), ShouldBeEmpty)
assert.Empty(t, sc.resp.Header().Get("Cache-Control"))
assert.Empty(t, sc.resp.Header().Get("Pragma"))
assert.Empty(t, sc.resp.Header().Get("Expires"))
})
middlewareScenario(t, "middleware should add Cache-Control header for requests with html response", func(sc *scenarioContext) {
@ -114,57 +135,53 @@ func TestMiddlewareContext(t *testing.T) {
c.HTML(200, "index-template", data)
})
sc.fakeReq("GET", "/").exec()
So(sc.resp.Code, ShouldEqual, 200)
So(sc.resp.Header().Get("Cache-Control"), ShouldEqual, "no-cache")
So(sc.resp.Header().Get("Pragma"), ShouldEqual, "no-cache")
So(sc.resp.Header().Get("Expires"), ShouldEqual, "-1")
assert.Equal(t, 200, sc.resp.Code)
assert.Equal(t, "no-cache", sc.resp.Header().Get("Cache-Control"))
assert.Equal(t, "no-cache", sc.resp.Header().Get("Pragma"))
assert.Equal(t, "-1", sc.resp.Header().Get("Expires"))
})
middlewareScenario(t, "middleware should add X-Frame-Options header with deny for request when not allowing embedding", func(sc *scenarioContext) {
sc.fakeReq("GET", "/api/search").exec()
So(sc.resp.Header().Get("X-Frame-Options"), ShouldEqual, "deny")
assert.Equal(t, "deny", sc.resp.Header().Get("X-Frame-Options"))
})
middlewareScenario(t, "middleware should not add X-Frame-Options header for request when allowing embedding", func(sc *scenarioContext) {
origAllowEmbedding := setting.AllowEmbedding
t.Cleanup(func() {
setting.AllowEmbedding = origAllowEmbedding
})
setting.AllowEmbedding = true
sc.fakeReq("GET", "/api/search").exec()
So(sc.resp.Header().Get("X-Frame-Options"), ShouldBeEmpty)
assert.Empty(t, sc.resp.Header().Get("X-Frame-Options"))
})
middlewareScenario(t, "Invalid api key", func(sc *scenarioContext) {
sc.apiKey = "invalid_key_test"
sc.fakeReq("GET", "/").exec()
Convey("Should not init session", func() {
So(sc.resp.Header().Get("Set-Cookie"), ShouldBeEmpty)
})
Convey("Should return 401", func() {
So(sc.resp.Code, ShouldEqual, 401)
So(sc.respJson["message"], ShouldEqual, errStringInvalidAPIKey)
})
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
assert.Equal(t, 401, sc.resp.Code)
assert.Equal(t, errStringInvalidAPIKey, sc.respJson["message"])
})
middlewareScenario(t, "Valid api key", func(sc *scenarioContext) {
const orgID int64 = 12
keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
So(err, ShouldBeNil)
require.NoError(t, err)
bus.AddHandler("test", func(query *models.GetApiKeyByNameQuery) error {
query.Result = &models.ApiKey{OrgId: 12, Role: models.ROLE_EDITOR, Key: keyhash}
query.Result = &models.ApiKey{OrgId: orgID, Role: models.ROLE_EDITOR, Key: keyhash}
return nil
})
sc.fakeReq("GET", "/").withValidApiKey().exec()
Convey("Should return 200", func() {
So(sc.resp.Code, ShouldEqual, 200)
})
assert.Equal(t, 200, sc.resp.Code)
Convey("Should init middleware context", func() {
So(sc.context.IsSignedIn, ShouldEqual, true)
So(sc.context.OrgId, ShouldEqual, 12)
So(sc.context.OrgRole, ShouldEqual, models.ROLE_EDITOR)
})
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, orgID, sc.context.OrgId)
assert.Equal(t, models.ROLE_EDITOR, sc.context.OrgRole)
})
middlewareScenario(t, "Valid api key, but does not match db hash", func(sc *scenarioContext) {
@ -177,10 +194,8 @@ func TestMiddlewareContext(t *testing.T) {
sc.fakeReq("GET", "/").withValidApiKey().exec()
Convey("Should return api key invalid", func() {
So(sc.resp.Code, ShouldEqual, 401)
So(sc.respJson["message"], ShouldEqual, errStringInvalidAPIKey)
})
assert.Equal(t, 401, sc.resp.Code)
assert.Equal(t, errStringInvalidAPIKey, sc.respJson["message"])
})
middlewareScenario(t, "Valid api key, but expired", func(sc *scenarioContext) {
@ -188,7 +203,7 @@ func TestMiddlewareContext(t *testing.T) {
defer resetGetTime()
keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
So(err, ShouldBeNil)
require.NoError(t, err)
bus.AddHandler("test", func(query *models.GetApiKeyByNameQuery) error {
// api key expired one second before
@ -200,52 +215,49 @@ func TestMiddlewareContext(t *testing.T) {
sc.fakeReq("GET", "/").withValidApiKey().exec()
Convey("Should return 401", func() {
So(sc.resp.Code, ShouldEqual, 401)
So(sc.respJson["message"], ShouldEqual, "Expired API key")
})
assert.Equal(t, 401, sc.resp.Code)
assert.Equal(t, "Expired API key", sc.respJson["message"])
})
middlewareScenario(t, "Non-expired auth token in cookie which not are being rotated", func(sc *scenarioContext) {
const userID int64 = 12
sc.withTokenSessionCookie("token")
bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error {
query.Result = &models.SignedInUser{OrgId: 2, UserId: 12}
query.Result = &models.SignedInUser{OrgId: 2, UserId: userID}
return nil
})
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
return &models.UserToken{
UserId: 12,
UserId: userID,
UnhashedToken: unhashedToken,
}, nil
}
sc.fakeReq("GET", "/").exec()
Convey("Should init context with user info", func() {
So(sc.context.IsSignedIn, ShouldBeTrue)
So(sc.context.UserId, ShouldEqual, 12)
So(sc.context.UserToken.UserId, ShouldEqual, 12)
So(sc.context.UserToken.UnhashedToken, ShouldEqual, "token")
})
Convey("Should not set cookie", func() {
So(sc.resp.Header().Get("Set-Cookie"), ShouldEqual, "")
})
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserId)
assert.Equal(t, userID, sc.context.UserToken.UserId)
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
assert.Equal(t, "", sc.resp.Header().Get("Set-Cookie"))
})
middlewareScenario(t, "Non-expired auth token in cookie which are being rotated", func(sc *scenarioContext) {
const userID int64 = 12
sc.withTokenSessionCookie("token")
bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error {
query.Result = &models.SignedInUser{OrgId: 2, UserId: 12}
query.Result = &models.SignedInUser{OrgId: 2, UserId: userID}
return nil
})
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
return &models.UserToken{
UserId: 12,
UserId: userID,
UnhashedToken: "",
}, nil
}
@ -258,13 +270,19 @@ func TestMiddlewareContext(t *testing.T) {
maxAge := int(setting.LoginMaxLifetime.Seconds())
sameSitePolicies := []http.SameSite{
sameSiteModes := []http.SameSite{
http.SameSiteNoneMode,
http.SameSiteLaxMode,
http.SameSiteStrictMode,
}
for _, sameSitePolicy := range sameSitePolicies {
setting.CookieSameSiteMode = sameSitePolicy
for _, sameSiteMode := range sameSiteModes {
t.Run(fmt.Sprintf("Same site mode %d", sameSiteMode), func(t *testing.T) {
origCookieSameSiteMode := setting.CookieSameSiteMode
t.Cleanup(func() {
setting.CookieSameSiteMode = origCookieSameSiteMode
})
setting.CookieSameSiteMode = sameSiteMode
expectedCookiePath := "/"
if len(setting.AppSubUrl) > 0 {
expectedCookiePath = setting.AppSubUrl
@ -276,26 +294,29 @@ func TestMiddlewareContext(t *testing.T) {
HttpOnly: true,
MaxAge: maxAge,
Secure: setting.CookieSecure,
SameSite: sameSitePolicy,
SameSite: sameSiteMode,
}
sc.fakeReq("GET", "/").exec()
Convey(fmt.Sprintf("Should init context with user info and setting.SameSite=%v", sameSitePolicy), func() {
So(sc.context.IsSignedIn, ShouldBeTrue)
So(sc.context.UserId, ShouldEqual, 12)
So(sc.context.UserToken.UserId, ShouldEqual, 12)
So(sc.context.UserToken.UnhashedToken, ShouldEqual, "rotated")
})
Convey(fmt.Sprintf("Should set cookie with setting.SameSite=%v", sameSitePolicy), func() {
So(sc.resp.Header().Get("Set-Cookie"), ShouldEqual, expectedCookie.String())
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserId)
assert.Equal(t, userID, sc.context.UserToken.UserId)
assert.Equal(t, "rotated", sc.context.UserToken.UnhashedToken)
assert.Equal(t, expectedCookie.String(), sc.resp.Header().Get("Set-Cookie"))
})
}
Convey("Should not set cookie with SameSite attribute when setting.CookieSameSiteDisabled is true", func() {
t.Run("Should not set cookie with SameSite attribute when setting.CookieSameSiteDisabled is true", func(t *testing.T) {
origCookieSameSiteDisabled := setting.CookieSameSiteDisabled
origCookieSameSiteMode := setting.CookieSameSiteMode
t.Cleanup(func() {
setting.CookieSameSiteDisabled = origCookieSameSiteDisabled
setting.CookieSameSiteMode = origCookieSameSiteMode
})
setting.CookieSameSiteDisabled = true
setting.CookieSameSiteMode = http.SameSiteLaxMode
expectedCookiePath := "/"
if len(setting.AppSubUrl) > 0 {
expectedCookiePath = setting.AppSubUrl
@ -310,7 +331,7 @@ func TestMiddlewareContext(t *testing.T) {
}
sc.fakeReq("GET", "/").exec()
So(sc.resp.Header().Get("Set-Cookie"), ShouldEqual, expectedCookie.String())
assert.Equal(t, expectedCookie.String(), sc.resp.Header().Get("Set-Cookie"))
})
})
@ -323,39 +344,61 @@ func TestMiddlewareContext(t *testing.T) {
sc.fakeReq("GET", "/").exec()
Convey("Should not init context with user info", func() {
So(sc.context.IsSignedIn, ShouldBeFalse)
So(sc.context.UserId, ShouldEqual, 0)
So(sc.context.UserToken, ShouldBeNil)
})
assert.False(t, sc.context.IsSignedIn)
assert.Equal(t, int64(0), sc.context.UserId)
assert.Nil(t, sc.context.UserToken)
})
middlewareScenario(t, "When anonymous access is enabled", func(sc *scenarioContext) {
const orgID int64 = 2
origAnonymousEnabled := setting.AnonymousEnabled
origAnonymousOrgName := setting.AnonymousOrgName
origAnonymousOrgRole := setting.AnonymousOrgRole
t.Cleanup(func() {
setting.AnonymousEnabled = origAnonymousEnabled
setting.AnonymousOrgName = origAnonymousOrgName
setting.AnonymousOrgRole = origAnonymousOrgRole
})
setting.AnonymousEnabled = true
setting.AnonymousOrgName = "test"
setting.AnonymousOrgRole = string(models.ROLE_EDITOR)
bus.AddHandler("test", func(query *models.GetOrgByNameQuery) error {
So(query.Name, ShouldEqual, "test")
assert.Equal(t, "test", query.Name)
query.Result = &models.Org{Id: 2, Name: "test"}
query.Result = &models.Org{Id: orgID, Name: "test"}
return nil
})
sc.fakeReq("GET", "/").exec()
Convey("Should init context with org info", func() {
So(sc.context.UserId, ShouldEqual, 0)
So(sc.context.OrgId, ShouldEqual, 2)
So(sc.context.OrgRole, ShouldEqual, models.ROLE_EDITOR)
assert.Equal(t, int64(0), sc.context.UserId)
assert.Equal(t, orgID, sc.context.OrgId)
assert.Equal(t, models.ROLE_EDITOR, sc.context.OrgRole)
assert.False(t, sc.context.IsSignedIn)
})
Convey("context signed in should be false", func() {
So(sc.context.IsSignedIn, ShouldBeFalse)
})
})
t.Run("auth_proxy", func(t *testing.T) {
const userID int64 = 33
const orgID int64 = 4
Convey("auth_proxy", func() {
origAuthProxyEnabled := setting.AuthProxyEnabled
origAuthProxyWhitelist := setting.AuthProxyWhitelist
origAuthProxyAutoSignUp := setting.AuthProxyAutoSignUp
origLDAPEnabled := setting.LDAPEnabled
origAuthProxyHeaderName := setting.AuthProxyHeaderName
origAuthProxyHeaderProperty := setting.AuthProxyHeaderProperty
origAuthProxyHeaders := setting.AuthProxyHeaders
t.Cleanup(func() {
setting.AuthProxyEnabled = origAuthProxyEnabled
setting.AuthProxyWhitelist = origAuthProxyWhitelist
setting.AuthProxyAutoSignUp = origAuthProxyAutoSignUp
setting.LDAPEnabled = origLDAPEnabled
setting.AuthProxyHeaderName = origAuthProxyHeaderName
setting.AuthProxyHeaderProperty = origAuthProxyHeaderProperty
setting.AuthProxyHeaders = origAuthProxyHeaders
})
setting.AuthProxyEnabled = true
setting.AuthProxyWhitelist = ""
setting.AuthProxyAutoSignUp = true
@ -363,34 +406,40 @@ func TestMiddlewareContext(t *testing.T) {
setting.AuthProxyHeaderName = "X-WEBAUTH-USER"
setting.AuthProxyHeaderProperty = "username"
setting.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS"}
name := "markelog"
group := "grafana-core-team"
const hdrName = "markelog"
const group = "grafana-core-team"
middlewareScenario(t, "Should not sync the user if it's in the cache", func(sc *scenarioContext) {
bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error {
query.Result = &models.SignedInUser{OrgId: 4, UserId: query.UserId}
query.Result = &models.SignedInUser{OrgId: orgID, UserId: query.UserId}
return nil
})
key := fmt.Sprintf(authproxy.CachePrefix, authproxy.HashCacheKey(name+"-"+group))
err := sc.remoteCacheService.Set(key, int64(33), 0)
So(err, ShouldBeNil)
key := fmt.Sprintf(authproxy.CachePrefix, authproxy.HashCacheKey(hdrName+"-"+group))
err := sc.remoteCacheService.Set(key, userID, 0)
require.NoError(t, err)
sc.fakeReq("GET", "/")
sc.req.Header.Add(setting.AuthProxyHeaderName, name)
sc.req.Header.Add("X-WEBAUTH-GROUPS", group)
sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName)
sc.req.Header.Set("X-WEBAUTH-GROUPS", group)
sc.exec()
Convey("Should init user via cache", func() {
So(sc.context.IsSignedIn, ShouldBeTrue)
So(sc.context.UserId, ShouldEqual, 33)
So(sc.context.OrgId, ShouldEqual, 4)
})
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserId)
assert.Equal(t, orgID, sc.context.OrgId)
})
middlewareScenario(t, "Should respect auto signup option", func(sc *scenarioContext) {
origLDAPEnabled = setting.LDAPEnabled
origAuthProxyAutoSignUp = setting.AuthProxyAutoSignUp
t.Cleanup(func() {
setting.LDAPEnabled = origLDAPEnabled
setting.AuthProxyAutoSignUp = origAuthProxyAutoSignUp
})
setting.LDAPEnabled = false
setting.AuthProxyAutoSignUp = false
var actualAuthProxyAutoSignUp *bool = nil
bus.AddHandler("test", func(cmd *models.UpsertUserCommand) error {
@ -399,7 +448,7 @@ func TestMiddlewareContext(t *testing.T) {
})
sc.fakeReq("GET", "/")
sc.req.Header.Add(setting.AuthProxyHeaderName, name)
sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName)
sc.exec()
assert.False(t, *actualAuthProxyAutoSignUp)
@ -408,106 +457,123 @@ func TestMiddlewareContext(t *testing.T) {
})
middlewareScenario(t, "Should create an user from a header", func(sc *scenarioContext) {
origLDAPEnabled = setting.LDAPEnabled
origAuthProxyAutoSignUp = setting.AuthProxyAutoSignUp
t.Cleanup(func() {
setting.LDAPEnabled = origLDAPEnabled
setting.AuthProxyAutoSignUp = origAuthProxyAutoSignUp
})
setting.LDAPEnabled = false
setting.AuthProxyAutoSignUp = true
bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error {
if query.UserId > 0 {
query.Result = &models.SignedInUser{OrgId: 4, UserId: 33}
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
return nil
}
return models.ErrUserNotFound
})
bus.AddHandler("test", func(cmd *models.UpsertUserCommand) error {
cmd.Result = &models.User{Id: 33}
cmd.Result = &models.User{Id: userID}
return nil
})
sc.fakeReq("GET", "/")
sc.req.Header.Add(setting.AuthProxyHeaderName, name)
sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName)
sc.exec()
Convey("Should create user from header info", func() {
So(sc.context.IsSignedIn, ShouldBeTrue)
So(sc.context.UserId, ShouldEqual, 33)
So(sc.context.OrgId, ShouldEqual, 4)
})
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserId)
assert.Equal(t, orgID, sc.context.OrgId)
})
middlewareScenario(t, "Should get an existing user from header", func(sc *scenarioContext) {
const userID int64 = 12
const orgID int64 = 2
origLDAPEnabled = setting.LDAPEnabled
t.Cleanup(func() {
setting.LDAPEnabled = origLDAPEnabled
})
setting.LDAPEnabled = false
bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error {
query.Result = &models.SignedInUser{OrgId: 2, UserId: 12}
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
return nil
})
bus.AddHandler("test", func(cmd *models.UpsertUserCommand) error {
cmd.Result = &models.User{Id: 12}
cmd.Result = &models.User{Id: userID}
return nil
})
sc.fakeReq("GET", "/")
sc.req.Header.Add(setting.AuthProxyHeaderName, name)
sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName)
sc.exec()
Convey("Should init context with user info", func() {
So(sc.context.IsSignedIn, ShouldBeTrue)
So(sc.context.UserId, ShouldEqual, 12)
So(sc.context.OrgId, ShouldEqual, 2)
})
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserId)
assert.Equal(t, orgID, sc.context.OrgId)
})
middlewareScenario(t, "Should allow the request from whitelist IP", func(sc *scenarioContext) {
origAuthProxyWhitelist = setting.AuthProxyWhitelist
origLDAPEnabled = setting.LDAPEnabled
t.Cleanup(func() {
setting.AuthProxyWhitelist = origAuthProxyWhitelist
setting.LDAPEnabled = origLDAPEnabled
})
setting.AuthProxyWhitelist = "192.168.1.0/24, 2001::0/120"
setting.LDAPEnabled = false
bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error {
query.Result = &models.SignedInUser{OrgId: 4, UserId: 33}
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
return nil
})
bus.AddHandler("test", func(cmd *models.UpsertUserCommand) error {
cmd.Result = &models.User{Id: 33}
cmd.Result = &models.User{Id: userID}
return nil
})
sc.fakeReq("GET", "/")
sc.req.Header.Add(setting.AuthProxyHeaderName, name)
sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName)
sc.req.RemoteAddr = "[2001::23]:12345"
sc.exec()
Convey("Should init context with user info", func() {
So(sc.context.IsSignedIn, ShouldBeTrue)
So(sc.context.UserId, ShouldEqual, 33)
So(sc.context.OrgId, ShouldEqual, 4)
})
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserId)
assert.Equal(t, orgID, sc.context.OrgId)
})
middlewareScenario(t, "Should not allow the request from whitelist IP", func(sc *scenarioContext) {
origAuthProxyWhitelist = setting.AuthProxyWhitelist
origLDAPEnabled = setting.LDAPEnabled
t.Cleanup(func() {
setting.AuthProxyWhitelist = origAuthProxyWhitelist
setting.LDAPEnabled = origLDAPEnabled
})
setting.AuthProxyWhitelist = "8.8.8.8"
setting.LDAPEnabled = false
bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error {
query.Result = &models.SignedInUser{OrgId: 4, UserId: 33}
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
return nil
})
bus.AddHandler("test", func(cmd *models.UpsertUserCommand) error {
cmd.Result = &models.User{Id: 33}
cmd.Result = &models.User{Id: userID}
return nil
})
sc.fakeReq("GET", "/")
sc.req.Header.Add(setting.AuthProxyHeaderName, name)
sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName)
sc.req.RemoteAddr = "[2001::23]:12345"
sc.exec()
Convey("Should return 407 status code", func() {
So(sc.resp.Code, ShouldEqual, 407)
So(sc.context, ShouldBeNil)
})
assert.Equal(t, 407, sc.resp.Code)
assert.Nil(t, sc.context)
})
middlewareScenario(t, "Should return 407 status code if LDAP says no", func(sc *scenarioContext) {
@ -516,13 +582,11 @@ func TestMiddlewareContext(t *testing.T) {
})
sc.fakeReq("GET", "/")
sc.req.Header.Add(setting.AuthProxyHeaderName, name)
sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName)
sc.exec()
Convey("Should return 407 status code", func() {
So(sc.resp.Code, ShouldEqual, 407)
So(sc.context, ShouldBeNil)
})
assert.Equal(t, 407, sc.resp.Code)
assert.Nil(t, sc.context)
})
middlewareScenario(t, "Should return 407 status code if there is cache mishap", func(sc *scenarioContext) {
@ -531,28 +595,33 @@ func TestMiddlewareContext(t *testing.T) {
})
sc.fakeReq("GET", "/")
sc.req.Header.Add(setting.AuthProxyHeaderName, name)
sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName)
sc.exec()
Convey("Should return 407 status code", func() {
So(sc.resp.Code, ShouldEqual, 407)
So(sc.context, ShouldBeNil)
})
})
assert.Equal(t, 407, sc.resp.Code)
assert.Nil(t, sc.context)
})
})
}
func middlewareScenario(t *testing.T, desc string, fn scenarioFunc) {
Convey(desc, func() {
defer bus.ClearBusHandlers()
t.Helper()
t.Run(desc, func(t *testing.T) {
t.Cleanup(bus.ClearBusHandlers)
origLoginCookieName := setting.LoginCookieName
origLoginMaxLifetime := setting.LoginMaxLifetime
t.Cleanup(func() {
setting.LoginCookieName = origLoginCookieName
setting.LoginMaxLifetime = origLoginMaxLifetime
})
setting.LoginCookieName = "grafana_session"
var err error
setting.LoginMaxLifetime, err = gtime.ParseDuration("30d")
require.NoError(t, err)
sc := &scenarioContext{}
sc := &scenarioContext{t: t}
viewsPath, err := filepath.Abs("../../public/views")
require.NoError(t, err)
@ -590,7 +659,7 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc) {
func TestDontRotateTokensOnCancelledRequests(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
reqContext, _, err := initTokenRotationTest(ctx)
reqContext, _, err := initTokenRotationTest(ctx, t)
require.NoError(t, err)
tryRotateCallCount := 0
@ -612,7 +681,7 @@ func TestDontRotateTokensOnCancelledRequests(t *testing.T) {
}
func TestTokenRotationAtEndOfRequest(t *testing.T) {
reqContext, rr, err := initTokenRotationTest(context.Background())
reqContext, rr, err := initTokenRotationTest(context.Background(), t)
require.NoError(t, err)
uts := &auth.FakeUserAuthTokenService{
@ -643,7 +712,15 @@ func TestTokenRotationAtEndOfRequest(t *testing.T) {
assert.True(t, foundLoginCookie, "Could not find cookie")
}
func initTokenRotationTest(ctx context.Context) (*models.ReqContext, *httptest.ResponseRecorder, error) {
func initTokenRotationTest(ctx context.Context, t *testing.T) (*models.ReqContext, *httptest.ResponseRecorder, error) {
t.Helper()
origLoginCookieName := setting.LoginCookieName
origLoginMaxLifetime := setting.LoginMaxLifetime
t.Cleanup(func() {
setting.LoginCookieName = origLoginCookieName
setting.LoginMaxLifetime = origLoginMaxLifetime
})
setting.LoginCookieName = "login_token"
var err error
setting.LoginMaxLifetime, err = gtime.ParseDuration("7d")

View File

@ -7,11 +7,10 @@ import (
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/models"
. "github.com/smartystreets/goconvey/convey"
"github.com/stretchr/testify/assert"
)
func TestOrgRedirectMiddleware(t *testing.T) {
Convey("Can redirect to correct org", t, func() {
middlewareScenario(t, "when setting a correct org for the user", func(sc *scenarioContext) {
sc.withTokenSessionCookie("token")
bus.AddHandler("test", func(query *models.SetUsingOrgCommand) error {
@ -33,9 +32,7 @@ func TestOrgRedirectMiddleware(t *testing.T) {
sc.m.Get("/", sc.defaultHandler)
sc.fakeReq("GET", "/?orgId=3").exec()
Convey("change org and redirect", func() {
So(sc.resp.Code, ShouldEqual, 302)
})
assert.Equal(t, 302, sc.resp.Code)
})
middlewareScenario(t, "when setting an invalid org for user", func(sc *scenarioContext) {
@ -59,9 +56,6 @@ func TestOrgRedirectMiddleware(t *testing.T) {
sc.m.Get("/", sc.defaultHandler)
sc.fakeReq("GET", "/?orgId=3").exec()
Convey("not allowed to change org", func() {
So(sc.resp.Code, ShouldEqual, 404)
})
})
assert.Equal(t, 404, sc.resp.Code)
})
}

View File

@ -9,11 +9,10 @@ import (
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/quota"
"github.com/grafana/grafana/pkg/setting"
. "github.com/smartystreets/goconvey/convey"
"github.com/stretchr/testify/assert"
)
func TestMiddlewareQuota(t *testing.T) {
Convey("Given the grafana quota middleware", t, func() {
setting.AnonymousEnabled = false
setting.Quota = setting.QuotaSettings{
Enabled: true,
@ -40,9 +39,10 @@ func TestMiddlewareQuota(t *testing.T) {
qs := &quota.QuotaService{
AuthTokenService: fakeAuthTokenService,
}
QuotaFn := Quota(qs)
quotaFn := Quota(qs)
middlewareScenario(t, "with user not logged in", func(sc *scenarioContext) {
t.Run("With user not logged in", func(t *testing.T) {
middlewareScenario(t, "and global quota not reached", func(sc *scenarioContext) {
bus.AddHandler("globalQuota", func(query *models.GetGlobalQuotaByTargetQuery) error {
query.Result = &models.GlobalQuotaDTO{
Target: query.Target,
@ -52,31 +52,63 @@ func TestMiddlewareQuota(t *testing.T) {
return nil
})
Convey("global quota not reached", func() {
sc.m.Get("/user", QuotaFn("user"), sc.defaultHandler)
sc.m.Get("/user", quotaFn("user"), sc.defaultHandler)
sc.fakeReq("GET", "/user").exec()
So(sc.resp.Code, ShouldEqual, 200)
assert.Equal(sc.t, 200, sc.resp.Code)
})
Convey("global quota reached", func() {
middlewareScenario(t, "and global quota reached", func(sc *scenarioContext) {
bus.AddHandler("globalQuota", func(query *models.GetGlobalQuotaByTargetQuery) error {
query.Result = &models.GlobalQuotaDTO{
Target: query.Target,
Limit: query.Default,
Used: 4,
}
return nil
})
origUser := setting.Quota.Global.User
t.Cleanup(func() {
setting.Quota.Global.User = origUser
})
setting.Quota.Global.User = 4
sc.m.Get("/user", QuotaFn("user"), sc.defaultHandler)
sc.m.Get("/user", quotaFn("user"), sc.defaultHandler)
sc.fakeReq("GET", "/user").exec()
So(sc.resp.Code, ShouldEqual, 403)
assert.Equal(t, 403, sc.resp.Code)
})
Convey("global session quota not reached", func() {
middlewareScenario(t, "and global session quota not reached", func(sc *scenarioContext) {
bus.AddHandler("globalQuota", func(query *models.GetGlobalQuotaByTargetQuery) error {
query.Result = &models.GlobalQuotaDTO{
Target: query.Target,
Limit: query.Default,
Used: 4,
}
return nil
})
origSession := setting.Quota.Global.Session
t.Cleanup(func() {
setting.Quota.Global.Session = origSession
})
setting.Quota.Global.Session = 10
sc.m.Get("/user", QuotaFn("session"), sc.defaultHandler)
sc.m.Get("/user", quotaFn("session"), sc.defaultHandler)
sc.fakeReq("GET", "/user").exec()
So(sc.resp.Code, ShouldEqual, 200)
assert.Equal(t, 200, sc.resp.Code)
})
Convey("global session quota reached", func() {
middlewareScenario(t, "and global session quota reached", func(sc *scenarioContext) {
origSession := setting.Quota.Global.Session
t.Cleanup(func() {
setting.Quota.Global.Session = origSession
})
setting.Quota.Global.Session = 1
sc.m.Get("/user", QuotaFn("session"), sc.defaultHandler)
sc.m.Get("/user", quotaFn("session"), sc.defaultHandler)
sc.fakeReq("GET", "/user").exec()
So(sc.resp.Code, ShouldEqual, 403)
assert.Equal(sc.t, 403, sc.resp.Code)
})
})
@ -121,48 +153,47 @@ func TestMiddlewareQuota(t *testing.T) {
return nil
})
Convey("global datasource quota reached", func() {
t.Run("global datasource quota reached", func(t *testing.T) {
setting.Quota.Global.DataSource = 4
sc.m.Get("/ds", QuotaFn("data_source"), sc.defaultHandler)
sc.m.Get("/ds", quotaFn("data_source"), sc.defaultHandler)
sc.fakeReq("GET", "/ds").exec()
So(sc.resp.Code, ShouldEqual, 403)
assert.Equal(t, 403, sc.resp.Code)
})
Convey("user Org quota not reached", func() {
t.Run("user Org quota not reached", func(t *testing.T) {
setting.Quota.User.Org = 5
sc.m.Get("/org", QuotaFn("org"), sc.defaultHandler)
sc.m.Get("/org", quotaFn("org"), sc.defaultHandler)
sc.fakeReq("GET", "/org").exec()
So(sc.resp.Code, ShouldEqual, 200)
assert.Equal(t, 200, sc.resp.Code)
})
Convey("user Org quota reached", func() {
t.Run("user Org quota reached", func(t *testing.T) {
setting.Quota.User.Org = 4
sc.m.Get("/org", QuotaFn("org"), sc.defaultHandler)
sc.m.Get("/org", quotaFn("org"), sc.defaultHandler)
sc.fakeReq("GET", "/org").exec()
So(sc.resp.Code, ShouldEqual, 403)
assert.Equal(t, 403, sc.resp.Code)
})
Convey("org dashboard quota not reached", func() {
t.Run("org dashboard quota not reached", func(t *testing.T) {
setting.Quota.Org.Dashboard = 10
sc.m.Get("/dashboard", QuotaFn("dashboard"), sc.defaultHandler)
sc.m.Get("/dashboard", quotaFn("dashboard"), sc.defaultHandler)
sc.fakeReq("GET", "/dashboard").exec()
So(sc.resp.Code, ShouldEqual, 200)
assert.Equal(t, 200, sc.resp.Code)
})
Convey("org dashboard quota reached", func() {
t.Run("org dashboard quota reached", func(t *testing.T) {
setting.Quota.Org.Dashboard = 4
sc.m.Get("/dashboard", QuotaFn("dashboard"), sc.defaultHandler)
sc.m.Get("/dashboard", quotaFn("dashboard"), sc.defaultHandler)
sc.fakeReq("GET", "/dashboard").exec()
So(sc.resp.Code, ShouldEqual, 403)
assert.Equal(t, 403, sc.resp.Code)
})
Convey("org dashboard quota reached but quotas disabled", func() {
t.Run("org dashboard quota reached but quotas disabled", func(t *testing.T) {
setting.Quota.Org.Dashboard = 4
setting.Quota.Enabled = false
sc.m.Get("/dashboard", QuotaFn("dashboard"), sc.defaultHandler)
sc.m.Get("/dashboard", quotaFn("dashboard"), sc.defaultHandler)
sc.fakeReq("GET", "/dashboard").exec()
So(sc.resp.Code, ShouldEqual, 200)
})
assert.Equal(t, 200, sc.resp.Code)
})
})
}

View File

@ -2,6 +2,7 @@ package middleware
import (
"path/filepath"
"strings"
"testing"
"github.com/grafana/grafana/pkg/bus"
@ -9,52 +10,55 @@ import (
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/setting"
. "github.com/smartystreets/goconvey/convey"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
macaron "gopkg.in/macaron.v1"
)
func TestRecoveryMiddleware(t *testing.T) {
setting.ErrTemplateName = "error-template"
Convey("Given an api route that panics", t, func() {
t.Run("Given an API route that panics", func(t *testing.T) {
apiURL := "/api/whatever"
recoveryScenario(t, "recovery middleware should return json", apiURL, func(sc *scenarioContext) {
sc.handlerFunc = PanicHandler
sc.handlerFunc = panicHandler
sc.fakeReq("GET", apiURL).exec()
sc.req.Header.Add("content-type", "application/json")
So(sc.resp.Code, ShouldEqual, 500)
So(sc.respJson["message"], ShouldStartWith, "Internal Server Error - Check the Grafana server logs for the detailed error message.")
So(sc.respJson["error"], ShouldStartWith, "Server Error")
assert.Equal(t, 500, sc.resp.Code)
assert.Equal(t, "Internal Server Error - Check the Grafana server logs for the detailed error message.", sc.respJson["message"])
assert.True(t, strings.HasPrefix(sc.respJson["error"].(string), "Server Error"))
})
})
Convey("Given a non-api route that panics", t, func() {
t.Run("Given a non-API route that panics", func(t *testing.T) {
apiURL := "/whatever"
recoveryScenario(t, "recovery middleware should return html", apiURL, func(sc *scenarioContext) {
sc.handlerFunc = PanicHandler
sc.handlerFunc = panicHandler
sc.fakeReq("GET", apiURL).exec()
So(sc.resp.Code, ShouldEqual, 500)
So(sc.resp.Header().Get("content-type"), ShouldEqual, "text/html; charset=UTF-8")
So(sc.resp.Body.String(), ShouldContainSubstring, "<title>Grafana - Error</title>")
assert.Equal(t, 500, sc.resp.Code)
assert.Equal(t, "text/html; charset=UTF-8", sc.resp.Header().Get("content-type"))
assert.True(t, strings.Contains(sc.resp.Body.String(), "<title>Grafana - Error</title>"))
})
})
}
func PanicHandler(c *models.ReqContext) {
func panicHandler(c *models.ReqContext) {
panic("Handler has panicked")
}
func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) {
Convey(desc, func() {
t.Run(desc, func(t *testing.T) {
defer bus.ClearBusHandlers()
sc := &scenarioContext{
t: t,
url: url,
}
viewsPath, _ := filepath.Abs("../../public/views")
viewsPath, err := filepath.Abs("../../public/views")
require.NoError(t, err)
sc.m = macaron.New()
sc.m.Use(Recovery())

View File

@ -4,6 +4,7 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"gopkg.in/macaron.v1"
@ -11,10 +12,11 @@ import (
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/setting"
"github.com/smartystreets/goconvey/convey"
"github.com/stretchr/testify/require"
)
type scenarioContext struct {
t *testing.T
m *macaron.Macaron
context *models.ReqContext
resp *httptest.ResponseRecorder
@ -47,15 +49,19 @@ func (sc *scenarioContext) withAuthorizationHeader(authHeader string) *scenarioC
}
func (sc *scenarioContext) fakeReq(method, url string) *scenarioContext {
sc.t.Helper()
sc.resp = httptest.NewRecorder()
req, err := http.NewRequest(method, url, nil)
convey.So(err, convey.ShouldBeNil)
require.NoError(sc.t, err)
sc.req = req
return sc
}
func (sc *scenarioContext) fakeReqWithParams(method, url string, queryParams map[string]string) *scenarioContext {
sc.t.Helper()
sc.resp = httptest.NewRecorder()
req, err := http.NewRequest(method, url, nil)
q := req.URL.Query()
@ -63,7 +69,7 @@ func (sc *scenarioContext) fakeReqWithParams(method, url string, queryParams map
q.Add(k, v)
}
req.URL.RawQuery = q.Encode()
convey.So(err, convey.ShouldBeNil)
require.NoError(sc.t, err)
sc.req = req
return sc
@ -75,15 +81,20 @@ func (sc *scenarioContext) handler(fn handlerFunc) *scenarioContext {
}
func (sc *scenarioContext) exec() {
sc.t.Helper()
if sc.apiKey != "" {
sc.t.Logf(`Adding header "Authorization: Bearer %s"`, sc.apiKey)
sc.req.Header.Add("Authorization", "Bearer "+sc.apiKey)
}
if sc.authHeader != "" {
sc.t.Logf(`Adding header "Authorization: %s"`, sc.authHeader)
sc.req.Header.Add("Authorization", sc.authHeader)
}
if sc.tokenSessionCookie != "" {
sc.t.Log(`Adding cookie`, "name", setting.LoginCookieName, "value", sc.tokenSessionCookie)
sc.req.AddCookie(&http.Cookie{
Name: setting.LoginCookieName,
Value: sc.tokenSessionCookie,
@ -94,7 +105,7 @@ func (sc *scenarioContext) exec() {
if sc.resp.Header().Get("Content-Type") == "application/json; charset=UTF-8" {
err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson)
convey.So(err, convey.ShouldBeNil)
require.NoError(sc.t, err)
}
}

View File

@ -30,37 +30,43 @@ func TestNotificationService(t *testing.T) {
}
evalCtx := NewEvalContext(context.Background(), testRule)
notificationServiceScenario(t, "Given alert rule with upload image enabled should render and upload image and send notification", evalCtx, true, func(scenarioCtx *scenarioContext) {
err := scenarioCtx.notificationService.SendIfNeeded(evalCtx)
require.NoError(t, err)
notificationServiceScenario(t, "Given alert rule with upload image enabled should render and upload image and send notification",
evalCtx, true, func(sc *scenarioContext) {
err := sc.notificationService.SendIfNeeded(evalCtx)
require.NoError(sc.t, err)
require.Equalf(t, 1, scenarioCtx.renderCount, "expected render to be called, but wasn't")
require.Equalf(t, 1, scenarioCtx.imageUploadCount, "expected image to be uploaded, but wasn't")
require.Truef(t, evalCtx.Ctx.Value(notificationSent{}).(bool), "expected notification to be sent, but wasn't")
require.Equalf(sc.t, 1, sc.renderCount, "expected render to be called, but wasn't")
require.Equalf(sc.t, 1, sc.imageUploadCount, "expected image to be uploaded, but wasn't")
require.Truef(sc.t, evalCtx.Ctx.Value(notificationSent{}).(bool), "expected notification to be sent, but wasn't")
})
notificationServiceScenario(t, "Given alert rule with upload image enabled but no renderer available should render and upload unavailable image and send notification", evalCtx, true, func(scenarioCtx *scenarioContext) {
scenarioCtx.rendererAvailable = false
err := scenarioCtx.notificationService.SendIfNeeded(evalCtx)
require.NoError(t, err)
notificationServiceScenario(t,
"Given alert rule with upload image enabled but no renderer available should render and upload unavailable image and send notification",
evalCtx, true, func(sc *scenarioContext) {
sc.rendererAvailable = false
err := sc.notificationService.SendIfNeeded(evalCtx)
require.NoError(sc.t, err)
require.Equalf(t, 1, scenarioCtx.renderCount, "expected render to be called, but it wasn't")
require.Equalf(t, 1, scenarioCtx.imageUploadCount, "expected image to be uploaded, but it wasn't")
require.Truef(t, evalCtx.Ctx.Value(notificationSent{}).(bool), "expected notification to be sent, but wasn't")
require.Equalf(sc.t, 1, sc.renderCount, "expected render to be called, but it wasn't")
require.Equalf(sc.t, 1, sc.imageUploadCount, "expected image to be uploaded, but it wasn't")
require.Truef(sc.t, evalCtx.Ctx.Value(notificationSent{}).(bool), "expected notification to be sent, but wasn't")
})
notificationServiceScenario(t, "Given alert rule with upload image disabled should not render and upload image, but send notification", evalCtx, false, func(scenarioCtx *scenarioContext) {
err := scenarioCtx.notificationService.SendIfNeeded(evalCtx)
notificationServiceScenario(
t, "Given alert rule with upload image disabled should not render and upload image, but send notification",
evalCtx, false, func(sc *scenarioContext) {
err := sc.notificationService.SendIfNeeded(evalCtx)
require.NoError(t, err)
require.Equalf(t, 0, scenarioCtx.renderCount, "expected render not to be called, but it was")
require.Equalf(t, 0, scenarioCtx.imageUploadCount, "expected image not to be uploaded, but it was")
require.Truef(t, evalCtx.Ctx.Value(notificationSent{}).(bool), "expected notification to be sent, but wasn't")
require.Equalf(sc.t, 0, sc.renderCount, "expected render not to be called, but it was")
require.Equalf(sc.t, 0, sc.imageUploadCount, "expected image not to be uploaded, but it was")
require.Truef(sc.t, evalCtx.Ctx.Value(notificationSent{}).(bool), "expected notification to be sent, but wasn't")
})
notificationServiceScenario(t, "Given alert rule with upload image enabled and render times out should send notification", evalCtx, true, func(scenarioCtx *scenarioContext) {
notificationServiceScenario(t, "Given alert rule with upload image enabled and render times out should send notification",
evalCtx, true, func(sc *scenarioContext) {
setting.AlertingNotificationTimeout = 200 * time.Millisecond
scenarioCtx.renderProvider = func(ctx context.Context, opts rendering.Opts) (*rendering.RenderResult, error) {
sc.renderProvider = func(ctx context.Context, opts rendering.Opts) (*rendering.RenderResult, error) {
wait := make(chan bool)
go func() {
@ -79,17 +85,18 @@ func TestNotificationService(t *testing.T) {
return nil, nil
}
err := scenarioCtx.notificationService.SendIfNeeded(evalCtx)
require.NoError(t, err)
err := sc.notificationService.SendIfNeeded(evalCtx)
require.NoError(sc.t, err)
require.Equalf(t, 0, scenarioCtx.renderCount, "expected render not to be called, but it was")
require.Equalf(t, 0, scenarioCtx.imageUploadCount, "expected image not to be uploaded, but it was")
require.Truef(t, evalCtx.Ctx.Value(notificationSent{}).(bool), "expected notification to be sent, but wasn't")
require.Equalf(sc.t, 0, sc.renderCount, "expected render not to be called, but it was")
require.Equalf(sc.t, 0, sc.imageUploadCount, "expected image not to be uploaded, but it was")
require.Truef(sc.t, evalCtx.Ctx.Value(notificationSent{}).(bool), "expected notification to be sent, but wasn't")
})
notificationServiceScenario(t, "Given alert rule with upload image enabled and upload times out should send notification", evalCtx, true, func(scenarioCtx *scenarioContext) {
notificationServiceScenario(t, "Given alert rule with upload image enabled and upload times out should send notification",
evalCtx, true, func(sc *scenarioContext) {
setting.AlertingNotificationTimeout = 200 * time.Millisecond
scenarioCtx.uploadProvider = func(ctx context.Context, path string) (string, error) {
sc.uploadProvider = func(ctx context.Context, path string) (string, error) {
wait := make(chan bool)
go func() {
@ -108,16 +115,17 @@ func TestNotificationService(t *testing.T) {
return "", nil
}
err := scenarioCtx.notificationService.SendIfNeeded(evalCtx)
require.NoError(t, err)
err := sc.notificationService.SendIfNeeded(evalCtx)
require.NoError(sc.t, err)
require.Equalf(t, 1, scenarioCtx.renderCount, "expected render to be called, but wasn't")
require.Equalf(t, 0, scenarioCtx.imageUploadCount, "expected image not to be uploaded, but it was")
require.Truef(t, evalCtx.Ctx.Value(notificationSent{}).(bool), "expected notification to be sent, but wasn't")
require.Equalf(sc.t, 1, sc.renderCount, "expected render to be called, but wasn't")
require.Equalf(sc.t, 0, sc.imageUploadCount, "expected image not to be uploaded, but it was")
require.Truef(sc.t, evalCtx.Ctx.Value(notificationSent{}).(bool), "expected notification to be sent, but wasn't")
})
}
type scenarioContext struct {
t *testing.T
evalCtx *EvalContext
notificationService *notificationService
imageUploadCount int
@ -175,6 +183,7 @@ func notificationServiceScenario(t *testing.T, name string, evalCtx *EvalContext
setting.AlertingNotificationTimeout = 30 * time.Second
scenarioCtx := &scenarioContext{
t: t,
evalCtx: evalCtx,
}