Use a Base table class for SQLAlchemy adapter

This commit is contained in:
François Voron
2019-10-09 18:03:10 +02:00
parent f2bd2c6485
commit 9f41a8b9a7
2 changed files with 20 additions and 16 deletions

View File

@ -1,16 +1,13 @@
from typing import List from typing import List
from databases import Database from databases import Database
from sqlalchemy import Boolean, Column, String from sqlalchemy import Boolean, Column, String, Table
from sqlalchemy.ext.declarative import declarative_base
from fastapi_users.db import BaseUserDatabase from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import UserDB from fastapi_users.models import UserDB
Base = declarative_base()
class BaseUser:
class BaseUser(Base):
__tablename__ = "user" __tablename__ = "user"
id = Column(String, primary_key=True) id = Column(String, primary_key=True)
@ -20,34 +17,35 @@ class BaseUser(Base):
is_superuser = Column(Boolean, default=False) is_superuser = Column(Boolean, default=False)
users = BaseUser.__table__
class SQLAlchemyUserDatabase(BaseUserDatabase): class SQLAlchemyUserDatabase(BaseUserDatabase):
database: Database database: Database
users: Table
def __init__(self, database): def __init__(self, database: Database, users: Table):
self.database = database self.database = database
self.users = users
async def list(self) -> List[UserDB]: async def list(self) -> List[UserDB]:
query = users.select() query = self.users.select()
return await self.database.fetch_all(query) return await self.database.fetch_all(query)
async def get(self, id: str) -> UserDB: async def get(self, id: str) -> UserDB:
query = users.select().where(BaseUser.id == id) query = self.users.select().where(self.users.c.id == id)
return await self.database.fetch_one(query) return await self.database.fetch_one(query)
async def get_by_email(self, email: str) -> UserDB: async def get_by_email(self, email: str) -> UserDB:
query = users.select().where(BaseUser.email == email) query = self.users.select().where(self.users.c.email == email)
return await self.database.fetch_one(query) return await self.database.fetch_one(query)
async def create(self, user: UserDB) -> UserDB: async def create(self, user: UserDB) -> UserDB:
query = users.insert().values(**user.dict()) query = self.users.insert().values(**user.dict())
await self.database.execute(query) await self.database.execute(query)
return user return user
async def update(self, user: UserDB) -> UserDB: async def update(self, user: UserDB) -> UserDB:
query = users.update().where(BaseUser.id == user.id).values(**user.dict()) query = (
self.users.update().where(self.users.c.id == user.id).values(**user.dict())
)
await self.database.execute(query) await self.database.execute(query)
return user return user

View File

@ -3,12 +3,18 @@ import sqlite3
import pytest import pytest
import sqlalchemy import sqlalchemy
from databases import Database from databases import Database
from sqlalchemy.ext.declarative import declarative_base
from fastapi_users.db.sqlalchemy import Base, SQLAlchemyUserDatabase from fastapi_users.db.sqlalchemy import BaseUser, SQLAlchemyUserDatabase
@pytest.fixture @pytest.fixture
async def sqlalchemy_user_db() -> SQLAlchemyUserDatabase: async def sqlalchemy_user_db() -> SQLAlchemyUserDatabase:
Base = declarative_base()
class User(BaseUser, Base):
pass
DATABASE_URL = "sqlite:///./test.db" DATABASE_URL = "sqlite:///./test.db"
database = Database(DATABASE_URL) database = Database(DATABASE_URL)
@ -19,7 +25,7 @@ async def sqlalchemy_user_db() -> SQLAlchemyUserDatabase:
await database.connect() await database.connect()
yield SQLAlchemyUserDatabase(database) yield SQLAlchemyUserDatabase(database, User.__table__)
Base.metadata.drop_all(engine) Base.metadata.drop_all(engine)