add baichuan support

This commit is contained in:
robinji0
2023-10-31 17:25:02 +08:00
parent fbf4014d3b
commit 22e84567b6
11 changed files with 669 additions and 3 deletions

View File

@ -38,6 +38,8 @@ import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest;
import ai.chat2db.server.web.api.controller.ai.request.ChatRequest;
import ai.chat2db.server.web.api.controller.ai.rest.client.RestAIClient;
import ai.chat2db.server.web.api.controller.ai.rest.listener.RestAIEventSourceListener;
import ai.chat2db.server.web.api.controller.ai.tongyi.client.TongyiChatAIClient;
import ai.chat2db.server.web.api.controller.ai.tongyi.listener.TongyiChatAIEventSourceListener;
import ai.chat2db.server.web.api.controller.ai.wenxin.client.WenxinAIClient;
import ai.chat2db.server.web.api.controller.ai.wenxin.listener.WenxinAIEventSourceListener;
import ai.chat2db.server.web.api.http.GatewayClientService;
@ -238,6 +240,10 @@ public class ChatController {
return chatWithClaudeAi(queryRequest, sseEmitter, uid);
case WENXINAI:
return chatWithWenxinAi(queryRequest, sseEmitter, uid);
case BAICHUANAI:
return chatWithBaichuanAi(queryRequest, sseEmitter, uid);
case TONGYIQIANWENAI:
return chatWithTongyiChatAi(queryRequest, sseEmitter, uid);
}
return chatWithOpenAi(queryRequest, sseEmitter, uid);
}
@ -373,6 +379,27 @@ public class ChatController {
return sseEmitter;
}
/**
* chat with tongyi chat openai
*
* @param queryRequest
* @param sseEmitter
* @param uid
* @return
* @throws IOException
*/
private SseEmitter chatWithTongyiChatAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
String prompt = buildPrompt(queryRequest);
List<FastChatMessage> messages = getFastChatMessage(uid, prompt);
buildSseEmitter(sseEmitter, uid);
TongyiChatAIEventSourceListener sourceListener = new TongyiChatAIEventSourceListener(sseEmitter);
TongyiChatAIClient.getInstance().streamCompletions(messages, sourceListener);
LocalCache.CACHE.put(uid, messages, LocalCache.TIMEOUT);
return sseEmitter;
}
/**
* chat with baichuan chat openai
*

View File

@ -0,0 +1,80 @@
package ai.chat2db.server.web.api.controller.ai.tongyi.client;
import ai.chat2db.server.domain.api.model.Config;
import ai.chat2db.server.domain.api.service.ConfigService;
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
/**
* @author moji
* @date 23/09/26
*/
@Slf4j
public class TongyiChatAIClient {
/**
* TONGYI OPENAI KEY
*/
public static final String TONGYI_API_KEY = "tongyi.chatgpt.apiKey";
/**
* TONGYI OPENAI HOST
*/
public static final String TONGYI_HOST = "tongyi.host";
/**
* TONGYI OPENAI model
*/
public static final String TONGYI_MODEL= "tongyi.model";
/**
* TONGYI OPENAI embedding model
*/
public static final String TONGYI_EMBEDDING_MODEL = "tongyi.embedding.model";
private static TongyiChatAIStreamClient TONGYI_AI_CLIENT;
public static TongyiChatAIStreamClient getInstance() {
if (TONGYI_AI_CLIENT != null) {
return TONGYI_AI_CLIENT;
} else {
return singleton();
}
}
private static TongyiChatAIStreamClient singleton() {
if (TONGYI_AI_CLIENT == null) {
synchronized (TongyiChatAIClient.class) {
if (TONGYI_AI_CLIENT == null) {
refresh();
}
}
}
return TONGYI_AI_CLIENT;
}
public static void refresh() {
String apiKey = "";
String apiHost = "";
String model = "";
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config apiHostConfig = configService.find(TONGYI_HOST).getData();
if (apiHostConfig != null && StringUtils.isNotBlank(apiHostConfig.getContent())) {
apiHost = apiHostConfig.getContent();
}
Config config = configService.find(TONGYI_API_KEY).getData();
if (config != null && StringUtils.isNotBlank(config.getContent())) {
apiKey = config.getContent();
}
Config deployConfig = configService.find(TONGYI_MODEL).getData();
if (deployConfig != null && StringUtils.isNotBlank(deployConfig.getContent())) {
model = deployConfig.getContent();
}
TONGYI_AI_CLIENT = TongyiChatAIStreamClient.builder().apiKey(apiKey).apiHost(apiHost).model(model)
.build();
}
}

