diff --git a/fastapi_users/fastapi_users.py b/fastapi_users/fastapi_users.py index a2187caf..3079b797 100644 --- a/fastapi_users/fastapi_users.py +++ b/fastapi_users/fastapi_users.py @@ -13,6 +13,7 @@ from fastapi_users.router import ( get_users_router, get_verify_router, ) +from fastapi_users.router.oauth import get_oauth_associate_router try: from httpx_oauth.oauth2 import BaseOAuth2 @@ -78,7 +79,7 @@ class FastAPIUsers(Generic[models.UP, models.ID]): :param backend: The authentication backend instance. :param requires_verification: Whether the authentication - require the user to be verified or not. + require the user to be verified or not. Defaults to False. """ return get_auth_router( backend, @@ -115,6 +116,35 @@ class FastAPIUsers(Generic[models.UP, models.ID]): associate_by_email, ) + def get_oauth_associate_router( + self, + oauth_client: BaseOAuth2, + user_schema: Type[schemas.U], + state_secret: SecretType, + redirect_url: str = None, + requires_verification: bool = False, + ) -> APIRouter: + """ + Return an OAuth association router for a given OAuth client. + + :param oauth_client: The HTTPX OAuth client instance. + :param user_schema: Pydantic schema of a public user. + :param state_secret: Secret used to encode the state JWT. + :param redirect_url: Optional arbitrary redirect URL for the OAuth2 flow. + If not given, the URL to the callback endpoint will be generated. + :param requires_verification: Whether the endpoints + require the users to be verified or not. Defaults to False. + """ + return get_oauth_associate_router( + oauth_client, + self.authenticator, + self.get_user_manager, + user_schema, + state_secret, + redirect_url, + requires_verification, + ) + def get_users_router( self, user_schema: Type[schemas.U], @@ -127,7 +157,7 @@ class FastAPIUsers(Generic[models.UP, models.ID]): :param user_schema: Pydantic schema of a public user. :param user_update_schema: Pydantic schema for updating a user. :param requires_verification: Whether the endpoints - require the users to be verified or not. + require the users to be verified or not. Defaults to False. """ return get_users_router( self.get_user_manager, diff --git a/fastapi_users/manager.py b/fastapi_users/manager.py index a5eb46d9..dc78277c 100644 --- a/fastapi_users/manager.py +++ b/fastapi_users/manager.py @@ -224,6 +224,48 @@ class BaseUserManager(Generic[models.UP, models.ID]): return user + async def oauth_associate_callback( + self: "BaseUserManager[models.UOAP, models.ID]", + user: models.UOAP, + oauth_name: str, + access_token: str, + account_id: str, + account_email: str, + expires_at: Optional[int] = None, + refresh_token: Optional[str] = None, + request: Optional[Request] = None, + ) -> models.UOAP: + """ + Handle the callback after a successful OAuth association. + + We add this new OAuth account to the given user. + + :param oauth_name: Name of the OAuth client. + :param access_token: Valid access token for the service provider. + :param account_id: models.ID of the user on the service provider. + :param account_email: E-mail of the user on the service provider. + :param expires_at: Optional timestamp at which the access token expires. + :param refresh_token: Optional refresh token to get a + fresh access token from the service provider. + :param request: Optional FastAPI request that + triggered the operation, defaults to None + :return: A user. + """ + oauth_account_dict = { + "oauth_name": oauth_name, + "access_token": access_token, + "account_id": account_id, + "account_email": account_email, + "expires_at": expires_at, + "refresh_token": refresh_token, + } + + user = await self.user_db.add_oauth_account(user, oauth_account_dict) + + await self.on_after_update(user, {}, request) + + return user + async def request_verify( self, user: models.UP, request: Optional[Request] = None ) -> None: diff --git a/fastapi_users/router/oauth.py b/fastapi_users/router/oauth.py index d3be6b05..2a54882f 100644 --- a/fastapi_users/router/oauth.py +++ b/fastapi_users/router/oauth.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Type import jwt from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response, status @@ -6,8 +6,8 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback from httpx_oauth.oauth2 import BaseOAuth2, OAuth2Token from pydantic import BaseModel -from fastapi_users import models -from fastapi_users.authentication import AuthenticationBackend, Strategy +from fastapi_users import models, schemas +from fastapi_users.authentication import AuthenticationBackend, Authenticator, Strategy from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt from fastapi_users.manager import BaseUserManager, UserManagerDependency from fastapi_users.router.common import ErrorCode, ErrorModel @@ -136,3 +136,115 @@ def get_oauth_router( return await backend.login(strategy, user, response) return router + + +def get_oauth_associate_router( + oauth_client: BaseOAuth2, + authenticator: Authenticator, + get_user_manager: UserManagerDependency[models.UP, models.ID], + user_schema: Type[schemas.U], + state_secret: SecretType, + redirect_url: str = None, + requires_verification: bool = False, +) -> APIRouter: + """Generate a router with the OAuth routes to associate an authenticated user.""" + router = APIRouter() + + get_current_active_user = authenticator.current_user( + active=True, verified=requires_verification + ) + + callback_route_name = f"oauth-associate:{oauth_client.name}.callback" + + if redirect_url is not None: + oauth2_authorize_callback = OAuth2AuthorizeCallback( + oauth_client, + redirect_url=redirect_url, + ) + else: + oauth2_authorize_callback = OAuth2AuthorizeCallback( + oauth_client, + route_name=callback_route_name, + ) + + @router.get( + "/authorize", + name=f"oauth-associate:{oauth_client.name}.authorize", + response_model=OAuth2AuthorizeResponse, + ) + async def authorize( + request: Request, + scopes: List[str] = Query(None), + user: models.UP = Depends(get_current_active_user), + ) -> OAuth2AuthorizeResponse: + if redirect_url is not None: + authorize_redirect_url = redirect_url + else: + authorize_redirect_url = request.url_for(callback_route_name) + + state_data: Dict[str, str] = {"sub": str(user.id)} + state = generate_state_token(state_data, state_secret) + authorization_url = await oauth_client.get_authorization_url( + authorize_redirect_url, + state, + scopes, + ) + + return OAuth2AuthorizeResponse(authorization_url=authorization_url) + + @router.get( + "/callback", + response_model=user_schema, + name=callback_route_name, + description="The response varies based on the authentication backend used.", + responses={ + status.HTTP_400_BAD_REQUEST: { + "model": ErrorModel, + "content": { + "application/json": { + "examples": { + "INVALID_STATE_TOKEN": { + "summary": "Invalid state token.", + "value": None, + }, + } + } + }, + }, + }, + ) + async def callback( + request: Request, + user: models.UP = Depends(get_current_active_user), + access_token_state: Tuple[OAuth2Token, str] = Depends( + oauth2_authorize_callback + ), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), + ): + token, state = access_token_state + account_id, account_email = await oauth_client.get_id_email( + token["access_token"] + ) + + try: + state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE]) + except jwt.DecodeError: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) + + if state_data["sub"] != str(user.id): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) + + user = await user_manager.oauth_associate_callback( + user, + oauth_client.name, + token["access_token"], + account_id, + account_email, + token.get("expires_at"), + token.get("refresh_token"), + request, + ) + + return user_schema.from_orm(user) + + return router diff --git a/tests/test_fastapi_users.py b/tests/test_fastapi_users.py index f4c60858..29109931 100644 --- a/tests/test_fastapi_users.py +++ b/tests/test_fastapi_users.py @@ -28,6 +28,9 @@ async def test_app_client( app.include_router( fastapi_users.get_oauth_router(oauth_client, mock_authentication, secret) ) + app.include_router( + fastapi_users.get_oauth_associate_router(oauth_client, User, secret) + ) @app.delete("/users/me") def custom_users_route(): diff --git a/tests/test_manager.py b/tests/test_manager.py index 60229312..bfafad0a 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -238,6 +238,31 @@ class TestOAuthCallback: assert user_manager_oauth.on_after_register.called is True +@pytest.mark.asyncio +@pytest.mark.manager +class TestOAuthAssociateCallback: + async def test_existing_user_without_oauth_associate( + self, + user_manager_oauth: UserManagerMock[UserOAuthModel], + superuser_oauth: UserOAuthModel, + ): + + user = await user_manager_oauth.oauth_associate_callback( + superuser_oauth, + "service1", + "TOKEN", + "superuser_oauth1", + superuser_oauth.email, + 1579000751, + ) + + assert user.id == user.id + assert len(user.oauth_accounts) == 1 + assert user.oauth_accounts[0].id is not None + + assert user_manager_oauth.on_after_update.called is True + + @pytest.mark.asyncio @pytest.mark.manager class TestRequestVerifyUser: diff --git a/tests/test_router_oauth.py b/tests/test_router_oauth.py index bc30648b..e18bce27 100644 --- a/tests/test_router_oauth.py +++ b/tests/test_router_oauth.py @@ -5,14 +5,29 @@ import pytest from fastapi import FastAPI, status from httpx_oauth.oauth2 import BaseOAuth2, OAuth2 -from fastapi_users.authentication import AuthenticationBackend -from fastapi_users.router.oauth import generate_state_token, get_oauth_router -from tests.conftest import AsyncMethodMocker, UserManagerMock, UserOAuthModel +from fastapi_users.authentication import AuthenticationBackend, Authenticator +from fastapi_users.router.oauth import ( + generate_state_token, + get_oauth_associate_router, + get_oauth_router, +) +from tests.conftest import ( + AsyncMethodMocker, + User, + UserManagerMock, + UserModel, + UserOAuth, + UserOAuthModel, +) @pytest.fixture def app_factory(secret, get_user_manager_oauth, mock_authentication, oauth_client): - def _app_factory(redirect_url: str = None) -> FastAPI: + def _app_factory( + redirect_url: str = None, requires_verification: bool = False + ) -> FastAPI: + authenticator = Authenticator([mock_authentication], get_user_manager_oauth) + oauth_router = get_oauth_router( oauth_client, mock_authentication, @@ -20,8 +35,19 @@ def app_factory(secret, get_user_manager_oauth, mock_authentication, oauth_clien secret, redirect_url, ) + oauth_associate_router = get_oauth_associate_router( + oauth_client, + authenticator, + get_user_manager_oauth, + User, + secret, + redirect_url, + requires_verification, + ) + app = FastAPI() - app.include_router(oauth_router) + app.include_router(oauth_router, prefix="/oauth") + app.include_router(oauth_associate_router, prefix="/oauth-associate") return app return _app_factory @@ -37,6 +63,11 @@ def test_app_redirect_url(app_factory): return app_factory("http://www.tintagel.bt/callback") +@pytest.fixture +def test_app_requires_verification(app_factory): + return app_factory(requires_verification=True) + + @pytest.fixture @pytest.mark.asyncio async def test_app_client(test_app, get_test_client): @@ -66,7 +97,7 @@ class TestAuthorize: ) response = await test_app_client.get( - "/authorize", params={"scopes": ["scope1", "scope2"]} + "/oauth/authorize", params={"scopes": ["scope1", "scope2"]} ) assert response.status_code == status.HTTP_200_OK @@ -86,7 +117,7 @@ class TestAuthorize: ) response = await test_app_client_redirect_url.get( - "/authorize", params={"scopes": ["scope1", "scope2"]} + "/oauth/authorize", params={"scopes": ["scope1", "scope2"]} ) assert response.status_code == status.HTTP_200_OK @@ -121,7 +152,7 @@ class TestCallback: ) response = await test_app_client.get( - "/callback", + "/oauth/callback", params={"code": "CODE", "state": "STATE"}, ) assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -147,7 +178,7 @@ class TestCallback: ) response = await test_app_client.get( - "/callback", + "/oauth/callback", params={"code": "CODE", "state": state_jwt}, ) @@ -177,7 +208,7 @@ class TestCallback: ) response = await test_app_client.get( - "/callback", + "/oauth/callback", params={"code": "CODE", "state": state_jwt}, ) @@ -204,7 +235,7 @@ class TestCallback: ) response = await test_app_client_redirect_url.get( - "/callback", + "/oauth/callback", params={"code": "CODE", "state": state_jwt}, ) @@ -218,6 +249,221 @@ class TestCallback: assert data["access_token"] == str(user_oauth.id) +@pytest.mark.router +@pytest.mark.oauth +@pytest.mark.asyncio +class TestAssociateAuthorize: + async def test_missing_token(self, test_app_client: httpx.AsyncClient): + response = await test_app_client.get( + "/oauth-associate/authorize", params={"scopes": ["scope1", "scope2"]} + ) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + async def test_inactive_user( + self, test_app_client: httpx.AsyncClient, inactive_user_oauth: UserOAuthModel + ): + response = await test_app_client.get( + "/oauth-associate/authorize", + params={"scopes": ["scope1", "scope2"]}, + headers={"Authorization": f"Bearer {inactive_user_oauth.id}"}, + ) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + async def test_active_user( + self, + async_method_mocker: AsyncMethodMocker, + test_app_client: httpx.AsyncClient, + oauth_client: BaseOAuth2, + user_oauth: UserOAuthModel, + ): + get_authorization_url_mock = async_method_mocker( + oauth_client, "get_authorization_url", return_value="AUTHORIZATION_URL" + ) + + response = await test_app_client.get( + "/oauth-associate/authorize", + params={"scopes": ["scope1", "scope2"]}, + headers={"Authorization": f"Bearer {user_oauth.id}"}, + ) + + assert response.status_code == status.HTTP_200_OK + get_authorization_url_mock.assert_called_once() + + data = response.json() + assert "authorization_url" in data + + async def test_with_redirect_url( + self, + async_method_mocker: AsyncMethodMocker, + test_app_client_redirect_url: httpx.AsyncClient, + oauth_client: BaseOAuth2, + user_oauth: UserOAuthModel, + ): + get_authorization_url_mock = async_method_mocker( + oauth_client, "get_authorization_url", return_value="AUTHORIZATION_URL" + ) + + response = await test_app_client_redirect_url.get( + "/oauth-associate/authorize", + params={"scopes": ["scope1", "scope2"]}, + headers={"Authorization": f"Bearer {user_oauth.id}"}, + ) + + assert response.status_code == status.HTTP_200_OK + get_authorization_url_mock.assert_called_once() + + data = response.json() + assert "authorization_url" in data + + +@pytest.mark.router +@pytest.mark.oauth +@pytest.mark.asyncio +@pytest.mark.parametrize( + "access_token", + [ + ({"access_token": "TOKEN", "expires_at": 1579179542}), + ({"access_token": "TOKEN"}), + ], +) +class TestAssociateCallback: + async def test_missing_token( + self, test_app_client: httpx.AsyncClient, access_token: str + ): + response = await test_app_client.get( + "/oauth-associate/callback", + params={"code": "CODE", "state": "STATE"}, + ) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + async def test_inactive_user( + self, + test_app_client: httpx.AsyncClient, + inactive_user_oauth: UserOAuthModel, + access_token: str, + ): + response = await test_app_client.get( + "/oauth-associate/callback", + params={"code": "CODE", "state": "STATE"}, + headers={"Authorization": f"Bearer {inactive_user_oauth.id}"}, + ) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + async def test_invalid_state( + self, + async_method_mocker: AsyncMethodMocker, + test_app_client: httpx.AsyncClient, + oauth_client: BaseOAuth2, + user_oauth: UserOAuthModel, + access_token: str, + ): + async_method_mocker(oauth_client, "get_access_token", return_value=access_token) + get_id_email_mock = async_method_mocker( + oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email) + ) + + response = await test_app_client.get( + "/oauth-associate/callback", + params={"code": "CODE", "state": "STATE"}, + headers={"Authorization": f"Bearer {user_oauth.id}"}, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + get_id_email_mock.assert_called_once_with("TOKEN") + + async def test_state_with_different_user_id( + self, + async_method_mocker: AsyncMethodMocker, + test_app_client: httpx.AsyncClient, + oauth_client: BaseOAuth2, + user_oauth: UserOAuthModel, + user: UserModel, + access_token: str, + ): + state_jwt = generate_state_token({"sub": str(user.id)}, "SECRET") + async_method_mocker(oauth_client, "get_access_token", return_value=access_token) + get_id_email_mock = async_method_mocker( + oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email) + ) + + response = await test_app_client.get( + "/oauth-associate/callback", + params={"code": "CODE", "state": state_jwt}, + headers={"Authorization": f"Bearer {user_oauth.id}"}, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + get_id_email_mock.assert_called_once_with("TOKEN") + + async def test_active_user( + self, + async_method_mocker: AsyncMethodMocker, + test_app_client: httpx.AsyncClient, + oauth_client: BaseOAuth2, + user_oauth: UserOAuthModel, + user_manager_oauth: UserManagerMock, + access_token: str, + ): + state_jwt = generate_state_token({"sub": str(user_oauth.id)}, "SECRET") + async_method_mocker(oauth_client, "get_access_token", return_value=access_token) + async_method_mocker( + oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email) + ) + async_method_mocker( + user_manager_oauth, "oauth_callback", return_value=user_oauth + ) + + response = await test_app_client.get( + "/oauth-associate/callback", + params={"code": "CODE", "state": state_jwt}, + headers={"Authorization": f"Bearer {user_oauth.id}"}, + ) + + assert response.status_code == status.HTTP_200_OK + + data = cast(Dict[str, Any], response.json()) + assert data["id"] == str(user_oauth.id) + + async def test_redirect_url_router( + self, + async_method_mocker: AsyncMethodMocker, + test_app_client_redirect_url: httpx.AsyncClient, + oauth_client: BaseOAuth2, + user_oauth: UserOAuthModel, + user_manager_oauth: UserManagerMock, + access_token: str, + ): + state_jwt = generate_state_token({"sub": str(user_oauth.id)}, "SECRET") + get_access_token_mock = async_method_mocker( + oauth_client, "get_access_token", return_value=access_token + ) + async_method_mocker( + oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email) + ) + async_method_mocker( + user_manager_oauth, "oauth_callback", return_value=user_oauth + ) + + response = await test_app_client_redirect_url.get( + "/oauth-associate/callback", + params={"code": "CODE", "state": state_jwt}, + headers={"Authorization": f"Bearer {user_oauth.id}"}, + ) + + assert response.status_code == status.HTTP_200_OK + + get_access_token_mock.assert_called_once_with( + "CODE", "http://www.tintagel.bt/callback", None + ) + + data = cast(Dict[str, Any], response.json()) + assert data["id"] == str(user_oauth.id) + + @pytest.mark.asyncio @pytest.mark.oauth @pytest.mark.router @@ -227,9 +473,9 @@ async def test_route_names( authorize_route_name = ( f"oauth:{oauth_client.name}.{mock_authentication.name}.authorize" ) - assert test_app.url_path_for(authorize_route_name) == "/authorize" + assert test_app.url_path_for(authorize_route_name) == "/oauth/authorize" callback_route_name = ( f"oauth:{oauth_client.name}.{mock_authentication.name}.callback" ) - assert test_app.url_path_for(callback_route_name) == "/callback" + assert test_app.url_path_for(callback_route_name) == "/oauth/callback"