* Export changes to OSS.

* First commit.

* Cherry-pick.

* Documentation.

* Post review updates.
This commit is contained in:
mmukhi
2018-04-30 09:54:33 -07:00
committed by GitHub
parent fc37cf1364
commit 7a8c989507
15 changed files with 1080 additions and 1023 deletions

View File

@ -66,17 +66,16 @@ type testStreamHandler struct {
}
func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
p := &parser{r: s}
for {
pf, req, err := p.recvMsg(math.MaxInt32)
isCompressed, req, err := recvMsg(s, math.MaxInt32)
if err == io.EOF {
break
}
if err != nil {
return
}
if pf != compressionNone {
t.Errorf("Received the mistaken message format %d, want %d", pf, compressionNone)
if isCompressed {
t.Errorf("Received compressed message want non-compressed message")
return
}
var v string
@ -105,12 +104,12 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
}
}
// send a response back to end the stream.
hdr, data, err := encode(testCodec{}, &expectedResponse, nil, nil, nil)
data, err := encode(testCodec{}, &expectedResponse, nil, nil, nil)
if err != nil {
t.Errorf("Failed to encode the response: %v", err)
return
}
h.t.Write(s, hdr, data, &transport.Options{})
h.t.Write(s, data, &transport.Options{})
h.t.WriteStatus(s, status.New(codes.OK, ""))
}

View File

