Implement variant of dep injections to get active/super user

This commit is contained in:
François Voron
2019-10-11 08:09:47 +02:00
parent ef796abb55
commit 76bb7bf6a5
8 changed files with 330 additions and 42 deletions

View File

@ -1,18 +1,57 @@
from typing import Callable from typing import Callable
from fastapi import HTTPException
from starlette import status
from starlette.responses import Response from starlette.responses import Response
from fastapi_users.db import BaseUserDatabase from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import BaseUserDB from fastapi_users.models import BaseUserDB
credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
class BaseAuthentication: class BaseAuthentication:
"""Base adapter for generating and decoding authentication tokens.""" """
Base adapter for generating and decoding authentication tokens.
Provides dependency injectors to get current active/superuser user.
"""
async def get_login_response(self, user: BaseUserDB, response: Response): async def get_login_response(self, user: BaseUserDB, response: Response):
raise NotImplementedError() raise NotImplementedError()
def get_authentication_method( def get_current_user(self, user_db: BaseUserDatabase):
raise NotImplementedError()
def get_current_active_user(self, user_db: BaseUserDatabase):
raise NotImplementedError()
def get_current_superuser(self, user_db: BaseUserDatabase):
raise NotImplementedError()
def _get_authentication_method(
self, user_db: BaseUserDatabase self, user_db: BaseUserDatabase
) -> Callable[..., BaseUserDB]: ) -> Callable[..., BaseUserDB]:
raise NotImplementedError() raise NotImplementedError()
def _get_current_user_base(self, user: BaseUserDB) -> BaseUserDB:
if user is None:
raise self._get_credentials_exception()
return user
def _get_current_active_user_base(self, user: BaseUserDB) -> BaseUserDB:
user = self._get_current_user_base(user)
if not user.is_active:
raise self._get_credentials_exception()
return user
def _get_current_superuser_base(self, user: BaseUserDB) -> BaseUserDB:
user = self._get_current_active_user_base(user)
if not user.is_superuser:
raise self._get_credentials_exception(status.HTTP_403_FORBIDDEN)
return user
def _get_credentials_exception(
self, status_code: int = status.HTTP_401_UNAUTHORIZED
) -> HTTPException:
return HTTPException(status_code=status_code)

View File

