revert buffer reuse (#3338)

* Revert "stream: fix returnBuffers race during retry (#3293)"

This reverts commit ede71d589cc36a6adff7244ce220516f0b3e446b.

* Revert "codec/proto: reuse of marshal byte buffers (#3167)"

This reverts commit 642675125e198ce612ea9caff4bf75d3a4a45667.
This commit is contained in:
Menghan Li
2020-01-27 13:30:41 -08:00
committed by GitHub
parent 7afcfdd66b
commit 8c50fc2565
15 changed files with 72 additions and 333 deletions

View File

@ -8,13 +8,8 @@ into bytes and vice-versa for the purposes of network transmission.
## Codecs (Serialization and Deserialization) ## Codecs (Serialization and Deserialization)
A `Codec` contains code to serialize a message into a byte slice (`Marshal`) and A `Codec` contains code to serialize a message into a byte slice (`Marshal`) and
deserialize a byte slice back into a message (`Unmarshal`). Optionally, a deserialize a byte slice back into a message (`Unmarshal`). `Codec`s are
`ReturnBuffer` method to potentially reuse the byte slice returned by the registered by name into a global registry maintained in the `encoding` package.
`Marshal` method may also be implemented; note that this is an experimental
feature with an API that is still in flux.
`Codec`s are registered by name into a global registry maintained in the
`encoding` package.
### Implementing a `Codec` ### Implementing a `Codec`

View File

@ -31,20 +31,6 @@ type baseCodec interface {
Unmarshal(data []byte, v interface{}) error Unmarshal(data []byte, v interface{}) error
} }
// A bufferReturner requires a ReturnBuffer method to be implemented. Once a
// Marshal caller is done with the returned byte buffer, they can choose to
// return it back to the encoding library for re-use using this method.
type bufferReturner interface {
// If implemented in a codec, this function may be called with the byte
// buffer returned by Marshal after gRPC is done with the buffer.
//
// gRPC will not call ReturnBuffer after it's done with the buffer if any of
// the following is true:
// 1. Stats handlers are used.
// 2. Binlogs are enabled.
ReturnBuffer(buf []byte)
}
var _ baseCodec = Codec(nil) var _ baseCodec = Codec(nil)
var _ baseCodec = encoding.Codec(nil) var _ baseCodec = encoding.Codec(nil)

View File

