🎨 [pre-commit.ci] Auto format from pre-commit.com hooks

This commit is contained in:
pre-commit-ci[bot]
2025-06-20 07:22:42 +00:00
parent 02d2e7da0a
commit f464ab4b9a
66 changed files with 847 additions and 701 deletions

View File

@ -1,12 +1,11 @@
import importlib import importlib
import types # Add import for types import types # Add import for types
from decimal import Decimal from decimal import Decimal
from unittest.mock import MagicMock # Keep MagicMock for type hint, though not strictly necessary for runtime
import pytest import pytest
from sqlmodel import create_engine from sqlmodel import create_engine
from ...conftest import needs_py310, PrintMock # Import PrintMock for type hint from ...conftest import PrintMock, needs_py310 # Import PrintMock for type hint
expected_calls = [ expected_calls = [
[ [
@ -45,7 +44,9 @@ def get_module(request: pytest.FixtureRequest):
return importlib.import_module(f"docs_src.advanced.decimal.{module_name}") return importlib.import_module(f"docs_src.advanced.decimal.{module_name}")
def test_tutorial(print_mock: PrintMock, module: types.ModuleType): # Use PrintMock for type hint and types.ModuleType def test_tutorial(
print_mock: PrintMock, module: types.ModuleType
): # Use PrintMock for type hint and types.ModuleType
module.sqlite_url = "sqlite://" module.sqlite_url = "sqlite://"
module.engine = create_engine(module.sqlite_url) module.engine = create_engine(module.sqlite_url)
module.main() module.main()

View File

@ -69,9 +69,7 @@ expected_calls = [
) )
def get_module(request: pytest.FixtureRequest) -> ModuleType: def get_module(request: pytest.FixtureRequest) -> ModuleType:
module_name = request.param module_name = request.param
mod = importlib.import_module( mod = importlib.import_module(f"docs_src.tutorial.connect.delete.{module_name}")
f"docs_src.tutorial.connect.delete.{module_name}"
)
mod.sqlite_url = "sqlite://" mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url) mod.engine = create_engine(mod.sqlite_url)
return mod return mod

View File

@ -49,9 +49,7 @@ expected_calls = [
) )
def get_module(request: pytest.FixtureRequest) -> ModuleType: def get_module(request: pytest.FixtureRequest) -> ModuleType:
module_name = request.param module_name = request.param
mod = importlib.import_module( mod = importlib.import_module(f"docs_src.tutorial.connect.insert.{module_name}")
f"docs_src.tutorial.connect.insert.{module_name}"
)
mod.sqlite_url = "sqlite://" mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url) mod.engine = create_engine(mod.sqlite_url)
return mod return mod

View File

@ -85,9 +85,7 @@ expected_calls = [
) )
def get_module(request: pytest.FixtureRequest) -> ModuleType: def get_module(request: pytest.FixtureRequest) -> ModuleType:
module_name = request.param module_name = request.param
mod = importlib.import_module( mod = importlib.import_module(f"docs_src.tutorial.connect.select.{module_name}")
f"docs_src.tutorial.connect.select.{module_name}"
)
mod.sqlite_url = "sqlite://" mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url) mod.engine = create_engine(mod.sqlite_url)
return mod return mod

View File

@ -59,9 +59,7 @@ expected_calls = [
) )
def get_module(request: pytest.FixtureRequest) -> ModuleType: def get_module(request: pytest.FixtureRequest) -> ModuleType:
module_name = request.param module_name = request.param
mod = importlib.import_module( mod = importlib.import_module(f"docs_src.tutorial.connect.select.{module_name}")
f"docs_src.tutorial.connect.select.{module_name}"
)
mod.sqlite_url = "sqlite://" mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url) mod.engine = create_engine(mod.sqlite_url)
return mod return mod

View File

@ -61,9 +61,7 @@ expected_calls = [
) )
def get_module(request: pytest.FixtureRequest) -> ModuleType: def get_module(request: pytest.FixtureRequest) -> ModuleType:
module_name = request.param module_name = request.param
mod = importlib.import_module( mod = importlib.import_module(f"docs_src.tutorial.connect.select.{module_name}")
f"docs_src.tutorial.connect.select.{module_name}"
)
mod.sqlite_url = "sqlite://" mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url) mod.engine = create_engine(mod.sqlite_url)
return mod return mod

View File

@ -60,14 +60,14 @@ expected_calls = [
) )
def get_module(request: pytest.FixtureRequest) -> ModuleType: def get_module(request: pytest.FixtureRequest) -> ModuleType:
module_name = request.param module_name = request.param
mod = importlib.import_module( mod = importlib.import_module(f"docs_src.tutorial.connect.update.{module_name}")
f"docs_src.tutorial.connect.update.{module_name}"
)
mod.sqlite_url = "sqlite://" mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url) mod.engine = create_engine(mod.sqlite_url)
return mod return mod
def test_tutorial(clear_sqlmodel: Any, print_mock: PrintMock, module: ModuleType) -> None: def test_tutorial(
clear_sqlmodel: Any, print_mock: PrintMock, module: ModuleType
) -> None:
module.main() module.main()
assert print_mock.calls == expected_calls assert print_mock.calls == expected_calls

View File

@ -10,6 +10,7 @@ from sqlmodel.pool import StaticPool # Keep this for session_fixture
from ....conftest import needs_py39, needs_py310 from ....conftest import needs_py39, needs_py310
# This will be our parametrized fixture providing the versioned 'main' module # This will be our parametrized fixture providing the versioned 'main' module
@pytest.fixture( @pytest.fixture(
name="module", name="module",
@ -20,7 +21,9 @@ from ....conftest import needs_py39, needs_py310
pytest.param("tutorial001_py310", marks=needs_py310), pytest.param("tutorial001_py310", marks=needs_py310),
], ],
) )
def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any) -> ModuleType: # clear_sqlmodel is autouse def get_module(
request: pytest.FixtureRequest, clear_sqlmodel: Any
) -> ModuleType: # clear_sqlmodel is autouse
module_name = f"docs_src.tutorial.fastapi.app_testing.{request.param}.main" module_name = f"docs_src.tutorial.fastapi.app_testing.{request.param}.main"
# Forcing reload to try to get a fresh state for models # Forcing reload to try to get a fresh state for models
@ -30,6 +33,7 @@ def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any) -> ModuleTyp
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
return module return module
@pytest.fixture(name="session", scope="function") @pytest.fixture(name="session", scope="function")
def session_fixture(module: ModuleType) -> Generator[Session, None, None]: def session_fixture(module: ModuleType) -> Generator[Session, None, None]:
# Store original engine-related attributes from the module # Store original engine-related attributes from the module
@ -45,7 +49,7 @@ def session_fixture(module: ModuleType) -> Generator[Session, None, None]:
test_engine = create_engine( test_engine = create_engine(
module.sqlite_url, module.sqlite_url,
connect_args=module.connect_args, connect_args=module.connect_args,
poolclass=StaticPool # Recommended for tests poolclass=StaticPool, # Recommended for tests
) )
module.engine = test_engine module.engine = test_engine
@ -55,7 +59,9 @@ def session_fixture(module: ModuleType) -> Generator[Session, None, None]:
# Fallback if the function isn't named create_db_and_tables # Fallback if the function isn't named create_db_and_tables
SQLModel.metadata.create_all(module.engine) SQLModel.metadata.create_all(module.engine)
with Session(module.engine) as session: # Use the module's (now test-configured) engine with Session(
module.engine
) as session: # Use the module's (now test-configured) engine
yield session yield session
# Teardown: drop tables from the module's engine # Teardown: drop tables from the module's engine
@ -74,7 +80,9 @@ def session_fixture(module: ModuleType) -> Generator[Session, None, None]:
@pytest.fixture(name="client", scope="function") @pytest.fixture(name="client", scope="function")
def client_fixture(session: Session, module: ModuleType) -> Generator[TestClient, None, None]: def client_fixture(
session: Session, module: ModuleType
) -> Generator[TestClient, None, None]:
def get_session_override() -> Generator[Session, None, None]: # Must be a generator def get_session_override() -> Generator[Session, None, None]: # Must be a generator
yield session yield session

View File

