Merge pull request #1249 from fastapi-users/pydantic-v2

Pydantic V2 support
This commit is contained in:
François Voron
2023-07-12 11:02:44 +02:00
committed by GitHub
11 changed files with 105 additions and 45 deletions

View File

@ -4,6 +4,26 @@ on: [push, pull_request]
jobs: jobs:
lint:
runs-on: ubuntu-latest
strategy:
matrix:
python_version: [3.8, 3.9, '3.10', '3.11']
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python_version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install hatch
- name: Lint and typecheck
run: |
hatch run lint-check
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
@ -20,13 +40,9 @@ jobs:
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install hatch pip install hatch
hatch env create
- name: Lint and typecheck
run: |
hatch run lint-check
- name: Test - name: Test
run: | run: |
hatch run test-cov-xml hatch run test:test-cov-xml
- uses: codecov/codecov-action@v3 - uses: codecov/codecov-action@v3
with: with:
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
@ -40,7 +56,7 @@ jobs:
release: release:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: test needs: [lint, test]
if: startsWith(github.ref, 'refs/tags/') if: startsWith(github.ref, 'refs/tags/')
steps: steps:

View File

@ -8,6 +8,7 @@ from fastapi_users.authentication.transport.base import (
TransportLogoutNotSupportedError, TransportLogoutNotSupportedError,
) )
from fastapi_users.openapi import OpenAPIResponseType from fastapi_users.openapi import OpenAPIResponseType
from fastapi_users.schemas import model_dump
class BearerResponse(BaseModel): class BearerResponse(BaseModel):
@ -23,7 +24,7 @@ class BearerTransport(Transport):
async def get_login_response(self, token: str) -> Response: async def get_login_response(self, token: str) -> Response:
bearer_response = BearerResponse(access_token=token, token_type="bearer") bearer_response = BearerResponse(access_token=token, token_type="bearer")
return JSONResponse(bearer_response.dict()) return JSONResponse(model_dump(bearer_response))
async def get_logout_response(self) -> Response: async def get_logout_response(self) -> Response:
raise TransportLogoutNotSupportedError() raise TransportLogoutNotSupportedError()

View File

@ -267,6 +267,6 @@ def get_oauth_associate_router(
request, request,
) )
return user_schema.from_orm(user) return schemas.model_validate(user_schema, user)
return router return router

View File

@ -71,6 +71,6 @@ def get_register_router(
}, },
) )
return user_schema.from_orm(created_user) return schemas.model_validate(user_schema, created_user)
return router return router

View File

