From f3652b31acb22ab9e8a401b9e99edd7a863c4ee9 Mon Sep 17 00:00:00 2001 From: Jaroslav Beran Date: Thu, 23 Jan 2025 17:17:52 +0100 Subject: [PATCH] Update rustls to version 0.23 Updates the code to compile with newer versions of `rustls` and `rustls_pemfile`. Replaces `lazy_static` with `std::sync::LazyLock`, because `lazy_static` could not handle `async move` construct for unknown reason, and it's recommended to replace it anyways. Fixes some warnings generated by `cargo clippy`, mostly lifetime elisions. I didn't fix the warnings regarding the large size difference between variants of enums, because I wasn't sure about the correct fix. The `Poll:Ok` would probably still return the large data instead of the Box anyways. --- Cargo.toml | 11 ++-- src/common/tls_state.rs | 10 +--- src/connector.rs | 21 +++----- src/rusttls/stream.rs | 10 ++-- src/rusttls/test_stream.rs | 28 +++++----- tests/test.rs | 102 +++++++++++++++++++------------------ 6 files changed, 88 insertions(+), 94 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 94a1969..33b23d1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "async-tls" -version = "0.13.0" +version = "0.14.0" authors = [ "The async-rs developers", "Florian Gilcher ", @@ -23,11 +23,11 @@ appveyor = { repository = "async-std/async-tls" } [dependencies] futures-io = "0.3.5" futures-core = "0.3.5" -rustls = "0.21" -rustls-pemfile = "1.0" +rustls = "0.23.21" +rustls-pemfile = "2.2" # webpki = { version = "0.22.0", optional = true } -rustls-webpki = { version = "0.101.4", optional = true } -webpki-roots = { version = "0.22.3", optional = true } +rustls-webpki = { version = "0.102", optional = true } +webpki-roots = { version = "0.26", optional = true } [features] default = ["client", "server"] @@ -36,7 +36,6 @@ early-data = [] server = [] [dev-dependencies] -lazy_static = "1" futures-executor = "0.3.5" futures-util = { version = "0.3.5", features = ["io"] } async-std = { version = "1.11", features = ["unstable"] } diff --git a/src/common/tls_state.rs b/src/common/tls_state.rs index 276d334..d01f6e5 100644 --- a/src/common/tls_state.rs +++ b/src/common/tls_state.rs @@ -24,16 +24,10 @@ impl TlsState { } pub(crate) fn writeable(&self) -> bool { - match *self { - TlsState::WriteShutdown | TlsState::FullyShutdown => false, - _ => true, - } + !matches!(*self, TlsState::WriteShutdown | TlsState::FullyShutdown) } pub(crate) fn readable(self) -> bool { - match self { - TlsState::ReadShutdown | TlsState::FullyShutdown => false, - _ => true, - } + !matches!(self, TlsState::ReadShutdown | TlsState::FullyShutdown) } } diff --git a/src/connector.rs b/src/connector.rs index 6f674ca..aac8f5f 100644 --- a/src/connector.rs +++ b/src/connector.rs @@ -3,7 +3,8 @@ use crate::common::tls_state::TlsState; use crate::client; use futures_io::{AsyncRead, AsyncWrite}; -use rustls::{ClientConfig, ClientConnection, OwnedTrustAnchor, RootCertStore, ServerName}; +use rustls::pki_types::ServerName; +use rustls::{ClientConfig, ClientConnection, RootCertStore}; use std::convert::TryFrom; use std::future::Future; use std::io; @@ -64,16 +65,10 @@ impl From for TlsConnector { impl Default for TlsConnector { fn default() -> Self { - let mut root_certs = RootCertStore::empty(); - root_certs.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - })); + let root_certs = RootCertStore { + roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(), + }; let config = ClientConfig::builder() - .with_safe_defaults() .with_root_certificates(root_certs) .with_no_client_auth(); Arc::new(config).into() @@ -103,7 +98,7 @@ impl TlsConnector { /// The function will return a `Connect` Future, representing the connecting part of a Tls /// handshake. It will resolve when the handshake is over. #[inline] - pub fn connect<'a, IO>(&self, domain: impl AsRef, stream: IO) -> Connect + pub fn connect(&self, domain: impl AsRef, stream: IO) -> Connect where IO: AsyncRead + AsyncWrite + Unpin, { @@ -112,12 +107,12 @@ impl TlsConnector { // NOTE: Currently private, exposing ClientConnection exposes rusttls // Early data should be exposed differently - fn connect_with<'a, IO, F>(&self, domain: impl AsRef, stream: IO, f: F) -> Connect + fn connect_with(&self, domain: impl AsRef, stream: IO, f: F) -> Connect where IO: AsyncRead + AsyncWrite + Unpin, F: FnOnce(&mut ClientConnection), { - let domain = match ServerName::try_from(domain.as_ref()) { + let domain = match ServerName::try_from(domain.as_ref().to_owned()) { Ok(domain) => domain, Err(_) => { return Connect(ConnectInner::Error(Some(io::Error::new( diff --git a/src/rusttls/stream.rs b/src/rusttls/stream.rs index bee787d..2629884 100644 --- a/src/rusttls/stream.rs +++ b/src/rusttls/stream.rs @@ -153,7 +153,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin> Stream<'a, IO> { cx: &'a mut Context<'b>, } - impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> { + impl Read for Reader<'_, '_, T> { fn read(&mut self, buf: &mut [u8]) -> io::Result { match Pin::new(&mut self.io).poll_read(self.cx, buf) { Poll::Ready(result) => result, @@ -253,7 +253,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin> Stream<'a, IO> { } } -impl<'a, IO: AsyncRead + AsyncWrite + Unpin> WriteTls for Stream<'a, IO> { +impl WriteTls for Stream<'_, IO> { fn write_tls(&mut self, cx: &mut Context) -> io::Result { // TODO writev @@ -262,7 +262,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin> WriteTls for Stream<'a, IO> { cx: &'a mut Context<'b>, } - impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> { + impl Write for Writer<'_, '_, T> { fn write(&mut self, buf: &[u8]) -> io::Result { match Pin::new(&mut self.io).poll_write(self.cx, buf) { Poll::Ready(result) => result, @@ -283,7 +283,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin> WriteTls for Stream<'a, IO> { } } -impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<'a, IO> { +impl AsyncRead for Stream<'_, IO> { fn poll_read( self: Pin<&mut Self>, cx: &mut Context, @@ -312,7 +312,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<'a, IO> { } } -impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<'a, IO> { +impl AsyncWrite for Stream<'_, IO> { fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { let this = self.get_mut(); diff --git a/src/rusttls/test_stream.rs b/src/rusttls/test_stream.rs index c8dc175..0f6112d 100644 --- a/src/rusttls/test_stream.rs +++ b/src/rusttls/test_stream.rs @@ -4,9 +4,10 @@ use futures_io::{AsyncRead, AsyncWrite}; use futures_util::io::{AsyncReadExt, AsyncWriteExt}; use futures_util::task::{noop_waker_ref, Context}; use futures_util::{future, ready}; +use rustls::pki_types::{PrivateKeyDer, ServerName}; use rustls::{ - Certificate, ClientConfig, ClientConnection, ConnectionCommon, PrivateKey, RootCertStore, - ServerConfig, ServerConnection, ServerName, + ClientConfig, ClientConnection, ConnectionCommon, RootCertStore, + ServerConfig, ServerConnection, }; use rustls_pemfile::{certs, pkcs8_private_keys}; use std::convert::TryFrom; @@ -17,7 +18,7 @@ use std::task::Poll; struct Good<'a, D>(&'a mut ConnectionCommon); -impl<'a, D> AsyncRead for Good<'a, D> { +impl AsyncRead for Good<'_, D> { fn poll_read( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, @@ -27,7 +28,7 @@ impl<'a, D> AsyncRead for Good<'a, D> { } } -impl<'a, D> AsyncWrite for Good<'a, D> { +impl AsyncWrite for Good<'_, D> { fn poll_write( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, @@ -223,12 +224,14 @@ fn make_pair() -> (ServerConnection, ClientConnection) { const CHAIN: &str = include_str!("../../tests/end.chain"); const RSA: &str = include_str!("../../tests/end.rsa"); - let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); - let cert = cert.into_iter().map(Certificate).collect(); - let mut keys = pkcs8_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); - let key = PrivateKey(keys.pop().unwrap()); + let cert = certs(&mut BufReader::new(Cursor::new(CERT))) + .collect::,_>>() + .unwrap(); + let mut keys = pkcs8_private_keys(&mut BufReader::new(Cursor::new(RSA))) + .collect::,_>>() + .unwrap(); + let key = PrivateKeyDer::Pkcs8(keys.pop().unwrap()); let sconfig = ServerConfig::builder() - .with_safe_defaults() .with_no_client_auth() .with_single_cert(cert, key) .unwrap(); @@ -236,11 +239,12 @@ fn make_pair() -> (ServerConnection, ClientConnection) { let domain = ServerName::try_from("localhost").unwrap(); let mut root_store = RootCertStore::empty(); - let chain = certs(&mut BufReader::new(Cursor::new(CHAIN))).unwrap(); - let (added, ignored) = root_store.add_parsable_certificates(&chain); + let chain = certs(&mut BufReader::new(Cursor::new(CHAIN))) + .collect::,_>>() + .unwrap(); + let (added, ignored) = root_store.add_parsable_certificates(chain); assert!(added >= 1 && ignored == 0); let cconfig = ClientConfig::builder() - .with_safe_defaults() .with_root_certificates(root_store) .with_no_client_auth(); let client = ClientConnection::new(Arc::new(cconfig), domain); diff --git a/tests/test.rs b/tests/test.rs index c7b5997..c074450 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -4,61 +4,65 @@ use async_std::net::{TcpListener, TcpStream}; use async_std::prelude::*; use async_std::task; use async_tls::{TlsAcceptor, TlsConnector}; -use lazy_static::lazy_static; -use rustls::{Certificate, ClientConfig, PrivateKey, RootCertStore, ServerConfig}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use rustls::{ClientConfig, RootCertStore, ServerConfig}; use rustls_pemfile::{certs, pkcs8_private_keys}; use std::io::{BufReader, Cursor}; use std::net::SocketAddr; use std::sync::Arc; +use std::sync::LazyLock; const CERT: &str = include_str!("end.cert"); const CHAIN: &str = include_str!("end.chain"); const RSA: &str = include_str!("end.rsa"); -lazy_static! { - static ref TEST_SERVER: (SocketAddr, &'static str, Vec>) = { - let cert = certs(&mut BufReader::new(Cursor::new(CERT))).unwrap(); - let cert = cert.into_iter().map(Certificate).collect(); - let chain = certs(&mut BufReader::new(Cursor::new(CHAIN))).unwrap(); - let mut keys = pkcs8_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); - let key = PrivateKey(keys.pop().unwrap()); - let sconfig = ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(cert, key) - .unwrap(); - let acceptor = TlsAcceptor::from(Arc::new(sconfig)); - - let (send, recv) = bounded(1); - - task::spawn(async move { - let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - let listener = TcpListener::bind(&addr).await?; - - send.send(listener.local_addr()?).await.unwrap(); - - let mut incoming = listener.incoming(); - while let Some(stream) = incoming.next().await { - let acceptor = acceptor.clone(); - task::spawn(async move { - use futures_util::io::AsyncReadExt; - let stream = acceptor.accept(stream?).await?; - let (mut reader, mut writer) = stream.split(); - io::copy(&mut reader, &mut writer).await?; - Ok(()) as io::Result<()> - }); - } - - Ok(()) as io::Result<()> - }); - - let addr = task::block_on(async move { recv.recv().await.unwrap() }); - (addr, "localhost", chain) - }; -} - -fn start_server() -> &'static (SocketAddr, &'static str, Vec>) { - &*TEST_SERVER +static TEST_SERVER: LazyLock<(SocketAddr, &'static str, Vec>)> = LazyLock::new(|| { + let cert = certs(&mut BufReader::new(Cursor::new(CERT))) + .collect::,_>>() + .unwrap(); + let chain = certs(&mut BufReader::new(Cursor::new(CHAIN))) + .collect::,_>>() + .unwrap(); + let mut keys = pkcs8_private_keys(&mut BufReader::new(Cursor::new(RSA))) + .map(|res| res.map(PrivateKeyDer::Pkcs8)) + .collect::,_>>() + .unwrap(); + let key = keys.pop().unwrap(); + let sconfig = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(cert, key) + .unwrap(); + let acceptor = TlsAcceptor::from(Arc::new(sconfig)); + + let (send, recv) = bounded(1); + + task::spawn(async move { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(&addr).await?; + + send.send(listener.local_addr()?).await.unwrap(); + + let mut incoming = listener.incoming(); + while let Some(stream) = incoming.next().await { + let acceptor = acceptor.clone(); + task::spawn(async move { + use futures_util::io::AsyncReadExt; + let stream = acceptor.accept(stream?).await?; + let (mut reader, mut writer) = stream.split(); + io::copy(&mut reader, &mut writer).await?; + Ok(()) as io::Result<()> + }); + } + + Ok(()) as io::Result<()> + }); + + let addr = task::block_on(async { recv.recv().await.unwrap() }); + (addr, "localhost", chain) +}); + +fn start_server() -> &'static (SocketAddr, &'static str, Vec>) { + &TEST_SERVER } async fn start_client(addr: SocketAddr, domain: &str, config: Arc) -> io::Result<()> { @@ -82,10 +86,9 @@ async fn start_client(addr: SocketAddr, domain: &str, config: Arc) fn pass() { let (addr, domain, chain) = start_server(); let mut root_store = RootCertStore::empty(); - let (added, ignored) = root_store.add_parsable_certificates(&chain); + let (added, ignored) = root_store.add_parsable_certificates(chain.clone()); assert!(added >= 1 && ignored == 0); let config = ClientConfig::builder() - .with_safe_defaults() .with_root_certificates(root_store) .with_no_client_auth(); task::block_on(start_client(*addr, domain, Arc::new(config))).unwrap(); @@ -95,10 +98,9 @@ fn pass() { fn fail() { let (addr, domain, chain) = start_server(); let mut root_store = RootCertStore::empty(); - let (added, ignored) = root_store.add_parsable_certificates(&chain); + let (added, ignored) = root_store.add_parsable_certificates(chain.clone()); assert!(added >= 1 && ignored == 0); let config = ClientConfig::builder() - .with_safe_defaults() .with_root_certificates(root_store) .with_no_client_auth(); let config = Arc::new(config);