feat: add a table select to choose a table as prompt (#74)

This commit is contained in:
CorrectRoadH
2023-05-05 16:46:29 +08:00
committed by GitHub
parent 07a10fb5c6
commit fa16f8fd48
8 changed files with 147 additions and 3 deletions

View File

@ -1,7 +1,14 @@
import { Drawer } from "@mui/material";
import { useEffect, useState } from "react";
import { useTranslation } from "react-i18next";
import { useConnectionStore, useLayoutStore, ResponsiveWidth } from "@/store";
import {
useConnectionStore,
useConversationStore,
useLayoutStore,
ResponsiveWidth,
} from "@/store";
import { Table } from "@/types";
import useLoading from "@/hooks/useLoading";
import Link from "next/link";
import Select from "./kit/Select";
import Tooltip from "./kit/Tooltip";
@ -16,12 +23,15 @@ const ConnectionSidebar = () => {
const { t } = useTranslation();
const layoutStore = useLayoutStore();
const connectionStore = useConnectionStore();
const conversationStore = useConversationStore();
const [isRequestingDatabase, setIsRequestingDatabase] =
useState<boolean>(false);
const currentConnectionCtx = connectionStore.currentConnectionCtx;
const databaseList = connectionStore.databaseList.filter(
(database) => database.connectionId === currentConnectionCtx?.connection.id
);
const [tableList, updateTableList] = useState<Table[]>([]);
const tableSchemaLoadingState = useLoading();
useEffect(() => {
const handleWindowResize = () => {
@ -49,12 +59,40 @@ const ConnectionSidebar = () => {
.getOrFetchDatabaseList(currentConnectionCtx.connection)
.finally(() => {
setIsRequestingDatabase(false);
const database = databaseList.find(
(database) =>
database.name ===
useConnectionStore.getState().currentConnectionCtx?.database?.name
);
if (database) {
tableSchemaLoadingState.setLoading();
connectionStore.getOrFetchDatabaseSchema(database).then(() => {
tableSchemaLoadingState.setFinish();
});
}
});
} else {
setIsRequestingDatabase(false);
}
}, [currentConnectionCtx?.connection]);
useEffect(() => {
const tableList =
connectionStore.databaseList.find(
(database) =>
database.connectionId === currentConnectionCtx?.connection.id &&
database.name === currentConnectionCtx?.database?.name
)?.tableList || [];
updateTableList([
{
name: "",
structure: "",
} as Table,
...tableList,
]);
}, [connectionStore, currentConnectionCtx]);
const handleDatabaseNameSelect = async (databaseName: string) => {
if (!currentConnectionCtx?.connection) {
return;
@ -70,8 +108,17 @@ const ConnectionSidebar = () => {
connection: currentConnectionCtx.connection,
database: database,
});
if (database) {
tableSchemaLoadingState.setLoading();
connectionStore.getOrFetchDatabaseSchema(database).then(() => {
tableSchemaLoadingState.setFinish();
});
}
};
const handleTableNameSelect = async (tableName: string) => {
conversationStore.updateTableName(tableName);
};
return (
<>
<Drawer
@ -126,6 +173,38 @@ const ConnectionSidebar = () => {
/>
</div>
)}
{tableSchemaLoadingState.isLoading ? (
<div className="w-full h-12 flex flex-row justify-start items-center px-4 sticky top-0 border z-1 mb-4 mt-2 rounded-lg text-sm text-gray-600 dark:text-gray-400">
<Icon.BiLoaderAlt className="w-4 h-auto animate-spin mr-1" />{" "}
{t("common.loading")}
</div>
) : (
tableList.length > 0 && (
<div className="w-full sticky top-0 z-1 my-4">
<Select
className="w-full px-4 py-3 !text-base"
value={
conversationStore.getConversationById(
conversationStore.currentConversationId
)?.tableName || ""
}
itemList={tableList.map((table) => {
return {
label:
table.name === ""
? t("connection.all-tables")
: table.name,
value: table.name,
};
})}
onValueChange={(tableName) =>
handleTableNameSelect(tableName)
}
placeholder={t("connection.select-table") || ""}
/>
</div>
)
)}
<ConversationList />
</div>
<div className="sticky bottom-0 w-full flex flex-col justify-center bg-gray-100 dark:bg-zinc-700 backdrop-blur bg-opacity-60 pb-4 py-2">

View File

@ -190,11 +190,22 @@ const ConversationView = () => {
const tables = await connectionStore.getOrFetchDatabaseSchema(
connectionStore.currentConnectionCtx?.database
);
for (const table of tables) {
if (tokens < MAX_TOKENS / 2) {
// Empty table name(such as "") denote all table. "" and `undefined` both are false in `if`
if (currentConversation.tableName) {
const table = tables.find((table) => {
return table.name === currentConversation.tableName;
});
if (table) {
tokens += countTextTokens(schema + table.structure);
schema += table.structure;
}
} else {
for (const table of tables) {
if (tokens < MAX_TOKENS / 2) {
tokens += countTextTokens(schema + table.structure);
schema += table.structure;
}
}
}
} catch (error: any) {
toast.error(error.message);

36
src/hooks/useLoading.ts Normal file
View File

@ -0,0 +1,36 @@
import { useState } from "react";
// React state hook that manage a state of loading
const useLoading = (initialState = true) => {
const [state, setState] = useState({ isLoading: initialState, isFailed: false, isSucceed: false });
return {
...state,
setLoading: () => {
setState({
...state,
isLoading: true,
isFailed: false,
isSucceed: false,
});
},
setFinish: () => {
setState({
...state,
isLoading: false,
isFailed: false,
isSucceed: true,
});
},
setError: () => {
setState({
...state,
isLoading: false,
isFailed: true,
isSucceed: false,
});
},
};
};
export default useLoading;

View File

@ -22,6 +22,8 @@
"new": "Create Connection",
"edit": "Edit Connection",
"select-database": "Select your database",
"select-table": "Select your table",
"all-tables":"All Tables",
"database-type": "Database type",
"title": "Title",
"host": "Host",

View File

@ -22,6 +22,8 @@
"new": "建立連線",
"edit": "編輯連線",
"select-database": "選擇您的資料庫",
"select-table": "選擇您的資料",
"all-tables":"全部表",
"database-type": "資料庫類型",
"title": "標題",
"host": "主機",

View File

@ -22,6 +22,8 @@
"new": "创建连接",
"edit": "编辑连接",
"select-database": "选择数据库",
"select-table": "选择数据表",
"all-tables":"全部表",
"database-type": "数据库类型",
"title": "标题",
"host": "主机",

View File

@ -31,6 +31,7 @@ interface ConversationState {
conversation: Partial<Conversation>
) => void;
clearConversation: (filter: (conversation: Conversation) => boolean) => void;
updateTableName: (tableName: string) => void;
}
export const useConversationStore = create<ConversationState>()(
@ -77,6 +78,16 @@ export const useConversationStore = create<ConversationState>()(
conversationList: state.conversationList.filter(filter),
}));
},
updateTableName: (tableName: string) => {
const currentConversation = get().getConversationById(
get().currentConversationId
);
if (currentConversation) {
get().updateConversation(currentConversation.id, {
tableName,
});
}
},
}),
{
name: "conversation-storage",

View File

@ -4,6 +4,7 @@ export interface Conversation {
id: string;
connectionId?: Id;
databaseName?: string;
tableName?: string;
assistantId: Id;
title: string;
createdAt: Timestamp;