mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-11-02 03:31:03 +08:00
Improve typing and make User pydantic models dynamic
This commit is contained in:
@ -3,7 +3,7 @@ from typing import Callable
|
||||
from starlette.responses import Response
|
||||
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.models import UserDB
|
||||
from fastapi_users.models import BaseUserDB
|
||||
|
||||
|
||||
class BaseAuthentication:
|
||||
@ -13,8 +13,8 @@ class BaseAuthentication:
|
||||
def __init__(self, user_db: BaseUserDatabase):
|
||||
self.user_db = user_db
|
||||
|
||||
async def get_login_response(self, user: UserDB, response: Response):
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_authentication_method(self) -> Callable[..., UserDB]:
|
||||
def get_authentication_method(self) -> Callable[..., BaseUserDB]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -7,7 +7,7 @@ from starlette import status
|
||||
from starlette.responses import Response
|
||||
|
||||
from fastapi_users.authentication import BaseAuthentication
|
||||
from fastapi_users.models import UserDB
|
||||
from fastapi_users.models import BaseUserDB
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
|
||||
|
||||
@ -30,7 +30,7 @@ class JWTAuthentication(BaseAuthentication):
|
||||
self.secret = secret
|
||||
self.lifetime_seconds = lifetime_seconds
|
||||
|
||||
async def get_login_response(self, user: UserDB, response: Response):
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response):
|
||||
data = {"user_id": user.id}
|
||||
token = generate_jwt(data, self.lifetime_seconds, self.secret, self.algorithm)
|
||||
|
||||
@ -44,7 +44,7 @@ class JWTAuthentication(BaseAuthentication):
|
||||
|
||||
try:
|
||||
data = jwt.decode(token, self.secret, algorithms=[self.algorithm])
|
||||
user_id: str = data.get("user_id")
|
||||
user_id = data.get("user_id")
|
||||
if user_id is None:
|
||||
raise credentials_exception
|
||||
except jwt.PyJWTError:
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from fastapi_users.models import UserDB
|
||||
from fastapi_users.models import BaseUserDB
|
||||
from fastapi_users.password import get_password_hash, verify_and_update_password
|
||||
|
||||
|
||||
@ -12,22 +12,24 @@ class BaseUserDatabase:
|
||||
the database.
|
||||
"""
|
||||
|
||||
async def list(self) -> List[UserDB]:
|
||||
async def list(self) -> List[BaseUserDB]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def get(self, id: str) -> UserDB:
|
||||
async def get(self, id: str) -> Optional[BaseUserDB]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def get_by_email(self, email: str) -> UserDB:
|
||||
async def get_by_email(self, email: str) -> Optional[BaseUserDB]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def create(self, user: UserDB) -> UserDB:
|
||||
async def create(self, user: BaseUserDB) -> BaseUserDB:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def update(self, user: UserDB) -> UserDB:
|
||||
async def update(self, user: BaseUserDB) -> BaseUserDB:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def authenticate(self, credentials: OAuth2PasswordRequestForm) -> UserDB:
|
||||
async def authenticate(
|
||||
self, credentials: OAuth2PasswordRequestForm
|
||||
) -> Optional[BaseUserDB]:
|
||||
user = await self.get_by_email(credentials.username)
|
||||
|
||||
# Always run the hasher to mitigate timing attack
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
from typing import List
|
||||
from typing import List, cast
|
||||
|
||||
from databases import Database
|
||||
from sqlalchemy import Boolean, Column, String, Table
|
||||
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.models import UserDB
|
||||
from fastapi_users.models import BaseUserDB
|
||||
|
||||
|
||||
class BaseUser:
|
||||
class BaseUserTable:
|
||||
__tablename__ = "user"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
@ -26,24 +26,24 @@ class SQLAlchemyUserDatabase(BaseUserDatabase):
|
||||
self.database = database
|
||||
self.users = users
|
||||
|
||||
async def list(self) -> List[UserDB]:
|
||||
async def list(self) -> List[BaseUserDB]:
|
||||
query = self.users.select()
|
||||
return await self.database.fetch_all(query)
|
||||
return cast(List[BaseUserDB], await self.database.fetch_all(query))
|
||||
|
||||
async def get(self, id: str) -> UserDB:
|
||||
async def get(self, id: str) -> BaseUserDB:
|
||||
query = self.users.select().where(self.users.c.id == id)
|
||||
return await self.database.fetch_one(query)
|
||||
return cast(BaseUserDB, await self.database.fetch_one(query))
|
||||
|
||||
async def get_by_email(self, email: str) -> UserDB:
|
||||
async def get_by_email(self, email: str) -> BaseUserDB:
|
||||
query = self.users.select().where(self.users.c.email == email)
|
||||
return await self.database.fetch_one(query)
|
||||
return cast(BaseUserDB, await self.database.fetch_one(query))
|
||||
|
||||
async def create(self, user: UserDB) -> UserDB:
|
||||
async def create(self, user: BaseUserDB) -> BaseUserDB:
|
||||
query = self.users.insert().values(**user.dict())
|
||||
await self.database.execute(query)
|
||||
return user
|
||||
|
||||
async def update(self, user: UserDB) -> UserDB:
|
||||
async def update(self, user: BaseUserDB) -> BaseUserDB:
|
||||
query = (
|
||||
self.users.update().where(self.users.c.id == user.id).values(**user.dict())
|
||||
)
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from typing import Optional, Type
|
||||
|
||||
import pydantic
|
||||
from pydantic import BaseModel
|
||||
from pydantic.types import EmailStr
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
id: str = None
|
||||
class BaseUser(BaseModel):
|
||||
id: Optional[str] = None
|
||||
email: Optional[EmailStr] = None
|
||||
is_active: Optional[bool] = True
|
||||
is_superuser: Optional[bool] = False
|
||||
@ -17,18 +17,31 @@ class UserBase(BaseModel):
|
||||
return v or str(uuid.uuid4())
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
class BaseUserCreate(BaseUser):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class UserUpdate(UserBase):
|
||||
class BaseUserUpdate(BaseUser):
|
||||
pass
|
||||
|
||||
|
||||
class UserDB(UserBase):
|
||||
class BaseUserDB(BaseUser):
|
||||
hashed_password: str
|
||||
|
||||
|
||||
class User(UserBase):
|
||||
pass
|
||||
class Models:
|
||||
def __init__(self, user_model: Type[BaseUser]):
|
||||
class UserCreate(user_model, BaseUserCreate): # type: ignore
|
||||
pass
|
||||
|
||||
class UserUpdate(user_model, BaseUserUpdate): # type: ignore
|
||||
pass
|
||||
|
||||
class UserDB(user_model, BaseUserDB): # type: ignore
|
||||
pass
|
||||
|
||||
self.User = user_model
|
||||
self.UserCreate = UserCreate
|
||||
self.UserUpdate = UserUpdate
|
||||
self.UserDB = UserDB
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Type
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from starlette import status
|
||||
@ -5,32 +7,34 @@ from starlette.responses import Response
|
||||
|
||||
from fastapi_users.authentication import BaseAuthentication
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.models import User, UserCreate, UserDB
|
||||
from fastapi_users.models import BaseUser, Models
|
||||
from fastapi_users.password import get_password_hash
|
||||
|
||||
|
||||
class UserRouter:
|
||||
def __new__(cls, user_db: BaseUserDatabase, auth: BaseAuthentication) -> APIRouter:
|
||||
router = APIRouter()
|
||||
def get_user_router(
|
||||
user_db: BaseUserDatabase, user_model: Type[BaseUser], auth: BaseAuthentication
|
||||
) -> APIRouter:
|
||||
router = APIRouter()
|
||||
models = Models(user_model)
|
||||
|
||||
@router.post("/register", response_model=User)
|
||||
async def register(user: UserCreate):
|
||||
hashed_password = get_password_hash(user.password)
|
||||
db_user = UserDB(**user.dict(), hashed_password=hashed_password)
|
||||
created_user = await user_db.create(db_user)
|
||||
return created_user
|
||||
@router.post("/register", response_model=models.User)
|
||||
async def register(user: models.UserCreate): # type: ignore
|
||||
hashed_password = get_password_hash(user.password)
|
||||
db_user = models.UserDB(**user.dict(), hashed_password=hashed_password)
|
||||
created_user = await user_db.create(db_user)
|
||||
return created_user
|
||||
|
||||
@router.post("/login")
|
||||
async def login(
|
||||
response: Response, credentials: OAuth2PasswordRequestForm = Depends()
|
||||
):
|
||||
user = await user_db.authenticate(credentials)
|
||||
@router.post("/login")
|
||||
async def login(
|
||||
response: Response, credentials: OAuth2PasswordRequestForm = Depends()
|
||||
):
|
||||
user = await user_db.authenticate(credentials)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
elif not user.is_active:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
elif not user.is_active:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
return await auth.get_login_response(user, response)
|
||||
return await auth.get_login_response(user, response)
|
||||
|
||||
return router
|
||||
return router
|
||||
|
||||
Reference in New Issue
Block a user