add wenxinyiyan support

This commit is contained in:
robinji0
2023-10-31 14:13:57 +08:00
parent 44ade9b85a
commit 2e627cbcb5
8 changed files with 599 additions and 0 deletions

View File

@ -37,6 +37,26 @@ public enum AiSqlSourceEnum implements BaseEnum<String> {
*/
CLAUDEAI("CLAUDE AI"),
/**
* WNEXIN AI
*/
WENXINAI("WENXIN AI"),
/**
* BAICHUAN AI
*/
BAICHUANAI("BAICHUAN AI"),
/**
* ZHIPU AI
*/
ZHIPUAI("ZHIPU AI"),
/**
* QIANWEN AI
*/
QIANWENAI("QIANWEN AI"),
/**
* FAST CHAT AI
*/

View File

@ -36,6 +36,8 @@ import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest;
import ai.chat2db.server.web.api.controller.ai.request.ChatRequest;
import ai.chat2db.server.web.api.controller.ai.rest.client.RestAIClient;
import ai.chat2db.server.web.api.controller.ai.rest.listener.RestAIEventSourceListener;
import ai.chat2db.server.web.api.controller.ai.wenxin.client.WenxinAIClient;
import ai.chat2db.server.web.api.controller.ai.wenxin.listener.WenxinAIEventSourceListener;
import ai.chat2db.server.web.api.http.GatewayClientService;
import ai.chat2db.server.web.api.http.model.EsTableSchema;
import ai.chat2db.server.web.api.http.model.TableSchema;
@ -232,6 +234,8 @@ public class ChatController {
return chatWithAzureAi(queryRequest, sseEmitter, uid);
case CLAUDEAI:
return chatWithClaudeAi(queryRequest, sseEmitter, uid);
case WENXINAI:
return chatWithWenxinAi(queryRequest, sseEmitter, uid);
}
return chatWithOpenAi(queryRequest, sseEmitter, uid);
}
@ -376,6 +380,36 @@ public class ChatController {
return sseEmitter;
}
/**
* chat with fast chat openai
*
* @param queryRequest
* @param sseEmitter
* @param uid
* @return
* @throws IOException
*/
private SseEmitter chatWithWenxinAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
String prompt = buildPrompt(queryRequest);
List<FastChatMessage> messages = (List<FastChatMessage>)LocalCache.CACHE.get(uid);
if (CollectionUtils.isNotEmpty(messages)) {
if (messages.size() >= contextLength) {
messages = messages.subList(1, contextLength);
}
} else {
messages = Lists.newArrayList();
}
FastChatMessage currentMessage = new FastChatMessage(FastChatRole.USER).setContent(prompt);
messages.add(currentMessage);
buildSseEmitter(sseEmitter, uid);
WenxinAIEventSourceListener sourceListener = new WenxinAIEventSourceListener(sseEmitter);
WenxinAIClient.getInstance().streamCompletions(messages, sourceListener);
LocalCache.CACHE.put(uid, messages, LocalCache.TIMEOUT);
return sseEmitter;
}
/**
* chat with claude ai

View File

@ -0,0 +1,81 @@
package ai.chat2db.server.web.api.controller.ai.wenxin.client;
import ai.chat2db.server.domain.api.model.Config;
import ai.chat2db.server.domain.api.service.ConfigService;
import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatAIStreamClient;
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
/**
* @author moji
* @date 23/09/26
*/
@Slf4j
public class WenxinAIClient {
/**
* WENXIN_ACCESS_TOKEN
*/
public static final String WENXIN_ACCESS_TOKEN = "wenxin.access.token";
/**
* WENXIN_HOST
*/
public static final String WENXIN_HOST = "wenxin.host";
/**
* WENXIN_MODEL
*/
public static final String WENXIN_MODEL= "wenxin.model";
/**
* Wenxin embedding model
*/
public static final String WENXIN_EMBEDDING_MODEL = "wenxin.embedding.model";
private static WenxinAIStreamClient WENXIN_AI_CLIENT;
public static WenxinAIStreamClient getInstance() {
if (WENXIN_AI_CLIENT != null) {
return WENXIN_AI_CLIENT;
} else {
return singleton();
}
}
private static WenxinAIStreamClient singleton() {
if (WENXIN_AI_CLIENT == null) {
synchronized (WenxinAIClient.class) {
if (WENXIN_AI_CLIENT == null) {
refresh();
}
}
}
return WENXIN_AI_CLIENT;
}
public static void refresh() {
String apiHost = "";
String accessToken = "";
String model = "";
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config apiHostConfig = configService.find(WENXIN_HOST).getData();
if (apiHostConfig != null && StringUtils.isNotBlank(apiHostConfig.getContent())) {
apiHost = apiHostConfig.getContent();
}
Config config = configService.find(WENXIN_ACCESS_TOKEN).getData();
if (config != null && StringUtils.isNotBlank(config.getContent())) {
accessToken = config.getContent();
}
Config deployConfig = configService.find(WENXIN_MODEL).getData();
if (deployConfig != null && StringUtils.isNotBlank(deployConfig.getContent())) {
model = deployConfig.getContent();
}
WENXIN_AI_CLIENT = WenxinAIStreamClient.builder().accessToken(accessToken).apiHost(apiHost).model(model)
.build();
}
}

