* Export changes to OSS.

* First commit.

* Cherry-pick.

* Documentation.

* Post review updates.
This commit is contained in:
mmukhi
2018-04-30 09:54:33 -07:00
committed by GitHub
parent fc37cf1364
commit 7a8c989507
15 changed files with 1080 additions and 1023 deletions

View File

@ -66,17 +66,16 @@ type testStreamHandler struct {
} }
func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) { func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
p := &parser{r: s}
for { for {
pf, req, err := p.recvMsg(math.MaxInt32) isCompressed, req, err := recvMsg(s, math.MaxInt32)
if err == io.EOF { if err == io.EOF {
break break
} }
if err != nil { if err != nil {
return return
} }
if pf != compressionNone { if isCompressed {
t.Errorf("Received the mistaken message format %d, want %d", pf, compressionNone) t.Errorf("Received compressed message want non-compressed message")
return return
} }
var v string var v string
@ -105,12 +104,12 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
} }
} }
// send a response back to end the stream. // send a response back to end the stream.
hdr, data, err := encode(testCodec{}, &expectedResponse, nil, nil, nil) data, err := encode(testCodec{}, &expectedResponse, nil, nil, nil)
if err != nil { if err != nil {
t.Errorf("Failed to encode the response: %v", err) t.Errorf("Failed to encode the response: %v", err)
return return
} }
h.t.Write(s, hdr, data, &transport.Options{}) h.t.Write(s, data, &transport.Options{})
h.t.WriteStatus(s, status.New(codes.OK, "")) h.t.WriteStatus(s, status.New(codes.OK, ""))
} }

View File

