mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-11-04 14:45:50 +08:00
Add compatibility layer for Pydantic V2
This commit is contained in:
@ -8,6 +8,7 @@ from fastapi_users.authentication.transport.base import (
|
|||||||
TransportLogoutNotSupportedError,
|
TransportLogoutNotSupportedError,
|
||||||
)
|
)
|
||||||
from fastapi_users.openapi import OpenAPIResponseType
|
from fastapi_users.openapi import OpenAPIResponseType
|
||||||
|
from fastapi_users.schemas import model_dump
|
||||||
|
|
||||||
|
|
||||||
class BearerResponse(BaseModel):
|
class BearerResponse(BaseModel):
|
||||||
@ -23,7 +24,7 @@ class BearerTransport(Transport):
|
|||||||
|
|
||||||
async def get_login_response(self, token: str) -> Response:
|
async def get_login_response(self, token: str) -> Response:
|
||||||
bearer_response = BearerResponse(access_token=token, token_type="bearer")
|
bearer_response = BearerResponse(access_token=token, token_type="bearer")
|
||||||
return JSONResponse(bearer_response.dict())
|
return JSONResponse(model_dump(bearer_response))
|
||||||
|
|
||||||
async def get_logout_response(self) -> Response:
|
async def get_logout_response(self) -> Response:
|
||||||
raise TransportLogoutNotSupportedError()
|
raise TransportLogoutNotSupportedError()
|
||||||
|
|||||||
@ -267,6 +267,6 @@ def get_oauth_associate_router(
|
|||||||
request,
|
request,
|
||||||
)
|
)
|
||||||
|
|
||||||
return user_schema.from_orm(user)
|
return schemas.model_validate(user_schema, user)
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|||||||
@ -71,6 +71,6 @@ def get_register_router(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return user_schema.from_orm(created_user)
|
return schemas.model_validate(user_schema, created_user)
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|||||||
@ -48,7 +48,7 @@ def get_users_router(
|
|||||||
async def me(
|
async def me(
|
||||||
user: models.UP = Depends(get_current_active_user),
|
user: models.UP = Depends(get_current_active_user),
|
||||||
):
|
):
|
||||||
return user_schema.from_orm(user)
|
return schemas.model_validate(user_schema, user)
|
||||||
|
|
||||||
@router.patch(
|
@router.patch(
|
||||||
"/me",
|
"/me",
|
||||||
@ -96,7 +96,7 @@ def get_users_router(
|
|||||||
user = await user_manager.update(
|
user = await user_manager.update(
|
||||||
user_update, user, safe=True, request=request
|
user_update, user, safe=True, request=request
|
||||||
)
|
)
|
||||||
return user_schema.from_orm(user)
|
return schemas.model_validate(user_schema, user)
|
||||||
except exceptions.InvalidPasswordException as e:
|
except exceptions.InvalidPasswordException as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
@ -129,7 +129,7 @@ def get_users_router(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def get_user(user=Depends(get_user_or_404)):
|
async def get_user(user=Depends(get_user_or_404)):
|
||||||
return user_schema.from_orm(user)
|
return schemas.model_validate(user_schema, user)
|
||||||
|
|
||||||
@router.patch(
|
@router.patch(
|
||||||
"/{id}",
|
"/{id}",
|
||||||
@ -183,7 +183,7 @@ def get_users_router(
|
|||||||
user = await user_manager.update(
|
user = await user_manager.update(
|
||||||
user_update, user, safe=False, request=request
|
user_update, user, safe=False, request=request
|
||||||
)
|
)
|
||||||
return user_schema.from_orm(user)
|
return schemas.model_validate(user_schema, user)
|
||||||
except exceptions.InvalidPasswordException as e:
|
except exceptions.InvalidPasswordException as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|||||||
@ -70,7 +70,7 @@ def get_verify_router(
|
|||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
user = await user_manager.verify(token, request)
|
user = await user_manager.verify(token, request)
|
||||||
return user_schema.from_orm(user)
|
return schemas.model_validate(user_schema, user)
|
||||||
except (exceptions.InvalidVerifyToken, exceptions.UserNotExists):
|
except (exceptions.InvalidVerifyToken, exceptions.UserNotExists):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
|||||||
@ -1,13 +1,35 @@
|
|||||||
from typing import Generic, List, Optional, TypeVar
|
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel, EmailStr
|
from pydantic import BaseModel, ConfigDict, EmailStr
|
||||||
|
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||||
|
|
||||||
from fastapi_users import models
|
from fastapi_users import models
|
||||||
|
|
||||||
|
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
|
||||||
|
|
||||||
|
SCHEMA = TypeVar("SCHEMA", bound=BaseModel)
|
||||||
|
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
|
||||||
|
def model_dump(model: BaseModel, *args, **kwargs) -> Dict[str, Any]:
|
||||||
|
return model.model_dump(*args, **kwargs)
|
||||||
|
|
||||||
|
def model_validate(schema: Type[SCHEMA], obj: Any, *args, **kwargs) -> SCHEMA:
|
||||||
|
return schema.model_validate(obj, *args, **kwargs)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def model_dump(model: BaseModel, *args, **kwargs) -> Dict[str, Any]:
|
||||||
|
return model.dict(*args, **kwargs)
|
||||||
|
|
||||||
|
def model_validate(schema: Type[SCHEMA], obj: Any, *args, **kwargs) -> SCHEMA:
|
||||||
|
return schema.from_orm(obj)
|
||||||
|
|
||||||
|
|
||||||
class CreateUpdateDictModel(BaseModel):
|
class CreateUpdateDictModel(BaseModel):
|
||||||
def create_update_dict(self):
|
def create_update_dict(self):
|
||||||
return self.dict(
|
return model_dump(
|
||||||
|
self,
|
||||||
exclude_unset=True,
|
exclude_unset=True,
|
||||||
exclude={
|
exclude={
|
||||||
"id",
|
"id",
|
||||||
@ -19,10 +41,10 @@ class CreateUpdateDictModel(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_update_dict_superuser(self):
|
def create_update_dict_superuser(self):
|
||||||
return self.dict(exclude_unset=True, exclude={"id"})
|
return model_dump(self, exclude_unset=True, exclude={"id"})
|
||||||
|
|
||||||
|
|
||||||
class BaseUser(Generic[models.ID], CreateUpdateDictModel):
|
class BaseUser(CreateUpdateDictModel, Generic[models.ID]):
|
||||||
"""Base User model."""
|
"""Base User model."""
|
||||||
|
|
||||||
id: models.ID
|
id: models.ID
|
||||||
@ -31,6 +53,10 @@ class BaseUser(Generic[models.ID], CreateUpdateDictModel):
|
|||||||
is_superuser: bool = False
|
is_superuser: bool = False
|
||||||
is_verified: bool = False
|
is_verified: bool = False
|
||||||
|
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
else:
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
orm_mode = True
|
orm_mode = True
|
||||||
|
|
||||||
@ -44,11 +70,11 @@ class BaseUserCreate(CreateUpdateDictModel):
|
|||||||
|
|
||||||
|
|
||||||
class BaseUserUpdate(CreateUpdateDictModel):
|
class BaseUserUpdate(CreateUpdateDictModel):
|
||||||
password: Optional[str]
|
password: Optional[str] = None
|
||||||
email: Optional[EmailStr]
|
email: Optional[EmailStr] = None
|
||||||
is_active: Optional[bool]
|
is_active: Optional[bool] = None
|
||||||
is_superuser: Optional[bool]
|
is_superuser: Optional[bool] = None
|
||||||
is_verified: Optional[bool]
|
is_verified: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
U = TypeVar("U", bound=BaseUser)
|
U = TypeVar("U", bound=BaseUser)
|
||||||
@ -56,7 +82,7 @@ UC = TypeVar("UC", bound=BaseUserCreate)
|
|||||||
UU = TypeVar("UU", bound=BaseUserUpdate)
|
UU = TypeVar("UU", bound=BaseUserUpdate)
|
||||||
|
|
||||||
|
|
||||||
class BaseOAuthAccount(Generic[models.ID], BaseModel):
|
class BaseOAuthAccount(BaseModel, Generic[models.ID]):
|
||||||
"""Base OAuth account model."""
|
"""Base OAuth account model."""
|
||||||
|
|
||||||
id: models.ID
|
id: models.ID
|
||||||
@ -67,6 +93,10 @@ class BaseOAuthAccount(Generic[models.ID], BaseModel):
|
|||||||
account_id: str
|
account_id: str
|
||||||
account_email: str
|
account_email: str
|
||||||
|
|
||||||
|
if PYDANTIC_V2:
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
else:
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
orm_mode = True
|
orm_mode = True
|
||||||
|
|
||||||
|
|||||||
@ -39,14 +39,14 @@ lancelot_password_hash = password_helper.hash("lancelot")
|
|||||||
excalibur_password_hash = password_helper.hash("excalibur")
|
excalibur_password_hash = password_helper.hash("excalibur")
|
||||||
|
|
||||||
|
|
||||||
IDType = uuid.UUID
|
IDType = UUID4
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class UserModel(models.UserProtocol[IDType]):
|
class UserModel(models.UserProtocol[IDType]):
|
||||||
email: str
|
email: str
|
||||||
hashed_password: str
|
hashed_password: str
|
||||||
id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4)
|
id: IDType = dataclasses.field(default_factory=uuid.uuid4)
|
||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
is_superuser: bool = False
|
is_superuser: bool = False
|
||||||
is_verified: bool = False
|
is_verified: bool = False
|
||||||
@ -59,7 +59,7 @@ class OAuthAccountModel(models.OAuthAccountProtocol[IDType]):
|
|||||||
access_token: str
|
access_token: str
|
||||||
account_id: str
|
account_id: str
|
||||||
account_email: str
|
account_email: str
|
||||||
id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4)
|
id: IDType = dataclasses.field(default_factory=uuid.uuid4)
|
||||||
expires_at: Optional[int] = None
|
expires_at: Optional[int] = None
|
||||||
refresh_token: Optional[str] = None
|
refresh_token: Optional[str] = None
|
||||||
|
|
||||||
@ -70,15 +70,15 @@ class UserOAuthModel(UserModel):
|
|||||||
|
|
||||||
|
|
||||||
class User(schemas.BaseUser[IDType]):
|
class User(schemas.BaseUser[IDType]):
|
||||||
first_name: Optional[str]
|
first_name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class UserCreate(schemas.BaseUserCreate):
|
class UserCreate(schemas.BaseUserCreate):
|
||||||
first_name: Optional[str]
|
first_name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class UserUpdate(schemas.BaseUserUpdate):
|
class UserUpdate(schemas.BaseUserUpdate):
|
||||||
first_name: Optional[str]
|
first_name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class UserOAuth(User, schemas.BaseOAuthAccountMixin):
|
class UserOAuth(User, schemas.BaseOAuthAccountMixin):
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import httpx
|
|||||||
import pytest
|
import pytest
|
||||||
from fastapi import Depends, FastAPI, status
|
from fastapi import Depends, FastAPI, status
|
||||||
|
|
||||||
from fastapi_users import FastAPIUsers
|
from fastapi_users import FastAPIUsers, schemas
|
||||||
from tests.conftest import IDType, User, UserCreate, UserModel, UserUpdate
|
from tests.conftest import IDType, User, UserCreate, UserModel, UserUpdate
|
||||||
|
|
||||||
|
|
||||||
@ -77,7 +77,7 @@ async def test_app_client(
|
|||||||
def optional_current_user(
|
def optional_current_user(
|
||||||
user: Optional[UserModel] = Depends(fastapi_users.current_user(optional=True)),
|
user: Optional[UserModel] = Depends(fastapi_users.current_user(optional=True)),
|
||||||
):
|
):
|
||||||
return User.from_orm(user) if user else None
|
return schemas.model_validate(User, user) if user else None
|
||||||
|
|
||||||
@app.get("/optional-current-active-user")
|
@app.get("/optional-current-active-user")
|
||||||
def optional_current_active_user(
|
def optional_current_active_user(
|
||||||
@ -85,7 +85,7 @@ async def test_app_client(
|
|||||||
fastapi_users.current_user(optional=True, active=True)
|
fastapi_users.current_user(optional=True, active=True)
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
return User.from_orm(user) if user else None
|
return schemas.model_validate(User, user) if user else None
|
||||||
|
|
||||||
@app.get("/optional-current-verified-user")
|
@app.get("/optional-current-verified-user")
|
||||||
def optional_current_verified_user(
|
def optional_current_verified_user(
|
||||||
@ -93,7 +93,7 @@ async def test_app_client(
|
|||||||
fastapi_users.current_user(optional=True, verified=True)
|
fastapi_users.current_user(optional=True, verified=True)
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
return User.from_orm(user) if user else None
|
return schemas.model_validate(User, user) if user else None
|
||||||
|
|
||||||
@app.get("/optional-current-superuser")
|
@app.get("/optional-current-superuser")
|
||||||
def optional_current_superuser(
|
def optional_current_superuser(
|
||||||
@ -101,7 +101,7 @@ async def test_app_client(
|
|||||||
fastapi_users.current_user(optional=True, active=True, superuser=True)
|
fastapi_users.current_user(optional=True, active=True, superuser=True)
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
return User.from_orm(user) if user else None
|
return schemas.model_validate(User, user) if user else None
|
||||||
|
|
||||||
@app.get("/optional-current-verified-superuser")
|
@app.get("/optional-current-verified-superuser")
|
||||||
def optional_current_verified_superuser(
|
def optional_current_verified_superuser(
|
||||||
@ -111,7 +111,7 @@ async def test_app_client(
|
|||||||
)
|
)
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
return User.from_orm(user) if user else None
|
return schemas.model_validate(User, user) if user else None
|
||||||
|
|
||||||
async for client in get_test_client(app):
|
async for client in get_test_client(app):
|
||||||
yield client
|
yield client
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
|
import uuid
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from pydantic import UUID4
|
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from fastapi_users.exceptions import (
|
from fastapi_users.exceptions import (
|
||||||
@ -77,7 +77,7 @@ def create_oauth2_password_request_form() -> (
|
|||||||
class TestGet:
|
class TestGet:
|
||||||
async def test_not_existing_user(self, user_manager: UserManagerMock[UserModel]):
|
async def test_not_existing_user(self, user_manager: UserManagerMock[UserModel]):
|
||||||
with pytest.raises(UserNotExists):
|
with pytest.raises(UserNotExists):
|
||||||
await user_manager.get(UUID4("d35d213e-f3d8-4f08-954a-7e0d1bea286f"))
|
await user_manager.get(uuid.UUID("d35d213e-f3d8-4f08-954a-7e0d1bea286f"))
|
||||||
|
|
||||||
async def test_existing_user(
|
async def test_existing_user(
|
||||||
self, user_manager: UserManagerMock[UserModel], user: UserModel
|
self, user_manager: UserManagerMock[UserModel], user: UserModel
|
||||||
|
|||||||
Reference in New Issue
Block a user