Add compatibility layer for Pydantic V2

This commit is contained in:
François Voron
2023-07-12 10:44:22 +02:00
parent d2a633d2f5
commit e17bb609ae
9 changed files with 68 additions and 37 deletions

View File

@ -8,6 +8,7 @@ from fastapi_users.authentication.transport.base import (
TransportLogoutNotSupportedError,
)
from fastapi_users.openapi import OpenAPIResponseType
from fastapi_users.schemas import model_dump
class BearerResponse(BaseModel):
@ -23,7 +24,7 @@ class BearerTransport(Transport):
async def get_login_response(self, token: str) -> Response:
bearer_response = BearerResponse(access_token=token, token_type="bearer")
return JSONResponse(bearer_response.dict())
return JSONResponse(model_dump(bearer_response))
async def get_logout_response(self) -> Response:
raise TransportLogoutNotSupportedError()

View File

@ -267,6 +267,6 @@ def get_oauth_associate_router(
request,
)
return user_schema.from_orm(user)
return schemas.model_validate(user_schema, user)
return router

View File

@ -71,6 +71,6 @@ def get_register_router(
},
)
return user_schema.from_orm(created_user)
return schemas.model_validate(user_schema, created_user)
return router

View File

@ -48,7 +48,7 @@ def get_users_router(
async def me(
user: models.UP = Depends(get_current_active_user),
):
return user_schema.from_orm(user)
return schemas.model_validate(user_schema, user)
@router.patch(
"/me",
@ -96,7 +96,7 @@ def get_users_router(
user = await user_manager.update(
user_update, user, safe=True, request=request
)
return user_schema.from_orm(user)
return schemas.model_validate(user_schema, user)
except exceptions.InvalidPasswordException as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -129,7 +129,7 @@ def get_users_router(
},
)
async def get_user(user=Depends(get_user_or_404)):
return user_schema.from_orm(user)
return schemas.model_validate(user_schema, user)
@router.patch(
"/{id}",
@ -183,7 +183,7 @@ def get_users_router(
user = await user_manager.update(
user_update, user, safe=False, request=request
)
return user_schema.from_orm(user)
return schemas.model_validate(user_schema, user)
except exceptions.InvalidPasswordException as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,

View File

@ -70,7 +70,7 @@ def get_verify_router(
):
try:
user = await user_manager.verify(token, request)
return user_schema.from_orm(user)
return schemas.model_validate(user_schema, user)
except (exceptions.InvalidVerifyToken, exceptions.UserNotExists):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,

View File

@ -1,13 +1,35 @@
from typing import Generic, List, Optional, TypeVar
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar
from pydantic import BaseModel, EmailStr
from pydantic import BaseModel, ConfigDict, EmailStr
from pydantic.version import VERSION as PYDANTIC_VERSION
from fastapi_users import models
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
SCHEMA = TypeVar("SCHEMA", bound=BaseModel)
if PYDANTIC_V2:
def model_dump(model: BaseModel, *args, **kwargs) -> Dict[str, Any]:
return model.model_dump(*args, **kwargs)
def model_validate(schema: Type[SCHEMA], obj: Any, *args, **kwargs) -> SCHEMA:
return schema.model_validate(obj, *args, **kwargs)
else:
def model_dump(model: BaseModel, *args, **kwargs) -> Dict[str, Any]:
return model.dict(*args, **kwargs)
def model_validate(schema: Type[SCHEMA], obj: Any, *args, **kwargs) -> SCHEMA:
return schema.from_orm(obj)
class CreateUpdateDictModel(BaseModel):
def create_update_dict(self):
return self.dict(
return model_dump(
self,
exclude_unset=True,
exclude={
"id",
@ -19,10 +41,10 @@ class CreateUpdateDictModel(BaseModel):
)
def create_update_dict_superuser(self):
return self.dict(exclude_unset=True, exclude={"id"})
return model_dump(self, exclude_unset=True, exclude={"id"})
class BaseUser(Generic[models.ID], CreateUpdateDictModel):
class BaseUser(CreateUpdateDictModel, Generic[models.ID]):
"""Base User model."""
id: models.ID
@ -31,6 +53,10 @@ class BaseUser(Generic[models.ID], CreateUpdateDictModel):
is_superuser: bool = False
is_verified: bool = False
if PYDANTIC_V2:
model_config = ConfigDict(from_attributes=True)
else:
class Config:
orm_mode = True
@ -44,11 +70,11 @@ class BaseUserCreate(CreateUpdateDictModel):
class BaseUserUpdate(CreateUpdateDictModel):
password: Optional[str]
email: Optional[EmailStr]
is_active: Optional[bool]
is_superuser: Optional[bool]
is_verified: Optional[bool]
password: Optional[str] = None
email: Optional[EmailStr] = None
is_active: Optional[bool] = None
is_superuser: Optional[bool] = None
is_verified: Optional[bool] = None
U = TypeVar("U", bound=BaseUser)
@ -56,7 +82,7 @@ UC = TypeVar("UC", bound=BaseUserCreate)
UU = TypeVar("UU", bound=BaseUserUpdate)
class BaseOAuthAccount(Generic[models.ID], BaseModel):
class BaseOAuthAccount(BaseModel, Generic[models.ID]):
"""Base OAuth account model."""
id: models.ID
@ -67,6 +93,10 @@ class BaseOAuthAccount(Generic[models.ID], BaseModel):
account_id: str
account_email: str
if PYDANTIC_V2:
model_config = ConfigDict(from_attributes=True)
else:
class Config:
orm_mode = True

View File

@ -39,14 +39,14 @@ lancelot_password_hash = password_helper.hash("lancelot")
excalibur_password_hash = password_helper.hash("excalibur")
IDType = uuid.UUID
IDType = UUID4
@dataclasses.dataclass
class UserModel(models.UserProtocol[IDType]):
email: str
hashed_password: str
id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4)
id: IDType = dataclasses.field(default_factory=uuid.uuid4)
is_active: bool = True
is_superuser: bool = False
is_verified: bool = False
@ -59,7 +59,7 @@ class OAuthAccountModel(models.OAuthAccountProtocol[IDType]):
access_token: str
account_id: str
account_email: str
id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4)
id: IDType = dataclasses.field(default_factory=uuid.uuid4)
expires_at: Optional[int] = None
refresh_token: Optional[str] = None
@ -70,15 +70,15 @@ class UserOAuthModel(UserModel):
class User(schemas.BaseUser[IDType]):
first_name: Optional[str]
first_name: Optional[str] = None
class UserCreate(schemas.BaseUserCreate):
first_name: Optional[str]
first_name: Optional[str] = None
class UserUpdate(schemas.BaseUserUpdate):
first_name: Optional[str]
first_name: Optional[str] = None
class UserOAuth(User, schemas.BaseOAuthAccountMixin):

View File

@ -4,7 +4,7 @@ import httpx
import pytest
from fastapi import Depends, FastAPI, status
from fastapi_users import FastAPIUsers
from fastapi_users import FastAPIUsers, schemas
from tests.conftest import IDType, User, UserCreate, UserModel, UserUpdate
@ -77,7 +77,7 @@ async def test_app_client(
def optional_current_user(
user: Optional[UserModel] = Depends(fastapi_users.current_user(optional=True)),
):
return User.from_orm(user) if user else None
return schemas.model_validate(User, user) if user else None
@app.get("/optional-current-active-user")
def optional_current_active_user(
@ -85,7 +85,7 @@ async def test_app_client(
fastapi_users.current_user(optional=True, active=True)
),
):
return User.from_orm(user) if user else None
return schemas.model_validate(User, user) if user else None
@app.get("/optional-current-verified-user")
def optional_current_verified_user(
@ -93,7 +93,7 @@ async def test_app_client(
fastapi_users.current_user(optional=True, verified=True)
),
):
return User.from_orm(user) if user else None
return schemas.model_validate(User, user) if user else None
@app.get("/optional-current-superuser")
def optional_current_superuser(
@ -101,7 +101,7 @@ async def test_app_client(
fastapi_users.current_user(optional=True, active=True, superuser=True)
),
):
return User.from_orm(user) if user else None
return schemas.model_validate(User, user) if user else None
@app.get("/optional-current-verified-superuser")
def optional_current_verified_superuser(
@ -111,7 +111,7 @@ async def test_app_client(
)
),
):
return User.from_orm(user) if user else None
return schemas.model_validate(User, user) if user else None
async for client in get_test_client(app):
yield client

View File

@ -1,8 +1,8 @@
import uuid
from typing import Callable
import pytest
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import UUID4
from pytest_mock import MockerFixture
from fastapi_users.exceptions import (
@ -77,7 +77,7 @@ def create_oauth2_password_request_form() -> (
class TestGet:
async def test_not_existing_user(self, user_manager: UserManagerMock[UserModel]):
with pytest.raises(UserNotExists):
await user_manager.get(UUID4("d35d213e-f3d8-4f08-954a-7e0d1bea286f"))
await user_manager.get(uuid.UUID("d35d213e-f3d8-4f08-954a-7e0d1bea286f"))
async def test_existing_user(
self, user_manager: UserManagerMock[UserModel], user: UserModel