diff --git a/server.go b/server.go index 5c20dd4b..487a75c5 100644 --- a/server.go +++ b/server.go @@ -259,7 +259,8 @@ func (s *Server) Serve(lis net.Listener) error { s.mu.Unlock() go func() { - st.HandleStreams(func(stream *transport.Stream, wg *sync.WaitGroup) { + var wg sync.WaitGroup + st.HandleStreams(func(stream *transport.Stream) { var trInfo *traceInfo if EnableTracing { trInfo = &traceInfo{ @@ -278,6 +279,7 @@ func (s *Server) Serve(lis net.Listener) error { wg.Done() }() }) + wg.Wait() s.mu.Lock() delete(s.conns, st) s.mu.Unlock() diff --git a/transport/http2_server.go b/transport/http2_server.go index ed8fde08..52d0ee53 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -138,7 +138,7 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI // operateHeader takes action on the decoded headers. It returns the current // stream if there are remaining headers on the wire (in the following // Continuation frame). -func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame headerFrame, endStream bool, handle func(*Stream, *sync.WaitGroup), wg *sync.WaitGroup) (pendingStream *Stream) { +func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame headerFrame, endStream bool, handle func(*Stream)) (pendingStream *Stream) { defer func() { if pendingStream == nil { hDec.state = decodeState{} @@ -202,13 +202,13 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header recv: s.buf, } s.method = hDec.state.method - handle(s, wg) + handle(s) return nil } // HandleStreams receives incoming streams using the given handler. This is // typically run in a separate goroutine. -func (t *http2Server) HandleStreams(handle func(*Stream, *sync.WaitGroup)) { +func (t *http2Server) HandleStreams(handle func(*Stream)) { // Check the validity of client preface. preface := make([]byte, len(clientPreface)) if _, err := io.ReadFull(t.conn, preface); err != nil { @@ -238,8 +238,6 @@ func (t *http2Server) HandleStreams(handle func(*Stream, *sync.WaitGroup)) { hDec := newHPACKDecoder() var curStream *Stream - var wg sync.WaitGroup - defer wg.Wait() for { frame, err := t.framer.readFrame() if err != nil { @@ -268,9 +266,9 @@ func (t *http2Server) HandleStreams(handle func(*Stream, *sync.WaitGroup)) { fc: fc, } endStream := frame.Header().Flags.Has(http2.FlagHeadersEndStream) - curStream = t.operateHeaders(hDec, curStream, frame, endStream, handle, &wg) + curStream = t.operateHeaders(hDec, curStream, frame, endStream, handle) case *http2.ContinuationFrame: - curStream = t.operateHeaders(hDec, curStream, frame, false, handle, &wg) + curStream = t.operateHeaders(hDec, curStream, frame, false, handle) case *http2.DataFrame: t.handleData(frame) case *http2.RSTStreamFrame: diff --git a/transport/transport.go b/transport/transport.go index c319a5fa..e1e7f576 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -391,7 +391,7 @@ type ServerTransport interface { // WriteHeader sends the header metedata for the given stream. WriteHeader(s *Stream, md metadata.MD) error // HandleStreams receives incoming streams using the given handler. - HandleStreams(func(*Stream, *sync.WaitGroup)) + HandleStreams(func(*Stream)) // Close tears down the transport. Once it is called, the transport // should not be accessed any more. All the pending streams and their // handlers will be terminated asynchronously. diff --git a/transport/transport_test.go b/transport/transport_test.go index ba1d66a7..9bf3ed3c 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -77,8 +77,7 @@ const ( misbehaved ) -func (h *testStreamHandler) handleStream(t *testing.T, s *Stream, wg *sync.WaitGroup) { - defer wg.Done() +func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) { req := expectedRequest resp := expectedResponse if s.Method() == "foo.Large" { @@ -100,16 +99,13 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream, wg *sync.WaitG } // handleStreamSuspension blocks until s.ctx is canceled. -func (h *testStreamHandler) handleStreamSuspension(s *Stream, wg *sync.WaitGroup) { - wg.Add(1) +func (h *testStreamHandler) handleStreamSuspension(s *Stream) { go func() { <-s.ctx.Done() - wg.Done() }() } -func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream, wg *sync.WaitGroup) { - defer wg.Done() +func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) { conn, ok := s.ServerTransport().(*http2Server) if !ok { t.Fatalf("Failed to convert %v to *http2Server", s.ServerTransport()) @@ -173,14 +169,12 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) { case suspended: go transport.HandleStreams(h.handleStreamSuspension) case misbehaved: - go transport.HandleStreams(func(s *Stream, wg *sync.WaitGroup) { - wg.Add(1) - go h.handleStreamMisbehave(t, s, wg) + go transport.HandleStreams(func(s *Stream) { + go h.handleStreamMisbehave(t, s) }) default: - go transport.HandleStreams(func(s *Stream, wg *sync.WaitGroup) { - wg.Add(1) - go h.handleStream(t, s, wg) + go transport.HandleStreams(func(s *Stream) { + go h.handleStream(t, s) }) } }