Implement OAuth2 flow (#88)

* Move users router in sub-module

* Factorize UserRouter into EventHandlersRouter

* Implement OAuth registration/login router

* Apply isort/black

* Remove temporary pytest marker

* Fix httpx-oauth version in lock file

* Ensure ON_AFTER_REGISTER event is triggered on OAuth registration

* Add API on FastAPIUsers to generate an OAuth router

* Improve test coverage of FastAPIUsers

* Small fixes

* Write the OAuth documentation

* Fix SQL unit-tests by avoiding collisions in SQLite db files
This commit is contained in:
François Voron
2020-01-17 11:43:17 +01:00
committed by GitHub
parent 54aefea59a
commit 88b133d41c
32 changed files with 1723 additions and 107 deletions

View File

@ -7,6 +7,7 @@ except ImportError: # pragma: no cover
try:
from fastapi_users.db.sqlalchemy import ( # noqa: F401
SQLAlchemyBaseOAuthAccountTable,
SQLAlchemyBaseUserTable,
SQLAlchemyUserDatabase,
)
@ -15,6 +16,7 @@ except ImportError: # pragma: no cover
try:
from fastapi_users.db.tortoise import ( # noqa: F401
TortoiseBaseOAuthAccountModel,
TortoiseBaseUserModel,
TortoiseUserDatabase,
)

View File

@ -30,6 +30,10 @@ class BaseUserDatabase(Generic[UD]):
"""Get a single user by email."""
raise NotImplementedError()
async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]:
"""Get a single user by OAuth account id."""
raise NotImplementedError()
async def create(self, user: UD) -> UD:
"""Create a user."""
raise NotImplementedError()

View File

@ -33,6 +33,15 @@ class MongoDBUserDatabase(BaseUserDatabase[UD]):
user = await self.collection.find_one({"email": email})
return self.user_db_model(**user) if user else None
async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]:
user = await self.collection.find_one(
{
"oauth_accounts.oauth_name": oauth,
"oauth_accounts.account_id": account_id,
}
)
return self.user_db_model(**user) if user else None
async def create(self, user: UD) -> UD:
await self.collection.insert_one(user.dict())
return user

View File

