🐛 Fix support for Annotated fields with Pydantic 2.12+ (#1607)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
Co-authored-by: svlandeg <svlandeg@github.com>
Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
This commit is contained in:
Victor Mota
2026-02-01 13:14:55 -05:00
committed by GitHub
parent 5611bda2e5
commit 66533c96d9
4 changed files with 231 additions and 17 deletions

View File

@@ -5,6 +5,7 @@ import ipaddress
import uuid
import weakref
from collections.abc import Mapping, Sequence, Set
from dataclasses import dataclass
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from enum import Enum
@@ -200,6 +201,38 @@ class RelationshipInfo(Representation):
self.sa_relationship_kwargs = sa_relationship_kwargs
@dataclass
class FieldInfoMetadata:
primary_key: Union[bool, UndefinedType] = Undefined
nullable: Union[bool, UndefinedType] = Undefined
foreign_key: Any = Undefined
ondelete: Union[OnDeleteType, UndefinedType] = Undefined
unique: Union[bool, UndefinedType] = Undefined
index: Union[bool, UndefinedType] = Undefined
sa_type: Union[type[Any], UndefinedType] = Undefined
sa_column: Union[Column[Any], UndefinedType] = Undefined
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined
def _get_sqlmodel_field_metadata(field_info: Any) -> Optional[FieldInfoMetadata]:
metadata_items = getattr(field_info, "metadata", None)
if metadata_items:
for meta in metadata_items:
if isinstance(meta, FieldInfoMetadata):
return meta
return None
def _get_sqlmodel_field_value(
field_info: Any, attribute: str, default: Any = Undefined
) -> Any:
metadata = _get_sqlmodel_field_metadata(field_info)
if metadata is not None and hasattr(metadata, attribute):
return getattr(metadata, attribute)
return getattr(field_info, attribute, default)
# include sa_type, sa_column_args, sa_column_kwargs
@overload
def Field(
@@ -423,6 +456,20 @@ def Field(
default_factory=default_factory,
**field_info_kwargs,
)
field_metadata = FieldInfoMetadata(
primary_key=primary_key,
nullable=nullable,
foreign_key=foreign_key,
ondelete=ondelete,
unique=unique,
index=index,
sa_type=sa_type,
sa_column=sa_column,
sa_column_args=sa_column_args,
sa_column_kwargs=sa_column_kwargs,
)
if hasattr(field_info, "metadata"):
field_info.metadata.append(field_metadata)
return field_info
@@ -637,7 +684,7 @@ class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
def get_sqlalchemy_type(field: Any) -> Any:
field_info = field
sa_type = getattr(field_info, "sa_type", Undefined) # noqa: B009
sa_type = _get_sqlmodel_field_value(field_info, "sa_type", Undefined) # noqa: B009
if sa_type is not Undefined:
return sa_type
@@ -691,39 +738,39 @@ def get_sqlalchemy_type(field: Any) -> Any:
def get_column_from_field(field: Any) -> Column: # type: ignore
field_info = field
sa_column = getattr(field_info, "sa_column", Undefined)
sa_column = _get_sqlmodel_field_value(field_info, "sa_column", Undefined)
if isinstance(sa_column, Column):
return sa_column
sa_type = get_sqlalchemy_type(field)
primary_key = getattr(field_info, "primary_key", Undefined)
primary_key = _get_sqlmodel_field_value(field_info, "primary_key", Undefined)
if primary_key is Undefined:
primary_key = False
index = getattr(field_info, "index", Undefined)
index = _get_sqlmodel_field_value(field_info, "index", Undefined)
if index is Undefined:
index = False
nullable = not primary_key and is_field_noneable(field)
# Override derived nullability if the nullable property is set explicitly
# on the field
field_nullable = getattr(field_info, "nullable", Undefined) # noqa: B009
field_nullable = _get_sqlmodel_field_value(field_info, "nullable", Undefined)
if field_nullable is not Undefined:
assert not isinstance(field_nullable, UndefinedType)
nullable = field_nullable
args = []
foreign_key = getattr(field_info, "foreign_key", Undefined)
foreign_key = _get_sqlmodel_field_value(field_info, "foreign_key", Undefined)
if foreign_key is Undefined:
foreign_key = None
unique = getattr(field_info, "unique", Undefined)
unique = _get_sqlmodel_field_value(field_info, "unique", Undefined)
if unique is Undefined:
unique = False
if foreign_key:
if field_info.ondelete == "SET NULL" and not nullable:
ondelete_value = _get_sqlmodel_field_value(field_info, "ondelete", Undefined)
if ondelete_value is Undefined:
ondelete_value = None
if ondelete_value == "SET NULL" and not nullable:
raise RuntimeError('ondelete="SET NULL" requires nullable=True')
assert isinstance(foreign_key, str)
ondelete = getattr(field_info, "ondelete", Undefined)
if ondelete is Undefined:
ondelete = None
assert isinstance(ondelete, (str, type(None))) # for typing
args.append(ForeignKey(foreign_key, ondelete=ondelete))
assert isinstance(ondelete_value, (str, type(None))) # for typing
args.append(ForeignKey(foreign_key, ondelete=ondelete_value))
kwargs = {
"primary_key": primary_key,
"nullable": nullable,
@@ -737,10 +784,12 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
sa_default = field_info.default
if sa_default is not Undefined:
kwargs["default"] = sa_default
sa_column_args = getattr(field_info, "sa_column_args", Undefined)
sa_column_args = _get_sqlmodel_field_value(field_info, "sa_column_args", Undefined)
if sa_column_args is not Undefined:
args.extend(list(cast(Sequence[Any], sa_column_args)))
sa_column_kwargs = getattr(field_info, "sa_column_kwargs", Undefined)
sa_column_kwargs = _get_sqlmodel_field_value(
field_info, "sa_column_kwargs", Undefined
)
if sa_column_kwargs is not Undefined:
kwargs.update(cast(dict[Any, Any], sa_column_kwargs))
return Column(sa_type, *args, **kwargs) # type: ignore

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Annotated, Optional
import pytest
from sqlalchemy import Column, Integer, String
@@ -17,6 +17,17 @@ def test_sa_column_takes_precedence() -> None:
assert isinstance(Item.id.type, String) # type: ignore
def test_sa_column_with_annotated_metadata() -> None:
class Item(SQLModel, table=True):
id: Annotated[Optional[int], "meta"] = Field(
default=None,
sa_column=Column(String, primary_key=True, nullable=False),
)
assert Item.id.nullable is False # type: ignore
assert isinstance(Item.id.type, String) # type: ignore
def test_sa_column_no_sa_args() -> None:
with pytest.raises(RuntimeError):

View File

@@ -0,0 +1,63 @@
from __future__ import annotations
from typing import Annotated, Optional
from sqlmodel import Field, Session, SQLModel, create_engine, select
def test_model_with_future_annotations(clear_sqlmodel):
class Hero(SQLModel, table=True):
id: Annotated[Optional[int], Field(primary_key=True)] = None
name: str
secret_name: str
age: Optional[int] = None
hero = Hero(name="Deadpond", secret_name="Dive Wilson", age=25)
engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
session.add(hero)
session.commit()
session.refresh(hero)
assert hero.id is not None
assert hero.name == "Deadpond"
assert hero.secret_name == "Dive Wilson"
assert hero.age == 25
with Session(engine) as session:
heroes = session.exec(select(Hero)).all()
assert len(heroes) == 1
assert heroes[0].name == "Deadpond"
def test_model_with_string_annotations(clear_sqlmodel):
class Team(SQLModel, table=True):
id: Annotated[Optional[int], Field(primary_key=True)] = None
name: str
class Player(SQLModel, table=True):
id: Annotated[Optional[int], Field(primary_key=True)] = None
name: str
team_id: Annotated[Optional[int], Field(foreign_key="team.id")] = None
engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)
team = Team(name="Champions")
player = Player(name="Alice", team_id=None)
with Session(engine) as session:
session.add(team)
session.commit()
session.refresh(team)
player.team_id = team.id
session.add(player)
session.commit()
session.refresh(player)
assert team.id is not None
assert player.team_id == team.id

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Annotated, Optional
import pytest
from sqlalchemy.exc import IntegrityError
@@ -125,3 +125,94 @@ def test_sa_relationship_property(clear_sqlmodel):
# The next statement should not raise an AttributeError
assert hero_rusty_man.team
assert hero_rusty_man.team.name == "Preventers"
def test_composite_primary_key(clear_sqlmodel):
class UserPermission(SQLModel, table=True):
user_id: int = Field(primary_key=True)
resource_id: int = Field(primary_key=True)
permission: str
engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)
pk_column_names = {column.name for column in UserPermission.__table__.primary_key}
assert pk_column_names == {"user_id", "resource_id"}
with Session(engine) as session:
perm1 = UserPermission(user_id=1, resource_id=1, permission="read")
perm2 = UserPermission(user_id=1, resource_id=2, permission="write")
session.add(perm1)
session.add(perm2)
session.commit()
with pytest.raises(IntegrityError):
with Session(engine) as session:
perm3 = UserPermission(user_id=1, resource_id=1, permission="admin")
session.add(perm3)
session.commit()
def test_composite_primary_key_and_validator(clear_sqlmodel):
from pydantic import AfterValidator
def validate_resource_id(value: int) -> int:
if value < 1:
raise ValueError("Resource ID must be positive")
return value
class UserPermission(SQLModel, table=True):
user_id: int = Field(primary_key=True)
resource_id: Annotated[int, AfterValidator(validate_resource_id)] = Field(
primary_key=True
)
permission: str
engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)
pk_column_names = {column.name for column in UserPermission.__table__.primary_key}
assert pk_column_names == {"user_id", "resource_id"}
with Session(engine) as session:
perm1 = UserPermission(user_id=1, resource_id=1, permission="read")
perm2 = UserPermission(user_id=1, resource_id=2, permission="write")
session.add(perm1)
session.add(perm2)
session.commit()
with pytest.raises(IntegrityError):
with Session(engine) as session:
perm3 = UserPermission(user_id=1, resource_id=1, permission="admin")
session.add(perm3)
session.commit()
def test_foreign_key_ondelete_with_annotated(clear_sqlmodel):
from pydantic import AfterValidator
def ensure_positive(value: int) -> int:
if value < 0:
raise ValueError("Team ID must be positive")
return value
class Team(SQLModel, table=True):
id: int = Field(primary_key=True)
name: str
class Hero(SQLModel, table=True):
id: int = Field(primary_key=True)
team_id: Annotated[int, AfterValidator(ensure_positive)] = Field(
foreign_key="team.id",
ondelete="CASCADE",
)
name: str
engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)
team_id_column = Hero.__table__.c.team_id # type: ignore[attr-defined]
foreign_keys = list(team_id_column.foreign_keys)
assert len(foreign_keys) == 1
assert foreign_keys[0].ondelete == "CASCADE"
assert team_id_column.nullable is False