Fix typing errors

This commit is contained in:
François Voron
2021-05-17 09:00:34 +02:00
parent 5267e605f4
commit a690e82408
5 changed files with 24 additions and 9 deletions

View File

@ -93,7 +93,7 @@ class OrmarUserDatabase(BaseUserDatabase[UD]):
self, model: OrmarBaseUserModel, oauth_accounts: List[BaseOAuthAccount]
):
if self.oauth_account_model:
oauth_accounts_db: List[ormar.Model] = [
oauth_accounts_db = [
self.oauth_account_model(user=model, **oacc.dict())
for oacc in oauth_accounts
]

View File

@ -168,21 +168,23 @@ class SQLAlchemyUserDatabase(BaseUserDatabase[UD]):
if self.oauth_accounts is None:
raise NotSetOAuthAccountTableError()
query = self.oauth_accounts.delete().where(
delete_query = self.oauth_accounts.delete().where(
self.oauth_accounts.c.user_id == user.id
)
await self.database.execute(query)
await self.database.execute(delete_query)
oauth_accounts_values = []
oauth_accounts = user_dict.pop("oauth_accounts")
for oauth_account in oauth_accounts:
oauth_accounts_values.append({"user_id": user.id, **oauth_account})
query = self.oauth_accounts.insert()
await self.database.execute_many(query, oauth_accounts_values)
insert_query = self.oauth_accounts.insert()
await self.database.execute_many(insert_query, oauth_accounts_values)
query = self.users.update().where(self.users.c.id == user.id).values(user_dict)
await self.database.execute(query)
update_query = (
self.users.update().where(self.users.c.id == user.id).values(user_dict)
)
await self.database.execute(update_query)
return user
async def delete(self, user: UD) -> None:

View File

@ -27,7 +27,7 @@ try:
from fastapi_users.router import get_oauth_router
except ModuleNotFoundError: # pragma: no cover
BaseOAuth2 = Type
BaseOAuth2 = Type # type: ignore
class FastAPIUsers:

View File

@ -24,3 +24,4 @@ httpx-oauth
httpx
asgi_lifespan
uvicorn
sqlalchemy-stubs

View File

@ -16,6 +16,18 @@ ignore = D1
[isort]
profile = black
[mypy]
plugins = sqlmypy
[mypy-motor.*]
ignore_missing_imports = True
[mypy-passlib.*]
ignore_missing_imports = True
[mypy-pymongo.*]
ignore_missing_imports = True
[tool:pytest]
addopts = --ignore=test_build.py
markers =