Implement variant of dep injections to get active/super user

This commit is contained in:
François Voron
2019-10-11 08:09:47 +02:00
parent ef796abb55
commit 76bb7bf6a5
8 changed files with 330 additions and 42 deletions

View File

@ -1,9 +1,8 @@
from typing import Optional
import pytest
from fastapi import Depends, HTTPException
from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer
from starlette import status
from starlette.responses import Response
from fastapi_users.authentication import BaseAuthentication
@ -24,6 +23,13 @@ inactive_user_data = BaseUserDB(
is_active=False,
)
superuser_data = BaseUserDB(
id="ccc",
email="merlin@camelot.bt",
hashed_password=get_password_hash("viviane"),
is_superuser=True,
)
@pytest.fixture
def user() -> BaseUserDB:
@ -35,12 +41,19 @@ def inactive_user() -> BaseUserDB:
return inactive_user_data
@pytest.fixture
def superuser() -> BaseUserDB:
return superuser_data
class MockUserDatabase(BaseUserDatabase):
async def get(self, id: str) -> Optional[BaseUserDB]:
if id == active_user_data.id:
return active_user_data
elif id == inactive_user_data.id:
return inactive_user_data
elif id == superuser_data.id:
return superuser_data
return None
async def get_by_email(self, email: str) -> Optional[BaseUserDB]:
@ -48,6 +61,8 @@ class MockUserDatabase(BaseUserDatabase):
return active_user_data
elif email == inactive_user_data.email:
return inactive_user_data
elif email == superuser_data.email:
return superuser_data
return None
async def create(self, user: BaseUserDB) -> BaseUserDB:
@ -60,20 +75,35 @@ def mock_user_db() -> MockUserDatabase:
class MockAuthentication(BaseAuthentication):
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
async def get_login_response(self, user: BaseUserDB, response: Response):
return {"token": user.id}
def get_authentication_method(self, user_db: BaseUserDatabase):
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
def get_current_user(self, user_db: BaseUserDatabase):
async def _get_current_user(token: str = Depends(self.oauth2_scheme)):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_user_base(user)
async def authentication_method(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED
)
user = await user_db.get(token)
if user is None or not user.is_active:
raise credentials_exception
return user
return _get_current_user
def get_current_active_user(self, user_db: BaseUserDatabase):
async def _get_current_active_user(token: str = Depends(self.oauth2_scheme)):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_active_user_base(user)
return _get_current_active_user
def get_current_superuser(self, user_db: BaseUserDatabase):
async def _get_current_superuser(token: str = Depends(self.oauth2_scheme)):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_superuser_base(user)
return _get_current_superuser
def _get_authentication_method(self, user_db: BaseUserDatabase):
async def authentication_method(token: str = Depends(self.oauth2_scheme)):
return await user_db.get(token)
return authentication_method

View File

