Files
François Voron 48d9a399e0 Fix type errors
2021-02-07 09:58:52 +01:00

114 lines
4.3 KiB
Python

from typing import cast, Any, List, Optional, Type
import ormar
from ormar.exceptions import NoMatch
from pydantic import UUID4
from fastapi_users.db.base import BaseUserDatabase
from fastapi_users.models import UD, BaseOAuthAccount
class OrmarBaseUserModel(ormar.Model):
class Meta:
tablename = "users"
abstract = True
id = ormar.UUID(primary_key=True, uuid_format="string")
email = ormar.String(index=True, unique=True, nullable=False, max_length=255)
hashed_password = ormar.String(nullable=False, max_length=255)
is_active = ormar.Boolean(default=True, nullable=False)
is_superuser = ormar.Boolean(default=False, nullable=False)
is_verified = ormar.Boolean(default=False, nullable=False)
class OrmarBaseOAuthAccountModel(ormar.Model):
class Meta:
tablename = "oauth_accounts"
abstract = True
id = ormar.UUID(primary_key=True, uuid_format="string")
oauth_name = ormar.String(nullable=False, max_length=255)
access_token = ormar.String(nullable=False, max_length=255)
expires_at = ormar.Integer(nullable=True)
refresh_token = ormar.String(nullable=True, max_length=255)
account_id = ormar.String(index=True, nullable=False, max_length=255)
account_email = ormar.String(nullable=False, max_length=255)
class OrmarUserDatabase(BaseUserDatabase[UD]):
"""
Database adapter for ormar.
:param user_db_model: Pydantic model of a DB representation of a user.
:param model: ormar ORM model.
:param oauth_account_model: Optional ormar ORM model of a OAuth account.
"""
model: Type[OrmarBaseUserModel]
oauth_account_model: Optional[Type[OrmarBaseOAuthAccountModel]]
def __init__(
self,
user_db_model: Type[UD],
model: Type[OrmarBaseUserModel],
oauth_account_model: Optional[Type[OrmarBaseOAuthAccountModel]] = None,
):
super().__init__(user_db_model)
self.model = model
self.oauth_account_model = oauth_account_model
async def get(self, id: UUID4) -> Optional[UD]:
return await self._get_user(id=id)
async def get_by_email(self, email: str) -> Optional[UD]:
return await self._get_user(email__iexact=email)
async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]:
return await self._get_user(
oauth_accounts__oauth_name=oauth, oauth_accounts__account_id=account_id
)
async def create(self, user: UD) -> UD:
oauth_accounts = getattr(user, "oauth_accounts", [])
model = await self.model(**user.dict(exclude={"oauth_accounts"})).save()
if oauth_accounts and self.oauth_account_model:
await self._create_oauth_models(model=model, oauth_accounts=oauth_accounts)
user_db = await self._get_user(id=user.id)
return cast(UD, user_db)
async def update(self, user: UD) -> UD:
oauth_accounts = getattr(user, "oauth_accounts", [])
model = await self._get_db_user(id=user.id)
await model.update(**user.dict(exclude={"oauth_accounts"}))
if oauth_accounts and self.oauth_account_model:
await model.oauth_accounts.clear(keep_reversed=False)
await self._create_oauth_models(model=model, oauth_accounts=oauth_accounts)
user_db = await self._get_user(id=user.id)
return cast(UD, user_db)
async def delete(self, user: UD) -> None:
await self.model.objects.delete(id=user.id)
async def _create_oauth_models(
self, model: OrmarBaseUserModel, oauth_accounts: List[BaseOAuthAccount]
):
if self.oauth_account_model:
oauth_accounts_db: List[ormar.Model] = [
self.oauth_account_model(user=model, **oacc.dict())
for oacc in oauth_accounts
]
await self.oauth_account_model.objects.bulk_create(oauth_accounts_db)
async def _get_db_user(self, **kwargs: Any) -> OrmarBaseUserModel:
query = self.model.objects.filter(**kwargs)
if self.oauth_account_model is not None:
query = query.select_related("oauth_accounts")
return cast(OrmarBaseUserModel, await query.get())
async def _get_user(self, **kwargs: Any) -> Optional[UD]:
try:
user = await self._get_db_user(**kwargs)
except NoMatch:
return None
return self.user_db_model(**user.dict())