diff --git a/Makefile b/Makefile index dac4f4ee..9445ea77 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ format: isort-src isort-docs test: docker stop $(MONGODB_CONTAINER_NAME) || true docker run -d --rm --name $(MONGODB_CONTAINER_NAME) -p 27017:27017 mvertes/alpine-mongo - $(PIPENV_RUN) pytest --cov=fastapi_users/ + $(PIPENV_RUN) pytest --cov=fastapi_users/ --cov-report=term-missing docker stop $(MONGODB_CONTAINER_NAME) docs-serve: diff --git a/docs/configuration/databases/mongodb.md b/docs/configuration/databases/mongodb.md index eeecf4eb..16cddd67 100644 --- a/docs/configuration/databases/mongodb.md +++ b/docs/configuration/databases/mongodb.md @@ -12,11 +12,14 @@ Let's create a MongoDB connection and instantiate a collection. You can choose any name for the database and the collection. +!!! warning + You may have noticed the `uuidRepresentation` parameter. It controls how the UUID values will be encoded in the database. By default, it's set to `pythonLegacy` but new applications should consider setting this to `standard` for cross language compatibility. [Read more about this](https://pymongo.readthedocs.io/en/stable/api/pymongo/mongo_client.html#pymongo.mongo_client.MongoClient). + ## Create the database adapter The database adapter of **FastAPI Users** makes the link between your database configuration and the users logic. Create it like this. -```py hl_lines="32" +```py hl_lines="34" {!./src/db_mongodb.py!} ``` @@ -26,7 +29,7 @@ Notice that we pass a reference to your [`UserDB` model](../model.md). The database adapter will automatically create a [unique index](https://docs.mongodb.com/manual/core/index-unique/) on `id` and `email`. !!! warning - **FastAPI Users** will use its defined [`id` UUID-string](../model.md) as unique identifier for the user, rather than the builtin MongoDB `_id`. + **FastAPI Users** will use its defined [`id` UUID](../model.md) as unique identifier for the user, rather than the builtin MongoDB `_id`. ## Next steps diff --git a/docs/configuration/model.md b/docs/configuration/model.md index 1e9030c4..3ac102d7 100644 --- a/docs/configuration/model.md +++ b/docs/configuration/model.md @@ -2,7 +2,7 @@ **FastAPI Users** defines a minimal User model for authentication purposes. It is structured like this: -* `id` (`str`) – Unique identifier of the user. Default to a **UUID4**. +* `id` (`UUID4`) – Unique identifier of the user. Default to a **UUID4**. * `email` (`str`) – Email of the user. Validated by [`email-validator`](https://github.com/JoshData/python-email-validator). * `is_active` (`bool`) – Whether or not the user is active. If not, login and forgot password requests will be denied. Default to `True`. * `is_superuser` (`bool`) – Whether or not the user is a superuser. Useful to implement administration logic. Default to `False`. diff --git a/docs/configuration/oauth.md b/docs/configuration/oauth.md index d4cd2c0b..33c7c8c7 100644 --- a/docs/configuration/oauth.md +++ b/docs/configuration/oauth.md @@ -56,7 +56,7 @@ class UserDB(User, models.BaseUserDB): Notice that we inherit from the `BaseOAuthAccountMixin`, which adds a `List` of `BaseOAuthAccount` objects. This object is structured like this: -* `id` (`str`) – Unique identifier of the user. Default to a **UUID4**. +* `id` (`UUID4`) – Unique identifier of the OAuth account information. Default to a **UUID4**. * `oauth_name` (`str`) – Name of the OAuth service. It corresponds to the `name` property of the OAuth client. * `access_token` (`str`) – Access token. * `expires_at` (`int`) - Timestamp at which the access token is expired. @@ -100,7 +100,7 @@ class OAuthAccount(TortoiseBaseOAuthAccountModel): ``` !!! warning - Note that you shouls define the foreign key yourself, so that you can point it the user model in your namespace. + Note that you should define the foreign key yourself, so that you can point it the user model in your namespace. Then, you should declare it on the database adapter: diff --git a/docs/migration/08_to_10.md b/docs/migration/08_to_10.md new file mode 100644 index 00000000..25706ac7 --- /dev/null +++ b/docs/migration/08_to_10.md @@ -0,0 +1,82 @@ +# 0.8.x ➡️ 1.0.x + +1.0 version introduces major breaking changes that need you to update some of your code and migrate your data. + +## Id. are UUID + +Users and OAuth accounts id. are now represented as real UUID objects instead of plain strings. This change was introduced to leverage efficient storage and indexing for DBMS that supports UUID (especially PostgreSQL and Mongo). + +### In Python code + +If you were doing comparison betwen a user id. and a string (in unit tests for example), you should now cast the id. to string: + +```py +# Before +assert "d35d213e-f3d8-4f08-954a-7e0d1bea286f" == user.id + +# Now +assert "d35d213e-f3d8-4f08-954a-7e0d1bea286f" == str(user.id) +``` + +If you were refering to user id. in your Pydantic models, the field should now be of `UUID4` type instead of `str`: + +```py +from pydantic import BaseModel, UUID4 + +# Before +class Model(BaseModel): + user_id: str + +# After +class Model(BaseModel): + user_id: UUID4 +``` + +### In database + +Id. were before stored as strings in the database. You should make a migration to convert string data to UUID data. + +!!! danger + Scripts below are provided as guidelines. Please **review them carefully**, **adapt them** and check that they are working on a test database before applying them to production. **BE CAREFUL. THEY CAN DESTROY YOUR DATA.**. + +#### PostgreSQL + +PostgreSQL supports UUID type. If not already, you should enable the `uuid-ossp` extension: + +```sql +CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; +``` + +To convert the existing id. string column, we can: + +1. Create a new column with UUID type. +2. Fill it with the id. converted to UUID. +3. Drop the original id. column. +4. Make the new column a primary key and rename it. + +```sql +ALTER TABLE "user" ADD uuid_id UUID; +UPDATE "user" SET uuid_id = uuid(id); +ALTER TABLE "user" DROP id; +ALTER TABLE "user" ADD PRIMARY KEY (uuid_id); +ALTER TABLE "user" RENAME COLUMN uuid_id TO id; +``` + +#### MySQL + +MySQL doesn't support UUID type. We'll just convert the column to `CHAR(36)` type: + +```sql +ALTER TABLE "user" MODIFY id CHAR(36); +``` + +#### MongoDB + +For MongoDB, we can use a `forEach` iterator to convert the id. for each document: + +```js +db.getCollection('users').find().forEach(function(user) { + var uuid = UUID(user.id); + db.getCollection('users').update({_id: user._id}, [{$set: {id: uuid}}]); +}); +``` diff --git a/docs/src/db_mongodb.py b/docs/src/db_mongodb.py index 53adfe47..75227c5b 100644 --- a/docs/src/db_mongodb.py +++ b/docs/src/db_mongodb.py @@ -21,7 +21,9 @@ class UserDB(User, models.BaseUserDB): DATABASE_URL = "mongodb://localhost:27017" -client = motor.motor_asyncio.AsyncIOMotorClient(DATABASE_URL) +client = motor.motor_asyncio.AsyncIOMotorClient( + DATABASE_URL, uuidRepresentation="standard" +) db = client["database_name"] collection = db["users"] diff --git a/docs/src/full_mongodb.py b/docs/src/full_mongodb.py index 9672a014..f3d1dc15 100644 --- a/docs/src/full_mongodb.py +++ b/docs/src/full_mongodb.py @@ -25,7 +25,9 @@ class UserDB(User, models.BaseUserDB): pass -client = motor.motor_asyncio.AsyncIOMotorClient(DATABASE_URL) +client = motor.motor_asyncio.AsyncIOMotorClient( + DATABASE_URL, uuidRepresentation="standard" +) db = client["database_name"] collection = db["users"] user_db = MongoDBUserDatabase(UserDB, collection) diff --git a/docs/src/oauth_full_mongodb.py b/docs/src/oauth_full_mongodb.py index 5b0724ff..73b2a5a7 100644 --- a/docs/src/oauth_full_mongodb.py +++ b/docs/src/oauth_full_mongodb.py @@ -29,7 +29,9 @@ class UserDB(User, models.BaseUserDB): pass -client = motor.motor_asyncio.AsyncIOMotorClient(DATABASE_URL) +client = motor.motor_asyncio.AsyncIOMotorClient( + DATABASE_URL, uuidRepresentation="standard" +) db = client["database_name"] collection = db["users"] user_db = MongoDBUserDatabase(UserDB, collection) diff --git a/fastapi_users/authentication/jwt.py b/fastapi_users/authentication/jwt.py index 7d81d938..ba80229e 100644 --- a/fastapi_users/authentication/jwt.py +++ b/fastapi_users/authentication/jwt.py @@ -2,6 +2,7 @@ from typing import Any, Optional import jwt from fastapi.security import OAuth2PasswordBearer +from pydantic import UUID4 from starlette.requests import Request from starlette.responses import Response @@ -56,7 +57,12 @@ class JWTAuthentication(BaseAuthentication): return None except jwt.PyJWTError: return None - return await user_db.get(user_id) + + try: + user_uiid = UUID4(user_id) + return await user_db.get(user_uiid) + except ValueError: + return None async def get_login_response(self, user: BaseUserDB, response: Response) -> Any: token = await self._generate_token(user) @@ -66,5 +72,5 @@ class JWTAuthentication(BaseAuthentication): return await self.scheme.__call__(request) async def _generate_token(self, user: BaseUserDB) -> str: - data = {"user_id": user.id, "aud": self.token_audience} + data = {"user_id": str(user.id), "aud": self.token_audience} return generate_jwt(data, self.lifetime_seconds, self.secret, JWT_ALGORITHM) diff --git a/fastapi_users/db/base.py b/fastapi_users/db/base.py index 3df7301e..5af58720 100644 --- a/fastapi_users/db/base.py +++ b/fastapi_users/db/base.py @@ -1,6 +1,7 @@ from typing import Generic, Optional, Type from fastapi.security import OAuth2PasswordRequestForm +from pydantic import UUID4 from fastapi_users import password from fastapi_users.models import UD @@ -18,7 +19,7 @@ class BaseUserDatabase(Generic[UD]): def __init__(self, user_db_model: Type[UD]): self.user_db_model = user_db_model - async def get(self, id: str) -> Optional[UD]: + async def get(self, id: UUID4) -> Optional[UD]: """Get a single user by id.""" raise NotImplementedError() diff --git a/fastapi_users/db/mongodb.py b/fastapi_users/db/mongodb.py index 1841484d..797a407a 100644 --- a/fastapi_users/db/mongodb.py +++ b/fastapi_users/db/mongodb.py @@ -1,6 +1,7 @@ from typing import Optional, Type from motor.motor_asyncio import AsyncIOMotorCollection +from pydantic import UUID4 from fastapi_users.db.base import BaseUserDatabase from fastapi_users.models import UD @@ -22,7 +23,7 @@ class MongoDBUserDatabase(BaseUserDatabase[UD]): self.collection.create_index("id", unique=True) self.collection.create_index("email", unique=True) - async def get(self, id: str) -> Optional[UD]: + async def get(self, id: UUID4) -> Optional[UD]: user = await self.collection.find_one({"id": id}) return self.user_db_model(**user) if user else None diff --git a/fastapi_users/db/sqlalchemy.py b/fastapi_users/db/sqlalchemy.py index 106c3117..b5a57b4a 100644 --- a/fastapi_users/db/sqlalchemy.py +++ b/fastapi_users/db/sqlalchemy.py @@ -1,19 +1,58 @@ +import uuid from typing import Mapping, Optional, Type from databases import Database +from pydantic import UUID4 from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Table, select +from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.ext.declarative import declared_attr +from sqlalchemy.types import CHAR, TypeDecorator from fastapi_users.db.base import BaseUserDatabase from fastapi_users.models import UD +class GUID(TypeDecorator): # pragma: no cover + """Platform-independent GUID type. + + Uses PostgreSQL's UUID type, otherwise uses + CHAR(36), storing as regular strings. + """ + + impl = CHAR + + def load_dialect_impl(self, dialect): + if dialect.name == "postgresql": + return dialect.type_descriptor(UUID()) + else: + return dialect.type_descriptor(CHAR(36)) + + def process_bind_param(self, value, dialect): + if value is None: + return value + elif dialect.name == "postgresql": + return str(value) + else: + if not isinstance(value, uuid.UUID): + return str(uuid.UUID(value)) + else: + return str(value) + + def process_result_value(self, value, dialect): + if value is None: + return value + else: + if not isinstance(value, uuid.UUID): + value = uuid.UUID(value) + return value + + class SQLAlchemyBaseUserTable: """Base SQLAlchemy users table definition.""" __tablename__ = "user" - id = Column(String, primary_key=True) + id = Column(GUID, primary_key=True) email = Column(String, unique=True, index=True, nullable=False) hashed_password = Column(String, nullable=False) is_active = Column(Boolean, default=True, nullable=False) @@ -25,7 +64,7 @@ class SQLAlchemyBaseOAuthAccountTable: __tablename__ = "oauth_account" - id = Column(String, primary_key=True) + id = Column(GUID, primary_key=True) oauth_name = Column(String, index=True, nullable=False) access_token = Column(String, nullable=False) expires_at = Column(Integer, nullable=False) @@ -35,7 +74,7 @@ class SQLAlchemyBaseOAuthAccountTable: @declared_attr def user_id(cls): - return Column(String, ForeignKey("user.id", ondelete="cascade"), nullable=False) + return Column(GUID, ForeignKey("user.id", ondelete="cascade"), nullable=False) class NotSetOAuthAccountTableError(Exception): @@ -75,7 +114,7 @@ class SQLAlchemyUserDatabase(BaseUserDatabase[UD]): self.users = users self.oauth_accounts = oauth_accounts - async def get(self, id: str) -> Optional[UD]: + async def get(self, id: UUID4) -> Optional[UD]: query = self.users.select().where(self.users.c.id == id) user = await self.database.fetch_one(query) return await self._make_user(user) if user else None diff --git a/fastapi_users/db/tortoise.py b/fastapi_users/db/tortoise.py index 72a97bee..80f04f11 100644 --- a/fastapi_users/db/tortoise.py +++ b/fastapi_users/db/tortoise.py @@ -1,5 +1,6 @@ from typing import Optional, Type +from pydantic import UUID4 from tortoise import fields, models from tortoise.exceptions import DoesNotExist @@ -8,7 +9,7 @@ from fastapi_users.models import UD class TortoiseBaseUserModel(models.Model): - id = fields.CharField(pk=True, generated=False, max_length=255) + id = fields.UUIDField(pk=True, generated=False) email = fields.CharField(index=True, unique=True, null=False, max_length=255) hashed_password = fields.CharField(null=False, max_length=255) is_active = fields.BooleanField(default=True, null=False) @@ -27,7 +28,7 @@ class TortoiseBaseUserModel(models.Model): class TortoiseBaseOAuthAccountModel(models.Model): - id = fields.CharField(pk=True, generated=False, max_length=255) + id = fields.UUIDField(pk=True, generated=False, max_length=255) oauth_name = fields.CharField(null=False, max_length=255) access_token = fields.CharField(null=False, max_length=255) expires_at = fields.IntField(null=False) @@ -61,7 +62,7 @@ class TortoiseUserDatabase(BaseUserDatabase[UD]): self.model = model self.oauth_account_model = oauth_account_model - async def get(self, id: str) -> Optional[UD]: + async def get(self, id: UUID4) -> Optional[UD]: try: query = self.model.get(id=id) diff --git a/fastapi_users/models.py b/fastapi_users/models.py index c51c89a0..0c95ea1a 100644 --- a/fastapi_users/models.py +++ b/fastapi_users/models.py @@ -1,21 +1,20 @@ import uuid from typing import List, Optional, TypeVar -import pydantic -from pydantic import BaseModel, EmailStr +from pydantic import UUID4, BaseModel, EmailStr, validator class BaseUser(BaseModel): """Base User model.""" - id: Optional[str] = None + id: Optional[UUID4] = None email: Optional[EmailStr] = None is_active: Optional[bool] = True is_superuser: Optional[bool] = False - @pydantic.validator("id", pre=True, always=True) + @validator("id", pre=True, always=True) def default_id(cls, v): - return v or str(uuid.uuid4()) + return v or uuid.uuid4() def create_update_dict(self): return self.dict( @@ -37,7 +36,7 @@ class BaseUserUpdate(BaseUser): class BaseUserDB(BaseUser): - id: str + id: UUID4 hashed_password: str class Config: @@ -50,7 +49,7 @@ UD = TypeVar("UD", bound=BaseUserDB) class BaseOAuthAccount(BaseModel): """Base OAuth account model.""" - id: Optional[str] = None + id: Optional[UUID4] = None oauth_name: str access_token: str expires_at: int @@ -58,9 +57,9 @@ class BaseOAuthAccount(BaseModel): account_id: str account_email: str - @pydantic.validator("id", pre=True, always=True) + @validator("id", pre=True, always=True) def default_id(cls, v): - return v or str(uuid.uuid4()) + return v or uuid.uuid4() class Config: orm_mode = True diff --git a/fastapi_users/router/users.py b/fastapi_users/router/users.py index 01169e4e..4125c7b9 100644 --- a/fastapi_users/router/users.py +++ b/fastapi_users/router/users.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Type, cast import jwt from fastapi import Body, Depends, HTTPException from fastapi.security import OAuth2PasswordRequestForm -from pydantic import EmailStr +from pydantic import UUID4, EmailStr from starlette import status from starlette.requests import Request from starlette.responses import Response @@ -69,7 +69,7 @@ def get_user_router( get_current_active_user = authenticator.get_current_active_user get_current_superuser = authenticator.get_current_superuser - async def _get_or_404(id: str) -> models.BaseUserDB: + async def _get_or_404(id: UUID4) -> models.BaseUserDB: user = await user_db.get(id) if user is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) @@ -118,7 +118,7 @@ def get_user_router( user = await user_db.get_by_email(email) if user is not None and user.is_active: - token_data = {"user_id": user.id, "aud": reset_password_token_audience} + token_data = {"user_id": str(user.id), "aud": reset_password_token_audience} token = generate_jwt( token_data, reset_password_token_lifetime_seconds, @@ -146,7 +146,15 @@ def get_user_router( detail=ErrorCode.RESET_PASSWORD_BAD_TOKEN, ) - user = await user_db.get(user_id) + try: + user_uiid = UUID4(user_id) + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ErrorCode.RESET_PASSWORD_BAD_TOKEN, + ) + + user = await user_db.get(user_uiid) if user is None or not user.is_active: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -190,7 +198,7 @@ def get_user_router( response_model=user_model, dependencies=[Depends(get_current_superuser)], ) - async def get_user(id: str,): + async def get_user(id: UUID4): return await _get_or_404(id) @router.patch( @@ -199,7 +207,7 @@ def get_user_router( dependencies=[Depends(get_current_superuser)], ) async def update_user( - id: str, updated_user: user_update_model, # type: ignore + id: UUID4, updated_user: user_update_model, # type: ignore ): updated_user = cast( models.BaseUserUpdate, updated_user, @@ -213,7 +221,7 @@ def get_user_router( status_code=status.HTTP_204_NO_CONTENT, dependencies=[Depends(get_current_superuser)], ) - async def delete_user(id: str): + async def delete_user(id: UUID4): user = await _get_or_404(id) await user_db.delete(user) return None diff --git a/mkdocs.yml b/mkdocs.yml index 32ce019d..30df3b7c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -45,3 +45,5 @@ nav: - usage/flow.md - usage/routes.md - usage/dependency-callables.md + - Migration: + - migration/08_to_10.md diff --git a/tests/conftest.py b/tests/conftest.py index bdef51e9..b52f07a4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ from asgi_lifespan import LifespanManager from fastapi import Depends, FastAPI from fastapi.security import OAuth2PasswordBearer from httpx_oauth.oauth2 import OAuth2 +from pydantic import UUID4 from starlette.applications import ASGIApp from starlette.requests import Request from starlette.responses import Response @@ -58,16 +59,13 @@ def event_loop(): @pytest.fixture def user() -> UserDB: return UserDB( - id="aaa", - email="king.arthur@camelot.bt", - hashed_password=guinevere_password_hash, + email="king.arthur@camelot.bt", hashed_password=guinevere_password_hash, ) @pytest.fixture def user_oauth(oauth_account1, oauth_account2) -> UserDBOAuth: return UserDBOAuth( - id="aaa", email="king.arthur@camelot.bt", hashed_password=guinevere_password_hash, oauth_accounts=[oauth_account1, oauth_account2], @@ -77,7 +75,6 @@ def user_oauth(oauth_account1, oauth_account2) -> UserDBOAuth: @pytest.fixture def inactive_user() -> UserDB: return UserDB( - id="bbb", email="percival@camelot.bt", hashed_password=angharad_password_hash, is_active=False, @@ -87,7 +84,6 @@ def inactive_user() -> UserDB: @pytest.fixture def inactive_user_oauth(oauth_account3) -> UserDBOAuth: return UserDBOAuth( - id="bbb", email="percival@camelot.bt", hashed_password=angharad_password_hash, is_active=False, @@ -98,7 +94,6 @@ def inactive_user_oauth(oauth_account3) -> UserDBOAuth: @pytest.fixture def superuser() -> UserDB: return UserDB( - id="ccc", email="merlin@camelot.bt", hashed_password=viviane_password_hash, is_superuser=True, @@ -108,7 +103,6 @@ def superuser() -> UserDB: @pytest.fixture def superuser_oauth() -> UserDBOAuth: return UserDBOAuth( - id="ccc", email="merlin@camelot.bt", hashed_password=viviane_password_hash, is_superuser=True, @@ -119,7 +113,6 @@ def superuser_oauth() -> UserDBOAuth: @pytest.fixture def oauth_account1() -> BaseOAuthAccount: return BaseOAuthAccount( - id="aaa", oauth_name="service1", access_token="TOKEN", expires_at=1579000751, @@ -131,7 +124,6 @@ def oauth_account1() -> BaseOAuthAccount: @pytest.fixture def oauth_account2() -> BaseOAuthAccount: return BaseOAuthAccount( - id="bbb", oauth_name="service2", access_token="TOKEN", expires_at=1579000751, @@ -143,7 +135,6 @@ def oauth_account2() -> BaseOAuthAccount: @pytest.fixture def oauth_account3() -> BaseOAuthAccount: return BaseOAuthAccount( - id="ccc", oauth_name="service3", access_token="TOKEN", expires_at=1579000751, @@ -155,7 +146,7 @@ def oauth_account3() -> BaseOAuthAccount: @pytest.fixture def mock_user_db(user, inactive_user, superuser) -> BaseUserDatabase: class MockUserDatabase(BaseUserDatabase[UserDB]): - async def get(self, id: str) -> Optional[UserDB]: + async def get(self, id: UUID4) -> Optional[UserDB]: if id == user.id: return user if id == inactive_user.id: @@ -190,7 +181,7 @@ def mock_user_db_oauth( user_oauth, inactive_user_oauth, superuser_oauth ) -> BaseUserDatabase: class MockUserDatabase(BaseUserDatabase[UserDBOAuth]): - async def get(self, id: str) -> Optional[UserDBOAuth]: + async def get(self, id: UUID4) -> Optional[UserDBOAuth]: if id == user_oauth.id: return user_oauth if id == inactive_user_oauth.id: @@ -246,7 +237,11 @@ class MockAuthentication(BaseAuthentication): async def __call__(self, request: Request, user_db: BaseUserDatabase): token = await self.scheme.__call__(request) if token is not None: - return await user_db.get(token) + try: + token_uuid = UUID4(token) + return await user_db.get(token_uuid) + except ValueError: + return None return None async def get_login_response(self, user: BaseUserDB, response: Response): diff --git a/tests/test_authentication_cookie.py b/tests/test_authentication_cookie.py index 4a6866de..981caeca 100644 --- a/tests/test_authentication_cookie.py +++ b/tests/test_authentication_cookie.py @@ -31,7 +31,7 @@ def token(): def _token(user=None, lifetime=LIFETIME): data = {"aud": "fastapi-users:auth"} if user is not None: - data["user_id"] = user.id + data["user_id"] = str(user.id) return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM) return _token @@ -131,7 +131,7 @@ async def test_get_login_response( decoded = jwt.decode( cookie_value, SECRET, audience="fastapi-users:auth", algorithms=[JWT_ALGORITHM] ) - assert decoded["user_id"] == user.id + assert decoded["user_id"] == str(user.id) @pytest.mark.authentication @@ -149,4 +149,4 @@ async def test_get_logout_response(user): cookie = cookies[0][1].decode("latin-1") - assert f"Max-Age=0" in cookie + assert "Max-Age=0" in cookie diff --git a/tests/test_authentication_jwt.py b/tests/test_authentication_jwt.py index bd6845f1..d6cf3f27 100644 --- a/tests/test_authentication_jwt.py +++ b/tests/test_authentication_jwt.py @@ -17,10 +17,10 @@ def jwt_authentication(): @pytest.fixture def token(): - def _token(user=None, lifetime=LIFETIME): + def _token(user_id=None, lifetime=LIFETIME): data = {"aud": "fastapi-users:auth"} - if user is not None: - data["user_id"] = user.id + if user_id is not None: + data["user_id"] = str(user_id) return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM) return _token @@ -57,11 +57,19 @@ class TestAuthenticate: authenticated_user = await jwt_authentication(request, mock_user_db) assert authenticated_user is None + @pytest.mark.asyncio + async def test_valid_token_invalid_uuid( + self, jwt_authentication, mock_user_db, request_builder, token + ): + request = request_builder(headers={"Authorization": f"Bearer {token('foo')}"}) + authenticated_user = await jwt_authentication(request, mock_user_db) + assert authenticated_user is None + @pytest.mark.asyncio async def test_valid_token( self, jwt_authentication, mock_user_db, request_builder, token, user ): - request = request_builder(headers={"Authorization": f"Bearer {token(user)}"}) + request = request_builder(headers={"Authorization": f"Bearer {token(user.id)}"}) authenticated_user = await jwt_authentication(request, mock_user_db) assert authenticated_user.id == user.id @@ -77,7 +85,7 @@ async def test_get_login_response(jwt_authentication, user): decoded = jwt.decode( token, SECRET, audience="fastapi-users:auth", algorithms=[JWT_ALGORITHM] ) - assert decoded["user_id"] == user.id + assert decoded["user_id"] == str(user.id) @pytest.mark.authentication diff --git a/tests/test_db_mongodb.py b/tests/test_db_mongodb.py index 77d975e8..a9dfc562 100644 --- a/tests/test_db_mongodb.py +++ b/tests/test_db_mongodb.py @@ -15,7 +15,9 @@ def get_mongodb_user_db(): user_model, ) -> AsyncGenerator[MongoDBUserDatabase, None]: client = motor.motor_asyncio.AsyncIOMotorClient( - "mongodb://localhost:27017", serverSelectionTimeoutMS=100 + "mongodb://localhost:27017", + serverSelectionTimeoutMS=100, + uuidRepresentation="standard", ) try: @@ -50,9 +52,7 @@ async def mongodb_user_db_oauth(get_mongodb_user_db): @pytest.mark.db async def test_queries(mongodb_user_db: MongoDBUserDatabase[UserDB]): user = UserDB( - id="111", - email="lancelot@camelot.bt", - hashed_password=get_password_hash("guinevere"), + email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"), ) # Create @@ -96,7 +96,6 @@ async def test_queries(mongodb_user_db: MongoDBUserDatabase[UserDB]): async def test_queries_custom_fields(mongodb_user_db: MongoDBUserDatabase[UserDB]): """It should output custom fields in query result.""" user = UserDB( - id="111", email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"), first_name="Lancelot", @@ -117,7 +116,6 @@ async def test_queries_oauth( oauth_account2, ): user = UserDBOAuth( - id="111", email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"), oauth_accounts=[oauth_account1, oauth_account2], diff --git a/tests/test_db_sqlalchemy.py b/tests/test_db_sqlalchemy.py index 7b5d93b6..cb8cd3c9 100644 --- a/tests/test_db_sqlalchemy.py +++ b/tests/test_db_sqlalchemy.py @@ -70,9 +70,7 @@ async def sqlalchemy_user_db_oauth() -> AsyncGenerator[SQLAlchemyUserDatabase, N @pytest.mark.db async def test_queries(sqlalchemy_user_db: SQLAlchemyUserDatabase[UserDB]): user = UserDB( - id="111", - email="lancelot@camelot.bt", - hashed_password=get_password_hash("guinevere"), + email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"), ) # Create @@ -103,7 +101,7 @@ async def test_queries(sqlalchemy_user_db: SQLAlchemyUserDatabase[UserDB]): # Exception when inserting non-nullable fields with pytest.raises(sqlite3.IntegrityError): - wrong_user = UserDB(id="222", hashed_password="aaa") + wrong_user = UserDB(hashed_password="aaa") await sqlalchemy_user_db.create(wrong_user) # Unknown user @@ -117,9 +115,7 @@ async def test_queries(sqlalchemy_user_db: SQLAlchemyUserDatabase[UserDB]): # Exception when creating/updating a OAuth user user_oauth = UserDBOAuth( - id="222", - email="lancelot@camelot.bt", - hashed_password=get_password_hash("guinevere"), + email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"), ) with pytest.raises(NotSetOAuthAccountTableError): await sqlalchemy_user_db.create(user_oauth) @@ -138,7 +134,6 @@ async def test_queries_custom_fields( ): """It should output custom fields in query result.""" user = UserDB( - id="111", email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"), first_name="Lancelot", @@ -159,7 +154,6 @@ async def test_queries_oauth( oauth_account2, ): user = UserDBOAuth( - id="111", email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"), oauth_accounts=[oauth_account1, oauth_account2], diff --git a/tests/test_db_tortoise.py b/tests/test_db_tortoise.py index e94f58c3..716faff9 100644 --- a/tests/test_db_tortoise.py +++ b/tests/test_db_tortoise.py @@ -55,9 +55,7 @@ async def tortoise_user_db_oauth() -> AsyncGenerator[TortoiseUserDatabase, None] @pytest.mark.db async def test_queries(tortoise_user_db: TortoiseUserDatabase[UserDB]): user = UserDB( - id="111", - email="lancelot@camelot.bt", - hashed_password=get_password_hash("guinevere"), + email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"), ) # Create @@ -88,7 +86,7 @@ async def test_queries(tortoise_user_db: TortoiseUserDatabase[UserDB]): # Exception when inserting non-nullable fields with pytest.raises(ValueError): - wrong_user = UserDB(id="222", hashed_password="aaa") + wrong_user = UserDB(hashed_password="aaa") await tortoise_user_db.create(wrong_user) # Unknown user @@ -106,7 +104,6 @@ async def test_queries(tortoise_user_db: TortoiseUserDatabase[UserDB]): async def test_queries_custom_fields(tortoise_user_db: TortoiseUserDatabase[UserDB]): """It should output custom fields in query result.""" user = UserDB( - id="111", email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"), first_name="Lancelot", @@ -127,7 +124,6 @@ async def test_queries_oauth( oauth_account2, ): user = UserDBOAuth( - id="111", email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"), oauth_accounts=[oauth_account1, oauth_account2], diff --git a/tests/test_router_oauth.py b/tests/test_router_oauth.py index f62c8f94..4840d87b 100644 --- a/tests/test_router_oauth.py +++ b/tests/test_router_oauth.py @@ -204,7 +204,7 @@ class TestCallback: user_update_mock.assert_awaited_once() data = cast(Dict[str, Any], response.json()) - assert data["token"] == user_oauth.id + assert data["token"] == str(user_oauth.id) assert event_handler.called is False @@ -242,7 +242,7 @@ class TestCallback: user_update_mock.assert_awaited_once() data = cast(Dict[str, Any], response.json()) - assert data["token"] == superuser_oauth.id + assert data["token"] == str(superuser_oauth.id) assert event_handler.called is False @@ -283,7 +283,7 @@ class TestCallback: assert event_handler.called is True actual_user = event_handler.call_args[0][0] - assert actual_user.id == data["token"] + assert str(actual_user.id) == data["token"] request = event_handler.call_args[0][1] assert isinstance(request, Request) @@ -348,4 +348,4 @@ class TestCallback: ) data = cast(Dict[str, Any], response.json()) - assert data["token"] == user_oauth.id + assert data["token"] == str(user_oauth.id) diff --git a/tests/test_router_users.py b/tests/test_router_users.py index 403fe26f..a6e6d079 100644 --- a/tests/test_router_users.py +++ b/tests/test_router_users.py @@ -23,7 +23,7 @@ def forgot_password_token(): def _forgot_password_token(user_id=None, lifetime=LIFETIME): data = {"aud": "fastapi-users:reset"} if user_id is not None: - data["user_id"] = user_id + data["user_id"] = str(user_id) return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM) return _forgot_password_token @@ -117,7 +117,7 @@ class TestRegister: assert data["id"] is not None actual_user = event_handler.call_args[0][0] - assert actual_user.id == data["id"] + assert str(actual_user.id) == data["id"] request = event_handler.call_args[0][1] assert isinstance(request, Request) @@ -190,7 +190,7 @@ class TestLogin: data = {"username": "king.arthur@camelot.bt", "password": "guinevere"} response = await test_app_client.post(path, data=data) assert response.status_code == status.HTTP_200_OK - assert response.json() == {"token": user.id} + assert response.json() == {"token": str(user.id)} async def test_inactive_user(self, path, test_app_client: httpx.AsyncClient): data = {"username": "percival@camelot.bt", "password": "angharad"} @@ -261,7 +261,7 @@ class TestForgotPassword: audience="fastapi-users:reset", algorithms=[JWT_ALGORITHM], ) - assert decoded_token["user_id"] == user.id + assert decoded_token["user_id"] == str(user.id) request = event_handler.call_args[0][2] assert isinstance(request, Request) @@ -306,6 +306,22 @@ class TestResetPassword: assert data["detail"] == ErrorCode.RESET_PASSWORD_BAD_TOKEN assert mock_user_db.update.called is False + async def test_valid_token_invalid_uuid( + self, + mocker, + mock_user_db, + test_app_client: httpx.AsyncClient, + forgot_password_token, + ): + mocker.spy(mock_user_db, "update") + + json = {"token": forgot_password_token("foo"), "password": "holygrail"} + response = await test_app_client.post("/reset-password", json=json) + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = cast(Dict[str, Any], response.json()) + assert data["detail"] == ErrorCode.RESET_PASSWORD_BAD_TOKEN + assert mock_user_db.update.called is False + async def test_inactive_user( self, mocker, @@ -368,7 +384,7 @@ class TestMe: assert response.status_code == status.HTTP_200_OK data = cast(Dict[str, Any], response.json()) - assert data["id"] == user.id + assert data["id"] == str(user.id) assert data["email"] == user.email @@ -504,12 +520,13 @@ class TestUpdateMe: @pytest.mark.asyncio class TestGetUser: async def test_missing_token(self, test_app_client: httpx.AsyncClient): - response = await test_app_client.get("/000") + response = await test_app_client.get("/d35d213e-f3d8-4f08-954a-7e0d1bea286f") assert response.status_code == status.HTTP_401_UNAUTHORIZED async def test_regular_user(self, test_app_client: httpx.AsyncClient, user: UserDB): response = await test_app_client.get( - "/000", headers={"Authorization": f"Bearer {user.id}"} + "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", + headers={"Authorization": f"Bearer {user.id}"}, ) assert response.status_code == status.HTTP_403_FORBIDDEN @@ -517,7 +534,8 @@ class TestGetUser: self, test_app_client: httpx.AsyncClient, superuser: UserDB ): response = await test_app_client.get( - "/000", headers={"Authorization": f"Bearer {superuser.id}"} + "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", + headers={"Authorization": f"Bearer {superuser.id}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -530,7 +548,7 @@ class TestGetUser: assert response.status_code == status.HTTP_200_OK data = cast(Dict[str, Any], response.json()) - assert data["id"] == user.id + assert data["id"] == str(user.id) assert "hashed_password" not in data @@ -538,12 +556,13 @@ class TestGetUser: @pytest.mark.asyncio class TestUpdateUser: async def test_missing_token(self, test_app_client: httpx.AsyncClient): - response = await test_app_client.patch("/000") + response = await test_app_client.patch("/d35d213e-f3d8-4f08-954a-7e0d1bea286f") assert response.status_code == status.HTTP_401_UNAUTHORIZED async def test_regular_user(self, test_app_client: httpx.AsyncClient, user: UserDB): response = await test_app_client.patch( - "/000", headers={"Authorization": f"Bearer {user.id}"} + "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", + headers={"Authorization": f"Bearer {user.id}"}, ) assert response.status_code == status.HTTP_403_FORBIDDEN @@ -551,7 +570,9 @@ class TestUpdateUser: self, test_app_client: httpx.AsyncClient, superuser: UserDB ): response = await test_app_client.patch( - "/000", json={}, headers={"Authorization": f"Bearer {superuser.id}"} + "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", + json={}, + headers={"Authorization": f"Bearer {superuser.id}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND @@ -636,12 +657,13 @@ class TestUpdateUser: @pytest.mark.asyncio class TestDeleteUser: async def test_missing_token(self, test_app_client: httpx.AsyncClient): - response = await test_app_client.delete("/000") + response = await test_app_client.delete("/d35d213e-f3d8-4f08-954a-7e0d1bea286f") assert response.status_code == status.HTTP_401_UNAUTHORIZED async def test_regular_user(self, test_app_client: httpx.AsyncClient, user: UserDB): response = await test_app_client.delete( - "/000", headers={"Authorization": f"Bearer {user.id}"} + "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", + headers={"Authorization": f"Bearer {user.id}"}, ) assert response.status_code == status.HTTP_403_FORBIDDEN @@ -649,7 +671,8 @@ class TestDeleteUser: self, test_app_client: httpx.AsyncClient, superuser: UserDB ): response = await test_app_client.delete( - "/000", headers={"Authorization": f"Bearer {superuser.id}"} + "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", + headers={"Authorization": f"Bearer {superuser.id}"}, ) assert response.status_code == status.HTTP_404_NOT_FOUND