diff --git a/fastapi_users/authentication/base.py b/fastapi_users/authentication/base.py index 73a21cdd..41616680 100644 --- a/fastapi_users/authentication/base.py +++ b/fastapi_users/authentication/base.py @@ -1,6 +1,7 @@ from typing import Callable from fastapi import HTTPException +from fastapi.security import OAuth2PasswordBearer from starlette import status from starlette.responses import Response @@ -15,8 +16,20 @@ class BaseAuthentication: Base adapter for generating and decoding authentication tokens. Provides dependency injectors to get current active/superuser user. + + :param scheme: Optional authentication scheme for OpenAPI. + Defaults to `OAuth2PasswordBearer(tokenUrl="/users/login")`. + Override it if your login route lives somewhere else. """ + scheme: OAuth2PasswordBearer + + def __init__(self, scheme: OAuth2PasswordBearer = None): + if scheme is None: + self.scheme = OAuth2PasswordBearer(tokenUrl="/users/login") + else: + self.scheme = scheme + async def get_login_response(self, user: BaseUserDB, response: Response): raise NotImplementedError() diff --git a/fastapi_users/authentication/jwt.py b/fastapi_users/authentication/jwt.py index 6fc3f4f0..ce51618b 100644 --- a/fastapi_users/authentication/jwt.py +++ b/fastapi_users/authentication/jwt.py @@ -1,6 +1,5 @@ import jwt from fastapi import Depends -from fastapi.security import OAuth2PasswordBearer from starlette.responses import Response from fastapi_users.authentication.base import BaseAuthentication @@ -8,8 +7,6 @@ from fastapi_users.db.base import BaseUserDatabase from fastapi_users.models import BaseUserDB from fastapi_users.utils import JWT_ALGORITHM, generate_jwt -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login") - class JWTAuthentication(BaseAuthentication): """ @@ -23,7 +20,8 @@ class JWTAuthentication(BaseAuthentication): secret: str lifetime_seconds: int - def __init__(self, secret: str, lifetime_seconds: int): + def __init__(self, secret: str, lifetime_seconds: int, *args, **kwargs): + super().__init__(*args, **kwargs) self.secret = secret self.lifetime_seconds = lifetime_seconds @@ -34,28 +32,28 @@ class JWTAuthentication(BaseAuthentication): return {"token": token} def get_current_user(self, user_db: BaseUserDatabase): - async def _get_current_user(token: str = Depends(oauth2_scheme)): + async def _get_current_user(token: str = Depends(self.scheme)): user = await self._get_authentication_method(user_db)(token) return self._get_current_user_base(user) return _get_current_user def get_current_active_user(self, user_db: BaseUserDatabase): - async def _get_current_active_user(token: str = Depends(oauth2_scheme)): + async def _get_current_active_user(token: str = Depends(self.scheme)): user = await self._get_authentication_method(user_db)(token) return self._get_current_active_user_base(user) return _get_current_active_user def get_current_superuser(self, user_db: BaseUserDatabase): - async def _get_current_superuser(token: str = Depends(oauth2_scheme)): + async def _get_current_superuser(token: str = Depends(self.scheme)): user = await self._get_authentication_method(user_db)(token) return self._get_current_superuser_base(user) return _get_current_superuser def _get_authentication_method(self, user_db: BaseUserDatabase): - async def authentication_method(token: str = Depends(oauth2_scheme)): + async def authentication_method(token: str = Depends(self.scheme)): try: data = jwt.decode( token, diff --git a/tests/test_authentication_base.py b/tests/test_authentication_base.py index 05dec909..794e0757 100644 --- a/tests/test_authentication_base.py +++ b/tests/test_authentication_base.py @@ -1,4 +1,5 @@ import pytest +from fastapi.security import OAuth2PasswordBearer from starlette.responses import Response from fastapi_users.authentication import BaseAuthentication @@ -6,9 +7,12 @@ from fastapi_users.authentication import BaseAuthentication @pytest.mark.asyncio @pytest.mark.authentication -async def test_not_implemented_methods(user, mock_user_db): +@pytest.mark.parametrize( + "constructor_kwargs", [{}, {"scheme": OAuth2PasswordBearer(tokenUrl="/foo")}] +) +async def test_not_implemented_methods(constructor_kwargs, user, mock_user_db): response = Response() - base_authentication = BaseAuthentication() + base_authentication = BaseAuthentication(**constructor_kwargs) with pytest.raises(NotImplementedError): await base_authentication.get_login_response(user, response)