add prompt support

This commit is contained in:
robin
2023-10-04 12:09:31 +08:00
parent eed1111fc9
commit a533f303f9
3 changed files with 139 additions and 10 deletions

View File

@ -1,6 +1,7 @@
package ai.chat2db.server.web.api.controller.ai;
import java.io.IOException;
import java.math.BigDecimal;
import java.time.Duration;
import java.time.LocalDateTime;
import java.util.ArrayList;
@ -42,14 +43,19 @@ import ai.chat2db.server.web.api.controller.ai.rest.listener.RestAIEventSourceLi
import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest;
import ai.chat2db.server.web.api.controller.ai.request.ChatRequest;
import ai.chat2db.server.web.api.controller.ai.rest.client.RestAIClient;
import ai.chat2db.server.web.api.http.GatewayClientService;
import ai.chat2db.server.web.api.http.model.TableSchema;
import ai.chat2db.server.web.api.http.response.TableSchemaResponse;
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
import ai.chat2db.server.web.api.controller.ai.openai.client.OpenAIClient;
import ai.chat2db.spi.model.TableColumn;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson2.JSON;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.unfbx.chatgpt.entity.chat.Message;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
@ -91,6 +97,9 @@ public class ChatController {
@Value("${chatgpt.version}")
private String gptVersion;
@Resource
private GatewayClientService gatewayClientService;
/**
* chat的超时时间
*/
@ -459,11 +468,7 @@ public class ChatController {
}
// 查询schema信息
DataResult<DataSource> dataResult = dataSourceService.queryById(queryRequest.getDataSourceId());
String dataSourceType = dataResult.getData().getType();
if (StringUtils.isBlank(dataSourceType)) {
dataSourceType = "MYSQL";
}
String dataSourceType = queryDatabaseType(queryRequest);
TableQueryParam queryParam = chatConverter.chat2tableQuery(queryRequest);
Map<String, List<TableColumn>> tableColumns = buildTableColumn(queryParam, queryRequest.getTableNames());
List<String> tableSchemas = tableColumns.entrySet().stream().map(
@ -492,6 +497,48 @@ public class ChatController {
return schemaProperty;
}
/**
* query database type
*
* @param queryRequest
* @return
*/
public String queryDatabaseType(ChatQueryRequest queryRequest) {
// 查询schema信息
DataResult<DataSource> dataResult = dataSourceService.queryById(queryRequest.getDataSourceId());
String dataSourceType = dataResult.getData().getType();
if (StringUtils.isBlank(dataSourceType)) {
dataSourceType = "MYSQL";
}
return dataSourceType;
}
/**
* query database schema
*
* @param queryRequest
* @return
* @throws IOException
*/
public String queryDatabaseSchema(ChatQueryRequest queryRequest) throws IOException {
// request embedding
FastChatEmbeddingResponse response = distributeAIEmbedding(queryRequest.getMessage());
List<List<BigDecimal>> contentVector = new ArrayList<>();
contentVector.add(response.getData().get(0).getEmbedding());
// search embedding
DataResult<TableSchemaResponse> result = gatewayClientService.schemaVectorSearch(contentVector);
List<String> schemas = Lists.newArrayList();
if (CollectionUtils.isNotEmpty(result.getData().getTableSchemas())) {
for(TableSchema data: result.getData().getTableSchemas()){
schemas.add(data.getTableSchema());
}
}
return JSON.toJSONString(schemas);
}
/**
* distribute embedding with different AI
*

View File

@ -0,0 +1,85 @@
package ai.chat2db.server.web.api.controller.ai;
import ai.chat2db.server.tools.common.exception.ParamBusinessException;
import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect;
import ai.chat2db.server.web.api.controller.ai.enums.PromptType;
import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest;
import ai.chat2db.server.web.api.http.GatewayClientService;
import cn.hutool.core.util.StrUtil;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
import java.time.Duration;
import java.util.Map;
/**
* @author moji
*/
@RestController
@ConnectionInfoAspect
@RequestMapping("/api/ai/text/generation")
@Slf4j
public class TextGenerationController extends ChatController {
/**
* chat的超时时间
*/
private static final Long CHAT_TIMEOUT = Duration.ofMinutes(50).toMillis();
@Resource
private GatewayClientService gatewayClientService;
/**
* sql auto complete
*
* @param queryRequest
* @return
* @throws IOException
*/
@GetMapping("/prompt")
@CrossOrigin
public SseEmitter prompt(ChatQueryRequest queryRequest, @RequestHeader Map<String, String> headers)
throws Exception {
queryRequest.setPromptType(PromptType.TEXT_GENERATION.getCode());
String promptTemplate = "### Instructions:\n" +
"Your task is generate a SQL query according to the prompt, given a %s database schema.\n" +
"Adhere to these rules:\n" +
"- **Deliberately go through the prompt and database schema word by word** to appropriately answer the question\n" +
"- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.\n" +
"\n" +
"### Input:\n" +
"Generate a SQL query according to the prompt `%s`.\n" +
"This query will run on a database whose schema is represented in this string:\n" +
"{%s}\n" +
"\n" +
"### Response:\n" +
"Based on your instructions, here is the SQL query I have generated to complete the prompt `{%s}`:\n" +
"```sql";
// query database schema info
String databaseType = queryDatabaseType(queryRequest);
String schemas = queryDatabaseSchema(queryRequest);
String prompt = String.format(promptTemplate, databaseType, queryRequest.getMessage(), schemas, queryRequest.getMessage());
queryRequest.setMessage(prompt);
// chat with AI
SseEmitter sseEmitter = new SseEmitter(CHAT_TIMEOUT);
String uid = headers.get("uid");
if (StrUtil.isBlank(uid)) {
throw new ParamBusinessException("uid");
}
if (StringUtils.isBlank(queryRequest.getMessage())) {
throw new ParamBusinessException("message");
}
return distributeAISql(queryRequest, sseEmitter, uid);
}
}

View File

@ -4,10 +4,7 @@ import ai.chat2db.server.tools.base.wrapper.result.ActionResult;
import ai.chat2db.server.tools.base.wrapper.result.DataResult;
import ai.chat2db.server.web.api.http.request.KnowledgeRequest;
import ai.chat2db.server.web.api.http.request.TableSchemaRequest;
import ai.chat2db.server.web.api.http.response.ApiKeyResponse;
import ai.chat2db.server.web.api.http.response.InviteQrCodeResponse;
import ai.chat2db.server.web.api.http.response.KnowledgeResponse;
import ai.chat2db.server.web.api.http.response.QrCodeResponse;
import ai.chat2db.server.web.api.http.response.*;
import com.dtflys.forest.annotation.*;
import java.math.BigDecimal;
@ -93,5 +90,5 @@ public interface GatewayClientService {
* @return
*/
@Get("/api/milvus/schema/search")
DataResult<KnowledgeResponse> schemaVectorSearch(List<List<BigDecimal>> searchVectors);
DataResult<TableSchemaResponse> schemaVectorSearch(List<List<BigDecimal>> searchVectors);
}