cache connection

This commit is contained in:
SwallowGG
2024-05-14 15:07:19 +08:00
parent a52d0c6982
commit 3a6ee931ca
6 changed files with 237 additions and 57 deletions

View File

@ -61,6 +61,7 @@ public enum OracleColumnTypeEnum implements ColumnBuilder {
NATIONAL_CHAR_VARYING("NATIONAL CHAR VARYING", true, false, true, false, false, false, true, true, false, true), NATIONAL_CHAR_VARYING("NATIONAL CHAR VARYING", true, false, true, false, false, false, true, true, false, true),
NATIONAL_CHARACTER("NATIONAL CHARACTER", true, false, true, false, false, false, true, true, false, true), NATIONAL_CHARACTER("NATIONAL CHARACTER", true, false, true, false, false, false, true, true, false, true),

View File

@ -84,17 +84,17 @@ public abstract class DataSourceConverter {
* @param param * @param param
* @return * @return
*/ */
protected String encryptString(DataSourceUpdateParam param) { protected String encryptString(DataSourceUpdateParam param) {
String encryptStr = param.getPassword(); String encryptStr = param.getPassword();
try { try {
DesUtil desUtil = new DesUtil(DesUtil.DES_KEY); DesUtil desUtil = new DesUtil(DesUtil.DES_KEY);
encryptStr = desUtil.encrypt(param.getPassword(), "CBC"); encryptStr = desUtil.encrypt(param.getPassword(), "CBC");
} catch (Exception exception) { } catch (Exception exception) {
// do nothing // do nothing
log.error("encrypt error", exception); log.error("encrypt error", exception);
}
return encryptStr;
} }
return encryptStr;
}
/** /**
* decrypt * decrypt

View File

@ -25,15 +25,7 @@ import java.util.concurrent.ConcurrentHashMap;
public class Chat2DBContext { public class Chat2DBContext {
private static final ThreadLocal<ConnectInfo> CONNECT_INFO_THREAD_LOCAL = new ThreadLocal<>(); private static final ThreadLocal<ConnectInfo> CONNECT_INFO_THREAD_LOCAL = new ThreadLocal<>();
// private static final Cache<String, ConnectInfo> CONNECT_INFO_CACHE = CacheBuilder.newBuilder()
// .maximumSize(1000)
// .expireAfterAccess(5, TimeUnit.MINUTES)
// .removalListener((RemovalListener<String, ConnectInfo>) notification -> {
// if (notification.getValue() != null) {
// System.out.println("remove connect info " + notification.getKey());
// notification.getValue().close();
// }
// }).build();
public static Map<String, Plugin> PLUGIN_MAP = new ConcurrentHashMap<>(); public static Map<String, Plugin> PLUGIN_MAP = new ConcurrentHashMap<>();
@ -87,36 +79,37 @@ public class Chat2DBContext {
} }
public static Connection getConnection() { public static Connection getConnection() {
ConnectInfo connectInfo = getConnectInfo(); // ConnectInfo connectInfo = getConnectInfo();
Connection connection = connectInfo.getConnection(); // Connection connection = connectInfo.getConnection();
try { // try {
if (connection == null || connection.isClosed()) { // if (connection == null || connection.isClosed()) {
synchronized (connectInfo) { // synchronized (connectInfo) {
connection = connectInfo.getConnection(); // connection = connectInfo.getConnection();
try { // try {
if (connection != null && !connection.isClosed()) { // if (connection != null && !connection.isClosed()) {
log.info("get connection from cache"); // log.info("get connection from cache");
return connection; // return connection;
} else { // } else {
log.info("get connection from db begin"); // log.info("get connection from db begin");
connection = getDBManage().getConnection(connectInfo); // connection = getDBManage().getConnection(connectInfo);
log.info("get connection from db end"); // log.info("get connection from db end");
} // }
} catch (SQLException e) { // } catch (SQLException e) {
log.error("get connection error", e); // log.error("get connection error", e);
log.info("get connection from db begin2"); // log.info("get connection from db begin2");
connection = getDBManage().getConnection(connectInfo); // connection = getDBManage().getConnection(connectInfo);
log.info("get connection from db end2"); // log.info("get connection from db end2");
} // }
connectInfo.setConnection(connection); // connectInfo.setConnection(connection);
} // }
} // }
} catch (SQLException e) { // } catch (SQLException e) {
log.error("get connection error", e); // log.error("get connection error", e);
} // }
return connection; return ConnectionPool.getConnection(getConnectInfo());
} }
public static String getDbVersion() { public static String getDbVersion() {
ConnectInfo connectInfo = getConnectInfo(); ConnectInfo connectInfo = getConnectInfo();
String dbVersion = connectInfo.getDbVersion(); String dbVersion = connectInfo.getDbVersion();
@ -157,8 +150,9 @@ public class Chat2DBContext {
public static void removeContext() { public static void removeContext() {
ConnectInfo connectInfo = CONNECT_INFO_THREAD_LOCAL.get(); ConnectInfo connectInfo = CONNECT_INFO_THREAD_LOCAL.get();
if (connectInfo != null) { if (connectInfo != null) {
connectInfo.close(); // connectInfo.close();
CONNECT_INFO_THREAD_LOCAL.remove(); CONNECT_INFO_THREAD_LOCAL.remove();
ConnectionPool.close(connectInfo);
} }
} }

View File

@ -4,6 +4,7 @@ package ai.chat2db.spi.sql;
import java.sql.Connection; import java.sql.Connection;
import java.sql.SQLException; import java.sql.SQLException;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.util.Date;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
@ -142,6 +143,9 @@ public class ConnectInfo {
private DriverConfig driverConfig; private DriverConfig driverConfig;
private Date lastAccessTime;
public String getDbVersion() { public String getDbVersion() {
return dbVersion; return dbVersion;
} }
@ -170,6 +174,7 @@ public class ConnectInfo {
public Session session; public Session session;
public LinkedHashMap<String, Object> getExtendMap() { public LinkedHashMap<String, Object> getExtendMap() {
if (ObjectUtils.isEmpty(extendInfo)) { if (ObjectUtils.isEmpty(extendInfo)) {
@ -591,4 +596,12 @@ public class ConnectInfo {
public void setLoginUser(String loginUser) { public void setLoginUser(String loginUser) {
this.loginUser = loginUser; this.loginUser = loginUser;
} }
public Date getLastAccessTime() {
return lastAccessTime;
}
public void setLastAccessTime(Date lastAccessTime) {
this.lastAccessTime = lastAccessTime;
}
} }

View File

@ -0,0 +1,112 @@
package ai.chat2db.spi.sql;
import lombok.extern.slf4j.Slf4j;
import org.h2.engine.ConnectionInfo;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Date;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
@Slf4j
public class ConnectionPool {
private static ConcurrentHashMap<String, ConnectInfo> CONNECTION_MAP = new ConcurrentHashMap<>();
static {
new Thread(() -> {
while (true) {
try {
Thread.sleep(1000 * 60 * 10);
CONNECTION_MAP.forEach((k, v) -> {
if (v.getLastAccessTime().getTime() + 1000 * 60 * 60 < System.currentTimeMillis()) {
try {
Connection connection = v.getConnection();
if (connection != null ) {
connection.close();
CONNECTION_MAP.remove(k);
}
} catch (SQLException e) {
log.error("close connection error", e);
}
}
});
} catch (InterruptedException e) {
log.error("close connection error", e);
}
}
}).start();
}
public static ConnectInfo getAndRemove(String key) {
return CONNECTION_MAP.computeIfPresent(key, (k, v) -> {
CONNECTION_MAP.remove(k); // 从 Map 中移除
return v; // 返回值
});
}
public static Connection getConnection(ConnectInfo connectInfo) {
try {
Connection connection = connectInfo.getConnection();
if (connection != null && !connection.isClosed()) {
log.info("get connection from loacl");
return connection;
}
ConnectInfo cache = getAndRemove(connectInfo.key());
if (cache != null) {
connection = cache.getConnection();
if (connection != null && !connection.isClosed()) {
log.info("get connection from cache");
connectInfo.setConnection(connection);
return connection;
}
}
synchronized (connectInfo) {
connection = connectInfo.getConnection();
try {
if (connection != null && !connection.isClosed()) {
log.info("get connection from cache");
return connection;
} else {
log.info("get connection from db begin");
connection = Chat2DBContext.getDBManage().getConnection(connectInfo);
log.info("get connection from db end");
}
} catch (SQLException e) {
log.error("get connection error", e);
log.info("get connection from db begin2");
connection = Chat2DBContext.getDBManage().getConnection(connectInfo);
log.info("get connection from db end2");
}
connectInfo.setConnection(connection);
return connection;
}
} catch (SQLException e) {
log.error("get connection error", e);
}
return null;
}
public static void close(ConnectInfo connectInfo) {
String key = connectInfo.key();
synchronized (key) {
ConnectInfo cache = getAndRemove(key);
if (cache != null) {
Connection connection = cache.getConnection();
if (connection != null) {
try {
connection.close();
} catch (SQLException e) {
log.error("close connection error", e);
}
}
}
connectInfo.setLastAccessTime(new Date());
CONNECTION_MAP.put(key, connectInfo);
}
}
}

View File

@ -25,6 +25,8 @@ import org.apache.commons.lang3.StringUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors; import java.util.stream.Collectors;
/** /**
@ -117,34 +119,92 @@ public class SqlUtils {
return null; return null;
} }
private static final String DELIMITER_AFTER_REGEX = "^\\s*(?i)delimiter\\s+(\\S+)";
private static final String DELIMITER_REGEX = "(?mi)^\\s*delimiter\\s*;?";
private static final String EVENT_REGEX = "(?i)\\bcreate\\s+event\\b.*?\\bend\\b";
public static List<String> parse(String sql, DbType dbType) { public static List<String> parse(String sql, DbType dbType) {
List<String> list = new ArrayList<>(); List<String> list = new ArrayList<>();
try { try {
if (StringUtils.isBlank(sql)) {
return list;
}
sql = removeDelimiter(sql);
if (StringUtils.isBlank(sql)) {
return list;
}
Statements statements = CCJSqlParserUtil.parseStatements(sql); Statements statements = CCJSqlParserUtil.parseStatements(sql);
// Iterate through each statement // Iterate through each statement
for (Statement stmt : statements.getStatements()) { for (Statement stmt : statements.getStatements()) {
if (!(stmt instanceof CreateProcedure)) { if (!(stmt instanceof CreateProcedure)) {
list.add(updateNow(stmt.toString(),dbType)); list.add(updateNow(stmt.toString(), dbType));
} }
} }
if (CollectionUtils.isEmpty(list)) { if (CollectionUtils.isEmpty(list)) {
list.add(sql); list.add(sql);
} }
} catch (Exception e) { } catch (Exception e) {
list = SQLParserUtils.splitAndRemoveComment(sql, dbType); try {
return splitWithCreateEvent(sql, dbType);
} catch (Exception e1) {
return SQLParserUtils.splitAndRemoveComment(sql, dbType);
}
} }
return list; return list;
} }
private static String updateNow(String sql,DbType dbType) { private static String removeDelimiter(String str) {
if(StringUtils.isBlank(sql) || !DbType.mysql.equals(dbType)){ try {
if (str.toUpperCase().contains("DELIMITER")) {
Pattern pattern = Pattern.compile(DELIMITER_AFTER_REGEX, Pattern.MULTILINE);
Matcher matcher = pattern.matcher(str);
while (matcher.find()) {
// 获取并打印 "DELIMITER" 后的第一个字符串
String mm = matcher.group(1);
if (!";".equals(mm)) {
str = str.replace(mm, "");
}
}
}
return str.replaceAll(DELIMITER_REGEX, "");
}catch (Exception e){
return str;
}
}
private static List<String> splitWithCreateEvent(String str, DbType dbType) {
List<String> list = new ArrayList<>();
String sql = SQLParserUtils.removeComment(str, dbType).trim();
Pattern pattern = Pattern.compile(EVENT_REGEX, Pattern.DOTALL);
Matcher matcher = pattern.matcher(sql);
StringBuilder stringBuilder = new StringBuilder();
int lastEnd = 0; // 用于跟踪上一个匹配的结束位置
while (matcher.find()) {
if (matcher.start() > lastEnd) {
List<String> l = SQLParserUtils.split(sql.substring(lastEnd, matcher.start()), dbType);
list.addAll(l);
}
list.add(matcher.group());
lastEnd = matcher.end(); // 更新上一个匹配的结束位置
}
if (lastEnd < sql.length()) {
List<String> l = SQLParserUtils.split(sql.substring(lastEnd), dbType);
list.addAll(l);
}
return list;
}
private static String updateNow(String sql, DbType dbType) {
if (StringUtils.isBlank(sql) || !DbType.mysql.equals(dbType)) {
return sql; return sql;
} }
if(sql.contains("default now ()")){ if (sql.contains("default now ()")) {
return sql.replace("default now ()","default CURRENT_TIMESTAMP"); return sql.replace("default now ()", "default CURRENT_TIMESTAMP");
} }
if(sql.contains("DEFAULT now ()")){ if (sql.contains("DEFAULT now ()")) {
return sql.replace("DEFAULT now ()","DEFAULT CURRENT_TIMESTAMP"); return sql.replace("DEFAULT now ()", "DEFAULT CURRENT_TIMESTAMP");
} }
return sql; return sql;
} }
@ -162,6 +222,7 @@ public class SqlUtils {
return dataTypeEnum.getSqlValue(value); return dataTypeEnum.getSqlValue(value);
} }
public static boolean hasPageLimit(String sql, DbType dbType) { public static boolean hasPageLimit(String sql, DbType dbType) {
try { try {
Statement statement = CCJSqlParserUtil.parse(sql); Statement statement = CCJSqlParserUtil.parse(sql);
@ -185,5 +246,4 @@ public class SqlUtils {
} }
return false; return false;
} }
} }