mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-11-02 12:21:53 +08:00
* Revamp authentication to allow multiple backends * Make router generate a login route for each backend * Apply black * Remove unused imports * Complete docstrings * Update documentation * WIP add cookie auth * Complete cookie auth unit tests * Add documentation for cookie auth * Fix cookie backend default name * Don't make cookie return a Response
This commit is contained in:
@ -1,2 +1,60 @@
|
||||
from typing import Sequence
|
||||
|
||||
from fastapi import HTTPException
|
||||
from starlette import status
|
||||
from starlette.requests import Request
|
||||
|
||||
from fastapi_users.authentication.base import BaseAuthentication # noqa: F401
|
||||
from fastapi_users.authentication.cookie import CookieAuthentication # noqa: F401
|
||||
from fastapi_users.authentication.jwt import JWTAuthentication # noqa: F401
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.models import BaseUserDB
|
||||
|
||||
|
||||
class Authenticator:
|
||||
"""
|
||||
Provides dependency callables to retrieve authenticated user.
|
||||
|
||||
It performs the authentication against a list of backends
|
||||
defined by the end-developer. The first backend yielding a user wins.
|
||||
If no backend yields a user, an HTTPException is raised.
|
||||
|
||||
:param backends: List of authentication backends.
|
||||
:param user_db: Database adapter instance.
|
||||
"""
|
||||
|
||||
backends: Sequence[BaseAuthentication]
|
||||
user_db: BaseUserDatabase
|
||||
|
||||
def __init__(
|
||||
self, backends: Sequence[BaseAuthentication], user_db: BaseUserDatabase
|
||||
):
|
||||
self.backends = backends
|
||||
self.user_db = user_db
|
||||
|
||||
async def get_current_user(self, request: Request) -> BaseUserDB:
|
||||
return await self._authenticate(request)
|
||||
|
||||
async def get_current_active_user(self, request: Request) -> BaseUserDB:
|
||||
user = await self.get_current_user(request)
|
||||
if not user.is_active:
|
||||
raise self._get_credentials_exception()
|
||||
return user
|
||||
|
||||
async def get_current_superuser(self, request: Request) -> BaseUserDB:
|
||||
user = await self.get_current_active_user(request)
|
||||
if not user.is_superuser:
|
||||
raise self._get_credentials_exception(status.HTTP_403_FORBIDDEN)
|
||||
return user
|
||||
|
||||
async def _authenticate(self, request: Request) -> BaseUserDB:
|
||||
for backend in self.backends:
|
||||
user = await backend(request, self.user_db)
|
||||
if user is not None:
|
||||
return user
|
||||
raise self._get_credentials_exception()
|
||||
|
||||
def _get_credentials_exception(
|
||||
self, status_code: int = status.HTTP_401_UNAUTHORIZED
|
||||
) -> HTTPException:
|
||||
return HTTPException(status_code=status_code)
|
||||
|
||||
@ -1,70 +1,30 @@
|
||||
from typing import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from starlette import status
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.models import BaseUserDB
|
||||
|
||||
credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
|
||||
class BaseAuthentication:
|
||||
"""
|
||||
Base adapter for generating and decoding authentication tokens.
|
||||
Base authentication backend.
|
||||
|
||||
Provides dependency injectors to get current active/superuser user.
|
||||
Every backend should derive from this class.
|
||||
|
||||
:param scheme: Optional authentication scheme for OpenAPI.
|
||||
Defaults to `OAuth2PasswordBearer(tokenUrl="/users/login")`.
|
||||
Override it if your login route lives somewhere else.
|
||||
:param name: Name of the backend. It will be used to name the login route.
|
||||
"""
|
||||
|
||||
scheme: OAuth2PasswordBearer
|
||||
name: str
|
||||
|
||||
def __init__(self, scheme: OAuth2PasswordBearer = None):
|
||||
if scheme is None:
|
||||
self.scheme = OAuth2PasswordBearer(tokenUrl="/users/login")
|
||||
else:
|
||||
self.scheme = scheme
|
||||
def __init__(self, name: str = "base"):
|
||||
self.name = name
|
||||
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response):
|
||||
async def __call__(
|
||||
self, request: Request, user_db: BaseUserDatabase
|
||||
) -> Optional[BaseUserDB]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_current_user(self, user_db: BaseUserDatabase):
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_current_active_user(self, user_db: BaseUserDatabase):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_current_superuser(self, user_db: BaseUserDatabase):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_authentication_method(
|
||||
self, user_db: BaseUserDatabase
|
||||
) -> Callable[..., BaseUserDB]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_current_user_base(self, user: BaseUserDB) -> BaseUserDB:
|
||||
if user is None:
|
||||
raise self._get_credentials_exception()
|
||||
return user
|
||||
|
||||
def _get_current_active_user_base(self, user: BaseUserDB) -> BaseUserDB:
|
||||
user = self._get_current_user_base(user)
|
||||
if not user.is_active:
|
||||
raise self._get_credentials_exception()
|
||||
return user
|
||||
|
||||
def _get_current_superuser_base(self, user: BaseUserDB) -> BaseUserDB:
|
||||
user = self._get_current_active_user_base(user)
|
||||
if not user.is_superuser:
|
||||
raise self._get_credentials_exception(status.HTTP_403_FORBIDDEN)
|
||||
return user
|
||||
|
||||
def _get_credentials_exception(
|
||||
self, status_code: int = status.HTTP_401_UNAUTHORIZED
|
||||
) -> HTTPException:
|
||||
return HTTPException(status_code=status_code)
|
||||
|
||||
53
fastapi_users/authentication/cookie.py
Normal file
53
fastapi_users/authentication/cookie.py
Normal file
@ -0,0 +1,53 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi.security import APIKeyCookie
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from fastapi_users.authentication.jwt import JWTAuthentication
|
||||
from fastapi_users.models import BaseUserDB
|
||||
|
||||
|
||||
class CookieAuthentication(JWTAuthentication):
|
||||
"""
|
||||
Authentication backend using a cookie.
|
||||
|
||||
Internally, uses a JWT token to store the data.
|
||||
|
||||
:param secret: Secret used to encode the cookie.
|
||||
:param lifetime_seconds: Lifetime duration of the cookie in seconds.
|
||||
:param cookie_name: Name of the cookie.
|
||||
:param name: Name of the backend. It will be used to name the login route.
|
||||
"""
|
||||
|
||||
lifetime_seconds: int
|
||||
cookie_name: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
secret: str,
|
||||
lifetime_seconds: int,
|
||||
cookie_name: str = "fastapiusersauth",
|
||||
name: str = "cookie",
|
||||
):
|
||||
super().__init__(secret, lifetime_seconds, name=name)
|
||||
self.lifetime_seconds = lifetime_seconds
|
||||
self.cookie_name = cookie_name
|
||||
self.api_key_cookie = APIKeyCookie(name=self.cookie_name, auto_error=False)
|
||||
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response) -> Any:
|
||||
token = await self._generate_token(user)
|
||||
response.set_cookie(
|
||||
self.cookie_name,
|
||||
token,
|
||||
max_age=self.lifetime_seconds,
|
||||
secure=True,
|
||||
httponly=True,
|
||||
)
|
||||
|
||||
# We shouldn't return directly the response
|
||||
# so that FastAPI can terminate it properly
|
||||
return None
|
||||
|
||||
async def _retrieve_token(self, request: Request) -> Optional[str]:
|
||||
return await self.api_key_cookie.__call__(request)
|
||||
@ -1,5 +1,8 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
import jwt
|
||||
from fastapi import Depends
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from fastapi_users.authentication.base import BaseAuthentication
|
||||
@ -10,62 +13,58 @@ from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
|
||||
|
||||
class JWTAuthentication(BaseAuthentication):
|
||||
"""
|
||||
Authentication using a JWT.
|
||||
Authentication backend using a JWT in a Bearer header.
|
||||
|
||||
:param secret: Secret used to encode the JWT.
|
||||
:param lifetime_seconds: Lifetime duration of the JWT in seconds.
|
||||
:param tokenUrl: Path where to get a token.
|
||||
:param name: Name of the backend. It will be used to name the login route.
|
||||
"""
|
||||
|
||||
token_audience: str = "fastapi-users:auth"
|
||||
secret: str
|
||||
lifetime_seconds: int
|
||||
|
||||
def __init__(self, secret: str, lifetime_seconds: int, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(
|
||||
self,
|
||||
secret: str,
|
||||
lifetime_seconds: int,
|
||||
tokenUrl: str = "/users/login",
|
||||
name: str = "jwt",
|
||||
):
|
||||
super().__init__(name)
|
||||
self.secret = secret
|
||||
self.lifetime_seconds = lifetime_seconds
|
||||
self.scheme = OAuth2PasswordBearer(tokenUrl, auto_error=False)
|
||||
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response):
|
||||
data = {"user_id": user.id, "aud": self.token_audience}
|
||||
token = generate_jwt(data, self.lifetime_seconds, self.secret, JWT_ALGORITHM)
|
||||
async def __call__(
|
||||
self, request: Request, user_db: BaseUserDatabase,
|
||||
) -> Optional[BaseUserDB]:
|
||||
token = await self._retrieve_token(request)
|
||||
if token is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = jwt.decode(
|
||||
token,
|
||||
self.secret,
|
||||
audience=self.token_audience,
|
||||
algorithms=[JWT_ALGORITHM],
|
||||
)
|
||||
user_id = data.get("user_id")
|
||||
if user_id is None:
|
||||
return None
|
||||
except jwt.PyJWTError:
|
||||
return None
|
||||
return await user_db.get(user_id)
|
||||
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response) -> Any:
|
||||
token = await self._generate_token(user)
|
||||
return {"token": token}
|
||||
|
||||
def get_current_user(self, user_db: BaseUserDatabase):
|
||||
async def _get_current_user(token: str = Depends(self.scheme)):
|
||||
user = await self._get_authentication_method(user_db)(token)
|
||||
return self._get_current_user_base(user)
|
||||
async def _retrieve_token(self, request: Request) -> Optional[str]:
|
||||
return await self.scheme.__call__(request)
|
||||
|
||||
return _get_current_user
|
||||
|
||||
def get_current_active_user(self, user_db: BaseUserDatabase):
|
||||
async def _get_current_active_user(token: str = Depends(self.scheme)):
|
||||
user = await self._get_authentication_method(user_db)(token)
|
||||
return self._get_current_active_user_base(user)
|
||||
|
||||
return _get_current_active_user
|
||||
|
||||
def get_current_superuser(self, user_db: BaseUserDatabase):
|
||||
async def _get_current_superuser(token: str = Depends(self.scheme)):
|
||||
user = await self._get_authentication_method(user_db)(token)
|
||||
return self._get_current_superuser_base(user)
|
||||
|
||||
return _get_current_superuser
|
||||
|
||||
def _get_authentication_method(self, user_db: BaseUserDatabase):
|
||||
async def authentication_method(token: str = Depends(self.scheme)):
|
||||
try:
|
||||
data = jwt.decode(
|
||||
token,
|
||||
self.secret,
|
||||
audience=self.token_audience,
|
||||
algorithms=[JWT_ALGORITHM],
|
||||
)
|
||||
user_id = data.get("user_id")
|
||||
if user_id is None:
|
||||
return None
|
||||
except jwt.PyJWTError:
|
||||
return None
|
||||
return await user_db.get(user_id)
|
||||
|
||||
return authentication_method
|
||||
async def _generate_token(self, user: BaseUserDB) -> str:
|
||||
data = {"user_id": user.id, "aud": self.token_audience}
|
||||
return generate_jwt(data, self.lifetime_seconds, self.secret, JWT_ALGORITHM)
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from typing import Callable, Type
|
||||
from typing import Callable, Sequence, Type
|
||||
|
||||
from fastapi_users.authentication import BaseAuthentication
|
||||
from fastapi_users.authentication import Authenticator, BaseAuthentication
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.models import BaseUser, BaseUserDB
|
||||
from fastapi_users.models import BaseUser
|
||||
from fastapi_users.router import Event, UserRouter, get_user_router
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ class FastAPIUsers:
|
||||
Main object that ties together the component for users authentication.
|
||||
|
||||
:param db: Database adapter instance.
|
||||
:param auth: Authentication logic instance.
|
||||
:param auth_backends: List of authentication backends.
|
||||
:param user_model: Pydantic model of a user.
|
||||
:param reset_password_token_secret: Secret to encode reset password token.
|
||||
:param reset_password_token_lifetime_seconds: Lifetime of reset password token.
|
||||
@ -21,36 +21,30 @@ class FastAPIUsers:
|
||||
"""
|
||||
|
||||
db: BaseUserDatabase
|
||||
auth: BaseAuthentication
|
||||
authenticator: Authenticator
|
||||
router: UserRouter
|
||||
get_current_user: Callable[..., BaseUserDB]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: BaseUserDatabase,
|
||||
auth: BaseAuthentication,
|
||||
auth_backends: Sequence[BaseAuthentication],
|
||||
user_model: Type[BaseUser],
|
||||
reset_password_token_secret: str,
|
||||
reset_password_token_lifetime_seconds: int = 3600,
|
||||
):
|
||||
self.db = db
|
||||
self.auth = auth
|
||||
self.authenticator = Authenticator(auth_backends, db)
|
||||
self.router = get_user_router(
|
||||
self.db,
|
||||
user_model,
|
||||
self.auth,
|
||||
self.authenticator,
|
||||
reset_password_token_secret,
|
||||
reset_password_token_lifetime_seconds,
|
||||
)
|
||||
|
||||
get_current_user = self.auth.get_current_user(self.db)
|
||||
self.get_current_user = get_current_user # type: ignore
|
||||
|
||||
get_current_active_user = self.auth.get_current_active_user(self.db)
|
||||
self.get_current_active_user = get_current_active_user # type: ignore
|
||||
|
||||
get_current_superuser = self.auth.get_current_superuser(self.db)
|
||||
self.get_current_superuser = get_current_superuser # type: ignore
|
||||
self.get_current_user = self.authenticator.get_current_user
|
||||
self.get_current_active_user = self.authenticator.get_current_active_user
|
||||
self.get_current_superuser = self.authenticator.get_current_superuser
|
||||
|
||||
def on_after_register(self) -> Callable:
|
||||
"""Add an event handler on successful registration."""
|
||||
|
||||
@ -10,7 +10,7 @@ from pydantic import EmailStr
|
||||
from starlette import status
|
||||
from starlette.responses import Response
|
||||
|
||||
from fastapi_users.authentication import BaseAuthentication
|
||||
from fastapi_users.authentication import Authenticator, BaseAuthentication
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.models import BaseUser, Models
|
||||
from fastapi_users.password import get_password_hash
|
||||
@ -46,10 +46,28 @@ class UserRouter(APIRouter):
|
||||
handler(*args, **kwargs)
|
||||
|
||||
|
||||
def _add_login_route(
|
||||
router: UserRouter, user_db: BaseUserDatabase, auth_backend: BaseAuthentication
|
||||
):
|
||||
@router.post(f"/login/{auth_backend.name}")
|
||||
async def login(
|
||||
response: Response, credentials: OAuth2PasswordRequestForm = Depends()
|
||||
):
|
||||
user = await user_db.authenticate(credentials)
|
||||
|
||||
if user is None or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.LOGIN_BAD_CREDENTIALS,
|
||||
)
|
||||
|
||||
return await auth_backend.get_login_response(user, response)
|
||||
|
||||
|
||||
def get_user_router(
|
||||
user_db: BaseUserDatabase,
|
||||
user_model: typing.Type[BaseUser],
|
||||
auth: BaseAuthentication,
|
||||
authenticator: Authenticator,
|
||||
reset_password_token_secret: str,
|
||||
reset_password_token_lifetime_seconds: int = 3600,
|
||||
) -> UserRouter:
|
||||
@ -59,8 +77,8 @@ def get_user_router(
|
||||
|
||||
reset_password_token_audience = "fastapi-users:reset"
|
||||
|
||||
get_current_active_user = auth.get_current_active_user(user_db)
|
||||
get_current_superuser = auth.get_current_superuser(user_db)
|
||||
get_current_active_user = authenticator.get_current_active_user
|
||||
get_current_superuser = authenticator.get_current_superuser
|
||||
|
||||
async def _get_or_404(id: str) -> models.UserDB: # type: ignore
|
||||
user = await user_db.get(id)
|
||||
@ -79,6 +97,9 @@ def get_user_router(
|
||||
setattr(user, field, update_dict[field])
|
||||
return await user_db.update(user)
|
||||
|
||||
for auth_backend in authenticator.backends:
|
||||
_add_login_route(router, user_db, auth_backend)
|
||||
|
||||
@router.post(
|
||||
"/register", response_model=models.User, status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
@ -101,20 +122,6 @@ def get_user_router(
|
||||
|
||||
return created_user
|
||||
|
||||
@router.post("/login")
|
||||
async def login(
|
||||
response: Response, credentials: OAuth2PasswordRequestForm = Depends()
|
||||
):
|
||||
user = await user_db.authenticate(credentials)
|
||||
|
||||
if user is None or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.LOGIN_BAD_CREDENTIALS,
|
||||
)
|
||||
|
||||
return await auth.get_login_response(user, response)
|
||||
|
||||
@router.post("/forgot-password", status_code=status.HTTP_202_ACCEPTED)
|
||||
async def forgot_password(email: EmailStr = Body(..., embed=True)):
|
||||
user = await user_db.get_by_email(email)
|
||||
|
||||
Reference in New Issue
Block a user