@ -35,18 +35,22 @@ def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any) -> ModuleTyp
module.engine = create_engine( module.engine = create_engine(
module.sqlite_url, module.sqlite_url,
connect_args={"check_same_thread": False}, # connect_args from original main.py connect_args={"check_same_thread": False}, # connect_args from original main.py
poolclass=StaticPool poolclass=StaticPool,
) )
# Assuming the module has a create_db_and_tables or similar, or uses SQLModel.metadata directly # Assuming the module has a create_db_and_tables or similar, or uses SQLModel.metadata directly
if hasattr(module, "create_db_and_tables"): if hasattr(module, "create_db_and_tables"):
module.create_db_and_tables() module.create_db_and_tables()
else: else:
SQLModel.metadata.create_all(module.engine) # Fallback, ensure tables are created SQLModel.metadata.create_all(
module.engine
) # Fallback, ensure tables are created
return module return module
def test_tutorial(clear_sqlmodel: Any, module: ModuleType): # clear_sqlmodel is autouse but explicit for safety def test_tutorial(
clear_sqlmodel: Any, module: ModuleType
): # clear_sqlmodel is autouse but explicit for safety
# The engine and tables are now set up by the 'module' fixture # The engine and tables are now set up by the 'module' fixture
# The app's dependency overrides for get_session will use module.engine # The app's dependency overrides for get_session will use module.engine
@ -71,7 +75,9 @@ def test_tutorial(clear_sqlmodel: Any, module: ModuleType): # clear_sqlmodel is
response = client.post("/heroes/", json=hero2_data) response = client.post("/heroes/", json=hero2_data)
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
hero2 = response.json() hero2 = response.json()
hero2_id = hero2["id"] # This will be the ID assigned by DB, not 9000 if 9000 is not allowed on POST hero2_id = hero2[
"id"
] # This will be the ID assigned by DB, not 9000 if 9000 is not allowed on POST
response = client.post("/heroes/", json=hero3_data) response = client.post("/heroes/", json=hero3_data)
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
@ -102,7 +108,9 @@ def test_tutorial(clear_sqlmodel: Any, module: ModuleType): # clear_sqlmodel is
) )
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
response = client.patch(f"/heroes/{non_existent_id_check}", json={"name": "Dragon Cube X"}) response = client.patch(
f"/heroes/{non_existent_id_check}", json={"name": "Dragon Cube X"}
)
assert response.status_code == 404, response.text assert response.status_code == 404, response.text
response = client.delete(f"/heroes/{hero2_id}") response = client.delete(f"/heroes/{hero2_id}")

View File

@ -22,7 +22,9 @@ from ....conftest import needs_py39, needs_py310
], ],
) )
def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any) -> ModuleType: def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any) -> ModuleType:
module_name = f"docs_src.tutorial.fastapi.limit_and_offset.{request.param}" # No .main module_name = (
f"docs_src.tutorial.fastapi.limit_and_offset.{request.param}" # No .main
)
if module_name in sys.modules: if module_name in sys.modules:
module = importlib.reload(sys.modules[module_name]) module = importlib.reload(sys.modules[module_name])
else: else:
@ -31,8 +33,10 @@ def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any) -> ModuleTyp
module.sqlite_url = "sqlite://" module.sqlite_url = "sqlite://"
module.engine = create_engine( module.engine = create_engine(
module.sqlite_url, module.sqlite_url,
connect_args={"check_same_thread": False}, # Assuming connect_args was in original mod or default connect_args={
poolclass=StaticPool "check_same_thread": False
}, # Assuming connect_args was in original mod or default
poolclass=StaticPool,
) )
if hasattr(module, "create_db_and_tables"): if hasattr(module, "create_db_and_tables"):
module.create_db_and_tables() module.create_db_and_tables()
@ -92,7 +96,9 @@ def test_tutorial(clear_sqlmodel: Any, module: ModuleType):
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
data_limit2 = response.json() data_limit2 = response.json()
assert len(data_limit2) == 2 assert len(data_limit2) == 2
assert data_limit2[0]["name"] == hero1["name"] # Compare with actual created hero data assert (
data_limit2[0]["name"] == hero1["name"]
) # Compare with actual created hero data
assert data_limit2[1]["name"] == hero2["name"] assert data_limit2[1]["name"] == hero2["name"]
response = client.get("/heroes/", params={"offset": 1}) response = client.get("/heroes/", params={"offset": 1})

View File

@ -24,7 +24,9 @@ from ....conftest import needs_py39, needs_py310
], ],
) )
def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any) -> ModuleType: def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any) -> ModuleType:
module_name = f"docs_src.tutorial.fastapi.multiple_models.{request.param}" # No .main module_name = (
f"docs_src.tutorial.fastapi.multiple_models.{request.param}" # No .main
)
if module_name in sys.modules: if module_name in sys.modules:
module = importlib.reload(sys.modules[module_name]) module = importlib.reload(sys.modules[module_name])
else: else:
@ -38,9 +40,7 @@ def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any) -> ModuleTyp
connect_args["check_same_thread"] = False connect_args["check_same_thread"] = False
module.engine = create_engine( module.engine = create_engine(
module.sqlite_url, module.sqlite_url, connect_args=connect_args, poolclass=StaticPool
connect_args=connect_args,
poolclass=StaticPool
) )
if hasattr(module, "create_db_and_tables"): if hasattr(module, "create_db_and_tables"):
module.create_db_and_tables() module.create_db_and_tables()
@ -80,7 +80,6 @@ def test_tutorial(clear_sqlmodel: Any, module: ModuleType):
assert data["age"] is None assert data["age"] is None
hero2_id = data["id"] # Store actual ID hero2_id = data["id"] # Store actual ID
response = client.get("/heroes/") response = client.get("/heroes/")
data = response.json() data = response.json()
@ -95,7 +94,6 @@ def test_tutorial(clear_sqlmodel: Any, module: ModuleType):
assert data[1]["name"] == hero2_data["name"] assert data[1]["name"] == hero2_data["name"]
assert data[1]["secret_name"] == hero2_data["secret_name"] assert data[1]["secret_name"] == hero2_data["secret_name"]
response = client.get("/openapi.json") response = client.get("/openapi.json")
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
# OpenAPI schema check - kept as is from original test # OpenAPI schema check - kept as is from original test
@ -255,10 +253,16 @@ def test_tutorial(clear_sqlmodel: Any, module: ModuleType):
] ]
# Convert list of dicts to list of tuples of sorted items for order-agnostic comparison # Convert list of dicts to list of tuples of sorted items for order-agnostic comparison
indexes_for_comparison = [tuple(sorted(d.items())) for d in indexes] indexes_for_comparison = [tuple(sorted(d.items())) for d in indexes]
expected_indexes_for_comparison = [tuple(sorted(d.items())) for d in expected_indexes] expected_indexes_for_comparison = [
tuple(sorted(d.items())) for d in expected_indexes
]
for index_data_tuple in expected_indexes_for_comparison: for index_data_tuple in expected_indexes_for_comparison:
assert index_data_tuple in indexes_for_comparison, f"Expected index {index_data_tuple} not found in DB indexes {indexes_for_comparison}" assert index_data_tuple in indexes_for_comparison, (
f"Expected index {index_data_tuple} not found in DB indexes {indexes_for_comparison}"
)
indexes_for_comparison.remove(index_data_tuple) indexes_for_comparison.remove(index_data_tuple)
assert len(indexes_for_comparison) == 0, f"Unexpected extra indexes found in DB: {indexes_for_comparison}" assert len(indexes_for_comparison) == 0, (
f"Unexpected extra indexes found in DB: {indexes_for_comparison}"
)

View File

@ -19,8 +19,12 @@ from ....conftest import needs_py39, needs_py310
scope="function", scope="function",
params=[ params=[
"tutorial002", # Changed to tutorial002 "tutorial002", # Changed to tutorial002
pytest.param("tutorial002_py39", marks=needs_py39), # Changed to tutorial002_py39 pytest.param(
pytest.param("tutorial002_py310", marks=needs_py310), # Changed to tutorial002_py310 "tutorial002_py39", marks=needs_py39
), # Changed to tutorial002_py39
pytest.param(
"tutorial002_py310", marks=needs_py310
), # Changed to tutorial002_py310
], ],
) )
def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any) -> ModuleType: def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any) -> ModuleType:
@ -36,9 +40,7 @@ def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any) -> ModuleTyp
connect_args["check_same_thread"] = False connect_args["check_same_thread"] = False
module.engine = create_engine( module.engine = create_engine(
module.sqlite_url, module.sqlite_url, connect_args=connect_args, poolclass=StaticPool
connect_args=connect_args,
poolclass=StaticPool
) )
if hasattr(module, "create_db_and_tables"): if hasattr(module, "create_db_and_tables"):
module.create_db_and_tables() module.create_db_and_tables()
@ -75,7 +77,6 @@ def test_tutorial(clear_sqlmodel: Any, module: ModuleType):
assert data["age"] is None assert data["age"] is None
hero2_id = data["id"] hero2_id = data["id"]
response = client.get("/heroes/") response = client.get("/heroes/")
data = response.json() data = response.json()
@ -88,7 +89,6 @@ def test_tutorial(clear_sqlmodel: Any, module: ModuleType):
assert data[1]["name"] == hero2_data["name"] assert data[1]["name"] == hero2_data["name"]
assert data[1]["secret_name"] == hero2_data["secret_name"] assert data[1]["secret_name"] == hero2_data["secret_name"]
response = client.get("/openapi.json") response = client.get("/openapi.json")
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert response.json() == { assert response.json() == {
@ -246,10 +246,16 @@ def test_tutorial(clear_sqlmodel: Any, module: ModuleType):
}, },
] ]
indexes_for_comparison = [tuple(sorted(d.items())) for d in indexes] indexes_for_comparison = [tuple(sorted(d.items())) for d in indexes]
expected_indexes_for_comparison = [tuple(sorted(d.items())) for d in expected_indexes] expected_indexes_for_comparison = [
tuple(sorted(d.items())) for d in expected_indexes
]
for index_data_tuple in expected_indexes_for_comparison: for index_data_tuple in expected_indexes_for_comparison:
assert index_data_tuple in indexes_for_comparison, f"Expected index {index_data_tuple} not found in DB indexes {indexes_for_comparison}" assert index_data_tuple in indexes_for_comparison, (
f"Expected index {index_data_tuple} not found in DB indexes {indexes_for_comparison}"
)
indexes_for_comparison.remove(index_data_tuple) indexes_for_comparison.remove(index_data_tuple)
assert len(indexes_for_comparison) == 0, f"Unexpected extra indexes found in DB: {indexes_for_comparison}" assert len(indexes_for_comparison) == 0, (
f"Unexpected extra indexes found in DB: {indexes_for_comparison}"
)

