Postgres: Switch the datasource plugin from lib/pq to pgx (#83768)

postgres: switch from lib/pq to pgx
This commit is contained in:
Gábor Farkas
2024-03-13 09:52:39 +01:00
committed by GitHub
parent 2acd48d1c2
commit ecd6de826a
17 changed files with 1082 additions and 947 deletions

11
go.mod
View File

@ -122,7 +122,7 @@ require (
gopkg.in/mail.v2 v2.3.1 // @grafana/backend-platform gopkg.in/mail.v2 v2.3.1 // @grafana/backend-platform
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // @grafana/alerting-squad-backend gopkg.in/yaml.v3 v3.0.1 // @grafana/alerting-squad-backend
xorm.io/builder v0.3.6 // @grafana/backend-platform xorm.io/builder v0.3.6 // indirect; @grafana/backend-platform
xorm.io/core v0.7.3 // @grafana/backend-platform xorm.io/core v0.7.3 // @grafana/backend-platform
xorm.io/xorm v0.8.2 // @grafana/alerting-squad-backend xorm.io/xorm v0.8.2 // @grafana/alerting-squad-backend
) )
@ -174,7 +174,7 @@ require (
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.1-0.20191002090509-6af20e3a5340 // indirect github.com/grpc-ecosystem/go-grpc-prometheus v1.2.1-0.20191002090509-6af20e3a5340 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-msgpack v0.5.5 // indirect github.com/hashicorp/go-msgpack v0.5.5 // indirect
github.com/hashicorp/go-multierror v1.1.1 // @grafana/alerting-squad github.com/hashicorp/go-multierror v1.1.1 // indirect; @grafana/alerting-squad
github.com/hashicorp/go-sockaddr v1.0.6 // indirect github.com/hashicorp/go-sockaddr v1.0.6 // indirect
github.com/hashicorp/golang-lru v0.6.0 // indirect github.com/hashicorp/golang-lru v0.6.0 // indirect
github.com/hashicorp/yamux v0.1.1 // indirect github.com/hashicorp/yamux v0.1.1 // indirect
@ -337,7 +337,7 @@ require (
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/grafana/regexp v0.0.0-20221123153739-15dc172cd2db // indirect github.com/grafana/regexp v0.0.0-20221123153739-15dc172cd2db // indirect
github.com/hashicorp/go-immutable-radix v1.3.1 // indirect github.com/hashicorp/go-immutable-radix v1.3.1 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.7 // @grafana/alerting-squad-backend github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect; @grafana/alerting-squad-backend
github.com/hashicorp/memberlist v0.5.0 // indirect github.com/hashicorp/memberlist v0.5.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/invopop/yaml v0.2.0 // indirect github.com/invopop/yaml v0.2.0 // indirect
@ -493,10 +493,15 @@ require (
github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 // @grafana/grafana-app-platform-squad github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 // @grafana/grafana-app-platform-squad
) )
require github.com/jackc/pgx/v5 v5.5.5 // @grafana/oss-big-tent
require ( require (
github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect github.com/buger/jsonparser v1.1.1 // indirect
github.com/invopop/jsonschema v0.12.0 // indirect github.com/invopop/jsonschema v0.12.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
) )

7
go.sum
View File

@ -2383,6 +2383,7 @@ github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bY
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE=
github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c=
github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
@ -2393,6 +2394,8 @@ github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwX
github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
github.com/jackc/pgproto3/v2 v2.2.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.2.0/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg=
github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc=
github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw=
@ -2404,10 +2407,14 @@ github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9
github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc=
github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs=
github.com/jackc/pgx/v4 v4.15.0/go.mod h1:D/zyOyXiaM1TmVWnOM18p0xdDtdakRBa0RsVGI3U3bw= github.com/jackc/pgx/v4 v4.15.0/go.mod h1:D/zyOyXiaM1TmVWnOM18p0xdDtdakRBa0RsVGI3U3bw=
github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jackc/puddle v1.2.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.2.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jarcoal/httpmock v1.3.0/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg= github.com/jarcoal/httpmock v1.3.0/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg=
github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc= github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc=
github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4= github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4=

View File

@ -1,85 +0,0 @@
package postgres
import (
"fmt"
"sync"
)
// locker is a named reader/writer mutual exclusion lock.
// The lock for each particular key can be held by an arbitrary number of readers or a single writer.
type locker struct {
locks map[any]*sync.RWMutex
locksRW *sync.RWMutex
}
func newLocker() *locker {
return &locker{
locks: make(map[any]*sync.RWMutex),
locksRW: new(sync.RWMutex),
}
}
// Lock locks named rw mutex with specified key for writing.
// If the lock with the same key is already locked for reading or writing,
// Lock blocks until the lock is available.
func (lkr *locker) Lock(key any) {
lk, ok := lkr.getLock(key)
if !ok {
lk = lkr.newLock(key)
}
lk.Lock()
}
// Unlock unlocks named rw mutex with specified key for writing. It is a run-time error if rw is
// not locked for writing on entry to Unlock.
func (lkr *locker) Unlock(key any) {
lk, ok := lkr.getLock(key)
if !ok {
panic(fmt.Errorf("lock for key '%s' not initialized", key))
}
lk.Unlock()
}
// RLock locks named rw mutex with specified key for reading.
//
// It should not be used for recursive read locking for the same key; a blocked Lock
// call excludes new readers from acquiring the lock. See the
// documentation on the golang RWMutex type.
func (lkr *locker) RLock(key any) {
lk, ok := lkr.getLock(key)
if !ok {
lk = lkr.newLock(key)
}
lk.RLock()
}
// RUnlock undoes a single RLock call for specified key;
// it does not affect other simultaneous readers of locker for specified key.
// It is a run-time error if locker for specified key is not locked for reading
func (lkr *locker) RUnlock(key any) {
lk, ok := lkr.getLock(key)
if !ok {
panic(fmt.Errorf("lock for key '%s' not initialized", key))
}
lk.RUnlock()
}
func (lkr *locker) newLock(key any) *sync.RWMutex {
lkr.locksRW.Lock()
defer lkr.locksRW.Unlock()
if lk, ok := lkr.locks[key]; ok {
return lk
}
lk := new(sync.RWMutex)
lkr.locks[key] = lk
return lk
}
func (lkr *locker) getLock(key any) (*sync.RWMutex, bool) {
lkr.locksRW.RLock()
defer lkr.locksRW.RUnlock()
lock, ok := lkr.locks[key]
return lock, ok
}

View File

@ -1,63 +0,0 @@
package postgres
import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestIntegrationLocker(t *testing.T) {
if testing.Short() {
t.Skip("Tests with Sleep")
}
const notUpdated = "not_updated"
const atThread1 = "at_thread_1"
const atThread2 = "at_thread_2"
t.Run("Should lock for same keys", func(t *testing.T) {
updated := notUpdated
locker := newLocker()
locker.Lock(1)
var wg sync.WaitGroup
wg.Add(1)
defer func() {
locker.Unlock(1)
wg.Wait()
}()
go func() {
locker.RLock(1)
defer func() {
locker.RUnlock(1)
wg.Done()
}()
require.Equal(t, atThread1, updated, "Value should be updated in different thread")
updated = atThread2
}()
time.Sleep(time.Millisecond * 10)
require.Equal(t, notUpdated, updated, "Value should not be updated in different thread")
updated = atThread1
})
t.Run("Should not lock for different keys", func(t *testing.T) {
updated := notUpdated
locker := newLocker()
locker.Lock(1)
defer locker.Unlock(1)
var wg sync.WaitGroup
wg.Add(1)
go func() {
locker.RLock(2)
defer func() {
locker.RUnlock(2)
wg.Done()
}()
require.Equal(t, notUpdated, updated, "Value should not be updated in different thread")
updated = atThread2
}()
wg.Wait()
require.Equal(t, atThread2, updated, "Value should be updated in different thread")
updated = atThread1
})
}

View File

@ -4,28 +4,34 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"os"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
pgxstdlib "github.com/jackc/pgx/v5/stdlib"
"github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/datasource" "github.com/grafana/grafana-plugin-sdk-go/backend/datasource"
"github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt" "github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt"
"github.com/grafana/grafana-plugin-sdk-go/backend/proxy"
"github.com/grafana/grafana-plugin-sdk-go/data" "github.com/grafana/grafana-plugin-sdk-go/data"
"github.com/grafana/grafana-plugin-sdk-go/data/sqlutil" "github.com/grafana/grafana-plugin-sdk-go/data/sqlutil"
"github.com/lib/pq"
"github.com/grafana/grafana-plugin-sdk-go/backend/log" "github.com/grafana/grafana-plugin-sdk-go/backend/log"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/grafana-postgresql-datasource/tls"
"github.com/grafana/grafana/pkg/tsdb/sqleng" "github.com/grafana/grafana/pkg/tsdb/sqleng"
) )
func ProvideService(cfg *setting.Cfg) *Service { func ProvideService(cfg *setting.Cfg) *Service {
logger := backend.NewLoggerWith("logger", "tsdb.postgres") logger := backend.NewLoggerWith("logger", "tsdb.postgres")
s := &Service{ s := &Service{
tlsManager: newTLSManager(logger, cfg.DataPath),
logger: logger, logger: logger,
} }
s.im = datasource.NewInstanceManager(s.newInstanceSettings()) s.im = datasource.NewInstanceManager(s.newInstanceSettings())
@ -33,7 +39,6 @@ func ProvideService(cfg *setting.Cfg) *Service {
} }
type Service struct { type Service struct {
tlsManager tlsSettingsProvider
im instancemgmt.InstanceManager im instancemgmt.InstanceManager
logger log.Logger logger log.Logger
} }
@ -55,17 +60,11 @@ func (s *Service) QueryData(ctx context.Context, req *backend.QueryDataRequest)
return dsInfo.QueryData(ctx, req) return dsInfo.QueryData(ctx, req)
} }
func newPostgres(ctx context.Context, userFacingDefaultError string, rowLimit int64, dsInfo sqleng.DataSourceInfo, cnnstr string, logger log.Logger, settings backend.DataSourceInstanceSettings) (*sql.DB, *sqleng.DataSourceHandler, error) { func newPostgres(userFacingDefaultError string, rowLimit int64, dsInfo sqleng.DataSourceInfo, logger log.Logger, proxyClient proxy.Client) (*sql.DB, *sqleng.DataSourceHandler, error) {
connector, err := pq.NewConnector(cnnstr) pgxConf, err := generateConnectionConfig(dsInfo)
if err != nil { if err != nil {
logger.Error("postgres connector creation failed", "error", err) logger.Error("postgres config creation failed", "error", err)
return nil, nil, fmt.Errorf("postgres connector creation failed") return nil, nil, fmt.Errorf("postgres config creation failed")
}
proxyClient, err := settings.ProxyClient(ctx)
if err != nil {
logger.Error("postgres proxy creation failed", "error", err)
return nil, nil, fmt.Errorf("postgres proxy creation failed")
} }
if proxyClient.SecureSocksProxyEnabled() { if proxyClient.SecureSocksProxyEnabled() {
@ -74,9 +73,8 @@ func newPostgres(ctx context.Context, userFacingDefaultError string, rowLimit in
logger.Error("postgres proxy creation failed", "error", err) logger.Error("postgres proxy creation failed", "error", err)
return nil, nil, fmt.Errorf("postgres proxy creation failed") return nil, nil, fmt.Errorf("postgres proxy creation failed")
} }
postgresDialer := newPostgresProxyDialer(dialer)
// update the postgres dialer with the proxy dialer pgxConf.DialFunc = newPgxDialFunc(dialer)
connector.Dialer(postgresDialer)
} }
config := sqleng.DataPluginConfiguration{ config := sqleng.DataPluginConfiguration{
@ -87,7 +85,7 @@ func newPostgres(ctx context.Context, userFacingDefaultError string, rowLimit in
queryResultTransformer := postgresQueryResultTransformer{} queryResultTransformer := postgresQueryResultTransformer{}
db := sql.OpenDB(connector) db := pgxstdlib.OpenDB(*pgxConf)
db.SetMaxOpenConns(config.DSInfo.JsonData.MaxOpenConns) db.SetMaxOpenConns(config.DSInfo.JsonData.MaxOpenConns)
db.SetMaxIdleConns(config.DSInfo.JsonData.MaxIdleConns) db.SetMaxIdleConns(config.DSInfo.JsonData.MaxIdleConns)
@ -143,17 +141,17 @@ func (s *Service) newInstanceSettings() datasource.InstanceFactoryFunc {
DecryptedSecureJSONData: settings.DecryptedSecureJSONData, DecryptedSecureJSONData: settings.DecryptedSecureJSONData,
} }
cnnstr, err := s.generateConnectionString(dsInfo)
if err != nil {
return nil, err
}
userFacingDefaultError, err := cfg.UserFacingDefaultError() userFacingDefaultError, err := cfg.UserFacingDefaultError()
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, handler, err := newPostgres(ctx, userFacingDefaultError, sqlCfg.RowLimit, dsInfo, cnnstr, logger, settings) proxyClient, err := settings.ProxyClient(ctx)
if err != nil {
return nil, err
}
_, handler, err := newPostgres(userFacingDefaultError, sqlCfg.RowLimit, dsInfo, logger, proxyClient)
if err != nil { if err != nil {
logger.Error("Failed connecting to Postgres", "err", err) logger.Error("Failed connecting to Postgres", "err", err)
@ -170,13 +168,11 @@ func escape(input string) string {
return strings.ReplaceAll(strings.ReplaceAll(input, `\`, `\\`), "'", `\'`) return strings.ReplaceAll(strings.ReplaceAll(input, `\`, `\\`), "'", `\'`)
} }
func (s *Service) generateConnectionString(dsInfo sqleng.DataSourceInfo) (string, error) { func generateConnectionConfig(dsInfo sqleng.DataSourceInfo) (*pgx.ConnConfig, error) {
logger := s.logger
var host string var host string
var port int var port int
if strings.HasPrefix(dsInfo.URL, "/") { if strings.HasPrefix(dsInfo.URL, "/") {
host = dsInfo.URL host = dsInfo.URL
logger.Debug("Generating connection string with Unix socket specifier", "socket", host)
} else { } else {
index := strings.LastIndex(dsInfo.URL, ":") index := strings.LastIndex(dsInfo.URL, ":")
v6Index := strings.Index(dsInfo.URL, "]") v6Index := strings.Index(dsInfo.URL, "]")
@ -187,12 +183,8 @@ func (s *Service) generateConnectionString(dsInfo sqleng.DataSourceInfo) (string
var err error var err error
port, err = strconv.Atoi(sp[1]) port, err = strconv.Atoi(sp[1])
if err != nil { if err != nil {
return "", fmt.Errorf("invalid port in host specifier %q: %w", sp[1], err) return nil, fmt.Errorf("invalid port in host specifier %q: %w", sp[1], err)
} }
logger.Debug("Generating connection string with network host/port pair", "host", host, "port", port)
} else {
logger.Debug("Generating connection string with network host", "host", host)
} }
} else { } else {
if index == v6Index+1 { if index == v6Index+1 {
@ -200,46 +192,45 @@ func (s *Service) generateConnectionString(dsInfo sqleng.DataSourceInfo) (string
var err error var err error
port, err = strconv.Atoi(dsInfo.URL[index+1:]) port, err = strconv.Atoi(dsInfo.URL[index+1:])
if err != nil { if err != nil {
return "", fmt.Errorf("invalid port in host specifier %q: %w", dsInfo.URL[index+1:], err) return nil, fmt.Errorf("invalid port in host specifier %q: %w", dsInfo.URL[index+1:], err)
} }
logger.Debug("Generating ipv6 connection string with network host/port pair", "host", host, "port", port)
} else { } else {
host = dsInfo.URL[1 : len(dsInfo.URL)-1] host = dsInfo.URL[1 : len(dsInfo.URL)-1]
logger.Debug("Generating ipv6 connection string with network host", "host", host)
} }
} }
} }
connStr := fmt.Sprintf("user='%s' password='%s' host='%s' dbname='%s'", // NOTE: we always set sslmode=disable in the connection string, we handle TLS manually later
connStr := fmt.Sprintf("sslmode=disable user='%s' password='%s' host='%s' dbname='%s'",
escape(dsInfo.User), escape(dsInfo.DecryptedSecureJSONData["password"]), escape(host), escape(dsInfo.Database)) escape(dsInfo.User), escape(dsInfo.DecryptedSecureJSONData["password"]), escape(host), escape(dsInfo.Database))
if port > 0 { if port > 0 {
connStr += fmt.Sprintf(" port=%d", port) connStr += fmt.Sprintf(" port=%d", port)
} }
tlsSettings, err := s.tlsManager.getTLSSettings(dsInfo) conf, err := pgx.ParseConfig(connStr)
if err != nil { if err != nil {
return "", err return nil, err
} }
connStr += fmt.Sprintf(" sslmode='%s'", escape(tlsSettings.Mode)) tlsConf, err := tls.GetTLSConfig(dsInfo, os.ReadFile, host)
if err != nil {
// Attach root certificate if provided return nil, err
if tlsSettings.RootCertFile != "" {
logger.Debug("Setting server root certificate", "tlsRootCert", tlsSettings.RootCertFile)
connStr += fmt.Sprintf(" sslrootcert='%s'", escape(tlsSettings.RootCertFile))
} }
// Attach client certificate and key if both are provided // before we set the TLS config, we need to make sure the `.Fallbacks` attribute is unset, see:
if tlsSettings.CertFile != "" && tlsSettings.CertKeyFile != "" { // https://github.com/jackc/pgx/discussions/1903#discussioncomment-8430146
logger.Debug("Setting TLS/SSL client auth", "tlsCert", tlsSettings.CertFile, "tlsKey", tlsSettings.CertKeyFile) if len(conf.Fallbacks) > 0 {
connStr += fmt.Sprintf(" sslcert='%s' sslkey='%s'", escape(tlsSettings.CertFile), escape(tlsSettings.CertKeyFile)) return nil, errors.New("tls: fallbacks configured, unable to set up TLS config")
} else if tlsSettings.CertFile != "" || tlsSettings.CertKeyFile != "" { }
return "", fmt.Errorf("TLS/SSL client certificate and key must both be specified") conf.TLSConfig = tlsConf
// by default pgx resolves hostnames to ip addresses. we must avoid this.
// (certain socks-proxy related functionality relies on the hostname being preserved)
conf.LookupFunc = func(ctx context.Context, host string) ([]string, error) {
return []string{host}, nil
} }
logger.Debug("Generated Postgres connection string successfully") return conf, nil
return connStr, nil
} }
type postgresQueryResultTransformer struct{} type postgresQueryResultTransformer struct{}
@ -267,6 +258,44 @@ func (s *Service) CheckHealth(ctx context.Context, req *backend.CheckHealthReque
func (t *postgresQueryResultTransformer) GetConverterList() []sqlutil.StringConverter { func (t *postgresQueryResultTransformer) GetConverterList() []sqlutil.StringConverter {
return []sqlutil.StringConverter{ return []sqlutil.StringConverter{
{
Name: "handle TIME WITH TIME ZONE",
InputScanKind: reflect.Interface,
InputTypeName: strconv.Itoa(pgtype.TimetzOID),
ConversionFunc: func(in *string) (*string, error) { return in, nil },
Replacer: &sqlutil.StringFieldReplacer{
OutputFieldType: data.FieldTypeNullableTime,
ReplaceFunc: func(in *string) (any, error) {
if in == nil {
return nil, nil
}
v, err := time.Parse("15:04:05-07", *in)
if err != nil {
return nil, err
}
return &v, nil
},
},
},
{
Name: "handle TIME",
InputScanKind: reflect.Interface,
InputTypeName: "TIME",
ConversionFunc: func(in *string) (*string, error) { return in, nil },
Replacer: &sqlutil.StringFieldReplacer{
OutputFieldType: data.FieldTypeNullableTime,
ReplaceFunc: func(in *string) (any, error) {
if in == nil {
return nil, nil
}
v, err := time.Parse("15:04:05", *in)
if err != nil {
return nil, err
}
return &v, nil
},
},
},
{ {
Name: "handle FLOAT4", Name: "handle FLOAT4",
InputScanKind: reflect.Interface, InputScanKind: reflect.Interface,

View File

@ -3,7 +3,6 @@ package postgres
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
@ -51,20 +50,6 @@ func TestIntegrationPostgresSnapshots(t *testing.T) {
t.Skip() t.Skip()
} }
getCnnStr := func() string {
host := os.Getenv("POSTGRES_HOST")
if host == "" {
host = "localhost"
}
port := os.Getenv("POSTGRES_PORT")
if port == "" {
port = "5432"
}
return fmt.Sprintf("user=grafanatest password=grafanatest host=%s port=%s dbname=grafanadstest sslmode=disable",
host, port)
}
sqlQueryCommentRe := regexp.MustCompile(`^-- (.+)\n`) sqlQueryCommentRe := regexp.MustCompile(`^-- (.+)\n`)
readSqlFile := func(path string) (string, string) { readSqlFile := func(path string) (string, string) {
@ -148,18 +133,34 @@ func TestIntegrationPostgresSnapshots(t *testing.T) {
ConnMaxLifetime: 14400, ConnMaxLifetime: 14400,
Timescaledb: false, Timescaledb: false,
ConfigurationMethod: "file-path", ConfigurationMethod: "file-path",
Mode: "disable",
}
host := os.Getenv("POSTGRES_HOST")
if host == "" {
host = "localhost"
}
port := os.Getenv("POSTGRES_PORT")
if port == "" {
port = "5432"
} }
dsInfo := sqleng.DataSourceInfo{ dsInfo := sqleng.DataSourceInfo{
JsonData: jsonData, JsonData: jsonData,
DecryptedSecureJSONData: map[string]string{}, DecryptedSecureJSONData: map[string]string{
"password": "grafanatest",
},
URL: host + ":" + port,
Database: "grafanadstest",
User: "grafanatest",
} }
logger := log.New() logger := log.New()
cnnstr := getCnnStr() settings := backend.DataSourceInstanceSettings{}
proxyClient, err := settings.ProxyClient(context.Background())
db, handler, err := newPostgres(context.Background(), "error", 10000, dsInfo, cnnstr, logger, backend.DataSourceInstanceSettings{}) require.NoError(t, err)
db, handler, err := newPostgres("error", 10000, dsInfo, logger, proxyClient)
t.Cleanup((func() { t.Cleanup((func() {
_, err := db.Exec("DROP TABLE tbl") _, err := db.Exec("DROP TABLE tbl")

View File

@ -2,31 +2,33 @@ package postgres
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"math/rand" "math/rand"
"net"
"net/http"
"os" "os"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/log"
backendproxy "github.com/grafana/grafana-plugin-sdk-go/backend/proxy"
"github.com/grafana/grafana-plugin-sdk-go/data" "github.com/grafana/grafana-plugin-sdk-go/data"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/net/proxy"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/tsdb/grafana-postgresql-datasource/tls"
"github.com/grafana/grafana/pkg/tsdb/sqleng" "github.com/grafana/grafana/pkg/tsdb/sqleng"
_ "github.com/lib/pq" _ "github.com/jackc/pgx/v5/stdlib"
) )
// Test generateConnectionString. func TestGenerateConnectionConfig(t *testing.T) {
func TestIntegrationGenerateConnectionString(t *testing.T) { rootCertBytes, err := tls.CreateRandomRootCertBytes()
if testing.Short() { require.NoError(t, err)
t.Skip("skipping integration test")
}
cfg := setting.NewCfg()
cfg.DataPath = t.TempDir()
testCases := []struct { testCases := []struct {
desc string desc string
@ -34,10 +36,15 @@ func TestIntegrationGenerateConnectionString(t *testing.T) {
user string user string
password string password string
database string database string
tlsSettings tlsSettings tlsMode string
expConnStr string tlsRootCert []byte
expErr string expErr string
uid string expHost string
expPort uint16
expUser string
expPassword string
expDatabase string
expTLS bool
}{ }{
{ {
desc: "Unix socket host", desc: "Unix socket host",
@ -45,8 +52,11 @@ func TestIntegrationGenerateConnectionString(t *testing.T) {
user: "user", user: "user",
password: "password", password: "password",
database: "database", database: "database",
tlsSettings: tlsSettings{Mode: "verify-full"}, tlsMode: "disable",
expConnStr: "user='user' password='password' host='/var/run/postgresql' dbname='database' sslmode='verify-full'", expUser: "user",
expPassword: "password",
expHost: "/var/run/postgresql",
expDatabase: "database",
}, },
{ {
desc: "TCP host", desc: "TCP host",
@ -54,8 +64,12 @@ func TestIntegrationGenerateConnectionString(t *testing.T) {
user: "user", user: "user",
password: "password", password: "password",
database: "database", database: "database",
tlsSettings: tlsSettings{Mode: "verify-full"}, tlsMode: "disable",
expConnStr: "user='user' password='password' host='host' dbname='database' sslmode='verify-full'", expUser: "user",
expPassword: "password",
expHost: "host",
expPort: 5432,
expDatabase: "database",
}, },
{ {
desc: "TCP/port host", desc: "TCP/port host",
@ -63,8 +77,12 @@ func TestIntegrationGenerateConnectionString(t *testing.T) {
user: "user", user: "user",
password: "password", password: "password",
database: "database", database: "database",
tlsSettings: tlsSettings{Mode: "verify-full"}, tlsMode: "disable",
expConnStr: "user='user' password='password' host='host' dbname='database' port=1234 sslmode='verify-full'", expUser: "user",
expPassword: "password",
expHost: "host",
expPort: 1234,
expDatabase: "database",
}, },
{ {
desc: "Ipv6 host", desc: "Ipv6 host",
@ -72,8 +90,11 @@ func TestIntegrationGenerateConnectionString(t *testing.T) {
user: "user", user: "user",
password: "password", password: "password",
database: "database", database: "database",
tlsSettings: tlsSettings{Mode: "verify-full"}, tlsMode: "disable",
expConnStr: "user='user' password='password' host='::1' dbname='database' sslmode='verify-full'", expUser: "user",
expPassword: "password",
expHost: "::1",
expDatabase: "database",
}, },
{ {
desc: "Ipv6/port host", desc: "Ipv6/port host",
@ -81,15 +102,19 @@ func TestIntegrationGenerateConnectionString(t *testing.T) {
user: "user", user: "user",
password: "password", password: "password",
database: "database", database: "database",
tlsSettings: tlsSettings{Mode: "verify-full"}, tlsMode: "disable",
expConnStr: "user='user' password='password' host='::1' dbname='database' port=1234 sslmode='verify-full'", expUser: "user",
expPassword: "password",
expHost: "::1",
expPort: 1234,
expDatabase: "database",
}, },
{ {
desc: "Invalid port", desc: "Invalid port",
host: "host:invalid", host: "host:invalid",
user: "user", user: "user",
database: "database", database: "database",
tlsSettings: tlsSettings{}, tlsMode: "disable",
expErr: "invalid port in host specifier", expErr: "invalid port in host specifier",
}, },
{ {
@ -98,8 +123,11 @@ func TestIntegrationGenerateConnectionString(t *testing.T) {
user: "user", user: "user",
password: `p'\assword`, password: `p'\assword`,
database: "database", database: "database",
tlsSettings: tlsSettings{Mode: "verify-full"}, tlsMode: "disable",
expConnStr: `user='user' password='p\'\\assword' host='host' dbname='database' sslmode='verify-full'`, expUser: "user",
expPassword: `p'\assword`,
expHost: "host",
expDatabase: "database",
}, },
{ {
desc: "User/DB with single quote and backslash", desc: "User/DB with single quote and backslash",
@ -107,8 +135,11 @@ func TestIntegrationGenerateConnectionString(t *testing.T) {
user: `u'\ser`, user: `u'\ser`,
password: `password`, password: `password`,
database: `d'\atabase`, database: `d'\atabase`,
tlsSettings: tlsSettings{Mode: "verify-full"}, tlsMode: "disable",
expConnStr: `user='u\'\\ser' password='password' host='host' dbname='d\'\\atabase' sslmode='verify-full'`, expUser: `u'\ser`,
expPassword: "password",
expDatabase: `d'\atabase`,
expHost: "host",
}, },
{ {
desc: "Custom TLS mode disabled", desc: "Custom TLS mode disabled",
@ -116,8 +147,11 @@ func TestIntegrationGenerateConnectionString(t *testing.T) {
user: "user", user: "user",
password: "password", password: "password",
database: "database", database: "database",
tlsSettings: tlsSettings{Mode: "disable"}, tlsMode: "disable",
expConnStr: "user='user' password='password' host='host' dbname='database' sslmode='disable'", expUser: "user",
expPassword: "password",
expHost: "host",
expDatabase: "database",
}, },
{ {
desc: "Custom TLS mode verify-full with certificate files", desc: "Custom TLS mode verify-full with certificate files",
@ -125,36 +159,43 @@ func TestIntegrationGenerateConnectionString(t *testing.T) {
user: "user", user: "user",
password: "password", password: "password",
database: "database", database: "database",
tlsSettings: tlsSettings{ tlsMode: "verify-full",
Mode: "verify-full", tlsRootCert: rootCertBytes,
RootCertFile: "i/am/coding/ca.crt", expUser: "user",
CertFile: "i/am/coding/client.crt", expPassword: "password",
CertKeyFile: "i/am/coding/client.key", expDatabase: "database",
}, expHost: "host",
expConnStr: "user='user' password='password' host='host' dbname='database' sslmode='verify-full' " + expTLS: true,
"sslrootcert='i/am/coding/ca.crt' sslcert='i/am/coding/client.crt' sslkey='i/am/coding/client.key'",
}, },
} }
for _, tt := range testCases { for _, tt := range testCases {
t.Run(tt.desc, func(t *testing.T) { t.Run(tt.desc, func(t *testing.T) {
svc := Service{
tlsManager: &tlsTestManager{settings: tt.tlsSettings},
logger: backend.NewLoggerWith("logger", "tsdb.postgres"),
}
ds := sqleng.DataSourceInfo{ ds := sqleng.DataSourceInfo{
URL: tt.host, URL: tt.host,
User: tt.user, User: tt.user,
DecryptedSecureJSONData: map[string]string{"password": tt.password}, DecryptedSecureJSONData: map[string]string{
"password": tt.password,
"tlsCACert": string(tt.tlsRootCert),
},
Database: tt.database, Database: tt.database,
UID: tt.uid, JsonData: sqleng.JsonData{
Mode: tt.tlsMode,
ConfigurationMethod: "file-content",
},
} }
connStr, err := svc.generateConnectionString(ds) c, err := generateConnectionConfig(ds)
if tt.expErr == "" { if tt.expErr == "" {
require.NoError(t, err, tt.desc) require.NoError(t, err, tt.desc)
assert.Equal(t, tt.expConnStr, connStr) assert.Equal(t, tt.expHost, c.Host)
if tt.expPort != 0 {
assert.Equal(t, tt.expPort, c.Port)
}
assert.Equal(t, tt.expUser, c.User)
assert.Equal(t, tt.expDatabase, c.Database)
assert.Equal(t, tt.expPassword, c.Password)
require.Equal(t, tt.expTLS, c.TLSConfig != nil)
} else { } else {
require.Error(t, err, tt.desc) require.Error(t, err, tt.desc)
assert.True(t, strings.HasPrefix(err.Error(), tt.expErr), assert.True(t, strings.HasPrefix(err.Error(), tt.expErr),
@ -191,24 +232,40 @@ func TestIntegrationPostgres(t *testing.T) {
return sql return sql
} }
host := os.Getenv("POSTGRES_HOST")
if host == "" {
host = "localhost"
}
port := os.Getenv("POSTGRES_PORT")
if port == "" {
port = "5432"
}
jsonData := sqleng.JsonData{ jsonData := sqleng.JsonData{
MaxOpenConns: 0, MaxOpenConns: 0,
MaxIdleConns: 2, MaxIdleConns: 2,
ConnMaxLifetime: 14400, ConnMaxLifetime: 14400,
Timescaledb: false, Timescaledb: false,
ConfigurationMethod: "file-path", ConfigurationMethod: "file-path",
Mode: "disable",
} }
dsInfo := sqleng.DataSourceInfo{ dsInfo := sqleng.DataSourceInfo{
JsonData: jsonData, JsonData: jsonData,
DecryptedSecureJSONData: map[string]string{}, DecryptedSecureJSONData: map[string]string{
"password": "grafanatest",
},
URL: host + ":" + port,
Database: "grafanadstest",
User: "grafanatest",
} }
logger := backend.NewLoggerWith("logger", "postgres.test") logger := backend.NewLoggerWith("logger", "postgres.test")
cnnstr := postgresTestDBConnString() settings := backend.DataSourceInstanceSettings{}
proxyClient, err := settings.ProxyClient(context.Background())
db, exe, err := newPostgres(context.Background(), "error", 10000, dsInfo, cnnstr, logger, backend.DataSourceInstanceSettings{}) require.NoError(t, err)
db, exe, err := newPostgres("error", 10000, dsInfo, logger, proxyClient)
require.NoError(t, err) require.NoError(t, err)
@ -1261,8 +1318,10 @@ func TestIntegrationPostgres(t *testing.T) {
}) })
t.Run("When row limit set to 1", func(t *testing.T) { t.Run("When row limit set to 1", func(t *testing.T) {
dsInfo := sqleng.DataSourceInfo{} settings := backend.DataSourceInstanceSettings{}
_, handler, err := newPostgres(context.Background(), "error", 1, dsInfo, cnnstr, logger, backend.DataSourceInstanceSettings{}) proxyClient, err := settings.ProxyClient(context.Background())
require.NoError(t, err)
_, handler, err := newPostgres("error", 1, dsInfo, logger, proxyClient)
require.NoError(t, err) require.NoError(t, err)
@ -1377,14 +1436,6 @@ func genTimeRangeByInterval(from time.Time, duration time.Duration, interval tim
return timeRange return timeRange
} }
type tlsTestManager struct {
settings tlsSettings
}
func (m *tlsTestManager) getTLSSettings(dsInfo sqleng.DataSourceInfo) (tlsSettings, error) {
return m.settings, nil
}
func isTestDbPostgres() bool { func isTestDbPostgres() bool {
if db, present := os.LookupEnv("GRAFANA_TEST_DB"); present { if db, present := os.LookupEnv("GRAFANA_TEST_DB"); present {
return db == "postgres" return db == "postgres"
@ -1393,15 +1444,59 @@ func isTestDbPostgres() bool {
return false return false
} }
func postgresTestDBConnString() string { type testNoResolveDialer struct {
host := os.Getenv("POSTGRES_HOST") }
if host == "" {
host = "localhost" func (d *testNoResolveDialer) Dial(network, addr string) (c net.Conn, err error) {
} return nil, fmt.Errorf("test-dialer: dialing to '%s'. not implemented", addr)
port := os.Getenv("POSTGRES_PORT") }
if port == "" {
port = "5432" var _ proxy.Dialer = (&testNoResolveDialer{})
}
return fmt.Sprintf("user=grafanatest password=grafanatest host=%s port=%s dbname=grafanadstest sslmode=disable", type testNoResolveProxyClient struct {
host, port) }
var _ backendproxy.Client = (&testNoResolveProxyClient{})
func (p *testNoResolveProxyClient) SecureSocksProxyEnabled() bool {
return true
}
func (p *testNoResolveProxyClient) ConfigureSecureSocksHTTPProxy(transport *http.Transport) error {
return errors.New("testNoResolveProxyClient.ConfigureSecureSocksHTTPProxy not implemented")
}
func (p *testNoResolveProxyClient) NewSecureSocksProxyContextDialer() (proxy.Dialer, error) {
return &testNoResolveDialer{}, nil
}
// we must make sure that pgx does not resolve hostnames:
// if we say the hostname is `localhost`, then it should
// instruct the socks-proxy to connect to `localhost`, not to `127.0.0.1`
// this is important, becase some other socks-proxy-code relies on this behavior.
func TestNoResolve(t *testing.T) {
jsonData := sqleng.JsonData{
MaxOpenConns: 0,
MaxIdleConns: 2,
ConnMaxLifetime: 14400,
Timescaledb: false,
ConfigurationMethod: "file-path",
}
dsInfo := sqleng.DataSourceInfo{
JsonData: jsonData,
DecryptedSecureJSONData: map[string]string{
"password": "password",
},
URL: "localhost:5432",
Database: "db",
User: "user",
}
db, _, err := newPostgres("error", 10000, dsInfo, log.New(), &testNoResolveProxyClient{})
require.NoError(t, err)
require.NotNil(t, db)
err = db.Ping()
require.Error(t, err)
require.Contains(t, err.Error(), "test-dialer: dialing to 'localhost:5432'. not implemented")
} }

View File

@ -3,33 +3,17 @@ package postgres
import ( import (
"context" "context"
"net" "net"
"time"
"github.com/lib/pq"
"golang.org/x/net/proxy" "golang.org/x/net/proxy"
) )
// we wrap the proxy.Dialer to become dialer that the postgres module accepts type PgxDialFunc = func(ctx context.Context, network string, address string) (net.Conn, error)
func newPostgresProxyDialer(dialer proxy.Dialer) pq.Dialer {
return &postgresProxyDialer{d: dialer} func newPgxDialFunc(dialer proxy.Dialer) PgxDialFunc {
} dialFunc :=
func(ctx context.Context, network string, addr string) (net.Conn, error) {
var _ pq.Dialer = (&postgresProxyDialer{}) return dialer.Dial(network, addr)
}
// postgresProxyDialer implements the postgres dialer using a proxy dialer, as their functions differ slightly
type postgresProxyDialer struct { return dialFunc
d proxy.Dialer
}
// Dial uses the normal proxy dial function with the updated dialer
func (p *postgresProxyDialer) Dial(network, addr string) (c net.Conn, err error) {
return p.d.Dial(network, addr)
}
// DialTimeout uses the normal postgres dial timeout function with the updated dialer
func (p *postgresProxyDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return p.d.(proxy.ContextDialer).DialContext(ctx, network, address)
} }

View File

@ -1,12 +1,12 @@
package postgres package postgres
import ( import (
"database/sql"
"fmt" "fmt"
"net" "net"
"testing" "testing"
"github.com/lib/pq" "github.com/jackc/pgx/v5"
pgxstdlib "github.com/jackc/pgx/v5/stdlib"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/net/proxy" "golang.org/x/net/proxy"
) )
@ -25,13 +25,13 @@ func TestPostgresProxyDriver(t *testing.T) {
cnnstr := fmt.Sprintf("postgres://auser:password@%s/db?sslmode=disable", dbURL) cnnstr := fmt.Sprintf("postgres://auser:password@%s/db?sslmode=disable", dbURL)
t.Run("Connector should use dialer context that routes through the socks proxy to db", func(t *testing.T) { t.Run("Connector should use dialer context that routes through the socks proxy to db", func(t *testing.T) {
connector, err := pq.NewConnector(cnnstr) pgxConf, err := pgx.ParseConfig(cnnstr)
require.NoError(t, err) require.NoError(t, err)
dialer := newPostgresProxyDialer(&testDialer{})
connector.Dialer(dialer) pgxConf.DialFunc = newPgxDialFunc(&testDialer{})
db := pgxstdlib.OpenDB(*pgxConf)
db := sql.OpenDB(connector)
err = db.Ping() err = db.Ping()
require.Contains(t, err.Error(), "test-dialer is not functional") require.Contains(t, err.Error(), "test-dialer is not functional")

View File

@ -9,16 +9,16 @@
// } // }
// Name: // Name:
// Dimensions: 4 Fields by 4 Rows // Dimensions: 4 Fields by 4 Rows
// +--------------------------------------+-------------------------------+------------------+-----------------------------------+ // +--------------------------------------+-----------------------------------+-----------------------+-----------------------------------+
// | Name: reallyt | Name: time | Name: n | Name: timeend | // | Name: reallyt | Name: time | Name: n | Name: timeend |
// | Labels: | Labels: | Labels: | Labels: | // | Labels: | Labels: | Labels: | Labels: |
// | Type: []*time.Time | Type: []*time.Time | Type: []*float64 | Type: []*time.Time | // | Type: []*time.Time | Type: []*time.Time | Type: []*float64 | Type: []*time.Time |
// +--------------------------------------+-------------------------------+------------------+-----------------------------------+ // +--------------------------------------+-----------------------------------+-----------------------+-----------------------------------+
// | 2023-12-21 12:22:24 +0000 UTC | 2023-12-21 12:21:40 +0000 UTC | 1.7031613e+09 | 2023-12-21 12:22:52 +0000 UTC | // | 2023-12-21 12:22:24 +0000 UTC | 2023-12-21 12:22:24 +0000 UTC | 1.703161344e+09 | 2023-12-21 12:22:52 +0000 UTC |
// | 2023-12-21 12:20:33.408 +0000 UTC | 2023-12-21 12:20:00 +0000 UTC | 1.7031612e+12 | 2023-12-21 12:21:52.522 +0000 UTC | // | 2023-12-21 12:20:33.408 +0000 UTC | 2023-12-21 12:20:33.408 +0000 UTC | 1.703161233408e+12 | 2023-12-21 12:21:52.522 +0000 UTC |
// | 2023-12-21 12:20:41.050022 +0000 UTC | 2023-12-21 12:20:00 +0000 UTC | 1.7031612e+18 | 2023-12-21 12:21:52.522 +0000 UTC | // | 2023-12-21 12:20:41.050022 +0000 UTC | 2023-12-21 12:20:41.05 +0000 UTC | 1.703161241050022e+18 | 2023-12-21 12:21:52.522 +0000 UTC |
// | null | null | null | null | // | null | null | null | null |
// +--------------------------------------+-------------------------------+------------------+-----------------------------------+ // +--------------------------------------+-----------------------------------+-----------------------+-----------------------------------+
// //
// //
// 🌟 This was machine generated. Do not edit. 🌟 // 🌟 This was machine generated. Do not edit. 🌟
@ -78,15 +78,15 @@
null null
], ],
[ [
1703161300000, 1703161344000,
1703161200000, 1703161233408,
1703161200000, 1703161241050,
null null
], ],
[ [
1703161300, 1703161344,
1703161200000, 1703161233408,
1703161200000000000, 1703161241050022000,
null null
], ],
[ [

View File

@ -9,14 +9,14 @@
// } // }
// Name: // Name:
// Dimensions: 12 Fields by 2 Rows // Dimensions: 12 Fields by 2 Rows
// +----------------------------------------+----------------------------------------+--------------------------------------+--------------------------------------+---------------------------------+---------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+ // +--------------------------------------+--------------------------------------+--------------------------------------+--------------------------------------+-------------------------------+-------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+
// | Name: ts | Name: tsnn | Name: tsz | Name: tsznn | Name: d | Name: dnn | Name: t | Name: tnn | Name: tz | Name: tznn | Name: i | Name: inn | // | Name: ts | Name: tsnn | Name: tsz | Name: tsznn | Name: d | Name: dnn | Name: t | Name: tnn | Name: tz | Name: tznn | Name: i | Name: inn |
// | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | // | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: | Labels: |
// | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*string | Type: []*string | // | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*time.Time | Type: []*string | Type: []*string |
// +----------------------------------------+----------------------------------------+--------------------------------------+--------------------------------------+---------------------------------+---------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+ // +--------------------------------------+--------------------------------------+--------------------------------------+--------------------------------------+-------------------------------+-------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+
// | 2023-11-15 05:06:07.123456 +0000 +0000 | 2023-11-15 05:06:08.123456 +0000 +0000 | 2021-07-22 11:22:33.654321 +0000 UTC | 2021-07-22 11:22:34.654321 +0000 UTC | 2023-12-20 00:00:00 +0000 +0000 | 2023-12-21 00:00:00 +0000 +0000 | 0000-01-01 12:34:56.234567 +0000 UTC | 0000-01-01 12:34:57.234567 +0000 UTC | 0000-01-01 23:12:36.765432 +0100 +0100 | 0000-01-01 23:12:37.765432 +0100 +0100 | 00:00:00.987654 | 00:00:00.887654 | // | 2023-11-15 05:06:07.123456 +0000 UTC | 2023-11-15 05:06:08.123456 +0000 UTC | 2021-07-22 11:22:33.654321 +0000 UTC | 2021-07-22 11:22:34.654321 +0000 UTC | 2023-12-20 00:00:00 +0000 UTC | 2023-12-21 00:00:00 +0000 UTC | 0000-01-01 12:34:56.234567 +0000 UTC | 0000-01-01 12:34:57.234567 +0000 UTC | 0000-01-01 23:12:36.765432 +0100 +0100 | 0000-01-01 23:12:37.765432 +0100 +0100 | 00:00:00.987654 | 00:00:00.887654 |
// | null | 2023-11-15 05:06:09.123456 +0000 +0000 | null | 2021-07-22 11:22:35.654321 +0000 UTC | null | 2023-12-22 00:00:00 +0000 +0000 | null | 0000-01-01 12:34:58.234567 +0000 UTC | null | 0000-01-01 23:12:38.765432 +0100 +0100 | null | 00:00:00.787654 | // | null | 2023-11-15 05:06:09.123456 +0000 UTC | null | 2021-07-22 11:22:35.654321 +0000 UTC | null | 2023-12-22 00:00:00 +0000 UTC | null | 0000-01-01 12:34:58.234567 +0000 UTC | null | 0000-01-01 23:12:38.765432 +0100 +0100 | null | 00:00:00.787654 |
// +----------------------------------------+----------------------------------------+--------------------------------------+--------------------------------------+---------------------------------+---------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+ // +--------------------------------------+--------------------------------------+--------------------------------------+--------------------------------------+-------------------------------+-------------------------------+--------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------+-----------------+-----------------+
// //
// //
// 🌟 This was machine generated. Do not edit. 🌟 // 🌟 This was machine generated. Do not edit. 🌟

View File

@ -0,0 +1,135 @@
package tls
import (
"crypto/tls"
"crypto/x509"
"errors"
"github.com/grafana/grafana/pkg/tsdb/sqleng"
)
// we support 4 postgres tls modes:
// disable - no tls
// require - use tls
// verify-ca - use tls, verify root cert but not the hostname
// verify-full - use tls, verify root cert
// (for all the options except `disable`, you can optionally use client certificates)
func getTLSConfigRequire(certs *Certs, serverName string) (*tls.Config, error) {
// see https://www.postgresql.org/docs/12/libpq-ssl.html ,
// mode=require + provided root-cert should behave as mode=verify-ca
if certs.rootCerts != nil {
return getTLSConfigVerifyCA(certs, serverName)
}
return &tls.Config{
InsecureSkipVerify: true, // we do not verify the root cert
Certificates: certs.clientCerts,
ServerName: serverName,
}, nil
}
// to implement the verify-ca mode, we need to do this:
// - for the root certificate
// - verify that the certificate we receive from the server is trusted,
// meaning it relates to our root certificate
// - we DO NOT verify that the hostname of the database matches
// the hostname in the certificate
//
// the problem is, `go“ does not offer such an option.
// by default, it will verify both things.
//
// so what we do is:
// - we turn off the default-verification with `InsecureSkipVerify`
// - we implement our own verification using `VerifyConnection`
//
// extra info about this:
// - there is a rejected feature-request about this at https://github.com/golang/go/issues/21971
// - the recommended workaround is based on VerifyPeerCertificate
// - there is even example code at https://github.com/golang/go/commit/29cfb4d3c3a97b6f426d1b899234da905be699aa
// - but later the example code was changed to use VerifyConnection instead:
// https://github.com/golang/go/commit/7eb5941b95a588a23f18fa4c22fe42ff0119c311
//
// a verifyConnection example is at https://pkg.go.dev/crypto/tls#example-Config-VerifyConnection .
//
// this is how the `pgx` library handles verify-ca:
//
// https://github.com/jackc/pgx/blob/5c63f646f820ca9696fc3515c1caf2a557d562e5/pgconn/config.go#L657-L690
// (unfortunately pgx only handles this for certificate-provided-as-path, so we cannot rely on it)
func getTLSConfigVerifyCA(certs *Certs, serverName string) (*tls.Config, error) {
conf := tls.Config{
ServerName: serverName,
Certificates: certs.clientCerts,
InsecureSkipVerify: true, // we turn off the default-verification, we'll do VerifyConnection instead
VerifyConnection: func(state tls.ConnectionState) error {
// we add all the certificates to the pool, we skip the first cert.
intermediates := x509.NewCertPool()
for _, c := range state.PeerCertificates[1:] {
intermediates.AddCert(c)
}
opts := x509.VerifyOptions{
Roots: certs.rootCerts,
Intermediates: intermediates,
}
// we call `Verify()` on the first cert (that we skipped previously)
_, err := state.PeerCertificates[0].Verify(opts)
return err
},
RootCAs: certs.rootCerts,
}
return &conf, nil
}
func getTLSConfigVerifyFull(certs *Certs, serverName string) (*tls.Config, error) {
conf := tls.Config{
Certificates: certs.clientCerts,
ServerName: serverName,
RootCAs: certs.rootCerts,
}
return &conf, nil
}
func IsTLSEnabled(dsInfo sqleng.DataSourceInfo) bool {
mode := dsInfo.JsonData.Mode
return mode != "disable"
}
// returns `nil` if tls is disabled
func GetTLSConfig(dsInfo sqleng.DataSourceInfo, readFile ReadFileFunc, serverName string) (*tls.Config, error) {
mode := dsInfo.JsonData.Mode
// we need to special-case the no-tls-mode
if mode == "disable" {
return nil, nil
}
// for all the remaining cases we need to load
// both the root-cert if exists, and the client-cert if exists.
certBytes, err := loadCertificateBytes(dsInfo, readFile)
if err != nil {
return nil, err
}
certs, err := createCertificates(certBytes)
if err != nil {
return nil, err
}
switch mode {
// `disable` already handled
case "":
// for backward-compatibility reasons this is the same as `require`
return getTLSConfigRequire(certs, serverName)
case "require":
return getTLSConfigRequire(certs, serverName)
case "verify-ca":
return getTLSConfigVerifyCA(certs, serverName)
case "verify-full":
return getTLSConfigVerifyFull(certs, serverName)
default:
return nil, errors.New("tls: invalid mode " + mode)
}
}

View File

@ -0,0 +1,101 @@
package tls
import (
"crypto/tls"
"crypto/x509"
"errors"
"github.com/grafana/grafana/pkg/tsdb/sqleng"
)
// this file deals with locating and loading the certificates,
// from json-data or from disk.
type CertBytes struct {
rootCert []byte
clientKey []byte
clientCert []byte
}
type ReadFileFunc = func(name string) ([]byte, error)
var errPartialClientCertNoKey = errors.New("tls: client cert provided but client key missing")
var errPartialClientCertNoCert = errors.New("tls: client key provided but client cert missing")
// certificates can be stored either as encrypted-json-data, or as file-path
func loadCertificateBytes(dsInfo sqleng.DataSourceInfo, readFile ReadFileFunc) (*CertBytes, error) {
if dsInfo.JsonData.ConfigurationMethod == "file-content" {
return &CertBytes{
rootCert: []byte(dsInfo.DecryptedSecureJSONData["tlsCACert"]),
clientKey: []byte(dsInfo.DecryptedSecureJSONData["tlsClientKey"]),
clientCert: []byte(dsInfo.DecryptedSecureJSONData["tlsClientCert"]),
}, nil
} else {
c := CertBytes{}
if dsInfo.JsonData.RootCertFile != "" {
rootCert, err := readFile(dsInfo.JsonData.RootCertFile)
if err != nil {
return nil, err
}
c.rootCert = rootCert
}
if dsInfo.JsonData.CertKeyFile != "" {
clientKey, err := readFile(dsInfo.JsonData.CertKeyFile)
if err != nil {
return nil, err
}
c.clientKey = clientKey
}
if dsInfo.JsonData.CertFile != "" {
clientCert, err := readFile(dsInfo.JsonData.CertFile)
if err != nil {
return nil, err
}
c.clientCert = clientCert
}
return &c, nil
}
}
type Certs struct {
clientCerts []tls.Certificate
rootCerts *x509.CertPool
}
func createCertificates(certBytes *CertBytes) (*Certs, error) {
certs := Certs{}
if len(certBytes.rootCert) > 0 {
pool := x509.NewCertPool()
ok := pool.AppendCertsFromPEM(certBytes.rootCert)
if !ok {
return nil, errors.New("tls: failed to add root certificate")
}
certs.rootCerts = pool
}
hasClientKey := len(certBytes.clientKey) > 0
hasClientCert := len(certBytes.clientCert) > 0
if hasClientKey && hasClientCert {
cert, err := tls.X509KeyPair(certBytes.clientCert, certBytes.clientKey)
if err != nil {
return nil, err
}
certs.clientCerts = []tls.Certificate{cert}
}
if hasClientKey && (!hasClientCert) {
return nil, errPartialClientCertNoCert
}
if hasClientCert && (!hasClientKey) {
return nil, errPartialClientCertNoKey
}
return &certs, nil
}

View File

@ -0,0 +1,402 @@
package tls
import (
"errors"
"os"
"testing"
"github.com/grafana/grafana/pkg/tsdb/sqleng"
"github.com/stretchr/testify/require"
)
func noReadFile(path string) ([]byte, error) {
return nil, errors.New("not implemented")
}
func TestTLSNoMode(t *testing.T) {
// for backward-compatibility reason,
// when mode is unset, it defaults to `require`
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
ConfigurationMethod: "",
},
}
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost")
require.NoError(t, err)
require.NotNil(t, c)
require.True(t, c.InsecureSkipVerify)
}
func TestTLSDisable(t *testing.T) {
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "disable",
ConfigurationMethod: "",
},
}
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost")
require.NoError(t, err)
require.Nil(t, c)
}
func TestTLSRequire(t *testing.T) {
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "require",
ConfigurationMethod: "",
},
}
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost")
require.NoError(t, err)
require.NotNil(t, c)
require.True(t, c.InsecureSkipVerify)
require.Nil(t, c.RootCAs)
}
func TestTLSRequireWithRootCert(t *testing.T) {
rootCertBytes, err := CreateRandomRootCertBytes()
require.NoError(t, err)
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "require",
ConfigurationMethod: "file-content",
},
DecryptedSecureJSONData: map[string]string{
"tlsCACert": string(rootCertBytes),
},
}
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost")
require.NoError(t, err)
require.NotNil(t, c)
require.True(t, c.InsecureSkipVerify)
require.NotNil(t, c.VerifyConnection)
require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available
}
func TestTLSVerifyCA(t *testing.T) {
rootCertBytes, err := CreateRandomRootCertBytes()
require.NoError(t, err)
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "verify-ca",
ConfigurationMethod: "file-content",
},
DecryptedSecureJSONData: map[string]string{
"tlsCACert": string(rootCertBytes),
},
}
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost")
require.NoError(t, err)
require.NotNil(t, c)
require.True(t, c.InsecureSkipVerify)
require.NotNil(t, c.VerifyConnection)
require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available
}
func TestTLSVerifyCANoRootCertProvided(t *testing.T) {
// this is ok. go will use the default system certs
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "verify-ca",
ConfigurationMethod: "file-content",
},
DecryptedSecureJSONData: map[string]string{},
}
_, err := GetTLSConfig(dsInfo, noReadFile, "localhost")
require.NoError(t, err)
}
func TestTLSClientCert(t *testing.T) {
clientKey, clientCert, err := CreateRandomClientCert()
require.NoError(t, err)
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "require",
ConfigurationMethod: "file-content",
},
DecryptedSecureJSONData: map[string]string{
"tlsClientCert": string(clientCert),
"tlsClientKey": string(clientKey),
},
}
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost")
require.NoError(t, err)
require.NotNil(t, c)
require.Len(t, c.Certificates, 1)
}
func TestTLSMethodFileContentClientCertMissingKey(t *testing.T) {
_, clientCert, err := CreateRandomClientCert()
require.NoError(t, err)
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "require",
ConfigurationMethod: "file-content",
},
DecryptedSecureJSONData: map[string]string{
"tlsClientCert": string(clientCert),
},
}
_, err = GetTLSConfig(dsInfo, noReadFile, "localhost")
require.ErrorIs(t, err, errPartialClientCertNoKey)
}
func TestTLSMethodFileContentClientCertMissingCert(t *testing.T) {
clientKey, _, err := CreateRandomClientCert()
require.NoError(t, err)
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "require",
ConfigurationMethod: "file-content",
},
DecryptedSecureJSONData: map[string]string{
"tlsClientKey": string(clientKey),
},
}
_, err = GetTLSConfig(dsInfo, noReadFile, "localhost")
require.ErrorIs(t, err, errPartialClientCertNoCert)
}
func TestTLSMethodFilePathClientCertMissingKey(t *testing.T) {
clientKey, _, err := CreateRandomClientCert()
require.NoError(t, err)
readFile := newMockReadFile(map[string]([]byte){
"path1": clientKey,
})
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "require",
ConfigurationMethod: "file-path",
CertKeyFile: "path1",
},
}
_, err = GetTLSConfig(dsInfo, readFile, "localhost")
require.ErrorIs(t, err, errPartialClientCertNoCert)
}
func TestTLSMethodFilePathClientCertMissingCert(t *testing.T) {
_, clientCert, err := CreateRandomClientCert()
require.NoError(t, err)
readFile := newMockReadFile(map[string]([]byte){
"path1": clientCert,
})
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "require",
ConfigurationMethod: "file-path",
CertFile: "path1",
},
}
_, err = GetTLSConfig(dsInfo, readFile, "localhost")
require.ErrorIs(t, err, errPartialClientCertNoKey)
}
func TestTLSVerifyFull(t *testing.T) {
rootCertBytes, err := CreateRandomRootCertBytes()
require.NoError(t, err)
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "verify-full",
ConfigurationMethod: "file-content",
},
DecryptedSecureJSONData: map[string]string{
"tlsCACert": string(rootCertBytes),
},
}
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost")
require.NoError(t, err)
require.NotNil(t, c)
require.False(t, c.InsecureSkipVerify)
require.Nil(t, c.VerifyConnection)
require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available
}
func TestTLSMethodFileContent(t *testing.T) {
rootCertBytes, err := CreateRandomRootCertBytes()
require.NoError(t, err)
clientKey, clientCert, err := CreateRandomClientCert()
require.NoError(t, err)
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "verify-full",
ConfigurationMethod: "file-content",
},
DecryptedSecureJSONData: map[string]string{
"tlsCACert": string(rootCertBytes),
"tlsClientCert": string(clientCert),
"tlsClientKey": string(clientKey),
},
}
c, err := GetTLSConfig(dsInfo, noReadFile, "localhost")
require.NoError(t, err)
require.NotNil(t, c)
require.Len(t, c.Certificates, 1)
require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available
}
func TestTLSMethodFilePath(t *testing.T) {
rootCertBytes, err := CreateRandomRootCertBytes()
require.NoError(t, err)
clientKey, clientCert, err := CreateRandomClientCert()
require.NoError(t, err)
readFile := newMockReadFile(map[string]([]byte){
"root-cert-path": rootCertBytes,
"client-key-path": clientKey,
"client-cert-path": clientCert,
})
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "verify-full",
ConfigurationMethod: "file-path",
RootCertFile: "root-cert-path",
CertKeyFile: "client-key-path",
CertFile: "client-cert-path",
},
}
c, err := GetTLSConfig(dsInfo, readFile, "localhost")
require.NoError(t, err)
require.NotNil(t, c)
require.Len(t, c.Certificates, 1)
require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available
}
func TestTLSMethodFilePathRootCertDoesNotExist(t *testing.T) {
readFile := newMockReadFile(map[string]([]byte){})
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "verify-full",
ConfigurationMethod: "file-path",
RootCertFile: "path1",
},
}
_, err := GetTLSConfig(dsInfo, readFile, "localhost")
require.ErrorIs(t, err, os.ErrNotExist)
}
func TestTLSMethodFilePathClientCertKeyDoesNotExist(t *testing.T) {
_, clientCert, err := CreateRandomClientCert()
require.NoError(t, err)
readFile := newMockReadFile(map[string]([]byte){
"cert-path": clientCert,
})
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "require",
ConfigurationMethod: "file-path",
CertKeyFile: "key-path",
CertFile: "cert-path",
},
}
_, err = GetTLSConfig(dsInfo, readFile, "localhost")
require.ErrorIs(t, err, os.ErrNotExist)
}
func TestTLSMethodFilePathClientCertCertDoesNotExist(t *testing.T) {
clientKey, _, err := CreateRandomClientCert()
require.NoError(t, err)
readFile := newMockReadFile(map[string]([]byte){
"key-path": clientKey,
})
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "require",
ConfigurationMethod: "file-path",
CertKeyFile: "key-path",
CertFile: "cert-path",
},
}
_, err = GetTLSConfig(dsInfo, readFile, "localhost")
require.ErrorIs(t, err, os.ErrNotExist)
}
// method="" equals to method="file-path"
func TestTLSMethodEmpty(t *testing.T) {
rootCertBytes, err := CreateRandomRootCertBytes()
require.NoError(t, err)
clientKey, clientCert, err := CreateRandomClientCert()
require.NoError(t, err)
readFile := newMockReadFile(map[string]([]byte){
"root-cert-path": rootCertBytes,
"client-key-path": clientKey,
"client-cert-path": clientCert,
})
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "verify-full",
ConfigurationMethod: "",
RootCertFile: "root-cert-path",
CertKeyFile: "client-key-path",
CertFile: "client-cert-path",
},
}
c, err := GetTLSConfig(dsInfo, readFile, "localhost")
require.NoError(t, err)
require.NotNil(t, c)
require.Len(t, c.Certificates, 1)
require.NotNil(t, c.RootCAs) // TODO: not the best, but nothing better available
}
func TestTLSVerifyFullNoRootCertProvided(t *testing.T) {
// this is ok. go will use the default system certs
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "verify-full",
ConfigurationMethod: "file-content",
},
DecryptedSecureJSONData: map[string]string{},
}
_, err := GetTLSConfig(dsInfo, noReadFile, "localhost")
require.NoError(t, err)
}
func TestTLSInvalidMode(t *testing.T) {
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: "not-a-valid-mode",
},
}
_, err := GetTLSConfig(dsInfo, noReadFile, "localhost")
require.Error(t, err)
}
func TestTLSServerNameSetInEveryMode(t *testing.T) {
modes := []string{"require", "verify-ca", "verify-full"}
for _, mode := range modes {
t.Run(mode, func(t *testing.T) {
dsInfo := sqleng.DataSourceInfo{
JsonData: sqleng.JsonData{
Mode: mode,
},
DecryptedSecureJSONData: map[string]string{},
}
c, err := GetTLSConfig(dsInfo, noReadFile, "example.com")
require.NoError(t, err)
require.Equal(t, "example.com", c.ServerName)
})
}
}

View File

@ -0,0 +1,105 @@
package tls
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"time"
)
func CreateRandomRootCertBytes() ([]byte, error) {
cert := x509.Certificate{
SerialNumber: big.NewInt(42),
Subject: pkix.Name{
CommonName: "test1",
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, err
}
bytes, err := x509.CreateCertificate(rand.Reader, &cert, &cert, &key.PublicKey, key)
if err != nil {
return nil, err
}
return pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: bytes,
}), nil
}
func CreateRandomClientCert() ([]byte, []byte, error) {
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, err
}
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, err
}
keyBytes := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
caCert := x509.Certificate{
SerialNumber: big.NewInt(42),
Subject: pkix.Name{
CommonName: "test1",
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
cert := x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{
CommonName: "test1",
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
}
certData, err := x509.CreateCertificate(rand.Reader, &cert, &caCert, &key.PublicKey, caKey)
if err != nil {
return nil, nil, err
}
certBytes := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certData,
})
return keyBytes, certBytes, nil
}
func newMockReadFile(data map[string]([]byte)) ReadFileFunc {
return func(path string) ([]byte, error) {
bytes, ok := data[path]
if !ok {
return nil, os.ErrNotExist
}
return bytes, nil
}
}