@ -48,7 +48,7 @@ def get_users_router(
async def me( async def me(
user: models.UP = Depends(get_current_active_user), user: models.UP = Depends(get_current_active_user),
): ):
return user_schema.from_orm(user) return schemas.model_validate(user_schema, user)
@router.patch( @router.patch(
"/me", "/me",
@ -96,7 +96,7 @@ def get_users_router(
user = await user_manager.update( user = await user_manager.update(
user_update, user, safe=True, request=request user_update, user, safe=True, request=request
) )
return user_schema.from_orm(user) return schemas.model_validate(user_schema, user)
except exceptions.InvalidPasswordException as e: except exceptions.InvalidPasswordException as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -129,7 +129,7 @@ def get_users_router(
}, },
) )
async def get_user(user=Depends(get_user_or_404)): async def get_user(user=Depends(get_user_or_404)):
return user_schema.from_orm(user) return schemas.model_validate(user_schema, user)
@router.patch( @router.patch(
"/{id}", "/{id}",
@ -183,7 +183,7 @@ def get_users_router(
user = await user_manager.update( user = await user_manager.update(
user_update, user, safe=False, request=request user_update, user, safe=False, request=request
) )
return user_schema.from_orm(user) return schemas.model_validate(user_schema, user)
except exceptions.InvalidPasswordException as e: except exceptions.InvalidPasswordException as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,

View File

@ -70,7 +70,7 @@ def get_verify_router(
): ):
try: try:
user = await user_manager.verify(token, request) user = await user_manager.verify(token, request)
return user_schema.from_orm(user) return schemas.model_validate(user_schema, user)
except (exceptions.InvalidVerifyToken, exceptions.UserNotExists): except (exceptions.InvalidVerifyToken, exceptions.UserNotExists):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,

View File

@ -1,13 +1,35 @@
from typing import Generic, List, Optional, TypeVar from typing import Any, Dict, Generic, List, Optional, Type, TypeVar
from pydantic import BaseModel, EmailStr from pydantic import BaseModel, ConfigDict, EmailStr
from pydantic.version import VERSION as PYDANTIC_VERSION
from fastapi_users import models from fastapi_users import models
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
SCHEMA = TypeVar("SCHEMA", bound=BaseModel)
if PYDANTIC_V2: # pragma: no cover
def model_dump(model: BaseModel, *args, **kwargs) -> Dict[str, Any]:
return model.model_dump(*args, **kwargs) # type: ignore
def model_validate(schema: Type[SCHEMA], obj: Any, *args, **kwargs) -> SCHEMA:
return schema.model_validate(obj, *args, **kwargs) # type: ignore
else: # pragma: no cover # type: ignore
def model_dump(model: BaseModel, *args, **kwargs) -> Dict[str, Any]:
return model.dict(*args, **kwargs) # type: ignore
def model_validate(schema: Type[SCHEMA], obj: Any, *args, **kwargs) -> SCHEMA:
return schema.from_orm(obj) # type: ignore
class CreateUpdateDictModel(BaseModel): class CreateUpdateDictModel(BaseModel):
def create_update_dict(self): def create_update_dict(self):
return self.dict( return model_dump(
self,
exclude_unset=True, exclude_unset=True,
exclude={ exclude={
"id", "id",
@ -19,10 +41,10 @@ class CreateUpdateDictModel(BaseModel):
) )
def create_update_dict_superuser(self): def create_update_dict_superuser(self):
return self.dict(exclude_unset=True, exclude={"id"}) return model_dump(self, exclude_unset=True, exclude={"id"})
class BaseUser(Generic[models.ID], CreateUpdateDictModel): class BaseUser(CreateUpdateDictModel, Generic[models.ID]):
"""Base User model.""" """Base User model."""
id: models.ID id: models.ID
@ -31,8 +53,12 @@ class BaseUser(Generic[models.ID], CreateUpdateDictModel):
is_superuser: bool = False is_superuser: bool = False
is_verified: bool = False is_verified: bool = False
class Config: if PYDANTIC_V2: # pragma: no cover
orm_mode = True model_config = ConfigDict(from_attributes=True) # type: ignore
else: # pragma: no cover
class Config:
orm_mode = True
class BaseUserCreate(CreateUpdateDictModel): class BaseUserCreate(CreateUpdateDictModel):
@ -44,11 +70,11 @@ class BaseUserCreate(CreateUpdateDictModel):
class BaseUserUpdate(CreateUpdateDictModel): class BaseUserUpdate(CreateUpdateDictModel):
password: Optional[str] password: Optional[str] = None
email: Optional[EmailStr] email: Optional[EmailStr] = None
is_active: Optional[bool] is_active: Optional[bool] = None
is_superuser: Optional[bool] is_superuser: Optional[bool] = None
is_verified: Optional[bool] is_verified: Optional[bool] = None
U = TypeVar("U", bound=BaseUser) U = TypeVar("U", bound=BaseUser)
@ -56,7 +82,7 @@ UC = TypeVar("UC", bound=BaseUserCreate)
UU = TypeVar("UU", bound=BaseUserUpdate) UU = TypeVar("UU", bound=BaseUserUpdate)
class BaseOAuthAccount(Generic[models.ID], BaseModel): class BaseOAuthAccount(BaseModel, Generic[models.ID]):
"""Base OAuth account model.""" """Base OAuth account model."""
id: models.ID id: models.ID
@ -67,8 +93,12 @@ class BaseOAuthAccount(Generic[models.ID], BaseModel):
account_id: str account_id: str
account_email: str account_email: str
class Config: if PYDANTIC_V2: # pragma: no cover
orm_mode = True model_config = ConfigDict(from_attributes=True) # type: ignore
else: # pragma: no cover
class Config:
orm_mode = True
class BaseOAuthAccountMixin(BaseModel): class BaseOAuthAccountMixin(BaseModel):

View File

@ -74,8 +74,6 @@ dependencies = [
] ]
[tool.hatch.envs.default.scripts] [tool.hatch.envs.default.scripts]
test = "pytest --cov=fastapi_users/ --cov-report=term-missing --cov-fail-under=100"
test-cov-xml = "pytest --cov=fastapi_users/ --cov-report=xml --cov-fail-under=100"
lint = [ lint = [
"isort ./fastapi_users ./tests", "isort ./fastapi_users ./tests",
"isort ./docs/src -o fastapi_users", "isort ./docs/src -o fastapi_users",
@ -94,6 +92,21 @@ lint-check = [
] ]
docs = "mkdocs serve" docs = "mkdocs serve"
[tool.hatch.envs.test]
[tool.hatch.envs.test.scripts]
test = "pytest --cov=fastapi_users/ --cov-report=term-missing --cov-fail-under=100"
test-cov-xml = "pytest --cov=fastapi_users/ --cov-report=xml --cov-fail-under=100"
[[tool.hatch.envs.test.matrix]]
pydantic = ["v1", "v2"]
[tool.hatch.envs.test.overrides]
matrix.pydantic.extra-dependencies = [
{value = "pydantic<2.0", if = ["v1"]},
{value = "pydantic>=2.0", if = ["v2"]},
]
[tool.hatch.build.targets.sdist] [tool.hatch.build.targets.sdist]
support-legacy = true # Create setup.py support-legacy = true # Create setup.py

View File

@ -39,14 +39,14 @@ lancelot_password_hash = password_helper.hash("lancelot")
excalibur_password_hash = password_helper.hash("excalibur") excalibur_password_hash = password_helper.hash("excalibur")
IDType = uuid.UUID IDType = UUID4
@dataclasses.dataclass @dataclasses.dataclass
class UserModel(models.UserProtocol[IDType]): class UserModel(models.UserProtocol[IDType]):
email: str email: str
hashed_password: str hashed_password: str
id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4) id: IDType = dataclasses.field(default_factory=uuid.uuid4)
is_active: bool = True is_active: bool = True
is_superuser: bool = False is_superuser: bool = False
is_verified: bool = False is_verified: bool = False
@ -59,7 +59,7 @@ class OAuthAccountModel(models.OAuthAccountProtocol[IDType]):
access_token: str access_token: str
account_id: str account_id: str
account_email: str account_email: str
id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4) id: IDType = dataclasses.field(default_factory=uuid.uuid4)
expires_at: Optional[int] = None expires_at: Optional[int] = None
refresh_token: Optional[str] = None refresh_token: Optional[str] = None
@ -70,15 +70,15 @@ class UserOAuthModel(UserModel):
class User(schemas.BaseUser[IDType]): class User(schemas.BaseUser[IDType]):
first_name: Optional[str] first_name: Optional[str] = None
class UserCreate(schemas.BaseUserCreate): class UserCreate(schemas.BaseUserCreate):
first_name: Optional[str] first_name: Optional[str] = None
class UserUpdate(schemas.BaseUserUpdate): class UserUpdate(schemas.BaseUserUpdate):
first_name: Optional[str] first_name: Optional[str] = None
class UserOAuth(User, schemas.BaseOAuthAccountMixin): class UserOAuth(User, schemas.BaseOAuthAccountMixin):

View File

@ -4,7 +4,7 @@ import httpx
import pytest import pytest
from fastapi import Depends, FastAPI, status from fastapi import Depends, FastAPI, status
from fastapi_users import FastAPIUsers from fastapi_users import FastAPIUsers, schemas
from tests.conftest import IDType, User, UserCreate, UserModel, UserUpdate from tests.conftest import IDType, User, UserCreate, UserModel, UserUpdate
@ -77,7 +77,7 @@ async def test_app_client(
def optional_current_user( def optional_current_user(
user: Optional[UserModel] = Depends(fastapi_users.current_user(optional=True)), user: Optional[UserModel] = Depends(fastapi_users.current_user(optional=True)),
): ):
return User.from_orm(user) if user else None return schemas.model_validate(User, user) if user else None
@app.get("/optional-current-active-user") @app.get("/optional-current-active-user")
def optional_current_active_user( def optional_current_active_user(
@ -85,7 +85,7 @@ async def test_app_client(
fastapi_users.current_user(optional=True, active=True) fastapi_users.current_user(optional=True, active=True)
), ),
): ):
return User.from_orm(user) if user else None return schemas.model_validate(User, user) if user else None
@app.get("/optional-current-verified-user") @app.get("/optional-current-verified-user")
def optional_current_verified_user( def optional_current_verified_user(
@ -93,7 +93,7 @@ async def test_app_client(
fastapi_users.current_user(optional=True, verified=True) fastapi_users.current_user(optional=True, verified=True)
), ),
): ):
return User.from_orm(user) if user else None return schemas.model_validate(User, user) if user else None
@app.get("/optional-current-superuser") @app.get("/optional-current-superuser")
def optional_current_superuser( def optional_current_superuser(
@ -101,7 +101,7 @@ async def test_app_client(
fastapi_users.current_user(optional=True, active=True, superuser=True) fastapi_users.current_user(optional=True, active=True, superuser=True)
), ),
): ):
return User.from_orm(user) if user else None return schemas.model_validate(User, user) if user else None
@app.get("/optional-current-verified-superuser") @app.get("/optional-current-verified-superuser")
def optional_current_verified_superuser( def optional_current_verified_superuser(
@ -111,7 +111,7 @@ async def test_app_client(
) )
), ),
): ):
return User.from_orm(user) if user else None return schemas.model_validate(User, user) if user else None
async for client in get_test_client(app): async for client in get_test_client(app):
yield client yield client

View File

@ -1,8 +1,8 @@
import uuid
from typing import Callable from typing import Callable
import pytest import pytest
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from pydantic import UUID4
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from fastapi_users.exceptions import ( from fastapi_users.exceptions import (
@ -77,7 +77,7 @@ def create_oauth2_password_request_form() -> (
class TestGet: class TestGet:
async def test_not_existing_user(self, user_manager: UserManagerMock[UserModel]): async def test_not_existing_user(self, user_manager: UserManagerMock[UserModel]):
with pytest.raises(UserNotExists): with pytest.raises(UserNotExists):
await user_manager.get(UUID4("d35d213e-f3d8-4f08-954a-7e0d1bea286f")) await user_manager.get(uuid.UUID("d35d213e-f3d8-4f08-954a-7e0d1bea286f"))
async def test_existing_user( async def test_existing_user(
self, user_manager: UserManagerMock[UserModel], user: UserModel self, user_manager: UserManagerMock[UserModel], user: UserModel