mirror of
https://github.com/CodePhiliaX/Chat2DB.git
synced 2025-08-01 08:52:11 +08:00
add prompt support
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
package ai.chat2db.server.web.api.controller.ai;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.math.BigDecimal;
|
||||
import java.time.Duration;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.ArrayList;
|
||||
@ -42,14 +43,19 @@ import ai.chat2db.server.web.api.controller.ai.rest.listener.RestAIEventSourceLi
|
||||
import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest;
|
||||
import ai.chat2db.server.web.api.controller.ai.request.ChatRequest;
|
||||
import ai.chat2db.server.web.api.controller.ai.rest.client.RestAIClient;
|
||||
import ai.chat2db.server.web.api.http.GatewayClientService;
|
||||
import ai.chat2db.server.web.api.http.model.TableSchema;
|
||||
import ai.chat2db.server.web.api.http.response.TableSchemaResponse;
|
||||
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
|
||||
import ai.chat2db.server.web.api.controller.ai.openai.client.OpenAIClient;
|
||||
import ai.chat2db.spi.model.TableColumn;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import com.alibaba.fastjson2.JSON;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Maps;
|
||||
import com.unfbx.chatgpt.entity.chat.Message;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
@ -91,6 +97,9 @@ public class ChatController {
|
||||
@Value("${chatgpt.version}")
|
||||
private String gptVersion;
|
||||
|
||||
@Resource
|
||||
private GatewayClientService gatewayClientService;
|
||||
|
||||
/**
|
||||
* chat的超时时间
|
||||
*/
|
||||
@ -459,11 +468,7 @@ public class ChatController {
|
||||
}
|
||||
|
||||
// 查询schema信息
|
||||
DataResult<DataSource> dataResult = dataSourceService.queryById(queryRequest.getDataSourceId());
|
||||
String dataSourceType = dataResult.getData().getType();
|
||||
if (StringUtils.isBlank(dataSourceType)) {
|
||||
dataSourceType = "MYSQL";
|
||||
}
|
||||
String dataSourceType = queryDatabaseType(queryRequest);
|
||||
TableQueryParam queryParam = chatConverter.chat2tableQuery(queryRequest);
|
||||
Map<String, List<TableColumn>> tableColumns = buildTableColumn(queryParam, queryRequest.getTableNames());
|
||||
List<String> tableSchemas = tableColumns.entrySet().stream().map(
|
||||
@ -492,6 +497,48 @@ public class ChatController {
|
||||
return schemaProperty;
|
||||
}
|
||||
|
||||
/**
|
||||
* query database type
|
||||
*
|
||||
* @param queryRequest
|
||||
* @return
|
||||
*/
|
||||
public String queryDatabaseType(ChatQueryRequest queryRequest) {
|
||||
// 查询schema信息
|
||||
DataResult<DataSource> dataResult = dataSourceService.queryById(queryRequest.getDataSourceId());
|
||||
String dataSourceType = dataResult.getData().getType();
|
||||
if (StringUtils.isBlank(dataSourceType)) {
|
||||
dataSourceType = "MYSQL";
|
||||
}
|
||||
return dataSourceType;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* query database schema
|
||||
*
|
||||
* @param queryRequest
|
||||
* @return
|
||||
* @throws IOException
|
||||
*/
|
||||
public String queryDatabaseSchema(ChatQueryRequest queryRequest) throws IOException {
|
||||
// request embedding
|
||||
FastChatEmbeddingResponse response = distributeAIEmbedding(queryRequest.getMessage());
|
||||
List<List<BigDecimal>> contentVector = new ArrayList<>();
|
||||
contentVector.add(response.getData().get(0).getEmbedding());
|
||||
|
||||
// search embedding
|
||||
DataResult<TableSchemaResponse> result = gatewayClientService.schemaVectorSearch(contentVector);
|
||||
|
||||
List<String> schemas = Lists.newArrayList();
|
||||
if (CollectionUtils.isNotEmpty(result.getData().getTableSchemas())) {
|
||||
for(TableSchema data: result.getData().getTableSchemas()){
|
||||
schemas.add(data.getTableSchema());
|
||||
}
|
||||
}
|
||||
return JSON.toJSONString(schemas);
|
||||
}
|
||||
|
||||
/**
|
||||
* distribute embedding with different AI
|
||||
*
|
||||
|
@ -0,0 +1,85 @@
|
||||
package ai.chat2db.server.web.api.controller.ai;
|
||||
|
||||
import ai.chat2db.server.tools.common.exception.ParamBusinessException;
|
||||
import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect;
|
||||
import ai.chat2db.server.web.api.controller.ai.enums.PromptType;
|
||||
import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest;
|
||||
import ai.chat2db.server.web.api.http.GatewayClientService;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Duration;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* @author moji
|
||||
*/
|
||||
@RestController
|
||||
@ConnectionInfoAspect
|
||||
@RequestMapping("/api/ai/text/generation")
|
||||
@Slf4j
|
||||
public class TextGenerationController extends ChatController {
|
||||
|
||||
|
||||
/**
|
||||
* chat的超时时间
|
||||
*/
|
||||
private static final Long CHAT_TIMEOUT = Duration.ofMinutes(50).toMillis();
|
||||
|
||||
|
||||
@Resource
|
||||
private GatewayClientService gatewayClientService;
|
||||
|
||||
/**
|
||||
* sql auto complete
|
||||
*
|
||||
* @param queryRequest
|
||||
* @return
|
||||
* @throws IOException
|
||||
*/
|
||||
@GetMapping("/prompt")
|
||||
@CrossOrigin
|
||||
public SseEmitter prompt(ChatQueryRequest queryRequest, @RequestHeader Map<String, String> headers)
|
||||
throws Exception {
|
||||
queryRequest.setPromptType(PromptType.TEXT_GENERATION.getCode());
|
||||
String promptTemplate = "### Instructions:\n" +
|
||||
"Your task is generate a SQL query according to the prompt, given a %s database schema.\n" +
|
||||
"Adhere to these rules:\n" +
|
||||
"- **Deliberately go through the prompt and database schema word by word** to appropriately answer the question\n" +
|
||||
"- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.\n" +
|
||||
"\n" +
|
||||
"### Input:\n" +
|
||||
"Generate a SQL query according to the prompt `%s`.\n" +
|
||||
"This query will run on a database whose schema is represented in this string:\n" +
|
||||
"{%s}\n" +
|
||||
"\n" +
|
||||
"### Response:\n" +
|
||||
"Based on your instructions, here is the SQL query I have generated to complete the prompt `{%s}`:\n" +
|
||||
"```sql";
|
||||
|
||||
// query database schema info
|
||||
String databaseType = queryDatabaseType(queryRequest);
|
||||
String schemas = queryDatabaseSchema(queryRequest);
|
||||
String prompt = String.format(promptTemplate, databaseType, queryRequest.getMessage(), schemas, queryRequest.getMessage());
|
||||
queryRequest.setMessage(prompt);
|
||||
|
||||
// chat with AI
|
||||
SseEmitter sseEmitter = new SseEmitter(CHAT_TIMEOUT);
|
||||
String uid = headers.get("uid");
|
||||
if (StrUtil.isBlank(uid)) {
|
||||
throw new ParamBusinessException("uid");
|
||||
}
|
||||
|
||||
if (StringUtils.isBlank(queryRequest.getMessage())) {
|
||||
throw new ParamBusinessException("message");
|
||||
}
|
||||
|
||||
return distributeAISql(queryRequest, sseEmitter, uid);
|
||||
}
|
||||
|
||||
}
|
@ -4,10 +4,7 @@ import ai.chat2db.server.tools.base.wrapper.result.ActionResult;
|
||||
import ai.chat2db.server.tools.base.wrapper.result.DataResult;
|
||||
import ai.chat2db.server.web.api.http.request.KnowledgeRequest;
|
||||
import ai.chat2db.server.web.api.http.request.TableSchemaRequest;
|
||||
import ai.chat2db.server.web.api.http.response.ApiKeyResponse;
|
||||
import ai.chat2db.server.web.api.http.response.InviteQrCodeResponse;
|
||||
import ai.chat2db.server.web.api.http.response.KnowledgeResponse;
|
||||
import ai.chat2db.server.web.api.http.response.QrCodeResponse;
|
||||
import ai.chat2db.server.web.api.http.response.*;
|
||||
import com.dtflys.forest.annotation.*;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
@ -93,5 +90,5 @@ public interface GatewayClientService {
|
||||
* @return
|
||||
*/
|
||||
@Get("/api/milvus/schema/search")
|
||||
DataResult<KnowledgeResponse> schemaVectorSearch(List<List<BigDecimal>> searchVectors);
|
||||
DataResult<TableSchemaResponse> schemaVectorSearch(List<List<BigDecimal>> searchVectors);
|
||||
}
|
||||
|
Reference in New Issue
Block a user