@ -1,7 +1,8 @@
from typing import List, Optional, Type
from typing import List, Mapping, Optional, Type
from databases import Database
from sqlalchemy import Boolean, Column, String, Table
from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Table, select
from sqlalchemy.ext.declarative import declared_attr
from fastapi_users.db.base import BaseUserDatabase
from fastapi_users.models import UD
@ -19,6 +20,35 @@ class SQLAlchemyBaseUserTable:
is_superuser = Column(Boolean, default=False, nullable=False)
class SQLAlchemyBaseOAuthAccountTable:
"""Base SQLAlchemy OAuth account table definition."""
__tablename__ = "oauth_account"
id = Column(String, primary_key=True)
oauth_name = Column(String, index=True, nullable=False)
access_token = Column(String, nullable=False)
expires_at = Column(Integer, nullable=False)
refresh_token = Column(String, nullable=True)
account_id = Column(String, index=True, nullable=False)
account_email = Column(String, nullable=False)
@declared_attr
def user_id(cls):
return Column(String, ForeignKey("user.id", ondelete="cascade"), nullable=False)
class NotSetOAuthAccountTableError(Exception):
"""
OAuth table was not set in DB adapter but was needed.
Raised when trying to create/update a user with OAuth accounts set
but no table were specified in the DB adapter.
"""
pass
class SQLAlchemyUserDatabase(BaseUserDatabase[UD]):
"""
Database adapter for SQLAlchemy.
@ -26,43 +56,110 @@ class SQLAlchemyUserDatabase(BaseUserDatabase[UD]):
:param user_db_model: Pydantic model of a DB representation of a user.
:param database: `Database` instance from `encode/databases`.
:param users: SQLAlchemy users table instance.
:param oauth_accounts: Optional SQLAlchemy OAuth accounts table instance.
"""
database: Database
users: Table
oauth_accounts: Optional[Table]
def __init__(self, user_db_model: Type[UD], database: Database, users: Table):
def __init__(
self,
user_db_model: Type[UD],
database: Database,
users: Table,
oauth_accounts: Optional[Table] = None,
):
super().__init__(user_db_model)
self.database = database
self.users = users
self.oauth_accounts = oauth_accounts
async def list(self) -> List[UD]:
query = self.users.select()
users = await self.database.fetch_all(query)
return [self.user_db_model(**user) for user in users]
return [await self._make_user(user) for user in users]
async def get(self, id: str) -> Optional[UD]:
query = self.users.select().where(self.users.c.id == id)
user = await self.database.fetch_one(query)
return self.user_db_model(**user) if user else None
return await self._make_user(user) if user else None
async def get_by_email(self, email: str) -> Optional[UD]:
query = self.users.select().where(self.users.c.email == email)
user = await self.database.fetch_one(query)
return self.user_db_model(**user) if user else None
return await self._make_user(user) if user else None
async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]:
if self.oauth_accounts is not None:
query = (
select([self.users])
.select_from(self.users.join(self.oauth_accounts))
.where(self.oauth_accounts.c.oauth_name == oauth)
.where(self.oauth_accounts.c.account_id == account_id)
)
user = await self.database.fetch_one(query)
return await self._make_user(user) if user else None
raise NotSetOAuthAccountTableError()
async def create(self, user: UD) -> UD:
query = self.users.insert().values(**user.dict())
await self.database.execute(query)
user_dict = user.dict()
oauth_accounts_values = None
if "oauth_accounts" in user_dict:
oauth_accounts_values = []
oauth_accounts = user_dict.pop("oauth_accounts")
for oauth_account in oauth_accounts:
oauth_accounts_values.append({"user_id": user.id, **oauth_account})
query = self.users.insert()
await self.database.execute(query, user_dict)
if oauth_accounts_values is not None:
if self.oauth_accounts is None:
raise NotSetOAuthAccountTableError()
query = self.oauth_accounts.insert()
await self.database.execute_many(query, oauth_accounts_values)
return user
async def update(self, user: UD) -> UD:
query = (
self.users.update().where(self.users.c.id == user.id).values(**user.dict())
)
user_dict = user.dict()
if "oauth_accounts" in user_dict:
if self.oauth_accounts is None:
raise NotSetOAuthAccountTableError()
query = self.oauth_accounts.delete().where(
self.oauth_accounts.c.user_id == user.id
)
await self.database.execute(query)
oauth_accounts_values = []
oauth_accounts = user_dict.pop("oauth_accounts")
for oauth_account in oauth_accounts:
oauth_accounts_values.append({"user_id": user.id, **oauth_account})
query = self.oauth_accounts.insert()
await self.database.execute_many(query, oauth_accounts_values)
query = self.users.update().where(self.users.c.id == user.id).values(user_dict)
await self.database.execute(query)
return user
async def delete(self, user: UD) -> None:
query = self.users.delete().where(self.users.c.id == user.id)
await self.database.execute(query)
async def _make_user(self, user: Mapping) -> UD:
user_dict = {**user}
if self.oauth_accounts is not None:
query = self.oauth_accounts.select().where(
self.oauth_accounts.c.user_id == user["id"]
)
oauth_accounts = await self.database.fetch_all(query)
user_dict["oauth_accounts"] = oauth_accounts
return self.user_db_model(**user_dict)

View File

