embedding update

This commit is contained in:
robin
2023-10-13 19:19:36 +08:00
parent e6cb1067d4
commit 5827afb71e
5 changed files with 74 additions and 126 deletions

View File

@ -55,8 +55,13 @@ import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSON;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.unfbx.chatgpt.OpenAiApi;
import com.unfbx.chatgpt.entity.chat.Message;
import com.unfbx.chatgpt.entity.embeddings.Embedding;
import com.unfbx.chatgpt.entity.embeddings.EmbeddingResponse;
import io.reactivex.Single;
import jakarta.annotation.Resource;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
@ -101,6 +106,9 @@ public class ChatController {
@Resource
private GatewayClientService gatewayClientService;
@Getter
private OpenAiApi openAiApi;
/**
* chat的超时时间
*/
@ -553,7 +561,7 @@ public class ChatController {
*
* @return
*/
public FastChatEmbeddingResponse distributeAIEmbedding(String input) throws IOException {
public FastChatEmbeddingResponse distributeAIEmbedding(String input) {
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData();
String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
@ -561,19 +569,16 @@ public class ChatController {
aiSqlSource = config.getContent();
}
AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(aiSqlSource);
if (Objects.isNull(aiSqlSourceEnum)) {
aiSqlSourceEnum = AiSqlSourceEnum.OPENAI;
}
switch (Objects.requireNonNull(aiSqlSourceEnum)) {
case OPENAI :
case AZUREAI :
case OPENAI:
case CHAT2DBAI:
return embeddingWithOpenAi(input);
case RESTAI :
case FASTCHATAI:
case AZUREAI :
case CLAUDEAI:
return distributeAIEmbedding(input);
return embeddingWithFastChatAi(input);
}
return distributeAIEmbedding(input);
return embeddingWithFastChatAi(input);
}
/**
@ -583,120 +588,22 @@ public class ChatController {
* @return
* @throws IOException
*/
private FastChatEmbeddingResponse embeddingWithFastChatAi(String input) throws IOException {
private FastChatEmbeddingResponse embeddingWithFastChatAi(String input) {
FastChatEmbeddingResponse response = FastChatAIClient.getInstance().embeddings(input);
return response;
}
///**
// * 问答对话模型
// *
// * @param msg
// * @param headers
// * @return
// * @throws IOException
// */
//@GetMapping("/chat1")
//@CrossOrigin
//public SseEmitter chat(@RequestParam("message") String msg, @RequestHeader Map<String, String> headers)
// throws IOException {
// //默认30秒超时,设置为0L则永不超时
// SseEmitter sseEmitter = new SseEmitter(CHAT_TIMEOUT);
// String uid = headers.get("uid");
// if (StrUtil.isBlank(uid)) {
// throw new BaseException(CommonError.SYS_ERROR);
// }
// return distributeAI(msg, sseEmitter, uid);
//}
/**
* embedding with open ai
*
* @param input
* @return
*/
private FastChatEmbeddingResponse embeddingWithOpenAi(String input) {
Embedding embedding = Embedding.builder().input(input).build();
Single<EmbeddingResponse> embeddings = this.openAiApi.embeddings(embedding);
EmbeddingResponse embeddingResponse = embeddings.blockingGet();
return chatConverter.response2response(embeddingResponse);
}
///**
// * distribute with different AI
// *
// * @return
// */
//private SseEmitter distributeAI(String msg, SseEmitter sseEmitter, String uid) throws IOException {
// ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
// Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData();
// String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
// if (Objects.nonNull(config)) {
// aiSqlSource = config.getContent();
// }
// AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(aiSqlSource);
// if (Objects.isNull(aiSqlSourceEnum)) {
// aiSqlSourceEnum = AiSqlSourceEnum.OPENAI;
// }
// switch (Objects.requireNonNull(aiSqlSourceEnum)) {
// case OPENAI :
// return chatWithOpenAi(msg, sseEmitter, uid);
// case CHAT2DBAI:
// return chatWithOpenAi(msg, sseEmitter, uid);
// case RESTAI :
// return chatWithRestAi(msg, sseEmitter);
// }
// return chatWithOpenAi(msg, sseEmitter, uid);
//}
///**
// * 使用OPENAI聊天相关接口
// *
// * @param msg
// * @param sseEmitter
// * @param uid
// * @return
// * @throws IOException
// */
//private SseEmitter chatWithOpenAi(String msg, SseEmitter sseEmitter, String uid) throws IOException {
// String messageContext = (String)LocalCache.CACHE.get(uid);
// List<Message> messages = new ArrayList<>();
// if (StrUtil.isNotBlank(messageContext)) {
// messages = JSONUtil.toList(messageContext, Message.class);
// if (messages.size() >= contextLength) {
// messages = messages.subList(1, contextLength);
// }
// Message currentMessage = Message.builder().content(msg).role(Message.Role.USER).build();
// messages.add(currentMessage);
// } else {
// Message currentMessage = Message.builder().content(msg).role(Message.Role.USER).build();
// messages.add(currentMessage);
// }
//
// return chatGpt35(messages, sseEmitter, uid);
//}
///**
// * 使用GPT3.0模型
// *
// * @param prompt
// * @param sseEmitter
// * @param uid
// * @return
// */
//private SseEmitter chatGpt3(String prompt, SseEmitter sseEmitter, String uid) throws IOException {
// sseEmitter.send(SseEmitter.event().id(uid).name("chatGpt3连接成功").data(LocalDateTime.now())
// .reconnectTime(3000));
// sseEmitter.onCompletion(() -> {
// log.info(LocalDateTime.now() + ", uid#" + uid + ", on completion");
// });
// sseEmitter.onTimeout(
// () -> log.info(LocalDateTime.now() + ", uid#" + uid + ", chatGpt3 on timeout#" + sseEmitter.getTimeout()));
// sseEmitter.onError(
// throwable -> {
// try {
// log.info(LocalDateTime.now() + ", uid#" + "765431" + ", chatGpt3 on error#" + throwable.toString());
// sseEmitter.send(SseEmitter.event().id("765431").name("chatGpt3 发生异常!")
// .data(throwable.getMessage())
// .reconnectTime(3000));
// } catch (IOException e) {
// e.printStackTrace();
// }
// }
// );
//
// // 获取返回结果
// OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter);
// Completion completion = Completion.builder().maxTokens(RETURN_TOKEN_LENGTH).stream(true).stop(
// Lists.newArrayList("#", ";")).user(uid).prompt(prompt).build();
// OpenAIClient.getInstance().streamCompletions(completion, openAIEventSourceListener);
// return sseEmitter;
//}
}

View File

@ -12,6 +12,7 @@ import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect;
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse;
import ai.chat2db.server.web.api.controller.rdb.converter.RdbWebConverter;
import ai.chat2db.server.web.api.controller.rdb.request.TableBriefQueryRequest;
import ai.chat2db.server.web.api.controller.rdb.request.TableMilvusQueryRequest;
import ai.chat2db.server.web.api.http.GatewayClientService;
import ai.chat2db.server.web.api.http.request.TableSchemaRequest;
import ai.chat2db.spi.model.Table;
@ -61,7 +62,7 @@ public class EmbeddingController extends ChatController {
*/
@PostMapping("/datasource")
@CrossOrigin
public ActionResult embeddings(@Valid TableBriefQueryRequest request)
public ActionResult embeddings(@Valid TableMilvusQueryRequest request)
throws Exception {
// query tables
@ -87,12 +88,10 @@ public class EmbeddingController extends ChatController {
// save first table embedding
TableSchemaRequest tableSchemaRequest = new TableSchemaRequest();
tableSchemaRequest.setDataSourceId(request.getDataSourceId());
tableSchemaRequest.setApiKey(request.getApikey());
tableSchemaRequest.setDeleteBeforeInsert(true);
String databaseName = StringUtils.isNotBlank(request.getDatabaseName()) ? request.getDatabaseName() : request.getSchemaName();
if (Objects.isNull(databaseName)) {
databaseName = "";
}
tableSchemaRequest.setDatabaseName(databaseName);
tableSchemaRequest.setDataSourceSchema(request.getSchemaName());
tableSchemaRequest.setDatabaseName(request.getDatabaseName());
saveTableEmbedding(tableSchema, tableSchemaRequest);

View File

@ -1,8 +1,14 @@
package ai.chat2db.server.web.api.controller.ai.converter;
import ai.chat2db.server.domain.api.param.TableQueryParam;
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse;
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatItem;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatCompletionsUsage;
import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest;
import com.unfbx.chatgpt.entity.common.Usage;
import com.unfbx.chatgpt.entity.embeddings.EmbeddingResponse;
import com.unfbx.chatgpt.entity.embeddings.Item;
import org.mapstruct.Mapper;
/**
@ -20,4 +26,28 @@ public abstract class ChatConverter {
* @return
*/
public abstract TableQueryParam chat2tableQuery(ChatQueryRequest request);
/**
* chat convert
*
* @param item
* @return
*/
public abstract FastChatItem item2ChatItem(Item item);
/**
* usage convert
*
* @param usage
* @return
*/
public abstract FastChatCompletionsUsage usage2usage(Usage usage);
/**
* response convert
*
* @param embeddingResponse
* @return
*/
public abstract FastChatEmbeddingResponse response2response(EmbeddingResponse embeddingResponse);
}

View File

@ -6,12 +6,14 @@ package ai.chat2db.server.web.api.controller.ai.fastchat.model;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
import lombok.NoArgsConstructor;
/**
* Representation of the token counts processed for a completions request. Counts consider all tokens across prompts,
* choices, choice alternates, best_of generations, and other consumers.
*/
@Data
@NoArgsConstructor
public final class FastChatCompletionsUsage {
/*

View File

@ -0,0 +1,10 @@
package ai.chat2db.server.web.api.controller.rdb.request;
import lombok.Data;
@Data
public class TableMilvusQueryRequest extends TableBriefQueryRequest {
private String apikey;
}