View File

@ -34,9 +34,7 @@ def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any) -> ModuleTyp
connect_args["check_same_thread"] = False connect_args["check_same_thread"] = False
module.engine = create_engine( module.engine = create_engine(
module.sqlite_url, module.sqlite_url, connect_args=connect_args, poolclass=StaticPool
connect_args=connect_args,
poolclass=StaticPool
) )
if hasattr(module, "create_db_and_tables"): if hasattr(module, "create_db_and_tables"):
module.create_db_and_tables() module.create_db_and_tables()
@ -79,7 +77,9 @@ def test_tutorial(clear_sqlmodel: Any, module: ModuleType):
# Check for a non-existent ID # Check for a non-existent ID
non_existent_id = hero1["id"] + hero2["id"] + 100 # A likely non-existent ID non_existent_id = hero1["id"] + hero2["id"] + 100 # A likely non-existent ID
response_get_non_existent = client.get(f"/heroes/{non_existent_id}") response_get_non_existent = client.get(f"/heroes/{non_existent_id}")
assert response_get_non_existent.status_code == 404, response_get_non_existent.text assert response_get_non_existent.status_code == 404, (
response_get_non_existent.text
)
response_openapi = client.get("/openapi.json") response_openapi = client.get("/openapi.json")
assert response_openapi.status_code == 200, response_openapi.text assert response_openapi.status_code == 200, response_openapi.text

View File

@ -4,9 +4,8 @@ import types
from typing import Any from typing import Any
import pytest import pytest
from dirty_equals import IsDict
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from sqlmodel import create_engine, SQLModel from sqlmodel import SQLModel, create_engine
from sqlmodel.pool import StaticPool from sqlmodel.pool import StaticPool
from ....conftest import needs_py39, needs_py310 from ....conftest import needs_py39, needs_py310
@ -108,7 +107,9 @@ def test_tutorial(module: types.ModuleType):
response = client.post("/heroes/", json=hero3_data) response = client.post("/heroes/", json=hero3_data)
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
response = client.get("/heroes/9000") # This might fail if hero2_id is not 9000 response = client.get("/heroes/9000") # This might fail if hero2_id is not 9000
assert response.status_code == 404, response.text # Original test expects 404, this implies ID 9000 is not found after creation. This needs to align with how IDs are handled. assert response.status_code == 404, (
response.text
) # Original test expects 404, this implies ID 9000 is not found after creation. This needs to align with how IDs are handled.
response = client.get("/heroes/") response = client.get("/heroes/")
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
@ -120,18 +121,25 @@ def test_tutorial(module: types.ModuleType):
data = response.json() data = response.json()
assert data["name"] == hero1_data["name"] assert data["name"] == hero1_data["name"]
# Ensure team is loaded and correct # Ensure team is loaded and correct
if "team" in data and data["team"] is not None: # Team might not be present if not correctly loaded by the endpoint if (
"team" in data and data["team"] is not None
): # Team might not be present if not correctly loaded by the endpoint
assert data["team"]["name"] == team_z_force["name"] assert data["team"]["name"] == team_z_force["name"]
elif short_module_name != "tutorial001_py310": # tutorial001_py310.py doesn't include team in HeroPublic elif (
short_module_name != "tutorial001_py310"
): # tutorial001_py310.py doesn't include team in HeroPublic
# If team is expected, this is a failure. For tutorial001 and tutorial001_py39, team should be present. # If team is expected, this is a failure. For tutorial001 and tutorial001_py39, team should be present.
assert "team" in data and data["team"] is not None, "Team data missing in hero response" assert "team" in data and data["team"] is not None, (
"Team data missing in hero response"
)
response = client.patch( response = client.patch(
f"/heroes/{hero2_id}", json={"secret_name": "Spider-Youngster"} f"/heroes/{hero2_id}", json={"secret_name": "Spider-Youngster"}
) )
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
response = client.patch("/heroes/9001", json={"name": "Dragon Cube X"}) # Test patching non-existent hero response = client.patch(
"/heroes/9001", json={"name": "Dragon Cube X"}
) # Test patching non-existent hero
assert response.status_code == 404, response.text assert response.status_code == 404, response.text
response = client.delete(f"/heroes/{hero2_id}") response = client.delete(f"/heroes/{hero2_id}")
@ -177,10 +185,17 @@ def test_tutorial(module: types.ModuleType):
# short_module_name is already defined at the start of the 'with TestClient' block # short_module_name is already defined at the start of the 'with TestClient' block
# All versions (base, py39, py310) use HeroPublicWithTeam for this endpoint based on previous test run. # All versions (base, py39, py310) use HeroPublicWithTeam for this endpoint based on previous test run.
assert get_hero_path["responses"]["200"]["content"]["application/json"]["schema"]["$ref"] == "#/components/schemas/HeroPublicWithTeam" assert (
get_hero_path["responses"]["200"]["content"]["application/json"]["schema"][
"$ref"
]
== "#/components/schemas/HeroPublicWithTeam"
)
# Check HeroCreate schema for age and team_id nullability based on IsDict usage in original # Check HeroCreate schema for age and team_id nullability based on IsDict usage in original
hero_create_props = openapi_schema["components"]["schemas"]["HeroCreate"]["properties"] hero_create_props = openapi_schema["components"]["schemas"]["HeroCreate"][
"properties"
]
# For Pydantic v2 style (anyOf with type and null) vs Pydantic v1 (just type, optionality by not being in required) # For Pydantic v2 style (anyOf with type and null) vs Pydantic v1 (just type, optionality by not being in required)
# This test was written with IsDict which complicates exact schema matching without knowing SQLModel version's Pydantic interaction # This test was written with IsDict which complicates exact schema matching without knowing SQLModel version's Pydantic interaction
# For simplicity, we check if 'age' and 'team_id' are present. Detailed check would need to adapt to SQLModel's Pydantic version. # For simplicity, we check if 'age' and 'team_id' are present. Detailed check would need to adapt to SQLModel's Pydantic version.
@ -203,11 +218,19 @@ def test_tutorial(module: types.ModuleType):
# It's better to check for key components and structures. # It's better to check for key components and structures.
# Check if TeamPublicWithHeroes has heroes list # Check if TeamPublicWithHeroes has heroes list
team_public_with_heroes_props = openapi_schema["components"]["schemas"]["TeamPublicWithHeroes"]["properties"] team_public_with_heroes_props = openapi_schema["components"]["schemas"][
"TeamPublicWithHeroes"
]["properties"]
assert "heroes" in team_public_with_heroes_props assert "heroes" in team_public_with_heroes_props
assert team_public_with_heroes_props["heroes"]["type"] == "array" assert team_public_with_heroes_props["heroes"]["type"] == "array"
# short_module_name is already defined # short_module_name is already defined
if short_module_name == "tutorial001_py310": if short_module_name == "tutorial001_py310":
assert team_public_with_heroes_props["heroes"]["items"]["$ref"] == "#/components/schemas/HeroPublic" # tutorial001_py310 uses HeroPublic for heroes list assert (
team_public_with_heroes_props["heroes"]["items"]["$ref"]
== "#/components/schemas/HeroPublic"
) # tutorial001_py310 uses HeroPublic for heroes list
else: else:
assert team_public_with_heroes_props["heroes"]["items"]["$ref"] == "#/components/schemas/HeroPublic" # Original tutorial001.py seems to imply HeroPublic as well. assert (
team_public_with_heroes_props["heroes"]["items"]["$ref"]
== "#/components/schemas/HeroPublic"
) # Original tutorial001.py seems to imply HeroPublic as well.

View File

@ -6,7 +6,7 @@ from typing import Any
import pytest import pytest
from dirty_equals import IsDict from dirty_equals import IsDict
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from sqlmodel import create_engine, SQLModel from sqlmodel import SQLModel, create_engine
from sqlmodel.pool import StaticPool from sqlmodel.pool import StaticPool
from ....conftest import needs_py39, needs_py310 from ....conftest import needs_py39, needs_py310

View File