@ -14,6 +14,27 @@ class TortoiseBaseUserModel(Model):
is_active = fields.BooleanField(default=True, null=False)
is_superuser = fields.BooleanField(default=False, null=False)
async def to_dict(self):
d = {}
for field in self._meta.db_fields:
d[field] = getattr(self, field)
for field in self._meta.backward_fk_fields:
d[field] = await getattr(self, field).all().values()
return d
class Meta:
abstract = True
class TortoiseBaseOAuthAccountModel(Model):
id = fields.CharField(pk=True, generated=False, max_length=255)
oauth_name = fields.CharField(null=False, max_length=255)
access_token = fields.CharField(null=False, max_length=255)
expires_at = fields.IntField(null=False)
refresh_token = fields.CharField(null=True, max_length=255)
account_id = fields.CharField(index=True, null=False, max_length=255)
account_email = fields.CharField(null=False, max_length=255)
class Meta:
abstract = True
@ -24,41 +45,109 @@ class TortoiseUserDatabase(BaseUserDatabase[UD]):
:param user_db_model: Pydantic model of a DB representation of a user.
:param model: Tortoise ORM model.
:param oauth_account_model: Optional Tortoise ORM model of a OAuth account.
"""
model: Type[TortoiseBaseUserModel]
oauth_account_model: Optional[Type[TortoiseBaseOAuthAccountModel]]
def __init__(self, user_db_model: Type[UD], model: Type[TortoiseBaseUserModel]):
def __init__(
self,
user_db_model: Type[UD],
model: Type[TortoiseBaseUserModel],
oauth_account_model: Optional[Type[TortoiseBaseOAuthAccountModel]] = None,
):
super().__init__(user_db_model)
self.model = model
self.oauth_account_model = oauth_account_model
async def list(self) -> List[UD]:
users = await self.model.all()
return [self.user_db_model.from_orm(user) for user in users]
query = self.model.all()
if self.oauth_account_model is not None:
query = query.prefetch_related("oauth_accounts")
users = await query
return [self.user_db_model(**await user.to_dict()) for user in users]
async def get(self, id: str) -> Optional[UD]:
try:
user = await self.model.get(id=id)
return self.user_db_model.from_orm(user)
query = self.model.get(id=id)
if self.oauth_account_model is not None:
query = query.prefetch_related("oauth_accounts")
user = await query
user_dict = await user.to_dict()
return self.user_db_model(**user_dict)
except DoesNotExist:
return None
async def get_by_email(self, email: str) -> Optional[UD]:
try:
user = await self.model.get(email=email)
return self.user_db_model.from_orm(user)
query = self.model.get(email=email)
if self.oauth_account_model is not None:
query = query.prefetch_related("oauth_accounts")
user = await query
user_dict = await user.to_dict()
return self.user_db_model(**user_dict)
except DoesNotExist:
return None
async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]:
try:
query = self.model.get(
oauth_accounts__oauth_name=oauth, oauth_accounts__account_id=account_id
).prefetch_related("oauth_accounts")
user = await query
user_dict = await user.to_dict()
return self.user_db_model(**user_dict)
except DoesNotExist:
return None
async def create(self, user: UD) -> UD:
model = self.model(**user.dict())
user_dict = user.dict()
oauth_accounts = user_dict.pop("oauth_accounts", None)
model = self.model(**user_dict)
await model.save()
if oauth_accounts and self.oauth_account_model:
oauth_account_objects = []
for oauth_account in oauth_accounts:
oauth_account_objects.append(
self.oauth_account_model(user=model, **oauth_account)
)
await self.oauth_account_model.bulk_create(oauth_account_objects)
return user
async def update(self, user: UD) -> UD:
user_dict = user.dict()
user_dict.pop("id") # Tortoise complains if we pass the PK again
await self.model.filter(id=user.id).update(**user_dict)
oauth_accounts = user_dict.pop("oauth_accounts", None)
model = await self.model.get(id=user.id)
for field in user_dict:
setattr(model, field, user_dict[field])
await model.save()
if oauth_accounts and self.oauth_account_model:
await model.oauth_accounts.all().delete()
oauth_account_objects = []
for oauth_account in oauth_accounts:
oauth_account_objects.append(
self.oauth_account_model(user=model, **oauth_account)
)
await self.oauth_account_model.bulk_create(oauth_account_objects)
return user
async def delete(self, user: UD) -> None:

View File

@ -1,9 +1,17 @@
from typing import Callable, Sequence, Type
from collections import defaultdict
from typing import Callable, DefaultDict, List, Sequence, Type
from httpx_oauth.oauth2 import BaseOAuth2
from fastapi_users import models
from fastapi_users.authentication import Authenticator, BaseAuthentication
from fastapi_users.db import BaseUserDatabase
from fastapi_users.router import Event, UserRouter, get_user_router
from fastapi_users.router import (
Event,
EventHandlersRouter,
get_oauth_router,
get_user_router,
)
class FastAPIUsers:
@ -20,12 +28,16 @@ class FastAPIUsers:
:param reset_password_token_lifetime_seconds: Lifetime of reset password token.
:attribute router: Router exposing authentication routes.
:attribute oauth_routers: List of OAuth routers created through `get_oauth_router`.
:attribute get_current_user: Dependency callable to inject authenticated user.
"""
db: BaseUserDatabase
authenticator: Authenticator
router: UserRouter
router: EventHandlersRouter
oauth_routers: List[EventHandlersRouter]
_user_db_model: Type[models.BaseUserDB]
_event_handlers: DefaultDict[Event, List[Callable]]
def __init__(
self,
@ -50,6 +62,9 @@ class FastAPIUsers:
reset_password_token_secret,
reset_password_token_lifetime_seconds,
)
self.oauth_routers = []
self._user_db_model = user_db_model
self._event_handlers = defaultdict(list)
self.get_current_user = self.authenticator.get_current_user
self.get_current_active_user = self.authenticator.get_current_active_user
@ -63,9 +78,40 @@ class FastAPIUsers:
"""Add an event handler on successful forgot password request."""
return self._on_event(Event.ON_AFTER_FORGOT_PASSWORD)
def get_oauth_router(
self, oauth_client: BaseOAuth2, state_secret: str, redirect_url: str = None
) -> EventHandlersRouter:
"""
Return an OAuth router for a given OAuth client.
:param oauth_client: The HTTPX OAuth client instance.
:param state_secret: Secret used to encode the state JWT.
:param redirect_url: Optional arbitrary redirect URL for the OAuth2 flow.
If not given, the URL to the callback endpoint will be generated.
"""
oauth_router = get_oauth_router(
oauth_client,
self.db,
self._user_db_model,
self.authenticator,
state_secret,
redirect_url,
)
for event_type in self._event_handlers:
for handler in self._event_handlers[event_type]:
oauth_router.add_event_handler(event_type, handler)
self.oauth_routers.append(oauth_router)
return oauth_router
def _on_event(self, event_type: Event) -> Callable:
def decorator(func: Callable) -> Callable:
self._event_handlers[event_type].append(func)
self.router.add_event_handler(event_type, func)
for oauth_router in self.oauth_routers:
oauth_router.add_event_handler(event_type, func)
return func
return decorator

