implement connection pool for psql

This commit is contained in:
Mohammad Azmi
2026-03-09 09:59:14 +07:00
parent 6ce8e67971
commit e5a33bebb6
8 changed files with 350 additions and 115 deletions

View File

@@ -13,6 +13,7 @@ import platformInfo from '@/common/platform_info';
import { LicenseKey } from '@/common/appdb/models/LicenseKey';
import { IdentifyResult } from 'sql-query-identifier/lib/defines';
import { Transcoder } from '../serialization/transcoders';
import { GenericConnectionPool } from './GenericConnectionPool';
const log = rawLog.scope('BasicDatabaseClient');
const logger = () => log;
@@ -75,6 +76,7 @@ export abstract class BasicDatabaseClient<RawResultType extends BaseQueryResult,
connErrHandler: (msg: string) => void = null;
reservedConnections: Map<number, Conn> = new Map<number, Conn>();
transcoders: Transcoder<any, any>[] = [];
abstract pool: GenericConnectionPool<Conn>;
constructor(knex: Knex | null, contextProvider: AppContextProvider, server: IDbConnectionServer, database: IDbConnectionDatabase) {
this.knex = knex;
@@ -129,34 +131,22 @@ export abstract class BasicDatabaseClient<RawResultType extends BaseQueryResult,
await this.disconnect();
}
// reuse existing tunnel
if (this.server.config.ssh && !this.server.sshTunnel) {
logger().debug('creating ssh tunnel');
this.server.sshTunnel = await connectTunnel(this.server.config);
this.server.config.localHost = this.server.sshTunnel.localHost
this.server.config.localPort = this.server.sshTunnel.localPort
}
await this.pool.start();
} catch (err) {
logger().error('Connection error %j', err);
// this.disconnect(this.server, this.database);
throw new Error('Database Connection Error: ' + err.message);
throw new Error('Database Connection Error: ' + err.message || String(err));
} finally {
this.database.connecting = false;
}
}
async disconnect(): Promise<void> {
this.database.connecting = false;
if (this.server.sshTunnel) {
await this.server.sshTunnel.connection.shutdown();
}
if (this.server.db[this.database.database]) {
// delete this.server.db[this.database.database]
}
await this.knex?.destroy();
await this.pool.end();
}
// ****************************************************************************

View File

@@ -0,0 +1,84 @@
import { IDbConnectionServer } from "@/lib/db/backendTypes";
import { IDbConnectionDatabase } from "@/lib/db/types";
import connectTunnel from '@/lib/db/tunnel';
import rawLog from "@bksLogger";
const log = rawLog.scope('BasicDatabaseClient');
const logger = () => log;
/**
* A class that uniforms connection pool logic including the ssh tunnel.
*
* @example
*
* class PostgresConnectionPool extends GenericConnectionPool<pg.Client> {
* // implement abstract methods here
* }
*
* const pool = new PostgresConnectionPool();
* const client = await pool.connect(); // connect to the database. no need to call `.start()`.
* await pool.end(); // close it manually when needed
*
* */
export abstract class GenericConnectionPool<ClientType> {
private started: boolean = false;
protected readonly server: IDbConnectionServer;
protected readonly database: IDbConnectionDatabase;
constructor(options: {
server: IDbConnectionServer,
database: IDbConnectionDatabase,
}) {
this.server = options.server;
this.database = options.database;
}
abstract doStart(): Promise<void>;
abstract doEnd(): Promise<void>;
abstract doConnect(): Promise<ClientType>;
async start() {
// reuse existing tunnel
if (this.server.config.ssh && !this.server.sshTunnel) {
logger().debug('creating ssh tunnel');
this.server.sshTunnel = await connectTunnel(this.server.config);
this.server.config.localHost = this.server.sshTunnel.localHost
this.server.config.localPort = this.server.sshTunnel.localPort
} else if (this.server.sshTunnel){
logger().debug('reusing ssh tunnel');
await this.server.sshTunnel.reconnect();
}
await this.doStart();
this.started = true;
}
async end() {
await this.doEnd();
if (this.server.sshTunnel) {
await this.server.sshTunnel.connection.shutdown();
this.server.sshTunnel = null;
}
this.started = false;
}
async connect(): Promise<ClientType> {
if (!this.started) {
await this.start();
}
return await this.doConnect();
}
protected async onConnectionTerminatedUnexpectedly() {
await this.end();
this.started = false;
}
}

