add azure openai support

This commit is contained in:
robin
2023-06-27 21:25:10 +08:00
parent c0899e0d9e
commit 8566907050
5 changed files with 306 additions and 1 deletions

View File

@ -34,10 +34,13 @@
<groupId>com.unfbx</groupId>
<artifactId>chatgpt-java</artifactId>
</dependency>
<dependency>
<groupId>com.azure</groupId>
<artifactId>azure-ai-openai</artifactId>
</dependency>
<dependency>
<groupId>com.theokanning.openai-gpt3-java</groupId>
<artifactId>service</artifactId>
<version>0.12.0</version>
</dependency>
</dependencies>

View File

@ -0,0 +1,77 @@
/**
* alibaba.com Inc.
* Copyright (c) 2004-2023 All Rights Reserved.
*/
package ai.chat2db.server.web.api.controller.ai.azure.client;
import ai.chat2db.server.domain.api.model.Config;
import ai.chat2db.server.domain.api.service.ConfigService;
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
import lombok.extern.slf4j.Slf4j;
/**
* @author jipengfei
* @version : OpenAIClient.java
*/
@Slf4j
public class AzureOpenAIClient {
public static final String AZURE_CHATGPT_API_KEY = "azure.chatgpt.apiKey";
/**
* OPENAI接口域名
*/
public static final String AZURE_CHATGPT_ENDPOINT = "azure.chatgpt.endpoint";
private static AzureOpenAiStreamClient OPEN_AI_CLIENT;
private static String apiKey;
public static AzureOpenAiStreamClient getInstance() {
if (OPEN_AI_CLIENT != null) {
return OPEN_AI_CLIENT;
} else {
return singleton();
}
}
private static AzureOpenAiStreamClient singleton() {
if (OPEN_AI_CLIENT == null) {
synchronized (AzureOpenAIClient.class) {
if (OPEN_AI_CLIENT == null) {
refresh();
}
}
}
return OPEN_AI_CLIENT;
}
public static void refresh() {
String apikey = "";
String apiEndpoint = "";
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config apiHostConfig = configService.find(AZURE_CHATGPT_ENDPOINT).getData();
if (apiHostConfig != null) {
apiEndpoint = apiHostConfig.getContent();
}
Config config = configService.find(AZURE_CHATGPT_API_KEY).getData();
if (config != null) {
apikey = config.getContent();
}
log.info("refresh azure openai apikey:{}", maskApiKey(apikey));
OPEN_AI_CLIENT = new AzureOpenAiStreamClient(apiKey, apiEndpoint);
apiKey = apikey;
}
private static String maskApiKey(String input) {
if (input == null) {
return input;
}
StringBuilder maskedString = new StringBuilder(input);
for (int i = input.length() / 2; i < input.length() / 2; i++) {
maskedString.setCharAt(i, '*');
}
return maskedString.toString();
}
}

View File

@ -0,0 +1,102 @@
package ai.chat2db.server.web.api.controller.ai.azure.client;
import java.util.List;
import java.util.Objects;
import ai.chat2db.server.tools.common.exception.ParamBusinessException;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.ai.openai.models.ChatChoice;
import com.azure.ai.openai.models.ChatCompletions;
import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.azure.ai.openai.models.ChatMessage;
import com.azure.ai.openai.models.CompletionsUsage;
import com.azure.core.credential.AzureKeyCredential;
import com.azure.core.util.IterableStream;
import com.unfbx.chatgpt.entity.chat.Message;
import lombok.extern.slf4j.Slf4j;
import okhttp3.sse.EventSourceListener;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
/**
* 自定义AI接口client
*
* @author moji
*/
@Slf4j
public class AzureOpenAiStreamClient {
/**
* client
*/
private OpenAIClient client;
/**
* 构造实例对象
*
* @param apiKey
* @param endpoint
*/
public AzureOpenAiStreamClient(String apiKey, String endpoint) {
this.client = new OpenAIClientBuilder()
.credential(new AzureKeyCredential(apiKey))
.endpoint(endpoint)
.buildClient();
}
/**
* 问答接口 stream 形式
*
* @param deployId
* @param chatMessages
* @param eventSourceListener
*/
public void streamCompletions(String deployId, List<ChatMessage> chatMessages, EventSourceListener eventSourceListener) {
if (CollectionUtils.isEmpty(chatMessages)) {
log.error("参数异常Azure Prompt不能为空");
throw new ParamBusinessException("prompt");
}
if (Objects.isNull(eventSourceListener)) {
log.error("参数异常AzureEventSourceListener不能为空");
throw new ParamBusinessException();
}
log.info("开始调用Azure Open AI, prompt:{}", chatMessages.get(chatMessages.size() - 1).getContent());
try {
IterableStream<ChatCompletions> chatCompletionsStream = client.getChatCompletionsStream(deployId,
new ChatCompletionsOptions(chatMessages));
chatCompletionsStream.forEach(chatCompletions -> {
String text = "";
System.out.printf("Model ID=%s is created at %d.%n", chatCompletions.getId(),
chatCompletions.getCreated());
for (ChatChoice choice : chatCompletions.getChoices()) {
ChatMessage message = choice.getDelta();
if (message != null) {
log.info("Index: {}, Chat Role: {}.%n", choice.getIndex(), message.getRole());
text = message.getContent();
}
}
Message message = new Message();
if (StringUtils.isNotBlank(text)) {
message.setContent(text);
eventSourceListener.onEvent(null, "[DATA]", null, text);
}
CompletionsUsage usage = chatCompletions.getUsage();
if (usage != null) {
log.info(
"Usage: number of prompt token is {}, number of completion token is {}, and number of total "
+ "tokens in request and response is {}.%n", usage.getPromptTokens(),
usage.getCompletionTokens(), usage.getTotalTokens());
}
});
eventSourceListener.onEvent(null, "[DONE]", null, "[DONE]");
log.info("结束调用非流式输出自定义AI");
} catch (Exception e) {
log.error("请求参数解析异常", e);
eventSourceListener.onFailure(null, e, null);
throw new ParamBusinessException();
}
}
}