View File

@ -1,249 +0,0 @@
package postgres
import (
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/grafana/grafana-plugin-sdk-go/backend/log"
"github.com/grafana/grafana/pkg/tsdb/sqleng"
)
var validateCertFunc = validateCertFilePaths
var writeCertFileFunc = writeCertFile
type certFileType int
const (
rootCert = iota
clientCert
clientKey
)
type tlsSettingsProvider interface {
getTLSSettings(dsInfo sqleng.DataSourceInfo) (tlsSettings, error)
}
type datasourceCacheManager struct {
locker *locker
cache sync.Map
}
type tlsManager struct {
logger log.Logger
dsCacheInstance datasourceCacheManager
dataPath string
}
func newTLSManager(logger log.Logger, dataPath string) tlsSettingsProvider {
return &tlsManager{
logger: logger,
dataPath: dataPath,
dsCacheInstance: datasourceCacheManager{locker: newLocker()},
}
}
type tlsSettings struct {
Mode string
ConfigurationMethod string
RootCertFile string
CertFile string
CertKeyFile string
}
func (m *tlsManager) getTLSSettings(dsInfo sqleng.DataSourceInfo) (tlsSettings, error) {
tlsconfig := tlsSettings{
Mode: dsInfo.JsonData.Mode,
}
isTLSDisabled := (tlsconfig.Mode == "disable")
if isTLSDisabled {
m.logger.Debug("Postgres TLS/SSL is disabled")
return tlsconfig, nil
}
m.logger.Debug("Postgres TLS/SSL is enabled", "tlsMode", tlsconfig.Mode)
tlsconfig.ConfigurationMethod = dsInfo.JsonData.ConfigurationMethod
tlsconfig.RootCertFile = dsInfo.JsonData.RootCertFile
tlsconfig.CertFile = dsInfo.JsonData.CertFile
tlsconfig.CertKeyFile = dsInfo.JsonData.CertKeyFile
if tlsconfig.ConfigurationMethod == "file-content" {
if err := m.writeCertFiles(dsInfo, &tlsconfig); err != nil {
return tlsconfig, err
}
} else {
if err := validateCertFunc(tlsconfig.RootCertFile, tlsconfig.CertFile, tlsconfig.CertKeyFile); err != nil {
return tlsconfig, err
}
}
return tlsconfig, nil
}
func (t certFileType) String() string {
switch t {
case rootCert:
return "root certificate"
case clientCert:
return "client certificate"
case clientKey:
return "client key"
default:
panic(fmt.Sprintf("Unrecognized certFileType %d", t))
}
}
func getFileName(dataDir string, fileType certFileType) string {
var filename string
switch fileType {
case rootCert:
filename = "root.crt"
case clientCert:
filename = "client.crt"
case clientKey:
filename = "client.key"
default:
panic(fmt.Sprintf("unrecognized certFileType %s", fileType.String()))
}
generatedFilePath := filepath.Join(dataDir, filename)
return generatedFilePath
}
// writeCertFile writes a certificate file.
func writeCertFile(logger log.Logger, fileContent string, generatedFilePath string) error {
fileContent = strings.TrimSpace(fileContent)
if fileContent != "" {
logger.Debug("Writing cert file", "path", generatedFilePath)
if err := os.WriteFile(generatedFilePath, []byte(fileContent), 0600); err != nil {
return err
}
// Make sure the file has the permissions expected by the Postgresql driver, otherwise it will bail
if err := os.Chmod(generatedFilePath, 0600); err != nil {
return err
}
return nil
}
logger.Debug("Deleting cert file since no content is provided", "path", generatedFilePath)
exists, err := fileExists(generatedFilePath)
if err != nil {
return err
}
if exists {
if err := os.Remove(generatedFilePath); err != nil {
return fmt.Errorf("failed to remove %q: %w", generatedFilePath, err)
}
}
return nil
}
func (m *tlsManager) writeCertFiles(dsInfo sqleng.DataSourceInfo, tlsconfig *tlsSettings) error {
m.logger.Debug("Writing TLS certificate files to disk")
tlsRootCert := dsInfo.DecryptedSecureJSONData["tlsCACert"]
tlsClientCert := dsInfo.DecryptedSecureJSONData["tlsClientCert"]
tlsClientKey := dsInfo.DecryptedSecureJSONData["tlsClientKey"]
if tlsRootCert == "" && tlsClientCert == "" && tlsClientKey == "" {
m.logger.Debug("No TLS/SSL certificates provided")
}
// Calculate all files path
workDir := filepath.Join(m.dataPath, "tls", dsInfo.UID+"generatedTLSCerts")
tlsconfig.RootCertFile = getFileName(workDir, rootCert)
tlsconfig.CertFile = getFileName(workDir, clientCert)
tlsconfig.CertKeyFile = getFileName(workDir, clientKey)
// Find datasource in the cache, if found, skip writing files
cacheKey := strconv.Itoa(int(dsInfo.ID))
m.dsCacheInstance.locker.RLock(cacheKey)
item, ok := m.dsCacheInstance.cache.Load(cacheKey)
m.dsCacheInstance.locker.RUnlock(cacheKey)
if ok {
if !item.(time.Time).Before(dsInfo.Updated) {
return nil
}
}
m.dsCacheInstance.locker.Lock(cacheKey)
defer m.dsCacheInstance.locker.Unlock(cacheKey)
item, ok = m.dsCacheInstance.cache.Load(cacheKey)
if ok {
if !item.(time.Time).Before(dsInfo.Updated) {
return nil
}
}
// Write certification directory and files
exists, err := fileExists(workDir)
if err != nil {
return err
}
if !exists {
if err := os.MkdirAll(workDir, 0700); err != nil {
return err
}
}
if err = writeCertFileFunc(m.logger, tlsRootCert, tlsconfig.RootCertFile); err != nil {
return err
}
if err = writeCertFileFunc(m.logger, tlsClientCert, tlsconfig.CertFile); err != nil {
return err
}
if err = writeCertFileFunc(m.logger, tlsClientKey, tlsconfig.CertKeyFile); err != nil {
return err
}
// we do not want to point to cert-files that do not exist
if tlsRootCert == "" {
tlsconfig.RootCertFile = ""
}
if tlsClientCert == "" {
tlsconfig.CertFile = ""
}
if tlsClientKey == "" {
tlsconfig.CertKeyFile = ""
}
// Update datasource cache
m.dsCacheInstance.cache.Store(cacheKey, dsInfo.Updated)
return nil
}
// validateCertFilePaths validates configured certificate file paths.
func validateCertFilePaths(rootCert, clientCert, clientKey string) error {
for _, fpath := range []string{rootCert, clientCert, clientKey} {
if fpath == "" {
continue
}
exists, err := fileExists(fpath)
if err != nil {
return err
}
if !exists {
return fmt.Errorf("certificate file %q doesn't exist", fpath)
}
}
return nil
}
// Exists determines whether a file/directory exists or not.
func fileExists(fpath string) (bool, error) {
_, err := os.Stat(fpath)
if err != nil {
if !os.IsNotExist(err) {
return false, err
}
return false, nil
}
return true, nil
}

