mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-11-01 10:25:45 +08:00
Implement OAuth2 flow (#88)
* Move users router in sub-module * Factorize UserRouter into EventHandlersRouter * Implement OAuth registration/login router * Apply isort/black * Remove temporary pytest marker * Fix httpx-oauth version in lock file * Ensure ON_AFTER_REGISTER event is triggered on OAuth registration * Add API on FastAPIUsers to generate an OAuth router * Improve test coverage of FastAPIUsers * Small fixes * Write the OAuth documentation * Fix SQL unit-tests by avoiding collisions in SQLite db files
This commit is contained in:
@ -7,6 +7,7 @@ except ImportError: # pragma: no cover
|
||||
|
||||
try:
|
||||
from fastapi_users.db.sqlalchemy import ( # noqa: F401
|
||||
SQLAlchemyBaseOAuthAccountTable,
|
||||
SQLAlchemyBaseUserTable,
|
||||
SQLAlchemyUserDatabase,
|
||||
)
|
||||
@ -15,6 +16,7 @@ except ImportError: # pragma: no cover
|
||||
|
||||
try:
|
||||
from fastapi_users.db.tortoise import ( # noqa: F401
|
||||
TortoiseBaseOAuthAccountModel,
|
||||
TortoiseBaseUserModel,
|
||||
TortoiseUserDatabase,
|
||||
)
|
||||
|
||||
@ -30,6 +30,10 @@ class BaseUserDatabase(Generic[UD]):
|
||||
"""Get a single user by email."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]:
|
||||
"""Get a single user by OAuth account id."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def create(self, user: UD) -> UD:
|
||||
"""Create a user."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -33,6 +33,15 @@ class MongoDBUserDatabase(BaseUserDatabase[UD]):
|
||||
user = await self.collection.find_one({"email": email})
|
||||
return self.user_db_model(**user) if user else None
|
||||
|
||||
async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]:
|
||||
user = await self.collection.find_one(
|
||||
{
|
||||
"oauth_accounts.oauth_name": oauth,
|
||||
"oauth_accounts.account_id": account_id,
|
||||
}
|
||||
)
|
||||
return self.user_db_model(**user) if user else None
|
||||
|
||||
async def create(self, user: UD) -> UD:
|
||||
await self.collection.insert_one(user.dict())
|
||||
return user
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from typing import List, Optional, Type
|
||||
from typing import List, Mapping, Optional, Type
|
||||
|
||||
from databases import Database
|
||||
from sqlalchemy import Boolean, Column, String, Table
|
||||
from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Table, select
|
||||
from sqlalchemy.ext.declarative import declared_attr
|
||||
|
||||
from fastapi_users.db.base import BaseUserDatabase
|
||||
from fastapi_users.models import UD
|
||||
@ -19,6 +20,35 @@ class SQLAlchemyBaseUserTable:
|
||||
is_superuser = Column(Boolean, default=False, nullable=False)
|
||||
|
||||
|
||||
class SQLAlchemyBaseOAuthAccountTable:
|
||||
"""Base SQLAlchemy OAuth account table definition."""
|
||||
|
||||
__tablename__ = "oauth_account"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
oauth_name = Column(String, index=True, nullable=False)
|
||||
access_token = Column(String, nullable=False)
|
||||
expires_at = Column(Integer, nullable=False)
|
||||
refresh_token = Column(String, nullable=True)
|
||||
account_id = Column(String, index=True, nullable=False)
|
||||
account_email = Column(String, nullable=False)
|
||||
|
||||
@declared_attr
|
||||
def user_id(cls):
|
||||
return Column(String, ForeignKey("user.id", ondelete="cascade"), nullable=False)
|
||||
|
||||
|
||||
class NotSetOAuthAccountTableError(Exception):
|
||||
"""
|
||||
OAuth table was not set in DB adapter but was needed.
|
||||
|
||||
Raised when trying to create/update a user with OAuth accounts set
|
||||
but no table were specified in the DB adapter.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SQLAlchemyUserDatabase(BaseUserDatabase[UD]):
|
||||
"""
|
||||
Database adapter for SQLAlchemy.
|
||||
@ -26,43 +56,110 @@ class SQLAlchemyUserDatabase(BaseUserDatabase[UD]):
|
||||
:param user_db_model: Pydantic model of a DB representation of a user.
|
||||
:param database: `Database` instance from `encode/databases`.
|
||||
:param users: SQLAlchemy users table instance.
|
||||
:param oauth_accounts: Optional SQLAlchemy OAuth accounts table instance.
|
||||
"""
|
||||
|
||||
database: Database
|
||||
users: Table
|
||||
oauth_accounts: Optional[Table]
|
||||
|
||||
def __init__(self, user_db_model: Type[UD], database: Database, users: Table):
|
||||
def __init__(
|
||||
self,
|
||||
user_db_model: Type[UD],
|
||||
database: Database,
|
||||
users: Table,
|
||||
oauth_accounts: Optional[Table] = None,
|
||||
):
|
||||
super().__init__(user_db_model)
|
||||
self.database = database
|
||||
self.users = users
|
||||
self.oauth_accounts = oauth_accounts
|
||||
|
||||
async def list(self) -> List[UD]:
|
||||
query = self.users.select()
|
||||
users = await self.database.fetch_all(query)
|
||||
return [self.user_db_model(**user) for user in users]
|
||||
return [await self._make_user(user) for user in users]
|
||||
|
||||
async def get(self, id: str) -> Optional[UD]:
|
||||
query = self.users.select().where(self.users.c.id == id)
|
||||
user = await self.database.fetch_one(query)
|
||||
return self.user_db_model(**user) if user else None
|
||||
return await self._make_user(user) if user else None
|
||||
|
||||
async def get_by_email(self, email: str) -> Optional[UD]:
|
||||
query = self.users.select().where(self.users.c.email == email)
|
||||
user = await self.database.fetch_one(query)
|
||||
return self.user_db_model(**user) if user else None
|
||||
return await self._make_user(user) if user else None
|
||||
|
||||
async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]:
|
||||
if self.oauth_accounts is not None:
|
||||
query = (
|
||||
select([self.users])
|
||||
.select_from(self.users.join(self.oauth_accounts))
|
||||
.where(self.oauth_accounts.c.oauth_name == oauth)
|
||||
.where(self.oauth_accounts.c.account_id == account_id)
|
||||
)
|
||||
user = await self.database.fetch_one(query)
|
||||
return await self._make_user(user) if user else None
|
||||
raise NotSetOAuthAccountTableError()
|
||||
|
||||
async def create(self, user: UD) -> UD:
|
||||
query = self.users.insert().values(**user.dict())
|
||||
await self.database.execute(query)
|
||||
user_dict = user.dict()
|
||||
oauth_accounts_values = None
|
||||
|
||||
if "oauth_accounts" in user_dict:
|
||||
oauth_accounts_values = []
|
||||
|
||||
oauth_accounts = user_dict.pop("oauth_accounts")
|
||||
for oauth_account in oauth_accounts:
|
||||
oauth_accounts_values.append({"user_id": user.id, **oauth_account})
|
||||
|
||||
query = self.users.insert()
|
||||
await self.database.execute(query, user_dict)
|
||||
|
||||
if oauth_accounts_values is not None:
|
||||
if self.oauth_accounts is None:
|
||||
raise NotSetOAuthAccountTableError()
|
||||
query = self.oauth_accounts.insert()
|
||||
await self.database.execute_many(query, oauth_accounts_values)
|
||||
|
||||
return user
|
||||
|
||||
async def update(self, user: UD) -> UD:
|
||||
query = (
|
||||
self.users.update().where(self.users.c.id == user.id).values(**user.dict())
|
||||
)
|
||||
user_dict = user.dict()
|
||||
|
||||
if "oauth_accounts" in user_dict:
|
||||
if self.oauth_accounts is None:
|
||||
raise NotSetOAuthAccountTableError()
|
||||
|
||||
query = self.oauth_accounts.delete().where(
|
||||
self.oauth_accounts.c.user_id == user.id
|
||||
)
|
||||
await self.database.execute(query)
|
||||
|
||||
oauth_accounts_values = []
|
||||
oauth_accounts = user_dict.pop("oauth_accounts")
|
||||
for oauth_account in oauth_accounts:
|
||||
oauth_accounts_values.append({"user_id": user.id, **oauth_account})
|
||||
|
||||
query = self.oauth_accounts.insert()
|
||||
await self.database.execute_many(query, oauth_accounts_values)
|
||||
|
||||
query = self.users.update().where(self.users.c.id == user.id).values(user_dict)
|
||||
await self.database.execute(query)
|
||||
return user
|
||||
|
||||
async def delete(self, user: UD) -> None:
|
||||
query = self.users.delete().where(self.users.c.id == user.id)
|
||||
await self.database.execute(query)
|
||||
|
||||
async def _make_user(self, user: Mapping) -> UD:
|
||||
user_dict = {**user}
|
||||
|
||||
if self.oauth_accounts is not None:
|
||||
query = self.oauth_accounts.select().where(
|
||||
self.oauth_accounts.c.user_id == user["id"]
|
||||
)
|
||||
oauth_accounts = await self.database.fetch_all(query)
|
||||
user_dict["oauth_accounts"] = oauth_accounts
|
||||
|
||||
return self.user_db_model(**user_dict)
|
||||
|
||||
@ -14,6 +14,27 @@ class TortoiseBaseUserModel(Model):
|
||||
is_active = fields.BooleanField(default=True, null=False)
|
||||
is_superuser = fields.BooleanField(default=False, null=False)
|
||||
|
||||
async def to_dict(self):
|
||||
d = {}
|
||||
for field in self._meta.db_fields:
|
||||
d[field] = getattr(self, field)
|
||||
for field in self._meta.backward_fk_fields:
|
||||
d[field] = await getattr(self, field).all().values()
|
||||
return d
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
|
||||
class TortoiseBaseOAuthAccountModel(Model):
|
||||
id = fields.CharField(pk=True, generated=False, max_length=255)
|
||||
oauth_name = fields.CharField(null=False, max_length=255)
|
||||
access_token = fields.CharField(null=False, max_length=255)
|
||||
expires_at = fields.IntField(null=False)
|
||||
refresh_token = fields.CharField(null=True, max_length=255)
|
||||
account_id = fields.CharField(index=True, null=False, max_length=255)
|
||||
account_email = fields.CharField(null=False, max_length=255)
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
@ -24,41 +45,109 @@ class TortoiseUserDatabase(BaseUserDatabase[UD]):
|
||||
|
||||
:param user_db_model: Pydantic model of a DB representation of a user.
|
||||
:param model: Tortoise ORM model.
|
||||
:param oauth_account_model: Optional Tortoise ORM model of a OAuth account.
|
||||
"""
|
||||
|
||||
model: Type[TortoiseBaseUserModel]
|
||||
oauth_account_model: Optional[Type[TortoiseBaseOAuthAccountModel]]
|
||||
|
||||
def __init__(self, user_db_model: Type[UD], model: Type[TortoiseBaseUserModel]):
|
||||
def __init__(
|
||||
self,
|
||||
user_db_model: Type[UD],
|
||||
model: Type[TortoiseBaseUserModel],
|
||||
oauth_account_model: Optional[Type[TortoiseBaseOAuthAccountModel]] = None,
|
||||
):
|
||||
super().__init__(user_db_model)
|
||||
self.model = model
|
||||
self.oauth_account_model = oauth_account_model
|
||||
|
||||
async def list(self) -> List[UD]:
|
||||
users = await self.model.all()
|
||||
return [self.user_db_model.from_orm(user) for user in users]
|
||||
query = self.model.all()
|
||||
|
||||
if self.oauth_account_model is not None:
|
||||
query = query.prefetch_related("oauth_accounts")
|
||||
|
||||
users = await query
|
||||
|
||||
return [self.user_db_model(**await user.to_dict()) for user in users]
|
||||
|
||||
async def get(self, id: str) -> Optional[UD]:
|
||||
try:
|
||||
user = await self.model.get(id=id)
|
||||
return self.user_db_model.from_orm(user)
|
||||
query = self.model.get(id=id)
|
||||
|
||||
if self.oauth_account_model is not None:
|
||||
query = query.prefetch_related("oauth_accounts")
|
||||
|
||||
user = await query
|
||||
user_dict = await user.to_dict()
|
||||
|
||||
return self.user_db_model(**user_dict)
|
||||
except DoesNotExist:
|
||||
return None
|
||||
|
||||
async def get_by_email(self, email: str) -> Optional[UD]:
|
||||
try:
|
||||
user = await self.model.get(email=email)
|
||||
return self.user_db_model.from_orm(user)
|
||||
query = self.model.get(email=email)
|
||||
|
||||
if self.oauth_account_model is not None:
|
||||
query = query.prefetch_related("oauth_accounts")
|
||||
|
||||
user = await query
|
||||
user_dict = await user.to_dict()
|
||||
|
||||
return self.user_db_model(**user_dict)
|
||||
except DoesNotExist:
|
||||
return None
|
||||
|
||||
async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]:
|
||||
try:
|
||||
query = self.model.get(
|
||||
oauth_accounts__oauth_name=oauth, oauth_accounts__account_id=account_id
|
||||
).prefetch_related("oauth_accounts")
|
||||
|
||||
user = await query
|
||||
user_dict = await user.to_dict()
|
||||
|
||||
return self.user_db_model(**user_dict)
|
||||
except DoesNotExist:
|
||||
return None
|
||||
|
||||
async def create(self, user: UD) -> UD:
|
||||
model = self.model(**user.dict())
|
||||
user_dict = user.dict()
|
||||
oauth_accounts = user_dict.pop("oauth_accounts", None)
|
||||
|
||||
model = self.model(**user_dict)
|
||||
await model.save()
|
||||
|
||||
if oauth_accounts and self.oauth_account_model:
|
||||
oauth_account_objects = []
|
||||
for oauth_account in oauth_accounts:
|
||||
oauth_account_objects.append(
|
||||
self.oauth_account_model(user=model, **oauth_account)
|
||||
)
|
||||
await self.oauth_account_model.bulk_create(oauth_account_objects)
|
||||
|
||||
return user
|
||||
|
||||
async def update(self, user: UD) -> UD:
|
||||
user_dict = user.dict()
|
||||
user_dict.pop("id") # Tortoise complains if we pass the PK again
|
||||
await self.model.filter(id=user.id).update(**user_dict)
|
||||
oauth_accounts = user_dict.pop("oauth_accounts", None)
|
||||
|
||||
model = await self.model.get(id=user.id)
|
||||
for field in user_dict:
|
||||
setattr(model, field, user_dict[field])
|
||||
await model.save()
|
||||
|
||||
if oauth_accounts and self.oauth_account_model:
|
||||
await model.oauth_accounts.all().delete()
|
||||
oauth_account_objects = []
|
||||
for oauth_account in oauth_accounts:
|
||||
oauth_account_objects.append(
|
||||
self.oauth_account_model(user=model, **oauth_account)
|
||||
)
|
||||
await self.oauth_account_model.bulk_create(oauth_account_objects)
|
||||
|
||||
return user
|
||||
|
||||
async def delete(self, user: UD) -> None:
|
||||
|
||||
@ -1,9 +1,17 @@
|
||||
from typing import Callable, Sequence, Type
|
||||
from collections import defaultdict
|
||||
from typing import Callable, DefaultDict, List, Sequence, Type
|
||||
|
||||
from httpx_oauth.oauth2 import BaseOAuth2
|
||||
|
||||
from fastapi_users import models
|
||||
from fastapi_users.authentication import Authenticator, BaseAuthentication
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.router import Event, UserRouter, get_user_router
|
||||
from fastapi_users.router import (
|
||||
Event,
|
||||
EventHandlersRouter,
|
||||
get_oauth_router,
|
||||
get_user_router,
|
||||
)
|
||||
|
||||
|
||||
class FastAPIUsers:
|
||||
@ -20,12 +28,16 @@ class FastAPIUsers:
|
||||
:param reset_password_token_lifetime_seconds: Lifetime of reset password token.
|
||||
|
||||
:attribute router: Router exposing authentication routes.
|
||||
:attribute oauth_routers: List of OAuth routers created through `get_oauth_router`.
|
||||
:attribute get_current_user: Dependency callable to inject authenticated user.
|
||||
"""
|
||||
|
||||
db: BaseUserDatabase
|
||||
authenticator: Authenticator
|
||||
router: UserRouter
|
||||
router: EventHandlersRouter
|
||||
oauth_routers: List[EventHandlersRouter]
|
||||
_user_db_model: Type[models.BaseUserDB]
|
||||
_event_handlers: DefaultDict[Event, List[Callable]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -50,6 +62,9 @@ class FastAPIUsers:
|
||||
reset_password_token_secret,
|
||||
reset_password_token_lifetime_seconds,
|
||||
)
|
||||
self.oauth_routers = []
|
||||
self._user_db_model = user_db_model
|
||||
self._event_handlers = defaultdict(list)
|
||||
|
||||
self.get_current_user = self.authenticator.get_current_user
|
||||
self.get_current_active_user = self.authenticator.get_current_active_user
|
||||
@ -63,9 +78,40 @@ class FastAPIUsers:
|
||||
"""Add an event handler on successful forgot password request."""
|
||||
return self._on_event(Event.ON_AFTER_FORGOT_PASSWORD)
|
||||
|
||||
def get_oauth_router(
|
||||
self, oauth_client: BaseOAuth2, state_secret: str, redirect_url: str = None
|
||||
) -> EventHandlersRouter:
|
||||
"""
|
||||
Return an OAuth router for a given OAuth client.
|
||||
|
||||
:param oauth_client: The HTTPX OAuth client instance.
|
||||
:param state_secret: Secret used to encode the state JWT.
|
||||
:param redirect_url: Optional arbitrary redirect URL for the OAuth2 flow.
|
||||
If not given, the URL to the callback endpoint will be generated.
|
||||
"""
|
||||
oauth_router = get_oauth_router(
|
||||
oauth_client,
|
||||
self.db,
|
||||
self._user_db_model,
|
||||
self.authenticator,
|
||||
state_secret,
|
||||
redirect_url,
|
||||
)
|
||||
|
||||
for event_type in self._event_handlers:
|
||||
for handler in self._event_handlers[event_type]:
|
||||
oauth_router.add_event_handler(event_type, handler)
|
||||
|
||||
self.oauth_routers.append(oauth_router)
|
||||
|
||||
return oauth_router
|
||||
|
||||
def _on_event(self, event_type: Event) -> Callable:
|
||||
def decorator(func: Callable) -> Callable:
|
||||
self._event_handlers[event_type].append(func)
|
||||
self.router.add_event_handler(event_type, func)
|
||||
for oauth_router in self.oauth_routers:
|
||||
oauth_router.add_event_handler(event_type, func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
from typing import Optional, TypeVar
|
||||
from typing import List, Optional, TypeVar
|
||||
|
||||
import pydantic
|
||||
from pydantic import BaseModel, EmailStr
|
||||
@ -19,7 +19,8 @@ class BaseUser(BaseModel):
|
||||
|
||||
def create_update_dict(self):
|
||||
return self.dict(
|
||||
exclude_unset=True, exclude={"id", "is_superuser", "is_active"}
|
||||
exclude_unset=True,
|
||||
exclude={"id", "is_superuser", "is_active", "oauth_accounts"},
|
||||
)
|
||||
|
||||
def create_update_dict_superuser(self):
|
||||
@ -44,3 +45,28 @@ class BaseUserDB(BaseUser):
|
||||
|
||||
|
||||
UD = TypeVar("UD", bound=BaseUserDB)
|
||||
|
||||
|
||||
class BaseOAuthAccount(BaseModel):
|
||||
"""Base OAuth account model."""
|
||||
|
||||
id: Optional[str] = None
|
||||
oauth_name: str
|
||||
access_token: str
|
||||
expires_at: int
|
||||
refresh_token: Optional[str] = None
|
||||
account_id: str
|
||||
account_email: str
|
||||
|
||||
@pydantic.validator("id", pre=True, always=True)
|
||||
def default_id(cls, v):
|
||||
return v or str(uuid.uuid4())
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class BaseOAuthAccountMixin(BaseModel):
|
||||
"""Adds OAuth accounts list to a User model."""
|
||||
|
||||
oauth_accounts: List[BaseOAuthAccount] = []
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from typing import Tuple
|
||||
|
||||
from passlib import pwd
|
||||
from passlib.context import CryptContext
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
@ -13,3 +14,7 @@ def verify_and_update_password(
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def generate_password() -> str:
|
||||
return pwd.genword()
|
||||
|
||||
7
fastapi_users/router/__init__.py
Normal file
7
fastapi_users/router/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
from fastapi_users.router.common import ( # noqa: F401
|
||||
ErrorCode,
|
||||
Event,
|
||||
EventHandlersRouter,
|
||||
)
|
||||
from fastapi_users.router.oauth import get_oauth_router # noqa: F401
|
||||
from fastapi_users.router.users import get_user_router # noqa: F401
|
||||
35
fastapi_users/router/common.py
Normal file
35
fastapi_users/router/common.py
Normal file
@ -0,0 +1,35 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
from typing import Callable, DefaultDict, List
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
|
||||
class ErrorCode:
|
||||
REGISTER_USER_ALREADY_EXISTS = "REGISTER_USER_ALREADY_EXISTS"
|
||||
LOGIN_BAD_CREDENTIALS = "LOGIN_BAD_CREDENTIALS"
|
||||
RESET_PASSWORD_BAD_TOKEN = "RESET_PASSWORD_BAD_TOKEN"
|
||||
|
||||
|
||||
class Event(Enum):
|
||||
ON_AFTER_REGISTER = auto()
|
||||
ON_AFTER_FORGOT_PASSWORD = auto()
|
||||
|
||||
|
||||
class EventHandlersRouter(APIRouter):
|
||||
event_handlers: DefaultDict[Event, List[Callable]]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.event_handlers = defaultdict(list)
|
||||
|
||||
def add_event_handler(self, event_type: Event, func: Callable) -> None:
|
||||
self.event_handlers[event_type].append(func)
|
||||
|
||||
async def run_handlers(self, event_type: Event, *args, **kwargs) -> None:
|
||||
for handler in self.event_handlers[event_type]:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler(*args, **kwargs)
|
||||
else:
|
||||
handler(*args, **kwargs)
|
||||
144
fastapi_users/router/oauth.py
Normal file
144
fastapi_users/router/oauth.py
Normal file
@ -0,0 +1,144 @@
|
||||
from typing import Dict, List, Type, cast
|
||||
|
||||
import jwt
|
||||
from fastapi import Depends, HTTPException, Query
|
||||
from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
|
||||
from httpx_oauth.oauth2 import BaseOAuth2
|
||||
from starlette import status
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from fastapi_users import models
|
||||
from fastapi_users.authentication import Authenticator
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.password import generate_password, get_password_hash
|
||||
from fastapi_users.router.common import ErrorCode, Event, EventHandlersRouter
|
||||
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
|
||||
|
||||
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
|
||||
|
||||
|
||||
def generate_state_token(
|
||||
data: Dict[str, str], secret: str, lifetime_seconds: int = 3600
|
||||
) -> str:
|
||||
data["aud"] = STATE_TOKEN_AUDIENCE
|
||||
return generate_jwt(data, lifetime_seconds, secret, JWT_ALGORITHM)
|
||||
|
||||
|
||||
def decode_state_token(token: str, secret: str) -> Dict[str, str]:
|
||||
return jwt.decode(
|
||||
token, secret, audience=STATE_TOKEN_AUDIENCE, algorithms=[JWT_ALGORITHM],
|
||||
)
|
||||
|
||||
|
||||
def get_oauth_router(
|
||||
oauth_client: BaseOAuth2,
|
||||
user_db: BaseUserDatabase[models.BaseUserDB],
|
||||
user_db_model: Type[models.BaseUserDB],
|
||||
authenticator: Authenticator,
|
||||
state_secret: str,
|
||||
redirect_url: str = None,
|
||||
) -> EventHandlersRouter:
|
||||
"""Generate a router with the OAuth routes."""
|
||||
router = EventHandlersRouter()
|
||||
callback_route_name = f"{oauth_client.name}-callback"
|
||||
|
||||
if redirect_url is not None:
|
||||
oauth2_authorize_callback = OAuth2AuthorizeCallback(
|
||||
oauth_client, redirect_url=redirect_url,
|
||||
)
|
||||
else:
|
||||
oauth2_authorize_callback = OAuth2AuthorizeCallback(
|
||||
oauth_client, route_name=callback_route_name,
|
||||
)
|
||||
|
||||
@router.get("/authorize")
|
||||
async def authorize(
|
||||
request: Request, authentication_backend: str, scopes: List[str] = Query(None),
|
||||
):
|
||||
# Check that authentication_backend exists
|
||||
backend_exists = False
|
||||
for backend in authenticator.backends:
|
||||
if backend.name == authentication_backend:
|
||||
backend_exists = True
|
||||
break
|
||||
if not backend_exists:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
redirect_url = request.url_for(callback_route_name)
|
||||
state_data = {
|
||||
"authentication_backend": authentication_backend,
|
||||
}
|
||||
state = generate_state_token(state_data, state_secret)
|
||||
authorization_url = await oauth_client.get_authorization_url(
|
||||
redirect_url, state, scopes,
|
||||
)
|
||||
|
||||
return {"authorization_url": authorization_url}
|
||||
|
||||
@router.get("/callback", name=f"{oauth_client.name}-callback")
|
||||
async def callback(
|
||||
response: Response, access_token_state=Depends(oauth2_authorize_callback)
|
||||
):
|
||||
token, state = access_token_state
|
||||
account_id, account_email = await oauth_client.get_id_email(
|
||||
token["access_token"]
|
||||
)
|
||||
|
||||
try:
|
||||
state_data = decode_state_token(state, state_secret)
|
||||
except jwt.DecodeError:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
user = await user_db.get_by_oauth_account(oauth_client.name, account_id)
|
||||
|
||||
new_oauth_account = models.BaseOAuthAccount(
|
||||
oauth_name=oauth_client.name,
|
||||
access_token=token["access_token"],
|
||||
expires_at=token["expires_at"],
|
||||
refresh_token=token.get("refresh_token"),
|
||||
account_id=account_id,
|
||||
account_email=account_email,
|
||||
)
|
||||
|
||||
if not user:
|
||||
user = await user_db.get_by_email(account_email)
|
||||
if user:
|
||||
# Link account
|
||||
user.oauth_accounts.append(new_oauth_account) # type: ignore
|
||||
await user_db.update(user)
|
||||
else:
|
||||
# Create account
|
||||
password = generate_password()
|
||||
user = user_db_model(
|
||||
email=account_email,
|
||||
hashed_password=get_password_hash(password),
|
||||
oauth_accounts=[new_oauth_account],
|
||||
)
|
||||
await user_db.create(user)
|
||||
await router.run_handlers(Event.ON_AFTER_REGISTER, user)
|
||||
else:
|
||||
# Update oauth
|
||||
updated_oauth_accounts = []
|
||||
for oauth_account in user.oauth_accounts: # type: ignore
|
||||
if oauth_account.account_id == account_id:
|
||||
updated_oauth_accounts.append(new_oauth_account)
|
||||
else:
|
||||
updated_oauth_accounts.append(oauth_account)
|
||||
user.oauth_accounts = updated_oauth_accounts # type: ignore
|
||||
await user_db.update(user)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.LOGIN_BAD_CREDENTIALS,
|
||||
)
|
||||
|
||||
# Authenticate
|
||||
for backend in authenticator.backends:
|
||||
if backend.name == state_data["authentication_backend"]:
|
||||
return await backend.get_login_response(
|
||||
cast(models.BaseUserDB, user), response
|
||||
)
|
||||
|
||||
return router
|
||||
@ -1,10 +1,7 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Callable, DefaultDict, Dict, List, Type, cast
|
||||
from typing import Any, Dict, List, Type, cast
|
||||
|
||||
import jwt
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from fastapi import Body, Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from pydantic import EmailStr
|
||||
from starlette import status
|
||||
@ -14,40 +11,14 @@ from fastapi_users import models
|
||||
from fastapi_users.authentication import Authenticator, BaseAuthentication
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.password import get_password_hash
|
||||
from fastapi_users.router.common import ErrorCode, Event, EventHandlersRouter
|
||||
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
|
||||
|
||||
|
||||
class ErrorCode:
|
||||
REGISTER_USER_ALREADY_EXISTS = "REGISTER_USER_ALREADY_EXISTS"
|
||||
LOGIN_BAD_CREDENTIALS = "LOGIN_BAD_CREDENTIALS"
|
||||
RESET_PASSWORD_BAD_TOKEN = "RESET_PASSWORD_BAD_TOKEN"
|
||||
|
||||
|
||||
class Event(Enum):
|
||||
ON_AFTER_REGISTER = auto()
|
||||
ON_AFTER_FORGOT_PASSWORD = auto()
|
||||
|
||||
|
||||
class UserRouter(APIRouter):
|
||||
event_handlers: DefaultDict[Event, List[Callable]]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.event_handlers = defaultdict(list)
|
||||
|
||||
def add_event_handler(self, event_type: Event, func: Callable) -> None:
|
||||
self.event_handlers[event_type].append(func)
|
||||
|
||||
async def run_handlers(self, event_type: Event, *args, **kwargs) -> None:
|
||||
for handler in self.event_handlers[event_type]:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler(*args, **kwargs)
|
||||
else:
|
||||
handler(*args, **kwargs)
|
||||
|
||||
|
||||
def _add_login_route(
|
||||
router: UserRouter, user_db: BaseUserDatabase, auth_backend: BaseAuthentication
|
||||
router: EventHandlersRouter,
|
||||
user_db: BaseUserDatabase,
|
||||
auth_backend: BaseAuthentication,
|
||||
):
|
||||
@router.post(f"/login/{auth_backend.name}")
|
||||
async def login(
|
||||
@ -73,9 +44,9 @@ def get_user_router(
|
||||
authenticator: Authenticator,
|
||||
reset_password_token_secret: str,
|
||||
reset_password_token_lifetime_seconds: int = 3600,
|
||||
) -> UserRouter:
|
||||
) -> EventHandlersRouter:
|
||||
"""Generate a router with the authentication routes."""
|
||||
router = UserRouter()
|
||||
router = EventHandlersRouter()
|
||||
|
||||
reset_password_token_audience = "fastapi-users:reset"
|
||||
|
||||
Reference in New Issue
Block a user