mirror of
https://github.com/fastapi/sqlmodel.git
synced 2026-02-04 11:44:01 +08:00
Allow Discriminator for discriminator in Field
This commit is contained in:
@@ -1,5 +1,9 @@
|
||||
__version__ = "0.0.31"
|
||||
|
||||
# Re-export from Pydantic
|
||||
from pydantic import Discriminator as Discriminator
|
||||
from pydantic import Tag as Tag
|
||||
|
||||
# Re-export from SQLAlchemy
|
||||
from sqlalchemy.engine import create_engine as create_engine
|
||||
from sqlalchemy.engine import create_mock_engine as create_mock_engine
|
||||
|
||||
@@ -22,7 +22,7 @@ from typing import (
|
||||
overload,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from pydantic import BaseModel, Discriminator, EmailStr
|
||||
from pydantic.fields import FieldInfo as PydanticFieldInfo
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
@@ -228,7 +228,7 @@ def Field(
|
||||
max_length: Optional[int] = None,
|
||||
allow_mutation: bool = True,
|
||||
regex: Optional[str] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
discriminator: Union[str, Discriminator, None] = None,
|
||||
repr: bool = True,
|
||||
primary_key: Union[bool, UndefinedType] = Undefined,
|
||||
foreign_key: Any = Undefined,
|
||||
@@ -271,7 +271,7 @@ def Field(
|
||||
max_length: Optional[int] = None,
|
||||
allow_mutation: bool = True,
|
||||
regex: Optional[str] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
discriminator: Union[str, Discriminator, None] = None,
|
||||
repr: bool = True,
|
||||
primary_key: Union[bool, UndefinedType] = Undefined,
|
||||
foreign_key: str,
|
||||
@@ -323,7 +323,7 @@ def Field(
|
||||
max_length: Optional[int] = None,
|
||||
allow_mutation: bool = True,
|
||||
regex: Optional[str] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
discriminator: Union[str, Discriminator, None] = None,
|
||||
repr: bool = True,
|
||||
sa_column: Union[Column[Any], UndefinedType] = Undefined,
|
||||
schema_extra: Optional[dict[str, Any]] = None,
|
||||
@@ -356,7 +356,7 @@ def Field(
|
||||
max_length: Optional[int] = None,
|
||||
allow_mutation: bool = True,
|
||||
regex: Optional[str] = None,
|
||||
discriminator: Optional[str] = None,
|
||||
discriminator: Union[str, Discriminator, None] = None,
|
||||
repr: bool = True,
|
||||
primary_key: Union[bool, UndefinedType] = Undefined,
|
||||
foreign_key: Any = Undefined,
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from decimal import Decimal
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Annotated, Any, Literal, Optional, Union
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from sqlmodel import Field, SQLModel
|
||||
from sqlmodel import Discriminator, Field, SQLModel, Tag
|
||||
|
||||
|
||||
def test_decimal():
|
||||
@@ -47,6 +47,39 @@ def test_discriminator():
|
||||
Model(pet={"pet_type": "dog"}, n=1) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_discriminator_callable():
|
||||
# Example adapted from
|
||||
# [Pydantic docs](https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator):
|
||||
|
||||
class Pie(SQLModel):
|
||||
pass
|
||||
|
||||
class ApplePie(Pie):
|
||||
fruit: Literal["apple"] = "apple"
|
||||
|
||||
class PumpkinPie(Pie):
|
||||
filling: Literal["pumpkin"] = "pumpkin"
|
||||
|
||||
def get_discriminator_value(v: Any) -> str:
|
||||
if isinstance(v, dict):
|
||||
return v.get("fruit", v.get("filling"))
|
||||
return getattr(v, "fruit", getattr(v, "filling", None))
|
||||
|
||||
class ThanksgivingDinner(SQLModel):
|
||||
dessert: Union[
|
||||
Annotated[ApplePie, Tag("apple")],
|
||||
Annotated[PumpkinPie, Tag("pumpkin")],
|
||||
] = Field(
|
||||
discriminator=Discriminator(get_discriminator_value),
|
||||
)
|
||||
|
||||
apple_pie = ThanksgivingDinner.model_validate({"dessert": {"fruit": "apple"}})
|
||||
assert isinstance(apple_pie.dessert, ApplePie)
|
||||
|
||||
pumpkin_pie = ThanksgivingDinner.model_validate({"dessert": {"filling": "pumpkin"}})
|
||||
assert isinstance(pumpkin_pie.dessert, PumpkinPie)
|
||||
|
||||
|
||||
def test_repr():
|
||||
class Model(SQLModel):
|
||||
id: Optional[int] = Field(primary_key=True)
|
||||
|
||||
Reference in New Issue
Block a user