From 293ea03eb034cb354359b641f9cde805ead75b45 Mon Sep 17 00:00:00 2001 From: Juan Batiz-Benet Date: Thu, 1 Jan 2015 07:00:35 -0800 Subject: [PATCH] secio: buffer remainders in calls to Read() used to return io.ErrShortBuffer, but this makes client code much more complicated. we're already allocating buffers when it's too large, so might as well just keep it for later. --- crypto/secio/rw.go | 41 ++++++++++++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/crypto/secio/rw.go b/crypto/secio/rw.go index 2a8106b04..8094e3a36 100644 --- a/crypto/secio/rw.go +++ b/crypto/secio/rw.go @@ -76,6 +76,9 @@ type etmReader struct { msgio.Reader io.Closer + // buffer + buf []byte + // params msg msgio.ReadCloser // msgio for knowing where boundaries lie str cipher.Stream // the stream cipher to encrypt with @@ -91,20 +94,35 @@ func (r *etmReader) NextMsgLen() (int, error) { return r.msg.NextMsgLen() } +func (r *etmReader) drainBuf(buf []byte) int { + if r.buf == nil { + return 0 + } + + n := copy(buf, r.buf) + r.buf = r.buf[n:] + return n +} + func (r *etmReader) Read(buf []byte) (int, error) { - // first, check the buffer has enough space. + // first, check if we have anything in the buffer + copied := r.drainBuf(buf) + buf = buf[copied:] + if copied > 0 { + return copied, nil + // return here to avoid complicating the rest... + // user can call io.ReadFull. + } + + // check the buffer has enough space for the next msg fullLen, err := r.msg.NextMsgLen() if err != nil { return 0, err } - dataLen := fullLen - r.mac.size - if cap(buf) < dataLen { - return 0, io.ErrShortBuffer - } - buf2 := buf changed := false + // if not enough space, allocate a new buffer. if cap(buf) < fullLen { buf2 = make([]byte, fullLen) changed = true @@ -121,10 +139,15 @@ func (r *etmReader) Read(buf []byte) (int, error) { return 0, err } buf2 = buf2[:m] - if changed { - return copy(buf, buf2), nil + if !changed { + return m, nil } - return m, nil + + n = copy(buf, buf2) + if len(buf2) > len(buf) { + r.buf = buf2[len(buf):] // had some left over? save it. + } + return n, nil } func (r *etmReader) ReadMsg() ([]byte, error) {