diff --git a/components/ConnectionSidebar.tsx b/components/ConnectionSidebar.tsx index 9ff347a..af1d5b8 100644 --- a/components/ConnectionSidebar.tsx +++ b/components/ConnectionSidebar.tsx @@ -7,13 +7,11 @@ import Icon from "./Icon"; import EngineIcon from "./EngineIcon"; import CreateConnectionModal from "./CreateConnectionModal"; import SettingModal from "./SettingModal"; -import ActionConfirmModal, { ActionConfirmModalProps } from "./ActionConfirmModal"; import EditChatTitleModal from "./EditChatTitleModal"; interface State { showCreateConnectionModal: boolean; showSettingModal: boolean; - showDeleteConnectionModal: boolean; showEditChatTitleModal: boolean; } @@ -24,10 +22,9 @@ const ConnectionSidebar = () => { const [state, setState] = useState({ showCreateConnectionModal: false, showSettingModal: false, - showDeleteConnectionModal: false, showEditChatTitleModal: false, }); - const [deleteConnectionModalContext, setDeleteConnectionModalContext] = useState(); + const [editConnectionModalContext, setEditConnectionModalContext] = useState(); const [editChatTitleModalContext, setEditChatTitleModalContext] = useState(); const connectionList = connectionStore.connectionList; const currentConnectionCtx = connectionStore.currentConnectionCtx; @@ -41,6 +38,7 @@ const ConnectionSidebar = () => { ...state, showCreateConnectionModal: show, }); + setEditConnectionModalContext(undefined); }; const toggleSettingModal = (show = true) => { @@ -65,28 +63,12 @@ const ConnectionSidebar = () => { }); }; - const handleDeleteConnection = (connection: Connection) => { + const handleEditConnection = (connection: Connection) => { setState({ ...state, - showDeleteConnectionModal: true, - }); - setDeleteConnectionModalContext({ - title: "Delete Connection", - content: "Are you sure to delete this connection?", - confirmButtonStyle: "btn-error", - close: () => { - setState({ - ...state, - showDeleteConnectionModal: false, - }); - }, - confirm: () => { - connectionStore.clearConnection((item) => item.id !== connection.id); - if (currentConnectionCtx?.connection.id === connection.id) { - connectionStore.setCurrentConnectionCtx(undefined); - } - }, + showCreateConnectionModal: true, }); + setEditConnectionModalContext(connection); }; const handleDatabaseNameSelect = async (databaseName: string) => { @@ -145,19 +127,19 @@ const ConnectionSidebar = () => { {connectionList.map((connection) => ( @@ -239,7 +221,7 @@ const ConnectionSidebar = () => { New Chat -
+
{ {createPortal( - toggleCreateConnectionModal(false)} />, + toggleCreateConnectionModal(false)} + />, document.body )} {createPortal( toggleSettingModal(false)} />, document.body)} - {state.showDeleteConnectionModal && - createPortal( - {})} - confirm={deleteConnectionModalContext?.confirm ?? (() => {})} - />, - document.body - )} - {state.showEditChatTitleModal && editChatTitleModalContext && createPortal( toggleEditChatTitleModal(false)} chat={editChatTitleModalContext} />, document.body)} diff --git a/components/CreateConnectionModal.tsx b/components/CreateConnectionModal.tsx index fea9c58..5859a17 100644 --- a/components/CreateConnectionModal.tsx +++ b/components/CreateConnectionModal.tsx @@ -1,13 +1,16 @@ import { cloneDeep, head } from "lodash-es"; import { useEffect, useState } from "react"; +import { createPortal } from "react-dom"; import { toast } from "react-hot-toast"; import { testConnection, useConnectionStore } from "@/store"; import { Connection, Engine } from "@/types"; import Icon from "./Icon"; import DataStorageBanner from "./DataStorageBanner"; +import ActionConfirmModal from "./ActionConfirmModal"; interface Props { show: boolean; + connection?: Connection; close: () => void; } @@ -22,15 +25,17 @@ const defaultConnection: Connection = { }; const CreateConnectionModal = (props: Props) => { - const { show, close } = props; + const { show, connection: editConnection, close } = props; const connectionStore = useConnectionStore(); const [connection, setConnection] = useState(defaultConnection); + const [showDeleteConnectionModal, setShowDeleteConnectionModal] = useState(false); const [isRequesting, setIsRequesting] = useState(false); const showDatabaseField = connection.engineType === Engine.PostgreSQL; + const isEditing = editConnection !== undefined; useEffect(() => { if (show) { - setConnection(defaultConnection); + setConnection(isEditing ? editConnection : defaultConnection); } }, [show]); @@ -47,22 +52,32 @@ const CreateConnectionModal = (props: Props) => { } setIsRequesting(true); - const connectionCreate = cloneDeep(connection); + const tempConnection = cloneDeep(connection); if (!showDatabaseField) { - connectionCreate.database = undefined; + tempConnection.database = undefined; } + try { - const result = await testConnection(connectionCreate); - if (!result) { - setIsRequesting(false); - toast.error("Failed to test connection"); - return; + await testConnection(tempConnection); + } catch (error) { + setIsRequesting(false); + toast.error("Failed to test connection, please check your connection settings"); + return; + } + + try { + let connection: Connection; + if (isEditing) { + connectionStore.updateConnection(tempConnection.id, tempConnection); + connection = tempConnection; + } else { + connection = connectionStore.createConnection(tempConnection); } - const createdConnection = connectionStore.createConnection(connectionCreate); + // Set the created connection as the current connection. - const databaseList = await connectionStore.getOrFetchDatabaseList(createdConnection); + const databaseList = await connectionStore.getOrFetchDatabaseList(connection, true); connectionStore.setCurrentConnectionCtx({ - connection: createdConnection, + connection: connection, database: head(databaseList), }); } catch (error) { @@ -76,90 +91,121 @@ const CreateConnectionModal = (props: Props) => { close(); }; + const handleDeleteConnection = () => { + connectionStore.clearConnection((item) => item.id !== connection.id); + if (connectionStore.currentConnectionCtx?.connection.id === connection.id) { + connectionStore.setCurrentConnectionCtx(undefined); + } + close(); + }; + return ( -
-
-

Create Connection

- -
- -
- - -
-
- - setPartialConnection({ host: e.target.value })} - /> -
-
- - setPartialConnection({ port: e.target.value })} - /> -
-
- - setPartialConnection({ username: e.target.value })} - /> -
-
- - setPartialConnection({ password: e.target.value })} - /> -
- {showDatabaseField && ( + <> +
+
+

{isEditing ? "Edit Connection" : "Create Connection"}

+ +
+
- + + +
+
+ setPartialConnection({ database: e.target.value })} + value={connection.host} + onChange={(e) => setPartialConnection({ host: e.target.value })} />
- )} -
-
- - +
+ + setPartialConnection({ port: e.target.value })} + /> +
+ {showDatabaseField && ( +
+ + setPartialConnection({ database: e.target.value })} + /> +
+ )} +
+ + setPartialConnection({ username: e.target.value })} + /> +
+
+ + setPartialConnection({ password: e.target.value })} + /> +
+
+
+
+ {isEditing && ( + + )} +
+
+ + +
+
-
+ + {showDeleteConnectionModal && + createPortal( + setShowDeleteConnectionModal(false)} + confirm={() => handleDeleteConnection()} + />, + document.body + )} + ); }; diff --git a/lib/connectors/mysql/index.ts b/lib/connectors/mysql/index.ts index c541186..ec817b7 100644 --- a/lib/connectors/mysql/index.ts +++ b/lib/connectors/mysql/index.ts @@ -11,13 +11,9 @@ const convertToConnectionUrl = (connection: Connection): string => { const testConnection = async (connection: Connection): Promise => { const connectionUrl = convertToConnectionUrl(connection); - try { - const conn = await mysql.createConnection(connectionUrl); - conn.destroy(); - return true; - } catch (error) { - return false; - } + const conn = await mysql.createConnection(connectionUrl); + conn.destroy(); + return true; }; const execute = async (connection: Connection, databaseName: string, statement: string): Promise => { diff --git a/lib/connectors/postgres/index.ts b/lib/connectors/postgres/index.ts index 5310482..6a79c4b 100644 --- a/lib/connectors/postgres/index.ts +++ b/lib/connectors/postgres/index.ts @@ -13,18 +13,10 @@ const newPostgresClient = (connection: Connection) => { }; const testConnection = async (connection: Connection): Promise => { - if (!connection.database) { - return false; - } - - try { - const client = newPostgresClient(connection); - await client.connect(); - await client.end(); - return true; - } catch (error) { - return false; - } + const client = newPostgresClient(connection); + await client.connect(); + await client.end(); + return true; }; const execute = async (connection: Connection, _: string, statement: string): Promise => { diff --git a/pages/api/connection/test.ts b/pages/api/connection/test.ts index f36b66d..e2dd750 100644 --- a/pages/api/connection/test.ts +++ b/pages/api/connection/test.ts @@ -13,10 +13,10 @@ const handler = async (req: NextApiRequest, res: NextApiResponse) => { const connection = req.body.connection as Connection; try { const connector = newConnector(connection); - const result = await connector.testConnection(); - res.status(200).json(result); + await connector.testConnection(); + res.status(200).json({}); } catch (error) { - res.status(400).json(false); + res.status(400).json({}); } }; diff --git a/store/chat.ts b/store/chat.ts index bce0335..da83239 100644 --- a/store/chat.ts +++ b/store/chat.ts @@ -1,5 +1,4 @@ import dayjs from "dayjs"; -import { uniqBy } from "lodash-es"; import { create } from "zustand"; import { persist } from "zustand/middleware"; import { Chat, Id } from "@/types"; @@ -43,13 +42,9 @@ export const useChatStore = create()( }, setCurrentChat: (chat: Chat | undefined) => set(() => ({ currentChat: chat })), updateChat: (chatId: Id, chat: Partial) => { - const rawChat = get().chatList.find((chat) => chat.id === chatId); - if (!rawChat) { - return; - } - Object.assign(rawChat, chat); set((state) => ({ - chatList: uniqBy([...state.chatList], (chat) => chat.id), + ...state, + chatList: state.chatList.map((item) => (item.id === chatId ? { ...item, ...chat } : item)), })); }, clearChat: (filter: (chat: Chat) => boolean) => { diff --git a/store/connection.ts b/store/connection.ts index 632692a..95df435 100644 --- a/store/connection.ts +++ b/store/connection.ts @@ -2,7 +2,7 @@ import axios from "axios"; import { uniqBy } from "lodash-es"; import { create } from "zustand"; import { persist } from "zustand/middleware"; -import { Connection, Database, Table } from "@/types"; +import { Connection, Database, Engine, Table } from "@/types"; import { generateUUID } from "@/utils"; interface ConnectionContext { @@ -10,21 +10,33 @@ interface ConnectionContext { database?: Database; } +const samplePGConnection: Connection = { + id: "sample-pg", + title: "Sample PostgreSQL", + engineType: Engine.PostgreSQL, + host: "db.aqbxmomjsyqbacfsujwd.supabase.co", + port: "", + username: "readonly_user", + password: "bytebase-sqlchat", + database: "employee", +}; + interface ConnectionState { connectionList: Connection[]; databaseList: Database[]; currentConnectionCtx?: ConnectionContext; createConnection: (connection: Connection) => Connection; setCurrentConnectionCtx: (connectionCtx: ConnectionContext | undefined) => void; - getOrFetchDatabaseList: (connection: Connection) => Promise; + getOrFetchDatabaseList: (connection: Connection, skipCache?: boolean) => Promise; getOrFetchDatabaseSchema: (database: Database) => Promise; + updateConnection: (connectionId: string, connection: Partial) => void; clearConnection: (filter: (connection: Connection) => boolean) => void; } export const useConnectionStore = create()( persist( (set, get) => ({ - connectionList: [], + connectionList: [samplePGConnection], databaseList: [], createConnection: (connection: Connection) => { const createdConnection = { @@ -42,10 +54,13 @@ export const useConnectionStore = create()( ...state, currentConnectionCtx: connectionCtx, })), - getOrFetchDatabaseList: async (connection: Connection) => { + getOrFetchDatabaseList: async (connection: Connection, skipCache = false) => { const state = get(); - if (state.databaseList.some((database) => database.connectionId === connection.id)) { - return state.databaseList.filter((database) => database.connectionId === connection.id); + + if (!skipCache) { + if (state.databaseList.some((database) => database.connectionId === connection.id)) { + return state.databaseList.filter((database) => database.connectionId === connection.id); + } } const { data } = await axios.post("/api/connection/db", { @@ -82,6 +97,12 @@ export const useConnectionStore = create()( }); return data; }, + updateConnection: (connectionId: string, connection: Partial) => { + set((state) => ({ + ...state, + connectionList: state.connectionList.map((item) => (item.id === connectionId ? { ...item, ...connection } : item)), + })); + }, clearConnection: (filter: (connection: Connection) => boolean) => { set((state) => ({ ...state, diff --git a/store/message.ts b/store/message.ts index 76c7c04..fbcaa08 100644 --- a/store/message.ts +++ b/store/message.ts @@ -1,4 +1,3 @@ -import { uniqBy } from "lodash-es"; import { create } from "zustand"; import { persist } from "zustand/middleware"; import { Id, Message } from "@/types"; @@ -18,21 +17,9 @@ export const useMessageStore = create()( getState: () => get(), addMessage: (message: Message) => set((state) => ({ messageList: [...state.messageList, message] })), updateMessage: (messageId: Id, message: Partial) => { - const rawMessage = get().messageList.find((message) => message.id === messageId); - if (!rawMessage) { - return; - } set((state) => ({ - messageList: uniqBy( - [ - ...state.messageList, - { - ...rawMessage, - ...message, - }, - ], - (message) => message.id - ), + ...state, + messageList: state.messageList.map((item) => (item.id === messageId ? { ...item, ...message } : item)), })); }, clearMessage: (filter: (message: Message) => boolean) => set((state) => ({ messageList: state.messageList.filter(filter) })),