View File

@ -0,0 +1,210 @@
package ai.chat2db.server.web.api.controller.ai.tongyi.client;
import ai.chat2db.server.tools.common.exception.ParamBusinessException;
import ai.chat2db.server.web.api.controller.ai.fastchat.interceptor.FastChatHeaderAuthorizationInterceptor;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage;
import ai.chat2db.server.web.api.controller.ai.tongyi.model.TongyiChatCompletionsOptions;
import ai.chat2db.server.web.api.controller.ai.tongyi.model.TongyiChatMessage;
import cn.hutool.http.ContentType;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;
import org.apache.commons.collections4.CollectionUtils;
import org.jetbrains.annotations.NotNull;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
/**
* tongyi Chat Aligned Client
*
* @author moji
*/
@Slf4j
public class TongyiChatAIStreamClient {
/**
* apikey
*/
@Getter
@NotNull
private String apiKey;
/**
* apiHost
*/
@Getter
@NotNull
private String apiHost;
/**
* model
*/
@Getter
private String model;
/**
* embeddingModel
*/
@Getter
private String embeddingModel;
/**
* okHttpClient
*/
@Getter
private OkHttpClient okHttpClient;
/**
* @param builder
*/
private TongyiChatAIStreamClient(Builder builder) {
this.apiKey = builder.apiKey;
this.apiHost = builder.apiHost;
this.model = builder.model;
this.embeddingModel = builder.embeddingModel;
if (Objects.isNull(builder.okHttpClient)) {
builder.okHttpClient = this.okHttpClient();
}
okHttpClient = builder.okHttpClient;
}
/**
* okhttpclient
*/
private OkHttpClient okHttpClient() {
OkHttpClient okHttpClient = new OkHttpClient
.Builder()
.addInterceptor(new FastChatHeaderAuthorizationInterceptor(this.apiKey))
.connectTimeout(10, TimeUnit.SECONDS)
.writeTimeout(50, TimeUnit.SECONDS)
.readTimeout(50, TimeUnit.SECONDS)
.build();
return okHttpClient;
}
/**
* 构造
*
* @return
*/
public static TongyiChatAIStreamClient.Builder builder() {
return new TongyiChatAIStreamClient.Builder();
}
/**
* builder
*/
public static final class Builder {
private String apiKey;
private String apiHost;
private String model;
private String embeddingModel;
/**
* OkhttpClient
*/
private OkHttpClient okHttpClient;
public Builder() {
}
public TongyiChatAIStreamClient.Builder apiKey(String apiKeyValue) {
this.apiKey = apiKeyValue;
return this;
}
/**
* @param apiHostValue
* @return
*/
public TongyiChatAIStreamClient.Builder apiHost(String apiHostValue) {
this.apiHost = apiHostValue;
return this;
}
/**
* @param modelValue
* @return
*/
public TongyiChatAIStreamClient.Builder model(String modelValue) {
this.model = modelValue;
return this;
}
public TongyiChatAIStreamClient.Builder embeddingModel(String embeddingModelValue) {
this.embeddingModel = embeddingModelValue;
return this;
}
public TongyiChatAIStreamClient.Builder okHttpClient(OkHttpClient val) {
this.okHttpClient = val;
return this;
}
public TongyiChatAIStreamClient build() {
return new TongyiChatAIStreamClient(this);
}
}
/**
* 问答接口 stream 形式
*
* @param chatMessages
* @param eventSourceListener
*/
public void streamCompletions(List<FastChatMessage> chatMessages, EventSourceListener eventSourceListener) {
if (CollectionUtils.isEmpty(chatMessages)) {
log.error("param errorTongyi Chat Prompt cannot be empty");
throw new ParamBusinessException("prompt");
}
if (Objects.isNull(eventSourceListener)) {
log.error("param errorTongyiChatEventSourceListener cannot be empty");
throw new ParamBusinessException();
}
log.info("Tongyi Chat AI, prompt:{}", chatMessages.get(chatMessages.size() - 1).getContent());
try {
TongyiChatCompletionsOptions chatCompletionsOptions = new TongyiChatCompletionsOptions();
chatCompletionsOptions.setStream(true);
chatCompletionsOptions.setModel(this.model);
Map<String, Object> parameters = new HashMap<>();
parameters.put("result_format", "text");
chatCompletionsOptions.setParameters(parameters);
TongyiChatMessage tongyiChatMessage = new TongyiChatMessage();
tongyiChatMessage.setMessages(chatMessages);
chatCompletionsOptions.setInput(tongyiChatMessage);
EventSource.Factory factory = EventSources.createFactory(this.okHttpClient);
ObjectMapper mapper = new ObjectMapper();
String requestBody = mapper.writeValueAsString(chatCompletionsOptions);
Request request = new Request.Builder()
.url(apiHost)
.post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody))
.build();
//创建事件
EventSource eventSource = factory.newEventSource(request, eventSourceListener);
log.info("finish invoking tongyi chat ai");
} catch (Exception e) {
log.error("tongyi chat ai error", e);
eventSourceListener.onFailure(null, e, null);
throw new ParamBusinessException();
}
}
}

