#!/usr/bin/env python3 # -*- coding: utf-8 -*- from typing import Any from fastapi import Request, Response from starlette.authentication import AuthCredentials, AuthenticationBackend, AuthenticationError from starlette.requests import HTTPConnection from backend.common.exception.errors import TokenError from backend.common.log import log from backend.common.security import jwt from backend.core.conf import settings from backend.database.db_mysql import async_db_session from backend.utils.serializers import MsgSpecJSONResponse class _AuthenticationError(AuthenticationError): """重写内部认证错误类""" def __init__(self, *, code: int = None, msg: str = None, headers: dict[str, Any] | None = None): self.code = code self.msg = msg self.headers = headers class JwtAuthMiddleware(AuthenticationBackend): """JWT 认证中间件""" @staticmethod def auth_exception_handler(conn: HTTPConnection, exc: _AuthenticationError) -> Response: """覆盖内部认证错误处理""" return MsgSpecJSONResponse(content={'code': exc.code, 'msg': exc.msg, 'data': None}, status_code=exc.code) async def authenticate(self, request: Request): auth = request.headers.get('Authorization') if not auth: return if request.url.path in settings.TOKEN_EXCLUDE: return scheme, token = auth.split() if scheme.lower() != 'bearer': return try: sub = await jwt.jwt_authentication(token) async with async_db_session() as db: user = await jwt.get_current_user(db, data=sub) except TokenError as exc: raise _AuthenticationError(code=exc.code, msg=exc.detail, headers=exc.headers) except Exception as e: log.exception(e) raise _AuthenticationError(code=getattr(e, 'code', 500), msg=getattr(e, 'msg', 'Internal Server Error')) # 请注意,此返回使用非标准模式,所以在认证通过时,将丢失某些标准特性 # 标准返回模式请查看:https://www.starlette.io/authentication/ return AuthCredentials(['authenticated']), user