From 0112e700ac686946b4d9243cc46cb9801e39ebc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 10 Oct 2019 13:37:52 +0200 Subject: [PATCH] Improve typing and make User pydantic models dynamic --- Pipfile | 1 + Pipfile.lock | 53 +++++++++++++++++++++++- fastapi_users/authentication/__init__.py | 6 +-- fastapi_users/authentication/jwt.py | 6 +-- fastapi_users/db/__init__.py | 18 ++++---- fastapi_users/db/sqlalchemy.py | 22 +++++----- fastapi_users/models.py | 29 +++++++++---- fastapi_users/router.py | 46 ++++++++++---------- tests/conftest.py | 22 +++++----- tests/test_authentication_jwt.py | 4 +- tests/test_db_sqlalchemy.py | 11 ++--- tests/test_router.py | 11 +++-- 12 files changed, 153 insertions(+), 76 deletions(-) diff --git a/Pipfile b/Pipfile index 32d8226a..7ce83db7 100644 --- a/Pipfile +++ b/Pipfile @@ -14,6 +14,7 @@ flake8-docstrings = "*" mkdocs = "*" mkdocs-material = "*" black = "*" +mypy = "*" [packages] fastapi = "*" diff --git a/Pipfile.lock b/Pipfile.lock index 96f94965..3a456df0 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "4687ef95ee5576f1882e551641586bdbeda40a663bb6d9b9ff95d4259e4cd023" + "sha256": "c1aa33adc0a5d3c81741b012a2cab88f01697fa74b7259bfbdaaec18ebe60036" }, "pipfile-spec": 6, "requires": { @@ -383,6 +383,29 @@ ], "version": "==7.2.0" }, + "mypy": { + "hashes": [ + "sha256:1d98fd818ad3128a5408148c9e4a5edce6ed6b58cc314283e631dd5d9216527b", + "sha256:22ee018e8fc212fe601aba65d3699689dd29a26410ef0d2cc1943de7bec7e3ac", + "sha256:3a24f80776edc706ec8d05329e854d5b9e464cd332e25cde10c8da2da0a0db6c", + "sha256:42a78944e80770f21609f504ca6c8173f7768043205b5ac51c9144e057dcf879", + "sha256:4b2b20106973548975f0c0b1112eceb4d77ed0cafe0a231a1318f3b3a22fc795", + "sha256:591a9625b4d285f3ba69f541c84c0ad9e7bffa7794da3fa0585ef13cf95cb021", + "sha256:5b4b70da3d8bae73b908a90bb2c387b977e59d484d22c604a2131f6f4397c1a3", + "sha256:84edda1ffeda0941b2ab38ecf49302326df79947fa33d98cdcfbf8ca9cf0bb23", + "sha256:b2b83d29babd61b876ae375786960a5374bba0e4aba3c293328ca6ca5dc448dd", + "sha256:cc4502f84c37223a1a5ab700649b5ab1b5e4d2bf2d426907161f20672a21930b", + "sha256:e29e24dd6e7f39f200a5bb55dcaa645d38a397dd5a6674f6042ef02df5795046" + ], + "index": "pypi", + "version": "==0.730" + }, + "mypy-extensions": { + "hashes": [ + "sha256:a161e3b917053de87dbe469987e173e49fb454eca10ef28b48b384538cc11458" + ], + "version": "==0.4.2" + }, "packaging": { "hashes": [ "sha256:28b924174df7a2fa32c1953825ff29c61e2f5e082343165438812f00d3a7fc47", @@ -535,6 +558,34 @@ ], "version": "==6.0.3" }, + "typed-ast": { + "hashes": [ + "sha256:18511a0b3e7922276346bcb47e2ef9f38fb90fd31cb9223eed42c85d1312344e", + "sha256:262c247a82d005e43b5b7f69aff746370538e176131c32dda9cb0f324d27141e", + "sha256:2b907eb046d049bcd9892e3076c7a6456c93a25bebfe554e931620c90e6a25b0", + "sha256:354c16e5babd09f5cb0ee000d54cfa38401d8b8891eefa878ac772f827181a3c", + "sha256:4e0b70c6fc4d010f8107726af5fd37921b666f5b31d9331f0bd24ad9a088e631", + "sha256:630968c5cdee51a11c05a30453f8cd65e0cc1d2ad0d9192819df9978984529f4", + "sha256:66480f95b8167c9c5c5c87f32cf437d585937970f3fc24386f313a4c97b44e34", + "sha256:71211d26ffd12d63a83e079ff258ac9d56a1376a25bc80b1cdcdf601b855b90b", + "sha256:95bd11af7eafc16e829af2d3df510cecfd4387f6453355188342c3e79a2ec87a", + "sha256:bc6c7d3fa1325a0c6613512a093bc2a2a15aeec350451cbdf9e1d4bffe3e3233", + "sha256:cc34a6f5b426748a507dd5d1de4c1978f2eb5626d51326e43280941206c209e1", + "sha256:d755f03c1e4a51e9b24d899561fec4ccaf51f210d52abdf8c07ee2849b212a36", + "sha256:d7c45933b1bdfaf9f36c579671fec15d25b06c8398f113dab64c18ed1adda01d", + "sha256:d896919306dd0aa22d0132f62a1b78d11aaf4c9fc5b3410d3c666b818191630a", + "sha256:ffde2fbfad571af120fcbfbbc61c72469e72f550d676c3342492a9dfdefb8f12" + ], + "version": "==1.4.0" + }, + "typing-extensions": { + "hashes": [ + "sha256:2ed632b30bb54fc3941c382decfd0ee4148f5c591651c9272473fea2c6397d95", + "sha256:b1edbbf0652660e32ae780ac9433f4231e7339c7f9a8057d0f042fcbcea49b87", + "sha256:d8179012ec2c620d3791ca6fe2bf7979d979acdbef1fca0bc56b37411db682ed" + ], + "version": "==3.7.4" + }, "urllib3": { "hashes": [ "sha256:3de946ffbed6e6746608990594d08faac602528ac7015ac28d33cee6a45b7398", diff --git a/fastapi_users/authentication/__init__.py b/fastapi_users/authentication/__init__.py index 5c326f27..2f2bad68 100644 --- a/fastapi_users/authentication/__init__.py +++ b/fastapi_users/authentication/__init__.py @@ -3,7 +3,7 @@ from typing import Callable from starlette.responses import Response from fastapi_users.db import BaseUserDatabase -from fastapi_users.models import UserDB +from fastapi_users.models import BaseUserDB class BaseAuthentication: @@ -13,8 +13,8 @@ class BaseAuthentication: def __init__(self, user_db: BaseUserDatabase): self.user_db = user_db - async def get_login_response(self, user: UserDB, response: Response): + async def get_login_response(self, user: BaseUserDB, response: Response): raise NotImplementedError() - def get_authentication_method(self) -> Callable[..., UserDB]: + def get_authentication_method(self) -> Callable[..., BaseUserDB]: raise NotImplementedError() diff --git a/fastapi_users/authentication/jwt.py b/fastapi_users/authentication/jwt.py index deb6af8c..ad57779f 100644 --- a/fastapi_users/authentication/jwt.py +++ b/fastapi_users/authentication/jwt.py @@ -7,7 +7,7 @@ from starlette import status from starlette.responses import Response from fastapi_users.authentication import BaseAuthentication -from fastapi_users.models import UserDB +from fastapi_users.models import BaseUserDB oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login") @@ -30,7 +30,7 @@ class JWTAuthentication(BaseAuthentication): self.secret = secret self.lifetime_seconds = lifetime_seconds - async def get_login_response(self, user: UserDB, response: Response): + async def get_login_response(self, user: BaseUserDB, response: Response): data = {"user_id": user.id} token = generate_jwt(data, self.lifetime_seconds, self.secret, self.algorithm) @@ -44,7 +44,7 @@ class JWTAuthentication(BaseAuthentication): try: data = jwt.decode(token, self.secret, algorithms=[self.algorithm]) - user_id: str = data.get("user_id") + user_id = data.get("user_id") if user_id is None: raise credentials_exception except jwt.PyJWTError: diff --git a/fastapi_users/db/__init__.py b/fastapi_users/db/__init__.py index 5370a71c..d09faffd 100644 --- a/fastapi_users/db/__init__.py +++ b/fastapi_users/db/__init__.py @@ -1,8 +1,8 @@ -from typing import List +from typing import List, Optional from fastapi.security import OAuth2PasswordRequestForm -from fastapi_users.models import UserDB +from fastapi_users.models import BaseUserDB from fastapi_users.password import get_password_hash, verify_and_update_password @@ -12,22 +12,24 @@ class BaseUserDatabase: the database. """ - async def list(self) -> List[UserDB]: + async def list(self) -> List[BaseUserDB]: raise NotImplementedError() - async def get(self, id: str) -> UserDB: + async def get(self, id: str) -> Optional[BaseUserDB]: raise NotImplementedError() - async def get_by_email(self, email: str) -> UserDB: + async def get_by_email(self, email: str) -> Optional[BaseUserDB]: raise NotImplementedError() - async def create(self, user: UserDB) -> UserDB: + async def create(self, user: BaseUserDB) -> BaseUserDB: raise NotImplementedError() - async def update(self, user: UserDB) -> UserDB: + async def update(self, user: BaseUserDB) -> BaseUserDB: raise NotImplementedError() - async def authenticate(self, credentials: OAuth2PasswordRequestForm) -> UserDB: + async def authenticate( + self, credentials: OAuth2PasswordRequestForm + ) -> Optional[BaseUserDB]: user = await self.get_by_email(credentials.username) # Always run the hasher to mitigate timing attack diff --git a/fastapi_users/db/sqlalchemy.py b/fastapi_users/db/sqlalchemy.py index c0542707..1edb4915 100644 --- a/fastapi_users/db/sqlalchemy.py +++ b/fastapi_users/db/sqlalchemy.py @@ -1,13 +1,13 @@ -from typing import List +from typing import List, cast from databases import Database from sqlalchemy import Boolean, Column, String, Table from fastapi_users.db import BaseUserDatabase -from fastapi_users.models import UserDB +from fastapi_users.models import BaseUserDB -class BaseUser: +class BaseUserTable: __tablename__ = "user" id = Column(String, primary_key=True) @@ -26,24 +26,24 @@ class SQLAlchemyUserDatabase(BaseUserDatabase): self.database = database self.users = users - async def list(self) -> List[UserDB]: + async def list(self) -> List[BaseUserDB]: query = self.users.select() - return await self.database.fetch_all(query) + return cast(List[BaseUserDB], await self.database.fetch_all(query)) - async def get(self, id: str) -> UserDB: + async def get(self, id: str) -> BaseUserDB: query = self.users.select().where(self.users.c.id == id) - return await self.database.fetch_one(query) + return cast(BaseUserDB, await self.database.fetch_one(query)) - async def get_by_email(self, email: str) -> UserDB: + async def get_by_email(self, email: str) -> BaseUserDB: query = self.users.select().where(self.users.c.email == email) - return await self.database.fetch_one(query) + return cast(BaseUserDB, await self.database.fetch_one(query)) - async def create(self, user: UserDB) -> UserDB: + async def create(self, user: BaseUserDB) -> BaseUserDB: query = self.users.insert().values(**user.dict()) await self.database.execute(query) return user - async def update(self, user: UserDB) -> UserDB: + async def update(self, user: BaseUserDB) -> BaseUserDB: query = ( self.users.update().where(self.users.c.id == user.id).values(**user.dict()) ) diff --git a/fastapi_users/models.py b/fastapi_users/models.py index 3205198c..e2660a6f 100644 --- a/fastapi_users/models.py +++ b/fastapi_users/models.py @@ -1,13 +1,13 @@ import uuid -from typing import Optional +from typing import Optional, Type import pydantic from pydantic import BaseModel from pydantic.types import EmailStr -class UserBase(BaseModel): - id: str = None +class BaseUser(BaseModel): + id: Optional[str] = None email: Optional[EmailStr] = None is_active: Optional[bool] = True is_superuser: Optional[bool] = False @@ -17,18 +17,31 @@ class UserBase(BaseModel): return v or str(uuid.uuid4()) -class UserCreate(UserBase): +class BaseUserCreate(BaseUser): email: EmailStr password: str -class UserUpdate(UserBase): +class BaseUserUpdate(BaseUser): pass -class UserDB(UserBase): +class BaseUserDB(BaseUser): hashed_password: str -class User(UserBase): - pass +class Models: + def __init__(self, user_model: Type[BaseUser]): + class UserCreate(user_model, BaseUserCreate): # type: ignore + pass + + class UserUpdate(user_model, BaseUserUpdate): # type: ignore + pass + + class UserDB(user_model, BaseUserDB): # type: ignore + pass + + self.User = user_model + self.UserCreate = UserCreate + self.UserUpdate = UserUpdate + self.UserDB = UserDB diff --git a/fastapi_users/router.py b/fastapi_users/router.py index 4d93e582..f091bd76 100644 --- a/fastapi_users/router.py +++ b/fastapi_users/router.py @@ -1,3 +1,5 @@ +from typing import Type + from fastapi import APIRouter, Depends, HTTPException from fastapi.security import OAuth2PasswordRequestForm from starlette import status @@ -5,32 +7,34 @@ from starlette.responses import Response from fastapi_users.authentication import BaseAuthentication from fastapi_users.db import BaseUserDatabase -from fastapi_users.models import User, UserCreate, UserDB +from fastapi_users.models import BaseUser, Models from fastapi_users.password import get_password_hash -class UserRouter: - def __new__(cls, user_db: BaseUserDatabase, auth: BaseAuthentication) -> APIRouter: - router = APIRouter() +def get_user_router( + user_db: BaseUserDatabase, user_model: Type[BaseUser], auth: BaseAuthentication +) -> APIRouter: + router = APIRouter() + models = Models(user_model) - @router.post("/register", response_model=User) - async def register(user: UserCreate): - hashed_password = get_password_hash(user.password) - db_user = UserDB(**user.dict(), hashed_password=hashed_password) - created_user = await user_db.create(db_user) - return created_user + @router.post("/register", response_model=models.User) + async def register(user: models.UserCreate): # type: ignore + hashed_password = get_password_hash(user.password) + db_user = models.UserDB(**user.dict(), hashed_password=hashed_password) + created_user = await user_db.create(db_user) + return created_user - @router.post("/login") - async def login( - response: Response, credentials: OAuth2PasswordRequestForm = Depends() - ): - user = await user_db.authenticate(credentials) + @router.post("/login") + async def login( + response: Response, credentials: OAuth2PasswordRequestForm = Depends() + ): + user = await user_db.authenticate(credentials) - if user is None: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) - elif not user.is_active: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) + if user is None: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) + elif not user.is_active: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) - return await auth.get_login_response(user, response) + return await auth.get_login_response(user, response) - return router + return router diff --git a/tests/conftest.py b/tests/conftest.py index 42025ff5..05f3ebcc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +from typing import Optional + import pytest from fastapi import HTTPException from starlette import status @@ -5,16 +7,16 @@ from starlette.responses import Response from fastapi_users.authentication import BaseAuthentication from fastapi_users.db import BaseUserDatabase -from fastapi_users.models import UserDB +from fastapi_users.models import BaseUserDB from fastapi_users.password import get_password_hash -active_user_data = UserDB( +active_user_data = BaseUserDB( id="aaa", email="king.arthur@camelot.bt", hashed_password=get_password_hash("guinevere"), ) -inactive_user_data = UserDB( +inactive_user_data = BaseUserDB( id="bbb", email="percival@camelot.bt", hashed_password=get_password_hash("angharad"), @@ -23,31 +25,31 @@ inactive_user_data = UserDB( @pytest.fixture -def user() -> UserDB: +def user() -> BaseUserDB: return active_user_data @pytest.fixture -def inactive_user() -> UserDB: +def inactive_user() -> BaseUserDB: return inactive_user_data class MockUserDatabase(BaseUserDatabase): - async def get(self, id: str) -> UserDB: + async def get(self, id: str) -> Optional[BaseUserDB]: if id == active_user_data.id: return active_user_data elif id == inactive_user_data.id: return inactive_user_data return None - async def get_by_email(self, email: str) -> UserDB: + async def get_by_email(self, email: str) -> Optional[BaseUserDB]: if email == active_user_data.email: return active_user_data elif email == inactive_user_data.email: return inactive_user_data return None - async def create(self, user: UserDB) -> UserDB: + async def create(self, user: BaseUserDB) -> BaseUserDB: return user @@ -57,10 +59,10 @@ def mock_user_db() -> MockUserDatabase: class MockAuthentication(BaseAuthentication): - async def get_login_response(self, user: UserDB, response: Response): + async def get_login_response(self, user: BaseUserDB, response: Response): return {"token": user.id} - async def authenticate(self, token: str) -> UserDB: + async def authenticate(self, token: str) -> BaseUserDB: user = await self.user_db.get(token) if user is None or not user.is_active: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) diff --git a/tests/test_authentication_jwt.py b/tests/test_authentication_jwt.py index 52bd3d59..a865eb41 100644 --- a/tests/test_authentication_jwt.py +++ b/tests/test_authentication_jwt.py @@ -6,7 +6,7 @@ from starlette.responses import Response from starlette.testclient import TestClient from fastapi_users.authentication.jwt import JWTAuthentication, generate_jwt -from fastapi_users.models import UserDB +from fastapi_users.models import BaseUserDB SECRET = "SECRET" ALGORITHM = "HS256" @@ -33,7 +33,7 @@ def test_auth_client(jwt_authentication): @app.get("/test-auth") def test_auth( - user: UserDB = Depends(jwt_authentication.get_authentication_method()) + user: BaseUserDB = Depends(jwt_authentication.get_authentication_method()) ): return user diff --git a/tests/test_db_sqlalchemy.py b/tests/test_db_sqlalchemy.py index f67c1938..aaf5641f 100644 --- a/tests/test_db_sqlalchemy.py +++ b/tests/test_db_sqlalchemy.py @@ -1,18 +1,19 @@ import sqlite3 +from typing import AsyncGenerator import pytest import sqlalchemy from databases import Database -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base -from fastapi_users.db.sqlalchemy import BaseUser, SQLAlchemyUserDatabase +from fastapi_users.db.sqlalchemy import BaseUserTable, SQLAlchemyUserDatabase @pytest.fixture -async def sqlalchemy_user_db() -> SQLAlchemyUserDatabase: - Base = declarative_base() +async def sqlalchemy_user_db() -> AsyncGenerator[SQLAlchemyUserDatabase, None]: + Base: DeclarativeMeta = declarative_base() - class User(BaseUser, Base): + class User(BaseUserTable, Base): pass DATABASE_URL = "sqlite:///./test.db" diff --git a/tests/test_router.py b/tests/test_router.py index 399cb589..997d49dd 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -3,13 +3,16 @@ from fastapi import FastAPI from starlette import status from starlette.testclient import TestClient -from fastapi_users.models import UserDB -from fastapi_users.router import UserRouter +from fastapi_users.models import BaseUser, BaseUserDB +from fastapi_users.router import get_user_router @pytest.fixture def test_app_client(mock_user_db, mock_authentication) -> TestClient: - userRouter = UserRouter(mock_user_db, mock_authentication) + class User(BaseUser): + pass + + userRouter = get_user_router(mock_user_db, User, mock_authentication) app = FastAPI() app.include_router(userRouter) @@ -68,7 +71,7 @@ class TestLogin: response = test_app_client.post("/login", data=data) assert response.status_code == status.HTTP_400_BAD_REQUEST - def test_valid_credentials(self, test_app_client: TestClient, user: UserDB): + def test_valid_credentials(self, test_app_client: TestClient, user: BaseUserDB): data = {"username": "king.arthur@camelot.bt", "password": "guinevere"} response = test_app_client.post("/login", data=data) assert response.status_code == status.HTTP_200_OK