Merge pull request #128 from iamqizhao/tp-test-fix

deflak a transport test
This commit is contained in:
Qi Zhao
2015-03-22 14:02:24 -07:00

View File

@ -97,7 +97,8 @@ func (h *testStreamHandler) handleStreamSuspension(s *Stream) {
<-s.ctx.Done() <-s.ctx.Done()
} }
func (s *server) Start(useTLS bool, port int, maxStreams uint32, suspend bool) { // start starts server. Other goroutines should block on s.readyChan for futher operations.
func (s *server) start(useTLS bool, port int, maxStreams uint32, suspend bool) {
var err error var err error
if port == 0 { if port == 0 {
s.lis, err = net.Listen("tcp", ":0") s.lis, err = net.Listen("tcp", ":0")
@ -119,10 +120,10 @@ func (s *server) Start(useTLS bool, port int, maxStreams uint32, suspend bool) {
log.Fatalf("failed to parse listener address: %v", err) log.Fatalf("failed to parse listener address: %v", err)
} }
s.port = p s.port = p
s.conns = make(map[ServerTransport]bool)
if s.readyChan != nil { if s.readyChan != nil {
close(s.readyChan) close(s.readyChan)
} }
s.conns = make(map[ServerTransport]bool)
for { for {
conn, err := s.lis.Accept() conn, err := s.lis.Accept()
if err != nil { if err != nil {
@ -149,7 +150,7 @@ func (s *server) Start(useTLS bool, port int, maxStreams uint32, suspend bool) {
} }
} }
func (s *server) Wait(t *testing.T, timeout time.Duration) { func (s *server) wait(t *testing.T, timeout time.Duration) {
select { select {
case <-s.readyChan: case <-s.readyChan:
case <-time.After(timeout): case <-time.After(timeout):
@ -157,20 +158,20 @@ func (s *server) Wait(t *testing.T, timeout time.Duration) {
} }
} }
func (s *server) Close() { func (s *server) stop() {
// Keep consistent with closeServer().
s.lis.Close() s.lis.Close()
s.mu.Lock() s.mu.Lock()
for c := range s.conns { for c := range s.conns {
c.Close() c.Close()
} }
s.conns = nil
s.mu.Unlock() s.mu.Unlock()
} }
func setUp(t *testing.T, useTLS bool, port int, maxStreams uint32, suspend bool) (*server, ClientTransport) { func setUp(t *testing.T, useTLS bool, port int, maxStreams uint32, suspend bool) (*server, ClientTransport) {
server := &server{readyChan: make(chan bool)} server := &server{readyChan: make(chan bool)}
go server.Start(useTLS, port, maxStreams, suspend) go server.start(useTLS, port, maxStreams, suspend)
server.Wait(t, 2*time.Second) server.wait(t, 2*time.Second)
addr := "localhost:" + server.port addr := "localhost:" + server.port
var ( var (
ct ClientTransport ct ClientTransport
@ -231,24 +232,12 @@ func TestClientSendAndReceive(t *testing.T) {
t.Fatalf("Error: %v; want <EOF>", recvErr) t.Fatalf("Error: %v; want <EOF>", recvErr)
} }
ct.Close() ct.Close()
server.Close() server.stop()
} }
func TestClientErrorNotify(t *testing.T) { func TestClientErrorNotify(t *testing.T) {
server, ct := setUp(t, true, 0, math.MaxUint32, false) server, ct := setUp(t, true, 0, math.MaxUint32, false)
callHdr := &CallHdr{ go server.stop()
Host: "localhost",
Method: "foo.Small",
}
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
t.Fatalf("failed to open stream: %v", err)
}
if s.id != 1 {
t.Fatalf("wrong stream id: %d", s.id)
}
// Tear down the server.
go server.Close()
// ct.reader should detect the error and activate ct.Error(). // ct.reader should detect the error and activate ct.Error().
<-ct.Error() <-ct.Error()
ct.Close() ct.Close()
@ -284,7 +273,7 @@ func TestClientMix(t *testing.T) {
s, ct := setUp(t, true, 0, math.MaxUint32, false) s, ct := setUp(t, true, 0, math.MaxUint32, false)
go func(s *server) { go func(s *server) {
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
s.Close() s.stop()
}(s) }(s)
go func(ct ClientTransport) { go func(ct ClientTransport) {
<-ct.Error() <-ct.Error()
@ -300,7 +289,7 @@ func TestExceedMaxStreamsLimit(t *testing.T) {
server, ct := setUp(t, true, 0, 1, false) server, ct := setUp(t, true, 0, 1, false)
defer func() { defer func() {
ct.Close() ct.Close()
server.Close() server.stop()
}() }()
callHdr := &CallHdr{ callHdr := &CallHdr{
Host: "localhost", Host: "localhost",
@ -375,7 +364,7 @@ func TestLargeMessage(t *testing.T) {
} }
wg.Wait() wg.Wait()
ct.Close() ct.Close()
server.Close() server.stop()
} }
func TestLargeMessageSuspension(t *testing.T) { func TestLargeMessageSuspension(t *testing.T) {
@ -397,7 +386,7 @@ func TestLargeMessageSuspension(t *testing.T) {
t.Fatalf("Write got %v, want %v", err, expectedErr) t.Fatalf("Write got %v, want %v", err, expectedErr)
} }
ct.Close() ct.Close()
server.Close() server.stop()
} }
func TestStreamContext(t *testing.T) { func TestStreamContext(t *testing.T) {