mirror of
https://github.com/sqlchat/sqlchat.git
synced 2025-07-31 11:13:02 +08:00
feat: implement token counter (#3)
This commit is contained in:
@ -3,13 +3,17 @@ import { useEffect, useRef, useState } from "react";
|
||||
import { toast } from "react-hot-toast";
|
||||
import { getAssistantById, getPromptGeneratorOfAssistant, useChatStore, useMessageStore, useConnectionStore } from "@/store";
|
||||
import { CreatorRole, Message } from "@/types";
|
||||
import { generateUUID } from "@/utils";
|
||||
import { countTextTokens, generateUUID } from "@/utils";
|
||||
import Header from "./Header";
|
||||
import EmptyView from "../EmptyView";
|
||||
import MessageView from "./MessageView";
|
||||
import MessageTextarea from "./MessageTextarea";
|
||||
import MessageLoader from "../MessageLoader";
|
||||
|
||||
// The maximum number of tokens that can be sent to the OpenAI API.
|
||||
// reference: https://platform.openai.com/docs/api-reference/completions/create#completions/create-max_tokens
|
||||
const MAX_TOKENS = 4000;
|
||||
|
||||
const ChatView = () => {
|
||||
const connectionStore = useConnectionStore();
|
||||
const chatStore = useChatStore();
|
||||
@ -86,24 +90,38 @@ const ChatView = () => {
|
||||
setIsRequesting(true);
|
||||
const messageList = messageStore.getState().messageList.filter((message) => message.chatId === currentChat.id);
|
||||
let prompt = "";
|
||||
let tokens = 0;
|
||||
if (connectionStore.currentConnectionCtx?.database) {
|
||||
const tables = await connectionStore.getOrFetchDatabaseSchema(connectionStore.currentConnectionCtx?.database);
|
||||
const promptGenerator = getPromptGeneratorOfAssistant(getAssistantById(currentChat.assistantId)!);
|
||||
prompt = promptGenerator(tables.map((table) => table.structure).join("/n"));
|
||||
let schema = "";
|
||||
for (const table of tables) {
|
||||
if (tokens < MAX_TOKENS / 2) {
|
||||
tokens += countTextTokens(schema + table.structure);
|
||||
schema += table.structure;
|
||||
}
|
||||
}
|
||||
prompt = promptGenerator(schema);
|
||||
}
|
||||
let formatedMessageList = [];
|
||||
for (let i = messageList.length - 1; i >= 0; i--) {
|
||||
const message = messageList[i];
|
||||
if (tokens < MAX_TOKENS) {
|
||||
tokens += countTextTokens(message.content);
|
||||
formatedMessageList.unshift({
|
||||
role: message.creatorRole,
|
||||
content: message.content,
|
||||
});
|
||||
}
|
||||
}
|
||||
formatedMessageList.unshift({
|
||||
role: CreatorRole.System,
|
||||
content: prompt,
|
||||
});
|
||||
const rawRes = await fetch("/api/chat", {
|
||||
method: "POST",
|
||||
body: JSON.stringify({
|
||||
messages: [
|
||||
{
|
||||
role: CreatorRole.System,
|
||||
content: prompt,
|
||||
},
|
||||
...messageList.map((message) => ({
|
||||
role: message.creatorRole,
|
||||
content: message.content,
|
||||
})),
|
||||
],
|
||||
messages: formatedMessageList,
|
||||
}),
|
||||
});
|
||||
setIsRequesting(false);
|
||||
|
@ -35,6 +35,7 @@
|
||||
"zustand": "^4.3.6"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@nem035/gpt-3-encoder": "^1.1.7",
|
||||
"@tailwindcss/typography": "^0.5.9",
|
||||
"@types/lodash-es": "^4.17.7",
|
||||
"@types/node": "^18.11.18",
|
||||
|
@ -17,7 +17,6 @@ const handler = async (req: NextRequest) => {
|
||||
body: JSON.stringify({
|
||||
model: "gpt-3.5-turbo",
|
||||
messages: reqBody.messages,
|
||||
max_tokens: 1000,
|
||||
temperature: 0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
|
6
pnpm-lock.yaml
generated
6
pnpm-lock.yaml
generated
@ -5,6 +5,7 @@ specifiers:
|
||||
'@emotion/styled': ^11.10.6
|
||||
'@mui/material': ^5.11.14
|
||||
'@mui/styled-engine-sc': ^5.11.11
|
||||
'@nem035/gpt-3-encoder': ^1.1.7
|
||||
'@tailwindcss/typography': ^0.5.9
|
||||
'@types/lodash-es': ^4.17.7
|
||||
'@types/node': ^18.11.18
|
||||
@ -74,6 +75,7 @@ dependencies:
|
||||
zustand: 4.3.6_react@18.2.0
|
||||
|
||||
devDependencies:
|
||||
'@nem035/gpt-3-encoder': 1.1.7
|
||||
'@tailwindcss/typography': 0.5.9_tailwindcss@3.2.7
|
||||
'@types/lodash-es': 4.17.7
|
||||
'@types/node': 18.15.3
|
||||
@ -580,6 +582,10 @@ packages:
|
||||
react-is: 18.2.0
|
||||
dev: false
|
||||
|
||||
/@nem035/gpt-3-encoder/1.1.7:
|
||||
resolution: {integrity: sha512-dtOenP4ZAmsKXkobTDUCcbkQvPJbuJ6Kp/LHqWDYLK//XNgGs3Re8ymcQzyVhtph8JckdI3K8FR5Q+6mX7HnpQ==}
|
||||
dev: true
|
||||
|
||||
/@next/env/13.2.4:
|
||||
resolution: {integrity: sha512-+Mq3TtpkeeKFZanPturjcXt+KHfKYnLlX6jMLyCrmpq6OOs4i1GqBOAauSkii9QeKCMTYzGppar21JU57b/GEA==}
|
||||
dev: false
|
||||
|
@ -1 +1,7 @@
|
||||
import { encode } from "@nem035/gpt-3-encoder";
|
||||
|
||||
export const openAIApiKey = process.env.OPENAI_API_KEY;
|
||||
|
||||
export const countTextTokens = (text: string) => {
|
||||
return encode(text).length;
|
||||
};
|
||||
|
Reference in New Issue
Block a user