Skip to content
This repository was archived by the owner on May 21, 2025. It is now read-only.

Commit 4f16b82

Browse files
committed
allow programs to bring their own transport to the connection
1 parent 1b5e73c commit 4f16b82

File tree

10 files changed

+130
-25
lines changed

10 files changed

+130
-25
lines changed

examples/custom_transport.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
use futures_util::{Sink, SinkExt, Stream, StreamExt, TryStreamExt};
2+
use rustls::{ClientConfig, KeyLogFile, RootCertStore};
3+
use std::error::Error;
4+
use std::future::ready;
5+
use std::sync::Arc;
6+
use steam_vent::connection::UnAuthenticatedConnection;
7+
use steam_vent::message::flatten_multi;
8+
use steam_vent::{NetworkError, RawNetMessage, ServerList};
9+
use tokio_tungstenite::tungstenite::Message as WsMessage;
10+
use tokio_tungstenite::{connect_async_tls_with_config, Connector};
11+
12+
#[tokio::main]
13+
async fn main() -> Result<(), Box<dyn Error>> {
14+
tracing_subscriber::fmt::init();
15+
16+
let server_list = ServerList::discover().await?;
17+
let (sender, receiver) = connect(&server_list.pick_ws()).await?;
18+
let connection = UnAuthenticatedConnection::from_sender_receiver(sender, receiver).await?;
19+
let _connection = connection.anonymous().await?;
20+
21+
Ok(())
22+
}
23+
24+
// this is just a copy of the standard websocket transport implementation, functioning as an example
25+
// how to implement a websocket transport
26+
pub async fn connect(
27+
addr: &str,
28+
) -> Result<
29+
(
30+
impl Sink<RawNetMessage, Error = NetworkError>,
31+
impl Stream<Item = Result<RawNetMessage, NetworkError>>,
32+
),
33+
NetworkError,
34+
> {
35+
rustls::crypto::aws_lc_rs::default_provider()
36+
.install_default()
37+
.ok(); // can only be once called
38+
let mut root_store = RootCertStore::empty();
39+
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
40+
let mut tls_config = ClientConfig::builder()
41+
.with_root_certificates(root_store)
42+
.with_no_client_auth();
43+
tls_config.key_log = Arc::new(KeyLogFile::new());
44+
let tls_config = Connector::Rustls(Arc::new(tls_config));
45+
let (stream, _) = connect_async_tls_with_config(addr, None, false, Some(tls_config)).await?;
46+
let (raw_write, raw_read) = stream.split();
47+
48+
Ok((
49+
raw_write.with(|msg: RawNetMessage| ready(Ok(WsMessage::binary(msg.into_bytes())))),
50+
flatten_multi(
51+
raw_read
52+
.map_err(NetworkError::from)
53+
.map_ok(|raw| raw.into_data())
54+
.map(|res| res.and_then(RawNetMessage::read)),
55+
),
56+
))
57+
}

protobuf/common/src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ pub trait RpcMessageWithKind: RpcMessage {
4242
const KIND: Self::KindEnum;
4343
}
4444

45+
/// A generic wrapper for "kind" constants used by network messages
4546
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
4647
pub struct MsgKind(pub i32);
4748

