Revert "Postgres: Switch the datasource plugin from lib/pq to pgx" (#83760)

Revert "Postgres: Switch the datasource plugin from lib/pq to pgx (#81353)"

This reverts commit 8c18d06386c87f2786119ea9a6334e35e4181cbe.
This commit is contained in:
Gábor Farkas
2024-03-01 12:20:47 +01:00
committed by GitHub
parent 0aebb9ee39
commit 142ac22023
17 changed files with 909 additions and 968 deletions

5
go.mod
View File

@ -471,14 +471,9 @@ require (
github.com/grafana/grafana/pkg/apiserver v0.0.0-20240226124929-648abdbd0ea4 // @grafana/grafana-app-platform-squad github.com/grafana/grafana/pkg/apiserver v0.0.0-20240226124929-648abdbd0ea4 // @grafana/grafana-app-platform-squad
) )
require github.com/jackc/pgx/v5 v5.5.3 // @grafana/oss-big-tent
require ( require (
github.com/bufbuild/protocompile v0.4.0 // indirect github.com/bufbuild/protocompile v0.4.0 // indirect
github.com/grafana/sqlds/v3 v3.2.0 // indirect github.com/grafana/sqlds/v3 v3.2.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jhump/protoreflect v1.15.1 // indirect github.com/jhump/protoreflect v1.15.1 // indirect
github.com/klauspost/asmfmt v1.3.2 // indirect github.com/klauspost/asmfmt v1.3.2 // indirect
github.com/krasun/gosqlparser v1.0.5 // @grafana/grafana-app-platform-squad github.com/krasun/gosqlparser v1.0.5 // @grafana/grafana-app-platform-squad

7
go.sum
View File

@ -2370,7 +2370,6 @@ github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bY
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE=
github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c=
github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
@ -2381,8 +2380,6 @@ github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwX
github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
github.com/jackc/pgproto3/v2 v2.2.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.2.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg=
github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc=
github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw=
@ -2394,14 +2391,10 @@ github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9
github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc=
github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs=
github.com/jackc/pgx/v4 v4.15.0/go.mod h1:D/zyOyXiaM1TmVWnOM18p0xdDtdakRBa0RsVGI3U3bw= github.com/jackc/pgx/v4 v4.15.0/go.mod h1:D/zyOyXiaM1TmVWnOM18p0xdDtdakRBa0RsVGI3U3bw=
github.com/jackc/pgx/v5 v5.5.3 h1:Ces6/M3wbDXYpM8JyyPD57ivTtJACFZJd885pdIaV2s=
github.com/jackc/pgx/v5 v5.5.3/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jackc/puddle v1.2.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.2.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jarcoal/httpmock v1.3.0/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg= github.com/jarcoal/httpmock v1.3.0/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg=
github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc= github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc=
github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4= github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4=

View File

@ -0,0 +1,85 @@
package postgres
import (
"fmt"
"sync"
)
// locker is a named reader/writer mutual exclusion lock.
// The lock for each particular key can be held by an arbitrary number of readers or a single writer.
type locker struct {
locks map[any]*sync.RWMutex
locksRW *sync.RWMutex
}
func newLocker() *locker {
return &locker{
locks: make(map[any]*sync.RWMutex),
locksRW: new(sync.RWMutex),
}
}
// Lock locks named rw mutex with specified key for writing.
// If the lock with the same key is already locked for reading or writing,
// Lock blocks until the lock is available.
func (lkr *locker) Lock(key any) {
lk, ok := lkr.getLock(key)
if !ok {
lk = lkr.newLock(key)
}
lk.Lock()
}
// Unlock unlocks named rw mutex with specified key for writing. It is a run-time error if rw is
// not locked for writing on entry to Unlock.
func (lkr *locker) Unlock(key any) {
lk, ok := lkr.getLock(key)
if !ok {
panic(fmt.Errorf("lock for key '%s' not initialized", key))
}
lk.Unlock()
}
// RLock locks named rw mutex with specified key for reading.
//
// It should not be used for recursive read locking for the same key; a blocked Lock
// call excludes new readers from acquiring the lock. See the
// documentation on the golang RWMutex type.
func (lkr *locker) RLock(key any) {
lk, ok := lkr.getLock(key)
if !ok {
lk = lkr.newLock(key)
}
lk.RLock()
}
// RUnlock undoes a single RLock call for specified key;
// it does not affect other simultaneous readers of locker for specified key.
// It is a run-time error if locker for specified key is not locked for reading
func (lkr *locker) RUnlock(key any) {
lk, ok := lkr.getLock(key)
if !ok {
panic(fmt.Errorf("lock for key '%s' not initialized", key))
}
lk.RUnlock()
}
func (lkr *locker) newLock(key any) *sync.RWMutex {
lkr.locksRW.Lock()
defer lkr.locksRW.Unlock()
if lk, ok := lkr.locks[key]; ok {
return lk
}
lk := new(sync.RWMutex)
lkr.locks[key] = lk
return lk
}
func (lkr *locker) getLock(key any) (*sync.RWMutex, bool) {
lkr.locksRW.RLock()
defer lkr.locksRW.RUnlock()
lock, ok := lkr.locks[key]
return lock, ok
}

View File

@ -0,0 +1,63 @@
package postgres
import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestIntegrationLocker(t *testing.T) {
if testing.Short() {
t.Skip("Tests with Sleep")
}
const notUpdated = "not_updated"
const atThread1 = "at_thread_1"
const atThread2 = "at_thread_2"
t.Run("Should lock for same keys", func(t *testing.T) {
updated := notUpdated
locker := newLocker()
locker.Lock(1)
var wg sync.WaitGroup
wg.Add(1)
defer func() {
locker.Unlock(1)
wg.Wait()
}()
go func() {
locker.RLock(1)
defer func() {
locker.RUnlock(1)
wg.Done()
}()
require.Equal(t, atThread1, updated, "Value should be updated in different thread")
updated = atThread2
}()
time.Sleep(time.Millisecond * 10)
require.Equal(t, notUpdated, updated, "Value should not be updated in different thread")
updated = atThread1
})
t.Run("Should not lock for different keys", func(t *testing.T) {
updated := notUpdated
locker := newLocker()
locker.Lock(1)
defer locker.Unlock(1)
var wg sync.WaitGroup
wg.Add(1)
go func() {
locker.RLock(2)
defer func() {
locker.RUnlock(2)
wg.Done()
}()
require.Equal(t, notUpdated, updated, "Value should not be updated in different thread")
updated = atThread2
}()
wg.Wait()
require.Equal(t, atThread2, updated, "Value should be updated in different thread")
updated = atThread1
})
}

View File

@ -4,33 +4,28 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"os"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
pgxstdlib "github.com/jackc/pgx/v5/stdlib"
"github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/datasource" "github.com/grafana/grafana-plugin-sdk-go/backend/datasource"
"github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt" "github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt"
"github.com/grafana/grafana-plugin-sdk-go/data" "github.com/grafana/grafana-plugin-sdk-go/data"
"github.com/grafana/grafana-plugin-sdk-go/data/sqlutil" "github.com/grafana/grafana-plugin-sdk-go/data/sqlutil"
"github.com/lib/pq"
"github.com/grafana/grafana-plugin-sdk-go/backend/log" "github.com/grafana/grafana-plugin-sdk-go/backend/log"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/grafana-postgresql-datasource/tls"
"github.com/grafana/grafana/pkg/tsdb/sqleng" "github.com/grafana/grafana/pkg/tsdb/sqleng"
) )
func ProvideService(cfg *setting.Cfg) *Service { func ProvideService(cfg *setting.Cfg) *Service {
logger := backend.NewLoggerWith("logger", "tsdb.postgres") logger := backend.NewLoggerWith("logger", "tsdb.postgres")
s := &Service{ s := &Service{
tlsManager: newTLSManager(logger, cfg.DataPath),
logger: logger, logger: logger,
} }
s.im = datasource.NewInstanceManager(s.newInstanceSettings()) s.im = datasource.NewInstanceManager(s.newInstanceSettings())
@ -38,6 +33,7 @@ func ProvideService(cfg *setting.Cfg) *Service {
} }
type Service struct { type Service struct {
tlsManager tlsSettingsProvider
im instancemgmt.InstanceManager im instancemgmt.InstanceManager
logger log.Logger logger log.Logger
} }
@ -59,7 +55,13 @@ func (s *Service) QueryData(ctx context.Context, req *backend.QueryDataRequest)
return dsInfo.QueryData(ctx, req) return dsInfo.QueryData(ctx, req)
} }
func newPostgres(ctx context.Context, userFacingDefaultError string, rowLimit int64, dsInfo sqleng.DataSourceInfo, pgxConf *pgx.ConnConfig, logger log.Logger, settings backend.DataSourceInstanceSettings) (*sql.DB, *sqleng.DataSourceHandler, error) { func newPostgres(ctx context.Context, userFacingDefaultError string, rowLimit int64, dsInfo sqleng.DataSourceInfo, cnnstr string, logger log.Logger, settings backend.DataSourceInstanceSettings) (*sql.DB, *sqleng.DataSourceHandler, error) {
connector, err := pq.NewConnector(cnnstr)
if err != nil {
logger.Error("postgres connector creation failed", "error", err)
return nil, nil, fmt.Errorf("postgres connector creation failed")
}
proxyClient, err := settings.ProxyClient(ctx) proxyClient, err := settings.ProxyClient(ctx)
if err != nil { if err != nil {
logger.Error("postgres proxy creation failed", "error", err) logger.Error("postgres proxy creation failed", "error", err)
@ -72,8 +74,9 @@ func newPostgres(ctx context.Context, userFacingDefaultError string, rowLimit in
logger.Error("postgres proxy creation failed", "error", err) logger.Error("postgres proxy creation failed", "error", err)
return nil, nil, fmt.Errorf("postgres proxy creation failed") return nil, nil, fmt.Errorf("postgres proxy creation failed")
} }
postgresDialer := newPostgresProxyDialer(dialer)
pgxConf.DialFunc = newPgxDialFunc(dialer) // update the postgres dialer with the proxy dialer
connector.Dialer(postgresDialer)
} }
config := sqleng.DataPluginConfiguration{ config := sqleng.DataPluginConfiguration{
@ -84,7 +87,7 @@ func newPostgres(ctx context.Context, userFacingDefaultError string, rowLimit in
queryResultTransformer := postgresQueryResultTransformer{} queryResultTransformer := postgresQueryResultTransformer{}
db := pgxstdlib.OpenDB(*pgxConf) db := sql.OpenDB(connector)
db.SetMaxOpenConns(config.DSInfo.JsonData.MaxOpenConns) db.SetMaxOpenConns(config.DSInfo.JsonData.MaxOpenConns)
db.SetMaxIdleConns(config.DSInfo.JsonData.MaxIdleConns) db.SetMaxIdleConns(config.DSInfo.JsonData.MaxIdleConns)
@ -140,7 +143,7 @@ func (s *Service) newInstanceSettings() datasource.InstanceFactoryFunc {
DecryptedSecureJSONData: settings.DecryptedSecureJSONData, DecryptedSecureJSONData: settings.DecryptedSecureJSONData,
} }
pgxConf, err := generateConnectionConfig(dsInfo) cnnstr, err := s.generateConnectionString(dsInfo)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -150,7 +153,7 @@ func (s *Service) newInstanceSettings() datasource.InstanceFactoryFunc {
return nil, err return nil, err
} }
_, handler, err := newPostgres(ctx, userFacingDefaultError, sqlCfg.RowLimit, dsInfo, pgxConf, logger, settings) _, handler, err := newPostgres(ctx, userFacingDefaultError, sqlCfg.RowLimit, dsInfo, cnnstr, logger, settings)
if err != nil { if err != nil {
logger.Error("Failed connecting to Postgres", "err", err) logger.Error("Failed connecting to Postgres", "err", err)
@ -167,11 +170,13 @@ func escape(input string) string {
return strings.ReplaceAll(strings.ReplaceAll(input, `\`, `\\`), "'", `\'`) return strings.ReplaceAll(strings.ReplaceAll(input, `\`, `\\`), "'", `\'`)
} }
func generateConnectionConfig(dsInfo sqleng.DataSourceInfo) (*pgx.ConnConfig, error) { func (s *Service) generateConnectionString(dsInfo sqleng.DataSourceInfo) (string, error) {
logger := s.logger
var host string var host string
var port int var port int
if strings.HasPrefix(dsInfo.URL, "/") { if strings.HasPrefix(dsInfo.URL, "/") {
host = dsInfo.URL host = dsInfo.URL
logger.Debug("Generating connection string with Unix socket specifier", "socket", host)
} else { } else {
index := strings.LastIndex(dsInfo.URL, ":") index := strings.LastIndex(dsInfo.URL, ":")
v6Index := strings.Index(dsInfo.URL, "]") v6Index := strings.Index(dsInfo.URL, "]")
@ -182,8 +187,12 @@ func generateConnectionConfig(dsInfo sqleng.DataSourceInfo) (*pgx.ConnConfig, er
var err error var err error
port, err = strconv.Atoi(sp[1]) port, err = strconv.Atoi(sp[1])
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid port in host specifier %q: %w", sp[1], err) return "", fmt.Errorf("invalid port in host specifier %q: %w", sp[1], err)
} }
logger.Debug("Generating connection string with network host/port pair", "host", host, "port", port)
} else {
logger.Debug("Generating connection string with network host", "host", host)
} }
} else { } else {
if index == v6Index+1 { if index == v6Index+1 {
@ -191,39 +200,46 @@ func generateConnectionConfig(dsInfo sqleng.DataSourceInfo) (*pgx.ConnConfig, er
var err error var err error
port, err = strconv.Atoi(dsInfo.URL[index+1:]) port, err = strconv.Atoi(dsInfo.URL[index+1:])
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid port in host specifier %q: %w", dsInfo.URL[index+1:], err) return "", fmt.Errorf("invalid port in host specifier %q: %w", dsInfo.URL[index+1:], err)
} }
logger.Debug("Generating ipv6 connection string with network host/port pair", "host", host, "port", port)
} else { } else {
host = dsInfo.URL[1 : len(dsInfo.URL)-1] host = dsInfo.URL[1 : len(dsInfo.URL)-1]
logger.Debug("Generating ipv6 connection string with network host", "host", host)
} }
} }
} }
// NOTE: we always set sslmode=disable in the connection string, we handle TLS manually later connStr := fmt.Sprintf("user='%s' password='%s' host='%s' dbname='%s'",
connStr := fmt.Sprintf("sslmode=disable user='%s' password='%s' host='%s' dbname='%s'",
escape(dsInfo.User), escape(dsInfo.DecryptedSecureJSONData["password"]), escape(host), escape(dsInfo.Database)) escape(dsInfo.User), escape(dsInfo.DecryptedSecureJSONData["password"]), escape(host), escape(dsInfo.Database))
if port > 0 { if port > 0 {
connStr += fmt.Sprintf(" port=%d", port) connStr += fmt.Sprintf(" port=%d", port)
} }
conf, err := pgx.ParseConfig(connStr) tlsSettings, err := s.tlsManager.getTLSSettings(dsInfo)
if err != nil { if err != nil {
return nil, err return "", err
} }
tlsConf, err := tls.GetTLSConfig(dsInfo, os.ReadFile, host) connStr += fmt.Sprintf(" sslmode='%s'", escape(tlsSettings.Mode))
if err != nil {
return nil, err // Attach root certificate if provided
if tlsSettings.RootCertFile != "" {
logger.Debug("Setting server root certificate", "tlsRootCert", tlsSettings.RootCertFile)
connStr += fmt.Sprintf(" sslrootcert='%s'", escape(tlsSettings.RootCertFile))
} }
// before we set the TLS config, we need to make sure the `.Fallbacks` attribute is unset, see: // Attach client certificate and key if both are provided
// https://github.com/jackc/pgx/discussions/1903#discussioncomment-8430146 if tlsSettings.CertFile != "" && tlsSettings.CertKeyFile != "" {
if len(conf.Fallbacks) > 0 { logger.Debug("Setting TLS/SSL client auth", "tlsCert", tlsSettings.CertFile, "tlsKey", tlsSettings.CertKeyFile)
return nil, errors.New("tls: fallbacks configured, unable to set up TLS config") connStr += fmt.Sprintf(" sslcert='%s' sslkey='%s'", escape(tlsSettings.CertFile), escape(tlsSettings.CertKeyFile))
} else if tlsSettings.CertFile != "" || tlsSettings.CertKeyFile != "" {
return "", fmt.Errorf("TLS/SSL client certificate and key must both be specified")
} }
conf.TLSConfig = tlsConf
return conf, nil logger.Debug("Generated Postgres connection string successfully")
return connStr, nil
} }
type postgresQueryResultTransformer struct{} type postgresQueryResultTransformer struct{}
@ -251,44 +267,6 @@ func (s *Service) CheckHealth(ctx context.Context, req *backend.CheckHealthReque
func (t *postgresQueryResultTransformer) GetConverterList() []sqlutil.StringConverter { func (t *postgresQueryResultTransformer) GetConverterList() []sqlutil.StringConverter {
return []sqlutil.StringConverter{ return []sqlutil.StringConverter{
{
Name: "handle TIME WITH TIME ZONE",
InputScanKind: reflect.Interface,
InputTypeName: strconv.Itoa(pgtype.TimetzOID),
ConversionFunc: func(in *string) (*string, error) { return in, nil },
Replacer: &sqlutil.StringFieldReplacer{
OutputFieldType: data.FieldTypeNullableTime,
ReplaceFunc: func(in *string) (any, error) {
if in == nil {
return nil, nil
}
v, err := time.Parse("15:04:05-07", *in)
if err != nil {
return nil, err
}
return &v, nil
},
},
},
{
Name: "handle TIME",
InputScanKind: reflect.Interface,
InputTypeName: "TIME",
ConversionFunc: func(in *string) (*string, error) { return in, nil },
Replacer: &sqlutil.StringFieldReplacer{
OutputFieldType: data.FieldTypeNullableTime,
ReplaceFunc: func(in *string) (any, error) {
if in == nil {
return nil, nil
}
v, err := time.Parse("15:04:05", *in)
if err != nil {
return nil, err
}
return &v, nil
},
},
},
{ {
Name: "handle FLOAT4", Name: "handle FLOAT4",
InputScanKind: reflect.Interface, InputScanKind: reflect.Interface,

View File

@ -14,7 +14,6 @@ import (
"github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/log" "github.com/grafana/grafana-plugin-sdk-go/backend/log"
"github.com/grafana/grafana-plugin-sdk-go/experimental" "github.com/grafana/grafana-plugin-sdk-go/experimental"
"github.com/jackc/pgx/v5"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/tsdb/sqleng" "github.com/grafana/grafana/pkg/tsdb/sqleng"
@ -52,7 +51,7 @@ func TestIntegrationPostgresSnapshots(t *testing.T) {
t.Skip() t.Skip()
} }
getCnn := func() (*pgx.ConnConfig, error) { getCnnStr := func() string {
host := os.Getenv("POSTGRES_HOST") host := os.Getenv("POSTGRES_HOST")
if host == "" { if host == "" {
host = "localhost" host = "localhost"
@ -62,10 +61,8 @@ func TestIntegrationPostgresSnapshots(t *testing.T) {
port = "5432" port = "5432"
} }
cnnString := fmt.Sprintf("user=grafanatest password=grafanatest host=%s port=%s dbname=grafanadstest sslmode=disable", return fmt.Sprintf("user=grafanatest password=grafanatest host=%s port=%s dbname=grafanadstest sslmode=disable",
host, port) host, port)
return pgx.ParseConfig(cnnString)
} }
sqlQueryCommentRe := regexp.MustCompile(`^-- (.+)\n`) sqlQueryCommentRe := regexp.MustCompile(`^-- (.+)\n`)
@ -160,10 +157,9 @@ func TestIntegrationPostgresSnapshots(t *testing.T) {
logger := log.New() logger := log.New()
cnn, err := getCnn() cnnstr := getCnnStr()
require.NoError(t, err)
db, handler, err := newPostgres(context.Background(), "error", 10000, dsInfo, cnn, logger, backend.DataSourceInstanceSettings{}) db, handler, err := newPostgres(context.Background(), "error", 10000, dsInfo, cnnstr, logger, backend.DataSourceInstanceSettings{})
t.Cleanup((func() { t.Cleanup((func() {
_, err := db.Exec("DROP TABLE tbl") _, err := db.Exec("DROP TABLE tbl")

View File

@ -14,16 +14,19 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/tsdb/grafana-postgresql-datasource/tls" "github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/sqleng" "github.com/grafana/grafana/pkg/tsdb/sqleng"
"github.com/jackc/pgx/v5" _ "github.com/lib/pq"
_ "github.com/jackc/pgx/v5/stdlib"
) )
func TestGenerateConnectionConfig(t *testing.T) { // Test generateConnectionString.
rootCertBytes, err := tls.CreateRandomRootCertBytes() func TestIntegrationGenerateConnectionString(t *testing.T) {
require.NoError(t, err) if testing.Short() {
t.Skip("skipping integration test")
}
cfg := setting.NewCfg()
cfg.DataPath = t.TempDir()
testCases := []struct { testCases := []struct {
desc string desc string
@ -31,15 +34,10 @@ func TestGenerateConnectionConfig(t *testing.T) {
user string user string
password string password string
database string database string
tlsMode string tlsSettings tlsSettings
tlsRootCert []byte expConnStr string
expErr string expErr string
expHost string uid string
expPort uint16
expUser string
expPassword string
expDatabase string
expTLS bool
}{ }{
{ {
desc: "Unix socket host", desc: "Unix socket host",
@ -47,11 +45,8 @@ func TestGenerateConnectionConfig(t *testing.T) {
user: "user", user: "user",
password: "password", password: "password",
database: "database", database: "database",
tlsMode: "disable", tlsSettings: tlsSettings{Mode: "verify-full"},
expUser: "user", expConnStr: "user='user' password='password' host='/var/run/postgresql' dbname='database' sslmode='verify-full'",
expPassword: "password",
expHost: "/var/run/postgresql",
expDatabase: "database",
}, },
{ {
desc: "TCP host", desc: "TCP host",
@ -59,12 +54,8 @@ func TestGenerateConnectionConfig(t *testing.T) {
user: "user", user: "user",
password: "password", password: "password",
database: "database", database: "database",
tlsMode: "disable", tlsSettings: tlsSettings{Mode: "verify-full"},
expUser: "user", expConnStr: "user='user' password='password' host='host' dbname='database' sslmode='verify-full'",
expPassword: "password",
expHost: "host",
expPort: 5432,
expDatabase: "database",
}, },
{ {
desc: "TCP/port host", desc: "TCP/port host",
@ -72,12 +63,8 @@ func TestGenerateConnectionConfig(t *testing.T) {
user: "user", user: "user",
password: "password", password: "password",
database: "database", database: "database",
tlsMode: "disable", tlsSettings: tlsSettings{Mode: "verify-full"},
expUser: "user", expConnStr: "user='user' password='password' host='host' dbname='database' port=1234 sslmode='verify-full'",
expPassword: "password",
expHost: "host",
expPort: 1234,
expDatabase: "database",
}, },
{ {
desc: "Ipv6 host", desc: "Ipv6 host",
@ -85,11 +72,8 @@ func TestGenerateConnectionConfig(t *testing.T) {
user: "user", user: "user",
password: "password", password: "password",
database: "database", database: "database",
tlsMode: "disable", tlsSettings: tlsSettings{Mode: "verify-full"},
expUser: "user", expConnStr: "user='user' password='password' host='::1' dbname='database' sslmode='verify-full'",
expPassword: "password",
expHost: "::1",
expDatabase: "database",
}, },
{ {
desc: "Ipv6/port host", desc: "Ipv6/port host",
@ -97,19 +81,15 @@ func TestGenerateConnectionConfig(t *testing.T) {
user: "user", user: "user",
password: "password", password: "password",
database: "database", database: "database",
tlsMode: "disable", tlsSettings: tlsSettings{Mode: "verify-full"},
expUser: "user", expConnStr: "user='user' password='password' host='::1' dbname='database' port=1234 sslmode='verify-full'",
expPassword: "password",
expHost: "::1",
expPort: 1234,
expDatabase: "database",
}, },
{ {
desc: "Invalid port", desc: "Invalid port",
host: "host:invalid", host: "host:invalid",
user: "user", user: "user",
database: "database", database: "database",
tlsMode: "disable", tlsSettings: tlsSettings{},
expErr: "invalid port in host specifier", expErr: "invalid port in host specifier",
}, },
{ {
@ -118,11 +98,8 @@ func TestGenerateConnectionConfig(t *testing.T) {
user: "user", user: "user",
password: `p'\assword`, password: `p'\assword`,
database: "database", database: "database",
tlsMode: "disable", tlsSettings: tlsSettings{Mode: "verify-full"},
expUser: "user", expConnStr: `user='user' password='p\'\\assword' host='host' dbname='database' sslmode='verify-full'`,
expPassword: `p'\assword`,
expHost: "host",
expDatabase: "database",
}, },
{ {
desc: "User/DB with single quote and backslash", desc: "User/DB with single quote and backslash",
@ -130,11 +107,8 @@ func TestGenerateConnectionConfig(t *testing.T) {
user: `u'\ser`, user: `u'\ser`,
password: `password`, password: `password`,
database: `d'\atabase`, database: `d'\atabase`,
tlsMode: "disable", tlsSettings: tlsSettings{Mode: "verify-full"},
expUser: `u'\ser`, expConnStr: `user='u\'\\ser' password='password' host='host' dbname='d\'\\atabase' sslmode='verify-full'`,
expPassword: "password",
expDatabase: `d'\atabase`,
expHost: "host",
}, },
{ {
desc: "Custom TLS mode disabled", desc: "Custom TLS mode disabled",
@ -142,11 +116,8 @@ func TestGenerateConnectionConfig(t *testing.T) {
user: "user", user: "user",
password: "password", password: "password",
database: "database", database: "database",
tlsMode: "disable", tlsSettings: tlsSettings{Mode: "disable"},
expUser: "user", expConnStr: "user='user' password='password' host='host' dbname='database' sslmode='disable'",
expPassword: "password",
expHost: "host",
expDatabase: "database",
}, },
{ {
desc: "Custom TLS mode verify-full with certificate files", desc: "Custom TLS mode verify-full with certificate files",
@ -154,43 +125,36 @@ func TestGenerateConnectionConfig(t *testing.T) {
user: "user", user: "user",
password: "password", password: "password",
database: "database", database: "database",
tlsMode: "verify-full", tlsSettings: tlsSettings{
tlsRootCert: rootCertBytes, Mode: "verify-full",
expUser: "user", RootCertFile: "i/am/coding/ca.crt",
expPassword: "password", CertFile: "i/am/coding/client.crt",
expDatabase: "database", CertKeyFile: "i/am/coding/client.key",
expHost: "host", },
expTLS: true, expConnStr: "user='user' password='password' host='host' dbname='database' sslmode='verify-full' " +
"sslrootcert='i/am/coding/ca.crt' sslcert='i/am/coding/client.crt' sslkey='i/am/coding/client.key'",
}, },
} }
for _, tt := range testCases { for _, tt := range testCases {
t.Run(tt.desc, func(t *testing.T) { t.Run(tt.desc, func(t *testing.T) {
svc := Service{
tlsManager: &tlsTestManager{settings: tt.tlsSettings},
logger: backend.NewLoggerWith("logger", "tsdb.postgres"),
}
ds := sqleng.DataSourceInfo{ ds := sqleng.DataSourceInfo{
URL: tt.host, URL: tt.host,
User: tt.user, User: tt.user,
DecryptedSecureJSONData: map[string]string{ DecryptedSecureJSONData: map[string]string{"password": tt.password},
"password": tt.password,
"tlsCACert": string(tt.tlsRootCert),
},
Database: tt.database, Database: tt.database,
JsonData: sqleng.JsonData{ UID: tt.uid,
Mode: tt.tlsMode,
ConfigurationMethod: "file-content",
},
} }
c, err := generateConnectionConfig(ds) connStr, err := svc.generateConnectionString(ds)
if tt.expErr == "" { if tt.expErr == "" {
require.NoError(t, err, tt.desc) require.NoError(t, err, tt.desc)
assert.Equal(t, tt.expHost, c.Host) assert.Equal(t, tt.expConnStr, connStr)
if tt.expPort != 0 {
assert.Equal(t, tt.expPort, c.Port)
}
assert.Equal(t, tt.expUser, c.User)
assert.Equal(t, tt.expDatabase, c.Database)
assert.Equal(t, tt.expPassword, c.Password)
require.Equal(t, tt.expTLS, c.TLSConfig != nil)
} else { } else {
require.Error(t, err, tt.desc) require.Error(t, err, tt.desc)
assert.True(t, strings.HasPrefix(err.Error(), tt.expErr), assert.True(t, strings.HasPrefix(err.Error(), tt.expErr),
@ -242,10 +206,9 @@ func TestIntegrationPostgres(t *testing.T) {
logger := backend.NewLoggerWith("logger", "postgres.test") logger := backend.NewLoggerWith("logger", "postgres.test")
cnn, err := postgresTestDBConn() cnnstr := postgresTestDBConnString()
require.NoError(t, err)
db, exe, err := newPostgres(context.Background(), "error", 10000, dsInfo, cnn, logger, backend.DataSourceInstanceSettings{}) db, exe, err := newPostgres(context.Background(), "error", 10000, dsInfo, cnnstr, logger, backend.DataSourceInstanceSettings{})
require.NoError(t, err) require.NoError(t, err)
@ -1299,7 +1262,7 @@ func TestIntegrationPostgres(t *testing.T) {
t.Run("When row limit set to 1", func(t *testing.T) { t.Run("When row limit set to 1", func(t *testing.T) {
dsInfo := sqleng.DataSourceInfo{} dsInfo := sqleng.DataSourceInfo{}
_, handler, err := newPostgres(context.Background(), "error", 1, dsInfo, cnn, logger, backend.DataSourceInstanceSettings{}) _, handler, err := newPostgres(context.Background(), "error", 1, dsInfo, cnnstr, logger, backend.DataSourceInstanceSettings{})
require.NoError(t, err) require.NoError(t, err)
@ -1414,6 +1377,14 @@ func genTimeRangeByInterval(from time.Time, duration time.Duration, interval tim
return timeRange return timeRange
} }
type tlsTestManager struct {
settings tlsSettings
}
func (m *tlsTestManager) getTLSSettings(dsInfo sqleng.DataSourceInfo) (tlsSettings, error) {
return m.settings, nil
}
func isTestDbPostgres() bool { func isTestDbPostgres() bool {
if db, present := os.LookupEnv("GRAFANA_TEST_DB"); present { if db, present := os.LookupEnv("GRAFANA_TEST_DB"); present {
return db == "postgres" return db == "postgres"
@ -1422,7 +1393,7 @@ func isTestDbPostgres() bool {
return false return false
} }
func postgresTestDBConn() (*pgx.ConnConfig, error) { func postgresTestDBConnString() string {
host := os.Getenv("POSTGRES_HOST") host := os.Getenv("POSTGRES_HOST")
if host == "" { if host == "" {
host = "localhost" host = "localhost"
@ -1431,8 +1402,6 @@ func postgresTestDBConn() (*pgx.ConnConfig, error) {
if port == "" { if port == "" {
port = "5432" port = "5432"
} }
connStr := fmt.Sprintf("user=grafanatest password=grafanatest host=%s port=%s dbname=grafanadstest sslmode=disable", return fmt.Sprintf("user=grafanatest password=grafanatest host=%s port=%s dbname=grafanadstest sslmode=disable",
host, port) host, port)
return pgx.ParseConfig(connStr)
} }

View File

@ -3,17 +3,33 @@ package postgres
import ( import (
"context" "context"
"net" "net"
"time"
"github.com/lib/pq"
"golang.org/x/net/proxy" "golang.org/x/net/proxy"
) )
type PgxDialFunc = func(ctx context.Context, network string, address string) (net.Conn, error) // we wrap the proxy.Dialer to become dialer that the postgres module accepts
func newPostgresProxyDialer(dialer proxy.Dialer) pq.Dialer {
func newPgxDialFunc(dialer proxy.Dialer) PgxDialFunc { return &postgresProxyDialer{d: dialer}
dialFunc :=
func(ctx context.Context, network string, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
} }
return dialFunc var _ pq.Dialer = (&postgresProxyDialer{})
// postgresProxyDialer implements the postgres dialer using a proxy dialer, as their functions differ slightly
type postgresProxyDialer struct {
d proxy.Dialer
}
// Dial uses the normal proxy dial function with the updated dialer
func (p *postgresProxyDialer) Dial(network, addr string) (c net.Conn, err error) {
return p.d.Dial(network, addr)
}
// DialTimeout uses the normal postgres dial timeout function with the updated dialer
func (p *postgresProxyDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return p.d.(proxy.ContextDialer).DialContext(ctx, network, address)
} }

View File

@ -1,12 +1,12 @@
package postgres package postgres
import ( import (
"database/sql"
"fmt" "fmt"
"net" "net"
"testing" "testing"
"github.com/jackc/pgx/v5" "github.com/lib/pq"
pgxstdlib "github.com/jackc/pgx/v5/stdlib"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/net/proxy" "golang.org/x/net/proxy"
) )
@ -25,13 +25,13 @@ func TestPostgresProxyDriver(t *testing.T) {
cnnstr := fmt.Sprintf("postgres://auser:password@%s/db?sslmode=disable", dbURL) cnnstr := fmt.Sprintf("postgres://auser:password@%s/db?sslmode=disable", dbURL)
t.Run("Connector should use dialer context that routes through the socks proxy to db", func(t *testing.T) { t.Run("Connector should use dialer context that routes through the socks proxy to db", func(t *testing.T) {
pgxConf, err := pgx.ParseConfig(cnnstr) connector, err := pq.NewConnector(cnnstr)
require.NoError(t, err) require.NoError(t, err)
dialer := newPostgresProxyDialer(&testDialer{})
pgxConf.DialFunc = newPgxDialFunc(&testDialer{}) connector.Dialer(dialer)
db := pgxstdlib.OpenDB(*pgxConf)
db := sql.OpenDB(connector)
err = db.Ping() err = db.Ping()
require.Contains(t, err.Error(), "test-dialer is not functional") require.Contains(t, err.Error(), "test-dialer is not functional")

View File

@ -9,16 +9,16 @@
// } // }
// Name: // Name:
// Dimensions: 4 Fields by 4 Rows // Dimensions: 4 Fields by 4 Rows
// +--------------------------------------+-----------------------------------+-----------------------+-----------------------------------+ // +--------------------------------------+-------------------------------+------------------+-----------------------------------+
// | Name: reallyt | Name: time | Name: n | Name: timeend | // | Name: reallyt | Name: time | Name: n | Name: timeend |
// | Labels: | Labels: | Labels: | Labels: | // | Labels: | Labels: | Labels: | Labels: |
// | Type: []*time.Time | Type: []*time.Time | Type: []*float64 | Type: []*time.Time | // | Type: []*time.Time | Type: []*time.Time | Type: []*float64 | Type: []*time.Time |
// +--------------------------------------+-----------------------------------+-----------------------+-----------------------------------+ // +--------------------------------------+-------------------------------+------------------+-----------------------------------+
// | 2023-12-21 12:22:24 +0000 UTC | 2023-12-21 12:22:24 +0000 UTC | 1.703161344e+09 | 2023-12-21 12:22:52 +0000 UTC | // | 2023-12-21 12:22:24 +0000 UTC | 2023-12-21 12:21:40 +0000 UTC | 1.7031613e+09 | 2023-12-21 12:22:52 +0000 UTC |
// | 2023-12-21 12:20:33.408 +0000 UTC | 2023-12-21 12:20:33.408 +0000 UTC | 1.703161233408e+12 | 2023-12-21 12:21:52.522 +0000 UTC | // | 2023-12-21 12:20:33.408 +0000 UTC | 2023-12-21 12:20:00 +0000 UTC | 1.7031612e+12 | 2023-12-21 12:21:52.522 +0000 UTC |
// | 2023-12-21 12:20:41.050022 +0000 UTC | 2023-12-21 12:20:41.05 +0000 UTC | 1.703161241050022e+18 | 2023-12-21 12:21:52.522 +0000 UTC | // | 2023-12-21 12:20:41.050022 +0000 UTC | 2023-12-21 12:20:00 +0000 UTC | 1.7031612e+18 | 2023-12-21 12:21:52.522 +0000 UTC |
// | null | null | null | null | // | null | null | null | null |
// +--------------------------------------+-----------------------------------+-----------------------+-----------------------------------+ // +--------------------------------------+-------------------------------+------------------+-----------------------------------+
// //
// //
// 🌟 This was machine generated. Do not edit. 🌟 // 🌟 This was machine generated. Do not edit. 🌟
@ -78,15 +78,15 @@
null null
], ],
[ [
1703161344000, 1703161300000,
1703161233408, 1703161200000,
1703161241050, 1703161200000,
null null
], ],
[ [
1703161344, 1703161300,
1703161233408, 1703161200000,
1703161241050022000, 1703161200000000000,
null null
], ],
[ [

View File

@ -9,14 +9,14 @@
// } // }
// Name: // Name:
// Dimensions: 12 Fields by 2 Rows // Dimensions: 12 Fields by 2 Rows
// +--------------------------------------+--------------------------------------+--------------------------------------+--------------------------------------+-------------------------------+-------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+ // +----------------------------------------+----------------------------------------+--------------------------------------+--------------------------------------+---------------------------------+---------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+
// | Name: ts | Name: tsnn | Name: tsz | Name: tsznn | Name: d | Name: dnn | Name: t | Name: tnn | Name: tz | Name: tznn | Name: i | Name: inn | // | Name: ts | Name: tsnn | Name: tsz | Name: tsznn | Name: d | Name: dnn | Name: t | Name: tnn | Name: tz | Name: tznn | Name: i | Name: inn |
// | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | // | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: |
// | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*string | Type: []*string | // | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*string | Type: []*string |
// +--------------------------------------+--------------------------------------+--------------------------------------+--------------------------------------+-------------------------------+-------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+ // +----------------------------------------+----------------------------------------+--------------------------------------+--------------------------------------+---------------------------------+---------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+
// | 2023-11-15 05:06:07.123456 +0000 UTC | 2023-11-15 05:06:08.123456 +0000 UTC | 2021-07-22 11:22:33.654321 +0000 UTC | 2021-07-22 11:22:34.654321 +0000 UTC | 2023-12-20 00:00:00 +0000 UTC | 2023-12-21 00:00:00 +0000 UTC | 0000-01-01 12:34:56.234567 +0000 UTC | 0000-01-01 12:34:57.234567 +0000 UTC | 0000-01-01 23:12:36.765432 +0100 +0100 | 0000-01-01 23:12:37.765432 +0100 +0100 | 00:00:00.987654 | 00:00:00.887654 | // | 2023-11-15 05:06:07.123456 +0000 +0000 | 2023-11-15 05:06:08.123456 +0000 +0000 | 2021-07-22 11:22:33.654321 +0000 UTC | 2021-07-22 11:22:34.654321 +0000 UTC | 2023-12-20 00:00:00 +0000 +0000 | 2023-12-21 00:00:00 +0000 +0000 | 0000-01-01 12:34:56.234567 +0000 UTC | 0000-01-01 12:34:57.234567 +0000 UTC | 0000-01-01 23:12:36.765432 +0100 +0100 | 0000-01-01 23:12:37.765432 +0100 +0100 | 00:00:00.987654 | 00:00:00.887654 |
// | null | 2023-11-15 05:06:09.123456 +0000 UTC | null | 2021-07-22 11:22:35.654321 +0000 UTC | null | 2023-12-22 00:00:00 +0000 UTC | null | 0000-01-01 12:34:58.234567 +0000 UTC | null | 0000-01-01 23:12:38.765432 +0100 +0100 | null | 00:00:00.787654 | // | null | 2023-11-15 05:06:09.123456 +0000 +0000 | null | 2021-07-22 11:22:35.654321 +0000 UTC | null | 2023-12-22 00:00:00 +0000 +0000 | null | 0000-01-01 12:34:58.234567 +0000 UTC | null | 0000-01-01 23:12:38.765432 +0100 +0100 | null | 00:00:00.787654 |
// +--------------------------------------+--------------------------------------+--------------------------------------+--------------------------------------+-------------------------------+-------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+ // +----------------------------------------+----------------------------------------+--------------------------------------+--------------------------------------+---------------------------------+---------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+
// //
// //
// 🌟 This was machine generated. Do not edit. 🌟 // 🌟 This was machine generated. Do not edit. 🌟

View File

@ -1,147 +0,0 @@
package tls
import (
"crypto/tls"
"crypto/x509"
"errors"
"github.com/grafana/grafana/pkg/tsdb/sqleng"
)
// we support 4 postgres tls modes:
// disable - no tls
// require - use tls
// verify-ca - use tls, verify root cert but not the hostname
// verify-full - use tls, verify root cert
// (for all the options except `disable`, you can optionally use client certificates)
var errNoRootCert = errors.New("tls: missing root certificate")
func getTLSConfigRequire(certs *Certs) (*tls.Config, error) {
// we may have a client-cert, we do not have a root-cert
// see https://www.postgresql.org/docs/12/libpq-ssl.html ,
// mode=require + provided root-cert should behave as mode=verify-ca
if certs.rootCerts != nil {
return getTLSConfigVerifyCA(certs)
}
return &tls.Config{
InsecureSkipVerify: true, // we do not verify the root cert
Certificates: certs.clientCerts,
}, nil
}
// to implement the verify-ca mode, we need to do this:
// - for the root certificate
// - verify that the certificate we receive from the server is trusted,
// meaning it relates to our root certificate
// - we DO NOT verify that the hostname of the database matches
// the hostname in the certificate
//
// the problem is, `go“ does not offer such an option.
// by default, it will verify both things.
//
// so what we do is:
// - we turn off the default-verification with `InsecureSkipVerify`
// - we implement our own verification using `VerifyConnection`
//
// extra info about this:
// - there is a rejected feature-request about this at https://github.com/golang/go/issues/21971
// - the recommended workaround is based on VerifyPeerCertificate
// - there is even example code at https://github.com/golang/go/commit/29cfb4d3c3a97b6f426d1b899234da905be699aa
// - but later the example code was changed to use VerifyConnection instead:
// https://github.com/golang/go/commit/7eb5941b95a588a23f18fa4c22fe42ff0119c311
//
// a verifyConnection example is at https://pkg.go.dev/crypto/tls#example-Config-VerifyConnection .
//
// this is how the `pgx` library handles verify-ca:
//
// https://github.com/jackc/pgx/blob/5c63f646f820ca9696fc3515c1caf2a557d562e5/pgconn/config.go#L657-L690
// (unfortunately pgx only handles this for certificate-provided-as-path, so we cannot rely on it)
func getTLSConfigVerifyCA(certs *Certs) (*tls.Config, error) {
// we must have a root certificate
if certs.rootCerts == nil {
return nil, errNoRootCert
}
conf := tls.Config{
Certificates: certs.clientCerts,
InsecureSkipVerify: true, // we turn off the default-verification, we'll do VerifyConnection instead
VerifyConnection: func(state tls.ConnectionState) error {
// we add all the certificates to the pool, we skip the first cert.
intermediates := x509.NewCertPool()
for _, c := range state.PeerCertificates[1:] {
intermediates.AddCert(c)
}
opts := x509.VerifyOptions{
Roots: certs.rootCerts,
Intermediates: intermediates,
}
// we call `Verify()` on the first cert (that we skipped previously)
_, err := state.PeerCertificates[0].Verify(opts)
return err
},
RootCAs: certs.rootCerts,
}
return &conf, nil
}
func getTLSConfigVerifyFull(certs *Certs, serverName string) (*tls.Config, error) {
// we must have a root certificate
if certs.rootCerts == nil {
return nil, errNoRootCert
}
conf := tls.Config{
Certificates: certs.clientCerts,
ServerName: serverName,
RootCAs: certs.rootCerts,
}
return &conf, nil
}
func IsTLSEnabled(dsInfo sqleng.DataSourceInfo) bool {
mode := dsInfo.JsonData.Mode
return mode != "disable"
}
// returns `nil` if tls is disabled
func GetTLSConfig(dsInfo sqleng.DataSourceInfo, readFile ReadFileFunc, serverName string) (*tls.Config, error) {
mode := dsInfo.JsonData.Mode
// we need to special-case the no-tls-mode
if mode == "disable" {
return nil, nil
}
// for all the remaining cases we need to load
// both the root-cert if exists, and the client-cert if exists.
certBytes, err := loadCertificateBytes(dsInfo, readFile)
if err != nil {
return nil, err
}
certs, err := createCertificates(certBytes)
if err != nil {
return nil, err
}
switch mode {
// `disable` already handled
case "":
// for backward-compatibility reasons this is the same as `require`
return getTLSConfigRequire(certs)
case "require":
return getTLSConfigRequire(certs)
case "verify-ca":
return getTLSConfigVerifyCA(certs)
case "verify-full":
return getTLSConfigVerifyFull(certs, serverName)
default:
return nil, errors.New("tls: invalid mode " + mode)
}
}

View File

@ -1,101 +0,0 @@
package tls
import (
"crypto/tls"
"crypto/x509"
"errors"
"github.com/grafana/grafana/pkg/tsdb/sqleng"
)
// this file deals with locating and loading the certificates,
// from json-data or from disk.
type CertBytes struct {
rootCert []byte
clientKey []byte
clientCert []byte
}
type ReadFileFunc = func(name string) ([]byte, error)
var errPartialClientCertNoKey = errors.New("tls: client cert provided but client key missing")
var errPartialClientCertNoCert = errors.New("tls: client key provided but client cert missing")
// certificates can be stored either as encrypted-json-data, or as file-path
func loadCertificateBytes(dsInfo sqleng.DataSourceInfo, readFile ReadFileFunc) (*CertBytes, error) {
if dsInfo.JsonData.ConfigurationMethod == "file-content" {
return &CertBytes{
rootCert: []byte(dsInfo.DecryptedSecureJSONData["tlsCACert"]),
clientKey: []byte(dsInfo.DecryptedSecureJSONData["tlsClientKey"]),
clientCert: []byte(dsInfo.DecryptedSecureJSONData["tlsClientCert"]),
}, nil
} else {
c := CertBytes{}
if dsInfo.JsonData.RootCertFile != "" {
rootCert, err := readFile(dsInfo.JsonData.RootCertFile)
if err != nil {
return nil, err
}
c.rootCert = rootCert
}
if dsInfo.JsonData.CertKeyFile != "" {
clientKey, err := readFile(dsInfo.JsonData.CertKeyFile)
if err != nil {
return nil, err
}
c.clientKey = clientKey
}
if dsInfo.JsonData.CertFile != "" {
clientCert, err := readFile(dsInfo.JsonData.CertFile)
if err != nil {
return nil, err
}
c.clientCert = clientCert
}
return &c, nil
}
}
type Certs struct {
clientCerts []tls.Certificate
rootCerts *x509.CertPool
}
func createCertificates(certBytes *CertBytes) (*Certs, error) {
certs := Certs{}
if len(certBytes.rootCert) > 0 {
pool := x509.NewCertPool()
ok := pool.AppendCertsFromPEM(certBytes.rootCert)
if !ok {
return nil, errors.New("tls: failed to add root certificate")
}
certs.rootCerts = pool
}
hasClientKey := len(certBytes.clientKey) > 0
hasClientCert := len(certBytes.clientCert) > 0
if hasClientKey && hasClientCert {
cert, err := tls.X509KeyPair(certBytes.clientCert, certBytes.clientKey)
if err != nil {
return nil, err
}
certs.clientCerts = []tls.Certificate{cert}
}
if hasClientKey && (!hasClientCert) {
return nil, errPartialClientCertNoCert
}
if hasClientCert && (!hasClientKey) {
return nil, errPartialClientCertNoKey
}
return &certs, nil
}

View File

@ -1,382 +0,0 @@
package tls
import (
"errors"
"os"
"testing"
"github.com/grafana/grafana/pkg/tsdb/sqleng"
"github.com/stretchr/testify/require"
)
func noReadFile(path string) ([]byte, error) {
return nil, errors.New("not implemented")
}
func TestTLSNoMode(t *testing.T) {
// for backward-compatibility reason,
// when mode is unset, it defaults to `require`
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
ConfigurationMethod: "",
},
}
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost")
require.NoError(t, err)
require.NotNil(t, c)
require.True(t, c.InsecureSkipVerify)
}
func TestTLSDisable(t *testing.T) {
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "disable",
ConfigurationMethod: "",
},
}
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost")
require.NoError(t, err)
require.Nil(t, c)
}
func TestTLSRequire(t *testing.T) {
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "require",
ConfigurationMethod: "",
},
}
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost")
require.NoError(t, err)
require.NotNil(t, c)
require.True(t, c.InsecureSkipVerify)
require.Nil(t, c.RootCAs)
}
func TestTLSRequireWithRootCert(t *testing.T) {
rootCertBytes, err := CreateRandomRootCertBytes()
require.NoError(t, err)
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "require",
ConfigurationMethod: "file-content",
},
DecryptedSecureJSONData: map[string]string{
"tlsCACert": string(rootCertBytes),
},
}
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost")
require.NoError(t, err)
require.NotNil(t, c)
require.True(t, c.InsecureSkipVerify)
require.NotNil(t, c.VerifyConnection)
require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available
}
func TestTLSVerifyCA(t *testing.T) {
rootCertBytes, err := CreateRandomRootCertBytes()
require.NoError(t, err)
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "verify-ca",
ConfigurationMethod: "file-content",
},
DecryptedSecureJSONData: map[string]string{
"tlsCACert": string(rootCertBytes),
},
}
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost")
require.NoError(t, err)
require.NotNil(t, c)
require.True(t, c.InsecureSkipVerify)
require.NotNil(t, c.VerifyConnection)
require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available
}
func TestTLSVerifyCAMisingRootCert(t *testing.T) {
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "verify-ca",
ConfigurationMethod: "file-content",
},
DecryptedSecureJSONData: map[string]string{},
}
_, err := GetTLSConfig(dsInfo, noReadFile, "localhost")
require.ErrorIs(t, err, errNoRootCert)
}
func TestTLSClientCert(t *testing.T) {
clientKey, clientCert, err := CreateRandomClientCert()
require.NoError(t, err)
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "require",
ConfigurationMethod: "file-content",
},
DecryptedSecureJSONData: map[string]string{
"tlsClientCert": string(clientCert),
"tlsClientKey": string(clientKey),
},
}
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost")
require.NoError(t, err)
require.NotNil(t, c)
require.Len(t, c.Certificates, 1)
}
func TestTLSMethodFileContentClientCertMissingKey(t *testing.T) {
_, clientCert, err := CreateRandomClientCert()
require.NoError(t, err)
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "require",
ConfigurationMethod: "file-content",
},
DecryptedSecureJSONData: map[string]string{
"tlsClientCert": string(clientCert),
},
}
_, err = GetTLSConfig(dsInfo, noReadFile, "localhost")
require.ErrorIs(t, err, errPartialClientCertNoKey)
}
func TestTLSMethodFileContentClientCertMissingCert(t *testing.T) {
clientKey, _, err := CreateRandomClientCert()
require.NoError(t, err)
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "require",
ConfigurationMethod: "file-content",
},
DecryptedSecureJSONData: map[string]string{
"tlsClientKey": string(clientKey),
},
}
_, err = GetTLSConfig(dsInfo, noReadFile, "localhost")
require.ErrorIs(t, err, errPartialClientCertNoCert)
}
func TestTLSMethodFilePathClientCertMissingKey(t *testing.T) {
clientKey, _, err := CreateRandomClientCert()
require.NoError(t, err)
readFile := newMockReadFile(map[string]([]byte){
"path1": clientKey,
})
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "require",
ConfigurationMethod: "file-path",
CertKeyFile: "path1",
},
}
_, err = GetTLSConfig(dsInfo, readFile, "localhost")
require.ErrorIs(t, err, errPartialClientCertNoCert)
}
func TestTLSMethodFilePathClientCertMissingCert(t *testing.T) {
_, clientCert, err := CreateRandomClientCert()
require.NoError(t, err)
readFile := newMockReadFile(map[string]([]byte){
"path1": clientCert,
})
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "require",
ConfigurationMethod: "file-path",
CertFile: "path1",
},
}
_, err = GetTLSConfig(dsInfo, readFile, "localhost")
require.ErrorIs(t, err, errPartialClientCertNoKey)
}
func TestTLSVerifyFull(t *testing.T) {
rootCertBytes, err := CreateRandomRootCertBytes()
require.NoError(t, err)
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "verify-full",
ConfigurationMethod: "file-content",
},
DecryptedSecureJSONData: map[string]string{
"tlsCACert": string(rootCertBytes),
},
}
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost")
require.NoError(t, err)
require.NotNil(t, c)
require.False(t, c.InsecureSkipVerify)
require.Nil(t, c.VerifyConnection)
require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available
}
func TestTLSMethodFileContent(t *testing.T) {
rootCertBytes, err := CreateRandomRootCertBytes()
require.NoError(t, err)
clientKey, clientCert, err := CreateRandomClientCert()
require.NoError(t, err)
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "verify-full",
ConfigurationMethod: "file-content",
},
DecryptedSecureJSONData: map[string]string{
"tlsCACert": string(rootCertBytes),
"tlsClientCert": string(clientCert),
"tlsClientKey": string(clientKey),
},
}
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost")
require.NoError(t, err)
require.NotNil(t, c)
require.Len(t, c.Certificates, 1)
require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available
}
func TestTLSMethodFilePath(t *testing.T) {
rootCertBytes, err := CreateRandomRootCertBytes()
require.NoError(t, err)
clientKey, clientCert, err := CreateRandomClientCert()
require.NoError(t, err)
readFile := newMockReadFile(map[string]([]byte){
"root-cert-path": rootCertBytes,
"client-key-path": clientKey,
"client-cert-path": clientCert,
})
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "verify-full",
ConfigurationMethod: "file-path",
RootCertFile: "root-cert-path",
CertKeyFile: "client-key-path",
CertFile: "client-cert-path",
},
}
c, err := GetTLSConfig(dsInfo, readFile, "localhost")
require.NoError(t, err)
require.NotNil(t, c)
require.Len(t, c.Certificates, 1)
require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available
}
func TestTLSMethodFilePathRootCertDoesNotExist(t *testing.T) {
readFile := newMockReadFile(map[string]([]byte){})
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "verify-full",
ConfigurationMethod: "file-path",
RootCertFile: "path1",
},
}
_, err := GetTLSConfig(dsInfo, readFile, "localhost")
require.ErrorIs(t, err, os.ErrNotExist)
}
func TestTLSMethodFilePathClientCertKeyDoesNotExist(t *testing.T) {
_, clientCert, err := CreateRandomClientCert()
require.NoError(t, err)
readFile := newMockReadFile(map[string]([]byte){
"cert-path": clientCert,
})
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "require",
ConfigurationMethod: "file-path",
CertKeyFile: "key-path",
CertFile: "cert-path",
},
}
_, err = GetTLSConfig(dsInfo, readFile, "localhost")
require.ErrorIs(t, err, os.ErrNotExist)
}
func TestTLSMethodFilePathClientCertCertDoesNotExist(t *testing.T) {
clientKey, _, err := CreateRandomClientCert()
require.NoError(t, err)
readFile := newMockReadFile(map[string]([]byte){
"key-path": clientKey,
})
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "require",
ConfigurationMethod: "file-path",
CertKeyFile: "key-path",
CertFile: "cert-path",
},
}
_, err = GetTLSConfig(dsInfo, readFile, "localhost")
require.ErrorIs(t, err, os.ErrNotExist)
}
// method="" equals to method="file-path"
func TestTLSMethodEmpty(t *testing.T) {
rootCertBytes, err := CreateRandomRootCertBytes()
require.NoError(t, err)
clientKey, clientCert, err := CreateRandomClientCert()
require.NoError(t, err)
readFile := newMockReadFile(map[string]([]byte){
"root-cert-path": rootCertBytes,
"client-key-path": clientKey,
"client-cert-path": clientCert,
})
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "verify-full",
ConfigurationMethod: "",
RootCertFile: "root-cert-path",
CertKeyFile: "client-key-path",
CertFile: "client-cert-path",
},
}
c, err := GetTLSConfig(dsInfo, readFile, "localhost")
require.NoError(t, err)
require.NotNil(t, c)
require.Len(t, c.Certificates, 1)
require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available
}
func TestTLSVerifyFullMisingRootCert(t *testing.T) {
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "verify-full",
ConfigurationMethod: "file-content",
},
DecryptedSecureJSONData: map[string]string{},
}
_, err := GetTLSConfig(dsInfo, noReadFile, "localhost")
require.ErrorIs(t, err, errNoRootCert)
}
func TestTLSInvalidMode(t *testing.T) {
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "not-a-valid-mode",
},
}
_, err := GetTLSConfig(dsInfo, noReadFile, "localhost")
require.Error(t, err)
}

View File

@ -1,105 +0,0 @@
package tls
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"time"
)
func CreateRandomRootCertBytes() ([]byte, error) {
cert := x509.Certificate{
SerialNumber: big.NewInt(42),
Subject: pkix.Name{
CommonName: "test1",
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, err
}
bytes, err := x509.CreateCertificate(rand.Reader, &cert, &cert, &key.PublicKey, key)
if err != nil {
return nil, err
}
return pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: bytes,
}), nil
}
func CreateRandomClientCert() ([]byte, []byte, error) {
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, err
}
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, err
}
keyBytes := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
caCert := x509.Certificate{
SerialNumber: big.NewInt(42),
Subject: pkix.Name{
CommonName: "test1",
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
cert := x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{
CommonName: "test1",
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
}
certData, err := x509.CreateCertificate(rand.Reader, &cert, &caCert, &key.PublicKey, caKey)
if err != nil {
return nil, nil, err
}
certBytes := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certData,
})
return keyBytes, certBytes, nil
}
func newMockReadFile(data map[string]([]byte)) ReadFileFunc {
return func(path string) ([]byte, error) {
bytes, ok := data[path]
if !ok {
return nil, os.ErrNotExist
}
return bytes, nil
}
}

View File

@ -0,0 +1,249 @@
package postgres
import (
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/grafana/grafana-plugin-sdk-go/backend/log"
"github.com/grafana/grafana/pkg/tsdb/sqleng"
)
var validateCertFunc = validateCertFilePaths
var writeCertFileFunc = writeCertFile
type certFileType int
const (
rootCert = iota
clientCert
clientKey
)
type tlsSettingsProvider interface {
getTLSSettings(dsInfo sqleng.DataSourceInfo) (tlsSettings, error)
}
type datasourceCacheManager struct {
locker *locker
cache sync.Map
}
type tlsManager struct {
logger log.Logger
dsCacheInstance datasourceCacheManager
dataPath string
}
func newTLSManager(logger log.Logger, dataPath string) tlsSettingsProvider {
return &tlsManager{
logger: logger,
dataPath: dataPath,
dsCacheInstance: datasourceCacheManager{locker: newLocker()},
}
}
type tlsSettings struct {
Mode string
ConfigurationMethod string
RootCertFile string
CertFile string
CertKeyFile string
}
func (m *tlsManager) getTLSSettings(dsInfo sqleng.DataSourceInfo) (tlsSettings, error) {
tlsconfig := tlsSettings{
Mode: dsInfo.JsonData.Mode,
}
isTLSDisabled := (tlsconfig.Mode == "disable")
if isTLSDisabled {
m.logger.Debug("Postgres TLS/SSL is disabled")
return tlsconfig, nil
}
m.logger.Debug("Postgres TLS/SSL is enabled", "tlsMode", tlsconfig.Mode)
tlsconfig.ConfigurationMethod = dsInfo.JsonData.ConfigurationMethod
tlsconfig.RootCertFile = dsInfo.JsonData.RootCertFile
tlsconfig.CertFile = dsInfo.JsonData.CertFile
tlsconfig.CertKeyFile = dsInfo.JsonData.CertKeyFile
if tlsconfig.ConfigurationMethod == "file-content" {
if err := m.writeCertFiles(dsInfo, &tlsconfig); err != nil {
return tlsconfig, err
}
} else {
if err := validateCertFunc(tlsconfig.RootCertFile, tlsconfig.CertFile, tlsconfig.CertKeyFile); err != nil {
return tlsconfig, err
}
}
return tlsconfig, nil
}
func (t certFileType) String() string {
switch t {
case rootCert:
return "root certificate"
case clientCert:
return "client certificate"
case clientKey:
return "client key"
default:
panic(fmt.Sprintf("Unrecognized certFileType %d", t))
}
}
func getFileName(dataDir string, fileType certFileType) string {
var filename string
switch fileType {
case rootCert:
filename = "root.crt"
case clientCert:
filename = "client.crt"
case clientKey:
filename = "client.key"
default:
panic(fmt.Sprintf("unrecognized certFileType %s", fileType.String()))
}
generatedFilePath := filepath.Join(dataDir, filename)
return generatedFilePath
}
// writeCertFile writes a certificate file.
func writeCertFile(logger log.Logger, fileContent string, generatedFilePath string) error {
fileContent = strings.TrimSpace(fileContent)
if fileContent != "" {
logger.Debug("Writing cert file", "path", generatedFilePath)
if err := os.WriteFile(generatedFilePath, []byte(fileContent), 0600); err != nil {
return err
}
// Make sure the file has the permissions expected by the Postgresql driver, otherwise it will bail
if err := os.Chmod(generatedFilePath, 0600); err != nil {
return err
}
return nil
}
logger.Debug("Deleting cert file since no content is provided", "path", generatedFilePath)
exists, err := fileExists(generatedFilePath)
if err != nil {
return err
}
if exists {
if err := os.Remove(generatedFilePath); err != nil {
return fmt.Errorf("failed to remove %q: %w", generatedFilePath, err)
}
}
return nil
}
func (m *tlsManager) writeCertFiles(dsInfo sqleng.DataSourceInfo, tlsconfig *tlsSettings) error {
m.logger.Debug("Writing TLS certificate files to disk")
tlsRootCert := dsInfo.DecryptedSecureJSONData["tlsCACert"]
tlsClientCert := dsInfo.DecryptedSecureJSONData["tlsClientCert"]
tlsClientKey := dsInfo.DecryptedSecureJSONData["tlsClientKey"]
if tlsRootCert == "" && tlsClientCert == "" && tlsClientKey == "" {
m.logger.Debug("No TLS/SSL certificates provided")
}
// Calculate all files path
workDir := filepath.Join(m.dataPath, "tls", dsInfo.UID+"generatedTLSCerts")
tlsconfig.RootCertFile = getFileName(workDir, rootCert)
tlsconfig.CertFile = getFileName(workDir, clientCert)
tlsconfig.CertKeyFile = getFileName(workDir, clientKey)
// Find datasource in the cache, if found, skip writing files
cacheKey := strconv.Itoa(int(dsInfo.ID))
m.dsCacheInstance.locker.RLock(cacheKey)
item, ok := m.dsCacheInstance.cache.Load(cacheKey)
m.dsCacheInstance.locker.RUnlock(cacheKey)
if ok {
if !item.(time.Time).Before(dsInfo.Updated) {
return nil
}
}
m.dsCacheInstance.locker.Lock(cacheKey)
defer m.dsCacheInstance.locker.Unlock(cacheKey)
item, ok = m.dsCacheInstance.cache.Load(cacheKey)
if ok {
if !item.(time.Time).Before(dsInfo.Updated) {
return nil
}
}
// Write certification directory and files
exists, err := fileExists(workDir)
if err != nil {
return err
}
if !exists {
if err := os.MkdirAll(workDir, 0700); err != nil {
return err
}
}
if err = writeCertFileFunc(m.logger, tlsRootCert, tlsconfig.RootCertFile); err != nil {
return err
}
if err = writeCertFileFunc(m.logger, tlsClientCert, tlsconfig.CertFile); err != nil {
return err
}
if err = writeCertFileFunc(m.logger, tlsClientKey, tlsconfig.CertKeyFile); err != nil {
return err
}
// we do not want to point to cert-files that do not exist
if tlsRootCert == "" {
tlsconfig.RootCertFile = ""
}
if tlsClientCert == "" {
tlsconfig.CertFile = ""
}
if tlsClientKey == "" {
tlsconfig.CertKeyFile = ""
}
// Update datasource cache
m.dsCacheInstance.cache.Store(cacheKey, dsInfo.Updated)
return nil
}
// validateCertFilePaths validates configured certificate file paths.
func validateCertFilePaths(rootCert, clientCert, clientKey string) error {
for _, fpath := range []string{rootCert, clientCert, clientKey} {
if fpath == "" {
continue
}
exists, err := fileExists(fpath)
if err != nil {
return err
}
if !exists {
return fmt.Errorf("certificate file %q doesn't exist", fpath)
}
}
return nil
}
// Exists determines whether a file/directory exists or not.
func fileExists(fpath string) (bool, error) {
_, err := os.Stat(fpath)
if err != nil {
if !os.IsNotExist(err) {
return false, err
}
return false, nil
}
return true, nil
}

View File

@ -0,0 +1,332 @@
package postgres
import (
"fmt"
"path/filepath"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/log"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/sqleng"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
_ "github.com/lib/pq"
)
var writeCertFileCallNum int
// TestDataSourceCacheManager is to test the Cache manager
func TestDataSourceCacheManager(t *testing.T) {
cfg := setting.NewCfg()
cfg.DataPath = t.TempDir()
mng := tlsManager{
logger: backend.NewLoggerWith("logger", "tsdb.postgres"),
dsCacheInstance: datasourceCacheManager{locker: newLocker()},
dataPath: cfg.DataPath,
}
jsonData := sqleng.JsonData{
Mode: "verify-full",
ConfigurationMethod: "file-content",
}
secureJSONData := map[string]string{
"tlsClientCert": "I am client certification",
"tlsClientKey": "I am client key",
"tlsCACert": "I am CA certification",
}
updateTime := time.Now().Add(-5 * time.Minute)
mockValidateCertFilePaths()
t.Cleanup(resetValidateCertFilePaths)
t.Run("Check datasource cache creation", func(t *testing.T) {
var wg sync.WaitGroup
wg.Add(10)
for id := int64(1); id <= 10; id++ {
go func(id int64) {
ds := sqleng.DataSourceInfo{
ID: id,
Updated: updateTime,
Database: "database",
JsonData: jsonData,
DecryptedSecureJSONData: secureJSONData,
UID: "testData",
}
s := tlsSettings{}
err := mng.writeCertFiles(ds, &s)
require.NoError(t, err)
wg.Done()
}(id)
}
wg.Wait()
t.Run("check cache creation is succeed", func(t *testing.T) {
for id := int64(1); id <= 10; id++ {
updated, ok := mng.dsCacheInstance.cache.Load(strconv.Itoa(int(id)))
require.True(t, ok)
require.Equal(t, updateTime, updated)
}
})
})
t.Run("Check datasource cache modification", func(t *testing.T) {
t.Run("check when version not changed, cache and files are not updated", func(t *testing.T) {
mockWriteCertFile()
t.Cleanup(resetWriteCertFile)
var wg1 sync.WaitGroup
wg1.Add(5)
for id := int64(1); id <= 5; id++ {
go func(id int64) {
ds := sqleng.DataSourceInfo{
ID: 1,
Updated: updateTime,
Database: "database",
JsonData: jsonData,
DecryptedSecureJSONData: secureJSONData,
UID: "testData",
}
s := tlsSettings{}
err := mng.writeCertFiles(ds, &s)
require.NoError(t, err)
wg1.Done()
}(id)
}
wg1.Wait()
assert.Equal(t, writeCertFileCallNum, 0)
})
t.Run("cache is updated with the last datasource version", func(t *testing.T) {
dsV2 := sqleng.DataSourceInfo{
ID: 1,
Updated: updateTime.Add(time.Minute),
Database: "database",
JsonData: jsonData,
DecryptedSecureJSONData: secureJSONData,
UID: "testData",
}
dsV3 := sqleng.DataSourceInfo{
ID: 1,
Updated: updateTime.Add(2 * time.Minute),
Database: "database",
JsonData: jsonData,
DecryptedSecureJSONData: secureJSONData,
UID: "testData",
}
s := tlsSettings{}
err := mng.writeCertFiles(dsV2, &s)
require.NoError(t, err)
err = mng.writeCertFiles(dsV3, &s)
require.NoError(t, err)
version, ok := mng.dsCacheInstance.cache.Load("1")
require.True(t, ok)
require.Equal(t, updateTime.Add(2*time.Minute), version)
})
})
}
// Test getFileName
func TestGetFileName(t *testing.T) {
testCases := []struct {
desc string
datadir string
fileType certFileType
expErr string
expectedGeneratedPath string
}{
{
desc: "Get File Name for root certification",
datadir: ".",
fileType: rootCert,
expectedGeneratedPath: "root.crt",
},
{
desc: "Get File Name for client certification",
datadir: ".",
fileType: clientCert,
expectedGeneratedPath: "client.crt",
},
{
desc: "Get File Name for client certification",
datadir: ".",
fileType: clientKey,
expectedGeneratedPath: "client.key",
},
}
for _, tt := range testCases {
t.Run(tt.desc, func(t *testing.T) {
generatedPath := getFileName(tt.datadir, tt.fileType)
assert.Equal(t, tt.expectedGeneratedPath, generatedPath)
})
}
}
// Test getTLSSettings.
func TestGetTLSSettings(t *testing.T) {
cfg := setting.NewCfg()
cfg.DataPath = t.TempDir()
mockValidateCertFilePaths()
t.Cleanup(resetValidateCertFilePaths)
updatedTime := time.Now()
testCases := []struct {
desc string
expErr string
jsonData sqleng.JsonData
secureJSONData map[string]string
uid string
tlsSettings tlsSettings
updated time.Time
}{
{
desc: "Custom TLS authentication disabled",
updated: updatedTime,
jsonData: sqleng.JsonData{
Mode: "disable",
RootCertFile: "i/am/coding/ca.crt",
CertFile: "i/am/coding/client.crt",
CertKeyFile: "i/am/coding/client.key",
ConfigurationMethod: "file-path",
},
tlsSettings: tlsSettings{Mode: "disable"},
},
{
desc: "Custom TLS authentication with file path",
updated: updatedTime.Add(time.Minute),
jsonData: sqleng.JsonData{
Mode: "verify-full",
ConfigurationMethod: "file-path",
RootCertFile: "i/am/coding/ca.crt",
CertFile: "i/am/coding/client.crt",
CertKeyFile: "i/am/coding/client.key",
},
tlsSettings: tlsSettings{
Mode: "verify-full",
ConfigurationMethod: "file-path",
RootCertFile: "i/am/coding/ca.crt",
CertFile: "i/am/coding/client.crt",
CertKeyFile: "i/am/coding/client.key",
},
},
{
desc: "Custom TLS mode verify-full with certificate files content",
updated: updatedTime.Add(2 * time.Minute),
uid: "xxx",
jsonData: sqleng.JsonData{
Mode: "verify-full",
ConfigurationMethod: "file-content",
},
secureJSONData: map[string]string{
"tlsCACert": "I am CA certification",
"tlsClientCert": "I am client certification",
"tlsClientKey": "I am client key",
},
tlsSettings: tlsSettings{
Mode: "verify-full",
ConfigurationMethod: "file-content",
RootCertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "root.crt"),
CertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.crt"),
CertKeyFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.key"),
},
},
{
desc: "Custom TLS mode verify-ca with no client certificates with certificate files content",
updated: updatedTime.Add(3 * time.Minute),
uid: "xxx",
jsonData: sqleng.JsonData{
Mode: "verify-ca",
ConfigurationMethod: "file-content",
},
secureJSONData: map[string]string{
"tlsCACert": "I am CA certification",
},
tlsSettings: tlsSettings{
Mode: "verify-ca",
ConfigurationMethod: "file-content",
RootCertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "root.crt"),
CertFile: "",
CertKeyFile: "",
},
},
{
desc: "Custom TLS mode require with client certificates and no root certificate with certificate files content",
updated: updatedTime.Add(4 * time.Minute),
uid: "xxx",
jsonData: sqleng.JsonData{
Mode: "require",
ConfigurationMethod: "file-content",
},
secureJSONData: map[string]string{
"tlsClientCert": "I am client certification",
"tlsClientKey": "I am client key",
},
tlsSettings: tlsSettings{
Mode: "require",
ConfigurationMethod: "file-content",
RootCertFile: "",
CertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.crt"),
CertKeyFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.key"),
},
},
}
for _, tt := range testCases {
t.Run(tt.desc, func(t *testing.T) {
var settings tlsSettings
var err error
mng := tlsManager{
logger: backend.NewLoggerWith("logger", "tsdb.postgres"),
dsCacheInstance: datasourceCacheManager{locker: newLocker()},
dataPath: cfg.DataPath,
}
ds := sqleng.DataSourceInfo{
JsonData: tt.jsonData,
DecryptedSecureJSONData: tt.secureJSONData,
UID: tt.uid,
Updated: tt.updated,
}
settings, err = mng.getTLSSettings(ds)
if tt.expErr == "" {
require.NoError(t, err, tt.desc)
assert.Equal(t, tt.tlsSettings, settings)
} else {
require.Error(t, err, tt.desc)
assert.True(t, strings.HasPrefix(err.Error(), tt.expErr),
fmt.Sprintf("%s: %q doesn't start with %q", tt.desc, err, tt.expErr))
}
})
}
}
func mockValidateCertFilePaths() {
validateCertFunc = func(rootCert, clientCert, clientKey string) error {
return nil
}
}
func resetValidateCertFilePaths() {
validateCertFunc = validateCertFilePaths
}
func mockWriteCertFile() {
writeCertFileCallNum = 0
writeCertFileFunc = func(logger log.Logger, fileContent string, generatedFilePath string) error {
writeCertFileCallNum++
return nil
}
}
func resetWriteCertFile() {
writeCertFileCallNum = 0
writeCertFileFunc = writeCertFile
}