Skip to content

Commit

Permalink
Merge pull request #687 from hatoo/mtls
Browse files Browse the repository at this point in the history
Support mtls
  • Loading branch information
hatoo authored Feb 8, 2025
2 parents 365d8ab + 3aa16f8 commit 3eae0c1
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 17 deletions.
80 changes: 78 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 6 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,21 @@ rlimit = "0.10.1"
[dev-dependencies]
assert_cmd = "2.0.14"
axum = { version = "0.8.1", features = ["http2"] }
axum-server = { version = "0.7.1", features = ["tls-rustls"] }
bytes = "1.6"
float-cmp = "0.10.0"
http-mitm-proxy = "0.12.0"
jsonschema = "0.28.1"
lazy_static = "1.5.0"
predicates = "3.1.0"
rcgen = "0.13.1"
# features = ["aws_lc_rs"] is a workaround for mac & native-tls
# https://github.com/sfackler/rust-native-tls/issues/225
rcgen = { version = "0.13.1", features = ["aws_lc_rs"] }
regex = "1.10.5"
tempfile = "3.10.1"
rustls = "0.23.18"

[target.'cfg(unix)'.dev-dependencies]
tempfile = "3.10.1"
actix-web = "4"

[profile.pgo]
Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ Options:
Lookup only ipv6.
--ipv4
Lookup only ipv4.
--cacert <CACERT>
(TLS) Use the specified certificate file to verify the peer. Native certificate store is used even if this argument is specified.
--cert <CERT>
(TLS) Use the specified client certificate file. --key must be also specified
--key <KEY>
(TLS) Use the specified client key file. --cert must be also specified
--insecure
Accept invalid certs.
--connect-to <CONNECT_TO>
Expand Down
4 changes: 2 additions & 2 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ mod test_db {
#[cfg(feature = "vsock")]
vsock_addr: None,
#[cfg(feature = "rustls")]
rustls_configs: crate::tls_config::RuslsConfigs::new(false),
rustls_configs: crate::tls_config::RuslsConfigs::new(false, None, None),
#[cfg(all(feature = "native-tls", not(feature = "rustls")))]
native_tls_connectors: crate::tls_config::NativeTlsConnectors::new(false),
native_tls_connectors: crate::tls_config::NativeTlsConnectors::new(false, None, None),
};
let result = store(&client, ":memory:", start, &test_vec);
assert_eq!(result.unwrap(), 2);
Expand Down
44 changes: 38 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,21 @@ Note: If qps is specified, burst will be ignored",
ipv6: bool,
#[arg(help = "Lookup only ipv4.", long = "ipv4")]
ipv4: bool,
#[arg(
help = "(TLS) Use the specified certificate file to verify the peer. Native certificate store is used even if this argument is specified.",
long
)]
cacert: Option<PathBuf>,
#[arg(
help = "(TLS) Use the specified client certificate file. --key must be also specified",
long
)]
cert: Option<PathBuf>,
#[arg(
help = "(TLS) Use the specified client key file. --cert must be also specified",
long
)]
key: Option<PathBuf>,
#[arg(help = "Accept invalid certs.", long = "insecure")]
insecure: bool,
#[arg(
Expand Down Expand Up @@ -520,6 +535,13 @@ async fn run() -> anyhow::Result<()> {
let (config, mut resolver_opts) = system_resolv_conf()?;
resolver_opts.ip_strategy = ip_strategy;
let resolver = hickory_resolver::AsyncResolver::tokio(config, resolver_opts);
let cacert = opts.cacert.as_deref().map(std::fs::read).transpose()?;
let client_auth = match (opts.cert, opts.key) {
(Some(cert), Some(key)) => Some((std::fs::read(cert)?, std::fs::read(key)?)),
(None, None) => None,
// TODO: Ensure it on clap
_ => anyhow::bail!("Both --cert and --key must be specified"),
};

let client = Arc::new(client::Client {
aws_config,
Expand All @@ -542,9 +564,21 @@ async fn run() -> anyhow::Result<()> {
#[cfg(feature = "vsock")]
vsock_addr: opts.vsock_addr.map(|v| v.0),
#[cfg(feature = "rustls")]
rustls_configs: tls_config::RuslsConfigs::new(opts.insecure),
rustls_configs: tls_config::RuslsConfigs::new(
opts.insecure,
cacert.as_deref(),
client_auth
.as_ref()
.map(|(cert, key)| (cert.as_slice(), key.as_slice())),
),
#[cfg(all(feature = "native-tls", not(feature = "rustls")))]
native_tls_connectors: tls_config::NativeTlsConnectors::new(opts.insecure),
native_tls_connectors: tls_config::NativeTlsConnectors::new(
opts.insecure,
cacert.as_deref(),
client_auth
.as_ref()
.map(|(cert, key)| (cert.as_slice(), key.as_slice())),
),
});

if !opts.no_pre_lookup {
Expand Down Expand Up @@ -595,10 +629,8 @@ async fn run() -> anyhow::Result<()> {
match work_mode {
WorkMode::Debug => {
let mut print_config = print_config;
if let Err(e) = client::work_debug(&mut print_config.output, client).await {
eprintln!("{e}");
}
std::process::exit(libc::EXIT_SUCCESS)
client::work_debug(&mut print_config.output, client).await?;
return Ok(());
}
WorkMode::FixedNumber {
n_requests,
Expand Down
48 changes: 43 additions & 5 deletions src/tls_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,38 @@ pub struct RuslsConfigs {

#[cfg(feature = "rustls")]
impl RuslsConfigs {
pub fn new(insecure: bool) -> Self {
pub fn new(
insecure: bool,
cacert_pem: Option<&[u8]>,
client_auth: Option<(&[u8], &[u8])>,
) -> Self {
use rustls_pki_types::pem::PemObject;
use std::sync::Arc;

let mut root_cert_store = rustls::RootCertStore::empty();
for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs")
{
root_cert_store.add(cert).unwrap();
}
let mut config = rustls::ClientConfig::builder()
.with_root_certificates(root_cert_store.clone())
.with_no_client_auth();

if let Some(cacert_pem) = cacert_pem {
for der in rustls_pki_types::CertificateDer::pem_slice_iter(cacert_pem) {
root_cert_store.add(der.unwrap()).unwrap();
}
}

let builder = rustls::ClientConfig::builder().with_root_certificates(root_cert_store);

let mut config = if let Some((cert, key)) = client_auth {
let certs = rustls_pki_types::CertificateDer::pem_slice_iter(cert)
.collect::<Result<Vec<_>, _>>()
.unwrap();
let key = rustls_pki_types::PrivateKeyDer::from_pem_slice(key).unwrap();

builder.with_client_auth_cert(certs, key).unwrap()
} else {
builder.with_no_client_auth()
};
if insecure {
config
.dangerous()
Expand Down Expand Up @@ -50,15 +71,32 @@ pub struct NativeTlsConnectors {

#[cfg(all(feature = "native-tls", not(feature = "rustls")))]
impl NativeTlsConnectors {
pub fn new(insecure: bool) -> Self {
pub fn new(
insecure: bool,
cacert_pem: Option<&[u8]>,
client_auth: Option<(&[u8], &[u8])>,
) -> Self {
let new = |is_http2: bool| {
let mut connector_builder = native_tls::TlsConnector::builder();

if let Some(cacert_pem) = cacert_pem {
let cert = native_tls::Certificate::from_pem(cacert_pem)
.expect("Failed to parse cacert_pem");
connector_builder.add_root_certificate(cert);
}

if insecure {
connector_builder
.danger_accept_invalid_certs(true)
.danger_accept_invalid_hostnames(true);
}

if let Some((cert, key)) = client_auth {
let cert = native_tls::Identity::from_pkcs8(cert, key)
.expect("Failed to parse client_auth cert/key");
connector_builder.identity(cert);
}

if is_http2 {
connector_builder.request_alpns(&["h2"]);
}
Expand Down
Loading

0 comments on commit 3eae0c1

Please sign in to comment.