mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-11-02 04:05:19 +08:00
Improve typing and make User pydantic models dynamic
This commit is contained in:
1
Pipfile
1
Pipfile
@ -14,6 +14,7 @@ flake8-docstrings = "*"
|
||||
mkdocs = "*"
|
||||
mkdocs-material = "*"
|
||||
black = "*"
|
||||
mypy = "*"
|
||||
|
||||
[packages]
|
||||
fastapi = "*"
|
||||
|
||||
53
Pipfile.lock
generated
53
Pipfile.lock
generated
@ -1,7 +1,7 @@
|
||||
{
|
||||
"_meta": {
|
||||
"hash": {
|
||||
"sha256": "4687ef95ee5576f1882e551641586bdbeda40a663bb6d9b9ff95d4259e4cd023"
|
||||
"sha256": "c1aa33adc0a5d3c81741b012a2cab88f01697fa74b7259bfbdaaec18ebe60036"
|
||||
},
|
||||
"pipfile-spec": 6,
|
||||
"requires": {
|
||||
@ -383,6 +383,29 @@
|
||||
],
|
||||
"version": "==7.2.0"
|
||||
},
|
||||
"mypy": {
|
||||
"hashes": [
|
||||
"sha256:1d98fd818ad3128a5408148c9e4a5edce6ed6b58cc314283e631dd5d9216527b",
|
||||
"sha256:22ee018e8fc212fe601aba65d3699689dd29a26410ef0d2cc1943de7bec7e3ac",
|
||||
"sha256:3a24f80776edc706ec8d05329e854d5b9e464cd332e25cde10c8da2da0a0db6c",
|
||||
"sha256:42a78944e80770f21609f504ca6c8173f7768043205b5ac51c9144e057dcf879",
|
||||
"sha256:4b2b20106973548975f0c0b1112eceb4d77ed0cafe0a231a1318f3b3a22fc795",
|
||||
"sha256:591a9625b4d285f3ba69f541c84c0ad9e7bffa7794da3fa0585ef13cf95cb021",
|
||||
"sha256:5b4b70da3d8bae73b908a90bb2c387b977e59d484d22c604a2131f6f4397c1a3",
|
||||
"sha256:84edda1ffeda0941b2ab38ecf49302326df79947fa33d98cdcfbf8ca9cf0bb23",
|
||||
"sha256:b2b83d29babd61b876ae375786960a5374bba0e4aba3c293328ca6ca5dc448dd",
|
||||
"sha256:cc4502f84c37223a1a5ab700649b5ab1b5e4d2bf2d426907161f20672a21930b",
|
||||
"sha256:e29e24dd6e7f39f200a5bb55dcaa645d38a397dd5a6674f6042ef02df5795046"
|
||||
],
|
||||
"index": "pypi",
|
||||
"version": "==0.730"
|
||||
},
|
||||
"mypy-extensions": {
|
||||
"hashes": [
|
||||
"sha256:a161e3b917053de87dbe469987e173e49fb454eca10ef28b48b384538cc11458"
|
||||
],
|
||||
"version": "==0.4.2"
|
||||
},
|
||||
"packaging": {
|
||||
"hashes": [
|
||||
"sha256:28b924174df7a2fa32c1953825ff29c61e2f5e082343165438812f00d3a7fc47",
|
||||
@ -535,6 +558,34 @@
|
||||
],
|
||||
"version": "==6.0.3"
|
||||
},
|
||||
"typed-ast": {
|
||||
"hashes": [
|
||||
"sha256:18511a0b3e7922276346bcb47e2ef9f38fb90fd31cb9223eed42c85d1312344e",
|
||||
"sha256:262c247a82d005e43b5b7f69aff746370538e176131c32dda9cb0f324d27141e",
|
||||
"sha256:2b907eb046d049bcd9892e3076c7a6456c93a25bebfe554e931620c90e6a25b0",
|
||||
"sha256:354c16e5babd09f5cb0ee000d54cfa38401d8b8891eefa878ac772f827181a3c",
|
||||
"sha256:4e0b70c6fc4d010f8107726af5fd37921b666f5b31d9331f0bd24ad9a088e631",
|
||||
"sha256:630968c5cdee51a11c05a30453f8cd65e0cc1d2ad0d9192819df9978984529f4",
|
||||
"sha256:66480f95b8167c9c5c5c87f32cf437d585937970f3fc24386f313a4c97b44e34",
|
||||
"sha256:71211d26ffd12d63a83e079ff258ac9d56a1376a25bc80b1cdcdf601b855b90b",
|
||||
"sha256:95bd11af7eafc16e829af2d3df510cecfd4387f6453355188342c3e79a2ec87a",
|
||||
"sha256:bc6c7d3fa1325a0c6613512a093bc2a2a15aeec350451cbdf9e1d4bffe3e3233",
|
||||
"sha256:cc34a6f5b426748a507dd5d1de4c1978f2eb5626d51326e43280941206c209e1",
|
||||
"sha256:d755f03c1e4a51e9b24d899561fec4ccaf51f210d52abdf8c07ee2849b212a36",
|
||||
"sha256:d7c45933b1bdfaf9f36c579671fec15d25b06c8398f113dab64c18ed1adda01d",
|
||||
"sha256:d896919306dd0aa22d0132f62a1b78d11aaf4c9fc5b3410d3c666b818191630a",
|
||||
"sha256:ffde2fbfad571af120fcbfbbc61c72469e72f550d676c3342492a9dfdefb8f12"
|
||||
],
|
||||
"version": "==1.4.0"
|
||||
},
|
||||
"typing-extensions": {
|
||||
"hashes": [
|
||||
"sha256:2ed632b30bb54fc3941c382decfd0ee4148f5c591651c9272473fea2c6397d95",
|
||||
"sha256:b1edbbf0652660e32ae780ac9433f4231e7339c7f9a8057d0f042fcbcea49b87",
|
||||
"sha256:d8179012ec2c620d3791ca6fe2bf7979d979acdbef1fca0bc56b37411db682ed"
|
||||
],
|
||||
"version": "==3.7.4"
|
||||
},
|
||||
"urllib3": {
|
||||
"hashes": [
|
||||
"sha256:3de946ffbed6e6746608990594d08faac602528ac7015ac28d33cee6a45b7398",
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Callable
|
||||
from starlette.responses import Response
|
||||
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.models import UserDB
|
||||
from fastapi_users.models import BaseUserDB
|
||||
|
||||
|
||||
class BaseAuthentication:
|
||||
@ -13,8 +13,8 @@ class BaseAuthentication:
|
||||
def __init__(self, user_db: BaseUserDatabase):
|
||||
self.user_db = user_db
|
||||
|
||||
async def get_login_response(self, user: UserDB, response: Response):
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_authentication_method(self) -> Callable[..., UserDB]:
|
||||
def get_authentication_method(self) -> Callable[..., BaseUserDB]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -7,7 +7,7 @@ from starlette import status
|
||||
from starlette.responses import Response
|
||||
|
||||
from fastapi_users.authentication import BaseAuthentication
|
||||
from fastapi_users.models import UserDB
|
||||
from fastapi_users.models import BaseUserDB
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
|
||||
|
||||
@ -30,7 +30,7 @@ class JWTAuthentication(BaseAuthentication):
|
||||
self.secret = secret
|
||||
self.lifetime_seconds = lifetime_seconds
|
||||
|
||||
async def get_login_response(self, user: UserDB, response: Response):
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response):
|
||||
data = {"user_id": user.id}
|
||||
token = generate_jwt(data, self.lifetime_seconds, self.secret, self.algorithm)
|
||||
|
||||
@ -44,7 +44,7 @@ class JWTAuthentication(BaseAuthentication):
|
||||
|
||||
try:
|
||||
data = jwt.decode(token, self.secret, algorithms=[self.algorithm])
|
||||
user_id: str = data.get("user_id")
|
||||
user_id = data.get("user_id")
|
||||
if user_id is None:
|
||||
raise credentials_exception
|
||||
except jwt.PyJWTError:
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from fastapi_users.models import UserDB
|
||||
from fastapi_users.models import BaseUserDB
|
||||
from fastapi_users.password import get_password_hash, verify_and_update_password
|
||||
|
||||
|
||||
@ -12,22 +12,24 @@ class BaseUserDatabase:
|
||||
the database.
|
||||
"""
|
||||
|
||||
async def list(self) -> List[UserDB]:
|
||||
async def list(self) -> List[BaseUserDB]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def get(self, id: str) -> UserDB:
|
||||
async def get(self, id: str) -> Optional[BaseUserDB]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def get_by_email(self, email: str) -> UserDB:
|
||||
async def get_by_email(self, email: str) -> Optional[BaseUserDB]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def create(self, user: UserDB) -> UserDB:
|
||||
async def create(self, user: BaseUserDB) -> BaseUserDB:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def update(self, user: UserDB) -> UserDB:
|
||||
async def update(self, user: BaseUserDB) -> BaseUserDB:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def authenticate(self, credentials: OAuth2PasswordRequestForm) -> UserDB:
|
||||
async def authenticate(
|
||||
self, credentials: OAuth2PasswordRequestForm
|
||||
) -> Optional[BaseUserDB]:
|
||||
user = await self.get_by_email(credentials.username)
|
||||
|
||||
# Always run the hasher to mitigate timing attack
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
from typing import List
|
||||
from typing import List, cast
|
||||
|
||||
from databases import Database
|
||||
from sqlalchemy import Boolean, Column, String, Table
|
||||
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.models import UserDB
|
||||
from fastapi_users.models import BaseUserDB
|
||||
|
||||
|
||||
class BaseUser:
|
||||
class BaseUserTable:
|
||||
__tablename__ = "user"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
@ -26,24 +26,24 @@ class SQLAlchemyUserDatabase(BaseUserDatabase):
|
||||
self.database = database
|
||||
self.users = users
|
||||
|
||||
async def list(self) -> List[UserDB]:
|
||||
async def list(self) -> List[BaseUserDB]:
|
||||
query = self.users.select()
|
||||
return await self.database.fetch_all(query)
|
||||
return cast(List[BaseUserDB], await self.database.fetch_all(query))
|
||||
|
||||
async def get(self, id: str) -> UserDB:
|
||||
async def get(self, id: str) -> BaseUserDB:
|
||||
query = self.users.select().where(self.users.c.id == id)
|
||||
return await self.database.fetch_one(query)
|
||||
return cast(BaseUserDB, await self.database.fetch_one(query))
|
||||
|
||||
async def get_by_email(self, email: str) -> UserDB:
|
||||
async def get_by_email(self, email: str) -> BaseUserDB:
|
||||
query = self.users.select().where(self.users.c.email == email)
|
||||
return await self.database.fetch_one(query)
|
||||
return cast(BaseUserDB, await self.database.fetch_one(query))
|
||||
|
||||
async def create(self, user: UserDB) -> UserDB:
|
||||
async def create(self, user: BaseUserDB) -> BaseUserDB:
|
||||
query = self.users.insert().values(**user.dict())
|
||||
await self.database.execute(query)
|
||||
return user
|
||||
|
||||
async def update(self, user: UserDB) -> UserDB:
|
||||
async def update(self, user: BaseUserDB) -> BaseUserDB:
|
||||
query = (
|
||||
self.users.update().where(self.users.c.id == user.id).values(**user.dict())
|
||||
)
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from typing import Optional, Type
|
||||
|
||||
import pydantic
|
||||
from pydantic import BaseModel
|
||||
from pydantic.types import EmailStr
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
id: str = None
|
||||
class BaseUser(BaseModel):
|
||||
id: Optional[str] = None
|
||||
email: Optional[EmailStr] = None
|
||||
is_active: Optional[bool] = True
|
||||
is_superuser: Optional[bool] = False
|
||||
@ -17,18 +17,31 @@ class UserBase(BaseModel):
|
||||
return v or str(uuid.uuid4())
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
class BaseUserCreate(BaseUser):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class UserUpdate(UserBase):
|
||||
class BaseUserUpdate(BaseUser):
|
||||
pass
|
||||
|
||||
|
||||
class UserDB(UserBase):
|
||||
class BaseUserDB(BaseUser):
|
||||
hashed_password: str
|
||||
|
||||
|
||||
class User(UserBase):
|
||||
class Models:
|
||||
def __init__(self, user_model: Type[BaseUser]):
|
||||
class UserCreate(user_model, BaseUserCreate): # type: ignore
|
||||
pass
|
||||
|
||||
class UserUpdate(user_model, BaseUserUpdate): # type: ignore
|
||||
pass
|
||||
|
||||
class UserDB(user_model, BaseUserDB): # type: ignore
|
||||
pass
|
||||
|
||||
self.User = user_model
|
||||
self.UserCreate = UserCreate
|
||||
self.UserUpdate = UserUpdate
|
||||
self.UserDB = UserDB
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Type
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from starlette import status
|
||||
@ -5,18 +7,20 @@ from starlette.responses import Response
|
||||
|
||||
from fastapi_users.authentication import BaseAuthentication
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.models import User, UserCreate, UserDB
|
||||
from fastapi_users.models import BaseUser, Models
|
||||
from fastapi_users.password import get_password_hash
|
||||
|
||||
|
||||
class UserRouter:
|
||||
def __new__(cls, user_db: BaseUserDatabase, auth: BaseAuthentication) -> APIRouter:
|
||||
def get_user_router(
|
||||
user_db: BaseUserDatabase, user_model: Type[BaseUser], auth: BaseAuthentication
|
||||
) -> APIRouter:
|
||||
router = APIRouter()
|
||||
models = Models(user_model)
|
||||
|
||||
@router.post("/register", response_model=User)
|
||||
async def register(user: UserCreate):
|
||||
@router.post("/register", response_model=models.User)
|
||||
async def register(user: models.UserCreate): # type: ignore
|
||||
hashed_password = get_password_hash(user.password)
|
||||
db_user = UserDB(**user.dict(), hashed_password=hashed_password)
|
||||
db_user = models.UserDB(**user.dict(), hashed_password=hashed_password)
|
||||
created_user = await user_db.create(db_user)
|
||||
return created_user
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from starlette import status
|
||||
@ -5,16 +7,16 @@ from starlette.responses import Response
|
||||
|
||||
from fastapi_users.authentication import BaseAuthentication
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.models import UserDB
|
||||
from fastapi_users.models import BaseUserDB
|
||||
from fastapi_users.password import get_password_hash
|
||||
|
||||
active_user_data = UserDB(
|
||||
active_user_data = BaseUserDB(
|
||||
id="aaa",
|
||||
email="king.arthur@camelot.bt",
|
||||
hashed_password=get_password_hash("guinevere"),
|
||||
)
|
||||
|
||||
inactive_user_data = UserDB(
|
||||
inactive_user_data = BaseUserDB(
|
||||
id="bbb",
|
||||
email="percival@camelot.bt",
|
||||
hashed_password=get_password_hash("angharad"),
|
||||
@ -23,31 +25,31 @@ inactive_user_data = UserDB(
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user() -> UserDB:
|
||||
def user() -> BaseUserDB:
|
||||
return active_user_data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def inactive_user() -> UserDB:
|
||||
def inactive_user() -> BaseUserDB:
|
||||
return inactive_user_data
|
||||
|
||||
|
||||
class MockUserDatabase(BaseUserDatabase):
|
||||
async def get(self, id: str) -> UserDB:
|
||||
async def get(self, id: str) -> Optional[BaseUserDB]:
|
||||
if id == active_user_data.id:
|
||||
return active_user_data
|
||||
elif id == inactive_user_data.id:
|
||||
return inactive_user_data
|
||||
return None
|
||||
|
||||
async def get_by_email(self, email: str) -> UserDB:
|
||||
async def get_by_email(self, email: str) -> Optional[BaseUserDB]:
|
||||
if email == active_user_data.email:
|
||||
return active_user_data
|
||||
elif email == inactive_user_data.email:
|
||||
return inactive_user_data
|
||||
return None
|
||||
|
||||
async def create(self, user: UserDB) -> UserDB:
|
||||
async def create(self, user: BaseUserDB) -> BaseUserDB:
|
||||
return user
|
||||
|
||||
|
||||
@ -57,10 +59,10 @@ def mock_user_db() -> MockUserDatabase:
|
||||
|
||||
|
||||
class MockAuthentication(BaseAuthentication):
|
||||
async def get_login_response(self, user: UserDB, response: Response):
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response):
|
||||
return {"token": user.id}
|
||||
|
||||
async def authenticate(self, token: str) -> UserDB:
|
||||
async def authenticate(self, token: str) -> BaseUserDB:
|
||||
user = await self.user_db.get(token)
|
||||
if user is None or not user.is_active:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
@ -6,7 +6,7 @@ from starlette.responses import Response
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from fastapi_users.authentication.jwt import JWTAuthentication, generate_jwt
|
||||
from fastapi_users.models import UserDB
|
||||
from fastapi_users.models import BaseUserDB
|
||||
|
||||
SECRET = "SECRET"
|
||||
ALGORITHM = "HS256"
|
||||
@ -33,7 +33,7 @@ def test_auth_client(jwt_authentication):
|
||||
|
||||
@app.get("/test-auth")
|
||||
def test_auth(
|
||||
user: UserDB = Depends(jwt_authentication.get_authentication_method())
|
||||
user: BaseUserDB = Depends(jwt_authentication.get_authentication_method())
|
||||
):
|
||||
return user
|
||||
|
||||
|
||||
@ -1,18 +1,19 @@
|
||||
import sqlite3
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
from databases import Database
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base
|
||||
|
||||
from fastapi_users.db.sqlalchemy import BaseUser, SQLAlchemyUserDatabase
|
||||
from fastapi_users.db.sqlalchemy import BaseUserTable, SQLAlchemyUserDatabase
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def sqlalchemy_user_db() -> SQLAlchemyUserDatabase:
|
||||
Base = declarative_base()
|
||||
async def sqlalchemy_user_db() -> AsyncGenerator[SQLAlchemyUserDatabase, None]:
|
||||
Base: DeclarativeMeta = declarative_base()
|
||||
|
||||
class User(BaseUser, Base):
|
||||
class User(BaseUserTable, Base):
|
||||
pass
|
||||
|
||||
DATABASE_URL = "sqlite:///./test.db"
|
||||
|
||||
@ -3,13 +3,16 @@ from fastapi import FastAPI
|
||||
from starlette import status
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from fastapi_users.models import UserDB
|
||||
from fastapi_users.router import UserRouter
|
||||
from fastapi_users.models import BaseUser, BaseUserDB
|
||||
from fastapi_users.router import get_user_router
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_app_client(mock_user_db, mock_authentication) -> TestClient:
|
||||
userRouter = UserRouter(mock_user_db, mock_authentication)
|
||||
class User(BaseUser):
|
||||
pass
|
||||
|
||||
userRouter = get_user_router(mock_user_db, User, mock_authentication)
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(userRouter)
|
||||
@ -68,7 +71,7 @@ class TestLogin:
|
||||
response = test_app_client.post("/login", data=data)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
def test_valid_credentials(self, test_app_client: TestClient, user: UserDB):
|
||||
def test_valid_credentials(self, test_app_client: TestClient, user: BaseUserDB):
|
||||
data = {"username": "king.arthur@camelot.bt", "password": "guinevere"}
|
||||
response = test_app_client.post("/login", data=data)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
Reference in New Issue
Block a user