From a318d2f96dde58dbb760f2ac08d278ae01e1a0a4 Mon Sep 17 00:00:00 2001 From: "Nathan.fooo" <86001920+appflowy@users.noreply.github.com> Date: Wed, 24 Jan 2024 05:51:42 +0800 Subject: [PATCH 1/3] test: wasm test (#271) * test: wasm test * ci: run wasm test * fix: wasm websocket connect * chore: add logs * ci: fix --- .../{docker.yml => integration_test.yml} | 12 +++ Cargo.lock | 79 +++++++++++++--- libs/client-api-test-util/Cargo.toml | 8 +- libs/client-api-test-util/src/client.rs | 10 ++ libs/client-api-test-util/src/user.rs | 22 ++--- libs/client-api/Cargo.toml | 12 ++- libs/client-api/src/collab_sync/plugin.rs | 7 +- libs/client-api/src/collab_sync/sink.rs | 6 +- libs/client-api/src/collab_sync/sync.rs | 8 +- libs/client-api/src/http.rs | 10 ++ libs/client-api/src/native/http_native.rs | 6 +- libs/client-api/src/native/mod.rs | 5 +- libs/client-api/src/{ws => native}/ping.rs | 3 +- libs/client-api/src/native/retry.rs | 91 ++++++++++++++++++- libs/client-api/src/wasm/http_wasm.rs | 2 +- libs/client-api/src/wasm/mod.rs | 4 + libs/client-api/src/wasm/ping.rs | 50 ++++++++++ libs/client-api/src/wasm/retry.rs | 12 +++ libs/client-api/src/ws/client.rs | 72 ++++----------- libs/client-api/src/ws/handler.rs | 5 +- libs/client-api/src/ws/mod.rs | 2 - libs/client-api/src/ws/retry.rs | 41 --------- libs/wasm-test/Cargo.toml | 3 + libs/wasm-test/README.md | 25 +++++ libs/wasm-test/tests/conn_test.rs | 23 +++-- libs/wasm-test/tests/main.rs | 7 +- libs/wasm-test/tests/user_test.rs | 10 ++ src/middleware/cors_mw.rs | 5 +- tests/websocket/conn_test.rs | 16 ---- 29 files changed, 378 insertions(+), 178 deletions(-) rename .github/workflows/{docker.yml => integration_test.yml} (85%) rename libs/client-api/src/{ws => native}/ping.rs (98%) create mode 100644 libs/client-api/src/wasm/ping.rs create mode 100644 libs/client-api/src/wasm/retry.rs delete mode 100644 libs/client-api/src/ws/retry.rs create mode 100644 libs/wasm-test/README.md create mode 100644 libs/wasm-test/tests/user_test.rs diff --git a/.github/workflows/docker.yml b/.github/workflows/integration_test.yml similarity index 85% rename from .github/workflows/docker.yml rename to .github/workflows/integration_test.yml index cc24dd081..2efd75fb9 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/integration_test.yml @@ -55,3 +55,15 @@ jobs: - name: Run tests run: | cargo test + + - name: Install Node.js + uses: actions/setup-node@v2 + with: + node-version: '14' + + - name: Run WASM tests + working-directory: ./libs/wasm-test + run: | + cargo install wasm-pack + wasm-pack test --headless --firefox --features="wasm_test" + diff --git a/Cargo.lock b/Cargo.lock index fdf36837e..de9865a61 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,7 +20,7 @@ dependencies = [ "futures-util", "log", "once_cell", - "parking_lot", + "parking_lot 0.12.1", "pin-project-lite", "smallvec", "tokio", @@ -390,6 +390,17 @@ dependencies = [ "subtle", ] +[[package]] +name = "again" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05802a5ad4d172eaf796f7047b42d0af9db513585d16d4169660a21613d34b93" +dependencies = [ + "log", + "rand 0.7.3", + "wasm-timer", +] + [[package]] name = "ahash" version = "0.7.6" @@ -1082,7 +1093,7 @@ dependencies = [ "fixedbitset", "getrandom 0.2.10", "once_cell", - "parking_lot", + "parking_lot 0.12.1", "petgraph", "regex", "rhai", @@ -1201,6 +1212,7 @@ checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" name = "client-api" version = "0.1.0" dependencies = [ + "again", "anyhow", "app-error", "async-trait", @@ -1218,7 +1230,7 @@ dependencies = [ "gotrue-entity", "mime", "mime_guess", - "parking_lot", + "parking_lot 0.12.1", "prost", "realtime-entity", "realtime-protocol", @@ -1235,6 +1247,7 @@ dependencies = [ "tracing", "url", "uuid", + "wasm-bindgen-futures", "websocket", "workspace-template", "yrs", @@ -1268,6 +1281,7 @@ dependencies = [ "tracing", "tracing-subscriber", "uuid", + "web-sys", ] [[package]] @@ -1279,7 +1293,7 @@ dependencies = [ "async-trait", "bincode", "bytes", - "parking_lot", + "parking_lot 0.12.1", "serde", "serde_json", "serde_repr", @@ -1299,7 +1313,7 @@ dependencies = [ "collab", "collab-entity", "nanoid", - "parking_lot", + "parking_lot 0.12.1", "serde", "serde_json", "thiserror", @@ -1331,7 +1345,7 @@ dependencies = [ "chrono", "collab", "collab-entity", - "parking_lot", + "parking_lot 0.12.1", "serde", "serde_json", "serde_repr", @@ -2073,7 +2087,7 @@ checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" dependencies = [ "futures-core", "lock_api", - "parking_lot", + "parking_lot 0.12.1", ] [[package]] @@ -3191,6 +3205,17 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "parking_lot" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" +dependencies = [ + "instant", + "lock_api", + "parking_lot_core 0.8.6", +] + [[package]] name = "parking_lot" version = "0.12.1" @@ -3198,7 +3223,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" dependencies = [ "lock_api", - "parking_lot_core", + "parking_lot_core 0.9.8", +] + +[[package]] +name = "parking_lot_core" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc" +dependencies = [ + "cfg-if", + "instant", + "libc", + "redox_syscall 0.2.16", + "smallvec", + "winapi", ] [[package]] @@ -3535,7 +3574,7 @@ checksum = "510c4f1c9d81d556458f94c98f857748130ea9737bbd6053da497503b26ea63c" dependencies = [ "dtoa", "itoa", - "parking_lot", + "parking_lot 0.12.1", "prometheus-client-derive-encode", ] @@ -3839,7 +3878,7 @@ dependencies = [ "database-entity", "futures-util", "once_cell", - "parking_lot", + "parking_lot 0.12.1", "realtime-entity", "realtime-protocol", "reqwest", @@ -4998,7 +5037,7 @@ checksum = "f91138e76242f575eb1d3b38b4f1362f10d3a43f47d182a5b359af488a02293b" dependencies = [ "new_debug_unreachable", "once_cell", - "parking_lot", + "parking_lot 0.12.1", "phf_shared 0.10.0", "precomputed-hash", "serde", @@ -5262,7 +5301,7 @@ dependencies = [ "libc", "mio", "num_cpus", - "parking_lot", + "parking_lot 0.12.1", "pin-project-lite", "signal-hook-registry", "socket2 0.5.5", @@ -5880,9 +5919,25 @@ dependencies = [ "client-api", "client-api-test-util", "tokio", + "wasm-bindgen-futures", "wasm-bindgen-test", ] +[[package]] +name = "wasm-timer" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be0ecb0db480561e9a7642b5d3e4187c128914e58aa84330b9493e3eb68c5e7f" +dependencies = [ + "futures", + "js-sys", + "parking_lot 0.11.2", + "pin-utils", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.64" diff --git a/libs/client-api-test-util/Cargo.toml b/libs/client-api-test-util/Cargo.toml index 9c8616e29..7d2fd4f8c 100644 --- a/libs/client-api-test-util/Cargo.toml +++ b/libs/client-api-test-util/Cargo.toml @@ -29,4 +29,10 @@ uuid = "1.6.1" lazy_static = "1.4.0" dotenv = "0.15.0" reqwest = "0.11.23" -gotrue.workspace = true \ No newline at end of file +gotrue.workspace = true + +[target.'cfg(target_arch = "wasm32")'.dependencies] +web-sys = { version = "0.3", features = ["console"] } + +[features] +wasm_test = [] \ No newline at end of file diff --git a/libs/client-api-test-util/src/client.rs b/libs/client-api-test-util/src/client.rs index 829b353d4..057ddae8d 100644 --- a/libs/client-api-test-util/src/client.rs +++ b/libs/client-api-test-util/src/client.rs @@ -5,6 +5,7 @@ use std::borrow::Cow; use std::env; use tracing::warn; +#[cfg(not(feature = "wasm_test"))] lazy_static! { pub static ref LOCALHOST_URL: Cow<'static, str> = get_env_var("LOCALHOST_URL", "http://localhost:8000"); @@ -14,6 +15,15 @@ lazy_static! { get_env_var("LOCALHOST_GOTRUE", "http://localhost:9999"); } +// The env vars are not available in wasm32-unknown-unknown +#[cfg(feature = "wasm_test")] +lazy_static! { + pub static ref LOCALHOST_URL: Cow<'static, str> = Cow::Owned("http://localhost".to_string()); + pub static ref LOCALHOST_WS: Cow<'static, str> = Cow::Owned("ws://localhost/ws".to_string()); + pub static ref LOCALHOST_GOTRUE: Cow<'static, str> = + Cow::Owned("http://localhost/gotrue".to_string()); +} + fn get_env_var<'default>(key: &str, default: &'default str) -> Cow<'default, str> { dotenv().ok(); match env::var(key) { diff --git a/libs/client-api-test-util/src/user.rs b/libs/client-api-test-util/src/user.rs index 6882b75d8..445643ef4 100644 --- a/libs/client-api-test-util/src/user.rs +++ b/libs/client-api-test-util/src/user.rs @@ -5,24 +5,12 @@ use dotenv::dotenv; use lazy_static::lazy_static; use uuid::Uuid; -#[cfg(not(target_arch = "wasm32"))] lazy_static! { pub static ref ADMIN_USER: User = { dotenv().ok(); User { - email: std::env::var("GOTRUE_ADMIN_EMAIL").unwrap(), - password: std::env::var("GOTRUE_ADMIN_PASSWORD").unwrap(), - } - }; -} - -#[cfg(target_arch = "wasm32")] -lazy_static! { - pub static ref ADMIN_USER: User = { - dotenv().ok(); - User { - email: "admin@example.com".to_string(), - password: "password".to_string(), + email: std::env::var("GOTRUE_ADMIN_EMAIL").unwrap_or("admin@example.com".to_string()), + password: std::env::var("GOTRUE_ADMIN_PASSWORD").unwrap_or("password".to_string()), } }; } @@ -39,6 +27,12 @@ pub fn generate_unique_email() -> String { pub async fn admin_user_client() -> Client { let admin_client = localhost_client(); + #[cfg(target_arch = "wasm32")] + { + let msg = format!("{}", admin_client); + web_sys::console::log_1(&msg.into()); + } + admin_client .sign_in_password(&ADMIN_USER.email, &ADMIN_USER.password) .await diff --git a/libs/client-api/Cargo.toml b/libs/client-api/Cargo.toml index 424994aba..2e363e89a 100644 --- a/libs/client-api/Cargo.toml +++ b/libs/client-api/Cargo.toml @@ -20,7 +20,6 @@ bytes = "1.5" uuid = "1.6.1" futures-util = "0.3.30" futures-core = "0.3.30" -tokio-retry = "0.3" parking_lot = "0.12.1" brotli = "3.4.0" mime_guess = "2.0.4" @@ -45,10 +44,8 @@ database-entity.workspace = true app-error = { workspace = true, features = ["tokio_error", "bincode_error"] } scraper = { version = "0.17.1", optional = true } -[target.'cfg(target_arch = "wasm32")'.dependencies] -getrandom = { version = "0.2", features = ["js"]} -tokio = { workspace = true, features = ["sync"]} - +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +tokio-retry = "0.3" [target.'cfg(not(target_arch = "wasm32"))'.dependencies.tokio] workspace = true @@ -58,6 +55,11 @@ features = ["sync", "net"] workspace = true features = ["tungstenite"] +[target.'cfg(target_arch = "wasm32")'.dependencies] +wasm-bindgen-futures = "0.4.40" +getrandom = { version = "0.2", features = ["js"]} +tokio = { workspace = true, features = ["sync"]} +again = "0.1.2" [features] collab-sync = ["collab", "yrs"] diff --git a/libs/client-api/src/collab_sync/plugin.rs b/libs/client-api/src/collab_sync/plugin.rs index 9c6912b1a..c0e28f602 100644 --- a/libs/client-api/src/collab_sync/plugin.rs +++ b/libs/client-api/src/collab_sync/plugin.rs @@ -15,6 +15,7 @@ use crate::collab_sync::{SinkConfig, SyncQueue}; use tokio_stream::wrappers::WatchStream; use tracing::trace; +use crate::platform_spawn; use crate::ws::{ConnectState, WSConnectStateReceiver}; use yrs::updates::encoder::Encode; @@ -63,7 +64,7 @@ where ); let mut sync_state_stream = WatchStream::new(sync_queue.subscribe_sync_state()); - tokio::spawn(async move { + platform_spawn(async move { while let Some(new_state) = sync_state_stream.next().await { if let Some(local_collab) = weak_local_collab.upgrade() { if let Some(local_collab) = local_collab.try_lock() { @@ -76,7 +77,7 @@ where let sync_queue = Arc::new(sync_queue); let weak_local_collab = collab; let weak_sync_queue = Arc::downgrade(&sync_queue); - tokio::spawn(async move { + platform_spawn(async move { while let Ok(connect_state) = ws_connect_state.recv().await { match connect_state { ConnectState::Connected => { @@ -132,7 +133,7 @@ where let object_id = self.object.object_id.clone(); let cloned_origin = origin.clone(); - tokio::spawn(async move { + platform_spawn(async move { if let Some(sync_queue) = weak_sync_queue.upgrade() { let payload = Message::Sync(SyncMessage::Update(update)).encode_v1(); sync_queue diff --git a/libs/client-api/src/collab_sync/sink.rs b/libs/client-api/src/collab_sync/sink.rs index 0cee8d2d1..9e62ce144 100644 --- a/libs/client-api/src/collab_sync/sink.rs +++ b/libs/client-api/src/collab_sync/sink.rs @@ -9,8 +9,8 @@ use crate::collab_sync::pending_msg::{MessageState, PendingMsgQueue}; use crate::collab_sync::{SyncError, SyncObject, DEFAULT_SYNC_TIMEOUT}; use futures_util::SinkExt; +use crate::platform_spawn; use realtime_entity::collab_msg::{CollabSinkMessage, MsgId}; -use tokio::spawn; use tokio::sync::{mpsc, oneshot, watch, Mutex}; use tokio::time::{interval, Instant, Interval}; use tracing::{debug, error, event, trace, warn}; @@ -96,7 +96,7 @@ where let weak_notifier = Arc::downgrade(¬ifier); let (tx, rx) = mpsc::channel(1); interval_runner_stop_tx = Some(tx); - spawn(IntervalRunner::new(*duration).run(weak_notifier, rx)); + platform_spawn(IntervalRunner::new(*duration).run(weak_notifier, rx)); } Self { uid, @@ -364,7 +364,7 @@ where } fn retry_later(weak_notifier: Weak>) { - spawn(async move { + platform_spawn(async move { interval(Duration::from_millis(100)).tick().await; if let Some(notifier) = weak_notifier.upgrade() { let _ = notifier.send(false); diff --git a/libs/client-api/src/collab_sync/sync.rs b/libs/client-api/src/collab_sync/sync.rs index 7f58003d5..9c47eb444 100644 --- a/libs/client-api/src/collab_sync/sync.rs +++ b/libs/client-api/src/collab_sync/sync.rs @@ -1,6 +1,7 @@ use crate::collab_sync::{ CollabSink, CollabSinkRunner, SinkConfig, SinkState, SyncError, SyncObject, }; +use crate::platform_spawn; use bytes::Bytes; use collab::core::awareness::Awareness; use collab::core::collab::MutexCollab; @@ -13,7 +14,6 @@ use realtime_protocol::{Message, MessageReader, SyncMessage}; use std::marker::PhantomData; use std::ops::Deref; use std::sync::{Arc, Weak}; -use tokio::spawn; use tokio::sync::watch; use tokio_stream::wrappers::WatchStream; use tracing::{error, event, trace, warn, Level}; @@ -75,7 +75,7 @@ where pause, )); - spawn(CollabSinkRunner::run(Arc::downgrade(&sink), notifier_rx)); + platform_spawn(CollabSinkRunner::run(Arc::downgrade(&sink), notifier_rx)); let cloned_protocol = protocol.clone(); let object_id = object.object_id.clone(); let stream = SyncStream::new( @@ -90,7 +90,7 @@ where let weak_sync_state = Arc::downgrade(&sync_state); let mut sink_state_stream = WatchStream::new(sink_state_rx); // Subscribe the sink state stream and update the sync state in the background. - spawn(async move { + platform_spawn(async move { while let Some(collab_state) = sink_state_stream.next().await { if let Some(sync_state) = weak_sync_state.upgrade() { match collab_state { @@ -209,7 +209,7 @@ where P: CollabSyncProtocol + Send + Sync + 'static, { let cloned_weak_collab = weak_collab.clone(); - spawn(SyncStream::::spawn_doc_stream::

( + platform_spawn(SyncStream::::spawn_doc_stream::

( origin, object_id.clone(), stream, diff --git a/libs/client-api/src/http.rs b/libs/client-api/src/http.rs index 1aac9112b..d67b36b1e 100644 --- a/libs/client-api/src/http.rs +++ b/libs/client-api/src/http.rs @@ -2,6 +2,7 @@ use crate::notify::{ClientToken, TokenStateReceiver}; use anyhow::Context; use brotli::CompressorReader; use gotrue_entity::dto::AuthProvider; +use std::fmt::{Display, Formatter}; use std::io::Read; use app_error::AppError; @@ -1169,6 +1170,15 @@ impl Client { } } +impl Display for Client { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!( + "Client {{ base_url: {}, ws_addr: {}, gotrue_url: {} }}", + self.base_url, self.ws_addr, self.gotrue_client.base_url + )) + } +} + fn url_missing_param(param: &str) -> AppResponseError { AppError::InvalidRequest(format!("Url Missing Parameter:{}", param)).into() } diff --git a/libs/client-api/src/native/http_native.rs b/libs/client-api/src/native/http_native.rs index e40c4197d..4056ac1e2 100644 --- a/libs/client-api/src/native/http_native.rs +++ b/libs/client-api/src/native/http_native.rs @@ -1,7 +1,7 @@ use crate::http::log_request_id; -use crate::retry::{RefreshTokenAction, RefreshTokenRetryCondition}; use crate::ws::{WSClientHttpSender, WSError}; use crate::{spawn_blocking_brotli_compress, Client}; +use crate::{RefreshTokenAction, RefreshTokenRetryCondition}; use app_error::AppError; use async_trait::async_trait; use database_entity::dto::CollabParams; @@ -55,7 +55,7 @@ impl Client { .into_iter() .map(|params| { let config = self.config.clone(); - tokio::spawn(async move { + platform_spawn(async move { let data = params.to_bytes().map_err(AppError::from)?; spawn_blocking_brotli_compress( data, @@ -153,7 +153,7 @@ impl WSClientHttpSender for Client { } } -pub fn spawn(future: T) -> tokio::task::JoinHandle +pub fn platform_spawn(future: T) -> tokio::task::JoinHandle where T: Future + Send + 'static, T::Output: Send + 'static, diff --git a/libs/client-api/src/native/mod.rs b/libs/client-api/src/native/mod.rs index 52888cfef..a12859fa1 100644 --- a/libs/client-api/src/native/mod.rs +++ b/libs/client-api/src/native/mod.rs @@ -1,5 +1,8 @@ mod http_native; -pub mod retry; +mod ping; +mod retry; #[allow(unused_imports)] pub use http_native::*; +pub(crate) use ping::*; +pub(crate) use retry::*; diff --git a/libs/client-api/src/ws/ping.rs b/libs/client-api/src/native/ping.rs similarity index 98% rename from libs/client-api/src/ws/ping.rs rename to libs/client-api/src/native/ping.rs index 0c0ddb8a4..70445aba1 100644 --- a/libs/client-api/src/ws/ping.rs +++ b/libs/client-api/src/native/ping.rs @@ -1,3 +1,4 @@ +use crate::platform_spawn; use crate::ws::{ConnectState, ConnectStateNotify}; use std::sync::Arc; use std::time::Duration; @@ -51,7 +52,7 @@ impl ServerFixIntervalPing { let weak_ping_count = Arc::downgrade(&self.ping_count); let weak_state = Arc::downgrade(&self.state); let reconnect_per_ping = self.maximum_ping_count; - tokio::spawn(async move { + platform_spawn(async move { loop { tokio::select! { _ = interval.tick() => { diff --git a/libs/client-api/src/native/retry.rs b/libs/client-api/src/native/retry.rs index e9461c1f8..ef0779595 100644 --- a/libs/client-api/src/native/retry.rs +++ b/libs/client-api/src/native/retry.rs @@ -1,11 +1,16 @@ use crate::notify::ClientToken; +use crate::ws::{ConnectState, ConnectStateNotify, CurrentAddr, StateNotify, WSError}; use app_error::gotrue::GoTrueError; use gotrue::grant::{Grant, RefreshTokenGrant}; use parking_lot::RwLock; use std::future::Future; use std::pin::Pin; -use std::sync::Arc; -use tokio_retry::{Action, Condition}; +use std::sync::{Arc, Weak}; +use std::time::Duration; +use tokio_retry::strategy::FixedInterval; +use tokio_retry::{Action, Condition, RetryIf}; +use tracing::{debug, info}; +use websocket::{connect_async, WebSocketStream}; pub(crate) struct RefreshTokenAction { token: Arc>, @@ -58,3 +63,85 @@ impl Condition for RefreshTokenRetryCondition { error.is_network_error() } } + +pub async fn retry_connect( + addr: &str, + state_notify: Weak, + current_addr: Weak, +) -> Result { + let connecting_addr = addr.to_owned(); + let stream = RetryIf::spawn( + FixedInterval::new(Duration::from_secs(6)), + ConnectAction::new(connecting_addr.clone()), + RetryCondition { + connecting_addr, + current_addr, + state_notify, + }, + ) + .await?; + Ok(stream) +} + +struct ConnectAction { + addr: String, +} + +impl ConnectAction { + fn new(addr: String) -> Self { + Self { addr } + } +} + +impl Action for ConnectAction { + type Future = Pin> + Send + Sync>>; + type Item = WebSocketStream; + type Error = WSError; + + fn run(&mut self) -> Self::Future { + let cloned_addr = self.addr.clone(); + Box::pin(async move { + info!("🔵websocket start connecting"); + match connect_async(&cloned_addr).await { + Ok(stream) => { + info!("🟢websocket connect success"); + Ok(stream) + }, + Err(e) => Err(e.into()), + } + }) + } +} + +struct RetryCondition { + connecting_addr: String, + current_addr: Weak>>, + state_notify: Weak>, +} +impl Condition for RetryCondition { + fn should_retry(&mut self, error: &WSError) -> bool { + if let WSError::AuthError(err) = error { + debug!("{}, stop retry connect", err); + if let Some(state_notify) = self.state_notify.upgrade() { + state_notify.lock().set_state(ConnectState::Unauthorized); + } + + return false; + } + + let should_retry = self + .current_addr + .upgrade() + .map(|addr| match addr.try_lock() { + None => false, + Some(addr) => match &*addr { + None => false, + Some(addr) => addr == &self.connecting_addr, + }, + }) + .unwrap_or(false); + + debug!("WSClient should_retry: {}", should_retry); + should_retry + } +} diff --git a/libs/client-api/src/wasm/http_wasm.rs b/libs/client-api/src/wasm/http_wasm.rs index db8be2287..6c5382349 100644 --- a/libs/client-api/src/wasm/http_wasm.rs +++ b/libs/client-api/src/wasm/http_wasm.rs @@ -61,7 +61,7 @@ impl Client { } } -pub fn spawn(future: T) -> tokio::task::JoinHandle +pub fn platform_spawn(future: T) -> tokio::task::JoinHandle where T: Future + 'static, T::Output: Send + 'static, diff --git a/libs/client-api/src/wasm/mod.rs b/libs/client-api/src/wasm/mod.rs index 3b11fc117..b3550a7e3 100644 --- a/libs/client-api/src/wasm/mod.rs +++ b/libs/client-api/src/wasm/mod.rs @@ -1,3 +1,7 @@ mod http_wasm; +mod ping; +mod retry; pub use http_wasm::*; +pub(crate) use ping::*; +pub(crate) use retry::*; diff --git a/libs/client-api/src/wasm/ping.rs b/libs/client-api/src/wasm/ping.rs new file mode 100644 index 000000000..086cb9236 --- /dev/null +++ b/libs/client-api/src/wasm/ping.rs @@ -0,0 +1,50 @@ +use crate::platform_spawn; +use crate::ws::{ConnectState, ConnectStateNotify}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::broadcast::Sender; +use tokio::sync::mpsc::Receiver; +use tokio::sync::Mutex; +use websocket::Message; +#[allow(dead_code)] +pub(crate) struct ServerFixIntervalPing { + duration: Duration, + ping_sender: Option>, + pong_recv: Option>, + #[allow(dead_code)] + stop_tx: tokio::sync::mpsc::Sender<()>, + stop_rx: Option>, + state: Arc>, + ping_count: Arc>, + maximum_ping_count: u32, +} + +impl ServerFixIntervalPing { + pub(crate) fn new( + duration: Duration, + state: Arc>, + ping_sender: Sender, + pong_recv: Receiver<()>, + maximum_ping_count: u32, + ) -> Self { + let (tx, rx) = tokio::sync::mpsc::channel(1000); + Self { + duration, + stop_tx: tx, + stop_rx: Some(rx), + state, + ping_sender: Some(ping_sender), + pong_recv: Some(pong_recv), + ping_count: Arc::new(Mutex::new(0)), + maximum_ping_count, + } + } + + pub(crate) async fn stop(&self) { + let _ = self.stop_tx.send(()).await; + } + + pub(crate) fn run(&mut self) { + // TODO(nathan): implement the ping for wasm + } +} diff --git a/libs/client-api/src/wasm/retry.rs b/libs/client-api/src/wasm/retry.rs new file mode 100644 index 000000000..a5f688eb8 --- /dev/null +++ b/libs/client-api/src/wasm/retry.rs @@ -0,0 +1,12 @@ +use crate::ws::{CurrentAddr, StateNotify, WSError}; +use std::sync::Weak; +use websocket::{connect_async, WebSocketStream}; + +pub async fn retry_connect( + addr: &str, + _state_notify: Weak, + _current_addr: Weak, +) -> Result { + let stream = connect_async(addr).await?; + Ok(stream) +} diff --git a/libs/client-api/src/ws/client.rs b/libs/client-api/src/ws/client.rs index 7f81b7914..cedcff1e3 100644 --- a/libs/client-api/src/ws/client.rs +++ b/libs/client-api/src/ws/client.rs @@ -7,16 +7,13 @@ use std::time::Duration; use tokio::sync::broadcast::{channel, Receiver, Sender}; -use crate::spawn; -use crate::ws::ping::ServerFixIntervalPing; -use crate::ws::retry::ConnectAction; use crate::ws::{ConnectState, ConnectStateNotify, WSError, WebSocketChannel}; +use crate::ServerFixIntervalPing; +use crate::{platform_spawn, retry_connect}; use realtime_entity::collab_msg::CollabMessage; use realtime_entity::message::RealtimeMessage; use realtime_entity::user::UserMessage; use tokio::sync::{oneshot, Mutex}; -use tokio_retry::strategy::FixedInterval; -use tokio_retry::{Condition, RetryIf}; use tracing::{debug, error, info, trace, warn}; use websocket::{CloseCode, CloseFrame, Message}; @@ -49,10 +46,12 @@ type WeakChannel = Weak>; type ChannelByObjectId = HashMap>; pub type WSConnectStateReceiver = Receiver; +pub(crate) type StateNotify = parking_lot::Mutex; +pub(crate) type CurrentAddr = parking_lot::Mutex>; pub struct WSClient { - addr: Arc>>, + addr: Arc, config: WSClientConfig, - state_notify: Arc>, + state_notify: Arc, /// Sender used to send messages to the websocket. sender: Sender, http_sender: Arc, @@ -89,24 +88,26 @@ impl WSClient { pub async fn connect(&self, addr: String, device_id: &str) -> Result<(), WSError> { self.set_state(ConnectState::Connecting).await; + // stop receiving message from client let (stop_tx, mut stop_rx) = oneshot::channel(); if let Some(old_stop_tx) = self.stop_tx.lock().await.take() { let _ = old_stop_tx.send(()); } *self.stop_tx.lock().await = Some(stop_tx); + // stop pinging *self.addr.lock() = Some(addr.clone()); if let Some(old_ping) = self.ping.lock().await.as_ref() { old_ping.stop().await; } - let retry_strategy = FixedInterval::new(Duration::from_secs(6)); - let action = ConnectAction::new(addr.clone()); - let cond = RetryCondition { - connecting_addr: addr, - addr: Arc::downgrade(&self.addr), - state_notify: Arc::downgrade(&self.state_notify), - }; + // start connecting + let conn_result = retry_connect( + &addr, + Arc::downgrade(&self.state_notify), + Arc::downgrade(&self.addr), + ) + .await; // handle websocket error when connecting or sending message let weak_state_notify = Arc::downgrade(&self.state_notify); @@ -122,8 +123,6 @@ impl WSClient { }, } }; - - let conn_result = RetryIf::spawn(retry_strategy, action, cond).await; if let Err(err) = &conn_result { handle_ws_error(err); } @@ -133,7 +132,6 @@ impl WSClient { let (mut sink, mut stream) = ws_stream.split(); let weak_collab_channels = Arc::downgrade(&self.collab_channels); let sender = self.sender.clone(); - let ping_sender = sender.clone(); let (pong_tx, pong_recv) = tokio::sync::mpsc::channel(1); let mut ping = ServerFixIntervalPing::new( @@ -148,7 +146,7 @@ impl WSClient { let user_message_tx = self.user_channel.as_ref().clone(); // Receive messages from the websocket, and send them to the channels. - spawn(async move { + platform_spawn(async move { while let Some(Ok(ws_msg)) = stream.next().await { match ws_msg { Message::Binary(_) => { @@ -213,7 +211,7 @@ impl WSClient { let mut rx = self.sender.subscribe(); let weak_http_sender = Arc::downgrade(&self.http_sender); let device_id = device_id.to_string(); - spawn(async move { + platform_spawn(async move { loop { tokio::select! { _ = &mut stop_rx => break, @@ -226,7 +224,7 @@ impl WSClient { if let Some(http_sender) = weak_http_sender.upgrade() { let cloned_device_id = device_id.clone(); // Spawn a task here in case of blocking the current loop task. - tokio::spawn(async move { + platform_spawn(async move { if let Err(err) = http_sender.send_ws_msg(&cloned_device_id, msg).await { error!("Failed to send WebSocket message over HTTP: {}", err); } @@ -309,37 +307,3 @@ impl WSClient { self.state_notify.lock().set_state(state); } } - -struct RetryCondition { - connecting_addr: String, - addr: Weak>>, - state_notify: Weak>, -} -impl Condition for RetryCondition { - fn should_retry(&mut self, error: &WSError) -> bool { - if let WSError::AuthError(err) = error { - debug!("{}, stop retry connect", err); - - if let Some(state_notify) = self.state_notify.upgrade() { - state_notify.lock().set_state(ConnectState::Unauthorized); - } - - return false; - } - - let should_retry = self - .addr - .upgrade() - .map(|addr| match addr.try_lock() { - None => false, - Some(addr) => match &*addr { - None => false, - Some(addr) => addr == &self.connecting_addr, - }, - }) - .unwrap_or(false); - - debug!("WSClient should_retry: {}", should_retry); - should_retry - } -} diff --git a/libs/client-api/src/ws/handler.rs b/libs/client-api/src/ws/handler.rs index 788e1d313..227d7ef28 100644 --- a/libs/client-api/src/ws/handler.rs +++ b/libs/client-api/src/ws/handler.rs @@ -1,3 +1,4 @@ +use crate::platform_spawn; use futures_util::Sink; use realtime_entity::message::RealtimeMessage; use std::fmt::Debug; @@ -47,7 +48,7 @@ where let (tx, mut rx) = unbounded_channel::(); let cloned_sender = self.sender.clone(); let object_id = self.object_id.clone(); - tokio::spawn(async move { + platform_spawn(async move { while let Some(msg) = rx.recv().await { let realtime_msg: RealtimeMessage = msg.into(); let _ = cloned_sender.send(realtime_msg.into()); @@ -61,7 +62,7 @@ where let (tx, rx) = unbounded_channel::>(); let mut recv = self.receiver.subscribe(); let object_id = self.object_id.clone(); - tokio::spawn(async move { + platform_spawn(async move { while let Ok(msg) = recv.recv().await { if let Err(err) = tx.send(Ok(msg)) { trace!("Failed to send message to channel stream: {}", err); diff --git a/libs/client-api/src/ws/mod.rs b/libs/client-api/src/ws/mod.rs index 041b82716..2a96a887f 100644 --- a/libs/client-api/src/ws/mod.rs +++ b/libs/client-api/src/ws/mod.rs @@ -1,8 +1,6 @@ mod client; mod error; mod handler; -pub(crate) mod ping; -mod retry; mod state; pub use client::*; diff --git a/libs/client-api/src/ws/retry.rs b/libs/client-api/src/ws/retry.rs deleted file mode 100644 index 992bc7df6..000000000 --- a/libs/client-api/src/ws/retry.rs +++ /dev/null @@ -1,41 +0,0 @@ -use std::future::Future; -use std::pin::Pin; - -use crate::ws::WSError; -use tokio_retry::Action; -use tracing::info; -use websocket::{connect_async, WebSocketStream}; - -pub(crate) struct ConnectAction { - addr: String, -} - -impl ConnectAction { - pub fn new(addr: String) -> Self { - Self { addr } - } -} - -impl Action for ConnectAction { - #[cfg(not(target_arch = "wasm32"))] - type Future = Pin> + Send + Sync>>; - - #[cfg(target_arch = "wasm32")] - type Future = Pin>>>; - type Item = WebSocketStream; - type Error = WSError; - - fn run(&mut self) -> Self::Future { - let cloned_addr = self.addr.clone(); - Box::pin(async move { - info!("🔵websocket start connecting"); - match connect_async(&cloned_addr).await { - Ok(stream) => { - info!("🟢websocket connect success"); - Ok(stream) - }, - Err(e) => Err(e.into()), - } - }) - } -} diff --git a/libs/wasm-test/Cargo.toml b/libs/wasm-test/Cargo.toml index 65b6acc69..ac8ed2702 100644 --- a/libs/wasm-test/Cargo.toml +++ b/libs/wasm-test/Cargo.toml @@ -12,4 +12,7 @@ wasm-bindgen-test = "0.3.40" client-api-test-util = { path = "../client-api-test-util" } client-api = { path = "../client-api" } tokio = { version = "1", features = ["sync", "macros"] } +wasm-bindgen-futures = "0.4" +[features] +wasm_test = ["client-api-test-util/wasm_test"] \ No newline at end of file diff --git a/libs/wasm-test/README.md b/libs/wasm-test/README.md new file mode 100644 index 000000000..830054f92 --- /dev/null +++ b/libs/wasm-test/README.md @@ -0,0 +1,25 @@ + +## Run test + +before running the test, it requires to install the [chrome driver](https://chromedriver.chromium.org/downloads). +for mac user, you can install it by brew. + +```shell +brew install chromedriver +``` + +then run the test + + +```shell +wasm-pack test --headless --chrome +``` + +## Testing in browser + +```shell +wasm-pack test --chrome +``` + +Ref: +[wasm-bindgen-test](https://rustwasm.github.io/wasm-bindgen/wasm-bindgen-test/browsers.html) \ No newline at end of file diff --git a/libs/wasm-test/tests/conn_test.rs b/libs/wasm-test/tests/conn_test.rs index 6936ef7b4..9a7e2ff50 100644 --- a/libs/wasm-test/tests/conn_test.rs +++ b/libs/wasm-test/tests/conn_test.rs @@ -3,20 +3,23 @@ use client_api_test_util::generate_unique_registered_user_client; use wasm_bindgen_test::wasm_bindgen_test; #[wasm_bindgen_test] -async fn realtime_connect_test() { +async fn wasm_websocket_connect_test() { let (c, _user) = generate_unique_registered_user_client().await; let ws_client = WSClient::new(WSClientConfig::default(), c.clone()); let mut state = ws_client.subscribe_connect_state(); let device_id = "fake_device_id"; - loop { - tokio::select! { - _ = ws_client.connect(c.ws_url(device_id).await.unwrap(), device_id) => {}, - value = state.recv() => { - let new_state = value.unwrap(); - if new_state == ConnectState::Connected { - break; - } - }, + + wasm_bindgen_futures::spawn_local(async move { + ws_client + .connect(c.ws_url(device_id).await.unwrap(), device_id) + .await + .unwrap(); + }); + + // wait for the connect state to be connected + while let Ok(new_state) = state.recv().await { + if new_state == ConnectState::Connected { + break; } } } diff --git a/libs/wasm-test/tests/main.rs b/libs/wasm-test/tests/main.rs index 43b580d5e..631ca531e 100644 --- a/libs/wasm-test/tests/main.rs +++ b/libs/wasm-test/tests/main.rs @@ -1,5 +1,8 @@ use wasm_bindgen_test::wasm_bindgen_test_configure; wasm_bindgen_test_configure!(run_in_browser); -// #[cfg(target_arch = "wasm32")] -// mod conn_test; +#[cfg(target_arch = "wasm32")] +mod conn_test; + +#[cfg(target_arch = "wasm32")] +mod user_test; diff --git a/libs/wasm-test/tests/user_test.rs b/libs/wasm-test/tests/user_test.rs new file mode 100644 index 000000000..7216b1f29 --- /dev/null +++ b/libs/wasm-test/tests/user_test.rs @@ -0,0 +1,10 @@ +use client_api_test_util::{generate_unique_email, localhost_client}; +use wasm_bindgen_test::wasm_bindgen_test; + +#[wasm_bindgen_test] +async fn wasm_sign_up_success() { + let email = generate_unique_email(); + let password = "Hello!123#"; + let c = localhost_client(); + c.sign_up(&email, password).await.unwrap(); +} diff --git a/src/middleware/cors_mw.rs b/src/middleware/cors_mw.rs index 00f7654b8..43659d54e 100644 --- a/src/middleware/cors_mw.rs +++ b/src/middleware/cors_mw.rs @@ -10,7 +10,10 @@ pub fn default_cors() -> Cors { .allow_any_origin() .send_wildcard() .allowed_methods(vec!["GET", "POST", "PUT", "DELETE"]) - .allowed_headers(vec![http::header::ACCEPT]) + .allowed_headers(vec![ + actix_web::http::header::AUTHORIZATION, + actix_web::http::header::ACCEPT, + ]) .allowed_header(http::header::CONTENT_TYPE) .max_age(3600) } diff --git a/tests/websocket/conn_test.rs b/tests/websocket/conn_test.rs index 98cc62628..be91b365e 100644 --- a/tests/websocket/conn_test.rs +++ b/tests/websocket/conn_test.rs @@ -71,19 +71,3 @@ async fn realtime_disconnect_test() { } } } - -// use std::time::Duration; -// use tokio_tungstenite::tungstenite::Message; -// #[tokio::test] -// async fn max_frame_size() { -// let (c, _user) = generate_unique_registered_user_client().await; -// let ws_client = WSClient::new(WSClientConfig::default(), c.clone()); -// let device_id = "fake_device_id"; -// ws_client -// .connect(c.ws_url(device_id).unwrap(), device_id) -// .await -// .unwrap(); -// -// ws_client.send(Message::Binary(vec![0; 65536])).unwrap(); -// tokio::time::sleep(Duration::from_secs(5)).await; -// } From bb103531e252e2f234a64eb71ca3175da780a04f Mon Sep 17 00:00:00 2001 From: nathan Date: Wed, 24 Jan 2024 11:31:53 +0800 Subject: [PATCH 2/3] refactor: template --- .../src/document/get_started.rs | 55 +++++++++++++++++++ libs/workspace-template/src/document/mod.rs | 54 +----------------- libs/workspace-template/src/lib.rs | 13 +++-- src/biz/user.rs | 2 + 4 files changed, 67 insertions(+), 57 deletions(-) create mode 100644 libs/workspace-template/src/document/get_started.rs diff --git a/libs/workspace-template/src/document/get_started.rs b/libs/workspace-template/src/document/get_started.rs new file mode 100644 index 000000000..72bd9ad11 --- /dev/null +++ b/libs/workspace-template/src/document/get_started.rs @@ -0,0 +1,55 @@ +use crate::document::parser::JsonToDocumentParser; +use crate::hierarchy_builder::WorkspaceViewBuilder; +use crate::{TemplateData, WorkspaceTemplate}; +use async_trait::async_trait; +use collab::core::collab::MutexCollab; +use collab::core::origin::CollabOrigin; +use collab_document::document::Document; +use collab_entity::CollabType; +use collab_folder::ViewLayout; +use std::sync::Arc; +use tokio::sync::RwLock; + +/// This template generates a document containing a 'read me' guide. +/// It ensures that at least one view is created for the document. +pub struct GetStartedDocumentTemplate; + +#[async_trait] +impl WorkspaceTemplate for GetStartedDocumentTemplate { + fn layout(&self) -> ViewLayout { + ViewLayout::Document + } + + async fn create_workspace_view( + &self, + _uid: i64, + workspace_view_builder: Arc>, + ) -> anyhow::Result { + let view_id = workspace_view_builder + .write() + .await + .with_view_builder(|view_builder| async { + view_builder + .with_name("Getting started") + .with_icon("⭐️") + .build() + }) + .await; + + // create a empty document + let data = tokio::task::spawn_blocking(|| { + let json_str = include_str!("../../assets/read_me.json"); + let document_data = JsonToDocumentParser::json_str_to_document(json_str).unwrap(); + let collab = Arc::new(MutexCollab::new(CollabOrigin::Empty, &view_id, vec![])); + let document = Document::create_with_data(collab, document_data)?; + let data = document.get_collab().encode_collab_v1(); + Ok::<_, anyhow::Error>(TemplateData { + object_id: view_id, + object_type: CollabType::Document, + object_data: data, + }) + }) + .await??; + Ok(data) + } +} diff --git a/libs/workspace-template/src/document/mod.rs b/libs/workspace-template/src/document/mod.rs index d04f22034..effd01a98 100644 --- a/libs/workspace-template/src/document/mod.rs +++ b/libs/workspace-template/src/document/mod.rs @@ -1,54 +1,2 @@ +pub mod get_started; mod parser; - -use crate::document::parser::JsonToDocumentParser; -use crate::hierarchy_builder::WorkspaceViewBuilder; -use crate::{TemplateData, WorkspaceTemplate}; -use async_trait::async_trait; -use collab::core::collab::MutexCollab; -use collab::core::origin::CollabOrigin; -use collab_document::document::Document; -use collab_entity::CollabType; -use std::sync::Arc; -use tokio::sync::RwLock; - -/// A default template for creating documents. -/// -/// This template generates a document containing a 'read me' guide. -/// It ensures that at least one view is created for the document. -pub struct DocumentTemplate; - -#[async_trait] -impl WorkspaceTemplate for DocumentTemplate { - async fn create_workspace_view( - &self, - _uid: i64, - workspace_view_builder: Arc>, - ) -> anyhow::Result { - let view_id = workspace_view_builder - .write() - .await - .with_view_builder(|view_builder| async { - view_builder - .with_name("Getting started") - .with_icon("⭐️") - .build() - }) - .await; - - // create a empty document - let data = tokio::task::spawn_blocking(|| { - let json_str = include_str!("../../assets/read_me.json"); - let document_data = JsonToDocumentParser::json_str_to_document(json_str).unwrap(); - let collab = Arc::new(MutexCollab::new(CollabOrigin::Empty, &view_id, vec![])); - let document = Document::create_with_data(collab, document_data)?; - let data = document.get_collab().encode_collab_v1(); - Ok::<_, anyhow::Error>(TemplateData { - object_id: view_id, - object_type: CollabType::Document, - object_data: data, - }) - }) - .await??; - Ok(data) - } -} diff --git a/libs/workspace-template/src/lib.rs b/libs/workspace-template/src/lib.rs index a8cedc788..616d68f82 100644 --- a/libs/workspace-template/src/lib.rs +++ b/libs/workspace-template/src/lib.rs @@ -1,4 +1,4 @@ -mod document; +pub mod document; mod hierarchy_builder; use crate::hierarchy_builder::{FlattedViews, WorkspaceViewBuilder}; @@ -17,6 +17,8 @@ use tokio::sync::RwLock; #[async_trait] pub trait WorkspaceTemplate { + fn layout(&self) -> ViewLayout; + async fn create_workspace_view( &self, uid: i64, @@ -40,9 +42,7 @@ pub struct WorkspaceTemplateBuilder { impl WorkspaceTemplateBuilder { pub fn new(uid: i64, workspace_id: &str) -> Self { - let mut handlers = WorkspaceTemplateHandlers::default(); - // register the document template handler - handlers.insert(ViewLayout::Document, Arc::new(document::DocumentTemplate)); + let handlers = WorkspaceTemplateHandlers::default(); Self { uid, workspace_id: workspace_id.to_string(), @@ -50,6 +50,11 @@ impl WorkspaceTemplateBuilder { } } + pub fn with_template(mut self, template: T) -> Self { + self.handlers.insert(template.layout(), Arc::new(template)); + self + } + pub async fn default_workspace(&self) -> Result> { let workspace_view_builder = Arc::new(RwLock::new(WorkspaceViewBuilder::new( self.workspace_id.clone(), diff --git a/src/biz/user.rs b/src/biz/user.rs index 1a4b00787..2e89a1875 100644 --- a/src/biz/user.rs +++ b/src/biz/user.rs @@ -25,6 +25,7 @@ use snowflake::Snowflake; use sqlx::{types::uuid, PgPool}; use tokio::sync::RwLock; use tracing::{debug, event, instrument}; +use workspace_template::document::get_started::GetStartedDocumentTemplate; use workspace_template::WorkspaceTemplateBuilder; /// Verify the token from the gotrue server and create the user if it is a new user @@ -89,6 +90,7 @@ where // Create the default workspace for the user. A default workspace might contain multiple // templates, e.g. a document template, a database template, etc. let templates = WorkspaceTemplateBuilder::new(new_uid, &workspace_id) + .with_template(GetStartedDocumentTemplate) .default_workspace() .await?; From bb1029077ce9116c4e895e34bdaeeb59125d74ce Mon Sep 17 00:00:00 2001 From: nathan Date: Wed, 24 Jan 2024 11:46:23 +0800 Subject: [PATCH 3/3] refactor: template --- libs/workspace-template/src/lib.rs | 19 +++++- src/biz/user.rs | 100 ++++++++++++++++++----------- 2 files changed, 79 insertions(+), 40 deletions(-) diff --git a/libs/workspace-template/src/lib.rs b/libs/workspace-template/src/lib.rs index 616d68f82..d09242d26 100644 --- a/libs/workspace-template/src/lib.rs +++ b/libs/workspace-template/src/lib.rs @@ -34,6 +34,8 @@ pub struct TemplateData { pub type WorkspaceTemplateHandlers = HashMap>; +/// A builder for creating a workspace template. +/// workspace template is a set of views that are created when a workspace is created. pub struct WorkspaceTemplateBuilder { pub uid: i64, pub workspace_id: String, @@ -50,12 +52,25 @@ impl WorkspaceTemplateBuilder { } } - pub fn with_template(mut self, template: T) -> Self { + pub fn with_template(mut self, template: T) -> Self + where + T: WorkspaceTemplate + Send + Sync + 'static, + { self.handlers.insert(template.layout(), Arc::new(template)); self } - pub async fn default_workspace(&self) -> Result> { + pub fn with_templates(mut self, templates: Vec) -> Self + where + T: WorkspaceTemplate + Send + Sync + 'static, + { + for template in templates { + self.handlers.insert(template.layout(), Arc::new(template)); + } + self + } + + pub async fn build(&self) -> Result> { let workspace_view_builder = Arc::new(RwLock::new(WorkspaceViewBuilder::new( self.workspace_id.clone(), self.uid, diff --git a/src/biz/user.rs b/src/biz/user.rs index 2e89a1875..bddeffc72 100644 --- a/src/biz/user.rs +++ b/src/biz/user.rs @@ -22,11 +22,11 @@ use database::user::{create_user, is_user_exist}; use realtime::entities::RealtimeUser; use shared_entity::dto::auth_dto::UpdateUserParams; use snowflake::Snowflake; -use sqlx::{types::uuid, PgPool}; +use sqlx::{types::uuid, PgPool, Transaction}; use tokio::sync::RwLock; use tracing::{debug, event, instrument}; use workspace_template::document::get_started::GetStartedDocumentTemplate; -use workspace_template::WorkspaceTemplateBuilder; +use workspace_template::{WorkspaceTemplate, WorkspaceTemplateBuilder}; /// Verify the token from the gotrue server and create the user if it is a new user /// Return true if the user is a new user @@ -87,42 +87,15 @@ where ) .await?; - // Create the default workspace for the user. A default workspace might contain multiple - // templates, e.g. a document template, a database template, etc. - let templates = WorkspaceTemplateBuilder::new(new_uid, &workspace_id) - .with_template(GetStartedDocumentTemplate) - .default_workspace() - .await?; - - debug!("create {} templates for user:{}", templates.len(), new_uid); - for template in templates { - let object_id = template.object_id; - let encoded_collab_v1 = template - .object_data - .encode_to_bytes() - .map_err(|err| AppError::Internal(anyhow::Error::from(err)))?; - - collab_access_control - .cache_collab_access_level( - realtime::collaborate::CollabUserId::UserId(&new_uid), - &object_id, - AFAccessLevel::FullAccess, - ) - .await?; - - insert_into_af_collab( - &mut txn, - &new_uid, - &workspace_id, - &CollabParams { - object_id, - encoded_collab_v1, - collab_type: template.object_type, - override_if_exist: false, - }, - ) - .await?; - } + // Create a workspace with the GetStarted template + create_workspace_for_user( + new_uid, + &workspace_id, + collab_access_control, + &mut txn, + vec![GetStartedDocumentTemplate], + ) + .await?; } txn .commit() @@ -131,6 +104,57 @@ where Ok(is_new) } +/// Create a workspace for a user. +/// This function generates a workspace along with its templates and stores them in the database. +/// Each template is stored as an individual collaborative object. +async fn create_workspace_for_user( + new_uid: i64, + workspace_id: &str, + collab_access_control: &C, + txn: &mut Transaction<'_, sqlx::Postgres>, + templates: Vec, +) -> Result<(), AppError> +where + C: CollabAccessControl, + T: WorkspaceTemplate + Send + Sync + 'static, +{ + let templates = WorkspaceTemplateBuilder::new(new_uid, workspace_id) + .with_templates(templates) + .build() + .await?; + + debug!("create {} templates for user:{}", templates.len(), new_uid); + for template in templates { + let object_id = template.object_id; + let encoded_collab_v1 = template + .object_data + .encode_to_bytes() + .map_err(|err| AppError::Internal(anyhow::Error::from(err)))?; + + collab_access_control + .cache_collab_access_level( + realtime::collaborate::CollabUserId::UserId(&new_uid), + &object_id, + AFAccessLevel::FullAccess, + ) + .await?; + + insert_into_af_collab( + txn, + &new_uid, + workspace_id, + &CollabParams { + object_id, + encoded_collab_v1, + collab_type: template.object_type, + override_if_exist: false, + }, + ) + .await?; + } + Ok(()) +} + pub async fn get_profile(pg_pool: &PgPool, uuid: &Uuid) -> Result { let row = select_user_profile(pg_pool, uuid) .await?