From ecd6de826a315526ed5fac0ce9e70f9f50e217c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=A1bor=20Farkas?= Date: Wed, 13 Mar 2024 09:52:39 +0100 Subject: [PATCH] Postgres: Switch the datasource plugin from lib/pq to pgx (#83768) postgres: switch from lib/pq to pgx --- go.mod | 11 +- go.sum | 7 + .../grafana-postgresql-datasource/locker.go | 85 ---- .../locker_test.go | 63 --- .../grafana-postgresql-datasource/postgres.go | 139 +++--- .../postgres_snapshot_test.go | 41 +- .../postgres_test.go | 265 ++++++++---- .../grafana-postgresql-datasource/proxy.go | 34 +- .../proxy_test.go | 12 +- .../table/timestamp_convert_real.golden.jsonc | 32 +- .../table/types_datetime.golden.jsonc | 16 +- .../grafana-postgresql-datasource/tls/tls.go | 135 ++++++ .../tls/tls_loader.go | 101 +++++ .../tls/tls_test.go | 402 ++++++++++++++++++ .../tls/tls_test_helpers.go | 105 +++++ .../tlsmanager.go | 249 ----------- .../tlsmanager_test.go | 332 --------------- 17 files changed, 1082 insertions(+), 947 deletions(-) delete mode 100644 pkg/tsdb/grafana-postgresql-datasource/locker.go delete mode 100644 pkg/tsdb/grafana-postgresql-datasource/locker_test.go create mode 100644 pkg/tsdb/grafana-postgresql-datasource/tls/tls.go create mode 100644 pkg/tsdb/grafana-postgresql-datasource/tls/tls_loader.go create mode 100644 pkg/tsdb/grafana-postgresql-datasource/tls/tls_test.go create mode 100644 pkg/tsdb/grafana-postgresql-datasource/tls/tls_test_helpers.go delete mode 100644 pkg/tsdb/grafana-postgresql-datasource/tlsmanager.go delete mode 100644 pkg/tsdb/grafana-postgresql-datasource/tlsmanager_test.go diff --git a/go.mod b/go.mod index b6e92d64d67..c570e2a216b 100644 --- a/go.mod +++ b/go.mod @@ -122,7 +122,7 @@ require ( gopkg.in/mail.v2 v2.3.1 // @grafana/backend-platform gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // @grafana/alerting-squad-backend - xorm.io/builder v0.3.6 // @grafana/backend-platform + xorm.io/builder v0.3.6 // indirect; @grafana/backend-platform xorm.io/core v0.7.3 // @grafana/backend-platform xorm.io/xorm v0.8.2 // @grafana/alerting-squad-backend ) @@ -174,7 +174,7 @@ require ( github.com/grpc-ecosystem/go-grpc-prometheus v1.2.1-0.20191002090509-6af20e3a5340 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-msgpack v0.5.5 // indirect - github.com/hashicorp/go-multierror v1.1.1 // @grafana/alerting-squad + github.com/hashicorp/go-multierror v1.1.1 // indirect; @grafana/alerting-squad github.com/hashicorp/go-sockaddr v1.0.6 // indirect github.com/hashicorp/golang-lru v0.6.0 // indirect github.com/hashicorp/yamux v0.1.1 // indirect @@ -337,7 +337,7 @@ require ( github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/grafana/regexp v0.0.0-20221123153739-15dc172cd2db // indirect github.com/hashicorp/go-immutable-radix v1.3.1 // indirect - github.com/hashicorp/golang-lru/v2 v2.0.7 // @grafana/alerting-squad-backend + github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect; @grafana/alerting-squad-backend github.com/hashicorp/memberlist v0.5.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/invopop/yaml v0.2.0 // indirect @@ -493,10 +493,15 @@ require ( github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 // @grafana/grafana-app-platform-squad ) +require github.com/jackc/pgx/v5 v5.5.5 // @grafana/oss-big-tent + require ( github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/invopop/jsonschema v0.12.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect ) diff --git a/go.sum b/go.sum index 76f5048564b..27c86e622f9 100644 --- a/go.sum +++ b/go.sum @@ -2383,6 +2383,7 @@ github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bY github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= @@ -2393,6 +2394,8 @@ github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwX github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.2.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= @@ -2404,10 +2407,14 @@ github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9 github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= github.com/jackc/pgx/v4 v4.15.0/go.mod h1:D/zyOyXiaM1TmVWnOM18p0xdDtdakRBa0RsVGI3U3bw= +github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= +github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.2.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jarcoal/httpmock v1.3.0/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg= github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc= github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4= diff --git a/pkg/tsdb/grafana-postgresql-datasource/locker.go b/pkg/tsdb/grafana-postgresql-datasource/locker.go deleted file mode 100644 index 796c37c7415..00000000000 --- a/pkg/tsdb/grafana-postgresql-datasource/locker.go +++ /dev/null @@ -1,85 +0,0 @@ -package postgres - -import ( - "fmt" - "sync" -) - -// locker is a named reader/writer mutual exclusion lock. -// The lock for each particular key can be held by an arbitrary number of readers or a single writer. -type locker struct { - locks map[any]*sync.RWMutex - locksRW *sync.RWMutex -} - -func newLocker() *locker { - return &locker{ - locks: make(map[any]*sync.RWMutex), - locksRW: new(sync.RWMutex), - } -} - -// Lock locks named rw mutex with specified key for writing. -// If the lock with the same key is already locked for reading or writing, -// Lock blocks until the lock is available. -func (lkr *locker) Lock(key any) { - lk, ok := lkr.getLock(key) - if !ok { - lk = lkr.newLock(key) - } - lk.Lock() -} - -// Unlock unlocks named rw mutex with specified key for writing. It is a run-time error if rw is -// not locked for writing on entry to Unlock. -func (lkr *locker) Unlock(key any) { - lk, ok := lkr.getLock(key) - if !ok { - panic(fmt.Errorf("lock for key '%s' not initialized", key)) - } - lk.Unlock() -} - -// RLock locks named rw mutex with specified key for reading. -// -// It should not be used for recursive read locking for the same key; a blocked Lock -// call excludes new readers from acquiring the lock. See the -// documentation on the golang RWMutex type. -func (lkr *locker) RLock(key any) { - lk, ok := lkr.getLock(key) - if !ok { - lk = lkr.newLock(key) - } - lk.RLock() -} - -// RUnlock undoes a single RLock call for specified key; -// it does not affect other simultaneous readers of locker for specified key. -// It is a run-time error if locker for specified key is not locked for reading -func (lkr *locker) RUnlock(key any) { - lk, ok := lkr.getLock(key) - if !ok { - panic(fmt.Errorf("lock for key '%s' not initialized", key)) - } - lk.RUnlock() -} - -func (lkr *locker) newLock(key any) *sync.RWMutex { - lkr.locksRW.Lock() - defer lkr.locksRW.Unlock() - - if lk, ok := lkr.locks[key]; ok { - return lk - } - lk := new(sync.RWMutex) - lkr.locks[key] = lk - return lk -} - -func (lkr *locker) getLock(key any) (*sync.RWMutex, bool) { - lkr.locksRW.RLock() - defer lkr.locksRW.RUnlock() - - lock, ok := lkr.locks[key] - return lock, ok -} diff --git a/pkg/tsdb/grafana-postgresql-datasource/locker_test.go b/pkg/tsdb/grafana-postgresql-datasource/locker_test.go deleted file mode 100644 index b1dc64f0351..00000000000 --- a/pkg/tsdb/grafana-postgresql-datasource/locker_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package postgres - -import ( - "sync" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -func TestIntegrationLocker(t *testing.T) { - if testing.Short() { - t.Skip("Tests with Sleep") - } - const notUpdated = "not_updated" - const atThread1 = "at_thread_1" - const atThread2 = "at_thread_2" - t.Run("Should lock for same keys", func(t *testing.T) { - updated := notUpdated - locker := newLocker() - locker.Lock(1) - var wg sync.WaitGroup - wg.Add(1) - defer func() { - locker.Unlock(1) - wg.Wait() - }() - - go func() { - locker.RLock(1) - defer func() { - locker.RUnlock(1) - wg.Done() - }() - require.Equal(t, atThread1, updated, "Value should be updated in different thread") - updated = atThread2 - }() - time.Sleep(time.Millisecond * 10) - require.Equal(t, notUpdated, updated, "Value should not be updated in different thread") - updated = atThread1 - }) - - t.Run("Should not lock for different keys", func(t *testing.T) { - updated := notUpdated - locker := newLocker() - locker.Lock(1) - defer locker.Unlock(1) - var wg sync.WaitGroup - wg.Add(1) - go func() { - locker.RLock(2) - defer func() { - locker.RUnlock(2) - wg.Done() - }() - require.Equal(t, notUpdated, updated, "Value should not be updated in different thread") - updated = atThread2 - }() - wg.Wait() - require.Equal(t, atThread2, updated, "Value should be updated in different thread") - updated = atThread1 - }) -} diff --git a/pkg/tsdb/grafana-postgresql-datasource/postgres.go b/pkg/tsdb/grafana-postgresql-datasource/postgres.go index 53e46308857..e995e360912 100644 --- a/pkg/tsdb/grafana-postgresql-datasource/postgres.go +++ b/pkg/tsdb/grafana-postgresql-datasource/postgres.go @@ -4,38 +4,43 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" + "os" "reflect" "strconv" "strings" "time" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + pgxstdlib "github.com/jackc/pgx/v5/stdlib" + "github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/backend/datasource" "github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt" + "github.com/grafana/grafana-plugin-sdk-go/backend/proxy" "github.com/grafana/grafana-plugin-sdk-go/data" "github.com/grafana/grafana-plugin-sdk-go/data/sqlutil" - "github.com/lib/pq" "github.com/grafana/grafana-plugin-sdk-go/backend/log" "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/tsdb/grafana-postgresql-datasource/tls" "github.com/grafana/grafana/pkg/tsdb/sqleng" ) func ProvideService(cfg *setting.Cfg) *Service { logger := backend.NewLoggerWith("logger", "tsdb.postgres") s := &Service{ - tlsManager: newTLSManager(logger, cfg.DataPath), - logger: logger, + logger: logger, } s.im = datasource.NewInstanceManager(s.newInstanceSettings()) return s } type Service struct { - tlsManager tlsSettingsProvider - im instancemgmt.InstanceManager - logger log.Logger + im instancemgmt.InstanceManager + logger log.Logger } func (s *Service) getDSInfo(ctx context.Context, pluginCtx backend.PluginContext) (*sqleng.DataSourceHandler, error) { @@ -55,17 +60,11 @@ func (s *Service) QueryData(ctx context.Context, req *backend.QueryDataRequest) return dsInfo.QueryData(ctx, req) } -func newPostgres(ctx context.Context, userFacingDefaultError string, rowLimit int64, dsInfo sqleng.DataSourceInfo, cnnstr string, logger log.Logger, settings backend.DataSourceInstanceSettings) (*sql.DB, *sqleng.DataSourceHandler, error) { - connector, err := pq.NewConnector(cnnstr) +func newPostgres(userFacingDefaultError string, rowLimit int64, dsInfo sqleng.DataSourceInfo, logger log.Logger, proxyClient proxy.Client) (*sql.DB, *sqleng.DataSourceHandler, error) { + pgxConf, err := generateConnectionConfig(dsInfo) if err != nil { - logger.Error("postgres connector creation failed", "error", err) - return nil, nil, fmt.Errorf("postgres connector creation failed") - } - - proxyClient, err := settings.ProxyClient(ctx) - if err != nil { - logger.Error("postgres proxy creation failed", "error", err) - return nil, nil, fmt.Errorf("postgres proxy creation failed") + logger.Error("postgres config creation failed", "error", err) + return nil, nil, fmt.Errorf("postgres config creation failed") } if proxyClient.SecureSocksProxyEnabled() { @@ -74,9 +73,8 @@ func newPostgres(ctx context.Context, userFacingDefaultError string, rowLimit in logger.Error("postgres proxy creation failed", "error", err) return nil, nil, fmt.Errorf("postgres proxy creation failed") } - postgresDialer := newPostgresProxyDialer(dialer) - // update the postgres dialer with the proxy dialer - connector.Dialer(postgresDialer) + + pgxConf.DialFunc = newPgxDialFunc(dialer) } config := sqleng.DataPluginConfiguration{ @@ -87,7 +85,7 @@ func newPostgres(ctx context.Context, userFacingDefaultError string, rowLimit in queryResultTransformer := postgresQueryResultTransformer{} - db := sql.OpenDB(connector) + db := pgxstdlib.OpenDB(*pgxConf) db.SetMaxOpenConns(config.DSInfo.JsonData.MaxOpenConns) db.SetMaxIdleConns(config.DSInfo.JsonData.MaxIdleConns) @@ -143,17 +141,17 @@ func (s *Service) newInstanceSettings() datasource.InstanceFactoryFunc { DecryptedSecureJSONData: settings.DecryptedSecureJSONData, } - cnnstr, err := s.generateConnectionString(dsInfo) - if err != nil { - return nil, err - } - userFacingDefaultError, err := cfg.UserFacingDefaultError() if err != nil { return nil, err } - _, handler, err := newPostgres(ctx, userFacingDefaultError, sqlCfg.RowLimit, dsInfo, cnnstr, logger, settings) + proxyClient, err := settings.ProxyClient(ctx) + if err != nil { + return nil, err + } + + _, handler, err := newPostgres(userFacingDefaultError, sqlCfg.RowLimit, dsInfo, logger, proxyClient) if err != nil { logger.Error("Failed connecting to Postgres", "err", err) @@ -170,13 +168,11 @@ func escape(input string) string { return strings.ReplaceAll(strings.ReplaceAll(input, `\`, `\\`), "'", `\'`) } -func (s *Service) generateConnectionString(dsInfo sqleng.DataSourceInfo) (string, error) { - logger := s.logger +func generateConnectionConfig(dsInfo sqleng.DataSourceInfo) (*pgx.ConnConfig, error) { var host string var port int if strings.HasPrefix(dsInfo.URL, "/") { host = dsInfo.URL - logger.Debug("Generating connection string with Unix socket specifier", "socket", host) } else { index := strings.LastIndex(dsInfo.URL, ":") v6Index := strings.Index(dsInfo.URL, "]") @@ -187,12 +183,8 @@ func (s *Service) generateConnectionString(dsInfo sqleng.DataSourceInfo) (string var err error port, err = strconv.Atoi(sp[1]) if err != nil { - return "", fmt.Errorf("invalid port in host specifier %q: %w", sp[1], err) + return nil, fmt.Errorf("invalid port in host specifier %q: %w", sp[1], err) } - - logger.Debug("Generating connection string with network host/port pair", "host", host, "port", port) - } else { - logger.Debug("Generating connection string with network host", "host", host) } } else { if index == v6Index+1 { @@ -200,46 +192,45 @@ func (s *Service) generateConnectionString(dsInfo sqleng.DataSourceInfo) (string var err error port, err = strconv.Atoi(dsInfo.URL[index+1:]) if err != nil { - return "", fmt.Errorf("invalid port in host specifier %q: %w", dsInfo.URL[index+1:], err) + return nil, fmt.Errorf("invalid port in host specifier %q: %w", dsInfo.URL[index+1:], err) } - - logger.Debug("Generating ipv6 connection string with network host/port pair", "host", host, "port", port) } else { host = dsInfo.URL[1 : len(dsInfo.URL)-1] - logger.Debug("Generating ipv6 connection string with network host", "host", host) } } } - connStr := fmt.Sprintf("user='%s' password='%s' host='%s' dbname='%s'", + // NOTE: we always set sslmode=disable in the connection string, we handle TLS manually later + connStr := fmt.Sprintf("sslmode=disable user='%s' password='%s' host='%s' dbname='%s'", escape(dsInfo.User), escape(dsInfo.DecryptedSecureJSONData["password"]), escape(host), escape(dsInfo.Database)) if port > 0 { connStr += fmt.Sprintf(" port=%d", port) } - tlsSettings, err := s.tlsManager.getTLSSettings(dsInfo) + conf, err := pgx.ParseConfig(connStr) if err != nil { - return "", err + return nil, err } - connStr += fmt.Sprintf(" sslmode='%s'", escape(tlsSettings.Mode)) - - // Attach root certificate if provided - if tlsSettings.RootCertFile != "" { - logger.Debug("Setting server root certificate", "tlsRootCert", tlsSettings.RootCertFile) - connStr += fmt.Sprintf(" sslrootcert='%s'", escape(tlsSettings.RootCertFile)) + tlsConf, err := tls.GetTLSConfig(dsInfo, os.ReadFile, host) + if err != nil { + return nil, err } - // Attach client certificate and key if both are provided - if tlsSettings.CertFile != "" && tlsSettings.CertKeyFile != "" { - logger.Debug("Setting TLS/SSL client auth", "tlsCert", tlsSettings.CertFile, "tlsKey", tlsSettings.CertKeyFile) - connStr += fmt.Sprintf(" sslcert='%s' sslkey='%s'", escape(tlsSettings.CertFile), escape(tlsSettings.CertKeyFile)) - } else if tlsSettings.CertFile != "" || tlsSettings.CertKeyFile != "" { - return "", fmt.Errorf("TLS/SSL client certificate and key must both be specified") + // before we set the TLS config, we need to make sure the `.Fallbacks` attribute is unset, see: + // https://github.com/jackc/pgx/discussions/1903#discussioncomment-8430146 + if len(conf.Fallbacks) > 0 { + return nil, errors.New("tls: fallbacks configured, unable to set up TLS config") + } + conf.TLSConfig = tlsConf + + // by default pgx resolves hostnames to ip addresses. we must avoid this. + // (certain socks-proxy related functionality relies on the hostname being preserved) + conf.LookupFunc = func(ctx context.Context, host string) ([]string, error) { + return []string{host}, nil } - logger.Debug("Generated Postgres connection string successfully") - return connStr, nil + return conf, nil } type postgresQueryResultTransformer struct{} @@ -267,6 +258,44 @@ func (s *Service) CheckHealth(ctx context.Context, req *backend.CheckHealthReque func (t *postgresQueryResultTransformer) GetConverterList() []sqlutil.StringConverter { return []sqlutil.StringConverter{ + { + Name: "handle TIME WITH TIME ZONE", + InputScanKind: reflect.Interface, + InputTypeName: strconv.Itoa(pgtype.TimetzOID), + ConversionFunc: func(in *string) (*string, error) { return in, nil }, + Replacer: &sqlutil.StringFieldReplacer{ + OutputFieldType: data.FieldTypeNullableTime, + ReplaceFunc: func(in *string) (any, error) { + if in == nil { + return nil, nil + } + v, err := time.Parse("15:04:05-07", *in) + if err != nil { + return nil, err + } + return &v, nil + }, + }, + }, + { + Name: "handle TIME", + InputScanKind: reflect.Interface, + InputTypeName: "TIME", + ConversionFunc: func(in *string) (*string, error) { return in, nil }, + Replacer: &sqlutil.StringFieldReplacer{ + OutputFieldType: data.FieldTypeNullableTime, + ReplaceFunc: func(in *string) (any, error) { + if in == nil { + return nil, nil + } + v, err := time.Parse("15:04:05", *in) + if err != nil { + return nil, err + } + return &v, nil + }, + }, + }, { Name: "handle FLOAT4", InputScanKind: reflect.Interface, diff --git a/pkg/tsdb/grafana-postgresql-datasource/postgres_snapshot_test.go b/pkg/tsdb/grafana-postgresql-datasource/postgres_snapshot_test.go index 0713e8cdacf..fad67a88115 100644 --- a/pkg/tsdb/grafana-postgresql-datasource/postgres_snapshot_test.go +++ b/pkg/tsdb/grafana-postgresql-datasource/postgres_snapshot_test.go @@ -3,7 +3,6 @@ package postgres import ( "context" "encoding/json" - "fmt" "os" "path/filepath" "regexp" @@ -51,20 +50,6 @@ func TestIntegrationPostgresSnapshots(t *testing.T) { t.Skip() } - getCnnStr := func() string { - host := os.Getenv("POSTGRES_HOST") - if host == "" { - host = "localhost" - } - port := os.Getenv("POSTGRES_PORT") - if port == "" { - port = "5432" - } - - return fmt.Sprintf("user=grafanatest password=grafanatest host=%s port=%s dbname=grafanadstest sslmode=disable", - host, port) - } - sqlQueryCommentRe := regexp.MustCompile(`^-- (.+)\n`) readSqlFile := func(path string) (string, string) { @@ -148,18 +133,34 @@ func TestIntegrationPostgresSnapshots(t *testing.T) { ConnMaxLifetime: 14400, Timescaledb: false, ConfigurationMethod: "file-path", + Mode: "disable", + } + + host := os.Getenv("POSTGRES_HOST") + if host == "" { + host = "localhost" + } + port := os.Getenv("POSTGRES_PORT") + if port == "" { + port = "5432" } dsInfo := sqleng.DataSourceInfo{ - JsonData: jsonData, - DecryptedSecureJSONData: map[string]string{}, + JsonData: jsonData, + DecryptedSecureJSONData: map[string]string{ + "password": "grafanatest", + }, + URL: host + ":" + port, + Database: "grafanadstest", + User: "grafanatest", } logger := log.New() - cnnstr := getCnnStr() - - db, handler, err := newPostgres(context.Background(), "error", 10000, dsInfo, cnnstr, logger, backend.DataSourceInstanceSettings{}) + settings := backend.DataSourceInstanceSettings{} + proxyClient, err := settings.ProxyClient(context.Background()) + require.NoError(t, err) + db, handler, err := newPostgres("error", 10000, dsInfo, logger, proxyClient) t.Cleanup((func() { _, err := db.Exec("DROP TABLE tbl") diff --git a/pkg/tsdb/grafana-postgresql-datasource/postgres_test.go b/pkg/tsdb/grafana-postgresql-datasource/postgres_test.go index 2dec40dc837..473d876f3de 100644 --- a/pkg/tsdb/grafana-postgresql-datasource/postgres_test.go +++ b/pkg/tsdb/grafana-postgresql-datasource/postgres_test.go @@ -2,31 +2,33 @@ package postgres import ( "context" + "errors" "fmt" "math/rand" + "net" + "net/http" "os" "strings" "testing" "time" "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana-plugin-sdk-go/backend/log" + backendproxy "github.com/grafana/grafana-plugin-sdk-go/backend/proxy" "github.com/grafana/grafana-plugin-sdk-go/data" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/net/proxy" - "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/tsdb/grafana-postgresql-datasource/tls" "github.com/grafana/grafana/pkg/tsdb/sqleng" - _ "github.com/lib/pq" + _ "github.com/jackc/pgx/v5/stdlib" ) -// Test generateConnectionString. -func TestIntegrationGenerateConnectionString(t *testing.T) { - if testing.Short() { - t.Skip("skipping integration test") - } - cfg := setting.NewCfg() - cfg.DataPath = t.TempDir() +func TestGenerateConnectionConfig(t *testing.T) { + rootCertBytes, err := tls.CreateRandomRootCertBytes() + require.NoError(t, err) testCases := []struct { desc string @@ -34,10 +36,15 @@ func TestIntegrationGenerateConnectionString(t *testing.T) { user string password string database string - tlsSettings tlsSettings - expConnStr string + tlsMode string + tlsRootCert []byte expErr string - uid string + expHost string + expPort uint16 + expUser string + expPassword string + expDatabase string + expTLS bool }{ { desc: "Unix socket host", @@ -45,8 +52,11 @@ func TestIntegrationGenerateConnectionString(t *testing.T) { user: "user", password: "password", database: "database", - tlsSettings: tlsSettings{Mode: "verify-full"}, - expConnStr: "user='user' password='password' host='/var/run/postgresql' dbname='database' sslmode='verify-full'", + tlsMode: "disable", + expUser: "user", + expPassword: "password", + expHost: "/var/run/postgresql", + expDatabase: "database", }, { desc: "TCP host", @@ -54,8 +64,12 @@ func TestIntegrationGenerateConnectionString(t *testing.T) { user: "user", password: "password", database: "database", - tlsSettings: tlsSettings{Mode: "verify-full"}, - expConnStr: "user='user' password='password' host='host' dbname='database' sslmode='verify-full'", + tlsMode: "disable", + expUser: "user", + expPassword: "password", + expHost: "host", + expPort: 5432, + expDatabase: "database", }, { desc: "TCP/port host", @@ -63,8 +77,12 @@ func TestIntegrationGenerateConnectionString(t *testing.T) { user: "user", password: "password", database: "database", - tlsSettings: tlsSettings{Mode: "verify-full"}, - expConnStr: "user='user' password='password' host='host' dbname='database' port=1234 sslmode='verify-full'", + tlsMode: "disable", + expUser: "user", + expPassword: "password", + expHost: "host", + expPort: 1234, + expDatabase: "database", }, { desc: "Ipv6 host", @@ -72,8 +90,11 @@ func TestIntegrationGenerateConnectionString(t *testing.T) { user: "user", password: "password", database: "database", - tlsSettings: tlsSettings{Mode: "verify-full"}, - expConnStr: "user='user' password='password' host='::1' dbname='database' sslmode='verify-full'", + tlsMode: "disable", + expUser: "user", + expPassword: "password", + expHost: "::1", + expDatabase: "database", }, { desc: "Ipv6/port host", @@ -81,16 +102,20 @@ func TestIntegrationGenerateConnectionString(t *testing.T) { user: "user", password: "password", database: "database", - tlsSettings: tlsSettings{Mode: "verify-full"}, - expConnStr: "user='user' password='password' host='::1' dbname='database' port=1234 sslmode='verify-full'", + tlsMode: "disable", + expUser: "user", + expPassword: "password", + expHost: "::1", + expPort: 1234, + expDatabase: "database", }, { - desc: "Invalid port", - host: "host:invalid", - user: "user", - database: "database", - tlsSettings: tlsSettings{}, - expErr: "invalid port in host specifier", + desc: "Invalid port", + host: "host:invalid", + user: "user", + database: "database", + tlsMode: "disable", + expErr: "invalid port in host specifier", }, { desc: "Password with single quote and backslash", @@ -98,8 +123,11 @@ func TestIntegrationGenerateConnectionString(t *testing.T) { user: "user", password: `p'\assword`, database: "database", - tlsSettings: tlsSettings{Mode: "verify-full"}, - expConnStr: `user='user' password='p\'\\assword' host='host' dbname='database' sslmode='verify-full'`, + tlsMode: "disable", + expUser: "user", + expPassword: `p'\assword`, + expHost: "host", + expDatabase: "database", }, { desc: "User/DB with single quote and backslash", @@ -107,8 +135,11 @@ func TestIntegrationGenerateConnectionString(t *testing.T) { user: `u'\ser`, password: `password`, database: `d'\atabase`, - tlsSettings: tlsSettings{Mode: "verify-full"}, - expConnStr: `user='u\'\\ser' password='password' host='host' dbname='d\'\\atabase' sslmode='verify-full'`, + tlsMode: "disable", + expUser: `u'\ser`, + expPassword: "password", + expDatabase: `d'\atabase`, + expHost: "host", }, { desc: "Custom TLS mode disabled", @@ -116,45 +147,55 @@ func TestIntegrationGenerateConnectionString(t *testing.T) { user: "user", password: "password", database: "database", - tlsSettings: tlsSettings{Mode: "disable"}, - expConnStr: "user='user' password='password' host='host' dbname='database' sslmode='disable'", + tlsMode: "disable", + expUser: "user", + expPassword: "password", + expHost: "host", + expDatabase: "database", }, { - desc: "Custom TLS mode verify-full with certificate files", - host: "host", - user: "user", - password: "password", - database: "database", - tlsSettings: tlsSettings{ - Mode: "verify-full", - RootCertFile: "i/am/coding/ca.crt", - CertFile: "i/am/coding/client.crt", - CertKeyFile: "i/am/coding/client.key", - }, - expConnStr: "user='user' password='password' host='host' dbname='database' sslmode='verify-full' " + - "sslrootcert='i/am/coding/ca.crt' sslcert='i/am/coding/client.crt' sslkey='i/am/coding/client.key'", + desc: "Custom TLS mode verify-full with certificate files", + host: "host", + user: "user", + password: "password", + database: "database", + tlsMode: "verify-full", + tlsRootCert: rootCertBytes, + expUser: "user", + expPassword: "password", + expDatabase: "database", + expHost: "host", + expTLS: true, }, } for _, tt := range testCases { t.Run(tt.desc, func(t *testing.T) { - svc := Service{ - tlsManager: &tlsTestManager{settings: tt.tlsSettings}, - logger: backend.NewLoggerWith("logger", "tsdb.postgres"), - } - ds := sqleng.DataSourceInfo{ - URL: tt.host, - User: tt.user, - DecryptedSecureJSONData: map[string]string{"password": tt.password}, - Database: tt.database, - UID: tt.uid, + URL: tt.host, + User: tt.user, + DecryptedSecureJSONData: map[string]string{ + "password": tt.password, + "tlsCACert": string(tt.tlsRootCert), + }, + Database: tt.database, + JsonData: sqleng.JsonData{ + Mode: tt.tlsMode, + ConfigurationMethod: "file-content", + }, } - connStr, err := svc.generateConnectionString(ds) + c, err := generateConnectionConfig(ds) if tt.expErr == "" { require.NoError(t, err, tt.desc) - assert.Equal(t, tt.expConnStr, connStr) + assert.Equal(t, tt.expHost, c.Host) + if tt.expPort != 0 { + assert.Equal(t, tt.expPort, c.Port) + } + assert.Equal(t, tt.expUser, c.User) + assert.Equal(t, tt.expDatabase, c.Database) + assert.Equal(t, tt.expPassword, c.Password) + require.Equal(t, tt.expTLS, c.TLSConfig != nil) } else { require.Error(t, err, tt.desc) assert.True(t, strings.HasPrefix(err.Error(), tt.expErr), @@ -191,24 +232,40 @@ func TestIntegrationPostgres(t *testing.T) { return sql } + host := os.Getenv("POSTGRES_HOST") + if host == "" { + host = "localhost" + } + port := os.Getenv("POSTGRES_PORT") + if port == "" { + port = "5432" + } + jsonData := sqleng.JsonData{ MaxOpenConns: 0, MaxIdleConns: 2, ConnMaxLifetime: 14400, Timescaledb: false, ConfigurationMethod: "file-path", + Mode: "disable", } dsInfo := sqleng.DataSourceInfo{ - JsonData: jsonData, - DecryptedSecureJSONData: map[string]string{}, + JsonData: jsonData, + DecryptedSecureJSONData: map[string]string{ + "password": "grafanatest", + }, + URL: host + ":" + port, + Database: "grafanadstest", + User: "grafanatest", } logger := backend.NewLoggerWith("logger", "postgres.test") - cnnstr := postgresTestDBConnString() - - db, exe, err := newPostgres(context.Background(), "error", 10000, dsInfo, cnnstr, logger, backend.DataSourceInstanceSettings{}) + settings := backend.DataSourceInstanceSettings{} + proxyClient, err := settings.ProxyClient(context.Background()) + require.NoError(t, err) + db, exe, err := newPostgres("error", 10000, dsInfo, logger, proxyClient) require.NoError(t, err) @@ -1261,8 +1318,10 @@ func TestIntegrationPostgres(t *testing.T) { }) t.Run("When row limit set to 1", func(t *testing.T) { - dsInfo := sqleng.DataSourceInfo{} - _, handler, err := newPostgres(context.Background(), "error", 1, dsInfo, cnnstr, logger, backend.DataSourceInstanceSettings{}) + settings := backend.DataSourceInstanceSettings{} + proxyClient, err := settings.ProxyClient(context.Background()) + require.NoError(t, err) + _, handler, err := newPostgres("error", 1, dsInfo, logger, proxyClient) require.NoError(t, err) @@ -1377,14 +1436,6 @@ func genTimeRangeByInterval(from time.Time, duration time.Duration, interval tim return timeRange } -type tlsTestManager struct { - settings tlsSettings -} - -func (m *tlsTestManager) getTLSSettings(dsInfo sqleng.DataSourceInfo) (tlsSettings, error) { - return m.settings, nil -} - func isTestDbPostgres() bool { if db, present := os.LookupEnv("GRAFANA_TEST_DB"); present { return db == "postgres" @@ -1393,15 +1444,59 @@ func isTestDbPostgres() bool { return false } -func postgresTestDBConnString() string { - host := os.Getenv("POSTGRES_HOST") - if host == "" { - host = "localhost" - } - port := os.Getenv("POSTGRES_PORT") - if port == "" { - port = "5432" - } - return fmt.Sprintf("user=grafanatest password=grafanatest host=%s port=%s dbname=grafanadstest sslmode=disable", - host, port) +type testNoResolveDialer struct { +} + +func (d *testNoResolveDialer) Dial(network, addr string) (c net.Conn, err error) { + return nil, fmt.Errorf("test-dialer: dialing to '%s'. not implemented", addr) +} + +var _ proxy.Dialer = (&testNoResolveDialer{}) + +type testNoResolveProxyClient struct { +} + +var _ backendproxy.Client = (&testNoResolveProxyClient{}) + +func (p *testNoResolveProxyClient) SecureSocksProxyEnabled() bool { + return true +} + +func (p *testNoResolveProxyClient) ConfigureSecureSocksHTTPProxy(transport *http.Transport) error { + return errors.New("testNoResolveProxyClient.ConfigureSecureSocksHTTPProxy not implemented") +} + +func (p *testNoResolveProxyClient) NewSecureSocksProxyContextDialer() (proxy.Dialer, error) { + return &testNoResolveDialer{}, nil +} + +// we must make sure that pgx does not resolve hostnames: +// if we say the hostname is `localhost`, then it should +// instruct the socks-proxy to connect to `localhost`, not to `127.0.0.1` +// this is important, becase some other socks-proxy-code relies on this behavior. +func TestNoResolve(t *testing.T) { + jsonData := sqleng.JsonData{ + MaxOpenConns: 0, + MaxIdleConns: 2, + ConnMaxLifetime: 14400, + Timescaledb: false, + ConfigurationMethod: "file-path", + } + + dsInfo := sqleng.DataSourceInfo{ + JsonData: jsonData, + DecryptedSecureJSONData: map[string]string{ + "password": "password", + }, + URL: "localhost:5432", + Database: "db", + User: "user", + } + + db, _, err := newPostgres("error", 10000, dsInfo, log.New(), &testNoResolveProxyClient{}) + require.NoError(t, err) + require.NotNil(t, db) + err = db.Ping() + require.Error(t, err) + require.Contains(t, err.Error(), "test-dialer: dialing to 'localhost:5432'. not implemented") } diff --git a/pkg/tsdb/grafana-postgresql-datasource/proxy.go b/pkg/tsdb/grafana-postgresql-datasource/proxy.go index d06d8b68152..0c836eb3459 100644 --- a/pkg/tsdb/grafana-postgresql-datasource/proxy.go +++ b/pkg/tsdb/grafana-postgresql-datasource/proxy.go @@ -3,33 +3,17 @@ package postgres import ( "context" "net" - "time" - "github.com/lib/pq" "golang.org/x/net/proxy" ) -// we wrap the proxy.Dialer to become dialer that the postgres module accepts -func newPostgresProxyDialer(dialer proxy.Dialer) pq.Dialer { - return &postgresProxyDialer{d: dialer} -} - -var _ pq.Dialer = (&postgresProxyDialer{}) - -// postgresProxyDialer implements the postgres dialer using a proxy dialer, as their functions differ slightly -type postgresProxyDialer struct { - d proxy.Dialer -} - -// Dial uses the normal proxy dial function with the updated dialer -func (p *postgresProxyDialer) Dial(network, addr string) (c net.Conn, err error) { - return p.d.Dial(network, addr) -} - -// DialTimeout uses the normal postgres dial timeout function with the updated dialer -func (p *postgresProxyDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - return p.d.(proxy.ContextDialer).DialContext(ctx, network, address) +type PgxDialFunc = func(ctx context.Context, network string, address string) (net.Conn, error) + +func newPgxDialFunc(dialer proxy.Dialer) PgxDialFunc { + dialFunc := + func(ctx context.Context, network string, addr string) (net.Conn, error) { + return dialer.Dial(network, addr) + } + + return dialFunc } diff --git a/pkg/tsdb/grafana-postgresql-datasource/proxy_test.go b/pkg/tsdb/grafana-postgresql-datasource/proxy_test.go index ec36e1a1eae..afd205bd372 100644 --- a/pkg/tsdb/grafana-postgresql-datasource/proxy_test.go +++ b/pkg/tsdb/grafana-postgresql-datasource/proxy_test.go @@ -1,12 +1,12 @@ package postgres import ( - "database/sql" "fmt" "net" "testing" - "github.com/lib/pq" + "github.com/jackc/pgx/v5" + pgxstdlib "github.com/jackc/pgx/v5/stdlib" "github.com/stretchr/testify/require" "golang.org/x/net/proxy" ) @@ -25,13 +25,13 @@ func TestPostgresProxyDriver(t *testing.T) { cnnstr := fmt.Sprintf("postgres://auser:password@%s/db?sslmode=disable", dbURL) t.Run("Connector should use dialer context that routes through the socks proxy to db", func(t *testing.T) { - connector, err := pq.NewConnector(cnnstr) + pgxConf, err := pgx.ParseConfig(cnnstr) require.NoError(t, err) - dialer := newPostgresProxyDialer(&testDialer{}) - connector.Dialer(dialer) + pgxConf.DialFunc = newPgxDialFunc(&testDialer{}) + + db := pgxstdlib.OpenDB(*pgxConf) - db := sql.OpenDB(connector) err = db.Ping() require.Contains(t, err.Error(), "test-dialer is not functional") diff --git a/pkg/tsdb/grafana-postgresql-datasource/testdata/table/timestamp_convert_real.golden.jsonc b/pkg/tsdb/grafana-postgresql-datasource/testdata/table/timestamp_convert_real.golden.jsonc index af23cf8b5bd..e6d1dfd2382 100644 --- a/pkg/tsdb/grafana-postgresql-datasource/testdata/table/timestamp_convert_real.golden.jsonc +++ b/pkg/tsdb/grafana-postgresql-datasource/testdata/table/timestamp_convert_real.golden.jsonc @@ -9,16 +9,16 @@ // } // Name: // Dimensions: 4 Fields by 4 Rows -// +--------------------------------------+-------------------------------+------------------+-----------------------------------+ -// | Name: reallyt | Name: time | Name: n | Name: timeend | -// | Labels: | Labels: | Labels: | Labels: | -// | Type: []*time.Time | Type: []*time.Time | Type: []*float64 | Type: []*time.Time | -// +--------------------------------------+-------------------------------+------------------+-----------------------------------+ -// | 2023-12-21 12:22:24 +0000 UTC | 2023-12-21 12:21:40 +0000 UTC | 1.7031613e+09 | 2023-12-21 12:22:52 +0000 UTC | -// | 2023-12-21 12:20:33.408 +0000 UTC | 2023-12-21 12:20:00 +0000 UTC | 1.7031612e+12 | 2023-12-21 12:21:52.522 +0000 UTC | -// | 2023-12-21 12:20:41.050022 +0000 UTC | 2023-12-21 12:20:00 +0000 UTC | 1.7031612e+18 | 2023-12-21 12:21:52.522 +0000 UTC | -// | null | null | null | null | -// +--------------------------------------+-------------------------------+------------------+-----------------------------------+ +// +--------------------------------------+-----------------------------------+-----------------------+-----------------------------------+ +// | Name: reallyt | Name: time | Name: n | Name: timeend | +// | Labels: | Labels: | Labels: | Labels: | +// | Type: []*time.Time | Type: []*time.Time | Type: []*float64 | Type: []*time.Time | +// +--------------------------------------+-----------------------------------+-----------------------+-----------------------------------+ +// | 2023-12-21 12:22:24 +0000 UTC | 2023-12-21 12:22:24 +0000 UTC | 1.703161344e+09 | 2023-12-21 12:22:52 +0000 UTC | +// | 2023-12-21 12:20:33.408 +0000 UTC | 2023-12-21 12:20:33.408 +0000 UTC | 1.703161233408e+12 | 2023-12-21 12:21:52.522 +0000 UTC | +// | 2023-12-21 12:20:41.050022 +0000 UTC | 2023-12-21 12:20:41.05 +0000 UTC | 1.703161241050022e+18 | 2023-12-21 12:21:52.522 +0000 UTC | +// | null | null | null | null | +// +--------------------------------------+-----------------------------------+-----------------------+-----------------------------------+ // // // 🌟 This was machine generated. Do not edit. 🌟 @@ -78,15 +78,15 @@ null ], [ - 1703161300000, - 1703161200000, - 1703161200000, + 1703161344000, + 1703161233408, + 1703161241050, null ], [ - 1703161300, - 1703161200000, - 1703161200000000000, + 1703161344, + 1703161233408, + 1703161241050022000, null ], [ diff --git a/pkg/tsdb/grafana-postgresql-datasource/testdata/table/types_datetime.golden.jsonc b/pkg/tsdb/grafana-postgresql-datasource/testdata/table/types_datetime.golden.jsonc index 48e55e7b077..09e9403459b 100644 --- a/pkg/tsdb/grafana-postgresql-datasource/testdata/table/types_datetime.golden.jsonc +++ b/pkg/tsdb/grafana-postgresql-datasource/testdata/table/types_datetime.golden.jsonc @@ -9,14 +9,14 @@ // } // Name: // Dimensions: 12 Fields by 2 Rows -// +----------------------------------------+----------------------------------------+--------------------------------------+--------------------------------------+---------------------------------+---------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+ -// | Name: ts | Name: tsnn | Name: tsz | Name: tsznn | Name: d | Name: dnn | Name: t | Name: tnn | Name: tz | Name: tznn | Name: i | Name: inn | -// | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | -// | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*string | Type: []*string | -// +----------------------------------------+----------------------------------------+--------------------------------------+--------------------------------------+---------------------------------+---------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+ -// | 2023-11-15 05:06:07.123456 +0000 +0000 | 2023-11-15 05:06:08.123456 +0000 +0000 | 2021-07-22 11:22:33.654321 +0000 UTC | 2021-07-22 11:22:34.654321 +0000 UTC | 2023-12-20 00:00:00 +0000 +0000 | 2023-12-21 00:00:00 +0000 +0000 | 0000-01-01 12:34:56.234567 +0000 UTC | 0000-01-01 12:34:57.234567 +0000 UTC | 0000-01-01 23:12:36.765432 +0100 +0100 | 0000-01-01 23:12:37.765432 +0100 +0100 | 00:00:00.987654 | 00:00:00.887654 | -// | null | 2023-11-15 05:06:09.123456 +0000 +0000 | null | 2021-07-22 11:22:35.654321 +0000 UTC | null | 2023-12-22 00:00:00 +0000 +0000 | null | 0000-01-01 12:34:58.234567 +0000 UTC | null | 0000-01-01 23:12:38.765432 +0100 +0100 | null | 00:00:00.787654 | -// +----------------------------------------+----------------------------------------+--------------------------------------+--------------------------------------+---------------------------------+---------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+ +// +--------------------------------------+--------------------------------------+--------------------------------------+--------------------------------------+-------------------------------+-------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+ +// | Name: ts | Name: tsnn | Name: tsz | Name: tsznn | Name: d | Name: dnn | Name: t | Name: tnn | Name: tz | Name: tznn | Name: i | Name: inn | +// | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | +// | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*string | Type: []*string | +// +--------------------------------------+--------------------------------------+--------------------------------------+--------------------------------------+-------------------------------+-------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+ +// | 2023-11-15 05:06:07.123456 +0000 UTC | 2023-11-15 05:06:08.123456 +0000 UTC | 2021-07-22 11:22:33.654321 +0000 UTC | 2021-07-22 11:22:34.654321 +0000 UTC | 2023-12-20 00:00:00 +0000 UTC | 2023-12-21 00:00:00 +0000 UTC | 0000-01-01 12:34:56.234567 +0000 UTC | 0000-01-01 12:34:57.234567 +0000 UTC | 0000-01-01 23:12:36.765432 +0100 +0100 | 0000-01-01 23:12:37.765432 +0100 +0100 | 00:00:00.987654 | 00:00:00.887654 | +// | null | 2023-11-15 05:06:09.123456 +0000 UTC | null | 2021-07-22 11:22:35.654321 +0000 UTC | null | 2023-12-22 00:00:00 +0000 UTC | null | 0000-01-01 12:34:58.234567 +0000 UTC | null | 0000-01-01 23:12:38.765432 +0100 +0100 | null | 00:00:00.787654 | +// +--------------------------------------+--------------------------------------+--------------------------------------+--------------------------------------+-------------------------------+-------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+ // // // 🌟 This was machine generated. Do not edit. 🌟 diff --git a/pkg/tsdb/grafana-postgresql-datasource/tls/tls.go b/pkg/tsdb/grafana-postgresql-datasource/tls/tls.go new file mode 100644 index 00000000000..d550d616962 --- /dev/null +++ b/pkg/tsdb/grafana-postgresql-datasource/tls/tls.go @@ -0,0 +1,135 @@ +package tls + +import ( + "crypto/tls" + "crypto/x509" + "errors" + + "github.com/grafana/grafana/pkg/tsdb/sqleng" +) + +// we support 4 postgres tls modes: +// disable - no tls +// require - use tls +// verify-ca - use tls, verify root cert but not the hostname +// verify-full - use tls, verify root cert +// (for all the options except `disable`, you can optionally use client certificates) + +func getTLSConfigRequire(certs *Certs, serverName string) (*tls.Config, error) { + // see https://www.postgresql.org/docs/12/libpq-ssl.html , + // mode=require + provided root-cert should behave as mode=verify-ca + if certs.rootCerts != nil { + return getTLSConfigVerifyCA(certs, serverName) + } + + return &tls.Config{ + InsecureSkipVerify: true, // we do not verify the root cert + Certificates: certs.clientCerts, + ServerName: serverName, + }, nil +} + +// to implement the verify-ca mode, we need to do this: +// - for the root certificate +// - verify that the certificate we receive from the server is trusted, +// meaning it relates to our root certificate +// - we DO NOT verify that the hostname of the database matches +// the hostname in the certificate +// +// the problem is, `go“ does not offer such an option. +// by default, it will verify both things. +// +// so what we do is: +// - we turn off the default-verification with `InsecureSkipVerify` +// - we implement our own verification using `VerifyConnection` +// +// extra info about this: +// - there is a rejected feature-request about this at https://github.com/golang/go/issues/21971 +// - the recommended workaround is based on VerifyPeerCertificate +// - there is even example code at https://github.com/golang/go/commit/29cfb4d3c3a97b6f426d1b899234da905be699aa +// - but later the example code was changed to use VerifyConnection instead: +// https://github.com/golang/go/commit/7eb5941b95a588a23f18fa4c22fe42ff0119c311 +// +// a verifyConnection example is at https://pkg.go.dev/crypto/tls#example-Config-VerifyConnection . +// +// this is how the `pgx` library handles verify-ca: +// +// https://github.com/jackc/pgx/blob/5c63f646f820ca9696fc3515c1caf2a557d562e5/pgconn/config.go#L657-L690 +// (unfortunately pgx only handles this for certificate-provided-as-path, so we cannot rely on it) +func getTLSConfigVerifyCA(certs *Certs, serverName string) (*tls.Config, error) { + conf := tls.Config{ + ServerName: serverName, + Certificates: certs.clientCerts, + InsecureSkipVerify: true, // we turn off the default-verification, we'll do VerifyConnection instead + VerifyConnection: func(state tls.ConnectionState) error { + // we add all the certificates to the pool, we skip the first cert. + intermediates := x509.NewCertPool() + for _, c := range state.PeerCertificates[1:] { + intermediates.AddCert(c) + } + + opts := x509.VerifyOptions{ + Roots: certs.rootCerts, + Intermediates: intermediates, + } + + // we call `Verify()` on the first cert (that we skipped previously) + _, err := state.PeerCertificates[0].Verify(opts) + return err + }, + RootCAs: certs.rootCerts, + } + + return &conf, nil +} + +func getTLSConfigVerifyFull(certs *Certs, serverName string) (*tls.Config, error) { + conf := tls.Config{ + Certificates: certs.clientCerts, + ServerName: serverName, + RootCAs: certs.rootCerts, + } + + return &conf, nil +} + +func IsTLSEnabled(dsInfo sqleng.DataSourceInfo) bool { + mode := dsInfo.JsonData.Mode + return mode != "disable" +} + +// returns `nil` if tls is disabled +func GetTLSConfig(dsInfo sqleng.DataSourceInfo, readFile ReadFileFunc, serverName string) (*tls.Config, error) { + mode := dsInfo.JsonData.Mode + // we need to special-case the no-tls-mode + if mode == "disable" { + return nil, nil + } + + // for all the remaining cases we need to load + // both the root-cert if exists, and the client-cert if exists. + certBytes, err := loadCertificateBytes(dsInfo, readFile) + if err != nil { + return nil, err + } + + certs, err := createCertificates(certBytes) + if err != nil { + return nil, err + } + + switch mode { + // `disable` already handled + case "": + // for backward-compatibility reasons this is the same as `require` + return getTLSConfigRequire(certs, serverName) + case "require": + return getTLSConfigRequire(certs, serverName) + case "verify-ca": + return getTLSConfigVerifyCA(certs, serverName) + case "verify-full": + return getTLSConfigVerifyFull(certs, serverName) + default: + return nil, errors.New("tls: invalid mode " + mode) + } +} diff --git a/pkg/tsdb/grafana-postgresql-datasource/tls/tls_loader.go b/pkg/tsdb/grafana-postgresql-datasource/tls/tls_loader.go new file mode 100644 index 00000000000..6c19d3801d5 --- /dev/null +++ b/pkg/tsdb/grafana-postgresql-datasource/tls/tls_loader.go @@ -0,0 +1,101 @@ +package tls + +import ( + "crypto/tls" + "crypto/x509" + "errors" + + "github.com/grafana/grafana/pkg/tsdb/sqleng" +) + +// this file deals with locating and loading the certificates, +// from json-data or from disk. + +type CertBytes struct { + rootCert []byte + clientKey []byte + clientCert []byte +} + +type ReadFileFunc = func(name string) ([]byte, error) + +var errPartialClientCertNoKey = errors.New("tls: client cert provided but client key missing") +var errPartialClientCertNoCert = errors.New("tls: client key provided but client cert missing") + +// certificates can be stored either as encrypted-json-data, or as file-path +func loadCertificateBytes(dsInfo sqleng.DataSourceInfo, readFile ReadFileFunc) (*CertBytes, error) { + if dsInfo.JsonData.ConfigurationMethod == "file-content" { + return &CertBytes{ + rootCert: []byte(dsInfo.DecryptedSecureJSONData["tlsCACert"]), + clientKey: []byte(dsInfo.DecryptedSecureJSONData["tlsClientKey"]), + clientCert: []byte(dsInfo.DecryptedSecureJSONData["tlsClientCert"]), + }, nil + } else { + c := CertBytes{} + + if dsInfo.JsonData.RootCertFile != "" { + rootCert, err := readFile(dsInfo.JsonData.RootCertFile) + if err != nil { + return nil, err + } + c.rootCert = rootCert + } + + if dsInfo.JsonData.CertKeyFile != "" { + clientKey, err := readFile(dsInfo.JsonData.CertKeyFile) + if err != nil { + return nil, err + } + c.clientKey = clientKey + } + + if dsInfo.JsonData.CertFile != "" { + clientCert, err := readFile(dsInfo.JsonData.CertFile) + if err != nil { + return nil, err + } + c.clientCert = clientCert + } + + return &c, nil + } +} + +type Certs struct { + clientCerts []tls.Certificate + rootCerts *x509.CertPool +} + +func createCertificates(certBytes *CertBytes) (*Certs, error) { + certs := Certs{} + + if len(certBytes.rootCert) > 0 { + pool := x509.NewCertPool() + ok := pool.AppendCertsFromPEM(certBytes.rootCert) + if !ok { + return nil, errors.New("tls: failed to add root certificate") + } + certs.rootCerts = pool + } + + hasClientKey := len(certBytes.clientKey) > 0 + hasClientCert := len(certBytes.clientCert) > 0 + + if hasClientKey && hasClientCert { + cert, err := tls.X509KeyPair(certBytes.clientCert, certBytes.clientKey) + if err != nil { + return nil, err + } + certs.clientCerts = []tls.Certificate{cert} + } + + if hasClientKey && (!hasClientCert) { + return nil, errPartialClientCertNoCert + } + + if hasClientCert && (!hasClientKey) { + return nil, errPartialClientCertNoKey + } + + return &certs, nil +} diff --git a/pkg/tsdb/grafana-postgresql-datasource/tls/tls_test.go b/pkg/tsdb/grafana-postgresql-datasource/tls/tls_test.go new file mode 100644 index 00000000000..9d1c9729c0b --- /dev/null +++ b/pkg/tsdb/grafana-postgresql-datasource/tls/tls_test.go @@ -0,0 +1,402 @@ +package tls + +import ( + "errors" + "os" + "testing" + + "github.com/grafana/grafana/pkg/tsdb/sqleng" + "github.com/stretchr/testify/require" +) + +func noReadFile(path string) ([]byte, error) { + return nil, errors.New("not implemented") +} + +func TestTLSNoMode(t *testing.T) { + // for backward-compatibility reason, + // when mode is unset, it defaults to `require` + dsInfo := sqleng.DataSourceInfo{ + JsonData: sqleng.JsonData{ + ConfigurationMethod: "", + }, + } + c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") + require.NoError(t, err) + require.NotNil(t, c) + require.True(t, c.InsecureSkipVerify) +} + +func TestTLSDisable(t *testing.T) { + dsInfo := sqleng.DataSourceInfo{ + JsonData: sqleng.JsonData{ + Mode: "disable", + ConfigurationMethod: "", + }, + } + c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") + require.NoError(t, err) + require.Nil(t, c) +} + +func TestTLSRequire(t *testing.T) { + dsInfo := sqleng.DataSourceInfo{ + JsonData: sqleng.JsonData{ + Mode: "require", + ConfigurationMethod: "", + }, + } + c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") + require.NoError(t, err) + require.NotNil(t, c) + require.True(t, c.InsecureSkipVerify) + require.Nil(t, c.RootCAs) +} + +func TestTLSRequireWithRootCert(t *testing.T) { + rootCertBytes, err := CreateRandomRootCertBytes() + require.NoError(t, err) + + dsInfo := sqleng.DataSourceInfo{ + JsonData: sqleng.JsonData{ + Mode: "require", + ConfigurationMethod: "file-content", + }, + DecryptedSecureJSONData: map[string]string{ + "tlsCACert": string(rootCertBytes), + }, + } + c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") + require.NoError(t, err) + require.NotNil(t, c) + require.True(t, c.InsecureSkipVerify) + require.NotNil(t, c.VerifyConnection) + require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available +} + +func TestTLSVerifyCA(t *testing.T) { + rootCertBytes, err := CreateRandomRootCertBytes() + require.NoError(t, err) + + dsInfo := sqleng.DataSourceInfo{ + JsonData: sqleng.JsonData{ + Mode: "verify-ca", + ConfigurationMethod: "file-content", + }, + DecryptedSecureJSONData: map[string]string{ + "tlsCACert": string(rootCertBytes), + }, + } + c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") + require.NoError(t, err) + require.NotNil(t, c) + require.True(t, c.InsecureSkipVerify) + require.NotNil(t, c.VerifyConnection) + require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available +} + +func TestTLSVerifyCANoRootCertProvided(t *testing.T) { + // this is ok. go will use the default system certs + dsInfo := sqleng.DataSourceInfo{ + JsonData: sqleng.JsonData{ + Mode: "verify-ca", + ConfigurationMethod: "file-content", + }, + DecryptedSecureJSONData: map[string]string{}, + } + _, err := GetTLSConfig(dsInfo, noReadFile, "localhost") + require.NoError(t, err) +} + +func TestTLSClientCert(t *testing.T) { + clientKey, clientCert, err := CreateRandomClientCert() + require.NoError(t, err) + + dsInfo := sqleng.DataSourceInfo{ + JsonData: sqleng.JsonData{ + Mode: "require", + ConfigurationMethod: "file-content", + }, + DecryptedSecureJSONData: map[string]string{ + "tlsClientCert": string(clientCert), + "tlsClientKey": string(clientKey), + }, + } + c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") + require.NoError(t, err) + require.NotNil(t, c) + require.Len(t, c.Certificates, 1) +} + +func TestTLSMethodFileContentClientCertMissingKey(t *testing.T) { + _, clientCert, err := CreateRandomClientCert() + require.NoError(t, err) + + dsInfo := sqleng.DataSourceInfo{ + JsonData: sqleng.JsonData{ + Mode: "require", + ConfigurationMethod: "file-content", + }, + DecryptedSecureJSONData: map[string]string{ + "tlsClientCert": string(clientCert), + }, + } + _, err = GetTLSConfig(dsInfo, noReadFile, "localhost") + require.ErrorIs(t, err, errPartialClientCertNoKey) +} + +func TestTLSMethodFileContentClientCertMissingCert(t *testing.T) { + clientKey, _, err := CreateRandomClientCert() + require.NoError(t, err) + + dsInfo := sqleng.DataSourceInfo{ + JsonData: sqleng.JsonData{ + Mode: "require", + ConfigurationMethod: "file-content", + }, + DecryptedSecureJSONData: map[string]string{ + "tlsClientKey": string(clientKey), + }, + } + _, err = GetTLSConfig(dsInfo, noReadFile, "localhost") + require.ErrorIs(t, err, errPartialClientCertNoCert) +} + +func TestTLSMethodFilePathClientCertMissingKey(t *testing.T) { + clientKey, _, err := CreateRandomClientCert() + require.NoError(t, err) + + readFile := newMockReadFile(map[string]([]byte){ + "path1": clientKey, + }) + + dsInfo := sqleng.DataSourceInfo{ + JsonData: sqleng.JsonData{ + Mode: "require", + ConfigurationMethod: "file-path", + CertKeyFile: "path1", + }, + } + _, err = GetTLSConfig(dsInfo, readFile, "localhost") + require.ErrorIs(t, err, errPartialClientCertNoCert) +} + +func TestTLSMethodFilePathClientCertMissingCert(t *testing.T) { + _, clientCert, err := CreateRandomClientCert() + require.NoError(t, err) + + readFile := newMockReadFile(map[string]([]byte){ + "path1": clientCert, + }) + + dsInfo := sqleng.DataSourceInfo{ + JsonData: sqleng.JsonData{ + Mode: "require", + ConfigurationMethod: "file-path", + CertFile: "path1", + }, + } + _, err = GetTLSConfig(dsInfo, readFile, "localhost") + require.ErrorIs(t, err, errPartialClientCertNoKey) +} + +func TestTLSVerifyFull(t *testing.T) { + rootCertBytes, err := CreateRandomRootCertBytes() + require.NoError(t, err) + + dsInfo := sqleng.DataSourceInfo{ + JsonData: sqleng.JsonData{ + Mode: "verify-full", + ConfigurationMethod: "file-content", + }, + DecryptedSecureJSONData: map[string]string{ + "tlsCACert": string(rootCertBytes), + }, + } + c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") + require.NoError(t, err) + require.NotNil(t, c) + require.False(t, c.InsecureSkipVerify) + require.Nil(t, c.VerifyConnection) + require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available +} + +func TestTLSMethodFileContent(t *testing.T) { + rootCertBytes, err := CreateRandomRootCertBytes() + require.NoError(t, err) + + clientKey, clientCert, err := CreateRandomClientCert() + require.NoError(t, err) + + dsInfo := sqleng.DataSourceInfo{ + JsonData: sqleng.JsonData{ + Mode: "verify-full", + ConfigurationMethod: "file-content", + }, + DecryptedSecureJSONData: map[string]string{ + "tlsCACert": string(rootCertBytes), + "tlsClientCert": string(clientCert), + "tlsClientKey": string(clientKey), + }, + } + c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") + require.NoError(t, err) + require.NotNil(t, c) + require.Len(t, c.Certificates, 1) + require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available +} + +func TestTLSMethodFilePath(t *testing.T) { + rootCertBytes, err := CreateRandomRootCertBytes() + require.NoError(t, err) + + clientKey, clientCert, err := CreateRandomClientCert() + require.NoError(t, err) + + readFile := newMockReadFile(map[string]([]byte){ + "root-cert-path": rootCertBytes, + "client-key-path": clientKey, + "client-cert-path": clientCert, + }) + + dsInfo := sqleng.DataSourceInfo{ + JsonData: sqleng.JsonData{ + Mode: "verify-full", + ConfigurationMethod: "file-path", + RootCertFile: "root-cert-path", + CertKeyFile: "client-key-path", + CertFile: "client-cert-path", + }, + } + c, err := GetTLSConfig(dsInfo, readFile, "localhost") + require.NoError(t, err) + require.NotNil(t, c) + require.Len(t, c.Certificates, 1) + require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available +} + +func TestTLSMethodFilePathRootCertDoesNotExist(t *testing.T) { + readFile := newMockReadFile(map[string]([]byte){}) + + dsInfo := sqleng.DataSourceInfo{ + JsonData: sqleng.JsonData{ + Mode: "verify-full", + ConfigurationMethod: "file-path", + RootCertFile: "path1", + }, + } + _, err := GetTLSConfig(dsInfo, readFile, "localhost") + require.ErrorIs(t, err, os.ErrNotExist) +} + +func TestTLSMethodFilePathClientCertKeyDoesNotExist(t *testing.T) { + _, clientCert, err := CreateRandomClientCert() + require.NoError(t, err) + + readFile := newMockReadFile(map[string]([]byte){ + "cert-path": clientCert, + }) + + dsInfo := sqleng.DataSourceInfo{ + JsonData: sqleng.JsonData{ + Mode: "require", + ConfigurationMethod: "file-path", + CertKeyFile: "key-path", + CertFile: "cert-path", + }, + } + _, err = GetTLSConfig(dsInfo, readFile, "localhost") + require.ErrorIs(t, err, os.ErrNotExist) +} + +func TestTLSMethodFilePathClientCertCertDoesNotExist(t *testing.T) { + clientKey, _, err := CreateRandomClientCert() + require.NoError(t, err) + + readFile := newMockReadFile(map[string]([]byte){ + "key-path": clientKey, + }) + + dsInfo := sqleng.DataSourceInfo{ + JsonData: sqleng.JsonData{ + Mode: "require", + ConfigurationMethod: "file-path", + CertKeyFile: "key-path", + CertFile: "cert-path", + }, + } + _, err = GetTLSConfig(dsInfo, readFile, "localhost") + require.ErrorIs(t, err, os.ErrNotExist) +} + +// method="" equals to method="file-path" +func TestTLSMethodEmpty(t *testing.T) { + rootCertBytes, err := CreateRandomRootCertBytes() + require.NoError(t, err) + + clientKey, clientCert, err := CreateRandomClientCert() + require.NoError(t, err) + + readFile := newMockReadFile(map[string]([]byte){ + "root-cert-path": rootCertBytes, + "client-key-path": clientKey, + "client-cert-path": clientCert, + }) + + dsInfo := sqleng.DataSourceInfo{ + JsonData: sqleng.JsonData{ + Mode: "verify-full", + ConfigurationMethod: "", + RootCertFile: "root-cert-path", + CertKeyFile: "client-key-path", + CertFile: "client-cert-path", + }, + } + c, err := GetTLSConfig(dsInfo, readFile, "localhost") + require.NoError(t, err) + require.NotNil(t, c) + require.Len(t, c.Certificates, 1) + require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available +} + +func TestTLSVerifyFullNoRootCertProvided(t *testing.T) { + // this is ok. go will use the default system certs + dsInfo := sqleng.DataSourceInfo{ + JsonData: sqleng.JsonData{ + Mode: "verify-full", + ConfigurationMethod: "file-content", + }, + DecryptedSecureJSONData: map[string]string{}, + } + _, err := GetTLSConfig(dsInfo, noReadFile, "localhost") + require.NoError(t, err) +} + +func TestTLSInvalidMode(t *testing.T) { + dsInfo := sqleng.DataSourceInfo{ + JsonData: sqleng.JsonData{ + Mode: "not-a-valid-mode", + }, + } + + _, err := GetTLSConfig(dsInfo, noReadFile, "localhost") + require.Error(t, err) +} + +func TestTLSServerNameSetInEveryMode(t *testing.T) { + modes := []string{"require", "verify-ca", "verify-full"} + + for _, mode := range modes { + t.Run(mode, func(t *testing.T) { + dsInfo := sqleng.DataSourceInfo{ + JsonData: sqleng.JsonData{ + Mode: mode, + }, + DecryptedSecureJSONData: map[string]string{}, + } + c, err := GetTLSConfig(dsInfo, noReadFile, "example.com") + require.NoError(t, err) + require.Equal(t, "example.com", c.ServerName) + }) + } +} diff --git a/pkg/tsdb/grafana-postgresql-datasource/tls/tls_test_helpers.go b/pkg/tsdb/grafana-postgresql-datasource/tls/tls_test_helpers.go new file mode 100644 index 00000000000..1b62df63d09 --- /dev/null +++ b/pkg/tsdb/grafana-postgresql-datasource/tls/tls_test_helpers.go @@ -0,0 +1,105 @@ +package tls + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "time" +) + +func CreateRandomRootCertBytes() ([]byte, error) { + cert := x509.Certificate{ + SerialNumber: big.NewInt(42), + Subject: pkix.Name{ + CommonName: "test1", + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + + bytes, err := x509.CreateCertificate(rand.Reader, &cert, &cert, &key.PublicKey, key) + if err != nil { + return nil, err + } + + return pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: bytes, + }), nil +} + +func CreateRandomClientCert() ([]byte, []byte, error) { + caKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, err + } + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, err + } + + keyBytes := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + + caCert := x509.Certificate{ + SerialNumber: big.NewInt(42), + Subject: pkix.Name{ + CommonName: "test1", + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + cert := x509.Certificate{ + SerialNumber: big.NewInt(2019), + Subject: pkix.Name{ + CommonName: "test1", + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + } + + certData, err := x509.CreateCertificate(rand.Reader, &cert, &caCert, &key.PublicKey, caKey) + if err != nil { + return nil, nil, err + } + + certBytes := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certData, + }) + + return keyBytes, certBytes, nil +} + +func newMockReadFile(data map[string]([]byte)) ReadFileFunc { + return func(path string) ([]byte, error) { + bytes, ok := data[path] + if !ok { + return nil, os.ErrNotExist + } + return bytes, nil + } +} diff --git a/pkg/tsdb/grafana-postgresql-datasource/tlsmanager.go b/pkg/tsdb/grafana-postgresql-datasource/tlsmanager.go deleted file mode 100644 index 116872d0613..00000000000 --- a/pkg/tsdb/grafana-postgresql-datasource/tlsmanager.go +++ /dev/null @@ -1,249 +0,0 @@ -package postgres - -import ( - "fmt" - "os" - "path/filepath" - "strconv" - "strings" - "sync" - "time" - - "github.com/grafana/grafana-plugin-sdk-go/backend/log" - "github.com/grafana/grafana/pkg/tsdb/sqleng" -) - -var validateCertFunc = validateCertFilePaths -var writeCertFileFunc = writeCertFile - -type certFileType int - -const ( - rootCert = iota - clientCert - clientKey -) - -type tlsSettingsProvider interface { - getTLSSettings(dsInfo sqleng.DataSourceInfo) (tlsSettings, error) -} - -type datasourceCacheManager struct { - locker *locker - cache sync.Map -} - -type tlsManager struct { - logger log.Logger - dsCacheInstance datasourceCacheManager - dataPath string -} - -func newTLSManager(logger log.Logger, dataPath string) tlsSettingsProvider { - return &tlsManager{ - logger: logger, - dataPath: dataPath, - dsCacheInstance: datasourceCacheManager{locker: newLocker()}, - } -} - -type tlsSettings struct { - Mode string - ConfigurationMethod string - RootCertFile string - CertFile string - CertKeyFile string -} - -func (m *tlsManager) getTLSSettings(dsInfo sqleng.DataSourceInfo) (tlsSettings, error) { - tlsconfig := tlsSettings{ - Mode: dsInfo.JsonData.Mode, - } - - isTLSDisabled := (tlsconfig.Mode == "disable") - - if isTLSDisabled { - m.logger.Debug("Postgres TLS/SSL is disabled") - return tlsconfig, nil - } - - m.logger.Debug("Postgres TLS/SSL is enabled", "tlsMode", tlsconfig.Mode) - - tlsconfig.ConfigurationMethod = dsInfo.JsonData.ConfigurationMethod - tlsconfig.RootCertFile = dsInfo.JsonData.RootCertFile - tlsconfig.CertFile = dsInfo.JsonData.CertFile - tlsconfig.CertKeyFile = dsInfo.JsonData.CertKeyFile - - if tlsconfig.ConfigurationMethod == "file-content" { - if err := m.writeCertFiles(dsInfo, &tlsconfig); err != nil { - return tlsconfig, err - } - } else { - if err := validateCertFunc(tlsconfig.RootCertFile, tlsconfig.CertFile, tlsconfig.CertKeyFile); err != nil { - return tlsconfig, err - } - } - return tlsconfig, nil -} - -func (t certFileType) String() string { - switch t { - case rootCert: - return "root certificate" - case clientCert: - return "client certificate" - case clientKey: - return "client key" - default: - panic(fmt.Sprintf("Unrecognized certFileType %d", t)) - } -} - -func getFileName(dataDir string, fileType certFileType) string { - var filename string - switch fileType { - case rootCert: - filename = "root.crt" - case clientCert: - filename = "client.crt" - case clientKey: - filename = "client.key" - default: - panic(fmt.Sprintf("unrecognized certFileType %s", fileType.String())) - } - generatedFilePath := filepath.Join(dataDir, filename) - return generatedFilePath -} - -// writeCertFile writes a certificate file. -func writeCertFile(logger log.Logger, fileContent string, generatedFilePath string) error { - fileContent = strings.TrimSpace(fileContent) - if fileContent != "" { - logger.Debug("Writing cert file", "path", generatedFilePath) - if err := os.WriteFile(generatedFilePath, []byte(fileContent), 0600); err != nil { - return err - } - // Make sure the file has the permissions expected by the Postgresql driver, otherwise it will bail - if err := os.Chmod(generatedFilePath, 0600); err != nil { - return err - } - return nil - } - - logger.Debug("Deleting cert file since no content is provided", "path", generatedFilePath) - exists, err := fileExists(generatedFilePath) - if err != nil { - return err - } - if exists { - if err := os.Remove(generatedFilePath); err != nil { - return fmt.Errorf("failed to remove %q: %w", generatedFilePath, err) - } - } - return nil -} - -func (m *tlsManager) writeCertFiles(dsInfo sqleng.DataSourceInfo, tlsconfig *tlsSettings) error { - m.logger.Debug("Writing TLS certificate files to disk") - tlsRootCert := dsInfo.DecryptedSecureJSONData["tlsCACert"] - tlsClientCert := dsInfo.DecryptedSecureJSONData["tlsClientCert"] - tlsClientKey := dsInfo.DecryptedSecureJSONData["tlsClientKey"] - if tlsRootCert == "" && tlsClientCert == "" && tlsClientKey == "" { - m.logger.Debug("No TLS/SSL certificates provided") - } - - // Calculate all files path - workDir := filepath.Join(m.dataPath, "tls", dsInfo.UID+"generatedTLSCerts") - tlsconfig.RootCertFile = getFileName(workDir, rootCert) - tlsconfig.CertFile = getFileName(workDir, clientCert) - tlsconfig.CertKeyFile = getFileName(workDir, clientKey) - - // Find datasource in the cache, if found, skip writing files - cacheKey := strconv.Itoa(int(dsInfo.ID)) - m.dsCacheInstance.locker.RLock(cacheKey) - item, ok := m.dsCacheInstance.cache.Load(cacheKey) - m.dsCacheInstance.locker.RUnlock(cacheKey) - if ok { - if !item.(time.Time).Before(dsInfo.Updated) { - return nil - } - } - - m.dsCacheInstance.locker.Lock(cacheKey) - defer m.dsCacheInstance.locker.Unlock(cacheKey) - - item, ok = m.dsCacheInstance.cache.Load(cacheKey) - if ok { - if !item.(time.Time).Before(dsInfo.Updated) { - return nil - } - } - - // Write certification directory and files - exists, err := fileExists(workDir) - if err != nil { - return err - } - if !exists { - if err := os.MkdirAll(workDir, 0700); err != nil { - return err - } - } - - if err = writeCertFileFunc(m.logger, tlsRootCert, tlsconfig.RootCertFile); err != nil { - return err - } - if err = writeCertFileFunc(m.logger, tlsClientCert, tlsconfig.CertFile); err != nil { - return err - } - if err = writeCertFileFunc(m.logger, tlsClientKey, tlsconfig.CertKeyFile); err != nil { - return err - } - - // we do not want to point to cert-files that do not exist - if tlsRootCert == "" { - tlsconfig.RootCertFile = "" - } - - if tlsClientCert == "" { - tlsconfig.CertFile = "" - } - - if tlsClientKey == "" { - tlsconfig.CertKeyFile = "" - } - - // Update datasource cache - m.dsCacheInstance.cache.Store(cacheKey, dsInfo.Updated) - return nil -} - -// validateCertFilePaths validates configured certificate file paths. -func validateCertFilePaths(rootCert, clientCert, clientKey string) error { - for _, fpath := range []string{rootCert, clientCert, clientKey} { - if fpath == "" { - continue - } - exists, err := fileExists(fpath) - if err != nil { - return err - } - if !exists { - return fmt.Errorf("certificate file %q doesn't exist", fpath) - } - } - return nil -} - -// Exists determines whether a file/directory exists or not. -func fileExists(fpath string) (bool, error) { - _, err := os.Stat(fpath) - if err != nil { - if !os.IsNotExist(err) { - return false, err - } - return false, nil - } - - return true, nil -} diff --git a/pkg/tsdb/grafana-postgresql-datasource/tlsmanager_test.go b/pkg/tsdb/grafana-postgresql-datasource/tlsmanager_test.go deleted file mode 100644 index 8e60e841cab..00000000000 --- a/pkg/tsdb/grafana-postgresql-datasource/tlsmanager_test.go +++ /dev/null @@ -1,332 +0,0 @@ -package postgres - -import ( - "fmt" - "path/filepath" - "strconv" - "strings" - "sync" - "testing" - "time" - - "github.com/grafana/grafana-plugin-sdk-go/backend" - "github.com/grafana/grafana-plugin-sdk-go/backend/log" - "github.com/grafana/grafana/pkg/setting" - "github.com/grafana/grafana/pkg/tsdb/sqleng" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - _ "github.com/lib/pq" -) - -var writeCertFileCallNum int - -// TestDataSourceCacheManager is to test the Cache manager -func TestDataSourceCacheManager(t *testing.T) { - cfg := setting.NewCfg() - cfg.DataPath = t.TempDir() - mng := tlsManager{ - logger: backend.NewLoggerWith("logger", "tsdb.postgres"), - dsCacheInstance: datasourceCacheManager{locker: newLocker()}, - dataPath: cfg.DataPath, - } - jsonData := sqleng.JsonData{ - Mode: "verify-full", - ConfigurationMethod: "file-content", - } - secureJSONData := map[string]string{ - "tlsClientCert": "I am client certification", - "tlsClientKey": "I am client key", - "tlsCACert": "I am CA certification", - } - - updateTime := time.Now().Add(-5 * time.Minute) - - mockValidateCertFilePaths() - t.Cleanup(resetValidateCertFilePaths) - - t.Run("Check datasource cache creation", func(t *testing.T) { - var wg sync.WaitGroup - wg.Add(10) - for id := int64(1); id <= 10; id++ { - go func(id int64) { - ds := sqleng.DataSourceInfo{ - ID: id, - Updated: updateTime, - Database: "database", - JsonData: jsonData, - DecryptedSecureJSONData: secureJSONData, - UID: "testData", - } - s := tlsSettings{} - err := mng.writeCertFiles(ds, &s) - require.NoError(t, err) - wg.Done() - }(id) - } - wg.Wait() - - t.Run("check cache creation is succeed", func(t *testing.T) { - for id := int64(1); id <= 10; id++ { - updated, ok := mng.dsCacheInstance.cache.Load(strconv.Itoa(int(id))) - require.True(t, ok) - require.Equal(t, updateTime, updated) - } - }) - }) - - t.Run("Check datasource cache modification", func(t *testing.T) { - t.Run("check when version not changed, cache and files are not updated", func(t *testing.T) { - mockWriteCertFile() - t.Cleanup(resetWriteCertFile) - var wg1 sync.WaitGroup - wg1.Add(5) - for id := int64(1); id <= 5; id++ { - go func(id int64) { - ds := sqleng.DataSourceInfo{ - ID: 1, - Updated: updateTime, - Database: "database", - JsonData: jsonData, - DecryptedSecureJSONData: secureJSONData, - UID: "testData", - } - s := tlsSettings{} - err := mng.writeCertFiles(ds, &s) - require.NoError(t, err) - wg1.Done() - }(id) - } - wg1.Wait() - assert.Equal(t, writeCertFileCallNum, 0) - }) - - t.Run("cache is updated with the last datasource version", func(t *testing.T) { - dsV2 := sqleng.DataSourceInfo{ - ID: 1, - Updated: updateTime.Add(time.Minute), - Database: "database", - JsonData: jsonData, - DecryptedSecureJSONData: secureJSONData, - UID: "testData", - } - dsV3 := sqleng.DataSourceInfo{ - ID: 1, - Updated: updateTime.Add(2 * time.Minute), - Database: "database", - JsonData: jsonData, - DecryptedSecureJSONData: secureJSONData, - UID: "testData", - } - s := tlsSettings{} - err := mng.writeCertFiles(dsV2, &s) - require.NoError(t, err) - err = mng.writeCertFiles(dsV3, &s) - require.NoError(t, err) - version, ok := mng.dsCacheInstance.cache.Load("1") - require.True(t, ok) - require.Equal(t, updateTime.Add(2*time.Minute), version) - }) - }) -} - -// Test getFileName - -func TestGetFileName(t *testing.T) { - testCases := []struct { - desc string - datadir string - fileType certFileType - expErr string - expectedGeneratedPath string - }{ - { - desc: "Get File Name for root certification", - datadir: ".", - fileType: rootCert, - expectedGeneratedPath: "root.crt", - }, - { - desc: "Get File Name for client certification", - datadir: ".", - fileType: clientCert, - expectedGeneratedPath: "client.crt", - }, - { - desc: "Get File Name for client certification", - datadir: ".", - fileType: clientKey, - expectedGeneratedPath: "client.key", - }, - } - for _, tt := range testCases { - t.Run(tt.desc, func(t *testing.T) { - generatedPath := getFileName(tt.datadir, tt.fileType) - assert.Equal(t, tt.expectedGeneratedPath, generatedPath) - }) - } -} - -// Test getTLSSettings. -func TestGetTLSSettings(t *testing.T) { - cfg := setting.NewCfg() - cfg.DataPath = t.TempDir() - - mockValidateCertFilePaths() - t.Cleanup(resetValidateCertFilePaths) - - updatedTime := time.Now() - - testCases := []struct { - desc string - expErr string - jsonData sqleng.JsonData - secureJSONData map[string]string - uid string - tlsSettings tlsSettings - updated time.Time - }{ - { - desc: "Custom TLS authentication disabled", - updated: updatedTime, - jsonData: sqleng.JsonData{ - Mode: "disable", - RootCertFile: "i/am/coding/ca.crt", - CertFile: "i/am/coding/client.crt", - CertKeyFile: "i/am/coding/client.key", - ConfigurationMethod: "file-path", - }, - tlsSettings: tlsSettings{Mode: "disable"}, - }, - { - desc: "Custom TLS authentication with file path", - updated: updatedTime.Add(time.Minute), - jsonData: sqleng.JsonData{ - Mode: "verify-full", - ConfigurationMethod: "file-path", - RootCertFile: "i/am/coding/ca.crt", - CertFile: "i/am/coding/client.crt", - CertKeyFile: "i/am/coding/client.key", - }, - tlsSettings: tlsSettings{ - Mode: "verify-full", - ConfigurationMethod: "file-path", - RootCertFile: "i/am/coding/ca.crt", - CertFile: "i/am/coding/client.crt", - CertKeyFile: "i/am/coding/client.key", - }, - }, - { - desc: "Custom TLS mode verify-full with certificate files content", - updated: updatedTime.Add(2 * time.Minute), - uid: "xxx", - jsonData: sqleng.JsonData{ - Mode: "verify-full", - ConfigurationMethod: "file-content", - }, - secureJSONData: map[string]string{ - "tlsCACert": "I am CA certification", - "tlsClientCert": "I am client certification", - "tlsClientKey": "I am client key", - }, - tlsSettings: tlsSettings{ - Mode: "verify-full", - ConfigurationMethod: "file-content", - RootCertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "root.crt"), - CertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.crt"), - CertKeyFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.key"), - }, - }, - { - desc: "Custom TLS mode verify-ca with no client certificates with certificate files content", - updated: updatedTime.Add(3 * time.Minute), - uid: "xxx", - jsonData: sqleng.JsonData{ - Mode: "verify-ca", - ConfigurationMethod: "file-content", - }, - secureJSONData: map[string]string{ - "tlsCACert": "I am CA certification", - }, - tlsSettings: tlsSettings{ - Mode: "verify-ca", - ConfigurationMethod: "file-content", - RootCertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "root.crt"), - CertFile: "", - CertKeyFile: "", - }, - }, - { - desc: "Custom TLS mode require with client certificates and no root certificate with certificate files content", - updated: updatedTime.Add(4 * time.Minute), - uid: "xxx", - jsonData: sqleng.JsonData{ - Mode: "require", - ConfigurationMethod: "file-content", - }, - secureJSONData: map[string]string{ - "tlsClientCert": "I am client certification", - "tlsClientKey": "I am client key", - }, - tlsSettings: tlsSettings{ - Mode: "require", - ConfigurationMethod: "file-content", - RootCertFile: "", - CertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.crt"), - CertKeyFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.key"), - }, - }, - } - for _, tt := range testCases { - t.Run(tt.desc, func(t *testing.T) { - var settings tlsSettings - var err error - mng := tlsManager{ - logger: backend.NewLoggerWith("logger", "tsdb.postgres"), - dsCacheInstance: datasourceCacheManager{locker: newLocker()}, - dataPath: cfg.DataPath, - } - - ds := sqleng.DataSourceInfo{ - JsonData: tt.jsonData, - DecryptedSecureJSONData: tt.secureJSONData, - UID: tt.uid, - Updated: tt.updated, - } - - settings, err = mng.getTLSSettings(ds) - - if tt.expErr == "" { - require.NoError(t, err, tt.desc) - assert.Equal(t, tt.tlsSettings, settings) - } else { - require.Error(t, err, tt.desc) - assert.True(t, strings.HasPrefix(err.Error(), tt.expErr), - fmt.Sprintf("%s: %q doesn't start with %q", tt.desc, err, tt.expErr)) - } - }) - } -} - -func mockValidateCertFilePaths() { - validateCertFunc = func(rootCert, clientCert, clientKey string) error { - return nil - } -} - -func resetValidateCertFilePaths() { - validateCertFunc = validateCertFilePaths -} - -func mockWriteCertFile() { - writeCertFileCallNum = 0 - writeCertFileFunc = func(logger log.Logger, fileContent string, generatedFilePath string) error { - writeCertFileCallNum++ - return nil - } -} - -func resetWriteCertFile() { - writeCertFileCallNum = 0 - writeCertFileFunc = writeCertFile -}