View File

@ -0,0 +1,131 @@
package ai.chat2db.server.web.api.controller.ai.tongyi.listener;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatChoice;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatCompletions;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatCompletionsUsage;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage;
import ai.chat2db.server.web.api.controller.ai.tongyi.model.TongyiChatCompletions;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.unfbx.chatgpt.entity.chat.Message;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.ResponseBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.apache.commons.lang3.StringUtils;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
import java.util.Objects;
/**
* 描述OpenAIEventSourceListener
*
* @author https:www.unfbx.com
* @date 2023-02-22
*/
@Slf4j
public class TongyiChatAIEventSourceListener extends EventSourceListener {
private SseEmitter sseEmitter;
private ObjectMapper mapper = new ObjectMapper().disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES);
public TongyiChatAIEventSourceListener(SseEmitter sseEmitter) {
this.sseEmitter = sseEmitter;
}
/**
* {@inheritDoc}
*/
@Override
public void onOpen(EventSource eventSource, Response response) {
log.info("Tongyi Chat Sse connecting...");
}
/**
* {@inheritDoc}
*/
@SneakyThrows
@Override
public void onEvent(EventSource eventSource, String id, String type, String data) {
log.info("Tongyi Chat AI response data{}", data);
if (data.equals("[DONE]")) {
log.info("Tongyi Chat AI closed");
sseEmitter.send(SseEmitter.event()
.id("[DONE]")
.data("[DONE]")
.reconnectTime(3000));
sseEmitter.complete();
return;
}
TongyiChatCompletions chatCompletions = mapper.readValue(data, TongyiChatCompletions.class);
String text = chatCompletions.getOutput().getText();
log.info("id: {}, text: {}", chatCompletions.getId(), text);
Message message = new Message();
message.setContent(text);
sseEmitter.send(SseEmitter.event()
.id(null)
.data(message)
.reconnectTime(3000));
}
@Override
public void onClosed(EventSource eventSource) {
try {
sseEmitter.send(SseEmitter.event()
.id("[DONE]")
.data("[DONE]"));
} catch (IOException e) {
throw new RuntimeException(e);
}
sseEmitter.complete();
log.info("TongyiChatAI close sse connection...");
}
@Override
public void onFailure(EventSource eventSource, Throwable t, Response response) {
try {
if (Objects.isNull(response)) {
String message = t.getMessage();
Message sseMessage = new Message();
sseMessage.setContent(message);
sseEmitter.send(SseEmitter.event()
.id("[ERROR]")
.data(sseMessage));
sseEmitter.send(SseEmitter.event()
.id("[DONE]")
.data("[DONE]"));
sseEmitter.complete();
return;
}
ResponseBody body = response.body();
String bodyString = Objects.nonNull(t) ? t.getMessage() : "";
if (Objects.nonNull(body)) {
bodyString = body.string();
if (StringUtils.isBlank(bodyString) && Objects.nonNull(t)) {
bodyString = t.getMessage();
}
log.error("Tongyi Chat AI sse response{}", bodyString);
} else {
log.error("Tongyi Chat AI sse response{}error{}", response, t);
}
eventSource.cancel();
Message message = new Message();
message.setContent("Tongyi Chat AI error" + bodyString);
sseEmitter.send(SseEmitter.event()
.id("[ERROR]")
.data(message));
sseEmitter.send(SseEmitter.event()
.id("[DONE]")
.data("[DONE]"));
sseEmitter.complete();
} catch (Exception exception) {
log.error("Tongyi Chat AI send data error:", exception);
}
}
}

