Skip to content

Commit f7dd681

Browse files
fix early-data implementation
1 parent c58588a commit f7dd681

File tree

3 files changed

+114
-14
lines changed

3 files changed

+114
-14
lines changed

src/client.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ where
163163
let (pos, data) = &mut this.early_data;
164164

165165
// write early data
166-
if let Some(mut early_data) = stream.session.early_data() {
166+
if let Some(mut early_data) = stream.conn.client_early_data() {
167167
let len = match early_data.write(buf) {
168168
Ok(n) => n,
169169
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
@@ -176,12 +176,12 @@ where
176176
}
177177

178178
// complete handshake
179-
if stream.session.is_handshaking() {
179+
if stream.conn.is_handshaking() {
180180
ready!(stream.complete_io(cx))?;
181181
}
182182

183183
// write early data (fallback)
184-
if !stream.session.is_early_data_accepted() {
184+
if !stream.conn.is_early_data_accepted() {
185185
while *pos < data.len() {
186186
let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
187187
*pos += len;

src/rusttls/stream.rs

+109-9
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,117 @@
11
use futures_core::ready;
22
use futures_io::{AsyncRead, AsyncWrite};
3-
use rustls::ConnectionCommon;
3+
#[cfg(feature = "early-data")]
4+
use rustls::client::WriteEarlyData;
5+
use rustls::{ClientConnection, IoState, Reader, ServerConnection, Writer};
46
use std::io::{self, Read, Write};
57
use std::marker::Unpin;
68
use std::pin::Pin;
79
use std::task::{Context, Poll};
810

9-
pub struct Stream<'a, IO, D> {
11+
pub struct Stream<'a, IO> {
1012
pub io: &'a mut IO,
11-
pub conn: &'a mut ConnectionCommon<D>,
13+
pub conn: Conn<'a>,
1214
pub eof: bool,
1315
}
1416

17+
pub(crate) enum Conn<'a> {
18+
Client(&'a mut ClientConnection),
19+
Server(&'a mut ServerConnection),
20+
}
21+
22+
impl Conn<'_> {
23+
pub(crate) fn is_handshaking(&self) -> bool {
24+
match self {
25+
Conn::Client(c) => c.is_handshaking(),
26+
Conn::Server(c) => c.is_handshaking(),
27+
}
28+
}
29+
30+
pub(crate) fn wants_write(&self) -> bool {
31+
match self {
32+
Conn::Client(c) => c.wants_write(),
33+
Conn::Server(c) => c.wants_write(),
34+
}
35+
}
36+
37+
pub(crate) fn wants_read(&self) -> bool {
38+
match self {
39+
Conn::Client(c) => c.wants_read(),
40+
Conn::Server(c) => c.wants_read(),
41+
}
42+
}
43+
44+
pub(crate) fn write_tls(&mut self, wr: &mut dyn io::Write) -> Result<usize, io::Error> {
45+
match self {
46+
Conn::Client(c) => c.write_tls(wr),
47+
Conn::Server(c) => c.write_tls(wr),
48+
}
49+
}
50+
51+
pub(crate) fn reader(&mut self) -> Reader {
52+
match self {
53+
Conn::Client(c) => c.reader(),
54+
Conn::Server(c) => c.reader(),
55+
}
56+
}
57+
58+
pub(crate) fn writer(&mut self) -> Writer {
59+
match self {
60+
Conn::Client(c) => c.writer(),
61+
Conn::Server(c) => c.writer(),
62+
}
63+
}
64+
65+
pub(crate) fn send_close_notify(&mut self) {
66+
match self {
67+
Conn::Client(c) => c.send_close_notify(),
68+
Conn::Server(c) => c.send_close_notify(),
69+
}
70+
}
71+
72+
pub(crate) fn read_tls(&mut self, rd: &mut dyn io::Read) -> Result<usize, io::Error> {
73+
match self {
74+
Conn::Client(c) => c.read_tls(rd),
75+
Conn::Server(c) => c.read_tls(rd),
76+
}
77+
}
78+
79+
pub(crate) fn process_new_packets(&mut self) -> Result<IoState, rustls::Error> {
80+
match self {
81+
Conn::Client(c) => c.process_new_packets(),
82+
Conn::Server(c) => c.process_new_packets(),
83+
}
84+
}
85+
86+
#[cfg(feature = "early-data")]
87+
pub(crate) fn is_early_data_accepted(&self) -> bool {
88+
match self {
89+
Conn::Client(c) => c.is_early_data_accepted(),
90+
Conn::Server(_) => false,
91+
}
92+
}
93+
94+
#[cfg(feature = "early-data")]
95+
pub(crate) fn client_early_data(&mut self) -> Option<WriteEarlyData<'_>> {
96+
match self {
97+
Conn::Client(c) => c.early_data(),
98+
Conn::Server(_) => None,
99+
}
100+
}
101+
}
102+
103+
impl<'a> From<&'a mut ClientConnection> for Conn<'a> {
104+
fn from(conn: &'a mut ClientConnection) -> Self {
105+
Conn::Client(conn)
106+
}
107+
}
108+
109+
impl<'a> From<&'a mut ServerConnection> for Conn<'a> {
110+
fn from(conn: &'a mut ServerConnection) -> Self {
111+
Conn::Server(conn)
112+
}
113+
}
114+
15115
trait WriteTls<IO: AsyncWrite> {
16116
fn write_tls(&mut self, cx: &mut Context) -> io::Result<usize>;
17117
}
@@ -23,11 +123,11 @@ enum Focus {
23123
Writable,
24124
}
25125

26-
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, D: Unpin> Stream<'a, IO, D> {
27-
pub fn new(io: &'a mut IO, conn: &'a mut ConnectionCommon<D>) -> Self {
126+
impl<'a, IO: AsyncRead + AsyncWrite + Unpin> Stream<'a, IO> {
127+
pub fn new(io: &'a mut IO, conn: impl Into<Conn<'a>>) -> Self {
28128
Stream {
29129
io,
30-
conn,
130+
conn: conn.into(),
31131
// The state so far is only used to detect EOF, so either Stream
32132
// or EarlyData state should both be all right.
33133
eof: false,
@@ -153,7 +253,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, D: Unpin> Stream<'a, IO, D> {
153253
}
154254
}
155255

156-
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, D: Unpin> WriteTls<IO> for Stream<'a, IO, D> {
256+
impl<'a, IO: AsyncRead + AsyncWrite + Unpin> WriteTls<IO> for Stream<'a, IO> {
157257
fn write_tls(&mut self, cx: &mut Context) -> io::Result<usize> {
158258
// TODO writev
159259

@@ -183,7 +283,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, D: Unpin> WriteTls<IO> for Stream<'
183283
}
184284
}
185285

186-
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, D: Unpin> AsyncRead for Stream<'a, IO, D> {
286+
impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<'a, IO> {
187287
fn poll_read(
188288
self: Pin<&mut Self>,
189289
cx: &mut Context,
@@ -212,7 +312,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, D: Unpin> AsyncRead for Stream<'a,
212312
}
213313
}
214314

215-
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, D: Unpin> AsyncWrite for Stream<'a, IO, D> {
315+
impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<'a, IO> {
216316
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
217317
let this = self.get_mut();
218318

src/test_0rtt.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ async fn get(
2929
#[test]
3030
fn test_0rtt() {
3131
let mut root_certs = RootCertStore::empty();
32-
root_certs.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
32+
root_certs.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
3333
OwnedTrustAnchor::from_subject_spki_name_constraints(
3434
ta.subject,
3535
ta.spki,
3636
ta.name_constraints,
3737
)
3838
}));
39-
let config = ClientConfig::builder()
39+
let mut config = ClientConfig::builder()
4040
.with_safe_defaults()
4141
.with_root_certificates(root_certs)
4242
.with_no_client_auth();

0 commit comments

Comments
 (0)