This commit is contained in:
iamqizhao
2016-07-28 18:32:51 -07:00
parent 80572b2739
commit 32df3a68d0
4 changed files with 15 additions and 6 deletions

11
call.go
View File

@ -51,6 +51,13 @@ import (
func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error { func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error {
// Try to acquire header metadata from the server if there is any. // Try to acquire header metadata from the server if there is any.
var err error var err error
defer func() {
if err != nil {
if _, ok := err.(transport.ConnectionError); !ok {
t.CloseStream(stream, err)
}
}
}()
c.headerMD, err = stream.Header() c.headerMD, err = stream.Header()
if err != nil { if err != nil {
return err return err
@ -190,20 +197,18 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
} }
return toRPCErr(err) return toRPCErr(err)
} }
// Receive the response
err = recvResponse(cc.dopts, t, &c, stream, reply) err = recvResponse(cc.dopts, t, &c, stream, reply)
if err != nil { if err != nil {
if put != nil { if put != nil {
put() put()
put = nil put = nil
} }
if _, ok := err.(transport.ConnectionError); ok { if _, ok := err.(transport.ConnectionError); ok || err == transport.ErrStreamDrain {
if c.failFast { if c.failFast {
return toRPCErr(err) return toRPCErr(err)
} }
continue continue
} }
t.CloseStream(stream, err)
return toRPCErr(err) return toRPCErr(err)
} }
if c.traceInfo.tr != nil { if c.traceInfo.tr != nil {

View File

@ -430,7 +430,9 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
delete(t.activeStreams, s.id) delete(t.activeStreams, s.id)
if t.state == draining && len(t.activeStreams) == 0 { if t.state == draining && len(t.activeStreams) == 0 {
// The transport is draining and s is the last live stream on t. // The transport is draining and s is the last live stream on t.
defer t.Close() t.mu.Unlock()
t.Close()
return
} }
t.mu.Unlock() t.mu.Unlock()
if updateStreams { if updateStreams {

View File

@ -240,6 +240,8 @@ func (s *Stream) Header() (metadata.MD, error) {
select { select {
case <-s.ctx.Done(): case <-s.ctx.Done():
return nil, ContextErr(s.ctx.Err()) return nil, ContextErr(s.ctx.Err())
case <-s.goAway:
return nil, ErrStreamDrain
case <-s.headerChan: case <-s.headerChan:
return s.header.Copy(), nil return s.header.Copy(), nil
} }

View File

@ -368,8 +368,8 @@ func TestGracefulClose(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
if _, err := ct.NewStream(context.Background(), callHdr); err != ErrConnClosing { if _, err := ct.NewStream(context.Background(), callHdr); err != ErrStreamDrain {
t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", ct, err, ErrConnClosing) t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", ct, err, ErrStreamDrain)
} }
}() }()
} }