View File

@ -1,5 +1,5 @@
import uuid
from typing import Optional, TypeVar
from typing import List, Optional, TypeVar
import pydantic
from pydantic import BaseModel, EmailStr
@ -19,7 +19,8 @@ class BaseUser(BaseModel):
def create_update_dict(self):
return self.dict(
exclude_unset=True, exclude={"id", "is_superuser", "is_active"}
exclude_unset=True,
exclude={"id", "is_superuser", "is_active", "oauth_accounts"},
)
def create_update_dict_superuser(self):
@ -44,3 +45,28 @@ class BaseUserDB(BaseUser):
UD = TypeVar("UD", bound=BaseUserDB)
class BaseOAuthAccount(BaseModel):
"""Base OAuth account model."""
id: Optional[str] = None
oauth_name: str
access_token: str
expires_at: int
refresh_token: Optional[str] = None
account_id: str
account_email: str
@pydantic.validator("id", pre=True, always=True)
def default_id(cls, v):
return v or str(uuid.uuid4())
class Config:
orm_mode = True
class BaseOAuthAccountMixin(BaseModel):
"""Adds OAuth accounts list to a User model."""
oauth_accounts: List[BaseOAuthAccount] = []

View File

@ -1,5 +1,6 @@
from typing import Tuple
from passlib import pwd
from passlib.context import CryptContext
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
@ -13,3 +14,7 @@ def verify_and_update_password(
def get_password_hash(password: str) -> str:
return pwd_context.hash(password)
def generate_password() -> str:
return pwd.genword()

View File

@ -0,0 +1,7 @@
from fastapi_users.router.common import ( # noqa: F401
ErrorCode,
Event,
EventHandlersRouter,
)
from fastapi_users.router.oauth import get_oauth_router # noqa: F401
from fastapi_users.router.users import get_user_router # noqa: F401

View File

@ -0,0 +1,35 @@
import asyncio
from collections import defaultdict
from enum import Enum, auto
from typing import Callable, DefaultDict, List
from fastapi import APIRouter
class ErrorCode:
REGISTER_USER_ALREADY_EXISTS = "REGISTER_USER_ALREADY_EXISTS"
LOGIN_BAD_CREDENTIALS = "LOGIN_BAD_CREDENTIALS"
RESET_PASSWORD_BAD_TOKEN = "RESET_PASSWORD_BAD_TOKEN"
class Event(Enum):
ON_AFTER_REGISTER = auto()
ON_AFTER_FORGOT_PASSWORD = auto()
class EventHandlersRouter(APIRouter):
event_handlers: DefaultDict[Event, List[Callable]]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.event_handlers = defaultdict(list)
def add_event_handler(self, event_type: Event, func: Callable) -> None:
self.event_handlers[event_type].append(func)
async def run_handlers(self, event_type: Event, *args, **kwargs) -> None:
for handler in self.event_handlers[event_type]:
if asyncio.iscoroutinefunction(handler):
await handler(*args, **kwargs)
else:
handler(*args, **kwargs)

View File

@ -0,0 +1,144 @@
from typing import Dict, List, Type, cast
import jwt
from fastapi import Depends, HTTPException, Query
from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
from httpx_oauth.oauth2 import BaseOAuth2
from starlette import status
from starlette.requests import Request
from starlette.responses import Response
from fastapi_users import models
from fastapi_users.authentication import Authenticator
from fastapi_users.db import BaseUserDatabase
from fastapi_users.password import generate_password, get_password_hash
from fastapi_users.router.common import ErrorCode, Event, EventHandlersRouter
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
def generate_state_token(
data: Dict[str, str], secret: str, lifetime_seconds: int = 3600
) -> str:
data["aud"] = STATE_TOKEN_AUDIENCE
return generate_jwt(data, lifetime_seconds, secret, JWT_ALGORITHM)
def decode_state_token(token: str, secret: str) -> Dict[str, str]:
return jwt.decode(
token, secret, audience=STATE_TOKEN_AUDIENCE, algorithms=[JWT_ALGORITHM],
)
def get_oauth_router(
oauth_client: BaseOAuth2,
user_db: BaseUserDatabase[models.BaseUserDB],
user_db_model: Type[models.BaseUserDB],
authenticator: Authenticator,
state_secret: str,
redirect_url: str = None,
) -> EventHandlersRouter:
"""Generate a router with the OAuth routes."""
router = EventHandlersRouter()
callback_route_name = f"{oauth_client.name}-callback"
if redirect_url is not None:
oauth2_authorize_callback = OAuth2AuthorizeCallback(
oauth_client, redirect_url=redirect_url,
)
else:
oauth2_authorize_callback = OAuth2AuthorizeCallback(
oauth_client, route_name=callback_route_name,
)
@router.get("/authorize")
async def authorize(
request: Request, authentication_backend: str, scopes: List[str] = Query(None),
):
# Check that authentication_backend exists
backend_exists = False
for backend in authenticator.backends:
if backend.name == authentication_backend:
backend_exists = True
break
if not backend_exists:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
redirect_url = request.url_for(callback_route_name)
state_data = {
"authentication_backend": authentication_backend,
}
state = generate_state_token(state_data, state_secret)
authorization_url = await oauth_client.get_authorization_url(
redirect_url, state, scopes,
)
return {"authorization_url": authorization_url}
@router.get("/callback", name=f"{oauth_client.name}-callback")
async def callback(
response: Response, access_token_state=Depends(oauth2_authorize_callback)
):
token, state = access_token_state
account_id, account_email = await oauth_client.get_id_email(
token["access_token"]
)
try:
state_data = decode_state_token(state, state_secret)
except jwt.DecodeError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
user = await user_db.get_by_oauth_account(oauth_client.name, account_id)
new_oauth_account = models.BaseOAuthAccount(
oauth_name=oauth_client.name,
access_token=token["access_token"],
expires_at=token["expires_at"],
refresh_token=token.get("refresh_token"),
account_id=account_id,
account_email=account_email,
)
if not user:
user = await user_db.get_by_email(account_email)
if user:
# Link account
user.oauth_accounts.append(new_oauth_account) # type: ignore
await user_db.update(user)
else:
# Create account
password = generate_password()
user = user_db_model(
email=account_email,
hashed_password=get_password_hash(password),
oauth_accounts=[new_oauth_account],
)
await user_db.create(user)
await router.run_handlers(Event.ON_AFTER_REGISTER, user)
else:
# Update oauth
updated_oauth_accounts = []
for oauth_account in user.oauth_accounts: # type: ignore
if oauth_account.account_id == account_id:
updated_oauth_accounts.append(new_oauth_account)
else:
updated_oauth_accounts.append(oauth_account)
user.oauth_accounts = updated_oauth_accounts # type: ignore
await user_db.update(user)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.LOGIN_BAD_CREDENTIALS,
)
# Authenticate
for backend in authenticator.backends:
if backend.name == state_data["authentication_backend"]:
return await backend.get_login_response(
cast(models.BaseUserDB, user), response
)
return router

