diff --git a/backend/app/task/celery.py b/backend/app/task/celery.py index e0f403b3..24b62c35 100644 --- a/backend/app/task/celery.py +++ b/backend/app/task/celery.py @@ -4,6 +4,7 @@ import celery import celery_aio_pool from backend.app.task.tasks.beat import LOCAL_BEAT_SCHEDULE +from backend.common.enums import DataBaseType from backend.core.conf import settings from backend.core.path_conf import BASE_PATH @@ -32,7 +33,7 @@ def init_celery() -> celery.Celery: broker_url = f'redis://:{settings.REDIS_PASSWORD}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER_REDIS_DATABASE}' result_backend = f'db+postgresql+psycopg://{settings.DATABASE_USER}:{settings.DATABASE_PASSWORD}@{settings.DATABASE_HOST}:{settings.DATABASE_PORT}/{settings.DATABASE_SCHEMA}' - if settings.DATABASE_TYPE == 'mysql': + if DataBaseType.mysql == settings.DATABASE_TYPE: result_backend = result_backend.replace('postgresql+psycopg', 'mysql+pymysql') # https://docs.celeryq.dev/en/stable/userguide/configuration.html diff --git a/backend/common/model.py b/backend/common/model.py index 4f8b799f..239c3071 100644 --- a/backend/common/model.py +++ b/backend/common/model.py @@ -6,6 +6,7 @@ from sqlalchemy.dialects.mysql import LONGTEXT from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, declared_attr, mapped_column +from backend.common.enums import DataBaseType, PrimaryKeyType from backend.core.conf import settings from backend.utils.snowflake import snowflake from backend.utils.timezone import timezone @@ -23,15 +24,11 @@ id_key = Annotated[ autoincrement=True, sort_order=-999, comment='主键 ID', - ), -] - - -# 雪花算法 Mapped 类型主键,使用方法与 id_key 相同 -# 详情:https://fastapi-practices.github.io/fastapi_best_architecture_docs/backend/reference/pk.html -snowflake_id_key = Annotated[ - int, - mapped_column( + ) + if PrimaryKeyType.autoincrement == settings.DATABASE_PK_MODE + # 雪花算法 Mapped 类型主键 + # 详情:https://fastapi-practices.github.io/fastapi_best_architecture_docs/backend/reference/pk.html + else mapped_column( BigInteger, primary_key=True, unique=True, @@ -46,7 +43,7 @@ snowflake_id_key = Annotated[ class UniversalText(TypeDecorator[str]): """PostgreSQL、MySQL 兼容性(长)文本类型""" - impl = LONGTEXT if settings.DATABASE_TYPE == 'mysql' else Text + impl = LONGTEXT if DataBaseType.mysql == settings.DATABASE_TYPE else Text cache_ok = True def process_bind_param(self, value: str | None, dialect) -> str | None: # noqa: ANN001 diff --git a/backend/common/schema.py b/backend/common/schema.py index e3b29810..37b0fe54 100644 --- a/backend/common/schema.py +++ b/backend/common/schema.py @@ -3,6 +3,7 @@ from typing import Annotated, Any from pydantic import BaseModel, ConfigDict, EmailStr, Field, validate_email +from backend.core.conf import settings from backend.utils.timezone import timezone CustomPhoneNumber = Annotated[str, Field(pattern=r'^1[3-9]\d{9}$')] @@ -28,6 +29,14 @@ class SchemaBase(BaseModel): }, ) + if settings.DATABASE_PK_MODE: + from pydantic import field_serializer + + # 详情:https://fastapi-practices.github.io/fastapi_best_architecture_docs/backend/reference/pk.html#%E6%B3%A8%E6%84%8F%E4%BA%8B%E9%A1%B9 + @field_serializer('id', check_fields=False) + def serialize_id(self, value: int) -> str: + return str(value) + def ser_string(value: Any) -> str | None: if value: diff --git a/backend/core/conf.py b/backend/core/conf.py index 03a08c20..6b443576 100644 --- a/backend/core/conf.py +++ b/backend/core/conf.py @@ -42,6 +42,7 @@ class Settings(BaseSettings): DATABASE_POOL_ECHO: bool | Literal['debug'] = False DATABASE_SCHEMA: str = 'fba' DATABASE_CHARSET: str = 'utf8mb4' + DATABASE_PK_MODE: Literal['autoincrement', 'snowflake'] = 'autoincrement' # .env Redis REDIS_HOST: str diff --git a/backend/database/db.py b/backend/database/db.py index 60ee9ef9..09d3ac4c 100644 --- a/backend/database/db.py +++ b/backend/database/db.py @@ -13,6 +13,7 @@ from sqlalchemy.ext.asyncio import ( create_async_engine, ) +from backend.common.enums import DataBaseType from backend.common.log import log from backend.common.model import MappedBase from backend.core.conf import settings @@ -26,14 +27,14 @@ def create_database_url(*, unittest: bool = False) -> URL: :return: """ url = URL.create( - drivername='mysql+asyncmy' if settings.DATABASE_TYPE == 'mysql' else 'postgresql+asyncpg', + drivername='mysql+asyncmy' if DataBaseType.mysql == settings.DATABASE_TYPE else 'postgresql+asyncpg', username=settings.DATABASE_USER, password=settings.DATABASE_PASSWORD, host=settings.DATABASE_HOST, port=settings.DATABASE_PORT, database=settings.DATABASE_SCHEMA if not unittest else f'{settings.DATABASE_SCHEMA}_test', ) - if settings.DATABASE_TYPE == 'mysql': + if DataBaseType.mysql == settings.DATABASE_TYPE: url.update_query_dict({'charset': settings.DATABASE_CHARSET}) return url diff --git a/backend/plugin/code_generator/crud/crud_code.py b/backend/plugin/code_generator/crud/crud_code.py index 83aa1450..306750eb 100644 --- a/backend/plugin/code_generator/crud/crud_code.py +++ b/backend/plugin/code_generator/crud/crud_code.py @@ -3,6 +3,7 @@ from collections.abc import Sequence from sqlalchemy import Row, RowMapping, text from sqlalchemy.ext.asyncio import AsyncSession +from backend.common.enums import DataBaseType from backend.core.conf import settings @@ -18,7 +19,7 @@ class CRUDGen: :param table_schema: 数据库 schema 名称 :return: """ - if settings.DATABASE_TYPE == 'mysql': + if DataBaseType.mysql == settings.DATABASE_TYPE: sql = """ SELECT table_name AS table_name, table_comment AS table_comment FROM information_schema.tables @@ -48,7 +49,7 @@ class CRUDGen: :param table_name: 表名 :return: """ - if settings.DATABASE_TYPE == 'mysql': + if DataBaseType.mysql == settings.DATABASE_TYPE: sql = """ SELECT table_name AS table_name, table_comment AS table_comment FROM information_schema.tables @@ -79,7 +80,7 @@ class CRUDGen: :param table_name: 表名 :return: """ - if settings.DATABASE_TYPE == 'mysql': + if DataBaseType.mysql == settings.DATABASE_TYPE: sql = """ SELECT column_name AS column_name, CASE WHEN column_key = 'PRI' THEN 1 ELSE 0 END AS is_pk, diff --git a/backend/plugin/code_generator/service/column_service.py b/backend/plugin/code_generator/service/column_service.py index 46707b74..18e09370 100644 --- a/backend/plugin/code_generator/service/column_service.py +++ b/backend/plugin/code_generator/service/column_service.py @@ -2,6 +2,7 @@ from collections.abc import Sequence from sqlalchemy.ext.asyncio import AsyncSession +from backend.common.enums import DataBaseType from backend.common.exception import errors from backend.core.conf import settings from backend.plugin.code_generator.crud.crud_column import gen_column_dao @@ -32,7 +33,7 @@ class GenColumnService: @staticmethod async def get_types() -> list[str]: """获取所有列类型""" - if settings.DATABASE_TYPE == 'mysql': + if DataBaseType.mysql == settings.DATABASE_TYPE: types = GenMySQLColumnType.get_member_keys() else: types = GenPostgreSQLColumnType.get_member_keys() diff --git a/backend/plugin/code_generator/utils/type_conversion.py b/backend/plugin/code_generator/utils/type_conversion.py index aae1717e..3e66118b 100644 --- a/backend/plugin/code_generator/utils/type_conversion.py +++ b/backend/plugin/code_generator/utils/type_conversion.py @@ -1,5 +1,6 @@ from functools import lru_cache +from backend.common.enums import DataBaseType from backend.core.conf import settings from backend.plugin.code_generator.enums import GenMySQLColumnType, GenPostgreSQLColumnType @@ -12,7 +13,7 @@ def sql_type_to_sqlalchemy(typing: str) -> str: :param typing: SQL 类型字符串 :return: """ - if settings.DATABASE_TYPE == 'mysql': + if DataBaseType.mysql == settings.DATABASE_TYPE: if typing in GenMySQLColumnType.get_member_keys(): return typing else: @@ -30,7 +31,7 @@ def sql_type_to_pydantic(typing: str) -> str: :return: """ try: - if settings.DATABASE_TYPE == 'mysql': + if DataBaseType.mysql == settings.DATABASE_TYPE: return GenMySQLColumnType[typing].value if typing == 'CHARACTER VARYING': # postgresql 中 DDL VARCHAR 的别名 return 'str' diff --git a/backend/plugin/tools.py b/backend/plugin/tools.py index 61675028..cc6d34ed 100644 --- a/backend/plugin/tools.py +++ b/backend/plugin/tools.py @@ -79,16 +79,16 @@ async def get_plugin_sql(plugin: str, db_type: DataBaseType, pk_type: PrimaryKey """ if db_type == DataBaseType.mysql: mysql_dir = PLUGIN_DIR / plugin / 'sql' / 'mysql' - if pk_type == PrimaryKeyType.autoincrement: - sql_file = mysql_dir / 'init.sql' - else: - sql_file = mysql_dir / 'init_snowflake.sql' + sql_file = ( + mysql_dir / 'init.sql' if pk_type == PrimaryKeyType.autoincrement else mysql_dir / 'init_snowflake.sql' + ) else: postgresql_dir = PLUGIN_DIR / plugin / 'sql' / 'postgresql' - if pk_type == PrimaryKeyType.autoincrement: - sql_file = postgresql_dir / 'init.sql' - else: - sql_file = postgresql_dir / 'init_snowflake.sql' + sql_file = ( + postgresql_dir / 'init.sql' + if pk_type == PrimaryKeyType.autoincrement + else postgresql_dir / 'init_snowflake.sql' + ) path = anyio.Path(sql_file) if not await path.exists():