mirror of
https://github.com/fastapi/sqlmodel.git
synced 2025-12-12 23:46:59 +08:00
WIP
This commit is contained in:
@@ -38,8 +38,13 @@ classifiers = [
|
||||
dependencies = [
|
||||
"SQLAlchemy >=2.0.14,<2.1.0",
|
||||
"pydantic >=1.10.13,<3.0.0",
|
||||
"typer >=0.9.0",
|
||||
"alembic >=1.13.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
sqlmodel = "sqlmodel.cli:main"
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/fastapi/sqlmodel"
|
||||
Documentation = "https://sqlmodel.tiangolo.com"
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
-r requirements-docs-tests.txt
|
||||
pytest >=7.0.1,<9.0.0
|
||||
coverage[toml] >=6.2,<8.0
|
||||
inline-snapshot >=0.13.0
|
||||
# Remove when support for Python 3.8 is dropped
|
||||
mypy ==1.14.1; python_version < "3.9"
|
||||
mypy ==1.18.2; python_version >= "3.9"
|
||||
|
||||
10
sqlmodel/cli/__init__.py
Normal file
10
sqlmodel/cli/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import typer
|
||||
|
||||
from .migrations import migrations_app
|
||||
|
||||
app = typer.Typer()
|
||||
app.add_typer(migrations_app, name="migrations")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
app()
|
||||
453
sqlmodel/cli/migrations.py
Normal file
453
sqlmodel/cli/migrations.py
Normal file
@@ -0,0 +1,453 @@
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
from alembic.autogenerate import produce_migrations, render_python_code
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from sqlalchemy import create_engine, pool
|
||||
|
||||
migrations_app = typer.Typer()
|
||||
|
||||
|
||||
def get_migrations_dir(migrations_path: Optional[str] = None) -> Path:
|
||||
"""Get the migrations directory path."""
|
||||
if migrations_path:
|
||||
return Path(migrations_path)
|
||||
return Path.cwd() / "migrations"
|
||||
|
||||
|
||||
def get_alembic_config(migrations_dir: Path) -> Config:
|
||||
"""Create an Alembic config object programmatically without ini file."""
|
||||
config = Config()
|
||||
config.set_main_option("script_location", str(migrations_dir))
|
||||
config.set_main_option("sqlalchemy.url", "") # Will be set by env.py
|
||||
return config
|
||||
|
||||
|
||||
def get_next_migration_number(migrations_dir: Path) -> str:
|
||||
"""Get the next sequential migration number."""
|
||||
if not migrations_dir.exists():
|
||||
return "0001"
|
||||
|
||||
migration_files = list(migrations_dir.glob("*.py"))
|
||||
if not migration_files:
|
||||
return "0001"
|
||||
|
||||
numbers = []
|
||||
for f in migration_files:
|
||||
match = re.match(r"^(\d{4})_", f.name)
|
||||
if match:
|
||||
numbers.append(int(match.group(1)))
|
||||
|
||||
if not numbers:
|
||||
return "0001"
|
||||
|
||||
return f"{max(numbers) + 1:04d}"
|
||||
|
||||
|
||||
def get_metadata(models_path: str):
|
||||
"""Import and return SQLModel metadata."""
|
||||
import sys
|
||||
from importlib import import_module
|
||||
|
||||
# Add current directory to Python path
|
||||
sys.path.insert(0, str(Path.cwd()))
|
||||
|
||||
try:
|
||||
# Import the module containing the models
|
||||
models_module = import_module(models_path)
|
||||
|
||||
# Get SQLModel from the module or import it
|
||||
if hasattr(models_module, "SQLModel"):
|
||||
return models_module.SQLModel.metadata
|
||||
else:
|
||||
# Try importing SQLModel from sqlmodel
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
return SQLModel.metadata
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
f"Failed to import models from '{models_path}': {e}\n"
|
||||
f"Make sure the module exists and is importable from the current directory."
|
||||
)
|
||||
|
||||
|
||||
def get_current_revision(db_url: str) -> Optional[str]:
|
||||
"""Get the current revision from the database."""
|
||||
import sqlalchemy as sa
|
||||
|
||||
engine = create_engine(db_url, poolclass=pool.NullPool)
|
||||
|
||||
# Create alembic_version table if it doesn't exist
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS alembic_version (
|
||||
version_num VARCHAR(32) NOT NULL,
|
||||
CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num)
|
||||
)
|
||||
""")
|
||||
)
|
||||
|
||||
# Get current revision
|
||||
with engine.connect() as connection:
|
||||
result = connection.execute(sa.text("SELECT version_num FROM alembic_version"))
|
||||
row = result.first()
|
||||
return row[0] if row else None
|
||||
|
||||
|
||||
def generate_migration_ops(db_url: str, metadata):
|
||||
"""Generate migration operations by comparing metadata to database."""
|
||||
|
||||
engine = create_engine(db_url, poolclass=pool.NullPool)
|
||||
|
||||
with engine.connect() as connection:
|
||||
migration_context = MigrationContext.configure(connection)
|
||||
# Use produce_migrations which returns actual operation objects
|
||||
migration_script = produce_migrations(migration_context, metadata)
|
||||
|
||||
return migration_script.upgrade_ops, migration_script.downgrade_ops
|
||||
|
||||
|
||||
@migrations_app.command()
|
||||
def create(
|
||||
message: str = typer.Option(..., "--message", "-m", help="Migration message"),
|
||||
models: str = typer.Option(
|
||||
...,
|
||||
"--models",
|
||||
help="Python import path to models module (e.g., 'models' or 'app.models')",
|
||||
),
|
||||
migrations_path: Optional[str] = typer.Option(
|
||||
None, "--path", "-p", help="Path to migrations directory"
|
||||
),
|
||||
) -> None:
|
||||
"""Create a new migration with autogenerate."""
|
||||
migrations_dir = get_migrations_dir(migrations_path)
|
||||
|
||||
# Create migrations directory if it doesn't exist
|
||||
migrations_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get next migration number
|
||||
migration_number = get_next_migration_number(migrations_dir)
|
||||
|
||||
# Create slug from message
|
||||
# TODO: truncate, handle special characters
|
||||
slug = message.lower().replace(" ", "_")
|
||||
slug = re.sub(r"[^a-z0-9_]", "", slug)
|
||||
|
||||
filename = f"{migration_number}_{slug}.py"
|
||||
filepath = migrations_dir / filename
|
||||
|
||||
typer.echo(f"Creating migration: {filename}")
|
||||
|
||||
try:
|
||||
# Get database URL
|
||||
db_url = os.getenv("DATABASE_URL")
|
||||
if not db_url:
|
||||
raise ValueError(
|
||||
"DATABASE_URL environment variable is not set. "
|
||||
"Please set it to your database connection string."
|
||||
)
|
||||
|
||||
# Get metadata
|
||||
metadata = get_metadata(models)
|
||||
|
||||
# Check if there are pending migrations that need to be applied
|
||||
current_revision = get_current_revision(db_url)
|
||||
existing_migrations = sorted(
|
||||
[f for f in migrations_dir.glob("*.py") if f.name != "__init__.py"]
|
||||
)
|
||||
|
||||
if existing_migrations:
|
||||
# Get the latest migration file
|
||||
latest_migration_file = existing_migrations[-1]
|
||||
content = latest_migration_file.read_text()
|
||||
# Extract revision = "..." from the file
|
||||
match = re.search(
|
||||
r'^revision = ["\']([^"\']+)["\']', content, re.MULTILINE
|
||||
)
|
||||
if match:
|
||||
latest_file_revision = match.group(1)
|
||||
# Check if database is up to date
|
||||
if current_revision != latest_file_revision:
|
||||
typer.echo(
|
||||
f"Error: Database is not up to date. Current revision: {current_revision or 'None'}, "
|
||||
f"Latest migration: {latest_file_revision}",
|
||||
err=True,
|
||||
)
|
||||
typer.echo(
|
||||
"Please run 'sqlmodel migrations migrate' to apply pending migrations before creating a new one.",
|
||||
err=True,
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Generate migration operations
|
||||
upgrade_ops_obj, downgrade_ops_obj = generate_migration_ops(db_url, metadata)
|
||||
|
||||
# Render upgrade
|
||||
if upgrade_ops_obj:
|
||||
upgrade_code = render_python_code(upgrade_ops_obj).strip()
|
||||
else:
|
||||
upgrade_code = "pass"
|
||||
|
||||
# Render downgrade
|
||||
if downgrade_ops_obj:
|
||||
downgrade_code = render_python_code(downgrade_ops_obj).strip()
|
||||
else:
|
||||
# TODO: space :)
|
||||
downgrade_code = "pass"
|
||||
|
||||
# Remove Alembic comments to check if migrations are actually empty
|
||||
def extract_code_without_comments(code: str) -> str:
|
||||
"""Extract actual code, removing Alembic auto-generated comments."""
|
||||
lines = code.split("\n")
|
||||
actual_lines = []
|
||||
for line in lines:
|
||||
# Skip Alembic comment lines
|
||||
if not line.strip().startswith("# ###"):
|
||||
actual_lines.append(line)
|
||||
return "\n".join(actual_lines).strip()
|
||||
|
||||
upgrade_code_clean = extract_code_without_comments(upgrade_code)
|
||||
downgrade_code_clean = extract_code_without_comments(downgrade_code)
|
||||
|
||||
# Only reject empty migrations if there are already existing migrations
|
||||
# (i.e., this is not the first migration)
|
||||
if (
|
||||
upgrade_code_clean == "pass"
|
||||
and downgrade_code_clean == "pass"
|
||||
and len(existing_migrations) > 0
|
||||
):
|
||||
# TODO: better message
|
||||
typer.echo(
|
||||
"Empty migrations are not allowed"
|
||||
) # TODO: unless you pass `--empty`
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Generate revision ID from filename (without .py extension)
|
||||
revision_id = f"{migration_number}_{slug}"
|
||||
|
||||
# Get previous revision by reading the last migration file's revision ID
|
||||
down_revision = None
|
||||
if existing_migrations:
|
||||
# Read the last migration file to get its revision ID
|
||||
last_migration = existing_migrations[-1]
|
||||
content = last_migration.read_text()
|
||||
# Extract revision = "..." from the file
|
||||
import re as regex_module
|
||||
|
||||
match = regex_module.search(
|
||||
r'^revision = ["\']([^"\']+)["\']', content, regex_module.MULTILINE
|
||||
)
|
||||
if match:
|
||||
down_revision = match.group(1)
|
||||
|
||||
# Check if we need to import sqlmodel
|
||||
needs_sqlmodel = "sqlmodel" in upgrade_code or "sqlmodel" in downgrade_code
|
||||
|
||||
# Generate migration file - build without f-strings to avoid % issues
|
||||
lines: list[str] = []
|
||||
lines.append(f'"""{message}"""')
|
||||
lines.append("")
|
||||
lines.append("import sqlalchemy as sa")
|
||||
if needs_sqlmodel:
|
||||
lines.append("import sqlmodel")
|
||||
lines.append("from alembic import op")
|
||||
lines.append("")
|
||||
lines.append('revision = "' + revision_id + '"')
|
||||
lines.append("down_revision = " + repr(down_revision))
|
||||
lines.append("depends_on = None")
|
||||
lines.append("")
|
||||
lines.append("")
|
||||
lines.append("def upgrade() -> None:")
|
||||
|
||||
# Add upgrade code with proper indentation
|
||||
for line in upgrade_code.split("\n"):
|
||||
lines.append(line)
|
||||
|
||||
lines.append("")
|
||||
lines.append("")
|
||||
lines.append("def downgrade() -> None:")
|
||||
|
||||
# Add downgrade code with proper indentation
|
||||
for line in downgrade_code.split("\n"):
|
||||
lines.append(line)
|
||||
|
||||
migration_content = "\n".join(lines)
|
||||
|
||||
filepath.write_text(migration_content)
|
||||
|
||||
typer.echo(f"✓ Created migration: {filename}")
|
||||
|
||||
except ValueError as e:
|
||||
typer.echo(f"Error creating migration: {e}", err=True)
|
||||
raise typer.Exit(1)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
typer.echo(f"Error creating migration: {e}", err=True)
|
||||
typer.echo("\nFull traceback:", err=True)
|
||||
traceback.print_exc()
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def get_pending_migrations(
|
||||
migrations_dir: Path, current_revision: Optional[str]
|
||||
) -> list[Path]:
|
||||
"""Get list of pending migration files."""
|
||||
all_migrations = sorted(
|
||||
[
|
||||
f
|
||||
for f in migrations_dir.glob("*.py")
|
||||
if f.name != "__init__.py" and f.name != "env.py"
|
||||
]
|
||||
)
|
||||
|
||||
if not current_revision:
|
||||
return all_migrations
|
||||
|
||||
# Find migrations after the current revision
|
||||
pending = []
|
||||
found_current = False
|
||||
for migration_file in all_migrations:
|
||||
if found_current:
|
||||
pending.append(migration_file)
|
||||
elif migration_file.stem == current_revision:
|
||||
found_current = True
|
||||
|
||||
return pending
|
||||
|
||||
|
||||
def apply_migrations_programmatically(
|
||||
migrations_dir: Path, db_url: str, models: str
|
||||
) -> None:
|
||||
"""Apply migrations programmatically without env.py."""
|
||||
import importlib.util
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
# Get metadata
|
||||
metadata = get_metadata(models)
|
||||
|
||||
# Create engine
|
||||
engine = create_engine(db_url, poolclass=pool.NullPool)
|
||||
|
||||
# Create alembic_version table if it doesn't exist (outside transaction)
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
sa.text("""
|
||||
CREATE TABLE IF NOT EXISTS alembic_version (
|
||||
version_num VARCHAR(32) NOT NULL,
|
||||
CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num)
|
||||
)
|
||||
""")
|
||||
)
|
||||
|
||||
# Get current revision
|
||||
with engine.connect() as connection:
|
||||
result = connection.execute(sa.text("SELECT version_num FROM alembic_version"))
|
||||
row = result.first()
|
||||
current_revision = row[0] if row else None
|
||||
|
||||
# Get pending migrations
|
||||
pending_migrations = get_pending_migrations(migrations_dir, current_revision)
|
||||
|
||||
if not pending_migrations:
|
||||
typer.echo(" No pending migrations")
|
||||
return
|
||||
|
||||
# Run each migration
|
||||
for migration_file in pending_migrations:
|
||||
revision_id = migration_file.stem
|
||||
typer.echo(f" Applying: {revision_id}")
|
||||
|
||||
# Execute each migration in its own transaction
|
||||
with engine.begin() as connection:
|
||||
# Create migration context
|
||||
migration_context = MigrationContext.configure(
|
||||
connection, opts={"target_metadata": metadata}
|
||||
)
|
||||
|
||||
# Load the migration module
|
||||
spec = importlib.util.spec_from_file_location(revision_id, migration_file)
|
||||
if not spec or not spec.loader:
|
||||
raise ValueError(f"Could not load migration: {migration_file}")
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
|
||||
# Also make sqlalchemy and sqlmodel available
|
||||
import sqlmodel
|
||||
|
||||
module.sa = sa # type: ignore
|
||||
module.sqlmodel = sqlmodel # type: ignore
|
||||
|
||||
# Execute the module to define the functions
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Create operations context and run upgrade within ops.invoke_for_target
|
||||
from alembic.operations import ops
|
||||
|
||||
with ops.Operations.context(migration_context):
|
||||
# Now op proxy is available via alembic.op
|
||||
from alembic import op as alembic_op
|
||||
|
||||
module.op = alembic_op # type: ignore
|
||||
|
||||
# Execute upgrade
|
||||
module.upgrade()
|
||||
|
||||
# Update alembic_version table
|
||||
if current_revision:
|
||||
connection.execute(
|
||||
sa.text("UPDATE alembic_version SET version_num = :version"),
|
||||
{"version": revision_id},
|
||||
)
|
||||
else:
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"INSERT INTO alembic_version (version_num) VALUES (:version)"
|
||||
),
|
||||
{"version": revision_id},
|
||||
)
|
||||
|
||||
current_revision = revision_id
|
||||
|
||||
|
||||
@migrations_app.command()
|
||||
def migrate(
|
||||
models: str = typer.Option(
|
||||
...,
|
||||
"--models",
|
||||
help="Python import path to models module (e.g., 'models' or 'app.models')",
|
||||
),
|
||||
migrations_path: Optional[str] = typer.Option(
|
||||
None, "--path", "-p", help="Path to migrations directory"
|
||||
),
|
||||
) -> None:
|
||||
"""Apply all pending migrations to the database."""
|
||||
migrations_dir = get_migrations_dir(migrations_path)
|
||||
|
||||
if not migrations_dir.exists():
|
||||
typer.echo(
|
||||
f"Error: {migrations_dir} not found. Run 'sqlmodel migrations init' first.",
|
||||
err=True,
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Get database URL
|
||||
db_url = os.getenv("DATABASE_URL")
|
||||
if not db_url:
|
||||
typer.echo("Error: DATABASE_URL environment variable is not set.", err=True)
|
||||
raise typer.Exit(1)
|
||||
|
||||
typer.echo("Applying migrations...")
|
||||
|
||||
try:
|
||||
apply_migrations_programmatically(migrations_dir, db_url, models)
|
||||
typer.echo("✓ Migrations applied successfully")
|
||||
except Exception as e:
|
||||
typer.echo(f"Error applying migrations: {e}", err=True)
|
||||
raise typer.Exit(1)
|
||||
0
tests/test_cli/__init__.py
Normal file
0
tests/test_cli/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Initial migration"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
|
||||
revision = "0001_initial_migration"
|
||||
down_revision = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('hero',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('secret_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('age', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('hero')
|
||||
# ### end Alembic commands ###
|
||||
0
tests/test_cli/fixtures/__init__.py
Normal file
0
tests/test_cli/fixtures/__init__.py
Normal file
10
tests/test_cli/fixtures/models_initial.py
Normal file
10
tests/test_cli/fixtures/models_initial.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class Hero(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
name: str
|
||||
secret_name: str
|
||||
age: Optional[int] = None
|
||||
225
tests/test_cli/test_migrations.py
Normal file
225
tests/test_cli/test_migrations.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from inline_snapshot import external, register_format_alias
|
||||
from sqlmodel.cli import app
|
||||
from typer.testing import CliRunner
|
||||
|
||||
register_format_alias(".py", ".txt")
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
HERE = Path(__file__).parent
|
||||
|
||||
register_format_alias(".html", ".txt")
|
||||
|
||||
|
||||
def test_create_first_migration(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test creating the first migration with an empty database."""
|
||||
db_path = tmp_path / "test.db"
|
||||
db_url = f"sqlite:///{db_path}"
|
||||
migrations_dir = tmp_path / "migrations"
|
||||
|
||||
model_source = HERE / "./fixtures/models_initial.py"
|
||||
|
||||
models_dir = tmp_path / "test_models"
|
||||
models_dir.mkdir()
|
||||
|
||||
(models_dir / "__init__.py").write_text("")
|
||||
models_file = models_dir / "models.py"
|
||||
|
||||
shutil.copy(model_source, models_file)
|
||||
|
||||
monkeypatch.setenv("DATABASE_URL", db_url)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
# Run the create command
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"migrations",
|
||||
"create",
|
||||
"-m",
|
||||
"Initial migration",
|
||||
"--models",
|
||||
"test_models.models",
|
||||
"--path",
|
||||
str(migrations_dir),
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, f"Command failed: {result.stdout}"
|
||||
assert "✓ Created migration:" in result.stdout
|
||||
|
||||
migration_files = sorted(
|
||||
[str(f.relative_to(tmp_path)) for f in migrations_dir.glob("*.py")]
|
||||
)
|
||||
|
||||
assert migration_files == [
|
||||
"migrations/0001_initial_migration.py",
|
||||
]
|
||||
|
||||
migration_file = migrations_dir / "0001_initial_migration.py"
|
||||
|
||||
assert migration_file.read_text() == external(
|
||||
"uuid:f1182584-912e-4f31-9d79-2233e5a8a986.py"
|
||||
)
|
||||
|
||||
|
||||
# TODO: to force migration you need to pass `--empty`s
|
||||
def test_running_migration_twice_only_generates_migration_once(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
db_path = tmp_path / "test.db"
|
||||
db_url = f"sqlite:///{db_path}"
|
||||
migrations_dir = tmp_path / "migrations"
|
||||
|
||||
model_source = HERE / "./fixtures/models_initial.py"
|
||||
|
||||
models_dir = tmp_path / "test_models"
|
||||
models_dir.mkdir()
|
||||
|
||||
(models_dir / "__init__.py").write_text("")
|
||||
models_file = models_dir / "models.py"
|
||||
|
||||
shutil.copy(model_source, models_file)
|
||||
|
||||
monkeypatch.setenv("DATABASE_URL", db_url)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
# Run the create command
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"migrations",
|
||||
"create",
|
||||
"-m",
|
||||
"Initial migration",
|
||||
"--models",
|
||||
"test_models.models",
|
||||
"--path",
|
||||
str(migrations_dir),
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, f"Command failed: {result.stdout}"
|
||||
assert "✓ Created migration:" in result.stdout
|
||||
|
||||
# Apply the first migration to the database
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"migrations",
|
||||
"migrate",
|
||||
"--models",
|
||||
"test_models.models",
|
||||
"--path",
|
||||
str(migrations_dir),
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, f"Migration failed: {result.stdout}"
|
||||
|
||||
# Run the create command again (should fail with empty migration)
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"migrations",
|
||||
"create",
|
||||
"-m",
|
||||
"Initial migration",
|
||||
"--models",
|
||||
"test_models.models",
|
||||
"--path",
|
||||
str(migrations_dir),
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "Empty migrations are not allowed" in result.stdout
|
||||
|
||||
migration_files = sorted(
|
||||
[str(f.relative_to(tmp_path)) for f in migrations_dir.glob("*.py")]
|
||||
)
|
||||
|
||||
assert migration_files == [
|
||||
"migrations/0001_initial_migration.py",
|
||||
]
|
||||
|
||||
|
||||
def test_cannot_create_migration_with_pending_migrations(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""Test that creating a migration fails if there are unapplied migrations."""
|
||||
db_path = tmp_path / "test.db"
|
||||
db_url = f"sqlite:///{db_path}"
|
||||
migrations_dir = tmp_path / "migrations"
|
||||
|
||||
model_source = HERE / "./fixtures/models_initial.py"
|
||||
|
||||
models_dir = tmp_path / "test_models"
|
||||
models_dir.mkdir()
|
||||
|
||||
(models_dir / "__init__.py").write_text("")
|
||||
models_file = models_dir / "models.py"
|
||||
|
||||
shutil.copy(model_source, models_file)
|
||||
|
||||
monkeypatch.setenv("DATABASE_URL", db_url)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
# Run the create command to create first migration
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"migrations",
|
||||
"create",
|
||||
"-m",
|
||||
"Initial migration",
|
||||
"--models",
|
||||
"test_models.models",
|
||||
"--path",
|
||||
str(migrations_dir),
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, f"Command failed: {result.stdout}"
|
||||
assert "✓ Created migration:" in result.stdout
|
||||
|
||||
# Try to create another migration WITHOUT applying the first one
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"migrations",
|
||||
"create",
|
||||
"-m",
|
||||
"Second migration",
|
||||
"--models",
|
||||
"test_models.models",
|
||||
"--path",
|
||||
str(migrations_dir),
|
||||
],
|
||||
)
|
||||
|
||||
# Should fail because database is not up to date
|
||||
assert result.exit_code == 1
|
||||
# Error messages are printed to stderr, which Typer's CliRunner combines into output
|
||||
assert (
|
||||
"Database is not up to date" in result.stdout
|
||||
or "Database is not up to date" in str(result.output)
|
||||
)
|
||||
assert (
|
||||
"Please run 'sqlmodel migrations migrate'" in result.stdout
|
||||
or "Please run 'sqlmodel migrations migrate'" in str(result.output)
|
||||
)
|
||||
|
||||
# Verify only one migration file exists
|
||||
migration_files = sorted(
|
||||
[str(f.relative_to(tmp_path)) for f in migrations_dir.glob("*.py")]
|
||||
)
|
||||
|
||||
assert migration_files == [
|
||||
"migrations/0001_initial_migration.py",
|
||||
]
|
||||
Reference in New Issue
Block a user