mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-08-15 11:11:16 +08:00
Use a Base table class for SQLAlchemy adapter
This commit is contained in:
@ -1,16 +1,13 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from databases import Database
|
from databases import Database
|
||||||
from sqlalchemy import Boolean, Column, String
|
from sqlalchemy import Boolean, Column, String, Table
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
|
|
||||||
from fastapi_users.db import BaseUserDatabase
|
from fastapi_users.db import BaseUserDatabase
|
||||||
from fastapi_users.models import UserDB
|
from fastapi_users.models import UserDB
|
||||||
|
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
|
class BaseUser:
|
||||||
class BaseUser(Base):
|
|
||||||
__tablename__ = "user"
|
__tablename__ = "user"
|
||||||
|
|
||||||
id = Column(String, primary_key=True)
|
id = Column(String, primary_key=True)
|
||||||
@ -20,34 +17,35 @@ class BaseUser(Base):
|
|||||||
is_superuser = Column(Boolean, default=False)
|
is_superuser = Column(Boolean, default=False)
|
||||||
|
|
||||||
|
|
||||||
users = BaseUser.__table__
|
|
||||||
|
|
||||||
|
|
||||||
class SQLAlchemyUserDatabase(BaseUserDatabase):
|
class SQLAlchemyUserDatabase(BaseUserDatabase):
|
||||||
|
|
||||||
database: Database
|
database: Database
|
||||||
|
users: Table
|
||||||
|
|
||||||
def __init__(self, database):
|
def __init__(self, database: Database, users: Table):
|
||||||
self.database = database
|
self.database = database
|
||||||
|
self.users = users
|
||||||
|
|
||||||
async def list(self) -> List[UserDB]:
|
async def list(self) -> List[UserDB]:
|
||||||
query = users.select()
|
query = self.users.select()
|
||||||
return await self.database.fetch_all(query)
|
return await self.database.fetch_all(query)
|
||||||
|
|
||||||
async def get(self, id: str) -> UserDB:
|
async def get(self, id: str) -> UserDB:
|
||||||
query = users.select().where(BaseUser.id == id)
|
query = self.users.select().where(self.users.c.id == id)
|
||||||
return await self.database.fetch_one(query)
|
return await self.database.fetch_one(query)
|
||||||
|
|
||||||
async def get_by_email(self, email: str) -> UserDB:
|
async def get_by_email(self, email: str) -> UserDB:
|
||||||
query = users.select().where(BaseUser.email == email)
|
query = self.users.select().where(self.users.c.email == email)
|
||||||
return await self.database.fetch_one(query)
|
return await self.database.fetch_one(query)
|
||||||
|
|
||||||
async def create(self, user: UserDB) -> UserDB:
|
async def create(self, user: UserDB) -> UserDB:
|
||||||
query = users.insert().values(**user.dict())
|
query = self.users.insert().values(**user.dict())
|
||||||
await self.database.execute(query)
|
await self.database.execute(query)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
async def update(self, user: UserDB) -> UserDB:
|
async def update(self, user: UserDB) -> UserDB:
|
||||||
query = users.update().where(BaseUser.id == user.id).values(**user.dict())
|
query = (
|
||||||
|
self.users.update().where(self.users.c.id == user.id).values(**user.dict())
|
||||||
|
)
|
||||||
await self.database.execute(query)
|
await self.database.execute(query)
|
||||||
return user
|
return user
|
||||||
|
@ -3,12 +3,18 @@ import sqlite3
|
|||||||
import pytest
|
import pytest
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from databases import Database
|
from databases import Database
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
from fastapi_users.db.sqlalchemy import Base, SQLAlchemyUserDatabase
|
from fastapi_users.db.sqlalchemy import BaseUser, SQLAlchemyUserDatabase
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def sqlalchemy_user_db() -> SQLAlchemyUserDatabase:
|
async def sqlalchemy_user_db() -> SQLAlchemyUserDatabase:
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
class User(BaseUser, Base):
|
||||||
|
pass
|
||||||
|
|
||||||
DATABASE_URL = "sqlite:///./test.db"
|
DATABASE_URL = "sqlite:///./test.db"
|
||||||
database = Database(DATABASE_URL)
|
database = Database(DATABASE_URL)
|
||||||
|
|
||||||
@ -19,7 +25,7 @@ async def sqlalchemy_user_db() -> SQLAlchemyUserDatabase:
|
|||||||
|
|
||||||
await database.connect()
|
await database.connect()
|
||||||
|
|
||||||
yield SQLAlchemyUserDatabase(database)
|
yield SQLAlchemyUserDatabase(database, User.__table__)
|
||||||
|
|
||||||
Base.metadata.drop_all(engine)
|
Base.metadata.drop_all(engine)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user