diff --git a/Cargo.lock b/Cargo.lock index 2e7829266..fdf36837e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -523,6 +523,7 @@ dependencies = [ "casbin", "chrono", "client-api", + "client-api-test-util", "collab", "collab-entity", "collab-folder", @@ -556,7 +557,7 @@ dependencies = [ "secrecy", "serde", "serde_json", - "shared_entity", + "shared-entity", "snowflake", "sqlx", "tempfile", @@ -1226,19 +1227,49 @@ dependencies = [ "serde", "serde_json", "serde_repr", - "shared_entity", + "shared-entity", "thiserror", "tokio", "tokio-retry", "tokio-stream", - "tokio-tungstenite", "tracing", "url", "uuid", + "websocket", "workspace-template", "yrs", ] +[[package]] +name = "client-api-test-util" +version = "0.1.0" +dependencies = [ + "assert-json-diff", + "bytes", + "client-api", + "collab", + "collab-entity", + "collab-folder", + "database-entity", + "dotenv", + "gotrue", + "image", + "lazy_static", + "mime", + "once_cell", + "opener", + "reqwest", + "scraper", + "serde_json", + "shared-entity", + "tempfile", + "tokio", + "tokio-stream", + "tracing", + "tracing-subscriber", + "uuid", +] + [[package]] name = "collab" version = "0.1.0" @@ -1330,6 +1361,16 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "console_error_panic_hook" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06aeb73f470f66dcdbf7223caeebb85984942f22f1adb2a088cf9668146bbbc" +dependencies = [ + "cfg-if", + "wasm-bindgen", +] + [[package]] name = "const-oid" version = "0.9.5" @@ -2618,9 +2659,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.64" +version = "0.3.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5f195fe497f702db0f318b07fdd68edb16955aed830df8363d837542f8f935a" +checksum = "9a1d36f1235bc969acba30b7f5990b864423a6068a10f7c90ae8f0112e3a59d1" dependencies = [ "wasm-bindgen", ] @@ -3839,6 +3880,7 @@ dependencies = [ "serde_json", "thiserror", "tokio-tungstenite", + "websocket", "yrs", ] @@ -4255,6 +4297,18 @@ dependencies = [ "sct", ] +[[package]] +name = "rustls-native-certs" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pemfile" version = "1.0.3" @@ -4321,6 +4375,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scoped_threadpool" version = "0.1.9" @@ -4546,7 +4606,7 @@ dependencies = [ ] [[package]] -name = "shared_entity" +name = "shared-entity" version = "0.1.0" dependencies = [ "actix-web", @@ -5285,9 +5345,13 @@ dependencies = [ "futures-util", "log", "native-tls", + "rustls 0.21.7", + "rustls-native-certs", "tokio", "tokio-native-tls", + "tokio-rustls", "tungstenite", + "webpki-roots 0.25.2", ] [[package]] @@ -5498,6 +5562,7 @@ dependencies = [ "log", "native-tls", "rand 0.8.5", + "rustls 0.21.7", "sha1", "thiserror", "url", @@ -5706,9 +5771,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.89" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ed0d4f68a3015cc185aff4db9506a015f4b96f95303897bfa23f846db54064e" +checksum = "b1223296a201415c7fad14792dbefaace9bd52b62d33453ade1c5b5f07555406" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -5716,9 +5781,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.89" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b56f625e64f3a1084ded111c4d5f477df9f8c92df113852fa5a374dbda78826" +checksum = "fcdc935b63408d58a32f8cc9738a0bffd8f05cc7c002086c6ef20b7312ad9dcd" dependencies = [ "bumpalo", "log", @@ -5731,9 +5796,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.37" +version = "0.4.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c02dbc21516f9f1f04f187958890d7e6026df8d16540b7ad9492bc34a67cea03" +checksum = "bde2032aeb86bdfaecc8b261eef3cba735cc426c1f3a3416d1e0791be95fc461" dependencies = [ "cfg-if", "js-sys", @@ -5743,9 +5808,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.89" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2" +checksum = "3e4c238561b2d428924c49815533a8b9121c664599558a5d9ec51f8a1740a999" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -5753,9 +5818,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.89" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" +checksum = "bae1abb6806dc1ad9e560ed242107c0f6c84335f1749dd4e8ddb012ebd5e25a7" dependencies = [ "proc-macro2", "quote", @@ -5766,9 +5831,34 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.89" +version = "0.2.90" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d91413b1c31d7539ba5ef2451af3f0b833a005eb27a631cec32bc0635a8602b" + +[[package]] +name = "wasm-bindgen-test" +version = "0.3.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" +checksum = "139bd73305d50e1c1c4333210c0db43d989395b64a237bd35c10ef3832a7f70c" +dependencies = [ + "console_error_panic_hook", + "js-sys", + "scoped-tls", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-bindgen-test-macro", +] + +[[package]] +name = "wasm-bindgen-test-macro" +version = "0.3.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70072aebfe5da66d2716002c729a14e4aec4da0e23cc2ea66323dac541c93928" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] [[package]] name = "wasm-streams" @@ -5783,6 +5873,16 @@ dependencies = [ "web-sys", ] +[[package]] +name = "wasm-test" +version = "0.1.0" +dependencies = [ + "client-api", + "client-api-test-util", + "tokio", + "wasm-bindgen-test", +] + [[package]] name = "web-sys" version = "0.3.64" @@ -5827,6 +5927,22 @@ version = "0.25.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "14247bb57be4f377dfb94c72830b8ce8fc6beac03cf4bf7b9732eadd414123fc" +[[package]] +name = "websocket" +version = "0.1.0" +dependencies = [ + "futures-channel", + "futures-util", + "http", + "httparse", + "js-sys", + "thiserror", + "tokio", + "tokio-tungstenite", + "wasm-bindgen", + "web-sys", +] + [[package]] name = "weezl" version = "0.1.7" diff --git a/Cargo.toml b/Cargo.toml index 9b3688aca..dc982a963 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -81,7 +81,7 @@ gotrue = { path = "libs/gotrue" } gotrue-entity = { path = "libs/gotrue-entity" } infra = { path = "libs/infra" } app-error = { workspace = true, features = ["sqlx_error", "actix_web_error", "tokio_error"] } -shared_entity = { path = "libs/shared-entity", features = ["cloud"] } +shared-entity = { path = "libs/shared-entity", features = ["cloud"] } workspace-template = { workspace = true } realtime-entity.workspace = true @@ -91,6 +91,7 @@ once_cell = "1.19.0" tempfile = "3.9.0" assert-json-diff = "2.0.2" scraper = "0.17.1" +client-api-test-util = { path = "libs/client-api-test-util" } client-api = { path = "libs/client-api", features = ["collab-sync", "test_util"] } opener = "0.6.1" image = "0.23.14" @@ -120,13 +121,16 @@ members = [ "libs/app_error", "libs/workspace-template", "libs/encrypt", - "libs/realtime-protocol" + "libs/realtime-protocol", + "libs/websocket", + "libs/client-api-test-util", "libs/wasm-test", ] [workspace.dependencies] realtime-entity = { path = "libs/realtime-entity" } realtime-protocol = { path = "libs/realtime-protocol" } database-entity = { path = "libs/database-entity" } +shared-entity = { path = "libs/shared-entity" } app-error = { path = "libs/app_error" } serde_json = "1.0.111" serde = { version = "1.0.195", features = ["derive"] } @@ -138,6 +142,12 @@ anyhow = "1.0.79" tokio = { version = "1.35", features = ["sync"] } yrs = "0.17.2" bincode = "1.3.3" +websocket = { path = "libs/websocket" } +collab = { version = "0.1.0" } +collab-folder = { version = "0.1.0" } +tracing = { version = "0.1"} +collab-entity = { version = "0.1.0" } +gotrue = { path = "libs/gotrue" } [profile.release] lto = true diff --git a/libs/client-api-test-util/Cargo.toml b/libs/client-api-test-util/Cargo.toml new file mode 100644 index 000000000..9c8616e29 --- /dev/null +++ b/libs/client-api-test-util/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "client-api-test-util" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +bytes = "1.5.0" +mime = "0.3.17" +serde_json = "1.0.111" +tokio = { version = "1.0", features = ["sync"] } +tokio-stream = "0.1.14" +tracing.workspace = true +collab-folder.workspace = true +collab = { workspace = true, features = ["async-plugin"] } +client-api = { path = "../client-api", features = ["collab-sync", "test_util"] } +once_cell = "1.19.0" +tempfile = "3.9.0" +assert-json-diff = "2.0.2" +scraper = "0.17.1" +opener = "0.6.1" +image = "0.23.14" +database-entity.workspace = true +collab-entity.workspace = true +shared-entity.workspace = true +tracing-subscriber = { version = "0.3.18", features = ["registry", "env-filter", "ansi", "json"] } +uuid = "1.6.1" +lazy_static = "1.4.0" +dotenv = "0.15.0" +reqwest = "0.11.23" +gotrue.workspace = true \ No newline at end of file diff --git a/libs/client-api-test-util/src/client.rs b/libs/client-api-test-util/src/client.rs new file mode 100644 index 000000000..829b353d4 --- /dev/null +++ b/libs/client-api-test-util/src/client.rs @@ -0,0 +1,50 @@ +use client_api::{Client, ClientConfiguration}; +use dotenv::dotenv; +use lazy_static::lazy_static; +use std::borrow::Cow; +use std::env; +use tracing::warn; + +lazy_static! { + pub static ref LOCALHOST_URL: Cow<'static, str> = + get_env_var("LOCALHOST_URL", "http://localhost:8000"); + pub static ref LOCALHOST_WS: Cow<'static, str> = + get_env_var("LOCALHOST_WS", "ws://localhost:8000/ws"); + pub static ref LOCALHOST_GOTRUE: Cow<'static, str> = + get_env_var("LOCALHOST_GOTRUE", "http://localhost:9999"); +} + +fn get_env_var<'default>(key: &str, default: &'default str) -> Cow<'default, str> { + dotenv().ok(); + match env::var(key) { + Ok(value) => Cow::Owned(value), + Err(_) => { + warn!("could not read env var {}: using default: {}", key, default); + Cow::Borrowed(default) + }, + } +} + +/// Return a client that connects to the local host. It requires to run the server locally. +/// ```shell +/// ./build/run_local_server.sh +/// ``` +pub fn localhost_client() -> Client { + Client::new( + &LOCALHOST_URL, + &LOCALHOST_WS, + &LOCALHOST_GOTRUE, + ClientConfiguration::default(), + ) +} + +pub async fn workspace_id_from_client(c: &Client) -> String { + c.get_workspaces() + .await + .unwrap() + .0 + .first() + .unwrap() + .workspace_id + .to_string() +} diff --git a/libs/client-api-test-util/src/lib.rs b/libs/client-api-test-util/src/lib.rs new file mode 100644 index 000000000..21c0c59c8 --- /dev/null +++ b/libs/client-api-test-util/src/lib.rs @@ -0,0 +1,9 @@ +mod client; +mod log; +mod test_client; +mod user; + +pub use client::*; +pub use log::*; +pub use test_client::*; +pub use user::*; diff --git a/tests/util/mod.rs b/libs/client-api-test-util/src/log.rs similarity index 67% rename from tests/util/mod.rs rename to libs/client-api-test-util/src/log.rs index 3f6829623..8a8ba21ff 100644 --- a/tests/util/mod.rs +++ b/libs/client-api-test-util/src/log.rs @@ -1,10 +1,10 @@ -use std::sync::Once; -use tracing_subscriber::fmt::Subscriber; -use tracing_subscriber::util::SubscriberInitExt; -use tracing_subscriber::EnvFilter; - -pub(crate) mod test_client; +#[cfg(not(target_arch = "wasm32"))] +use { + std::sync::Once, + tracing_subscriber::{fmt::Subscriber, util::SubscriberInitExt, EnvFilter}, +}; +#[cfg(not(target_arch = "wasm32"))] pub fn setup_log() { static START: Once = Once::new(); START.call_once(|| { @@ -20,3 +20,6 @@ pub fn setup_log() { subscriber.try_init().unwrap(); }); } + +#[cfg(target_arch = "wasm32")] +pub fn setup_log() {} diff --git a/tests/util/test_client.rs b/libs/client-api-test-util/src/test_client.rs similarity index 89% rename from tests/util/test_client.rs rename to libs/client-api-test-util/src/test_client.rs index 4cd68cf1a..8440e23f5 100644 --- a/tests/util/test_client.rs +++ b/libs/client-api-test-util/src/test_client.rs @@ -1,9 +1,10 @@ +use crate::{localhost_client, setup_log}; use assert_json_diff::{ assert_json_eq, assert_json_include, assert_json_matches_no_panic, CompareMode, Config, }; use bytes::Bytes; use client_api::collab_sync::{SinkConfig, SyncObject, SyncPlugin}; -use client_api::{WSClient, WSClientConfig}; +use client_api::ws::{WSClient, WSClientConfig}; use collab::core::collab::MutexCollab; use collab::core::collab_plugin::EncodedCollab; use collab::core::collab_state::SyncState; @@ -23,32 +24,30 @@ use shared_entity::dto::workspace_dto::{ BlobMetadata, CreateWorkspaceMember, WorkspaceMemberChangeset, WorkspaceSpaceUsage, }; use shared_entity::response::AppResponseError; -use sqlx::types::Uuid; use std::collections::HashMap; use std::ops::Deref; use std::path::{Path, PathBuf}; use std::sync::Arc; use tokio::time::{timeout, Duration}; use tokio_stream::StreamExt; +use uuid::Uuid; -use crate::localhost_client; -use crate::user::utils::{generate_unique_registered_user, User}; -use crate::util::setup_log; +use crate::user::{generate_unique_registered_user, User}; -pub(crate) struct TestClient { +pub struct TestClient { pub user: User, pub ws_client: WSClient, pub api_client: client_api::Client, pub collab_by_object_id: HashMap, pub device_id: String, } -pub(crate) struct TestCollab { +pub struct TestCollab { #[allow(dead_code)] pub origin: CollabOrigin, pub collab: Arc, } impl TestClient { - pub(crate) async fn new(device_id: String, registered_user: User, start_ws_conn: bool) -> Self { + pub async fn new(device_id: String, registered_user: User, start_ws_conn: bool) -> Self { setup_log(); let api_client = localhost_client(); api_client @@ -81,24 +80,24 @@ impl TestClient { } } - pub(crate) async fn new_user() -> Self { + pub async fn new_user() -> Self { let registered_user = generate_unique_registered_user().await; let device_id = Uuid::new_v4().to_string(); Self::new(device_id, registered_user, true).await } - pub(crate) async fn new_user_without_ws_conn() -> Self { + pub async fn new_user_without_ws_conn() -> Self { let registered_user = generate_unique_registered_user().await; let device_id = Uuid::new_v4().to_string(); Self::new(device_id, registered_user, false).await } - pub(crate) async fn user_with_new_device(registered_user: User) -> Self { + pub async fn user_with_new_device(registered_user: User) -> Self { let device_id = Uuid::new_v4().to_string(); Self::new(device_id, registered_user, true).await } - pub(crate) async fn add_workspace_member( + pub async fn add_workspace_member( &self, workspace_id: &str, other_client: &TestClient, @@ -110,15 +109,15 @@ impl TestClient { .unwrap(); } - pub(crate) async fn get_user_workspace_info(&self) -> AFUserWorkspaceInfo { + pub async fn get_user_workspace_info(&self) -> AFUserWorkspaceInfo { self.api_client.get_user_workspace_info().await.unwrap() } - pub(crate) async fn open_workspace(&self, workspace_id: &str) -> AFWorkspace { + pub async fn open_workspace(&self, workspace_id: &str) -> AFWorkspace { self.api_client.open_workspace(workspace_id).await.unwrap() } - pub(crate) async fn get_user_folder(&self) -> Folder { + pub async fn get_user_folder(&self) -> Folder { let uid = self.uid().await; let workspace_id = self.workspace_id().await; let data = self @@ -141,7 +140,7 @@ impl TestClient { .unwrap() } - pub(crate) async fn try_update_workspace_member( + pub async fn try_update_workspace_member( &self, workspace_id: &str, other_client: &TestClient, @@ -158,7 +157,7 @@ impl TestClient { .await } - pub(crate) async fn try_add_workspace_member( + pub async fn try_add_workspace_member( &self, workspace_id: &str, other_client: &TestClient, @@ -171,7 +170,7 @@ impl TestClient { .await } - pub(crate) async fn try_remove_workspace_member( + pub async fn try_remove_workspace_member( &self, workspace_id: &str, other_client: &TestClient, @@ -191,7 +190,7 @@ impl TestClient { .unwrap() } - pub(crate) async fn add_client_as_collab_member( + pub async fn add_client_as_collab_member( &self, workspace_id: &str, object_id: &str, @@ -211,7 +210,7 @@ impl TestClient { .unwrap(); } - pub(crate) async fn update_collab_member_access_level( + pub async fn update_collab_member_access_level( &self, workspace_id: &str, object_id: &str, @@ -231,13 +230,13 @@ impl TestClient { .unwrap(); } - pub(crate) async fn wait_object_sync_complete(&self, object_id: &str) { + pub async fn wait_object_sync_complete(&self, object_id: &str) { self .wait_object_sync_complete_with_secs(object_id, 20) .await; } - pub(crate) async fn wait_object_sync_complete_with_secs(&self, object_id: &str, secs: u64) { + pub async fn wait_object_sync_complete_with_secs(&self, object_id: &str, secs: u64) { let mut sync_state = self .collab_by_object_id .get(object_id) @@ -281,7 +280,7 @@ impl TestClient { .unwrap() } - pub(crate) async fn workspace_id(&self) -> String { + pub async fn workspace_id(&self) -> String { self .api_client .get_workspaces() @@ -294,16 +293,16 @@ impl TestClient { .to_string() } - pub(crate) async fn email(&self) -> String { + pub async fn email(&self) -> String { self.api_client.get_profile().await.unwrap().email.unwrap() } - pub(crate) async fn uid(&self) -> i64 { + pub async fn uid(&self) -> i64 { self.api_client.get_profile().await.unwrap().uid } #[allow(dead_code)] - pub(crate) async fn get_snapshot( + pub async fn get_snapshot( &self, workspace_id: &str, object_id: &str, @@ -321,7 +320,7 @@ impl TestClient { .await } - pub(crate) async fn create_snapshot( + pub async fn create_snapshot( &self, workspace_id: &str, object_id: &str, @@ -333,7 +332,7 @@ impl TestClient { .await } - pub(crate) async fn get_snapshot_list( + pub async fn get_snapshot_list( &self, workspace_id: &str, object_id: &str, @@ -344,7 +343,7 @@ impl TestClient { .await } - pub(crate) async fn create_collab_list( + pub async fn create_collab_list( &mut self, workspace_id: &str, params: Vec, @@ -355,7 +354,7 @@ impl TestClient { .await } - pub(crate) async fn batch_get_collab( + pub async fn batch_get_collab( &mut self, workspace_id: &str, params: Vec, @@ -364,7 +363,7 @@ impl TestClient { } #[allow(clippy::await_holding_lock)] - pub(crate) async fn create_and_edit_collab( + pub async fn create_and_edit_collab( &mut self, workspace_id: &str, collab_type: CollabType, @@ -377,7 +376,7 @@ impl TestClient { } #[allow(clippy::await_holding_lock)] - pub(crate) async fn create_and_edit_collab_with_data( + pub async fn create_and_edit_collab_with_data( &mut self, object_id: String, workspace_id: &str, @@ -438,14 +437,14 @@ impl TestClient { self.wait_object_sync_complete(&object_id).await; } - pub(crate) async fn open_workspace_collab(&mut self, workspace_id: &str) { + pub async fn open_workspace_collab(&mut self, workspace_id: &str) { self .open_collab(workspace_id, workspace_id, CollabType::Folder) .await; } #[allow(clippy::await_holding_lock)] - pub(crate) async fn open_collab( + pub async fn open_collab( &mut self, workspace_id: &str, object_id: &str, @@ -482,11 +481,11 @@ impl TestClient { .insert(object_id.to_string(), test_collab); } - pub(crate) async fn disconnect(&self) { + pub async fn disconnect(&self) { self.ws_client.disconnect().await; } - pub(crate) async fn reconnect(&self) { + pub async fn reconnect(&self) { self .ws_client .connect( @@ -600,7 +599,7 @@ pub async fn assert_server_collab( } } -pub(crate) async fn assert_client_collab( +pub async fn assert_client_collab( client: &mut TestClient, object_id: &str, key: &str, @@ -638,7 +637,7 @@ pub(crate) async fn assert_client_collab( } } -pub(crate) async fn assert_client_collab_include_value( +pub async fn assert_client_collab_include_value( client: &mut TestClient, object_id: &str, expected: Value, diff --git a/tests/user/utils.rs b/libs/client-api-test-util/src/user.rs similarity index 82% rename from tests/user/utils.rs rename to libs/client-api-test-util/src/user.rs index 1f24d35fb..6882b75d8 100644 --- a/tests/user/utils.rs +++ b/libs/client-api-test-util/src/user.rs @@ -1,13 +1,11 @@ +use crate::client::{localhost_client, LOCALHOST_GOTRUE}; +use crate::log::setup_log; use client_api::Client; -use dotenvy::dotenv; - -use sqlx::types::Uuid; - +use dotenv::dotenv; use lazy_static::lazy_static; +use uuid::Uuid; -use crate::util::setup_log; -use crate::{localhost_client, LOCALHOST_GOTRUE}; - +#[cfg(not(target_arch = "wasm32"))] lazy_static! { pub static ref ADMIN_USER: User = { dotenv().ok(); @@ -18,6 +16,17 @@ lazy_static! { }; } +#[cfg(target_arch = "wasm32")] +lazy_static! { + pub static ref ADMIN_USER: User = { + dotenv().ok(); + User { + email: "admin@example.com".to_string(), + password: "password".to_string(), + } + }; +} + #[derive(Clone, Debug)] pub struct User { pub email: String, diff --git a/libs/client-api/Cargo.toml b/libs/client-api/Cargo.toml index 3ba286111..424994aba 100644 --- a/libs/client-api/Cargo.toml +++ b/libs/client-api/Cargo.toml @@ -13,7 +13,7 @@ anyhow = "1.0.79" serde_repr = "0.1.18" gotrue = { path = "../gotrue" } gotrue-entity = { path = "../gotrue-entity" } -shared_entity = { path = "../shared-entity" } +shared-entity = { path = "../shared-entity" } tracing = { version = "0.1" } thiserror = "1.0.56" bytes = "1.5" @@ -32,6 +32,7 @@ mime = "0.3.17" tokio-stream = { version = "0.1.14" } realtime-entity = { workspace = true } chrono = "0.4" +websocket = { workspace = true, features = ["native-tls"] } collab = { version = "0.1.0", optional = true } collab-entity = { version = "0.1.0" } @@ -57,12 +58,13 @@ features = ["sync", "net"] workspace = true features = ["tungstenite"] -[target.'cfg(not(target_arch = "wasm32"))'.dependencies.tokio-tungstenite] -version = "0.20.1" -features = ["native-tls"] [features] collab-sync = ["collab", "yrs"] test_util = ["scraper"] template = ["workspace-template"] wasm_build = [] + + +[profile.dev] +debug = true diff --git a/libs/client-api/src/collab_sync/plugin.rs b/libs/client-api/src/collab_sync/plugin.rs index d62e993c3..9c6912b1a 100644 --- a/libs/client-api/src/collab_sync/plugin.rs +++ b/libs/client-api/src/collab_sync/plugin.rs @@ -15,7 +15,7 @@ use crate::collab_sync::{SinkConfig, SyncQueue}; use tokio_stream::wrappers::WatchStream; use tracing::trace; -use crate::{ConnectState, WSConnectStateReceiver}; +use crate::ws::{ConnectState, WSConnectStateReceiver}; use yrs::updates::encoder::Encode; pub struct SyncPlugin { diff --git a/libs/client-api/src/lib.rs b/libs/client-api/src/lib.rs index 81d09699d..211735478 100644 --- a/libs/client-api/src/lib.rs +++ b/libs/client-api/src/lib.rs @@ -30,9 +30,10 @@ if_wasm! { mod wasm; #[allow(unused_imports)] pub use wasm::*; - pub use wasm::ws_wasm::*; } +pub mod ws; + pub mod error { pub use shared_entity::response::AppResponseError; pub use shared_entity::response::ErrorCode; diff --git a/libs/client-api/src/native/http_native.rs b/libs/client-api/src/native/http_native.rs index 406f8e114..e40c4197d 100644 --- a/libs/client-api/src/native/http_native.rs +++ b/libs/client-api/src/native/http_native.rs @@ -1,6 +1,7 @@ use crate::http::log_request_id; use crate::retry::{RefreshTokenAction, RefreshTokenRetryCondition}; -use crate::{spawn_blocking_brotli_compress, Client, WSClientHttpSender, WSError}; +use crate::ws::{WSClientHttpSender, WSError}; +use crate::{spawn_blocking_brotli_compress, Client}; use app_error::AppError; use async_trait::async_trait; use database_entity::dto::CollabParams; @@ -9,6 +10,7 @@ use prost::Message; use realtime_entity::realtime_proto::HttpRealtimeMessage; use reqwest::{Body, Method}; use shared_entity::response::{AppResponse, AppResponseError}; +use std::future::Future; use std::sync::atomic::Ordering; use std::time::Duration; use tokio_retry::strategy::FixedInterval; @@ -20,7 +22,7 @@ impl Client { pub async fn post_realtime_msg( &self, device_id: &str, - msg: tokio_tungstenite::tungstenite::Message, + msg: websocket::Message, ) -> Result<(), AppResponseError> { let device_id = device_id.to_string(); let payload = @@ -143,14 +145,18 @@ impl Client { #[async_trait] impl WSClientHttpSender for Client { - async fn send_ws_msg( - &self, - device_id: &str, - message: tokio_tungstenite::tungstenite::Message, - ) -> Result<(), WSError> { + async fn send_ws_msg(&self, device_id: &str, message: websocket::Message) -> Result<(), WSError> { self .post_realtime_msg(device_id, message) .await .map_err(|err| WSError::Internal(anyhow::Error::from(err))) } } + +pub fn spawn(future: T) -> tokio::task::JoinHandle +where + T: Future + Send + 'static, + T::Output: Send + 'static, +{ + tokio::spawn(future) +} diff --git a/libs/client-api/src/native/mod.rs b/libs/client-api/src/native/mod.rs index 16cd51169..52888cfef 100644 --- a/libs/client-api/src/native/mod.rs +++ b/libs/client-api/src/native/mod.rs @@ -1,7 +1,5 @@ mod http_native; pub mod retry; -mod ws; #[allow(unused_imports)] pub use http_native::*; -pub use ws::*; diff --git a/libs/client-api/src/native/ws/msg.rs b/libs/client-api/src/native/ws/msg.rs deleted file mode 100644 index 680e8a85e..000000000 --- a/libs/client-api/src/native/ws/msg.rs +++ /dev/null @@ -1,97 +0,0 @@ -use crate::ws::WSError; -use realtime_entity::collab_msg::CollabMessage; -use serde::{Deserialize, Serialize}; -use serde_repr::{Deserialize_repr, Serialize_repr}; -use tokio_tungstenite::tungstenite::Message; - -#[derive(Debug, Copy, Clone, Serialize_repr, Deserialize_repr, Eq, PartialEq, Hash)] -#[repr(u8)] -pub enum BusinessID { - CollabId = 1, -} - -/// The message sent through WebSocket. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ClientRealtimeMessage { - pub business_id: BusinessID, - pub object_id: String, - pub payload: Vec, -} - -impl ClientRealtimeMessage { - pub fn new(business_id: BusinessID, object_id: String, payload: Vec) -> Self { - Self { - business_id, - object_id, - payload, - } - } -} - -impl TryFrom<&[u8]> for ClientRealtimeMessage { - type Error = WSError; - - fn try_from(bytes: &[u8]) -> Result { - let msg = serde_json::from_slice::(bytes)?; - Ok(msg) - } -} - -impl TryFrom> for ClientRealtimeMessage { - type Error = WSError; - - fn try_from(bytes: Vec) -> Result { - let msg = serde_json::from_slice::(&bytes)?; - Ok(msg) - } -} - -impl TryFrom<&Message> for ClientRealtimeMessage { - type Error = WSError; - - fn try_from(value: &Message) -> Result { - match value { - Message::Binary(bytes) => { - let msg = serde_json::from_slice::(bytes)?; - Ok(msg) - }, - _ => Err(WSError::UnsupportedMsgType), - } - } -} - -impl From for Message { - fn from(msg: ClientRealtimeMessage) -> Self { - let bytes = serde_json::to_vec(&msg).unwrap_or_default(); - Message::Binary(bytes) - } -} - -impl From for ClientRealtimeMessage { - fn from(msg: CollabMessage) -> Self { - let business_id = BusinessID::CollabId; - let object_id = msg.object_id().to_string(); - let payload = msg.to_vec(); - Self { - business_id, - object_id, - payload, - } - } -} - -impl TryFrom for CollabMessage { - type Error = WSError; - - fn try_from(value: ClientRealtimeMessage) -> Result { - let msg = - CollabMessage::from_vec(&value.payload).map_err(|e| WSError::Internal(Box::new(e)))?; - Ok(msg) - } -} - -impl From for Result { - fn from(msg: ClientRealtimeMessage) -> Self { - CollabMessage::try_from(msg) - } -} diff --git a/libs/client-api/src/wasm/http_wasm.rs b/libs/client-api/src/wasm/http_wasm.rs index 825321269..db8be2287 100644 --- a/libs/client-api/src/wasm/http_wasm.rs +++ b/libs/client-api/src/wasm/http_wasm.rs @@ -1,25 +1,27 @@ use crate::http::RefreshTokenRet; +use crate::ws::{WSClientHttpSender, WSError}; use crate::Client; use app_error::gotrue::GoTrueError; -use app_error::AppError; +use app_error::ErrorCode; use async_trait::async_trait; use database_entity::dto::CollabParams; use gotrue::grant::{Grant, RefreshTokenGrant}; use shared_entity::response::AppResponseError; +use std::future::Future; use std::sync::atomic::Ordering; -use std::time::Duration; -use tokio_retry::strategy::FixedInterval; -use tokio_retry::RetryIf; -use tracing::{event, instrument}; +use tracing::instrument; impl Client { pub async fn create_collab_list( &self, workspace_id: &str, - params_list: Vec, + _params_list: Vec, ) -> Result<(), AppResponseError> { let _url = self.batch_create_collab_url(workspace_id); - todo!() + Err(AppResponseError::new( + ErrorCode::Unhandled, + "not implemented", + )) } #[instrument(level = "debug", skip_all, err)] @@ -58,3 +60,22 @@ impl Client { Ok(()) } } + +pub fn spawn(future: T) -> tokio::task::JoinHandle +where + T: Future + 'static, + T::Output: Send + 'static, +{ + tokio::task::spawn_local(future) +} + +#[async_trait] +impl WSClientHttpSender for Client { + async fn send_ws_msg( + &self, + _device_id: &str, + _message: websocket::Message, + ) -> Result<(), WSError> { + Err(WSError::Internal(anyhow::Error::msg("not supported"))) + } +} diff --git a/libs/client-api/src/wasm/mod.rs b/libs/client-api/src/wasm/mod.rs index a5eb2d79e..3b11fc117 100644 --- a/libs/client-api/src/wasm/mod.rs +++ b/libs/client-api/src/wasm/mod.rs @@ -1,4 +1,3 @@ mod http_wasm; -pub mod ws_wasm; pub use http_wasm::*; diff --git a/libs/client-api/src/wasm/ws_wasm.rs b/libs/client-api/src/wasm/ws_wasm.rs deleted file mode 100644 index e69de29bb..000000000 diff --git a/libs/client-api/src/native/ws/client.rs b/libs/client-api/src/ws/client.rs similarity index 94% rename from libs/client-api/src/native/ws/client.rs rename to libs/client-api/src/ws/client.rs index 6c71f8016..7f81b7914 100644 --- a/libs/client-api/src/native/ws/client.rs +++ b/libs/client-api/src/ws/client.rs @@ -1,28 +1,24 @@ use futures_util::{SinkExt, StreamExt}; -use std::borrow::Cow; - use parking_lot::RwLock; +use std::borrow::Cow; use std::collections::HashMap; -use std::net::SocketAddr; use std::sync::{Arc, Weak}; use std::time::Duration; use tokio::sync::broadcast::{channel, Receiver, Sender}; -use crate::native::ping::ServerFixIntervalPing; -use crate::native::ws::retry::ConnectAction; -use crate::{ConnectState, ConnectStateNotify, WSError, WebSocketChannel}; +use crate::spawn; +use crate::ws::ping::ServerFixIntervalPing; +use crate::ws::retry::ConnectAction; +use crate::ws::{ConnectState, ConnectStateNotify, WSError, WebSocketChannel}; use realtime_entity::collab_msg::CollabMessage; use realtime_entity::message::RealtimeMessage; use realtime_entity::user::UserMessage; use tokio::sync::{oneshot, Mutex}; use tokio_retry::strategy::FixedInterval; use tokio_retry::{Condition, RetryIf}; -use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode; -use tokio_tungstenite::tungstenite::protocol::CloseFrame; -use tokio_tungstenite::tungstenite::Message; -use tokio_tungstenite::MaybeTlsStream; use tracing::{debug, error, info, trace, warn}; +use websocket::{CloseCode, CloseFrame, Message}; pub struct WSClientConfig { /// specifies the number of messages that the channel can hold at any given @@ -90,11 +86,7 @@ impl WSClient { } } - pub async fn connect( - &self, - addr: String, - device_id: &str, - ) -> Result, WSError> { + pub async fn connect(&self, addr: String, device_id: &str) -> Result<(), WSError> { self.set_state(ConnectState::Connecting).await; let (stop_tx, mut stop_rx) = oneshot::channel(); @@ -137,11 +129,6 @@ impl WSClient { } let ws_stream = conn_result?; - let addr = match ws_stream.get_ref() { - MaybeTlsStream::Plain(s) => s.local_addr().ok(), - _ => None, - }; - self.set_state(ConnectState::Connected).await; let (mut sink, mut stream) = ws_stream.split(); let weak_collab_channels = Arc::downgrade(&self.collab_channels); @@ -161,7 +148,7 @@ impl WSClient { let user_message_tx = self.user_channel.as_ref().clone(); // Receive messages from the websocket, and send them to the channels. - tokio::spawn(async move { + spawn(async move { while let Some(Ok(ws_msg)) = stream.next().await { match ws_msg { Message::Binary(_) => { @@ -226,7 +213,7 @@ impl WSClient { let mut rx = self.sender.subscribe(); let weak_http_sender = Arc::downgrade(&self.http_sender); let device_id = device_id.to_string(); - tokio::spawn(async move { + spawn(async move { loop { tokio::select! { _ = &mut stop_rx => break, @@ -257,7 +244,7 @@ impl WSClient { } }); - Ok(addr) + Ok(()) } /// Return a [WebSocketChannel] that can be used to send messages to the websocket. Caller should diff --git a/libs/client-api/src/native/ws/error.rs b/libs/client-api/src/ws/error.rs similarity index 95% rename from libs/client-api/src/native/ws/error.rs rename to libs/client-api/src/ws/error.rs index 7976052ed..747631ae5 100644 --- a/libs/client-api/src/native/ws/error.rs +++ b/libs/client-api/src/ws/error.rs @@ -1,5 +1,5 @@ use reqwest::StatusCode; -use tokio_tungstenite::tungstenite::Error; +use websocket::Error; #[derive(Debug, thiserror::Error)] pub enum WSError { diff --git a/libs/client-api/src/native/ws/handler.rs b/libs/client-api/src/ws/handler.rs similarity index 98% rename from libs/client-api/src/native/ws/handler.rs rename to libs/client-api/src/ws/handler.rs index fe5da1a7c..788e1d313 100644 --- a/libs/client-api/src/native/ws/handler.rs +++ b/libs/client-api/src/ws/handler.rs @@ -6,8 +6,8 @@ use std::task::{Context, Poll}; use tokio::sync::broadcast::{channel, Sender}; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; use tokio_stream::wrappers::UnboundedReceiverStream; -use tokio_tungstenite::tungstenite::Message; use tracing::{trace, warn}; +use websocket::Message; pub struct WebSocketChannel { object_id: String, diff --git a/libs/client-api/src/native/ws/mod.rs b/libs/client-api/src/ws/mod.rs similarity index 83% rename from libs/client-api/src/native/ws/mod.rs rename to libs/client-api/src/ws/mod.rs index 1174bf733..041b82716 100644 --- a/libs/client-api/src/native/ws/mod.rs +++ b/libs/client-api/src/ws/mod.rs @@ -1,7 +1,6 @@ mod client; mod error; mod handler; -// mod msg; pub(crate) mod ping; mod retry; mod state; @@ -9,5 +8,4 @@ mod state; pub use client::*; pub use error::*; pub use handler::*; -// pub use msg::*; pub use state::*; diff --git a/libs/client-api/src/native/ws/ping.rs b/libs/client-api/src/ws/ping.rs similarity index 96% rename from libs/client-api/src/native/ws/ping.rs rename to libs/client-api/src/ws/ping.rs index 29640d8ee..0c0ddb8a4 100644 --- a/libs/client-api/src/native/ws/ping.rs +++ b/libs/client-api/src/ws/ping.rs @@ -1,10 +1,10 @@ -use crate::{ConnectState, ConnectStateNotify}; +use crate::ws::{ConnectState, ConnectStateNotify}; use std::sync::Arc; use std::time::Duration; use tokio::sync::broadcast::Sender; use tokio::sync::mpsc::Receiver; use tokio::sync::Mutex; -use tokio_tungstenite::tungstenite::Message; +use websocket::Message; pub(crate) struct ServerFixIntervalPing { duration: Duration, diff --git a/libs/client-api/src/native/ws/retry.rs b/libs/client-api/src/ws/retry.rs similarity index 71% rename from libs/client-api/src/native/ws/retry.rs rename to libs/client-api/src/ws/retry.rs index 7913cf849..992bc7df6 100644 --- a/libs/client-api/src/native/ws/retry.rs +++ b/libs/client-api/src/ws/retry.rs @@ -1,11 +1,10 @@ use std::future::Future; use std::pin::Pin; -use crate::WSError; -use tokio::net::TcpStream; +use crate::ws::WSError; use tokio_retry::Action; -use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; use tracing::info; +use websocket::{connect_async, WebSocketStream}; pub(crate) struct ConnectAction { addr: String, @@ -18,8 +17,12 @@ impl ConnectAction { } impl Action for ConnectAction { + #[cfg(not(target_arch = "wasm32"))] type Future = Pin> + Send + Sync>>; - type Item = WebSocketStream>; + + #[cfg(target_arch = "wasm32")] + type Future = Pin>>>; + type Item = WebSocketStream; type Error = WSError; fn run(&mut self) -> Self::Future { @@ -27,7 +30,7 @@ impl Action for ConnectAction { Box::pin(async move { info!("🔵websocket start connecting"); match connect_async(&cloned_addr).await { - Ok((stream, _response)) => { + Ok(stream) => { info!("🟢websocket connect success"); Ok(stream) }, diff --git a/libs/client-api/src/native/ws/state.rs b/libs/client-api/src/ws/state.rs similarity index 100% rename from libs/client-api/src/native/ws/state.rs rename to libs/client-api/src/ws/state.rs diff --git a/libs/client-api/tests/main.rs b/libs/client-api/tests/main.rs new file mode 100644 index 000000000..0de2a2b6b --- /dev/null +++ b/libs/client-api/tests/main.rs @@ -0,0 +1,2 @@ +// mod native; +// mod web; diff --git a/tests/websocket/connect.rs b/libs/client-api/tests/native/conn_test.rs similarity index 97% rename from tests/websocket/connect.rs rename to libs/client-api/tests/native/conn_test.rs index c51f626a4..788219665 100644 --- a/tests/websocket/connect.rs +++ b/libs/client-api/tests/native/conn_test.rs @@ -1,7 +1,7 @@ use std::time::SystemTime; use crate::user::utils::generate_unique_registered_user_client; -use client_api::{ConnectState, WSClient, WSClientConfig}; +use client_api::ws::{ConnectState, WSClient, WSClientConfig}; #[tokio::test] async fn realtime_connect_test() { diff --git a/libs/client-api/tests/native/mod.rs b/libs/client-api/tests/native/mod.rs new file mode 100644 index 000000000..a217458fc --- /dev/null +++ b/libs/client-api/tests/native/mod.rs @@ -0,0 +1 @@ +mod conn_test; diff --git a/libs/client-api/tests/web/conn_test.rs b/libs/client-api/tests/web/conn_test.rs new file mode 100644 index 000000000..d289fe03f --- /dev/null +++ b/libs/client-api/tests/web/conn_test.rs @@ -0,0 +1,22 @@ +use crate::user::utils::generate_unique_registered_user_client; +use client_api::ws::{ConnectState, WSClient, WSClientConfig}; +use wasm_bindgen_test::wasm_bindgen_test; + +#[wasm_bindgen_test] +async fn realtime_connect_test() { + let (c, _user) = generate_unique_registered_user_client().await; + let ws_client = WSClient::new(WSClientConfig::default(), c.clone()); + let mut state = ws_client.subscribe_connect_state(); + let device_id = "fake_device_id"; + loop { + tokio::select! { + _ = ws_client.connect(c.ws_url(device_id).await.unwrap(), device_id) => {}, + value = state.recv() => { + let new_state = value.unwrap(); + if new_state == ConnectState::Connected { + break; + } + }, + } + } +} diff --git a/libs/client-api/tests/web/mod.rs b/libs/client-api/tests/web/mod.rs new file mode 100644 index 000000000..1aed1c9bd --- /dev/null +++ b/libs/client-api/tests/web/mod.rs @@ -0,0 +1,3 @@ +use wasm_bindgen_test::wasm_bindgen_test_configure; +wasm_bindgen_test_configure!(run_in_browser); +mod conn_test; diff --git a/libs/realtime-entity/Cargo.toml b/libs/realtime-entity/Cargo.toml index 16042e34d..9f9e1a435 100644 --- a/libs/realtime-entity/Cargo.toml +++ b/libs/realtime-entity/Cargo.toml @@ -22,6 +22,7 @@ database-entity.workspace = true yrs.workspace = true thiserror = "1.0.56" realtime-protocol.workspace = true +websocket.workspace = true [build-dependencies] protoc-bin-vendored = { version = "3.0" } diff --git a/libs/realtime-entity/src/message.rs b/libs/realtime-entity/src/message.rs index 16db1ba2f..0931ba66d 100644 --- a/libs/realtime-entity/src/message.rs +++ b/libs/realtime-entity/src/message.rs @@ -2,6 +2,7 @@ use crate::collab_msg::CollabMessage; use bytes::Bytes; use serde::{Deserialize, Serialize}; use std::fmt::Display; +use websocket::Message; #[derive(Debug, Clone, Serialize, Deserialize)] #[cfg_attr( @@ -71,10 +72,7 @@ impl TryFrom> for RealtimeMessage { } use crate::user::UserMessage; -#[cfg(feature = "tungstenite")] -use tokio_tungstenite::tungstenite::Message; -#[cfg(feature = "tungstenite")] impl TryFrom<&Message> for RealtimeMessage { type Error = anyhow::Error; @@ -86,7 +84,6 @@ impl TryFrom<&Message> for RealtimeMessage { } } -#[cfg(feature = "tungstenite")] impl TryFrom for RealtimeMessage { type Error = anyhow::Error; @@ -98,7 +95,6 @@ impl TryFrom for RealtimeMessage { } } -#[cfg(feature = "tungstenite")] impl From for Message { fn from(msg: RealtimeMessage) -> Self { let bytes = bincode::serialize(&msg).unwrap_or_default(); diff --git a/libs/shared-entity/Cargo.toml b/libs/shared-entity/Cargo.toml index 5da8365b5..28ab6f1b1 100644 --- a/libs/shared-entity/Cargo.toml +++ b/libs/shared-entity/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "shared_entity" +name = "shared-entity" version = "0.1.0" edition = "2021" diff --git a/libs/wasm-test/Cargo.toml b/libs/wasm-test/Cargo.toml new file mode 100644 index 000000000..65b6acc69 --- /dev/null +++ b/libs/wasm-test/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "wasm-test" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] + +[dev-dependencies] +wasm-bindgen-test = "0.3.40" +client-api-test-util = { path = "../client-api-test-util" } +client-api = { path = "../client-api" } +tokio = { version = "1", features = ["sync", "macros"] } + diff --git a/libs/wasm-test/src/lib.rs b/libs/wasm-test/src/lib.rs new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/libs/wasm-test/src/lib.rs @@ -0,0 +1 @@ + diff --git a/libs/wasm-test/tests/conn_test.rs b/libs/wasm-test/tests/conn_test.rs new file mode 100644 index 000000000..6936ef7b4 --- /dev/null +++ b/libs/wasm-test/tests/conn_test.rs @@ -0,0 +1,22 @@ +use client_api::ws::{ConnectState, WSClient, WSClientConfig}; +use client_api_test_util::generate_unique_registered_user_client; +use wasm_bindgen_test::wasm_bindgen_test; + +#[wasm_bindgen_test] +async fn realtime_connect_test() { + let (c, _user) = generate_unique_registered_user_client().await; + let ws_client = WSClient::new(WSClientConfig::default(), c.clone()); + let mut state = ws_client.subscribe_connect_state(); + let device_id = "fake_device_id"; + loop { + tokio::select! { + _ = ws_client.connect(c.ws_url(device_id).await.unwrap(), device_id) => {}, + value = state.recv() => { + let new_state = value.unwrap(); + if new_state == ConnectState::Connected { + break; + } + }, + } + } +} diff --git a/libs/wasm-test/tests/main.rs b/libs/wasm-test/tests/main.rs new file mode 100644 index 000000000..43b580d5e --- /dev/null +++ b/libs/wasm-test/tests/main.rs @@ -0,0 +1,5 @@ +use wasm_bindgen_test::wasm_bindgen_test_configure; +wasm_bindgen_test_configure!(run_in_browser); + +// #[cfg(target_arch = "wasm32")] +// mod conn_test; diff --git a/libs/websocket/Cargo.toml b/libs/websocket/Cargo.toml new file mode 100644 index 000000000..988896bcc --- /dev/null +++ b/libs/websocket/Cargo.toml @@ -0,0 +1,48 @@ +[package] +name = "websocket" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +native-tls = ["tokio-tungstenite/native-tls"] +native-tls-vendored = ["native-tls", "tokio-tungstenite/native-tls-vendored"] +rustls-tls-native-roots = [ + "__rustls-tls", + "tokio-tungstenite/rustls-tls-native-roots", +] +rustls-tls-webpki-roots = [ + "__rustls-tls", + "tokio-tungstenite/rustls-tls-webpki-roots", +] +__rustls-tls = [] + +[dependencies] +thiserror = "1" +http = "0.2" +httparse = "1.8" +futures-util = { version = "0.3", default-features = false, features = [ + "sink", + "std", +] } + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +tokio-tungstenite = "0.20" +tokio = { version = "1", default-features = false, features = ["net"] } + +[target.'cfg(target_arch = "wasm32")'.dependencies] +wasm-bindgen = "0.2" +js-sys = "0.3" +futures-channel = { version = "0.3" } + +[target.'cfg(target_arch = "wasm32")'.dependencies.web-sys] +version = "0.3" +features = [ + "WebSocket", + "MessageEvent", + "CloseEvent", + "Event", + "ErrorEvent", + "BinaryType", + "Blob", +] diff --git a/libs/websocket/src/error.rs b/libs/websocket/src/error.rs new file mode 100644 index 000000000..f5b4cde62 --- /dev/null +++ b/libs/websocket/src/error.rs @@ -0,0 +1,268 @@ +use http::{header::HeaderName, Response}; +use std::{io, result, str, string}; +use thiserror::Error; + +/// These error types are copy-pasted from the tokio_tungstenite crate. +pub type Result = result::Result; + +/// Possible WebSocket errors. +#[derive(Error, Debug)] +pub enum Error { + /// WebSocket connection closed normally. This informs you of the close. + /// It's not an error as such and nothing wrong happened. + /// + /// This is returned as soon as the close handshake is finished (we have both sent and + /// received a close frame) on the server end and as soon as the server has closed the + /// underlying connection if this endpoint is a client. + /// + /// Thus when you receive this, it is safe to drop the underlying connection. + /// + /// Receiving this error means that the WebSocket object is not usable anymore and the + /// only meaningful action with it is dropping it. + #[error("Connection closed normally")] + ConnectionClosed, + /// Trying to work with already closed connection. + /// + /// Trying to read or write after receiving `ConnectionClosed` causes this. + /// + /// As opposed to `ConnectionClosed`, this indicates your code tries to operate on the + /// connection when it really shouldn't anymore, so this really indicates a programmer + /// error on your part. + #[error("Trying to work with closed connection")] + AlreadyClosed, + /// Input-output error. Apart from WouldBlock, these are generally errors with the + /// underlying connection and you should probably consider them fatal. + #[error("IO error: {0}")] + Io(#[from] io::Error), + /// TLS error. + /// + /// Note that this error variant is enabled unconditionally even if no TLS feature is enabled, + /// to provide a feature-agnostic API surface. + #[cfg(not(target_arch = "wasm32"))] + #[error("TLS error: {0}")] + Tls(#[from] tokio_tungstenite::tungstenite::error::TlsError), + /// - When reading: buffer capacity exhausted. + /// - When writing: your message is bigger than the configured max message size + /// (64MB by default). + #[error("Space limit exceeded: {0}")] + Capacity(#[from] CapacityError), + /// Protocol violation. + #[error("WebSocket protocol error: {0}")] + Protocol(#[from] ProtocolError), + #[error("Write buffer is full")] + WriteBufferFull(crate::Message), + /// UTF coding error. + #[error("UTF-8 encoding error")] + Utf8, + #[error("Attack attempt detected")] + AttackAttempt, + #[error("URL error: {0}")] + Url(#[from] UrlError), + #[error("HTTP error: {}", .0.status())] + Http(Response>>), + #[error("HTTP format error: {0}")] + HttpFormat(#[from] http::Error), + #[error("Parsing blobs is unsupported")] + BlobFormatUnsupported, + #[error("Unknown data format encountered")] + UnknownFormat, +} + +impl From for Error { + fn from(_: str::Utf8Error) -> Self { + Error::Utf8 + } +} + +impl From for Error { + fn from(_: string::FromUtf8Error) -> Self { + Error::Utf8 + } +} + +impl From for Error { + fn from(err: http::header::InvalidHeaderValue) -> Self { + Error::HttpFormat(err.into()) + } +} + +impl From for Error { + fn from(err: http::header::InvalidHeaderName) -> Self { + Error::HttpFormat(err.into()) + } +} + +impl From for Error { + fn from(_: http::header::ToStrError) -> Self { + Error::Utf8 + } +} + +impl From for Error { + fn from(err: http::uri::InvalidUri) -> Self { + Error::HttpFormat(err.into()) + } +} + +impl From for Error { + fn from(err: http::status::InvalidStatusCode) -> Self { + Error::HttpFormat(err.into()) + } +} + +impl From for Error { + fn from(err: httparse::Error) -> Self { + match err { + httparse::Error::TooManyHeaders => Error::Capacity(CapacityError::TooManyHeaders), + e => Error::Protocol(ProtocolError::HttparseError(e)), + } + } +} + +/// Indicates the specific type/cause of a capacity error. +#[derive(Error, Debug, PartialEq, Eq, Clone, Copy)] +pub enum CapacityError { + /// Too many headers provided (see [`httparse::Error::TooManyHeaders`]). + #[error("Too many headers")] + TooManyHeaders, + /// Received header is too long. + /// Message is bigger than the maximum allowed size. + #[error("Message too long: {size} > {max_size}")] + MessageTooLong { + /// The size of the message. + size: usize, + /// The maximum allowed message size. + max_size: usize, + }, +} + +/// Indicates the specific type/cause of a protocol error. +#[derive(Error, Debug, PartialEq, Eq, Clone)] +pub enum ProtocolError { + /// Use of the wrong HTTP method (the WebSocket protocol requires the GET method be used). + #[error("Unsupported HTTP method used - only GET is allowed")] + WrongHttpMethod, + /// Wrong HTTP version used (the WebSocket protocol requires version 1.1 or higher). + #[error("HTTP version must be 1.1 or higher")] + WrongHttpVersion, + /// Missing `Connection: upgrade` HTTP header. + #[error("No \"Connection: upgrade\" header")] + MissingConnectionUpgradeHeader, + /// Missing `Upgrade: websocket` HTTP header. + #[error("No \"Upgrade: websocket\" header")] + MissingUpgradeWebSocketHeader, + /// Missing `Sec-WebSocket-Version: 13` HTTP header. + #[error("No \"Sec-WebSocket-Version: 13\" header")] + MissingSecWebSocketVersionHeader, + /// Missing `Sec-WebSocket-Key` HTTP header. + #[error("No \"Sec-WebSocket-Key\" header")] + MissingSecWebSocketKey, + /// The `Sec-WebSocket-Accept` header is either not present or does not specify the correct key value. + #[error("Key mismatch in \"Sec-WebSocket-Accept\" header")] + SecWebSocketAcceptKeyMismatch, + /// Garbage data encountered after client request. + #[error("Junk after client request")] + JunkAfterRequest, + /// Custom responses must be unsuccessful. + #[error("Custom response must not be successful")] + CustomResponseSuccessful, + /// Invalid header is passed. This header is formed by the library automatically + /// and must not be overwritten by the user. + #[error("Not allowed to pass overwrite the standard header {0}")] + InvalidHeader(HeaderName), + /// No more data while still performing handshake. + #[error("Handshake not finished")] + HandshakeIncomplete, + /// Wrapper around a [`httparse::Error`] value. + #[error("httparse error: {0}")] + HttparseError(#[from] httparse::Error), + /// Not allowed to send after having sent a closing frame. + #[error("Sending after closing is not allowed")] + SendAfterClosing, + /// Remote sent data after sending a closing frame. + #[error("Remote sent after having closed")] + ReceivedAfterClosing, + /// Reserved bits in frame header are non-zero. + #[error("Reserved bits are non-zero")] + NonZeroReservedBits, + /// The server must close the connection when an unmasked frame is received. + #[error("Received an unmasked frame from client")] + UnmaskedFrameFromClient, + /// The client must close the connection when a masked frame is received. + #[error("Received a masked frame from server")] + MaskedFrameFromServer, + /// Control frames must not be fragmented. + #[error("Fragmented control frame")] + FragmentedControlFrame, + /// Control frames must have a payload of 125 bytes or less. + #[error("Control frame too big (payload must be 125 bytes or less)")] + ControlFrameTooBig, + /// Type of control frame not recognised. + #[error("Unknown control frame type: {0}")] + UnknownControlFrameType(u8), + /// Type of data frame not recognised. + #[error("Unknown data frame type: {0}")] + UnknownDataFrameType(u8), + /// Received a continue frame despite there being nothing to continue. + #[error("Continue frame but nothing to continue")] + UnexpectedContinueFrame, + /// Received data while waiting for more fragments. + #[error("While waiting for more fragments received: {0}")] + ExpectedFragment(Data), + /// Connection closed without performing the closing handshake. + #[error("Connection reset without closing handshake")] + ResetWithoutClosingHandshake, + /// Encountered an invalid opcode. + #[error("Encountered invalid opcode: {0}")] + InvalidOpcode(u8), + /// The payload for the closing frame is invalid. + #[error("Invalid close sequence")] + InvalidCloseSequence, +} + +/// Indicates the specific type/cause of URL error. +#[derive(Error, Debug, PartialEq, Eq)] +pub enum UrlError { + /// TLS is used despite not being compiled with the TLS feature enabled. + #[error("TLS support not compiled in")] + TlsFeatureNotEnabled, + /// The URL does not include a host name. + #[error("No host name in the URL")] + NoHostName, + /// Failed to connect with this URL. + #[error("Unable to connect to {0}")] + UnableToConnect(String), + /// Unsupported URL scheme used (only `ws://` or `wss://` may be used). + #[error("URL scheme not supported")] + UnsupportedUrlScheme, + /// The URL host name, though included, is empty. + #[error("URL contains empty host name")] + EmptyHostName, + /// The URL does not include a path/query. + #[error("No path/query in URL")] + NoPathOrQuery, +} + +/// Data opcodes as in RFC 6455 +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum Data { + /// 0x0 denotes a continuation frame + Continue, + /// 0x1 denotes a text frame + Text, + /// 0x2 denotes a binary frame + Binary, + /// 0x3-7 are reserved for further non-control frames + Reserved(u8), +} + +impl std::fmt::Display for Data { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match *self { + Data::Continue => write!(f, "CONTINUE"), + Data::Text => write!(f, "TEXT"), + Data::Binary => write!(f, "BINARY"), + Data::Reserved(x) => write!(f, "RESERVED_DATA_{}", x), + } + } +} diff --git a/libs/websocket/src/lib.rs b/libs/websocket/src/lib.rs new file mode 100644 index 000000000..996bb41ec --- /dev/null +++ b/libs/websocket/src/lib.rs @@ -0,0 +1,20 @@ +mod error; +mod message; +#[cfg(not(target_arch = "wasm32"))] +mod native; +#[cfg(target_arch = "wasm32")] +mod web; + +pub use error::{Error, Result}; +pub use message::coding::*; +pub use message::CloseFrame; +pub use message::Message; +#[cfg(not(target_arch = "wasm32"))] +use native as ws; +#[cfg(target_arch = "wasm32")] +use web as ws; +pub use ws::WebSocketStream; + +pub async fn connect_async>(url: S) -> Result { + ws::connect_async(url.as_ref()).await +} diff --git a/libs/websocket/src/message.rs b/libs/websocket/src/message.rs new file mode 100644 index 000000000..238f4f053 --- /dev/null +++ b/libs/websocket/src/message.rs @@ -0,0 +1,330 @@ +/// An enum representing the various forms of a WebSocket message. +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum Message { + /// A text WebSocket message + Text(String), + /// A binary WebSocket message + Binary(Vec), + /// A close message with the optional close frame. + Close(Option>), + Ping(Vec), + Pong(Vec), +} + +impl Message { + /// Create a new text WebSocket message from a stringable. + pub fn text(string: S) -> Message + where + S: Into, + { + Message::Text(string.into()) + } + + /// Create a new binary WebSocket message by converting to Vec. + pub fn binary(bin: B) -> Message + where + B: Into>, + { + Message::Binary(bin.into()) + } + + /// Indicates whether a message is a text message. + pub fn is_text(&self) -> bool { + matches!(*self, Message::Text(_)) + } + + /// Indicates whether a message is a binary message. + pub fn is_binary(&self) -> bool { + matches!(*self, Message::Binary(_)) + } + + /// Indicates whether a message is a ping message. + pub fn is_ping(&self) -> bool { + false + } + + /// Indicates whether a message is a pong message. + pub fn is_pong(&self) -> bool { + false + } + + /// Indicates whether a message ia s close message. + pub fn is_close(&self) -> bool { + matches!(*self, Message::Close(_)) + } + + /// Get the length of the WebSocket message. + pub fn len(&self) -> usize { + match self { + Message::Text(s) => s.len(), + Message::Binary(data) => data.len(), + Message::Close(data) => data.as_ref().map(|d| d.reason.len()).unwrap_or(0), + Message::Ping(data) => data.len(), + Message::Pong(data) => data.len(), + } + } + + /// Returns true if the WebSocket message has no content. + /// For example, if the other side of the connection sent an empty string. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Consume the WebSocket and return it as binary data. + pub fn into_data(self) -> Vec { + match self { + Message::Text(string) => string.into_bytes(), + Message::Binary(data) => data, + Message::Close(None) => Vec::new(), + Message::Close(Some(frame)) => frame.reason.into_owned().into_bytes(), + Message::Ping(data) => data, + Message::Pong(data) => data, + } + } + + /// Attempt to consume the WebSocket message and convert it to a String. + pub fn into_text(self) -> Result { + match self { + Message::Text(string) => Ok(string), + Message::Binary(data) => Ok(String::from_utf8(data).map_err(|err| err.utf8_error())?), + Message::Close(None) => Ok(String::new()), + Message::Close(Some(frame)) => Ok(frame.reason.into_owned()), + Message::Ping(data) => Ok(String::from_utf8(data).map_err(|err| err.utf8_error())?), + Message::Pong(data) => Ok(String::from_utf8(data).map_err(|err| err.utf8_error())?), + } + } + + /// Attempt to get a &str from the WebSocket message, + /// this will try to convert binary data to utf8. + pub fn to_text(&self) -> Result<&str, crate::Error> { + match self { + Message::Text(s) => Ok(s.as_str()), + Message::Binary(data) => Ok(std::str::from_utf8(data)?), + Message::Close(None) => Ok(""), + Message::Close(Some(ref frame)) => Ok(&frame.reason), + Message::Ping(data) => Ok(std::str::from_utf8(data)?), + Message::Pong(data) => Ok(std::str::from_utf8(data)?), + } + } +} + +impl From for Message { + fn from(string: String) -> Self { + Message::text(string) + } +} + +impl<'s> From<&'s str> for Message { + fn from(string: &'s str) -> Self { + Message::text(string) + } +} + +impl<'b> From<&'b [u8]> for Message { + fn from(data: &'b [u8]) -> Self { + Message::binary(data) + } +} + +impl From> for Message { + fn from(data: Vec) -> Self { + Message::binary(data) + } +} + +impl From for Vec { + fn from(message: Message) -> Self { + message.into_data() + } +} + +impl std::convert::TryFrom for String { + type Error = crate::Error; + + fn try_from(value: Message) -> std::result::Result { + value.into_text() + } +} + +impl std::fmt::Display for Message { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::result::Result<(), std::fmt::Error> { + if let Ok(string) = self.to_text() { + write!(f, "{}", string) + } else { + write!(f, "Binary Data", self.len()) + } + } +} + +/// A struct representing the close command. +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct CloseFrame<'t> { + /// The reason as a code. + pub code: coding::CloseCode, + /// The reason as text string. + pub reason: std::borrow::Cow<'t, str>, +} + +impl<'t> CloseFrame<'t> { + /// Convert into a owned string. + pub fn into_owned(self) -> CloseFrame<'static> { + CloseFrame { + code: self.code, + reason: self.reason.into_owned().into(), + } + } +} + +impl<'t> std::fmt::Display for CloseFrame<'t> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{} ({})", self.reason, self.code) + } +} + +pub mod coding { + use self::CloseCode::*; + + /// Status code used to indicate why an endpoint is closing the WebSocket connection. + #[derive(Debug, Eq, PartialEq, Clone, Copy)] + pub enum CloseCode { + /// Indicates a normal closure, meaning that the purpose for + /// which the connection was established has been fulfilled. + Normal, + /// Indicates that an endpoint is "going away", such as a server + /// going down or a browser having navigated away from a page. + Away, + /// Indicates that an endpoint is terminating the connection due + /// to a protocol error. + Protocol, + /// Indicates that an endpoint is terminating the connection + /// because it has received a type of data it cannot accept (e.g., an + /// endpoint that understands only text data MAY send this if it + /// receives a binary message). + Unsupported, + /// Indicates that no status code was included in a closing frame. This + /// close code makes it possible to use a single method, `on_close` to + /// handle even cases where no close code was provided. + Status, + /// Indicates an abnormal closure. If the abnormal closure was due to an + /// error, this close code will not be used. Instead, the `on_error` method + /// of the handler will be called with the error. However, if the connection + /// is simply dropped, without an error, this close code will be sent to the + /// handler. + Abnormal, + /// Indicates that an endpoint is terminating the connection + /// because it has received data within a message that was not + /// consistent with the type of the message (e.g., non-UTF-8 \[RFC3629\] + /// data within a text message). + Invalid, + /// Indicates that an endpoint is terminating the connection + /// because it has received a message that violates its policy. This + /// is a generic status code that can be returned when there is no + /// other more suitable status code (e.g., Unsupported or Size) or if there + /// is a need to hide specific details about the policy. + Policy, + /// Indicates that an endpoint is terminating the connection + /// because it has received a message that is too big for it to + /// process. + Size, + /// Indicates that an endpoint (client) is terminating the + /// connection because it has expected the server to negotiate one or + /// more extension, but the server didn't return them in the response + /// message of the WebSocket handshake. The list of extensions that + /// are needed should be given as the reason for closing. + /// Note that this status code is not used by the server, because it + /// can fail the WebSocket handshake instead. + Extension, + /// Indicates that a server is terminating the connection because + /// it encountered an unexpected condition that prevented it from + /// fulfilling the request. + Error, + /// Indicates that the server is restarting. A client may choose to reconnect, + /// and if it does, it should use a randomized delay of 5-30 seconds between attempts. + Restart, + /// Indicates that the server is overloaded and the client should either connect + /// to a different IP (when multiple targets exist), or reconnect to the same IP + /// when a user has performed an action. + Again, + #[doc(hidden)] + Tls, + #[doc(hidden)] + Reserved(u16), + #[doc(hidden)] + Iana(u16), + #[doc(hidden)] + Library(u16), + #[doc(hidden)] + Bad(u16), + } + + impl CloseCode { + /// Check if this CloseCode is allowed. + pub fn is_allowed(self) -> bool { + !matches!(self, Bad(_) | Reserved(_) | Status | Abnormal | Tls) + } + } + + impl std::fmt::Display for CloseCode { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let code: u16 = self.into(); + write!(f, "{}", code) + } + } + + impl From for u16 { + fn from(code: CloseCode) -> u16 { + match code { + Normal => 1000, + Away => 1001, + Protocol => 1002, + Unsupported => 1003, + Status => 1005, + Abnormal => 1006, + Invalid => 1007, + Policy => 1008, + Size => 1009, + Extension => 1010, + Error => 1011, + Restart => 1012, + Again => 1013, + Tls => 1015, + Reserved(code) => code, + Iana(code) => code, + Library(code) => code, + Bad(code) => code, + } + } + } + + impl<'t> From<&'t CloseCode> for u16 { + fn from(code: &'t CloseCode) -> u16 { + (*code).into() + } + } + + impl From for CloseCode { + fn from(code: u16) -> CloseCode { + match code { + 1000 => Normal, + 1001 => Away, + 1002 => Protocol, + 1003 => Unsupported, + 1005 => Status, + 1006 => Abnormal, + 1007 => Invalid, + 1008 => Policy, + 1009 => Size, + 1010 => Extension, + 1011 => Error, + 1012 => Restart, + 1013 => Again, + 1015 => Tls, + 1..=999 => Bad(code), + 1016..=2999 => Reserved(code), + 3000..=3999 => Iana(code), + 4000..=4999 => Library(code), + _ => Bad(code), + } + } + } +} diff --git a/libs/websocket/src/native.rs b/libs/websocket/src/native.rs new file mode 100644 index 000000000..316c8dfa2 --- /dev/null +++ b/libs/websocket/src/native.rs @@ -0,0 +1,244 @@ +use futures_util::{Sink, Stream, StreamExt}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio_tungstenite::{ + tungstenite::{ + error::*, + protocol::{frame::coding::Data, CloseFrame}, + Message, Result, + }, + MaybeTlsStream, +}; + +pub async fn connect_async(url: &str) -> crate::Result { + let (inner, _response) = tokio_tungstenite::connect_async(url).await?; + let inner = inner.filter_map(to_fut_message as fn(_) -> _); + Ok(WebSocketStream { inner }) +} + +type TokioTungsteniteStream = + tokio_tungstenite::WebSocketStream>; +type FutMessage = futures_util::future::Ready>>; +pub struct WebSocketStream { + inner: futures_util::stream::FilterMap< + TokioTungsteniteStream, + FutMessage, + fn(Result) -> FutMessage, + >, +} + +impl Stream for WebSocketStream { + type Item = crate::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_next(cx) + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +impl Sink for WebSocketStream { + type Error = crate::Error; + + fn poll_ready( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_ready(cx).map_err(Into::into) + } + + fn start_send( + mut self: Pin<&mut Self>, + item: crate::Message, + ) -> std::result::Result<(), Self::Error> { + Pin::new(&mut self.inner) + .start_send(item.into()) + .map_err(Into::into) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx).map_err(Into::into) + } + + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_close(cx).map_err(Into::into) + } +} + +fn to_fut_message(msg: Result) -> FutMessage { + fn inner(msg: Result) -> Option> { + let msg = match msg { + Ok(msg) => match msg { + Message::Text(inner) => Ok(crate::Message::Text(inner)), + Message::Binary(inner) => Ok(crate::Message::Binary(inner)), + Message::Close(inner) => Ok(crate::Message::Close(inner.map(Into::into))), + Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => return None, + }, + Err(err) => Err(crate::Error::from(err)), + }; + Some(msg) + } + futures_util::future::ready(inner(msg)) +} + +impl<'a> From> for crate::message::CloseFrame<'a> { + fn from(close_frame: CloseFrame<'a>) -> Self { + crate::message::CloseFrame { + code: u16::from(close_frame.code).into(), + reason: close_frame.reason, + } + } +} + +impl<'a> From> for CloseFrame<'a> { + fn from(close_frame: crate::message::CloseFrame<'a>) -> Self { + CloseFrame { + code: u16::from(close_frame.code).into(), + reason: close_frame.reason, + } + } +} + +impl From for crate::Message { + fn from(msg: Message) -> Self { + match msg { + Message::Text(inner) => crate::Message::Text(inner), + Message::Binary(inner) => crate::Message::Binary(inner), + Message::Close(inner) => crate::Message::Close(inner.map(Into::into)), + Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => { + unreachable!("Unsendable via interface.") + }, + } + } +} + +impl From for Message { + fn from(msg: crate::Message) -> Self { + match msg { + crate::Message::Text(inner) => Message::Text(inner), + crate::Message::Binary(inner) => Message::Binary(inner), + crate::Message::Close(inner) => Message::Close(inner.map(Into::into)), + crate::Message::Ping(data) => Message::Ping(data), + crate::Message::Pong(data) => Message::Pong(data), + } + } +} + +impl From for crate::Error { + fn from(err: Error) -> Self { + match err { + Error::ConnectionClosed => crate::Error::ConnectionClosed, + Error::AlreadyClosed => crate::Error::AlreadyClosed, + Error::Io(inner) => crate::Error::Io(inner), + Error::Tls(inner) => crate::Error::Tls(inner), + Error::Capacity(inner) => crate::Error::Capacity(inner.into()), + Error::Protocol(inner) => crate::Error::Protocol(inner.into()), + Error::WriteBufferFull(inner) => crate::Error::WriteBufferFull(inner.into()), + Error::Utf8 => crate::Error::Utf8, + Error::AttackAttempt => crate::Error::AttackAttempt, + Error::Url(inner) => crate::Error::Url(inner.into()), + Error::Http(inner) => crate::Error::Http(inner), + Error::HttpFormat(inner) => crate::Error::HttpFormat(inner), + } + } +} + +impl From for crate::error::CapacityError { + fn from(err: CapacityError) -> Self { + match err { + CapacityError::TooManyHeaders => crate::error::CapacityError::TooManyHeaders, + CapacityError::MessageTooLong { size, max_size } => { + crate::error::CapacityError::MessageTooLong { size, max_size } + }, + } + } +} + +impl From for crate::error::UrlError { + fn from(err: UrlError) -> Self { + match err { + UrlError::TlsFeatureNotEnabled => crate::error::UrlError::TlsFeatureNotEnabled, + UrlError::NoHostName => crate::error::UrlError::NoHostName, + UrlError::UnableToConnect(inner) => crate::error::UrlError::UnableToConnect(inner), + UrlError::UnsupportedUrlScheme => crate::error::UrlError::UnsupportedUrlScheme, + UrlError::EmptyHostName => crate::error::UrlError::EmptyHostName, + UrlError::NoPathOrQuery => crate::error::UrlError::NoPathOrQuery, + } + } +} + +impl From for crate::error::ProtocolError { + fn from(err: ProtocolError) -> Self { + match err { + ProtocolError::WrongHttpMethod => crate::error::ProtocolError::WrongHttpMethod, + ProtocolError::WrongHttpVersion => crate::error::ProtocolError::WrongHttpVersion, + ProtocolError::MissingConnectionUpgradeHeader => { + crate::error::ProtocolError::MissingConnectionUpgradeHeader + }, + ProtocolError::MissingUpgradeWebSocketHeader => { + crate::error::ProtocolError::MissingUpgradeWebSocketHeader + }, + ProtocolError::MissingSecWebSocketVersionHeader => { + crate::error::ProtocolError::MissingSecWebSocketVersionHeader + }, + ProtocolError::MissingSecWebSocketKey => crate::error::ProtocolError::MissingSecWebSocketKey, + ProtocolError::SecWebSocketAcceptKeyMismatch => { + crate::error::ProtocolError::SecWebSocketAcceptKeyMismatch + }, + ProtocolError::JunkAfterRequest => crate::error::ProtocolError::JunkAfterRequest, + ProtocolError::CustomResponseSuccessful => { + crate::error::ProtocolError::CustomResponseSuccessful + }, + ProtocolError::InvalidHeader(header_name) => { + crate::error::ProtocolError::InvalidHeader(header_name) + }, + ProtocolError::HandshakeIncomplete => crate::error::ProtocolError::HandshakeIncomplete, + ProtocolError::HttparseError(inner) => crate::error::ProtocolError::HttparseError(inner), + ProtocolError::SendAfterClosing => crate::error::ProtocolError::SendAfterClosing, + ProtocolError::ReceivedAfterClosing => crate::error::ProtocolError::ReceivedAfterClosing, + ProtocolError::NonZeroReservedBits => crate::error::ProtocolError::NonZeroReservedBits, + ProtocolError::UnmaskedFrameFromClient => { + crate::error::ProtocolError::UnmaskedFrameFromClient + }, + ProtocolError::MaskedFrameFromServer => crate::error::ProtocolError::MaskedFrameFromServer, + ProtocolError::FragmentedControlFrame => crate::error::ProtocolError::FragmentedControlFrame, + ProtocolError::ControlFrameTooBig => crate::error::ProtocolError::ControlFrameTooBig, + ProtocolError::UnknownControlFrameType(inner) => { + crate::error::ProtocolError::UnknownControlFrameType(inner) + }, + ProtocolError::UnknownDataFrameType(inner) => { + crate::error::ProtocolError::UnknownDataFrameType(inner) + }, + ProtocolError::UnexpectedContinueFrame => { + crate::error::ProtocolError::UnexpectedContinueFrame + }, + ProtocolError::ExpectedFragment(inner) => { + crate::error::ProtocolError::ExpectedFragment(inner.into()) + }, + ProtocolError::ResetWithoutClosingHandshake => { + crate::error::ProtocolError::ResetWithoutClosingHandshake + }, + ProtocolError::InvalidOpcode(inner) => crate::error::ProtocolError::InvalidOpcode(inner), + ProtocolError::InvalidCloseSequence => crate::error::ProtocolError::InvalidCloseSequence, + } + } +} + +impl From for crate::error::Data { + fn from(data: Data) -> Self { + match data { + Data::Continue => crate::error::Data::Continue, + Data::Text => crate::error::Data::Text, + Data::Binary => crate::error::Data::Binary, + Data::Reserved(inner) => crate::error::Data::Reserved(inner), + } + } +} diff --git a/libs/websocket/src/web.rs b/libs/websocket/src/web.rs new file mode 100644 index 000000000..0c8d82491 --- /dev/null +++ b/libs/websocket/src/web.rs @@ -0,0 +1,246 @@ +use std::{cell::RefCell, collections::VecDeque, rc::Rc, task::Waker}; +use wasm_bindgen::{closure::Closure, JsCast}; +use web_sys::{CloseEvent, ErrorEvent, MessageEvent, WebSocket}; + +pub async fn connect_async(url: &str) -> crate::Result { + WebSocketStream::new(url).await +} + +pub struct WebSocketStream { + inner: WebSocket, + queue: Rc>>>, + waker: Rc>>, + _on_message_callback: Closure, + _on_error_callback: Closure, + _on_close_callback: Closure, +} + +impl WebSocketStream { + async fn new(url: &str) -> crate::Result { + match web_sys::WebSocket::new(url) { + Err(_err) => Err(crate::Error::Url( + crate::error::UrlError::UnsupportedUrlScheme, + )), + Ok(ws) => { + ws.set_binary_type(web_sys::BinaryType::Arraybuffer); + + let (open_sx, open_rx) = futures_channel::oneshot::channel(); + let on_open_callback = { + let mut open_sx = Some(open_sx); + Closure::wrap(Box::new(move |_event| { + open_sx.take().map(|open_sx| open_sx.send(())); + }) as Box) + }; + ws.set_onopen(Some(on_open_callback.as_ref().unchecked_ref())); + + let (err_sx, err_rx) = futures_channel::oneshot::channel(); + let on_error_callback = { + let mut err_sx = Some(err_sx); + Closure::wrap(Box::new(move |_error_event| { + err_sx.take().map(|err_sx| err_sx.send(())); + }) as Box) + }; + ws.set_onerror(Some(on_error_callback.as_ref().unchecked_ref())); + + let result = futures_util::future::select(open_rx, err_rx).await; + ws.set_onopen(None); + ws.set_onerror(None); + let ws = match result { + futures_util::future::Either::Left((_, _)) => Ok(ws), + futures_util::future::Either::Right((_, _)) => Err(crate::Error::ConnectionClosed), + }?; + + let waker = Rc::new(RefCell::new(Option::::None)); + let queue = Rc::new(RefCell::new(VecDeque::new())); + let on_message_callback = { + let waker = Rc::clone(&waker); + let queue = Rc::clone(&queue); + Closure::wrap(Box::new(move |event: MessageEvent| { + let payload = std::convert::TryFrom::try_from(event); + queue.borrow_mut().push_back(payload); + if let Some(waker) = waker.borrow_mut().take() { + waker.wake(); + } + }) as Box) + }; + ws.set_onmessage(Some(on_message_callback.as_ref().unchecked_ref())); + + let on_error_callback = { + let waker = Rc::clone(&waker); + let queue = Rc::clone(&queue); + Closure::wrap(Box::new(move |_error_event| { + queue + .borrow_mut() + .push_back(Err(crate::Error::ConnectionClosed)); + if let Some(waker) = waker.borrow_mut().take() { + waker.wake(); + } + }) as Box) + }; + ws.set_onerror(Some(on_error_callback.as_ref().unchecked_ref())); + + let on_close_callback = { + let waker = Rc::clone(&waker); + let queue = Rc::clone(&queue); + Closure::wrap(Box::new(move |event: CloseEvent| { + queue.borrow_mut().push_back(Ok(crate::Message::Close(Some( + crate::message::CloseFrame { + code: event.code().into(), + reason: event.reason().into(), + }, + )))); + if let Some(waker) = waker.borrow_mut().take() { + waker.wake(); + } + }) as Box) + }; + ws.set_onclose(Some(on_close_callback.as_ref().unchecked_ref())); + + Ok(Self { + inner: ws, + queue, + waker, + _on_message_callback: on_message_callback, + _on_error_callback: on_error_callback, + _on_close_callback: on_close_callback, + }) + }, + } + } +} + +impl Drop for WebSocketStream { + fn drop(&mut self) { + let _r = self.inner.close(); + self.inner.set_onmessage(None); + self.inner.set_onclose(None); + self.inner.set_onerror(None); + } +} + +enum ReadyState { + Closed, + Closing, + Connecting, + Open, +} + +impl std::convert::TryFrom for ReadyState { + type Error = (); + + fn try_from(value: u16) -> Result { + match value { + web_sys::WebSocket::CLOSED => Ok(Self::Closed), + web_sys::WebSocket::CLOSING => Ok(Self::Closing), + web_sys::WebSocket::OPEN => Ok(Self::Open), + web_sys::WebSocket::CONNECTING => Ok(Self::Connecting), + _ => Err(()), + } + } +} + +mod stream { + use super::ReadyState; + use std::pin::Pin; + use std::task::{Context, Poll}; + + impl futures_util::Stream for super::WebSocketStream { + type Item = crate::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.queue.borrow().is_empty() { + *self.waker.borrow_mut() = Some(cx.waker().clone()); + + match std::convert::TryFrom::try_from(self.inner.ready_state()) { + Ok(ReadyState::Open) => Poll::Pending, + _ => None.into(), + } + } else { + self.queue.borrow_mut().pop_front().into() + } + } + } + + impl futures_util::Sink for super::WebSocketStream { + type Error = crate::Error; + + fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + match std::convert::TryFrom::try_from(self.inner.ready_state()) { + Ok(ReadyState::Open) => Ok(()).into(), + _ => Err(crate::Error::ConnectionClosed).into(), + } + } + + fn start_send(self: Pin<&mut Self>, item: crate::Message) -> Result<(), Self::Error> { + match std::convert::TryFrom::try_from(self.inner.ready_state()) { + Ok(ReadyState::Open) => { + match item { + crate::Message::Text(text) => self + .inner + .send_with_str(&text) + .map_err(|_| crate::Error::Utf8)?, + crate::Message::Binary(bin) => self + .inner + .send_with_u8_array(&bin) + .map_err(|_| crate::Error::Utf8)?, + crate::Message::Close(frame) => match frame { + None => self + .inner + .close() + .map_err(|_| crate::Error::AlreadyClosed)?, + Some(frame) => self + .inner + .close_with_code_and_reason(frame.code.into(), &frame.reason) + .map_err(|_| crate::Error::AlreadyClosed)?, + }, + crate::Message::Ping(data) => self + .inner + .send_with_u8_array(&data) + .map_err(|_| crate::Error::Utf8)?, + crate::Message::Pong(data) => self + .inner + .send_with_u8_array(&data) + .map_err(|_| crate::Error::Utf8)?, + } + Ok(()) + }, + _ => Err(crate::Error::ConnectionClosed), + } + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Ok(()).into() + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + self + .inner + .close() + .map_err(|_| crate::Error::AlreadyClosed)?; + Ok(()).into() + } + } +} + +impl std::convert::TryFrom for crate::Message { + type Error = crate::Error; + + fn try_from(event: MessageEvent) -> Result { + match event.data() { + payload if payload.is_instance_of::() => { + let buffer = js_sys::Uint8Array::new(payload.unchecked_ref()); + let mut v = vec![0; buffer.length() as usize]; + buffer.copy_to(&mut v); + Ok(crate::Message::Binary(v)) + }, + payload if payload.is_string() => match payload.as_string() { + Some(text) => Ok(crate::Message::Text(text)), + None => Err(crate::Error::Utf8), + }, + payload if payload.is_instance_of::() => { + Err(crate::Error::BlobFormatUnsupported) + }, + _ => Err(crate::Error::UnknownFormat), + } + } +} diff --git a/tests/collab/collab_curd_test.rs b/tests/collab/collab_curd_test.rs index e6e566564..3be08983c 100644 --- a/tests/collab/collab_curd_test.rs +++ b/tests/collab/collab_curd_test.rs @@ -1,4 +1,3 @@ -use crate::util::test_client::TestClient; use app_error::ErrorCode; use assert_json_diff::assert_json_include; use collab::core::collab_plugin::EncodedCollab; @@ -14,6 +13,7 @@ use reqwest::Method; use serde::Serialize; use serde_json::json; +use client_api_test_util::TestClient; use shared_entity::response::AppResponse; use uuid::Uuid; diff --git a/tests/collab/edit_permission.rs b/tests/collab/edit_permission.rs index 1934ec99e..b63ff54ea 100644 --- a/tests/collab/edit_permission.rs +++ b/tests/collab/edit_permission.rs @@ -1,4 +1,4 @@ -use crate::util::test_client::{ +use client_api_test_util::{ assert_client_collab, assert_client_collab_include_value, assert_server_collab, TestClient, }; use collab_entity::CollabType; diff --git a/tests/collab/member_crud.rs b/tests/collab/member_crud.rs index d0d1a5e38..052938874 100644 --- a/tests/collab/member_crud.rs +++ b/tests/collab/member_crud.rs @@ -1,5 +1,4 @@ -use crate::collab::workspace_id_from_client; -use crate::user::utils::generate_unique_registered_user_client; +use client_api_test_util::{generate_unique_registered_user_client, workspace_id_from_client}; use collab_entity::CollabType; use database_entity::dto::{ AFAccessLevel, CollabMemberIdentify, CreateCollabParams, InsertCollabMemberParams, diff --git a/tests/collab/mod.rs b/tests/collab/mod.rs index 4ed9a000c..bdb8bdf9e 100644 --- a/tests/collab/mod.rs +++ b/tests/collab/mod.rs @@ -1,5 +1,3 @@ -use client_api::Client; - mod collab_curd_test; mod edit_permission; mod member_crud; @@ -8,14 +6,3 @@ mod single_device_edit; mod snapshot_test; mod storage_test; mod workspace_collab; - -pub(crate) async fn workspace_id_from_client(c: &Client) -> String { - c.get_workspaces() - .await - .unwrap() - .0 - .first() - .unwrap() - .workspace_id - .to_string() -} diff --git a/tests/collab/multi_devices_edit.rs b/tests/collab/multi_devices_edit.rs index 6014b6094..896116301 100644 --- a/tests/collab/multi_devices_edit.rs +++ b/tests/collab/multi_devices_edit.rs @@ -1,9 +1,6 @@ -use crate::user::utils::generate_unique_registered_user; -use crate::util::test_client::{ - assert_client_collab, assert_client_collab_include_value, assert_server_collab, TestClient, -}; use std::time::Duration; +use client_api_test_util::*; use collab_entity::CollabType; use serde_json::json; use sqlx::types::uuid; diff --git a/tests/collab/single_device_edit.rs b/tests/collab/single_device_edit.rs index ce2fa9c0f..20240d13d 100644 --- a/tests/collab/single_device_edit.rs +++ b/tests/collab/single_device_edit.rs @@ -1,8 +1,6 @@ -use crate::util::test_client::{ - assert_client_collab_include_value, assert_server_collab, TestClient, -}; use collab_entity::CollabType; +use client_api_test_util::*; use database_entity::dto::AFAccessLevel; use serde_json::json; use uuid::Uuid; diff --git a/tests/collab/snapshot_test.rs b/tests/collab/snapshot_test.rs index ebfce3364..d3ace72b9 100644 --- a/tests/collab/snapshot_test.rs +++ b/tests/collab/snapshot_test.rs @@ -1,10 +1,10 @@ -use crate::util::test_client::{assert_server_snapshot, TestClient}; use collab::core::collab_plugin::EncodedCollab; use collab::preclude::Collab; use collab_entity::CollabType; use serde_json::{json, Value}; use std::time::Duration; +use client_api_test_util::*; use database::collab::COLLAB_SNAPSHOT_LIMIT; use uuid::Uuid; diff --git a/tests/collab/storage_test.rs b/tests/collab/storage_test.rs index 1f6f9b018..02eeeecdd 100644 --- a/tests/collab/storage_test.rs +++ b/tests/collab/storage_test.rs @@ -1,15 +1,12 @@ -use crate::{ - collab::workspace_id_from_client, user::utils::generate_unique_registered_user_client, -}; -use collab::core::collab_plugin::EncodedCollab; -use std::collections::HashMap; - use app_error::ErrorCode; +use client_api_test_util::*; +use collab::core::collab_plugin::EncodedCollab; use collab_entity::CollabType; use database_entity::dto::{ CreateCollabParams, DeleteCollabParams, QueryCollab, QueryCollabParams, QueryCollabResult, }; use sqlx::types::Uuid; +use std::collections::HashMap; #[tokio::test] async fn success_insert_collab_test() { diff --git a/tests/collab/workspace_collab.rs b/tests/collab/workspace_collab.rs index 001d579e5..74390f631 100644 --- a/tests/collab/workspace_collab.rs +++ b/tests/collab/workspace_collab.rs @@ -1,6 +1,4 @@ -use crate::util::test_client::{ - assert_client_collab, assert_client_collab_include_value, assert_server_collab, TestClient, -}; +use client_api_test_util::*; use collab_entity::CollabType; use database_entity::dto::AFRole; use serde_json::json; diff --git a/tests/gotrue/admin.rs b/tests/gotrue/admin.rs index 4e537636d..5eda6a467 100644 --- a/tests/gotrue/admin.rs +++ b/tests/gotrue/admin.rs @@ -1,15 +1,10 @@ +use client_api_test_util::*; use gotrue::{ api::Client, grant::{Grant, PasswordGrant}, params::{AdminDeleteUserParams, AdminUserParams, GenerateLinkParams}, }; -use crate::{ - localhost_client, - user::utils::{generate_unique_email, ADMIN_USER}, - LOCALHOST_GOTRUE, -}; - #[tokio::test] async fn admin_user_create_list_edit_delete() { let http_client = reqwest::Client::new(); diff --git a/tests/gotrue/health.rs b/tests/gotrue/health.rs index 6f40e09d8..2ec46d57b 100644 --- a/tests/gotrue/health.rs +++ b/tests/gotrue/health.rs @@ -1,7 +1,6 @@ +use client_api_test_util::LOCALHOST_GOTRUE; use gotrue::api::Client; -use crate::LOCALHOST_GOTRUE; - #[tokio::test] async fn gotrue_health() { let http_client = reqwest::Client::new(); diff --git a/tests/gotrue/settings.rs b/tests/gotrue/settings.rs index cb1cf062a..8de07230e 100644 --- a/tests/gotrue/settings.rs +++ b/tests/gotrue/settings.rs @@ -1,14 +1,10 @@ +use client_api_test_util::{generate_unique_email, ADMIN_USER, LOCALHOST_GOTRUE}; use gotrue::{ api::Client, grant::{Grant, PasswordGrant}, params::AdminUserParams, }; -use crate::{ - user::utils::{generate_unique_email, ADMIN_USER}, - LOCALHOST_GOTRUE, -}; - #[tokio::test] async fn gotrue_settings() { let http_client = reqwest::Client::new(); diff --git a/tests/main.rs b/tests/main.rs index dd42e372a..b1c517a4e 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -1,47 +1,7 @@ -extern crate core; -use client_api::{Client, ClientConfiguration}; -use dotenvy::dotenv; -use tracing::warn; mod casbin; mod collab; mod gotrue; mod user; -mod util; -mod websocket; mod workspace; -use lazy_static::lazy_static; -use std::{borrow::Cow, env}; - -lazy_static! { - pub static ref LOCALHOST_URL: Cow<'static, str> = - get_env_var("LOCALHOST_URL", "http://localhost:8000"); - pub static ref LOCALHOST_WS: Cow<'static, str> = - get_env_var("LOCALHOST_WS", "ws://localhost:8000/ws"); - pub static ref LOCALHOST_GOTRUE: Cow<'static, str> = - get_env_var("LOCALHOST_GOTRUE", "http://localhost:9999"); -} - -fn get_env_var<'default>(key: &str, default: &'default str) -> Cow<'default, str> { - dotenv().ok(); - match env::var(key) { - Ok(value) => Cow::Owned(value), - Err(_) => { - warn!("could not read env var {}: using default: {}", key, default); - Cow::Borrowed(default) - }, - } -} - -/// Return a client that connects to the local host. It requires to run the server locally. -/// ```shell -/// ./build/run_local_server.sh -/// ``` -pub fn localhost_client() -> Client { - Client::new( - &LOCALHOST_URL, - &LOCALHOST_WS, - &LOCALHOST_GOTRUE, - ClientConfiguration::default(), - ) -} +mod websocket; diff --git a/tests/user/delete.rs b/tests/user/delete.rs index 08b4436bd..415dad7c6 100644 --- a/tests/user/delete.rs +++ b/tests/user/delete.rs @@ -1,12 +1,6 @@ +use client_api_test_util::*; use gotrue::params::{AdminDeleteUserParams, AdminUserParams}; -use crate::{ - localhost_client, - user::utils::{ - admin_user_client, generate_unique_registered_user_client, localhost_gotrue_client, - }, -}; - #[tokio::test] async fn admin_delete_create_same_user_hard() { let (client, user) = generate_unique_registered_user_client().await; diff --git a/tests/user/mod.rs b/tests/user/mod.rs index 09317ff1c..4e8712949 100644 --- a/tests/user/mod.rs +++ b/tests/user/mod.rs @@ -4,4 +4,3 @@ mod sign_in; mod sign_out; mod sign_up; mod update; -pub mod utils; diff --git a/tests/user/refresh.rs b/tests/user/refresh.rs index 55326ed5c..a17853ccd 100644 --- a/tests/user/refresh.rs +++ b/tests/user/refresh.rs @@ -1,9 +1,8 @@ use app_error::AppError; +use client_api_test_util::generate_unique_registered_user_client; use futures::future::join_all; use std::time::{Duration, SystemTime}; -use crate::user::utils::generate_unique_registered_user_client; - #[tokio::test] async fn refresh_success() { let (c, _user) = generate_unique_registered_user_client().await; diff --git a/tests/user/sign_in.rs b/tests/user/sign_in.rs index 47dd65574..23141e574 100644 --- a/tests/user/sign_in.rs +++ b/tests/user/sign_in.rs @@ -1,8 +1,5 @@ -use crate::localhost_client; -use crate::user::utils::{ - generate_sign_in_action_link, generate_unique_email, generate_unique_registered_user, -}; use app_error::ErrorCode; +use client_api_test_util::*; #[tokio::test] async fn sign_in_unknown_user() { diff --git a/tests/user/sign_out.rs b/tests/user/sign_out.rs index d863f59ba..fc660465a 100644 --- a/tests/user/sign_out.rs +++ b/tests/user/sign_out.rs @@ -1,4 +1,4 @@ -use crate::{localhost_client, user::utils::generate_unique_registered_user_client}; +use client_api_test_util::*; #[tokio::test] async fn sign_out_but_not_sign_in() { diff --git a/tests/user/sign_up.rs b/tests/user/sign_up.rs index 5965af174..4c4247c03 100644 --- a/tests/user/sign_up.rs +++ b/tests/user/sign_up.rs @@ -1,11 +1,7 @@ use app_error::ErrorCode; +use client_api_test_util::*; use gotrue_entity::dto::AuthProvider; -use crate::{ - localhost_client, - user::utils::{generate_unique_email, generate_unique_registered_user_client}, -}; - #[tokio::test] async fn sign_up_success() { let email = generate_unique_email(); diff --git a/tests/user/update.rs b/tests/user/update.rs index 8dbdfa674..e75800db5 100644 --- a/tests/user/update.rs +++ b/tests/user/update.rs @@ -1,7 +1,6 @@ -use crate::localhost_client; -use crate::user::utils::generate_unique_registered_user_client; use app_error::ErrorCode; -use client_api::{WSClient, WSClientConfig}; +use client_api::ws::{WSClient, WSClientConfig}; +use client_api_test_util::*; use serde_json::json; use shared_entity::dto::auth_dto::{UpdateUserParams, UserMetaData}; use std::time::Duration; @@ -165,7 +164,7 @@ async fn user_change_notify_test() { let mut user_change_recv = ws_client.subscribe_user_changed(); let device_id = "fake_device_id"; - let _ = ws_client + ws_client .connect(c.ws_url(device_id).await.unwrap(), device_id) .await .unwrap(); diff --git a/tests/websocket/conn_test.rs b/tests/websocket/conn_test.rs new file mode 100644 index 000000000..98cc62628 --- /dev/null +++ b/tests/websocket/conn_test.rs @@ -0,0 +1,89 @@ +use std::time::SystemTime; + +use client_api::ws::{ConnectState, WSClient, WSClientConfig}; +use client_api_test_util::generate_unique_registered_user_client; + +#[tokio::test] +async fn realtime_connect_test() { + let (c, _user) = generate_unique_registered_user_client().await; + let ws_client = WSClient::new(WSClientConfig::default(), c.clone()); + let mut state = ws_client.subscribe_connect_state(); + let device_id = "fake_device_id"; + loop { + tokio::select! { + _ = ws_client.connect(c.ws_url(device_id).await.unwrap(), device_id) => {}, + value = state.recv() => { + let new_state = value.unwrap(); + if new_state == ConnectState::Connected { + break; + } + }, + } + } +} + +#[tokio::test] +async fn realtime_connect_after_token_exp_test() { + let (c, _user) = generate_unique_registered_user_client().await; + + // Set the token to be expired + c.token().write().as_mut().unwrap().expires_at = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + + let ws_client = WSClient::new(WSClientConfig::default(), c.clone()); + let mut state = ws_client.subscribe_connect_state(); + let device_id = "fake_device_id"; + loop { + tokio::select! { + _ = ws_client.connect(c.ws_url(device_id).await.unwrap(), device_id) => {}, + value = state.recv() => { + let new_state = value.unwrap(); + if new_state == ConnectState::Connected { + break; + } + }, + } + } +} + +#[tokio::test] +async fn realtime_disconnect_test() { + let (c, _user) = generate_unique_registered_user_client().await; + let ws_client = WSClient::new(WSClientConfig::default(), c.clone()); + let device_id = "fake_device_id"; + ws_client + .connect(c.ws_url(device_id).await.unwrap(), device_id) + .await + .unwrap(); + + let mut state = ws_client.subscribe_connect_state(); + loop { + tokio::select! { + _ = ws_client.disconnect() => {}, + value = state.recv() => { + let new_state = value.unwrap(); + if new_state == ConnectState::Closed { + break; + } + }, + } + } +} + +// use std::time::Duration; +// use tokio_tungstenite::tungstenite::Message; +// #[tokio::test] +// async fn max_frame_size() { +// let (c, _user) = generate_unique_registered_user_client().await; +// let ws_client = WSClient::new(WSClientConfig::default(), c.clone()); +// let device_id = "fake_device_id"; +// ws_client +// .connect(c.ws_url(device_id).unwrap(), device_id) +// .await +// .unwrap(); +// +// ws_client.send(Message::Binary(vec![0; 65536])).unwrap(); +// tokio::time::sleep(Duration::from_secs(5)).await; +// } diff --git a/tests/websocket/mod.rs b/tests/websocket/mod.rs index eb7f65c88..a217458fc 100644 --- a/tests/websocket/mod.rs +++ b/tests/websocket/mod.rs @@ -1 +1 @@ -mod connect; +mod conn_test; diff --git a/tests/workspace/blob/put_and_get.rs b/tests/workspace/blob/put_and_get.rs index 8e83bac6c..b98ab8cd9 100644 --- a/tests/workspace/blob/put_and_get.rs +++ b/tests/workspace/blob/put_and_get.rs @@ -1,6 +1,5 @@ -use crate::collab::workspace_id_from_client; -use crate::user::utils::generate_unique_registered_user_client; use app_error::ErrorCode; +use client_api_test_util::{generate_unique_registered_user_client, workspace_id_from_client}; #[tokio::test] async fn get_but_not_exists() { diff --git a/tests/workspace/blob/usage.rs b/tests/workspace/blob/usage.rs index 2130e20c5..75b9084b2 100644 --- a/tests/workspace/blob/usage.rs +++ b/tests/workspace/blob/usage.rs @@ -1,4 +1,4 @@ -use crate::util::test_client::TestClient; +use client_api_test_util::TestClient; #[tokio::test] async fn workspace_usage_put_blob_test() { diff --git a/tests/workspace/member_crud.rs b/tests/workspace/member_crud.rs index ddb1b62ea..7a4b599cf 100644 --- a/tests/workspace/member_crud.rs +++ b/tests/workspace/member_crud.rs @@ -1,5 +1,5 @@ -use crate::util::test_client::TestClient; use app_error::ErrorCode; +use client_api_test_util::TestClient; use database_entity::dto::AFRole; use shared_entity::dto::workspace_dto::CreateWorkspaceMember; diff --git a/tests/workspace/template_test.rs b/tests/workspace/template_test.rs index 70fa935e5..f44785348 100644 --- a/tests/workspace/template_test.rs +++ b/tests/workspace/template_test.rs @@ -1,6 +1,4 @@ -use crate::localhost_client; -use crate::user::utils::generate_unique_email; -use crate::util::test_client::TestClient; +use client_api_test_util::*; #[tokio::test] async fn get_user_default_workspace_test() {