From f76f426da3b5ba24ed40a45a935d8b6089131e5c Mon Sep 17 00:00:00 2001 From: Marcus Efraimsson Date: Thu, 27 May 2021 12:43:21 +0200 Subject: [PATCH] Chore: Refactor Prometheus HTTP client middleware (#34473) Following #33439 this refactors the Prometheus HTTP transport which is replaced by HTTP client middleware. --- pkg/models/datasource_cache.go | 7 +- .../custom_query_params_middleware.go | 48 +++++++ .../custom_query_params_middleware_test.go | 109 ++++++++++++++++ pkg/tsdb/prometheus/prometheus.go | 119 ++++-------------- pkg/tsdb/prometheus/prometheus_test.go | 34 +++-- 5 files changed, 203 insertions(+), 114 deletions(-) create mode 100644 pkg/tsdb/prometheus/custom_query_params_middleware.go create mode 100644 pkg/tsdb/prometheus/custom_query_params_middleware_test.go diff --git a/pkg/models/datasource_cache.go b/pkg/models/datasource_cache.go index 1ef513bf64b..bb0a2b28d40 100644 --- a/pkg/models/datasource_cache.go +++ b/pkg/models/datasource_cache.go @@ -50,7 +50,7 @@ func (ds *DataSource) GetHTTPClient(provider httpclient.Provider) (*http.Client, }, nil } -func (ds *DataSource) GetHTTPTransport(provider httpclient.Provider) (http.RoundTripper, error) { +func (ds *DataSource) GetHTTPTransport(provider httpclient.Provider, customMiddlewares ...sdkhttpclient.Middleware) (http.RoundTripper, error) { ptc.Lock() defer ptc.Unlock() @@ -58,7 +58,10 @@ func (ds *DataSource) GetHTTPTransport(provider httpclient.Provider) (http.Round return t.roundTripper, nil } - rt, err := provider.GetTransport(ds.HTTPClientOptions()) + opts := ds.HTTPClientOptions() + opts.Middlewares = customMiddlewares + + rt, err := provider.GetTransport(opts) if err != nil { return nil, err } diff --git a/pkg/tsdb/prometheus/custom_query_params_middleware.go b/pkg/tsdb/prometheus/custom_query_params_middleware.go new file mode 100644 index 00000000000..88f23d10c9c --- /dev/null +++ b/pkg/tsdb/prometheus/custom_query_params_middleware.go @@ -0,0 +1,48 @@ +package prometheus + +import ( + "fmt" + "net/http" + "net/url" + "strings" + + sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" +) + +const ( + customQueryParametersMiddlewareName = "prom-custom-query-parameters" + customQueryParametersKey = "customQueryParameters" +) + +func customQueryParametersMiddleware() sdkhttpclient.Middleware { + return sdkhttpclient.NamedMiddlewareFunc(customQueryParametersMiddlewareName, func(opts sdkhttpclient.Options, next http.RoundTripper) http.RoundTripper { + customQueryParamsVal, exists := opts.CustomOptions[customQueryParametersKey] + if !exists { + return next + } + customQueryParams, ok := customQueryParamsVal.(string) + if !ok || customQueryParams == "" { + return next + } + + return sdkhttpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + params := url.Values{} + for _, param := range strings.Split(customQueryParams, "&") { + parts := strings.Split(param, "=") + if len(parts) == 1 { + // This is probably a mistake on the users part in defining the params but we don't want to crash. + params.Add(parts[0], "") + } else { + params.Add(parts[0], parts[1]) + } + } + if req.URL.RawQuery != "" { + req.URL.RawQuery = fmt.Sprintf("%s&%s", req.URL.RawQuery, params.Encode()) + } else { + req.URL.RawQuery = params.Encode() + } + + return next.RoundTrip(req) + }) + }) +} diff --git a/pkg/tsdb/prometheus/custom_query_params_middleware_test.go b/pkg/tsdb/prometheus/custom_query_params_middleware_test.go new file mode 100644 index 00000000000..444ffcf3d5f --- /dev/null +++ b/pkg/tsdb/prometheus/custom_query_params_middleware_test.go @@ -0,0 +1,109 @@ +package prometheus + +import ( + "net/http" + "testing" + + "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" + "github.com/stretchr/testify/require" +) + +func TestCustomQueryParametersMiddleware(t *testing.T) { + require.Equal(t, "customQueryParameters", customQueryParametersKey) + + finalRoundTripper := httpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{StatusCode: http.StatusOK}, nil + }) + + t.Run("Without custom query parameters set should not apply middleware", func(t *testing.T) { + mw := customQueryParametersMiddleware() + rt := mw.CreateMiddleware(httpclient.Options{}, finalRoundTripper) + require.NotNil(t, rt) + middlewareName, ok := mw.(httpclient.MiddlewareName) + require.True(t, ok) + require.Equal(t, customQueryParametersMiddlewareName, middlewareName.MiddlewareName()) + + req, err := http.NewRequest(http.MethodGet, "http://test.com/query?hello=name", nil) + require.NoError(t, err) + res, err := rt.RoundTrip(req) + require.NoError(t, err) + require.NotNil(t, res) + if res.Body != nil { + require.NoError(t, res.Body.Close()) + } + + require.Equal(t, "http://test.com/query?hello=name", req.URL.String()) + }) + + t.Run("Without custom query parameters set as string should not apply middleware", func(t *testing.T) { + mw := customQueryParametersMiddleware() + rt := mw.CreateMiddleware(httpclient.Options{ + CustomOptions: map[string]interface{}{ + customQueryParametersKey: 64, + }, + }, finalRoundTripper) + require.NotNil(t, rt) + middlewareName, ok := mw.(httpclient.MiddlewareName) + require.True(t, ok) + require.Equal(t, customQueryParametersMiddlewareName, middlewareName.MiddlewareName()) + + req, err := http.NewRequest(http.MethodGet, "http://test.com/query?hello=name", nil) + require.NoError(t, err) + res, err := rt.RoundTrip(req) + require.NoError(t, err) + require.NotNil(t, res) + if res.Body != nil { + require.NoError(t, res.Body.Close()) + } + + require.Equal(t, "http://test.com/query?hello=name", req.URL.String()) + }) + + t.Run("With custom query parameters set as empty string should not apply middleware", func(t *testing.T) { + mw := customQueryParametersMiddleware() + rt := mw.CreateMiddleware(httpclient.Options{ + CustomOptions: map[string]interface{}{ + customQueryParametersKey: "", + }, + }, finalRoundTripper) + require.NotNil(t, rt) + middlewareName, ok := mw.(httpclient.MiddlewareName) + require.True(t, ok) + require.Equal(t, customQueryParametersMiddlewareName, middlewareName.MiddlewareName()) + + req, err := http.NewRequest(http.MethodGet, "http://test.com/query?hello=name", nil) + require.NoError(t, err) + res, err := rt.RoundTrip(req) + require.NoError(t, err) + require.NotNil(t, res) + if res.Body != nil { + require.NoError(t, res.Body.Close()) + } + + require.Equal(t, "http://test.com/query?hello=name", req.URL.String()) + }) + + t.Run("With custom query parameters set as string should apply middleware", func(t *testing.T) { + mw := customQueryParametersMiddleware() + rt := mw.CreateMiddleware(httpclient.Options{ + CustomOptions: map[string]interface{}{ + customQueryParametersKey: "custom=par/am&second=f oo", + }, + }, finalRoundTripper) + require.NotNil(t, rt) + middlewareName, ok := mw.(httpclient.MiddlewareName) + require.True(t, ok) + require.Equal(t, customQueryParametersMiddlewareName, middlewareName.MiddlewareName()) + + req, err := http.NewRequest(http.MethodGet, "http://test.com/query?hello=name", nil) + require.NoError(t, err) + res, err := rt.RoundTrip(req) + require.NoError(t, err) + require.NotNil(t, res) + if res.Body != nil { + require.NoError(t, res.Body.Close()) + } + + require.Equal(t, "http://test.com/query?hello=name&custom=par%2Fam&second=f+oo", req.URL.String()) + }) +} diff --git a/pkg/tsdb/prometheus/prometheus.go b/pkg/tsdb/prometheus/prometheus.go index 91c7e869010..edf99ddf6e3 100644 --- a/pkg/tsdb/prometheus/prometheus.go +++ b/pkg/tsdb/prometheus/prometheus.go @@ -4,84 +4,22 @@ import ( "context" "errors" "fmt" - "net/url" "regexp" "strings" "time" - "github.com/opentracing/opentracing-go" - - "net/http" - "github.com/grafana/grafana-plugin-sdk-go/data" "github.com/grafana/grafana/pkg/infra/httpclient" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/tsdb/interval" + "github.com/opentracing/opentracing-go" "github.com/prometheus/client_golang/api" apiv1 "github.com/prometheus/client_golang/api/prometheus/v1" "github.com/prometheus/common/model" ) -type PrometheusExecutor struct { - baseRoundTripperFactory func(dsInfo *models.DataSource) (http.RoundTripper, error) - intervalCalculator interval.Calculator -} - -type prometheusTransport struct { - Transport http.RoundTripper - - hasBasicAuth bool - username string - password string - - customQueryParameters string -} - -func (transport *prometheusTransport) RoundTrip(req *http.Request) (*http.Response, error) { - if transport.hasBasicAuth { - req.SetBasicAuth(transport.username, transport.password) - } - - if transport.customQueryParameters != "" { - params := url.Values{} - for _, param := range strings.Split(transport.customQueryParameters, "&") { - parts := strings.Split(param, "=") - if len(parts) == 1 { - // This is probably a mistake on the users part in defining the params but we don't want to crash. - params.Add(parts[0], "") - } else { - params.Add(parts[0], parts[1]) - } - } - if req.URL.RawQuery != "" { - req.URL.RawQuery = fmt.Sprintf("%s&%s", req.URL.RawQuery, params.Encode()) - } else { - req.URL.RawQuery = params.Encode() - } - } - - return transport.Transport.RoundTrip(req) -} - -//nolint: staticcheck // plugins.DataPlugin deprecated -func New(provider httpclient.Provider) func(*models.DataSource) (plugins.DataPlugin, error) { - return func(dsInfo *models.DataSource) (plugins.DataPlugin, error) { - transport, err := dsInfo.GetHTTPTransport(provider) - if err != nil { - return nil, err - } - - return &PrometheusExecutor{ - intervalCalculator: interval.NewCalculator(interval.CalculatorOptions{MinInterval: time.Second * 1}), - baseRoundTripperFactory: func(ds *models.DataSource) (http.RoundTripper, error) { - return transport, nil - }, - }, nil - } -} - var ( plog log.Logger legendFormat *regexp.Regexp = regexp.MustCompile(`\{\{\s*(.+?)\s*\}\}`) @@ -91,32 +29,34 @@ func init() { plog = log.New("tsdb.prometheus") } -func (e *PrometheusExecutor) getClient(dsInfo *models.DataSource) (apiv1.API, error) { - // Would make sense to cache this but executor is recreated on every alert request anyway. - transport, err := e.baseRoundTripperFactory(dsInfo) - if err != nil { - return nil, err - } +type PrometheusExecutor struct { + client apiv1.API + intervalCalculator interval.Calculator +} - promTransport := &prometheusTransport{ - Transport: transport, - hasBasicAuth: dsInfo.BasicAuth, - username: dsInfo.BasicAuthUser, - password: dsInfo.DecryptedBasicAuthPassword(), - customQueryParameters: dsInfo.JsonData.Get("customQueryParameters").MustString(""), - } +//nolint: staticcheck // plugins.DataPlugin deprecated +func New(provider httpclient.Provider) func(*models.DataSource) (plugins.DataPlugin, error) { + return func(dsInfo *models.DataSource) (plugins.DataPlugin, error) { + transport, err := dsInfo.GetHTTPTransport(provider, customQueryParametersMiddleware()) + if err != nil { + return nil, err + } - cfg := api.Config{ - Address: dsInfo.Url, - RoundTripper: promTransport, - } + cfg := api.Config{ + Address: dsInfo.Url, + RoundTripper: transport, + } - client, err := api.NewClient(cfg) - if err != nil { - return nil, err - } + client, err := api.NewClient(cfg) + if err != nil { + return nil, err + } - return apiv1.NewAPI(client), nil + return &PrometheusExecutor{ + intervalCalculator: interval.NewCalculator(interval.CalculatorOptions{MinInterval: time.Second * 1}), + client: apiv1.NewAPI(client), + }, nil + } } //nolint: staticcheck // plugins.DataResponse deprecated @@ -126,11 +66,6 @@ func (e *PrometheusExecutor) DataQuery(ctx context.Context, dsInfo *models.DataS Results: map[string]plugins.DataQueryResult{}, } - client, err := e.getClient(dsInfo) - if err != nil { - return result, err - } - queries, err := e.parseQuery(dsInfo, tsdbQuery) if err != nil { return result, err @@ -145,13 +80,13 @@ func (e *PrometheusExecutor) DataQuery(ctx context.Context, dsInfo *models.DataS plog.Debug("Sending query", "start", timeRange.Start, "end", timeRange.End, "step", timeRange.Step, "query", query.Expr) - span, ctx := opentracing.StartSpanFromContext(ctx, "alerting.prometheus") + span, ctx := opentracing.StartSpanFromContext(ctx, "datasource.prometheus") span.SetTag("expr", query.Expr) span.SetTag("start_unixnano", query.Start.UnixNano()) span.SetTag("stop_unixnano", query.End.UnixNano()) defer span.Finish() - value, _, err := client.QueryRange(ctx, query.Expr, timeRange) + value, _, err := e.client.QueryRange(ctx, query.Expr, timeRange) if err != nil { return result, err diff --git a/pkg/tsdb/prometheus/prometheus_test.go b/pkg/tsdb/prometheus/prometheus_test.go index 6fead701028..73ef12f97a9 100644 --- a/pkg/tsdb/prometheus/prometheus_test.go +++ b/pkg/tsdb/prometheus/prometheus_test.go @@ -2,11 +2,11 @@ package prometheus import ( "context" - "fmt" "net/http" "testing" "time" + sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" "github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/infra/httpclient" "github.com/grafana/grafana/pkg/models" @@ -22,7 +22,17 @@ func TestPrometheus(t *testing.T) { dsInfo := &models.DataSource{ JsonData: json, } - plug, err := New(httpclient.NewProvider())(dsInfo) + var capturedRequest *http.Request + mw := sdkhttpclient.MiddlewareFunc(func(opts sdkhttpclient.Options, next http.RoundTripper) http.RoundTripper { + return sdkhttpclient.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + capturedRequest = req + return &http.Response{StatusCode: http.StatusOK}, nil + }) + }) + provider := httpclient.NewProvider(sdkhttpclient.ProviderOptions{ + Middlewares: []sdkhttpclient.Middleware{mw}, + }) + plug, err := New(provider)(dsInfo) require.NoError(t, err) executor := plug.(*PrometheusExecutor) @@ -113,28 +123,12 @@ func TestPrometheus(t *testing.T) { "intervalFactor": 1, "refId": "A" }`) - queryParams := "" - executor.baseRoundTripperFactory = func(ds *models.DataSource) (http.RoundTripper, error) { - rt := &RoundTripperMock{} - rt.roundTrip = func(request *http.Request) (*http.Response, error) { - queryParams = request.URL.RawQuery - return nil, fmt.Errorf("this is fine") - } - return rt, nil - } _, _ = executor.DataQuery(context.Background(), dsInfo, query) - require.Equal(t, "custom=par%2Fam&second=f+oo", queryParams) + require.NotNil(t, capturedRequest) + require.Equal(t, "custom=par%2Fam&second=f+oo", capturedRequest.URL.RawQuery) }) } -type RoundTripperMock struct { - roundTrip func(*http.Request) (*http.Response, error) -} - -func (rt *RoundTripperMock) RoundTrip(req *http.Request) (*http.Response, error) { - return rt.roundTrip(req) -} - func queryContext(json string) plugins.DataQuery { jsonModel, _ := simplejson.NewJson([]byte(json)) queryModels := []plugins.DataSubQuery{