mirror of
https://github.com/CodePhiliaX/Chat2DB.git
synced 2025-09-25 16:13:24 +08:00
add azure openai support
This commit is contained in:
@ -34,10 +34,13 @@
|
||||
<groupId>com.unfbx</groupId>
|
||||
<artifactId>chatgpt-java</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.azure</groupId>
|
||||
<artifactId>azure-ai-openai</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.theokanning.openai-gpt3-java</groupId>
|
||||
<artifactId>service</artifactId>
|
||||
<version>0.12.0</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
@ -0,0 +1,77 @@
|
||||
/**
|
||||
* alibaba.com Inc.
|
||||
* Copyright (c) 2004-2023 All Rights Reserved.
|
||||
*/
|
||||
package ai.chat2db.server.web.api.controller.ai.azure.client;
|
||||
|
||||
import ai.chat2db.server.domain.api.model.Config;
|
||||
import ai.chat2db.server.domain.api.service.ConfigService;
|
||||
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
/**
|
||||
* @author jipengfei
|
||||
* @version : OpenAIClient.java
|
||||
*/
|
||||
@Slf4j
|
||||
public class AzureOpenAIClient {
|
||||
|
||||
public static final String AZURE_CHATGPT_API_KEY = "azure.chatgpt.apiKey";
|
||||
|
||||
/**
|
||||
* OPENAI接口域名
|
||||
*/
|
||||
public static final String AZURE_CHATGPT_ENDPOINT = "azure.chatgpt.endpoint";
|
||||
|
||||
private static AzureOpenAiStreamClient OPEN_AI_CLIENT;
|
||||
private static String apiKey;
|
||||
|
||||
public static AzureOpenAiStreamClient getInstance() {
|
||||
if (OPEN_AI_CLIENT != null) {
|
||||
return OPEN_AI_CLIENT;
|
||||
} else {
|
||||
return singleton();
|
||||
}
|
||||
}
|
||||
|
||||
private static AzureOpenAiStreamClient singleton() {
|
||||
if (OPEN_AI_CLIENT == null) {
|
||||
synchronized (AzureOpenAIClient.class) {
|
||||
if (OPEN_AI_CLIENT == null) {
|
||||
refresh();
|
||||
}
|
||||
}
|
||||
}
|
||||
return OPEN_AI_CLIENT;
|
||||
}
|
||||
|
||||
public static void refresh() {
|
||||
String apikey = "";
|
||||
String apiEndpoint = "";
|
||||
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
|
||||
Config apiHostConfig = configService.find(AZURE_CHATGPT_ENDPOINT).getData();
|
||||
if (apiHostConfig != null) {
|
||||
apiEndpoint = apiHostConfig.getContent();
|
||||
}
|
||||
Config config = configService.find(AZURE_CHATGPT_API_KEY).getData();
|
||||
if (config != null) {
|
||||
apikey = config.getContent();
|
||||
}
|
||||
log.info("refresh azure openai apikey:{}", maskApiKey(apikey));
|
||||
OPEN_AI_CLIENT = new AzureOpenAiStreamClient(apiKey, apiEndpoint);
|
||||
apiKey = apikey;
|
||||
}
|
||||
|
||||
private static String maskApiKey(String input) {
|
||||
if (input == null) {
|
||||
return input;
|
||||
}
|
||||
|
||||
StringBuilder maskedString = new StringBuilder(input);
|
||||
for (int i = input.length() / 2; i < input.length() / 2; i++) {
|
||||
maskedString.setCharAt(i, '*');
|
||||
}
|
||||
return maskedString.toString();
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,102 @@
|
||||
package ai.chat2db.server.web.api.controller.ai.azure.client;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import ai.chat2db.server.tools.common.exception.ParamBusinessException;
|
||||
import com.azure.ai.openai.OpenAIClient;
|
||||
import com.azure.ai.openai.OpenAIClientBuilder;
|
||||
import com.azure.ai.openai.models.ChatChoice;
|
||||
import com.azure.ai.openai.models.ChatCompletions;
|
||||
import com.azure.ai.openai.models.ChatCompletionsOptions;
|
||||
import com.azure.ai.openai.models.ChatMessage;
|
||||
import com.azure.ai.openai.models.CompletionsUsage;
|
||||
import com.azure.core.credential.AzureKeyCredential;
|
||||
import com.azure.core.util.IterableStream;
|
||||
import com.unfbx.chatgpt.entity.chat.Message;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.sse.EventSourceListener;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/**
|
||||
* 自定义AI接口client
|
||||
*
|
||||
* @author moji
|
||||
*/
|
||||
@Slf4j
|
||||
public class AzureOpenAiStreamClient {
|
||||
|
||||
/**
|
||||
* client
|
||||
*/
|
||||
private OpenAIClient client;
|
||||
|
||||
/**
|
||||
* 构造实例对象
|
||||
*
|
||||
* @param apiKey
|
||||
* @param endpoint
|
||||
*/
|
||||
public AzureOpenAiStreamClient(String apiKey, String endpoint) {
|
||||
this.client = new OpenAIClientBuilder()
|
||||
.credential(new AzureKeyCredential(apiKey))
|
||||
.endpoint(endpoint)
|
||||
.buildClient();
|
||||
}
|
||||
|
||||
/**
|
||||
* 问答接口 stream 形式
|
||||
*
|
||||
* @param deployId
|
||||
* @param chatMessages
|
||||
* @param eventSourceListener
|
||||
*/
|
||||
public void streamCompletions(String deployId, List<ChatMessage> chatMessages, EventSourceListener eventSourceListener) {
|
||||
if (CollectionUtils.isEmpty(chatMessages)) {
|
||||
log.error("参数异常:Azure Prompt不能为空");
|
||||
throw new ParamBusinessException("prompt");
|
||||
}
|
||||
if (Objects.isNull(eventSourceListener)) {
|
||||
log.error("参数异常:AzureEventSourceListener不能为空");
|
||||
throw new ParamBusinessException();
|
||||
}
|
||||
log.info("开始调用Azure Open AI, prompt:{}", chatMessages.get(chatMessages.size() - 1).getContent());
|
||||
try {
|
||||
IterableStream<ChatCompletions> chatCompletionsStream = client.getChatCompletionsStream(deployId,
|
||||
new ChatCompletionsOptions(chatMessages));
|
||||
|
||||
chatCompletionsStream.forEach(chatCompletions -> {
|
||||
String text = "";
|
||||
System.out.printf("Model ID=%s is created at %d.%n", chatCompletions.getId(),
|
||||
chatCompletions.getCreated());
|
||||
for (ChatChoice choice : chatCompletions.getChoices()) {
|
||||
ChatMessage message = choice.getDelta();
|
||||
if (message != null) {
|
||||
log.info("Index: {}, Chat Role: {}.%n", choice.getIndex(), message.getRole());
|
||||
text = message.getContent();
|
||||
}
|
||||
}
|
||||
Message message = new Message();
|
||||
if (StringUtils.isNotBlank(text)) {
|
||||
message.setContent(text);
|
||||
eventSourceListener.onEvent(null, "[DATA]", null, text);
|
||||
}
|
||||
CompletionsUsage usage = chatCompletions.getUsage();
|
||||
if (usage != null) {
|
||||
log.info(
|
||||
"Usage: number of prompt token is {}, number of completion token is {}, and number of total "
|
||||
+ "tokens in request and response is {}.%n", usage.getPromptTokens(),
|
||||
usage.getCompletionTokens(), usage.getTotalTokens());
|
||||
}
|
||||
});
|
||||
eventSourceListener.onEvent(null, "[DONE]", null, "[DONE]");
|
||||
log.info("结束调用非流式输出自定义AI");
|
||||
} catch (Exception e) {
|
||||
log.error("请求参数解析异常", e);
|
||||
eventSourceListener.onFailure(null, e, null);
|
||||
throw new ParamBusinessException();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,118 @@
|
||||
package ai.chat2db.server.web.api.controller.ai.listener;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
import com.azure.ai.openai.models.Completions;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.unfbx.chatgpt.entity.chat.Message;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.Response;
|
||||
import okhttp3.ResponseBody;
|
||||
import okhttp3.sse.EventSource;
|
||||
import okhttp3.sse.EventSourceListener;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
/**
|
||||
* 描述:OpenAIEventSourceListener
|
||||
*
|
||||
* @author https:www.unfbx.com
|
||||
* @date 2023-02-22
|
||||
*/
|
||||
@Slf4j
|
||||
public class AzureOpenAIEventSourceListener extends EventSourceListener {
|
||||
|
||||
private SseEmitter sseEmitter;
|
||||
|
||||
public AzureOpenAIEventSourceListener(SseEmitter sseEmitter) {
|
||||
this.sseEmitter = sseEmitter;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@Override
|
||||
public void onOpen(EventSource eventSource, Response response) {
|
||||
log.info("AzureOpenAI建立sse连接...");
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*/
|
||||
@SneakyThrows
|
||||
@Override
|
||||
public void onEvent(EventSource eventSource, String id, String type, String data) {
|
||||
log.info("AzureOpenAI返回数据:{}", data);
|
||||
if (data.equals("[DONE]")) {
|
||||
log.info("AzureOpenAI返回数据结束了");
|
||||
sseEmitter.send(SseEmitter.event()
|
||||
.id("[DONE]")
|
||||
.data("[DONE]")
|
||||
.reconnectTime(3000));
|
||||
sseEmitter.complete();
|
||||
return;
|
||||
}
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
// 读取Json
|
||||
Completions completionResponse = mapper.readValue(data, Completions.class);
|
||||
String text = completionResponse.getChoices().get(0).getText() ;
|
||||
Message message = new Message();
|
||||
if (text != null) {
|
||||
message.setContent(text);
|
||||
sseEmitter.send(SseEmitter.event()
|
||||
.id(completionResponse.getId())
|
||||
.data(message)
|
||||
.reconnectTime(3000));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onClosed(EventSource eventSource) {
|
||||
sseEmitter.complete();
|
||||
log.info("AzureOpenAI关闭sse连接...");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(EventSource eventSource, Throwable t, Response response) {
|
||||
try {
|
||||
if (Objects.isNull(response)) {
|
||||
String message = t.getMessage();
|
||||
if ("No route to host".equals(message)) {
|
||||
message = "网络连接超时,请检查网络连通性,参考文章<https://github.com/alibaba/Chat2DB/blob/main/CHAT2DB_AI_SQL.md>";
|
||||
} else {
|
||||
message = "Azure AI无法正常访问,请参考文章<https://github.com/alibaba/Chat2DB/blob/main/CHAT2DB_AI_SQL.md>进行配置";
|
||||
}
|
||||
Message sseMessage = new Message();
|
||||
sseMessage.setContent(message);
|
||||
sseEmitter.send(SseEmitter.event()
|
||||
.id("[ERROR]")
|
||||
.data(sseMessage));
|
||||
sseEmitter.send(SseEmitter.event()
|
||||
.id("[DONE]")
|
||||
.data("[DONE]"));
|
||||
sseEmitter.complete();
|
||||
return;
|
||||
}
|
||||
ResponseBody body = response.body();
|
||||
String bodyString = null;
|
||||
if (Objects.nonNull(body)) {
|
||||
bodyString = body.string();
|
||||
log.error("Azure OpenAI sse连接异常data:{},异常:{}", bodyString, t);
|
||||
} else {
|
||||
log.error("Azure OpenAI sse连接异常data:{},异常:{}", response, t);
|
||||
}
|
||||
eventSource.cancel();
|
||||
Message message = new Message();
|
||||
message.setContent("Azure OpenAI出现异常,请在帮助中查看详细日志:" + bodyString);
|
||||
sseEmitter.send(SseEmitter.event()
|
||||
.id("[ERROR]")
|
||||
.data(message));
|
||||
sseEmitter.send(SseEmitter.event()
|
||||
.id("[DONE]")
|
||||
.data("[DONE]"));
|
||||
sseEmitter.complete();
|
||||
} catch (Exception exception) {
|
||||
log.error("Azure OpenAI发送数据异常:", exception);
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user