feat: support schema select (#112)

This commit is contained in:
CorrectRoadH
2023-05-31 14:04:10 +08:00
committed by GitHub
parent 1a57c6a899
commit 00022e6bb7
13 changed files with 157 additions and 221 deletions

View File

@ -2,7 +2,7 @@ import { Drawer } from "@mui/material";
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import { useConnectionStore, useConversationStore, useLayoutStore, ResponsiveWidth, useSettingStore } from "@/store"; import { useConnectionStore, useConversationStore, useLayoutStore, ResponsiveWidth, useSettingStore } from "@/store";
import { Table } from "@/types"; import { Engine, Table, Schema } from "@/types";
import useLoading from "@/hooks/useLoading"; import useLoading from "@/hooks/useLoading";
import Select from "./kit/Select"; import Select from "./kit/Select";
import Icon from "./Icon"; import Icon from "./Icon";
@ -26,11 +26,21 @@ const ConnectionSidebar = () => {
const currentConnectionCtx = connectionStore.currentConnectionCtx; const currentConnectionCtx = connectionStore.currentConnectionCtx;
const databaseList = connectionStore.databaseList.filter((database) => database.connectionId === currentConnectionCtx?.connection.id); const databaseList = connectionStore.databaseList.filter((database) => database.connectionId === currentConnectionCtx?.connection.id);
const [tableList, updateTableList] = useState<Table[]>([]); const [tableList, updateTableList] = useState<Table[]>([]);
const [schemaList, updateSchemaList] = useState<Schema[]>([]);
const [hasSchemaProperty, updateHasSchemaProperty] = useState<boolean>(false);
const selectedTablesName: string[] = const selectedTablesName: string[] =
conversationStore.getConversationById(conversationStore.currentConversationId)?.selectedTablesName || []; conversationStore.getConversationById(conversationStore.currentConversationId)?.selectedTablesName || [];
const selectedSchemaName: string =
conversationStore.getConversationById(conversationStore.currentConversationId)?.selectedSchemaName || "";
const tableSchemaLoadingState = useLoading(); const tableSchemaLoadingState = useLoading();
const currentConversation = conversationStore.getConversationById(conversationStore.currentConversationId); const currentConversation = conversationStore.getConversationById(conversationStore.currentConversationId);
useEffect(() => {
updateHasSchemaProperty(
currentConnectionCtx?.connection.engineType === Engine.PostgreSQL || currentConnectionCtx?.connection.engineType === Engine.MSSQL
);
}, [currentConnectionCtx?.connection]);
useEffect(() => { useEffect(() => {
const handleWindowResize = () => { const handleWindowResize = () => {
if (window.innerWidth < ResponsiveWidth.sm) { if (window.innerWidth < ResponsiveWidth.sm) {
@ -71,14 +81,26 @@ const ConnectionSidebar = () => {
}, [currentConnectionCtx?.connection]); }, [currentConnectionCtx?.connection]);
useEffect(() => { useEffect(() => {
const tableList = const schemaList =
connectionStore.databaseList.find( connectionStore.databaseList.find(
(database) => (database) =>
database.connectionId === currentConnectionCtx?.connection.id && database.name === currentConnectionCtx?.database?.name database.connectionId === currentConnectionCtx?.connection.id && database.name === currentConnectionCtx?.database?.name
)?.tableList || []; )?.schemaList || [];
updateSchemaList(schemaList);
// need to create a conversation. otherwise updateSelectedSchemaName will failed.
createConversation();
if (hasSchemaProperty) {
conversationStore.updateSelectedSchemaName(schemaList[0]?.name || "");
} else {
conversationStore.updateSelectedSchemaName("");
}
}, [connectionStore, hasSchemaProperty, currentConnectionCtx, schemaList]);
useEffect(() => {
const tableList = schemaList.find((schema) => schema.name === selectedSchemaName)?.tables || [];
updateTableList(tableList); updateTableList(tableList);
}, [connectionStore, currentConnectionCtx]); }, [selectedSchemaName, selectedTablesName, schemaList]);
const handleDatabaseNameSelect = async (databaseName: string) => { const handleDatabaseNameSelect = async (databaseName: string) => {
if (!currentConnectionCtx?.connection) { if (!currentConnectionCtx?.connection) {
@ -113,19 +135,23 @@ const ConnectionSidebar = () => {
}; };
const handleTableNameSelect = async (selectedTablesName: string[]) => { const handleTableNameSelect = async (selectedTablesName: string[]) => {
createConversation();
conversationStore.updateSelectedTablesName(selectedTablesName); conversationStore.updateSelectedTablesName(selectedTablesName);
}; };
const handleAllSelect = async () => { const handleAllSelect = async () => {
createConversation();
conversationStore.updateSelectedTablesName(tableList.map((table) => table.name)); conversationStore.updateSelectedTablesName(tableList.map((table) => table.name));
}; };
const handleEmptySelect = async () => { const handleEmptySelect = async () => {
createConversation();
conversationStore.updateSelectedTablesName([]); conversationStore.updateSelectedTablesName([]);
}; };
const handleSchemaNameSelect = async (schemaName: string) => {
// need to empty selectedTablesName when schemaName changed. because selectedTablesName may not exist in new schema.
conversationStore.updateSelectedTablesName([]);
conversationStore.updateSelectedSchemaName(schemaName);
};
return ( return (
<> <>
<Drawer <Drawer
@ -169,6 +195,20 @@ const ConnectionSidebar = () => {
/> />
</div> </div>
)} )}
{hasSchemaProperty && schemaList.length > 0 && (
<Select
className="w-full px-4 py-3 !text-base"
value={selectedSchemaName}
itemList={schemaList.map((schema) => {
return {
label: schema.name,
value: schema.name,
};
})}
onValueChange={(schema) => handleSchemaNameSelect(schema)}
placeholder={t("connection.select-schema") || ""}
/>
)}
{currentConnectionCtx && {currentConnectionCtx &&
(tableSchemaLoadingState.isLoading ? ( (tableSchemaLoadingState.isLoading ? (
<div className="w-full h-12 flex flex-row justify-start items-center px-4 sticky top-0 border z-1 mb-4 mt-2 rounded-lg text-sm text-gray-600 dark:text-gray-400"> <div className="w-full h-12 flex flex-row justify-start items-center px-4 sticky top-0 border z-1 mb-4 mt-2 rounded-lg text-sm text-gray-600 dark:text-gray-400">

View File

@ -150,17 +150,17 @@ const ConversationView = () => {
if (connectionStore.currentConnectionCtx?.database) { if (connectionStore.currentConnectionCtx?.database) {
let schema = ""; let schema = "";
try { try {
const tables = await connectionStore.getOrFetchDatabaseSchema(connectionStore.currentConnectionCtx?.database); const schemaList = await connectionStore.getOrFetchDatabaseSchema(connectionStore.currentConnectionCtx?.database);
// Empty table name(such as []) denote all table. [] and `undefined` both are false in `if` // Empty table name(such as []) denote all table. [] and `undefined` both are false in `if`
const tableList: string[] = []; const tableList: string[] = [];
const selectedSchema = schemaList.find((schema) => schema.name == (currentConversation.selectedSchemaName || ""));
if (currentConversation.selectedTablesName) { if (currentConversation.selectedTablesName) {
currentConversation.selectedTablesName.forEach((tableName: string) => { currentConversation.selectedTablesName.forEach((tableName: string) => {
const table = tables.find((table) => table.name === tableName); const table = selectedSchema?.tables.find((table) => table.name == tableName);
tableList.push(table!.structure); tableList.push(table!.structure);
}); });
} else { } else {
for (const table of tables) { for (const table of selectedSchema?.tables || []) {
tableList.push(table!.structure); tableList.push(table!.structure);
} }
} }

View File

@ -1,4 +1,4 @@
import { Connection, Engine, ExecutionResult } from "@/types"; import { Connection, Engine, ExecutionResult, Schema } from "@/types";
import mysql from "./mysql"; import mysql from "./mysql";
import postgres from "./postgres"; import postgres from "./postgres";
import mssql from "./mssql"; import mssql from "./mssql";
@ -7,17 +7,7 @@ export interface Connector {
testConnection: () => Promise<boolean>; testConnection: () => Promise<boolean>;
execute: (databaseName: string, statement: string) => Promise<ExecutionResult>; execute: (databaseName: string, statement: string) => Promise<ExecutionResult>;
getDatabases: () => Promise<string[]>; getDatabases: () => Promise<string[]>;
getTables: (databaseName: string) => Promise<string[]>; getTableSchema: (databaseName: string) => Promise<Schema[]>;
getTableStructure: (
databaseName: string,
tableName: string,
structureFetched: (tableName: string, structure: string) => void
) => Promise<void>;
getTableStructureBatch: (
databaseName: string,
tableNameList: string[],
structureFetched: (tableName: string, structure: string) => void
) => Promise<void>;
} }
export const newConnector = (connection: Connection): Connector => { export const newConnector = (connection: Connection): Connector => {

View File

@ -1,5 +1,5 @@
import { ConnectionPool } from "mssql"; import { ConnectionPool } from "mssql";
import { Connection, ExecutionResult } from "@/types"; import { Connection, ExecutionResult, Schema, Table } from "@/types";
import { Connector } from ".."; import { Connector } from "..";
const systemDatabases = ["master", "tempdb", "model", "msdb"]; const systemDatabases = ["master", "tempdb", "model", "msdb"];
@ -59,62 +59,31 @@ const getDatabases = async (connection: Connection): Promise<string[]> => {
return databaseList; return databaseList;
}; };
const getTables = async (connection: Connection, databaseName: string): Promise<string[]> => { const getTableSchema = async (connection: Connection, databaseName: string): Promise<Schema[]> => {
const pool = await getMSSQLConnection(connection); const pool = await getMSSQLConnection(connection);
const request = pool.request(); const request = pool.request();
const schemaList: Schema[] = [];
const result = await request.query( const result = await request.query(
`SELECT TABLE_NAME as table_name FROM ${databaseName}.INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE';` `SELECT TABLE_NAME as table_name, TABLE_SCHEMA as table_schema FROM ${databaseName}.INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE';`
); );
await pool.close();
const tableList = [];
for (const row of result.recordset) { for (const row of result.recordset) {
if (row["table_name"]) { if (row["table_name"]) {
tableList.push(row["table_name"]); const schema = schemaList.find((schema) => schema.name === row["table_schema"]);
if (schema) {
schema.tables.push({ name: row["table_name"] as string, structure: "" } as Table);
} else {
schemaList.push({
name: row["table_schema"],
tables: [{ name: row["table_name"], structure: "" } as Table],
});
}
} }
} }
return tableList;
};
const getTableStructure = async ( for (const schema of schemaList) {
connection: Connection, for (const table of schema.tables) {
databaseName: string,
tableName: string,
structureFetched: (tableName: string, structure: string) => void
): Promise<void> => {
const pool = await getMSSQLConnection(connection);
const request = pool.request();
const { recordset } = await request.query( const { recordset } = await request.query(
`SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE FROM ${databaseName}.INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA='dbo' AND TABLE_NAME='${tableName}';` `SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE FROM ${databaseName}.INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA='${schema.name}' AND TABLE_NAME='${table.name}';`
);
const columnList = [];
// Transform to standard schema string.
for (const row of recordset) {
columnList.push(
`${row["COLUMN_NAME"]} ${row["DATA_TYPE"].toUpperCase()} ${String(row["IS_NULLABLE"]).toUpperCase() === "NO" ? "NOT NULL" : ""}`
);
}
structureFetched(
tableName,
`CREATE TABLE [${tableName}] (
${columnList.join(",\n")}
);`
);
};
const getTableStructureBatch = async (
connection: Connection,
databaseName: string,
tableNameList: string[],
structureFetched: (tableName: string, structure: string) => void
): Promise<void> => {
const pool = await getMSSQLConnection(connection);
const request = pool.request();
await Promise.all(
tableNameList.map(async (tableName) => {
const { recordset } = await request.query(
`SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE FROM ${databaseName}.INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA='dbo' AND TABLE_NAME='${tableName}';`
); );
const columnList = []; const columnList = [];
// Transform to standard schema string. // Transform to standard schema string.
@ -123,14 +92,13 @@ const getTableStructureBatch = async (
`${row["COLUMN_NAME"]} ${row["DATA_TYPE"].toUpperCase()} ${String(row["IS_NULLABLE"]).toUpperCase() === "NO" ? "NOT NULL" : ""}` `${row["COLUMN_NAME"]} ${row["DATA_TYPE"].toUpperCase()} ${String(row["IS_NULLABLE"]).toUpperCase() === "NO" ? "NOT NULL" : ""}`
); );
} }
structureFetched( table.structure = `CREATE TABLE [${table.name}] (
tableName,
`CREATE TABLE [${tableName}] (
${columnList.join(",\n")} ${columnList.join(",\n")}
);` );`;
); }
}) }
); await pool.close();
return schemaList;
}; };
const newConnector = (connection: Connection): Connector => { const newConnector = (connection: Connection): Connector => {
@ -138,14 +106,7 @@ const newConnector = (connection: Connection): Connector => {
testConnection: () => testConnection(connection), testConnection: () => testConnection(connection),
execute: (databaseName: string, statement: string) => execute(connection, databaseName, statement), execute: (databaseName: string, statement: string) => execute(connection, databaseName, statement),
getDatabases: () => getDatabases(connection), getDatabases: () => getDatabases(connection),
getTables: (databaseName: string) => getTables(connection, databaseName), getTableSchema: (databaseName: string) => getTableSchema(connection, databaseName),
getTableStructure: (databaseName: string, tableName: string, structureFetched: (tableName: string, structure: string) => void) =>
getTableStructure(connection, databaseName, tableName, structureFetched),
getTableStructureBatch: (
databaseName: string,
tableNameList: string[],
structureFetched: (tableName: string, structure: string) => void
) => getTableStructureBatch(connection, databaseName, tableNameList, structureFetched),
}; };
}; };

View File

@ -1,6 +1,6 @@
import { ConnectionOptions } from "mysql2"; import { ConnectionOptions } from "mysql2";
import mysql, { RowDataPacket } from "mysql2/promise"; import mysql, { RowDataPacket } from "mysql2/promise";
import { Connection, ExecutionResult } from "@/types"; import { Connection, ExecutionResult, Table, Schema } from "@/types";
import { Connector } from ".."; import { Connector } from "..";
const systemDatabases = ["information_schema", "mysql", "performance_schema", "sys"]; const systemDatabases = ["information_schema", "mysql", "performance_schema", "sys"];
@ -69,56 +69,30 @@ const getDatabases = async (connection: Connection): Promise<string[]> => {
return databaseList; return databaseList;
}; };
const getTables = async (connection: Connection, databaseName: string): Promise<string[]> => { const getTableSchema = async (connection: Connection, databaseName: string): Promise<Schema[]> => {
const conn = await getMySQLConnection(connection); const conn = await getMySQLConnection(connection);
// get All tableList from database
const [rows] = await conn.query<RowDataPacket[]>( const [rows] = await conn.query<RowDataPacket[]>(
`SELECT TABLE_NAME as table_name FROM information_schema.tables WHERE TABLE_SCHEMA=? AND TABLE_TYPE='BASE TABLE';`, `SELECT TABLE_NAME as table_name FROM information_schema.tables WHERE TABLE_SCHEMA=? AND TABLE_TYPE='BASE TABLE';`,
[databaseName] [databaseName]
); );
conn.destroy();
const tableList = []; const tableList = [];
for (const row of rows) { for (const row of rows) {
if (row["table_name"]) { if (row["table_name"]) {
tableList.push(row["table_name"]); tableList.push(row["table_name"]);
} }
} }
return tableList; const SchemaList: Schema[] = [{ name: "", tables: [] as Table[] }];
};
const getTableStructure = async ( for (const tableName of tableList) {
connection: Connection,
databaseName: string,
tableName: string,
structureFetched: (tableName: string, structure: string) => void
): Promise<void> => {
const conn = await getMySQLConnection(connection);
const [rows] = await conn.query<RowDataPacket[]>(`SHOW CREATE TABLE \`${databaseName}\`.\`${tableName}\`;`);
conn.destroy();
if (rows.length !== 1) {
throw new Error("Unexpected number of rows.");
}
structureFetched(tableName, rows[0]["Create Table"] || "");
};
const getTableStructureBatch = async (
connection: Connection,
databaseName: string,
tableNameList: string[],
structureFetched: (tableName: string, structure: string) => void
): Promise<void> => {
const conn = await getMySQLConnection(connection);
await Promise.all(
tableNameList.map(async (tableName) => {
const [rows] = await conn.query<RowDataPacket[]>(`SHOW CREATE TABLE \`${databaseName}\`.\`${tableName}\`;`); const [rows] = await conn.query<RowDataPacket[]>(`SHOW CREATE TABLE \`${databaseName}\`.\`${tableName}\`;`);
if (rows.length !== 1) { if (rows.length !== 1) {
throw new Error("Unexpected number of rows."); throw new Error("Unexpected number of rows.");
} }
structureFetched(tableName, rows[0]["Create Table"] || "");
}) SchemaList[0].tables.push({ name: tableName, structure: rows[0]["Create Table"] || "" });
).finally(() => { }
conn.destroy(); return SchemaList;
});
}; };
const newConnector = (connection: Connection): Connector => { const newConnector = (connection: Connection): Connector => {
@ -126,14 +100,7 @@ const newConnector = (connection: Connection): Connector => {
testConnection: () => testConnection(connection), testConnection: () => testConnection(connection),
execute: (databaseName: string, statement: string) => execute(connection, databaseName, statement), execute: (databaseName: string, statement: string) => execute(connection, databaseName, statement),
getDatabases: () => getDatabases(connection), getDatabases: () => getDatabases(connection),
getTables: (databaseName: string) => getTables(connection, databaseName), getTableSchema: (databaseName: string) => getTableSchema(connection, databaseName),
getTableStructure: (databaseName: string, tableName: string, structureFetched: (tableName: string, structure: string) => void) =>
getTableStructure(connection, databaseName, tableName, structureFetched),
getTableStructureBatch: (
databaseName: string,
tableNameList: string[],
structureFetched: (tableName: string, structure: string) => void
) => getTableStructureBatch(connection, databaseName, tableNameList, structureFetched),
}; };
}; };

View File

@ -1,5 +1,5 @@
import { Client, ClientConfig } from "pg"; import { Client, ClientConfig } from "pg";
import { Connection, ExecutionResult } from "@/types"; import { Connection, ExecutionResult, Table, Schema } from "@/types";
import { Connector } from ".."; import { Connector } from "..";
const systemSchemas = const systemSchemas =
@ -92,86 +92,51 @@ const getDatabases = async (connection: Connection): Promise<string[]> => {
return databaseList; return databaseList;
}; };
const getTables = async (connection: Connection, databaseName: string): Promise<string[]> => { const getTableSchema = async (connection: Connection, databaseName: string): Promise<Schema[]> => {
connection.database = databaseName; connection.database = databaseName;
const client = await newPostgresClient(connection); const client = await newPostgresClient(connection);
const { rows } = await client.query( const { rows } = await client.query(
`SELECT table_schema, table_name FROM information_schema.tables WHERE table_schema NOT IN (${systemSchemas}) AND table_name NOT IN (${systemTables}) AND table_type='BASE TABLE' AND table_catalog=$1;`, `SELECT table_schema, table_name FROM information_schema.tables WHERE table_schema NOT IN (${systemSchemas}) AND table_name NOT IN (${systemTables}) AND table_type='BASE TABLE' AND table_catalog=$1;`,
[databaseName] [databaseName]
); );
await client.end();
const tableList = []; const schemaList: Schema[] = [];
for (const row of rows) { for (const row of rows) {
if (row["table_name"]) { if (row["table_name"]) {
if (row["table_schema"] !== "public") { const schema = schemaList.find((schema) => schema.name === row["table_schema"]);
tableList.push(`${row["table_schema"]}.${row["table_name"]}`); if (schema) {
continue; schema.tables.push({ name: row["table_name"] as string, structure: "" } as Table);
} } else {
tableList.push(row["table_name"]); schemaList.push({
} name: row["table_schema"],
} tables: [{ name: row["table_name"], structure: "" } as Table],
return tableList;
};
const getTableStructure = async (
connection: Connection,
databaseName: string,
tableName: string,
structureFetched: (tableName: string, structure: string) => void
): Promise<void> => {
connection.database = databaseName;
const client = await newPostgresClient(connection);
const { rows } = await client.query(
`SELECT column_name, data_type, is_nullable FROM information_schema.columns WHERE table_schema NOT IN (${systemSchemas}) AND table_name=$1;`,
[tableName]
);
await client.end();
const columnList = [];
// TODO(steven): transform it to standard schema string.
for (const row of rows) {
columnList.push(
`${row["column_name"]} ${row["data_type"].toUpperCase()} ${String(row["is_nullable"]).toUpperCase() === "NO" ? "NOT NULL" : ""}`
);
}
structureFetched(
tableName,
`CREATE TABLE \`${tableName}\` (
${columnList.join(",\n")}
);`
);
};
const getTableStructureBatch = async (
connection: Connection,
databaseName: string,
tableNameList: string[],
structureFetched: (tableName: string, structure: string) => void
): Promise<void> => {
connection.database = databaseName;
const client = await newPostgresClient(connection);
await Promise.all(
tableNameList.map(async (tableName) => {
const { rows } = await client.query(
`SELECT column_name, data_type, is_nullable FROM information_schema.columns WHERE table_schema NOT IN (${systemSchemas}) AND table_name=$1;`,
[tableName]
);
const columnList = [];
// TODO(steven): transform it to standard schema string.
for (const row of rows) {
columnList.push(
`${row["column_name"]} ${row["data_type"].toUpperCase()} ${String(row["is_nullable"]).toUpperCase() === "NO" ? "NOT NULL" : ""}`
);
}
structureFetched(
tableName,
`CREATE TABLE \`${tableName}\` (
${columnList.join(",\n")}
);`
);
})
).finally(async () => {
await client.end();
}); });
}
}
}
for (const schema of schemaList) {
for (const table of schema.tables) {
const { rows: result } = await client.query(
`SELECT column_name, data_type, is_nullable FROM information_schema.columns WHERE table_schema NOT IN (${systemSchemas}) AND table_name=$1 AND table_schema=$2;`,
[table.name, schema.name]
);
const columnList = [];
// TODO(steven): transform it to standard schema string.
for (const row of result) {
columnList.push(
`${row["column_name"]} ${row["data_type"].toUpperCase()} ${String(row["is_nullable"]).toUpperCase() === "NO" ? "NOT NULL" : ""}`
);
}
table.structure = `CREATE TABLE \`${table.name}\` (
${columnList.join(",\n")}
);`;
}
}
await client.end();
return schemaList;
return [];
}; };
const newConnector = (connection: Connection): Connector => { const newConnector = (connection: Connection): Connector => {
@ -179,14 +144,7 @@ const newConnector = (connection: Connection): Connector => {
testConnection: () => testConnection(connection), testConnection: () => testConnection(connection),
execute: (databaseName: string, statement: string) => execute(connection, databaseName, statement), execute: (databaseName: string, statement: string) => execute(connection, databaseName, statement),
getDatabases: () => getDatabases(connection), getDatabases: () => getDatabases(connection),
getTables: (databaseName: string) => getTables(connection, databaseName), getTableSchema: (databaseName: string) => getTableSchema(connection, databaseName),
getTableStructure: (databaseName: string, tableName: string, structureFetched: (tableName: string, structure: string) => void) =>
getTableStructure(connection, databaseName, tableName, structureFetched),
getTableStructureBatch: (
databaseName: string,
tableNameList: string[],
structureFetched: (tableName: string, structure: string) => void
) => getTableStructureBatch(connection, databaseName, tableNameList, structureFetched),
}; };
}; };

View File

@ -30,6 +30,7 @@
"edit": "Edit Connection", "edit": "Edit Connection",
"select-database": "Select your database", "select-database": "Select your database",
"select-table": "Select your table", "select-table": "Select your table",
"select-schema": "Select your Schema",
"all-tables": "All Tables", "all-tables": "All Tables",
"database-type": "Database type", "database-type": "Database type",
"title": "Title", "title": "Title",

View File

@ -30,6 +30,7 @@
"edit": "编辑连接", "edit": "编辑连接",
"select-database": "选择数据库", "select-database": "选择数据库",
"select-table": "选择数据表", "select-table": "选择数据表",
"select-schema": "选择 Schema",
"all-tables": "全部表", "all-tables": "全部表",
"select-all": "选择全部", "select-all": "选择全部",
"empty-select": "清空选择", "empty-select": "清空选择",

View File

@ -1,6 +1,6 @@
import { NextApiRequest, NextApiResponse } from "next"; import { NextApiRequest, NextApiResponse } from "next";
import { newConnector } from "@/lib/connectors"; import { newConnector } from "@/lib/connectors";
import { Connection, Table } from "@/types"; import { Connection, Schema } from "@/types";
import { changeTiDBConnectionToMySQL } from "@/utils"; import { changeTiDBConnectionToMySQL } from "@/utils";
import { Engine } from "@/types/connection"; import { Engine } from "@/types/connection";
@ -18,20 +18,13 @@ const handler = async (req: NextApiRequest, res: NextApiResponse) => {
} }
const db = req.body.db as string; const db = req.body.db as string;
try { try {
const connector = newConnector(connection); const connector = newConnector(connection);
const tableStructures: Table[] = []; const schemaList: Schema[] = await connector.getTableSchema(db);
const rawTableNameList = await connector.getTables(db);
const structureFetched = (tableName: string, structure: string) => {
tableStructures.push({
name: tableName,
structure,
});
};
await connector.getTableStructureBatch(db, rawTableNameList, structureFetched);
res.status(200).json({ res.status(200).json({
data: tableStructures, data: schemaList,
}); });
} catch (error: any) { } catch (error: any) {
res.status(400).json({ res.status(400).json({

View File

@ -2,7 +2,7 @@ import axios from "axios";
import { uniqBy } from "lodash-es"; import { uniqBy } from "lodash-es";
import { create } from "zustand"; import { create } from "zustand";
import { persist } from "zustand/middleware"; import { persist } from "zustand/middleware";
import { Connection, Database, Engine, ResponseObject, Table } from "@/types"; import { Connection, Database, Engine, ResponseObject, Schema } from "@/types";
import { generateUUID } from "@/utils"; import { generateUUID } from "@/utils";
interface ConnectionContext { interface ConnectionContext {
@ -28,7 +28,7 @@ interface ConnectionState {
createConnection: (connection: Connection) => Connection; createConnection: (connection: Connection) => Connection;
setCurrentConnectionCtx: (connectionCtx: ConnectionContext | undefined) => void; setCurrentConnectionCtx: (connectionCtx: ConnectionContext | undefined) => void;
getOrFetchDatabaseList: (connection: Connection, skipCache?: boolean) => Promise<Database[]>; getOrFetchDatabaseList: (connection: Connection, skipCache?: boolean) => Promise<Database[]>;
getOrFetchDatabaseSchema: (database: Database, skipCache?: boolean) => Promise<Table[]>; getOrFetchDatabaseSchema: (database: Database, skipCache?: boolean) => Promise<Schema[]>;
getConnectionById: (connectionId: string) => Connection | undefined; getConnectionById: (connectionId: string) => Connection | undefined;
updateConnection: (connectionId: string, connection: Partial<Connection>) => void; updateConnection: (connectionId: string, connection: Partial<Connection>) => void;
clearConnection: (filter: (connection: Connection) => boolean) => void; clearConnection: (filter: (connection: Connection) => boolean) => void;
@ -67,12 +67,13 @@ export const useConnectionStore = create<ConnectionState>()(
const { data } = await axios.post<string[]>("/api/connection/db", { const { data } = await axios.post<string[]>("/api/connection/db", {
connection, connection,
}); });
const fetchedDatabaseList = data.map( const fetchedDatabaseList = data.map(
(dbName) => (dbName) =>
({ ({
connectionId: connection.id, connectionId: connection.id,
name: dbName, name: dbName,
tableList: [], schemaList: [],
} as Database) } as Database)
); );
const databaseList = uniqBy( const databaseList = uniqBy(
@ -90,8 +91,8 @@ export const useConnectionStore = create<ConnectionState>()(
if (!skipCache) { if (!skipCache) {
const db = state.databaseList.find((db) => db.connectionId === database.connectionId && db.name === database.name); const db = state.databaseList.find((db) => db.connectionId === database.connectionId && db.name === database.name);
if (db !== undefined && Array.isArray(db.tableList) && db.tableList.length !== 0) { if (db !== undefined && Array.isArray(db.schemaList) && db.schemaList.length !== 0) {
return db.tableList; return db.schemaList;
} }
} }
@ -100,19 +101,20 @@ export const useConnectionStore = create<ConnectionState>()(
return []; return [];
} }
const { data: result } = await axios.post<ResponseObject<Table[]>>("/api/connection/db_schema", { const { data: result } = await axios.post<ResponseObject<Schema[]>>("/api/connection/db_schema", {
connection, connection,
db: database.name, db: database.name,
}); });
if (result.message) { if (result.message) {
throw result.message; throw result.message;
} }
const fetchedTableList = result.data; const fetchedTableList: Schema[] = result.data;
set((state) => ({ set((state) => ({
...state, ...state,
databaseList: state.databaseList.map((item) => databaseList: state.databaseList.map((item) =>
item.connectionId === database.connectionId && item.name === database.name ? { ...item, tableList: fetchedTableList } : item item.connectionId === database.connectionId && item.name === database.name ? { ...item, schemaList: fetchedTableList } : item
), ),
})); }));
@ -136,6 +138,16 @@ export const useConnectionStore = create<ConnectionState>()(
}), }),
{ {
name: "connection-storage", name: "connection-storage",
version: 1,
migrate: (persistedState: any, version: number) => {
let state = persistedState as ConnectionState;
if (version === 0) {
console.info(`migrate from ${version} to 1`);
// to clear old data. it will make refetch new schema List
state.databaseList = [];
}
return state;
},
} }
) )
); );

View File

@ -24,6 +24,7 @@ interface ConversationState {
updateConversation: (conversationId: Id, conversation: Partial<Conversation>) => void; updateConversation: (conversationId: Id, conversation: Partial<Conversation>) => void;
clearConversation: (filter: (conversation: Conversation) => boolean) => void; clearConversation: (filter: (conversation: Conversation) => boolean) => void;
updateSelectedTablesName: (selectedTablesName: string[]) => void; updateSelectedTablesName: (selectedTablesName: string[]) => void;
updateSelectedSchemaName: (selectedSchemaName: string) => void;
} }
export const useConversationStore = create<ConversationState>()( export const useConversationStore = create<ConversationState>()(
@ -72,6 +73,14 @@ export const useConversationStore = create<ConversationState>()(
}); });
} }
}, },
updateSelectedSchemaName: (selectedSchemaName: string) => {
const currentConversation = get().getConversationById(get().currentConversationId);
if (currentConversation) {
get().updateConversation(currentConversation.id, {
selectedSchemaName,
});
}
},
}), }),
{ {
name: "conversation-storage", name: "conversation-storage",

View File

@ -5,6 +5,7 @@ export interface Conversation {
connectionId?: Id; connectionId?: Id;
databaseName?: string; databaseName?: string;
selectedTablesName?: string[]; selectedTablesName?: string[];
selectedSchemaName?: string;
assistantId: Id; assistantId: Id;
title: string; title: string;
createdAt: Timestamp; createdAt: Timestamp;

View File

@ -1,9 +1,8 @@
import { Id } from "."; import { Id } from ".";
export interface Database { export interface Database {
connectionId: Id; connectionId: Id;
name: string; name: string;
tableList: Table[]; schemaList: Schema[];
} }
export interface Table { export interface Table {
@ -12,3 +11,7 @@ export interface Table {
// It's mainly used for providing a chat context for the assistant. // It's mainly used for providing a chat context for the assistant.
structure: string; structure: string;
} }
export interface Schema {
name: string;
tables: Table[];
}