diff --git a/credentials/alts/internal/handshaker/service/service.go b/credentials/alts/internal/handshaker/service/service.go index 0c7b5683..77d759cd 100644 --- a/credentials/alts/internal/handshaker/service/service.go +++ b/credentials/alts/internal/handshaker/service/service.go @@ -27,9 +27,12 @@ import ( ) var ( - // hsConn represents a connection to hypervisor handshaker service. - hsConn *grpc.ClientConn - mu sync.Mutex + // mu guards hsConnMap and hsDialer. + mu sync.Mutex + // hsConn represents a mapping from a hypervisor handshaker service address + // to a corresponding connection to a hypervisor handshaker service + // instance. + hsConnMap = make(map[string]*grpc.ClientConn) // hsDialer will be reassigned in tests. hsDialer = grpc.Dial ) @@ -41,7 +44,8 @@ func Dial(hsAddress string) (*grpc.ClientConn, error) { mu.Lock() defer mu.Unlock() - if hsConn == nil { + hsConn, ok := hsConnMap[hsAddress] + if !ok { // Create a new connection to the handshaker service. Note that // this connection stays open until the application is closed. var err error @@ -49,6 +53,7 @@ func Dial(hsAddress string) (*grpc.ClientConn, error) { if err != nil { return nil, err } + hsConnMap[hsAddress] = hsConn } return hsConn, nil } diff --git a/credentials/alts/internal/handshaker/service/service_test.go b/credentials/alts/internal/handshaker/service/service_test.go index 98160bf0..28b4af75 100644 --- a/credentials/alts/internal/handshaker/service/service_test.go +++ b/credentials/alts/internal/handshaker/service/service_test.go @@ -25,8 +25,8 @@ import ( ) const ( - // The address is irrelevant in this test. - testAddress = "some_address" + testAddress1 = "some_address_1" + testAddress2 = "some_address_2" ) func TestDial(t *testing.T) { @@ -40,30 +40,44 @@ func TestDial(t *testing.T) { } }() - // Ensure that hsConn is nil at first. - hsConn = nil - - // First call to Dial, it should create set hsConn. - conn1, err := Dial(testAddress) + // First call to Dial, it should create a connection to the server running + // at the given address. + conn1, err := Dial(testAddress1) if err != nil { - t.Fatalf("first call to Dial failed: %v", err) + t.Fatalf("first call to Dial(%v) failed: %v", testAddress1, err) } if conn1 == nil { - t.Fatal("first call to Dial(_)=(nil, _), want not nil") + t.Fatalf("first call to Dial(%v)=(nil, _), want not nil", testAddress1) } - if got, want := hsConn, conn1; got != want { - t.Fatalf("hsConn=%v, want %v", got, want) + if got, want := hsConnMap[testAddress1], conn1; got != want { + t.Fatalf("hsConnMap[%v]=%v, want %v", testAddress1, got, want) } // Second call to Dial should return conn1 above. - conn2, err := Dial(testAddress) + conn2, err := Dial(testAddress1) if err != nil { - t.Fatalf("second call to Dial(_) failed: %v", err) + t.Fatalf("second call to Dial(%v) failed: %v", testAddress1, err) } if got, want := conn2, conn1; got != want { - t.Fatalf("second call to Dial(_)=(%v, _), want (%v,. _)", got, want) + t.Fatalf("second call to Dial(%v)=(%v, _), want (%v,. _)", testAddress1, got, want) } - if got, want := hsConn, conn1; got != want { - t.Fatalf("hsConn=%v, want %v", got, want) + if got, want := hsConnMap[testAddress1], conn1; got != want { + t.Fatalf("hsConnMap[%v]=%v, want %v", testAddress1, got, want) + } + + // Third call to Dial using a different address should create a new + // connection. + conn3, err := Dial(testAddress2) + if err != nil { + t.Fatalf("third call to Dial(%v) failed: %v", testAddress2, err) + } + if conn3 == nil { + t.Fatalf("third call to Dial(%v)=(nil, _), want not nil", testAddress2) + } + if got, want := hsConnMap[testAddress2], conn3; got != want { + t.Fatalf("hsConnMap[%v]=%v, want %v", testAddress2, got, want) + } + if got, want := conn2 == conn3, false; got != want { + t.Fatalf("(conn2==conn3)=%v, want %v", got, want) } }