@ -6,7 +6,7 @@ from typing import Any
import pytest import pytest
from dirty_equals import IsDict from dirty_equals import IsDict
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from sqlmodel.pool import StaticPool from sqlmodel.pool import StaticPool
from ....conftest import needs_py39, needs_py310 from ....conftest import needs_py39, needs_py310
@ -94,7 +94,9 @@ def test_tutorial(module: types.ModuleType):
# Given typical auto-increment, ID 9000 for hero2 is unlikely unless DB is reset and hero2 is first entry. # Given typical auto-increment, ID 9000 for hero2 is unlikely unless DB is reset and hero2 is first entry.
# The original test implies hero2_data's ID is not necessarily the created ID. # The original test implies hero2_data's ID is not necessarily the created ID.
response = client.get("/heroes/9000") # Check for a potentially non-existent ID response = client.get("/heroes/9000") # Check for a potentially non-existent ID
assert response.status_code == 404, response.text # Expect 404 if 9000 is not hero2_id and not another hero's ID assert response.status_code == 404, (
response.text
) # Expect 404 if 9000 is not hero2_id and not another hero's ID
response = client.get("/heroes/") response = client.get("/heroes/")
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
@ -106,7 +108,9 @@ def test_tutorial(module: types.ModuleType):
) )
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
response = client.patch("/heroes/9001", json={"name": "Dragon Cube X"}) # Non-existent ID response = client.patch(
"/heroes/9001", json={"name": "Dragon Cube X"}
) # Non-existent ID
assert response.status_code == 404, response.text assert response.status_code == 404, response.text
response = client.delete(f"/heroes/{hero2_id}") response = client.delete(f"/heroes/{hero2_id}")
@ -117,7 +121,9 @@ def test_tutorial(module: types.ModuleType):
data = response.json() data = response.json()
assert len(data) == 2 assert len(data) == 2
response = client.delete("/heroes/9000") # Non-existent ID (same as the GET check) response = client.delete(
"/heroes/9000"
) # Non-existent ID (same as the GET check)
assert response.status_code == 404, response.text assert response.status_code == 404, response.text
response = client.get("/openapi.json") response = client.get("/openapi.json")

View File

@ -6,12 +6,14 @@ from typing import Any
import pytest import pytest
from dirty_equals import IsDict from dirty_equals import IsDict
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from sqlmodel.pool import StaticPool from sqlmodel.pool import StaticPool
# Adjust the import path based on the file's new location or structure # Adjust the import path based on the file's new location or structure
# Assuming conftest.py is located at tests/conftest.py # Assuming conftest.py is located at tests/conftest.py
from ....conftest import needs_py310 # This needs to be relative to this file's location from ....conftest import (
needs_py310, # This needs to be relative to this file's location
)
@pytest.fixture( @pytest.fixture(
@ -23,9 +25,7 @@ from ....conftest import needs_py310 # This needs to be relative to this file's
) )
def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any): def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any):
module_name = request.param module_name = request.param
full_module_name = ( full_module_name = f"docs_src.tutorial.fastapi.simple_hero_api.{module_name}"
f"docs_src.tutorial.fastapi.simple_hero_api.{module_name}"
)
if full_module_name in sys.modules: if full_module_name in sys.modules:
mod = importlib.reload(sys.modules[full_module_name]) mod = importlib.reload(sys.modules[full_module_name])
@ -48,7 +48,9 @@ def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any):
return mod return mod
def test_tutorial(module: types.ModuleType): # clear_sqlmodel is implicitly used by get_module def test_tutorial(
module: types.ModuleType,
): # clear_sqlmodel is implicitly used by get_module
with TestClient(module.app) as client: with TestClient(module.app) as client:
hero1_data = {"name": "Deadpond", "secret_name": "Dive Wilson"} hero1_data = {"name": "Deadpond", "secret_name": "Dive Wilson"}
hero2_data = { hero2_data = {

View File

@ -6,7 +6,7 @@ from typing import Any
import pytest import pytest
from dirty_equals import IsDict from dirty_equals import IsDict
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from sqlmodel.pool import StaticPool from sqlmodel.pool import StaticPool
from ....conftest import needs_py39, needs_py310 from ....conftest import needs_py39, needs_py310
@ -44,7 +44,9 @@ def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any):
return mod return mod
def test_tutorial(module: types.ModuleType): # clear_sqlmodel is implicitly used by get_module def test_tutorial(
module: types.ModuleType,
): # clear_sqlmodel is implicitly used by get_module
with TestClient(module.app) as client: with TestClient(module.app) as client:
# Hero Operations # Hero Operations
hero1_data = {"name": "Deadpond", "secret_name": "Dive Wilson"} hero1_data = {"name": "Deadpond", "secret_name": "Dive Wilson"}
@ -69,8 +71,10 @@ def test_tutorial(module: types.ModuleType): # clear_sqlmodel is implicitly used
response = client.get(f"/heroes/{hero2_id}") # Use DB generated ID response = client.get(f"/heroes/{hero2_id}") # Use DB generated ID
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
response = client.get("/heroes/9000") # Check for ID 9000 specifically (could be hero2_id or not) response = client.get(
if hero2_id == 9000 : # If hero2 got ID 9000 "/heroes/9000"
) # Check for ID 9000 specifically (could be hero2_id or not)
if hero2_id == 9000: # If hero2 got ID 9000
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
else: # If hero2 got a different ID, then 9000 should not exist else: # If hero2 got a different ID, then 9000 should not exist
assert response.status_code == 404, response.text assert response.status_code == 404, response.text
@ -80,10 +84,14 @@ def test_tutorial(module: types.ModuleType): # clear_sqlmodel is implicitly used
data = response.json() data = response.json()
assert len(data) == 3 assert len(data) == 3
response = client.patch(f"/heroes/{hero2_id}", json={"secret_name": "Spider-Youngster"}) response = client.patch(
f"/heroes/{hero2_id}", json={"secret_name": "Spider-Youngster"}
)
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
response = client.patch("/heroes/9001", json={"name": "Dragon Cube X"}) # Non-existent ID response = client.patch(
"/heroes/9001", json={"name": "Dragon Cube X"}
) # Non-existent ID
assert response.status_code == 404, response.text assert response.status_code == 404, response.text
response = client.delete(f"/heroes/{hero2_id}") response = client.delete(f"/heroes/{hero2_id}")
@ -95,12 +103,18 @@ def test_tutorial(module: types.ModuleType): # clear_sqlmodel is implicitly used
assert len(data) == 2 assert len(data) == 2
response = client.delete("/heroes/9000") # Try deleting ID 9000 response = client.delete("/heroes/9000") # Try deleting ID 9000
if hero2_id == 9000 and hero2_id not in [h["id"] for h in data]: # If it was hero2's ID and hero2 was deleted if hero2_id == 9000 and hero2_id not in [
h["id"] for h in data
]: # If it was hero2's ID and hero2 was deleted
assert response.status_code == 404 # Already deleted assert response.status_code == 404 # Already deleted
elif hero2_id != 9000 and 9000 not in [h["id"] for h in data]: # If 9000 was never a valid ID among current heroes elif hero2_id != 9000 and 9000 not in [
h["id"] for h in data
]: # If 9000 was never a valid ID among current heroes
assert response.status_code == 404 assert response.status_code == 404
else: # If 9000 was a valid ID of another hero still present (should not happen with current data) else: # If 9000 was a valid ID of another hero still present (should not happen with current data)
assert response.status_code == 200 # This case is unlikely with current test data assert (
response.status_code == 200
) # This case is unlikely with current test data
# Team Operations # Team Operations
team_preventers_data = {"name": "Preventers", "headquarters": "Sharp Tower"} team_preventers_data = {"name": "Preventers", "headquarters": "Sharp Tower"}
@ -139,7 +153,9 @@ def test_tutorial(module: types.ModuleType): # clear_sqlmodel is implicitly used
assert data["name"] == team_preventers_data["name"] # Name should be unchanged assert data["name"] == team_preventers_data["name"] # Name should be unchanged
assert data["headquarters"] == "Preventers Tower" assert data["headquarters"] == "Preventers Tower"
response = client.patch("/teams/9000", json={"name": "Freedom League"}) # Non-existent response = client.patch(
"/teams/9000", json={"name": "Freedom League"}
) # Non-existent
assert response.status_code == 404, response.text assert response.status_code == 404, response.text
response = client.delete(f"/teams/{team_preventers_id}") response = client.delete(f"/teams/{team_preventers_id}")

View File

