diff --git a/credentials/alts/internal/handshaker/handshaker.go b/credentials/alts/internal/handshaker/handshaker.go index 633c7125..083a3de6 100644 --- a/credentials/alts/internal/handshaker/handshaker.go +++ b/credentials/alts/internal/handshaker/handshaker.go @@ -74,8 +74,10 @@ func init() { } } -func acquire(n int64) bool { +func acquire() bool { mu.Lock() + // If we need n to be configurable, we can pass it as an argument. + n := int64(1) success := maxPendingHandshakes-concurrentHandshakes >= n if success { concurrentHandshakes += n @@ -84,8 +86,10 @@ func acquire(n int64) bool { return success } -func release(n int64) { +func release() { mu.Lock() + // If we need n to be configurable, we can pass it as an argument. + n := int64(1) concurrentHandshakes -= n if concurrentHandshakes < 0 { mu.Unlock() @@ -182,10 +186,10 @@ func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, // ClientHandshake starts and completes a client ALTS handshaking for GCP. Once // done, ClientHandshake returns a secure connection. func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) { - if !acquire(1) { + if !acquire() { return nil, nil, errDropped } - defer release(1) + defer release() if h.side != core.ClientSide { return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client handshaker") @@ -225,10 +229,10 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent // ServerHandshake starts and completes a server ALTS handshaking for GCP. Once // done, ServerHandshake returns a secure connection. func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) { - if !acquire(1) { + if !acquire() { return nil, nil, errDropped } - defer release(1) + defer release() if h.side != core.ServerSide { return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server handshaker")