import uuid from typing import Mapping, Optional, Type from databases import Database from pydantic import UUID4 from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Table, select from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.types import CHAR, TypeDecorator from fastapi_users.db.base import BaseUserDatabase from fastapi_users.models import UD class GUID(TypeDecorator): # pragma: no cover """Platform-independent GUID type. Uses PostgreSQL's UUID type, otherwise uses CHAR(36), storing as regular strings. """ impl = CHAR def load_dialect_impl(self, dialect): if dialect.name == "postgresql": return dialect.type_descriptor(UUID()) else: return dialect.type_descriptor(CHAR(36)) def process_bind_param(self, value, dialect): if value is None: return value elif dialect.name == "postgresql": return str(value) else: if not isinstance(value, uuid.UUID): return str(uuid.UUID(value)) else: return str(value) def process_result_value(self, value, dialect): if value is None: return value else: if not isinstance(value, uuid.UUID): value = uuid.UUID(value) return value class SQLAlchemyBaseUserTable: """Base SQLAlchemy users table definition.""" __tablename__ = "user" id = Column(GUID, primary_key=True) email = Column(String, unique=True, index=True, nullable=False) hashed_password = Column(String, nullable=False) is_active = Column(Boolean, default=True, nullable=False) is_superuser = Column(Boolean, default=False, nullable=False) class SQLAlchemyBaseOAuthAccountTable: """Base SQLAlchemy OAuth account table definition.""" __tablename__ = "oauth_account" id = Column(GUID, primary_key=True) oauth_name = Column(String, index=True, nullable=False) access_token = Column(String, nullable=False) expires_at = Column(Integer, nullable=False) refresh_token = Column(String, nullable=True) account_id = Column(String, index=True, nullable=False) account_email = Column(String, nullable=False) @declared_attr def user_id(cls): return Column(GUID, ForeignKey("user.id", ondelete="cascade"), nullable=False) class NotSetOAuthAccountTableError(Exception): """ OAuth table was not set in DB adapter but was needed. Raised when trying to create/update a user with OAuth accounts set but no table were specified in the DB adapter. """ pass class SQLAlchemyUserDatabase(BaseUserDatabase[UD]): """ Database adapter for SQLAlchemy. :param user_db_model: Pydantic model of a DB representation of a user. :param database: `Database` instance from `encode/databases`. :param users: SQLAlchemy users table instance. :param oauth_accounts: Optional SQLAlchemy OAuth accounts table instance. """ database: Database users: Table oauth_accounts: Optional[Table] def __init__( self, user_db_model: Type[UD], database: Database, users: Table, oauth_accounts: Optional[Table] = None, ): super().__init__(user_db_model) self.database = database self.users = users self.oauth_accounts = oauth_accounts async def get(self, id: UUID4) -> Optional[UD]: query = self.users.select().where(self.users.c.id == id) user = await self.database.fetch_one(query) return await self._make_user(user) if user else None async def get_by_email(self, email: str) -> Optional[UD]: query = self.users.select().where(self.users.c.email == email) user = await self.database.fetch_one(query) return await self._make_user(user) if user else None async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]: if self.oauth_accounts is not None: query = ( select([self.users]) .select_from(self.users.join(self.oauth_accounts)) .where(self.oauth_accounts.c.oauth_name == oauth) .where(self.oauth_accounts.c.account_id == account_id) ) user = await self.database.fetch_one(query) return await self._make_user(user) if user else None raise NotSetOAuthAccountTableError() async def create(self, user: UD) -> UD: user_dict = user.dict() oauth_accounts_values = None if "oauth_accounts" in user_dict: oauth_accounts_values = [] oauth_accounts = user_dict.pop("oauth_accounts") for oauth_account in oauth_accounts: oauth_accounts_values.append({"user_id": user.id, **oauth_account}) query = self.users.insert() await self.database.execute(query, user_dict) if oauth_accounts_values is not None: if self.oauth_accounts is None: raise NotSetOAuthAccountTableError() query = self.oauth_accounts.insert() await self.database.execute_many(query, oauth_accounts_values) return user async def update(self, user: UD) -> UD: user_dict = user.dict() if "oauth_accounts" in user_dict: if self.oauth_accounts is None: raise NotSetOAuthAccountTableError() query = self.oauth_accounts.delete().where( self.oauth_accounts.c.user_id == user.id ) await self.database.execute(query) oauth_accounts_values = [] oauth_accounts = user_dict.pop("oauth_accounts") for oauth_account in oauth_accounts: oauth_accounts_values.append({"user_id": user.id, **oauth_account}) query = self.oauth_accounts.insert() await self.database.execute_many(query, oauth_accounts_values) query = self.users.update().where(self.users.c.id == user.id).values(user_dict) await self.database.execute(query) return user async def delete(self, user: UD) -> None: query = self.users.delete().where(self.users.c.id == user.id) await self.database.execute(query) async def _make_user(self, user: Mapping) -> UD: user_dict = {**user} if self.oauth_accounts is not None: query = self.oauth_accounts.select().where( self.oauth_accounts.c.user_id == user["id"] ) oauth_accounts = await self.database.fetch_all(query) user_dict["oauth_accounts"] = oauth_accounts return self.user_db_model(**user_dict)