From 9d59a879e1a9fcec30404370ff77b86343bc123f Mon Sep 17 00:00:00 2001 From: iamqizhao Date: Fri, 1 May 2015 18:10:40 -0700 Subject: [PATCH] Add handshaker option to ClientConn --- clientconn.go | 6 ++++++ transport/http2_client.go | 11 +++++++++++ transport/transport.go | 1 + 3 files changed, 18 insertions(+) diff --git a/clientconn.go b/clientconn.go index e03cabdf..d2e8c94b 100644 --- a/clientconn.go +++ b/clientconn.go @@ -104,6 +104,12 @@ func WithDialer(f func(addr string, timeout time.Duration) (net.Conn, error)) Di } } +func WithHandshaker(h func(conn net.Conn) (credentials.TransportAuthenticator, error)) DialOption { + return func(o *dialOptions) { + o.copts.Handshaker = h + } +} + // Dial creates a client connection the given target. // TODO(zhaoq): Have an option to make Dial return immediately without waiting // for connection to complete. diff --git a/transport/http2_client.go b/transport/http2_client.go index 98cfb803..5de498fe 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -111,6 +111,17 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e if connErr != nil { return nil, ConnectionErrorf("transport: %v", connErr) } + // Perform handshake if opts.Handshaker is set. + if opts.Handshaker != nil { + auth, err := opts.Handshaker(conn) + if err != nil { + return nil, ConnectionErrorf("transport: handshaking failed %v", err) + } + // Prepend the resulting authenticator to opts.AuthOptions. + if auth != nil { + opts.AuthOptions = append([]credentials.Credentials{auth}, opts.AuthOptions...) + } + } for _, c := range opts.AuthOptions { if ccreds, ok := c.(credentials.TransportAuthenticator); ok { scheme = "https" diff --git a/transport/transport.go b/transport/transport.go index 5dfd89f0..de2bdb41 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -316,6 +316,7 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32) (Serv // ConnectOptions covers all relevant options for dialing a server. type ConnectOptions struct { Dialer func(string, time.Duration) (net.Conn, error) + Handshaker func(conn net.Conn) (credentials.TransportAuthenticator, error) AuthOptions []credentials.Credentials Timeout time.Duration }