View File

@ -0,0 +1,46 @@
package ai.chat2db.server.web.api.controller.ai.tongyi.model;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
@Data
public class TongyiChatCompletions {
/*
* A unique identifier associated with this chat completions response.
*/
@JsonProperty(value = "request_id")
private String id;
/*
* The collection of completions choices associated with this completions response.
* Generally, `n` choices are generated per provided prompt with a default value of 1.
* Token limits and other settings may limit the number of choices generated.
*/
private TongyiChatOutput output;
/*
* Usage information for tokens processed and generated as part of this completions operation.
*/
private TongyiChatCompletionsUsage usage;
/**
* Creates an instance of ChatCompletions class.
*
* @param id the id value to set.
* @param choices the choices value to set.
* @param usage the usage value to set.
*/
@JsonCreator
private TongyiChatCompletions(
@JsonProperty(value = "id") String id,
@JsonProperty(value = "output") TongyiChatOutput choices,
@JsonProperty(value = "usage") TongyiChatCompletionsUsage usage) {
this.id = id;
this.output = choices;
this.usage = usage;
}
}

View File

@ -0,0 +1,41 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// Code generated by Microsoft (R) AutoRest Code Generator.
package ai.chat2db.server.web.api.controller.ai.tongyi.model;
import lombok.Data;
import java.util.Map;
/**
* The configuration information for a chat completions request. Completions support a wide variety of tasks and
* generate text that continues from or "completes" provided prompt data.
*/
@Data
public final class TongyiChatCompletionsOptions {
/*
* The collection of context messages associated with this chat completions request.
* Typical usage begins with a chat message for the System role that provides instructions for
* the behavior of the assistant, followed by alternating messages between the User and
* Assistant roles.
*/
private TongyiChatMessage input;
/*
* A value indicating whether chat completions should be streamed for this request.
*/
private Boolean stream;
/*
* The model name to provide as part of this completions request.
* Not applicable to Fast Chat AI, where deployment information should be included in the Fast Chat
* resource URI that's connected to.
*/
private String model;
/**
* parameters
*/
private Map<String, Object> parameters;
}

View File

@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// Code generated by Microsoft (R) AutoRest Code Generator.
package ai.chat2db.server.web.api.controller.ai.tongyi.model;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
import lombok.NoArgsConstructor;
/**
* Representation of the token counts processed for a completions request. Counts consider all tokens across prompts,
* choices, choice alternates, best_of generations, and other consumers.
*/
@Data
@NoArgsConstructor
public final class TongyiChatCompletionsUsage {
/*
* The number of tokens generated across all completions emissions.
*/
@JsonProperty(value = "output_tokens")
private int outputTokens;
/*
* The number of tokens in the provided prompts for the completions request.
*/
@JsonProperty(value = "input_tokens")
private int inputTokens;
/**
* Creates an instance of CompletionsUsage class.
*
* @param completionTokens the completionTokens value to set.
* @param promptTokens the promptTokens value to set.
*/
@JsonCreator
private TongyiChatCompletionsUsage(
@JsonProperty(value = "output_tokens") int completionTokens,
@JsonProperty(value = "input_tokens") int promptTokens) {
this.outputTokens = completionTokens;
this.inputTokens = promptTokens;
}
}

View File

@ -0,0 +1,13 @@
package ai.chat2db.server.web.api.controller.ai.tongyi.model;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage;
import lombok.Data;
import java.util.List;
@Data
public class TongyiChatMessage {
private List<FastChatMessage> messages;
}

View File

