Update the database and Redis for easier scaling (#1015)

* Update the database and Redis for easier scaling

* Restore plugin redis variable naming
This commit is contained in:
Wu Clan
2026-01-15 16:54:43 +08:00
committed by GitHub
parent f876162456
commit 383620c899
5 changed files with 81 additions and 40 deletions

View File

@@ -2,14 +2,17 @@ from collections.abc import AsyncGenerator
from sqlalchemy.ext.asyncio.session import AsyncSession
from backend.database.db import create_async_engine_and_session, create_database_url
from backend.database.db import create_database_async_engine, create_database_async_session, create_database_url
# SQLA 数据库链接
TEST_SQLALCHEMY_DATABASE_URL = create_database_url(unittest=True)
_, async_test_db_session = create_async_engine_and_session(TEST_SQLALCHEMY_DATABASE_URL)
# SALA 异步引擎和会话
async_test_engine = create_database_async_engine(TEST_SQLALCHEMY_DATABASE_URL)
async_test_db_session = create_database_async_session(async_test_engine)
async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
"""session 生成器"""
"""获取数据库会话"""
async with async_test_db_session() as session:
yield session

View File

@@ -51,7 +51,7 @@ async def register_init(app: FastAPI) -> AsyncGenerator[None, None]:
await create_tables()
# 初始化 redis
await redis_client.open()
await redis_client.init()
# 初始化 limiter
await FastAPILimiter.init(

View File

@@ -1,7 +1,7 @@
import sys
from collections.abc import AsyncGenerator
from typing import Annotated
from typing import Annotated, Any
from uuid import uuid4
from fastapi import Depends
@@ -19,36 +19,41 @@ from backend.common.model import MappedBase
from backend.core.conf import settings
def create_database_url(*, unittest: bool = False) -> URL:
def create_database_url(*, unittest: bool = False, with_database: bool = True) -> URL:
"""
创建数据库链接
:param unittest: 是否用于单元测试
:param with_database: 是否包含数据库名(创建数据库时不需要)
:return:
"""
if with_database:
database = settings.DATABASE_SCHEMA if not unittest else f'{settings.DATABASE_SCHEMA}_test'
else:
database = None if DataBaseType.mysql == settings.DATABASE_TYPE else 'postgres'
url = URL.create(
drivername='mysql+asyncmy' if DataBaseType.mysql == settings.DATABASE_TYPE else 'postgresql+asyncpg',
username=settings.DATABASE_USER,
password=settings.DATABASE_PASSWORD,
host=settings.DATABASE_HOST,
port=settings.DATABASE_PORT,
database=settings.DATABASE_SCHEMA if not unittest else f'{settings.DATABASE_SCHEMA}_test',
database=database,
)
if DataBaseType.mysql == settings.DATABASE_TYPE:
url.update_query_dict({'charset': settings.DATABASE_CHARSET})
if DataBaseType.mysql == settings.DATABASE_TYPE and with_database:
url = url.update_query_dict({'charset': settings.DATABASE_CHARSET})
return url
def create_async_engine_and_session(url: str | URL) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]:
def create_database_async_engine(url: str | URL) -> AsyncEngine:
"""
创建数据库引擎和 Session
创建数据库异步引擎
:param url: 数据库连接 URL
:param url: 数据库连接地址
:return:
"""
try:
# 数据库引擎
engine = create_async_engine(
return create_async_engine(
url,
echo=settings.DATABASE_ECHO,
echo_pool=settings.DATABASE_POOL_ECHO,
@@ -62,16 +67,23 @@ def create_async_engine_and_session(url: str | URL) -> tuple[AsyncEngine, async_
pool_use_lifo=False, # 低False 高True
)
except Exception as e:
log.error('数据库接失败 {}', e)
log.error(f'数据库接失败 {e}')
sys.exit()
else:
db_session = async_sessionmaker(
bind=engine,
class_=AsyncSession,
autoflush=False, # 禁用自动刷新
expire_on_commit=False, # 禁用提交时过期
)
return engine, db_session
def create_database_async_session(engine: AsyncEngine) -> async_sessionmaker[AsyncSession | Any]:
"""
创建数据库异步会话
:param engine: 数据库异步引擎
:return:
"""
return async_sessionmaker(
bind=engine,
class_=AsyncSession,
autoflush=False, # 禁用自动刷新
expire_on_commit=False, # 禁用提交时过期
)
async def get_db() -> AsyncGenerator[AsyncSession, None]:
@@ -107,7 +119,8 @@ def uuid4_str() -> str:
SQLALCHEMY_DATABASE_URL = create_database_url()
# SALA 异步引擎和会话
async_engine, async_db_session = create_async_engine_and_session(SQLALCHEMY_DATABASE_URL)
async_engine = create_database_async_engine(SQLALCHEMY_DATABASE_URL)
async_db_session = create_database_async_session(async_engine)
# Session Annotated
CurrentSession = Annotated[AsyncSession, Depends(get_db)]

View File

@@ -10,32 +10,56 @@ from backend.core.conf import settings
class RedisCli(Redis):
"""Redis 客户端"""
def __init__(self) -> None:
"""初始化 Redis 客户端"""
def __init__(
self,
host: str = settings.REDIS_HOST,
port: int = settings.REDIS_PORT,
password: str = settings.REDIS_PASSWORD,
db: int = settings.REDIS_DATABASE,
socket_timeout: int = settings.REDIS_TIMEOUT,
socket_connect_timeout: int = settings.REDIS_TIMEOUT,
*,
socket_keepalive: bool = True,
health_check_interval: int = 30,
decode_responses: bool = True,
) -> None:
"""
初始化 Redis 客户端
:param host: Redis 服务器的主机地址
:param port: Redis 服务器的端口号
:param password: Redis 认证密码
:param db: 使用的 Redis 逻辑数据库索引
:param socket_timeout: Socket 读写操作的超时时间
:param socket_connect_timeout: 建立 TCP 连接时的超时时间
:param socket_keepalive: 是否开启 TCP Keepalive 探测
:param health_check_interval: 健康检查间隔时间(秒)
:param decode_responses: 是否自动将 Redis 返回的字节流bytes解码为字符串utf-8
"""
super().__init__(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
password=settings.REDIS_PASSWORD,
db=settings.REDIS_DATABASE,
socket_timeout=settings.REDIS_TIMEOUT,
socket_connect_timeout=settings.REDIS_TIMEOUT,
socket_keepalive=True, # 保持连接
health_check_interval=30, # 健康检查间隔
decode_responses=True, # 转码 utf-8
host=host,
port=port,
password=password,
db=db,
socket_timeout=socket_timeout,
socket_connect_timeout=socket_connect_timeout,
socket_keepalive=socket_keepalive,
health_check_interval=health_check_interval,
decode_responses=decode_responses,
)
async def open(self) -> None:
"""触发初始化连接"""
async def init(self) -> None:
"""初始化 Redis 服务器"""
try:
await self.ping()
except TimeoutError:
log.error('❌ 数据库 redis 连接超时')
log.error('Redis 服务器连接超时')
sys.exit()
except AuthenticationError:
log.error('❌ 数据库 redis 连接认证失败')
log.error('Redis 服务器连接认证失败')
sys.exit()
except Exception as e:
log.error('❌ 数据库 redis 连接异常 {}', e)
log.error('Redis 服务器连接异常 {}', e)
sys.exit()
async def delete_prefix(self, prefix: str, exclude: str | list[str] | None = None, batch_size: int = 1000) -> None:

View File

@@ -113,6 +113,7 @@ def parse_plugin_config() -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
# 使用独立单例,避免与主线程冲突
current_redis_client = RedisCli()
run_await(current_redis_client.init)()
# 清理未知插件信息
run_await(current_redis_client.delete_prefix)(