Files
fastapi_best_architecture/backend/middleware/opera_log_middleware.py
IAseven e09062eb39 Optimize the opera log storage logic through queue (#750)
*  feat: 操作日志中间件添加批量插入功能

* Delete GEMINI.md

* 🌈 style: 修复格式化错误

* 🐞 fix: 通过asyncio.wait_for兼容py3.10中asyncio.timeout不存在

* 🦄 refactor: 重新组织操作日志批量插入代码逻辑

* 优化代码实现

* 恢复默认配置

* 恢复默认 .gitignore 文件

* 更新队列批处理逻辑
2025-08-07 17:36:10 +08:00

214 lines
7.7 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import time
from asyncio import Queue
from typing import Any
from asgiref.sync import sync_to_async
from fastapi import Response
from starlette.datastructures import UploadFile
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from backend.app.admin.schema.opera_log import CreateOperaLogParam
from backend.app.admin.service.opera_log_service import opera_log_service
from backend.common.enums import OperaLogCipherType, StatusType
from backend.common.log import log
from backend.common.queue import batch_dequeue
from backend.core.conf import settings
from backend.utils.encrypt import AESCipher, ItsDCipher, Md5Cipher
from backend.utils.trace_id import get_request_trace_id
class OperaLogMiddleware(BaseHTTPMiddleware):
"""操作日志中间件"""
opera_log_queue: Queue = Queue(maxsize=100000)
async def dispatch(self, request: Request, call_next: Any) -> Response:
"""
处理请求并记录操作日志
:param request: FastAPI 请求对象
:param call_next: 下一个中间件或路由处理函数
:return:
"""
response = None
path = request.url.path
if path in settings.OPERA_LOG_PATH_EXCLUDE or not path.startswith(f'{settings.FASTAPI_API_V1_PATH}'):
response = await call_next(request)
else:
method = request.method
args = await self.get_request_args(request)
# 执行请求
elapsed = 0.0
code = 200
msg = 'Success'
status = StatusType.enable
error = None
try:
response = await call_next(request)
elapsed = (time.perf_counter() - request.state.perf_time) * 1000
for state in [
'__request_http_exception__',
'__request_validation_exception__',
'__request_assertion_error__',
'__request_custom_exception__',
'__request_all_unknown_exception__',
'__request_cors_500_exception__',
]:
exception = getattr(request.state, state, None)
if exception:
code = exception.get('code')
msg = exception.get('msg')
log.error(f'请求异常: {msg}')
break
except Exception as e:
log.error(f'请求异常: {str(e)}')
code = getattr(e, 'code', code) # 兼容 SQLAlchemy 异常用法
msg = getattr(e, 'msg', msg)
status = StatusType.disable
error = e
# 此信息只能在请求后获取
_route = request.scope.get('route')
summary = getattr(_route, 'summary', '')
try:
# 此信息来源于 JWT 认证中间件
username = request.user.username
except AttributeError:
username = None
# 日志记录
log.debug(f'接口摘要:[{summary}]')
log.debug(f'请求地址:[{request.state.ip}]')
log.debug(f'请求参数:{args}')
# 日志创建
opera_log_in = CreateOperaLogParam(
trace_id=get_request_trace_id(request),
username=username,
method=method,
title=summary,
path=path,
ip=request.state.ip,
country=request.state.country,
region=request.state.region,
city=request.state.city,
user_agent=request.state.user_agent,
os=request.state.os,
browser=request.state.browser,
device=request.state.device,
args=args,
status=status,
code=str(code),
msg=msg,
cost_time=elapsed, # 可能和日志存在微小差异(可忽略)
opera_time=request.state.start_time,
)
await self.opera_log_queue.put(opera_log_in)
# 错误抛出
if error:
raise error from None
return response
async def get_request_args(self, request: Request) -> dict[str, Any] | None:
"""
获取请求参数
:param request: FastAPI 请求对象
:return:
"""
args = {}
# 查询参数
query_params = dict(request.query_params)
if query_params:
args['query_params'] = await self.desensitization(query_params)
# 路径参数
path_params = request.path_params
if path_params:
args['path_params'] = await self.desensitization(path_params)
# Tip: .body() 必须在 .form() 之前获取
# https://github.com/encode/starlette/discussions/1933
content_type = request.headers.get('Content-Type', '').split(';')
# 请求体
body_data = await request.body()
if body_data:
# 注意:非 json 数据默认使用 data 作为键
if 'application/json' not in content_type:
args['data'] = str(body_data)
else:
json_data = await request.json()
if isinstance(json_data, dict):
args['json'] = await self.desensitization(json_data)
else:
args['data'] = str(body_data)
# 表单参数
form_data = await request.form()
if len(form_data) > 0:
for k, v in form_data.items():
if isinstance(v, UploadFile):
form_data = {k: v.filename}
else:
form_data = {k: v}
if 'multipart/form-data' not in content_type:
args['x-www-form-urlencoded'] = await self.desensitization(form_data)
else:
args['form-data'] = await self.desensitization(form_data)
return None if not args else args
@staticmethod
@sync_to_async
def desensitization(args: dict[str, Any]) -> dict[str, Any]:
"""
脱敏处理
:param args: 需要脱敏的参数字典
:return:
"""
for key, value in args.items():
if key in settings.OPERA_LOG_ENCRYPT_KEY_INCLUDE:
match settings.OPERA_LOG_ENCRYPT_TYPE:
case OperaLogCipherType.aes:
args[key] = (AESCipher(settings.OPERA_LOG_ENCRYPT_SECRET_KEY).encrypt(value)).hex()
case OperaLogCipherType.md5:
args[key] = Md5Cipher.encrypt(value)
case OperaLogCipherType.itsdangerous:
args[key] = ItsDCipher(settings.OPERA_LOG_ENCRYPT_SECRET_KEY).encrypt(value)
case OperaLogCipherType.plan:
pass
case _:
args[key] = '******'
return args
@classmethod
async def consumer(cls) -> None:
"""操作日志消费者"""
while True:
logs = await batch_dequeue(
cls.opera_log_queue,
max_items=settings.OPERA_LOG_QUEUE_BATCH_CONSUME_SIZE,
timeout=settings.OPERA_LOG_QUEUE_TIMEOUT,
)
if logs:
try:
if settings.DATABASE_ECHO:
log.info('自动执行【操作日志批量创建】任务...')
await opera_log_service.bulk_create(objs=logs)
finally:
if not cls.opera_log_queue.empty():
cls.opera_log_queue.task_done()