mirror of
https://github.com/CodePhiliaX/Chat2DB.git
synced 2025-09-20 19:35:46 +08:00
embedding update
This commit is contained in:
@ -55,8 +55,13 @@ import cn.hutool.json.JSONUtil;
|
||||
import com.alibaba.fastjson2.JSON;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Maps;
|
||||
import com.unfbx.chatgpt.OpenAiApi;
|
||||
import com.unfbx.chatgpt.entity.chat.Message;
|
||||
import com.unfbx.chatgpt.entity.embeddings.Embedding;
|
||||
import com.unfbx.chatgpt.entity.embeddings.EmbeddingResponse;
|
||||
import io.reactivex.Single;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
@ -101,6 +106,9 @@ public class ChatController {
|
||||
@Resource
|
||||
private GatewayClientService gatewayClientService;
|
||||
|
||||
@Getter
|
||||
private OpenAiApi openAiApi;
|
||||
|
||||
/**
|
||||
* chat的超时时间
|
||||
*/
|
||||
@ -553,7 +561,7 @@ public class ChatController {
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public FastChatEmbeddingResponse distributeAIEmbedding(String input) throws IOException {
|
||||
public FastChatEmbeddingResponse distributeAIEmbedding(String input) {
|
||||
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
|
||||
Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData();
|
||||
String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
|
||||
@ -561,19 +569,16 @@ public class ChatController {
|
||||
aiSqlSource = config.getContent();
|
||||
}
|
||||
AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(aiSqlSource);
|
||||
if (Objects.isNull(aiSqlSourceEnum)) {
|
||||
aiSqlSourceEnum = AiSqlSourceEnum.OPENAI;
|
||||
}
|
||||
switch (Objects.requireNonNull(aiSqlSourceEnum)) {
|
||||
case OPENAI :
|
||||
case AZUREAI :
|
||||
case OPENAI:
|
||||
case CHAT2DBAI:
|
||||
return embeddingWithOpenAi(input);
|
||||
case RESTAI :
|
||||
case FASTCHATAI:
|
||||
case AZUREAI :
|
||||
case CLAUDEAI:
|
||||
return distributeAIEmbedding(input);
|
||||
return embeddingWithFastChatAi(input);
|
||||
}
|
||||
return distributeAIEmbedding(input);
|
||||
return embeddingWithFastChatAi(input);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -583,120 +588,22 @@ public class ChatController {
|
||||
* @return
|
||||
* @throws IOException
|
||||
*/
|
||||
private FastChatEmbeddingResponse embeddingWithFastChatAi(String input) throws IOException {
|
||||
private FastChatEmbeddingResponse embeddingWithFastChatAi(String input) {
|
||||
FastChatEmbeddingResponse response = FastChatAIClient.getInstance().embeddings(input);
|
||||
return response;
|
||||
}
|
||||
|
||||
///**
|
||||
// * 问答对话模型
|
||||
// *
|
||||
// * @param msg
|
||||
// * @param headers
|
||||
// * @return
|
||||
// * @throws IOException
|
||||
// */
|
||||
//@GetMapping("/chat1")
|
||||
//@CrossOrigin
|
||||
//public SseEmitter chat(@RequestParam("message") String msg, @RequestHeader Map<String, String> headers)
|
||||
// throws IOException {
|
||||
// //默认30秒超时,设置为0L则永不超时
|
||||
// SseEmitter sseEmitter = new SseEmitter(CHAT_TIMEOUT);
|
||||
// String uid = headers.get("uid");
|
||||
// if (StrUtil.isBlank(uid)) {
|
||||
// throw new BaseException(CommonError.SYS_ERROR);
|
||||
// }
|
||||
// return distributeAI(msg, sseEmitter, uid);
|
||||
//}
|
||||
/**
|
||||
* embedding with open ai
|
||||
*
|
||||
* @param input
|
||||
* @return
|
||||
*/
|
||||
private FastChatEmbeddingResponse embeddingWithOpenAi(String input) {
|
||||
Embedding embedding = Embedding.builder().input(input).build();
|
||||
Single<EmbeddingResponse> embeddings = this.openAiApi.embeddings(embedding);
|
||||
EmbeddingResponse embeddingResponse = embeddings.blockingGet();
|
||||
return chatConverter.response2response(embeddingResponse);
|
||||
}
|
||||
|
||||
///**
|
||||
// * distribute with different AI
|
||||
// *
|
||||
// * @return
|
||||
// */
|
||||
//private SseEmitter distributeAI(String msg, SseEmitter sseEmitter, String uid) throws IOException {
|
||||
// ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
|
||||
// Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData();
|
||||
// String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
|
||||
// if (Objects.nonNull(config)) {
|
||||
// aiSqlSource = config.getContent();
|
||||
// }
|
||||
// AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(aiSqlSource);
|
||||
// if (Objects.isNull(aiSqlSourceEnum)) {
|
||||
// aiSqlSourceEnum = AiSqlSourceEnum.OPENAI;
|
||||
// }
|
||||
// switch (Objects.requireNonNull(aiSqlSourceEnum)) {
|
||||
// case OPENAI :
|
||||
// return chatWithOpenAi(msg, sseEmitter, uid);
|
||||
// case CHAT2DBAI:
|
||||
// return chatWithOpenAi(msg, sseEmitter, uid);
|
||||
// case RESTAI :
|
||||
// return chatWithRestAi(msg, sseEmitter);
|
||||
// }
|
||||
// return chatWithOpenAi(msg, sseEmitter, uid);
|
||||
//}
|
||||
|
||||
///**
|
||||
// * 使用OPENAI聊天相关接口
|
||||
// *
|
||||
// * @param msg
|
||||
// * @param sseEmitter
|
||||
// * @param uid
|
||||
// * @return
|
||||
// * @throws IOException
|
||||
// */
|
||||
//private SseEmitter chatWithOpenAi(String msg, SseEmitter sseEmitter, String uid) throws IOException {
|
||||
// String messageContext = (String)LocalCache.CACHE.get(uid);
|
||||
// List<Message> messages = new ArrayList<>();
|
||||
// if (StrUtil.isNotBlank(messageContext)) {
|
||||
// messages = JSONUtil.toList(messageContext, Message.class);
|
||||
// if (messages.size() >= contextLength) {
|
||||
// messages = messages.subList(1, contextLength);
|
||||
// }
|
||||
// Message currentMessage = Message.builder().content(msg).role(Message.Role.USER).build();
|
||||
// messages.add(currentMessage);
|
||||
// } else {
|
||||
// Message currentMessage = Message.builder().content(msg).role(Message.Role.USER).build();
|
||||
// messages.add(currentMessage);
|
||||
// }
|
||||
//
|
||||
// return chatGpt35(messages, sseEmitter, uid);
|
||||
//}
|
||||
|
||||
///**
|
||||
// * 使用GPT3.0模型
|
||||
// *
|
||||
// * @param prompt
|
||||
// * @param sseEmitter
|
||||
// * @param uid
|
||||
// * @return
|
||||
// */
|
||||
//private SseEmitter chatGpt3(String prompt, SseEmitter sseEmitter, String uid) throws IOException {
|
||||
// sseEmitter.send(SseEmitter.event().id(uid).name("chatGpt3连接成功!!!!").data(LocalDateTime.now())
|
||||
// .reconnectTime(3000));
|
||||
// sseEmitter.onCompletion(() -> {
|
||||
// log.info(LocalDateTime.now() + ", uid#" + uid + ", on completion");
|
||||
// });
|
||||
// sseEmitter.onTimeout(
|
||||
// () -> log.info(LocalDateTime.now() + ", uid#" + uid + ", chatGpt3 on timeout#" + sseEmitter.getTimeout()));
|
||||
// sseEmitter.onError(
|
||||
// throwable -> {
|
||||
// try {
|
||||
// log.info(LocalDateTime.now() + ", uid#" + "765431" + ", chatGpt3 on error#" + throwable.toString());
|
||||
// sseEmitter.send(SseEmitter.event().id("765431").name("chatGpt3 发生异常!")
|
||||
// .data(throwable.getMessage())
|
||||
// .reconnectTime(3000));
|
||||
// } catch (IOException e) {
|
||||
// e.printStackTrace();
|
||||
// }
|
||||
// }
|
||||
// );
|
||||
//
|
||||
// // 获取返回结果
|
||||
// OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter);
|
||||
// Completion completion = Completion.builder().maxTokens(RETURN_TOKEN_LENGTH).stream(true).stop(
|
||||
// Lists.newArrayList("#", ";")).user(uid).prompt(prompt).build();
|
||||
// OpenAIClient.getInstance().streamCompletions(completion, openAIEventSourceListener);
|
||||
// return sseEmitter;
|
||||
//}
|
||||
}
|
||||
|
@ -12,6 +12,7 @@ import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect;
|
||||
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse;
|
||||
import ai.chat2db.server.web.api.controller.rdb.converter.RdbWebConverter;
|
||||
import ai.chat2db.server.web.api.controller.rdb.request.TableBriefQueryRequest;
|
||||
import ai.chat2db.server.web.api.controller.rdb.request.TableMilvusQueryRequest;
|
||||
import ai.chat2db.server.web.api.http.GatewayClientService;
|
||||
import ai.chat2db.server.web.api.http.request.TableSchemaRequest;
|
||||
import ai.chat2db.spi.model.Table;
|
||||
@ -61,7 +62,7 @@ public class EmbeddingController extends ChatController {
|
||||
*/
|
||||
@PostMapping("/datasource")
|
||||
@CrossOrigin
|
||||
public ActionResult embeddings(@Valid TableBriefQueryRequest request)
|
||||
public ActionResult embeddings(@Valid TableMilvusQueryRequest request)
|
||||
throws Exception {
|
||||
|
||||
// query tables
|
||||
@ -87,12 +88,10 @@ public class EmbeddingController extends ChatController {
|
||||
// save first table embedding
|
||||
TableSchemaRequest tableSchemaRequest = new TableSchemaRequest();
|
||||
tableSchemaRequest.setDataSourceId(request.getDataSourceId());
|
||||
tableSchemaRequest.setApiKey(request.getApikey());
|
||||
tableSchemaRequest.setDeleteBeforeInsert(true);
|
||||
String databaseName = StringUtils.isNotBlank(request.getDatabaseName()) ? request.getDatabaseName() : request.getSchemaName();
|
||||
if (Objects.isNull(databaseName)) {
|
||||
databaseName = "";
|
||||
}
|
||||
tableSchemaRequest.setDatabaseName(databaseName);
|
||||
tableSchemaRequest.setDataSourceSchema(request.getSchemaName());
|
||||
tableSchemaRequest.setDatabaseName(request.getDatabaseName());
|
||||
|
||||
saveTableEmbedding(tableSchema, tableSchemaRequest);
|
||||
|
||||
|
@ -1,8 +1,14 @@
|
||||
package ai.chat2db.server.web.api.controller.ai.converter;
|
||||
|
||||
import ai.chat2db.server.domain.api.param.TableQueryParam;
|
||||
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse;
|
||||
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatItem;
|
||||
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatCompletionsUsage;
|
||||
import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest;
|
||||
|
||||
import com.unfbx.chatgpt.entity.common.Usage;
|
||||
import com.unfbx.chatgpt.entity.embeddings.EmbeddingResponse;
|
||||
import com.unfbx.chatgpt.entity.embeddings.Item;
|
||||
import org.mapstruct.Mapper;
|
||||
|
||||
/**
|
||||
@ -20,4 +26,28 @@ public abstract class ChatConverter {
|
||||
* @return
|
||||
*/
|
||||
public abstract TableQueryParam chat2tableQuery(ChatQueryRequest request);
|
||||
|
||||
/**
|
||||
* chat convert
|
||||
*
|
||||
* @param item
|
||||
* @return
|
||||
*/
|
||||
public abstract FastChatItem item2ChatItem(Item item);
|
||||
|
||||
/**
|
||||
* usage convert
|
||||
*
|
||||
* @param usage
|
||||
* @return
|
||||
*/
|
||||
public abstract FastChatCompletionsUsage usage2usage(Usage usage);
|
||||
|
||||
/**
|
||||
* response convert
|
||||
*
|
||||
* @param embeddingResponse
|
||||
* @return
|
||||
*/
|
||||
public abstract FastChatEmbeddingResponse response2response(EmbeddingResponse embeddingResponse);
|
||||
}
|
||||
|
@ -6,12 +6,14 @@ package ai.chat2db.server.web.api.controller.ai.fastchat.model;
|
||||
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
/**
|
||||
* Representation of the token counts processed for a completions request. Counts consider all tokens across prompts,
|
||||
* choices, choice alternates, best_of generations, and other consumers.
|
||||
*/
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
public final class FastChatCompletionsUsage {
|
||||
|
||||
/*
|
||||
|
@ -0,0 +1,10 @@
|
||||
package ai.chat2db.server.web.api.controller.rdb.request;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class TableMilvusQueryRequest extends TableBriefQueryRequest {
|
||||
|
||||
private String apikey;
|
||||
}
|
Reference in New Issue
Block a user