From 04e9f6c24f32b59f457590bc7827691fa11f1811 Mon Sep 17 00:00:00 2001 From: Arve Knudsen Date: Wed, 10 Mar 2021 10:14:00 +0100 Subject: [PATCH] UsageStatsService: Don't use global state (#31849) * UsageStatsService: Don't use global state Signed-off-by: Arve Knudsen --- pkg/infra/usagestats/service.go | 1 - pkg/infra/usagestats/usage_stats.go | 83 ++--- pkg/infra/usagestats/usage_stats_test.go | 422 +++++++++++------------ pkg/setting/setting.go | 18 +- 4 files changed, 254 insertions(+), 270 deletions(-) diff --git a/pkg/infra/usagestats/service.go b/pkg/infra/usagestats/service.go index c4d509c8b70..79b500f0a34 100644 --- a/pkg/infra/usagestats/service.go +++ b/pkg/infra/usagestats/service.go @@ -27,7 +27,6 @@ func init() { type UsageStats interface { GetUsageReport(ctx context.Context) (UsageReport, error) - RegisterMetric(name string, fn MetricFunc) } diff --git a/pkg/infra/usagestats/usage_stats.go b/pkg/infra/usagestats/usage_stats.go index 9479a9a24c9..c2dd3777731 100644 --- a/pkg/infra/usagestats/usage_stats.go +++ b/pkg/infra/usagestats/usage_stats.go @@ -13,7 +13,6 @@ import ( "github.com/grafana/grafana/pkg/infra/metrics" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/plugins/manager" - "github.com/grafana/grafana/pkg/setting" ) var usageStatsURL = "https://stats.grafana.org/grafana-usage-report" @@ -29,18 +28,22 @@ type UsageReport struct { } func (uss *UsageStatsService) GetUsageReport(ctx context.Context) (UsageReport, error) { - version := strings.ReplaceAll(setting.BuildVersion, ".", "_") + version := strings.ReplaceAll(uss.Cfg.BuildVersion, ".", "_") metrics := map[string]interface{}{} + edition := "oss" + if uss.Cfg.IsEnterprise { + edition = "enterprise" + } report := UsageReport{ Version: version, Metrics: metrics, Os: runtime.GOOS, Arch: runtime.GOARCH, - Edition: getEdition(), + Edition: edition, HasValidLicense: uss.License.HasValidLicense(), - Packaging: setting.Packaging, + Packaging: uss.Cfg.Packaging, } statsQuery := models.GetSystemStatsQuery{} @@ -69,9 +72,19 @@ func (uss *UsageStatsService) GetUsageReport(ctx context.Context) (UsageReport, metrics["stats.total_auth_token.count"] = statsQuery.Result.AuthTokens metrics["stats.dashboard_versions.count"] = statsQuery.Result.DashboardVersions metrics["stats.annotations.count"] = statsQuery.Result.Annotations - metrics["stats.valid_license.count"] = getValidLicenseCount(uss.License.HasValidLicense()) - metrics["stats.edition.oss.count"] = getOssEditionCount() - metrics["stats.edition.enterprise.count"] = getEnterpriseEditionCount() + validLicCount := 0 + if uss.License.HasValidLicense() { + validLicCount = 1 + } + metrics["stats.valid_license.count"] = validLicCount + ossEditionCount := 1 + enterpriseEditionCount := 0 + if uss.Cfg.IsEnterprise { + enterpriseEditionCount = 1 + ossEditionCount = 0 + } + metrics["stats.edition.oss.count"] = ossEditionCount + metrics["stats.edition.enterprise.count"] = enterpriseEditionCount uss.registerExternalMetrics(metrics) @@ -102,8 +115,8 @@ func (uss *UsageStatsService) GetUsageReport(ctx context.Context) (UsageReport, } metrics["stats.ds.other.count"] = dsOtherCount - metrics["stats.packaging."+setting.Packaging+".count"] = 1 - metrics["stats.distributor."+setting.ReportingDistributor+".count"] = 1 + metrics["stats.packaging."+uss.Cfg.Packaging+".count"] = 1 + metrics["stats.distributor."+uss.Cfg.ReportingDistributor+".count"] = 1 // Alerting stats alertingUsageStats, err := uss.AlertingUsageStats.QueryUsageStats() @@ -170,10 +183,10 @@ func (uss *UsageStatsService) GetUsageReport(ctx context.Context) (UsageReport, // Add stats about auth configuration authTypes := map[string]bool{} - authTypes["anonymous"] = setting.AnonymousEnabled - authTypes["basic_auth"] = setting.BasicAuthEnabled - authTypes["ldap"] = setting.LDAPEnabled - authTypes["auth_proxy"] = setting.AuthProxyEnabled + authTypes["anonymous"] = uss.Cfg.AnonymousEnabled + authTypes["basic_auth"] = uss.Cfg.BasicAuthEnabled + authTypes["ldap"] = uss.Cfg.LDAPEnabled + authTypes["auth_proxy"] = uss.Cfg.AuthProxyEnabled for provider, enabled := range uss.oauthProviders { authTypes["oauth_"+provider] = enabled @@ -221,7 +234,7 @@ func (uss *UsageStatsService) RegisterMetric(name string, fn MetricFunc) { } func (uss *UsageStatsService) sendUsageStats(ctx context.Context) error { - if !setting.ReportingEnabled { + if !uss.Cfg.ReportingEnabled { return nil } @@ -237,9 +250,17 @@ func (uss *UsageStatsService) sendUsageStats(ctx context.Context) error { return err } data := bytes.NewBuffer(out) + sendUsageStats(data) - client := http.Client{Timeout: 5 * time.Second} + return nil +} + +// sendUsageStats sends usage statistics. +// +// Stubbable by tests. +var sendUsageStats = func(data *bytes.Buffer) { go func() { + client := http.Client{Timeout: 5 * time.Second} resp, err := client.Post(usageStatsURL, "application/json", data) if err != nil { metricsLogger.Error("Failed to send usage stats", "err", err) @@ -249,8 +270,6 @@ func (uss *UsageStatsService) sendUsageStats(ctx context.Context) error { metricsLogger.Warn("Failed to close response body", "err", err) } }() - - return nil } func (uss *UsageStatsService) updateTotalStats() { @@ -298,33 +317,3 @@ func (uss *UsageStatsService) shouldBeReported(dsType string) bool { return ds.Signature.IsValid() || ds.Signature.IsInternal() } - -func getEdition() string { - edition := "oss" - if setting.IsEnterprise { - edition = "enterprise" - } - - return edition -} - -func getEnterpriseEditionCount() int { - if setting.IsEnterprise { - return 1 - } - return 0 -} - -func getOssEditionCount() int { - if setting.IsEnterprise { - return 0 - } - return 1 -} - -func getValidLicenseCount(validLicense bool) int { - if validLicense { - return 1 - } - return 0 -} diff --git a/pkg/infra/usagestats/usage_stats_test.go b/pkg/infra/usagestats/usage_stats_test.go index 4973e086a00..b1247c2d470 100644 --- a/pkg/infra/usagestats/usage_stats_test.go +++ b/pkg/infra/usagestats/usage_stats_test.go @@ -6,7 +6,6 @@ import ( "errors" "io/ioutil" "runtime" - "sync" "testing" "time" @@ -40,13 +39,8 @@ func Test_InterfaceContractValidity(t *testing.T) { func TestMetrics(t *testing.T) { t.Run("When sending usage stats", func(t *testing.T) { - setupSomeDataSourcePlugins(t) - - uss := &UsageStatsService{ - Bus: bus.New(), - SQLStore: sqlstore.InitTestDB(t), - License: &licensing.OSSLicensingService{}, - } + uss := createService(t, setting.Cfg{}) + setupSomeDataSourcePlugins(t, uss) var getSystemStatsQuery *models.GetSystemStatsQuery uss.Bus.AddHandler(func(query *models.GetSystemStatsQuery) error { @@ -166,22 +160,6 @@ func TestMetrics(t *testing.T) { createConcurrentTokens(t, uss.SQLStore) uss.AlertingUsageStats = &alertingUsageMock{} - var wg sync.WaitGroup - var responseBuffer *bytes.Buffer - var req *http.Request - ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - req = r - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - t.Fatalf("Failed to read response body, err=%v", err) - } - responseBuffer = bytes.NewBuffer(buf) - wg.Done() - })) - usageStatsURL = ts.URL - - defer ts.Close() - uss.oauthProviders = map[string]bool{ "github": true, "gitlab": true, @@ -195,125 +173,163 @@ func TestMetrics(t *testing.T) { require.NoError(t, err) t.Run("Given reporting not enabled and sending usage stats", func(t *testing.T) { - setting.ReportingEnabled = false + origSendUsageStats := sendUsageStats + t.Cleanup(func() { + sendUsageStats = origSendUsageStats + }) + statsSent := false + sendUsageStats = func(*bytes.Buffer) { + statsSent = true + } + + uss.Cfg.ReportingEnabled = false err := uss.sendUsageStats(context.Background()) require.NoError(t, err) - t.Run("Should not gather stats or call http endpoint", func(t *testing.T) { - assert.Nil(t, getSystemStatsQuery) - assert.Nil(t, getDataSourceStatsQuery) - assert.Nil(t, getDataSourceAccessStatsQuery) - assert.Nil(t, req) - }) + require.False(t, statsSent) + assert.Nil(t, getSystemStatsQuery) + assert.Nil(t, getDataSourceStatsQuery) + assert.Nil(t, getDataSourceAccessStatsQuery) }) - t.Run("Given reporting enabled and sending usage stats", func(t *testing.T) { - setting.ReportingEnabled = true - setting.BuildVersion = "5.0.0" - setting.AnonymousEnabled = true - setting.BasicAuthEnabled = true - setting.LDAPEnabled = true - setting.AuthProxyEnabled = true - setting.Packaging = "deb" - setting.ReportingDistributor = "hosted-grafana" + t.Run("Given reporting enabled, stats should be gathered and sent to HTTP endpoint", func(t *testing.T) { + origCfg := uss.Cfg + t.Cleanup(func() { + uss.Cfg = origCfg + }) + uss.Cfg = &setting.Cfg{ + ReportingEnabled: true, + BuildVersion: "5.0.0", + AnonymousEnabled: true, + BasicAuthEnabled: true, + LDAPEnabled: true, + AuthProxyEnabled: true, + Packaging: "deb", + ReportingDistributor: "hosted-grafana", + } + + ch := make(chan httpResp) + ticker := time.NewTicker(2 * time.Second) + ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Logf("Fake HTTP handler received an error: %s", err.Error()) + ch <- httpResp{ + err: err, + } + return + } + require.NoError(t, err, "Failed to read response body, err=%v", err) + t.Logf("Fake HTTP handler received a response") + ch <- httpResp{ + responseBuffer: bytes.NewBuffer(buf), + req: r, + } + })) + t.Cleanup(ts.Close) + t.Cleanup(func() { + close(ch) + }) + usageStatsURL = ts.URL - wg.Add(1) err := uss.sendUsageStats(context.Background()) require.NoError(t, err) - t.Run("Should gather stats and call http endpoint", func(t *testing.T) { - if waitTimeout(&wg, 2*time.Second) { - t.Fatalf("Timed out waiting for http request") - } + // Wait for fake HTTP server to receive a request + var resp httpResp + select { + case resp = <-ch: + require.NoError(t, resp.err, "Fake server experienced an error") + case <-ticker.C: + t.Fatalf("Timed out waiting for HTTP request") + } - assert.NotNil(t, getSystemStatsQuery) - assert.NotNil(t, getDataSourceStatsQuery) - assert.NotNil(t, getDataSourceAccessStatsQuery) - assert.NotNil(t, getAlertNotifierUsageStatsQuery) - assert.NotNil(t, req) + t.Logf("Received response from fake HTTP server: %+v\n", resp) - assert.Equal(t, http.MethodPost, req.Method) - assert.Equal(t, "application/json", req.Header.Get("Content-Type")) + assert.NotNil(t, getSystemStatsQuery) + assert.NotNil(t, getDataSourceStatsQuery) + assert.NotNil(t, getDataSourceAccessStatsQuery) + assert.NotNil(t, getAlertNotifierUsageStatsQuery) + assert.NotNil(t, resp.req) - assert.NotNil(t, responseBuffer) + assert.Equal(t, http.MethodPost, resp.req.Method) + assert.Equal(t, "application/json", resp.req.Header.Get("Content-Type")) - j, err := simplejson.NewFromReader(responseBuffer) - assert.Nil(t, err) + require.NotNil(t, resp.responseBuffer) - assert.Equal(t, "5_0_0", j.Get("version").MustString()) - assert.Equal(t, runtime.GOOS, j.Get("os").MustString()) - assert.Equal(t, runtime.GOARCH, j.Get("arch").MustString()) + j, err := simplejson.NewFromReader(resp.responseBuffer) + require.NoError(t, err) - metrics := j.Get("metrics") - assert.Equal(t, getSystemStatsQuery.Result.Dashboards, metrics.Get("stats.dashboards.count").MustInt64()) - assert.Equal(t, getSystemStatsQuery.Result.Users, metrics.Get("stats.users.count").MustInt64()) - assert.Equal(t, getSystemStatsQuery.Result.Orgs, metrics.Get("stats.orgs.count").MustInt64()) - assert.Equal(t, getSystemStatsQuery.Result.Playlists, metrics.Get("stats.playlist.count").MustInt64()) - assert.Equal(t, len(manager.Apps), metrics.Get("stats.plugins.apps.count").MustInt()) - assert.Equal(t, len(manager.Panels), metrics.Get("stats.plugins.panels.count").MustInt()) - assert.Equal(t, len(manager.DataSources), metrics.Get("stats.plugins.datasources.count").MustInt()) - assert.Equal(t, getSystemStatsQuery.Result.Alerts, metrics.Get("stats.alerts.count").MustInt64()) - assert.Equal(t, getSystemStatsQuery.Result.ActiveUsers, metrics.Get("stats.active_users.count").MustInt64()) - assert.Equal(t, getSystemStatsQuery.Result.Datasources, metrics.Get("stats.datasources.count").MustInt64()) - assert.Equal(t, getSystemStatsQuery.Result.Stars, metrics.Get("stats.stars.count").MustInt64()) - assert.Equal(t, getSystemStatsQuery.Result.Folders, metrics.Get("stats.folders.count").MustInt64()) - assert.Equal(t, getSystemStatsQuery.Result.DashboardPermissions, metrics.Get("stats.dashboard_permissions.count").MustInt64()) - assert.Equal(t, getSystemStatsQuery.Result.FolderPermissions, metrics.Get("stats.folder_permissions.count").MustInt64()) - assert.Equal(t, getSystemStatsQuery.Result.ProvisionedDashboards, metrics.Get("stats.provisioned_dashboards.count").MustInt64()) - assert.Equal(t, getSystemStatsQuery.Result.Snapshots, metrics.Get("stats.snapshots.count").MustInt64()) - assert.Equal(t, getSystemStatsQuery.Result.Teams, metrics.Get("stats.teams.count").MustInt64()) - assert.Equal(t, 15, metrics.Get("stats.total_auth_token.count").MustInt()) - assert.Equal(t, 5, metrics.Get("stats.avg_auth_token_per_user.count").MustInt()) - assert.Equal(t, 16, metrics.Get("stats.dashboard_versions.count").MustInt()) - assert.Equal(t, 17, metrics.Get("stats.annotations.count").MustInt()) + assert.Equal(t, "5_0_0", j.Get("version").MustString()) + assert.Equal(t, runtime.GOOS, j.Get("os").MustString()) + assert.Equal(t, runtime.GOARCH, j.Get("arch").MustString()) - assert.Equal(t, 9, metrics.Get("stats.ds."+models.DS_ES+".count").MustInt()) - assert.Equal(t, 10, metrics.Get("stats.ds."+models.DS_PROMETHEUS+".count").MustInt()) - assert.Equal(t, 11+12, metrics.Get("stats.ds.other.count").MustInt()) + metrics := j.Get("metrics") + assert.Equal(t, getSystemStatsQuery.Result.Dashboards, metrics.Get("stats.dashboards.count").MustInt64()) + assert.Equal(t, getSystemStatsQuery.Result.Users, metrics.Get("stats.users.count").MustInt64()) + assert.Equal(t, getSystemStatsQuery.Result.Orgs, metrics.Get("stats.orgs.count").MustInt64()) + assert.Equal(t, getSystemStatsQuery.Result.Playlists, metrics.Get("stats.playlist.count").MustInt64()) + assert.Equal(t, len(manager.Apps), metrics.Get("stats.plugins.apps.count").MustInt()) + assert.Equal(t, len(manager.Panels), metrics.Get("stats.plugins.panels.count").MustInt()) + assert.Equal(t, len(manager.DataSources), metrics.Get("stats.plugins.datasources.count").MustInt()) + assert.Equal(t, getSystemStatsQuery.Result.Alerts, metrics.Get("stats.alerts.count").MustInt64()) + assert.Equal(t, getSystemStatsQuery.Result.ActiveUsers, metrics.Get("stats.active_users.count").MustInt64()) + assert.Equal(t, getSystemStatsQuery.Result.Datasources, metrics.Get("stats.datasources.count").MustInt64()) + assert.Equal(t, getSystemStatsQuery.Result.Stars, metrics.Get("stats.stars.count").MustInt64()) + assert.Equal(t, getSystemStatsQuery.Result.Folders, metrics.Get("stats.folders.count").MustInt64()) + assert.Equal(t, getSystemStatsQuery.Result.DashboardPermissions, metrics.Get("stats.dashboard_permissions.count").MustInt64()) + assert.Equal(t, getSystemStatsQuery.Result.FolderPermissions, metrics.Get("stats.folder_permissions.count").MustInt64()) + assert.Equal(t, getSystemStatsQuery.Result.ProvisionedDashboards, metrics.Get("stats.provisioned_dashboards.count").MustInt64()) + assert.Equal(t, getSystemStatsQuery.Result.Snapshots, metrics.Get("stats.snapshots.count").MustInt64()) + assert.Equal(t, getSystemStatsQuery.Result.Teams, metrics.Get("stats.teams.count").MustInt64()) + assert.Equal(t, 15, metrics.Get("stats.total_auth_token.count").MustInt()) + assert.Equal(t, 5, metrics.Get("stats.avg_auth_token_per_user.count").MustInt()) + assert.Equal(t, 16, metrics.Get("stats.dashboard_versions.count").MustInt()) + assert.Equal(t, 17, metrics.Get("stats.annotations.count").MustInt()) - assert.Equal(t, 1, metrics.Get("stats.ds_access."+models.DS_ES+".direct.count").MustInt()) - assert.Equal(t, 2, metrics.Get("stats.ds_access."+models.DS_ES+".proxy.count").MustInt()) - assert.Equal(t, 3, metrics.Get("stats.ds_access."+models.DS_PROMETHEUS+".proxy.count").MustInt()) - assert.Equal(t, 6+7, metrics.Get("stats.ds_access.other.direct.count").MustInt()) - assert.Equal(t, 4+8, metrics.Get("stats.ds_access.other.proxy.count").MustInt()) + assert.Equal(t, 9, metrics.Get("stats.ds."+models.DS_ES+".count").MustInt()) + assert.Equal(t, 10, metrics.Get("stats.ds."+models.DS_PROMETHEUS+".count").MustInt()) + assert.Equal(t, 11+12, metrics.Get("stats.ds.other.count").MustInt()) - assert.Equal(t, 1, metrics.Get("stats.alerting.ds.prometheus.count").MustInt()) - assert.Equal(t, 2, metrics.Get("stats.alerting.ds.graphite.count").MustInt()) - assert.Equal(t, 5, metrics.Get("stats.alerting.ds.mysql.count").MustInt()) - assert.Equal(t, 90, metrics.Get("stats.alerting.ds.other.count").MustInt()) + assert.Equal(t, 1, metrics.Get("stats.ds_access."+models.DS_ES+".direct.count").MustInt()) + assert.Equal(t, 2, metrics.Get("stats.ds_access."+models.DS_ES+".proxy.count").MustInt()) + assert.Equal(t, 3, metrics.Get("stats.ds_access."+models.DS_PROMETHEUS+".proxy.count").MustInt()) + assert.Equal(t, 6+7, metrics.Get("stats.ds_access.other.direct.count").MustInt()) + assert.Equal(t, 4+8, metrics.Get("stats.ds_access.other.proxy.count").MustInt()) - assert.Equal(t, 1, metrics.Get("stats.alert_notifiers.slack.count").MustInt()) - assert.Equal(t, 2, metrics.Get("stats.alert_notifiers.webhook.count").MustInt()) + assert.Equal(t, 1, metrics.Get("stats.alerting.ds.prometheus.count").MustInt()) + assert.Equal(t, 2, metrics.Get("stats.alerting.ds.graphite.count").MustInt()) + assert.Equal(t, 5, metrics.Get("stats.alerting.ds.mysql.count").MustInt()) + assert.Equal(t, 90, metrics.Get("stats.alerting.ds.other.count").MustInt()) - assert.Equal(t, 1, metrics.Get("stats.auth_enabled.anonymous.count").MustInt()) - assert.Equal(t, 1, metrics.Get("stats.auth_enabled.basic_auth.count").MustInt()) - assert.Equal(t, 1, metrics.Get("stats.auth_enabled.ldap.count").MustInt()) - assert.Equal(t, 1, metrics.Get("stats.auth_enabled.auth_proxy.count").MustInt()) - assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_github.count").MustInt()) - assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_gitlab.count").MustInt()) - assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_google.count").MustInt()) - assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_azuread.count").MustInt()) - assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_generic_oauth.count").MustInt()) - assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_grafana_com.count").MustInt()) + assert.Equal(t, 1, metrics.Get("stats.alert_notifiers.slack.count").MustInt()) + assert.Equal(t, 2, metrics.Get("stats.alert_notifiers.webhook.count").MustInt()) - assert.Equal(t, 1, metrics.Get("stats.packaging.deb.count").MustInt()) - assert.Equal(t, 1, metrics.Get("stats.distributor.hosted-grafana.count").MustInt()) + assert.Equal(t, 1, metrics.Get("stats.auth_enabled.anonymous.count").MustInt()) + assert.Equal(t, 1, metrics.Get("stats.auth_enabled.basic_auth.count").MustInt()) + assert.Equal(t, 1, metrics.Get("stats.auth_enabled.ldap.count").MustInt()) + assert.Equal(t, 1, metrics.Get("stats.auth_enabled.auth_proxy.count").MustInt()) + assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_github.count").MustInt()) + assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_gitlab.count").MustInt()) + assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_google.count").MustInt()) + assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_azuread.count").MustInt()) + assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_generic_oauth.count").MustInt()) + assert.Equal(t, 1, metrics.Get("stats.auth_enabled.oauth_grafana_com.count").MustInt()) - assert.Equal(t, 1, metrics.Get("stats.auth_token_per_user_le_3").MustInt()) - assert.Equal(t, 2, metrics.Get("stats.auth_token_per_user_le_6").MustInt()) - assert.Equal(t, 3, metrics.Get("stats.auth_token_per_user_le_9").MustInt()) - assert.Equal(t, 4, metrics.Get("stats.auth_token_per_user_le_12").MustInt()) - assert.Equal(t, 5, metrics.Get("stats.auth_token_per_user_le_15").MustInt()) - assert.Equal(t, 6, metrics.Get("stats.auth_token_per_user_le_inf").MustInt()) - }) + assert.Equal(t, 1, metrics.Get("stats.packaging.deb.count").MustInt()) + assert.Equal(t, 1, metrics.Get("stats.distributor.hosted-grafana.count").MustInt()) + + assert.Equal(t, 1, metrics.Get("stats.auth_token_per_user_le_3").MustInt()) + assert.Equal(t, 2, metrics.Get("stats.auth_token_per_user_le_6").MustInt()) + assert.Equal(t, 3, metrics.Get("stats.auth_token_per_user_le_9").MustInt()) + assert.Equal(t, 4, metrics.Get("stats.auth_token_per_user_le_12").MustInt()) + assert.Equal(t, 5, metrics.Get("stats.auth_token_per_user_le_15").MustInt()) + assert.Equal(t, 6, metrics.Get("stats.auth_token_per_user_le_inf").MustInt()) }) }) t.Run("When updating total stats", func(t *testing.T) { - uss := &UsageStatsService{ - Bus: bus.New(), - Cfg: setting.NewCfg(), - } + uss := createService(t, setting.Cfg{}) uss.Cfg.MetricsEndpointEnabled = true uss.Cfg.MetricsEndpointDisableTotalStats = false getSystemStatsWasCalled := false @@ -323,56 +339,44 @@ func TestMetrics(t *testing.T) { return nil }) - t.Run("When metrics is disabled and total stats is enabled", func(t *testing.T) { + t.Run("When metrics is disabled and total stats is enabled, stats should not be updated", func(t *testing.T) { uss.Cfg.MetricsEndpointEnabled = false uss.Cfg.MetricsEndpointDisableTotalStats = false - t.Run("Should not update stats", func(t *testing.T) { - uss.updateTotalStats() + uss.updateTotalStats() - assert.False(t, getSystemStatsWasCalled) - }) + assert.False(t, getSystemStatsWasCalled) }) - t.Run("When metrics is enabled and total stats is disabled", func(t *testing.T) { + t.Run("When metrics is enabled and total stats is disabled, stats should not be updated", func(t *testing.T) { uss.Cfg.MetricsEndpointEnabled = true uss.Cfg.MetricsEndpointDisableTotalStats = true - t.Run("Should not update stats", func(t *testing.T) { - uss.updateTotalStats() + uss.updateTotalStats() - assert.False(t, getSystemStatsWasCalled) - }) + assert.False(t, getSystemStatsWasCalled) }) - t.Run("When metrics is disabled and total stats is disabled", func(t *testing.T) { + t.Run("When metrics is disabled and total stats is disabled, stats should not be updated", func(t *testing.T) { uss.Cfg.MetricsEndpointEnabled = false uss.Cfg.MetricsEndpointDisableTotalStats = true - t.Run("Should not update stats", func(t *testing.T) { - uss.updateTotalStats() + uss.updateTotalStats() - assert.False(t, getSystemStatsWasCalled) - }) + assert.False(t, getSystemStatsWasCalled) }) - t.Run("When metrics is enabled and total stats is enabled", func(t *testing.T) { + t.Run("When metrics is enabled and total stats is enabled, stats should be updated", func(t *testing.T) { uss.Cfg.MetricsEndpointEnabled = true uss.Cfg.MetricsEndpointDisableTotalStats = false - t.Run("Should update stats", func(t *testing.T) { - uss.updateTotalStats() + uss.updateTotalStats() - assert.True(t, getSystemStatsWasCalled) - }) + assert.True(t, getSystemStatsWasCalled) }) }) t.Run("When registering a metric", func(t *testing.T) { - uss := &UsageStatsService{ - Bus: bus.New(), - Cfg: setting.NewCfg(), - externalMetrics: make(map[string]MetricFunc), - } + uss := createService(t, setting.Cfg{}) metricName := "stats.test_metric.count" t.Run("Adds a new metric to the external metrics", func(t *testing.T) { @@ -380,37 +384,31 @@ func TestMetrics(t *testing.T) { return 1, nil }) - metric, _ := uss.externalMetrics[metricName]() + metric, err := uss.externalMetrics[metricName]() + require.NoError(t, err) assert.Equal(t, 1, metric) }) - t.Run("When metric already exists", func(t *testing.T) { + t.Run("When metric already exists, the metric should be overridden", func(t *testing.T) { uss.RegisterMetric(metricName, func() (interface{}, error) { return 1, nil }) - metric, _ := uss.externalMetrics[metricName]() + metric, err := uss.externalMetrics[metricName]() + require.NoError(t, err) assert.Equal(t, 1, metric) - t.Run("Overrides the metric", func(t *testing.T) { - uss.RegisterMetric(metricName, func() (interface{}, error) { - return 2, nil - }) - newMetric, _ := uss.externalMetrics[metricName]() - assert.Equal(t, 2, newMetric) + uss.RegisterMetric(metricName, func() (interface{}, error) { + return 2, nil }) + newMetric, err := uss.externalMetrics[metricName]() + require.NoError(t, err) + assert.Equal(t, 2, newMetric) }) }) t.Run("When getting usage report", func(t *testing.T) { - uss := &UsageStatsService{ - Bus: bus.New(), - Cfg: setting.NewCfg(), - SQLStore: sqlstore.InitTestDB(t), - License: &licensing.OSSLicensingService{}, - AlertingUsageStats: &alertingUsageMock{}, - externalMetrics: make(map[string]MetricFunc), - } + uss := createService(t, setting.Cfg{}) metricName := "stats.test_metric.count" uss.Bus.AddHandler(func(query *models.GetSystemStatsQuery) error { @@ -453,7 +451,7 @@ func TestMetrics(t *testing.T) { }) report, err := uss.GetUsageReport(context.Background()) - assert.Nil(t, err, "Expected no error") + require.NoError(t, err, "Expected no error") metric := report.Metrics[metricName] assert.Equal(t, 1, metric) @@ -461,24 +459,18 @@ func TestMetrics(t *testing.T) { }) t.Run("When registering external metrics", func(t *testing.T) { - uss := &UsageStatsService{ - Bus: bus.New(), - Cfg: setting.NewCfg(), - externalMetrics: make(map[string]MetricFunc), - } + uss := createService(t, setting.Cfg{}) metrics := map[string]interface{}{"stats.test_metric.count": 1, "stats.test_metric_second.count": 2} extMetricName := "stats.test_external_metric.count" - t.Run("Should add to metrics", func(t *testing.T) { - uss.RegisterMetric(extMetricName, func() (interface{}, error) { - return 1, nil - }) - - uss.registerExternalMetrics(metrics) - - assert.Equal(t, 1, metrics[extMetricName]) + uss.RegisterMetric(extMetricName, func() (interface{}, error) { + return 1, nil }) + uss.registerExternalMetrics(metrics) + + assert.Equal(t, 1, metrics[extMetricName]) + t.Run("When loading a metric results to an error", func(t *testing.T) { uss.RegisterMetric(extMetricName, func() (interface{}, error) { return 1, nil @@ -495,7 +487,7 @@ func TestMetrics(t *testing.T) { extErrorMetric := metrics[extErrorMetricName] extMetric := metrics[extMetricName] - assert.Nil(t, extErrorMetric, "Invalid metric should not be added") + require.Nil(t, extErrorMetric, "Invalid metric should not be added") assert.Equal(t, 1, extMetric) assert.Len(t, metrics, 3, "Expected only one available metric") }) @@ -503,20 +495,6 @@ func TestMetrics(t *testing.T) { }) } -func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { - c := make(chan struct{}) - go func() { - defer close(c) - wg.Wait() - }() - select { - case <-c: - return false // completed normally - case <-time.After(timeout): - return true // timed out - } -} - type alertingUsageMock struct{} func (aum *alertingUsageMock) QueryUsageStats() (*alerting.UsageStats, error) { @@ -530,40 +508,58 @@ func (aum *alertingUsageMock) QueryUsageStats() (*alerting.UsageStats, error) { }, nil } -func setupSomeDataSourcePlugins(t *testing.T) { +func setupSomeDataSourcePlugins(t *testing.T, uss *UsageStatsService) { + t.Helper() + originalDataSources := manager.DataSources t.Cleanup(func() { manager.DataSources = originalDataSources }) - - manager.DataSources = make(map[string]*plugins.DataSourcePlugin) - - manager.DataSources[models.DS_ES] = &plugins.DataSourcePlugin{ - FrontendPluginBase: plugins.FrontendPluginBase{ - PluginBase: plugins.PluginBase{ - Signature: "internal", + manager.DataSources = map[string]*plugins.DataSourcePlugin{ + models.DS_ES: { + FrontendPluginBase: plugins.FrontendPluginBase{ + PluginBase: plugins.PluginBase{ + Signature: "internal", + }, }, }, - } - manager.DataSources[models.DS_PROMETHEUS] = &plugins.DataSourcePlugin{ - FrontendPluginBase: plugins.FrontendPluginBase{ - PluginBase: plugins.PluginBase{ - Signature: "internal", + models.DS_PROMETHEUS: { + FrontendPluginBase: plugins.FrontendPluginBase{ + PluginBase: plugins.PluginBase{ + Signature: "internal", + }, }, }, - } - - manager.DataSources[models.DS_GRAPHITE] = &plugins.DataSourcePlugin{ - FrontendPluginBase: plugins.FrontendPluginBase{ - PluginBase: plugins.PluginBase{ - Signature: "internal", + models.DS_GRAPHITE: { + FrontendPluginBase: plugins.FrontendPluginBase{ + PluginBase: plugins.PluginBase{ + Signature: "internal", + }, }, }, - } - - manager.DataSources[models.DS_MYSQL] = &plugins.DataSourcePlugin{ - FrontendPluginBase: plugins.FrontendPluginBase{ - PluginBase: plugins.PluginBase{ - Signature: "internal", + models.DS_MYSQL: { + FrontendPluginBase: plugins.FrontendPluginBase{ + PluginBase: plugins.PluginBase{ + Signature: "internal", + }, }, }, } } + +type httpResp struct { + req *http.Request + responseBuffer *bytes.Buffer + err error +} + +func createService(t *testing.T, cfg setting.Cfg) *UsageStatsService { + t.Helper() + + return &UsageStatsService{ + Bus: bus.New(), + Cfg: &cfg, + SQLStore: sqlstore.InitTestDB(t), + License: &licensing.OSSLicensingService{}, + AlertingUsageStats: &alertingUsageMock{}, + externalMetrics: make(map[string]MetricFunc), + } +} diff --git a/pkg/setting/setting.go b/pkg/setting/setting.go index 05a12bb2c22..c026d3f3eb4 100644 --- a/pkg/setting/setting.go +++ b/pkg/setting/setting.go @@ -141,10 +141,8 @@ var ( appliedEnvOverrides []string // analytics - ReportingEnabled bool - ReportingDistributor string - GoogleAnalyticsId string - GoogleTagManagerId string + GoogleAnalyticsId string + GoogleTagManagerId string // LDAP LDAPEnabled bool @@ -337,7 +335,9 @@ type Cfg struct { Env string // Analytics - CheckForUpdates bool + CheckForUpdates bool + ReportingDistributor string + ReportingEnabled bool // LDAP LDAPEnabled bool @@ -831,10 +831,10 @@ func (cfg *Cfg) Load(args *CommandLineArgs) error { cfg.CheckForUpdates = analytics.Key("check_for_updates").MustBool(true) GoogleAnalyticsId = analytics.Key("google_analytics_ua_id").String() GoogleTagManagerId = analytics.Key("google_tag_manager_id").String() - ReportingEnabled = analytics.Key("reporting_enabled").MustBool(true) - ReportingDistributor = analytics.Key("reporting_distributor").MustString("grafana-labs") - if len(ReportingDistributor) >= 100 { - ReportingDistributor = ReportingDistributor[:100] + cfg.ReportingEnabled = analytics.Key("reporting_enabled").MustBool(true) + cfg.ReportingDistributor = analytics.Key("reporting_distributor").MustString("grafana-labs") + if len(cfg.ReportingDistributor) >= 100 { + cfg.ReportingDistributor = cfg.ReportingDistributor[:100] } if err := readAlertingSettings(iniFile); err != nil {