feat: add a table select to choose a table as prompt (#74)

This commit is contained in:
CorrectRoadH
2023-05-05 16:46:29 +08:00
committed by GitHub
parent 07a10fb5c6
commit fa16f8fd48
8 changed files with 147 additions and 3 deletions

View File

@ -1,7 +1,14 @@
import { Drawer } from "@mui/material"; import { Drawer } from "@mui/material";
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import { useConnectionStore, useLayoutStore, ResponsiveWidth } from "@/store"; import {
useConnectionStore,
useConversationStore,
useLayoutStore,
ResponsiveWidth,
} from "@/store";
import { Table } from "@/types";
import useLoading from "@/hooks/useLoading";
import Link from "next/link"; import Link from "next/link";
import Select from "./kit/Select"; import Select from "./kit/Select";
import Tooltip from "./kit/Tooltip"; import Tooltip from "./kit/Tooltip";
@ -16,12 +23,15 @@ const ConnectionSidebar = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const layoutStore = useLayoutStore(); const layoutStore = useLayoutStore();
const connectionStore = useConnectionStore(); const connectionStore = useConnectionStore();
const conversationStore = useConversationStore();
const [isRequestingDatabase, setIsRequestingDatabase] = const [isRequestingDatabase, setIsRequestingDatabase] =
useState<boolean>(false); useState<boolean>(false);
const currentConnectionCtx = connectionStore.currentConnectionCtx; const currentConnectionCtx = connectionStore.currentConnectionCtx;
const databaseList = connectionStore.databaseList.filter( const databaseList = connectionStore.databaseList.filter(
(database) => database.connectionId === currentConnectionCtx?.connection.id (database) => database.connectionId === currentConnectionCtx?.connection.id
); );
const [tableList, updateTableList] = useState<Table[]>([]);
const tableSchemaLoadingState = useLoading();
useEffect(() => { useEffect(() => {
const handleWindowResize = () => { const handleWindowResize = () => {
@ -49,12 +59,40 @@ const ConnectionSidebar = () => {
.getOrFetchDatabaseList(currentConnectionCtx.connection) .getOrFetchDatabaseList(currentConnectionCtx.connection)
.finally(() => { .finally(() => {
setIsRequestingDatabase(false); setIsRequestingDatabase(false);
const database = databaseList.find(
(database) =>
database.name ===
useConnectionStore.getState().currentConnectionCtx?.database?.name
);
if (database) {
tableSchemaLoadingState.setLoading();
connectionStore.getOrFetchDatabaseSchema(database).then(() => {
tableSchemaLoadingState.setFinish();
});
}
}); });
} else { } else {
setIsRequestingDatabase(false); setIsRequestingDatabase(false);
} }
}, [currentConnectionCtx?.connection]); }, [currentConnectionCtx?.connection]);
useEffect(() => {
const tableList =
connectionStore.databaseList.find(
(database) =>
database.connectionId === currentConnectionCtx?.connection.id &&
database.name === currentConnectionCtx?.database?.name
)?.tableList || [];
updateTableList([
{
name: "",
structure: "",
} as Table,
...tableList,
]);
}, [connectionStore, currentConnectionCtx]);
const handleDatabaseNameSelect = async (databaseName: string) => { const handleDatabaseNameSelect = async (databaseName: string) => {
if (!currentConnectionCtx?.connection) { if (!currentConnectionCtx?.connection) {
return; return;
@ -70,8 +108,17 @@ const ConnectionSidebar = () => {
connection: currentConnectionCtx.connection, connection: currentConnectionCtx.connection,
database: database, database: database,
}); });
if (database) {
tableSchemaLoadingState.setLoading();
connectionStore.getOrFetchDatabaseSchema(database).then(() => {
tableSchemaLoadingState.setFinish();
});
}
}; };
const handleTableNameSelect = async (tableName: string) => {
conversationStore.updateTableName(tableName);
};
return ( return (
<> <>
<Drawer <Drawer
@ -126,6 +173,38 @@ const ConnectionSidebar = () => {
/> />
</div> </div>
)} )}
{tableSchemaLoadingState.isLoading ? (
<div className="w-full h-12 flex flex-row justify-start items-center px-4 sticky top-0 border z-1 mb-4 mt-2 rounded-lg text-sm text-gray-600 dark:text-gray-400">
<Icon.BiLoaderAlt className="w-4 h-auto animate-spin mr-1" />{" "}
{t("common.loading")}
</div>
) : (
tableList.length > 0 && (
<div className="w-full sticky top-0 z-1 my-4">
<Select
className="w-full px-4 py-3 !text-base"
value={
conversationStore.getConversationById(
conversationStore.currentConversationId
)?.tableName || ""
}
itemList={tableList.map((table) => {
return {
label:
table.name === ""
? t("connection.all-tables")
: table.name,
value: table.name,
};
})}
onValueChange={(tableName) =>
handleTableNameSelect(tableName)
}
placeholder={t("connection.select-table") || ""}
/>
</div>
)
)}
<ConversationList /> <ConversationList />
</div> </div>
<div className="sticky bottom-0 w-full flex flex-col justify-center bg-gray-100 dark:bg-zinc-700 backdrop-blur bg-opacity-60 pb-4 py-2"> <div className="sticky bottom-0 w-full flex flex-col justify-center bg-gray-100 dark:bg-zinc-700 backdrop-blur bg-opacity-60 pb-4 py-2">

