feat: implement token counter (#3)

This commit is contained in:
boojack
2023-03-29 00:11:39 +08:00
committed by GitHub
parent f94a273d78
commit 0a878caf4f
5 changed files with 43 additions and 13 deletions

View File

@ -3,13 +3,17 @@ import { useEffect, useRef, useState } from "react";
import { toast } from "react-hot-toast"; import { toast } from "react-hot-toast";
import { getAssistantById, getPromptGeneratorOfAssistant, useChatStore, useMessageStore, useConnectionStore } from "@/store"; import { getAssistantById, getPromptGeneratorOfAssistant, useChatStore, useMessageStore, useConnectionStore } from "@/store";
import { CreatorRole, Message } from "@/types"; import { CreatorRole, Message } from "@/types";
import { generateUUID } from "@/utils"; import { countTextTokens, generateUUID } from "@/utils";
import Header from "./Header"; import Header from "./Header";
import EmptyView from "../EmptyView"; import EmptyView from "../EmptyView";
import MessageView from "./MessageView"; import MessageView from "./MessageView";
import MessageTextarea from "./MessageTextarea"; import MessageTextarea from "./MessageTextarea";
import MessageLoader from "../MessageLoader"; import MessageLoader from "../MessageLoader";
// The maximum number of tokens that can be sent to the OpenAI API.
// reference: https://platform.openai.com/docs/api-reference/completions/create#completions/create-max_tokens
const MAX_TOKENS = 4000;
const ChatView = () => { const ChatView = () => {
const connectionStore = useConnectionStore(); const connectionStore = useConnectionStore();
const chatStore = useChatStore(); const chatStore = useChatStore();
@ -86,24 +90,38 @@ const ChatView = () => {
setIsRequesting(true); setIsRequesting(true);
const messageList = messageStore.getState().messageList.filter((message) => message.chatId === currentChat.id); const messageList = messageStore.getState().messageList.filter((message) => message.chatId === currentChat.id);
let prompt = ""; let prompt = "";
let tokens = 0;
if (connectionStore.currentConnectionCtx?.database) { if (connectionStore.currentConnectionCtx?.database) {
const tables = await connectionStore.getOrFetchDatabaseSchema(connectionStore.currentConnectionCtx?.database); const tables = await connectionStore.getOrFetchDatabaseSchema(connectionStore.currentConnectionCtx?.database);
const promptGenerator = getPromptGeneratorOfAssistant(getAssistantById(currentChat.assistantId)!); const promptGenerator = getPromptGeneratorOfAssistant(getAssistantById(currentChat.assistantId)!);
prompt = promptGenerator(tables.map((table) => table.structure).join("/n")); let schema = "";
for (const table of tables) {
if (tokens < MAX_TOKENS / 2) {
tokens += countTextTokens(schema + table.structure);
schema += table.structure;
}
}
prompt = promptGenerator(schema);
} }
let formatedMessageList = [];
for (let i = messageList.length - 1; i >= 0; i--) {
const message = messageList[i];
if (tokens < MAX_TOKENS) {
tokens += countTextTokens(message.content);
formatedMessageList.unshift({
role: message.creatorRole,
content: message.content,
});
}
}
formatedMessageList.unshift({
role: CreatorRole.System,
content: prompt,
});
const rawRes = await fetch("/api/chat", { const rawRes = await fetch("/api/chat", {
method: "POST", method: "POST",
body: JSON.stringify({ body: JSON.stringify({
messages: [ messages: formatedMessageList,
{
role: CreatorRole.System,
content: prompt,
},
...messageList.map((message) => ({
role: message.creatorRole,
content: message.content,
})),
],
}), }),
}); });
setIsRequesting(false); setIsRequesting(false);

View File

@ -35,6 +35,7 @@
"zustand": "^4.3.6" "zustand": "^4.3.6"
}, },
"devDependencies": { "devDependencies": {
"@nem035/gpt-3-encoder": "^1.1.7",
"@tailwindcss/typography": "^0.5.9", "@tailwindcss/typography": "^0.5.9",
"@types/lodash-es": "^4.17.7", "@types/lodash-es": "^4.17.7",
"@types/node": "^18.11.18", "@types/node": "^18.11.18",

View File

@ -17,7 +17,6 @@ const handler = async (req: NextRequest) => {
body: JSON.stringify({ body: JSON.stringify({
model: "gpt-3.5-turbo", model: "gpt-3.5-turbo",
messages: reqBody.messages, messages: reqBody.messages,
max_tokens: 1000,
temperature: 0, temperature: 0,
frequency_penalty: 0.0, frequency_penalty: 0.0,
presence_penalty: 0.0, presence_penalty: 0.0,

6
pnpm-lock.yaml generated
View File

@ -5,6 +5,7 @@ specifiers:
'@emotion/styled': ^11.10.6 '@emotion/styled': ^11.10.6
'@mui/material': ^5.11.14 '@mui/material': ^5.11.14
'@mui/styled-engine-sc': ^5.11.11 '@mui/styled-engine-sc': ^5.11.11
'@nem035/gpt-3-encoder': ^1.1.7
'@tailwindcss/typography': ^0.5.9 '@tailwindcss/typography': ^0.5.9
'@types/lodash-es': ^4.17.7 '@types/lodash-es': ^4.17.7
'@types/node': ^18.11.18 '@types/node': ^18.11.18
@ -74,6 +75,7 @@ dependencies:
zustand: 4.3.6_react@18.2.0 zustand: 4.3.6_react@18.2.0
devDependencies: devDependencies:
'@nem035/gpt-3-encoder': 1.1.7
'@tailwindcss/typography': 0.5.9_tailwindcss@3.2.7 '@tailwindcss/typography': 0.5.9_tailwindcss@3.2.7
'@types/lodash-es': 4.17.7 '@types/lodash-es': 4.17.7
'@types/node': 18.15.3 '@types/node': 18.15.3
@ -580,6 +582,10 @@ packages:
react-is: 18.2.0 react-is: 18.2.0
dev: false dev: false
/@nem035/gpt-3-encoder/1.1.7:
resolution: {integrity: sha512-dtOenP4ZAmsKXkobTDUCcbkQvPJbuJ6Kp/LHqWDYLK//XNgGs3Re8ymcQzyVhtph8JckdI3K8FR5Q+6mX7HnpQ==}
dev: true
/@next/env/13.2.4: /@next/env/13.2.4:
resolution: {integrity: sha512-+Mq3TtpkeeKFZanPturjcXt+KHfKYnLlX6jMLyCrmpq6OOs4i1GqBOAauSkii9QeKCMTYzGppar21JU57b/GEA==} resolution: {integrity: sha512-+Mq3TtpkeeKFZanPturjcXt+KHfKYnLlX6jMLyCrmpq6OOs4i1GqBOAauSkii9QeKCMTYzGppar21JU57b/GEA==}
dev: false dev: false

View File

@ -1 +1,7 @@
import { encode } from "@nem035/gpt-3-encoder";
export const openAIApiKey = process.env.OPENAI_API_KEY; export const openAIApiKey = process.env.OPENAI_API_KEY;
export const countTextTokens = (text: string) => {
return encode(text).length;
};