mirror of
				https://github.com/fastapi/sqlmodel.git
				synced 2025-11-04 06:37:29 +08:00 
			
		
		
		
	✨ Add SQLModel core code
This commit is contained in:
		
							
								
								
									
										139
									
								
								sqlmodel/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								sqlmodel/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,139 @@
 | 
			
		||||
__version__ = "0.0.1"
 | 
			
		||||
 | 
			
		||||
# Re-export from SQLAlchemy
 | 
			
		||||
from sqlalchemy.engine import create_mock_engine as create_mock_engine
 | 
			
		||||
from sqlalchemy.engine import engine_from_config as engine_from_config
 | 
			
		||||
from sqlalchemy.inspection import inspect as inspect
 | 
			
		||||
from sqlalchemy.schema import BLANK_SCHEMA as BLANK_SCHEMA
 | 
			
		||||
from sqlalchemy.schema import CheckConstraint as CheckConstraint
 | 
			
		||||
from sqlalchemy.schema import Column as Column
 | 
			
		||||
from sqlalchemy.schema import ColumnDefault as ColumnDefault
 | 
			
		||||
from sqlalchemy.schema import Computed as Computed
 | 
			
		||||
from sqlalchemy.schema import Constraint as Constraint
 | 
			
		||||
from sqlalchemy.schema import DDL as DDL
 | 
			
		||||
from sqlalchemy.schema import DefaultClause as DefaultClause
 | 
			
		||||
from sqlalchemy.schema import FetchedValue as FetchedValue
 | 
			
		||||
from sqlalchemy.schema import ForeignKey as ForeignKey
 | 
			
		||||
from sqlalchemy.schema import ForeignKeyConstraint as ForeignKeyConstraint
 | 
			
		||||
from sqlalchemy.schema import Identity as Identity
 | 
			
		||||
from sqlalchemy.schema import Index as Index
 | 
			
		||||
from sqlalchemy.schema import MetaData as MetaData
 | 
			
		||||
from sqlalchemy.schema import PrimaryKeyConstraint as PrimaryKeyConstraint
 | 
			
		||||
from sqlalchemy.schema import Sequence as Sequence
 | 
			
		||||
from sqlalchemy.schema import Table as Table
 | 
			
		||||
from sqlalchemy.schema import ThreadLocalMetaData as ThreadLocalMetaData
 | 
			
		||||
from sqlalchemy.schema import UniqueConstraint as UniqueConstraint
 | 
			
		||||
from sqlalchemy.sql import alias as alias
 | 
			
		||||
from sqlalchemy.sql import all_ as all_
 | 
			
		||||
from sqlalchemy.sql import and_ as and_
 | 
			
		||||
from sqlalchemy.sql import any_ as any_
 | 
			
		||||
from sqlalchemy.sql import asc as asc
 | 
			
		||||
from sqlalchemy.sql import between as between
 | 
			
		||||
from sqlalchemy.sql import bindparam as bindparam
 | 
			
		||||
from sqlalchemy.sql import case as case
 | 
			
		||||
from sqlalchemy.sql import cast as cast
 | 
			
		||||
from sqlalchemy.sql import collate as collate
 | 
			
		||||
from sqlalchemy.sql import column as column
 | 
			
		||||
from sqlalchemy.sql import delete as delete
 | 
			
		||||
from sqlalchemy.sql import desc as desc
 | 
			
		||||
from sqlalchemy.sql import distinct as distinct
 | 
			
		||||
from sqlalchemy.sql import except_ as except_
 | 
			
		||||
from sqlalchemy.sql import except_all as except_all
 | 
			
		||||
from sqlalchemy.sql import exists as exists
 | 
			
		||||
from sqlalchemy.sql import extract as extract
 | 
			
		||||
from sqlalchemy.sql import false as false
 | 
			
		||||
from sqlalchemy.sql import func as func
 | 
			
		||||
from sqlalchemy.sql import funcfilter as funcfilter
 | 
			
		||||
from sqlalchemy.sql import insert as insert
 | 
			
		||||
from sqlalchemy.sql import intersect as intersect
 | 
			
		||||
from sqlalchemy.sql import intersect_all as intersect_all
 | 
			
		||||
from sqlalchemy.sql import join as join
 | 
			
		||||
from sqlalchemy.sql import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT
 | 
			
		||||
from sqlalchemy.sql import (
 | 
			
		||||
    LABEL_STYLE_DISAMBIGUATE_ONLY as LABEL_STYLE_DISAMBIGUATE_ONLY,
 | 
			
		||||
)
 | 
			
		||||
from sqlalchemy.sql import LABEL_STYLE_NONE as LABEL_STYLE_NONE
 | 
			
		||||
from sqlalchemy.sql import (
 | 
			
		||||
    LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL,
 | 
			
		||||
)
 | 
			
		||||
from sqlalchemy.sql import lambda_stmt as lambda_stmt
 | 
			
		||||
from sqlalchemy.sql import lateral as lateral
 | 
			
		||||
from sqlalchemy.sql import literal as literal
 | 
			
		||||
from sqlalchemy.sql import literal_column as literal_column
 | 
			
		||||
from sqlalchemy.sql import modifier as modifier
 | 
			
		||||
from sqlalchemy.sql import not_ as not_
 | 
			
		||||
from sqlalchemy.sql import null as null
 | 
			
		||||
from sqlalchemy.sql import nulls_first as nulls_first
 | 
			
		||||
from sqlalchemy.sql import nulls_last as nulls_last
 | 
			
		||||
from sqlalchemy.sql import nullsfirst as nullsfirst
 | 
			
		||||
from sqlalchemy.sql import nullslast as nullslast
 | 
			
		||||
from sqlalchemy.sql import or_ as or_
 | 
			
		||||
from sqlalchemy.sql import outerjoin as outerjoin
 | 
			
		||||
from sqlalchemy.sql import outparam as outparam
 | 
			
		||||
from sqlalchemy.sql import over as over
 | 
			
		||||
from sqlalchemy.sql import subquery as subquery
 | 
			
		||||
from sqlalchemy.sql import table as table
 | 
			
		||||
from sqlalchemy.sql import tablesample as tablesample
 | 
			
		||||
from sqlalchemy.sql import text as text
 | 
			
		||||
from sqlalchemy.sql import true as true
 | 
			
		||||
from sqlalchemy.sql import tuple_ as tuple_
 | 
			
		||||
from sqlalchemy.sql import type_coerce as type_coerce
 | 
			
		||||
from sqlalchemy.sql import union as union
 | 
			
		||||
from sqlalchemy.sql import union_all as union_all
 | 
			
		||||
from sqlalchemy.sql import update as update
 | 
			
		||||
from sqlalchemy.sql import values as values
 | 
			
		||||
from sqlalchemy.sql import within_group as within_group
 | 
			
		||||
from sqlalchemy.types import ARRAY as ARRAY
 | 
			
		||||
from sqlalchemy.types import BIGINT as BIGINT
 | 
			
		||||
from sqlalchemy.types import BigInteger as BigInteger
 | 
			
		||||
from sqlalchemy.types import BINARY as BINARY
 | 
			
		||||
from sqlalchemy.types import BLOB as BLOB
 | 
			
		||||
from sqlalchemy.types import BOOLEAN as BOOLEAN
 | 
			
		||||
from sqlalchemy.types import Boolean as Boolean
 | 
			
		||||
from sqlalchemy.types import CHAR as CHAR
 | 
			
		||||
from sqlalchemy.types import CLOB as CLOB
 | 
			
		||||
from sqlalchemy.types import DATE as DATE
 | 
			
		||||
from sqlalchemy.types import Date as Date
 | 
			
		||||
from sqlalchemy.types import DATETIME as DATETIME
 | 
			
		||||
from sqlalchemy.types import DateTime as DateTime
 | 
			
		||||
from sqlalchemy.types import DECIMAL as DECIMAL
 | 
			
		||||
from sqlalchemy.types import Enum as Enum
 | 
			
		||||
from sqlalchemy.types import FLOAT as FLOAT
 | 
			
		||||
from sqlalchemy.types import Float as Float
 | 
			
		||||
from sqlalchemy.types import INT as INT
 | 
			
		||||
from sqlalchemy.types import INTEGER as INTEGER
 | 
			
		||||
from sqlalchemy.types import Integer as Integer
 | 
			
		||||
from sqlalchemy.types import Interval as Interval
 | 
			
		||||
from sqlalchemy.types import JSON as JSON
 | 
			
		||||
from sqlalchemy.types import LargeBinary as LargeBinary
 | 
			
		||||
from sqlalchemy.types import NCHAR as NCHAR
 | 
			
		||||
from sqlalchemy.types import NUMERIC as NUMERIC
 | 
			
		||||
from sqlalchemy.types import Numeric as Numeric
 | 
			
		||||
from sqlalchemy.types import NVARCHAR as NVARCHAR
 | 
			
		||||
from sqlalchemy.types import PickleType as PickleType
 | 
			
		||||
from sqlalchemy.types import REAL as REAL
 | 
			
		||||
from sqlalchemy.types import SMALLINT as SMALLINT
 | 
			
		||||
from sqlalchemy.types import SmallInteger as SmallInteger
 | 
			
		||||
from sqlalchemy.types import String as String
 | 
			
		||||
from sqlalchemy.types import TEXT as TEXT
 | 
			
		||||
from sqlalchemy.types import Text as Text
 | 
			
		||||
from sqlalchemy.types import TIME as TIME
 | 
			
		||||
from sqlalchemy.types import Time as Time
 | 
			
		||||
from sqlalchemy.types import TIMESTAMP as TIMESTAMP
 | 
			
		||||
from sqlalchemy.types import TypeDecorator as TypeDecorator
 | 
			
		||||
from sqlalchemy.types import Unicode as Unicode
 | 
			
		||||
from sqlalchemy.types import UnicodeText as UnicodeText
 | 
			
		||||
from sqlalchemy.types import VARBINARY as VARBINARY
 | 
			
		||||
from sqlalchemy.types import VARCHAR as VARCHAR
 | 
			
		||||
 | 
			
		||||
# Extensions and modifications of SQLAlchemy in SQLModel
 | 
			
		||||
from .engine.create import create_engine as create_engine
 | 
			
		||||
from .orm.session import Session as Session
 | 
			
		||||
from .sql.expression import select as select
 | 
			
		||||
from .sql.expression import col as col
 | 
			
		||||
from .sql.sqltypes import AutoString as AutoString
 | 
			
		||||
 | 
			
		||||
# Export SQLModel specifics (equivalent to Pydantic)
 | 
			
		||||
from .main import SQLModel as SQLModel
 | 
			
		||||
from .main import Field as Field
 | 
			
		||||
from .main import Relationship as Relationship
 | 
			
		||||
							
								
								
									
										32
									
								
								sqlmodel/default.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								sqlmodel/default.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,32 @@
 | 
			
		||||
