internal: fix client send preface problems (#2380)
internal: fix client send preface problems This CL fixes three problems: - In clientconn_state_transitions_test.go, sometimes tests would flake because there's not enough buffer to send client side settings, causing the connection to unpredictably enter TRANSIENT FAILURE. Each time we set up a server to send SETTINGS, we should also set up the server to read. This allows the client to successfully send its SETTINGS, unflaking the test. - In clientconn.go, we incorrectly transitioned into TRANSIENT FAILURE when creating an http2client returned an error. This should be handled in the outer resetTransport main reset loop. The reason this became a problem is that the outer resetTransport has very specific conditions around when to transition into TRANSIENT FAILURE that the egregious transition did not have. So, it could transition into TRANSIENT FAILURE after failing to dial, even if it was trying to connect to a non-final address in the list of addresses. - In clientconn.go, we incorrectly stay in CONNECTING after `createTransport` when a server sends its connection preface but the client is not able to send its connection preface. This CL causes the addrconn to correctly enter TRANSIENT FAILURE when `createTransport` fails, even if a server preface was received. It does so by making ac.successfulHandshake to consider both server preface received as well as client preface sent.
This commit is contained in:
@ -1059,6 +1059,10 @@ func (ac *addrConn) createTransport(backoffNum int, addr resolver.Address, copts
|
|||||||
prefaceReceived := make(chan struct{})
|
prefaceReceived := make(chan struct{})
|
||||||
onCloseCalled := make(chan struct{})
|
onCloseCalled := make(chan struct{})
|
||||||
|
|
||||||
|
var prefaceMu sync.Mutex
|
||||||
|
var serverPrefaceReceived bool
|
||||||
|
var clientPrefaceWrote bool
|
||||||
|
|
||||||
onGoAway := func(r transport.GoAwayReason) {
|
onGoAway := func(r transport.GoAwayReason) {
|
||||||
ac.mu.Lock()
|
ac.mu.Lock()
|
||||||
ac.adjustParams(r)
|
ac.adjustParams(r)
|
||||||
@ -1100,11 +1104,18 @@ func (ac *addrConn) createTransport(backoffNum int, addr resolver.Address, copts
|
|||||||
|
|
||||||
// TODO(deklerk): optimization; does anyone else actually use this lock? maybe we can just remove it for this scope
|
// TODO(deklerk): optimization; does anyone else actually use this lock? maybe we can just remove it for this scope
|
||||||
ac.mu.Lock()
|
ac.mu.Lock()
|
||||||
ac.successfulHandshake = true
|
|
||||||
ac.backoffDeadline = time.Time{}
|
prefaceMu.Lock()
|
||||||
ac.connectDeadline = time.Time{}
|
serverPrefaceReceived = true
|
||||||
ac.addrIdx = 0
|
if clientPrefaceWrote {
|
||||||
ac.backoffIdx = 0
|
ac.successfulHandshake = true
|
||||||
|
ac.backoffDeadline = time.Time{}
|
||||||
|
ac.connectDeadline = time.Time{}
|
||||||
|
ac.addrIdx = 0
|
||||||
|
ac.backoffIdx = 0
|
||||||
|
}
|
||||||
|
prefaceMu.Unlock()
|
||||||
|
|
||||||
ac.mu.Unlock()
|
ac.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1117,6 +1128,13 @@ func (ac *addrConn) createTransport(backoffNum int, addr resolver.Address, copts
|
|||||||
newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, target, copts, onPrefaceReceipt, onGoAway, onClose)
|
newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, target, copts, onPrefaceReceipt, onGoAway, onClose)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
prefaceMu.Lock()
|
||||||
|
clientPrefaceWrote = true
|
||||||
|
if serverPrefaceReceived {
|
||||||
|
ac.successfulHandshake = true
|
||||||
|
}
|
||||||
|
prefaceMu.Unlock()
|
||||||
|
|
||||||
if ac.dopts.waitForHandshake {
|
if ac.dopts.waitForHandshake {
|
||||||
select {
|
select {
|
||||||
case <-prefaceTimer.C:
|
case <-prefaceTimer.C:
|
||||||
@ -1160,8 +1178,6 @@ func (ac *addrConn) createTransport(backoffNum int, addr resolver.Address, copts
|
|||||||
|
|
||||||
return errConnClosing
|
return errConnClosing
|
||||||
}
|
}
|
||||||
ac.updateConnectivityState(connectivity.TransientFailure)
|
|
||||||
ac.cc.handleSubConnStateChange(ac.acbw, ac.state)
|
|
||||||
ac.mu.Unlock()
|
ac.mu.Unlock()
|
||||||
grpclog.Warningf("grpc: addrConn.createTransport failed to connect to %v. Err :%v. Reconnecting...", addr, err)
|
grpclog.Warningf("grpc: addrConn.createTransport failed to connect to %v. Err :%v. Reconnecting...", addr, err)
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ package grpc
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -29,6 +30,7 @@ import (
|
|||||||
"google.golang.org/grpc/balancer"
|
"google.golang.org/grpc/balancer"
|
||||||
"google.golang.org/grpc/connectivity"
|
"google.golang.org/grpc/connectivity"
|
||||||
"google.golang.org/grpc/internal/leakcheck"
|
"google.golang.org/grpc/internal/leakcheck"
|
||||||
|
"google.golang.org/grpc/internal/testutils"
|
||||||
"google.golang.org/grpc/resolver"
|
"google.golang.org/grpc/resolver"
|
||||||
"google.golang.org/grpc/resolver/manual"
|
"google.golang.org/grpc/resolver/manual"
|
||||||
)
|
)
|
||||||
@ -41,76 +43,135 @@ func init() {
|
|||||||
balancer.Register(testBalancer)
|
balancer.Register(testBalancer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// These tests use a pipeListener. This listener is similar to net.Listener except that it is unbuffered, so each read
|
||||||
|
// and write will wait for the other side's corresponding write or read.
|
||||||
func TestStateTransitions_SingleAddress(t *testing.T) {
|
func TestStateTransitions_SingleAddress(t *testing.T) {
|
||||||
|
defer leakcheck.Check(t)
|
||||||
|
|
||||||
|
mctBkp := getMinConnectTimeout()
|
||||||
|
defer func() {
|
||||||
|
atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(mctBkp))
|
||||||
|
}()
|
||||||
|
atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(time.Millisecond)*100)
|
||||||
|
|
||||||
for _, test := range []struct {
|
for _, test := range []struct {
|
||||||
name string
|
desc string
|
||||||
want []connectivity.State
|
want []connectivity.State
|
||||||
server func(net.Listener)
|
server func(net.Listener) net.Conn
|
||||||
}{
|
}{
|
||||||
// When the server returns server preface, the client enters READY.
|
|
||||||
{
|
{
|
||||||
name: "ServerEntersReadyOnPrefaceReceipt",
|
desc: "When the server returns server preface, the client enters READY.",
|
||||||
want: []connectivity.State{
|
want: []connectivity.State{
|
||||||
connectivity.Connecting,
|
connectivity.Connecting,
|
||||||
connectivity.Ready,
|
connectivity.Ready,
|
||||||
},
|
},
|
||||||
server: func(lis net.Listener) {
|
server: func(lis net.Listener) net.Conn {
|
||||||
conn, err := lis.Accept()
|
conn, err := lis.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
return
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
go keepReading(conn)
|
||||||
|
|
||||||
|
framer := http2.NewFramer(conn, conn)
|
||||||
|
if err := framer.WriteSettings(http2.Setting{}); err != nil {
|
||||||
|
t.Errorf("Error while writing settings frame. %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "When the connection is closed, the client enters TRANSIENT FAILURE.",
|
||||||
|
want: []connectivity.State{
|
||||||
|
connectivity.Connecting,
|
||||||
|
connectivity.TransientFailure,
|
||||||
|
},
|
||||||
|
server: func(lis net.Listener) net.Conn {
|
||||||
|
conn, err := lis.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.Close()
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: `When the server sends its connection preface, but the connection dies before the client can write its
|
||||||
|
connection preface, the client enters TRANSIENT FAILURE.`,
|
||||||
|
want: []connectivity.State{
|
||||||
|
connectivity.Connecting,
|
||||||
|
connectivity.TransientFailure,
|
||||||
|
},
|
||||||
|
server: func(lis net.Listener) net.Conn {
|
||||||
|
conn, err := lis.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
framer := http2.NewFramer(conn, conn)
|
framer := http2.NewFramer(conn, conn)
|
||||||
if err := framer.WriteSettings(http2.Setting{}); err != nil {
|
if err := framer.WriteSettings(http2.Setting{}); err != nil {
|
||||||
t.Errorf("Error while writing settings frame. %v", err)
|
t.Errorf("Error while writing settings frame. %v", err)
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
conn.Close()
|
||||||
|
return nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
// When the connection is closed, the client enters TRANSIENT FAILURE.
|
|
||||||
{
|
{
|
||||||
name: "ServerEntersTransientFailureOnClose",
|
desc: `When the server reads the client connection preface but does not send its connection preface, the
|
||||||
|
client enters TRANSIENT FAILURE.`,
|
||||||
want: []connectivity.State{
|
want: []connectivity.State{
|
||||||
connectivity.Connecting,
|
connectivity.Connecting,
|
||||||
connectivity.TransientFailure,
|
connectivity.TransientFailure,
|
||||||
},
|
},
|
||||||
server: func(lis net.Listener) {
|
server: func(lis net.Listener) net.Conn {
|
||||||
conn, err := lis.Accept()
|
conn, err := lis.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.Close()
|
go keepReading(conn)
|
||||||
|
|
||||||
|
return conn
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
} {
|
} {
|
||||||
t.Logf("Test %s", test.name)
|
t.Log(test.desc)
|
||||||
testStateTransitionSingleAddress(t, test.want, test.server)
|
testStateTransitionSingleAddress(t, test.want, test.server)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testStateTransitionSingleAddress(t *testing.T, want []connectivity.State, server func(net.Listener)) {
|
func testStateTransitionSingleAddress(t *testing.T, want []connectivity.State, server func(net.Listener) net.Conn) {
|
||||||
defer leakcheck.Check(t)
|
defer leakcheck.Check(t)
|
||||||
|
|
||||||
stateNotifications := make(chan connectivity.State, len(want))
|
stateNotifications := make(chan connectivity.State, len(want))
|
||||||
testBalancer.ResetNotifier(stateNotifications)
|
testBalancer.ResetNotifier(stateNotifications)
|
||||||
defer close(stateNotifications)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", "localhost:0")
|
pl := testutils.NewPipeListener()
|
||||||
if err != nil {
|
defer pl.Close()
|
||||||
t.Fatalf("Error while listening. Err: %v", err)
|
|
||||||
}
|
|
||||||
defer lis.Close()
|
|
||||||
|
|
||||||
// Launch the server.
|
// Launch the server.
|
||||||
go server(lis)
|
var conn net.Conn
|
||||||
|
var connMu sync.Mutex
|
||||||
|
go func() {
|
||||||
|
connMu.Lock()
|
||||||
|
conn = server(pl)
|
||||||
|
connMu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
client, err := DialContext(ctx, lis.Addr().String(), WithWaitForHandshake(), WithInsecure(), WithBalancerName(stateRecordingBalancerName))
|
client, err := DialContext(ctx, "", WithWaitForHandshake(), WithInsecure(),
|
||||||
|
WithBalancerName(stateRecordingBalancerName), WithDialer(pl.Dialer()), withBackoff(noBackoff{}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -128,6 +189,15 @@ func testStateTransitionSingleAddress(t *testing.T, want []connectivity.State, s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
connMu.Lock()
|
||||||
|
defer connMu.Unlock()
|
||||||
|
if conn != nil {
|
||||||
|
err = conn.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// When a READY connection is closed, the client enters TRANSIENT FAILURE before CONNECTING.
|
// When a READY connection is closed, the client enters TRANSIENT FAILURE before CONNECTING.
|
||||||
@ -143,7 +213,6 @@ func TestStateTransition_ReadyToTransientFailure(t *testing.T) {
|
|||||||
|
|
||||||
stateNotifications := make(chan connectivity.State, len(want))
|
stateNotifications := make(chan connectivity.State, len(want))
|
||||||
testBalancer.ResetNotifier(stateNotifications)
|
testBalancer.ResetNotifier(stateNotifications)
|
||||||
defer close(stateNotifications)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@ -164,6 +233,8 @@ func TestStateTransition_ReadyToTransientFailure(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go keepReading(conn)
|
||||||
|
|
||||||
framer := http2.NewFramer(conn, conn)
|
framer := http2.NewFramer(conn, conn)
|
||||||
if err := framer.WriteSettings(http2.Setting{}); err != nil {
|
if err := framer.WriteSettings(http2.Setting{}); err != nil {
|
||||||
t.Errorf("Error while writing settings frame. %v", err)
|
t.Errorf("Error while writing settings frame. %v", err)
|
||||||
@ -211,7 +282,6 @@ func TestStateTransitions_TriesAllAddrsBeforeTransientFailure(t *testing.T) {
|
|||||||
|
|
||||||
stateNotifications := make(chan connectivity.State, len(want))
|
stateNotifications := make(chan connectivity.State, len(want))
|
||||||
testBalancer.ResetNotifier(stateNotifications)
|
testBalancer.ResetNotifier(stateNotifications)
|
||||||
defer close(stateNotifications)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@ -250,11 +320,14 @@ func TestStateTransitions_TriesAllAddrsBeforeTransientFailure(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go keepReading(conn)
|
||||||
|
|
||||||
framer := http2.NewFramer(conn, conn)
|
framer := http2.NewFramer(conn, conn)
|
||||||
if err := framer.WriteSettings(http2.Setting{}); err != nil {
|
if err := framer.WriteSettings(http2.Setting{}); err != nil {
|
||||||
t.Errorf("Error while writing settings frame. %v", err)
|
t.Errorf("Error while writing settings frame. %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
close(server2Done)
|
close(server2Done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -307,7 +380,6 @@ func TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) {
|
|||||||
|
|
||||||
stateNotifications := make(chan connectivity.State, len(want))
|
stateNotifications := make(chan connectivity.State, len(want))
|
||||||
testBalancer.ResetNotifier(stateNotifications)
|
testBalancer.ResetNotifier(stateNotifications)
|
||||||
defer close(stateNotifications)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@ -336,6 +408,8 @@ func TestStateTransitions_MultipleAddrsEntersReady(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go keepReading(conn)
|
||||||
|
|
||||||
framer := http2.NewFramer(conn, conn)
|
framer := http2.NewFramer(conn, conn)
|
||||||
if err := framer.WriteSettings(http2.Setting{}); err != nil {
|
if err := framer.WriteSettings(http2.Setting{}); err != nil {
|
||||||
t.Errorf("Error while writing settings frame. %v", err)
|
t.Errorf("Error while writing settings frame. %v", err)
|
||||||
@ -426,3 +500,16 @@ func (b *stateRecordingBalancer) Build(cc balancer.ClientConn, opts balancer.Bui
|
|||||||
b.mu.Unlock()
|
b.mu.Unlock()
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type noBackoff struct{}
|
||||||
|
|
||||||
|
func (b noBackoff) Backoff(int) time.Duration { return time.Duration(0) }
|
||||||
|
|
||||||
|
// Keep reading until something causes the connection to die (EOF, server closed, etc). Useful
|
||||||
|
// as a tool for mindlessly keeping the connection healthy, since the client will error if
|
||||||
|
// things like client prefaces are not accepted in a timely fashion.
|
||||||
|
func keepReading(conn net.Conn) {
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
for _, err := conn.Read(buf); err == nil; _, err = conn.Read(buf) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -225,10 +225,8 @@ func TestCloseConnectionWhenServerPrefaceNotReceived(t *testing.T) {
|
|||||||
// 3. The new server sends its preface.
|
// 3. The new server sends its preface.
|
||||||
// 4. Client doesn't kill the connection this time.
|
// 4. Client doesn't kill the connection this time.
|
||||||
mctBkp := getMinConnectTimeout()
|
mctBkp := getMinConnectTimeout()
|
||||||
// Call this only after transportMonitor goroutine has ended.
|
|
||||||
defer func() {
|
defer func() {
|
||||||
atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(mctBkp))
|
atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(mctBkp))
|
||||||
|
|
||||||
}()
|
}()
|
||||||
defer leakcheck.Check(t)
|
defer leakcheck.Check(t)
|
||||||
atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(time.Millisecond)*500)
|
atomic.StoreInt64((*int64)(&mutableMinConnectTimeout), int64(time.Millisecond)*500)
|
||||||
|
95
internal/testutils/pipe_listener.go
Normal file
95
internal/testutils/pipe_listener.go
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
/*
|
||||||
|
*
|
||||||
|
* Copyright 2018 gRPC authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package testutils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errClosed = errors.New("closed")
|
||||||
|
|
||||||
|
type pipeAddr struct{}
|
||||||
|
|
||||||
|
func (p pipeAddr) Network() string { return "pipe" }
|
||||||
|
func (p pipeAddr) String() string { return "pipe" }
|
||||||
|
|
||||||
|
// PipeListener is a listener with an unbuffered pipe. Each write will complete only once the other side reads. It
|
||||||
|
// should only be created using NewPipeListener.
|
||||||
|
type PipeListener struct {
|
||||||
|
c chan chan<- net.Conn
|
||||||
|
done chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPipeListener creates a new pipe listener.
|
||||||
|
func NewPipeListener() *PipeListener {
|
||||||
|
return &PipeListener{
|
||||||
|
c: make(chan chan<- net.Conn),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accept accepts a connection.
|
||||||
|
func (p *PipeListener) Accept() (net.Conn, error) {
|
||||||
|
var connChan chan<- net.Conn
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
return nil, errClosed
|
||||||
|
case connChan = <-p.c:
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
close(connChan)
|
||||||
|
return nil, errClosed
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c1, c2 := net.Pipe()
|
||||||
|
connChan <- c1
|
||||||
|
close(connChan)
|
||||||
|
return c2, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the listener.
|
||||||
|
func (p *PipeListener) Close() error {
|
||||||
|
close(p.done)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Addr returns a pipe addr.
|
||||||
|
func (p *PipeListener) Addr() net.Addr {
|
||||||
|
return pipeAddr{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dialer dials a connection.
|
||||||
|
func (p *PipeListener) Dialer() func(string, time.Duration) (net.Conn, error) {
|
||||||
|
return func(string, time.Duration) (net.Conn, error) {
|
||||||
|
connChan := make(chan net.Conn)
|
||||||
|
select {
|
||||||
|
case p.c <- connChan:
|
||||||
|
case <-p.done:
|
||||||
|
return nil, errClosed
|
||||||
|
}
|
||||||
|
conn, ok := <-connChan
|
||||||
|
if !ok {
|
||||||
|
return nil, errClosed
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
}
|
163
internal/testutils/pipe_listener_test.go
Normal file
163
internal/testutils/pipe_listener_test.go
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
/*
|
||||||
|
*
|
||||||
|
* Copyright 2018 gRPC authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package testutils_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/internal/testutils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPipeListener(t *testing.T) {
|
||||||
|
pl := testutils.NewPipeListener()
|
||||||
|
recvdBytes := make(chan []byte)
|
||||||
|
const want = "hello world"
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
c, err := pl.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
read := make([]byte, len(want))
|
||||||
|
_, err = c.Read(read)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
recvdBytes <- read
|
||||||
|
}()
|
||||||
|
|
||||||
|
dl := pl.Dialer()
|
||||||
|
conn, err := dl("", time.Duration(0))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = conn.Write([]byte(want))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case gotBytes := <-recvdBytes:
|
||||||
|
got := string(gotBytes)
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("expected to get %s, got %s", got, want)
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("timed out waiting for server to receive bytes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnblocking(t *testing.T) {
|
||||||
|
for _, test := range []struct {
|
||||||
|
desc string
|
||||||
|
blockFuncShouldError bool
|
||||||
|
blockFunc func(*testutils.PipeListener, chan struct{}) error
|
||||||
|
unblockFunc func(*testutils.PipeListener) error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "Accept unblocks Dial",
|
||||||
|
blockFunc: func(pl *testutils.PipeListener, done chan struct{}) error {
|
||||||
|
dl := pl.Dialer()
|
||||||
|
_, err := dl("", time.Duration(0))
|
||||||
|
close(done)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
unblockFunc: func(pl *testutils.PipeListener) error {
|
||||||
|
_, err := pl.Accept()
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Close unblocks Dial",
|
||||||
|
blockFuncShouldError: true, // because pl.Close will be called
|
||||||
|
blockFunc: func(pl *testutils.PipeListener, done chan struct{}) error {
|
||||||
|
dl := pl.Dialer()
|
||||||
|
_, err := dl("", time.Duration(0))
|
||||||
|
close(done)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
unblockFunc: func(pl *testutils.PipeListener) error {
|
||||||
|
return pl.Close()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Dial unblocks Accept",
|
||||||
|
blockFunc: func(pl *testutils.PipeListener, done chan struct{}) error {
|
||||||
|
_, err := pl.Accept()
|
||||||
|
close(done)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
unblockFunc: func(pl *testutils.PipeListener) error {
|
||||||
|
dl := pl.Dialer()
|
||||||
|
_, err := dl("", time.Duration(0))
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Close unblocks Accept",
|
||||||
|
blockFuncShouldError: true, // because pl.Close will be called
|
||||||
|
blockFunc: func(pl *testutils.PipeListener, done chan struct{}) error {
|
||||||
|
_, err := pl.Accept()
|
||||||
|
close(done)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
unblockFunc: func(pl *testutils.PipeListener) error {
|
||||||
|
return pl.Close()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Log(test.desc)
|
||||||
|
testUnblocking(t, test.blockFunc, test.unblockFunc, test.blockFuncShouldError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testUnblocking(t *testing.T, blockFunc func(*testutils.PipeListener, chan struct{}) error, unblockFunc func(*testutils.PipeListener) error, blockFuncShouldError bool) {
|
||||||
|
pl := testutils.NewPipeListener()
|
||||||
|
dialFinished := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
err := blockFunc(pl, dialFinished)
|
||||||
|
if blockFuncShouldError && err == nil {
|
||||||
|
t.Error("expected blocking func to return error because pl.Close was called, but got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !blockFuncShouldError && err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-dialFinished:
|
||||||
|
t.Fatal("expected Dial to block until pl.Close or pl.Accept")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := unblockFunc(pl); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-dialFinished:
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatal("expected Accept to unblock after pl.Accept was called")
|
||||||
|
}
|
||||||
|
}
|
@ -58,6 +58,7 @@ import (
|
|||||||
healthpb "google.golang.org/grpc/health/grpc_health_v1"
|
healthpb "google.golang.org/grpc/health/grpc_health_v1"
|
||||||
"google.golang.org/grpc/internal/channelz"
|
"google.golang.org/grpc/internal/channelz"
|
||||||
"google.golang.org/grpc/internal/leakcheck"
|
"google.golang.org/grpc/internal/leakcheck"
|
||||||
|
"google.golang.org/grpc/internal/testutils"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/peer"
|
"google.golang.org/grpc/peer"
|
||||||
@ -6926,49 +6927,11 @@ func testClientMaxHeaderListSizeServerIntentionalViolation(t *testing.T, e env)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type pipeAddr struct{}
|
|
||||||
|
|
||||||
func (p pipeAddr) Network() string { return "pipe" }
|
|
||||||
func (p pipeAddr) String() string { return "pipe" }
|
|
||||||
|
|
||||||
type pipeListener struct {
|
|
||||||
c chan chan<- net.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *pipeListener) Accept() (net.Conn, error) {
|
|
||||||
connChan, ok := <-p.c
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("closed")
|
|
||||||
}
|
|
||||||
c1, c2 := net.Pipe()
|
|
||||||
connChan <- c1
|
|
||||||
close(connChan)
|
|
||||||
return c2, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *pipeListener) Close() error {
|
|
||||||
close(p.c)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *pipeListener) Addr() net.Addr {
|
|
||||||
return pipeAddr{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *pipeListener) Dialer() func(string, time.Duration) (net.Conn, error) {
|
|
||||||
return func(string, time.Duration) (net.Conn, error) {
|
|
||||||
connChan := make(chan net.Conn)
|
|
||||||
p.c <- connChan
|
|
||||||
conn := <-connChan
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNetPipeConn(t *testing.T) {
|
func TestNetPipeConn(t *testing.T) {
|
||||||
// This test will block indefinitely if grpc writes both client and server
|
// This test will block indefinitely if grpc writes both client and server
|
||||||
// prefaces without either reading from the Conn.
|
// prefaces without either reading from the Conn.
|
||||||
defer leakcheck.Check(t)
|
defer leakcheck.Check(t)
|
||||||
pl := &pipeListener{c: make(chan chan<- net.Conn)}
|
pl := testutils.NewPipeListener()
|
||||||
s := grpc.NewServer()
|
s := grpc.NewServer()
|
||||||
defer s.Stop()
|
defer s.Stop()
|
||||||
ts := &funcServer{unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
|
ts := &funcServer{unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
|
||||||
|
Reference in New Issue
Block a user