From ef6dd2c39c95de92d3482c3c96b9127f2c58a3ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 10 Oct 2019 18:55:11 +0200 Subject: [PATCH] Implement user-facing API --- Makefile | 6 +++ fastapi_users/__init__.py | 36 ++++++++++++++++ fastapi_users/authentication/__init__.py | 10 ++--- fastapi_users/authentication/jwt.py | 8 ++-- fastapi_users/db/__init__.py | 15 +++++-- fastapi_users/db/sqlalchemy.py | 8 ++++ tests/conftest.py | 25 +++++++---- tests/test_authentication_jwt.py | 10 +++-- tests/test_fastapi_users.py | 54 ++++++++++++++++++++++++ 9 files changed, 145 insertions(+), 27 deletions(-) create mode 100644 tests/test_fastapi_users.py diff --git a/Makefile b/Makefile index 5789f4b0..1677490c 100644 --- a/Makefile +++ b/Makefile @@ -6,3 +6,9 @@ format: test: $(PIPENV_RUN) pytest + +docs-serve: + $(PIPENV_RUN) mkdocs serve + +docs-publish: + $(PIPENV_RUN) mkdocs gh-deploy diff --git a/fastapi_users/__init__.py b/fastapi_users/__init__.py index e69de29b..7ee793b4 100644 --- a/fastapi_users/__init__.py +++ b/fastapi_users/__init__.py @@ -0,0 +1,36 @@ +from typing import Callable, Type + +from fastapi import APIRouter + +from fastapi_users.authentication import BaseAuthentication +from fastapi_users.db import BaseUserDatabase +from fastapi_users.models import BaseUser, BaseUserDB +from fastapi_users.router import get_user_router + + +class FastAPIUsers: + """ + Main object that ties together the component for users authentication. + + :param db: Database adapter instance. + :param auth: Authentication logic instance. + :param user_model: Pydantic model of a user. + + :attribute router: FastAPI router exposing authentication routes. + :attribute get_current_user: Dependency callable to inject authenticated user. + """ + + db: BaseUserDatabase + auth: BaseAuthentication + router: APIRouter + get_current_user: Callable[..., BaseUserDB] + + def __init__( + self, db: BaseUserDatabase, auth: BaseAuthentication, user_model: Type[BaseUser] + ): + self.db = db + self.auth = auth + self.router = get_user_router(self.db, user_model, self.auth) + + get_current_user = self.auth.get_authentication_method(self.db) + self.get_current_user = get_current_user # type: ignore diff --git a/fastapi_users/authentication/__init__.py b/fastapi_users/authentication/__init__.py index 2f2bad68..8421596c 100644 --- a/fastapi_users/authentication/__init__.py +++ b/fastapi_users/authentication/__init__.py @@ -7,14 +7,10 @@ from fastapi_users.models import BaseUserDB class BaseAuthentication: - - user_db: BaseUserDatabase - - def __init__(self, user_db: BaseUserDatabase): - self.user_db = user_db - async def get_login_response(self, user: BaseUserDB, response: Response): raise NotImplementedError() - def get_authentication_method(self) -> Callable[..., BaseUserDB]: + def get_authentication_method( + self, user_db: BaseUserDatabase + ) -> Callable[..., BaseUserDB]: raise NotImplementedError() diff --git a/fastapi_users/authentication/jwt.py b/fastapi_users/authentication/jwt.py index ad57779f..99f9d282 100644 --- a/fastapi_users/authentication/jwt.py +++ b/fastapi_users/authentication/jwt.py @@ -7,6 +7,7 @@ from starlette import status from starlette.responses import Response from fastapi_users.authentication import BaseAuthentication +from fastapi_users.db import BaseUserDatabase from fastapi_users.models import BaseUserDB oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login") @@ -25,8 +26,7 @@ class JWTAuthentication(BaseAuthentication): secret: str lifetime_seconds: int - def __init__(self, secret: str, lifetime_seconds: int, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, secret: str, lifetime_seconds: int): self.secret = secret self.lifetime_seconds = lifetime_seconds @@ -36,7 +36,7 @@ class JWTAuthentication(BaseAuthentication): return {"token": token} - def get_authentication_method(self): + def get_authentication_method(self, user_db: BaseUserDatabase): async def authentication_method(token: str = Depends(oauth2_scheme)): credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED @@ -50,7 +50,7 @@ class JWTAuthentication(BaseAuthentication): except jwt.PyJWTError: raise credentials_exception - user = await self.user_db.get(user_id) + user = await user_db.get(user_id) if user is None or not user.is_active: raise credentials_exception diff --git a/fastapi_users/db/__init__.py b/fastapi_users/db/__init__.py index d09faffd..42204c8b 100644 --- a/fastapi_users/db/__init__.py +++ b/fastapi_users/db/__init__.py @@ -7,29 +7,36 @@ from fastapi_users.password import get_password_hash, verify_and_update_password class BaseUserDatabase: - """ - Common interface exposing methods to list, get, create and update users in - the database. - """ + """Base adapter for retrieving, creating and updating users from a database.""" async def list(self) -> List[BaseUserDB]: + """List all users.""" raise NotImplementedError() async def get(self, id: str) -> Optional[BaseUserDB]: + """Get a single user by id.""" raise NotImplementedError() async def get_by_email(self, email: str) -> Optional[BaseUserDB]: + """Get a single user by email.""" raise NotImplementedError() async def create(self, user: BaseUserDB) -> BaseUserDB: + """Create a user.""" raise NotImplementedError() async def update(self, user: BaseUserDB) -> BaseUserDB: + """Update a user.""" raise NotImplementedError() async def authenticate( self, credentials: OAuth2PasswordRequestForm ) -> Optional[BaseUserDB]: + """ + Authenticate and return a user following an email and a password. + + Will automatically upgrade password hash if necessary. + """ user = await self.get_by_email(credentials.username) # Always run the hasher to mitigate timing attack diff --git a/fastapi_users/db/sqlalchemy.py b/fastapi_users/db/sqlalchemy.py index 1edb4915..8efaec0e 100644 --- a/fastapi_users/db/sqlalchemy.py +++ b/fastapi_users/db/sqlalchemy.py @@ -8,6 +8,8 @@ from fastapi_users.models import BaseUserDB class BaseUserTable: + """Base SQLAlchemy users table definition.""" + __tablename__ = "user" id = Column(String, primary_key=True) @@ -18,6 +20,12 @@ class BaseUserTable: class SQLAlchemyUserDatabase(BaseUserDatabase): + """ + Database adapter for SQLAlchemy. + + :param database: `Database` instance from `encode/databases`. + :param users: SQLAlchemy users table instance. + """ database: Database users: Table diff --git a/tests/conftest.py b/tests/conftest.py index 05f3ebcc..e78a85b9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,8 @@ from typing import Optional import pytest -from fastapi import HTTPException +from fastapi import Depends, HTTPException +from fastapi.security import OAuth2PasswordBearer from starlette import status from starlette.responses import Response @@ -62,13 +63,21 @@ class MockAuthentication(BaseAuthentication): async def get_login_response(self, user: BaseUserDB, response: Response): return {"token": user.id} - async def authenticate(self, token: str) -> BaseUserDB: - user = await self.user_db.get(token) - if user is None or not user.is_active: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - return user + def get_authentication_method(self, user_db: BaseUserDatabase): + oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login") + + async def authentication_method(token: str = Depends(oauth2_scheme)): + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED + ) + user = await user_db.get(token) + if user is None or not user.is_active: + raise credentials_exception + return user + + return authentication_method @pytest.fixture -def mock_authentication(mock_user_db) -> MockAuthentication: - return MockAuthentication(mock_user_db) +def mock_authentication() -> MockAuthentication: + return MockAuthentication() diff --git a/tests/test_authentication_jwt.py b/tests/test_authentication_jwt.py index a865eb41..39dc239c 100644 --- a/tests/test_authentication_jwt.py +++ b/tests/test_authentication_jwt.py @@ -14,8 +14,8 @@ LIFETIME = 3600 @pytest.fixture -def jwt_authentication(mock_user_db): - return JWTAuthentication(SECRET, LIFETIME, mock_user_db) +def jwt_authentication(): + return JWTAuthentication(SECRET, LIFETIME) @pytest.fixture @@ -28,12 +28,14 @@ def token(): @pytest.fixture -def test_auth_client(jwt_authentication): +def test_auth_client(jwt_authentication, mock_user_db): app = FastAPI() @app.get("/test-auth") def test_auth( - user: BaseUserDB = Depends(jwt_authentication.get_authentication_method()) + user: BaseUserDB = Depends( + jwt_authentication.get_authentication_method(mock_user_db) + ) ): return user diff --git a/tests/test_fastapi_users.py b/tests/test_fastapi_users.py new file mode 100644 index 00000000..78befb36 --- /dev/null +++ b/tests/test_fastapi_users.py @@ -0,0 +1,54 @@ +import pytest +from fastapi import Depends, FastAPI +from starlette import status +from starlette.testclient import TestClient + +from fastapi_users import FastAPIUsers +from fastapi_users.models import BaseUser, BaseUserDB + + +@pytest.fixture +def fastapi_users(mock_user_db, mock_authentication) -> FastAPIUsers: + class User(BaseUser): + pass + + return FastAPIUsers(mock_user_db, mock_authentication, User) + + +@pytest.fixture +def test_app_client(fastapi_users: FastAPIUsers) -> TestClient: + app = FastAPI() + app.include_router(fastapi_users.router) + + @app.get("/authenticated") + def authenticated(user=Depends(fastapi_users.get_current_user)): + return user + + return TestClient(app) + + +class TestRouter: + def test_routes_exist(self, test_app_client: TestClient): + response = test_app_client.post("/register") + assert response.status_code != status.HTTP_404_NOT_FOUND + + response = test_app_client.post("/login") + assert response.status_code != status.HTTP_404_NOT_FOUND + + +class TestGetCurrentUser: + def test_missing_token(self, test_app_client: TestClient): + response = test_app_client.get("/authenticated") + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_invalid_token(self, test_app_client: TestClient): + response = test_app_client.get( + "/authenticated", headers={"Authorization": "Bearer foo"} + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_valid_token(self, test_app_client: TestClient, user: BaseUserDB): + response = test_app_client.get( + "/authenticated", headers={"Authorization": f"Bearer {user.id}"} + ) + assert response.status_code == status.HTTP_200_OK