mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-08-15 03:04:27 +08:00
Use a Base table class for SQLAlchemy adapter
This commit is contained in:
@ -1,16 +1,13 @@
|
||||
from typing import List
|
||||
|
||||
from databases import Database
|
||||
from sqlalchemy import Boolean, Column, String
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy import Boolean, Column, String, Table
|
||||
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.models import UserDB
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class BaseUser(Base):
|
||||
class BaseUser:
|
||||
__tablename__ = "user"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
@ -20,34 +17,35 @@ class BaseUser(Base):
|
||||
is_superuser = Column(Boolean, default=False)
|
||||
|
||||
|
||||
users = BaseUser.__table__
|
||||
|
||||
|
||||
class SQLAlchemyUserDatabase(BaseUserDatabase):
|
||||
|
||||
database: Database
|
||||
users: Table
|
||||
|
||||
def __init__(self, database):
|
||||
def __init__(self, database: Database, users: Table):
|
||||
self.database = database
|
||||
self.users = users
|
||||
|
||||
async def list(self) -> List[UserDB]:
|
||||
query = users.select()
|
||||
query = self.users.select()
|
||||
return await self.database.fetch_all(query)
|
||||
|
||||
async def get(self, id: str) -> UserDB:
|
||||
query = users.select().where(BaseUser.id == id)
|
||||
query = self.users.select().where(self.users.c.id == id)
|
||||
return await self.database.fetch_one(query)
|
||||
|
||||
async def get_by_email(self, email: str) -> UserDB:
|
||||
query = users.select().where(BaseUser.email == email)
|
||||
query = self.users.select().where(self.users.c.email == email)
|
||||
return await self.database.fetch_one(query)
|
||||
|
||||
async def create(self, user: UserDB) -> UserDB:
|
||||
query = users.insert().values(**user.dict())
|
||||
query = self.users.insert().values(**user.dict())
|
||||
await self.database.execute(query)
|
||||
return user
|
||||
|
||||
async def update(self, user: UserDB) -> UserDB:
|
||||
query = users.update().where(BaseUser.id == user.id).values(**user.dict())
|
||||
query = (
|
||||
self.users.update().where(self.users.c.id == user.id).values(**user.dict())
|
||||
)
|
||||
await self.database.execute(query)
|
||||
return user
|
||||
|
Reference in New Issue
Block a user