diff --git a/go.mod b/go.mod index 37e27eddb7f..061044a1a84 100644 --- a/go.mod +++ b/go.mod @@ -471,14 +471,9 @@ require ( github.com/grafana/grafana/pkg/apiserver v0.0.0-20240226124929-648abdbd0ea4 // @grafana/grafana-app-platform-squad ) -require github.com/jackc/pgx/v5 v5.5.3 // @grafana/oss-big-tent - require ( github.com/bufbuild/protocompile v0.4.0 // indirect github.com/grafana/sqlds/v3 v3.2.0 // indirect - github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jhump/protoreflect v1.15.1 // indirect github.com/klauspost/asmfmt v1.3.2 // indirect github.com/krasun/gosqlparser v1.0.5 // @grafana/grafana-app-platform-squad diff --git a/go.sum b/go.sum index 30e3b806f88..a6aa605d076 100644 --- a/go.sum +++ b/go.sum @@ -2370,7 +2370,6 @@ github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bY github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= -github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= @@ -2381,8 +2380,6 @@ github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwX github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.2.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= @@ -2394,14 +2391,10 @@ github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9 github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= github.com/jackc/pgx/v4 v4.15.0/go.mod h1:D/zyOyXiaM1TmVWnOM18p0xdDtdakRBa0RsVGI3U3bw= -github.com/jackc/pgx/v5 v5.5.3 h1:Ces6/M3wbDXYpM8JyyPD57ivTtJACFZJd885pdIaV2s= -github.com/jackc/pgx/v5 v5.5.3/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.2.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= -github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jarcoal/httpmock v1.3.0/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg= github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc= github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4= diff --git a/pkg/tsdb/grafana-postgresql-datasource/locker.go b/pkg/tsdb/grafana-postgresql-datasource/locker.go new file mode 100644 index 00000000000..796c37c7415 --- /dev/null +++ b/pkg/tsdb/grafana-postgresql-datasource/locker.go @@ -0,0 +1,85 @@ +package postgres + +import ( + "fmt" + "sync" +) + +// locker is a named reader/writer mutual exclusion lock. +// The lock for each particular key can be held by an arbitrary number of readers or a single writer. +type locker struct { + locks map[any]*sync.RWMutex + locksRW *sync.RWMutex +} + +func newLocker() *locker { + return &locker{ + locks: make(map[any]*sync.RWMutex), + locksRW: new(sync.RWMutex), + } +} + +// Lock locks named rw mutex with specified key for writing. +// If the lock with the same key is already locked for reading or writing, +// Lock blocks until the lock is available. +func (lkr *locker) Lock(key any) { + lk, ok := lkr.getLock(key) + if !ok { + lk = lkr.newLock(key) + } + lk.Lock() +} + +// Unlock unlocks named rw mutex with specified key for writing. It is a run-time error if rw is +// not locked for writing on entry to Unlock. +func (lkr *locker) Unlock(key any) { + lk, ok := lkr.getLock(key) + if !ok { + panic(fmt.Errorf("lock for key '%s' not initialized", key)) + } + lk.Unlock() +} + +// RLock locks named rw mutex with specified key for reading. +// +// It should not be used for recursive read locking for the same key; a blocked Lock +// call excludes new readers from acquiring the lock. See the +// documentation on the golang RWMutex type. +func (lkr *locker) RLock(key any) { + lk, ok := lkr.getLock(key) + if !ok { + lk = lkr.newLock(key) + } + lk.RLock() +} + +// RUnlock undoes a single RLock call for specified key; +// it does not affect other simultaneous readers of locker for specified key. +// It is a run-time error if locker for specified key is not locked for reading +func (lkr *locker) RUnlock(key any) { + lk, ok := lkr.getLock(key) + if !ok { + panic(fmt.Errorf("lock for key '%s' not initialized", key)) + } + lk.RUnlock() +} + +func (lkr *locker) newLock(key any) *sync.RWMutex { + lkr.locksRW.Lock() + defer lkr.locksRW.Unlock() + + if lk, ok := lkr.locks[key]; ok { + return lk + } + lk := new(sync.RWMutex) + lkr.locks[key] = lk + return lk +} + +func (lkr *locker) getLock(key any) (*sync.RWMutex, bool) { + lkr.locksRW.RLock() + defer lkr.locksRW.RUnlock() + + lock, ok := lkr.locks[key] + return lock, ok +} diff --git a/pkg/tsdb/grafana-postgresql-datasource/locker_test.go b/pkg/tsdb/grafana-postgresql-datasource/locker_test.go new file mode 100644 index 00000000000..b1dc64f0351 --- /dev/null +++ b/pkg/tsdb/grafana-postgresql-datasource/locker_test.go @@ -0,0 +1,63 @@ +package postgres + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestIntegrationLocker(t *testing.T) { + if testing.Short() { + t.Skip("Tests with Sleep") + } + const notUpdated = "not_updated" + const atThread1 = "at_thread_1" + const atThread2 = "at_thread_2" + t.Run("Should lock for same keys", func(t *testing.T) { + updated := notUpdated + locker := newLocker() + locker.Lock(1) + var wg sync.WaitGroup + wg.Add(1) + defer func() { + locker.Unlock(1) + wg.Wait() + }() + + go func() { + locker.RLock(1) + defer func() { + locker.RUnlock(1) + wg.Done() + }() + require.Equal(t, atThread1, updated, "Value should be updated in different thread") + updated = atThread2 + }() + time.Sleep(time.Millisecond * 10) + require.Equal(t, notUpdated, updated, "Value should not be updated in different thread") + updated = atThread1 + }) + + t.Run("Should not lock for different keys", func(t *testing.T) { + updated := notUpdated + locker := newLocker() + locker.Lock(1) + defer locker.Unlock(1) + var wg sync.WaitGroup + wg.Add(1) + go func() { + locker.RLock(2) + defer func() { + locker.RUnlock(2) + wg.Done() + }() + require.Equal(t, notUpdated, updated, "Value should not be updated in different thread") + updated = atThread2 + }() + wg.Wait() + require.Equal(t, atThread2, updated, "Value should be updated in different thread") + updated = atThread1 + }) +} diff --git a/pkg/tsdb/grafana-postgresql-datasource/postgres.go b/pkg/tsdb/grafana-postgresql-datasource/postgres.go index cfcdd0ede00..53e46308857 100644 --- a/pkg/tsdb/grafana-postgresql-datasource/postgres.go +++ b/pkg/tsdb/grafana-postgresql-datasource/postgres.go @@ -4,42 +4,38 @@ import ( "context" "database/sql" "encoding/json" - "errors" "fmt" - "os" "reflect" "strconv" "strings" "time" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" - pgxstdlib "github.com/jackc/pgx/v5/stdlib" - "github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/backend/datasource" "github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt" "github.com/grafana/grafana-plugin-sdk-go/data" "github.com/grafana/grafana-plugin-sdk-go/data/sqlutil" + "github.com/lib/pq" "github.com/grafana/grafana-plugin-sdk-go/backend/log" "github.com/grafana/grafana/pkg/setting" - "github.com/grafana/grafana/pkg/tsdb/grafana-postgresql-datasource/tls" "github.com/grafana/grafana/pkg/tsdb/sqleng" ) func ProvideService(cfg *setting.Cfg) *Service { logger := backend.NewLoggerWith("logger", "tsdb.postgres") s := &Service{ - logger: logger, + tlsManager: newTLSManager(logger, cfg.DataPath), + logger: logger, } s.im = datasource.NewInstanceManager(s.newInstanceSettings()) return s } type Service struct { - im instancemgmt.InstanceManager - logger log.Logger + tlsManager tlsSettingsProvider + im instancemgmt.InstanceManager + logger log.Logger } func (s *Service) getDSInfo(ctx context.Context, pluginCtx backend.PluginContext) (*sqleng.DataSourceHandler, error) { @@ -59,7 +55,13 @@ func (s *Service) QueryData(ctx context.Context, req *backend.QueryDataRequest) return dsInfo.QueryData(ctx, req) } -func newPostgres(ctx context.Context, userFacingDefaultError string, rowLimit int64, dsInfo sqleng.DataSourceInfo, pgxConf *pgx.ConnConfig, logger log.Logger, settings backend.DataSourceInstanceSettings) (*sql.DB, *sqleng.DataSourceHandler, error) { +func newPostgres(ctx context.Context, userFacingDefaultError string, rowLimit int64, dsInfo sqleng.DataSourceInfo, cnnstr string, logger log.Logger, settings backend.DataSourceInstanceSettings) (*sql.DB, *sqleng.DataSourceHandler, error) { + connector, err := pq.NewConnector(cnnstr) + if err != nil { + logger.Error("postgres connector creation failed", "error", err) + return nil, nil, fmt.Errorf("postgres connector creation failed") + } + proxyClient, err := settings.ProxyClient(ctx) if err != nil { logger.Error("postgres proxy creation failed", "error", err) @@ -72,8 +74,9 @@ func newPostgres(ctx context.Context, userFacingDefaultError string, rowLimit in logger.Error("postgres proxy creation failed", "error", err) return nil, nil, fmt.Errorf("postgres proxy creation failed") } - - pgxConf.DialFunc = newPgxDialFunc(dialer) + postgresDialer := newPostgresProxyDialer(dialer) + // update the postgres dialer with the proxy dialer + connector.Dialer(postgresDialer) } config := sqleng.DataPluginConfiguration{ @@ -84,7 +87,7 @@ func newPostgres(ctx context.Context, userFacingDefaultError string, rowLimit in queryResultTransformer := postgresQueryResultTransformer{} - db := pgxstdlib.OpenDB(*pgxConf) + db := sql.OpenDB(connector) db.SetMaxOpenConns(config.DSInfo.JsonData.MaxOpenConns) db.SetMaxIdleConns(config.DSInfo.JsonData.MaxIdleConns) @@ -140,7 +143,7 @@ func (s *Service) newInstanceSettings() datasource.InstanceFactoryFunc { DecryptedSecureJSONData: settings.DecryptedSecureJSONData, } - pgxConf, err := generateConnectionConfig(dsInfo) + cnnstr, err := s.generateConnectionString(dsInfo) if err != nil { return nil, err } @@ -150,7 +153,7 @@ func (s *Service) newInstanceSettings() datasource.InstanceFactoryFunc { return nil, err } - _, handler, err := newPostgres(ctx, userFacingDefaultError, sqlCfg.RowLimit, dsInfo, pgxConf, logger, settings) + _, handler, err := newPostgres(ctx, userFacingDefaultError, sqlCfg.RowLimit, dsInfo, cnnstr, logger, settings) if err != nil { logger.Error("Failed connecting to Postgres", "err", err) @@ -167,11 +170,13 @@ func escape(input string) string { return strings.ReplaceAll(strings.ReplaceAll(input, `\`, `\\`), "'", `\'`) } -func generateConnectionConfig(dsInfo sqleng.DataSourceInfo) (*pgx.ConnConfig, error) { +func (s *Service) generateConnectionString(dsInfo sqleng.DataSourceInfo) (string, error) { + logger := s.logger var host string var port int if strings.HasPrefix(dsInfo.URL, "/") { host = dsInfo.URL + logger.Debug("Generating connection string with Unix socket specifier", "socket", host) } else { index := strings.LastIndex(dsInfo.URL, ":") v6Index := strings.Index(dsInfo.URL, "]") @@ -182,8 +187,12 @@ func generateConnectionConfig(dsInfo sqleng.DataSourceInfo) (*pgx.ConnConfig, er var err error port, err = strconv.Atoi(sp[1]) if err != nil { - return nil, fmt.Errorf("invalid port in host specifier %q: %w", sp[1], err) + return "", fmt.Errorf("invalid port in host specifier %q: %w", sp[1], err) } + + logger.Debug("Generating connection string with network host/port pair", "host", host, "port", port) + } else { + logger.Debug("Generating connection string with network host", "host", host) } } else { if index == v6Index+1 { @@ -191,39 +200,46 @@ func generateConnectionConfig(dsInfo sqleng.DataSourceInfo) (*pgx.ConnConfig, er var err error port, err = strconv.Atoi(dsInfo.URL[index+1:]) if err != nil { - return nil, fmt.Errorf("invalid port in host specifier %q: %w", dsInfo.URL[index+1:], err) + return "", fmt.Errorf("invalid port in host specifier %q: %w", dsInfo.URL[index+1:], err) } + + logger.Debug("Generating ipv6 connection string with network host/port pair", "host", host, "port", port) } else { host = dsInfo.URL[1 : len(dsInfo.URL)-1] + logger.Debug("Generating ipv6 connection string with network host", "host", host) } } } - // NOTE: we always set sslmode=disable in the connection string, we handle TLS manually later - connStr := fmt.Sprintf("sslmode=disable user='%s' password='%s' host='%s' dbname='%s'", + connStr := fmt.Sprintf("user='%s' password='%s' host='%s' dbname='%s'", escape(dsInfo.User), escape(dsInfo.DecryptedSecureJSONData["password"]), escape(host), escape(dsInfo.Database)) if port > 0 { connStr += fmt.Sprintf(" port=%d", port) } - conf, err := pgx.ParseConfig(connStr) + tlsSettings, err := s.tlsManager.getTLSSettings(dsInfo) if err != nil { - return nil, err + return "", err } - tlsConf, err := tls.GetTLSConfig(dsInfo, os.ReadFile, host) - if err != nil { - return nil, err + connStr += fmt.Sprintf(" sslmode='%s'", escape(tlsSettings.Mode)) + + // Attach root certificate if provided + if tlsSettings.RootCertFile != "" { + logger.Debug("Setting server root certificate", "tlsRootCert", tlsSettings.RootCertFile) + connStr += fmt.Sprintf(" sslrootcert='%s'", escape(tlsSettings.RootCertFile)) } - // before we set the TLS config, we need to make sure the `.Fallbacks` attribute is unset, see: - // https://github.com/jackc/pgx/discussions/1903#discussioncomment-8430146 - if len(conf.Fallbacks) > 0 { - return nil, errors.New("tls: fallbacks configured, unable to set up TLS config") + // Attach client certificate and key if both are provided + if tlsSettings.CertFile != "" && tlsSettings.CertKeyFile != "" { + logger.Debug("Setting TLS/SSL client auth", "tlsCert", tlsSettings.CertFile, "tlsKey", tlsSettings.CertKeyFile) + connStr += fmt.Sprintf(" sslcert='%s' sslkey='%s'", escape(tlsSettings.CertFile), escape(tlsSettings.CertKeyFile)) + } else if tlsSettings.CertFile != "" || tlsSettings.CertKeyFile != "" { + return "", fmt.Errorf("TLS/SSL client certificate and key must both be specified") } - conf.TLSConfig = tlsConf - return conf, nil + logger.Debug("Generated Postgres connection string successfully") + return connStr, nil } type postgresQueryResultTransformer struct{} @@ -251,44 +267,6 @@ func (s *Service) CheckHealth(ctx context.Context, req *backend.CheckHealthReque func (t *postgresQueryResultTransformer) GetConverterList() []sqlutil.StringConverter { return []sqlutil.StringConverter{ - { - Name: "handle TIME WITH TIME ZONE", - InputScanKind: reflect.Interface, - InputTypeName: strconv.Itoa(pgtype.TimetzOID), - ConversionFunc: func(in *string) (*string, error) { return in, nil }, - Replacer: &sqlutil.StringFieldReplacer{ - OutputFieldType: data.FieldTypeNullableTime, - ReplaceFunc: func(in *string) (any, error) { - if in == nil { - return nil, nil - } - v, err := time.Parse("15:04:05-07", *in) - if err != nil { - return nil, err - } - return &v, nil - }, - }, - }, - { - Name: "handle TIME", - InputScanKind: reflect.Interface, - InputTypeName: "TIME", - ConversionFunc: func(in *string) (*string, error) { return in, nil }, - Replacer: &sqlutil.StringFieldReplacer{ - OutputFieldType: data.FieldTypeNullableTime, - ReplaceFunc: func(in *string) (any, error) { - if in == nil { - return nil, nil - } - v, err := time.Parse("15:04:05", *in) - if err != nil { - return nil, err - } - return &v, nil - }, - }, - }, { Name: "handle FLOAT4", InputScanKind: reflect.Interface, diff --git a/pkg/tsdb/grafana-postgresql-datasource/postgres_snapshot_test.go b/pkg/tsdb/grafana-postgresql-datasource/postgres_snapshot_test.go index dca97dc1dcb..0713e8cdacf 100644 --- a/pkg/tsdb/grafana-postgresql-datasource/postgres_snapshot_test.go +++ b/pkg/tsdb/grafana-postgresql-datasource/postgres_snapshot_test.go @@ -14,7 +14,6 @@ import ( "github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/backend/log" "github.com/grafana/grafana-plugin-sdk-go/experimental" - "github.com/jackc/pgx/v5" "github.com/stretchr/testify/require" "github.com/grafana/grafana/pkg/tsdb/sqleng" @@ -52,7 +51,7 @@ func TestIntegrationPostgresSnapshots(t *testing.T) { t.Skip() } - getCnn := func() (*pgx.ConnConfig, error) { + getCnnStr := func() string { host := os.Getenv("POSTGRES_HOST") if host == "" { host = "localhost" @@ -62,10 +61,8 @@ func TestIntegrationPostgresSnapshots(t *testing.T) { port = "5432" } - cnnString := fmt.Sprintf("user=grafanatest password=grafanatest host=%s port=%s dbname=grafanadstest sslmode=disable", + return fmt.Sprintf("user=grafanatest password=grafanatest host=%s port=%s dbname=grafanadstest sslmode=disable", host, port) - - return pgx.ParseConfig(cnnString) } sqlQueryCommentRe := regexp.MustCompile(`^-- (.+)\n`) @@ -160,10 +157,9 @@ func TestIntegrationPostgresSnapshots(t *testing.T) { logger := log.New() - cnn, err := getCnn() - require.NoError(t, err) + cnnstr := getCnnStr() - db, handler, err := newPostgres(context.Background(), "error", 10000, dsInfo, cnn, logger, backend.DataSourceInstanceSettings{}) + db, handler, err := newPostgres(context.Background(), "error", 10000, dsInfo, cnnstr, logger, backend.DataSourceInstanceSettings{}) t.Cleanup((func() { _, err := db.Exec("DROP TABLE tbl") diff --git a/pkg/tsdb/grafana-postgresql-datasource/postgres_test.go b/pkg/tsdb/grafana-postgresql-datasource/postgres_test.go index 656b03b0a8d..2dec40dc837 100644 --- a/pkg/tsdb/grafana-postgresql-datasource/postgres_test.go +++ b/pkg/tsdb/grafana-postgresql-datasource/postgres_test.go @@ -14,16 +14,19 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/grafana/grafana/pkg/tsdb/grafana-postgresql-datasource/tls" + "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/tsdb/sqleng" - "github.com/jackc/pgx/v5" - _ "github.com/jackc/pgx/v5/stdlib" + _ "github.com/lib/pq" ) -func TestGenerateConnectionConfig(t *testing.T) { - rootCertBytes, err := tls.CreateRandomRootCertBytes() - require.NoError(t, err) +// Test generateConnectionString. +func TestIntegrationGenerateConnectionString(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test") + } + cfg := setting.NewCfg() + cfg.DataPath = t.TempDir() testCases := []struct { desc string @@ -31,15 +34,10 @@ func TestGenerateConnectionConfig(t *testing.T) { user string password string database string - tlsMode string - tlsRootCert []byte + tlsSettings tlsSettings + expConnStr string expErr string - expHost string - expPort uint16 - expUser string - expPassword string - expDatabase string - expTLS bool + uid string }{ { desc: "Unix socket host", @@ -47,11 +45,8 @@ func TestGenerateConnectionConfig(t *testing.T) { user: "user", password: "password", database: "database", - tlsMode: "disable", - expUser: "user", - expPassword: "password", - expHost: "/var/run/postgresql", - expDatabase: "database", + tlsSettings: tlsSettings{Mode: "verify-full"}, + expConnStr: "user='user' password='password' host='/var/run/postgresql' dbname='database' sslmode='verify-full'", }, { desc: "TCP host", @@ -59,12 +54,8 @@ func TestGenerateConnectionConfig(t *testing.T) { user: "user", password: "password", database: "database", - tlsMode: "disable", - expUser: "user", - expPassword: "password", - expHost: "host", - expPort: 5432, - expDatabase: "database", + tlsSettings: tlsSettings{Mode: "verify-full"}, + expConnStr: "user='user' password='password' host='host' dbname='database' sslmode='verify-full'", }, { desc: "TCP/port host", @@ -72,12 +63,8 @@ func TestGenerateConnectionConfig(t *testing.T) { user: "user", password: "password", database: "database", - tlsMode: "disable", - expUser: "user", - expPassword: "password", - expHost: "host", - expPort: 1234, - expDatabase: "database", + tlsSettings: tlsSettings{Mode: "verify-full"}, + expConnStr: "user='user' password='password' host='host' dbname='database' port=1234 sslmode='verify-full'", }, { desc: "Ipv6 host", @@ -85,11 +72,8 @@ func TestGenerateConnectionConfig(t *testing.T) { user: "user", password: "password", database: "database", - tlsMode: "disable", - expUser: "user", - expPassword: "password", - expHost: "::1", - expDatabase: "database", + tlsSettings: tlsSettings{Mode: "verify-full"}, + expConnStr: "user='user' password='password' host='::1' dbname='database' sslmode='verify-full'", }, { desc: "Ipv6/port host", @@ -97,20 +81,16 @@ func TestGenerateConnectionConfig(t *testing.T) { user: "user", password: "password", database: "database", - tlsMode: "disable", - expUser: "user", - expPassword: "password", - expHost: "::1", - expPort: 1234, - expDatabase: "database", + tlsSettings: tlsSettings{Mode: "verify-full"}, + expConnStr: "user='user' password='password' host='::1' dbname='database' port=1234 sslmode='verify-full'", }, { - desc: "Invalid port", - host: "host:invalid", - user: "user", - database: "database", - tlsMode: "disable", - expErr: "invalid port in host specifier", + desc: "Invalid port", + host: "host:invalid", + user: "user", + database: "database", + tlsSettings: tlsSettings{}, + expErr: "invalid port in host specifier", }, { desc: "Password with single quote and backslash", @@ -118,11 +98,8 @@ func TestGenerateConnectionConfig(t *testing.T) { user: "user", password: `p'\assword`, database: "database", - tlsMode: "disable", - expUser: "user", - expPassword: `p'\assword`, - expHost: "host", - expDatabase: "database", + tlsSettings: tlsSettings{Mode: "verify-full"}, + expConnStr: `user='user' password='p\'\\assword' host='host' dbname='database' sslmode='verify-full'`, }, { desc: "User/DB with single quote and backslash", @@ -130,11 +107,8 @@ func TestGenerateConnectionConfig(t *testing.T) { user: `u'\ser`, password: `password`, database: `d'\atabase`, - tlsMode: "disable", - expUser: `u'\ser`, - expPassword: "password", - expDatabase: `d'\atabase`, - expHost: "host", + tlsSettings: tlsSettings{Mode: "verify-full"}, + expConnStr: `user='u\'\\ser' password='password' host='host' dbname='d\'\\atabase' sslmode='verify-full'`, }, { desc: "Custom TLS mode disabled", @@ -142,55 +116,45 @@ func TestGenerateConnectionConfig(t *testing.T) { user: "user", password: "password", database: "database", - tlsMode: "disable", - expUser: "user", - expPassword: "password", - expHost: "host", - expDatabase: "database", + tlsSettings: tlsSettings{Mode: "disable"}, + expConnStr: "user='user' password='password' host='host' dbname='database' sslmode='disable'", }, { - desc: "Custom TLS mode verify-full with certificate files", - host: "host", - user: "user", - password: "password", - database: "database", - tlsMode: "verify-full", - tlsRootCert: rootCertBytes, - expUser: "user", - expPassword: "password", - expDatabase: "database", - expHost: "host", - expTLS: true, + desc: "Custom TLS mode verify-full with certificate files", + host: "host", + user: "user", + password: "password", + database: "database", + tlsSettings: tlsSettings{ + Mode: "verify-full", + RootCertFile: "i/am/coding/ca.crt", + CertFile: "i/am/coding/client.crt", + CertKeyFile: "i/am/coding/client.key", + }, + expConnStr: "user='user' password='password' host='host' dbname='database' sslmode='verify-full' " + + "sslrootcert='i/am/coding/ca.crt' sslcert='i/am/coding/client.crt' sslkey='i/am/coding/client.key'", }, } for _, tt := range testCases { t.Run(tt.desc, func(t *testing.T) { - ds := sqleng.DataSourceInfo{ - URL: tt.host, - User: tt.user, - DecryptedSecureJSONData: map[string]string{ - "password": tt.password, - "tlsCACert": string(tt.tlsRootCert), - }, - Database: tt.database, - JsonData: sqleng.JsonData{ - Mode: tt.tlsMode, - ConfigurationMethod: "file-content", - }, + svc := Service{ + tlsManager: &tlsTestManager{settings: tt.tlsSettings}, + logger: backend.NewLoggerWith("logger", "tsdb.postgres"), } - c, err := generateConnectionConfig(ds) + ds := sqleng.DataSourceInfo{ + URL: tt.host, + User: tt.user, + DecryptedSecureJSONData: map[string]string{"password": tt.password}, + Database: tt.database, + UID: tt.uid, + } + + connStr, err := svc.generateConnectionString(ds) if tt.expErr == "" { require.NoError(t, err, tt.desc) - assert.Equal(t, tt.expHost, c.Host) - if tt.expPort != 0 { - assert.Equal(t, tt.expPort, c.Port) - } - assert.Equal(t, tt.expUser, c.User) - assert.Equal(t, tt.expDatabase, c.Database) - assert.Equal(t, tt.expPassword, c.Password) - require.Equal(t, tt.expTLS, c.TLSConfig != nil) + assert.Equal(t, tt.expConnStr, connStr) } else { require.Error(t, err, tt.desc) assert.True(t, strings.HasPrefix(err.Error(), tt.expErr), @@ -242,10 +206,9 @@ func TestIntegrationPostgres(t *testing.T) { logger := backend.NewLoggerWith("logger", "postgres.test") - cnn, err := postgresTestDBConn() - require.NoError(t, err) + cnnstr := postgresTestDBConnString() - db, exe, err := newPostgres(context.Background(), "error", 10000, dsInfo, cnn, logger, backend.DataSourceInstanceSettings{}) + db, exe, err := newPostgres(context.Background(), "error", 10000, dsInfo, cnnstr, logger, backend.DataSourceInstanceSettings{}) require.NoError(t, err) @@ -1299,7 +1262,7 @@ func TestIntegrationPostgres(t *testing.T) { t.Run("When row limit set to 1", func(t *testing.T) { dsInfo := sqleng.DataSourceInfo{} - _, handler, err := newPostgres(context.Background(), "error", 1, dsInfo, cnn, logger, backend.DataSourceInstanceSettings{}) + _, handler, err := newPostgres(context.Background(), "error", 1, dsInfo, cnnstr, logger, backend.DataSourceInstanceSettings{}) require.NoError(t, err) @@ -1414,6 +1377,14 @@ func genTimeRangeByInterval(from time.Time, duration time.Duration, interval tim return timeRange } +type tlsTestManager struct { + settings tlsSettings +} + +func (m *tlsTestManager) getTLSSettings(dsInfo sqleng.DataSourceInfo) (tlsSettings, error) { + return m.settings, nil +} + func isTestDbPostgres() bool { if db, present := os.LookupEnv("GRAFANA_TEST_DB"); present { return db == "postgres" @@ -1422,7 +1393,7 @@ func isTestDbPostgres() bool { return false } -func postgresTestDBConn() (*pgx.ConnConfig, error) { +func postgresTestDBConnString() string { host := os.Getenv("POSTGRES_HOST") if host == "" { host = "localhost" @@ -1431,8 +1402,6 @@ func postgresTestDBConn() (*pgx.ConnConfig, error) { if port == "" { port = "5432" } - connStr := fmt.Sprintf("user=grafanatest password=grafanatest host=%s port=%s dbname=grafanadstest sslmode=disable", + return fmt.Sprintf("user=grafanatest password=grafanatest host=%s port=%s dbname=grafanadstest sslmode=disable", host, port) - - return pgx.ParseConfig(connStr) } diff --git a/pkg/tsdb/grafana-postgresql-datasource/proxy.go b/pkg/tsdb/grafana-postgresql-datasource/proxy.go index 0c836eb3459..d06d8b68152 100644 --- a/pkg/tsdb/grafana-postgresql-datasource/proxy.go +++ b/pkg/tsdb/grafana-postgresql-datasource/proxy.go @@ -3,17 +3,33 @@ package postgres import ( "context" "net" + "time" + "github.com/lib/pq" "golang.org/x/net/proxy" ) -type PgxDialFunc = func(ctx context.Context, network string, address string) (net.Conn, error) - -func newPgxDialFunc(dialer proxy.Dialer) PgxDialFunc { - dialFunc := - func(ctx context.Context, network string, addr string) (net.Conn, error) { - return dialer.Dial(network, addr) - } - - return dialFunc +// we wrap the proxy.Dialer to become dialer that the postgres module accepts +func newPostgresProxyDialer(dialer proxy.Dialer) pq.Dialer { + return &postgresProxyDialer{d: dialer} +} + +var _ pq.Dialer = (&postgresProxyDialer{}) + +// postgresProxyDialer implements the postgres dialer using a proxy dialer, as their functions differ slightly +type postgresProxyDialer struct { + d proxy.Dialer +} + +// Dial uses the normal proxy dial function with the updated dialer +func (p *postgresProxyDialer) Dial(network, addr string) (c net.Conn, err error) { + return p.d.Dial(network, addr) +} + +// DialTimeout uses the normal postgres dial timeout function with the updated dialer +func (p *postgresProxyDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + return p.d.(proxy.ContextDialer).DialContext(ctx, network, address) } diff --git a/pkg/tsdb/grafana-postgresql-datasource/proxy_test.go b/pkg/tsdb/grafana-postgresql-datasource/proxy_test.go index afd205bd372..ec36e1a1eae 100644 --- a/pkg/tsdb/grafana-postgresql-datasource/proxy_test.go +++ b/pkg/tsdb/grafana-postgresql-datasource/proxy_test.go @@ -1,12 +1,12 @@ package postgres import ( + "database/sql" "fmt" "net" "testing" - "github.com/jackc/pgx/v5" - pgxstdlib "github.com/jackc/pgx/v5/stdlib" + "github.com/lib/pq" "github.com/stretchr/testify/require" "golang.org/x/net/proxy" ) @@ -25,13 +25,13 @@ func TestPostgresProxyDriver(t *testing.T) { cnnstr := fmt.Sprintf("postgres://auser:password@%s/db?sslmode=disable", dbURL) t.Run("Connector should use dialer context that routes through the socks proxy to db", func(t *testing.T) { - pgxConf, err := pgx.ParseConfig(cnnstr) + connector, err := pq.NewConnector(cnnstr) require.NoError(t, err) + dialer := newPostgresProxyDialer(&testDialer{}) - pgxConf.DialFunc = newPgxDialFunc(&testDialer{}) - - db := pgxstdlib.OpenDB(*pgxConf) + connector.Dialer(dialer) + db := sql.OpenDB(connector) err = db.Ping() require.Contains(t, err.Error(), "test-dialer is not functional") diff --git a/pkg/tsdb/grafana-postgresql-datasource/testdata/table/timestamp_convert_real.golden.jsonc b/pkg/tsdb/grafana-postgresql-datasource/testdata/table/timestamp_convert_real.golden.jsonc index e6d1dfd2382..af23cf8b5bd 100644 --- a/pkg/tsdb/grafana-postgresql-datasource/testdata/table/timestamp_convert_real.golden.jsonc +++ b/pkg/tsdb/grafana-postgresql-datasource/testdata/table/timestamp_convert_real.golden.jsonc @@ -9,16 +9,16 @@ // } // Name: // Dimensions: 4 Fields by 4 Rows -// +--------------------------------------+-----------------------------------+-----------------------+-----------------------------------+ -// | Name: reallyt | Name: time | Name: n | Name: timeend | -// | Labels: | Labels: | Labels: | Labels: | -// | Type: []*time.Time | Type: []*time.Time | Type: []*float64 | Type: []*time.Time | -// +--------------------------------------+-----------------------------------+-----------------------+-----------------------------------+ -// | 2023-12-21 12:22:24 +0000 UTC | 2023-12-21 12:22:24 +0000 UTC | 1.703161344e+09 | 2023-12-21 12:22:52 +0000 UTC | -// | 2023-12-21 12:20:33.408 +0000 UTC | 2023-12-21 12:20:33.408 +0000 UTC | 1.703161233408e+12 | 2023-12-21 12:21:52.522 +0000 UTC | -// | 2023-12-21 12:20:41.050022 +0000 UTC | 2023-12-21 12:20:41.05 +0000 UTC | 1.703161241050022e+18 | 2023-12-21 12:21:52.522 +0000 UTC | -// | null | null | null | null | -// +--------------------------------------+-----------------------------------+-----------------------+-----------------------------------+ +// +--------------------------------------+-------------------------------+------------------+-----------------------------------+ +// | Name: reallyt | Name: time | Name: n | Name: timeend | +// | Labels: | Labels: | Labels: | Labels: | +// | Type: []*time.Time | Type: []*time.Time | Type: []*float64 | Type: []*time.Time | +// +--------------------------------------+-------------------------------+------------------+-----------------------------------+ +// | 2023-12-21 12:22:24 +0000 UTC | 2023-12-21 12:21:40 +0000 UTC | 1.7031613e+09 | 2023-12-21 12:22:52 +0000 UTC | +// | 2023-12-21 12:20:33.408 +0000 UTC | 2023-12-21 12:20:00 +0000 UTC | 1.7031612e+12 | 2023-12-21 12:21:52.522 +0000 UTC | +// | 2023-12-21 12:20:41.050022 +0000 UTC | 2023-12-21 12:20:00 +0000 UTC | 1.7031612e+18 | 2023-12-21 12:21:52.522 +0000 UTC | +// | null | null | null | null | +// +--------------------------------------+-------------------------------+------------------+-----------------------------------+ // // // 🌟 This was machine generated. Do not edit. 🌟 @@ -78,15 +78,15 @@ null ], [ - 1703161344000, - 1703161233408, - 1703161241050, + 1703161300000, + 1703161200000, + 1703161200000, null ], [ - 1703161344, - 1703161233408, - 1703161241050022000, + 1703161300, + 1703161200000, + 1703161200000000000, null ], [ diff --git a/pkg/tsdb/grafana-postgresql-datasource/testdata/table/types_datetime.golden.jsonc b/pkg/tsdb/grafana-postgresql-datasource/testdata/table/types_datetime.golden.jsonc index 09e9403459b..48e55e7b077 100644 --- a/pkg/tsdb/grafana-postgresql-datasource/testdata/table/types_datetime.golden.jsonc +++ b/pkg/tsdb/grafana-postgresql-datasource/testdata/table/types_datetime.golden.jsonc @@ -9,14 +9,14 @@ // } // Name: // Dimensions: 12 Fields by 2 Rows -// +--------------------------------------+--------------------------------------+--------------------------------------+--------------------------------------+-------------------------------+-------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+ -// | Name: ts | Name: tsnn | Name: tsz | Name: tsznn | Name: d | Name: dnn | Name: t | Name: tnn | Name: tz | Name: tznn | Name: i | Name: inn | -// | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | -// | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*string | Type: []*string | -// +--------------------------------------+--------------------------------------+--------------------------------------+--------------------------------------+-------------------------------+-------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+ -// | 2023-11-15 05:06:07.123456 +0000 UTC | 2023-11-15 05:06:08.123456 +0000 UTC | 2021-07-22 11:22:33.654321 +0000 UTC | 2021-07-22 11:22:34.654321 +0000 UTC | 2023-12-20 00:00:00 +0000 UTC | 2023-12-21 00:00:00 +0000 UTC | 0000-01-01 12:34:56.234567 +0000 UTC | 0000-01-01 12:34:57.234567 +0000 UTC | 0000-01-01 23:12:36.765432 +0100 +0100 | 0000-01-01 23:12:37.765432 +0100 +0100 | 00:00:00.987654 | 00:00:00.887654 | -// | null | 2023-11-15 05:06:09.123456 +0000 UTC | null | 2021-07-22 11:22:35.654321 +0000 UTC | null | 2023-12-22 00:00:00 +0000 UTC | null | 0000-01-01 12:34:58.234567 +0000 UTC | null | 0000-01-01 23:12:38.765432 +0100 +0100 | null | 00:00:00.787654 | -// +--------------------------------------+--------------------------------------+--------------------------------------+--------------------------------------+-------------------------------+-------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+ +// +----------------------------------------+----------------------------------------+--------------------------------------+--------------------------------------+---------------------------------+---------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+ +// | Name: ts | Name: tsnn | Name: tsz | Name: tsznn | Name: d | Name: dnn | Name: t | Name: tnn | Name: tz | Name: tznn | Name: i | Name: inn | +// | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | +// | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*string | Type: []*string | +// +----------------------------------------+----------------------------------------+--------------------------------------+--------------------------------------+---------------------------------+---------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+ +// | 2023-11-15 05:06:07.123456 +0000 +0000 | 2023-11-15 05:06:08.123456 +0000 +0000 | 2021-07-22 11:22:33.654321 +0000 UTC | 2021-07-22 11:22:34.654321 +0000 UTC | 2023-12-20 00:00:00 +0000 +0000 | 2023-12-21 00:00:00 +0000 +0000 | 0000-01-01 12:34:56.234567 +0000 UTC | 0000-01-01 12:34:57.234567 +0000 UTC | 0000-01-01 23:12:36.765432 +0100 +0100 | 0000-01-01 23:12:37.765432 +0100 +0100 | 00:00:00.987654 | 00:00:00.887654 | +// | null | 2023-11-15 05:06:09.123456 +0000 +0000 | null | 2021-07-22 11:22:35.654321 +0000 UTC | null | 2023-12-22 00:00:00 +0000 +0000 | null | 0000-01-01 12:34:58.234567 +0000 UTC | null | 0000-01-01 23:12:38.765432 +0100 +0100 | null | 00:00:00.787654 | +// +----------------------------------------+----------------------------------------+--------------------------------------+--------------------------------------+---------------------------------+---------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+ // // // 🌟 This was machine generated. Do not edit. 🌟 diff --git a/pkg/tsdb/grafana-postgresql-datasource/tls/tls.go b/pkg/tsdb/grafana-postgresql-datasource/tls/tls.go deleted file mode 100644 index e5e409dc55b..00000000000 --- a/pkg/tsdb/grafana-postgresql-datasource/tls/tls.go +++ /dev/null @@ -1,147 +0,0 @@ -package tls - -import ( - "crypto/tls" - "crypto/x509" - "errors" - - "github.com/grafana/grafana/pkg/tsdb/sqleng" -) - -// we support 4 postgres tls modes: -// disable - no tls -// require - use tls -// verify-ca - use tls, verify root cert but not the hostname -// verify-full - use tls, verify root cert -// (for all the options except `disable`, you can optionally use client certificates) - -var errNoRootCert = errors.New("tls: missing root certificate") - -func getTLSConfigRequire(certs *Certs) (*tls.Config, error) { - // we may have a client-cert, we do not have a root-cert - - // see https://www.postgresql.org/docs/12/libpq-ssl.html , - // mode=require + provided root-cert should behave as mode=verify-ca - if certs.rootCerts != nil { - return getTLSConfigVerifyCA(certs) - } - - return &tls.Config{ - InsecureSkipVerify: true, // we do not verify the root cert - Certificates: certs.clientCerts, - }, nil -} - -// to implement the verify-ca mode, we need to do this: -// - for the root certificate -// - verify that the certificate we receive from the server is trusted, -// meaning it relates to our root certificate -// - we DO NOT verify that the hostname of the database matches -// the hostname in the certificate -// -// the problem is, `go“ does not offer such an option. -// by default, it will verify both things. -// -// so what we do is: -// - we turn off the default-verification with `InsecureSkipVerify` -// - we implement our own verification using `VerifyConnection` -// -// extra info about this: -// - there is a rejected feature-request about this at https://github.com/golang/go/issues/21971 -// - the recommended workaround is based on VerifyPeerCertificate -// - there is even example code at https://github.com/golang/go/commit/29cfb4d3c3a97b6f426d1b899234da905be699aa -// - but later the example code was changed to use VerifyConnection instead: -// https://github.com/golang/go/commit/7eb5941b95a588a23f18fa4c22fe42ff0119c311 -// -// a verifyConnection example is at https://pkg.go.dev/crypto/tls#example-Config-VerifyConnection . -// -// this is how the `pgx` library handles verify-ca: -// -// https://github.com/jackc/pgx/blob/5c63f646f820ca9696fc3515c1caf2a557d562e5/pgconn/config.go#L657-L690 -// (unfortunately pgx only handles this for certificate-provided-as-path, so we cannot rely on it) -func getTLSConfigVerifyCA(certs *Certs) (*tls.Config, error) { - // we must have a root certificate - if certs.rootCerts == nil { - return nil, errNoRootCert - } - - conf := tls.Config{ - Certificates: certs.clientCerts, - InsecureSkipVerify: true, // we turn off the default-verification, we'll do VerifyConnection instead - VerifyConnection: func(state tls.ConnectionState) error { - // we add all the certificates to the pool, we skip the first cert. - intermediates := x509.NewCertPool() - for _, c := range state.PeerCertificates[1:] { - intermediates.AddCert(c) - } - - opts := x509.VerifyOptions{ - Roots: certs.rootCerts, - Intermediates: intermediates, - } - - // we call `Verify()` on the first cert (that we skipped previously) - _, err := state.PeerCertificates[0].Verify(opts) - return err - }, - RootCAs: certs.rootCerts, - } - - return &conf, nil -} - -func getTLSConfigVerifyFull(certs *Certs, serverName string) (*tls.Config, error) { - // we must have a root certificate - if certs.rootCerts == nil { - return nil, errNoRootCert - } - - conf := tls.Config{ - Certificates: certs.clientCerts, - ServerName: serverName, - RootCAs: certs.rootCerts, - } - - return &conf, nil -} - -func IsTLSEnabled(dsInfo sqleng.DataSourceInfo) bool { - mode := dsInfo.JsonData.Mode - return mode != "disable" -} - -// returns `nil` if tls is disabled -func GetTLSConfig(dsInfo sqleng.DataSourceInfo, readFile ReadFileFunc, serverName string) (*tls.Config, error) { - mode := dsInfo.JsonData.Mode - // we need to special-case the no-tls-mode - if mode == "disable" { - return nil, nil - } - - // for all the remaining cases we need to load - // both the root-cert if exists, and the client-cert if exists. - certBytes, err := loadCertificateBytes(dsInfo, readFile) - if err != nil { - return nil, err - } - - certs, err := createCertificates(certBytes) - if err != nil { - return nil, err - } - - switch mode { - // `disable` already handled - case "": - // for backward-compatibility reasons this is the same as `require` - return getTLSConfigRequire(certs) - case "require": - return getTLSConfigRequire(certs) - case "verify-ca": - return getTLSConfigVerifyCA(certs) - case "verify-full": - return getTLSConfigVerifyFull(certs, serverName) - default: - return nil, errors.New("tls: invalid mode " + mode) - } -} diff --git a/pkg/tsdb/grafana-postgresql-datasource/tls/tls_loader.go b/pkg/tsdb/grafana-postgresql-datasource/tls/tls_loader.go deleted file mode 100644 index 6c19d3801d5..00000000000 --- a/pkg/tsdb/grafana-postgresql-datasource/tls/tls_loader.go +++ /dev/null @@ -1,101 +0,0 @@ -package tls - -import ( - "crypto/tls" - "crypto/x509" - "errors" - - "github.com/grafana/grafana/pkg/tsdb/sqleng" -) - -// this file deals with locating and loading the certificates, -// from json-data or from disk. - -type CertBytes struct { - rootCert []byte - clientKey []byte - clientCert []byte -} - -type ReadFileFunc = func(name string) ([]byte, error) - -var errPartialClientCertNoKey = errors.New("tls: client cert provided but client key missing") -var errPartialClientCertNoCert = errors.New("tls: client key provided but client cert missing") - -// certificates can be stored either as encrypted-json-data, or as file-path -func loadCertificateBytes(dsInfo sqleng.DataSourceInfo, readFile ReadFileFunc) (*CertBytes, error) { - if dsInfo.JsonData.ConfigurationMethod == "file-content" { - return &CertBytes{ - rootCert: []byte(dsInfo.DecryptedSecureJSONData["tlsCACert"]), - clientKey: []byte(dsInfo.DecryptedSecureJSONData["tlsClientKey"]), - clientCert: []byte(dsInfo.DecryptedSecureJSONData["tlsClientCert"]), - }, nil - } else { - c := CertBytes{} - - if dsInfo.JsonData.RootCertFile != "" { - rootCert, err := readFile(dsInfo.JsonData.RootCertFile) - if err != nil { - return nil, err - } - c.rootCert = rootCert - } - - if dsInfo.JsonData.CertKeyFile != "" { - clientKey, err := readFile(dsInfo.JsonData.CertKeyFile) - if err != nil { - return nil, err - } - c.clientKey = clientKey - } - - if dsInfo.JsonData.CertFile != "" { - clientCert, err := readFile(dsInfo.JsonData.CertFile) - if err != nil { - return nil, err - } - c.clientCert = clientCert - } - - return &c, nil - } -} - -type Certs struct { - clientCerts []tls.Certificate - rootCerts *x509.CertPool -} - -func createCertificates(certBytes *CertBytes) (*Certs, error) { - certs := Certs{} - - if len(certBytes.rootCert) > 0 { - pool := x509.NewCertPool() - ok := pool.AppendCertsFromPEM(certBytes.rootCert) - if !ok { - return nil, errors.New("tls: failed to add root certificate") - } - certs.rootCerts = pool - } - - hasClientKey := len(certBytes.clientKey) > 0 - hasClientCert := len(certBytes.clientCert) > 0 - - if hasClientKey && hasClientCert { - cert, err := tls.X509KeyPair(certBytes.clientCert, certBytes.clientKey) - if err != nil { - return nil, err - } - certs.clientCerts = []tls.Certificate{cert} - } - - if hasClientKey && (!hasClientCert) { - return nil, errPartialClientCertNoCert - } - - if hasClientCert && (!hasClientKey) { - return nil, errPartialClientCertNoKey - } - - return &certs, nil -} diff --git a/pkg/tsdb/grafana-postgresql-datasource/tls/tls_test.go b/pkg/tsdb/grafana-postgresql-datasource/tls/tls_test.go deleted file mode 100644 index bce63e1c43f..00000000000 --- a/pkg/tsdb/grafana-postgresql-datasource/tls/tls_test.go +++ /dev/null @@ -1,382 +0,0 @@ -package tls - -import ( - "errors" - "os" - "testing" - - "github.com/grafana/grafana/pkg/tsdb/sqleng" - "github.com/stretchr/testify/require" -) - -func noReadFile(path string) ([]byte, error) { - return nil, errors.New("not implemented") -} - -func TestTLSNoMode(t *testing.T) { - // for backward-compatibility reason, - // when mode is unset, it defaults to `require` - dsInfo := sqleng.DataSourceInfo{ - JsonData: sqleng.JsonData{ - ConfigurationMethod: "", - }, - } - c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") - require.NoError(t, err) - require.NotNil(t, c) - require.True(t, c.InsecureSkipVerify) -} - -func TestTLSDisable(t *testing.T) { - dsInfo := sqleng.DataSourceInfo{ - JsonData: sqleng.JsonData{ - Mode: "disable", - ConfigurationMethod: "", - }, - } - c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") - require.NoError(t, err) - require.Nil(t, c) -} - -func TestTLSRequire(t *testing.T) { - dsInfo := sqleng.DataSourceInfo{ - JsonData: sqleng.JsonData{ - Mode: "require", - ConfigurationMethod: "", - }, - } - c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") - require.NoError(t, err) - require.NotNil(t, c) - require.True(t, c.InsecureSkipVerify) - require.Nil(t, c.RootCAs) -} - -func TestTLSRequireWithRootCert(t *testing.T) { - rootCertBytes, err := CreateRandomRootCertBytes() - require.NoError(t, err) - - dsInfo := sqleng.DataSourceInfo{ - JsonData: sqleng.JsonData{ - Mode: "require", - ConfigurationMethod: "file-content", - }, - DecryptedSecureJSONData: map[string]string{ - "tlsCACert": string(rootCertBytes), - }, - } - c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") - require.NoError(t, err) - require.NotNil(t, c) - require.True(t, c.InsecureSkipVerify) - require.NotNil(t, c.VerifyConnection) - require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available -} - -func TestTLSVerifyCA(t *testing.T) { - rootCertBytes, err := CreateRandomRootCertBytes() - require.NoError(t, err) - - dsInfo := sqleng.DataSourceInfo{ - JsonData: sqleng.JsonData{ - Mode: "verify-ca", - ConfigurationMethod: "file-content", - }, - DecryptedSecureJSONData: map[string]string{ - "tlsCACert": string(rootCertBytes), - }, - } - c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") - require.NoError(t, err) - require.NotNil(t, c) - require.True(t, c.InsecureSkipVerify) - require.NotNil(t, c.VerifyConnection) - require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available -} - -func TestTLSVerifyCAMisingRootCert(t *testing.T) { - dsInfo := sqleng.DataSourceInfo{ - JsonData: sqleng.JsonData{ - Mode: "verify-ca", - ConfigurationMethod: "file-content", - }, - DecryptedSecureJSONData: map[string]string{}, - } - _, err := GetTLSConfig(dsInfo, noReadFile, "localhost") - require.ErrorIs(t, err, errNoRootCert) -} - -func TestTLSClientCert(t *testing.T) { - clientKey, clientCert, err := CreateRandomClientCert() - require.NoError(t, err) - - dsInfo := sqleng.DataSourceInfo{ - JsonData: sqleng.JsonData{ - Mode: "require", - ConfigurationMethod: "file-content", - }, - DecryptedSecureJSONData: map[string]string{ - "tlsClientCert": string(clientCert), - "tlsClientKey": string(clientKey), - }, - } - c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") - require.NoError(t, err) - require.NotNil(t, c) - require.Len(t, c.Certificates, 1) -} - -func TestTLSMethodFileContentClientCertMissingKey(t *testing.T) { - _, clientCert, err := CreateRandomClientCert() - require.NoError(t, err) - - dsInfo := sqleng.DataSourceInfo{ - JsonData: sqleng.JsonData{ - Mode: "require", - ConfigurationMethod: "file-content", - }, - DecryptedSecureJSONData: map[string]string{ - "tlsClientCert": string(clientCert), - }, - } - _, err = GetTLSConfig(dsInfo, noReadFile, "localhost") - require.ErrorIs(t, err, errPartialClientCertNoKey) -} - -func TestTLSMethodFileContentClientCertMissingCert(t *testing.T) { - clientKey, _, err := CreateRandomClientCert() - require.NoError(t, err) - - dsInfo := sqleng.DataSourceInfo{ - JsonData: sqleng.JsonData{ - Mode: "require", - ConfigurationMethod: "file-content", - }, - DecryptedSecureJSONData: map[string]string{ - "tlsClientKey": string(clientKey), - }, - } - _, err = GetTLSConfig(dsInfo, noReadFile, "localhost") - require.ErrorIs(t, err, errPartialClientCertNoCert) -} - -func TestTLSMethodFilePathClientCertMissingKey(t *testing.T) { - clientKey, _, err := CreateRandomClientCert() - require.NoError(t, err) - - readFile := newMockReadFile(map[string]([]byte){ - "path1": clientKey, - }) - - dsInfo := sqleng.DataSourceInfo{ - JsonData: sqleng.JsonData{ - Mode: "require", - ConfigurationMethod: "file-path", - CertKeyFile: "path1", - }, - } - _, err = GetTLSConfig(dsInfo, readFile, "localhost") - require.ErrorIs(t, err, errPartialClientCertNoCert) -} - -func TestTLSMethodFilePathClientCertMissingCert(t *testing.T) { - _, clientCert, err := CreateRandomClientCert() - require.NoError(t, err) - - readFile := newMockReadFile(map[string]([]byte){ - "path1": clientCert, - }) - - dsInfo := sqleng.DataSourceInfo{ - JsonData: sqleng.JsonData{ - Mode: "require", - ConfigurationMethod: "file-path", - CertFile: "path1", - }, - } - _, err = GetTLSConfig(dsInfo, readFile, "localhost") - require.ErrorIs(t, err, errPartialClientCertNoKey) -} - -func TestTLSVerifyFull(t *testing.T) { - rootCertBytes, err := CreateRandomRootCertBytes() - require.NoError(t, err) - - dsInfo := sqleng.DataSourceInfo{ - JsonData: sqleng.JsonData{ - Mode: "verify-full", - ConfigurationMethod: "file-content", - }, - DecryptedSecureJSONData: map[string]string{ - "tlsCACert": string(rootCertBytes), - }, - } - c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") - require.NoError(t, err) - require.NotNil(t, c) - require.False(t, c.InsecureSkipVerify) - require.Nil(t, c.VerifyConnection) - require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available -} - -func TestTLSMethodFileContent(t *testing.T) { - rootCertBytes, err := CreateRandomRootCertBytes() - require.NoError(t, err) - - clientKey, clientCert, err := CreateRandomClientCert() - require.NoError(t, err) - - dsInfo := sqleng.DataSourceInfo{ - JsonData: sqleng.JsonData{ - Mode: "verify-full", - ConfigurationMethod: "file-content", - }, - DecryptedSecureJSONData: map[string]string{ - "tlsCACert": string(rootCertBytes), - "tlsClientCert": string(clientCert), - "tlsClientKey": string(clientKey), - }, - } - c, err := GetTLSConfig(dsInfo, noReadFile, "localhost") - require.NoError(t, err) - require.NotNil(t, c) - require.Len(t, c.Certificates, 1) - require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available -} - -func TestTLSMethodFilePath(t *testing.T) { - rootCertBytes, err := CreateRandomRootCertBytes() - require.NoError(t, err) - - clientKey, clientCert, err := CreateRandomClientCert() - require.NoError(t, err) - - readFile := newMockReadFile(map[string]([]byte){ - "root-cert-path": rootCertBytes, - "client-key-path": clientKey, - "client-cert-path": clientCert, - }) - - dsInfo := sqleng.DataSourceInfo{ - JsonData: sqleng.JsonData{ - Mode: "verify-full", - ConfigurationMethod: "file-path", - RootCertFile: "root-cert-path", - CertKeyFile: "client-key-path", - CertFile: "client-cert-path", - }, - } - c, err := GetTLSConfig(dsInfo, readFile, "localhost") - require.NoError(t, err) - require.NotNil(t, c) - require.Len(t, c.Certificates, 1) - require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available -} - -func TestTLSMethodFilePathRootCertDoesNotExist(t *testing.T) { - readFile := newMockReadFile(map[string]([]byte){}) - - dsInfo := sqleng.DataSourceInfo{ - JsonData: sqleng.JsonData{ - Mode: "verify-full", - ConfigurationMethod: "file-path", - RootCertFile: "path1", - }, - } - _, err := GetTLSConfig(dsInfo, readFile, "localhost") - require.ErrorIs(t, err, os.ErrNotExist) -} - -func TestTLSMethodFilePathClientCertKeyDoesNotExist(t *testing.T) { - _, clientCert, err := CreateRandomClientCert() - require.NoError(t, err) - - readFile := newMockReadFile(map[string]([]byte){ - "cert-path": clientCert, - }) - - dsInfo := sqleng.DataSourceInfo{ - JsonData: sqleng.JsonData{ - Mode: "require", - ConfigurationMethod: "file-path", - CertKeyFile: "key-path", - CertFile: "cert-path", - }, - } - _, err = GetTLSConfig(dsInfo, readFile, "localhost") - require.ErrorIs(t, err, os.ErrNotExist) -} - -func TestTLSMethodFilePathClientCertCertDoesNotExist(t *testing.T) { - clientKey, _, err := CreateRandomClientCert() - require.NoError(t, err) - - readFile := newMockReadFile(map[string]([]byte){ - "key-path": clientKey, - }) - - dsInfo := sqleng.DataSourceInfo{ - JsonData: sqleng.JsonData{ - Mode: "require", - ConfigurationMethod: "file-path", - CertKeyFile: "key-path", - CertFile: "cert-path", - }, - } - _, err = GetTLSConfig(dsInfo, readFile, "localhost") - require.ErrorIs(t, err, os.ErrNotExist) -} - -// method="" equals to method="file-path" -func TestTLSMethodEmpty(t *testing.T) { - rootCertBytes, err := CreateRandomRootCertBytes() - require.NoError(t, err) - - clientKey, clientCert, err := CreateRandomClientCert() - require.NoError(t, err) - - readFile := newMockReadFile(map[string]([]byte){ - "root-cert-path": rootCertBytes, - "client-key-path": clientKey, - "client-cert-path": clientCert, - }) - - dsInfo := sqleng.DataSourceInfo{ - JsonData: sqleng.JsonData{ - Mode: "verify-full", - ConfigurationMethod: "", - RootCertFile: "root-cert-path", - CertKeyFile: "client-key-path", - CertFile: "client-cert-path", - }, - } - c, err := GetTLSConfig(dsInfo, readFile, "localhost") - require.NoError(t, err) - require.NotNil(t, c) - require.Len(t, c.Certificates, 1) - require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available -} - -func TestTLSVerifyFullMisingRootCert(t *testing.T) { - dsInfo := sqleng.DataSourceInfo{ - JsonData: sqleng.JsonData{ - Mode: "verify-full", - ConfigurationMethod: "file-content", - }, - DecryptedSecureJSONData: map[string]string{}, - } - _, err := GetTLSConfig(dsInfo, noReadFile, "localhost") - require.ErrorIs(t, err, errNoRootCert) -} - -func TestTLSInvalidMode(t *testing.T) { - dsInfo := sqleng.DataSourceInfo{ - JsonData: sqleng.JsonData{ - Mode: "not-a-valid-mode", - }, - } - - _, err := GetTLSConfig(dsInfo, noReadFile, "localhost") - require.Error(t, err) -} diff --git a/pkg/tsdb/grafana-postgresql-datasource/tls/tls_test_helpers.go b/pkg/tsdb/grafana-postgresql-datasource/tls/tls_test_helpers.go deleted file mode 100644 index 1b62df63d09..00000000000 --- a/pkg/tsdb/grafana-postgresql-datasource/tls/tls_test_helpers.go +++ /dev/null @@ -1,105 +0,0 @@ -package tls - -import ( - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "math/big" - "os" - "time" -) - -func CreateRandomRootCertBytes() ([]byte, error) { - cert := x509.Certificate{ - SerialNumber: big.NewInt(42), - Subject: pkix.Name{ - CommonName: "test1", - }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(10, 0, 0), - IsCA: true, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - BasicConstraintsValid: true, - } - - key, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, err - } - - bytes, err := x509.CreateCertificate(rand.Reader, &cert, &cert, &key.PublicKey, key) - if err != nil { - return nil, err - } - - return pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: bytes, - }), nil -} - -func CreateRandomClientCert() ([]byte, []byte, error) { - caKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, nil, err - } - - key, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, nil, err - } - - keyBytes := pem.EncodeToMemory(&pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(key), - }) - - caCert := x509.Certificate{ - SerialNumber: big.NewInt(42), - Subject: pkix.Name{ - CommonName: "test1", - }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(10, 0, 0), - IsCA: true, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - BasicConstraintsValid: true, - } - - cert := x509.Certificate{ - SerialNumber: big.NewInt(2019), - Subject: pkix.Name{ - CommonName: "test1", - }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(10, 0, 0), - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, - KeyUsage: x509.KeyUsageDigitalSignature, - } - - certData, err := x509.CreateCertificate(rand.Reader, &cert, &caCert, &key.PublicKey, caKey) - if err != nil { - return nil, nil, err - } - - certBytes := pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: certData, - }) - - return keyBytes, certBytes, nil -} - -func newMockReadFile(data map[string]([]byte)) ReadFileFunc { - return func(path string) ([]byte, error) { - bytes, ok := data[path] - if !ok { - return nil, os.ErrNotExist - } - return bytes, nil - } -} diff --git a/pkg/tsdb/grafana-postgresql-datasource/tlsmanager.go b/pkg/tsdb/grafana-postgresql-datasource/tlsmanager.go new file mode 100644 index 00000000000..116872d0613 --- /dev/null +++ b/pkg/tsdb/grafana-postgresql-datasource/tlsmanager.go @@ -0,0 +1,249 @@ +package postgres + +import ( + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "time" + + "github.com/grafana/grafana-plugin-sdk-go/backend/log" + "github.com/grafana/grafana/pkg/tsdb/sqleng" +) + +var validateCertFunc = validateCertFilePaths +var writeCertFileFunc = writeCertFile + +type certFileType int + +const ( + rootCert = iota + clientCert + clientKey +) + +type tlsSettingsProvider interface { + getTLSSettings(dsInfo sqleng.DataSourceInfo) (tlsSettings, error) +} + +type datasourceCacheManager struct { + locker *locker + cache sync.Map +} + +type tlsManager struct { + logger log.Logger + dsCacheInstance datasourceCacheManager + dataPath string +} + +func newTLSManager(logger log.Logger, dataPath string) tlsSettingsProvider { + return &tlsManager{ + logger: logger, + dataPath: dataPath, + dsCacheInstance: datasourceCacheManager{locker: newLocker()}, + } +} + +type tlsSettings struct { + Mode string + ConfigurationMethod string + RootCertFile string + CertFile string + CertKeyFile string +} + +func (m *tlsManager) getTLSSettings(dsInfo sqleng.DataSourceInfo) (tlsSettings, error) { + tlsconfig := tlsSettings{ + Mode: dsInfo.JsonData.Mode, + } + + isTLSDisabled := (tlsconfig.Mode == "disable") + + if isTLSDisabled { + m.logger.Debug("Postgres TLS/SSL is disabled") + return tlsconfig, nil + } + + m.logger.Debug("Postgres TLS/SSL is enabled", "tlsMode", tlsconfig.Mode) + + tlsconfig.ConfigurationMethod = dsInfo.JsonData.ConfigurationMethod + tlsconfig.RootCertFile = dsInfo.JsonData.RootCertFile + tlsconfig.CertFile = dsInfo.JsonData.CertFile + tlsconfig.CertKeyFile = dsInfo.JsonData.CertKeyFile + + if tlsconfig.ConfigurationMethod == "file-content" { + if err := m.writeCertFiles(dsInfo, &tlsconfig); err != nil { + return tlsconfig, err + } + } else { + if err := validateCertFunc(tlsconfig.RootCertFile, tlsconfig.CertFile, tlsconfig.CertKeyFile); err != nil { + return tlsconfig, err + } + } + return tlsconfig, nil +} + +func (t certFileType) String() string { + switch t { + case rootCert: + return "root certificate" + case clientCert: + return "client certificate" + case clientKey: + return "client key" + default: + panic(fmt.Sprintf("Unrecognized certFileType %d", t)) + } +} + +func getFileName(dataDir string, fileType certFileType) string { + var filename string + switch fileType { + case rootCert: + filename = "root.crt" + case clientCert: + filename = "client.crt" + case clientKey: + filename = "client.key" + default: + panic(fmt.Sprintf("unrecognized certFileType %s", fileType.String())) + } + generatedFilePath := filepath.Join(dataDir, filename) + return generatedFilePath +} + +// writeCertFile writes a certificate file. +func writeCertFile(logger log.Logger, fileContent string, generatedFilePath string) error { + fileContent = strings.TrimSpace(fileContent) + if fileContent != "" { + logger.Debug("Writing cert file", "path", generatedFilePath) + if err := os.WriteFile(generatedFilePath, []byte(fileContent), 0600); err != nil { + return err + } + // Make sure the file has the permissions expected by the Postgresql driver, otherwise it will bail + if err := os.Chmod(generatedFilePath, 0600); err != nil { + return err + } + return nil + } + + logger.Debug("Deleting cert file since no content is provided", "path", generatedFilePath) + exists, err := fileExists(generatedFilePath) + if err != nil { + return err + } + if exists { + if err := os.Remove(generatedFilePath); err != nil { + return fmt.Errorf("failed to remove %q: %w", generatedFilePath, err) + } + } + return nil +} + +func (m *tlsManager) writeCertFiles(dsInfo sqleng.DataSourceInfo, tlsconfig *tlsSettings) error { + m.logger.Debug("Writing TLS certificate files to disk") + tlsRootCert := dsInfo.DecryptedSecureJSONData["tlsCACert"] + tlsClientCert := dsInfo.DecryptedSecureJSONData["tlsClientCert"] + tlsClientKey := dsInfo.DecryptedSecureJSONData["tlsClientKey"] + if tlsRootCert == "" && tlsClientCert == "" && tlsClientKey == "" { + m.logger.Debug("No TLS/SSL certificates provided") + } + + // Calculate all files path + workDir := filepath.Join(m.dataPath, "tls", dsInfo.UID+"generatedTLSCerts") + tlsconfig.RootCertFile = getFileName(workDir, rootCert) + tlsconfig.CertFile = getFileName(workDir, clientCert) + tlsconfig.CertKeyFile = getFileName(workDir, clientKey) + + // Find datasource in the cache, if found, skip writing files + cacheKey := strconv.Itoa(int(dsInfo.ID)) + m.dsCacheInstance.locker.RLock(cacheKey) + item, ok := m.dsCacheInstance.cache.Load(cacheKey) + m.dsCacheInstance.locker.RUnlock(cacheKey) + if ok { + if !item.(time.Time).Before(dsInfo.Updated) { + return nil + } + } + + m.dsCacheInstance.locker.Lock(cacheKey) + defer m.dsCacheInstance.locker.Unlock(cacheKey) + + item, ok = m.dsCacheInstance.cache.Load(cacheKey) + if ok { + if !item.(time.Time).Before(dsInfo.Updated) { + return nil + } + } + + // Write certification directory and files + exists, err := fileExists(workDir) + if err != nil { + return err + } + if !exists { + if err := os.MkdirAll(workDir, 0700); err != nil { + return err + } + } + + if err = writeCertFileFunc(m.logger, tlsRootCert, tlsconfig.RootCertFile); err != nil { + return err + } + if err = writeCertFileFunc(m.logger, tlsClientCert, tlsconfig.CertFile); err != nil { + return err + } + if err = writeCertFileFunc(m.logger, tlsClientKey, tlsconfig.CertKeyFile); err != nil { + return err + } + + // we do not want to point to cert-files that do not exist + if tlsRootCert == "" { + tlsconfig.RootCertFile = "" + } + + if tlsClientCert == "" { + tlsconfig.CertFile = "" + } + + if tlsClientKey == "" { + tlsconfig.CertKeyFile = "" + } + + // Update datasource cache + m.dsCacheInstance.cache.Store(cacheKey, dsInfo.Updated) + return nil +} + +// validateCertFilePaths validates configured certificate file paths. +func validateCertFilePaths(rootCert, clientCert, clientKey string) error { + for _, fpath := range []string{rootCert, clientCert, clientKey} { + if fpath == "" { + continue + } + exists, err := fileExists(fpath) + if err != nil { + return err + } + if !exists { + return fmt.Errorf("certificate file %q doesn't exist", fpath) + } + } + return nil +} + +// Exists determines whether a file/directory exists or not. +func fileExists(fpath string) (bool, error) { + _, err := os.Stat(fpath) + if err != nil { + if !os.IsNotExist(err) { + return false, err + } + return false, nil + } + + return true, nil +} diff --git a/pkg/tsdb/grafana-postgresql-datasource/tlsmanager_test.go b/pkg/tsdb/grafana-postgresql-datasource/tlsmanager_test.go new file mode 100644 index 00000000000..8e60e841cab --- /dev/null +++ b/pkg/tsdb/grafana-postgresql-datasource/tlsmanager_test.go @@ -0,0 +1,332 @@ +package postgres + +import ( + "fmt" + "path/filepath" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana-plugin-sdk-go/backend/log" + "github.com/grafana/grafana/pkg/setting" + "github.com/grafana/grafana/pkg/tsdb/sqleng" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + _ "github.com/lib/pq" +) + +var writeCertFileCallNum int + +// TestDataSourceCacheManager is to test the Cache manager +func TestDataSourceCacheManager(t *testing.T) { + cfg := setting.NewCfg() + cfg.DataPath = t.TempDir() + mng := tlsManager{ + logger: backend.NewLoggerWith("logger", "tsdb.postgres"), + dsCacheInstance: datasourceCacheManager{locker: newLocker()}, + dataPath: cfg.DataPath, + } + jsonData := sqleng.JsonData{ + Mode: "verify-full", + ConfigurationMethod: "file-content", + } + secureJSONData := map[string]string{ + "tlsClientCert": "I am client certification", + "tlsClientKey": "I am client key", + "tlsCACert": "I am CA certification", + } + + updateTime := time.Now().Add(-5 * time.Minute) + + mockValidateCertFilePaths() + t.Cleanup(resetValidateCertFilePaths) + + t.Run("Check datasource cache creation", func(t *testing.T) { + var wg sync.WaitGroup + wg.Add(10) + for id := int64(1); id <= 10; id++ { + go func(id int64) { + ds := sqleng.DataSourceInfo{ + ID: id, + Updated: updateTime, + Database: "database", + JsonData: jsonData, + DecryptedSecureJSONData: secureJSONData, + UID: "testData", + } + s := tlsSettings{} + err := mng.writeCertFiles(ds, &s) + require.NoError(t, err) + wg.Done() + }(id) + } + wg.Wait() + + t.Run("check cache creation is succeed", func(t *testing.T) { + for id := int64(1); id <= 10; id++ { + updated, ok := mng.dsCacheInstance.cache.Load(strconv.Itoa(int(id))) + require.True(t, ok) + require.Equal(t, updateTime, updated) + } + }) + }) + + t.Run("Check datasource cache modification", func(t *testing.T) { + t.Run("check when version not changed, cache and files are not updated", func(t *testing.T) { + mockWriteCertFile() + t.Cleanup(resetWriteCertFile) + var wg1 sync.WaitGroup + wg1.Add(5) + for id := int64(1); id <= 5; id++ { + go func(id int64) { + ds := sqleng.DataSourceInfo{ + ID: 1, + Updated: updateTime, + Database: "database", + JsonData: jsonData, + DecryptedSecureJSONData: secureJSONData, + UID: "testData", + } + s := tlsSettings{} + err := mng.writeCertFiles(ds, &s) + require.NoError(t, err) + wg1.Done() + }(id) + } + wg1.Wait() + assert.Equal(t, writeCertFileCallNum, 0) + }) + + t.Run("cache is updated with the last datasource version", func(t *testing.T) { + dsV2 := sqleng.DataSourceInfo{ + ID: 1, + Updated: updateTime.Add(time.Minute), + Database: "database", + JsonData: jsonData, + DecryptedSecureJSONData: secureJSONData, + UID: "testData", + } + dsV3 := sqleng.DataSourceInfo{ + ID: 1, + Updated: updateTime.Add(2 * time.Minute), + Database: "database", + JsonData: jsonData, + DecryptedSecureJSONData: secureJSONData, + UID: "testData", + } + s := tlsSettings{} + err := mng.writeCertFiles(dsV2, &s) + require.NoError(t, err) + err = mng.writeCertFiles(dsV3, &s) + require.NoError(t, err) + version, ok := mng.dsCacheInstance.cache.Load("1") + require.True(t, ok) + require.Equal(t, updateTime.Add(2*time.Minute), version) + }) + }) +} + +// Test getFileName + +func TestGetFileName(t *testing.T) { + testCases := []struct { + desc string + datadir string + fileType certFileType + expErr string + expectedGeneratedPath string + }{ + { + desc: "Get File Name for root certification", + datadir: ".", + fileType: rootCert, + expectedGeneratedPath: "root.crt", + }, + { + desc: "Get File Name for client certification", + datadir: ".", + fileType: clientCert, + expectedGeneratedPath: "client.crt", + }, + { + desc: "Get File Name for client certification", + datadir: ".", + fileType: clientKey, + expectedGeneratedPath: "client.key", + }, + } + for _, tt := range testCases { + t.Run(tt.desc, func(t *testing.T) { + generatedPath := getFileName(tt.datadir, tt.fileType) + assert.Equal(t, tt.expectedGeneratedPath, generatedPath) + }) + } +} + +// Test getTLSSettings. +func TestGetTLSSettings(t *testing.T) { + cfg := setting.NewCfg() + cfg.DataPath = t.TempDir() + + mockValidateCertFilePaths() + t.Cleanup(resetValidateCertFilePaths) + + updatedTime := time.Now() + + testCases := []struct { + desc string + expErr string + jsonData sqleng.JsonData + secureJSONData map[string]string + uid string + tlsSettings tlsSettings + updated time.Time + }{ + { + desc: "Custom TLS authentication disabled", + updated: updatedTime, + jsonData: sqleng.JsonData{ + Mode: "disable", + RootCertFile: "i/am/coding/ca.crt", + CertFile: "i/am/coding/client.crt", + CertKeyFile: "i/am/coding/client.key", + ConfigurationMethod: "file-path", + }, + tlsSettings: tlsSettings{Mode: "disable"}, + }, + { + desc: "Custom TLS authentication with file path", + updated: updatedTime.Add(time.Minute), + jsonData: sqleng.JsonData{ + Mode: "verify-full", + ConfigurationMethod: "file-path", + RootCertFile: "i/am/coding/ca.crt", + CertFile: "i/am/coding/client.crt", + CertKeyFile: "i/am/coding/client.key", + }, + tlsSettings: tlsSettings{ + Mode: "verify-full", + ConfigurationMethod: "file-path", + RootCertFile: "i/am/coding/ca.crt", + CertFile: "i/am/coding/client.crt", + CertKeyFile: "i/am/coding/client.key", + }, + }, + { + desc: "Custom TLS mode verify-full with certificate files content", + updated: updatedTime.Add(2 * time.Minute), + uid: "xxx", + jsonData: sqleng.JsonData{ + Mode: "verify-full", + ConfigurationMethod: "file-content", + }, + secureJSONData: map[string]string{ + "tlsCACert": "I am CA certification", + "tlsClientCert": "I am client certification", + "tlsClientKey": "I am client key", + }, + tlsSettings: tlsSettings{ + Mode: "verify-full", + ConfigurationMethod: "file-content", + RootCertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "root.crt"), + CertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.crt"), + CertKeyFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.key"), + }, + }, + { + desc: "Custom TLS mode verify-ca with no client certificates with certificate files content", + updated: updatedTime.Add(3 * time.Minute), + uid: "xxx", + jsonData: sqleng.JsonData{ + Mode: "verify-ca", + ConfigurationMethod: "file-content", + }, + secureJSONData: map[string]string{ + "tlsCACert": "I am CA certification", + }, + tlsSettings: tlsSettings{ + Mode: "verify-ca", + ConfigurationMethod: "file-content", + RootCertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "root.crt"), + CertFile: "", + CertKeyFile: "", + }, + }, + { + desc: "Custom TLS mode require with client certificates and no root certificate with certificate files content", + updated: updatedTime.Add(4 * time.Minute), + uid: "xxx", + jsonData: sqleng.JsonData{ + Mode: "require", + ConfigurationMethod: "file-content", + }, + secureJSONData: map[string]string{ + "tlsClientCert": "I am client certification", + "tlsClientKey": "I am client key", + }, + tlsSettings: tlsSettings{ + Mode: "require", + ConfigurationMethod: "file-content", + RootCertFile: "", + CertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.crt"), + CertKeyFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.key"), + }, + }, + } + for _, tt := range testCases { + t.Run(tt.desc, func(t *testing.T) { + var settings tlsSettings + var err error + mng := tlsManager{ + logger: backend.NewLoggerWith("logger", "tsdb.postgres"), + dsCacheInstance: datasourceCacheManager{locker: newLocker()}, + dataPath: cfg.DataPath, + } + + ds := sqleng.DataSourceInfo{ + JsonData: tt.jsonData, + DecryptedSecureJSONData: tt.secureJSONData, + UID: tt.uid, + Updated: tt.updated, + } + + settings, err = mng.getTLSSettings(ds) + + if tt.expErr == "" { + require.NoError(t, err, tt.desc) + assert.Equal(t, tt.tlsSettings, settings) + } else { + require.Error(t, err, tt.desc) + assert.True(t, strings.HasPrefix(err.Error(), tt.expErr), + fmt.Sprintf("%s: %q doesn't start with %q", tt.desc, err, tt.expErr)) + } + }) + } +} + +func mockValidateCertFilePaths() { + validateCertFunc = func(rootCert, clientCert, clientKey string) error { + return nil + } +} + +func resetValidateCertFilePaths() { + validateCertFunc = validateCertFilePaths +} + +func mockWriteCertFile() { + writeCertFileCallNum = 0 + writeCertFileFunc = func(logger log.Logger, fileContent string, generatedFilePath string) error { + writeCertFileCallNum++ + return nil + } +} + +func resetWriteCertFile() { + writeCertFileCallNum = 0 + writeCertFileFunc = writeCertFile +}