mirror of
https://github.com/grafana/grafana.git
synced 2025-09-19 19:33:49 +08:00
Frontend logging: handle logging endpoints without expensive middleware (#54960)
This commit is contained in:
@ -31,9 +31,6 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/api/frontendlogging"
|
|
||||||
"github.com/grafana/grafana/pkg/api/routing"
|
"github.com/grafana/grafana/pkg/api/routing"
|
||||||
"github.com/grafana/grafana/pkg/infra/log"
|
"github.com/grafana/grafana/pkg/infra/log"
|
||||||
"github.com/grafana/grafana/pkg/middleware"
|
"github.com/grafana/grafana/pkg/middleware"
|
||||||
@ -650,11 +647,4 @@ func (hs *HTTPServer) registerRoutes() {
|
|||||||
r.Get("/api/snapshots/:key", routing.Wrap(hs.GetDashboardSnapshot))
|
r.Get("/api/snapshots/:key", routing.Wrap(hs.GetDashboardSnapshot))
|
||||||
r.Get("/api/snapshots-delete/:deleteKey", reqSnapshotPublicModeOrSignedIn, routing.Wrap(hs.DeleteDashboardSnapshotByDeleteKey))
|
r.Get("/api/snapshots-delete/:deleteKey", reqSnapshotPublicModeOrSignedIn, routing.Wrap(hs.DeleteDashboardSnapshotByDeleteKey))
|
||||||
r.Delete("/api/snapshots/:key", reqEditorRole, routing.Wrap(hs.DeleteDashboardSnapshot))
|
r.Delete("/api/snapshots/:key", reqEditorRole, routing.Wrap(hs.DeleteDashboardSnapshot))
|
||||||
|
|
||||||
// Frontend logs
|
|
||||||
sourceMapStore := frontendlogging.NewSourceMapStore(hs.Cfg, hs.pluginStaticRouteResolver, frontendlogging.ReadSourceMapFromFS)
|
|
||||||
r.Post("/log", middleware.RateLimit(hs.Cfg.Sentry.EndpointRPS, hs.Cfg.Sentry.EndpointBurst, time.Now),
|
|
||||||
routing.Wrap(NewFrontendLogMessageHandler(sourceMapStore)))
|
|
||||||
r.Post("/log-grafana-javascript-agent", middleware.RateLimit(hs.Cfg.GrafanaJavascriptAgent.EndpointRPS, hs.Cfg.GrafanaJavascriptAgent.EndpointBurst, time.Now),
|
|
||||||
routing.Wrap(GrafanaJavascriptAgentLogMessageHandler(sourceMapStore)))
|
|
||||||
}
|
}
|
||||||
|
@ -2,24 +2,33 @@ package api
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/getsentry/sentry-go"
|
"github.com/getsentry/sentry-go"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/api/frontendlogging"
|
"github.com/grafana/grafana/pkg/api/frontendlogging"
|
||||||
"github.com/grafana/grafana/pkg/api/response"
|
|
||||||
"github.com/grafana/grafana/pkg/infra/log"
|
"github.com/grafana/grafana/pkg/infra/log"
|
||||||
"github.com/grafana/grafana/pkg/models"
|
|
||||||
"github.com/grafana/grafana/pkg/web"
|
"github.com/grafana/grafana/pkg/web"
|
||||||
)
|
)
|
||||||
|
|
||||||
var frontendLogger = log.New("frontend")
|
var frontendLogger = log.New("frontend")
|
||||||
|
|
||||||
type frontendLogMessageHandler func(c *models.ReqContext) response.Response
|
type frontendLogMessageHandler func(hs *HTTPServer, c *web.Context)
|
||||||
|
|
||||||
|
const sentryLogEndpointPath = "/log"
|
||||||
|
const grafanaJavascriptAgentEndpointPath = "/log-grafana-javascript-agent"
|
||||||
|
|
||||||
func NewFrontendLogMessageHandler(store *frontendlogging.SourceMapStore) frontendLogMessageHandler {
|
func NewFrontendLogMessageHandler(store *frontendlogging.SourceMapStore) frontendLogMessageHandler {
|
||||||
return func(c *models.ReqContext) response.Response {
|
return func(hs *HTTPServer, c *web.Context) {
|
||||||
event := frontendlogging.FrontendSentryEvent{}
|
event := frontendlogging.FrontendSentryEvent{}
|
||||||
if err := web.Bind(c.Req, &event); err != nil {
|
if err := web.Bind(c.Req, &event); err != nil {
|
||||||
return response.Error(http.StatusBadRequest, "bad request data", err)
|
c.Resp.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, err = c.Resp.Write([]byte("bad request data"))
|
||||||
|
if err != nil {
|
||||||
|
hs.log.Error("could not write to response", "err", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var msg = "unknown"
|
var msg = "unknown"
|
||||||
@ -43,15 +52,23 @@ func NewFrontendLogMessageHandler(store *frontendlogging.SourceMapStore) fronten
|
|||||||
frontendLogger.Info(msg, ctx...)
|
frontendLogger.Info(msg, ctx...)
|
||||||
}
|
}
|
||||||
|
|
||||||
return response.Success("ok")
|
c.Resp.WriteHeader(http.StatusAccepted)
|
||||||
|
_, err := c.Resp.Write([]byte("OK"))
|
||||||
|
if err != nil {
|
||||||
|
hs.log.Error("could not write to response", "err", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GrafanaJavascriptAgentLogMessageHandler(store *frontendlogging.SourceMapStore) frontendLogMessageHandler {
|
func GrafanaJavascriptAgentLogMessageHandler(store *frontendlogging.SourceMapStore) frontendLogMessageHandler {
|
||||||
return func(c *models.ReqContext) response.Response {
|
return func(hs *HTTPServer, c *web.Context) {
|
||||||
event := frontendlogging.FrontendGrafanaJavascriptAgentEvent{}
|
event := frontendlogging.FrontendGrafanaJavascriptAgentEvent{}
|
||||||
if err := web.Bind(c.Req, &event); err != nil {
|
if err := web.Bind(c.Req, &event); err != nil {
|
||||||
return response.Error(http.StatusBadRequest, "bad request data", err)
|
c.Resp.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, err = c.Resp.Write([]byte("bad request data"))
|
||||||
|
if err != nil {
|
||||||
|
hs.log.Error("could not write to response", "err", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Meta object is standard across event types, adding it globally.
|
// Meta object is standard across event types, adding it globally.
|
||||||
@ -112,6 +129,64 @@ func GrafanaJavascriptAgentLogMessageHandler(store *frontendlogging.SourceMapSto
|
|||||||
frontendLogger.Error(exception.Message(), ctx...)
|
frontendLogger.Error(exception.Message(), ctx...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return response.Success("ok")
|
c.Resp.WriteHeader(http.StatusAccepted)
|
||||||
|
_, err := c.Resp.Write([]byte("OK"))
|
||||||
|
if err != nil {
|
||||||
|
hs.log.Error("could not write to response", "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupFrontendLogHandlers will set up handlers for logs incoming from frontend.
|
||||||
|
// handlers are setup even if frontend logging is disabled, but in this case do nothing
|
||||||
|
// this is to avoid reporting errors in case config was changes but there are browser
|
||||||
|
// sessions still open with older config
|
||||||
|
func (hs *HTTPServer) frontendLogEndpoints() web.Handler {
|
||||||
|
if !(hs.Cfg.GrafanaJavascriptAgent.Enabled || hs.Cfg.Sentry.Enabled) {
|
||||||
|
return func(ctx *web.Context) {
|
||||||
|
if ctx.Req.Method == http.MethodPost && (ctx.Req.URL.Path == sentryLogEndpointPath || ctx.Req.URL.Path == grafanaJavascriptAgentEndpointPath) {
|
||||||
|
ctx.Resp.WriteHeader(http.StatusAccepted)
|
||||||
|
_, err := ctx.Resp.Write([]byte("OK"))
|
||||||
|
if err != nil {
|
||||||
|
hs.log.Error("could not write to response", "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sourceMapStore := frontendlogging.NewSourceMapStore(hs.Cfg, hs.pluginStaticRouteResolver, frontendlogging.ReadSourceMapFromFS)
|
||||||
|
|
||||||
|
var rateLimiter *rate.Limiter
|
||||||
|
var handler frontendLogMessageHandler
|
||||||
|
handlerEndpoint := ""
|
||||||
|
dummyEndpoint := ""
|
||||||
|
|
||||||
|
if hs.Cfg.GrafanaJavascriptAgent.Enabled {
|
||||||
|
rateLimiter = rate.NewLimiter(rate.Limit(hs.Cfg.GrafanaJavascriptAgent.EndpointRPS), hs.Cfg.GrafanaJavascriptAgent.EndpointBurst)
|
||||||
|
handler = GrafanaJavascriptAgentLogMessageHandler(sourceMapStore)
|
||||||
|
handlerEndpoint = grafanaJavascriptAgentEndpointPath
|
||||||
|
dummyEndpoint = sentryLogEndpointPath
|
||||||
|
} else {
|
||||||
|
rateLimiter = rate.NewLimiter(rate.Limit(hs.Cfg.Sentry.EndpointRPS), hs.Cfg.Sentry.EndpointBurst)
|
||||||
|
handler = NewFrontendLogMessageHandler(sourceMapStore)
|
||||||
|
handlerEndpoint = sentryLogEndpointPath
|
||||||
|
dummyEndpoint = grafanaJavascriptAgentEndpointPath
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(ctx *web.Context) {
|
||||||
|
if ctx.Req.Method == http.MethodPost && ctx.Req.URL.Path == dummyEndpoint {
|
||||||
|
ctx.Resp.WriteHeader(http.StatusAccepted)
|
||||||
|
_, err := ctx.Resp.Write([]byte("OK"))
|
||||||
|
if err != nil {
|
||||||
|
hs.log.Error("could not write to response", "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ctx.Req.Method == http.MethodPost && ctx.Req.URL.Path == handlerEndpoint {
|
||||||
|
if !rateLimiter.AllowN(time.Now(), 1) {
|
||||||
|
ctx.Resp.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
handler(hs, ctx)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package api
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
@ -92,7 +93,8 @@ func logSentryEventScenario(t *testing.T, desc string, event frontendlogging.Fro
|
|||||||
sc.context = c
|
sc.context = c
|
||||||
c.Req.Body = mockRequestBody(event)
|
c.Req.Body = mockRequestBody(event)
|
||||||
c.Req.Header.Add("Content-Type", "application/json")
|
c.Req.Header.Add("Content-Type", "application/json")
|
||||||
return loggingHandler(c)
|
loggingHandler(nil, c.Context)
|
||||||
|
return response.Success("ok")
|
||||||
})
|
})
|
||||||
|
|
||||||
sc.m.Post(sc.url, handler)
|
sc.m.Post(sc.url, handler)
|
||||||
@ -164,7 +166,8 @@ func logGrafanaJavascriptAgentEventScenario(t *testing.T, desc string, event fro
|
|||||||
sc.context = c
|
sc.context = c
|
||||||
c.Req.Body = mockRequestBody(event)
|
c.Req.Body = mockRequestBody(event)
|
||||||
c.Req.Header.Add("Content-Type", "application/json")
|
c.Req.Header.Add("Content-Type", "application/json")
|
||||||
return loggingHandler(c)
|
loggingHandler(nil, c.Context)
|
||||||
|
return response.Success("OK")
|
||||||
})
|
})
|
||||||
|
|
||||||
sc.m.Post(sc.url, handler)
|
sc.m.Post(sc.url, handler)
|
||||||
@ -227,7 +230,7 @@ func TestFrontendLoggingEndpointSentry(t *testing.T) {
|
|||||||
|
|
||||||
logSentryEventScenario(t, "Should log received error event", errorEvent,
|
logSentryEventScenario(t, "Should log received error event", errorEvent,
|
||||||
func(sc *scenarioContext, logs map[string]interface{}, sourceMapReads []SourceMapReadRecord) {
|
func(sc *scenarioContext, logs map[string]interface{}, sourceMapReads []SourceMapReadRecord) {
|
||||||
assert.Equal(t, 200, sc.resp.Code)
|
assert.Equal(t, http.StatusAccepted, sc.resp.Code)
|
||||||
assertContextContains(t, logs, "logger", "frontend")
|
assertContextContains(t, logs, "logger", "frontend")
|
||||||
assertContextContains(t, logs, "url", errorEvent.Request.URL)
|
assertContextContains(t, logs, "url", errorEvent.Request.URL)
|
||||||
assertContextContains(t, logs, "user_agent", errorEvent.Request.Headers["User-Agent"])
|
assertContextContains(t, logs, "user_agent", errorEvent.Request.Headers["User-Agent"])
|
||||||
@ -253,7 +256,7 @@ func TestFrontendLoggingEndpointSentry(t *testing.T) {
|
|||||||
|
|
||||||
logSentryEventScenario(t, "Should log received message event", messageEvent,
|
logSentryEventScenario(t, "Should log received message event", messageEvent,
|
||||||
func(sc *scenarioContext, logs map[string]interface{}, sourceMapReads []SourceMapReadRecord) {
|
func(sc *scenarioContext, logs map[string]interface{}, sourceMapReads []SourceMapReadRecord) {
|
||||||
assert.Equal(t, 200, sc.resp.Code)
|
assert.Equal(t, http.StatusAccepted, sc.resp.Code)
|
||||||
assert.Len(t, logs, 10)
|
assert.Len(t, logs, 10)
|
||||||
assertContextContains(t, logs, "logger", "frontend")
|
assertContextContains(t, logs, "logger", "frontend")
|
||||||
assertContextContains(t, logs, "msg", "hello world")
|
assertContextContains(t, logs, "msg", "hello world")
|
||||||
@ -290,7 +293,7 @@ func TestFrontendLoggingEndpointSentry(t *testing.T) {
|
|||||||
|
|
||||||
logSentryEventScenario(t, "Should log event context", eventWithContext,
|
logSentryEventScenario(t, "Should log event context", eventWithContext,
|
||||||
func(sc *scenarioContext, logs map[string]interface{}, sourceMapReads []SourceMapReadRecord) {
|
func(sc *scenarioContext, logs map[string]interface{}, sourceMapReads []SourceMapReadRecord) {
|
||||||
assert.Equal(t, 200, sc.resp.Code)
|
assert.Equal(t, http.StatusAccepted, sc.resp.Code)
|
||||||
assertContextContains(t, logs, "context_foo_one", "two")
|
assertContextContains(t, logs, "context_foo_one", "two")
|
||||||
assertContextContains(t, logs, "context_foo_three", "4")
|
assertContextContains(t, logs, "context_foo_three", "4")
|
||||||
assertContextContains(t, logs, "context_bar", "baz")
|
assertContextContains(t, logs, "context_bar", "baz")
|
||||||
@ -356,7 +359,7 @@ func TestFrontendLoggingEndpointSentry(t *testing.T) {
|
|||||||
|
|
||||||
logSentryEventScenario(t, "Should load sourcemap and transform stacktrace line when possible",
|
logSentryEventScenario(t, "Should load sourcemap and transform stacktrace line when possible",
|
||||||
errorEventForSourceMapping, func(sc *scenarioContext, logs map[string]interface{}, sourceMapReads []SourceMapReadRecord) {
|
errorEventForSourceMapping, func(sc *scenarioContext, logs map[string]interface{}, sourceMapReads []SourceMapReadRecord) {
|
||||||
assert.Equal(t, 200, sc.resp.Code)
|
assert.Equal(t, http.StatusAccepted, sc.resp.Code)
|
||||||
assert.Len(t, logs, 9)
|
assert.Len(t, logs, 9)
|
||||||
assertContextContains(t, logs, "stacktrace", `UserError: Please replace user and try again
|
assertContextContains(t, logs, "stacktrace", `UserError: Please replace user and try again
|
||||||
at ? (core|webpack:///./some_source.ts:2:2)
|
at ? (core|webpack:///./some_source.ts:2:2)
|
||||||
@ -420,7 +423,7 @@ func TestFrontendLoggingEndpointGrafanaJavascriptAgent(t *testing.T) {
|
|||||||
|
|
||||||
logGrafanaJavascriptAgentEventScenario(t, "Should log received error event", errorEvent,
|
logGrafanaJavascriptAgentEventScenario(t, "Should log received error event", errorEvent,
|
||||||
func(sc *scenarioContext, logs map[string]interface{}, sourceMapReads []SourceMapReadRecord) {
|
func(sc *scenarioContext, logs map[string]interface{}, sourceMapReads []SourceMapReadRecord) {
|
||||||
assert.Equal(t, 200, sc.resp.Code)
|
assert.Equal(t, http.StatusAccepted, sc.resp.Code)
|
||||||
assertContextContains(t, logs, "logger", "frontend")
|
assertContextContains(t, logs, "logger", "frontend")
|
||||||
assertContextContains(t, logs, "page_url", errorEvent.Meta.Page.URL)
|
assertContextContains(t, logs, "page_url", errorEvent.Meta.Page.URL)
|
||||||
assertContextContains(t, logs, "user_email", errorEvent.Meta.User.Email)
|
assertContextContains(t, logs, "user_email", errorEvent.Meta.User.Email)
|
||||||
@ -443,7 +446,7 @@ func TestFrontendLoggingEndpointGrafanaJavascriptAgent(t *testing.T) {
|
|||||||
|
|
||||||
logGrafanaJavascriptAgentEventScenario(t, "Should log received log event", logEvent,
|
logGrafanaJavascriptAgentEventScenario(t, "Should log received log event", logEvent,
|
||||||
func(sc *scenarioContext, logs map[string]interface{}, sourceMapReads []SourceMapReadRecord) {
|
func(sc *scenarioContext, logs map[string]interface{}, sourceMapReads []SourceMapReadRecord) {
|
||||||
assert.Equal(t, 200, sc.resp.Code)
|
assert.Equal(t, http.StatusAccepted, sc.resp.Code)
|
||||||
assert.Len(t, logs, 11)
|
assert.Len(t, logs, 11)
|
||||||
assertContextContains(t, logs, "logger", "frontend")
|
assertContextContains(t, logs, "logger", "frontend")
|
||||||
assertContextContains(t, logs, "msg", "This is a test log message")
|
assertContextContains(t, logs, "msg", "This is a test log message")
|
||||||
@ -468,7 +471,7 @@ func TestFrontendLoggingEndpointGrafanaJavascriptAgent(t *testing.T) {
|
|||||||
|
|
||||||
logGrafanaJavascriptAgentEventScenario(t, "Should log received log context", logEventWithContext,
|
logGrafanaJavascriptAgentEventScenario(t, "Should log received log context", logEventWithContext,
|
||||||
func(sc *scenarioContext, logs map[string]interface{}, sourceMapReads []SourceMapReadRecord) {
|
func(sc *scenarioContext, logs map[string]interface{}, sourceMapReads []SourceMapReadRecord) {
|
||||||
assert.Equal(t, 200, sc.resp.Code)
|
assert.Equal(t, http.StatusAccepted, sc.resp.Code)
|
||||||
assertContextContains(t, logs, "context_one", "two")
|
assertContextContains(t, logs, "context_one", "two")
|
||||||
assertContextContains(t, logs, "context_bar", "baz")
|
assertContextContains(t, logs, "context_bar", "baz")
|
||||||
})
|
})
|
||||||
@ -531,7 +534,7 @@ func TestFrontendLoggingEndpointGrafanaJavascriptAgent(t *testing.T) {
|
|||||||
|
|
||||||
logGrafanaJavascriptAgentEventScenario(t, "Should load sourcemap and transform stacktrace line when possible", errorEventForSourceMapping,
|
logGrafanaJavascriptAgentEventScenario(t, "Should load sourcemap and transform stacktrace line when possible", errorEventForSourceMapping,
|
||||||
func(sc *scenarioContext, logs map[string]interface{}, sourceMapReads []SourceMapReadRecord) {
|
func(sc *scenarioContext, logs map[string]interface{}, sourceMapReads []SourceMapReadRecord) {
|
||||||
assert.Equal(t, 200, sc.resp.Code)
|
assert.Equal(t, http.StatusAccepted, sc.resp.Code)
|
||||||
assertContextContains(t, logs, "stacktrace", `UserError: Please replace user and try again
|
assertContextContains(t, logs, "stacktrace", `UserError: Please replace user and try again
|
||||||
at ? (webpack:///./some_source.ts:2:2)
|
at ? (webpack:///./some_source.ts:2:2)
|
||||||
at ? (webpack:///./some_source.ts:3:2)
|
at ? (webpack:///./some_source.ts:3:2)
|
||||||
@ -567,7 +570,7 @@ func TestFrontendLoggingEndpointGrafanaJavascriptAgent(t *testing.T) {
|
|||||||
|
|
||||||
logGrafanaJavascriptAgentEventScenario(t, "Should log web vitals as context", logWebVitals,
|
logGrafanaJavascriptAgentEventScenario(t, "Should log web vitals as context", logWebVitals,
|
||||||
func(sc *scenarioContext, logs map[string]interface{}, sourceMapReads []SourceMapReadRecord) {
|
func(sc *scenarioContext, logs map[string]interface{}, sourceMapReads []SourceMapReadRecord) {
|
||||||
assert.Equal(t, 200, sc.resp.Code)
|
assert.Equal(t, http.StatusAccepted, sc.resp.Code)
|
||||||
assertContextContains(t, logs, "CLS", float64(1))
|
assertContextContains(t, logs, "CLS", float64(1))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -572,6 +572,7 @@ func (hs *HTTPServer) addMiddlewaresAndStaticRoutes() {
|
|||||||
m.Use(hs.apiHealthHandler)
|
m.Use(hs.apiHealthHandler)
|
||||||
m.Use(hs.metricsEndpoint)
|
m.Use(hs.metricsEndpoint)
|
||||||
m.Use(hs.pluginMetricsEndpoint)
|
m.Use(hs.pluginMetricsEndpoint)
|
||||||
|
m.Use(hs.frontendLogEndpoints())
|
||||||
|
|
||||||
m.UseMiddleware(hs.ContextHandler.Middleware)
|
m.UseMiddleware(hs.ContextHandler.Middleware)
|
||||||
m.Use(middleware.OrgRedirect(hs.Cfg, hs.SQLStore))
|
m.Use(middleware.OrgRedirect(hs.Cfg, hs.SQLStore))
|
||||||
|
@ -1,24 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/models"
|
|
||||||
"github.com/grafana/grafana/pkg/web"
|
|
||||||
"golang.org/x/time/rate"
|
|
||||||
)
|
|
||||||
|
|
||||||
type getTimeFn func() time.Time
|
|
||||||
|
|
||||||
// RateLimit is a very basic rate limiter.
|
|
||||||
// Will allow average of "rps" requests per second over an extended period of time, with max "burst" requests at the same time.
|
|
||||||
// getTime should return the current time. For non-testing purposes use time.Now
|
|
||||||
func RateLimit(rps, burst int, getTime getTimeFn) web.Handler {
|
|
||||||
l := rate.NewLimiter(rate.Limit(rps), burst)
|
|
||||||
return func(c *models.ReqContext) {
|
|
||||||
if !l.AllowN(getTime(), 1) {
|
|
||||||
c.JsonApiErr(429, "Rate limit reached", nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,88 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/models"
|
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/web"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
type execFunc func() *httptest.ResponseRecorder
|
|
||||||
type advanceTimeFunc func(deltaTime time.Duration)
|
|
||||||
type rateLimiterScenarioFunc func(c execFunc, t advanceTimeFunc)
|
|
||||||
|
|
||||||
func rateLimiterScenario(t *testing.T, desc string, rps int, burst int, fn rateLimiterScenarioFunc) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
t.Run(desc, func(t *testing.T) {
|
|
||||||
defaultHandler := func(c *models.ReqContext) {
|
|
||||||
resp := make(map[string]interface{})
|
|
||||||
resp["message"] = "OK"
|
|
||||||
c.JSON(http.StatusOK, resp)
|
|
||||||
}
|
|
||||||
currentTime := time.Now()
|
|
||||||
|
|
||||||
cfg := setting.NewCfg()
|
|
||||||
|
|
||||||
m := web.New()
|
|
||||||
m.UseMiddleware(web.Renderer("../../public/views", "[[", "]]"))
|
|
||||||
m.Use(getContextHandler(t, cfg, nil, nil, nil, nil).Middleware)
|
|
||||||
m.Get("/foo", RateLimit(rps, burst, func() time.Time { return currentTime }), defaultHandler)
|
|
||||||
|
|
||||||
fn(func() *httptest.ResponseRecorder {
|
|
||||||
resp := httptest.NewRecorder()
|
|
||||||
req, err := http.NewRequest("GET", "/foo", nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
m.ServeHTTP(resp, req)
|
|
||||||
return resp
|
|
||||||
}, func(deltaTime time.Duration) {
|
|
||||||
currentTime = currentTime.Add(deltaTime)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRateLimitMiddleware(t *testing.T) {
|
|
||||||
rateLimiterScenario(t, "rate limit calls, with burst", 10, 10, func(doReq execFunc, advanceTime advanceTimeFunc) {
|
|
||||||
// first 10 calls succeed
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
resp := doReq()
|
|
||||||
assert.Equal(t, 200, resp.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
// next one fails
|
|
||||||
resp := doReq()
|
|
||||||
assert.Equal(t, 429, resp.Code)
|
|
||||||
|
|
||||||
// check that requests are accepted again in 1 sec
|
|
||||||
advanceTime(1 * time.Second)
|
|
||||||
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
resp := doReq()
|
|
||||||
assert.Equal(t, 200, resp.Code)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
rateLimiterScenario(t, "rate limit calls, no burst", 10, 1, func(doReq execFunc, advanceTime advanceTimeFunc) {
|
|
||||||
// first calls succeeds
|
|
||||||
resp := doReq()
|
|
||||||
assert.Equal(t, 200, resp.Code)
|
|
||||||
|
|
||||||
// immediately fired next one fails
|
|
||||||
resp = doReq()
|
|
||||||
assert.Equal(t, 429, resp.Code)
|
|
||||||
|
|
||||||
// but spacing calls out works
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
advanceTime(100 * time.Millisecond)
|
|
||||||
resp := doReq()
|
|
||||||
assert.Equal(t, 200, resp.Code)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
Reference in New Issue
Block a user