diff --git a/credentials/tls/certprovider/meshca/builder.go b/credentials/tls/certprovider/meshca/builder.go index 3544a164..4b8af7c9 100644 --- a/credentials/tls/certprovider/meshca/builder.go +++ b/credentials/tls/certprovider/meshca/builder.go @@ -65,7 +65,7 @@ type refCountedCC struct { } // pluginBuilder is an implementation of the certprovider.Builder interface, -// which build certificate provider instances which get certificates signed from +// which builds certificate provider instances to get certificates signed from // the MeshCA. type pluginBuilder struct { // A collection of ClientConns to the MeshCA server along with a reference @@ -75,22 +75,30 @@ type pluginBuilder struct { clients map[ccMapKey]*refCountedCC } -// Build returns a MeshCA certificate provider for the passed in configuration -// and options. +// ParseConfig parses the configuration to be passed to the MeshCA plugin +// implementation. Expects the config to be a json.RawMessage which contains a +// serialized JSON representation of the meshca_experimental.GoogleMeshCaConfig +// proto message. // -// This builder takes care of sharing the ClientConn to the MeshCA server among +// Takes care of sharing the ClientConn to the MeshCA server among // different plugin instantiations. -func (b *pluginBuilder) Build(c certprovider.StableConfig, opts certprovider.Options) certprovider.Provider { - cfg, ok := c.(*pluginConfig) +func (b *pluginBuilder) ParseConfig(c interface{}) (*certprovider.BuildableConfig, error) { + data, ok := c.(json.RawMessage) if !ok { - // This is not expected when passing config returned by ParseConfig(). - // This could indicate a bug in the certprovider.Store implementation or - // in cases where the user is directly using these APIs, could be a user - // error. - logger.Errorf("unsupported config type: %T", c) - return nil + return nil, fmt.Errorf("meshca: unsupported config type: %T", c) } + cfg, err := pluginConfigFromJSON(data) + if err != nil { + return nil, err + } + return certprovider.NewBuildableConfig(pluginName, cfg.canonical(), func(opts certprovider.BuildOptions) certprovider.Provider { + return b.buildFromConfig(cfg, opts) + }), nil +} +// buildFromConfig builds a certificate provider instance for the given config +// and options. Provider instances are shared wherever possible. +func (b *pluginBuilder) buildFromConfig(cfg *pluginConfig, opts certprovider.BuildOptions) certprovider.Provider { b.mu.Lock() defer b.mu.Unlock() @@ -151,18 +159,6 @@ func (b *pluginBuilder) Build(c certprovider.StableConfig, opts certprovider.Opt return p } -// ParseConfig parses the configuration to be passed to the MeshCA plugin -// implementation. Expects the config to be a json.RawMessage which contains a -// serialized JSON representation of the meshca_experimental.GoogleMeshCaConfig -// proto message. -func (b *pluginBuilder) ParseConfig(c interface{}) (certprovider.StableConfig, error) { - data, ok := c.(json.RawMessage) - if !ok { - return nil, fmt.Errorf("meshca: unsupported config type: %T", c) - } - return pluginConfigFromJSON(data) -} - // Name returns the MeshCA plugin name. func (b *pluginBuilder) Name() string { return pluginName diff --git a/credentials/tls/certprovider/meshca/builder_test.go b/credentials/tls/certprovider/meshca/builder_test.go index b395f4f4..d21307f4 100644 --- a/credentials/tls/certprovider/meshca/builder_test.go +++ b/credentials/tls/certprovider/meshca/builder_test.go @@ -71,7 +71,7 @@ func (s) TestBuildSameConfig(t *testing.T) { // invocations of Build(). inputConfig := makeJSONConfig(t, goodConfigFullySpecified) builder := newPluginBuilder() - stableConfig, err := builder.ParseConfig(inputConfig) + buildableConfig, err := builder.ParseConfig(inputConfig) if err != nil { t.Fatalf("builder.ParseConfig(%q) failed: %v", inputConfig, err) } @@ -80,9 +80,9 @@ func (s) TestBuildSameConfig(t *testing.T) { // end up sharing the same ClientConn. providers := []certprovider.Provider{} for i := 0; i < cnt; i++ { - p := builder.Build(stableConfig, certprovider.Options{}) - if p == nil { - t.Fatalf("builder.Build(%s) failed: %v", string(stableConfig.Canonical()), err) + p, err := buildableConfig.Build(certprovider.BuildOptions{}) + if err != nil { + t.Fatalf("Build(%+v) failed: %v", buildableConfig, err) } providers = append(providers, p) } @@ -146,14 +146,14 @@ func (s) TestBuildDifferentConfig(t *testing.T) { cfg := proto.Clone(goodConfigFullySpecified).(*configpb.GoogleMeshCaConfig) cfg.Server.GrpcServices[0].GetGoogleGrpc().TargetUri = fmt.Sprintf("test-mesh-ca:%d", i) inputConfig := makeJSONConfig(t, cfg) - stableConfig, err := builder.ParseConfig(inputConfig) + buildableConfig, err := builder.ParseConfig(inputConfig) if err != nil { t.Fatalf("builder.ParseConfig(%q) failed: %v", inputConfig, err) } - p := builder.Build(stableConfig, certprovider.Options{}) - if p == nil { - t.Fatalf("builder.Build(%s) failed: %v", string(stableConfig.Canonical()), err) + p, err := buildableConfig.Build(certprovider.BuildOptions{}) + if err != nil { + t.Fatalf("Build(%+v) failed: %v", buildableConfig, err) } providers = append(providers, p) } diff --git a/credentials/tls/certprovider/meshca/config.go b/credentials/tls/certprovider/meshca/config.go index 38186fa8..2800becc 100644 --- a/credentials/tls/certprovider/meshca/config.go +++ b/credentials/tls/certprovider/meshca/config.go @@ -161,7 +161,7 @@ func pluginConfigFromJSON(data json.RawMessage) (*pluginConfig, error) { return pc, nil } -func (pc *pluginConfig) Canonical() []byte { +func (pc *pluginConfig) canonical() []byte { return []byte(fmt.Sprintf("%s:%s:%s:%s:%s:%s:%d:%s", pc.serverURI, pc.stsOpts, pc.callTimeout, pc.certLifetime, pc.certGraceTime, pc.keyType, pc.keySize, pc.location)) } diff --git a/credentials/tls/certprovider/meshca/config_test.go b/credentials/tls/certprovider/meshca/config_test.go index 34dd9f75..f5e9b417 100644 --- a/credentials/tls/certprovider/meshca/config_test.go +++ b/credentials/tls/certprovider/meshca/config_test.go @@ -160,13 +160,13 @@ func (s) TestParseConfigSuccessFullySpecified(t *testing.T) { inputConfig := makeJSONConfig(t, goodConfigFullySpecified) wantConfig := "test-meshca:http://test-sts:test-resource:test-audience:test-scope:test-requested-token-type:test-subject-token-path:test-subject-token-type:test-actor-token-path:test-actor-token-type:10s:24h0m0s:12h0m0s:RSA:2048:us-west1-b" - builder := newPluginBuilder() - gotConfig, err := builder.ParseConfig(inputConfig) + cfg, err := pluginConfigFromJSON(inputConfig) if err != nil { - t.Fatalf("builder.ParseConfig(%q) failed: %v", inputConfig, err) + t.Fatalf("pluginConfigFromJSON(%q) failed: %v", inputConfig, err) } - if diff := cmp.Diff(wantConfig, string(gotConfig.Canonical())); diff != "" { - t.Errorf("builder.ParseConfig(%q) returned config does not match expected (-want +got):\n%s", inputConfig, diff) + gotConfig := cfg.canonical() + if diff := cmp.Diff(wantConfig, string(gotConfig)); diff != "" { + t.Errorf("pluginConfigFromJSON(%q) returned config does not match expected (-want +got):\n%s", inputConfig, diff) } } @@ -248,13 +248,12 @@ func (s) TestParseConfigSuccessWithDefaults(t *testing.T) { errCh <- nil }() - builder := newPluginBuilder() - gotConfig, err := builder.ParseConfig(inputConfig) + cfg, err := pluginConfigFromJSON(inputConfig) if err != nil { - t.Fatalf("builder.ParseConfig(%q) failed: %v", inputConfig, err) - + t.Fatalf("pluginConfigFromJSON(%q) failed: %v", inputConfig, err) } - if diff := cmp.Diff(wantConfig, string(gotConfig.Canonical())); diff != "" { + gotConfig := cfg.canonical() + if diff := cmp.Diff(wantConfig, string(gotConfig)); diff != "" { t.Errorf("builder.ParseConfig(%q) returned config does not match expected (-want +got):\n%s", inputConfig, diff) } @@ -268,14 +267,9 @@ func (s) TestParseConfigSuccessWithDefaults(t *testing.T) { func (s) TestParseConfigFailureCases(t *testing.T) { tests := []struct { desc string - inputConfig interface{} + inputConfig json.RawMessage wantErr string }{ - { - desc: "bad config type", - inputConfig: struct{ foo string }{foo: "bar"}, - wantErr: "unsupported config type", - }, { desc: "invalid JSON", inputConfig: json.RawMessage(`bad bad json`), @@ -396,10 +390,9 @@ func (s) TestParseConfigFailureCases(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { - builder := newPluginBuilder() - sc, err := builder.ParseConfig(test.inputConfig) + cfg, err := pluginConfigFromJSON(test.inputConfig) if err == nil { - t.Fatalf("builder.ParseConfig(%q) = %v, expected to return error (%v)", test.inputConfig, string(sc.Canonical()), test.wantErr) + t.Fatalf("pluginConfigFromJSON(%q) = %v, expected to return error (%v)", test.inputConfig, string(cfg.canonical()), test.wantErr) } if !strings.Contains(err.Error(), test.wantErr) { diff --git a/credentials/tls/certprovider/meshca/plugin.go b/credentials/tls/certprovider/meshca/plugin.go index 5ff3e9cf..ab1958ac 100644 --- a/credentials/tls/certprovider/meshca/plugin.go +++ b/credentials/tls/certprovider/meshca/plugin.go @@ -65,12 +65,12 @@ type distributor interface { type providerPlugin struct { distributor // Holds the key material. cancel context.CancelFunc - cc *grpc.ClientConn // Connection to MeshCA server. - cfg *pluginConfig // Plugin configuration. - opts certprovider.Options // Key material options. - logger *grpclog.PrefixLogger // Plugin instance specific prefix. - backoff func(int) time.Duration // Exponential backoff. - doneFunc func() // Notify the builder when done. + cc *grpc.ClientConn // Connection to MeshCA server. + cfg *pluginConfig // Plugin configuration. + opts certprovider.BuildOptions // Key material options. + logger *grpclog.PrefixLogger // Plugin instance specific prefix. + backoff func(int) time.Duration // Exponential backoff. + doneFunc func() // Notify the builder when done. } // providerParams wraps params passed to the provider plugin at creation time. @@ -78,7 +78,7 @@ type providerParams struct { // This ClientConn to the MeshCA server is owned by the builder. cc *grpc.ClientConn cfg *pluginConfig - opts certprovider.Options + opts certprovider.BuildOptions backoff func(int) time.Duration doneFunc func() } diff --git a/credentials/tls/certprovider/meshca/plugin_test.go b/credentials/tls/certprovider/meshca/plugin_test.go index 48740c14..5b3f068d 100644 --- a/credentials/tls/certprovider/meshca/plugin_test.go +++ b/credentials/tls/certprovider/meshca/plugin_test.go @@ -297,14 +297,15 @@ func (s) TestCreateCertificate(t *testing.T) { e, addr, cancel := setup(t, opts{}) defer cancel() - // Set the MeshCA targetURI in the plugin configuration to point to our fake - // MeshCA. + // Set the MeshCA targetURI to point to our fake MeshCA. cfg := proto.Clone(goodConfigFullySpecified).(*configpb.GoogleMeshCaConfig) cfg.Server.GrpcServices[0].GetGoogleGrpc().TargetUri = addr inputConfig := makeJSONConfig(t, cfg) - prov, err := certprovider.GetProvider(pluginName, inputConfig, certprovider.Options{}) + + // Lookup MeshCA plugin builder, parse config and start the plugin. + prov, err := certprovider.GetProvider(pluginName, inputConfig, certprovider.BuildOptions{}) if err != nil { - t.Fatalf("certprovider.GetProvider(%s, %s) failed: %v", pluginName, cfg, err) + t.Fatalf("GetProvider(%s, %s) failed: %v", pluginName, string(inputConfig), err) } defer prov.Close() @@ -339,14 +340,15 @@ func (s) TestCreateCertificateWithBackoff(t *testing.T) { e, addr, cancel := setup(t, opts{withbackoff: true}) defer cancel() - // Set the MeshCA targetURI in the plugin configuration to point to our fake - // MeshCA. + // Set the MeshCA targetURI to point to our fake MeshCA. cfg := proto.Clone(goodConfigFullySpecified).(*configpb.GoogleMeshCaConfig) cfg.Server.GrpcServices[0].GetGoogleGrpc().TargetUri = addr inputConfig := makeJSONConfig(t, cfg) - prov, err := certprovider.GetProvider(pluginName, inputConfig, certprovider.Options{}) + + // Lookup MeshCA plugin builder, parse config and start the plugin. + prov, err := certprovider.GetProvider(pluginName, inputConfig, certprovider.BuildOptions{}) if err != nil { - t.Fatalf("certprovider.GetProvider(%s, %s) failed: %v", pluginName, cfg, err) + t.Fatalf("GetProvider(%s, %s) failed: %v", pluginName, string(inputConfig), err) } defer prov.Close() @@ -394,14 +396,15 @@ func (s) TestCreateCertificateWithRefresh(t *testing.T) { e, addr, cancel := setup(t, opts{withShortLife: true}) defer cancel() - // Set the MeshCA targetURI in the plugin configuration to point to our fake - // MeshCA. + // Set the MeshCA targetURI to point to our fake MeshCA. cfg := proto.Clone(goodConfigFullySpecified).(*configpb.GoogleMeshCaConfig) cfg.Server.GrpcServices[0].GetGoogleGrpc().TargetUri = addr inputConfig := makeJSONConfig(t, cfg) - prov, err := certprovider.GetProvider(pluginName, inputConfig, certprovider.Options{}) + + // Lookup MeshCA plugin builder, parse config and start the plugin. + prov, err := certprovider.GetProvider(pluginName, inputConfig, certprovider.BuildOptions{}) if err != nil { - t.Fatalf("certprovider.GetProvider(%s, %s) failed: %v", pluginName, cfg, err) + t.Fatalf("GetProvider(%s, %s) failed: %v", pluginName, string(inputConfig), err) } defer prov.Close() diff --git a/credentials/tls/certprovider/provider.go b/credentials/tls/certprovider/provider.go index 8d8ae80a..275c176a 100644 --- a/credentials/tls/certprovider/provider.go +++ b/credentials/tls/certprovider/provider.go @@ -64,28 +64,14 @@ func getBuilder(name string) Builder { // Builder creates a Provider. type Builder interface { - // Build creates a new Provider and initializes it with the given config and - // options combination. - Build(StableConfig, Options) Provider - - // ParseConfig converts config input in a format specific to individual - // implementations and returns an implementation of the StableConfig - // interface. - // Equivalent configurations must return StableConfig types whose - // Canonical() method returns the same output. - ParseConfig(interface{}) (StableConfig, error) + // ParseConfig parses the given config, which is in a format specific to individual + // implementations, and returns a BuildableConfig on success. + ParseConfig(interface{}) (*BuildableConfig, error) // Name returns the name of providers built by this builder. Name() string } -// StableConfig wraps the method to return a stable Provider configuration. -type StableConfig interface { - // Canonical returns Provider config as an arbitrary byte slice. - // Equivalent configurations must return the same output. - Canonical() []byte -} - // Provider makes it possible to keep channel credential implementations up to // date with secrets that they rely on to secure communications on the // underlying channel. @@ -110,8 +96,8 @@ type KeyMaterial struct { Roots *x509.CertPool } -// Options contains configuration knobs passed to a Provider at creation time. -type Options struct { +// BuildOptions contains parameters passed to a Provider at build time. +type BuildOptions struct { // CertName holds the certificate name, whose key material is of interest to // the caller. CertName string diff --git a/credentials/tls/certprovider/store.go b/credentials/tls/certprovider/store.go index 7f41fddb..90f98b3c 100644 --- a/credentials/tls/certprovider/store.go +++ b/credentials/tls/certprovider/store.go @@ -39,7 +39,7 @@ type storeKey struct { // configuration of the certificate provider in string form. config string // opts contains the certificate name and other keyMaterial options. - opts Options + opts BuildOptions } // wrappedProvider wraps a provider instance with a reference count. @@ -59,66 +59,6 @@ type store struct { providers map[storeKey]*wrappedProvider } -// GetProvider returns a provider instance from which keyMaterial can be read. -// -// name is the registered name of the provider, config is the provider-specific -// configuration, opts contains extra information that controls the keyMaterial -// returned by the provider. -// -// Implementations of the Builder interface should clearly document the type of -// configuration accepted by them. -// -// If a provider exists for passed arguments, its reference count is incremented -// before returning. If no provider exists for the passed arguments, a new one -// is created using the registered builder. If no registered builder is found, -// or the provider configuration is rejected by it, a non-nil error is returned. -func GetProvider(name string, config interface{}, opts Options) (Provider, error) { - provStore.mu.Lock() - defer provStore.mu.Unlock() - - builder := getBuilder(name) - if builder == nil { - return nil, fmt.Errorf("no registered builder for provider name: %s", name) - } - - var ( - stableConfig StableConfig - err error - ) - if c, ok := config.(StableConfig); ok { - // The config passed to the store has already been parsed. - stableConfig = c - } else { - stableConfig, err = builder.ParseConfig(config) - if err != nil { - return nil, err - } - } - - sk := storeKey{ - name: name, - config: string(stableConfig.Canonical()), - opts: opts, - } - if wp, ok := provStore.providers[sk]; ok { - wp.refCount++ - return wp, nil - } - - provider := builder.Build(stableConfig, opts) - if provider == nil { - return nil, fmt.Errorf("certprovider.Build(%v) failed", sk) - } - wp := &wrappedProvider{ - Provider: provider, - refCount: 1, - storeKey: sk, - store: provStore, - } - provStore.providers[sk] = wp - return wp, nil -} - // Close overrides the Close method of the embedded provider. It releases the // reference held by the caller on the underlying provider and if the // provider's reference count reaches zero, it is removed from the store, and @@ -134,3 +74,83 @@ func (wp *wrappedProvider) Close() { delete(ps.providers, wp.storeKey) } } + +// BuildableConfig wraps parsed provider configuration and functionality to +// instantiate provider instances. +type BuildableConfig struct { + name string + config []byte + starter func(BuildOptions) Provider + pStore *store +} + +// NewBuildableConfig creates a new BuildableConfig with the given arguments. +// Provider implementations are expected to invoke this function after parsing +// the given configuration as part of their ParseConfig() method. +// Equivalent configurations are expected to invoke this function with the same +// config argument. +func NewBuildableConfig(name string, config []byte, starter func(BuildOptions) Provider) *BuildableConfig { + return &BuildableConfig{ + name: name, + config: config, + starter: starter, + pStore: provStore, + } +} + +// Build kicks off a provider instance with the wrapped configuration. Multiple +// invocations of this method with the same opts will result in provider +// instances being reused. +func (bc *BuildableConfig) Build(opts BuildOptions) (Provider, error) { + provStore.mu.Lock() + defer provStore.mu.Unlock() + + sk := storeKey{ + name: bc.name, + config: string(bc.config), + opts: opts, + } + if wp, ok := provStore.providers[sk]; ok { + wp.refCount++ + return wp, nil + } + + provider := bc.starter(opts) + if provider == nil { + return nil, fmt.Errorf("provider(%q, %q).Build(%v) failed", sk.name, sk.config, opts) + } + wp := &wrappedProvider{ + Provider: provider, + refCount: 1, + storeKey: sk, + store: provStore, + } + provStore.providers[sk] = wp + return wp, nil +} + +// String returns the provider name and config as a colon separated string. +func (bc *BuildableConfig) String() string { + return fmt.Sprintf("%s:%s", bc.name, string(bc.config)) +} + +// ParseConfig is a convenience function to create a BuildableConfig given a +// provider name and configuration. Returns an error if there is no registered +// builder for the given name or if the config parsing fails. +func ParseConfig(name string, config interface{}) (*BuildableConfig, error) { + parser := getBuilder(name) + if parser == nil { + return nil, fmt.Errorf("no certificate provider builder found for %q", name) + } + return parser.ParseConfig(config) +} + +// GetProvider is a convenience function to create a provider given the name, +// config and build options. +func GetProvider(name string, config interface{}, opts BuildOptions) (Provider, error) { + bc, err := ParseConfig(name, config) + if err != nil { + return nil, err + } + return bc.Build(opts) +} diff --git a/credentials/tls/certprovider/store_test.go b/credentials/tls/certprovider/store_test.go index 618c2d7d..00d33a2b 100644 --- a/credentials/tls/certprovider/store_test.go +++ b/credentials/tls/certprovider/store_test.go @@ -37,10 +37,11 @@ import ( ) const ( - fakeProvider1Name = "fake-certificate-provider-1" - fakeProvider2Name = "fake-certificate-provider-2" - fakeConfig = "my fake config" - defaultTestTimeout = 1 * time.Second + fakeProvider1Name = "fake-certificate-provider-1" + fakeProvider2Name = "fake-certificate-provider-2" + fakeConfig = "my fake config" + defaultTestTimeout = 5 * time.Second + defaultTestShortTimeout = 10 * time.Millisecond ) var fpb1, fpb2 *fakeProviderBuilder @@ -73,36 +74,36 @@ type fakeProviderBuilder struct { providerChan *testutils.Channel } -func (b *fakeProviderBuilder) Build(StableConfig, Options) Provider { - p := &fakeProvider{Distributor: NewDistributor()} - b.providerChan.Send(p) - return p -} - -func (b *fakeProviderBuilder) ParseConfig(config interface{}) (StableConfig, error) { +func (b *fakeProviderBuilder) ParseConfig(config interface{}) (*BuildableConfig, error) { s, ok := config.(string) if !ok { return nil, fmt.Errorf("providerBuilder %s received config of type %T, want string", b.name, config) } - return &fakeStableConfig{config: s}, nil + return NewBuildableConfig(b.name, []byte(s), func(BuildOptions) Provider { + fp := &fakeProvider{ + Distributor: NewDistributor(), + config: s, + } + b.providerChan.Send(fp) + return fp + }), nil } func (b *fakeProviderBuilder) Name() string { return b.name } -type fakeStableConfig struct { - config string -} - -func (c *fakeStableConfig) Canonical() []byte { - return []byte(c.config) -} - // fakeProvider is an implementation of the Provider interface which provides a // method for tests to invoke to push new key materials. type fakeProvider struct { *Distributor + config string +} + +func (p *fakeProvider) Start(BuildOptions) Provider { + // This is practically a no-op since this provider doesn't do any work which + // needs to be started at this point. + return p } // newKeyMaterial allows tests to push new key material to the fake provider @@ -166,15 +167,19 @@ func compareKeyMaterial(got, want *KeyMaterial) error { return nil } +func createProvider(t *testing.T, name, config string, opts BuildOptions) Provider { + t.Helper() + prov, err := GetProvider(name, config, opts) + if err != nil { + t.Fatalf("GetProvider(%s, %s, %v) failed: %v", name, config, opts, err) + } + return prov +} + // TestStoreSingleProvider creates a single provider through the store and calls // methods on them. func (s) TestStoreSingleProvider(t *testing.T) { - // Create a Provider through the store. - kmOpts := Options{CertName: "default"} - prov, err := GetProvider(fakeProvider1Name, fakeConfig, kmOpts) - if err != nil { - t.Fatalf("GetProvider(%s, %s, %v) failed: %v", fakeProvider1Name, fakeConfig, kmOpts, err) - } + prov := createProvider(t, fakeProvider1Name, fakeConfig, BuildOptions{CertName: "default"}) defer prov.Close() // Our fakeProviderBuilder pushes newly created providers on a channel. Grab @@ -190,7 +195,9 @@ func (s) TestStoreSingleProvider(t *testing.T) { // Attempt to read from key material from the Provider returned by the // store. This will fail because we have not pushed any key material into // our fake provider. - if err := readAndVerifyKeyMaterial(ctx, prov, nil); !errors.Is(err, context.DeadlineExceeded) { + sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if err := readAndVerifyKeyMaterial(sCtx, prov, nil); !errors.Is(err, context.DeadlineExceeded) { t.Fatal(err) } @@ -198,8 +205,6 @@ func (s) TestStoreSingleProvider(t *testing.T) { // and attempt to read from the Provider returned by the store. testKM1 := loadKeyMaterials(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem") fakeProv.newKeyMaterial(testKM1, nil) - ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() if err := readAndVerifyKeyMaterial(ctx, prov, testKM1); err != nil { t.Fatal(err) } @@ -220,18 +225,14 @@ func (s) TestStoreSingleProvider(t *testing.T) { func (s) TestStoreSingleProviderSameConfigDifferentOpts(t *testing.T) { // Create three readers on the same fake provider. Two of these readers use // certName `foo`, while the third one uses certName `bar`. - optsFoo := Options{CertName: "foo"} - optsBar := Options{CertName: "bar"} - provFoo1, err := GetProvider(fakeProvider1Name, fakeConfig, optsFoo) - if err != nil { - t.Fatalf("GetProvider(%s, %s, %v) failed: %v", fakeProvider1Name, fakeConfig, optsFoo, err) - } - defer provFoo1.Close() - provFoo2, err := GetProvider(fakeProvider1Name, fakeConfig, optsFoo) - if err != nil { - t.Fatalf("GetProvider(%s, %s, %v) failed: %v", fakeProvider1Name, fakeConfig, optsFoo, err) - } - defer provFoo2.Close() + optsFoo := BuildOptions{CertName: "foo"} + provFoo1 := createProvider(t, fakeProvider1Name, fakeConfig, optsFoo) + provFoo2 := createProvider(t, fakeProvider1Name, fakeConfig, optsFoo) + defer func() { + provFoo1.Close() + provFoo2.Close() + }() + // Our fakeProviderBuilder pushes newly created providers on a channel. // Grab the fake provider for optsFoo. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) @@ -242,11 +243,18 @@ func (s) TestStoreSingleProviderSameConfigDifferentOpts(t *testing.T) { } fakeProvFoo := p.(*fakeProvider) - provBar1, err := GetProvider(fakeProvider1Name, fakeConfig, optsBar) - if err != nil { - t.Fatalf("GetProvider(%s, %s, %v) failed: %v", fakeProvider1Name, fakeConfig, optsBar, err) + // Make sure only provider was created by the builder so far. The store + // should be able to share the providers. + sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if _, err := fpb1.providerChan.Receive(sCtx); !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("A second provider created when expected to be shared by the store") } + + optsBar := BuildOptions{CertName: "bar"} + provBar1 := createProvider(t, fakeProvider1Name, fakeConfig, optsBar) defer provBar1.Close() + // Grab the fake provider for optsBar. p, err = fpb1.providerChan.Receive(ctx) if err != nil { @@ -264,7 +272,9 @@ func (s) TestStoreSingleProviderSameConfigDifferentOpts(t *testing.T) { if err := readAndVerifyKeyMaterial(ctx, provFoo2, fooKM); err != nil { t.Fatal(err) } - if err := readAndVerifyKeyMaterial(ctx, provBar1, nil); !errors.Is(err, context.DeadlineExceeded) { + sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if err := readAndVerifyKeyMaterial(sCtx, provBar1, nil); !errors.Is(err, context.DeadlineExceeded) { t.Fatal(err) } @@ -272,8 +282,6 @@ func (s) TestStoreSingleProviderSameConfigDifferentOpts(t *testing.T) { // appropriate key material. barKM := loadKeyMaterials(t, "x509/server2_cert.pem", "x509/server2_key.pem", "x509/client_ca_cert.pem") fakeProvBar.newKeyMaterial(barKM, nil) - ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() if err := readAndVerifyKeyMaterial(ctx, provBar1, barKM); err != nil { t.Fatal(err) } @@ -290,13 +298,11 @@ func (s) TestStoreSingleProviderSameConfigDifferentOpts(t *testing.T) { // would take place. func (s) TestStoreSingleProviderDifferentConfigs(t *testing.T) { // Create two providers of the same type, but with different configs. - opts := Options{CertName: "foo"} + opts := BuildOptions{CertName: "foo"} cfg1 := fakeConfig + "1111" cfg2 := fakeConfig + "2222" - prov1, err := GetProvider(fakeProvider1Name, cfg1, opts) - if err != nil { - t.Fatalf("GetProvider(%s, %s, %v) failed: %v", fakeProvider1Name, cfg1, opts, err) - } + + prov1 := createProvider(t, fakeProvider1Name, cfg1, opts) defer prov1.Close() // Our fakeProviderBuilder pushes newly created providers on a channel. Grab // the fake provider from that channel. @@ -308,10 +314,7 @@ func (s) TestStoreSingleProviderDifferentConfigs(t *testing.T) { } fakeProv1 := p1.(*fakeProvider) - prov2, err := GetProvider(fakeProvider1Name, cfg2, opts) - if err != nil { - t.Fatalf("GetProvider(%s, %s, %v) failed: %v", fakeProvider1Name, cfg2, opts, err) - } + prov2 := createProvider(t, fakeProvider1Name, cfg2, opts) defer prov2.Close() // Grab the second provider from the channel. p2, err := fpb1.providerChan.Receive(ctx) @@ -354,11 +357,8 @@ func (s) TestStoreSingleProviderDifferentConfigs(t *testing.T) { // TestStoreMultipleProviders creates providers of different types and makes // sure closing of one does not affect the other. func (s) TestStoreMultipleProviders(t *testing.T) { - opts := Options{CertName: "foo"} - prov1, err := GetProvider(fakeProvider1Name, fakeConfig, opts) - if err != nil { - t.Fatalf("GetProvider(%s, %s, %v) failed: %v", fakeProvider1Name, fakeConfig, opts, err) - } + opts := BuildOptions{CertName: "foo"} + prov1 := createProvider(t, fakeProvider1Name, fakeConfig, opts) defer prov1.Close() // Our fakeProviderBuilder pushes newly created providers on a channel. Grab // the fake provider from that channel. @@ -370,10 +370,7 @@ func (s) TestStoreMultipleProviders(t *testing.T) { } fakeProv1 := p1.(*fakeProvider) - prov2, err := GetProvider(fakeProvider2Name, fakeConfig, opts) - if err != nil { - t.Fatalf("GetProvider(%s, %s, %v) failed: %v", fakeProvider1Name, fakeConfig, opts, err) - } + prov2 := createProvider(t, fakeProvider2Name, fakeConfig, opts) defer prov2.Close() // Grab the second provider from the channel. p2, err := fpb2.providerChan.Receive(ctx) diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer.go b/xds/internal/balancer/cdsbalancer/cdsbalancer.go index 55d6e8c9..ab4b78ea 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer.go @@ -35,7 +35,6 @@ import ( "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/xds/internal/balancer/edsbalancer" - "google.golang.org/grpc/xds/internal/client/bootstrap" xdsinternal "google.golang.org/grpc/xds/internal" xdsclient "google.golang.org/grpc/xds/internal/client" @@ -61,8 +60,7 @@ var ( // not deal with subConns. return builder.Build(cc, opts), nil } - - getProvider = certprovider.GetProvider + buildProvider = buildProviderFunc ) func init() { @@ -133,7 +131,7 @@ func (cdsBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, // the cdsBalancer. This will be faked out in unittests. type xdsClientInterface interface { WatchCluster(string, func(xdsclient.ClusterUpdate, error)) func() - CertProviderConfigs() map[string]bootstrap.CertProviderConfig + CertProviderConfigs() map[string]*certprovider.BuildableConfig Close() } @@ -252,48 +250,28 @@ func (b *cdsBalancer) handleSecurityConfig(config *xdsclient.SecurityConfig) err } // A root provider is required whether we are using TLS or mTLS. - rootCfg, ok := cpc[config.RootInstanceName] - if !ok { - return fmt.Errorf("certificate provider instance %q not found in bootstrap file", config.RootInstanceName) - } - rootProvider, err := getProvider(rootCfg.Name, rootCfg.Config, certprovider.Options{ - CertName: config.RootCertName, - WantRoot: true, - }) + rootProvider, err := buildProvider(cpc, config.RootInstanceName, config.RootCertName, false, true) if err != nil { - // This error is not expected since the bootstrap process parses the - // config and makes sure that it is acceptable to the plugin. Still, it - // is possible that the plugin parses the config successfully, but its - // Build() method errors out. - return fmt.Errorf("xds: failed to get security plugin instance (%+v): %v", rootCfg, err) - } - if b.cachedRoot != nil { - b.cachedRoot.Close() + return err } // The identity provider is only present when using mTLS. var identityProvider certprovider.Provider - if name := config.IdentityInstanceName; name != "" { - identityCfg := cpc[name] - if !ok { - return fmt.Errorf("certificate provider instance %q not found in bootstrap file", config.IdentityInstanceName) - } - identityProvider, err = getProvider(identityCfg.Name, identityCfg.Config, certprovider.Options{ - CertName: config.IdentityCertName, - WantIdentity: true, - }) + if name, cert := config.IdentityInstanceName, config.IdentityCertName; name != "" { + var err error + identityProvider, err = buildProvider(cpc, name, cert, true, false) if err != nil { - // This error is not expected since the bootstrap process parses the - // config and makes sure that it is acceptable to the plugin. Still, - // it is possible that the plugin parses the config successfully, - // but its Build() method errors out. - return fmt.Errorf("xds: failed to get security plugin instance (%+v): %v", identityCfg, err) + return err } } + + // Close the old providers and cache the new ones. + if b.cachedRoot != nil { + b.cachedRoot.Close() + } if b.cachedIdentity != nil { b.cachedIdentity.Close() } - b.cachedRoot = rootProvider b.cachedIdentity = identityProvider @@ -305,6 +283,26 @@ func (b *cdsBalancer) handleSecurityConfig(config *xdsclient.SecurityConfig) err return nil } +func buildProviderFunc(configs map[string]*certprovider.BuildableConfig, instanceName, certName string, wantIdentity, wantRoot bool) (certprovider.Provider, error) { + cfg, ok := configs[instanceName] + if !ok { + return nil, fmt.Errorf("certificate provider instance %q not found in bootstrap file", instanceName) + } + provider, err := cfg.Build(certprovider.BuildOptions{ + CertName: certName, + WantIdentity: wantIdentity, + WantRoot: wantRoot, + }) + if err != nil { + // This error is not expected since the bootstrap process parses the + // config and makes sure that it is acceptable to the plugin. Still, it + // is possible that the plugin parses the config successfully, but its + // Build() method errors out. + return nil, fmt.Errorf("xds: failed to get security plugin instance (%+v): %v", cfg, err) + } + return provider, nil +} + // handleWatchUpdate handles a watch update from the xDS Client. Good updates // lead to clientConn updates being invoked on the underlying edsBalancer. func (b *cdsBalancer) handleWatchUpdate(update *watchUpdate) { diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go b/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go index 15c097d3..a3f923df 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go @@ -31,7 +31,6 @@ import ( "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/resolver" xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" xdstestutils "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/testutils/fakeclient" ) @@ -44,16 +43,7 @@ const ( var ( fpb1, fpb2 *fakeProviderBuilder - bootstrapCertProviderConfigs = map[string]bootstrap.CertProviderConfig{ - "default1": { - Name: fakeProvider1Name, - Config: &fakeStableConfig{config: fakeConfig + "1111"}, - }, - "default2": { - Name: fakeProvider2Name, - Config: &fakeStableConfig{config: fakeConfig + "2222"}, - }, - } + bootstrapCertProviderConfigs map[string]*certprovider.BuildableConfig cdsUpdateWithGoodSecurityCfg = xdsclient.ClusterUpdate{ ServiceName: serviceName, SecurityCfg: &xdsclient.SecurityConfig{ @@ -72,6 +62,12 @@ var ( func init() { fpb1 = &fakeProviderBuilder{name: fakeProvider1Name} fpb2 = &fakeProviderBuilder{name: fakeProvider2Name} + cfg1, _ := fpb1.ParseConfig(fakeConfig + "1111") + cfg2, _ := fpb2.ParseConfig(fakeConfig + "2222") + bootstrapCertProviderConfigs = map[string]*certprovider.BuildableConfig{ + "default1": cfg1, + "default2": cfg2, + } certprovider.Register(fpb1) certprovider.Register(fpb2) } @@ -82,40 +78,38 @@ type fakeProviderBuilder struct { name string } -func (b *fakeProviderBuilder) Build(certprovider.StableConfig, certprovider.Options) certprovider.Provider { - p := &fakeProvider{} - return p -} - -func (b *fakeProviderBuilder) ParseConfig(config interface{}) (certprovider.StableConfig, error) { +func (b *fakeProviderBuilder) ParseConfig(config interface{}) (*certprovider.BuildableConfig, error) { s, ok := config.(string) if !ok { return nil, fmt.Errorf("providerBuilder %s received config of type %T, want string", b.name, config) } - return &fakeStableConfig{config: s}, nil + return certprovider.NewBuildableConfig(b.name, []byte(s), func(certprovider.BuildOptions) certprovider.Provider { + return &fakeProvider{ + Distributor: certprovider.NewDistributor(), + config: s, + } + }), nil } func (b *fakeProviderBuilder) Name() string { return b.name } -type fakeStableConfig struct { - config string -} - -func (c *fakeStableConfig) Canonical() []byte { - return []byte(c.config) -} - // fakeProvider is an implementation of the Provider interface which provides a // method for tests to invoke to push new key materials. type fakeProvider struct { - certprovider.Provider + *certprovider.Distributor + config string +} + +// Close helps implement the Provider interface. +func (p *fakeProvider) Close() { + p.Distributor.Stop() } // setupWithXDSCreds performs all the setup steps required for tests which use // xDSCredentials. -func setupWithXDSCreds(t *testing.T, storeErr bool) (*fakeclient.Client, *cdsBalancer, *testEDSBalancer, *xdstestutils.TestClientConn, *testutils.Channel, func()) { +func setupWithXDSCreds(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBalancer, *xdstestutils.TestClientConn, func()) { t.Helper() builder := balancer.Get(cdsName) @@ -162,25 +156,8 @@ func setupWithXDSCreds(t *testing.T, storeErr bool) (*fakeclient.Client, *cdsBal t.Fatalf("xdsClient.WatchCDS called for cluster: %v, want: %v", gotCluster, clusterName) } - // Override the certificate provider creation function to get notified about - // provider creation. Create a channel with size 2, so we have buffer to - // push notifications for both providers. - providerCh := testutils.NewChannelWithSize(2) - origGetProviderFunc := getProvider - if storeErr { - getProvider = func(string, interface{}, certprovider.Options) (certprovider.Provider, error) { - return nil, errors.New("certprovider.Store failed to created provider") - } - } else { - getProvider = func(name string, cfg interface{}, opts certprovider.Options) (certprovider.Provider, error) { - providerCh.Send(nil) - return origGetProviderFunc(name, cfg, opts) - } - } - - return xdsC, cdsB.(*cdsBalancer), edsB, tcc, providerCh, func() { + return xdsC, cdsB.(*cdsBalancer), edsB, tcc, func() { newEDSBalancer = oldEDSBalancerBuilder - getProvider = origGetProviderFunc } } @@ -223,16 +200,6 @@ func makeNewSubConn(ctx context.Context, edsCC balancer.ClientConn, parentCC *xd // the address attributes added as part of the intercepted NewSubConn() method // indicate the use of fallback credentials. func (s) TestSecurityConfigWithoutXDSCreds(t *testing.T) { - // Override the certificate provider creation function to get notified about - // provider creation. - providerCh := testutils.NewChannel() - origGetProviderFunc := getProvider - getProvider = func(name string, cfg interface{}, opts certprovider.Options) (certprovider.Provider, error) { - providerCh.Send(nil) - return origGetProviderFunc(name, cfg, opts) - } - defer func() { getProvider = origGetProviderFunc }() - // This creates a CDS balancer, pushes a ClientConnState update with a fake // xdsClient, and makes sure that the CDS balancer registers a watch on the // provided xdsClient. @@ -242,6 +209,17 @@ func (s) TestSecurityConfigWithoutXDSCreds(t *testing.T) { cdsB.Close() }() + // Override the provider builder function to push on a channel. We do not + // expect this function to be called as part of this test. + providerCh := testutils.NewChannel() + origBuildProvider := buildProvider + buildProvider = func(c map[string]*certprovider.BuildableConfig, id, cert string, wi, wr bool) (certprovider.Provider, error) { + p, err := origBuildProvider(c, id, cert, wi, wr) + providerCh.Send(nil) + return p, err + } + defer func() { buildProvider = origBuildProvider }() + // Here we invoke the watch callback registered on the fake xdsClient. This // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be @@ -255,7 +233,9 @@ func (s) TestSecurityConfigWithoutXDSCreds(t *testing.T) { t.Fatal(err) } - // Make a NewSubConn and verify that attributes are not added. + // Make a NewSubConn and verify that the HandshakeInfo does not contain any + // certificate providers, forcing the credentials implementation to use + // fallback creds. if err := makeNewSubConn(ctx, edsB.parentCC, tcc, true); err != nil { t.Fatal(err) } @@ -278,12 +258,23 @@ func (s) TestNoSecurityConfigWithXDSCreds(t *testing.T) { // This creates a CDS balancer which uses xdsCredentials, pushes a // ClientConnState update with a fake xdsClient, and makes sure that the CDS // balancer registers a watch on the provided xdsClient. - xdsC, cdsB, edsB, tcc, providerCh, cancel := setupWithXDSCreds(t, false) + xdsC, cdsB, edsB, tcc, cancel := setupWithXDSCreds(t) defer func() { cancel() cdsB.Close() }() + // Override the provider builder function to push on a channel. We do not + // expect this function to be called as part of this test. + providerCh := testutils.NewChannel() + origBuildProvider := buildProvider + buildProvider = func(c map[string]*certprovider.BuildableConfig, id, cert string, wi, wr bool) (certprovider.Provider, error) { + p, err := origBuildProvider(c, id, cert, wi, wr) + providerCh.Send(nil) + return p, err + } + defer func() { buildProvider = origBuildProvider }() + // Here we invoke the watch callback registered on the fake xdsClient. This // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be @@ -298,7 +289,9 @@ func (s) TestNoSecurityConfigWithXDSCreds(t *testing.T) { t.Fatal(err) } - // Make a NewSubConn and verify that attributes are not added. + // Make a NewSubConn and verify that the HandshakeInfo does not contain any + // certificate providers, forcing the credentials implementation to use + // fallback creds. if err := makeNewSubConn(ctx, edsB.parentCC, tcc, true); err != nil { t.Fatal(err) } @@ -325,7 +318,7 @@ func (s) TestSecurityConfigNotFoundInBootstrap(t *testing.T) { // This creates a CDS balancer which uses xdsCredentials, pushes a // ClientConnState update with a fake xdsClient, and makes sure that the CDS // balancer registers a watch on the provided xdsClient. - xdsC, cdsB, edsB, tcc, _, cancel := setupWithXDSCreds(t, false) + xdsC, cdsB, edsB, tcc, cancel := setupWithXDSCreds(t) defer func() { cancel() cdsB.Close() @@ -366,12 +359,19 @@ func (s) TestCertproviderStoreError(t *testing.T) { // This creates a CDS balancer which uses xdsCredentials, pushes a // ClientConnState update with a fake xdsClient, and makes sure that the CDS // balancer registers a watch on the provided xdsClient. - xdsC, cdsB, edsB, tcc, _, cancel := setupWithXDSCreds(t, true) + xdsC, cdsB, edsB, tcc, cancel := setupWithXDSCreds(t) defer func() { cancel() cdsB.Close() }() + // Override the provider builder function to return an error. + origBuildProvider := buildProvider + buildProvider = func(c map[string]*certprovider.BuildableConfig, id, cert string, wi, wr bool) (certprovider.Provider, error) { + return nil, errors.New("certprovider store error") + } + defer func() { buildProvider = origBuildProvider }() + // Set the bootstrap config used by the fake client. xdsC.SetCertProviderConfigs(bootstrapCertProviderConfigs) @@ -402,7 +402,7 @@ func (s) TestSecurityConfigUpdate_BadToGood(t *testing.T) { // This creates a CDS balancer which uses xdsCredentials, pushes a // ClientConnState update with a fake xdsClient, and makes sure that the CDS // balancer registers a watch on the provided xdsClient. - xdsC, cdsB, edsB, tcc, providerCh, cancel := setupWithXDSCreds(t, false) + xdsC, cdsB, edsB, tcc, cancel := setupWithXDSCreds(t) defer func() { cancel() cdsB.Close() @@ -442,12 +442,6 @@ func (s) TestSecurityConfigUpdate_BadToGood(t *testing.T) { if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdateWithGoodSecurityCfg, nil}, wantCCS, edsB); err != nil { t.Fatal(err) } - // Make sure two certificate providers are created. - for i := 0; i < 2; i++ { - if _, err := providerCh.Receive(ctx); err != nil { - t.Fatalf("Failed to create certificate provider upon receipt of security config") - } - } // Make a NewSubConn and verify that attributes are added. if err := makeNewSubConn(ctx, edsB.parentCC, tcc, false); err != nil { @@ -464,7 +458,7 @@ func (s) TestGoodSecurityConfig(t *testing.T) { // This creates a CDS balancer which uses xdsCredentials, pushes a // ClientConnState update with a fake xdsClient, and makes sure that the CDS // balancer registers a watch on the provided xdsClient. - xdsC, cdsB, edsB, tcc, providerCh, cancel := setupWithXDSCreds(t, false) + xdsC, cdsB, edsB, tcc, cancel := setupWithXDSCreds(t) defer func() { cancel() cdsB.Close() @@ -484,12 +478,6 @@ func (s) TestGoodSecurityConfig(t *testing.T) { if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdateWithGoodSecurityCfg, nil}, wantCCS, edsB); err != nil { t.Fatal(err) } - // Make sure two certificate providers are created. - for i := 0; i < 2; i++ { - if _, err := providerCh.Receive(ctx); err != nil { - t.Fatalf("Failed to create certificate provider upon receipt of security config") - } - } // Make a NewSubConn and verify that attributes are added. if err := makeNewSubConn(ctx, edsB.parentCC, tcc, false); err != nil { @@ -501,7 +489,7 @@ func (s) TestSecurityConfigUpdate_GoodToFallback(t *testing.T) { // This creates a CDS balancer which uses xdsCredentials, pushes a // ClientConnState update with a fake xdsClient, and makes sure that the CDS // balancer registers a watch on the provided xdsClient. - xdsC, cdsB, edsB, tcc, providerCh, cancel := setupWithXDSCreds(t, false) + xdsC, cdsB, edsB, tcc, cancel := setupWithXDSCreds(t) defer func() { cancel() cdsB.Close() @@ -521,12 +509,6 @@ func (s) TestSecurityConfigUpdate_GoodToFallback(t *testing.T) { if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdateWithGoodSecurityCfg, nil}, wantCCS, edsB); err != nil { t.Fatal(err) } - // Make sure two certificate providers are created. - for i := 0; i < 2; i++ { - if _, err := providerCh.Receive(ctx); err != nil { - t.Fatalf("Failed to create certificate provider upon receipt of security config") - } - } // Make a NewSubConn and verify that attributes are added. if err := makeNewSubConn(ctx, edsB.parentCC, tcc, false); err != nil { @@ -557,7 +539,7 @@ func (s) TestSecurityConfigUpdate_GoodToBad(t *testing.T) { // This creates a CDS balancer which uses xdsCredentials, pushes a // ClientConnState update with a fake xdsClient, and makes sure that the CDS // balancer registers a watch on the provided xdsClient. - xdsC, cdsB, edsB, tcc, providerCh, cancel := setupWithXDSCreds(t, false) + xdsC, cdsB, edsB, tcc, cancel := setupWithXDSCreds(t) defer func() { cancel() cdsB.Close() @@ -577,12 +559,6 @@ func (s) TestSecurityConfigUpdate_GoodToBad(t *testing.T) { if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdateWithGoodSecurityCfg, nil}, wantCCS, edsB); err != nil { t.Fatal(err) } - // Make sure two certificate providers are created. - for i := 0; i < 2; i++ { - if _, err := providerCh.Receive(ctx); err != nil { - t.Fatalf("Failed to create certificate provider upon receipt of security config") - } - } // Make a NewSubConn and verify that attributes are added. if err := makeNewSubConn(ctx, edsB.parentCC, tcc, false); err != nil { @@ -624,12 +600,22 @@ func (s) TestSecurityConfigUpdate_GoodToGood(t *testing.T) { // This creates a CDS balancer which uses xdsCredentials, pushes a // ClientConnState update with a fake xdsClient, and makes sure that the CDS // balancer registers a watch on the provided xdsClient. - xdsC, cdsB, edsB, tcc, providerCh, cancel := setupWithXDSCreds(t, false) + xdsC, cdsB, edsB, tcc, cancel := setupWithXDSCreds(t) defer func() { cancel() cdsB.Close() }() + // Override the provider builder function to push on a channel. + providerCh := testutils.NewChannel() + origBuildProvider := buildProvider + buildProvider = func(c map[string]*certprovider.BuildableConfig, id, cert string, wi, wr bool) (certprovider.Provider, error) { + p, err := origBuildProvider(c, id, cert, wi, wr) + providerCh.Send(nil) + return p, err + } + defer func() { buildProvider = origBuildProvider }() + // Set the bootstrap config used by the fake client. xdsC.SetCertProviderConfigs(bootstrapCertProviderConfigs) diff --git a/xds/internal/client/bootstrap/bootstrap.go b/xds/internal/client/bootstrap/bootstrap.go index 0ef38094..789d1d0e 100644 --- a/xds/internal/client/bootstrap/bootstrap.go +++ b/xds/internal/client/bootstrap/bootstrap.go @@ -75,18 +75,9 @@ type Config struct { // NodeProto contains the Node proto to be used in xDS requests. The actual // type depends on the transport protocol version used. NodeProto proto.Message - // CertProviderConfigs contain parsed configs for supported certificate - // provider plugins found in the bootstrap file. - CertProviderConfigs map[string]CertProviderConfig -} - -// CertProviderConfig wraps the certificate provider plugin name and config -// (corresponding to one plugin instance) found in the bootstrap file. -type CertProviderConfig struct { - // Name is the registered name of the certificate provider. - Name string - // Config is the parsed config to be passed to the certificate provider. - Config certprovider.StableConfig + // CertProviderConfigs contains a mapping from certificate provider plugin + // instance names to parsed buildable configs. + CertProviderConfigs map[string]*certprovider.BuildableConfig } type channelCreds struct { @@ -207,7 +198,7 @@ func NewConfig() (*Config, error) { if err := json.Unmarshal(v, &providerInstances); err != nil { return nil, fmt.Errorf("xds: json.Unmarshal(%v) for field %q failed during bootstrap: %v", string(v), k, err) } - configs := make(map[string]CertProviderConfig) + configs := make(map[string]*certprovider.BuildableConfig) getBuilder := internal.GetCertificateProviderBuilder.(func(string) certprovider.Builder) for instance, data := range providerInstances { var nameAndConfig struct { @@ -224,15 +215,11 @@ func NewConfig() (*Config, error) { // We ignore plugins that we do not know about. continue } - cfg := nameAndConfig.Config - c, err := parser.ParseConfig(cfg) + bc, err := parser.ParseConfig(nameAndConfig.Config) if err != nil { return nil, fmt.Errorf("xds: Config parsing for plugin %q failed: %v", name, err) } - configs[instance] = CertProviderConfig{ - Name: name, - Config: c, - } + configs[instance] = bc } config.CertProviderConfigs = configs } diff --git a/xds/internal/client/bootstrap/bootstrap_test.go b/xds/internal/client/bootstrap/bootstrap_test.go index 6ee68c34..1b9decac 100644 --- a/xds/internal/client/bootstrap/bootstrap_test.go +++ b/xds/internal/client/bootstrap/bootstrap_test.go @@ -251,10 +251,10 @@ func (c *Config) compare(want *Config) error { for instance, gotCfg := range gotCfgs { wantCfg, ok := wantCfgs[instance] if !ok { - return fmt.Errorf("config.CertProviderConfigs has unexpected plugin instance %q with config %q", instance, string(gotCfg.Config.Canonical())) + return fmt.Errorf("config.CertProviderConfigs has unexpected plugin instance %q with config %q", instance, gotCfg.String()) } - if gotCfg.Name != wantCfg.Name || !cmp.Equal(gotCfg.Config.Canonical(), wantCfg.Config.Canonical()) { - return fmt.Errorf("config.CertProviderConfigs for plugin instance %q has config {%s, %s, want {%s, %s}", instance, gotCfg.Name, string(gotCfg.Config.Canonical()), wantCfg.Name, string(wantCfg.Config.Canonical())) + if got, want := gotCfg.String(), wantCfg.String(); got != want { + return fmt.Errorf("config.CertProviderConfigs for plugin instance %q has config %q, want %q", instance, got, want) } } return nil @@ -489,13 +489,9 @@ const fakeCertProviderName = "fake-certificate-provider" // interprets the config provided to it as JSON with a single key and value. type fakeCertProviderBuilder struct{} -func (b *fakeCertProviderBuilder) Build(certprovider.StableConfig, certprovider.Options) certprovider.Provider { - return &fakeCertProvider{} -} - // ParseConfig expects input in JSON format containing a map from string to // string, with a single entry and mapKey being "configKey". -func (b *fakeCertProviderBuilder) ParseConfig(cfg interface{}) (certprovider.StableConfig, error) { +func (b *fakeCertProviderBuilder) ParseConfig(cfg interface{}) (*certprovider.BuildableConfig, error) { config, ok := cfg.(json.RawMessage) if !ok { return nil, fmt.Errorf("fakeCertProviderBuilder received config of type %T, want []byte", config) @@ -507,7 +503,10 @@ func (b *fakeCertProviderBuilder) ParseConfig(cfg interface{}) (certprovider.Sta if len(cfgData) != 1 || cfgData["configKey"] == "" { return nil, errors.New("fakeCertProviderBuilder received invalid config") } - return &fakeStableConfig{config: cfgData}, nil + fc := &fakeStableConfig{config: cfgData} + return certprovider.NewBuildableConfig(fakeCertProviderName, fc.canonical(), func(certprovider.BuildOptions) certprovider.Provider { + return &fakeCertProvider{} + }), nil } func (b *fakeCertProviderBuilder) Name() string { @@ -518,7 +517,7 @@ type fakeStableConfig struct { config map[string]string } -func (c *fakeStableConfig) Canonical() []byte { +func (c *fakeStableConfig) canonical() []byte { var cfg string for k, v := range c.config { cfg = fmt.Sprintf("%s:%s", k, v) @@ -652,11 +651,8 @@ func TestNewConfigWithCertificateProviders(t *testing.T) { Creds: grpc.WithCredentialsBundle(google.NewComputeEngineCredentials()), TransportAPI: version.TransportV3, NodeProto: v3NodeProto, - CertProviderConfigs: map[string]CertProviderConfig{ - "fakeProviderInstance": { - Name: fakeCertProviderName, - Config: wantCfg, - }, + CertProviderConfigs: map[string]*certprovider.BuildableConfig{ + "fakeProviderInstance": wantCfg, }, } tests := []struct { diff --git a/xds/internal/client/client.go b/xds/internal/client/client.go index 583f09ba..5497fbd9 100644 --- a/xds/internal/client/client.go +++ b/xds/internal/client/client.go @@ -30,6 +30,7 @@ import ( v2corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" "github.com/golang/protobuf/proto" + "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/xds/internal/client/load" "google.golang.org/grpc" @@ -422,10 +423,9 @@ func New(opts Options) (*Client, error) { } // CertProviderConfigs returns the certificate provider configuration from the -// "certificate_providers" field of the bootstrap file. The returned value is a -// map from plugin_instance_name to {plugin_name, plugin_config}. Callers must -// not modify the returned map. -func (c *Client) CertProviderConfigs() map[string]bootstrap.CertProviderConfig { +// "certificate_providers" field of the bootstrap file. The key in the returned +// map is the plugin_instance_name. Callers must not modify the returned map. +func (c *Client) CertProviderConfigs() map[string]*certprovider.BuildableConfig { return c.opts.Config.CertProviderConfigs } diff --git a/xds/internal/testutils/fakeclient/client.go b/xds/internal/testutils/fakeclient/client.go index 6496d132..c540b573 100644 --- a/xds/internal/testutils/fakeclient/client.go +++ b/xds/internal/testutils/fakeclient/client.go @@ -22,9 +22,9 @@ package fakeclient import ( "context" + "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/internal/testutils" xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" "google.golang.org/grpc/xds/internal/client/load" ) @@ -43,7 +43,7 @@ type Client struct { loadReportCh *testutils.Channel closeCh *testutils.Channel loadStore *load.Store - certConfigs map[string]bootstrap.CertProviderConfig + certConfigs map[string]*certprovider.BuildableConfig ldsCb func(xdsclient.ListenerUpdate, error) rdsCb func(xdsclient.RouteConfigUpdate, error) @@ -224,12 +224,12 @@ func (xdsC *Client) WaitForClose(ctx context.Context) error { } // CertProviderConfigs returns the configured certificate provider configs. -func (xdsC *Client) CertProviderConfigs() map[string]bootstrap.CertProviderConfig { +func (xdsC *Client) CertProviderConfigs() map[string]*certprovider.BuildableConfig { return xdsC.certConfigs } // SetCertProviderConfigs updates the certificate provider configs. -func (xdsC *Client) SetCertProviderConfigs(configs map[string]bootstrap.CertProviderConfig) { +func (xdsC *Client) SetCertProviderConfigs(configs map[string]*certprovider.BuildableConfig) { xdsC.certConfigs = configs }