Add request state middleware (#426)

* Add request state middleware

* Fix linux do OAuth2 redirect uri
This commit is contained in:
Wu Clan
2024-09-21 17:57:27 +08:00
committed by GitHub
parent 6792bf9954
commit 8da993655b
6 changed files with 32 additions and 16 deletions

View File

@ -21,7 +21,7 @@ _linux_do_oauth2 = FastAPIOAuth20(_linux_do_client, admin_settings.OAUTH2_LINUX_
@router.get('', summary='获取 Linux Do 授权链接')
async def linux_do_auth2() -> ResponseModel:
auth_url = await _linux_do_client.get_authorization_url(redirect_uri=admin_settings.OAUTH2_GITHUB_REDIRECT_URI)
auth_url = await _linux_do_client.get_authorization_url(redirect_uri=admin_settings.OAUTH2_LINUX_DO_REDIRECT_URI)
return response_base.success(data=auth_url)

View File

@ -29,7 +29,6 @@ class LoginLogService:
msg: str,
) -> None:
try:
# request.state 来自 opera log 中间件定义的扩展参数,详见 opera_log_middleware.py
obj_in = CreateLoginLogParam(
user_uuid=user_uuid,
username=username,

View File

@ -17,6 +17,7 @@ from backend.database.db_mysql import create_table
from backend.database.db_redis import redis_client
from backend.middleware.jwt_auth_middleware import JwtAuthMiddleware
from backend.middleware.opera_log_middleware import OperaLogMiddleware
from backend.middleware.state_middleware import StateMiddleware
from backend.utils.demo_site import demo_site
from backend.utils.health_check import ensure_unique_route_names, http_limit_callback
from backend.utils.openapi import simplify_operation_ids
@ -126,6 +127,8 @@ def register_middleware(app: FastAPI):
from backend.middleware.access_middleware import AccessMiddleware
app.add_middleware(AccessMiddleware)
# State
app.add_middleware(StateMiddleware)
# Trace ID (required)
app.add_middleware(CorrelationIdMiddleware, validator=False)
# CORS: Always at the end

View File

@ -15,7 +15,6 @@ from backend.common.enums import OperaLogCipherType, StatusType
from backend.common.log import log
from backend.core.conf import settings
from backend.utils.encrypt import AESCipher, ItsDCipher, Md5Cipher
from backend.utils.request_parse import parse_ip_info, parse_user_agent_info
from backend.utils.timezone import timezone
from backend.utils.trace_id import get_request_trace_id
@ -30,8 +29,6 @@ class OperaLogMiddleware(BaseHTTPMiddleware):
return await call_next(request)
# 请求解析
ip_info = await parse_ip_info(request)
ua_info = await parse_user_agent_info(request)
try:
# 此信息依赖于 jwt 中间件
username = request.user.username
@ -41,16 +38,6 @@ class OperaLogMiddleware(BaseHTTPMiddleware):
args = await self.get_request_args(request)
args = await self.desensitization(args)
# 设置附加请求信息
request.state.ip = ip_info.ip
request.state.country = ip_info.country
request.state.region = ip_info.region
request.state.city = ip_info.city
request.state.user_agent = ua_info.user_agent
request.state.os = ua_info.os
request.state.browser = ua_info.browser
request.state.device = ua_info.device
# 执行请求
start_time = timezone.now()
request_next = await self.execute_request(request, call_next)

View File

@ -0,0 +1,28 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from backend.utils.request_parse import parse_ip_info, parse_user_agent_info
class StateMiddleware(BaseHTTPMiddleware):
"""请求 state 中间件"""
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
ip_info = await parse_ip_info(request)
ua_info = parse_user_agent_info(request)
# 设置附加请求信息
request.state.ip = ip_info.ip
request.state.country = ip_info.country
request.state.region = ip_info.region
request.state.city = ip_info.city
request.state.user_agent = ua_info.user_agent
request.state.os = ua_info.os
request.state.browser = ua_info.browser
request.state.device = ua_info.device
response = await call_next(request)
return response

View File

@ -100,7 +100,6 @@ async def parse_ip_info(request: Request) -> IpInfo:
return IpInfo(ip=ip, country=country, region=region, city=city)
@sync_to_async
def parse_user_agent_info(request: Request) -> UserAgentInfo:
user_agent = request.headers.get('User-Agent')
_user_agent = parse(user_agent)