diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/client/AzureOpenAIClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/client/AzureOpenAIClient.java index 82e18745..96c3103d 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/client/AzureOpenAIClient.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/client/AzureOpenAIClient.java @@ -16,13 +16,21 @@ import lombok.extern.slf4j.Slf4j; @Slf4j public class AzureOpenAIClient { + /** + * AZURE OPENAI KEY + */ public static final String AZURE_CHATGPT_API_KEY = "azure.chatgpt.apiKey"; /** - * OPENAI接口域名 + * AZURE OPENAI ENDPOINT */ public static final String AZURE_CHATGPT_ENDPOINT = "azure.chatgpt.endpoint"; + /** + * AZURE OPENAI DEPLOYMENT ID + */ + public static final String AZURE_CHATGPT_DEPLOYMENT_ID = "azure.chatgpt.deployment.id"; + private static AzureOpenAiStreamClient OPEN_AI_CLIENT; private static String apiKey; @@ -48,6 +56,7 @@ public class AzureOpenAIClient { public static void refresh() { String apikey = ""; String apiEndpoint = ""; + String deployId = ""; ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); Config apiHostConfig = configService.find(AZURE_CHATGPT_ENDPOINT).getData(); if (apiHostConfig != null) { @@ -57,8 +66,12 @@ public class AzureOpenAIClient { if (config != null) { apikey = config.getContent(); } + Config deployConfig = configService.find(AZURE_CHATGPT_DEPLOYMENT_ID).getData(); + if (config != null) { + deployId = deployConfig.getContent(); + } log.info("refresh azure openai apikey:{}", maskApiKey(apikey)); - OPEN_AI_CLIENT = new AzureOpenAiStreamClient(apiKey, apiEndpoint); + OPEN_AI_CLIENT = new AzureOpenAiStreamClient(apiKey, apiEndpoint, deployId); apiKey = apikey; } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/client/AzureOpenAiStreamClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/client/AzureOpenAiStreamClient.java index 2327ce71..ce4ff7dd 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/client/AzureOpenAiStreamClient.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/client/AzureOpenAiStreamClient.java @@ -27,6 +27,11 @@ import org.apache.commons.lang3.StringUtils; @Slf4j public class AzureOpenAiStreamClient { + /** + * deployId + */ + private String deployId; + /** * client */ @@ -38,7 +43,8 @@ public class AzureOpenAiStreamClient { * @param apiKey * @param endpoint */ - public AzureOpenAiStreamClient(String apiKey, String endpoint) { + public AzureOpenAiStreamClient(String apiKey, String endpoint, String deployId) { + this.deployId = deployId; this.client = new OpenAIClientBuilder() .credential(new AzureKeyCredential(apiKey)) .endpoint(endpoint) @@ -48,11 +54,10 @@ public class AzureOpenAiStreamClient { /** * 问答接口 stream 形式 * - * @param deployId * @param chatMessages * @param eventSourceListener */ - public void streamCompletions(String deployId, List chatMessages, EventSourceListener eventSourceListener) { + public void streamCompletions(List chatMessages, EventSourceListener eventSourceListener) { if (CollectionUtils.isEmpty(chatMessages)) { log.error("参数异常:Azure Prompt不能为空"); throw new ParamBusinessException("prompt");