Skip to content

Commit c5aa660

Browse files
committed
Fix rustls feature
1 parent 174f40f commit c5aa660

File tree

6 files changed

+115
-49
lines changed

6 files changed

+115
-49
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ readme = "README.md"
99
repository = "https://github.com/blackbeam/mysql_async"
1010
version = "0.34.0"
1111
exclude = ["test/*"]
12-
edition = "2018"
12+
edition = "2021"
1313
categories = ["asynchronous", "database"]
1414

1515
[dependencies]

src/error/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
pub use url::ParseError;
1010

11-
mod tls;
11+
pub mod tls;
1212

1313
use mysql_common::{
1414
named_params::MixedParamsError, params::MissingNamedParameterError,

src/error/tls/rustls_error.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22

33
use std::fmt::Display;
44

5+
use rustls::server::VerifierBuilderError;
6+
57
#[derive(Debug)]
68
pub enum TlsError {
79
Tls(rustls::Error),
810
Pki(webpki::Error),
911
InvalidDnsName(webpki::InvalidDnsNameError),
12+
VerifierBuilderError(VerifierBuilderError),
1013
}
1114

1215
impl From<TlsError> for crate::Error {
@@ -15,6 +18,12 @@ impl From<TlsError> for crate::Error {
1518
}
1619
}
1720

21+
impl From<VerifierBuilderError> for TlsError {
22+
fn from(e: VerifierBuilderError) -> Self {
23+
TlsError::VerifierBuilderError(e)
24+
}
25+
}
26+
1827
impl From<rustls::Error> for TlsError {
1928
fn from(e: rustls::Error) -> Self {
2029
TlsError::Tls(e)
@@ -57,6 +66,7 @@ impl std::error::Error for TlsError {
5766
TlsError::Tls(e) => Some(e),
5867
TlsError::Pki(e) => Some(e),
5968
TlsError::InvalidDnsName(e) => Some(e),
69+
TlsError::VerifierBuilderError(e) => Some(e),
6070
}
6171
}
6272
}
@@ -67,6 +77,7 @@ impl Display for TlsError {
6777
TlsError::Tls(e) => e.fmt(f),
6878
TlsError::Pki(e) => e.fmt(f),
6979
TlsError::InvalidDnsName(e) => e.fmt(f),
80+
TlsError::VerifierBuilderError(e) => e.fmt(f),
7081
}
7182
}
7283
}

src/io/tls/rustls_io.rs

Lines changed: 85 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,35 @@
11
#![cfg(feature = "rustls-tls")]
22

3-
use std::{convert::TryInto, sync::Arc};
3+
use std::sync::Arc;
44

55
use rustls::{
6-
client::{ServerCertVerifier, WebPkiVerifier},
7-
Certificate, ClientConfig, OwnedTrustAnchor, RootCertStore,
6+
client::{
7+
danger::{ServerCertVerified, ServerCertVerifier},
8+
WebPkiServerVerifier,
9+
},
10+
pki_types::{CertificateDer, ServerName},
11+
ClientConfig, RootCertStore,
812
};
913

1014
use rustls_pemfile::certs;
1115
use tokio_rustls::TlsConnector;
1216

13-
use crate::{io::Endpoint, Result, SslOpts};
17+
use crate::{io::Endpoint, Result, SslOpts, TlsError};
1418

1519
impl SslOpts {
16-
async fn load_root_certs(&self) -> crate::Result<Vec<Certificate>> {
20+
async fn load_root_certs(&self) -> crate::Result<Vec<CertificateDer<'static>>> {
1721
let mut output = Vec::new();
1822

1923
for root_cert in self.root_certs() {
2024
let root_cert_data = root_cert.read().await?;
2125
let mut seen = false;
22-
for cert in certs(&mut &*root_cert_data)? {
26+
for cert in certs(&mut &*root_cert_data) {
2327
seen = true;
24-
output.push(Certificate(cert));
28+
output.push(cert?);
2529
}
2630

2731
if !seen && !root_cert_data.is_empty() {
28-
output.push(Certificate(root_cert_data.into_owned()));
32+
output.push(CertificateDer::from(root_cert_data.into_owned()));
2933
}
3034
}
3135

@@ -42,21 +46,13 @@ impl Endpoint {
4246
}
4347

4448
let mut root_store = RootCertStore::empty();
45-
root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
46-
OwnedTrustAnchor::from_subject_spki_name_constraints(
47-
ta.subject,
48-
ta.spki,
49-
ta.name_constraints,
50-
)
51-
}));
49+
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().map(|x| x.to_owned()));
5250

