remove sync.WaitGroup param from ServerTransport.HandleStream
This commit is contained in:
@ -259,7 +259,8 @@ func (s *Server) Serve(lis net.Listener) error {
|
||||
s.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
st.HandleStreams(func(stream *transport.Stream, wg *sync.WaitGroup) {
|
||||
var wg sync.WaitGroup
|
||||
st.HandleStreams(func(stream *transport.Stream) {
|
||||
var trInfo *traceInfo
|
||||
if EnableTracing {
|
||||
trInfo = &traceInfo{
|
||||
@ -278,6 +279,7 @@ func (s *Server) Serve(lis net.Listener) error {
|
||||
wg.Done()
|
||||
}()
|
||||
})
|
||||
wg.Wait()
|
||||
s.mu.Lock()
|
||||
delete(s.conns, st)
|
||||
s.mu.Unlock()
|
||||
|
@ -138,7 +138,7 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI
|
||||
// operateHeader takes action on the decoded headers. It returns the current
|
||||
// stream if there are remaining headers on the wire (in the following
|
||||
// Continuation frame).
|
||||
func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame headerFrame, endStream bool, handle func(*Stream, *sync.WaitGroup), wg *sync.WaitGroup) (pendingStream *Stream) {
|
||||
func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame headerFrame, endStream bool, handle func(*Stream)) (pendingStream *Stream) {
|
||||
defer func() {
|
||||
if pendingStream == nil {
|
||||
hDec.state = decodeState{}
|
||||
@ -202,13 +202,13 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header
|
||||
recv: s.buf,
|
||||
}
|
||||
s.method = hDec.state.method
|
||||
handle(s, wg)
|
||||
handle(s)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleStreams receives incoming streams using the given handler. This is
|
||||
// typically run in a separate goroutine.
|
||||
func (t *http2Server) HandleStreams(handle func(*Stream, *sync.WaitGroup)) {
|
||||
func (t *http2Server) HandleStreams(handle func(*Stream)) {
|
||||
// Check the validity of client preface.
|
||||
preface := make([]byte, len(clientPreface))
|
||||
if _, err := io.ReadFull(t.conn, preface); err != nil {
|
||||
@ -238,8 +238,6 @@ func (t *http2Server) HandleStreams(handle func(*Stream, *sync.WaitGroup)) {
|
||||
|
||||
hDec := newHPACKDecoder()
|
||||
var curStream *Stream
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
for {
|
||||
frame, err := t.framer.readFrame()
|
||||
if err != nil {
|
||||
@ -268,9 +266,9 @@ func (t *http2Server) HandleStreams(handle func(*Stream, *sync.WaitGroup)) {
|
||||
fc: fc,
|
||||
}
|
||||
endStream := frame.Header().Flags.Has(http2.FlagHeadersEndStream)
|
||||
curStream = t.operateHeaders(hDec, curStream, frame, endStream, handle, &wg)
|
||||
curStream = t.operateHeaders(hDec, curStream, frame, endStream, handle)
|
||||
case *http2.ContinuationFrame:
|
||||
curStream = t.operateHeaders(hDec, curStream, frame, false, handle, &wg)
|
||||
curStream = t.operateHeaders(hDec, curStream, frame, false, handle)
|
||||
case *http2.DataFrame:
|
||||
t.handleData(frame)
|
||||
case *http2.RSTStreamFrame:
|
||||
|
@ -391,7 +391,7 @@ type ServerTransport interface {
|
||||
// WriteHeader sends the header metedata for the given stream.
|
||||
WriteHeader(s *Stream, md metadata.MD) error
|
||||
// HandleStreams receives incoming streams using the given handler.
|
||||
HandleStreams(func(*Stream, *sync.WaitGroup))
|
||||
HandleStreams(func(*Stream))
|
||||
// Close tears down the transport. Once it is called, the transport
|
||||
// should not be accessed any more. All the pending streams and their
|
||||
// handlers will be terminated asynchronously.
|
||||
|
@ -77,8 +77,7 @@ const (
|
||||
misbehaved
|
||||
)
|
||||
|
||||
func (h *testStreamHandler) handleStream(t *testing.T, s *Stream, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
|
||||
req := expectedRequest
|
||||
resp := expectedResponse
|
||||
if s.Method() == "foo.Large" {
|
||||
@ -100,16 +99,13 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream, wg *sync.WaitG
|
||||
}
|
||||
|
||||
// handleStreamSuspension blocks until s.ctx is canceled.
|
||||
func (h *testStreamHandler) handleStreamSuspension(s *Stream, wg *sync.WaitGroup) {
|
||||
wg.Add(1)
|
||||
func (h *testStreamHandler) handleStreamSuspension(s *Stream) {
|
||||
go func() {
|
||||
<-s.ctx.Done()
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) {
|
||||
conn, ok := s.ServerTransport().(*http2Server)
|
||||
if !ok {
|
||||
t.Fatalf("Failed to convert %v to *http2Server", s.ServerTransport())
|
||||
@ -173,14 +169,12 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) {
|
||||
case suspended:
|
||||
go transport.HandleStreams(h.handleStreamSuspension)
|
||||
case misbehaved:
|
||||
go transport.HandleStreams(func(s *Stream, wg *sync.WaitGroup) {
|
||||
wg.Add(1)
|
||||
go h.handleStreamMisbehave(t, s, wg)
|
||||
go transport.HandleStreams(func(s *Stream) {
|
||||
go h.handleStreamMisbehave(t, s)
|
||||
})
|
||||
default:
|
||||
go transport.HandleStreams(func(s *Stream, wg *sync.WaitGroup) {
|
||||
wg.Add(1)
|
||||
go h.handleStream(t, s, wg)
|
||||
go transport.HandleStreams(func(s *Stream) {
|
||||
go h.handleStream(t, s)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user