Improve typing and make User pydantic models dynamic

This commit is contained in:
François Voron
2019-10-10 13:37:52 +02:00
parent d9e6e93f08
commit 0112e700ac
12 changed files with 153 additions and 76 deletions

View File

@ -3,7 +3,7 @@ from typing import Callable
from starlette.responses import Response
from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import UserDB
from fastapi_users.models import BaseUserDB
class BaseAuthentication:
@ -13,8 +13,8 @@ class BaseAuthentication:
def __init__(self, user_db: BaseUserDatabase):
self.user_db = user_db
async def get_login_response(self, user: UserDB, response: Response):
async def get_login_response(self, user: BaseUserDB, response: Response):
raise NotImplementedError()
def get_authentication_method(self) -> Callable[..., UserDB]:
def get_authentication_method(self) -> Callable[..., BaseUserDB]:
raise NotImplementedError()

View File

@ -7,7 +7,7 @@ from starlette import status
from starlette.responses import Response
from fastapi_users.authentication import BaseAuthentication
from fastapi_users.models import UserDB
from fastapi_users.models import BaseUserDB
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
@ -30,7 +30,7 @@ class JWTAuthentication(BaseAuthentication):
self.secret = secret
self.lifetime_seconds = lifetime_seconds
async def get_login_response(self, user: UserDB, response: Response):
async def get_login_response(self, user: BaseUserDB, response: Response):
data = {"user_id": user.id}
token = generate_jwt(data, self.lifetime_seconds, self.secret, self.algorithm)
@ -44,7 +44,7 @@ class JWTAuthentication(BaseAuthentication):
try:
data = jwt.decode(token, self.secret, algorithms=[self.algorithm])
user_id: str = data.get("user_id")
user_id = data.get("user_id")
if user_id is None:
raise credentials_exception
except jwt.PyJWTError:

View File

