diff --git a/backend/app/api/routers.py b/backend/app/api/routers.py index c18ab54..8e64b8e 100644 --- a/backend/app/api/routers.py +++ b/backend/app/api/routers.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- from fastapi import APIRouter +from backend.app.core.conf import settings from backend.app.api.v1.auth import router as auth_router from backend.app.api.v1.user import router as user_router from backend.app.api.v1.casbin import router as casbin_router @@ -16,7 +17,7 @@ from backend.app.api.v1.task_demo import router as task_demo_router from backend.app.api.v1.dict_type import router as dict_type_router from backend.app.api.v1.dict_data import router as dict_data_router -v1 = APIRouter(prefix='/v1') +v1 = APIRouter(prefix=settings.API_V1_STR) v1.include_router(auth_router) v1.include_router(user_router, prefix='/users', tags=['用户管理']) @@ -26,8 +27,8 @@ v1.include_router(role_router, prefix='/roles', tags=['角色管理']) v1.include_router(menu_router, prefix='/menus', tags=['菜单管理']) v1.include_router(api_router, prefix='/apis', tags=['API管理']) v1.include_router(config_router, prefix='/configs', tags=['系统配置']) +v1.include_router(dict_type_router, prefix='/dict_types', tags=['字典类型管理']) +v1.include_router(dict_data_router, prefix='/dict_datas', tags=['字典数据管理']) v1.include_router(login_log_router, prefix='/login_logs', tags=['登录日志管理']) v1.include_router(opera_log_router, prefix='/opera_logs', tags=['操作日志管理']) v1.include_router(task_demo_router, prefix='/tasks', tags=['任务管理']) -v1.include_router(dict_type_router, prefix='/dict_types', tags=['字典类型管理']) -v1.include_router(dict_data_router, prefix='/dict_datas', tags=['字典数据管理']) diff --git a/backend/app/api/v1/api.py b/backend/app/api/v1/api.py index 696df42..a344671 100644 --- a/backend/app/api/v1/api.py +++ b/backend/app/api/v1/api.py @@ -5,7 +5,6 @@ from typing import Annotated from fastapi import APIRouter, Query, Request from backend.app.common.casbin_rbac import DependsRBAC -from backend.app.common.jwt import DependsJwtAuth from backend.app.common.pagination import PageDepends, paging_data from backend.app.common.response.response_schema import response_base from backend.app.database.db_mysql import CurrentSession @@ -15,13 +14,13 @@ from backend.app.services.api_service import ApiService router = APIRouter() -@router.get('/{pk}', summary='获取接口详情', dependencies=[DependsJwtAuth]) +@router.get('/{pk}', summary='获取接口详情', dependencies=[DependsRBAC]) async def get_api(pk: int): api = await ApiService.get(pk=pk) return await response_base.success(data=api) -@router.get('', summary='(模糊条件)分页获取所有接口', dependencies=[DependsJwtAuth, PageDepends]) +@router.get('', summary='(模糊条件)分页获取所有接口', dependencies=[DependsRBAC, PageDepends]) async def get_all_apis( db: CurrentSession, name: Annotated[str | None, Query()] = None, diff --git a/backend/app/api/v1/auth/auth.py b/backend/app/api/v1/auth/auth.py index 503ce14..5652291 100644 --- a/backend/app/api/v1/auth/auth.py +++ b/backend/app/api/v1/auth/auth.py @@ -7,6 +7,7 @@ from fastapi.security import OAuth2PasswordRequestForm from fastapi_limiter.depends import RateLimiter from starlette.background import BackgroundTasks +from backend.app.common.casbin_rbac import DependsRBAC from backend.app.common.jwt import DependsJwtAuth from backend.app.common.response.response_schema import response_base from backend.app.schemas.token import GetLoginToken, GetSwaggerToken, GetNewToken diff --git a/backend/app/api/v1/casbin.py b/backend/app/api/v1/casbin.py index 0894929..1cb26bc 100644 --- a/backend/app/api/v1/casbin.py +++ b/backend/app/api/v1/casbin.py @@ -5,7 +5,6 @@ from typing import Annotated from fastapi import APIRouter, Query from backend.app.common.casbin_rbac import DependsRBAC -from backend.app.common.jwt import DependsJwtAuth from backend.app.common.pagination import PageDepends, paging_data from backend.app.common.response.response_schema import response_base from backend.app.database.db_mysql import CurrentSession @@ -22,7 +21,7 @@ from backend.app.services.casbin_service import CasbinService router = APIRouter() -@router.get('', summary='(模糊条件)分页获取所有 casbin 规则', dependencies=[DependsJwtAuth, PageDepends]) +@router.get('', summary='(模糊条件)分页获取所有 casbin 规则', dependencies=[DependsRBAC, PageDepends]) async def get_all_casbin( db: CurrentSession, ptype: Annotated[str | None, Query()] = None, @@ -33,7 +32,7 @@ async def get_all_casbin( return await response_base.success(data=page_data) -@router.get('/policies', summary='获取所有 P 规则', dependencies=[DependsJwtAuth]) +@router.get('/policy', summary='获取所有 P 规则', dependencies=[DependsRBAC]) async def get_all_policies(): policies = await CasbinService.get_policy_list() return await response_base.success(data=policies) @@ -66,7 +65,7 @@ async def delete_policy(p: DeletePolicy): return await response_base.success(data=data) -@router.get('/groups', summary='获取所有 g 规则', dependencies=[DependsJwtAuth]) +@router.get('/group', summary='获取所有 g 规则', dependencies=[DependsRBAC]) async def get_all_groups(): data = await CasbinService.get_group_list() return await response_base.success(data=data) diff --git a/backend/app/api/v1/dept.py b/backend/app/api/v1/dept.py index f992183..6735d91 100644 --- a/backend/app/api/v1/dept.py +++ b/backend/app/api/v1/dept.py @@ -5,7 +5,6 @@ from typing import Annotated from fastapi import APIRouter, Query, Request from backend.app.common.casbin_rbac import DependsRBAC -from backend.app.common.jwt import DependsJwtAuth from backend.app.common.response.response_schema import response_base from backend.app.schemas.dept import CreateDept, GetAllDept, UpdateDept from backend.app.services.dept_service import DeptService @@ -14,21 +13,21 @@ from backend.app.utils.serializers import select_to_json router = APIRouter() -@router.get('/{pk}', summary='获取部门详情', dependencies=[DependsJwtAuth]) +@router.get('/{pk}', summary='获取部门详情', dependencies=[DependsRBAC]) async def get_dept(pk: int): dept = await DeptService.get(pk=pk) data = GetAllDept(**select_to_json(dept)) return await response_base.success(data=data) -@router.get('', summary='获取所有部门展示树', dependencies=[DependsJwtAuth]) +@router.get('', summary='获取所有部门展示树', dependencies=[DependsRBAC]) async def get_all_depts( name: Annotated[str | None, Query()] = None, leader: Annotated[str | None, Query()] = None, phone: Annotated[str | None, Query()] = None, status: Annotated[bool | None, Query()] = None, ): - dept = await DeptService.get_select(name=name, leader=leader, phone=phone, status=status) + dept = await DeptService.get_dept_tree(name=name, leader=leader, phone=phone, status=status) return await response_base.success(data=dept) diff --git a/backend/app/api/v1/login_log.py b/backend/app/api/v1/login_log.py index 6477cf1..613240e 100644 --- a/backend/app/api/v1/login_log.py +++ b/backend/app/api/v1/login_log.py @@ -5,7 +5,6 @@ from typing import Annotated from fastapi import APIRouter, Query from backend.app.common.casbin_rbac import DependsRBAC -from backend.app.common.jwt import DependsJwtAuth from backend.app.common.pagination import paging_data, PageDepends from backend.app.common.response.response_schema import response_base from backend.app.database.db_mysql import CurrentSession @@ -15,7 +14,7 @@ from backend.app.services.login_log_service import LoginLogService router = APIRouter() -@router.get('', summary='(模糊条件)分页获取登录日志', dependencies=[DependsJwtAuth, PageDepends]) +@router.get('', summary='(模糊条件)分页获取登录日志', dependencies=[DependsRBAC, PageDepends]) async def get_all_login_logs( db: CurrentSession, username: Annotated[str | None, Query()] = None, diff --git a/backend/app/api/v1/menu.py b/backend/app/api/v1/menu.py index d486043..a029cb2 100644 --- a/backend/app/api/v1/menu.py +++ b/backend/app/api/v1/menu.py @@ -5,7 +5,6 @@ from typing import Annotated from fastapi import APIRouter, Query, Request from backend.app.common.casbin_rbac import DependsRBAC -from backend.app.common.jwt import DependsJwtAuth from backend.app.common.response.response_schema import response_base from backend.app.schemas.menu import GetAllMenu, CreateMenu, UpdateMenu from backend.app.services.menu_service import MenuService @@ -14,29 +13,35 @@ from backend.app.utils.serializers import select_to_json router = APIRouter() -@router.get('/{pk}', summary='获取目录详情', dependencies=[DependsJwtAuth]) +@router.get('/sidebar', summary='获取用户菜单展示树', dependencies=[DependsRBAC]) +async def get_user_menus(request: Request): + menu = await MenuService.get_user_menu_tree(request=request) + return await response_base.success(data=menu) + + +@router.get('/{pk}', summary='获取菜单详情', dependencies=[DependsRBAC]) async def get_menu(pk: int): menu = await MenuService.get(pk=pk) data = GetAllMenu(**select_to_json(menu)) return await response_base.success(data=data) -@router.get('', summary='获取所有目录展示树', dependencies=[DependsJwtAuth]) +@router.get('', summary='获取所有菜单展示树', dependencies=[DependsRBAC]) async def get_all_menus( name: Annotated[str | None, Query()] = None, status: Annotated[bool | None, Query()] = None, ): - menu = await MenuService.get_select(name=name, status=status) + menu = await MenuService.get_menu_tree(name=name, status=status) return await response_base.success(data=menu) -@router.post('', summary='创建目录', dependencies=[DependsRBAC]) +@router.post('', summary='创建菜单', dependencies=[DependsRBAC]) async def create_menu(request: Request, obj: CreateMenu): await MenuService.create(obj=obj, user_id=request.user.id) return await response_base.success() -@router.put('/{pk}', summary='更新目录', dependencies=[DependsRBAC]) +@router.put('/{pk}', summary='更新菜单', dependencies=[DependsRBAC]) async def update_menu(request: Request, pk: int, obj: UpdateMenu): count = await MenuService.update(pk=pk, obj=obj, user_id=request.user.id) if count > 0: @@ -44,7 +49,7 @@ async def update_menu(request: Request, pk: int, obj: UpdateMenu): return await response_base.fail() -@router.delete('{pk}', summary='删除目录', dependencies=[DependsRBAC]) +@router.delete('/{pk}', summary='删除菜单', dependencies=[DependsRBAC]) async def delete_menu(pk: int): count = await MenuService.delete(pk=pk) if count > 0: diff --git a/backend/app/api/v1/opera_log.py b/backend/app/api/v1/opera_log.py index 184ca2d..e4e1e5c 100644 --- a/backend/app/api/v1/opera_log.py +++ b/backend/app/api/v1/opera_log.py @@ -5,7 +5,6 @@ from typing import Annotated from fastapi import APIRouter, Query from backend.app.common.casbin_rbac import DependsRBAC -from backend.app.common.jwt import DependsJwtAuth from backend.app.common.pagination import PageDepends, paging_data from backend.app.common.response.response_schema import response_base from backend.app.database.db_mysql import CurrentSession @@ -15,7 +14,7 @@ from backend.app.services.opera_log_service import OperaLogService router = APIRouter() -@router.get('', summary='(模糊条件)分页获取操作日志', dependencies=[DependsJwtAuth, PageDepends]) +@router.get('', summary='(模糊条件)分页获取操作日志', dependencies=[DependsRBAC, PageDepends]) async def get_all_opera_logs( db: CurrentSession, username: Annotated[str | None, Query()] = None, diff --git a/backend/app/api/v1/role.py b/backend/app/api/v1/role.py index 659ad3d..e9ad1d4 100644 --- a/backend/app/api/v1/role.py +++ b/backend/app/api/v1/role.py @@ -5,7 +5,6 @@ from typing import Annotated from fastapi import APIRouter, Query, Request from backend.app.common.casbin_rbac import DependsRBAC -from backend.app.common.jwt import DependsJwtAuth from backend.app.common.pagination import PageDepends, paging_data from backend.app.common.response.response_schema import response_base from backend.app.database.db_mysql import CurrentSession @@ -16,14 +15,14 @@ from backend.app.utils.serializers import select_to_json router = APIRouter() -@router.get('/{pk}', summary='获取角色详情', dependencies=[DependsJwtAuth]) +@router.get('/{pk}', summary='获取角色详情', dependencies=[DependsRBAC]) async def get_role(pk: int): role = await RoleService.get(pk=pk) data = GetAllRole(**select_to_json(role)) return await response_base.success(data=data) -@router.get('', summary='(模糊条件)分页获取所有角色', dependencies=[DependsJwtAuth, PageDepends]) +@router.get('', summary='(模糊条件)分页获取所有角色', dependencies=[DependsRBAC, PageDepends]) async def get_all_roles( db: CurrentSession, name: Annotated[str | None, Query()] = None, diff --git a/backend/app/api/v1/user.py b/backend/app/api/v1/user.py index f47f4e1..098aca6 100644 --- a/backend/app/api/v1/user.py +++ b/backend/app/api/v1/user.py @@ -4,7 +4,7 @@ from typing import Annotated from fastapi import APIRouter, Query, Request -from backend.app.common.jwt import DependsJwtAuth +from backend.app.common.casbin_rbac import DependsRBAC from backend.app.common.pagination import paging_data, PageDepends from backend.app.common.response.response_schema import response_base from backend.app.database.db_mysql import CurrentSession @@ -21,7 +21,7 @@ async def user_register(obj: CreateUser): return await response_base.success() -@router.post('/password/reset', summary='密码重置', dependencies=[DependsJwtAuth]) +@router.post('/password/reset', summary='密码重置', dependencies=[DependsRBAC]) async def password_reset(request: Request, obj: ResetPassword): count = await UserService.pwd_reset(request=request, obj=obj) if count > 0: @@ -29,14 +29,14 @@ async def password_reset(request: Request, obj: ResetPassword): return await response_base.fail() -@router.get('/{username}', summary='查看用户信息', dependencies=[DependsJwtAuth]) +@router.get('/{username}', summary='查看用户信息', dependencies=[DependsRBAC]) async def get_user(username: str): current_user = await UserService.get_userinfo(username=username) data = GetAllUserInfo(**select_to_json(current_user)) return await response_base.success(data=data) -@router.put('/{username}', summary='更新用户信息', dependencies=[DependsJwtAuth]) +@router.put('/{username}', summary='更新用户信息', dependencies=[DependsRBAC]) async def update_userinfo(request: Request, username: str, obj: UpdateUser): count = await UserService.update(request=request, username=username, obj=obj) if count > 0: @@ -44,7 +44,7 @@ async def update_userinfo(request: Request, username: str, obj: UpdateUser): return await response_base.fail() -@router.put('/{username}/avatar', summary='更新头像', dependencies=[DependsJwtAuth]) +@router.put('/{username}/avatar', summary='更新头像', dependencies=[DependsRBAC]) async def update_avatar(request: Request, username: str, avatar: Avatar): count = await UserService.update_avatar(request=request, username=username, avatar=avatar) if count > 0: @@ -52,7 +52,7 @@ async def update_avatar(request: Request, username: str, avatar: Avatar): return await response_base.fail() -@router.get('', summary='(模糊条件)分页获取所有用户', dependencies=[DependsJwtAuth, PageDepends]) +@router.get('', summary='(模糊条件)分页获取所有用户', dependencies=[DependsRBAC, PageDepends]) async def get_all_users( db: CurrentSession, username: Annotated[str | None, Query()] = None, @@ -64,7 +64,7 @@ async def get_all_users( return await response_base.success(data=page_data) -@router.post('/{pk}/super', summary='修改用户超级权限', dependencies=[DependsJwtAuth]) +@router.post('/{pk}/super', summary='修改用户超级权限', dependencies=[DependsRBAC]) async def super_set(request: Request, pk: int): count = await UserService.update_permission(request=request, pk=pk) if count > 0: @@ -72,7 +72,7 @@ async def super_set(request: Request, pk: int): return await response_base.fail() -@router.post('/{pk}/action', summary='修改用户状态', dependencies=[DependsJwtAuth]) +@router.post('/{pk}/action', summary='修改用户状态', dependencies=[DependsRBAC]) async def active_set(request: Request, pk: int): count = await UserService.update_active(request=request, pk=pk) if count > 0: @@ -80,7 +80,7 @@ async def active_set(request: Request, pk: int): return await response_base.fail() -@router.post('/{pk}/multi', summary='修改用户多点登录状态', dependencies=[DependsJwtAuth]) +@router.post('/{pk}/multi', summary='修改用户多点登录状态', dependencies=[DependsRBAC]) async def multi_set(request: Request, pk: int): count = await UserService.update_multi_login(request=request, pk=pk) if count > 0: @@ -92,7 +92,7 @@ async def multi_set(request: Request, pk: int): path='/{username}', summary='用户注销', description='用户注销 != 用户登出,注销之后用户将从数据库删除', - dependencies=[DependsJwtAuth], + dependencies=[DependsRBAC], ) async def delete_user(request: Request, username: str): count = await UserService.delete(request=request, username=username) diff --git a/backend/app/common/casbin_rbac.py b/backend/app/common/casbin_rbac.py index 31291cc..a8a0e76 100644 --- a/backend/app/common/casbin_rbac.py +++ b/backend/app/common/casbin_rbac.py @@ -26,36 +26,45 @@ class RBAC: return enforcer - async def rbac_verify(self, request: Request, _: str = DependsJwtAuth) -> None: + async def rbac_verify(self, request: Request, _: dict = DependsJwtAuth) -> None: """ - 权限校验 + RBAC 权限校验 :param request: :param _: :return: """ + # 超级管理员免校验 super_user = request.user.is_superuser if super_user: return - + # 免鉴权的接口 method = request.method path = request.url.path if (method, path) in settings.CASBIN_EXCLUDE: return - + # 检测角色数据权限范围 user_roles = request.user.roles - data_scope = [role.data_scope for role in user_roles if role.data_scope == 1] + data_scope = any(role.data_scope == 1 for role in user_roles) if data_scope: return - - # TODO: 通过 redis 做鉴权查询优化,减少数据库查询 - user_uuid = request.user.user_uuid - enforcer = self.enforcer() - if not enforcer.enforce(user_uuid, path, method): - raise AuthorizationError + if settings.MENU_PERMISSION: + # 菜单权限校验 + path_auth = request.url.path.replace(f'{settings.API_V1_STR}', '').replace('/', ':') + menu_perms = [] + for role in user_roles: + menu_perms.extend([menu.perms for menu in role.menus]) + if not menu_perms or path_auth not in menu_perms: + raise AuthorizationError + else: + # casbin 权限校验 + user_uuid = request.user.user_uuid + enforcer = self.enforcer() + if not enforcer.enforce(user_uuid, path, method): + raise AuthorizationError RBAC = RBAC() RbacEnforcer = RBAC.enforcer() -# RBAC 依赖注入 +# RBAC 授权依赖注入 DependsRBAC = Depends(RBAC.rbac_verify) diff --git a/backend/app/common/jwt.py b/backend/app/common/jwt.py index f00216b..45b1932 100644 --- a/backend/app/common/jwt.py +++ b/backend/app/common/jwt.py @@ -3,7 +3,7 @@ from datetime import datetime, timedelta from asgiref.sync import sync_to_async -from fastapi import Depends, Request +from fastapi import Request, Depends from fastapi.security import OAuth2PasswordBearer from fastapi.security.utils import get_authorization_scheme_param from jose import jwt @@ -12,6 +12,7 @@ from pydantic import ValidationError from sqlalchemy.ext.asyncio import AsyncSession from backend.app.common.exception.errors import AuthorizationError, TokenError +from backend.app.common.log import log from backend.app.common.redis import redis_client from backend.app.core.conf import settings from backend.app.crud.crud_user import UserDao @@ -120,7 +121,7 @@ def get_token(request: Request) -> str: authorization = request.headers.get('Authorization') scheme, token = get_authorization_scheme_param(authorization) if not authorization or scheme.lower() != 'bearer': - raise TokenError + raise TokenError(msg='token 无效') return token @@ -136,13 +137,13 @@ def jwt_decode(token: str) -> int: payload = jwt.decode(token, settings.TOKEN_SECRET_KEY, algorithms=[settings.TOKEN_ALGORITHM]) user_id = int(payload.get('sub')) if not user_id: - raise TokenError + raise TokenError(msg='token 无效') except (jwt.JWTError, ValidationError, Exception): - raise TokenError + raise TokenError(msg='token 无效') return user_id -async def jwt_authentication(token: str) -> dict[str, int]: +async def jwt_authentication(token: str = Depends(oauth2_schema)) -> dict[str, int]: """ JWT authentication @@ -168,12 +169,18 @@ async def get_current_user(db: AsyncSession, data: dict) -> User: user_id = data.get('sub') user = await UserDao.get_with_relation(db, user_id=user_id) if not user: - raise TokenError + raise TokenError(msg='token 无效') if not user.is_active: raise AuthorizationError(msg='用户已锁定') if user.dept_id: if not user.dept.status: raise AuthorizationError(msg='用户所属部门已锁定') + if user.dept.del_flag: + raise AuthorizationError(msg='用户所属部门已删除') + if user.roles: + role_status = [role.status for role in user.roles] + if all(status == 0 for status in role_status): + raise AuthorizationError(msg='用户所属角色已锁定') return user @@ -191,5 +198,6 @@ def superuser_verify(request: Request) -> bool: return is_superuser -# Jwt verify dependency -DependsJwtAuth = Depends(oauth2_schema) +# JWT authorizes dependency injection, which can be used if the interface only +# needs to provide a token instead of RBAC permission control +DependsJwtAuth = Depends(jwt_authentication) diff --git a/backend/app/core/conf.py b/backend/app/core/conf.py index 280b874..cd826c5 100644 --- a/backend/app/core/conf.py +++ b/backend/app/core/conf.py @@ -35,7 +35,7 @@ class Settings(BaseSettings): OPERA_LOG_ENCRYPT_SECRET_KEY: str # 密钥 os.urandom(32), 需使用 bytes.hex() 方法转换为 str # FastAPI - API_V1_STR: str = '/v1' + API_V1_STR: str = '/api/v1' TITLE: str = 'FastAPI' VERSION: str = '0.0.1' DESCRIPTION: str = 'FastAPI Best Architecture' @@ -87,7 +87,7 @@ class Settings(BaseSettings): TOKEN_REDIS_PREFIX: str = 'fba_token' TOKEN_REFRESH_REDIS_PREFIX: str = 'fba_refresh_token' - # captcha + # Captcha CAPTCHA_LOGIN_REDIS_PREFIX: str = 'fba_login_captcha' CAPTCHA_LOGIN_EXPIRE_SECONDS: int = 60 * 5 # 过期时间,单位:秒 @@ -103,19 +103,28 @@ class Settings(BaseSettings): # Casbin CASBIN_RBAC_MODEL_NAME: str = 'rbac_model.conf' CASBIN_EXCLUDE: set[tuple[str, str]] = { - ('POST', '/v1/auth/swagger_login'), - ('POST', '/v1/auth/login'), - ('POST', '/v1/auth/register'), - ('POST', '/v1/auth/password/reset'), + ('POST', f'{API_V1_STR}/auth/swagger_login'), + ('POST', f'{API_V1_STR}/auth/login'), + ('POST', f'{API_V1_STR}/auth/register'), + ('GET', f'{API_V1_STR}/auth/captcha'), } + # Menu + MENU_PERMISSION: bool = False # 危险行为,开启此功能, Casbin 鉴权将失效,并将使用角色菜单鉴权 (默认关闭) + MENU_EXCLUDE: list[str] = [ + 'auth:swagger_login', + 'auth:login', + 'auth:register', + 'auth:captcha', + ] + # Opera log OPERA_LOG_EXCLUDE: list[str] = [ '/favicon.ico', DOCS_URL, REDOCS_URL, OPENAPI_URL, - '/v1/auth/swagger_login', + f'{API_V1_STR}/auth/swagger_login', ] OPERA_LOG_ENCRYPT: int = 1 # 请求入参加密, 0: AES (高性能损耗), 1: md5, 2: 不加密, other: 替换为 ****** OPERA_LOG_ENCRYPT_INCLUDE: list[str] = ['password', 'old_password', 'new_password', 'confirm_password'] diff --git a/backend/app/core/registrar.py b/backend/app/core/registrar.py index 999ca4d..6dcc69c 100644 --- a/backend/app/core/registrar.py +++ b/backend/app/core/registrar.py @@ -13,8 +13,8 @@ from backend.app.common.redis import redis_client from backend.app.common.task import scheduler from backend.app.core.conf import settings from backend.app.database.db_mysql import create_table -from backend.app.middleware.opera_log_middleware import OperaLogMiddleware from backend.app.middleware.jwt_auth_middleware import JwtAuthMiddleware +from backend.app.middleware.opera_log_middleware import OperaLogMiddleware from backend.app.utils.health_check import ensure_unique_route_names from backend.app.utils.openapi import simplify_operation_ids @@ -135,6 +135,7 @@ def register_router(app: FastAPI): :param app: FastAPI :return: """ + # API app.include_router(v1) # Extra diff --git a/backend/app/crud/crud_menu.py b/backend/app/crud/crud_menu.py index ae57e44..8ef2501 100644 --- a/backend/app/crud/crud_menu.py +++ b/backend/app/crud/crud_menu.py @@ -26,6 +26,18 @@ class CRUDMenu(CRUDBase[Menu, CreateMenu, UpdateMenu]): menu = await db.execute(se) return menu.scalars().all() + async def get_role_menus(self, db, superuser: bool, menu_ids: list[int]) -> list[Menu]: + se = select(self.model).order_by(asc(self.model.sort)) + where_list = [ + self.model.menu_type.in_([0, 1]), + self.model.status == 1, + ] + if not superuser: + where_list.append(self.model.id.in_(menu_ids)) + se = se.where(and_(*where_list)) + menu = await db.execute(se) + return menu.scalars().all() + async def create(self, db, obj_in: dict) -> None: obj = self.model(**obj_in) db.add(obj) diff --git a/backend/app/middleware/jwt_auth_middleware.py b/backend/app/middleware/jwt_auth_middleware.py index 5e08e91..650daf7 100644 --- a/backend/app/middleware/jwt_auth_middleware.py +++ b/backend/app/middleware/jwt_auth_middleware.py @@ -3,14 +3,13 @@ from typing import Any from fastapi import Request, Response -from starlette.authentication import AuthenticationBackend, AuthenticationError +from starlette.authentication import AuthenticationBackend, AuthenticationError, AuthCredentials from starlette.requests import HTTPConnection from starlette.responses import JSONResponse from backend.app.common import jwt from backend.app.common.exception.errors import TokenError from backend.app.common.log import log -from backend.app.core.conf import settings from backend.app.database.db_mysql import async_db_session @@ -55,4 +54,4 @@ class JwtAuthMiddleware(AuthenticationBackend): # 请注意,此返回使用非标准模式,所以在认证通过时,将丢失某些标准特性 # 标准返回模式请查看:https://www.starlette.io/authentication/ - return auth, user + return AuthCredentials(['authenticated']), user diff --git a/backend/app/middleware/opera_log_middleware.py b/backend/app/middleware/opera_log_middleware.py index a8cdddd..5d0dcf1 100644 --- a/backend/app/middleware/opera_log_middleware.py +++ b/backend/app/middleware/opera_log_middleware.py @@ -33,7 +33,7 @@ class OperaLogMiddleware: # 排除记录白名单 path = request.url.path - if path in settings.OPERA_LOG_EXCLUDE: + if path in settings.OPERA_LOG_EXCLUDE or not path.startswith(f'{settings.API_V1_STR}'): await self.app(scope, receive, send) return @@ -48,7 +48,7 @@ class OperaLogMiddleware: method = request.method args = await self.get_request_args(request) - # 设置附加请求信息(可选) + # 设置附加请求信息 request.state.ip = ip request.state.country = country request.state.region = region @@ -120,7 +120,7 @@ class OperaLogMiddleware: await self.app(request.scope, wrapped_rcv, send) except Exception as e: log.exception(e) - code = getattr(e, 'code', 500) + code = getattr(e, 'code', '500') msg = getattr(e, 'msg', 'Internal Server Error') status = False err = e diff --git a/backend/app/models/sys_opera_log.py b/backend/app/models/sys_opera_log.py index 45a6568..7d9beae 100644 --- a/backend/app/models/sys_opera_log.py +++ b/backend/app/models/sys_opera_log.py @@ -29,7 +29,7 @@ class OperaLog(DataClassBase): device: Mapped[str | None] = mapped_column(String(50), comment='设备') args: Mapped[str | None] = mapped_column(JSON(), comment='请求参数') status: Mapped[bool] = mapped_column(comment='操作状态(0异常 1正常)') - code: Mapped[int | str] = mapped_column(insert_default=200, comment='操作状态码') + code: Mapped[str] = mapped_column(String(20), insert_default='200', comment='操作状态码') msg: Mapped[str | None] = mapped_column(LONGTEXT, comment='提示消息') cost_time: Mapped[float] = mapped_column(insert_default=0.0, comment='请求耗时ms') opera_time: Mapped[datetime] = mapped_column(comment='操作时间') diff --git a/backend/app/schemas/opera_log.py b/backend/app/schemas/opera_log.py index 5094de9..e512b3e 100644 --- a/backend/app/schemas/opera_log.py +++ b/backend/app/schemas/opera_log.py @@ -20,7 +20,7 @@ class OperaLogBase(BaseModel): device: str | None = None args: dict | None = None status: bool - code: int | str + code: str msg: str | None = None cost_time: float opera_time: datetime diff --git a/backend/app/services/dept_service.py b/backend/app/services/dept_service.py index 16fe6cd..f6905f0 100644 --- a/backend/app/services/dept_service.py +++ b/backend/app/services/dept_service.py @@ -19,7 +19,7 @@ class DeptService: return dept @staticmethod - async def get_select( + async def get_dept_tree( *, name: str | None = None, leader: str | None = None, phone: str | None = None, status: bool | None = None ): async with async_db_session() as db: diff --git a/backend/app/services/menu_service.py b/backend/app/services/menu_service.py index 89b579d..4bc1a33 100644 --- a/backend/app/services/menu_service.py +++ b/backend/app/services/menu_service.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from fastapi import Request from backend.app.common.exception import errors from backend.app.crud.crud_menu import MenuDao from backend.app.database.db_mysql import async_db_session @@ -9,7 +10,7 @@ from backend.app.utils.build_tree import get_tree_data class MenuService: @staticmethod - async def get(pk: int): + async def get(*, pk: int): async with async_db_session() as db: menu = await MenuDao.get(db, menu_id=pk) if not menu: @@ -17,14 +18,25 @@ class MenuService: return menu @staticmethod - async def get_select(name: str | None = None, status: bool | None = None): + async def get_menu_tree(*, name: str | None = None, status: bool | None = None): async with async_db_session() as db: menu_select = await MenuDao.get_all(db, name=name, status=status) menu_tree = await get_tree_data(menu_select) return menu_tree @staticmethod - async def create(obj: CreateMenu, user_id: int): + async def get_user_menu_tree(*, request: Request): + async with async_db_session() as db: + roles = request.user.roles + menu_ids = [] + for role in roles: + menu_ids.extend([menu.id for menu in role.menus]) + menu_select = await MenuDao.get_role_menus(db, request.user.is_superuser, menu_ids) + menu_tree = await get_tree_data(menu_select) + return menu_tree + + @staticmethod + async def create(*, obj: CreateMenu, user_id: int): async with async_db_session.begin() as db: menu = await MenuDao.get_by_name(db, obj.name) if menu: @@ -34,7 +46,7 @@ class MenuService: await MenuDao.create(db, new_obj) @staticmethod - async def update(pk: int, obj: UpdateMenu, user_id: int): + async def update(*, pk: int, obj: UpdateMenu, user_id: int): async with async_db_session.begin() as db: menu = await MenuDao.get(db, pk) if not menu: @@ -48,7 +60,7 @@ class MenuService: return count @staticmethod - async def delete(pk: int): + async def delete(*, pk: int): async with async_db_session.begin() as db: children = await MenuDao.get_children(db, pk) if children: diff --git a/backend/app/services/user_service.py b/backend/app/services/user_service.py index 7359876..2fccf97 100644 --- a/backend/app/services/user_service.py +++ b/backend/app/services/user_service.py @@ -7,7 +7,7 @@ from sqlalchemy import Select from backend.app.common import jwt from backend.app.common.exception import errors -from backend.app.common.jwt import get_token, jwt_decode, password_verify +from backend.app.common.jwt import get_token, password_verify from backend.app.common.redis import redis_client from backend.app.core.conf import settings from backend.app.crud.crud_dept import DeptDao @@ -104,6 +104,8 @@ class UserService: if not await UserDao.get(db, pk): raise errors.NotFoundError(msg='用户不存在') else: + if pk == request.user.id: + raise errors.ForbiddenError(msg='禁止修改自身权限') count = await UserDao.set_super(db, pk) return count @@ -114,6 +116,8 @@ class UserService: if not await UserDao.get(db, pk): raise errors.NotFoundError(msg='用户不存在') else: + if pk == request.user.id: + raise errors.ForbiddenError(msg='禁止修改自身状态') count = await UserDao.set_active(db, pk) return count @@ -126,7 +130,7 @@ class UserService: else: count = await UserDao.set_multi_login(db, pk) token = await get_token(request) - user_id = await jwt_decode(token) + user_id = request.user.id latest_multi_login = await UserDao.get_multi_login(db, pk) # TODO: 删除用户 refresh token, 此操作需要传参,暂时不考虑实现 # 当前用户修改自身时(普通/超级),除当前token外,其他token失效