@ -75,11 +75,6 @@ func GetCompressor(name string) Compressor {
// Codec defines the interface gRPC uses to encode and decode messages. Note // Codec defines the interface gRPC uses to encode and decode messages. Note
// that implementations of this interface must be thread safe; a Codec's // that implementations of this interface must be thread safe; a Codec's
// methods can be called from concurrent goroutines. // methods can be called from concurrent goroutines.
//
// Optionally, if a ReturnBuffer(buf []byte) is implemented, it may be called
// to return the byte slice it received from the Marshal function after gRPC is
// done with it. The codec may reuse this byte slice in a future Marshal
// operation to reduce the application's memory footprint.
type Codec interface { type Codec interface {
// Marshal returns the wire format of v. // Marshal returns the wire format of v.
Marshal(v interface{}) ([]byte, error) Marshal(v interface{}) ([]byte, error)

View File

@ -21,6 +21,7 @@
package proto package proto
import ( import (
"math"
"sync" "sync"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
@ -37,16 +38,29 @@ func init() {
// codec is a Codec implementation with protobuf. It is the default codec for gRPC. // codec is a Codec implementation with protobuf. It is the default codec for gRPC.
type codec struct{} type codec struct{}
func marshal(v interface{}, pb *proto.Buffer) ([]byte, error) { type cachedProtoBuffer struct {
protoMsg := v.(proto.Message) lastMarshaledSize uint32
newSlice := returnBufferPool.Get().([]byte) proto.Buffer
}
pb.SetBuf(newSlice) func capToMaxInt32(val int) uint32 {
pb.Reset() if val > math.MaxInt32 {
if err := pb.Marshal(protoMsg); err != nil { return uint32(math.MaxInt32)
}
return uint32(val)
}
func marshal(v interface{}, cb *cachedProtoBuffer) ([]byte, error) {
protoMsg := v.(proto.Message)
newSlice := make([]byte, 0, cb.lastMarshaledSize)
cb.SetBuf(newSlice)
cb.Reset()
if err := cb.Marshal(protoMsg); err != nil {
return nil, err return nil, err
} }
out := pb.Bytes() out := cb.Bytes()
cb.lastMarshaledSize = capToMaxInt32(len(out))
return out, nil return out, nil
} }
@ -56,12 +70,12 @@ func (codec) Marshal(v interface{}) ([]byte, error) {
return pm.Marshal() return pm.Marshal()
} }
pb := protoBufferPool.Get().(*proto.Buffer) cb := protoBufferPool.Get().(*cachedProtoBuffer)
out, err := marshal(v, pb) out, err := marshal(v, cb)
// put back buffer and lose the ref to the slice // put back buffer and lose the ref to the slice
pb.SetBuf(nil) cb.SetBuf(nil)
protoBufferPool.Put(pb) protoBufferPool.Put(cb)
return out, err return out, err
} }
@ -74,39 +88,23 @@ func (codec) Unmarshal(data []byte, v interface{}) error {
return pu.Unmarshal(data) return pu.Unmarshal(data)
} }
pb := protoBufferPool.Get().(*proto.Buffer) cb := protoBufferPool.Get().(*cachedProtoBuffer)
pb.SetBuf(data) cb.SetBuf(data)
err := pb.Unmarshal(protoMsg) err := cb.Unmarshal(protoMsg)
pb.SetBuf(nil) cb.SetBuf(nil)
protoBufferPool.Put(pb) protoBufferPool.Put(cb)
return err return err
} }
func (codec) ReturnBuffer(data []byte) {
// Make sure we set the length of the buffer to zero so that future appends
// will start from the zeroeth byte, not append to the previous, stale data.
//
// Apparently, sync.Pool with non-pointer objects (slices, in this case)
// causes small allocations because of how interface{} works under the hood.
// This isn't a problem for us, however, because we're more concerned with
// _how_ much that allocation is. Ideally, we'd be using bytes.Buffer as the
// Marshal return value to remove even that allocation, but we can't change
// the Marshal interface at this point.
returnBufferPool.Put(data[:0])
}
func (codec) Name() string { func (codec) Name() string {
return Name return Name
} }
var protoBufferPool = &sync.Pool{ var protoBufferPool = &sync.Pool{
New: func() interface{} { New: func() interface{} {
return &proto.Buffer{} return &cachedProtoBuffer{
}, Buffer: proto.Buffer{},
} lastMarshaledSize: 16,
}
var returnBufferPool = &sync.Pool{
New: func() interface{} {
return make([]byte, 0, 16)
}, },
} }

View File

@ -127,53 +127,3 @@ func TestStaggeredMarshalAndUnmarshalUsingSamePool(t *testing.T) {
} }
} }
} }
func TestBufferReuse(t *testing.T) {
c := codec{}
marshal := func(toMarshal []byte) []byte {
protoIn := &codec_perf.Buffer{Body: toMarshal}
b, err := c.Marshal(protoIn)
if err != nil {
t.Errorf("codec.Marshal(%v) failed: %v", protoIn, err)
}
// We cannot expect the actual pointer to be the same because sync.Pool
// during GC pauses.
bc := append([]byte(nil), b...)
c.ReturnBuffer(b)
return bc
}
unmarshal := func(b []byte) []byte {
protoOut := &codec_perf.Buffer{}
if err := c.Unmarshal(b, protoOut); err != nil {
t.Errorf("codec.Unarshal(%v) failed: %v", protoOut, err)
}
return protoOut.GetBody()
}
check := func(in []byte, out []byte) {
if len(in) != len(out) {
t.Errorf("unequal lengths: len(in=%v)=%d, len(out=%v)=%d", in, len(in), out, len(out))
}
for i := 0; i < len(in); i++ {
if in[i] != out[i] {
t.Errorf("unequal values: in[%d] = %v, out[%d] = %v", i, in[i], i, out[i])
}
}
}
// To test that the returned buffer does not have unexpected data at the end,
// we use a second input data that is smaller than the first.
in1 := []byte{1, 2, 3}
b1 := marshal(in1)
in2 := []byte{4, 5}
b2 := marshal(in2)
out1 := unmarshal(b1)
out2 := unmarshal(b2)
check(in1, out1)
check(in2, out2)
}

