diff --git a/clientconn.go b/clientconn.go index a3170320..34af35e6 100644 --- a/clientconn.go +++ b/clientconn.go @@ -1059,6 +1059,10 @@ func (ac *addrConn) createTransport(backoffNum int, addr resolver.Address, copts prefaceReceived := make(chan struct{}) onCloseCalled := make(chan struct{}) + var prefaceMu sync.Mutex + var serverPrefaceReceived bool + var clientPrefaceWrote bool + onGoAway := func(r transport.GoAwayReason) { ac.mu.Lock() ac.adjustParams(r) @@ -1100,11 +1104,18 @@ func (ac *addrConn) createTransport(backoffNum int, addr resolver.Address, copts // TODO(deklerk): optimization; does anyone else actually use this lock? maybe we can just remove it for this scope ac.mu.Lock() - ac.successfulHandshake = true - ac.backoffDeadline = time.Time{} - ac.connectDeadline = time.Time{} - ac.addrIdx = 0 - ac.backoffIdx = 0 + + prefaceMu.Lock() + serverPrefaceReceived = true + if clientPrefaceWrote { + ac.successfulHandshake = true + ac.backoffDeadline = time.Time{} + ac.connectDeadline = time.Time{} + ac.addrIdx = 0 + ac.backoffIdx = 0 + } + prefaceMu.Unlock() + ac.mu.Unlock() } @@ -1117,6 +1128,13 @@ func (ac *addrConn) createTransport(backoffNum int, addr resolver.Address, copts newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, target, copts, onPrefaceReceipt, onGoAway, onClose) if err == nil { + prefaceMu.Lock() + clientPrefaceWrote = true + if serverPrefaceReceived { + ac.successfulHandshake = true + } + prefaceMu.Unlock() + if ac.dopts.waitForHandshake { select { case <-prefaceTimer.C: @@ -1160,8 +1178,6 @@ func (ac *addrConn) createTransport(backoffNum int, addr resolver.Address, copts return errConnClosing } - ac.updateConnectivityState(connectivity.TransientFailure) - ac.cc.handleSubConnStateChange(ac.acbw, ac.state) ac.mu.Unlock() grpclog.Warningf("grpc: addrConn.createTransport failed to connect to %v. Err :%v. Reconnecting...", addr, err) diff --git a/clientconn_state_transition_test.go b/clientconn_state_transition_test.go index e16d606f..b480af9e 100644 --- a/clientconn_state_transition_test.go +++ b/clientconn_state_transition_test.go @@ -21,6 +21,7 @@ package grpc import ( "net" "sync" + "sync/atomic" "testing" "time" @@ -29,6 +30,7 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/internal/leakcheck" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" ) @@ -41,76 +43,135 @@ func init() { balancer.Register(testBalancer) } +// These tests use a pipeListener. This listener is similar to net.Listener except that it is unbuffered, so each read +// and write will wait for the other side's corresponding write or read. func TestStateTransitions_SingleAddress(t *testing.T) { + defer leakcheck.Check(t) + + mctBkp := getMinConnectTimeout() + defer func() { + atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(mctBkp)) + }() + atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(time.Millisecond)*100) + for _, test := range []struct { - name string + desc string want []connectivity.State - server func(net.Listener) + server func(net.Listener) net.Conn }{ - // When the server returns server preface, the client enters READY. { - name: "ServerEntersReadyOnPrefaceReceipt", + desc: "When the server returns server preface, the client enters READY.", want: []connectivity.State{ connectivity.Connecting, connectivity.Ready, }, - server: func(lis net.Listener) { + server: func(lis net.Listener) net.Conn { conn, err := lis.Accept() if err != nil { t.Error(err) - return + return nil + } + + go keepReading(conn) + + framer := http2.NewFramer(conn, conn) + if err := framer.WriteSettings(http2.Setting{}); err != nil { + t.Errorf("Error while writing settings frame. %v", err) + return nil + } + + return conn + }, + }, + { + desc: "When the connection is closed, the client enters TRANSIENT FAILURE.", + want: []connectivity.State{ + connectivity.Connecting, + connectivity.TransientFailure, + }, + server: func(lis net.Listener) net.Conn { + conn, err := lis.Accept() + if err != nil { + t.Error(err) + return nil + } + + conn.Close() + return nil + }, + }, + { + desc: `When the server sends its connection preface, but the connection dies before the client can write its +connection preface, the client enters TRANSIENT FAILURE.`, + want: []connectivity.State{ + connectivity.Connecting, + connectivity.TransientFailure, + }, + server: func(lis net.Listener) net.Conn { + conn, err := lis.Accept() + if err != nil { + t.Error(err) + return nil } framer := http2.NewFramer(conn, conn) if err := framer.WriteSettings(http2.Setting{}); err != nil { t.Errorf("Error while writing settings frame. %v", err) - return + return nil } + + conn.Close() + return nil }, }, - // When the connection is closed, the client enters TRANSIENT FAILURE. { - name: "ServerEntersTransientFailureOnClose", + desc: `When the server reads the client connection preface but does not send its connection preface, the +client enters TRANSIENT FAILURE.`, want: []connectivity.State{ connectivity.Connecting, connectivity.TransientFailure, }, - server: func(lis net.Listener) { + server: func(lis net.Listener) net.Conn { conn, err := lis.Accept() if err != nil { t.Error(err) - return + return nil } - conn.Close() + go keepReading(conn) + + return conn }, }, } { - t.Logf("Test %s", test.name) + t.Log(test.desc) testStateTransitionSingleAddress(t, test.want, test.server) } } -func testStateTransitionSingleAddress(t *testing.T, want []connectivity.State, server func(net.Listener)) { +func testStateTransitionSingleAddress(t *testing.T, want []connectivity.State, server func(net.Listener) net.Conn) { defer leakcheck.Check(t) stateNotifications := make(chan connectivity.State, len(want)) testBalancer.ResetNotifier(stateNotifications) - defer close(stateNotifications) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("Error while listening. Err: %v", err) - } - defer lis.Close() + pl := testutils.NewPipeListener() + defer pl.Close() // Launch the server. - go server(lis) + var conn net.Conn + var connMu sync.Mutex + go func() { + connMu.Lock() + conn = server(pl) + connMu.Unlock() + }() - client, err := DialContext(ctx, lis.Addr().String(), WithWaitForHandshake(), WithInsecure(), WithBalancerName(stateRecordingBalancerName)) + client, err := DialContext(ctx, "", WithWaitForHandshake(), WithInsecure(), + WithBalancerName(stateRecordingBalancerName), WithDialer(pl.Dialer()), withBackoff(noBackoff{})) if err != nil { t.Fatal(err) } @@ -128,6 +189,15 @@ func testStateTransitionSingleAddress(t *testing.T, want []connectivity.State, s } } } + + connMu.Lock() + defer connMu.Unlock() + if conn != nil { + err = conn.Close() + if err != nil { + t.Fatal(err) + } + } } // When a READY connection is closed, the client enters TRANSIENT FAILURE before CONNECTING. @@ -143,7 +213,6 @@ func TestStateTransition_ReadyToTransientFailure(t *testing.T) { stateNotifications := make(chan connectivity.State, len(want)) testBalancer.ResetNotifier(stateNotifications) - defer close(stateNotifications) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() @@ -164,6 +233,8 @@ func TestStateTransition_ReadyToTransientFailure(t *testing.T) { return } + go keepReading(conn) + framer := http2.NewFramer(conn, conn) if err := framer.WriteSettings(http2.Setting{}); err != nil { t.Errorf("Error while writing settings frame. %v", err) @@ -211,7 +282,6 @@ func TestStateTransitions_TriesAllAddrsBeforeTransientFailure(t *testing.T) { stateNotifications := make(chan connectivity.State, len(want)) testBalancer.ResetNotifier(stateNotifications) - defer close(stateNotifications) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() @@ -250,11 +320,14 @@ func TestStateTransitions_TriesAllAddrsBeforeTransientFailure(t *testing.T) { return } + go keepReading(conn) + framer := http2.NewFramer(conn, conn) if err := framer.WriteSettings(http2.Setting{}); err != nil { t.Errorf("Error while writing settings frame. %v", err) return } + close(server2Done) }() @@ -307,7 +380,6 @@ func TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) { stateNotifications := make(chan connectivity.State, len(want)) testBalancer.ResetNotifier(stateNotifications) - defer close(stateNotifications) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() @@ -336,6 +408,8 @@ func TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) { return } + go keepReading(conn) + framer := http2.NewFramer(conn, conn) if err := framer.WriteSettings(http2.Setting{}); err != nil { t.Errorf("Error while writing settings frame. %v", err) @@ -426,3 +500,16 @@ func (b *stateRecordingBalancer) Build(cc balancer.ClientConn, opts balancer.Bui b.mu.Unlock() return b } + +type noBackoff struct{} + +func (b noBackoff) Backoff(int) time.Duration { return time.Duration(0) } + +// Keep reading until something causes the connection to die (EOF, server closed, etc). Useful +// as a tool for mindlessly keeping the connection healthy, since the client will error if +// things like client prefaces are not accepted in a timely fashion. +func keepReading(conn net.Conn) { + buf := make([]byte, 1024) + for _, err := conn.Read(buf); err == nil; _, err = conn.Read(buf) { + } +} diff --git a/clientconn_test.go b/clientconn_test.go index c856897f..56707b39 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -225,10 +225,8 @@ func TestCloseConnectionWhenServerPrefaceNotReceived(t *testing.T) { // 3. The new server sends its preface. // 4. Client doesn't kill the connection this time. mctBkp := getMinConnectTimeout() - // Call this only after transportMonitor goroutine has ended. defer func() { atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(mctBkp)) - }() defer leakcheck.Check(t) atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(time.Millisecond)*500) diff --git a/internal/testutils/pipe_listener.go b/internal/testutils/pipe_listener.go new file mode 100644 index 00000000..77def117 --- /dev/null +++ b/internal/testutils/pipe_listener.go @@ -0,0 +1,95 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package testutils + +import ( + "errors" + "net" + "time" +) + +var errClosed = errors.New("closed") + +type pipeAddr struct{} + +func (p pipeAddr) Network() string { return "pipe" } +func (p pipeAddr) String() string { return "pipe" } + +// PipeListener is a listener with an unbuffered pipe. Each write will complete only once the other side reads. It +// should only be created using NewPipeListener. +type PipeListener struct { + c chan chan<- net.Conn + done chan struct{} +} + +// NewPipeListener creates a new pipe listener. +func NewPipeListener() *PipeListener { + return &PipeListener{ + c: make(chan chan<- net.Conn), + done: make(chan struct{}), + } +} + +// Accept accepts a connection. +func (p *PipeListener) Accept() (net.Conn, error) { + var connChan chan<- net.Conn + select { + case <-p.done: + return nil, errClosed + case connChan = <-p.c: + select { + case <-p.done: + close(connChan) + return nil, errClosed + default: + } + } + c1, c2 := net.Pipe() + connChan <- c1 + close(connChan) + return c2, nil +} + +// Close closes the listener. +func (p *PipeListener) Close() error { + close(p.done) + return nil +} + +// Addr returns a pipe addr. +func (p *PipeListener) Addr() net.Addr { + return pipeAddr{} +} + +// Dialer dials a connection. +func (p *PipeListener) Dialer() func(string, time.Duration) (net.Conn, error) { + return func(string, time.Duration) (net.Conn, error) { + connChan := make(chan net.Conn) + select { + case p.c <- connChan: + case <-p.done: + return nil, errClosed + } + conn, ok := <-connChan + if !ok { + return nil, errClosed + } + return conn, nil + } +} diff --git a/internal/testutils/pipe_listener_test.go b/internal/testutils/pipe_listener_test.go new file mode 100644 index 00000000..9bd399cb --- /dev/null +++ b/internal/testutils/pipe_listener_test.go @@ -0,0 +1,163 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package testutils_test + +import ( + "testing" + "time" + + "google.golang.org/grpc/internal/testutils" +) + +func TestPipeListener(t *testing.T) { + pl := testutils.NewPipeListener() + recvdBytes := make(chan []byte) + const want = "hello world" + + go func() { + c, err := pl.Accept() + if err != nil { + t.Error(err) + } + + read := make([]byte, len(want)) + _, err = c.Read(read) + if err != nil { + t.Error(err) + } + recvdBytes <- read + }() + + dl := pl.Dialer() + conn, err := dl("", time.Duration(0)) + if err != nil { + t.Fatal(err) + } + + _, err = conn.Write([]byte(want)) + if err != nil { + t.Fatal(err) + } + + select { + case gotBytes := <-recvdBytes: + got := string(gotBytes) + if got != want { + t.Fatalf("expected to get %s, got %s", got, want) + } + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for server to receive bytes") + } +} + +func TestUnblocking(t *testing.T) { + for _, test := range []struct { + desc string + blockFuncShouldError bool + blockFunc func(*testutils.PipeListener, chan struct{}) error + unblockFunc func(*testutils.PipeListener) error + }{ + { + desc: "Accept unblocks Dial", + blockFunc: func(pl *testutils.PipeListener, done chan struct{}) error { + dl := pl.Dialer() + _, err := dl("", time.Duration(0)) + close(done) + return err + }, + unblockFunc: func(pl *testutils.PipeListener) error { + _, err := pl.Accept() + return err + }, + }, + { + desc: "Close unblocks Dial", + blockFuncShouldError: true, // because pl.Close will be called + blockFunc: func(pl *testutils.PipeListener, done chan struct{}) error { + dl := pl.Dialer() + _, err := dl("", time.Duration(0)) + close(done) + return err + }, + unblockFunc: func(pl *testutils.PipeListener) error { + return pl.Close() + }, + }, + { + desc: "Dial unblocks Accept", + blockFunc: func(pl *testutils.PipeListener, done chan struct{}) error { + _, err := pl.Accept() + close(done) + return err + }, + unblockFunc: func(pl *testutils.PipeListener) error { + dl := pl.Dialer() + _, err := dl("", time.Duration(0)) + return err + }, + }, + { + desc: "Close unblocks Accept", + blockFuncShouldError: true, // because pl.Close will be called + blockFunc: func(pl *testutils.PipeListener, done chan struct{}) error { + _, err := pl.Accept() + close(done) + return err + }, + unblockFunc: func(pl *testutils.PipeListener) error { + return pl.Close() + }, + }, + } { + t.Log(test.desc) + testUnblocking(t, test.blockFunc, test.unblockFunc, test.blockFuncShouldError) + } +} + +func testUnblocking(t *testing.T, blockFunc func(*testutils.PipeListener, chan struct{}) error, unblockFunc func(*testutils.PipeListener) error, blockFuncShouldError bool) { + pl := testutils.NewPipeListener() + dialFinished := make(chan struct{}) + + go func() { + err := blockFunc(pl, dialFinished) + if blockFuncShouldError && err == nil { + t.Error("expected blocking func to return error because pl.Close was called, but got nil") + } + + if !blockFuncShouldError && err != nil { + t.Error(err) + } + }() + + select { + case <-dialFinished: + t.Fatal("expected Dial to block until pl.Close or pl.Accept") + default: + } + + if err := unblockFunc(pl); err != nil { + t.Fatal(err) + } + + select { + case <-dialFinished: + case <-time.After(100 * time.Millisecond): + t.Fatal("expected Accept to unblock after pl.Accept was called") + } +} diff --git a/test/end2end_test.go b/test/end2end_test.go index 11b3bc8b..cbbeca3c 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -58,6 +58,7 @@ import ( healthpb "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/leakcheck" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" @@ -6926,49 +6927,11 @@ func testClientMaxHeaderListSizeServerIntentionalViolation(t *testing.T, e env) } } -type pipeAddr struct{} - -func (p pipeAddr) Network() string { return "pipe" } -func (p pipeAddr) String() string { return "pipe" } - -type pipeListener struct { - c chan chan<- net.Conn -} - -func (p *pipeListener) Accept() (net.Conn, error) { - connChan, ok := <-p.c - if !ok { - return nil, errors.New("closed") - } - c1, c2 := net.Pipe() - connChan <- c1 - close(connChan) - return c2, nil -} - -func (p *pipeListener) Close() error { - close(p.c) - return nil -} - -func (p *pipeListener) Addr() net.Addr { - return pipeAddr{} -} - -func (p *pipeListener) Dialer() func(string, time.Duration) (net.Conn, error) { - return func(string, time.Duration) (net.Conn, error) { - connChan := make(chan net.Conn) - p.c <- connChan - conn := <-connChan - return conn, nil - } -} - func TestNetPipeConn(t *testing.T) { // This test will block indefinitely if grpc writes both client and server // prefaces without either reading from the Conn. defer leakcheck.Check(t) - pl := &pipeListener{c: make(chan chan<- net.Conn)} + pl := testutils.NewPipeListener() s := grpc.NewServer() defer s.Stop() ts := &funcServer{unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {