diff --git a/credentials/alts/alts.go b/credentials/alts/alts.go index c24c5faf..94005654 100644 --- a/credentials/alts/alts.go +++ b/credentials/alts/alts.go @@ -131,7 +131,6 @@ func DefaultServerOptions() *ServerOptions { // It implements credentials.TransportCredentials interface. type altsTC struct { info *credentials.ProtocolInfo - hsAddr string side core.Side accounts []string hsAddress string @@ -269,8 +268,16 @@ func (g *altsTC) Info() credentials.ProtocolInfo { func (g *altsTC) Clone() credentials.TransportCredentials { info := *g.info + var accounts []string + if g.accounts != nil { + accounts = make([]string, len(g.accounts)) + copy(accounts, g.accounts) + } return &altsTC{ - info: &info, + info: &info, + side: g.side, + hsAddress: g.hsAddress, + accounts: accounts, } } diff --git a/credentials/alts/alts_test.go b/credentials/alts/alts_test.go index 35fae87b..6b041d99 100644 --- a/credentials/alts/alts_test.go +++ b/credentials/alts/alts_test.go @@ -19,6 +19,7 @@ package alts import ( + "reflect" "testing" "github.com/golang/protobuf/proto" @@ -45,10 +46,40 @@ func TestOverrideServerName(t *testing.T) { } } -func TestClone(t *testing.T) { +func TestCloneClient(t *testing.T) { + wantServerName := "server.name" + opt := DefaultClientOptions() + opt.TargetServiceAccounts = []string{"not", "empty"} + c := NewClientCreds(opt) + c.OverrideServerName(wantServerName) + cc := c.Clone() + if got, want := cc.Info().ServerName, wantServerName; got != want { + t.Fatalf("cc.Info().ServerName = %v, want %v", got, want) + } + cc.OverrideServerName("") + if got, want := c.Info().ServerName, wantServerName; got != want { + t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want) + } + if got, want := cc.Info().ServerName, ""; got != want { + t.Fatalf("cc.Info().ServerName = %v, want %v", got, want) + } + + ct := c.(*altsTC) + cct := cc.(*altsTC) + + if ct.side != cct.side { + t.Errorf("cc.side = %q, want %q", cct.side, ct.side) + } + if ct.hsAddress != cct.hsAddress { + t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress) + } + if !reflect.DeepEqual(ct.accounts, cct.accounts) { + t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts) + } +} + +func TestCloneServer(t *testing.T) { wantServerName := "server.name" - // This is not testing any handshaker functionality, so it's fine to only - // use NewServerCreds and not NewClientCreds. c := NewServerCreds(DefaultServerOptions()) c.OverrideServerName(wantServerName) cc := c.Clone() @@ -62,6 +93,19 @@ func TestClone(t *testing.T) { if got, want := cc.Info().ServerName, ""; got != want { t.Fatalf("cc.Info().ServerName = %v, want %v", got, want) } + + ct := c.(*altsTC) + cct := cc.(*altsTC) + + if ct.side != cct.side { + t.Errorf("cc.side = %q, want %q", cct.side, ct.side) + } + if ct.hsAddress != cct.hsAddress { + t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress) + } + if !reflect.DeepEqual(ct.accounts, cct.accounts) { + t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts) + } } func TestInfo(t *testing.T) {