Files

102 lines
3.4 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Any, Dict, Generic, Type, TypeVar, NoReturn
from pydantic import BaseModel
from sqlalchemy import select, update, delete, and_
from sqlalchemy.ext.asyncio import AsyncSession
from backend.app.models.base import MappedBase
ModelType = TypeVar('ModelType', bound=MappedBase)
CreateSchemaType = TypeVar('CreateSchemaType', bound=BaseModel)
UpdateSchemaType = TypeVar('UpdateSchemaType', bound=BaseModel)
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def __init__(self, model: Type[ModelType]):
self.model = model
async def get_(
self,
db: AsyncSession,
*,
pk: int | None = None,
name: str | None = None,
status: int | None = None,
del_flag: int | None = None,
) -> ModelType | None:
"""
通过主键 id 或者 name 获取一条数据
:param db:
:param pk:
:param name:
:param status:
:param del_flag:
:return:
"""
assert pk is not None or name is not None, '查询错误, pk 和 name 参数不能同时为空'
where_list = [self.model.id == pk] if pk is not None else [self.model.name == name]
if status is not None:
assert status in (0, 1), '查询错误, status 参数只能为 0 或 1'
where_list.append(self.model.status == status)
if del_flag is not None:
assert del_flag in (0, 1), '查询错误, del_flag 参数只能为 0 或 1'
where_list.append(self.model.del_flag == del_flag)
result = await db.execute(select(self.model).where(and_(*where_list)))
return result.scalars().first()
async def create_(self, db: AsyncSession, obj_in: CreateSchemaType, user_id: int | None = None) -> NoReturn:
"""
新增一条数据
:param db:
:param obj_in: Pydantic 模型类
:param user_id:
:return:
"""
if user_id:
create_data = self.model(**obj_in.dict(), create_user=user_id)
else:
create_data = self.model(**obj_in.dict())
db.add(create_data)
async def update_(
self, db: AsyncSession, pk: int, obj_in: UpdateSchemaType | Dict[str, Any], user_id: int | None = None
) -> int:
"""
通过主键 id 更新一条数据
:param db:
:param pk:
:param obj_in: Pydantic模型类 or 对应数据库字段的字典
:param user_id:
:return:
"""
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.dict(exclude_unset=True)
if user_id:
update_data.update({'update_user': user_id})
result = await db.execute(update(self.model).where(self.model.id == pk).values(**update_data))
return result.rowcount
async def delete_(self, db: AsyncSession, pk: int, *, del_flag: int | None = None) -> int:
"""
通过主键 id 删除一条数据
:param db:
:param pk:
:param del_flag:
:return:
"""
if del_flag is None:
result = await db.execute(delete(self.model).where(self.model.id == pk))
else:
assert del_flag == 1, '删除错误, del_flag 参数只能为 1'
result = await db.execute(update(self.model).where(self.model.id == pk).values(del_flag=del_flag))
return result.rowcount