mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-08-14 18:58:10 +08:00
161 lines
5.5 KiB
Python
161 lines
5.5 KiB
Python
from typing import Mapping, Optional, Type
|
|
|
|
from databases import Database
|
|
from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Table, select
|
|
from sqlalchemy.ext.declarative import declared_attr
|
|
|
|
from fastapi_users.db.base import BaseUserDatabase
|
|
from fastapi_users.models import UD
|
|
|
|
|
|
class SQLAlchemyBaseUserTable:
|
|
"""Base SQLAlchemy users table definition."""
|
|
|
|
__tablename__ = "user"
|
|
|
|
id = Column(String, primary_key=True)
|
|
email = Column(String, unique=True, index=True, nullable=False)
|
|
hashed_password = Column(String, nullable=False)
|
|
is_active = Column(Boolean, default=True, nullable=False)
|
|
is_superuser = Column(Boolean, default=False, nullable=False)
|
|
|
|
|
|
class SQLAlchemyBaseOAuthAccountTable:
|
|
"""Base SQLAlchemy OAuth account table definition."""
|
|
|
|
__tablename__ = "oauth_account"
|
|
|
|
id = Column(String, primary_key=True)
|
|
oauth_name = Column(String, index=True, nullable=False)
|
|
access_token = Column(String, nullable=False)
|
|
expires_at = Column(Integer, nullable=False)
|
|
refresh_token = Column(String, nullable=True)
|
|
account_id = Column(String, index=True, nullable=False)
|
|
account_email = Column(String, nullable=False)
|
|
|
|
@declared_attr
|
|
def user_id(cls):
|
|
return Column(String, ForeignKey("user.id", ondelete="cascade"), nullable=False)
|
|
|
|
|
|
class NotSetOAuthAccountTableError(Exception):
|
|
"""
|
|
OAuth table was not set in DB adapter but was needed.
|
|
|
|
Raised when trying to create/update a user with OAuth accounts set
|
|
but no table were specified in the DB adapter.
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
class SQLAlchemyUserDatabase(BaseUserDatabase[UD]):
|
|
"""
|
|
Database adapter for SQLAlchemy.
|
|
|
|
:param user_db_model: Pydantic model of a DB representation of a user.
|
|
:param database: `Database` instance from `encode/databases`.
|
|
:param users: SQLAlchemy users table instance.
|
|
:param oauth_accounts: Optional SQLAlchemy OAuth accounts table instance.
|
|
"""
|
|
|
|
database: Database
|
|
users: Table
|
|
oauth_accounts: Optional[Table]
|
|
|
|
def __init__(
|
|
self,
|
|
user_db_model: Type[UD],
|
|
database: Database,
|
|
users: Table,
|
|
oauth_accounts: Optional[Table] = None,
|
|
):
|
|
super().__init__(user_db_model)
|
|
self.database = database
|
|
self.users = users
|
|
self.oauth_accounts = oauth_accounts
|
|
|
|
async def get(self, id: str) -> Optional[UD]:
|
|
query = self.users.select().where(self.users.c.id == id)
|
|
user = await self.database.fetch_one(query)
|
|
return await self._make_user(user) if user else None
|
|
|
|
async def get_by_email(self, email: str) -> Optional[UD]:
|
|
query = self.users.select().where(self.users.c.email == email)
|
|
user = await self.database.fetch_one(query)
|
|
return await self._make_user(user) if user else None
|
|
|
|
async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]:
|
|
if self.oauth_accounts is not None:
|
|
query = (
|
|
select([self.users])
|
|
.select_from(self.users.join(self.oauth_accounts))
|
|
.where(self.oauth_accounts.c.oauth_name == oauth)
|
|
.where(self.oauth_accounts.c.account_id == account_id)
|
|
)
|
|
user = await self.database.fetch_one(query)
|
|
return await self._make_user(user) if user else None
|
|
raise NotSetOAuthAccountTableError()
|
|
|
|
async def create(self, user: UD) -> UD:
|
|
user_dict = user.dict()
|
|
oauth_accounts_values = None
|
|
|
|
if "oauth_accounts" in user_dict:
|
|
oauth_accounts_values = []
|
|
|
|
oauth_accounts = user_dict.pop("oauth_accounts")
|
|
for oauth_account in oauth_accounts:
|
|
oauth_accounts_values.append({"user_id": user.id, **oauth_account})
|
|
|
|
query = self.users.insert()
|
|
await self.database.execute(query, user_dict)
|
|
|
|
if oauth_accounts_values is not None:
|
|
if self.oauth_accounts is None:
|
|
raise NotSetOAuthAccountTableError()
|
|
query = self.oauth_accounts.insert()
|
|
await self.database.execute_many(query, oauth_accounts_values)
|
|
|
|
return user
|
|
|
|
async def update(self, user: UD) -> UD:
|
|
user_dict = user.dict()
|
|
|
|
if "oauth_accounts" in user_dict:
|
|
if self.oauth_accounts is None:
|
|
raise NotSetOAuthAccountTableError()
|
|
|
|
query = self.oauth_accounts.delete().where(
|
|
self.oauth_accounts.c.user_id == user.id
|
|
)
|
|
await self.database.execute(query)
|
|
|
|
oauth_accounts_values = []
|
|
oauth_accounts = user_dict.pop("oauth_accounts")
|
|
for oauth_account in oauth_accounts:
|
|
oauth_accounts_values.append({"user_id": user.id, **oauth_account})
|
|
|
|
query = self.oauth_accounts.insert()
|
|
await self.database.execute_many(query, oauth_accounts_values)
|
|
|
|
query = self.users.update().where(self.users.c.id == user.id).values(user_dict)
|
|
await self.database.execute(query)
|
|
return user
|
|
|
|
async def delete(self, user: UD) -> None:
|
|
query = self.users.delete().where(self.users.c.id == user.id)
|
|
await self.database.execute(query)
|
|
|
|
async def _make_user(self, user: Mapping) -> UD:
|
|
user_dict = {**user}
|
|
|
|
if self.oauth_accounts is not None:
|
|
query = self.oauth_accounts.select().where(
|
|
self.oauth_accounts.c.user_id == user["id"]
|
|
)
|
|
oauth_accounts = await self.database.fetch_all(query)
|
|
user_dict["oauth_accounts"] = oauth_accounts
|
|
|
|
return self.user_db_model(**user_dict)
|