View File

@ -0,0 +1,202 @@
package ai.chat2db.server.web.api.controller.ai.wenxin.client;
import ai.chat2db.server.tools.common.exception.ParamBusinessException;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatCompletionsOptions;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage;
import ai.chat2db.server.web.api.controller.ai.wenxin.interceptor.AccessTokenInterceptor;
import cn.hutool.http.ContentType;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;
import org.apache.commons.collections4.CollectionUtils;
import org.jetbrains.annotations.NotNull;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
/**
* Fast Chat Aligned Client
*
* @author moji
*/
@Slf4j
public class WenxinAIStreamClient {
/**
* apikey
*/
@Getter
@NotNull
private String accessToken;
/**
* apiHost
*/
@Getter
@NotNull
private String apiHost;
/**
* model
*/
@Getter
private String model;
/**
* embeddingModel
*/
@Getter
private String embeddingModel;
/**
* okHttpClient
*/
@Getter
private OkHttpClient okHttpClient;
/**
* @param builder
*/
private WenxinAIStreamClient(Builder builder) {
this.accessToken = builder.accessToken;
this.apiHost = builder.apiHost;
this.model = builder.model;
this.embeddingModel = builder.embeddingModel;
if (Objects.isNull(builder.okHttpClient)) {
builder.okHttpClient = this.okHttpClient();
}
okHttpClient = builder.okHttpClient;
}
/**
* okhttpclient
*/
private OkHttpClient okHttpClient() {
OkHttpClient okHttpClient = new OkHttpClient
.Builder()
.addInterceptor(new AccessTokenInterceptor(this.accessToken))
.connectTimeout(10, TimeUnit.SECONDS)
.writeTimeout(50, TimeUnit.SECONDS)
.readTimeout(50, TimeUnit.SECONDS)
.build();
return okHttpClient;
}
/**
* 构造
*
* @return
*/
public static WenxinAIStreamClient.Builder builder() {
return new WenxinAIStreamClient.Builder();
}
/**
* builder
*/
public static final class Builder {
private String accessToken;
private String apiHost;
private String model;
private String embeddingModel;
/**
* OkhttpClient
*/
private OkHttpClient okHttpClient;
public Builder() {
}
public WenxinAIStreamClient.Builder accessToken(String accessToken) {
this.accessToken = accessToken;
return this;
}
/**
* @param apiHostValue
* @return
*/
public WenxinAIStreamClient.Builder apiHost(String apiHostValue) {
this.apiHost = apiHostValue;
return this;
}
/**
* @param modelValue
* @return
*/
public WenxinAIStreamClient.Builder model(String modelValue) {
this.model = modelValue;
return this;
}
public WenxinAIStreamClient.Builder embeddingModel(String embeddingModelValue) {
this.embeddingModel = embeddingModelValue;
return this;
}
public WenxinAIStreamClient.Builder okHttpClient(OkHttpClient val) {
this.okHttpClient = val;
return this;
}
public WenxinAIStreamClient build() {
return new WenxinAIStreamClient(this);
}
}
/**
* 问答接口 stream 形式
*
* @param chatMessages
* @param eventSourceListener
*/
public void streamCompletions(List<FastChatMessage> chatMessages, EventSourceListener eventSourceListener) {
if (CollectionUtils.isEmpty(chatMessages)) {
log.error("param errorWenxin Prompt cannot be empty");
throw new ParamBusinessException("prompt");
}
if (Objects.isNull(eventSourceListener)) {
log.error("param errorWenxinEventSourceListener cannot be empty");
throw new ParamBusinessException();
}
log.info("Wenxin Chat AI, prompt:{}", chatMessages.get(chatMessages.size() - 1).getContent());
try {
FastChatCompletionsOptions chatCompletionsOptions = new FastChatCompletionsOptions(chatMessages);
chatCompletionsOptions.setStream(true);
chatCompletionsOptions.setModel(this.model);
EventSource.Factory factory = EventSources.createFactory(this.okHttpClient);
ObjectMapper mapper = new ObjectMapper();
String requestBody = mapper.writeValueAsString(chatCompletionsOptions);
Request request = new Request.Builder()
.url(apiHost)
.post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody))
.build();
//创建事件
EventSource eventSource = factory.newEventSource(request, eventSourceListener);
log.info("finish invoking fast chat ai");
} catch (Exception e) {
log.error("fast chat ai error", e);
eventSourceListener.onFailure(null, e, null);
throw new ParamBusinessException();
}
}
}

