#5 Improve test coverage (#6)

* Improve test coverage of BaseUserDatabase

* Improve unit test isolation

* Improve coverage of router and authentication
This commit is contained in:
François Voron
2019-10-15 07:54:53 +02:00
committed by GitHub
parent 66ef56758a
commit f2892aa378
12 changed files with 335 additions and 312 deletions

View File

@ -24,5 +24,5 @@ jobs:
env: env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
run: | run: |
pipenv run pytest --cov=./ pipenv run pytest --cov=fastapi_users/
pipenv run codecov pipenv run codecov

View File

@ -5,7 +5,7 @@ format:
$(PIPENV_RUN) black . $(PIPENV_RUN) black .
test: test:
$(PIPENV_RUN) pytest $(PIPENV_RUN) pytest --cov=fastapi_users/
docs-serve: docs-serve:
$(PIPENV_RUN) mkdocs serve $(PIPENV_RUN) mkdocs serve

View File

@ -18,6 +18,7 @@ mypy = "*"
codecov = "*" codecov = "*"
pytest-cov = "*" pytest-cov = "*"
pytest-mock = "*" pytest-mock = "*"
asynctest = "*"
[packages] [packages]
fastapi = "*" fastapi = "*"

10
Pipfile.lock generated
View File

@ -1,7 +1,7 @@
{ {
"_meta": { "_meta": {
"hash": { "hash": {
"sha256": "a45946502cd269a192b56dc49c22a1e2d2c95002de21c6f56da8599d4082014c" "sha256": "4d1dac037a2a0eae31b75c787118ecef1fc552f381d7f6a112382de8003b2008"
}, },
"pipfile-spec": 6, "pipfile-spec": 6,
"requires": { "requires": {
@ -184,6 +184,14 @@
], ],
"version": "==1.4.3" "version": "==1.4.3"
}, },
"asynctest": {
"hashes": [
"sha256:5da6118a7e6d6b54d83a8f7197769d046922a44d2a99c21382f0a6e4fadae676",
"sha256:c27862842d15d83e6a34eb0b2866c323880eb3a75e4485b079ea11748fd77fac"
],
"index": "pypi",
"version": "==0.13.0"
},
"atomicwrites": { "atomicwrites": {
"hashes": [ "hashes": [
"sha256:03472c30eb2c5d1ba9227e4c2ca66ab8287fbfbbda3888aa93dc2e28fc6811b4", "sha256:03472c30eb2c5d1ba9227e4c2ca66ab8287fbfbbda3888aa93dc2e28fc6811b4",

View File

@ -2,8 +2,8 @@ from typing import List, Optional
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from fastapi_users import password
from fastapi_users.models import BaseUserDB from fastapi_users.models import BaseUserDB
from fastapi_users.password import get_password_hash, verify_and_update_password
class BaseUserDatabase: class BaseUserDatabase:
@ -41,12 +41,12 @@ class BaseUserDatabase:
# Always run the hasher to mitigate timing attack # Always run the hasher to mitigate timing attack
# Inspired from Django: https://code.djangoproject.com/ticket/20760 # Inspired from Django: https://code.djangoproject.com/ticket/20760
get_password_hash(credentials.password) password.get_password_hash(credentials.password)
if user is None: if user is None:
return None return None
else: else:
verified, updated_password_hash = verify_and_update_password( verified, updated_password_hash = password.verify_and_update_password(
credentials.password, user.hashed_password credentials.password, user.hashed_password
) )
if not verified: if not verified:

View File

@ -1,4 +1,4 @@
import inspect import asyncio
from typing import Any, Callable, Type from typing import Any, Callable, Type
import jwt import jwt
@ -28,7 +28,7 @@ def get_user_router(
models = Models(user_model) models = Models(user_model)
reset_password_token_audience = "fastapi-users:reset" reset_password_token_audience = "fastapi-users:reset"
is_on_after_forgot_password_async = inspect.iscoroutinefunction( is_on_after_forgot_password_async = asyncio.iscoroutinefunction(
on_after_forgot_password on_after_forgot_password
) )
@ -87,9 +87,8 @@ def get_user_router(
if user is None or not user.is_active: if user is None or not user.is_active:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
updated_user = BaseUserDB(**user.dict()) user.hashed_password = get_password_hash(password)
updated_user.hashed_password = get_password_hash(password) await user_db.update(user)
await user_db.update(updated_user)
except jwt.PyJWTError: except jwt.PyJWTError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)

View File

@ -1,116 +1,147 @@
from typing import Optional from typing import Optional
import pytest import pytest
from fastapi import Depends from fastapi import Depends, FastAPI
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from starlette.responses import Response from starlette.responses import Response
from starlette.testclient import TestClient
from fastapi_users.authentication import BaseAuthentication from fastapi_users.authentication import BaseAuthentication
from fastapi_users.db import BaseUserDatabase from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import BaseUserDB from fastapi_users.models import BaseUserDB
from fastapi_users.password import get_password_hash from fastapi_users.password import get_password_hash
active_user_data = BaseUserDB( guinevere_password_hash = get_password_hash("guinevere")
id="aaa", angharad_password_hash = get_password_hash("angharad")
email="king.arthur@camelot.bt", viviane_password_hash = get_password_hash("viviane")
hashed_password=get_password_hash("guinevere"),
)
inactive_user_data = BaseUserDB(
id="bbb",
email="percival@camelot.bt",
hashed_password=get_password_hash("angharad"),
is_active=False,
)
superuser_data = BaseUserDB(
id="ccc",
email="merlin@camelot.bt",
hashed_password=get_password_hash("viviane"),
is_superuser=True,
)
@pytest.fixture @pytest.fixture
def user() -> BaseUserDB: def user() -> BaseUserDB:
return active_user_data return BaseUserDB(
id="aaa",
email="king.arthur@camelot.bt",
hashed_password=guinevere_password_hash,
)
@pytest.fixture @pytest.fixture
def inactive_user() -> BaseUserDB: def inactive_user() -> BaseUserDB:
return inactive_user_data return BaseUserDB(
id="bbb",
email="percival@camelot.bt",
hashed_password=angharad_password_hash,
is_active=False,
)
@pytest.fixture @pytest.fixture
def superuser() -> BaseUserDB: def superuser() -> BaseUserDB:
return superuser_data return BaseUserDB(
id="ccc",
email="merlin@camelot.bt",
class MockUserDatabase(BaseUserDatabase): hashed_password=viviane_password_hash,
async def get(self, id: str) -> Optional[BaseUserDB]: is_superuser=True,
if id == active_user_data.id: )
return active_user_data
elif id == inactive_user_data.id:
return inactive_user_data
elif id == superuser_data.id:
return superuser_data
return None
async def get_by_email(self, email: str) -> Optional[BaseUserDB]:
if email == active_user_data.email:
return active_user_data
elif email == inactive_user_data.email:
return inactive_user_data
elif email == superuser_data.email:
return superuser_data
return None
async def create(self, user: BaseUserDB) -> BaseUserDB:
return user
async def update(self, user: BaseUserDB) -> BaseUserDB:
return user
@pytest.fixture @pytest.fixture
def mock_user_db() -> MockUserDatabase: def mock_user_db(user, inactive_user, superuser) -> BaseUserDatabase:
class MockUserDatabase(BaseUserDatabase):
async def get(self, id: str) -> Optional[BaseUserDB]:
if id == user.id:
return user
elif id == inactive_user.id:
return inactive_user
elif id == superuser.id:
return superuser
return None
async def get_by_email(self, email: str) -> Optional[BaseUserDB]:
if email == user.email:
return user
elif email == inactive_user.email:
return inactive_user
elif email == superuser.email:
return superuser
return None
async def create(self, user: BaseUserDB) -> BaseUserDB:
return user
async def update(self, user: BaseUserDB) -> BaseUserDB:
return user
return MockUserDatabase() return MockUserDatabase()
class MockAuthentication(BaseAuthentication): @pytest.fixture
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login") def mock_authentication():
class MockAuthentication(BaseAuthentication):
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
async def get_login_response(self, user: BaseUserDB, response: Response): async def get_login_response(self, user: BaseUserDB, response: Response):
return {"token": user.id} return {"token": user.id}
def get_current_user(self, user_db: BaseUserDatabase): def get_current_user(self, user_db: BaseUserDatabase):
async def _get_current_user(token: str = Depends(self.oauth2_scheme)): async def _get_current_user(token: str = Depends(self.oauth2_scheme)):
user = await self._get_authentication_method(user_db)(token) user = await self._get_authentication_method(user_db)(token)
return self._get_current_user_base(user) return self._get_current_user_base(user)
return _get_current_user return _get_current_user
def get_current_active_user(self, user_db: BaseUserDatabase): def get_current_active_user(self, user_db: BaseUserDatabase):
async def _get_current_active_user(token: str = Depends(self.oauth2_scheme)): async def _get_current_active_user(
user = await self._get_authentication_method(user_db)(token) token: str = Depends(self.oauth2_scheme)
return self._get_current_active_user_base(user) ):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_active_user_base(user)
return _get_current_active_user return _get_current_active_user
def get_current_superuser(self, user_db: BaseUserDatabase): def get_current_superuser(self, user_db: BaseUserDatabase):
async def _get_current_superuser(token: str = Depends(self.oauth2_scheme)): async def _get_current_superuser(token: str = Depends(self.oauth2_scheme)):
user = await self._get_authentication_method(user_db)(token) user = await self._get_authentication_method(user_db)(token)
return self._get_current_superuser_base(user) return self._get_current_superuser_base(user)
return _get_current_superuser return _get_current_superuser
def _get_authentication_method(self, user_db: BaseUserDatabase): def _get_authentication_method(self, user_db: BaseUserDatabase):
async def authentication_method(token: str = Depends(self.oauth2_scheme)): async def authentication_method(token: str = Depends(self.oauth2_scheme)):
return await user_db.get(token) return await user_db.get(token)
return authentication_method return authentication_method
return MockAuthentication()
@pytest.fixture @pytest.fixture
def mock_authentication() -> MockAuthentication: def get_test_auth_client(mock_user_db):
return MockAuthentication() def _get_test_auth_client(authentication):
app = FastAPI()
@app.get("/test-current-user")
def test_current_user(
user: BaseUserDB = Depends(authentication.get_current_user(mock_user_db))
):
return user
@app.get("/test-current-active-user")
def test_current_active_user(
user: BaseUserDB = Depends(
authentication.get_current_active_user(mock_user_db)
)
):
return user
@app.get("/test-current-superuser")
def test_current_superuser(
user: BaseUserDB = Depends(
authentication.get_current_superuser(mock_user_db)
)
):
return user
return TestClient(app)
return _get_test_auth_client

View File

@ -1,130 +1,25 @@
import pytest import pytest
from fastapi import Depends, FastAPI from starlette.responses import Response
from starlette import status
from starlette.testclient import TestClient
from fastapi_users.models import BaseUserDB from fastapi_users.authentication import BaseAuthentication
@pytest.fixture @pytest.mark.asyncio
def test_auth_client(mock_authentication, mock_user_db): async def test_not_implemented_methods(user, mock_user_db):
app = FastAPI() response = Response()
base_authentication = BaseAuthentication()
@app.get("/test-current-user") with pytest.raises(NotImplementedError):
def test_current_user( await base_authentication.get_login_response(user, response)
user: BaseUserDB = Depends(mock_authentication.get_current_user(mock_user_db))
):
return user
@app.get("/test-current-active-user") with pytest.raises(NotImplementedError):
def test_current_active_user( await base_authentication.get_current_user(mock_user_db)
user: BaseUserDB = Depends(
mock_authentication.get_current_active_user(mock_user_db)
)
):
return user
@app.get("/test-current-superuser") with pytest.raises(NotImplementedError):
def test_current_superuser( await base_authentication.get_current_active_user(mock_user_db)
user: BaseUserDB = Depends(
mock_authentication.get_current_superuser(mock_user_db)
)
):
return user
return TestClient(app) with pytest.raises(NotImplementedError):
await base_authentication.get_current_superuser(mock_user_db)
with pytest.raises(NotImplementedError):
class TestGetCurrentUser: await base_authentication._get_authentication_method(mock_user_db)
def test_missing_token(self, test_auth_client):
response = test_auth_client.get("/test-current-user")
print(response.json())
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_auth_client):
response = test_auth_client.get(
"/test-current-user", headers={"Authorization": "Bearer foo"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_inactive_user(self, test_auth_client, inactive_user):
response = test_auth_client.get(
"/test-current-user",
headers={"Authorization": f"Bearer {inactive_user.id}"},
)
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == inactive_user.id
def test_valid_token(self, test_auth_client, user):
response = test_auth_client.get(
"/test-current-user", headers={"Authorization": f"Bearer {user.id}"}
)
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == user.id
class TestGetCurrentActiveUser:
def test_missing_token(self, test_auth_client):
response = test_auth_client.get("/test-current-active-user")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_auth_client):
response = test_auth_client.get(
"/test-current-active-user", headers={"Authorization": "Bearer foo"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_inactive_user(self, test_auth_client, inactive_user):
response = test_auth_client.get(
"/test-current-active-user",
headers={"Authorization": f"Bearer {inactive_user.id}"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token(self, test_auth_client, user):
response = test_auth_client.get(
"/test-current-active-user", headers={"Authorization": f"Bearer {user.id}"}
)
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == user.id
class TestGetCurrentSuperuser:
def test_missing_token(self, test_auth_client):
response = test_auth_client.get("/test-current-superuser")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_auth_client):
response = test_auth_client.get(
"/test-current-superuser", headers={"Authorization": "Bearer foo"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_inactive_user(self, test_auth_client, inactive_user):
response = test_auth_client.get(
"/test-current-superuser",
headers={"Authorization": f"Bearer {inactive_user.id}"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_regular_user(self, test_auth_client, user):
response = test_auth_client.get(
"/test-current-superuser", headers={"Authorization": f"Bearer {user.id}"}
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_valid_token_superuser(self, test_auth_client, superuser):
response = test_auth_client.get(
"/test-current-superuser",
headers={"Authorization": f"Bearer {superuser.id}"},
)
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == superuser.id

View File

@ -1,12 +1,9 @@
import jwt import jwt
import pytest import pytest
from fastapi import Depends, FastAPI
from starlette import status from starlette import status
from starlette.responses import Response from starlette.responses import Response
from starlette.testclient import TestClient
from fastapi_users.authentication.jwt import JWTAuthentication from fastapi_users.authentication.jwt import JWTAuthentication
from fastapi_users.models import BaseUserDB
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
SECRET = "SECRET" SECRET = "SECRET"
@ -20,24 +17,18 @@ def jwt_authentication():
@pytest.fixture @pytest.fixture
def token(): def token():
def _token(user, lifetime=LIFETIME): def _token(user=None, lifetime=LIFETIME):
data = {"user_id": user.id, "aud": "fastapi-users:auth"} data = {"aud": "fastapi-users:auth"}
if user is not None:
data["user_id"] = user.id
return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM) return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM)
return _token return _token
@pytest.fixture @pytest.fixture
def test_auth_client(jwt_authentication, mock_user_db): def test_auth_client(get_test_auth_client, jwt_authentication):
app = FastAPI() return get_test_auth_client(jwt_authentication)
@app.get("/test-auth")
def test_auth(
user: BaseUserDB = Depends(jwt_authentication.get_current_user(mock_user_db))
):
return user
return TestClient(app)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -55,18 +46,26 @@ async def test_get_login_response(jwt_authentication, user):
class TestGetCurrentUser: class TestGetCurrentUser:
def test_missing_token(self, test_auth_client): def test_missing_token(self, test_auth_client):
response = test_auth_client.get("/test-auth") response = test_auth_client.get("/test-current-user")
print(response.json())
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_auth_client): def test_invalid_token(self, test_auth_client):
response = test_auth_client.get( response = test_auth_client.get(
"/test-auth", headers={"Authorization": "Bearer foo"} "/test-current-user", headers={"Authorization": "Bearer foo"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_missing_user_payload(self, test_auth_client, token):
response = test_auth_client.get(
"/test-current-user", headers={"Authorization": f"Bearer {token()}"}
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user): def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user):
response = test_auth_client.get( response = test_auth_client.get(
"/test-auth", headers={"Authorization": f"Bearer {token(inactive_user)}"} "/test-current-user",
headers={"Authorization": f"Bearer {token(inactive_user)}"},
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -75,9 +74,74 @@ class TestGetCurrentUser:
def test_valid_token(self, test_auth_client, token, user): def test_valid_token(self, test_auth_client, token, user):
response = test_auth_client.get( response = test_auth_client.get(
"/test-auth", headers={"Authorization": f"Bearer {token(user)}"} "/test-current-user", headers={"Authorization": f"Bearer {token(user)}"}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
response_json = response.json() response_json = response.json()
assert response_json["id"] == user.id assert response_json["id"] == user.id
class TestGetCurrentActiveUser:
def test_missing_token(self, test_auth_client):
response = test_auth_client.get("/test-current-active-user")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_auth_client):
response = test_auth_client.get(
"/test-current-active-user", headers={"Authorization": "Bearer foo"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user):
response = test_auth_client.get(
"/test-current-active-user",
headers={"Authorization": f"Bearer {token(inactive_user)}"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token(self, test_auth_client, token, user):
response = test_auth_client.get(
"/test-current-active-user",
headers={"Authorization": f"Bearer {token(user)}"},
)
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == user.id
class TestGetCurrentSuperuser:
def test_missing_token(self, test_auth_client):
response = test_auth_client.get("/test-current-superuser")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_auth_client):
response = test_auth_client.get(
"/test-current-superuser", headers={"Authorization": "Bearer foo"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user):
response = test_auth_client.get(
"/test-current-superuser",
headers={"Authorization": f"Bearer {token(inactive_user)}"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_regular_user(self, test_auth_client, token, user):
response = test_auth_client.get(
"/test-current-superuser",
headers={"Authorization": f"Bearer {token(user)}"},
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_valid_token_superuser(self, test_auth_client, token, superuser):
response = test_auth_client.get(
"/test-current-superuser",
headers={"Authorization": f"Bearer {token(superuser)}"},
)
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == superuser.id

View File

@ -1,6 +1,8 @@
import pytest import pytest
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from fastapi_users.db import BaseUserDatabase
@pytest.fixture @pytest.fixture
def create_oauth2_password_request_form(): def create_oauth2_password_request_form():
@ -10,6 +12,26 @@ def create_oauth2_password_request_form():
return _create_oauth2_password_request_form return _create_oauth2_password_request_form
@pytest.mark.asyncio
async def test_not_implemented_methods(user):
base_user_db = BaseUserDatabase()
with pytest.raises(NotImplementedError):
await base_user_db.list()
with pytest.raises(NotImplementedError):
await base_user_db.get("aaa")
with pytest.raises(NotImplementedError):
await base_user_db.get_by_email("lancelot@camelot.bt")
with pytest.raises(NotImplementedError):
await base_user_db.create(user)
with pytest.raises(NotImplementedError):
await base_user_db.update(user)
class TestAuthenticate: class TestAuthenticate:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_unknown_user( async def test_unknown_user(
@ -37,3 +59,21 @@ class TestAuthenticate:
user = await mock_user_db.authenticate(form) user = await mock_user_db.authenticate(form)
assert user is not None assert user is not None
assert user.email == "king.arthur@camelot.bt" assert user.email == "king.arthur@camelot.bt"
@pytest.mark.asyncio
async def test_upgrade_password_hash(
self, mocker, create_oauth2_password_request_form, mock_user_db
):
verify_and_update_password_patch = mocker.patch(
"fastapi_users.password.verify_and_update_password"
)
verify_and_update_password_patch.return_value = (True, "updated_hash")
mocker.spy(mock_user_db, "update")
form = create_oauth2_password_request_form(
"king.arthur@camelot.bt", "guinevere"
)
user = await mock_user_db.authenticate(form)
assert user is not None
assert user.email == "king.arthur@camelot.bt"
assert mock_user_db.update.called is True

View File

@ -9,21 +9,22 @@ from fastapi_users.models import BaseUser, BaseUserDB
SECRET = "SECRET" SECRET = "SECRET"
@pytest.fixture def sync_on_after_forgot_password():
def fastapi_users(mock_user_db, mock_authentication) -> FastAPIUsers: return None
async def async_on_after_forgot_password():
return None
@pytest.fixture(params=[sync_on_after_forgot_password, async_on_after_forgot_password])
def test_app_client(request, mock_user_db, mock_authentication) -> TestClient:
class User(BaseUser): class User(BaseUser):
pass pass
def on_after_forgot_password(user, token): fastapi_users = FastAPIUsers(
pass mock_user_db, mock_authentication, User, request.param, SECRET
return FastAPIUsers(
mock_user_db, mock_authentication, User, on_after_forgot_password, SECRET
) )
@pytest.fixture
def test_app_client(fastapi_users: FastAPIUsers) -> TestClient:
app = FastAPI() app = FastAPI()
app.include_router(fastapi_users.router) app.include_router(fastapi_users.router)

View File

@ -1,6 +1,6 @@
import asyncio
from unittest.mock import MagicMock from unittest.mock import MagicMock
import asynctest
import jwt import jwt
import pytest import pytest
from fastapi import FastAPI from fastapi import FastAPI
@ -17,57 +17,48 @@ LIFETIME = 3600
@pytest.fixture @pytest.fixture
def forgot_password_token(): def forgot_password_token():
def _forgot_password_token(user_id, lifetime=LIFETIME): def _forgot_password_token(user_id=None, lifetime=LIFETIME):
data = {"user_id": user_id, "aud": "fastapi-users:reset"} data = {"aud": "fastapi-users:reset"}
if user_id is not None:
data["user_id"] = user_id
return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM) return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM)
return _forgot_password_token return _forgot_password_token
@pytest.fixture()
def on_after_forgot_password_sync(): def on_after_forgot_password_sync():
on_after_forgot_password_mock = MagicMock(return_value=None) return MagicMock(return_value=None)
return on_after_forgot_password_mock
def on_after_forgot_password_async():
return asynctest.CoroutineMock(return_value=None)
@pytest.fixture(params=[on_after_forgot_password_sync, on_after_forgot_password_async])
def on_after_forgot_password(request):
return request.param()
@pytest.fixture() @pytest.fixture()
def on_after_forgot_password_async(): def test_app_client(
on_after_forgot_password_mock = MagicMock(return_value=asyncio.Future()) mock_user_db, mock_authentication, on_after_forgot_password
on_after_forgot_password_mock.return_value.set_result(None) ) -> TestClient:
return on_after_forgot_password_mock class User(BaseUser):
pass
userRouter = get_user_router(
mock_user_db,
User,
mock_authentication,
on_after_forgot_password,
SECRET,
LIFETIME,
)
@pytest.fixture app = FastAPI()
def get_test_app_client(mock_user_db, mock_authentication): app.include_router(userRouter)
def _get_test_app_client(on_after_forgot_password) -> TestClient:
class User(BaseUser):
pass
userRouter = get_user_router( return TestClient(app)
mock_user_db,
User,
mock_authentication,
on_after_forgot_password,
SECRET,
LIFETIME,
)
app = FastAPI()
app.include_router(userRouter)
return TestClient(app)
return _get_test_app_client
@pytest.fixture
def test_app_client(get_test_app_client, on_after_forgot_password_sync):
return get_test_app_client(on_after_forgot_password_sync)
@pytest.fixture
def test_app_client_async(get_test_app_client, on_after_forgot_password_async):
return get_test_app_client(on_after_forgot_password_async)
class TestRegister: class TestRegister:
@ -134,59 +125,36 @@ class TestLogin:
class TestForgotPassword: class TestForgotPassword:
def test_empty_body( def test_empty_body(self, test_app_client: TestClient, on_after_forgot_password):
self, test_app_client: TestClient, on_after_forgot_password_sync
):
response = test_app_client.post("/forgot-password", json={}) response = test_app_client.post("/forgot-password", json={})
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
assert on_after_forgot_password_sync.called is False assert on_after_forgot_password.called is False
def test_not_existing_user( def test_not_existing_user(
self, test_app_client: TestClient, on_after_forgot_password_sync self, test_app_client: TestClient, on_after_forgot_password
): ):
json = {"email": "lancelot@camelot.bt"} json = {"email": "lancelot@camelot.bt"}
response = test_app_client.post("/forgot-password", json=json) response = test_app_client.post("/forgot-password", json=json)
assert response.status_code == status.HTTP_202_ACCEPTED assert response.status_code == status.HTTP_202_ACCEPTED
assert on_after_forgot_password_sync.called is False assert on_after_forgot_password.called is False
def test_inactive_user( def test_inactive_user(self, test_app_client: TestClient, on_after_forgot_password):
self, test_app_client: TestClient, on_after_forgot_password_sync
):
json = {"email": "percival@camelot.bt"} json = {"email": "percival@camelot.bt"}
response = test_app_client.post("/forgot-password", json=json) response = test_app_client.post("/forgot-password", json=json)
assert response.status_code == status.HTTP_202_ACCEPTED assert response.status_code == status.HTTP_202_ACCEPTED
assert on_after_forgot_password_sync.called is False assert on_after_forgot_password.called is False
def test_existing_user_sync_hook( def test_existing_user(
self, test_app_client: TestClient, on_after_forgot_password_sync, user self, test_app_client: TestClient, on_after_forgot_password, user
): ):
json = {"email": "king.arthur@camelot.bt"} json = {"email": "king.arthur@camelot.bt"}
response = test_app_client.post("/forgot-password", json=json) response = test_app_client.post("/forgot-password", json=json)
assert response.status_code == status.HTTP_202_ACCEPTED assert response.status_code == status.HTTP_202_ACCEPTED
assert on_after_forgot_password_sync.called is True assert on_after_forgot_password.called is True
actual_user = on_after_forgot_password_sync.call_args[0][0] actual_user = on_after_forgot_password.call_args[0][0]
assert actual_user.id == user.id assert actual_user.id == user.id
actual_token = on_after_forgot_password_sync.call_args[0][1] actual_token = on_after_forgot_password.call_args[0][1]
decoded_token = jwt.decode(
actual_token,
SECRET,
audience="fastapi-users:reset",
algorithms=[JWT_ALGORITHM],
)
assert decoded_token["user_id"] == user.id
def test_existing_user_async_hook(
self, test_app_client_async: TestClient, on_after_forgot_password_async, user
):
json = {"email": "king.arthur@camelot.bt"}
response = test_app_client_async.post("/forgot-password", json=json)
assert response.status_code == status.HTTP_202_ACCEPTED
assert on_after_forgot_password_async.called is True
actual_user = on_after_forgot_password_async.call_args[0][0]
assert actual_user.id == user.id
actual_token = on_after_forgot_password_async.call_args[0][1]
decoded_token = jwt.decode( decoded_token = jwt.decode(
actual_token, actual_token,
SECRET, SECRET,
@ -217,6 +185,21 @@ class TestResetPassword:
print(response.json(), response.status_code) print(response.json(), response.status_code)
assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_valid_token_missing_user_id_payload(
self,
mocker,
mock_user_db,
test_app_client: TestClient,
forgot_password_token,
inactive_user: BaseUserDB,
):
mocker.spy(mock_user_db, "update")
json = {"token": forgot_password_token(), "password": "holygrail"}
response = test_app_client.post("/reset-password", json=json)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert mock_user_db.update.called is False
def test_inactive_user( def test_inactive_user(
self, self,
mocker, mocker,
@ -244,6 +227,7 @@ class TestResetPassword:
user: BaseUserDB, user: BaseUserDB,
): ):
mocker.spy(mock_user_db, "update") mocker.spy(mock_user_db, "update")
current_hashed_passord = user.hashed_password
json = {"token": forgot_password_token(user.id), "password": "holygrail"} json = {"token": forgot_password_token(user.id), "password": "holygrail"}
response = test_app_client.post("/reset-password", json=json) response = test_app_client.post("/reset-password", json=json)
@ -251,4 +235,4 @@ class TestResetPassword:
assert mock_user_db.update.called is True assert mock_user_db.update.called is True
updated_user = mock_user_db.update.call_args[0][0] updated_user = mock_user_db.update.call_args[0][0]
assert updated_user.hashed_password != user.hashed_password assert updated_user.hashed_password != current_hashed_passord