Zanzana: Adds running migrations from openfga w. RunMigrations() (#105691)

This commit is contained in:
Eric Leijonmarck
2025-05-29 15:54:12 +01:00
committed by GitHub
parent 8dcd66e0e6
commit 69653ea3dc
3 changed files with 154 additions and 112 deletions

View File

@ -15,6 +15,7 @@ import (
"github.com/grafana/grafana/pkg/infra/tracing"
"github.com/grafana/grafana/pkg/services/authz/zanzana/common"
"github.com/grafana/grafana/pkg/services/authz/zanzana/store"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tests/testsuite"
@ -50,15 +51,20 @@ func TestIntegrationServer(t *testing.T) {
t.Skip("skipping integration test")
}
testDB, cfg := db.InitTestDBWithCfg(t)
// Create a test-specific config to avoid migration conflicts
cfg := setting.NewCfg()
// Use a test-specific database to avoid migration conflicts
testStore := sqlstore.NewTestStore(t, sqlstore.WithCfg(cfg))
// Hack to skip these tests on mysql 5.7
if testDB.GetDialect().DriverName() == migrator.MySQL {
if supported, err := testDB.RecursiveQueriesAreSupported(); !supported || err != nil {
if testStore.GetDialect().DriverName() == migrator.MySQL {
if supported, err := testStore.RecursiveQueriesAreSupported(); !supported || err != nil {
t.Skip("skipping integration test")
}
}
srv := setup(t, testDB, cfg)
srv := setup(t, testStore, cfg)
t.Run("test check", func(t *testing.T) {
testCheck(t, srv)
})
@ -80,6 +86,7 @@ func TestIntegrationServer(t *testing.T) {
func setup(t *testing.T, testDB db.DB, cfg *setting.Cfg) *Server {
t.Helper()
store, err := store.NewEmbeddedStore(cfg, testDB, log.NewNopLogger())
require.NoError(t, err)
openfga, err := NewOpenFGAServer(cfg.ZanzanaServer, store, log.NewNopLogger())
@ -122,6 +129,26 @@ func setup(t *testing.T, testDB db.DB, cfg *setting.Cfg) *Server {
t.Log(w.String())
}
// First, try to delete any existing tuples to avoid conflicts
deletes := make([]*openfgav1.TupleKeyWithoutCondition, 0, len(writes.TupleKeys))
for _, tupleKey := range writes.TupleKeys {
deletes = append(deletes, &openfgav1.TupleKeyWithoutCondition{
User: tupleKey.User,
Relation: tupleKey.Relation,
Object: tupleKey.Object,
})
}
// Try to delete existing tuples (ignore errors if they don't exist)
_, _ = openfga.Write(context.Background(), &openfgav1.WriteRequest{
StoreId: storeInf.ID,
AuthorizationModelId: storeInf.ModelID,
Deletes: &openfgav1.WriteRequestDeletes{
TupleKeys: deletes,
},
})
// Now write the new tuples
_, err = openfga.Write(context.Background(), &openfgav1.WriteRequest{
StoreId: storeInf.ID,
AuthorizationModelId: storeInf.ModelID,

View File

@ -1,19 +1,22 @@
package migration
import (
"embed"
"errors"
"fmt"
"strings"
"github.com/grafana/grafana/pkg/util/xorm"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
"github.com/grafana/grafana/pkg/util/xorm"
"github.com/openfga/openfga/pkg/storage/migrate"
)
func Run(cfg *setting.Cfg, typ, connStr string, fs embed.FS, path string) error {
engine, err := xorm.NewEngine(typ, connStr)
func Run(cfg *setting.Cfg, dbType string, grafanaDBConfig *sqlstore.DatabaseConfig, logger log.Logger) error {
connStr := grafanaDBConfig.ConnectionString
// running grafana migrations
engine, err := xorm.NewEngine(dbType, connStr)
if err != nil {
return fmt.Errorf("failed to create db engine: %w", err)
}
@ -21,108 +24,124 @@ func Run(cfg *setting.Cfg, typ, connStr string, fs embed.FS, path string) error
m := migrator.NewMigrator(engine, cfg)
m.AddCreateMigration()
if err := RunWithMigrator(m, cfg, fs, path); err != nil {
if err := RunWithMigrator(m, cfg); err != nil {
return err
}
return engine.Close()
}
func RunWithMigrator(m *migrator.Migrator, cfg *setting.Cfg, fs embed.FS, path string) error {
migrations, err := getMigrations(fs, path)
if err != nil {
return err
// running openfga migrations
switch dbType {
case migrator.SQLite:
// openfga expects sqlite but grafana uses sqlite3
dbType = "sqlite"
case migrator.Postgres:
// Parse and transform the connection string to the format OpenFGA expects
connStr = constructPostgresConnStrForOpenFGA(grafanaDBConfig)
}
for _, mig := range migrations {
m.AddMigration(mig.name, mig.migration)
migrationConfig := migrate.MigrationConfig{
URI: connStr,
Engine: dbType,
}
sec := cfg.Raw.Section("database")
return m.Start(
sec.Key("migration_locking").MustBool(true),
sec.Key("locking_attempt_timeout_sec").MustInt(),
)
}
type migration struct {
name string
migration migrator.Migration
}
func getMigrations(fs embed.FS, path string) ([]migration, error) {
entries, err := fs.ReadDir(path)
if err != nil {
return nil, fmt.Errorf("failed to read migration dir: %w", err)
if err := migrate.RunMigrations(migrationConfig); err != nil {
return fmt.Errorf("failed to run openfga migrations: %w", err)
}
// parseStatements extracts statements from a sql file so we can execute
// them as separate migrations. OpenFGA uses Goose as their migration egine
// and Goose uses a single sql file for both up and down migrations.
// Grafana only supports up migration so we strip out the down migration
// and parse each individual statement
parseStatements := func(data []byte) ([]string, error) {
scripts := strings.Split(strings.TrimPrefix(string(data), "-- +goose Up"), "-- +goose Down")
if len(scripts) != 2 {
return nil, errors.New("malformed migration file")
}
// We assume that up migrations are always before down migrations
parts := strings.SplitAfter(scripts[0], ";")
stmts := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
stmts = append(stmts, p)
}
}
return stmts, nil
if err := engine.Close(); err != nil {
logger.Warn("failed to close db engine", "error", err)
}
formatName := func(name string) string {
// Each migration file start with XXX where X is a number.
// We remove that part and prefix each migration with "zanzana".
return strings.TrimSuffix("zanzana"+name[3:], ".sql")
}
migrations := make([]migration, 0, len(entries))
for _, e := range entries {
data, err := fs.ReadFile(path + "/" + e.Name())
if err != nil {
return nil, fmt.Errorf("failed to read migration file: %w", err)
}
stmts, err := parseStatements(data)
if err != nil {
return nil, fmt.Errorf("failed to parse migration: %w", err)
}
migrations = append(migrations, migration{
name: formatName(e.Name()),
migration: &rawMigration{stmts: stmts},
})
}
return migrations, nil
}
var _ migrator.CodeMigration = (*rawMigration)(nil)
type rawMigration struct {
stmts []string
migrator.MigrationBase
}
func (m *rawMigration) Exec(sess *xorm.Session, migrator *migrator.Migrator) error {
for _, stmt := range m.stmts {
if _, err := sess.Exec(stmt); err != nil {
return fmt.Errorf("failed to run migration: %w", err)
}
}
return nil
}
func (m *rawMigration) SQL(dialect migrator.Dialect) string {
return strings.Join(m.stmts, "\n")
func RunWithMigrator(m *migrator.Migrator, cfg *setting.Cfg) error {
openfgaTables := []string{"tuple", "authorization_model", "store", "assertion", "changelog"}
for _, table := range openfgaTables {
m.AddMigration(fmt.Sprintf("Drop existing openfga table %s", table), migrator.NewDropTableMigration(table))
}
sec := cfg.Raw.Section("database")
return m.Start(
sec.Key("migration_locking").MustBool(true),
sec.Key("locking_attempt_timeout_sec").MustInt(30),
)
}
// constructPostgresConnStrForOpenFGA parses a PostgreSQL connection string into a map of key-value pairs
// parses into a format like
// postgresql://grafana:password@127.0.0.1:5432/grafana?sslmode=disable&lock_timeout=2s&statement_timeout=10s
func constructPostgresConnStrForOpenFGA(grafanaDBCfg *sqlstore.DatabaseConfig) string {
var host, port, user, password, dbname, sslmode string
// If individual fields are populated, use them directly
if grafanaDBCfg.Host != "" && grafanaDBCfg.User != "" && grafanaDBCfg.Name != "" {
// Parse host and port from the Host field (which might contain both)
addr, err := util.SplitHostPortDefault(grafanaDBCfg.Host, "127.0.0.1", "5432")
if err != nil {
// If parsing fails, use the host as-is and assume default port
addr = util.NetworkAddress{Host: grafanaDBCfg.Host, Port: "5432"}
}
host = addr.Host
port = addr.Port
user = grafanaDBCfg.User
password = grafanaDBCfg.Pwd
dbname = grafanaDBCfg.Name
sslmode = grafanaDBCfg.SslMode
} else {
// Parse from connection string (test environment case)
// Connection string format: "user=grafanatest password=grafanatest host=127.0.0.1 port=5432 dbname=grafanatest sslmode=disable"
connStr := grafanaDBCfg.ConnectionString
parts := strings.Fields(connStr)
// Set defaults
host = "127.0.0.1"
port = "5432"
sslmode = "disable"
for _, part := range parts {
if strings.Contains(part, "=") {
kv := strings.SplitN(part, "=", 2)
if len(kv) == 2 {
key, value := kv[0], kv[1]
switch key {
case "host":
host = value
case "port":
port = value
case "user":
user = value
case "password":
password = value
case "dbname":
dbname = value
case "sslmode":
sslmode = value
}
}
}
}
}
// Construct the connection string with proper host:port format
connectionStr := fmt.Sprintf("postgresql://%s:%s@%s:%s/%s",
user, password, host, port, dbname)
// Build query parameters - always include sslmode and timeouts
queryParams := fmt.Sprintf("sslmode=%s&lock_timeout=2s&statement_timeout=10s", sslmode)
// Only add SSL certificate parameters if they are not empty
if grafanaDBCfg.ClientCertPath != "" {
queryParams += fmt.Sprintf("&sslcert=%s", grafanaDBCfg.ClientCertPath)
}
if grafanaDBCfg.ClientKeyPath != "" {
queryParams += fmt.Sprintf("&sslkey=%s", grafanaDBCfg.ClientKeyPath)
}
if grafanaDBCfg.CaCertPath != "" {
queryParams += fmt.Sprintf("&sslrootcert=%s", grafanaDBCfg.CaCertPath)
}
finalConnStr := connectionStr + "?" + queryParams
return finalConnStr
}

View File

@ -5,7 +5,6 @@ import (
"strings"
"time"
"github.com/openfga/openfga/assets"
"github.com/openfga/openfga/pkg/storage"
"github.com/openfga/openfga/pkg/storage/mysql"
"github.com/openfga/openfga/pkg/storage/postgres"
@ -31,7 +30,7 @@ func NewStore(cfg *setting.Cfg, logger log.Logger) (storage.OpenFGADatastore, er
switch grafanaDBCfg.Type {
case migrator.SQLite:
connStr := sqliteConnectionString(grafanaDBCfg.ConnectionString)
if err := migration.Run(cfg, migrator.SQLite, connStr, assets.EmbedMigrations, assets.SqliteMigrationDir); err != nil {
if err := migration.Run(cfg, migrator.SQLite, grafanaDBCfg, logger); err != nil {
return nil, fmt.Errorf("failed to run migrations: %w", err)
}
@ -39,18 +38,18 @@ func NewStore(cfg *setting.Cfg, logger log.Logger) (storage.OpenFGADatastore, er
case migrator.MySQL:
// For mysql we need to pass parseTime parameter in connection string
connStr := grafanaDBCfg.ConnectionString + "&parseTime=true"
if err := migration.Run(cfg, migrator.MySQL, connStr, assets.EmbedMigrations, assets.MySQLMigrationDir); err != nil {
if err := migration.Run(cfg, migrator.MySQL, grafanaDBCfg, logger); err != nil {
return nil, fmt.Errorf("failed to run migrations: %w", err)
}
return mysql.New(connStr, zanzanaDBCfg)
case migrator.Postgres:
connStr := grafanaDBCfg.ConnectionString
if err := migration.Run(cfg, migrator.Postgres, connStr, assets.EmbedMigrations, assets.PostgresMigrationDir); err != nil {
// Parse and transform the connection string to the format OpenFGA expects
if err := migration.Run(cfg, migrator.Postgres, grafanaDBCfg, logger); err != nil {
return nil, fmt.Errorf("failed to run migrations: %w", err)
}
return postgres.New(connStr, zanzanaDBCfg)
return postgres.New(grafanaDBCfg.ConnectionString, zanzanaDBCfg)
}
// Should never happen
@ -66,22 +65,19 @@ func NewEmbeddedStore(cfg *setting.Cfg, db db.DB, logger log.Logger) (storage.Op
switch grafanaDBCfg.Type {
case migrator.SQLite:
grafanaDBCfg.ConnectionString = sqliteConnectionString(grafanaDBCfg.ConnectionString)
if err := migration.Run(cfg, migrator.SQLite, grafanaDBCfg.ConnectionString, assets.EmbedMigrations, assets.SqliteMigrationDir); err != nil {
if err := migration.Run(cfg, migrator.SQLite, grafanaDBCfg, logger); err != nil {
return nil, fmt.Errorf("failed to run migrations: %w", err)
}
return sqlite.New(grafanaDBCfg.ConnectionString, zanzanaDBCfg)
case migrator.MySQL:
m := migrator.NewMigrator(db.GetEngine(), cfg)
if err := migration.RunWithMigrator(m, cfg, assets.EmbedMigrations, assets.MySQLMigrationDir); err != nil {
if err := migration.Run(cfg, migrator.MySQL, grafanaDBCfg, logger); err != nil {
return nil, fmt.Errorf("failed to run migrations: %w", err)
}
// For mysql we need to pass parseTime parameter in connection string
return mysql.New(grafanaDBCfg.ConnectionString+"&parseTime=true", zanzanaDBCfg)
case migrator.Postgres:
m := migrator.NewMigrator(db.GetEngine(), cfg)
if err := migration.RunWithMigrator(m, cfg, assets.EmbedMigrations, assets.PostgresMigrationDir); err != nil {
if err := migration.Run(cfg, migrator.Postgres, grafanaDBCfg, logger); err != nil {
return nil, fmt.Errorf("failed to run migrations: %w", err)
}