This commit is contained in:
Patrick Arminio
2025-11-24 19:46:20 +00:00
parent 7c98f59a23
commit e625317014
2 changed files with 28 additions and 18 deletions

View File

@@ -17,8 +17,8 @@ except ImportError:
migrations_app = typer.Typer() migrations_app = typer.Typer()
def get_models_path_from_config() -> str: def get_config_from_pyproject() -> dict:
"""Get the models path from pyproject.toml configuration.""" """Load and return the [tool.sqlmodel] configuration from pyproject.toml."""
pyproject_path = Path.cwd() / "pyproject.toml" pyproject_path = Path.cwd() / "pyproject.toml"
if not pyproject_path.exists(): if not pyproject_path.exists():
@@ -30,7 +30,7 @@ def get_models_path_from_config() -> str:
with open(pyproject_path, "rb") as f: with open(pyproject_path, "rb") as f:
config = tomllib.load(f) config = tomllib.load(f)
# Try to get models path from [tool.sqlmodel] # Try to get [tool.sqlmodel] section
if "tool" not in config or "sqlmodel" not in config["tool"]: if "tool" not in config or "sqlmodel" not in config["tool"]:
raise ValueError( raise ValueError(
"No [tool.sqlmodel] section found in pyproject.toml. " "No [tool.sqlmodel] section found in pyproject.toml. "
@@ -39,7 +39,12 @@ def get_models_path_from_config() -> str:
"models = \"your.models.path\"\n" "models = \"your.models.path\"\n"
) )
sqlmodel_config = config["tool"]["sqlmodel"] return config["tool"]["sqlmodel"]
def get_models_path_from_config() -> str:
"""Get the models path from pyproject.toml configuration."""
sqlmodel_config = get_config_from_pyproject()
if "models" not in sqlmodel_config: if "models" not in sqlmodel_config:
raise ValueError( raise ValueError(
@@ -53,9 +58,25 @@ def get_models_path_from_config() -> str:
def get_migrations_dir(migrations_path: Optional[str] = None) -> Path: def get_migrations_dir(migrations_path: Optional[str] = None) -> Path:
"""Get the migrations directory path.""" """Get the migrations directory path.
Priority:
1. Explicit migrations_path parameter
2. migrations_path in [tool.sqlmodel] in pyproject.toml
3. Default to ./migrations
"""
if migrations_path: if migrations_path:
return Path(migrations_path) return Path(migrations_path)
# Try to get from config
try:
sqlmodel_config = get_config_from_pyproject()
if "migrations_path" in sqlmodel_config:
return Path(sqlmodel_config["migrations_path"])
except ValueError:
# No pyproject.toml or no [tool.sqlmodel] section, use default
pass
return Path.cwd() / "migrations" return Path.cwd() / "migrations"

View File

@@ -45,9 +45,10 @@ def migration_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> MigrationT
shutil.copy(model_source, models_file) shutil.copy(model_source, models_file)
# Create pyproject.toml with [tool.sqlmodel] configuration # Create pyproject.toml with [tool.sqlmodel] configuration
pyproject_content = """\ pyproject_content = f"""\
[tool.sqlmodel] [tool.sqlmodel]
models = "test_models.models" models = "test_models.models"
migrations_path = "{migrations_dir}"
""" """
(tmp_path / "pyproject.toml").write_text(pyproject_content) (tmp_path / "pyproject.toml").write_text(pyproject_content)
@@ -74,8 +75,6 @@ def test_create_first_migration(migration_env: MigrationTestEnv):
"create", "create",
"-m", "-m",
"Initial migration", "Initial migration",
"--path",
str(migration_env.migrations_dir),
], ],
) )
@@ -110,8 +109,6 @@ def test_running_migration_twice_only_generates_migration_once(
"create", "create",
"-m", "-m",
"Initial migration", "Initial migration",
"--path",
str(migration_env.migrations_dir),
], ],
) )
@@ -124,8 +121,6 @@ def test_running_migration_twice_only_generates_migration_once(
[ [
"migrations", "migrations",
"migrate", "migrate",
"--path",
str(migration_env.migrations_dir),
], ],
) )
@@ -139,8 +134,6 @@ def test_running_migration_twice_only_generates_migration_once(
"create", "create",
"-m", "-m",
"Initial migration", "Initial migration",
"--path",
str(migration_env.migrations_dir),
], ],
) )
@@ -168,8 +161,6 @@ def test_cannot_create_migration_with_pending_migrations(
"create", "create",
"-m", "-m",
"Initial migration", "Initial migration",
"--path",
str(migration_env.migrations_dir),
], ],
) )
@@ -184,8 +175,6 @@ def test_cannot_create_migration_with_pending_migrations(
"create", "create",
"-m", "-m",
"Second migration", "Second migration",
"--path",
str(migration_env.migrations_dir),
], ],
) )