Improve typing and make User pydantic models dynamic

This commit is contained in:
François Voron
2019-10-10 13:37:52 +02:00
parent d9e6e93f08
commit 0112e700ac
12 changed files with 153 additions and 76 deletions

View File

@ -14,6 +14,7 @@ flake8-docstrings = "*"
mkdocs = "*" mkdocs = "*"
mkdocs-material = "*" mkdocs-material = "*"
black = "*" black = "*"
mypy = "*"
[packages] [packages]
fastapi = "*" fastapi = "*"

53
Pipfile.lock generated
View File

@ -1,7 +1,7 @@
{ {
"_meta": { "_meta": {
"hash": { "hash": {
"sha256": "4687ef95ee5576f1882e551641586bdbeda40a663bb6d9b9ff95d4259e4cd023" "sha256": "c1aa33adc0a5d3c81741b012a2cab88f01697fa74b7259bfbdaaec18ebe60036"
}, },
"pipfile-spec": 6, "pipfile-spec": 6,
"requires": { "requires": {
@ -383,6 +383,29 @@
], ],
"version": "==7.2.0" "version": "==7.2.0"
}, },
"mypy": {
"hashes": [
"sha256:1d98fd818ad3128a5408148c9e4a5edce6ed6b58cc314283e631dd5d9216527b",
"sha256:22ee018e8fc212fe601aba65d3699689dd29a26410ef0d2cc1943de7bec7e3ac",
"sha256:3a24f80776edc706ec8d05329e854d5b9e464cd332e25cde10c8da2da0a0db6c",
"sha256:42a78944e80770f21609f504ca6c8173f7768043205b5ac51c9144e057dcf879",
"sha256:4b2b20106973548975f0c0b1112eceb4d77ed0cafe0a231a1318f3b3a22fc795",
"sha256:591a9625b4d285f3ba69f541c84c0ad9e7bffa7794da3fa0585ef13cf95cb021",
"sha256:5b4b70da3d8bae73b908a90bb2c387b977e59d484d22c604a2131f6f4397c1a3",
"sha256:84edda1ffeda0941b2ab38ecf49302326df79947fa33d98cdcfbf8ca9cf0bb23",
"sha256:b2b83d29babd61b876ae375786960a5374bba0e4aba3c293328ca6ca5dc448dd",
"sha256:cc4502f84c37223a1a5ab700649b5ab1b5e4d2bf2d426907161f20672a21930b",
"sha256:e29e24dd6e7f39f200a5bb55dcaa645d38a397dd5a6674f6042ef02df5795046"
],
"index": "pypi",
"version": "==0.730"
},
"mypy-extensions": {
"hashes": [
"sha256:a161e3b917053de87dbe469987e173e49fb454eca10ef28b48b384538cc11458"
],
"version": "==0.4.2"
},
"packaging": { "packaging": {
"hashes": [ "hashes": [
"sha256:28b924174df7a2fa32c1953825ff29c61e2f5e082343165438812f00d3a7fc47", "sha256:28b924174df7a2fa32c1953825ff29c61e2f5e082343165438812f00d3a7fc47",
@ -535,6 +558,34 @@
], ],
"version": "==6.0.3" "version": "==6.0.3"
}, },
"typed-ast": {
"hashes": [
"sha256:18511a0b3e7922276346bcb47e2ef9f38fb90fd31cb9223eed42c85d1312344e",
"sha256:262c247a82d005e43b5b7f69aff746370538e176131c32dda9cb0f324d27141e",
"sha256:2b907eb046d049bcd9892e3076c7a6456c93a25bebfe554e931620c90e6a25b0",
"sha256:354c16e5babd09f5cb0ee000d54cfa38401d8b8891eefa878ac772f827181a3c",
"sha256:4e0b70c6fc4d010f8107726af5fd37921b666f5b31d9331f0bd24ad9a088e631",
"sha256:630968c5cdee51a11c05a30453f8cd65e0cc1d2ad0d9192819df9978984529f4",
"sha256:66480f95b8167c9c5c5c87f32cf437d585937970f3fc24386f313a4c97b44e34",
"sha256:71211d26ffd12d63a83e079ff258ac9d56a1376a25bc80b1cdcdf601b855b90b",
"sha256:95bd11af7eafc16e829af2d3df510cecfd4387f6453355188342c3e79a2ec87a",
"sha256:bc6c7d3fa1325a0c6613512a093bc2a2a15aeec350451cbdf9e1d4bffe3e3233",
"sha256:cc34a6f5b426748a507dd5d1de4c1978f2eb5626d51326e43280941206c209e1",
"sha256:d755f03c1e4a51e9b24d899561fec4ccaf51f210d52abdf8c07ee2849b212a36",
"sha256:d7c45933b1bdfaf9f36c579671fec15d25b06c8398f113dab64c18ed1adda01d",
"sha256:d896919306dd0aa22d0132f62a1b78d11aaf4c9fc5b3410d3c666b818191630a",
"sha256:ffde2fbfad571af120fcbfbbc61c72469e72f550d676c3342492a9dfdefb8f12"
],
"version": "==1.4.0"
},
"typing-extensions": {
"hashes": [
"sha256:2ed632b30bb54fc3941c382decfd0ee4148f5c591651c9272473fea2c6397d95",
"sha256:b1edbbf0652660e32ae780ac9433f4231e7339c7f9a8057d0f042fcbcea49b87",
"sha256:d8179012ec2c620d3791ca6fe2bf7979d979acdbef1fca0bc56b37411db682ed"
],
"version": "==3.7.4"
},
"urllib3": { "urllib3": {
"hashes": [ "hashes": [
"sha256:3de946ffbed6e6746608990594d08faac602528ac7015ac28d33cee6a45b7398", "sha256:3de946ffbed6e6746608990594d08faac602528ac7015ac28d33cee6a45b7398",

View File

@ -3,7 +3,7 @@ from typing import Callable
from starlette.responses import Response from starlette.responses import Response
from fastapi_users.db import BaseUserDatabase from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import UserDB from fastapi_users.models import BaseUserDB
class BaseAuthentication: class BaseAuthentication:
@ -13,8 +13,8 @@ class BaseAuthentication:
def __init__(self, user_db: BaseUserDatabase): def __init__(self, user_db: BaseUserDatabase):
self.user_db = user_db self.user_db = user_db
async def get_login_response(self, user: UserDB, response: Response): async def get_login_response(self, user: BaseUserDB, response: Response):
raise NotImplementedError() raise NotImplementedError()
def get_authentication_method(self) -> Callable[..., UserDB]: def get_authentication_method(self) -> Callable[..., BaseUserDB]:
raise NotImplementedError() raise NotImplementedError()

View File

@ -7,7 +7,7 @@ from starlette import status
from starlette.responses import Response from starlette.responses import Response
from fastapi_users.authentication import BaseAuthentication from fastapi_users.authentication import BaseAuthentication
from fastapi_users.models import UserDB from fastapi_users.models import BaseUserDB
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
@ -30,7 +30,7 @@ class JWTAuthentication(BaseAuthentication):
self.secret = secret self.secret = secret
self.lifetime_seconds = lifetime_seconds self.lifetime_seconds = lifetime_seconds
async def get_login_response(self, user: UserDB, response: Response): async def get_login_response(self, user: BaseUserDB, response: Response):
data = {"user_id": user.id} data = {"user_id": user.id}
token = generate_jwt(data, self.lifetime_seconds, self.secret, self.algorithm) token = generate_jwt(data, self.lifetime_seconds, self.secret, self.algorithm)
@ -44,7 +44,7 @@ class JWTAuthentication(BaseAuthentication):
try: try:
data = jwt.decode(token, self.secret, algorithms=[self.algorithm]) data = jwt.decode(token, self.secret, algorithms=[self.algorithm])
user_id: str = data.get("user_id") user_id = data.get("user_id")
if user_id is None: if user_id is None:
raise credentials_exception raise credentials_exception
except jwt.PyJWTError: except jwt.PyJWTError:

View File

@ -1,8 +1,8 @@
from typing import List from typing import List, Optional
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from fastapi_users.models import UserDB from fastapi_users.models import BaseUserDB
from fastapi_users.password import get_password_hash, verify_and_update_password from fastapi_users.password import get_password_hash, verify_and_update_password
@ -12,22 +12,24 @@ class BaseUserDatabase:
the database. the database.
""" """
async def list(self) -> List[UserDB]: async def list(self) -> List[BaseUserDB]:
raise NotImplementedError() raise NotImplementedError()
async def get(self, id: str) -> UserDB: async def get(self, id: str) -> Optional[BaseUserDB]:
raise NotImplementedError() raise NotImplementedError()
async def get_by_email(self, email: str) -> UserDB: async def get_by_email(self, email: str) -> Optional[BaseUserDB]:
raise NotImplementedError() raise NotImplementedError()
async def create(self, user: UserDB) -> UserDB: async def create(self, user: BaseUserDB) -> BaseUserDB:
raise NotImplementedError() raise NotImplementedError()
async def update(self, user: UserDB) -> UserDB: async def update(self, user: BaseUserDB) -> BaseUserDB:
raise NotImplementedError() raise NotImplementedError()
async def authenticate(self, credentials: OAuth2PasswordRequestForm) -> UserDB: async def authenticate(
self, credentials: OAuth2PasswordRequestForm
) -> Optional[BaseUserDB]:
user = await self.get_by_email(credentials.username) user = await self.get_by_email(credentials.username)
# Always run the hasher to mitigate timing attack # Always run the hasher to mitigate timing attack

View File

@ -1,13 +1,13 @@
from typing import List from typing import List, cast
from databases import Database from databases import Database
from sqlalchemy import Boolean, Column, String, Table from sqlalchemy import Boolean, Column, String, Table
from fastapi_users.db import BaseUserDatabase from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import UserDB from fastapi_users.models import BaseUserDB
class BaseUser: class BaseUserTable:
__tablename__ = "user" __tablename__ = "user"
id = Column(String, primary_key=True) id = Column(String, primary_key=True)
@ -26,24 +26,24 @@ class SQLAlchemyUserDatabase(BaseUserDatabase):
self.database = database self.database = database
self.users = users self.users = users
async def list(self) -> List[UserDB]: async def list(self) -> List[BaseUserDB]:
query = self.users.select() query = self.users.select()
return await self.database.fetch_all(query) return cast(List[BaseUserDB], await self.database.fetch_all(query))
async def get(self, id: str) -> UserDB: async def get(self, id: str) -> BaseUserDB:
query = self.users.select().where(self.users.c.id == id) query = self.users.select().where(self.users.c.id == id)
return await self.database.fetch_one(query) return cast(BaseUserDB, await self.database.fetch_one(query))
async def get_by_email(self, email: str) -> UserDB: async def get_by_email(self, email: str) -> BaseUserDB:
query = self.users.select().where(self.users.c.email == email) query = self.users.select().where(self.users.c.email == email)
return await self.database.fetch_one(query) return cast(BaseUserDB, await self.database.fetch_one(query))
async def create(self, user: UserDB) -> UserDB: async def create(self, user: BaseUserDB) -> BaseUserDB:
query = self.users.insert().values(**user.dict()) query = self.users.insert().values(**user.dict())
await self.database.execute(query) await self.database.execute(query)
return user return user
async def update(self, user: UserDB) -> UserDB: async def update(self, user: BaseUserDB) -> BaseUserDB:
query = ( query = (
self.users.update().where(self.users.c.id == user.id).values(**user.dict()) self.users.update().where(self.users.c.id == user.id).values(**user.dict())
) )

View File

@ -1,13 +1,13 @@
import uuid import uuid
from typing import Optional from typing import Optional, Type
import pydantic import pydantic
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.types import EmailStr from pydantic.types import EmailStr
class UserBase(BaseModel): class BaseUser(BaseModel):
id: str = None id: Optional[str] = None
email: Optional[EmailStr] = None email: Optional[EmailStr] = None
is_active: Optional[bool] = True is_active: Optional[bool] = True
is_superuser: Optional[bool] = False is_superuser: Optional[bool] = False
@ -17,18 +17,31 @@ class UserBase(BaseModel):
return v or str(uuid.uuid4()) return v or str(uuid.uuid4())
class UserCreate(UserBase): class BaseUserCreate(BaseUser):
email: EmailStr email: EmailStr
password: str password: str
class UserUpdate(UserBase): class BaseUserUpdate(BaseUser):
pass pass
class UserDB(UserBase): class BaseUserDB(BaseUser):
hashed_password: str hashed_password: str
class User(UserBase): class Models:
pass def __init__(self, user_model: Type[BaseUser]):
class UserCreate(user_model, BaseUserCreate): # type: ignore
pass
class UserUpdate(user_model, BaseUserUpdate): # type: ignore
pass
class UserDB(user_model, BaseUserDB): # type: ignore
pass
self.User = user_model
self.UserCreate = UserCreate
self.UserUpdate = UserUpdate
self.UserDB = UserDB

View File

@ -1,3 +1,5 @@
from typing import Type
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from starlette import status from starlette import status
@ -5,32 +7,34 @@ from starlette.responses import Response
from fastapi_users.authentication import BaseAuthentication from fastapi_users.authentication import BaseAuthentication
from fastapi_users.db import BaseUserDatabase from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import User, UserCreate, UserDB from fastapi_users.models import BaseUser, Models
from fastapi_users.password import get_password_hash from fastapi_users.password import get_password_hash
class UserRouter: def get_user_router(
def __new__(cls, user_db: BaseUserDatabase, auth: BaseAuthentication) -> APIRouter: user_db: BaseUserDatabase, user_model: Type[BaseUser], auth: BaseAuthentication
router = APIRouter() ) -> APIRouter:
router = APIRouter()
models = Models(user_model)
@router.post("/register", response_model=User) @router.post("/register", response_model=models.User)
async def register(user: UserCreate): async def register(user: models.UserCreate): # type: ignore
hashed_password = get_password_hash(user.password) hashed_password = get_password_hash(user.password)
db_user = UserDB(**user.dict(), hashed_password=hashed_password) db_user = models.UserDB(**user.dict(), hashed_password=hashed_password)
created_user = await user_db.create(db_user) created_user = await user_db.create(db_user)
return created_user return created_user
@router.post("/login") @router.post("/login")
async def login( async def login(
response: Response, credentials: OAuth2PasswordRequestForm = Depends() response: Response, credentials: OAuth2PasswordRequestForm = Depends()
): ):
user = await user_db.authenticate(credentials) user = await user_db.authenticate(credentials)
if user is None: if user is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
elif not user.is_active: elif not user.is_active:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
return await auth.get_login_response(user, response) return await auth.get_login_response(user, response)
return router return router

View File

@ -1,3 +1,5 @@
from typing import Optional
import pytest import pytest
from fastapi import HTTPException from fastapi import HTTPException
from starlette import status from starlette import status
@ -5,16 +7,16 @@ from starlette.responses import Response
from fastapi_users.authentication import BaseAuthentication from fastapi_users.authentication import BaseAuthentication
from fastapi_users.db import BaseUserDatabase from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import UserDB from fastapi_users.models import BaseUserDB
from fastapi_users.password import get_password_hash from fastapi_users.password import get_password_hash
active_user_data = UserDB( active_user_data = BaseUserDB(
id="aaa", id="aaa",
email="king.arthur@camelot.bt", email="king.arthur@camelot.bt",
hashed_password=get_password_hash("guinevere"), hashed_password=get_password_hash("guinevere"),
) )
inactive_user_data = UserDB( inactive_user_data = BaseUserDB(
id="bbb", id="bbb",
email="percival@camelot.bt", email="percival@camelot.bt",
hashed_password=get_password_hash("angharad"), hashed_password=get_password_hash("angharad"),
@ -23,31 +25,31 @@ inactive_user_data = UserDB(
@pytest.fixture @pytest.fixture
def user() -> UserDB: def user() -> BaseUserDB:
return active_user_data return active_user_data
@pytest.fixture @pytest.fixture
def inactive_user() -> UserDB: def inactive_user() -> BaseUserDB:
return inactive_user_data return inactive_user_data
class MockUserDatabase(BaseUserDatabase): class MockUserDatabase(BaseUserDatabase):
async def get(self, id: str) -> UserDB: async def get(self, id: str) -> Optional[BaseUserDB]:
if id == active_user_data.id: if id == active_user_data.id:
return active_user_data return active_user_data
elif id == inactive_user_data.id: elif id == inactive_user_data.id:
return inactive_user_data return inactive_user_data
return None return None
async def get_by_email(self, email: str) -> UserDB: async def get_by_email(self, email: str) -> Optional[BaseUserDB]:
if email == active_user_data.email: if email == active_user_data.email:
return active_user_data return active_user_data
elif email == inactive_user_data.email: elif email == inactive_user_data.email:
return inactive_user_data return inactive_user_data
return None return None
async def create(self, user: UserDB) -> UserDB: async def create(self, user: BaseUserDB) -> BaseUserDB:
return user return user
@ -57,10 +59,10 @@ def mock_user_db() -> MockUserDatabase:
class MockAuthentication(BaseAuthentication): class MockAuthentication(BaseAuthentication):
async def get_login_response(self, user: UserDB, response: Response): async def get_login_response(self, user: BaseUserDB, response: Response):
return {"token": user.id} return {"token": user.id}
async def authenticate(self, token: str) -> UserDB: async def authenticate(self, token: str) -> BaseUserDB:
user = await self.user_db.get(token) user = await self.user_db.get(token)
if user is None or not user.is_active: if user is None or not user.is_active:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)

View File

@ -6,7 +6,7 @@ from starlette.responses import Response
from starlette.testclient import TestClient from starlette.testclient import TestClient
from fastapi_users.authentication.jwt import JWTAuthentication, generate_jwt from fastapi_users.authentication.jwt import JWTAuthentication, generate_jwt
from fastapi_users.models import UserDB from fastapi_users.models import BaseUserDB
SECRET = "SECRET" SECRET = "SECRET"
ALGORITHM = "HS256" ALGORITHM = "HS256"
@ -33,7 +33,7 @@ def test_auth_client(jwt_authentication):
@app.get("/test-auth") @app.get("/test-auth")
def test_auth( def test_auth(
user: UserDB = Depends(jwt_authentication.get_authentication_method()) user: BaseUserDB = Depends(jwt_authentication.get_authentication_method())
): ):
return user return user

View File

@ -1,18 +1,19 @@
import sqlite3 import sqlite3
from typing import AsyncGenerator
import pytest import pytest
import sqlalchemy import sqlalchemy
from databases import Database from databases import Database
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base
from fastapi_users.db.sqlalchemy import BaseUser, SQLAlchemyUserDatabase from fastapi_users.db.sqlalchemy import BaseUserTable, SQLAlchemyUserDatabase
@pytest.fixture @pytest.fixture
async def sqlalchemy_user_db() -> SQLAlchemyUserDatabase: async def sqlalchemy_user_db() -> AsyncGenerator[SQLAlchemyUserDatabase, None]:
Base = declarative_base() Base: DeclarativeMeta = declarative_base()
class User(BaseUser, Base): class User(BaseUserTable, Base):
pass pass
DATABASE_URL = "sqlite:///./test.db" DATABASE_URL = "sqlite:///./test.db"

View File

@ -3,13 +3,16 @@ from fastapi import FastAPI
from starlette import status from starlette import status
from starlette.testclient import TestClient from starlette.testclient import TestClient
from fastapi_users.models import UserDB from fastapi_users.models import BaseUser, BaseUserDB
from fastapi_users.router import UserRouter from fastapi_users.router import get_user_router
@pytest.fixture @pytest.fixture
def test_app_client(mock_user_db, mock_authentication) -> TestClient: def test_app_client(mock_user_db, mock_authentication) -> TestClient:
userRouter = UserRouter(mock_user_db, mock_authentication) class User(BaseUser):
pass
userRouter = get_user_router(mock_user_db, User, mock_authentication)
app = FastAPI() app = FastAPI()
app.include_router(userRouter) app.include_router(userRouter)
@ -68,7 +71,7 @@ class TestLogin:
response = test_app_client.post("/login", data=data) response = test_app_client.post("/login", data=data)
assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_valid_credentials(self, test_app_client: TestClient, user: UserDB): def test_valid_credentials(self, test_app_client: TestClient, user: BaseUserDB):
data = {"username": "king.arthur@camelot.bt", "password": "guinevere"} data = {"username": "king.arthur@camelot.bt", "password": "guinevere"}
response = test_app_client.post("/login", data=data) response = test_app_client.post("/login", data=data)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK