alts: copy handshake address in Clone() (#2119)
This commit is contained in:
@ -131,7 +131,6 @@ func DefaultServerOptions() *ServerOptions {
|
|||||||
// It implements credentials.TransportCredentials interface.
|
// It implements credentials.TransportCredentials interface.
|
||||||
type altsTC struct {
|
type altsTC struct {
|
||||||
info *credentials.ProtocolInfo
|
info *credentials.ProtocolInfo
|
||||||
hsAddr string
|
|
||||||
side core.Side
|
side core.Side
|
||||||
accounts []string
|
accounts []string
|
||||||
hsAddress string
|
hsAddress string
|
||||||
@ -269,8 +268,16 @@ func (g *altsTC) Info() credentials.ProtocolInfo {
|
|||||||
|
|
||||||
func (g *altsTC) Clone() credentials.TransportCredentials {
|
func (g *altsTC) Clone() credentials.TransportCredentials {
|
||||||
info := *g.info
|
info := *g.info
|
||||||
|
var accounts []string
|
||||||
|
if g.accounts != nil {
|
||||||
|
accounts = make([]string, len(g.accounts))
|
||||||
|
copy(accounts, g.accounts)
|
||||||
|
}
|
||||||
return &altsTC{
|
return &altsTC{
|
||||||
info: &info,
|
info: &info,
|
||||||
|
side: g.side,
|
||||||
|
hsAddress: g.hsAddress,
|
||||||
|
accounts: accounts,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@
|
|||||||
package alts
|
package alts
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
@ -45,10 +46,40 @@ func TestOverrideServerName(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClone(t *testing.T) {
|
func TestCloneClient(t *testing.T) {
|
||||||
|
wantServerName := "server.name"
|
||||||
|
opt := DefaultClientOptions()
|
||||||
|
opt.TargetServiceAccounts = []string{"not", "empty"}
|
||||||
|
c := NewClientCreds(opt)
|
||||||
|
c.OverrideServerName(wantServerName)
|
||||||
|
cc := c.Clone()
|
||||||
|
if got, want := cc.Info().ServerName, wantServerName; got != want {
|
||||||
|
t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
cc.OverrideServerName("")
|
||||||
|
if got, want := c.Info().ServerName, wantServerName; got != want {
|
||||||
|
t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := cc.Info().ServerName, ""; got != want {
|
||||||
|
t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
ct := c.(*altsTC)
|
||||||
|
cct := cc.(*altsTC)
|
||||||
|
|
||||||
|
if ct.side != cct.side {
|
||||||
|
t.Errorf("cc.side = %q, want %q", cct.side, ct.side)
|
||||||
|
}
|
||||||
|
if ct.hsAddress != cct.hsAddress {
|
||||||
|
t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(ct.accounts, cct.accounts) {
|
||||||
|
t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloneServer(t *testing.T) {
|
||||||
wantServerName := "server.name"
|
wantServerName := "server.name"
|
||||||
// This is not testing any handshaker functionality, so it's fine to only
|
|
||||||
// use NewServerCreds and not NewClientCreds.
|
|
||||||
c := NewServerCreds(DefaultServerOptions())
|
c := NewServerCreds(DefaultServerOptions())
|
||||||
c.OverrideServerName(wantServerName)
|
c.OverrideServerName(wantServerName)
|
||||||
cc := c.Clone()
|
cc := c.Clone()
|
||||||
@ -62,6 +93,19 @@ func TestClone(t *testing.T) {
|
|||||||
if got, want := cc.Info().ServerName, ""; got != want {
|
if got, want := cc.Info().ServerName, ""; got != want {
|
||||||
t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
|
t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ct := c.(*altsTC)
|
||||||
|
cct := cc.(*altsTC)
|
||||||
|
|
||||||
|
if ct.side != cct.side {
|
||||||
|
t.Errorf("cc.side = %q, want %q", cct.side, ct.side)
|
||||||
|
}
|
||||||
|
if ct.hsAddress != cct.hsAddress {
|
||||||
|
t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(ct.accounts, cct.accounts) {
|
||||||
|
t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInfo(t *testing.T) {
|
func TestInfo(t *testing.T) {
|
||||||
|
Reference in New Issue
Block a user