@ -1,8 +1,8 @@
from typing import List
from typing import List, Optional
from fastapi.security import OAuth2PasswordRequestForm
from fastapi_users.models import UserDB
from fastapi_users.models import BaseUserDB
from fastapi_users.password import get_password_hash, verify_and_update_password
@ -12,22 +12,24 @@ class BaseUserDatabase:
the database.
"""
async def list(self) -> List[UserDB]:
async def list(self) -> List[BaseUserDB]:
raise NotImplementedError()
async def get(self, id: str) -> UserDB:
async def get(self, id: str) -> Optional[BaseUserDB]:
raise NotImplementedError()
async def get_by_email(self, email: str) -> UserDB:
async def get_by_email(self, email: str) -> Optional[BaseUserDB]:
raise NotImplementedError()
async def create(self, user: UserDB) -> UserDB:
async def create(self, user: BaseUserDB) -> BaseUserDB:
raise NotImplementedError()
async def update(self, user: UserDB) -> UserDB:
async def update(self, user: BaseUserDB) -> BaseUserDB:
raise NotImplementedError()
async def authenticate(self, credentials: OAuth2PasswordRequestForm) -> UserDB:
async def authenticate(
self, credentials: OAuth2PasswordRequestForm
) -> Optional[BaseUserDB]:
user = await self.get_by_email(credentials.username)
# Always run the hasher to mitigate timing attack

View File

@ -1,13 +1,13 @@
from typing import List
from typing import List, cast
from databases import Database
from sqlalchemy import Boolean, Column, String, Table
from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import UserDB
from fastapi_users.models import BaseUserDB
class BaseUser:
class BaseUserTable:
__tablename__ = "user"
id = Column(String, primary_key=True)
@ -26,24 +26,24 @@ class SQLAlchemyUserDatabase(BaseUserDatabase):
self.database = database
self.users = users
async def list(self) -> List[UserDB]:
async def list(self) -> List[BaseUserDB]:
query = self.users.select()
return await self.database.fetch_all(query)
return cast(List[BaseUserDB], await self.database.fetch_all(query))
async def get(self, id: str) -> UserDB:
async def get(self, id: str) -> BaseUserDB:
query = self.users.select().where(self.users.c.id == id)
return await self.database.fetch_one(query)
return cast(BaseUserDB, await self.database.fetch_one(query))
async def get_by_email(self, email: str) -> UserDB:
async def get_by_email(self, email: str) -> BaseUserDB:
query = self.users.select().where(self.users.c.email == email)
return await self.database.fetch_one(query)
return cast(BaseUserDB, await self.database.fetch_one(query))
async def create(self, user: UserDB) -> UserDB:
async def create(self, user: BaseUserDB) -> BaseUserDB:
query = self.users.insert().values(**user.dict())
await self.database.execute(query)
return user
async def update(self, user: UserDB) -> UserDB:
async def update(self, user: BaseUserDB) -> BaseUserDB:
query = (
self.users.update().where(self.users.c.id == user.id).values(**user.dict())
)

View File

@ -1,13 +1,13 @@
import uuid
from typing import Optional
from typing import Optional, Type
import pydantic
from pydantic import BaseModel
from pydantic.types import EmailStr
class UserBase(BaseModel):
id: str = None
class BaseUser(BaseModel):
id: Optional[str] = None
email: Optional[EmailStr] = None
is_active: Optional[bool] = True
is_superuser: Optional[bool] = False
@ -17,18 +17,31 @@ class UserBase(BaseModel):
return v or str(uuid.uuid4())
class UserCreate(UserBase):
class BaseUserCreate(BaseUser):
email: EmailStr
password: str
class UserUpdate(UserBase):
class BaseUserUpdate(BaseUser):
pass
class UserDB(UserBase):
class BaseUserDB(BaseUser):
hashed_password: str
class User(UserBase):
pass
class Models:
def __init__(self, user_model: Type[BaseUser]):
class UserCreate(user_model, BaseUserCreate): # type: ignore
pass
class UserUpdate(user_model, BaseUserUpdate): # type: ignore
pass
class UserDB(user_model, BaseUserDB): # type: ignore
pass
self.User = user_model
self.UserCreate = UserCreate
self.UserUpdate = UserUpdate
self.UserDB = UserDB

View File

@ -1,3 +1,5 @@
from typing import Type
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from starlette import status
@ -5,32 +7,34 @@ from starlette.responses import Response
from fastapi_users.authentication import BaseAuthentication
from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import User, UserCreate, UserDB
from fastapi_users.models import BaseUser, Models
from fastapi_users.password import get_password_hash
class UserRouter:
def __new__(cls, user_db: BaseUserDatabase, auth: BaseAuthentication) -> APIRouter:
router = APIRouter()
def get_user_router(
user_db: BaseUserDatabase, user_model: Type[BaseUser], auth: BaseAuthentication
) -> APIRouter:
router = APIRouter()
models = Models(user_model)
@router.post("/register", response_model=User)
async def register(user: UserCreate):
hashed_password = get_password_hash(user.password)
db_user = UserDB(**user.dict(), hashed_password=hashed_password)
created_user = await user_db.create(db_user)
return created_user
@router.post("/register", response_model=models.User)
async def register(user: models.UserCreate): # type: ignore
hashed_password = get_password_hash(user.password)
db_user = models.UserDB(**user.dict(), hashed_password=hashed_password)
created_user = await user_db.create(db_user)
return created_user
@router.post("/login")
async def login(
response: Response, credentials: OAuth2PasswordRequestForm = Depends()
):
user = await user_db.authenticate(credentials)
@router.post("/login")
async def login(
response: Response, credentials: OAuth2PasswordRequestForm = Depends()
):
user = await user_db.authenticate(credentials)
if user is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
elif not user.is_active:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
if user is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
elif not user.is_active:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
return await auth.get_login_response(user, response)
return await auth.get_login_response(user, response)
return router
return router