View File

@@ -13,9 +13,7 @@ import { FilterOptions, OrderBy, TableFilter, TableUpdateResult, TableResult, Ro
import { buildDatabaseFilter, buildDeleteQueries, buildInsertQueries, buildSchemaFilter, buildSelectQueriesFromUpdates, buildUpdateQueries, escapeString, refreshTokenIfNeeded, joinQueries, errorMessages } from './utils';
import { createCancelablePromise, joinFilters } from '../../../common/utils';
import { errors } from '../../errors';
// FIXME (azmi): use BksConfig
import globals from '../../../common/globals';
import { HasPool, VersionInfo } from './postgresql/types'
import { VersionInfo } from './postgresql/types'
import { PsqlCursor } from './postgresql/PsqlCursor';
import { PostgresqlChangeBuilder } from '@shared/lib/sql/change_builder/PostgresqlChangeBuilder';
import { AlterPartitionsSpec, IndexColumn, TableKey } from '@shared/lib/dialects/models';
@@ -27,6 +25,7 @@ import BksConfig from '@/common/bksConfig';
import { IDbConnectionServer } from '../backendTypes';
import { GenericBinaryTranscoder } from "../serialization/transcoders";
import {AzureAuthService} from "@/lib/db/authentication/azure";
import { PsqlConnectionPool } from './postgresql/PsqlConnectionPool';
const PD = PostgresData
@@ -87,16 +86,16 @@ const postgresContext = {
export class PostgresClient extends BasicDatabaseClient<QueryResult, PoolClient> {
version: VersionInfo;
conn: HasPool;
_defaultSchema: string;
dataTypes: any;
transcoders = [GenericBinaryTranscoder];
interval: NodeJS.Timeout;
pool: PsqlConnectionPool;
constructor(server: IDbConnectionServer, database: IDbConnectionDatabase) {
super(knex, postgresContext, server, database);
this.dialect = 'psql';
this.readOnlyMode = server?.config?.readOnlyMode || false;
this.pool = new PsqlConnectionPool({ server, database });
}
async versionString(): Promise<string> {
@@ -134,74 +133,12 @@ export class PostgresClient extends BasicDatabaseClient<QueryResult, PoolClient>
return;
}
await super.connect();
const dbConfig = await this.configDatabase(this.server, this.database);
log.info("CONFIG: ", dbConfig)
this.conn = {
pool: new pg.Pool(dbConfig)
};
const test = await this.conn.pool.connect()
if (this.server.config.iamAuthOptions?.iamAuthenticationEnabled) {
this.interval = setInterval(async () => {
try {
const newPassword = await refreshTokenIfNeeded(this.server.config.iamAuthOptions, this.server, this.server.config.port || 5432);
const newPool = new pg.Pool({
...dbConfig,
password: newPassword,
});
const test = await newPool.connect();
test.release();
if (this.conn?.pool) {
await this.conn.pool.end();
}
this.conn = { pool: newPool };
log.info('Token refreshed successfully and connection pool updated.');
} catch (err) {
log.error('Could not refresh token or update connection pool!', err);
}
// FIXME (azmi): use BksConfig
}, globals.iamRefreshTime);
}
test.release();
this.conn.pool.on('acquire', (_client) => {
log.debug('Pool event: connection acquired')
})
this.conn.pool.on('error', (err, _client) => {
log.error("Pool event: connection error:", err.name, err.message)
})
// @ts-ignore
this.conn.pool.on('release', (err, client) => {
log.debug('Pool event: connection released')
})
log.debug('connected');
this._defaultSchema = await this.getSchema();
this.version = await this.getVersion();
this.dataTypes = await this.getTypes();
this.database.connected = true;
}
async disconnect(): Promise<void> {
if(this.interval){
clearInterval(this.interval);
}
await super.disconnect();
this.conn.pool.end();
}
async listTables(filter?: FilterOptions): Promise<TableOrView[]> {
const schemaFilter = buildSchemaFilter(filter, 'table_schema');
@@ -1100,7 +1037,7 @@ export class PostgresClient extends BasicDatabaseClient<QueryResult, PoolClient>
const cursorOpts = {
query: qs.query,
params: qs.params,
conn: this.conn,
pool: this.pool,
chunkSize
}
@@ -1115,7 +1052,7 @@ export class PostgresClient extends BasicDatabaseClient<QueryResult, PoolClient>
const cursorOpts = {
query: query,
params: [],
conn: this.conn,
pool: this.pool,
chunkSize
}
@@ -1291,7 +1228,7 @@ export class PostgresClient extends BasicDatabaseClient<QueryResult, PoolClient>
throw new Error(errorMessages.maxReservedConnections)
}
const conn = await this.conn.pool.connect();
const conn = await this.pool.connect();
this.pushConnection(tabId, conn);
}
@@ -1338,7 +1275,7 @@ export class PostgresClient extends BasicDatabaseClient<QueryResult, PoolClient>
// this will manage the connection for you, but won't call rollback
// on an error, for that use `runWithTransaction`
async runWithConnection<T>(child: (c: PoolClient) => Promise<T>): Promise<T> {
const connection = await this.conn.pool.connect()
const connection = await this.pool.connect()
try {
return await child(connection)
} finally {

View File

@@ -0,0 +1,177 @@
import _ from "lodash";
import { readFileSync } from "fs";
import pg, { PoolClient, PoolConfig } from "pg";
import logRaw from "@bksLogger";
import { GenericConnectionPool } from "@/lib/db/clients/GenericConnectionPool";
import type { IDbConnectionServer } from "@/lib/db/backendTypes";
import { refreshTokenIfNeeded } from "@/lib/db/clients/utils";
import BksConfig from "@/common/bksConfig";
import { AzureAuthService } from "@/lib/db/authentication/azure";
// FIXME (azmi): use BksConfig
import globals from '@/common/globals';
import { HasPool } from "@/lib/db/clients/postgresql/types";
const log = logRaw.scope("postgresql");
export class PsqlConnectionPool extends GenericConnectionPool<PoolClient> {
private interval: NodeJS.Timeout | null = null;
private conn: HasPool;
async doStart(): Promise<void> {
const dbConfig = await this.configDatabase(this.server, this.database);
log.info("CONFIG: ", dbConfig)
this.conn = {
pool: new pg.Pool(dbConfig)
};
const test = await this.conn.pool.connect()
if (this.server.config.iamAuthOptions?.iamAuthenticationEnabled) {
this.interval = setInterval(async () => {
try {
const newPassword = await refreshTokenIfNeeded(this.server.config.iamAuthOptions, this.server, this.server.config.port || 5432);
const newPool = new pg.Pool({
...dbConfig,
password: newPassword,
});
const test = await newPool.connect();
test.release();
if (this.conn?.pool) {
await this.conn.pool.end();
}
this.conn = { pool: newPool };
log.info('Token refreshed successfully and connection pool updated.');
} catch (err) {
log.error('Could not refresh token or update connection pool!', err);
}
// FIXME (azmi): use BksConfig
}, globals.iamRefreshTime);
}
test.release();
this.conn.pool.on('acquire', (_client) => {
log.debug('Pool event: connection acquired')
})
this.conn.pool.on('error', (err, _client) => {
log.error("Pool event: connection error:", err.name, err.message)
if (err.message === "Connection terminated unexpectedly") {
this.onConnectionTerminatedUnexpectedly();
}
})
// @ts-ignore
this.conn.pool.on('release', (err, client) => {
log.debug('Pool event: connection released')
})
log.debug('connected');
}
async doEnd(): Promise<void> {
if (this.interval){
clearInterval(this.interval);
}
if (this.conn.pool.ended) {
return;
}
await this.conn.pool.end();
}
async doConnect(): Promise<PoolClient> {
return await this.conn.pool.connect();
}
protected async configDatabase(server: IDbConnectionServer, database: { database: string}) {
let iamToken = undefined;
if(server.config.iamAuthOptions?.iamAuthenticationEnabled){
iamToken = await refreshTokenIfNeeded(server.config?.iamAuthOptions, server, server.config.port || 5432)
}
const config: PoolConfig = {
host: server.config.host,
port: server.config.port || undefined,
password: iamToken || server.config.password || undefined,
database: database.database,
max: BksConfig.db.postgres.maxConnections, // max idle connections per time (30 secs)
connectionTimeoutMillis: BksConfig.db.postgres.connectionTimeout,
idleTimeoutMillis: BksConfig.db.postgres.idleTimeout,
};
if (server.config.azureAuthOptions?.azureAuthEnabled) {
const authService = new AzureAuthService();
config.user = server.config.user
return authService.configDB(server, config)
}
if (
server.config.client === "postgresql" &&
// fix https://github.com/beekeeper-studio/beekeeper-studio/issues/2630
// we only need SSL for iam authentication
server.config?.iamAuthOptions?.iamAuthenticationEnabled
){
server.config.ssl = true;
}
return this.configurePool(config, server, null);
}
protected configurePool(config: PoolConfig, server: IDbConnectionServer, tempUser: string) {
if (tempUser) {
config.user = tempUser
} else if (server.config.user) {
config.user = server.config.user
} else if (server.config.osUser) {
config.user = server.config.osUser
}
if (server.config.socketPathEnabled) {
config.host = server.config.socketPath;
config.port = server.config.port;
return config;
}
if (server.sshTunnel) {
config.host = server.config.localHost;
config.port = server.config.localPort;
}
if (server.config.ssl) {
config.ssl = {}
if (server.config.sslCaFile) {
config.ssl.ca = readFileSync(server.config.sslCaFile);
}
if (server.config.sslCertFile) {
config.ssl.cert = readFileSync(server.config.sslCertFile);
}
if (server.config.sslKeyFile) {
config.ssl.key = readFileSync(server.config.sslKeyFile);
}
if (!config.ssl.key && !config.ssl.ca && !config.ssl.cert) {
// TODO: provide this as an option in settings
// not per-connection
// How it works:
// if false, cert can be self-signed
// if true, has to be from a public CA
// Heroku certs are self-signed.
// if you provide ca/cert/key files, it overrides this
config.ssl.rejectUnauthorized = false
} else {
config.ssl.rejectUnauthorized = server.config.sslRejectUnauthorized
}
}
return config;
}
}

View File

@@ -2,14 +2,14 @@ import { PoolClient } from "pg"
import Cursor from "pg-cursor"
import { BeeCursor } from "../../models"
import rawlog from '@bksLogger'
import { HasPool } from './types'
import { PsqlConnectionPool } from "./PsqlConnectionPool"
const log = rawlog.scope('postgresql/cursor')
interface CursorOptions {
query: string,
params: (string | string[])[],
conn: HasPool,
pool: PsqlConnectionPool,
chunkSize: number
}
@@ -32,7 +32,7 @@ export class PsqlCursor extends BeeCursor {
async start() {
this.client = await this.options.conn.pool.connect()
this.client = await this.options.pool.connect()
this.client.on('error', this.handleError.bind(this))
const { query, params } = this.options
this.cursor = this.client.query(new Cursor(query, params, {rowMode: 'array'}))

View File

@@ -25,4 +25,4 @@ services:
- USER_NAME=beekeeper #optional
- SUDO_ACCESS=true
ports:
- 2222
- "7222:2222"

View File

@@ -1,4 +1,6 @@
import { DockerComposeEnvironment, Wait } from 'testcontainers'
import { BasicDatabaseClient } from '@/lib/db/clients/BasicDatabaseClient';
import { IDbConnectionPublicServer } from '@/lib/db/serverTypes';
import { SshEnvironment } from '@tests/integration/lib/db/clients/ssh/SshEnvironment';
import ConnectionProvider from '@commercial/backend/lib/connection-provider';
import { dbtimeout } from '../../../../lib/db'
import { TestOrmConnection } from '@tests/lib/TestOrmConnection';
@@ -8,29 +10,22 @@ import { TestOrmConnection } from '@tests/lib/TestOrmConnection';
describe("SSH Tunnel Tests", () => {
jest.setTimeout(dbtimeout)
let container;
let connection
let database
let environment
let environment: SshEnvironment;
let connection: IDbConnectionPublicServer;
let database: BasicDatabaseClient<any>;
beforeAll(async () => {
await TestOrmConnection.connect()
const timeoutDefault = 5000
environment = await new DockerComposeEnvironment("tests/docker", "ssh.yml")
.withWaitStrategy('test_ssh_postgres', Wait.forLogMessage("database system is ready to accept connections", 2))
.withWaitStrategy('test_ssh', Wait.forListeningPorts())
.up()
container = environment.getContainer('test_ssh')
const db = environment.getContainer('test_ssh_postgres')
environment = new SshEnvironment();
await environment.start();
jest.setTimeout(timeoutDefault)
const quickConfig = {
host: db.getHost(),
port: db.getMappedPort(5432),
host: environment.getDbHost(),
port: environment.getDbPort(),
username: 'postgres',
password: 'example',
connectionType: 'postgresql'
@@ -38,7 +33,6 @@ describe("SSH Tunnel Tests", () => {
// NB: If this fails it's due to ipv4 vs ipv6 mixup.
// as of Node 17+ DNS defaults to v6 instead of v4.
let host = container.getHost()
const config = {
connectionType: 'postgresql',
host: 'postgres',
@@ -46,8 +40,8 @@ describe("SSH Tunnel Tests", () => {
username: 'postgres',
password: 'example',
sshEnabled: true,
sshHost: container.getHost(),
sshPort: container.getMappedPort(2222),
sshHost: environment.getSshHost(),
sshPort: environment.getSshPort(),
sshUsername: 'beekeeper',
sshPassword: 'password'
}
@@ -59,7 +53,6 @@ describe("SSH Tunnel Tests", () => {
await query.execute()
await qdb.disconnect();
console.log("Starting SSH test with config", config)
connection = ConnectionProvider.for(config)
database = connection.createConnection('integration_test')
await database.connect()
@@ -67,18 +60,19 @@ describe("SSH Tunnel Tests", () => {
describe("Can SSH and run a query", () => {
it("should work", async () => {
const query = await database.query('select 1');
await query.execute()
} )
await database.executeQuery('select 1');
})
it("should re-estabilish lost connection", async () => {
await environment.restart();
await database.executeQuery('select 1');
});
})
afterAll(async () => {
if (database) {
await database.disconnect()
}
if (container) {
await container.stop()
}
if (environment) {
await environment.stop()
}

View File

@@ -0,0 +1,53 @@
import {
DockerComposeEnvironment,
StartedDockerComposeEnvironment,
Wait,
} from "testcontainers";
export class SshEnvironment {
private environment!: StartedDockerComposeEnvironment;
async start() {
this.environment = await new DockerComposeEnvironment(
"tests/docker",
"ssh.yml"
)
.withWaitStrategy(
"test_ssh_postgres",
Wait.forLogMessage("database system is ready to accept connections", 2)
)
.withWaitStrategy("test_ssh", Wait.forListeningPorts())
.up();
}
async restart() {
const container = this.environment.getContainer("test_ssh");
if (container) {
await container.restart();
// wait until it's fully restarted
await new Promise((resolve) => setTimeout(resolve, 500));
}
}
async stop() {
await this.environment?.stop();
}
getDbHost() {
return this.environment.getContainer("test_ssh_postgres").getHost();
}
getDbPort() {
return this.environment
.getContainer("test_ssh_postgres")
.getMappedPort(5432);
}
getSshHost() {
return this.environment.getContainer("test_ssh").getHost();
}
getSshPort() {
return 7222;
}
}