View File

@ -25,7 +25,6 @@ import (
"runtime" "runtime"
"sort" "sort"
"strings" "strings"
"sync/atomic"
"time" "time"
) )
@ -75,24 +74,11 @@ func ignore(g string) bool {
return false return false
} }
var lastStacktraceSize uint32 = 4 << 10
// interestingGoroutines returns all goroutines we care about for the purpose of // interestingGoroutines returns all goroutines we care about for the purpose of
// leak checking. It excludes testing or runtime ones. // leak checking. It excludes testing or runtime ones.
func interestingGoroutines() (gs []string) { func interestingGoroutines() (gs []string) {
n := atomic.LoadUint32(&lastStacktraceSize) buf := make([]byte, 2<<20)
buf := make([]byte, n) buf = buf[:runtime.Stack(buf, true)]
for {
nb := uint32(runtime.Stack(buf, true))
if nb < uint32(len(buf)) {
buf = buf[:nb]
break
}
n <<= 1
buf = make([]byte, n)
}
atomic.StoreUint32(&lastStacktraceSize, n)
for _, g := range strings.Split(string(buf), "\n\n") { for _, g := range strings.Split(string(buf), "\n\n") {
if !ignore(g) { if !ignore(g) {
gs = append(gs, g) gs = append(gs, g)

View File

@ -34,9 +34,8 @@ var updateHeaderTblSize = func(e *hpack.Encoder, v uint32) {
} }
type itemNode struct { type itemNode struct {
it interface{} it interface{}
onDequeue func() next *itemNode
next *itemNode
} }
type itemList struct { type itemList struct {
@ -44,8 +43,8 @@ type itemList struct {
tail *itemNode tail *itemNode
} }
func (il *itemList) enqueue(i interface{}, onDequeue func()) { func (il *itemList) enqueue(i interface{}) {
n := &itemNode{it: i, onDequeue: onDequeue} n := &itemNode{it: i}
if il.tail == nil { if il.tail == nil {
il.head, il.tail = n, n il.head, il.tail = n, n
return return
@ -64,14 +63,11 @@ func (il *itemList) dequeue() interface{} {
if il.head == nil { if il.head == nil {
return nil return nil
} }
i, onDequeue := il.head.it, il.head.onDequeue i := il.head.it
il.head = il.head.next il.head = il.head.next
if il.head == nil { if il.head == nil {
il.tail = nil il.tail = nil
} }
if onDequeue != nil {
onDequeue()
}
return i return i
} }
@ -140,7 +136,6 @@ type dataFrame struct {
// onEachWrite is called every time // onEachWrite is called every time
// a part of d is written out. // a part of d is written out.
onEachWrite func() onEachWrite func()
rb *ReturnBuffer
} }
func (*dataFrame) isTransportResponseFrame() bool { return false } func (*dataFrame) isTransportResponseFrame() bool { return false }
@ -334,7 +329,7 @@ func (c *controlBuffer) executeAndPut(f func(it interface{}) bool, it cbItem) (b
wakeUp = true wakeUp = true
c.consumerWaiting = false c.consumerWaiting = false
} }
c.list.enqueue(it, nil) c.list.enqueue(it)
if it.isTransportResponseFrame() { if it.isTransportResponseFrame() {
c.transportResponseFrames++ c.transportResponseFrames++
if c.transportResponseFrames == maxQueuedTransportResponseFrames { if c.transportResponseFrames == maxQueuedTransportResponseFrames {
@ -621,7 +616,7 @@ func (l *loopyWriter) headerHandler(h *headerFrame) error {
if str.state != empty { // either active or waiting on stream quota. if str.state != empty { // either active or waiting on stream quota.
// add it str's list of items. // add it str's list of items.
str.itl.enqueue(h, nil) str.itl.enqueue(h)
return nil return nil
} }
if err := l.writeHeader(h.streamID, h.endStream, h.hf, h.onWrite); err != nil { if err := l.writeHeader(h.streamID, h.endStream, h.hf, h.onWrite); err != nil {
@ -636,7 +631,7 @@ func (l *loopyWriter) headerHandler(h *headerFrame) error {
itl: &itemList{}, itl: &itemList{},
wq: h.wq, wq: h.wq,
} }
str.itl.enqueue(h, nil) str.itl.enqueue(h)
return l.originateStream(str) return l.originateStream(str)
} }
@ -707,11 +702,7 @@ func (l *loopyWriter) preprocessData(df *dataFrame) error {
} }
// If we got data for a stream it means that // If we got data for a stream it means that
// stream was originated and the headers were sent out. // stream was originated and the headers were sent out.
var onDequeue func() str.itl.enqueue(df)
if df.rb != nil {
onDequeue = df.rb.Done
}
str.itl.enqueue(df, onDequeue)
if str.state == empty { if str.state == empty {
str.state = active str.state = active
l.activeStreams.enqueue(str) l.activeStreams.enqueue(str)
@ -735,12 +726,6 @@ func (l *loopyWriter) outFlowControlSizeRequestHandler(o *outFlowControlSizeRequ
func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error { func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error {
c.onWrite() c.onWrite()
if str, ok := l.estdStreams[c.streamID]; ok { if str, ok := l.estdStreams[c.streamID]; ok {
// Dequeue all items from the stream's item list. This would call any pending onDequeue functions.
if str.state == active {
for !str.itl.isEmpty() {
str.itl.dequeue()
}
}
// On the server side it could be a trailers-only response or // On the server side it could be a trailers-only response or
// a RST_STREAM before stream initialization thus the stream might // a RST_STREAM before stream initialization thus the stream might
// not be established yet. // not be established yet.

View File

@ -263,23 +263,12 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
} }
func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
rb := opts.ReturnBuffer return ht.do(func() {
if rb != nil {
rb.Add(1)
}
err := ht.do(func() {
ht.writeCommonHeaders(s) ht.writeCommonHeaders(s)
ht.rw.Write(hdr) ht.rw.Write(hdr)
ht.rw.Write(data) ht.rw.Write(data)
ht.rw.(http.Flusher).Flush() ht.rw.(http.Flusher).Flush()
if rb != nil {
rb.Done()
}
}) })
if rb != nil && err != nil {
rb.Done()
}
return err
} }
func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {

View File

@ -847,7 +847,6 @@ func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
df := &dataFrame{ df := &dataFrame{
streamID: s.id, streamID: s.id,
endStream: opts.Last, endStream: opts.Last,
rb: opts.ReturnBuffer,
} }
if hdr != nil || data != nil { // If it's not an empty data frame. if hdr != nil || data != nil { // If it's not an empty data frame.
// Add some data to grpc message header so that we can equally // Add some data to grpc message header so that we can equally
@ -864,9 +863,6 @@ func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
return err return err
} }
} }
if df.rb != nil {
df.rb.Add(1)
}
return t.controlBuf.put(df) return t.controlBuf.put(df)
} }

View File

@ -923,7 +923,6 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
h: hdr, h: hdr,
d: data, d: data,
onEachWrite: t.setResetPingStrikes, onEachWrite: t.setResetPingStrikes,
rb: opts.ReturnBuffer,
} }
if err := s.wq.get(int32(len(hdr) + len(data))); err != nil { if err := s.wq.get(int32(len(hdr) + len(data))); err != nil {
select { select {
@ -933,9 +932,6 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
} }
return ContextErr(s.ctx.Err()) return ContextErr(s.ctx.Err())
} }
if df.rb != nil {
df.rb.Add(1)
}
return t.controlBuf.put(df) return t.controlBuf.put(df)
} }

View File

@ -33,7 +33,6 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
@ -588,9 +587,6 @@ type Options struct {
// Last indicates whether this write is the last piece for // Last indicates whether this write is the last piece for
// this stream. // this stream.
Last bool Last bool
// If non-nil, ReturnBuffer.Done() should be called in order to return some
// allocated buffer back to a sync pool.
ReturnBuffer *ReturnBuffer
} }
// CallHdr carries the information of a particular RPC. // CallHdr carries the information of a particular RPC.
@ -810,37 +806,3 @@ func ContextErr(err error) error {
} }
return status.Errorf(codes.Internal, "Unexpected error from context packet: %v", err) return status.Errorf(codes.Internal, "Unexpected error from context packet: %v", err)
} }
// ReturnBuffer contains a function holding a closure that can return a byte
// slice back to the encoder for reuse. This function is called when the
// counter c reaches 0, which happens when all Add calls have called their
// corresponding Done calls. All operations on ReturnBuffer are
// concurrency-safe.
type ReturnBuffer struct {
c int32
f func()
}
// NewReturnBuffer allocates and returns a *ReturnBuffer.
func NewReturnBuffer(c int32, f func()) *ReturnBuffer {
return &ReturnBuffer{c: c, f: f}
}
// Add increments an internal counter atomically.
func (rb *ReturnBuffer) Add(n int32) {
atomic.AddInt32(&rb.c, n)
}
// Done decrements the internal counter and executes the closured ReturnBuffer
// function if the internal counter reaches zero.
func (rb *ReturnBuffer) Done() {
nc := atomic.AddInt32(&rb.c, -1)
if nc < 0 {
// Same behaviour as sync.WaitGroup, this should NEVER happen. And if it
// does happen, it's better to terminate early than silently continue with
// corrupt data.
grpclog.Fatalln("grpc: ReturnBuffer negative counter")
} else if nc == 0 {
rb.f()
}
}

