This commit is contained in:
Patrick Arminio
2025-10-16 20:47:09 +01:00
parent a85de910d1
commit 405ea19495
9 changed files with 731 additions and 0 deletions

View File

@@ -38,8 +38,13 @@ classifiers = [
dependencies = [
"SQLAlchemy >=2.0.14,<2.1.0",
"pydantic >=1.10.13,<3.0.0",
"typer >=0.9.0",
"alembic >=1.13.0",
]
[project.scripts]
sqlmodel = "sqlmodel.cli:main"
[project.urls]
Homepage = "https://github.com/fastapi/sqlmodel"
Documentation = "https://sqlmodel.tiangolo.com"

View File

@@ -2,6 +2,7 @@
-r requirements-docs-tests.txt
pytest >=7.0.1,<9.0.0
coverage[toml] >=6.2,<8.0
inline-snapshot >=0.13.0
# Remove when support for Python 3.8 is dropped
mypy ==1.14.1; python_version < "3.9"
mypy ==1.18.2; python_version >= "3.9"

10
sqlmodel/cli/__init__.py Normal file
View File

@@ -0,0 +1,10 @@
import typer
from .migrations import migrations_app
app = typer.Typer()
app.add_typer(migrations_app, name="migrations")
def main() -> None:
app()

453
sqlmodel/cli/migrations.py Normal file
View File

@@ -0,0 +1,453 @@
import os
import re
from pathlib import Path
from typing import Optional
import typer
from alembic.autogenerate import produce_migrations, render_python_code
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from sqlalchemy import create_engine, pool
migrations_app = typer.Typer()
def get_migrations_dir(migrations_path: Optional[str] = None) -> Path:
"""Get the migrations directory path."""
if migrations_path:
return Path(migrations_path)
return Path.cwd() / "migrations"
def get_alembic_config(migrations_dir: Path) -> Config:
"""Create an Alembic config object programmatically without ini file."""
config = Config()
config.set_main_option("script_location", str(migrations_dir))
config.set_main_option("sqlalchemy.url", "") # Will be set by env.py
return config
def get_next_migration_number(migrations_dir: Path) -> str:
"""Get the next sequential migration number."""
if not migrations_dir.exists():
return "0001"
migration_files = list(migrations_dir.glob("*.py"))
if not migration_files:
return "0001"
numbers = []
for f in migration_files:
match = re.match(r"^(\d{4})_", f.name)
if match:
numbers.append(int(match.group(1)))
if not numbers:
return "0001"
return f"{max(numbers) + 1:04d}"
def get_metadata(models_path: str):
"""Import and return SQLModel metadata."""
import sys
from importlib import import_module
# Add current directory to Python path
sys.path.insert(0, str(Path.cwd()))
try:
# Import the module containing the models
models_module = import_module(models_path)
# Get SQLModel from the module or import it
if hasattr(models_module, "SQLModel"):
return models_module.SQLModel.metadata
else:
# Try importing SQLModel from sqlmodel
from sqlmodel import SQLModel
return SQLModel.metadata
except ImportError as e:
raise ValueError(
f"Failed to import models from '{models_path}': {e}\n"
f"Make sure the module exists and is importable from the current directory."
)
def get_current_revision(db_url: str) -> Optional[str]:
"""Get the current revision from the database."""
import sqlalchemy as sa
engine = create_engine(db_url, poolclass=pool.NullPool)
# Create alembic_version table if it doesn't exist
with engine.begin() as connection:
connection.execute(
sa.text("""
CREATE TABLE IF NOT EXISTS alembic_version (
version_num VARCHAR(32) NOT NULL,
CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num)
)
""")
)
# Get current revision
with engine.connect() as connection:
result = connection.execute(sa.text("SELECT version_num FROM alembic_version"))
row = result.first()
return row[0] if row else None
def generate_migration_ops(db_url: str, metadata):
"""Generate migration operations by comparing metadata to database."""
engine = create_engine(db_url, poolclass=pool.NullPool)
with engine.connect() as connection:
migration_context = MigrationContext.configure(connection)
# Use produce_migrations which returns actual operation objects
migration_script = produce_migrations(migration_context, metadata)
return migration_script.upgrade_ops, migration_script.downgrade_ops
@migrations_app.command()
def create(
message: str = typer.Option(..., "--message", "-m", help="Migration message"),
models: str = typer.Option(
...,
"--models",
help="Python import path to models module (e.g., 'models' or 'app.models')",
),
migrations_path: Optional[str] = typer.Option(
None, "--path", "-p", help="Path to migrations directory"
),
) -> None:
"""Create a new migration with autogenerate."""
migrations_dir = get_migrations_dir(migrations_path)
# Create migrations directory if it doesn't exist
migrations_dir.mkdir(parents=True, exist_ok=True)
# Get next migration number
migration_number = get_next_migration_number(migrations_dir)
# Create slug from message
# TODO: truncate, handle special characters
slug = message.lower().replace(" ", "_")
slug = re.sub(r"[^a-z0-9_]", "", slug)
filename = f"{migration_number}_{slug}.py"
filepath = migrations_dir / filename
typer.echo(f"Creating migration: {filename}")
try:
# Get database URL
db_url = os.getenv("DATABASE_URL")
if not db_url:
raise ValueError(
"DATABASE_URL environment variable is not set. "
"Please set it to your database connection string."
)
# Get metadata
metadata = get_metadata(models)
# Check if there are pending migrations that need to be applied
current_revision = get_current_revision(db_url)
existing_migrations = sorted(
[f for f in migrations_dir.glob("*.py") if f.name != "__init__.py"]
)
if existing_migrations:
# Get the latest migration file
latest_migration_file = existing_migrations[-1]
content = latest_migration_file.read_text()
# Extract revision = "..." from the file
match = re.search(
r'^revision = ["\']([^"\']+)["\']', content, re.MULTILINE
)
if match:
latest_file_revision = match.group(1)
# Check if database is up to date
if current_revision != latest_file_revision:
typer.echo(
f"Error: Database is not up to date. Current revision: {current_revision or 'None'}, "
f"Latest migration: {latest_file_revision}",
err=True,
)
typer.echo(
"Please run 'sqlmodel migrations migrate' to apply pending migrations before creating a new one.",
err=True,
)
raise typer.Exit(1)
# Generate migration operations
upgrade_ops_obj, downgrade_ops_obj = generate_migration_ops(db_url, metadata)
# Render upgrade
if upgrade_ops_obj:
upgrade_code = render_python_code(upgrade_ops_obj).strip()
else:
upgrade_code = "pass"
# Render downgrade
if downgrade_ops_obj:
downgrade_code = render_python_code(downgrade_ops_obj).strip()
else:
# TODO: space :)
downgrade_code = "pass"
# Remove Alembic comments to check if migrations are actually empty
def extract_code_without_comments(code: str) -> str:
"""Extract actual code, removing Alembic auto-generated comments."""
lines = code.split("\n")
actual_lines = []
for line in lines:
# Skip Alembic comment lines
if not line.strip().startswith("# ###"):
actual_lines.append(line)
return "\n".join(actual_lines).strip()
upgrade_code_clean = extract_code_without_comments(upgrade_code)
downgrade_code_clean = extract_code_without_comments(downgrade_code)
# Only reject empty migrations if there are already existing migrations
# (i.e., this is not the first migration)
if (
upgrade_code_clean == "pass"
and downgrade_code_clean == "pass"
and len(existing_migrations) > 0
):
# TODO: better message
typer.echo(
"Empty migrations are not allowed"
) # TODO: unless you pass `--empty`
raise typer.Exit(1)
# Generate revision ID from filename (without .py extension)
revision_id = f"{migration_number}_{slug}"
# Get previous revision by reading the last migration file's revision ID
down_revision = None
if existing_migrations:
# Read the last migration file to get its revision ID
last_migration = existing_migrations[-1]
content = last_migration.read_text()
# Extract revision = "..." from the file
import re as regex_module
match = regex_module.search(
r'^revision = ["\']([^"\']+)["\']', content, regex_module.MULTILINE
)
if match:
down_revision = match.group(1)
# Check if we need to import sqlmodel
needs_sqlmodel = "sqlmodel" in upgrade_code or "sqlmodel" in downgrade_code
# Generate migration file - build without f-strings to avoid % issues
lines: list[str] = []
lines.append(f'"""{message}"""')
lines.append("")
lines.append("import sqlalchemy as sa")
if needs_sqlmodel:
lines.append("import sqlmodel")
lines.append("from alembic import op")
lines.append("")
lines.append('revision = "' + revision_id + '"')
lines.append("down_revision = " + repr(down_revision))
lines.append("depends_on = None")
lines.append("")
lines.append("")
lines.append("def upgrade() -> None:")
# Add upgrade code with proper indentation
for line in upgrade_code.split("\n"):
lines.append(line)
lines.append("")
lines.append("")
lines.append("def downgrade() -> None:")
# Add downgrade code with proper indentation
for line in downgrade_code.split("\n"):
lines.append(line)
migration_content = "\n".join(lines)
filepath.write_text(migration_content)
typer.echo(f"✓ Created migration: {filename}")
except ValueError as e:
typer.echo(f"Error creating migration: {e}", err=True)
raise typer.Exit(1)
except Exception as e:
import traceback
typer.echo(f"Error creating migration: {e}", err=True)
typer.echo("\nFull traceback:", err=True)
traceback.print_exc()
raise typer.Exit(1)
def get_pending_migrations(
migrations_dir: Path, current_revision: Optional[str]
) -> list[Path]:
"""Get list of pending migration files."""
all_migrations = sorted(
[
f
for f in migrations_dir.glob("*.py")
if f.name != "__init__.py" and f.name != "env.py"
]
)
if not current_revision:
return all_migrations
# Find migrations after the current revision
pending = []
found_current = False
for migration_file in all_migrations:
if found_current:
pending.append(migration_file)
elif migration_file.stem == current_revision:
found_current = True
return pending
def apply_migrations_programmatically(
migrations_dir: Path, db_url: str, models: str
) -> None:
"""Apply migrations programmatically without env.py."""
import importlib.util
import sqlalchemy as sa
# Get metadata
metadata = get_metadata(models)
# Create engine
engine = create_engine(db_url, poolclass=pool.NullPool)
# Create alembic_version table if it doesn't exist (outside transaction)
with engine.begin() as connection:
connection.execute(
sa.text("""
CREATE TABLE IF NOT EXISTS alembic_version (
version_num VARCHAR(32) NOT NULL,
CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num)
)
""")
)
# Get current revision
with engine.connect() as connection:
result = connection.execute(sa.text("SELECT version_num FROM alembic_version"))
row = result.first()
current_revision = row[0] if row else None
# Get pending migrations
pending_migrations = get_pending_migrations(migrations_dir, current_revision)
if not pending_migrations:
typer.echo(" No pending migrations")
return
# Run each migration
for migration_file in pending_migrations:
revision_id = migration_file.stem
typer.echo(f" Applying: {revision_id}")
# Execute each migration in its own transaction
with engine.begin() as connection:
# Create migration context
migration_context = MigrationContext.configure(
connection, opts={"target_metadata": metadata}
)
# Load the migration module
spec = importlib.util.spec_from_file_location(revision_id, migration_file)
if not spec or not spec.loader:
raise ValueError(f"Could not load migration: {migration_file}")
module = importlib.util.module_from_spec(spec)
# Also make sqlalchemy and sqlmodel available
import sqlmodel
module.sa = sa # type: ignore
module.sqlmodel = sqlmodel # type: ignore
# Execute the module to define the functions
spec.loader.exec_module(module)
# Create operations context and run upgrade within ops.invoke_for_target
from alembic.operations import ops
with ops.Operations.context(migration_context):
# Now op proxy is available via alembic.op
from alembic import op as alembic_op
module.op = alembic_op # type: ignore
# Execute upgrade
module.upgrade()
# Update alembic_version table
if current_revision:
connection.execute(
sa.text("UPDATE alembic_version SET version_num = :version"),
{"version": revision_id},
)
else:
connection.execute(
sa.text(
"INSERT INTO alembic_version (version_num) VALUES (:version)"
),
{"version": revision_id},
)
current_revision = revision_id
@migrations_app.command()
def migrate(
models: str = typer.Option(
...,
"--models",
help="Python import path to models module (e.g., 'models' or 'app.models')",
),
migrations_path: Optional[str] = typer.Option(
None, "--path", "-p", help="Path to migrations directory"
),
) -> None:
"""Apply all pending migrations to the database."""
migrations_dir = get_migrations_dir(migrations_path)
if not migrations_dir.exists():
typer.echo(
f"Error: {migrations_dir} not found. Run 'sqlmodel migrations init' first.",
err=True,
)
raise typer.Exit(1)
# Get database URL
db_url = os.getenv("DATABASE_URL")
if not db_url:
typer.echo("Error: DATABASE_URL environment variable is not set.", err=True)
raise typer.Exit(1)
typer.echo("Applying migrations...")
try:
apply_migrations_programmatically(migrations_dir, db_url, models)
typer.echo("✓ Migrations applied successfully")
except Exception as e:
typer.echo(f"Error applying migrations: {e}", err=True)
raise typer.Exit(1)

View File

View File

@@ -0,0 +1,27 @@
"""Initial migration"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
revision = "0001_initial_migration"
down_revision = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('hero',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('secret_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('age', sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('hero')
# ### end Alembic commands ###

View File

View File

@@ -0,0 +1,10 @@
from typing import Optional
from sqlmodel import Field, SQLModel
class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
secret_name: str
age: Optional[int] = None

View File

@@ -0,0 +1,225 @@
import shutil
from pathlib import Path
import pytest
from inline_snapshot import external, register_format_alias
from sqlmodel.cli import app
from typer.testing import CliRunner
register_format_alias(".py", ".txt")
runner = CliRunner()
HERE = Path(__file__).parent
register_format_alias(".html", ".txt")
def test_create_first_migration(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
"""Test creating the first migration with an empty database."""
db_path = tmp_path / "test.db"
db_url = f"sqlite:///{db_path}"
migrations_dir = tmp_path / "migrations"
model_source = HERE / "./fixtures/models_initial.py"
models_dir = tmp_path / "test_models"
models_dir.mkdir()
(models_dir / "__init__.py").write_text("")
models_file = models_dir / "models.py"
shutil.copy(model_source, models_file)
monkeypatch.setenv("DATABASE_URL", db_url)
monkeypatch.chdir(tmp_path)
# Run the create command
result = runner.invoke(
app,
[
"migrations",
"create",
"-m",
"Initial migration",
"--models",
"test_models.models",
"--path",
str(migrations_dir),
],
)
assert result.exit_code == 0, f"Command failed: {result.stdout}"
assert "✓ Created migration:" in result.stdout
migration_files = sorted(
[str(f.relative_to(tmp_path)) for f in migrations_dir.glob("*.py")]
)
assert migration_files == [
"migrations/0001_initial_migration.py",
]
migration_file = migrations_dir / "0001_initial_migration.py"
assert migration_file.read_text() == external(
"uuid:f1182584-912e-4f31-9d79-2233e5a8a986.py"
)
# TODO: to force migration you need to pass `--empty`s
def test_running_migration_twice_only_generates_migration_once(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
):
db_path = tmp_path / "test.db"
db_url = f"sqlite:///{db_path}"
migrations_dir = tmp_path / "migrations"
model_source = HERE / "./fixtures/models_initial.py"
models_dir = tmp_path / "test_models"
models_dir.mkdir()
(models_dir / "__init__.py").write_text("")
models_file = models_dir / "models.py"
shutil.copy(model_source, models_file)
monkeypatch.setenv("DATABASE_URL", db_url)
monkeypatch.chdir(tmp_path)
# Run the create command
result = runner.invoke(
app,
[
"migrations",
"create",
"-m",
"Initial migration",
"--models",
"test_models.models",
"--path",
str(migrations_dir),
],
)
assert result.exit_code == 0, f"Command failed: {result.stdout}"
assert "✓ Created migration:" in result.stdout
# Apply the first migration to the database
result = runner.invoke(
app,
[
"migrations",
"migrate",
"--models",
"test_models.models",
"--path",
str(migrations_dir),
],
)
assert result.exit_code == 0, f"Migration failed: {result.stdout}"
# Run the create command again (should fail with empty migration)
result = runner.invoke(
app,
[
"migrations",
"create",
"-m",
"Initial migration",
"--models",
"test_models.models",
"--path",
str(migrations_dir),
],
)
assert result.exit_code == 1
assert "Empty migrations are not allowed" in result.stdout
migration_files = sorted(
[str(f.relative_to(tmp_path)) for f in migrations_dir.glob("*.py")]
)
assert migration_files == [
"migrations/0001_initial_migration.py",
]
def test_cannot_create_migration_with_pending_migrations(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
):
"""Test that creating a migration fails if there are unapplied migrations."""
db_path = tmp_path / "test.db"
db_url = f"sqlite:///{db_path}"
migrations_dir = tmp_path / "migrations"
model_source = HERE / "./fixtures/models_initial.py"
models_dir = tmp_path / "test_models"
models_dir.mkdir()
(models_dir / "__init__.py").write_text("")
models_file = models_dir / "models.py"
shutil.copy(model_source, models_file)
monkeypatch.setenv("DATABASE_URL", db_url)
monkeypatch.chdir(tmp_path)
# Run the create command to create first migration
result = runner.invoke(
app,
[
"migrations",
"create",
"-m",
"Initial migration",
"--models",
"test_models.models",
"--path",
str(migrations_dir),
],
)
assert result.exit_code == 0, f"Command failed: {result.stdout}"
assert "✓ Created migration:" in result.stdout
# Try to create another migration WITHOUT applying the first one
result = runner.invoke(
app,
[
"migrations",
"create",
"-m",
"Second migration",
"--models",
"test_models.models",
"--path",
str(migrations_dir),
],
)
# Should fail because database is not up to date
assert result.exit_code == 1
# Error messages are printed to stderr, which Typer's CliRunner combines into output
assert (
"Database is not up to date" in result.stdout
or "Database is not up to date" in str(result.output)
)
assert (
"Please run 'sqlmodel migrations migrate'" in result.stdout
or "Please run 'sqlmodel migrations migrate'" in str(result.output)
)
# Verify only one migration file exists
migration_files = sorted(
[str(f.relative_to(tmp_path)) for f in migrations_dir.glob("*.py")]
)
assert migration_files == [
"migrations/0001_initial_migration.py",
]