View File

@ -1,332 +0,0 @@
package postgres
import (
"fmt"
"path/filepath"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/log"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/sqleng"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
_ "github.com/lib/pq"
)
var writeCertFileCallNum int
// TestDataSourceCacheManager is to test the Cache manager
func TestDataSourceCacheManager(t *testing.T) {
cfg := setting.NewCfg()
cfg.DataPath = t.TempDir()
mng := tlsManager{
logger: backend.NewLoggerWith("logger", "tsdb.postgres"),
dsCacheInstance: datasourceCacheManager{locker: newLocker()},
dataPath: cfg.DataPath,
}
jsonData := sqleng.JsonData{
Mode: "verify-full",
ConfigurationMethod: "file-content",
}
secureJSONData := map[string]string{
"tlsClientCert": "I am client certification",
"tlsClientKey": "I am client key",
"tlsCACert": "I am CA certification",
}
updateTime := time.Now().Add(-5 * time.Minute)
mockValidateCertFilePaths()
t.Cleanup(resetValidateCertFilePaths)
t.Run("Check datasource cache creation", func(t *testing.T) {
var wg sync.WaitGroup
wg.Add(10)
for id := int64(1); id <= 10; id++ {
go func(id int64) {
ds := sqleng.DataSourceInfo{
ID: id,
Updated: updateTime,
Database: "database",
JsonData: jsonData,
DecryptedSecureJSONData: secureJSONData,
UID: "testData",
}
s := tlsSettings{}
err := mng.writeCertFiles(ds, &s)
require.NoError(t, err)
wg.Done()
}(id)
}
wg.Wait()
t.Run("check cache creation is succeed", func(t *testing.T) {
for id := int64(1); id <= 10; id++ {
updated, ok := mng.dsCacheInstance.cache.Load(strconv.Itoa(int(id)))
require.True(t, ok)
require.Equal(t, updateTime, updated)
}
})
})
t.Run("Check datasource cache modification", func(t *testing.T) {
t.Run("check when version not changed, cache and files are not updated", func(t *testing.T) {
mockWriteCertFile()
t.Cleanup(resetWriteCertFile)
var wg1 sync.WaitGroup
wg1.Add(5)
for id := int64(1); id <= 5; id++ {
go func(id int64) {
ds := sqleng.DataSourceInfo{
ID: 1,
Updated: updateTime,
Database: "database",
JsonData: jsonData,
DecryptedSecureJSONData: secureJSONData,
UID: "testData",
}
s := tlsSettings{}
err := mng.writeCertFiles(ds, &s)
require.NoError(t, err)
wg1.Done()
}(id)
}
wg1.Wait()
assert.Equal(t, writeCertFileCallNum, 0)
})
t.Run("cache is updated with the last datasource version", func(t *testing.T) {
dsV2 := sqleng.DataSourceInfo{
ID: 1,
Updated: updateTime.Add(time.Minute),
Database: "database",
JsonData: jsonData,
DecryptedSecureJSONData: secureJSONData,
UID: "testData",
}
dsV3 := sqleng.DataSourceInfo{
ID: 1,
Updated: updateTime.Add(2 * time.Minute),
Database: "database",
JsonData: jsonData,
DecryptedSecureJSONData: secureJSONData,
UID: "testData",
}
s := tlsSettings{}
err := mng.writeCertFiles(dsV2, &s)
require.NoError(t, err)
err = mng.writeCertFiles(dsV3, &s)
require.NoError(t, err)
version, ok := mng.dsCacheInstance.cache.Load("1")
require.True(t, ok)
require.Equal(t, updateTime.Add(2*time.Minute), version)
})
})
}
// Test getFileName
func TestGetFileName(t *testing.T) {
testCases := []struct {
desc string
datadir string
fileType certFileType
expErr string
expectedGeneratedPath string
}{
{
desc: "Get File Name for root certification",
datadir: ".",
fileType: rootCert,
expectedGeneratedPath: "root.crt",
},
{
desc: "Get File Name for client certification",
datadir: ".",
fileType: clientCert,
expectedGeneratedPath: "client.crt",
},
{
desc: "Get File Name for client certification",
datadir: ".",
fileType: clientKey,
expectedGeneratedPath: "client.key",
},
}
for _, tt := range testCases {
t.Run(tt.desc, func(t *testing.T) {
generatedPath := getFileName(tt.datadir, tt.fileType)
assert.Equal(t, tt.expectedGeneratedPath, generatedPath)
})
}
}
// Test getTLSSettings.
func TestGetTLSSettings(t *testing.T) {
cfg := setting.NewCfg()
cfg.DataPath = t.TempDir()
mockValidateCertFilePaths()
t.Cleanup(resetValidateCertFilePaths)
updatedTime := time.Now()
testCases := []struct {
desc string
expErr string
jsonData sqleng.JsonData
secureJSONData map[string]string
uid string
tlsSettings tlsSettings
updated time.Time
}{
{
desc: "Custom TLS authentication disabled",
updated: updatedTime,
jsonData: sqleng.JsonData{
Mode: "disable",
RootCertFile: "i/am/coding/ca.crt",
CertFile: "i/am/coding/client.crt",
CertKeyFile: "i/am/coding/client.key",
ConfigurationMethod: "file-path",
},
tlsSettings: tlsSettings{Mode: "disable"},
},
{
desc: "Custom TLS authentication with file path",
updated: updatedTime.Add(time.Minute),
jsonData: sqleng.JsonData{
Mode: "verify-full",
ConfigurationMethod: "file-path",
RootCertFile: "i/am/coding/ca.crt",
CertFile: "i/am/coding/client.crt",
CertKeyFile: "i/am/coding/client.key",
},
tlsSettings: tlsSettings{
Mode: "verify-full",
ConfigurationMethod: "file-path",
RootCertFile: "i/am/coding/ca.crt",
CertFile: "i/am/coding/client.crt",
CertKeyFile: "i/am/coding/client.key",
},
},
{
desc: "Custom TLS mode verify-full with certificate files content",
updated: updatedTime.Add(2 * time.Minute),
uid: "xxx",
jsonData: sqleng.JsonData{
Mode: "verify-full",
ConfigurationMethod: "file-content",
},
secureJSONData: map[string]string{
"tlsCACert": "I am CA certification",
"tlsClientCert": "I am client certification",
"tlsClientKey": "I am client key",
},
tlsSettings: tlsSettings{
Mode: "verify-full",
ConfigurationMethod: "file-content",
RootCertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "root.crt"),
CertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.crt"),
CertKeyFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.key"),
},
},
{
desc: "Custom TLS mode verify-ca with no client certificates with certificate files content",
updated: updatedTime.Add(3 * time.Minute),
uid: "xxx",
jsonData: sqleng.JsonData{
Mode: "verify-ca",
ConfigurationMethod: "file-content",
},
secureJSONData: map[string]string{
"tlsCACert": "I am CA certification",
},
tlsSettings: tlsSettings{
Mode: "verify-ca",
ConfigurationMethod: "file-content",
RootCertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "root.crt"),
CertFile: "",
CertKeyFile: "",
},
},
{
desc: "Custom TLS mode require with client certificates and no root certificate with certificate files content",
updated: updatedTime.Add(4 * time.Minute),
uid: "xxx",
jsonData: sqleng.JsonData{
Mode: "require",
ConfigurationMethod: "file-content",
},
secureJSONData: map[string]string{
"tlsClientCert": "I am client certification",
"tlsClientKey": "I am client key",
},
tlsSettings: tlsSettings{
Mode: "require",
ConfigurationMethod: "file-content",
RootCertFile: "",
CertFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.crt"),
CertKeyFile: filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.key"),
},
},
}
for _, tt := range testCases {
t.Run(tt.desc, func(t *testing.T) {
var settings tlsSettings
var err error
mng := tlsManager{
logger: backend.NewLoggerWith("logger", "tsdb.postgres"),
dsCacheInstance: datasourceCacheManager{locker: newLocker()},
dataPath: cfg.DataPath,
}
ds := sqleng.DataSourceInfo{
JsonData: tt.jsonData,
DecryptedSecureJSONData: tt.secureJSONData,
UID: tt.uid,
Updated: tt.updated,
}
settings, err = mng.getTLSSettings(ds)
if tt.expErr == "" {
require.NoError(t, err, tt.desc)
assert.Equal(t, tt.tlsSettings, settings)
} else {
require.Error(t, err, tt.desc)
assert.True(t, strings.HasPrefix(err.Error(), tt.expErr),
fmt.Sprintf("%s: %q doesn't start with %q", tt.desc, err, tt.expErr))
}
})
}
}
func mockValidateCertFilePaths() {
validateCertFunc = func(rootCert, clientCert, clientKey string) error {
return nil
}
}
func resetValidateCertFilePaths() {
validateCertFunc = validateCertFilePaths
}
func mockWriteCertFile() {
writeCertFileCallNum = 0
writeCertFileFunc = func(logger log.Logger, fileContent string, generatedFilePath string) error {
writeCertFileCallNum++
return nil
}
}
func resetWriteCertFile() {
writeCertFileCallNum = 0
writeCertFileFunc = writeCertFile
}