mirror of
https://github.com/fastapi-practices/fastapi_best_architecture.git
synced 2025-08-14 10:50:48 +08:00
97 lines
2.9 KiB
Python
97 lines
2.9 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
import sys
|
||
|
||
from typing import Annotated, AsyncGenerator
|
||
from uuid import uuid4
|
||
|
||
from fastapi import Depends
|
||
from sqlalchemy import URL
|
||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
||
|
||
from backend.common.log import log
|
||
from backend.common.model import MappedBase
|
||
from backend.core.conf import settings
|
||
|
||
|
||
def create_database_url(unittest: bool = False) -> URL:
|
||
"""
|
||
创建数据库链接
|
||
|
||
:param unittest: 是否用于单元测试
|
||
:return:
|
||
"""
|
||
url = URL.create(
|
||
drivername='mysql+asyncmy' if settings.DATABASE_TYPE == 'mysql' else 'postgresql+asyncpg',
|
||
username=settings.DATABASE_USER,
|
||
password=settings.DATABASE_PASSWORD,
|
||
host=settings.DATABASE_HOST,
|
||
port=settings.DATABASE_PORT,
|
||
database=settings.DATABASE_SCHEMA if not unittest else f'{settings.DATABASE_SCHEMA}_test',
|
||
)
|
||
if settings.DATABASE_TYPE == 'mysql':
|
||
url.update_query_dict({'charset': settings.DATABASE_CHARSET})
|
||
return url
|
||
|
||
|
||
def create_async_engine_and_session(url: str | URL) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]:
|
||
"""
|
||
创建数据库引擎和 Session
|
||
|
||
:param url: 数据库连接 URL
|
||
:return:
|
||
"""
|
||
try:
|
||
# 数据库引擎
|
||
engine = create_async_engine(
|
||
url,
|
||
echo=settings.DATABASE_ECHO,
|
||
echo_pool=settings.DATABASE_POOL_ECHO,
|
||
future=True,
|
||
# 中等并发
|
||
pool_size=10, # 低:- 高:+
|
||
max_overflow=20, # 低:- 高:+
|
||
pool_timeout=30, # 低:+ 高:-
|
||
pool_recycle=3600, # 低:+ 高:-
|
||
pool_pre_ping=True, # 低:False 高:True
|
||
pool_use_lifo=False, # 低:False 高:True
|
||
)
|
||
except Exception as e:
|
||
log.error('❌ 数据库链接失败 {}', e)
|
||
sys.exit()
|
||
else:
|
||
db_session = async_sessionmaker(
|
||
bind=engine,
|
||
class_=AsyncSession,
|
||
autoflush=False, # 禁用自动刷新
|
||
expire_on_commit=False, # 禁用提交时过期
|
||
)
|
||
return engine, db_session
|
||
|
||
|
||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||
"""获取数据库会话"""
|
||
async with async_db_session() as session:
|
||
yield session
|
||
|
||
|
||
async def create_tables() -> None:
|
||
"""创建数据库表"""
|
||
async with async_engine.begin() as coon:
|
||
await coon.run_sync(MappedBase.metadata.create_all)
|
||
|
||
|
||
def uuid4_str() -> str:
|
||
"""数据库引擎 UUID 类型兼容性解决方案"""
|
||
return str(uuid4())
|
||
|
||
|
||
# SQLA 数据库链接
|
||
SQLALCHEMY_DATABASE_URL = create_database_url()
|
||
|
||
# SALA 异步引擎和会话
|
||
async_engine, async_db_session = create_async_engine_and_session(SQLALCHEMY_DATABASE_URL)
|
||
|
||
# Session Annotated
|
||
CurrentSession = Annotated[AsyncSession, Depends(get_db)]
|