View File

@ -190,11 +190,22 @@ const ConversationView = () => {
const tables = await connectionStore.getOrFetchDatabaseSchema( const tables = await connectionStore.getOrFetchDatabaseSchema(
connectionStore.currentConnectionCtx?.database connectionStore.currentConnectionCtx?.database
); );
for (const table of tables) { // Empty table name(such as "") denote all table. "" and `undefined` both are false in `if`
if (tokens < MAX_TOKENS / 2) { if (currentConversation.tableName) {
const table = tables.find((table) => {
return table.name === currentConversation.tableName;
});
if (table) {
tokens += countTextTokens(schema + table.structure); tokens += countTextTokens(schema + table.structure);
schema += table.structure; schema += table.structure;
} }
} else {
for (const table of tables) {
if (tokens < MAX_TOKENS / 2) {
tokens += countTextTokens(schema + table.structure);
schema += table.structure;
}
}
} }
} catch (error: any) { } catch (error: any) {
toast.error(error.message); toast.error(error.message);

36
src/hooks/useLoading.ts Normal file
View File

@ -0,0 +1,36 @@
import { useState } from "react";
// React state hook that manage a state of loading
const useLoading = (initialState = true) => {
const [state, setState] = useState({ isLoading: initialState, isFailed: false, isSucceed: false });
return {
...state,
setLoading: () => {
setState({
...state,
isLoading: true,
isFailed: false,
isSucceed: false,
});
},
setFinish: () => {
setState({
...state,
isLoading: false,
isFailed: false,
isSucceed: true,
});
},
setError: () => {
setState({
...state,
isLoading: false,
isFailed: true,
isSucceed: false,
});
},
};
};
export default useLoading;

View File

@ -22,6 +22,8 @@
"new": "Create Connection", "new": "Create Connection",
"edit": "Edit Connection", "edit": "Edit Connection",
"select-database": "Select your database", "select-database": "Select your database",
"select-table": "Select your table",
"all-tables":"All Tables",
"database-type": "Database type", "database-type": "Database type",
"title": "Title", "title": "Title",
"host": "Host", "host": "Host",

View File

@ -22,6 +22,8 @@
"new": "建立連線", "new": "建立連線",
"edit": "編輯連線", "edit": "編輯連線",
"select-database": "選擇您的資料庫", "select-database": "選擇您的資料庫",
"select-table": "選擇您的資料",
"all-tables":"全部表",
"database-type": "資料庫類型", "database-type": "資料庫類型",
"title": "標題", "title": "標題",
"host": "主機", "host": "主機",

View File

@ -22,6 +22,8 @@
"new": "创建连接", "new": "创建连接",
"edit": "编辑连接", "edit": "编辑连接",
"select-database": "选择数据库", "select-database": "选择数据库",
"select-table": "选择数据表",
"all-tables":"全部表",
"database-type": "数据库类型", "database-type": "数据库类型",
"title": "标题", "title": "标题",
"host": "主机", "host": "主机",

View File

@ -31,6 +31,7 @@ interface ConversationState {
conversation: Partial<Conversation> conversation: Partial<Conversation>
) => void; ) => void;
clearConversation: (filter: (conversation: Conversation) => boolean) => void; clearConversation: (filter: (conversation: Conversation) => boolean) => void;
updateTableName: (tableName: string) => void;
} }
export const useConversationStore = create<ConversationState>()( export const useConversationStore = create<ConversationState>()(
@ -77,6 +78,16 @@ export const useConversationStore = create<ConversationState>()(
conversationList: state.conversationList.filter(filter), conversationList: state.conversationList.filter(filter),
})); }));
}, },
updateTableName: (tableName: string) => {
const currentConversation = get().getConversationById(
get().currentConversationId
);
if (currentConversation) {
get().updateConversation(currentConversation.id, {
tableName,
});
}
},
}), }),
{ {
name: "conversation-storage", name: "conversation-storage",

View File

@ -4,6 +4,7 @@ export interface Conversation {
id: string; id: string;
connectionId?: Id; connectionId?: Id;
databaseName?: string; databaseName?: string;
tableName?: string;
assistantId: Id; assistantId: Id;
title: string; title: string;
createdAt: Timestamp; createdAt: Timestamp;