split encode into three functions (#2058)

This commit is contained in:
dfawley
2018-05-11 13:47:10 -07:00
committed by GitHub
parent b75baa103c
commit 091a800143
5 changed files with 122 additions and 101 deletions

View File

@ -105,12 +105,13 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
} }
} }
// send a response back to end the stream. // send a response back to end the stream.
hdr, data, err := encode(testCodec{}, &expectedResponse, nil, nil, nil) data, err := encode(testCodec{}, &expectedResponse)
if err != nil { if err != nil {
t.Errorf("Failed to encode the response: %v", err) t.Errorf("Failed to encode the response: %v", err)
return return
} }
h.t.Write(s, hdr, data, &transport.Options{}) hdr, payload := msgHeader(data, nil)
h.t.Write(s, hdr, payload, &transport.Options{})
h.t.WriteStatus(s, status.New(codes.OK, "")) h.t.WriteStatus(s, status.New(codes.OK, ""))
} }

View File

@ -419,8 +419,8 @@ func (o CustomCodecCallOption) after(c *callInfo) {}
type payloadFormat uint8 type payloadFormat uint8
const ( const (
compressionNone payloadFormat = iota // no compression compressionNone payloadFormat = 0 // no compression
compressionMade compressionMade payloadFormat = 1 // compressed
) )
// parser reads complete gRPC messages from the underlying reader. // parser reads complete gRPC messages from the underlying reader.
@ -477,65 +477,82 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byt
return pf, msg, nil return pf, msg, nil
} }
// encode serializes msg and returns a buffer of message header and a buffer of msg. // encode serializes msg and returns a buffer containing the message, or an
// If msg is nil, it generates the message header and an empty msg buffer. // error if it is too large to be transmitted by grpc. If msg is nil, it
// TODO(ddyihai): eliminate extra Compressor parameter. // generates an empty message.
func encode(c baseCodec, msg interface{}, cp Compressor, outPayload *stats.OutPayload, compressor encoding.Compressor) ([]byte, []byte, error) { func encode(c baseCodec, msg interface{}) ([]byte, error) {
var ( if msg == nil { // NOTE: typed nils will not be caught by this check
b []byte return nil, nil
cbuf *bytes.Buffer }
) b, err := c.Marshal(msg)
const ( if err != nil {
payloadLen = 1 return nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
sizeLen = 4
)
if msg != nil {
var err error
b, err = c.Marshal(msg)
if err != nil {
return nil, nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
}
if outPayload != nil {
outPayload.Payload = msg
// TODO truncate large payload.
outPayload.Data = b
outPayload.Length = len(b)
}
if compressor != nil || cp != nil {
cbuf = new(bytes.Buffer)
// Has compressor, check Compressor is set by UseCompressor first.
if compressor != nil {
z, _ := compressor.Compress(cbuf)
if _, err := z.Write(b); err != nil {
return nil, nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
}
z.Close()
} else {
// If Compressor is not set by UseCompressor, use default Compressor
if err := cp.Do(cbuf, b); err != nil {
return nil, nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
}
}
b = cbuf.Bytes()
}
} }
if uint(len(b)) > math.MaxUint32 { if uint(len(b)) > math.MaxUint32 {
return nil, nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b)) return nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b))
} }
return b, nil
}
bufHeader := make([]byte, payloadLen+sizeLen) // compress returns the input bytes compressed by compressor or cp. If both
if compressor != nil || cp != nil { // compressors are nil, returns nil.
bufHeader[0] = byte(compressionMade) //
// TODO(dfawley): eliminate cp parameter by wrapping Compressor in an encoding.Compressor.
func compress(in []byte, cp Compressor, compressor encoding.Compressor) ([]byte, error) {
if compressor == nil && cp == nil {
return nil, nil
}
wrapErr := func(err error) error {
return status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
}
cbuf := &bytes.Buffer{}
if compressor != nil {
z, _ := compressor.Compress(cbuf)
if _, err := z.Write(in); err != nil {
return nil, wrapErr(err)
}
if err := z.Close(); err != nil {
return nil, wrapErr(err)
}
} else { } else {
bufHeader[0] = byte(compressionNone) if err := cp.Do(cbuf, in); err != nil {
return nil, wrapErr(err)
}
}
return cbuf.Bytes(), nil
}
const (
payloadLen = 1
sizeLen = 4
headerLen = payloadLen + sizeLen
)
// msgHeader returns a 5-byte header for the message being transmitted and the
// payload, which is compData if non-nil or data otherwise.
func msgHeader(data, compData []byte) (hdr []byte, payload []byte) {
hdr = make([]byte, headerLen)
if compData != nil {
hdr[0] = byte(compressionMade)
data = compData
} else {
hdr[0] = byte(compressionNone)
} }
// Write length of b into buf // Write length of payload into buf
binary.BigEndian.PutUint32(bufHeader[payloadLen:], uint32(len(b))) binary.BigEndian.PutUint32(hdr[payloadLen:], uint32(len(data)))
if outPayload != nil { return hdr, data
outPayload.WireLength = payloadLen + sizeLen + len(b) }
func outPayload(client bool, msg interface{}, data, payload []byte, t time.Time) *stats.OutPayload {
return &stats.OutPayload{
Client: client,
Payload: msg,
Data: data,
Length: len(data),
WireLength: len(payload) + headerLen,
SentTime: t,
} }
return bufHeader, b, nil
} }
func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status { func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool) *status.Status {

View File

@ -105,23 +105,25 @@ func TestEncode(t *testing.T) {
for _, test := range []struct { for _, test := range []struct {
// input // input
msg proto.Message msg proto.Message
cp Compressor
// outputs // outputs
hdr []byte hdr []byte
data []byte data []byte
err error err error
}{ }{
{nil, nil, []byte{0, 0, 0, 0, 0}, []byte{}, nil}, {nil, []byte{0, 0, 0, 0, 0}, []byte{}, nil},
} { } {
hdr, data, err := encode(encoding.GetCodec(protoenc.Name), test.msg, nil, nil, nil) data, err := encode(encoding.GetCodec(protoenc.Name), test.msg)
if err != test.err || !bytes.Equal(hdr, test.hdr) || !bytes.Equal(data, test.data) { if err != test.err || !bytes.Equal(data, test.data) {
t.Fatalf("encode(_, _, %v, _) = %v, %v, %v\nwant %v, %v, %v", test.cp, hdr, data, err, test.hdr, test.data, test.err) t.Errorf("encode(_, %v) = %v, %v; want %v, %v", test.msg, data, err, test.data, test.err)
continue
}
if hdr, _ := msgHeader(data, nil); !bytes.Equal(hdr, test.hdr) {
t.Errorf("msgHeader(%v, false) = %v; want %v", data, hdr, test.hdr)
} }
} }
} }
func TestCompress(t *testing.T) { func TestCompress(t *testing.T) {
bestCompressor, err := NewGZIPCompressorWithLevel(gzip.BestCompression) bestCompressor, err := NewGZIPCompressorWithLevel(gzip.BestCompression)
if err != nil { if err != nil {
t.Fatalf("Could not initialize gzip compressor with best compression.") t.Fatalf("Could not initialize gzip compressor with best compression.")
@ -214,12 +216,12 @@ func TestParseDialTarget(t *testing.T) {
func bmEncode(b *testing.B, mSize int) { func bmEncode(b *testing.B, mSize int) {
cdc := encoding.GetCodec(protoenc.Name) cdc := encoding.GetCodec(protoenc.Name)
msg := &perfpb.Buffer{Body: make([]byte, mSize)} msg := &perfpb.Buffer{Body: make([]byte, mSize)}
encodeHdr, encodeData, _ := encode(cdc, msg, nil, nil, nil) encodeData, _ := encode(cdc, msg)
encodedSz := int64(len(encodeHdr) + len(encodeData)) encodedSz := int64(len(encodeData))
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
encode(cdc, msg, nil, nil, nil) encode(cdc, msg)
} }
b.SetBytes(encodedSz) b.SetBytes(encodedSz)
} }

View File

@ -827,24 +827,24 @@ func (s *Server) incrCallsFailed() {
} }
func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options, comp encoding.Compressor) error { func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options, comp encoding.Compressor) error {
var ( data, err := encode(s.getCodec(stream.ContentSubtype()), msg)
outPayload *stats.OutPayload
)
if s.opts.statsHandler != nil {
outPayload = &stats.OutPayload{}
}
hdr, data, err := encode(s.getCodec(stream.ContentSubtype()), msg, cp, outPayload, comp)
if err != nil { if err != nil {
grpclog.Errorln("grpc: server failed to encode response: ", err) grpclog.Errorln("grpc: server failed to encode response: ", err)
return err return err
} }
if len(data) > s.opts.maxSendMessageSize { compData, err := compress(data, cp, comp)
return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), s.opts.maxSendMessageSize) if err != nil {
grpclog.Errorln("grpc: server failed to compress response: ", err)
return err
} }
err = t.Write(stream, hdr, data, opts) hdr, payload := msgHeader(data, compData)
if err == nil && outPayload != nil { // TODO(dfawley): should we be checking len(data) instead?
outPayload.SentTime = time.Now() if len(payload) > s.opts.maxSendMessageSize {
s.opts.statsHandler.HandleRPC(stream.Context(), outPayload) return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(payload), s.opts.maxSendMessageSize)
}
err = t.Write(stream, hdr, payload, opts)
if err == nil && s.opts.statsHandler != nil {
s.opts.statsHandler.HandleRPC(stream.Context(), outPayload(false, msg, data, payload, time.Now()))
} }
return err return err
} }

