from typing import Mapping, Optional, Type from databases import Database from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Table, select from sqlalchemy.ext.declarative import declared_attr from fastapi_users.db.base import BaseUserDatabase from fastapi_users.models import UD class SQLAlchemyBaseUserTable: """Base SQLAlchemy users table definition.""" __tablename__ = "user" id = Column(String, primary_key=True) email = Column(String, unique=True, index=True, nullable=False) hashed_password = Column(String, nullable=False) is_active = Column(Boolean, default=True, nullable=False) is_superuser = Column(Boolean, default=False, nullable=False) class SQLAlchemyBaseOAuthAccountTable: """Base SQLAlchemy OAuth account table definition.""" __tablename__ = "oauth_account" id = Column(String, primary_key=True) oauth_name = Column(String, index=True, nullable=False) access_token = Column(String, nullable=False) expires_at = Column(Integer, nullable=False) refresh_token = Column(String, nullable=True) account_id = Column(String, index=True, nullable=False) account_email = Column(String, nullable=False) @declared_attr def user_id(cls): return Column(String, ForeignKey("user.id", ondelete="cascade"), nullable=False) class NotSetOAuthAccountTableError(Exception): """ OAuth table was not set in DB adapter but was needed. Raised when trying to create/update a user with OAuth accounts set but no table were specified in the DB adapter. """ pass class SQLAlchemyUserDatabase(BaseUserDatabase[UD]): """ Database adapter for SQLAlchemy. :param user_db_model: Pydantic model of a DB representation of a user. :param database: `Database` instance from `encode/databases`. :param users: SQLAlchemy users table instance. :param oauth_accounts: Optional SQLAlchemy OAuth accounts table instance. """ database: Database users: Table oauth_accounts: Optional[Table] def __init__( self, user_db_model: Type[UD], database: Database, users: Table, oauth_accounts: Optional[Table] = None, ): super().__init__(user_db_model) self.database = database self.users = users self.oauth_accounts = oauth_accounts async def get(self, id: str) -> Optional[UD]: query = self.users.select().where(self.users.c.id == id) user = await self.database.fetch_one(query) return await self._make_user(user) if user else None async def get_by_email(self, email: str) -> Optional[UD]: query = self.users.select().where(self.users.c.email == email) user = await self.database.fetch_one(query) return await self._make_user(user) if user else None async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]: if self.oauth_accounts is not None: query = ( select([self.users]) .select_from(self.users.join(self.oauth_accounts)) .where(self.oauth_accounts.c.oauth_name == oauth) .where(self.oauth_accounts.c.account_id == account_id) ) user = await self.database.fetch_one(query) return await self._make_user(user) if user else None raise NotSetOAuthAccountTableError() async def create(self, user: UD) -> UD: user_dict = user.dict() oauth_accounts_values = None if "oauth_accounts" in user_dict: oauth_accounts_values = [] oauth_accounts = user_dict.pop("oauth_accounts") for oauth_account in oauth_accounts: oauth_accounts_values.append({"user_id": user.id, **oauth_account}) query = self.users.insert() await self.database.execute(query, user_dict) if oauth_accounts_values is not None: if self.oauth_accounts is None: raise NotSetOAuthAccountTableError() query = self.oauth_accounts.insert() await self.database.execute_many(query, oauth_accounts_values) return user async def update(self, user: UD) -> UD: user_dict = user.dict() if "oauth_accounts" in user_dict: if self.oauth_accounts is None: raise NotSetOAuthAccountTableError() query = self.oauth_accounts.delete().where( self.oauth_accounts.c.user_id == user.id ) await self.database.execute(query) oauth_accounts_values = [] oauth_accounts = user_dict.pop("oauth_accounts") for oauth_account in oauth_accounts: oauth_accounts_values.append({"user_id": user.id, **oauth_account}) query = self.oauth_accounts.insert() await self.database.execute_many(query, oauth_accounts_values) query = self.users.update().where(self.users.c.id == user.id).values(user_dict) await self.database.execute(query) return user async def delete(self, user: UD) -> None: query = self.users.delete().where(self.users.c.id == user.id) await self.database.execute(query) async def _make_user(self, user: Mapping) -> UD: user_dict = {**user} if self.oauth_accounts is not None: query = self.oauth_accounts.select().where( self.oauth_accounts.c.user_id == user["id"] ) oauth_accounts = await self.database.fetch_all(query) user_dict["oauth_accounts"] = oauth_accounts return self.user_db_model(**user_dict)