From e9e6ae6215eed864a46b66148f33ae4dfd7cc724 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Thu, 14 Jul 2016 15:19:57 -0700 Subject: [PATCH] Make Dial() withblock error on bad certificates --- call.go | 18 +++++++++++++++--- clientconn.go | 3 +++ stream.go | 9 ++++++++- test/end2end_test.go | 20 ++++++++++++++++++++ transport/http2_client.go | 19 ++++++++++--------- transport/http2_server.go | 10 +++++----- transport/transport.go | 11 +++++++++-- 7 files changed, 70 insertions(+), 20 deletions(-) diff --git a/call.go b/call.go index a8b6dcfd..f1a40a37 100644 --- a/call.go +++ b/call.go @@ -84,7 +84,7 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd } defer func() { if err != nil { - if _, ok := err.(transport.ConnectionError); !ok { + if e, ok := err.(transport.ConnectionError); !ok || !e.Temporary() { t.CloseStream(stream, err) } } @@ -190,10 +190,13 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli // Retry a non-failfast RPC when // i) there is a connection error; or // ii) the server started to drain before this RPC was initiated. - if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { + if e, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { if c.failFast { return toRPCErr(err) } + if ok && !e.Temporary() { + return toRPCErr(err) + } continue } return toRPCErr(err) @@ -204,7 +207,16 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli put() put = nil } - if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { + if e, ok := err.(transport.ConnectionError); ok { + if c.failFast { + return toRPCErr(err) + } + if !e.Temporary() { + return toRPCErr(err) + } + continue + } + if err == transport.ErrStreamDrain { if c.failFast { return toRPCErr(err) } diff --git a/clientconn.go b/clientconn.go index 5c2bf644..f2e568ce 100644 --- a/clientconn.go +++ b/clientconn.go @@ -605,6 +605,9 @@ func (ac *addrConn) resetTransport(closeTransport bool) error { if err != nil { cancel() + if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() { + return fmt.Errorf("failed to create client transport: %v", err) + } ac.mu.Lock() if ac.state == Shutdown { // ac.tearDown(...) has been invoked. diff --git a/stream.go b/stream.go index 66bfad81..1d5104f6 100644 --- a/stream.go +++ b/stream.go @@ -166,7 +166,14 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth put() put = nil } - if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { + if e, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain { + if c.failFast || e.Temporary() { + cs.finish(err) + return nil, toRPCErr(err) + } + continue + } + if err == transport.ErrStreamDrain { if c.failFast { cs.finish(err) return nil, toRPCErr(err) diff --git a/test/end2end_test.go b/test/end2end_test.go index e57ff919..ffa4fa35 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -2267,6 +2267,26 @@ func testClientRequestBodyError_Cancel_StreamingInput(t *testing.T, e env) { }) } +func TestDialWithBlockErrorOnBadCertificates(t *testing.T) { + te := newTest(t, env{name: "bad-tls", network: "tcp", security: "bad-tls"}) + te.startServer() + defer te.tearDown() + + var ( + err error + opts []grpc.DialOption + ) + creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "wrong-server.com") + if err != nil { + te.t.Fatalf("Failed to load credentials: %v", err) + } + opts = append(opts, grpc.WithTransportCredentials(creds), grpc.WithBlock()) + te.cc, err = grpc.Dial(te.srvAddr, opts...) + if err == nil { + te.t.Fatalf("Dial(%q) = %v, want ConnectionError: credentials handshake failed", te.srvAddr, err) + } +} + // interestingGoroutines returns all goroutines we care about for the purpose // of leak checking. It excludes testing or runtime ones. func interestingGoroutines() (gs []string) { diff --git a/transport/http2_client.go b/transport/http2_client.go index 6dc48787..a3709d2d 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -121,7 +121,7 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl scheme := "http" conn, connErr := dial(opts.Dialer, ctx, addr) if connErr != nil { - return nil, ConnectionErrorf("transport: %v", connErr) + return nil, ConnectionErrorf(true, "transport: %v", connErr) } var authInfo credentials.AuthInfo if creds := opts.TransportCredentials; creds != nil { @@ -129,7 +129,8 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl conn, authInfo, connErr = creds.ClientHandshake(ctx, addr, conn) } if connErr != nil { - return nil, ConnectionErrorf("transport: %v", connErr) + // Credentials handshake error is not a temporary error. + return nil, ConnectionErrorf(false, "transport: %v", connErr) } defer func() { if err != nil { @@ -173,11 +174,11 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl n, err := t.conn.Write(clientPreface) if err != nil { t.Close() - return nil, ConnectionErrorf("transport: %v", err) + return nil, ConnectionErrorf(true, "transport: %v", err) } if n != len(clientPreface) { t.Close() - return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) + return nil, ConnectionErrorf(true, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) } if initialWindowSize != defaultWindowSize { err = t.framer.writeSettings(true, http2.Setting{ @@ -189,13 +190,13 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl } if err != nil { t.Close() - return nil, ConnectionErrorf("transport: %v", err) + return nil, ConnectionErrorf(true, "transport: %v", err) } // Adjust the connection flow control window if needed. if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 { if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil { t.Close() - return nil, ConnectionErrorf("transport: %v", err) + return nil, ConnectionErrorf(true, "transport: %v", err) } } go t.controller() @@ -405,7 +406,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } if err != nil { t.notifyError(err) - return nil, ConnectionErrorf("transport: %v", err) + return nil, ConnectionErrorf(true, "transport: %v", err) } } t.writableChan <- 0 @@ -619,7 +620,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { // invoked. if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil { t.notifyError(err) - return ConnectionErrorf("transport: %v", err) + return ConnectionErrorf(true, "transport: %v", err) } if t.framer.adjustNumWriters(-1) == 0 { t.framer.flushWrite() @@ -667,7 +668,7 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) { func (t *http2Client) handleData(f *http2.DataFrame) { size := len(f.Data()) if err := t.fc.onData(uint32(size)); err != nil { - t.notifyError(ConnectionErrorf("%v", err)) + t.notifyError(ConnectionErrorf(true, "%v", err)) return } // Select the right stream to dispatch. diff --git a/transport/http2_server.go b/transport/http2_server.go index 8ed0cd59..7fd6aeba 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -111,12 +111,12 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI Val: uint32(initialWindowSize)}) } if err := framer.writeSettings(true, settings...); err != nil { - return nil, ConnectionErrorf("transport: %v", err) + return nil, ConnectionErrorf(true, "transport: %v", err) } // Adjust the connection flow control window if needed. if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 { if err := framer.writeWindowUpdate(true, 0, delta); err != nil { - return nil, ConnectionErrorf("transport: %v", err) + return nil, ConnectionErrorf(true, "transport: %v", err) } } var buf bytes.Buffer @@ -448,7 +448,7 @@ func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) e } if err != nil { t.Close() - return ConnectionErrorf("transport: %v", err) + return ConnectionErrorf(true, "transport: %v", err) } } return nil @@ -568,7 +568,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { } if err := t.framer.writeHeaders(false, p); err != nil { t.Close() - return ConnectionErrorf("transport: %v", err) + return ConnectionErrorf(true, "transport: %v", err) } t.writableChan <- 0 } @@ -642,7 +642,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { } if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil { t.Close() - return ConnectionErrorf("transport: %v", err) + return ConnectionErrorf(true, "transport: %v", err) } if t.framer.adjustNumWriters(-1) == 0 { t.framer.flushWrite() diff --git a/transport/transport.go b/transport/transport.go index f739090c..e2b67184 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -485,9 +485,10 @@ func StreamErrorf(c codes.Code, format string, a ...interface{}) StreamError { } // ConnectionErrorf creates an ConnectionError with the specified error description. -func ConnectionErrorf(format string, a ...interface{}) ConnectionError { +func ConnectionErrorf(temp bool, format string, a ...interface{}) ConnectionError { return ConnectionError{ Desc: fmt.Sprintf(format, a...), + temp: temp, } } @@ -495,15 +496,21 @@ func ConnectionErrorf(format string, a ...interface{}) ConnectionError { // entire connection and the retry of all the active streams. type ConnectionError struct { Desc string + temp bool } func (e ConnectionError) Error() string { return fmt.Sprintf("connection error: desc = %q", e.Desc) } +// Temporary indicates if this connection error is temporary or fatal. +func (e ConnectionError) Temporary() bool { + return e.temp +} + var ( // ErrConnClosing indicates that the transport is closing. - ErrConnClosing = ConnectionError{Desc: "transport is closing"} + ErrConnClosing = ConnectionError{Desc: "transport is closing", temp: true} // ErrStreamDrain indicates that the stream is rejected by the server because // the server stops accepting new RPCs. ErrStreamDrain = StreamErrorf(codes.Unavailable, "the server stops accepting new RPCs")