mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-08-16 03:40:23 +08:00
* Improve test coverage of BaseUserDatabase * Improve unit test isolation * Improve coverage of router and authentication
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@ -24,5 +24,5 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||||
run: |
|
run: |
|
||||||
pipenv run pytest --cov=./
|
pipenv run pytest --cov=fastapi_users/
|
||||||
pipenv run codecov
|
pipenv run codecov
|
||||||
|
2
Makefile
2
Makefile
@ -5,7 +5,7 @@ format:
|
|||||||
$(PIPENV_RUN) black .
|
$(PIPENV_RUN) black .
|
||||||
|
|
||||||
test:
|
test:
|
||||||
$(PIPENV_RUN) pytest
|
$(PIPENV_RUN) pytest --cov=fastapi_users/
|
||||||
|
|
||||||
docs-serve:
|
docs-serve:
|
||||||
$(PIPENV_RUN) mkdocs serve
|
$(PIPENV_RUN) mkdocs serve
|
||||||
|
1
Pipfile
1
Pipfile
@ -18,6 +18,7 @@ mypy = "*"
|
|||||||
codecov = "*"
|
codecov = "*"
|
||||||
pytest-cov = "*"
|
pytest-cov = "*"
|
||||||
pytest-mock = "*"
|
pytest-mock = "*"
|
||||||
|
asynctest = "*"
|
||||||
|
|
||||||
[packages]
|
[packages]
|
||||||
fastapi = "*"
|
fastapi = "*"
|
||||||
|
10
Pipfile.lock
generated
10
Pipfile.lock
generated
@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"_meta": {
|
"_meta": {
|
||||||
"hash": {
|
"hash": {
|
||||||
"sha256": "a45946502cd269a192b56dc49c22a1e2d2c95002de21c6f56da8599d4082014c"
|
"sha256": "4d1dac037a2a0eae31b75c787118ecef1fc552f381d7f6a112382de8003b2008"
|
||||||
},
|
},
|
||||||
"pipfile-spec": 6,
|
"pipfile-spec": 6,
|
||||||
"requires": {
|
"requires": {
|
||||||
@ -184,6 +184,14 @@
|
|||||||
],
|
],
|
||||||
"version": "==1.4.3"
|
"version": "==1.4.3"
|
||||||
},
|
},
|
||||||
|
"asynctest": {
|
||||||
|
"hashes": [
|
||||||
|
"sha256:5da6118a7e6d6b54d83a8f7197769d046922a44d2a99c21382f0a6e4fadae676",
|
||||||
|
"sha256:c27862842d15d83e6a34eb0b2866c323880eb3a75e4485b079ea11748fd77fac"
|
||||||
|
],
|
||||||
|
"index": "pypi",
|
||||||
|
"version": "==0.13.0"
|
||||||
|
},
|
||||||
"atomicwrites": {
|
"atomicwrites": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:03472c30eb2c5d1ba9227e4c2ca66ab8287fbfbbda3888aa93dc2e28fc6811b4",
|
"sha256:03472c30eb2c5d1ba9227e4c2ca66ab8287fbfbbda3888aa93dc2e28fc6811b4",
|
||||||
|
@ -2,8 +2,8 @@ from typing import List, Optional
|
|||||||
|
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
|
||||||
|
from fastapi_users import password
|
||||||
from fastapi_users.models import BaseUserDB
|
from fastapi_users.models import BaseUserDB
|
||||||
from fastapi_users.password import get_password_hash, verify_and_update_password
|
|
||||||
|
|
||||||
|
|
||||||
class BaseUserDatabase:
|
class BaseUserDatabase:
|
||||||
@ -41,12 +41,12 @@ class BaseUserDatabase:
|
|||||||
|
|
||||||
# Always run the hasher to mitigate timing attack
|
# Always run the hasher to mitigate timing attack
|
||||||
# Inspired from Django: https://code.djangoproject.com/ticket/20760
|
# Inspired from Django: https://code.djangoproject.com/ticket/20760
|
||||||
get_password_hash(credentials.password)
|
password.get_password_hash(credentials.password)
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
verified, updated_password_hash = verify_and_update_password(
|
verified, updated_password_hash = password.verify_and_update_password(
|
||||||
credentials.password, user.hashed_password
|
credentials.password, user.hashed_password
|
||||||
)
|
)
|
||||||
if not verified:
|
if not verified:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import inspect
|
import asyncio
|
||||||
from typing import Any, Callable, Type
|
from typing import Any, Callable, Type
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
@ -28,7 +28,7 @@ def get_user_router(
|
|||||||
models = Models(user_model)
|
models = Models(user_model)
|
||||||
|
|
||||||
reset_password_token_audience = "fastapi-users:reset"
|
reset_password_token_audience = "fastapi-users:reset"
|
||||||
is_on_after_forgot_password_async = inspect.iscoroutinefunction(
|
is_on_after_forgot_password_async = asyncio.iscoroutinefunction(
|
||||||
on_after_forgot_password
|
on_after_forgot_password
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -87,9 +87,8 @@ def get_user_router(
|
|||||||
if user is None or not user.is_active:
|
if user is None or not user.is_active:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
updated_user = BaseUserDB(**user.dict())
|
user.hashed_password = get_password_hash(password)
|
||||||
updated_user.hashed_password = get_password_hash(password)
|
await user_db.update(user)
|
||||||
await user_db.update(updated_user)
|
|
||||||
except jwt.PyJWTError:
|
except jwt.PyJWTError:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
|
@ -1,116 +1,147 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import Depends
|
from fastapi import Depends, FastAPI
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
from fastapi_users.authentication import BaseAuthentication
|
from fastapi_users.authentication import BaseAuthentication
|
||||||
from fastapi_users.db import BaseUserDatabase
|
from fastapi_users.db import BaseUserDatabase
|
||||||
from fastapi_users.models import BaseUserDB
|
from fastapi_users.models import BaseUserDB
|
||||||
from fastapi_users.password import get_password_hash
|
from fastapi_users.password import get_password_hash
|
||||||
|
|
||||||
active_user_data = BaseUserDB(
|
guinevere_password_hash = get_password_hash("guinevere")
|
||||||
id="aaa",
|
angharad_password_hash = get_password_hash("angharad")
|
||||||
email="king.arthur@camelot.bt",
|
viviane_password_hash = get_password_hash("viviane")
|
||||||
hashed_password=get_password_hash("guinevere"),
|
|
||||||
)
|
|
||||||
|
|
||||||
inactive_user_data = BaseUserDB(
|
|
||||||
id="bbb",
|
|
||||||
email="percival@camelot.bt",
|
|
||||||
hashed_password=get_password_hash("angharad"),
|
|
||||||
is_active=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
superuser_data = BaseUserDB(
|
|
||||||
id="ccc",
|
|
||||||
email="merlin@camelot.bt",
|
|
||||||
hashed_password=get_password_hash("viviane"),
|
|
||||||
is_superuser=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def user() -> BaseUserDB:
|
def user() -> BaseUserDB:
|
||||||
return active_user_data
|
return BaseUserDB(
|
||||||
|
id="aaa",
|
||||||
|
email="king.arthur@camelot.bt",
|
||||||
|
hashed_password=guinevere_password_hash,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def inactive_user() -> BaseUserDB:
|
def inactive_user() -> BaseUserDB:
|
||||||
return inactive_user_data
|
return BaseUserDB(
|
||||||
|
id="bbb",
|
||||||
|
email="percival@camelot.bt",
|
||||||
|
hashed_password=angharad_password_hash,
|
||||||
|
is_active=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def superuser() -> BaseUserDB:
|
def superuser() -> BaseUserDB:
|
||||||
return superuser_data
|
return BaseUserDB(
|
||||||
|
id="ccc",
|
||||||
|
email="merlin@camelot.bt",
|
||||||
class MockUserDatabase(BaseUserDatabase):
|
hashed_password=viviane_password_hash,
|
||||||
async def get(self, id: str) -> Optional[BaseUserDB]:
|
is_superuser=True,
|
||||||
if id == active_user_data.id:
|
)
|
||||||
return active_user_data
|
|
||||||
elif id == inactive_user_data.id:
|
|
||||||
return inactive_user_data
|
|
||||||
elif id == superuser_data.id:
|
|
||||||
return superuser_data
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_by_email(self, email: str) -> Optional[BaseUserDB]:
|
|
||||||
if email == active_user_data.email:
|
|
||||||
return active_user_data
|
|
||||||
elif email == inactive_user_data.email:
|
|
||||||
return inactive_user_data
|
|
||||||
elif email == superuser_data.email:
|
|
||||||
return superuser_data
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def create(self, user: BaseUserDB) -> BaseUserDB:
|
|
||||||
return user
|
|
||||||
|
|
||||||
async def update(self, user: BaseUserDB) -> BaseUserDB:
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_user_db() -> MockUserDatabase:
|
def mock_user_db(user, inactive_user, superuser) -> BaseUserDatabase:
|
||||||
|
class MockUserDatabase(BaseUserDatabase):
|
||||||
|
async def get(self, id: str) -> Optional[BaseUserDB]:
|
||||||
|
if id == user.id:
|
||||||
|
return user
|
||||||
|
elif id == inactive_user.id:
|
||||||
|
return inactive_user
|
||||||
|
elif id == superuser.id:
|
||||||
|
return superuser
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_by_email(self, email: str) -> Optional[BaseUserDB]:
|
||||||
|
if email == user.email:
|
||||||
|
return user
|
||||||
|
elif email == inactive_user.email:
|
||||||
|
return inactive_user
|
||||||
|
elif email == superuser.email:
|
||||||
|
return superuser
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def create(self, user: BaseUserDB) -> BaseUserDB:
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def update(self, user: BaseUserDB) -> BaseUserDB:
|
||||||
|
return user
|
||||||
|
|
||||||
return MockUserDatabase()
|
return MockUserDatabase()
|
||||||
|
|
||||||
|
|
||||||
class MockAuthentication(BaseAuthentication):
|
@pytest.fixture
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
|
def mock_authentication():
|
||||||
|
class MockAuthentication(BaseAuthentication):
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
|
||||||
|
|
||||||
async def get_login_response(self, user: BaseUserDB, response: Response):
|
async def get_login_response(self, user: BaseUserDB, response: Response):
|
||||||
return {"token": user.id}
|
return {"token": user.id}
|
||||||
|
|
||||||
def get_current_user(self, user_db: BaseUserDatabase):
|
def get_current_user(self, user_db: BaseUserDatabase):
|
||||||
async def _get_current_user(token: str = Depends(self.oauth2_scheme)):
|
async def _get_current_user(token: str = Depends(self.oauth2_scheme)):
|
||||||
user = await self._get_authentication_method(user_db)(token)
|
user = await self._get_authentication_method(user_db)(token)
|
||||||
return self._get_current_user_base(user)
|
return self._get_current_user_base(user)
|
||||||
|
|
||||||
return _get_current_user
|
return _get_current_user
|
||||||
|
|
||||||
def get_current_active_user(self, user_db: BaseUserDatabase):
|
def get_current_active_user(self, user_db: BaseUserDatabase):
|
||||||
async def _get_current_active_user(token: str = Depends(self.oauth2_scheme)):
|
async def _get_current_active_user(
|
||||||
user = await self._get_authentication_method(user_db)(token)
|
token: str = Depends(self.oauth2_scheme)
|
||||||
return self._get_current_active_user_base(user)
|
):
|
||||||
|
user = await self._get_authentication_method(user_db)(token)
|
||||||
|
return self._get_current_active_user_base(user)
|
||||||
|
|
||||||
return _get_current_active_user
|
return _get_current_active_user
|
||||||
|
|
||||||
def get_current_superuser(self, user_db: BaseUserDatabase):
|
def get_current_superuser(self, user_db: BaseUserDatabase):
|
||||||
async def _get_current_superuser(token: str = Depends(self.oauth2_scheme)):
|
async def _get_current_superuser(token: str = Depends(self.oauth2_scheme)):
|
||||||
user = await self._get_authentication_method(user_db)(token)
|
user = await self._get_authentication_method(user_db)(token)
|
||||||
return self._get_current_superuser_base(user)
|
return self._get_current_superuser_base(user)
|
||||||
|
|
||||||
return _get_current_superuser
|
return _get_current_superuser
|
||||||
|
|
||||||
def _get_authentication_method(self, user_db: BaseUserDatabase):
|
def _get_authentication_method(self, user_db: BaseUserDatabase):
|
||||||
async def authentication_method(token: str = Depends(self.oauth2_scheme)):
|
async def authentication_method(token: str = Depends(self.oauth2_scheme)):
|
||||||
return await user_db.get(token)
|
return await user_db.get(token)
|
||||||
|
|
||||||
return authentication_method
|
return authentication_method
|
||||||
|
|
||||||
|
return MockAuthentication()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_authentication() -> MockAuthentication:
|
def get_test_auth_client(mock_user_db):
|
||||||
return MockAuthentication()
|
def _get_test_auth_client(authentication):
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/test-current-user")
|
||||||
|
def test_current_user(
|
||||||
|
user: BaseUserDB = Depends(authentication.get_current_user(mock_user_db))
|
||||||
|
):
|
||||||
|
return user
|
||||||
|
|
||||||
|
@app.get("/test-current-active-user")
|
||||||
|
def test_current_active_user(
|
||||||
|
user: BaseUserDB = Depends(
|
||||||
|
authentication.get_current_active_user(mock_user_db)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
return user
|
||||||
|
|
||||||
|
@app.get("/test-current-superuser")
|
||||||
|
def test_current_superuser(
|
||||||
|
user: BaseUserDB = Depends(
|
||||||
|
authentication.get_current_superuser(mock_user_db)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
return user
|
||||||
|
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
return _get_test_auth_client
|
||||||
|
@ -1,130 +1,25 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from fastapi import Depends, FastAPI
|
from starlette.responses import Response
|
||||||
from starlette import status
|
|
||||||
from starlette.testclient import TestClient
|
|
||||||
|
|
||||||
from fastapi_users.models import BaseUserDB
|
from fastapi_users.authentication import BaseAuthentication
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.mark.asyncio
|
||||||
def test_auth_client(mock_authentication, mock_user_db):
|
async def test_not_implemented_methods(user, mock_user_db):
|
||||||
app = FastAPI()
|
response = Response()
|
||||||
|
base_authentication = BaseAuthentication()
|
||||||
|
|
||||||
@app.get("/test-current-user")
|
with pytest.raises(NotImplementedError):
|
||||||
def test_current_user(
|
await base_authentication.get_login_response(user, response)
|
||||||
user: BaseUserDB = Depends(mock_authentication.get_current_user(mock_user_db))
|
|
||||||
):
|
|
||||||
return user
|
|
||||||
|
|
||||||
@app.get("/test-current-active-user")
|
with pytest.raises(NotImplementedError):
|
||||||
def test_current_active_user(
|
await base_authentication.get_current_user(mock_user_db)
|
||||||
user: BaseUserDB = Depends(
|
|
||||||
mock_authentication.get_current_active_user(mock_user_db)
|
|
||||||
)
|
|
||||||
):
|
|
||||||
return user
|
|
||||||
|
|
||||||
@app.get("/test-current-superuser")
|
with pytest.raises(NotImplementedError):
|
||||||
def test_current_superuser(
|
await base_authentication.get_current_active_user(mock_user_db)
|
||||||
user: BaseUserDB = Depends(
|
|
||||||
mock_authentication.get_current_superuser(mock_user_db)
|
|
||||||
)
|
|
||||||
):
|
|
||||||
return user
|
|
||||||
|
|
||||||
return TestClient(app)
|
with pytest.raises(NotImplementedError):
|
||||||
|
await base_authentication.get_current_superuser(mock_user_db)
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
class TestGetCurrentUser:
|
await base_authentication._get_authentication_method(mock_user_db)
|
||||||
def test_missing_token(self, test_auth_client):
|
|
||||||
response = test_auth_client.get("/test-current-user")
|
|
||||||
print(response.json())
|
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
||||||
|
|
||||||
def test_invalid_token(self, test_auth_client):
|
|
||||||
response = test_auth_client.get(
|
|
||||||
"/test-current-user", headers={"Authorization": "Bearer foo"}
|
|
||||||
)
|
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
||||||
|
|
||||||
def test_valid_token_inactive_user(self, test_auth_client, inactive_user):
|
|
||||||
response = test_auth_client.get(
|
|
||||||
"/test-current-user",
|
|
||||||
headers={"Authorization": f"Bearer {inactive_user.id}"},
|
|
||||||
)
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
|
||||||
|
|
||||||
response_json = response.json()
|
|
||||||
assert response_json["id"] == inactive_user.id
|
|
||||||
|
|
||||||
def test_valid_token(self, test_auth_client, user):
|
|
||||||
response = test_auth_client.get(
|
|
||||||
"/test-current-user", headers={"Authorization": f"Bearer {user.id}"}
|
|
||||||
)
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
|
||||||
|
|
||||||
response_json = response.json()
|
|
||||||
assert response_json["id"] == user.id
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetCurrentActiveUser:
|
|
||||||
def test_missing_token(self, test_auth_client):
|
|
||||||
response = test_auth_client.get("/test-current-active-user")
|
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
||||||
|
|
||||||
def test_invalid_token(self, test_auth_client):
|
|
||||||
response = test_auth_client.get(
|
|
||||||
"/test-current-active-user", headers={"Authorization": "Bearer foo"}
|
|
||||||
)
|
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
||||||
|
|
||||||
def test_valid_token_inactive_user(self, test_auth_client, inactive_user):
|
|
||||||
response = test_auth_client.get(
|
|
||||||
"/test-current-active-user",
|
|
||||||
headers={"Authorization": f"Bearer {inactive_user.id}"},
|
|
||||||
)
|
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
||||||
|
|
||||||
def test_valid_token(self, test_auth_client, user):
|
|
||||||
response = test_auth_client.get(
|
|
||||||
"/test-current-active-user", headers={"Authorization": f"Bearer {user.id}"}
|
|
||||||
)
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
|
||||||
|
|
||||||
response_json = response.json()
|
|
||||||
assert response_json["id"] == user.id
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetCurrentSuperuser:
|
|
||||||
def test_missing_token(self, test_auth_client):
|
|
||||||
response = test_auth_client.get("/test-current-superuser")
|
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
||||||
|
|
||||||
def test_invalid_token(self, test_auth_client):
|
|
||||||
response = test_auth_client.get(
|
|
||||||
"/test-current-superuser", headers={"Authorization": "Bearer foo"}
|
|
||||||
)
|
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
||||||
|
|
||||||
def test_valid_token_inactive_user(self, test_auth_client, inactive_user):
|
|
||||||
response = test_auth_client.get(
|
|
||||||
"/test-current-superuser",
|
|
||||||
headers={"Authorization": f"Bearer {inactive_user.id}"},
|
|
||||||
)
|
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
||||||
|
|
||||||
def test_valid_token_regular_user(self, test_auth_client, user):
|
|
||||||
response = test_auth_client.get(
|
|
||||||
"/test-current-superuser", headers={"Authorization": f"Bearer {user.id}"}
|
|
||||||
)
|
|
||||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
|
||||||
|
|
||||||
def test_valid_token_superuser(self, test_auth_client, superuser):
|
|
||||||
response = test_auth_client.get(
|
|
||||||
"/test-current-superuser",
|
|
||||||
headers={"Authorization": f"Bearer {superuser.id}"},
|
|
||||||
)
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
|
||||||
|
|
||||||
response_json = response.json()
|
|
||||||
assert response_json["id"] == superuser.id
|
|
||||||
|
@ -1,12 +1,9 @@
|
|||||||
import jwt
|
import jwt
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import Depends, FastAPI
|
|
||||||
from starlette import status
|
from starlette import status
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
from starlette.testclient import TestClient
|
|
||||||
|
|
||||||
from fastapi_users.authentication.jwt import JWTAuthentication
|
from fastapi_users.authentication.jwt import JWTAuthentication
|
||||||
from fastapi_users.models import BaseUserDB
|
|
||||||
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
|
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
|
||||||
|
|
||||||
SECRET = "SECRET"
|
SECRET = "SECRET"
|
||||||
@ -20,24 +17,18 @@ def jwt_authentication():
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def token():
|
def token():
|
||||||
def _token(user, lifetime=LIFETIME):
|
def _token(user=None, lifetime=LIFETIME):
|
||||||
data = {"user_id": user.id, "aud": "fastapi-users:auth"}
|
data = {"aud": "fastapi-users:auth"}
|
||||||
|
if user is not None:
|
||||||
|
data["user_id"] = user.id
|
||||||
return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM)
|
return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM)
|
||||||
|
|
||||||
return _token
|
return _token
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def test_auth_client(jwt_authentication, mock_user_db):
|
def test_auth_client(get_test_auth_client, jwt_authentication):
|
||||||
app = FastAPI()
|
return get_test_auth_client(jwt_authentication)
|
||||||
|
|
||||||
@app.get("/test-auth")
|
|
||||||
def test_auth(
|
|
||||||
user: BaseUserDB = Depends(jwt_authentication.get_current_user(mock_user_db))
|
|
||||||
):
|
|
||||||
return user
|
|
||||||
|
|
||||||
return TestClient(app)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -55,18 +46,26 @@ async def test_get_login_response(jwt_authentication, user):
|
|||||||
|
|
||||||
class TestGetCurrentUser:
|
class TestGetCurrentUser:
|
||||||
def test_missing_token(self, test_auth_client):
|
def test_missing_token(self, test_auth_client):
|
||||||
response = test_auth_client.get("/test-auth")
|
response = test_auth_client.get("/test-current-user")
|
||||||
|
print(response.json())
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
def test_invalid_token(self, test_auth_client):
|
def test_invalid_token(self, test_auth_client):
|
||||||
response = test_auth_client.get(
|
response = test_auth_client.get(
|
||||||
"/test-auth", headers={"Authorization": "Bearer foo"}
|
"/test-current-user", headers={"Authorization": "Bearer foo"}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_valid_token_missing_user_payload(self, test_auth_client, token):
|
||||||
|
response = test_auth_client.get(
|
||||||
|
"/test-current-user", headers={"Authorization": f"Bearer {token()}"}
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user):
|
def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user):
|
||||||
response = test_auth_client.get(
|
response = test_auth_client.get(
|
||||||
"/test-auth", headers={"Authorization": f"Bearer {token(inactive_user)}"}
|
"/test-current-user",
|
||||||
|
headers={"Authorization": f"Bearer {token(inactive_user)}"},
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
@ -75,9 +74,74 @@ class TestGetCurrentUser:
|
|||||||
|
|
||||||
def test_valid_token(self, test_auth_client, token, user):
|
def test_valid_token(self, test_auth_client, token, user):
|
||||||
response = test_auth_client.get(
|
response = test_auth_client.get(
|
||||||
"/test-auth", headers={"Authorization": f"Bearer {token(user)}"}
|
"/test-current-user", headers={"Authorization": f"Bearer {token(user)}"}
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
assert response_json["id"] == user.id
|
assert response_json["id"] == user.id
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCurrentActiveUser:
|
||||||
|
def test_missing_token(self, test_auth_client):
|
||||||
|
response = test_auth_client.get("/test-current-active-user")
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_invalid_token(self, test_auth_client):
|
||||||
|
response = test_auth_client.get(
|
||||||
|
"/test-current-active-user", headers={"Authorization": "Bearer foo"}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user):
|
||||||
|
response = test_auth_client.get(
|
||||||
|
"/test-current-active-user",
|
||||||
|
headers={"Authorization": f"Bearer {token(inactive_user)}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_valid_token(self, test_auth_client, token, user):
|
||||||
|
response = test_auth_client.get(
|
||||||
|
"/test-current-active-user",
|
||||||
|
headers={"Authorization": f"Bearer {token(user)}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
assert response_json["id"] == user.id
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCurrentSuperuser:
|
||||||
|
def test_missing_token(self, test_auth_client):
|
||||||
|
response = test_auth_client.get("/test-current-superuser")
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_invalid_token(self, test_auth_client):
|
||||||
|
response = test_auth_client.get(
|
||||||
|
"/test-current-superuser", headers={"Authorization": "Bearer foo"}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user):
|
||||||
|
response = test_auth_client.get(
|
||||||
|
"/test-current-superuser",
|
||||||
|
headers={"Authorization": f"Bearer {token(inactive_user)}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_valid_token_regular_user(self, test_auth_client, token, user):
|
||||||
|
response = test_auth_client.get(
|
||||||
|
"/test-current-superuser",
|
||||||
|
headers={"Authorization": f"Bearer {token(user)}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
def test_valid_token_superuser(self, test_auth_client, token, superuser):
|
||||||
|
response = test_auth_client.get(
|
||||||
|
"/test-current-superuser",
|
||||||
|
headers={"Authorization": f"Bearer {token(superuser)}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
assert response_json["id"] == superuser.id
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
|
||||||
|
from fastapi_users.db import BaseUserDatabase
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def create_oauth2_password_request_form():
|
def create_oauth2_password_request_form():
|
||||||
@ -10,6 +12,26 @@ def create_oauth2_password_request_form():
|
|||||||
return _create_oauth2_password_request_form
|
return _create_oauth2_password_request_form
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_not_implemented_methods(user):
|
||||||
|
base_user_db = BaseUserDatabase()
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
await base_user_db.list()
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
await base_user_db.get("aaa")
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
await base_user_db.get_by_email("lancelot@camelot.bt")
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
await base_user_db.create(user)
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
await base_user_db.update(user)
|
||||||
|
|
||||||
|
|
||||||
class TestAuthenticate:
|
class TestAuthenticate:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_unknown_user(
|
async def test_unknown_user(
|
||||||
@ -37,3 +59,21 @@ class TestAuthenticate:
|
|||||||
user = await mock_user_db.authenticate(form)
|
user = await mock_user_db.authenticate(form)
|
||||||
assert user is not None
|
assert user is not None
|
||||||
assert user.email == "king.arthur@camelot.bt"
|
assert user.email == "king.arthur@camelot.bt"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upgrade_password_hash(
|
||||||
|
self, mocker, create_oauth2_password_request_form, mock_user_db
|
||||||
|
):
|
||||||
|
verify_and_update_password_patch = mocker.patch(
|
||||||
|
"fastapi_users.password.verify_and_update_password"
|
||||||
|
)
|
||||||
|
verify_and_update_password_patch.return_value = (True, "updated_hash")
|
||||||
|
mocker.spy(mock_user_db, "update")
|
||||||
|
|
||||||
|
form = create_oauth2_password_request_form(
|
||||||
|
"king.arthur@camelot.bt", "guinevere"
|
||||||
|
)
|
||||||
|
user = await mock_user_db.authenticate(form)
|
||||||
|
assert user is not None
|
||||||
|
assert user.email == "king.arthur@camelot.bt"
|
||||||
|
assert mock_user_db.update.called is True
|
||||||
|
@ -9,21 +9,22 @@ from fastapi_users.models import BaseUser, BaseUserDB
|
|||||||
SECRET = "SECRET"
|
SECRET = "SECRET"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
def sync_on_after_forgot_password():
|
||||||
def fastapi_users(mock_user_db, mock_authentication) -> FastAPIUsers:
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def async_on_after_forgot_password():
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=[sync_on_after_forgot_password, async_on_after_forgot_password])
|
||||||
|
def test_app_client(request, mock_user_db, mock_authentication) -> TestClient:
|
||||||
class User(BaseUser):
|
class User(BaseUser):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_after_forgot_password(user, token):
|
fastapi_users = FastAPIUsers(
|
||||||
pass
|
mock_user_db, mock_authentication, User, request.param, SECRET
|
||||||
|
|
||||||
return FastAPIUsers(
|
|
||||||
mock_user_db, mock_authentication, User, on_after_forgot_password, SECRET
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def test_app_client(fastapi_users: FastAPIUsers) -> TestClient:
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.include_router(fastapi_users.router)
|
app.include_router(fastapi_users.router)
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import asyncio
|
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import asynctest
|
||||||
import jwt
|
import jwt
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
@ -17,57 +17,48 @@ LIFETIME = 3600
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def forgot_password_token():
|
def forgot_password_token():
|
||||||
def _forgot_password_token(user_id, lifetime=LIFETIME):
|
def _forgot_password_token(user_id=None, lifetime=LIFETIME):
|
||||||
data = {"user_id": user_id, "aud": "fastapi-users:reset"}
|
data = {"aud": "fastapi-users:reset"}
|
||||||
|
if user_id is not None:
|
||||||
|
data["user_id"] = user_id
|
||||||
return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM)
|
return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM)
|
||||||
|
|
||||||
return _forgot_password_token
|
return _forgot_password_token
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def on_after_forgot_password_sync():
|
def on_after_forgot_password_sync():
|
||||||
on_after_forgot_password_mock = MagicMock(return_value=None)
|
return MagicMock(return_value=None)
|
||||||
return on_after_forgot_password_mock
|
|
||||||
|
|
||||||
|
def on_after_forgot_password_async():
|
||||||
|
return asynctest.CoroutineMock(return_value=None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=[on_after_forgot_password_sync, on_after_forgot_password_async])
|
||||||
|
def on_after_forgot_password(request):
|
||||||
|
return request.param()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def on_after_forgot_password_async():
|
def test_app_client(
|
||||||
on_after_forgot_password_mock = MagicMock(return_value=asyncio.Future())
|
mock_user_db, mock_authentication, on_after_forgot_password
|
||||||
on_after_forgot_password_mock.return_value.set_result(None)
|
) -> TestClient:
|
||||||
return on_after_forgot_password_mock
|
class User(BaseUser):
|
||||||
|
pass
|
||||||
|
|
||||||
|
userRouter = get_user_router(
|
||||||
|
mock_user_db,
|
||||||
|
User,
|
||||||
|
mock_authentication,
|
||||||
|
on_after_forgot_password,
|
||||||
|
SECRET,
|
||||||
|
LIFETIME,
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.fixture
|
app = FastAPI()
|
||||||
def get_test_app_client(mock_user_db, mock_authentication):
|
app.include_router(userRouter)
|
||||||
def _get_test_app_client(on_after_forgot_password) -> TestClient:
|
|
||||||
class User(BaseUser):
|
|
||||||
pass
|
|
||||||
|
|
||||||
userRouter = get_user_router(
|
return TestClient(app)
|
||||||
mock_user_db,
|
|
||||||
User,
|
|
||||||
mock_authentication,
|
|
||||||
on_after_forgot_password,
|
|
||||||
SECRET,
|
|
||||||
LIFETIME,
|
|
||||||
)
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
app.include_router(userRouter)
|
|
||||||
|
|
||||||
return TestClient(app)
|
|
||||||
|
|
||||||
return _get_test_app_client
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def test_app_client(get_test_app_client, on_after_forgot_password_sync):
|
|
||||||
return get_test_app_client(on_after_forgot_password_sync)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def test_app_client_async(get_test_app_client, on_after_forgot_password_async):
|
|
||||||
return get_test_app_client(on_after_forgot_password_async)
|
|
||||||
|
|
||||||
|
|
||||||
class TestRegister:
|
class TestRegister:
|
||||||
@ -134,59 +125,36 @@ class TestLogin:
|
|||||||
|
|
||||||
|
|
||||||
class TestForgotPassword:
|
class TestForgotPassword:
|
||||||
def test_empty_body(
|
def test_empty_body(self, test_app_client: TestClient, on_after_forgot_password):
|
||||||
self, test_app_client: TestClient, on_after_forgot_password_sync
|
|
||||||
):
|
|
||||||
response = test_app_client.post("/forgot-password", json={})
|
response = test_app_client.post("/forgot-password", json={})
|
||||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
assert on_after_forgot_password_sync.called is False
|
assert on_after_forgot_password.called is False
|
||||||
|
|
||||||
def test_not_existing_user(
|
def test_not_existing_user(
|
||||||
self, test_app_client: TestClient, on_after_forgot_password_sync
|
self, test_app_client: TestClient, on_after_forgot_password
|
||||||
):
|
):
|
||||||
json = {"email": "lancelot@camelot.bt"}
|
json = {"email": "lancelot@camelot.bt"}
|
||||||
response = test_app_client.post("/forgot-password", json=json)
|
response = test_app_client.post("/forgot-password", json=json)
|
||||||
assert response.status_code == status.HTTP_202_ACCEPTED
|
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||||
assert on_after_forgot_password_sync.called is False
|
assert on_after_forgot_password.called is False
|
||||||
|
|
||||||
def test_inactive_user(
|
def test_inactive_user(self, test_app_client: TestClient, on_after_forgot_password):
|
||||||
self, test_app_client: TestClient, on_after_forgot_password_sync
|
|
||||||
):
|
|
||||||
json = {"email": "percival@camelot.bt"}
|
json = {"email": "percival@camelot.bt"}
|
||||||
response = test_app_client.post("/forgot-password", json=json)
|
response = test_app_client.post("/forgot-password", json=json)
|
||||||
assert response.status_code == status.HTTP_202_ACCEPTED
|
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||||
assert on_after_forgot_password_sync.called is False
|
assert on_after_forgot_password.called is False
|
||||||
|
|
||||||
def test_existing_user_sync_hook(
|
def test_existing_user(
|
||||||
self, test_app_client: TestClient, on_after_forgot_password_sync, user
|
self, test_app_client: TestClient, on_after_forgot_password, user
|
||||||
):
|
):
|
||||||
json = {"email": "king.arthur@camelot.bt"}
|
json = {"email": "king.arthur@camelot.bt"}
|
||||||
response = test_app_client.post("/forgot-password", json=json)
|
response = test_app_client.post("/forgot-password", json=json)
|
||||||
assert response.status_code == status.HTTP_202_ACCEPTED
|
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||||
assert on_after_forgot_password_sync.called is True
|
assert on_after_forgot_password.called is True
|
||||||
|
|
||||||
actual_user = on_after_forgot_password_sync.call_args[0][0]
|
actual_user = on_after_forgot_password.call_args[0][0]
|
||||||
assert actual_user.id == user.id
|
assert actual_user.id == user.id
|
||||||
actual_token = on_after_forgot_password_sync.call_args[0][1]
|
actual_token = on_after_forgot_password.call_args[0][1]
|
||||||
decoded_token = jwt.decode(
|
|
||||||
actual_token,
|
|
||||||
SECRET,
|
|
||||||
audience="fastapi-users:reset",
|
|
||||||
algorithms=[JWT_ALGORITHM],
|
|
||||||
)
|
|
||||||
assert decoded_token["user_id"] == user.id
|
|
||||||
|
|
||||||
def test_existing_user_async_hook(
|
|
||||||
self, test_app_client_async: TestClient, on_after_forgot_password_async, user
|
|
||||||
):
|
|
||||||
json = {"email": "king.arthur@camelot.bt"}
|
|
||||||
response = test_app_client_async.post("/forgot-password", json=json)
|
|
||||||
assert response.status_code == status.HTTP_202_ACCEPTED
|
|
||||||
assert on_after_forgot_password_async.called is True
|
|
||||||
|
|
||||||
actual_user = on_after_forgot_password_async.call_args[0][0]
|
|
||||||
assert actual_user.id == user.id
|
|
||||||
actual_token = on_after_forgot_password_async.call_args[0][1]
|
|
||||||
decoded_token = jwt.decode(
|
decoded_token = jwt.decode(
|
||||||
actual_token,
|
actual_token,
|
||||||
SECRET,
|
SECRET,
|
||||||
@ -217,6 +185,21 @@ class TestResetPassword:
|
|||||||
print(response.json(), response.status_code)
|
print(response.json(), response.status_code)
|
||||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
|
||||||
|
def test_valid_token_missing_user_id_payload(
|
||||||
|
self,
|
||||||
|
mocker,
|
||||||
|
mock_user_db,
|
||||||
|
test_app_client: TestClient,
|
||||||
|
forgot_password_token,
|
||||||
|
inactive_user: BaseUserDB,
|
||||||
|
):
|
||||||
|
mocker.spy(mock_user_db, "update")
|
||||||
|
|
||||||
|
json = {"token": forgot_password_token(), "password": "holygrail"}
|
||||||
|
response = test_app_client.post("/reset-password", json=json)
|
||||||
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
assert mock_user_db.update.called is False
|
||||||
|
|
||||||
def test_inactive_user(
|
def test_inactive_user(
|
||||||
self,
|
self,
|
||||||
mocker,
|
mocker,
|
||||||
@ -244,6 +227,7 @@ class TestResetPassword:
|
|||||||
user: BaseUserDB,
|
user: BaseUserDB,
|
||||||
):
|
):
|
||||||
mocker.spy(mock_user_db, "update")
|
mocker.spy(mock_user_db, "update")
|
||||||
|
current_hashed_passord = user.hashed_password
|
||||||
|
|
||||||
json = {"token": forgot_password_token(user.id), "password": "holygrail"}
|
json = {"token": forgot_password_token(user.id), "password": "holygrail"}
|
||||||
response = test_app_client.post("/reset-password", json=json)
|
response = test_app_client.post("/reset-password", json=json)
|
||||||
@ -251,4 +235,4 @@ class TestResetPassword:
|
|||||||
assert mock_user_db.update.called is True
|
assert mock_user_db.update.called is True
|
||||||
|
|
||||||
updated_user = mock_user_db.update.call_args[0][0]
|
updated_user = mock_user_db.update.call_args[0][0]
|
||||||
assert updated_user.hashed_password != user.hashed_password
|
assert updated_user.hashed_password != current_hashed_passord
|
||||||
|
Reference in New Issue
Block a user