Fix MongoDB unit tests

This commit is contained in:
François Voron
2021-03-19 18:18:06 +01:00
parent e7ceb1569c
commit 902bcdb8d2
2 changed files with 22 additions and 18 deletions

View File

@ -47,7 +47,7 @@ class UserDBOAuth(UserOAuth, UserDB):
pass pass
@pytest.fixture @pytest.fixture(scope="session")
def event_loop(): def event_loop():
"""Force the pytest-asyncio loop to be the main one.""" """Force the pytest-asyncio loop to be the main one."""
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()

View File

@ -1,20 +1,17 @@
from typing import AsyncGenerator from typing import AsyncGenerator
import motor.motor_asyncio
import pymongo.errors import pymongo.errors
import pytest import pytest
from motor.motor_asyncio import AsyncIOMotorClient
from fastapi_users.db.mongodb import MongoDBUserDatabase from fastapi_users.db.mongodb import MongoDBUserDatabase
from fastapi_users.password import get_password_hash from fastapi_users.password import get_password_hash
from tests.conftest import UserDB, UserDBOAuth from tests.conftest import UserDB, UserDBOAuth
@pytest.fixture @pytest.fixture(scope="module")
def get_mongodb_user_db(): async def mongodb_client():
async def _get_mongodb_user_db( client = AsyncIOMotorClient(
user_model,
) -> AsyncGenerator[MongoDBUserDatabase, None]:
client = motor.motor_asyncio.AsyncIOMotorClient(
"mongodb://localhost:27017", "mongodb://localhost:27017",
serverSelectionTimeoutMS=100, serverSelectionTimeoutMS=100,
uuidRepresentation="standard", uuidRepresentation="standard",
@ -22,17 +19,24 @@ def get_mongodb_user_db():
try: try:
await client.server_info() await client.server_info()
yield client
client.close()
except pymongo.errors.ServerSelectionTimeoutError: except pymongo.errors.ServerSelectionTimeoutError:
pytest.skip("MongoDB not available", allow_module_level=True) pytest.skip("MongoDB not available", allow_module_level=True)
return return
db = client["test_database"]
@pytest.fixture
def get_mongodb_user_db(mongodb_client: AsyncIOMotorClient):
async def _get_mongodb_user_db(
user_model,
) -> AsyncGenerator[MongoDBUserDatabase, None]:
db = mongodb_client["test_database"]
collection = db["users"] collection = db["users"]
yield MongoDBUserDatabase(user_model, collection) yield MongoDBUserDatabase(user_model, collection)
await collection.drop() await collection.delete_many({})
client.close()
return _get_mongodb_user_db return _get_mongodb_user_db