View File

@ -1,10 +1,7 @@
import asyncio
from collections import defaultdict
from enum import Enum, auto
from typing import Any, Callable, DefaultDict, Dict, List, Type, cast
from typing import Any, Dict, List, Type, cast
import jwt
from fastapi import APIRouter, Body, Depends, HTTPException
from fastapi import Body, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import EmailStr
from starlette import status
@ -14,40 +11,14 @@ from fastapi_users import models
from fastapi_users.authentication import Authenticator, BaseAuthentication
from fastapi_users.db import BaseUserDatabase
from fastapi_users.password import get_password_hash
from fastapi_users.router.common import ErrorCode, Event, EventHandlersRouter
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
class ErrorCode:
REGISTER_USER_ALREADY_EXISTS = "REGISTER_USER_ALREADY_EXISTS"
LOGIN_BAD_CREDENTIALS = "LOGIN_BAD_CREDENTIALS"
RESET_PASSWORD_BAD_TOKEN = "RESET_PASSWORD_BAD_TOKEN"
class Event(Enum):
ON_AFTER_REGISTER = auto()
ON_AFTER_FORGOT_PASSWORD = auto()
class UserRouter(APIRouter):
event_handlers: DefaultDict[Event, List[Callable]]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.event_handlers = defaultdict(list)
def add_event_handler(self, event_type: Event, func: Callable) -> None:
self.event_handlers[event_type].append(func)
async def run_handlers(self, event_type: Event, *args, **kwargs) -> None:
for handler in self.event_handlers[event_type]:
if asyncio.iscoroutinefunction(handler):
await handler(*args, **kwargs)
else:
handler(*args, **kwargs)
def _add_login_route(
router: UserRouter, user_db: BaseUserDatabase, auth_backend: BaseAuthentication
router: EventHandlersRouter,
user_db: BaseUserDatabase,
auth_backend: BaseAuthentication,
):
@router.post(f"/login/{auth_backend.name}")
async def login(
@ -73,9 +44,9 @@ def get_user_router(
authenticator: Authenticator,
reset_password_token_secret: str,
reset_password_token_lifetime_seconds: int = 3600,
) -> UserRouter:
) -> EventHandlersRouter:
"""Generate a router with the authentication routes."""
router = UserRouter()
router = EventHandlersRouter()
reset_password_token_audience = "fastapi-users:reset"