mirror of
https://github.com/sqlchat/sqlchat.git
synced 2025-08-01 05:36:11 +08:00
feat: implement token counter (#3)
This commit is contained in:
@ -3,13 +3,17 @@ import { useEffect, useRef, useState } from "react";
|
|||||||
import { toast } from "react-hot-toast";
|
import { toast } from "react-hot-toast";
|
||||||
import { getAssistantById, getPromptGeneratorOfAssistant, useChatStore, useMessageStore, useConnectionStore } from "@/store";
|
import { getAssistantById, getPromptGeneratorOfAssistant, useChatStore, useMessageStore, useConnectionStore } from "@/store";
|
||||||
import { CreatorRole, Message } from "@/types";
|
import { CreatorRole, Message } from "@/types";
|
||||||
import { generateUUID } from "@/utils";
|
import { countTextTokens, generateUUID } from "@/utils";
|
||||||
import Header from "./Header";
|
import Header from "./Header";
|
||||||
import EmptyView from "../EmptyView";
|
import EmptyView from "../EmptyView";
|
||||||
import MessageView from "./MessageView";
|
import MessageView from "./MessageView";
|
||||||
import MessageTextarea from "./MessageTextarea";
|
import MessageTextarea from "./MessageTextarea";
|
||||||
import MessageLoader from "../MessageLoader";
|
import MessageLoader from "../MessageLoader";
|
||||||
|
|
||||||
|
// The maximum number of tokens that can be sent to the OpenAI API.
|
||||||
|
// reference: https://platform.openai.com/docs/api-reference/completions/create#completions/create-max_tokens
|
||||||
|
const MAX_TOKENS = 4000;
|
||||||
|
|
||||||
const ChatView = () => {
|
const ChatView = () => {
|
||||||
const connectionStore = useConnectionStore();
|
const connectionStore = useConnectionStore();
|
||||||
const chatStore = useChatStore();
|
const chatStore = useChatStore();
|
||||||
@ -86,24 +90,38 @@ const ChatView = () => {
|
|||||||
setIsRequesting(true);
|
setIsRequesting(true);
|
||||||
const messageList = messageStore.getState().messageList.filter((message) => message.chatId === currentChat.id);
|
const messageList = messageStore.getState().messageList.filter((message) => message.chatId === currentChat.id);
|
||||||
let prompt = "";
|
let prompt = "";
|
||||||
|
let tokens = 0;
|
||||||
if (connectionStore.currentConnectionCtx?.database) {
|
if (connectionStore.currentConnectionCtx?.database) {
|
||||||
const tables = await connectionStore.getOrFetchDatabaseSchema(connectionStore.currentConnectionCtx?.database);
|
const tables = await connectionStore.getOrFetchDatabaseSchema(connectionStore.currentConnectionCtx?.database);
|
||||||
const promptGenerator = getPromptGeneratorOfAssistant(getAssistantById(currentChat.assistantId)!);
|
const promptGenerator = getPromptGeneratorOfAssistant(getAssistantById(currentChat.assistantId)!);
|
||||||
prompt = promptGenerator(tables.map((table) => table.structure).join("/n"));
|
let schema = "";
|
||||||
|
for (const table of tables) {
|
||||||
|
if (tokens < MAX_TOKENS / 2) {
|
||||||
|
tokens += countTextTokens(schema + table.structure);
|
||||||
|
schema += table.structure;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
prompt = promptGenerator(schema);
|
||||||
|
}
|
||||||
|
let formatedMessageList = [];
|
||||||
|
for (let i = messageList.length - 1; i >= 0; i--) {
|
||||||
|
const message = messageList[i];
|
||||||
|
if (tokens < MAX_TOKENS) {
|
||||||
|
tokens += countTextTokens(message.content);
|
||||||
|
formatedMessageList.unshift({
|
||||||
|
role: message.creatorRole,
|
||||||
|
content: message.content,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
formatedMessageList.unshift({
|
||||||
|
role: CreatorRole.System,
|
||||||
|
content: prompt,
|
||||||
|
});
|
||||||
const rawRes = await fetch("/api/chat", {
|
const rawRes = await fetch("/api/chat", {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
messages: [
|
messages: formatedMessageList,
|
||||||
{
|
|
||||||
role: CreatorRole.System,
|
|
||||||
content: prompt,
|
|
||||||
},
|
|
||||||
...messageList.map((message) => ({
|
|
||||||
role: message.creatorRole,
|
|
||||||
content: message.content,
|
|
||||||
})),
|
|
||||||
],
|
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
setIsRequesting(false);
|
setIsRequesting(false);
|
||||||
|
@ -35,6 +35,7 @@
|
|||||||
"zustand": "^4.3.6"
|
"zustand": "^4.3.6"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
|
"@nem035/gpt-3-encoder": "^1.1.7",
|
||||||
"@tailwindcss/typography": "^0.5.9",
|
"@tailwindcss/typography": "^0.5.9",
|
||||||
"@types/lodash-es": "^4.17.7",
|
"@types/lodash-es": "^4.17.7",
|
||||||
"@types/node": "^18.11.18",
|
"@types/node": "^18.11.18",
|
||||||
|
@ -17,7 +17,6 @@ const handler = async (req: NextRequest) => {
|
|||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
model: "gpt-3.5-turbo",
|
model: "gpt-3.5-turbo",
|
||||||
messages: reqBody.messages,
|
messages: reqBody.messages,
|
||||||
max_tokens: 1000,
|
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
frequency_penalty: 0.0,
|
frequency_penalty: 0.0,
|
||||||
presence_penalty: 0.0,
|
presence_penalty: 0.0,
|
||||||
|
6
pnpm-lock.yaml
generated
6
pnpm-lock.yaml
generated
@ -5,6 +5,7 @@ specifiers:
|
|||||||
'@emotion/styled': ^11.10.6
|
'@emotion/styled': ^11.10.6
|
||||||
'@mui/material': ^5.11.14
|
'@mui/material': ^5.11.14
|
||||||
'@mui/styled-engine-sc': ^5.11.11
|
'@mui/styled-engine-sc': ^5.11.11
|
||||||
|
'@nem035/gpt-3-encoder': ^1.1.7
|
||||||
'@tailwindcss/typography': ^0.5.9
|
'@tailwindcss/typography': ^0.5.9
|
||||||
'@types/lodash-es': ^4.17.7
|
'@types/lodash-es': ^4.17.7
|
||||||
'@types/node': ^18.11.18
|
'@types/node': ^18.11.18
|
||||||
@ -74,6 +75,7 @@ dependencies:
|
|||||||
zustand: 4.3.6_react@18.2.0
|
zustand: 4.3.6_react@18.2.0
|
||||||
|
|
||||||
devDependencies:
|
devDependencies:
|
||||||
|
'@nem035/gpt-3-encoder': 1.1.7
|
||||||
'@tailwindcss/typography': 0.5.9_tailwindcss@3.2.7
|
'@tailwindcss/typography': 0.5.9_tailwindcss@3.2.7
|
||||||
'@types/lodash-es': 4.17.7
|
'@types/lodash-es': 4.17.7
|
||||||
'@types/node': 18.15.3
|
'@types/node': 18.15.3
|
||||||
@ -580,6 +582,10 @@ packages:
|
|||||||
react-is: 18.2.0
|
react-is: 18.2.0
|
||||||
dev: false
|
dev: false
|
||||||
|
|
||||||
|
/@nem035/gpt-3-encoder/1.1.7:
|
||||||
|
resolution: {integrity: sha512-dtOenP4ZAmsKXkobTDUCcbkQvPJbuJ6Kp/LHqWDYLK//XNgGs3Re8ymcQzyVhtph8JckdI3K8FR5Q+6mX7HnpQ==}
|
||||||
|
dev: true
|
||||||
|
|
||||||
/@next/env/13.2.4:
|
/@next/env/13.2.4:
|
||||||
resolution: {integrity: sha512-+Mq3TtpkeeKFZanPturjcXt+KHfKYnLlX6jMLyCrmpq6OOs4i1GqBOAauSkii9QeKCMTYzGppar21JU57b/GEA==}
|
resolution: {integrity: sha512-+Mq3TtpkeeKFZanPturjcXt+KHfKYnLlX6jMLyCrmpq6OOs4i1GqBOAauSkii9QeKCMTYzGppar21JU57b/GEA==}
|
||||||
dev: false
|
dev: false
|
||||||
|
@ -1 +1,7 @@
|
|||||||
|
import { encode } from "@nem035/gpt-3-encoder";
|
||||||
|
|
||||||
export const openAIApiKey = process.env.OPENAI_API_KEY;
|
export const openAIApiKey = process.env.OPENAI_API_KEY;
|
||||||
|
|
||||||
|
export const countTextTokens = (text: string) => {
|
||||||
|
return encode(text).length;
|
||||||
|
};
|
||||||
|
Reference in New Issue
Block a user