split encode into three functions (#2058)
This commit is contained in:
@ -105,12 +105,13 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// send a response back to end the stream.
|
// send a response back to end the stream.
|
||||||
hdr, data, err := encode(testCodec{}, &expectedResponse, nil, nil, nil)
|
data, err := encode(testCodec{}, &expectedResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Failed to encode the response: %v", err)
|
t.Errorf("Failed to encode the response: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.t.Write(s, hdr, data, &transport.Options{})
|
hdr, payload := msgHeader(data, nil)
|
||||||
|
h.t.Write(s, hdr, payload, &transport.Options{})
|
||||||
h.t.WriteStatus(s, status.New(codes.OK, ""))
|
h.t.WriteStatus(s, status.New(codes.OK, ""))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
119
rpc_util.go
119
rpc_util.go
@ -419,8 +419,8 @@ func (o CustomCodecCallOption) after(c *callInfo) {}
|
|||||||
type payloadFormat uint8
|
type payloadFormat uint8
|
||||||
|
|
||||||
const (
|
const (
|
||||||
compressionNone payloadFormat = iota // no compression
|
compressionNone payloadFormat = 0 // no compression
|
||||||
compressionMade
|
compressionMade payloadFormat = 1 // compressed
|
||||||
)
|
)
|
||||||
|
|
||||||
// parser reads complete gRPC messages from the underlying reader.
|
// parser reads complete gRPC messages from the underlying reader.
|
||||||
@ -477,65 +477,82 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byt
|
|||||||
return pf, msg, nil
|
return pf, msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// encode serializes msg and returns a buffer of message header and a buffer of msg.
|
// encode serializes msg and returns a buffer containing the message, or an
|
||||||
// If msg is nil, it generates the message header and an empty msg buffer.
|
// error if it is too large to be transmitted by grpc. If msg is nil, it
|
||||||
// TODO(ddyihai): eliminate extra Compressor parameter.
|
// generates an empty message.
|
||||||
func encode(c baseCodec, msg interface{}, cp Compressor, outPayload *stats.OutPayload, compressor encoding.Compressor) ([]byte, []byte, error) {
|
func encode(c baseCodec, msg interface{}) ([]byte, error) {
|
||||||
var (
|
if msg == nil { // NOTE: typed nils will not be caught by this check
|
||||||
b []byte
|
return nil, nil
|
||||||
cbuf *bytes.Buffer
|
}
|
||||||
)
|
b, err := c.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
|
||||||
|
}
|
||||||
|
if uint(len(b)) > math.MaxUint32 {
|
||||||
|
return nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b))
|
||||||
|
}
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// compress returns the input bytes compressed by compressor or cp. If both
|
||||||
|
// compressors are nil, returns nil.
|
||||||
|
//
|
||||||
|
// TODO(dfawley): eliminate cp parameter by wrapping Compressor in an encoding.Compressor.
|
||||||
|
func compress(in []byte, cp Compressor, compressor encoding.Compressor) ([]byte, error) {
|
||||||
|
if compressor == nil && cp == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
wrapErr := func(err error) error {
|
||||||
|
return status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
|
||||||
|
}
|
||||||
|
cbuf := &bytes.Buffer{}
|
||||||
|
if compressor != nil {
|
||||||
|
z, _ := compressor.Compress(cbuf)
|
||||||
|
if _, err := z.Write(in); err != nil {
|
||||||
|
return nil, wrapErr(err)
|
||||||
|
}
|
||||||
|
if err := z.Close(); err != nil {
|
||||||
|
return nil, wrapErr(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := cp.Do(cbuf, in); err != nil {
|
||||||
|
return nil, wrapErr(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cbuf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
payloadLen = 1
|
payloadLen = 1
|
||||||
sizeLen = 4
|
sizeLen = 4
|
||||||
|
headerLen = payloadLen + sizeLen
|
||||||
)
|
)
|
||||||
if msg != nil {
|
|
||||||
var err error
|
// msgHeader returns a 5-byte header for the message being transmitted and the
|
||||||
b, err = c.Marshal(msg)
|
// payload, which is compData if non-nil or data otherwise.
|
||||||
if err != nil {
|
func msgHeader(data, compData []byte) (hdr []byte, payload []byte) {
|
||||||
return nil, nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
|
hdr = make([]byte, headerLen)
|
||||||
}
|
if compData != nil {
|
||||||
if outPayload != nil {
|
hdr[0] = byte(compressionMade)
|
||||||
outPayload.Payload = msg
|
data = compData
|
||||||
// TODO truncate large payload.
|
|
||||||
outPayload.Data = b
|
|
||||||
outPayload.Length = len(b)
|
|
||||||
}
|
|
||||||
if compressor != nil || cp != nil {
|
|
||||||
cbuf = new(bytes.Buffer)
|
|
||||||
// Has compressor, check Compressor is set by UseCompressor first.
|
|
||||||
if compressor != nil {
|
|
||||||
z, _ := compressor.Compress(cbuf)
|
|
||||||
if _, err := z.Write(b); err != nil {
|
|
||||||
return nil, nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
|
|
||||||
}
|
|
||||||
z.Close()
|
|
||||||
} else {
|
} else {
|
||||||
// If Compressor is not set by UseCompressor, use default Compressor
|
hdr[0] = byte(compressionNone)
|
||||||
if err := cp.Do(cbuf, b); err != nil {
|
|
||||||
return nil, nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
b = cbuf.Bytes()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if uint(len(b)) > math.MaxUint32 {
|
|
||||||
return nil, nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bufHeader := make([]byte, payloadLen+sizeLen)
|
// Write length of payload into buf
|
||||||
if compressor != nil || cp != nil {
|
binary.BigEndian.PutUint32(hdr[payloadLen:], uint32(len(data)))
|
||||||
bufHeader[0] = byte(compressionMade)
|
return hdr, data
|
||||||
} else {
|
|
||||||
bufHeader[0] = byte(compressionNone)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write length of b into buf
|
func outPayload(client bool, msg interface{}, data, payload []byte, t time.Time) *stats.OutPayload {
|
||||||
binary.BigEndian.PutUint32(bufHeader[payloadLen:], uint32(len(b)))
|
return &stats.OutPayload{
|
||||||
if outPayload != nil {
|
Client: client,
|
||||||
outPayload.WireLength = payloadLen + sizeLen + len(b)
|
Payload: msg,
|
||||||
|
Data: data,
|
||||||
|
Length: len(data),
|
||||||
|
WireLength: len(payload) + headerLen,
|
||||||
|
SentTime: t,
|
||||||
}
|
}
|
||||||
return bufHeader, b, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status {
|
func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status {
|
||||||
|
@ -105,23 +105,25 @@ func TestEncode(t *testing.T) {
|
|||||||
for _, test := range []struct {
|
for _, test := range []struct {
|
||||||
// input
|
// input
|
||||||
msg proto.Message
|
msg proto.Message
|
||||||
cp Compressor
|
|
||||||
// outputs
|
// outputs
|
||||||
hdr []byte
|
hdr []byte
|
||||||
data []byte
|
data []byte
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
{nil, nil, []byte{0, 0, 0, 0, 0}, []byte{}, nil},
|
{nil, []byte{0, 0, 0, 0, 0}, []byte{}, nil},
|
||||||
} {
|
} {
|
||||||
hdr, data, err := encode(encoding.GetCodec(protoenc.Name), test.msg, nil, nil, nil)
|
data, err := encode(encoding.GetCodec(protoenc.Name), test.msg)
|
||||||
if err != test.err || !bytes.Equal(hdr, test.hdr) || !bytes.Equal(data, test.data) {
|
if err != test.err || !bytes.Equal(data, test.data) {
|
||||||
t.Fatalf("encode(_, _, %v, _) = %v, %v, %v\nwant %v, %v, %v", test.cp, hdr, data, err, test.hdr, test.data, test.err)
|
t.Errorf("encode(_, %v) = %v, %v; want %v, %v", test.msg, data, err, test.data, test.err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if hdr, _ := msgHeader(data, nil); !bytes.Equal(hdr, test.hdr) {
|
||||||
|
t.Errorf("msgHeader(%v, false) = %v; want %v", data, hdr, test.hdr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCompress(t *testing.T) {
|
func TestCompress(t *testing.T) {
|
||||||
|
|
||||||
bestCompressor, err := NewGZIPCompressorWithLevel(gzip.BestCompression)
|
bestCompressor, err := NewGZIPCompressorWithLevel(gzip.BestCompression)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Could not initialize gzip compressor with best compression.")
|
t.Fatalf("Could not initialize gzip compressor with best compression.")
|
||||||
@ -214,12 +216,12 @@ func TestParseDialTarget(t *testing.T) {
|
|||||||
func bmEncode(b *testing.B, mSize int) {
|
func bmEncode(b *testing.B, mSize int) {
|
||||||
cdc := encoding.GetCodec(protoenc.Name)
|
cdc := encoding.GetCodec(protoenc.Name)
|
||||||
msg := &perfpb.Buffer{Body: make([]byte, mSize)}
|
msg := &perfpb.Buffer{Body: make([]byte, mSize)}
|
||||||
encodeHdr, encodeData, _ := encode(cdc, msg, nil, nil, nil)
|
encodeData, _ := encode(cdc, msg)
|
||||||
encodedSz := int64(len(encodeHdr) + len(encodeData))
|
encodedSz := int64(len(encodeData))
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
encode(cdc, msg, nil, nil, nil)
|
encode(cdc, msg)
|
||||||
}
|
}
|
||||||
b.SetBytes(encodedSz)
|
b.SetBytes(encodedSz)
|
||||||
}
|
}
|
||||||
|
26
server.go
26
server.go
@ -827,24 +827,24 @@ func (s *Server) incrCallsFailed() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options, comp encoding.Compressor) error {
|
func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options, comp encoding.Compressor) error {
|
||||||
var (
|
data, err := encode(s.getCodec(stream.ContentSubtype()), msg)
|
||||||
outPayload *stats.OutPayload
|
|
||||||
)
|
|
||||||
if s.opts.statsHandler != nil {
|
|
||||||
outPayload = &stats.OutPayload{}
|
|
||||||
}
|
|
||||||
hdr, data, err := encode(s.getCodec(stream.ContentSubtype()), msg, cp, outPayload, comp)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
grpclog.Errorln("grpc: server failed to encode response: ", err)
|
grpclog.Errorln("grpc: server failed to encode response: ", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if len(data) > s.opts.maxSendMessageSize {
|
compData, err := compress(data, cp, comp)
|
||||||
return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), s.opts.maxSendMessageSize)
|
if err != nil {
|
||||||
|
grpclog.Errorln("grpc: server failed to compress response: ", err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
err = t.Write(stream, hdr, data, opts)
|
hdr, payload := msgHeader(data, compData)
|
||||||
if err == nil && outPayload != nil {
|
// TODO(dfawley): should we be checking len(data) instead?
|
||||||
outPayload.SentTime = time.Now()
|
if len(payload) > s.opts.maxSendMessageSize {
|
||||||
s.opts.statsHandler.HandleRPC(stream.Context(), outPayload)
|
return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(payload), s.opts.maxSendMessageSize)
|
||||||
|
}
|
||||||
|
err = t.Write(stream, hdr, payload, opts)
|
||||||
|
if err == nil && s.opts.statsHandler != nil {
|
||||||
|
s.opts.statsHandler.HandleRPC(stream.Context(), outPayload(false, msg, data, payload, time.Now()))
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
49
stream.go
49
stream.go
@ -476,27 +476,27 @@ func (a *csAttempt) sendMsg(m interface{}) (err error) {
|
|||||||
}
|
}
|
||||||
a.mu.Unlock()
|
a.mu.Unlock()
|
||||||
}
|
}
|
||||||
var outPayload *stats.OutPayload
|
data, err := encode(cs.codec, m)
|
||||||
if a.statsHandler != nil {
|
|
||||||
outPayload = &stats.OutPayload{
|
|
||||||
Client: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
hdr, data, err := encode(cs.codec, m, cs.cp, outPayload, cs.comp)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if len(data) > *cs.c.maxSendMessageSize {
|
compData, err := compress(data, cs.cp, cs.comp)
|
||||||
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize)
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
hdr, payload := msgHeader(data, compData)
|
||||||
|
// TODO(dfawley): should we be checking len(data) instead?
|
||||||
|
if len(payload) > *cs.c.maxSendMessageSize {
|
||||||
|
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), *cs.c.maxSendMessageSize)
|
||||||
|
}
|
||||||
|
|
||||||
if !cs.desc.ClientStreams {
|
if !cs.desc.ClientStreams {
|
||||||
cs.sentLast = true
|
cs.sentLast = true
|
||||||
}
|
}
|
||||||
err = a.t.Write(a.s, hdr, data, &transport.Options{Last: !cs.desc.ClientStreams})
|
err = a.t.Write(a.s, hdr, payload, &transport.Options{Last: !cs.desc.ClientStreams})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if outPayload != nil {
|
if a.statsHandler != nil {
|
||||||
outPayload.SentTime = time.Now()
|
a.statsHandler.HandleRPC(a.ctx, outPayload(true, m, data, payload, time.Now()))
|
||||||
a.statsHandler.HandleRPC(a.ctx, outPayload)
|
|
||||||
}
|
}
|
||||||
if channelz.IsOn() {
|
if channelz.IsOn() {
|
||||||
a.t.IncrMsgSent()
|
a.t.IncrMsgSent()
|
||||||
@ -706,23 +706,24 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
|
|||||||
ss.t.IncrMsgSent()
|
ss.t.IncrMsgSent()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
var outPayload *stats.OutPayload
|
data, err := encode(ss.codec, m)
|
||||||
if ss.statsHandler != nil {
|
|
||||||
outPayload = &stats.OutPayload{}
|
|
||||||
}
|
|
||||||
hdr, data, err := encode(ss.codec, m, ss.cp, outPayload, ss.comp)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if len(data) > ss.maxSendMessageSize {
|
compData, err := compress(data, ss.cp, ss.comp)
|
||||||
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize)
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
if err := ss.t.Write(ss.s, hdr, data, &transport.Options{Last: false}); err != nil {
|
hdr, payload := msgHeader(data, compData)
|
||||||
|
// TODO(dfawley): should we be checking len(data) instead?
|
||||||
|
if len(payload) > ss.maxSendMessageSize {
|
||||||
|
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), ss.maxSendMessageSize)
|
||||||
|
}
|
||||||
|
if err := ss.t.Write(ss.s, hdr, payload, &transport.Options{Last: false}); err != nil {
|
||||||
return toRPCErr(err)
|
return toRPCErr(err)
|
||||||
}
|
}
|
||||||
if outPayload != nil {
|
if ss.statsHandler != nil {
|
||||||
outPayload.SentTime = time.Now()
|
ss.statsHandler.HandleRPC(ss.s.Context(), outPayload(false, m, data, payload, time.Now()))
|
||||||
ss.statsHandler.HandleRPC(ss.s.Context(), outPayload)
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user