@ -24,6 +24,7 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -347,7 +348,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
|
||||
ht.stats.HandleRPC(s.ctx, inHeader)
|
||||
}
|
||||
s.trReader = &transportReader{
|
||||
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf},
|
||||
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}},
|
||||
windowHandler: func(int) {},
|
||||
}
|
||||
|
||||
@ -361,7 +362,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
|
||||
for buf := make([]byte, readSize); ; {
|
||||
n, err := req.Body.Read(buf)
|
||||
if n > 0 {
|
||||
s.buf.put(recvMsg{data: buf[:n:n]})
|
||||
s.buf.put(recvMsg{buffer: bytes.NewBuffer(buf[:n:n])})
|
||||
buf = buf[n:]
|
||||
}
|
||||
if err != nil {
|
||||
|
@ -117,6 +117,8 @@ type http2Client struct {
|
||||
|
||||
onGoAway func(GoAwayReason)
|
||||
onClose func()
|
||||
|
||||
bufferPool *bufferPool
|
||||
}
|
||||
|
||||
func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr string) (net.Conn, error) {
|
||||
@ -249,6 +251,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
|
||||
onGoAway: onGoAway,
|
||||
onClose: onClose,
|
||||
keepaliveEnabled: keepaliveEnabled,
|
||||
bufferPool: newBufferPool(),
|
||||
}
|
||||
t.controlBuf = newControlBuffer(t.ctxDone)
|
||||
if opts.InitialWindowSize >= defaultWindowSize {
|
||||
@ -367,6 +370,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
|
||||
closeStream: func(err error) {
|
||||
t.CloseStream(s, err)
|
||||
},
|
||||
freeBuffer: t.bufferPool.put,
|
||||
},
|
||||
windowHandler: func(n int) {
|
||||
t.updateWindow(s, uint32(n))
|
||||
@ -946,9 +950,10 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
|
||||
// guarantee f.Data() is consumed before the arrival of next frame.
|
||||
// Can this copy be eliminated?
|
||||
if len(f.Data()) > 0 {
|
||||
data := make([]byte, len(f.Data()))
|
||||
copy(data, f.Data())
|
||||
s.write(recvMsg{data: data})
|
||||
buffer := t.bufferPool.get()
|
||||
buffer.Reset()
|
||||
buffer.Write(f.Data())
|
||||
s.write(recvMsg{buffer: buffer})
|
||||
}
|
||||
}
|
||||
// The server has closed the stream without sending trailers. Record that
|
||||
|
@ -124,6 +124,7 @@ type http2Server struct {
|
||||
// Fields below are for channelz metric collection.
|
||||
channelzID int64 // channelz unique identification number
|
||||
czData *channelzData
|
||||
bufferPool *bufferPool
|
||||
}
|
||||
|
||||
// newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is
|
||||
@ -225,6 +226,7 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
|
||||
kep: kep,
|
||||
initialWindowSize: iwz,
|
||||
czData: new(channelzData),
|
||||
bufferPool: newBufferPool(),
|
||||
}
|
||||
t.controlBuf = newControlBuffer(t.ctxDone)
|
||||
if dynamicWindow {
|
||||
@ -410,9 +412,10 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
||||
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
|
||||
s.trReader = &transportReader{
|
||||
reader: &recvBufferReader{
|
||||
ctx: s.ctx,
|
||||
ctxDone: s.ctxDone,
|
||||
recv: s.buf,
|
||||
ctx: s.ctx,
|
||||
ctxDone: s.ctxDone,
|
||||
recv: s.buf,
|
||||
freeBuffer: t.bufferPool.put,
|
||||
},
|
||||
windowHandler: func(n int) {
|
||||
t.updateWindow(s, uint32(n))
|
||||
@ -596,9 +599,10 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
|
||||
// guarantee f.Data() is consumed before the arrival of next frame.
|
||||
// Can this copy be eliminated?
|
||||
if len(f.Data()) > 0 {
|
||||
data := make([]byte, len(f.Data()))
|
||||
copy(data, f.Data())
|
||||
s.write(recvMsg{data: data})
|
||||
buffer := t.bufferPool.get()
|
||||
buffer.Reset()
|
||||
buffer.Write(f.Data())
|
||||
s.write(recvMsg{buffer: buffer})
|
||||
}
|
||||
}
|
||||
if f.Header().Flags.Has(http2.FlagDataEndStream) {
|
||||
|
@ -22,6 +22,7 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -39,10 +40,32 @@ import (
|
||||
"google.golang.org/grpc/tap"
|
||||
)
|
||||
|
||||
type bufferPool struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
func newBufferPool() *bufferPool {
|
||||
return &bufferPool{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(bytes.Buffer)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *bufferPool) get() *bytes.Buffer {
|
||||
return p.pool.Get().(*bytes.Buffer)
|
||||
}
|
||||
|
||||
func (p *bufferPool) put(b *bytes.Buffer) {
|
||||
p.pool.Put(b)
|
||||
}
|
||||
|
||||
// recvMsg represents the received msg from the transport. All transport
|
||||
// protocol specific info has been removed.
|
||||
type recvMsg struct {
|
||||
data []byte
|
||||
buffer *bytes.Buffer
|
||||
// nil: received some data
|
||||
// io.EOF: stream is completed. data is nil.
|
||||
// other non-nil error: transport failure. data is nil.
|
||||
@ -117,8 +140,9 @@ type recvBufferReader struct {
|
||||
ctx context.Context
|
||||
ctxDone <-chan struct{} // cache of ctx.Done() (for performance).
|
||||
recv *recvBuffer
|
||||
last []byte // Stores the remaining data in the previous calls.
|
||||
last *bytes.Buffer // Stores the remaining data in the previous calls.
|
||||
err error
|
||||
freeBuffer func(*bytes.Buffer)
|
||||
}
|
||||
|
||||
// Read reads the next len(p) bytes from last. If last is drained, it tries to
|
||||
@ -128,10 +152,13 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) {
|
||||
if r.err != nil {
|
||||
return 0, r.err
|
||||
}
|
||||
if r.last != nil && len(r.last) > 0 {
|
||||
if r.last != nil {
|
||||
// Read remaining data left in last call.
|
||||
copied := copy(p, r.last)
|
||||
r.last = r.last[copied:]
|
||||
copied, _ := r.last.Read(p)
|
||||
if r.last.Len() == 0 {
|
||||
r.freeBuffer(r.last)
|
||||
r.last = nil
|
||||
}
|
||||
return copied, nil
|
||||
}
|
||||
if r.closeStream != nil {
|
||||
@ -170,8 +197,13 @@ func (r *recvBufferReader) readAdditional(m recvMsg, p []byte) (n int, err error
|
||||
if m.err != nil {
|
||||
return 0, m.err
|
||||
}
|
||||
copied := copy(p, m.data)
|
||||
r.last = m.data[copied:]
|
||||
copied, _ := m.buffer.Read(p)
|
||||
if m.buffer.Len() == 0 {
|
||||
r.freeBuffer(m.buffer)
|
||||
r.last = nil
|
||||
} else {
|
||||
r.last = m.buffer
|
||||
}
|
||||
return copied, nil
|
||||
}
|
||||
|
||||
|
@ -1968,16 +1968,18 @@ func TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) {
|
||||
}
|
||||
s.trReader = &transportReader{
|
||||
reader: &recvBufferReader{
|
||||
ctx: s.ctx,
|
||||
ctxDone: s.ctx.Done(),
|
||||
recv: s.buf,
|
||||
ctx: s.ctx,
|
||||
ctxDone: s.ctx.Done(),
|
||||
recv: s.buf,
|
||||
freeBuffer: func(*bytes.Buffer) {},
|
||||
},
|
||||
windowHandler: func(int) {},
|
||||
}
|
||||
testData := make([]byte, 1)
|
||||
testData[0] = 5
|
||||
testBuffer := bytes.NewBuffer(testData)
|
||||
testErr := errors.New("test error")
|
||||
s.write(recvMsg{data: testData, err: testErr})
|
||||
s.write(recvMsg{buffer: testBuffer, err: testErr})
|
||||
|
||||
inBuf := make([]byte, 1)
|
||||
actualCount, actualErr := s.Read(inBuf)
|
||||
@ -1988,8 +1990,8 @@ func TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) {
|
||||
t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error())
|
||||
}
|
||||
|
||||
s.write(recvMsg{data: testData, err: nil})
|
||||
s.write(recvMsg{data: testData, err: errors.New("different error from first")})
|
||||
s.write(recvMsg{buffer: testBuffer, err: nil})
|
||||
s.write(recvMsg{buffer: testBuffer, err: errors.New("different error from first")})
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
inBuf := make([]byte, 1)
|
||||
|
Reference in New Issue
Block a user