Files
Wu Clan 228ad0e82b Simplify the multivariate expression of exceptions handler (#252)
* Simplify the multivariate expression of cors middleware

* Simplify the multivariate expression of assert exception handler
2023-12-18 21:36:24 +08:00

181 lines
7.5 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from pydantic import ValidationError
from pydantic.errors import EnumMemberError, WrongConstantError
from starlette.exceptions import HTTPException
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import JSONResponse
from backend.app.common.exception.errors import BaseExceptionMixin
from backend.app.common.log import log
from backend.app.common.response.response_code import CustomResponseCode, StandardResponseCode
from backend.app.common.response.response_schema import ResponseModel, response_base
from backend.app.core.conf import settings
def register_exception(app: FastAPI):
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
"""
全局HTTP异常处理
:param request:
:param exc:
:return:
"""
content = ResponseModel(code=exc.status_code, msg=exc.detail).dict()
request.state.__request_http_exception__ = content # 用于在中间件中获取异常信息
return JSONResponse(
status_code=StandardResponseCode.HTTP_400,
content=content if settings.ENVIRONMENT == 'dev' else await response_base.fail(CustomResponseCode.HTTP_400),
headers=exc.headers,
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""
数据验证异常处理
:param request:
:param exc:
:return:
"""
message = ''
data = {}
for raw_error in exc.raw_errors:
raw_exc = raw_error.exc
if isinstance(raw_exc, ValidationError):
if hasattr(raw_exc, 'model'):
fields = raw_exc.model.__dict__.get('__fields__')
for field_key in fields.keys():
field_title = fields.get(field_key).field_info.title
data[field_key] = field_title if field_title else field_key
# 处理特殊类型异常模板信息 -> backend/app/schemas/base.py: SCHEMA_ERROR_MSG_TEMPLATES
for sub_raw_error in raw_exc.raw_errors:
sub_raw_exc = sub_raw_error.exc
if isinstance(sub_raw_exc, (EnumMemberError, WrongConstantError)):
if getattr(sub_raw_exc, 'code') == 'enum':
sub_raw_exc.__dict__['permitted'] = ', '.join(
repr(v.value)
for v in sub_raw_exc.enum_values # type: ignore
)
else:
sub_raw_exc.__dict__['permitted'] = ', '.join(
repr(v)
for v in sub_raw_exc.permitted # type: ignore
)
# 处理异常信息
for error in raw_exc.errors()[:1]:
field = str(error.get('loc')[-1])
msg = error.get('msg')
message += f'{data.get(field, field) if field != "__root__" else ""} {msg}' + '.'
elif isinstance(raw_error.exc, json.JSONDecodeError):
message += 'json解析失败'
content = ResponseModel(
code=StandardResponseCode.HTTP_422,
msg='请求参数非法' if len(message) == 0 else f'请求参数非法: {message}',
data={'errors': exc.errors()} if message == '' else None,
).dict()
request.state.__request_validation_exception__ = content # 用于在中间件中获取异常信息
return JSONResponse(
status_code=StandardResponseCode.HTTP_422,
content=content if settings.ENVIRONMENT == 'dev' else await response_base.fail(CustomResponseCode.HTTP_422),
)
@app.exception_handler(Exception)
async def all_exception_handler(request: Request, exc: Exception):
"""
全局异常处理
:param request:
:param exc:
:return:
"""
if isinstance(exc, BaseExceptionMixin):
return JSONResponse(
status_code=StandardResponseCode.HTTP_400,
content=ResponseModel(
code=exc.code,
msg=str(exc.msg),
data=exc.data if exc.data else None,
).dict(),
background=exc.background,
)
elif isinstance(exc, AssertionError):
if exc.args:
msg = ','.join(exc.args)
else:
msg = exc.__repr__()
if not exc.__repr__().startswith('AssertionError'):
msg = exc.__doc__
return JSONResponse(
status_code=StandardResponseCode.HTTP_500,
content=ResponseModel(
code=StandardResponseCode.HTTP_500,
msg=str(msg),
).dict()
if settings.ENVIRONMENT == 'dev'
else await response_base.fail(CustomResponseCode.HTTP_500),
)
else:
import traceback
log.error(f'未知异常: {exc}')
log.error(traceback.format_exc())
return JSONResponse(
status_code=StandardResponseCode.HTTP_500,
content=ResponseModel(code=500, msg=str(exc)).dict()
if settings.ENVIRONMENT == 'dev'
else await response_base.fail(CustomResponseCode.HTTP_500),
)
if settings.MIDDLEWARE_CORS:
@app.exception_handler(StandardResponseCode.HTTP_500)
async def cors_status_code_500_exception_handler(request, exc):
"""
跨域 500 异常处理
`Related issue <https://github.com/encode/starlette/issues/1175>`_
:param request:
:param exc:
:return:
"""
if isinstance(exc, BaseExceptionMixin):
content = ResponseModel(code=exc.code, msg=exc.msg, data=exc.data).dict()
else:
content = (
ResponseModel(code=StandardResponseCode.HTTP_500, msg=str(exc)).dict()
if settings.ENVIRONMENT == 'dev'
else await response_base.fail(CustomResponseCode.HTTP_500)
)
response = JSONResponse(
status_code=exc.code if isinstance(exc, BaseExceptionMixin) else StandardResponseCode.HTTP_500,
content=content,
background=exc.background if isinstance(exc, BaseExceptionMixin) else None,
)
origin = request.headers.get('origin')
if origin:
cors = CORSMiddleware(
app=app,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
response.headers.update(cors.simple_headers)
has_cookie = 'cookie' in request.headers
if cors.allow_all_origins and has_cookie:
response.headers['Access-Control-Allow-Origin'] = origin
elif not cors.allow_all_origins and cors.is_allowed_origin(origin=origin):
response.headers['Access-Control-Allow-Origin'] = origin
response.headers.add_vary_header('Origin')
return response