Fix #36: fix token url in auto doc (#38)

* Fix #36: fix token url in auto doc

* Define OAuth scheme in authentication base with default /users/login tokenUrl
* Allow to override it through contructor argument of auth class

* Fix test coverage of BaseAuthentication
This commit is contained in:
François Voron
2019-11-03 09:20:16 +01:00
committed by GitHub
parent 6f8bf57d0a
commit 47ad4ce1cc
3 changed files with 25 additions and 10 deletions

View File

@ -1,6 +1,7 @@
from typing import Callable from typing import Callable
from fastapi import HTTPException from fastapi import HTTPException
from fastapi.security import OAuth2PasswordBearer
from starlette import status from starlette import status
from starlette.responses import Response from starlette.responses import Response
@ -15,8 +16,20 @@ class BaseAuthentication:
Base adapter for generating and decoding authentication tokens. Base adapter for generating and decoding authentication tokens.
Provides dependency injectors to get current active/superuser user. Provides dependency injectors to get current active/superuser user.
:param scheme: Optional authentication scheme for OpenAPI.
Defaults to `OAuth2PasswordBearer(tokenUrl="/users/login")`.
Override it if your login route lives somewhere else.
""" """
scheme: OAuth2PasswordBearer
def __init__(self, scheme: OAuth2PasswordBearer = None):
if scheme is None:
self.scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
else:
self.scheme = scheme
async def get_login_response(self, user: BaseUserDB, response: Response): async def get_login_response(self, user: BaseUserDB, response: Response):
raise NotImplementedError() raise NotImplementedError()

View File

@ -1,6 +1,5 @@
import jwt import jwt
from fastapi import Depends from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer
from starlette.responses import Response from starlette.responses import Response
from fastapi_users.authentication.base import BaseAuthentication from fastapi_users.authentication.base import BaseAuthentication
@ -8,8 +7,6 @@ from fastapi_users.db.base import BaseUserDatabase
from fastapi_users.models import BaseUserDB from fastapi_users.models import BaseUserDB
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
class JWTAuthentication(BaseAuthentication): class JWTAuthentication(BaseAuthentication):
""" """
@ -23,7 +20,8 @@ class JWTAuthentication(BaseAuthentication):
secret: str secret: str
lifetime_seconds: int lifetime_seconds: int
def __init__(self, secret: str, lifetime_seconds: int): def __init__(self, secret: str, lifetime_seconds: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.secret = secret self.secret = secret
self.lifetime_seconds = lifetime_seconds self.lifetime_seconds = lifetime_seconds
@ -34,28 +32,28 @@ class JWTAuthentication(BaseAuthentication):
return {"token": token} return {"token": token}
def get_current_user(self, user_db: BaseUserDatabase): def get_current_user(self, user_db: BaseUserDatabase):
async def _get_current_user(token: str = Depends(oauth2_scheme)): async def _get_current_user(token: str = Depends(self.scheme)):
user = await self._get_authentication_method(user_db)(token) user = await self._get_authentication_method(user_db)(token)
return self._get_current_user_base(user) return self._get_current_user_base(user)
return _get_current_user return _get_current_user
def get_current_active_user(self, user_db: BaseUserDatabase): def get_current_active_user(self, user_db: BaseUserDatabase):
async def _get_current_active_user(token: str = Depends(oauth2_scheme)): async def _get_current_active_user(token: str = Depends(self.scheme)):
user = await self._get_authentication_method(user_db)(token) user = await self._get_authentication_method(user_db)(token)
return self._get_current_active_user_base(user) return self._get_current_active_user_base(user)
return _get_current_active_user return _get_current_active_user
def get_current_superuser(self, user_db: BaseUserDatabase): def get_current_superuser(self, user_db: BaseUserDatabase):
async def _get_current_superuser(token: str = Depends(oauth2_scheme)): async def _get_current_superuser(token: str = Depends(self.scheme)):
user = await self._get_authentication_method(user_db)(token) user = await self._get_authentication_method(user_db)(token)
return self._get_current_superuser_base(user) return self._get_current_superuser_base(user)
return _get_current_superuser return _get_current_superuser
def _get_authentication_method(self, user_db: BaseUserDatabase): def _get_authentication_method(self, user_db: BaseUserDatabase):
async def authentication_method(token: str = Depends(oauth2_scheme)): async def authentication_method(token: str = Depends(self.scheme)):
try: try:
data = jwt.decode( data = jwt.decode(
token, token,

View File

@ -1,4 +1,5 @@
import pytest import pytest
from fastapi.security import OAuth2PasswordBearer
from starlette.responses import Response from starlette.responses import Response
from fastapi_users.authentication import BaseAuthentication from fastapi_users.authentication import BaseAuthentication
@ -6,9 +7,12 @@ from fastapi_users.authentication import BaseAuthentication
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.authentication @pytest.mark.authentication
async def test_not_implemented_methods(user, mock_user_db): @pytest.mark.parametrize(
"constructor_kwargs", [{}, {"scheme": OAuth2PasswordBearer(tokenUrl="/foo")}]
)
async def test_not_implemented_methods(constructor_kwargs, user, mock_user_db):
response = Response() response = Response()
base_authentication = BaseAuthentication() base_authentication = BaseAuthentication(**constructor_kwargs)
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
await base_authentication.get_login_response(user, response) await base_authentication.get_login_response(user, response)