diff --git a/transport/transport_test.go b/transport/transport_test.go index 1a284072..39792e4a 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -97,7 +97,8 @@ func (h *testStreamHandler) handleStreamSuspension(s *Stream) { <-s.ctx.Done() } -func (s *server) Start(useTLS bool, port int, maxStreams uint32, suspend bool) { +// start starts server. Other goroutines should block on s.readyChan for futher operations. +func (s *server) start(useTLS bool, port int, maxStreams uint32, suspend bool) { var err error if port == 0 { s.lis, err = net.Listen("tcp", ":0") @@ -119,10 +120,10 @@ func (s *server) Start(useTLS bool, port int, maxStreams uint32, suspend bool) { log.Fatalf("failed to parse listener address: %v", err) } s.port = p + s.conns = make(map[ServerTransport]bool) if s.readyChan != nil { close(s.readyChan) } - s.conns = make(map[ServerTransport]bool) for { conn, err := s.lis.Accept() if err != nil { @@ -149,7 +150,7 @@ func (s *server) Start(useTLS bool, port int, maxStreams uint32, suspend bool) { } } -func (s *server) Wait(t *testing.T, timeout time.Duration) { +func (s *server) wait(t *testing.T, timeout time.Duration) { select { case <-s.readyChan: case <-time.After(timeout): @@ -157,20 +158,20 @@ func (s *server) Wait(t *testing.T, timeout time.Duration) { } } -func (s *server) Close() { - // Keep consistent with closeServer(). +func (s *server) stop() { s.lis.Close() s.mu.Lock() for c := range s.conns { c.Close() } + s.conns = nil s.mu.Unlock() } func setUp(t *testing.T, useTLS bool, port int, maxStreams uint32, suspend bool) (*server, ClientTransport) { server := &server{readyChan: make(chan bool)} - go server.Start(useTLS, port, maxStreams, suspend) - server.Wait(t, 2*time.Second) + go server.start(useTLS, port, maxStreams, suspend) + server.wait(t, 2*time.Second) addr := "localhost:" + server.port var ( ct ClientTransport @@ -231,24 +232,12 @@ func TestClientSendAndReceive(t *testing.T) { t.Fatalf("Error: %v; want ", recvErr) } ct.Close() - server.Close() + server.stop() } func TestClientErrorNotify(t *testing.T) { server, ct := setUp(t, true, 0, math.MaxUint32, false) - callHdr := &CallHdr{ - Host: "localhost", - Method: "foo.Small", - } - s, err := ct.NewStream(context.Background(), callHdr) - if err != nil { - t.Fatalf("failed to open stream: %v", err) - } - if s.id != 1 { - t.Fatalf("wrong stream id: %d", s.id) - } - // Tear down the server. - go server.Close() + go server.stop() // ct.reader should detect the error and activate ct.Error(). <-ct.Error() ct.Close() @@ -284,7 +273,7 @@ func TestClientMix(t *testing.T) { s, ct := setUp(t, true, 0, math.MaxUint32, false) go func(s *server) { time.Sleep(5 * time.Second) - s.Close() + s.stop() }(s) go func(ct ClientTransport) { <-ct.Error() @@ -300,7 +289,7 @@ func TestExceedMaxStreamsLimit(t *testing.T) { server, ct := setUp(t, true, 0, 1, false) defer func() { ct.Close() - server.Close() + server.stop() }() callHdr := &CallHdr{ Host: "localhost", @@ -375,7 +364,7 @@ func TestLargeMessage(t *testing.T) { } wg.Wait() ct.Close() - server.Close() + server.stop() } func TestLargeMessageSuspension(t *testing.T) { @@ -397,7 +386,7 @@ func TestLargeMessageSuspension(t *testing.T) { t.Fatalf("Write got %v, want %v", err, expectedErr) } ct.Close() - server.Close() + server.stop() } func TestStreamContext(t *testing.T) {