mirror of
https://github.com/fastapi/sqlmodel.git
synced 2025-12-19 04:58:50 +08:00
✨ Update type annotations and upgrade mypy (#173)
This commit is contained in:
committed by
GitHub
parent
02da85c9ec
commit
e30c7ef4e9
@@ -6,6 +6,4 @@ _T = TypeVar("_T")
|
||||
|
||||
|
||||
class Executable(_Executable, Generic[_T]):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.__dict__["_exec_options"] = kwargs.pop("_exec_options", None)
|
||||
super(_Executable, self).__init__(*args, **kwargs)
|
||||
pass
|
||||
|
||||
@@ -45,10 +45,10 @@ else:
|
||||
class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore
|
||||
pass
|
||||
|
||||
class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): # type: ignore
|
||||
class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
|
||||
pass
|
||||
|
||||
class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): # type: ignore
|
||||
class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta):
|
||||
pass
|
||||
|
||||
# Cast them for editors to work correctly, from several tricks tried, this works
|
||||
@@ -65,9 +65,9 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
|
||||
_TScalar_0 = TypeVar(
|
||||
"_TScalar_0",
|
||||
Column,
|
||||
Sequence,
|
||||
Mapping,
|
||||
Column, # type: ignore
|
||||
Sequence, # type: ignore
|
||||
Mapping, # type: ignore
|
||||
UUID,
|
||||
datetime,
|
||||
float,
|
||||
@@ -83,9 +83,9 @@ _TModel_0 = TypeVar("_TModel_0", bound="SQLModel")
|
||||
|
||||
_TScalar_1 = TypeVar(
|
||||
"_TScalar_1",
|
||||
Column,
|
||||
Sequence,
|
||||
Mapping,
|
||||
Column, # type: ignore
|
||||
Sequence, # type: ignore
|
||||
Mapping, # type: ignore
|
||||
UUID,
|
||||
datetime,
|
||||
float,
|
||||
@@ -101,9 +101,9 @@ _TModel_1 = TypeVar("_TModel_1", bound="SQLModel")
|
||||
|
||||
_TScalar_2 = TypeVar(
|
||||
"_TScalar_2",
|
||||
Column,
|
||||
Sequence,
|
||||
Mapping,
|
||||
Column, # type: ignore
|
||||
Sequence, # type: ignore
|
||||
Mapping, # type: ignore
|
||||
UUID,
|
||||
datetime,
|
||||
float,
|
||||
@@ -119,9 +119,9 @@ _TModel_2 = TypeVar("_TModel_2", bound="SQLModel")
|
||||
|
||||
_TScalar_3 = TypeVar(
|
||||
"_TScalar_3",
|
||||
Column,
|
||||
Sequence,
|
||||
Mapping,
|
||||
Column, # type: ignore
|
||||
Sequence, # type: ignore
|
||||
Mapping, # type: ignore
|
||||
UUID,
|
||||
datetime,
|
||||
float,
|
||||
@@ -446,14 +446,14 @@ def select( # type: ignore
|
||||
# Generated overloads end
|
||||
|
||||
|
||||
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
|
||||
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore
|
||||
if len(entities) == 1:
|
||||
return SelectOfScalar._create(*entities, **kw) # type: ignore
|
||||
return Select._create(*entities, **kw) # type: ignore
|
||||
|
||||
|
||||
# TODO: add several @overload from Python types to SQLAlchemy equivalents
|
||||
def col(column_expression: Any) -> ColumnClause:
|
||||
def col(column_expression: Any) -> ColumnClause: # type: ignore
|
||||
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
|
||||
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
|
||||
return column_expression
|
||||
|
||||
@@ -63,9 +63,9 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
{% for i in range(number_of_types) %}
|
||||
_TScalar_{{ i }} = TypeVar(
|
||||
"_TScalar_{{ i }}",
|
||||
Column,
|
||||
Sequence,
|
||||
Mapping,
|
||||
Column, # type: ignore
|
||||
Sequence, # type: ignore
|
||||
Mapping, # type: ignore
|
||||
UUID,
|
||||
datetime,
|
||||
float,
|
||||
@@ -106,14 +106,14 @@ def select( # type: ignore
|
||||
|
||||
# Generated overloads end
|
||||
|
||||
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]:
|
||||
def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore
|
||||
if len(entities) == 1:
|
||||
return SelectOfScalar._create(*entities, **kw) # type: ignore
|
||||
return Select._create(*entities, **kw)
|
||||
return Select._create(*entities, **kw) # type: ignore
|
||||
|
||||
|
||||
# TODO: add several @overload from Python types to SQLAlchemy equivalents
|
||||
def col(column_expression: Any) -> ColumnClause:
|
||||
def col(column_expression: Any) -> ColumnClause: # type: ignore
|
||||
if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)):
|
||||
raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}")
|
||||
return column_expression
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import uuid
|
||||
from typing import Any, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
from sqlalchemy.types import CHAR, TypeDecorator
|
||||
|
||||
|
||||
class AutoString(types.TypeDecorator):
|
||||
class AutoString(types.TypeDecorator): # type: ignore
|
||||
|
||||
impl = types.String
|
||||
cache_ok = True
|
||||
@@ -22,7 +23,7 @@ class AutoString(types.TypeDecorator):
|
||||
|
||||
# Reference form SQLAlchemy docs: https://docs.sqlalchemy.org/en/14/core/custom_types.html#backend-agnostic-guid-type
|
||||
# with small modifications
|
||||
class GUID(TypeDecorator):
|
||||
class GUID(TypeDecorator): # type: ignore
|
||||
"""Platform-independent GUID type.
|
||||
|
||||
Uses PostgreSQL's UUID type, otherwise uses
|
||||
@@ -33,13 +34,13 @@ class GUID(TypeDecorator):
|
||||
impl = CHAR
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine: # type: ignore
|
||||
if dialect.name == "postgresql":
|
||||
return dialect.type_descriptor(UUID())
|
||||
return dialect.type_descriptor(UUID()) # type: ignore
|
||||
else:
|
||||
return dialect.type_descriptor(CHAR(32))
|
||||
return dialect.type_descriptor(CHAR(32)) # type: ignore
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]:
|
||||
if value is None:
|
||||
return value
|
||||
elif dialect.name == "postgresql":
|
||||
@@ -51,10 +52,10 @@ class GUID(TypeDecorator):
|
||||
# hexstring
|
||||
return f"{value.int:x}"
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
def process_result_value(self, value: Any, dialect: Dialect) -> Optional[uuid.UUID]:
|
||||
if value is None:
|
||||
return value
|
||||
else:
|
||||
if not isinstance(value, uuid.UUID):
|
||||
value = uuid.UUID(value)
|
||||
return value
|
||||
return cast(uuid.UUID, value)
|
||||
|
||||
Reference in New Issue
Block a user