@ -1,9 +1,8 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
import jwt import jwt
from fastapi import Depends, HTTPException from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from starlette import status
from starlette.responses import Response from starlette.responses import Response
from fastapi_users.authentication.base import BaseAuthentication from fastapi_users.authentication.base import BaseAuthentication
@ -42,24 +41,36 @@ class JWTAuthentication(BaseAuthentication):
return {"token": token} return {"token": token}
def get_authentication_method(self, user_db: BaseUserDatabase): def get_current_user(self, user_db: BaseUserDatabase):
async def authentication_method(token: str = Depends(oauth2_scheme)): async def _get_current_user(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException( user = await self._get_authentication_method(user_db)(token)
status_code=status.HTTP_401_UNAUTHORIZED return self._get_current_user_base(user)
)
return _get_current_user
def get_current_active_user(self, user_db: BaseUserDatabase):
async def _get_current_active_user(token: str = Depends(oauth2_scheme)):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_active_user_base(user)
return _get_current_active_user
def get_current_superuser(self, user_db: BaseUserDatabase):
async def _get_current_superuser(token: str = Depends(oauth2_scheme)):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_superuser_base(user)
return _get_current_superuser
def _get_authentication_method(self, user_db: BaseUserDatabase):
async def authentication_method(token: str = Depends(oauth2_scheme)):
try: try:
data = jwt.decode(token, self.secret, algorithms=[self.algorithm]) data = jwt.decode(token, self.secret, algorithms=[self.algorithm])
user_id = data.get("user_id") user_id = data.get("user_id")
if user_id is None: if user_id is None:
raise credentials_exception return None
except jwt.PyJWTError: except jwt.PyJWTError:
raise credentials_exception return None
return await user_db.get(user_id)
user = await user_db.get(user_id)
if user is None or not user.is_active:
raise credentials_exception
return user
return authentication_method return authentication_method

View File

@ -34,5 +34,11 @@ class FastAPIUsers:
self.auth = auth self.auth = auth
self.router = get_user_router(self.db, user_model, self.auth) self.router = get_user_router(self.db, user_model, self.auth)
get_current_user = self.auth.get_authentication_method(self.db) get_current_user = self.auth.get_current_user(self.db)
self.get_current_user = get_current_user # type: ignore self.get_current_user = get_current_user # type: ignore
get_current_active_user = self.auth.get_current_active_user(self.db)
self.get_current_active_user = get_current_active_user # type: ignore
get_current_superuser = self.auth.get_current_superuser(self.db)
self.get_current_superuser = get_current_superuser # type: ignore

View File

@ -1,9 +1,8 @@
from typing import Optional from typing import Optional
import pytest import pytest
from fastapi import Depends, HTTPException from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from starlette import status
from starlette.responses import Response from starlette.responses import Response
from fastapi_users.authentication import BaseAuthentication from fastapi_users.authentication import BaseAuthentication
@ -24,6 +23,13 @@ inactive_user_data = BaseUserDB(
is_active=False, is_active=False,
) )
superuser_data = BaseUserDB(
id="ccc",
email="merlin@camelot.bt",
hashed_password=get_password_hash("viviane"),
is_superuser=True,
)
@pytest.fixture @pytest.fixture
def user() -> BaseUserDB: def user() -> BaseUserDB:
@ -35,12 +41,19 @@ def inactive_user() -> BaseUserDB:
return inactive_user_data return inactive_user_data
@pytest.fixture
def superuser() -> BaseUserDB:
return superuser_data
class MockUserDatabase(BaseUserDatabase): class MockUserDatabase(BaseUserDatabase):
async def get(self, id: str) -> Optional[BaseUserDB]: async def get(self, id: str) -> Optional[BaseUserDB]:
if id == active_user_data.id: if id == active_user_data.id:
return active_user_data return active_user_data
elif id == inactive_user_data.id: elif id == inactive_user_data.id:
return inactive_user_data return inactive_user_data
elif id == superuser_data.id:
return superuser_data
return None return None
async def get_by_email(self, email: str) -> Optional[BaseUserDB]: async def get_by_email(self, email: str) -> Optional[BaseUserDB]:
@ -48,6 +61,8 @@ class MockUserDatabase(BaseUserDatabase):
return active_user_data return active_user_data
elif email == inactive_user_data.email: elif email == inactive_user_data.email:
return inactive_user_data return inactive_user_data
elif email == superuser_data.email:
return superuser_data
return None return None
async def create(self, user: BaseUserDB) -> BaseUserDB: async def create(self, user: BaseUserDB) -> BaseUserDB:
@ -60,20 +75,35 @@ def mock_user_db() -> MockUserDatabase:
class MockAuthentication(BaseAuthentication): class MockAuthentication(BaseAuthentication):
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
async def get_login_response(self, user: BaseUserDB, response: Response): async def get_login_response(self, user: BaseUserDB, response: Response):
return {"token": user.id} return {"token": user.id}
def get_authentication_method(self, user_db: BaseUserDatabase): def get_current_user(self, user_db: BaseUserDatabase):
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login") async def _get_current_user(token: str = Depends(self.oauth2_scheme)):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_user_base(user)
async def authentication_method(token: str = Depends(oauth2_scheme)): return _get_current_user
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED def get_current_active_user(self, user_db: BaseUserDatabase):
) async def _get_current_active_user(token: str = Depends(self.oauth2_scheme)):
user = await user_db.get(token) user = await self._get_authentication_method(user_db)(token)
if user is None or not user.is_active: return self._get_current_active_user_base(user)
raise credentials_exception
return user return _get_current_active_user
def get_current_superuser(self, user_db: BaseUserDatabase):
async def _get_current_superuser(token: str = Depends(self.oauth2_scheme)):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_superuser_base(user)
return _get_current_superuser
def _get_authentication_method(self, user_db: BaseUserDatabase):
async def authentication_method(token: str = Depends(self.oauth2_scheme)):
return await user_db.get(token)
return authentication_method return authentication_method

View File

