Less mem (#1987)
* Export changes to OSS. * First commit. * Cherry-pick. * Documentation. * Post review updates.
This commit is contained in:
11
call_test.go
11
call_test.go
@ -66,17 +66,16 @@ type testStreamHandler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
|
func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
|
||||||
p := &parser{r: s}
|
|
||||||
for {
|
for {
|
||||||
pf, req, err := p.recvMsg(math.MaxInt32)
|
isCompressed, req, err := recvMsg(s, math.MaxInt32)
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if pf != compressionNone {
|
if isCompressed {
|
||||||
t.Errorf("Received the mistaken message format %d, want %d", pf, compressionNone)
|
t.Errorf("Received compressed message want non-compressed message")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var v string
|
var v string
|
||||||
@ -105,12 +104,12 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// send a response back to end the stream.
|
// send a response back to end the stream.
|
||||||
hdr, data, err := encode(testCodec{}, &expectedResponse, nil, nil, nil)
|
data, err := encode(testCodec{}, &expectedResponse, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Failed to encode the response: %v", err)
|
t.Errorf("Failed to encode the response: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.t.Write(s, hdr, data, &transport.Options{})
|
h.t.Write(s, data, &transport.Options{})
|
||||||
h.t.WriteStatus(s, status.New(codes.OK, ""))
|
h.t.WriteStatus(s, status.New(codes.OK, ""))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
203
internal/msgdecoder/msgdecoder.go
Normal file
203
internal/msgdecoder/msgdecoder.go
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018 gRPC authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Package msgdecoder contains the logic to deconstruct a gRPC-message.
|
||||||
|
package msgdecoder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RecvMsg is a message constructed from the incoming
|
||||||
|
// bytes on the transport for a stream.
|
||||||
|
// An instance of RecvMsg will contain only one of the
|
||||||
|
// following: message header related fields, data slice
|
||||||
|
// or error.
|
||||||
|
type RecvMsg struct {
|
||||||
|
// Following three are message header related
|
||||||
|
// fields.
|
||||||
|
// true if the message was compressed by the other
|
||||||
|
// side.
|
||||||
|
IsCompressed bool
|
||||||
|
// Length of the message.
|
||||||
|
Length int
|
||||||
|
// Overhead is the length of message header(5 bytes)
|
||||||
|
// plus padding.
|
||||||
|
Overhead int
|
||||||
|
|
||||||
|
// Data payload of the message.
|
||||||
|
Data []byte
|
||||||
|
|
||||||
|
// Err occurred while reading.
|
||||||
|
// nil: received some data
|
||||||
|
// io.EOF: stream is completed. data is nil.
|
||||||
|
// other non-nil error: transport failure. data is nil.
|
||||||
|
Err error
|
||||||
|
|
||||||
|
Next *RecvMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecvMsgList is a linked-list of RecvMsg.
|
||||||
|
type RecvMsgList struct {
|
||||||
|
head *RecvMsg
|
||||||
|
tail *RecvMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEmpty returns true when l is empty.
|
||||||
|
func (l *RecvMsgList) IsEmpty() bool {
|
||||||
|
if l.tail == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enqueue adds r to l at the back.
|
||||||
|
func (l *RecvMsgList) Enqueue(r *RecvMsg) {
|
||||||
|
if l.IsEmpty() {
|
||||||
|
l.head, l.tail = r, r
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t := l.tail
|
||||||
|
l.tail = r
|
||||||
|
t.Next = r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dequeue removes a RcvMsg from the end of l.
|
||||||
|
func (l *RecvMsgList) Dequeue() *RecvMsg {
|
||||||
|
if l.head == nil {
|
||||||
|
// Note to developer: Instead of calling isEmpty() which
|
||||||
|
// checks the same condition on l.tail, we check it directly
|
||||||
|
// on l.head so that in non-nil cases, there aren't cache misses.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
r := l.head
|
||||||
|
l.head = l.head.Next
|
||||||
|
if l.head == nil {
|
||||||
|
l.tail = nil
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessageDecoder decodes bytes from HTTP2 data frames
|
||||||
|
// and constructs a gRPC message which is then put in a
|
||||||
|
// buffer that application(RPCs) read from.
|
||||||
|
// gRPC Messages:
|
||||||
|
// First 5 bytes is the message header:
|
||||||
|
// First byte: Payload format.
|
||||||
|
// Next 4 bytes: Length of the message.
|
||||||
|
// Rest of the bytes is the message payload.
|
||||||
|
//
|
||||||
|
// TODO(mmukhi): Write unit tests.
|
||||||
|
type MessageDecoder struct {
|
||||||
|
// current message being read by the transport.
|
||||||
|
current *RecvMsg
|
||||||
|
dataOfst int
|
||||||
|
padding int
|
||||||
|
// hdr stores the message header as it is beind received by the transport.
|
||||||
|
hdr []byte
|
||||||
|
hdrOfst int
|
||||||
|
// Callback used to send decoded messages.
|
||||||
|
dispatch func(*RecvMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMessageDecoder creates an instance of MessageDecoder. It takes a callback
|
||||||
|
// which is called to dispatch finished headers and messages to the application.
|
||||||
|
func NewMessageDecoder(dispatch func(*RecvMsg)) *MessageDecoder {
|
||||||
|
return &MessageDecoder{
|
||||||
|
hdr: make([]byte, 5),
|
||||||
|
dispatch: dispatch,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode consumes bytes from a HTTP2 data frame to create gRPC messages.
|
||||||
|
func (m *MessageDecoder) Decode(b []byte, padding int) {
|
||||||
|
m.padding += padding
|
||||||
|
for len(b) > 0 {
|
||||||
|
// Case 1: A complete message hdr was received earlier.
|
||||||
|
if m.current != nil {
|
||||||
|
n := copy(m.current.Data[m.dataOfst:], b)
|
||||||
|
m.dataOfst += n
|
||||||
|
b = b[n:]
|
||||||
|
if m.dataOfst == len(m.current.Data) { // Message is complete.
|
||||||
|
m.dispatch(m.current)
|
||||||
|
m.current = nil
|
||||||
|
m.dataOfst = 0
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Case 2a: No message header has been received yet.
|
||||||
|
if m.hdrOfst == 0 {
|
||||||
|
// case 2a.1: b has the whole header
|
||||||
|
if len(b) >= 5 {
|
||||||
|
m.parseHeader(b[:5])
|
||||||
|
b = b[5:]
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// case 2a.2: b has partial header
|
||||||
|
n := copy(m.hdr, b)
|
||||||
|
m.hdrOfst = n
|
||||||
|
b = b[n:]
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Case 2b: Partial message header was received earlier.
|
||||||
|
n := copy(m.hdr[m.hdrOfst:], b)
|
||||||
|
m.hdrOfst += n
|
||||||
|
b = b[n:]
|
||||||
|
if m.hdrOfst == 5 { // hdr is complete.
|
||||||
|
m.hdrOfst = 0
|
||||||
|
m.parseHeader(m.hdr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MessageDecoder) parseHeader(b []byte) {
|
||||||
|
length := int(binary.BigEndian.Uint32(b[1:5]))
|
||||||
|
hdr := &RecvMsg{
|
||||||
|
IsCompressed: int(b[0]) == 1,
|
||||||
|
Length: length,
|
||||||
|
Overhead: m.padding + 5,
|
||||||
|
}
|
||||||
|
m.padding = 0
|
||||||
|
// Dispatch the information retreived from message header so
|
||||||
|
// that the RPC goroutine can send a proactive window update as we
|
||||||
|
// wait for the rest of it.
|
||||||
|
m.dispatch(hdr)
|
||||||
|
if length == 0 {
|
||||||
|
m.dispatch(&RecvMsg{})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.current = &RecvMsg{
|
||||||
|
Data: getMem(length),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getMem(l int) []byte {
|
||||||
|
// TODO(mmukhi): Reuse this memory.
|
||||||
|
return make([]byte, l)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateMessageHeader creates a gRPC-specific message header.
|
||||||
|
func CreateMessageHeader(l int, isCompressed bool) []byte {
|
||||||
|
// TODO(mmukhi): Investigate if this memory is worth
|
||||||
|
// reusing.
|
||||||
|
hdr := make([]byte, 5)
|
||||||
|
if isCompressed {
|
||||||
|
hdr[0] = byte(1)
|
||||||
|
}
|
||||||
|
binary.BigEndian.PutUint32(hdr[1:], uint32(l))
|
||||||
|
return hdr
|
||||||
|
}
|
81
internal/msgdecoder/msgdecoder_test.go
Normal file
81
internal/msgdecoder/msgdecoder_test.go
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018 gRPC authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package msgdecoder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMessageDecoder(t *testing.T) {
|
||||||
|
for _, test := range []struct {
|
||||||
|
numFrames int
|
||||||
|
data []string
|
||||||
|
}{
|
||||||
|
{1, []string{"abc"}}, // One message per frame.
|
||||||
|
{1, []string{"abc", "def", "ghi"}}, // Multiple messages per frame.
|
||||||
|
{3, []string{"a", "bcdef", "ghif"}}, // Multiple messages over multiple frames.
|
||||||
|
} {
|
||||||
|
var want []*RecvMsg
|
||||||
|
for _, d := range test.data {
|
||||||
|
want = append(want, &RecvMsg{Length: len(d), Overhead: 5})
|
||||||
|
want = append(want, &RecvMsg{Data: []byte(d)})
|
||||||
|
}
|
||||||
|
var got []*RecvMsg
|
||||||
|
dcdr := NewMessageDecoder(func(r *RecvMsg) { got = append(got, r) })
|
||||||
|
for _, fr := range createFrames(test.numFrames, test.data) {
|
||||||
|
dcdr.Decode(fr, 0)
|
||||||
|
}
|
||||||
|
if !match(got, want) {
|
||||||
|
t.Fatalf("got: %v, want: %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func match(got, want []*RecvMsg) bool {
|
||||||
|
for i, v := range got {
|
||||||
|
if !reflect.DeepEqual(v, want[i]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func createFrames(n int, msgs []string) [][]byte {
|
||||||
|
var b []byte
|
||||||
|
for _, m := range msgs {
|
||||||
|
payload := []byte(m)
|
||||||
|
hdr := make([]byte, 5)
|
||||||
|
binary.BigEndian.PutUint32(hdr[1:], uint32(len(payload)))
|
||||||
|
b = append(b, hdr...)
|
||||||
|
b = append(b, payload...)
|
||||||
|
}
|
||||||
|
// break b into n parts.
|
||||||
|
var result [][]byte
|
||||||
|
batch := len(b) / n
|
||||||
|
for len(b) != 0 {
|
||||||
|
sz := batch
|
||||||
|
if len(b) < sz {
|
||||||
|
sz = len(b)
|
||||||
|
}
|
||||||
|
result = append(result, b[:sz])
|
||||||
|
b = b[sz:]
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
135
rpc_util.go
135
rpc_util.go
@ -21,7 +21,6 @@ package grpc
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
@ -415,85 +414,39 @@ func (o CustomCodecCallOption) before(c *callInfo) error {
|
|||||||
}
|
}
|
||||||
func (o CustomCodecCallOption) after(c *callInfo) {}
|
func (o CustomCodecCallOption) after(c *callInfo) {}
|
||||||
|
|
||||||
// The format of the payload: compressed or not?
|
|
||||||
type payloadFormat uint8
|
|
||||||
|
|
||||||
const (
|
|
||||||
compressionNone payloadFormat = iota // no compression
|
|
||||||
compressionMade
|
|
||||||
)
|
|
||||||
|
|
||||||
// parser reads complete gRPC messages from the underlying reader.
|
|
||||||
type parser struct {
|
|
||||||
// r is the underlying reader.
|
|
||||||
// See the comment on recvMsg for the permissible
|
|
||||||
// error types.
|
|
||||||
r io.Reader
|
|
||||||
|
|
||||||
// The header of a gRPC message. Find more detail at
|
|
||||||
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
|
|
||||||
header [5]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// recvMsg reads a complete gRPC message from the stream.
|
// recvMsg reads a complete gRPC message from the stream.
|
||||||
//
|
//
|
||||||
// It returns the message and its payload (compression/encoding)
|
// It returns a flag set to true if message was compressed,
|
||||||
// format. The caller owns the returned msg memory.
|
// the message as a byte slice or error if so.
|
||||||
|
// The caller owns the returned msg memory.
|
||||||
//
|
//
|
||||||
// If there is an error, possible values are:
|
// If there is an error, possible values are:
|
||||||
// * io.EOF, when no messages remain
|
// * io.EOF, when no messages remain
|
||||||
// * io.ErrUnexpectedEOF
|
// * io.ErrUnexpectedEOF
|
||||||
// * of type transport.ConnectionError
|
// * of type transport.ConnectionError
|
||||||
// * of type transport.StreamError
|
// * of type transport.StreamError
|
||||||
// No other error values or types must be returned, which also means
|
// No other error values or types must be returned.
|
||||||
// that the underlying io.Reader must not return an incompatible
|
func recvMsg(s *transport.Stream, maxRecvMsgSize int) (bool, []byte, error) {
|
||||||
// error.
|
isCompressed, msg, err := s.Read(maxRecvMsgSize)
|
||||||
func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byte, err error) {
|
if err != nil {
|
||||||
if _, err := p.r.Read(p.header[:]); err != nil {
|
return false, nil, err
|
||||||
return 0, nil, err
|
|
||||||
}
|
}
|
||||||
|
return isCompressed, msg, nil
|
||||||
pf = payloadFormat(p.header[0])
|
|
||||||
length := binary.BigEndian.Uint32(p.header[1:])
|
|
||||||
|
|
||||||
if length == 0 {
|
|
||||||
return pf, nil, nil
|
|
||||||
}
|
|
||||||
if int64(length) > int64(maxInt) {
|
|
||||||
return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max length allowed on current machine (%d vs. %d)", length, maxInt)
|
|
||||||
}
|
|
||||||
if int(length) > maxReceiveMessageSize {
|
|
||||||
return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize)
|
|
||||||
}
|
|
||||||
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
|
|
||||||
// of making it for each message:
|
|
||||||
msg = make([]byte, int(length))
|
|
||||||
if _, err := p.r.Read(msg); err != nil {
|
|
||||||
if err == io.EOF {
|
|
||||||
err = io.ErrUnexpectedEOF
|
|
||||||
}
|
|
||||||
return 0, nil, err
|
|
||||||
}
|
|
||||||
return pf, msg, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// encode serializes msg and returns a buffer of message header and a buffer of msg.
|
// encode serializes msg and returns a buffer of msg.
|
||||||
// If msg is nil, it generates the message header and an empty msg buffer.
|
// If msg is nil, it generates an empty buffer.
|
||||||
// TODO(ddyihai): eliminate extra Compressor parameter.
|
// TODO(ddyihai): eliminate extra Compressor parameter.
|
||||||
func encode(c baseCodec, msg interface{}, cp Compressor, outPayload *stats.OutPayload, compressor encoding.Compressor) ([]byte, []byte, error) {
|
func encode(c baseCodec, msg interface{}, cp Compressor, outPayload *stats.OutPayload, compressor encoding.Compressor) ([]byte, error) {
|
||||||
var (
|
var (
|
||||||
b []byte
|
b []byte
|
||||||
cbuf *bytes.Buffer
|
cbuf *bytes.Buffer
|
||||||
)
|
)
|
||||||
const (
|
|
||||||
payloadLen = 1
|
|
||||||
sizeLen = 4
|
|
||||||
)
|
|
||||||
if msg != nil {
|
if msg != nil {
|
||||||
var err error
|
var err error
|
||||||
b, err = c.Marshal(msg)
|
b, err = c.Marshal(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
|
return nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
|
||||||
}
|
}
|
||||||
if outPayload != nil {
|
if outPayload != nil {
|
||||||
outPayload.Payload = msg
|
outPayload.Payload = msg
|
||||||
@ -507,49 +460,36 @@ func encode(c baseCodec, msg interface{}, cp Compressor, outPayload *stats.OutPa
|
|||||||
if compressor != nil {
|
if compressor != nil {
|
||||||
z, _ := compressor.Compress(cbuf)
|
z, _ := compressor.Compress(cbuf)
|
||||||
if _, err := z.Write(b); err != nil {
|
if _, err := z.Write(b); err != nil {
|
||||||
return nil, nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
|
return nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
|
||||||
}
|
}
|
||||||
z.Close()
|
z.Close()
|
||||||
} else {
|
} else {
|
||||||
// If Compressor is not set by UseCompressor, use default Compressor
|
// If Compressor is not set by UseCompressor, use default Compressor
|
||||||
if err := cp.Do(cbuf, b); err != nil {
|
if err := cp.Do(cbuf, b); err != nil {
|
||||||
return nil, nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
|
return nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
b = cbuf.Bytes()
|
b = cbuf.Bytes()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if uint(len(b)) > math.MaxUint32 {
|
if uint(len(b)) > math.MaxUint32 {
|
||||||
return nil, nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b))
|
return nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
bufHeader := make([]byte, payloadLen+sizeLen)
|
|
||||||
if compressor != nil || cp != nil {
|
|
||||||
bufHeader[0] = byte(compressionMade)
|
|
||||||
} else {
|
|
||||||
bufHeader[0] = byte(compressionNone)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write length of b into buf
|
|
||||||
binary.BigEndian.PutUint32(bufHeader[payloadLen:], uint32(len(b)))
|
|
||||||
if outPayload != nil {
|
if outPayload != nil {
|
||||||
outPayload.WireLength = payloadLen + sizeLen + len(b)
|
// A 5 byte gRPC-specific message header will added to this message
|
||||||
|
// before it's put on wire.
|
||||||
|
outPayload.WireLength = 5 + len(b)
|
||||||
}
|
}
|
||||||
return bufHeader, b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status {
|
func checkRecvPayload(recvCompress string, haveCompressor bool) *status.Status {
|
||||||
switch pf {
|
if recvCompress == "" || recvCompress == encoding.Identity {
|
||||||
case compressionNone:
|
return status.New(codes.Internal, "grpc: compressed flag set with identity or empty encoding")
|
||||||
case compressionMade:
|
}
|
||||||
if recvCompress == "" || recvCompress == encoding.Identity {
|
if !haveCompressor {
|
||||||
return status.New(codes.Internal, "grpc: compressed flag set with identity or empty encoding")
|
return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
|
||||||
}
|
|
||||||
if !haveCompressor {
|
|
||||||
return status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return status.Newf(codes.Internal, "grpc: received unexpected payload format %d", pf)
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -557,8 +497,8 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool
|
|||||||
// For the two compressor parameters, both should not be set, but if they are,
|
// For the two compressor parameters, both should not be set, but if they are,
|
||||||
// dc takes precedence over compressor.
|
// dc takes precedence over compressor.
|
||||||
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
|
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
|
||||||
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload, compressor encoding.Compressor) error {
|
func recv(c baseCodec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload, compressor encoding.Compressor) error {
|
||||||
pf, d, err := p.recvMsg(maxReceiveMessageSize)
|
isCompressed, d, err := recvMsg(s, maxReceiveMessageSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -566,11 +506,10 @@ func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m interf
|
|||||||
inPayload.WireLength = len(d)
|
inPayload.WireLength = len(d)
|
||||||
}
|
}
|
||||||
|
|
||||||
if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil {
|
if isCompressed {
|
||||||
return st.Err()
|
if st := checkRecvPayload(s.RecvCompress(), compressor != nil || dc != nil); st != nil {
|
||||||
}
|
return st.Err()
|
||||||
|
}
|
||||||
if pf == compressionMade {
|
|
||||||
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
|
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
|
||||||
// use this decompressor as the default.
|
// use this decompressor as the default.
|
||||||
if dc != nil {
|
if dc != nil {
|
||||||
@ -588,11 +527,11 @@ func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m interf
|
|||||||
return status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
|
return status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
if len(d) > maxReceiveMessageSize {
|
||||||
if len(d) > maxReceiveMessageSize {
|
// TODO: Revisit the error code. Currently keep it consistent with java
|
||||||
// TODO: Revisit the error code. Currently keep it consistent with java
|
// implementation.
|
||||||
// implementation.
|
return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(d), maxReceiveMessageSize)
|
||||||
return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(d), maxReceiveMessageSize)
|
}
|
||||||
}
|
}
|
||||||
if err := c.Unmarshal(d, m); err != nil {
|
if err := c.Unmarshal(d, m); err != nil {
|
||||||
return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
|
return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
|
||||||
|
@ -22,7 +22,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -45,77 +44,20 @@ func (f fullReader) Read(p []byte) (int, error) {
|
|||||||
|
|
||||||
var _ CallOption = EmptyCallOption{} // ensure EmptyCallOption implements the interface
|
var _ CallOption = EmptyCallOption{} // ensure EmptyCallOption implements the interface
|
||||||
|
|
||||||
func TestSimpleParsing(t *testing.T) {
|
|
||||||
bigMsg := bytes.Repeat([]byte{'x'}, 1<<24)
|
|
||||||
for _, test := range []struct {
|
|
||||||
// input
|
|
||||||
p []byte
|
|
||||||
// outputs
|
|
||||||
err error
|
|
||||||
b []byte
|
|
||||||
pt payloadFormat
|
|
||||||
}{
|
|
||||||
{nil, io.EOF, nil, compressionNone},
|
|
||||||
{[]byte{0, 0, 0, 0, 0}, nil, nil, compressionNone},
|
|
||||||
{[]byte{0, 0, 0, 0, 1, 'a'}, nil, []byte{'a'}, compressionNone},
|
|
||||||
{[]byte{1, 0}, io.ErrUnexpectedEOF, nil, compressionNone},
|
|
||||||
{[]byte{0, 0, 0, 0, 10, 'a'}, io.ErrUnexpectedEOF, nil, compressionNone},
|
|
||||||
// Check that messages with length >= 2^24 are parsed.
|
|
||||||
{append([]byte{0, 1, 0, 0, 0}, bigMsg...), nil, bigMsg, compressionNone},
|
|
||||||
} {
|
|
||||||
buf := fullReader{bytes.NewReader(test.p)}
|
|
||||||
parser := &parser{r: buf}
|
|
||||||
pt, b, err := parser.recvMsg(math.MaxInt32)
|
|
||||||
if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt {
|
|
||||||
t.Fatalf("parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMultipleParsing(t *testing.T) {
|
|
||||||
// Set a byte stream consists of 3 messages with their headers.
|
|
||||||
p := []byte{0, 0, 0, 0, 1, 'a', 0, 0, 0, 0, 2, 'b', 'c', 0, 0, 0, 0, 1, 'd'}
|
|
||||||
b := fullReader{bytes.NewReader(p)}
|
|
||||||
parser := &parser{r: b}
|
|
||||||
|
|
||||||
wantRecvs := []struct {
|
|
||||||
pt payloadFormat
|
|
||||||
data []byte
|
|
||||||
}{
|
|
||||||
{compressionNone, []byte("a")},
|
|
||||||
{compressionNone, []byte("bc")},
|
|
||||||
{compressionNone, []byte("d")},
|
|
||||||
}
|
|
||||||
for i, want := range wantRecvs {
|
|
||||||
pt, data, err := parser.recvMsg(math.MaxInt32)
|
|
||||||
if err != nil || pt != want.pt || !reflect.DeepEqual(data, want.data) {
|
|
||||||
t.Fatalf("after %d calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, <nil>",
|
|
||||||
i, p, pt, data, err, want.pt, want.data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pt, data, err := parser.recvMsg(math.MaxInt32)
|
|
||||||
if err != io.EOF {
|
|
||||||
t.Fatalf("after %d recvMsgs calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant _, _, %v",
|
|
||||||
len(wantRecvs), p, pt, data, err, io.EOF)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEncode(t *testing.T) {
|
func TestEncode(t *testing.T) {
|
||||||
for _, test := range []struct {
|
for _, test := range []struct {
|
||||||
// input
|
// input
|
||||||
msg proto.Message
|
msg proto.Message
|
||||||
cp Compressor
|
cp Compressor
|
||||||
// outputs
|
// outputs
|
||||||
hdr []byte
|
|
||||||
data []byte
|
data []byte
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
{nil, nil, []byte{0, 0, 0, 0, 0}, []byte{}, nil},
|
{nil, nil, []byte{}, nil},
|
||||||
} {
|
} {
|
||||||
hdr, data, err := encode(encoding.GetCodec(protoenc.Name), test.msg, nil, nil, nil)
|
data, err := encode(encoding.GetCodec(protoenc.Name), test.msg, nil, nil, nil)
|
||||||
if err != test.err || !bytes.Equal(hdr, test.hdr) || !bytes.Equal(data, test.data) {
|
if err != test.err || !bytes.Equal(data, test.data) {
|
||||||
t.Fatalf("encode(_, _, %v, _) = %v, %v, %v\nwant %v, %v, %v", test.cp, hdr, data, err, test.hdr, test.data, test.err)
|
t.Fatalf("encode(_, _, %v, _) = %v, %v\nwant %v, %v", test.cp, data, err, test.data, test.err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -214,8 +156,11 @@ func TestParseDialTarget(t *testing.T) {
|
|||||||
func bmEncode(b *testing.B, mSize int) {
|
func bmEncode(b *testing.B, mSize int) {
|
||||||
cdc := encoding.GetCodec(protoenc.Name)
|
cdc := encoding.GetCodec(protoenc.Name)
|
||||||
msg := &perfpb.Buffer{Body: make([]byte, mSize)}
|
msg := &perfpb.Buffer{Body: make([]byte, mSize)}
|
||||||
encodeHdr, encodeData, _ := encode(cdc, msg, nil, nil, nil)
|
encodeData, _ := encode(cdc, msg, nil, nil, nil)
|
||||||
encodedSz := int64(len(encodeHdr) + len(encodeData))
|
// 5 bytes of gRPC-specific message header
|
||||||
|
// is added to the message before it is written
|
||||||
|
// to the wire.
|
||||||
|
encodedSz := int64(5 + len(encodeData))
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
|
30
server.go
30
server.go
@ -831,7 +831,7 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str
|
|||||||
if s.opts.statsHandler != nil {
|
if s.opts.statsHandler != nil {
|
||||||
outPayload = &stats.OutPayload{}
|
outPayload = &stats.OutPayload{}
|
||||||
}
|
}
|
||||||
hdr, data, err := encode(s.getCodec(stream.ContentSubtype()), msg, cp, outPayload, comp)
|
data, err := encode(s.getCodec(stream.ContentSubtype()), msg, cp, outPayload, comp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
grpclog.Errorln("grpc: server failed to encode response: ", err)
|
grpclog.Errorln("grpc: server failed to encode response: ", err)
|
||||||
return err
|
return err
|
||||||
@ -839,7 +839,8 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str
|
|||||||
if len(data) > s.opts.maxSendMessageSize {
|
if len(data) > s.opts.maxSendMessageSize {
|
||||||
return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), s.opts.maxSendMessageSize)
|
return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), s.opts.maxSendMessageSize)
|
||||||
}
|
}
|
||||||
err = t.Write(stream, hdr, data, opts)
|
opts.IsCompressed = cp != nil || comp != nil
|
||||||
|
err = t.Write(stream, data, opts)
|
||||||
if err == nil && outPayload != nil {
|
if err == nil && outPayload != nil {
|
||||||
outPayload.SentTime = time.Now()
|
outPayload.SentTime = time.Now()
|
||||||
s.opts.statsHandler.HandleRPC(stream.Context(), outPayload)
|
s.opts.statsHandler.HandleRPC(stream.Context(), outPayload)
|
||||||
@ -924,8 +925,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
p := &parser{r: stream}
|
isCompressed, req, err := recvMsg(stream, s.opts.maxReceiveMessageSize)
|
||||||
pf, req, err := p.recvMsg(s.opts.maxReceiveMessageSize)
|
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
// The entire stream is done (for unary RPC only).
|
// The entire stream is done (for unary RPC only).
|
||||||
return err
|
return err
|
||||||
@ -955,12 +955,6 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||||||
if channelz.IsOn() {
|
if channelz.IsOn() {
|
||||||
t.IncrMsgRecv()
|
t.IncrMsgRecv()
|
||||||
}
|
}
|
||||||
if st := checkRecvPayload(pf, stream.RecvCompress(), dc != nil || decomp != nil); st != nil {
|
|
||||||
if e := t.WriteStatus(stream, st); e != nil {
|
|
||||||
grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status %v", e)
|
|
||||||
}
|
|
||||||
return st.Err()
|
|
||||||
}
|
|
||||||
var inPayload *stats.InPayload
|
var inPayload *stats.InPayload
|
||||||
if sh != nil {
|
if sh != nil {
|
||||||
inPayload = &stats.InPayload{
|
inPayload = &stats.InPayload{
|
||||||
@ -971,7 +965,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||||||
if inPayload != nil {
|
if inPayload != nil {
|
||||||
inPayload.WireLength = len(req)
|
inPayload.WireLength = len(req)
|
||||||
}
|
}
|
||||||
if pf == compressionMade {
|
if isCompressed {
|
||||||
|
if st := checkRecvPayload(stream.RecvCompress(), dc != nil || decomp != nil); st != nil {
|
||||||
|
return st.Err()
|
||||||
|
}
|
||||||
var err error
|
var err error
|
||||||
if dc != nil {
|
if dc != nil {
|
||||||
req, err = dc.Do(bytes.NewReader(req))
|
req, err = dc.Do(bytes.NewReader(req))
|
||||||
@ -985,11 +982,11 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||||||
return status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
|
return status.Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
if len(req) > s.opts.maxReceiveMessageSize {
|
||||||
if len(req) > s.opts.maxReceiveMessageSize {
|
// TODO: Revisit the error code. Currently keep it consistent with
|
||||||
// TODO: Revisit the error code. Currently keep it consistent with
|
// java implementation.
|
||||||
// java implementation.
|
return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(req), s.opts.maxReceiveMessageSize)
|
||||||
return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(req), s.opts.maxReceiveMessageSize)
|
}
|
||||||
}
|
}
|
||||||
if err := s.getCodec(stream.ContentSubtype()).Unmarshal(req, v); err != nil {
|
if err := s.getCodec(stream.ContentSubtype()).Unmarshal(req, v); err != nil {
|
||||||
return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
|
return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
|
||||||
@ -1100,7 +1097,6 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
|
|||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
t: t,
|
t: t,
|
||||||
s: stream,
|
s: stream,
|
||||||
p: &parser{r: stream},
|
|
||||||
codec: s.getCodec(stream.ContentSubtype()),
|
codec: s.getCodec(stream.ContentSubtype()),
|
||||||
maxReceiveMessageSize: s.opts.maxReceiveMessageSize,
|
maxReceiveMessageSize: s.opts.maxReceiveMessageSize,
|
||||||
maxSendMessageSize: s.opts.maxSendMessageSize,
|
maxSendMessageSize: s.opts.maxSendMessageSize,
|
||||||
|
27
stream.go
27
stream.go
@ -290,7 +290,6 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
|
|||||||
attempt: &csAttempt{
|
attempt: &csAttempt{
|
||||||
t: t,
|
t: t,
|
||||||
s: s,
|
s: s,
|
||||||
p: &parser{r: s},
|
|
||||||
done: done,
|
done: done,
|
||||||
dc: cc.dopts.dc,
|
dc: cc.dopts.dc,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
@ -347,7 +346,6 @@ type csAttempt struct {
|
|||||||
cs *clientStream
|
cs *clientStream
|
||||||
t transport.ClientTransport
|
t transport.ClientTransport
|
||||||
s *transport.Stream
|
s *transport.Stream
|
||||||
p *parser
|
|
||||||
done func(balancer.DoneInfo)
|
done func(balancer.DoneInfo)
|
||||||
|
|
||||||
dc Decompressor
|
dc Decompressor
|
||||||
@ -472,7 +470,7 @@ func (a *csAttempt) sendMsg(m interface{}) (err error) {
|
|||||||
Client: true,
|
Client: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
hdr, data, err := encode(cs.codec, m, cs.cp, outPayload, cs.comp)
|
data, err := encode(cs.codec, m, cs.cp, outPayload, cs.comp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -482,7 +480,11 @@ func (a *csAttempt) sendMsg(m interface{}) (err error) {
|
|||||||
if !cs.desc.ClientStreams {
|
if !cs.desc.ClientStreams {
|
||||||
cs.sentLast = true
|
cs.sentLast = true
|
||||||
}
|
}
|
||||||
err = a.t.Write(a.s, hdr, data, &transport.Options{Last: !cs.desc.ClientStreams})
|
opts := &transport.Options{
|
||||||
|
Last: !cs.desc.ClientStreams,
|
||||||
|
IsCompressed: cs.cp != nil || cs.comp != nil,
|
||||||
|
}
|
||||||
|
err = a.t.Write(a.s, data, opts)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if outPayload != nil {
|
if outPayload != nil {
|
||||||
outPayload.SentTime = time.Now()
|
outPayload.SentTime = time.Now()
|
||||||
@ -526,7 +528,7 @@ func (a *csAttempt) recvMsg(m interface{}) (err error) {
|
|||||||
// Only initialize this state once per stream.
|
// Only initialize this state once per stream.
|
||||||
a.decompSet = true
|
a.decompSet = true
|
||||||
}
|
}
|
||||||
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.c.maxReceiveMessageSize, inPayload, a.decomp)
|
err = recv(cs.codec, a.s, a.dc, m, *cs.c.maxReceiveMessageSize, inPayload, a.decomp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
if statusErr := a.s.Status().Err(); statusErr != nil {
|
if statusErr := a.s.Status().Err(); statusErr != nil {
|
||||||
@ -556,7 +558,7 @@ func (a *csAttempt) recvMsg(m interface{}) (err error) {
|
|||||||
|
|
||||||
// Special handling for non-server-stream rpcs.
|
// Special handling for non-server-stream rpcs.
|
||||||
// This recv expects EOF or errors, so we don't collect inPayload.
|
// This recv expects EOF or errors, so we don't collect inPayload.
|
||||||
err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.c.maxReceiveMessageSize, nil, a.decomp)
|
err = recv(cs.codec, a.s, a.dc, m, *cs.c.maxReceiveMessageSize, nil, a.decomp)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
|
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
|
||||||
}
|
}
|
||||||
@ -572,7 +574,7 @@ func (a *csAttempt) closeSend() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
cs.sentLast = true
|
cs.sentLast = true
|
||||||
cs.attempt.t.Write(cs.attempt.s, nil, nil, &transport.Options{Last: true})
|
cs.attempt.t.Write(cs.attempt.s, nil, &transport.Options{Last: true})
|
||||||
// We ignore errors from Write. Any error it would return would also be
|
// We ignore errors from Write. Any error it would return would also be
|
||||||
// returned by a subsequent RecvMsg call, and the user is supposed to always
|
// returned by a subsequent RecvMsg call, and the user is supposed to always
|
||||||
// finish the stream by calling RecvMsg until it returns err != nil.
|
// finish the stream by calling RecvMsg until it returns err != nil.
|
||||||
@ -635,7 +637,6 @@ type serverStream struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
t transport.ServerTransport
|
t transport.ServerTransport
|
||||||
s *transport.Stream
|
s *transport.Stream
|
||||||
p *parser
|
|
||||||
codec baseCodec
|
codec baseCodec
|
||||||
|
|
||||||
cp Compressor
|
cp Compressor
|
||||||
@ -700,14 +701,18 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
|
|||||||
if ss.statsHandler != nil {
|
if ss.statsHandler != nil {
|
||||||
outPayload = &stats.OutPayload{}
|
outPayload = &stats.OutPayload{}
|
||||||
}
|
}
|
||||||
hdr, data, err := encode(ss.codec, m, ss.cp, outPayload, ss.comp)
|
data, err := encode(ss.codec, m, ss.cp, outPayload, ss.comp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if len(data) > ss.maxSendMessageSize {
|
if len(data) > ss.maxSendMessageSize {
|
||||||
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize)
|
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize)
|
||||||
}
|
}
|
||||||
if err := ss.t.Write(ss.s, hdr, data, &transport.Options{Last: false}); err != nil {
|
opts := &transport.Options{
|
||||||
|
Last: false,
|
||||||
|
IsCompressed: ss.cp != nil || ss.comp != nil,
|
||||||
|
}
|
||||||
|
if err := ss.t.Write(ss.s, data, opts); err != nil {
|
||||||
return toRPCErr(err)
|
return toRPCErr(err)
|
||||||
}
|
}
|
||||||
if outPayload != nil {
|
if outPayload != nil {
|
||||||
@ -743,7 +748,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
|
|||||||
if ss.statsHandler != nil {
|
if ss.statsHandler != nil {
|
||||||
inPayload = &stats.InPayload{}
|
inPayload = &stats.InPayload{}
|
||||||
}
|
}
|
||||||
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload, ss.decomp); err != nil {
|
if err := recv(ss.codec, ss.s, ss.dc, m, ss.maxReceiveMessageSize, inPayload, ss.decomp); err != nil {
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -21,7 +21,6 @@ package transport
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -96,35 +95,39 @@ func (w *writeQuota) replenish(n int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type trInFlow struct {
|
type trInFlow struct {
|
||||||
limit uint32
|
limit uint32 // accessed by reader goroutine.
|
||||||
unacked uint32
|
unacked uint32 // accessed by reader goroutine.
|
||||||
effectiveWindowSize uint32
|
effectiveWindowSize uint32 // accessed by reader and channelz request goroutine.
|
||||||
|
// Callback used to schedule window update.
|
||||||
|
scheduleWU func(uint32)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *trInFlow) newLimit(n uint32) uint32 {
|
// Sets the new limit.
|
||||||
d := n - f.limit
|
func (f *trInFlow) newLimit(n uint32) {
|
||||||
|
if n > f.limit {
|
||||||
|
f.scheduleWU(n - f.limit)
|
||||||
|
}
|
||||||
f.limit = n
|
f.limit = n
|
||||||
f.updateEffectiveWindowSize()
|
f.updateEffectiveWindowSize()
|
||||||
return d
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *trInFlow) onData(n uint32) uint32 {
|
func (f *trInFlow) onData(n uint32) {
|
||||||
f.unacked += n
|
f.unacked += n
|
||||||
if f.unacked >= f.limit/4 {
|
if f.unacked >= f.limit/4 {
|
||||||
w := f.unacked
|
w := f.unacked
|
||||||
f.unacked = 0
|
f.unacked = 0
|
||||||
f.updateEffectiveWindowSize()
|
f.scheduleWU(w)
|
||||||
return w
|
|
||||||
}
|
}
|
||||||
f.updateEffectiveWindowSize()
|
f.updateEffectiveWindowSize()
|
||||||
return 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *trInFlow) reset() uint32 {
|
func (f *trInFlow) reset() {
|
||||||
w := f.unacked
|
if f.unacked == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.scheduleWU(f.unacked)
|
||||||
f.unacked = 0
|
f.unacked = 0
|
||||||
f.updateEffectiveWindowSize()
|
f.updateEffectiveWindowSize()
|
||||||
return w
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *trInFlow) updateEffectiveWindowSize() {
|
func (f *trInFlow) updateEffectiveWindowSize() {
|
||||||
@ -135,102 +138,57 @@ func (f *trInFlow) getSize() uint32 {
|
|||||||
return atomic.LoadUint32(&f.effectiveWindowSize)
|
return atomic.LoadUint32(&f.effectiveWindowSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(mmukhi): Simplify this code.
|
// stInFlow deals with inbound flow control for stream.
|
||||||
// inFlow deals with inbound flow control
|
// It can be simultaneously read by transport's reader
|
||||||
type inFlow struct {
|
// goroutine and an RPC's goroutine.
|
||||||
mu sync.Mutex
|
// It is protected by the lock in stream that owns it.
|
||||||
|
type stInFlow struct {
|
||||||
|
// rcvd is the bytes of data that this end-point has
|
||||||
|
// received from the perspective of other side.
|
||||||
|
// This can go negative. It must be Accessed atomically.
|
||||||
|
// Needs to be aligned because of golang bug with atomics:
|
||||||
|
// https://golang.org/pkg/sync/atomic/#pkg-note-BUG
|
||||||
|
rcvd int64
|
||||||
// The inbound flow control limit for pending data.
|
// The inbound flow control limit for pending data.
|
||||||
limit uint32
|
limit uint32
|
||||||
// pendingData is the overall data which have been received but not been
|
// number of bytes received so far, this should be accessed
|
||||||
// consumed by applications.
|
// number of bytes that have been read by the RPC.
|
||||||
pendingData uint32
|
read uint32
|
||||||
// The amount of data the application has consumed but grpc has not sent
|
// a window update should be sent when the RPC has
|
||||||
// window update for them. Used to reduce window update frequency.
|
// read these many bytes.
|
||||||
pendingUpdate uint32
|
// TODO(mmukhi, dfawley): Does this have to be limit/4?
|
||||||
// delta is the extra window update given by receiver when an application
|
// Keeping it a constant makes implementation easy.
|
||||||
// is reading data bigger in size than the inFlow limit.
|
wuThreshold uint32
|
||||||
delta uint32
|
// Callback used to schedule window update.
|
||||||
|
scheduleWU func(uint32)
|
||||||
}
|
}
|
||||||
|
|
||||||
// newLimit updates the inflow window to a new value n.
|
// called by transport's reader goroutine to set a new limit on
|
||||||
// It assumes that n is always greater than the old limit.
|
// incoming flow control based on BDP estimation.
|
||||||
func (f *inFlow) newLimit(n uint32) uint32 {
|
func (s *stInFlow) newLimit(n uint32) {
|
||||||
f.mu.Lock()
|
s.limit = n
|
||||||
d := n - f.limit
|
|
||||||
f.limit = n
|
|
||||||
f.mu.Unlock()
|
|
||||||
return d
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *inFlow) maybeAdjust(n uint32) uint32 {
|
// called by transport's reader goroutine when data is received by it.
|
||||||
if n > uint32(math.MaxInt32) {
|
func (s *stInFlow) onData(n uint32) error {
|
||||||
n = uint32(math.MaxInt32)
|
rcvd := atomic.AddInt64(&s.rcvd, int64(n))
|
||||||
|
if rcvd > int64(s.limit) { // Flow control violation.
|
||||||
|
return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", rcvd, s.limit)
|
||||||
}
|
}
|
||||||
f.mu.Lock()
|
|
||||||
// estSenderQuota is the receiver's view of the maximum number of bytes the sender
|
|
||||||
// can send without a window update.
|
|
||||||
estSenderQuota := int32(f.limit - (f.pendingData + f.pendingUpdate))
|
|
||||||
// estUntransmittedData is the maximum number of bytes the sends might not have put
|
|
||||||
// on the wire yet. A value of 0 or less means that we have already received all or
|
|
||||||
// more bytes than the application is requesting to read.
|
|
||||||
estUntransmittedData := int32(n - f.pendingData) // Casting into int32 since it could be negative.
|
|
||||||
// This implies that unless we send a window update, the sender won't be able to send all the bytes
|
|
||||||
// for this message. Therefore we must send an update over the limit since there's an active read
|
|
||||||
// request from the application.
|
|
||||||
if estUntransmittedData > estSenderQuota {
|
|
||||||
// Sender's window shouldn't go more than 2^31 - 1 as specified in the HTTP spec.
|
|
||||||
if f.limit+n > maxWindowSize {
|
|
||||||
f.delta = maxWindowSize - f.limit
|
|
||||||
} else {
|
|
||||||
// Send a window update for the whole message and not just the difference between
|
|
||||||
// estUntransmittedData and estSenderQuota. This will be helpful in case the message
|
|
||||||
// is padded; We will fallback on the current available window(at least a 1/4th of the limit).
|
|
||||||
f.delta = n
|
|
||||||
}
|
|
||||||
f.mu.Unlock()
|
|
||||||
return f.delta
|
|
||||||
}
|
|
||||||
f.mu.Unlock()
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// onData is invoked when some data frame is received. It updates pendingData.
|
|
||||||
func (f *inFlow) onData(n uint32) error {
|
|
||||||
f.mu.Lock()
|
|
||||||
f.pendingData += n
|
|
||||||
if f.pendingData+f.pendingUpdate > f.limit+f.delta {
|
|
||||||
limit := f.limit
|
|
||||||
rcvd := f.pendingData + f.pendingUpdate
|
|
||||||
f.mu.Unlock()
|
|
||||||
return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", rcvd, limit)
|
|
||||||
}
|
|
||||||
f.mu.Unlock()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// onRead is invoked when the application reads the data. It returns the window size
|
// called by RPC's goroutine when data is read by it.
|
||||||
// to be sent to the peer.
|
func (s *stInFlow) onRead(n uint32) {
|
||||||
func (f *inFlow) onRead(n uint32) uint32 {
|
s.read += n
|
||||||
f.mu.Lock()
|
if s.read >= s.wuThreshold {
|
||||||
if f.pendingData == 0 {
|
val := atomic.AddInt64(&s.rcvd, ^int64(s.read-1))
|
||||||
f.mu.Unlock()
|
// Check if threshold needs to go up since limit might have gone up.
|
||||||
return 0
|
val += int64(s.read)
|
||||||
|
if val > int64(4*s.wuThreshold) {
|
||||||
|
s.wuThreshold = uint32(val / 4)
|
||||||
|
}
|
||||||
|
s.scheduleWU(s.read)
|
||||||
|
s.read = 0
|
||||||
}
|
}
|
||||||
f.pendingData -= n
|
|
||||||
if n > f.delta {
|
|
||||||
n -= f.delta
|
|
||||||
f.delta = 0
|
|
||||||
} else {
|
|
||||||
f.delta -= n
|
|
||||||
n = 0
|
|
||||||
}
|
|
||||||
f.pendingUpdate += n
|
|
||||||
if f.pendingUpdate >= f.limit/4 {
|
|
||||||
wu := f.pendingUpdate
|
|
||||||
f.pendingUpdate = 0
|
|
||||||
f.mu.Unlock()
|
|
||||||
return wu
|
|
||||||
}
|
|
||||||
f.mu.Unlock()
|
|
||||||
return 0
|
|
||||||
}
|
}
|
||||||
|
@ -38,6 +38,7 @@ import (
|
|||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
|
"google.golang.org/grpc/internal/msgdecoder"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/peer"
|
"google.golang.org/grpc/peer"
|
||||||
"google.golang.org/grpc/stats"
|
"google.golang.org/grpc/stats"
|
||||||
@ -269,10 +270,10 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
|
func (ht *serverHandlerTransport) Write(s *Stream, data []byte, opts *Options) error {
|
||||||
return ht.do(func() {
|
return ht.do(func() {
|
||||||
ht.writeCommonHeaders(s)
|
ht.writeCommonHeaders(s)
|
||||||
ht.rw.Write(hdr)
|
ht.rw.Write(msgdecoder.CreateMessageHeader(len(data), opts.IsCompressed))
|
||||||
ht.rw.Write(data)
|
ht.rw.Write(data)
|
||||||
if !opts.Delay {
|
if !opts.Delay {
|
||||||
ht.rw.(http.Flusher).Flush()
|
ht.rw.(http.Flusher).Flush()
|
||||||
@ -337,16 +338,13 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
|
|||||||
|
|
||||||
req := ht.req
|
req := ht.req
|
||||||
|
|
||||||
s := &Stream{
|
s := newStream(ctx)
|
||||||
id: 0, // irrelevant
|
s.cancel = cancel
|
||||||
requestRead: func(int) {},
|
s.st = ht
|
||||||
cancel: cancel,
|
s.method = req.URL.Path
|
||||||
buf: newRecvBuffer(),
|
s.recvCompress = req.Header.Get("grpc-encoding")
|
||||||
st: ht,
|
s.contentSubtype = ht.contentSubtype
|
||||||
method: req.URL.Path,
|
|
||||||
recvCompress: req.Header.Get("grpc-encoding"),
|
|
||||||
contentSubtype: ht.contentSubtype,
|
|
||||||
}
|
|
||||||
pr := &peer.Peer{
|
pr := &peer.Peer{
|
||||||
Addr: ht.RemoteAddr(),
|
Addr: ht.RemoteAddr(),
|
||||||
}
|
}
|
||||||
@ -364,10 +362,6 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
|
|||||||
}
|
}
|
||||||
ht.stats.HandleRPC(s.ctx, inHeader)
|
ht.stats.HandleRPC(s.ctx, inHeader)
|
||||||
}
|
}
|
||||||
s.trReader = &transportReader{
|
|
||||||
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf},
|
|
||||||
windowHandler: func(int) {},
|
|
||||||
}
|
|
||||||
|
|
||||||
// readerDone is closed when the Body.Read-ing goroutine exits.
|
// readerDone is closed when the Body.Read-ing goroutine exits.
|
||||||
readerDone := make(chan struct{})
|
readerDone := make(chan struct{})
|
||||||
@ -379,11 +373,11 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
|
|||||||
for buf := make([]byte, readSize); ; {
|
for buf := make([]byte, readSize); ; {
|
||||||
n, err := req.Body.Read(buf)
|
n, err := req.Body.Read(buf)
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
s.buf.put(recvMsg{data: buf[:n:n]})
|
s.consume(buf[:n:n], 0)
|
||||||
buf = buf[n:]
|
buf = buf[n:]
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.buf.put(recvMsg{err: mapRecvMsgError(err)})
|
s.notifyErr(mapRecvMsgError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(buf) == 0 {
|
if len(buf) == 0 {
|
||||||
|
@ -423,7 +423,7 @@ func TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
|
|||||||
st.bodyw.Close() // no body
|
st.bodyw.Close() // no body
|
||||||
|
|
||||||
st.ht.WriteStatus(s, status.New(codes.OK, ""))
|
st.ht.WriteStatus(s, status.New(codes.OK, ""))
|
||||||
st.ht.Write(s, []byte("hdr"), []byte("data"), &Options{})
|
st.ht.Write(s, []byte("data"), &Options{})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ import (
|
|||||||
"google.golang.org/grpc/channelz"
|
"google.golang.org/grpc/channelz"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
|
"google.golang.org/grpc/internal/msgdecoder"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/peer"
|
"google.golang.org/grpc/peer"
|
||||||
@ -95,8 +96,9 @@ type http2Client struct {
|
|||||||
waitingStreams uint32
|
waitingStreams uint32
|
||||||
nextID uint32
|
nextID uint32
|
||||||
|
|
||||||
mu sync.Mutex // guard the following variables
|
mu sync.Mutex // guard the following variables
|
||||||
state transportState
|
state transportState
|
||||||
|
// TODO(mmukhi): Make this a sharded map.
|
||||||
activeStreams map[uint32]*Stream
|
activeStreams map[uint32]*Stream
|
||||||
// prevGoAway ID records the Last-Stream-ID in the previous GOAway frame.
|
// prevGoAway ID records the Last-Stream-ID in the previous GOAway frame.
|
||||||
prevGoAwayID uint32
|
prevGoAwayID uint32
|
||||||
@ -218,7 +220,6 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
|
|||||||
goAway: make(chan struct{}),
|
goAway: make(chan struct{}),
|
||||||
awakenKeepalive: make(chan struct{}, 1),
|
awakenKeepalive: make(chan struct{}, 1),
|
||||||
framer: newFramer(conn, writeBufSize, readBufSize),
|
framer: newFramer(conn, writeBufSize, readBufSize),
|
||||||
fc: &trInFlow{limit: uint32(icwz)},
|
|
||||||
scheme: scheme,
|
scheme: scheme,
|
||||||
activeStreams: make(map[uint32]*Stream),
|
activeStreams: make(map[uint32]*Stream),
|
||||||
isSecure: isSecure,
|
isSecure: isSecure,
|
||||||
@ -233,6 +234,15 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
|
|||||||
streamsQuotaAvailable: make(chan struct{}, 1),
|
streamsQuotaAvailable: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
t.controlBuf = newControlBuffer(t.ctxDone)
|
t.controlBuf = newControlBuffer(t.ctxDone)
|
||||||
|
t.fc = &trInFlow{
|
||||||
|
limit: uint32(icwz),
|
||||||
|
scheduleWU: func(w uint32) {
|
||||||
|
t.controlBuf.put(&outgoingWindowUpdate{
|
||||||
|
streamID: 0,
|
||||||
|
increment: w,
|
||||||
|
})
|
||||||
|
},
|
||||||
|
}
|
||||||
if opts.InitialWindowSize >= defaultWindowSize {
|
if opts.InitialWindowSize >= defaultWindowSize {
|
||||||
t.initialWindowSize = opts.InitialWindowSize
|
t.initialWindowSize = opts.InitialWindowSize
|
||||||
dynamicWindow = false
|
dynamicWindow = false
|
||||||
@ -306,33 +316,17 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
|
func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
|
||||||
// TODO(zhaoq): Handle uint32 overflow of Stream.id.
|
|
||||||
s := &Stream{
|
|
||||||
done: make(chan struct{}),
|
|
||||||
method: callHdr.Method,
|
|
||||||
sendCompress: callHdr.SendCompress,
|
|
||||||
buf: newRecvBuffer(),
|
|
||||||
headerChan: make(chan struct{}),
|
|
||||||
contentSubtype: callHdr.ContentSubtype,
|
|
||||||
}
|
|
||||||
s.wq = newWriteQuota(defaultWriteQuota, s.done)
|
|
||||||
s.requestRead = func(n int) {
|
|
||||||
t.adjustWindow(s, uint32(n))
|
|
||||||
}
|
|
||||||
// The client side stream context should have exactly the same life cycle with the user provided context.
|
// The client side stream context should have exactly the same life cycle with the user provided context.
|
||||||
// That means, s.ctx should be read-only. And s.ctx is done iff ctx is done.
|
// That means, s.ctx should be read-only. And s.ctx is done iff ctx is done.
|
||||||
// So we use the original context here instead of creating a copy.
|
// So we use the original context here instead of creating a copy.
|
||||||
s.ctx = ctx
|
s := newStream(ctx)
|
||||||
s.trReader = &transportReader{
|
// Initialize stream with client-side specific fields.
|
||||||
reader: &recvBufferReader{
|
s.done = make(chan struct{})
|
||||||
ctx: s.ctx,
|
s.method = callHdr.Method
|
||||||
ctxDone: s.ctx.Done(),
|
s.sendCompress = callHdr.SendCompress
|
||||||
recv: s.buf,
|
s.headerChan = make(chan struct{})
|
||||||
},
|
s.contentSubtype = callHdr.ContentSubtype
|
||||||
windowHandler: func(n int) {
|
s.wq = newWriteQuota(defaultWriteQuota, s.done)
|
||||||
t.updateWindow(s, uint32(n))
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -504,7 +498,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||||||
}
|
}
|
||||||
// The stream was unprocessed by the server.
|
// The stream was unprocessed by the server.
|
||||||
atomic.StoreUint32(&s.unprocessed, 1)
|
atomic.StoreUint32(&s.unprocessed, 1)
|
||||||
s.write(recvMsg{err: err})
|
s.notifyErr(err)
|
||||||
close(s.done)
|
close(s.done)
|
||||||
// If headerChan isn't closed, then close it.
|
// If headerChan isn't closed, then close it.
|
||||||
if atomic.SwapUint32(&s.headerDone, 1) == 0 {
|
if atomic.SwapUint32(&s.headerDone, 1) == 0 {
|
||||||
@ -572,7 +566,13 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||||||
h.streamID = t.nextID
|
h.streamID = t.nextID
|
||||||
t.nextID += 2
|
t.nextID += 2
|
||||||
s.id = h.streamID
|
s.id = h.streamID
|
||||||
s.fc = &inFlow{limit: uint32(t.initialWindowSize)}
|
s.fc = &stInFlow{
|
||||||
|
limit: uint32(t.initialWindowSize),
|
||||||
|
scheduleWU: func(w uint32) {
|
||||||
|
t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w})
|
||||||
|
},
|
||||||
|
wuThreshold: uint32(t.initialWindowSize / 4),
|
||||||
|
}
|
||||||
if t.streamQuota > 0 && t.waitingStreams > 0 {
|
if t.streamQuota > 0 && t.waitingStreams > 0 {
|
||||||
select {
|
select {
|
||||||
case t.streamsQuotaAvailable <- struct{}{}:
|
case t.streamsQuotaAvailable <- struct{}{}:
|
||||||
@ -642,7 +642,7 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.
|
|||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// This will unblock reads eventually.
|
// This will unblock reads eventually.
|
||||||
s.write(recvMsg{err: err})
|
s.notifyErr(err)
|
||||||
}
|
}
|
||||||
// This will unblock write.
|
// This will unblock write.
|
||||||
close(s.done)
|
close(s.done)
|
||||||
@ -740,7 +740,7 @@ func (t *http2Client) GracefulClose() error {
|
|||||||
|
|
||||||
// Write formats the data into HTTP2 data frame(s) and sends it out. The caller
|
// Write formats the data into HTTP2 data frame(s) and sends it out. The caller
|
||||||
// should proceed only if Write returns nil.
|
// should proceed only if Write returns nil.
|
||||||
func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
|
func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
|
||||||
if opts.Last {
|
if opts.Last {
|
||||||
// If it's the last message, update stream state.
|
// If it's the last message, update stream state.
|
||||||
if !s.compareAndSwapState(streamActive, streamWriteDone) {
|
if !s.compareAndSwapState(streamActive, streamWriteDone) {
|
||||||
@ -753,7 +753,9 @@ func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
|
|||||||
streamID: s.id,
|
streamID: s.id,
|
||||||
endStream: opts.Last,
|
endStream: opts.Last,
|
||||||
}
|
}
|
||||||
if hdr != nil || data != nil { // If it's not an empty data frame.
|
if data != nil { // If it's not an empty data frame.
|
||||||
|
// Get a gRPC-specific header for this message.
|
||||||
|
hdr := msgdecoder.CreateMessageHeader(len(data), opts.IsCompressed)
|
||||||
// Add some data to grpc message header so that we can equally
|
// Add some data to grpc message header so that we can equally
|
||||||
// distribute bytes across frames.
|
// distribute bytes across frames.
|
||||||
emptyLen := http2MaxFrameLen - len(hdr)
|
emptyLen := http2MaxFrameLen - len(hdr)
|
||||||
@ -778,39 +780,19 @@ func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) {
|
|||||||
return s, ok
|
return s, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
// adjustWindow sends out extra window update over the initial window size
|
|
||||||
// of stream if the application is requesting data larger in size than
|
|
||||||
// the window.
|
|
||||||
func (t *http2Client) adjustWindow(s *Stream, n uint32) {
|
|
||||||
if w := s.fc.maybeAdjust(n); w > 0 {
|
|
||||||
t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateWindow adjusts the inbound quota for the stream.
|
|
||||||
// Window updates will be sent out when the cumulative quota
|
|
||||||
// exceeds the corresponding threshold.
|
|
||||||
func (t *http2Client) updateWindow(s *Stream, n uint32) {
|
|
||||||
if w := s.fc.onRead(n); w > 0 {
|
|
||||||
t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateFlowControl updates the incoming flow control windows
|
// updateFlowControl updates the incoming flow control windows
|
||||||
// for the transport and the stream based on the current bdp
|
// for the transport and the stream based on the current bdp
|
||||||
// estimation.
|
// estimation.
|
||||||
func (t *http2Client) updateFlowControl(n uint32) {
|
func (t *http2Client) updateFlowControl(n uint32) {
|
||||||
t.mu.Lock()
|
t.fc.newLimit(n) // Update transport's window.
|
||||||
for _, s := range t.activeStreams {
|
updateIWS := func(interface{}) bool { // Update streams' windows.
|
||||||
s.fc.newLimit(n)
|
// All future streams should see the
|
||||||
}
|
// updated value.
|
||||||
t.mu.Unlock()
|
|
||||||
updateIWS := func(interface{}) bool {
|
|
||||||
t.initialWindowSize = int32(n)
|
t.initialWindowSize = int32(n)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
t.controlBuf.executeAndPut(updateIWS, &outgoingWindowUpdate{streamID: 0, increment: t.fc.newLimit(n)})
|
// Notify the other side of updated window.
|
||||||
t.controlBuf.put(&outgoingSettings{
|
t.controlBuf.executeAndPut(updateIWS, &outgoingSettings{
|
||||||
ss: []http2.Setting{
|
ss: []http2.Setting{
|
||||||
{
|
{
|
||||||
ID: http2.SettingInitialWindowSize,
|
ID: http2.SettingInitialWindowSize,
|
||||||
@ -818,13 +800,25 @@ func (t *http2Client) updateFlowControl(n uint32) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
t.mu.Lock()
|
||||||
|
// Update all the currently active streams.
|
||||||
|
for _, s := range t.activeStreams {
|
||||||
|
s.fc.newLimit(n)
|
||||||
|
}
|
||||||
|
t.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *http2Client) handleData(f *http2.DataFrame) {
|
func (t *http2Client) handleData(f *http2.DataFrame) {
|
||||||
size := f.Header().Length
|
size := f.Header().Length
|
||||||
var sendBDPPing bool
|
if size == 0 {
|
||||||
if t.bdpEst != nil {
|
if f.StreamEnded() {
|
||||||
sendBDPPing = t.bdpEst.add(size)
|
// The server has closed the stream without sending trailers. Record that
|
||||||
|
// the read direction is closed, and set the status appropriately.
|
||||||
|
if s, ok := t.getStream(f); ok {
|
||||||
|
t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
// Decouple connection's flow control from application's read.
|
// Decouple connection's flow control from application's read.
|
||||||
// An update on connection's flow control should not depend on
|
// An update on connection's flow control should not depend on
|
||||||
@ -835,53 +829,30 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
|
|||||||
// active(fast) streams from starving in presence of slow or
|
// active(fast) streams from starving in presence of slow or
|
||||||
// inactive streams.
|
// inactive streams.
|
||||||
//
|
//
|
||||||
if w := t.fc.onData(size); w > 0 {
|
t.fc.onData(size)
|
||||||
t.controlBuf.put(&outgoingWindowUpdate{
|
if t.bdpEst != nil && t.bdpEst.add(size) {
|
||||||
streamID: 0,
|
|
||||||
increment: w,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
if sendBDPPing {
|
|
||||||
// Avoid excessive ping detection (e.g. in an L7 proxy)
|
// Avoid excessive ping detection (e.g. in an L7 proxy)
|
||||||
// by sending a window update prior to the BDP ping.
|
// by sending a window update prior to the BDP ping.
|
||||||
|
t.fc.reset()
|
||||||
if w := t.fc.reset(); w > 0 {
|
|
||||||
t.controlBuf.put(&outgoingWindowUpdate{
|
|
||||||
streamID: 0,
|
|
||||||
increment: w,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
t.controlBuf.put(bdpPing)
|
t.controlBuf.put(bdpPing)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Select the right stream to dispatch.
|
// Select the right stream to dispatch.
|
||||||
s, ok := t.getStream(f)
|
if s, ok := t.getStream(f); ok {
|
||||||
if !ok {
|
d := f.Data()
|
||||||
return
|
padding := 0
|
||||||
}
|
if f.Header().Flags.Has(http2.FlagDataPadded) {
|
||||||
if size > 0 {
|
padding = int(size) - len(d)
|
||||||
if err := s.fc.onData(size); err != nil {
|
}
|
||||||
|
if err := s.consume(d, padding); err != nil {
|
||||||
t.closeStream(s, io.EOF, true, http2.ErrCodeFlowControl, status.New(codes.Internal, err.Error()), nil, false)
|
t.closeStream(s, io.EOF, true, http2.ErrCodeFlowControl, status.New(codes.Internal, err.Error()), nil, false)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if f.Header().Flags.Has(http2.FlagDataPadded) {
|
if f.StreamEnded() {
|
||||||
if w := s.fc.onRead(size - uint32(len(f.Data()))); w > 0 {
|
// The server has closed the stream without sending trailers. Record that
|
||||||
t.controlBuf.put(&outgoingWindowUpdate{s.id, w})
|
// the read direction is closed, and set the status appropriately.
|
||||||
}
|
t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true)
|
||||||
}
|
}
|
||||||
// TODO(bradfitz, zhaoq): A copy is required here because there is no
|
|
||||||
// guarantee f.Data() is consumed before the arrival of next frame.
|
|
||||||
// Can this copy be eliminated?
|
|
||||||
if len(f.Data()) > 0 {
|
|
||||||
data := make([]byte, len(f.Data()))
|
|
||||||
copy(data, f.Data())
|
|
||||||
s.write(recvMsg{data: data})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// The server has closed the stream without sending trailers. Record that
|
|
||||||
// the read direction is closed, and set the status appropriately.
|
|
||||||
if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) {
|
|
||||||
t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -890,6 +861,7 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
errorf("transport: client got RST_STREAM with error %v, for stream: %d", f.ErrCode, s.id)
|
||||||
if f.ErrCode == http2.ErrCodeRefusedStream {
|
if f.ErrCode == http2.ErrCodeRefusedStream {
|
||||||
// The stream was unprocessed by the server.
|
// The stream was unprocessed by the server.
|
||||||
atomic.StoreUint32(&s.unprocessed, 1)
|
atomic.StoreUint32(&s.unprocessed, 1)
|
||||||
|
@ -39,6 +39,7 @@ import (
|
|||||||
"google.golang.org/grpc/channelz"
|
"google.golang.org/grpc/channelz"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
|
"google.golang.org/grpc/internal/msgdecoder"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/peer"
|
"google.golang.org/grpc/peer"
|
||||||
@ -212,7 +213,6 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
|
|||||||
writerDone: make(chan struct{}),
|
writerDone: make(chan struct{}),
|
||||||
maxStreams: maxStreams,
|
maxStreams: maxStreams,
|
||||||
inTapHandle: config.InTapHandle,
|
inTapHandle: config.InTapHandle,
|
||||||
fc: &trInFlow{limit: uint32(icwz)},
|
|
||||||
state: reachable,
|
state: reachable,
|
||||||
activeStreams: make(map[uint32]*Stream),
|
activeStreams: make(map[uint32]*Stream),
|
||||||
stats: config.StatsHandler,
|
stats: config.StatsHandler,
|
||||||
@ -222,6 +222,15 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
|
|||||||
initialWindowSize: iwz,
|
initialWindowSize: iwz,
|
||||||
}
|
}
|
||||||
t.controlBuf = newControlBuffer(t.ctxDone)
|
t.controlBuf = newControlBuffer(t.ctxDone)
|
||||||
|
t.fc = &trInFlow{
|
||||||
|
limit: uint32(icwz),
|
||||||
|
scheduleWU: func(w uint32) {
|
||||||
|
t.controlBuf.put(&outgoingWindowUpdate{
|
||||||
|
streamID: 0,
|
||||||
|
increment: w,
|
||||||
|
})
|
||||||
|
},
|
||||||
|
}
|
||||||
if dynamicWindow {
|
if dynamicWindow {
|
||||||
t.bdpEst = &bdpEstimator{
|
t.bdpEst = &bdpEstimator{
|
||||||
bdp: initialWindowSize,
|
bdp: initialWindowSize,
|
||||||
@ -298,25 +307,14 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
var (
|
||||||
buf := newRecvBuffer()
|
ctx context.Context
|
||||||
s := &Stream{
|
cancel func()
|
||||||
id: streamID,
|
)
|
||||||
st: t,
|
|
||||||
buf: buf,
|
|
||||||
fc: &inFlow{limit: uint32(t.initialWindowSize)},
|
|
||||||
recvCompress: state.encoding,
|
|
||||||
method: state.method,
|
|
||||||
contentSubtype: state.contentSubtype,
|
|
||||||
}
|
|
||||||
if frame.StreamEnded() {
|
|
||||||
// s is just created by the caller. No lock needed.
|
|
||||||
s.state = streamReadDone
|
|
||||||
}
|
|
||||||
if state.timeoutSet {
|
if state.timeoutSet {
|
||||||
s.ctx, s.cancel = context.WithTimeout(t.ctx, state.timeout)
|
ctx, cancel = context.WithTimeout(t.ctx, state.timeout)
|
||||||
} else {
|
} else {
|
||||||
s.ctx, s.cancel = context.WithCancel(t.ctx)
|
ctx, cancel = context.WithCancel(t.ctx)
|
||||||
}
|
}
|
||||||
pr := &peer.Peer{
|
pr := &peer.Peer{
|
||||||
Addr: t.remoteAddr,
|
Addr: t.remoteAddr,
|
||||||
@ -325,34 +323,55 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
|||||||
if t.authInfo != nil {
|
if t.authInfo != nil {
|
||||||
pr.AuthInfo = t.authInfo
|
pr.AuthInfo = t.authInfo
|
||||||
}
|
}
|
||||||
s.ctx = peer.NewContext(s.ctx, pr)
|
ctx = peer.NewContext(ctx, pr)
|
||||||
// Attach the received metadata to the context.
|
// Attach the received metadata to the context.
|
||||||
if len(state.mdata) > 0 {
|
if len(state.mdata) > 0 {
|
||||||
s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata)
|
ctx = metadata.NewIncomingContext(ctx, state.mdata)
|
||||||
}
|
}
|
||||||
if state.statsTags != nil {
|
if state.statsTags != nil {
|
||||||
s.ctx = stats.SetIncomingTags(s.ctx, state.statsTags)
|
ctx = stats.SetIncomingTags(ctx, state.statsTags)
|
||||||
}
|
}
|
||||||
if state.statsTrace != nil {
|
if state.statsTrace != nil {
|
||||||
s.ctx = stats.SetIncomingTrace(s.ctx, state.statsTrace)
|
ctx = stats.SetIncomingTrace(ctx, state.statsTrace)
|
||||||
}
|
}
|
||||||
if t.inTapHandle != nil {
|
if t.inTapHandle != nil {
|
||||||
var err error
|
var err error
|
||||||
info := &tap.Info{
|
info := &tap.Info{
|
||||||
FullMethodName: state.method,
|
FullMethodName: state.method,
|
||||||
}
|
}
|
||||||
s.ctx, err = t.inTapHandle(s.ctx, info)
|
ctx, err = t.inTapHandle(ctx, info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
warningf("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err)
|
warningf("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err)
|
||||||
t.controlBuf.put(&cleanupStream{
|
t.controlBuf.put(&cleanupStream{
|
||||||
streamID: s.id,
|
streamID: streamID,
|
||||||
rst: true,
|
rst: true,
|
||||||
rstCode: http2.ErrCodeRefusedStream,
|
rstCode: http2.ErrCodeRefusedStream,
|
||||||
onWrite: func() {},
|
onWrite: func() {},
|
||||||
})
|
})
|
||||||
|
cancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ctx = traceCtx(ctx, state.method)
|
||||||
|
s := newStream(ctx)
|
||||||
|
// Initialize s with server-side specific fields.
|
||||||
|
s.cancel = cancel
|
||||||
|
s.id = streamID
|
||||||
|
s.st = t
|
||||||
|
s.fc = &stInFlow{
|
||||||
|
limit: uint32(t.initialWindowSize),
|
||||||
|
scheduleWU: func(w uint32) {
|
||||||
|
t.controlBuf.put(&outgoingWindowUpdate{streamID: streamID, increment: w})
|
||||||
|
},
|
||||||
|
wuThreshold: uint32(t.initialWindowSize / 4),
|
||||||
|
}
|
||||||
|
s.recvCompress = state.encoding
|
||||||
|
s.method = state.method
|
||||||
|
s.contentSubtype = state.contentSubtype
|
||||||
|
if frame.StreamEnded() {
|
||||||
|
s.state = streamReadDone
|
||||||
|
}
|
||||||
|
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
if t.state != reachable {
|
if t.state != reachable {
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
@ -386,10 +405,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
|||||||
t.lastStreamCreated = time.Now()
|
t.lastStreamCreated = time.Now()
|
||||||
t.czmu.Unlock()
|
t.czmu.Unlock()
|
||||||
}
|
}
|
||||||
s.requestRead = func(n int) {
|
|
||||||
t.adjustWindow(s, uint32(n))
|
|
||||||
}
|
|
||||||
s.ctx = traceCtx(s.ctx, s.method)
|
|
||||||
if t.stats != nil {
|
if t.stats != nil {
|
||||||
s.ctx = t.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
|
s.ctx = t.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
|
||||||
inHeader := &stats.InHeader{
|
inHeader := &stats.InHeader{
|
||||||
@ -401,18 +416,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
|||||||
}
|
}
|
||||||
t.stats.HandleRPC(s.ctx, inHeader)
|
t.stats.HandleRPC(s.ctx, inHeader)
|
||||||
}
|
}
|
||||||
s.ctxDone = s.ctx.Done()
|
|
||||||
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
|
|
||||||
s.trReader = &transportReader{
|
|
||||||
reader: &recvBufferReader{
|
|
||||||
ctx: s.ctx,
|
|
||||||
ctxDone: s.ctxDone,
|
|
||||||
recv: s.buf,
|
|
||||||
},
|
|
||||||
windowHandler: func(n int) {
|
|
||||||
t.updateWindow(s, uint32(n))
|
|
||||||
},
|
|
||||||
}
|
|
||||||
handle(s)
|
handle(s)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -490,41 +493,20 @@ func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) {
|
|||||||
return s, true
|
return s, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// adjustWindow sends out extra window update over the initial window size
|
|
||||||
// of stream if the application is requesting data larger in size than
|
|
||||||
// the window.
|
|
||||||
func (t *http2Server) adjustWindow(s *Stream, n uint32) {
|
|
||||||
if w := s.fc.maybeAdjust(n); w > 0 {
|
|
||||||
t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w})
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateWindow adjusts the inbound quota for the stream and the transport.
|
|
||||||
// Window updates will deliver to the controller for sending when
|
|
||||||
// the cumulative quota exceeds the corresponding threshold.
|
|
||||||
func (t *http2Server) updateWindow(s *Stream, n uint32) {
|
|
||||||
if w := s.fc.onRead(n); w > 0 {
|
|
||||||
t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id,
|
|
||||||
increment: w,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateFlowControl updates the incoming flow control windows
|
// updateFlowControl updates the incoming flow control windows
|
||||||
// for the transport and the stream based on the current bdp
|
// for the transport and the stream based on the current bdp
|
||||||
// estimation.
|
// estimation.
|
||||||
func (t *http2Server) updateFlowControl(n uint32) {
|
func (t *http2Server) updateFlowControl(n uint32) {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
|
// Update all the current streams' window.
|
||||||
for _, s := range t.activeStreams {
|
for _, s := range t.activeStreams {
|
||||||
s.fc.newLimit(n)
|
s.fc.newLimit(n)
|
||||||
}
|
}
|
||||||
|
// Update all the future streams' window.
|
||||||
t.initialWindowSize = int32(n)
|
t.initialWindowSize = int32(n)
|
||||||
t.mu.Unlock()
|
t.mu.Unlock()
|
||||||
t.controlBuf.put(&outgoingWindowUpdate{
|
t.fc.newLimit(n) // Update transport's window.
|
||||||
streamID: 0,
|
// Notify the other side of the updated value.
|
||||||
increment: t.fc.newLimit(n),
|
|
||||||
})
|
|
||||||
t.controlBuf.put(&outgoingSettings{
|
t.controlBuf.put(&outgoingSettings{
|
||||||
ss: []http2.Setting{
|
ss: []http2.Setting{
|
||||||
{
|
{
|
||||||
@ -538,9 +520,15 @@ func (t *http2Server) updateFlowControl(n uint32) {
|
|||||||
|
|
||||||
func (t *http2Server) handleData(f *http2.DataFrame) {
|
func (t *http2Server) handleData(f *http2.DataFrame) {
|
||||||
size := f.Header().Length
|
size := f.Header().Length
|
||||||
var sendBDPPing bool
|
if size == 0 {
|
||||||
if t.bdpEst != nil {
|
if f.StreamEnded() {
|
||||||
sendBDPPing = t.bdpEst.add(size)
|
if s, ok := t.getStream(f); ok {
|
||||||
|
// Received the end of stream from the client.
|
||||||
|
s.compareAndSwapState(streamActive, streamReadDone)
|
||||||
|
s.notifyErr(io.EOF)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
// Decouple connection's flow control from application's read.
|
// Decouple connection's flow control from application's read.
|
||||||
// An update on connection's flow control should not depend on
|
// An update on connection's flow control should not depend on
|
||||||
@ -550,51 +538,30 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
|
|||||||
// Decoupling the connection flow control will prevent other
|
// Decoupling the connection flow control will prevent other
|
||||||
// active(fast) streams from starving in presence of slow or
|
// active(fast) streams from starving in presence of slow or
|
||||||
// inactive streams.
|
// inactive streams.
|
||||||
if w := t.fc.onData(size); w > 0 {
|
t.fc.onData(size)
|
||||||
t.controlBuf.put(&outgoingWindowUpdate{
|
if t.bdpEst != nil && t.bdpEst.add(size) {
|
||||||
streamID: 0,
|
|
||||||
increment: w,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
if sendBDPPing {
|
|
||||||
// Avoid excessive ping detection (e.g. in an L7 proxy)
|
// Avoid excessive ping detection (e.g. in an L7 proxy)
|
||||||
// by sending a window update prior to the BDP ping.
|
// by sending a window update prior to the BDP ping.
|
||||||
if w := t.fc.reset(); w > 0 {
|
t.fc.reset()
|
||||||
t.controlBuf.put(&outgoingWindowUpdate{
|
|
||||||
streamID: 0,
|
|
||||||
increment: w,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
t.controlBuf.put(bdpPing)
|
t.controlBuf.put(bdpPing)
|
||||||
}
|
}
|
||||||
// Select the right stream to dispatch.
|
// Select the right stream to dispatch.
|
||||||
s, ok := t.getStream(f)
|
if s, ok := t.getStream(f); ok {
|
||||||
if !ok {
|
d := f.Data()
|
||||||
return
|
padding := 0
|
||||||
}
|
if f.Header().Flags.Has(http2.FlagDataPadded) {
|
||||||
if size > 0 {
|
padding = int(size) - len(d)
|
||||||
if err := s.fc.onData(size); err != nil {
|
}
|
||||||
|
if err := s.consume(d, padding); err != nil {
|
||||||
|
errorf("transport: flow control error on server: %v", err)
|
||||||
t.closeStream(s, true, http2.ErrCodeFlowControl, nil, false)
|
t.closeStream(s, true, http2.ErrCodeFlowControl, nil, false)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if f.Header().Flags.Has(http2.FlagDataPadded) {
|
if f.StreamEnded() {
|
||||||
if w := s.fc.onRead(size - uint32(len(f.Data()))); w > 0 {
|
// Received the end of stream from the client.
|
||||||
t.controlBuf.put(&outgoingWindowUpdate{s.id, w})
|
s.compareAndSwapState(streamActive, streamReadDone)
|
||||||
}
|
s.notifyErr(io.EOF)
|
||||||
}
|
}
|
||||||
// TODO(bradfitz, zhaoq): A copy is required here because there is no
|
|
||||||
// guarantee f.Data() is consumed before the arrival of next frame.
|
|
||||||
// Can this copy be eliminated?
|
|
||||||
if len(f.Data()) > 0 {
|
|
||||||
data := make([]byte, len(f.Data()))
|
|
||||||
copy(data, f.Data())
|
|
||||||
s.write(recvMsg{data: data})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if f.Header().Flags.Has(http2.FlagDataEndStream) {
|
|
||||||
// Received the end of stream from the client.
|
|
||||||
s.compareAndSwapState(streamActive, streamReadDone)
|
|
||||||
s.write(recvMsg{err: io.EOF})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -792,7 +759,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
|
|||||||
|
|
||||||
// Write converts the data into HTTP2 data frame and sends it out. Non-nil error
|
// Write converts the data into HTTP2 data frame and sends it out. Non-nil error
|
||||||
// is returns if it fails (e.g., framing error, transport error).
|
// is returns if it fails (e.g., framing error, transport error).
|
||||||
func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
|
func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
|
||||||
if !s.headerOk { // Headers haven't been written yet.
|
if !s.headerOk { // Headers haven't been written yet.
|
||||||
if err := t.WriteHeader(s, nil); err != nil {
|
if err := t.WriteHeader(s, nil); err != nil {
|
||||||
// TODO(mmukhi, dfawley): Make sure this is the right code to return.
|
// TODO(mmukhi, dfawley): Make sure this is the right code to return.
|
||||||
@ -811,6 +778,8 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
|
|||||||
return ContextErr(s.ctx.Err())
|
return ContextErr(s.ctx.Err())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Get a gRPC-specific header for this message.
|
||||||
|
hdr := msgdecoder.CreateMessageHeader(len(data), opts.IsCompressed)
|
||||||
// Add some data to header frame so that we can equally distribute bytes across frames.
|
// Add some data to header frame so that we can equally distribute bytes across frames.
|
||||||
emptyLen := http2MaxFrameLen - len(hdr)
|
emptyLen := http2MaxFrameLen - len(hdr)
|
||||||
if emptyLen > len(data) {
|
if emptyLen > len(data) {
|
||||||
|
407
transport/stream.go
Normal file
407
transport/stream.go
Normal file
@ -0,0 +1,407 @@
|
|||||||
|
/*
|
||||||
|
*
|
||||||
|
* Copyright 2014 gRPC authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package transport
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/internal/msgdecoder"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
const maxInt = int(^uint(0) >> 1)
|
||||||
|
|
||||||
|
type streamState uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
streamActive streamState = iota
|
||||||
|
streamWriteDone // EndStream sent
|
||||||
|
streamReadDone // EndStream received
|
||||||
|
streamDone // the entire stream is finished.
|
||||||
|
)
|
||||||
|
|
||||||
|
// transport's reader goroutine adds msgdecoder.RecvMsg to it which are later
|
||||||
|
// read by RPC's reading goroutine.
|
||||||
|
//
|
||||||
|
// It is protected by a lock in the Stream that owns it.
|
||||||
|
type recvBuffer struct {
|
||||||
|
ctx context.Context
|
||||||
|
ctxDone <-chan struct{}
|
||||||
|
c chan *msgdecoder.RecvMsg
|
||||||
|
mu sync.Mutex
|
||||||
|
waiting bool
|
||||||
|
list *msgdecoder.RecvMsgList
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRecvBuffer(ctx context.Context, ctxDone <-chan struct{}) *recvBuffer {
|
||||||
|
return &recvBuffer{
|
||||||
|
ctx: ctx,
|
||||||
|
ctxDone: ctxDone,
|
||||||
|
c: make(chan *msgdecoder.RecvMsg, 1),
|
||||||
|
list: &msgdecoder.RecvMsgList{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// put adds r to the underlying list if there's no consumer
|
||||||
|
// waiting, otherwise, it writes on the chan directly.
|
||||||
|
func (b *recvBuffer) put(r *msgdecoder.RecvMsg) {
|
||||||
|
b.mu.Lock()
|
||||||
|
if b.waiting {
|
||||||
|
b.waiting = false
|
||||||
|
b.mu.Unlock()
|
||||||
|
b.c <- r
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.list.Enqueue(r)
|
||||||
|
b.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// getNoBlock returns a msgdecoder.RecvMsg and true status, if there's
|
||||||
|
// any available.
|
||||||
|
// If the status is false, the caller must then call
|
||||||
|
// getWithBlock() before calling getNoBlock() again.
|
||||||
|
func (b *recvBuffer) getNoBlock() (*msgdecoder.RecvMsg, bool) {
|
||||||
|
b.mu.Lock()
|
||||||
|
r := b.list.Dequeue()
|
||||||
|
if r != nil {
|
||||||
|
b.mu.Unlock()
|
||||||
|
return r, true
|
||||||
|
}
|
||||||
|
b.waiting = true
|
||||||
|
b.mu.Unlock()
|
||||||
|
return nil, false
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// getWithBlock() blocks until a complete message has been
|
||||||
|
// received, or an error has occurred or the underlying
|
||||||
|
// context has expired.
|
||||||
|
// It must only be called after having called GetNoBlock()
|
||||||
|
// once.
|
||||||
|
func (b *recvBuffer) getWithBlock() (*msgdecoder.RecvMsg, error) {
|
||||||
|
select {
|
||||||
|
case <-b.ctxDone:
|
||||||
|
return nil, ContextErr(b.ctx.Err())
|
||||||
|
case r := <-b.c:
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *recvBuffer) get() (*msgdecoder.RecvMsg, error) {
|
||||||
|
m, ok := b.getNoBlock()
|
||||||
|
if ok {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
m, err := b.getWithBlock()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream represents an RPC in the transport layer.
|
||||||
|
type Stream struct {
|
||||||
|
id uint32
|
||||||
|
st ServerTransport // nil for client side Stream
|
||||||
|
ctx context.Context // the associated context of the stream
|
||||||
|
cancel context.CancelFunc // always nil for client side Stream
|
||||||
|
done chan struct{} // closed at the end of stream to unblock writers. On the client side.
|
||||||
|
ctxDone <-chan struct{} // same as done chan but for server side. Cache of ctx.Done() (for performance)
|
||||||
|
method string // the associated RPC method of the stream
|
||||||
|
recvCompress string
|
||||||
|
sendCompress string
|
||||||
|
buf *recvBuffer
|
||||||
|
wq *writeQuota
|
||||||
|
|
||||||
|
headerChan chan struct{} // closed to indicate the end of header metadata.
|
||||||
|
headerDone uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
|
||||||
|
header metadata.MD // the received header metadata.
|
||||||
|
trailer metadata.MD // the key-value map of trailer metadata.
|
||||||
|
|
||||||
|
headerOk bool // becomes true from the first header is about to send
|
||||||
|
state streamState
|
||||||
|
|
||||||
|
status *status.Status // the status error received from the server
|
||||||
|
|
||||||
|
bytesReceived uint32 // indicates whether any bytes have been received on this stream
|
||||||
|
unprocessed uint32 // set if the server sends a refused stream or GOAWAY including this stream
|
||||||
|
|
||||||
|
// contentSubtype is the content-subtype for requests.
|
||||||
|
// this must be lowercase or the behavior is undefined.
|
||||||
|
contentSubtype string
|
||||||
|
|
||||||
|
rbuf *recvBuffer
|
||||||
|
fc *stInFlow
|
||||||
|
msgDecoder *msgdecoder.MessageDecoder
|
||||||
|
readErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStream(ctx context.Context) *Stream {
|
||||||
|
// Cache the done chan since a Done() call is expensive.
|
||||||
|
ctxDone := ctx.Done()
|
||||||
|
s := &Stream{
|
||||||
|
ctx: ctx,
|
||||||
|
ctxDone: ctxDone,
|
||||||
|
rbuf: newRecvBuffer(ctx, ctxDone),
|
||||||
|
}
|
||||||
|
dispatch := func(r *msgdecoder.RecvMsg) {
|
||||||
|
s.rbuf.put(r)
|
||||||
|
}
|
||||||
|
s.msgDecoder = msgdecoder.NewMessageDecoder(dispatch)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// notifyErr notifies RPC of an error seen by the transport.
|
||||||
|
//
|
||||||
|
// Note to developers: This call can unblock Read calls on RPC
|
||||||
|
// and lead to reading of unprotected fields on stream on the
|
||||||
|
// client-side. It should only be called from inside
|
||||||
|
// transport.closeStream() if the stream was initialized or from
|
||||||
|
// inside the cleanup callback if the stream was not initialized.
|
||||||
|
func (s *Stream) notifyErr(err error) {
|
||||||
|
s.rbuf.put(&msgdecoder.RecvMsg{Err: err})
|
||||||
|
}
|
||||||
|
|
||||||
|
// consume is called by transport's reader goroutine for parsing
|
||||||
|
// and decoding data received for this stream.
|
||||||
|
func (s *Stream) consume(b []byte, padding int) error {
|
||||||
|
// Flow control check.
|
||||||
|
if s.fc != nil { // HandlerServer doesn't use our flow control.
|
||||||
|
if err := s.fc.onData(uint32(len(b) + padding)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.msgDecoder.Decode(b, padding)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read reads one whole message from the transport.
|
||||||
|
// It is called by RPC's goroutine.
|
||||||
|
// It is not safe to be called concurrently by multiple goroutines.
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// 1. received message's compression status(true if was compressed)
|
||||||
|
// 2. Message as a byte slice
|
||||||
|
// 3. Error, if any.
|
||||||
|
func (s *Stream) Read(maxRecvMsgSize int) (bool, []byte, error) {
|
||||||
|
if s.readErr != nil {
|
||||||
|
return false, nil, s.readErr
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
m *msgdecoder.RecvMsg
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
// First read the underlying message header
|
||||||
|
if m, err = s.rbuf.get(); err != nil {
|
||||||
|
s.readErr = err
|
||||||
|
return false, nil, err
|
||||||
|
}
|
||||||
|
if m.Err != nil {
|
||||||
|
s.readErr = m.Err
|
||||||
|
return false, nil, s.readErr
|
||||||
|
}
|
||||||
|
// Make sure the message being received isn't too large.
|
||||||
|
if int64(m.Length) > int64(maxInt) {
|
||||||
|
s.readErr = status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max length allowed on current machine (%d vs. %d)", m.Length, maxInt)
|
||||||
|
return false, nil, s.readErr
|
||||||
|
}
|
||||||
|
if m.Length > maxRecvMsgSize {
|
||||||
|
s.readErr = status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", m.Length, maxRecvMsgSize)
|
||||||
|
return false, nil, s.readErr
|
||||||
|
}
|
||||||
|
// Send a window update for the message this RPC is reading.
|
||||||
|
if s.fc != nil { // HanderServer doesn't use our flow control.
|
||||||
|
s.fc.onRead(uint32(m.Length + m.Overhead))
|
||||||
|
}
|
||||||
|
isCompressed := m.IsCompressed
|
||||||
|
// Read the message.
|
||||||
|
if m, err = s.rbuf.get(); err != nil {
|
||||||
|
s.readErr = err
|
||||||
|
return false, nil, err
|
||||||
|
}
|
||||||
|
if m.Err != nil {
|
||||||
|
if m.Err == io.EOF {
|
||||||
|
m.Err = io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
s.readErr = m.Err
|
||||||
|
return false, nil, s.readErr
|
||||||
|
}
|
||||||
|
return isCompressed, m.Data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stream) swapState(st streamState) streamState {
|
||||||
|
return streamState(atomic.SwapUint32((*uint32)(&s.state), uint32(st)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stream) compareAndSwapState(oldState, newState streamState) bool {
|
||||||
|
return atomic.CompareAndSwapUint32((*uint32)(&s.state), uint32(oldState), uint32(newState))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stream) getState() streamState {
|
||||||
|
return streamState(atomic.LoadUint32((*uint32)(&s.state)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stream) waitOnHeader() error {
|
||||||
|
if s.headerChan == nil {
|
||||||
|
// On the server headerChan is always nil since a stream originates
|
||||||
|
// only after having received headers.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
return ContextErr(s.ctx.Err())
|
||||||
|
case <-s.headerChan:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecvCompress returns the compression algorithm applied to the inbound
|
||||||
|
// message. It is empty string if there is no compression applied.
|
||||||
|
func (s *Stream) RecvCompress() string {
|
||||||
|
if err := s.waitOnHeader(); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return s.recvCompress
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSendCompress sets the compression algorithm to the stream.
|
||||||
|
func (s *Stream) SetSendCompress(str string) {
|
||||||
|
s.sendCompress = str
|
||||||
|
}
|
||||||
|
|
||||||
|
// Done returns a chanel which is closed when it receives the final status
|
||||||
|
// from the server.
|
||||||
|
func (s *Stream) Done() <-chan struct{} {
|
||||||
|
return s.done
|
||||||
|
}
|
||||||
|
|
||||||
|
// Header acquires the key-value pairs of header metadata once it
|
||||||
|
// is available. It blocks until i) the metadata is ready or ii) there is no
|
||||||
|
// header metadata or iii) the stream is canceled/expired.
|
||||||
|
func (s *Stream) Header() (metadata.MD, error) {
|
||||||
|
err := s.waitOnHeader()
|
||||||
|
// Even if the stream is closed, header is returned if available.
|
||||||
|
select {
|
||||||
|
case <-s.headerChan:
|
||||||
|
if s.header == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return s.header.Copy(), nil
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trailer returns the cached trailer metedata. Note that if it is not called
|
||||||
|
// after the entire stream is done, it could return an empty MD. Client
|
||||||
|
// side only.
|
||||||
|
// It can be safely read only after stream has ended that is either read
|
||||||
|
// or write have returned io.EOF.
|
||||||
|
func (s *Stream) Trailer() metadata.MD {
|
||||||
|
c := s.trailer.Copy()
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerTransport returns the underlying ServerTransport for the stream.
|
||||||
|
// The client side stream always returns nil.
|
||||||
|
func (s *Stream) ServerTransport() ServerTransport {
|
||||||
|
return s.st
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContentSubtype returns the content-subtype for a request. For example, a
|
||||||
|
// content-subtype of "proto" will result in a content-type of
|
||||||
|
// "application/grpc+proto". This will always be lowercase. See
|
||||||
|
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
|
||||||
|
// more details.
|
||||||
|
func (s *Stream) ContentSubtype() string {
|
||||||
|
return s.contentSubtype
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context returns the context of the stream.
|
||||||
|
func (s *Stream) Context() context.Context {
|
||||||
|
return s.ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
// Method returns the method for the stream.
|
||||||
|
func (s *Stream) Method() string {
|
||||||
|
return s.method
|
||||||
|
}
|
||||||
|
|
||||||
|
// Status returns the status received from the server.
|
||||||
|
// Status can be read safely only after the stream has ended,
|
||||||
|
// that is, read or write has returned io.EOF.
|
||||||
|
func (s *Stream) Status() *status.Status {
|
||||||
|
return s.status
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetHeader sets the header metadata. This can be called multiple times.
|
||||||
|
// Server side only.
|
||||||
|
// This should not be called in parallel to other data writes.
|
||||||
|
func (s *Stream) SetHeader(md metadata.MD) error {
|
||||||
|
if md.Len() == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if s.headerOk || atomic.LoadUint32((*uint32)(&s.state)) == uint32(streamDone) {
|
||||||
|
return ErrIllegalHeaderWrite
|
||||||
|
}
|
||||||
|
s.header = metadata.Join(s.header, md)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendHeader sends the given header metadata. The given metadata is
|
||||||
|
// combined with any metadata set by previous calls to SetHeader and
|
||||||
|
// then written to the transport stream.
|
||||||
|
func (s *Stream) SendHeader(md metadata.MD) error {
|
||||||
|
t := s.ServerTransport()
|
||||||
|
return t.WriteHeader(s, md)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTrailer sets the trailer metadata which will be sent with the RPC status
|
||||||
|
// by the server. This can be called multiple times. Server side only.
|
||||||
|
// This should not be called parallel to other data writes.
|
||||||
|
func (s *Stream) SetTrailer(md metadata.MD) error {
|
||||||
|
if md.Len() == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s.trailer = metadata.Join(s.trailer, md)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BytesReceived indicates whether any bytes have been received on this stream.
|
||||||
|
func (s *Stream) BytesReceived() bool {
|
||||||
|
return atomic.LoadUint32(&s.bytesReceived) == 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unprocessed indicates whether the server did not process this stream --
|
||||||
|
// i.e. it sent a refused stream or GOAWAY including this stream ID.
|
||||||
|
func (s *Stream) Unprocessed() bool {
|
||||||
|
return atomic.LoadUint32(&s.unprocessed) == 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GoString is implemented by Stream so context.String() won't
|
||||||
|
// race when printing %#v.
|
||||||
|
func (s *Stream) GoString() string {
|
||||||
|
return fmt.Sprintf("<stream: %p, %v>", s, s.method)
|
||||||
|
}
|
@ -24,10 +24,7 @@ package transport // externally used as import "google.golang.org/grpc/transport
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
|
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
@ -39,359 +36,6 @@ import (
|
|||||||
"google.golang.org/grpc/tap"
|
"google.golang.org/grpc/tap"
|
||||||
)
|
)
|
||||||
|
|
||||||
// recvMsg represents the received msg from the transport. All transport
|
|
||||||
// protocol specific info has been removed.
|
|
||||||
type recvMsg struct {
|
|
||||||
data []byte
|
|
||||||
// nil: received some data
|
|
||||||
// io.EOF: stream is completed. data is nil.
|
|
||||||
// other non-nil error: transport failure. data is nil.
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
// recvBuffer is an unbounded channel of recvMsg structs.
|
|
||||||
// Note recvBuffer differs from controlBuffer only in that recvBuffer
|
|
||||||
// holds a channel of only recvMsg structs instead of objects implementing "item" interface.
|
|
||||||
// recvBuffer is written to much more often than
|
|
||||||
// controlBuffer and using strict recvMsg structs helps avoid allocation in "recvBuffer.put"
|
|
||||||
type recvBuffer struct {
|
|
||||||
c chan recvMsg
|
|
||||||
mu sync.Mutex
|
|
||||||
backlog []recvMsg
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func newRecvBuffer() *recvBuffer {
|
|
||||||
b := &recvBuffer{
|
|
||||||
c: make(chan recvMsg, 1),
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *recvBuffer) put(r recvMsg) {
|
|
||||||
b.mu.Lock()
|
|
||||||
if b.err != nil {
|
|
||||||
b.mu.Unlock()
|
|
||||||
// An error had occurred earlier, don't accept more
|
|
||||||
// data or errors.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
b.err = r.err
|
|
||||||
if len(b.backlog) == 0 {
|
|
||||||
select {
|
|
||||||
case b.c <- r:
|
|
||||||
b.mu.Unlock()
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
b.backlog = append(b.backlog, r)
|
|
||||||
b.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *recvBuffer) load() {
|
|
||||||
b.mu.Lock()
|
|
||||||
if len(b.backlog) > 0 {
|
|
||||||
select {
|
|
||||||
case b.c <- b.backlog[0]:
|
|
||||||
b.backlog[0] = recvMsg{}
|
|
||||||
b.backlog = b.backlog[1:]
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
b.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// get returns the channel that receives a recvMsg in the buffer.
|
|
||||||
//
|
|
||||||
// Upon receipt of a recvMsg, the caller should call load to send another
|
|
||||||
// recvMsg onto the channel if there is any.
|
|
||||||
func (b *recvBuffer) get() <-chan recvMsg {
|
|
||||||
return b.c
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// recvBufferReader implements io.Reader interface to read the data from
|
|
||||||
// recvBuffer.
|
|
||||||
type recvBufferReader struct {
|
|
||||||
ctx context.Context
|
|
||||||
ctxDone <-chan struct{} // cache of ctx.Done() (for performance).
|
|
||||||
recv *recvBuffer
|
|
||||||
last []byte // Stores the remaining data in the previous calls.
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read reads the next len(p) bytes from last. If last is drained, it tries to
|
|
||||||
// read additional data from recv. It blocks if there no additional data available
|
|
||||||
// in recv. If Read returns any non-nil error, it will continue to return that error.
|
|
||||||
func (r *recvBufferReader) Read(p []byte) (n int, err error) {
|
|
||||||
if r.err != nil {
|
|
||||||
return 0, r.err
|
|
||||||
}
|
|
||||||
n, r.err = r.read(p)
|
|
||||||
return n, r.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *recvBufferReader) read(p []byte) (n int, err error) {
|
|
||||||
if r.last != nil && len(r.last) > 0 {
|
|
||||||
// Read remaining data left in last call.
|
|
||||||
copied := copy(p, r.last)
|
|
||||||
r.last = r.last[copied:]
|
|
||||||
return copied, nil
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-r.ctxDone:
|
|
||||||
return 0, ContextErr(r.ctx.Err())
|
|
||||||
case m := <-r.recv.get():
|
|
||||||
r.recv.load()
|
|
||||||
if m.err != nil {
|
|
||||||
return 0, m.err
|
|
||||||
}
|
|
||||||
copied := copy(p, m.data)
|
|
||||||
r.last = m.data[copied:]
|
|
||||||
return copied, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type streamState uint32
|
|
||||||
|
|
||||||
const (
|
|
||||||
streamActive streamState = iota
|
|
||||||
streamWriteDone // EndStream sent
|
|
||||||
streamReadDone // EndStream received
|
|
||||||
streamDone // the entire stream is finished.
|
|
||||||
)
|
|
||||||
|
|
||||||
// Stream represents an RPC in the transport layer.
|
|
||||||
type Stream struct {
|
|
||||||
id uint32
|
|
||||||
st ServerTransport // nil for client side Stream
|
|
||||||
ctx context.Context // the associated context of the stream
|
|
||||||
cancel context.CancelFunc // always nil for client side Stream
|
|
||||||
done chan struct{} // closed at the end of stream to unblock writers. On the client side.
|
|
||||||
ctxDone <-chan struct{} // same as done chan but for server side. Cache of ctx.Done() (for performance)
|
|
||||||
method string // the associated RPC method of the stream
|
|
||||||
recvCompress string
|
|
||||||
sendCompress string
|
|
||||||
buf *recvBuffer
|
|
||||||
trReader io.Reader
|
|
||||||
fc *inFlow
|
|
||||||
recvQuota uint32
|
|
||||||
wq *writeQuota
|
|
||||||
|
|
||||||
// Callback to state application's intentions to read data. This
|
|
||||||
// is used to adjust flow control, if needed.
|
|
||||||
requestRead func(int)
|
|
||||||
|
|
||||||
headerChan chan struct{} // closed to indicate the end of header metadata.
|
|
||||||
headerDone uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
|
|
||||||
header metadata.MD // the received header metadata.
|
|
||||||
trailer metadata.MD // the key-value map of trailer metadata.
|
|
||||||
|
|
||||||
headerOk bool // becomes true from the first header is about to send
|
|
||||||
state streamState
|
|
||||||
|
|
||||||
status *status.Status // the status error received from the server
|
|
||||||
|
|
||||||
bytesReceived uint32 // indicates whether any bytes have been received on this stream
|
|
||||||
unprocessed uint32 // set if the server sends a refused stream or GOAWAY including this stream
|
|
||||||
|
|
||||||
// contentSubtype is the content-subtype for requests.
|
|
||||||
// this must be lowercase or the behavior is undefined.
|
|
||||||
contentSubtype string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Stream) swapState(st streamState) streamState {
|
|
||||||
return streamState(atomic.SwapUint32((*uint32)(&s.state), uint32(st)))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Stream) compareAndSwapState(oldState, newState streamState) bool {
|
|
||||||
return atomic.CompareAndSwapUint32((*uint32)(&s.state), uint32(oldState), uint32(newState))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Stream) getState() streamState {
|
|
||||||
return streamState(atomic.LoadUint32((*uint32)(&s.state)))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Stream) waitOnHeader() error {
|
|
||||||
if s.headerChan == nil {
|
|
||||||
// On the server headerChan is always nil since a stream originates
|
|
||||||
// only after having received headers.
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-s.ctx.Done():
|
|
||||||
return ContextErr(s.ctx.Err())
|
|
||||||
case <-s.headerChan:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RecvCompress returns the compression algorithm applied to the inbound
|
|
||||||
// message. It is empty string if there is no compression applied.
|
|
||||||
func (s *Stream) RecvCompress() string {
|
|
||||||
if err := s.waitOnHeader(); err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return s.recvCompress
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSendCompress sets the compression algorithm to the stream.
|
|
||||||
func (s *Stream) SetSendCompress(str string) {
|
|
||||||
s.sendCompress = str
|
|
||||||
}
|
|
||||||
|
|
||||||
// Done returns a chanel which is closed when it receives the final status
|
|
||||||
// from the server.
|
|
||||||
func (s *Stream) Done() <-chan struct{} {
|
|
||||||
return s.done
|
|
||||||
}
|
|
||||||
|
|
||||||
// Header acquires the key-value pairs of header metadata once it
|
|
||||||
// is available. It blocks until i) the metadata is ready or ii) there is no
|
|
||||||
// header metadata or iii) the stream is canceled/expired.
|
|
||||||
func (s *Stream) Header() (metadata.MD, error) {
|
|
||||||
err := s.waitOnHeader()
|
|
||||||
// Even if the stream is closed, header is returned if available.
|
|
||||||
select {
|
|
||||||
case <-s.headerChan:
|
|
||||||
if s.header == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return s.header.Copy(), nil
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Trailer returns the cached trailer metedata. Note that if it is not called
|
|
||||||
// after the entire stream is done, it could return an empty MD. Client
|
|
||||||
// side only.
|
|
||||||
// It can be safely read only after stream has ended that is either read
|
|
||||||
// or write have returned io.EOF.
|
|
||||||
func (s *Stream) Trailer() metadata.MD {
|
|
||||||
c := s.trailer.Copy()
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServerTransport returns the underlying ServerTransport for the stream.
|
|
||||||
// The client side stream always returns nil.
|
|
||||||
func (s *Stream) ServerTransport() ServerTransport {
|
|
||||||
return s.st
|
|
||||||
}
|
|
||||||
|
|
||||||
// ContentSubtype returns the content-subtype for a request. For example, a
|
|
||||||
// content-subtype of "proto" will result in a content-type of
|
|
||||||
// "application/grpc+proto". This will always be lowercase. See
|
|
||||||
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for
|
|
||||||
// more details.
|
|
||||||
func (s *Stream) ContentSubtype() string {
|
|
||||||
return s.contentSubtype
|
|
||||||
}
|
|
||||||
|
|
||||||
// Context returns the context of the stream.
|
|
||||||
func (s *Stream) Context() context.Context {
|
|
||||||
return s.ctx
|
|
||||||
}
|
|
||||||
|
|
||||||
// Method returns the method for the stream.
|
|
||||||
func (s *Stream) Method() string {
|
|
||||||
return s.method
|
|
||||||
}
|
|
||||||
|
|
||||||
// Status returns the status received from the server.
|
|
||||||
// Status can be read safely only after the stream has ended,
|
|
||||||
// that is, read or write has returned io.EOF.
|
|
||||||
func (s *Stream) Status() *status.Status {
|
|
||||||
return s.status
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetHeader sets the header metadata. This can be called multiple times.
|
|
||||||
// Server side only.
|
|
||||||
// This should not be called in parallel to other data writes.
|
|
||||||
func (s *Stream) SetHeader(md metadata.MD) error {
|
|
||||||
if md.Len() == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if s.headerOk || atomic.LoadUint32((*uint32)(&s.state)) == uint32(streamDone) {
|
|
||||||
return ErrIllegalHeaderWrite
|
|
||||||
}
|
|
||||||
s.header = metadata.Join(s.header, md)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendHeader sends the given header metadata. The given metadata is
|
|
||||||
// combined with any metadata set by previous calls to SetHeader and
|
|
||||||
// then written to the transport stream.
|
|
||||||
func (s *Stream) SendHeader(md metadata.MD) error {
|
|
||||||
t := s.ServerTransport()
|
|
||||||
return t.WriteHeader(s, md)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTrailer sets the trailer metadata which will be sent with the RPC status
|
|
||||||
// by the server. This can be called multiple times. Server side only.
|
|
||||||
// This should not be called parallel to other data writes.
|
|
||||||
func (s *Stream) SetTrailer(md metadata.MD) error {
|
|
||||||
if md.Len() == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
s.trailer = metadata.Join(s.trailer, md)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Stream) write(m recvMsg) {
|
|
||||||
s.buf.put(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read reads all p bytes from the wire for this stream.
|
|
||||||
func (s *Stream) Read(p []byte) (n int, err error) {
|
|
||||||
// Don't request a read if there was an error earlier
|
|
||||||
if er := s.trReader.(*transportReader).er; er != nil {
|
|
||||||
return 0, er
|
|
||||||
}
|
|
||||||
s.requestRead(len(p))
|
|
||||||
return io.ReadFull(s.trReader, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
// tranportReader reads all the data available for this Stream from the transport and
|
|
||||||
// passes them into the decoder, which converts them into a gRPC message stream.
|
|
||||||
// The error is io.EOF when the stream is done or another non-nil error if
|
|
||||||
// the stream broke.
|
|
||||||
type transportReader struct {
|
|
||||||
reader io.Reader
|
|
||||||
// The handler to control the window update procedure for both this
|
|
||||||
// particular stream and the associated transport.
|
|
||||||
windowHandler func(int)
|
|
||||||
er error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *transportReader) Read(p []byte) (n int, err error) {
|
|
||||||
n, err = t.reader.Read(p)
|
|
||||||
if err != nil {
|
|
||||||
t.er = err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.windowHandler(n)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// BytesReceived indicates whether any bytes have been received on this stream.
|
|
||||||
func (s *Stream) BytesReceived() bool {
|
|
||||||
return atomic.LoadUint32(&s.bytesReceived) == 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unprocessed indicates whether the server did not process this stream --
|
|
||||||
// i.e. it sent a refused stream or GOAWAY including this stream ID.
|
|
||||||
func (s *Stream) Unprocessed() bool {
|
|
||||||
return atomic.LoadUint32(&s.unprocessed) == 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// GoString is implemented by Stream so context.String() won't
|
|
||||||
// race when printing %#v.
|
|
||||||
func (s *Stream) GoString() string {
|
|
||||||
return fmt.Sprintf("<stream: %p, %v>", s, s.method)
|
|
||||||
}
|
|
||||||
|
|
||||||
// state of transport
|
// state of transport
|
||||||
type transportState int
|
type transportState int
|
||||||
|
|
||||||
@ -476,7 +120,13 @@ type Options struct {
|
|||||||
// Delay is a hint to the transport implementation for whether
|
// Delay is a hint to the transport implementation for whether
|
||||||
// the data could be buffered for a batching write. The
|
// the data could be buffered for a batching write. The
|
||||||
// transport implementation may ignore the hint.
|
// transport implementation may ignore the hint.
|
||||||
|
// TODO(mmukhi, dfawley): Should this be deleted?
|
||||||
Delay bool
|
Delay bool
|
||||||
|
|
||||||
|
// IsCompressed indicates weather the message being written
|
||||||
|
// was compressed or not. Transport relays this information
|
||||||
|
// to the API that generates gRPC-specific message header.
|
||||||
|
IsCompressed bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// CallHdr carries the information of a particular RPC.
|
// CallHdr carries the information of a particular RPC.
|
||||||
@ -525,7 +175,7 @@ type ClientTransport interface {
|
|||||||
|
|
||||||
// Write sends the data for the given stream. A nil stream indicates
|
// Write sends the data for the given stream. A nil stream indicates
|
||||||
// the write is to be performed on the transport as a whole.
|
// the write is to be performed on the transport as a whole.
|
||||||
Write(s *Stream, hdr []byte, data []byte, opts *Options) error
|
Write(s *Stream, data []byte, opts *Options) error
|
||||||
|
|
||||||
// NewStream creates a Stream for an RPC.
|
// NewStream creates a Stream for an RPC.
|
||||||
NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error)
|
NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error)
|
||||||
@ -573,7 +223,7 @@ type ServerTransport interface {
|
|||||||
|
|
||||||
// Write sends the data for the given stream.
|
// Write sends the data for the given stream.
|
||||||
// Write may not be called on all streams.
|
// Write may not be called on all streams.
|
||||||
Write(s *Stream, hdr []byte, data []byte, opts *Options) error
|
Write(s *Stream, data []byte, opts *Options) error
|
||||||
|
|
||||||
// WriteStatus sends the status of a stream to the client. WriteStatus is
|
// WriteStatus sends the status of a stream to the client. WriteStatus is
|
||||||
// the final call made on a stream and always occurs.
|
// the final call made on a stream and always occurs.
|
||||||
|
@ -100,8 +100,7 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
|
|||||||
req = expectedRequestLarge
|
req = expectedRequestLarge
|
||||||
resp = expectedResponseLarge
|
resp = expectedResponseLarge
|
||||||
}
|
}
|
||||||
p := make([]byte, len(req))
|
_, p, err := s.Read(math.MaxInt32)
|
||||||
_, err := s.Read(p)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -109,31 +108,26 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
|
|||||||
t.Fatalf("handleStream got %v, want %v", p, req)
|
t.Fatalf("handleStream got %v, want %v", p, req)
|
||||||
}
|
}
|
||||||
// send a response back to the client.
|
// send a response back to the client.
|
||||||
h.t.Write(s, nil, resp, &Options{})
|
h.t.Write(s, resp, &Options{})
|
||||||
// send the trailer to end the stream.
|
// send the trailer to end the stream.
|
||||||
h.t.WriteStatus(s, status.New(codes.OK, ""))
|
h.t.WriteStatus(s, status.New(codes.OK, ""))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) {
|
func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) {
|
||||||
header := make([]byte, 5)
|
|
||||||
for {
|
for {
|
||||||
if _, err := s.Read(header); err != nil {
|
_, msg, err := s.Read(math.MaxInt32)
|
||||||
|
if err != nil {
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
h.t.WriteStatus(s, status.New(codes.OK, ""))
|
h.t.WriteStatus(s, status.New(codes.OK, ""))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
t.Fatalf("Error on server while reading data header: %v", err)
|
t.Errorf("Error on server while reading data header: %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
sz := binary.BigEndian.Uint32(header[1:])
|
if err := h.t.Write(s, msg, &Options{}); err != nil {
|
||||||
msg := make([]byte, int(sz))
|
t.Errorf("Error on server while writing: %v", err)
|
||||||
if _, err := s.Read(msg); err != nil {
|
return
|
||||||
t.Fatalf("Error on server while reading message: %v", err)
|
|
||||||
}
|
}
|
||||||
buf := make([]byte, sz+5)
|
|
||||||
buf[0] = byte(0)
|
|
||||||
binary.BigEndian.PutUint32(buf[1:], uint32(sz))
|
|
||||||
copy(buf[5:], msg)
|
|
||||||
h.t.Write(s, nil, buf, &Options{})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -189,12 +183,10 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) {
|
|||||||
req = expectedRequestLarge
|
req = expectedRequestLarge
|
||||||
resp = expectedResponseLarge
|
resp = expectedResponseLarge
|
||||||
}
|
}
|
||||||
p := make([]byte, len(req))
|
|
||||||
|
|
||||||
// Wait before reading. Give time to client to start sending
|
// Wait before reading. Give time to client to start sending
|
||||||
// before server starts reading.
|
// before server starts reading.
|
||||||
time.Sleep(2 * time.Second)
|
time.Sleep(2 * time.Second)
|
||||||
_, err := s.Read(p)
|
_, p, err := s.Read(math.MaxInt32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("s.Read(_) = _, %v, want _, <nil>", err)
|
t.Errorf("s.Read(_) = _, %v, want _, <nil>", err)
|
||||||
return
|
return
|
||||||
@ -205,7 +197,7 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
// send a response back to the client.
|
// send a response back to the client.
|
||||||
if err := h.t.Write(s, nil, resp, &Options{}); err != nil {
|
if err := h.t.Write(s, resp, &Options{}); err != nil {
|
||||||
t.Errorf("server Write got %v, want <nil>", err)
|
t.Errorf("server Write got %v, want <nil>", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -223,8 +215,7 @@ func (h *testStreamHandler) handleStreamDelayWrite(t *testing.T, s *Stream) {
|
|||||||
req = expectedRequestLarge
|
req = expectedRequestLarge
|
||||||
resp = expectedResponseLarge
|
resp = expectedResponseLarge
|
||||||
}
|
}
|
||||||
p := make([]byte, len(req))
|
_, p, err := s.Read(math.MaxInt32)
|
||||||
_, err := s.Read(p)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("s.Read(_) = _, %v, want _, <nil>", err)
|
t.Errorf("s.Read(_) = _, %v, want _, <nil>", err)
|
||||||
return
|
return
|
||||||
@ -237,7 +228,7 @@ func (h *testStreamHandler) handleStreamDelayWrite(t *testing.T, s *Stream) {
|
|||||||
// Wait before sending. Give time to client to start reading
|
// Wait before sending. Give time to client to start reading
|
||||||
// before server starts sending.
|
// before server starts sending.
|
||||||
time.Sleep(2 * time.Second)
|
time.Sleep(2 * time.Second)
|
||||||
if err := h.t.Write(s, nil, resp, &Options{}); err != nil {
|
if err := h.t.Write(s, resp, &Options{}); err != nil {
|
||||||
t.Errorf("server Write got %v, want <nil>", err)
|
t.Errorf("server Write got %v, want <nil>", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -442,7 +433,7 @@ func TestInflightStreamClosing(t *testing.T) {
|
|||||||
serr := StreamError{Desc: "client connection is closing"}
|
serr := StreamError{Desc: "client connection is closing"}
|
||||||
go func() {
|
go func() {
|
||||||
defer close(donec)
|
defer close(donec)
|
||||||
if _, err := stream.Read(make([]byte, defaultWindowSize)); err != serr {
|
if _, _, err := stream.Read(math.MaxInt32); err != serr {
|
||||||
t.Errorf("unexpected Stream error %v, expected %v", err, serr)
|
t.Errorf("unexpected Stream error %v, expected %v", err, serr)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -858,15 +849,14 @@ func TestClientSendAndReceive(t *testing.T) {
|
|||||||
Last: true,
|
Last: true,
|
||||||
Delay: false,
|
Delay: false,
|
||||||
}
|
}
|
||||||
if err := ct.Write(s1, nil, expectedRequest, &opts); err != nil && err != io.EOF {
|
if err := ct.Write(s1, expectedRequest, &opts); err != nil && err != io.EOF {
|
||||||
t.Fatalf("failed to send data: %v", err)
|
t.Fatalf("failed to send data: %v", err)
|
||||||
}
|
}
|
||||||
p := make([]byte, len(expectedResponse))
|
_, p, recvErr := s1.Read(math.MaxInt32)
|
||||||
_, recvErr := s1.Read(p)
|
|
||||||
if recvErr != nil || !bytes.Equal(p, expectedResponse) {
|
if recvErr != nil || !bytes.Equal(p, expectedResponse) {
|
||||||
t.Fatalf("Error: %v, want <nil>; Result: %v, want %v", recvErr, p, expectedResponse)
|
t.Fatalf("Error: %v, want <nil>; Result: %v, want %v", recvErr, p, expectedResponse)
|
||||||
}
|
}
|
||||||
_, recvErr = s1.Read(p)
|
_, _, recvErr = s1.Read(math.MaxInt32)
|
||||||
if recvErr != io.EOF {
|
if recvErr != io.EOF {
|
||||||
t.Fatalf("Error: %v; want <EOF>", recvErr)
|
t.Fatalf("Error: %v; want <EOF>", recvErr)
|
||||||
}
|
}
|
||||||
@ -895,16 +885,15 @@ func performOneRPC(ct ClientTransport) {
|
|||||||
Last: true,
|
Last: true,
|
||||||
Delay: false,
|
Delay: false,
|
||||||
}
|
}
|
||||||
if err := ct.Write(s, []byte{}, expectedRequest, &opts); err == nil || err == io.EOF {
|
if err := ct.Write(s, expectedRequest, &opts); err == nil || err == io.EOF {
|
||||||
time.Sleep(5 * time.Millisecond)
|
time.Sleep(5 * time.Millisecond)
|
||||||
// The following s.Recv()'s could error out because the
|
// The following s.Recv()'s could error out because the
|
||||||
// underlying transport is gone.
|
// underlying transport is gone.
|
||||||
//
|
//
|
||||||
// Read response
|
// Read response
|
||||||
p := make([]byte, len(expectedResponse))
|
s.Read(math.MaxInt32)
|
||||||
s.Read(p)
|
|
||||||
// Read io.EOF
|
// Read io.EOF
|
||||||
s.Read(p)
|
s.Read(math.MaxInt32)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -939,14 +928,13 @@ func TestLargeMessage(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
|
t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
|
||||||
}
|
}
|
||||||
if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
|
if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
|
||||||
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
|
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
|
||||||
}
|
}
|
||||||
p := make([]byte, len(expectedResponseLarge))
|
if _, p, err := s.Read(math.MaxInt32); err != nil || !bytes.Equal(p, expectedResponseLarge) {
|
||||||
if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
|
t.Errorf("s.Read(math.MaxInt32) = %v, %v, want %v, <nil>", p, err, expectedResponse)
|
||||||
t.Errorf("s.Read(%v) = _, %v, want %v, <nil>", err, p, expectedResponse)
|
|
||||||
}
|
}
|
||||||
if _, err = s.Read(p); err != io.EOF {
|
if _, _, err = s.Read(math.MaxInt32); err != io.EOF {
|
||||||
t.Errorf("Failed to complete the stream %v; want <EOF>", err)
|
t.Errorf("Failed to complete the stream %v; want <EOF>", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -974,19 +962,18 @@ func TestLargeMessageWithDelayRead(t *testing.T) {
|
|||||||
t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
|
t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil {
|
if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil {
|
||||||
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
|
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p := make([]byte, len(expectedResponseLarge))
|
|
||||||
|
|
||||||
// Give time to server to begin sending before client starts reading.
|
// Give time to server to begin sending before client starts reading.
|
||||||
time.Sleep(2 * time.Second)
|
time.Sleep(2 * time.Second)
|
||||||
if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
|
if _, p, err := s.Read(math.MaxInt32); err != nil || !bytes.Equal(p, expectedResponseLarge) {
|
||||||
t.Errorf("s.Read(_) = _, %v, want _, <nil>", err)
|
t.Errorf("s.Read(_) = _, %v, want _, <nil>", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if _, err = s.Read(p); err != io.EOF {
|
if _, _, err = s.Read(math.MaxInt32); err != io.EOF {
|
||||||
t.Errorf("Failed to complete the stream %v; want <EOF>", err)
|
t.Errorf("Failed to complete the stream %v; want <EOF>", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -1017,16 +1004,15 @@ func TestLargeMessageDelayWrite(t *testing.T) {
|
|||||||
|
|
||||||
// Give time to server to start reading before client starts sending.
|
// Give time to server to start reading before client starts sending.
|
||||||
time.Sleep(2 * time.Second)
|
time.Sleep(2 * time.Second)
|
||||||
if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil {
|
if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil {
|
||||||
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
|
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p := make([]byte, len(expectedResponseLarge))
|
if _, p, err := s.Read(math.MaxInt32); err != nil || !bytes.Equal(p, expectedResponseLarge) {
|
||||||
if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
|
|
||||||
t.Errorf("io.ReadFull(%v) = _, %v, want %v, <nil>", err, p, expectedResponse)
|
t.Errorf("io.ReadFull(%v) = _, %v, want %v, <nil>", err, p, expectedResponse)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if _, err = s.Read(p); err != io.EOF {
|
if _, _, err = s.Read(math.MaxInt32); err != io.EOF {
|
||||||
t.Errorf("Failed to complete the stream %v; want <EOF>", err)
|
t.Errorf("Failed to complete the stream %v; want <EOF>", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -1047,19 +1033,10 @@ func TestGracefulClose(t *testing.T) {
|
|||||||
t.Fatalf("NewStream(_, _) = _, %v, want _, <nil>", err)
|
t.Fatalf("NewStream(_, _) = _, %v, want _, <nil>", err)
|
||||||
}
|
}
|
||||||
msg := make([]byte, 1024)
|
msg := make([]byte, 1024)
|
||||||
outgoingHeader := make([]byte, 5)
|
if err := ct.Write(s, msg, &Options{}); err != nil {
|
||||||
outgoingHeader[0] = byte(0)
|
|
||||||
binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(len(msg)))
|
|
||||||
incomingHeader := make([]byte, 5)
|
|
||||||
if err := ct.Write(s, outgoingHeader, msg, &Options{}); err != nil {
|
|
||||||
t.Fatalf("Error while writing: %v", err)
|
t.Fatalf("Error while writing: %v", err)
|
||||||
}
|
}
|
||||||
if _, err := s.Read(incomingHeader); err != nil {
|
if _, _, err := s.Read(math.MaxInt32); err != nil {
|
||||||
t.Fatalf("Error while reading: %v", err)
|
|
||||||
}
|
|
||||||
sz := binary.BigEndian.Uint32(incomingHeader[1:])
|
|
||||||
recvMsg := make([]byte, int(sz))
|
|
||||||
if _, err := s.Read(recvMsg); err != nil {
|
|
||||||
t.Fatalf("Error while reading: %v", err)
|
t.Fatalf("Error while reading: %v", err)
|
||||||
}
|
}
|
||||||
if err = ct.GracefulClose(); err != nil {
|
if err = ct.GracefulClose(); err != nil {
|
||||||
@ -1075,14 +1052,14 @@ func TestGracefulClose(t *testing.T) {
|
|||||||
if err == errStreamDrain {
|
if err == errStreamDrain {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ct.Write(str, nil, nil, &Options{Last: true})
|
ct.Write(str, nil, &Options{Last: true})
|
||||||
if _, err := str.Read(make([]byte, 8)); err != errStreamDrain {
|
if _, _, err := str.Read(math.MaxInt32); err != errStreamDrain {
|
||||||
t.Errorf("_.NewStream(_, _) = _, %v, want _, %v", err, errStreamDrain)
|
t.Errorf("_.NewStream(_, _) = _, %v, want _, %v", err, errStreamDrain)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
ct.Write(s, nil, nil, &Options{Last: true})
|
ct.Write(s, nil, &Options{Last: true})
|
||||||
if _, err := s.Read(incomingHeader); err != io.EOF {
|
if _, _, err := s.Read(math.MaxInt32); err != io.EOF {
|
||||||
t.Fatalf("Client expected EOF from the server. Got: %v", err)
|
t.Fatalf("Client expected EOF from the server. Got: %v", err)
|
||||||
}
|
}
|
||||||
// The stream which was created before graceful close can still proceed.
|
// The stream which was created before graceful close can still proceed.
|
||||||
@ -1110,13 +1087,13 @@ func TestLargeMessageSuspension(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
// Write should not be done successfully due to flow control.
|
// Write should not be done successfully due to flow control.
|
||||||
msg := make([]byte, initialWindowSize*8)
|
msg := make([]byte, initialWindowSize*8)
|
||||||
ct.Write(s, nil, msg, &Options{})
|
ct.Write(s, msg, &Options{})
|
||||||
err = ct.Write(s, nil, msg, &Options{Last: true})
|
err = ct.Write(s, msg, &Options{Last: true})
|
||||||
if err != errStreamDone {
|
if err != errStreamDone {
|
||||||
t.Fatalf("Write got %v, want io.EOF", err)
|
t.Fatalf("Write got %v, want io.EOF", err)
|
||||||
}
|
}
|
||||||
expectedErr := streamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded)
|
expectedErr := streamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded)
|
||||||
if _, err := s.Read(make([]byte, 8)); err != expectedErr {
|
if _, _, err := s.Read(math.MaxInt32); err != expectedErr {
|
||||||
t.Fatalf("Read got %v of type %T, want %v", err, err, expectedErr)
|
t.Fatalf("Read got %v of type %T, want %v", err, err, expectedErr)
|
||||||
}
|
}
|
||||||
ct.Close()
|
ct.Close()
|
||||||
@ -1305,7 +1282,7 @@ func TestClientConnDecoupledFromApplicationRead(t *testing.T) {
|
|||||||
t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream1.id)
|
t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream1.id)
|
||||||
}
|
}
|
||||||
// Exhaust client's connection window.
|
// Exhaust client's connection window.
|
||||||
if err := st.Write(sstream1, []byte{}, make([]byte, defaultWindowSize), &Options{}); err != nil {
|
if err := st.Write(sstream1, make([]byte, defaultWindowSize), &Options{}); err != nil {
|
||||||
t.Fatalf("Server failed to write data. Err: %v", err)
|
t.Fatalf("Server failed to write data. Err: %v", err)
|
||||||
}
|
}
|
||||||
notifyChan = make(chan struct{})
|
notifyChan = make(chan struct{})
|
||||||
@ -1330,17 +1307,17 @@ func TestClientConnDecoupledFromApplicationRead(t *testing.T) {
|
|||||||
t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream2.id)
|
t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream2.id)
|
||||||
}
|
}
|
||||||
// Server should be able to send data on the new stream, even though the client hasn't read anything on the first stream.
|
// Server should be able to send data on the new stream, even though the client hasn't read anything on the first stream.
|
||||||
if err := st.Write(sstream2, []byte{}, make([]byte, defaultWindowSize), &Options{}); err != nil {
|
if err := st.Write(sstream2, make([]byte, defaultWindowSize), &Options{}); err != nil {
|
||||||
t.Fatalf("Server failed to write data. Err: %v", err)
|
t.Fatalf("Server failed to write data. Err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client should be able to read data on second stream.
|
// Client should be able to read data on second stream.
|
||||||
if _, err := cstream2.Read(make([]byte, defaultWindowSize)); err != nil {
|
if _, _, err := cstream2.Read(math.MaxInt32); err != nil {
|
||||||
t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err)
|
t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client should be able to read data on first stream.
|
// Client should be able to read data on first stream.
|
||||||
if _, err := cstream1.Read(make([]byte, defaultWindowSize)); err != nil {
|
if _, _, err := cstream1.Read(math.MaxInt32); err != nil {
|
||||||
t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err)
|
t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1373,7 +1350,7 @@ func TestServerConnDecoupledFromApplicationRead(t *testing.T) {
|
|||||||
t.Fatalf("Failed to create 1st stream. Err: %v", err)
|
t.Fatalf("Failed to create 1st stream. Err: %v", err)
|
||||||
}
|
}
|
||||||
// Exhaust server's connection window.
|
// Exhaust server's connection window.
|
||||||
if err := client.Write(cstream1, nil, make([]byte, defaultWindowSize), &Options{Last: true}); err != nil {
|
if err := client.Write(cstream1, make([]byte, defaultWindowSize), &Options{Last: true}); err != nil {
|
||||||
t.Fatalf("Client failed to write data. Err: %v", err)
|
t.Fatalf("Client failed to write data. Err: %v", err)
|
||||||
}
|
}
|
||||||
//Client should be able to create another stream and send data on it.
|
//Client should be able to create another stream and send data on it.
|
||||||
@ -1381,7 +1358,7 @@ func TestServerConnDecoupledFromApplicationRead(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create 2nd stream. Err: %v", err)
|
t.Fatalf("Failed to create 2nd stream. Err: %v", err)
|
||||||
}
|
}
|
||||||
if err := client.Write(cstream2, nil, make([]byte, defaultWindowSize), &Options{}); err != nil {
|
if err := client.Write(cstream2, make([]byte, defaultWindowSize), &Options{}); err != nil {
|
||||||
t.Fatalf("Client failed to write data. Err: %v", err)
|
t.Fatalf("Client failed to write data. Err: %v", err)
|
||||||
}
|
}
|
||||||
// Get the streams on server.
|
// Get the streams on server.
|
||||||
@ -1403,11 +1380,11 @@ func TestServerConnDecoupledFromApplicationRead(t *testing.T) {
|
|||||||
}
|
}
|
||||||
st.mu.Unlock()
|
st.mu.Unlock()
|
||||||
// Reading from the stream on server should succeed.
|
// Reading from the stream on server should succeed.
|
||||||
if _, err := sstream1.Read(make([]byte, defaultWindowSize)); err != nil {
|
if _, _, err := sstream1.Read(math.MaxInt32); err != nil {
|
||||||
t.Fatalf("_.Read(_) = %v, want <nil>", err)
|
t.Fatalf("_.Read(_) = %v, want <nil>", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := sstream1.Read(make([]byte, 1)); err != io.EOF {
|
if _, _, err := sstream1.Read(math.MaxInt32); err != io.EOF {
|
||||||
t.Fatalf("_.Read(_) = %v, want io.EOF", err)
|
t.Fatalf("_.Read(_) = %v, want io.EOF", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1616,11 +1593,10 @@ func TestEncodingRequiredStatus(t *testing.T) {
|
|||||||
Last: true,
|
Last: true,
|
||||||
Delay: false,
|
Delay: false,
|
||||||
}
|
}
|
||||||
if err := ct.Write(s, nil, expectedRequest, &opts); err != nil && err != errStreamDone {
|
if err := ct.Write(s, expectedRequest, &opts); err != nil && err != errStreamDone {
|
||||||
t.Fatalf("Failed to write the request: %v", err)
|
t.Fatalf("Failed to write the request: %v", err)
|
||||||
}
|
}
|
||||||
p := make([]byte, http2MaxFrameLen)
|
if _, _, err := s.Read(math.MaxInt32); err != io.EOF {
|
||||||
if _, err := s.trReader.(*transportReader).Read(p); err != io.EOF {
|
|
||||||
t.Fatalf("Read got error %v, want %v", err, io.EOF)
|
t.Fatalf("Read got error %v, want %v", err, io.EOF)
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(s.Status(), encodingTestStatus) {
|
if !reflect.DeepEqual(s.Status(), encodingTestStatus) {
|
||||||
@ -1640,8 +1616,7 @@ func TestInvalidHeaderField(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p := make([]byte, http2MaxFrameLen)
|
_, _, err = s.Read(math.MaxInt32)
|
||||||
_, err = s.trReader.(*transportReader).Read(p)
|
|
||||||
if se, ok := err.(StreamError); !ok || se.Code != codes.Internal || !strings.Contains(err.Error(), expectedInvalidHeaderField) {
|
if se, ok := err.(StreamError); !ok || se.Code != codes.Internal || !strings.Contains(err.Error(), expectedInvalidHeaderField) {
|
||||||
t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.Internal, expectedInvalidHeaderField)
|
t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.Internal, expectedInvalidHeaderField)
|
||||||
}
|
}
|
||||||
@ -1764,26 +1739,17 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig)
|
|||||||
t.Fatalf("Failed to create stream. Err: %v", err)
|
t.Fatalf("Failed to create stream. Err: %v", err)
|
||||||
}
|
}
|
||||||
msg := make([]byte, msgSize)
|
msg := make([]byte, msgSize)
|
||||||
buf := make([]byte, msgSize+5)
|
|
||||||
buf[0] = byte(0)
|
|
||||||
binary.BigEndian.PutUint32(buf[1:], uint32(msgSize))
|
|
||||||
copy(buf[5:], msg)
|
|
||||||
opts := Options{}
|
opts := Options{}
|
||||||
header := make([]byte, 5)
|
|
||||||
for i := 1; i <= 10; i++ {
|
for i := 1; i <= 10; i++ {
|
||||||
if err := ct.Write(cstream, nil, buf, &opts); err != nil {
|
if err := ct.Write(cstream, msg, &opts); err != nil {
|
||||||
t.Fatalf("Error on client while writing message: %v", err)
|
t.Fatalf("Error on client while writing message: %v", err)
|
||||||
}
|
}
|
||||||
if _, err := cstream.Read(header); err != nil {
|
_, recvMsg, err := cstream.Read(math.MaxInt32)
|
||||||
t.Fatalf("Error on client while reading data frame header: %v", err)
|
if err != nil {
|
||||||
}
|
|
||||||
sz := binary.BigEndian.Uint32(header[1:])
|
|
||||||
recvMsg := make([]byte, int(sz))
|
|
||||||
if _, err := cstream.Read(recvMsg); err != nil {
|
|
||||||
t.Fatalf("Error on client while reading data: %v", err)
|
t.Fatalf("Error on client while reading data: %v", err)
|
||||||
}
|
}
|
||||||
if len(recvMsg) != len(msg) {
|
if !bytes.Equal(recvMsg, msg) {
|
||||||
t.Fatalf("Length of message received by client: %v, want: %v", len(recvMsg), len(msg))
|
t.Fatalf("Message received by client(len: %d) not equal to what was expected(len: %d)", len(recvMsg), len(msg))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var sstream *Stream
|
var sstream *Stream
|
||||||
@ -1794,8 +1760,8 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig)
|
|||||||
st.mu.Unlock()
|
st.mu.Unlock()
|
||||||
loopyServerStream := st.loopy.estdStreams[sstream.id]
|
loopyServerStream := st.loopy.estdStreams[sstream.id]
|
||||||
loopyClientStream := ct.loopy.estdStreams[cstream.id]
|
loopyClientStream := ct.loopy.estdStreams[cstream.id]
|
||||||
ct.Write(cstream, nil, nil, &Options{Last: true}) // Close the stream.
|
ct.Write(cstream, nil, &Options{Last: true}) // Close the stream.
|
||||||
if _, err := cstream.Read(header); err != io.EOF {
|
if _, _, err := cstream.Read(math.MaxInt32); err != io.EOF {
|
||||||
t.Fatalf("Client expected an EOF from the server. Got: %v", err)
|
t.Fatalf("Client expected an EOF from the server. Got: %v", err)
|
||||||
}
|
}
|
||||||
// Sleep for a little to make sure both sides flush out their buffers.
|
// Sleep for a little to make sure both sides flush out their buffers.
|
||||||
@ -1816,11 +1782,11 @@ func testFlowControlAccountCheck(t *testing.T, msgSize int, wc windowSizeConfig)
|
|||||||
t.Fatalf("Account mismatch: server transport inflow(%d) != server unacked(%d) + client sendQuota(%d)", st.fc.limit, st.fc.unacked, ct.loopy.sendQuota)
|
t.Fatalf("Account mismatch: server transport inflow(%d) != server unacked(%d) + client sendQuota(%d)", st.fc.limit, st.fc.unacked, ct.loopy.sendQuota)
|
||||||
}
|
}
|
||||||
// Check stream flow control.
|
// Check stream flow control.
|
||||||
if int(cstream.fc.limit+cstream.fc.delta-cstream.fc.pendingData-cstream.fc.pendingUpdate) != int(st.loopy.oiws)-loopyServerStream.bytesOutStanding {
|
if int(cstream.fc.limit)-int(cstream.fc.rcvd) != int(st.loopy.oiws)-loopyServerStream.bytesOutStanding {
|
||||||
t.Fatalf("Account mismatch: client stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != server outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", cstream.fc.limit, cstream.fc.delta, cstream.fc.pendingData, cstream.fc.pendingUpdate, st.loopy.oiws, loopyServerStream.bytesOutStanding)
|
t.Fatalf("Account mismatch: client stream inflow limit(%d) - rcvd(%d) != server outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", cstream.fc.limit, cstream.fc.rcvd, st.loopy.oiws, loopyServerStream.bytesOutStanding)
|
||||||
}
|
}
|
||||||
if int(sstream.fc.limit+sstream.fc.delta-sstream.fc.pendingData-sstream.fc.pendingUpdate) != int(ct.loopy.oiws)-loopyClientStream.bytesOutStanding {
|
if int(sstream.fc.limit)-int(sstream.fc.rcvd) != int(ct.loopy.oiws)-loopyClientStream.bytesOutStanding {
|
||||||
t.Fatalf("Account mismatch: server stream inflow limit(%d) + delta(%d) - pendingData(%d) - pendingUpdate(%d) != client outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", sstream.fc.limit, sstream.fc.delta, sstream.fc.pendingData, sstream.fc.pendingUpdate, ct.loopy.oiws, loopyClientStream.bytesOutStanding)
|
t.Fatalf("Account mismatch: server stream inflow limit(%d) - rcvd(%d) != client outgoing InitialWindowSize(%d) - outgoingStream.bytesOutStanding(%d)", sstream.fc.limit, sstream.fc.rcvd, ct.loopy.oiws, loopyClientStream.bytesOutStanding)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2000,8 +1966,7 @@ func testHTTPToGRPCStatusMapping(t *testing.T, httpStatus int, wh writeHeaders)
|
|||||||
stream, cleanUp := setUpHTTPStatusTest(t, httpStatus, wh)
|
stream, cleanUp := setUpHTTPStatusTest(t, httpStatus, wh)
|
||||||
defer cleanUp()
|
defer cleanUp()
|
||||||
want := httpStatusConvTab[httpStatus]
|
want := httpStatusConvTab[httpStatus]
|
||||||
buf := make([]byte, 8)
|
_, _, err := stream.Read(math.MaxInt32)
|
||||||
_, err := stream.Read(buf)
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("Stream.Read(_) unexpectedly returned no error. Expected stream error with code %v", want)
|
t.Fatalf("Stream.Read(_) unexpectedly returned no error. Expected stream error with code %v", want)
|
||||||
}
|
}
|
||||||
@ -2017,8 +1982,7 @@ func testHTTPToGRPCStatusMapping(t *testing.T, httpStatus int, wh writeHeaders)
|
|||||||
func TestHTTPStatusOKAndMissingGRPCStatus(t *testing.T) {
|
func TestHTTPStatusOKAndMissingGRPCStatus(t *testing.T) {
|
||||||
stream, cleanUp := setUpHTTPStatusTest(t, http.StatusOK, writeOneHeader)
|
stream, cleanUp := setUpHTTPStatusTest(t, http.StatusOK, writeOneHeader)
|
||||||
defer cleanUp()
|
defer cleanUp()
|
||||||
buf := make([]byte, 8)
|
_, _, err := stream.Read(math.MaxInt32)
|
||||||
_, err := stream.Read(buf)
|
|
||||||
if err != io.EOF {
|
if err != io.EOF {
|
||||||
t.Fatalf("stream.Read(_) = _, %v, want _, io.EOF", err)
|
t.Fatalf("stream.Read(_) = _, %v, want _, io.EOF", err)
|
||||||
}
|
}
|
||||||
@ -2035,45 +1999,25 @@ func TestHTTPStatusNottOKAndMissingGRPCStatusInSecondHeader(t *testing.T) {
|
|||||||
// If any error occurs on a call to Stream.Read, future calls
|
// If any error occurs on a call to Stream.Read, future calls
|
||||||
// should continue to return that same error.
|
// should continue to return that same error.
|
||||||
func TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) {
|
func TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) {
|
||||||
testRecvBuffer := newRecvBuffer()
|
s := newStream(context.Background())
|
||||||
s := &Stream{
|
|
||||||
ctx: context.Background(),
|
|
||||||
buf: testRecvBuffer,
|
|
||||||
requestRead: func(int) {},
|
|
||||||
}
|
|
||||||
s.trReader = &transportReader{
|
|
||||||
reader: &recvBufferReader{
|
|
||||||
ctx: s.ctx,
|
|
||||||
ctxDone: s.ctx.Done(),
|
|
||||||
recv: s.buf,
|
|
||||||
},
|
|
||||||
windowHandler: func(int) {},
|
|
||||||
}
|
|
||||||
testData := make([]byte, 1)
|
|
||||||
testData[0] = 5
|
|
||||||
testErr := errors.New("test error")
|
testErr := errors.New("test error")
|
||||||
s.write(recvMsg{data: testData, err: testErr})
|
s.notifyErr(testErr)
|
||||||
|
|
||||||
inBuf := make([]byte, 1)
|
pf, inBuf, actualErr := s.Read(math.MaxInt32)
|
||||||
actualCount, actualErr := s.Read(inBuf)
|
if pf != false || inBuf != nil || actualErr.Error() != testErr.Error() {
|
||||||
if actualCount != 0 {
|
t.Errorf("%v, %v, %v := s.Read(_) differs; want false, <nil>, %v", pf, inBuf, actualErr, testErr)
|
||||||
t.Errorf("actualCount, _ := s.Read(_) differs; want 0; got %v", actualCount)
|
|
||||||
}
|
|
||||||
if actualErr.Error() != testErr.Error() {
|
|
||||||
t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s.write(recvMsg{data: testData, err: nil})
|
testData := make([]byte, 6)
|
||||||
s.write(recvMsg{data: testData, err: errors.New("different error from first")})
|
testData[0] = byte(1)
|
||||||
|
binary.BigEndian.PutUint32(testData[1:], uint32(1))
|
||||||
|
s.consume(testData, 0)
|
||||||
|
s.notifyErr(errors.New("different error from first"))
|
||||||
|
|
||||||
for i := 0; i < 2; i++ {
|
for i := 0; i < 2; i++ {
|
||||||
inBuf := make([]byte, 1)
|
pf, inBuf, actualErr := s.Read(math.MaxInt32)
|
||||||
actualCount, actualErr := s.Read(inBuf)
|
if pf != false || inBuf != nil || actualErr.Error() != testErr.Error() {
|
||||||
if actualCount != 0 {
|
t.Errorf("%v, %v, %v := s.Read(_) differs; want false, <nil>, %v", pf, inBuf, actualErr, testErr)
|
||||||
t.Errorf("actualCount, _ := s.Read(_) differs; want %v; got %v", 0, actualCount)
|
|
||||||
}
|
|
||||||
if actualErr.Error() != testErr.Error() {
|
|
||||||
t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2113,11 +2057,7 @@ func runPingPongTest(t *testing.T, msgSize int) {
|
|||||||
t.Fatalf("Failed to create stream. Err: %v", err)
|
t.Fatalf("Failed to create stream. Err: %v", err)
|
||||||
}
|
}
|
||||||
msg := make([]byte, msgSize)
|
msg := make([]byte, msgSize)
|
||||||
outgoingHeader := make([]byte, 5)
|
|
||||||
outgoingHeader[0] = byte(0)
|
|
||||||
binary.BigEndian.PutUint32(outgoingHeader[1:], uint32(msgSize))
|
|
||||||
opts := &Options{}
|
opts := &Options{}
|
||||||
incomingHeader := make([]byte, 5)
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
timer := time.NewTimer(time.Second * 5)
|
timer := time.NewTimer(time.Second * 5)
|
||||||
@ -2127,23 +2067,22 @@ func runPingPongTest(t *testing.T, msgSize int) {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
ct.Write(stream, nil, nil, &Options{Last: true})
|
ct.Write(stream, nil, &Options{Last: true})
|
||||||
if _, err := stream.Read(incomingHeader); err != io.EOF {
|
if _, _, err := stream.Read(math.MaxInt32); err != io.EOF {
|
||||||
t.Fatalf("Client expected EOF from the server. Got: %v", err)
|
t.Fatalf("Client expected EOF from the server. Got: %v", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
if err := ct.Write(stream, outgoingHeader, msg, opts); err != nil {
|
if err := ct.Write(stream, msg, opts); err != nil {
|
||||||
t.Fatalf("Error on client while writing message. Err: %v", err)
|
t.Fatalf("Error on client while writing message. Err: %v", err)
|
||||||
}
|
}
|
||||||
if _, err := stream.Read(incomingHeader); err != nil {
|
_, recvMsg, err := stream.Read(math.MaxInt32)
|
||||||
t.Fatalf("Error on client while reading data header. Err: %v", err)
|
if err != nil {
|
||||||
}
|
|
||||||
sz := binary.BigEndian.Uint32(incomingHeader[1:])
|
|
||||||
recvMsg := make([]byte, int(sz))
|
|
||||||
if _, err := stream.Read(recvMsg); err != nil {
|
|
||||||
t.Fatalf("Error on client while reading data. Err: %v", err)
|
t.Fatalf("Error on client while reading data. Err: %v", err)
|
||||||
}
|
}
|
||||||
|
if !bytes.Equal(recvMsg, msg) {
|
||||||
|
t.Fatalf("%v != %v", recvMsg, msg)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user