From 848315badc1e9447cfe4a8ff6928a4d77ff5ab70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 27 Aug 2021 17:01:50 +0200 Subject: [PATCH] Remove DB dependencies (#704) * Remove database adapter in favor of external dependencies * Prevent flit from installing all optional dependencies when testing build * Remove MongoDB service during CI build --- .github/workflows/build.yml | 8 +- Makefile | 5 - fastapi_users/db/__init__.py | 8 +- fastapi_users/db/mongodb.py | 72 ----------- fastapi_users/db/ormar.py | 113 ----------------- fastapi_users/db/sqlalchemy.py | 206 ------------------------------- fastapi_users/db/tortoise.py | 147 ---------------------- pyproject.toml | 9 +- requirements.dev.txt | 2 - requirements.txt | 5 - tests/test_db_mongodb.py | 200 ------------------------------ tests/test_db_ormar.py | 214 --------------------------------- tests/test_db_sqlalchemy.py | 203 ------------------------------- tests/test_db_tortoise.py | 183 ---------------------------- 14 files changed, 9 insertions(+), 1366 deletions(-) delete mode 100644 fastapi_users/db/mongodb.py delete mode 100644 fastapi_users/db/ormar.py delete mode 100644 fastapi_users/db/sqlalchemy.py delete mode 100644 fastapi_users/db/tortoise.py delete mode 100644 tests/test_db_mongodb.py delete mode 100644 tests/test_db_ormar.py delete mode 100644 tests/test_db_sqlalchemy.py delete mode 100644 tests/test_db_tortoise.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index affbe9ae..1db408fa 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -10,12 +10,6 @@ jobs: matrix: python_version: [3.7, 3.8, 3.9] - services: - mongo: - image: mongo:4.2 - ports: - - 27017:27017 - steps: - uses: actions/checkout@v1 - name: Set up Python @@ -35,7 +29,7 @@ jobs: - name: Build and install it on system host run: | flit build - flit install --python $(which python) + flit install --deps none --python $(which python) python test_build.py release: diff --git a/Makefile b/Makefile index ea52d1f0..dc81a0d4 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,3 @@ -MONGODB_CONTAINER_NAME := fastapi-users-test-mongo - isort-src: isort ./fastapi_users ./tests @@ -10,10 +8,7 @@ format: isort-src isort-docs black . test: - docker stop $(MONGODB_CONTAINER_NAME) || true - docker run -d --rm --name $(MONGODB_CONTAINER_NAME) -p 27017:27017 mongo:4.2 pytest --cov=fastapi_users/ --cov-report=term-missing --cov-fail-under=100 - docker stop $(MONGODB_CONTAINER_NAME) docs-serve: mkdocs serve diff --git a/fastapi_users/db/__init__.py b/fastapi_users/db/__init__.py index 5988e77e..d853a74b 100644 --- a/fastapi_users/db/__init__.py +++ b/fastapi_users/db/__init__.py @@ -1,12 +1,12 @@ from fastapi_users.db.base import BaseUserDatabase # noqa: F401 try: - from fastapi_users.db.mongodb import MongoDBUserDatabase # noqa: F401 + from fastapi_users_db_mongodb import MongoDBUserDatabase # noqa: F401 except ImportError: # pragma: no cover pass try: - from fastapi_users.db.sqlalchemy import ( # noqa: F401 + from fastapi_users_db_sqlalchemy import ( # noqa: F401 SQLAlchemyBaseOAuthAccountTable, SQLAlchemyBaseUserTable, SQLAlchemyUserDatabase, @@ -15,7 +15,7 @@ except ImportError: # pragma: no cover pass try: - from fastapi_users.db.tortoise import ( # noqa: F401 + from fastapi_users_db_tortoise import ( # noqa: F401 TortoiseBaseOAuthAccountModel, TortoiseBaseUserModel, TortoiseUserDatabase, @@ -24,7 +24,7 @@ except ImportError: # pragma: no cover pass try: - from fastapi_users.db.ormar import ( # noqa: F401 + from fastapi_users_db_ormar import ( # noqa: F401 OrmarBaseOAuthAccountModel, OrmarBaseUserModel, OrmarUserDatabase, diff --git a/fastapi_users/db/mongodb.py b/fastapi_users/db/mongodb.py deleted file mode 100644 index 30f262f6..00000000 --- a/fastapi_users/db/mongodb.py +++ /dev/null @@ -1,72 +0,0 @@ -from typing import Optional, Type - -from motor.motor_asyncio import AsyncIOMotorCollection -from pydantic import UUID4 -from pymongo.collation import Collation - -from fastapi_users.db.base import BaseUserDatabase -from fastapi_users.models import UD - - -class MongoDBUserDatabase(BaseUserDatabase[UD]): - """ - Database adapter for MongoDB. - - :param user_db_model: Pydantic model of a DB representation of a user. - :param collection: Collection instance from `motor`. - """ - - collection: AsyncIOMotorCollection - email_collation: Collation - - def __init__( - self, - user_db_model: Type[UD], - collection: AsyncIOMotorCollection, - email_collation: Optional[Collation] = None, - ): - super().__init__(user_db_model) - self.collection = collection - self.collection.create_index("id", unique=True) - self.collection.create_index("email", unique=True) - - if email_collation: - self.email_collation = email_collation # pragma: no cover - else: - self.email_collation = Collation("en", strength=2) - - self.collection.create_index( - "email", - name="case_insensitive_email_index", - collation=self.email_collation, - ) - - async def get(self, id: UUID4) -> Optional[UD]: - user = await self.collection.find_one({"id": id}) - return self.user_db_model(**user) if user else None - - async def get_by_email(self, email: str) -> Optional[UD]: - user = await self.collection.find_one( - {"email": email}, collation=self.email_collation - ) - return self.user_db_model(**user) if user else None - - async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]: - user = await self.collection.find_one( - { - "oauth_accounts.oauth_name": oauth, - "oauth_accounts.account_id": account_id, - } - ) - return self.user_db_model(**user) if user else None - - async def create(self, user: UD) -> UD: - await self.collection.insert_one(user.dict()) - return user - - async def update(self, user: UD) -> UD: - await self.collection.replace_one({"id": user.id}, user.dict()) - return user - - async def delete(self, user: UD) -> None: - await self.collection.delete_one({"id": user.id}) diff --git a/fastapi_users/db/ormar.py b/fastapi_users/db/ormar.py deleted file mode 100644 index 66f56397..00000000 --- a/fastapi_users/db/ormar.py +++ /dev/null @@ -1,113 +0,0 @@ -from typing import Any, List, Optional, Type, cast - -import ormar -from ormar.exceptions import NoMatch -from pydantic import UUID4 - -from fastapi_users.db.base import BaseUserDatabase -from fastapi_users.models import UD, BaseOAuthAccount - - -class OrmarBaseUserModel(ormar.Model): - class Meta: - tablename = "users" - abstract = True - - id = ormar.UUID(primary_key=True, uuid_format="string") - email = ormar.String(index=True, unique=True, nullable=False, max_length=255) - hashed_password = ormar.String(nullable=False, max_length=255) - is_active = ormar.Boolean(default=True, nullable=False) - is_superuser = ormar.Boolean(default=False, nullable=False) - is_verified = ormar.Boolean(default=False, nullable=False) - - -class OrmarBaseOAuthAccountModel(ormar.Model): - class Meta: - tablename = "oauth_accounts" - abstract = True - - id = ormar.UUID(primary_key=True, uuid_format="string") - oauth_name = ormar.String(nullable=False, max_length=255) - access_token = ormar.String(nullable=False, max_length=255) - expires_at = ormar.Integer(nullable=True) - refresh_token = ormar.String(nullable=True, max_length=255) - account_id = ormar.String(index=True, nullable=False, max_length=255) - account_email = ormar.String(nullable=False, max_length=255) - - -class OrmarUserDatabase(BaseUserDatabase[UD]): - """ - Database adapter for ormar. - - :param user_db_model: Pydantic model of a DB representation of a user. - :param model: ormar ORM model. - :param oauth_account_model: Optional ormar ORM model of a OAuth account. - """ - - model: Type[OrmarBaseUserModel] - oauth_account_model: Optional[Type[OrmarBaseOAuthAccountModel]] - - def __init__( - self, - user_db_model: Type[UD], - model: Type[OrmarBaseUserModel], - oauth_account_model: Optional[Type[OrmarBaseOAuthAccountModel]] = None, - ): - super().__init__(user_db_model) - self.model = model - self.oauth_account_model = oauth_account_model - - async def get(self, id: UUID4) -> Optional[UD]: - return await self._get_user(id=id) - - async def get_by_email(self, email: str) -> Optional[UD]: - return await self._get_user(email__iexact=email) - - async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]: - return await self._get_user( - oauth_accounts__oauth_name=oauth, oauth_accounts__account_id=account_id - ) - - async def create(self, user: UD) -> UD: - oauth_accounts = getattr(user, "oauth_accounts", []) - model = await self.model(**user.dict(exclude={"oauth_accounts"})).save() - if oauth_accounts and self.oauth_account_model: - await self._create_oauth_models(model=model, oauth_accounts=oauth_accounts) - user_db = await self._get_user(id=user.id) - return cast(UD, user_db) - - async def update(self, user: UD) -> UD: - oauth_accounts = getattr(user, "oauth_accounts", []) - model = await self._get_db_user(id=user.id) - await model.update(**user.dict(exclude={"oauth_accounts"})) - if oauth_accounts and self.oauth_account_model: - await model.oauth_accounts.clear(keep_reversed=False) - await self._create_oauth_models(model=model, oauth_accounts=oauth_accounts) - user_db = await self._get_user(id=user.id) - return cast(UD, user_db) - - async def delete(self, user: UD) -> None: - await self.model.objects.delete(id=user.id) - - async def _create_oauth_models( - self, model: OrmarBaseUserModel, oauth_accounts: List[BaseOAuthAccount] - ): - if self.oauth_account_model: - oauth_accounts_db = [ - self.oauth_account_model(user=model, **oacc.dict()) - for oacc in oauth_accounts - ] - await self.oauth_account_model.objects.bulk_create(oauth_accounts_db) - - async def _get_db_user(self, **kwargs: Any) -> OrmarBaseUserModel: - query = self.model.objects.filter(**kwargs) - if self.oauth_account_model is not None: - query = query.select_related("oauth_accounts") - return cast(OrmarBaseUserModel, await query.get()) - - async def _get_user(self, **kwargs: Any) -> Optional[UD]: - try: - user = await self._get_db_user(**kwargs) - except NoMatch: - return None - return self.user_db_model(**user.dict()) diff --git a/fastapi_users/db/sqlalchemy.py b/fastapi_users/db/sqlalchemy.py deleted file mode 100644 index e5ec0270..00000000 --- a/fastapi_users/db/sqlalchemy.py +++ /dev/null @@ -1,206 +0,0 @@ -import uuid -from typing import Mapping, Optional, Type - -from databases import Database -from pydantic import UUID4 -from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Table, func, select -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.ext.declarative import declared_attr -from sqlalchemy.types import CHAR, TypeDecorator - -from fastapi_users.db.base import BaseUserDatabase -from fastapi_users.models import UD - - -class GUID(TypeDecorator): # pragma: no cover - """Platform-independent GUID type. - - Uses PostgreSQL's UUID type, otherwise uses - CHAR(36), storing as regular strings. - """ - class UUIDChar(CHAR): - python_type = UUID4 - - impl = UUIDChar - - def load_dialect_impl(self, dialect): - if dialect.name == "postgresql": - return dialect.type_descriptor(UUID()) - else: - return dialect.type_descriptor(CHAR(36)) - - def process_bind_param(self, value, dialect): - if value is None: - return value - elif dialect.name == "postgresql": - return str(value) - else: - if not isinstance(value, uuid.UUID): - return str(uuid.UUID(value)) - else: - return str(value) - - def process_result_value(self, value, dialect): - if value is None: - return value - else: - if not isinstance(value, uuid.UUID): - value = uuid.UUID(value) - return value - - -class SQLAlchemyBaseUserTable: - """Base SQLAlchemy users table definition.""" - - __tablename__ = "user" - - id = Column(GUID, primary_key=True) - email = Column(String(length=320), unique=True, index=True, nullable=False) - hashed_password = Column(String(length=72), nullable=False) - is_active = Column(Boolean, default=True, nullable=False) - is_superuser = Column(Boolean, default=False, nullable=False) - is_verified = Column(Boolean, default=False, nullable=False) - - -class SQLAlchemyBaseOAuthAccountTable: - """Base SQLAlchemy OAuth account table definition.""" - - __tablename__ = "oauth_account" - - id = Column(GUID, primary_key=True) - oauth_name = Column(String(length=100), index=True, nullable=False) - access_token = Column(String(length=1024), nullable=False) - expires_at = Column(Integer, nullable=True) - refresh_token = Column(String(length=1024), nullable=True) - account_id = Column(String(length=320), index=True, nullable=False) - account_email = Column(String(length=320), nullable=False) - - @declared_attr - def user_id(cls): - return Column(GUID, ForeignKey("user.id", ondelete="cascade"), nullable=False) - - -class NotSetOAuthAccountTableError(Exception): - """ - OAuth table was not set in DB adapter but was needed. - - Raised when trying to create/update a user with OAuth accounts set - but no table were specified in the DB adapter. - """ - - pass - - -class SQLAlchemyUserDatabase(BaseUserDatabase[UD]): - """ - Database adapter for SQLAlchemy. - - :param user_db_model: Pydantic model of a DB representation of a user. - :param database: `Database` instance from `encode/databases`. - :param users: SQLAlchemy users table instance. - :param oauth_accounts: Optional SQLAlchemy OAuth accounts table instance. - """ - - database: Database - users: Table - oauth_accounts: Optional[Table] - - def __init__( - self, - user_db_model: Type[UD], - database: Database, - users: Table, - oauth_accounts: Optional[Table] = None, - ): - super().__init__(user_db_model) - self.database = database - self.users = users - self.oauth_accounts = oauth_accounts - - async def get(self, id: UUID4) -> Optional[UD]: - query = self.users.select().where(self.users.c.id == id) - user = await self.database.fetch_one(query) - return await self._make_user(user) if user else None - - async def get_by_email(self, email: str) -> Optional[UD]: - query = self.users.select().where( - func.lower(self.users.c.email) == func.lower(email) - ) - user = await self.database.fetch_one(query) - return await self._make_user(user) if user else None - - async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]: - if self.oauth_accounts is not None: - query = ( - select([self.users]) - .select_from(self.users.join(self.oauth_accounts)) - .where(self.oauth_accounts.c.oauth_name == oauth) - .where(self.oauth_accounts.c.account_id == account_id) - ) - user = await self.database.fetch_one(query) - return await self._make_user(user) if user else None - raise NotSetOAuthAccountTableError() - - async def create(self, user: UD) -> UD: - user_dict = user.dict() - oauth_accounts_values = None - - if "oauth_accounts" in user_dict: - oauth_accounts_values = [] - - oauth_accounts = user_dict.pop("oauth_accounts") - for oauth_account in oauth_accounts: - oauth_accounts_values.append({"user_id": user.id, **oauth_account}) - - query = self.users.insert() - await self.database.execute(query, user_dict) - - if oauth_accounts_values is not None: - if self.oauth_accounts is None: - raise NotSetOAuthAccountTableError() - query = self.oauth_accounts.insert() - await self.database.execute_many(query, oauth_accounts_values) - - return user - - async def update(self, user: UD) -> UD: - user_dict = user.dict() - - if "oauth_accounts" in user_dict: - if self.oauth_accounts is None: - raise NotSetOAuthAccountTableError() - - delete_query = self.oauth_accounts.delete().where( - self.oauth_accounts.c.user_id == user.id - ) - await self.database.execute(delete_query) - - oauth_accounts_values = [] - oauth_accounts = user_dict.pop("oauth_accounts") - for oauth_account in oauth_accounts: - oauth_accounts_values.append({"user_id": user.id, **oauth_account}) - - insert_query = self.oauth_accounts.insert() - await self.database.execute_many(insert_query, oauth_accounts_values) - - update_query = ( - self.users.update().where(self.users.c.id == user.id).values(user_dict) - ) - await self.database.execute(update_query) - return user - - async def delete(self, user: UD) -> None: - query = self.users.delete().where(self.users.c.id == user.id) - await self.database.execute(query) - - async def _make_user(self, user: Mapping) -> UD: - user_dict = {**user} - - if self.oauth_accounts is not None: - query = self.oauth_accounts.select().where( - self.oauth_accounts.c.user_id == user["id"] - ) - oauth_accounts = await self.database.fetch_all(query) - user_dict["oauth_accounts"] = [{**a} for a in oauth_accounts] - - return self.user_db_model(**user_dict) diff --git a/fastapi_users/db/tortoise.py b/fastapi_users/db/tortoise.py deleted file mode 100644 index da5e7959..00000000 --- a/fastapi_users/db/tortoise.py +++ /dev/null @@ -1,147 +0,0 @@ -from typing import Optional, Type, cast - -from pydantic import UUID4 -from tortoise import fields, models -from tortoise.contrib.pydantic import PydanticModel -from tortoise.exceptions import DoesNotExist -from tortoise.queryset import QuerySetSingle - -from fastapi_users.db.base import BaseUserDatabase -from fastapi_users.models import UD - - -class TortoiseBaseUserModel(models.Model): - id = fields.UUIDField(pk=True, generated=False) - email = fields.CharField(index=True, unique=True, null=False, max_length=255) - hashed_password = fields.CharField(null=False, max_length=255) - is_active = fields.BooleanField(default=True, null=False) - is_superuser = fields.BooleanField(default=False, null=False) - is_verified = fields.BooleanField(default=False, null=False) - - class Meta: - abstract = True - - -class TortoiseBaseOAuthAccountModel(models.Model): - id = fields.UUIDField(pk=True, generated=False, max_length=255) - oauth_name = fields.CharField(null=False, max_length=255) - access_token = fields.CharField(null=False, max_length=255) - expires_at = fields.IntField(null=True) - refresh_token = fields.CharField(null=True, max_length=255) - account_id = fields.CharField(index=True, null=False, max_length=255) - account_email = fields.CharField(null=False, max_length=255) - - class Meta: - abstract = True - - -class TortoiseUserDatabase(BaseUserDatabase[UD]): - """ - Database adapter for Tortoise ORM. - - :param user_db_model: Pydantic model of a DB representation of a user. - :param model: Tortoise ORM model. - :param oauth_account_model: Optional Tortoise ORM model of a OAuth account. - """ - - model: Type[TortoiseBaseUserModel] - oauth_account_model: Optional[Type[TortoiseBaseOAuthAccountModel]] - - def __init__( - self, - user_db_model: Type[UD], - model: Type[TortoiseBaseUserModel], - oauth_account_model: Optional[Type[TortoiseBaseOAuthAccountModel]] = None, - ): - super().__init__(user_db_model) - self.model = model - self.oauth_account_model = oauth_account_model - - async def get(self, id: UUID4) -> Optional[UD]: - try: - query = self.model.get(id=id) - - if self.oauth_account_model is not None: - query = query.prefetch_related("oauth_accounts") - - user = await query - pydantic_user = await cast( - PydanticModel, self.user_db_model - ).from_tortoise_orm(user) - - return cast(UD, pydantic_user) - except DoesNotExist: - return None - - async def get_by_email(self, email: str) -> Optional[UD]: - query = self.model.filter(email__iexact=email).first() - - if self.oauth_account_model is not None: - query = query.prefetch_related("oauth_accounts") - - user = await query - - if user is None: - return None - - pydantic_user = await cast(PydanticModel, self.user_db_model).from_tortoise_orm( - user - ) - - return cast(UD, pydantic_user) - - async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]: - try: - query: QuerySetSingle[TortoiseBaseUserModel] = self.model.get( - oauth_accounts__oauth_name=oauth, oauth_accounts__account_id=account_id - ).prefetch_related("oauth_accounts") - - user = await query - pydantic_user = await cast( - PydanticModel, self.user_db_model - ).from_tortoise_orm(user) - - return cast(UD, pydantic_user) - except DoesNotExist: - return None - - async def create(self, user: UD) -> UD: - user_dict = user.dict() - oauth_accounts = user_dict.pop("oauth_accounts", None) - - model = self.model(**user_dict) - await model.save() - - if oauth_accounts and self.oauth_account_model: - oauth_account_objects = [] - for oauth_account in oauth_accounts: - oauth_account_objects.append( - self.oauth_account_model(user=model, **oauth_account) - ) - await self.oauth_account_model.bulk_create(oauth_account_objects) - - return user - - async def update(self, user: UD) -> UD: - user_dict = user.dict() - user_dict.pop("id") # Tortoise complains if we pass the PK again - oauth_accounts = user_dict.pop("oauth_accounts", None) - - model = await self.model.get(id=user.id) - for field in user_dict: - setattr(model, field, user_dict[field]) - await model.save() - - if oauth_accounts and self.oauth_account_model: - await model.oauth_accounts.all().delete() # type: ignore - oauth_account_objects = [] - for oauth_account in oauth_accounts: - oauth_account_objects.append( - self.oauth_account_model(user=model, **oauth_account) - ) - await self.oauth_account_model.bulk_create(oauth_account_objects) - - return user - - async def delete(self, user: UD) -> None: - await self.model.filter(id=user.id).delete() diff --git a/pyproject.toml b/pyproject.toml index 251deeba..29e1d6cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,17 +34,16 @@ requires = [ [tool.flit.metadata.requires-extra] sqlalchemy = [ - "sqlalchemy >=1.3.13,<1.4", - "databases >=0.3.0,<0.5", + "fastapi-users-db-sqlalchemy >=1.0.0", ] mongodb = [ - "motor >=2.2.0,<3.0.0", + "fastapi-users-db-mongodb >=1.0.0", ] tortoise-orm = [ - "tortoise-orm >=0.16.0,<0.18.0" + "fastapi-users-db-tortoise >=1.0.0", ] ormar = [ - "ormar >=0.9.5,<0.11.0" + "fastapi-users-db-ormar >=1.0.0", ] oauth = [ "httpx-oauth >=0.3,<0.4" diff --git a/requirements.dev.txt b/requirements.dev.txt index 02f7f82b..e4acb53f 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -4,7 +4,6 @@ flake8 pytest requests isort -databases pytest-asyncio flake8-docstrings mkdocs @@ -24,4 +23,3 @@ httpx-oauth httpx asgi_lifespan uvicorn -sqlalchemy-stubs diff --git a/requirements.txt b/requirements.txt index 1f2e5a32..e18ce129 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,8 @@ fastapi >=0.65.2,<0.69.0 passlib[bcrypt] ==1.7.4 email-validator >=1.1.2,<1.2 -sqlalchemy >=1.3.23,<1.4 -databases[postgresql, sqlite] >=0.3.2,<0.5 pyjwt ==2.1.0 python-multipart ==0.0.5 -motor >=2.3.1,<3.0.0 -tortoise-orm >=0.17.1,<0.18.0 -ormar >=0.10.1,<0.11.0 makefun >=1.11.2,<1.12 typing-extensions >=3.7.4.3; python_version < '3.8' Deprecated >=1.2.12,<2.0.0 diff --git a/tests/test_db_mongodb.py b/tests/test_db_mongodb.py deleted file mode 100644 index bd3aef3b..00000000 --- a/tests/test_db_mongodb.py +++ /dev/null @@ -1,200 +0,0 @@ -from typing import AsyncGenerator - -import pymongo.errors -import pytest -from motor.motor_asyncio import AsyncIOMotorClient - -from fastapi_users.db.mongodb import MongoDBUserDatabase -from fastapi_users.password import get_password_hash -from tests.conftest import UserDB, UserDBOAuth - - -@pytest.fixture(scope="module") -async def mongodb_client(): - client = AsyncIOMotorClient( - "mongodb://localhost:27017", - serverSelectionTimeoutMS=100, - uuidRepresentation="standard", - ) - - try: - await client.server_info() - yield client - client.close() - except pymongo.errors.ServerSelectionTimeoutError: - pytest.skip("MongoDB not available", allow_module_level=True) - return - - -@pytest.fixture -def get_mongodb_user_db(mongodb_client: AsyncIOMotorClient): - async def _get_mongodb_user_db( - user_model, - ) -> AsyncGenerator[MongoDBUserDatabase, None]: - db = mongodb_client["test_database"] - collection = db["users"] - - yield MongoDBUserDatabase(user_model, collection) - - await collection.delete_many({}) - - return _get_mongodb_user_db - - -@pytest.fixture -async def mongodb_user_db(get_mongodb_user_db): - async for u in get_mongodb_user_db(UserDB): - yield u - - -@pytest.fixture -async def mongodb_user_db_oauth(get_mongodb_user_db): - async for u in get_mongodb_user_db(UserDBOAuth): - yield u - - -@pytest.mark.asyncio -@pytest.mark.db -async def test_queries(mongodb_user_db: MongoDBUserDatabase[UserDB]): - user = UserDB( - email="lancelot@camelot.bt", - hashed_password=get_password_hash("guinevere"), - ) - - # Create - user_db = await mongodb_user_db.create(user) - assert user_db.id is not None - assert user_db.is_active is True - assert user_db.is_superuser is False - assert user_db.email == user.email - - # Update - user_db.is_superuser = True - await mongodb_user_db.update(user_db) - - # Get by id - id_user = await mongodb_user_db.get(user.id) - assert id_user is not None - assert id_user.id == user_db.id - assert id_user.is_superuser is True - - # Get by email - email_user = await mongodb_user_db.get_by_email(str(user.email)) - assert email_user is not None - assert email_user.id == user_db.id - - # Get by uppercased email - email_user = await mongodb_user_db.get_by_email("Lancelot@camelot.bt") - assert email_user is not None - assert email_user.id == user_db.id - - # Exception when inserting existing email - with pytest.raises(pymongo.errors.DuplicateKeyError): - await mongodb_user_db.create(user) - - # Unknown user - unknown_user = await mongodb_user_db.get_by_email("galahad@camelot.bt") - assert unknown_user is None - - # Delete user - await mongodb_user_db.delete(user) - deleted_user = await mongodb_user_db.get(user.id) - assert deleted_user is None - - -@pytest.mark.asyncio -@pytest.mark.db -@pytest.mark.parametrize( - "email,query,found", - [ - ("lancelot@camelot.bt", "lancelot@camelot.bt", True), - ("lancelot@camelot.bt", "LanceloT@camelot.bt", True), - ("lancelot@camelot.bt", "lancelot.@camelot.bt", False), - ("lancelot@camelot.bt", "lancelot.*", False), - ("lancelot@camelot.bt", "lancelot+guinevere@camelot.bt", False), - ("lancelot+guinevere@camelot.bt", "lancelot+guinevere@camelot.bt", True), - ("lancelot+guinevere@camelot.bt", "lancelot.*", False), - ("квіточка@пошта.укр", "квіточка@пошта.укр", True), - ("квіточка@пошта.укр", "КВІТОЧКА@ПОШТА.УКР", True), - ], -) -async def test_email_query( - mongodb_user_db: MongoDBUserDatabase[UserDB], email: str, query: str, found: bool -): - user = UserDB( - email=email, - hashed_password=get_password_hash("guinevere"), - ) - await mongodb_user_db.create(user) - - email_user = await mongodb_user_db.get_by_email(query) - - if found: - assert email_user is not None - assert email_user.id == user.id - else: - assert email_user is None - - -@pytest.mark.asyncio -@pytest.mark.db -async def test_queries_custom_fields(mongodb_user_db: MongoDBUserDatabase[UserDB]): - """It should output custom fields in query result.""" - user = UserDB( - email="lancelot@camelot.bt", - hashed_password=get_password_hash("guinevere"), - first_name="Lancelot", - ) - await mongodb_user_db.create(user) - - id_user = await mongodb_user_db.get(user.id) - assert id_user is not None - assert id_user.id == user.id - assert id_user.first_name == user.first_name - - -@pytest.mark.asyncio -@pytest.mark.db -async def test_queries_oauth( - mongodb_user_db_oauth: MongoDBUserDatabase[UserDBOAuth], - oauth_account1, - oauth_account2, -): - user = UserDBOAuth( - email="lancelot@camelot.bt", - hashed_password=get_password_hash("guinevere"), - oauth_accounts=[oauth_account1, oauth_account2], - ) - - # Create - user_db = await mongodb_user_db_oauth.create(user) - assert user_db.id is not None - assert hasattr(user_db, "oauth_accounts") - assert len(user_db.oauth_accounts) == 2 - - # Update - user_db.oauth_accounts[0].access_token = "NEW_TOKEN" - await mongodb_user_db_oauth.update(user_db) - - # Get by id - id_user = await mongodb_user_db_oauth.get(user.id) - assert id_user is not None - assert id_user.id == user_db.id - assert id_user.oauth_accounts[0].access_token == "NEW_TOKEN" - - # Get by email - email_user = await mongodb_user_db_oauth.get_by_email(str(user.email)) - assert email_user is not None - assert email_user.id == user_db.id - assert len(email_user.oauth_accounts) == 2 - - # Get by OAuth account - oauth_user = await mongodb_user_db_oauth.get_by_oauth_account( - oauth_account1.oauth_name, oauth_account1.account_id - ) - assert oauth_user is not None - assert oauth_user.id == user.id - - # Unknown OAuth account - unknown_oauth_user = await mongodb_user_db_oauth.get_by_oauth_account("foo", "bar") - assert unknown_oauth_user is None diff --git a/tests/test_db_ormar.py b/tests/test_db_ormar.py deleted file mode 100644 index 57a59f90..00000000 --- a/tests/test_db_ormar.py +++ /dev/null @@ -1,214 +0,0 @@ -import uuid -from sqlite3 import IntegrityError -from typing import AsyncGenerator - -import databases -import ormar -import pytest -import sqlalchemy -from ormar.exceptions import NoMatch - -from fastapi_users.db.ormar import ( - OrmarBaseOAuthAccountModel, - OrmarBaseUserModel, - OrmarUserDatabase, -) -from fastapi_users.password import get_password_hash -from tests.conftest import UserDB, UserDBOAuth - -DATABASE_URL = "sqlite:///./test-ormar-user.db" -metadata = sqlalchemy.MetaData() -database = databases.Database(DATABASE_URL) - - -class User(OrmarBaseUserModel): - class Meta: - metadata = metadata - database = database - - first_name = ormar.String(nullable=True, max_length=255) - - -class OAuthAccount(OrmarBaseOAuthAccountModel): - class Meta: - metadata = metadata - database = database - - user = ormar.ForeignKey(User, related_name="oauth_accounts") - - -@pytest.fixture -async def ormar_user_db() -> AsyncGenerator[OrmarUserDatabase, None]: - engine = sqlalchemy.create_engine( - DATABASE_URL, connect_args={"check_same_thread": False} - ) - metadata.create_all(engine) - - await database.connect() - - yield OrmarUserDatabase(user_db_model=UserDB, model=User) - - metadata.drop_all(engine) - await database.disconnect() - - -@pytest.fixture -async def ormar_user_db_oauth() -> AsyncGenerator[OrmarUserDatabase, None]: - engine = sqlalchemy.create_engine( - DATABASE_URL, connect_args={"check_same_thread": False} - ) - metadata.create_all(engine) - - await database.connect() - - yield OrmarUserDatabase( - user_db_model=UserDBOAuth, model=User, oauth_account_model=OAuthAccount - ) - - metadata.drop_all(engine) - await database.disconnect() - - -@pytest.mark.asyncio -@pytest.mark.db -async def test_queries(ormar_user_db: OrmarUserDatabase[UserDB]): - user = UserDB( - email="lancelot@camelot.bt", - hashed_password=get_password_hash("guinevere"), - ) - - # Create - user_db = await ormar_user_db.create(user) - assert user_db.id is not None - assert user_db.is_active is True - assert user_db.is_superuser is False - assert user_db.email == user.email - - # Update - user_db.is_superuser = True - await ormar_user_db.update(user_db) - - # Exception when updating a user with a not existing id - id_backup = user_db.id - user_db.id = uuid.uuid4() - with pytest.raises(NoMatch): - await ormar_user_db.update(user_db) - user_db.id = id_backup - - # Get by id - id_user = await ormar_user_db.get(user.id) - assert id_user is not None - assert id_user.id == user_db.id - assert id_user.is_superuser is True - - # Get by email - email_user = await ormar_user_db.get_by_email(str(user.email)) - assert email_user is not None - assert email_user.id == user_db.id - - # Get by uppercased email - email_user = await ormar_user_db.get_by_email("Lancelot@camelot.bt") - assert email_user is not None - assert email_user.id == user_db.id - - # Exception when inserting existing email - with pytest.raises(IntegrityError): - await ormar_user_db.create(user) - - # Exception when inserting non-nullable fields - with pytest.raises(ValueError): - wrong_user = UserDB(hashed_password="aaa") - await ormar_user_db.create(wrong_user) - - # Unknown user - unknown_user = await ormar_user_db.get_by_email("galahad@camelot.bt") - assert unknown_user is None - - # Delete user - await ormar_user_db.delete(user) - deleted_user = await ormar_user_db.get(user.id) - assert deleted_user is None - - -@pytest.mark.asyncio -@pytest.mark.db -async def test_queries_custom_fields(ormar_user_db: OrmarUserDatabase[UserDB]): - """It should output custom fields in query result.""" - user = UserDB( - email="lancelot@camelot.bt", - hashed_password=get_password_hash("guinevere"), - first_name="Lancelot", - ) - await ormar_user_db.create(user) - - id_user = await ormar_user_db.get(user.id) - assert id_user is not None - assert id_user.id == user.id - assert id_user.first_name == user.first_name - - -@pytest.mark.asyncio -@pytest.mark.db -async def test_queries_oauth( - ormar_user_db_oauth: OrmarUserDatabase[UserDBOAuth], - oauth_account1, - oauth_account2, - oauth_account3, -): - user = UserDBOAuth( - email="lancelot@camelot.bt", - hashed_password=get_password_hash("guinevere"), - oauth_accounts=[oauth_account1, oauth_account2], - ) - - # Create - user_db = await ormar_user_db_oauth.create(user) - assert user_db.id is not None - assert hasattr(user_db, "oauth_accounts") - assert len(user_db.oauth_accounts) == 2 - - # Update - oauth_to_check_id = user_db.oauth_accounts[0].id - user_db.oauth_accounts[0].access_token = "NEW_TOKEN" - await ormar_user_db_oauth.update(user_db) - - # Get by id - id_user = await ormar_user_db_oauth.get(user.id) - assert id_user is not None - assert id_user.id == user_db.id - updated_oauth = next( - (oauth for oauth in id_user.oauth_accounts if oauth.id == oauth_to_check_id), - None, - ) - assert updated_oauth.access_token == "NEW_TOKEN" - - # Add new oauth - id_user.oauth_accounts.append(oauth_account3) - await ormar_user_db_oauth.update(id_user) - id_user_updated = await ormar_user_db_oauth.get(user.id) - assert len(id_user_updated.oauth_accounts) == 3 - - # Remove oauth2 and update - id_user.oauth_accounts = [ - oauth for oauth in id_user.oauth_accounts if oauth.id != oauth_account2.id - ] - await ormar_user_db_oauth.update(id_user) - id_user_updated = await ormar_user_db_oauth.get(user.id) - assert len(id_user_updated.oauth_accounts) == 2 - - # Get by email - email_user = await ormar_user_db_oauth.get_by_email(str(user.email)) - assert email_user is not None - assert email_user.id == user_db.id - assert len(email_user.oauth_accounts) == 2 - - # Get by OAuth account - oauth_user = await ormar_user_db_oauth.get_by_oauth_account( - oauth_account1.oauth_name, oauth_account1.account_id - ) - assert oauth_user is not None - assert oauth_user.id == user.id - - # Unknown OAuth account - unknown_oauth_user = await ormar_user_db_oauth.get_by_oauth_account("foo", "bar") - assert unknown_oauth_user is None diff --git a/tests/test_db_sqlalchemy.py b/tests/test_db_sqlalchemy.py deleted file mode 100644 index 92d955ed..00000000 --- a/tests/test_db_sqlalchemy.py +++ /dev/null @@ -1,203 +0,0 @@ -import sqlite3 -from typing import AsyncGenerator - -import pytest -import sqlalchemy -from databases import Database -from sqlalchemy import Column, String -from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base - -from fastapi_users.db.sqlalchemy import ( - NotSetOAuthAccountTableError, - SQLAlchemyBaseOAuthAccountTable, - SQLAlchemyBaseUserTable, - SQLAlchemyUserDatabase, -) -from fastapi_users.password import get_password_hash -from tests.conftest import UserDB, UserDBOAuth - - -@pytest.fixture -async def sqlalchemy_user_db() -> AsyncGenerator[SQLAlchemyUserDatabase, None]: - Base: DeclarativeMeta = declarative_base() - - class User(SQLAlchemyBaseUserTable, Base): - first_name = Column(String, nullable=True) - - DATABASE_URL = "sqlite:///./test-sqlalchemy-user.db" - database = Database(DATABASE_URL) - - engine = sqlalchemy.create_engine( - DATABASE_URL, connect_args={"check_same_thread": False} - ) - Base.metadata.create_all(engine) - - await database.connect() - - yield SQLAlchemyUserDatabase(UserDB, database, User.__table__) - - Base.metadata.drop_all(engine) - await database.disconnect() - - -@pytest.fixture -async def sqlalchemy_user_db_oauth() -> AsyncGenerator[SQLAlchemyUserDatabase, None]: - Base: DeclarativeMeta = declarative_base() - - class User(SQLAlchemyBaseUserTable, Base): - first_name = Column(String, nullable=True) - - class OAuthAccount(SQLAlchemyBaseOAuthAccountTable, Base): - pass - - DATABASE_URL = "sqlite:///./test-sqlalchemy-user-oauth.db" - database = Database(DATABASE_URL) - - engine = sqlalchemy.create_engine( - DATABASE_URL, connect_args={"check_same_thread": False} - ) - Base.metadata.create_all(engine) - - await database.connect() - - yield SQLAlchemyUserDatabase( - UserDBOAuth, database, User.__table__, OAuthAccount.__table__ - ) - - Base.metadata.drop_all(engine) - - -@pytest.mark.asyncio -@pytest.mark.db -async def test_queries(sqlalchemy_user_db: SQLAlchemyUserDatabase[UserDB]): - user = UserDB( - email="lancelot@camelot.bt", - hashed_password=get_password_hash("guinevere"), - ) - - # Create - user_db = await sqlalchemy_user_db.create(user) - assert user_db.id is not None - assert user_db.is_active is True - assert user_db.is_superuser is False - assert user_db.email == user.email - - # Update - user_db.is_superuser = True - await sqlalchemy_user_db.update(user_db) - - # Get by id - id_user = await sqlalchemy_user_db.get(user.id) - assert id_user is not None - assert id_user.id == user_db.id - assert id_user.is_superuser is True - - # Get by email - email_user = await sqlalchemy_user_db.get_by_email(str(user.email)) - assert email_user is not None - assert email_user.id == user_db.id - - # Get by uppercased email - email_user = await sqlalchemy_user_db.get_by_email("Lancelot@camelot.bt") - assert email_user is not None - assert email_user.id == user_db.id - - # Exception when inserting existing email - with pytest.raises(sqlite3.IntegrityError): - await sqlalchemy_user_db.create(user) - - # Exception when inserting non-nullable fields - with pytest.raises(sqlite3.IntegrityError): - wrong_user = UserDB(hashed_password="aaa") - await sqlalchemy_user_db.create(wrong_user) - - # Unknown user - unknown_user = await sqlalchemy_user_db.get_by_email("galahad@camelot.bt") - assert unknown_user is None - - # Delete user - await sqlalchemy_user_db.delete(user) - deleted_user = await sqlalchemy_user_db.get(user.id) - assert deleted_user is None - - # Exception when creating/updating a OAuth user - user_oauth = UserDBOAuth( - email="lancelot@camelot.bt", - hashed_password=get_password_hash("guinevere"), - ) - with pytest.raises(NotSetOAuthAccountTableError): - await sqlalchemy_user_db.create(user_oauth) - with pytest.raises(NotSetOAuthAccountTableError): - await sqlalchemy_user_db.update(user_oauth) - - # Exception when trying to get by OAuth account - with pytest.raises(NotSetOAuthAccountTableError): - await sqlalchemy_user_db.get_by_oauth_account("foo", "bar") - - -@pytest.mark.asyncio -@pytest.mark.db -async def test_queries_custom_fields( - sqlalchemy_user_db: SQLAlchemyUserDatabase[UserDB], -): - """It should output custom fields in query result.""" - user = UserDB( - email="lancelot@camelot.bt", - hashed_password=get_password_hash("guinevere"), - first_name="Lancelot", - ) - await sqlalchemy_user_db.create(user) - - id_user = await sqlalchemy_user_db.get(user.id) - assert id_user is not None - assert id_user.id == user.id - assert id_user.first_name == user.first_name - - -@pytest.mark.asyncio -@pytest.mark.db -async def test_queries_oauth( - sqlalchemy_user_db_oauth: SQLAlchemyUserDatabase[UserDBOAuth], - oauth_account1, - oauth_account2, -): - user = UserDBOAuth( - email="lancelot@camelot.bt", - hashed_password=get_password_hash("guinevere"), - oauth_accounts=[oauth_account1, oauth_account2], - ) - - # Create - user_db = await sqlalchemy_user_db_oauth.create(user) - assert user_db.id is not None - assert hasattr(user_db, "oauth_accounts") - assert len(user_db.oauth_accounts) == 2 - - # Update - user_db.oauth_accounts[0].access_token = "NEW_TOKEN" - await sqlalchemy_user_db_oauth.update(user_db) - - # Get by id - id_user = await sqlalchemy_user_db_oauth.get(user.id) - assert id_user is not None - assert id_user.id == user_db.id - assert id_user.oauth_accounts[0].access_token == "NEW_TOKEN" - - # Get by email - email_user = await sqlalchemy_user_db_oauth.get_by_email(str(user.email)) - assert email_user is not None - assert email_user.id == user_db.id - assert len(email_user.oauth_accounts) == 2 - - # Get by OAuth account - oauth_user = await sqlalchemy_user_db_oauth.get_by_oauth_account( - oauth_account1.oauth_name, oauth_account1.account_id - ) - assert oauth_user is not None - assert oauth_user.id == user.id - - # Unknown OAuth account - unknown_oauth_user = await sqlalchemy_user_db_oauth.get_by_oauth_account( - "foo", "bar" - ) - assert unknown_oauth_user is None diff --git a/tests/test_db_tortoise.py b/tests/test_db_tortoise.py deleted file mode 100644 index 0546419e..00000000 --- a/tests/test_db_tortoise.py +++ /dev/null @@ -1,183 +0,0 @@ -from typing import AsyncGenerator - -import pytest -from tortoise import Tortoise, fields -from tortoise.contrib.pydantic import PydanticModel -from tortoise.exceptions import IntegrityError - -from fastapi_users.db.tortoise import ( - TortoiseBaseOAuthAccountModel, - TortoiseBaseUserModel, - TortoiseUserDatabase, -) -from fastapi_users.password import get_password_hash -from tests.conftest import UserDB as BaseUserDB -from tests.conftest import UserDBOAuth as BaseUserDBOAuth - - -class User(TortoiseBaseUserModel): - first_name = fields.CharField(null=True, max_length=255) - - -class UserDB(BaseUserDB, PydanticModel): - class Config: - orm_mode = True - orig_model = User - - -class OAuthAccount(TortoiseBaseOAuthAccountModel): - user = fields.ForeignKeyField("models.User", related_name="oauth_accounts") - - -class UserDBOAuth(BaseUserDBOAuth, PydanticModel): - class Config: - orm_mode = True - orig_model = OAuthAccount - - -@pytest.fixture -async def tortoise_user_db() -> AsyncGenerator[TortoiseUserDatabase, None]: - DATABASE_URL = "sqlite://./test-tortoise-user.db" - - await Tortoise.init( - db_url=DATABASE_URL, modules={"models": ["tests.test_db_tortoise"]} - ) - await Tortoise.generate_schemas() - - yield TortoiseUserDatabase(UserDB, User) - - await User.all().delete() - await Tortoise.close_connections() - - -@pytest.fixture -async def tortoise_user_db_oauth() -> AsyncGenerator[TortoiseUserDatabase, None]: - DATABASE_URL = "sqlite://./test-tortoise-user-oauth.db" - - await Tortoise.init( - db_url=DATABASE_URL, modules={"models": ["tests.test_db_tortoise"]} - ) - await Tortoise.generate_schemas() - - yield TortoiseUserDatabase(UserDBOAuth, User, OAuthAccount) - - await User.all().delete() - await Tortoise.close_connections() - - -@pytest.mark.asyncio -@pytest.mark.db -async def test_queries(tortoise_user_db: TortoiseUserDatabase[UserDB]): - user = UserDB( - email="lancelot@camelot.bt", - hashed_password=get_password_hash("guinevere"), - ) - - # Create - user_db = await tortoise_user_db.create(user) - assert user_db.id is not None - assert user_db.is_active is True - assert user_db.is_superuser is False - assert user_db.email == user.email - - # Update - user_db.is_superuser = True - await tortoise_user_db.update(user_db) - - # Get by id - id_user = await tortoise_user_db.get(user.id) - assert id_user is not None - assert id_user.id == user_db.id - assert id_user.is_superuser is True - - # Get by email - email_user = await tortoise_user_db.get_by_email(str(user.email)) - assert email_user is not None - assert email_user.id == user_db.id - - # Get by uppercased email - email_user = await tortoise_user_db.get_by_email("Lancelot@camelot.bt") - assert email_user is not None - assert email_user.id == user_db.id - - # Exception when inserting existing email - with pytest.raises(IntegrityError): - await tortoise_user_db.create(user) - - # Exception when inserting non-nullable fields - with pytest.raises(ValueError): - wrong_user = UserDB(hashed_password="aaa") - await tortoise_user_db.create(wrong_user) - - # Unknown user - unknown_user = await tortoise_user_db.get_by_email("galahad@camelot.bt") - assert unknown_user is None - - # Delete user - await tortoise_user_db.delete(user) - deleted_user = await tortoise_user_db.get(user.id) - assert deleted_user is None - - -@pytest.mark.asyncio -@pytest.mark.db -async def test_queries_custom_fields(tortoise_user_db: TortoiseUserDatabase[UserDB]): - """It should output custom fields in query result.""" - user = UserDB( - email="lancelot@camelot.bt", - hashed_password=get_password_hash("guinevere"), - first_name="Lancelot", - ) - await tortoise_user_db.create(user) - - id_user = await tortoise_user_db.get(user.id) - assert id_user is not None - assert id_user.id == user.id - assert id_user.first_name == user.first_name - - -@pytest.mark.asyncio -@pytest.mark.db -async def test_queries_oauth( - tortoise_user_db_oauth: TortoiseUserDatabase[UserDBOAuth], - oauth_account1, - oauth_account2, -): - user = UserDBOAuth( - email="lancelot@camelot.bt", - hashed_password=get_password_hash("guinevere"), - oauth_accounts=[oauth_account1, oauth_account2], - ) - - # Create - user_db = await tortoise_user_db_oauth.create(user) - assert user_db.id is not None - assert hasattr(user_db, "oauth_accounts") - assert len(user_db.oauth_accounts) == 2 - - # Update - user_db.oauth_accounts[0].access_token = "NEW_TOKEN" - await tortoise_user_db_oauth.update(user_db) - - # Get by id - id_user = await tortoise_user_db_oauth.get(user.id) - assert id_user is not None - assert id_user.id == user_db.id - assert id_user.oauth_accounts[0].access_token == "NEW_TOKEN" - - # Get by email - email_user = await tortoise_user_db_oauth.get_by_email(str(user.email)) - assert email_user is not None - assert email_user.id == user_db.id - assert len(email_user.oauth_accounts) == 2 - - # Get by OAuth account - oauth_user = await tortoise_user_db_oauth.get_by_oauth_account( - oauth_account1.oauth_name, oauth_account1.account_id - ) - assert oauth_user is not None - assert oauth_user.id == user.id - - # Unknown OAuth account - unknown_oauth_user = await tortoise_user_db_oauth.get_by_oauth_account("foo", "bar") - assert unknown_oauth_user is None