From 8b2b10b13dbd5cb3c23ed4bfa501d6913617f89c Mon Sep 17 00:00:00 2001 From: Tianzhou Chen Date: Sat, 20 May 2023 23:26:44 +0800 Subject: [PATCH] feat: allow user to use gpt-4 --- src/components/ConversationView/index.tsx | 5 +- src/components/OpenAIApiConfigView.tsx | 66 ++++++++++++++++++++--- src/components/kit/Radio.tsx | 39 ++++++++++++++ src/locales/en.json | 10 +++- src/locales/es.json | 10 +++- src/locales/zh.json | 10 +++- src/pages/api/chat.ts | 11 ++-- src/pages/api/collect.ts | 4 +- src/store/setting.ts | 1 + src/types/setting.ts | 1 + src/utils/model.ts | 20 ++++++- 11 files changed, 156 insertions(+), 21 deletions(-) create mode 100644 src/components/kit/Radio.tsx diff --git a/src/components/ConversationView/index.tsx b/src/components/ConversationView/index.tsx index e2e5e9d..39e6979 100644 --- a/src/components/ConversationView/index.tsx +++ b/src/components/ConversationView/index.tsx @@ -217,6 +217,9 @@ const ConversationView = () => { if (settingStore.setting.openAIApiConfig?.endpoint) { requestHeaders["x-openai-endpoint"] = settingStore.setting.openAIApiConfig?.endpoint; } + if (settingStore.setting.openAIApiConfig?.model) { + requestHeaders["x-openai-model"] = settingStore.setting.openAIApiConfig?.model; + } const rawRes = await fetch("/api/chat", { method: "POST", body: JSON.stringify({ @@ -311,7 +314,7 @@ const ConversationView = () => { messages: usageMessageList, }, { - headers: session?.user.id ? { Authorization: `Bearer ${session?.user.id}` } : undefined, + headers: requestHeaders, } ) .catch(() => { diff --git a/src/components/OpenAIApiConfigView.tsx b/src/components/OpenAIApiConfigView.tsx index 9c875a9..a70ae9a 100644 --- a/src/components/OpenAIApiConfigView.tsx +++ b/src/components/OpenAIApiConfigView.tsx @@ -3,14 +3,31 @@ import { useTranslation } from "react-i18next"; import { useDebounce } from "react-use"; import { useSettingStore } from "@/store"; import { OpenAIApiConfig } from "@/types"; +import Radio from "./kit/Radio"; import TextField from "./kit/TextField"; - +import Tooltip from "./kit/Tooltip"; const OpenAIApiConfigView = () => { const { t } = useTranslation(); const settingStore = useSettingStore(); const [openAIApiConfig, setOpenAIApiConfig] = useState(settingStore.setting.openAIApiConfig); const [maskKey, setMaskKey] = useState(true); + const models = [ + { + id: "gpt-3.5-turbo", + title: `GPT-3.5 (${t("setting.openai-api-configuration.quota-per-ask", { count: 1 })})`, + disabled: false, + tooltip: "", + }, + // Disable GPT-4 if user doesn't provide key (because SQL Chat own key hasn't been whitelisted yet). + { + id: "gpt-4", + title: `GPT-4 (${t("setting.openai-api-configuration.quota-per-ask", { count: 10 })})`, + disabled: !settingStore.setting.openAIApiConfig.key, + tooltip: t("setting.openai-api-configuration.provide-gpt4-key"), + }, + ]; + const maskedKey = (str: string) => { if (str.length < 7) { return str; @@ -37,21 +54,58 @@ const OpenAIApiConfigView = () => { setMaskKey(false); }; + const modelRadio = (model: any) => { + return ( +
+ handleSetOpenAIApiConfig({ model: value })} + /> + +
+ ); + }; + return ( <>
-
- +
+ +

{t("setting.openai-api-configuration.model-description")}

+
+
+ {models.map((model) => + model.disabled ? ( + + {modelRadio(model)} + + ) : ( + modelRadio(model) + ) + )} +
+
+
+
+ +

{t("setting.openai-api-configuration.key-description")}

handleSetOpenAIApiConfig({ key: value })} />
-
- +
+ +

{t("setting.openai-api-configuration.endpoint-description")}

handleSetOpenAIApiConfig({ endpoint: value })} /> diff --git a/src/components/kit/Radio.tsx b/src/components/kit/Radio.tsx new file mode 100644 index 0000000..73a5f48 --- /dev/null +++ b/src/components/kit/Radio.tsx @@ -0,0 +1,39 @@ +import { HTMLInputTypeAttribute } from "react"; + +interface Props { + value: string; + onChange?: (value: string) => void; + type?: HTMLInputTypeAttribute; + className?: string; + disabled?: boolean; + checked?: boolean; +} + +const getDefaultProps = () => ({ + value: "", + onChange: () => {}, + type: "radio", + className: "", + disabled: false, + checked: false, +}); + +const Radio = (props: Props) => { + const { value, disabled, className, type, checked, onChange } = { + ...getDefaultProps(), + ...props, + }; + + return ( + onChange(e.target.value)} + /> + ); +}; + +export default Radio; diff --git a/src/locales/en.json b/src/locales/en.json index 8a5d334..41dd2b9 100644 --- a/src/locales/en.json +++ b/src/locales/en.json @@ -68,7 +68,7 @@ "upgrade": "Upgrade", "renew": "Renew", "expired": "Expired", - "n-question-per-month": "{{count}} questions / month", + "n-question-per-month": "{{count}} quota / month", "early-bird-checkout": "Early bird discount, 50% off for 1 year" }, "billing": { @@ -89,7 +89,13 @@ "dark": "Dark" }, "openai-api-configuration": { - "self": "OpenAI API configuration" + "self": "OpenAI API configuration", + "model": "Model", + "model-description": "Quota won't be consumed if you provide your own key below.", + "quota-per-ask": "{{count}} quata per ask", + "provide-gpt4-key": "Require your own GPT-4 enabled API key", + "key-description": "Bring your own key to waive quota requirement.", + "endpoint-description": "Optional endpoint pointing to your own compatible server or gateway." }, "data": { "self": "Data", diff --git a/src/locales/es.json b/src/locales/es.json index d15cd9e..0ba472b 100644 --- a/src/locales/es.json +++ b/src/locales/es.json @@ -66,7 +66,7 @@ "upgrade": "Mejora", "renew": "Renovar", "expired": "Expirado", - "n-question-per-month": "{{count}} preguntas / mes", + "n-question-per-month": "{{count}} Cuota / mes", "early-bird-checkout": "Descuento por reserva anticipada, 50 % de descuento durante 1 año" }, "billing": { @@ -87,7 +87,13 @@ "dark": "Oscuro" }, "openai-api-configuration": { - "self": "Configuración del API de OpenAI" + "self": "Configuración del API de OpenAI", + "model": "Modelo", + "model-description": "La cuota no se consumirá si proporciona su propia clave a continuación.", + "quota-per-ask": "{{count}} cuotas por pedido", + "provide-gpt4-key": "Requerir su propia clave API habilitada para GPT-4", + "key-description": "Traiga su propia llave para renunciar al requisito de cuota.", + "endpoint-description": "Punto final opcional que apunta a su propio servidor o puerta de enlace compatible." }, "data": { "self": "Datos", diff --git a/src/locales/zh.json b/src/locales/zh.json index df87cc5..51c182a 100644 --- a/src/locales/zh.json +++ b/src/locales/zh.json @@ -68,7 +68,7 @@ "upgrade": "升级", "renew": "续费", "expired": "已过期", - "n-question-per-month": "{{count}} 次提问 / 月", + "n-question-per-month": "{{count}} 点额度 / 月", "early-bird-checkout": "早鸟优惠,5 折购买 1 年" }, "billing": { @@ -89,7 +89,13 @@ "dark": "深色" }, "openai-api-configuration": { - "self": "OpenAI API 配置" + "self": "OpenAI API 配置", + "model": "模型", + "model-description": "如果您提供自己的 key,额度是不会消耗的。", + "quota-per-ask": "每一个提问消耗 {{count}} 点额度", + "provide-gpt4-key": "需提供您自己的,可以使用 GPT-4 的 key", + "key-description": "一旦您提供了自己的 key,额度就不受限制了。", + "endpoint-description": "可选的 endpoint 指向接口兼容的服务器或者网关。" }, "data": { "self": "数据", diff --git a/src/pages/api/chat.ts b/src/pages/api/chat.ts index 3decfd0..0b5cf43 100644 --- a/src/pages/api/chat.ts +++ b/src/pages/api/chat.ts @@ -1,6 +1,6 @@ import { createParser, ParsedEvent, ReconnectInterval } from "eventsource-parser"; import { NextRequest } from "next/server"; -import { openAIApiEndpoint, openAIApiKey, gpt35, hasFeature } from "@/utils"; +import { openAIApiEndpoint, openAIApiKey, hasFeature, getModel } from "@/utils"; // Needs Edge for streaming response. export const config = { @@ -90,6 +90,7 @@ const handler = async (req: NextRequest) => { } const apiEndpoint = getApiEndpoint(req.headers.get("x-openai-endpoint") || openAIApiEndpoint); + const model = getModel(req.headers.get("x-openai-model") || ""); const remoteRes = await fetch(apiEndpoint, { headers: { "Content-Type": "application/json", @@ -97,11 +98,11 @@ const handler = async (req: NextRequest) => { }, method: "POST", body: JSON.stringify({ - model: gpt35.name, + model: model.name, messages: reqBody.messages, - temperature: gpt35.temperature, - frequency_penalty: gpt35.frequency_penalty, - presence_penalty: gpt35.presence_penalty, + temperature: model.temperature, + frequency_penalty: model.frequency_penalty, + presence_penalty: model.presence_penalty, stream: true, // Send end-user ID to help OpenAI monitor and detect abuse. user: req.ip, diff --git a/src/pages/api/collect.ts b/src/pages/api/collect.ts index f0d5647..1e96795 100644 --- a/src/pages/api/collect.ts +++ b/src/pages/api/collect.ts @@ -1,7 +1,7 @@ import { PrismaClient } from "@prisma/client"; import { NextApiRequest, NextApiResponse } from "next"; import { Conversation, Message } from "@/types"; -import { gpt35 } from "@/utils"; +import { getModel, gpt35 } from "@/utils"; import { getEndUser } from "./auth/end-user"; const prisma = new PrismaClient(); @@ -36,7 +36,7 @@ export default async function handler(req: NextApiRequest, res: NextApiResponse) data: { id: conversation.id, createdAt: new Date(conversation.createdAt), - model: gpt35, + model: getModel((req.headers["x-openai-model"] as string) || ""), ctx: {}, messages: { create: messages.map((message) => ({ diff --git a/src/store/setting.ts b/src/store/setting.ts index 7086e4c..b51abaa 100644 --- a/src/store/setting.ts +++ b/src/store/setting.ts @@ -10,6 +10,7 @@ const getDefaultSetting = (): Setting => { openAIApiConfig: { key: "", endpoint: "", + model: "gpt-3.5-turbo", }, }; }; diff --git a/src/types/setting.ts b/src/types/setting.ts index 636a85f..886dbf2 100644 --- a/src/types/setting.ts +++ b/src/types/setting.ts @@ -5,6 +5,7 @@ export type Theme = "light" | "dark" | "system"; export interface OpenAIApiConfig { key: string; endpoint: string; + model: string; } export interface Setting { diff --git a/src/utils/model.ts b/src/utils/model.ts index af555a4..0764f11 100644 --- a/src/utils/model.ts +++ b/src/utils/model.ts @@ -1,6 +1,24 @@ -export const gpt35 = { +const gpt35 = { name: "gpt-3.5-turbo", temperature: 0, frequency_penalty: 0.0, presence_penalty: 0.0, }; + +const gpt4 = { + name: "gpt-4", + temperature: 0, + frequency_penalty: 0.0, + presence_penalty: 0.0, +}; + +export const models = [gpt35, gpt4]; + +export const getModel = (name: string) => { + for (const model of models) { + if (model.name === name) { + return model; + } + } + return gpt35; +};