@ -0,0 +1,203 @@
/*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package msgdecoder contains the logic to deconstruct a gRPC-message.
package msgdecoder
import (
"encoding/binary"
)
// RecvMsg is a message constructed from the incoming
// bytes on the transport for a stream.
// An instance of RecvMsg will contain only one of the
// following: message header related fields, data slice
// or error.
type RecvMsg struct {
// Following three are message header related
// fields.
// true if the message was compressed by the other
// side.
IsCompressed bool
// Length of the message.
Length int
// Overhead is the length of message header(5 bytes)
// plus padding.
Overhead int
// Data payload of the message.
Data []byte
// Err occurred while reading.
// nil: received some data
// io.EOF: stream is completed. data is nil.
// other non-nil error: transport failure. data is nil.
Err error
Next *RecvMsg
}
// RecvMsgList is a linked-list of RecvMsg.
type RecvMsgList struct {
head *RecvMsg
tail *RecvMsg
}
// IsEmpty returns true when l is empty.
func (l *RecvMsgList) IsEmpty() bool {
if l.tail == nil {
return true
}
return false
}
// Enqueue adds r to l at the back.
func (l *RecvMsgList) Enqueue(r *RecvMsg) {
if l.IsEmpty() {
l.head, l.tail = r, r
return
}
t := l.tail
l.tail = r
t.Next = r
}
// Dequeue removes a RcvMsg from the end of l.
func (l *RecvMsgList) Dequeue() *RecvMsg {
if l.head == nil {
// Note to developer: Instead of calling isEmpty() which
// checks the same condition on l.tail, we check it directly
// on l.head so that in non-nil cases, there aren't cache misses.
return nil
}
r := l.head
l.head = l.head.Next
if l.head == nil {
l.tail = nil
}
return r
}
// MessageDecoder decodes bytes from HTTP2 data frames
// and constructs a gRPC message which is then put in a
// buffer that application(RPCs) read from.
// gRPC Messages:
// First 5 bytes is the message header:
// First byte: Payload format.
// Next 4 bytes: Length of the message.
// Rest of the bytes is the message payload.
//
// TODO(mmukhi): Write unit tests.
type MessageDecoder struct {
// current message being read by the transport.
current *RecvMsg
dataOfst int
padding int
// hdr stores the message header as it is beind received by the transport.
hdr []byte
hdrOfst int
// Callback used to send decoded messages.
dispatch func(*RecvMsg)
}
// NewMessageDecoder creates an instance of MessageDecoder. It takes a callback
// which is called to dispatch finished headers and messages to the application.
func NewMessageDecoder(dispatch func(*RecvMsg)) *MessageDecoder {
return &MessageDecoder{
hdr: make([]byte, 5),
dispatch: dispatch,
}
}
// Decode consumes bytes from a HTTP2 data frame to create gRPC messages.
func (m *MessageDecoder) Decode(b []byte, padding int) {
m.padding += padding
for len(b) > 0 {
// Case 1: A complete message hdr was received earlier.
if m.current != nil {
n := copy(m.current.Data[m.dataOfst:], b)
m.dataOfst += n
b = b[n:]
if m.dataOfst == len(m.current.Data) { // Message is complete.
m.dispatch(m.current)
m.current = nil
m.dataOfst = 0
}
continue
}
// Case 2a: No message header has been received yet.
if m.hdrOfst == 0 {
// case 2a.1: b has the whole header
if len(b) >= 5 {
m.parseHeader(b[:5])
b = b[5:]
continue
}
// case 2a.2: b has partial header
n := copy(m.hdr, b)
m.hdrOfst = n
b = b[n:]
continue
}
// Case 2b: Partial message header was received earlier.
n := copy(m.hdr[m.hdrOfst:], b)
m.hdrOfst += n
b = b[n:]
if m.hdrOfst == 5 { // hdr is complete.
m.hdrOfst = 0
m.parseHeader(m.hdr)
}
}
}
func (m *MessageDecoder) parseHeader(b []byte) {
length := int(binary.BigEndian.Uint32(b[1:5]))
hdr := &RecvMsg{
IsCompressed: int(b[0]) == 1,
Length: length,
Overhead: m.padding + 5,
}
m.padding = 0
// Dispatch the information retreived from message header so
// that the RPC goroutine can send a proactive window update as we
// wait for the rest of it.
m.dispatch(hdr)
if length == 0 {
m.dispatch(&RecvMsg{})
return
}
m.current = &RecvMsg{
Data: getMem(length),
}
}
func getMem(l int) []byte {
// TODO(mmukhi): Reuse this memory.
return make([]byte, l)
}
// CreateMessageHeader creates a gRPC-specific message header.
func CreateMessageHeader(l int, isCompressed bool) []byte {
// TODO(mmukhi): Investigate if this memory is worth
// reusing.
hdr := make([]byte, 5)
if isCompressed {
hdr[0] = byte(1)
}
binary.BigEndian.PutUint32(hdr[1:], uint32(l))
return hdr
}

View File

@ -0,0 +1,81 @@
/*
* Copyright 2018 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package msgdecoder
import (
"encoding/binary"
"reflect"
"testing"
)
func TestMessageDecoder(t *testing.T) {
for _, test := range []struct {
numFrames int
data []string
}{
{1, []string{"abc"}}, // One message per frame.
{1, []string{"abc", "def", "ghi"}}, // Multiple messages per frame.
{3, []string{"a", "bcdef", "ghif"}}, // Multiple messages over multiple frames.
} {
var want []*RecvMsg
for _, d := range test.data {
want = append(want, &RecvMsg{Length: len(d), Overhead: 5})
want = append(want, &RecvMsg{Data: []byte(d)})
}
var got []*RecvMsg
dcdr := NewMessageDecoder(func(r *RecvMsg) { got = append(got, r) })
for _, fr := range createFrames(test.numFrames, test.data) {
dcdr.Decode(fr, 0)
}
if !match(got, want) {
t.Fatalf("got: %v, want: %v", got, want)
}
}
}
func match(got, want []*RecvMsg) bool {
for i, v := range got {
if !reflect.DeepEqual(v, want[i]) {
return false
}
}
return true
}
func createFrames(n int, msgs []string) [][]byte {
var b []byte
for _, m := range msgs {
payload := []byte(m)
hdr := make([]byte, 5)
binary.BigEndian.PutUint32(hdr[1:], uint32(len(payload)))
b = append(b, hdr...)
b = append(b, payload...)
}
// break b into n parts.
var result [][]byte
batch := len(b) / n
for len(b) != 0 {
sz := batch
if len(b) < sz {
sz = len(b)
}
result = append(result, b[:sz])
b = b[sz:]
}
return result
}

View File

@ -21,7 +21,6 @@ package grpc
import (
"bytes"
"compress/gzip"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
@ -415,85 +414,39 @@ func (o CustomCodecCallOption) before(c *callInfo) error {
}
func (o CustomCodecCallOption) after(c *callInfo) {}
// The format of the payload: compressed or not?
type payloadFormat uint8
const (
compressionNone payloadFormat = iota // no compression
compressionMade
)
// parser reads complete gRPC messages from the underlying reader.
type parser struct {
// r is the underlying reader.
// See the comment on recvMsg for the permissible
// error types.
r io.Reader
// The header of a gRPC message. Find more detail at
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
header [5]byte
}
// recvMsg reads a complete gRPC message from the stream.
//
// It returns the message and its payload (compression/encoding)
// format. The caller owns the returned msg memory.
// It returns a flag set to true if message was compressed,
// the message as a byte slice or error if so.
// The caller owns the returned msg memory.
//
// If there is an error, possible values are:
// * io.EOF, when no messages remain
// * io.ErrUnexpectedEOF
// * of type transport.ConnectionError
// * of type transport.StreamError
// No other error values or types must be returned, which also means
// that the underlying io.Reader must not return an incompatible
// error.
func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byte, err error) {
if _, err := p.r.Read(p.header[:]); err != nil {
return 0, nil, err
// No other error values or types must be returned.
func recvMsg(s *transport.Stream, maxRecvMsgSize int) (bool, []byte, error) {
isCompressed, msg, err := s.Read(maxRecvMsgSize)
if err != nil {
return false, nil, err
}
pf = payloadFormat(p.header[0])
length := binary.BigEndian.Uint32(p.header[1:])
if length == 0 {
return pf, nil, nil
}
if int64(length) > int64(maxInt) {
return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max length allowed on current machine (%d vs. %d)", length, maxInt)
}
if int(length) > maxReceiveMessageSize {
return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize)
}
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
// of making it for each message:
msg = make([]byte, int(length))
if _, err := p.r.Read(msg); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return 0, nil, err
}
return pf, msg, nil
return isCompressed, msg, nil
}
// encode serializes msg and returns a buffer of message header and a buffer of msg.
// If msg is nil, it generates the message header and an empty msg buffer.
// encode serializes msg and returns a buffer of msg.
// If msg is nil, it generates an empty buffer.
// TODO(ddyihai): eliminate extra Compressor parameter.
func encode(c baseCodec, msg interface{}, cp Compressor, outPayload *stats.OutPayload, compressor encoding.Compressor) ([]byte, []byte, error) {
func encode(c baseCodec, msg interface{}, cp Compressor, outPayload *stats.OutPayload, compressor encoding.Compressor) ([]byte, error) {
var (
b []byte
cbuf *bytes.Buffer
)
const (
payloadLen = 1
sizeLen = 4
)
if msg != nil {
var err error
b, err = c.Marshal(msg)
if err != nil {
return nil, nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
return nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
}
if outPayload != nil {
outPayload.Payload = msg
@ -507,49 +460,36 @@ func encode(c baseCodec, msg interface{}, cp Compressor, outPayload *stats.OutPa
if compressor != nil {
z, _ := compressor.Compress(cbuf)
if _, err := z.Write(b); err != nil {
return nil, nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
return nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
}
z.Close()
} else {
// If Compressor is not set by UseCompressor, use default Compressor
if err := cp.Do(cbuf, b); err != nil {
return nil, nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
return nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
}
}
b = cbuf.Bytes()
}
}
if uint(len(b)) > math.MaxUint32 {
return nil, nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b))
return nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b))
}
bufHeader := make([]byte, payloadLen+sizeLen)
if compressor != nil || cp != nil {
bufHeader[0] = byte(compressionMade)
} else {
bufHeader[0] = byte(compressionNone)
}
// Write length of b into buf
binary.BigEndian.PutUint32(bufHeader[payloadLen:], uint32(len(b)))
if outPayload != nil {
outPayload.WireLength = payloadLen + sizeLen + len(b)
// A 5 byte gRPC-specific message header will added to this message
// before it's put on wire.
outPayload.WireLength = 5 + len(b)
}
return bufHeader, b, nil
return b, nil
}
func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status {
switch pf {
case compressionNone:
case compressionMade:
if recvCompress == "" || recvCompress == encoding.Identity {
return status.New(codes.Internal, "grpc: compressed flag set with identity or empty encoding")
}
if !haveCompressor {
return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
}
default:
return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf)
func checkRecvPayload(recvCompress string, haveCompressor bool) *status.Status {
if recvCompress == "" || recvCompress == encoding.Identity {
return status.New(codes.Internal, "grpc: compressed flag set with identity or empty encoding")
}
if !haveCompressor {
return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
}
return nil
}
@ -557,8 +497,8 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool
// For the two compressor parameters, both should not be set, but if they are,
// dc takes precedence over compressor.
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload, compressor encoding.Compressor) error {
pf, d, err := p.recvMsg(maxReceiveMessageSize)
func recv(c baseCodec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload, compressor encoding.Compressor) error {
isCompressed, d, err := recvMsg(s, maxReceiveMessageSize)
if err != nil {
return err
}
@ -566,11 +506,10 @@ func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m interf
inPayload.WireLength = len(d)
}
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil {
return st.Err()
}
if pf == compressionMade {
if isCompressed {
if st := checkRecvPayload(s.RecvCompress(), compressor != nil || dc != nil); st != nil {
return st.Err()
}
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
// use this decompressor as the default.
if dc != nil {
@ -588,11 +527,11 @@ func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m interf
return status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
}
}
}
if len(d) > maxReceiveMessageSize {
// TODO: Revisit the error code. Currently keep it consistent with java
// implementation.
return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(d), maxReceiveMessageSize)
if len(d) > maxReceiveMessageSize {
// TODO: Revisit the error code. Currently keep it consistent with java
// implementation.
return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(d), maxReceiveMessageSize)
}
}
if err := c.Unmarshal(d, m); err != nil {
return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)

View File

@ -22,7 +22,6 @@ import (
"bytes"
"compress/gzip"
"io"
"math"
"reflect"
"testing"
@ -45,77 +44,20 @@ func (f fullReader) Read(p []byte) (int, error) {
var _ CallOption = EmptyCallOption{} // ensure EmptyCallOption implements the interface
func TestSimpleParsing(t *testing.T) {
bigMsg := bytes.Repeat([]byte{'x'}, 1<<24)
for _, test := range []struct {
// input
p []byte
// outputs
err error
b []byte
pt payloadFormat
}{
{nil, io.EOF, nil, compressionNone},
{[]byte{0, 0, 0, 0, 0}, nil, nil, compressionNone},
{[]byte{0, 0, 0, 0, 1, 'a'}, nil, []byte{'a'}, compressionNone},
{[]byte{1, 0}, io.ErrUnexpectedEOF, nil, compressionNone},
{[]byte{0, 0, 0, 0, 10, 'a'}, io.ErrUnexpectedEOF, nil, compressionNone},
// Check that messages with length >= 2^24 are parsed.
{append([]byte{0, 1, 0, 0, 0}, bigMsg...), nil, bigMsg, compressionNone},
} {
buf := fullReader{bytes.NewReader(test.p)}
parser := &parser{r: buf}
pt, b, err := parser.recvMsg(math.MaxInt32)
if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt {
t.Fatalf("parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err)
}
}
}
func TestMultipleParsing(t *testing.T) {
// Set a byte stream consists of 3 messages with their headers.
p := []byte{0, 0, 0, 0, 1, 'a', 0, 0, 0, 0, 2, 'b', 'c', 0, 0, 0, 0, 1, 'd'}
b := fullReader{bytes.NewReader(p)}
parser := &parser{r: b}
wantRecvs := []struct {
pt payloadFormat
data []byte
}{
{compressionNone, []byte("a")},
{compressionNone, []byte("bc")},
{compressionNone, []byte("d")},
}
for i, want := range wantRecvs {
pt, data, err := parser.recvMsg(math.MaxInt32)
if err != nil || pt != want.pt || !reflect.DeepEqual(data, want.data) {
t.Fatalf("after %d calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, <nil>",
i, p, pt, data, err, want.pt, want.data)
}
}
pt, data, err := parser.recvMsg(math.MaxInt32)
if err != io.EOF {
t.Fatalf("after %d recvMsgs calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant _, _, %v",
len(wantRecvs), p, pt, data, err, io.EOF)
}
}
func TestEncode(t *testing.T) {
for _, test := range []struct {
// input
msg proto.Message
cp Compressor
// outputs
hdr []byte
data []byte
err error
}{
{nil, nil, []byte{0, 0, 0, 0, 0}, []byte{}, nil},
{nil, nil, []byte{}, nil},
} {
hdr, data, err := encode(encoding.GetCodec(protoenc.Name), test.msg, nil, nil, nil)
if err != test.err || !bytes.Equal(hdr, test.hdr) || !bytes.Equal(data, test.data) {
t.Fatalf("encode(_, _, %v, _) = %v, %v, %v\nwant %v, %v, %v", test.cp, hdr, data, err, test.hdr, test.data, test.err)
data, err := encode(encoding.GetCodec(protoenc.Name), test.msg, nil, nil, nil)
if err != test.err || !bytes.Equal(data, test.data) {
t.Fatalf("encode(_, _, %v, _) = %v, %v\nwant %v, %v", test.cp, data, err, test.data, test.err)
}
}
}
@ -214,8 +156,11 @@ func TestParseDialTarget(t *testing.T) {
func bmEncode(b *testing.B, mSize int) {
cdc := encoding.GetCodec(protoenc.Name)
msg := &perfpb.Buffer{Body: make([]byte, mSize)}
encodeHdr, encodeData, _ := encode(cdc, msg, nil, nil, nil)
encodedSz := int64(len(encodeHdr) + len(encodeData))
encodeData, _ := encode(cdc, msg, nil, nil, nil)
// 5 bytes of gRPC-specific message header
// is added to the message before it is written
// to the wire.
encodedSz := int64(5 + len(encodeData))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {

View File

@ -831,7 +831,7 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str
if s.opts.statsHandler != nil {
outPayload = &stats.OutPayload{}
}
hdr, data, err := encode(s.getCodec(stream.ContentSubtype()), msg, cp, outPayload, comp)
data, err := encode(s.getCodec(stream.ContentSubtype()), msg, cp, outPayload, comp)
if err != nil {
grpclog.Errorln("grpc: server failed to encode response: ", err)
return err
@ -839,7 +839,8 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str
if len(data) > s.opts.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), s.opts.maxSendMessageSize)
}
err = t.Write(stream, hdr, data, opts)
opts.IsCompressed = cp != nil || comp != nil
err = t.Write(stream, data, opts)
if err == nil && outPayload != nil {
outPayload.SentTime = time.Now()
s.opts.statsHandler.HandleRPC(stream.Context(), outPayload)
@ -924,8 +925,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
}
}
p := &parser{r: stream}
pf, req, err := p.recvMsg(s.opts.maxReceiveMessageSize)
isCompressed, req, err := recvMsg(stream, s.opts.maxReceiveMessageSize)
if err == io.EOF {
// The entire stream is done (for unary RPC only).
return err
@ -955,12 +955,6 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
if channelz.IsOn() {
t.IncrMsgRecv()
}
if st := checkRecvPayload(pf, stream.RecvCompress(), dc != nil || decomp != nil); st != nil {
if e := t.WriteStatus(stream, st); e != nil {
grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e)
}
return st.Err()
}
var inPayload *stats.InPayload
if sh != nil {
inPayload = &stats.InPayload{
@ -971,7 +965,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
if inPayload != nil {
inPayload.WireLength = len(req)
}
if pf == compressionMade {
if isCompressed {
if st := checkRecvPayload(stream.RecvCompress(), dc != nil || decomp != nil); st != nil {
return st.Err()
}
var err error
if dc != nil {
req, err = dc.Do(bytes.NewReader(req))
@ -985,11 +982,11 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
return status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
}
}
}
if len(req) > s.opts.maxReceiveMessageSize {
// TODO: Revisit the error code. Currently keep it consistent with
// java implementation.
return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(req), s.opts.maxReceiveMessageSize)
if len(req) > s.opts.maxReceiveMessageSize {
// TODO: Revisit the error code. Currently keep it consistent with
// java implementation.
return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(req), s.opts.maxReceiveMessageSize)
}
}
if err := s.getCodec(stream.ContentSubtype()).Unmarshal(req, v); err != nil {
return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
@ -1100,7 +1097,6 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
ctx: ctx,
t: t,
s: stream,
p: &parser{r: stream},
codec: s.getCodec(stream.ContentSubtype()),
maxReceiveMessageSize: s.opts.maxReceiveMessageSize,
maxSendMessageSize: s.opts.maxSendMessageSize,

View File

@ -290,7 +290,6 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
attempt: &csAttempt{
t: t,
s: s,
p: &parser{r: s},
done: done,
dc: cc.dopts.dc,
ctx: ctx,
@ -347,7 +346,6 @@ type csAttempt struct {
cs *clientStream
t transport.ClientTransport
s *transport.Stream
p *parser
done func(balancer.DoneInfo)
dc Decompressor
@ -472,7 +470,7 @@ func (a *csAttempt) sendMsg(m interface{}) (err error) {
Client: true,
}
}
hdr, data, err := encode(cs.codec, m, cs.cp, outPayload, cs.comp)
data, err := encode(cs.codec, m, cs.cp, outPayload, cs.comp)
if err != nil {
return err
}
@ -482,7 +480,11 @@ func (a *csAttempt) sendMsg(m interface{}) (err error) {
if !cs.desc.ClientStreams {
cs.sentLast = true
}
err = a.t.Write(a.s, hdr, data, &transport.Options{Last: !cs.desc.ClientStreams})
opts := &transport.Options{
Last: !cs.desc.ClientStreams,
IsCompressed: cs.cp != nil || cs.comp != nil,
}
err = a.t.Write(a.s, data, opts)
if err == nil {
if outPayload != nil {
outPayload.SentTime = time.Now()
@ -526,7 +528,7 @@ func (a *csAttempt) recvMsg(m interface{}) (err error) {
// Only initialize this state once per stream.
a.decompSet = true
}
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.c.maxReceiveMessageSize, inPayload, a.decomp)
err = recv(cs.codec, a.s, a.dc, m, *cs.c.maxReceiveMessageSize, inPayload, a.decomp)
if err != nil {
if err == io.EOF {
if statusErr := a.s.Status().Err(); statusErr != nil {
@ -556,7 +558,7 @@ func (a *csAttempt) recvMsg(m interface{}) (err error) {
// Special handling for non-server-stream rpcs.
// This recv expects EOF or errors, so we don't collect inPayload.
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.c.maxReceiveMessageSize, nil, a.decomp)
err = recv(cs.codec, a.s, a.dc, m, *cs.c.maxReceiveMessageSize, nil, a.decomp)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
}
@ -572,7 +574,7 @@ func (a *csAttempt) closeSend() {
return
}
cs.sentLast = true
cs.attempt.t.Write(cs.attempt.s, nil, nil, &transport.Options{Last: true})
cs.attempt.t.Write(cs.attempt.s, nil, &transport.Options{Last: true})
// We ignore errors from Write. Any error it would return would also be
// returned by a subsequent RecvMsg call, and the user is supposed to always
// finish the stream by calling RecvMsg until it returns err != nil.
@ -635,7 +637,6 @@ type serverStream struct {
ctx context.Context
t transport.ServerTransport
s *transport.Stream
p *parser
codec baseCodec
cp Compressor
@ -700,14 +701,18 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
if ss.statsHandler != nil {
outPayload = &stats.OutPayload{}
}
hdr, data, err := encode(ss.codec, m, ss.cp, outPayload, ss.comp)
data, err := encode(ss.codec, m, ss.cp, outPayload, ss.comp)
if err != nil {
return err
}
if len(data) > ss.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize)
}
if err := ss.t.Write(ss.s, hdr, data, &transport.Options{Last: false}); err != nil {
opts := &transport.Options{
Last: false,
IsCompressed: ss.cp != nil || ss.comp != nil,
}
if err := ss.t.Write(ss.s, data, opts); err != nil {
return toRPCErr(err)
}
if outPayload != nil {
@ -743,7 +748,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
if ss.statsHandler != nil {
inPayload = &stats.InPayload{}
}
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload, ss.decomp); err != nil {
if err := recv(ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload, ss.decomp); err != nil {
if err == io.EOF {
return err
}

View File

@ -21,7 +21,6 @@ package transport
import (
"fmt"
"math"
"sync"
"sync/atomic"
"time"
)
@ -96,35 +95,39 @@ func (w *writeQuota) replenish(n int) {
}
type trInFlow struct {
limit uint32
unacked uint32
effectiveWindowSize uint32
limit uint32 // accessed by reader goroutine.
unacked uint32 // accessed by reader goroutine.
effectiveWindowSize uint32 // accessed by reader and channelz request goroutine.
// Callback used to schedule window update.
scheduleWU func(uint32)
}
func (f *trInFlow) newLimit(n uint32) uint32 {
d := n - f.limit
// Sets the new limit.
func (f *trInFlow) newLimit(n uint32) {
if n > f.limit {
f.scheduleWU(n - f.limit)
}
f.limit = n
f.updateEffectiveWindowSize()
return d
}
func (f *trInFlow) onData(n uint32) uint32 {
func (f *trInFlow) onData(n uint32) {
f.unacked += n
if f.unacked >= f.limit/4 {
w := f.unacked
f.unacked = 0
f.updateEffectiveWindowSize()
return w
f.scheduleWU(w)
}
f.updateEffectiveWindowSize()
return 0
}
func (f *trInFlow) reset() uint32 {
w := f.unacked
func (f *trInFlow) reset() {
if f.unacked == 0 {
return
}
f.scheduleWU(f.unacked)
f.unacked = 0
f.updateEffectiveWindowSize()
return w
}
func (f *trInFlow) updateEffectiveWindowSize() {
@ -135,102 +138,57 @@ func (f *trInFlow) getSize() uint32 {
return atomic.LoadUint32(&f.effectiveWindowSize)
}
// TODO(mmukhi): Simplify this code.
// inFlow deals with inbound flow control
type inFlow struct {
mu sync.Mutex
// stInFlow deals with inbound flow control for stream.
// It can be simultaneously read by transport's reader
// goroutine and an RPC's goroutine.
// It is protected by the lock in stream that owns it.
type stInFlow struct {
// rcvd is the bytes of data that this end-point has
// received from the perspective of other side.
// This can go negative. It must be Accessed atomically.
// Needs to be aligned because of golang bug with atomics:
// https://golang.org/pkg/sync/atomic/#pkg-note-BUG
rcvd int64
// The inbound flow control limit for pending data.
limit uint32
// pendingData is the overall data which have been received but not been
// consumed by applications.
pendingData uint32
// The amount of data the application has consumed but grpc has not sent
// window update for them. Used to reduce window update frequency.
pendingUpdate uint32
// delta is the extra window update given by receiver when an application
// is reading data bigger in size than the inFlow limit.
delta uint32
// number of bytes received so far, this should be accessed
// number of bytes that have been read by the RPC.
read uint32
// a window update should be sent when the RPC has
// read these many bytes.
// TODO(mmukhi, dfawley): Does this have to be limit/4?
// Keeping it a constant makes implementation easy.
wuThreshold uint32
// Callback used to schedule window update.
scheduleWU func(uint32)
}
// newLimit updates the inflow window to a new value n.
// It assumes that n is always greater than the old limit.
func (f *inFlow) newLimit(n uint32) uint32 {
f.mu.Lock()
d := n - f.limit
f.limit = n
f.mu.Unlock()
return d
// called by transport's reader goroutine to set a new limit on
// incoming flow control based on BDP estimation.
func (s *stInFlow) newLimit(n uint32) {
s.limit = n
}
func (f *inFlow) maybeAdjust(n uint32) uint32 {
if n > uint32(math.MaxInt32) {
n = uint32(math.MaxInt32)
// called by transport's reader goroutine when data is received by it.
func (s *stInFlow) onData(n uint32) error {
rcvd := atomic.AddInt64(&s.rcvd, int64(n))
if rcvd > int64(s.limit) { // Flow control violation.
return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", rcvd, s.limit)
}
f.mu.Lock()
// estSenderQuota is the receiver's view of the maximum number of bytes the sender
// can send without a window update.
estSenderQuota := int32(f.limit - (f.pendingData + f.pendingUpdate))
// estUntransmittedData is the maximum number of bytes the sends might not have put
// on the wire yet. A value of 0 or less means that we have already received all or
// more bytes than the application is requesting to read.
estUntransmittedData := int32(n - f.pendingData) // Casting into int32 since it could be negative.
// This implies that unless we send a window update, the sender won't be able to send all the bytes
// for this message. Therefore we must send an update over the limit since there's an active read
// request from the application.
if estUntransmittedData > estSenderQuota {
// Sender's window shouldn't go more than 2^31 - 1 as specified in the HTTP spec.
if f.limit+n > maxWindowSize {
f.delta = maxWindowSize - f.limit
} else {
// Send a window update for the whole message and not just the difference between
// estUntransmittedData and estSenderQuota. This will be helpful in case the message
// is padded; We will fallback on the current available window(at least a 1/4th of the limit).
f.delta = n
}
f.mu.Unlock()
return f.delta
}
f.mu.Unlock()
return 0
}
// onData is invoked when some data frame is received. It updates pendingData.
func (f *inFlow) onData(n uint32) error {
f.mu.Lock()
f.pendingData += n
if f.pendingData+f.pendingUpdate > f.limit+f.delta {
limit := f.limit
rcvd := f.pendingData + f.pendingUpdate
f.mu.Unlock()
return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", rcvd, limit)
}
f.mu.Unlock()
return nil
}
// onRead is invoked when the application reads the data. It returns the window size
// to be sent to the peer.
func (f *inFlow) onRead(n uint32) uint32 {
f.mu.Lock()
if f.pendingData == 0 {
f.mu.Unlock()
return 0
// called by RPC's goroutine when data is read by it.
func (s *stInFlow) onRead(n uint32) {
s.read += n
if s.read >= s.wuThreshold {
val := atomic.AddInt64(&s.rcvd, ^int64(s.read-1))
// Check if threshold needs to go up since limit might have gone up.
val += int64(s.read)
if val > int64(4*s.wuThreshold) {
s.wuThreshold = uint32(val / 4)
}
s.scheduleWU(s.read)
s.read = 0
}
f.pendingData -= n
if n > f.delta {
n -= f.delta
f.delta = 0
} else {
f.delta -= n
n = 0
}
f.pendingUpdate += n
if f.pendingUpdate >= f.limit/4 {
wu := f.pendingUpdate
f.pendingUpdate = 0
f.mu.Unlock()
return wu
}
f.mu.Unlock()
return 0
}

View File

@ -38,6 +38,7 @@ import (
"golang.org/x/net/http2"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/msgdecoder"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/stats"
@ -269,10 +270,10 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
}
}
func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
func (ht *serverHandlerTransport) Write(s *Stream, data []byte, opts *Options) error {
return ht.do(func() {
ht.writeCommonHeaders(s)
ht.rw.Write(hdr)
ht.rw.Write(msgdecoder.CreateMessageHeader(len(data), opts.IsCompressed))
ht.rw.Write(data)
if !opts.Delay {
ht.rw.(http.Flusher).Flush()
@ -337,16 +338,13 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
req := ht.req
s := &Stream{
id: 0, // irrelevant
requestRead: func(int) {},
cancel: cancel,
buf: newRecvBuffer(),
st: ht,
method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
contentSubtype: ht.contentSubtype,
}
s := newStream(ctx)
s.cancel = cancel
s.st = ht
s.method = req.URL.Path
s.recvCompress = req.Header.Get("grpc-encoding")
s.contentSubtype = ht.contentSubtype
pr := &peer.Peer{
Addr: ht.RemoteAddr(),
}
@ -364,10 +362,6 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
}
ht.stats.HandleRPC(s.ctx, inHeader)
}
s.trReader = &transportReader{
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf},
windowHandler: func(int) {},
}
// readerDone is closed when the Body.Read-ing goroutine exits.
readerDone := make(chan struct{})
@ -379,11 +373,11 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
for buf := make([]byte, readSize); ; {
n, err := req.Body.Read(buf)
if n > 0 {
s.buf.put(recvMsg{data: buf[:n:n]})
s.consume(buf[:n:n], 0)
buf = buf[n:]
}
if err != nil {
s.buf.put(recvMsg{err: mapRecvMsgError(err)})
s.notifyErr(mapRecvMsgError(err))
return
}
if len(buf) == 0 {

View File

@ -423,7 +423,7 @@ func TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
st.bodyw.Close() // no body
st.ht.WriteStatus(s, status.New(codes.OK, ""))
st.ht.Write(s, []byte("hdr"), []byte("data"), &Options{})
st.ht.Write(s, []byte("data"), &Options{})
})
}

View File

@ -34,6 +34,7 @@ import (
"google.golang.org/grpc/channelz"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/msgdecoder"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
@ -95,8 +96,9 @@ type http2Client struct {
waitingStreams uint32
nextID uint32
mu sync.Mutex // guard the following variables
state transportState
mu sync.Mutex // guard the following variables
state transportState
// TODO(mmukhi): Make this a sharded map.
activeStreams map[uint32]*Stream
// prevGoAway ID records the Last-Stream-ID in the previous GOAway frame.
prevGoAwayID uint32
@ -218,7 +220,6 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
goAway: make(chan struct{}),
awakenKeepalive: make(chan struct{}, 1),
framer: newFramer(conn, writeBufSize, readBufSize),
fc: &trInFlow{limit: uint32(icwz)},
scheme: scheme,
activeStreams: make(map[uint32]*Stream),
isSecure: isSecure,
@ -233,6 +234,15 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
streamsQuotaAvailable: make(chan struct{}, 1),
}
t.controlBuf = newControlBuffer(t.ctxDone)
t.fc = &trInFlow{
limit: uint32(icwz),
scheduleWU: func(w uint32) {
t.controlBuf.put(&outgoingWindowUpdate{
streamID: 0,
increment: w,
})
},
}
if opts.InitialWindowSize >= defaultWindowSize {
t.initialWindowSize = opts.InitialWindowSize
dynamicWindow = false
@ -306,33 +316,17 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
}
func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
// TODO(zhaoq): Handle uint32 overflow of Stream.id.
s := &Stream{
done: make(chan struct{}),
method: callHdr.Method,
sendCompress: callHdr.SendCompress,
buf: newRecvBuffer(),
headerChan: make(chan struct{}),
contentSubtype: callHdr.ContentSubtype,
}
s.wq = newWriteQuota(defaultWriteQuota, s.done)
s.requestRead = func(n int) {
t.adjustWindow(s, uint32(n))
}
// The client side stream context should have exactly the same life cycle with the user provided context.
// That means, s.ctx should be read-only. And s.ctx is done iff ctx is done.
// So we use the original context here instead of creating a copy.
s.ctx = ctx
s.trReader = &transportReader{
reader: &recvBufferReader{
ctx: s.ctx,
ctxDone: s.ctx.Done(),
recv: s.buf,
},
windowHandler: func(n int) {
t.updateWindow(s, uint32(n))
},
}
s := newStream(ctx)
// Initialize stream with client-side specific fields.
s.done = make(chan struct{})
s.method = callHdr.Method
s.sendCompress = callHdr.SendCompress
s.headerChan = make(chan struct{})
s.contentSubtype = callHdr.ContentSubtype
s.wq = newWriteQuota(defaultWriteQuota, s.done)
return s
}
@ -504,7 +498,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
}
// The stream was unprocessed by the server.
atomic.StoreUint32(&s.unprocessed, 1)
s.write(recvMsg{err: err})
s.notifyErr(err)
close(s.done)
// If headerChan isn't closed, then close it.
if atomic.SwapUint32(&s.headerDone, 1) == 0 {
@ -572,7 +566,13 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
h.streamID = t.nextID
t.nextID += 2
s.id = h.streamID
s.fc = &inFlow{limit: uint32(t.initialWindowSize)}
s.fc = &stInFlow{
limit: uint32(t.initialWindowSize),
scheduleWU: func(w uint32) {
t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w})
},
wuThreshold: uint32(t.initialWindowSize / 4),
}
if t.streamQuota > 0 && t.waitingStreams > 0 {
select {
case t.streamsQuotaAvailable <- struct{}{}:
@ -642,7 +642,7 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.
}
if err != nil {
// This will unblock reads eventually.
s.write(recvMsg{err: err})
s.notifyErr(err)
}
// This will unblock write.
close(s.done)
@ -740,7 +740,7 @@ func (t *http2Client) GracefulClose() error {
// Write formats the data into HTTP2 data frame(s) and sends it out. The caller
// should proceed only if Write returns nil.
func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
if opts.Last {
// If it's the last message, update stream state.
if !s.compareAndSwapState(streamActive, streamWriteDone) {
@ -753,7 +753,9 @@ func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
streamID: s.id,
endStream: opts.Last,
}
if hdr != nil || data != nil { // If it's not an empty data frame.
if data != nil { // If it's not an empty data frame.
// Get a gRPC-specific header for this message.
hdr := msgdecoder.CreateMessageHeader(len(data), opts.IsCompressed)
// Add some data to grpc message header so that we can equally
// distribute bytes across frames.
emptyLen := http2MaxFrameLen - len(hdr)
@ -778,39 +780,19 @@ func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) {
return s, ok
}
// adjustWindow sends out extra window update over the initial window size
// of stream if the application is requesting data larger in size than
// the window.
func (t *http2Client) adjustWindow(s *Stream, n uint32) {
if w := s.fc.maybeAdjust(n); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w})
}
}
// updateWindow adjusts the inbound quota for the stream.
// Window updates will be sent out when the cumulative quota
// exceeds the corresponding threshold.
func (t *http2Client) updateWindow(s *Stream, n uint32) {
if w := s.fc.onRead(n); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w})
}
}
// updateFlowControl updates the incoming flow control windows
// for the transport and the stream based on the current bdp
// estimation.
func (t *http2Client) updateFlowControl(n uint32) {
t.mu.Lock()
for _, s := range t.activeStreams {
s.fc.newLimit(n)
}
t.mu.Unlock()
updateIWS := func(interface{}) bool {
t.fc.newLimit(n) // Update transport's window.
updateIWS := func(interface{}) bool { // Update streams' windows.
// All future streams should see the
// updated value.
t.initialWindowSize = int32(n)
return true
}
t.controlBuf.executeAndPut(updateIWS, &outgoingWindowUpdate{streamID: 0, increment: t.fc.newLimit(n)})
t.controlBuf.put(&outgoingSettings{
// Notify the other side of updated window.
t.controlBuf.executeAndPut(updateIWS, &outgoingSettings{
ss: []http2.Setting{
{
ID: http2.SettingInitialWindowSize,
@ -818,13 +800,25 @@ func (t *http2Client) updateFlowControl(n uint32) {
},
},
})
t.mu.Lock()
// Update all the currently active streams.
for _, s := range t.activeStreams {
s.fc.newLimit(n)
}
t.mu.Unlock()
}
func (t *http2Client) handleData(f *http2.DataFrame) {
size := f.Header().Length
var sendBDPPing bool
if t.bdpEst != nil {
sendBDPPing = t.bdpEst.add(size)
if size == 0 {
if f.StreamEnded() {
// The server has closed the stream without sending trailers. Record that
// the read direction is closed, and set the status appropriately.
if s, ok := t.getStream(f); ok {
t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true)
}
}
return
}
// Decouple connection's flow control from application's read.
// An update on connection's flow control should not depend on
@ -835,53 +829,30 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
// active(fast) streams from starving in presence of slow or
// inactive streams.
//
if w := t.fc.onData(size); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{
streamID: 0,
increment: w,
})
}
if sendBDPPing {
t.fc.onData(size)
if t.bdpEst != nil && t.bdpEst.add(size) {
// Avoid excessive ping detection (e.g. in an L7 proxy)
// by sending a window update prior to the BDP ping.
if w := t.fc.reset(); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{
streamID: 0,
increment: w,
})
}
t.fc.reset()
t.controlBuf.put(bdpPing)
}
// Select the right stream to dispatch.
s, ok := t.getStream(f)
if !ok {
return
}
if size > 0 {
if err := s.fc.onData(size); err != nil {
if s, ok := t.getStream(f); ok {
d := f.Data()
padding := 0
if f.Header().Flags.Has(http2.FlagDataPadded) {
padding = int(size) - len(d)
}
if err := s.consume(d, padding); err != nil {
t.closeStream(s, io.EOF, true, http2.ErrCodeFlowControl, status.New(codes.Internal, err.Error()), nil, false)
return
}
if f.Header().Flags.Has(http2.FlagDataPadded) {
if w := s.fc.onRead(size - uint32(len(f.Data()))); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{s.id, w})
}
if f.StreamEnded() {
// The server has closed the stream without sending trailers. Record that
// the read direction is closed, and set the status appropriately.
t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true)
}
// TODO(bradfitz, zhaoq): A copy is required here because there is no
// guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated?
if len(f.Data()) > 0 {
data := make([]byte, len(f.Data()))
copy(data, f.Data())
s.write(recvMsg{data: data})
}
}
// The server has closed the stream without sending trailers. Record that
// the read direction is closed, and set the status appropriately.
if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) {
t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true)
}
}
@ -890,6 +861,7 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
if !ok {
return
}
errorf("transport: client got RST_STREAM with error %v, for stream: %d", f.ErrCode, s.id)
if f.ErrCode == http2.ErrCodeRefusedStream {
// The stream was unprocessed by the server.
atomic.StoreUint32(&s.unprocessed, 1)

View File

@ -39,6 +39,7 @@ import (
"google.golang.org/grpc/channelz"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/msgdecoder"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
@ -212,7 +213,6 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
writerDone: make(chan struct{}),
maxStreams: maxStreams,
inTapHandle: config.InTapHandle,
fc: &trInFlow{limit: uint32(icwz)},
state: reachable,
activeStreams: make(map[uint32]*Stream),
stats: config.StatsHandler,
@ -222,6 +222,15 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
initialWindowSize: iwz,
}
t.controlBuf = newControlBuffer(t.ctxDone)
t.fc = &trInFlow{
limit: uint32(icwz),
scheduleWU: func(w uint32) {
t.controlBuf.put(&outgoingWindowUpdate{
streamID: 0,
increment: w,
})
},
}
if dynamicWindow {
t.bdpEst = &bdpEstimator{
bdp: initialWindowSize,
@ -298,25 +307,14 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
return
}
}
buf := newRecvBuffer()
s := &Stream{
id: streamID,
st: t,
buf: buf,
fc: &inFlow{limit: uint32(t.initialWindowSize)},
recvCompress: state.encoding,
method: state.method,
contentSubtype: state.contentSubtype,
}
if frame.StreamEnded() {
// s is just created by the caller. No lock needed.
s.state = streamReadDone
}
var (
ctx context.Context
cancel func()
)
if state.timeoutSet {
s.ctx, s.cancel = context.WithTimeout(t.ctx, state.timeout)
ctx, cancel = context.WithTimeout(t.ctx, state.timeout)
} else {
s.ctx, s.cancel = context.WithCancel(t.ctx)
ctx, cancel = context.WithCancel(t.ctx)
}
pr := &peer.Peer{
Addr: t.remoteAddr,
@ -325,34 +323,55 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
if t.authInfo != nil {
pr.AuthInfo = t.authInfo
}
s.ctx = peer.NewContext(s.ctx, pr)
ctx = peer.NewContext(ctx, pr)
// Attach the received metadata to the context.
if len(state.mdata) > 0 {
s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata)
ctx = metadata.NewIncomingContext(ctx, state.mdata)
}
if state.statsTags != nil {
s.ctx = stats.SetIncomingTags(s.ctx, state.statsTags)
ctx = stats.SetIncomingTags(ctx, state.statsTags)
}
if state.statsTrace != nil {
s.ctx = stats.SetIncomingTrace(s.ctx, state.statsTrace)
ctx = stats.SetIncomingTrace(ctx, state.statsTrace)
}
if t.inTapHandle != nil {
var err error
info := &tap.Info{
FullMethodName: state.method,
}
s.ctx, err = t.inTapHandle(s.ctx, info)
ctx, err = t.inTapHandle(ctx, info)
if err != nil {
warningf("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err)
t.controlBuf.put(&cleanupStream{
streamID: s.id,
streamID: streamID,
rst: true,
rstCode: http2.ErrCodeRefusedStream,
onWrite: func() {},
})
cancel()
return
}
}
ctx = traceCtx(ctx, state.method)
s := newStream(ctx)
// Initialize s with server-side specific fields.
s.cancel = cancel
s.id = streamID
s.st = t
s.fc = &stInFlow{
limit: uint32(t.initialWindowSize),
scheduleWU: func(w uint32) {
t.controlBuf.put(&outgoingWindowUpdate{streamID: streamID, increment: w})
},
wuThreshold: uint32(t.initialWindowSize / 4),
}
s.recvCompress = state.encoding
s.method = state.method
s.contentSubtype = state.contentSubtype
if frame.StreamEnded() {
s.state = streamReadDone
}
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
t.mu.Lock()
if t.state != reachable {
t.mu.Unlock()
@ -386,10 +405,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
t.lastStreamCreated = time.Now()
t.czmu.Unlock()
}
s.requestRead = func(n int) {
t.adjustWindow(s, uint32(n))
}
s.ctx = traceCtx(s.ctx, s.method)
if t.stats != nil {
s.ctx = t.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
inHeader := &stats.InHeader{
@ -401,18 +416,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
}
t.stats.HandleRPC(s.ctx, inHeader)
}
s.ctxDone = s.ctx.Done()
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
s.trReader = &transportReader{
reader: &recvBufferReader{
ctx: s.ctx,
ctxDone: s.ctxDone,
recv: s.buf,
},
windowHandler: func(n int) {
t.updateWindow(s, uint32(n))
},
}
handle(s)
return
}
@ -490,41 +493,20 @@ func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) {
return s, true
}
// adjustWindow sends out extra window update over the initial window size
// of stream if the application is requesting data larger in size than
// the window.
func (t *http2Server) adjustWindow(s *Stream, n uint32) {
if w := s.fc.maybeAdjust(n); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w})
}
}
// updateWindow adjusts the inbound quota for the stream and the transport.
// Window updates will deliver to the controller for sending when
// the cumulative quota exceeds the corresponding threshold.
func (t *http2Server) updateWindow(s *Stream, n uint32) {
if w := s.fc.onRead(n); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id,
increment: w,
})
}
}
// updateFlowControl updates the incoming flow control windows
// for the transport and the stream based on the current bdp
// estimation.
func (t *http2Server) updateFlowControl(n uint32) {
t.mu.Lock()
// Update all the current streams' window.
for _, s := range t.activeStreams {
s.fc.newLimit(n)
}
// Update all the future streams' window.
t.initialWindowSize = int32(n)
t.mu.Unlock()
t.controlBuf.put(&outgoingWindowUpdate{
streamID: 0,
increment: t.fc.newLimit(n),
})
t.fc.newLimit(n) // Update transport's window.
// Notify the other side of the updated value.
t.controlBuf.put(&outgoingSettings{
ss: []http2.Setting{
{
@ -538,9 +520,15 @@ func (t *http2Server) updateFlowControl(n uint32) {
func (t *http2Server) handleData(f *http2.DataFrame) {
size := f.Header().Length
var sendBDPPing bool
if t.bdpEst != nil {
sendBDPPing = t.bdpEst.add(size)
if size == 0 {
if f.StreamEnded() {
if s, ok := t.getStream(f); ok {
// Received the end of stream from the client.
s.compareAndSwapState(streamActive, streamReadDone)
s.notifyErr(io.EOF)
}
}
return
}
// Decouple connection's flow control from application's read.
// An update on connection's flow control should not depend on
@ -550,51 +538,30 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
// Decoupling the connection flow control will prevent other
// active(fast) streams from starving in presence of slow or
// inactive streams.
if w := t.fc.onData(size); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{
streamID: 0,
increment: w,
})
}
if sendBDPPing {
t.fc.onData(size)
if t.bdpEst != nil && t.bdpEst.add(size) {
// Avoid excessive ping detection (e.g. in an L7 proxy)
// by sending a window update prior to the BDP ping.
if w := t.fc.reset(); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{
streamID: 0,
increment: w,
})
}
t.fc.reset()
t.controlBuf.put(bdpPing)
}
// Select the right stream to dispatch.
s, ok := t.getStream(f)
if !ok {
return
}
if size > 0 {
if err := s.fc.onData(size); err != nil {
if s, ok := t.getStream(f); ok {
d := f.Data()
padding := 0
if f.Header().Flags.Has(http2.FlagDataPadded) {
padding = int(size) - len(d)
}
if err := s.consume(d, padding); err != nil {
errorf("transport: flow control error on server: %v", err)
t.closeStream(s, true, http2.ErrCodeFlowControl, nil, false)
return
}
if f.Header().Flags.Has(http2.FlagDataPadded) {
if w := s.fc.onRead(size - uint32(len(f.Data()))); w > 0 {
t.controlBuf.put(&outgoingWindowUpdate{s.id, w})
}
if f.StreamEnded() {
// Received the end of stream from the client.
s.compareAndSwapState(streamActive, streamReadDone)
s.notifyErr(io.EOF)
}
// TODO(bradfitz, zhaoq): A copy is required here because there is no
// guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated?
if len(f.Data()) > 0 {
data := make([]byte, len(f.Data()))
copy(data, f.Data())
s.write(recvMsg{data: data})
}
}
if f.Header().Flags.Has(http2.FlagDataEndStream) {
// Received the end of stream from the client.
s.compareAndSwapState(streamActive, streamReadDone)
s.write(recvMsg{err: io.EOF})
}
}
@ -792,7 +759,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
// Write converts the data into HTTP2 data frame and sends it out. Non-nil error
// is returns if it fails (e.g., framing error, transport error).
func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
if !s.headerOk { // Headers haven't been written yet.
if err := t.WriteHeader(s, nil); err != nil {
// TODO(mmukhi, dfawley): Make sure this is the right code to return.
@ -811,6 +778,8 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
return ContextErr(s.ctx.Err())
}
}
// Get a gRPC-specific header for this message.
hdr := msgdecoder.CreateMessageHeader(len(data), opts.IsCompressed)
// Add some data to header frame so that we can equally distribute bytes across frames.
emptyLen := http2MaxFrameLen - len(hdr)
if emptyLen > len(data) {

407
transport/stream.go Normal file
View File

@ -0,0 +1,407 @@
/*
*
* Copyright 2014 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package transport
import (
"fmt"
"io"
"sync"
"sync/atomic"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/internal/msgdecoder"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
const maxInt = int(^uint(0) >> 1)
type streamState uint32
const (
streamActive streamState = iota
streamWriteDone // EndStream sent
streamReadDone // EndStream received
streamDone // the entire stream is finished.
)
// transport's reader goroutine adds msgdecoder.RecvMsg to it which are later
// read by RPC's reading goroutine.
//
// It is protected by a lock in the Stream that owns it.
type recvBuffer struct {
ctx context.Context
ctxDone <-chan struct{}
c chan *msgdecoder.RecvMsg
mu sync.Mutex
waiting bool
list *msgdecoder.RecvMsgList
}
func newRecvBuffer(ctx context.Context, ctxDone <-chan struct{}) *recvBuffer {
return &recvBuffer{
ctx: ctx,
ctxDone: ctxDone,
c: make(chan *msgdecoder.RecvMsg, 1),
list: &msgdecoder.RecvMsgList{},
}
}
// put adds r to the underlying list if there's no consumer
// waiting, otherwise, it writes on the chan directly.
func (b *recvBuffer) put(r *msgdecoder.RecvMsg) {
b.mu.Lock()
if b.waiting {
b.waiting = false
b.mu.Unlock()
b.c <- r
return
}
b.list.Enqueue(r)
b.mu.Unlock()
}
// getNoBlock returns a msgdecoder.RecvMsg and true status, if there's
// any available.
// If the status is false, the caller must then call
// getWithBlock() before calling getNoBlock() again.
func (b *recvBuffer) getNoBlock() (*msgdecoder.RecvMsg, bool) {
b.mu.Lock()
r := b.list.Dequeue()
if r != nil {
b.mu.Unlock()
return r, true
}
b.waiting = true
b.mu.Unlock()
return nil, false
}
// getWithBlock() blocks until a complete message has been
// received, or an error has occurred or the underlying
// context has expired.
// It must only be called after having called GetNoBlock()
// once.
func (b *recvBuffer) getWithBlock() (*msgdecoder.RecvMsg, error) {
select {
case <-b.ctxDone:
return nil, ContextErr(b.ctx.Err())
case r := <-b.c:
return r, nil
}
}
func (b *recvBuffer) get() (*msgdecoder.RecvMsg, error) {
m, ok := b.getNoBlock()
if ok {
return m, nil
}
m, err := b.getWithBlock()
if err != nil {
return nil, err
}
return m, nil
}
// Stream represents an RPC in the transport layer.
type Stream struct {
id uint32
st ServerTransport // nil for client side Stream
ctx context.Context // the associated context of the stream
cancel context.CancelFunc // always nil for client side Stream
done chan struct{} // closed at the end of stream to unblock writers. On the client side.
ctxDone <-chan struct{} // same as done chan but for server side. Cache of ctx.Done() (for performance)
method string // the associated RPC method of the stream
recvCompress string
sendCompress string
buf *recvBuffer
wq *writeQuota
headerChan chan struct{} // closed to indicate the end of header metadata.
headerDone uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
header metadata.MD // the received header metadata.
trailer metadata.MD // the key-value map of trailer metadata.
headerOk bool // becomes true from the first header is about to send
state streamState
status *status.Status // the status error received from the server
bytesReceived uint32 // indicates whether any bytes have been received on this stream
unprocessed uint32 // set if the server sends a refused stream or GOAWAY including this stream
// contentSubtype is the content-subtype for requests.
// this must be lowercase or the behavior is undefined.
contentSubtype string
rbuf *recvBuffer
fc *stInFlow
msgDecoder *msgdecoder.MessageDecoder
readErr error
}
func newStream(ctx context.Context) *Stream {
// Cache the done chan since a Done() call is expensive.
ctxDone := ctx.Done()
s := &Stream{
ctx: ctx,
ctxDone: ctxDone,
rbuf: newRecvBuffer(ctx, ctxDone),
}
dispatch := func(r *msgdecoder.RecvMsg) {
s.rbuf.put(r)
}
s.msgDecoder = msgdecoder.NewMessageDecoder(dispatch)
return s
}
// notifyErr notifies RPC of an error seen by the transport.
//
// Note to developers: This call can unblock Read calls on RPC
// and lead to reading of unprotected fields on stream on the
// client-side. It should only be called from inside
// transport.closeStream() if the stream was initialized or from
// inside the cleanup callback if the stream was not initialized.
func (s *Stream) notifyErr(err error) {
s.rbuf.put(&msgdecoder.RecvMsg{Err: err})
}
// consume is called by transport's reader goroutine for parsing
// and decoding data received for this stream.
func (s *Stream) consume(b []byte, padding int) error {
// Flow control check.
if s.fc != nil { // HandlerServer doesn't use our flow control.
if err := s.fc.onData(uint32(len(b) + padding)); err != nil {
return err
}
}
s.msgDecoder.Decode(b, padding)
return nil
}
// Read reads one whole message from the transport.
// It is called by RPC's goroutine.
// It is not safe to be called concurrently by multiple goroutines.
//
// Returns:
// 1. received message's compression status(true if was compressed)
// 2. Message as a byte slice
// 3. Error, if any.
func (s *Stream) Read(maxRecvMsgSize int) (bool, []byte, error) {
if s.readErr != nil {
return false, nil, s.readErr
}
var (
m *msgdecoder.RecvMsg
err error
)
// First read the underlying message header
if m, err = s.rbuf.get(); err != nil {
s.readErr = err
return false, nil, err
}
if m.Err != nil {
s.readErr = m.Err
return false, nil, s.readErr
}
// Make sure the message being received isn't too large.
if int64(m.Length) > int64(maxInt) {
s.readErr = status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max length allowed on current machine (%d vs. %d)", m.Length, maxInt)
return false, nil, s.readErr
}
if m.Length > maxRecvMsgSize {
s.readErr = status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", m.Length, maxRecvMsgSize)
return false, nil, s.readErr
}
// Send a window update for the message this RPC is reading.
if s.fc != nil { // HanderServer doesn't use our flow control.
s.fc.onRead(uint32(m.Length + m.Overhead))
}
isCompressed := m.IsCompressed
// Read the message.
if m, err = s.rbuf.get(); err != nil {
s.readErr = err
return false, nil, err
}
if m.Err != nil {
if m.Err == io.EOF {
m.Err = io.ErrUnexpectedEOF
}
s.readErr = m.Err
return false, nil, s.readErr
}
return isCompressed, m.Data, nil
}
func (s *Stream) swapState(st streamState) streamState {
return streamState(atomic.SwapUint32((*uint32)(&s.state), uint32(st)))
}
func (s *Stream) compareAndSwapState(oldState, newState streamState) bool {
return atomic.CompareAndSwapUint32((*uint32)(&s.state), uint32(oldState), uint32(newState))
}
func (s *Stream) getState() streamState {
return streamState(atomic.LoadUint32((*uint32)(&s.state)))
}
func (s *Stream) waitOnHeader() error {
if s.headerChan == nil {
// On the server headerChan is always nil since a stream originates
// only after having received headers.
return nil
}
select {
case <-s.ctx.Done():
return ContextErr(s.ctx.Err())
case <-s.headerChan:
return nil
}
}
// RecvCompress returns the compression algorithm applied to the inbound
// message. It is empty string if there is no compression applied.
func (s *Stream) RecvCompress() string {
if err := s.waitOnHeader(); err != nil {
return ""
}
return s.recvCompress
}
// SetSendCompress sets the compression algorithm to the stream.
func (s *Stream) SetSendCompress(str string) {
s.sendCompress = str
}
// Done returns a chanel which is closed when it receives the final status
// from the server.
func (s *Stream) Done() <-chan struct{} {
return s.done
}
// Header acquires the key-value pairs of header metadata once it
// is available. It blocks until i) the metadata is ready or ii) there is no
// header metadata or iii) the stream is canceled/expired.
func (s *Stream) Header() (metadata.MD, error) {
err := s.waitOnHeader()
// Even if the stream is closed, header is returned if available.
select {
case <-s.headerChan:
if s.header == nil {
return nil, nil
}
return s.header.Copy(), nil
default:
}
return nil, err
}
// Trailer returns the cached trailer metedata. Note that if it is not called
// after the entire stream is done, it could return an empty MD. Client
// side only.
// It can be safely read only after stream has ended that is either read
// or write have returned io.EOF.
func (s *Stream) Trailer() metadata.MD {
c := s.trailer.Copy()
return c
}
// ServerTransport returns the underlying ServerTransport for the stream.
// The client side stream always returns nil.
func (s *Stream) ServerTransport() ServerTransport {
return s.st
}
// ContentSubtype returns the content-subtype for a request. For example, a
// content-subtype of "proto" will result in a content-type of
// "application/grpc+proto". This will always be lowercase. See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details.
func (s *Stream) ContentSubtype() string {
return s.contentSubtype
}
// Context returns the context of the stream.
func (s *Stream) Context() context.Context {
return s.ctx
}
// Method returns the method for the stream.
func (s *Stream) Method() string {
return s.method
}
// Status returns the status received from the server.
// Status can be read safely only after the stream has ended,
// that is, read or write has returned io.EOF.
func (s *Stream) Status() *status.Status {
return s.status
}
// SetHeader sets the header metadata. This can be called multiple times.
// Server side only.
// This should not be called in parallel to other data writes.
func (s *Stream) SetHeader(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
if s.headerOk || atomic.LoadUint32((*uint32)(&s.state)) == uint32(streamDone) {
return ErrIllegalHeaderWrite
}
s.header = metadata.Join(s.header, md)
return nil
}
// SendHeader sends the given header metadata. The given metadata is
// combined with any metadata set by previous calls to SetHeader and
// then written to the transport stream.
func (s *Stream) SendHeader(md metadata.MD) error {
t := s.ServerTransport()
return t.WriteHeader(s, md)
}
// SetTrailer sets the trailer metadata which will be sent with the RPC status
// by the server. This can be called multiple times. Server side only.
// This should not be called parallel to other data writes.
func (s *Stream) SetTrailer(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
s.trailer = metadata.Join(s.trailer, md)
return nil
}
// BytesReceived indicates whether any bytes have been received on this stream.
func (s *Stream) BytesReceived() bool {
return atomic.LoadUint32(&s.bytesReceived) == 1
}
// Unprocessed indicates whether the server did not process this stream --
// i.e. it sent a refused stream or GOAWAY including this stream ID.
func (s *Stream) Unprocessed() bool {
return atomic.LoadUint32(&s.unprocessed) == 1
}
// GoString is implemented by Stream so context.String() won't
// race when printing %#v.
func (s *Stream) GoString() string {
return fmt.Sprintf("<stream: %p, %v>", s, s.method)
}

View File

@ -24,10 +24,7 @@ package transport // externally used as import "google.golang.org/grpc/transport
import (
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
@ -39,359 +36,6 @@ import (
"google.golang.org/grpc/tap"
)
// recvMsg represents the received msg from the transport. All transport
// protocol specific info has been removed.
type recvMsg struct {
data []byte
// nil: received some data
// io.EOF: stream is completed. data is nil.
// other non-nil error: transport failure. data is nil.
err error
}
// recvBuffer is an unbounded channel of recvMsg structs.
// Note recvBuffer differs from controlBuffer only in that recvBuffer
// holds a channel of only recvMsg structs instead of objects implementing "item" interface.
// recvBuffer is written to much more often than
// controlBuffer and using strict recvMsg structs helps avoid allocation in "recvBuffer.put"
type recvBuffer struct {
c chan recvMsg
mu sync.Mutex
backlog []recvMsg
err error
}
func newRecvBuffer() *recvBuffer {
b := &recvBuffer{
c: make(chan recvMsg, 1),
}
return b
}
func (b *recvBuffer) put(r recvMsg) {
b.mu.Lock()
if b.err != nil {
b.mu.Unlock()
// An error had occurred earlier, don't accept more
// data or errors.
return
}
b.err = r.err
if len(b.backlog) == 0 {
select {
case b.c <- r:
b.mu.Unlock()
return
default:
}
}
b.backlog = append(b.backlog, r)
b.mu.Unlock()
}
func (b *recvBuffer) load() {
b.mu.Lock()
if len(b.backlog) > 0 {
select {
case b.c <- b.backlog[0]:
b.backlog[0] = recvMsg{}
b.backlog = b.backlog[1:]
default:
}
}
b.mu.Unlock()
}
// get returns the channel that receives a recvMsg in the buffer.
//
// Upon receipt of a recvMsg, the caller should call load to send another
// recvMsg onto the channel if there is any.
func (b *recvBuffer) get() <-chan recvMsg {
return b.c
}
//
// recvBufferReader implements io.Reader interface to read the data from
// recvBuffer.
type recvBufferReader struct {
ctx context.Context
ctxDone <-chan struct{} // cache of ctx.Done() (for performance).
recv *recvBuffer
last []byte // Stores the remaining data in the previous calls.
err error
}
// Read reads the next len(p) bytes from last. If last is drained, it tries to
// read additional data from recv. It blocks if there no additional data available
// in recv. If Read returns any non-nil error, it will continue to return that error.
func (r *recvBufferReader) Read(p []byte) (n int, err error) {
if r.err != nil {
return 0, r.err
}
n, r.err = r.read(p)
return n, r.err
}
func (r *recvBufferReader) read(p []byte) (n int, err error) {
if r.last != nil && len(r.last) > 0 {
// Read remaining data left in last call.
copied := copy(p, r.last)
r.last = r.last[copied:]
return copied, nil
}
select {
case <-r.ctxDone:
return 0, ContextErr(r.ctx.Err())
case m := <-r.recv.get():
r.recv.load()
if m.err != nil {
return 0, m.err
}
copied := copy(p, m.data)
r.last = m.data[copied:]
return copied, nil
}
}
type streamState uint32
const (
streamActive streamState = iota
streamWriteDone // EndStream sent
streamReadDone // EndStream received
streamDone // the entire stream is finished.
)
// Stream represents an RPC in the transport layer.
type Stream struct {
id uint32
st ServerTransport // nil for client side Stream
ctx context.Context // the associated context of the stream
cancel context.CancelFunc // always nil for client side Stream
done chan struct{} // closed at the end of stream to unblock writers. On the client side.
ctxDone <-chan struct{} // same as done chan but for server side. Cache of ctx.Done() (for performance)
method string // the associated RPC method of the stream
recvCompress string
sendCompress string
buf *recvBuffer
trReader io.Reader
fc *inFlow
recvQuota uint32
wq *writeQuota
// Callback to state application's intentions to read data. This
// is used to adjust flow control, if needed.
requestRead func(int)
headerChan chan struct{} // closed to indicate the end of header metadata.
headerDone uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
header metadata.MD // the received header metadata.
trailer metadata.MD // the key-value map of trailer metadata.
headerOk bool // becomes true from the first header is about to send
state streamState
status *status.Status // the status error received from the server
bytesReceived uint32 // indicates whether any bytes have been received on this stream
unprocessed uint32 // set if the server sends a refused stream or GOAWAY including this stream
// contentSubtype is the content-subtype for requests.
// this must be lowercase or the behavior is undefined.
contentSubtype string
}
func (s *Stream) swapState(st streamState) streamState {
return streamState(atomic.SwapUint32((*uint32)(&s.state), uint32(st)))
}
func (s *Stream) compareAndSwapState(oldState, newState streamState) bool {
return atomic.CompareAndSwapUint32((*uint32)(&s.state), uint32(oldState), uint32(newState))
}
func (s *Stream) getState() streamState {
return streamState(atomic.LoadUint32((*uint32)(&s.state)))
}
func (s *Stream) waitOnHeader() error {
if s.headerChan == nil {
// On the server headerChan is always nil since a stream originates
// only after having received headers.
return nil
}
select {
case <-s.ctx.Done():
return ContextErr(s.ctx.Err())
case <-s.headerChan:
return nil
}
}
// RecvCompress returns the compression algorithm applied to the inbound
// message. It is empty string if there is no compression applied.
func (s *Stream) RecvCompress() string {
if err := s.waitOnHeader(); err != nil {
return ""
}
return s.recvCompress
}
// SetSendCompress sets the compression algorithm to the stream.
func (s *Stream) SetSendCompress(str string) {
s.sendCompress = str
}
// Done returns a chanel which is closed when it receives the final status
// from the server.
func (s *Stream) Done() <-chan struct{} {
return s.done
}
// Header acquires the key-value pairs of header metadata once it
// is available. It blocks until i) the metadata is ready or ii) there is no
// header metadata or iii) the stream is canceled/expired.
func (s *Stream) Header() (metadata.MD, error) {
err := s.waitOnHeader()
// Even if the stream is closed, header is returned if available.
select {
case <-s.headerChan:
if s.header == nil {
return nil, nil
}
return s.header.Copy(), nil
default:
}
return nil, err
}
// Trailer returns the cached trailer metedata. Note that if it is not called
// after the entire stream is done, it could return an empty MD. Client
// side only.
// It can be safely read only after stream has ended that is either read
// or write have returned io.EOF.
func (s *Stream) Trailer() metadata.MD {
c := s.trailer.Copy()
return c
}
// ServerTransport returns the underlying ServerTransport for the stream.
// The client side stream always returns nil.
func (s *Stream) ServerTransport() ServerTransport {
return s.st
}
// ContentSubtype returns the content-subtype for a request. For example, a
// content-subtype of "proto" will result in a content-type of
// "application/grpc+proto". This will always be lowercase. See
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
// more details.
func (s *Stream) ContentSubtype() string {
return s.contentSubtype
}
// Context returns the context of the stream.
func (s *Stream) Context() context.Context {
return s.ctx
}
// Method returns the method for the stream.
func (s *Stream) Method() string {
return s.method
}
// Status returns the status received from the server.
// Status can be read safely only after the stream has ended,
// that is, read or write has returned io.EOF.
func (s *Stream) Status() *status.Status {
return s.status
}
// SetHeader sets the header metadata. This can be called multiple times.
// Server side only.
// This should not be called in parallel to other data writes.
func (s *Stream) SetHeader(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
if s.headerOk || atomic.LoadUint32((*uint32)(&s.state)) == uint32(streamDone) {
return ErrIllegalHeaderWrite
}
s.header = metadata.Join(s.header, md)
return nil
}
// SendHeader sends the given header metadata. The given metadata is
// combined with any metadata set by previous calls to SetHeader and
// then written to the transport stream.
func (s *Stream) SendHeader(md metadata.MD) error {
t := s.ServerTransport()
return t.WriteHeader(s, md)
}
// SetTrailer sets the trailer metadata which will be sent with the RPC status
// by the server. This can be called multiple times. Server side only.
// This should not be called parallel to other data writes.
func (s *Stream) SetTrailer(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
s.trailer = metadata.Join(s.trailer, md)
return nil
}
func (s *Stream) write(m recvMsg) {
s.buf.put(m)
}
// Read reads all p bytes from the wire for this stream.
func (s *Stream) Read(p []byte) (n int, err error) {
// Don't request a read if there was an error earlier
if er := s.trReader.(*transportReader).er; er != nil {
return 0, er
}
s.requestRead(len(p))
return io.ReadFull(s.trReader, p)
}
// tranportReader reads all the data available for this Stream from the transport and
// passes them into the decoder, which converts them into a gRPC message stream.
// The error is io.EOF when the stream is done or another non-nil error if
// the stream broke.
type transportReader struct {
reader io.Reader
// The handler to control the window update procedure for both this
// particular stream and the associated transport.
windowHandler func(int)
er error
}
func (t *transportReader) Read(p []byte) (n int, err error) {
n, err = t.reader.Read(p)
if err != nil {
t.er = err
return
}
t.windowHandler(n)
return
}
// BytesReceived indicates whether any bytes have been received on this stream.
func (s *Stream) BytesReceived() bool {
return atomic.LoadUint32(&s.bytesReceived) == 1
}
// Unprocessed indicates whether the server did not process this stream --
// i.e. it sent a refused stream or GOAWAY including this stream ID.
func (s *Stream) Unprocessed() bool {
return atomic.LoadUint32(&s.unprocessed) == 1
}
// GoString is implemented by Stream so context.String() won't
// race when printing %#v.
func (s *Stream) GoString() string {
return fmt.Sprintf("<stream: %p, %v>", s, s.method)
}
// state of transport
type transportState int
@ -476,7 +120,13 @@ type Options struct {
// Delay is a hint to the transport implementation for whether
// the data could be buffered for a batching write. The
// transport implementation may ignore the hint.
// TODO(mmukhi, dfawley): Should this be deleted?
Delay bool
// IsCompressed indicates weather the message being written
// was compressed or not. Transport relays this information
// to the API that generates gRPC-specific message header.
IsCompressed bool
}
// CallHdr carries the information of a particular RPC.
@ -525,7 +175,7 @@ type ClientTransport interface {
// Write sends the data for the given stream. A nil stream indicates
// the write is to be performed on the transport as a whole.
Write(s *Stream, hdr []byte, data []byte, opts *Options) error
Write(s *Stream, data []byte, opts *Options) error
// NewStream creates a Stream for an RPC.
NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error)
@ -573,7 +223,7 @@ type ServerTransport interface {
// Write sends the data for the given stream.
// Write may not be called on all streams.
Write(s *Stream, hdr []byte, data []byte, opts *Options) error
Write(s *Stream, data []byte, opts *Options) error
// WriteStatus sends the status of a stream to the client. WriteStatus is
// the final call made on a stream and always occurs.

View File

@ -100,8 +100,7 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
req = expectedRequestLarge
resp = expectedResponseLarge
}
p := make([]byte, len(req))
_, err := s.Read(p)
_, p, err := s.Read(math.MaxInt32)
if err != nil {
return
}
@ -109,31 +108,26 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
t.Fatalf("handleStream got %v, want %v", p, req)
}
// send a response back to the client.
h.t.Write(s, nil, resp, &Options{})
h.t.Write(s, resp, &Options{})
// send the trailer to end the stream.
h.t.WriteStatus(s, status.New(codes.OK, ""))
}
func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) {
header := make([]byte, 5)
for {
if _, err := s.Read(header); err != nil {
_, msg, err := s.Read(math.MaxInt32)
if err != nil {
if err == io.EOF {
h.t.WriteStatus(s, status.New(codes.OK, ""))
return
}
t.Fatalf("Error on server while reading data header: %v", err)
t.Errorf("Error on server while reading data header: %v", err)
return
}
sz := binary.BigEndian.Uint32(header[1:])
msg := make([]byte, int(sz))
if _, err := s.Read(msg); err != nil {
t.Fatalf("Error on server while reading message: %v", err)
if err := h.t.Write(s, msg, &Options{}); err != nil {
t.Errorf("Error on server while writing: %v", err)
return
}
buf := make([]byte, sz+5)
buf[0] = byte(0)
binary.BigEndian.PutUint32(buf[1:], uint32(sz))
copy(buf[5:], msg)
h.t.Write(s, nil, buf, &Options{})
}
}
@ -189,12 +183,10 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) {
req = expectedRequestLarge
resp = expectedResponseLarge
}
p := make([]byte, len(req))
// Wait before reading. Give time to client to start sending
// before server starts reading.
time.Sleep(2 * time.Second)
_, err := s.Read(p)
_, p, err := s.Read(math.MaxInt32)
if err != nil {
t.Errorf("s.Read(_) = _, %v, want _, <nil>", err)
return
@ -205,7 +197,7 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) {
return
}
// send a response back to the client.
if err := h.t.Write(s, nil, resp, &Options{}); err != nil {
if err := h.t.Write(s, resp, &Options{}); err != nil {
t.Errorf("server Write got %v, want <nil>", err)
return
}
@ -223,8 +215,7 @@ func (h *testStreamHandler) handleStreamDelayWrite(t *testing.T, s *Stream) {
req = expectedRequestLarge
resp = expectedResponseLarge
}
p := make([]byte, len(req))
_, err := s.Read(p)
_, p, err := s.Read(math.MaxInt32)
if err != nil {
t.Errorf("s.Read(_) = _, %v, want _, <nil>", err)
return
@ -237,7 +228,7 @@ func (h *testStreamHandler) handleStreamDelayWrite(t *testing.T, s *Stream) {
// Wait before sending. Give time to client to start reading
// before server starts sending.
time.Sleep(2 * time.Second)
if err := h.t.Write(s, nil, resp, &Options{}); err != nil {
if err := h.t.Write(s, resp, &Options{}); err != nil {
t.Errorf("server Write got %v, want <nil>", err)
return
}
@ -442,7 +433,7 @@ func TestInflightStreamClosing(t *testing.T) {
serr := StreamError{Desc: "client connection is closing"}
go func() {
defer close(donec)
if _, err := stream.Read(make([]byte, defaultWindowSize)); err != serr {
if _, _, err := stream.Read(math.MaxInt32); err != serr {
t.Errorf("unexpected Stream error %v, expected %v", err, serr)
}
}()
@ -858,15 +849,14 @@ func TestClientSendAndReceive(t *testing.T) {
Last: true,
Delay: false,
}
if err := ct.Write(s1, nil, expectedRequest, &opts); err != nil && err != io.EOF {
if err := ct.Write(s1, expectedRequest, &opts); err != nil && err != io.EOF {
t.Fatalf("failed to send data: %v", err)
}
p := make([]byte, len(expectedResponse))
_, recvErr := s1.Read(p)
_, p, recvErr := s1.Read(math.MaxInt32)
if recvErr != nil || !bytes.Equal(p, expectedResponse) {
t.Fatalf("Error: %v, want <nil>; Result: %v, want %v", recvErr, p, expectedResponse)
}
_, recvErr = s1.Read(p)
_, _, recvErr = s1.Read(math.MaxInt32)
if recvErr != io.EOF {
t.Fatalf("Error: %v; want <EOF>", recvErr)
}
@ -895,16 +885,15 @@ func performOneRPC(ct ClientTransport) {
Last: true,
Delay: false,
}
if err := ct.Write(s, []byte{}, expectedRequest, &opts); err == nil || err == io.EOF {
if err := ct.Write(s, expectedRequest, &opts); err == nil || err == io.EOF {
time.Sleep(5 * time.Millisecond)
// The following s.Recv()'s could error out because the
// underlying transport is gone.
//
// Read response
p := make([]byte, len(expectedResponse))
s.Read(p)
s.Read(math.MaxInt32)
// Read io.EOF
s.Read(p)
s.Read(math.MaxInt32)
}
}
@ -939,14 +928,13 @@ func TestLargeMessage(t *testing.T) {
if err != nil {
t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
}
if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
}
p := make([]byte, len(expectedResponseLarge))
if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
t.Errorf("s.Read(%v) = _, %v, want %v, <nil>", err, p, expectedResponse)
if _, p, err := s.Read(math.MaxInt32); err != nil || !bytes.Equal(p, expectedResponseLarge) {
t.Errorf("s.Read(math.MaxInt32) = %v, %v, want %v, <nil>", p, err, expectedResponse)
}
if _, err = s.Read(p); err != io.EOF {
if _, _, err = s.Read(math.MaxInt32); err != io.EOF {
t.Errorf("Failed to complete the stream %v; want <EOF>", err)
}
}()
@ -974,19 +962,18 @@ func TestLargeMessageWithDelayRead(t *testing.T) {
t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
return
}
if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil {
if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil {
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
return
}
p := make([]byte, len(expectedResponseLarge))
// Give time to server to begin sending before client starts reading.
time.Sleep(2 * time.Second)
if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
if _, p, err := s.Read(math.MaxInt32); err != nil || !bytes.Equal(p, expectedResponseLarge) {
t.Errorf("s.Read(_) = _, %v, want _, <nil>", err)
return
}
if _, err = s.Read(p); err != io.EOF {
if _, _, err = s.Read(math.MaxInt32); err != io.EOF {
t.Errorf("Failed to complete the stream %v; want <EOF>", err)
}
}()
@ -1017,16 +1004,15 @@ func TestLargeMessageDelayWrite(t *testing.T) {
// Give time to server to start reading before client starts sending.
time.Sleep(2 * time.Second)
if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil {
if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil {
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
return
}
p := make([]byte, len(expectedResponseLarge))
if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
if _, p, err := s.Read(math.MaxInt32); err != nil || !bytes.Equal(p, expectedResponseLarge) {
t.Errorf("io.ReadFull(%v) = _, %v, want %v, <nil>", err, p, expectedResponse)
return
}
if _, err = s.Read(p); err != io.EOF {
if _, _, err = s.Read(math.MaxInt32); err != io.EOF {
t.Errorf("Failed to complete the stream %v; want <EOF>", err)
}
}()
@ -1047,19 +1033,10 @@ func TestGracefulClose(t *testing.T) {
t.Fatalf("NewStream(_, _) = _, %v, want _, <nil>", err)
}
msg := make([]byte, 1024)
outgoingHeader := make([]byte, 5)
outgoingHeader[0] = byte(0)
binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(len(msg)))
incomingHeader := make([]byte, 5)
if err := ct.Write(s, outgoingHeader, msg, &Options{}); err != nil {
if err := ct.Write(s, msg, &Options{}); err != nil {
t.Fatalf("Error while writing: %v", err)
}
if _, err := s.Read(incomingHeader); err != nil {
t.Fatalf("Error while reading: %v", err)
}
sz := binary.BigEndian.Uint32(incomingHeader[1:])
recvMsg := make([]byte, int(sz))
if _, err := s.Read(recvMsg); err != nil {
if _, _, err := s.Read(math.MaxInt32); err != nil {
t.Fatalf("Error while reading: %v", err)
}
if err = ct.GracefulClose(); err != nil {
@ -1075,14 +1052,14 @@ func TestGracefulClose(t *testing.T) {
if err == errStreamDrain {
return
}
ct.Write(str, nil, nil, &Options{Last: true})
if _, err := str.Read(make([]byte, 8)); err != errStreamDrain {
ct.Write(str, nil, &Options{Last: true})
if _, _, err := str.Read(math.MaxInt32); err != errStreamDrain {
t.Errorf("_.NewStream(_, _) = _, %v, want _, %v", err, errStreamDrain)
}
}()
}
ct.Write(s, nil, nil, &Options{Last: true})
if _, err := s.Read(incomingHeader); err != io.EOF {
ct.Write(s, nil, &Options{Last: true})
if _, _, err := s.Read(math.MaxInt32); err != io.EOF {
t.Fatalf("Client expected EOF from the server. Got: %v", err)
}
// The stream which was created before graceful close can still proceed.
@ -1110,13 +1087,13 @@ func TestLargeMessageSuspension(t *testing.T) {
}()
// Write should not be done successfully due to flow control.
msg := make([]byte, initialWindowSize*8)
ct.Write(s, nil, msg, &Options{})
err = ct.Write(s, nil, msg, &Options{Last: true})
ct.Write(s, msg, &Options{})
err = ct.Write(s, msg, &Options{Last: true})
if err != errStreamDone {
t.Fatalf("Write got %v, want io.EOF", err)
}
expectedErr := streamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded)
if _, err := s.Read(make([]byte, 8)); err != expectedErr {
if _, _, err := s.Read(math.MaxInt32); err != expectedErr {
t.Fatalf("Read got %v of type %T, want %v", err, err, expectedErr)
}
ct.Close()
@ -1305,7 +1282,7 @@ func TestClientConnDecoupledFromApplicationRead(t *testing.T) {
t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream1.id)
}
// Exhaust client's connection window.
if err := st.Write(sstream1, []byte{}, make([]byte, defaultWindowSize), &Options{}); err != nil {
if err := st.Write(sstream1, make([]byte, defaultWindowSize), &Options{}); err != nil {
t.Fatalf("Server failed to write data. Err: %v", err)
}
notifyChan = make(chan struct{})
@ -1330,17 +1307,17 @@ func TestClientConnDecoupledFromApplicationRead(t *testing.T) {
t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream2.id)
}
// Server should be able to send data on the new stream, even though the client hasn't read anything on the first stream.
if err := st.Write(sstream2, []byte{}, make([]byte, defaultWindowSize), &Options{}); err != nil {
if err := st.Write(sstream2, make([]byte, defaultWindowSize), &Options{}); err != nil {
t.Fatalf("Server failed to write data. Err: %v", err)
}
// Client should be able to read data on second stream.
if _, err := cstream2.Read(make([]byte, defaultWindowSize)); err != nil {
if _, _, err := cstream2.Read(math.MaxInt32); err != nil {
t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err)
}
// Client should be able to read data on first stream.
if _, err := cstream1.Read(make([]byte, defaultWindowSize)); err != nil {
if _, _, err := cstream1.Read(math.MaxInt32); err != nil {
t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err)
}
}
@ -1373,7 +1350,7 @@ func TestServerConnDecoupledFromApplicationRead(t *testing.T) {
t.Fatalf("Failed to create 1st stream. Err: %v", err)
}
// Exhaust server's connection window.
if err := client.Write(cstream1, nil, make([]byte, defaultWindowSize), &Options{Last: true}); err != nil {
if err := client.Write(cstream1, make([]byte, defaultWindowSize), &Options{Last: true}); err != nil {
t.Fatalf("Client failed to write data. Err: %v", err)
}
//Client should be able to create another stream and send data on it.
@ -1381,7 +1358,7 @@ func TestServerConnDecoupledFromApplicationRead(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create 2nd stream. Err: %v", err)
}
if err := client.Write(cstream2, nil, make([]byte, defaultWindowSize), &Options{}); err != nil {
if err := client.Write(cstream2, make([]byte, defaultWindowSize), &Options{}); err != nil {
t.Fatalf("Client failed to write data. Err: %v", err)
}
// Get the streams on server.
@ -1403,11 +1380,11 @@ func TestServerConnDecoupledFromApplicationRead(t *testing.T) {
}
st.mu.Unlock()
// Reading from the stream on server should succeed.
if _, err := sstream1.Read(make([]byte, defaultWindowSize)); err != nil {
if _, _, err := sstream1.Read(math.MaxInt32); err != nil {
t.Fatalf("_.Read(_) = %v, want <nil>", err)
}
if _, err := sstream1.Read(make([]byte, 1)); err != io.EOF {
if _, _, err := sstream1.Read(math.MaxInt32); err != io.EOF {
t.Fatalf("_.Read(_) = %v, want io.EOF", err)
}
@ -1616,11 +1593,10 @@ func TestEncodingRequiredStatus(t *testing.T) {
Last: true,
Delay: false,
}
if err := ct.Write(s, nil, expectedRequest, &opts); err != nil && err != errStreamDone {
if err := ct.Write(s, expectedRequest, &opts); err != nil && err != errStreamDone {
t.Fatalf("Failed to write the request: %v", err)
}
p := make([]byte, http2MaxFrameLen)
if _, err := s.trReader.(*transportReader).Read(p); err != io.EOF {
if _, _, err := s.Read(math.MaxInt32); err != io.EOF {
t.Fatalf("Read got error %v, want %v", err, io.EOF)
}
if !reflect.DeepEqual(s.Status(), encodingTestStatus) {
@ -1640,8 +1616,7 @@ func TestInvalidHeaderField(t *testing.T) {
if err != nil {
return
}
p := make([]byte, http2MaxFrameLen)
_, err = s.trReader.(*transportReader).Read(p)
_, _, err = s.Read(math.MaxInt32)
if se, ok := err.(StreamError); !ok || se.Code != codes.Internal || !strings.Contains(err.Error(), expectedInvalidHeaderField) {
t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.Internal, expectedInvalidHeaderField)
}
@ -1764,26 +1739,17 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig)
t.Fatalf("Failed to create stream. Err: %v", err)
}
msg := make([]byte, msgSize)
buf := make([]byte, msgSize+5)
buf[0] = byte(0)
binary.BigEndian.PutUint32(buf[1:], uint32(msgSize))
copy(buf[5:], msg)
opts := Options{}
header := make([]byte, 5)
for i := 1; i <= 10; i++ {
if err := ct.Write(cstream, nil, buf, &opts); err != nil {
if err := ct.Write(cstream, msg, &opts); err != nil {
t.Fatalf("Error on client while writing message: %v", err)
}
if _, err := cstream.Read(header); err != nil {
t.Fatalf("Error on client while reading data frame header: %v", err)
}
sz := binary.BigEndian.Uint32(header[1:])
recvMsg := make([]byte, int(sz))
if _, err := cstream.Read(recvMsg); err != nil {
_, recvMsg, err := cstream.Read(math.MaxInt32)
if err != nil {
t.Fatalf("Error on client while reading data: %v", err)
}
if len(recvMsg) != len(msg) {
t.Fatalf("Length of message received by client: %v, want: %v", len(recvMsg), len(msg))
if !bytes.Equal(recvMsg, msg) {
t.Fatalf("Message received by client(len: %d) not equal to what was expected(len: %d)", len(recvMsg), len(msg))
}
}
var sstream *Stream
@ -1794,8 +1760,8 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig)
st.mu.Unlock()
loopyServerStream := st.loopy.estdStreams[sstream.id]
loopyClientStream := ct.loopy.estdStreams[cstream.id]
ct.Write(cstream, nil, nil, &Options{Last: true}) // Close the stream.
if _, err := cstream.Read(header); err != io.EOF {
ct.Write(cstream, nil, &Options{Last: true}) // Close the stream.
if _, _, err := cstream.Read(math.MaxInt32); err != io.EOF {
t.Fatalf("Client expected an EOF from the server. Got: %v", err)
}
// Sleep for a little to make sure both sides flush out their buffers.
@ -1816,11 +1782,11 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig)
t.Fatalf("Account mismatch: server transport inflow(%d) != server unacked(%d) + client sendQuota(%d)", st.fc.limit, st.fc.unacked, ct.loopy.sendQuota)
}
// Check stream flow control.
if int(cstream.fc.limit+cstream.fc.delta-cstream.fc.pendingData-cstream.fc.pendingUpdate) != int(st.loopy.oiws)-loopyServerStream.bytesOutStanding {
t.Fatalf("Account mismatch: client stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != server outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", cstream.fc.limit, cstream.fc.delta, cstream.fc.pendingData, cstream.fc.pendingUpdate, st.loopy.oiws, loopyServerStream.bytesOutStanding)
if int(cstream.fc.limit)-int(cstream.fc.rcvd) != int(st.loopy.oiws)-loopyServerStream.bytesOutStanding {
t.Fatalf("Account mismatch: client stream inflow limit(%d) - rcvd(%d) != server outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", cstream.fc.limit, cstream.fc.rcvd, st.loopy.oiws, loopyServerStream.bytesOutStanding)
}
if int(sstream.fc.limit+sstream.fc.delta-sstream.fc.pendingData-sstream.fc.pendingUpdate) != int(ct.loopy.oiws)-loopyClientStream.bytesOutStanding {
t.Fatalf("Account mismatch: server stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != client outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", sstream.fc.limit, sstream.fc.delta, sstream.fc.pendingData, sstream.fc.pendingUpdate, ct.loopy.oiws, loopyClientStream.bytesOutStanding)
if int(sstream.fc.limit)-int(sstream.fc.rcvd) != int(ct.loopy.oiws)-loopyClientStream.bytesOutStanding {
t.Fatalf("Account mismatch: server stream inflow limit(%d) - rcvd(%d) != client outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", sstream.fc.limit, sstream.fc.rcvd, ct.loopy.oiws, loopyClientStream.bytesOutStanding)
}
}
@ -2000,8 +1966,7 @@ func testHTTPToGRPCStatusMapping(t *testing.T, httpStatus int, wh writeHeaders)
stream, cleanUp := setUpHTTPStatusTest(t, httpStatus, wh)
defer cleanUp()
want := httpStatusConvTab[httpStatus]
buf := make([]byte, 8)
_, err := stream.Read(buf)
_, _, err := stream.Read(math.MaxInt32)
if err == nil {
t.Fatalf("Stream.Read(_) unexpectedly returned no error. Expected stream error with code %v", want)
}
@ -2017,8 +1982,7 @@ func testHTTPToGRPCStatusMapping(t *testing.T, httpStatus int, wh writeHeaders)
func TestHTTPStatusOKAndMissingGRPCStatus(t *testing.T) {
stream, cleanUp := setUpHTTPStatusTest(t, http.StatusOK, writeOneHeader)
defer cleanUp()
buf := make([]byte, 8)
_, err := stream.Read(buf)
_, _, err := stream.Read(math.MaxInt32)
if err != io.EOF {
t.Fatalf("stream.Read(_) = _, %v, want _, io.EOF", err)
}
@ -2035,45 +1999,25 @@ func TestHTTPStatusNottOKAndMissingGRPCStatusInSecondHeader(t *testing.T) {
// If any error occurs on a call to Stream.Read, future calls
// should continue to return that same error.
func TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) {
testRecvBuffer := newRecvBuffer()
s := &Stream{
ctx: context.Background(),
buf: testRecvBuffer,
requestRead: func(int) {},
}
s.trReader = &transportReader{
reader: &recvBufferReader{
ctx: s.ctx,
ctxDone: s.ctx.Done(),
recv: s.buf,
},
windowHandler: func(int) {},
}
testData := make([]byte, 1)
testData[0] = 5
s := newStream(context.Background())
testErr := errors.New("test error")
s.write(recvMsg{data: testData, err: testErr})
s.notifyErr(testErr)
inBuf := make([]byte, 1)
actualCount, actualErr := s.Read(inBuf)
if actualCount != 0 {
t.Errorf("actualCount, _ := s.Read(_) differs; want 0; got %v", actualCount)
}
if actualErr.Error() != testErr.Error() {
t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error())
pf, inBuf, actualErr := s.Read(math.MaxInt32)
if pf != false || inBuf != nil || actualErr.Error() != testErr.Error() {
t.Errorf("%v, %v, %v := s.Read(_) differs; want false, <nil>, %v", pf, inBuf, actualErr, testErr)
}
s.write(recvMsg{data: testData, err: nil})
s.write(recvMsg{data: testData, err: errors.New("different error from first")})
testData := make([]byte, 6)
testData[0] = byte(1)
binary.BigEndian.PutUint32(testData[1:], uint32(1))
s.consume(testData, 0)
s.notifyErr(errors.New("different error from first"))
for i := 0; i < 2; i++ {
inBuf := make([]byte, 1)
actualCount, actualErr := s.Read(inBuf)
if actualCount != 0 {
t.Errorf("actualCount, _ := s.Read(_) differs; want %v; got %v", 0, actualCount)
}
if actualErr.Error() != testErr.Error() {
t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error())
pf, inBuf, actualErr := s.Read(math.MaxInt32)
if pf != false || inBuf != nil || actualErr.Error() != testErr.Error() {
t.Errorf("%v, %v, %v := s.Read(_) differs; want false, <nil>, %v", pf, inBuf, actualErr, testErr)
}
}
}
@ -2113,11 +2057,7 @@ func runPingPongTest(t *testing.T, msgSize int) {
t.Fatalf("Failed to create stream. Err: %v", err)
}
msg := make([]byte, msgSize)
outgoingHeader := make([]byte, 5)
outgoingHeader[0] = byte(0)
binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(msgSize))
opts := &Options{}
incomingHeader := make([]byte, 5)
done := make(chan struct{})
go func() {
timer := time.NewTimer(time.Second * 5)
@ -2127,23 +2067,22 @@ func runPingPongTest(t *testing.T, msgSize int) {
for {
select {
case <-done:
ct.Write(stream, nil, nil, &Options{Last: true})
if _, err := stream.Read(incomingHeader); err != io.EOF {
ct.Write(stream, nil, &Options{Last: true})
if _, _, err := stream.Read(math.MaxInt32); err != io.EOF {
t.Fatalf("Client expected EOF from the server. Got: %v", err)
}
return
default:
if err := ct.Write(stream, outgoingHeader, msg, opts); err != nil {
if err := ct.Write(stream, msg, opts); err != nil {
t.Fatalf("Error on client while writing message. Err: %v", err)
}
if _, err := stream.Read(incomingHeader); err != nil {
t.Fatalf("Error on client while reading data header. Err: %v", err)
}
sz := binary.BigEndian.Uint32(incomingHeader[1:])
recvMsg := make([]byte, int(sz))
if _, err := stream.Read(recvMsg); err != nil {
_, recvMsg, err := stream.Read(math.MaxInt32)
if err != nil {
t.Fatalf("Error on client while reading data. Err: %v", err)
}
if !bytes.Equal(recvMsg, msg) {
t.Fatalf("%v != %v", recvMsg, msg)
}
}
}
}