Optimize code generation data processing (#1020)

* Optimize code generation data processing

* Add SQL script generation

* Update code gen table scripts

* Update types

* Fix model jinja and type conversion
This commit is contained in:
Wu Clan
2026-01-20 18:27:42 +08:00
committed by GitHub
parent b2785dd46f
commit 1a6aba6105
22 changed files with 467 additions and 298 deletions

View File

@@ -1,6 +1,6 @@
from collections.abc import Sequence
from sqlalchemy import Row, RowMapping, text
from sqlalchemy import RowMapping, text
from sqlalchemy.ext.asyncio import AsyncSession
from backend.common.enums import DataBaseType
@@ -21,57 +21,75 @@ class CRUDGen:
"""
if DataBaseType.mysql == settings.DATABASE_TYPE:
sql = """
SELECT table_name AS table_name, table_comment AS table_comment
FROM information_schema.tables
WHERE table_name NOT LIKE 'sys_gen_%'
AND table_schema = :table_schema;
SELECT TABLE_NAME AS TABLE_NAME,
table_comment AS table_comment
FROM
information_schema.TABLES
WHERE
TABLE_NAME NOT LIKE'sys_gen_%'
AND table_schema = :table_schema;
"""
stmt = text(sql).bindparams(table_schema=table_schema)
else:
sql = """
SELECT c.relname AS table_name, obj_description(c.oid) AS table_comment
FROM pg_class c
LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE c.relkind = 'r'
AND n.nspname = 'public' -- schema 通常是 'public'
AND c.relname NOT LIKE 'sys_gen_%';
SELECT
c.relname AS TABLE_NAME,
obj_description (c.OID) AS table_comment
FROM
pg_class c
LEFT JOIN pg_namespace n ON n.OID = c.relnamespace
WHERE
c.relkind = 'r'
AND c.relname NOT LIKE'sys_gen_%'
AND n.nspname = :table_schema;
"""
stmt = text(sql)
stmt = text(sql).bindparams(table_schema='public')
result = await db.execute(stmt)
return result.mappings().all()
@staticmethod
async def get_table(db: AsyncSession, table_name: str) -> Row[tuple]:
async def get_table(db: AsyncSession, table_schema: str, table_name: str) -> RowMapping | None:
"""
获取表信息
:param db: 数据库会话
:param table_schema: 数据库 schema 名称
:param table_name: 表名
:return:
"""
if DataBaseType.mysql == settings.DATABASE_TYPE:
sql = """
SELECT table_name AS table_name, table_comment AS table_comment
FROM information_schema.tables
WHERE table_name NOT LIKE 'sys_gen_%'
AND table_name = :table_name;
SELECT TABLE_NAME AS TABLE_NAME,
table_comment AS table_comment
FROM
information_schema.TABLES
WHERE
TABLE_NAME NOT LIKE'sys_gen_%'
AND TABLE_NAME = :table_name
AND table_schema = :table_schema;
"""
stmt = text(sql).bindparams(table_schema=table_schema, table_name=table_name)
else:
sql = """
SELECT c.relname AS table_name, obj_description(c.oid) AS table_comment
FROM pg_class c
LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE c.relkind = 'r'
AND n.nspname = 'public' -- schema 通常是 'public'
AND c.relname = :table_name
AND c.relname NOT LIKE 'sys_gen_%';
SELECT
c.relname AS table_name,
obj_description (c.OID) AS table_comment
FROM
pg_class c
LEFT JOIN pg_namespace n ON n.OID = c.relnamespace
WHERE
c.relkind = 'r'
AND c.relname NOT LIKE'sys_gen_%'
AND c.relname = :table_name
AND n.nspname = :table_schema;
"""
stmt = text(sql).bindparams(table_name=table_name)
stmt = text(sql).bindparams(table_schema='public', table_name=table_name)
result = await db.execute(stmt)
return result.fetchone()
row = result.fetchone()
return row._mapping if row else None
@staticmethod
async def get_all_columns(db: AsyncSession, table_schema: str, table_name: str) -> Sequence[Row[tuple]]:
async def get_all_columns(db: AsyncSession, table_schema: str, table_name: str) -> Sequence[RowMapping]:
"""
获取所有列信息
@@ -82,54 +100,73 @@ class CRUDGen:
"""
if DataBaseType.mysql == settings.DATABASE_TYPE:
sql = """
SELECT column_name AS column_name,
CASE WHEN column_key = 'PRI' THEN 1 ELSE 0 END AS is_pk,
CASE WHEN is_nullable = 'NO' OR column_key = 'PRI' THEN 0 ELSE 1 END AS is_nullable,
ordinal_position AS sort, column_comment AS column_comment,
column_type AS column_type FROM information_schema.columns
WHERE table_schema = :table_schema
AND table_name = :table_name
AND column_name <> 'id'
AND column_name <> 'created_time'
AND column_name <> 'updated_time'
ORDER BY sort;
SELECT COLUMN_NAME AS COLUMN_NAME,
CASE
WHEN column_key = 'PRI' THEN
1
ELSE
0
END AS is_pk,
CASE
WHEN is_nullable = 'NO'
OR column_key = 'PRI' THEN
0
ELSE
1
END AS is_nullable,
ordinal_position AS sort,
column_comment AS column_comment,
column_type AS column_type
FROM
information_schema.COLUMNS
WHERE
COLUMN_NAME <> 'id'
AND COLUMN_NAME <> 'created_time'
AND COLUMN_NAME <> 'updated_time'
AND TABLE_NAME = :table_name
AND table_schema = :table_schema
ORDER BY
sort;
"""
stmt = text(sql).bindparams(table_schema=table_schema, table_name=table_name)
else:
sql = """
SELECT a.attname AS column_name,
CASE WHEN EXISTS (
SELECT 1
FROM pg_constraint c
WHERE c.conrelid = t.oid
AND c.contype = 'p'
AND a.attnum = ANY(c.conkey)
) THEN 1 ELSE 0 END AS is_pk,
CASE WHEN a.attnotnull OR EXISTS (
SELECT 1
FROM pg_constraint c
WHERE c.conrelid = t.oid
AND c.contype = 'p'
AND a.attnum = ANY(c.conkey)
) THEN 0 ELSE 1 END AS is_nullable,
a.attnum AS sort,
col_description(t.oid, a.attnum) AS column_comment,
pg_catalog.format_type(a.atttypid, a.atttypmod) AS column_type
FROM pg_attribute a
JOIN pg_class t ON a.attrelid = t.oid
JOIN pg_namespace n ON n.oid = t.relnamespace
WHERE n.nspname = 'public' -- 根据你的实际情况修改 schema 名称,通常是 'public'
AND t.relname = :table_name
AND a.attnum > 0
AND NOT a.attisdropped
AND a.attname <> 'id'
AND a.attname <> 'created_time'
AND a.attname <> 'updated_time'
ORDER BY sort;
"""
stmt = text(sql).bindparams(table_name=table_name)
SELECT
a.attname AS COLUMN_NAME,
CASE
WHEN EXISTS (SELECT 1 FROM pg_constraint c WHERE c.conrelid = t.OID AND c.contype = 'p' AND a.attnum = ANY (c.conkey)) THEN
1
ELSE
0
END AS is_pk,
CASE
WHEN a.attnotnull
OR EXISTS (SELECT 1 FROM pg_constraint c WHERE c.conrelid = t.OID AND c.contype = 'p' AND a.attnum = ANY (c.conkey)) THEN
0
ELSE
1
END AS is_nullable,
a.attnum AS sort,
col_description (t.OID, a.attnum) AS column_comment,
pg_catalog.format_type (a.atttypid, a.atttypmod) AS column_type
FROM
pg_attribute a
JOIN pg_class t ON a.attrelid = t.
OID JOIN pg_namespace n ON n.OID = t.relnamespace
WHERE
a.attnum > 0
AND NOT a.attisdropped
AND a.attname <> 'id'
AND a.attname <> 'created_time'
AND a.attname <> 'updated_time'
AND t.relname = :table_name
AND n.nspname = :table_schema
ORDER BY
sort;
""" # noqa: E501
stmt = text(sql).bindparams(table_schema='public', table_name=table_name)
result = await db.execute(stmt)
return result.fetchall()
return result.mappings().all()
gen_dao: CRUDGen = CRUDGen()

View File

@@ -19,34 +19,34 @@ class GenMySQLColumnType(StrEnum):
DateTime = 'datetime' # DATETIME
DECIMAL = 'Decimal'
DOUBLE = 'float'
Double = 'float' # DOUBLE
DOUBLE_PRECISION = 'float'
Double = 'float' # DOUBLE
Enum = 'Enum' # Enum()
FLOAT = 'float'
Float = 'float' # FLOAT
INT = 'int' # INTEGER
INTEGER = 'int'
Integer = 'int' # INTEGER
Interval = 'timedelta' # DATETIME
Interval = 'timedelta' # INTERVAL
JSON = 'dict'
LargeBinary = 'bytes' # BLOB
NCHAR = 'str'
NUMERIC = 'Decimal'
Numeric = 'Decimal' # NUMERIC
NVARCHAR = 'str' # String
Numeric = 'Decimal' # NUMERIC
PickleType = 'bytes' # BLOB
REAL = 'float'
SMALLINT = 'int'
SmallInteger = 'int' # SMALLINT
String = 'str' # String
TEXT = 'str'
Text = 'str' # TEXT
TIME = 'time'
Time = 'time' # TIME
TIMESTAMP = 'datetime'
Text = 'str' # TEXT
Time = 'time' # TIME
UUID = 'str | UUID'
Unicode = 'str' # String
UnicodeText = 'str' # TEXT
UUID = 'str | UUID'
Uuid = 'str' # CHAR(32)
VARBINARY = 'bytes'
VARCHAR = 'str' # String
@@ -77,6 +77,8 @@ class GenPostgreSQLColumnType(StrEnum):
BOOLEAN = 'bool'
Boolean = 'bool' # BOOLEAN
CHAR = 'str'
CHARACTER = 'str' # CHAR
CHARACTER_VARYING = 'str' # CHARACTER VARYING
CLOB = 'str'
DATE = 'date'
Date = 'date' # DATE
@@ -84,8 +86,8 @@ class GenPostgreSQLColumnType(StrEnum):
DateTime = 'datetime' # TIMESTAMP WITHOUT TIME ZONE
DECIMAL = 'Decimal'
DOUBLE = 'float'
Double = 'float' # DOUBLE PRECISION
DOUBLE_PRECISION = 'float' # DOUBLE PRECISION
Double = 'float' # DOUBLE PRECISION
Enum = 'Enum' # Enum(name='enum')
FLOAT = 'float'
Float = 'float' # FLOAT
@@ -97,21 +99,25 @@ class GenPostgreSQLColumnType(StrEnum):
LargeBinary = 'bytes' # BYTEA
NCHAR = 'str'
NUMERIC = 'Decimal'
Numeric = 'Decimal' # NUMERIC
NVARCHAR = 'str' # String
Numeric = 'Decimal' # NUMERIC
PickleType = 'bytes' # BYTEA
REAL = 'float'
SMALLINT = 'int'
SmallInteger = 'int' # SMALLINT
String = 'str' # String
TEXT = 'str'
Text = 'str' # TEXT
TIME = 'time' # TIME WITHOUT TIME ZONE
Time = 'time' # TIME WITHOUT TIME ZONE
TIME_WITHOUT_TIME_ZONE = 'time' # TIME WITHOUT TIME ZONE
TIME_WITH_TIME_ZONE = 'time' # TIME WITH TIME ZONE
TIMESTAMP = 'datetime' # TIMESTAMP WITHOUT TIME ZONE
TIMESTAMP_WITHOUT_TIME_ZONE = 'datetime' # TIMESTAMP WITHOUT TIME ZONE
TIMESTAMP_WITH_TIME_ZONE = 'datetime' # TIMESTAMP WITH TIME ZONE
Text = 'str' # TEXT
Time = 'time' # TIME WITHOUT TIME ZONE
UUID = 'str | UUID'
Unicode = 'str' # String
UnicodeText = 'str' # TEXT
UUID = 'str | UUID'
Uuid = 'str'
VARBINARY = 'bytes'
VARCHAR = 'str' # String

View File

@@ -11,19 +11,14 @@ class GenBusiness(Base):
__tablename__ = 'gen_business'
id: Mapped[id_key] = mapped_column(init=False)
app_name: Mapped[str] = mapped_column(sa.String(64), comment='应用名称(英文)')
table_name: Mapped[str] = mapped_column(sa.String(256), unique=True, comment='表名称(英文)')
doc_comment: Mapped[str] = mapped_column(sa.String(256), comment='文档注释(用于函数/参数文档)')
app_name: Mapped[str] = mapped_column(sa.String(64), comment='应用名称')
table_name: Mapped[str] = mapped_column(sa.String(256), unique=True, comment='表名称')
doc_comment: Mapped[str] = mapped_column(sa.String(256), comment='文档注释')
table_comment: Mapped[str | None] = mapped_column(sa.String(256), default=None, comment='表描述')
# relate_model_fk: Mapped[int | None] = mapped_column(default=None, comment='关联表外键')
class_name: Mapped[str | None] = mapped_column(sa.String(64), default=None, comment='基础类名(默认为英文表名称')
schema_name: Mapped[str | None] = mapped_column(
sa.String(64), default=None, comment='Schema 名称 (默认为英文表名称)'
)
filename: Mapped[str | None] = mapped_column(sa.String(64), default=None, comment='基础文件名(默认为英文表名称)')
default_datetime_column: Mapped[bool] = mapped_column(default=True, comment='是否存在默认时间列')
api_version: Mapped[str] = mapped_column(sa.String(32), default='v1', comment='代码生成 api 版本,默认为 v1')
gen_path: Mapped[str | None] = mapped_column(
sa.String(256), default=None, comment='代码生成路径(默认为 app 根路径)'
)
class_name: Mapped[str | None] = mapped_column(sa.String(64), default=None, comment='基础类名')
schema_name: Mapped[str | None] = mapped_column(sa.String(64), default=None, comment='Schema 名称')
filename: Mapped[str | None] = mapped_column(sa.String(64), default=None, comment='基础文件名')
datetime_mixin: Mapped[bool] = mapped_column(default=True, comment='是否包含时间 Mixin 列')
api_version: Mapped[str] = mapped_column(sa.String(32), default='v1', comment='API 版本')
gen_path: Mapped[str | None] = mapped_column(sa.String(256), default=None, comment='生成路径')
remark: Mapped[str | None] = mapped_column(UniversalText, default=None, comment='备注')

View File

@@ -1,6 +1,6 @@
[plugin]
summary = '代码生成'
version = '0.0.7'
version = '0.1.0'
description = '生成通用业务代码'
author = 'wu-clan'

View File

@@ -1,8 +1,10 @@
from datetime import datetime
from pydantic import ConfigDict, Field
from pydantic import ConfigDict, Field, field_validator
from backend.common.exception import errors
from backend.common.schema import SchemaBase
from backend.utils.pattern_validate import is_english_identifier
class GenBusinessSchemaBase(SchemaBase):
@@ -15,11 +17,19 @@ class GenBusinessSchemaBase(SchemaBase):
class_name: str | None = Field(None, description='用于 python 代码基础类名')
schema_name: str | None = Field(None, description='用于 python Schema 代码基础类名')
filename: str | None = Field(None, description='用于 python 代码基础文件名')
default_datetime_column: bool = Field(True, description='是否存在默认时间')
api_version: str = Field('v1', description='代码生成 api 版本')
gen_path: str | None = Field(None, description='代码生成路径')
datetime_mixin: bool = Field(True, description='是否包含时间 Mixin ')
api_version: str = Field('v1', description='API 版本')
gen_path: str | None = Field(None, description='生成路径(默认在 backend/app 目录下)')
remark: str | None = Field(None, description='备注')
@field_validator('app_name', 'table_name')
@classmethod
def validate_english_only(cls, v: str) -> str:
"""验证英文字段"""
if not is_english_identifier(v):
raise errors.RequestError(msg='必须以英文字母开头且只能包含英文字母和下划线')
return v
class CreateGenBusinessParam(GenBusinessSchemaBase):
"""创建代码生成业务参数"""

View File

@@ -19,8 +19,8 @@ class GenColumnSchemaBase(SchemaBase):
@field_validator('type')
@classmethod
def type_update(cls, v: str) -> str:
"""更新列类型"""
def normalize_type(cls, v: str) -> str:
"""规范化类型"""
return sql_type_to_sqlalchemy(v)

View File

@@ -21,6 +21,7 @@ from backend.plugin.code_generator.schema.business import CreateGenBusinessParam
from backend.plugin.code_generator.schema.column import CreateGenColumnParam
from backend.plugin.code_generator.schema.gen import ImportParam
from backend.plugin.code_generator.service.column_service import gen_column_service
from backend.plugin.code_generator.utils.format_code import format_python_code
from backend.plugin.code_generator.utils.gen_template import gen_template
from backend.plugin.code_generator.utils.type_conversion import sql_type_to_pydantic
@@ -50,7 +51,7 @@ class GenService:
:return:
"""
table_info = await gen_dao.get_table(db, obj.table_name)
table_info = await gen_dao.get_table(db, obj.table_schema, obj.table_name)
if not table_info:
raise errors.NotFoundError(msg='数据库表不存在')
@@ -58,13 +59,13 @@ class GenService:
if business_info:
raise errors.ConflictError(msg='已存在相同数据库表业务')
table_name = table_info[0]
table_name = table_info['table_name']
new_business = GenBusiness(
**CreateGenBusinessParam(
app_name=obj.app,
table_name=table_name,
doc_comment=table_info[1] or table_name.split('_')[-1],
table_comment=table_info[1],
doc_comment=table_info['table_comment'] or table_name.split('_')[-1],
table_comment=table_info['table_comment'],
class_name=to_pascal(table_name),
schema_name=to_pascal(table_name),
filename=table_name,
@@ -75,25 +76,27 @@ class GenService:
column_info = await gen_dao.get_all_columns(db, obj.table_schema, table_name)
for column in column_info:
column_type = column[-1].split('(')[0].upper()
column_type = column['column_type'].split('(')[0].upper()
pd_type = sql_type_to_pydantic(column_type)
await gen_column_dao.create(
db,
CreateGenColumnParam(
name=column[0],
comment=column[-2],
name=column['column_name'],
comment=column['column_comment'],
type=column_type,
sort=column[-3],
length=column[-1].split('(')[1][:-1] if pd_type == 'str' and '(' in column[-1] else 0,
is_pk=column[1],
is_nullable=column[2],
sort=column['sort'],
length=column['column_type'].split('(')[1][:-1]
if pd_type == 'str' and '(' in column['column_type']
else 0,
is_pk=column['is_pk'],
is_nullable=column['is_nullable'],
gen_business_id=new_business.id,
),
pd_type=pd_type,
)
@staticmethod
async def render_tpl_code(*, db: AsyncSession, business: GenBusiness) -> dict[str, str]:
async def _render_tpl_code(*, db: AsyncSession, business: GenBusiness) -> dict[str, str]:
"""
渲染模板代码
@@ -106,10 +109,16 @@ class GenService:
raise errors.NotFoundError(msg='代码生成模型表为空')
gen_vars = gen_template.get_vars(business, gen_models)
return {
tpl_path: await gen_template.get_template(tpl_path).render_async(**gen_vars)
for tpl_path in gen_template.get_template_files()
}
template_mapping = gen_template.get_template_path_mapping(business)
rendered_codes = {}
for template_path, output_path in template_mapping.items():
code = await gen_template.get_template(template_path).render_async(**gen_vars)
if output_path.endswith('.py'):
code = await format_python_code(code)
rendered_codes[output_path] = code
return rendered_codes
async def preview(self, *, db: AsyncSession, pk: int) -> dict[str, bytes]:
"""
@@ -119,33 +128,20 @@ class GenService:
:param pk: 业务 ID
:return:
"""
business = await gen_business_dao.get(db, pk)
if not business:
raise errors.NotFoundError(msg='业务不存在')
tpl_code_map = await self.render_tpl_code(db=db, business=business)
codes = {}
for tpl_path, code in tpl_code_map.items():
if tpl_path.startswith('python'):
rootpath = f'fastapi_best_architecture/backend/app/{business.app_name}'
template_name = tpl_path.split('/')[-1]
filepath = None
match template_name:
case 'api.jinja':
filepath = f'{rootpath}/api/{business.api_version}/{business.filename}.py'
case 'crud.jinja':
filepath = f'{rootpath}/crud/crud_{business.filename}.py'
case 'model.jinja':
filepath = f'{rootpath}/model/{business.filename}.py'
case 'schema.jinja':
filepath = f'{rootpath}/schema/{business.filename}.py'
case 'service.jinja':
filepath = f'{rootpath}/service/{business.filename}_service.py'
backend_path = 'fastapi_best_architecture/backend/app/'
if filepath:
codes[filepath] = code.encode('utf-8')
init_files = gen_template.get_init_files(business)
for filepath, content in init_files.items():
codes[f'{backend_path}{filepath}'] = content.encode('utf-8')
rendered_codes = await self._render_tpl_code(db=db, business=business)
for filepath, code in rendered_codes.items():
codes[f'{backend_path}{filepath}'] = code.encode('utf-8')
return codes
@@ -158,15 +154,20 @@ class GenService:
:param pk: 业务 ID
:return:
"""
business = await gen_business_dao.get(db, pk)
if not business:
raise errors.NotFoundError(msg='业务不存在')
gen_path = business.gen_path or '.../backend/app/'
target_files = gen_template.get_code_gen_paths(business)
gen_path = business.gen_path or '<project_root>/backend/app'
paths = []
return [os.path.join(gen_path, *target_file.split('/')) for target_file in target_files]
init_files = gen_template.get_init_files(business)
paths.extend(os.path.join(gen_path, *filepath.split('/')) for filepath in init_files.keys())
template_mapping = gen_template.get_template_path_mapping(business)
paths.extend(os.path.join(gen_path, *filepath.split('/')) for filepath in template_mapping.values())
return paths
async def generate(self, *, db: AsyncSession, pk: int) -> str:
"""
@@ -176,51 +177,26 @@ class GenService:
:param pk: 业务 ID
:return:
"""
business = await gen_business_dao.get(db, pk)
if not business:
raise errors.NotFoundError(msg='业务不存在')
tpl_code_map = await self.render_tpl_code(db=db, business=business)
gen_path = business.gen_path or BASE_PATH / 'app'
gen_path = business.gen_path or str(BASE_PATH / 'app')
for tpl_path, code in tpl_code_map.items():
code_filepath = os.path.join(
gen_path,
*gen_template.get_code_gen_path(tpl_path, business).split('/'),
)
init_files = gen_template.get_init_files(business)
for init_filepath, init_content in init_files.items():
full_path = os.path.join(gen_path, *init_filepath.split('/'))
init_folder = anyio.Path(full_path).parent
await init_folder.mkdir(parents=True, exist_ok=True)
async with await open_file(full_path, 'w', encoding='utf-8') as f:
await f.write(init_content)
# 写入 init 文件
code_folder = anyio.Path(code_filepath).parent
rendered_codes = await self._render_tpl_code(db=db, business=business)
for code_filepath, code in rendered_codes.items():
full_path = os.path.join(gen_path, *code_filepath.split('/'))
code_folder = anyio.Path(full_path).parent
await code_folder.mkdir(parents=True, exist_ok=True)
init_filepath = code_folder.joinpath('__init__.py')
if not await init_filepath.exists():
async with await open_file(init_filepath, 'w', encoding='utf-8') as f:
await f.write(gen_template.init_content)
# api __init__.py
if 'api' in code_filepath:
api_init_filepath = code_folder.parent.joinpath('__init__.py')
async with await open_file(api_init_filepath, 'w', encoding='utf-8') as f:
await f.write(gen_template.init_content)
# app __init__.py
if 'service' in code_filepath:
app_init_filepath = code_folder.parent.joinpath('__init__.py')
async with await open_file(app_init_filepath, 'w', encoding='utf-8') as f:
await f.write(gen_template.init_content)
# model init 文件补充
if code_folder.name == 'model':
async with await open_file(init_filepath, 'a', encoding='utf-8') as f:
await f.write(
f'from backend.app.{business.app_name}.model.{business.table_name} '
f'import {to_pascal(business.table_name)}\n',
)
# 写入代码文件
async with await open_file(code_filepath, 'w', encoding='utf-8') as f:
async with await open_file(full_path, 'w', encoding='utf-8') as f:
await f.write(code)
return gen_path
@@ -233,41 +209,18 @@ class GenService:
:param pk: 业务 ID
:return:
"""
business = await gen_business_dao.get(db, pk)
if not business:
raise errors.NotFoundError(msg='业务不存在')
bio = io.BytesIO()
with zipfile.ZipFile(bio, 'w') as zf:
tpl_code_map = await self.render_tpl_code(db=db, business=business)
for tpl_path, code in tpl_code_map.items():
code_filepath = gen_template.get_code_gen_path(tpl_path, business)
init_files = gen_template.get_init_files(business)
for init_filepath, init_content in init_files.items():
zf.writestr(init_filepath, init_content)
# 写入 init 文件
code_dir = os.path.dirname(code_filepath)
init_filepath = os.path.join(code_dir, '__init__.py')
if 'model' not in code_filepath.split('/'):
zf.writestr(init_filepath, gen_template.init_content)
else:
zf.writestr(
init_filepath,
f'{gen_template.init_content}'
f'from backend.app.{business.app_name}.model.{business.table_name} '
f'import {to_pascal(business.table_name)}\n',
)
# api __init__.py
if 'api' in code_dir:
api_init_filepath = os.path.join(os.path.dirname(code_dir), '__init__.py')
zf.writestr(api_init_filepath, gen_template.init_content)
# app __init__.py
if 'service' in code_dir:
app_init_filepath = os.path.join(os.path.dirname(code_dir), '__init__.py')
zf.writestr(app_init_filepath, gen_template.init_content)
# 写入代码文件
rendered_codes = await self._render_tpl_code(db=db, business=business)
for code_filepath, code in rendered_codes.items():
zf.writestr(code_filepath, code)
bio.seek(0)

View File

@@ -1,4 +1,4 @@
insert into gen_business (id, app_name, table_name, doc_comment, table_comment, class_name, schema_name, filename, default_datetime_column, api_version, gen_path, remark, created_time, updated_time)
insert into gen_business (id, app_name, table_name, doc_comment, table_comment, class_name, schema_name, filename, datetime_mixin, api_version, gen_path, remark, created_time, updated_time)
values (1, 'test', 'sys_opera_log', '操作日志表', '操作日志表', 'SysOperaLog', 'SysOperaLog', 'sys_opera_log', true, 'v1', null, null, '2025-12-15 15:30:33', null);
insert into gen_column (id, name, comment, type, pd_type, `default`, sort, `length`, is_pk, is_nullable, gen_business_id)

View File

@@ -1,4 +1,4 @@
insert into gen_business (id, app_name, table_name, doc_comment, table_comment, class_name, schema_name, filename, default_datetime_column, api_version, gen_path, remark, created_time, updated_time)
insert into gen_business (id, app_name, table_name, doc_comment, table_comment, class_name, schema_name, filename, datetime_mixin, api_version, gen_path, remark, created_time, updated_time)
values (2112248797819043840, 'test', 'sys_opera_log', '操作日志表', '操作日志表', 'SysOperaLog', 'SysOperaLog', 'sys_opera_log', true, 'v1', null, null, '2025-12-15 15:30:33', null);
insert into gen_column (id, name, comment, type, pd_type, `default`, sort, `length`, is_pk, is_nullable, gen_business_id)

View File

@@ -1,4 +1,4 @@
insert into gen_business (id, app_name, table_name, doc_comment, table_comment, class_name, schema_name, filename, default_datetime_column, api_version, gen_path, remark, created_time, updated_time)
insert into gen_business (id, app_name, table_name, doc_comment, table_comment, class_name, schema_name, filename, datetime_mixin, api_version, gen_path, remark, created_time, updated_time)
values (1, 'test', 'sys_opera_log', '操作日志表', '操作日志表', 'SysOperaLog', 'SysOperaLog', 'sys_opera_log', true, 'v1', null, null, '2025-12-15 15:30:33', null);
insert into gen_column (id, name, comment, type, pd_type, "default", sort, "length", is_pk, is_nullable, gen_business_id)

View File

@@ -1,4 +1,4 @@
insert into gen_business (id, app_name, table_name, doc_comment, table_comment, class_name, schema_name, filename, default_datetime_column, api_version, gen_path, remark, created_time, updated_time)
insert into gen_business (id, app_name, table_name, doc_comment, table_comment, class_name, schema_name, filename, datetime_mixin, api_version, gen_path, remark, created_time, updated_time)
values (2112248797819043840, 'test', 'sys_opera_log', '操作日志表', '操作日志表', 'SysOperaLog', 'SysOperaLog', 'sys_opera_log', true, 'v1', null, null, '2025-12-15 15:30:33', null);
insert into gen_column (id, name, comment, type, pd_type, "default", sort, "length", is_pk, is_nullable, gen_business_id)

View File

@@ -8,7 +8,7 @@ from backend.app.{{ app_name }}.model import {{ class_name }}
from backend.app.{{ app_name }}.schema.{{ table_name }} import Create{{ schema_name }}Param, Update{{ schema_name }}Param
class CRUD{{ class_name }}(CRUDPlus[{{ schema_name }}]):
class CRUD{{ class_name }}(CRUDPlus[{{ class_name }}]):
async def get(self, db: AsyncSession, pk: int) -> {{ class_name }} | None:
"""
获取{{ doc_comment }}

View File

@@ -5,29 +5,48 @@
'HSTORE', 'INET', 'INT4MULTIRANGE', 'INT4RANGE', 'INT8MULTIRANGE', 'INT8RANGE', 'INTERVAL', 'JSONB', 'JSONPATH',
'MACADDR', 'MACADDR8', 'MONEY', 'NUMMULTIRANGE', 'NUMRANGE', 'OID', 'REGCLASS', 'REGCONFIG', 'TSMULTIRANGE', 'TSQUERY',
'TSRANGE', 'TSTZMULTIRANGE', 'TSTZRANGE', 'TSVECTOR'] %}
{% if default_datetime_column %}
from datetime import datetime
{% set FACTORY_TYPES = ['dict', 'list', 'list[str]', 'list[int]'] %}
{% set pd_types = models|map(attribute='pd_type')|list %}
{% set NEED_MYSQL_DIALECT = database_type == 'mysql' and model_types|select('in', MYSQL_TYPES)|list|length > 0 %}
{% set NEED_PGSQL_DIALECT = database_type == 'postgresql' and model_types|select('in', POSTGRESQL_TYPES)|list|length > 0 %}
{% set NEED_DATE = 'date' in pd_types %}
{% set NEED_DATETIME = datetime_mixin or 'datetime' in pd_types %}
{% set NEED_UNIVERSAL_TEXT = 'TEXT' in model_types or 'Text' in model_types or 'LONGTEXT' in model_types %}
{% set NEED_TIMEZONE = 'TIMESTAMP' in model_types or 'DateTime' in model_types or 'TIMESTAMP WITHOUT TIME ZONE' in model_types or 'TIMESTAMP WITH TIME ZONE' in model_types %}
{% if NEED_DATETIME or NEED_DATE %}
from datetime import {% if NEED_DATETIME %}datetime{% endif %}{% if NEED_DATE %}{% if NEED_DATETIME %}, {% endif %}date{% endif %}
{% endif %}
{% if model_types|select('in', DECIMAL_TYPES)|first %}
from decimal import Decimal
{% endif %}
{% if 'Uuid' in model_types or 'UUID' in model_types %}
from uuid import UUID
{% endif %}
{% endif %}
import sqlalchemy as sa
{% if database_type == 'mysql' -%}
{% if NEED_MYSQL_DIALECT %}
from sqlalchemy.dialects import mysql
{% else -%}
{% endif %}
{% if NEED_PGSQL_DIALECT %}
from sqlalchemy.dialects import postgresql
{% endif -%}
{% endif %}
from sqlalchemy.orm import Mapped, mapped_column
from backend.common.model import {% if default_datetime_column %}Base{% else %}DataClassBase{% endif %}, id_key
from backend.common.model import {% if datetime_mixin %}Base{% else %}DataClassBase{% endif %}, id_key
{%- if NEED_UNIVERSAL_TEXT or NEED_TIMEZONE %}, {% endif %}
{%- if NEED_UNIVERSAL_TEXT %}UniversalText{% endif %}
{%- if NEED_UNIVERSAL_TEXT and NEED_TIMEZONE %}, {% endif %}
{%- if NEED_TIMEZONE %}TimeZone{% endif %}
{% if NEED_TIMEZONE %}
from backend.utils.timezone import timezone
{% endif %}
class {{ class_name }}({% if default_datetime_column %}Base{% else %}DataClassBase{% endif %}):
class {{ class_name }}({% if datetime_mixin %}Base{% else %}DataClassBase{% endif %}):
"""{{ table_comment }}"""
__tablename__ = '{{ table_name }}'
@@ -38,41 +57,36 @@ class {{ class_name }}({% if default_datetime_column %}Base{% else %}DataClassBa
{%- if model.is_nullable %} Mapped[{{ model.pd_type }} | None]
{%- else %} Mapped[{{ model.pd_type }}]
{%- endif %} = mapped_column(
{%- if model.type in ['NVARCHAR', 'String', 'Unicode', 'VARCHAR'] -%}
{%- if model.type in ['TEXT', 'Text', 'LONGTEXT'] -%}
UniversalText
{%- elif model.type in ['TIMESTAMP', 'DateTime', 'TIMESTAMP WITHOUT TIME ZONE', 'TIMESTAMP WITH TIME ZONE'] -%}
TimeZone
{%- elif model.type in ['NVARCHAR', 'String', 'Unicode', 'VARCHAR', 'CHARACTER VARYING'] -%}
sa.String({{ model.length }})
{%- elif database_type == 'mysql' and model.type in MYSQL_TYPES -%}
mysql.{{ model.type }}()
mysql.{{ model.type|sqlalchemy_type }}()
{%- elif database_type == 'postgresql' and model.type in POSTGRESQL_TYPES -%}
{%- else -%}
sa.{{ model.type }}()
{%- endif -%}, default=
{%- if model.is_nullable and model.default == None -%}
None
{%- else -%}
{%- if model.default != None -%}
'{{ model.default }}'
{%- if model.type == 'ARRAY' -%}
postgresql.ARRAY(sa.String)
{%- else -%}
{%- if model.pd_type == 'str' -%}
''
{%- elif model.pd_type == 'int' -%}
0
{%- elif model.pd_type == 'bytes' -%}
b''
{%- elif model.pd_type == 'bool' -%}
True
{%- elif model.pd_type == 'float' -%}
0.0
{%- elif model.pd_type == 'dict' -%}
{}
{%- elif model.pd_type == 'date' or model.pd_type == 'datetime' -%}
timezone.now()
{%- elif model.pd_type == 'list[str]' -%}
()
{%- else -%}
''
{%- endif -%}
postgresql.{{ model.type|sqlalchemy_type }}()
{%- endif -%}
{%- endif -%}{% if model.sort != 0 %}, sort_order={{ model.sort }}{% endif %}, comment=
{%- else -%}
sa.{{ model.type|sqlalchemy_type }}()
{%- endif -%},
{%- if model.is_nullable and model.default == None %} default=None
{%- elif model.default != None %} default='{{ model.default }}'
{%- elif model.pd_type in FACTORY_TYPES %} default_factory={{ model.pd_type.split('[')[0] }}
{%- elif model.pd_type == 'str' %} default=''
{%- elif model.pd_type == 'int' %} default=0
{%- elif model.pd_type == 'bytes' %} default=b''
{%- elif model.pd_type == 'bool' %} default=True
{%- elif model.pd_type == 'float' %} default=0.0
{%- elif model.pd_type == 'date' %} default_factory=date.today
{%- elif model.pd_type == 'datetime' and model.type in ['TIMESTAMP', 'DateTime', 'TIMESTAMP WITHOUT TIME ZONE', 'TIMESTAMP WITH TIME ZONE'] %} default_factory=timezone.now
{%- elif model.pd_type == 'datetime' %} default_factory=datetime.now
{%- else %} default=None
{%- endif -%}, comment=
{%- if model.comment != None -%}
'{{ model.comment }}')
{% else -%}

View File

@@ -1,5 +1,20 @@
{% if default_datetime_column %}
from datetime import datetime
{% set pd_types = models|map(attribute='pd_type')|list %}
{% set NEED_DATETIME = datetime_mixin or 'datetime' in pd_types %}
{% set NEED_DATE = 'date' in pd_types %}
{% set NEED_TIME = 'time' in pd_types %}
{% set NEED_TIMEDELTA = 'timedelta' in pd_types %}
{% set NEED_DECIMAL = 'Decimal' in pd_types %}
{% set NEED_UUID = 'UUID' in pd_types or 'str | UUID' in pd_types %}
{% if NEED_DATETIME or NEED_DATE or NEED_TIME or NEED_TIMEDELTA %}
from datetime import {% if NEED_DATETIME %}datetime{% endif %}{% if NEED_DATE %}{% if NEED_DATETIME %}, {% endif %}date{% endif %}{% if NEED_TIME %}{% if NEED_DATETIME or NEED_DATE %}, {% endif %}time{% endif %}{% if NEED_TIMEDELTA %}{% if NEED_DATETIME or NEED_DATE or NEED_TIME %}, {% endif %}timedelta{% endif %}
{% endif %}
{% if NEED_DECIMAL %}
from decimal import Decimal
{% endif %}
{% if NEED_UUID %}
from uuid import UUID
{% endif %}
from pydantic import ConfigDict, Field
@@ -35,7 +50,7 @@ class Get{{ schema_name }}Detail({{ schema_name }}SchemaBase):
model_config = ConfigDict(from_attributes=True)
id: int
{% if default_datetime_column %}
{% if datetime_mixin %}
created_time: datetime
updated_time: datetime | None = None
{% endif %}

View File

@@ -0,0 +1,11 @@
insert into sys_menu (title, name, path, sort, icon, type, component, perms, status, display, cache, link, remark, parent_id, created_time, updated_time)
values ('{{ doc_comment }}', '{{ schema_name }}', '/{{ app_name }}/{{ table_name|replace("_", "-") }}', 0, 'tabler:list', 1, '/{{ app_name }}/{{ table_name }}/views/index', null, 1, 1, 1, '', null, null, now(), null);
set @parent_menu_id = LAST_INSERT_ID();
insert into sys_menu (title, name, path, sort, icon, type, component, perms, status, display, cache, link, remark, parent_id, created_time, updated_time)
values
('新增', 'Add{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:add', 1, 0, 1, '', null, @parent_menu_id, now(), null),
('修改', 'Edit{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:edit', 1, 0, 1, '', null, @parent_menu_id, now(), null),
('删除', 'Delete{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:del', 1, 0, 1, '', null, @parent_menu_id, now(), null),
('查询', 'Get{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:get', 1, 0, 1, '', null, @parent_menu_id, now(), null);

View File

@@ -0,0 +1,9 @@
insert into sys_menu (id, title, name, path, sort, icon, type, component, perms, status, display, cache, link, remark, parent_id, created_time, updated_time)
values ({{ parent_menu_id }}, '{{ doc_comment }}', '{{ schema_name }}', '/{{ app_name }}/{{ table_name|replace("_", "-") }}', 0, 'tabler:list', 1, '/{{ app_name }}/{{ table_name }}/views/index', null, 1, 1, 1, '', null, null, now(), null);
insert into sys_menu (id, title, name, path, sort, icon, type, component, perms, status, display, cache, link, remark, parent_id, created_time, updated_time)
values
({{ button_ids[0] }}, '新增', 'Add{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:add', 1, 0, 1, '', null, {{ parent_menu_id }}, now(), null),
({{ button_ids[1] }}, '修改', 'Edit{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:edit', 1, 0, 1, '', null, {{ parent_menu_id }}, now(), null),
({{ button_ids[2] }}, '删除', 'Delete{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:del', 1, 0, 1, '', null, {{ parent_menu_id }}, now(), null),
({{ button_ids[3] }}, '查询', 'Get{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:get', 1, 0, 1, '', null, {{ parent_menu_id }}, now(), null);

View File

@@ -0,0 +1,17 @@
do $$
declare
parent_menu_id bigint;
begin
insert into sys_menu (title, name, path, sort, icon, type, component, perms, status, display, cache, link, remark, parent_id, created_time, updated_time)
values ('{{ doc_comment }}', '{{ schema_name }}', '/{{ app_name }}/{{ table_name|replace("_", "-") }}', 0, 'tabler:list', 1, '/{{ app_name }}/{{ table_name }}/views/index', null, 1, 1, 1, '', null, null, now(), null)
returning id into parent_menu_id;
insert into sys_menu (title, name, path, sort, icon, type, component, perms, status, display, cache, link, remark, parent_id, created_time, updated_time)
values
('新增', 'Add{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:add', 1, 0, 1, '', null, parent_menu_id, now(), null),
('修改', 'Edit{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:edit', 1, 0, 1, '', null, parent_menu_id, now(), null),
('删除', 'Delete{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:del', 1, 0, 1, '', null, parent_menu_id, now(), null),
('查询', 'Get{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:get', 1, 0, 1, '', null, parent_menu_id, now(), null);
end $$;
select setval(pg_get_serial_sequence('sys_menu', 'id'),coalesce(max(id), 0) + 1, true) from sys_menu;

View File

@@ -0,0 +1,9 @@
insert into sys_menu (id, title, name, path, sort, icon, type, component, perms, status, display, cache, link, remark, parent_id, created_time, updated_time)
values ({{ parent_menu_id }}, '{{ doc_comment }}', '{{ schema_name }}', '/{{ app_name }}/{{ table_name|replace("_", "-") }}', 0, 'tabler:list', 1, '/{{ app_name }}/{{ table_name }}/views/index', null, 1, 1, 1, '', null, null, now(), null);
insert into sys_menu (id, title, name, path, sort, icon, type, component, perms, status, display, cache, link, remark, parent_id, created_time, updated_time)
values
({{ button_ids[0] }}, '新增', 'Add{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:add', 1, 0, 1, '', null, {{ parent_menu_id }}, now(), null),
({{ button_ids[1] }}, '修改', 'Edit{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:edit', 1, 0, 1, '', null, {{ parent_menu_id }}, now(), null),
({{ button_ids[2] }}, '删除', 'Delete{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:del', 1, 0, 1, '', null, {{ parent_menu_id }}, now(), null),
({{ button_ids[3] }}, '查询', 'Get{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:get', 1, 0, 1, '', null, {{ parent_menu_id }}, now(), null);

View File

@@ -0,0 +1,43 @@
import asyncio
import tempfile
import uuid
import anyio
from anyio import open_file
async def format_python_code(code: str) -> str:
"""
使用 ruff 格式化 Python 代码
:param code: 原始代码
:return:
"""
temp_dir = anyio.Path(tempfile.gettempdir())
temp_file = temp_dir / f'fba_codegen_{uuid.uuid4().hex}.py'
try:
async with await open_file(temp_file, 'w', encoding='utf-8') as f:
await f.write(code)
process = await asyncio.create_subprocess_exec(
'ruff',
'format',
str(temp_file),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
await process.communicate()
if process.returncode == 0:
async with await open_file(temp_file, encoding='utf-8') as f:
formatted_code = await f.read()
else:
formatted_code = code
except (FileNotFoundError, OSError):
return code
finally:
await temp_file.unlink(missing_ok=True)
return formatted_code

View File

@@ -1,10 +1,14 @@
from collections.abc import Sequence
from jinja2 import Environment, FileSystemLoader, Template, select_autoescape
from pydantic.alias_generators import to_pascal
from backend.core.conf import settings
from backend.plugin.code_generator.model import GenBusiness, GenColumn
from backend.plugin.code_generator.path_conf import JINJA2_TEMPLATE_DIR
from backend.plugin.code_generator.utils.type_conversion import sql_type_to_sqlalchemy_name
from backend.utils.snowflake import snowflake
from backend.utils.timezone import timezone
class GenTemplate:
@@ -18,60 +22,64 @@ class GenTemplate:
keep_trailing_newline=True,
enable_async=True,
)
self.env.filters['sqlalchemy_type'] = sql_type_to_sqlalchemy_name
self.init_content = ''
def get_template(self, jinja_file: str) -> Template:
"""
获取模板文件
获取 Jinja2 模板对象
:param jinja_file: Jinja2 模板文件
:return:
:param jinja_file: Jinja2 模板文件路径
:return: Template 对象
"""
return self.env.get_template(jinja_file)
@staticmethod
def get_template_files() -> list[str]:
def get_template_path_mapping(business: GenBusiness) -> dict[str, str]:
"""
获取模板文件列表
:return:
"""
return [
'python/api.jinja',
'python/crud.jinja',
'python/model.jinja',
'python/schema.jinja',
'python/service.jinja',
]
@staticmethod
def get_code_gen_paths(business: GenBusiness) -> list[str]:
"""
获取代码生成路径列表
获取模板文件到生成文件的路径映射
:param business: 代码生成业务对象
:return:
:return: {模板路径: 生成文件路径}
"""
app_name = business.app_name
filename = business.filename
return [
f'{app_name}/api/{business.api_version}/{filename}.py',
f'{app_name}/crud/crud_{filename}.py',
f'{app_name}/model/{filename}.py',
f'{app_name}/schema/{filename}.py',
f'{app_name}/service/{filename}_service.py',
]
api_version = business.api_version
pk_suffix = '_snowflake' if settings.DATABASE_PK_MODE == 'snowflake' else ''
def get_code_gen_path(self, tpl_path: str, business: GenBusiness) -> str:
"""
获取代码生成路径
return {
'python/api.jinja': f'{app_name}/api/{api_version}/{filename}.py',
'python/crud.jinja': f'{app_name}/crud/crud_{filename}.py',
'python/model.jinja': f'{app_name}/model/{filename}.py',
'python/schema.jinja': f'{app_name}/schema/{filename}.py',
'python/service.jinja': f'{app_name}/service/{filename}_service.py',
f'sql/mysql/init{pk_suffix}.jinja': f'{app_name}/sql/mysql/init{pk_suffix}.sql',
f'sql/postgresql/init{pk_suffix}.jinja': f'{app_name}/sql/postgresql/init{pk_suffix}.sql',
}
:param tpl_path: 模板文件路径
:param business: 代码生成业务对象
:return:
def get_init_files(self, business: GenBusiness) -> dict[str, str]:
"""
code_gen_path_mapping = dict(zip(self.get_template_files(), self.get_code_gen_paths(business)))
return code_gen_path_mapping[tpl_path]
获取需要生成的 __init__.py 文件及其内容
:param business: 业务对象
:return: {相对路径: 文件内容}
"""
app_name = business.app_name
table_name = business.table_name
class_name = business.class_name or to_pascal(table_name)
return {
f'{app_name}/__init__.py': self.init_content,
f'{app_name}/api/__init__.py': self.init_content,
f'{app_name}/api/{business.api_version}/__init__.py': self.init_content,
f'{app_name}/crud/__init__.py': self.init_content,
f'{app_name}/model/__init__.py': (
f'{self.init_content}'
f'from backend.app.{app_name}.model.{table_name} import {class_name} as {class_name}\n'
),
f'{app_name}/schema/__init__.py': self.init_content,
f'{app_name}/service/__init__.py': self.init_content,
}
@staticmethod
def get_vars(business: GenBusiness, models: Sequence[GenColumn]) -> dict[str, str | Sequence[GenColumn]]:
@@ -82,19 +90,26 @@ class GenTemplate:
:param models: 代码生成模型对象列表
:return:
"""
return {
vars_dict = {
'app_name': business.app_name,
'table_name': business.table_name,
'doc_comment': business.doc_comment,
'table_comment': business.table_comment,
'class_name': business.class_name,
'schema_name': business.schema_name,
'default_datetime_column': business.default_datetime_column,
'permission': str(business.table_name.replace('_', ':')),
'datetime_mixin': business.datetime_mixin,
'permission': business.table_name.replace('_', ':'),
'database_type': settings.DATABASE_TYPE,
'models': models,
'model_types': [model.type for model in models],
'now': timezone.now(),
}
if settings.DATABASE_PK_MODE == 'snowflake':
vars_dict['parent_menu_id'] = snowflake.generate()
vars_dict['button_ids'] = [snowflake.generate() for _ in range(4)]
return vars_dict
gen_template: GenTemplate = GenTemplate()

View File

@@ -17,7 +17,8 @@ def sql_type_to_sqlalchemy(typing: str) -> str:
if typing in GenMySQLColumnType.get_member_keys():
return typing
else:
if typing in GenPostgreSQLColumnType.get_member_keys():
normalized_type = typing.replace(' ', '_')
if normalized_type in GenPostgreSQLColumnType.get_member_keys():
return typing
return 'String'
@@ -33,8 +34,31 @@ def sql_type_to_pydantic(typing: str) -> str:
try:
if DataBaseType.mysql == settings.DATABASE_TYPE:
return GenMySQLColumnType[typing].value
if typing == 'CHARACTER VARYING': # postgresql 中 DDL VARCHAR 的别名
return 'str'
return GenPostgreSQLColumnType[typing].value
normalized_type = typing.replace(' ', '_')
return GenPostgreSQLColumnType[normalized_type].value
except KeyError:
return 'str'
@lru_cache(maxsize=128)
def sql_type_to_sqlalchemy_name(typing: str) -> str:
"""
将 SQL 类型转换为有效的 SQLAlchemy 类型名称(用于代码生成)
:param typing: SQL 类型字符串
:return:
"""
pg_type_mapping = {
'CHARACTER VARYING': 'String',
'CHARACTER': 'CHAR',
'TIMESTAMP WITHOUT TIME ZONE': 'TIMESTAMP',
'TIMESTAMP WITH TIME ZONE': 'TIMESTAMP',
'TIME WITHOUT TIME ZONE': 'TIME',
'TIME WITH TIME ZONE': 'TIME',
'DOUBLE PRECISION': 'Double',
}
if DataBaseType.postgresql == settings.DATABASE_TYPE and typing in pg_type_mapping:
return pg_type_mapping[typing]
return typing

View File

@@ -76,3 +76,14 @@ def is_has_special_char(value: str) -> re.Match[str]:
"""
special_char_pattern = r'[!@#$%^&*()_+\-=\[\]{};:\'",.<>?/\\|`~]'
return search_string(special_char_pattern, value)
def is_english_identifier(value: str) -> re.Match[str]:
"""
检查英文标识符
:param value: 待检查的值
:return:
"""
identifier_pattern = r'^[a-zA-Z][a-zA-Z_]*$'
return match_string(identifier_pattern, value)