Fix #36: fix token url in auto doc (#38)

* Fix #36: fix token url in auto doc

* Define OAuth scheme in authentication base with default /users/login tokenUrl
* Allow to override it through contructor argument of auth class

* Fix test coverage of BaseAuthentication
This commit is contained in:
François Voron
2019-11-03 09:20:16 +01:00
committed by GitHub
parent 6f8bf57d0a
commit 47ad4ce1cc
3 changed files with 25 additions and 10 deletions

View File

@ -1,6 +1,7 @@
from typing import Callable
from fastapi import HTTPException
from fastapi.security import OAuth2PasswordBearer
from starlette import status
from starlette.responses import Response
@ -15,8 +16,20 @@ class BaseAuthentication:
Base adapter for generating and decoding authentication tokens.
Provides dependency injectors to get current active/superuser user.
:param scheme: Optional authentication scheme for OpenAPI.
Defaults to `OAuth2PasswordBearer(tokenUrl="/users/login")`.
Override it if your login route lives somewhere else.
"""
scheme: OAuth2PasswordBearer
def __init__(self, scheme: OAuth2PasswordBearer = None):
if scheme is None:
self.scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
else:
self.scheme = scheme
async def get_login_response(self, user: BaseUserDB, response: Response):
raise NotImplementedError()

View File

@ -1,6 +1,5 @@
import jwt
from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer
from starlette.responses import Response
from fastapi_users.authentication.base import BaseAuthentication
@ -8,8 +7,6 @@ from fastapi_users.db.base import BaseUserDatabase
from fastapi_users.models import BaseUserDB
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
class JWTAuthentication(BaseAuthentication):
"""
@ -23,7 +20,8 @@ class JWTAuthentication(BaseAuthentication):
secret: str
lifetime_seconds: int
def __init__(self, secret: str, lifetime_seconds: int):
def __init__(self, secret: str, lifetime_seconds: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.secret = secret
self.lifetime_seconds = lifetime_seconds
@ -34,28 +32,28 @@ class JWTAuthentication(BaseAuthentication):
return {"token": token}
def get_current_user(self, user_db: BaseUserDatabase):
async def _get_current_user(token: str = Depends(oauth2_scheme)):
async def _get_current_user(token: str = Depends(self.scheme)):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_user_base(user)
return _get_current_user
def get_current_active_user(self, user_db: BaseUserDatabase):
async def _get_current_active_user(token: str = Depends(oauth2_scheme)):
async def _get_current_active_user(token: str = Depends(self.scheme)):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_active_user_base(user)
return _get_current_active_user
def get_current_superuser(self, user_db: BaseUserDatabase):
async def _get_current_superuser(token: str = Depends(oauth2_scheme)):
async def _get_current_superuser(token: str = Depends(self.scheme)):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_superuser_base(user)
return _get_current_superuser
def _get_authentication_method(self, user_db: BaseUserDatabase):
async def authentication_method(token: str = Depends(oauth2_scheme)):
async def authentication_method(token: str = Depends(self.scheme)):
try:
data = jwt.decode(
token,