View File

@ -476,27 +476,27 @@ func (a *csAttempt) sendMsg(m interface{}) (err error) {
} }
a.mu.Unlock() a.mu.Unlock()
} }
var outPayload *stats.OutPayload data, err := encode(cs.codec, m)
if a.statsHandler != nil {
outPayload = &stats.OutPayload{
Client: true,
}
}
hdr, data, err := encode(cs.codec, m, cs.cp, outPayload, cs.comp)
if err != nil { if err != nil {
return err return err
} }
if len(data) > *cs.c.maxSendMessageSize { compData, err := compress(data, cs.cp, cs.comp)
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize) if err != nil {
return err
} }
hdr, payload := msgHeader(data, compData)
// TODO(dfawley): should we be checking len(data) instead?
if len(payload) > *cs.c.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), *cs.c.maxSendMessageSize)
}
if !cs.desc.ClientStreams { if !cs.desc.ClientStreams {
cs.sentLast = true cs.sentLast = true
} }
err = a.t.Write(a.s, hdr, data, &transport.Options{Last: !cs.desc.ClientStreams}) err = a.t.Write(a.s, hdr, payload, &transport.Options{Last: !cs.desc.ClientStreams})
if err == nil { if err == nil {
if outPayload != nil { if a.statsHandler != nil {
outPayload.SentTime = time.Now() a.statsHandler.HandleRPC(a.ctx, outPayload(true, m, data, payload, time.Now()))
a.statsHandler.HandleRPC(a.ctx, outPayload)
} }
if channelz.IsOn() { if channelz.IsOn() {
a.t.IncrMsgSent() a.t.IncrMsgSent()
@ -706,23 +706,24 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
ss.t.IncrMsgSent() ss.t.IncrMsgSent()
} }
}() }()
var outPayload *stats.OutPayload data, err := encode(ss.codec, m)
if ss.statsHandler != nil {
outPayload = &stats.OutPayload{}
}
hdr, data, err := encode(ss.codec, m, ss.cp, outPayload, ss.comp)
if err != nil { if err != nil {
return err return err
} }
if len(data) > ss.maxSendMessageSize { compData, err := compress(data, ss.cp, ss.comp)
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize) if err != nil {
return err
} }
if err := ss.t.Write(ss.s, hdr, data, &transport.Options{Last: false}); err != nil { hdr, payload := msgHeader(data, compData)
// TODO(dfawley): should we be checking len(data) instead?
if len(payload) > ss.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), ss.maxSendMessageSize)
}
if err := ss.t.Write(ss.s, hdr, payload, &transport.Options{Last: false}); err != nil {
return toRPCErr(err) return toRPCErr(err)
} }
if outPayload != nil { if ss.statsHandler != nil {
outPayload.SentTime = time.Now() ss.statsHandler.HandleRPC(ss.s.Context(), outPayload(false, m, data, payload, time.Now()))
ss.statsHandler.HandleRPC(ss.s.Context(), outPayload)
} }
return nil return nil
} }