5351
for cert in ssl_opts.load_root_certs().await? {
54-
root_store.add(&cert)?;
52+
root_store.add(cert)?;
5553
}
5654

57-
let config_builder = ClientConfig::builder()
58-
.with_safe_defaults()
59-
.with_root_certificates(root_store.clone());
55+
let config_builder = ClientConfig::builder().with_root_certificates(root_store.clone());
6056

6157
let mut config = if let Some(identity) = ssl_opts.client_identity() {
6258
let (cert_chain, priv_key) = identity.load().await?;
@@ -65,12 +61,13 @@ impl Endpoint {
6561
config_builder.with_no_client_auth()
6662
};
6763

68-
let server_name = domain
69-
.as_str()
70-
.try_into()
71-
.map_err(|_| webpki::InvalidDnsNameError)?;
64+
let server_name = ServerName::try_from(domain.as_str())
65+
.map_err(|_| webpki::InvalidDnsNameError)?
66+
.to_owned();
7267
let mut dangerous = config.dangerous();
73-
let web_pki_verifier = WebPkiVerifier::new(root_store, None);
68+
let web_pki_verifier = WebPkiServerVerifier::builder(Arc::new(root_store))
69+
.build()
70+
.map_err(TlsError::from)?;
7471
let dangerous_verifier = DangerousVerifier::new(
7572
ssl_opts.accept_invalid_certs(),
7673
ssl_opts.skip_domain_validation(),
@@ -97,17 +94,18 @@ impl Endpoint {
9794
}
9895
}
9996

97+
#[derive(Debug)]
10098
struct DangerousVerifier {
10199
accept_invalid_certs: bool,
102100
skip_domain_validation: bool,
103-
verifier: WebPkiVerifier,
101+
verifier: Arc<WebPkiServerVerifier>,
104102
}
105103

106104
impl DangerousVerifier {
107105
fn new(
108106
accept_invalid_certs: bool,
109107
skip_domain_validation: bool,
110-
verifier: WebPkiVerifier,
108+
verifier: Arc<WebPkiServerVerifier>,
111109
) -> Self {
112110
Self {
113111
accept_invalid_certs,
@@ -118,34 +116,86 @@ impl DangerousVerifier {
118116
}
119117

120118
impl ServerCertVerifier for DangerousVerifier {
119+
// fn verify_server_cert(
120+
// &self,
121+
// end_entity: &Certificate,
122+
// intermediates: &[Certificate],
123+
// server_name: &rustls::ServerName,
124+
// scts: &mut dyn Iterator<Item = &[u8]>,
125+
// ocsp_response: &[u8],
126+
// now: std::time::SystemTime,
127+
// ) -> std::result::Result<rustls::client::ServerCertVerified, rustls::Error> {
128+
// if self.accept_invalid_certs {
129+
// Ok(rustls::client::ServerCertVerified::assertion())
130+
// } else {
131+
// match self.verifier.verify_server_cert(
132+
// end_entity,
133+
// intermediates,
134+
// server_name,
135+
// scts,
136+
// ocsp_response,
137+
// now,
138+
// ) {
139+
// Ok(assertion) => Ok(assertion),
140+
// Err(ref e)
141+
// if e.to_string().contains("NotValidForName") && self.skip_domain_validation =>
142+
// {
143+
// Ok(rustls::client::ServerCertVerified::assertion())
144+
// }
145+
// Err(e) => Err(e),
146+
// }
147+
// }
148+
// }
121149
fn verify_server_cert(
122150
&self,
123-
end_entity: &Certificate,
124-
intermediates: &[Certificate],
125-
server_name: &rustls::ServerName,
126-
scts: &mut dyn Iterator<Item = &[u8]>,
151+
end_entity: &CertificateDer<'_>,
152+
intermediates: &[CertificateDer<'_>],
153+
server_name: &rustls::pki_types::ServerName<'_>,
127154
ocsp_response: &[u8],
128-
now: std::time::SystemTime,
129-
) -> std::result::Result<rustls::client::ServerCertVerified, rustls::Error> {
155+
now: rustls::pki_types::UnixTime,
156+
) -> std::prelude::v1::Result<ServerCertVerified, rustls::Error> {
130157
if self.accept_invalid_certs {
131-
Ok(rustls::client::ServerCertVerified::assertion())
158+
Ok(ServerCertVerified::assertion())
132159
} else {
133160
match self.verifier.verify_server_cert(
134161
end_entity,
135162
intermediates,
136163
server_name,
137-
scts,
138164
ocsp_response,
139165
now,
140166
) {
141167
Ok(assertion) => Ok(assertion),
142168
Err(ref e)
143169
if e.to_string().contains("NotValidForName") && self.skip_domain_validation =>
144170
{
145-
Ok(rustls::client::ServerCertVerified::assertion())
171+
Ok(ServerCertVerified::assertion())
146172
}
147173
Err(e) => Err(e),
148174
}
149175
}
150176
}
177+
178+
fn verify_tls12_signature(
179+
&self,
180+
message: &[u8],
181+
cert: &CertificateDer<'_>,
182+
dss: &rustls::DigitallySignedStruct,
183+
) -> std::prelude::v1::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
184+
{
185+
self.verifier.verify_tls12_signature(message, cert, dss)
186+
}
187+
188+
fn verify_tls13_signature(
189+
&self,
190+
message: &[u8],
191+
cert: &CertificateDer<'_>,
192+
dss: &rustls::DigitallySignedStruct,
193+
) -> std::prelude::v1::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error>
194+
{
195+
self.verifier.verify_tls13_signature(message, cert, dss)
196+
}
197+
198+
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
199+
self.verifier.supported_verify_schemes()
200+
}
151201
}

src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,8 @@ pub use self::conn::pool::Pool;
466466

467467
#[doc(inline)]
468468
pub use self::error::{
469-
DriverError, Error, IoError, LocalInfileError, ParseError, Result, ServerError, UrlError,
469+
tls::TlsError, DriverError, Error, IoError, LocalInfileError, ParseError, Result, ServerError,
470+
UrlError,
470471
};
471472

472473
#[doc(inline)]

src/opts/rustls_opts.rs

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#![cfg(feature = "rustls-tls")]
22

3-
use rustls::{Certificate, PrivateKey};
3+
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs1KeyDer};
44
use rustls_pemfile::{certs, rsa_private_keys};
55

66
use std::{borrow::Cow, path::Path};
@@ -50,27 +50,31 @@ impl ClientIdentity {
5050
self.priv_key.borrow()
5151
}
5252

53-
pub(crate) async fn load(&self) -> crate::Result<(Vec<Certificate>, PrivateKey)> {
53+
pub(crate) async fn load(
54+
&self,
55+
) -> crate::Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
5456
let cert_data = self.cert_chain.read().await?;
5557
let key_data = self.priv_key.read().await?;
5658

5759
let mut cert_chain = Vec::new();
5860
if std::str::from_utf8(&cert_data).is_err() {
59-
cert_chain.push(Certificate(cert_data.into_owned()));
61+
cert_chain.push(CertificateDer::from(cert_data.into_owned()));
6062
} else {
61-
for cert in certs(&mut &*cert_data)? {
62-
cert_chain.push(Certificate(cert));
63+
for cert in certs(&mut &*cert_data) {
64+
cert_chain.push(cert?);
6365
}
6466
}
6567

6668
let priv_key = if std::str::from_utf8(&key_data).is_err() {
67-
Some(PrivateKey(key_data.into_owned()))
69+
Some(PrivateKeyDer::Pkcs1(PrivatePkcs1KeyDer::from(
70+
key_data.into_owned(),
71+
)))
6872
} else {
69-
rsa_private_keys(&mut &*key_data)?
70-
.into_iter()
71-
.take(1)
72-
.map(PrivateKey)
73-
.next()
73+
let mut priv_key = None;
74+
for key in rsa_private_keys(&mut &*key_data).take(1) {
75+
priv_key = Some(PrivateKeyDer::Pkcs1(key?.clone_key()));
76+
}
77+
priv_key
7478
};
7579

7680
Ok((

0 commit comments

Comments
 (0)