mirror of
				https://github.com/fastapi-users/fastapi-users.git
				synced 2025-11-04 06:37:51 +08:00 
			
		
		
		
	* Move users router in sub-module * Factorize UserRouter into EventHandlersRouter * Implement OAuth registration/login router * Apply isort/black * Remove temporary pytest marker * Fix httpx-oauth version in lock file * Ensure ON_AFTER_REGISTER event is triggered on OAuth registration * Add API on FastAPIUsers to generate an OAuth router * Improve test coverage of FastAPIUsers * Small fixes * Write the OAuth documentation * Fix SQL unit-tests by avoiding collisions in SQLite db files
		
			
				
	
	
		
			166 lines
		
	
	
		
			5.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			166 lines
		
	
	
		
			5.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from typing import List, Mapping, Optional, Type
 | 
						|
 | 
						|
from databases import Database
 | 
						|
from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Table, select
 | 
						|
from sqlalchemy.ext.declarative import declared_attr
 | 
						|
 | 
						|
from fastapi_users.db.base import BaseUserDatabase
 | 
						|
from fastapi_users.models import UD
 | 
						|
 | 
						|
 | 
						|
class SQLAlchemyBaseUserTable:
 | 
						|
    """Base SQLAlchemy users table definition."""
 | 
						|
 | 
						|
    __tablename__ = "user"
 | 
						|
 | 
						|
    id = Column(String, primary_key=True)
 | 
						|
    email = Column(String, unique=True, index=True, nullable=False)
 | 
						|
    hashed_password = Column(String, nullable=False)
 | 
						|
    is_active = Column(Boolean, default=True, nullable=False)
 | 
						|
    is_superuser = Column(Boolean, default=False, nullable=False)
 | 
						|
 | 
						|
 | 
						|
class SQLAlchemyBaseOAuthAccountTable:
 | 
						|
    """Base SQLAlchemy OAuth account table definition."""
 | 
						|
 | 
						|
    __tablename__ = "oauth_account"
 | 
						|
 | 
						|
    id = Column(String, primary_key=True)
 | 
						|
    oauth_name = Column(String, index=True, nullable=False)
 | 
						|
    access_token = Column(String, nullable=False)
 | 
						|
    expires_at = Column(Integer, nullable=False)
 | 
						|
    refresh_token = Column(String, nullable=True)
 | 
						|
    account_id = Column(String, index=True, nullable=False)
 | 
						|
    account_email = Column(String, nullable=False)
 | 
						|
 | 
						|
    @declared_attr
 | 
						|
    def user_id(cls):
 | 
						|
        return Column(String, ForeignKey("user.id", ondelete="cascade"), nullable=False)
 | 
						|
 | 
						|
 | 
						|
class NotSetOAuthAccountTableError(Exception):
 | 
						|
    """
 | 
						|
    OAuth table was not set in DB adapter but was needed.
 | 
						|
 | 
						|
    Raised when trying to create/update a user with OAuth accounts set
 | 
						|
    but no table were specified in the DB adapter.
 | 
						|
    """
 | 
						|
 | 
						|
    pass
 | 
						|
 | 
						|
 | 
						|
