Make Dial() withblock error on bad certificates
This commit is contained in:
18
call.go
18
call.go
@ -84,7 +84,7 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
|
|||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(transport.ConnectionError); !ok {
|
if e, ok := err.(transport.ConnectionError); !ok || !e.Temporary() {
|
||||||
t.CloseStream(stream, err)
|
t.CloseStream(stream, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -190,10 +190,13 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
|
|||||||
// Retry a non-failfast RPC when
|
// Retry a non-failfast RPC when
|
||||||
// i) there is a connection error; or
|
// i) there is a connection error; or
|
||||||
// ii) the server started to drain before this RPC was initiated.
|
// ii) the server started to drain before this RPC was initiated.
|
||||||
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
|
if e, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
|
||||||
if c.failFast {
|
if c.failFast {
|
||||||
return toRPCErr(err)
|
return toRPCErr(err)
|
||||||
}
|
}
|
||||||
|
if ok && !e.Temporary() {
|
||||||
|
return toRPCErr(err)
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return toRPCErr(err)
|
return toRPCErr(err)
|
||||||
@ -204,7 +207,16 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
|
|||||||
put()
|
put()
|
||||||
put = nil
|
put = nil
|
||||||
}
|
}
|
||||||
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
|
if e, ok := err.(transport.ConnectionError); ok {
|
||||||
|
if c.failFast {
|
||||||
|
return toRPCErr(err)
|
||||||
|
}
|
||||||
|
if !e.Temporary() {
|
||||||
|
return toRPCErr(err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err == transport.ErrStreamDrain {
|
||||||
if c.failFast {
|
if c.failFast {
|
||||||
return toRPCErr(err)
|
return toRPCErr(err)
|
||||||
}
|
}
|
||||||
|
@ -605,6 +605,9 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
cancel()
|
cancel()
|
||||||
|
|
||||||
|
if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() {
|
||||||
|
return fmt.Errorf("failed to create client transport: %v", err)
|
||||||
|
}
|
||||||
ac.mu.Lock()
|
ac.mu.Lock()
|
||||||
if ac.state == Shutdown {
|
if ac.state == Shutdown {
|
||||||
// ac.tearDown(...) has been invoked.
|
// ac.tearDown(...) has been invoked.
|
||||||
|
@ -166,7 +166,14 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
|
|||||||
put()
|
put()
|
||||||
put = nil
|
put = nil
|
||||||
}
|
}
|
||||||
if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
|
if e, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
|
||||||
|
if c.failFast || e.Temporary() {
|
||||||
|
cs.finish(err)
|
||||||
|
return nil, toRPCErr(err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err == transport.ErrStreamDrain {
|
||||||
if c.failFast {
|
if c.failFast {
|
||||||
cs.finish(err)
|
cs.finish(err)
|
||||||
return nil, toRPCErr(err)
|
return nil, toRPCErr(err)
|
||||||
|
@ -2267,6 +2267,26 @@ func testClientRequestBodyError_Cancel_StreamingInput(t *testing.T, e env) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDialWithBlockErrorOnBadCertificates(t *testing.T) {
|
||||||
|
te := newTest(t, env{name: "bad-tls", network: "tcp", security: "bad-tls"})
|
||||||
|
te.startServer()
|
||||||
|
defer te.tearDown()
|
||||||
|
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
opts []grpc.DialOption
|
||||||
|
)
|
||||||
|
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "wrong-server.com")
|
||||||
|
if err != nil {
|
||||||
|
te.t.Fatalf("Failed to load credentials: %v", err)
|
||||||
|
}
|
||||||
|
opts = append(opts, grpc.WithTransportCredentials(creds), grpc.WithBlock())
|
||||||
|
te.cc, err = grpc.Dial(te.srvAddr, opts...)
|
||||||
|
if err == nil {
|
||||||
|
te.t.Fatalf("Dial(%q) = %v, want ConnectionError: credentials handshake failed", te.srvAddr, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// interestingGoroutines returns all goroutines we care about for the purpose
|
// interestingGoroutines returns all goroutines we care about for the purpose
|
||||||
// of leak checking. It excludes testing or runtime ones.
|
// of leak checking. It excludes testing or runtime ones.
|
||||||
func interestingGoroutines() (gs []string) {
|
func interestingGoroutines() (gs []string) {
|
||||||
|
@ -121,7 +121,7 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl
|
|||||||
scheme := "http"
|
scheme := "http"
|
||||||
conn, connErr := dial(opts.Dialer, ctx, addr)
|
conn, connErr := dial(opts.Dialer, ctx, addr)
|
||||||
if connErr != nil {
|
if connErr != nil {
|
||||||
return nil, ConnectionErrorf("transport: %v", connErr)
|
return nil, ConnectionErrorf(true, "transport: %v", connErr)
|
||||||
}
|
}
|
||||||
var authInfo credentials.AuthInfo
|
var authInfo credentials.AuthInfo
|
||||||
if creds := opts.TransportCredentials; creds != nil {
|
if creds := opts.TransportCredentials; creds != nil {
|
||||||
@ -129,7 +129,8 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl
|
|||||||
conn, authInfo, connErr = creds.ClientHandshake(ctx, addr, conn)
|
conn, authInfo, connErr = creds.ClientHandshake(ctx, addr, conn)
|
||||||
}
|
}
|
||||||
if connErr != nil {
|
if connErr != nil {
|
||||||
return nil, ConnectionErrorf("transport: %v", connErr)
|
// Credentials handshake error is not a temporary error.
|
||||||
|
return nil, ConnectionErrorf(false, "transport: %v", connErr)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -173,11 +174,11 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl
|
|||||||
n, err := t.conn.Write(clientPreface)
|
n, err := t.conn.Write(clientPreface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Close()
|
t.Close()
|
||||||
return nil, ConnectionErrorf("transport: %v", err)
|
return nil, ConnectionErrorf(true, "transport: %v", err)
|
||||||
}
|
}
|
||||||
if n != len(clientPreface) {
|
if n != len(clientPreface) {
|
||||||
t.Close()
|
t.Close()
|
||||||
return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
|
return nil, ConnectionErrorf(true, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
|
||||||
}
|
}
|
||||||
if initialWindowSize != defaultWindowSize {
|
if initialWindowSize != defaultWindowSize {
|
||||||
err = t.framer.writeSettings(true, http2.Setting{
|
err = t.framer.writeSettings(true, http2.Setting{
|
||||||
@ -189,13 +190,13 @@ func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ Cl
|
|||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Close()
|
t.Close()
|
||||||
return nil, ConnectionErrorf("transport: %v", err)
|
return nil, ConnectionErrorf(true, "transport: %v", err)
|
||||||
}
|
}
|
||||||
// Adjust the connection flow control window if needed.
|
// Adjust the connection flow control window if needed.
|
||||||
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
|
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
|
||||||
if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil {
|
if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil {
|
||||||
t.Close()
|
t.Close()
|
||||||
return nil, ConnectionErrorf("transport: %v", err)
|
return nil, ConnectionErrorf(true, "transport: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
go t.controller()
|
go t.controller()
|
||||||
@ -405,7 +406,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.notifyError(err)
|
t.notifyError(err)
|
||||||
return nil, ConnectionErrorf("transport: %v", err)
|
return nil, ConnectionErrorf(true, "transport: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
t.writableChan <- 0
|
t.writableChan <- 0
|
||||||
@ -619,7 +620,7 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
|
|||||||
// invoked.
|
// invoked.
|
||||||
if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil {
|
if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil {
|
||||||
t.notifyError(err)
|
t.notifyError(err)
|
||||||
return ConnectionErrorf("transport: %v", err)
|
return ConnectionErrorf(true, "transport: %v", err)
|
||||||
}
|
}
|
||||||
if t.framer.adjustNumWriters(-1) == 0 {
|
if t.framer.adjustNumWriters(-1) == 0 {
|
||||||
t.framer.flushWrite()
|
t.framer.flushWrite()
|
||||||
@ -667,7 +668,7 @@ func (t *http2Client) updateWindow(s *Stream, n uint32) {
|
|||||||
func (t *http2Client) handleData(f *http2.DataFrame) {
|
func (t *http2Client) handleData(f *http2.DataFrame) {
|
||||||
size := len(f.Data())
|
size := len(f.Data())
|
||||||
if err := t.fc.onData(uint32(size)); err != nil {
|
if err := t.fc.onData(uint32(size)); err != nil {
|
||||||
t.notifyError(ConnectionErrorf("%v", err))
|
t.notifyError(ConnectionErrorf(true, "%v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Select the right stream to dispatch.
|
// Select the right stream to dispatch.
|
||||||
|
@ -111,12 +111,12 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI
|
|||||||
Val: uint32(initialWindowSize)})
|
Val: uint32(initialWindowSize)})
|
||||||
}
|
}
|
||||||
if err := framer.writeSettings(true, settings...); err != nil {
|
if err := framer.writeSettings(true, settings...); err != nil {
|
||||||
return nil, ConnectionErrorf("transport: %v", err)
|
return nil, ConnectionErrorf(true, "transport: %v", err)
|
||||||
}
|
}
|
||||||
// Adjust the connection flow control window if needed.
|
// Adjust the connection flow control window if needed.
|
||||||
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
|
if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 {
|
||||||
if err := framer.writeWindowUpdate(true, 0, delta); err != nil {
|
if err := framer.writeWindowUpdate(true, 0, delta); err != nil {
|
||||||
return nil, ConnectionErrorf("transport: %v", err)
|
return nil, ConnectionErrorf(true, "transport: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
@ -448,7 +448,7 @@ func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) e
|
|||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Close()
|
t.Close()
|
||||||
return ConnectionErrorf("transport: %v", err)
|
return ConnectionErrorf(true, "transport: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -568,7 +568,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
|
|||||||
}
|
}
|
||||||
if err := t.framer.writeHeaders(false, p); err != nil {
|
if err := t.framer.writeHeaders(false, p); err != nil {
|
||||||
t.Close()
|
t.Close()
|
||||||
return ConnectionErrorf("transport: %v", err)
|
return ConnectionErrorf(true, "transport: %v", err)
|
||||||
}
|
}
|
||||||
t.writableChan <- 0
|
t.writableChan <- 0
|
||||||
}
|
}
|
||||||
@ -642,7 +642,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
|
|||||||
}
|
}
|
||||||
if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil {
|
if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil {
|
||||||
t.Close()
|
t.Close()
|
||||||
return ConnectionErrorf("transport: %v", err)
|
return ConnectionErrorf(true, "transport: %v", err)
|
||||||
}
|
}
|
||||||
if t.framer.adjustNumWriters(-1) == 0 {
|
if t.framer.adjustNumWriters(-1) == 0 {
|
||||||
t.framer.flushWrite()
|
t.framer.flushWrite()
|
||||||
|
@ -485,9 +485,10 @@ func StreamErrorf(c codes.Code, format string, a ...interface{}) StreamError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ConnectionErrorf creates an ConnectionError with the specified error description.
|
// ConnectionErrorf creates an ConnectionError with the specified error description.
|
||||||
func ConnectionErrorf(format string, a ...interface{}) ConnectionError {
|
func ConnectionErrorf(temp bool, format string, a ...interface{}) ConnectionError {
|
||||||
return ConnectionError{
|
return ConnectionError{
|
||||||
Desc: fmt.Sprintf(format, a...),
|
Desc: fmt.Sprintf(format, a...),
|
||||||
|
temp: temp,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -495,15 +496,21 @@ func ConnectionErrorf(format string, a ...interface{}) ConnectionError {
|
|||||||
// entire connection and the retry of all the active streams.
|
// entire connection and the retry of all the active streams.
|
||||||
type ConnectionError struct {
|
type ConnectionError struct {
|
||||||
Desc string
|
Desc string
|
||||||
|
temp bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e ConnectionError) Error() string {
|
func (e ConnectionError) Error() string {
|
||||||
return fmt.Sprintf("connection error: desc = %q", e.Desc)
|
return fmt.Sprintf("connection error: desc = %q", e.Desc)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Temporary indicates if this connection error is temporary or fatal.
|
||||||
|
func (e ConnectionError) Temporary() bool {
|
||||||
|
return e.temp
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// ErrConnClosing indicates that the transport is closing.
|
// ErrConnClosing indicates that the transport is closing.
|
||||||
ErrConnClosing = ConnectionError{Desc: "transport is closing"}
|
ErrConnClosing = ConnectionError{Desc: "transport is closing", temp: true}
|
||||||
// ErrStreamDrain indicates that the stream is rejected by the server because
|
// ErrStreamDrain indicates that the stream is rejected by the server because
|
||||||
// the server stops accepting new RPCs.
|
// the server stops accepting new RPCs.
|
||||||
ErrStreamDrain = StreamErrorf(codes.Unavailable, "the server stops accepting new RPCs")
|
ErrStreamDrain = StreamErrorf(codes.Unavailable, "the server stops accepting new RPCs")
|
||||||
|
Reference in New Issue
Block a user