diff --git a/pkg/middleware/middleware.go b/pkg/middleware/middleware.go index 75da4c3ac27..92f97a0529b 100644 --- a/pkg/middleware/middleware.go +++ b/pkg/middleware/middleware.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "fmt" "net/url" "strconv" @@ -228,7 +229,19 @@ func initContextWithToken(authTokenService models.UserTokenService, ctx *models. // Rotate the token just before we write response headers to ensure there is no delay between // the new token being generated and the client receiving it. - ctx.Resp.Before(func(w macaron.ResponseWriter) { + ctx.Resp.Before(rotateEndOfRequestFunc(ctx, authTokenService, token)) + + return true +} + +func rotateEndOfRequestFunc(ctx *models.ReqContext, authTokenService models.UserTokenService, token *models.UserToken) macaron.BeforeFunc { + return func(w macaron.ResponseWriter) { + // if the request is cancelled by the client we should not try + // to rotate the token since the client would not accept any result. + if ctx.Context.Req.Context().Err() == context.Canceled { + return + } + rotated, err := authTokenService.TryRotateToken(ctx.Req.Context(), token, ctx.RemoteAddr(), ctx.Req.UserAgent()) if err != nil { ctx.Logger.Error("Failed to rotate token", "error", err) @@ -238,9 +251,7 @@ func initContextWithToken(authTokenService models.UserTokenService, ctx *models. if rotated { WriteSessionCookie(ctx, token.UnhashedToken, setting.LoginMaxLifetimeDays) } - }) - - return true + } } func WriteSessionCookie(ctx *models.ReqContext, value string, maxLifetimeDays int) { diff --git a/pkg/middleware/middleware_test.go b/pkg/middleware/middleware_test.go index 1977eb66d28..dd91f58e7ef 100644 --- a/pkg/middleware/middleware_test.go +++ b/pkg/middleware/middleware_test.go @@ -6,16 +6,19 @@ import ( "errors" "fmt" "net/http" + "net/http/httptest" "path/filepath" "testing" "time" . "github.com/smartystreets/goconvey/convey" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/macaron.v1" "github.com/grafana/grafana/pkg/api/dtos" "github.com/grafana/grafana/pkg/bus" + "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/remotecache" authproxy "github.com/grafana/grafana/pkg/middleware/auth_proxy" "github.com/grafana/grafana/pkg/models" @@ -541,7 +544,8 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc) { sc := &scenarioContext{} - viewsPath, _ := filepath.Abs("../../public/views") + viewsPath, err := filepath.Abs("../../public/views") + require.NoError(t, err) sc.m = macaron.New() sc.m.Use(AddDefaultResponseHeaders()) @@ -571,3 +575,88 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc) { fn(sc) }) } + +func TestDontRotateTokensOnCancelledRequests(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + reqContext, _, err := initTokenRotationTest(ctx) + require.NoError(t, err) + + tryRotateCallCount := 0 + uts := &auth.FakeUserAuthTokenService{ + TryRotateTokenProvider: func(ctx context.Context, token *models.UserToken, clientIP, userAgent string) (bool, error) { + tryRotateCallCount++ + return false, nil + }, + } + + token := &models.UserToken{AuthToken: "oldtoken"} + + fn := rotateEndOfRequestFunc(reqContext, uts, token) + cancel() + fn(reqContext.Resp) + + assert.Equal(t, 0, tryRotateCallCount, "Token rotation was attempted") +} + +func TestTokenRotationAtEndOfRequest(t *testing.T) { + reqContext, rr, err := initTokenRotationTest(context.Background()) + require.NoError(t, err) + + uts := &auth.FakeUserAuthTokenService{ + TryRotateTokenProvider: func(ctx context.Context, token *models.UserToken, clientIP, userAgent string) (bool, error) { + newToken, err := util.RandomHex(16) + require.NoError(t, err) + token.AuthToken = newToken + return true, nil + }, + } + + token := &models.UserToken{AuthToken: "oldtoken"} + + rotateEndOfRequestFunc(reqContext, uts, token)(reqContext.Resp) + + foundLoginCookie := false + for _, c := range rr.Result().Cookies() { + if c.Name == "login_token" { + foundLoginCookie = true + + require.NotEqual(t, token.AuthToken, c.Value, "Auth token is still the same") + } + } + + assert.True(t, foundLoginCookie, "Could not find cookie") +} + +func initTokenRotationTest(ctx context.Context) (*models.ReqContext, *httptest.ResponseRecorder, error) { + setting.LoginCookieName = "login_token" + setting.LoginMaxLifetimeDays = 7 + + rr := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, "", "", nil) + if err != nil { + return nil, nil, err + } + reqContext := &models.ReqContext{ + Context: &macaron.Context{ + Req: macaron.Request{ + Request: req, + }, + }, + Logger: log.New("testlogger"), + } + + mw := mockWriter{rr} + reqContext.Resp = mw + + return reqContext, rr, nil +} + +type mockWriter struct { + *httptest.ResponseRecorder +} + +func (mw mockWriter) Flush() {} +func (mw mockWriter) Status() int { return 0 } +func (mw mockWriter) Size() int { return 0 } +func (mw mockWriter) Written() bool { return false } +func (mw mockWriter) Before(macaron.BeforeFunc) {}