Client should have a check on maximum size of received message size.
This commit is contained in:
3
call.go
3
call.go
@ -36,7 +36,6 @@ package grpc
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
@ -73,7 +72,7 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran
|
||||
}
|
||||
}
|
||||
for {
|
||||
if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32, inPayload); err != nil {
|
||||
if err = recv(p, dopts.codec, stream, dopts.dc, reply, dopts.maxMsgSize, inPayload); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
|
||||
@ -36,6 +36,7 @@ package grpc
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -87,23 +88,33 @@ var (
|
||||
// dialOptions configure a Dial call. dialOptions are set by the DialOption
|
||||
// values passed to Dial.
|
||||
type dialOptions struct {
|
||||
unaryInt UnaryClientInterceptor
|
||||
streamInt StreamClientInterceptor
|
||||
codec Codec
|
||||
cp Compressor
|
||||
dc Decompressor
|
||||
bs backoffStrategy
|
||||
balancer Balancer
|
||||
block bool
|
||||
insecure bool
|
||||
timeout time.Duration
|
||||
scChan <-chan ServiceConfig
|
||||
copts transport.ConnectOptions
|
||||
unaryInt UnaryClientInterceptor
|
||||
streamInt StreamClientInterceptor
|
||||
codec Codec
|
||||
cp Compressor
|
||||
dc Decompressor
|
||||
bs backoffStrategy
|
||||
balancer Balancer
|
||||
block bool
|
||||
insecure bool
|
||||
timeout time.Duration
|
||||
scChan <-chan ServiceConfig
|
||||
copts transport.ConnectOptions
|
||||
maxMsgSize int
|
||||
}
|
||||
|
||||
const defaultClientMaxMsgSize = math.MaxInt32
|
||||
|
||||
// DialOption configures how we set up the connection.
|
||||
type DialOption func(*dialOptions)
|
||||
|
||||
// WithMaxMsgSize returns a DialOption which sets the maximum message size the client can receive.
|
||||
func WithMaxMsgSize(s int) DialOption {
|
||||
return func(o *dialOptions) {
|
||||
o.maxMsgSize = s
|
||||
}
|
||||
}
|
||||
|
||||
// WithCodec returns a DialOption which sets a codec for message marshaling and unmarshaling.
|
||||
func WithCodec(c Codec) DialOption {
|
||||
return func(o *dialOptions) {
|
||||
@ -304,6 +315,9 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
|
||||
ctx, cancel = context.WithTimeout(ctx, cc.dopts.timeout)
|
||||
defer cancel()
|
||||
}
|
||||
if cc.dopts.maxMsgSize == 0 {
|
||||
cc.dopts.maxMsgSize = defaultClientMaxMsgSize
|
||||
}
|
||||
|
||||
defer func() {
|
||||
select {
|
||||
|
||||
43
stream.go
43
stream.go
@ -37,7 +37,6 @@ import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -208,13 +207,14 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
|
||||
break
|
||||
}
|
||||
cs := &clientStream{
|
||||
opts: opts,
|
||||
c: c,
|
||||
desc: desc,
|
||||
codec: cc.dopts.codec,
|
||||
cp: cc.dopts.cp,
|
||||
dc: cc.dopts.dc,
|
||||
cancel: cancel,
|
||||
opts: opts,
|
||||
c: c,
|
||||
desc: desc,
|
||||
codec: cc.dopts.codec,
|
||||
cp: cc.dopts.cp,
|
||||
dc: cc.dopts.dc,
|
||||
maxMsgSize: cc.dopts.maxMsgSize,
|
||||
cancel: cancel,
|
||||
|
||||
put: put,
|
||||
t: t,
|
||||
@ -259,17 +259,18 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
|
||||
|
||||
// clientStream implements a client side Stream.
|
||||
type clientStream struct {
|
||||
opts []CallOption
|
||||
c callInfo
|
||||
t transport.ClientTransport
|
||||
s *transport.Stream
|
||||
p *parser
|
||||
desc *StreamDesc
|
||||
codec Codec
|
||||
cp Compressor
|
||||
cbuf *bytes.Buffer
|
||||
dc Decompressor
|
||||
cancel context.CancelFunc
|
||||
opts []CallOption
|
||||
c callInfo
|
||||
t transport.ClientTransport
|
||||
s *transport.Stream
|
||||
p *parser
|
||||
desc *StreamDesc
|
||||
codec Codec
|
||||
cp Compressor
|
||||
cbuf *bytes.Buffer
|
||||
dc Decompressor
|
||||
maxMsgSize int
|
||||
cancel context.CancelFunc
|
||||
|
||||
tracing bool // set to EnableTracing when the clientStream is created.
|
||||
|
||||
@ -382,7 +383,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
||||
Client: true,
|
||||
}
|
||||
}
|
||||
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, inPayload)
|
||||
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxMsgSize, inPayload)
|
||||
defer func() {
|
||||
// err != nil indicates the termination of the stream.
|
||||
if err != nil {
|
||||
@ -405,7 +406,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
||||
}
|
||||
// Special handling for client streaming rpc.
|
||||
// This recv expects EOF or errors, so we don't collect inPayload.
|
||||
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, nil)
|
||||
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxMsgSize, nil)
|
||||
cs.closeTransportStream(err)
|
||||
if err == nil {
|
||||
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
|
||||
|
||||
@ -570,6 +570,9 @@ func (te *test) clientConn() *grpc.ClientConn {
|
||||
if te.streamClientInt != nil {
|
||||
opts = append(opts, grpc.WithStreamInterceptor(te.streamClientInt))
|
||||
}
|
||||
if te.maxMsgSize > 0 {
|
||||
opts = append(opts, grpc.WithMaxMsgSize(te.maxMsgSize))
|
||||
}
|
||||
switch te.e.security {
|
||||
case "tls":
|
||||
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
|
||||
@ -1427,22 +1430,33 @@ func testExceedMsgLimit(t *testing.T, e env) {
|
||||
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||
|
||||
argSize := int32(te.maxMsgSize + 1)
|
||||
const respSize = 1
|
||||
const smallSize = 1
|
||||
|
||||
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
smallPayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, smallSize)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// test on server side for unary RPC
|
||||
req := &testpb.SimpleRequest{
|
||||
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
|
||||
ResponseSize: proto.Int32(respSize),
|
||||
ResponseSize: proto.Int32(smallSize),
|
||||
Payload: payload,
|
||||
}
|
||||
if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal {
|
||||
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.Internal)
|
||||
}
|
||||
// test on client side for unary RPC
|
||||
req.ResponseSize = proto.Int32(int32(te.maxMsgSize) + 1)
|
||||
if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal {
|
||||
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.Internal)
|
||||
}
|
||||
|
||||
// test on server side for streaming RPC
|
||||
stream, err := tc.FullDuplexCall(te.ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
|
||||
@ -1469,6 +1483,21 @@ func testExceedMsgLimit(t *testing.T, e env) {
|
||||
if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal {
|
||||
t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.Internal)
|
||||
}
|
||||
|
||||
// test on client side for streaming RPC
|
||||
stream, err = tc.FullDuplexCall(te.ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
|
||||
}
|
||||
respParam[0].Size = proto.Int32(int32(te.maxMsgSize) + 1)
|
||||
sreq.Payload = smallPayload
|
||||
if err := stream.Send(sreq); err != nil {
|
||||
t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
|
||||
}
|
||||
if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal {
|
||||
t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.Internal)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestPeerClientSide(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user