@@ -59,6 +60,10 @@ impl From<MsgKind> for i32 {
5960

6061
pub const PROTO_MASK: u32 = 0x80000000;
6162

63+
/// An enum containing "kind" constants used by the network messages
64+
///
65+
/// Though it is possible to use the generic [`MsgKind`] struct. Applications shipping their own protobufs
66+
/// are encouraged to create their own enums containing the constants in use for ease of use.
6267
pub trait MsgKindEnum: Enum + Debug {
6368
fn enum_value(&self) -> i32 {
6469
<Self as Enum>::value(self)

src/connection/filter.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::net::{JobId, RawNetMessage};
33
use dashmap::DashMap;
44
use futures_util::Stream;
55
use std::collections::VecDeque;
6+
use std::pin::pin;
67
use std::sync::{Arc, Mutex};
78
use steam_vent_proto::enums_clientserver::EMsg;
89
use steam_vent_proto::MsgKind;
@@ -59,10 +60,8 @@ pub struct MessageFilter {
5960
}
6061

6162
impl MessageFilter {
62-
pub fn new<
63-
Input: Stream<Item = crate::connection::Result<RawNetMessage>> + Send + Unpin + 'static,
64-
>(
65-
mut source: Input,
63+
pub fn new<Input: Stream<Item = crate::connection::Result<RawNetMessage>> + Send + 'static>(
64+
source: Input,
6665
) -> Self {
6766
let filter = MessageFilter {
6867
job_id_filters: Default::default(),
@@ -75,6 +74,7 @@ impl MessageFilter {
7574

7675
let filter_send = filter.clone();
7776
spawn(async move {
77+
let mut source = pin!(source);
7878
while let Some(res) = source.next().await {
7979
match res {
8080
Ok(message) => {

src/connection/raw.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ use crate::message::EncodableMessage;
44
use crate::net::{NetMessageHeader, RawNetMessage};
55
use crate::session::{hello, Session};
66
use crate::transport::websocket::connect;
7-
use crate::{ConnectionError, ServerList};
7+
use crate::{ConnectionError, NetworkError, ServerList};
8+
use futures_util::{Sink, Stream};
89
use std::fmt::{Debug, Formatter};
910
use std::sync::Arc;
1011
use std::time::Duration;
@@ -34,14 +35,24 @@ impl Debug for RawConnection {
3435

3536
impl RawConnection {
3637
pub async fn connect(server_list: &ServerList) -> Result<Self, ConnectionError> {
37-
let (read, write) = connect(&server_list.pick_ws()).await?;
38-
let filter = MessageFilter::new(read);
38+
let (sender, receiver) = connect(&server_list.pick_ws()).await?;
39+
Self::from_sender_receiver(sender, receiver).await
40+
}
41+
42+
pub async fn from_sender_receiver<
43+
Sender: Sink<RawNetMessage, Error = NetworkError> + Send + 'static,
44+
Receiver: Stream<Item = Result<RawNetMessage>> + Send + 'static,
45+
>(
46+
sender: Sender,
47+
receiver: Receiver,
48+
) -> Result<Self, ConnectionError> {
49+
let filter = MessageFilter::new(receiver);
3950
let heartbeat_cancellation_token = CancellationToken::new();
4051
let mut connection = RawConnection {
4152
session: Session::default(),
4253
filter,
4354
sender: MessageSender {
44-
write: Arc::new(Mutex::new(write)),
55+
write: Arc::new(Mutex::new(Box::pin(sender))),
4556
},
4657
timeout: Duration::from_secs(10),
4758
heartbeat_cancellation_token: heartbeat_cancellation_token.clone(),

src/connection/unauthenticated.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ use crate::service_method::ServiceMethodRequest;
77
use crate::session::{anonymous, login};
88
use crate::{Connection, ConnectionError, NetMessage, NetworkError, ServerList};
99
use futures_util::future::{select, Either};
10-
use futures_util::FutureExt;
1110
use futures_util::Stream;
11+
use futures_util::{FutureExt, Sink};
1212
use std::future::Future;
1313
use std::pin::pin;
1414
use steam_vent_proto::enums_clientserver::EMsg;
@@ -22,12 +22,30 @@ use tracing::{debug, error};
2222
pub struct UnAuthenticatedConnection(RawConnection);
2323

2424
impl UnAuthenticatedConnection {
25+
/// Create a connection from a sender, receiver pair.
26+
///
27+
/// This allows customizing the transport used by the connection. For example to customize the
28+
/// TLS configuration, use an existing websocket client or use a proxy.
29+
pub async fn from_sender_receiver<
30+
Sender: Sink<RawNetMessage, Error = NetworkError> + Send + 'static,
31+
Receiver: Stream<Item = Result<RawNetMessage>> + Send + 'static,
32+
>(
33+
sender: Sender,
34+
receiver: Receiver,
35+
) -> Result<Self, ConnectionError> {
36+
Ok(UnAuthenticatedConnection(
37+
RawConnection::from_sender_receiver(sender, receiver).await?,
38+
))
39+
}
40+
41+
/// Connect to a server from the server list using the default websocket transport
2542
pub async fn connect(server_list: &ServerList) -> Result<Self, ConnectionError> {
2643
Ok(UnAuthenticatedConnection(
2744
RawConnection::connect(server_list).await?,
2845
))
2946
}
3047

48+
/// Start an anonymous client session with this connection
3149
pub async fn anonymous(self) -> Result<Connection, ConnectionError> {
3250
let mut raw = self.0;
3351
raw.session = anonymous(&raw, AccountType::AnonUser).await?;
@@ -37,6 +55,7 @@ impl UnAuthenticatedConnection {
3755
Ok(connection)
3856
}
3957

58+
/// Start an anonymous server session with this connection
4059
pub async fn anonymous_server(self) -> Result<Connection, ConnectionError> {
4160
let mut raw = self.0;
4261
raw.session = anonymous(&raw, AccountType::AnonGameServer).await?;
@@ -46,6 +65,7 @@ impl UnAuthenticatedConnection {
4665
Ok(connection)
4766
}
4867

68+
/// Start a client session with this connection
4969
pub async fn login<H: AuthConfirmationHandler, G: GuardDataStore>(
5070
self,
5171
account: &str,

src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ pub mod auth;
22
pub mod connection;
33
mod eresult;
44
mod game_coordinator;
5-
mod message;
5+
pub mod message;
66
mod net;
77
mod serverlist;
88
mod service_method;
@@ -15,6 +15,6 @@ pub use connection::{Connection, ConnectionTrait, ReadonlyConnection};
1515
pub use eresult::EResult;
1616
pub use game_coordinator::GameCoordinator;
1717
pub use message::NetMessage;
18-
pub use net::NetworkError;
18+
pub use net::{NetworkError, RawNetMessage};
1919
pub use serverlist::{DiscoverOptions, ServerDiscoveryError, ServerList};
2020
pub use session::{ConnectionError, LoginError};

src/message.rs

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use thiserror::Error;
2222
use tokio_stream::Stream;
2323
use tracing::{debug, trace};
2424

25+
/// Malformed message body
2526
#[derive(Error, Debug)]
2627
#[error("Malformed message body for {0:?}: {1}")]
2728
pub struct MalformedBody(MsgKind, MessageBodyError);
@@ -32,6 +33,7 @@ impl MalformedBody {
3233
}
3334
}
3435

36+
/// Error while parsing the message body
3537
#[derive(Error, Debug)]
3638
#[non_exhaustive]
3739
pub enum MessageBodyError {
@@ -55,6 +57,11 @@ impl From<String> for MessageBodyError {
5557
}
5658
}
5759

60+
/// A message which can be encoded and/or decoded
61+
///
62+
/// Applications can implement this trait on a struct to allow sending it using
63+
/// [`raw_send_with_kind`](crate::ConnectionTrait::raw_send_with_kind). To use the higher level messages a struct also needs to implement
64+
/// [`NetMessage`]
5865
pub trait EncodableMessage: Sized + Debug + Send {
5966
fn read_body(_data: BytesMut, _header: &NetMessageHeader) -> Result<Self, MalformedBody> {
6067
panic!("Reading not implemented for {}", type_name::<Self>())
@@ -71,14 +78,15 @@ pub trait EncodableMessage: Sized + Debug + Send {
7178
fn process_header(&self, _header: &mut NetMessageHeader) {}
7279
}
7380

81+
/// A message with associated kind
7482
pub trait NetMessage: EncodableMessage {
7583
type KindEnum: MsgKindEnum;
7684
const KIND: Self::KindEnum;
7785
const IS_PROTOBUF: bool = false;
7886
}
7987

8088
#[derive(Debug, BinRead)]
81-
pub struct ChannelEncryptRequest {
89+
pub(crate) struct ChannelEncryptRequest {
8290
pub protocol: u32,
8391
#[allow(dead_code)]
8492
pub universe: u32,
@@ -99,7 +107,7 @@ impl NetMessage for ChannelEncryptRequest {
99107
}
100108

101109
#[derive(Debug, BinRead)]
102-
pub struct ChannelEncryptResult {
110+
pub(crate) struct ChannelEncryptResult {
103111
pub result: u32,
104112
}
105113

@@ -117,7 +125,7 @@ impl NetMessage for ChannelEncryptResult {
117125
}
118126

119127
#[derive(Debug)]
120-
pub struct ClientEncryptResponse {
128+
pub(crate) struct ClientEncryptResponse {
121129
pub protocol: u32,
122130
pub encrypted_key: Vec<u8>,
123131
}
@@ -164,6 +172,7 @@ impl Read for MaybeZipReader {
164172
}
165173
}
166174

175+
/// Flatten any "multi" messages in a stream of raw messages
167176
pub fn flatten_multi<S: Stream<Item = Result<RawNetMessage, NetworkError>>>(
168177
source: S,
169178
) -> impl Stream<Item = Result<RawNetMessage, NetworkError>> {
@@ -226,7 +235,7 @@ impl<R: Read> Iterator for MultiBodyIter<R> {
226235
}
227236

228237
#[derive(Debug)]
229-
pub struct ServiceMethodMessage<Request: Debug>(pub Request);
238+
pub(crate) struct ServiceMethodMessage<Request: Debug>(pub Request);
230239

231240
impl<Request: ServiceMethodRequest + Debug> EncodableMessage for ServiceMethodMessage<Request> {
232241
fn read_body(data: BytesMut, _header: &NetMessageHeader) -> Result<Self, MalformedBody> {
@@ -259,7 +268,7 @@ impl<Request: ServiceMethodRequest + Debug> NetMessage for ServiceMethodMessage<
259268
}
260269

261270
#[derive(Debug)]
262-
pub struct ServiceMethodResponseMessage {
271+
pub(crate) struct ServiceMethodResponseMessage {
263272
job_name: String,
264273
body: BytesMut,
265274
}
@@ -301,7 +310,7 @@ impl NetMessage for ServiceMethodResponseMessage {
301310
}
302311

303312
#[derive(Debug, Clone)]
304-
pub struct ServiceMethodNotification {
313+
pub(crate) struct ServiceMethodNotification {
305314
pub(crate) job_name: String,
306315
body: BytesMut,
307316
}

src/net.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,13 @@ impl RawNetMessage {
321321
header_buffer,
322322
})
323323
}
324+
325+
/// Return a buffer containing the raw message bytes
326+
pub fn into_bytes(self) -> BytesMut {
327+
let mut body = self.header_buffer;
328+
body.unsplit(self.data);
329+
body
330+
}
324331
}
325332

326333
impl RawNetMessage {

src/transport/tcp.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@ pub async fn encode_message<T: NetMessage, S: Sink<BytesMut, Error = NetworkErro
155155
pub async fn connect<A: ToSocketAddrs + Debug>(
156156
addr: A,
157157
) -> Result<(
158-
impl Stream<Item = Result<RawNetMessage>>,
159158
impl Sink<RawNetMessage, Error = NetworkError>,
159+
impl Stream<Item = Result<RawNetMessage>>,
160160
)> {
161161
let stream = TcpStream::connect(addr).await?;
162162
debug!("connected to server");
@@ -190,6 +190,7 @@ pub async fn connect<A: ToSocketAddrs + Debug>(
190190
let key = key.plain;
191191

192192
Ok((
193+
FramedWrite::new(raw_writer.into_inner(), RawMessageEncoder { key }),
193194
flatten_multi(
194195
raw_reader
195196
.and_then(move |encrypted| {
@@ -201,6 +202,5 @@ pub async fn connect<A: ToSocketAddrs + Debug>(
201202
})
202203
.and_then(|raw| ready(RawNetMessage::read(raw))),
203204
),
204-
FramedWrite::new(raw_writer.into_inner(), RawMessageEncoder { key }),
205205
))
206206
}

src/transport/websocket.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ type Result<T, E = NetworkError> = std::result::Result<T, E>;
1515
pub async fn connect(
1616
addr: &str,
1717
) -> Result<(
18-
impl Stream<Item = Result<RawNetMessage>>,
1918
impl Sink<RawNetMessage, Error = NetworkError>,
19+
impl Stream<Item = Result<RawNetMessage>>,
2020
)> {
2121
rustls::crypto::aws_lc_rs::default_provider()
2222
.install_default()
@@ -33,16 +33,12 @@ pub async fn connect(
3333
let (raw_write, raw_read) = stream.split();
3434

3535
Ok((
36+
raw_write.with(|msg: RawNetMessage| ready(Ok(WsMessage::binary(msg.into_bytes())))),
3637
flatten_multi(
3738
raw_read
3839
.map_err(NetworkError::from)
3940
.map_ok(|raw| raw.into_data())
4041
.map(|res| res.and_then(RawNetMessage::read)),
4142
),
42-
raw_write.with(|msg: RawNetMessage| {
43-
let mut body = msg.header_buffer;
44-
body.unsplit(msg.data);
45-
ready(Ok(WsMessage::binary(body)))
46-
}),
4743
))
4844
}

0 commit comments

Comments
 (0)