diff --git a/pkg/api/http_server.go b/pkg/api/http_server.go index 25040a02729..7fddd31bfaa 100644 --- a/pkg/api/http_server.go +++ b/pkg/api/http_server.go @@ -467,7 +467,7 @@ func (hs *HTTPServer) addMiddlewaresAndStaticRoutes() { } m.Use(middleware.Recovery(hs.Cfg)) - m.UseMiddleware(middleware.CSRF(hs.Cfg.LoginCookieName)) + m.UseMiddleware(middleware.CSRF(hs.Cfg.LoginCookieName, hs.log)) hs.mapStatic(m, hs.Cfg.StaticRootPath, "build", "public/build") hs.mapStatic(m, hs.Cfg.StaticRootPath, "", "public", "/public/views/swagger.html") diff --git a/pkg/middleware/csrf.go b/pkg/middleware/csrf.go index bc70d09779d..7bce53f5666 100644 --- a/pkg/middleware/csrf.go +++ b/pkg/middleware/csrf.go @@ -4,10 +4,12 @@ import ( "errors" "net/http" "net/url" - "strings" + + "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/util" ) -func CSRF(loginCookieName string) func(http.Handler) http.Handler { +func CSRF(loginCookieName string, logger log.Logger) func(http.Handler) http.Handler { // As per RFC 7231/4.2.2 these methods are idempotent: // (GET is excluded because it may have side effects in some APIs) safeMethods := []string{"HEAD", "OPTIONS", "TRACE"} @@ -27,12 +29,21 @@ func CSRF(loginCookieName string) func(http.Handler) http.Handler { } } // Otherwise - verify that Origin matches the server origin - host := strings.Split(r.Host, ":")[0] + netAddr, err := util.SplitHostPortDefault(r.Host, "", "0") // we ignore the port + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + origin, err := url.Parse(r.Header.Get("Origin")) - if err != nil || (origin.String() != "" && origin.Hostname() != host) { + if err != nil { + logger.Error("error parsing Origin header", "err", err) + } + if err != nil || netAddr.Host == "" || (origin.String() != "" && origin.Hostname() != netAddr.Host) { http.Error(w, "origin not allowed", http.StatusForbidden) return } + next.ServeHTTP(w, r) }) } diff --git a/pkg/middleware/csrf_test.go b/pkg/middleware/csrf_test.go new file mode 100644 index 00000000000..312356cce3f --- /dev/null +++ b/pkg/middleware/csrf_test.go @@ -0,0 +1,124 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/grafana/grafana/pkg/infra/log" + "github.com/stretchr/testify/require" +) + +func TestMiddlewareCSRF(t *testing.T) { + tests := []struct { + name string + cookieName string + method string + origin string + host string + code int + }{ + { + name: "mismatched origin and host is forbidden", + cookieName: "foo", + method: "GET", + origin: "http://notLocalhost", + host: "localhost", + code: http.StatusForbidden, + }, + { + name: "mismatched origin and host is NOT forbidden with a 'Safe Method'", + cookieName: "foo", + method: "TRACE", + origin: "http://notLocalhost", + host: "localhost", + code: http.StatusOK, + }, + { + name: "mismatched origin and host is NOT forbidden without a cookie", + cookieName: "", + method: "GET", + origin: "http://notLocalhost", + host: "localhost", + code: http.StatusOK, + }, + { + name: "malformed host is a bad request", + cookieName: "foo", + method: "GET", + host: "localhost:80:80", + code: http.StatusBadRequest, + }, + { + name: "host works without port", + cookieName: "foo", + method: "GET", + host: "localhost", + origin: "http://localhost", + code: http.StatusOK, + }, + { + name: "port does not have to match", + cookieName: "foo", + method: "GET", + host: "localhost:80", + origin: "http://localhost:3000", + code: http.StatusOK, + }, + { + name: "IPv6 host works with port", + cookieName: "foo", + method: "GET", + host: "[::1]:3000", + origin: "http://[::1]:3000", + code: http.StatusOK, + }, + { + name: "IPv6 host (with longer address) works with port", + cookieName: "foo", + method: "GET", + host: "[2001:db8::1]:3000", + origin: "http://[2001:db8::1]:3000", + code: http.StatusOK, + }, + { + name: "IPv6 host (with longer address) works without port", + cookieName: "foo", + method: "GET", + host: "[2001:db8::1]", + origin: "http://[2001:db8::1]", + code: http.StatusOK, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rr := csrfScenario(t, tt.cookieName, tt.method, tt.origin, tt.host) + require.Equal(t, tt.code, rr.Code) + }) + } +} + +func csrfScenario(t *testing.T, cookieName, method, origin, host string) *httptest.ResponseRecorder { + req, err := http.NewRequest(method, "/", nil) + if err != nil { + t.Fatal(err) + } + req.AddCookie(&http.Cookie{ + Name: cookieName, + }) + + // Note: Not sure where host header populates req.Host, or how that works. + req.Host = host + req.Header.Set("HOST", host) + + req.Header.Set("ORIGIN", origin) + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + }) + + rr := httptest.NewRecorder() + handler := CSRF(cookieName, log.New())(testHandler) + handler.ServeHTTP(rr, req) + return rr +}