diff --git a/credentials/alts/internal/handshaker/handshaker.go b/credentials/alts/internal/handshaker/handshaker.go index 083a3de6..8bc7ceee 100644 --- a/credentials/alts/internal/handshaker/handshaker.go +++ b/credentials/alts/internal/handshaker/handshaker.go @@ -64,6 +64,9 @@ var ( concurrentHandshakes = int64(0) // errDropped occurs when maxPendingHandshakes is reached. errDropped = errors.New("maximum number of concurrent ALTS handshakes is reached") + // errOutOfBound occurs when the handshake service returns a consumed + // bytes value larger than the buffer that was passed to it originally. + errOutOfBound = errors.New("handshaker service consumed bytes value is out-of-bound") ) func init() { @@ -284,6 +287,9 @@ func (h *altsHandshaker) doHandshake(req *altspb.HandshakerReq) (net.Conn, *alts var extra []byte if req.GetServerStart() != nil { + if resp.GetBytesConsumed() > uint32(len(req.GetServerStart().GetInBytes())) { + return nil, nil, errOutOfBound + } extra = req.GetServerStart().GetInBytes()[resp.GetBytesConsumed():] } result, extra, err := h.processUntilDone(resp, extra) @@ -355,6 +361,9 @@ func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []b return nil, nil, err } // Set extra based on handshaker service response. + if resp.GetBytesConsumed() > uint32(len(p)) { + return nil, nil, errOutOfBound + } extra = p[resp.GetBytesConsumed():] } }