From 1e383b0c1e98734ad4bf974f0435512d10c33246 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Bedi?= Date: Mon, 26 May 2025 08:54:18 +0200 Subject: [PATCH] Postgres: Switch the datasource plugin from lib/pq to pgx (#103961) * Create libpqToPGX feature toggle * Refactor PostgreSQL datasource to support PGX with feature toggle - Updated `ProvideService` to accept feature toggles for enabling PGX. - Modified integration tests to use the new PGX connection method. - Introduced new functions for handling PGX connections and queries. - Enhanced TLS configuration handling for PostgreSQL connections. - Updated existing tests to ensure compatibility with PGX and new connection methods. * Update PostgreSQL datasource to enhance connection pooling and error handling - Increased `MaxOpenConns` to 10 in integration tests for improved connection management. - Refactored connection handling in `newPostgresPGX` to return a connection pool instead of a single connection. - Updated health check error handling to utilize context and feature toggles for better error reporting. - Adjusted `DisposePGX` method to close the connection pool properly. - Enhanced query execution to acquire connections from the pool, ensuring efficient resource usage. * Cleanup * Revert postgres_test unnecessary changes * Rename feature toggle from `libpqToPGX` to `postgresDSUsePGX` * Add null check to dispose method * Fix lint issues * Refactor connection string generation * Address comment in health check file * Rename p to pool * Refactor executeQueryPGX and split into multiple functions * Fix lint issues * The returning error message from PGX is enough no need to separate the error code. * Move TLS handling to newPostgresPGX function * Disable ssl for integration tests * Use MaxIdleConns option * Remove old feature toggle * Rename`generateConnectionConfigPGX` to `generateConnectionStringPGX` * Add back part of the error messages * Don't show max idle connections option when PGX enabled * Address comments from Sriram * Add back Sriram's changes * PostgreSQL: Rework tls manager to use temporary files instead (#105330) * Rework tls manager to use temporary files instead * Lint and test fixes * Update pkg/tsdb/grafana-postgresql-datasource/postgres.go Co-authored-by: Ivana Huckova <30407135+ivanahuckova@users.noreply.github.com> * Update betterer --------- Co-authored-by: Ivana Huckova <30407135+ivanahuckova@users.noreply.github.com> --------- Co-authored-by: Ivana Huckova <30407135+ivanahuckova@users.noreply.github.com> --- .betterer.results | 10 +- .../src/types/featureToggles.gen.ts | 4 + .../configuration/ConnectionLimits.tsx | 70 +-- .../configuration/MaxLifetimeField.tsx | 42 ++ .../configuration/MaxOpenConnectionsField.tsx | 45 ++ packages/grafana-sql/src/index.ts | 2 + .../backendplugin/coreplugin/registry.go | 2 +- pkg/services/featuremgmt/registry.go | 6 + pkg/services/featuremgmt/toggles_gen.csv | 1 + pkg/services/featuremgmt/toggles_gen.go | 4 + pkg/services/featuremgmt/toggles_gen.json | 12 + .../plugins_integration_test.go | 2 +- .../grafana-postgresql-datasource/locker.go | 85 --- .../locker_test.go | 63 -- .../grafana-postgresql-datasource/postgres.go | 216 +++++-- .../postgres_snapshot_test.go | 14 +- .../postgres_test.go | 129 ++-- .../grafana-postgresql-datasource/proxy.go | 8 + .../sqleng/handler_checkhealth.go | 15 +- .../sqleng/sql_engine.go | 3 + .../sqleng/sql_engine_pgx.go | 554 ++++++++++++++++++ .../table/types_datetime.golden.jsonc | 64 +- .../tlsmanager.go | 262 +++------ .../tlsmanager_test.go | 367 +++++------- .../configuration/ConfigurationEditor.tsx | 31 +- 25 files changed, 1233 insertions(+), 778 deletions(-) create mode 100644 packages/grafana-sql/src/components/configuration/MaxLifetimeField.tsx create mode 100644 packages/grafana-sql/src/components/configuration/MaxOpenConnectionsField.tsx delete mode 100644 pkg/tsdb/grafana-postgresql-datasource/locker.go delete mode 100644 pkg/tsdb/grafana-postgresql-datasource/locker_test.go create mode 100644 pkg/tsdb/grafana-postgresql-datasource/sqleng/sql_engine_pgx.go diff --git a/.betterer.results b/.betterer.results index a52202f2626..00f89db007f 100644 --- a/.betterer.results +++ b/.betterer.results @@ -557,9 +557,13 @@ exports[`better eslint`] = { ], "packages/grafana-sql/src/components/configuration/ConnectionLimits.tsx:5381": [ [0, 0, 0, "Add noMargin prop to Field components to remove built-in margins. Use layout components like Stack or Grid with the gap prop instead for consistent spacing.", "0"], - [0, 0, 0, "Add noMargin prop to Field components to remove built-in margins. Use layout components like Stack or Grid with the gap prop instead for consistent spacing.", "1"], - [0, 0, 0, "Add noMargin prop to Field components to remove built-in margins. Use layout components like Stack or Grid with the gap prop instead for consistent spacing.", "2"], - [0, 0, 0, "Add noMargin prop to Field components to remove built-in margins. Use layout components like Stack or Grid with the gap prop instead for consistent spacing.", "3"] + [0, 0, 0, "Add noMargin prop to Field components to remove built-in margins. Use layout components like Stack or Grid with the gap prop instead for consistent spacing.", "1"] + ], + "packages/grafana-sql/src/components/configuration/MaxLifetimeField.tsx:5381": [ + [0, 0, 0, "Add noMargin prop to Field components to remove built-in margins. Use layout components like Stack or Grid with the gap prop instead for consistent spacing.", "0"] + ], + "packages/grafana-sql/src/components/configuration/MaxOpenConnectionsField.tsx:5381": [ + [0, 0, 0, "Add noMargin prop to Field components to remove built-in margins. Use layout components like Stack or Grid with the gap prop instead for consistent spacing.", "0"] ], "packages/grafana-sql/src/components/configuration/TLSSecretsConfig.tsx:5381": [ [0, 0, 0, "Add noMargin prop to Field components to remove built-in margins. Use layout components like Stack or Grid with the gap prop instead for consistent spacing.", "0"], diff --git a/packages/grafana-data/src/types/featureToggles.gen.ts b/packages/grafana-data/src/types/featureToggles.gen.ts index e2a79ca824f..3993cab76bf 100644 --- a/packages/grafana-data/src/types/featureToggles.gen.ts +++ b/packages/grafana-data/src/types/featureToggles.gen.ts @@ -1003,6 +1003,10 @@ export interface FeatureToggles { */ metricsFromProfiles?: boolean; /** + * Enables using PGX instead of libpq for PostgreSQL datasource + */ + postgresDSUsePGX?: boolean; + /** * Enables auto-updating of users installed plugins */ pluginsAutoUpdate?: boolean; diff --git a/packages/grafana-sql/src/components/configuration/ConnectionLimits.tsx b/packages/grafana-sql/src/components/configuration/ConnectionLimits.tsx index 50bc7e82ae7..351523bd223 100644 --- a/packages/grafana-sql/src/components/configuration/ConnectionLimits.tsx +++ b/packages/grafana-sql/src/components/configuration/ConnectionLimits.tsx @@ -5,6 +5,8 @@ import { Field, Icon, InlineLabel, Label, Stack, Switch, Tooltip } from '@grafan import { SQLConnectionLimits, SQLOptions } from '../../types'; +import { MaxLifetimeField } from './MaxLifetimeField'; +import { MaxOpenConnectionsField } from './MaxOpenConnectionsField'; import { NumberInput } from './NumberInput'; interface Props { @@ -84,36 +86,11 @@ export const ConnectionLimits = (props: Props) return ( - - - Max open - - The maximum number of open connections to the database. If Max idle connections is greater - than 0 and the Max open connections is less than Max idle connections, then - Max idle connections will be reduced to match the Max open connections limit. If set - to 0, there is no limit on the number of open connections. - - } - > - - - - - } - > - { - onMaxConnectionsChanged(value); - }} - width={labelWidth} - /> - + (props: Props) )} - - - Max lifetime - - The maximum amount of time in seconds a connection may be reused. If set to 0, connections are - reused forever. - - } - > - - - - - } - > - { - onJSONDataNumberChanged('connMaxLifetime')(value); - }} - width={labelWidth} - /> - + ); }; diff --git a/packages/grafana-sql/src/components/configuration/MaxLifetimeField.tsx b/packages/grafana-sql/src/components/configuration/MaxLifetimeField.tsx new file mode 100644 index 00000000000..5a4e2c93359 --- /dev/null +++ b/packages/grafana-sql/src/components/configuration/MaxLifetimeField.tsx @@ -0,0 +1,42 @@ +import { config } from '@grafana/runtime'; +import { Field, Icon, Label, Stack, Tooltip } from '@grafana/ui'; + +import { SQLOptions } from '../../types'; + +import { NumberInput } from './NumberInput'; + +interface Props { + labelWidth: number; + onMaxLifetimeChanged: (number?: number) => void; + jsonData: SQLOptions; +} +export function MaxLifetimeField({ labelWidth, onMaxLifetimeChanged, jsonData }: Props) { + return ( + + + Max lifetime + + The maximum amount of time in seconds a connection may be reused. If set to 0, connections are reused + forever. + + } + > + + + + + } + > + + + ); +} diff --git a/packages/grafana-sql/src/components/configuration/MaxOpenConnectionsField.tsx b/packages/grafana-sql/src/components/configuration/MaxOpenConnectionsField.tsx new file mode 100644 index 00000000000..7152312882c --- /dev/null +++ b/packages/grafana-sql/src/components/configuration/MaxOpenConnectionsField.tsx @@ -0,0 +1,45 @@ +import { config } from '@grafana/runtime'; +import { Field, Icon, Label, Stack, Tooltip } from '@grafana/ui'; + +import { SQLOptions } from '../../types'; + +import { NumberInput } from './NumberInput'; + +interface Props { + labelWidth: number; + onMaxConnectionsChanged: (number?: number) => void; + jsonData: SQLOptions; +} + +export function MaxOpenConnectionsField({ labelWidth, onMaxConnectionsChanged, jsonData }: Props) { + return ( + + + Max open + + The maximum number of open connections to the database. If Max idle connections is greater than + 0 and the Max open connections is less than Max idle connections, then + Max idle connections will be reduced to match the Max open connections limit. If set to + 0, there is no limit on the number of open connections. + + } + > + + + + + } + > + + + ); +} diff --git a/packages/grafana-sql/src/index.ts b/packages/grafana-sql/src/index.ts index 8b953f68d5a..d81a2bf9adf 100644 --- a/packages/grafana-sql/src/index.ts +++ b/packages/grafana-sql/src/index.ts @@ -14,6 +14,8 @@ export { COMMON_FNS, MACRO_FUNCTIONS } from './constants'; export { SqlDatasource } from './datasource/SqlDatasource'; export { formatSQL } from './utils/formatSQL'; export { ConnectionLimits } from './components/configuration/ConnectionLimits'; +export { MaxLifetimeField } from './components/configuration/MaxLifetimeField'; +export { MaxOpenConnectionsField } from './components/configuration/MaxOpenConnectionsField'; export { Divider } from './components/configuration/Divider'; export { TLSSecretsConfig } from './components/configuration/TLSSecretsConfig'; export { useMigrateDatabaseFields } from './components/configuration/useMigrateDatabaseFields'; diff --git a/pkg/plugins/backendplugin/coreplugin/registry.go b/pkg/plugins/backendplugin/coreplugin/registry.go index cb0dea1bd0a..1d29ae218f7 100644 --- a/pkg/plugins/backendplugin/coreplugin/registry.go +++ b/pkg/plugins/backendplugin/coreplugin/registry.go @@ -237,7 +237,7 @@ func NewPlugin(pluginID string, cfg *setting.Cfg, httpClientProvider *httpclient case Tempo: svc = tempo.ProvideService(httpClientProvider) case PostgreSQL: - svc = postgres.ProvideService(cfg) + svc = postgres.ProvideService(features) case MySQL: svc = mysql.ProvideService() case MSSQL: diff --git a/pkg/services/featuremgmt/registry.go b/pkg/services/featuremgmt/registry.go index 1bdb9bb20d3..e968818bbd2 100644 --- a/pkg/services/featuremgmt/registry.go +++ b/pkg/services/featuremgmt/registry.go @@ -1720,6 +1720,12 @@ var ( Owner: grafanaObservabilityTracesAndProfilingSquad, FrontendOnly: true, }, + { + Name: "postgresDSUsePGX", + Description: "Enables using PGX instead of libpq for PostgreSQL datasource", + Stage: FeatureStageExperimental, + Owner: grafanaOSSBigTent, + }, { Name: "pluginsAutoUpdate", Description: "Enables auto-updating of users installed plugins", diff --git a/pkg/services/featuremgmt/toggles_gen.csv b/pkg/services/featuremgmt/toggles_gen.csv index 5123d14872a..f1b62dd7b94 100644 --- a/pkg/services/featuremgmt/toggles_gen.csv +++ b/pkg/services/featuremgmt/toggles_gen.csv @@ -225,6 +225,7 @@ localizationForPlugins,experimental,@grafana/plugins-platform-backend,false,fals unifiedNavbars,GA,@grafana/plugins-platform-backend,false,false,true logsPanelControls,preview,@grafana/observability-logs,false,false,true metricsFromProfiles,experimental,@grafana/observability-traces-and-profiling,false,false,true +postgresDSUsePGX,experimental,@grafana/oss-big-tent,false,false,false pluginsAutoUpdate,experimental,@grafana/plugins-platform-backend,false,false,false multiTenantFrontend,experimental,@grafana/grafana-frontend-platform,false,false,false alertingListViewV2PreviewToggle,privatePreview,@grafana/alerting-squad,false,false,true diff --git a/pkg/services/featuremgmt/toggles_gen.go b/pkg/services/featuremgmt/toggles_gen.go index 83f95f37f1e..9ef9f4cfbeb 100644 --- a/pkg/services/featuremgmt/toggles_gen.go +++ b/pkg/services/featuremgmt/toggles_gen.go @@ -911,6 +911,10 @@ const ( // Enables creating metrics from profiles and storing them as recording rules FlagMetricsFromProfiles = "metricsFromProfiles" + // FlagPostgresDSUsePGX + // Enables using PGX instead of libpq for PostgreSQL datasource + FlagPostgresDSUsePGX = "postgresDSUsePGX" + // FlagPluginsAutoUpdate // Enables auto-updating of users installed plugins FlagPluginsAutoUpdate = "pluginsAutoUpdate" diff --git a/pkg/services/featuremgmt/toggles_gen.json b/pkg/services/featuremgmt/toggles_gen.json index 4702b756557..6cad9239307 100644 --- a/pkg/services/featuremgmt/toggles_gen.json +++ b/pkg/services/featuremgmt/toggles_gen.json @@ -2571,6 +2571,18 @@ "expression": "false" } }, + { + "metadata": { + "name": "postgresDSUsePGX", + "resourceVersion": "1745320933872", + "creationTimestamp": "2025-04-22T11:22:13Z" + }, + "spec": { + "description": "Enables using PGX instead of libpq for PostgreSQL datasource", + "stage": "experimental", + "codeowner": "@grafana/oss-big-tent" + } + }, { "metadata": { "name": "preinstallAutoUpdate", diff --git a/pkg/services/pluginsintegration/plugins_integration_test.go b/pkg/services/pluginsintegration/plugins_integration_test.go index 1dce4a244a2..b3baeef1533 100644 --- a/pkg/services/pluginsintegration/plugins_integration_test.go +++ b/pkg/services/pluginsintegration/plugins_integration_test.go @@ -85,7 +85,7 @@ func TestIntegrationPluginManager(t *testing.T) { pr := prometheus.ProvideService(hcp) tmpo := tempo.ProvideService(hcp) td := testdatasource.ProvideService() - pg := postgres.ProvideService(cfg) + pg := postgres.ProvideService(features) my := mysql.ProvideService() ms := mssql.ProvideService(cfg) db := db.InitTestDB(t, sqlstore.InitTestDBOpt{Cfg: cfg}) diff --git a/pkg/tsdb/grafana-postgresql-datasource/locker.go b/pkg/tsdb/grafana-postgresql-datasource/locker.go deleted file mode 100644 index 796c37c7415..00000000000 --- a/pkg/tsdb/grafana-postgresql-datasource/locker.go +++ /dev/null @@ -1,85 +0,0 @@ -package postgres - -import ( - "fmt" - "sync" -) - -// locker is a named reader/writer mutual exclusion lock. -// The lock for each particular key can be held by an arbitrary number of readers or a single writer. -type locker struct { - locks map[any]*sync.RWMutex - locksRW *sync.RWMutex -} - -func newLocker() *locker { - return &locker{ - locks: make(map[any]*sync.RWMutex), - locksRW: new(sync.RWMutex), - } -} - -// Lock locks named rw mutex with specified key for writing. -// If the lock with the same key is already locked for reading or writing, -// Lock blocks until the lock is available. -func (lkr *locker) Lock(key any) { - lk, ok := lkr.getLock(key) - if !ok { - lk = lkr.newLock(key) - } - lk.Lock() -} - -// Unlock unlocks named rw mutex with specified key for writing. It is a run-time error if rw is -// not locked for writing on entry to Unlock. -func (lkr *locker) Unlock(key any) { - lk, ok := lkr.getLock(key) - if !ok { - panic(fmt.Errorf("lock for key '%s' not initialized", key)) - } - lk.Unlock() -} - -// RLock locks named rw mutex with specified key for reading. -// -// It should not be used for recursive read locking for the same key; a blocked Lock -// call excludes new readers from acquiring the lock. See the -// documentation on the golang RWMutex type. -func (lkr *locker) RLock(key any) { - lk, ok := lkr.getLock(key) - if !ok { - lk = lkr.newLock(key) - } - lk.RLock() -} - -// RUnlock undoes a single RLock call for specified key; -// it does not affect other simultaneous readers of locker for specified key. -// It is a run-time error if locker for specified key is not locked for reading -func (lkr *locker) RUnlock(key any) { - lk, ok := lkr.getLock(key) - if !ok { - panic(fmt.Errorf("lock for key '%s' not initialized", key)) - } - lk.RUnlock() -} - -func (lkr *locker) newLock(key any) *sync.RWMutex { - lkr.locksRW.Lock() - defer lkr.locksRW.Unlock() - - if lk, ok := lkr.locks[key]; ok { - return lk - } - lk := new(sync.RWMutex) - lkr.locks[key] = lk - return lk -} - -func (lkr *locker) getLock(key any) (*sync.RWMutex, bool) { - lkr.locksRW.RLock() - defer lkr.locksRW.RUnlock() - - lock, ok := lkr.locks[key] - return lock, ok -} diff --git a/pkg/tsdb/grafana-postgresql-datasource/locker_test.go b/pkg/tsdb/grafana-postgresql-datasource/locker_test.go deleted file mode 100644 index b1dc64f0351..00000000000 --- a/pkg/tsdb/grafana-postgresql-datasource/locker_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package postgres - -import ( - "sync" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -func TestIntegrationLocker(t *testing.T) { - if testing.Short() { - t.Skip("Tests with Sleep") - } - const notUpdated = "not_updated" - const atThread1 = "at_thread_1" - const atThread2 = "at_thread_2" - t.Run("Should lock for same keys", func(t *testing.T) { - updated := notUpdated - locker := newLocker() - locker.Lock(1) - var wg sync.WaitGroup - wg.Add(1) - defer func() { - locker.Unlock(1) - wg.Wait() - }() - - go func() { - locker.RLock(1) - defer func() { - locker.RUnlock(1) - wg.Done() - }() - require.Equal(t, atThread1, updated, "Value should be updated in different thread") - updated = atThread2 - }() - time.Sleep(time.Millisecond * 10) - require.Equal(t, notUpdated, updated, "Value should not be updated in different thread") - updated = atThread1 - }) - - t.Run("Should not lock for different keys", func(t *testing.T) { - updated := notUpdated - locker := newLocker() - locker.Lock(1) - defer locker.Unlock(1) - var wg sync.WaitGroup - wg.Add(1) - go func() { - locker.RLock(2) - defer func() { - locker.RUnlock(2) - wg.Done() - }() - require.Equal(t, notUpdated, updated, "Value should not be updated in different thread") - updated = atThread2 - }() - wg.Wait() - require.Equal(t, atThread2, updated, "Value should be updated in different thread") - updated = atThread1 - }) -} diff --git a/pkg/tsdb/grafana-postgresql-datasource/postgres.go b/pkg/tsdb/grafana-postgresql-datasource/postgres.go index fde06c08172..cdeda5c7105 100644 --- a/pkg/tsdb/grafana-postgresql-datasource/postgres.go +++ b/pkg/tsdb/grafana-postgresql-datasource/postgres.go @@ -15,27 +15,30 @@ import ( "github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt" "github.com/grafana/grafana-plugin-sdk-go/data" "github.com/grafana/grafana-plugin-sdk-go/data/sqlutil" + "github.com/grafana/grafana/pkg/services/featuremgmt" + "github.com/jackc/pgx/v5/pgxpool" "github.com/lib/pq" "github.com/grafana/grafana-plugin-sdk-go/backend/log" - "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/tsdb/grafana-postgresql-datasource/sqleng" ) -func ProvideService(cfg *setting.Cfg) *Service { +func ProvideService(features featuremgmt.FeatureToggles) *Service { logger := backend.NewLoggerWith("logger", "tsdb.postgres") s := &Service{ - tlsManager: newTLSManager(logger, cfg.DataPath), + tlsManager: newTLSManager(logger), logger: logger, + features: features, } s.im = datasource.NewInstanceManager(s.newInstanceSettings()) return s } type Service struct { - tlsManager tlsSettingsProvider + tlsManager *tlsManager im instancemgmt.InstanceManager logger log.Logger + features featuremgmt.FeatureToggles } func (s *Service) getDSInfo(ctx context.Context, pluginCtx backend.PluginContext) (*sqleng.DataSourceHandler, error) { @@ -52,6 +55,11 @@ func (s *Service) QueryData(ctx context.Context, req *backend.QueryDataRequest) if err != nil { return nil, err } + + if s.features.IsEnabled(ctx, featuremgmt.FlagPostgresDSUsePGX) { + return dsInfo.QueryDataPGX(ctx, req) + } + return dsInfo.QueryData(ctx, req) } @@ -93,6 +101,13 @@ func newPostgres(ctx context.Context, userFacingDefaultError string, rowLimit in db.SetMaxIdleConns(config.DSInfo.JsonData.MaxIdleConns) db.SetConnMaxLifetime(time.Duration(config.DSInfo.JsonData.ConnMaxLifetime) * time.Second) + // We need to ping the database to ensure that the connection is valid and the temporary files are not deleted + // before the connection is used. + if err := db.Ping(); err != nil { + logger.Error("Failed to ping Postgres database", "error", err) + return nil, nil, backend.DownstreamError(fmt.Errorf("failed to ping Postgres database: %w", err)) + } + handler, err := sqleng.NewQueryDataHandler(userFacingDefaultError, db, config, &queryResultTransformer, newPostgresMacroEngine(dsInfo.JsonData.Timescaledb), logger) if err != nil { @@ -104,6 +119,62 @@ func newPostgres(ctx context.Context, userFacingDefaultError string, rowLimit in return db, handler, nil } +func newPostgresPGX(ctx context.Context, userFacingDefaultError string, rowLimit int64, dsInfo sqleng.DataSourceInfo, cnnstr string, logger log.Logger, settings backend.DataSourceInstanceSettings) (*pgxpool.Pool, *sqleng.DataSourceHandler, error) { + pgxConf, err := pgxpool.ParseConfig(cnnstr) + if err != nil { + logger.Error("postgres config creation failed", "error", err) + return nil, nil, fmt.Errorf("postgres config creation failed") + } + + proxyClient, err := settings.ProxyClient(ctx) + if err != nil { + logger.Error("postgres proxy creation failed", "error", err) + return nil, nil, fmt.Errorf("postgres proxy creation failed") + } + + if proxyClient.SecureSocksProxyEnabled() { + dialer, err := proxyClient.NewSecureSocksProxyContextDialer() + if err != nil { + logger.Error("postgres proxy creation failed", "error", err) + return nil, nil, fmt.Errorf("postgres proxy creation failed") + } + + pgxConf.ConnConfig.DialFunc = newPgxDialFunc(dialer) + } + + // by default pgx resolves hostnames to ip addresses. we must avoid this. + // (certain socks-proxy related functionality relies on the hostname being preserved) + pgxConf.ConnConfig.LookupFunc = func(_ context.Context, host string) ([]string, error) { + return []string{host}, nil + } + + config := sqleng.DataPluginConfiguration{ + DSInfo: dsInfo, + MetricColumnTypes: []string{"unknown", "text", "varchar", "char", "bpchar"}, + RowLimit: rowLimit, + } + + queryResultTransformer := postgresQueryResultTransformer{} + pgxConf.MaxConnLifetime = time.Duration(config.DSInfo.JsonData.ConnMaxLifetime) * time.Second + pgxConf.MaxConns = int32(config.DSInfo.JsonData.MaxOpenConns) + + p, err := pgxpool.NewWithConfig(ctx, pgxConf) + if err != nil { + logger.Error("Failed connecting to Postgres", "err", err) + return nil, nil, err + } + + handler, err := sqleng.NewQueryDataHandlerPGX(userFacingDefaultError, p, config, &queryResultTransformer, newPostgresMacroEngine(dsInfo.JsonData.Timescaledb), + logger) + if err != nil { + logger.Error("Failed connecting to Postgres", "err", err) + return nil, nil, err + } + + logger.Debug("Successfully connected to Postgres") + return p, handler, nil +} + func (s *Service) newInstanceSettings() datasource.InstanceFactoryFunc { logger := s.logger return func(ctx context.Context, settings backend.DataSourceInstanceSettings) (instancemgmt.Instance, error) { @@ -143,7 +214,16 @@ func (s *Service) newInstanceSettings() datasource.InstanceFactoryFunc { DecryptedSecureJSONData: settings.DecryptedSecureJSONData, } - cnnstr, err := s.generateConnectionString(dsInfo) + tlsSettings, err := s.tlsManager.getTLSSettings(dsInfo) + if err != nil { + return "", err + } + + // Ensure cleanupCertFiles is called after the connection is opened + defer s.tlsManager.cleanupCertFiles(tlsSettings) + + isPGX := s.features.IsEnabled(ctx, featuremgmt.FlagPostgresDSUsePGX) + cnnstr, err := s.generateConnectionString(dsInfo, tlsSettings, isPGX) if err != nil { return nil, err } @@ -153,7 +233,12 @@ func (s *Service) newInstanceSettings() datasource.InstanceFactoryFunc { return nil, err } - _, handler, err := newPostgres(ctx, userFacingDefaultError, sqlCfg.RowLimit, dsInfo, cnnstr, logger, settings) + var handler instancemgmt.Instance + if isPGX { + _, handler, err = newPostgresPGX(ctx, userFacingDefaultError, sqlCfg.RowLimit, dsInfo, cnnstr, logger, settings) + } else { + _, handler, err = newPostgres(ctx, userFacingDefaultError, sqlCfg.RowLimit, dsInfo, cnnstr, logger, settings) + } if err != nil { logger.Error("Failed connecting to Postgres", "err", err) @@ -170,65 +255,100 @@ func escape(input string) string { return strings.ReplaceAll(strings.ReplaceAll(input, `\`, `\\`), "'", `\'`) } -func (s *Service) generateConnectionString(dsInfo sqleng.DataSourceInfo) (string, error) { - logger := s.logger - var host string - var port int +type connectionParams struct { + host string + port int + user string + password string + database string +} + +func parseConnectionParams(dsInfo sqleng.DataSourceInfo, logger log.Logger) (connectionParams, error) { + var params connectionParams + var err error + if strings.HasPrefix(dsInfo.URL, "/") { - host = dsInfo.URL + params.host = dsInfo.URL logger.Debug("Generating connection string with Unix socket specifier", "address", dsInfo.URL) } else { - index := strings.LastIndex(dsInfo.URL, ":") - v6Index := strings.Index(dsInfo.URL, "]") - sp := strings.SplitN(dsInfo.URL, ":", 2) - host = sp[0] - if v6Index == -1 { - if len(sp) > 1 { - var err error - port, err = strconv.Atoi(sp[1]) - if err != nil { - logger.Debug("Error parsing the IPv4 address", "address", dsInfo.URL) - return "", sqleng.ErrParsingPostgresURL - } - logger.Debug("Generating IPv4 connection string with network host/port pair", "host", host, "port", port, "address", dsInfo.URL) - } else { - logger.Debug("Generating IPv4 connection string with network host", "host", host, "address", dsInfo.URL) - } - } else { - if index == v6Index+1 { - host = dsInfo.URL[1 : index-1] - var err error - port, err = strconv.Atoi(dsInfo.URL[index+1:]) - if err != nil { - logger.Debug("Error parsing the IPv6 address", "address", dsInfo.URL) - return "", sqleng.ErrParsingPostgresURL - } - logger.Debug("Generating IPv6 connection string with network host/port pair", "host", host, "port", port, "address", dsInfo.URL) - } else { - host = dsInfo.URL[1 : len(dsInfo.URL)-1] - logger.Debug("Generating IPv6 connection string with network host", "host", host, "address", dsInfo.URL) - } + params.host, params.port, err = parseNetworkAddress(dsInfo.URL, logger) + if err != nil { + return connectionParams{}, err } } - connStr := fmt.Sprintf("user='%s' password='%s' host='%s' dbname='%s'", - escape(dsInfo.User), escape(dsInfo.DecryptedSecureJSONData["password"]), escape(host), escape(dsInfo.Database)) - if port > 0 { - connStr += fmt.Sprintf(" port=%d", port) + params.user = dsInfo.User + params.password = dsInfo.DecryptedSecureJSONData["password"] + params.database = dsInfo.Database + + return params, nil +} + +func parseNetworkAddress(url string, logger log.Logger) (string, int, error) { + index := strings.LastIndex(url, ":") + v6Index := strings.Index(url, "]") + sp := strings.SplitN(url, ":", 2) + host := sp[0] + port := 0 + + if v6Index == -1 { + if len(sp) > 1 { + var err error + port, err = strconv.Atoi(sp[1]) + if err != nil { + logger.Debug("Error parsing the IPv4 address", "address", url) + return "", 0, sqleng.ErrParsingPostgresURL + } + logger.Debug("Generating IPv4 connection string with network host/port pair", "host", host, "port", port, "address", url) + } else { + logger.Debug("Generating IPv4 connection string with network host", "host", host, "address", url) + } + } else { + if index == v6Index+1 { + host = url[1 : index-1] + var err error + port, err = strconv.Atoi(url[index+1:]) + if err != nil { + logger.Debug("Error parsing the IPv6 address", "address", url) + return "", 0, sqleng.ErrParsingPostgresURL + } + logger.Debug("Generating IPv6 connection string with network host/port pair", "host", host, "port", port, "address", url) + } else { + host = url[1 : len(url)-1] + logger.Debug("Generating IPv6 connection string with network host", "host", host, "address", url) + } } - tlsSettings, err := s.tlsManager.getTLSSettings(dsInfo) + return host, port, nil +} + +func buildBaseConnectionString(params connectionParams) string { + connStr := fmt.Sprintf("user='%s' password='%s' host='%s' dbname='%s'", + escape(params.user), escape(params.password), escape(params.host), escape(params.database)) + if params.port > 0 { + connStr += fmt.Sprintf(" port=%d", params.port) + } + return connStr +} + +func (s *Service) generateConnectionString(dsInfo sqleng.DataSourceInfo, tlsSettings tlsSettings, isPGX bool) (string, error) { + logger := s.logger + + params, err := parseConnectionParams(dsInfo, logger) if err != nil { return "", err } + connStr := buildBaseConnectionString(params) + connStr += fmt.Sprintf(" sslmode='%s'", escape(tlsSettings.Mode)) // there is an issue with the lib/pq module, the `verify-ca` tls mode // does not work correctly. ( see https://github.com/lib/pq/issues/1106 ) // to workaround the problem, if the `verify-ca` mode is chosen, // we disable sslsni. - if tlsSettings.Mode == "verify-ca" { + if tlsSettings.Mode == "verify-ca" && !isPGX { + logger.Debug("Disabling sslsni for verify-ca mode") connStr += " sslsni=0" } @@ -262,7 +382,7 @@ func (s *Service) CheckHealth(ctx context.Context, req *backend.CheckHealthReque if err != nil { return sqleng.ErrToHealthCheckResult(err) } - return dsHandler.CheckHealth(ctx, req) + return dsHandler.CheckHealth(ctx, req, s.features) } func (t *postgresQueryResultTransformer) GetConverterList() []sqlutil.StringConverter { diff --git a/pkg/tsdb/grafana-postgresql-datasource/postgres_snapshot_test.go b/pkg/tsdb/grafana-postgresql-datasource/postgres_snapshot_test.go index e0d9e307ae0..d558f79f3a3 100644 --- a/pkg/tsdb/grafana-postgresql-datasource/postgres_snapshot_test.go +++ b/pkg/tsdb/grafana-postgresql-datasource/postgres_snapshot_test.go @@ -149,10 +149,11 @@ func TestIntegrationPostgresSnapshots(t *testing.T) { } jsonData := sqleng.JsonData{ - MaxOpenConns: 0, + MaxOpenConns: 10, MaxIdleConns: 2, ConnMaxLifetime: 14400, Timescaledb: false, + Mode: "disable", ConfigurationMethod: "file-path", } @@ -165,13 +166,12 @@ func TestIntegrationPostgresSnapshots(t *testing.T) { cnnstr := getCnnStr() - db, handler, err := newPostgres(context.Background(), "error", 10000, dsInfo, cnnstr, logger, backend.DataSourceInstanceSettings{}) + p, handler, err := newPostgresPGX(context.Background(), "error", 10000, dsInfo, cnnstr, logger, backend.DataSourceInstanceSettings{}) t.Cleanup((func() { - _, err := db.Exec("DROP TABLE tbl") - require.NoError(t, err) - err = db.Close() + _, err := p.Exec(context.Background(), "DROP TABLE tbl") require.NoError(t, err) + p.Close() })) require.NoError(t, err) @@ -181,12 +181,12 @@ func TestIntegrationPostgresSnapshots(t *testing.T) { rawSQL, sql := readSqlFile(sqlFilePath) - _, err = db.Exec(sql) + _, err = p.Exec(context.Background(), sql) require.NoError(t, err) query := makeQuery(rawSQL, test.format) - result, err := handler.QueryData(context.Background(), &query) + result, err := handler.QueryDataPGX(context.Background(), &query) require.Len(t, result.Responses, 1) response, found := result.Responses["A"] require.True(t, found) diff --git a/pkg/tsdb/grafana-postgresql-datasource/postgres_test.go b/pkg/tsdb/grafana-postgresql-datasource/postgres_test.go index 47eee7b1a6b..ab81a2082d7 100644 --- a/pkg/tsdb/grafana-postgresql-datasource/postgres_test.go +++ b/pkg/tsdb/grafana-postgresql-datasource/postgres_test.go @@ -14,7 +14,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/tsdb/grafana-postgresql-datasource/sqleng" _ "github.com/lib/pq" @@ -25,8 +24,6 @@ func TestIntegrationGenerateConnectionString(t *testing.T) { if testing.Short() { t.Skip("skipping integration test") } - cfg := setting.NewCfg() - cfg.DataPath = t.TempDir() testCases := []struct { desc string @@ -147,8 +144,7 @@ func TestIntegrationGenerateConnectionString(t *testing.T) { for _, tt := range testCases { t.Run(tt.desc, func(t *testing.T) { svc := Service{ - tlsManager: &tlsTestManager{settings: tt.tlsSettings}, - logger: backend.NewLoggerWith("logger", "tsdb.postgres"), + logger: backend.NewLoggerWith("logger", "tsdb.postgres"), } ds := sqleng.DataSourceInfo{ @@ -159,7 +155,7 @@ func TestIntegrationGenerateConnectionString(t *testing.T) { UID: tt.uid, } - connStr, err := svc.generateConnectionString(ds) + connStr, err := svc.generateConnectionString(ds, tt.tlsSettings, false) if tt.expErr == "" { require.NoError(t, err, tt.desc) @@ -201,10 +197,11 @@ func TestIntegrationPostgres(t *testing.T) { } jsonData := sqleng.JsonData{ - MaxOpenConns: 0, + MaxOpenConns: 10, MaxIdleConns: 2, ConnMaxLifetime: 14400, Timescaledb: false, + Mode: "disable", ConfigurationMethod: "file-path", } @@ -217,7 +214,7 @@ func TestIntegrationPostgres(t *testing.T) { cnnstr := postgresTestDBConnString() - db, exe, err := newPostgres(context.Background(), "error", 10000, dsInfo, cnnstr, logger, backend.DataSourceInstanceSettings{}) + p, exe, err := newPostgresPGX(context.Background(), "error", 10000, dsInfo, cnnstr, logger, backend.DataSourceInstanceSettings{}) require.NoError(t, err) @@ -250,7 +247,7 @@ func TestIntegrationPostgres(t *testing.T) { c16_smallint smallint ); ` - _, err := db.Exec(sql) + _, err := p.Exec(context.Background(), sql) require.NoError(t, err) sql = ` @@ -263,7 +260,7 @@ func TestIntegrationPostgres(t *testing.T) { null ); ` - _, err = db.Exec(sql) + _, err = p.Exec(context.Background(), sql) require.NoError(t, err) t.Run("When doing a table query should map Postgres column types to Go types", func(t *testing.T) { @@ -278,7 +275,7 @@ func TestIntegrationPostgres(t *testing.T) { }, }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -306,9 +303,9 @@ func TestIntegrationPostgres(t *testing.T) { require.True(t, ok) _, ok = frames[0].Fields[12].At(0).(*time.Time) require.True(t, ok) - _, ok = frames[0].Fields[13].At(0).(*time.Time) + _, ok = frames[0].Fields[13].At(0).(*string) require.True(t, ok) - _, ok = frames[0].Fields[14].At(0).(*time.Time) + _, ok = frames[0].Fields[14].At(0).(*string) require.True(t, ok) _, ok = frames[0].Fields[15].At(0).(*time.Time) require.True(t, ok) @@ -326,7 +323,7 @@ func TestIntegrationPostgres(t *testing.T) { ) ` - _, err := db.Exec(sql) + _, err := p.Exec(context.Background(), sql) require.NoError(t, err) type metric struct { @@ -353,7 +350,7 @@ func TestIntegrationPostgres(t *testing.T) { } for _, m := range series { - _, err := db.Exec(`INSERT INTO metric ("time", value) VALUES ($1, $2)`, m.Time.UTC(), m.Value) + _, err := p.Exec(context.Background(), `INSERT INTO metric ("time", value) VALUES ($1, $2)`, m.Time.UTC(), m.Value) require.NoError(t, err) } @@ -370,7 +367,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -426,7 +423,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] frames := queryResult.Frames @@ -454,7 +451,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -508,7 +505,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -534,7 +531,7 @@ func TestIntegrationPostgres(t *testing.T) { } for _, m := range series { - _, err := db.Exec(`INSERT INTO metric ("time", value) VALUES ($1, $2)`, m.Time.UTC(), m.Value) + _, err := p.Exec(context.Background(), `INSERT INTO metric ("time", value) VALUES ($1, $2)`, m.Time.UTC(), m.Value) require.NoError(t, err) } @@ -555,7 +552,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -590,7 +587,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -618,10 +615,10 @@ func TestIntegrationPostgres(t *testing.T) { ValueTwo int64 } - _, err := db.Exec("DROP TABLE IF EXISTS metric_values") + _, err := p.Exec(context.Background(), "DROP TABLE IF EXISTS metric_values") require.NoError(t, err) - _, err = db.Exec(`CREATE TABLE metric_values ( + _, err = p.Exec(context.Background(), `CREATE TABLE metric_values ( "time" TIMESTAMP NULL, "timeInt64" BIGINT NOT NULL, "timeInt64Nullable" BIGINT NULL, "timeFloat64" DOUBLE PRECISION NOT NULL, "timeFloat64Nullable" DOUBLE PRECISION NULL, @@ -674,7 +671,7 @@ func TestIntegrationPostgres(t *testing.T) { // _, err = session.InsertMulti(series) for _, m := range series { - _, err := db.Exec(`INSERT INTO "metric_values" ( + _, err := p.Exec(context.Background(), `INSERT INTO "metric_values" ( time, "timeInt64", "timeInt64Nullable", "timeFloat64", "timeFloat64Nullable", @@ -707,7 +704,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -731,7 +728,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -755,7 +752,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -779,7 +776,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -803,7 +800,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -827,7 +824,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -851,7 +848,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -876,7 +873,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -900,7 +897,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -925,7 +922,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -957,7 +954,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -992,7 +989,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -1011,9 +1008,9 @@ func TestIntegrationPostgres(t *testing.T) { Tags string } - _, err := db.Exec("DROP TABLE IF EXISTS event") + _, err := p.Exec(context.Background(), "DROP TABLE IF EXISTS event") require.NoError(t, err) - _, err = db.Exec(`CREATE TABLE event (time_sec BIGINT NULL, description VARCHAR(255) NULL, tags VARCHAR(255) NULL)`) + _, err = p.Exec(context.Background(), `CREATE TABLE event (time_sec BIGINT NULL, description VARCHAR(255) NULL, tags VARCHAR(255) NULL)`) require.NoError(t, err) events := []*event{} @@ -1031,7 +1028,7 @@ func TestIntegrationPostgres(t *testing.T) { } for _, e := range events { - _, err := db.Exec("INSERT INTO event (time_sec, description, tags) VALUES ($1, $2, $3)", e.TimeSec, e.Description, e.Tags) + _, err := p.Exec(context.Background(), "INSERT INTO event (time_sec, description, tags) VALUES ($1, $2, $3)", e.TimeSec, e.Description, e.Tags) require.NoError(t, err) } @@ -1052,7 +1049,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["Deploys"] @@ -1079,7 +1076,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["Tickets"] @@ -1102,7 +1099,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -1127,7 +1124,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -1152,7 +1149,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -1178,7 +1175,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -1204,7 +1201,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -1230,7 +1227,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -1256,7 +1253,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -1270,8 +1267,20 @@ func TestIntegrationPostgres(t *testing.T) { }) t.Run("When row limit set to 1", func(t *testing.T) { - dsInfo := sqleng.DataSourceInfo{} - _, handler, err := newPostgres(context.Background(), "error", 1, dsInfo, cnnstr, logger, backend.DataSourceInstanceSettings{}) + jsonData := sqleng.JsonData{ + MaxOpenConns: 10, + MaxIdleConns: 2, + ConnMaxLifetime: 14400, + Timescaledb: false, + Mode: "disable", + ConfigurationMethod: "file-path", + } + + dsInfo := sqleng.DataSourceInfo{ + JsonData: jsonData, + DecryptedSecureJSONData: map[string]string{}, + } + _, handler, err := newPostgresPGX(context.Background(), "error", 1, dsInfo, cnnstr, logger, backend.DataSourceInstanceSettings{}) require.NoError(t, err) @@ -1292,7 +1301,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := handler.QueryData(context.Background(), query) + resp, err := handler.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -1322,7 +1331,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := handler.QueryData(context.Background(), query) + resp, err := handler.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] require.NoError(t, queryResult.Error) @@ -1338,9 +1347,9 @@ func TestIntegrationPostgres(t *testing.T) { }) t.Run("Given an empty table", func(t *testing.T) { - _, err := db.Exec("DROP TABLE IF EXISTS empty_obj") + _, err := p.Exec(context.Background(), "DROP TABLE IF EXISTS empty_obj") require.NoError(t, err) - _, err = db.Exec("CREATE TABLE empty_obj (empty_key VARCHAR(255) NULL, empty_val BIGINT NULL)") + _, err = p.Exec(context.Background(), "CREATE TABLE empty_obj (empty_key VARCHAR(255) NULL, empty_val BIGINT NULL)") require.NoError(t, err) t.Run("When no rows are returned, should return an empty frame", func(t *testing.T) { @@ -1360,7 +1369,7 @@ func TestIntegrationPostgres(t *testing.T) { }, } - resp, err := exe.QueryData(context.Background(), query) + resp, err := exe.QueryDataPGX(context.Background(), query) require.NoError(t, err) queryResult := resp.Responses["A"] @@ -1386,14 +1395,6 @@ func genTimeRangeByInterval(from time.Time, duration time.Duration, interval tim return timeRange } -type tlsTestManager struct { - settings tlsSettings -} - -func (m *tlsTestManager) getTLSSettings(dsInfo sqleng.DataSourceInfo) (tlsSettings, error) { - return m.settings, nil -} - func isTestDbPostgres() bool { if db, present := os.LookupEnv("GRAFANA_TEST_DB"); present { return db == "postgres" diff --git a/pkg/tsdb/grafana-postgresql-datasource/proxy.go b/pkg/tsdb/grafana-postgresql-datasource/proxy.go index d06d8b68152..dca66df38e9 100644 --- a/pkg/tsdb/grafana-postgresql-datasource/proxy.go +++ b/pkg/tsdb/grafana-postgresql-datasource/proxy.go @@ -33,3 +33,11 @@ func (p *postgresProxyDialer) DialTimeout(network, address string, timeout time. return p.d.(proxy.ContextDialer).DialContext(ctx, network, address) } + +type PgxDialFunc = func(ctx context.Context, network string, address string) (net.Conn, error) + +func newPgxDialFunc(dialer proxy.Dialer) PgxDialFunc { + return func(ctx context.Context, network string, addr string) (net.Conn, error) { + return dialer.Dial(network, addr) + } +} diff --git a/pkg/tsdb/grafana-postgresql-datasource/sqleng/handler_checkhealth.go b/pkg/tsdb/grafana-postgresql-datasource/sqleng/handler_checkhealth.go index 25b4ee0fdf8..99d36aaaa85 100644 --- a/pkg/tsdb/grafana-postgresql-datasource/sqleng/handler_checkhealth.go +++ b/pkg/tsdb/grafana-postgresql-datasource/sqleng/handler_checkhealth.go @@ -10,11 +10,17 @@ import ( "github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/backend/log" + "github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/lib/pq" ) -func (e *DataSourceHandler) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) { - err := e.Ping() +func (e *DataSourceHandler) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest, features featuremgmt.FeatureToggles) (*backend.CheckHealthResult, error) { + var err error + if features.IsEnabled(ctx, featuremgmt.FlagPostgresDSUsePGX) { + err = e.PingPGX(ctx) + } else { + err = e.Ping() + } if err != nil { logCheckHealthError(ctx, e.dsInfo, err) if strings.EqualFold(req.PluginContext.User.Role, "Admin") { @@ -63,6 +69,7 @@ func ErrToHealthCheckResult(err error) (*backend.CheckHealthResult, error) { res.Message += fmt.Sprintf(". Error message: %s", errMessage) } } + if errors.Is(err, pq.ErrSSLNotSupported) { res.Message = "SSL error: Failed to connect to the server" } @@ -125,10 +132,10 @@ func logCheckHealthError(ctx context.Context, dsInfo DataSourceInfo, err error) "config_tls_client_cert_length": len(dsInfo.DecryptedSecureJSONData["tlsClientCert"]), "config_tls_client_key_length": len(dsInfo.DecryptedSecureJSONData["tlsClientKey"]), } - configSummaryJson, marshalError := json.Marshal(configSummary) + configSummaryJSON, marshalError := json.Marshal(configSummary) if marshalError != nil { logger.Error("Check health failed", "error", err, "message_type", "ds_config_health_check_error") return } - logger.Error("Check health failed", "error", err, "message_type", "ds_config_health_check_error_detailed", "details", string(configSummaryJson)) + logger.Error("Check health failed", "error", err, "message_type", "ds_config_health_check_error_detailed", "details", string(configSummaryJSON)) } diff --git a/pkg/tsdb/grafana-postgresql-datasource/sqleng/sql_engine.go b/pkg/tsdb/grafana-postgresql-datasource/sqleng/sql_engine.go index 7712fe7c3d5..8f354497689 100644 --- a/pkg/tsdb/grafana-postgresql-datasource/sqleng/sql_engine.go +++ b/pkg/tsdb/grafana-postgresql-datasource/sqleng/sql_engine.go @@ -19,6 +19,7 @@ import ( "github.com/grafana/grafana-plugin-sdk-go/backend/log" "github.com/grafana/grafana-plugin-sdk-go/data" "github.com/grafana/grafana-plugin-sdk-go/data/sqlutil" + "github.com/jackc/pgx/v5/pgxpool" ) // MetaKeyExecutedQueryString is the key where the executed query should get stored @@ -88,6 +89,7 @@ type DataSourceHandler struct { dsInfo DataSourceInfo rowLimit int64 userError string + pool *pgxpool.Pool } type QueryJson struct { @@ -489,6 +491,7 @@ type dataQueryModel struct { Interval time.Duration columnNames []string columnTypes []*sql.ColumnType + columnTypesPGX []string timeIndex int timeEndIndex int metricIndex int diff --git a/pkg/tsdb/grafana-postgresql-datasource/sqleng/sql_engine_pgx.go b/pkg/tsdb/grafana-postgresql-datasource/sqleng/sql_engine_pgx.go new file mode 100644 index 00000000000..df182a3119f --- /dev/null +++ b/pkg/tsdb/grafana-postgresql-datasource/sqleng/sql_engine_pgx.go @@ -0,0 +1,554 @@ +package sqleng + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "runtime/debug" + "strings" + "sync" + "time" + + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana-plugin-sdk-go/backend/log" + "github.com/grafana/grafana-plugin-sdk-go/data" + "github.com/grafana/grafana-plugin-sdk-go/data/sqlutil" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" +) + +func NewQueryDataHandlerPGX(userFacingDefaultError string, p *pgxpool.Pool, config DataPluginConfiguration, queryResultTransformer SqlQueryResultTransformer, + macroEngine SQLMacroEngine, log log.Logger) (*DataSourceHandler, error) { + queryDataHandler := DataSourceHandler{ + queryResultTransformer: queryResultTransformer, + macroEngine: macroEngine, + timeColumnNames: []string{"time"}, + log: log, + dsInfo: config.DSInfo, + rowLimit: config.RowLimit, + userError: userFacingDefaultError, + } + + if len(config.TimeColumnNames) > 0 { + queryDataHandler.timeColumnNames = config.TimeColumnNames + } + + if len(config.MetricColumnTypes) > 0 { + queryDataHandler.metricColumnTypes = config.MetricColumnTypes + } + + queryDataHandler.pool = p + return &queryDataHandler, nil +} + +func (e *DataSourceHandler) DisposePGX() { + e.log.Debug("Disposing DB...") + + if e.pool != nil { + e.pool.Close() + } + + e.log.Debug("DB disposed") +} + +func (e *DataSourceHandler) PingPGX(ctx context.Context) error { + return e.pool.Ping(ctx) +} + +func (e *DataSourceHandler) QueryDataPGX(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { + result := backend.NewQueryDataResponse() + ch := make(chan DBDataResponse, len(req.Queries)) + var wg sync.WaitGroup + // Execute each query in a goroutine and wait for them to finish afterwards + for _, query := range req.Queries { + queryjson := QueryJson{ + Fill: false, + Format: "time_series", + } + err := json.Unmarshal(query.JSON, &queryjson) + if err != nil { + return nil, fmt.Errorf("error unmarshal query json: %w", err) + } + + // the fill-params are only stored inside this function, during query-interpolation. we do not support + // sending them in "from the outside" + if queryjson.Fill || queryjson.FillInterval != 0.0 || queryjson.FillMode != "" || queryjson.FillValue != 0.0 { + return nil, fmt.Errorf("query fill-parameters not supported") + } + + if queryjson.RawSql == "" { + continue + } + + wg.Add(1) + go e.executeQueryPGX(ctx, query, &wg, ch, queryjson) + } + + wg.Wait() + + // Read results from channels + close(ch) + result.Responses = make(map[string]backend.DataResponse) + for queryResult := range ch { + result.Responses[queryResult.refID] = queryResult.dataResponse + } + + return result, nil +} + +func (e *DataSourceHandler) handleQueryError(frameErr string, err error, query string, source backend.ErrorSource, ch chan DBDataResponse, queryResult DBDataResponse) { + var emptyFrame data.Frame + emptyFrame.SetMeta(&data.FrameMeta{ExecutedQueryString: query}) + if backend.IsDownstreamError(err) { + source = backend.ErrorSourceDownstream + } + queryResult.dataResponse.Error = fmt.Errorf("%s: %w", frameErr, err) + queryResult.dataResponse.ErrorSource = source + queryResult.dataResponse.Frames = data.Frames{&emptyFrame} + ch <- queryResult +} + +func (e *DataSourceHandler) handlePanic(logger log.Logger, queryResult *DBDataResponse, ch chan DBDataResponse) { + if r := recover(); r != nil { + logger.Error("ExecuteQuery panic", "error", r, "stack", string(debug.Stack())) + if theErr, ok := r.(error); ok { + queryResult.dataResponse.Error = theErr + queryResult.dataResponse.ErrorSource = backend.ErrorSourcePlugin + } else if theErrString, ok := r.(string); ok { + queryResult.dataResponse.Error = errors.New(theErrString) + queryResult.dataResponse.ErrorSource = backend.ErrorSourcePlugin + } else { + queryResult.dataResponse.Error = fmt.Errorf("unexpected error - %s", e.userError) + queryResult.dataResponse.ErrorSource = backend.ErrorSourceDownstream + } + ch <- *queryResult + } +} + +func (e *DataSourceHandler) execQuery(ctx context.Context, query string, logger log.Logger) ([]*pgconn.Result, error) { + c, err := e.pool.Acquire(ctx) + if err != nil { + return nil, fmt.Errorf("failed to acquire connection: %w", err) + } + defer c.Release() + + mrr := c.Conn().PgConn().Exec(ctx, query) + defer func() { + if err := mrr.Close(); err != nil { + logger.Warn("Failed to close multi-result reader", "error", err) + } + }() + return mrr.ReadAll() +} + +func (e *DataSourceHandler) executeQueryPGX(queryContext context.Context, query backend.DataQuery, wg *sync.WaitGroup, + ch chan DBDataResponse, queryJSON QueryJson) { + defer wg.Done() + queryResult := DBDataResponse{ + dataResponse: backend.DataResponse{}, + refID: query.RefID, + } + + logger := e.log.FromContext(queryContext) + defer e.handlePanic(logger, &queryResult, ch) + + if queryJSON.RawSql == "" { + panic("Query model property rawSql should not be empty at this point") + } + + // global substitutions + interpolatedQuery := Interpolate(query, query.TimeRange, e.dsInfo.JsonData.TimeInterval, queryJSON.RawSql) + + // data source specific substitutions + interpolatedQuery, err := e.macroEngine.Interpolate(&query, query.TimeRange, interpolatedQuery) + if err != nil { + e.handleQueryError("interpolation failed", e.TransformQueryError(logger, err), interpolatedQuery, backend.ErrorSourcePlugin, ch, queryResult) + return + } + + results, err := e.execQuery(queryContext, interpolatedQuery, logger) + if err != nil { + e.handleQueryError("db query error", e.TransformQueryError(logger, err), interpolatedQuery, backend.ErrorSourcePlugin, ch, queryResult) + return + } + + qm, err := e.newProcessCfgPGX(queryContext, query, results, interpolatedQuery) + if err != nil { + e.handleQueryError("failed to get configurations", err, interpolatedQuery, backend.ErrorSourcePlugin, ch, queryResult) + return + } + + frame, err := convertResultsToFrame(results, e.rowLimit) + if err != nil { + e.handleQueryError("convert frame from rows error", err, interpolatedQuery, backend.ErrorSourcePlugin, ch, queryResult) + return + } + + e.processFrame(frame, qm, queryResult, ch, logger) +} + +func (e *DataSourceHandler) processFrame(frame *data.Frame, qm *dataQueryModel, queryResult DBDataResponse, ch chan DBDataResponse, logger log.Logger) { + if frame.Meta == nil { + frame.Meta = &data.FrameMeta{} + } + frame.Meta.ExecutedQueryString = qm.InterpolatedQuery + + // If no rows were returned, clear any previously set `Fields` with a single empty `data.Field` slice. + // Then assign `queryResult.dataResponse.Frames` the current single frame with that single empty Field. + // This assures 1) our visualization doesn't display unwanted empty fields, and also that 2) + // additionally-needed frame data stays intact and is correctly passed to our visulization. + if frame.Rows() == 0 { + frame.Fields = []*data.Field{} + queryResult.dataResponse.Frames = data.Frames{frame} + ch <- queryResult + return + } + + if err := convertSQLTimeColumnsToEpochMS(frame, qm); err != nil { + e.handleQueryError("converting time columns failed", err, qm.InterpolatedQuery, backend.ErrorSourcePlugin, ch, queryResult) + return + } + + if qm.Format == dataQueryFormatSeries { + // time series has to have time column + if qm.timeIndex == -1 { + e.handleQueryError("db has no time column", errors.New("time column is missing; make sure your data includes a time column for time series format or switch to a table format that doesn't require it"), qm.InterpolatedQuery, backend.ErrorSourceDownstream, ch, queryResult) + return + } + + // Make sure to name the time field 'Time' to be backward compatible with Grafana pre-v8. + frame.Fields[qm.timeIndex].Name = data.TimeSeriesTimeFieldName + + for i := range qm.columnNames { + if i == qm.timeIndex || i == qm.metricIndex { + continue + } + + if t := frame.Fields[i].Type(); t == data.FieldTypeString || t == data.FieldTypeNullableString { + continue + } + + var err error + if frame, err = convertSQLValueColumnToFloat(frame, i); err != nil { + e.handleQueryError("convert value to float failed", err, qm.InterpolatedQuery, backend.ErrorSourcePlugin, ch, queryResult) + return + } + } + + tsSchema := frame.TimeSeriesSchema() + if tsSchema.Type == data.TimeSeriesTypeLong { + var err error + originalData := frame + frame, err = data.LongToWide(frame, qm.FillMissing) + if err != nil { + e.handleQueryError("failed to convert long to wide series when converting from dataframe", err, qm.InterpolatedQuery, backend.ErrorSourcePlugin, ch, queryResult) + return + } + + // Before 8x, a special metric column was used to name time series. The LongToWide transforms that into a metric label on the value field. + // But that makes series name have both the value column name AND the metric name. So here we are removing the metric label here and moving it to the + // field name to get the same naming for the series as pre v8 + if len(originalData.Fields) == 3 { + for _, field := range frame.Fields { + if len(field.Labels) == 1 { // 7x only supported one label + name, ok := field.Labels["metric"] + if ok { + field.Name = name + field.Labels = nil + } + } + } + } + } + if qm.FillMissing != nil { + // we align the start-time + startUnixTime := qm.TimeRange.From.Unix() / int64(qm.Interval.Seconds()) * int64(qm.Interval.Seconds()) + alignedTimeRange := backend.TimeRange{ + From: time.Unix(startUnixTime, 0), + To: qm.TimeRange.To, + } + + var err error + frame, err = sqlutil.ResampleWideFrame(frame, qm.FillMissing, alignedTimeRange, qm.Interval) + if err != nil { + logger.Error("Failed to resample dataframe", "err", err) + frame.AppendNotices(data.Notice{Text: "Failed to resample dataframe", Severity: data.NoticeSeverityWarning}) + return + } + } + } + + queryResult.dataResponse.Frames = data.Frames{frame} + ch <- queryResult +} + +func (e *DataSourceHandler) newProcessCfgPGX(queryContext context.Context, query backend.DataQuery, + results []*pgconn.Result, interpolatedQuery string) (*dataQueryModel, error) { + columnNames := []string{} + columnTypesPGX := []string{} + + // The results will contain column information in the metadata + for _, result := range results { + // Get column names from the result metadata + for _, field := range result.FieldDescriptions { + columnNames = append(columnNames, field.Name) + pqtype, ok := pgtype.NewMap().TypeForOID(field.DataTypeOID) + if !ok { + // Handle special cases for field types + switch field.DataTypeOID { + case pgtype.TimetzOID: + columnTypesPGX = append(columnTypesPGX, "timetz") + case 790: + columnTypesPGX = append(columnTypesPGX, "money") + default: + return nil, fmt.Errorf("unknown data type oid: %d", field.DataTypeOID) + } + } else { + columnTypesPGX = append(columnTypesPGX, pqtype.Name) + } + } + } + + qm := &dataQueryModel{ + columnTypesPGX: columnTypesPGX, + columnNames: columnNames, + timeIndex: -1, + timeEndIndex: -1, + metricIndex: -1, + metricPrefix: false, + queryContext: queryContext, + } + + queryJSON := QueryJson{} + err := json.Unmarshal(query.JSON, &queryJSON) + if err != nil { + return nil, err + } + + if queryJSON.Fill { + qm.FillMissing = &data.FillMissing{} + qm.Interval = time.Duration(queryJSON.FillInterval * float64(time.Second)) + switch strings.ToLower(queryJSON.FillMode) { + case "null": + qm.FillMissing.Mode = data.FillModeNull + case "previous": + qm.FillMissing.Mode = data.FillModePrevious + case "value": + qm.FillMissing.Mode = data.FillModeValue + qm.FillMissing.Value = queryJSON.FillValue + default: + } + } + + qm.TimeRange.From = query.TimeRange.From.UTC() + qm.TimeRange.To = query.TimeRange.To.UTC() + + switch queryJSON.Format { + case "time_series": + qm.Format = dataQueryFormatSeries + case "table": + qm.Format = dataQueryFormatTable + default: + panic(fmt.Sprintf("Unrecognized query model format: %q", queryJSON.Format)) + } + + for i, col := range qm.columnNames { + for _, tc := range e.timeColumnNames { + if col == tc { + qm.timeIndex = i + break + } + } + + if qm.Format == dataQueryFormatTable && strings.EqualFold(col, "timeend") { + qm.timeEndIndex = i + continue + } + + switch col { + case "metric": + qm.metricIndex = i + default: + if qm.metricIndex == -1 { + columnType := qm.columnTypesPGX[i] + for _, mct := range e.metricColumnTypes { + if columnType == mct { + qm.metricIndex = i + continue + } + } + } + } + } + qm.InterpolatedQuery = interpolatedQuery + return qm, nil +} + +func convertResultsToFrame(results []*pgconn.Result, rowLimit int64) (*data.Frame, error) { + frame := data.Frame{} + m := pgtype.NewMap() + + for _, result := range results { + // Skip non-select statements + if !result.CommandTag.Select() { + continue + } + fields := make(data.Fields, len(result.FieldDescriptions)) + + fieldTypes, err := getFieldTypesFromDescriptions(result.FieldDescriptions, m) + if err != nil { + return nil, err + } + + for i, v := range result.FieldDescriptions { + fields[i] = data.NewFieldFromFieldType(fieldTypes[i], 0) + fields[i].Name = v.Name + } + // Create a new frame + frame = *data.NewFrame("", fields...) + } + + // Add rows to the frame + for _, result := range results { + // Skip non-select statements + if !result.CommandTag.Select() { + continue + } + fieldDescriptions := result.FieldDescriptions + for rowIdx := range result.Rows { + if rowIdx == int(rowLimit) { + frame.AppendNotices(data.Notice{ + Severity: data.NoticeSeverityWarning, + Text: fmt.Sprintf("Results have been limited to %v because the SQL row limit was reached", rowLimit), + }) + break + } + row := make([]interface{}, len(fieldDescriptions)) + for colIdx, fd := range fieldDescriptions { + rawValue := result.Rows[rowIdx][colIdx] + dataTypeOID := fd.DataTypeOID + format := fd.Format + + if rawValue == nil { + row[colIdx] = nil + continue + } + + // Convert based on type + switch fd.DataTypeOID { + case pgtype.Int2OID: + var d *int16 + scanPlan := m.PlanScan(dataTypeOID, format, &d) + err := scanPlan.Scan(rawValue, &d) + if err != nil { + return nil, err + } + row[colIdx] = d + case pgtype.Int4OID: + var d *int32 + scanPlan := m.PlanScan(dataTypeOID, format, &d) + err := scanPlan.Scan(rawValue, &d) + if err != nil { + return nil, err + } + row[colIdx] = d + case pgtype.Int8OID: + var d *int64 + scanPlan := m.PlanScan(dataTypeOID, format, &d) + err := scanPlan.Scan(rawValue, &d) + if err != nil { + return nil, err + } + row[colIdx] = d + case pgtype.NumericOID, pgtype.Float8OID, pgtype.Float4OID: + var d *float64 + scanPlan := m.PlanScan(dataTypeOID, format, &d) + err := scanPlan.Scan(rawValue, &d) + if err != nil { + return nil, err + } + row[colIdx] = d + case pgtype.BoolOID: + var d *bool + scanPlan := m.PlanScan(dataTypeOID, format, &d) + err := scanPlan.Scan(rawValue, &d) + if err != nil { + return nil, err + } + row[colIdx] = d + case pgtype.ByteaOID: + d, err := pgtype.ByteaCodec.DecodeValue(pgtype.ByteaCodec{}, m, dataTypeOID, format, rawValue) + if err != nil { + return nil, err + } + str := string(d.([]byte)) + row[colIdx] = &str + case pgtype.TimestampOID, pgtype.TimestamptzOID, pgtype.DateOID: + var d *time.Time + scanPlan := m.PlanScan(dataTypeOID, format, &d) + err := scanPlan.Scan(rawValue, &d) + if err != nil { + return nil, err + } + row[colIdx] = d + case pgtype.TimeOID, pgtype.TimetzOID: + var d *string + scanPlan := m.PlanScan(dataTypeOID, format, &d) + err := scanPlan.Scan(rawValue, &d) + if err != nil { + return nil, err + } + row[colIdx] = d + default: + var d *string + scanPlan := m.PlanScan(dataTypeOID, format, &d) + err := scanPlan.Scan(rawValue, &d) + if err != nil { + return nil, err + } + row[colIdx] = d + } + } + frame.AppendRow(row...) + } + } + + return &frame, nil +} + +func getFieldTypesFromDescriptions(fieldDescriptions []pgconn.FieldDescription, m *pgtype.Map) ([]data.FieldType, error) { + fieldTypes := make([]data.FieldType, len(fieldDescriptions)) + for i, v := range fieldDescriptions { + typeName, ok := m.TypeForOID(v.DataTypeOID) + if !ok { + // Handle special cases for field types + if v.DataTypeOID == pgtype.TimetzOID || v.DataTypeOID == 790 { + fieldTypes[i] = data.FieldTypeNullableString + } else { + return nil, fmt.Errorf("unknown data type oid: %d", v.DataTypeOID) + } + } else { + switch typeName.Name { + case "int2": + fieldTypes[i] = data.FieldTypeNullableInt16 + case "int4": + fieldTypes[i] = data.FieldTypeNullableInt32 + case "int8": + fieldTypes[i] = data.FieldTypeNullableInt64 + case "float4", "float8", "numeric": + fieldTypes[i] = data.FieldTypeNullableFloat64 + case "bool": + fieldTypes[i] = data.FieldTypeNullableBool + case "timestamptz", "timestamp", "date": + fieldTypes[i] = data.FieldTypeNullableTime + case "json", "jsonb": + fieldTypes[i] = data.FieldTypeNullableJSON + default: + fieldTypes[i] = data.FieldTypeNullableString + } + } + } + return fieldTypes, nil +} diff --git a/pkg/tsdb/grafana-postgresql-datasource/testdata/table/types_datetime.golden.jsonc b/pkg/tsdb/grafana-postgresql-datasource/testdata/table/types_datetime.golden.jsonc index 48e55e7b077..5754a588b10 100644 --- a/pkg/tsdb/grafana-postgresql-datasource/testdata/table/types_datetime.golden.jsonc +++ b/pkg/tsdb/grafana-postgresql-datasource/testdata/table/types_datetime.golden.jsonc @@ -9,14 +9,14 @@ // } // Name: // Dimensions: 12 Fields by 2 Rows -// +----------------------------------------+----------------------------------------+--------------------------------------+--------------------------------------+---------------------------------+---------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+ -// | Name: ts | Name: tsnn | Name: tsz | Name: tsznn | Name: d | Name: dnn | Name: t | Name: tnn | Name: tz | Name: tznn | Name: i | Name: inn | -// | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | -// | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*string | Type: []*string | -// +----------------------------------------+----------------------------------------+--------------------------------------+--------------------------------------+---------------------------------+---------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+ -// | 2023-11-15 05:06:07.123456 +0000 +0000 | 2023-11-15 05:06:08.123456 +0000 +0000 | 2021-07-22 11:22:33.654321 +0000 UTC | 2021-07-22 11:22:34.654321 +0000 UTC | 2023-12-20 00:00:00 +0000 +0000 | 2023-12-21 00:00:00 +0000 +0000 | 0000-01-01 12:34:56.234567 +0000 UTC | 0000-01-01 12:34:57.234567 +0000 UTC | 0000-01-01 23:12:36.765432 +0100 +0100 | 0000-01-01 23:12:37.765432 +0100 +0100 | 00:00:00.987654 | 00:00:00.887654 | -// | null | 2023-11-15 05:06:09.123456 +0000 +0000 | null | 2021-07-22 11:22:35.654321 +0000 UTC | null | 2023-12-22 00:00:00 +0000 +0000 | null | 0000-01-01 12:34:58.234567 +0000 UTC | null | 0000-01-01 23:12:38.765432 +0100 +0100 | null | 00:00:00.787654 | -// +----------------------------------------+----------------------------------------+--------------------------------------+--------------------------------------+---------------------------------+---------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+ +// +--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-------------------------------+-------------------------------+-----------------+-----------------+--------------------+--------------------+-----------------+-----------------+ +// | Name: ts | Name: tsnn | Name: tsz | Name: tsznn | Name: d | Name: dnn | Name: t | Name: tnn | Name: tz | Name: tznn | Name: i | Name: inn | +// | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | +// | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*string | Type: []*string | Type: []*string | Type: []*string | Type: []*string | Type: []*string | +// +--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-------------------------------+-------------------------------+-----------------+-----------------+--------------------+--------------------+-----------------+-----------------+ +// | 2023-11-15 05:06:07.123456 +0000 UTC | 2023-11-15 05:06:08.123456 +0000 UTC | 2021-07-22 11:22:33.654321 +0000 +0000 | 2021-07-22 11:22:34.654321 +0000 +0000 | 2023-12-20 00:00:00 +0000 UTC | 2023-12-21 00:00:00 +0000 UTC | 12:34:56.234567 | 12:34:57.234567 | 23:12:36.765432+01 | 23:12:37.765432+01 | 00:00:00.987654 | 00:00:00.887654 | +// | null | 2023-11-15 05:06:09.123456 +0000 UTC | null | 2021-07-22 11:22:35.654321 +0000 +0000 | null | 2023-12-22 00:00:00 +0000 UTC | null | 12:34:58.234567 | null | 23:12:38.765432+01 | null | 00:00:00.787654 | +// +--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-------------------------------+-------------------------------+-----------------+-----------------+--------------------+--------------------+-----------------+-----------------+ // // // 🌟 This was machine generated. Do not edit. 🌟 @@ -83,33 +83,33 @@ }, { "name": "t", - "type": "time", + "type": "string", "typeInfo": { - "frame": "time.Time", + "frame": "string", "nullable": true } }, { "name": "tnn", - "type": "time", + "type": "string", "typeInfo": { - "frame": "time.Time", + "frame": "string", "nullable": true } }, { "name": "tz", - "type": "time", + "type": "string", "typeInfo": { - "frame": "time.Time", + "frame": "string", "nullable": true } }, { "name": "tznn", - "type": "time", + "type": "string", "typeInfo": { - "frame": "time.Time", + "frame": "string", "nullable": true } }, @@ -158,20 +158,20 @@ 1703203200000 ], [ - -62167173903766, + "12:34:56.234567", null ], [ - -62167173902766, - -62167173901766 + "12:34:57.234567", + "12:34:58.234567" ], [ - -62167139243235, + "23:12:36.765432+01", null ], [ - -62167139242235, - -62167139241235 + "23:12:37.765432+01", + "23:12:38.765432+01" ], [ "00:00:00.987654", @@ -201,22 +201,10 @@ ], null, null, - [ - 567000, - 0 - ], - [ - 567000, - 567000 - ], - [ - 432000, - 0 - ], - [ - 432000, - 432000 - ], + null, + null, + null, + null, null, null ] diff --git a/pkg/tsdb/grafana-postgresql-datasource/tlsmanager.go b/pkg/tsdb/grafana-postgresql-datasource/tlsmanager.go index 306b1f43e81..2adadcf58ea 100644 --- a/pkg/tsdb/grafana-postgresql-datasource/tlsmanager.go +++ b/pkg/tsdb/grafana-postgresql-datasource/tlsmanager.go @@ -3,47 +3,20 @@ package postgres import ( "fmt" "os" - "path/filepath" - "strconv" - "strings" - "sync" - "time" "github.com/grafana/grafana-plugin-sdk-go/backend/log" "github.com/grafana/grafana/pkg/tsdb/grafana-postgresql-datasource/sqleng" ) var validateCertFunc = validateCertFilePaths -var writeCertFileFunc = writeCertFile - -type certFileType int - -const ( - rootCert = iota - clientCert - clientKey -) - -type tlsSettingsProvider interface { - getTLSSettings(dsInfo sqleng.DataSourceInfo) (tlsSettings, error) -} - -type datasourceCacheManager struct { - locker *locker - cache sync.Map -} type tlsManager struct { - logger log.Logger - dsCacheInstance datasourceCacheManager - dataPath string + logger log.Logger } -func newTLSManager(logger log.Logger, dataPath string) tlsSettingsProvider { +func newTLSManager(logger log.Logger) *tlsManager { return &tlsManager{ - logger: logger, - dataPath: dataPath, - dsCacheInstance: datasourceCacheManager{locker: newLocker()}, + logger: logger, } } @@ -55,178 +28,116 @@ type tlsSettings struct { CertKeyFile string } +// getTLSSettings retrieves TLS settings and handles certificate file creation if needed. func (m *tlsManager) getTLSSettings(dsInfo sqleng.DataSourceInfo) (tlsSettings, error) { - tlsconfig := tlsSettings{ + tlsConfig := tlsSettings{ Mode: dsInfo.JsonData.Mode, } - isTLSDisabled := (tlsconfig.Mode == "disable") - - if isTLSDisabled { + if tlsConfig.Mode == "disable" { m.logger.Debug("Postgres TLS/SSL is disabled") - return tlsconfig, nil + return tlsConfig, nil } - m.logger.Debug("Postgres TLS/SSL is enabled", "tlsMode", tlsconfig.Mode) + tlsConfig.ConfigurationMethod = dsInfo.JsonData.ConfigurationMethod + tlsConfig.RootCertFile = dsInfo.JsonData.RootCertFile + tlsConfig.CertFile = dsInfo.JsonData.CertFile + tlsConfig.CertKeyFile = dsInfo.JsonData.CertKeyFile - tlsconfig.ConfigurationMethod = dsInfo.JsonData.ConfigurationMethod - tlsconfig.RootCertFile = dsInfo.JsonData.RootCertFile - tlsconfig.CertFile = dsInfo.JsonData.CertFile - tlsconfig.CertKeyFile = dsInfo.JsonData.CertKeyFile - - if tlsconfig.ConfigurationMethod == "file-content" { - if err := m.writeCertFiles(dsInfo, &tlsconfig); err != nil { - return tlsconfig, err + if tlsConfig.ConfigurationMethod == "file-content" { + if err := m.createCertFiles(dsInfo, &tlsConfig); err != nil { + return tlsConfig, fmt.Errorf("failed to create TLS certificate files: %w", err) } } else { - if err := validateCertFunc(tlsconfig.RootCertFile, tlsconfig.CertFile, tlsconfig.CertKeyFile); err != nil { - return tlsconfig, err + if err := validateCertFunc(tlsConfig.RootCertFile, tlsConfig.CertFile, tlsConfig.CertKeyFile); err != nil { + return tlsConfig, fmt.Errorf("invalid TLS certificate file paths: %w", err) } } - return tlsconfig, nil + + return tlsConfig, nil } -func (t certFileType) String() string { - switch t { - case rootCert: - return "root certificate" - case clientCert: - return "client certificate" - case clientKey: - return "client key" - default: - panic(fmt.Sprintf("Unrecognized certFileType %d", t)) - } -} +// createCertFiles writes certificate files to temporary locations. +func (m *tlsManager) createCertFiles(dsInfo sqleng.DataSourceInfo, tlsConfig *tlsSettings) error { + m.logger.Debug("Writing TLS certificate files to temporary locations") -func getFileName(dataDir string, fileType certFileType) string { - var filename string - switch fileType { - case rootCert: - filename = "root.crt" - case clientCert: - filename = "client.crt" - case clientKey: - filename = "client.key" - default: - panic(fmt.Sprintf("unrecognized certFileType %s", fileType.String())) - } - generatedFilePath := filepath.Join(dataDir, filename) - return generatedFilePath -} - -// writeCertFile writes a certificate file. -func writeCertFile(logger log.Logger, fileContent string, generatedFilePath string) error { - fileContent = strings.TrimSpace(fileContent) - if fileContent != "" { - logger.Debug("Writing cert file", "path", generatedFilePath) - if err := os.WriteFile(generatedFilePath, []byte(fileContent), 0600); err != nil { - return err - } - // Make sure the file has the permissions expected by the Postgresql driver, otherwise it will bail - if err := os.Chmod(generatedFilePath, 0600); err != nil { - return err - } - return nil - } - - logger.Debug("Deleting cert file since no content is provided", "path", generatedFilePath) - exists, err := fileExists(generatedFilePath) - if err != nil { + var err error + if tlsConfig.RootCertFile, err = m.writeCertFile("root-*.crt", dsInfo.DecryptedSecureJSONData["tlsCACert"]); err != nil { return err } - if exists { - if err := os.Remove(generatedFilePath); err != nil { - return fmt.Errorf("failed to remove %q: %w", generatedFilePath, err) - } + if tlsConfig.CertFile, err = m.writeCertFile("client-*.crt", dsInfo.DecryptedSecureJSONData["tlsClientCert"]); err != nil { + return err } + if tlsConfig.CertKeyFile, err = m.writeCertFile("client-*.key", dsInfo.DecryptedSecureJSONData["tlsClientKey"]); err != nil { + return err + } + return nil } -func (m *tlsManager) writeCertFiles(dsInfo sqleng.DataSourceInfo, tlsconfig *tlsSettings) error { - m.logger.Debug("Writing TLS certificate files to disk") - tlsRootCert := dsInfo.DecryptedSecureJSONData["tlsCACert"] - tlsClientCert := dsInfo.DecryptedSecureJSONData["tlsClientCert"] - tlsClientKey := dsInfo.DecryptedSecureJSONData["tlsClientKey"] - if tlsRootCert == "" && tlsClientCert == "" && tlsClientKey == "" { - m.logger.Debug("No TLS/SSL certificates provided") +// writeCertFile writes a single certificate file to a temporary location. +func (m *tlsManager) writeCertFile(pattern, content string) (string, error) { + if content == "" { + return "", nil } - // Calculate all files path - workDir := filepath.Join(m.dataPath, "tls", dsInfo.UID+"generatedTLSCerts") - tlsconfig.RootCertFile = getFileName(workDir, rootCert) - tlsconfig.CertFile = getFileName(workDir, clientCert) - tlsconfig.CertKeyFile = getFileName(workDir, clientKey) - - // Find datasource in the cache, if found, skip writing files - cacheKey := strconv.Itoa(int(dsInfo.ID)) - m.dsCacheInstance.locker.RLock(cacheKey) - item, ok := m.dsCacheInstance.cache.Load(cacheKey) - m.dsCacheInstance.locker.RUnlock(cacheKey) - if ok { - if !item.(time.Time).Before(dsInfo.Updated) { - return nil - } - } - - m.dsCacheInstance.locker.Lock(cacheKey) - defer m.dsCacheInstance.locker.Unlock(cacheKey) - - item, ok = m.dsCacheInstance.cache.Load(cacheKey) - if ok { - if !item.(time.Time).Before(dsInfo.Updated) { - return nil - } - } - - // Write certification directory and files - exists, err := fileExists(workDir) + m.logger.Debug("Writing certificate file", "pattern", pattern) + file, err := os.CreateTemp("", pattern) if err != nil { - return err + return "", fmt.Errorf("failed to create temporary file: %w", err) } - if !exists { - if err := os.MkdirAll(workDir, 0700); err != nil { - return err + defer func() { + if err := file.Close(); err != nil { + m.logger.Error("Failed to close file", "error", err) } + }() + + if _, err := file.WriteString(content); err != nil { + return "", fmt.Errorf("failed to write to temporary file: %w", err) } - if err = writeCertFileFunc(m.logger, tlsRootCert, tlsconfig.RootCertFile); err != nil { - return err - } - if err = writeCertFileFunc(m.logger, tlsClientCert, tlsconfig.CertFile); err != nil { - return err - } - if err = writeCertFileFunc(m.logger, tlsClientKey, tlsconfig.CertKeyFile); err != nil { - return err - } - - // we do not want to point to cert-files that do not exist - if tlsRootCert == "" { - tlsconfig.RootCertFile = "" - } - - if tlsClientCert == "" { - tlsconfig.CertFile = "" - } - - if tlsClientKey == "" { - tlsconfig.CertKeyFile = "" - } - - // Update datasource cache - m.dsCacheInstance.cache.Store(cacheKey, dsInfo.Updated) - return nil + return file.Name(), nil } -// validateCertFilePaths validates configured certificate file paths. -func validateCertFilePaths(rootCert, clientCert, clientKey string) error { - for _, fpath := range []string{rootCert, clientCert, clientKey} { - if fpath == "" { +// cleanupCertFiles removes temporary certificate files. +func (m *tlsManager) cleanupCertFiles(tlsConfig tlsSettings) { + // Only clean up if the configuration method is "file-content" + if tlsConfig.ConfigurationMethod != "file-content" { + m.logger.Debug("Skipping cleanup of TLS certificate files") + return + } + m.logger.Debug("Cleaning up TLS certificate files") + + files := []struct { + path string + name string + }{ + {tlsConfig.RootCertFile, "root certificate"}, + {tlsConfig.CertFile, "client certificate"}, + {tlsConfig.CertKeyFile, "client key"}, + } + + for _, file := range files { + if file.path == "" { continue } - exists, err := fileExists(fpath) + if err := os.Remove(file.path); err != nil { + m.logger.Error("Failed to remove file", "type", file.name, "path", file.path, "error", err) + } else { + m.logger.Debug("Successfully removed file", "type", file.name, "path", file.path) + } + } +} + +// validateCertFilePaths validates the existence of configured certificate file paths. +func validateCertFilePaths(rootCert, clientCert, clientKey string) error { + for _, path := range []string{rootCert, clientCert, clientKey} { + if path == "" { + continue + } + exists, err := fileExists(path) if err != nil { - return err + return fmt.Errorf("error checking file existence: %w", err) } if !exists { return sqleng.ErrCertFileNotExist @@ -235,15 +146,14 @@ func validateCertFilePaths(rootCert, clientCert, clientKey string) error { return nil } -// Exists determines whether a file/directory exists or not. -func fileExists(fpath string) (bool, error) { - _, err := os.Stat(fpath) +// fileExists checks if a file exists at the given path. +func fileExists(path string) (bool, error) { + _, err := os.Stat(path) if err != nil { - if !os.IsNotExist(err) { - return false, err + if os.IsNotExist(err) { + return false, nil } - return false, nil + return false, err } - return true, nil } diff --git a/pkg/tsdb/grafana-postgresql-datasource/tlsmanager_test.go b/pkg/tsdb/grafana-postgresql-datasource/tlsmanager_test.go index c8eda0a83a7..63d64b55b17 100644 --- a/pkg/tsdb/grafana-postgresql-datasource/tlsmanager_test.go +++ b/pkg/tsdb/grafana-postgresql-datasource/tlsmanager_test.go @@ -2,176 +2,21 @@ package postgres import ( "fmt" + "os" "path/filepath" - "strconv" "strings" - "sync" "testing" "time" "github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/backend/log" - "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/tsdb/grafana-postgresql-datasource/sqleng" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - _ "github.com/lib/pq" ) -var writeCertFileCallNum int - -// TestDataSourceCacheManager is to test the Cache manager -func TestDataSourceCacheManager(t *testing.T) { - cfg := setting.NewCfg() - cfg.DataPath = t.TempDir() - mng := tlsManager{ - logger: backend.NewLoggerWith("logger", "tsdb.postgres"), - dsCacheInstance: datasourceCacheManager{locker: newLocker()}, - dataPath: cfg.DataPath, - } - jsonData := sqleng.JsonData{ - Mode: "verify-full", - ConfigurationMethod: "file-content", - } - secureJSONData := map[string]string{ - "tlsClientCert": "I am client certification", - "tlsClientKey": "I am client key", - "tlsCACert": "I am CA certification", - } - - updateTime := time.Now().Add(-5 * time.Minute) - - mockValidateCertFilePaths() - t.Cleanup(resetValidateCertFilePaths) - - t.Run("Check datasource cache creation", func(t *testing.T) { - var wg sync.WaitGroup - wg.Add(10) - for id := int64(1); id <= 10; id++ { - go func(id int64) { - ds := sqleng.DataSourceInfo{ - ID: id, - Updated: updateTime, - Database: "database", - JsonData: jsonData, - DecryptedSecureJSONData: secureJSONData, - UID: "testData", - } - s := tlsSettings{} - err := mng.writeCertFiles(ds, &s) - require.NoError(t, err) - wg.Done() - }(id) - } - wg.Wait() - - t.Run("check cache creation is succeed", func(t *testing.T) { - for id := int64(1); id <= 10; id++ { - updated, ok := mng.dsCacheInstance.cache.Load(strconv.Itoa(int(id))) - require.True(t, ok) - require.Equal(t, updateTime, updated) - } - }) - }) - - t.Run("Check datasource cache modification", func(t *testing.T) { - t.Run("check when version not changed, cache and files are not updated", func(t *testing.T) { - mockWriteCertFile() - t.Cleanup(resetWriteCertFile) - var wg1 sync.WaitGroup - wg1.Add(5) - for id := int64(1); id <= 5; id++ { - go func(id int64) { - ds := sqleng.DataSourceInfo{ - ID: 1, - Updated: updateTime, - Database: "database", - JsonData: jsonData, - DecryptedSecureJSONData: secureJSONData, - UID: "testData", - } - s := tlsSettings{} - err := mng.writeCertFiles(ds, &s) - require.NoError(t, err) - wg1.Done() - }(id) - } - wg1.Wait() - assert.Equal(t, writeCertFileCallNum, 0) - }) - - t.Run("cache is updated with the last datasource version", func(t *testing.T) { - dsV2 := sqleng.DataSourceInfo{ - ID: 1, - Updated: updateTime.Add(time.Minute), - Database: "database", - JsonData: jsonData, - DecryptedSecureJSONData: secureJSONData, - UID: "testData", - } - dsV3 := sqleng.DataSourceInfo{ - ID: 1, - Updated: updateTime.Add(2 * time.Minute), - Database: "database", - JsonData: jsonData, - DecryptedSecureJSONData: secureJSONData, - UID: "testData", - } - s := tlsSettings{} - err := mng.writeCertFiles(dsV2, &s) - require.NoError(t, err) - err = mng.writeCertFiles(dsV3, &s) - require.NoError(t, err) - version, ok := mng.dsCacheInstance.cache.Load("1") - require.True(t, ok) - require.Equal(t, updateTime.Add(2*time.Minute), version) - }) - }) -} - -// Test getFileName - -func TestGetFileName(t *testing.T) { - testCases := []struct { - desc string - datadir string - fileType certFileType - expErr string - expectedGeneratedPath string - }{ - { - desc: "Get File Name for root certification", - datadir: ".", - fileType: rootCert, - expectedGeneratedPath: "root.crt", - }, - { - desc: "Get File Name for client certification", - datadir: ".", - fileType: clientCert, - expectedGeneratedPath: "client.crt", - }, - { - desc: "Get File Name for client certification", - datadir: ".", - fileType: clientKey, - expectedGeneratedPath: "client.key", - }, - } - for _, tt := range testCases { - t.Run(tt.desc, func(t *testing.T) { - generatedPath := getFileName(tt.datadir, tt.fileType) - assert.Equal(t, tt.expectedGeneratedPath, generatedPath) - }) - } -} - // Test getTLSSettings. func TestGetTLSSettings(t *testing.T) { - cfg := setting.NewCfg() - cfg.DataPath = t.TempDir() - mockValidateCertFilePaths() t.Cleanup(resetValidateCertFilePaths) @@ -216,75 +61,13 @@ func TestGetTLSSettings(t *testing.T) { CertKeyFile: "i/am/coding/client.key", }, }, - { - desc: "Custom TLS mode verify-full with certificate files content", - updated: updatedTime.Add(2 * time.Minute), - uid: "xxx", - jsonData: sqleng.JsonData{ - Mode: "verify-full", - ConfigurationMethod: "file-content", - }, - secureJSONData: map[string]string{ - "tlsCACert": "I am CA certification", - "tlsClientCert": "I am client certification", - "tlsClientKey": "I am client key", - }, - tlsSettings: tlsSettings{ - Mode: "verify-full", - ConfigurationMethod: "file-content", - RootCertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "root.crt"), - CertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.crt"), - CertKeyFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.key"), - }, - }, - { - desc: "Custom TLS mode verify-ca with no client certificates with certificate files content", - updated: updatedTime.Add(3 * time.Minute), - uid: "xxx", - jsonData: sqleng.JsonData{ - Mode: "verify-ca", - ConfigurationMethod: "file-content", - }, - secureJSONData: map[string]string{ - "tlsCACert": "I am CA certification", - }, - tlsSettings: tlsSettings{ - Mode: "verify-ca", - ConfigurationMethod: "file-content", - RootCertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "root.crt"), - CertFile: "", - CertKeyFile: "", - }, - }, - { - desc: "Custom TLS mode require with client certificates and no root certificate with certificate files content", - updated: updatedTime.Add(4 * time.Minute), - uid: "xxx", - jsonData: sqleng.JsonData{ - Mode: "require", - ConfigurationMethod: "file-content", - }, - secureJSONData: map[string]string{ - "tlsClientCert": "I am client certification", - "tlsClientKey": "I am client key", - }, - tlsSettings: tlsSettings{ - Mode: "require", - ConfigurationMethod: "file-content", - RootCertFile: "", - CertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.crt"), - CertKeyFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.key"), - }, - }, } for _, tt := range testCases { t.Run(tt.desc, func(t *testing.T) { var settings tlsSettings var err error mng := tlsManager{ - logger: backend.NewLoggerWith("logger", "tsdb.postgres"), - dsCacheInstance: datasourceCacheManager{locker: newLocker()}, - dataPath: cfg.DataPath, + logger: backend.NewLoggerWith("logger", "tsdb.postgres"), } ds := sqleng.DataSourceInfo{ @@ -318,15 +101,145 @@ func resetValidateCertFilePaths() { validateCertFunc = validateCertFilePaths } -func mockWriteCertFile() { - writeCertFileCallNum = 0 - writeCertFileFunc = func(logger log.Logger, fileContent string, generatedFilePath string) error { - writeCertFileCallNum++ - return nil +func TestTLSManager_GetTLSSettings(t *testing.T) { + logger := log.New() + tlsManager := newTLSManager(logger) + + dsInfo := sqleng.DataSourceInfo{ + JsonData: sqleng.JsonData{ + Mode: "require", + ConfigurationMethod: "file-content", + }, + DecryptedSecureJSONData: map[string]string{ + "tlsCACert": "root-cert-content", + "tlsClientCert": "client-cert-content", + "tlsClientKey": "client-key-content", + }, } + + tlsConfig, err := tlsManager.getTLSSettings(dsInfo) + require.NoError(t, err) + assert.Equal(t, "require", tlsConfig.Mode) + assert.NotEmpty(t, tlsConfig.RootCertFile) + assert.NotEmpty(t, tlsConfig.CertFile) + assert.NotEmpty(t, tlsConfig.CertKeyFile) + + // Cleanup temporary files + tlsManager.cleanupCertFiles(tlsConfig) + assert.NoFileExists(t, tlsConfig.RootCertFile) + assert.NoFileExists(t, tlsConfig.CertFile) + assert.NoFileExists(t, tlsConfig.CertKeyFile) } -func resetWriteCertFile() { - writeCertFileCallNum = 0 - writeCertFileFunc = writeCertFile +func TestTLSManager_CleanupCertFiles_FilePath(t *testing.T) { + logger := log.New() + tlsManager := newTLSManager(logger) + + // Create temporary files for testing + rootCertFile, err := tlsManager.writeCertFile("root-*.crt", "root-cert-content") + require.NoError(t, err) + clientCertFile, err := tlsManager.writeCertFile("client-*.crt", "client-cert-content") + require.NoError(t, err) + clientKeyFile, err := tlsManager.writeCertFile("client-*.key", "client-key-content") + require.NoError(t, err) + + // Simulate a configuration where the method is "file-path" + tlsConfig := tlsSettings{ + ConfigurationMethod: "file-path", + RootCertFile: rootCertFile, + CertFile: clientCertFile, + CertKeyFile: clientKeyFile, + } + + // Call cleanupCertFiles + tlsManager.cleanupCertFiles(tlsConfig) + + // Verify the files are NOT deleted + assert.FileExists(t, rootCertFile, "Root certificate file should not be deleted") + assert.FileExists(t, clientCertFile, "Client certificate file should not be deleted") + assert.FileExists(t, clientKeyFile, "Client key file should not be deleted") + + // Cleanup the files manually + err = os.Remove(rootCertFile) + require.NoError(t, err) + err = os.Remove(clientCertFile) + require.NoError(t, err) + err = os.Remove(clientKeyFile) + require.NoError(t, err) +} + +func TestTLSManager_CreateCertFiles(t *testing.T) { + logger := log.New() + tlsManager := newTLSManager(logger) + + dsInfo := sqleng.DataSourceInfo{ + DecryptedSecureJSONData: map[string]string{ + "tlsCACert": "root-cert-content", + "tlsClientCert": "client-cert-content", + "tlsClientKey": "client-key-content", + }, + } + + tlsConfig := tlsSettings{ + ConfigurationMethod: "file-content", + } + err := tlsManager.createCertFiles(dsInfo, &tlsConfig) + require.NoError(t, err) + + assert.FileExists(t, tlsConfig.RootCertFile) + assert.FileExists(t, tlsConfig.CertFile) + assert.FileExists(t, tlsConfig.CertKeyFile) + + // Cleanup temporary files + tlsManager.cleanupCertFiles(tlsConfig) + assert.NoFileExists(t, tlsConfig.RootCertFile) + assert.NoFileExists(t, tlsConfig.CertFile) + assert.NoFileExists(t, tlsConfig.CertKeyFile) +} + +func TestTLSManager_WriteCertFile(t *testing.T) { + logger := log.New() + tlsManager := newTLSManager(logger) + + // Test writing a valid certificate file + filePath, err := tlsManager.writeCertFile("test-*.crt", "test-cert-content") + require.NoError(t, err) + assert.FileExists(t, filePath) + + content, err := os.ReadFile(filepath.Clean(filePath)) + require.NoError(t, err) + assert.Equal(t, "test-cert-content", string(content)) + + // Cleanup the file + err = os.Remove(filePath) + require.NoError(t, err) + assert.NoFileExists(t, filePath) +} + +func TestTLSManager_CleanupCertFiles(t *testing.T) { + logger := log.New() + tlsManager := newTLSManager(logger) + + // Create temporary files for testing + rootCertFile, err := tlsManager.writeCertFile("root-*.crt", "root-cert-content") + require.NoError(t, err) + clientCertFile, err := tlsManager.writeCertFile("client-*.crt", "client-cert-content") + require.NoError(t, err) + clientKeyFile, err := tlsManager.writeCertFile("client-*.key", "client-key-content") + require.NoError(t, err) + + tlsConfig := tlsSettings{ + ConfigurationMethod: "file-content", + RootCertFile: rootCertFile, + CertFile: clientCertFile, + CertKeyFile: clientKeyFile, + } + + // Cleanup the files + tlsManager.cleanupCertFiles(tlsConfig) + + // Verify the files are deleted + assert.NoFileExists(t, rootCertFile) + assert.NoFileExists(t, clientCertFile) + assert.NoFileExists(t, clientKeyFile) } diff --git a/public/app/plugins/datasource/grafana-postgresql-datasource/configuration/ConfigurationEditor.tsx b/public/app/plugins/datasource/grafana-postgresql-datasource/configuration/ConfigurationEditor.tsx index 5d8cb9aae06..273c60df162 100644 --- a/public/app/plugins/datasource/grafana-postgresql-datasource/configuration/ConfigurationEditor.tsx +++ b/public/app/plugins/datasource/grafana-postgresql-datasource/configuration/ConfigurationEditor.tsx @@ -10,7 +10,14 @@ import { } from '@grafana/data'; import { ConfigSection, ConfigSubSection, DataSourceDescription, EditorStack } from '@grafana/plugin-ui'; import { config } from '@grafana/runtime'; -import { ConnectionLimits, Divider, TLSSecretsConfig, useMigrateDatabaseFields } from '@grafana/sql'; +import { + ConnectionLimits, + Divider, + MaxLifetimeField, + MaxOpenConnectionsField, + TLSSecretsConfig, + useMigrateDatabaseFields, +} from '@grafana/sql'; import { Input, Select, @@ -76,6 +83,14 @@ export const PostgresConfigEditor = (props: DataSourcePluginOptionsEditorProps

{ + updateDatasourcePluginJsonDataOption(props, 'maxOpenConns', number); + }; + + const onMaxLifetimeChanged = (number?: number) => { + updateDatasourcePluginJsonDataOption(props, 'connMaxLifetime', number); + }; + const onTimeScaleDBChanged = (event: SyntheticEvent) => { updateDatasourcePluginJsonDataOption(props, 'timescaledb', event.currentTarget.checked); }; @@ -397,8 +412,18 @@ export const PostgresConfigEditor = (props: DataSourcePluginOptionsEditorProps

- - + {config.featureToggles.postgresDSUsePGX ? ( + + + + + ) : ( + + )} {config.secureSocksDSProxyEnabled && ( )}