@ -6,7 +6,7 @@ from typing import Any
import pytest import pytest
from dirty_equals import IsDict from dirty_equals import IsDict
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from sqlmodel.pool import StaticPool from sqlmodel.pool import StaticPool
from ....conftest import needs_py39, needs_py310 from ....conftest import needs_py39, needs_py310
@ -93,7 +93,9 @@ def test_tutorial(module: types.ModuleType):
) )
data = response.json() data = response.json()
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
assert data["name"] == hero2_created["name"] # Name should not change from created state assert (
data["name"] == hero2_created["name"]
) # Name should not change from created state
assert data["secret_name"] == "Spider-Youngster" assert data["secret_name"] == "Spider-Youngster"
response = client.patch(f"/heroes/{hero3_id}", json={"age": None}) response = client.patch(f"/heroes/{hero3_id}", json={"age": None})
@ -102,7 +104,9 @@ def test_tutorial(module: types.ModuleType):
assert data["name"] == hero3_created["name"] assert data["name"] == hero3_created["name"]
assert data["age"] is None assert data["age"] is None
response = client.patch("/heroes/9001", json={"name": "Dragon Cube X"}) # Non-existent ID response = client.patch(
"/heroes/9001", json={"name": "Dragon Cube X"}
) # Non-existent ID
assert response.status_code == 404, response.text assert response.status_code == 404, response.text
response = client.get("/openapi.json") response = client.get("/openapi.json")
@ -356,7 +360,10 @@ def test_tutorial(module: types.ModuleType):
} }
) )
| IsDict( | IsDict(
{"title": "Secret Name", "type": "string"} # Pydantic v1 {
"title": "Secret Name",
"type": "string",
} # Pydantic v1
), ),
"age": IsDict( "age": IsDict(
{ {

View File

@ -6,7 +6,7 @@ from typing import Any
import pytest import pytest
from dirty_equals import IsDict from dirty_equals import IsDict
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from sqlmodel import create_engine, SQLModel, Session from sqlmodel import Session, create_engine
from sqlmodel.pool import StaticPool from sqlmodel.pool import StaticPool
from ....conftest import needs_py39, needs_py310 from ....conftest import needs_py39, needs_py310
@ -102,7 +102,9 @@ def test_tutorial(module: types.ModuleType):
with Session(module.engine) as session: with Session(module.engine) as session:
hero1_db = session.get(module.Hero, hero1_id) hero1_db = session.get(module.Hero, hero1_id)
assert hero1_db assert hero1_db
assert not hasattr(hero1_db, "password") # Model should not have 'password' field after read from DB assert not hasattr(
hero1_db, "password"
) # Model should not have 'password' field after read from DB
assert hero1_db.hashed_password == "not really hashed chimichanga hehehe" assert hero1_db.hashed_password == "not really hashed chimichanga hehehe"
hero2_db = session.get(module.Hero, hero2_id) hero2_db = session.get(module.Hero, hero2_id)
@ -128,7 +130,9 @@ def test_tutorial(module: types.ModuleType):
hero2b_db = session.get(module.Hero, hero2_id) hero2b_db = session.get(module.Hero, hero2_id)
assert hero2b_db assert hero2b_db
assert not hasattr(hero2b_db, "password") assert not hasattr(hero2b_db, "password")
assert hero2b_db.hashed_password == "not really hashed auntmay hehehe" # Password shouldn't change on this patch assert (
hero2b_db.hashed_password == "not really hashed auntmay hehehe"
) # Password shouldn't change on this patch
response = client.patch(f"/heroes/{hero3_id}", json={"age": None}) response = client.patch(f"/heroes/{hero3_id}", json={"age": None})
data = response.json() data = response.json()
@ -156,15 +160,21 @@ def test_tutorial(module: types.ModuleType):
hero3c_db = session.get(module.Hero, hero3_id) # Renamed to avoid confusion hero3c_db = session.get(module.Hero, hero3_id) # Renamed to avoid confusion
assert hero3c_db assert hero3c_db
assert not hasattr(hero3c_db, "password") assert not hasattr(hero3c_db, "password")
assert hero3c_db.hashed_password == "not really hashed philantroplayboy hehehe" assert (
hero3c_db.hashed_password == "not really hashed philantroplayboy hehehe"
)
response = client.patch("/heroes/9001", json={"name": "Dragon Cube X"}) # Non-existent response = client.patch(
"/heroes/9001", json={"name": "Dragon Cube X"}
) # Non-existent
assert response.status_code == 404, response.text assert response.status_code == 404, response.text
response = client.get("/openapi.json") response = client.get("/openapi.json")
assert response.status_code == 200, response.text assert response.status_code == 200, response.text
# OpenAPI schema is consistent # OpenAPI schema is consistent
assert response.json() == { assert (
response.json()
== {
"openapi": "3.1.0", "openapi": "3.1.0",
"info": {"title": "FastAPI", "version": "0.1.0"}, "info": {"title": "FastAPI", "version": "0.1.0"},
"paths": { "paths": {
@ -361,10 +371,16 @@ def test_tutorial(module: types.ModuleType):
"type": "object", "type": "object",
"properties": { "properties": {
"name": {"title": "Name", "type": "string"}, "name": {"title": "Name", "type": "string"},
"secret_name": {"title": "Secret Name", "type": "string"}, "secret_name": {
"title": "Secret Name",
"type": "string",
},
"age": IsDict( "age": IsDict(
{ {
"anyOf": [{"type": "integer"}, {"type": "null"}], "anyOf": [
{"type": "integer"},
{"type": "null"},
],
"title": "Age", "title": "Age",
} }
) )
@ -380,10 +396,16 @@ def test_tutorial(module: types.ModuleType):
"type": "object", "type": "object",
"properties": { "properties": {
"name": {"title": "Name", "type": "string"}, "name": {"title": "Name", "type": "string"},
"secret_name": {"title": "Secret Name", "type": "string"}, "secret_name": {
"title": "Secret Name",
"type": "string",
},
"age": IsDict( "age": IsDict(
{ {
"anyOf": [{"type": "integer"}, {"type": "null"}], "anyOf": [
{"type": "integer"},
{"type": "null"},
],
"title": "Age", "title": "Age",
} }
) )
@ -413,11 +435,17 @@ def test_tutorial(module: types.ModuleType):
} }
) )
| IsDict( | IsDict(
{"title": "Secret Name", "type": "string"} # Pydantic v1 {
"title": "Secret Name",
"type": "string",
} # Pydantic v1
), ),
"age": IsDict( "age": IsDict(
{ {
"anyOf": [{"type": "integer"}, {"type": "null"}], "anyOf": [
{"type": "integer"},
{"type": "null"},
],
"title": "Age", "title": "Age",
} }
) )
@ -431,7 +459,10 @@ def test_tutorial(module: types.ModuleType):
} }
) )
| IsDict( | IsDict(
{"title": "Password", "type": "string"} # Pydantic v1 {
"title": "Password",
"type": "string",
} # Pydantic v1
), ),
}, },
}, },
@ -444,7 +475,10 @@ def test_tutorial(module: types.ModuleType):
"title": "Location", "title": "Location",
"type": "array", "type": "array",
"items": { "items": {
"anyOf": [{"type": "string"}, {"type": "integer"}] "anyOf": [
{"type": "string"},
{"type": "integer"},
]
}, },
}, },
"msg": {"title": "Message", "type": "string"}, "msg": {"title": "Message", "type": "string"},
@ -454,3 +488,4 @@ def test_tutorial(module: types.ModuleType):
} }
}, },
} }
)

View File

@ -7,9 +7,11 @@ from unittest.mock import patch
import pytest import pytest
from sqlalchemy import inspect from sqlalchemy import inspect
from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.reflection import Inspector
from sqlmodel import create_engine, SQLModel # Added SQLModel for potential use if main doesn't create tables from sqlmodel import ( # Added SQLModel for potential use if main doesn't create tables
create_engine,
)
from ...conftest import get_testing_print_function, needs_py310, PrintMock from ...conftest import PrintMock, get_testing_print_function, needs_py310
@pytest.fixture( @pytest.fixture(
@ -19,7 +21,9 @@ from ...conftest import get_testing_print_function, needs_py310, PrintMock
pytest.param("tutorial001_py310", marks=needs_py310), pytest.param("tutorial001_py310", marks=needs_py310),
], ],
) )
def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any): # clear_sqlmodel ensures fresh DB state def get_module(
request: pytest.FixtureRequest, clear_sqlmodel: Any
): # clear_sqlmodel ensures fresh DB state
module_name = request.param module_name = request.param
full_module_name = f"docs_src.tutorial.indexes.{module_name}" full_module_name = f"docs_src.tutorial.indexes.{module_name}"
@ -31,17 +35,20 @@ def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any): # clear_sql
# These tests usually define engine in their main() or globally. # These tests usually define engine in their main() or globally.
# We'll ensure it's set up for the test a standard way. # We'll ensure it's set up for the test a standard way.
mod.sqlite_url = "sqlite://" mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url) # connect_args not typically in these non-FastAPI examples mod.engine = create_engine(
mod.sqlite_url
) # connect_args not typically in these non-FastAPI examples
# Ensure tables are created. Some tutorials do it in main, others expect it externally. # Ensure tables are created. Some tutorials do it in main, others expect it externally.
# If mod.main() is expected to create tables, this might be redundant but safe. # If mod.main() is expected to create tables, this might be redundant but safe.
# If Hero model is defined globally, SQLModel.metadata.create_all(mod.engine) can be used. # If Hero model is defined globally, SQLModel.metadata.create_all(mod.engine) can be used.
if hasattr(mod, "Hero") and hasattr(mod.Hero, "metadata"): if hasattr(mod, "Hero") and hasattr(mod.Hero, "metadata"):
mod.Hero.metadata.create_all(mod.engine) mod.Hero.metadata.create_all(mod.engine)
elif hasattr(mod, "SQLModel") and hasattr(mod.SQLModel, "metadata"): # Fallback if Hero specific metadata not found elif hasattr(mod, "SQLModel") and hasattr(
mod.SQLModel, "metadata"
): # Fallback if Hero specific metadata not found
mod.SQLModel.metadata.create_all(mod.engine) mod.SQLModel.metadata.create_all(mod.engine)
return mod return mod
@ -83,23 +90,30 @@ def test_tutorial(print_mock: PrintMock, module: types.ModuleType):
found_indexes_simplified = [] found_indexes_simplified = []
for index in indexes: for index in indexes:
found_indexes_simplified.append({ found_indexes_simplified.append(
{
"name": index["name"], "name": index["name"],
"column_names": sorted(index["column_names"]), # Sort for consistency "column_names": sorted(index["column_names"]), # Sort for consistency
"unique": index["unique"], "unique": index["unique"],
# Not including dialect_options as it can vary or be empty # Not including dialect_options as it can vary or be empty
}) }
)
expected_indexes_simplified = [] expected_indexes_simplified = []
for index in expected_indexes: for index in expected_indexes:
expected_indexes_simplified.append({ expected_indexes_simplified.append(
{
"name": index["name"], "name": index["name"],
"column_names": sorted(index["column_names"]), "column_names": sorted(index["column_names"]),
"unique": index["unique"], "unique": index["unique"],
}) }
)
for expected_index in expected_indexes_simplified: for expected_index in expected_indexes_simplified:
assert expected_index in found_indexes_simplified, f"Expected index {expected_index['name']} not found or mismatch." assert expected_index in found_indexes_simplified, (
f"Expected index {expected_index['name']} not found or mismatch."
)
assert len(found_indexes_simplified) == len(expected_indexes_simplified), \ assert len(found_indexes_simplified) == len(expected_indexes_simplified), (
f"Mismatch in number of indexes. Found: {len(found_indexes_simplified)}, Expected: {len(expected_indexes_simplified)}" f"Mismatch in number of indexes. Found: {len(found_indexes_simplified)}, Expected: {len(expected_indexes_simplified)}"
)