@ -0,0 +1,203 @@
/*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package msgdecoder contains the logic to deconstruct a gRPC-message.
package msgdecoder
import (
"encoding/binary"
)
// RecvMsg is a message constructed from the incoming
// bytes on the transport for a stream.
// An instance of RecvMsg will contain only one of the
// following: message header related fields, data slice
// or error.
type RecvMsg struct {
// Following three are message header related
// fields.
// true if the message was compressed by the other
// side.
IsCompressed bool
// Length of the message.
Length int
// Overhead is the length of message header(5 bytes)
// plus padding.
Overhead int
// Data payload of the message.
Data []byte
// Err occurred while reading.
// nil: received some data
// io.EOF: stream is completed. data is nil.
// other non-nil error: transport failure. data is nil.
Err error
Next *RecvMsg
}
// RecvMsgList is a linked-list of RecvMsg.
type RecvMsgList struct {
head *RecvMsg
tail *RecvMsg
}
// IsEmpty returns true when l is empty.
func (l *RecvMsgList) IsEmpty() bool {
if l.tail == nil {
return true
}
return false
}
// Enqueue adds r to l at the back.
func (l *RecvMsgList) Enqueue(r *RecvMsg) {
if l.IsEmpty() {
l.head, l.tail = r, r
return
}
t := l.tail
l.tail = r
t.Next = r
}
// Dequeue removes a RcvMsg from the end of l.
func (l *RecvMsgList) Dequeue() *RecvMsg {
if l.head == nil {
// Note to developer: Instead of calling isEmpty() which
// checks the same condition on l.tail, we check it directly
// on l.head so that in non-nil cases, there aren't cache misses.
return nil
}
r := l.head
l.head = l.head.Next
if l.head == nil {
l.tail = nil
}
return r
}
// MessageDecoder decodes bytes from HTTP2 data frames
// and constructs a gRPC message which is then put in a
// buffer that application(RPCs) read from.
// gRPC Messages:
// First 5 bytes is the message header:
// First byte: Payload format.
// Next 4 bytes: Length of the message.
// Rest of the bytes is the message payload.
//
// TODO(mmukhi): Write unit tests.
type MessageDecoder struct {
// current message being read by the transport.
current *RecvMsg
dataOfst int
padding int
// hdr stores the message header as it is beind received by the transport.
hdr []byte
hdrOfst int
// Callback used to send decoded messages.
dispatch func(*RecvMsg)
}
// NewMessageDecoder creates an instance of MessageDecoder. It takes a callback
// which is called to dispatch finished headers and messages to the application.
func NewMessageDecoder(dispatch func(*RecvMsg)) *MessageDecoder {
return &MessageDecoder{
hdr: make([]byte, 5),
dispatch: dispatch,
}
}
// Decode consumes bytes from a HTTP2 data frame to create gRPC messages.
func (m *MessageDecoder) Decode(b []byte, padding int) {
m.padding += padding
for len(b) > 0 {
// Case 1: A complete message hdr was received earlier.
if m.current != nil {
n := copy(m.current.Data[m.dataOfst:], b)
m.dataOfst += n
b = b[n:]
if m.dataOfst == len(m.current.Data) { // Message is complete.
m.dispatch(m.current)
m.current = nil
m.dataOfst = 0
}
continue
}
// Case 2a: No message header has been received yet.
if m.hdrOfst == 0 {
// case 2a.1: b has the whole header
if len(b) >= 5 {
m.parseHeader(b[:5])
b = b[5:]
continue
}
// case 2a.2: b has partial header
n := copy(m.hdr, b)
m.hdrOfst = n
b = b[n:]
continue
}
// Case 2b: Partial message header was received earlier.
n := copy(m.hdr[m.hdrOfst:], b)
m.hdrOfst += n
b = b[n:]
if m.hdrOfst == 5 { // hdr is complete.
m.hdrOfst = 0
m.parseHeader(m.hdr)
}
}
}
func (m *MessageDecoder) parseHeader(b []byte) {
length := int(binary.BigEndian.Uint32(b[1:5]))
hdr := &RecvMsg{
IsCompressed: int(b[0]) == 1,
Length: length,
Overhead: m.padding + 5,
}
m.padding = 0
// Dispatch the information retreived from message header so
// that the RPC goroutine can send a proactive window update as we
// wait for the rest of it.
m.dispatch(hdr)
if length == 0 {
m.dispatch(&RecvMsg{})
return
}
m.current = &RecvMsg{
Data: getMem(length),
}
}
func getMem(l int) []byte {
// TODO(mmukhi): Reuse this memory.
return make([]byte, l)
}
// CreateMessageHeader creates a gRPC-specific message header.
func CreateMessageHeader(l int, isCompressed bool) []byte {
// TODO(mmukhi): Investigate if this memory is worth
// reusing.
hdr := make([]byte, 5)
if isCompressed {
hdr[0] = byte(1)
}
binary.BigEndian.PutUint32(hdr[1:], uint32(l))
return hdr
}

View File

@ -0,0 +1,81 @@
/*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package msgdecoder
import (
"encoding/binary"
"reflect"
"testing"
)
func TestMessageDecoder(t *testing.T) {
for _, test := range []struct {
numFrames int
data []string
}{
{1, []string{"abc"}}, // One message per frame.
{1, []string{"abc", "def", "ghi"}}, // Multiple messages per frame.
{3, []string{"a", "bcdef", "ghif"}}, // Multiple messages over multiple frames.
} {
var want []*RecvMsg
for _, d := range test.data {
want = append(want, &RecvMsg{Length: len(d), Overhead: 5})
want = append(want, &RecvMsg{Data: []byte(d)})
}
var got []*RecvMsg
dcdr := NewMessageDecoder(func(r *RecvMsg) { got = append(got, r) })
for _, fr := range createFrames(test.numFrames, test.data) {
dcdr.Decode(fr, 0)
}
if !match(got, want) {
t.Fatalf("got: %v, want: %v", got, want)
}
}
}
func match(got, want []*RecvMsg) bool {
for i, v := range got {
if !reflect.DeepEqual(v, want[i]) {
return false
}
}
return true
}
func createFrames(n int, msgs []string) [][]byte {
var b []byte
for _, m := range msgs {
payload := []byte(m)
hdr := make([]byte, 5)
binary.BigEndian.PutUint32(hdr[1:], uint32(len(payload)))
b = append(b, hdr...)
b = append(b, payload...)
}
// break b into n parts.
var result [][]byte
batch := len(b) / n
for len(b) != 0 {
sz := batch
if len(b) < sz {
sz = len(b)
}
result = append(result, b[:sz])
b = b[sz:]
}
return result
}

View File

@ -21,7 +21,6 @@ package grpc
import ( import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"encoding/binary"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -415,85 +414,39 @@ func (o CustomCodecCallOption) before(c *callInfo) error {
} }
func (o CustomCodecCallOption) after(c *callInfo) {} func (o CustomCodecCallOption) after(c *callInfo) {}
// The format of the payload: compressed or not?
type payloadFormat uint8
const (
compressionNone payloadFormat = iota // no compression
compressionMade
)
// parser reads complete gRPC messages from the underlying reader.
type parser struct {
// r is the underlying reader.
// See the comment on recvMsg for the permissible
// error types.
r io.Reader
// The header of a gRPC message. Find more detail at
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
header [5]byte
}
// recvMsg reads a complete gRPC message from the stream. // recvMsg reads a complete gRPC message from the stream.
// //
// It returns the message and its payload (compression/encoding) // It returns a flag set to true if message was compressed,
// format. The caller owns the returned msg memory. // the message as a byte slice or error if so.
// The caller owns the returned msg memory.
// //
// If there is an error, possible values are: // If there is an error, possible values are:
// * io.EOF, when no messages remain // * io.EOF, when no messages remain
// * io.ErrUnexpectedEOF // * io.ErrUnexpectedEOF
// * of type transport.ConnectionError // * of type transport.ConnectionError
// * of type transport.StreamError // * of type transport.StreamError
// No other error values or types must be returned, which also means // No other error values or types must be returned.
// that the underlying io.Reader must not return an incompatible func recvMsg(s *transport.Stream, maxRecvMsgSize int) (bool, []byte, error) {
// error. isCompressed, msg, err := s.Read(maxRecvMsgSize)
func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byte, err error) { if err != nil {
if _, err := p.r.Read(p.header[:]); err != nil { return false, nil, err
return 0, nil, err
} }
return isCompressed, msg, nil
pf = payloadFormat(p.header[0])
length := binary.BigEndian.Uint32(p.header[1:])
if length == 0 {
return pf, nil, nil
}
if int64(length) > int64(maxInt) {
return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max length allowed on current machine (%d vs. %d)", length, maxInt)
}
if int(length) > maxReceiveMessageSize {
return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize)
}
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
// of making it for each message:
msg = make([]byte, int(length))
if _, err := p.r.Read(msg); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return 0, nil, err
}
return pf, msg, nil
} }
// encode serializes msg and returns a buffer of message header and a buffer of msg. // encode serializes msg and returns a buffer of msg.
// If msg is nil, it generates the message header and an empty msg buffer. // If msg is nil, it generates an empty buffer.
// TODO(ddyihai): eliminate extra Compressor parameter. // TODO(ddyihai): eliminate extra Compressor parameter.
func encode(c baseCodec, msg interface{}, cp Compressor, outPayload *stats.OutPayload, compressor encoding.Compressor) ([]byte, []byte, error) { func encode(c baseCodec, msg interface{}, cp Compressor, outPayload *stats.OutPayload, compressor encoding.Compressor) ([]byte, error) {
var ( var (
b []byte b []byte
cbuf *bytes.Buffer cbuf *bytes.Buffer
) )
const (
payloadLen = 1
sizeLen = 4
)
if msg != nil { if msg != nil {
var err error var err error
b, err = c.Marshal(msg) b, err = c.Marshal(msg)
if err != nil { if err != nil {
return nil, nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error()) return nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
} }
if outPayload != nil { if outPayload != nil {
outPayload.Payload = msg outPayload.Payload = msg
@ -507,49 +460,36 @@ func encode(c baseCodec, msg interface{}, cp Compressor, outPayload *stats.OutPa
if compressor != nil { if compressor != nil {
z, _ := compressor.Compress(cbuf) z, _ := compressor.Compress(cbuf)
if _, err := z.Write(b); err != nil { if _, err := z.Write(b); err != nil {
return nil, nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) return nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
} }
z.Close() z.Close()
} else { } else {
// If Compressor is not set by UseCompressor, use default Compressor // If Compressor is not set by UseCompressor, use default Compressor
if err := cp.Do(cbuf, b); err != nil { if err := cp.Do(cbuf, b); err != nil {
return nil, nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error()) return nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
} }
} }
b = cbuf.Bytes() b = cbuf.Bytes()
} }
} }
if uint(len(b)) > math.MaxUint32 { if uint(len(b)) > math.MaxUint32 {
return nil, nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b)) return nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b))
} }
bufHeader := make([]byte, payloadLen+sizeLen)
if compressor != nil || cp != nil {
bufHeader[0] = byte(compressionMade)
} else {
bufHeader[0] = byte(compressionNone)
}
// Write length of b into buf
binary.BigEndian.PutUint32(bufHeader[payloadLen:], uint32(len(b)))
if outPayload != nil { if outPayload != nil {
outPayload.WireLength = payloadLen + sizeLen + len(b) // A 5 byte gRPC-specific message header will added to this message
// before it's put on wire.
outPayload.WireLength = 5 + len(b)
} }
return bufHeader, b, nil return b, nil
} }
func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status { func checkRecvPayload(recvCompress string, haveCompressor bool) *status.Status {
switch pf { if recvCompress == "" || recvCompress == encoding.Identity {
case compressionNone: return status.New(codes.Internal, "grpc: compressed flag set with identity or empty encoding")
case compressionMade: }
if recvCompress == "" || recvCompress == encoding.Identity { if !haveCompressor {
return status.New(codes.Internal, "grpc: compressed flag set with identity or empty encoding") return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
}
if !haveCompressor {
return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
}
default:
return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf)
} }
return nil return nil
} }
@ -557,8 +497,8 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool
// For the two compressor parameters, both should not be set, but if they are, // For the two compressor parameters, both should not be set, but if they are,
// dc takes precedence over compressor. // dc takes precedence over compressor.
// TODO(dfawley): wrap the old compressor/decompressor using the new API? // TODO(dfawley): wrap the old compressor/decompressor using the new API?
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload, compressor encoding.Compressor) error { func recv(c baseCodec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload, compressor encoding.Compressor) error {
pf, d, err := p.recvMsg(maxReceiveMessageSize) isCompressed, d, err := recvMsg(s, maxReceiveMessageSize)
if err != nil { if err != nil {
return err return err
} }
@ -566,11 +506,10 @@ func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m interf
inPayload.WireLength = len(d) inPayload.WireLength = len(d)
} }
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil { if isCompressed {
return st.Err() if st := checkRecvPayload(s.RecvCompress(), compressor != nil || dc != nil); st != nil {
} return st.Err()
}
if pf == compressionMade {
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor, // To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
// use this decompressor as the default. // use this decompressor as the default.
if dc != nil { if dc != nil {
@ -588,11 +527,11 @@ func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m interf
return status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) return status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
} }
} }
} if len(d) > maxReceiveMessageSize {
if len(d) > maxReceiveMessageSize { // TODO: Revisit the error code. Currently keep it consistent with java
// TODO: Revisit the error code. Currently keep it consistent with java // implementation.
// implementation. return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(d), maxReceiveMessageSize)
return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(d), maxReceiveMessageSize) }
} }
if err := c.Unmarshal(d, m); err != nil { if err := c.Unmarshal(d, m); err != nil {
return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err) return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)

View File

@ -22,7 +22,6 @@ import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"io" "io"
"math"
"reflect" "reflect"
"testing" "testing"
@ -45,77 +44,20 @@ func (f fullReader) Read(p []byte) (int, error) {
var _ CallOption = EmptyCallOption{} // ensure EmptyCallOption implements the interface var _ CallOption = EmptyCallOption{} // ensure EmptyCallOption implements the interface
func TestSimpleParsing(t *testing.T) {
bigMsg := bytes.Repeat([]byte{'x'}, 1<<24)
for _, test := range []struct {
// input
p []byte
// outputs
err error
b []byte
pt payloadFormat
}{
{nil, io.EOF, nil, compressionNone},
{[]byte{0, 0, 0, 0, 0}, nil, nil, compressionNone},
{[]byte{0, 0, 0, 0, 1, 'a'}, nil, []byte{'a'}, compressionNone},
{[]byte{1, 0}, io.ErrUnexpectedEOF, nil, compressionNone},
{[]byte{0, 0, 0, 0, 10, 'a'}, io.ErrUnexpectedEOF, nil, compressionNone},
// Check that messages with length >= 2^24 are parsed.
{append([]byte{0, 1, 0, 0, 0}, bigMsg...), nil, bigMsg, compressionNone},
} {
buf := fullReader{bytes.NewReader(test.p)}
parser := &parser{r: buf}
pt, b, err := parser.recvMsg(math.MaxInt32)
if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt {
t.Fatalf("parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err)
}
}
}
func TestMultipleParsing(t *testing.T) {
// Set a byte stream consists of 3 messages with their headers.
p := []byte{0, 0, 0, 0, 1, 'a', 0, 0, 0, 0, 2, 'b', 'c', 0, 0, 0, 0, 1, 'd'}
b := fullReader{bytes.NewReader(p)}
parser := &parser{r: b}
wantRecvs := []struct {
pt payloadFormat
data []byte
}{
{compressionNone, []byte("a")},
{compressionNone, []byte("bc")},
{compressionNone, []byte("d")},
}
for i, want := range wantRecvs {
pt, data, err := parser.recvMsg(math.MaxInt32)
if err != nil || pt != want.pt || !reflect.DeepEqual(data, want.data) {
t.Fatalf("after %d calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, <nil>",
i, p, pt, data, err, want.pt, want.data)
}
}
pt, data, err := parser.recvMsg(math.MaxInt32)
if err != io.EOF {
t.Fatalf("after %d recvMsgs calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant _, _, %v",
len(wantRecvs), p, pt, data, err, io.EOF)
}
}
func TestEncode(t *testing.T) { func TestEncode(t *testing.T) {
for _, test := range []struct { for _, test := range []struct {
// input // input
msg proto.Message msg proto.Message
cp Compressor cp Compressor
// outputs // outputs
hdr []byte
data []byte data []byte
err error err error
}{ }{
{nil, nil, []byte{0, 0, 0, 0, 0}, []byte{}, nil}, {nil, nil, []byte{}, nil},
} { } {
hdr, data, err := encode(encoding.GetCodec(protoenc.Name), test.msg, nil, nil, nil) data, err := encode(encoding.GetCodec(protoenc.Name), test.msg, nil, nil, nil)
if err != test.err || !bytes.Equal(hdr, test.hdr) || !bytes.Equal(data, test.data) { if err != test.err || !bytes.Equal(data, test.data) {
t.Fatalf("encode(_, _, %v, _) = %v, %v, %v\nwant %v, %v, %v", test.cp, hdr, data, err, test.hdr, test.data, test.err) t.Fatalf("encode(_, _, %v, _) = %v, %v\nwant %v, %v", test.cp, data, err, test.data, test.err)
} }
} }
} }
@ -214,8 +156,11 @@ func TestParseDialTarget(t *testing.T) {
func bmEncode(b *testing.B, mSize int) { func bmEncode(b *testing.B, mSize int) {
cdc := encoding.GetCodec(protoenc.Name) cdc := encoding.GetCodec(protoenc.Name)
msg := &perfpb.Buffer{Body: make([]byte, mSize)} msg := &perfpb.Buffer{Body: make([]byte, mSize)}
encodeHdr, encodeData, _ := encode(cdc, msg, nil, nil, nil) encodeData, _ := encode(cdc, msg, nil, nil, nil)
encodedSz := int64(len(encodeHdr) + len(encodeData)) // 5 bytes of gRPC-specific message header
// is added to the message before it is written
// to the wire.
encodedSz := int64(5 + len(encodeData))
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {

View File

@ -831,7 +831,7 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str
if s.opts.statsHandler != nil { if s.opts.statsHandler != nil {
outPayload = &stats.OutPayload{} outPayload = &stats.OutPayload{}
} }
hdr, data, err := encode(s.getCodec(stream.ContentSubtype()), msg, cp, outPayload, comp) data, err := encode(s.getCodec(stream.ContentSubtype()), msg, cp, outPayload, comp)
if err != nil { if err != nil {
grpclog.Errorln("grpc: server failed to encode response: ", err) grpclog.Errorln("grpc: server failed to encode response: ", err)
return err return err
@ -839,7 +839,8 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str
if len(data) > s.opts.maxSendMessageSize { if len(data) > s.opts.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), s.opts.maxSendMessageSize) return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), s.opts.maxSendMessageSize)
} }
err = t.Write(stream, hdr, data, opts) opts.IsCompressed = cp != nil || comp != nil
err = t.Write(stream, data, opts)
if err == nil && outPayload != nil { if err == nil && outPayload != nil {
outPayload.SentTime = time.Now() outPayload.SentTime = time.Now()
s.opts.statsHandler.HandleRPC(stream.Context(), outPayload) s.opts.statsHandler.HandleRPC(stream.Context(), outPayload)
@ -924,8 +925,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
} }
} }
p := &parser{r: stream} isCompressed, req, err := recvMsg(stream, s.opts.maxReceiveMessageSize)
pf, req, err := p.recvMsg(s.opts.maxReceiveMessageSize)
if err == io.EOF { if err == io.EOF {
// The entire stream is done (for unary RPC only). // The entire stream is done (for unary RPC only).
return err return err
@ -955,12 +955,6 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
if channelz.IsOn() { if channelz.IsOn() {
t.IncrMsgRecv() t.IncrMsgRecv()
} }
if st := checkRecvPayload(pf, stream.RecvCompress(), dc != nil || decomp != nil); st != nil {
if e := t.WriteStatus(stream, st); e != nil {
grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
return st.Err()
}
var inPayload *stats.InPayload var inPayload *stats.InPayload
if sh != nil { if sh != nil {
inPayload = &stats.InPayload{ inPayload = &stats.InPayload{
@ -971,7 +965,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
if inPayload != nil { if inPayload != nil {
inPayload.WireLength = len(req) inPayload.WireLength = len(req)
} }
if pf == compressionMade { if isCompressed {
if st := checkRecvPayload(stream.RecvCompress(), dc != nil || decomp != nil); st != nil {
return st.Err()
}
var err error var err error
if dc != nil { if dc != nil {
req, err = dc.Do(bytes.NewReader(req)) req, err = dc.Do(bytes.NewReader(req))
@ -985,11 +982,11 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
return status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err) return status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
} }
} }
} if len(req) > s.opts.maxReceiveMessageSize {
if len(req) > s.opts.maxReceiveMessageSize { // TODO: Revisit the error code. Currently keep it consistent with
// TODO: Revisit the error code. Currently keep it consistent with // java implementation.
// java implementation. return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(req), s.opts.maxReceiveMessageSize)
return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(req), s.opts.maxReceiveMessageSize) }
} }
if err := s.getCodec(stream.ContentSubtype()).Unmarshal(req, v); err != nil { if err := s.getCodec(stream.ContentSubtype()).Unmarshal(req, v); err != nil {
return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
@ -1100,7 +1097,6 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
ctx: ctx, ctx: ctx,
t: t, t: t,
s: stream, s: stream,
p: &parser{r: stream},
codec: s.getCodec(stream.ContentSubtype()), codec: s.getCodec(stream.ContentSubtype()),
maxReceiveMessageSize: s.opts.maxReceiveMessageSize, maxReceiveMessageSize: s.opts.maxReceiveMessageSize,
maxSendMessageSize: s.opts.maxSendMessageSize, maxSendMessageSize: s.opts.maxSendMessageSize,

View File

@ -290,7 +290,6 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
attempt: &csAttempt{ attempt: &csAttempt{
t: t, t: t,
s: s, s: s,
p: &parser{r: s},
done: done, done: done,
dc: cc.dopts.dc, dc: cc.dopts.dc,
ctx: ctx, ctx: ctx,
@ -347,7 +346,6 @@ type csAttempt struct {
cs *clientStream cs *clientStream
t transport.ClientTransport t transport.ClientTransport
s *transport.Stream s *transport.Stream
p *parser
done func(balancer.DoneInfo) done func(balancer.DoneInfo)
dc Decompressor dc Decompressor
@ -472,7 +470,7 @@ func (a *csAttempt) sendMsg(m interface{}) (err error) {
Client: true, Client: true,
} }
} }
hdr, data, err := encode(cs.codec, m, cs.cp, outPayload, cs.comp) data, err := encode(cs.codec, m, cs.cp, outPayload, cs.comp)
if err != nil { if err != nil {
return err return err
} }
@ -482,7 +480,11 @@ func (a *csAttempt) sendMsg(m interface{}) (err error) {
if !cs.desc.ClientStreams { if !cs.desc.ClientStreams {
cs.sentLast = true cs.sentLast = true
} }
err = a.t.Write(a.s, hdr, data, &transport.Options{Last: !cs.desc.ClientStreams}) opts := &transport.Options{
Last: !cs.desc.ClientStreams,
IsCompressed: cs.cp != nil || cs.comp != nil,
}
err = a.t.Write(a.s, data, opts)
if err == nil { if err == nil {
if outPayload != nil { if outPayload != nil {
outPayload.SentTime = time.Now() outPayload.SentTime = time.Now()
@ -526,7 +528,7 @@ func (a *csAttempt) recvMsg(m interface{}) (err error) {
// Only initialize this state once per stream. // Only initialize this state once per stream.
a.decompSet = true a.decompSet = true
} }
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.c.maxReceiveMessageSize, inPayload, a.decomp) err = recv(cs.codec, a.s, a.dc, m, *cs.c.maxReceiveMessageSize, inPayload, a.decomp)
if err != nil { if err != nil {
if err == io.EOF { if err == io.EOF {
if statusErr := a.s.Status().Err(); statusErr != nil { if statusErr := a.s.Status().Err(); statusErr != nil {
@ -556,7 +558,7 @@ func (a *csAttempt) recvMsg(m interface{}) (err error) {
// Special handling for non-server-stream rpcs. // Special handling for non-server-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload. // This recv expects EOF or errors, so we don't collect inPayload.
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.c.maxReceiveMessageSize, nil, a.decomp) err = recv(cs.codec, a.s, a.dc, m, *cs.c.maxReceiveMessageSize, nil, a.decomp)
if err == nil { if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>")) return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
} }
@ -572,7 +574,7 @@ func (a *csAttempt) closeSend() {
return return
} }
cs.sentLast = true cs.sentLast = true
cs.attempt.t.Write(cs.attempt.s, nil, nil, &transport.Options{Last: true}) cs.attempt.t.Write(cs.attempt.s, nil, &transport.Options{Last: true})
// We ignore errors from Write. Any error it would return would also be // We ignore errors from Write. Any error it would return would also be
// returned by a subsequent RecvMsg call, and the user is supposed to always // returned by a subsequent RecvMsg call, and the user is supposed to always
// finish the stream by calling RecvMsg until it returns err != nil. // finish the stream by calling RecvMsg until it returns err != nil.
@ -635,7 +637,6 @@ type serverStream struct {
ctx context.Context ctx context.Context
t transport.ServerTransport t transport.ServerTransport
s *transport.Stream s *transport.Stream
p *parser
codec baseCodec codec baseCodec
cp Compressor cp Compressor
@ -700,14 +701,18 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
if ss.statsHandler != nil { if ss.statsHandler != nil {
outPayload = &stats.OutPayload{} outPayload = &stats.OutPayload{}
} }
hdr, data, err := encode(ss.codec, m, ss.cp, outPayload, ss.comp) data, err := encode(ss.codec, m, ss.cp, outPayload, ss.comp)
if err != nil { if err != nil {
return err return err
} }
if len(data) > ss.maxSendMessageSize { if len(data) > ss.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize) return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize)
} }
if err := ss.t.Write(ss.s, hdr, data, &transport.Options{Last: false}); err != nil { opts := &transport.Options{
Last: false,
IsCompressed: ss.cp != nil || ss.comp != nil,
}
if err := ss.t.Write(ss.s, data, opts); err != nil {
return toRPCErr(err) return toRPCErr(err)
} }
if outPayload != nil { if outPayload != nil {
@ -743,7 +748,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
if ss.statsHandler != nil { if ss.statsHandler != nil {
inPayload = &stats.InPayload{} inPayload = &stats.InPayload{}
} }
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload, ss.decomp); err != nil { if err := recv(ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload, ss.decomp); err != nil {
if err == io.EOF { if err == io.EOF {
return err return err
} }

View File

@ -21,7 +21,6 @@ package transport
import ( import (
"fmt" "fmt"
"math" "math"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
) )
@ -96,35 +95,39 @@ func (w *writeQuota) replenish(n int) {
} }
type trInFlow struct { type trInFlow struct {
limit uint32 limit uint32 // accessed by reader goroutine.
unacked uint32 unacked uint32 // accessed by reader goroutine.
effectiveWindowSize uint32 effectiveWindowSize uint32 // accessed by reader and channelz request goroutine.
// Callback used to schedule window update.
scheduleWU func(uint32)
} }
func (f *trInFlow) newLimit(n uint32) uint32 { // Sets the new limit.
d := n - f.limit func (f *trInFlow) newLimit(n uint32) {
if n > f.limit {
f.scheduleWU(n - f.limit)
}
f.limit = n f.limit = n
f.updateEffectiveWindowSize() f.updateEffectiveWindowSize()
return d
} }
func (f *trInFlow) onData(n uint32) uint32 { func (f *trInFlow) onData(n uint32) {
f.unacked += n f.unacked += n
if f.unacked >= f.limit/4 { if f.unacked >= f.limit/4 {
w := f.unacked w := f.unacked
f.unacked = 0 f.unacked = 0
f.updateEffectiveWindowSize() f.scheduleWU(w)
return w
} }
f.updateEffectiveWindowSize() f.updateEffectiveWindowSize()
return 0
} }
func (f *trInFlow) reset() uint32 { func (f *trInFlow) reset() {
w := f.unacked if f.unacked == 0 {
return
}
f.scheduleWU(f.unacked)
f.unacked = 0 f.unacked = 0
f.updateEffectiveWindowSize() f.updateEffectiveWindowSize()
return w
} }
func (f *trInFlow) updateEffectiveWindowSize() { func (f *trInFlow) updateEffectiveWindowSize() {
@ -135,102 +138,57 @@ func (f *trInFlow) getSize() uint32 {
return atomic.LoadUint32(&f.effectiveWindowSize) return atomic.LoadUint32(&f.effectiveWindowSize)
} }
// TODO(mmukhi): Simplify this code. // stInFlow deals with inbound flow control for stream.
// inFlow deals with inbound flow control // It can be simultaneously read by transport's reader
type inFlow struct { // goroutine and an RPC's goroutine.
mu sync.Mutex // It is protected by the lock in stream that owns it.
type stInFlow struct {
// rcvd is the bytes of data that this end-point has
// received from the perspective of other side.
// This can go negative. It must be Accessed atomically.
// Needs to be aligned because of golang bug with atomics:
// https://golang.org/pkg/sync/atomic/#pkg-note-BUG
rcvd int64
// The inbound flow control limit for pending data. // The inbound flow control limit for pending data.
limit uint32 limit uint32
// pendingData is the overall data which have been received but not been // number of bytes received so far, this should be accessed
// consumed by applications. // number of bytes that have been read by the RPC.
pendingData uint32 read uint32
// The amount of data the application has consumed but grpc has not sent // a window update should be sent when the RPC has
// window update for them. Used to reduce window update frequency. // read these many bytes.
pendingUpdate uint32 // TODO(mmukhi, dfawley): Does this have to be limit/4?
// delta is the extra window update given by receiver when an application // Keeping it a constant makes implementation easy.
// is reading data bigger in size than the inFlow limit. wuThreshold uint32
delta uint32 // Callback used to schedule window update.
scheduleWU func(uint32)
} }
// newLimit updates the inflow window to a new value n. // called by transport's reader goroutine to set a new limit on
// It assumes that n is always greater than the old limit. // incoming flow control based on BDP estimation.
func (f *inFlow) newLimit(n uint32) uint32 { func (s *stInFlow) newLimit(n uint32) {
f.mu.Lock() s.limit = n
d := n - f.limit
f.limit = n
f.mu.Unlock()
return d
} }
func (f *inFlow) maybeAdjust(n uint32) uint32 { // called by transport's reader goroutine when data is received by it.
if n > uint32(math.MaxInt32) { func (s *stInFlow) onData(n uint32) error {
n = uint32(math.MaxInt32) rcvd := atomic.AddInt64(&s.rcvd, int64(n))
if rcvd > int64(s.limit) { // Flow control violation.
return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", rcvd, s.limit)
} }
f.mu.Lock()
// estSenderQuota is the receiver's view of the maximum number of bytes the sender
// can send without a window update.
estSenderQuota := int32(f.limit - (f.pendingData + f.pendingUpdate))
// estUntransmittedData is the maximum number of bytes the sends might not have put
// on the wire yet. A value of 0 or less means that we have already received all or
// more bytes than the application is requesting to read.
estUntransmittedData := int32(n - f.pendingData) // Casting into int32 since it could be negative.
// This implies that unless we send a window update, the sender won't be able to send all the bytes
// for this message. Therefore we must send an update over the limit since there's an active read
// request from the application.
if estUntransmittedData > estSenderQuota {
// Sender's window shouldn't go more than 2^31 - 1 as specified in the HTTP spec.
if f.limit+n > maxWindowSize {
f.delta = maxWindowSize - f.limit
} else {
// Send a window update for the whole message and not just the difference between
// estUntransmittedData and estSenderQuota. This will be helpful in case the message
// is padded; We will fallback on the current available window(at least a 1/4th of the limit).
f.delta = n
}
f.mu.Unlock()
return f.delta
}
f.mu.Unlock()
return 0
}
// onData is invoked when some data frame is received. It updates pendingData.
func (f *inFlow) onData(n uint32) error {
f.mu.Lock()
f.pendingData += n
if f.pendingData+f.pendingUpdate > f.limit+f.delta {
limit := f.limit
rcvd := f.pendingData + f.pendingUpdate
f.mu.Unlock()
return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", rcvd, limit)
}
f.mu.Unlock()
return nil return nil
} }
// onRead is invoked when the application reads the data. It returns the window size // called by RPC's goroutine when data is read by it.
// to be sent to the peer. func (s *stInFlow) onRead(n uint32) {
func (f *inFlow) onRead(n uint32) uint32 { s.read += n
f.mu.Lock() if s.read >= s.wuThreshold {
if f.pendingData == 0 { val := atomic.AddInt64(&s.rcvd, ^int64(s.read-1))
f.mu.Unlock() // Check if threshold needs to go up since limit might have gone up.
return 0 val += int64(s.read)
if val > int64(4*s.wuThreshold) {
s.wuThreshold = uint32(val / 4)
}
s.scheduleWU(s.read)
s.read = 0
} }
f.pendingData -= n
if n > f.delta {
n -= f.delta
f.delta = 0
} else {
f.delta -= n
n = 0
}
f.pendingUpdate += n
if f.pendingUpdate >= f.limit/4 {
wu := f.pendingUpdate
f.pendingUpdate = 0
f.mu.Unlock()
return wu
}
f.mu.Unlock()
return 0
} }

View File

@ -38,6 +38,7 @@ import (
"golang.org/x/net/http2" "golang.org/x/net/http2"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/msgdecoder"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
@ -269,10 +270,10 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
} }
} }
func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { func (ht *serverHandlerTransport) Write(s *Stream, data []byte, opts *Options) error {
return ht.do(func() { return ht.do(func() {
ht.writeCommonHeaders(s) ht.writeCommonHeaders(s)
ht.rw.Write(hdr) ht.rw.Write(msgdecoder.CreateMessageHeader(len(data), opts.IsCompressed))
ht.rw.Write(data) ht.rw.Write(data)
if !opts.Delay { if !opts.Delay {
ht.rw.(http.Flusher).Flush() ht.rw.(http.Flusher).Flush()
@ -337,16 +338,13 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
req := ht.req req := ht.req
s := &Stream{ s := newStream(ctx)
id: 0, // irrelevant s.cancel = cancel
requestRead: func(int) {}, s.st = ht
cancel: cancel, s.method = req.URL.Path
buf: newRecvBuffer(), s.recvCompress = req.Header.Get("grpc-encoding")
st: ht, s.contentSubtype = ht.contentSubtype
method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
contentSubtype: ht.contentSubtype,
}
pr := &peer.Peer{ pr := &peer.Peer{
Addr: ht.RemoteAddr(), Addr: ht.RemoteAddr(),
} }
@ -364,10 +362,6 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
} }
ht.stats.HandleRPC(s.ctx, inHeader) ht.stats.HandleRPC(s.ctx, inHeader)
} }
s.trReader = &transportReader{
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf},
windowHandler: func(int) {},
}
// readerDone is closed when the Body.Read-ing goroutine exits. // readerDone is closed when the Body.Read-ing goroutine exits.
readerDone := make(chan struct{}) readerDone := make(chan struct{})
@ -379,11 +373,11 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
for buf := make([]byte, readSize); ; { for buf := make([]byte, readSize); ; {
n, err := req.Body.Read(buf) n, err := req.Body.Read(buf)
if n > 0 { if n > 0 {
s.buf.put(recvMsg{data: buf[:n:n]}) s.consume(buf[:n:n], 0)
buf = buf[n:] buf = buf[n:]
} }
if err != nil { if err != nil {
s.buf.put(recvMsg{err: mapRecvMsgError(err)}) s.notifyErr(mapRecvMsgError(err))
return return
} }
if len(buf) == 0 { if len(buf) == 0 {

View File

@ -423,7 +423,7 @@ func TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
st.bodyw.Close() // no body st.bodyw.Close() // no body
st.ht.WriteStatus(s, status.New(codes.OK, "")) st.ht.WriteStatus(s, status.New(codes.OK, ""))
st.ht.Write(s, []byte("hdr"), []byte("data"), &Options{}) st.ht.Write(s, []byte("data"), &Options{})
}) })
} }

View File

@ -34,6 +34,7 @@ import (
"google.golang.org/grpc/channelz" "google.golang.org/grpc/channelz"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/msgdecoder"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
@ -95,8 +96,9 @@ type http2Client struct {
waitingStreams uint32 waitingStreams uint32
nextID uint32 nextID uint32
mu sync.Mutex // guard the following variables mu sync.Mutex // guard the following variables
state transportState state transportState
// TODO(mmukhi): Make this a sharded map.
activeStreams map[uint32]*Stream activeStreams map[uint32]*Stream
// prevGoAway ID records the Last-Stream-ID in the previous GOAway frame. // prevGoAway ID records the Last-Stream-ID in the previous GOAway frame.
prevGoAwayID uint32 prevGoAwayID uint32
@ -218,7 +220,6 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
goAway: make(chan struct{}), goAway: make(chan struct{}),
awakenKeepalive: make(chan struct{}, 1), awakenKeepalive: make(chan struct{}, 1),
framer: newFramer(conn, writeBufSize, readBufSize), framer: newFramer(conn, writeBufSize, readBufSize),
fc: &trInFlow{limit: uint32(icwz)},
scheme: scheme, scheme: scheme,
activeStreams: make(map[uint32]*Stream), activeStreams: make(map[uint32]*Stream),
isSecure: isSecure, isSecure: isSecure,
@ -233,6 +234,15 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
streamsQuotaAvailable: make(chan struct{}, 1), streamsQuotaAvailable: make(chan struct{}, 1),
} }
t.controlBuf = newControlBuffer(t.ctxDone) t.controlBuf = newControlBuffer(t.ctxDone)
t.fc = &trInFlow{
limit: uint32(icwz),
scheduleWU: func(w uint32) {
t.controlBuf.put(&outgoingWindowUpdate{
streamID: 0,
increment: w,
})
},
}
if opts.InitialWindowSize >= defaultWindowSize { if opts.InitialWindowSize >= defaultWindowSize {
t.initialWindowSize = opts.InitialWindowSize t.initialWindowSize = opts.InitialWindowSize
dynamicWindow = false dynamicWindow = false
@ -306,33 +316,17 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
} }
func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
// TODO(zhaoq): Handle uint32 overflow of Stream.id.
s := &Stream{
done: make(chan struct{}),
method: callHdr.Method,
sendCompress: callHdr.SendCompress,
buf: newRecvBuffer(),
headerChan: make(chan struct{}),
contentSubtype: callHdr.ContentSubtype,
}
s.wq = newWriteQuota(defaultWriteQuota, s.done)
s.requestRead = func(n int) {
t.adjustWindow(s, uint32(n))
}
// The client side stream context should have exactly the same life cycle with the user provided context. // The client side stream context should have exactly the same life cycle with the user provided context.
// That means, s.ctx should be read-only. And s.ctx is done iff ctx is done. // That means, s.ctx should be read-only. And s.ctx is done iff ctx is done.
// So we use the original context here instead of creating a copy. // So we use the original context here instead of creating a copy.
s.ctx = ctx s := newStream(ctx)
s.trReader = &transportReader{ // Initialize stream with client-side specific fields.
reader: &recvBufferReader{ s.done = make(chan struct{})
ctx: s.ctx, s.method = callHdr.Method
ctxDone: s.ctx.Done(), s.sendCompress = callHdr.SendCompress
recv: s.buf, s.headerChan = make(chan struct{})
}, s.contentSubtype = callHdr.ContentSubtype
windowHandler: func(n int) { s.wq = newWriteQuota(defaultWriteQuota, s.done)
t.updateWindow(s, uint32(n))
},
}
return s return s
} }
@ -504,7 +498,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
} }
// The stream was unprocessed by the server. // The stream was unprocessed by the server.
atomic.StoreUint32(&s.unprocessed, 1) atomic.StoreUint32(&s.unprocessed, 1)
s.write(recvMsg{err: err}) s.notifyErr(err)
close(s.done) close(s.done)
// If headerChan isn't closed, then close it. // If headerChan isn't closed, then close it.
if atomic.SwapUint32(&s.headerDone, 1) == 0 { if atomic.SwapUint32(&s.headerDone, 1) == 0 {
@ -572,7 +566,13 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
h.streamID = t.nextID h.streamID = t.nextID
t.nextID += 2 t.nextID += 2
s.id = h.streamID s.id = h.streamID
s.fc = &inFlow{limit: uint32(t.initialWindowSize)} s.fc = &stInFlow{
limit: uint32(t.initialWindowSize),
scheduleWU: func(w uint32) {
t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w})
},
wuThreshold: uint32(t.initialWindowSize / 4),
}
if t.streamQuota > 0 && t.waitingStreams > 0 { if t.streamQuota > 0 && t.waitingStreams > 0 {
select { select {
case t.streamsQuotaAvailable <- struct{}{}: case t.streamsQuotaAvailable <- struct{}{}:
@ -642,7 +642,7 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.
} }
if err != nil { if err != nil {
// This will unblock reads eventually. // This will unblock reads eventually.
s.write(recvMsg{err: err}) s.notifyErr(err)
} }
// This will unblock write. // This will unblock write.
close(s.done) close(s.done)
@ -740,7 +740,7 @@ func (t *http2Client) GracefulClose() error {
// Write formats the data into HTTP2 data frame(s) and sends it out. The caller // Write formats the data into HTTP2 data frame(s) and sends it out. The caller
// should proceed only if Write returns nil. // should proceed only if Write returns nil.
func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
if opts.Last { if opts.Last {
// If it's the last message, update stream state. // If it's the last message, update stream state.
if !s.compareAndSwapState(streamActive, streamWriteDone) { if !s.compareAndSwapState(streamActive, streamWriteDone) {
@ -753,7 +753,9 @@ func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
streamID: s.id, streamID: s.id,
endStream: opts.Last, endStream: opts.Last,
} }
if hdr != nil || data != nil { // If it's not an empty data frame. if data != nil { // If it's not an empty data frame.
// Get a gRPC-specific header for this message.
hdr := msgdecoder.CreateMessageHeader(len(data), opts.IsCompressed)
// Add some data to grpc message header so that we can equally // Add some data to grpc message header so that we can equally
// distribute bytes across frames. // distribute bytes across frames.
emptyLen := http2MaxFrameLen - len(hdr) emptyLen := http2MaxFrameLen - len(hdr)
@ -778,39 +780,19 @@ func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) {
return s, ok return s, ok
} }
// adjustWindow sends out extra window update over the initial window size
// of stream if the application is requesting data larger in size than
// the window.
func (t *http2Client) adjustWindow(s *Stream, n uint32) {
if w := s.fc.maybeAdjust(n); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w})
}
}
// updateWindow adjusts the inbound quota for the stream.
// Window updates will be sent out when the cumulative quota
// exceeds the corresponding threshold.
func (t *http2Client) updateWindow(s *Stream, n uint32) {
if w := s.fc.onRead(n); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w})
}
}
// updateFlowControl updates the incoming flow control windows // updateFlowControl updates the incoming flow control windows
// for the transport and the stream based on the current bdp // for the transport and the stream based on the current bdp
// estimation. // estimation.
func (t *http2Client) updateFlowControl(n uint32) { func (t *http2Client) updateFlowControl(n uint32) {
t.mu.Lock() t.fc.newLimit(n) // Update transport's window.
for _, s := range t.activeStreams { updateIWS := func(interface{}) bool { // Update streams' windows.
s.fc.newLimit(n) // All future streams should see the
} // updated value.
t.mu.Unlock()
updateIWS := func(interface{}) bool {
t.initialWindowSize = int32(n) t.initialWindowSize = int32(n)
return true return true
} }
t.controlBuf.executeAndPut(updateIWS, &outgoingWindowUpdate{streamID: 0, increment: t.fc.newLimit(n)}) // Notify the other side of updated window.
t.controlBuf.put(&outgoingSettings{ t.controlBuf.executeAndPut(updateIWS, &outgoingSettings{
ss: []http2.Setting{ ss: []http2.Setting{
{ {
ID: http2.SettingInitialWindowSize, ID: http2.SettingInitialWindowSize,
@ -818,13 +800,25 @@ func (t *http2Client) updateFlowControl(n uint32) {
}, },
}, },
}) })
t.mu.Lock()
// Update all the currently active streams.
for _, s := range t.activeStreams {
s.fc.newLimit(n)
}
t.mu.Unlock()
} }
func (t *http2Client) handleData(f *http2.DataFrame) { func (t *http2Client) handleData(f *http2.DataFrame) {
size := f.Header().Length size := f.Header().Length
var sendBDPPing bool if size == 0 {
if t.bdpEst != nil { if f.StreamEnded() {
sendBDPPing = t.bdpEst.add(size) // The server has closed the stream without sending trailers. Record that
// the read direction is closed, and set the status appropriately.
if s, ok := t.getStream(f); ok {
t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true)
}
}
return
} }
// Decouple connection's flow control from application's read. // Decouple connection's flow control from application's read.
// An update on connection's flow control should not depend on // An update on connection's flow control should not depend on
@ -835,53 +829,30 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
// active(fast) streams from starving in presence of slow or // active(fast) streams from starving in presence of slow or
// inactive streams. // inactive streams.
// //
if w := t.fc.onData(size); w > 0 { t.fc.onData(size)
t.controlBuf.put(&outgoingWindowUpdate{ if t.bdpEst != nil && t.bdpEst.add(size) {
streamID: 0,
increment: w,
})
}
if sendBDPPing {
// Avoid excessive ping detection (e.g. in an L7 proxy) // Avoid excessive ping detection (e.g. in an L7 proxy)
// by sending a window update prior to the BDP ping. // by sending a window update prior to the BDP ping.
t.fc.reset()
if w := t.fc.reset(); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{
streamID: 0,
increment: w,
})
}
t.controlBuf.put(bdpPing) t.controlBuf.put(bdpPing)
} }
// Select the right stream to dispatch. // Select the right stream to dispatch.
s, ok := t.getStream(f) if s, ok := t.getStream(f); ok {
if !ok { d := f.Data()
return padding := 0
} if f.Header().Flags.Has(http2.FlagDataPadded) {
if size > 0 { padding = int(size) - len(d)
if err := s.fc.onData(size); err != nil { }
if err := s.consume(d, padding); err != nil {
t.closeStream(s, io.EOF, true, http2.ErrCodeFlowControl, status.New(codes.Internal, err.Error()), nil, false) t.closeStream(s, io.EOF, true, http2.ErrCodeFlowControl, status.New(codes.Internal, err.Error()), nil, false)
return return
} }
if f.Header().Flags.Has(http2.FlagDataPadded) { if f.StreamEnded() {
if w := s.fc.onRead(size - uint32(len(f.Data()))); w > 0 { // The server has closed the stream without sending trailers. Record that
t.controlBuf.put(&outgoingWindowUpdate{s.id, w}) // the read direction is closed, and set the status appropriately.
} t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true)
} }
// TODO(bradfitz, zhaoq): A copy is required here because there is no
// guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated?
if len(f.Data()) > 0 {
data := make([]byte, len(f.Data()))
copy(data, f.Data())
s.write(recvMsg{data: data})
}
}
// The server has closed the stream without sending trailers. Record that
// the read direction is closed, and set the status appropriately.
if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) {
t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true)
} }
} }
@ -890,6 +861,7 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
if !ok { if !ok {
return return
} }
errorf("transport: client got RST_STREAM with error %v, for stream: %d", f.ErrCode, s.id)
if f.ErrCode == http2.ErrCodeRefusedStream { if f.ErrCode == http2.ErrCodeRefusedStream {
// The stream was unprocessed by the server. // The stream was unprocessed by the server.
atomic.StoreUint32(&s.unprocessed, 1) atomic.StoreUint32(&s.unprocessed, 1)

View File

@ -39,6 +39,7 @@ import (
"google.golang.org/grpc/channelz" "google.golang.org/grpc/channelz"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/msgdecoder"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
@ -212,7 +213,6 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
writerDone: make(chan struct{}), writerDone: make(chan struct{}),
maxStreams: maxStreams, maxStreams: maxStreams,
inTapHandle: config.InTapHandle, inTapHandle: config.InTapHandle,
fc: &trInFlow{limit: uint32(icwz)},
state: reachable, state: reachable,
activeStreams: make(map[uint32]*Stream), activeStreams: make(map[uint32]*Stream),
stats: config.StatsHandler, stats: config.StatsHandler,
@ -222,6 +222,15 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
initialWindowSize: iwz, initialWindowSize: iwz,
} }
t.controlBuf = newControlBuffer(t.ctxDone) t.controlBuf = newControlBuffer(t.ctxDone)
t.fc = &trInFlow{
limit: uint32(icwz),
scheduleWU: func(w uint32) {
t.controlBuf.put(&outgoingWindowUpdate{
streamID: 0,
increment: w,
})
},
}
if dynamicWindow { if dynamicWindow {
t.bdpEst = &bdpEstimator{ t.bdpEst = &bdpEstimator{
bdp: initialWindowSize, bdp: initialWindowSize,
@ -298,25 +307,14 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
return return
} }
} }
var (
buf := newRecvBuffer() ctx context.Context
s := &Stream{ cancel func()
id: streamID, )
st: t,
buf: buf,
fc: &inFlow{limit: uint32(t.initialWindowSize)},
recvCompress: state.encoding,
method: state.method,
contentSubtype: state.contentSubtype,
}
if frame.StreamEnded() {
// s is just created by the caller. No lock needed.
s.state = streamReadDone
}
if state.timeoutSet { if state.timeoutSet {
s.ctx, s.cancel = context.WithTimeout(t.ctx, state.timeout) ctx, cancel = context.WithTimeout(t.ctx, state.timeout)
} else { } else {
s.ctx, s.cancel = context.WithCancel(t.ctx) ctx, cancel = context.WithCancel(t.ctx)
} }
pr := &peer.Peer{ pr := &peer.Peer{
Addr: t.remoteAddr, Addr: t.remoteAddr,
@ -325,34 +323,55 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
if t.authInfo != nil { if t.authInfo != nil {
pr.AuthInfo = t.authInfo pr.AuthInfo = t.authInfo
} }
s.ctx = peer.NewContext(s.ctx, pr) ctx = peer.NewContext(ctx, pr)
// Attach the received metadata to the context. // Attach the received metadata to the context.
if len(state.mdata) > 0 { if len(state.mdata) > 0 {
s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata) ctx = metadata.NewIncomingContext(ctx, state.mdata)
} }
if state.statsTags != nil { if state.statsTags != nil {
s.ctx = stats.SetIncomingTags(s.ctx, state.statsTags) ctx = stats.SetIncomingTags(ctx, state.statsTags)
} }
if state.statsTrace != nil { if state.statsTrace != nil {
s.ctx = stats.SetIncomingTrace(s.ctx, state.statsTrace) ctx = stats.SetIncomingTrace(ctx, state.statsTrace)
} }
if t.inTapHandle != nil { if t.inTapHandle != nil {
var err error var err error
info := &tap.Info{ info := &tap.Info{
FullMethodName: state.method, FullMethodName: state.method,
} }
s.ctx, err = t.inTapHandle(s.ctx, info) ctx, err = t.inTapHandle(ctx, info)
if err != nil { if err != nil {
warningf("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err) warningf("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err)
t.controlBuf.put(&cleanupStream{ t.controlBuf.put(&cleanupStream{
streamID: s.id, streamID: streamID,
rst: true, rst: true,
rstCode: http2.ErrCodeRefusedStream, rstCode: http2.ErrCodeRefusedStream,
onWrite: func() {}, onWrite: func() {},
}) })
cancel()
return return
} }
} }
ctx = traceCtx(ctx, state.method)
s := newStream(ctx)
// Initialize s with server-side specific fields.
s.cancel = cancel
s.id = streamID
s.st = t
s.fc = &stInFlow{
limit: uint32(t.initialWindowSize),
scheduleWU: func(w uint32) {
t.controlBuf.put(&outgoingWindowUpdate{streamID: streamID, increment: w})
},
wuThreshold: uint32(t.initialWindowSize / 4),
}
s.recvCompress = state.encoding
s.method = state.method
s.contentSubtype = state.contentSubtype
if frame.StreamEnded() {
s.state = streamReadDone
}
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
t.mu.Lock() t.mu.Lock()
if t.state != reachable { if t.state != reachable {
t.mu.Unlock() t.mu.Unlock()
@ -386,10 +405,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
t.lastStreamCreated = time.Now() t.lastStreamCreated = time.Now()
t.czmu.Unlock() t.czmu.Unlock()
} }
s.requestRead = func(n int) {
t.adjustWindow(s, uint32(n))
}
s.ctx = traceCtx(s.ctx, s.method)
if t.stats != nil { if t.stats != nil {
s.ctx = t.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) s.ctx = t.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
inHeader := &stats.InHeader{ inHeader := &stats.InHeader{
@ -401,18 +416,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
} }
t.stats.HandleRPC(s.ctx, inHeader) t.stats.HandleRPC(s.ctx, inHeader)
} }
s.ctxDone = s.ctx.Done()
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
s.trReader = &transportReader{
reader: &recvBufferReader{
ctx: s.ctx,
ctxDone: s.ctxDone,
recv: s.buf,
},
windowHandler: func(n int) {
t.updateWindow(s, uint32(n))
},
}
handle(s) handle(s)
return return
} }
@ -490,41 +493,20 @@ func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) {
return s, true return s, true
} }
// adjustWindow sends out extra window update over the initial window size
// of stream if the application is requesting data larger in size than
// the window.
func (t *http2Server) adjustWindow(s *Stream, n uint32) {
if w := s.fc.maybeAdjust(n); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w})
}
}
// updateWindow adjusts the inbound quota for the stream and the transport.
// Window updates will deliver to the controller for sending when
// the cumulative quota exceeds the corresponding threshold.
func (t *http2Server) updateWindow(s *Stream, n uint32) {
if w := s.fc.onRead(n); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id,
increment: w,
})
}
}
// updateFlowControl updates the incoming flow control windows // updateFlowControl updates the incoming flow control windows
// for the transport and the stream based on the current bdp // for the transport and the stream based on the current bdp
// estimation. // estimation.
func (t *http2Server) updateFlowControl(n uint32) { func (t *http2Server) updateFlowControl(n uint32) {
t.mu.Lock() t.mu.Lock()
// Update all the current streams' window.
for _, s := range t.activeStreams { for _, s := range t.activeStreams {
s.fc.newLimit(n) s.fc.newLimit(n)
} }
// Update all the future streams' window.
t.initialWindowSize = int32(n) t.initialWindowSize = int32(n)
t.mu.Unlock() t.mu.Unlock()
t.controlBuf.put(&outgoingWindowUpdate{ t.fc.newLimit(n) // Update transport's window.
streamID: 0, // Notify the other side of the updated value.
increment: t.fc.newLimit(n),
})
t.controlBuf.put(&outgoingSettings{ t.controlBuf.put(&outgoingSettings{
ss: []http2.Setting{ ss: []http2.Setting{
{ {
@ -538,9 +520,15 @@ func (t *http2Server) updateFlowControl(n uint32) {
func (t *http2Server) handleData(f *http2.DataFrame) { func (t *http2Server) handleData(f *http2.DataFrame) {
size := f.Header().Length size := f.Header().Length
var sendBDPPing bool if size == 0 {
if t.bdpEst != nil { if f.StreamEnded() {
sendBDPPing = t.bdpEst.add(size) if s, ok := t.getStream(f); ok {
// Received the end of stream from the client.
s.compareAndSwapState(streamActive, streamReadDone)
s.notifyErr(io.EOF)
}
}
return
} }
// Decouple connection's flow control from application's read. // Decouple connection's flow control from application's read.
// An update on connection's flow control should not depend on // An update on connection's flow control should not depend on
@ -550,51 +538,30 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
// Decoupling the connection flow control will prevent other // Decoupling the connection flow control will prevent other
// active(fast) streams from starving in presence of slow or // active(fast) streams from starving in presence of slow or
// inactive streams. // inactive streams.
if w := t.fc.onData(size); w > 0 { t.fc.onData(size)
t.controlBuf.put(&outgoingWindowUpdate{ if t.bdpEst != nil && t.bdpEst.add(size) {
streamID: 0,
increment: w,
})
}
if sendBDPPing {
// Avoid excessive ping detection (e.g. in an L7 proxy) // Avoid excessive ping detection (e.g. in an L7 proxy)
// by sending a window update prior to the BDP ping. // by sending a window update prior to the BDP ping.
if w := t.fc.reset(); w > 0 { t.fc.reset()
t.controlBuf.put(&outgoingWindowUpdate{
streamID: 0,
increment: w,
})
}
t.controlBuf.put(bdpPing) t.controlBuf.put(bdpPing)
} }
// Select the right stream to dispatch. // Select the right stream to dispatch.
s, ok := t.getStream(f) if s, ok := t.getStream(f); ok {
if !ok { d := f.Data()
return padding := 0
} if f.Header().Flags.Has(http2.FlagDataPadded) {
if size > 0 { padding = int(size) - len(d)
if err := s.fc.onData(size); err != nil { }
if err := s.consume(d, padding); err != nil {
errorf("transport: flow control error on server: %v", err)
t.closeStream(s, true, http2.ErrCodeFlowControl, nil, false) t.closeStream(s, true, http2.ErrCodeFlowControl, nil, false)
return return
} }
if f.Header().Flags.Has(http2.FlagDataPadded) { if f.StreamEnded() {
if w := s.fc.onRead(size - uint32(len(f.Data()))); w > 0 { // Received the end of stream from the client.
t.controlBuf.put(&outgoingWindowUpdate{s.id, w}) s.compareAndSwapState(streamActive, streamReadDone)
} s.notifyErr(io.EOF)
} }
// TODO(bradfitz, zhaoq): A copy is required here because there is no
// guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated?
if len(f.Data()) > 0 {
data := make([]byte, len(f.Data()))
copy(data, f.Data())
s.write(recvMsg{data: data})
}
}
if f.Header().Flags.Has(http2.FlagDataEndStream) {
// Received the end of stream from the client.
s.compareAndSwapState(streamActive, streamReadDone)
s.write(recvMsg{err: io.EOF})
} }
} }
@ -792,7 +759,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
// Write converts the data into HTTP2 data frame and sends it out. Non-nil error // Write converts the data into HTTP2 data frame and sends it out. Non-nil error
// is returns if it fails (e.g., framing error, transport error). // is returns if it fails (e.g., framing error, transport error).
func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
if !s.headerOk { // Headers haven't been written yet. if !s.headerOk { // Headers haven't been written yet.
if err := t.WriteHeader(s, nil); err != nil { if err := t.WriteHeader(s, nil); err != nil {
// TODO(mmukhi, dfawley): Make sure this is the right code to return. // TODO(mmukhi, dfawley): Make sure this is the right code to return.
@ -811,6 +778,8 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
return ContextErr(s.ctx.Err()) return ContextErr(s.ctx.Err())
} }
} }
// Get a gRPC-specific header for this message.
hdr := msgdecoder.CreateMessageHeader(len(data), opts.IsCompressed)
// Add some data to header frame so that we can equally distribute bytes across frames. // Add some data to header frame so that we can equally distribute bytes across frames.
emptyLen := http2MaxFrameLen - len(hdr) emptyLen := http2MaxFrameLen - len(hdr)
if emptyLen > len(data) { if emptyLen > len(data) {

407
transport/stream.go Normal file
View File

@ -0,0 +1,407 @@
/*
*
* Copyright 2014 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package transport
import (
"fmt"
"io"
"sync"
"sync/atomic"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/internal/msgdecoder"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
const maxInt = int(^uint(0) >> 1)
type streamState uint32
const (
streamActive streamState = iota
streamWriteDone // EndStream sent
streamReadDone // EndStream received
streamDone // the entire stream is finished.
)
// transport's reader goroutine adds msgdecoder.RecvMsg to it which are later
// read by RPC's reading goroutine.
//
// It is protected by a lock in the Stream that owns it.
type recvBuffer struct {
ctx context.Context
ctxDone <-chan struct{}
c chan *msgdecoder.RecvMsg
mu sync.Mutex
waiting bool
list *msgdecoder.RecvMsgList
}
func newRecvBuffer(ctx context.Context, ctxDone <-chan struct{}) *recvBuffer {
return &recvBuffer{
ctx: ctx,
ctxDone: ctxDone,
c: make(chan *msgdecoder.RecvMsg, 1),
list: &msgdecoder.RecvMsgList{},
}
}
// put adds r to the underlying list if there's no consumer
// waiting, otherwise, it writes on the chan directly.
func (b *recvBuffer) put(r *msgdecoder.RecvMsg) {
b.mu.Lock()
if b.waiting {
b.waiting = false
b.mu.Unlock()
b.c <- r
return
}
b.list.Enqueue(r)
b.mu.Unlock()
}
// getNoBlock returns a msgdecoder.RecvMsg and true status, if there's
// any available.
// If the status is false, the caller must then call
// getWithBlock() before calling getNoBlock() again.
func (b *recvBuffer) getNoBlock() (*msgdecoder.RecvMsg, bool) {
b.mu.Lock()
r := b.list.Dequeue()
if r != nil {
b.mu.Unlock()
return r, true
}
b.waiting = true
b.mu.Unlock()
return nil, false
}
// getWithBlock() blocks until a complete message has been
// received, or an error has occurred or the underlying
// context has expired.
// It must only be called after having called GetNoBlock()
// once.
func (b *recvBuffer) getWithBlock() (*msgdecoder.RecvMsg, error) {
select {
case <-b.ctxDone:
return nil, ContextErr(b.ctx.Err())
case r := <-b.c:
return r, nil
}
}
func (b *recvBuffer) get() (*msgdecoder.RecvMsg, error) {
m, ok := b.getNoBlock()
if ok {
return m, nil
}
m, err := b.getWithBlock()
if err != nil {
return nil, err
}
return m, nil
}
// Stream represents an RPC in the transport layer.
type Stream struct {
id uint32
st ServerTransport // nil for client side Stream
ctx context.Context // the associated context of the stream
cancel context.CancelFunc // always nil for client side Stream
done chan struct{} // closed at the end of stream to unblock writers. On the client side.
ctxDone <-chan struct{} // same as done chan but for server side. Cache of ctx.Done() (for performance)
method string // the associated RPC method of the stream
recvCompress string
sendCompress string
buf *recvBuffer
wq *writeQuota
headerChan chan struct{} // closed to indicate the end of header metadata.
headerDone uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
header metadata.MD // the received header metadata.
trailer metadata.MD // the key-value map of trailer metadata.
headerOk bool // becomes true from the first header is about to send
state streamState
status *status.Status // the status error received from the server
bytesReceived uint32 // indicates whether any bytes have been received on this stream
unprocessed uint32 // set if the server sends a refused stream or GOAWAY including this stream
// contentSubtype is the content-subtype for requests.
// this must be lowercase or the behavior is undefined.
contentSubtype string
rbuf *recvBuffer
fc *stInFlow
msgDecoder *msgdecoder.MessageDecoder
readErr error
}
func newStream(ctx context.Context) *Stream {
// Cache the done chan since a Done() call is expensive.
ctxDone := ctx.Done()
s := &Stream{
ctx: ctx,
ctxDone: ctxDone,
rbuf: newRecvBuffer(ctx, ctxDone),
}
dispatch := func(r *msgdecoder.RecvMsg) {
s.rbuf.put(r)
}
s.msgDecoder = msgdecoder.NewMessageDecoder(dispatch)
return s
}
// notifyErr notifies RPC of an error seen by the transport.
//
// Note to developers: This call can unblock Read calls on RPC
// and lead to reading of unprotected fields on stream on the
// client-side. It should only be called from inside
// transport.closeStream() if the stream was initialized or from
// inside the cleanup callback if the stream was not initialized.
func (s *Stream) notifyErr(err error) {
s.rbuf.put(&msgdecoder.RecvMsg{Err: err})
}
// consume is called by transport's reader goroutine for parsing
// and decoding data received for this stream.
func (s *Stream) consume(b []byte, padding int) error {
// Flow control check.
if s.fc != nil { // HandlerServer doesn't use our flow control.
if err := s.fc.onData(uint32(len(b) + padding)); err != nil {
return err
}
}
s.msgDecoder.Decode(b, padding)
return nil
}
// Read reads one whole message from the transport.
// It is called by RPC's goroutine.
// It is not safe to be called concurrently by multiple goroutines.
//
// Returns:
// 1. received message's compression status(true if was compressed)
// 2. Message as a byte slice
// 3. Error, if any.
func (s *Stream) Read(maxRecvMsgSize int) (bool, []byte, error) {
if s.readErr != nil {
return false, nil, s.readErr
}
var (
m *msgdecoder.RecvMsg
err error
)
// First read the underlying message header
if m, err = s.rbuf.get(); err != nil {
s.readErr = err
return false, nil, err
}
if m.Err != nil {
s.readErr = m.Err
return false, nil, s.readErr
}
// Make sure the message being received isn't too large.
if int64(m.Length) > int64(maxInt) {
s.readErr = status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max length allowed on current machine (%d vs. %d)", m.Length, maxInt)
return false, nil, s.readErr
}
if m.Length > maxRecvMsgSize {
s.readErr = status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", m.Length, maxRecvMsgSize)
return false, nil, s.readErr
}
// Send a window update for the message this RPC is reading.
if s.fc != nil { // HanderServer doesn't use our flow control.
s.fc.onRead(uint32(m.Length + m.Overhead))
}
isCompressed := m.IsCompressed
// Read the message.
if m, err = s.rbuf.get(); err != nil {
s.readErr = err
return false, nil, err
}
if m.Err != nil {
if m.Err == io.EOF {
m.Err = io.ErrUnexpectedEOF
}
s.readErr = m.Err
return false, nil, s.readErr
}
return isCompressed, m.Data, nil
}
func (s *Stream) swapState(st streamState) streamState {
return streamState(atomic.SwapUint32((*uint32)(&s.state), uint32(st)))
}
func (s *Stream) compareAndSwapState(oldState, newState streamState) bool {
return atomic.CompareAndSwapUint32((*uint32)(&s.state), uint32(oldState), uint32(newState))
}
func (s *Stream) getState() streamState {
return streamState(atomic.LoadUint32((*uint32)(&s.state)))
}
func (s *Stream) waitOnHeader() error {
if s.headerChan == nil {
// On the server headerChan is always nil since a stream originates
// only after having received headers.
return nil
}
select {
case <-s.ctx.Done():
return ContextErr(s.ctx.Err())
case <-s.headerChan:
return nil
}
}
// RecvCompress returns the compression algorithm applied to the inbound
// message. It is empty string if there is no compression applied.
func (s *Stream) RecvCompress() string {
if err := s.waitOnHeader(); err != nil {
return ""
}
return s.recvCompress
}
// SetSendCompress sets the compression algorithm to the stream.
func (s *Stream) SetSendCompress(str string) {
s.sendCompress = str
}
// Done returns a chanel which is closed when it receives the final status
// from the server.
func (s *Stream) Done() <-chan struct{} {
return s.done
}
// Header acquires the key-value pairs of header metadata once it
// is available. It blocks until i) the metadata is ready or ii) there is no
// header metadata or iii) the stream is canceled/expired.
func (s *Stream) Header() (metadata.MD, error) {
err := s.waitOnHeader()
// Even if the stream is closed, header is returned if available.
select {
case <-s.headerChan:
if s.header == nil {
return nil, nil
}
return s.header.Copy(), nil
default:
}
return nil, err
}
// Trailer returns the cached trailer metedata. Note that if it is not called
// after the entire stream is done, it could return an empty MD. Client
// side only.
// It can be safely read only after stream has ended that is either read
// or write have returned io.EOF.
func (s *Stream) Trailer() metadata.MD {
c := s.trailer.Copy()
return c
}
// ServerTransport returns the underlying ServerTransport for the stream.
// The client side stream always returns nil.
func (s *Stream) ServerTransport() ServerTransport {
return s.st
}
// ContentSubtype returns the content-subtype for a request. For example, a
// content-subtype of "proto" will result in a content-type of
// "application/grpc+proto". This will always be lowercase. See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details.
func (s *Stream) ContentSubtype() string {
return s.contentSubtype
}
// Context returns the context of the stream.
func (s *Stream) Context() context.Context {
return s.ctx
}
// Method returns the method for the stream.
func (s *Stream) Method() string {
return s.method
}
// Status returns the status received from the server.
// Status can be read safely only after the stream has ended,
// that is, read or write has returned io.EOF.
func (s *Stream) Status() *status.Status {
return s.status
}
// SetHeader sets the header metadata. This can be called multiple times.
// Server side only.
// This should not be called in parallel to other data writes.
func (s *Stream) SetHeader(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
if s.headerOk || atomic.LoadUint32((*uint32)(&s.state)) == uint32(streamDone) {
return ErrIllegalHeaderWrite
}
s.header = metadata.Join(s.header, md)
return nil
}
// SendHeader sends the given header metadata. The given metadata is
// combined with any metadata set by previous calls to SetHeader and
// then written to the transport stream.
func (s *Stream) SendHeader(md metadata.MD) error {
t := s.ServerTransport()
return t.WriteHeader(s, md)
}
// SetTrailer sets the trailer metadata which will be sent with the RPC status
// by the server. This can be called multiple times. Server side only.
// This should not be called parallel to other data writes.
func (s *Stream) SetTrailer(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
s.trailer = metadata.Join(s.trailer, md)
return nil
}
// BytesReceived indicates whether any bytes have been received on this stream.
func (s *Stream) BytesReceived() bool {
return atomic.LoadUint32(&s.bytesReceived) == 1
}
// Unprocessed indicates whether the server did not process this stream --
// i.e. it sent a refused stream or GOAWAY including this stream ID.
func (s *Stream) Unprocessed() bool {
return atomic.LoadUint32(&s.unprocessed) == 1
}
// GoString is implemented by Stream so context.String() won't
// race when printing %#v.
func (s *Stream) GoString() string {
return fmt.Sprintf("<stream: %p, %v>", s, s.method)
}

View File

@ -24,10 +24,7 @@ package transport // externally used as import "google.golang.org/grpc/transport
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"sync"
"sync/atomic"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@ -39,359 +36,6 @@ import (
"google.golang.org/grpc/tap" "google.golang.org/grpc/tap"
) )
// recvMsg represents the received msg from the transport. All transport
// protocol specific info has been removed.
type recvMsg struct {
data []byte
// nil: received some data
// io.EOF: stream is completed. data is nil.
// other non-nil error: transport failure. data is nil.
err error
}
// recvBuffer is an unbounded channel of recvMsg structs.
// Note recvBuffer differs from controlBuffer only in that recvBuffer
// holds a channel of only recvMsg structs instead of objects implementing "item" interface.
// recvBuffer is written to much more often than
// controlBuffer and using strict recvMsg structs helps avoid allocation in "recvBuffer.put"
type recvBuffer struct {
c chan recvMsg
mu sync.Mutex
backlog []recvMsg
err error
}
func newRecvBuffer() *recvBuffer {
b := &recvBuffer{
c: make(chan recvMsg, 1),
}
return b
}
func (b *recvBuffer) put(r recvMsg) {
b.mu.Lock()
if b.err != nil {
b.mu.Unlock()
// An error had occurred earlier, don't accept more
// data or errors.
return
}
b.err = r.err
if len(b.backlog) == 0 {
select {
case b.c <- r:
b.mu.Unlock()
return
default:
}
}
b.backlog = append(b.backlog, r)
b.mu.Unlock()
}
func (b *recvBuffer) load() {
b.mu.Lock()
if len(b.backlog) > 0 {
select {
case b.c <- b.backlog[0]:
b.backlog[0] = recvMsg{}
b.backlog = b.backlog[1:]
default:
}
}
b.mu.Unlock()
}
// get returns the channel that receives a recvMsg in the buffer.
//
// Upon receipt of a recvMsg, the caller should call load to send another
// recvMsg onto the channel if there is any.
func (b *recvBuffer) get() <-chan recvMsg {
return b.c
}
//
// recvBufferReader implements io.Reader interface to read the data from
// recvBuffer.
type recvBufferReader struct {
ctx context.Context
ctxDone <-chan struct{} // cache of ctx.Done() (for performance).
recv *recvBuffer
last []byte // Stores the remaining data in the previous calls.
err error
}
// Read reads the next len(p) bytes from last. If last is drained, it tries to
// read additional data from recv. It blocks if there no additional data available
// in recv. If Read returns any non-nil error, it will continue to return that error.
func (r *recvBufferReader) Read(p []byte) (n int, err error) {
if r.err != nil {
return 0, r.err
}
n, r.err = r.read(p)
return n, r.err
}
func (r *recvBufferReader) read(p []byte) (n int, err error) {
if r.last != nil && len(r.last) > 0 {
// Read remaining data left in last call.
copied := copy(p, r.last)
r.last = r.last[copied:]
return copied, nil
}
select {
case <-r.ctxDone:
return 0, ContextErr(r.ctx.Err())
case m := <-r.recv.get():
r.recv.load()
if m.err != nil {
return 0, m.err
}
copied := copy(p, m.data)
r.last = m.data[copied:]
return copied, nil
}
}
type streamState uint32
const (
streamActive streamState = iota
streamWriteDone // EndStream sent
streamReadDone // EndStream received
streamDone // the entire stream is finished.
)
// Stream represents an RPC in the transport layer.
type Stream struct {
id uint32
st ServerTransport // nil for client side Stream
ctx context.Context // the associated context of the stream
cancel context.CancelFunc // always nil for client side Stream
done chan struct{} // closed at the end of stream to unblock writers. On the client side.
ctxDone <-chan struct{} // same as done chan but for server side. Cache of ctx.Done() (for performance)
method string // the associated RPC method of the stream
recvCompress string
sendCompress string
buf *recvBuffer
trReader io.Reader
fc *inFlow
recvQuota uint32
wq *writeQuota
// Callback to state application's intentions to read data. This
// is used to adjust flow control, if needed.
requestRead func(int)
headerChan chan struct{} // closed to indicate the end of header metadata.
headerDone uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
header metadata.MD // the received header metadata.
trailer metadata.MD // the key-value map of trailer metadata.
headerOk bool // becomes true from the first header is about to send
state streamState
status *status.Status // the status error received from the server
bytesReceived uint32 // indicates whether any bytes have been received on this stream
unprocessed uint32 // set if the server sends a refused stream or GOAWAY including this stream
// contentSubtype is the content-subtype for requests.
// this must be lowercase or the behavior is undefined.
contentSubtype string
}
func (s *Stream) swapState(st streamState) streamState {
return streamState(atomic.SwapUint32((*uint32)(&s.state), uint32(st)))
}
func (s *Stream) compareAndSwapState(oldState, newState streamState) bool {
return atomic.CompareAndSwapUint32((*uint32)(&s.state), uint32(oldState), uint32(newState))
}
func (s *Stream) getState() streamState {
return streamState(atomic.LoadUint32((*uint32)(&s.state)))
}
func (s *Stream) waitOnHeader() error {
if s.headerChan == nil {
// On the server headerChan is always nil since a stream originates
// only after having received headers.
return nil
}
select {
case <-s.ctx.Done():
return ContextErr(s.ctx.Err())
case <-s.headerChan:
return nil
}
}
// RecvCompress returns the compression algorithm applied to the inbound
// message. It is empty string if there is no compression applied.
func (s *Stream) RecvCompress() string {
if err := s.waitOnHeader(); err != nil {
return ""
}
return s.recvCompress
}
// SetSendCompress sets the compression algorithm to the stream.
func (s *Stream) SetSendCompress(str string) {
s.sendCompress = str
}
// Done returns a chanel which is closed when it receives the final status
// from the server.
func (s *Stream) Done() <-chan struct{} {
return s.done
}
// Header acquires the key-value pairs of header metadata once it
// is available. It blocks until i) the metadata is ready or ii) there is no
// header metadata or iii) the stream is canceled/expired.
func (s *Stream) Header() (metadata.MD, error) {
err := s.waitOnHeader()
// Even if the stream is closed, header is returned if available.
select {
case <-s.headerChan:
if s.header == nil {
return nil, nil
}
return s.header.Copy(), nil
default:
}
return nil, err
}
// Trailer returns the cached trailer metedata. Note that if it is not called
// after the entire stream is done, it could return an empty MD. Client
// side only.
// It can be safely read only after stream has ended that is either read
// or write have returned io.EOF.
func (s *Stream) Trailer() metadata.MD {
c := s.trailer.Copy()
return c
}
// ServerTransport returns the underlying ServerTransport for the stream.
// The client side stream always returns nil.
func (s *Stream) ServerTransport() ServerTransport {
return s.st
}
// ContentSubtype returns the content-subtype for a request. For example, a
// content-subtype of "proto" will result in a content-type of
// "application/grpc+proto". This will always be lowercase. See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details.
func (s *Stream) ContentSubtype() string {
return s.contentSubtype
}
// Context returns the context of the stream.
func (s *Stream) Context() context.Context {
return s.ctx
}
// Method returns the method for the stream.
func (s *Stream) Method() string {
return s.method
}
// Status returns the status received from the server.
// Status can be read safely only after the stream has ended,
// that is, read or write has returned io.EOF.
func (s *Stream) Status() *status.Status {
return s.status
}
// SetHeader sets the header metadata. This can be called multiple times.
// Server side only.
// This should not be called in parallel to other data writes.
func (s *Stream) SetHeader(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
if s.headerOk || atomic.LoadUint32((*uint32)(&s.state)) == uint32(streamDone) {
return ErrIllegalHeaderWrite
}
s.header = metadata.Join(s.header, md)
return nil
}
// SendHeader sends the given header metadata. The given metadata is
// combined with any metadata set by previous calls to SetHeader and
// then written to the transport stream.
func (s *Stream) SendHeader(md metadata.MD) error {
t := s.ServerTransport()
return t.WriteHeader(s, md)
}
// SetTrailer sets the trailer metadata which will be sent with the RPC status
// by the server. This can be called multiple times. Server side only.
// This should not be called parallel to other data writes.
func (s *Stream) SetTrailer(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
s.trailer = metadata.Join(s.trailer, md)
return nil
}
func (s *Stream) write(m recvMsg) {
s.buf.put(m)
}
// Read reads all p bytes from the wire for this stream.
func (s *Stream) Read(p []byte) (n int, err error) {
// Don't request a read if there was an error earlier
if er := s.trReader.(*transportReader).er; er != nil {
return 0, er
}
s.requestRead(len(p))
return io.ReadFull(s.trReader, p)
}
// tranportReader reads all the data available for this Stream from the transport and
// passes them into the decoder, which converts them into a gRPC message stream.
// The error is io.EOF when the stream is done or another non-nil error if
// the stream broke.
type transportReader struct {
reader io.Reader
// The handler to control the window update procedure for both this
// particular stream and the associated transport.
windowHandler func(int)
er error
}
func (t *transportReader) Read(p []byte) (n int, err error) {
n, err = t.reader.Read(p)
if err != nil {
t.er = err
return
}
t.windowHandler(n)
return
}
// BytesReceived indicates whether any bytes have been received on this stream.
func (s *Stream) BytesReceived() bool {
return atomic.LoadUint32(&s.bytesReceived) == 1
}
// Unprocessed indicates whether the server did not process this stream --
// i.e. it sent a refused stream or GOAWAY including this stream ID.
func (s *Stream) Unprocessed() bool {
return atomic.LoadUint32(&s.unprocessed) == 1
}
// GoString is implemented by Stream so context.String() won't
// race when printing %#v.
func (s *Stream) GoString() string {
return fmt.Sprintf("<stream: %p, %v>", s, s.method)
}
// state of transport // state of transport
type transportState int type transportState int
@ -476,7 +120,13 @@ type Options struct {
// Delay is a hint to the transport implementation for whether // Delay is a hint to the transport implementation for whether
// the data could be buffered for a batching write. The // the data could be buffered for a batching write. The
// transport implementation may ignore the hint. // transport implementation may ignore the hint.
// TODO(mmukhi, dfawley): Should this be deleted?
Delay bool Delay bool
// IsCompressed indicates weather the message being written
// was compressed or not. Transport relays this information
// to the API that generates gRPC-specific message header.
IsCompressed bool
} }
// CallHdr carries the information of a particular RPC. // CallHdr carries the information of a particular RPC.
@ -525,7 +175,7 @@ type ClientTransport interface {
// Write sends the data for the given stream. A nil stream indicates // Write sends the data for the given stream. A nil stream indicates
// the write is to be performed on the transport as a whole. // the write is to be performed on the transport as a whole.
Write(s *Stream, hdr []byte, data []byte, opts *Options) error Write(s *Stream, data []byte, opts *Options) error
// NewStream creates a Stream for an RPC. // NewStream creates a Stream for an RPC.
NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error) NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error)
@ -573,7 +223,7 @@ type ServerTransport interface {
// Write sends the data for the given stream. // Write sends the data for the given stream.
// Write may not be called on all streams. // Write may not be called on all streams.
Write(s *Stream, hdr []byte, data []byte, opts *Options) error Write(s *Stream, data []byte, opts *Options) error
// WriteStatus sends the status of a stream to the client. WriteStatus is // WriteStatus sends the status of a stream to the client. WriteStatus is
// the final call made on a stream and always occurs. // the final call made on a stream and always occurs.

View File

@ -100,8 +100,7 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
req = expectedRequestLarge req = expectedRequestLarge
resp = expectedResponseLarge resp = expectedResponseLarge
} }
p := make([]byte, len(req)) _, p, err := s.Read(math.MaxInt32)
_, err := s.Read(p)
if err != nil { if err != nil {
return return
} }
@ -109,31 +108,26 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
t.Fatalf("handleStream got %v, want %v", p, req) t.Fatalf("handleStream got %v, want %v", p, req)
} }
// send a response back to the client. // send a response back to the client.
h.t.Write(s, nil, resp, &Options{}) h.t.Write(s, resp, &Options{})
// send the trailer to end the stream. // send the trailer to end the stream.
h.t.WriteStatus(s, status.New(codes.OK, "")) h.t.WriteStatus(s, status.New(codes.OK, ""))
} }
func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) { func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) {
header := make([]byte, 5)
for { for {
if _, err := s.Read(header); err != nil { _, msg, err := s.Read(math.MaxInt32)
if err != nil {
if err == io.EOF { if err == io.EOF {
h.t.WriteStatus(s, status.New(codes.OK, "")) h.t.WriteStatus(s, status.New(codes.OK, ""))
return return
} }
t.Fatalf("Error on server while reading data header: %v", err) t.Errorf("Error on server while reading data header: %v", err)
return
} }
sz := binary.BigEndian.Uint32(header[1:]) if err := h.t.Write(s, msg, &Options{}); err != nil {
msg := make([]byte, int(sz)) t.Errorf("Error on server while writing: %v", err)
if _, err := s.Read(msg); err != nil { return
t.Fatalf("Error on server while reading message: %v", err)
} }
buf := make([]byte, sz+5)
buf[0] = byte(0)
binary.BigEndian.PutUint32(buf[1:], uint32(sz))
copy(buf[5:], msg)
h.t.Write(s, nil, buf, &Options{})
} }
} }
@ -189,12 +183,10 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) {
req = expectedRequestLarge req = expectedRequestLarge
resp = expectedResponseLarge resp = expectedResponseLarge
} }
p := make([]byte, len(req))
// Wait before reading. Give time to client to start sending // Wait before reading. Give time to client to start sending
// before server starts reading. // before server starts reading.
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
_, err := s.Read(p) _, p, err := s.Read(math.MaxInt32)
if err != nil { if err != nil {
t.Errorf("s.Read(_) = _, %v, want _, <nil>", err) t.Errorf("s.Read(_) = _, %v, want _, <nil>", err)
return return
@ -205,7 +197,7 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) {
return return
} }
// send a response back to the client. // send a response back to the client.
if err := h.t.Write(s, nil, resp, &Options{}); err != nil { if err := h.t.Write(s, resp, &Options{}); err != nil {
t.Errorf("server Write got %v, want <nil>", err) t.Errorf("server Write got %v, want <nil>", err)
return return
} }
@ -223,8 +215,7 @@ func (h *testStreamHandler) handleStreamDelayWrite(t *testing.T, s *Stream) {
req = expectedRequestLarge req = expectedRequestLarge
resp = expectedResponseLarge resp = expectedResponseLarge
} }
p := make([]byte, len(req)) _, p, err := s.Read(math.MaxInt32)
_, err := s.Read(p)
if err != nil { if err != nil {
t.Errorf("s.Read(_) = _, %v, want _, <nil>", err) t.Errorf("s.Read(_) = _, %v, want _, <nil>", err)
return return
@ -237,7 +228,7 @@ func (h *testStreamHandler) handleStreamDelayWrite(t *testing.T, s *Stream) {
// Wait before sending. Give time to client to start reading // Wait before sending. Give time to client to start reading
// before server starts sending. // before server starts sending.
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
if err := h.t.Write(s, nil, resp, &Options{}); err != nil { if err := h.t.Write(s, resp, &Options{}); err != nil {
t.Errorf("server Write got %v, want <nil>", err) t.Errorf("server Write got %v, want <nil>", err)
return return
} }
@ -442,7 +433,7 @@ func TestInflightStreamClosing(t *testing.T) {
serr := StreamError{Desc: "client connection is closing"} serr := StreamError{Desc: "client connection is closing"}
go func() { go func() {
defer close(donec) defer close(donec)
if _, err := stream.Read(make([]byte, defaultWindowSize)); err != serr { if _, _, err := stream.Read(math.MaxInt32); err != serr {
t.Errorf("unexpected Stream error %v, expected %v", err, serr) t.Errorf("unexpected Stream error %v, expected %v", err, serr)
} }
}() }()
@ -858,15 +849,14 @@ func TestClientSendAndReceive(t *testing.T) {
Last: true, Last: true,
Delay: false, Delay: false,
} }
if err := ct.Write(s1, nil, expectedRequest, &opts); err != nil && err != io.EOF { if err := ct.Write(s1, expectedRequest, &opts); err != nil && err != io.EOF {
t.Fatalf("failed to send data: %v", err) t.Fatalf("failed to send data: %v", err)
} }
p := make([]byte, len(expectedResponse)) _, p, recvErr := s1.Read(math.MaxInt32)
_, recvErr := s1.Read(p)
if recvErr != nil || !bytes.Equal(p, expectedResponse) { if recvErr != nil || !bytes.Equal(p, expectedResponse) {
t.Fatalf("Error: %v, want <nil>; Result: %v, want %v", recvErr, p, expectedResponse) t.Fatalf("Error: %v, want <nil>; Result: %v, want %v", recvErr, p, expectedResponse)
} }
_, recvErr = s1.Read(p) _, _, recvErr = s1.Read(math.MaxInt32)
if recvErr != io.EOF { if recvErr != io.EOF {
t.Fatalf("Error: %v; want <EOF>", recvErr) t.Fatalf("Error: %v; want <EOF>", recvErr)
} }
@ -895,16 +885,15 @@ func performOneRPC(ct ClientTransport) {
Last: true, Last: true,
Delay: false, Delay: false,
} }
if err := ct.Write(s, []byte{}, expectedRequest, &opts); err == nil || err == io.EOF { if err := ct.Write(s, expectedRequest, &opts); err == nil || err == io.EOF {
time.Sleep(5 * time.Millisecond) time.Sleep(5 * time.Millisecond)
// The following s.Recv()'s could error out because the // The following s.Recv()'s could error out because the
// underlying transport is gone. // underlying transport is gone.
// //
// Read response // Read response
p := make([]byte, len(expectedResponse)) s.Read(math.MaxInt32)
s.Read(p)
// Read io.EOF // Read io.EOF
s.Read(p) s.Read(math.MaxInt32)
} }
} }
@ -939,14 +928,13 @@ func TestLargeMessage(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err) t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
} }
if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err) t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
} }
p := make([]byte, len(expectedResponseLarge)) if _, p, err := s.Read(math.MaxInt32); err != nil || !bytes.Equal(p, expectedResponseLarge) {
if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) { t.Errorf("s.Read(math.MaxInt32) = %v, %v, want %v, <nil>", p, err, expectedResponse)
t.Errorf("s.Read(%v) = _, %v, want %v, <nil>", err, p, expectedResponse)
} }
if _, err = s.Read(p); err != io.EOF { if _, _, err = s.Read(math.MaxInt32); err != io.EOF {
t.Errorf("Failed to complete the stream %v; want <EOF>", err) t.Errorf("Failed to complete the stream %v; want <EOF>", err)
} }
}() }()
@ -974,19 +962,18 @@ func TestLargeMessageWithDelayRead(t *testing.T) {
t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err) t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
return return
} }
if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil { if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil {
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err) t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
return return
} }
p := make([]byte, len(expectedResponseLarge))
// Give time to server to begin sending before client starts reading. // Give time to server to begin sending before client starts reading.
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) { if _, p, err := s.Read(math.MaxInt32); err != nil || !bytes.Equal(p, expectedResponseLarge) {
t.Errorf("s.Read(_) = _, %v, want _, <nil>", err) t.Errorf("s.Read(_) = _, %v, want _, <nil>", err)
return return
} }
if _, err = s.Read(p); err != io.EOF { if _, _, err = s.Read(math.MaxInt32); err != io.EOF {
t.Errorf("Failed to complete the stream %v; want <EOF>", err) t.Errorf("Failed to complete the stream %v; want <EOF>", err)
} }
}() }()
@ -1017,16 +1004,15 @@ func TestLargeMessageDelayWrite(t *testing.T) {
// Give time to server to start reading before client starts sending. // Give time to server to start reading before client starts sending.
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil { if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil {
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err) t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
return return
} }
p := make([]byte, len(expectedResponseLarge)) if _, p, err := s.Read(math.MaxInt32); err != nil || !bytes.Equal(p, expectedResponseLarge) {
if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
t.Errorf("io.ReadFull(%v) = _, %v, want %v, <nil>", err, p, expectedResponse) t.Errorf("io.ReadFull(%v) = _, %v, want %v, <nil>", err, p, expectedResponse)
return return
} }
if _, err = s.Read(p); err != io.EOF { if _, _, err = s.Read(math.MaxInt32); err != io.EOF {
t.Errorf("Failed to complete the stream %v; want <EOF>", err) t.Errorf("Failed to complete the stream %v; want <EOF>", err)
} }
}() }()
@ -1047,19 +1033,10 @@ func TestGracefulClose(t *testing.T) {
t.Fatalf("NewStream(_, _) = _, %v, want _, <nil>", err) t.Fatalf("NewStream(_, _) = _, %v, want _, <nil>", err)
} }
msg := make([]byte, 1024) msg := make([]byte, 1024)
outgoingHeader := make([]byte, 5) if err := ct.Write(s, msg, &Options{}); err != nil {
outgoingHeader[0] = byte(0)
binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(len(msg)))
incomingHeader := make([]byte, 5)
if err := ct.Write(s, outgoingHeader, msg, &Options{}); err != nil {
t.Fatalf("Error while writing: %v", err) t.Fatalf("Error while writing: %v", err)
} }
if _, err := s.Read(incomingHeader); err != nil { if _, _, err := s.Read(math.MaxInt32); err != nil {
t.Fatalf("Error while reading: %v", err)
}
sz := binary.BigEndian.Uint32(incomingHeader[1:])
recvMsg := make([]byte, int(sz))
if _, err := s.Read(recvMsg); err != nil {
t.Fatalf("Error while reading: %v", err) t.Fatalf("Error while reading: %v", err)
} }
if err = ct.GracefulClose(); err != nil { if err = ct.GracefulClose(); err != nil {
@ -1075,14 +1052,14 @@ func TestGracefulClose(t *testing.T) {
if err == errStreamDrain { if err == errStreamDrain {
return return
} }
ct.Write(str, nil, nil, &Options{Last: true}) ct.Write(str, nil, &Options{Last: true})
if _, err := str.Read(make([]byte, 8)); err != errStreamDrain { if _, _, err := str.Read(math.MaxInt32); err != errStreamDrain {
t.Errorf("_.NewStream(_, _) = _, %v, want _, %v", err, errStreamDrain) t.Errorf("_.NewStream(_, _) = _, %v, want _, %v", err, errStreamDrain)
} }
}() }()
} }
ct.Write(s, nil, nil, &Options{Last: true}) ct.Write(s, nil, &Options{Last: true})
if _, err := s.Read(incomingHeader); err != io.EOF { if _, _, err := s.Read(math.MaxInt32); err != io.EOF {
t.Fatalf("Client expected EOF from the server. Got: %v", err) t.Fatalf("Client expected EOF from the server. Got: %v", err)
} }
// The stream which was created before graceful close can still proceed. // The stream which was created before graceful close can still proceed.
@ -1110,13 +1087,13 @@ func TestLargeMessageSuspension(t *testing.T) {
}() }()
// Write should not be done successfully due to flow control. // Write should not be done successfully due to flow control.
msg := make([]byte, initialWindowSize*8) msg := make([]byte, initialWindowSize*8)
ct.Write(s, nil, msg, &Options{}) ct.Write(s, msg, &Options{})
err = ct.Write(s, nil, msg, &Options{Last: true}) err = ct.Write(s, msg, &Options{Last: true})
if err != errStreamDone { if err != errStreamDone {
t.Fatalf("Write got %v, want io.EOF", err) t.Fatalf("Write got %v, want io.EOF", err)
} }
expectedErr := streamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded) expectedErr := streamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded)
if _, err := s.Read(make([]byte, 8)); err != expectedErr { if _, _, err := s.Read(math.MaxInt32); err != expectedErr {
t.Fatalf("Read got %v of type %T, want %v", err, err, expectedErr) t.Fatalf("Read got %v of type %T, want %v", err, err, expectedErr)
} }
ct.Close() ct.Close()
@ -1305,7 +1282,7 @@ func TestClientConnDecoupledFromApplicationRead(t *testing.T) {
t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream1.id) t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream1.id)
} }
// Exhaust client's connection window. // Exhaust client's connection window.
if err := st.Write(sstream1, []byte{}, make([]byte, defaultWindowSize), &Options{}); err != nil { if err := st.Write(sstream1, make([]byte, defaultWindowSize), &Options{}); err != nil {
t.Fatalf("Server failed to write data. Err: %v", err) t.Fatalf("Server failed to write data. Err: %v", err)
} }
notifyChan = make(chan struct{}) notifyChan = make(chan struct{})
@ -1330,17 +1307,17 @@ func TestClientConnDecoupledFromApplicationRead(t *testing.T) {
t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream2.id) t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream2.id)
} }
// Server should be able to send data on the new stream, even though the client hasn't read anything on the first stream. // Server should be able to send data on the new stream, even though the client hasn't read anything on the first stream.
if err := st.Write(sstream2, []byte{}, make([]byte, defaultWindowSize), &Options{}); err != nil { if err := st.Write(sstream2, make([]byte, defaultWindowSize), &Options{}); err != nil {
t.Fatalf("Server failed to write data. Err: %v", err) t.Fatalf("Server failed to write data. Err: %v", err)
} }
// Client should be able to read data on second stream. // Client should be able to read data on second stream.
if _, err := cstream2.Read(make([]byte, defaultWindowSize)); err != nil { if _, _, err := cstream2.Read(math.MaxInt32); err != nil {
t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err) t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err)
} }
// Client should be able to read data on first stream. // Client should be able to read data on first stream.
if _, err := cstream1.Read(make([]byte, defaultWindowSize)); err != nil { if _, _, err := cstream1.Read(math.MaxInt32); err != nil {
t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err) t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err)
} }
} }
@ -1373,7 +1350,7 @@ func TestServerConnDecoupledFromApplicationRead(t *testing.T) {
t.Fatalf("Failed to create 1st stream. Err: %v", err) t.Fatalf("Failed to create 1st stream. Err: %v", err)
} }
// Exhaust server's connection window. // Exhaust server's connection window.
if err := client.Write(cstream1, nil, make([]byte, defaultWindowSize), &Options{Last: true}); err != nil { if err := client.Write(cstream1, make([]byte, defaultWindowSize), &Options{Last: true}); err != nil {
t.Fatalf("Client failed to write data. Err: %v", err) t.Fatalf("Client failed to write data. Err: %v", err)
} }
//Client should be able to create another stream and send data on it. //Client should be able to create another stream and send data on it.
@ -1381,7 +1358,7 @@ func TestServerConnDecoupledFromApplicationRead(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to create 2nd stream. Err: %v", err) t.Fatalf("Failed to create 2nd stream. Err: %v", err)
} }
if err := client.Write(cstream2, nil, make([]byte, defaultWindowSize), &Options{}); err != nil { if err := client.Write(cstream2, make([]byte, defaultWindowSize), &Options{}); err != nil {
t.Fatalf("Client failed to write data. Err: %v", err) t.Fatalf("Client failed to write data. Err: %v", err)
} }
// Get the streams on server. // Get the streams on server.
@ -1403,11 +1380,11 @@ func TestServerConnDecoupledFromApplicationRead(t *testing.T) {
} }
st.mu.Unlock() st.mu.Unlock()
// Reading from the stream on server should succeed. // Reading from the stream on server should succeed.
if _, err := sstream1.Read(make([]byte, defaultWindowSize)); err != nil { if _, _, err := sstream1.Read(math.MaxInt32); err != nil {
t.Fatalf("_.Read(_) = %v, want <nil>", err) t.Fatalf("_.Read(_) = %v, want <nil>", err)
} }
if _, err := sstream1.Read(make([]byte, 1)); err != io.EOF { if _, _, err := sstream1.Read(math.MaxInt32); err != io.EOF {
t.Fatalf("_.Read(_) = %v, want io.EOF", err) t.Fatalf("_.Read(_) = %v, want io.EOF", err)
} }
@ -1616,11 +1593,10 @@ func TestEncodingRequiredStatus(t *testing.T) {
Last: true, Last: true,
Delay: false, Delay: false,
} }
if err := ct.Write(s, nil, expectedRequest, &opts); err != nil && err != errStreamDone { if err := ct.Write(s, expectedRequest, &opts); err != nil && err != errStreamDone {
t.Fatalf("Failed to write the request: %v", err) t.Fatalf("Failed to write the request: %v", err)
} }
p := make([]byte, http2MaxFrameLen) if _, _, err := s.Read(math.MaxInt32); err != io.EOF {
if _, err := s.trReader.(*transportReader).Read(p); err != io.EOF {
t.Fatalf("Read got error %v, want %v", err, io.EOF) t.Fatalf("Read got error %v, want %v", err, io.EOF)
} }
if !reflect.DeepEqual(s.Status(), encodingTestStatus) { if !reflect.DeepEqual(s.Status(), encodingTestStatus) {
@ -1640,8 +1616,7 @@ func TestInvalidHeaderField(t *testing.T) {
if err != nil { if err != nil {
return return
} }
p := make([]byte, http2MaxFrameLen) _, _, err = s.Read(math.MaxInt32)
_, err = s.trReader.(*transportReader).Read(p)
if se, ok := err.(StreamError); !ok || se.Code != codes.Internal || !strings.Contains(err.Error(), expectedInvalidHeaderField) { if se, ok := err.(StreamError); !ok || se.Code != codes.Internal || !strings.Contains(err.Error(), expectedInvalidHeaderField) {
t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.Internal, expectedInvalidHeaderField) t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.Internal, expectedInvalidHeaderField)
} }
@ -1764,26 +1739,17 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig)
t.Fatalf("Failed to create stream. Err: %v", err) t.Fatalf("Failed to create stream. Err: %v", err)
} }
msg := make([]byte, msgSize) msg := make([]byte, msgSize)
buf := make([]byte, msgSize+5)
buf[0] = byte(0)
binary.BigEndian.PutUint32(buf[1:], uint32(msgSize))
copy(buf[5:], msg)
opts := Options{} opts := Options{}
header := make([]byte, 5)
for i := 1; i <= 10; i++ { for i := 1; i <= 10; i++ {
if err := ct.Write(cstream, nil, buf, &opts); err != nil { if err := ct.Write(cstream, msg, &opts); err != nil {
t.Fatalf("Error on client while writing message: %v", err) t.Fatalf("Error on client while writing message: %v", err)
} }
if _, err := cstream.Read(header); err != nil { _, recvMsg, err := cstream.Read(math.MaxInt32)
t.Fatalf("Error on client while reading data frame header: %v", err) if err != nil {
}
sz := binary.BigEndian.Uint32(header[1:])
recvMsg := make([]byte, int(sz))
if _, err := cstream.Read(recvMsg); err != nil {
t.Fatalf("Error on client while reading data: %v", err) t.Fatalf("Error on client while reading data: %v", err)
} }
if len(recvMsg) != len(msg) { if !bytes.Equal(recvMsg, msg) {
t.Fatalf("Length of message received by client: %v, want: %v", len(recvMsg), len(msg)) t.Fatalf("Message received by client(len: %d) not equal to what was expected(len: %d)", len(recvMsg), len(msg))
} }
} }
var sstream *Stream var sstream *Stream
@ -1794,8 +1760,8 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig)
st.mu.Unlock() st.mu.Unlock()
loopyServerStream := st.loopy.estdStreams[sstream.id] loopyServerStream := st.loopy.estdStreams[sstream.id]
loopyClientStream := ct.loopy.estdStreams[cstream.id] loopyClientStream := ct.loopy.estdStreams[cstream.id]
ct.Write(cstream, nil, nil, &Options{Last: true}) // Close the stream. ct.Write(cstream, nil, &Options{Last: true}) // Close the stream.
if _, err := cstream.Read(header); err != io.EOF { if _, _, err := cstream.Read(math.MaxInt32); err != io.EOF {
t.Fatalf("Client expected an EOF from the server. Got: %v", err) t.Fatalf("Client expected an EOF from the server. Got: %v", err)
} }
// Sleep for a little to make sure both sides flush out their buffers. // Sleep for a little to make sure both sides flush out their buffers.
@ -1816,11 +1782,11 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig)
t.Fatalf("Account mismatch: server transport inflow(%d) != server unacked(%d) + client sendQuota(%d)", st.fc.limit, st.fc.unacked, ct.loopy.sendQuota) t.Fatalf("Account mismatch: server transport inflow(%d) != server unacked(%d) + client sendQuota(%d)", st.fc.limit, st.fc.unacked, ct.loopy.sendQuota)
} }
// Check stream flow control. // Check stream flow control.
if int(cstream.fc.limit+cstream.fc.delta-cstream.fc.pendingData-cstream.fc.pendingUpdate) != int(st.loopy.oiws)-loopyServerStream.bytesOutStanding { if int(cstream.fc.limit)-int(cstream.fc.rcvd) != int(st.loopy.oiws)-loopyServerStream.bytesOutStanding {
t.Fatalf("Account mismatch: client stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != server outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", cstream.fc.limit, cstream.fc.delta, cstream.fc.pendingData, cstream.fc.pendingUpdate, st.loopy.oiws, loopyServerStream.bytesOutStanding) t.Fatalf("Account mismatch: client stream inflow limit(%d) - rcvd(%d) != server outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", cstream.fc.limit, cstream.fc.rcvd, st.loopy.oiws, loopyServerStream.bytesOutStanding)
} }
if int(sstream.fc.limit+sstream.fc.delta-sstream.fc.pendingData-sstream.fc.pendingUpdate) != int(ct.loopy.oiws)-loopyClientStream.bytesOutStanding { if int(sstream.fc.limit)-int(sstream.fc.rcvd) != int(ct.loopy.oiws)-loopyClientStream.bytesOutStanding {
t.Fatalf("Account mismatch: server stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != client outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", sstream.fc.limit, sstream.fc.delta, sstream.fc.pendingData, sstream.fc.pendingUpdate, ct.loopy.oiws, loopyClientStream.bytesOutStanding) t.Fatalf("Account mismatch: server stream inflow limit(%d) - rcvd(%d) != client outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", sstream.fc.limit, sstream.fc.rcvd, ct.loopy.oiws, loopyClientStream.bytesOutStanding)
} }
} }
@ -2000,8 +1966,7 @@ func testHTTPToGRPCStatusMapping(t *testing.T, httpStatus int, wh writeHeaders)
stream, cleanUp := setUpHTTPStatusTest(t, httpStatus, wh) stream, cleanUp := setUpHTTPStatusTest(t, httpStatus, wh)
defer cleanUp() defer cleanUp()
want := httpStatusConvTab[httpStatus] want := httpStatusConvTab[httpStatus]
buf := make([]byte, 8) _, _, err := stream.Read(math.MaxInt32)
_, err := stream.Read(buf)
if err == nil { if err == nil {
t.Fatalf("Stream.Read(_) unexpectedly returned no error. Expected stream error with code %v", want) t.Fatalf("Stream.Read(_) unexpectedly returned no error. Expected stream error with code %v", want)
} }
@ -2017,8 +1982,7 @@ func testHTTPToGRPCStatusMapping(t *testing.T, httpStatus int, wh writeHeaders)
func TestHTTPStatusOKAndMissingGRPCStatus(t *testing.T) { func TestHTTPStatusOKAndMissingGRPCStatus(t *testing.T) {
stream, cleanUp := setUpHTTPStatusTest(t, http.StatusOK, writeOneHeader) stream, cleanUp := setUpHTTPStatusTest(t, http.StatusOK, writeOneHeader)
defer cleanUp() defer cleanUp()
buf := make([]byte, 8) _, _, err := stream.Read(math.MaxInt32)
_, err := stream.Read(buf)
if err != io.EOF { if err != io.EOF {
t.Fatalf("stream.Read(_) = _, %v, want _, io.EOF", err) t.Fatalf("stream.Read(_) = _, %v, want _, io.EOF", err)
} }
@ -2035,45 +1999,25 @@ func TestHTTPStatusNottOKAndMissingGRPCStatusInSecondHeader(t *testing.T) {
// If any error occurs on a call to Stream.Read, future calls // If any error occurs on a call to Stream.Read, future calls
// should continue to return that same error. // should continue to return that same error.
func TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) { func TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) {
testRecvBuffer := newRecvBuffer() s := newStream(context.Background())
s := &Stream{
ctx: context.Background(),
buf: testRecvBuffer,
requestRead: func(int) {},
}
s.trReader = &transportReader{
reader: &recvBufferReader{
ctx: s.ctx,
ctxDone: s.ctx.Done(),
recv: s.buf,
},
windowHandler: func(int) {},
}
testData := make([]byte, 1)
testData[0] = 5
testErr := errors.New("test error") testErr := errors.New("test error")
s.write(recvMsg{data: testData, err: testErr}) s.notifyErr(testErr)
inBuf := make([]byte, 1) pf, inBuf, actualErr := s.Read(math.MaxInt32)
actualCount, actualErr := s.Read(inBuf) if pf != false || inBuf != nil || actualErr.Error() != testErr.Error() {
if actualCount != 0 { t.Errorf("%v, %v, %v := s.Read(_) differs; want false, <nil>, %v", pf, inBuf, actualErr, testErr)
t.Errorf("actualCount, _ := s.Read(_) differs; want 0; got %v", actualCount)
}
if actualErr.Error() != testErr.Error() {
t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error())
} }
s.write(recvMsg{data: testData, err: nil}) testData := make([]byte, 6)
s.write(recvMsg{data: testData, err: errors.New("different error from first")}) testData[0] = byte(1)
binary.BigEndian.PutUint32(testData[1:], uint32(1))
s.consume(testData, 0)
s.notifyErr(errors.New("different error from first"))
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
inBuf := make([]byte, 1) pf, inBuf, actualErr := s.Read(math.MaxInt32)
actualCount, actualErr := s.Read(inBuf) if pf != false || inBuf != nil || actualErr.Error() != testErr.Error() {
if actualCount != 0 { t.Errorf("%v, %v, %v := s.Read(_) differs; want false, <nil>, %v", pf, inBuf, actualErr, testErr)
t.Errorf("actualCount, _ := s.Read(_) differs; want %v; got %v", 0, actualCount)
}
if actualErr.Error() != testErr.Error() {
t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error())
} }
} }
} }
@ -2113,11 +2057,7 @@ func runPingPongTest(t *testing.T, msgSize int) {
t.Fatalf("Failed to create stream. Err: %v", err) t.Fatalf("Failed to create stream. Err: %v", err)
} }
msg := make([]byte, msgSize) msg := make([]byte, msgSize)
outgoingHeader := make([]byte, 5)
outgoingHeader[0] = byte(0)
binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(msgSize))
opts := &Options{} opts := &Options{}
incomingHeader := make([]byte, 5)
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
timer := time.NewTimer(time.Second * 5) timer := time.NewTimer(time.Second * 5)
@ -2127,23 +2067,22 @@ func runPingPongTest(t *testing.T, msgSize int) {
for { for {
select { select {
case <-done: case <-done:
ct.Write(stream, nil, nil, &Options{Last: true}) ct.Write(stream, nil, &Options{Last: true})
if _, err := stream.Read(incomingHeader); err != io.EOF { if _, _, err := stream.Read(math.MaxInt32); err != io.EOF {
t.Fatalf("Client expected EOF from the server. Got: %v", err) t.Fatalf("Client expected EOF from the server. Got: %v", err)
} }
return return
default: default:
if err := ct.Write(stream, outgoingHeader, msg, opts); err != nil { if err := ct.Write(stream, msg, opts); err != nil {
t.Fatalf("Error on client while writing message. Err: %v", err) t.Fatalf("Error on client while writing message. Err: %v", err)
} }
if _, err := stream.Read(incomingHeader); err != nil { _, recvMsg, err := stream.Read(math.MaxInt32)
t.Fatalf("Error on client while reading data header. Err: %v", err) if err != nil {
}
sz := binary.BigEndian.Uint32(incomingHeader[1:])
recvMsg := make([]byte, int(sz))
if _, err := stream.Read(recvMsg); err != nil {
t.Fatalf("Error on client while reading data. Err: %v", err) t.Fatalf("Error on client while reading data. Err: %v", err)
} }
if !bytes.Equal(recvMsg, msg) {
t.Fatalf("%v != %v", recvMsg, msg)
}
} }
} }
} }