@ -0,0 +1,130 @@
import pytest
from fastapi import Depends, FastAPI
from starlette import status
from starlette.testclient import TestClient
from fastapi_users.models import BaseUserDB
@pytest.fixture
def test_auth_client(mock_authentication, mock_user_db):
app = FastAPI()
@app.get("/test-current-user")
def test_current_user(
user: BaseUserDB = Depends(mock_authentication.get_current_user(mock_user_db))
):
return user
@app.get("/test-current-active-user")
def test_current_active_user(
user: BaseUserDB = Depends(
mock_authentication.get_current_active_user(mock_user_db)
)
):
return user
@app.get("/test-current-superuser")
def test_current_superuser(
user: BaseUserDB = Depends(
mock_authentication.get_current_superuser(mock_user_db)
)
):
return user
return TestClient(app)
class TestGetCurrentUser:
def test_missing_token(self, test_auth_client):
response = test_auth_client.get("/test-current-user")
print(response.json())
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_auth_client):
response = test_auth_client.get(
"/test-current-user", headers={"Authorization": "Bearer foo"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_inactive_user(self, test_auth_client, inactive_user):
response = test_auth_client.get(
"/test-current-user",
headers={"Authorization": f"Bearer {inactive_user.id}"},
)
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == inactive_user.id
def test_valid_token(self, test_auth_client, user):
response = test_auth_client.get(
"/test-current-user", headers={"Authorization": f"Bearer {user.id}"}
)
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == user.id
class TestGetCurrentActiveUser:
def test_missing_token(self, test_auth_client):
response = test_auth_client.get("/test-current-active-user")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_auth_client):
response = test_auth_client.get(
"/test-current-active-user", headers={"Authorization": "Bearer foo"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_inactive_user(self, test_auth_client, inactive_user):
response = test_auth_client.get(
"/test-current-active-user",
headers={"Authorization": f"Bearer {inactive_user.id}"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token(self, test_auth_client, user):
response = test_auth_client.get(
"/test-current-active-user", headers={"Authorization": f"Bearer {user.id}"}
)
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == user.id
class TestGetCurrentSuperuser:
def test_missing_token(self, test_auth_client):
response = test_auth_client.get("/test-current-superuser")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_auth_client):
response = test_auth_client.get(
"/test-current-superuser", headers={"Authorization": "Bearer foo"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_inactive_user(self, test_auth_client, inactive_user):
response = test_auth_client.get(
"/test-current-superuser",
headers={"Authorization": f"Bearer {inactive_user.id}"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_regular_user(self, test_auth_client, user):
response = test_auth_client.get(
"/test-current-superuser", headers={"Authorization": f"Bearer {user.id}"}
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_valid_token_superuser(self, test_auth_client, superuser):
response = test_auth_client.get(
"/test-current-superuser",
headers={"Authorization": f"Bearer {superuser.id}"},
)
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == superuser.id

View File

@ -33,9 +33,7 @@ def test_auth_client(jwt_authentication, mock_user_db):
@app.get("/test-auth") @app.get("/test-auth")
def test_auth( def test_auth(
user: BaseUserDB = Depends( user: BaseUserDB = Depends(jwt_authentication.get_current_user(mock_user_db))
jwt_authentication.get_authentication_method(mock_user_db)
)
): ):
return user return user
@ -53,7 +51,7 @@ async def test_get_login_response(jwt_authentication, user):
assert decoded["user_id"] == user.id assert decoded["user_id"] == user.id
class TestGetAuthenticationMethod: class TestGetCurrentUser:
def test_missing_token(self, test_auth_client): def test_missing_token(self, test_auth_client):
response = test_auth_client.get("/test-auth") response = test_auth_client.get("/test-auth")
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
@ -68,7 +66,10 @@ class TestGetAuthenticationMethod:
response = test_auth_client.get( response = test_auth_client.get(
"/test-auth", headers={"Authorization": f"Bearer {token(inactive_user)}"} "/test-auth", headers={"Authorization": f"Bearer {token(inactive_user)}"}
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == inactive_user.id
def test_valid_token(self, test_auth_client, token, user): def test_valid_token(self, test_auth_client, token, user):
response = test_auth_client.get( response = test_auth_client.get(

View File

@ -7,6 +7,8 @@ from databases import Database
from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base
from fastapi_users.db.sqlalchemy import SQLAlchemyBaseUserTable, SQLAlchemyUserDatabase from fastapi_users.db.sqlalchemy import SQLAlchemyBaseUserTable, SQLAlchemyUserDatabase
from fastapi_users.models import BaseUserDB
from fastapi_users.password import get_password_hash
@pytest.fixture @pytest.fixture
@ -32,7 +34,13 @@ async def sqlalchemy_user_db() -> AsyncGenerator[SQLAlchemyUserDatabase, None]:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_queries(user, sqlalchemy_user_db): async def test_queries(sqlalchemy_user_db):
user = BaseUserDB(
id="111",
email="lancelot@camelot.bt",
hashed_password=get_password_hash("guinevere"),
)
# Create # Create
user_db = await sqlalchemy_user_db.create(user) user_db = await sqlalchemy_user_db.create(user)
assert user_db.id is not None assert user_db.id is not None
@ -64,5 +72,5 @@ async def test_queries(user, sqlalchemy_user_db):
await sqlalchemy_user_db.create(user) await sqlalchemy_user_db.create(user)
# Unknown user # Unknown user
unknown_user = await sqlalchemy_user_db.get_by_email("lancelot@camelot.bt") unknown_user = await sqlalchemy_user_db.get_by_email("galahad@camelot.bt")
assert unknown_user is None assert unknown_user is None

View File

@ -20,8 +20,16 @@ def test_app_client(fastapi_users: FastAPIUsers) -> TestClient:
app = FastAPI() app = FastAPI()
app.include_router(fastapi_users.router) app.include_router(fastapi_users.router)
@app.get("/authenticated") @app.get("/current-user")
def authenticated(user=Depends(fastapi_users.get_current_user)): def current_user(user=Depends(fastapi_users.get_current_user)):
return user
@app.get("/current-active-user")
def current_active_user(user=Depends(fastapi_users.get_current_active_user)):
return user
@app.get("/current-superuser")
def current_superuser(user=Depends(fastapi_users.get_current_superuser)):
return user return user
return TestClient(app) return TestClient(app)
@ -38,17 +46,72 @@ class TestRouter:
class TestGetCurrentUser: class TestGetCurrentUser:
def test_missing_token(self, test_app_client: TestClient): def test_missing_token(self, test_app_client: TestClient):
response = test_app_client.get("/authenticated") response = test_app_client.get("/current-user")
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_app_client: TestClient): def test_invalid_token(self, test_app_client: TestClient):
response = test_app_client.get( response = test_app_client.get(
"/authenticated", headers={"Authorization": "Bearer foo"} "/current-user", headers={"Authorization": "Bearer foo"}
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token(self, test_app_client: TestClient, user: BaseUserDB): def test_valid_token(self, test_app_client: TestClient, user: BaseUserDB):
response = test_app_client.get( response = test_app_client.get(
"/authenticated", headers={"Authorization": f"Bearer {user.id}"} "/current-user", headers={"Authorization": f"Bearer {user.id}"}
)
assert response.status_code == status.HTTP_200_OK
class TestGetCurrentActiveUser:
def test_missing_token(self, test_app_client: TestClient):
response = test_app_client.get("/current-active-user")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_app_client: TestClient):
response = test_app_client.get(
"/current-active-user", headers={"Authorization": "Bearer foo"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_inactive_user(
self, test_app_client: TestClient, inactive_user: BaseUserDB
):
response = test_app_client.get(
"/current-active-user",
headers={"Authorization": f"Bearer {inactive_user.id}"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token(self, test_app_client: TestClient, user: BaseUserDB):
response = test_app_client.get(
"/current-active-user", headers={"Authorization": f"Bearer {user.id}"}
)
assert response.status_code == status.HTTP_200_OK
class TestGetCurrentSuperuser:
def test_missing_token(self, test_app_client: TestClient):
response = test_app_client.get("/current-superuser")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_app_client: TestClient):
response = test_app_client.get(
"/current-superuser", headers={"Authorization": "Bearer foo"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_regular_user(
self, test_app_client: TestClient, user: BaseUserDB
):
response = test_app_client.get(
"/current-superuser", headers={"Authorization": f"Bearer {user.id}"}
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_valid_token_superuser(
self, test_app_client: TestClient, superuser: BaseUserDB
):
response = test_app_client.get(
"/current-superuser", headers={"Authorization": f"Bearer {superuser.id}"}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK