diff --git a/src/components/ExecutionView/DataTableView.tsx b/src/components/ExecutionView/DataTableView.tsx new file mode 100644 index 0000000..4e5f6c1 --- /dev/null +++ b/src/components/ExecutionView/DataTableView.tsx @@ -0,0 +1,39 @@ +import { head } from "lodash-es"; +import DataTable from "react-data-table-component"; +import { useTranslation } from "react-i18next"; +import { RawResult } from "@/types"; +import Icon from "../Icon"; + +interface Props { + rawResults: RawResult[]; +} + +const DataTableView = (props: Props) => { + const { rawResults } = props; + const { t } = useTranslation(); + const columns = Object.keys(head(rawResults) || {}).map((key) => { + return { + name: key, + sortable: true, + selector: (row: any) => row[key], + }; + }); + + return rawResults.length === 0 ? ( +
+ + {t("execution.message.no-data")} +
+ ) : ( + + ); +}; + +export default DataTableView; diff --git a/src/components/ExecutionView/ExecutionWarningBanner.tsx b/src/components/ExecutionView/ExecutionWarningBanner.tsx new file mode 100644 index 0000000..904e9dc --- /dev/null +++ b/src/components/ExecutionView/ExecutionWarningBanner.tsx @@ -0,0 +1,22 @@ +import { useTranslation } from "react-i18next"; +import Icon from "../Icon"; + +interface Props { + className?: string; +} + +const ExecutionWarningBanner = (props: Props) => { + const { className } = props; + const { t } = useTranslation(); + + return ( +
+ + + {t("banner.non-select-sql-warning")} + +
+ ); +}; + +export default ExecutionWarningBanner; diff --git a/src/components/ExecutionView/NotificationView.tsx b/src/components/ExecutionView/NotificationView.tsx new file mode 100644 index 0000000..b1c2853 --- /dev/null +++ b/src/components/ExecutionView/NotificationView.tsx @@ -0,0 +1,12 @@ +interface Props { + message: string; + style: "info" | "error"; +} + +const NotificationView = (props: Props) => { + const { message, style } = props; + const additionalStyle = style === "error" ? "text-red-500" : "text-gray-500"; + return

{message}

; +}; + +export default NotificationView; diff --git a/src/components/QueryDrawer.tsx b/src/components/QueryDrawer.tsx index 89520a4..f4f2c27 100644 --- a/src/components/QueryDrawer.tsx +++ b/src/components/QueryDrawer.tsx @@ -1,41 +1,38 @@ import { Drawer } from "@mui/material"; -import { head } from "lodash-es"; import { useEffect, useState } from "react"; -import DataTable from "react-data-table-component"; import { toast } from "react-hot-toast"; import { useTranslation } from "react-i18next"; import TextareaAutosize from "react-textarea-autosize"; import { useQueryStore } from "@/store"; -import { ResponseObject } from "@/types"; +import { ExecutionResult, ResponseObject } from "@/types"; +import { checkStatementIsSelect, getMessageFromExecutionResult } from "@/utils"; import Icon from "./Icon"; import EngineIcon from "./EngineIcon"; - -type RawQueryResult = { - [key: string]: any; -}; +import DataTableView from "./ExecutionView/DataTableView"; +import NotificationView from "./ExecutionView/NotificationView"; +import ExecutionWarningBanner from "./ExecutionView/ExecutionWarningBanner"; const QueryDrawer = () => { const { t } = useTranslation(); const queryStore = useQueryStore(); - const [rawResults, setRawResults] = useState([]); + const [executionResult, setExecutionResult] = useState(undefined); + const [statement, setStatement] = useState(""); + const [isLoading, setIsLoading] = useState(false); const context = queryStore.context; - const [statement, setStatement] = useState(context?.statement || ""); - const [isLoading, setIsLoading] = useState(true); - const columns = Object.keys(head(rawResults) || {}).map((key) => { - return { - name: key, - sortable: true, - selector: (row: RawQueryResult) => row[key], - }; - }); + const executionMessage = executionResult ? getMessageFromExecutionResult(executionResult) : ""; + const showExecutionWarningBanner = !checkStatementIsSelect(statement); useEffect(() => { if (!queryStore.showDrawer) { return; } - setStatement(context?.statement || ""); - executeStatement(context?.statement || ""); + const statement = context?.statement || ""; + setStatement(statement); + if (statement !== "" && checkStatementIsSelect(statement)) { + executeStatement(statement); + } + setExecutionResult(undefined); }, [context, queryStore.showDrawer]); const executeStatement = async (statement: string) => { @@ -47,12 +44,12 @@ const QueryDrawer = () => { if (!context) { toast.error("No execution context found."); setIsLoading(false); - setRawResults([]); + setExecutionResult(undefined); return; } setIsLoading(true); - setRawResults([]); + setExecutionResult(undefined); const { connection, database } = context; try { const response = await fetch("/api/connection/execute", { @@ -66,11 +63,14 @@ const QueryDrawer = () => { statement, }), }); - const result = (await response.json()) as ResponseObject; + const result = (await response.json()) as ResponseObject; if (result.message) { - toast.error(result.message); + setExecutionResult({ + rawResult: [], + error: result.message, + }); } else { - setRawResults(result.data); + setExecutionResult(result.data); } } catch (error) { console.error(error); @@ -101,6 +101,7 @@ const QueryDrawer = () => { {context.database?.name} + {showExecutionWarningBanner && }
{ {t("execution.message.executing")}
- ) : rawResults.length === 0 ? ( -
- - {t("execution.message.no-data")} -
) : ( -
- -
+ <> + {executionResult ? ( + executionMessage ? ( + + ) : ( + + ) + ) : ( + <> + )} + )} diff --git a/src/lib/connectors/index.ts b/src/lib/connectors/index.ts index ff0599b..24b6cd9 100644 --- a/src/lib/connectors/index.ts +++ b/src/lib/connectors/index.ts @@ -1,11 +1,11 @@ -import { Connection, Engine } from "@/types"; +import { Connection, Engine, ExecutionResult } from "@/types"; import mysql from "./mysql"; import postgres from "./postgres"; import mssql from "./mssql"; export interface Connector { testConnection: () => Promise; - execute: (databaseName: string, statement: string) => Promise; + execute: (databaseName: string, statement: string) => Promise; getDatabases: () => Promise; getTables: (databaseName: string) => Promise; getTableStructure: (databaseName: string, tableName: string) => Promise; @@ -17,8 +17,8 @@ export const newConnector = (connection: Connection): Connector => { return mysql(connection); case Engine.PostgreSQL: return postgres(connection); - case Engine.MSSQL: - return mssql(connection); + case Engine.MSSQL: + return mssql(connection); default: throw new Error("Unsupported engine type."); } diff --git a/src/lib/connectors/mssql/index.ts b/src/lib/connectors/mssql/index.ts index fd2a34b..5865547 100644 --- a/src/lib/connectors/mssql/index.ts +++ b/src/lib/connectors/mssql/index.ts @@ -1,5 +1,5 @@ import { ConnectionPool } from "mssql"; -import { Connection } from "@/types"; +import { Connection, ExecutionResult } from "@/types"; import { Connector } from ".."; const systemDatabases = ["master", "tempdb", "model", "msdb"]; @@ -37,7 +37,12 @@ const execute = async (connection: Connection, databaseName: string, statement: const request = pool.request(); const result = await request.query(`USE ${databaseName}; ${statement}`); await pool.close(); - return result.recordset; + + const executionResult: ExecutionResult = { + rawResult: result.recordset, + affectedRows: result.rowsAffected.length, + }; + return executionResult; }; const getDatabases = async (connection: Connection): Promise => { diff --git a/src/lib/connectors/mysql/index.ts b/src/lib/connectors/mysql/index.ts index b98cef2..0ffdc1b 100644 --- a/src/lib/connectors/mysql/index.ts +++ b/src/lib/connectors/mysql/index.ts @@ -1,6 +1,6 @@ import { ConnectionOptions } from "mysql2"; import mysql, { RowDataPacket } from "mysql2/promise"; -import { Connection } from "@/types"; +import { Connection, ExecutionResult } from "@/types"; import { Connector } from ".."; const systemDatabases = ["information_schema", "mysql", "performance_schema", "sys"]; @@ -33,9 +33,19 @@ const testConnection = async (connection: Connection): Promise => { const execute = async (connection: Connection, databaseName: string, statement: string): Promise => { connection.database = databaseName; const conn = await getMySQLConnection(connection); - const [rows] = await conn.query(statement); + const [rows] = await conn.execute(statement); conn.destroy(); - return rows; + + const executionResult: ExecutionResult = { + rawResult: [], + affectedRows: 0, + }; + if (Array.isArray(rows)) { + executionResult.rawResult = rows; + } else { + executionResult.affectedRows = rows.affectedRows; + } + return executionResult; }; const getDatabases = async (connection: Connection): Promise => { diff --git a/src/lib/connectors/postgres/index.ts b/src/lib/connectors/postgres/index.ts index d31c523..8415e2e 100644 --- a/src/lib/connectors/postgres/index.ts +++ b/src/lib/connectors/postgres/index.ts @@ -1,5 +1,5 @@ import { Client, ClientConfig } from "pg"; -import { Connection } from "@/types"; +import { Connection, ExecutionResult } from "@/types"; import { Connector } from ".."; const newPostgresClient = (connection: Connection) => { @@ -27,12 +27,18 @@ const testConnection = async (connection: Connection): Promise => { return true; }; -const execute = async (connection: Connection, _: string, statement: string): Promise => { +const execute = async (connection: Connection, databaseName: string, statement: string): Promise => { + connection.database = databaseName; const client = newPostgresClient(connection); await client.connect(); - const { rows } = await client.query(statement); + const { rows, rowCount } = await client.query(statement); await client.end(); - return rows; + + const executionResult: ExecutionResult = { + rawResult: rows, + affectedRows: rowCount, + }; + return executionResult; }; const getDatabases = async (connection: Connection): Promise => { diff --git a/src/locales/en.json b/src/locales/en.json index 52e6572..79893be 100644 --- a/src/locales/en.json +++ b/src/locales/en.json @@ -57,6 +57,7 @@ "join-wechat-group": "Join WeChat Group" }, "banner": { - "data-storage": "Connection settings are stored in your local browser" + "data-storage": "Connection settings are stored in your local browser", + "non-select-sql-warning": "The current statement may be non-SELECT SQL, which will result in a database schema or data change. Make sure you know what you are doing." } } diff --git a/src/locales/zh.json b/src/locales/zh.json index 1513066..4dcedc3 100644 --- a/src/locales/zh.json +++ b/src/locales/zh.json @@ -57,6 +57,7 @@ "join-wechat-group": "加入微信群" }, "banner": { - "data-storage": "连接设置存储在您的本地浏览器中" + "data-storage": "连接设置存储在您的本地浏览器中", + "non-select-sql-warning": "当前语句可能是非 SELECT SQL,这将导致数据库模式或数据变化。" } } diff --git a/src/pages/api/connection/execute.ts b/src/pages/api/connection/execute.ts index fb4f7b4..6ff613f 100644 --- a/src/pages/api/connection/execute.ts +++ b/src/pages/api/connection/execute.ts @@ -1,7 +1,6 @@ import { NextApiRequest, NextApiResponse } from "next"; import { newConnector } from "@/lib/connectors"; import { Connection } from "@/types"; -import { checkStatementIsSelect } from "@/utils"; // POST /api/connection/execute // req body: { connection: Connection, db: string, statement: string } @@ -13,10 +12,6 @@ const handler = async (req: NextApiRequest, res: NextApiResponse) => { const connection = req.body.connection as Connection; const db = req.body.db as string; const statement = req.body.statement as string; - // We only support SELECT statements for now. - if (!checkStatementIsSelect(statement)) { - return res.status(400).json([]); - } try { const connector = newConnector(connection); diff --git a/src/types/connector.ts b/src/types/connector.ts new file mode 100644 index 0000000..cead8e9 --- /dev/null +++ b/src/types/connector.ts @@ -0,0 +1,9 @@ +export type RawResult = { + [key: string]: any; +}; + +export interface ExecutionResult { + rawResult: RawResult[]; + affectedRows?: number; + error?: string; +} diff --git a/src/types/index.ts b/src/types/index.ts index e38f56b..4f74f47 100644 --- a/src/types/index.ts +++ b/src/types/index.ts @@ -6,3 +6,4 @@ export * from "./conversation"; export * from "./message"; export * from "./setting"; export * from "./api"; +export * from "./connector"; diff --git a/src/utils/execution.ts b/src/utils/execution.ts new file mode 100644 index 0000000..2a0d6b1 --- /dev/null +++ b/src/utils/execution.ts @@ -0,0 +1,11 @@ +import { ExecutionResult } from "@/types"; + +export const getMessageFromExecutionResult = (result: ExecutionResult): string => { + if (result.error) { + return result.error; + } + if (result.affectedRows) { + return `${result.affectedRows} rows affected.`; + } + return ""; +}; diff --git a/src/utils/index.ts b/src/utils/index.ts index 0a36f17..13975e9 100644 --- a/src/utils/index.ts +++ b/src/utils/index.ts @@ -1,3 +1,4 @@ export * from "./id"; export * from "./openai"; export * from "./sql"; +export * from "./execution";