feat: implement token counter (#3)

This commit is contained in:
boojack
2023-03-29 00:11:39 +08:00
committed by GitHub
parent f94a273d78
commit 0a878caf4f
5 changed files with 43 additions and 13 deletions

View File

@ -3,13 +3,17 @@ import { useEffect, useRef, useState } from "react";
import { toast } from "react-hot-toast";
import { getAssistantById, getPromptGeneratorOfAssistant, useChatStore, useMessageStore, useConnectionStore } from "@/store";
import { CreatorRole, Message } from "@/types";
import { generateUUID } from "@/utils";
import { countTextTokens, generateUUID } from "@/utils";
import Header from "./Header";
import EmptyView from "../EmptyView";
import MessageView from "./MessageView";
import MessageTextarea from "./MessageTextarea";
import MessageLoader from "../MessageLoader";
// The maximum number of tokens that can be sent to the OpenAI API.
// reference: https://platform.openai.com/docs/api-reference/completions/create#completions/create-max_tokens
const MAX_TOKENS = 4000;
const ChatView = () => {
const connectionStore = useConnectionStore();
const chatStore = useChatStore();
@ -86,24 +90,38 @@ const ChatView = () => {
setIsRequesting(true);
const messageList = messageStore.getState().messageList.filter((message) => message.chatId === currentChat.id);
let prompt = "";
let tokens = 0;
if (connectionStore.currentConnectionCtx?.database) {
const tables = await connectionStore.getOrFetchDatabaseSchema(connectionStore.currentConnectionCtx?.database);
const promptGenerator = getPromptGeneratorOfAssistant(getAssistantById(currentChat.assistantId)!);
prompt = promptGenerator(tables.map((table) => table.structure).join("/n"));
let schema = "";
for (const table of tables) {
if (tokens < MAX_TOKENS / 2) {
tokens += countTextTokens(schema + table.structure);
schema += table.structure;
}
}
prompt = promptGenerator(schema);
}
let formatedMessageList = [];
for (let i = messageList.length - 1; i >= 0; i--) {
const message = messageList[i];
if (tokens < MAX_TOKENS) {
tokens += countTextTokens(message.content);
formatedMessageList.unshift({
role: message.creatorRole,
content: message.content,
});
}
}
formatedMessageList.unshift({
role: CreatorRole.System,
content: prompt,
});
const rawRes = await fetch("/api/chat", {
method: "POST",
body: JSON.stringify({
messages: [
{
role: CreatorRole.System,
content: prompt,
},
...messageList.map((message) => ({
role: message.creatorRole,
content: message.content,
})),
],
messages: formatedMessageList,
}),
});
setIsRequesting(false);

View File

@ -35,6 +35,7 @@
"zustand": "^4.3.6"
},
"devDependencies": {
"@nem035/gpt-3-encoder": "^1.1.7",
"@tailwindcss/typography": "^0.5.9",
"@types/lodash-es": "^4.17.7",
"@types/node": "^18.11.18",

View File

@ -17,7 +17,6 @@ const handler = async (req: NextRequest) => {
body: JSON.stringify({
model: "gpt-3.5-turbo",
messages: reqBody.messages,
max_tokens: 1000,
temperature: 0,
frequency_penalty: 0.0,
presence_penalty: 0.0,

6
pnpm-lock.yaml generated
View File

@ -5,6 +5,7 @@ specifiers:
'@emotion/styled': ^11.10.6
'@mui/material': ^5.11.14
'@mui/styled-engine-sc': ^5.11.11
'@nem035/gpt-3-encoder': ^1.1.7
'@tailwindcss/typography': ^0.5.9
'@types/lodash-es': ^4.17.7
'@types/node': ^18.11.18
@ -74,6 +75,7 @@ dependencies:
zustand: 4.3.6_react@18.2.0
devDependencies:
'@nem035/gpt-3-encoder': 1.1.7
'@tailwindcss/typography': 0.5.9_tailwindcss@3.2.7
'@types/lodash-es': 4.17.7
'@types/node': 18.15.3
@ -580,6 +582,10 @@ packages:
react-is: 18.2.0
dev: false
/@nem035/gpt-3-encoder/1.1.7:
resolution: {integrity: sha512-dtOenP4ZAmsKXkobTDUCcbkQvPJbuJ6Kp/LHqWDYLK//XNgGs3Re8ymcQzyVhtph8JckdI3K8FR5Q+6mX7HnpQ==}
dev: true
/@next/env/13.2.4:
resolution: {integrity: sha512-+Mq3TtpkeeKFZanPturjcXt+KHfKYnLlX6jMLyCrmpq6OOs4i1GqBOAauSkii9QeKCMTYzGppar21JU57b/GEA==}
dev: false

View File

@ -1 +1,7 @@
import { encode } from "@nem035/gpt-3-encoder";
export const openAIApiKey = process.env.OPENAI_API_KEY;
export const countTextTokens = (text: string) => {
return encode(text).length;
};