mirror of
https://github.com/fastapi-practices/fastapi_best_architecture.git
synced 2025-08-14 19:04:00 +08:00

* ✨ feat: 操作日志中间件添加批量插入功能 * Delete GEMINI.md * 🌈 style: 修复格式化错误 * 🐞 fix: 通过asyncio.wait_for兼容py3.10中asyncio.timeout不存在 * 🦄 refactor: 重新组织操作日志批量插入代码逻辑 * 优化代码实现 * 恢复默认配置 * 恢复默认 .gitignore 文件 * 更新队列批处理逻辑
214 lines
7.7 KiB
Python
214 lines
7.7 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
import time
|
|
|
|
from asyncio import Queue
|
|
from typing import Any
|
|
|
|
from asgiref.sync import sync_to_async
|
|
from fastapi import Response
|
|
from starlette.datastructures import UploadFile
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.requests import Request
|
|
|
|
from backend.app.admin.schema.opera_log import CreateOperaLogParam
|
|
from backend.app.admin.service.opera_log_service import opera_log_service
|
|
from backend.common.enums import OperaLogCipherType, StatusType
|
|
from backend.common.log import log
|
|
from backend.common.queue import batch_dequeue
|
|
from backend.core.conf import settings
|
|
from backend.utils.encrypt import AESCipher, ItsDCipher, Md5Cipher
|
|
from backend.utils.trace_id import get_request_trace_id
|
|
|
|
|
|
class OperaLogMiddleware(BaseHTTPMiddleware):
|
|
"""操作日志中间件"""
|
|
|
|
opera_log_queue: Queue = Queue(maxsize=100000)
|
|
|
|
async def dispatch(self, request: Request, call_next: Any) -> Response:
|
|
"""
|
|
处理请求并记录操作日志
|
|
|
|
:param request: FastAPI 请求对象
|
|
:param call_next: 下一个中间件或路由处理函数
|
|
:return:
|
|
"""
|
|
response = None
|
|
path = request.url.path
|
|
|
|
if path in settings.OPERA_LOG_PATH_EXCLUDE or not path.startswith(f'{settings.FASTAPI_API_V1_PATH}'):
|
|
response = await call_next(request)
|
|
else:
|
|
method = request.method
|
|
args = await self.get_request_args(request)
|
|
|
|
# 执行请求
|
|
elapsed = 0.0
|
|
code = 200
|
|
msg = 'Success'
|
|
status = StatusType.enable
|
|
error = None
|
|
try:
|
|
response = await call_next(request)
|
|
elapsed = (time.perf_counter() - request.state.perf_time) * 1000
|
|
for state in [
|
|
'__request_http_exception__',
|
|
'__request_validation_exception__',
|
|
'__request_assertion_error__',
|
|
'__request_custom_exception__',
|
|
'__request_all_unknown_exception__',
|
|
'__request_cors_500_exception__',
|
|
]:
|
|
exception = getattr(request.state, state, None)
|
|
if exception:
|
|
code = exception.get('code')
|
|
msg = exception.get('msg')
|
|
log.error(f'请求异常: {msg}')
|
|
break
|
|
except Exception as e:
|
|
log.error(f'请求异常: {str(e)}')
|
|
code = getattr(e, 'code', code) # 兼容 SQLAlchemy 异常用法
|
|
msg = getattr(e, 'msg', msg)
|
|
status = StatusType.disable
|
|
error = e
|
|
|
|
# 此信息只能在请求后获取
|
|
_route = request.scope.get('route')
|
|
summary = getattr(_route, 'summary', '')
|
|
|
|
try:
|
|
# 此信息来源于 JWT 认证中间件
|
|
username = request.user.username
|
|
except AttributeError:
|
|
username = None
|
|
|
|
# 日志记录
|
|
log.debug(f'接口摘要:[{summary}]')
|
|
log.debug(f'请求地址:[{request.state.ip}]')
|
|
log.debug(f'请求参数:{args}')
|
|
|
|
# 日志创建
|
|
opera_log_in = CreateOperaLogParam(
|
|
trace_id=get_request_trace_id(request),
|
|
username=username,
|
|
method=method,
|
|
title=summary,
|
|
path=path,
|
|
ip=request.state.ip,
|
|
country=request.state.country,
|
|
region=request.state.region,
|
|
city=request.state.city,
|
|
user_agent=request.state.user_agent,
|
|
os=request.state.os,
|
|
browser=request.state.browser,
|
|
device=request.state.device,
|
|
args=args,
|
|
status=status,
|
|
code=str(code),
|
|
msg=msg,
|
|
cost_time=elapsed, # 可能和日志存在微小差异(可忽略)
|
|
opera_time=request.state.start_time,
|
|
)
|
|
await self.opera_log_queue.put(opera_log_in)
|
|
|
|
# 错误抛出
|
|
if error:
|
|
raise error from None
|
|
|
|
return response
|
|
|
|
async def get_request_args(self, request: Request) -> dict[str, Any] | None:
|
|
"""
|
|
获取请求参数
|
|
|
|
:param request: FastAPI 请求对象
|
|
:return:
|
|
"""
|
|
args = {}
|
|
|
|
# 查询参数
|
|
query_params = dict(request.query_params)
|
|
if query_params:
|
|
args['query_params'] = await self.desensitization(query_params)
|
|
|
|
# 路径参数
|
|
path_params = request.path_params
|
|
if path_params:
|
|
args['path_params'] = await self.desensitization(path_params)
|
|
|
|
# Tip: .body() 必须在 .form() 之前获取
|
|
# https://github.com/encode/starlette/discussions/1933
|
|
content_type = request.headers.get('Content-Type', '').split(';')
|
|
|
|
# 请求体
|
|
body_data = await request.body()
|
|
if body_data:
|
|
# 注意:非 json 数据默认使用 data 作为键
|
|
if 'application/json' not in content_type:
|
|
args['data'] = str(body_data)
|
|
else:
|
|
json_data = await request.json()
|
|
if isinstance(json_data, dict):
|
|
args['json'] = await self.desensitization(json_data)
|
|
else:
|
|
args['data'] = str(body_data)
|
|
|
|
# 表单参数
|
|
form_data = await request.form()
|
|
if len(form_data) > 0:
|
|
for k, v in form_data.items():
|
|
if isinstance(v, UploadFile):
|
|
form_data = {k: v.filename}
|
|
else:
|
|
form_data = {k: v}
|
|
if 'multipart/form-data' not in content_type:
|
|
args['x-www-form-urlencoded'] = await self.desensitization(form_data)
|
|
else:
|
|
args['form-data'] = await self.desensitization(form_data)
|
|
|
|
return None if not args else args
|
|
|
|
@staticmethod
|
|
@sync_to_async
|
|
def desensitization(args: dict[str, Any]) -> dict[str, Any]:
|
|
"""
|
|
脱敏处理
|
|
|
|
:param args: 需要脱敏的参数字典
|
|
:return:
|
|
"""
|
|
for key, value in args.items():
|
|
if key in settings.OPERA_LOG_ENCRYPT_KEY_INCLUDE:
|
|
match settings.OPERA_LOG_ENCRYPT_TYPE:
|
|
case OperaLogCipherType.aes:
|
|
args[key] = (AESCipher(settings.OPERA_LOG_ENCRYPT_SECRET_KEY).encrypt(value)).hex()
|
|
case OperaLogCipherType.md5:
|
|
args[key] = Md5Cipher.encrypt(value)
|
|
case OperaLogCipherType.itsdangerous:
|
|
args[key] = ItsDCipher(settings.OPERA_LOG_ENCRYPT_SECRET_KEY).encrypt(value)
|
|
case OperaLogCipherType.plan:
|
|
pass
|
|
case _:
|
|
args[key] = '******'
|
|
|
|
return args
|
|
|
|
@classmethod
|
|
async def consumer(cls) -> None:
|
|
"""操作日志消费者"""
|
|
while True:
|
|
logs = await batch_dequeue(
|
|
cls.opera_log_queue,
|
|
max_items=settings.OPERA_LOG_QUEUE_BATCH_CONSUME_SIZE,
|
|
timeout=settings.OPERA_LOG_QUEUE_TIMEOUT,
|
|
)
|
|
if logs:
|
|
try:
|
|
if settings.DATABASE_ECHO:
|
|
log.info('自动执行【操作日志批量创建】任务...')
|
|
await opera_log_service.bulk_create(objs=logs)
|
|
finally:
|
|
if not cls.opera_log_queue.empty():
|
|
cls.opera_log_queue.task_done()
|