pemfile: Implement certprovider config parsing API (#4023)

This commit is contained in:
Easwar Swaminathan
2020-11-17 15:36:28 -08:00
committed by GitHub
parent 3d14af97a5
commit fa59d20167
14 changed files with 311 additions and 888 deletions

View File

@ -69,7 +69,6 @@ var (
readAudienceFunc = readAudience
)
// Implements the certprovider.StableConfig interface.
type pluginConfig struct {
serverURI string
stsOpts sts.Options

View File

@ -0,0 +1,96 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package pemfile
import (
"encoding/json"
"fmt"
"time"
"google.golang.org/grpc/credentials/tls/certprovider"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/durationpb"
)
const (
pluginName = "file_watcher"
defaultRefreshInterval = 10 * time.Minute
)
func init() {
certprovider.Register(&pluginBuilder{})
}
type pluginBuilder struct{}
func (p *pluginBuilder) ParseConfig(c interface{}) (*certprovider.BuildableConfig, error) {
data, ok := c.(json.RawMessage)
if !ok {
return nil, fmt.Errorf("meshca: unsupported config type: %T", c)
}
opts, err := pluginConfigFromJSON(data)
if err != nil {
return nil, err
}
return certprovider.NewBuildableConfig(pluginName, opts.canonical(), func(certprovider.BuildOptions) certprovider.Provider {
return newProvider(opts)
}), nil
}
func (p *pluginBuilder) Name() string {
return pluginName
}
func pluginConfigFromJSON(jd json.RawMessage) (Options, error) {
// The only difference between this anonymous struct and the Options struct
// is that the refresh_interval is represented here as a duration proto,
// while in the latter a time.Duration is used.
cfg := &struct {
CertificateFile string `json:"certificate_file,omitempty"`
PrivateKeyFile string `json:"private_key_file,omitempty"`
CACertificateFile string `json:"ca_certificate_file,omitempty"`
RefreshInterval json.RawMessage `json:"refresh_interval,omitempty"`
}{}
if err := json.Unmarshal(jd, cfg); err != nil {
return Options{}, fmt.Errorf("pemfile: json.Unmarshal(%s) failed: %v", string(jd), err)
}
opts := Options{
CertFile: cfg.CertificateFile,
KeyFile: cfg.PrivateKeyFile,
RootFile: cfg.CACertificateFile,
// Refresh interval is the only field in the configuration for which we
// support a default value. We cannot possibly have valid defaults for
// file paths to watch. Also, it is valid to specify an empty path for
// some of those fields if the user does not want to watch them.
RefreshDuration: defaultRefreshInterval,
}
if cfg.RefreshInterval != nil {
dur := &durationpb.Duration{}
if err := protojson.Unmarshal(cfg.RefreshInterval, dur); err != nil {
return Options{}, fmt.Errorf("pemfile: protojson.Unmarshal(%+v) failed: %v", cfg.RefreshInterval, err)
}
opts.RefreshDuration = dur.AsDuration()
}
if err := opts.validate(); err != nil {
return Options{}, err
}
return opts, nil
}

View File

@ -0,0 +1,130 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package pemfile
import (
"encoding/json"
"testing"
)
func TestParseConfig(t *testing.T) {
tests := []struct {
desc string
input interface{}
wantOutput string
wantErr bool
}{
{
desc: "non JSON input",
input: new(int),
wantErr: true,
},
{
desc: "invalid JSON",
input: json.RawMessage(`bad bad json`),
wantErr: true,
},
{
desc: "JSON input does not match expected",
input: json.RawMessage(`["foo": "bar"]`),
wantErr: true,
},
{
desc: "no credential files",
input: json.RawMessage(`{}`),
wantErr: true,
},
{
desc: "only cert file",
input: json.RawMessage(`
{
"certificate_file": "/a/b/cert.pem"
}`),
wantErr: true,
},
{
desc: "only key file",
input: json.RawMessage(`
{
"private_key_file": "/a/b/key.pem"
}`),
wantErr: true,
},
{
desc: "cert and key in different directories",
input: json.RawMessage(`
{
"certificate_file": "/b/a/cert.pem",
"private_key_file": "/a/b/key.pem"
}`),
wantErr: true,
},
{
desc: "bad refresh duration",
input: json.RawMessage(`
{
"certificate_file": "/a/b/cert.pem",
"private_key_file": "/a/b/key.pem",
"ca_certificate_file": "/a/b/ca.pem",
"refresh_interval": "duration"
}`),
wantErr: true,
},
{
desc: "good config with default refresh interval",
input: json.RawMessage(`
{
"certificate_file": "/a/b/cert.pem",
"private_key_file": "/a/b/key.pem",
"ca_certificate_file": "/a/b/ca.pem"
}`),
wantOutput: "file_watcher:/a/b/cert.pem:/a/b/key.pem:/a/b/ca.pem:10m0s",
},
{
desc: "good config",
input: json.RawMessage(`
{
"certificate_file": "/a/b/cert.pem",
"private_key_file": "/a/b/key.pem",
"ca_certificate_file": "/a/b/ca.pem",
"refresh_interval": "200s"
}`),
wantOutput: "file_watcher:/a/b/cert.pem:/a/b/key.pem:/a/b/ca.pem:3m20s",
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
builder := &pluginBuilder{}
bc, err := builder.ParseConfig(test.input)
if (err != nil) != test.wantErr {
t.Fatalf("ParseConfig(%+v) failed: %v", test.input, err)
}
if test.wantErr {
return
}
gotConfig := bc.String()
if gotConfig != test.wantOutput {
t.Fatalf("ParseConfig(%v) = %s, want %s", test.input, gotConfig, test.wantOutput)
}
})
}
}

View File

@ -30,18 +30,17 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"path/filepath"
"time"
"google.golang.org/grpc/credentials/tls/certprovider"
"google.golang.org/grpc/grpclog"
)
const (
defaultCertRefreshDuration = 1 * time.Hour
defaultRootRefreshDuration = 2 * time.Hour
)
const defaultCertRefreshDuration = 1 * time.Hour
var (
// For overriding from unit tests.
@ -62,30 +61,48 @@ type Options struct {
// RootFile is the file that holds trusted root certificate(s).
// Optional.
RootFile string
// CertRefreshDuration is the amount of time the plugin waits before
// checking for updates in the specified identity certificate and key file.
// RefreshDuration is the amount of time the plugin waits before checking
// for updates in the specified files.
// Optional. If not set, a default value (1 hour) will be used.
CertRefreshDuration time.Duration
// RootRefreshDuration is the amount of time the plugin waits before
// checking for updates in the specified root file.
// Optional. If not set, a default value (2 hour) will be used.
RootRefreshDuration time.Duration
RefreshDuration time.Duration
}
func (o Options) canonical() []byte {
return []byte(fmt.Sprintf("%s:%s:%s:%s", o.CertFile, o.KeyFile, o.RootFile, o.RefreshDuration))
}
func (o Options) validate() error {
if o.CertFile == "" && o.KeyFile == "" && o.RootFile == "" {
return fmt.Errorf("pemfile: at least one credential file needs to be specified")
}
if keySpecified, certSpecified := o.KeyFile != "", o.CertFile != ""; keySpecified != certSpecified {
return fmt.Errorf("pemfile: private key file and identity cert file should be both specified or not specified")
}
// C-core has a limitation that they cannot verify that a certificate file
// matches a key file. So, the only way to get around this is to make sure
// that both files are in the same directory and that they do an atomic
// read. Even though Java/Go do not have this limitation, we want the
// overall plugin behavior to be consistent across languages.
if certDir, keyDir := filepath.Dir(o.CertFile), filepath.Dir(o.KeyFile); certDir != keyDir {
return errors.New("pemfile: certificate and key file must be in the same directory")
}
return nil
}
// NewProvider returns a new certificate provider plugin that is configured to
// watch the PEM files specified in the passed in options.
func NewProvider(o Options) (certprovider.Provider, error) {
if o.CertFile == "" && o.KeyFile == "" && o.RootFile == "" {
return nil, fmt.Errorf("pemfile: at least one credential file needs to be specified")
if err := o.validate(); err != nil {
return nil, err
}
if keySpecified, certSpecified := o.KeyFile != "", o.CertFile != ""; keySpecified != certSpecified {
return nil, fmt.Errorf("pemfile: private key file and identity cert file should be both specified or not specified")
}
if o.CertRefreshDuration == 0 {
o.CertRefreshDuration = defaultCertRefreshDuration
}
if o.RootRefreshDuration == 0 {
o.RootRefreshDuration = defaultRootRefreshDuration
return newProvider(o), nil
}
// newProvider is used to create a new certificate provider plugin after
// validating the options, and hence does not return an error.
func newProvider(o Options) certprovider.Provider {
if o.RefreshDuration == 0 {
o.RefreshDuration = defaultCertRefreshDuration
}
provider := &watcher{opts: o}
@ -99,8 +116,7 @@ func NewProvider(o Options) (certprovider.Provider, error) {
ctx, cancel := context.WithCancel(context.Background())
provider.cancel = cancel
go provider.run(ctx)
return provider, nil
return provider
}
// watcher is a certificate provider plugin that implements the
@ -203,13 +219,13 @@ func (w *watcher) run(ctx context.Context) {
w.updateIdentityDistributor()
w.updateRootDistributor()
identityTicker := time.NewTicker(w.opts.CertRefreshDuration)
rootTicker := time.NewTicker(w.opts.RootRefreshDuration)
ticker := time.NewTicker(w.opts.RefreshDuration)
for {
w.updateIdentityDistributor()
w.updateRootDistributor()
select {
case <-ctx.Done():
identityTicker.Stop()
rootTicker.Stop()
ticker.Stop()
if w.identityDistributor != nil {
w.identityDistributor.Stop()
}
@ -217,10 +233,7 @@ func (w *watcher) run(ctx context.Context) {
w.rootDistributor.Stop()
}
return
case <-identityTicker.C:
w.updateIdentityDistributor()
case <-rootTicker.C:
w.updateRootDistributor()
case <-ticker.C:
}
}
}

View File

@ -187,11 +187,10 @@ func initializeProvider(t *testing.T, testName string) (string, certprovider.Pro
// Create a new provider to watch the files in tmpdir.
dir := createTmpDirWithFiles(t, testName+"*", "x509/client1_cert.pem", "x509/client1_key.pem", "x509/client_ca_cert.pem")
opts := Options{
CertFile: path.Join(dir, certFile),
KeyFile: path.Join(dir, keyFile),
RootFile: path.Join(dir, rootFile),
CertRefreshDuration: defaultTestRefreshDuration,
RootRefreshDuration: defaultTestRefreshDuration,
CertFile: path.Join(dir, certFile),
KeyFile: path.Join(dir, keyFile),
RootFile: path.Join(dir, rootFile),
RefreshDuration: defaultTestRefreshDuration,
}
prov, err := NewProvider(opts)
if err != nil {
@ -314,11 +313,10 @@ func (s) TestProvider_UpdateSuccessWithSymlink(t *testing.T) {
// Create a provider which watches the files pointed to by the symlink.
opts := Options{
CertFile: path.Join(symLinkName, certFile),
KeyFile: path.Join(symLinkName, keyFile),
RootFile: path.Join(symLinkName, rootFile),
CertRefreshDuration: defaultTestRefreshDuration,
RootRefreshDuration: defaultTestRefreshDuration,
CertFile: path.Join(symLinkName, certFile),
KeyFile: path.Join(symLinkName, keyFile),
RootFile: path.Join(symLinkName, rootFile),
RefreshDuration: defaultTestRefreshDuration,
}
prov, err := NewProvider(opts)
if err != nil {