refactor(kms): share a KMS client for all KMS operations (#744)

This commit is contained in:
Sanchith Hegde
2023-03-15 21:53:38 +05:30
committed by GitHub
parent 5c9bec9f53
commit a3ff2e8d4f
14 changed files with 347 additions and 306 deletions

View File

@ -3,6 +3,8 @@ use error_stack::{report, IntoReport, ResultExt};
use masking::{PeekInterface, StrongSecret};
use router_env::{instrument, tracing};
#[cfg(feature = "kms")]
use crate::services::kms;
use crate::{
configs::settings,
consts,
@ -12,43 +14,40 @@ use crate::{
types::{api, storage, transformers::ForeignInto},
utils,
};
#[cfg(feature = "kms")]
use crate::{routes::metrics, services::kms};
pub static HASH_KEY: tokio::sync::OnceCell<StrongSecret<[u8; PlaintextApiKey::HASH_KEY_LEN]>> =
static HASH_KEY: tokio::sync::OnceCell<StrongSecret<[u8; PlaintextApiKey::HASH_KEY_LEN]>> =
tokio::sync::OnceCell::const_new();
pub async fn get_hash_key(
api_key_config: &settings::ApiKeys,
) -> errors::RouterResult<StrongSecret<[u8; PlaintextApiKey::HASH_KEY_LEN]>> {
#[cfg(feature = "kms")]
let hash_key = kms::KeyHandler::get_kms_decrypted_key(
&api_key_config.aws_region,
&api_key_config.aws_key_id,
api_key_config.kms_encrypted_hash_key.clone(),
)
.await
.map_err(|error| {
metrics::AWS_KMS_FAILURES.add(&metrics::CONTEXT, 1, &[]);
error
})
.change_context(errors::ApiErrorResponse::InternalServerError)
.attach_printable("Failed to KMS decrypt API key hashing key")?;
#[cfg(feature = "kms")] kms_config: &settings::Kms,
) -> errors::RouterResult<&'static StrongSecret<[u8; PlaintextApiKey::HASH_KEY_LEN]>> {
HASH_KEY
.get_or_try_init(|| async {
#[cfg(feature = "kms")]
let hash_key = kms::get_kms_client(kms_config)
.await
.decrypt(&api_key_config.kms_encrypted_hash_key)
.await
.change_context(errors::ApiErrorResponse::InternalServerError)
.attach_printable("Failed to KMS decrypt API key hashing key")?;
#[cfg(not(feature = "kms"))]
let hash_key = &api_key_config.hash_key;
#[cfg(not(feature = "kms"))]
let hash_key = &api_key_config.hash_key;
<[u8; PlaintextApiKey::HASH_KEY_LEN]>::try_from(
hex::decode(hash_key)
<[u8; PlaintextApiKey::HASH_KEY_LEN]>::try_from(
hex::decode(hash_key)
.into_report()
.change_context(errors::ApiErrorResponse::InternalServerError)
.attach_printable("API key hash key has invalid hexadecimal data")?
.as_slice(),
)
.into_report()
.change_context(errors::ApiErrorResponse::InternalServerError)
.attach_printable("API key hash key has invalid hexadecimal data")?
.as_slice(),
)
.into_report()
.change_context(errors::ApiErrorResponse::InternalServerError)
.attach_printable("The API hashing key has incorrect length")
.map(StrongSecret::new)
.attach_printable("The API hashing key has incorrect length")
.map(StrongSecret::new)
})
.await
}
// Defining new types `PlaintextApiKey` and `HashedApiKey` in the hopes of reducing the possibility
@ -119,12 +118,16 @@ impl PlaintextApiKey {
pub async fn create_api_key(
store: &dyn StorageInterface,
api_key_config: &settings::ApiKeys,
#[cfg(feature = "kms")] kms_config: &settings::Kms,
api_key: api::CreateApiKeyRequest,
merchant_id: String,
) -> RouterResponse<api::CreateApiKeyResponse> {
let hash_key = HASH_KEY
.get_or_try_init(|| get_hash_key(api_key_config))
.await?;
let hash_key = get_hash_key(
api_key_config,
#[cfg(feature = "kms")]
kms_config,
)
.await?;
let plaintext_api_key = PlaintextApiKey::new(consts::API_KEY_LENGTH);
let api_key = storage::ApiKeyNew {
key_id: PlaintextApiKey::new_key_id(),
@ -248,10 +251,13 @@ mod tests {
let settings = settings::Settings::new().expect("invalid settings");
let plaintext_api_key = PlaintextApiKey::new(consts::API_KEY_LENGTH);
let hash_key = HASH_KEY
.get_or_try_init(|| get_hash_key(&settings.api_keys))
.await
.unwrap();
let hash_key = get_hash_key(
&settings.api_keys,
#[cfg(feature = "kms")]
&settings.kms,
)
.await
.unwrap();
let hashed_api_key = plaintext_api_key.keyed_hash(hash_key.peek());
assert_ne!(