mirror of
https://github.com/grafana/grafana.git
synced 2025-07-29 23:52:19 +08:00
98 lines
2.5 KiB
Go
98 lines
2.5 KiB
Go
package db_test
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/grafana/grafana/pkg/storage/unified/sql/db"
|
|
"github.com/grafana/grafana/pkg/storage/unified/sql/db/mocks"
|
|
"github.com/grafana/grafana/pkg/util/testutil"
|
|
)
|
|
|
|
var errTest = errors.New("you shall not pass")
|
|
|
|
// Copy-paste of the constants used in `service.go`, since we need to use a
|
|
// separate package to avoid circular dependencies so we cannot import them.
|
|
// Keep these ones and the ones in `service.go` in sync.
|
|
const (
|
|
txOpStr = "transactional operation"
|
|
beginStr = "begin"
|
|
commitStr = "commit"
|
|
rollbackStr = "rollback"
|
|
)
|
|
|
|
func TestNewWithTxFunc(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
execTest := func(t *testing.T, d db.DB, txErr error) error {
|
|
ctx := testutil.NewDefaultTestContext(t)
|
|
return db.NewWithTxFunc(d.BeginTx).WithTx(ctx, nil,
|
|
func(context.Context, db.Tx) error {
|
|
return txErr
|
|
})
|
|
}
|
|
|
|
t.Run("happy path", func(t *testing.T) {
|
|
t.Parallel()
|
|
mDB, mTx := mocks.NewDB(t), mocks.NewTx(t)
|
|
|
|
mDB.EXPECT().BeginTx(mock.Anything, mock.Anything).Return(mTx, nil)
|
|
mTx.EXPECT().Commit().Return(nil)
|
|
|
|
err := execTest(t, mDB, nil)
|
|
require.NoError(t, err)
|
|
})
|
|
|
|
t.Run("failed begin", func(t *testing.T) {
|
|
t.Parallel()
|
|
mDB := mocks.NewDB(t)
|
|
|
|
mDB.EXPECT().BeginTx(mock.Anything, mock.Anything).Return(nil, errTest)
|
|
|
|
err := execTest(t, mDB, nil)
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, beginStr)
|
|
})
|
|
|
|
t.Run("fail tx", func(t *testing.T) {
|
|
t.Parallel()
|
|
mDB, mTx := mocks.NewDB(t), mocks.NewTx(t)
|
|
|
|
mDB.EXPECT().BeginTx(mock.Anything, mock.Anything).Return(mTx, nil)
|
|
mTx.EXPECT().Rollback().Return(nil)
|
|
|
|
err := execTest(t, mDB, errTest)
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, txOpStr)
|
|
})
|
|
|
|
t.Run("fail tx; fail rollback", func(t *testing.T) {
|
|
t.Parallel()
|
|
mDB, mTx := mocks.NewDB(t), mocks.NewTx(t)
|
|
|
|
mDB.EXPECT().BeginTx(mock.Anything, mock.Anything).Return(mTx, nil)
|
|
mTx.EXPECT().Rollback().Return(errTest)
|
|
|
|
err := execTest(t, mDB, errTest)
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, txOpStr)
|
|
require.ErrorContains(t, err, rollbackStr)
|
|
})
|
|
|
|
t.Run("fail commit", func(t *testing.T) {
|
|
t.Parallel()
|
|
mDB, mTx := mocks.NewDB(t), mocks.NewTx(t)
|
|
|
|
mDB.EXPECT().BeginTx(mock.Anything, mock.Anything).Return(mTx, nil)
|
|
mTx.EXPECT().Commit().Return(errTest)
|
|
|
|
err := execTest(t, mDB, nil)
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, commitStr)
|
|
})
|
|
}
|