View File

@ -28,10 +28,9 @@ import (
// This API is EXPERIMENTAL. // This API is EXPERIMENTAL.
type PreparedMsg struct { type PreparedMsg struct {
// Struct for preparing msg before sending them // Struct for preparing msg before sending them
encodedData []byte encodedData []byte
hdr []byte hdr []byte
payload []byte payload []byte
returnBuffer func()
} }
// Encode marshalls and compresses the message using the codec and compressor for the stream. // Encode marshalls and compresses the message using the codec and compressor for the stream.
@ -56,14 +55,6 @@ func (p *PreparedMsg) Encode(s Stream, msg interface{}) error {
return err return err
} }
p.encodedData = data p.encodedData = data
if cap(data) >= bufferReuseThreshold {
if bcodec, ok := rpcInfo.preloaderInfo.codec.(bufferReturner); ok {
p.returnBuffer = func() {
bcodec.ReturnBuffer(data)
}
}
}
compData, err := compress(data, rpcInfo.preloaderInfo.cp, rpcInfo.preloaderInfo.comp) compData, err := compress(data, rpcInfo.preloaderInfo.cp, rpcInfo.preloaderInfo.comp)
if err != nil { if err != nil {
return err return err

View File

@ -841,24 +841,12 @@ func (s *Server) incrCallsFailed() {
atomic.AddInt64(&s.czData.callsFailed, 1) atomic.AddInt64(&s.czData.callsFailed, 1)
} }
func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options, comp encoding.Compressor, attemptBufferReuse bool) error { func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options, comp encoding.Compressor) error {
codec := s.getCodec(stream.ContentSubtype()) data, err := encode(s.getCodec(stream.ContentSubtype()), msg)
data, err := encode(codec, msg)
if err != nil { if err != nil {
grpclog.Errorln("grpc: server failed to encode response: ", err) grpclog.Errorln("grpc: server failed to encode response: ", err)
return err return err
} }
if attemptBufferReuse && len(data) >= bufferReuseThreshold {
if bcodec, ok := codec.(bufferReturner); ok {
rb := transport.NewReturnBuffer(1, func() {
bcodec.ReturnBuffer(data)
})
opts.ReturnBuffer = rb
defer rb.Done()
}
}
compData, err := compress(data, cp, comp) compData, err := compress(data, cp, comp)
if err != nil { if err != nil {
grpclog.Errorln("grpc: server failed to compress response: ", err) grpclog.Errorln("grpc: server failed to compress response: ", err)
@ -1067,8 +1055,8 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
trInfo.tr.LazyLog(stringer("OK"), false) trInfo.tr.LazyLog(stringer("OK"), false)
} }
opts := &transport.Options{Last: true} opts := &transport.Options{Last: true}
err = s.sendResponse(t, stream, reply, cp, opts, comp, sh == nil && binlog == nil)
if err != nil { if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil {
if err == io.EOF { if err == io.EOF {
// The entire stream is done (for unary RPC only). // The entire stream is done (for unary RPC only).
return err return err
@ -1209,12 +1197,6 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
ss.binlog.Log(logEntry) ss.binlog.Log(logEntry)
} }
// Stats handlers and binlog handlers are allowed to retain references to
// this slice internally. We may not, therefore, return this to the pool.
if ss.statsHandler == nil && ss.binlog == nil {
ss.attemptBufferReuse = true
}
// If dc is set and matches the stream's compression, use it. Otherwise, try // If dc is set and matches the stream's compression, use it. Otherwise, try
// to find a matching registered compressor for decomp. // to find a matching registered compressor for decomp.
if rc := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc { if rc := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc {

104
stream.go
View File

@ -278,10 +278,6 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
} }
cs.binlog = binarylog.GetMethodLogger(method) cs.binlog = binarylog.GetMethodLogger(method)
// Stats handlers and binlog handlers are allowed to retain references to
// this slice internally. We may not, therefore, return this to the pool.
cs.attemptBufferReuse = sh == nil && cs.binlog == nil
cs.callInfo.stream = cs cs.callInfo.stream = cs
// Only this initial attempt has stats/tracing. // Only this initial attempt has stats/tracing.
// TODO(dfawley): move to newAttempt when per-attempt stats are implemented. // TODO(dfawley): move to newAttempt when per-attempt stats are implemented.
@ -426,12 +422,6 @@ type clientStream struct {
committed bool // active attempt committed for retry? committed bool // active attempt committed for retry?
buffer []func(a *csAttempt) error // operations to replay on retry buffer []func(a *csAttempt) error // operations to replay on retry
bufferSize int // current size of buffer bufferSize int // current size of buffer
// This is per-stream array instead of a per-attempt one because there may be
// multiple attempts working on the same data, but we may not free the same
// buffer twice.
returnBuffers []*transport.ReturnBuffer
attemptBufferReuse bool
} }
// csAttempt implements a single transport stream attempt within a // csAttempt implements a single transport stream attempt within a
@ -458,12 +448,8 @@ type csAttempt struct {
} }
func (cs *clientStream) commitAttemptLocked() { func (cs *clientStream) commitAttemptLocked() {
cs.buffer = nil
cs.committed = true cs.committed = true
for _, rb := range cs.returnBuffers { cs.buffer = nil
rb.Done()
}
cs.returnBuffers = nil
} }
func (cs *clientStream) commitAttempt() { func (cs *clientStream) commitAttempt() {
@ -710,45 +696,24 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
} }
// load hdr, payload, data // load hdr, payload, data
hdr, payload, data, f, err := prepareMsg(m, cs.codec, cs.cp, cs.comp, cs.attemptBufferReuse) hdr, payload, data, err := prepareMsg(m, cs.codec, cs.cp, cs.comp)
if err != nil { if err != nil {
return err return err
} }
var rb *transport.ReturnBuffer
if f != nil {
rb = transport.NewReturnBuffer(1, f)
}
// TODO(dfawley): should we be checking len(data) instead? // TODO(dfawley): should we be checking len(data) instead?
if len(payload) > *cs.callInfo.maxSendMessageSize { if len(payload) > *cs.callInfo.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), *cs.callInfo.maxSendMessageSize) return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), *cs.callInfo.maxSendMessageSize)
} }
msgBytes := data // Store the pointer before setting to nil. For binary logging. msgBytes := data // Store the pointer before setting to nil. For binary logging.
op := func(a *csAttempt) error { op := func(a *csAttempt) error {
err := a.sendMsg(m, hdr, payload, data, rb) err := a.sendMsg(m, hdr, payload, data)
// nil out the message and uncomp when replaying; they are only needed for // nil out the message and uncomp when replaying; they are only needed for
// stats which is disabled for subsequent attempts. // stats which is disabled for subsequent attempts.
m, data = nil, nil m, data = nil, nil
return err return err
} }
err = cs.withRetry(op, func() { cs.bufferForRetryLocked(len(hdr)+len(payload), op) }) err = cs.withRetry(op, func() { cs.bufferForRetryLocked(len(hdr)+len(payload), op) })
if rb != nil {
// If this stream is not committed, the buffer needs to be kept for future
// attempts. It's ref-count will be subtracted when committing.
//
// If this stream is already committed, the ref-count can be subtracted
// here.
cs.mu.Lock()
if !cs.committed {
cs.returnBuffers = append(cs.returnBuffers, rb)
} else {
rb.Done()
}
cs.mu.Unlock()
}
if cs.binlog != nil && err == nil { if cs.binlog != nil && err == nil {
cs.binlog.Log(&binarylog.ClientMessage{ cs.binlog.Log(&binarylog.ClientMessage{
OnClientSide: true, OnClientSide: true,
@ -833,7 +798,6 @@ func (cs *clientStream) finish(err error) {
cs.mu.Unlock() cs.mu.Unlock()
return return
} }
cs.finished = true cs.finished = true
cs.commitAttemptLocked() cs.commitAttemptLocked()
cs.mu.Unlock() cs.mu.Unlock()
@ -869,7 +833,7 @@ func (cs *clientStream) finish(err error) {
cs.cancel() cs.cancel()
} }
func (a *csAttempt) sendMsg(m interface{}, hdr, payld, data []byte, rb *transport.ReturnBuffer) error { func (a *csAttempt) sendMsg(m interface{}, hdr, payld, data []byte) error {
cs := a.cs cs := a.cs
if a.trInfo != nil { if a.trInfo != nil {
a.mu.Lock() a.mu.Lock()
@ -878,8 +842,7 @@ func (a *csAttempt) sendMsg(m interface{}, hdr, payld, data []byte, rb *transpor
} }
a.mu.Unlock() a.mu.Unlock()
} }
if err := a.t.Write(a.s, hdr, payld, &transport.Options{Last: !cs.desc.ClientStreams}); err != nil {
if err := a.t.Write(a.s, hdr, payld, &transport.Options{Last: !cs.desc.ClientStreams, ReturnBuffer: rb}); err != nil {
if !cs.desc.ClientStreams { if !cs.desc.ClientStreams {
// For non-client-streaming RPCs, we return nil instead of EOF on error // For non-client-streaming RPCs, we return nil instead of EOF on error
// because the generated code requires it. finish is not called; RecvMsg() // because the generated code requires it. finish is not called; RecvMsg()
@ -1202,24 +1165,18 @@ func (as *addrConnStream) SendMsg(m interface{}) (err error) {
as.sentLast = true as.sentLast = true
} }
// load hdr, payload, data, returnBuffer // load hdr, payload, data
hdr, payld, _, f, err := prepareMsg(m, as.codec, as.cp, as.comp, true) hdr, payld, _, err := prepareMsg(m, as.codec, as.cp, as.comp)
if err != nil { if err != nil {
return err return err
} }
var rb *transport.ReturnBuffer
if f != nil {
rb = transport.NewReturnBuffer(1, f)
defer rb.Done()
}
// TODO(dfawley): should we be checking len(data) instead? // TODO(dfawley): should we be checking len(data) instead?
if len(payld) > *as.callInfo.maxSendMessageSize { if len(payld) > *as.callInfo.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payld), *as.callInfo.maxSendMessageSize) return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payld), *as.callInfo.maxSendMessageSize)
} }
if err := as.t.Write(as.s, hdr, payld, &transport.Options{Last: !as.desc.ClientStreams, ReturnBuffer: rb}); err != nil { if err := as.t.Write(as.s, hdr, payld, &transport.Options{Last: !as.desc.ClientStreams}); err != nil {
if !as.desc.ClientStreams { if !as.desc.ClientStreams {
// For non-client-streaming RPCs, we return nil instead of EOF on error // For non-client-streaming RPCs, we return nil instead of EOF on error
// because the generated code requires it. finish is not called; RecvMsg() // because the generated code requires it. finish is not called; RecvMsg()
@ -1390,8 +1347,6 @@ type serverStream struct {
serverHeaderBinlogged bool serverHeaderBinlogged bool
mu sync.Mutex // protects trInfo.tr after the service handler runs. mu sync.Mutex // protects trInfo.tr after the service handler runs.
attemptBufferReuse bool
} }
func (ss *serverStream) Context() context.Context { func (ss *serverStream) Context() context.Context {
@ -1453,23 +1408,17 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
} }
}() }()
// load hdr, payload, returnBuffer, data // load hdr, payload, data
hdr, payload, data, f, err := prepareMsg(m, ss.codec, ss.cp, ss.comp, ss.attemptBufferReuse) hdr, payload, data, err := prepareMsg(m, ss.codec, ss.cp, ss.comp)
if err != nil { if err != nil {
return err return err
} }
var rb *transport.ReturnBuffer
if f != nil {
rb = transport.NewReturnBuffer(1, f)
defer rb.Done()
}
// TODO(dfawley): should we be checking len(data) instead? // TODO(dfawley): should we be checking len(data) instead?
if len(payload) > ss.maxSendMessageSize { if len(payload) > ss.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), ss.maxSendMessageSize) return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), ss.maxSendMessageSize)
} }
if err := ss.t.Write(ss.s, hdr, payload, &transport.Options{Last: false, ReturnBuffer: rb}); err != nil { if err := ss.t.Write(ss.s, hdr, payload, &transport.Options{Last: false}); err != nil {
return toRPCErr(err) return toRPCErr(err)
} }
if ss.binlog != nil { if ss.binlog != nil {
@ -1558,44 +1507,23 @@ func MethodFromServerStream(stream ServerStream) (string, bool) {
return Method(stream.Context()) return Method(stream.Context())
} }
// Threshold beyond which buffer reuse should apply.
//
// TODO(adtac): make this an option in the future so that the user can
// configure it per-RPC or even per-message?
const bufferReuseThreshold = 1024
// prepareMsg returns the hdr, payload and data // prepareMsg returns the hdr, payload and data
// using the compressors passed or using the // using the compressors passed or using the
// passed preparedmsg // passed preparedmsg
func prepareMsg(m interface{}, codec baseCodec, cp Compressor, comp encoding.Compressor, attemptBufferReuse bool) (hdr, payload, data []byte, returnBuffer func(), err error) { func prepareMsg(m interface{}, codec baseCodec, cp Compressor, comp encoding.Compressor) (hdr, payload, data []byte, err error) {
if preparedMsg, ok := m.(*PreparedMsg); ok { if preparedMsg, ok := m.(*PreparedMsg); ok {
f := preparedMsg.returnBuffer return preparedMsg.hdr, preparedMsg.payload, preparedMsg.encodedData, nil
if !attemptBufferReuse {
f = nil
}
return preparedMsg.hdr, preparedMsg.payload, preparedMsg.encodedData, f, nil
} }
// The input interface is not a prepared msg. // The input interface is not a prepared msg.
// Marshal and Compress the data at this point // Marshal and Compress the data at this point
data, err = encode(codec, m) data, err = encode(codec, m)
if err != nil { if err != nil {
return nil, nil, nil, nil, err return nil, nil, nil, err
} }
if attemptBufferReuse && cap(data) >= bufferReuseThreshold {
if bcodec, ok := codec.(bufferReturner); ok {
returnBuffer = func() {
bcodec.ReturnBuffer(data)
}
}
}
compData, err := compress(data, cp, comp) compData, err := compress(data, cp, comp)
if err != nil { if err != nil {
return nil, nil, nil, nil, err return nil, nil, nil, err
} }
hdr, payload = msgHeader(data, compData) hdr, payload = msgHeader(data, compData)
return hdr, payload, data, returnBuffer, nil return hdr, payload, data, nil
} }

2
vet.sh
View File

@ -117,7 +117,7 @@ fi
# TODO(dfawley): don't use deprecated functions in examples or first-party # TODO(dfawley): don't use deprecated functions in examples or first-party
# plugins. # plugins.
SC_OUT="$(mktemp)" SC_OUT="$(mktemp)"
staticcheck -go 1.9 -checks 'inherit,-ST1015,-SA6002' ./... > "${SC_OUT}" || true staticcheck -go 1.9 -checks 'inherit,-ST1015' ./... > "${SC_OUT}" || true
# Error if anything other than deprecation warnings are printed. # Error if anything other than deprecation warnings are printed.
(! grep -v "is deprecated:.*SA1019" "${SC_OUT}") (! grep -v "is deprecated:.*SA1019" "${SC_OUT}")
# Only ignore the following deprecated types/fields/functions. # Only ignore the following deprecated types/fields/functions.