mirror of
https://github.com/fastapi/sqlmodel.git
synced 2026-03-13 09:29:54 +08:00
🐛 Fix support for Annotated fields with Pydantic 2.12+ (#1607)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> Co-authored-by: svlandeg <svlandeg@github.com> Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
@@ -5,6 +5,7 @@ import ipaddress
|
||||
import uuid
|
||||
import weakref
|
||||
from collections.abc import Mapping, Sequence, Set
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
@@ -200,6 +201,38 @@ class RelationshipInfo(Representation):
|
||||
self.sa_relationship_kwargs = sa_relationship_kwargs
|
||||
|
||||
|
||||
@dataclass
|
||||
class FieldInfoMetadata:
|
||||
primary_key: Union[bool, UndefinedType] = Undefined
|
||||
nullable: Union[bool, UndefinedType] = Undefined
|
||||
foreign_key: Any = Undefined
|
||||
ondelete: Union[OnDeleteType, UndefinedType] = Undefined
|
||||
unique: Union[bool, UndefinedType] = Undefined
|
||||
index: Union[bool, UndefinedType] = Undefined
|
||||
sa_type: Union[type[Any], UndefinedType] = Undefined
|
||||
sa_column: Union[Column[Any], UndefinedType] = Undefined
|
||||
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined
|
||||
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined
|
||||
|
||||
|
||||
def _get_sqlmodel_field_metadata(field_info: Any) -> Optional[FieldInfoMetadata]:
|
||||
metadata_items = getattr(field_info, "metadata", None)
|
||||
if metadata_items:
|
||||
for meta in metadata_items:
|
||||
if isinstance(meta, FieldInfoMetadata):
|
||||
return meta
|
||||
return None
|
||||
|
||||
|
||||
def _get_sqlmodel_field_value(
|
||||
field_info: Any, attribute: str, default: Any = Undefined
|
||||
) -> Any:
|
||||
metadata = _get_sqlmodel_field_metadata(field_info)
|
||||
if metadata is not None and hasattr(metadata, attribute):
|
||||
return getattr(metadata, attribute)
|
||||
return getattr(field_info, attribute, default)
|
||||
|
||||
|
||||
# include sa_type, sa_column_args, sa_column_kwargs
|
||||
@overload
|
||||
def Field(
|
||||
@@ -423,6 +456,20 @@ def Field(
|
||||
default_factory=default_factory,
|
||||
**field_info_kwargs,
|
||||
)
|
||||
field_metadata = FieldInfoMetadata(
|
||||
primary_key=primary_key,
|
||||
nullable=nullable,
|
||||
foreign_key=foreign_key,
|
||||
ondelete=ondelete,
|
||||
unique=unique,
|
||||
index=index,
|
||||
sa_type=sa_type,
|
||||
sa_column=sa_column,
|
||||
sa_column_args=sa_column_args,
|
||||
sa_column_kwargs=sa_column_kwargs,
|
||||
)
|
||||
if hasattr(field_info, "metadata"):
|
||||
field_info.metadata.append(field_metadata)
|
||||
return field_info
|
||||
|
||||
|
||||
@@ -637,7 +684,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
|
||||
|
||||
def get_sqlalchemy_type(field: Any) -> Any:
|
||||
field_info = field
|
||||
sa_type = getattr(field_info, "sa_type", Undefined) # noqa: B009
|
||||
sa_type = _get_sqlmodel_field_value(field_info, "sa_type", Undefined) # noqa: B009
|
||||
if sa_type is not Undefined:
|
||||
return sa_type
|
||||
|
||||
@@ -691,39 +738,39 @@ def get_sqlalchemy_type(field: Any) -> Any:
|
||||
|
||||
def get_column_from_field(field: Any) -> Column: # type: ignore
|
||||
field_info = field
|
||||
sa_column = getattr(field_info, "sa_column", Undefined)
|
||||
sa_column = _get_sqlmodel_field_value(field_info, "sa_column", Undefined)
|
||||
if isinstance(sa_column, Column):
|
||||
return sa_column
|
||||
sa_type = get_sqlalchemy_type(field)
|
||||
primary_key = getattr(field_info, "primary_key", Undefined)
|
||||
primary_key = _get_sqlmodel_field_value(field_info, "primary_key", Undefined)
|
||||
if primary_key is Undefined:
|
||||
primary_key = False
|
||||
index = getattr(field_info, "index", Undefined)
|
||||
index = _get_sqlmodel_field_value(field_info, "index", Undefined)
|
||||
if index is Undefined:
|
||||
index = False
|
||||
nullable = not primary_key and is_field_noneable(field)
|
||||
# Override derived nullability if the nullable property is set explicitly
|
||||
# on the field
|
||||
field_nullable = getattr(field_info, "nullable", Undefined) # noqa: B009
|
||||
field_nullable = _get_sqlmodel_field_value(field_info, "nullable", Undefined)
|
||||
if field_nullable is not Undefined:
|
||||
assert not isinstance(field_nullable, UndefinedType)
|
||||
nullable = field_nullable
|
||||
args = []
|
||||
foreign_key = getattr(field_info, "foreign_key", Undefined)
|
||||
foreign_key = _get_sqlmodel_field_value(field_info, "foreign_key", Undefined)
|
||||
if foreign_key is Undefined:
|
||||
foreign_key = None
|
||||
unique = getattr(field_info, "unique", Undefined)
|
||||
unique = _get_sqlmodel_field_value(field_info, "unique", Undefined)
|
||||
if unique is Undefined:
|
||||
unique = False
|
||||
if foreign_key:
|
||||
if field_info.ondelete == "SET NULL" and not nullable:
|
||||
ondelete_value = _get_sqlmodel_field_value(field_info, "ondelete", Undefined)
|
||||
if ondelete_value is Undefined:
|
||||
ondelete_value = None
|
||||
if ondelete_value == "SET NULL" and not nullable:
|
||||
raise RuntimeError('ondelete="SET NULL" requires nullable=True')
|
||||
assert isinstance(foreign_key, str)
|
||||
ondelete = getattr(field_info, "ondelete", Undefined)
|
||||
if ondelete is Undefined:
|
||||
ondelete = None
|
||||
assert isinstance(ondelete, (str, type(None))) # for typing
|
||||
args.append(ForeignKey(foreign_key, ondelete=ondelete))
|
||||
assert isinstance(ondelete_value, (str, type(None))) # for typing
|
||||
args.append(ForeignKey(foreign_key, ondelete=ondelete_value))
|
||||
kwargs = {
|
||||
"primary_key": primary_key,
|
||||
"nullable": nullable,
|
||||
@@ -737,10 +784,12 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
|
||||
sa_default = field_info.default
|
||||
if sa_default is not Undefined:
|
||||
kwargs["default"] = sa_default
|
||||
sa_column_args = getattr(field_info, "sa_column_args", Undefined)
|
||||
sa_column_args = _get_sqlmodel_field_value(field_info, "sa_column_args", Undefined)
|
||||
if sa_column_args is not Undefined:
|
||||
args.extend(list(cast(Sequence[Any], sa_column_args)))
|
||||
sa_column_kwargs = getattr(field_info, "sa_column_kwargs", Undefined)
|
||||
sa_column_kwargs = _get_sqlmodel_field_value(
|
||||
field_info, "sa_column_kwargs", Undefined
|
||||
)
|
||||
if sa_column_kwargs is not Undefined:
|
||||
kwargs.update(cast(dict[Any, Any], sa_column_kwargs))
|
||||
return Column(sa_type, *args, **kwargs) # type: ignore
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Column, Integer, String
|
||||
@@ -17,6 +17,17 @@ def test_sa_column_takes_precedence() -> None:
|
||||
assert isinstance(Item.id.type, String) # type: ignore
|
||||
|
||||
|
||||
def test_sa_column_with_annotated_metadata() -> None:
|
||||
class Item(SQLModel, table=True):
|
||||
id: Annotated[Optional[int], "meta"] = Field(
|
||||
default=None,
|
||||
sa_column=Column(String, primary_key=True, nullable=False),
|
||||
)
|
||||
|
||||
assert Item.id.nullable is False # type: ignore
|
||||
assert isinstance(Item.id.type, String) # type: ignore
|
||||
|
||||
|
||||
def test_sa_column_no_sa_args() -> None:
|
||||
with pytest.raises(RuntimeError):
|
||||
|
||||
|
||||
63
tests/test_future_annotations.py
Normal file
63
tests/test_future_annotations.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from sqlmodel import Field, Session, SQLModel, create_engine, select
|
||||
|
||||
|
||||
def test_model_with_future_annotations(clear_sqlmodel):
|
||||
class Hero(SQLModel, table=True):
|
||||
id: Annotated[Optional[int], Field(primary_key=True)] = None
|
||||
name: str
|
||||
secret_name: str
|
||||
age: Optional[int] = None
|
||||
|
||||
hero = Hero(name="Deadpond", secret_name="Dive Wilson", age=25)
|
||||
|
||||
engine = create_engine("sqlite://")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
with Session(engine) as session:
|
||||
session.add(hero)
|
||||
session.commit()
|
||||
session.refresh(hero)
|
||||
|
||||
assert hero.id is not None
|
||||
assert hero.name == "Deadpond"
|
||||
assert hero.secret_name == "Dive Wilson"
|
||||
assert hero.age == 25
|
||||
|
||||
with Session(engine) as session:
|
||||
heroes = session.exec(select(Hero)).all()
|
||||
assert len(heroes) == 1
|
||||
assert heroes[0].name == "Deadpond"
|
||||
|
||||
|
||||
def test_model_with_string_annotations(clear_sqlmodel):
|
||||
class Team(SQLModel, table=True):
|
||||
id: Annotated[Optional[int], Field(primary_key=True)] = None
|
||||
name: str
|
||||
|
||||
class Player(SQLModel, table=True):
|
||||
id: Annotated[Optional[int], Field(primary_key=True)] = None
|
||||
name: str
|
||||
team_id: Annotated[Optional[int], Field(foreign_key="team.id")] = None
|
||||
|
||||
engine = create_engine("sqlite://")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
team = Team(name="Champions")
|
||||
player = Player(name="Alice", team_id=None)
|
||||
|
||||
with Session(engine) as session:
|
||||
session.add(team)
|
||||
session.commit()
|
||||
session.refresh(team)
|
||||
|
||||
player.team_id = team.id
|
||||
session.add(player)
|
||||
session.commit()
|
||||
session.refresh(player)
|
||||
|
||||
assert team.id is not None
|
||||
assert player.team_id == team.id
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@@ -125,3 +125,94 @@ def test_sa_relationship_property(clear_sqlmodel):
|
||||
# The next statement should not raise an AttributeError
|
||||
assert hero_rusty_man.team
|
||||
assert hero_rusty_man.team.name == "Preventers"
|
||||
|
||||
|
||||
def test_composite_primary_key(clear_sqlmodel):
|
||||
class UserPermission(SQLModel, table=True):
|
||||
user_id: int = Field(primary_key=True)
|
||||
resource_id: int = Field(primary_key=True)
|
||||
permission: str
|
||||
|
||||
engine = create_engine("sqlite://")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
pk_column_names = {column.name for column in UserPermission.__table__.primary_key}
|
||||
assert pk_column_names == {"user_id", "resource_id"}
|
||||
|
||||
with Session(engine) as session:
|
||||
perm1 = UserPermission(user_id=1, resource_id=1, permission="read")
|
||||
perm2 = UserPermission(user_id=1, resource_id=2, permission="write")
|
||||
session.add(perm1)
|
||||
session.add(perm2)
|
||||
session.commit()
|
||||
|
||||
with pytest.raises(IntegrityError):
|
||||
with Session(engine) as session:
|
||||
perm3 = UserPermission(user_id=1, resource_id=1, permission="admin")
|
||||
session.add(perm3)
|
||||
session.commit()
|
||||
|
||||
|
||||
def test_composite_primary_key_and_validator(clear_sqlmodel):
|
||||
from pydantic import AfterValidator
|
||||
|
||||
def validate_resource_id(value: int) -> int:
|
||||
if value < 1:
|
||||
raise ValueError("Resource ID must be positive")
|
||||
return value
|
||||
|
||||
class UserPermission(SQLModel, table=True):
|
||||
user_id: int = Field(primary_key=True)
|
||||
resource_id: Annotated[int, AfterValidator(validate_resource_id)] = Field(
|
||||
primary_key=True
|
||||
)
|
||||
permission: str
|
||||
|
||||
engine = create_engine("sqlite://")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
pk_column_names = {column.name for column in UserPermission.__table__.primary_key}
|
||||
assert pk_column_names == {"user_id", "resource_id"}
|
||||
|
||||
with Session(engine) as session:
|
||||
perm1 = UserPermission(user_id=1, resource_id=1, permission="read")
|
||||
perm2 = UserPermission(user_id=1, resource_id=2, permission="write")
|
||||
session.add(perm1)
|
||||
session.add(perm2)
|
||||
session.commit()
|
||||
|
||||
with pytest.raises(IntegrityError):
|
||||
with Session(engine) as session:
|
||||
perm3 = UserPermission(user_id=1, resource_id=1, permission="admin")
|
||||
session.add(perm3)
|
||||
session.commit()
|
||||
|
||||
|
||||
def test_foreign_key_ondelete_with_annotated(clear_sqlmodel):
|
||||
from pydantic import AfterValidator
|
||||
|
||||
def ensure_positive(value: int) -> int:
|
||||
if value < 0:
|
||||
raise ValueError("Team ID must be positive")
|
||||
return value
|
||||
|
||||
class Team(SQLModel, table=True):
|
||||
id: int = Field(primary_key=True)
|
||||
name: str
|
||||
|
||||
class Hero(SQLModel, table=True):
|
||||
id: int = Field(primary_key=True)
|
||||
team_id: Annotated[int, AfterValidator(ensure_positive)] = Field(
|
||||
foreign_key="team.id",
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
name: str
|
||||
|
||||
engine = create_engine("sqlite://")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
team_id_column = Hero.__table__.c.team_id # type: ignore[attr-defined]
|
||||
foreign_keys = list(team_id_column.foreign_keys)
|
||||
assert len(foreign_keys) == 1
|
||||
assert foreign_keys[0].ondelete == "CASCADE"
|
||||
assert team_id_column.nullable is False
|
||||
|
||||
Reference in New Issue
Block a user