Skip to content

Commit

Permalink
refactor: introducing hyperswitch_interface crates (#3536)
Browse files Browse the repository at this point in the history
Co-authored-by: hyperswitch-bot[bot] <148525504+hyperswitch-bot[bot]@users.noreply.github.com>
  • Loading branch information
Chethan-rao and hyperswitch-bot[bot] authored Feb 12, 2024
1 parent 33df352 commit b6754a7
Show file tree
Hide file tree
Showing 44 changed files with 1,157 additions and 628 deletions.
14 changes: 14 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions crates/drainer/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use diesel::PgConnection;
#[cfg(feature = "aws_kms")]
use external_services::aws_kms::{self, decrypt::AwsKmsDecrypt};
#[cfg(feature = "hashicorp-vault")]
use external_services::hashicorp_vault::{self, decrypt::VaultFetch, Kv2};
use external_services::hashicorp_vault::{
core::{HashiCorpVault, Kv2},
decrypt::VaultFetch,
};
#[cfg(not(feature = "aws_kms"))]
use masking::PeekInterface;

Expand All @@ -28,8 +31,8 @@ pub async fn redis_connection(
pub async fn diesel_make_pg_pool(
database: &Database,
_test_transaction: bool,
#[cfg(feature = "aws_kms")] aws_kms_client: &'static aws_kms::AwsKmsClient,
#[cfg(feature = "hashicorp-vault")] hashicorp_client: &'static hashicorp_vault::HashiCorpVault,
#[cfg(feature = "aws_kms")] aws_kms_client: &'static aws_kms::core::AwsKmsClient,
#[cfg(feature = "hashicorp-vault")] hashicorp_client: &'static HashiCorpVault,
) -> PgPool {
let password = database.password.clone();
#[cfg(feature = "hashicorp-vault")]
Expand Down
4 changes: 2 additions & 2 deletions crates/drainer/src/services.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ impl Store {
&config.master_database,
test_transaction,
#[cfg(feature = "aws_kms")]
external_services::aws_kms::get_aws_kms_client(&config.kms).await,
external_services::aws_kms::core::get_aws_kms_client(&config.kms).await,
#[cfg(feature = "hashicorp-vault")]
#[allow(clippy::expect_used)]
external_services::hashicorp_vault::get_hashicorp_client(&config.hc_vault)
external_services::hashicorp_vault::core::get_hashicorp_client(&config.hc_vault)
.await
.expect("Failed while getting hashicorp client"),
)
Expand Down
6 changes: 3 additions & 3 deletions crates/drainer/src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use serde::Deserialize;
use crate::errors;

#[cfg(feature = "aws_kms")]
pub type Password = aws_kms::AwsKmsValue;
pub type Password = aws_kms::core::AwsKmsValue;
#[cfg(not(feature = "aws_kms"))]
pub type Password = masking::Secret<String>;

Expand All @@ -36,9 +36,9 @@ pub struct Settings {
pub log: Log,
pub drainer: DrainerSettings,
#[cfg(feature = "aws_kms")]
pub kms: aws_kms::AwsKmsConfig,
pub kms: aws_kms::core::AwsKmsConfig,
#[cfg(feature = "hashicorp-vault")]
pub hc_vault: hashicorp_vault::HashiCorpVaultConfig,
pub hc_vault: hashicorp_vault::core::HashiCorpVaultConfig,
}

#[derive(Debug, Deserialize, Clone)]
Expand Down
1 change: 1 addition & 0 deletions crates/external_services/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,6 @@ hex = "0.4.3"

# First party crates
common_utils = { version = "0.1.0", path = "../common_utils" }
hyperswitch_interfaces = { version = "0.1.0", path = "../hyperswitch_interfaces" }
masking = { version = "0.1.0", path = "../masking" }
router_env = { version = "0.1.0", path = "../router_env", features = ["log_extra_implicit_fields", "log_custom_entries_to_extra"] }
281 changes: 2 additions & 279 deletions crates/external_services/src/aws_kms.rs
Original file line number Diff line number Diff line change
@@ -1,284 +1,7 @@
//! Interactions with the AWS KMS SDK
use std::time::Instant;
pub mod core;

use aws_config::meta::region::RegionProviderChain;
use aws_sdk_kms::{config::Region, primitives::Blob, Client};
use base64::Engine;
use common_utils::errors::CustomResult;
use error_stack::{IntoReport, ResultExt};
use masking::{PeekInterface, Secret};
use router_env::logger;
/// decrypting data using the AWS KMS SDK.
pub mod decrypt;

use crate::{consts, metrics};

static AWS_KMS_CLIENT: tokio::sync::OnceCell<AwsKmsClient> = tokio::sync::OnceCell::const_new();

/// Returns a shared AWS KMS client, or initializes a new one if not previously initialized.
#[inline]
pub async fn get_aws_kms_client(config: &AwsKmsConfig) -> &'static AwsKmsClient {
AWS_KMS_CLIENT
.get_or_init(|| AwsKmsClient::new(config))
.await
}

/// Configuration parameters required for constructing a [`AwsKmsClient`].
#[derive(Clone, Debug, Default, serde::Deserialize)]
#[serde(default)]
pub struct AwsKmsConfig {
/// The AWS key identifier of the KMS key used to encrypt or decrypt data.
pub key_id: String,

/// The AWS region to send KMS requests to.
pub region: String,
}

/// Client for AWS KMS operations.
#[derive(Debug)]
pub struct AwsKmsClient {
inner_client: Client,
key_id: String,
}

impl AwsKmsClient {
/// Constructs a new AWS KMS client.
pub async fn new(config: &AwsKmsConfig) -> Self {
let region_provider = RegionProviderChain::first_try(Region::new(config.region.clone()));
let sdk_config = aws_config::from_env().region(region_provider).load().await;

Self {
inner_client: Client::new(&sdk_config),
key_id: config.key_id.clone(),
}
}

/// Decrypts the provided base64-encoded encrypted data using the AWS KMS SDK. We assume that
/// the SDK has the values required to interact with the AWS KMS APIs (`AWS_ACCESS_KEY_ID` and
/// `AWS_SECRET_ACCESS_KEY`) either set in environment variables, or that the SDK is running in
/// a machine that is able to assume an IAM role.
pub async fn decrypt(&self, data: impl AsRef<[u8]>) -> CustomResult<String, AwsKmsError> {
let start = Instant::now();
let data = consts::BASE64_ENGINE
.decode(data)
.into_report()
.change_context(AwsKmsError::Base64DecodingFailed)?;
let ciphertext_blob = Blob::new(data);

let decrypt_output = self
.inner_client
.decrypt()
.key_id(&self.key_id)
.ciphertext_blob(ciphertext_blob)
.send()
.await
.map_err(|error| {
// Logging using `Debug` representation of the error as the `Display`
// representation does not hold sufficient information.
logger::error!(aws_kms_sdk_error=?error, "Failed to AWS KMS decrypt data");
metrics::AWS_KMS_DECRYPTION_FAILURES.add(&metrics::CONTEXT, 1, &[]);
error
})
.into_report()
.change_context(AwsKmsError::DecryptionFailed)?;

let output = decrypt_output
.plaintext
.ok_or(AwsKmsError::MissingPlaintextDecryptionOutput)
.into_report()
.and_then(|blob| {
String::from_utf8(blob.into_inner())
.into_report()
.change_context(AwsKmsError::Utf8DecodingFailed)
})?;

let time_taken = start.elapsed();
metrics::AWS_KMS_DECRYPT_TIME.record(&metrics::CONTEXT, time_taken.as_secs_f64(), &[]);

Ok(output)
}

/// Encrypts the provided String data using the AWS KMS SDK. We assume that
/// the SDK has the values required to interact with the AWS KMS APIs (`AWS_ACCESS_KEY_ID` and
/// `AWS_SECRET_ACCESS_KEY`) either set in environment variables, or that the SDK is running in
/// a machine that is able to assume an IAM role.
pub async fn encrypt(&self, data: impl AsRef<[u8]>) -> CustomResult<String, AwsKmsError> {
let start = Instant::now();
let plaintext_blob = Blob::new(data.as_ref());

let encrypted_output = self
.inner_client
.encrypt()
.key_id(&self.key_id)
.plaintext(plaintext_blob)
.send()
.await
.map_err(|error| {
// Logging using `Debug` representation of the error as the `Display`
// representation does not hold sufficient information.
logger::error!(aws_kms_sdk_error=?error, "Failed to AWS KMS encrypt data");
metrics::AWS_KMS_ENCRYPTION_FAILURES.add(&metrics::CONTEXT, 1, &[]);
error
})
.into_report()
.change_context(AwsKmsError::EncryptionFailed)?;

let output = encrypted_output
.ciphertext_blob
.ok_or(AwsKmsError::MissingCiphertextEncryptionOutput)
.into_report()
.map(|blob| consts::BASE64_ENGINE.encode(blob.into_inner()))?;
let time_taken = start.elapsed();
metrics::AWS_KMS_ENCRYPT_TIME.record(&metrics::CONTEXT, time_taken.as_secs_f64(), &[]);

Ok(output)
}
}

/// Errors that could occur during AWS KMS operations.
#[derive(Debug, thiserror::Error)]
pub enum AwsKmsError {
/// An error occurred when base64 encoding input data.
#[error("Failed to base64 encode input data")]
Base64EncodingFailed,

/// An error occurred when base64 decoding input data.
#[error("Failed to base64 decode input data")]
Base64DecodingFailed,

/// An error occurred when AWS KMS decrypting input data.
#[error("Failed to AWS KMS decrypt input data")]
DecryptionFailed,

/// An error occurred when AWS KMS encrypting input data.
#[error("Failed to AWS KMS encrypt input data")]
EncryptionFailed,

/// The AWS KMS decrypted output does not include a plaintext output.
#[error("Missing plaintext AWS KMS decryption output")]
MissingPlaintextDecryptionOutput,

/// The AWS KMS encrypted output does not include a ciphertext output.
#[error("Missing ciphertext AWS KMS encryption output")]
MissingCiphertextEncryptionOutput,

/// An error occurred UTF-8 decoding AWS KMS decrypted output.
#[error("Failed to UTF-8 decode decryption output")]
Utf8DecodingFailed,

/// The AWS KMS client has not been initialized.
#[error("The AWS KMS client has not been initialized")]
AwsKmsClientNotInitialized,
}

impl AwsKmsConfig {
/// Verifies that the [`AwsKmsClient`] configuration is usable.
pub fn validate(&self) -> Result<(), &'static str> {
use common_utils::{ext_traits::ConfigExt, fp_utils::when};

when(self.key_id.is_default_or_empty(), || {
Err("KMS AWS key ID must not be empty")
})?;

when(self.region.is_default_or_empty(), || {
Err("KMS AWS region must not be empty")
})
}
}

/// A wrapper around a AWS KMS value that can be decrypted.
#[derive(Clone, Debug, Default, serde::Deserialize, Eq, PartialEq)]
#[serde(transparent)]
pub struct AwsKmsValue(Secret<String>);

impl common_utils::ext_traits::ConfigExt for AwsKmsValue {
fn is_empty_after_trim(&self) -> bool {
self.0.peek().is_empty_after_trim()
}
}

impl From<String> for AwsKmsValue {
fn from(value: String) -> Self {
Self(Secret::new(value))
}
}

impl From<Secret<String>> for AwsKmsValue {
fn from(value: Secret<String>) -> Self {
Self(value)
}
}

#[cfg(feature = "hashicorp-vault")]
#[async_trait::async_trait]
impl super::hashicorp_vault::decrypt::VaultFetch for AwsKmsValue {
async fn fetch_inner<En>(
self,
client: &super::hashicorp_vault::HashiCorpVault,
) -> error_stack::Result<Self, super::hashicorp_vault::HashiCorpError>
where
for<'a> En: super::hashicorp_vault::Engine<
ReturnType<'a, String> = std::pin::Pin<
Box<
dyn std::future::Future<
Output = error_stack::Result<
String,
super::hashicorp_vault::HashiCorpError,
>,
> + Send
+ 'a,
>,
>,
> + 'a,
{
self.0.fetch_inner::<En>(client).await.map(AwsKmsValue)
}
}

#[cfg(test)]
mod tests {
#![allow(clippy::expect_used)]
#[tokio::test]
async fn check_aws_kms_encryption() {
std::env::set_var("AWS_SECRET_ACCESS_KEY", "YOUR SECRET ACCESS KEY");
std::env::set_var("AWS_ACCESS_KEY_ID", "YOUR AWS ACCESS KEY ID");
use super::*;
let config = AwsKmsConfig {
key_id: "YOUR AWS KMS KEY ID".to_string(),
region: "AWS REGION".to_string(),
};

let data = "hello".to_string();
let binding = data.as_bytes();
let kms_encrypted_fingerprint = AwsKmsClient::new(&config)
.await
.encrypt(binding)
.await
.expect("aws kms encryption failed");

println!("{}", kms_encrypted_fingerprint);
}

#[tokio::test]
async fn check_aws_kms_decrypt() {
std::env::set_var("AWS_SECRET_ACCESS_KEY", "YOUR SECRET ACCESS KEY");
std::env::set_var("AWS_ACCESS_KEY_ID", "YOUR AWS ACCESS KEY ID");
use super::*;
let config = AwsKmsConfig {
key_id: "YOUR AWS KMS KEY ID".to_string(),
region: "AWS REGION".to_string(),
};

// Should decrypt to hello
let data = "AWS KMS ENCRYPTED CIPHER".to_string();
let binding = data.as_bytes();
let kms_encrypted_fingerprint = AwsKmsClient::new(&config)
.await
.decrypt(binding)
.await
.expect("aws kms decryption failed");

println!("{}", kms_encrypted_fingerprint);
}
}
pub mod implementers;
Loading

0 comments on commit b6754a7

Please sign in to comment.