Merge pull request #389 from iamqizhao/master

remove sync.WaitGroup param from ServerTransport.HandleStream
This commit is contained in:
Qi Zhao
2015-10-06 16:58:19 -07:00
4 changed files with 16 additions and 22 deletions

View File

@ -259,7 +259,8 @@ func (s *Server) Serve(lis net.Listener) error {
s.mu.Unlock()
go func() {
st.HandleStreams(func(stream *transport.Stream, wg *sync.WaitGroup) {
var wg sync.WaitGroup
st.HandleStreams(func(stream *transport.Stream) {
var trInfo *traceInfo
if EnableTracing {
trInfo = &traceInfo{
@ -278,6 +279,7 @@ func (s *Server) Serve(lis net.Listener) error {
wg.Done()
}()
})
wg.Wait()
s.mu.Lock()
delete(s.conns, st)
s.mu.Unlock()

View File

@ -138,7 +138,7 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthI
// operateHeader takes action on the decoded headers. It returns the current
// stream if there are remaining headers on the wire (in the following
// Continuation frame).
func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame headerFrame, endStream bool, handle func(*Stream, *sync.WaitGroup), wg *sync.WaitGroup) (pendingStream *Stream) {
func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame headerFrame, endStream bool, handle func(*Stream)) (pendingStream *Stream) {
defer func() {
if pendingStream == nil {
hDec.state = decodeState{}
@ -202,13 +202,13 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header
recv: s.buf,
}
s.method = hDec.state.method
handle(s, wg)
handle(s)
return nil
}
// HandleStreams receives incoming streams using the given handler. This is
// typically run in a separate goroutine.
func (t *http2Server) HandleStreams(handle func(*Stream, *sync.WaitGroup)) {
func (t *http2Server) HandleStreams(handle func(*Stream)) {
// Check the validity of client preface.
preface := make([]byte, len(clientPreface))
if _, err := io.ReadFull(t.conn, preface); err != nil {
@ -238,8 +238,6 @@ func (t *http2Server) HandleStreams(handle func(*Stream, *sync.WaitGroup)) {
hDec := newHPACKDecoder()
var curStream *Stream
var wg sync.WaitGroup
defer wg.Wait()
for {
frame, err := t.framer.readFrame()
if err != nil {
@ -268,9 +266,9 @@ func (t *http2Server) HandleStreams(handle func(*Stream, *sync.WaitGroup)) {
fc: fc,
}
endStream := frame.Header().Flags.Has(http2.FlagHeadersEndStream)
curStream = t.operateHeaders(hDec, curStream, frame, endStream, handle, &wg)
curStream = t.operateHeaders(hDec, curStream, frame, endStream, handle)
case *http2.ContinuationFrame:
curStream = t.operateHeaders(hDec, curStream, frame, false, handle, &wg)
curStream = t.operateHeaders(hDec, curStream, frame, false, handle)
case *http2.DataFrame:
t.handleData(frame)
case *http2.RSTStreamFrame:

View File

@ -391,7 +391,7 @@ type ServerTransport interface {
// WriteHeader sends the header metedata for the given stream.
WriteHeader(s *Stream, md metadata.MD) error
// HandleStreams receives incoming streams using the given handler.
HandleStreams(func(*Stream, *sync.WaitGroup))
HandleStreams(func(*Stream))
// Close tears down the transport. Once it is called, the transport
// should not be accessed any more. All the pending streams and their
// handlers will be terminated asynchronously.

View File

@ -77,8 +77,7 @@ const (
misbehaved
)
func (h *testStreamHandler) handleStream(t *testing.T, s *Stream, wg *sync.WaitGroup) {
defer wg.Done()
func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
req := expectedRequest
resp := expectedResponse
if s.Method() == "foo.Large" {
@ -100,16 +99,13 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream, wg *sync.WaitG
}
// handleStreamSuspension blocks until s.ctx is canceled.
func (h *testStreamHandler) handleStreamSuspension(s *Stream, wg *sync.WaitGroup) {
wg.Add(1)
func (h *testStreamHandler) handleStreamSuspension(s *Stream) {
go func() {
<-s.ctx.Done()
wg.Done()
}()
}
func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream, wg *sync.WaitGroup) {
defer wg.Done()
func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) {
conn, ok := s.ServerTransport().(*http2Server)
if !ok {
t.Fatalf("Failed to convert %v to *http2Server", s.ServerTransport())
@ -173,14 +169,12 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) {
case suspended:
go transport.HandleStreams(h.handleStreamSuspension)
case misbehaved:
go transport.HandleStreams(func(s *Stream, wg *sync.WaitGroup) {
wg.Add(1)
go h.handleStreamMisbehave(t, s, wg)
go transport.HandleStreams(func(s *Stream) {
go h.handleStreamMisbehave(t, s)
})
default:
go transport.HandleStreams(func(s *Stream, wg *sync.WaitGroup) {
wg.Add(1)
go h.handleStream(t, s, wg)
go transport.HandleStreams(func(s *Stream) {
go h.handleStream(t, s)
})
}
}