View File

@ -0,0 +1,35 @@
package ai.chat2db.server.web.api.controller.ai.wenxin.interceptor;
import okhttp3.HttpUrl;
import okhttp3.Interceptor;
import okhttp3.Request;
import okhttp3.Response;
import java.io.IOException;
public class AccessTokenInterceptor implements Interceptor {
private final String accessToken;
public AccessTokenInterceptor(String accessToken) {
this.accessToken = accessToken;
}
@Override
public Response intercept(Chain chain) throws IOException {
Request originalRequest = chain.request();
HttpUrl originalHttpUrl = originalRequest.url();
// 使用 HttpUrl.Builder 来添加查询参数 access_token
HttpUrl urlWithAccessToken = originalHttpUrl.newBuilder()
.addQueryParameter("access_token", accessToken)
.build();
// 创建新的请求,将新的 URL 应用到它上面
Request newRequest = originalRequest.newBuilder()
.url(urlWithAccessToken)
.build();
return chain.proceed(newRequest);
}
}

View File

@ -0,0 +1,128 @@
package ai.chat2db.server.web.api.controller.ai.wenxin.listener;
import ai.chat2db.server.web.api.controller.ai.wenxin.model.WenxinChatCompletions;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.unfbx.chatgpt.entity.chat.Message;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.ResponseBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.apache.commons.lang3.StringUtils;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
import java.util.Objects;
/**
* 描述OpenAIEventSourceListener
*
* @author https:www.unfbx.com
* @date 2023-02-22
*/
@Slf4j
public class WenxinAIEventSourceListener extends EventSourceListener {
private SseEmitter sseEmitter;
private ObjectMapper mapper = new ObjectMapper().disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES);
public WenxinAIEventSourceListener(SseEmitter sseEmitter) {
this.sseEmitter = sseEmitter;
}
/**
* {@inheritDoc}
*/
@Override
public void onOpen(EventSource eventSource, Response response) {
log.info("Fast Chat Sse connecting...");
}
/**
* {@inheritDoc}
*/
@SneakyThrows
@Override
public void onEvent(EventSource eventSource, String id, String type, String data) {
log.info("Wenxin AI response data{}", data);
if (data.equals("[DONE]")) {
log.info("Wenxin AI closed");
sseEmitter.send(SseEmitter.event()
.id("[DONE]")
.data("[DONE]")
.reconnectTime(3000));
sseEmitter.complete();
return;
}
WenxinChatCompletions chatCompletions = mapper.readValue(data, WenxinChatCompletions.class);
String text = chatCompletions.getResult();
log.info("Model={} is created at {}. message:{}", chatCompletions.getObject(),
chatCompletions.getCreated(), text);
Message message = new Message();
message.setContent(text);
sseEmitter.send(SseEmitter.event()
.id(null)
.data(message)
.reconnectTime(3000));
}
@Override
public void onClosed(EventSource eventSource) {
try {
sseEmitter.send(SseEmitter.event()
.id("[DONE]")
.data("[DONE]"));
} catch (IOException e) {
throw new RuntimeException(e);
}
sseEmitter.complete();
log.info("FastChatAI close sse connection...");
}
@Override
public void onFailure(EventSource eventSource, Throwable t, Response response) {
try {
if (Objects.isNull(response)) {
String message = t.getMessage();
Message sseMessage = new Message();
sseMessage.setContent(message);
sseEmitter.send(SseEmitter.event()
.id("[ERROR]")
.data(sseMessage));
sseEmitter.send(SseEmitter.event()
.id("[DONE]")
.data("[DONE]"));
sseEmitter.complete();
return;
}
ResponseBody body = response.body();
String bodyString = Objects.nonNull(t) ? t.getMessage() : "";
if (Objects.nonNull(body)) {
bodyString = body.string();
if (StringUtils.isBlank(bodyString) && Objects.nonNull(t)) {
bodyString = t.getMessage();
}
log.error("Fast Chat AI sse response{}", bodyString);
} else {
log.error("Fast Chat AI sse response{}error{}", response, t);
}
eventSource.cancel();
Message message = new Message();
message.setContent("Fast Chat AI error" + bodyString);
sseEmitter.send(SseEmitter.event()
.id("[ERROR]")
.data(message));
sseEmitter.send(SseEmitter.event()
.id("[DONE]")
.data("[DONE]"));
sseEmitter.complete();
} catch (Exception exception) {
log.error("Fast Chat AI send data error:", exception);
}
}
}