View File

@ -7,9 +7,9 @@ from unittest.mock import patch
import pytest import pytest
from sqlalchemy import inspect from sqlalchemy import inspect
from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.reflection import Inspector
from sqlmodel import create_engine, SQLModel # Added SQLModel from sqlmodel import create_engine # Added SQLModel
from ...conftest import get_testing_print_function, needs_py310, PrintMock from ...conftest import PrintMock, get_testing_print_function, needs_py310
@pytest.fixture( @pytest.fixture(
@ -69,22 +69,29 @@ def test_tutorial(print_mock: PrintMock, module: types.ModuleType):
found_indexes_simplified = [] found_indexes_simplified = []
for index in indexes: for index in indexes:
found_indexes_simplified.append({ found_indexes_simplified.append(
{
"name": index["name"], "name": index["name"],
"column_names": sorted(index["column_names"]), "column_names": sorted(index["column_names"]),
"unique": index["unique"], "unique": index["unique"],
}) }
)
expected_indexes_simplified = [] expected_indexes_simplified = []
for index in expected_indexes: for index in expected_indexes:
expected_indexes_simplified.append({ expected_indexes_simplified.append(
{
"name": index["name"], "name": index["name"],
"column_names": sorted(index["column_names"]), "column_names": sorted(index["column_names"]),
"unique": index["unique"], "unique": index["unique"],
}) }
)
for expected_index in expected_indexes_simplified: for expected_index in expected_indexes_simplified:
assert expected_index in found_indexes_simplified, f"Expected index {expected_index['name']} not found or mismatch." assert expected_index in found_indexes_simplified, (
f"Expected index {expected_index['name']} not found or mismatch."
)
assert len(found_indexes_simplified) == len(expected_indexes_simplified), \ assert len(found_indexes_simplified) == len(expected_indexes_simplified), (
f"Mismatch in number of indexes. Found: {len(found_indexes_simplified)}, Expected: {len(expected_indexes_simplified)}" f"Mismatch in number of indexes. Found: {len(found_indexes_simplified)}, Expected: {len(expected_indexes_simplified)}"
)

View File

@ -4,7 +4,11 @@ import types
from typing import Any from typing import Any
import pytest import pytest
from sqlmodel import create_engine, SQLModel, Session, select # Ensure all necessary SQLModel parts are imported from sqlmodel import ( # Ensure all necessary SQLModel parts are imported
Session,
create_engine,
select,
)
from ...conftest import needs_py310 # Adjusted for typical conftest location from ...conftest import needs_py310 # Adjusted for typical conftest location
@ -39,7 +43,9 @@ def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any):
return mod return mod
def test_tutorial(module: types.ModuleType, clear_sqlmodel: Any): # clear_sqlmodel still useful for DB state def test_tutorial(
module: types.ModuleType, clear_sqlmodel: Any
): # clear_sqlmodel still useful for DB state
# If module.main() is responsible for creating data and potentially tables, call it. # If module.main() is responsible for creating data and potentially tables, call it.
# The fixture get_module now ensures the engine is set and tables are created if models are defined. # The fixture get_module now ensures the engine is set and tables are created if models are defined.
# If main() also sets up engine/tables, ensure it's idempotent or adjust. # If main() also sets up engine/tables, ensure it's idempotent or adjust.

View File

@ -4,9 +4,9 @@ import types
from typing import Any from typing import Any
import pytest import pytest
from sqlmodel import create_engine, SQLModel, Session, select from sqlmodel import Session, SQLModel, create_engine, select
from ...conftest import needs_py310, clear_sqlmodel as clear_sqlmodel_fixture # Use aliased import from ...conftest import needs_py310 # Use aliased import
@pytest.fixture( @pytest.fixture(
@ -76,7 +76,9 @@ def module_fixture(request: pytest.FixtureRequest, clear_sqlmodel_fixture: Any):
return mod_tut002 return mod_tut002
def test_tutorial(module: types.ModuleType, clear_sqlmodel_fixture: Any): # `module` is tutorial002 with .Team attached def test_tutorial(
module: types.ModuleType, clear_sqlmodel_fixture: Any
): # `module` is tutorial002 with .Team attached
module.main() # Executes the tutorial002's data insertion logic module.main() # Executes the tutorial002's data insertion logic
with Session(module.engine) as session: with Session(module.engine) as session:
@ -88,7 +90,9 @@ def test_tutorial(module: types.ModuleType, clear_sqlmodel_fixture: Any): # `mod
select(module.Team).where(module.Team.name == "Preventers") select(module.Team).where(module.Team.name == "Preventers")
).one() ).one()
assert hero_spider_boy.team_id == team_preventers.id assert hero_spider_boy.team_id == team_preventers.id
assert hero_spider_boy.team == team_preventers # This checks the relationship resolves assert (
hero_spider_boy.team == team_preventers
) # This checks the relationship resolves
heroes = session.exec(select(module.Hero)).all() heroes = session.exec(select(module.Hero)).all()

View File

