diff --git a/pkg/infra/httpclient/count_bytes_reader.go b/pkg/infra/httpclient/count_bytes_reader.go new file mode 100644 index 00000000000..278a0b378c3 --- /dev/null +++ b/pkg/infra/httpclient/count_bytes_reader.go @@ -0,0 +1,39 @@ +package httpclient + +import ( + "io" +) + +type CloseCallbackFunc func(bytesRead int64) + +// CountBytesReader counts the total amount of bytes read from the underlying reader. +// +// The provided callback func will be called before the underlying reader is closed. +func CountBytesReader(reader io.ReadCloser, callback CloseCallbackFunc) io.ReadCloser { + if reader == nil { + panic("reader cannot be nil") + } + + if callback == nil { + panic("callback cannot be nil") + } + + return &countBytesReader{reader: reader, callback: callback} +} + +type countBytesReader struct { + reader io.ReadCloser + callback CloseCallbackFunc + counter int64 +} + +func (r *countBytesReader) Read(p []byte) (int, error) { + n, err := r.reader.Read(p) + r.counter += int64(n) + return n, err +} + +func (r *countBytesReader) Close() error { + r.callback(r.counter) + return r.reader.Close() +} diff --git a/pkg/infra/httpclient/count_bytes_reader_test.go b/pkg/infra/httpclient/count_bytes_reader_test.go new file mode 100644 index 00000000000..d8cb077328d --- /dev/null +++ b/pkg/infra/httpclient/count_bytes_reader_test.go @@ -0,0 +1,38 @@ +package httpclient + +import ( + "fmt" + "io/ioutil" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCountBytesReader(t *testing.T) { + tcs := []struct { + body string + expectedBytesCount int64 + }{ + {body: "d", expectedBytesCount: 1}, + {body: "dummy", expectedBytesCount: 5}, + } + + for index, tc := range tcs { + t.Run(fmt.Sprintf("Test CountBytesReader %d", index), func(t *testing.T) { + body := ioutil.NopCloser(strings.NewReader(tc.body)) + var actualBytesRead int64 + + readCloser := CountBytesReader(body, func(bytesRead int64) { + actualBytesRead = bytesRead + }) + + bodyBytes, err := ioutil.ReadAll(readCloser) + require.NoError(t, err) + err = readCloser.Close() + require.NoError(t, err) + require.Equal(t, tc.expectedBytesCount, actualBytesRead) + require.Equal(t, string(bodyBytes), tc.body) + }) + } +} diff --git a/pkg/infra/httpclient/httpclientprovider/datasource_metrics_middleware.go b/pkg/infra/httpclient/httpclientprovider/datasource_metrics_middleware.go index cdf2addcf11..2709ed6ab76 100644 --- a/pkg/infra/httpclient/httpclientprovider/datasource_metrics_middleware.go +++ b/pkg/infra/httpclient/httpclientprovider/datasource_metrics_middleware.go @@ -3,7 +3,8 @@ package httpclientprovider import ( "net/http" - "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" + sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient" + "github.com/grafana/grafana/pkg/infra/httpclient" "github.com/grafana/grafana/pkg/infra/metrics/metricutil" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -56,8 +57,8 @@ const DataSourceMetricsMiddlewareName = "metrics" var executeMiddlewareFunc = executeMiddleware -func DataSourceMetricsMiddleware() httpclient.Middleware { - return httpclient.NamedMiddlewareFunc(DataSourceMetricsMiddlewareName, func(opts httpclient.Options, next http.RoundTripper) http.RoundTripper { +func DataSourceMetricsMiddleware() sdkhttpclient.Middleware { + return sdkhttpclient.NamedMiddlewareFunc(DataSourceMetricsMiddlewareName, func(opts sdkhttpclient.Options, next http.RoundTripper) http.RoundTripper { if opts.Labels == nil { return next } @@ -81,7 +82,7 @@ func DataSourceMetricsMiddleware() httpclient.Middleware { } func executeMiddleware(next http.RoundTripper, datasourceLabel prometheus.Labels) http.RoundTripper { - return httpclient.RoundTripperFunc(func(r *http.Request) (*http.Response, error) { + return sdkhttpclient.RoundTripperFunc(func(r *http.Request) (*http.Response, error) { requestCounter := datasourceRequestCounter.MustCurryWith(datasourceLabel) requestSummary := datasourceRequestSummary.MustCurryWith(datasourceLabel) requestInFlight := datasourceRequestsInFlight.With(datasourceLabel) @@ -94,10 +95,11 @@ func executeMiddleware(next http.RoundTripper, datasourceLabel prometheus.Labels if err != nil { return nil, err } - // we avoid measuring contentlength less than zero because it indicates - // that the content size is unknown. https://godoc.org/github.com/badu/http#Response - if res != nil && res.ContentLength > 0 { - responseSizeSummary.Observe(float64(res.ContentLength)) + + if res != nil { + res.Body = httpclient.CountBytesReader(res.Body, func(bytesRead int64) { + responseSizeSummary.Observe(float64(bytesRead)) + }) } return res, nil