Allow Discriminator for discriminator in Field

This commit is contained in:
Yurii Motov
2026-01-29 10:00:01 +01:00
parent 5611bda2e5
commit 80f89862ae
3 changed files with 44 additions and 7 deletions

View File

@@ -1,5 +1,9 @@
__version__ = "0.0.31"
# Re-export from Pydantic
from pydantic import Discriminator as Discriminator
from pydantic import Tag as Tag
# Re-export from SQLAlchemy
from sqlalchemy.engine import create_engine as create_engine
from sqlalchemy.engine import create_mock_engine as create_mock_engine

View File

@@ -22,7 +22,7 @@ from typing import (
overload,
)
from pydantic import BaseModel, EmailStr
from pydantic import BaseModel, Discriminator, EmailStr
from pydantic.fields import FieldInfo as PydanticFieldInfo
from sqlalchemy import (
Boolean,
@@ -228,7 +228,7 @@ def Field(
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
discriminator: Optional[str] = None,
discriminator: Union[str, Discriminator, None] = None,
repr: bool = True,
primary_key: Union[bool, UndefinedType] = Undefined,
foreign_key: Any = Undefined,
@@ -271,7 +271,7 @@ def Field(
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
discriminator: Optional[str] = None,
discriminator: Union[str, Discriminator, None] = None,
repr: bool = True,
primary_key: Union[bool, UndefinedType] = Undefined,
foreign_key: str,
@@ -323,7 +323,7 @@ def Field(
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
discriminator: Optional[str] = None,
discriminator: Union[str, Discriminator, None] = None,
repr: bool = True,
sa_column: Union[Column[Any], UndefinedType] = Undefined,
schema_extra: Optional[dict[str, Any]] = None,
@@ -356,7 +356,7 @@ def Field(
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
discriminator: Optional[str] = None,
discriminator: Union[str, Discriminator, None] = None,
repr: bool = True,
primary_key: Union[bool, UndefinedType] = Undefined,
foreign_key: Any = Undefined,

View File

@@ -1,9 +1,9 @@
from decimal import Decimal
from typing import Literal, Optional, Union
from typing import Annotated, Any, Literal, Optional, Union
import pytest
from pydantic import ValidationError
from sqlmodel import Field, SQLModel
from sqlmodel import Discriminator, Field, SQLModel, Tag
def test_decimal():
@@ -47,6 +47,39 @@ def test_discriminator():
Model(pet={"pet_type": "dog"}, n=1) # type: ignore[arg-type]
def test_discriminator_callable():
# Example adapted from
# [Pydantic docs](https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator):
class Pie(SQLModel):
pass
class ApplePie(Pie):
fruit: Literal["apple"] = "apple"
class PumpkinPie(Pie):
filling: Literal["pumpkin"] = "pumpkin"
def get_discriminator_value(v: Any) -> str:
if isinstance(v, dict):
return v.get("fruit", v.get("filling"))
return getattr(v, "fruit", getattr(v, "filling", None))
class ThanksgivingDinner(SQLModel):
dessert: Union[
Annotated[ApplePie, Tag("apple")],
Annotated[PumpkinPie, Tag("pumpkin")],
] = Field(
discriminator=Discriminator(get_discriminator_value),
)
apple_pie = ThanksgivingDinner.model_validate({"dessert": {"fruit": "apple"}})
assert isinstance(apple_pie.dessert, ApplePie)
pumpkin_pie = ThanksgivingDinner.model_validate({"dessert": {"filling": "pumpkin"}})
assert isinstance(pumpkin_pie.dessert, PumpkinPie)
def test_repr():
class Model(SQLModel):
id: Optional[int] = Field(primary_key=True)