#!/usr/bin/env python3 # -*- coding: utf-8 -*- from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError from pydantic import ValidationError from pydantic.errors import PydanticUserError from starlette.exceptions import HTTPException from starlette.middleware.cors import CORSMiddleware from backend.app.common.exception.errors import BaseExceptionMixin from backend.app.common.log import log from backend.app.common.response.response_code import CustomResponseCode, StandardResponseCode from backend.app.common.response.response_schema import response_base from backend.app.core.conf import settings from backend.app.schemas.base import ( CUSTOM_USAGE_ERROR_MESSAGES, CUSTOM_VALIDATION_ERROR_MESSAGES, ) from backend.app.utils.serializers import MsgSpecJSONResponse async def _validation_exception_handler(request: Request, e: RequestValidationError | ValidationError): """ 数据验证异常处理 :param e: :return: """ errors = [] for error in e.errors(): custom_message = CUSTOM_VALIDATION_ERROR_MESSAGES.get(error['type']) if custom_message: ctx = error.get('ctx') error['msg'] = custom_message.format(**ctx) if ctx else custom_message errors.append(error) error = errors[0] if error.get('type') == 'json_invalid': message = 'json解析失败' else: error_input = error.get('input') field = str(error.get('loc')[-1]) error_msg = error.get('msg') message = f'{field} {error_msg},输入:{error_input}' msg = f'请求参数非法: {message}' data = {'errors': errors} if settings.ENVIRONMENT == 'dev' else None content = { 'code': StandardResponseCode.HTTP_422, 'msg': msg, 'data': data, } request.state.__request_validation_exception__ = content # 用于在中间件中获取异常信息 return MsgSpecJSONResponse(status_code=422, content=content) def register_exception(app: FastAPI): @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException): """ 全局HTTP异常处理 :param request: :param exc: :return: """ if settings.ENVIRONMENT == 'dev': content = { 'code': exc.status_code, 'msg': exc.detail, 'data': None, } else: res = await response_base.fail(res=CustomResponseCode.HTTP_400) content = res.model_dump() request.state.__request_http_exception__ = content # 用于在中间件中获取异常信息 return MsgSpecJSONResponse( status_code=StandardResponseCode.HTTP_400, content=content, headers=exc.headers, ) @app.exception_handler(RequestValidationError) async def fastapi_validation_exception_handler(request: Request, exc: RequestValidationError): """ fastapi 数据验证异常处理 :param request: :param exc: :return: """ return await _validation_exception_handler(request, exc) @app.exception_handler(ValidationError) async def pydantic_validation_exception_handler(request: Request, exc: ValidationError): """ pydantic 数据验证异常处理 :param request: :param exc: :return: """ return await _validation_exception_handler(request, exc) @app.exception_handler(PydanticUserError) async def pydantic_user_error_handler(request: Request, exc: PydanticUserError): """ Pydantic 用户异常处理 :param request: :param exc: :return: """ return MsgSpecJSONResponse( status_code=StandardResponseCode.HTTP_500, content={ 'code': StandardResponseCode.HTTP_500, 'msg': CUSTOM_USAGE_ERROR_MESSAGES.get(exc.code), 'data': None, }, ) @app.exception_handler(AssertionError) async def assertion_error_handler(request: Request, exc: AssertionError): """ 断言错误处理 :param request: :param exc: :return: """ if settings.ENVIRONMENT == 'dev': content = { 'code': StandardResponseCode.HTTP_500, 'msg': str(''.join(exc.args) if exc.args else exc.__doc__), 'data': None, } else: res = await response_base.fail(res=CustomResponseCode.HTTP_500) content = res.model_dump() return MsgSpecJSONResponse( status_code=StandardResponseCode.HTTP_500, content=content, ) @app.exception_handler(Exception) async def all_exception_handler(request: Request, exc: Exception): """ 全局异常处理 :param request: :param exc: :return: """ if isinstance(exc, BaseExceptionMixin): return MsgSpecJSONResponse( status_code=StandardResponseCode.HTTP_400, content={ 'code': exc.code, 'msg': str(exc.msg), 'data': exc.data if exc.data else None, }, background=exc.background, ) else: import traceback log.error(f'未知异常: {exc}') log.error(traceback.format_exc()) if settings.ENVIRONMENT == 'dev': content = { 'code': 500, 'msg': str(exc), 'data': None, } else: res = await response_base.fail(res=CustomResponseCode.HTTP_500) content = res.model_dump() return MsgSpecJSONResponse(status_code=StandardResponseCode.HTTP_500, content=content) if settings.MIDDLEWARE_CORS: @app.exception_handler(StandardResponseCode.HTTP_500) async def cors_status_code_500_exception_handler(request, exc): """ 跨域 500 异常处理 `Related issue `_ :param request: :param exc: :return: """ if isinstance(exc, BaseExceptionMixin): content = { 'code': exc.code, 'msg': exc.msg, 'data': exc.data, } else: if settings.ENVIRONMENT == 'dev': content = { 'code': StandardResponseCode.HTTP_500, 'msg': str(exc), 'data': None, } else: res = await response_base.fail(res=CustomResponseCode.HTTP_500) content = res.model_dump() response = MsgSpecJSONResponse( status_code=exc.code if isinstance(exc, BaseExceptionMixin) else StandardResponseCode.HTTP_500, content=content, background=exc.background if isinstance(exc, BaseExceptionMixin) else None, ) origin = request.headers.get('origin') if origin: cors = CORSMiddleware( app=app, allow_origins=['*'], allow_credentials=True, allow_methods=['*'], allow_headers=['*'], ) response.headers.update(cors.simple_headers) has_cookie = 'cookie' in request.headers if cors.allow_all_origins and has_cookie: response.headers['Access-Control-Allow-Origin'] = origin elif not cors.allow_all_origins and cors.is_allowed_origin(origin=origin): response.headers['Access-Control-Allow-Origin'] = origin response.headers.add_vary_header('Origin') return response