cache connection

This commit is contained in:
SwallowGG
2024-05-14 15:07:19 +08:00
parent a52d0c6982
commit 3a6ee931ca
6 changed files with 237 additions and 57 deletions

View File

@ -61,6 +61,7 @@ public enum OracleColumnTypeEnum implements ColumnBuilder {
NATIONAL_CHAR_VARYING("NATIONAL CHAR VARYING", true, false, true, false, false, false, true, true, false, true),
NATIONAL_CHARACTER("NATIONAL CHARACTER", true, false, true, false, false, false, true, true, false, true),

View File

@ -25,15 +25,7 @@ import java.util.concurrent.ConcurrentHashMap;
public class Chat2DBContext {
private static final ThreadLocal<ConnectInfo> CONNECT_INFO_THREAD_LOCAL = new ThreadLocal<>();
// private static final Cache<String, ConnectInfo> CONNECT_INFO_CACHE = CacheBuilder.newBuilder()
// .maximumSize(1000)
// .expireAfterAccess(5, TimeUnit.MINUTES)
// .removalListener((RemovalListener<String, ConnectInfo>) notification -> {
// if (notification.getValue() != null) {
// System.out.println("remove connect info " + notification.getKey());
// notification.getValue().close();
// }
// }).build();
public static Map<String, Plugin> PLUGIN_MAP = new ConcurrentHashMap<>();
@ -87,36 +79,37 @@ public class Chat2DBContext {
}
public static Connection getConnection() {
ConnectInfo connectInfo = getConnectInfo();
Connection connection = connectInfo.getConnection();
try {
if (connection == null || connection.isClosed()) {
synchronized (connectInfo) {
connection = connectInfo.getConnection();
try {
if (connection != null && !connection.isClosed()) {
log.info("get connection from cache");
return connection;
} else {
log.info("get connection from db begin");
connection = getDBManage().getConnection(connectInfo);
log.info("get connection from db end");
}
} catch (SQLException e) {
log.error("get connection error", e);
log.info("get connection from db begin2");
connection = getDBManage().getConnection(connectInfo);
log.info("get connection from db end2");
}
connectInfo.setConnection(connection);
}
}
} catch (SQLException e) {
log.error("get connection error", e);
}
return connection;
// ConnectInfo connectInfo = getConnectInfo();
// Connection connection = connectInfo.getConnection();
// try {
// if (connection == null || connection.isClosed()) {
// synchronized (connectInfo) {
// connection = connectInfo.getConnection();
// try {
// if (connection != null && !connection.isClosed()) {
// log.info("get connection from cache");
// return connection;
// } else {
// log.info("get connection from db begin");
// connection = getDBManage().getConnection(connectInfo);
// log.info("get connection from db end");
// }
// } catch (SQLException e) {
// log.error("get connection error", e);
// log.info("get connection from db begin2");
// connection = getDBManage().getConnection(connectInfo);
// log.info("get connection from db end2");
// }
// connectInfo.setConnection(connection);
// }
// }
// } catch (SQLException e) {
// log.error("get connection error", e);
// }
return ConnectionPool.getConnection(getConnectInfo());
}
public static String getDbVersion() {
ConnectInfo connectInfo = getConnectInfo();
String dbVersion = connectInfo.getDbVersion();
@ -157,8 +150,9 @@ public class Chat2DBContext {
public static void removeContext() {
ConnectInfo connectInfo = CONNECT_INFO_THREAD_LOCAL.get();
if (connectInfo != null) {
connectInfo.close();
// connectInfo.close();
CONNECT_INFO_THREAD_LOCAL.remove();
ConnectionPool.close(connectInfo);
}
}

View File

@ -4,6 +4,7 @@ package ai.chat2db.spi.sql;
import java.sql.Connection;
import java.sql.SQLException;
import java.time.LocalDateTime;
import java.util.Date;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Objects;
@ -142,6 +143,9 @@ public class ConnectInfo {
private DriverConfig driverConfig;
private Date lastAccessTime;
public String getDbVersion() {
return dbVersion;
}
@ -170,6 +174,7 @@ public class ConnectInfo {
public Session session;
public LinkedHashMap<String, Object> getExtendMap() {
if (ObjectUtils.isEmpty(extendInfo)) {
@ -591,4 +596,12 @@ public class ConnectInfo {
public void setLoginUser(String loginUser) {
this.loginUser = loginUser;
}
public Date getLastAccessTime() {
return lastAccessTime;
}
public void setLastAccessTime(Date lastAccessTime) {
this.lastAccessTime = lastAccessTime;
}
}

View File

@ -0,0 +1,112 @@
package ai.chat2db.spi.sql;
import lombok.extern.slf4j.Slf4j;
import org.h2.engine.ConnectionInfo;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Date;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
@Slf4j
public class ConnectionPool {
private static ConcurrentHashMap<String, ConnectInfo> CONNECTION_MAP = new ConcurrentHashMap<>();
static {
new Thread(() -> {
while (true) {
try {
Thread.sleep(1000 * 60 * 10);
CONNECTION_MAP.forEach((k, v) -> {
if (v.getLastAccessTime().getTime() + 1000 * 60 * 60 < System.currentTimeMillis()) {
try {
Connection connection = v.getConnection();
if (connection != null ) {
connection.close();
CONNECTION_MAP.remove(k);
}
} catch (SQLException e) {
log.error("close connection error", e);
}
}
});
} catch (InterruptedException e) {
log.error("close connection error", e);
}
}
}).start();
}
public static ConnectInfo getAndRemove(String key) {
return CONNECTION_MAP.computeIfPresent(key, (k, v) -> {
CONNECTION_MAP.remove(k); // 从 Map 中移除
return v; // 返回值
});
}
public static Connection getConnection(ConnectInfo connectInfo) {
try {
Connection connection = connectInfo.getConnection();
if (connection != null && !connection.isClosed()) {
log.info("get connection from loacl");
return connection;
}
ConnectInfo cache = getAndRemove(connectInfo.key());
if (cache != null) {
connection = cache.getConnection();
if (connection != null && !connection.isClosed()) {
log.info("get connection from cache");
connectInfo.setConnection(connection);
return connection;
}
}
synchronized (connectInfo) {
connection = connectInfo.getConnection();
try {
if (connection != null && !connection.isClosed()) {
log.info("get connection from cache");
return connection;
} else {
log.info("get connection from db begin");
connection = Chat2DBContext.getDBManage().getConnection(connectInfo);
log.info("get connection from db end");
}
} catch (SQLException e) {
log.error("get connection error", e);
log.info("get connection from db begin2");
connection = Chat2DBContext.getDBManage().getConnection(connectInfo);
log.info("get connection from db end2");
}
connectInfo.setConnection(connection);
return connection;
}
} catch (SQLException e) {
log.error("get connection error", e);
}
return null;
}
public static void close(ConnectInfo connectInfo) {
String key = connectInfo.key();
synchronized (key) {
ConnectInfo cache = getAndRemove(key);
if (cache != null) {
Connection connection = cache.getConnection();
if (connection != null) {
try {
connection.close();
} catch (SQLException e) {
log.error("close connection error", e);
}
}
}
connectInfo.setLastAccessTime(new Date());
CONNECTION_MAP.put(key, connectInfo);
}
}
}

View File

@ -25,6 +25,8 @@ import org.apache.commons.lang3.StringUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
/**
@ -117,9 +119,21 @@ public class SqlUtils {
return null;
}
private static final String DELIMITER_AFTER_REGEX = "^\\s*(?i)delimiter\\s+(\\S+)";
private static final String DELIMITER_REGEX = "(?mi)^\\s*delimiter\\s*;?";
private static final String EVENT_REGEX = "(?i)\\bcreate\\s+event\\b.*?\\bend\\b";
public static List<String> parse(String sql, DbType dbType) {
List<String> list = new ArrayList<>();
try {
if (StringUtils.isBlank(sql)) {
return list;
}
sql = removeDelimiter(sql);
if (StringUtils.isBlank(sql)) {
return list;
}
Statements statements = CCJSqlParserUtil.parseStatements(sql);
// Iterate through each statement
for (Statement stmt : statements.getStatements()) {
@ -131,11 +145,57 @@ public class SqlUtils {
list.add(sql);
}
} catch (Exception e) {
list = SQLParserUtils.splitAndRemoveComment(sql, dbType);
try {
return splitWithCreateEvent(sql, dbType);
} catch (Exception e1) {
return SQLParserUtils.splitAndRemoveComment(sql, dbType);
}
}
return list;
}
private static String removeDelimiter(String str) {
try {
if (str.toUpperCase().contains("DELIMITER")) {
Pattern pattern = Pattern.compile(DELIMITER_AFTER_REGEX, Pattern.MULTILINE);
Matcher matcher = pattern.matcher(str);
while (matcher.find()) {
// 获取并打印 "DELIMITER" 后的第一个字符串
String mm = matcher.group(1);
if (!";".equals(mm)) {
str = str.replace(mm, "");
}
}
}
return str.replaceAll(DELIMITER_REGEX, "");
}catch (Exception e){
return str;
}
}
private static List<String> splitWithCreateEvent(String str, DbType dbType) {
List<String> list = new ArrayList<>();
String sql = SQLParserUtils.removeComment(str, dbType).trim();
Pattern pattern = Pattern.compile(EVENT_REGEX, Pattern.DOTALL);
Matcher matcher = pattern.matcher(sql);
StringBuilder stringBuilder = new StringBuilder();
int lastEnd = 0; // 用于跟踪上一个匹配的结束位置
while (matcher.find()) {
if (matcher.start() > lastEnd) {
List<String> l = SQLParserUtils.split(sql.substring(lastEnd, matcher.start()), dbType);
list.addAll(l);
}
list.add(matcher.group());
lastEnd = matcher.end(); // 更新上一个匹配的结束位置
}
if (lastEnd < sql.length()) {
List<String> l = SQLParserUtils.split(sql.substring(lastEnd), dbType);
list.addAll(l);
}
return list;
}
private static String updateNow(String sql, DbType dbType) {
if (StringUtils.isBlank(sql) || !DbType.mysql.equals(dbType)) {
return sql;
@ -162,6 +222,7 @@ public class SqlUtils {
return dataTypeEnum.getSqlValue(value);
}
public static boolean hasPageLimit(String sql, DbType dbType) {
try {
Statement statement = CCJSqlParserUtil.parse(sql);
@ -185,5 +246,4 @@ public class SqlUtils {
}
return false;
}
}