transport: share recv buffers (#2813)

transport: share recv buffers
This commit is contained in:
Can Guler
2019-06-20 15:01:58 -07:00
committed by GitHub
parent 712624e686
commit eca11cb9e4
5 changed files with 68 additions and 24 deletions

View File

@ -24,6 +24,7 @@
package transport package transport
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"fmt" "fmt"
@ -347,7 +348,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
ht.stats.HandleRPC(s.ctx, inHeader) ht.stats.HandleRPC(s.ctx, inHeader)
} }
s.trReader = &transportReader{ s.trReader = &transportReader{
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf}, reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}},
windowHandler: func(int) {}, windowHandler: func(int) {},
} }
@ -361,7 +362,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
for buf := make([]byte, readSize); ; { for buf := make([]byte, readSize); ; {
n, err := req.Body.Read(buf) n, err := req.Body.Read(buf)
if n > 0 { if n > 0 {
s.buf.put(recvMsg{data: buf[:n:n]}) s.buf.put(recvMsg{buffer: bytes.NewBuffer(buf[:n:n])})
buf = buf[n:] buf = buf[n:]
} }
if err != nil { if err != nil {

View File

@ -117,6 +117,8 @@ type http2Client struct {
onGoAway func(GoAwayReason) onGoAway func(GoAwayReason)
onClose func() onClose func()
bufferPool *bufferPool
} }
func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr string) (net.Conn, error) { func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr string) (net.Conn, error) {
@ -249,6 +251,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
onGoAway: onGoAway, onGoAway: onGoAway,
onClose: onClose, onClose: onClose,
keepaliveEnabled: keepaliveEnabled, keepaliveEnabled: keepaliveEnabled,
bufferPool: newBufferPool(),
} }
t.controlBuf = newControlBuffer(t.ctxDone) t.controlBuf = newControlBuffer(t.ctxDone)
if opts.InitialWindowSize >= defaultWindowSize { if opts.InitialWindowSize >= defaultWindowSize {
@ -367,6 +370,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
closeStream: func(err error) { closeStream: func(err error) {
t.CloseStream(s, err) t.CloseStream(s, err)
}, },
freeBuffer: t.bufferPool.put,
}, },
windowHandler: func(n int) { windowHandler: func(n int) {
t.updateWindow(s, uint32(n)) t.updateWindow(s, uint32(n))
@ -946,9 +950,10 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
// guarantee f.Data() is consumed before the arrival of next frame. // guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated? // Can this copy be eliminated?
if len(f.Data()) > 0 { if len(f.Data()) > 0 {
data := make([]byte, len(f.Data())) buffer := t.bufferPool.get()
copy(data, f.Data()) buffer.Reset()
s.write(recvMsg{data: data}) buffer.Write(f.Data())
s.write(recvMsg{buffer: buffer})
} }
} }
// The server has closed the stream without sending trailers. Record that // The server has closed the stream without sending trailers. Record that

View File

@ -124,6 +124,7 @@ type http2Server struct {
// Fields below are for channelz metric collection. // Fields below are for channelz metric collection.
channelzID int64 // channelz unique identification number channelzID int64 // channelz unique identification number
czData *channelzData czData *channelzData
bufferPool *bufferPool
} }
// newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is // newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is
@ -225,6 +226,7 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
kep: kep, kep: kep,
initialWindowSize: iwz, initialWindowSize: iwz,
czData: new(channelzData), czData: new(channelzData),
bufferPool: newBufferPool(),
} }
t.controlBuf = newControlBuffer(t.ctxDone) t.controlBuf = newControlBuffer(t.ctxDone)
if dynamicWindow { if dynamicWindow {
@ -410,9 +412,10 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone) s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
s.trReader = &transportReader{ s.trReader = &transportReader{
reader: &recvBufferReader{ reader: &recvBufferReader{
ctx: s.ctx, ctx: s.ctx,
ctxDone: s.ctxDone, ctxDone: s.ctxDone,
recv: s.buf, recv: s.buf,
freeBuffer: t.bufferPool.put,
}, },
windowHandler: func(n int) { windowHandler: func(n int) {
t.updateWindow(s, uint32(n)) t.updateWindow(s, uint32(n))
@ -596,9 +599,10 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
// guarantee f.Data() is consumed before the arrival of next frame. // guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated? // Can this copy be eliminated?
if len(f.Data()) > 0 { if len(f.Data()) > 0 {
data := make([]byte, len(f.Data())) buffer := t.bufferPool.get()
copy(data, f.Data()) buffer.Reset()
s.write(recvMsg{data: data}) buffer.Write(f.Data())
s.write(recvMsg{buffer: buffer})
} }
} }
if f.Header().Flags.Has(http2.FlagDataEndStream) { if f.Header().Flags.Has(http2.FlagDataEndStream) {

View File

@ -22,6 +22,7 @@
package transport package transport
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"fmt" "fmt"
@ -39,10 +40,32 @@ import (
"google.golang.org/grpc/tap" "google.golang.org/grpc/tap"
) )
type bufferPool struct {
pool sync.Pool
}
func newBufferPool() *bufferPool {
return &bufferPool{
pool: sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
},
}
}
func (p *bufferPool) get() *bytes.Buffer {
return p.pool.Get().(*bytes.Buffer)
}
func (p *bufferPool) put(b *bytes.Buffer) {
p.pool.Put(b)
}
// recvMsg represents the received msg from the transport. All transport // recvMsg represents the received msg from the transport. All transport
// protocol specific info has been removed. // protocol specific info has been removed.
type recvMsg struct { type recvMsg struct {
data []byte buffer *bytes.Buffer
// nil: received some data // nil: received some data
// io.EOF: stream is completed. data is nil. // io.EOF: stream is completed. data is nil.
// other non-nil error: transport failure. data is nil. // other non-nil error: transport failure. data is nil.
@ -117,8 +140,9 @@ type recvBufferReader struct {
ctx context.Context ctx context.Context
ctxDone <-chan struct{} // cache of ctx.Done() (for performance). ctxDone <-chan struct{} // cache of ctx.Done() (for performance).
recv *recvBuffer recv *recvBuffer
last []byte // Stores the remaining data in the previous calls. last *bytes.Buffer // Stores the remaining data in the previous calls.
err error err error
freeBuffer func(*bytes.Buffer)
} }
// Read reads the next len(p) bytes from last. If last is drained, it tries to // Read reads the next len(p) bytes from last. If last is drained, it tries to
@ -128,10 +152,13 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) {
if r.err != nil { if r.err != nil {
return 0, r.err return 0, r.err
} }
if r.last != nil && len(r.last) > 0 { if r.last != nil {
// Read remaining data left in last call. // Read remaining data left in last call.
copied := copy(p, r.last) copied, _ := r.last.Read(p)
r.last = r.last[copied:] if r.last.Len() == 0 {
r.freeBuffer(r.last)
r.last = nil
}
return copied, nil return copied, nil
} }
if r.closeStream != nil { if r.closeStream != nil {
@ -170,8 +197,13 @@ func (r *recvBufferReader) readAdditional(m recvMsg, p []byte) (n int, err error
if m.err != nil { if m.err != nil {
return 0, m.err return 0, m.err
} }
copied := copy(p, m.data) copied, _ := m.buffer.Read(p)
r.last = m.data[copied:] if m.buffer.Len() == 0 {
r.freeBuffer(m.buffer)
r.last = nil
} else {
r.last = m.buffer
}
return copied, nil return copied, nil
} }

View File

@ -1968,16 +1968,18 @@ func TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) {
} }
s.trReader = &transportReader{ s.trReader = &transportReader{
reader: &recvBufferReader{ reader: &recvBufferReader{
ctx: s.ctx, ctx: s.ctx,
ctxDone: s.ctx.Done(), ctxDone: s.ctx.Done(),
recv: s.buf, recv: s.buf,
freeBuffer: func(*bytes.Buffer) {},
}, },
windowHandler: func(int) {}, windowHandler: func(int) {},
} }
testData := make([]byte, 1) testData := make([]byte, 1)
testData[0] = 5 testData[0] = 5
testBuffer := bytes.NewBuffer(testData)
testErr := errors.New("test error") testErr := errors.New("test error")
s.write(recvMsg{data: testData, err: testErr}) s.write(recvMsg{buffer: testBuffer, err: testErr})
inBuf := make([]byte, 1) inBuf := make([]byte, 1)
actualCount, actualErr := s.Read(inBuf) actualCount, actualErr := s.Read(inBuf)
@ -1988,8 +1990,8 @@ func TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) {
t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error()) t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error())
} }
s.write(recvMsg{data: testData, err: nil}) s.write(recvMsg{buffer: testBuffer, err: nil})
s.write(recvMsg{data: testData, err: errors.New("different error from first")}) s.write(recvMsg{buffer: testBuffer, err: errors.New("different error from first")})
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
inBuf := make([]byte, 1) inBuf := make([]byte, 1)