from typing import Any, TypeVar
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _DefaultPlaceholder:
 | 
			
		||||
    """
 | 
			
		||||
    You shouldn't use this class directly.
 | 
			
		||||
 | 
			
		||||
    It's used internally to recognize when a default value has been overwritten, even
 | 
			
		||||
    if the overriden default value was truthy.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, value: Any):
 | 
			
		||||
        self.value = value
 | 
			
		||||
 | 
			
		||||
    def __bool__(self) -> bool:
 | 
			
		||||
        return bool(self.value)
 | 
			
		||||
 | 
			
		||||
    def __eq__(self, o: object) -> bool:
 | 
			
		||||
        return isinstance(o, _DefaultPlaceholder) and o.value == self.value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_TDefaultType = TypeVar("_TDefaultType")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def Default(value: _TDefaultType) -> _TDefaultType:
 | 
			
		||||
    """
 | 
			
		||||
    You shouldn't use this function directly.
 | 
			
		||||
 | 
			
		||||
    It's used internally to recognize when a default value has been overwritten, even
 | 
			
		||||
    if the overriden default value was truthy.
 | 
			
		||||
    """
 | 
			
		||||
    return _DefaultPlaceholder(value)  # type: ignore
 | 
			
		||||
							
								
								
									
										0
									
								
								sqlmodel/engine/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								sqlmodel/engine/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										139
									
								
								sqlmodel/engine/create.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								sqlmodel/engine/create.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,139 @@
 | 
			
		||||
import json
 | 
			
		||||
import sqlite3
 | 
			
		||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
 | 
			
		||||
 | 
			
		||||
from sqlalchemy import create_engine as _create_engine
 | 
			
		||||
from sqlalchemy.engine.url import URL
 | 
			
		||||
from sqlalchemy.future import Engine as _FutureEngine
 | 
			
		||||
from sqlalchemy.pool import Pool
 | 
			
		||||
from typing_extensions import Literal, TypedDict
 | 
			
		||||
 | 
			
		||||
from ..default import Default, _DefaultPlaceholder
 | 
			
		||||
 | 
			
		||||
# Types defined in sqlalchemy2-stubs, but can't be imported, so re-define here
 | 
			
		||||
 | 
			
		||||
_Debug = Literal["debug"]
 | 
			
		||||
 | 
			
		||||
_IsolationLevel = Literal[
 | 
			
		||||
    "SERIALIZABLE",
 | 
			
		||||
    "REPEATABLE READ",
 | 
			
		||||
    "READ COMMITTED",
 | 
			
		||||
    "READ UNCOMMITTED",
 | 
			
		||||
    "AUTOCOMMIT",
 | 
			
		||||
]
 | 
			
		||||
_ParamStyle = Literal["qmark", "numeric", "named", "format", "pyformat"]
 | 
			
		||||
_ResetOnReturn = Literal["rollback", "commit"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _SQLiteConnectArgs(TypedDict, total=False):
 | 
			
		||||
    timeout: float
 | 
			
		||||
    detect_types: Any
 | 
			
		||||
    isolation_level: Optional[Literal["DEFERRED", "IMMEDIATE", "EXCLUSIVE"]]
 | 
			
		||||
    check_same_thread: bool
 | 
			
		||||
    factory: Type[sqlite3.Connection]
 | 
			
		||||
    cached_statements: int
 | 
			
		||||
    uri: bool
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_ConnectArgs = Union[_SQLiteConnectArgs, Dict[str, Any]]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Re-define create_engine to have by default future=True, and assume that's what is used
 | 
			
		||||
# Also show the default values used for each parameter, but don't set them unless
 | 
			
		||||
# explicitly passed as arguments by the user to prevent errors. E.g. SQLite doesn't
 | 
			
		||||
# support pool connection arguments.
 | 
			
		||||
def create_engine(
 | 
			
		||||
    url: Union[str, URL],
 | 
			
		||||
    *,
 | 
			
		||||
    connect_args: _ConnectArgs = Default({}),  # type: ignore
 | 
			
		||||
    echo: Union[bool, _Debug] = Default(False),
 | 
			
		||||
    echo_pool: Union[bool, _Debug] = Default(False),
 | 
			
		||||
    enable_from_linting: bool = Default(True),
 | 
			
		||||
    encoding: str = Default("utf-8"),
 | 
			
		||||
    execution_options: Dict[Any, Any] = Default({}),
 | 
			
		||||
    future: bool = True,
 | 
			
		||||
    hide_parameters: bool = Default(False),
 | 
			
		||||
    implicit_returning: bool = Default(True),
 | 
			
		||||
    isolation_level: Optional[_IsolationLevel] = Default(None),
 | 
			
		||||
    json_deserializer: Callable[..., Any] = Default(json.loads),
 | 
			
		||||
    json_serializer: Callable[..., Any] = Default(json.dumps),
 | 
			
		||||
    label_length: Optional[int] = Default(None),
 | 
			
		||||
    logging_name: Optional[str] = Default(None),
 | 
			
		||||
    max_identifier_length: Optional[int] = Default(None),
 | 
			
		||||
    max_overflow: int = Default(10),
 | 
			
		||||
    module: Optional[Any] = Default(None),
 | 
			
		||||
    paramstyle: Optional[_ParamStyle] = Default(None),
 | 
			
		||||
    pool: Optional[Pool] = Default(None),
 | 
			
		||||
    poolclass: Optional[Type[Pool]] = Default(None),
 | 
			
		||||
    pool_logging_name: Optional[str] = Default(None),
 | 
			
		||||
    pool_pre_ping: bool = Default(False),
 | 
			
		||||
    pool_size: int = Default(5),
 | 
			
		||||
    pool_recycle: int = Default(-1),
 | 
			
		||||
    pool_reset_on_return: Optional[_ResetOnReturn] = Default("rollback"),
 | 
			
		||||
    pool_timeout: float = Default(30),
 | 
			
		||||
    pool_use_lifo: bool = Default(False),
 | 
			
		||||
    plugins: Optional[List[str]] = Default(None),
 | 
			
		||||
    query_cache_size: Optional[int] = Default(None),
 | 
			
		||||
    **kwargs: Any,
 | 
			
		||||
) -> _FutureEngine:
 | 
			
		||||
    current_kwargs: Dict[str, Any] = {
 | 
			
		||||
        "future": future,
 | 
			
		||||
    }
 | 
			
		||||
    if not isinstance(echo, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["echo"] = echo
 | 
			
		||||
    if not isinstance(echo_pool, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["echo_pool"] = echo_pool
 | 
			
		||||
    if not isinstance(enable_from_linting, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["enable_from_linting"] = enable_from_linting
 | 
			
		||||
    if not isinstance(connect_args, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["connect_args"] = connect_args
 | 
			
		||||
    if not isinstance(encoding, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["encoding"] = encoding
 | 
			
		||||
    if not isinstance(execution_options, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["execution_options"] = execution_options
 | 
			
		||||
    if not isinstance(hide_parameters, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["hide_parameters"] = hide_parameters
 | 
			
		||||
    if not isinstance(implicit_returning, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["implicit_returning"] = implicit_returning
 | 
			
		||||
    if not isinstance(isolation_level, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["isolation_level"] = isolation_level
 | 
			
		||||
    if not isinstance(json_deserializer, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["json_deserializer"] = json_deserializer
 | 
			
		||||
    if not isinstance(json_serializer, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["json_serializer"] = json_serializer
 | 
			
		||||
    if not isinstance(label_length, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["label_length"] = label_length
 | 
			
		||||
    if not isinstance(logging_name, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["logging_name"] = logging_name
 | 
			
		||||
    if not isinstance(max_identifier_length, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["max_identifier_length"] = max_identifier_length
 | 
			
		||||
    if not isinstance(max_overflow, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["max_overflow"] = max_overflow
 | 
			
		||||
    if not isinstance(module, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["module"] = module
 | 
			
		||||
    if not isinstance(paramstyle, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["paramstyle"] = paramstyle
 | 
			
		||||
    if not isinstance(pool, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["pool"] = pool
 | 
			
		||||
    if not isinstance(poolclass, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["poolclass"] = poolclass
 | 
			
		||||
    if not isinstance(pool_logging_name, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["pool_logging_name"] = pool_logging_name
 | 
			
		||||
    if not isinstance(pool_pre_ping, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["pool_pre_ping"] = pool_pre_ping
 | 
			
		||||
    if not isinstance(pool_size, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["pool_size"] = pool_size
 | 
			
		||||
    if not isinstance(pool_recycle, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["pool_recycle"] = pool_recycle
 | 
			
		||||
    if not isinstance(pool_reset_on_return, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["pool_reset_on_return"] = pool_reset_on_return
 | 
			
		||||
    if not isinstance(pool_timeout, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["pool_timeout"] = pool_timeout
 | 
			
		||||
    if not isinstance(pool_use_lifo, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["pool_use_lifo"] = pool_use_lifo
 | 
			
		||||
    if not isinstance(plugins, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["plugins"] = plugins
 | 
			
		||||
    if not isinstance(query_cache_size, _DefaultPlaceholder):
 | 
			
		||||
        current_kwargs["query_cache_size"] = query_cache_size
 | 
			
		||||
    current_kwargs.update(kwargs)
 | 
			
		||||
    return _create_engine(url, **current_kwargs)
 | 
			
		||||
							
								
								
									
										79
									
								
								sqlmodel/engine/result.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								sqlmodel/engine/result.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,79 @@
 | 
			
		||||
from typing import Generic, Iterator, List, Optional, TypeVar
 | 
			
		||||
 | 
			
		||||
from sqlalchemy.engine.result import Result as _Result
 | 
			
		||||
from sqlalchemy.engine.result import ScalarResult as _ScalarResult
 | 
			
		||||
 | 
			
		||||
_T = TypeVar("_T")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ScalarResult(_ScalarResult, Generic[_T]):
 | 
			
		||||
    def all(self) -> List[_T]:
 | 
			
		||||
        return super().all()
 | 
			
		||||
 | 
			
		||||
    def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]:
 | 
			
		||||
        return super().partitions(size)
 | 
			
		||||
 | 
			
		||||
    def fetchall(self) -> List[_T]:
 | 
			
		||||
        return super().fetchall()
 | 
			
		||||
 | 
			
		||||
    def fetchmany(self, size: Optional[int] = None) -> List[_T]:
 | 
			
		||||
        return super().fetchmany(size)
 | 
			
		||||
 | 
			
		||||
    def __iter__(self) -> Iterator[_T]:
 | 
			
		||||
        return super().__iter__()
 | 
			
		||||
 | 
			
		||||
    def __next__(self) -> _T:
 | 
			
		||||
        return super().__next__()
 | 
			
		||||
 | 
			
		||||
    def first(self) -> Optional[_T]:
 | 
			
		||||
        return super().first()
 | 
			
		||||
 | 
			
		||||
    def one_or_none(self) -> Optional[_T]:
 | 
			
		||||
        return super().one_or_none()
 | 
			
		||||
 | 
			
		||||
    def one(self) -> _T:
 | 
			
		||||
        return super().one()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Result(_Result, Generic[_T]):
 | 
			
		||||
    def scalars(self, index: int = 0) -> ScalarResult[_T]:
 | 
			
		||||
        return super().scalars(index)  # type: ignore
 | 
			
		||||
 | 
			
		||||
    def __iter__(self) -> Iterator[_T]:  # type: ignore
 | 
			
		||||
        return super().__iter__()  # type: ignore
 | 
			
		||||
 | 
			
		||||
    def __next__(self) -> _T:  # type: ignore
 | 
			
		||||
        return super().__next__()  # type: ignore
 | 
			
		||||
 | 
			
		||||
    def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]:  # type: ignore
 | 
			
		||||
        return super().partitions(size)  # type: ignore
 | 
			
		||||
 | 
			
		||||
    def fetchall(self) -> List[_T]:  # type: ignore
 | 
			
		||||
        return super().fetchall()  # type: ignore
 | 
			
		||||
 | 
			
		||||
    def fetchone(self) -> Optional[_T]:  # type: ignore
 | 
			
		||||
        return super().fetchone()  # type: ignore
 | 
			
		||||
 | 
			
		||||
    def fetchmany(self, size: Optional[int] = None) -> List[_T]:  # type: ignore
 | 
			
		||||
        return super().fetchmany()  # type: ignore
 | 
			
		||||
 | 
			
		||||
    def all(self) -> List[_T]:  # type: ignore
 | 
			
		||||
        return super().all()  # type: ignore
 | 
			
		||||
 | 
			
		||||
    def first(self) -> Optional[_T]:  # type: ignore
 | 
			
		||||
        return super().first()  # type: ignore
 | 
			
		||||
 | 
			
		||||
    def one_or_none(self) -> Optional[_T]:  # type: ignore
 | 
			
		||||
        return super().one_or_none()  # type: ignore
 | 
			
		||||
 | 
			
		||||
    def scalar_one(self) -> _T:
 | 
			
		||||
        return super().scalar_one()  # type: ignore
 | 
			
		||||
 | 
			
		||||
    def scalar_one_or_none(self) -> Optional[_T]:
 | 
			
		||||
        return super().scalar_one_or_none()  # type: ignore
 | 
			
		||||
 | 
			
		||||
    def one(self) -> _T:  # type: ignore
 | 
			
		||||
        return super().one()  # type: ignore
 | 
			
		||||
 | 
			
		||||
    def scalar(self) -> Optional[_T]:
 | 
			
		||||
        return super().scalar()  # type: ignore
 | 
			
		||||
							
								
								
									
										0
									
								
								sqlmodel/ext/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								sqlmodel/ext/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								sqlmodel/ext/asyncio/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								sqlmodel/ext/asyncio/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										62
									
								
								sqlmodel/ext/asyncio/session.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								sqlmodel/ext/asyncio/session.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,62 @@
 | 
			
		||||
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union
 | 
			
		||||
 | 
			
		||||
from sqlalchemy import util
 | 
			
		||||
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
 | 
			
		||||
from sqlalchemy.ext.asyncio import engine
 | 
			
		||||
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
 | 
			
		||||
from sqlalchemy.util.concurrency import greenlet_spawn
 | 
			
		||||
from sqlmodel.sql.base import Executable
 | 
			
		||||
 | 
			
		||||
from ...engine.result import ScalarResult
 | 
			
		||||
from ...orm.session import Session
 | 
			
		||||
from ...sql.expression import Select
 | 
			
		||||
 | 
			
		||||
_T = TypeVar("_T")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AsyncSession(_AsyncSession):
 | 
			
		||||
    sync_session: Session
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
 | 
			
		||||
        binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
 | 
			
		||||
        **kw,
 | 
			
		||||
    ):
 | 
			
		||||
        # All the same code of the original AsyncSession
 | 
			
		||||
        kw["future"] = True
 | 
			
		||||
        if bind:
 | 
			
		||||
            self.bind = bind
 | 
			
		||||
            bind = engine._get_sync_engine_or_connection(bind)  # type: ignore
 | 
			
		||||
 | 
			
		||||
        if binds:
 | 
			
		||||
            self.binds = binds
 | 
			
		||||
            binds = {
 | 
			
		||||
                key: engine._get_sync_engine_or_connection(b)  # type: ignore
 | 
			
		||||
                for key, b in binds.items()
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        self.sync_session = self._proxied = self._assign_proxied(  # type: ignore
 | 
			
		||||
            Session(bind=bind, binds=binds, **kw)  # type: ignore
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def exec(
 | 
			
		||||
        self,
 | 
			
		||||
        statement: Union[Select[_T], Executable[_T]],
 | 
			
		||||
        params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
 | 
			
		||||
        execution_options: Mapping[Any, Any] = util.EMPTY_DICT,
 | 
			
		||||
        bind_arguments: Optional[Mapping[str, Any]] = None,
 | 
			
		||||
        **kw: Any,
 | 
			
		||||
    ) -> ScalarResult[_T]:
 | 
			
		||||
        # TODO: the documentation says execution_options accepts a dict, but only
 | 
			
		||||
        # util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
 | 
			
		||||
        execution_options = execution_options.union({"prebuffer_rows": True})  # type: ignore
 | 
			
		||||
 | 
			
		||||
        return await greenlet_spawn(  # type: ignore
 | 
			
		||||
            self.sync_session.exec,
 | 
			
		||||
            statement,
 | 
			
		||||
            params=params,
 | 
			
		||||
            execution_options=execution_options,
 | 
			
		||||
            bind_arguments=bind_arguments,
 | 
			
		||||
            **kw,
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										631
									
								
								sqlmodel/main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										631
									
								
								sqlmodel/main.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,631 @@
 | 
			
		||||
import ipaddress
 | 
			
		||||
import uuid
 | 
			
		||||
import weakref
 | 
			
		||||
from datetime import date, datetime, time, timedelta
 | 
			
		||||
from decimal import Decimal
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import (
 | 
			
		||||
    TYPE_CHECKING,
 | 
			
		||||
    AbstractSet,
 | 
			
		||||
    Any,
 | 
			
		||||
    Callable,
 | 
			
		||||
    ClassVar,
 | 
			
		||||
    Dict,
 | 
			
		||||
    ForwardRef,
 | 
			
		||||
    List,
 | 
			
		||||
    Mapping,
 | 
			
		||||
    Optional,
 | 
			
		||||
    Sequence,
 | 
			
		||||
    Set,
 | 
			
		||||
    Tuple,
 | 
			
		||||
    Type,
 | 
			
		||||
    TypeVar,
 | 
			
		||||
    Union,
 | 
			
		||||
    cast,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
from pydantic.errors import ConfigError, DictError
 | 
			
		||||
from pydantic.fields import FieldInfo as PydanticFieldInfo
 | 
			
		||||
from pydantic.fields import ModelField, Undefined, UndefinedType
 | 
			
		||||
from pydantic.main import BaseConfig, ModelMetaclass, validate_model
 | 
			
		||||
from pydantic.typing import NoArgAnyCallable, resolve_annotations
 | 
			
		||||
from pydantic.utils import ROOT_KEY, Representation, ValueItems
 | 
			
		||||
from sqlalchemy import (
 | 
			
		||||
    Boolean,
 | 
			
		||||
    Column,
 | 
			
		||||
    Date,
 | 
			
		||||
    DateTime,
 | 
			
		||||
    Float,
 | 
			
		||||
    ForeignKey,
 | 
			
		||||
    Integer,
 | 
			
		||||
    Interval,
 | 
			
		||||
    Numeric,
 | 
			
		||||
    inspect,
 | 
			
		||||
)
 | 
			
		||||
from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship
 | 
			
		||||
from sqlalchemy.orm.attributes import set_attribute
 | 
			
		||||
from sqlalchemy.orm.decl_api import DeclarativeMeta
 | 
			
		||||
from sqlalchemy.orm.instrumentation import is_instrumented
 | 
			
		||||
from sqlalchemy.sql.schema import MetaData
 | 
			
		||||
from sqlalchemy.sql.sqltypes import LargeBinary, Time
 | 
			
		||||
 | 
			
		||||
from .sql.sqltypes import GUID, AutoString
 | 
			
		||||
 | 
			
		||||
_T = TypeVar("_T")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def __dataclass_transform__(
 | 
			
		||||
    *,
 | 
			
		||||
    eq_default: bool = True,
 | 
			
		||||
    order_default: bool = False,
 | 
			
		||||
    kw_only_default: bool = False,
 | 
			
		||||
    field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
 | 
			
		||||
) -> Callable[[_T], _T]:
 | 
			
		||||
    return lambda a: a
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FieldInfo(PydanticFieldInfo):
 | 
			
		||||
    def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
 | 
			
		||||
        primary_key = kwargs.pop("primary_key", False)
 | 
			
		||||
        nullable = kwargs.pop("nullable", Undefined)
 | 
			
		||||
        foreign_key = kwargs.pop("foreign_key", Undefined)
 | 
			
		||||
        index = kwargs.pop("index", Undefined)
 | 
			
		||||
        sa_column = kwargs.pop("sa_column", Undefined)
 | 
			
		||||
        sa_column_args = kwargs.pop("sa_column_args", Undefined)
 | 
			
		||||
        sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined)
 | 
			
		||||
        if sa_column is not Undefined:
 | 
			
		||||
            if sa_column_args is not Undefined:
 | 
			
		||||
                raise RuntimeError(
 | 
			
		||||
                    "Passing sa_column_args is not supported when "
 | 
			
		||||
                    "also passing a sa_column"
 | 
			
		||||
                )
 | 
			
		||||
            if sa_column_kwargs is not Undefined:
 | 
			
		||||
                raise RuntimeError(
 | 
			
		||||
                    "Passing sa_column_kwargs is not supported when "
 | 
			
		||||
                    "also passing a sa_column"
 | 
			
		||||
                )
 | 
			
		||||
        super().__init__(default=default, **kwargs)
 | 
			
		||||
        self.primary_key = primary_key
 | 
			
		||||
        self.nullable = nullable
 | 
			
		||||
        self.foreign_key = foreign_key
 | 
			
		||||
        self.index = index
 | 
			
		||||
        self.sa_column = sa_column
 | 
			
		||||
        self.sa_column_args = sa_column_args
 | 
			
		||||
        self.sa_column_kwargs = sa_column_kwargs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RelationshipInfo(Representation):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        *,
 | 
			
		||||
        back_populates: Optional[str] = None,
 | 
			
		||||
        link_model: Optional[Any] = None,
 | 
			
		||||
        sa_relationship: Optional[RelationshipProperty] = None,
 | 
			
		||||
        sa_relationship_args: Optional[Sequence[Any]] = None,
 | 
			
		||||
        sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        if sa_relationship is not None:
 | 
			
		||||
            if sa_relationship_args is not None:
 | 
			
		||||
                raise RuntimeError(
 | 
			
		||||
                    "Passing sa_relationship_args is not supported when "
 | 
			
		||||
                    "also passing a sa_relationship"
 | 
			
		||||
                )
 | 
			
		||||
            if sa_relationship_kwargs is not None:
 | 
			
		||||
                raise RuntimeError(
 | 
			
		||||
                    "Passing sa_relationship_kwargs is not supported when "
 | 
			
		||||
                    "also passing a sa_relationship"
 | 
			
		||||
                )
 | 
			
		||||
        self.back_populates = back_populates
 | 
			
		||||
        self.link_model = link_model
 | 
			
		||||
        self.sa_relationship = sa_relationship
 | 
			
		||||
        self.sa_relationship_args = sa_relationship_args
 | 
			
		||||
        self.sa_relationship_kwargs = sa_relationship_kwargs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def Field(
 | 
			
		||||
    default: Any = Undefined,
 | 
			
		||||
    *,
 | 
			
		||||
    default_factory: Optional[NoArgAnyCallable] = None,
 | 
			
		||||
    alias: str = None,
 | 
			
		||||
    title: str = None,
 | 
			
		||||
    description: str = None,
 | 
			
		||||
    exclude: Union[
 | 
			
		||||
        AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
 | 
			
		||||
    ] = None,
 | 
			
		||||
    include: Union[
 | 
			
		||||
        AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any
 | 
			
		||||
    ] = None,
 | 
			
		||||
    const: bool = None,
 | 
			
		||||
    gt: float = None,
 | 
			
		||||
    ge: float = None,
 | 
			
		||||
    lt: float = None,
 | 
			
		||||
    le: float = None,
 | 
			
		||||
    multiple_of: float = None,
 | 
			
		||||
    min_items: int = None,
 | 
			
		||||
    max_items: int = None,
 | 
			
		||||
    min_length: int = None,
 | 
			
		||||
    max_length: int = None,
 | 
			
		||||
    allow_mutation: bool = True,
 | 
			
		||||
    regex: str = None,
 | 
			
		||||
    primary_key: bool = False,
 | 
			
		||||
    foreign_key: Optional[Any] = None,
 | 
			
		||||
    nullable: Union[bool, UndefinedType] = Undefined,
 | 
			
		||||
    index: Union[bool, UndefinedType] = Undefined,
 | 
			
		||||
    sa_column: Union[Column, UndefinedType] = Undefined,
 | 
			
		||||
    sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
 | 
			
		||||
    sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
 | 
			
		||||
    schema_extra: Optional[Dict[str, Any]] = None,
 | 
			
		||||
) -> Any:
 | 
			
		||||
    current_schema_extra = schema_extra or {}
 | 
			
		||||
    field_info = FieldInfo(
 | 
			
		||||
        default,
 | 
			
		||||
        default_factory=default_factory,
 | 
			
		||||
        alias=alias,
 | 
			
		||||
        title=title,
 | 
			
		||||
        description=description,
 | 
			
		||||
        exclude=exclude,
 | 
			
		||||
        include=include,
 | 
			
		||||
        const=const,
 | 
			
		||||
        gt=gt,
 | 
			
		||||
        ge=ge,
 | 
			
		||||
        lt=lt,
 | 
			
		||||
        le=le,
 | 
			
		||||
        multiple_of=multiple_of,
 | 
			
		||||
        min_items=min_items,
 | 
			
		||||
        max_items=max_items,
 | 
			
		||||
        min_length=min_length,
 | 
			
		||||
        max_length=max_length,
 | 
			
		||||
        allow_mutation=allow_mutation,
 | 
			
		||||
        regex=regex,
 | 
			
		||||
        primary_key=primary_key,
 | 
			
		||||
        foreign_key=foreign_key,
 | 
			
		||||
        nullable=nullable,
 | 
			
		||||
        index=index,
 | 
			
		||||
        sa_column=sa_column,
 | 
			
		||||
        sa_column_args=sa_column_args,
 | 
			
		||||
        sa_column_kwargs=sa_column_kwargs,
 | 
			
		||||
        **current_schema_extra,
 | 
			
		||||
    )
 | 
			
		||||
    field_info._validate()
 | 
			
		||||
    return field_info
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def Relationship(
 | 
			
		||||
    *,
 | 
			
		||||
    back_populates: Optional[str] = None,
 | 
			
		||||
    link_model: Optional[Any] = None,
 | 
			
		||||
    sa_relationship: Optional[RelationshipProperty] = None,
 | 
			
		||||
    sa_relationship_args: Optional[Sequence[Any]] = None,
 | 
			
		||||
    sa_relationship_kwargs: Optional[Mapping[str, Any]] = None,
 | 
			
		||||
) -> Any:
 | 
			
		||||
    relationship_info = RelationshipInfo(
 | 
			
		||||
        back_populates=back_populates,
 | 
			
		||||
        link_model=link_model,
 | 
			
		||||
        sa_relationship=sa_relationship,
 | 
			
		||||
        sa_relationship_args=sa_relationship_args,
 | 
			
		||||
        sa_relationship_kwargs=sa_relationship_kwargs,
 | 
			
		||||
    )
 | 
			
		||||
    return relationship_info
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
 | 
			
		||||
class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
 | 
			
		||||
    __sqlmodel_relationships__: Dict[str, RelationshipInfo]
 | 
			
		||||
    __config__: Type[BaseConfig]
 | 
			
		||||
    __fields__: Dict[str, ModelField]
 | 
			
		||||
 | 
			
		||||
    # Replicate SQLAlchemy
 | 
			
		||||
    def __setattr__(cls, name: str, value: Any) -> None:
 | 
			
		||||
        if getattr(cls.__config__, "table", False):  # type: ignore
 | 
			
		||||
            DeclarativeMeta.__setattr__(cls, name, value)
 | 
			
		||||
        else:
 | 
			
		||||
            super().__setattr__(name, value)
 | 
			
		||||
 | 
			
		||||
    def __delattr__(cls, name: str) -> None:
 | 
			
		||||
        if getattr(cls.__config__, "table", False):  # type: ignore
 | 
			
		||||
            DeclarativeMeta.__delattr__(cls, name)
 | 
			
		||||
        else:
 | 
			
		||||
            super().__delattr__(name)
 | 
			
		||||
 | 
			
		||||
    # From Pydantic
 | 
			
		||||
    def __new__(cls, name, bases, class_dict: dict, **kwargs) -> Any:
 | 
			
		||||
        relationships: Dict[str, RelationshipInfo] = {}
 | 
			
		||||
        dict_for_pydantic = {}
 | 
			
		||||
        original_annotations = resolve_annotations(
 | 
			
		||||
            class_dict.get("__annotations__", {}), class_dict.get("__module__", None)
 | 
			
		||||
        )
 | 
			
		||||
        pydantic_annotations = {}
 | 
			
		||||
        relationship_annotations = {}
 | 
			
		||||
        for k, v in class_dict.items():
 | 
			
		||||
            if isinstance(v, RelationshipInfo):
 | 
			
		||||
                relationships[k] = v
 | 
			
		||||
            else:
 | 
			
		||||
                dict_for_pydantic[k] = v
 | 
			
		||||
        for k, v in original_annotations.items():
 | 
			
		||||
            if k in relationships:
 | 
			
		||||
                relationship_annotations[k] = v
 | 
			
		||||
            else:
 | 
			
		||||
                pydantic_annotations[k] = v
 | 
			
		||||
        dict_used = {
 | 
			
		||||
            **dict_for_pydantic,
 | 
			
		||||
            "__weakref__": None,
 | 
			
		||||
            "__sqlmodel_relationships__": relationships,
 | 
			
		||||
            "__annotations__": pydantic_annotations,
 | 
			
		||||
        }
 | 
			
		||||
        # Duplicate logic from Pydantic to filter config kwargs because if they are
 | 
			
		||||
        # passed directly including the registry Pydantic will pass them over to the
 | 
			
		||||
        # superclass causing an error
 | 
			
		||||
        allowed_config_kwargs: Set[str] = {
 | 
			
		||||
            key
 | 
			
		||||
            for key in dir(BaseConfig)
 | 
			
		||||
            if not (
 | 
			
		||||
                key.startswith("__") and key.endswith("__")
 | 
			
		||||
            )  # skip dunder methods and attributes
 | 
			
		||||
        }
 | 
			
		||||
        pydantic_kwargs = kwargs.copy()
 | 
			
		||||
        config_kwargs = {
 | 
			
		||||
            key: pydantic_kwargs.pop(key)
 | 
			
		||||
            for key in pydantic_kwargs.keys() & allowed_config_kwargs
 | 
			
		||||
        }
 | 
			
		||||
        new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
 | 
			
		||||
        new_cls.__annotations__ = {
 | 
			
		||||
            **relationship_annotations,
 | 
			
		||||
            **pydantic_annotations,
 | 
			
		||||
            **new_cls.__annotations__,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        def get_config(name: str) -> Any:
 | 
			
		||||
            config_class_value = getattr(new_cls.__config__, name, Undefined)
 | 
			
		||||
            if config_class_value is not Undefined:
 | 
			
		||||
                return config_class_value
 | 
			
		||||
            kwarg_value = kwargs.get(name, Undefined)
 | 
			
		||||
            if kwarg_value is not Undefined:
 | 
			
		||||
                return kwarg_value
 | 
			
		||||
            return Undefined
 | 
			
		||||
 | 
			
		||||
        config_table = get_config("table")
 | 
			
		||||
        if config_table is True:
 | 
			
		||||
            # If it was passed by kwargs, ensure it's also set in config
 | 
			
		||||
            new_cls.__config__.table = config_table
 | 
			
		||||
            for k, v in new_cls.__fields__.items():
 | 
			
		||||
                col = get_column_from_field(v)
 | 
			
		||||
                setattr(new_cls, k, col)
 | 
			
		||||
            # Set a config flag to tell FastAPI that this should be read with a field
 | 
			
		||||
            # in orm_mode instead of preemptively converting it to a dict.
 | 
			
		||||
            # This could be done by reading new_cls.__config__.table in FastAPI, but
 | 
			
		||||
            # that's very specific about SQLModel, so let's have another config that
 | 
			
		||||
            # other future tools based on Pydantic can use.
 | 
			
		||||
            new_cls.__config__.read_with_orm_mode = True
 | 
			
		||||
 | 
			
		||||
        config_registry = get_config("registry")
 | 
			
		||||
        if config_registry is not Undefined:
 | 
			
		||||
            config_registry = cast(registry, config_registry)
 | 
			
		||||
            # If it was passed by kwargs, ensure it's also set in config
 | 
			
		||||
            new_cls.__config__.registry = config_table
 | 
			
		||||
            setattr(new_cls, "_sa_registry", config_registry)
 | 
			
		||||
            setattr(new_cls, "metadata", config_registry.metadata)
 | 
			
		||||
            setattr(new_cls, "__abstract__", True)
 | 
			
		||||
        return new_cls
 | 
			
		||||
 | 
			
		||||
    # Override SQLAlchemy, allow both SQLAlchemy and plain Pydantic models
 | 
			
		||||
    def __init__(
 | 
			
		||||
        cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any], **kw: Any
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        # Only one of the base classes (or the current one) should be a table model
 | 
			
		||||
        # this allows FastAPI cloning a SQLModel for the response_model without
 | 
			
		||||
        # trying to create a new SQLAlchemy, for a new table, with the same name, that
 | 
			
		||||
        # triggers an error
 | 
			
		||||
        base_is_table = False
 | 
			
		||||
        for base in bases:
 | 
			
		||||
            config = getattr(base, "__config__")
 | 
			
		||||
            if config and getattr(config, "table", False):
 | 
			
		||||
                base_is_table = True
 | 
			
		||||
                break
 | 
			
		||||
        if getattr(cls.__config__, "table", False) and not base_is_table:
 | 
			
		||||
            dict_used = dict_.copy()
 | 
			
		||||
            for field_name, field_value in cls.__fields__.items():
 | 
			
		||||
                dict_used[field_name] = get_column_from_field(field_value)
 | 
			
		||||
            for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
 | 
			
		||||
                if rel_info.sa_relationship:
 | 
			
		||||
                    # There's a SQLAlchemy relationship declared, that takes precedence
 | 
			
		||||
                    # over anything else, use that and continue with the next attribute
 | 
			
		||||
                    dict_used[rel_name] = rel_info.sa_relationship
 | 
			
		||||
                    continue
 | 
			
		||||
                ann = cls.__annotations__[rel_name]
 | 
			
		||||
                temp_field = ModelField.infer(
 | 
			
		||||
                    name=rel_name,
 | 
			
		||||
                    value=rel_info,
 | 
			
		||||
                    annotation=ann,
 | 
			
		||||
                    class_validators=None,
 | 
			
		||||
                    config=BaseConfig,
 | 
			
		||||
                )
 | 
			
		||||
                relationship_to = temp_field.type_
 | 
			
		||||
                if isinstance(temp_field.type_, ForwardRef):
 | 
			
		||||
                    relationship_to = temp_field.type_.__forward_arg__
 | 
			
		||||
                rel_kwargs: Dict[str, Any] = {}
 | 
			
		||||
                if rel_info.back_populates:
 | 
			
		||||
                    rel_kwargs["back_populates"] = rel_info.back_populates
 | 
			
		||||
                if rel_info.link_model:
 | 
			
		||||
                    ins = inspect(rel_info.link_model)
 | 
			
		||||
                    local_table = getattr(ins, "local_table")
 | 
			
		||||
                    if local_table is None:
 | 
			
		||||
                        raise RuntimeError(
 | 
			
		||||
                            "Couldn't find the secondary table for "
 | 
			
		||||
                            f"model {rel_info.link_model}"
 | 
			
		||||
                        )
 | 
			
		||||
                    rel_kwargs["secondary"] = local_table
 | 
			
		||||
                rel_args: List[Any] = []
 | 
			
		||||
                if rel_info.sa_relationship_args:
 | 
			
		||||
                    rel_args.extend(rel_info.sa_relationship_args)
 | 
			
		||||
                if rel_info.sa_relationship_kwargs:
 | 
			
		||||
                    rel_kwargs.update(rel_info.sa_relationship_kwargs)
 | 
			
		||||
                rel_value: RelationshipProperty = relationship(
 | 
			
		||||
                    relationship_to, *rel_args, **rel_kwargs
 | 
			
		||||
                )
 | 
			
		||||
                dict_used[rel_name] = rel_value
 | 
			
		||||
            DeclarativeMeta.__init__(cls, classname, bases, dict_used, **kw)
 | 
			
		||||
        else:
 | 
			
		||||
            ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_sqlachemy_type(field: ModelField) -> Any:
 | 
			
		||||
    if issubclass(field.type_, str):
 | 
			
		||||
        if field.field_info.max_length:
 | 
			
		||||
            return AutoString(length=field.field_info.max_length)
 | 
			
		||||
        return AutoString
 | 
			
		||||
    if issubclass(field.type_, float):
 | 
			
		||||
        return Float
 | 
			
		||||
    if issubclass(field.type_, bool):
 | 
			
		||||
        return Boolean
 | 
			
		||||
    if issubclass(field.type_, int):
 | 
			
		||||
        return Integer
 | 
			
		||||
    if issubclass(field.type_, datetime):
 | 
			
		||||
        return DateTime
 | 
			
		||||
    if issubclass(field.type_, date):
 | 
			
		||||
        return Date
 | 
			
		||||
    if issubclass(field.type_, timedelta):
 | 
			
		||||
        return Interval
 | 
			
		||||
    if issubclass(field.type_, time):
 | 
			
		||||
        return Time
 | 
			
		||||
    if issubclass(field.type_, Enum):
 | 
			
		||||
        return Enum
 | 
			
		||||
    if issubclass(field.type_, bytes):
 | 
			
		||||
        return LargeBinary
 | 
			
		||||
    if issubclass(field.type_, Decimal):
 | 
			
		||||
        return Numeric
 | 
			
		||||
    if issubclass(field.type_, ipaddress.IPv4Address):
 | 
			
		||||
        return AutoString
 | 
			
		||||
    if issubclass(field.type_, ipaddress.IPv4Network):
 | 
			
		||||
        return AutoString
 | 
			
		||||
    if issubclass(field.type_, ipaddress.IPv6Address):
 | 
			
		||||
        return AutoString
 | 
			
		||||
    if issubclass(field.type_, ipaddress.IPv6Network):
 | 
			
		||||
        return AutoString
 | 
			
		||||
    if issubclass(field.type_, Path):
 | 
			
		||||
        return AutoString
 | 
			
		||||
    if issubclass(field.type_, uuid.UUID):
 | 
			
		||||
        return GUID
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_column_from_field(field: ModelField) -> Column:
 | 
			
		||||
    sa_column = getattr(field.field_info, "sa_column", Undefined)
 | 
			
		||||
    if isinstance(sa_column, Column):
 | 
			
		||||
        return sa_column
 | 
			
		||||
    sa_type = get_sqlachemy_type(field)
 | 
			
		||||
    primary_key = getattr(field.field_info, "primary_key", False)
 | 
			
		||||
    nullable = not field.required
 | 
			
		||||
    index = getattr(field.field_info, "index", Undefined)
 | 
			
		||||
    if index is Undefined:
 | 
			
		||||
        index = True
 | 
			
		||||
    if hasattr(field.field_info, "nullable"):
 | 
			
		||||
        field_nullable = getattr(field.field_info, "nullable")
 | 
			
		||||
        if field_nullable != Undefined:
 | 
			
		||||
            nullable = field_nullable
 | 
			
		||||
    args = []
 | 
			
		||||
    foreign_key = getattr(field.field_info, "foreign_key", None)
 | 
			
		||||
    if foreign_key:
 | 
			
		||||
        args.append(ForeignKey(foreign_key))
 | 
			
		||||
    kwargs = {
 | 
			
		||||
        "primary_key": primary_key,
 | 
			
		||||
        "nullable": nullable,
 | 
			
		||||
        "index": index,
 | 
			
		||||
    }
 | 
			
		||||
    sa_default = Undefined
 | 
			
		||||
    if field.field_info.default_factory:
 | 
			
		||||
        sa_default = field.field_info.default_factory
 | 
			
		||||
    elif field.field_info.default is not Undefined:
 | 
			
		||||
        sa_default = field.field_info.default
 | 
			
		||||
    if sa_default is not Undefined:
 | 
			
		||||
        kwargs["default"] = sa_default
 | 
			
		||||
    sa_column_args = getattr(field.field_info, "sa_column_args", Undefined)
 | 
			
		||||
    if sa_column_args is not Undefined:
 | 
			
		||||
        args.extend(list(cast(Sequence, sa_column_args)))
 | 
			
		||||
    sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined)
 | 
			
		||||
    if sa_column_kwargs is not Undefined:
 | 
			
		||||
        kwargs.update(cast(dict, sa_column_kwargs))
 | 
			
		||||
    return Column(sa_type, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class_registry = weakref.WeakValueDictionary()  # type: ignore
 | 
			
		||||
 | 
			
		||||
default_registry = registry()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
 | 
			
		||||
    # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
 | 
			
		||||
    __slots__ = ("__weakref__",)
 | 
			
		||||
    __tablename__: ClassVar[Union[str, Callable[..., str]]]
 | 
			
		||||
    __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]]
 | 
			
		||||
    __name__: ClassVar[str]
 | 
			
		||||
    metadata: ClassVar[MetaData]
 | 
			
		||||
 | 
			
		||||
    class Config:
 | 
			
		||||
        orm_mode = True
 | 
			
		||||
 | 
			
		||||
    def __new__(cls, *args, **kwargs) -> Any:
 | 
			
		||||
        new_object = super().__new__(cls)
 | 
			
		||||
        # SQLAlchemy doesn't call __init__ on the base class
 | 
			
		||||
        # Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html
 | 
			
		||||
        # Set __fields_set__ here, that would have been set when calling __init__
 | 
			
		||||
        # in the Pydantic model so that when SQLAlchemy sets attributes that are
 | 
			
		||||
        # added (e.g. when querying from DB) to the __fields_set__, this already exists
 | 
			
		||||
        object.__setattr__(new_object, "__fields_set__", set())
 | 
			
		||||
        return new_object
 | 
			
		||||
 | 
			
		||||
    def __init__(__pydantic_self__, **data: Any) -> None:
 | 
			
		||||
        # Uses something other than `self` the first arg to allow "self" as a
 | 
			
		||||
        # settable attribute
 | 
			
		||||
        if TYPE_CHECKING:
 | 
			
		||||
            __pydantic_self__.__dict__: Dict[str, Any] = {}
 | 
			
		||||
            __pydantic_self__.__fields_set__: Set[str] = set()
 | 
			
		||||
        values, fields_set, validation_error = validate_model(
 | 
			
		||||
            __pydantic_self__.__class__, data
 | 
			
		||||
        )
 | 
			
		||||
        # Only raise errors if not a SQLModel model
 | 
			
		||||
        if (
 | 
			
		||||
            not getattr(__pydantic_self__.__config__, "table", False)
 | 
			
		||||
            and validation_error
 | 
			
		||||
        ):
 | 
			
		||||
            raise validation_error
 | 
			
		||||
        # Do not set values as in Pydantic, pass them through setattr, so SQLAlchemy
 | 
			
		||||
        # can handle them
 | 
			
		||||
        # object.__setattr__(__pydantic_self__, '__dict__', values)
 | 
			
		||||
        object.__setattr__(__pydantic_self__, "__fields_set__", fields_set)
 | 
			
		||||
        for key, value in values.items():
 | 
			
		||||
            setattr(__pydantic_self__, key, value)
 | 
			
		||||
        non_pydantic_keys = data.keys() - values.keys()
 | 
			
		||||
        for key in non_pydantic_keys:
 | 
			
		||||
            if key in __pydantic_self__.__sqlmodel_relationships__:
 | 
			
		||||
                setattr(__pydantic_self__, key, data[key])
 | 
			
		||||
 | 
			
		||||
    def __setattr__(self, name: str, value: Any) -> None:
 | 
			
		||||
        if name in {"_sa_instance_state"}:
 | 
			
		||||
            self.__dict__[name] = value
 | 
			
		||||
            return
 | 
			
		||||
        else:
 | 
			
		||||
            # Set in SQLAlchemy, before Pydantic to trigger events and updates
 | 
			
		||||
            if getattr(self.__config__, "table", False):
 | 
			
		||||
                if is_instrumented(self, name):
 | 
			
		||||
                    set_attribute(self, name, value)
 | 
			
		||||
            # Set in Pydantic model to trigger possible validation changes, only for
 | 
			
		||||
            # non relationship values
 | 
			
		||||
            if name not in self.__sqlmodel_relationships__:
 | 
			
		||||
                super().__setattr__(name, value)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_orm(cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None):
 | 
			
		||||
        # Duplicated from Pydantic
 | 
			
		||||
        if not cls.__config__.orm_mode:
 | 
			
		||||
            raise ConfigError(
 | 
			
		||||
                "You must have the config attribute orm_mode=True to use from_orm"
 | 
			
		||||
            )
 | 
			
		||||
        obj = {ROOT_KEY: obj} if cls.__custom_root_type__ else cls._decompose_class(obj)
 | 
			
		||||
        # SQLModel, support update dict
 | 
			
		||||
        if update is not None:
 | 
			
		||||
            obj = {**obj, **update}
 | 
			
		||||
        # End SQLModel support dict
 | 
			
		||||
        if not getattr(cls.__config__, "table", False):
 | 
			
		||||
            # If not table, normal Pydantic code
 | 
			
		||||
            m = cls.__new__(cls)
 | 
			
		||||
        else:
 | 
			
		||||
            # If table, create the new instance normally to make SQLAlchemy create
 | 
			
		||||
            # the _sa_instance_state attribute
 | 
			
		||||
            m = cls()
 | 
			
		||||
        values, fields_set, validation_error = validate_model(cls, obj)
 | 
			
		||||
        if validation_error:
 | 
			
		||||
            raise validation_error
 | 
			
		||||
        # Updated to trigger SQLAlchemy internal handling
 | 
			
		||||
        if not getattr(cls.__config__, "table", False):
 | 
			
		||||
            object.__setattr__(m, "__dict__", values)
 | 
			
		||||
        else:
 | 
			
		||||
            for key, value in values.items():
 | 
			
		||||
                setattr(m, key, value)
 | 
			
		||||
        # Continue with standard Pydantic logic
 | 
			
		||||
        object.__setattr__(m, "__fields_set__", fields_set)
 | 
			
		||||
        m._init_private_attributes()
 | 
			
		||||
        return m
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def parse_obj(
 | 
			
		||||
        cls: Type["SQLModel"], obj: Any, update: Dict[str, Any] = None
 | 
			
		||||
    ) -> "SQLModel":
 | 
			
		||||
        obj = cls._enforce_dict_if_root(obj)
 | 
			
		||||
        # SQLModel, support update dict
 | 
			
		||||
        if update is not None:
 | 
			
		||||
            obj = {**obj, **update}
 | 
			
		||||
        # End SQLModel support dict
 | 
			
		||||
        return super().parse_obj(obj)
 | 
			
		||||
 | 
			
		||||
    def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
 | 
			
		||||
        # Don't show SQLAlchemy private attributes
 | 
			
		||||
        return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_sa_")]
 | 
			
		||||
 | 
			
		||||
    # From Pydantic, override to enforce validation with dict
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def validate(cls: Type["SQLModel"], value: Any) -> "SQLModel":
 | 
			
		||||
        if isinstance(value, cls):
 | 
			
		||||
            return value.copy() if cls.__config__.copy_on_model_validation else value
 | 
			
		||||
 | 
			
		||||
        value = cls._enforce_dict_if_root(value)
 | 
			
		||||
        if isinstance(value, dict):
 | 
			
		||||
            values, fields_set, validation_error = validate_model(cls, value)
 | 
			
		||||
            if validation_error:
 | 
			
		||||
                raise validation_error
 | 
			
		||||
            model = cls(**values)
 | 
			
		||||
            # Reset fields set, this would have been done in Pydantic in __init__
 | 
			
		||||
            object.__setattr__(model, "__fields_set__", fields_set)
 | 
			
		||||
            return model
 | 
			
		||||
        elif cls.__config__.orm_mode:
 | 
			
		||||
            return cls.from_orm(value)
 | 
			
		||||
        elif cls.__custom_root_type__:
 | 
			
		||||
            return cls.parse_obj(value)
 | 
			
		||||
        else:
 | 
			
		||||
            try:
 | 
			
		||||
                value_as_dict = dict(value)
 | 
			
		||||
            except (TypeError, ValueError) as e:
 | 
			
		||||
                raise DictError() from e
 | 
			
		||||
            return cls(**value_as_dict)
 | 
			
		||||
 | 
			
		||||
    # From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes
 | 
			
		||||
    def _calculate_keys(
 | 
			
		||||
        self,
 | 
			
		||||
        include: Optional[Mapping[Union[int, str], Any]],
 | 
			
		||||
        exclude: Optional[Mapping[Union[int, str], Any]],
 | 
			
		||||
        exclude_unset: bool,
 | 
			
		||||
        update: Optional[Dict[str, Any]] = None,
 | 
			
		||||
    ) -> Optional[AbstractSet[str]]:
 | 
			
		||||
        if include is None and exclude is None and exclude_unset is False:
 | 
			
		||||
            # Original in Pydantic:
 | 
			
		||||
            # return None
 | 
			
		||||
            # Updated to not return SQLAlchemy attributes
 | 
			
		||||
            # Do not include relationships as that would easily lead to infinite
 | 
			
		||||
            # recursion, or traversing the whole database
 | 
			
		||||
            return self.__fields__.keys()  # | self.__sqlmodel_relationships__.keys()
 | 
			
		||||
 | 
			
		||||
        keys: AbstractSet[str]
 | 
			
		||||
        if exclude_unset:
 | 
			
		||||
            keys = self.__fields_set__.copy()
 | 
			
		||||
        else:
 | 
			
		||||
            # Original in Pydantic:
 | 
			
		||||
            # keys = self.__dict__.keys()
 | 
			
		||||
            # Updated to not return SQLAlchemy attributes
 | 
			
		||||
            # Do not include relationships as that would easily lead to infinite
 | 
			
		||||
            # recursion, or traversing the whole database
 | 
			
		||||
            keys = self.__fields__.keys()  # | self.__sqlmodel_relationships__.keys()
 | 
			
		||||
 | 
			
		||||
        if include is not None:
 | 
			
		||||
            keys &= include.keys()
 | 
			
		||||
 | 
			
		||||
        if update:
 | 
			
		||||
            keys -= update.keys()
 | 
			
		||||
 | 
			
		||||
        if exclude:
 | 
			
		||||
            keys -= {k for k, v in exclude.items() if ValueItems.is_true(v)}
 | 
			
		||||
 | 
			
		||||
        return keys
 | 
			
		||||
 | 
			
		||||
    @declared_attr  # type: ignore
 | 
			
		||||
    def __tablename__(cls) -> str:
 | 
			
		||||
        return cls.__name__.lower()
 | 
			
		||||
							
								
								
									
										0
									
								
								sqlmodel/orm/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								sqlmodel/orm/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										135
									
								
								sqlmodel/orm/session.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										135
									
								
								sqlmodel/orm/session.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,135 @@
 | 
			
		||||
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union, overload
 | 
			
		||||
 | 
			
		||||
from sqlalchemy import util
 | 
			
		||||
from sqlalchemy.orm import Query as _Query
 | 
			
		||||
from sqlalchemy.orm import Session as _Session
 | 
			
		||||
from sqlalchemy.sql.base import Executable as _Executable
 | 
			
		||||
from sqlmodel.sql.expression import Select, SelectOfScalar
 | 
			
		||||
from typing_extensions import Literal
 | 
			
		||||
 | 
			
		||||
from ..engine.result import Result, ScalarResult
 | 
			
		||||
from ..sql.base import Executable
 | 
			
		||||
 | 
			
		||||
_T = TypeVar("_T")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Session(_Session):
 | 
			
		||||
    @overload
 | 
			
		||||
    def exec(
 | 
			
		||||
        self,
 | 
			
		||||
        statement: Select[_T],
 | 
			
		||||
        *,
 | 
			
		||||
        params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
 | 
			
		||||
        execution_options: Mapping[str, Any] = util.EMPTY_DICT,
 | 
			
		||||
        bind_arguments: Optional[Mapping[str, Any]] = None,
 | 
			
		||||
        _parent_execute_state: Optional[Any] = None,
 | 
			
		||||
        _add_event: Optional[Any] = None,
 | 
			
		||||
        **kw: Any,
 | 
			
		||||
    ) -> Union[Result[_T]]:
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    @overload
 | 
			
		||||
    def exec(
 | 
			
		||||
        self,
 | 
			
		||||
        statement: SelectOfScalar[_T],
 | 
			
		||||
        *,
 | 
			
		||||
        params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
 | 
			
		||||
        execution_options: Mapping[str, Any] = util.EMPTY_DICT,
 | 
			
		||||
        bind_arguments: Optional[Mapping[str, Any]] = None,
 | 
			
		||||
        _parent_execute_state: Optional[Any] = None,
 | 
			
		||||
        _add_event: Optional[Any] = None,
 | 
			
		||||
        **kw: Any,
 | 
			
		||||
    ) -> Union[ScalarResult[_T]]:
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    def exec(
 | 
			
		||||
        self,
 | 
			
		||||
        statement: Union[Select[_T], SelectOfScalar[_T], Executable[_T]],
 | 
			
		||||
        *,
 | 
			
		||||
        params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
 | 
			
		||||
        execution_options: Mapping[str, Any] = util.EMPTY_DICT,
 | 
			
		||||
        bind_arguments: Optional[Mapping[str, Any]] = None,
 | 
			
		||||
        _parent_execute_state: Optional[Any] = None,
 | 
			
		||||
        _add_event: Optional[Any] = None,
 | 
			
		||||
        **kw: Any,
 | 
			
		||||
    ) -> Union[Result[_T], ScalarResult[_T]]:
 | 
			
		||||
        results = super().execute(
 | 
			
		||||
            statement,
 | 
			
		||||
            params=params,
 | 
			
		||||
            execution_options=execution_options,  # type: ignore
 | 
			
		||||
            bind_arguments=bind_arguments,
 | 
			
		||||
            _parent_execute_state=_parent_execute_state,
 | 
			
		||||
            _add_event=_add_event,
 | 
			
		||||
            **kw,
 | 
			
		||||
        )
 | 
			
		||||
        if isinstance(statement, SelectOfScalar):
 | 
			
		||||
            return results.scalars()  # type: ignore
 | 
			
		||||
        return results  # type: ignore
 | 
			
		||||
 | 
			
		||||
    def execute(
 | 
			
		||||
        self,
 | 
			
		||||
        statement: _Executable,
 | 
			
		||||
        params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
 | 
			
		||||
        execution_options: Mapping[str, Any] = util.EMPTY_DICT,
 | 
			
		||||
        bind_arguments: Optional[Mapping[str, Any]] = None,
 | 
			
		||||
        _parent_execute_state: Optional[Any] = None,
 | 
			
		||||
        _add_event: Optional[Any] = None,
 | 
			
		||||
        **kw: Any,
 | 
			
		||||
    ) -> Result[Any]:
 | 
			
		||||
        """
 | 
			
		||||
        🚨 You probably want to use `session.exec()` instead of `session.execute()`.
 | 
			
		||||
 | 
			
		||||
        This is the original SQLAlchemy `session.execute()` method that returns objects
 | 
			
		||||
        of type `Row`, and that you have to call `scalars()` to get the model objects.
 | 
			
		||||
 | 
			
		||||
        For example:
 | 
			
		||||
 | 
			
		||||
        ```Python
 | 
			
		||||
        heroes = session.execute(select(Hero)).scalars().all()
 | 
			
		||||
        ```
 | 
			
		||||
 | 
			
		||||
        instead you could use `exec()`:
 | 
			
		||||
 | 
			
		||||
        ```Python
 | 
			
		||||
        heroes = session.exec(select(Hero)).all()
 | 
			
		||||
        ```
 | 
			
		||||
        """
 | 
			
		||||
        return super().execute(  # type: ignore
 | 
			
		||||
            statement,
 | 
			
		||||
            params=params,
 | 
			
		||||
            execution_options=execution_options,  # type: ignore
 | 
			
		||||
            bind_arguments=bind_arguments,
 | 
			
		||||
            _parent_execute_state=_parent_execute_state,
 | 
			
		||||
            _add_event=_add_event,
 | 
			
		||||
            **kw,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def query(self, *entities: Any, **kwargs: Any) -> "_Query[Any]":
 | 
			
		||||
        """
 | 
			
		||||
        🚨 You probably want to use `session.exec()` instead of `session.query()`.
 | 
			
		||||
 | 
			
		||||
        `session.exec()` is SQLModel's own short version with increased type
 | 
			
		||||
        annotations.
 | 
			
		||||
 | 
			
		||||
        Or otherwise you might want to use `session.execute()` instead of
 | 
			
		||||
        `session.query()`.
 | 
			
		||||
        """
 | 
			
		||||
        return super().query(*entities, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def get(
 | 
			
		||||
        self,
 | 
			
		||||
        entity: _T,
 | 
			
		||||
        ident: Any,
 | 
			
		||||
        options: Optional[Sequence[Any]] = None,
 | 
			
		||||
        populate_existing: bool = False,
 | 
			
		||||
        with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None,
 | 
			
		||||
        identity_token: Optional[Any] = None,
 | 
			
		||||
    ) -> _T:
 | 
			
		||||
        return super().get(
 | 
			
		||||
            entity,
 | 
			
		||||
            ident,
 | 
			
		||||
            options=options,
 | 
			
		||||
            populate_existing=populate_existing,
 | 
			
		||||
            with_for_update=with_for_update,
 | 
			
		||||
            identity_token=identity_token,
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										1
									
								
								sqlmodel/pool/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								sqlmodel/pool/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1 @@
 | 
			
		||||
from sqlalchemy.pool import StaticPool as StaticPool  # noqa: F401
 | 
			
		||||
							
								
								
									
										0
									
								
								sqlmodel/sql/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								sqlmodel/sql/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										11
									
								
								sqlmodel/sql/base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								sqlmodel/sql/base.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,11 @@
 | 
			
		||||
from typing import Generic, TypeVar
 | 
			
		||||
 | 
			
		||||
from sqlalchemy.sql.base import Executable as _Executable
 | 
			
		||||
 | 
			
		||||
_T = TypeVar("_T")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Executable(_Executable, Generic[_T]):
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        self.__dict__["_exec_options"] = kwargs.pop("_exec_options", None)
 | 
			
		||||
        super(_Executable, self).__init__(*args, **kwargs)
 | 
			
		||||
							
								
								
									
										459
									
								
								sqlmodel/sql/expression.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										459
									
								
								sqlmodel/sql/expression.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,459 @@
 | 
			
		||||
# WARNING: do not modify this code, it is generated by expression.py.jinja2
 | 
			
		||||
 | 
			
		||||
import sys
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
from typing import (
 | 
			
		||||
    TYPE_CHECKING,
 | 
			
		||||
    Any,
 | 
			
		||||
    Generic,
 | 
			
		||||
    Mapping,
 | 
			
		||||
    Sequence,
 | 
			
		||||
    Tuple,
 | 
			
		||||
    Type,
 | 
			
		||||
    TypeVar,
 | 
			
		||||
    Union,
 | 
			
		||||
    cast,
 | 
			
		||||
    overload,
 | 
			
		||||
)
 | 
			
		||||
from uuid import UUID
 | 
			
		||||
 | 
			
		||||
from sqlalchemy import Column
 | 
			
		||||
from sqlalchemy.orm import InstrumentedAttribute
 | 
			
		||||
from sqlalchemy.sql.elements import ColumnClause
 | 
			
		||||
from sqlalchemy.sql.expression import Select as _Select
 | 
			
		||||
 | 
			
		||||
_TSelect = TypeVar("_TSelect")
 | 
			
		||||
 | 
			
		||||
# Workaround Generics incompatibility in Python 3.6
 | 
			
		||||
# Ref: https://github.com/python/typing/issues/449#issuecomment-316061322
 | 
			
		||||
if sys.version_info.minor >= 7:
 | 
			
		||||
 | 
			
		||||
    class Select(_Select, Generic[_TSelect]):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    # This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
 | 
			
		||||
    # purpose. This is the same as a normal SQLAlchemy Select class where there's only one
 | 
			
		||||
    # entity, so the result will be converted to a scalar by default. This way writing
 | 
			
		||||
    # for loops on the results will feel natural.
 | 
			
		||||
    class SelectOfScalar(_Select, Generic[_TSelect]):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
else:
 | 
			
		||||
    from typing import GenericMeta  # type: ignore
 | 
			
		||||
 | 
			
		||||
    class GenericSelectMeta(GenericMeta, _Select.__class__):  # type: ignore
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    # Cast them for editors to work correctly, from several tricks tried, this works
 | 
			
		||||
    # for both VS Code and PyCharm
 | 
			
		||||
    Select = cast("Select", _Py36Select)  # type: ignore
 | 
			
		||||
    SelectOfScalar = cast("SelectOfScalar", _Py36SelectOfScalar)  # type: ignore
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:  # pragma: no cover
 | 
			
		||||
    from ..main import SQLModel
 | 
			
		||||
 | 
			
		||||
# Generated TypeVars start
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_TScalar_0 = TypeVar(
 | 
			
		||||
    "_TScalar_0",
 | 
			
		||||
    Column,
 | 
			
		||||
    Sequence,
 | 
			
		||||
    Mapping,
 | 
			
		||||
    UUID,
 | 
			
		||||
    datetime,
 | 
			
		||||
    float,
 | 
			
		||||
    int,
 | 
			
		||||
    bool,
 | 
			
		||||
    bytes,
 | 
			
		||||
    str,
 | 
			
		||||
    None,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
_TModel_0 = TypeVar("_TModel_0", bound="SQLModel")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_TScalar_1 = TypeVar(
 | 
			
		||||
    "_TScalar_1",
 | 
			
		||||
    Column,
 | 
			
		||||
    Sequence,
 | 
			
		||||
    Mapping,
 | 
			
		||||
    UUID,
 | 
			
		||||
    datetime,
 | 
			
		||||
    float,
 | 
			
		||||
    int,
 | 
			
		||||
    bool,
 | 
			
		||||
    bytes,
 | 
			
		||||
    str,
 | 
			
		||||
    None,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
_TModel_1 = TypeVar("_TModel_1", bound="SQLModel")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_TScalar_2 = TypeVar(
 | 
			
		||||
    "_TScalar_2",
 | 
			
		||||
    Column,
 | 
			
		||||
    Sequence,
 | 
			
		||||
    Mapping,
 | 
			
		||||
    UUID,
 | 
			
		||||
    datetime,
 | 
			
		||||
    float,
 | 
			
		||||
    int,
 | 
			
		||||
    bool,
 | 
			
		||||
    bytes,
 | 
			
		||||
    str,
 | 
			
		||||
    None,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
_TModel_2 = TypeVar("_TModel_2", bound="SQLModel")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_TScalar_3 = TypeVar(
 | 
			
		||||
    "_TScalar_3",
 | 
			
		||||
    Column,
 | 
			
		||||
    Sequence,
 | 
			
		||||
    Mapping,
 | 
			
		||||
    UUID,
 | 
			
		||||
    datetime,
 | 
			
		||||
    float,
 | 
			
		||||
    int,
 | 
			
		||||
    bool,
 | 
			
		||||
    bytes,
 | 
			
		||||
    str,
 | 
			
		||||
    None,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
_TModel_3 = TypeVar("_TModel_3", bound="SQLModel")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Generated TypeVars end
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]:  # type: ignore
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]:  # type: ignore
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Generated overloads start
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: _TScalar_0,
 | 
			
		||||
    entity_1: _TScalar_1,
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TScalar_0, _TScalar_1]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: _TScalar_0,
 | 
			
		||||
    entity_1: Type[_TModel_1],
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TScalar_0, _TModel_1]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: Type[_TModel_0],
 | 
			
		||||
    entity_1: _TScalar_1,
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TModel_0, _TScalar_1]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: Type[_TModel_0],
 | 
			
		||||
    entity_1: Type[_TModel_1],
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TModel_0, _TModel_1]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: _TScalar_0,
 | 
			
		||||
    entity_1: _TScalar_1,
 | 
			
		||||
    entity_2: _TScalar_2,
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: _TScalar_0,
 | 
			
		||||
    entity_1: _TScalar_1,
 | 
			
		||||
    entity_2: Type[_TModel_2],
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: _TScalar_0,
 | 
			
		||||
    entity_1: Type[_TModel_1],
 | 
			
		||||
    entity_2: _TScalar_2,
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: _TScalar_0,
 | 
			
		||||
    entity_1: Type[_TModel_1],
 | 
			
		||||
    entity_2: Type[_TModel_2],
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: Type[_TModel_0],
 | 
			
		||||
    entity_1: _TScalar_1,
 | 
			
		||||
    entity_2: _TScalar_2,
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: Type[_TModel_0],
 | 
			
		||||
    entity_1: _TScalar_1,
 | 
			
		||||
    entity_2: Type[_TModel_2],
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: Type[_TModel_0],
 | 
			
		||||
    entity_1: Type[_TModel_1],
 | 
			
		||||
    entity_2: _TScalar_2,
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: Type[_TModel_0],
 | 
			
		||||
    entity_1: Type[_TModel_1],
 | 
			
		||||
    entity_2: Type[_TModel_2],
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: _TScalar_0,
 | 
			
		||||
    entity_1: _TScalar_1,
 | 
			
		||||
    entity_2: _TScalar_2,
 | 
			
		||||
    entity_3: _TScalar_3,
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TScalar_3]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: _TScalar_0,
 | 
			
		||||
    entity_1: _TScalar_1,
 | 
			
		||||
    entity_2: _TScalar_2,
 | 
			
		||||
    entity_3: Type[_TModel_3],
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TModel_3]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: _TScalar_0,
 | 
			
		||||
    entity_1: _TScalar_1,
 | 
			
		||||
    entity_2: Type[_TModel_2],
 | 
			
		||||
    entity_3: _TScalar_3,
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TScalar_3]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: _TScalar_0,
 | 
			
		||||
    entity_1: _TScalar_1,
 | 
			
		||||
    entity_2: Type[_TModel_2],
 | 
			
		||||
    entity_3: Type[_TModel_3],
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TModel_3]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: _TScalar_0,
 | 
			
		||||
    entity_1: Type[_TModel_1],
 | 
			
		||||
    entity_2: _TScalar_2,
 | 
			
		||||
    entity_3: _TScalar_3,
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TScalar_3]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: _TScalar_0,
 | 
			
		||||
    entity_1: Type[_TModel_1],
 | 
			
		||||
    entity_2: _TScalar_2,
 | 
			
		||||
    entity_3: Type[_TModel_3],
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TModel_3]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: _TScalar_0,
 | 
			
		||||
    entity_1: Type[_TModel_1],
 | 
			
		||||
    entity_2: Type[_TModel_2],
 | 
			
		||||
    entity_3: _TScalar_3,
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TScalar_3]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: _TScalar_0,
 | 
			
		||||
    entity_1: Type[_TModel_1],
 | 
			
		||||
    entity_2: Type[_TModel_2],
 | 
			
		||||
    entity_3: Type[_TModel_3],
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TModel_3]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: Type[_TModel_0],
 | 
			
		||||
    entity_1: _TScalar_1,
 | 
			
		||||
    entity_2: _TScalar_2,
 | 
			
		||||
    entity_3: _TScalar_3,
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TScalar_3]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: Type[_TModel_0],
 | 
			
		||||
    entity_1: _TScalar_1,
 | 
			
		||||
    entity_2: _TScalar_2,
 | 
			
		||||
    entity_3: Type[_TModel_3],
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TModel_3]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: Type[_TModel_0],
 | 
			
		||||
    entity_1: _TScalar_1,
 | 
			
		||||
    entity_2: Type[_TModel_2],
 | 
			
		||||
    entity_3: _TScalar_3,
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TScalar_3]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: Type[_TModel_0],
 | 
			
		||||
    entity_1: _TScalar_1,
 | 
			
		||||
    entity_2: Type[_TModel_2],
 | 
			
		||||
    entity_3: Type[_TModel_3],
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TModel_3]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: Type[_TModel_0],
 | 
			
		||||
    entity_1: Type[_TModel_1],
 | 
			
		||||
    entity_2: _TScalar_2,
 | 
			
		||||
    entity_3: _TScalar_3,
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TScalar_3]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: Type[_TModel_0],
 | 
			
		||||
    entity_1: Type[_TModel_1],
 | 
			
		||||
    entity_2: _TScalar_2,
 | 
			
		||||
    entity_3: Type[_TModel_3],
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TModel_3]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: Type[_TModel_0],
 | 
			
		||||
    entity_1: Type[_TModel_1],
 | 
			
		||||
    entity_2: Type[_TModel_2],
 | 
			
		||||
    entity_3: _TScalar_3,
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TScalar_3]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    entity_0: Type[_TModel_0],
 | 
			
		||||
    entity_1: Type[_TModel_1],
 | 
			
		||||
    entity_2: Type[_TModel_2],
 | 
			
		||||
    entity_3: Type[_TModel_3],
 | 
			
		||||
    **kw: Any,
 | 
			
		||||
) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TModel_3]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Generated overloads end
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
 | 
			
		||||
    if len(entities) == 1:
 | 
			
		||||
        return SelectOfScalar._create(*entities, **kw)  # type: ignore
 | 
			
		||||
    return Select._create(*entities, **kw)  # type: ignore
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO: add several @overload from Python types to SQLAlchemy equivalents
 | 
			
		||||
def col(column_expression: Any) -> ColumnClause:
 | 
			
		||||
    if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
 | 
			
		||||
        raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
 | 
			
		||||
    return column_expression
 | 
			
		||||
							
								
								
									
										119
									
								
								sqlmodel/sql/expression.py.jinja2
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								sqlmodel/sql/expression.py.jinja2
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,119 @@
 | 
			
		||||
import sys
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
from typing import (
 | 
			
		||||
    TYPE_CHECKING,
 | 
			
		||||
    Any,
 | 
			
		||||
    Generic,
 | 
			
		||||
    Mapping,
 | 
			
		||||
    Sequence,
 | 
			
		||||
    Tuple,
 | 
			
		||||
    Type,
 | 
			
		||||
    TypeVar,
 | 
			
		||||
    Union,
 | 
			
		||||
    cast,
 | 
			
		||||
    overload,
 | 
			
		||||
)
 | 
			
		||||
from uuid import UUID
 | 
			
		||||
 | 
			
		||||
from sqlalchemy import Column
 | 
			
		||||
from sqlalchemy.orm import InstrumentedAttribute
 | 
			
		||||
from sqlalchemy.sql.elements import ColumnClause
 | 
			
		||||
from sqlalchemy.sql.expression import Select as _Select
 | 
			
		||||
 | 
			
		||||
_TSelect = TypeVar("_TSelect")
 | 
			
		||||
 | 
			
		||||
# Workaround Generics incompatibility in Python 3.6
 | 
			
		||||
# Ref: https://github.com/python/typing/issues/449#issuecomment-316061322
 | 
			
		||||
if sys.version_info.minor >= 7:
 | 
			
		||||
 | 
			
		||||
    class Select(_Select, Generic[_TSelect]):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    # This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different
 | 
			
		||||
    # purpose. This is the same as a normal SQLAlchemy Select class where there's only one
 | 
			
		||||
    # entity, so the result will be converted to a scalar by default. This way writing
 | 
			
		||||
    # for loops on the results will feel natural.
 | 
			
		||||
    class SelectOfScalar(_Select, Generic[_TSelect]):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
else:
 | 
			
		||||
    from typing import GenericMeta  # type: ignore
 | 
			
		||||
 | 
			
		||||
    class GenericSelectMeta(GenericMeta, _Select.__class__):  # type: ignore
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    # Cast them for editors to work correctly, from several tricks tried, this works
 | 
			
		||||
    # for both VS Code and PyCharm
 | 
			
		||||
    Select = cast("Select", _Py36Select)  # type: ignore
 | 
			
		||||
    SelectOfScalar = cast("SelectOfScalar", _Py36SelectOfScalar)  # type: ignore
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:  # pragma: no cover
 | 
			
		||||
    from ..main import SQLModel
 | 
			
		||||
 | 
			
		||||
# Generated TypeVars start
 | 
			
		||||
 | 
			
		||||
{% for i in range(number_of_types) %}
 | 
			
		||||
_TScalar_{{ i }} = TypeVar(
 | 
			
		||||
    "_TScalar_{{ i }}",
 | 
			
		||||
    Column,
 | 
			
		||||
    Sequence,
 | 
			
		||||
    Mapping,
 | 
			
		||||
    UUID,
 | 
			
		||||
    datetime,
 | 
			
		||||
    float,
 | 
			
		||||
    int,
 | 
			
		||||
    bool,
 | 
			
		||||
    bytes,
 | 
			
		||||
    str,
 | 
			
		||||
    None,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
_TModel_{{ i }} = TypeVar("_TModel_{{ i }}", bound="SQLModel")
 | 
			
		||||
 | 
			
		||||
{% endfor %}
 | 
			
		||||
 | 
			
		||||
# Generated TypeVars end
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]:  # type: ignore
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]:  # type: ignore
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Generated overloads start
 | 
			
		||||
 | 
			
		||||
{% for signature in signatures %}
 | 
			
		||||
 | 
			
		||||
@overload
 | 
			
		||||
def select(  # type: ignore
 | 
			
		||||
    {% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %}**kw: Any,
 | 
			
		||||
    ) -> Select[Tuple[{%for ret in signature[1] %}{{ ret }} {% if not loop.last %}, {% endif %}{% endfor %}]]:
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
{% endfor %}
 | 
			
		||||
 | 
			
		||||
# Generated overloads end
 | 
			
		||||
 | 
			
		||||
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
 | 
			
		||||
    if len(entities) == 1:
 | 
			
		||||
        return SelectOfScalar._create(*entities, **kw)  # type: ignore
 | 
			
		||||
    return Select._create(*entities, **kw)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO: add several @overload from Python types to SQLAlchemy equivalents
 | 
			
		||||
def col(column_expression: Any) -> ColumnClause:
 | 
			
		||||
    if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
 | 
			
		||||
        raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
 | 
			
		||||
    return column_expression
 | 
			
		||||
							
								
								
									
										60
									
								
								sqlmodel/sql/sqltypes.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								sqlmodel/sql/sqltypes.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,60 @@
 | 
			
		||||
import uuid
 | 
			
		||||
from typing import Any, cast
 | 
			
		||||
 | 
			
		||||
from sqlalchemy import types
 | 
			
		||||
from sqlalchemy.dialects.postgresql import UUID
 | 
			
		||||
from sqlalchemy.engine.interfaces import Dialect
 | 
			
		||||
from sqlalchemy.types import CHAR, TypeDecorator
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AutoString(types.TypeDecorator):
 | 
			
		||||
 | 
			
		||||
    impl = types.String
 | 
			
		||||
    cache_ok = True
 | 
			
		||||
    mysql_default_length = 255
 | 
			
		||||
 | 
			
		||||
    def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]":
 | 
			
		||||
        impl = cast(types.String, self.impl)
 | 
			
		||||
        if impl.length is None and dialect.name == "mysql":
 | 
			
		||||
            return dialect.type_descriptor(types.String(self.mysql_default_length))  # type: ignore
 | 
			
		||||
        return super().load_dialect_impl(dialect)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Reference form SQLAlchemy docs: https://docs.sqlalchemy.org/en/14/core/custom_types.html#backend-agnostic-guid-type
 | 
			
		||||
# with small modifications
 | 
			
		||||
class GUID(TypeDecorator):
 | 
			
		||||
    """Platform-independent GUID type.
 | 
			
		||||
 | 
			
		||||
    Uses PostgreSQL's UUID type, otherwise uses
 | 
			
		||||
    CHAR(32), storing as stringified hex values.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    impl = CHAR
 | 
			
		||||
    cache_ok = True
 | 
			
		||||
 | 
			
		||||
    def load_dialect_impl(self, dialect):
 | 
			
		||||
        if dialect.name == "postgresql":
 | 
			
		||||
            return dialect.type_descriptor(UUID())
 | 
			
		||||
        else:
 | 
			
		||||
            return dialect.type_descriptor(CHAR(32))
 | 
			
		||||
 | 
			
		||||
    def process_bind_param(self, value, dialect):
 | 
			
		||||
        if value is None:
 | 
			
		||||
            return value
 | 
			
		||||
        elif dialect.name == "postgresql":
 | 
			
		||||
            return str(value)
 | 
			
		||||
        else:
 | 
			
		||||
            if not isinstance(value, uuid.UUID):
 | 
			
		||||
                return f"{uuid.UUID(value).int:x}"
 | 
			
		||||
            else:
 | 
			
		||||
                # hexstring
 | 
			
		||||
                return f"{value.int:x}"
 | 
			
		||||
 | 
			
		||||
    def process_result_value(self, value, dialect):
 | 
			
		||||
        if value is None:
 | 
			
		||||
            return value
 | 
			
		||||
        else:
 | 
			
		||||
            if not isinstance(value, uuid.UUID):
 | 
			
		||||
                value = uuid.UUID(value)
 | 
			
		||||
            return value
 | 
			
		||||
		Reference in New Issue
	
	Block a user