diff --git a/fastapi_users/authentication/jwt.py b/fastapi_users/authentication/jwt.py index 768886cc..6fc3f4f0 100644 --- a/fastapi_users/authentication/jwt.py +++ b/fastapi_users/authentication/jwt.py @@ -1,5 +1,3 @@ -from datetime import datetime, timedelta - import jwt from fastapi import Depends from fastapi.security import OAuth2PasswordBearer @@ -8,17 +6,11 @@ from starlette.responses import Response from fastapi_users.authentication.base import BaseAuthentication from fastapi_users.db.base import BaseUserDatabase from fastapi_users.models import BaseUserDB +from fastapi_users.utils import JWT_ALGORITHM, generate_jwt oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login") -def generate_jwt(data: dict, lifetime_seconds: int, secret: str, algorithm: str) -> str: - payload = data.copy() - expire = datetime.utcnow() + timedelta(seconds=lifetime_seconds) - payload["exp"] = expire - return jwt.encode(payload, secret, algorithm=algorithm).decode("utf-8") - - class JWTAuthentication(BaseAuthentication): """ Authentication using a JWT. @@ -27,7 +19,7 @@ class JWTAuthentication(BaseAuthentication): :param lifetime_seconds: Lifetime duration of the JWT in seconds. """ - algorithm: str = "HS256" + token_audience: str = "fastapi-users:auth" secret: str lifetime_seconds: int @@ -36,8 +28,8 @@ class JWTAuthentication(BaseAuthentication): self.lifetime_seconds = lifetime_seconds async def get_login_response(self, user: BaseUserDB, response: Response): - data = {"user_id": user.id} - token = generate_jwt(data, self.lifetime_seconds, self.secret, self.algorithm) + data = {"user_id": user.id, "aud": self.token_audience} + token = generate_jwt(data, self.lifetime_seconds, self.secret, JWT_ALGORITHM) return {"token": token} @@ -65,7 +57,12 @@ class JWTAuthentication(BaseAuthentication): def _get_authentication_method(self, user_db: BaseUserDatabase): async def authentication_method(token: str = Depends(oauth2_scheme)): try: - data = jwt.decode(token, self.secret, algorithms=[self.algorithm]) + data = jwt.decode( + token, + self.secret, + audience=self.token_audience, + algorithms=[JWT_ALGORITHM], + ) user_id = data.get("user_id") if user_id is None: return None diff --git a/fastapi_users/fastapi_users.py b/fastapi_users/fastapi_users.py index dc57e238..c9b46896 100644 --- a/fastapi_users/fastapi_users.py +++ b/fastapi_users/fastapi_users.py @@ -1,6 +1,6 @@ """Ready-to-use and customizable users management for FastAPI.""" -from typing import Callable, Type +from typing import Any, Callable, Type from fastapi import APIRouter @@ -17,6 +17,9 @@ class FastAPIUsers: :param db: Database adapter instance. :param auth: Authentication logic instance. :param user_model: Pydantic model of a user. + :param on_after_forgot_password: Hook called after a forgot password request. + :param reset_password_token_secret: Secret to encode reset password token. + :param reset_password_token_lifetime_seconds: Lifetime of reset password token. :attribute router: FastAPI router exposing authentication routes. :attribute get_current_user: Dependency callable to inject authenticated user. @@ -28,11 +31,24 @@ class FastAPIUsers: get_current_user: Callable[..., BaseUserDB] def __init__( - self, db: BaseUserDatabase, auth: BaseAuthentication, user_model: Type[BaseUser] + self, + db: BaseUserDatabase, + auth: BaseAuthentication, + user_model: Type[BaseUser], + on_after_forgot_password: Callable[[BaseUserDB, str], Any], + reset_password_token_secret: str, + reset_password_token_lifetime_seconds: int = 3600, ): self.db = db self.auth = auth - self.router = get_user_router(self.db, user_model, self.auth) + self.router = get_user_router( + self.db, + user_model, + self.auth, + on_after_forgot_password, + reset_password_token_secret, + reset_password_token_lifetime_seconds, + ) get_current_user = self.auth.get_current_user(self.db) self.get_current_user = get_current_user # type: ignore diff --git a/fastapi_users/router.py b/fastapi_users/router.py index b0db5a2f..515472c0 100644 --- a/fastapi_users/router.py +++ b/fastapi_users/router.py @@ -1,23 +1,37 @@ -from typing import Type +import inspect +from typing import Any, Callable, Type -from fastapi import APIRouter, Depends, HTTPException +import jwt +from fastapi import APIRouter, Body, Depends, HTTPException from fastapi.security import OAuth2PasswordRequestForm +from pydantic.types import EmailStr from starlette import status from starlette.responses import Response from fastapi_users.authentication import BaseAuthentication from fastapi_users.db import BaseUserDatabase -from fastapi_users.models import BaseUser, Models +from fastapi_users.models import BaseUser, BaseUserDB, Models from fastapi_users.password import get_password_hash +from fastapi_users.utils import JWT_ALGORITHM, generate_jwt def get_user_router( - user_db: BaseUserDatabase, user_model: Type[BaseUser], auth: BaseAuthentication + user_db: BaseUserDatabase, + user_model: Type[BaseUser], + auth: BaseAuthentication, + on_after_forgot_password: Callable[[BaseUserDB, str], Any], + reset_password_token_secret: str, + reset_password_token_lifetime_seconds: int = 3600, ) -> APIRouter: """Generate a router with the authentication routes.""" router = APIRouter() models = Models(user_model) + reset_password_token_audience = "fastapi-users:reset" + is_on_after_forgot_password_async = inspect.iscoroutinefunction( + on_after_forgot_password + ) + @router.post("/register", response_model=models.User) async def register(user: models.UserCreate): # type: ignore hashed_password = get_password_hash(user.password) @@ -38,4 +52,45 @@ def get_user_router( return await auth.get_login_response(user, response) + @router.post("/forgot-password", status_code=status.HTTP_202_ACCEPTED) + async def forgot_password(email: EmailStr = Body(..., embed=True)): + user = await user_db.get_by_email(email) + + if user is not None and user.is_active: + token_data = {"user_id": user.id, "aud": reset_password_token_audience} + token = generate_jwt( + token_data, + reset_password_token_lifetime_seconds, + reset_password_token_secret, + ) + if is_on_after_forgot_password_async: + await on_after_forgot_password(user, token) + else: + on_after_forgot_password(user, token) + + return None + + @router.post("/reset-password") + async def reset_password(token: str = Body(...), password: str = Body(...)): + try: + data = jwt.decode( + token, + reset_password_token_secret, + audience=reset_password_token_audience, + algorithms=[JWT_ALGORITHM], + ) + user_id = data.get("user_id") + if user_id is None: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) + + user = await user_db.get(user_id) + if user is None or not user.is_active: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) + + updated_user = BaseUserDB(**user.dict()) + updated_user.hashed_password = get_password_hash(password) + await user_db.update(updated_user) + except jwt.PyJWTError: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) + return router diff --git a/fastapi_users/utils.py b/fastapi_users/utils.py new file mode 100644 index 00000000..7a61faee --- /dev/null +++ b/fastapi_users/utils.py @@ -0,0 +1,14 @@ +from datetime import datetime, timedelta + +import jwt + +JWT_ALGORITHM = "HS256" + + +def generate_jwt( + data: dict, lifetime_seconds: int, secret: str, algorithm: str = JWT_ALGORITHM +) -> str: + payload = data.copy() + expire = datetime.utcnow() + timedelta(seconds=lifetime_seconds) + payload["exp"] = expire + return jwt.encode(payload, secret, algorithm=algorithm).decode("utf-8") diff --git a/tests/conftest.py b/tests/conftest.py index 60d89086..9ff2a917 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -68,6 +68,9 @@ class MockUserDatabase(BaseUserDatabase): async def create(self, user: BaseUserDB) -> BaseUserDB: return user + async def update(self, user: BaseUserDB) -> BaseUserDB: + return user + @pytest.fixture def mock_user_db() -> MockUserDatabase: diff --git a/tests/test_authentication_jwt.py b/tests/test_authentication_jwt.py index f2c73994..6a5511f2 100644 --- a/tests/test_authentication_jwt.py +++ b/tests/test_authentication_jwt.py @@ -5,11 +5,11 @@ from starlette import status from starlette.responses import Response from starlette.testclient import TestClient -from fastapi_users.authentication.jwt import JWTAuthentication, generate_jwt +from fastapi_users.authentication.jwt import JWTAuthentication from fastapi_users.models import BaseUserDB +from fastapi_users.utils import JWT_ALGORITHM, generate_jwt SECRET = "SECRET" -ALGORITHM = "HS256" LIFETIME = 3600 @@ -21,8 +21,8 @@ def jwt_authentication(): @pytest.fixture def token(): def _token(user, lifetime=LIFETIME): - data = {"user_id": user.id} - return generate_jwt(data, lifetime, SECRET, ALGORITHM) + data = {"user_id": user.id, "aud": "fastapi-users:auth"} + return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM) return _token @@ -47,7 +47,9 @@ async def test_get_login_response(jwt_authentication, user): assert "token" in login_response token = login_response["token"] - decoded = jwt.decode(token, SECRET, algorithms=[ALGORITHM]) + decoded = jwt.decode( + token, SECRET, audience="fastapi-users:auth", algorithms=[JWT_ALGORITHM] + ) assert decoded["user_id"] == user.id diff --git a/tests/test_db_sqlalchemy.py b/tests/test_db_sqlalchemy.py index d6f36bf7..47a4d589 100644 --- a/tests/test_db_sqlalchemy.py +++ b/tests/test_db_sqlalchemy.py @@ -73,10 +73,7 @@ async def test_queries(sqlalchemy_user_db): # Exception when inserting non-nullable fields with pytest.raises(sqlite3.IntegrityError): - wrong_user = BaseUserDB( - id="222", - hashed_password="aaa" - ) + wrong_user = BaseUserDB(id="222", hashed_password="aaa") await sqlalchemy_user_db.create(wrong_user) # Unknown user diff --git a/tests/test_fastapi_users.py b/tests/test_fastapi_users.py index a7fa1442..1e06d827 100644 --- a/tests/test_fastapi_users.py +++ b/tests/test_fastapi_users.py @@ -6,13 +6,20 @@ from starlette.testclient import TestClient from fastapi_users import FastAPIUsers from fastapi_users.models import BaseUser, BaseUserDB +SECRET = "SECRET" + @pytest.fixture def fastapi_users(mock_user_db, mock_authentication) -> FastAPIUsers: class User(BaseUser): pass - return FastAPIUsers(mock_user_db, mock_authentication, User) + def on_after_forgot_password(user, token): + pass + + return FastAPIUsers( + mock_user_db, mock_authentication, User, on_after_forgot_password, SECRET + ) @pytest.fixture @@ -43,6 +50,12 @@ class TestRouter: response = test_app_client.post("/login") assert response.status_code != status.HTTP_404_NOT_FOUND + response = test_app_client.post("/forgot-password") + assert response.status_code != status.HTTP_404_NOT_FOUND + + response = test_app_client.post("/reset-password") + assert response.status_code != status.HTTP_404_NOT_FOUND + class TestGetCurrentUser: def test_missing_token(self, test_app_client: TestClient): diff --git a/tests/test_router.py b/tests/test_router.py index 997d49dd..4cf42605 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -1,3 +1,7 @@ +import asyncio +from unittest.mock import MagicMock + +import jwt import pytest from fastapi import FastAPI from starlette import status @@ -5,19 +9,65 @@ from starlette.testclient import TestClient from fastapi_users.models import BaseUser, BaseUserDB from fastapi_users.router import get_user_router +from fastapi_users.utils import JWT_ALGORITHM, generate_jwt + +SECRET = "SECRET" +LIFETIME = 3600 @pytest.fixture -def test_app_client(mock_user_db, mock_authentication) -> TestClient: - class User(BaseUser): - pass +def forgot_password_token(): + def _forgot_password_token(user_id, lifetime=LIFETIME): + data = {"user_id": user_id, "aud": "fastapi-users:reset"} + return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM) - userRouter = get_user_router(mock_user_db, User, mock_authentication) + return _forgot_password_token - app = FastAPI() - app.include_router(userRouter) - return TestClient(app) +@pytest.fixture() +def on_after_forgot_password_sync(): + on_after_forgot_password_mock = MagicMock(return_value=None) + return on_after_forgot_password_mock + + +@pytest.fixture() +def on_after_forgot_password_async(): + on_after_forgot_password_mock = MagicMock(return_value=asyncio.Future()) + on_after_forgot_password_mock.return_value.set_result(None) + return on_after_forgot_password_mock + + +@pytest.fixture +def get_test_app_client(mock_user_db, mock_authentication): + def _get_test_app_client(on_after_forgot_password) -> TestClient: + class User(BaseUser): + pass + + userRouter = get_user_router( + mock_user_db, + User, + mock_authentication, + on_after_forgot_password, + SECRET, + LIFETIME, + ) + + app = FastAPI() + app.include_router(userRouter) + + return TestClient(app) + + return _get_test_app_client + + +@pytest.fixture +def test_app_client(get_test_app_client, on_after_forgot_password_sync): + return get_test_app_client(on_after_forgot_password_sync) + + +@pytest.fixture +def test_app_client_async(get_test_app_client, on_after_forgot_password_async): + return get_test_app_client(on_after_forgot_password_async) class TestRegister: @@ -81,3 +131,124 @@ class TestLogin: data = {"username": "percival@camelot.bt", "password": "angharad"} response = test_app_client.post("/login", data=data) assert response.status_code == status.HTTP_400_BAD_REQUEST + + +class TestForgotPassword: + def test_empty_body( + self, test_app_client: TestClient, on_after_forgot_password_sync + ): + response = test_app_client.post("/forgot-password", json={}) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert on_after_forgot_password_sync.called is False + + def test_not_existing_user( + self, test_app_client: TestClient, on_after_forgot_password_sync + ): + json = {"email": "lancelot@camelot.bt"} + response = test_app_client.post("/forgot-password", json=json) + assert response.status_code == status.HTTP_202_ACCEPTED + assert on_after_forgot_password_sync.called is False + + def test_inactive_user( + self, test_app_client: TestClient, on_after_forgot_password_sync + ): + json = {"email": "percival@camelot.bt"} + response = test_app_client.post("/forgot-password", json=json) + assert response.status_code == status.HTTP_202_ACCEPTED + assert on_after_forgot_password_sync.called is False + + def test_existing_user_sync_hook( + self, test_app_client: TestClient, on_after_forgot_password_sync, user + ): + json = {"email": "king.arthur@camelot.bt"} + response = test_app_client.post("/forgot-password", json=json) + assert response.status_code == status.HTTP_202_ACCEPTED + assert on_after_forgot_password_sync.called is True + + actual_user = on_after_forgot_password_sync.call_args[0][0] + assert actual_user.id == user.id + actual_token = on_after_forgot_password_sync.call_args[0][1] + decoded_token = jwt.decode( + actual_token, + SECRET, + audience="fastapi-users:reset", + algorithms=[JWT_ALGORITHM], + ) + assert decoded_token["user_id"] == user.id + + def test_existing_user_async_hook( + self, test_app_client_async: TestClient, on_after_forgot_password_async, user + ): + json = {"email": "king.arthur@camelot.bt"} + response = test_app_client_async.post("/forgot-password", json=json) + assert response.status_code == status.HTTP_202_ACCEPTED + assert on_after_forgot_password_async.called is True + + actual_user = on_after_forgot_password_async.call_args[0][0] + assert actual_user.id == user.id + actual_token = on_after_forgot_password_async.call_args[0][1] + decoded_token = jwt.decode( + actual_token, + SECRET, + audience="fastapi-users:reset", + algorithms=[JWT_ALGORITHM], + ) + assert decoded_token["user_id"] == user.id + + +class TestResetPassword: + def test_empty_body(self, test_app_client: TestClient): + response = test_app_client.post("/reset-password", json={}) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_missing_token(self, test_app_client: TestClient): + json = {"password": "guinevere"} + response = test_app_client.post("/reset-password", json=json) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_missing_password(self, test_app_client: TestClient): + json = {"token": "foo"} + response = test_app_client.post("/reset-password", json=json) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_invalid_token(self, test_app_client: TestClient): + json = {"token": "foo", "password": "guinevere"} + response = test_app_client.post("/reset-password", json=json) + print(response.json(), response.status_code) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_inactive_user( + self, + mocker, + mock_user_db, + test_app_client: TestClient, + forgot_password_token, + inactive_user: BaseUserDB, + ): + mocker.spy(mock_user_db, "update") + + json = { + "token": forgot_password_token(inactive_user.id), + "password": "holygrail", + } + response = test_app_client.post("/reset-password", json=json) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert mock_user_db.update.called is False + + def test_existing_user( + self, + mocker, + mock_user_db, + test_app_client: TestClient, + forgot_password_token, + user: BaseUserDB, + ): + mocker.spy(mock_user_db, "update") + + json = {"token": forgot_password_token(user.id), "password": "holygrail"} + response = test_app_client.post("/reset-password", json=json) + assert response.status_code == status.HTTP_200_OK + assert mock_user_db.update.called is True + + updated_user = mock_user_db.update.call_args[0][0] + assert updated_user.hashed_password != user.hashed_password