credentials: Use net.SplitHostPort safely parse IPv6 authorities in ClientHandshake (#3082)
This commit is contained in:

committed by
Easwar Swaminathan

parent
ff0c603b9b
commit
f07f2cffa0
@ -30,9 +30,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
|
|
||||||
"google.golang.org/grpc/credentials/internal"
|
"google.golang.org/grpc/credentials/internal"
|
||||||
ginternal "google.golang.org/grpc/internal"
|
ginternal "google.golang.org/grpc/internal"
|
||||||
)
|
)
|
||||||
@ -168,11 +168,12 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawCon
|
|||||||
// use local cfg to avoid clobbering ServerName if using multiple endpoints
|
// use local cfg to avoid clobbering ServerName if using multiple endpoints
|
||||||
cfg := cloneTLSConfig(c.config)
|
cfg := cloneTLSConfig(c.config)
|
||||||
if cfg.ServerName == "" {
|
if cfg.ServerName == "" {
|
||||||
colonPos := strings.LastIndex(authority, ":")
|
serverName, _, err := net.SplitHostPort(authority)
|
||||||
if colonPos == -1 {
|
if err != nil {
|
||||||
colonPos = len(authority)
|
// If the authority had no host port or if the authority cannot be parsed, use it as-is.
|
||||||
|
serverName = authority
|
||||||
}
|
}
|
||||||
cfg.ServerName = authority[:colonPos]
|
cfg.ServerName = serverName
|
||||||
}
|
}
|
||||||
conn := tls.Client(rawConn, cfg)
|
conn := tls.Client(rawConn, cfg)
|
||||||
errChannel := make(chan error, 1)
|
errChannel := make(chan error, 1)
|
||||||
|
@ -23,6 +23,7 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"google.golang.org/grpc/testdata"
|
"google.golang.org/grpc/testdata"
|
||||||
@ -55,18 +56,40 @@ func TestTLSClone(t *testing.T) {
|
|||||||
type serverHandshake func(net.Conn) (AuthInfo, error)
|
type serverHandshake func(net.Conn) (AuthInfo, error)
|
||||||
|
|
||||||
func TestClientHandshakeReturnsAuthInfo(t *testing.T) {
|
func TestClientHandshakeReturnsAuthInfo(t *testing.T) {
|
||||||
done := make(chan AuthInfo, 1)
|
tcs := []struct {
|
||||||
lis := launchServer(t, tlsServerHandshake, done)
|
name string
|
||||||
defer lis.Close()
|
address string
|
||||||
lisAddr := lis.Addr().String()
|
}{
|
||||||
clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr)
|
{
|
||||||
// wait until server sends serverAuthInfo or fails.
|
name: "localhost",
|
||||||
serverAuthInfo, ok := <-done
|
address: "localhost:0",
|
||||||
if !ok {
|
},
|
||||||
t.Fatalf("Error at server-side")
|
{
|
||||||
|
name: "ipv4",
|
||||||
|
address: "127.0.0.1:0",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ipv6",
|
||||||
|
address: "[::1]:0",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if !compare(clientAuthInfo, serverAuthInfo) {
|
|
||||||
t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo)
|
for _, tc := range tcs {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
done := make(chan AuthInfo, 1)
|
||||||
|
lis := launchServerOnListenAddress(t, tlsServerHandshake, done, tc.address)
|
||||||
|
defer lis.Close()
|
||||||
|
lisAddr := lis.Addr().String()
|
||||||
|
clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr)
|
||||||
|
// wait until server sends serverAuthInfo or fails.
|
||||||
|
serverAuthInfo, ok := <-done
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Error at server-side")
|
||||||
|
}
|
||||||
|
if !compare(clientAuthInfo, serverAuthInfo) {
|
||||||
|
t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -121,8 +144,15 @@ func compare(a1, a2 AuthInfo) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func launchServer(t *testing.T, hs serverHandshake, done chan AuthInfo) net.Listener {
|
func launchServer(t *testing.T, hs serverHandshake, done chan AuthInfo) net.Listener {
|
||||||
lis, err := net.Listen("tcp", "localhost:0")
|
return launchServerOnListenAddress(t, hs, done, "localhost:0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func launchServerOnListenAddress(t *testing.T, hs serverHandshake, done chan AuthInfo, address string) net.Listener {
|
||||||
|
lis, err := net.Listen("tcp", address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if strings.Contains(err.Error(), "bind: cannot assign requested address") {
|
||||||
|
t.Skip("missing IPv6 support")
|
||||||
|
}
|
||||||
t.Fatalf("Failed to listen: %v", err)
|
t.Fatalf("Failed to listen: %v", err)
|
||||||
}
|
}
|
||||||
go serverHandle(t, hs, done, lis)
|
go serverHandle(t, hs, done, lis)
|
||||||
|
Reference in New Issue
Block a user