From c5aa660e5174157be89b397a340a4677159ddc03 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Tue, 19 Mar 2024 00:13:15 +0300 Subject: [PATCH] Fix rustls feature --- Cargo.toml | 2 +- src/error/mod.rs | 2 +- src/error/tls/rustls_error.rs | 11 ++++ src/io/tls/rustls_io.rs | 120 ++++++++++++++++++++++++---------- src/lib.rs | 3 +- src/opts/rustls_opts.rs | 26 ++++---- 6 files changed, 115 insertions(+), 49 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2e0f56de..da62f97d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ readme = "README.md" repository = "https://github.com/blackbeam/mysql_async" version = "0.34.0" exclude = ["test/*"] -edition = "2018" +edition = "2021" categories = ["asynchronous", "database"] [dependencies] diff --git a/src/error/mod.rs b/src/error/mod.rs index 983f274e..ffb788f8 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -8,7 +8,7 @@ pub use url::ParseError; -mod tls; +pub mod tls; use mysql_common::{ named_params::MixedParamsError, params::MissingNamedParameterError, diff --git a/src/error/tls/rustls_error.rs b/src/error/tls/rustls_error.rs index 2ee67d39..faae88db 100644 --- a/src/error/tls/rustls_error.rs +++ b/src/error/tls/rustls_error.rs @@ -2,11 +2,14 @@ use std::fmt::Display; +use rustls::server::VerifierBuilderError; + #[derive(Debug)] pub enum TlsError { Tls(rustls::Error), Pki(webpki::Error), InvalidDnsName(webpki::InvalidDnsNameError), + VerifierBuilderError(VerifierBuilderError), } impl From for crate::Error { @@ -15,6 +18,12 @@ impl From for crate::Error { } } +impl From for TlsError { + fn from(e: VerifierBuilderError) -> Self { + TlsError::VerifierBuilderError(e) + } +} + impl From for TlsError { fn from(e: rustls::Error) -> Self { TlsError::Tls(e) @@ -57,6 +66,7 @@ impl std::error::Error for TlsError { TlsError::Tls(e) => Some(e), TlsError::Pki(e) => Some(e), TlsError::InvalidDnsName(e) => Some(e), + TlsError::VerifierBuilderError(e) => Some(e), } } } @@ -67,6 +77,7 @@ impl Display for TlsError { TlsError::Tls(e) => e.fmt(f), TlsError::Pki(e) => e.fmt(f), TlsError::InvalidDnsName(e) => e.fmt(f), + TlsError::VerifierBuilderError(e) => e.fmt(f), } } } diff --git a/src/io/tls/rustls_io.rs b/src/io/tls/rustls_io.rs index e8757d0b..76080ff2 100644 --- a/src/io/tls/rustls_io.rs +++ b/src/io/tls/rustls_io.rs @@ -1,31 +1,35 @@ #![cfg(feature = "rustls-tls")] -use std::{convert::TryInto, sync::Arc}; +use std::sync::Arc; use rustls::{ - client::{ServerCertVerifier, WebPkiVerifier}, - Certificate, ClientConfig, OwnedTrustAnchor, RootCertStore, + client::{ + danger::{ServerCertVerified, ServerCertVerifier}, + WebPkiServerVerifier, + }, + pki_types::{CertificateDer, ServerName}, + ClientConfig, RootCertStore, }; use rustls_pemfile::certs; use tokio_rustls::TlsConnector; -use crate::{io::Endpoint, Result, SslOpts}; +use crate::{io::Endpoint, Result, SslOpts, TlsError}; impl SslOpts { - async fn load_root_certs(&self) -> crate::Result> { + async fn load_root_certs(&self) -> crate::Result>> { let mut output = Vec::new(); for root_cert in self.root_certs() { let root_cert_data = root_cert.read().await?; let mut seen = false; - for cert in certs(&mut &*root_cert_data)? { + for cert in certs(&mut &*root_cert_data) { seen = true; - output.push(Certificate(cert)); + output.push(cert?); } if !seen && !root_cert_data.is_empty() { - output.push(Certificate(root_cert_data.into_owned())); + output.push(CertificateDer::from(root_cert_data.into_owned())); } } @@ -42,21 +46,13 @@ impl Endpoint { } let mut root_store = RootCertStore::empty(); - root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - })); + root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().map(|x| x.to_owned())); for cert in ssl_opts.load_root_certs().await? { - root_store.add(&cert)?; + root_store.add(cert)?; } - let config_builder = ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_store.clone()); + let config_builder = ClientConfig::builder().with_root_certificates(root_store.clone()); let mut config = if let Some(identity) = ssl_opts.client_identity() { let (cert_chain, priv_key) = identity.load().await?; @@ -65,12 +61,13 @@ impl Endpoint { config_builder.with_no_client_auth() }; - let server_name = domain - .as_str() - .try_into() - .map_err(|_| webpki::InvalidDnsNameError)?; + let server_name = ServerName::try_from(domain.as_str()) + .map_err(|_| webpki::InvalidDnsNameError)? + .to_owned(); let mut dangerous = config.dangerous(); - let web_pki_verifier = WebPkiVerifier::new(root_store, None); + let web_pki_verifier = WebPkiServerVerifier::builder(Arc::new(root_store)) + .build() + .map_err(TlsError::from)?; let dangerous_verifier = DangerousVerifier::new( ssl_opts.accept_invalid_certs(), ssl_opts.skip_domain_validation(), @@ -97,17 +94,18 @@ impl Endpoint { } } +#[derive(Debug)] struct DangerousVerifier { accept_invalid_certs: bool, skip_domain_validation: bool, - verifier: WebPkiVerifier, + verifier: Arc, } impl DangerousVerifier { fn new( accept_invalid_certs: bool, skip_domain_validation: bool, - verifier: WebPkiVerifier, + verifier: Arc, ) -> Self { Self { accept_invalid_certs, @@ -118,23 +116,51 @@ impl DangerousVerifier { } impl ServerCertVerifier for DangerousVerifier { + // fn verify_server_cert( + // &self, + // end_entity: &Certificate, + // intermediates: &[Certificate], + // server_name: &rustls::ServerName, + // scts: &mut dyn Iterator, + // ocsp_response: &[u8], + // now: std::time::SystemTime, + // ) -> std::result::Result { + // if self.accept_invalid_certs { + // Ok(rustls::client::ServerCertVerified::assertion()) + // } else { + // match self.verifier.verify_server_cert( + // end_entity, + // intermediates, + // server_name, + // scts, + // ocsp_response, + // now, + // ) { + // Ok(assertion) => Ok(assertion), + // Err(ref e) + // if e.to_string().contains("NotValidForName") && self.skip_domain_validation => + // { + // Ok(rustls::client::ServerCertVerified::assertion()) + // } + // Err(e) => Err(e), + // } + // } + // } fn verify_server_cert( &self, - end_entity: &Certificate, - intermediates: &[Certificate], - server_name: &rustls::ServerName, - scts: &mut dyn Iterator, + end_entity: &CertificateDer<'_>, + intermediates: &[CertificateDer<'_>], + server_name: &rustls::pki_types::ServerName<'_>, ocsp_response: &[u8], - now: std::time::SystemTime, - ) -> std::result::Result { + now: rustls::pki_types::UnixTime, + ) -> std::prelude::v1::Result { if self.accept_invalid_certs { - Ok(rustls::client::ServerCertVerified::assertion()) + Ok(ServerCertVerified::assertion()) } else { match self.verifier.verify_server_cert( end_entity, intermediates, server_name, - scts, ocsp_response, now, ) { @@ -142,10 +168,34 @@ impl ServerCertVerifier for DangerousVerifier { Err(ref e) if e.to_string().contains("NotValidForName") && self.skip_domain_validation => { - Ok(rustls::client::ServerCertVerified::assertion()) + Ok(ServerCertVerified::assertion()) } Err(e) => Err(e), } } } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> std::prelude::v1::Result + { + self.verifier.verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> std::prelude::v1::Result + { + self.verifier.verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { + self.verifier.supported_verify_schemes() + } } diff --git a/src/lib.rs b/src/lib.rs index f0605910..3e2c7983 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -466,7 +466,8 @@ pub use self::conn::pool::Pool; #[doc(inline)] pub use self::error::{ - DriverError, Error, IoError, LocalInfileError, ParseError, Result, ServerError, UrlError, + tls::TlsError, DriverError, Error, IoError, LocalInfileError, ParseError, Result, ServerError, + UrlError, }; #[doc(inline)] diff --git a/src/opts/rustls_opts.rs b/src/opts/rustls_opts.rs index 562846d8..ef954ea9 100644 --- a/src/opts/rustls_opts.rs +++ b/src/opts/rustls_opts.rs @@ -1,6 +1,6 @@ #![cfg(feature = "rustls-tls")] -use rustls::{Certificate, PrivateKey}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs1KeyDer}; use rustls_pemfile::{certs, rsa_private_keys}; use std::{borrow::Cow, path::Path}; @@ -50,27 +50,31 @@ impl ClientIdentity { self.priv_key.borrow() } - pub(crate) async fn load(&self) -> crate::Result<(Vec, PrivateKey)> { + pub(crate) async fn load( + &self, + ) -> crate::Result<(Vec>, PrivateKeyDer<'static>)> { let cert_data = self.cert_chain.read().await?; let key_data = self.priv_key.read().await?; let mut cert_chain = Vec::new(); if std::str::from_utf8(&cert_data).is_err() { - cert_chain.push(Certificate(cert_data.into_owned())); + cert_chain.push(CertificateDer::from(cert_data.into_owned())); } else { - for cert in certs(&mut &*cert_data)? { - cert_chain.push(Certificate(cert)); + for cert in certs(&mut &*cert_data) { + cert_chain.push(cert?); } } let priv_key = if std::str::from_utf8(&key_data).is_err() { - Some(PrivateKey(key_data.into_owned())) + Some(PrivateKeyDer::Pkcs1(PrivatePkcs1KeyDer::from( + key_data.into_owned(), + ))) } else { - rsa_private_keys(&mut &*key_data)? - .into_iter() - .take(1) - .map(PrivateKey) - .next() + let mut priv_key = None; + for key in rsa_private_keys(&mut &*key_data).take(1) { + priv_key = Some(PrivateKeyDer::Pkcs1(key?.clone_key())); + } + priv_key }; Ok((