@ -0,0 +1,130 @@
import pytest
from fastapi import Depends, FastAPI
from starlette import status
from starlette.testclient import TestClient
from fastapi_users.models import BaseUserDB
@pytest.fixture
def test_auth_client(mock_authentication, mock_user_db):
app = FastAPI()
@app.get("/test-current-user")
def test_current_user(
user: BaseUserDB = Depends(mock_authentication.get_current_user(mock_user_db))
):
return user
@app.get("/test-current-active-user")
def test_current_active_user(
user: BaseUserDB = Depends(
mock_authentication.get_current_active_user(mock_user_db)
)
):
return user
@app.get("/test-current-superuser")
def test_current_superuser(
user: BaseUserDB = Depends(
mock_authentication.get_current_superuser(mock_user_db)
)
):
return user
return TestClient(app)
class TestGetCurrentUser:
def test_missing_token(self, test_auth_client):
response = test_auth_client.get("/test-current-user")
print(response.json())
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_auth_client):
response = test_auth_client.get(
"/test-current-user", headers={"Authorization": "Bearer foo"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_inactive_user(self, test_auth_client, inactive_user):
response = test_auth_client.get(
"/test-current-user",
headers={"Authorization": f"Bearer {inactive_user.id}"},
)
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == inactive_user.id
def test_valid_token(self, test_auth_client, user):
response = test_auth_client.get(
"/test-current-user", headers={"Authorization": f"Bearer {user.id}"}
)
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == user.id
class TestGetCurrentActiveUser:
def test_missing_token(self, test_auth_client):
response = test_auth_client.get("/test-current-active-user")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_auth_client):
response = test_auth_client.get(
"/test-current-active-user", headers={"Authorization": "Bearer foo"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_inactive_user(self, test_auth_client, inactive_user):
response = test_auth_client.get(
"/test-current-active-user",
headers={"Authorization": f"Bearer {inactive_user.id}"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token(self, test_auth_client, user):
response = test_auth_client.get(
"/test-current-active-user", headers={"Authorization": f"Bearer {user.id}"}
)
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == user.id
class TestGetCurrentSuperuser:
def test_missing_token(self, test_auth_client):
response = test_auth_client.get("/test-current-superuser")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_auth_client):
response = test_auth_client.get(
"/test-current-superuser", headers={"Authorization": "Bearer foo"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_inactive_user(self, test_auth_client, inactive_user):
response = test_auth_client.get(
"/test-current-superuser",
headers={"Authorization": f"Bearer {inactive_user.id}"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_regular_user(self, test_auth_client, user):
response = test_auth_client.get(
"/test-current-superuser", headers={"Authorization": f"Bearer {user.id}"}
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_valid_token_superuser(self, test_auth_client, superuser):
response = test_auth_client.get(
"/test-current-superuser",
headers={"Authorization": f"Bearer {superuser.id}"},
)
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == superuser.id

View File

@ -33,9 +33,7 @@ def test_auth_client(jwt_authentication, mock_user_db):
@app.get("/test-auth")
def test_auth(
user: BaseUserDB = Depends(
jwt_authentication.get_authentication_method(mock_user_db)
)
user: BaseUserDB = Depends(jwt_authentication.get_current_user(mock_user_db))
):
return user
@ -53,7 +51,7 @@ async def test_get_login_response(jwt_authentication, user):
assert decoded["user_id"] == user.id
class TestGetAuthenticationMethod:
class TestGetCurrentUser:
def test_missing_token(self, test_auth_client):
response = test_auth_client.get("/test-auth")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@ -68,7 +66,10 @@ class TestGetAuthenticationMethod:
response = test_auth_client.get(
"/test-auth", headers={"Authorization": f"Bearer {token(inactive_user)}"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == inactive_user.id
def test_valid_token(self, test_auth_client, token, user):
response = test_auth_client.get(

View File

@ -7,6 +7,8 @@ from databases import Database
from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base
from fastapi_users.db.sqlalchemy import SQLAlchemyBaseUserTable, SQLAlchemyUserDatabase
from fastapi_users.models import BaseUserDB
from fastapi_users.password import get_password_hash
@pytest.fixture
@ -32,7 +34,13 @@ async def sqlalchemy_user_db() -> AsyncGenerator[SQLAlchemyUserDatabase, None]:
@pytest.mark.asyncio
async def test_queries(user, sqlalchemy_user_db):
async def test_queries(sqlalchemy_user_db):
user = BaseUserDB(
id="111",
email="lancelot@camelot.bt",
hashed_password=get_password_hash("guinevere"),
)
# Create
user_db = await sqlalchemy_user_db.create(user)
assert user_db.id is not None
@ -64,5 +72,5 @@ async def test_queries(user, sqlalchemy_user_db):
await sqlalchemy_user_db.create(user)
# Unknown user
unknown_user = await sqlalchemy_user_db.get_by_email("lancelot@camelot.bt")
unknown_user = await sqlalchemy_user_db.get_by_email("galahad@camelot.bt")
assert unknown_user is None

View File

@ -20,8 +20,16 @@ def test_app_client(fastapi_users: FastAPIUsers) -> TestClient:
app = FastAPI()
app.include_router(fastapi_users.router)
@app.get("/authenticated")
def authenticated(user=Depends(fastapi_users.get_current_user)):
@app.get("/current-user")
def current_user(user=Depends(fastapi_users.get_current_user)):
return user
@app.get("/current-active-user")
def current_active_user(user=Depends(fastapi_users.get_current_active_user)):
return user
@app.get("/current-superuser")
def current_superuser(user=Depends(fastapi_users.get_current_superuser)):
return user
return TestClient(app)
@ -38,17 +46,72 @@ class TestRouter:
class TestGetCurrentUser:
def test_missing_token(self, test_app_client: TestClient):
response = test_app_client.get("/authenticated")
response = test_app_client.get("/current-user")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_app_client: TestClient):
response = test_app_client.get(
"/authenticated", headers={"Authorization": "Bearer foo"}
"/current-user", headers={"Authorization": "Bearer foo"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token(self, test_app_client: TestClient, user: BaseUserDB):
response = test_app_client.get(
"/authenticated", headers={"Authorization": f"Bearer {user.id}"}
"/current-user", headers={"Authorization": f"Bearer {user.id}"}
)
assert response.status_code == status.HTTP_200_OK
class TestGetCurrentActiveUser:
def test_missing_token(self, test_app_client: TestClient):
response = test_app_client.get("/current-active-user")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_app_client: TestClient):
response = test_app_client.get(
"/current-active-user", headers={"Authorization": "Bearer foo"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_inactive_user(
self, test_app_client: TestClient, inactive_user: BaseUserDB
):
response = test_app_client.get(
"/current-active-user",
headers={"Authorization": f"Bearer {inactive_user.id}"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token(self, test_app_client: TestClient, user: BaseUserDB):
response = test_app_client.get(
"/current-active-user", headers={"Authorization": f"Bearer {user.id}"}
)
assert response.status_code == status.HTTP_200_OK
class TestGetCurrentSuperuser:
def test_missing_token(self, test_app_client: TestClient):
response = test_app_client.get("/current-superuser")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_app_client: TestClient):
response = test_app_client.get(
"/current-superuser", headers={"Authorization": "Bearer foo"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_regular_user(
self, test_app_client: TestClient, user: BaseUserDB
):
response = test_app_client.get(
"/current-superuser", headers={"Authorization": f"Bearer {user.id}"}
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_valid_token_superuser(
self, test_app_client: TestClient, superuser: BaseUserDB
):
response = test_app_client.get(
"/current-superuser", headers={"Authorization": f"Bearer {superuser.id}"}
)
assert response.status_code == status.HTTP_200_OK