Chore: Add context to org (#40685)

* Add context to org

* Rebase

* Fix rebase
This commit is contained in:
idafurjes
2021-11-03 11:31:56 +01:00
committed by GitHub
parent 91da1bbb79
commit 47f6bb3583
24 changed files with 120 additions and 112 deletions

View File

@ -33,7 +33,7 @@ func (hs *HTTPServer) AdminProvisioningReloadPlugins(c *models.ReqContext) respo
} }
func (hs *HTTPServer) AdminProvisioningReloadNotifications(c *models.ReqContext) response.Response { func (hs *HTTPServer) AdminProvisioningReloadNotifications(c *models.ReqContext) response.Response {
err := hs.ProvisioningService.ProvisionNotifications() err := hs.ProvisioningService.ProvisionNotifications(c.Req.Context())
if err != nil { if err != nil {
return response.Error(500, "", err) return response.Error(500, "", err)
} }

View File

@ -1,6 +1,7 @@
package api package api
import ( import (
"context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -143,7 +144,7 @@ func TestAPI_AdminProvisioningReload_AccessControl(t *testing.T) {
sc, hs := setupAccessControlScenarioContext(t, cfg, test.url, test.permissions) sc, hs := setupAccessControlScenarioContext(t, cfg, test.url, test.permissions)
// Setup the mock // Setup the mock
provisioningMock := provisioning.NewProvisioningServiceMock() provisioningMock := provisioning.NewProvisioningServiceMock(context.Background())
hs.ProvisioningService = provisioningMock hs.ProvisioningService = provisioningMock
sc.resp = httptest.NewRecorder() sc.resp = httptest.NewRecorder()

View File

@ -972,7 +972,7 @@ func TestDashboardAPIEndpoint(t *testing.T) {
loggedInUserScenarioWithRole(t, "When calling GET on", "GET", "/api/dashboards/uid/dash", "/api/dashboards/uid/:uid", models.ROLE_EDITOR, func(sc *scenarioContext) { loggedInUserScenarioWithRole(t, "When calling GET on", "GET", "/api/dashboards/uid/dash", "/api/dashboards/uid/:uid", models.ROLE_EDITOR, func(sc *scenarioContext) {
setUp() setUp()
mock := provisioning.NewProvisioningServiceMock() mock := provisioning.NewProvisioningServiceMock(context.Background())
mock.GetDashboardProvisionerResolvedPathFunc = func(name string) string { mock.GetDashboardProvisionerResolvedPathFunc = func(name string) string {
return "/tmp/grafana/dashboards" return "/tmp/grafana/dashboards"
} }
@ -985,7 +985,7 @@ func TestDashboardAPIEndpoint(t *testing.T) {
loggedInUserScenarioWithRole(t, "When allowUiUpdates is true and calling GET on", "GET", "/api/dashboards/uid/dash", "/api/dashboards/uid/:uid", models.ROLE_EDITOR, func(sc *scenarioContext) { loggedInUserScenarioWithRole(t, "When allowUiUpdates is true and calling GET on", "GET", "/api/dashboards/uid/dash", "/api/dashboards/uid/:uid", models.ROLE_EDITOR, func(sc *scenarioContext) {
setUp() setUp()
mock := provisioning.NewProvisioningServiceMock() mock := provisioning.NewProvisioningServiceMock(context.Background())
mock.GetDashboardProvisionerResolvedPathFunc = func(name string) string { mock.GetDashboardProvisionerResolvedPathFunc = func(name string) string {
return "/tmp/grafana/dashboards" return "/tmp/grafana/dashboards"
} }
@ -1015,7 +1015,7 @@ func TestDashboardAPIEndpoint(t *testing.T) {
func getDashboardShouldReturn200WithConfig(sc *scenarioContext, provisioningService provisioning.ProvisioningService) dtos. func getDashboardShouldReturn200WithConfig(sc *scenarioContext, provisioningService provisioning.ProvisioningService) dtos.
DashboardFullWithMeta { DashboardFullWithMeta {
if provisioningService == nil { if provisioningService == nil {
provisioningService = provisioning.NewProvisioningServiceMock() provisioningService = provisioning.NewProvisioningServiceMock(context.Background())
} }
libraryPanelsService := mockLibraryPanelService{} libraryPanelsService := mockLibraryPanelService{}
@ -1110,7 +1110,7 @@ func postDashboardScenario(t *testing.T, desc string, url string, routePattern s
hs := HTTPServer{ hs := HTTPServer{
Bus: bus.GetBus(), Bus: bus.GetBus(),
Cfg: cfg, Cfg: cfg,
ProvisioningService: provisioning.NewProvisioningServiceMock(), ProvisioningService: provisioning.NewProvisioningServiceMock(context.Background()),
Live: newTestLive(t), Live: newTestLive(t),
QuotaService: &quota.QuotaService{ QuotaService: &quota.QuotaService{
Cfg: cfg, Cfg: cfg,
@ -1179,7 +1179,7 @@ func restoreDashboardVersionScenario(t *testing.T, desc string, url string, rout
hs := HTTPServer{ hs := HTTPServer{
Cfg: cfg, Cfg: cfg,
Bus: bus.GetBus(), Bus: bus.GetBus(),
ProvisioningService: provisioning.NewProvisioningServiceMock(), ProvisioningService: provisioning.NewProvisioningServiceMock(context.Background()),
Live: newTestLive(t), Live: newTestLive(t),
QuotaService: &quota.QuotaService{Cfg: cfg}, QuotaService: &quota.QuotaService{Cfg: cfg},
LibraryPanelService: &mockLibraryPanelService{}, LibraryPanelService: &mockLibraryPanelService{},

View File

@ -1,6 +1,7 @@
package api package api
import ( import (
"context"
"errors" "errors"
"github.com/grafana/grafana/pkg/api/dtos" "github.com/grafana/grafana/pkg/api/dtos"
@ -15,12 +16,12 @@ import (
// GET /api/org // GET /api/org
func GetCurrentOrg(c *models.ReqContext) response.Response { func GetCurrentOrg(c *models.ReqContext) response.Response {
return getOrgHelper(c.OrgId) return getOrgHelper(c.Req.Context(), c.OrgId)
} }
// GET /api/orgs/:orgId // GET /api/orgs/:orgId
func GetOrgByID(c *models.ReqContext) response.Response { func GetOrgByID(c *models.ReqContext) response.Response {
return getOrgHelper(c.ParamsInt64(":orgId")) return getOrgHelper(c.Req.Context(), c.ParamsInt64(":orgId"))
} }
// Get /api/orgs/name/:name // Get /api/orgs/name/:name
@ -49,14 +50,13 @@ func (hs *HTTPServer) GetOrgByName(c *models.ReqContext) response.Response {
return response.JSON(200, &result) return response.JSON(200, &result)
} }
func getOrgHelper(orgID int64) response.Response { func getOrgHelper(ctx context.Context, orgID int64) response.Response {
query := models.GetOrgByIdQuery{Id: orgID} query := models.GetOrgByIdQuery{Id: orgID}
if err := sqlstore.GetOrgById(&query); err != nil { if err := sqlstore.GetOrgById(ctx, &query); err != nil {
if errors.Is(err, models.ErrOrgNotFound) { if errors.Is(err, models.ErrOrgNotFound) {
return response.Error(404, "Organization not found", err) return response.Error(404, "Organization not found", err)
} }
return response.Error(500, "Failed to get organization", err) return response.Error(500, "Failed to get organization", err)
} }
@ -85,7 +85,7 @@ func (hs *HTTPServer) CreateOrg(c *models.ReqContext, cmd models.CreateOrgComman
} }
cmd.UserId = c.UserId cmd.UserId = c.UserId
if err := sqlstore.CreateOrg(&cmd); err != nil { if err := sqlstore.CreateOrg(c.Req.Context(), &cmd); err != nil {
if errors.Is(err, models.ErrOrgNameTaken) { if errors.Is(err, models.ErrOrgNameTaken) {
return response.Error(409, "Organization name taken", err) return response.Error(409, "Organization name taken", err)
} }
@ -102,17 +102,17 @@ func (hs *HTTPServer) CreateOrg(c *models.ReqContext, cmd models.CreateOrgComman
// PUT /api/org // PUT /api/org
func UpdateCurrentOrg(c *models.ReqContext, form dtos.UpdateOrgForm) response.Response { func UpdateCurrentOrg(c *models.ReqContext, form dtos.UpdateOrgForm) response.Response {
return updateOrgHelper(form, c.OrgId) return updateOrgHelper(c.Req.Context(), form, c.OrgId)
} }
// PUT /api/orgs/:orgId // PUT /api/orgs/:orgId
func UpdateOrg(c *models.ReqContext, form dtos.UpdateOrgForm) response.Response { func UpdateOrg(c *models.ReqContext, form dtos.UpdateOrgForm) response.Response {
return updateOrgHelper(form, c.ParamsInt64(":orgId")) return updateOrgHelper(c.Req.Context(), form, c.ParamsInt64(":orgId"))
} }
func updateOrgHelper(form dtos.UpdateOrgForm, orgID int64) response.Response { func updateOrgHelper(ctx context.Context, form dtos.UpdateOrgForm, orgID int64) response.Response {
cmd := models.UpdateOrgCommand{Name: form.Name, OrgId: orgID} cmd := models.UpdateOrgCommand{Name: form.Name, OrgId: orgID}
if err := sqlstore.UpdateOrg(&cmd); err != nil { if err := sqlstore.UpdateOrg(ctx, &cmd); err != nil {
if errors.Is(err, models.ErrOrgNameTaken) { if errors.Is(err, models.ErrOrgNameTaken) {
return response.Error(400, "Organization name taken", err) return response.Error(400, "Organization name taken", err)
} }
@ -124,15 +124,15 @@ func updateOrgHelper(form dtos.UpdateOrgForm, orgID int64) response.Response {
// PUT /api/org/address // PUT /api/org/address
func UpdateCurrentOrgAddress(c *models.ReqContext, form dtos.UpdateOrgAddressForm) response.Response { func UpdateCurrentOrgAddress(c *models.ReqContext, form dtos.UpdateOrgAddressForm) response.Response {
return updateOrgAddressHelper(form, c.OrgId) return updateOrgAddressHelper(c.Req.Context(), form, c.OrgId)
} }
// PUT /api/orgs/:orgId/address // PUT /api/orgs/:orgId/address
func UpdateOrgAddress(c *models.ReqContext, form dtos.UpdateOrgAddressForm) response.Response { func UpdateOrgAddress(c *models.ReqContext, form dtos.UpdateOrgAddressForm) response.Response {
return updateOrgAddressHelper(form, c.ParamsInt64(":orgId")) return updateOrgAddressHelper(c.Req.Context(), form, c.ParamsInt64(":orgId"))
} }
func updateOrgAddressHelper(form dtos.UpdateOrgAddressForm, orgID int64) response.Response { func updateOrgAddressHelper(ctx context.Context, form dtos.UpdateOrgAddressForm, orgID int64) response.Response {
cmd := models.UpdateOrgAddressCommand{ cmd := models.UpdateOrgAddressCommand{
OrgId: orgID, OrgId: orgID,
Address: models.Address{ Address: models.Address{
@ -145,7 +145,7 @@ func updateOrgAddressHelper(form dtos.UpdateOrgAddressForm, orgID int64) respons
}, },
} }
if err := sqlstore.UpdateOrgAddress(&cmd); err != nil { if err := sqlstore.UpdateOrgAddress(ctx, &cmd); err != nil {
return response.Error(500, "Failed to update org address", err) return response.Error(500, "Failed to update org address", err)
} }
@ -160,7 +160,7 @@ func DeleteOrgByID(c *models.ReqContext) response.Response {
return response.Error(400, "Can not delete org for current user", nil) return response.Error(400, "Can not delete org for current user", nil)
} }
if err := sqlstore.DeleteOrg(&models.DeleteOrgCommand{Id: orgID}); err != nil { if err := sqlstore.DeleteOrg(c.Req.Context(), &models.DeleteOrgCommand{Id: orgID}); err != nil {
if errors.Is(err, models.ErrOrgNotFound) { if errors.Is(err, models.ErrOrgNotFound) {
return response.Error(404, "Failed to delete organization. ID not found", nil) return response.Error(404, "Failed to delete organization. ID not found", nil)
} }
@ -184,7 +184,7 @@ func SearchOrgs(c *models.ReqContext) response.Response {
Limit: perPage, Limit: perPage,
} }
if err := sqlstore.SearchOrgs(&query); err != nil { if err := sqlstore.SearchOrgs(c.Req.Context(), &query); err != nil {
return response.Error(500, "Failed to search orgs", err) return response.Error(500, "Failed to search orgs", err)
} }

View File

@ -197,7 +197,7 @@ func ProvideService(plugCtxProvider *plugincontext.Provider, cfg *setting.Cfg, r
// This can be unreasonable to have in production scenario with many // This can be unreasonable to have in production scenario with many
// organizations. // organizations.
query := &models.SearchOrgsQuery{} query := &models.SearchOrgsQuery{}
err := sqlstore.SearchOrgs(query) err := sqlstore.SearchOrgs(context.TODO(), query)
if err != nil { if err != nil {
return nil, fmt.Errorf("can't get org list: %w", err) return nil, fmt.Errorf("can't get org list: %w", err)
} }

View File

@ -1,6 +1,7 @@
package dashboards package dashboards
import ( import (
"context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
@ -62,7 +63,7 @@ func (cr *configReader) parseConfigs(file os.FileInfo) ([]*config, error) {
return []*config{}, nil return []*config{}, nil
} }
func (cr *configReader) readConfig() ([]*config, error) { func (cr *configReader) readConfig(ctx context.Context) ([]*config, error) {
var dashboards []*config var dashboards []*config
files, err := ioutil.ReadDir(cr.path) files, err := ioutil.ReadDir(cr.path)
@ -92,7 +93,7 @@ func (cr *configReader) readConfig() ([]*config, error) {
dashboard.OrgID = 1 dashboard.OrgID = 1
} }
if err := utils.CheckOrgExists(dashboard.OrgID); err != nil { if err := utils.CheckOrgExists(ctx, dashboard.OrgID); err != nil {
return nil, fmt.Errorf("failed to provision dashboards with %q reader: %w", dashboard.Name, err) return nil, fmt.Errorf("failed to provision dashboards with %q reader: %w", dashboard.Name, err)
} }

View File

@ -1,6 +1,7 @@
package dashboards package dashboards
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"os" "os"
@ -28,20 +29,20 @@ func TestDashboardsAsConfig(t *testing.T) {
t.Run("Should fail if orgs don't exist in the database", func(t *testing.T) { t.Run("Should fail if orgs don't exist in the database", func(t *testing.T) {
cfgProvider := configReader{path: appliedDefaults, log: logger} cfgProvider := configReader{path: appliedDefaults, log: logger}
_, err := cfgProvider.readConfig() _, err := cfgProvider.readConfig(context.Background())
require.Error(t, err) require.Error(t, err)
assert.True(t, errors.Is(err, models.ErrOrgNotFound)) assert.True(t, errors.Is(err, models.ErrOrgNotFound))
}) })
for i := 1; i <= 2; i++ { for i := 1; i <= 2; i++ {
orgCommand := models.CreateOrgCommand{Name: fmt.Sprintf("Main Org. %v", i)} orgCommand := models.CreateOrgCommand{Name: fmt.Sprintf("Main Org. %v", i)}
err := sqlstore.CreateOrg(&orgCommand) err := sqlstore.CreateOrg(context.Background(), &orgCommand)
require.NoError(t, err) require.NoError(t, err)
} }
t.Run("default values should be applied", func(t *testing.T) { t.Run("default values should be applied", func(t *testing.T) {
cfgProvider := configReader{path: appliedDefaults, log: logger} cfgProvider := configReader{path: appliedDefaults, log: logger}
cfg, err := cfgProvider.readConfig() cfg, err := cfgProvider.readConfig(context.Background())
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "file", cfg[0].Type) require.Equal(t, "file", cfg[0].Type)
@ -52,7 +53,7 @@ func TestDashboardsAsConfig(t *testing.T) {
t.Run("Can read config file version 1 format", func(t *testing.T) { t.Run("Can read config file version 1 format", func(t *testing.T) {
_ = os.Setenv("TEST_VAR", "general") _ = os.Setenv("TEST_VAR", "general")
cfgProvider := configReader{path: simpleDashboardConfig, log: logger} cfgProvider := configReader{path: simpleDashboardConfig, log: logger}
cfg, err := cfgProvider.readConfig() cfg, err := cfgProvider.readConfig(context.Background())
_ = os.Unsetenv("TEST_VAR") _ = os.Unsetenv("TEST_VAR")
require.NoError(t, err) require.NoError(t, err)
@ -61,7 +62,7 @@ func TestDashboardsAsConfig(t *testing.T) {
t.Run("Can read config file in version 0 format", func(t *testing.T) { t.Run("Can read config file in version 0 format", func(t *testing.T) {
cfgProvider := configReader{path: oldVersion, log: logger} cfgProvider := configReader{path: oldVersion, log: logger}
cfg, err := cfgProvider.readConfig() cfg, err := cfgProvider.readConfig(context.Background())
require.NoError(t, err) require.NoError(t, err)
validateDashboardAsConfig(t, cfg) validateDashboardAsConfig(t, cfg)
@ -69,7 +70,7 @@ func TestDashboardsAsConfig(t *testing.T) {
t.Run("Should skip invalid path", func(t *testing.T) { t.Run("Should skip invalid path", func(t *testing.T) {
cfgProvider := configReader{path: "/invalid-directory", log: logger} cfgProvider := configReader{path: "/invalid-directory", log: logger}
cfg, err := cfgProvider.readConfig() cfg, err := cfgProvider.readConfig(context.Background())
if err != nil { if err != nil {
t.Fatalf("readConfig return an error %v", err) t.Fatalf("readConfig return an error %v", err)
} }
@ -79,7 +80,7 @@ func TestDashboardsAsConfig(t *testing.T) {
t.Run("Should skip broken config files", func(t *testing.T) { t.Run("Should skip broken config files", func(t *testing.T) {
cfgProvider := configReader{path: brokenConfigs, log: logger} cfgProvider := configReader{path: brokenConfigs, log: logger}
cfg, err := cfgProvider.readConfig() cfg, err := cfgProvider.readConfig(context.Background())
if err != nil { if err != nil {
t.Fatalf("readConfig return an error %v", err) t.Fatalf("readConfig return an error %v", err)
} }

View File

@ -23,7 +23,7 @@ type DashboardProvisioner interface {
} }
// DashboardProvisionerFactory creates DashboardProvisioners based on input // DashboardProvisionerFactory creates DashboardProvisioners based on input
type DashboardProvisionerFactory func(string, dashboards.Store) (DashboardProvisioner, error) type DashboardProvisionerFactory func(context.Context, string, dashboards.Store) (DashboardProvisioner, error)
// Provisioner is responsible for syncing dashboard from disk to Grafana's database. // Provisioner is responsible for syncing dashboard from disk to Grafana's database.
type Provisioner struct { type Provisioner struct {
@ -34,10 +34,10 @@ type Provisioner struct {
} }
// New returns a new DashboardProvisioner // New returns a new DashboardProvisioner
func New(configDirectory string, store dashboards.Store) (DashboardProvisioner, error) { func New(ctx context.Context, configDirectory string, store dashboards.Store) (DashboardProvisioner, error) {
logger := log.New("provisioning.dashboard") logger := log.New("provisioning.dashboard")
cfgReader := &configReader{path: configDirectory, log: logger} cfgReader := &configReader{path: configDirectory, log: logger}
configs, err := cfgReader.readConfig() configs, err := cfgReader.readConfig(ctx)
if err != nil { if err != nil {
return nil, errutil.Wrap("Failed to read dashboards config", err) return nil, errutil.Wrap("Failed to read dashboards config", err)
} }

View File

@ -1,6 +1,7 @@
package datasources package datasources
import ( import (
"context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
@ -17,7 +18,7 @@ type configReader struct {
log log.Logger log log.Logger
} }
func (cr *configReader) readConfig(path string) ([]*configs, error) { func (cr *configReader) readConfig(ctx context.Context, path string) ([]*configs, error) {
var datasources []*configs var datasources []*configs
files, err := ioutil.ReadDir(path) files, err := ioutil.ReadDir(path)
@ -39,7 +40,7 @@ func (cr *configReader) readConfig(path string) ([]*configs, error) {
} }
} }
err = cr.validateDefaultUniqueness(datasources) err = cr.validateDefaultUniqueness(ctx, datasources)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -88,7 +89,7 @@ func (cr *configReader) parseDatasourceConfig(path string, file os.FileInfo) (*c
return v0.mapToDatasourceFromConfig(apiVersion.APIVersion), nil return v0.mapToDatasourceFromConfig(apiVersion.APIVersion), nil
} }
func (cr *configReader) validateDefaultUniqueness(datasources []*configs) error { func (cr *configReader) validateDefaultUniqueness(ctx context.Context, datasources []*configs) error {
defaultCount := map[int64]int{} defaultCount := map[int64]int{}
for i := range datasources { for i := range datasources {
if datasources[i].Datasources == nil { if datasources[i].Datasources == nil {
@ -100,7 +101,7 @@ func (cr *configReader) validateDefaultUniqueness(datasources []*configs) error
ds.OrgID = 1 ds.OrgID = 1
} }
if err := cr.validateAccessAndOrgID(ds); err != nil { if err := cr.validateAccessAndOrgID(ctx, ds); err != nil {
return fmt.Errorf("failed to provision %q data source: %w", ds.Name, err) return fmt.Errorf("failed to provision %q data source: %w", ds.Name, err)
} }
@ -122,8 +123,8 @@ func (cr *configReader) validateDefaultUniqueness(datasources []*configs) error
return nil return nil
} }
func (cr *configReader) validateAccessAndOrgID(ds *upsertDataSourceFromConfig) error { func (cr *configReader) validateAccessAndOrgID(ctx context.Context, ds *upsertDataSourceFromConfig) error {
if err := utils.CheckOrgExists(ds.OrgID); err != nil { if err := utils.CheckOrgExists(ctx, ds.OrgID); err != nil {
return err return err
} }

View File

@ -151,20 +151,20 @@ func TestDatasourceAsConfig(t *testing.T) {
t.Run("broken yaml should return error", func(t *testing.T) { t.Run("broken yaml should return error", func(t *testing.T) {
reader := &configReader{} reader := &configReader{}
_, err := reader.readConfig(brokenYaml) _, err := reader.readConfig(context.Background(), brokenYaml)
require.NotNil(t, err) require.NotNil(t, err)
}) })
t.Run("invalid access should warn about invalid value and return 'proxy'", func(t *testing.T) { t.Run("invalid access should warn about invalid value and return 'proxy'", func(t *testing.T) {
reader := &configReader{log: logger} reader := &configReader{log: logger}
configs, err := reader.readConfig(invalidAccess) configs, err := reader.readConfig(context.Background(), invalidAccess)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, configs[0].Datasources[0].Access, models.DS_ACCESS_PROXY) require.Equal(t, configs[0].Datasources[0].Access, models.DS_ACCESS_PROXY)
}) })
t.Run("skip invalid directory", func(t *testing.T) { t.Run("skip invalid directory", func(t *testing.T) {
cfgProvider := &configReader{log: log.New("test logger")} cfgProvider := &configReader{log: log.New("test logger")}
cfg, err := cfgProvider.readConfig("./invalid-directory") cfg, err := cfgProvider.readConfig(context.Background(), "./invalid-directory")
if err != nil { if err != nil {
t.Fatalf("readConfig return an error %v", err) t.Fatalf("readConfig return an error %v", err)
} }
@ -175,7 +175,7 @@ func TestDatasourceAsConfig(t *testing.T) {
t.Run("can read all properties from version 1", func(t *testing.T) { t.Run("can read all properties from version 1", func(t *testing.T) {
_ = os.Setenv("TEST_VAR", "name") _ = os.Setenv("TEST_VAR", "name")
cfgProvider := &configReader{log: log.New("test logger")} cfgProvider := &configReader{log: log.New("test logger")}
cfg, err := cfgProvider.readConfig(allProperties) cfg, err := cfgProvider.readConfig(context.Background(), allProperties)
_ = os.Unsetenv("TEST_VAR") _ = os.Unsetenv("TEST_VAR")
if err != nil { if err != nil {
t.Fatalf("readConfig return an error %v", err) t.Fatalf("readConfig return an error %v", err)
@ -204,7 +204,7 @@ func TestDatasourceAsConfig(t *testing.T) {
t.Run("can read all properties from version 0", func(t *testing.T) { t.Run("can read all properties from version 0", func(t *testing.T) {
cfgProvider := &configReader{log: log.New("test logger")} cfgProvider := &configReader{log: log.New("test logger")}
cfg, err := cfgProvider.readConfig(versionZero) cfg, err := cfgProvider.readConfig(context.Background(), versionZero)
if err != nil { if err != nil {
t.Fatalf("readConfig return an error %v", err) t.Fatalf("readConfig return an error %v", err)
} }

View File

@ -69,7 +69,7 @@ func (dc *DatasourceProvisioner) apply(ctx context.Context, cfg *configs) error
} }
func (dc *DatasourceProvisioner) applyChanges(ctx context.Context, configPath string) error { func (dc *DatasourceProvisioner) applyChanges(ctx context.Context, configPath string) error {
configs, err := dc.cfgProvider.readConfig(configPath) configs, err := dc.cfgProvider.readConfig(ctx, configPath)
if err != nil { if err != nil {
return err return err
} }

View File

@ -5,12 +5,13 @@ import (
"github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/encryption" "github.com/grafana/grafana/pkg/services/encryption"
"golang.org/x/net/context"
) )
// Provision alert notifiers // Provision alert notifiers
func Provision(configDirectory string, encryptionService encryption.Service) error { func Provision(ctx context.Context, configDirectory string, encryptionService encryption.Service) error {
dc := newNotificationProvisioner(encryptionService, log.New("provisioning.notifiers")) dc := newNotificationProvisioner(encryptionService, log.New("provisioning.notifiers"))
return dc.applyChanges(configDirectory) return dc.applyChanges(ctx, configDirectory)
} }
// NotificationProvisioner is responsible for provsioning alert notifiers // NotificationProvisioner is responsible for provsioning alert notifiers
@ -132,8 +133,8 @@ func (dc *NotificationProvisioner) mergeNotifications(notificationToMerge []*not
return nil return nil
} }
func (dc *NotificationProvisioner) applyChanges(configPath string) error { func (dc *NotificationProvisioner) applyChanges(ctx context.Context, configPath string) error {
configs, err := dc.cfgProvider.readConfig(configPath) configs, err := dc.cfgProvider.readConfig(ctx, configPath)
if err != nil { if err != nil {
return err return err
} }

View File

@ -22,7 +22,7 @@ type configReader struct {
log log.Logger log log.Logger
} }
func (cr *configReader) readConfig(path string) ([]*notificationsAsConfig, error) { func (cr *configReader) readConfig(ctx context.Context, path string) ([]*notificationsAsConfig, error) {
var notifications []*notificationsAsConfig var notifications []*notificationsAsConfig
cr.log.Debug("Looking for alert notification provisioning files", "path", path) cr.log.Debug("Looking for alert notification provisioning files", "path", path)
@ -51,7 +51,7 @@ func (cr *configReader) readConfig(path string) ([]*notificationsAsConfig, error
return nil, err return nil, err
} }
if err := cr.checkOrgIDAndOrgName(notifications); err != nil { if err := cr.checkOrgIDAndOrgName(ctx, notifications); err != nil {
return nil, err return nil, err
} }
@ -81,7 +81,7 @@ func (cr *configReader) parseNotificationConfig(path string, file os.FileInfo) (
return cfg.mapToNotificationFromConfig(), nil return cfg.mapToNotificationFromConfig(), nil
} }
func (cr *configReader) checkOrgIDAndOrgName(notifications []*notificationsAsConfig) error { func (cr *configReader) checkOrgIDAndOrgName(ctx context.Context, notifications []*notificationsAsConfig) error {
for i := range notifications { for i := range notifications {
for _, notification := range notifications[i].Notifications { for _, notification := range notifications[i].Notifications {
if notification.OrgID < 1 { if notification.OrgID < 1 {
@ -91,7 +91,7 @@ func (cr *configReader) checkOrgIDAndOrgName(notifications []*notificationsAsCon
notification.OrgID = 0 notification.OrgID = 0
} }
} else { } else {
if err := utils.CheckOrgExists(notification.OrgID); err != nil { if err := utils.CheckOrgExists(ctx, notification.OrgID); err != nil {
return fmt.Errorf("failed to provision %q notification: %w", notification.Name, err) return fmt.Errorf("failed to provision %q notification: %w", notification.Name, err)
} }
} }

View File

@ -1,6 +1,7 @@
package notifiers package notifiers
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"testing" "testing"
@ -40,7 +41,7 @@ func TestNotificationAsConfig(t *testing.T) {
for i := 1; i < 5; i++ { for i := 1; i < 5; i++ {
orgCommand := models.CreateOrgCommand{Name: fmt.Sprintf("Main Org. %v", i)} orgCommand := models.CreateOrgCommand{Name: fmt.Sprintf("Main Org. %v", i)}
err := sqlstore.CreateOrg(&orgCommand) err := sqlstore.CreateOrg(context.Background(), &orgCommand)
require.NoError(t, err) require.NoError(t, err)
} }
@ -65,7 +66,7 @@ func TestNotificationAsConfig(t *testing.T) {
log: log.New("test logger"), log: log.New("test logger"),
} }
cfg, err := cfgProvider.readConfig(correctProperties) cfg, err := cfgProvider.readConfig(context.Background(), correctProperties)
_ = os.Unsetenv("TEST_VAR") _ = os.Unsetenv("TEST_VAR")
if err != nil { if err != nil {
t.Fatalf("readConfig return an error %v", err) t.Fatalf("readConfig return an error %v", err)
@ -140,7 +141,7 @@ func TestNotificationAsConfig(t *testing.T) {
setup() setup()
dc := newNotificationProvisioner(ossencryption.ProvideService(), logger) dc := newNotificationProvisioner(ossencryption.ProvideService(), logger)
err := dc.applyChanges(twoNotificationsConfig) err := dc.applyChanges(context.Background(), twoNotificationsConfig)
if err != nil { if err != nil {
t.Fatalf("applyChanges return an error %v", err) t.Fatalf("applyChanges return an error %v", err)
} }
@ -170,7 +171,7 @@ func TestNotificationAsConfig(t *testing.T) {
t.Run("should update one notification", func(t *testing.T) { t.Run("should update one notification", func(t *testing.T) {
dc := newNotificationProvisioner(ossencryption.ProvideService(), logger) dc := newNotificationProvisioner(ossencryption.ProvideService(), logger)
err = dc.applyChanges(twoNotificationsConfig) err = dc.applyChanges(context.Background(), twoNotificationsConfig)
if err != nil { if err != nil {
t.Fatalf("applyChanges return an error %v", err) t.Fatalf("applyChanges return an error %v", err)
} }
@ -194,7 +195,7 @@ func TestNotificationAsConfig(t *testing.T) {
t.Run("Two notifications with is_default", func(t *testing.T) { t.Run("Two notifications with is_default", func(t *testing.T) {
setup() setup()
dc := newNotificationProvisioner(ossencryption.ProvideService(), logger) dc := newNotificationProvisioner(ossencryption.ProvideService(), logger)
err := dc.applyChanges(doubleNotificationsConfig) err := dc.applyChanges(context.Background(), doubleNotificationsConfig)
t.Run("should both be inserted", func(t *testing.T) { t.Run("should both be inserted", func(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
notificationsQuery := models.GetAllAlertNotificationsQuery{OrgId: 1} notificationsQuery := models.GetAllAlertNotificationsQuery{OrgId: 1}
@ -237,7 +238,7 @@ func TestNotificationAsConfig(t *testing.T) {
t.Run("should have two new notifications", func(t *testing.T) { t.Run("should have two new notifications", func(t *testing.T) {
dc := newNotificationProvisioner(ossencryption.ProvideService(), logger) dc := newNotificationProvisioner(ossencryption.ProvideService(), logger)
err := dc.applyChanges(twoNotificationsConfig) err := dc.applyChanges(context.Background(), twoNotificationsConfig)
if err != nil { if err != nil {
t.Fatalf("applyChanges return an error %v", err) t.Fatalf("applyChanges return an error %v", err)
} }
@ -253,11 +254,11 @@ func TestNotificationAsConfig(t *testing.T) {
t.Run("Can read correct properties with orgName instead of orgId", func(t *testing.T) { t.Run("Can read correct properties with orgName instead of orgId", func(t *testing.T) {
setup() setup()
existingOrg1 := models.GetOrgByNameQuery{Name: "Main Org. 1"} existingOrg1 := models.GetOrgByNameQuery{Name: "Main Org. 1"}
err := sqlstore.GetOrgByName(&existingOrg1) err := sqlstore.GetOrgByName(context.Background(), &existingOrg1)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, existingOrg1.Result) require.NotNil(t, existingOrg1.Result)
existingOrg2 := models.GetOrgByNameQuery{Name: "Main Org. 2"} existingOrg2 := models.GetOrgByNameQuery{Name: "Main Org. 2"}
err = sqlstore.GetOrgByName(&existingOrg2) err = sqlstore.GetOrgByName(context.Background(), &existingOrg2)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, existingOrg2.Result) require.NotNil(t, existingOrg2.Result)
@ -271,7 +272,7 @@ func TestNotificationAsConfig(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
dc := newNotificationProvisioner(ossencryption.ProvideService(), logger) dc := newNotificationProvisioner(ossencryption.ProvideService(), logger)
err = dc.applyChanges(correctPropertiesWithOrgName) err = dc.applyChanges(context.Background(), correctPropertiesWithOrgName)
if err != nil { if err != nil {
t.Fatalf("applyChanges return an error %v", err) t.Fatalf("applyChanges return an error %v", err)
} }
@ -290,7 +291,7 @@ func TestNotificationAsConfig(t *testing.T) {
t.Run("Config doesn't contain required field", func(t *testing.T) { t.Run("Config doesn't contain required field", func(t *testing.T) {
setup() setup()
dc := newNotificationProvisioner(ossencryption.ProvideService(), logger) dc := newNotificationProvisioner(ossencryption.ProvideService(), logger)
err := dc.applyChanges(noRequiredFields) err := dc.applyChanges(context.Background(), noRequiredFields)
require.NotNil(t, err) require.NotNil(t, err)
errString := err.Error() errString := err.Error()
@ -304,7 +305,7 @@ func TestNotificationAsConfig(t *testing.T) {
t.Run("should have not changed repo", func(t *testing.T) { t.Run("should have not changed repo", func(t *testing.T) {
setup() setup()
dc := newNotificationProvisioner(ossencryption.ProvideService(), logger) dc := newNotificationProvisioner(ossencryption.ProvideService(), logger)
err := dc.applyChanges(emptyFile) err := dc.applyChanges(context.Background(), emptyFile)
if err != nil { if err != nil {
t.Fatalf("applyChanges return an error %v", err) t.Fatalf("applyChanges return an error %v", err)
} }
@ -321,7 +322,7 @@ func TestNotificationAsConfig(t *testing.T) {
log: log.New("test logger"), log: log.New("test logger"),
} }
_, err := reader.readConfig(brokenYaml) _, err := reader.readConfig(context.Background(), brokenYaml)
require.NotNil(t, err) require.NotNil(t, err)
}) })
@ -331,7 +332,7 @@ func TestNotificationAsConfig(t *testing.T) {
log: log.New("test logger"), log: log.New("test logger"),
} }
cfg, err := cfgProvider.readConfig(emptyFolder) cfg, err := cfgProvider.readConfig(context.Background(), emptyFolder)
if err != nil { if err != nil {
t.Fatalf("readConfig return an error %v", err) t.Fatalf("readConfig return an error %v", err)
} }
@ -343,7 +344,7 @@ func TestNotificationAsConfig(t *testing.T) {
encryptionService: ossencryption.ProvideService(), encryptionService: ossencryption.ProvideService(),
log: log.New("test logger"), log: log.New("test logger"),
} }
_, err := cfgProvider.readConfig(unknownNotifier) _, err := cfgProvider.readConfig(context.Background(), unknownNotifier)
require.NotNil(t, err) require.NotNil(t, err)
require.Equal(t, err.Error(), `unsupported notification type "nonexisting"`) require.Equal(t, err.Error(), `unsupported notification type "nonexisting"`)
}) })
@ -353,7 +354,7 @@ func TestNotificationAsConfig(t *testing.T) {
encryptionService: ossencryption.ProvideService(), encryptionService: ossencryption.ProvideService(),
log: log.New("test logger"), log: log.New("test logger"),
} }
_, err := cfgProvider.readConfig(incorrectSettings) _, err := cfgProvider.readConfig(context.Background(), incorrectSettings)
require.NotNil(t, err) require.NotNil(t, err)
require.Equal(t, err.Error(), "alert validation error: token must be specified when using the Slack chat API") require.Equal(t, err.Error(), "alert validation error: token must be specified when using the Slack chat API")
}) })
@ -362,7 +363,7 @@ func TestNotificationAsConfig(t *testing.T) {
func setupBusHandlers(sqlStore *sqlstore.SQLStore) { func setupBusHandlers(sqlStore *sqlstore.SQLStore) {
bus.AddHandler("getOrg", func(q *models.GetOrgByNameQuery) error { bus.AddHandler("getOrg", func(q *models.GetOrgByNameQuery) error {
return sqlstore.GetOrgByName(q) return sqlstore.GetOrgByName(context.Background(), q)
}) })
bus.AddHandler("getAlertNotifications", func(q *models.GetAlertNotificationsWithUidQuery) error { bus.AddHandler("getAlertNotifications", func(q *models.GetAlertNotificationsWithUidQuery) error {

View File

@ -39,7 +39,7 @@ type ProvisioningService interface {
RunInitProvisioners(ctx context.Context) error RunInitProvisioners(ctx context.Context) error
ProvisionDatasources(ctx context.Context) error ProvisionDatasources(ctx context.Context) error
ProvisionPlugins() error ProvisionPlugins() error
ProvisionNotifications() error ProvisionNotifications(ctx context.Context) error
ProvisionDashboards(ctx context.Context) error ProvisionDashboards(ctx context.Context) error
GetDashboardProvisionerResolvedPath(name string) string GetDashboardProvisionerResolvedPath(name string) string
GetAllowUIUpdatesFromConfig(name string) bool GetAllowUIUpdatesFromConfig(name string) bool
@ -59,7 +59,7 @@ func NewProvisioningServiceImpl() *ProvisioningServiceImpl {
// Used for testing purposes // Used for testing purposes
func newProvisioningServiceImpl( func newProvisioningServiceImpl(
newDashboardProvisioner dashboards.DashboardProvisionerFactory, newDashboardProvisioner dashboards.DashboardProvisionerFactory,
provisionNotifiers func(string, encryption.Service) error, provisionNotifiers func(context.Context, string, encryption.Service) error,
provisionDatasources func(context.Context, string) error, provisionDatasources func(context.Context, string) error,
provisionPlugins func(string, plugifaces.Store) error, provisionPlugins func(string, plugifaces.Store) error,
) *ProvisioningServiceImpl { ) *ProvisioningServiceImpl {
@ -81,7 +81,7 @@ type ProvisioningServiceImpl struct {
pollingCtxCancel context.CancelFunc pollingCtxCancel context.CancelFunc
newDashboardProvisioner dashboards.DashboardProvisionerFactory newDashboardProvisioner dashboards.DashboardProvisionerFactory
dashboardProvisioner dashboards.DashboardProvisioner dashboardProvisioner dashboards.DashboardProvisioner
provisionNotifiers func(string, encryption.Service) error provisionNotifiers func(context.Context, string, encryption.Service) error
provisionDatasources func(context.Context, string) error provisionDatasources func(context.Context, string) error
provisionPlugins func(string, plugifaces.Store) error provisionPlugins func(string, plugifaces.Store) error
mutex sync.Mutex mutex sync.Mutex
@ -98,7 +98,7 @@ func (ps *ProvisioningServiceImpl) RunInitProvisioners(ctx context.Context) erro
return err return err
} }
err = ps.ProvisionNotifications() err = ps.ProvisionNotifications(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -147,15 +147,15 @@ func (ps *ProvisioningServiceImpl) ProvisionPlugins() error {
return errutil.Wrap("app provisioning error", err) return errutil.Wrap("app provisioning error", err)
} }
func (ps *ProvisioningServiceImpl) ProvisionNotifications() error { func (ps *ProvisioningServiceImpl) ProvisionNotifications(ctx context.Context) error {
alertNotificationsPath := filepath.Join(ps.Cfg.ProvisioningPath, "notifiers") alertNotificationsPath := filepath.Join(ps.Cfg.ProvisioningPath, "notifiers")
err := ps.provisionNotifiers(alertNotificationsPath, ps.EncryptionService) err := ps.provisionNotifiers(ctx, alertNotificationsPath, ps.EncryptionService)
return errutil.Wrap("Alert notification provisioning error", err) return errutil.Wrap("Alert notification provisioning error", err)
} }
func (ps *ProvisioningServiceImpl) ProvisionDashboards(ctx context.Context) error { func (ps *ProvisioningServiceImpl) ProvisionDashboards(ctx context.Context) error {
dashboardPath := filepath.Join(ps.Cfg.ProvisioningPath, "dashboards") dashboardPath := filepath.Join(ps.Cfg.ProvisioningPath, "dashboards")
dashProvisioner, err := ps.newDashboardProvisioner(dashboardPath, ps.SQLStore) dashProvisioner, err := ps.newDashboardProvisioner(ctx, dashboardPath, ps.SQLStore)
if err != nil { if err != nil {
return errutil.Wrap("Failed to create provisioner", err) return errutil.Wrap("Failed to create provisioner", err)
} }

View File

@ -25,7 +25,7 @@ type ProvisioningServiceMock struct {
RunFunc func(ctx context.Context) error RunFunc func(ctx context.Context) error
} }
func NewProvisioningServiceMock() *ProvisioningServiceMock { func NewProvisioningServiceMock(ctx context.Context) *ProvisioningServiceMock {
return &ProvisioningServiceMock{ return &ProvisioningServiceMock{
Calls: &Calls{}, Calls: &Calls{},
} }
@ -55,7 +55,7 @@ func (mock *ProvisioningServiceMock) ProvisionPlugins() error {
return nil return nil
} }
func (mock *ProvisioningServiceMock) ProvisionNotifications() error { func (mock *ProvisioningServiceMock) ProvisionNotifications(ctx context.Context) error {
mock.Calls.ProvisionNotifications = append(mock.Calls.ProvisionNotifications, nil) mock.Calls.ProvisionNotifications = append(mock.Calls.ProvisionNotifications, nil)
if mock.ProvisionNotificationsFunc != nil { if mock.ProvisionNotificationsFunc != nil {
return mock.ProvisionNotificationsFunc() return mock.ProvisionNotificationsFunc()

View File

@ -92,7 +92,7 @@ func setup() *serviceTestStruct {
} }
serviceTest.service = newProvisioningServiceImpl( serviceTest.service = newProvisioningServiceImpl(
func(string, dboards.Store) (dashboards.DashboardProvisioner, error) { func(context.Context, string, dboards.Store) (dashboards.DashboardProvisioner, error) {
return serviceTest.mock, nil return serviceTest.mock, nil
}, },
nil, nil,

View File

@ -1,6 +1,7 @@
package utils package utils
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
@ -8,9 +9,9 @@ import (
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
) )
func CheckOrgExists(orgID int64) error { func CheckOrgExists(ctx context.Context, orgID int64) error {
query := models.GetOrgByIdQuery{Id: orgID} query := models.GetOrgByIdQuery{Id: orgID}
if err := bus.Dispatch(&query); err != nil { if err := bus.DispatchCtx(ctx, &query); err != nil {
if errors.Is(err, models.ErrOrgNotFound) { if errors.Is(err, models.ErrOrgNotFound) {
return err return err
} }

View File

@ -1,6 +1,7 @@
package utils package utils
import ( import (
"context"
"testing" "testing"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
@ -14,16 +15,16 @@ func TestCheckOrgExists(t *testing.T) {
sqlstore.InitTestDB(t) sqlstore.InitTestDB(t)
defaultOrg := models.CreateOrgCommand{Name: "Main Org."} defaultOrg := models.CreateOrgCommand{Name: "Main Org."}
err := sqlstore.CreateOrg(&defaultOrg) err := sqlstore.CreateOrg(context.Background(), &defaultOrg)
require.NoError(t, err) require.NoError(t, err)
t.Run("default org exists", func(t *testing.T) { t.Run("default org exists", func(t *testing.T) {
err := CheckOrgExists(defaultOrg.Result.Id) err := CheckOrgExists(context.Background(), defaultOrg.Result.Id)
require.NoError(t, err) require.NoError(t, err)
}) })
t.Run("other org doesn't exist", func(t *testing.T) { t.Run("other org doesn't exist", func(t *testing.T) {
err := CheckOrgExists(defaultOrg.Result.Id + 1) err := CheckOrgExists(context.Background(), defaultOrg.Result.Id+1)
require.Equal(t, err, models.ErrOrgNotFound) require.Equal(t, err, models.ErrOrgNotFound)
}) })
}) })

View File

@ -16,16 +16,16 @@ import (
const MainOrgName = "Main Org." const MainOrgName = "Main Org."
func init() { func init() {
bus.AddHandler("sql", GetOrgById) bus.AddHandlerCtx("sql", GetOrgById)
bus.AddHandler("sql", CreateOrg) bus.AddHandlerCtx("sql", CreateOrg)
bus.AddHandler("sql", UpdateOrg) bus.AddHandlerCtx("sql", UpdateOrg)
bus.AddHandler("sql", UpdateOrgAddress) bus.AddHandlerCtx("sql", UpdateOrgAddress)
bus.AddHandler("sql", GetOrgByName) bus.AddHandlerCtx("sql", GetOrgByName)
bus.AddHandler("sql", SearchOrgs) bus.AddHandlerCtx("sql", SearchOrgs)
bus.AddHandler("sql", DeleteOrg) bus.AddHandlerCtx("sql", DeleteOrg)
} }
func SearchOrgs(query *models.SearchOrgsQuery) error { func SearchOrgs(ctx context.Context, query *models.SearchOrgsQuery) error {
query.Result = make([]*models.OrgDTO, 0) query.Result = make([]*models.OrgDTO, 0)
sess := x.Table("org") sess := x.Table("org")
if query.Query != "" { if query.Query != "" {
@ -48,7 +48,7 @@ func SearchOrgs(query *models.SearchOrgsQuery) error {
return err return err
} }
func GetOrgById(query *models.GetOrgByIdQuery) error { func GetOrgById(ctx context.Context, query *models.GetOrgByIdQuery) error {
var org models.Org var org models.Org
exists, err := x.Id(query.Id).Get(&org) exists, err := x.Id(query.Id).Get(&org)
if err != nil { if err != nil {
@ -63,7 +63,7 @@ func GetOrgById(query *models.GetOrgByIdQuery) error {
return nil return nil
} }
func GetOrgByName(query *models.GetOrgByNameQuery) error { func GetOrgByName(ctx context.Context, query *models.GetOrgByNameQuery) error {
var org models.Org var org models.Org
exists, err := x.Where("name=?", query.Name).Get(&org) exists, err := x.Where("name=?", query.Name).Get(&org)
if err != nil { if err != nil {
@ -154,7 +154,7 @@ func (ss *SQLStore) CreateOrgWithMember(name string, userID int64) (models.Org,
return createOrg(name, userID, ss.engine) return createOrg(name, userID, ss.engine)
} }
func CreateOrg(cmd *models.CreateOrgCommand) error { func CreateOrg(ctx context.Context, cmd *models.CreateOrgCommand) error {
org, err := createOrg(cmd.Name, cmd.UserId, x) org, err := createOrg(cmd.Name, cmd.UserId, x)
if err != nil { if err != nil {
return err return err
@ -164,7 +164,7 @@ func CreateOrg(cmd *models.CreateOrgCommand) error {
return nil return nil
} }
func UpdateOrg(cmd *models.UpdateOrgCommand) error { func UpdateOrg(ctx context.Context, cmd *models.UpdateOrgCommand) error {
return inTransaction(func(sess *DBSession) error { return inTransaction(func(sess *DBSession) error {
if isNameTaken, err := isOrgNameTaken(cmd.Name, cmd.OrgId, sess); err != nil { if isNameTaken, err := isOrgNameTaken(cmd.Name, cmd.OrgId, sess); err != nil {
return err return err
@ -197,7 +197,7 @@ func UpdateOrg(cmd *models.UpdateOrgCommand) error {
}) })
} }
func UpdateOrgAddress(cmd *models.UpdateOrgAddressCommand) error { func UpdateOrgAddress(ctx context.Context, cmd *models.UpdateOrgAddressCommand) error {
return inTransaction(func(sess *DBSession) error { return inTransaction(func(sess *DBSession) error {
org := models.Org{ org := models.Org{
Address1: cmd.Address1, Address1: cmd.Address1,
@ -224,7 +224,7 @@ func UpdateOrgAddress(cmd *models.UpdateOrgAddressCommand) error {
}) })
} }
func DeleteOrg(cmd *models.DeleteOrgCommand) error { func DeleteOrg(ctx context.Context, cmd *models.DeleteOrgCommand) error {
return inTransaction(func(sess *DBSession) error { return inTransaction(func(sess *DBSession) error {
if res, err := sess.Query("SELECT 1 from org WHERE id=?", cmd.Id); err != nil { if res, err := sess.Query("SELECT 1 from org WHERE id=?", cmd.Id); err != nil {
return err return err

View File

@ -25,14 +25,14 @@ func TestAccountDataAccess(t *testing.T) {
for i := 1; i < 4; i++ { for i := 1; i < 4; i++ {
cmd = &models.CreateOrgCommand{Name: fmt.Sprint("Org #", i)} cmd = &models.CreateOrgCommand{Name: fmt.Sprint("Org #", i)}
err = CreateOrg(cmd) err = CreateOrg(context.Background(), cmd)
require.NoError(t, err) require.NoError(t, err)
ids = append(ids, cmd.Result.Id) ids = append(ids, cmd.Result.Id)
} }
query := &models.SearchOrgsQuery{Ids: ids} query := &models.SearchOrgsQuery{Ids: ids}
err = SearchOrgs(query) err = SearchOrgs(context.Background(), query)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(query.Result), 3) require.Equal(t, len(query.Result), 3)
@ -42,13 +42,13 @@ func TestAccountDataAccess(t *testing.T) {
sqlStore = InitTestDB(t) sqlStore = InitTestDB(t)
for i := 1; i < 4; i++ { for i := 1; i < 4; i++ {
cmd := &models.CreateOrgCommand{Name: fmt.Sprint("Org #", i)} cmd := &models.CreateOrgCommand{Name: fmt.Sprint("Org #", i)}
err := CreateOrg(cmd) err := CreateOrg(context.Background(), cmd)
require.NoError(t, err) require.NoError(t, err)
} }
t.Run("Should be able to search with defaults", func(t *testing.T) { t.Run("Should be able to search with defaults", func(t *testing.T) {
query := &models.SearchOrgsQuery{} query := &models.SearchOrgsQuery{}
err := SearchOrgs(query) err := SearchOrgs(context.Background(), query)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(query.Result), 3) require.Equal(t, len(query.Result), 3)
@ -56,7 +56,7 @@ func TestAccountDataAccess(t *testing.T) {
t.Run("Should be able to limit search", func(t *testing.T) { t.Run("Should be able to limit search", func(t *testing.T) {
query := &models.SearchOrgsQuery{Limit: 1} query := &models.SearchOrgsQuery{Limit: 1}
err := SearchOrgs(query) err := SearchOrgs(context.Background(), query)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(query.Result), 1) require.Equal(t, len(query.Result), 1)
@ -64,7 +64,7 @@ func TestAccountDataAccess(t *testing.T) {
t.Run("Should be able to limit and paginate search", func(t *testing.T) { t.Run("Should be able to limit and paginate search", func(t *testing.T) {
query := &models.SearchOrgsQuery{Limit: 2, Page: 1} query := &models.SearchOrgsQuery{Limit: 2, Page: 1}
err := SearchOrgs(query) err := SearchOrgs(context.Background(), query)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(query.Result), 1) require.Equal(t, len(query.Result), 1)
@ -278,7 +278,7 @@ func TestAccountDataAccess(t *testing.T) {
t.Run("Removing user from org should delete user completely if in no other org", func(t *testing.T) { t.Run("Removing user from org should delete user completely if in no other org", func(t *testing.T) {
// make sure ac2 has no org // make sure ac2 has no org
err := DeleteOrg(&models.DeleteOrgCommand{Id: ac2.OrgId}) err := DeleteOrg(context.Background(), &models.DeleteOrgCommand{Id: ac2.OrgId})
require.NoError(t, err) require.NoError(t, err)
// remove ac2 user from ac1 org // remove ac2 user from ac1 org

View File

@ -49,7 +49,7 @@ func TestQuotaCommandsAndQueries(t *testing.T) {
UserId: 1, UserId: 1,
} }
err := CreateOrg(&userCmd) err := CreateOrg(context.Background(), &userCmd)
require.NoError(t, err) require.NoError(t, err)
orgId = userCmd.Result.Id orgId = userCmd.Result.Id

View File

@ -78,7 +78,7 @@ func populateDB(t *testing.T, sqlStore *SQLStore) {
// get 1st user's organisation // get 1st user's organisation
getOrgByIdQuery := &models.GetOrgByIdQuery{Id: users[0].OrgId} getOrgByIdQuery := &models.GetOrgByIdQuery{Id: users[0].OrgId}
err := GetOrgById(getOrgByIdQuery) err := GetOrgById(context.Background(), getOrgByIdQuery)
require.NoError(t, err) require.NoError(t, err)
org := getOrgByIdQuery.Result org := getOrgByIdQuery.Result
@ -102,7 +102,7 @@ func populateDB(t *testing.T, sqlStore *SQLStore) {
// get 2nd user's organisation // get 2nd user's organisation
getOrgByIdQuery = &models.GetOrgByIdQuery{Id: users[1].OrgId} getOrgByIdQuery = &models.GetOrgByIdQuery{Id: users[1].OrgId}
err = GetOrgById(getOrgByIdQuery) err = GetOrgById(context.Background(), getOrgByIdQuery)
require.NoError(t, err) require.NoError(t, err)
org = getOrgByIdQuery.Result org = getOrgByIdQuery.Result

View File

@ -80,7 +80,7 @@ func TestUserDataAccess(t *testing.T) {
}() }()
orgCmd := &models.CreateOrgCommand{Name: "Some Test Org"} orgCmd := &models.CreateOrgCommand{Name: "Some Test Org"}
err := CreateOrg(orgCmd) err := CreateOrg(context.Background(), orgCmd)
require.Nil(t, err) require.Nil(t, err)
cmd := models.CreateUserCommand{ cmd := models.CreateUserCommand{