fix:1559 修复自定义AI不能使用的问题。

This commit is contained in:
tmlx1990
2024-12-16 20:04:24 +08:00
parent ecf2f99478
commit be4ecc2169
7 changed files with 338 additions and 188 deletions

View File

@ -235,7 +235,7 @@ public class ChatController {
case CHAT2DBAI:
return chatWithChat2dbAi(queryRequest, sseEmitter, uid);
case RESTAI :
return chatWithRestAi(queryRequest, sseEmitter);
return chatWithRestAi(queryRequest, sseEmitter, uid);
case FASTCHATAI:
return chatWithFastChatAi(queryRequest, sseEmitter, uid);
case AZUREAI :
@ -261,9 +261,15 @@ public class ChatController {
* @param sseEmitter
* @return
*/
private SseEmitter chatWithRestAi(ChatQueryRequest prompt, SseEmitter sseEmitter) {
RestAIEventSourceListener eventSourceListener = new RestAIEventSourceListener(sseEmitter);
RestAIClient.getInstance().restCompletions(buildPrompt(prompt), eventSourceListener);
private SseEmitter chatWithRestAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
String prompt = buildPrompt(queryRequest);
List<FastChatMessage> messages = getFastChatMessage(uid, prompt);
buildSseEmitter(sseEmitter, uid);
RestAIEventSourceListener restAIEventSourceListener = new RestAIEventSourceListener(sseEmitter);
RestAIClient.getInstance().streamCompletions(messages, restAIEventSourceListener);
LocalCache.CACHE.put(uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);
return sseEmitter;
}

View File

@ -6,6 +6,7 @@ import ai.chat2db.server.domain.api.service.ConfigService;
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
/**
* @author moji
@ -19,6 +20,11 @@ public class RestAIClient {
*/
public static final String AI_SQL_SOURCE = "ai.sql.source";
/**
* Customized AI interface KEY
*/
public static final String REST_AI_API_KEY = "rest.ai.apiKey";
/**
* Customized AI interface address
*/
@ -29,9 +35,16 @@ public class RestAIClient {
*/
public static final String REST_AI_STREAM_OUT = "rest.ai.stream";
private static RestAiStreamClient REST_AI_STREAM_CLIENT;
/**
* Custom AI interface model
*/
public static final String REST_AI_MODEL = "rest.ai.model";
public static RestAiStreamClient getInstance() {
private static RestAIStreamClient REST_AI_STREAM_CLIENT;
public static RestAIStreamClient getInstance() {
if (REST_AI_STREAM_CLIENT != null) {
return REST_AI_STREAM_CLIENT;
} else {
@ -39,7 +52,7 @@ public class RestAIClient {
}
}
private static RestAiStreamClient singleton() {
private static RestAIStreamClient singleton() {
if (REST_AI_STREAM_CLIENT == null) {
synchronized (RestAIClient.class) {
if (REST_AI_STREAM_CLIENT == null) {
@ -55,17 +68,23 @@ public class RestAIClient {
*/
public static void refresh() {
String apiUrl = "";
Boolean stream = Boolean.TRUE;
String apiKey = "";
String model = "";
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config apiHostConfig = configService.find(REST_AI_URL).getData();
if (apiHostConfig != null) {
apiUrl = apiHostConfig.getContent();
}
Config config = configService.find(REST_AI_STREAM_OUT).getData();
Config config = configService.find(REST_AI_API_KEY).getData();
if (config != null) {
stream = Boolean.valueOf(config.getContent());
apiKey = config.getContent();
}
REST_AI_STREAM_CLIENT = new RestAiStreamClient(apiUrl, stream);
Config deployConfig = configService.find(REST_AI_MODEL).getData();
if (deployConfig != null && StringUtils.isNotBlank(deployConfig.getContent())) {
model = deployConfig.getContent();
}
REST_AI_STREAM_CLIENT = RestAIStreamClient.builder().apiKey(apiKey).apiHost(apiUrl).model(model)
.build();
}
}

View File

@ -0,0 +1,180 @@
package ai.chat2db.server.web.api.controller.ai.rest.client;
import ai.chat2db.server.tools.common.exception.ParamBusinessException;
import ai.chat2db.server.web.api.controller.ai.fastchat.interceptor.FastChatHeaderAuthorizationInterceptor;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatCompletionsOptions;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage;
import cn.hutool.http.ContentType;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;
import org.apache.commons.collections4.CollectionUtils;
import org.jetbrains.annotations.NotNull;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
/**
* Custom AI interface client
* @author moji
*/
@Slf4j
public class RestAIStreamClient {
/**
* apikey
*/
@Getter
@NotNull
private String apiKey;
/**
* apiHost
*/
@Getter
@NotNull
private String apiHost;
/**
* model
*/
@Getter
private String model;
/**
* okHttpClient
*/
@Getter
private OkHttpClient okHttpClient;
/**
* Construct instance object
*
* @param builder
*/
public RestAIStreamClient(Builder builder) {
this.apiKey = builder.apiKey;
this.apiHost = builder.apiHost;
this.model = builder.model;
this.okHttpClient = new OkHttpClient
.Builder()
.addInterceptor(new FastChatHeaderAuthorizationInterceptor(this.apiKey))
.connectTimeout(10, TimeUnit.SECONDS)
.writeTimeout(50, TimeUnit.SECONDS)
.readTimeout(50, TimeUnit.SECONDS)
.build();
}
/**
* structure
*
* @return
*/
public static RestAIStreamClient.Builder builder() {
return new RestAIStreamClient.Builder();
}
/**
* builder
*/
public static final class Builder {
private String apiKey;
private String apiHost;
private String model;
/**
* OkhttpClient
*/
private OkHttpClient okHttpClient;
public Builder() {
}
public RestAIStreamClient.Builder apiKey(String apiKeyValue) {
this.apiKey = apiKeyValue;
return this;
}
/**
* @param apiHostValue
* @return
*/
public RestAIStreamClient.Builder apiHost(String apiHostValue) {
this.apiHost = apiHostValue;
return this;
}
/**
* @param modelValue
* @return
*/
public RestAIStreamClient.Builder model(String modelValue) {
this.model = modelValue;
return this;
}
public RestAIStreamClient.Builder okHttpClient(OkHttpClient val) {
this.okHttpClient = val;
return this;
}
public RestAIStreamClient build() {
return new RestAIStreamClient(this);
}
}
/**
* Q&A interface stream form
*
* @param chatMessages
* @param eventSourceListener
*/
public void streamCompletions(List<FastChatMessage> chatMessages, EventSourceListener eventSourceListener) {
if (CollectionUtils.isEmpty(chatMessages)) {
log.error("param errorRest AI Prompt cannot be empty");
throw new ParamBusinessException("prompt");
}
if (Objects.isNull(eventSourceListener)) {
log.error("param errorRestAIEventSourceListener cannot be empty");
throw new ParamBusinessException();
}
log.info("Rest AI, prompt:{}", chatMessages.get(chatMessages.size() - 1).getContent());
try {
FastChatCompletionsOptions chatCompletionsOptions = new FastChatCompletionsOptions(chatMessages);
chatCompletionsOptions.setStream(true);
chatCompletionsOptions.setModel(this.model);
EventSource.Factory factory = EventSources.createFactory(this.okHttpClient);
ObjectMapper mapper = new ObjectMapper();
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
String requestBody = mapper.writeValueAsString(chatCompletionsOptions);
Request request = new Request.Builder()
.url(apiHost)
.post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody))
.build();
//Create event
EventSource eventSource = factory.newEventSource(request, eventSourceListener);
log.info("finish invoking rest ai");
} catch (Exception e) {
log.error("rest ai error", e);
eventSourceListener.onFailure(null, e, null);
throw new ParamBusinessException();
}
}
}

View File

@ -1,166 +0,0 @@
package ai.chat2db.server.web.api.controller.ai.rest.client;
import java.io.IOException;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import ai.chat2db.server.tools.common.exception.ParamBusinessException;
import ai.chat2db.server.web.api.controller.ai.rest.model.RestAiCompletion;
import cn.hutool.http.ContentType;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.unfbx.chatgpt.sse.ConsoleEventSourceListener;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Call;
import okhttp3.Callback;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.ResponseBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;
import org.apache.commons.lang3.StringUtils;
/**
* Custom AI interface client
* @author moji
*/
@Slf4j
public class RestAiStreamClient {
/**
* rest api url
*/
@Getter
private String apiUrl;
/**
* Whether to stream interface
*/
@Getter
private Boolean stream;
/**
* okHttpClient
*/
@Getter
private OkHttpClient okHttpClient;
/**
* Construct instance object
*
* @param url
*/
public RestAiStreamClient(String url, Boolean stream) {
this.apiUrl = url;
this.stream = stream;
this.okHttpClient = new OkHttpClient
.Builder()
.connectTimeout(10, TimeUnit.SECONDS)
.writeTimeout(50, TimeUnit.SECONDS)
.readTimeout(50, TimeUnit.SECONDS)
.build();
}
/**
* Request RESTAI interface
*
* @param prompt
* @param eventSourceListener
*/
public void restCompletions(String prompt,
EventSourceListener eventSourceListener) {
log.info("Start calling the custom AI, prompt:{}", prompt);
RestAiCompletion completion = new RestAiCompletion();
completion.setPrompt(prompt);
if (Objects.isNull(stream) || stream) {
streamCompletions(completion, eventSourceListener);
log.info("End calling streaming output custom AI");
return;
}
nonStreamCompletions(completion, eventSourceListener);
log.info("End calling non-streaming output custom AI");
}
/**
* Q&A interface stream form
*
* @param completion open ai parameter
* @param eventSourceListener sse listener
* @see ConsoleEventSourceListener
*/
public void streamCompletions(RestAiCompletion completion, EventSourceListener eventSourceListener) {
if (Objects.isNull(eventSourceListener)) {
log.error("Parameter exception: EventSourceListener cannot be empty");
throw new ParamBusinessException();
}
if (StringUtils.isBlank(completion.getPrompt())) {
log.error("Parameter exception: Prompt cannot be empty");
throw new ParamBusinessException();
}
try {
EventSource.Factory factory = EventSources.createFactory(this.okHttpClient);
ObjectMapper mapper = new ObjectMapper();
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
String requestBody = mapper.writeValueAsString(completion);
Request request = new Request.Builder()
.url(this.apiUrl)
.post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody))
.build();
//Create event
EventSource eventSource = factory.newEventSource(request, eventSourceListener);
} catch (Exception e) {
log.error("Request parameter parsing exception", e);
throw new ParamBusinessException();
}
}
/**
* Request non-streaming output interface
*
* @param completion
* @param eventSourceListener
*/
public void nonStreamCompletions(RestAiCompletion completion, EventSourceListener eventSourceListener) {
if (StringUtils.isBlank(completion.getPrompt())) {
log.error("Parameter exception: Prompt cannot be empty");
throw new ParamBusinessException();
}
try {
ObjectMapper mapper = new ObjectMapper();
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
String requestBody = mapper.writeValueAsString(completion);
Request request = new Request.Builder()
.url(this.apiUrl)
.post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody))
.build();
this.okHttpClient.newCall(request).enqueue(new Callback() {
@Override
public void onFailure(Call call, IOException e) {
eventSourceListener.onFailure(null, e, null);
}
@Override
public void onResponse(Call call, Response response) throws IOException {
try (ResponseBody responseBody = response.body()) {
if (responseBody != null) {
String content = responseBody.string();
eventSourceListener.onEvent(null, "[DATA]", null, content);
eventSourceListener.onEvent(null, "[DONE]", null, "[DONE]");
}
} catch (IOException e) {
eventSourceListener.onFailure(null, e, response);
}
}
});
} catch (Exception e) {
log.error("Request parameter parsing exception", e);
throw new ParamBusinessException();
}
}
}

View File

@ -1,7 +1,12 @@
package ai.chat2db.server.web.api.controller.ai.rest.listener;
import java.io.IOException;
import java.util.Objects;
import ai.chat2db.server.web.api.controller.ai.rest.model.RestAIChatCompletions;
import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletions;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.unfbx.chatgpt.entity.chat.Message;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
@ -27,6 +32,7 @@ public class RestAIEventSourceListener extends EventSourceListener {
this.sseEmitter = sseEmitter;
}
private ObjectMapper mapper = new ObjectMapper().disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES);
/**
* {@inheritDoc}
*/
@ -54,9 +60,11 @@ public class RestAIEventSourceListener extends EventSourceListener {
}
Message message = new Message();
if (StringUtils.isNotBlank(data)) {
data = data.replaceAll("^\"|\"$", "");
data = data.replaceAll("\\\\n", "\n");
message.setContent(data);
RestAIChatCompletions chatCompletions = mapper.readValue(data, RestAIChatCompletions.class);
String text = chatCompletions.getChoices().get(0).getDelta()==null?
chatCompletions.getChoices().get(0).getText()
:chatCompletions.getChoices().get(0).getDelta().getContent();
message.setContent(text);
sseEmitter.send(SseEmitter.event()
.id(id)
.data(message)
@ -68,10 +76,14 @@ public class RestAIEventSourceListener extends EventSourceListener {
@Override
public void onClosed(EventSource eventSource) {
log.info("REST AI close sse connection...");
sseEmitter.send(SseEmitter.event()
.id("[DONE]")
.data("[DONE]")
.reconnectTime(3000));
try {
sseEmitter.send(SseEmitter.event()
.id("[DONE]")
.data("[DONE]")
.reconnectTime(3000));
} catch (IOException e) {
throw new RuntimeException(e);
}
sseEmitter.complete();
}

View File

@ -0,0 +1,96 @@
package ai.chat2db.server.web.api.controller.ai.rest.model;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatChoice;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
import java.util.List;
@Data
public class RestAIChatCompletions {
/*
* A unique identifier associated with this chat completions response.
*/
private String id;
/*
* The first timestamp associated with generation activity for this completions response,
* represented as seconds since the beginning of the Unix epoch of 00:00 on 1 Jan 1970.
*/
private int created;
/**
* model
*/
private String model;
/**
* object
*/
private String object;
/*
* The collection of completions choices associated with this completions response.
* Generally, `n` choices are generated per provided prompt with a default value of 1.
* Token limits and other settings may limit the number of choices generated.
*/
@JsonProperty(value = "choices")
private List<FastChatChoice> choices;
/**
* Creates an instance of ChatCompletions class.
*
* @param id the id value to set.
* @param created the created value to set.
* @param choices the choices value to set.
*/
@JsonCreator
private RestAIChatCompletions(
@JsonProperty(value = "id") String id,
@JsonProperty(value = "created") int created,
@JsonProperty(value = "model") String model,
@JsonProperty(value = "object") String object,
@JsonProperty(value = "choices") List<FastChatChoice> choices) {
this.id = id;
this.created = created;
this.model = model;
this.object = object;
this.choices = choices;
}
/**
* Get the id property: A unique identifier associated with this chat completions response.
*
* @return the id value.
*/
public String getId() {
return this.id;
}
/**
* Get the created property: The first timestamp associated with generation activity for this completions response,
* represented as seconds since the beginning of the Unix epoch of 00:00 on 1 Jan 1970.
*
* @return the created value.
*/
public int getCreated() {
return this.created;
}
/**
* Get the choices property: The collection of completions choices associated with this completions response.
* Generally, `n` choices are generated per provided prompt with a default value of 1. Token limits and other
* settings may limit the number of choices generated.
*
* @return the choices value.
*/
public List<FastChatChoice> getChoices() {
return this.choices;
}
}

View File

@ -83,7 +83,7 @@ public class ConfigController {
saveChat2dbAIConfig(request);
break;
case RESTAI:
saveFastChatAIConfig(request);
saveRestAIConfig(request);
break;
case AZUREAI:
saveAzureAIConfig(request);
@ -152,12 +152,15 @@ public class ConfigController {
* @param request
*/
private void saveRestAIConfig(AIConfigCreateRequest request) {
SystemConfigParam param = SystemConfigParam.builder().code(RestAIClient.REST_AI_API_KEY).content(
request.getApiKey()).build();
configService.createOrUpdate(param);
SystemConfigParam restParam = SystemConfigParam.builder().code(RestAIClient.REST_AI_URL).content(
request.getApiHost()).build();
configService.createOrUpdate(restParam);
SystemConfigParam methodParam = SystemConfigParam.builder().code(RestAIClient.REST_AI_STREAM_OUT).content(
request.getStream().toString()).build();
configService.createOrUpdate(methodParam);
SystemConfigParam modelParam = SystemConfigParam.builder().code(RestAIClient.REST_AI_MODEL)
.content(request.getModel()).build();
configService.createOrUpdate(modelParam);
RestAIClient.refresh();
}