mirror of
				https://github.com/fastapi-users/fastapi-users.git
				synced 2025-11-04 06:37:51 +08:00 
			
		
		
		
	* Revamp Transport so they always build a full Response object * Fix linting * Add private methods to set cookies on CookieTransport * Change on_after_login login_return parameter to response
		
			
				
	
	
		
			273 lines
		
	
	
		
			9.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			273 lines
		
	
	
		
			9.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from typing import Dict, List, Optional, Tuple, Type
 | 
						|
 | 
						|
import jwt
 | 
						|
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
 | 
						|
from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
 | 
						|
from httpx_oauth.oauth2 import BaseOAuth2, OAuth2Token
 | 
						|
from pydantic import BaseModel
 | 
						|
 | 
						|
from fastapi_users import models, schemas
 | 
						|
from fastapi_users.authentication import AuthenticationBackend, Authenticator, Strategy
 | 
						|
from fastapi_users.exceptions import UserAlreadyExists
 | 
						|
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
 | 
						|
from fastapi_users.manager import BaseUserManager, UserManagerDependency
 | 
						|
from fastapi_users.router.common import ErrorCode, ErrorModel
 | 
						|
 | 
						|
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
 | 
						|
 | 
						|
 | 
						|
class OAuth2AuthorizeResponse(BaseModel):
 | 
						|
    authorization_url: str
 | 
						|
 | 
						|
 | 
						|
def generate_state_token(
 | 
						|
    data: Dict[str, str], secret: SecretType, lifetime_seconds: int = 3600
 | 
						|
) -> str:
 | 
						|
    data["aud"] = STATE_TOKEN_AUDIENCE
 | 
						|
    return generate_jwt(data, secret, lifetime_seconds)
 | 
						|
 | 
						|
 | 
						|