View File

@ -0,0 +1,118 @@
package ai.chat2db.server.web.api.controller.ai.listener;
import java.util.Objects;
import com.azure.ai.openai.models.Completions;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.unfbx.chatgpt.entity.chat.Message;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.ResponseBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
/**
* 描述OpenAIEventSourceListener
*
* @author https:www.unfbx.com
* @date 2023-02-22
*/
@Slf4j
public class AzureOpenAIEventSourceListener extends EventSourceListener {
private SseEmitter sseEmitter;
public AzureOpenAIEventSourceListener(SseEmitter sseEmitter) {
this.sseEmitter = sseEmitter;
}
/**
* {@inheritDoc}
*/
@Override
public void onOpen(EventSource eventSource, Response response) {
log.info("AzureOpenAI建立sse连接...");
}
/**
* {@inheritDoc}
*/
@SneakyThrows
@Override
public void onEvent(EventSource eventSource, String id, String type, String data) {
log.info("AzureOpenAI返回数据{}", data);
if (data.equals("[DONE]")) {
log.info("AzureOpenAI返回数据结束了");
sseEmitter.send(SseEmitter.event()
.id("[DONE]")
.data("[DONE]")
.reconnectTime(3000));
sseEmitter.complete();
return;
}
ObjectMapper mapper = new ObjectMapper();
// 读取Json
Completions completionResponse = mapper.readValue(data, Completions.class);
String text = completionResponse.getChoices().get(0).getText() ;
Message message = new Message();
if (text != null) {
message.setContent(text);
sseEmitter.send(SseEmitter.event()
.id(completionResponse.getId())
.data(message)
.reconnectTime(3000));
}
}
@Override
public void onClosed(EventSource eventSource) {
sseEmitter.complete();
log.info("AzureOpenAI关闭sse连接...");
}
@Override
public void onFailure(EventSource eventSource, Throwable t, Response response) {
try {
if (Objects.isNull(response)) {
String message = t.getMessage();
if ("No route to host".equals(message)) {
message = "网络连接超时,请检查网络连通性,参考文章<https://github.com/alibaba/Chat2DB/blob/main/CHAT2DB_AI_SQL.md>";
} else {
message = "Azure AI无法正常访问请参考文章<https://github.com/alibaba/Chat2DB/blob/main/CHAT2DB_AI_SQL.md>进行配置";
}
Message sseMessage = new Message();
sseMessage.setContent(message);
sseEmitter.send(SseEmitter.event()
.id("[ERROR]")
.data(sseMessage));
sseEmitter.send(SseEmitter.event()
.id("[DONE]")
.data("[DONE]"));
sseEmitter.complete();
return;
}
ResponseBody body = response.body();
String bodyString = null;
if (Objects.nonNull(body)) {
bodyString = body.string();
log.error("Azure OpenAI sse连接异常data{},异常:{}", bodyString, t);
} else {
log.error("Azure OpenAI sse连接异常data{},异常:{}", response, t);
}
eventSource.cancel();
Message message = new Message();
message.setContent("Azure OpenAI出现异常,请在帮助中查看详细日志:" + bodyString);
sseEmitter.send(SseEmitter.event()
.id("[ERROR]")
.data(message));
sseEmitter.send(SseEmitter.event()
.id("[DONE]")
.data("[DONE]"));
sseEmitter.complete();
} catch (Exception exception) {
log.error("Azure OpenAI发送数据异常:", exception);
}
}
}