refactor(router): KMS decrypt secrets when kms feature is enabled (#868)

Co-authored-by: Sanchith Hegde <22217505+SanchithHegde@users.noreply.github.com>
This commit is contained in:
chethan-rao-juspay
2023-04-17 23:59:39 +05:30
committed by GitHub
parent bc38bc47d8
commit 8905e66340
5 changed files with 127 additions and 19 deletions

View File

@ -3,10 +3,13 @@ use api_models::{payment_methods::PaymentMethodListRequest, payments::PaymentsRe
use async_trait::async_trait;
use common_utils::date_time;
use error_stack::{report, IntoReport, ResultExt};
#[cfg(feature = "kms")]
use external_services::kms;
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use masking::PeekInterface;
use masking::{PeekInterface, StrongSecret};
use crate::{
configs::settings,
core::{
api_keys,
errors::{self, RouterResult},
@ -112,6 +115,31 @@ where
}
}
static ADMIN_API_KEY: tokio::sync::OnceCell<StrongSecret<String>> =
tokio::sync::OnceCell::const_new();
pub async fn get_admin_api_key(
secrets: &settings::Secrets,
#[cfg(feature = "kms")] kms_config: &kms::KmsConfig,
) -> RouterResult<&'static StrongSecret<String>> {
ADMIN_API_KEY
.get_or_try_init(|| async {
#[cfg(feature = "kms")]
let admin_api_key = kms::get_kms_client(kms_config)
.await
.decrypt(&secrets.kms_encrypted_admin_api_key)
.await
.change_context(errors::ApiErrorResponse::InternalServerError)
.attach_printable("Failed to KMS decrypt admin API key")?;
#[cfg(not(feature = "kms"))]
let admin_api_key = secrets.admin_api_key.clone();
Ok(StrongSecret::new(admin_api_key))
})
.await
}
#[derive(Debug)]
pub struct AdminApiAuth;
@ -125,13 +153,22 @@ where
request_headers: &HeaderMap,
state: &A,
) -> RouterResult<()> {
let admin_api_key =
let request_admin_api_key =
get_api_key(request_headers).change_context(errors::ApiErrorResponse::Unauthorized)?;
let conf = state.conf();
if admin_api_key != conf.secrets.admin_api_key {
let admin_api_key = get_admin_api_key(
&conf.secrets,
#[cfg(feature = "kms")]
&conf.kms,
)
.await?;
if request_admin_api_key != admin_api_key.peek() {
Err(report!(errors::ApiErrorResponse::Unauthorized)
.attach_printable("Admin Authentication Failure"))?;
}
Ok(())
}
}
@ -213,7 +250,9 @@ where
) -> RouterResult<()> {
let mut token = get_jwt(request_headers)?;
token = strip_jwt_token(token)?;
decode_jwt::<JwtAuthPayloadFetchUnit>(token, state).map(|_| ())
decode_jwt::<JwtAuthPayloadFetchUnit>(token, state)
.await
.map(|_| ())
}
}
@ -234,7 +273,7 @@ where
) -> RouterResult<storage::MerchantAccount> {
let mut token = get_jwt(request_headers)?;
token = strip_jwt_token(token)?;
let payload = decode_jwt::<JwtAuthPayloadFetchMerchantAccount>(token, state)?;
let payload = decode_jwt::<JwtAuthPayloadFetchMerchantAccount>(token, state).await?;
state
.store()
.find_merchant_account_by_merchant_id(&payload.merchant_id)
@ -353,12 +392,44 @@ pub fn is_jwt_auth(headers: &HeaderMap) -> bool {
headers.get(crate::headers::AUTHORIZATION).is_some()
}
pub fn decode_jwt<T>(token: &str, state: &impl AppStateInfo) -> RouterResult<T>
static JWT_SECRET: tokio::sync::OnceCell<StrongSecret<String>> = tokio::sync::OnceCell::const_new();
pub async fn get_jwt_secret(
secrets: &settings::Secrets,
#[cfg(feature = "kms")] kms_config: &kms::KmsConfig,
) -> RouterResult<&'static StrongSecret<String>> {
JWT_SECRET
.get_or_try_init(|| async {
#[cfg(feature = "kms")]
let jwt_secret = kms::get_kms_client(kms_config)
.await
.decrypt(&secrets.kms_encrypted_jwt_secret)
.await
.change_context(errors::ApiErrorResponse::InternalServerError)
.attach_printable("Failed to KMS decrypt JWT secret")?;
#[cfg(not(feature = "kms"))]
let jwt_secret = secrets.jwt_secret.clone();
Ok(StrongSecret::new(jwt_secret))
})
.await
}
pub async fn decode_jwt<T>(token: &str, state: &impl AppStateInfo) -> RouterResult<T>
where
T: serde::de::DeserializeOwned,
{
let conf = state.conf();
let secret = conf.secrets.jwt_secret.as_bytes();
let secret = get_jwt_secret(
&conf.secrets,
#[cfg(feature = "kms")]
&conf.kms,
)
.await?
.peek()
.as_bytes();
let key = DecodingKey::from_secret(secret);
decode::<T>(token, &key, &Validation::new(Algorithm::HS256))
.map(|decoded| decoded.claims)