def get_oauth_router(
 | 
						|
    oauth_client: BaseOAuth2,
 | 
						|
    backend: AuthenticationBackend,
 | 
						|
    get_user_manager: UserManagerDependency[models.UP, models.ID],
 | 
						|
    state_secret: SecretType,
 | 
						|
    redirect_url: Optional[str] = None,
 | 
						|
    associate_by_email: bool = False,
 | 
						|
    is_verified_by_default: bool = False,
 | 
						|
) -> APIRouter:
 | 
						|
    """Generate a router with the OAuth routes."""
 | 
						|
    router = APIRouter()
 | 
						|
    callback_route_name = f"oauth:{oauth_client.name}.{backend.name}.callback"
 | 
						|
 | 
						|
    if redirect_url is not None:
 | 
						|
        oauth2_authorize_callback = OAuth2AuthorizeCallback(
 | 
						|
            oauth_client,
 | 
						|
            redirect_url=redirect_url,
 | 
						|
        )
 | 
						|
    else:
 | 
						|
        oauth2_authorize_callback = OAuth2AuthorizeCallback(
 | 
						|
            oauth_client,
 | 
						|
            route_name=callback_route_name,
 | 
						|
        )
 | 
						|
 | 
						|
    @router.get(
 | 
						|
        "/authorize",
 | 
						|
        name=f"oauth:{oauth_client.name}.{backend.name}.authorize",
 | 
						|
        response_model=OAuth2AuthorizeResponse,
 | 
						|
    )
 | 
						|
    async def authorize(
 | 
						|
        request: Request, scopes: List[str] = Query(None)
 | 
						|
    ) -> OAuth2AuthorizeResponse:
 | 
						|
        if redirect_url is not None:
 | 
						|
            authorize_redirect_url = redirect_url
 | 
						|
        else:
 | 
						|
            authorize_redirect_url = str(request.url_for(callback_route_name))
 | 
						|
 | 
						|
        state_data: Dict[str, str] = {}
 | 
						|
        state = generate_state_token(state_data, state_secret)
 | 
						|
        authorization_url = await oauth_client.get_authorization_url(
 | 
						|
            authorize_redirect_url,
 | 
						|
            state,
 | 
						|
            scopes,
 | 
						|
        )
 | 
						|
 | 
						|
        return OAuth2AuthorizeResponse(authorization_url=authorization_url)
 | 
						|
 | 
						|
    @router.get(
 | 
						|
        "/callback",
 | 
						|
        name=callback_route_name,
 | 
						|
        description="The response varies based on the authentication backend used.",
 | 
						|
        responses={
 | 
						|
            status.HTTP_400_BAD_REQUEST: {
 | 
						|
                "model": ErrorModel,
 | 
						|
                "content": {
 | 
						|
                    "application/json": {
 | 
						|
                        "examples": {
 | 
						|
                            "INVALID_STATE_TOKEN": {
 | 
						|
                                "summary": "Invalid state token.",
 | 
						|
                                "value": None,
 | 
						|
                            },
 | 
						|
                            ErrorCode.LOGIN_BAD_CREDENTIALS: {
 | 
						|
                                "summary": "User is inactive.",
 | 
						|
                                "value": {"detail": ErrorCode.LOGIN_BAD_CREDENTIALS},
 | 
						|
                            },
 | 
						|
                        }
 | 
						|
                    }
 | 
						|
                },
 | 
						|
            },
 | 
						|
        },
 | 
						|
    )
 | 
						|
    async def callback(
 | 
						|
        request: Request,
 | 
						|
        access_token_state: Tuple[OAuth2Token, str] = Depends(
 | 
						|
            oauth2_authorize_callback
 | 
						|
        ),
 | 
						|
        user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
 | 
						|
        strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
 | 
						|
    ):
 | 
						|
        token, state = access_token_state
 | 
						|
        account_id, account_email = await oauth_client.get_id_email(
 | 
						|
            token["access_token"]
 | 
						|
        )
 | 
						|
 | 
						|
        if account_email is None:
 | 
						|
            raise HTTPException(
 | 
						|
                status_code=status.HTTP_400_BAD_REQUEST,
 | 
						|
                detail=ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL,
 | 
						|
            )
 | 
						|
 | 
						|
        try:
 | 
						|
            decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
 | 
						|
        except jwt.DecodeError:
 | 
						|
            raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
 | 
						|
 | 
						|
        try:
 | 
						|
            user = await user_manager.oauth_callback(
 | 
						|
                oauth_client.name,
 | 
						|
                token["access_token"],
 | 
						|
                account_id,
 | 
						|
                account_email,
 | 
						|
                token.get("expires_at"),
 | 
						|
                token.get("refresh_token"),
 | 
						|
                request,
 | 
						|
                associate_by_email=associate_by_email,
 | 
						|
                is_verified_by_default=is_verified_by_default,
 | 
						|
            )
 | 
						|
        except UserAlreadyExists:
 | 
						|
            raise HTTPException(
 | 
						|
                status_code=status.HTTP_400_BAD_REQUEST,
 | 
						|
                detail=ErrorCode.OAUTH_USER_ALREADY_EXISTS,
 | 
						|
            )
 | 
						|
 | 
						|
        if not user.is_active:
 | 
						|
            raise HTTPException(
 | 
						|
                status_code=status.HTTP_400_BAD_REQUEST,
 | 
						|
                detail=ErrorCode.LOGIN_BAD_CREDENTIALS,
 | 
						|
            )
 | 
						|
 | 
						|
        # Authenticate
 | 
						|
        response = await backend.login(strategy, user)
 | 
						|
        await user_manager.on_after_login(user, request, response)
 | 
						|
        return response
 | 
						|
 | 
						|
    return router
 | 
						|
 | 
						|
 | 
						|