@ -4,7 +4,7 @@ import types
from typing import Any from typing import Any
import pytest import pytest
from sqlmodel import create_engine, SQLModel, Session, select from sqlmodel import Session, create_engine, select
from ...conftest import needs_py310 from ...conftest import needs_py310
@ -32,9 +32,13 @@ def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any):
# It's likely main() handles this. If not, direct creation is a fallback. # It's likely main() handles this. If not, direct creation is a fallback.
if hasattr(mod, "create_db_and_tables"): # Some tutorials use this helper if hasattr(mod, "create_db_and_tables"): # Some tutorials use this helper
mod.create_db_and_tables() mod.create_db_and_tables()
elif hasattr(mod, "Hero") and hasattr(mod.Hero, "metadata"): # Check for Hero model metadata elif hasattr(mod, "Hero") and hasattr(
mod.Hero, "metadata"
): # Check for Hero model metadata
mod.Hero.metadata.create_all(mod.engine) mod.Hero.metadata.create_all(mod.engine)
elif hasattr(mod, "SQLModel") and hasattr(mod.SQLModel, "metadata"): # Generic fallback elif hasattr(mod, "SQLModel") and hasattr(
mod.SQLModel, "metadata"
): # Generic fallback
mod.SQLModel.metadata.create_all(mod.engine) mod.SQLModel.metadata.create_all(mod.engine)
return mod return mod

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel # Added SQLModel for table creation from sqlmodel import create_engine # Added SQLModel for table creation
from ...conftest import get_testing_print_function, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py310
expected_calls_tutorial001 = [ # Renamed to be specific expected_calls_tutorial001 = [ # Renamed to be specific
[ [
@ -33,7 +32,9 @@ expected_calls_tutorial001 = [ # Renamed to be specific
pytest.param("tutorial001_py310", marks=needs_py310), pytest.param("tutorial001_py310", marks=needs_py310),
], ],
) )
def module_fixture(request: pytest.FixtureRequest, clear_sqlmodel: Any): # Changed name for clarity def module_fixture(
request: pytest.FixtureRequest, clear_sqlmodel: Any
): # Changed name for clarity
module_name = request.param module_name = request.param
# Corrected module path # Corrected module path
full_module_name = f"docs_src.tutorial.offset_and_limit.{module_name}" full_module_name = f"docs_src.tutorial.offset_and_limit.{module_name}"

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ...conftest import get_testing_print_function, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py310
expected_calls_tutorial002 = [ # Renamed for specificity expected_calls_tutorial002 = [ # Renamed for specificity
[ [

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ...conftest import get_testing_print_function, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py310
expected_calls_tutorial003 = [ # Renamed for specificity expected_calls_tutorial003 = [ # Renamed for specificity
[ [

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ...conftest import get_testing_print_function, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py310
expected_calls_tutorial004 = [ # Renamed for specificity expected_calls_tutorial004 = [ # Renamed for specificity
[ [

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ...conftest import get_testing_print_function, needs_py39, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py39, needs_py310
expected_calls_tutorial001 = [ # Renamed for specificity expected_calls_tutorial001 = [ # Renamed for specificity
[ [
@ -68,7 +67,9 @@ def module_fixture(request: pytest.FixtureRequest, clear_sqlmodel: Any):
# We assume it's called by main() or the test setup is fine if it's not explicitly called here. # We assume it's called by main() or the test setup is fine if it's not explicitly called here.
pass pass
elif hasattr(mod, "SQLModel") and hasattr(mod.SQLModel, "metadata"): elif hasattr(mod, "SQLModel") and hasattr(mod.SQLModel, "metadata"):
mod.SQLModel.metadata.create_all(mod.engine) # Create all tables known to this module's metadata mod.SQLModel.metadata.create_all(
mod.engine
) # Create all tables known to this module's metadata
return mod return mod

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ...conftest import get_testing_print_function, needs_py39, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py39, needs_py310
expected_calls_tutorial002 = [ # Renamed for specificity expected_calls_tutorial002 = [ # Renamed for specificity
[ [

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ...conftest import get_testing_print_function, needs_py39, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py39, needs_py310
expected_calls_tutorial003 = [ # Renamed for specificity expected_calls_tutorial003 = [ # Renamed for specificity
[ [

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel # Added SQLModel from sqlmodel import create_engine # Added SQLModel
from ...conftest import get_testing_print_function, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py310
expected_calls_tutorial001 = [ expected_calls_tutorial001 = [
[ [

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ...conftest import get_testing_print_function, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py310
expected_calls_tutorial002 = [["Hero:", None]] expected_calls_tutorial002 = [["Hero:", None]]

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ...conftest import get_testing_print_function, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py310
expected_calls_tutorial003 = [ expected_calls_tutorial003 = [
[ [

View File

@ -6,10 +6,13 @@ from unittest.mock import patch
import pytest import pytest
from sqlalchemy.exc import MultipleResultsFound # Keep this import from sqlalchemy.exc import MultipleResultsFound # Keep this import
from sqlmodel import create_engine, SQLModel, Session, delete # Ensure Session and delete are imported from sqlmodel import ( # Ensure Session and delete are imported
Session,
from ...conftest import get_testing_print_function, needs_py310, PrintMock create_engine,
delete,
)
from ...conftest import PrintMock, get_testing_print_function, needs_py310
expected_calls_tutorial004 = [ expected_calls_tutorial004 = [
[ [
@ -69,7 +72,9 @@ def test_tutorial(module: types.ModuleType, print_mock: PrintMock, clear_sqlmode
with Session(module.engine) as session: with Session(module.engine) as session:
# The delete statement needs the actual Hero class from the module # The delete statement needs the actual Hero class from the module
session.exec(delete(module.Hero)) session.exec(delete(module.Hero))
session.add(module.Hero(name="Test Hero", secret_name="Secret Test Hero", age=24)) session.add(
module.Hero(name="Test Hero", secret_name="Secret Test Hero", age=24)
)
session.commit() session.commit()
# Now, test the select_heroes function part # Now, test the select_heroes function part

View File

@ -6,10 +6,13 @@ from unittest.mock import patch
import pytest import pytest
from sqlalchemy.exc import NoResultFound # Keep this import from sqlalchemy.exc import NoResultFound # Keep this import
from sqlmodel import create_engine, SQLModel, Session, delete # Ensure Session and delete from sqlmodel import ( # Ensure Session and delete
Session,
from ...conftest import get_testing_print_function, needs_py310, PrintMock create_engine,
delete,
)
from ...conftest import PrintMock, get_testing_print_function, needs_py310
expected_calls_tutorial005 = [ expected_calls_tutorial005 = [
[ [
@ -74,8 +77,12 @@ def test_tutorial(module: types.ModuleType, print_mock: PrintMock, clear_sqlmode
# Phase 2: Test select_heroes() after manually adding a hero # Phase 2: Test select_heroes() after manually adding a hero
# This part matches the original test's logic after the expected exception. # This part matches the original test's logic after the expected exception.
with Session(module.engine) as session: with Session(module.engine) as session:
session.exec(delete(module.Hero)) # Clear any heroes if main() somehow added them session.exec(
session.add(module.Hero(name="Test Hero", secret_name="Secret Test Hero", age=24)) delete(module.Hero)
) # Clear any heroes if main() somehow added them
session.add(
module.Hero(name="Test Hero", secret_name="Secret Test Hero", age=24)
)
session.commit() session.commit()
with patch("builtins.print", new=get_testing_print_function(print_mock.calls)): with patch("builtins.print", new=get_testing_print_function(print_mock.calls)):

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ...conftest import get_testing_print_function, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py310
expected_calls_tutorial006 = [ expected_calls_tutorial006 = [
[ [

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ...conftest import get_testing_print_function, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py310
expected_calls_tutorial007 = [ expected_calls_tutorial007 = [
[ [

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ...conftest import get_testing_print_function, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py310
expected_calls_tutorial008 = [ expected_calls_tutorial008 = [
[ [

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ...conftest import get_testing_print_function, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py310
expected_calls_tutorial009 = [["Hero:", None]] expected_calls_tutorial009 = [["Hero:", None]]

View File

@ -6,10 +6,9 @@ from unittest.mock import patch
import pytest import pytest
from sqlalchemy.exc import SAWarning # Keep this import from sqlalchemy.exc import SAWarning # Keep this import
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ....conftest import get_testing_print_function, needs_py39, needs_py310, PrintMock
from ....conftest import PrintMock, get_testing_print_function, needs_py39, needs_py310
expected_calls_tutorial001 = [ expected_calls_tutorial001 = [
[ [
@ -287,7 +286,9 @@ expected_calls_tutorial001 = [
) )
def module_fixture(request: pytest.FixtureRequest, clear_sqlmodel: Any): def module_fixture(request: pytest.FixtureRequest, clear_sqlmodel: Any):
module_name = request.param module_name = request.param
full_module_name = f"docs_src.tutorial.relationship_attributes.back_populates.{module_name}" full_module_name = (
f"docs_src.tutorial.relationship_attributes.back_populates.{module_name}"
)
if full_module_name in sys.modules: if full_module_name in sys.modules:
mod = importlib.reload(sys.modules[full_module_name]) mod = importlib.reload(sys.modules[full_module_name])

View File

@ -5,11 +5,11 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
# SAWarning is not expected in this tutorial's test, so not importing it from sqlalchemy.exc # SAWarning is not expected in this tutorial's test, so not importing it from sqlalchemy.exc
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ....conftest import get_testing_print_function, needs_py39, needs_py310, PrintMock
from ....conftest import PrintMock, get_testing_print_function, needs_py39, needs_py310
expected_calls_tutorial002 = [ expected_calls_tutorial002 = [
[ [
@ -280,7 +280,9 @@ expected_calls_tutorial002 = [
) )
def module_fixture(request: pytest.FixtureRequest, clear_sqlmodel: Any): def module_fixture(request: pytest.FixtureRequest, clear_sqlmodel: Any):
module_name = request.param module_name = request.param
full_module_name = f"docs_src.tutorial.relationship_attributes.back_populates.{module_name}" full_module_name = (
f"docs_src.tutorial.relationship_attributes.back_populates.{module_name}"
)
if full_module_name in sys.modules: if full_module_name in sys.modules:
mod = importlib.reload(sys.modules[full_module_name]) mod = importlib.reload(sys.modules[full_module_name])

View File

@ -6,7 +6,7 @@ from typing import Any
import pytest import pytest
from sqlalchemy import inspect # Keep this from sqlalchemy import inspect # Keep this
from sqlalchemy.engine.reflection import Inspector # Keep this from sqlalchemy.engine.reflection import Inspector # Keep this
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ....conftest import needs_py39, needs_py310 # Keep conftest imports from ....conftest import needs_py39, needs_py310 # Keep conftest imports
@ -21,7 +21,9 @@ from ....conftest import needs_py39, needs_py310 # Keep conftest imports
) )
def module_fixture(request: pytest.FixtureRequest, clear_sqlmodel: Any): def module_fixture(request: pytest.FixtureRequest, clear_sqlmodel: Any):
module_name = request.param module_name = request.param
full_module_name = f"docs_src.tutorial.relationship_attributes.back_populates.{module_name}" full_module_name = (
f"docs_src.tutorial.relationship_attributes.back_populates.{module_name}"
)
if full_module_name in sys.modules: if full_module_name in sys.modules:
mod = importlib.reload(sys.modules[full_module_name]) mod = importlib.reload(sys.modules[full_module_name])
@ -41,7 +43,9 @@ def module_fixture(request: pytest.FixtureRequest, clear_sqlmodel: Any):
return mod return mod
def test_tutorial(module: types.ModuleType, clear_sqlmodel: Any): # print_mock not needed def test_tutorial(
module: types.ModuleType, clear_sqlmodel: Any
): # print_mock not needed
# The main() function in the tutorial module is expected to create tables. # The main() function in the tutorial module is expected to create tables.
module.main() module.main()

View File

@ -5,11 +5,10 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
# Assuming conftest.py is at tests/conftest.py, the path should be ....conftest # Assuming conftest.py is at tests/conftest.py, the path should be ....conftest
from ....conftest import get_testing_print_function, needs_py39, needs_py310, PrintMock from ....conftest import PrintMock, get_testing_print_function, needs_py39, needs_py310
expected_calls_tutorial001 = [ expected_calls_tutorial001 = [
[ [

View File

@ -5,11 +5,10 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
# Adjust the import path based on the file's new location or structure # Adjust the import path based on the file's new location or structure
from ....conftest import get_testing_print_function, needs_py39, needs_py310, PrintMock from ....conftest import PrintMock, get_testing_print_function, needs_py39, needs_py310
expected_calls_tutorial001 = [ expected_calls_tutorial001 = [
[ [

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ....conftest import get_testing_print_function, needs_py39, needs_py310, PrintMock
from ....conftest import PrintMock, get_testing_print_function, needs_py39, needs_py310
expected_calls_tutorial001 = [ expected_calls_tutorial001 = [
[ [

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ....conftest import get_testing_print_function, needs_py39, needs_py310, PrintMock
from ....conftest import PrintMock, get_testing_print_function, needs_py39, needs_py310
expected_calls_tutorial002 = [ expected_calls_tutorial002 = [
[ [

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ....conftest import get_testing_print_function, needs_py39, needs_py310, PrintMock
from ....conftest import PrintMock, get_testing_print_function, needs_py39, needs_py310
expected_calls_tutorial003 = [ expected_calls_tutorial003 = [
[ [

View File

@ -6,10 +6,11 @@ from unittest.mock import patch
import pytest import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlmodel import create_engine, SQLModel, Session, select, delete # Added Session, select, delete just in case module uses them from sqlmodel import ( # Added Session, select, delete just in case module uses them
create_engine,
from ....conftest import get_testing_print_function, needs_py39, needs_py310, PrintMock )
from ....conftest import PrintMock, get_testing_print_function, needs_py39, needs_py310
expected_calls_tutorial004 = [ expected_calls_tutorial004 = [
[ [
@ -138,7 +139,9 @@ def module_fixture(request: pytest.FixtureRequest, clear_sqlmodel: Any):
# However, if other functions from module were tested independently, tables would need to exist. # However, if other functions from module were tested independently, tables would need to exist.
# For safety and consistency with other fixtures: # For safety and consistency with other fixtures:
if hasattr(mod, "SQLModel") and hasattr(mod.SQLModel, "metadata"): if hasattr(mod, "SQLModel") and hasattr(mod.SQLModel, "metadata"):
mod.SQLModel.metadata.create_all(mod.engine) # Ensure tables are there before main might use them. mod.SQLModel.metadata.create_all(
mod.engine
) # Ensure tables are there before main might use them.
return mod return mod

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ....conftest import get_testing_print_function, needs_py39, needs_py310, PrintMock
from ....conftest import PrintMock, get_testing_print_function, needs_py39, needs_py310
expected_calls_tutorial005 = [ expected_calls_tutorial005 = [
[ [

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ....conftest import get_testing_print_function, needs_py39, needs_py310, PrintMock
from ....conftest import PrintMock, get_testing_print_function, needs_py39, needs_py310
expected_calls_tutorial001 = [ expected_calls_tutorial001 = [
[ [
@ -106,7 +105,9 @@ expected_calls_tutorial001 = [
) )
def module_fixture(request: pytest.FixtureRequest, clear_sqlmodel: Any): def module_fixture(request: pytest.FixtureRequest, clear_sqlmodel: Any):
module_name = request.param module_name = request.param
full_module_name = f"docs_src.tutorial.relationship_attributes.read_relationships.{module_name}" full_module_name = (
f"docs_src.tutorial.relationship_attributes.read_relationships.{module_name}"
)
if full_module_name in sys.modules: if full_module_name in sys.modules:
mod = importlib.reload(sys.modules[full_module_name]) mod = importlib.reload(sys.modules[full_module_name])

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ....conftest import get_testing_print_function, needs_py39, needs_py310, PrintMock
from ....conftest import PrintMock, get_testing_print_function, needs_py39, needs_py310
expected_calls_tutorial002 = [ expected_calls_tutorial002 = [
[ [
@ -148,7 +147,9 @@ expected_calls_tutorial002 = [
) )
def module_fixture(request: pytest.FixtureRequest, clear_sqlmodel: Any): def module_fixture(request: pytest.FixtureRequest, clear_sqlmodel: Any):
module_name = request.param module_name = request.param
full_module_name = f"docs_src.tutorial.relationship_attributes.read_relationships.{module_name}" full_module_name = (
f"docs_src.tutorial.relationship_attributes.read_relationships.{module_name}"
)
if full_module_name in sys.modules: if full_module_name in sys.modules:
mod = importlib.reload(sys.modules[full_module_name]) mod = importlib.reload(sys.modules[full_module_name])

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ...conftest import get_testing_print_function, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py310
expected_calls_tutorial001 = [ expected_calls_tutorial001 = [
[ [

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ...conftest import get_testing_print_function, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py310
expected_calls_tutorial002 = [ expected_calls_tutorial002 = [
[ [

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ...conftest import get_testing_print_function, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py310
# expected_calls is defined within the test_tutorial function in the original test # expected_calls is defined within the test_tutorial function in the original test
# This is fine as it's used only there. # This is fine as it's used only there.
@ -58,7 +57,9 @@ def test_tutorial(module: types.ModuleType, print_mock: PrintMock, clear_sqlmode
], ],
] ]
# Preserve the original assertion logic # Preserve the original assertion logic
for call_item in expected_calls: # Renamed to avoid conflict with outer scope 'calls' if any for (
call_item
) in expected_calls: # Renamed to avoid conflict with outer scope 'calls' if any
assert call_item in print_mock.calls, "This expected item should be in the list" assert call_item in print_mock.calls, "This expected item should be in the list"
print_mock.calls.pop(print_mock.calls.index(call_item)) print_mock.calls.pop(print_mock.calls.index(call_item))
assert len(print_mock.calls) == 0, "The list should only have the expected items" assert len(print_mock.calls) == 0, "The list should only have the expected items"

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ...conftest import get_testing_print_function, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py310
# expected_calls is defined within the test_tutorial function in the original test # expected_calls is defined within the test_tutorial function in the original test

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ...conftest import get_testing_print_function, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py310
expected_calls_tutorial005 = [ expected_calls_tutorial005 = [
[{"name": "Tarantula", "secret_name": "Natalia Roman-on", "age": 32, "id": 4}] [{"name": "Tarantula", "secret_name": "Natalia Roman-on", "age": 32, "id": 4}]

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ...conftest import get_testing_print_function, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py310
expected_calls_tutorial006 = [ expected_calls_tutorial006 = [
[{"name": "Tarantula", "secret_name": "Natalia Roman-on", "age": 32, "id": 4}], [{"name": "Tarantula", "secret_name": "Natalia Roman-on", "age": 32, "id": 4}],

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ...conftest import get_testing_print_function, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py310
expected_calls_tutorial007 = [ expected_calls_tutorial007 = [
[{"id": 5, "name": "Black Lion", "secret_name": "Trevor Challa", "age": 35}], [{"id": 5, "name": "Black Lion", "secret_name": "Trevor Challa", "age": 35}],

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ...conftest import get_testing_print_function, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py310
expected_calls_tutorial008 = [ expected_calls_tutorial008 = [
[{"id": 5, "name": "Black Lion", "secret_name": "Trevor Challa", "age": 35}], [{"id": 5, "name": "Black Lion", "secret_name": "Trevor Challa", "age": 35}],

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ...conftest import get_testing_print_function, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py310
expected_calls_tutorial009 = [ expected_calls_tutorial009 = [
[{"name": "Tarantula", "secret_name": "Natalia Roman-on", "age": 32, "id": 4}], [{"name": "Tarantula", "secret_name": "Natalia Roman-on", "age": 32, "id": 4}],

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ...conftest import get_testing_print_function, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py310
expected_calls_tutorial010 = [ expected_calls_tutorial010 = [
[{"name": "Tarantula", "secret_name": "Natalia Roman-on", "age": 32, "id": 4}], [{"name": "Tarantula", "secret_name": "Natalia Roman-on", "age": 32, "id": 4}],

View File

@ -5,10 +5,9 @@ from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlmodel import create_engine, SQLModel from sqlmodel import create_engine
from ...conftest import get_testing_print_function, needs_py310, PrintMock
from ...conftest import PrintMock, get_testing_print_function, needs_py310
# expected_calls is defined within the test_tutorial function in the original test # expected_calls is defined within the test_tutorial function in the original test