Files

97 lines
2.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import sys
from typing import Annotated, AsyncGenerator
from uuid import uuid4
from fastapi import Depends
from sqlalchemy import URL
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from backend.common.log import log
from backend.common.model import MappedBase
from backend.core.conf import settings
def create_database_url(unittest: bool = False) -> URL:
"""
创建数据库链接
:param unittest: 是否用于单元测试
:return:
"""
url = URL.create(
drivername='mysql+asyncmy' if settings.DATABASE_TYPE == 'mysql' else 'postgresql+asyncpg',
username=settings.DATABASE_USER,
password=settings.DATABASE_PASSWORD,
host=settings.DATABASE_HOST,
port=settings.DATABASE_PORT,
database=settings.DATABASE_SCHEMA if not unittest else f'{settings.DATABASE_SCHEMA}_test',
)
if settings.DATABASE_TYPE == 'mysql':
url.update_query_dict({'charset': settings.DATABASE_CHARSET})
return url
def create_async_engine_and_session(url: str | URL) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]:
"""
创建数据库引擎和 Session
:param url: 数据库连接 URL
:return:
"""
try:
# 数据库引擎
engine = create_async_engine(
url,
echo=settings.DATABASE_ECHO,
echo_pool=settings.DATABASE_POOL_ECHO,
future=True,
# 中等并发
pool_size=10, # 低:- 高:+
max_overflow=20, # 低:- 高:+
pool_timeout=30, # 低:+ 高:-
pool_recycle=3600, # 低:+ 高:-
pool_pre_ping=True, # 低False 高True
pool_use_lifo=False, # 低False 高True
)
except Exception as e:
log.error('❌ 数据库链接失败 {}', e)
sys.exit()
else:
db_session = async_sessionmaker(
bind=engine,
class_=AsyncSession,
autoflush=False, # 禁用自动刷新
expire_on_commit=False, # 禁用提交时过期
)
return engine, db_session
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""获取数据库会话"""
async with async_db_session() as session:
yield session
async def create_tables() -> None:
"""创建数据库表"""
async with async_engine.begin() as coon:
await coon.run_sync(MappedBase.metadata.create_all)
def uuid4_str() -> str:
"""数据库引擎 UUID 类型兼容性解决方案"""
return str(uuid4())
# SQLA 数据库链接
SQLALCHEMY_DATABASE_URL = create_database_url()
# SALA 异步引擎和会话
async_engine, async_db_session = create_async_engine_and_session(SQLALCHEMY_DATABASE_URL)
# Session Annotated
CurrentSession = Annotated[AsyncSession, Depends(get_db)]