def get_oauth_associate_router(
 | 
						|
    oauth_client: BaseOAuth2,
 | 
						|
    authenticator: Authenticator,
 | 
						|
    get_user_manager: UserManagerDependency[models.UP, models.ID],
 | 
						|
    user_schema: Type[schemas.U],
 | 
						|
    state_secret: SecretType,
 | 
						|
    redirect_url: Optional[str] = None,
 | 
						|
    requires_verification: bool = False,
 | 
						|
) -> APIRouter:
 | 
						|
    """Generate a router with the OAuth routes to associate an authenticated user."""
 | 
						|
    router = APIRouter()
 | 
						|
 | 
						|
    get_current_active_user = authenticator.current_user(
 | 
						|
        active=True, verified=requires_verification
 | 
						|
    )
 | 
						|
 | 
						|
    callback_route_name = f"oauth-associate:{oauth_client.name}.callback"
 | 
						|
 | 
						|
    if redirect_url is not None:
 | 
						|
        oauth2_authorize_callback = OAuth2AuthorizeCallback(
 | 
						|
            oauth_client,
 | 
						|
            redirect_url=redirect_url,
 | 
						|
        )
 | 
						|
    else:
 | 
						|
        oauth2_authorize_callback = OAuth2AuthorizeCallback(
 | 
						|
            oauth_client,
 | 
						|
            route_name=callback_route_name,
 | 
						|
        )
 | 
						|
 | 
						|
    @router.get(
 | 
						|
        "/authorize",
 | 
						|
        name=f"oauth-associate:{oauth_client.name}.authorize",
 | 
						|
        response_model=OAuth2AuthorizeResponse,
 | 
						|
    )
 | 
						|
    async def authorize(
 | 
						|
        request: Request,
 | 
						|
        scopes: List[str] = Query(None),
 | 
						|
        user: models.UP = Depends(get_current_active_user),
 | 
						|
    ) -> OAuth2AuthorizeResponse:
 | 
						|
        if redirect_url is not None:
 | 
						|
            authorize_redirect_url = redirect_url
 | 
						|
        else:
 | 
						|
            authorize_redirect_url = str(request.url_for(callback_route_name))
 | 
						|
 | 
						|
        state_data: Dict[str, str] = {"sub": str(user.id)}
 | 
						|
        state = generate_state_token(state_data, state_secret)
 | 
						|
        authorization_url = await oauth_client.get_authorization_url(
 | 
						|
            authorize_redirect_url,
 | 
						|
            state,
 | 
						|
            scopes,
 | 
						|
        )
 | 
						|
 | 
						|
        return OAuth2AuthorizeResponse(authorization_url=authorization_url)
 | 
						|
 | 
						|
    @router.get(
 | 
						|
        "/callback",
 | 
						|
        response_model=user_schema,
 | 
						|
        name=callback_route_name,
 | 
						|
        description="The response varies based on the authentication backend used.",
 | 
						|
        responses={
 | 
						|
            status.HTTP_400_BAD_REQUEST: {
 | 
						|
                "model": ErrorModel,
 | 
						|
                "content": {
 | 
						|
                    "application/json": {
 | 
						|
                        "examples": {
 | 
						|
                            "INVALID_STATE_TOKEN": {
 | 
						|
                                "summary": "Invalid state token.",
 | 
						|
                                "value": None,
 | 
						|
                            },
 | 
						|
                        }
 | 
						|
                    }
 | 
						|
                },
 | 
						|
            },
 | 
						|
        },
 | 
						|
    )
 | 
						|
    async def callback(
 | 
						|
        request: Request,
 | 
						|
        user: models.UP = Depends(get_current_active_user),
 | 
						|
        access_token_state: Tuple[OAuth2Token, str] = Depends(
 | 
						|
            oauth2_authorize_callback
 | 
						|
        ),
 | 
						|
        user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
 | 
						|
    ):
 | 
						|
        token, state = access_token_state
 | 
						|
        account_id, account_email = await oauth_client.get_id_email(
 | 
						|
            token["access_token"]
 | 
						|
        )
 | 
						|
 | 
						|
        if account_email is None:
 | 
						|
            raise HTTPException(
 | 
						|
                status_code=status.HTTP_400_BAD_REQUEST,
 | 
						|
                detail=ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL,
 | 
						|
            )
 | 
						|
 | 
						|
        try:
 | 
						|
            state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
 | 
						|
        except jwt.DecodeError:
 | 
						|
            raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
 | 
						|
 | 
						|
        if state_data["sub"] != str(user.id):
 | 
						|
            raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
 | 
						|
 | 
						|
        user = await user_manager.oauth_associate_callback(
 | 
						|
            user,
 | 
						|
            oauth_client.name,
 | 
						|
            token["access_token"],
 | 
						|
            account_id,
 | 
						|
            account_email,
 | 
						|
            token.get("expires_at"),
 | 
						|
            token.get("refresh_token"),
 | 
						|
            request,
 | 
						|
        )
 | 
						|
 | 
						|
        return user_schema.from_orm(user)
 | 
						|
 | 
						|
    return router
 |