View File

@ -0,0 +1,74 @@
package ai.chat2db.server.web.api.controller.ai.wenxin.model;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatCompletionsUsage;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
@Data
public class WenxinChatCompletions {
/*
* A unique identifier associated with this chat completions response.
*/
private String id;
/*
* The first timestamp associated with generation activity for this completions response,
* represented as seconds since the beginning of the Unix epoch of 00:00 on 1 Jan 1970.
*/
private int created;
/**
* model
*/
@JsonProperty(value = "is_truncated")
private String isTruncated;
@JsonProperty(value = "need_clear_history")
private String needClearHistory;
/**
* object
*/
private String object;
/*
* The collection of completions choices associated with this completions response.
* Generally, `n` choices are generated per provided prompt with a default value of 1.
* Token limits and other settings may limit the number of choices generated.
*/
private String result;
/*
* Usage information for tokens processed and generated as part of this completions operation.
*/
private FastChatCompletionsUsage usage;
/**
* Creates an instance of ChatCompletions class.
*
* @param id the id value to set.
* @param created the created value to set.
* @param result the result value to set.
* @param usage the usage value to set.
*/
@JsonCreator
private WenxinChatCompletions(
@JsonProperty(value = "id") String id,
@JsonProperty(value = "created") int created,
@JsonProperty(value = "is_truncated") String isTruncated,
@JsonProperty(value = "need_clear_history") String needClearHistory,
@JsonProperty(value = "object") String object,
@JsonProperty(value = "result") String result,
@JsonProperty(value = "usage") FastChatCompletionsUsage usage) {
this.id = id;
this.created = created;
this.isTruncated = isTruncated;
this.needClearHistory = needClearHistory;
this.object = object;
this.result = result;
this.usage = usage;
}
}

View File

@ -16,6 +16,7 @@ import ai.chat2db.server.web.api.controller.ai.azure.client.AzureOpenAIClient;
import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient;
import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatAIClient;
import ai.chat2db.server.web.api.controller.ai.rest.client.RestAIClient;
import ai.chat2db.server.web.api.controller.ai.wenxin.client.WenxinAIClient;
import ai.chat2db.server.web.api.controller.config.request.AIConfigCreateRequest;
import ai.chat2db.server.web.api.controller.config.request.SystemConfigRequest;
import ai.chat2db.server.web.api.controller.ai.openai.client.OpenAIClient;
@ -87,6 +88,9 @@ public class ConfigController {
case FASTCHATAI:
saveFastChatAIConfig(request);
break;
case WENXINAI:
saveWenxinAIConfig(request);
break;
}
return ActionResult.isSuccess();
}
@ -181,6 +185,21 @@ public class ConfigController {
FastChatAIClient.refresh();
}
/**
* save common fast chat ai config
*
* @param request
*/
private void saveWenxinAIConfig(AIConfigCreateRequest request) {
SystemConfigParam apikeyParam = SystemConfigParam.builder().code(WenxinAIClient.WENXIN_ACCESS_TOKEN)
.content(request.getApiKey()).build();
configService.createOrUpdate(apikeyParam);
SystemConfigParam apiHostParam = SystemConfigParam.builder().code(WenxinAIClient.WENXIN_HOST)
.content(request.getApiHost()).build();
configService.createOrUpdate(apiHostParam);
WenxinAIClient.refresh();
}
@GetMapping("/system_config/{code}")
public DataResult<Config> getSystemConfig(@PathVariable("code") String code) {
DataResult<Config> result = configService.find(code);
@ -251,6 +270,12 @@ public class ConfigController {
config.setApiHost(Objects.nonNull(fastChatApiHost.getData()) ? fastChatApiHost.getData().getContent() : "");
config.setModel(Objects.nonNull(fastChatModel.getData()) ? fastChatModel.getData().getContent() : "");
break;
case WENXINAI:
DataResult<Config> wenxinAccessToken = configService.find(WenxinAIClient.WENXIN_ACCESS_TOKEN);
DataResult<Config> wenxinApiHost = configService.find(WenxinAIClient.WENXIN_HOST);
config.setApiKey(Objects.nonNull(wenxinAccessToken.getData()) ? wenxinAccessToken.getData().getContent() : "");
config.setApiHost(Objects.nonNull(wenxinApiHost.getData()) ? wenxinApiHost.getData().getContent() : "");
break;
default:
break;
}