add knowledge support

This commit is contained in:
robin
2023-10-04 11:28:57 +08:00
parent 2dabb8eb39
commit 9b16b3c35e
20 changed files with 662 additions and 8 deletions

View File

@ -64,18 +64,19 @@
<dependency>
<groupId>com.deepoove</groupId>
<artifactId>poi-tl</artifactId>
<version>1.10.5</version>
</dependency>
<!--pdf-->
<dependency>
<groupId>com.itextpdf</groupId>
<artifactId>itext-asian</artifactId>
<version>5.2.0</version>
</dependency>
<dependency>
<groupId>com.itextpdf</groupId>
<artifactId>itextpdf</artifactId>
<version>5.5.13</version>
</dependency>
<dependency>
<groupId>org.apache.pdfbox</groupId>
<artifactId>pdfbox</artifactId>
</dependency>
</dependencies>

View File

@ -33,6 +33,7 @@ import ai.chat2db.server.web.api.controller.ai.enums.PromptType;
import ai.chat2db.server.web.api.controller.ai.azure.listener.AzureOpenAIEventSourceListener;
import ai.chat2db.server.web.api.controller.ai.claude.listener.ClaudeAIEventSourceListener;
import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatAIClient;
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse;
import ai.chat2db.server.web.api.controller.ai.fastchat.listener.FastChatAIEventSourceListener;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatRole;
@ -199,7 +200,7 @@ public class ChatController {
*
* @return
*/
private SseEmitter distributeAISql(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
public SseEmitter distributeAISql(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData();
String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
@ -453,6 +454,10 @@ public class ChatController {
* @return
*/
private String buildPrompt(ChatQueryRequest queryRequest) {
if (PromptType.QUESTION_ANSWERING.getCode().equals(queryRequest.getPromptType())) {
return queryRequest.getMessage();
}
// 查询schema信息
DataResult<DataSource> dataResult = dataSourceService.queryById(queryRequest.getDataSourceId());
String dataSourceType = dataResult.getData().getType();
@ -487,6 +492,46 @@ public class ChatController {
return schemaProperty;
}
/**
* distribute embedding with different AI
*
* @return
*/
public FastChatEmbeddingResponse distributeAIEmbedding(String input) throws IOException {
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData();
String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
if (Objects.nonNull(config)) {
aiSqlSource = config.getContent();
}
AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(aiSqlSource);
if (Objects.isNull(aiSqlSourceEnum)) {
aiSqlSourceEnum = AiSqlSourceEnum.OPENAI;
}
switch (Objects.requireNonNull(aiSqlSourceEnum)) {
case OPENAI :
case CHAT2DBAI:
case RESTAI :
case FASTCHATAI:
case AZUREAI :
case CLAUDEAI:
return distributeAIEmbedding(input);
}
return distributeAIEmbedding(input);
}
/**
* embedding with fast chat openai
*
* @param input
* @return
* @throws IOException
*/
private FastChatEmbeddingResponse embeddingWithFastChatAi(String input) throws IOException {
FastChatEmbeddingResponse response = FastChatAIClient.getInstance().embeddings(input);
return response;
}
///**
// * 问答对话模型
// *

View File

@ -0,0 +1,13 @@
package ai.chat2db.server.web.api.controller.ai.DocParser;
import java.io.InputStream;
import java.util.List;
/**
* @author CYY
* @date 2023年03月20日 上午8:13
* @description
*/
public abstract class AbstractParser {
public abstract List<String> parse(InputStream inputStream) throws Exception;
}

View File

@ -0,0 +1,46 @@
package ai.chat2db.server.web.api.controller.ai.DocParser;
import org.apache.pdfbox.pdmodel.PDDocument;
import org.apache.pdfbox.text.PDFTextStripper;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
/**
* @author CYY
* @date 2023年03月11日 下午3:23
* @description
*/
public class PdfParse extends AbstractParser {
private static final int MAX_LENGTH = 200;
@Override
public List<String> parse(InputStream inputStream) throws IOException {
// 打开 PDF 文件
PDDocument document = PDDocument.load(inputStream);
// 创建 PDFTextStripper 对象
PDFTextStripper stripper = new PDFTextStripper();
// 获取文本内容
String text = stripper.getText(document);
//过滤字符
text = text.replaceAll("\\s", " ").replaceAll("(\\r\\n|\\r|\\n|\\n\\r)"," ");
String[] sentence = text.split("");
List<String> ans = new ArrayList<>();
for (String s : sentence) {
if (s.length() > MAX_LENGTH) {
for (int index = 0; index < sentence.length; index = (index + 1) * MAX_LENGTH) {
String substring = s.substring(index, MAX_LENGTH);
if(substring.length() < 5) continue;
ans.add(substring);
}
} else {
ans.add(s);
}
}
// 关闭文档
document.close();
return ans;
}
}

View File

@ -0,0 +1,131 @@
package ai.chat2db.server.web.api.controller.ai;
import ai.chat2db.server.tools.base.wrapper.result.ActionResult;
import ai.chat2db.server.tools.base.wrapper.result.DataResult;
import ai.chat2db.server.tools.common.exception.ParamBusinessException;
import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect;
import ai.chat2db.server.web.api.controller.ai.DocParser.AbstractParser;
import ai.chat2db.server.web.api.controller.ai.DocParser.PdfParse;
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse;
import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest;
import ai.chat2db.server.web.api.http.GatewayClientService;
import ai.chat2db.server.web.api.http.model.Knowledge;
import ai.chat2db.server.web.api.http.request.KnowledgeRequest;
import ai.chat2db.server.web.api.http.response.KnowledgeResponse;
import cn.hutool.core.util.StrUtil;
import com.alibaba.fastjson2.JSON;
import jakarta.annotation.Resource;
import jakarta.servlet.http.HttpServletRequest;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
import java.math.BigDecimal;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* @author moji
*/
@RestController
@ConnectionInfoAspect
@RequestMapping("/api/ai/knowledge")
@Slf4j
public class KnowledgeController extends ChatController {
/**
* chat的超时时间
*/
private static final Long CHAT_TIMEOUT = Duration.ofMinutes(50).toMillis();
@Resource
private GatewayClientService gatewayClientService;
/**
* save knowledge embeddings from pdf file
*
* @param file
* @return
* @throws IOException
*/
@PostMapping("/embeddings")
@CrossOrigin
public ActionResult embeddings(MultipartFile file, HttpServletRequest request)
throws Exception {
AbstractParser pdfParse = new PdfParse();
List<String> sentenceList = pdfParse.parse(file.getInputStream());
List<Integer> contentWordCount = new ArrayList<>();
List<List<BigDecimal>> contentVector = new ArrayList<>();
for(String str : sentenceList){
contentWordCount.add(str.length());
// request embedding
FastChatEmbeddingResponse response = distributeAIEmbedding(str);
if(response == null){
continue;
}
contentVector.add(response.getData().get(0).getEmbedding());
}
KnowledgeRequest knowledgeRequest = new KnowledgeRequest();
knowledgeRequest.setContentVector(contentVector);
knowledgeRequest.setSentenceList(sentenceList);
// save knowledge embedding
ActionResult actionResult = gatewayClientService.knowledgeVectorSave(knowledgeRequest);
return actionResult;
}
/**
* search knowledge embeddings
*
* @param queryRequest
* @return
* @throws IOException
*/
@GetMapping("/search")
@CrossOrigin
public SseEmitter search(ChatQueryRequest queryRequest, @RequestHeader Map<String, String> headers)
throws Exception {
// request embedding
FastChatEmbeddingResponse response = distributeAIEmbedding(queryRequest.getMessage());
List<List<BigDecimal>> contentVector = new ArrayList<>();
contentVector.add(response.getData().get(0).getEmbedding());
// search embedding
DataResult<KnowledgeResponse> result = gatewayClientService.knowledgeVectorSearch(contentVector);
String prompt = queryRequest.getMessage();
if (CollectionUtils.isNotEmpty(result.getData().getKnowledgeList())) {
List<String> contents = new ArrayList<>();
for(Knowledge data: result.getData().getKnowledgeList()){
contents.add(data.getContent());
}
prompt = String.format("基于%s。请回答%s。", JSON.toJSONString(contents), prompt);
queryRequest.setMessage(prompt);
}
// chat with AI
SseEmitter sseEmitter = new SseEmitter(CHAT_TIMEOUT);
String uid = headers.get("uid");
if (StrUtil.isBlank(uid)) {
throw new ParamBusinessException("uid");
}
if (StringUtils.isBlank(queryRequest.getMessage())) {
throw new ParamBusinessException("message");
}
return distributeAISql(queryRequest, sseEmitter, uid);
}
}

View File

@ -33,6 +33,11 @@ public enum PromptType implements BaseEnum<String> {
* SQL转换
*/
SQL_2_SQL("进行SQL转换"),
/**
* knowledge qa
*/
QUESTION_ANSWERING("问答"),
;
final String description;

View File

@ -29,6 +29,11 @@ public class FastChatAIClient {
*/
public static final String FASTCHAT_MODEL= "fastchat.model";
/**
* FASTCHAT OPENAI embedding model
*/
public static final String FASTCHAT_embedding_MODEL= "fastchat.embedding.model";
private static FastChatAIStreamClient FASTCHAT_AI_CLIENT;

View File

@ -1,11 +1,14 @@
package ai.chat2db.server.web.api.controller.ai.fastchat.client;
import ai.chat2db.server.tools.common.exception.ParamBusinessException;
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbedding;
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse;
import ai.chat2db.server.web.api.controller.ai.fastchat.interceptor.FastChatHeaderAuthorizationInterceptor;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatCompletionsOptions;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage;
import cn.hutool.http.ContentType;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.reactivex.Single;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
@ -16,6 +19,7 @@ import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull;
import java.util.List;
@ -50,12 +54,21 @@ public class FastChatAIStreamClient {
@Getter
private String model;
/**
* embeddingModel
*/
@Getter
private String embeddingModel;
/**
* okHttpClient
*/
@Getter
private OkHttpClient okHttpClient;
@Getter
private FastChatOpenAiApi fastChatOpenAiApi;
/**
* @param builder
@ -64,6 +77,7 @@ public class FastChatAIStreamClient {
this.apiKey = builder.apiKey;
this.apiHost = builder.apiHost;
this.model = builder.model;
this.embeddingModel = builder.embeddingModel;
if (Objects.isNull(builder.okHttpClient)) {
builder.okHttpClient = this.okHttpClient();
}
@ -103,6 +117,8 @@ public class FastChatAIStreamClient {
private String model;
private String embeddingModel;
/**
* OkhttpClient
*/
@ -134,6 +150,11 @@ public class FastChatAIStreamClient {
return this;
}
public FastChatAIStreamClient.Builder embeddingModel(String embeddingModelValue) {
this.embeddingModel = embeddingModelValue;
return this;
}
public FastChatAIStreamClient.Builder okHttpClient(OkHttpClient val) {
this.okHttpClient = val;
return this;
@ -184,4 +205,28 @@ public class FastChatAIStreamClient {
}
}
/**
* Creates an embedding vector representing the input text.
*
* @param input
* @return EmbeddingResponse
*/
public FastChatEmbeddingResponse embeddings(String input) {
FastChatEmbedding embedding = FastChatEmbedding.builder().input(input).build();
if (StringUtils.isNotBlank(this.embeddingModel)) {
embedding.setModel(this.embeddingModel);
}
return this.embeddings(embedding);
}
/**
* Creates an embedding vector representing the input text.
*
* @param embedding
* @return EmbeddingResponse
*/
public FastChatEmbeddingResponse embeddings(FastChatEmbedding embedding) {
Single<FastChatEmbeddingResponse> embeddings = this.fastChatOpenAiApi.embeddings(embedding);
return embeddings.blockingGet();
}
}

View File

@ -0,0 +1,54 @@
package ai.chat2db.server.web.api.controller.ai.fastchat.client;
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbedding;
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse;
import com.unfbx.chatgpt.entity.billing.CreditGrantsResponse;
import com.unfbx.chatgpt.entity.chat.ChatCompletion;
import com.unfbx.chatgpt.entity.chat.ChatCompletionResponse;
import com.unfbx.chatgpt.entity.common.DeleteResponse;
import com.unfbx.chatgpt.entity.common.OpenAiResponse;
import com.unfbx.chatgpt.entity.completions.Completion;
import com.unfbx.chatgpt.entity.completions.CompletionResponse;
import com.unfbx.chatgpt.entity.edits.Edit;
import com.unfbx.chatgpt.entity.edits.EditResponse;
import com.unfbx.chatgpt.entity.embeddings.Embedding;
import com.unfbx.chatgpt.entity.embeddings.EmbeddingResponse;
import com.unfbx.chatgpt.entity.engines.Engine;
import com.unfbx.chatgpt.entity.files.File;
import com.unfbx.chatgpt.entity.files.UploadFileResponse;
import com.unfbx.chatgpt.entity.fineTune.Event;
import com.unfbx.chatgpt.entity.fineTune.FineTune;
import com.unfbx.chatgpt.entity.fineTune.FineTuneResponse;
import com.unfbx.chatgpt.entity.images.Image;
import com.unfbx.chatgpt.entity.images.ImageResponse;
import com.unfbx.chatgpt.entity.models.Model;
import com.unfbx.chatgpt.entity.models.ModelResponse;
import com.unfbx.chatgpt.entity.moderations.Moderation;
import com.unfbx.chatgpt.entity.moderations.ModerationResponse;
import com.unfbx.chatgpt.entity.whisper.WhisperResponse;
import io.reactivex.Single;
import okhttp3.MultipartBody;
import okhttp3.RequestBody;
import okhttp3.ResponseBody;
import retrofit2.http.*;
import java.util.Map;
/**
* 描述: open ai官方api接口
*
* @author https:www.unfbx.com
* 2023-02-15
*/
public interface FastChatOpenAiApi {
/**
* Creates an embedding vector representing the input text.
*
* @param embedding
* @return Single EmbeddingResponse
*/
@POST("v1/embeddings")
Single<FastChatEmbeddingResponse> embeddings(@Body FastChatEmbedding embedding);
}

View File

@ -0,0 +1,73 @@
package ai.chat2db.server.web.api.controller.ai.fastchat.embeddings;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.unfbx.chatgpt.exception.BaseException;
import com.unfbx.chatgpt.exception.CommonError;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import java.io.Serializable;
import java.util.Objects;
/**
* 描述:
*
* @author https:www.unfbx.com
* 2023-02-15
*/
@Getter
@Slf4j
@Builder
@JsonInclude(JsonInclude.Include.NON_NULL)
@NoArgsConstructor
@AllArgsConstructor
public class FastChatEmbedding implements Serializable {
@NonNull
@Builder.Default
private String model = Model.TEXT_EMBEDDING_ADA_002.getName();
/**
* 必选项长度不能超过8192
*/
@NonNull
private String input;
private String user;
public void setModel(Model model) {
if (Objects.isNull(model)) {
model = Model.TEXT_EMBEDDING_ADA_002;
}
this.model = model.getName();
}
public void setModel(String model) {
if (Objects.isNull(model)) {
model = Model.TEXT_EMBEDDING_ADA_002.getName();
}
this.model = model;
}
public void setInput(String input) {
if (input == null || "".equals(input)) {
log.error("input不能为空");
throw new BaseException(CommonError.PARAM_ERROR);
}
if (input.length() > 8192) {
log.error("input超长");
throw new BaseException(CommonError.PARAM_ERROR);
}
this.input = input;
}
public void setUser(String user) {
this.user = user;
}
@Getter
@AllArgsConstructor
public enum Model {
TEXT_EMBEDDING_ADA_002("text-embedding-ada-002"),
;
private String name;
}
}

View File

@ -0,0 +1,22 @@
package ai.chat2db.server.web.api.controller.ai.fastchat.embeddings;
import com.unfbx.chatgpt.entity.common.Usage;
import lombok.Data;
import java.io.Serializable;
import java.util.List;
/**
* 描述:
*
* @author https:www.unfbx.com
* 2023-02-15
*/
@Data
public class FastChatEmbeddingResponse implements Serializable {
private String object;
private List<FastChatItem> data;
private String model;
private Usage usage;
}

View File

@ -0,0 +1,14 @@
package ai.chat2db.server.web.api.controller.ai.fastchat.embeddings;
import lombok.Data;
import java.io.Serializable;
import java.math.BigDecimal;
import java.util.List;
@Data
public class FastChatItem implements Serializable {
private String object;
private List<BigDecimal> embedding;
private Integer index;
}

View File

@ -1,13 +1,17 @@
package ai.chat2db.server.web.api.http;
import ai.chat2db.server.tools.base.wrapper.result.ActionResult;
import ai.chat2db.server.tools.base.wrapper.result.DataResult;
import ai.chat2db.server.web.api.http.request.KnowledgeRequest;
import ai.chat2db.server.web.api.http.request.TableSchemaRequest;
import ai.chat2db.server.web.api.http.response.ApiKeyResponse;
import ai.chat2db.server.web.api.http.response.InviteQrCodeResponse;
import ai.chat2db.server.web.api.http.response.KnowledgeResponse;
import ai.chat2db.server.web.api.http.response.QrCodeResponse;
import com.dtflys.forest.annotation.BaseRequest;
import com.dtflys.forest.annotation.Get;
import com.dtflys.forest.annotation.Query;
import com.dtflys.forest.annotation.Var;
import com.dtflys.forest.annotation.*;
import java.math.BigDecimal;
import java.util.List;
/**
* Gateway 的http 服务
@ -54,4 +58,40 @@ public interface GatewayClientService {
@Get("/api/client/inviteQrCode")
DataResult<InviteQrCodeResponse> getInviteQrCode(@Query("apiKey") String apiKey);
/**
* save knowledge vector
*
* @param request
* @return
*/
@Post("/api/milvus/knowledge/save")
ActionResult knowledgeVectorSave(KnowledgeRequest request);
/**
* save table schema vector
*
* @param request
* @return
*/
@Post("/api/milvus/schema/save")
ActionResult schemaVectorSave(TableSchemaRequest request);
/**
* save knowledge vector
*
* @param searchVectors
* @return
*/
@Get("/api/milvus/knowledge/search")
DataResult<KnowledgeResponse> knowledgeVectorSearch(List<List<BigDecimal>> searchVectors);
/**
* save table schema vector
*
* @param searchVectors
* @return
*/
@Get("/api/milvus/schema/search")
DataResult<KnowledgeResponse> schemaVectorSearch(List<List<BigDecimal>> searchVectors);
}

View File

@ -0,0 +1,27 @@
package ai.chat2db.server.web.api.http.model;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.SuperBuilder;
@Data
@SuperBuilder
@NoArgsConstructor
@AllArgsConstructor
public class Knowledge {
private Long id;
private String content;
private String contentVector;
private Integer wordCount;
public Knowledge(Long id, String content, Integer wordCount) {
this.id = id;
this.content = content;
this.wordCount = wordCount;
}
}

View File

@ -0,0 +1,30 @@
package ai.chat2db.server.web.api.http.model;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.SuperBuilder;
@Data
@SuperBuilder
@NoArgsConstructor
@AllArgsConstructor
public class TableSchema {
private Long id;
private Long dataSourceId;
private String tableSchema;
private String tableSchemaVector;
private Integer wordCount;
public TableSchema(Long id, Long dataSourceId, String tableSchema, Integer wordCount) {
this.id = id;
this.dataSourceId = dataSourceId;
this.tableSchema = tableSchema;
this.wordCount = wordCount;
}
}

View File

@ -0,0 +1,20 @@
package ai.chat2db.server.web.api.http.request;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.SuperBuilder;
import java.math.BigDecimal;
import java.util.List;
@Data
@SuperBuilder
@NoArgsConstructor
@AllArgsConstructor
public class KnowledgeRequest {
private List<List<BigDecimal>> contentVector;
private List<String> sentenceList;
}

View File

@ -0,0 +1,22 @@
package ai.chat2db.server.web.api.http.request;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.SuperBuilder;
import java.math.BigDecimal;
import java.util.List;
@Data
@SuperBuilder
@NoArgsConstructor
@AllArgsConstructor
public class TableSchemaRequest {
private Long dataSourceId;
private List<List<BigDecimal>> contentVector;
private List<String> sentenceList;
}

View File

@ -0,0 +1,18 @@
package ai.chat2db.server.web.api.http.response;
import ai.chat2db.server.web.api.http.model.Knowledge;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.SuperBuilder;
import java.util.List;
@Data
@SuperBuilder
@NoArgsConstructor
@AllArgsConstructor
public class KnowledgeResponse {
private List<Knowledge> knowledgeList;
}

View File

@ -0,0 +1,18 @@
package ai.chat2db.server.web.api.http.response;
import ai.chat2db.server.web.api.http.model.TableSchema;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.SuperBuilder;
import java.util.List;
@Data
@SuperBuilder
@NoArgsConstructor
@AllArgsConstructor
public class TableSchemaResponse {
private List<TableSchema> tableSchemas;
}