class SQLAlchemyUserDatabase(BaseUserDatabase[UD]):
 | 
						|
    """
 | 
						|
    Database adapter for SQLAlchemy.
 | 
						|
 | 
						|
    :param user_db_model: Pydantic model of a DB representation of a user.
 | 
						|
    :param database: `Database` instance from `encode/databases`.
 | 
						|
    :param users: SQLAlchemy users table instance.
 | 
						|
    :param oauth_accounts: Optional SQLAlchemy OAuth accounts table instance.
 | 
						|
    """
 | 
						|
 | 
						|
    database: Database
 | 
						|
    users: Table
 | 
						|
    oauth_accounts: Optional[Table]
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        user_db_model: Type[UD],
 | 
						|
        database: Database,
 | 
						|
        users: Table,
 | 
						|
        oauth_accounts: Optional[Table] = None,
 | 
						|
    ):
 | 
						|
        super().__init__(user_db_model)
 | 
						|
        self.database = database
 | 
						|
        self.users = users
 | 
						|
        self.oauth_accounts = oauth_accounts
 | 
						|
 | 
						|
    async def list(self) -> List[UD]:
 | 
						|
        query = self.users.select()
 | 
						|
        users = await self.database.fetch_all(query)
 | 
						|
        return [await self._make_user(user) for user in users]
 | 
						|
 | 
						|
    async def get(self, id: str) -> Optional[UD]:
 | 
						|
        query = self.users.select().where(self.users.c.id == id)
 | 
						|
        user = await self.database.fetch_one(query)
 | 
						|
        return await self._make_user(user) if user else None
 | 
						|
 | 
						|
    async def get_by_email(self, email: str) -> Optional[UD]:
 | 
						|
        query = self.users.select().where(self.users.c.email == email)
 | 
						|
        user = await self.database.fetch_one(query)
 | 
						|
        return await self._make_user(user) if user else None
 | 
						|
 | 
						|
    async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]:
 | 
						|
        if self.oauth_accounts is not None:
 | 
						|
            query = (
 | 
						|
                select([self.users])
 | 
						|
                .select_from(self.users.join(self.oauth_accounts))
 | 
						|
                .where(self.oauth_accounts.c.oauth_name == oauth)
 | 
						|
                .where(self.oauth_accounts.c.account_id == account_id)
 | 
						|
            )
 | 
						|
            user = await self.database.fetch_one(query)
 | 
						|
            return await self._make_user(user) if user else None
 | 
						|
        raise NotSetOAuthAccountTableError()
 | 
						|
 | 
						|
    async def create(self, user: UD) -> UD:
 | 
						|
        user_dict = user.dict()
 | 
						|
        oauth_accounts_values = None
 | 
						|
 | 
						|
        if "oauth_accounts" in user_dict:
 | 
						|
            oauth_accounts_values = []
 | 
						|
 | 
						|
            oauth_accounts = user_dict.pop("oauth_accounts")
 | 
						|
            for oauth_account in oauth_accounts:
 | 
						|
                oauth_accounts_values.append({"user_id": user.id, **oauth_account})
 | 
						|
 | 
						|
        query = self.users.insert()
 | 
						|
        await self.database.execute(query, user_dict)
 | 
						|
 | 
						|
        if oauth_accounts_values is not None:
 | 
						|
            if self.oauth_accounts is None:
 | 
						|
                raise NotSetOAuthAccountTableError()
 | 
						|
            query = self.oauth_accounts.insert()
 | 
						|
            await self.database.execute_many(query, oauth_accounts_values)
 | 
						|
 | 
						|
        return user
 | 
						|
 | 
						|
    async def update(self, user: UD) -> UD:
 | 
						|
        user_dict = user.dict()
 | 
						|
 | 
						|
        if "oauth_accounts" in user_dict:
 | 
						|
            if self.oauth_accounts is None:
 | 
						|
                raise NotSetOAuthAccountTableError()
 | 
						|
 | 
						|
            query = self.oauth_accounts.delete().where(
 | 
						|
                self.oauth_accounts.c.user_id == user.id
 | 
						|
            )
 | 
						|
            await self.database.execute(query)
 | 
						|
 | 
						|
            oauth_accounts_values = []
 | 
						|
            oauth_accounts = user_dict.pop("oauth_accounts")
 | 
						|
            for oauth_account in oauth_accounts:
 | 
						|
                oauth_accounts_values.append({"user_id": user.id, **oauth_account})
 | 
						|
 | 
						|
            query = self.oauth_accounts.insert()
 | 
						|
            await self.database.execute_many(query, oauth_accounts_values)
 | 
						|
 | 
						|
        query = self.users.update().where(self.users.c.id == user.id).values(user_dict)
 | 
						|
        await self.database.execute(query)
 | 
						|
        return user
 | 
						|
 | 
						|
    async def delete(self, user: UD) -> None:
 | 
						|
        query = self.users.delete().where(self.users.c.id == user.id)
 | 
						|
        await self.database.execute(query)
 | 
						|
 | 
						|
    async def _make_user(self, user: Mapping) -> UD:
 | 
						|
        user_dict = {**user}
 | 
						|
 | 
						|
        if self.oauth_accounts is not None:
 | 
						|
            query = self.oauth_accounts.select().where(
 | 
						|
                self.oauth_accounts.c.user_id == user["id"]
 | 
						|
            )
 | 
						|
            oauth_accounts = await self.database.fetch_all(query)
 | 
						|
            user_dict["oauth_accounts"] = oauth_accounts
 | 
						|
 | 
						|
        return self.user_db_model(**user_dict)
 |