remove sync.WaitGroup param from ServerTransport.HandleStream

This commit is contained in:
iamqizhao
2015-10-06 16:44:30 -07:00
parent e63d714a45
commit 63a6c4155a
4 changed files with 16 additions and 22 deletions

View File

@ -259,7 +259,8 @@ func (s *Server) Serve(lis net.Listener) error {
s.mu.Unlock() s.mu.Unlock()
go func() { go func() {
st.HandleStreams(func(stream *transport.Stream, wg *sync.WaitGroup) { var wg sync.WaitGroup
st.HandleStreams(func(stream *transport.Stream) {
var trInfo *traceInfo var trInfo *traceInfo
if EnableTracing { if EnableTracing {
trInfo = &traceInfo{ trInfo = &traceInfo{
@ -278,6 +279,7 @@ func (s *Server) Serve(lis net.Listener) error {
wg.Done() wg.Done()
}() }()
}) })
wg.Wait()
s.mu.Lock() s.mu.Lock()
delete(s.conns, st) delete(s.conns, st)
s.mu.Unlock() s.mu.Unlock()

View File

@ -138,7 +138,7 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI
// operateHeader takes action on the decoded headers. It returns the current // operateHeader takes action on the decoded headers. It returns the current
// stream if there are remaining headers on the wire (in the following // stream if there are remaining headers on the wire (in the following
// Continuation frame). // Continuation frame).
func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame headerFrame, endStream bool, handle func(*Stream, *sync.WaitGroup), wg *sync.WaitGroup) (pendingStream *Stream) { func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame headerFrame, endStream bool, handle func(*Stream)) (pendingStream *Stream) {
defer func() { defer func() {
if pendingStream == nil { if pendingStream == nil {
hDec.state = decodeState{} hDec.state = decodeState{}
@ -202,13 +202,13 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header
recv: s.buf, recv: s.buf,
} }
s.method = hDec.state.method s.method = hDec.state.method
handle(s, wg) handle(s)
return nil return nil
} }
// HandleStreams receives incoming streams using the given handler. This is // HandleStreams receives incoming streams using the given handler. This is
// typically run in a separate goroutine. // typically run in a separate goroutine.
func (t *http2Server) HandleStreams(handle func(*Stream, *sync.WaitGroup)) { func (t *http2Server) HandleStreams(handle func(*Stream)) {
// Check the validity of client preface. // Check the validity of client preface.
preface := make([]byte, len(clientPreface)) preface := make([]byte, len(clientPreface))
if _, err := io.ReadFull(t.conn, preface); err != nil { if _, err := io.ReadFull(t.conn, preface); err != nil {
@ -238,8 +238,6 @@ func (t *http2Server) HandleStreams(handle func(*Stream, *sync.WaitGroup)) {
hDec := newHPACKDecoder() hDec := newHPACKDecoder()
var curStream *Stream var curStream *Stream
var wg sync.WaitGroup
defer wg.Wait()
for { for {
frame, err := t.framer.readFrame() frame, err := t.framer.readFrame()
if err != nil { if err != nil {
@ -268,9 +266,9 @@ func (t *http2Server) HandleStreams(handle func(*Stream, *sync.WaitGroup)) {
fc: fc, fc: fc,
} }
endStream := frame.Header().Flags.Has(http2.FlagHeadersEndStream) endStream := frame.Header().Flags.Has(http2.FlagHeadersEndStream)
curStream = t.operateHeaders(hDec, curStream, frame, endStream, handle, &wg) curStream = t.operateHeaders(hDec, curStream, frame, endStream, handle)
case *http2.ContinuationFrame: case *http2.ContinuationFrame:
curStream = t.operateHeaders(hDec, curStream, frame, false, handle, &wg) curStream = t.operateHeaders(hDec, curStream, frame, false, handle)
case *http2.DataFrame: case *http2.DataFrame:
t.handleData(frame) t.handleData(frame)
case *http2.RSTStreamFrame: case *http2.RSTStreamFrame:

View File

@ -391,7 +391,7 @@ type ServerTransport interface {
// WriteHeader sends the header metedata for the given stream. // WriteHeader sends the header metedata for the given stream.
WriteHeader(s *Stream, md metadata.MD) error WriteHeader(s *Stream, md metadata.MD) error
// HandleStreams receives incoming streams using the given handler. // HandleStreams receives incoming streams using the given handler.
HandleStreams(func(*Stream, *sync.WaitGroup)) HandleStreams(func(*Stream))
// Close tears down the transport. Once it is called, the transport // Close tears down the transport. Once it is called, the transport
// should not be accessed any more. All the pending streams and their // should not be accessed any more. All the pending streams and their
// handlers will be terminated asynchronously. // handlers will be terminated asynchronously.

View File

@ -77,8 +77,7 @@ const (
misbehaved misbehaved
) )
func (h *testStreamHandler) handleStream(t *testing.T, s *Stream, wg *sync.WaitGroup) { func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
defer wg.Done()
req := expectedRequest req := expectedRequest
resp := expectedResponse resp := expectedResponse
if s.Method() == "foo.Large" { if s.Method() == "foo.Large" {
@ -100,16 +99,13 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream, wg *sync.WaitG
} }
// handleStreamSuspension blocks until s.ctx is canceled. // handleStreamSuspension blocks until s.ctx is canceled.
func (h *testStreamHandler) handleStreamSuspension(s *Stream, wg *sync.WaitGroup) { func (h *testStreamHandler) handleStreamSuspension(s *Stream) {
wg.Add(1)
go func() { go func() {
<-s.ctx.Done() <-s.ctx.Done()
wg.Done()
}() }()
} }
func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream, wg *sync.WaitGroup) { func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) {
defer wg.Done()
conn, ok := s.ServerTransport().(*http2Server) conn, ok := s.ServerTransport().(*http2Server)
if !ok { if !ok {
t.Fatalf("Failed to convert %v to *http2Server", s.ServerTransport()) t.Fatalf("Failed to convert %v to *http2Server", s.ServerTransport())
@ -173,14 +169,12 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) {
case suspended: case suspended:
go transport.HandleStreams(h.handleStreamSuspension) go transport.HandleStreams(h.handleStreamSuspension)
case misbehaved: case misbehaved:
go transport.HandleStreams(func(s *Stream, wg *sync.WaitGroup) { go transport.HandleStreams(func(s *Stream) {
wg.Add(1) go h.handleStreamMisbehave(t, s)
go h.handleStreamMisbehave(t, s, wg)
}) })
default: default:
go transport.HandleStreams(func(s *Stream, wg *sync.WaitGroup) { go transport.HandleStreams(func(s *Stream) {
wg.Add(1) go h.handleStream(t, s)
go h.handleStream(t, s, wg)
}) })
} }
} }