@ -0,0 +1,43 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// Code generated by Microsoft (R) AutoRest Code Generator.
package ai.chat2db.server.web.api.controller.ai.tongyi.model;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
/**
* The representation of a single prompt completion as part of an overall completions request. Generally, `n` choices
* are generated per provided prompt with a default value of 1. Token limits and other settings may limit the number of
* choices generated.
*/
@Data
public final class TongyiChatOutput {
/*
* The generated text for a given completions prompt.
*/
@JsonProperty(value = "text")
private String text;
/*
* Reason for finishing
*/
@JsonProperty(value = "finish_reason")
private String finishReason;
/**
* Creates an instance of Choice class.
*
* @param text the text value to set.
* @param finishReason the finishReason value to set.
*/
@JsonCreator
private TongyiChatOutput(
@JsonProperty(value = "text") String text,
@JsonProperty(value = "finish_reason") String finishReason) {
this.text = text;
this.finishReason = finishReason;
}
}

View File

@ -17,6 +17,7 @@ import ai.chat2db.server.web.api.controller.ai.baichuan.client.BaichuanAIClient;
import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient;
import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatAIClient;
import ai.chat2db.server.web.api.controller.ai.rest.client.RestAIClient;
import ai.chat2db.server.web.api.controller.ai.tongyi.client.TongyiChatAIClient;
import ai.chat2db.server.web.api.controller.ai.wenxin.client.WenxinAIClient;
import ai.chat2db.server.web.api.controller.config.request.AIConfigCreateRequest;
import ai.chat2db.server.web.api.controller.config.request.SystemConfigRequest;
@ -89,6 +90,9 @@ public class ConfigController {
case FASTCHATAI:
saveFastChatAIConfig(request);
break;
case TONGYIQIANWENAI:
saveTongyiChatAIConfig(request);
break;
case WENXINAI:
saveWenxinAIConfig(request);
break;
@ -188,6 +192,24 @@ public class ConfigController {
FastChatAIClient.refresh();
}
/**
* save common tongyi chat ai config
*
* @param request
*/
private void saveTongyiChatAIConfig(AIConfigCreateRequest request) {
SystemConfigParam apikeyParam = SystemConfigParam.builder().code(TongyiChatAIClient.TONGYI_API_KEY)
.content(request.getApiKey()).build();
configService.createOrUpdate(apikeyParam);
SystemConfigParam apiHostParam = SystemConfigParam.builder().code(TongyiChatAIClient.TONGYI_HOST)
.content(request.getApiHost()).build();
configService.createOrUpdate(apiHostParam);
SystemConfigParam modelParam = SystemConfigParam.builder().code(TongyiChatAIClient.TONGYI_MODEL)
.content(request.getModel()).build();
configService.createOrUpdate(modelParam);
TongyiChatAIClient.refresh();
}
/**
* save common wenxin chat ai config
*
@ -221,7 +243,7 @@ public class ConfigController {
SystemConfigParam modelParam = SystemConfigParam.builder().code(BaichuanAIClient.BAICHUAN_MODEL)
.content(request.getModel()).build();
configService.createOrUpdate(modelParam);
FastChatAIClient.refresh();
BaichuanAIClient.refresh();
}
@GetMapping("/system_config/{code}")
@ -310,6 +332,14 @@ public class ConfigController {
config.setApiHost(Objects.nonNull(baichuanApiHost.getData()) ? baichuanApiHost.getData().getContent() : "");
config.setModel(Objects.nonNull(baichuanModel.getData()) ? baichuanModel.getData().getContent() : "");
break;
case TONGYIQIANWENAI:
DataResult<Config> tongyiApiKey = configService.find(TongyiChatAIClient.TONGYI_API_KEY);
DataResult<Config> tongyiApiHost = configService.find(TongyiChatAIClient.TONGYI_HOST);
DataResult<Config> tongyiModel = configService.find(TongyiChatAIClient.TONGYI_MODEL);
config.setApiKey(Objects.nonNull(tongyiApiKey.getData()) ? tongyiApiKey.getData().getContent() : "");
config.setApiHost(Objects.nonNull(tongyiApiHost.getData()) ? tongyiApiHost.getData().getContent() : "");
config.setModel(Objects.nonNull(tongyiModel.getData()) ? tongyiModel.getData().getContent() : "");
break;
default:
break;
}