Improve typing and make User pydantic models dynamic

This commit is contained in:
François Voron
2019-10-10 13:37:52 +02:00
parent d9e6e93f08
commit 0112e700ac
12 changed files with 153 additions and 76 deletions

View File

@ -14,6 +14,7 @@ flake8-docstrings = "*"
mkdocs = "*"
mkdocs-material = "*"
black = "*"
mypy = "*"
[packages]
fastapi = "*"

53
Pipfile.lock generated
View File

@ -1,7 +1,7 @@
{
"_meta": {
"hash": {
"sha256": "4687ef95ee5576f1882e551641586bdbeda40a663bb6d9b9ff95d4259e4cd023"
"sha256": "c1aa33adc0a5d3c81741b012a2cab88f01697fa74b7259bfbdaaec18ebe60036"
},
"pipfile-spec": 6,
"requires": {
@ -383,6 +383,29 @@
],
"version": "==7.2.0"
},
"mypy": {
"hashes": [
"sha256:1d98fd818ad3128a5408148c9e4a5edce6ed6b58cc314283e631dd5d9216527b",
"sha256:22ee018e8fc212fe601aba65d3699689dd29a26410ef0d2cc1943de7bec7e3ac",
"sha256:3a24f80776edc706ec8d05329e854d5b9e464cd332e25cde10c8da2da0a0db6c",
"sha256:42a78944e80770f21609f504ca6c8173f7768043205b5ac51c9144e057dcf879",
"sha256:4b2b20106973548975f0c0b1112eceb4d77ed0cafe0a231a1318f3b3a22fc795",
"sha256:591a9625b4d285f3ba69f541c84c0ad9e7bffa7794da3fa0585ef13cf95cb021",
"sha256:5b4b70da3d8bae73b908a90bb2c387b977e59d484d22c604a2131f6f4397c1a3",
"sha256:84edda1ffeda0941b2ab38ecf49302326df79947fa33d98cdcfbf8ca9cf0bb23",
"sha256:b2b83d29babd61b876ae375786960a5374bba0e4aba3c293328ca6ca5dc448dd",
"sha256:cc4502f84c37223a1a5ab700649b5ab1b5e4d2bf2d426907161f20672a21930b",
"sha256:e29e24dd6e7f39f200a5bb55dcaa645d38a397dd5a6674f6042ef02df5795046"
],
"index": "pypi",
"version": "==0.730"
},
"mypy-extensions": {
"hashes": [
"sha256:a161e3b917053de87dbe469987e173e49fb454eca10ef28b48b384538cc11458"
],
"version": "==0.4.2"
},
"packaging": {
"hashes": [
"sha256:28b924174df7a2fa32c1953825ff29c61e2f5e082343165438812f00d3a7fc47",
@ -535,6 +558,34 @@
],
"version": "==6.0.3"
},
"typed-ast": {
"hashes": [
"sha256:18511a0b3e7922276346bcb47e2ef9f38fb90fd31cb9223eed42c85d1312344e",
"sha256:262c247a82d005e43b5b7f69aff746370538e176131c32dda9cb0f324d27141e",
"sha256:2b907eb046d049bcd9892e3076c7a6456c93a25bebfe554e931620c90e6a25b0",
"sha256:354c16e5babd09f5cb0ee000d54cfa38401d8b8891eefa878ac772f827181a3c",
"sha256:4e0b70c6fc4d010f8107726af5fd37921b666f5b31d9331f0bd24ad9a088e631",
"sha256:630968c5cdee51a11c05a30453f8cd65e0cc1d2ad0d9192819df9978984529f4",
"sha256:66480f95b8167c9c5c5c87f32cf437d585937970f3fc24386f313a4c97b44e34",
"sha256:71211d26ffd12d63a83e079ff258ac9d56a1376a25bc80b1cdcdf601b855b90b",
"sha256:95bd11af7eafc16e829af2d3df510cecfd4387f6453355188342c3e79a2ec87a",
"sha256:bc6c7d3fa1325a0c6613512a093bc2a2a15aeec350451cbdf9e1d4bffe3e3233",
"sha256:cc34a6f5b426748a507dd5d1de4c1978f2eb5626d51326e43280941206c209e1",
"sha256:d755f03c1e4a51e9b24d899561fec4ccaf51f210d52abdf8c07ee2849b212a36",
"sha256:d7c45933b1bdfaf9f36c579671fec15d25b06c8398f113dab64c18ed1adda01d",
"sha256:d896919306dd0aa22d0132f62a1b78d11aaf4c9fc5b3410d3c666b818191630a",
"sha256:ffde2fbfad571af120fcbfbbc61c72469e72f550d676c3342492a9dfdefb8f12"
],
"version": "==1.4.0"
},
"typing-extensions": {
"hashes": [
"sha256:2ed632b30bb54fc3941c382decfd0ee4148f5c591651c9272473fea2c6397d95",
"sha256:b1edbbf0652660e32ae780ac9433f4231e7339c7f9a8057d0f042fcbcea49b87",
"sha256:d8179012ec2c620d3791ca6fe2bf7979d979acdbef1fca0bc56b37411db682ed"
],
"version": "==3.7.4"
},
"urllib3": {
"hashes": [
"sha256:3de946ffbed6e6746608990594d08faac602528ac7015ac28d33cee6a45b7398",

View File

@ -3,7 +3,7 @@ from typing import Callable
from starlette.responses import Response
from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import UserDB
from fastapi_users.models import BaseUserDB
class BaseAuthentication:
@ -13,8 +13,8 @@ class BaseAuthentication:
def __init__(self, user_db: BaseUserDatabase):
self.user_db = user_db
async def get_login_response(self, user: UserDB, response: Response):
async def get_login_response(self, user: BaseUserDB, response: Response):
raise NotImplementedError()
def get_authentication_method(self) -> Callable[..., UserDB]:
def get_authentication_method(self) -> Callable[..., BaseUserDB]:
raise NotImplementedError()

View File

@ -7,7 +7,7 @@ from starlette import status
from starlette.responses import Response
from fastapi_users.authentication import BaseAuthentication
from fastapi_users.models import UserDB
from fastapi_users.models import BaseUserDB
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
@ -30,7 +30,7 @@ class JWTAuthentication(BaseAuthentication):
self.secret = secret
self.lifetime_seconds = lifetime_seconds
async def get_login_response(self, user: UserDB, response: Response):
async def get_login_response(self, user: BaseUserDB, response: Response):
data = {"user_id": user.id}
token = generate_jwt(data, self.lifetime_seconds, self.secret, self.algorithm)
@ -44,7 +44,7 @@ class JWTAuthentication(BaseAuthentication):
try:
data = jwt.decode(token, self.secret, algorithms=[self.algorithm])
user_id: str = data.get("user_id")
user_id = data.get("user_id")
if user_id is None:
raise credentials_exception
except jwt.PyJWTError:

View File

@ -1,8 +1,8 @@
from typing import List
from typing import List, Optional
from fastapi.security import OAuth2PasswordRequestForm
from fastapi_users.models import UserDB
from fastapi_users.models import BaseUserDB
from fastapi_users.password import get_password_hash, verify_and_update_password
@ -12,22 +12,24 @@ class BaseUserDatabase:
the database.
"""
async def list(self) -> List[UserDB]:
async def list(self) -> List[BaseUserDB]:
raise NotImplementedError()
async def get(self, id: str) -> UserDB:
async def get(self, id: str) -> Optional[BaseUserDB]:
raise NotImplementedError()
async def get_by_email(self, email: str) -> UserDB:
async def get_by_email(self, email: str) -> Optional[BaseUserDB]:
raise NotImplementedError()
async def create(self, user: UserDB) -> UserDB:
async def create(self, user: BaseUserDB) -> BaseUserDB:
raise NotImplementedError()
async def update(self, user: UserDB) -> UserDB:
async def update(self, user: BaseUserDB) -> BaseUserDB:
raise NotImplementedError()
async def authenticate(self, credentials: OAuth2PasswordRequestForm) -> UserDB:
async def authenticate(
self, credentials: OAuth2PasswordRequestForm
) -> Optional[BaseUserDB]:
user = await self.get_by_email(credentials.username)
# Always run the hasher to mitigate timing attack

View File

@ -1,13 +1,13 @@
from typing import List
from typing import List, cast
from databases import Database
from sqlalchemy import Boolean, Column, String, Table
from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import UserDB
from fastapi_users.models import BaseUserDB
class BaseUser:
class BaseUserTable:
__tablename__ = "user"
id = Column(String, primary_key=True)
@ -26,24 +26,24 @@ class SQLAlchemyUserDatabase(BaseUserDatabase):
self.database = database
self.users = users
async def list(self) -> List[UserDB]:
async def list(self) -> List[BaseUserDB]:
query = self.users.select()
return await self.database.fetch_all(query)
return cast(List[BaseUserDB], await self.database.fetch_all(query))
async def get(self, id: str) -> UserDB:
async def get(self, id: str) -> BaseUserDB:
query = self.users.select().where(self.users.c.id == id)
return await self.database.fetch_one(query)
return cast(BaseUserDB, await self.database.fetch_one(query))
async def get_by_email(self, email: str) -> UserDB:
async def get_by_email(self, email: str) -> BaseUserDB:
query = self.users.select().where(self.users.c.email == email)
return await self.database.fetch_one(query)
return cast(BaseUserDB, await self.database.fetch_one(query))
async def create(self, user: UserDB) -> UserDB:
async def create(self, user: BaseUserDB) -> BaseUserDB:
query = self.users.insert().values(**user.dict())
await self.database.execute(query)
return user
async def update(self, user: UserDB) -> UserDB:
async def update(self, user: BaseUserDB) -> BaseUserDB:
query = (
self.users.update().where(self.users.c.id == user.id).values(**user.dict())
)

View File

@ -1,13 +1,13 @@
import uuid
from typing import Optional
from typing import Optional, Type
import pydantic
from pydantic import BaseModel
from pydantic.types import EmailStr
class UserBase(BaseModel):
id: str = None
class BaseUser(BaseModel):
id: Optional[str] = None
email: Optional[EmailStr] = None
is_active: Optional[bool] = True
is_superuser: Optional[bool] = False
@ -17,18 +17,31 @@ class UserBase(BaseModel):
return v or str(uuid.uuid4())
class UserCreate(UserBase):
class BaseUserCreate(BaseUser):
email: EmailStr
password: str
class UserUpdate(UserBase):
class BaseUserUpdate(BaseUser):
pass
class UserDB(UserBase):
class BaseUserDB(BaseUser):
hashed_password: str
class User(UserBase):
class Models:
def __init__(self, user_model: Type[BaseUser]):
class UserCreate(user_model, BaseUserCreate): # type: ignore
pass
class UserUpdate(user_model, BaseUserUpdate): # type: ignore
pass
class UserDB(user_model, BaseUserDB): # type: ignore
pass
self.User = user_model
self.UserCreate = UserCreate
self.UserUpdate = UserUpdate
self.UserDB = UserDB

View File

@ -1,3 +1,5 @@
from typing import Type
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from starlette import status
@ -5,18 +7,20 @@ from starlette.responses import Response
from fastapi_users.authentication import BaseAuthentication
from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import User, UserCreate, UserDB
from fastapi_users.models import BaseUser, Models
from fastapi_users.password import get_password_hash
class UserRouter:
def __new__(cls, user_db: BaseUserDatabase, auth: BaseAuthentication) -> APIRouter:
def get_user_router(
user_db: BaseUserDatabase, user_model: Type[BaseUser], auth: BaseAuthentication
) -> APIRouter:
router = APIRouter()
models = Models(user_model)
@router.post("/register", response_model=User)
async def register(user: UserCreate):
@router.post("/register", response_model=models.User)
async def register(user: models.UserCreate): # type: ignore
hashed_password = get_password_hash(user.password)
db_user = UserDB(**user.dict(), hashed_password=hashed_password)
db_user = models.UserDB(**user.dict(), hashed_password=hashed_password)
created_user = await user_db.create(db_user)
return created_user

View File

@ -1,3 +1,5 @@
from typing import Optional
import pytest
from fastapi import HTTPException
from starlette import status
@ -5,16 +7,16 @@ from starlette.responses import Response
from fastapi_users.authentication import BaseAuthentication
from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import UserDB
from fastapi_users.models import BaseUserDB
from fastapi_users.password import get_password_hash
active_user_data = UserDB(
active_user_data = BaseUserDB(
id="aaa",
email="king.arthur@camelot.bt",
hashed_password=get_password_hash("guinevere"),
)
inactive_user_data = UserDB(
inactive_user_data = BaseUserDB(
id="bbb",
email="percival@camelot.bt",
hashed_password=get_password_hash("angharad"),
@ -23,31 +25,31 @@ inactive_user_data = UserDB(
@pytest.fixture
def user() -> UserDB:
def user() -> BaseUserDB:
return active_user_data
@pytest.fixture
def inactive_user() -> UserDB:
def inactive_user() -> BaseUserDB:
return inactive_user_data
class MockUserDatabase(BaseUserDatabase):
async def get(self, id: str) -> UserDB:
async def get(self, id: str) -> Optional[BaseUserDB]:
if id == active_user_data.id:
return active_user_data
elif id == inactive_user_data.id:
return inactive_user_data
return None
async def get_by_email(self, email: str) -> UserDB:
async def get_by_email(self, email: str) -> Optional[BaseUserDB]:
if email == active_user_data.email:
return active_user_data
elif email == inactive_user_data.email:
return inactive_user_data
return None
async def create(self, user: UserDB) -> UserDB:
async def create(self, user: BaseUserDB) -> BaseUserDB:
return user
@ -57,10 +59,10 @@ def mock_user_db() -> MockUserDatabase:
class MockAuthentication(BaseAuthentication):
async def get_login_response(self, user: UserDB, response: Response):
async def get_login_response(self, user: BaseUserDB, response: Response):
return {"token": user.id}
async def authenticate(self, token: str) -> UserDB:
async def authenticate(self, token: str) -> BaseUserDB:
user = await self.user_db.get(token)
if user is None or not user.is_active:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)

View File

@ -6,7 +6,7 @@ from starlette.responses import Response
from starlette.testclient import TestClient
from fastapi_users.authentication.jwt import JWTAuthentication, generate_jwt
from fastapi_users.models import UserDB
from fastapi_users.models import BaseUserDB
SECRET = "SECRET"
ALGORITHM = "HS256"
@ -33,7 +33,7 @@ def test_auth_client(jwt_authentication):
@app.get("/test-auth")
def test_auth(
user: UserDB = Depends(jwt_authentication.get_authentication_method())
user: BaseUserDB = Depends(jwt_authentication.get_authentication_method())
):
return user

View File

@ -1,18 +1,19 @@
import sqlite3
from typing import AsyncGenerator
import pytest
import sqlalchemy
from databases import Database
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base
from fastapi_users.db.sqlalchemy import BaseUser, SQLAlchemyUserDatabase
from fastapi_users.db.sqlalchemy import BaseUserTable, SQLAlchemyUserDatabase
@pytest.fixture
async def sqlalchemy_user_db() -> SQLAlchemyUserDatabase:
Base = declarative_base()
async def sqlalchemy_user_db() -> AsyncGenerator[SQLAlchemyUserDatabase, None]:
Base: DeclarativeMeta = declarative_base()
class User(BaseUser, Base):
class User(BaseUserTable, Base):
pass
DATABASE_URL = "sqlite:///./test.db"

View File

@ -3,13 +3,16 @@ from fastapi import FastAPI
from starlette import status
from starlette.testclient import TestClient
from fastapi_users.models import UserDB
from fastapi_users.router import UserRouter
from fastapi_users.models import BaseUser, BaseUserDB
from fastapi_users.router import get_user_router
@pytest.fixture
def test_app_client(mock_user_db, mock_authentication) -> TestClient:
userRouter = UserRouter(mock_user_db, mock_authentication)
class User(BaseUser):
pass
userRouter = get_user_router(mock_user_db, User, mock_authentication)
app = FastAPI()
app.include_router(userRouter)
@ -68,7 +71,7 @@ class TestLogin:
response = test_app_client.post("/login", data=data)
assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_valid_credentials(self, test_app_client: TestClient, user: UserDB):
def test_valid_credentials(self, test_app_client: TestClient, user: BaseUserDB):
data = {"username": "king.arthur@camelot.bt", "password": "guinevere"}
response = test_app_client.post("/login", data=data)
assert response.status_code == status.HTTP_200_OK