diff --git a/fastapi_users/authentication/base.py b/fastapi_users/authentication/base.py index a152b145..73a21cdd 100644 --- a/fastapi_users/authentication/base.py +++ b/fastapi_users/authentication/base.py @@ -1,18 +1,57 @@ from typing import Callable +from fastapi import HTTPException +from starlette import status from starlette.responses import Response from fastapi_users.db import BaseUserDatabase from fastapi_users.models import BaseUserDB +credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) + class BaseAuthentication: - """Base adapter for generating and decoding authentication tokens.""" + """ + Base adapter for generating and decoding authentication tokens. + + Provides dependency injectors to get current active/superuser user. + """ async def get_login_response(self, user: BaseUserDB, response: Response): raise NotImplementedError() - def get_authentication_method( + def get_current_user(self, user_db: BaseUserDatabase): + raise NotImplementedError() + + def get_current_active_user(self, user_db: BaseUserDatabase): + raise NotImplementedError() + + def get_current_superuser(self, user_db: BaseUserDatabase): + raise NotImplementedError() + + def _get_authentication_method( self, user_db: BaseUserDatabase ) -> Callable[..., BaseUserDB]: raise NotImplementedError() + + def _get_current_user_base(self, user: BaseUserDB) -> BaseUserDB: + if user is None: + raise self._get_credentials_exception() + return user + + def _get_current_active_user_base(self, user: BaseUserDB) -> BaseUserDB: + user = self._get_current_user_base(user) + if not user.is_active: + raise self._get_credentials_exception() + return user + + def _get_current_superuser_base(self, user: BaseUserDB) -> BaseUserDB: + user = self._get_current_active_user_base(user) + if not user.is_superuser: + raise self._get_credentials_exception(status.HTTP_403_FORBIDDEN) + return user + + def _get_credentials_exception( + self, status_code: int = status.HTTP_401_UNAUTHORIZED + ) -> HTTPException: + return HTTPException(status_code=status_code) diff --git a/fastapi_users/authentication/jwt.py b/fastapi_users/authentication/jwt.py index 73110ffd..768886cc 100644 --- a/fastapi_users/authentication/jwt.py +++ b/fastapi_users/authentication/jwt.py @@ -1,9 +1,8 @@ from datetime import datetime, timedelta import jwt -from fastapi import Depends, HTTPException +from fastapi import Depends from fastapi.security import OAuth2PasswordBearer -from starlette import status from starlette.responses import Response from fastapi_users.authentication.base import BaseAuthentication @@ -42,24 +41,36 @@ class JWTAuthentication(BaseAuthentication): return {"token": token} - def get_authentication_method(self, user_db: BaseUserDatabase): - async def authentication_method(token: str = Depends(oauth2_scheme)): - credentials_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED - ) + def get_current_user(self, user_db: BaseUserDatabase): + async def _get_current_user(token: str = Depends(oauth2_scheme)): + user = await self._get_authentication_method(user_db)(token) + return self._get_current_user_base(user) + return _get_current_user + + def get_current_active_user(self, user_db: BaseUserDatabase): + async def _get_current_active_user(token: str = Depends(oauth2_scheme)): + user = await self._get_authentication_method(user_db)(token) + return self._get_current_active_user_base(user) + + return _get_current_active_user + + def get_current_superuser(self, user_db: BaseUserDatabase): + async def _get_current_superuser(token: str = Depends(oauth2_scheme)): + user = await self._get_authentication_method(user_db)(token) + return self._get_current_superuser_base(user) + + return _get_current_superuser + + def _get_authentication_method(self, user_db: BaseUserDatabase): + async def authentication_method(token: str = Depends(oauth2_scheme)): try: data = jwt.decode(token, self.secret, algorithms=[self.algorithm]) user_id = data.get("user_id") if user_id is None: - raise credentials_exception + return None except jwt.PyJWTError: - raise credentials_exception - - user = await user_db.get(user_id) - if user is None or not user.is_active: - raise credentials_exception - - return user + return None + return await user_db.get(user_id) return authentication_method diff --git a/fastapi_users/fastapi_users.py b/fastapi_users/fastapi_users.py index a1df00b4..dc57e238 100644 --- a/fastapi_users/fastapi_users.py +++ b/fastapi_users/fastapi_users.py @@ -34,5 +34,11 @@ class FastAPIUsers: self.auth = auth self.router = get_user_router(self.db, user_model, self.auth) - get_current_user = self.auth.get_authentication_method(self.db) + get_current_user = self.auth.get_current_user(self.db) self.get_current_user = get_current_user # type: ignore + + get_current_active_user = self.auth.get_current_active_user(self.db) + self.get_current_active_user = get_current_active_user # type: ignore + + get_current_superuser = self.auth.get_current_superuser(self.db) + self.get_current_superuser = get_current_superuser # type: ignore diff --git a/tests/conftest.py b/tests/conftest.py index e78a85b9..60d89086 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,8 @@ from typing import Optional import pytest -from fastapi import Depends, HTTPException +from fastapi import Depends from fastapi.security import OAuth2PasswordBearer -from starlette import status from starlette.responses import Response from fastapi_users.authentication import BaseAuthentication @@ -24,6 +23,13 @@ inactive_user_data = BaseUserDB( is_active=False, ) +superuser_data = BaseUserDB( + id="ccc", + email="merlin@camelot.bt", + hashed_password=get_password_hash("viviane"), + is_superuser=True, +) + @pytest.fixture def user() -> BaseUserDB: @@ -35,12 +41,19 @@ def inactive_user() -> BaseUserDB: return inactive_user_data +@pytest.fixture +def superuser() -> BaseUserDB: + return superuser_data + + class MockUserDatabase(BaseUserDatabase): async def get(self, id: str) -> Optional[BaseUserDB]: if id == active_user_data.id: return active_user_data elif id == inactive_user_data.id: return inactive_user_data + elif id == superuser_data.id: + return superuser_data return None async def get_by_email(self, email: str) -> Optional[BaseUserDB]: @@ -48,6 +61,8 @@ class MockUserDatabase(BaseUserDatabase): return active_user_data elif email == inactive_user_data.email: return inactive_user_data + elif email == superuser_data.email: + return superuser_data return None async def create(self, user: BaseUserDB) -> BaseUserDB: @@ -60,20 +75,35 @@ def mock_user_db() -> MockUserDatabase: class MockAuthentication(BaseAuthentication): + oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login") + async def get_login_response(self, user: BaseUserDB, response: Response): return {"token": user.id} - def get_authentication_method(self, user_db: BaseUserDatabase): - oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login") + def get_current_user(self, user_db: BaseUserDatabase): + async def _get_current_user(token: str = Depends(self.oauth2_scheme)): + user = await self._get_authentication_method(user_db)(token) + return self._get_current_user_base(user) - async def authentication_method(token: str = Depends(oauth2_scheme)): - credentials_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED - ) - user = await user_db.get(token) - if user is None or not user.is_active: - raise credentials_exception - return user + return _get_current_user + + def get_current_active_user(self, user_db: BaseUserDatabase): + async def _get_current_active_user(token: str = Depends(self.oauth2_scheme)): + user = await self._get_authentication_method(user_db)(token) + return self._get_current_active_user_base(user) + + return _get_current_active_user + + def get_current_superuser(self, user_db: BaseUserDatabase): + async def _get_current_superuser(token: str = Depends(self.oauth2_scheme)): + user = await self._get_authentication_method(user_db)(token) + return self._get_current_superuser_base(user) + + return _get_current_superuser + + def _get_authentication_method(self, user_db: BaseUserDatabase): + async def authentication_method(token: str = Depends(self.oauth2_scheme)): + return await user_db.get(token) return authentication_method diff --git a/tests/test_authentication_base.py b/tests/test_authentication_base.py new file mode 100644 index 00000000..8f4cf792 --- /dev/null +++ b/tests/test_authentication_base.py @@ -0,0 +1,130 @@ +import pytest +from fastapi import Depends, FastAPI +from starlette import status +from starlette.testclient import TestClient + +from fastapi_users.models import BaseUserDB + + +@pytest.fixture +def test_auth_client(mock_authentication, mock_user_db): + app = FastAPI() + + @app.get("/test-current-user") + def test_current_user( + user: BaseUserDB = Depends(mock_authentication.get_current_user(mock_user_db)) + ): + return user + + @app.get("/test-current-active-user") + def test_current_active_user( + user: BaseUserDB = Depends( + mock_authentication.get_current_active_user(mock_user_db) + ) + ): + return user + + @app.get("/test-current-superuser") + def test_current_superuser( + user: BaseUserDB = Depends( + mock_authentication.get_current_superuser(mock_user_db) + ) + ): + return user + + return TestClient(app) + + +class TestGetCurrentUser: + def test_missing_token(self, test_auth_client): + response = test_auth_client.get("/test-current-user") + print(response.json()) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_invalid_token(self, test_auth_client): + response = test_auth_client.get( + "/test-current-user", headers={"Authorization": "Bearer foo"} + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_valid_token_inactive_user(self, test_auth_client, inactive_user): + response = test_auth_client.get( + "/test-current-user", + headers={"Authorization": f"Bearer {inactive_user.id}"}, + ) + assert response.status_code == status.HTTP_200_OK + + response_json = response.json() + assert response_json["id"] == inactive_user.id + + def test_valid_token(self, test_auth_client, user): + response = test_auth_client.get( + "/test-current-user", headers={"Authorization": f"Bearer {user.id}"} + ) + assert response.status_code == status.HTTP_200_OK + + response_json = response.json() + assert response_json["id"] == user.id + + +class TestGetCurrentActiveUser: + def test_missing_token(self, test_auth_client): + response = test_auth_client.get("/test-current-active-user") + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_invalid_token(self, test_auth_client): + response = test_auth_client.get( + "/test-current-active-user", headers={"Authorization": "Bearer foo"} + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_valid_token_inactive_user(self, test_auth_client, inactive_user): + response = test_auth_client.get( + "/test-current-active-user", + headers={"Authorization": f"Bearer {inactive_user.id}"}, + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_valid_token(self, test_auth_client, user): + response = test_auth_client.get( + "/test-current-active-user", headers={"Authorization": f"Bearer {user.id}"} + ) + assert response.status_code == status.HTTP_200_OK + + response_json = response.json() + assert response_json["id"] == user.id + + +class TestGetCurrentSuperuser: + def test_missing_token(self, test_auth_client): + response = test_auth_client.get("/test-current-superuser") + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_invalid_token(self, test_auth_client): + response = test_auth_client.get( + "/test-current-superuser", headers={"Authorization": "Bearer foo"} + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_valid_token_inactive_user(self, test_auth_client, inactive_user): + response = test_auth_client.get( + "/test-current-superuser", + headers={"Authorization": f"Bearer {inactive_user.id}"}, + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_valid_token_regular_user(self, test_auth_client, user): + response = test_auth_client.get( + "/test-current-superuser", headers={"Authorization": f"Bearer {user.id}"} + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_valid_token_superuser(self, test_auth_client, superuser): + response = test_auth_client.get( + "/test-current-superuser", + headers={"Authorization": f"Bearer {superuser.id}"}, + ) + assert response.status_code == status.HTTP_200_OK + + response_json = response.json() + assert response_json["id"] == superuser.id diff --git a/tests/test_authentication_jwt.py b/tests/test_authentication_jwt.py index 39dc239c..f2c73994 100644 --- a/tests/test_authentication_jwt.py +++ b/tests/test_authentication_jwt.py @@ -33,9 +33,7 @@ def test_auth_client(jwt_authentication, mock_user_db): @app.get("/test-auth") def test_auth( - user: BaseUserDB = Depends( - jwt_authentication.get_authentication_method(mock_user_db) - ) + user: BaseUserDB = Depends(jwt_authentication.get_current_user(mock_user_db)) ): return user @@ -53,7 +51,7 @@ async def test_get_login_response(jwt_authentication, user): assert decoded["user_id"] == user.id -class TestGetAuthenticationMethod: +class TestGetCurrentUser: def test_missing_token(self, test_auth_client): response = test_auth_client.get("/test-auth") assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -68,7 +66,10 @@ class TestGetAuthenticationMethod: response = test_auth_client.get( "/test-auth", headers={"Authorization": f"Bearer {token(inactive_user)}"} ) - assert response.status_code == status.HTTP_401_UNAUTHORIZED + assert response.status_code == status.HTTP_200_OK + + response_json = response.json() + assert response_json["id"] == inactive_user.id def test_valid_token(self, test_auth_client, token, user): response = test_auth_client.get( diff --git a/tests/test_db_sqlalchemy.py b/tests/test_db_sqlalchemy.py index 51da0c3f..81f1fbbd 100644 --- a/tests/test_db_sqlalchemy.py +++ b/tests/test_db_sqlalchemy.py @@ -7,6 +7,8 @@ from databases import Database from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base from fastapi_users.db.sqlalchemy import SQLAlchemyBaseUserTable, SQLAlchemyUserDatabase +from fastapi_users.models import BaseUserDB +from fastapi_users.password import get_password_hash @pytest.fixture @@ -32,7 +34,13 @@ async def sqlalchemy_user_db() -> AsyncGenerator[SQLAlchemyUserDatabase, None]: @pytest.mark.asyncio -async def test_queries(user, sqlalchemy_user_db): +async def test_queries(sqlalchemy_user_db): + user = BaseUserDB( + id="111", + email="lancelot@camelot.bt", + hashed_password=get_password_hash("guinevere"), + ) + # Create user_db = await sqlalchemy_user_db.create(user) assert user_db.id is not None @@ -64,5 +72,5 @@ async def test_queries(user, sqlalchemy_user_db): await sqlalchemy_user_db.create(user) # Unknown user - unknown_user = await sqlalchemy_user_db.get_by_email("lancelot@camelot.bt") + unknown_user = await sqlalchemy_user_db.get_by_email("galahad@camelot.bt") assert unknown_user is None diff --git a/tests/test_fastapi_users.py b/tests/test_fastapi_users.py index 78befb36..a7fa1442 100644 --- a/tests/test_fastapi_users.py +++ b/tests/test_fastapi_users.py @@ -20,8 +20,16 @@ def test_app_client(fastapi_users: FastAPIUsers) -> TestClient: app = FastAPI() app.include_router(fastapi_users.router) - @app.get("/authenticated") - def authenticated(user=Depends(fastapi_users.get_current_user)): + @app.get("/current-user") + def current_user(user=Depends(fastapi_users.get_current_user)): + return user + + @app.get("/current-active-user") + def current_active_user(user=Depends(fastapi_users.get_current_active_user)): + return user + + @app.get("/current-superuser") + def current_superuser(user=Depends(fastapi_users.get_current_superuser)): return user return TestClient(app) @@ -38,17 +46,72 @@ class TestRouter: class TestGetCurrentUser: def test_missing_token(self, test_app_client: TestClient): - response = test_app_client.get("/authenticated") + response = test_app_client.get("/current-user") assert response.status_code == status.HTTP_401_UNAUTHORIZED def test_invalid_token(self, test_app_client: TestClient): response = test_app_client.get( - "/authenticated", headers={"Authorization": "Bearer foo"} + "/current-user", headers={"Authorization": "Bearer foo"} ) assert response.status_code == status.HTTP_401_UNAUTHORIZED def test_valid_token(self, test_app_client: TestClient, user: BaseUserDB): response = test_app_client.get( - "/authenticated", headers={"Authorization": f"Bearer {user.id}"} + "/current-user", headers={"Authorization": f"Bearer {user.id}"} + ) + assert response.status_code == status.HTTP_200_OK + + +class TestGetCurrentActiveUser: + def test_missing_token(self, test_app_client: TestClient): + response = test_app_client.get("/current-active-user") + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_invalid_token(self, test_app_client: TestClient): + response = test_app_client.get( + "/current-active-user", headers={"Authorization": "Bearer foo"} + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_valid_token_inactive_user( + self, test_app_client: TestClient, inactive_user: BaseUserDB + ): + response = test_app_client.get( + "/current-active-user", + headers={"Authorization": f"Bearer {inactive_user.id}"}, + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_valid_token(self, test_app_client: TestClient, user: BaseUserDB): + response = test_app_client.get( + "/current-active-user", headers={"Authorization": f"Bearer {user.id}"} + ) + assert response.status_code == status.HTTP_200_OK + + +class TestGetCurrentSuperuser: + def test_missing_token(self, test_app_client: TestClient): + response = test_app_client.get("/current-superuser") + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_invalid_token(self, test_app_client: TestClient): + response = test_app_client.get( + "/current-superuser", headers={"Authorization": "Bearer foo"} + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_valid_token_regular_user( + self, test_app_client: TestClient, user: BaseUserDB + ): + response = test_app_client.get( + "/current-superuser", headers={"Authorization": f"Bearer {user.id}"} + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_valid_token_superuser( + self, test_app_client: TestClient, superuser: BaseUserDB + ): + response = test_app_client.get( + "/current-superuser", headers={"Authorization": f"Bearer {superuser.id}"} ) assert response.status_code == status.HTTP_200_OK