diff --git a/components/ChatView/index.tsx b/components/ChatView/index.tsx index 08d7776..96c4461 100644 --- a/components/ChatView/index.tsx +++ b/components/ChatView/index.tsx @@ -3,13 +3,17 @@ import { useEffect, useRef, useState } from "react"; import { toast } from "react-hot-toast"; import { getAssistantById, getPromptGeneratorOfAssistant, useChatStore, useMessageStore, useConnectionStore } from "@/store"; import { CreatorRole, Message } from "@/types"; -import { generateUUID } from "@/utils"; +import { countTextTokens, generateUUID } from "@/utils"; import Header from "./Header"; import EmptyView from "../EmptyView"; import MessageView from "./MessageView"; import MessageTextarea from "./MessageTextarea"; import MessageLoader from "../MessageLoader"; +// The maximum number of tokens that can be sent to the OpenAI API. +// reference: https://platform.openai.com/docs/api-reference/completions/create#completions/create-max_tokens +const MAX_TOKENS = 4000; + const ChatView = () => { const connectionStore = useConnectionStore(); const chatStore = useChatStore(); @@ -86,24 +90,38 @@ const ChatView = () => { setIsRequesting(true); const messageList = messageStore.getState().messageList.filter((message) => message.chatId === currentChat.id); let prompt = ""; + let tokens = 0; if (connectionStore.currentConnectionCtx?.database) { const tables = await connectionStore.getOrFetchDatabaseSchema(connectionStore.currentConnectionCtx?.database); const promptGenerator = getPromptGeneratorOfAssistant(getAssistantById(currentChat.assistantId)!); - prompt = promptGenerator(tables.map((table) => table.structure).join("/n")); + let schema = ""; + for (const table of tables) { + if (tokens < MAX_TOKENS / 2) { + tokens += countTextTokens(schema + table.structure); + schema += table.structure; + } + } + prompt = promptGenerator(schema); } + let formatedMessageList = []; + for (let i = messageList.length - 1; i >= 0; i--) { + const message = messageList[i]; + if (tokens < MAX_TOKENS) { + tokens += countTextTokens(message.content); + formatedMessageList.unshift({ + role: message.creatorRole, + content: message.content, + }); + } + } + formatedMessageList.unshift({ + role: CreatorRole.System, + content: prompt, + }); const rawRes = await fetch("/api/chat", { method: "POST", body: JSON.stringify({ - messages: [ - { - role: CreatorRole.System, - content: prompt, - }, - ...messageList.map((message) => ({ - role: message.creatorRole, - content: message.content, - })), - ], + messages: formatedMessageList, }), }); setIsRequesting(false); diff --git a/package.json b/package.json index 564308e..05acdfa 100644 --- a/package.json +++ b/package.json @@ -35,6 +35,7 @@ "zustand": "^4.3.6" }, "devDependencies": { + "@nem035/gpt-3-encoder": "^1.1.7", "@tailwindcss/typography": "^0.5.9", "@types/lodash-es": "^4.17.7", "@types/node": "^18.11.18", diff --git a/pages/api/chat.ts b/pages/api/chat.ts index 01232cf..fdfaa23 100644 --- a/pages/api/chat.ts +++ b/pages/api/chat.ts @@ -17,7 +17,6 @@ const handler = async (req: NextRequest) => { body: JSON.stringify({ model: "gpt-3.5-turbo", messages: reqBody.messages, - max_tokens: 1000, temperature: 0, frequency_penalty: 0.0, presence_penalty: 0.0, diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index d021451..e620aed 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -5,6 +5,7 @@ specifiers: '@emotion/styled': ^11.10.6 '@mui/material': ^5.11.14 '@mui/styled-engine-sc': ^5.11.11 + '@nem035/gpt-3-encoder': ^1.1.7 '@tailwindcss/typography': ^0.5.9 '@types/lodash-es': ^4.17.7 '@types/node': ^18.11.18 @@ -74,6 +75,7 @@ dependencies: zustand: 4.3.6_react@18.2.0 devDependencies: + '@nem035/gpt-3-encoder': 1.1.7 '@tailwindcss/typography': 0.5.9_tailwindcss@3.2.7 '@types/lodash-es': 4.17.7 '@types/node': 18.15.3 @@ -580,6 +582,10 @@ packages: react-is: 18.2.0 dev: false + /@nem035/gpt-3-encoder/1.1.7: + resolution: {integrity: sha512-dtOenP4ZAmsKXkobTDUCcbkQvPJbuJ6Kp/LHqWDYLK//XNgGs3Re8ymcQzyVhtph8JckdI3K8FR5Q+6mX7HnpQ==} + dev: true + /@next/env/13.2.4: resolution: {integrity: sha512-+Mq3TtpkeeKFZanPturjcXt+KHfKYnLlX6jMLyCrmpq6OOs4i1GqBOAauSkii9QeKCMTYzGppar21JU57b/GEA==} dev: false diff --git a/utils/openai.ts b/utils/openai.ts index 3c8dd5b..e7e80dd 100644 --- a/utils/openai.ts +++ b/utils/openai.ts @@ -1 +1,7 @@ +import { encode } from "@nem035/gpt-3-encoder"; + export const openAIApiKey = process.env.OPENAI_API_KEY; + +export const countTextTokens = (text: string) => { + return encode(text).length; +};