diff --git a/backend/app/api/v1/api.py b/backend/app/api/v1/api.py index a344671..75315f0 100644 --- a/backend/app/api/v1/api.py +++ b/backend/app/api/v1/api.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- from typing import Annotated -from fastapi import APIRouter, Query, Request +from fastapi import APIRouter, Query from backend.app.common.casbin_rbac import DependsRBAC from backend.app.common.pagination import PageDepends, paging_data @@ -33,14 +33,14 @@ async def get_all_apis( @router.post('', summary='创建接口', dependencies=[DependsRBAC]) -async def create_api(request: Request, obj: CreateApi): - await ApiService.create(obj=obj, user_id=request.user.id) +async def create_api(obj: CreateApi): + await ApiService.create(obj=obj) return await response_base.success() @router.put('/{pk}', summary='更新接口', dependencies=[DependsRBAC]) -async def update_api(request: Request, pk: int, obj: UpdateApi): - count = await ApiService.update(pk=pk, obj=obj, user_id=request.user.id) +async def update_api(pk: int, obj: UpdateApi): + count = await ApiService.update(pk=pk, obj=obj) if count > 0: return await response_base.success() return await response_base.fail() diff --git a/backend/app/api/v1/casbin.py b/backend/app/api/v1/casbin.py index 1cb26bc..e1fd747 100644 --- a/backend/app/api/v1/casbin.py +++ b/backend/app/api/v1/casbin.py @@ -21,7 +21,7 @@ from backend.app.services.casbin_service import CasbinService router = APIRouter() -@router.get('', summary='(模糊条件)分页获取所有 casbin 规则', dependencies=[DependsRBAC, PageDepends]) +@router.get('', summary='(模糊条件)分页获取所有权限规则', dependencies=[DependsRBAC, PageDepends]) async def get_all_casbin( db: CurrentSession, ptype: Annotated[str | None, Query()] = None, @@ -32,13 +32,13 @@ async def get_all_casbin( return await response_base.success(data=page_data) -@router.get('/policy', summary='获取所有 P 规则', dependencies=[DependsRBAC]) +@router.get('/policy', summary='获取所有访问权限规则', dependencies=[DependsRBAC]) async def get_all_policies(): policies = await CasbinService.get_policy_list() return await response_base.success(data=policies) -@router.post('/policy', summary='添加基于角色(主)/用户(次)的访问权限', dependencies=[DependsRBAC]) +@router.post('/policy', summary='添加访问权限', dependencies=[DependsRBAC]) async def create_policy(p: CreatePolicy): """ p 规则: @@ -53,25 +53,25 @@ async def create_policy(p: CreatePolicy): return await response_base.success(data=data) -@router.put('/policy', summary='更新基于角色(主)/用户(次)的访问权限', dependencies=[DependsRBAC]) +@router.put('/policy', summary='更新访问权限', dependencies=[DependsRBAC]) async def update_policy(old: UpdatePolicy, new: UpdatePolicy): data = await CasbinService.update_policy(old=old, new=new) return await response_base.success(data=data) -@router.delete('/policy', summary='删除基于角色(主)/用户的访问权限', dependencies=[DependsRBAC]) +@router.delete('/policy', summary='删除访问权限', dependencies=[DependsRBAC]) async def delete_policy(p: DeletePolicy): data = await CasbinService.delete_policy(p=p) return await response_base.success(data=data) -@router.get('/group', summary='获取所有 g 规则', dependencies=[DependsRBAC]) +@router.get('/group', summary='获取所有组访问权限规则', dependencies=[DependsRBAC]) async def get_all_groups(): data = await CasbinService.get_group_list() return await response_base.success(data=data) -@router.post('/group', summary='添加基于用户组的访问权限', dependencies=[DependsRBAC]) +@router.post('/group', summary='添加组访问权限', dependencies=[DependsRBAC]) async def create_group(g: CreateUserRole): """ g 规则 (**依赖 p 规则**): @@ -79,14 +79,14 @@ async def create_group(g: CreateUserRole): - 如果在 p 规则中添加了基于角色的访问权限, 则还需要在 g 规则中添加基于用户组的访问权限, 才能真正拥有访问权限
**格式**: 用户 uuid + 角色 role - - 如果在p策略中添加了基于用户的访问权限, 则不添加相应的 g 规则能直接拥有访问权限
+ - 如果在 p 策略中添加了基于用户的访问权限, 则不添加相应的 g 规则能直接拥有访问权限
但是拥有的不是用户角色的所有权限, 而只是单一的对应的 p 规则所添加的访问权限 """ data = await CasbinService.create_group(g=g) return await response_base.success(data=data) -@router.delete('/group', summary='删除基于用户组的访问权限', dependencies=[DependsRBAC]) +@router.delete('/group', summary='删除组访问权限', dependencies=[DependsRBAC]) async def delete_group(g: DeleteUserRole): data = await CasbinService.delete_group(g=g) return await response_base.success(data=data) diff --git a/backend/app/api/v1/dept.py b/backend/app/api/v1/dept.py index 9a4861d..f152044 100644 --- a/backend/app/api/v1/dept.py +++ b/backend/app/api/v1/dept.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- from typing import Annotated -from fastapi import APIRouter, Query, Request +from fastapi import APIRouter, Query from backend.app.common.casbin_rbac import DependsRBAC from backend.app.common.response.response_schema import response_base @@ -33,14 +33,14 @@ async def get_all_depts( @router.post('', summary='创建部门', dependencies=[DependsRBAC]) -async def create_dept(request: Request, obj: CreateDept): - await DeptService.create(obj=obj, user_id=request.user.id) +async def create_dept(obj: CreateDept): + await DeptService.create(obj=obj) return await response_base.success() @router.put('/{pk}', summary='更新部门', dependencies=[DependsRBAC]) -async def update_dept(request: Request, pk: int, obj: UpdateDept): - count = await DeptService.update(pk=pk, obj=obj, user_id=request.user.id) +async def update_dept(pk: int, obj: UpdateDept): + count = await DeptService.update(pk=pk, obj=obj) if count > 0: return await response_base.success() return await response_base.fail() diff --git a/backend/app/api/v1/dict_data.py b/backend/app/api/v1/dict_data.py index c9cdcf8..1978c28 100644 --- a/backend/app/api/v1/dict_data.py +++ b/backend/app/api/v1/dict_data.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- from typing import Annotated -from fastapi import APIRouter, Query, Request +from fastapi import APIRouter, Query from backend.app.common.casbin_rbac import DependsRBAC from backend.app.common.pagination import PageDepends, paging_data @@ -35,14 +35,14 @@ async def get_all_dict_datas( @router.post('', summary='创建字典', dependencies=[DependsRBAC]) -async def create_dict_data(request: Request, obj: CreateDictData): - await DictDataService.create(obj=obj, user_id=request.user.id) +async def create_dict_data(obj: CreateDictData): + await DictDataService.create(obj=obj) return await response_base.success() @router.put('/{pk}', summary='更新字典', dependencies=[DependsRBAC]) -async def update_dict_data(request: Request, pk: int, obj: UpdateDictData): - count = await DictDataService.update(pk=pk, obj=obj, user_id=request.user.id) +async def update_dict_data(pk: int, obj: UpdateDictData): + count = await DictDataService.update(pk=pk, obj=obj) if count > 0: return await response_base.success() return await response_base.fail() diff --git a/backend/app/api/v1/dict_type.py b/backend/app/api/v1/dict_type.py index c69f3f4..0c85c86 100644 --- a/backend/app/api/v1/dict_type.py +++ b/backend/app/api/v1/dict_type.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- from typing import Annotated -from fastapi import APIRouter, Query, Request +from fastapi import APIRouter, Query from backend.app.common.casbin_rbac import DependsRBAC from backend.app.common.pagination import PageDepends, paging_data @@ -27,14 +27,14 @@ async def get_all_dict_types( @router.post('', summary='创建字典类型', dependencies=[DependsRBAC]) -async def create_dict_type(request: Request, obj: CreateDictType): - await DictTypeService.create(obj=obj, user_id=request.user.id) +async def create_dict_type(obj: CreateDictType): + await DictTypeService.create(obj=obj) return await response_base.success() @router.put('/{pk}', summary='更新字典类型', dependencies=[DependsRBAC]) -async def update_dict_type(request: Request, pk: int, obj: UpdateDictType): - count = await DictTypeService.update(pk=pk, obj=obj, user_id=request.user.id) +async def update_dict_type(pk: int, obj: UpdateDictType): + count = await DictTypeService.update(pk=pk, obj=obj) if count > 0: return await response_base.success() return await response_base.fail() diff --git a/backend/app/api/v1/menu.py b/backend/app/api/v1/menu.py index 1ef84c9..cf725b0 100644 --- a/backend/app/api/v1/menu.py +++ b/backend/app/api/v1/menu.py @@ -36,14 +36,14 @@ async def get_all_menus( @router.post('', summary='创建菜单', dependencies=[DependsRBAC]) -async def create_menu(request: Request, obj: CreateMenu): - await MenuService.create(obj=obj, user_id=request.user.id) +async def create_menu(obj: CreateMenu): + await MenuService.create(obj=obj) return await response_base.success() @router.put('/{pk}', summary='更新菜单', dependencies=[DependsRBAC]) -async def update_menu(request: Request, pk: int, obj: UpdateMenu): - count = await MenuService.update(pk=pk, obj=obj, user_id=request.user.id) +async def update_menu(pk: int, obj: UpdateMenu): + count = await MenuService.update(pk=pk, obj=obj) if count > 0: return await response_base.success() return await response_base.fail() diff --git a/backend/app/api/v1/role.py b/backend/app/api/v1/role.py index e9ad1d4..f0ed846 100644 --- a/backend/app/api/v1/role.py +++ b/backend/app/api/v1/role.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- from typing import Annotated -from fastapi import APIRouter, Query, Request +from fastapi import APIRouter, Query from backend.app.common.casbin_rbac import DependsRBAC from backend.app.common.pagination import PageDepends, paging_data @@ -34,14 +34,14 @@ async def get_all_roles( @router.post('', summary='创建角色', dependencies=[DependsRBAC]) -async def create_role(request: Request, obj: CreateRole): - await RoleService.create(obj=obj, user_id=request.user.id) +async def create_role(obj: CreateRole): + await RoleService.create(obj=obj) return await response_base.success() @router.put('/{pk}', summary='更新角色', dependencies=[DependsRBAC]) -async def update_role(request: Request, pk: int, obj: UpdateRole): - count = await RoleService.update(pk=pk, obj=obj, user_id=request.user.id) +async def update_role(pk: int, obj: UpdateRole): + count = await RoleService.update(pk=pk, obj=obj) if count > 0: return await response_base.success() return await response_base.fail() diff --git a/backend/app/common/casbin_rbac.py b/backend/app/common/casbin_rbac.py index c007d15..b3f4bf0 100644 --- a/backend/app/common/casbin_rbac.py +++ b/backend/app/common/casbin_rbac.py @@ -61,7 +61,7 @@ class RBAC: raise AuthorizationError else: # casbin 权限校验 - user_uuid = request.user.user_uuid + user_uuid = request.user.uuid enforcer = self.enforcer() if not enforcer.enforce(user_uuid, path, method): raise AuthorizationError diff --git a/backend/app/common/exception/exception_handler.py b/backend/app/common/exception/exception_handler.py index 4668cc9..f500d64 100644 --- a/backend/app/common/exception/exception_handler.py +++ b/backend/app/common/exception/exception_handler.py @@ -86,9 +86,7 @@ def register_exception(app: FastAPI): for error in raw_exc.errors()[:1]: field = str(error.get('loc')[-1]) msg = error.get('msg') - message += ( - f'{data.get(field, field) if field != "__root__" else ""} {msg}' + '.' - ) + message += f'{data.get(field, field) if field != "__root__" else ""} {msg}' + '.' elif isinstance(raw_error.exc, json.JSONDecodeError): message += 'json解析失败' content = { diff --git a/backend/app/common/jwt.py b/backend/app/common/jwt.py index 25fe3ee..742d4be 100644 --- a/backend/app/common/jwt.py +++ b/backend/app/common/jwt.py @@ -16,6 +16,7 @@ from backend.app.common.redis import redis_client from backend.app.core.conf import settings from backend.app.crud.crud_user import UserDao from backend.app.models import User +from backend.app.utils.timezone import timezone_utils pwd_context = CryptContext(schemes=['bcrypt'], deprecated='auto') @@ -54,10 +55,10 @@ async def create_access_token(sub: str, expires_delta: timedelta | None = None, :return: """ if expires_delta: - expire = datetime.now() + expires_delta + expire = timezone_utils.get_timezone_expire_time(expires_delta) expire_seconds = int(expires_delta.total_seconds()) else: - expire = datetime.now() + timedelta(seconds=settings.TOKEN_EXPIRE_SECONDS) + expire = timezone_utils.get_timezone_expire_time(timedelta(seconds=settings.TOKEN_EXPIRE_SECONDS)) expire_seconds = settings.TOKEN_EXPIRE_SECONDS multi_login = kwargs.pop('multi_login', None) to_encode = {'exp': expire, 'sub': sub, **kwargs} @@ -80,9 +81,9 @@ async def create_refresh_token(sub: str, expire_time: datetime | None = None, ** """ if expire_time: expire = expire_time + timedelta(seconds=settings.TOKEN_REFRESH_EXPIRE_SECONDS) - expire_seconds = int((expire - datetime.now()).total_seconds()) + expire_seconds = timezone_utils.get_timezone_expire_seconds(expire_time) else: - expire = datetime.now() + timedelta(seconds=settings.TOKEN_REFRESH_EXPIRE_SECONDS) + expire = timezone_utils.get_timezone_expire_time(timedelta(seconds=settings.TOKEN_EXPIRE_SECONDS)) expire_seconds = settings.TOKEN_REFRESH_EXPIRE_SECONDS multi_login = kwargs.pop('multi_login', None) to_encode = {'exp': expire, 'sub': sub, **kwargs} diff --git a/backend/app/common/response/response_code.py b/backend/app/common/response/response_code.py index 3a9aa06..ae341e0 100644 --- a/backend/app/common/response/response_code.py +++ b/backend/app/common/response/response_code.py @@ -8,7 +8,7 @@ class CustomCode(Enum): 自定义错误码 """ - CAPTCHA_ERROR = (40001, '图形验证码错误') + CAPTCHA_ERROR = (40001, '验证码错误') @property def code(self): diff --git a/backend/app/crud/base.py b/backend/app/crud/base.py index b8b7754..dfd37da 100644 --- a/backend/app/crud/base.py +++ b/backend/app/crud/base.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from sqlalchemy import select, update, delete, and_ from sqlalchemy.ext.asyncio import AsyncSession -from backend.app.database.base_class import MappedBase +from backend.app.models.base import MappedBase ModelType = TypeVar('ModelType', bound=MappedBase) CreateSchemaType = TypeVar('CreateSchemaType', bound=BaseModel) diff --git a/backend/app/crud/crud_api.py b/backend/app/crud/crud_api.py index 7d5fd0d..c315051 100644 --- a/backend/app/crud/crud_api.py +++ b/backend/app/crud/crud_api.py @@ -31,11 +31,11 @@ class CRUDApi(CRUDBase[Api, CreateApi, UpdateApi]): api = await db.execute(select(self.model).where(self.model.name == name)) return api.scalars().first() - async def create(self, db: AsyncSession, obj_in: CreateApi, user_id: int) -> NoReturn: - await self.create_(db, obj_in, user_id) + async def create(self, db: AsyncSession, obj_in: CreateApi) -> NoReturn: + await self.create_(db, obj_in) - async def update(self, db: AsyncSession, pk: int, obj_in: UpdateApi, user_id: int) -> int: - return await self.update_(db, pk, obj_in, user_id) + async def update(self, db: AsyncSession, pk: int, obj_in: UpdateApi) -> int: + return await self.update_(db, pk, obj_in) async def delete(self, db: AsyncSession, pk: list[int]) -> int: apis = await db.execute(delete(self.model).where(self.model.id.in_(pk))) diff --git a/backend/app/crud/crud_dept.py b/backend/app/crud/crud_dept.py index 7418c6d..0661d68 100644 --- a/backend/app/crud/crud_dept.py +++ b/backend/app/crud/crud_dept.py @@ -36,13 +36,11 @@ class CRUDDept(CRUDBase[Dept, CreateDept, UpdateDept]): dept = await db.execute(se) return dept.scalars().all() - async def create(self, db: AsyncSession, obj_in: dict, user_id: int) -> None: - obj_in.update({'create_user': user_id}) - obj = self.model(**obj_in) - db.add(obj) + async def create(self, db: AsyncSession, obj_in: CreateDept) -> None: + await self.create_(db, obj_in) - async def update(self, db: AsyncSession, dept_id: int, obj_in: dict, user_id: int) -> int: - return await self.update_(db, dept_id, obj_in, user_id=user_id) + async def update(self, db: AsyncSession, dept_id: int, obj_in: UpdateDept) -> int: + return await self.update_(db, dept_id, obj_in) async def delete(self, db: AsyncSession, dept_id: int) -> int: return await self.delete_(db, dept_id, del_flag=1) diff --git a/backend/app/crud/crud_dict_data.py b/backend/app/crud/crud_dict_data.py index 231b1f1..0b0ae7c 100644 --- a/backend/app/crud/crud_dict_data.py +++ b/backend/app/crud/crud_dict_data.py @@ -30,11 +30,11 @@ class CRUDDictData(CRUDBase[DictData, CreateDictData, UpdateDictData]): api = await db.execute(select(self.model).where(self.model.label == label)) return api.scalars().first() - async def create(self, db: AsyncSession, obj_in: CreateDictData, user_id: int) -> None: - await self.create_(db, obj_in, user_id) + async def create(self, db: AsyncSession, obj_in: CreateDictData) -> None: + await self.create_(db, obj_in) - async def update(self, db: AsyncSession, pk: int, obj_in: UpdateDictData, user_id: int) -> int: - return await self.update_(db, pk, obj_in, user_id) + async def update(self, db: AsyncSession, pk: int, obj_in: UpdateDictData) -> int: + return await self.update_(db, pk, obj_in) async def delete(self, db: AsyncSession, pk: list[int]) -> int: apis = await db.execute(delete(self.model).where(self.model.id.in_(pk))) diff --git a/backend/app/crud/crud_dict_type.py b/backend/app/crud/crud_dict_type.py index 33b334c..934eeaf 100644 --- a/backend/app/crud/crud_dict_type.py +++ b/backend/app/crud/crud_dict_type.py @@ -29,11 +29,11 @@ class CRUDDictType(CRUDBase[DictType, CreateDictType, UpdateDictType]): dept = await db.execute(select(self.model).where(self.model.code == code)) return dept.scalars().first() - async def create(self, db: AsyncSession, obj_in: CreateDictType, user_id: int) -> None: - await self.create_(db, obj_in, user_id) + async def create(self, db: AsyncSession, obj_in: CreateDictType) -> None: + await self.create_(db, obj_in) - async def update(self, db: AsyncSession, pk: int, obj_in: UpdateDictType, user_id: int) -> int: - return await self.update_(db, pk, obj_in, user_id) + async def update(self, db: AsyncSession, pk: int, obj_in: UpdateDictType) -> int: + return await self.update_(db, pk, obj_in) async def delete(self, db: AsyncSession, pk: list[int]) -> int: apis = await db.execute(delete(self.model).where(self.model.id.in_(pk))) diff --git a/backend/app/crud/crud_login_log.py b/backend/app/crud/crud_login_log.py index 11af50c..5dea793 100644 --- a/backend/app/crud/crud_login_log.py +++ b/backend/app/crud/crud_login_log.py @@ -12,7 +12,7 @@ from backend.app.schemas.login_log import CreateLoginLog, UpdateLoginLog class CRUDLoginLog(CRUDBase[LoginLog, CreateLoginLog, UpdateLoginLog]): async def get_all(self, username: str | None = None, status: int | None = None, ip: str | None = None) -> Select: - se = select(self.model).order_by(desc(self.model.create_time)) + se = select(self.model).order_by(desc(self.model.created_time)) where_list = [] if username: where_list.append(self.model.username.like(f'%{username}%')) diff --git a/backend/app/crud/crud_menu.py b/backend/app/crud/crud_menu.py index a7dbc5c..09a96c1 100644 --- a/backend/app/crud/crud_menu.py +++ b/backend/app/crud/crud_menu.py @@ -38,13 +38,11 @@ class CRUDMenu(CRUDBase[Menu, CreateMenu, UpdateMenu]): menu = await db.execute(se) return menu.scalars().all() - async def create(self, db, obj_in: dict, user_id: int) -> None: - obj_in.update({'create_user': user_id}) - obj = self.model(**obj_in) - db.add(obj) + async def create(self, db, obj_in: CreateMenu) -> None: + await self.create_(db, obj_in) - async def update(self, db, menu_id: int, obj_in: dict, user_id: int) -> int: - return await self.update_(db, menu_id, obj_in, user_id) + async def update(self, db, menu_id: int, obj_in: UpdateMenu) -> int: + return await self.update_(db, menu_id, obj_in) async def delete(self, db, menu_id: int) -> int: return await self.delete_(db, menu_id) diff --git a/backend/app/crud/crud_opera_log.py b/backend/app/crud/crud_opera_log.py index 51635f1..04046a7 100644 --- a/backend/app/crud/crud_opera_log.py +++ b/backend/app/crud/crud_opera_log.py @@ -12,7 +12,7 @@ from backend.app.schemas.opera_log import CreateOperaLog, UpdateOperaLog class CRUDOperaLogDao(CRUDBase[OperaLog, CreateOperaLog, UpdateOperaLog]): async def get_all(self, username: str | None = None, status: int | None = None, ip: str | None = None) -> Select: - se = select(self.model).order_by(desc(self.model.create_time)) + se = select(self.model).order_by(desc(self.model.created_time)) where_list = [] if username: where_list.append(self.model.username.like(f'%{username}%')) diff --git a/backend/app/crud/crud_role.py b/backend/app/crud/crud_role.py index b8e2e84..1db5693 100644 --- a/backend/app/crud/crud_role.py +++ b/backend/app/crud/crud_role.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- from typing import NoReturn -from sqlalchemy import select, update, delete +from sqlalchemy import select, update, delete, desc from sqlalchemy.orm import selectinload from backend.app.crud.base import CRUDBase @@ -21,7 +21,7 @@ class CRUDRole(CRUDBase[Role, CreateRole, UpdateRole]): return role.scalars().first() async def get_all(self, name: str = None, data_scope: int = None): - se = select(self.model).options(selectinload(self.model.menus)).order_by(self.model.created_time.desc()) + se = select(self.model).options(selectinload(self.model.menus)).order_by(desc(self.model.created_time)) where_list = [] if name: where_list.append(self.model.name.like(f'%{name}%')) @@ -35,19 +35,17 @@ class CRUDRole(CRUDBase[Role, CreateRole, UpdateRole]): role = await db.execute(select(self.model).where(self.model.name == name)) return role.scalars().first() - async def create(self, db, obj_in: CreateRole, user_id: int) -> NoReturn: - new_role = self.model(**obj_in.dict(exclude={'menus'}), create_user=user_id) + async def create(self, db, obj_in: CreateRole) -> NoReturn: + new_role = self.model(**obj_in.dict(exclude={'menus'})) menu_list = [] for menu_id in obj_in.menus: menu_list.append(await db.get(Menu, menu_id)) new_role.menus.append(*menu_list) db.add(new_role) - async def update(self, db, role_id: int, obj_in: UpdateRole, user_id: int) -> int: + async def update(self, db, role_id: int, obj_in: UpdateRole) -> int: role = await db.execute( - update(self.model) - .where(self.model.id == role_id) - .values(**obj_in.dict(exclude={'menus'}), update_user=user_id) + update(self.model).where(self.model.id == role_id).values(**obj_in.dict(exclude={'menus'})) ) current_role = await self.get_with_relation(db, role_id) # 删除角色所有菜单 diff --git a/backend/app/crud/crud_user.py b/backend/app/crud/crud_user.py index aa1bf9f..94491ac 100644 --- a/backend/app/crud/crud_user.py +++ b/backend/app/crud/crud_user.py @@ -23,7 +23,9 @@ class CRUDUser(CRUDBase[User, CreateUser, UpdateUser]): return user.scalars().first() async def update_login_time(self, db: AsyncSession, username: str, login_time: datetime) -> int: - user = await db.execute(update(self.model).where(self.model.username == username).values(last_login=login_time)) + user = await db.execute( + update(self.model).where(self.model.username == username).values(last_login_time=login_time) + ) await db.commit() return user.rowcount @@ -72,7 +74,7 @@ class CRUDUser(CRUDBase[User, CreateUser, UpdateUser]): select(self.model) .options(selectinload(self.model.dept)) .options(selectinload(self.model.roles).selectinload(Role.menus)) - .order_by(desc(self.model.time_joined)) + .order_by(desc(self.model.join_time)) ) where_list = [] if username: diff --git a/backend/app/database/db_mysql.py b/backend/app/database/db_mysql.py index 4313626..1448ed3 100644 --- a/backend/app/database/db_mysql.py +++ b/backend/app/database/db_mysql.py @@ -9,11 +9,7 @@ from typing_extensions import Annotated from backend.app.common.log import log from backend.app.core.conf import settings -from backend.app.database.base_class import MappedBase - -""" -说明:SqlAlchemy -""" +from backend.app.models.base import MappedBase def create_engine_and_session(url: str | URL): @@ -38,11 +34,7 @@ async_engine, async_db_session = create_engine_and_session(SQLALCHEMY_DATABASE_U async def get_db() -> AsyncSession: - """ - session 生成器 - - :return: - """ + """session 生成器""" session = async_db_session() try: yield session @@ -58,8 +50,6 @@ CurrentSession = Annotated[AsyncSession, Depends(get_db)] async def create_table(): - """ - 创建数据库表 - """ + """创建数据库表""" async with async_engine.begin() as coon: await coon.run_sync(MappedBase.metadata.create_all) diff --git a/backend/app/init_test_data.py b/backend/app/init_test_data.py index 802c0cc..7cef7c9 100644 --- a/backend/app/init_test_data.py +++ b/backend/app/init_test_data.py @@ -21,15 +21,15 @@ class InitTestData: async def create_dept(self): """自动创建部门""" async with self.session.begin() as db: - department_obj = Dept(name='test', create_user=1) + department_obj = Dept(name='test') db.add(department_obj) log.info('部门 test 创建成功') async def create_role(self): """自动创建角色""" async with self.session.begin() as db: - role_obj = Role(name='test', create_user=1) - role_obj.menus.append(Menu(name='test', create_user=1)) + role_obj = Role(name='test') + role_obj.menus.append(Menu(name='test')) db.add(role_obj) log.info('角色 test 创建成功') diff --git a/backend/app/middleware/access_middleware.py b/backend/app/middleware/access_middleware.py index f85c76c..24594a6 100644 --- a/backend/app/middleware/access_middleware.py +++ b/backend/app/middleware/access_middleware.py @@ -1,19 +1,18 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from datetime import datetime - from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from backend.app.common.log import log +from backend.app.utils.timezone import timezone_utils class AccessMiddleware(BaseHTTPMiddleware): """记录请求日志中间件""" async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: - start_time = datetime.now() + start_time = timezone_utils.get_timezone_datetime() response = await call_next(request) - end_time = datetime.now() + end_time = timezone_utils.get_timezone_datetime() log.info(f'{response.status_code} {request.client.host} {request.method} {request.url} {end_time - start_time}') return response diff --git a/backend/app/middleware/opera_log_middleware.py b/backend/app/middleware/opera_log_middleware.py index 3e3e7f8..2da8fdb 100644 --- a/backend/app/middleware/opera_log_middleware.py +++ b/backend/app/middleware/opera_log_middleware.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from datetime import datetime from typing import Any from asgiref.sync import sync_to_async @@ -16,6 +15,7 @@ from backend.app.schemas.opera_log import CreateOperaLog from backend.app.services.opera_log_service import OperaLogService from backend.app.utils.encrypt import AESCipher, Md5Cipher from backend.app.utils.request_parse import parse_user_agent_info, parse_ip_info +from backend.app.utils.timezone import timezone_utils class OperaLogMiddleware: @@ -59,9 +59,9 @@ class OperaLogMiddleware: request.state.device = device # 执行请求 - start_time = datetime.now() + start_time = timezone_utils.get_timezone_datetime() code, msg, status, err = await self.execute_request(request, send) - end_time = datetime.now() + end_time = timezone_utils.get_timezone_datetime() cost_time = (end_time - start_time).total_seconds() * 1000.0 router = request.scope.get('route') diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index b3b4a8d..c0b8b24 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -4,7 +4,7 @@ # 导入所有模型,并将 Base 放在最前面, 以便 Base 拥有它们 # imported by Alembic """ -from backend.app.database.base_class import MappedBase +from backend.app.models.base import MappedBase from backend.app.models.sys_api import Api from backend.app.models.sys_casbin_rule import CasbinRule from backend.app.models.sys_dept import Dept diff --git a/backend/app/database/base_class.py b/backend/app/models/base.py similarity index 67% rename from backend/app/database/base_class.py rename to backend/app/models/base.py index 0bb7637..2345986 100644 --- a/backend/app/database/base_class.py +++ b/backend/app/models/base.py @@ -1,32 +1,36 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import uuid from datetime import datetime -from sqlalchemy import func from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, declared_attr, MappedAsDataclass from typing_extensions import Annotated +from backend.app.utils.timezone import timezone_utils + # 通用 Mapped 类型主键, 需手动添加,参考以下使用方式 # MappedBase -> id: Mapped[id_key] # DataClassBase && Base -> id: Mapped[id_key] = mapped_column(init=False) id_key = Annotated[ - int, mapped_column(primary_key=True, index=True, autoincrement=True, sort_order=-9999, comment='主键id') + int, mapped_column(primary_key=True, index=True, autoincrement=True, sort_order=-999, comment='主键id') ] -class _BaseMixin(MappedAsDataclass): - """ - Mixin 数据类 +# Mixin: 一种面向对象编程概念, 使结构变得更加清晰, `Wiki `__ +class UserMixin(MappedAsDataclass): + """用户 Mixin 数据类""" - Mixin: 一种面向对象编程概念, 使结构变得更加清晰, `Wiki `__ - """ + create_user: Mapped[int] = mapped_column(sort_order=998, comment='创建者') + update_user: Mapped[int | None] = mapped_column(init=False, default=None, sort_order=998, comment='修改者') - create_user: Mapped[int] = mapped_column(sort_order=9999, comment='创建者') - update_user: Mapped[int | None] = mapped_column(init=False, default=None, sort_order=9999, comment='修改者') - created_time: Mapped[datetime] = mapped_column(init=False, default=func.now(), sort_order=9999, comment='创建时间') + +class DateTimeMixin(MappedAsDataclass): + """日期时间 Mixin 数据类""" + + created_time: Mapped[datetime] = mapped_column( + init=False, default_factory=timezone_utils.get_timezone_datetime, sort_order=999, comment='创建时间' + ) updated_time: Mapped[datetime | None] = mapped_column( - init=False, onupdate=func.now(), sort_order=9999, comment='更新时间' + init=False, onupdate=timezone_utils.get_timezone_datetime, sort_order=999, comment='更新时间' ) @@ -53,18 +57,9 @@ class DataClassBase(MappedAsDataclass, MappedBase): __abstract__ = True -class Base(DataClassBase, _BaseMixin): +class Base(DataClassBase, DateTimeMixin): """ 声明性 Mixin 数据类基类, 带有数据类集成, 并包含 MiXin 数据类基础表结构, 你可以简单的理解它为含有基础表结构的数据类基类 """ # noqa: E501 __abstract__ = True - - -def use_uuid() -> str: - """ - 使用uuid - - :return: - """ - return uuid.uuid4().hex diff --git a/backend/app/models/sys_api.py b/backend/app/models/sys_api.py index 4f4a243..147e0ea 100644 --- a/backend/app/models/sys_api.py +++ b/backend/app/models/sys_api.py @@ -5,7 +5,7 @@ from sqlalchemy import String from sqlalchemy.dialects.mysql import LONGTEXT from sqlalchemy.orm import Mapped, mapped_column -from backend.app.database.base_class import Base, id_key +from backend.app.models.base import Base, id_key class Api(Base): diff --git a/backend/app/models/sys_casbin_rule.py b/backend/app/models/sys_casbin_rule.py index d6ab000..e3ea5ca 100644 --- a/backend/app/models/sys_casbin_rule.py +++ b/backend/app/models/sys_casbin_rule.py @@ -1,17 +1,14 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- - from sqlalchemy import String from sqlalchemy.dialects.mysql import LONGTEXT from sqlalchemy.orm import Mapped, mapped_column -from backend.app.database.base_class import id_key, MappedBase +from backend.app.models.base import id_key, MappedBase class CasbinRule(MappedBase): - """ - 重写 casbin_sqlalchemy_adapter 中的 casbinRule model类, 使用自定义 MappedBase, 避免产生 alembic 迁移问题 - """ + """重写 casbin 中的 casbinRule model 类, 使用自定义 Base, 避免产生 alembic 迁移问题""" __tablename__ = 'sys_casbin_rule' @@ -23,3 +20,14 @@ class CasbinRule(MappedBase): v3: Mapped[str | None] = mapped_column(String(255)) v4: Mapped[str | None] = mapped_column(String(255)) v5: Mapped[str | None] = mapped_column(String(255)) + + def __str__(self): + arr = [self.ptype] + for v in (self.v0, self.v1, self.v2, self.v3, self.v4, self.v5): + if v is None: + break + arr.append(v) + return ', '.join(arr) + + def __repr__(self): + return ''.format(self.id, str(self)) diff --git a/backend/app/models/sys_dept.py b/backend/app/models/sys_dept.py index f84565c..bcc487b 100644 --- a/backend/app/models/sys_dept.py +++ b/backend/app/models/sys_dept.py @@ -5,7 +5,7 @@ from typing import Union from sqlalchemy import String, ForeignKey from sqlalchemy.orm import Mapped, mapped_column, relationship -from backend.app.database.base_class import Base, id_key +from backend.app.models.base import Base, id_key class Dept(Base): diff --git a/backend/app/models/sys_dict_data.py b/backend/app/models/sys_dict_data.py index fe6fe5f..618c39a 100644 --- a/backend/app/models/sys_dict_data.py +++ b/backend/app/models/sys_dict_data.py @@ -4,7 +4,7 @@ from sqlalchemy import String, ForeignKey from sqlalchemy.dialects.mysql import LONGTEXT from sqlalchemy.orm import Mapped, mapped_column, relationship -from backend.app.database.base_class import Base, id_key +from backend.app.models.base import Base, id_key class DictData(Base): diff --git a/backend/app/models/sys_dict_type.py b/backend/app/models/sys_dict_type.py index 0f4f243..98b4134 100644 --- a/backend/app/models/sys_dict_type.py +++ b/backend/app/models/sys_dict_type.py @@ -4,7 +4,7 @@ from sqlalchemy import String from sqlalchemy.dialects.mysql import LONGTEXT from sqlalchemy.orm import Mapped, mapped_column, relationship -from backend.app.database.base_class import Base, id_key +from backend.app.models.base import Base, id_key class DictType(Base): diff --git a/backend/app/models/sys_login_log.py b/backend/app/models/sys_login_log.py index c3ba72d..1a06c96 100644 --- a/backend/app/models/sys_login_log.py +++ b/backend/app/models/sys_login_log.py @@ -2,11 +2,12 @@ # -*- coding: utf-8 -*- from datetime import datetime -from sqlalchemy import String, func +from sqlalchemy import String from sqlalchemy.dialects.mysql import LONGTEXT from sqlalchemy.orm import Mapped, mapped_column -from backend.app.database.base_class import DataClassBase, id_key +from backend.app.models.base import DataClassBase, id_key +from backend.app.utils.timezone import timezone_utils class LoginLog(DataClassBase): @@ -28,4 +29,6 @@ class LoginLog(DataClassBase): device: Mapped[str | None] = mapped_column(String(50), comment='设备') msg: Mapped[str] = mapped_column(LONGTEXT, comment='提示消息') login_time: Mapped[datetime] = mapped_column(comment='登录时间') - create_time: Mapped[datetime] = mapped_column(init=False, default=func.now(), comment='创建时间') + created_time: Mapped[datetime] = mapped_column( + init=False, default_factory=timezone_utils.get_timezone_datetime, comment='创建时间' + ) diff --git a/backend/app/models/sys_menu.py b/backend/app/models/sys_menu.py index 0401887..7bda51c 100644 --- a/backend/app/models/sys_menu.py +++ b/backend/app/models/sys_menu.py @@ -6,7 +6,7 @@ from sqlalchemy import String, ForeignKey from sqlalchemy.dialects.mysql import LONGTEXT from sqlalchemy.orm import Mapped, mapped_column, relationship -from backend.app.database.base_class import Base, id_key +from backend.app.models.base import Base, id_key from backend.app.models.sys_role_menu import sys_role_menu diff --git a/backend/app/models/sys_opera_log.py b/backend/app/models/sys_opera_log.py index 593e94a..1ef047f 100644 --- a/backend/app/models/sys_opera_log.py +++ b/backend/app/models/sys_opera_log.py @@ -2,11 +2,12 @@ # -*- coding: utf-8 -*- from datetime import datetime -from sqlalchemy import String, func +from sqlalchemy import String from sqlalchemy.dialects.mysql import JSON, LONGTEXT from sqlalchemy.orm import Mapped, mapped_column -from backend.app.database.base_class import DataClassBase, id_key +from backend.app.models.base import DataClassBase, id_key +from backend.app.utils.timezone import timezone_utils class OperaLog(DataClassBase): @@ -33,4 +34,6 @@ class OperaLog(DataClassBase): msg: Mapped[str | None] = mapped_column(LONGTEXT, comment='提示消息') cost_time: Mapped[float] = mapped_column(insert_default=0.0, comment='请求耗时ms') opera_time: Mapped[datetime] = mapped_column(comment='操作时间') - create_time: Mapped[datetime] = mapped_column(init=False, default=func.now(), comment='创建时间') + created_time: Mapped[datetime] = mapped_column( + init=False, default_factory=timezone_utils.get_timezone_datetime, comment='创建时间' + ) diff --git a/backend/app/models/sys_role.py b/backend/app/models/sys_role.py index f2fb03e..efdb737 100644 --- a/backend/app/models/sys_role.py +++ b/backend/app/models/sys_role.py @@ -4,7 +4,7 @@ from sqlalchemy import String from sqlalchemy.dialects.mysql import LONGTEXT from sqlalchemy.orm import Mapped, mapped_column, relationship -from backend.app.database.base_class import Base, id_key +from backend.app.models.base import Base, id_key from backend.app.models.sys_role_menu import sys_role_menu from backend.app.models.sys_user_role import sys_user_role diff --git a/backend/app/models/sys_role_menu.py b/backend/app/models/sys_role_menu.py index 73835df..48bff2c 100644 --- a/backend/app/models/sys_role_menu.py +++ b/backend/app/models/sys_role_menu.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- from sqlalchemy import Table, Column, ForeignKey, INT, Integer -from backend.app.database.base_class import MappedBase +from backend.app.models.base import MappedBase sys_role_menu = Table( 'sys_role_menu', diff --git a/backend/app/models/sys_user.py b/backend/app/models/sys_user.py index 3b7614f..2d913ef 100644 --- a/backend/app/models/sys_user.py +++ b/backend/app/models/sys_user.py @@ -2,21 +2,23 @@ # -*- coding: utf-8 -*- from datetime import datetime from typing import Union +from uuid import uuid4 -from sqlalchemy import func, String, ForeignKey +from sqlalchemy import String, ForeignKey from sqlalchemy.orm import Mapped, mapped_column, relationship -from backend.app.database.base_class import use_uuid, id_key, DataClassBase +from backend.app.models.base import id_key, Base from backend.app.models.sys_user_role import sys_user_role +from backend.app.utils.timezone import timezone_utils -class User(DataClassBase): +class User(Base): """用户表""" __tablename__ = 'sys_user' id: Mapped[id_key] = mapped_column(init=False) - user_uuid: Mapped[str] = mapped_column(String(50), init=False, insert_default=use_uuid, unique=True) + uuid: Mapped[str] = mapped_column(String(50), init=False, default_factory=uuid4, unique=True) username: Mapped[str] = mapped_column(String(20), unique=True, index=True, comment='用户名') nickname: Mapped[str] = mapped_column(String(20), unique=True, comment='昵称') password: Mapped[str] = mapped_column(String(255), comment='密码') @@ -26,8 +28,12 @@ class User(DataClassBase): is_multi_login: Mapped[bool] = mapped_column(default=False, comment='是否重复登陆(0否 1是)') avatar: Mapped[str | None] = mapped_column(String(255), default=None, comment='头像') phone: Mapped[str | None] = mapped_column(String(11), default=None, comment='手机号') - time_joined: Mapped[datetime] = mapped_column(init=False, default=func.now(), comment='注册时间') - last_login: Mapped[datetime | None] = mapped_column(init=False, onupdate=func.now(), comment='上次登录') + join_time: Mapped[datetime] = mapped_column( + init=False, default_factory=timezone_utils.get_timezone_datetime, comment='注册时间' + ) + last_login_time: Mapped[datetime | None] = mapped_column( + init=False, onupdate=timezone_utils.get_timezone_datetime, comment='上次登录' + ) # 部门用户一对多 dept_id: Mapped[int | None] = mapped_column( ForeignKey('sys_dept.id', ondelete='SET NULL'), default=None, comment='部门关联ID' diff --git a/backend/app/models/sys_user_role.py b/backend/app/models/sys_user_role.py index c29d16f..8174111 100644 --- a/backend/app/models/sys_user_role.py +++ b/backend/app/models/sys_user_role.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- from sqlalchemy import Table, Column, ForeignKey, INT, Integer -from backend.app.database.base_class import MappedBase +from backend.app.models.base import MappedBase sys_user_role = Table( 'sys_user_role', diff --git a/backend/app/schemas/api.py b/backend/app/schemas/api.py index 09cb425..708d870 100644 --- a/backend/app/schemas/api.py +++ b/backend/app/schemas/api.py @@ -31,8 +31,6 @@ class UpdateApi(ApiBase): class GetAllApi(ApiBase): id: int - create_user: int - update_user: int = None created_time: datetime updated_time: datetime | None = None diff --git a/backend/app/schemas/dept.py b/backend/app/schemas/dept.py index 096400d..5cfb379 100644 --- a/backend/app/schemas/dept.py +++ b/backend/app/schemas/dept.py @@ -48,8 +48,6 @@ class UpdateDept(DeptBase): class GetAllDept(DeptBase): id: int del_flag: bool - create_user: int - update_user: int = None created_time: datetime updated_time: datetime | None = None diff --git a/backend/app/schemas/dict_data.py b/backend/app/schemas/dict_data.py index 8ad33f2..51c6776 100644 --- a/backend/app/schemas/dict_data.py +++ b/backend/app/schemas/dict_data.py @@ -29,8 +29,6 @@ class UpdateDictData(DictDataBase): class GetAllDictData(DictDataBase): id: int type: GetAllDictType - create_user: int - update_user: int = None created_time: datetime updated_time: datetime | None = None diff --git a/backend/app/schemas/dict_type.py b/backend/app/schemas/dict_type.py index 94264c8..41de459 100644 --- a/backend/app/schemas/dict_type.py +++ b/backend/app/schemas/dict_type.py @@ -25,8 +25,6 @@ class UpdateDictType(DictTypeBase): class GetAllDictType(DictTypeBase): id: int - create_user: int - update_user: int = None created_time: datetime updated_time: datetime | None = None diff --git a/backend/app/schemas/login_log.py b/backend/app/schemas/login_log.py index ac0d257..d7556b8 100644 --- a/backend/app/schemas/login_log.py +++ b/backend/app/schemas/login_log.py @@ -2,7 +2,6 @@ # -*- coding: utf-8 -*- from datetime import datetime - from backend.app.schemas.base import SchemaBase @@ -32,7 +31,7 @@ class UpdateLoginLog(LoginLogBase): class GetAllLoginLog(LoginLogBase): id: int - create_time: datetime + created_time: datetime class Config: orm_mode = True diff --git a/backend/app/schemas/menu.py b/backend/app/schemas/menu.py index 8b5b1b2..00961b2 100644 --- a/backend/app/schemas/menu.py +++ b/backend/app/schemas/menu.py @@ -31,8 +31,6 @@ class UpdateMenu(MenuBase): class GetAllMenu(MenuBase): id: int - create_user: int - update_user: int = None created_time: datetime updated_time: datetime | None = None diff --git a/backend/app/schemas/opera_log.py b/backend/app/schemas/opera_log.py index bbc46ab..35553f9 100644 --- a/backend/app/schemas/opera_log.py +++ b/backend/app/schemas/opera_log.py @@ -39,7 +39,7 @@ class UpdateOperaLog(OperaLogBase): class GetAllOperaLog(OperaLogBase): id: int - create_time: datetime + created_time: datetime class Config: orm_mode = True diff --git a/backend/app/schemas/role.py b/backend/app/schemas/role.py index b4b1ef3..2ecfcd4 100644 --- a/backend/app/schemas/role.py +++ b/backend/app/schemas/role.py @@ -26,8 +26,6 @@ class UpdateRole(RoleBase): class GetAllRole(RoleBase): id: int - create_user: int - update_user: int = None created_time: datetime updated_time: datetime | None = None menus: list[GetAllMenu] diff --git a/backend/app/schemas/token.py b/backend/app/schemas/token.py index c4e4045..242880a 100644 --- a/backend/app/schemas/token.py +++ b/backend/app/schemas/token.py @@ -2,7 +2,6 @@ # -*- coding: utf-8 -*- from datetime import datetime - from backend.app.schemas.base import SchemaBase from backend.app.schemas.user import GetUserInfoNoRelation diff --git a/backend/app/schemas/user.py b/backend/app/schemas/user.py index edba120..7a42c2a 100644 --- a/backend/app/schemas/user.py +++ b/backend/app/schemas/user.py @@ -68,13 +68,13 @@ class Avatar(SchemaBase): class GetUserInfoNoRelation(_UserInfoBase): dept_id: int | None = None id: int - user_uuid: str + uuid: str avatar: str | None = None status: StatusType = Field(default=StatusType.enable) is_superuser: bool is_multi_login: bool - time_joined: datetime = None - last_login: datetime | None = None + join_time: datetime = None + last_login_time: datetime | None = None class Config: orm_mode = True diff --git a/backend/app/services/api_service.py b/backend/app/services/api_service.py index 4da6caa..776d50b 100644 --- a/backend/app/services/api_service.py +++ b/backend/app/services/api_service.py @@ -23,17 +23,17 @@ class ApiService: return await ApiDao.get_all(name=name, method=method, path=path) @staticmethod - async def create(*, obj: CreateApi, user_id: int) -> None: + async def create(*, obj: CreateApi) -> None: async with async_db_session.begin() as db: api = await ApiDao.get_by_name(db, obj.name) if api: raise errors.ForbiddenError(msg='接口已存在') - await ApiDao.create(db, obj, user_id) + await ApiDao.create(db, obj) @staticmethod - async def update(*, pk: int, obj: UpdateApi, user_id: int) -> int: + async def update(*, pk: int, obj: UpdateApi) -> int: async with async_db_session.begin() as db: - count = await ApiDao.update(db, pk, obj, user_id) + count = await ApiDao.update(db, pk, obj) return count @staticmethod diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py index 6de1f04..a0aac8c 100644 --- a/backend/app/services/auth_service.py +++ b/backend/app/services/auth_service.py @@ -5,7 +5,6 @@ from typing import NoReturn from fastapi import Request from fastapi.security import OAuth2PasswordRequestForm -from pydantic.datetime_parse import parse_datetime from starlette.background import BackgroundTasks, BackgroundTask from backend.app.common import jwt @@ -19,10 +18,11 @@ from backend.app.crud.crud_user import UserDao from backend.app.database.db_mysql import async_db_session from backend.app.schemas.user import AuthLogin from backend.app.services.login_log_service import LoginLogService +from backend.app.utils.timezone import timezone_utils class AuthService: - login_time = parse_datetime(datetime.now()) + login_time = timezone_utils.get_timezone_datetime() async def swagger_login(self, *, form_data: OAuth2PasswordRequestForm): async with async_db_session() as db: @@ -89,6 +89,7 @@ class AuthService: msg='登录成功', ) background_tasks.add_task(LoginLogService.create, **log_info) + await redis_client.delete(f'{settings.CAPTCHA_LOGIN_REDIS_PREFIX}:{request.state.ip}') return access_token, refresh_token, access_token_expire_time, refresh_token_expire_time, user @staticmethod diff --git a/backend/app/services/dept_service.py b/backend/app/services/dept_service.py index 3cc4a00..8d8572d 100644 --- a/backend/app/services/dept_service.py +++ b/backend/app/services/dept_service.py @@ -28,7 +28,7 @@ class DeptService: return tree_data @staticmethod - async def create(*, obj: CreateDept, user_id: int): + async def create(*, obj: CreateDept): async with async_db_session.begin() as db: dept = await DeptDao.get_by_name(db, obj.name) if dept: @@ -37,11 +37,10 @@ class DeptService: parent_dept = await DeptDao.get(db, obj.parent_id) if not parent_dept: raise errors.NotFoundError(msg='父级部门不存在') - new_obj = obj.dict() - await DeptDao.create(db, new_obj, user_id) + await DeptDao.create(db, obj) @staticmethod - async def update(*, pk: int, obj: UpdateDept, user_id: int): + async def update(*, pk: int, obj: UpdateDept): async with async_db_session.begin() as db: dept = await DeptDao.get(db, pk) if not dept: @@ -53,8 +52,7 @@ class DeptService: parent_dept = await DeptDao.get(db, obj.parent_id) if not parent_dept: raise errors.NotFoundError(msg='父级部门不存在') - new_obj = obj.dict() - count = await DeptDao.update(db, pk, new_obj, user_id) + count = await DeptDao.update(db, pk, obj) return count @staticmethod diff --git a/backend/app/services/dict_data_service.py b/backend/app/services/dict_data_service.py index d534512..f19a040 100644 --- a/backend/app/services/dict_data_service.py +++ b/backend/app/services/dict_data_service.py @@ -24,7 +24,7 @@ class DictDataService: return await DictDataDao.get_all(label=label, value=value, status=status) @staticmethod - async def create(*, obj: CreateDictData, user_id: int) -> None: + async def create(*, obj: CreateDictData) -> None: async with async_db_session.begin() as db: dict_data = await DictDataDao.get_by_label(db, obj.label) if dict_data: @@ -32,10 +32,10 @@ class DictDataService: dict_type = await DictTypeDao.get(db, obj.type_id) if not dict_type: raise errors.ForbiddenError(msg='字典类型不存在') - await DictDataDao.create(db, obj, user_id) + await DictDataDao.create(db, obj) @staticmethod - async def update(*, pk: int, obj: UpdateDictData, user_id: int) -> int: + async def update(*, pk: int, obj: UpdateDictData) -> int: async with async_db_session.begin() as db: dict_data = await DictDataDao.get(db, pk) if not dict_data: @@ -46,7 +46,7 @@ class DictDataService: dict_type = await DictTypeDao.get(db, obj.type_id) if not dict_type: raise errors.ForbiddenError(msg='字典类型不存在') - count = await DictDataDao.update(db, pk, obj, user_id) + count = await DictDataDao.update(db, pk, obj) return count @staticmethod diff --git a/backend/app/services/dict_type_service.py b/backend/app/services/dict_type_service.py index a471d8b..dcf4081 100644 --- a/backend/app/services/dict_type_service.py +++ b/backend/app/services/dict_type_service.py @@ -14,15 +14,15 @@ class DictTypeService: return await DictTypeDao.get_all(name=name, code=code, status=status) @staticmethod - async def create(*, obj: CreateDictType, user_id: int) -> None: + async def create(*, obj: CreateDictType) -> None: async with async_db_session.begin() as db: dict_type = await DictTypeDao.get_by_code(db, obj.code) if dict_type: raise errors.ForbiddenError(msg='字典类型已存在') - await DictTypeDao.create(db, obj, user_id) + await DictTypeDao.create(db, obj) @staticmethod - async def update(*, pk: int, obj: UpdateDictType, user_id: int) -> int: + async def update(*, pk: int, obj: UpdateDictType) -> int: async with async_db_session.begin() as db: dict_type = await DictTypeDao.get(db, pk) if not dict_type: @@ -30,7 +30,7 @@ class DictTypeService: if dict_type.code != obj.code: if await DictTypeDao.get_by_code(db, obj.code): raise errors.ForbiddenError(msg='字典类型已存在') - count = await DictTypeDao.update(db, pk, obj, user_id) + count = await DictTypeDao.update(db, pk, obj) return count @staticmethod diff --git a/backend/app/services/login_log_service.py b/backend/app/services/login_log_service.py index f0b93da..cda1579 100644 --- a/backend/app/services/login_log_service.py +++ b/backend/app/services/login_log_service.py @@ -26,7 +26,7 @@ class LoginLogService: try: # request.state 来自 opera log 中间件定义的扩展参数,详见 opera_log_middleware.py obj_in = CreateLoginLog( - user_uuid=user.user_uuid, + user_uuid=user.uuid, username=user.username, status=status, ip=request.state.ip, diff --git a/backend/app/services/menu_service.py b/backend/app/services/menu_service.py index 4dad414..1d8bdd6 100644 --- a/backend/app/services/menu_service.py +++ b/backend/app/services/menu_service.py @@ -36,7 +36,7 @@ class MenuService: return menu_tree @staticmethod - async def create(*, obj: CreateMenu, user_id: int): + async def create(*, obj: CreateMenu): async with async_db_session.begin() as db: menu = await MenuDao.get_by_name(db, obj.name) if menu: @@ -45,11 +45,10 @@ class MenuService: parent_menu = await MenuDao.get(db, obj.parent_id) if not parent_menu: raise errors.NotFoundError(msg='父级菜单不存在') - new_obj = obj.dict() - await MenuDao.create(db, new_obj, user_id) + await MenuDao.create(db, obj) @staticmethod - async def update(*, pk: int, obj: UpdateMenu, user_id: int): + async def update(*, pk: int, obj: UpdateMenu): async with async_db_session.begin() as db: menu = await MenuDao.get(db, pk) if not menu: @@ -61,8 +60,7 @@ class MenuService: parent_menu = await MenuDao.get(db, obj.parent_id) if not parent_menu: raise errors.NotFoundError(msg='父级菜单不存在') - new_obj = obj.dict() - count = await MenuDao.update(db, pk, new_obj, user_id) + count = await MenuDao.update(db, pk, obj) return count @staticmethod diff --git a/backend/app/services/role_service.py b/backend/app/services/role_service.py index 80f5758..551f2bd 100644 --- a/backend/app/services/role_service.py +++ b/backend/app/services/role_service.py @@ -24,7 +24,7 @@ class RoleService: return await RoleDao.get_all(name=name, data_scope=data_scope) @staticmethod - async def create(*, obj: CreateRole, user_id: int) -> None: + async def create(*, obj: CreateRole) -> None: async with async_db_session.begin() as db: role = await RoleDao.get_by_name(db, obj.name) if role: @@ -33,10 +33,10 @@ class RoleService: menu = await MenuDao.get(db, menu_id) if not menu: raise errors.ForbiddenError(msg='菜单不存在') - await RoleDao.create(db, obj, user_id) + await RoleDao.create(db, obj) @staticmethod - async def update(*, pk: int, obj: UpdateRole, user_id: int) -> int: + async def update(*, pk: int, obj: UpdateRole) -> int: async with async_db_session.begin() as db: role = await RoleDao.get(db, pk) if not role: @@ -49,7 +49,7 @@ class RoleService: menu = await MenuDao.get(db, menu_id) if not menu: raise errors.ForbiddenError(msg='菜单不存在') - count = await RoleDao.update(db, pk, obj, user_id) + count = await RoleDao.update(db, pk, obj) return count @staticmethod diff --git a/backend/app/tests/utils/db_mysql.py b/backend/app/tests/utils/db_mysql.py index d349b11..280d981 100644 --- a/backend/app/tests/utils/db_mysql.py +++ b/backend/app/tests/utils/db_mysql.py @@ -1,11 +1,9 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- - from sqlalchemy.ext.asyncio import AsyncSession - from backend.app.core.conf import settings -from backend.app.database.base_class import MappedBase +from backend.app.models.base import MappedBase from backend.app.database.db_mysql import create_engine_and_session TEST_DB_DATABASE = settings.DB_DATABASE + '_test' @@ -19,11 +17,7 @@ async_engine, async_db_session = create_engine_and_session(SQLALCHEMY_DATABASE_U async def override_get_db() -> AsyncSession: - """ - session 生成器 - - :return: - """ + """session 生成器""" session = async_db_session() try: yield session @@ -35,8 +29,6 @@ async def override_get_db() -> AsyncSession: async def create_table(): - """ - 创建数据库表 - """ + """创建数据库表""" async with async_engine.begin() as coon: await coon.run_sync(MappedBase.metadata.create_all) diff --git a/backend/app/utils/data_factory.py b/backend/app/utils/data_factory.py deleted file mode 100644 index ef60556..0000000 --- a/backend/app/utils/data_factory.py +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -import datetime -import uuid - - -def get_uuid_str() -> str: - """ - 生成uuid - - :return: str(uuid) - """ - return str(uuid.uuid4()) - - -def get_current_timestamp() -> float: - """ - 生成当前时间戳 - - :return: - """ - return datetime.datetime.now().timestamp() diff --git a/backend/app/utils/datetime.py b/backend/app/utils/datetime.py deleted file mode 100644 index 33a6d71..0000000 --- a/backend/app/utils/datetime.py +++ /dev/null @@ -1,224 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -import datetime -import pytz - -from backend.app.core.conf import settings - - -class DateTimeUtils: - def __init__(self, timezone_str=settings.DATETIME_TIMEZONE): - """ - 初始化函数,设置时区 - - :param timezone_str: 时区字符串,默认为 UTC - """ - self.timezone_str = timezone_str - self.timezone = pytz.timezone(self.timezone_str) - - def get_current_time(self) -> datetime.datetime: - """ - 获取当前时间 - - :return: 当前时间的 datetime 对象 - """ - return datetime.datetime.now(self.timezone) - - @staticmethod - def get_current_timestamp() -> int: - """ - 获取当前时间戳 (秒) - - :return: 当前时间戳 (秒) - """ - return int(datetime.datetime.now().timestamp()) - - @staticmethod - def get_current_milliseconds() -> int: - """ - 获取当前时间戳 (毫秒) - - :return: 当前时间戳 (毫秒) - """ - return int(datetime.datetime.now().timestamp() * 1000) - - def timestamp_to_datetime(self, timestamp: int) -> datetime.datetime: - """ - 时间戳转 datetime 对象 - - :param timestamp: 时间戳 (秒) - :return: datetime 对象 - """ - return datetime.datetime.utcfromtimestamp(timestamp).replace(tzinfo=self.timezone) - - def datetime_to_timestamp(self, dt: datetime.datetime) -> int: - """ - datetime 对象转时间戳(秒) - - :param dt: datetime 对象 - :return: 时间戳 (秒) - """ - return int(dt.astimezone(self.timezone).timestamp()) - - def datetime_to_milliseconds(self, dt: datetime.datetime) -> int: - """ - datetime 对象转时间戳(毫秒) - - :param dt: datetime 对象 - :return: 时间戳 (毫秒) - """ - return int(dt.astimezone(self.timezone).timestamp() * 1000) - - def str_to_datetime(self, time_str: str, format_str: str = settings.DATETIME_FORMAT) -> datetime.datetime: - """ - 时间字符串转 datetime 对象 - - :param time_str: 时间字符串 - :param format_str: 时间字符串的格式,默认为 '%Y-%m-%d %H:%M:%S' - :return: datetime 对象 - """ - return datetime.datetime.strptime(time_str, format_str).replace(tzinfo=self.timezone) - - def datetime_to_str(self, dt: datetime.datetime, format_str: str = settings.DATETIME_FORMAT) -> str: - """ - datetime 对象转时间字符串 - - :param dt: datetime 对象 - :param format_str: 时间字符串的格式,默认为 '%Y-%m-%d %H:%M:%S' - :return: 时间字符串 - """ - return dt.astimezone(self.timezone).strftime(format_str) - - @staticmethod - def get_timezone(timezone_str: str) -> pytz.timezone: - """ - 获取指定时区的 pytz.timezone 对象 - - :param timezone_str: 时区字符串 - :return: pytz.timezone 对象 - """ - return pytz.timezone(timezone_str) - - def get_timezone_time(self, timezone_str: str) -> datetime.datetime: - """ - 获取指定时区的当前时间 - - :param timezone_str: 时区字符串 - :return: 当前时间的 datetime 对象 - """ - timezone = self.get_timezone(timezone_str) - return datetime.datetime.now(timezone) - - def datetime_to_timezone(self, dt: datetime.datetime, timezone_str: str) -> datetime.datetime: - """ - 将 datetime 对象转换为指定时区的 datetime 对象 - - :param dt: datetime 对象 - :param timezone_str: 目标时区字符串 - :return: 目标时区的 datetime 对象 - """ - timezone = self.get_timezone(timezone_str) - return dt.astimezone(timezone) - - def datetime_to_timezone_str( - self, dt: datetime.datetime, timezone_str: str, format_str: str = settings.DATETIME_FORMAT - ) -> str: - """ - 将 datetime 对象转换为指定时区的时间字符串 - - :param dt: datetime 对象 - :param timezone_str: 目标时区字符串 - :param format_str: 时间字符串的格式,默认为 '%Y-%m-%d %H:%M:%S' - :return: 目标时区的时间字符串 - """ - dt_timezone = self.datetime_to_timezone(dt, timezone_str) - return dt_timezone.strftime(format_str) - - def str_to_timezone( - self, time_str: str, timezone_str: str, format_str: str = settings.DATETIME_FORMAT - ) -> datetime.datetime: - """ - 将指定时区的时间字符串转换为 datetime 对象 - - :param time_str: 指定时区的时间字符串 - :param timezone_str: 指定时区字符串 - :param format_str: 时间字符串的格式,默认为 '%Y-%m-%d %H:%M:%S' - :return: datetime 对象 - """ - dt = datetime.datetime.strptime(time_str, format_str).replace(tzinfo=self.timezone) - return self.datetime_to_timezone(dt, timezone_str) - - @staticmethod - def datetime_to_utc(dt: datetime.datetime) -> datetime.datetime: - """ - 将 datetime 对象转换为 UTC 时间 - - :param dt: datetime 对象 - :return: UTC 时间的 datetime 对象 - """ - return dt.astimezone(pytz.utc) - - def str_to_utc(self, time_str: str, format_str: str = settings.DATETIME_FORMAT) -> datetime.datetime: - """ - 将时间字符串转换为 UTC 时间的 datetime 对象 - - :param time_str: 时间字符串 - :param format_str: 时间字符串的格式,默认为 '%Y-%m-%d %H:%M:%S' - :return: UTC 时间的 datetime 对象 - """ - dt = datetime.datetime.strptime(time_str, format_str).replace(tzinfo=self.timezone) - return self.datetime_to_utc(dt) - - def utc_to_datetime(self, utc_time: datetime.datetime) -> datetime.datetime: - """ - 将 UTC 时间的 datetime 对象转换为指定时区的 datetime 对象 - - :param utc_time: UTC 时间的 datetime 对象 - :return: 目标时区的 datetime 对象 - """ - return utc_time.replace(tzinfo=pytz.utc).astimezone(self.timezone).replace(tzinfo=None) - - def get_expire_time(self, expires_delta: datetime.timedelta) -> datetime: - """ - 获取过期时间 - - :param expires_delta: 时间间隔对象 - :return: 过期时间的 datetime 对象 - """ - return self.get_current_time() + expires_delta - - @staticmethod - def get_expire_time_from_datetime(expire_time: datetime, seconds: int) -> datetime: - """ - 获取从指定时间开始一定时间后的过期时间 - - :param expire_time: 指定时间的 datetime 对象 - :param seconds: 时间间隔(秒) - :return: 过期时间的 datetime 对象 - """ - return expire_time + datetime.timedelta(seconds=seconds) - - @staticmethod - def get_expire_seconds(expires_delta: datetime.timedelta) -> int: - """ - 获取过期时间(秒) - - :param expires_delta: 时间间隔对象 - :return: 过期时间(秒) - """ - return int(expires_delta.total_seconds()) - - def get_expire_seconds_from_datetime(self, expire_datetime: datetime) -> int: - """ - 获取从指定时间开始到当前时间的时间间隔(秒) - - :param expire_datetime: 指定时间的 datetime 对象 - :return: 时间间隔(秒) - """ - current_time = self.get_current_time() - if expire_datetime < current_time: - return 0 - return int((expire_datetime - current_time).total_seconds()) - - -datetime_utils = DateTimeUtils() diff --git a/backend/app/utils/server_info.py b/backend/app/utils/server_info.py index 313c693..0a98d98 100644 --- a/backend/app/utils/server_info.py +++ b/backend/app/utils/server_info.py @@ -2,12 +2,12 @@ import os import platform import socket import sys -from datetime import datetime, timedelta +from datetime import timedelta from typing import List import psutil -from backend.app.core.conf import settings +from backend.app.utils.timezone import timezone_utils class ServerInfo: @@ -23,16 +23,24 @@ class ServerInfo: @staticmethod def fmt_timedelta(td: timedelta) -> str: - """格式化时间戳""" - days, rem = divmod(td.seconds, 86400) + """格式化时间差""" + total_seconds = round(td.total_seconds()) + days, rem = divmod(total_seconds, 86400) hours, rem = divmod(rem, 3600) - minutes, _ = divmod(rem, 60) - res = f'{minutes} 分钟' - if hours: - res = f'{hours} 小时 {res}' + minutes, seconds = divmod(rem, 60) + parts = [] if days: - res = f'{days} 天 {res}' - return res + parts.append('{} 天'.format(days)) + if hours: + parts.append('{} 小时'.format(hours)) + if minutes: + parts.append('{} 分钟'.format(minutes)) + if seconds: + parts.append('{} 秒'.format(seconds)) + if len(parts) == 0: + return '0 秒' + else: + return ' '.join(parts) @staticmethod def get_cpu_info() -> dict: @@ -96,7 +104,7 @@ class ServerInfo: """获取服务信息""" process = psutil.Process(os.getpid()) mem_info = process.memory_info() - start_time = datetime.fromtimestamp(process.create_time()) + start_time = timezone_utils.utc_timestamp_to_timezone_datetime(process.create_time()) return { 'name': 'Python3', 'version': platform.python_version(), @@ -105,6 +113,6 @@ class ServerInfo: 'mem_vms': ServerInfo.format_bytes(mem_info.vms), # 虚拟内存, 即当前进程申请的虚拟内存 'mem_rss': ServerInfo.format_bytes(mem_info.rss), # 常驻内存, 即当前进程实际使用的物理内存 'mem_free': ServerInfo.format_bytes(mem_info.vms - mem_info.rss), # 空闲内存 - 'startup': start_time.strftime(settings.DATETIME_FORMAT), - 'elapsed': f'{ServerInfo.fmt_timedelta(datetime.now() - start_time)}', + 'startup': start_time, + 'elapsed': f'{ServerInfo.fmt_timedelta(timezone_utils.get_timezone_datetime() - start_time)}', } diff --git a/backend/app/utils/timezone.py b/backend/app/utils/timezone.py new file mode 100644 index 0000000..a1fbdaa --- /dev/null +++ b/backend/app/utils/timezone.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import datetime + +import pytz + +from backend.app.core.conf import settings + + +class TimeZoneUtils: + def __init__(self, timezone_str=settings.DATETIME_TIMEZONE): + self.timezone = pytz.timezone(timezone_str) + + def get_timezone_datetime(self) -> datetime.datetime: + """ + 获取时区时间 + + :return: + """ + return datetime.datetime.now(self.timezone) + + def get_timezone_timestamp(self) -> int: + """ + 获取时区时间戳 (秒) + + :return: + """ + return int(self.get_timezone_datetime().timestamp()) + + def get_timezone_milliseconds(self) -> int: + """ + 获取时区时间戳 (毫秒) + + :return: + """ + return int(self.get_timezone_datetime().timestamp() * 1000) + + def datetime_to_timezone_str(self, dt: datetime.datetime, format_str: str = settings.DATETIME_FORMAT) -> str: + """ + datetime 对象转时区时间字符串 + + :param dt: + :param format_str: + :return: + """ + return dt.astimezone(self.timezone).strftime(format_str) + + def datetime_to_timezone_datetime(self, dt: datetime.datetime) -> datetime.datetime: + """ + datetime 对象转 datetime 时区对象 + + :param dt: + :return: + """ + return dt.astimezone(self.timezone) + + @staticmethod + def datetime_to_timezone_utc(dt: datetime.datetime) -> datetime.datetime: + """ + datetime 对象转 datetime UTC 对象 + + :param dt: + :return: + """ + return dt.astimezone(pytz.utc) + + def datetime_to_timezone_timestamp(self, dt: datetime.datetime) -> int: + """ + datetime 对象转时区时间戳(秒) + + :param dt: + :return: + """ + return int(dt.astimezone(self.timezone).timestamp()) + + def datetime_to_timezone_milliseconds(self, dt: datetime.datetime) -> int: + """ + datetime 对象转时区时间戳(毫秒) + + :param dt: + :return: + """ + return int(dt.astimezone(self.timezone).timestamp() * 1000) + + def str_to_timezone_utc(self, time_str: str, format_str: str = settings.DATETIME_FORMAT) -> datetime.datetime: + """ + 时间字符串转时区 datetime UTC 对象 + + :param time_str: + :param format_str: + :return: + """ + dt = datetime.datetime.strptime(time_str, format_str).replace(tzinfo=self.timezone) + return self.datetime_to_timezone_utc(dt) + + def str_to_timezone_datetime(self, time_str: str, format_str: str = settings.DATETIME_FORMAT) -> datetime.datetime: + """ + 时间字符串转 datetime 时区对象 + + :param time_str: + :param format_str: + :return: + """ + return datetime.datetime.strptime(time_str, format_str).replace(tzinfo=self.timezone) + + def utc_datetime_to_timezone_datetime(self, utc_time: datetime.datetime) -> datetime.datetime: + """ + datetime UTC 对象转 datetime 时区对象 + + :param utc_time: + :return: + """ + return utc_time.replace(tzinfo=pytz.utc).astimezone(self.timezone) + + def utc_timestamp_to_timezone_datetime(self, timestamp: int) -> datetime.datetime: + """ + 时间戳转 datetime 时区对象 + + :param timestamp: + :return: + """ + utc_datetime = datetime.datetime.utcfromtimestamp(timestamp).replace(tzinfo=pytz.utc) + return self.datetime_to_timezone_datetime(utc_datetime) + + def get_timezone_expire_time(self, expires_delta: datetime.timedelta) -> datetime.datetime: + """ + 获取时区过期时间 + + :param expires_delta: + :return: + """ + return self.get_timezone_datetime() + expires_delta + + def get_timezone_expire_seconds(self, expire_datetime: datetime.datetime) -> int: + """ + 获取从指定时间开始到当前时间的时间间隔(秒) + + :param expire_datetime: 指定时间的 datetime 对象 + :return: 时间间隔(秒) + """ + timezone_datetime = self.get_timezone_datetime() + expire_datetime = self.datetime_to_timezone_datetime(expire_datetime) + if expire_datetime < timezone_datetime: + return 0 + return int((expire_datetime - timezone_datetime).total_seconds()) + + +timezone_utils = TimeZoneUtils()