Postgres: Switch the datasource plugin from lib/pq to pgx (#81353)

* postgres: switch from lib/pq to pgx

* postgres: improved tls handling
This commit is contained in:
Gábor Farkas
2024-02-28 07:52:45 +01:00
committed by GitHub
parent e8df62941b
commit 8c18d06386
17 changed files with 967 additions and 908 deletions

View File

@ -4,38 +4,42 @@ import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"os"
"reflect"
"strconv"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
pgxstdlib "github.com/jackc/pgx/v5/stdlib"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/datasource"
"github.com/grafana/grafana-plugin-sdk-go/backend/instancemgmt"
"github.com/grafana/grafana-plugin-sdk-go/data"
"github.com/grafana/grafana-plugin-sdk-go/data/sqlutil"
"github.com/lib/pq"
"github.com/grafana/grafana-plugin-sdk-go/backend/log"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/grafana-postgresql-datasource/tls"
"github.com/grafana/grafana/pkg/tsdb/sqleng"
)
func ProvideService(cfg *setting.Cfg) *Service {
logger := backend.NewLoggerWith("logger", "tsdb.postgres")
s := &Service{
tlsManager: newTLSManager(logger, cfg.DataPath),
logger: logger,
logger: logger,
}
s.im = datasource.NewInstanceManager(s.newInstanceSettings())
return s
}
type Service struct {
tlsManager tlsSettingsProvider
im instancemgmt.InstanceManager
logger log.Logger
im instancemgmt.InstanceManager
logger log.Logger
}
func (s *Service) getDSInfo(ctx context.Context, pluginCtx backend.PluginContext) (*sqleng.DataSourceHandler, error) {
@ -55,13 +59,7 @@ func (s *Service) QueryData(ctx context.Context, req *backend.QueryDataRequest)
return dsInfo.QueryData(ctx, req)
}
func newPostgres(ctx context.Context, userFacingDefaultError string, rowLimit int64, dsInfo sqleng.DataSourceInfo, cnnstr string, logger log.Logger, settings backend.DataSourceInstanceSettings) (*sql.DB, *sqleng.DataSourceHandler, error) {
connector, err := pq.NewConnector(cnnstr)
if err != nil {
logger.Error("postgres connector creation failed", "error", err)
return nil, nil, fmt.Errorf("postgres connector creation failed")
}
func newPostgres(ctx context.Context, userFacingDefaultError string, rowLimit int64, dsInfo sqleng.DataSourceInfo, pgxConf *pgx.ConnConfig, logger log.Logger, settings backend.DataSourceInstanceSettings) (*sql.DB, *sqleng.DataSourceHandler, error) {
proxyClient, err := settings.ProxyClient(ctx)
if err != nil {
logger.Error("postgres proxy creation failed", "error", err)
@ -74,9 +72,8 @@ func newPostgres(ctx context.Context, userFacingDefaultError string, rowLimit in
logger.Error("postgres proxy creation failed", "error", err)
return nil, nil, fmt.Errorf("postgres proxy creation failed")
}
postgresDialer := newPostgresProxyDialer(dialer)
// update the postgres dialer with the proxy dialer
connector.Dialer(postgresDialer)
pgxConf.DialFunc = newPgxDialFunc(dialer)
}
config := sqleng.DataPluginConfiguration{
@ -87,7 +84,7 @@ func newPostgres(ctx context.Context, userFacingDefaultError string, rowLimit in
queryResultTransformer := postgresQueryResultTransformer{}
db := sql.OpenDB(connector)
db := pgxstdlib.OpenDB(*pgxConf)
db.SetMaxOpenConns(config.DSInfo.JsonData.MaxOpenConns)
db.SetMaxIdleConns(config.DSInfo.JsonData.MaxIdleConns)
@ -143,7 +140,7 @@ func (s *Service) newInstanceSettings() datasource.InstanceFactoryFunc {
DecryptedSecureJSONData: settings.DecryptedSecureJSONData,
}
cnnstr, err := s.generateConnectionString(dsInfo)
pgxConf, err := generateConnectionConfig(dsInfo)
if err != nil {
return nil, err
}
@ -153,7 +150,7 @@ func (s *Service) newInstanceSettings() datasource.InstanceFactoryFunc {
return nil, err
}
_, handler, err := newPostgres(ctx, userFacingDefaultError, sqlCfg.RowLimit, dsInfo, cnnstr, logger, settings)
_, handler, err := newPostgres(ctx, userFacingDefaultError, sqlCfg.RowLimit, dsInfo, pgxConf, logger, settings)
if err != nil {
logger.Error("Failed connecting to Postgres", "err", err)
@ -170,13 +167,11 @@ func escape(input string) string {
return strings.ReplaceAll(strings.ReplaceAll(input, `\`, `\\`), "'", `\'`)
}
func (s *Service) generateConnectionString(dsInfo sqleng.DataSourceInfo) (string, error) {
logger := s.logger
func generateConnectionConfig(dsInfo sqleng.DataSourceInfo) (*pgx.ConnConfig, error) {
var host string
var port int
if strings.HasPrefix(dsInfo.URL, "/") {
host = dsInfo.URL
logger.Debug("Generating connection string with Unix socket specifier", "socket", host)
} else {
index := strings.LastIndex(dsInfo.URL, ":")
v6Index := strings.Index(dsInfo.URL, "]")
@ -187,12 +182,8 @@ func (s *Service) generateConnectionString(dsInfo sqleng.DataSourceInfo) (string
var err error
port, err = strconv.Atoi(sp[1])
if err != nil {
return "", fmt.Errorf("invalid port in host specifier %q: %w", sp[1], err)
return nil, fmt.Errorf("invalid port in host specifier %q: %w", sp[1], err)
}
logger.Debug("Generating connection string with network host/port pair", "host", host, "port", port)
} else {
logger.Debug("Generating connection string with network host", "host", host)
}
} else {
if index == v6Index+1 {
@ -200,46 +191,39 @@ func (s *Service) generateConnectionString(dsInfo sqleng.DataSourceInfo) (string
var err error
port, err = strconv.Atoi(dsInfo.URL[index+1:])
if err != nil {
return "", fmt.Errorf("invalid port in host specifier %q: %w", dsInfo.URL[index+1:], err)
return nil, fmt.Errorf("invalid port in host specifier %q: %w", dsInfo.URL[index+1:], err)
}
logger.Debug("Generating ipv6 connection string with network host/port pair", "host", host, "port", port)
} else {
host = dsInfo.URL[1 : len(dsInfo.URL)-1]
logger.Debug("Generating ipv6 connection string with network host", "host", host)
}
}
}
connStr := fmt.Sprintf("user='%s' password='%s' host='%s' dbname='%s'",
// NOTE: we always set sslmode=disable in the connection string, we handle TLS manually later
connStr := fmt.Sprintf("sslmode=disable user='%s' password='%s' host='%s' dbname='%s'",
escape(dsInfo.User), escape(dsInfo.DecryptedSecureJSONData["password"]), escape(host), escape(dsInfo.Database))
if port > 0 {
connStr += fmt.Sprintf(" port=%d", port)
}
tlsSettings, err := s.tlsManager.getTLSSettings(dsInfo)
conf, err := pgx.ParseConfig(connStr)
if err != nil {
return "", err
return nil, err
}
connStr += fmt.Sprintf(" sslmode='%s'", escape(tlsSettings.Mode))
// Attach root certificate if provided
if tlsSettings.RootCertFile != "" {
logger.Debug("Setting server root certificate", "tlsRootCert", tlsSettings.RootCertFile)
connStr += fmt.Sprintf(" sslrootcert='%s'", escape(tlsSettings.RootCertFile))
tlsConf, err := tls.GetTLSConfig(dsInfo, os.ReadFile, host)
if err != nil {
return nil, err
}
// Attach client certificate and key if both are provided
if tlsSettings.CertFile != "" && tlsSettings.CertKeyFile != "" {
logger.Debug("Setting TLS/SSL client auth", "tlsCert", tlsSettings.CertFile, "tlsKey", tlsSettings.CertKeyFile)
connStr += fmt.Sprintf(" sslcert='%s' sslkey='%s'", escape(tlsSettings.CertFile), escape(tlsSettings.CertKeyFile))
} else if tlsSettings.CertFile != "" || tlsSettings.CertKeyFile != "" {
return "", fmt.Errorf("TLS/SSL client certificate and key must both be specified")
// before we set the TLS config, we need to make sure the `.Fallbacks` attribute is unset, see:
// https://github.com/jackc/pgx/discussions/1903#discussioncomment-8430146
if len(conf.Fallbacks) > 0 {
return nil, errors.New("tls: fallbacks configured, unable to set up TLS config")
}
conf.TLSConfig = tlsConf
logger.Debug("Generated Postgres connection string successfully")
return connStr, nil
return conf, nil
}
type postgresQueryResultTransformer struct{}
@ -267,6 +251,44 @@ func (s *Service) CheckHealth(ctx context.Context, req *backend.CheckHealthReque
func (t *postgresQueryResultTransformer) GetConverterList() []sqlutil.StringConverter {
return []sqlutil.StringConverter{
{
Name: "handle TIME WITH TIME ZONE",
InputScanKind: reflect.Interface,
InputTypeName: strconv.Itoa(pgtype.TimetzOID),
ConversionFunc: func(in *string) (*string, error) { return in, nil },
Replacer: &sqlutil.StringFieldReplacer{
OutputFieldType: data.FieldTypeNullableTime,
ReplaceFunc: func(in *string) (any, error) {
if in == nil {
return nil, nil
}
v, err := time.Parse("15:04:05-07", *in)
if err != nil {
return nil, err
}
return &v, nil
},
},
},
{
Name: "handle TIME",
InputScanKind: reflect.Interface,
InputTypeName: "TIME",
ConversionFunc: func(in *string) (*string, error) { return in, nil },
Replacer: &sqlutil.StringFieldReplacer{
OutputFieldType: data.FieldTypeNullableTime,
ReplaceFunc: func(in *string) (any, error) {
if in == nil {
return nil, nil
}
v, err := time.Parse("15:04:05", *in)
if err != nil {
return nil, err
}
return &v, nil
},
},
},
{
Name: "handle FLOAT4",
InputScanKind: reflect.Interface,