refactor the NewStream a bit

This commit is contained in:
iamqizhao
2015-03-13 16:37:05 -07:00
parent 51496073b8
commit 4a56e1fdd9

View File

@ -166,11 +166,10 @@ func newHTTP2Client(addr string, opts *DialOptions) (_ ClientTransport, err erro
return t, nil return t, nil
} }
func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { func (t *http2Client) newStream(ctx context.Context, streamID uint32, callHdr *CallHdr) *Stream {
t.mu.Lock()
// TODO(zhaoq): Handle uint32 overflow. // TODO(zhaoq): Handle uint32 overflow.
s := &Stream{ s := &Stream{
id: t.nextID, id: streamID,
method: callHdr.Method, method: callHdr.Method,
buf: newRecvBuffer(), buf: newRecvBuffer(),
sendQuotaPool: newQuotaPool(initialWindowSize), sendQuotaPool: newQuotaPool(initialWindowSize),
@ -185,22 +184,12 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
ctx: s.ctx, ctx: s.ctx,
recv: s.buf, recv: s.buf,
} }
t.nextID += 2
t.mu.Unlock()
return s return s
} }
// NewStream creates a stream and register it into the transport as "active" // NewStream creates a stream and register it into the transport as "active"
// streams. // streams.
func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) { func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) {
if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil {
return nil, err
}
defer func() {
if _, ok := err.(ConnectionError); !ok {
t.writableChan <- 0
}
}()
// Record the timeout value on the context. // Record the timeout value on the context.
var timeout time.Duration var timeout time.Duration
if dl, ok := ctx.Deadline(); ok { if dl, ok := ctx.Deadline(); ok {
@ -221,6 +210,9 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
return nil, StreamErrorf(codes.InvalidArgument, "transport: %v", err) return nil, StreamErrorf(codes.InvalidArgument, "transport: %v", err)
} }
} }
if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil {
return nil, err
}
// HPACK encodes various headers. Note that once WriteField(...) is // HPACK encodes various headers. Note that once WriteField(...) is
// called, the corresponding headers/continuation frame has to be sent // called, the corresponding headers/continuation frame has to be sent
// because hpack.Encoder is stateful. // because hpack.Encoder is stateful.
@ -244,6 +236,8 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
} }
first := true first := true
endHeaders := false endHeaders := false
streamID := t.nextID
t.nextID += 2
// Sends the headers in a single batch even when they span multiple frames. // Sends the headers in a single batch even when they span multiple frames.
for !endHeaders { for !endHeaders {
size := t.hBuf.Len() size := t.hBuf.Len()
@ -255,7 +249,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
if first { if first {
// Sends a HeadersFrame to server to start a new stream. // Sends a HeadersFrame to server to start a new stream.
p := http2.HeadersFrameParam{ p := http2.HeadersFrameParam{
StreamID: t.nextID, StreamID: streamID,
BlockFragment: t.hBuf.Next(size), BlockFragment: t.hBuf.Next(size),
EndStream: false, EndStream: false,
EndHeaders: endHeaders, EndHeaders: endHeaders,
@ -271,7 +265,8 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
return nil, ConnectionErrorf("transport: %v", err) return nil, ConnectionErrorf("transport: %v", err)
} }
} }
s := t.newStream(ctx, callHdr) t.writableChan <- 0
s := t.newStream(ctx, streamID, callHdr)
t.mu.Lock() t.mu.Lock()
if t.state != reachable { if t.state != reachable {
t.mu.Unlock() t.mu.Unlock()