From 405ea1949511dc517e72a44655bf7bacd4eb6f18 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Thu, 16 Oct 2025 20:47:09 +0100 Subject: [PATCH] WIP --- pyproject.toml | 5 + requirements-tests.txt | 1 + sqlmodel/cli/__init__.py | 10 + sqlmodel/cli/migrations.py | 453 ++++++++++++++++++ tests/test_cli/__init__.py | 0 .../f1182584-912e-4f31-9d79-2233e5a8a986.py | 27 ++ tests/test_cli/fixtures/__init__.py | 0 tests/test_cli/fixtures/models_initial.py | 10 + tests/test_cli/test_migrations.py | 225 +++++++++ 9 files changed, 731 insertions(+) create mode 100644 sqlmodel/cli/__init__.py create mode 100644 sqlmodel/cli/migrations.py create mode 100644 tests/test_cli/__init__.py create mode 100644 tests/test_cli/__inline_snapshot__/test_migrations/test_create_first_migration/f1182584-912e-4f31-9d79-2233e5a8a986.py create mode 100644 tests/test_cli/fixtures/__init__.py create mode 100644 tests/test_cli/fixtures/models_initial.py create mode 100644 tests/test_cli/test_migrations.py diff --git a/pyproject.toml b/pyproject.toml index cd47f5ec..9c14dfdf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,8 +38,13 @@ classifiers = [ dependencies = [ "SQLAlchemy >=2.0.14,<2.1.0", "pydantic >=1.10.13,<3.0.0", + "typer >=0.9.0", + "alembic >=1.13.0", ] +[project.scripts] +sqlmodel = "sqlmodel.cli:main" + [project.urls] Homepage = "https://github.com/fastapi/sqlmodel" Documentation = "https://sqlmodel.tiangolo.com" diff --git a/requirements-tests.txt b/requirements-tests.txt index 6cae1015..e5333545 100644 --- a/requirements-tests.txt +++ b/requirements-tests.txt @@ -2,6 +2,7 @@ -r requirements-docs-tests.txt pytest >=7.0.1,<9.0.0 coverage[toml] >=6.2,<8.0 +inline-snapshot >=0.13.0 # Remove when support for Python 3.8 is dropped mypy ==1.14.1; python_version < "3.9" mypy ==1.18.2; python_version >= "3.9" diff --git a/sqlmodel/cli/__init__.py b/sqlmodel/cli/__init__.py new file mode 100644 index 00000000..cf2cf7c3 --- /dev/null +++ b/sqlmodel/cli/__init__.py @@ -0,0 +1,10 @@ +import typer + +from .migrations import migrations_app + +app = typer.Typer() +app.add_typer(migrations_app, name="migrations") + + +def main() -> None: + app() diff --git a/sqlmodel/cli/migrations.py b/sqlmodel/cli/migrations.py new file mode 100644 index 00000000..3c5cc9fa --- /dev/null +++ b/sqlmodel/cli/migrations.py @@ -0,0 +1,453 @@ +import os +import re +from pathlib import Path +from typing import Optional + +import typer +from alembic.autogenerate import produce_migrations, render_python_code +from alembic.config import Config +from alembic.runtime.migration import MigrationContext +from sqlalchemy import create_engine, pool + +migrations_app = typer.Typer() + + +def get_migrations_dir(migrations_path: Optional[str] = None) -> Path: + """Get the migrations directory path.""" + if migrations_path: + return Path(migrations_path) + return Path.cwd() / "migrations" + + +def get_alembic_config(migrations_dir: Path) -> Config: + """Create an Alembic config object programmatically without ini file.""" + config = Config() + config.set_main_option("script_location", str(migrations_dir)) + config.set_main_option("sqlalchemy.url", "") # Will be set by env.py + return config + + +def get_next_migration_number(migrations_dir: Path) -> str: + """Get the next sequential migration number.""" + if not migrations_dir.exists(): + return "0001" + + migration_files = list(migrations_dir.glob("*.py")) + if not migration_files: + return "0001" + + numbers = [] + for f in migration_files: + match = re.match(r"^(\d{4})_", f.name) + if match: + numbers.append(int(match.group(1))) + + if not numbers: + return "0001" + + return f"{max(numbers) + 1:04d}" + + +def get_metadata(models_path: str): + """Import and return SQLModel metadata.""" + import sys + from importlib import import_module + + # Add current directory to Python path + sys.path.insert(0, str(Path.cwd())) + + try: + # Import the module containing the models + models_module = import_module(models_path) + + # Get SQLModel from the module or import it + if hasattr(models_module, "SQLModel"): + return models_module.SQLModel.metadata + else: + # Try importing SQLModel from sqlmodel + from sqlmodel import SQLModel + + return SQLModel.metadata + except ImportError as e: + raise ValueError( + f"Failed to import models from '{models_path}': {e}\n" + f"Make sure the module exists and is importable from the current directory." + ) + + +def get_current_revision(db_url: str) -> Optional[str]: + """Get the current revision from the database.""" + import sqlalchemy as sa + + engine = create_engine(db_url, poolclass=pool.NullPool) + + # Create alembic_version table if it doesn't exist + with engine.begin() as connection: + connection.execute( + sa.text(""" + CREATE TABLE IF NOT EXISTS alembic_version ( + version_num VARCHAR(32) NOT NULL, + CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num) + ) + """) + ) + + # Get current revision + with engine.connect() as connection: + result = connection.execute(sa.text("SELECT version_num FROM alembic_version")) + row = result.first() + return row[0] if row else None + + +def generate_migration_ops(db_url: str, metadata): + """Generate migration operations by comparing metadata to database.""" + + engine = create_engine(db_url, poolclass=pool.NullPool) + + with engine.connect() as connection: + migration_context = MigrationContext.configure(connection) + # Use produce_migrations which returns actual operation objects + migration_script = produce_migrations(migration_context, metadata) + + return migration_script.upgrade_ops, migration_script.downgrade_ops + + +@migrations_app.command() +def create( + message: str = typer.Option(..., "--message", "-m", help="Migration message"), + models: str = typer.Option( + ..., + "--models", + help="Python import path to models module (e.g., 'models' or 'app.models')", + ), + migrations_path: Optional[str] = typer.Option( + None, "--path", "-p", help="Path to migrations directory" + ), +) -> None: + """Create a new migration with autogenerate.""" + migrations_dir = get_migrations_dir(migrations_path) + + # Create migrations directory if it doesn't exist + migrations_dir.mkdir(parents=True, exist_ok=True) + + # Get next migration number + migration_number = get_next_migration_number(migrations_dir) + + # Create slug from message + # TODO: truncate, handle special characters + slug = message.lower().replace(" ", "_") + slug = re.sub(r"[^a-z0-9_]", "", slug) + + filename = f"{migration_number}_{slug}.py" + filepath = migrations_dir / filename + + typer.echo(f"Creating migration: {filename}") + + try: + # Get database URL + db_url = os.getenv("DATABASE_URL") + if not db_url: + raise ValueError( + "DATABASE_URL environment variable is not set. " + "Please set it to your database connection string." + ) + + # Get metadata + metadata = get_metadata(models) + + # Check if there are pending migrations that need to be applied + current_revision = get_current_revision(db_url) + existing_migrations = sorted( + [f for f in migrations_dir.glob("*.py") if f.name != "__init__.py"] + ) + + if existing_migrations: + # Get the latest migration file + latest_migration_file = existing_migrations[-1] + content = latest_migration_file.read_text() + # Extract revision = "..." from the file + match = re.search( + r'^revision = ["\']([^"\']+)["\']', content, re.MULTILINE + ) + if match: + latest_file_revision = match.group(1) + # Check if database is up to date + if current_revision != latest_file_revision: + typer.echo( + f"Error: Database is not up to date. Current revision: {current_revision or 'None'}, " + f"Latest migration: {latest_file_revision}", + err=True, + ) + typer.echo( + "Please run 'sqlmodel migrations migrate' to apply pending migrations before creating a new one.", + err=True, + ) + raise typer.Exit(1) + + # Generate migration operations + upgrade_ops_obj, downgrade_ops_obj = generate_migration_ops(db_url, metadata) + + # Render upgrade + if upgrade_ops_obj: + upgrade_code = render_python_code(upgrade_ops_obj).strip() + else: + upgrade_code = "pass" + + # Render downgrade + if downgrade_ops_obj: + downgrade_code = render_python_code(downgrade_ops_obj).strip() + else: + # TODO: space :) + downgrade_code = "pass" + + # Remove Alembic comments to check if migrations are actually empty + def extract_code_without_comments(code: str) -> str: + """Extract actual code, removing Alembic auto-generated comments.""" + lines = code.split("\n") + actual_lines = [] + for line in lines: + # Skip Alembic comment lines + if not line.strip().startswith("# ###"): + actual_lines.append(line) + return "\n".join(actual_lines).strip() + + upgrade_code_clean = extract_code_without_comments(upgrade_code) + downgrade_code_clean = extract_code_without_comments(downgrade_code) + + # Only reject empty migrations if there are already existing migrations + # (i.e., this is not the first migration) + if ( + upgrade_code_clean == "pass" + and downgrade_code_clean == "pass" + and len(existing_migrations) > 0 + ): + # TODO: better message + typer.echo( + "Empty migrations are not allowed" + ) # TODO: unless you pass `--empty` + raise typer.Exit(1) + + # Generate revision ID from filename (without .py extension) + revision_id = f"{migration_number}_{slug}" + + # Get previous revision by reading the last migration file's revision ID + down_revision = None + if existing_migrations: + # Read the last migration file to get its revision ID + last_migration = existing_migrations[-1] + content = last_migration.read_text() + # Extract revision = "..." from the file + import re as regex_module + + match = regex_module.search( + r'^revision = ["\']([^"\']+)["\']', content, regex_module.MULTILINE + ) + if match: + down_revision = match.group(1) + + # Check if we need to import sqlmodel + needs_sqlmodel = "sqlmodel" in upgrade_code or "sqlmodel" in downgrade_code + + # Generate migration file - build without f-strings to avoid % issues + lines: list[str] = [] + lines.append(f'"""{message}"""') + lines.append("") + lines.append("import sqlalchemy as sa") + if needs_sqlmodel: + lines.append("import sqlmodel") + lines.append("from alembic import op") + lines.append("") + lines.append('revision = "' + revision_id + '"') + lines.append("down_revision = " + repr(down_revision)) + lines.append("depends_on = None") + lines.append("") + lines.append("") + lines.append("def upgrade() -> None:") + + # Add upgrade code with proper indentation + for line in upgrade_code.split("\n"): + lines.append(line) + + lines.append("") + lines.append("") + lines.append("def downgrade() -> None:") + + # Add downgrade code with proper indentation + for line in downgrade_code.split("\n"): + lines.append(line) + + migration_content = "\n".join(lines) + + filepath.write_text(migration_content) + + typer.echo(f"✓ Created migration: {filename}") + + except ValueError as e: + typer.echo(f"Error creating migration: {e}", err=True) + raise typer.Exit(1) + except Exception as e: + import traceback + + typer.echo(f"Error creating migration: {e}", err=True) + typer.echo("\nFull traceback:", err=True) + traceback.print_exc() + raise typer.Exit(1) + + +def get_pending_migrations( + migrations_dir: Path, current_revision: Optional[str] +) -> list[Path]: + """Get list of pending migration files.""" + all_migrations = sorted( + [ + f + for f in migrations_dir.glob("*.py") + if f.name != "__init__.py" and f.name != "env.py" + ] + ) + + if not current_revision: + return all_migrations + + # Find migrations after the current revision + pending = [] + found_current = False + for migration_file in all_migrations: + if found_current: + pending.append(migration_file) + elif migration_file.stem == current_revision: + found_current = True + + return pending + + +def apply_migrations_programmatically( + migrations_dir: Path, db_url: str, models: str +) -> None: + """Apply migrations programmatically without env.py.""" + import importlib.util + + import sqlalchemy as sa + + # Get metadata + metadata = get_metadata(models) + + # Create engine + engine = create_engine(db_url, poolclass=pool.NullPool) + + # Create alembic_version table if it doesn't exist (outside transaction) + with engine.begin() as connection: + connection.execute( + sa.text(""" + CREATE TABLE IF NOT EXISTS alembic_version ( + version_num VARCHAR(32) NOT NULL, + CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num) + ) + """) + ) + + # Get current revision + with engine.connect() as connection: + result = connection.execute(sa.text("SELECT version_num FROM alembic_version")) + row = result.first() + current_revision = row[0] if row else None + + # Get pending migrations + pending_migrations = get_pending_migrations(migrations_dir, current_revision) + + if not pending_migrations: + typer.echo(" No pending migrations") + return + + # Run each migration + for migration_file in pending_migrations: + revision_id = migration_file.stem + typer.echo(f" Applying: {revision_id}") + + # Execute each migration in its own transaction + with engine.begin() as connection: + # Create migration context + migration_context = MigrationContext.configure( + connection, opts={"target_metadata": metadata} + ) + + # Load the migration module + spec = importlib.util.spec_from_file_location(revision_id, migration_file) + if not spec or not spec.loader: + raise ValueError(f"Could not load migration: {migration_file}") + + module = importlib.util.module_from_spec(spec) + + # Also make sqlalchemy and sqlmodel available + import sqlmodel + + module.sa = sa # type: ignore + module.sqlmodel = sqlmodel # type: ignore + + # Execute the module to define the functions + spec.loader.exec_module(module) + + # Create operations context and run upgrade within ops.invoke_for_target + from alembic.operations import ops + + with ops.Operations.context(migration_context): + # Now op proxy is available via alembic.op + from alembic import op as alembic_op + + module.op = alembic_op # type: ignore + + # Execute upgrade + module.upgrade() + + # Update alembic_version table + if current_revision: + connection.execute( + sa.text("UPDATE alembic_version SET version_num = :version"), + {"version": revision_id}, + ) + else: + connection.execute( + sa.text( + "INSERT INTO alembic_version (version_num) VALUES (:version)" + ), + {"version": revision_id}, + ) + + current_revision = revision_id + + +@migrations_app.command() +def migrate( + models: str = typer.Option( + ..., + "--models", + help="Python import path to models module (e.g., 'models' or 'app.models')", + ), + migrations_path: Optional[str] = typer.Option( + None, "--path", "-p", help="Path to migrations directory" + ), +) -> None: + """Apply all pending migrations to the database.""" + migrations_dir = get_migrations_dir(migrations_path) + + if not migrations_dir.exists(): + typer.echo( + f"Error: {migrations_dir} not found. Run 'sqlmodel migrations init' first.", + err=True, + ) + raise typer.Exit(1) + + # Get database URL + db_url = os.getenv("DATABASE_URL") + if not db_url: + typer.echo("Error: DATABASE_URL environment variable is not set.", err=True) + raise typer.Exit(1) + + typer.echo("Applying migrations...") + + try: + apply_migrations_programmatically(migrations_dir, db_url, models) + typer.echo("✓ Migrations applied successfully") + except Exception as e: + typer.echo(f"Error applying migrations: {e}", err=True) + raise typer.Exit(1) diff --git a/tests/test_cli/__init__.py b/tests/test_cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_cli/__inline_snapshot__/test_migrations/test_create_first_migration/f1182584-912e-4f31-9d79-2233e5a8a986.py b/tests/test_cli/__inline_snapshot__/test_migrations/test_create_first_migration/f1182584-912e-4f31-9d79-2233e5a8a986.py new file mode 100644 index 00000000..55dee2e1 --- /dev/null +++ b/tests/test_cli/__inline_snapshot__/test_migrations/test_create_first_migration/f1182584-912e-4f31-9d79-2233e5a8a986.py @@ -0,0 +1,27 @@ +"""Initial migration""" + +import sqlalchemy as sa +import sqlmodel +from alembic import op + +revision = "0001_initial_migration" +down_revision = None +depends_on = None + + +def upgrade() -> None: +# ### commands auto generated by Alembic - please adjust! ### + op.create_table('hero', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('secret_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('age', sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: +# ### commands auto generated by Alembic - please adjust! ### + op.drop_table('hero') + # ### end Alembic commands ### \ No newline at end of file diff --git a/tests/test_cli/fixtures/__init__.py b/tests/test_cli/fixtures/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_cli/fixtures/models_initial.py b/tests/test_cli/fixtures/models_initial.py new file mode 100644 index 00000000..12132bb7 --- /dev/null +++ b/tests/test_cli/fixtures/models_initial.py @@ -0,0 +1,10 @@ +from typing import Optional + +from sqlmodel import Field, SQLModel + + +class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + secret_name: str + age: Optional[int] = None diff --git a/tests/test_cli/test_migrations.py b/tests/test_cli/test_migrations.py new file mode 100644 index 00000000..001a7542 --- /dev/null +++ b/tests/test_cli/test_migrations.py @@ -0,0 +1,225 @@ +import shutil +from pathlib import Path + +import pytest +from inline_snapshot import external, register_format_alias +from sqlmodel.cli import app +from typer.testing import CliRunner + +register_format_alias(".py", ".txt") + +runner = CliRunner() + + +HERE = Path(__file__).parent + +register_format_alias(".html", ".txt") + + +def test_create_first_migration(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + """Test creating the first migration with an empty database.""" + db_path = tmp_path / "test.db" + db_url = f"sqlite:///{db_path}" + migrations_dir = tmp_path / "migrations" + + model_source = HERE / "./fixtures/models_initial.py" + + models_dir = tmp_path / "test_models" + models_dir.mkdir() + + (models_dir / "__init__.py").write_text("") + models_file = models_dir / "models.py" + + shutil.copy(model_source, models_file) + + monkeypatch.setenv("DATABASE_URL", db_url) + monkeypatch.chdir(tmp_path) + + # Run the create command + result = runner.invoke( + app, + [ + "migrations", + "create", + "-m", + "Initial migration", + "--models", + "test_models.models", + "--path", + str(migrations_dir), + ], + ) + + assert result.exit_code == 0, f"Command failed: {result.stdout}" + assert "✓ Created migration:" in result.stdout + + migration_files = sorted( + [str(f.relative_to(tmp_path)) for f in migrations_dir.glob("*.py")] + ) + + assert migration_files == [ + "migrations/0001_initial_migration.py", + ] + + migration_file = migrations_dir / "0001_initial_migration.py" + + assert migration_file.read_text() == external( + "uuid:f1182584-912e-4f31-9d79-2233e5a8a986.py" + ) + + +# TODO: to force migration you need to pass `--empty`s +def test_running_migration_twice_only_generates_migration_once( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +): + db_path = tmp_path / "test.db" + db_url = f"sqlite:///{db_path}" + migrations_dir = tmp_path / "migrations" + + model_source = HERE / "./fixtures/models_initial.py" + + models_dir = tmp_path / "test_models" + models_dir.mkdir() + + (models_dir / "__init__.py").write_text("") + models_file = models_dir / "models.py" + + shutil.copy(model_source, models_file) + + monkeypatch.setenv("DATABASE_URL", db_url) + monkeypatch.chdir(tmp_path) + + # Run the create command + result = runner.invoke( + app, + [ + "migrations", + "create", + "-m", + "Initial migration", + "--models", + "test_models.models", + "--path", + str(migrations_dir), + ], + ) + + assert result.exit_code == 0, f"Command failed: {result.stdout}" + assert "✓ Created migration:" in result.stdout + + # Apply the first migration to the database + result = runner.invoke( + app, + [ + "migrations", + "migrate", + "--models", + "test_models.models", + "--path", + str(migrations_dir), + ], + ) + + assert result.exit_code == 0, f"Migration failed: {result.stdout}" + + # Run the create command again (should fail with empty migration) + result = runner.invoke( + app, + [ + "migrations", + "create", + "-m", + "Initial migration", + "--models", + "test_models.models", + "--path", + str(migrations_dir), + ], + ) + + assert result.exit_code == 1 + assert "Empty migrations are not allowed" in result.stdout + + migration_files = sorted( + [str(f.relative_to(tmp_path)) for f in migrations_dir.glob("*.py")] + ) + + assert migration_files == [ + "migrations/0001_initial_migration.py", + ] + + +def test_cannot_create_migration_with_pending_migrations( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +): + """Test that creating a migration fails if there are unapplied migrations.""" + db_path = tmp_path / "test.db" + db_url = f"sqlite:///{db_path}" + migrations_dir = tmp_path / "migrations" + + model_source = HERE / "./fixtures/models_initial.py" + + models_dir = tmp_path / "test_models" + models_dir.mkdir() + + (models_dir / "__init__.py").write_text("") + models_file = models_dir / "models.py" + + shutil.copy(model_source, models_file) + + monkeypatch.setenv("DATABASE_URL", db_url) + monkeypatch.chdir(tmp_path) + + # Run the create command to create first migration + result = runner.invoke( + app, + [ + "migrations", + "create", + "-m", + "Initial migration", + "--models", + "test_models.models", + "--path", + str(migrations_dir), + ], + ) + + assert result.exit_code == 0, f"Command failed: {result.stdout}" + assert "✓ Created migration:" in result.stdout + + # Try to create another migration WITHOUT applying the first one + result = runner.invoke( + app, + [ + "migrations", + "create", + "-m", + "Second migration", + "--models", + "test_models.models", + "--path", + str(migrations_dir), + ], + ) + + # Should fail because database is not up to date + assert result.exit_code == 1 + # Error messages are printed to stderr, which Typer's CliRunner combines into output + assert ( + "Database is not up to date" in result.stdout + or "Database is not up to date" in str(result.output) + ) + assert ( + "Please run 'sqlmodel migrations migrate'" in result.stdout + or "Please run 'sqlmodel migrations migrate'" in str(result.output) + ) + + # Verify only one migration file exists + migration_files = sorted( + [str(f.relative_to(tmp_path)) for f in migrations_dir.glob("*.py")] + ) + + assert migration_files == [ + "migrations/0001_initial_migration.py", + ]