mirror of
https://github.com/fastapi-practices/fastapi_best_architecture.git
synced 2026-03-13 09:31:31 +08:00
Optimize code generation data processing (#1020)
* Optimize code generation data processing * Add SQL script generation * Update code gen table scripts * Update types * Fix model jinja and type conversion
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import Row, RowMapping, text
|
||||
from sqlalchemy import RowMapping, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.common.enums import DataBaseType
|
||||
@@ -21,57 +21,75 @@ class CRUDGen:
|
||||
"""
|
||||
if DataBaseType.mysql == settings.DATABASE_TYPE:
|
||||
sql = """
|
||||
SELECT table_name AS table_name, table_comment AS table_comment
|
||||
FROM information_schema.tables
|
||||
WHERE table_name NOT LIKE 'sys_gen_%'
|
||||
AND table_schema = :table_schema;
|
||||
SELECT TABLE_NAME AS TABLE_NAME,
|
||||
table_comment AS table_comment
|
||||
FROM
|
||||
information_schema.TABLES
|
||||
WHERE
|
||||
TABLE_NAME NOT LIKE'sys_gen_%'
|
||||
AND table_schema = :table_schema;
|
||||
"""
|
||||
stmt = text(sql).bindparams(table_schema=table_schema)
|
||||
else:
|
||||
sql = """
|
||||
SELECT c.relname AS table_name, obj_description(c.oid) AS table_comment
|
||||
FROM pg_class c
|
||||
LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE c.relkind = 'r'
|
||||
AND n.nspname = 'public' -- schema 通常是 'public'
|
||||
AND c.relname NOT LIKE 'sys_gen_%';
|
||||
SELECT
|
||||
c.relname AS TABLE_NAME,
|
||||
obj_description (c.OID) AS table_comment
|
||||
FROM
|
||||
pg_class c
|
||||
LEFT JOIN pg_namespace n ON n.OID = c.relnamespace
|
||||
WHERE
|
||||
c.relkind = 'r'
|
||||
AND c.relname NOT LIKE'sys_gen_%'
|
||||
AND n.nspname = :table_schema;
|
||||
"""
|
||||
stmt = text(sql)
|
||||
stmt = text(sql).bindparams(table_schema='public')
|
||||
result = await db.execute(stmt)
|
||||
return result.mappings().all()
|
||||
|
||||
@staticmethod
|
||||
async def get_table(db: AsyncSession, table_name: str) -> Row[tuple]:
|
||||
async def get_table(db: AsyncSession, table_schema: str, table_name: str) -> RowMapping | None:
|
||||
"""
|
||||
获取表信息
|
||||
|
||||
:param db: 数据库会话
|
||||
:param table_schema: 数据库 schema 名称
|
||||
:param table_name: 表名
|
||||
:return:
|
||||
"""
|
||||
if DataBaseType.mysql == settings.DATABASE_TYPE:
|
||||
sql = """
|
||||
SELECT table_name AS table_name, table_comment AS table_comment
|
||||
FROM information_schema.tables
|
||||
WHERE table_name NOT LIKE 'sys_gen_%'
|
||||
AND table_name = :table_name;
|
||||
SELECT TABLE_NAME AS TABLE_NAME,
|
||||
table_comment AS table_comment
|
||||
FROM
|
||||
information_schema.TABLES
|
||||
WHERE
|
||||
TABLE_NAME NOT LIKE'sys_gen_%'
|
||||
AND TABLE_NAME = :table_name
|
||||
AND table_schema = :table_schema;
|
||||
"""
|
||||
stmt = text(sql).bindparams(table_schema=table_schema, table_name=table_name)
|
||||
else:
|
||||
sql = """
|
||||
SELECT c.relname AS table_name, obj_description(c.oid) AS table_comment
|
||||
FROM pg_class c
|
||||
LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE c.relkind = 'r'
|
||||
AND n.nspname = 'public' -- schema 通常是 'public'
|
||||
AND c.relname = :table_name
|
||||
AND c.relname NOT LIKE 'sys_gen_%';
|
||||
SELECT
|
||||
c.relname AS table_name,
|
||||
obj_description (c.OID) AS table_comment
|
||||
FROM
|
||||
pg_class c
|
||||
LEFT JOIN pg_namespace n ON n.OID = c.relnamespace
|
||||
WHERE
|
||||
c.relkind = 'r'
|
||||
AND c.relname NOT LIKE'sys_gen_%'
|
||||
AND c.relname = :table_name
|
||||
AND n.nspname = :table_schema;
|
||||
"""
|
||||
stmt = text(sql).bindparams(table_name=table_name)
|
||||
stmt = text(sql).bindparams(table_schema='public', table_name=table_name)
|
||||
result = await db.execute(stmt)
|
||||
return result.fetchone()
|
||||
row = result.fetchone()
|
||||
return row._mapping if row else None
|
||||
|
||||
@staticmethod
|
||||
async def get_all_columns(db: AsyncSession, table_schema: str, table_name: str) -> Sequence[Row[tuple]]:
|
||||
async def get_all_columns(db: AsyncSession, table_schema: str, table_name: str) -> Sequence[RowMapping]:
|
||||
"""
|
||||
获取所有列信息
|
||||
|
||||
@@ -82,54 +100,73 @@ class CRUDGen:
|
||||
"""
|
||||
if DataBaseType.mysql == settings.DATABASE_TYPE:
|
||||
sql = """
|
||||
SELECT column_name AS column_name,
|
||||
CASE WHEN column_key = 'PRI' THEN 1 ELSE 0 END AS is_pk,
|
||||
CASE WHEN is_nullable = 'NO' OR column_key = 'PRI' THEN 0 ELSE 1 END AS is_nullable,
|
||||
ordinal_position AS sort, column_comment AS column_comment,
|
||||
column_type AS column_type FROM information_schema.columns
|
||||
WHERE table_schema = :table_schema
|
||||
AND table_name = :table_name
|
||||
AND column_name <> 'id'
|
||||
AND column_name <> 'created_time'
|
||||
AND column_name <> 'updated_time'
|
||||
ORDER BY sort;
|
||||
SELECT COLUMN_NAME AS COLUMN_NAME,
|
||||
CASE
|
||||
WHEN column_key = 'PRI' THEN
|
||||
1
|
||||
ELSE
|
||||
0
|
||||
END AS is_pk,
|
||||
CASE
|
||||
WHEN is_nullable = 'NO'
|
||||
OR column_key = 'PRI' THEN
|
||||
0
|
||||
ELSE
|
||||
1
|
||||
END AS is_nullable,
|
||||
ordinal_position AS sort,
|
||||
column_comment AS column_comment,
|
||||
column_type AS column_type
|
||||
FROM
|
||||
information_schema.COLUMNS
|
||||
WHERE
|
||||
COLUMN_NAME <> 'id'
|
||||
AND COLUMN_NAME <> 'created_time'
|
||||
AND COLUMN_NAME <> 'updated_time'
|
||||
AND TABLE_NAME = :table_name
|
||||
AND table_schema = :table_schema
|
||||
ORDER BY
|
||||
sort;
|
||||
"""
|
||||
stmt = text(sql).bindparams(table_schema=table_schema, table_name=table_name)
|
||||
else:
|
||||
sql = """
|
||||
SELECT a.attname AS column_name,
|
||||
CASE WHEN EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_constraint c
|
||||
WHERE c.conrelid = t.oid
|
||||
AND c.contype = 'p'
|
||||
AND a.attnum = ANY(c.conkey)
|
||||
) THEN 1 ELSE 0 END AS is_pk,
|
||||
CASE WHEN a.attnotnull OR EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_constraint c
|
||||
WHERE c.conrelid = t.oid
|
||||
AND c.contype = 'p'
|
||||
AND a.attnum = ANY(c.conkey)
|
||||
) THEN 0 ELSE 1 END AS is_nullable,
|
||||
a.attnum AS sort,
|
||||
col_description(t.oid, a.attnum) AS column_comment,
|
||||
pg_catalog.format_type(a.atttypid, a.atttypmod) AS column_type
|
||||
FROM pg_attribute a
|
||||
JOIN pg_class t ON a.attrelid = t.oid
|
||||
JOIN pg_namespace n ON n.oid = t.relnamespace
|
||||
WHERE n.nspname = 'public' -- 根据你的实际情况修改 schema 名称,通常是 'public'
|
||||
AND t.relname = :table_name
|
||||
AND a.attnum > 0
|
||||
AND NOT a.attisdropped
|
||||
AND a.attname <> 'id'
|
||||
AND a.attname <> 'created_time'
|
||||
AND a.attname <> 'updated_time'
|
||||
ORDER BY sort;
|
||||
"""
|
||||
stmt = text(sql).bindparams(table_name=table_name)
|
||||
SELECT
|
||||
a.attname AS COLUMN_NAME,
|
||||
CASE
|
||||
WHEN EXISTS (SELECT 1 FROM pg_constraint c WHERE c.conrelid = t.OID AND c.contype = 'p' AND a.attnum = ANY (c.conkey)) THEN
|
||||
1
|
||||
ELSE
|
||||
0
|
||||
END AS is_pk,
|
||||
CASE
|
||||
WHEN a.attnotnull
|
||||
OR EXISTS (SELECT 1 FROM pg_constraint c WHERE c.conrelid = t.OID AND c.contype = 'p' AND a.attnum = ANY (c.conkey)) THEN
|
||||
0
|
||||
ELSE
|
||||
1
|
||||
END AS is_nullable,
|
||||
a.attnum AS sort,
|
||||
col_description (t.OID, a.attnum) AS column_comment,
|
||||
pg_catalog.format_type (a.atttypid, a.atttypmod) AS column_type
|
||||
FROM
|
||||
pg_attribute a
|
||||
JOIN pg_class t ON a.attrelid = t.
|
||||
OID JOIN pg_namespace n ON n.OID = t.relnamespace
|
||||
WHERE
|
||||
a.attnum > 0
|
||||
AND NOT a.attisdropped
|
||||
AND a.attname <> 'id'
|
||||
AND a.attname <> 'created_time'
|
||||
AND a.attname <> 'updated_time'
|
||||
AND t.relname = :table_name
|
||||
AND n.nspname = :table_schema
|
||||
ORDER BY
|
||||
sort;
|
||||
""" # noqa: E501
|
||||
stmt = text(sql).bindparams(table_schema='public', table_name=table_name)
|
||||
result = await db.execute(stmt)
|
||||
return result.fetchall()
|
||||
return result.mappings().all()
|
||||
|
||||
|
||||
gen_dao: CRUDGen = CRUDGen()
|
||||
|
||||
@@ -19,34 +19,34 @@ class GenMySQLColumnType(StrEnum):
|
||||
DateTime = 'datetime' # DATETIME
|
||||
DECIMAL = 'Decimal'
|
||||
DOUBLE = 'float'
|
||||
Double = 'float' # DOUBLE
|
||||
DOUBLE_PRECISION = 'float'
|
||||
Double = 'float' # DOUBLE
|
||||
Enum = 'Enum' # Enum()
|
||||
FLOAT = 'float'
|
||||
Float = 'float' # FLOAT
|
||||
INT = 'int' # INTEGER
|
||||
INTEGER = 'int'
|
||||
Integer = 'int' # INTEGER
|
||||
Interval = 'timedelta' # DATETIME
|
||||
Interval = 'timedelta' # INTERVAL
|
||||
JSON = 'dict'
|
||||
LargeBinary = 'bytes' # BLOB
|
||||
NCHAR = 'str'
|
||||
NUMERIC = 'Decimal'
|
||||
Numeric = 'Decimal' # NUMERIC
|
||||
NVARCHAR = 'str' # String
|
||||
Numeric = 'Decimal' # NUMERIC
|
||||
PickleType = 'bytes' # BLOB
|
||||
REAL = 'float'
|
||||
SMALLINT = 'int'
|
||||
SmallInteger = 'int' # SMALLINT
|
||||
String = 'str' # String
|
||||
TEXT = 'str'
|
||||
Text = 'str' # TEXT
|
||||
TIME = 'time'
|
||||
Time = 'time' # TIME
|
||||
TIMESTAMP = 'datetime'
|
||||
Text = 'str' # TEXT
|
||||
Time = 'time' # TIME
|
||||
UUID = 'str | UUID'
|
||||
Unicode = 'str' # String
|
||||
UnicodeText = 'str' # TEXT
|
||||
UUID = 'str | UUID'
|
||||
Uuid = 'str' # CHAR(32)
|
||||
VARBINARY = 'bytes'
|
||||
VARCHAR = 'str' # String
|
||||
@@ -77,6 +77,8 @@ class GenPostgreSQLColumnType(StrEnum):
|
||||
BOOLEAN = 'bool'
|
||||
Boolean = 'bool' # BOOLEAN
|
||||
CHAR = 'str'
|
||||
CHARACTER = 'str' # CHAR
|
||||
CHARACTER_VARYING = 'str' # CHARACTER VARYING
|
||||
CLOB = 'str'
|
||||
DATE = 'date'
|
||||
Date = 'date' # DATE
|
||||
@@ -84,8 +86,8 @@ class GenPostgreSQLColumnType(StrEnum):
|
||||
DateTime = 'datetime' # TIMESTAMP WITHOUT TIME ZONE
|
||||
DECIMAL = 'Decimal'
|
||||
DOUBLE = 'float'
|
||||
Double = 'float' # DOUBLE PRECISION
|
||||
DOUBLE_PRECISION = 'float' # DOUBLE PRECISION
|
||||
Double = 'float' # DOUBLE PRECISION
|
||||
Enum = 'Enum' # Enum(name='enum')
|
||||
FLOAT = 'float'
|
||||
Float = 'float' # FLOAT
|
||||
@@ -97,21 +99,25 @@ class GenPostgreSQLColumnType(StrEnum):
|
||||
LargeBinary = 'bytes' # BYTEA
|
||||
NCHAR = 'str'
|
||||
NUMERIC = 'Decimal'
|
||||
Numeric = 'Decimal' # NUMERIC
|
||||
NVARCHAR = 'str' # String
|
||||
Numeric = 'Decimal' # NUMERIC
|
||||
PickleType = 'bytes' # BYTEA
|
||||
REAL = 'float'
|
||||
SMALLINT = 'int'
|
||||
SmallInteger = 'int' # SMALLINT
|
||||
String = 'str' # String
|
||||
TEXT = 'str'
|
||||
Text = 'str' # TEXT
|
||||
TIME = 'time' # TIME WITHOUT TIME ZONE
|
||||
Time = 'time' # TIME WITHOUT TIME ZONE
|
||||
TIME_WITHOUT_TIME_ZONE = 'time' # TIME WITHOUT TIME ZONE
|
||||
TIME_WITH_TIME_ZONE = 'time' # TIME WITH TIME ZONE
|
||||
TIMESTAMP = 'datetime' # TIMESTAMP WITHOUT TIME ZONE
|
||||
TIMESTAMP_WITHOUT_TIME_ZONE = 'datetime' # TIMESTAMP WITHOUT TIME ZONE
|
||||
TIMESTAMP_WITH_TIME_ZONE = 'datetime' # TIMESTAMP WITH TIME ZONE
|
||||
Text = 'str' # TEXT
|
||||
Time = 'time' # TIME WITHOUT TIME ZONE
|
||||
UUID = 'str | UUID'
|
||||
Unicode = 'str' # String
|
||||
UnicodeText = 'str' # TEXT
|
||||
UUID = 'str | UUID'
|
||||
Uuid = 'str'
|
||||
VARBINARY = 'bytes'
|
||||
VARCHAR = 'str' # String
|
||||
|
||||
@@ -11,19 +11,14 @@ class GenBusiness(Base):
|
||||
__tablename__ = 'gen_business'
|
||||
|
||||
id: Mapped[id_key] = mapped_column(init=False)
|
||||
app_name: Mapped[str] = mapped_column(sa.String(64), comment='应用名称(英文)')
|
||||
table_name: Mapped[str] = mapped_column(sa.String(256), unique=True, comment='表名称(英文)')
|
||||
doc_comment: Mapped[str] = mapped_column(sa.String(256), comment='文档注释(用于函数/参数文档)')
|
||||
app_name: Mapped[str] = mapped_column(sa.String(64), comment='应用名称')
|
||||
table_name: Mapped[str] = mapped_column(sa.String(256), unique=True, comment='表名称')
|
||||
doc_comment: Mapped[str] = mapped_column(sa.String(256), comment='文档注释')
|
||||
table_comment: Mapped[str | None] = mapped_column(sa.String(256), default=None, comment='表描述')
|
||||
# relate_model_fk: Mapped[int | None] = mapped_column(default=None, comment='关联表外键')
|
||||
class_name: Mapped[str | None] = mapped_column(sa.String(64), default=None, comment='基础类名(默认为英文表名称)')
|
||||
schema_name: Mapped[str | None] = mapped_column(
|
||||
sa.String(64), default=None, comment='Schema 名称 (默认为英文表名称)'
|
||||
)
|
||||
filename: Mapped[str | None] = mapped_column(sa.String(64), default=None, comment='基础文件名(默认为英文表名称)')
|
||||
default_datetime_column: Mapped[bool] = mapped_column(default=True, comment='是否存在默认时间列')
|
||||
api_version: Mapped[str] = mapped_column(sa.String(32), default='v1', comment='代码生成 api 版本,默认为 v1')
|
||||
gen_path: Mapped[str | None] = mapped_column(
|
||||
sa.String(256), default=None, comment='代码生成路径(默认为 app 根路径)'
|
||||
)
|
||||
class_name: Mapped[str | None] = mapped_column(sa.String(64), default=None, comment='基础类名')
|
||||
schema_name: Mapped[str | None] = mapped_column(sa.String(64), default=None, comment='Schema 名称')
|
||||
filename: Mapped[str | None] = mapped_column(sa.String(64), default=None, comment='基础文件名')
|
||||
datetime_mixin: Mapped[bool] = mapped_column(default=True, comment='是否包含时间 Mixin 列')
|
||||
api_version: Mapped[str] = mapped_column(sa.String(32), default='v1', comment='API 版本')
|
||||
gen_path: Mapped[str | None] = mapped_column(sa.String(256), default=None, comment='生成路径')
|
||||
remark: Mapped[str | None] = mapped_column(UniversalText, default=None, comment='备注')
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[plugin]
|
||||
summary = '代码生成'
|
||||
version = '0.0.7'
|
||||
version = '0.1.0'
|
||||
description = '生成通用业务代码'
|
||||
author = 'wu-clan'
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import ConfigDict, Field
|
||||
from pydantic import ConfigDict, Field, field_validator
|
||||
|
||||
from backend.common.exception import errors
|
||||
from backend.common.schema import SchemaBase
|
||||
from backend.utils.pattern_validate import is_english_identifier
|
||||
|
||||
|
||||
class GenBusinessSchemaBase(SchemaBase):
|
||||
@@ -15,11 +17,19 @@ class GenBusinessSchemaBase(SchemaBase):
|
||||
class_name: str | None = Field(None, description='用于 python 代码基础类名')
|
||||
schema_name: str | None = Field(None, description='用于 python Schema 代码基础类名')
|
||||
filename: str | None = Field(None, description='用于 python 代码基础文件名')
|
||||
default_datetime_column: bool = Field(True, description='是否存在默认时间列')
|
||||
api_version: str = Field('v1', description='代码生成 api 版本')
|
||||
gen_path: str | None = Field(None, description='代码生成路径')
|
||||
datetime_mixin: bool = Field(True, description='是否包含时间 Mixin 列')
|
||||
api_version: str = Field('v1', description='API 版本')
|
||||
gen_path: str | None = Field(None, description='生成路径(默认在 backend/app 目录下)')
|
||||
remark: str | None = Field(None, description='备注')
|
||||
|
||||
@field_validator('app_name', 'table_name')
|
||||
@classmethod
|
||||
def validate_english_only(cls, v: str) -> str:
|
||||
"""验证英文字段"""
|
||||
if not is_english_identifier(v):
|
||||
raise errors.RequestError(msg='必须以英文字母开头且只能包含英文字母和下划线')
|
||||
return v
|
||||
|
||||
|
||||
class CreateGenBusinessParam(GenBusinessSchemaBase):
|
||||
"""创建代码生成业务参数"""
|
||||
|
||||
@@ -19,8 +19,8 @@ class GenColumnSchemaBase(SchemaBase):
|
||||
|
||||
@field_validator('type')
|
||||
@classmethod
|
||||
def type_update(cls, v: str) -> str:
|
||||
"""更新列类型"""
|
||||
def normalize_type(cls, v: str) -> str:
|
||||
"""规范化类型"""
|
||||
return sql_type_to_sqlalchemy(v)
|
||||
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from backend.plugin.code_generator.schema.business import CreateGenBusinessParam
|
||||
from backend.plugin.code_generator.schema.column import CreateGenColumnParam
|
||||
from backend.plugin.code_generator.schema.gen import ImportParam
|
||||
from backend.plugin.code_generator.service.column_service import gen_column_service
|
||||
from backend.plugin.code_generator.utils.format_code import format_python_code
|
||||
from backend.plugin.code_generator.utils.gen_template import gen_template
|
||||
from backend.plugin.code_generator.utils.type_conversion import sql_type_to_pydantic
|
||||
|
||||
@@ -50,7 +51,7 @@ class GenService:
|
||||
:return:
|
||||
"""
|
||||
|
||||
table_info = await gen_dao.get_table(db, obj.table_name)
|
||||
table_info = await gen_dao.get_table(db, obj.table_schema, obj.table_name)
|
||||
if not table_info:
|
||||
raise errors.NotFoundError(msg='数据库表不存在')
|
||||
|
||||
@@ -58,13 +59,13 @@ class GenService:
|
||||
if business_info:
|
||||
raise errors.ConflictError(msg='已存在相同数据库表业务')
|
||||
|
||||
table_name = table_info[0]
|
||||
table_name = table_info['table_name']
|
||||
new_business = GenBusiness(
|
||||
**CreateGenBusinessParam(
|
||||
app_name=obj.app,
|
||||
table_name=table_name,
|
||||
doc_comment=table_info[1] or table_name.split('_')[-1],
|
||||
table_comment=table_info[1],
|
||||
doc_comment=table_info['table_comment'] or table_name.split('_')[-1],
|
||||
table_comment=table_info['table_comment'],
|
||||
class_name=to_pascal(table_name),
|
||||
schema_name=to_pascal(table_name),
|
||||
filename=table_name,
|
||||
@@ -75,25 +76,27 @@ class GenService:
|
||||
|
||||
column_info = await gen_dao.get_all_columns(db, obj.table_schema, table_name)
|
||||
for column in column_info:
|
||||
column_type = column[-1].split('(')[0].upper()
|
||||
column_type = column['column_type'].split('(')[0].upper()
|
||||
pd_type = sql_type_to_pydantic(column_type)
|
||||
await gen_column_dao.create(
|
||||
db,
|
||||
CreateGenColumnParam(
|
||||
name=column[0],
|
||||
comment=column[-2],
|
||||
name=column['column_name'],
|
||||
comment=column['column_comment'],
|
||||
type=column_type,
|
||||
sort=column[-3],
|
||||
length=column[-1].split('(')[1][:-1] if pd_type == 'str' and '(' in column[-1] else 0,
|
||||
is_pk=column[1],
|
||||
is_nullable=column[2],
|
||||
sort=column['sort'],
|
||||
length=column['column_type'].split('(')[1][:-1]
|
||||
if pd_type == 'str' and '(' in column['column_type']
|
||||
else 0,
|
||||
is_pk=column['is_pk'],
|
||||
is_nullable=column['is_nullable'],
|
||||
gen_business_id=new_business.id,
|
||||
),
|
||||
pd_type=pd_type,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def render_tpl_code(*, db: AsyncSession, business: GenBusiness) -> dict[str, str]:
|
||||
async def _render_tpl_code(*, db: AsyncSession, business: GenBusiness) -> dict[str, str]:
|
||||
"""
|
||||
渲染模板代码
|
||||
|
||||
@@ -106,10 +109,16 @@ class GenService:
|
||||
raise errors.NotFoundError(msg='代码生成模型表为空')
|
||||
|
||||
gen_vars = gen_template.get_vars(business, gen_models)
|
||||
return {
|
||||
tpl_path: await gen_template.get_template(tpl_path).render_async(**gen_vars)
|
||||
for tpl_path in gen_template.get_template_files()
|
||||
}
|
||||
template_mapping = gen_template.get_template_path_mapping(business)
|
||||
|
||||
rendered_codes = {}
|
||||
for template_path, output_path in template_mapping.items():
|
||||
code = await gen_template.get_template(template_path).render_async(**gen_vars)
|
||||
if output_path.endswith('.py'):
|
||||
code = await format_python_code(code)
|
||||
rendered_codes[output_path] = code
|
||||
|
||||
return rendered_codes
|
||||
|
||||
async def preview(self, *, db: AsyncSession, pk: int) -> dict[str, bytes]:
|
||||
"""
|
||||
@@ -119,33 +128,20 @@ class GenService:
|
||||
:param pk: 业务 ID
|
||||
:return:
|
||||
"""
|
||||
|
||||
business = await gen_business_dao.get(db, pk)
|
||||
if not business:
|
||||
raise errors.NotFoundError(msg='业务不存在')
|
||||
|
||||
tpl_code_map = await self.render_tpl_code(db=db, business=business)
|
||||
|
||||
codes = {}
|
||||
for tpl_path, code in tpl_code_map.items():
|
||||
if tpl_path.startswith('python'):
|
||||
rootpath = f'fastapi_best_architecture/backend/app/{business.app_name}'
|
||||
template_name = tpl_path.split('/')[-1]
|
||||
filepath = None
|
||||
match template_name:
|
||||
case 'api.jinja':
|
||||
filepath = f'{rootpath}/api/{business.api_version}/{business.filename}.py'
|
||||
case 'crud.jinja':
|
||||
filepath = f'{rootpath}/crud/crud_{business.filename}.py'
|
||||
case 'model.jinja':
|
||||
filepath = f'{rootpath}/model/{business.filename}.py'
|
||||
case 'schema.jinja':
|
||||
filepath = f'{rootpath}/schema/{business.filename}.py'
|
||||
case 'service.jinja':
|
||||
filepath = f'{rootpath}/service/{business.filename}_service.py'
|
||||
backend_path = 'fastapi_best_architecture/backend/app/'
|
||||
|
||||
if filepath:
|
||||
codes[filepath] = code.encode('utf-8')
|
||||
init_files = gen_template.get_init_files(business)
|
||||
for filepath, content in init_files.items():
|
||||
codes[f'{backend_path}{filepath}'] = content.encode('utf-8')
|
||||
|
||||
rendered_codes = await self._render_tpl_code(db=db, business=business)
|
||||
for filepath, code in rendered_codes.items():
|
||||
codes[f'{backend_path}{filepath}'] = code.encode('utf-8')
|
||||
|
||||
return codes
|
||||
|
||||
@@ -158,15 +154,20 @@ class GenService:
|
||||
:param pk: 业务 ID
|
||||
:return:
|
||||
"""
|
||||
|
||||
business = await gen_business_dao.get(db, pk)
|
||||
if not business:
|
||||
raise errors.NotFoundError(msg='业务不存在')
|
||||
|
||||
gen_path = business.gen_path or '.../backend/app/'
|
||||
target_files = gen_template.get_code_gen_paths(business)
|
||||
gen_path = business.gen_path or '<project_root>/backend/app'
|
||||
paths = []
|
||||
|
||||
return [os.path.join(gen_path, *target_file.split('/')) for target_file in target_files]
|
||||
init_files = gen_template.get_init_files(business)
|
||||
paths.extend(os.path.join(gen_path, *filepath.split('/')) for filepath in init_files.keys())
|
||||
|
||||
template_mapping = gen_template.get_template_path_mapping(business)
|
||||
paths.extend(os.path.join(gen_path, *filepath.split('/')) for filepath in template_mapping.values())
|
||||
|
||||
return paths
|
||||
|
||||
async def generate(self, *, db: AsyncSession, pk: int) -> str:
|
||||
"""
|
||||
@@ -176,51 +177,26 @@ class GenService:
|
||||
:param pk: 业务 ID
|
||||
:return:
|
||||
"""
|
||||
|
||||
business = await gen_business_dao.get(db, pk)
|
||||
if not business:
|
||||
raise errors.NotFoundError(msg='业务不存在')
|
||||
|
||||
tpl_code_map = await self.render_tpl_code(db=db, business=business)
|
||||
gen_path = business.gen_path or BASE_PATH / 'app'
|
||||
gen_path = business.gen_path or str(BASE_PATH / 'app')
|
||||
|
||||
for tpl_path, code in tpl_code_map.items():
|
||||
code_filepath = os.path.join(
|
||||
gen_path,
|
||||
*gen_template.get_code_gen_path(tpl_path, business).split('/'),
|
||||
)
|
||||
init_files = gen_template.get_init_files(business)
|
||||
for init_filepath, init_content in init_files.items():
|
||||
full_path = os.path.join(gen_path, *init_filepath.split('/'))
|
||||
init_folder = anyio.Path(full_path).parent
|
||||
await init_folder.mkdir(parents=True, exist_ok=True)
|
||||
async with await open_file(full_path, 'w', encoding='utf-8') as f:
|
||||
await f.write(init_content)
|
||||
|
||||
# 写入 init 文件
|
||||
code_folder = anyio.Path(code_filepath).parent
|
||||
rendered_codes = await self._render_tpl_code(db=db, business=business)
|
||||
for code_filepath, code in rendered_codes.items():
|
||||
full_path = os.path.join(gen_path, *code_filepath.split('/'))
|
||||
code_folder = anyio.Path(full_path).parent
|
||||
await code_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
init_filepath = code_folder.joinpath('__init__.py')
|
||||
if not await init_filepath.exists():
|
||||
async with await open_file(init_filepath, 'w', encoding='utf-8') as f:
|
||||
await f.write(gen_template.init_content)
|
||||
|
||||
# api __init__.py
|
||||
if 'api' in code_filepath:
|
||||
api_init_filepath = code_folder.parent.joinpath('__init__.py')
|
||||
async with await open_file(api_init_filepath, 'w', encoding='utf-8') as f:
|
||||
await f.write(gen_template.init_content)
|
||||
|
||||
# app __init__.py
|
||||
if 'service' in code_filepath:
|
||||
app_init_filepath = code_folder.parent.joinpath('__init__.py')
|
||||
async with await open_file(app_init_filepath, 'w', encoding='utf-8') as f:
|
||||
await f.write(gen_template.init_content)
|
||||
|
||||
# model init 文件补充
|
||||
if code_folder.name == 'model':
|
||||
async with await open_file(init_filepath, 'a', encoding='utf-8') as f:
|
||||
await f.write(
|
||||
f'from backend.app.{business.app_name}.model.{business.table_name} '
|
||||
f'import {to_pascal(business.table_name)}\n',
|
||||
)
|
||||
|
||||
# 写入代码文件
|
||||
async with await open_file(code_filepath, 'w', encoding='utf-8') as f:
|
||||
async with await open_file(full_path, 'w', encoding='utf-8') as f:
|
||||
await f.write(code)
|
||||
|
||||
return gen_path
|
||||
@@ -233,41 +209,18 @@ class GenService:
|
||||
:param pk: 业务 ID
|
||||
:return:
|
||||
"""
|
||||
|
||||
business = await gen_business_dao.get(db, pk)
|
||||
if not business:
|
||||
raise errors.NotFoundError(msg='业务不存在')
|
||||
|
||||
bio = io.BytesIO()
|
||||
with zipfile.ZipFile(bio, 'w') as zf:
|
||||
tpl_code_map = await self.render_tpl_code(db=db, business=business)
|
||||
for tpl_path, code in tpl_code_map.items():
|
||||
code_filepath = gen_template.get_code_gen_path(tpl_path, business)
|
||||
init_files = gen_template.get_init_files(business)
|
||||
for init_filepath, init_content in init_files.items():
|
||||
zf.writestr(init_filepath, init_content)
|
||||
|
||||
# 写入 init 文件
|
||||
code_dir = os.path.dirname(code_filepath)
|
||||
init_filepath = os.path.join(code_dir, '__init__.py')
|
||||
if 'model' not in code_filepath.split('/'):
|
||||
zf.writestr(init_filepath, gen_template.init_content)
|
||||
else:
|
||||
zf.writestr(
|
||||
init_filepath,
|
||||
f'{gen_template.init_content}'
|
||||
f'from backend.app.{business.app_name}.model.{business.table_name} '
|
||||
f'import {to_pascal(business.table_name)}\n',
|
||||
)
|
||||
|
||||
# api __init__.py
|
||||
if 'api' in code_dir:
|
||||
api_init_filepath = os.path.join(os.path.dirname(code_dir), '__init__.py')
|
||||
zf.writestr(api_init_filepath, gen_template.init_content)
|
||||
|
||||
# app __init__.py
|
||||
if 'service' in code_dir:
|
||||
app_init_filepath = os.path.join(os.path.dirname(code_dir), '__init__.py')
|
||||
zf.writestr(app_init_filepath, gen_template.init_content)
|
||||
|
||||
# 写入代码文件
|
||||
rendered_codes = await self._render_tpl_code(db=db, business=business)
|
||||
for code_filepath, code in rendered_codes.items():
|
||||
zf.writestr(code_filepath, code)
|
||||
|
||||
bio.seek(0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
insert into gen_business (id, app_name, table_name, doc_comment, table_comment, class_name, schema_name, filename, default_datetime_column, api_version, gen_path, remark, created_time, updated_time)
|
||||
insert into gen_business (id, app_name, table_name, doc_comment, table_comment, class_name, schema_name, filename, datetime_mixin, api_version, gen_path, remark, created_time, updated_time)
|
||||
values (1, 'test', 'sys_opera_log', '操作日志表', '操作日志表', 'SysOperaLog', 'SysOperaLog', 'sys_opera_log', true, 'v1', null, null, '2025-12-15 15:30:33', null);
|
||||
|
||||
insert into gen_column (id, name, comment, type, pd_type, `default`, sort, `length`, is_pk, is_nullable, gen_business_id)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
insert into gen_business (id, app_name, table_name, doc_comment, table_comment, class_name, schema_name, filename, default_datetime_column, api_version, gen_path, remark, created_time, updated_time)
|
||||
insert into gen_business (id, app_name, table_name, doc_comment, table_comment, class_name, schema_name, filename, datetime_mixin, api_version, gen_path, remark, created_time, updated_time)
|
||||
values (2112248797819043840, 'test', 'sys_opera_log', '操作日志表', '操作日志表', 'SysOperaLog', 'SysOperaLog', 'sys_opera_log', true, 'v1', null, null, '2025-12-15 15:30:33', null);
|
||||
|
||||
insert into gen_column (id, name, comment, type, pd_type, `default`, sort, `length`, is_pk, is_nullable, gen_business_id)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
insert into gen_business (id, app_name, table_name, doc_comment, table_comment, class_name, schema_name, filename, default_datetime_column, api_version, gen_path, remark, created_time, updated_time)
|
||||
insert into gen_business (id, app_name, table_name, doc_comment, table_comment, class_name, schema_name, filename, datetime_mixin, api_version, gen_path, remark, created_time, updated_time)
|
||||
values (1, 'test', 'sys_opera_log', '操作日志表', '操作日志表', 'SysOperaLog', 'SysOperaLog', 'sys_opera_log', true, 'v1', null, null, '2025-12-15 15:30:33', null);
|
||||
|
||||
insert into gen_column (id, name, comment, type, pd_type, "default", sort, "length", is_pk, is_nullable, gen_business_id)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
insert into gen_business (id, app_name, table_name, doc_comment, table_comment, class_name, schema_name, filename, default_datetime_column, api_version, gen_path, remark, created_time, updated_time)
|
||||
insert into gen_business (id, app_name, table_name, doc_comment, table_comment, class_name, schema_name, filename, datetime_mixin, api_version, gen_path, remark, created_time, updated_time)
|
||||
values (2112248797819043840, 'test', 'sys_opera_log', '操作日志表', '操作日志表', 'SysOperaLog', 'SysOperaLog', 'sys_opera_log', true, 'v1', null, null, '2025-12-15 15:30:33', null);
|
||||
|
||||
insert into gen_column (id, name, comment, type, pd_type, "default", sort, "length", is_pk, is_nullable, gen_business_id)
|
||||
|
||||
@@ -8,7 +8,7 @@ from backend.app.{{ app_name }}.model import {{ class_name }}
|
||||
from backend.app.{{ app_name }}.schema.{{ table_name }} import Create{{ schema_name }}Param, Update{{ schema_name }}Param
|
||||
|
||||
|
||||
class CRUD{{ class_name }}(CRUDPlus[{{ schema_name }}]):
|
||||
class CRUD{{ class_name }}(CRUDPlus[{{ class_name }}]):
|
||||
async def get(self, db: AsyncSession, pk: int) -> {{ class_name }} | None:
|
||||
"""
|
||||
获取{{ doc_comment }}
|
||||
|
||||
@@ -5,29 +5,48 @@
|
||||
'HSTORE', 'INET', 'INT4MULTIRANGE', 'INT4RANGE', 'INT8MULTIRANGE', 'INT8RANGE', 'INTERVAL', 'JSONB', 'JSONPATH',
|
||||
'MACADDR', 'MACADDR8', 'MONEY', 'NUMMULTIRANGE', 'NUMRANGE', 'OID', 'REGCLASS', 'REGCONFIG', 'TSMULTIRANGE', 'TSQUERY',
|
||||
'TSRANGE', 'TSTZMULTIRANGE', 'TSTZRANGE', 'TSVECTOR'] %}
|
||||
{% if default_datetime_column %}
|
||||
from datetime import datetime
|
||||
{% set FACTORY_TYPES = ['dict', 'list', 'list[str]', 'list[int]'] %}
|
||||
{% set pd_types = models|map(attribute='pd_type')|list %}
|
||||
{% set NEED_MYSQL_DIALECT = database_type == 'mysql' and model_types|select('in', MYSQL_TYPES)|list|length > 0 %}
|
||||
{% set NEED_PGSQL_DIALECT = database_type == 'postgresql' and model_types|select('in', POSTGRESQL_TYPES)|list|length > 0 %}
|
||||
{% set NEED_DATE = 'date' in pd_types %}
|
||||
{% set NEED_DATETIME = datetime_mixin or 'datetime' in pd_types %}
|
||||
{% set NEED_UNIVERSAL_TEXT = 'TEXT' in model_types or 'Text' in model_types or 'LONGTEXT' in model_types %}
|
||||
{% set NEED_TIMEZONE = 'TIMESTAMP' in model_types or 'DateTime' in model_types or 'TIMESTAMP WITHOUT TIME ZONE' in model_types or 'TIMESTAMP WITH TIME ZONE' in model_types %}
|
||||
{% if NEED_DATETIME or NEED_DATE %}
|
||||
from datetime import {% if NEED_DATETIME %}datetime{% endif %}{% if NEED_DATE %}{% if NEED_DATETIME %}, {% endif %}date{% endif %}
|
||||
|
||||
{% endif %}
|
||||
{% if model_types|select('in', DECIMAL_TYPES)|first %}
|
||||
from decimal import Decimal
|
||||
|
||||
{% endif %}
|
||||
{% if 'Uuid' in model_types or 'UUID' in model_types %}
|
||||
from uuid import UUID
|
||||
{% endif %}
|
||||
|
||||
{% endif %}
|
||||
import sqlalchemy as sa
|
||||
|
||||
{% if database_type == 'mysql' -%}
|
||||
{% if NEED_MYSQL_DIALECT %}
|
||||
from sqlalchemy.dialects import mysql
|
||||
{% else -%}
|
||||
{% endif %}
|
||||
{% if NEED_PGSQL_DIALECT %}
|
||||
from sqlalchemy.dialects import postgresql
|
||||
{% endif -%}
|
||||
{% endif %}
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from backend.common.model import {% if default_datetime_column %}Base{% else %}DataClassBase{% endif %}, id_key
|
||||
from backend.common.model import {% if datetime_mixin %}Base{% else %}DataClassBase{% endif %}, id_key
|
||||
{%- if NEED_UNIVERSAL_TEXT or NEED_TIMEZONE %}, {% endif %}
|
||||
{%- if NEED_UNIVERSAL_TEXT %}UniversalText{% endif %}
|
||||
{%- if NEED_UNIVERSAL_TEXT and NEED_TIMEZONE %}, {% endif %}
|
||||
{%- if NEED_TIMEZONE %}TimeZone{% endif %}
|
||||
|
||||
{% if NEED_TIMEZONE %}
|
||||
from backend.utils.timezone import timezone
|
||||
{% endif %}
|
||||
|
||||
|
||||
class {{ class_name }}({% if default_datetime_column %}Base{% else %}DataClassBase{% endif %}):
|
||||
class {{ class_name }}({% if datetime_mixin %}Base{% else %}DataClassBase{% endif %}):
|
||||
"""{{ table_comment }}"""
|
||||
|
||||
__tablename__ = '{{ table_name }}'
|
||||
@@ -38,41 +57,36 @@ class {{ class_name }}({% if default_datetime_column %}Base{% else %}DataClassBa
|
||||
{%- if model.is_nullable %} Mapped[{{ model.pd_type }} | None]
|
||||
{%- else %} Mapped[{{ model.pd_type }}]
|
||||
{%- endif %} = mapped_column(
|
||||
{%- if model.type in ['NVARCHAR', 'String', 'Unicode', 'VARCHAR'] -%}
|
||||
{%- if model.type in ['TEXT', 'Text', 'LONGTEXT'] -%}
|
||||
UniversalText
|
||||
{%- elif model.type in ['TIMESTAMP', 'DateTime', 'TIMESTAMP WITHOUT TIME ZONE', 'TIMESTAMP WITH TIME ZONE'] -%}
|
||||
TimeZone
|
||||
{%- elif model.type in ['NVARCHAR', 'String', 'Unicode', 'VARCHAR', 'CHARACTER VARYING'] -%}
|
||||
sa.String({{ model.length }})
|
||||
{%- elif database_type == 'mysql' and model.type in MYSQL_TYPES -%}
|
||||
mysql.{{ model.type }}()
|
||||
mysql.{{ model.type|sqlalchemy_type }}()
|
||||
{%- elif database_type == 'postgresql' and model.type in POSTGRESQL_TYPES -%}
|
||||
{%- else -%}
|
||||
sa.{{ model.type }}()
|
||||
{%- endif -%}, default=
|
||||
{%- if model.is_nullable and model.default == None -%}
|
||||
None
|
||||
{%- else -%}
|
||||
{%- if model.default != None -%}
|
||||
'{{ model.default }}'
|
||||
{%- if model.type == 'ARRAY' -%}
|
||||
postgresql.ARRAY(sa.String)
|
||||
{%- else -%}
|
||||
{%- if model.pd_type == 'str' -%}
|
||||
''
|
||||
{%- elif model.pd_type == 'int' -%}
|
||||
0
|
||||
{%- elif model.pd_type == 'bytes' -%}
|
||||
b''
|
||||
{%- elif model.pd_type == 'bool' -%}
|
||||
True
|
||||
{%- elif model.pd_type == 'float' -%}
|
||||
0.0
|
||||
{%- elif model.pd_type == 'dict' -%}
|
||||
{}
|
||||
{%- elif model.pd_type == 'date' or model.pd_type == 'datetime' -%}
|
||||
timezone.now()
|
||||
{%- elif model.pd_type == 'list[str]' -%}
|
||||
()
|
||||
{%- else -%}
|
||||
''
|
||||
{%- endif -%}
|
||||
postgresql.{{ model.type|sqlalchemy_type }}()
|
||||
{%- endif -%}
|
||||
{%- endif -%}{% if model.sort != 0 %}, sort_order={{ model.sort }}{% endif %}, comment=
|
||||
{%- else -%}
|
||||
sa.{{ model.type|sqlalchemy_type }}()
|
||||
{%- endif -%},
|
||||
{%- if model.is_nullable and model.default == None %} default=None
|
||||
{%- elif model.default != None %} default='{{ model.default }}'
|
||||
{%- elif model.pd_type in FACTORY_TYPES %} default_factory={{ model.pd_type.split('[')[0] }}
|
||||
{%- elif model.pd_type == 'str' %} default=''
|
||||
{%- elif model.pd_type == 'int' %} default=0
|
||||
{%- elif model.pd_type == 'bytes' %} default=b''
|
||||
{%- elif model.pd_type == 'bool' %} default=True
|
||||
{%- elif model.pd_type == 'float' %} default=0.0
|
||||
{%- elif model.pd_type == 'date' %} default_factory=date.today
|
||||
{%- elif model.pd_type == 'datetime' and model.type in ['TIMESTAMP', 'DateTime', 'TIMESTAMP WITHOUT TIME ZONE', 'TIMESTAMP WITH TIME ZONE'] %} default_factory=timezone.now
|
||||
{%- elif model.pd_type == 'datetime' %} default_factory=datetime.now
|
||||
{%- else %} default=None
|
||||
{%- endif -%}, comment=
|
||||
{%- if model.comment != None -%}
|
||||
'{{ model.comment }}')
|
||||
{% else -%}
|
||||
|
||||
@@ -1,5 +1,20 @@
|
||||
{% if default_datetime_column %}
|
||||
from datetime import datetime
|
||||
{% set pd_types = models|map(attribute='pd_type')|list %}
|
||||
{% set NEED_DATETIME = datetime_mixin or 'datetime' in pd_types %}
|
||||
{% set NEED_DATE = 'date' in pd_types %}
|
||||
{% set NEED_TIME = 'time' in pd_types %}
|
||||
{% set NEED_TIMEDELTA = 'timedelta' in pd_types %}
|
||||
{% set NEED_DECIMAL = 'Decimal' in pd_types %}
|
||||
{% set NEED_UUID = 'UUID' in pd_types or 'str | UUID' in pd_types %}
|
||||
{% if NEED_DATETIME or NEED_DATE or NEED_TIME or NEED_TIMEDELTA %}
|
||||
from datetime import {% if NEED_DATETIME %}datetime{% endif %}{% if NEED_DATE %}{% if NEED_DATETIME %}, {% endif %}date{% endif %}{% if NEED_TIME %}{% if NEED_DATETIME or NEED_DATE %}, {% endif %}time{% endif %}{% if NEED_TIMEDELTA %}{% if NEED_DATETIME or NEED_DATE or NEED_TIME %}, {% endif %}timedelta{% endif %}
|
||||
|
||||
{% endif %}
|
||||
{% if NEED_DECIMAL %}
|
||||
from decimal import Decimal
|
||||
|
||||
{% endif %}
|
||||
{% if NEED_UUID %}
|
||||
from uuid import UUID
|
||||
|
||||
{% endif %}
|
||||
from pydantic import ConfigDict, Field
|
||||
@@ -35,7 +50,7 @@ class Get{{ schema_name }}Detail({{ schema_name }}SchemaBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
{% if default_datetime_column %}
|
||||
{% if datetime_mixin %}
|
||||
created_time: datetime
|
||||
updated_time: datetime | None = None
|
||||
{% endif %}
|
||||
|
||||
11
backend/plugin/code_generator/templates/sql/mysql/init.jinja
Normal file
11
backend/plugin/code_generator/templates/sql/mysql/init.jinja
Normal file
@@ -0,0 +1,11 @@
|
||||
insert into sys_menu (title, name, path, sort, icon, type, component, perms, status, display, cache, link, remark, parent_id, created_time, updated_time)
|
||||
values ('{{ doc_comment }}', '{{ schema_name }}', '/{{ app_name }}/{{ table_name|replace("_", "-") }}', 0, 'tabler:list', 1, '/{{ app_name }}/{{ table_name }}/views/index', null, 1, 1, 1, '', null, null, now(), null);
|
||||
|
||||
set @parent_menu_id = LAST_INSERT_ID();
|
||||
|
||||
insert into sys_menu (title, name, path, sort, icon, type, component, perms, status, display, cache, link, remark, parent_id, created_time, updated_time)
|
||||
values
|
||||
('新增', 'Add{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:add', 1, 0, 1, '', null, @parent_menu_id, now(), null),
|
||||
('修改', 'Edit{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:edit', 1, 0, 1, '', null, @parent_menu_id, now(), null),
|
||||
('删除', 'Delete{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:del', 1, 0, 1, '', null, @parent_menu_id, now(), null),
|
||||
('查询', 'Get{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:get', 1, 0, 1, '', null, @parent_menu_id, now(), null);
|
||||
@@ -0,0 +1,9 @@
|
||||
insert into sys_menu (id, title, name, path, sort, icon, type, component, perms, status, display, cache, link, remark, parent_id, created_time, updated_time)
|
||||
values ({{ parent_menu_id }}, '{{ doc_comment }}', '{{ schema_name }}', '/{{ app_name }}/{{ table_name|replace("_", "-") }}', 0, 'tabler:list', 1, '/{{ app_name }}/{{ table_name }}/views/index', null, 1, 1, 1, '', null, null, now(), null);
|
||||
|
||||
insert into sys_menu (id, title, name, path, sort, icon, type, component, perms, status, display, cache, link, remark, parent_id, created_time, updated_time)
|
||||
values
|
||||
({{ button_ids[0] }}, '新增', 'Add{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:add', 1, 0, 1, '', null, {{ parent_menu_id }}, now(), null),
|
||||
({{ button_ids[1] }}, '修改', 'Edit{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:edit', 1, 0, 1, '', null, {{ parent_menu_id }}, now(), null),
|
||||
({{ button_ids[2] }}, '删除', 'Delete{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:del', 1, 0, 1, '', null, {{ parent_menu_id }}, now(), null),
|
||||
({{ button_ids[3] }}, '查询', 'Get{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:get', 1, 0, 1, '', null, {{ parent_menu_id }}, now(), null);
|
||||
@@ -0,0 +1,17 @@
|
||||
do $$
|
||||
declare
|
||||
parent_menu_id bigint;
|
||||
begin
|
||||
insert into sys_menu (title, name, path, sort, icon, type, component, perms, status, display, cache, link, remark, parent_id, created_time, updated_time)
|
||||
values ('{{ doc_comment }}', '{{ schema_name }}', '/{{ app_name }}/{{ table_name|replace("_", "-") }}', 0, 'tabler:list', 1, '/{{ app_name }}/{{ table_name }}/views/index', null, 1, 1, 1, '', null, null, now(), null)
|
||||
returning id into parent_menu_id;
|
||||
|
||||
insert into sys_menu (title, name, path, sort, icon, type, component, perms, status, display, cache, link, remark, parent_id, created_time, updated_time)
|
||||
values
|
||||
('新增', 'Add{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:add', 1, 0, 1, '', null, parent_menu_id, now(), null),
|
||||
('修改', 'Edit{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:edit', 1, 0, 1, '', null, parent_menu_id, now(), null),
|
||||
('删除', 'Delete{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:del', 1, 0, 1, '', null, parent_menu_id, now(), null),
|
||||
('查询', 'Get{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:get', 1, 0, 1, '', null, parent_menu_id, now(), null);
|
||||
end $$;
|
||||
|
||||
select setval(pg_get_serial_sequence('sys_menu', 'id'),coalesce(max(id), 0) + 1, true) from sys_menu;
|
||||
@@ -0,0 +1,9 @@
|
||||
insert into sys_menu (id, title, name, path, sort, icon, type, component, perms, status, display, cache, link, remark, parent_id, created_time, updated_time)
|
||||
values ({{ parent_menu_id }}, '{{ doc_comment }}', '{{ schema_name }}', '/{{ app_name }}/{{ table_name|replace("_", "-") }}', 0, 'tabler:list', 1, '/{{ app_name }}/{{ table_name }}/views/index', null, 1, 1, 1, '', null, null, now(), null);
|
||||
|
||||
insert into sys_menu (id, title, name, path, sort, icon, type, component, perms, status, display, cache, link, remark, parent_id, created_time, updated_time)
|
||||
values
|
||||
({{ button_ids[0] }}, '新增', 'Add{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:add', 1, 0, 1, '', null, {{ parent_menu_id }}, now(), null),
|
||||
({{ button_ids[1] }}, '修改', 'Edit{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:edit', 1, 0, 1, '', null, {{ parent_menu_id }}, now(), null),
|
||||
({{ button_ids[2] }}, '删除', 'Delete{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:del', 1, 0, 1, '', null, {{ parent_menu_id }}, now(), null),
|
||||
({{ button_ids[3] }}, '查询', 'Get{{ schema_name }}', null, 0, null, 2, null, '{{ permission }}:get', 1, 0, 1, '', null, {{ parent_menu_id }}, now(), null);
|
||||
43
backend/plugin/code_generator/utils/format_code.py
Normal file
43
backend/plugin/code_generator/utils/format_code.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import asyncio
|
||||
import tempfile
|
||||
import uuid
|
||||
|
||||
import anyio
|
||||
|
||||
from anyio import open_file
|
||||
|
||||
|
||||
async def format_python_code(code: str) -> str:
|
||||
"""
|
||||
使用 ruff 格式化 Python 代码
|
||||
|
||||
:param code: 原始代码
|
||||
:return:
|
||||
"""
|
||||
temp_dir = anyio.Path(tempfile.gettempdir())
|
||||
temp_file = temp_dir / f'fba_codegen_{uuid.uuid4().hex}.py'
|
||||
|
||||
try:
|
||||
async with await open_file(temp_file, 'w', encoding='utf-8') as f:
|
||||
await f.write(code)
|
||||
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
'ruff',
|
||||
'format',
|
||||
str(temp_file),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
await process.communicate()
|
||||
|
||||
if process.returncode == 0:
|
||||
async with await open_file(temp_file, encoding='utf-8') as f:
|
||||
formatted_code = await f.read()
|
||||
else:
|
||||
formatted_code = code
|
||||
except (FileNotFoundError, OSError):
|
||||
return code
|
||||
finally:
|
||||
await temp_file.unlink(missing_ok=True)
|
||||
|
||||
return formatted_code
|
||||
@@ -1,10 +1,14 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader, Template, select_autoescape
|
||||
from pydantic.alias_generators import to_pascal
|
||||
|
||||
from backend.core.conf import settings
|
||||
from backend.plugin.code_generator.model import GenBusiness, GenColumn
|
||||
from backend.plugin.code_generator.path_conf import JINJA2_TEMPLATE_DIR
|
||||
from backend.plugin.code_generator.utils.type_conversion import sql_type_to_sqlalchemy_name
|
||||
from backend.utils.snowflake import snowflake
|
||||
from backend.utils.timezone import timezone
|
||||
|
||||
|
||||
class GenTemplate:
|
||||
@@ -18,60 +22,64 @@ class GenTemplate:
|
||||
keep_trailing_newline=True,
|
||||
enable_async=True,
|
||||
)
|
||||
self.env.filters['sqlalchemy_type'] = sql_type_to_sqlalchemy_name
|
||||
self.init_content = ''
|
||||
|
||||
def get_template(self, jinja_file: str) -> Template:
|
||||
"""
|
||||
获取模板文件
|
||||
获取 Jinja2 模板对象
|
||||
|
||||
:param jinja_file: Jinja2 模板文件
|
||||
:return:
|
||||
:param jinja_file: Jinja2 模板文件路径
|
||||
:return: Template 对象
|
||||
"""
|
||||
return self.env.get_template(jinja_file)
|
||||
|
||||
@staticmethod
|
||||
def get_template_files() -> list[str]:
|
||||
def get_template_path_mapping(business: GenBusiness) -> dict[str, str]:
|
||||
"""
|
||||
获取模板文件列表
|
||||
|
||||
:return:
|
||||
"""
|
||||
return [
|
||||
'python/api.jinja',
|
||||
'python/crud.jinja',
|
||||
'python/model.jinja',
|
||||
'python/schema.jinja',
|
||||
'python/service.jinja',
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_code_gen_paths(business: GenBusiness) -> list[str]:
|
||||
"""
|
||||
获取代码生成路径列表
|
||||
获取模板文件到生成文件的路径映射
|
||||
|
||||
:param business: 代码生成业务对象
|
||||
:return:
|
||||
:return: {模板路径: 生成文件路径}
|
||||
"""
|
||||
app_name = business.app_name
|
||||
filename = business.filename
|
||||
return [
|
||||
f'{app_name}/api/{business.api_version}/{filename}.py',
|
||||
f'{app_name}/crud/crud_{filename}.py',
|
||||
f'{app_name}/model/{filename}.py',
|
||||
f'{app_name}/schema/{filename}.py',
|
||||
f'{app_name}/service/{filename}_service.py',
|
||||
]
|
||||
api_version = business.api_version
|
||||
pk_suffix = '_snowflake' if settings.DATABASE_PK_MODE == 'snowflake' else ''
|
||||
|
||||
def get_code_gen_path(self, tpl_path: str, business: GenBusiness) -> str:
|
||||
"""
|
||||
获取代码生成路径
|
||||
return {
|
||||
'python/api.jinja': f'{app_name}/api/{api_version}/{filename}.py',
|
||||
'python/crud.jinja': f'{app_name}/crud/crud_{filename}.py',
|
||||
'python/model.jinja': f'{app_name}/model/{filename}.py',
|
||||
'python/schema.jinja': f'{app_name}/schema/{filename}.py',
|
||||
'python/service.jinja': f'{app_name}/service/{filename}_service.py',
|
||||
f'sql/mysql/init{pk_suffix}.jinja': f'{app_name}/sql/mysql/init{pk_suffix}.sql',
|
||||
f'sql/postgresql/init{pk_suffix}.jinja': f'{app_name}/sql/postgresql/init{pk_suffix}.sql',
|
||||
}
|
||||
|
||||
:param tpl_path: 模板文件路径
|
||||
:param business: 代码生成业务对象
|
||||
:return:
|
||||
def get_init_files(self, business: GenBusiness) -> dict[str, str]:
|
||||
"""
|
||||
code_gen_path_mapping = dict(zip(self.get_template_files(), self.get_code_gen_paths(business)))
|
||||
return code_gen_path_mapping[tpl_path]
|
||||
获取需要生成的 __init__.py 文件及其内容
|
||||
|
||||
:param business: 业务对象
|
||||
:return: {相对路径: 文件内容}
|
||||
"""
|
||||
app_name = business.app_name
|
||||
table_name = business.table_name
|
||||
class_name = business.class_name or to_pascal(table_name)
|
||||
|
||||
return {
|
||||
f'{app_name}/__init__.py': self.init_content,
|
||||
f'{app_name}/api/__init__.py': self.init_content,
|
||||
f'{app_name}/api/{business.api_version}/__init__.py': self.init_content,
|
||||
f'{app_name}/crud/__init__.py': self.init_content,
|
||||
f'{app_name}/model/__init__.py': (
|
||||
f'{self.init_content}'
|
||||
f'from backend.app.{app_name}.model.{table_name} import {class_name} as {class_name}\n'
|
||||
),
|
||||
f'{app_name}/schema/__init__.py': self.init_content,
|
||||
f'{app_name}/service/__init__.py': self.init_content,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_vars(business: GenBusiness, models: Sequence[GenColumn]) -> dict[str, str | Sequence[GenColumn]]:
|
||||
@@ -82,19 +90,26 @@ class GenTemplate:
|
||||
:param models: 代码生成模型对象列表
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
vars_dict = {
|
||||
'app_name': business.app_name,
|
||||
'table_name': business.table_name,
|
||||
'doc_comment': business.doc_comment,
|
||||
'table_comment': business.table_comment,
|
||||
'class_name': business.class_name,
|
||||
'schema_name': business.schema_name,
|
||||
'default_datetime_column': business.default_datetime_column,
|
||||
'permission': str(business.table_name.replace('_', ':')),
|
||||
'datetime_mixin': business.datetime_mixin,
|
||||
'permission': business.table_name.replace('_', ':'),
|
||||
'database_type': settings.DATABASE_TYPE,
|
||||
'models': models,
|
||||
'model_types': [model.type for model in models],
|
||||
'now': timezone.now(),
|
||||
}
|
||||
|
||||
if settings.DATABASE_PK_MODE == 'snowflake':
|
||||
vars_dict['parent_menu_id'] = snowflake.generate()
|
||||
vars_dict['button_ids'] = [snowflake.generate() for _ in range(4)]
|
||||
|
||||
return vars_dict
|
||||
|
||||
|
||||
gen_template: GenTemplate = GenTemplate()
|
||||
|
||||
@@ -17,7 +17,8 @@ def sql_type_to_sqlalchemy(typing: str) -> str:
|
||||
if typing in GenMySQLColumnType.get_member_keys():
|
||||
return typing
|
||||
else:
|
||||
if typing in GenPostgreSQLColumnType.get_member_keys():
|
||||
normalized_type = typing.replace(' ', '_')
|
||||
if normalized_type in GenPostgreSQLColumnType.get_member_keys():
|
||||
return typing
|
||||
return 'String'
|
||||
|
||||
@@ -33,8 +34,31 @@ def sql_type_to_pydantic(typing: str) -> str:
|
||||
try:
|
||||
if DataBaseType.mysql == settings.DATABASE_TYPE:
|
||||
return GenMySQLColumnType[typing].value
|
||||
if typing == 'CHARACTER VARYING': # postgresql 中 DDL VARCHAR 的别名
|
||||
return 'str'
|
||||
return GenPostgreSQLColumnType[typing].value
|
||||
normalized_type = typing.replace(' ', '_')
|
||||
return GenPostgreSQLColumnType[normalized_type].value
|
||||
except KeyError:
|
||||
return 'str'
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def sql_type_to_sqlalchemy_name(typing: str) -> str:
|
||||
"""
|
||||
将 SQL 类型转换为有效的 SQLAlchemy 类型名称(用于代码生成)
|
||||
|
||||
:param typing: SQL 类型字符串
|
||||
:return:
|
||||
"""
|
||||
pg_type_mapping = {
|
||||
'CHARACTER VARYING': 'String',
|
||||
'CHARACTER': 'CHAR',
|
||||
'TIMESTAMP WITHOUT TIME ZONE': 'TIMESTAMP',
|
||||
'TIMESTAMP WITH TIME ZONE': 'TIMESTAMP',
|
||||
'TIME WITHOUT TIME ZONE': 'TIME',
|
||||
'TIME WITH TIME ZONE': 'TIME',
|
||||
'DOUBLE PRECISION': 'Double',
|
||||
}
|
||||
|
||||
if DataBaseType.postgresql == settings.DATABASE_TYPE and typing in pg_type_mapping:
|
||||
return pg_type_mapping[typing]
|
||||
|
||||
return typing
|
||||
|
||||
@@ -76,3 +76,14 @@ def is_has_special_char(value: str) -> re.Match[str]:
|
||||
"""
|
||||
special_char_pattern = r'[!@#$%^&*()_+\-=\[\]{};:\'",.<>?/\\|`~]'
|
||||
return search_string(special_char_pattern, value)
|
||||
|
||||
|
||||
def is_english_identifier(value: str) -> re.Match[str]:
|
||||
"""
|
||||
检查英文标识符
|
||||
|
||||
:param value: 待检查的值
|
||||
:return:
|
||||
"""
|
||||
identifier_pattern = r'^[a-zA-Z][a-zA-Z_]*$'
|
||||
return match_string(identifier_pattern, value)
|
||||
|
||||
Reference in New Issue
Block a user