from typing import Dict, List import jwt from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response, status from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback from httpx_oauth.oauth2 import BaseOAuth2 from fastapi_users import models from fastapi_users.authentication import Authenticator from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt from fastapi_users.manager import BaseUserManager, UserManagerDependency from fastapi_users.router.common import ErrorCode STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state" def generate_state_token( data: Dict[str, str], secret: SecretType, lifetime_seconds: int = 3600 ) -> str: data["aud"] = STATE_TOKEN_AUDIENCE return generate_jwt(data, secret, lifetime_seconds) def get_oauth_router( oauth_client: BaseOAuth2, get_user_manager: UserManagerDependency[models.UC, models.UD], authenticator: Authenticator, state_secret: SecretType, redirect_url: str = None, ) -> APIRouter: """Generate a router with the OAuth routes.""" router = APIRouter() callback_route_name = f"{oauth_client.name}-callback" if redirect_url is not None: oauth2_authorize_callback = OAuth2AuthorizeCallback( oauth_client, redirect_url=redirect_url, ) else: oauth2_authorize_callback = OAuth2AuthorizeCallback( oauth_client, route_name=callback_route_name, ) @router.get("/authorize") async def authorize( request: Request, authentication_backend: str, scopes: List[str] = Query(None), ): # Check that authentication_backend exists backend_exists = any( backend.name == authentication_backend for backend in authenticator.backends ) if not backend_exists: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) if redirect_url is not None: authorize_redirect_url = redirect_url else: authorize_redirect_url = request.url_for(callback_route_name) state_data = { "authentication_backend": authentication_backend, } state = generate_state_token(state_data, state_secret) authorization_url = await oauth_client.get_authorization_url( authorize_redirect_url, state, scopes, ) return {"authorization_url": authorization_url} @router.get("/callback", name=f"{oauth_client.name}-callback") async def callback( request: Request, response: Response, access_token_state=Depends(oauth2_authorize_callback), user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), ): token, state = access_token_state account_id, account_email = await oauth_client.get_id_email( token["access_token"] ) try: state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE]) except jwt.DecodeError: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) new_oauth_account = models.BaseOAuthAccount( oauth_name=oauth_client.name, access_token=token["access_token"], expires_at=token.get("expires_at"), refresh_token=token.get("refresh_token"), account_id=account_id, account_email=account_email, ) user = await user_manager.oauth_callback(new_oauth_account, request) if not user.is_active: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ErrorCode.LOGIN_BAD_CREDENTIALS, ) # Authenticate for backend in authenticator.backends: if backend.name == state_data["authentication_backend"]: return await backend.get_login_response(user, response, user_manager) return router