diff --git a/.github/workflows/integration_test.yml b/.github/workflows/integration_test.yml index fc539dc29..3ab55014b 100644 --- a/.github/workflows/integration_test.yml +++ b/.github/workflows/integration_test.yml @@ -69,7 +69,7 @@ jobs: matrix: include: - test_service: "appflowy_cloud" - test_cmd: "--workspace --exclude appflowy-history --exclude appflowy-ai-client --features ai-test-enabled" + test_cmd: "--workspace --exclude appflowy-ai-client --features ai-test-enabled" - test_service: "appflowy_worker" test_cmd: "-p appflowy-worker" - test_service: "admin_frontend" @@ -130,7 +130,7 @@ jobs: - name: Run Tests run: | echo "Running tests for ${{ matrix.test_service }} with flags: ${{ matrix.test_cmd }}" - RUST_LOG="info" DISABLE_CI_TEST_LOG="true" cargo test ${{ matrix.test_cmd }} + RUST_LOG="info" DISABLE_CI_TEST_LOG="true" cargo test ${{ matrix.test_cmd }} -- --skip stress_test - name: Server Logs if: failure() diff --git a/.github/workflows/push_latest_docker.yml b/.github/workflows/push_latest_docker.yml index 361867969..3d451a945 100644 --- a/.github/workflows/push_latest_docker.yml +++ b/.github/workflows/push_latest_docker.yml @@ -239,7 +239,6 @@ jobs: if: always() run: docker logout - appflowy_worker_image: runs-on: ubuntu-22.04 env: diff --git a/.github/workflows/stress_test.yml b/.github/workflows/stress_test.yml new file mode 100644 index 000000000..267a15eea --- /dev/null +++ b/.github/workflows/stress_test.yml @@ -0,0 +1,49 @@ +name: AppFlowy-Cloud Stress Test + +on: [ pull_request ] + +concurrency: + group: stress-test-${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: false + +env: + POSTGRES_HOST: localhost + REDIS_HOST: localhost + MINIO_HOST: localhost + SQLX_OFFLINE: true + RUST_TOOLCHAIN: "1.78" + +jobs: + test: + name: Collab Stress Tests + runs-on: self-hosted-appflowy3 + + steps: + - name: Checkout Repository + uses: actions/checkout@v3 + + - name: Install Rust Toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Copy and Rename deploy.env to .env + run: cp deploy.env .env + + - name: Replace Values in .env + run: | + sed -i '' 's|RUST_LOG=.*|RUST_LOG=debug|' .env + sed -i '' 's|API_EXTERNAL_URL=.*|API_EXTERNAL_URL=http://localhost:9999|' .env + sed -i '' 's|APPFLOWY_GOTRUE_BASE_URL=.*|APPFLOWY_GOTRUE_BASE_URL=http://localhost:9999|' .env + shell: bash + + - name: Start Docker Compose Services + run: | + docker compose -f docker-compose-stress-test.yml up -d + docker ps -a + + - name: Install Prerequisites + run: | + brew install protobuf + + - name: Run Server and Test + run: | + cargo run --package xtask -- --stress-test diff --git a/.sqlx/query-88516b9a2a424bc7697337d6f16b0d6e94b919597d709f930467423c5b4c0ec2.json b/.sqlx/query-88516b9a2a424bc7697337d6f16b0d6e94b919597d709f930467423c5b4c0ec2.json new file mode 100644 index 000000000..e2d0093d2 --- /dev/null +++ b/.sqlx/query-88516b9a2a424bc7697337d6f16b0d6e94b919597d709f930467423c5b4c0ec2.json @@ -0,0 +1,65 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT * FROM af_collab_snapshot\n WHERE workspace_id = $1 AND oid = $2 AND deleted_at IS NULL\n ORDER BY created_at DESC\n LIMIT 1;\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "sid", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "oid", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "blob", + "type_info": "Bytea" + }, + { + "ordinal": 3, + "name": "len", + "type_info": "Int4" + }, + { + "ordinal": 4, + "name": "encrypt", + "type_info": "Int4" + }, + { + "ordinal": 5, + "name": "deleted_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 6, + "name": "workspace_id", + "type_info": "Uuid" + }, + { + "ordinal": 7, + "name": "created_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Uuid", + "Text" + ] + }, + "nullable": [ + false, + false, + false, + false, + true, + true, + false, + false + ] + }, + "hash": "88516b9a2a424bc7697337d6f16b0d6e94b919597d709f930467423c5b4c0ec2" +} diff --git a/Cargo.lock b/Cargo.lock index 15305ed3d..2a6ad2168 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -638,6 +638,7 @@ dependencies = [ "derive_more", "dotenvy", "fancy-regex 0.11.0", + "flate2", "futures", "futures-lite", "futures-util", @@ -716,6 +717,7 @@ dependencies = [ "anyhow", "app-error", "appflowy-ai-client", + "arc-swap", "async-stream", "async-trait", "authentication", @@ -2354,11 +2356,15 @@ name = "collab-stream" version = "0.1.0" dependencies = [ "anyhow", + "async-stream", + "async-trait", "bincode", "bytes", "chrono", + "collab", "collab-entity", "futures", + "loole", "prost 0.13.3", "rand 0.8.5", "redis 0.25.4", @@ -2369,6 +2375,7 @@ dependencies = [ "tokio-stream", "tokio-util", "tracing", + "zstd 0.13.2", ] [[package]] @@ -3323,9 +3330,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", @@ -3354,9 +3361,9 @@ checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" [[package]] name = "futures-executor" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" dependencies = [ "futures-core", "futures-task", @@ -4514,6 +4521,16 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "loole" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2998397c725c822c6b2ba605fd9eb4c6a7a0810f1629ba3cc232ef4f0308d96" +dependencies = [ + "futures-core", + "futures-sink", +] + [[package]] name = "lru" version = "0.12.4" @@ -8666,6 +8683,7 @@ name = "xtask" version = "0.1.0" dependencies = [ "anyhow", + "futures", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index d737d9b1e..69941a014 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -159,6 +159,7 @@ http.workspace = true indexer.workspace = true [dev-dependencies] +flate2 = "1.0" once_cell = "1.19.0" tempfile = "3.9.0" assert-json-diff = "2.0.2" @@ -214,7 +215,6 @@ members = [ "libs/appflowy-ai-client", "libs/client-api-entity", # services - #"services/appflowy-history", "services/appflowy-collaborate", "services/appflowy-worker", # xtask @@ -283,6 +283,7 @@ sanitize-filename = "0.5.0" base64 = "0.22" md5 = "0.7.0" pin-project = "1.1.5" +arc-swap = { version = "1.7" } validator = "0.19" zstd = { version = "0.13.2", features = [] } chrono = { version = "0.4.39", features = [ diff --git a/deploy.env b/deploy.env index 59de9edd2..fe4667740 100644 --- a/deploy.env +++ b/deploy.env @@ -4,7 +4,7 @@ # PostgreSQL Settings POSTGRES_HOST=postgres POSTGRES_USER=postgres -POSTGRES_PASSWORD=changepassword +POSTGRES_PASSWORD=password POSTGRES_PORT=5432 POSTGRES_DB=postgres @@ -15,6 +15,10 @@ SUPABASE_PASSWORD=root REDIS_HOST=redis REDIS_PORT=6379 +# Minio Host +MINIO_HOST=minio +MINIO_PORT=9000 + # AppFlowy Cloud ## URL that connects to the gotrue docker container APPFLOWY_GOTRUE_BASE_URL=http://gotrue:9999 @@ -69,11 +73,12 @@ GOTRUE_DISABLE_SIGNUP=false # If you are using a different domain, you need to change the redirect_uri in the OAuth2 configuration # Make sure that this domain is accessible to the user # Make sure no endswith / +# Replace with your host name instead of localhost API_EXTERNAL_URL=http://your-host # In docker environment, `postgres` is the hostname of the postgres service # GoTrue connect to postgres using this url -GOTRUE_DATABASE_URL=postgres://supabase_auth_admin:${SUPABASE_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB} +GOTRUE_DATABASE_URL=postgres://supabase_auth_admin:${SUPABASE_PASSWORD}@postgres:${POSTGRES_PORT}/${POSTGRES_DB} # Refer to this for details: https://github.com/AppFlowy-IO/AppFlowy-Cloud/blob/main/doc/AUTHENTICATION.md # Google OAuth2 @@ -105,7 +110,7 @@ APPFLOWY_S3_CREATE_BUCKET=true # By default, Minio is used as the default file storage which uses host's file system. # Keep this as true if you are using other S3 compatible storage provider other than AWS. APPFLOWY_S3_USE_MINIO=true -APPFLOWY_S3_MINIO_URL=http://minio:9000 # change this if you are using a different address for minio +APPFLOWY_S3_MINIO_URL=http://${MINIO_HOST}:${MINIO_PORT} # change this if you are using a different address for minio APPFLOWY_S3_ACCESS_KEY=minioadmin APPFLOWY_S3_SECRET_KEY=minioadmin APPFLOWY_S3_BUCKET=appflowy diff --git a/docker-compose-stress-test.yml b/docker-compose-stress-test.yml new file mode 100644 index 000000000..58a65e641 --- /dev/null +++ b/docker-compose-stress-test.yml @@ -0,0 +1,97 @@ +services: + nginx: + restart: on-failure + image: nginx + ports: + - 80:80 # Disable this if you are using TLS + - 443:443 + volumes: + - ./nginx/nginx.conf:/etc/nginx/nginx.conf + - ./nginx/ssl/certificate.crt:/etc/nginx/ssl/certificate.crt + - ./nginx/ssl/private_key.key:/etc/nginx/ssl/private_key.key + minio: + restart: on-failure + image: minio/minio + ports: + - 9000:9000 + - 9001:9001 + environment: + - MINIO_BROWSER_REDIRECT_URL=http://localhost:9001 + command: server /data --console-address ":9001" + + postgres: + restart: on-failure + image: pgvector/pgvector:pg16 + environment: + - POSTGRES_USER=${POSTGRES_USER:-postgres} + - POSTGRES_DB=${POSTGRES_DB:-postgres} + - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password} + - POSTGRES_HOST=${POSTGRES_HOST:-postgres} + - SUPABASE_USER=${SUPABASE_USER:-supabase_auth_admin} + - SUPABASE_PASSWORD=${SUPABASE_PASSWORD:-root} + ports: + - 5432:5432 + volumes: + - ./migrations/before:/docker-entrypoint-initdb.d + # comment out the following line if you want to persist data when restarting docker + #- postgres_data:/var/lib/postgresql/data + + redis: + restart: on-failure + image: redis + ports: + - 6379:6379 + + gotrue: + restart: on-failure + image: supabase/gotrue:v2.159.1 + depends_on: + - postgres + environment: + # Gotrue config: https://github.com/supabase/gotrue/blob/master/example.env + - GOTRUE_SITE_URL=appflowy-flutter:// # redirected to AppFlowy application + - URI_ALLOW_LIST=* # adjust restrict if necessary + - GOTRUE_JWT_SECRET=${GOTRUE_JWT_SECRET} # authentication secret + - GOTRUE_JWT_EXP=${GOTRUE_JWT_EXP} + - GOTRUE_DB_DRIVER=postgres + - API_EXTERNAL_URL=${API_EXTERNAL_URL} + - DATABASE_URL=${GOTRUE_DATABASE_URL} + - PORT=9999 + - GOTRUE_MAILER_URLPATHS_CONFIRMATION=/verify + - GOTRUE_SMTP_HOST=${GOTRUE_SMTP_HOST} # e.g. smtp.gmail.com + - GOTRUE_SMTP_PORT=${GOTRUE_SMTP_PORT} # e.g. 465 + - GOTRUE_SMTP_USER=${GOTRUE_SMTP_USER} # email sender, e.g. noreply@appflowy.io + - GOTRUE_SMTP_PASS=${GOTRUE_SMTP_PASS} # email password + - GOTRUE_SMTP_ADMIN_EMAIL=${GOTRUE_SMTP_ADMIN_EMAIL} # email with admin privileges e.g. internal@appflowy.io + - GOTRUE_SMTP_MAX_FREQUENCY=${GOTRUE_SMTP_MAX_FREQUENCY:-1ns} # set to 1ns for running tests + - GOTRUE_RATE_LIMIT_EMAIL_SENT=${GOTRUE_RATE_LIMIT_EMAIL_SENT:-100} # number of email sendable per minute + - GOTRUE_MAILER_AUTOCONFIRM=${GOTRUE_MAILER_AUTOCONFIRM:-false} # change this to true to skip email confirmation + # Google OAuth config + - GOTRUE_EXTERNAL_GOOGLE_ENABLED=${GOTRUE_EXTERNAL_GOOGLE_ENABLED} + - GOTRUE_EXTERNAL_GOOGLE_CLIENT_ID=${GOTRUE_EXTERNAL_GOOGLE_CLIENT_ID} + - GOTRUE_EXTERNAL_GOOGLE_SECRET=${GOTRUE_EXTERNAL_GOOGLE_SECRET} + - GOTRUE_EXTERNAL_GOOGLE_REDIRECT_URI=${GOTRUE_EXTERNAL_GOOGLE_REDIRECT_URI} + # Apple OAuth config + - GOTRUE_EXTERNAL_APPLE_ENABLED=${GOTRUE_EXTERNAL_APPLE_ENABLED} + - GOTRUE_EXTERNAL_APPLE_CLIENT_ID=${GOTRUE_EXTERNAL_APPLE_CLIENT_ID} + - GOTRUE_EXTERNAL_APPLE_SECRET=${GOTRUE_EXTERNAL_APPLE_SECRET} + - GOTRUE_EXTERNAL_APPLE_REDIRECT_URI=${GOTRUE_EXTERNAL_APPLE_REDIRECT_URI} + # GITHUB OAuth config + - GOTRUE_EXTERNAL_GITHUB_ENABLED=${GOTRUE_EXTERNAL_GITHUB_ENABLED} + - GOTRUE_EXTERNAL_GITHUB_CLIENT_ID=${GOTRUE_EXTERNAL_GITHUB_CLIENT_ID} + - GOTRUE_EXTERNAL_GITHUB_SECRET=${GOTRUE_EXTERNAL_GITHUB_SECRET} + - GOTRUE_EXTERNAL_GITHUB_REDIRECT_URI=${GOTRUE_EXTERNAL_GITHUB_REDIRECT_URI} + # Discord OAuth config + - GOTRUE_EXTERNAL_DISCORD_ENABLED=${GOTRUE_EXTERNAL_DISCORD_ENABLED} + - GOTRUE_EXTERNAL_DISCORD_CLIENT_ID=${GOTRUE_EXTERNAL_DISCORD_CLIENT_ID} + - GOTRUE_EXTERNAL_DISCORD_SECRET=${GOTRUE_EXTERNAL_DISCORD_SECRET} + - GOTRUE_EXTERNAL_DISCORD_REDIRECT_URI=${GOTRUE_EXTERNAL_DISCORD_REDIRECT_URI} + # Prometheus Metrics + - GOTRUE_METRICS_ENABLED=true + - GOTRUE_METRICS_EXPORTER=prometheus + - GOTRUE_MAILER_TEMPLATES_CONFIRMATION=${GOTRUE_MAILER_TEMPLATES_CONFIRMATION} + ports: + - 9999:9999 + +volumes: + postgres_data: diff --git a/libs/client-api/Cargo.toml b/libs/client-api/Cargo.toml index 7a24d4363..49a1e9489 100644 --- a/libs/client-api/Cargo.toml +++ b/libs/client-api/Cargo.toml @@ -40,7 +40,7 @@ serde_json.workspace = true serde.workspace = true app-error = { workspace = true, features = ["tokio_error", "bincode_error"] } scraper = { version = "0.17.1", optional = true } -arc-swap = "1.7" +arc-swap.workspace = true shared-entity = { workspace = true } collab-rt-entity = { workspace = true } diff --git a/libs/client-api/src/collab_sync/plugin.rs b/libs/client-api/src/collab_sync/plugin.rs index 6f4f51496..8e2a51edf 100644 --- a/libs/client-api/src/collab_sync/plugin.rs +++ b/libs/client-api/src/collab_sync/plugin.rs @@ -185,7 +185,7 @@ where _event: &Event, update: &AwarenessUpdate, ) { - let payload = Message::Awareness(update.clone()).encode_v1(); + let payload = Message::Awareness(update.encode_v1()).encode_v1(); self.sync_queue.queue_msg(|msg_id| { let update_sync = UpdateSync::new(origin.clone(), object_id.to_string(), payload, msg_id); if cfg!(feature = "sync_verbose_log") { diff --git a/libs/client-api/src/collab_sync/sync_control.rs b/libs/client-api/src/collab_sync/sync_control.rs index 07510841c..8fe0e530e 100644 --- a/libs/client-api/src/collab_sync/sync_control.rs +++ b/libs/client-api/src/collab_sync/sync_control.rs @@ -199,7 +199,11 @@ where return Ok(false); } - trace!("🔥{} start sync, reason:{}", &sync_object.object_id, reason); + tracing::debug!( + "🔥{} restart sync due to missing update, reason:{}", + &sync_object.object_id, + reason + ); let awareness = collab.get_awareness(); let payload = gen_sync_state(awareness, &ClientSyncProtocol)?; sink.queue_init_sync(|msg_id| { @@ -236,8 +240,8 @@ where SyncReason::CollabInitialize | SyncReason::ServerCannotApplyUpdate | SyncReason::NetworkResume => { - trace!( - "🔥{} start sync, reason: {}", + tracing::debug!( + "🔥{} resume network, reason: {}", &sync_object.object_id, reason ); diff --git a/libs/collab-rt-entity/src/server_message.rs b/libs/collab-rt-entity/src/server_message.rs index e5bc26de0..94838bb65 100644 --- a/libs/collab-rt-entity/src/server_message.rs +++ b/libs/collab-rt-entity/src/server_message.rs @@ -245,12 +245,13 @@ impl AckMeta { impl Display for CollabAck { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.write_fmt(format_args!( - "ack: [uid:{}|oid:{}|msg_id:{:?}|len:{}|code:{}]", + "ack: [uid:{}|oid:{}|msg_id:{:?}|len:{}|code:{}|seq_nr:{}]", self.origin.client_user_id().unwrap_or(0), self.object_id, self.msg_id, self.payload.len(), self.code, + self.seq_num )) } } diff --git a/libs/collab-rt-protocol/src/message.rs b/libs/collab-rt-protocol/src/message.rs index 1b659b7e3..557555521 100644 --- a/libs/collab-rt-protocol/src/message.rs +++ b/libs/collab-rt-protocol/src/message.rs @@ -1,6 +1,5 @@ use std::fmt::{Debug, Display, Formatter}; -use collab::core::awareness::AwarenessUpdate; use serde::{Deserialize, Serialize}; use thiserror::Error; use yrs::updates::decoder::{Decode, Decoder}; @@ -22,7 +21,7 @@ pub const PERMISSION_GRANTED: u8 = 1; pub enum Message { Sync(SyncMessage), Auth(Option), - Awareness(AwarenessUpdate), + Awareness(Vec), Custom(CustomMessage), } @@ -44,7 +43,7 @@ impl Encode for Message { }, Message::Awareness(update) => { encoder.write_var(MSG_AWARENESS); - encoder.write_buf(update.encode_v1()) + encoder.write_buf(update) }, Message::Custom(msg) => { encoder.write_var(MSG_CUSTOM); @@ -64,8 +63,7 @@ impl Decode for Message { }, MSG_AWARENESS => { let data = decoder.read_buf()?; - let update = AwarenessUpdate::decode_v1(data)?; - Ok(Message::Awareness(update)) + Ok(Message::Awareness(data.into())) }, MSG_AUTH => { let reason = if decoder.read_var::()? == PERMISSION_DENIED { diff --git a/libs/collab-rt-protocol/src/protocol.rs b/libs/collab-rt-protocol/src/protocol.rs index eb850e20c..3a90e49eb 100644 --- a/libs/collab-rt-protocol/src/protocol.rs +++ b/libs/collab-rt-protocol/src/protocol.rs @@ -93,6 +93,7 @@ pub trait CollabSyncProtocol { Message::Auth(reason) => self.handle_auth(collab, reason).await, //FIXME: where is the QueryAwareness protocol? Message::Awareness(update) => { + let update = AwarenessUpdate::decode_v1(&update)?; self .handle_awareness_update(message_origin, collab, update) .await @@ -117,7 +118,7 @@ pub trait CollabSyncProtocol { .map_err(|e| RTProtocolError::YrsTransaction(e.to_string()))? .state_vector(); let awareness_update = awareness.update()?; - (state_vector, awareness_update) + (state_vector, awareness_update.encode_v1()) }; // 1. encode doc state vector @@ -220,14 +221,14 @@ pub trait CollabSyncProtocol { } } -const LARGE_UPDATE_THRESHOLD: usize = 1024 * 1024; // 1MB +pub const LARGE_UPDATE_THRESHOLD: usize = 1024 * 1024; // 1MB #[inline] -pub async fn decode_update(update: Vec) -> Result { +pub async fn decode_update(update: Vec) -> Result { let update = if update.len() > LARGE_UPDATE_THRESHOLD { spawn_blocking(move || Update::decode_v1(&update)) .await - .map_err(|err| RTProtocolError::Internal(err.into()))? + .map_err(|err| yrs::encoding::read::Error::Custom(err.to_string()))? } else { Update::decode_v1(&update) }?; diff --git a/libs/collab-stream/Cargo.toml b/libs/collab-stream/Cargo.toml index 78f234430..12bcb1d2e 100644 --- a/libs/collab-stream/Cargo.toml +++ b/libs/collab-stream/Cargo.toml @@ -16,12 +16,16 @@ tracing = "0.1" serde = { version = "1", features = ["derive"] } bincode = "1.3.3" bytes.workspace = true +collab.workspace = true collab-entity.workspace = true serde_json.workspace = true chrono = "0.4" tokio-util = { version = "0.7" } prost.workspace = true - +async-stream.workspace = true +async-trait.workspace = true +zstd = "0.13" +loole = "0.4.0" [dev-dependencies] futures = "0.3.30" diff --git a/libs/collab-stream/src/client.rs b/libs/collab-stream/src/client.rs index 57d17fd14..2ad9d79ed 100644 --- a/libs/collab-stream/src/client.rs +++ b/libs/collab-stream/src/client.rs @@ -1,29 +1,61 @@ -use crate::error::StreamError; -use crate::pubsub::{CollabStreamPub, CollabStreamSub}; -use crate::stream::CollabStream; +use crate::collab_update_sink::{AwarenessUpdateSink, CollabUpdateSink}; +use crate::error::{internal, StreamError}; +use crate::lease::{Lease, LeaseAcquisition}; +use crate::model::{AwarenessStreamUpdate, CollabStreamUpdate, MessageId}; use crate::stream_group::{StreamConfig, StreamGroup}; +use crate::stream_router::{StreamRouter, StreamRouterOptions}; +use futures::Stream; use redis::aio::ConnectionManager; +use redis::streams::StreamReadReply; +use redis::{AsyncCommands, FromRedisValue}; +use std::sync::Arc; +use std::time::Duration; use tracing::error; -pub const CONTROL_STREAM_KEY: &str = "af_collab_control"; - #[derive(Clone)] pub struct CollabRedisStream { connection_manager: ConnectionManager, + stream_router: Arc, } impl CollabRedisStream { + pub const LEASE_TTL: Duration = Duration::from_secs(60); + pub async fn new(redis_client: redis::Client) -> Result { + let router_options = StreamRouterOptions { + worker_count: 60, + xread_streams: 100, + xread_block_millis: Some(5000), + xread_count: None, + }; + let stream_router = Arc::new(StreamRouter::with_options(&redis_client, router_options)?); let connection_manager = redis_client.get_connection_manager().await?; - Ok(Self::new_with_connection_manager(connection_manager)) + Ok(Self::new_with_connection_manager( + connection_manager, + stream_router, + )) } - pub fn new_with_connection_manager(connection_manager: ConnectionManager) -> Self { - Self { connection_manager } + pub fn new_with_connection_manager( + connection_manager: ConnectionManager, + stream_router: Arc, + ) -> Self { + Self { + connection_manager, + stream_router, + } } - pub async fn stream(&self, workspace_id: &str, oid: &str) -> CollabStream { - CollabStream::new(workspace_id, oid, self.connection_manager.clone()) + pub async fn lease( + &self, + workspace_id: &str, + object_id: &str, + ) -> Result, StreamError> { + let lease_key = format!("af:{}:{}:snapshot_lease", workspace_id, object_id); + self + .connection_manager + .lease(lease_key, Self::LEASE_TTL) + .await } pub async fn collab_control_stream( @@ -46,7 +78,7 @@ impl CollabRedisStream { Ok(group) } - pub async fn collab_update_stream( + pub async fn collab_update_stream_group( &self, workspace_id: &str, oid: &str, @@ -66,29 +98,106 @@ impl CollabRedisStream { group.ensure_consumer_group().await?; Ok(group) } -} -pub struct PubSubClient { - redis_client: redis::Client, - connection_manager: ConnectionManager, -} + pub fn collab_update_sink(&self, workspace_id: &str, object_id: &str) -> CollabUpdateSink { + let stream_key = CollabStreamUpdate::stream_key(workspace_id, object_id); + CollabUpdateSink::new(self.connection_manager.clone(), stream_key) + } -impl PubSubClient { - pub async fn new(redis_client: redis::Client) -> Result { - let connection_manager = redis_client.get_connection_manager().await?; - Ok(Self { - redis_client, - connection_manager, - }) + pub fn awareness_update_sink(&self, workspace_id: &str, object_id: &str) -> AwarenessUpdateSink { + let stream_key = AwarenessStreamUpdate::stream_key(workspace_id, object_id); + AwarenessUpdateSink::new(self.connection_manager.clone(), stream_key) + } + + /// Reads all collab updates for a given `workspace_id`:`object_id` entry, starting + /// from a given message id. Once Redis stream return no more results, the stream will be closed. + pub async fn current_collab_updates( + &self, + workspace_id: &str, + object_id: &str, + since: Option, + ) -> Result, StreamError> { + let stream_key = CollabStreamUpdate::stream_key(workspace_id, object_id); + let since = since.unwrap_or_default().to_string(); + let mut conn = self.connection_manager.clone(); + let mut result = Vec::new(); + let mut reply: StreamReadReply = conn.xread(&[&stream_key], &[&since]).await?; + if let Some(key) = reply.keys.pop() { + if key.key == stream_key { + for stream_id in key.ids { + let message_id = MessageId::try_from(stream_id.id)?; + let stream_update = CollabStreamUpdate::try_from(stream_id.map)?; + result.push((message_id, stream_update)); + } + } + } + Ok(result) } - pub async fn collab_pub(&self) -> CollabStreamPub { - CollabStreamPub::new(self.connection_manager.clone()) + /// Reads all collab updates for a given `workspace_id`:`object_id` entry, starting + /// from a given message id. This stream will be kept alive and pass over all future messages + /// coming from corresponding Redis stream until explicitly closed. + pub fn live_collab_updates( + &self, + workspace_id: &str, + object_id: &str, + since: Option, + ) -> impl Stream> { + let stream_key = CollabStreamUpdate::stream_key(workspace_id, object_id); + let since = since.map(|id| id.to_string()); + let mut reader = self.stream_router.observe(stream_key, since); + async_stream::try_stream! { + while let Some((message_id, fields)) = reader.recv().await { + tracing::trace!("incoming collab update `{}`", message_id); + let message_id = MessageId::try_from(message_id).map_err(|e| internal(e.to_string()))?; + let collab_update = CollabStreamUpdate::try_from(fields)?; + yield (message_id, collab_update); + } + } } - #[allow(deprecated)] - pub async fn collab_sub(&self) -> Result { - let conn = self.redis_client.get_async_connection().await?; - Ok(CollabStreamSub::new(conn)) + pub fn awareness_updates( + &self, + workspace_id: &str, + object_id: &str, + since: Option, + ) -> impl Stream> { + let stream_key = AwarenessStreamUpdate::stream_key(workspace_id, object_id); + let since = since.map(|id| id.to_string()); + let mut reader = self.stream_router.observe(stream_key, since); + async_stream::try_stream! { + while let Some((message_id, fields)) = reader.recv().await { + tracing::trace!("incoming awareness update `{}`", message_id); + let awareness_update = AwarenessStreamUpdate::try_from(fields)?; + yield awareness_update; + } + } + } + + pub async fn prune_stream( + &self, + stream_key: &str, + mut message_id: MessageId, + ) -> Result { + let mut conn = self.connection_manager.clone(); + // we want to delete everything <= message_id + message_id.sequence_number += 1; + let value = conn + .send_packed_command( + redis::cmd("XTRIM") + .arg(stream_key) + .arg("MINID") + .arg(format!("{}", message_id)), + ) + .await?; + let count = usize::from_redis_value(&value)?; + drop(conn); + tracing::debug!( + "pruned redis stream `{}` <= `{}` ({} objects)", + stream_key, + message_id, + count + ); + Ok(count) } } diff --git a/libs/collab-stream/src/collab_update_sink.rs b/libs/collab-stream/src/collab_update_sink.rs new file mode 100644 index 000000000..db1a82fd2 --- /dev/null +++ b/libs/collab-stream/src/collab_update_sink.rs @@ -0,0 +1,66 @@ +use crate::error::StreamError; +use crate::model::{AwarenessStreamUpdate, CollabStreamUpdate, MessageId}; +use redis::aio::ConnectionManager; +use redis::cmd; +use tokio::sync::Mutex; + +pub struct CollabUpdateSink { + conn: Mutex, + stream_key: String, +} + +impl CollabUpdateSink { + pub fn new(conn: ConnectionManager, stream_key: String) -> Self { + CollabUpdateSink { + conn: conn.into(), + stream_key, + } + } + + pub async fn send(&self, msg: &CollabStreamUpdate) -> Result { + let mut lock = self.conn.lock().await; + let msg_id: MessageId = cmd("XADD") + .arg(&self.stream_key) + .arg("*") + .arg("flags") + .arg(msg.flags) + .arg("sender") + .arg(msg.sender.to_string()) + .arg("data") + .arg(&*msg.data) + .query_async(&mut *lock) + .await?; + Ok(msg_id) + } +} + +pub struct AwarenessUpdateSink { + conn: Mutex, + stream_key: String, +} + +impl AwarenessUpdateSink { + pub fn new(conn: ConnectionManager, stream_key: String) -> Self { + AwarenessUpdateSink { + conn: conn.into(), + stream_key, + } + } + + pub async fn send(&self, msg: &AwarenessStreamUpdate) -> Result { + let mut lock = self.conn.lock().await; + let msg_id: MessageId = cmd("XADD") + .arg(&self.stream_key) + .arg("MAXLEN") + .arg("~") + .arg(100) // we cap awareness stream to at most 20 awareness updates + .arg("*") + .arg("sender") + .arg(msg.sender.to_string()) + .arg("data") + .arg(&*msg.data) + .query_async(&mut *lock) + .await?; + Ok(msg_id) + } +} diff --git a/libs/collab-stream/src/error.rs b/libs/collab-stream/src/error.rs index 06da19325..b0241a938 100644 --- a/libs/collab-stream/src/error.rs +++ b/libs/collab-stream/src/error.rs @@ -32,6 +32,12 @@ pub enum StreamError { #[error(transparent)] BinCodeSerde(#[from] bincode::Error), + #[error("failed to decode update: {0}")] + UpdateError(#[from] collab::preclude::encoding::read::Error), + + #[error("I/O error: {0}")] + IO(#[from] std::io::Error), + #[error("Internal error: {0}")] Internal(anyhow::Error), } diff --git a/libs/collab-stream/src/lease.rs b/libs/collab-stream/src/lease.rs new file mode 100644 index 000000000..9a80e7693 --- /dev/null +++ b/libs/collab-stream/src/lease.rs @@ -0,0 +1,145 @@ +use crate::error::StreamError; +use async_trait::async_trait; +use redis::aio::ConnectionManager; +use redis::Value; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +const RELEASE_SCRIPT: &str = r#" +if redis.call("GET", KEYS[1]) == ARGV[1] then + return redis.call("DEL", KEYS[1]) +else + return 0 +end +"#; + +pub struct LeaseAcquisition { + conn: Option, + stream_key: String, + token: u128, +} + +impl LeaseAcquisition { + pub async fn release(&mut self) -> Result { + if let Some(conn) = self.conn.take() { + Self::release_internal(conn, &self.stream_key, self.token).await + } else { + Ok(false) + } + } + + async fn release_internal>( + mut conn: ConnectionManager, + stream_key: S, + token: u128, + ) -> Result { + let script = redis::Script::new(RELEASE_SCRIPT); + let result: i32 = script + .key(stream_key.as_ref()) + .arg(token.to_le_bytes().as_slice()) + .invoke_async(&mut conn) + .await?; + Ok(result == 1) + } +} + +impl Drop for LeaseAcquisition { + fn drop(&mut self) { + if let Some(conn) = self.conn.take() { + let stream_key = self.stream_key.clone(); + let token = self.token; + tokio::spawn(async move { + if let Err(err) = Self::release_internal(conn, stream_key, token).await { + tracing::error!("error while releasing lease (drop): {}", err); + } + }); + } + } +} + +/// This is Redlock algorithm implementation. +/// See: https://redis.io/docs/latest/commands/set#patterns +#[async_trait] +pub trait Lease { + /// Attempt to acquire lease on a stream for a given time-to-live. + /// Returns `None` if the lease could not be acquired. + async fn lease( + &self, + stream_key: String, + ttl: Duration, + ) -> Result, StreamError>; +} + +#[async_trait] +impl Lease for ConnectionManager { + async fn lease( + &self, + stream_key: String, + ttl: Duration, + ) -> Result, StreamError> { + let mut conn = self.clone(); + let ttl = ttl.as_millis() as u64; + let token = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis(); + tracing::trace!("acquiring lease `{}` for {}ms", stream_key, ttl); + let result: Value = redis::cmd("SET") + .arg(&stream_key) + .arg(token.to_le_bytes().as_slice()) + .arg("NX") + .arg("PX") + .arg(ttl) + .query_async(&mut conn) + .await?; + + match result { + Value::Okay => Ok(Some(LeaseAcquisition { + conn: Some(conn), + stream_key, + token, + })), + o => { + tracing::trace!("lease locked: {:?}", o); + Ok(None) + }, + } + } +} + +#[cfg(test)] +mod test { + use crate::lease::Lease; + use redis::Client; + + #[tokio::test] + async fn lease_acquisition() { + let redis_client = Client::open("redis://localhost:6379").unwrap(); + let conn = redis_client.get_connection_manager().await.unwrap(); + + let l1 = conn + .lease("stream1".into(), std::time::Duration::from_secs(1)) + .await + .unwrap(); + + assert!(l1.is_some(), "should successfully acquire lease"); + + let l2 = conn + .lease("stream1".into(), std::time::Duration::from_secs(1)) + .await + .unwrap(); + + assert!(l2.is_none(), "should fail to acquire lease"); + + l1.unwrap().release().await.unwrap(); + + let l3 = conn + .lease("stream1".into(), std::time::Duration::from_secs(1)) + .await + .unwrap(); + + assert!( + l3.is_some(), + "should successfully acquire lease after it was released" + ); + } +} diff --git a/libs/collab-stream/src/lib.rs b/libs/collab-stream/src/lib.rs index bd7e2d9e6..ff2c0cad6 100644 --- a/libs/collab-stream/src/lib.rs +++ b/libs/collab-stream/src/lib.rs @@ -1,6 +1,8 @@ pub mod client; +pub mod collab_update_sink; pub mod error; +pub mod lease; pub mod model; pub mod pubsub; -pub mod stream; pub mod stream_group; +pub mod stream_router; diff --git a/libs/collab-stream/src/model.rs b/libs/collab-stream/src/model.rs index 0ee436a41..d74eb77a1 100644 --- a/libs/collab-stream/src/model.rs +++ b/libs/collab-stream/src/model.rs @@ -1,12 +1,14 @@ use crate::error::{internal, StreamError}; use bytes::Bytes; +use collab::core::origin::{CollabClient, CollabOrigin}; +use collab::preclude::updates::decoder::Decode; use collab_entity::proto::collab::collab_update_event::Update; use collab_entity::{proto, CollabType}; use prost::Message; use redis::streams::StreamId; -use redis::{FromRedisValue, RedisError, RedisResult, Value}; +use redis::{FromRedisValue, RedisError, RedisResult, RedisWrite, ToRedisArgs, Value}; use serde::{Deserialize, Serialize}; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap}; use std::fmt::{Display, Formatter}; use std::ops::Deref; use std::str::FromStr; @@ -20,12 +22,21 @@ use std::str::FromStr; /// /// An example message ID might look like this: 1631020452097-0. In this example, 1631020452097 is /// the timestamp in milliseconds, and 0 is the sequence number. -#[derive(Debug, Clone)] +#[derive(Debug, Copy, Clone, Default, Ord, PartialOrd, Eq, PartialEq)] pub struct MessageId { pub timestamp_ms: u64, pub sequence_number: u16, } +impl MessageId { + pub fn new(timestamp_ms: u64, sequence_number: u16) -> Self { + MessageId { + timestamp_ms, + sequence_number, + } + } +} + impl Display for MessageId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}-{}", self.timestamp_ms, self.sequence_number) @@ -355,8 +366,218 @@ impl TryFrom for StreamBinary { } } +pub struct CollabStreamUpdate { + pub data: Vec, // yrs::Update::encode_v1 + pub sender: CollabOrigin, + pub flags: UpdateFlags, +} + +impl CollabStreamUpdate { + pub fn new(data: B, sender: CollabOrigin, flags: F) -> Self + where + B: Into>, + F: Into, + { + CollabStreamUpdate { + data: data.into(), + sender, + flags: flags.into(), + } + } + + /// Returns Redis stream key, that's storing entries mapped to/from [CollabStreamUpdate]. + pub fn stream_key(workspace_id: &str, object_id: &str) -> String { + // use `:` separator as it adheres to Redis naming conventions + format!("af:{}:{}:updates", workspace_id, object_id) + } + + pub fn into_update(self) -> Result { + let bytes = if self.flags.is_compressed() { + zstd::decode_all(std::io::Cursor::new(self.data))? + } else { + self.data + }; + let update = if self.flags.is_v1_encoded() { + collab::preclude::Update::decode_v1(&bytes)? + } else { + collab::preclude::Update::decode_v2(&bytes)? + }; + Ok(update) + } +} + +impl TryFrom> for CollabStreamUpdate { + type Error = StreamError; + + fn try_from(fields: HashMap) -> Result { + let sender = match fields.get("sender") { + None => CollabOrigin::Empty, + Some(sender) => { + let raw_origin = String::from_redis_value(sender)?; + collab_origin_from_str(&raw_origin)? + }, + }; + let flags = match fields.get("flags") { + None => UpdateFlags::default(), + Some(flags) => u8::from_redis_value(flags).unwrap_or(0).into(), + }; + let data_raw = fields + .get("data") + .ok_or_else(|| internal("expecting field `data`"))?; + let data: Vec = FromRedisValue::from_redis_value(data_raw)?; + Ok(CollabStreamUpdate { + data, + sender, + flags, + }) + } +} + +pub struct AwarenessStreamUpdate { + pub data: Vec, // AwarenessUpdate::encode_v1 + pub sender: CollabOrigin, +} + +impl AwarenessStreamUpdate { + /// Returns Redis stream key, that's storing entries mapped to/from [AwarenessStreamUpdate]. + pub fn stream_key(workspace_id: &str, object_id: &str) -> String { + format!("af:{}:{}:awareness", workspace_id, object_id) + } +} + +impl TryFrom> for AwarenessStreamUpdate { + type Error = StreamError; + + fn try_from(fields: HashMap) -> Result { + let sender = match fields.get("sender") { + None => CollabOrigin::Empty, + Some(sender) => { + let raw_origin = String::from_redis_value(sender)?; + collab_origin_from_str(&raw_origin)? + }, + }; + let data_raw = fields + .get("data") + .ok_or_else(|| internal("expecting field `data`"))?; + let data: Vec = FromRedisValue::from_redis_value(data_raw)?; + Ok(AwarenessStreamUpdate { data, sender }) + } +} + +//FIXME: this should be `impl FromStr for CollabOrigin` +fn collab_origin_from_str(value: &str) -> RedisResult { + match value { + "" => Ok(CollabOrigin::Empty), + "server" => Ok(CollabOrigin::Server), + other => { + let mut split = other.split('|'); + match (split.next(), split.next()) { + (Some(uid), Some(device_id)) | (Some(device_id), Some(uid)) + if uid.starts_with("uid:") && device_id.starts_with("device_id:") => + { + let uid = uid.trim_start_matches("uid:"); + let device_id = device_id.trim_start_matches("device_id:").to_string(); + let uid: i64 = uid + .parse() + .map_err(|err| internal(format!("failed to parse uid: {}", err)))?; + Ok(CollabOrigin::Client(CollabClient { uid, device_id })) + }, + _ => Err(internal(format!( + "couldn't parse collab origin from `{}`", + other + ))), + } + }, + } +} + +#[repr(transparent)] +#[derive(Copy, Clone, Eq, PartialEq, Default)] +pub struct UpdateFlags(u8); + +impl UpdateFlags { + /// Flag bit to mark if update is encoded using [EncoderV2] (if set) or [EncoderV1] (if clear). + pub const IS_V2_ENCODED: u8 = 0b0000_0001; + /// Flag bit to mark if update is compressed. + pub const IS_COMPRESSED: u8 = 0b0000_0010; + + #[inline] + pub fn is_v2_encoded(&self) -> bool { + self.0 & Self::IS_V2_ENCODED != 0 + } + + #[inline] + pub fn is_v1_encoded(&self) -> bool { + !self.is_v2_encoded() + } + + #[inline] + pub fn is_compressed(&self) -> bool { + self.0 & Self::IS_COMPRESSED != 0 + } +} + +impl ToRedisArgs for UpdateFlags { + #[inline] + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + self.0.write_redis_args(out) + } +} + +impl From for UpdateFlags { + #[inline] + fn from(value: u8) -> Self { + UpdateFlags(value) + } +} + +impl Display for UpdateFlags { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + if !self.is_v2_encoded() { + write!(f, ".v1")?; + } else { + write!(f, ".v2")?; + } + + if self.is_compressed() { + write!(f, ".zstd")?; + } + + Ok(()) + } +} + #[cfg(test)] mod test { + use crate::model::collab_origin_from_str; + use collab::core::origin::{CollabClient, CollabOrigin}; + + #[test] + fn parse_collab_origin_empty() { + let expected = CollabOrigin::Empty; + let actual = collab_origin_from_str(&expected.to_string()).unwrap(); + assert_eq!(actual, expected); + } + + #[test] + fn parse_collab_origin_server() { + let expected = CollabOrigin::Server; + let actual = collab_origin_from_str(&expected.to_string()).unwrap(); + assert_eq!(actual, expected); + } + + #[test] + fn parse_collab_origin_client() { + let expected = CollabOrigin::Client(CollabClient { + uid: 123, + device_id: "test-device".to_string(), + }); + let actual = collab_origin_from_str(&expected.to_string()).unwrap(); + assert_eq!(actual, expected); + } #[test] fn test_collab_update_event_decoding() { diff --git a/libs/collab-stream/src/stream.rs b/libs/collab-stream/src/stream.rs deleted file mode 100644 index 148bceb24..000000000 --- a/libs/collab-stream/src/stream.rs +++ /dev/null @@ -1,99 +0,0 @@ -use crate::error::StreamError; -use crate::model::{MessageId, StreamBinary, StreamMessage, StreamMessageByStreamKey}; -use redis::aio::ConnectionManager; -use redis::streams::{StreamMaxlen, StreamReadOptions}; -use redis::{pipe, AsyncCommands, Pipeline, RedisError}; - -pub struct CollabStream { - connection_manager: ConnectionManager, - stream_key: String, -} - -impl CollabStream { - pub fn new(workspace_id: &str, oid: &str, connection_manager: ConnectionManager) -> Self { - let stream_key = format!("af_collab-{}-{}", workspace_id, oid); - Self { - connection_manager, - stream_key, - } - } - - /// Inserts a single message into the Redis stream. - pub async fn insert_message(&mut self, message: StreamBinary) -> Result { - let tuple = message.into_tuple_array(); - let message_id = self - .connection_manager - .xadd(&self.stream_key, "*", tuple.as_slice()) - .await?; - Ok(message_id) - } - - /// Inserts multiple messages into the Redis stream using a pipeline. - /// - pub async fn insert_messages(&mut self, messages: Vec) -> Result<(), StreamError> { - let mut pipe = pipe(); - for message in messages { - let tuple = message.into_tuple_array(); - let _: &mut Pipeline = pipe.xadd(&self.stream_key, "*", tuple.as_slice()); - } - let () = pipe.query_async(&mut self.connection_manager).await?; - Ok(()) - } - - /// Fetches the next message from a Redis stream after a specified entry. - /// - pub async fn next(&mut self) -> Result, StreamError> { - let options = StreamReadOptions::default().count(1).block(100); - let map: StreamMessageByStreamKey = self - .connection_manager - .xread_options(&[&self.stream_key], &["$"], &options) - .await?; - - let (_, mut messages) = map - .0 - .into_iter() - .next() - .ok_or_else(|| StreamError::UnexpectedValue("Empty stream".into()))?; - - debug_assert_eq!(messages.len(), 1); - Ok(messages.pop()) - } - - pub async fn next_after( - &mut self, - after: Option, - ) -> Result, StreamError> { - let message_id = after - .map(|ct| ct.to_string()) - .unwrap_or_else(|| "$".to_string()); - - let options = StreamReadOptions::default().group("1", "2").block(100); - let map: StreamMessageByStreamKey = self - .connection_manager - .xread_options(&[&self.stream_key], &[&message_id], &options) - .await?; - - let (_, mut messages) = map - .0 - .into_iter() - .next() - .ok_or_else(|| StreamError::UnexpectedValue("Empty stream".into()))?; - - debug_assert_eq!(messages.len(), 1); - Ok(messages.pop()) - } - - pub async fn read_all_message(&mut self) -> Result, StreamError> { - let read_messages: Vec = - self.connection_manager.xrange_all(&self.stream_key).await?; - Ok(read_messages.into_iter().map(Into::into).collect()) - } - - pub async fn clear(&mut self) -> Result<(), RedisError> { - let () = self - .connection_manager - .xtrim(&self.stream_key, StreamMaxlen::Equals(0)) - .await?; - Ok(()) - } -} diff --git a/libs/collab-stream/src/stream_group.rs b/libs/collab-stream/src/stream_group.rs index 8361f9eb7..8b45de50c 100644 --- a/libs/collab-stream/src/stream_group.rs +++ b/libs/collab-stream/src/stream_group.rs @@ -418,12 +418,12 @@ impl StreamGroup { _ = interval.tick() => { if let Ok(len) = get_stream_length(&mut connection_manager, &stream_key).await { if len + 100 > max_len { - warn!("stream len is going to exceed the max len: {}, current: {}", max_len, len); + warn!("stream `{}` len is going to exceed the max len: {}, current: {}", stream_key, max_len, len); } } } _ = cancel_token.cancelled() => { - trace!("Stream length check task cancelled."); + trace!("Stream `{}` length check task cancelled.", stream_key); break; } } diff --git a/libs/collab-stream/src/stream_router.rs b/libs/collab-stream/src/stream_router.rs new file mode 100644 index 000000000..fb048c1a9 --- /dev/null +++ b/libs/collab-stream/src/stream_router.rs @@ -0,0 +1,373 @@ +use loole::{Receiver, Sender}; +use redis::streams::{StreamReadOptions, StreamReadReply}; +use redis::Client; +use redis::Commands; +use redis::Connection; +use redis::RedisError; +use redis::RedisResult; +use redis::Value; +use std::collections::HashMap; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering::SeqCst; +use std::sync::Arc; +use std::thread::{sleep, JoinHandle}; +use std::time::Duration; + +/// Redis stream key. +pub type StreamKey = String; + +/// Channel returned by [StreamRouter::observe], that allows to receive messages retrieved by +/// the router. +pub type StreamReader = tokio::sync::mpsc::UnboundedReceiver<(String, RedisMap)>; + +/// Redis stream router used to multiplex multiple number of Redis stream read requests over a +/// fixed number of Redis connections. +pub struct StreamRouter { + buf: Sender, + alive: Arc, + #[allow(dead_code)] + workers: Vec, +} + +impl StreamRouter { + pub fn new(client: &Client) -> Result { + Self::with_options(client, Default::default()) + } + + pub fn with_options(client: &Client, options: StreamRouterOptions) -> Result { + let alive = Arc::new(AtomicBool::new(true)); + let (tx, rx) = loole::unbounded(); + let mut workers = Vec::with_capacity(options.worker_count); + for worker_id in 0..options.worker_count { + let conn = client.get_connection()?; + let worker = Worker::new( + worker_id, + conn, + tx.clone(), + rx.clone(), + alive.clone(), + &options, + ); + workers.push(worker); + } + tracing::info!("stared Redis stream router with {} workers", workers.len()); + Ok(Self { + buf: tx, + workers, + alive, + }) + } + + pub fn observe(&self, stream_key: StreamKey, last_id: Option) -> StreamReader { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let last_id = last_id.unwrap_or_else(|| "0".to_string()); + let h = StreamHandle::new(stream_key, last_id, tx); + self.buf.send(h).unwrap(); + rx + } +} + +impl Drop for StreamRouter { + fn drop(&mut self) { + self.alive.store(false, SeqCst); + } +} + +/// Options used to configure [StreamRouter]. +#[derive(Debug, Clone)] +pub struct StreamRouterOptions { + /// Number of worker threads. Each worker thread has its own Redis connection. + /// Default: number of CPU threads but can vary under specific circumstances. + pub worker_count: usize, + /// How many Redis streams a single Redis poll worker can read at a time. + /// Default: 100 + pub xread_streams: usize, + /// How long poll worker will be blocked while waiting for Redis `XREAD` request to respond. + /// This blocks a worker thread and doesn't affect other threads. + /// + /// If set to `None` it won't block and will return immediately, which gives a biggest + /// responsiveness but can lead to unnecessary active loops causing CPU spikes even when idle. + /// + /// Default: `Some(0)` meaning blocking for indefinite amount of time. + pub xread_block_millis: Option, + /// How many messages a single worker's `XREAD` request is allowed to return. + /// Default: `None` (unbounded). + pub xread_count: Option, +} + +impl Default for StreamRouterOptions { + fn default() -> Self { + StreamRouterOptions { + worker_count: std::thread::available_parallelism().unwrap().get(), + xread_streams: 100, + xread_block_millis: Some(0), + xread_count: None, + } + } +} + +struct Worker { + _handle: JoinHandle<()>, +} + +impl Worker { + fn new( + worker_id: usize, + conn: Connection, + tx: Sender, + rx: Receiver, + alive: Arc, + options: &StreamRouterOptions, + ) -> Self { + let mut xread_options = StreamReadOptions::default(); + if let Some(block_millis) = options.xread_block_millis { + xread_options = xread_options.block(block_millis); + } + if let Some(count) = options.xread_count { + xread_options = xread_options.count(count); + } + let count = options.xread_streams; + let handle = std::thread::spawn(move || { + if let Err(err) = Self::process_streams(conn, tx, rx, alive, xread_options, count) { + tracing::error!("worker {} failed: {}", worker_id, err); + } + }); + Self { _handle: handle } + } + + fn process_streams( + mut conn: Connection, + tx: Sender, + rx: Receiver, + alive: Arc, + options: StreamReadOptions, + count: usize, + ) -> RedisResult<()> { + let mut stream_keys = Vec::with_capacity(count); + let mut message_ids = Vec::with_capacity(count); + let mut senders = HashMap::with_capacity(count); + while alive.load(SeqCst) { + if !Self::read_buf(&rx, &mut stream_keys, &mut message_ids, &mut senders) { + break; // rx channel has closed + } + + let key_count = stream_keys.len(); + if key_count == 0 { + tracing::warn!("Bug: read empty buf"); + sleep(Duration::from_millis(100)); + continue; + } + + let result: StreamReadReply = conn.xread_options(&stream_keys, &message_ids, &options)?; + + let mut msgs = 0; + for stream in result.keys { + let mut remove_sender = false; + if let Some((sender, idx)) = senders.get(stream.key.as_str()) { + for id in stream.ids { + let message_id = id.id; + let value = id.map; + message_ids[*idx].clone_from(&message_id); //TODO: optimize + msgs += 1; + if let Err(err) = sender.send((message_id, value)) { + tracing::warn!("failed to send: {}", err); + remove_sender = true; + } + } + } + + if remove_sender { + senders.remove(stream.key.as_str()); + } + } + + if msgs > 0 { + tracing::trace!( + "XREAD: read total of {} messages for {} streams", + msgs, + key_count + ); + } + Self::schedule_back(&tx, &mut stream_keys, &mut message_ids, &mut senders); + } + Ok(()) + } + + fn schedule_back( + tx: &Sender, + keys: &mut Vec, + ids: &mut Vec, + senders: &mut HashMap<&str, (StreamSender, usize)>, + ) { + let keys = keys.drain(..); + let mut ids = ids.drain(..); + for key in keys { + if let Some(last_id) = ids.next() { + if let Some((sender, _)) = senders.remove(key.as_str()) { + let h = StreamHandle::new(key, last_id, sender); + if let Err(err) = tx.send(h) { + tracing::warn!("failed to reschedule: {}", err); + break; + } + } + } + } + senders.clear(); + } + + fn read_buf( + rx: &Receiver, + stream_keys: &mut Vec, + message_ids: &mut Vec, + senders: &mut HashMap<&'static str, (StreamSender, usize)>, + ) -> bool { + // try to receive first element - block thread if there's none + let mut count = stream_keys.capacity(); + if let Ok(h) = rx.recv() { + // senders and stream_keys have bound lifetimes and fixed internal buffers + // since API users are using StreamKeys => String, we want to avoid allocations + let key_ref: &'static str = unsafe { std::mem::transmute(h.key.as_str()) }; + senders.insert(key_ref, (h.sender, stream_keys.len())); + stream_keys.push(h.key); + message_ids.push(h.last_id.to_string()); + + count -= 1; + if count == 0 { + return true; + } + + // try to fill more without blocking if there's anything on the receiver + while let Ok(h) = rx.try_recv() { + let key_ref: &'static str = unsafe { std::mem::transmute(h.key.as_str()) }; + senders.insert(key_ref, (h.sender, stream_keys.len())); + stream_keys.push(h.key); + message_ids.push(h.last_id.to_string()); + + count -= 1; + if count == 0 { + return true; + } + } + true + } else { + false + } + } +} + +type RedisMap = HashMap; +type StreamSender = tokio::sync::mpsc::UnboundedSender<(String, RedisMap)>; + +struct StreamHandle { + key: StreamKey, + last_id: String, + sender: StreamSender, +} + +impl StreamHandle { + fn new(key: StreamKey, last_id: String, sender: StreamSender) -> Self { + StreamHandle { + key, + last_id, + sender, + } + } +} + +#[cfg(test)] +mod test { + use crate::stream_router::StreamRouter; + use rand::random; + use redis::{Client, Commands, FromRedisValue}; + use tokio::task::JoinSet; + + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + async fn multi_worker_preexisting_messages() { + const ROUTES_COUNT: usize = 200; + const MSG_PER_ROUTE: usize = 10; + let mut client = Client::open("redis://127.0.0.1/").unwrap(); + let keys = init_streams(&mut client, ROUTES_COUNT, MSG_PER_ROUTE); + + let router = StreamRouter::new(&client).unwrap(); + + let mut join_set = JoinSet::new(); + for key in keys { + let mut observer = router.observe(key.clone(), None); + join_set.spawn(async move { + for i in 0..MSG_PER_ROUTE { + let (_msg_id, map) = observer.recv().await.unwrap(); + let value = String::from_redis_value(&map["data"]).unwrap(); + assert_eq!(value, format!("{}-{}", key, i)); + } + }); + } + + while let Some(t) = join_set.join_next().await { + t.unwrap(); + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + async fn multi_worker_live_messages() { + const ROUTES_COUNT: usize = 200; + const MSG_PER_ROUTE: usize = 10; + let mut client = Client::open("redis://127.0.0.1/").unwrap(); + let keys = init_streams(&mut client, ROUTES_COUNT, 0); + + let router = StreamRouter::new(&client).unwrap(); + + let mut join_set = JoinSet::new(); + for key in keys.iter() { + let mut observer = router.observe(key.clone(), None); + let key = key.clone(); + join_set.spawn(async move { + for i in 0..MSG_PER_ROUTE { + let (_msg_id, map) = observer.recv().await.unwrap(); + let value = String::from_redis_value(&map["data"]).unwrap(); + assert_eq!(value, format!("{}-{}", key, i)); + } + }); + } + + for msg_idx in 0..MSG_PER_ROUTE { + for key in keys.iter() { + let data = format!("{}-{}", key, msg_idx); + let _: String = client.xadd(key, "*", &[("data", data)]).unwrap(); + } + } + + while let Some(t) = join_set.join_next().await { + t.unwrap(); + } + } + + #[tokio::test] + async fn stream_reader_continue_from() { + let mut client = Client::open("redis://127.0.0.1/").unwrap(); + let key = format!("test:{}:{}", random::(), 0); + let _: String = client.xadd(&key, "*", &[("data", 1)]).unwrap(); + let m2: String = client.xadd(&key, "*", &[("data", 2)]).unwrap(); + let m3: String = client.xadd(&key, "*", &[("data", 3)]).unwrap(); + + let router = StreamRouter::new(&client).unwrap(); + let mut observer = router.observe(key, Some(m2)); + + let (msg_id, m) = observer.recv().await.unwrap(); + assert_eq!(msg_id, m3); + assert_eq!(u32::from_redis_value(&m["data"]).unwrap(), 3); + } + + fn init_streams(client: &mut Client, stream_count: usize, msgs_per_stream: usize) -> Vec { + let test_prefix: u32 = random(); + let mut keys = Vec::with_capacity(stream_count); + for worker_idx in 0..stream_count { + let key = format!("test:{}:{}", test_prefix, worker_idx); + for msg_idx in 0..msgs_per_stream { + let data = format!("{}-{}", key, msg_idx); + let _: String = client.xadd(&key, "*", &[("data", data)]).unwrap(); + } + keys.push(key); + } + keys + } +} diff --git a/libs/collab-stream/tests/collab_stream_test/mod.rs b/libs/collab-stream/tests/collab_stream_test/mod.rs index a7c84e32e..bba24ebfa 100644 --- a/libs/collab-stream/tests/collab_stream_test/mod.rs +++ b/libs/collab-stream/tests/collab_stream_test/mod.rs @@ -1,4 +1,3 @@ -mod pubsub_test; mod stream_group_test; mod stream_test; mod test_util; diff --git a/libs/collab-stream/tests/collab_stream_test/pubsub_test.rs b/libs/collab-stream/tests/collab_stream_test/pubsub_test.rs deleted file mode 100644 index e0e8b6f54..000000000 --- a/libs/collab-stream/tests/collab_stream_test/pubsub_test.rs +++ /dev/null @@ -1,37 +0,0 @@ -use crate::collab_stream_test::test_util::{pubsub_client, random_i64}; - -use collab_stream::pubsub::PubSubMessage; - -use futures::StreamExt; -use std::time::Duration; -use tokio::time::sleep; - -#[tokio::test] -async fn pubsub_test() { - let oid = format!("o{}", random_i64()); - let client_1 = pubsub_client().await; - let client_2 = pubsub_client().await; - - let mut publish = client_1.collab_pub().await; - let send_msg = PubSubMessage { - workspace_id: "1".to_string(), - oid: oid.clone(), - }; - - let cloned_msg = send_msg.clone(); - tokio::spawn(async move { - sleep(Duration::from_secs(1)).await; - match publish.publish(cloned_msg).await { - Ok(_) => {}, - Err(err) => { - panic!("failed to publish message: {:?}", err); - }, - } - }); - - let subscriber = client_2.collab_sub().await.unwrap(); - let mut pubsub = subscriber.subscribe().await.unwrap(); - let receive_msg = pubsub.next().await.unwrap().unwrap(); - - assert_eq!(send_msg.workspace_id, receive_msg.workspace_id); -} diff --git a/libs/collab-stream/tests/collab_stream_test/stream_group_test.rs b/libs/collab-stream/tests/collab_stream_test/stream_group_test.rs index 371888687..b6404cbeb 100644 --- a/libs/collab-stream/tests/collab_stream_test/stream_group_test.rs +++ b/libs/collab-stream/tests/collab_stream_test/stream_group_test.rs @@ -10,7 +10,7 @@ async fn single_group_read_message_test() { let oid = format!("o{}", random_i64()); let client = stream_client().await; let mut group = client - .collab_update_stream(workspace_id, &oid, "g1") + .collab_update_stream_group(workspace_id, &oid, "g1") .await .unwrap(); let msg = StreamBinary(vec![1, 2, 3, 4, 5]); @@ -18,7 +18,7 @@ async fn single_group_read_message_test() { { let client = stream_client().await; let mut group = client - .collab_update_stream(workspace_id, &oid, "g2") + .collab_update_stream_group(workspace_id, &oid, "g2") .await .unwrap(); group.insert_binary(msg).await.unwrap(); @@ -45,7 +45,7 @@ async fn single_group_async_read_message_test() { let oid = format!("o{}", random_i64()); let client = stream_client().await; let mut group = client - .collab_update_stream(workspace_id, &oid, "g1") + .collab_update_stream_group(workspace_id, &oid, "g1") .await .unwrap(); @@ -54,7 +54,7 @@ async fn single_group_async_read_message_test() { { let client = stream_client().await; let mut group = client - .collab_update_stream(workspace_id, &oid, "g2") + .collab_update_stream_group(workspace_id, &oid, "g2") .await .unwrap(); group.insert_binary(msg).await.unwrap(); @@ -79,14 +79,23 @@ async fn single_group_async_read_message_test() { async fn different_group_read_undelivered_message_test() { let oid = format!("o{}", random_i64()); let client = stream_client().await; - let mut group_1 = client.collab_update_stream("w1", &oid, "g1").await.unwrap(); - let mut group_2 = client.collab_update_stream("w1", &oid, "g2").await.unwrap(); + let mut group_1 = client + .collab_update_stream_group("w1", &oid, "g1") + .await + .unwrap(); + let mut group_2 = client + .collab_update_stream_group("w1", &oid, "g2") + .await + .unwrap(); let msg = StreamBinary(vec![1, 2, 3, 4, 5]); { let client = stream_client().await; - let mut group = client.collab_update_stream("w1", &oid, "g2").await.unwrap(); + let mut group = client + .collab_update_stream_group("w1", &oid, "g2") + .await + .unwrap(); group.insert_binary(msg).await.unwrap(); } @@ -105,14 +114,23 @@ async fn different_group_read_undelivered_message_test() { async fn different_group_read_message_test() { let oid = format!("o{}", random_i64()); let client = stream_client().await; - let mut group_1 = client.collab_update_stream("w1", &oid, "g1").await.unwrap(); - let mut group_2 = client.collab_update_stream("w1", &oid, "g2").await.unwrap(); + let mut group_1 = client + .collab_update_stream_group("w1", &oid, "g1") + .await + .unwrap(); + let mut group_2 = client + .collab_update_stream_group("w1", &oid, "g2") + .await + .unwrap(); let msg = StreamBinary(vec![1, 2, 3, 4, 5]); { let client = stream_client().await; - let mut group = client.collab_update_stream("w1", &oid, "g2").await.unwrap(); + let mut group = client + .collab_update_stream_group("w1", &oid, "g2") + .await + .unwrap(); group.insert_binary(msg).await.unwrap(); } let msg = group_1 @@ -140,13 +158,13 @@ async fn read_specific_num_of_message_test() { let object_id = format!("o{}", random_i64()); let client = stream_client().await; let mut group_1 = client - .collab_update_stream("w1", &object_id, "g1") + .collab_update_stream_group("w1", &object_id, "g1") .await .unwrap(); { let client = stream_client().await; let mut group = client - .collab_update_stream("w1", &object_id, "g2") + .collab_update_stream_group("w1", &object_id, "g2") .await .unwrap(); let mut messages = vec![]; @@ -177,13 +195,13 @@ async fn read_all_message_test() { let object_id = format!("o{}", random_i64()); let client = stream_client().await; let mut group = client - .collab_update_stream("w1", &object_id, "g1") + .collab_update_stream_group("w1", &object_id, "g1") .await .unwrap(); { let client = stream_client().await; let mut group_2 = client - .collab_update_stream("w1", &object_id, "g2") + .collab_update_stream_group("w1", &object_id, "g2") .await .unwrap(); let mut messages = vec![]; @@ -211,10 +229,16 @@ async fn group_already_exist_test() { let client = stream_client().await; // create group - client.collab_update_stream("w1", &oid, "g2").await.unwrap(); + client + .collab_update_stream_group("w1", &oid, "g2") + .await + .unwrap(); // create same group - client.collab_update_stream("w1", &oid, "g2").await.unwrap(); + client + .collab_update_stream_group("w1", &oid, "g2") + .await + .unwrap(); } #[tokio::test] @@ -223,7 +247,10 @@ async fn group_not_exist_test() { let client = stream_client().await; // create group - let mut group = client.collab_update_stream("w1", &oid, "g2").await.unwrap(); + let mut group = client + .collab_update_stream_group("w1", &oid, "g2") + .await + .unwrap(); group.destroy_group().await; let err = group diff --git a/libs/collab-stream/tests/collab_stream_test/stream_test.rs b/libs/collab-stream/tests/collab_stream_test/stream_test.rs index a8108f99c..8b1378917 100644 --- a/libs/collab-stream/tests/collab_stream_test/stream_test.rs +++ b/libs/collab-stream/tests/collab_stream_test/stream_test.rs @@ -1,50 +1 @@ -use crate::collab_stream_test::test_util::{random_i64, stream_client}; -use collab_stream::model::StreamBinary; -#[tokio::test] -async fn read_single_message_test() { - let oid = format!("o{}", random_i64()); - let client_2 = stream_client().await; - let mut stream_2 = client_2.stream("w1", &oid).await; - - let (tx, mut rx) = tokio::sync::mpsc::channel(1); - tokio::spawn(async move { - let msg = stream_2.next().await.unwrap(); - tx.send(msg).await.unwrap(); - }); - - let msg = StreamBinary(vec![1, 2, 3]); - { - let client_1 = stream_client().await; - let mut stream_1 = client_1.stream("w1", &oid).await; - stream_1.insert_message(msg).await.unwrap(); - } - - let msg = rx.recv().await.unwrap().unwrap(); - assert_eq!(msg.data, vec![1, 2, 3]); -} - -#[tokio::test] -async fn read_multiple_messages_test() { - let oid = format!("o{}", random_i64()); - let client_2 = stream_client().await; - let mut stream_2 = client_2.stream("w1", &oid).await; - stream_2.clear().await.unwrap(); - - { - let client_1 = stream_client().await; - let mut stream_1 = client_1.stream("w1", &oid).await; - let messages = vec![ - StreamBinary(vec![1, 2, 3]), - StreamBinary(vec![4, 5, 6]), - StreamBinary(vec![7, 8, 9]), - ]; - stream_1.insert_messages(messages).await.unwrap(); - } - - let msg = stream_2.read_all_message().await.unwrap(); - assert_eq!(msg.len(), 3); - assert_eq!(*msg[0], vec![1, 2, 3]); - assert_eq!(*msg[1], vec![4, 5, 6]); - assert_eq!(*msg[2], vec![7, 8, 9]); -} diff --git a/libs/collab-stream/tests/collab_stream_test/test_util.rs b/libs/collab-stream/tests/collab_stream_test/test_util.rs index f00cfc57c..f3b03fa51 100644 --- a/libs/collab-stream/tests/collab_stream_test/test_util.rs +++ b/libs/collab-stream/tests/collab_stream_test/test_util.rs @@ -1,5 +1,5 @@ use anyhow::Context; -use collab_stream::client::{CollabRedisStream, PubSubClient}; +use collab_stream::client::CollabRedisStream; use rand::{thread_rng, Rng}; pub async fn redis_client() -> redis::Client { @@ -17,14 +17,6 @@ pub async fn stream_client() -> CollabRedisStream { .unwrap() } -pub async fn pubsub_client() -> PubSubClient { - let redis_client = redis_client().await; - PubSubClient::new(redis_client) - .await - .context("failed to create pubsub client") - .unwrap() -} - pub fn random_i64() -> i64 { let mut rng = thread_rng(); let num: i64 = rng.gen(); diff --git a/libs/database-entity/src/dto.rs b/libs/database-entity/src/dto.rs index 6fb179c98..bb5e2db1b 100644 --- a/libs/database-entity/src/dto.rs +++ b/libs/database-entity/src/dto.rs @@ -107,16 +107,16 @@ impl Display for CollabParams { } impl CollabParams { - pub fn new( + pub fn new>( object_id: T, collab_type: CollabType, - encoded_collab_v1: Vec, + encoded_collab_v1: B, ) -> Self { let object_id = object_id.to_string(); Self { object_id, collab_type, - encoded_collab_v1: Bytes::from(encoded_collab_v1), + encoded_collab_v1: encoded_collab_v1.into(), } } @@ -217,7 +217,7 @@ pub struct InsertSnapshotParams { #[validate(custom(function = "validate_not_empty_str"))] pub object_id: String, #[validate(custom(function = "validate_not_empty_payload"))] - pub data: Bytes, + pub doc_state: Bytes, #[validate(custom(function = "validate_not_empty_str"))] pub workspace_id: String, pub collab_type: CollabType, @@ -254,13 +254,13 @@ impl Display for QueryCollabParams { } impl QueryCollabParams { - pub fn new( + pub fn new, T2: Into>( object_id: T1, collab_type: CollabType, workspace_id: T2, ) -> Self { - let workspace_id = workspace_id.to_string(); - let object_id = object_id.to_string(); + let workspace_id = workspace_id.into(); + let object_id = object_id.into(); let inner = QueryCollab { object_id, collab_type, diff --git a/libs/database/src/collab/collab_db_ops.rs b/libs/database/src/collab/collab_db_ops.rs index 2d2e52a09..6be3d3315 100644 --- a/libs/database/src/collab/collab_db_ops.rs +++ b/libs/database/src/collab/collab_db_ops.rs @@ -445,6 +445,28 @@ pub async fn select_snapshot( Ok(row) } +#[inline] +pub async fn select_latest_snapshot( + pg_pool: &PgPool, + workspace_id: &Uuid, + object_id: &str, +) -> Result, Error> { + let row = sqlx::query_as!( + AFSnapshotRow, + r#" + SELECT * FROM af_collab_snapshot + WHERE workspace_id = $1 AND oid = $2 AND deleted_at IS NULL + ORDER BY created_at DESC + LIMIT 1; + "#, + workspace_id, + object_id + ) + .fetch_optional(pg_pool) + .await?; + Ok(row) +} + /// Returns list of snapshots for given object_id in descending order of creation time. pub async fn get_all_collab_snapshot_meta( pg_pool: &PgPool, diff --git a/libs/database/src/collab/collab_storage.rs b/libs/database/src/collab/collab_storage.rs index b3af1f65e..6a6cefc6e 100644 --- a/libs/database/src/collab/collab_storage.rs +++ b/libs/database/src/collab/collab_storage.rs @@ -6,6 +6,7 @@ use database_entity::dto::{ QueryCollabParams, QueryCollabResult, SnapshotData, }; +use crate::collab::CollabType; use collab::entity::EncodedCollab; use serde::{Deserialize, Serialize}; use sqlx::Transaction; @@ -147,6 +148,13 @@ pub trait CollabStorage: Send + Sync + 'static { snapshot_id: &i64, ) -> AppResult; + async fn get_latest_snapshot( + &self, + workspace_id: &str, + object_id: &str, + collab_type: CollabType, + ) -> AppResult>; + /// Returns list of snapshots for given object_id in descending order of creation time. async fn get_collab_snapshot_list( &self, diff --git a/libs/indexer/src/scheduler.rs b/libs/indexer/src/scheduler.rs index f92fdf212..7c014f0ad 100644 --- a/libs/indexer/src/scheduler.rs +++ b/libs/indexer/src/scheduler.rs @@ -8,7 +8,6 @@ use crate::vector::embedder::Embedder; use crate::vector::open_ai; use app_error::AppError; use appflowy_ai_client::dto::{EmbeddingRequest, OpenAIEmbeddingResponse}; -use collab::lock::RwLock; use collab::preclude::Collab; use collab_document::document::DocumentBody; use collab_entity::CollabType; @@ -243,7 +242,7 @@ impl IndexerScheduler { &self, workspace_id: &str, object_id: &str, - collab: &Arc>, + collab: &Collab, collab_type: &CollabType, ) -> Result<(), AppError> { if !self.index_enabled() { @@ -256,11 +255,9 @@ impl IndexerScheduler { match collab_type { CollabType::Document => { - let lock = collab.read().await; - let txn = lock.transact(); - let text = DocumentBody::from_collab(&lock) + let txn = collab.transact(); + let text = DocumentBody::from_collab(collab) .and_then(|body| body.to_plain_text(txn, false, true).ok()); - drop(lock); // release the read lock ASAP if let Some(text) = text { if !text.is_empty() { @@ -268,7 +265,7 @@ impl IndexerScheduler { Uuid::parse_str(workspace_id)?, object_id.to_string(), collab_type.clone(), - UnindexedData::UnindexedText(text), + UnindexedData::Text(text), ); self.embed_immediately(pending)?; } @@ -491,7 +488,7 @@ fn process_collab( ) -> Result)>, AppError> { if let Some(indexer) = indexer { let chunks = match data { - UnindexedData::UnindexedText(text) => { + UnindexedData::Text(text) => { indexer.create_embedded_chunks_from_text(object_id.to_string(), text, embedder.model())? }, }; @@ -543,13 +540,13 @@ impl UnindexedCollabTask { #[derive(Debug, Serialize, Deserialize)] pub enum UnindexedData { - UnindexedText(String), + Text(String), } impl UnindexedData { pub fn is_empty(&self) -> bool { match self { - UnindexedData::UnindexedText(text) => text.is_empty(), + UnindexedData::Text(text) => text.is_empty(), } } } diff --git a/services/appflowy-collaborate/Cargo.toml b/services/appflowy-collaborate/Cargo.toml index eaa20eff9..11661d828 100644 --- a/services/appflowy-collaborate/Cargo.toml +++ b/services/appflowy-collaborate/Cargo.toml @@ -61,6 +61,7 @@ thiserror = "1.0.56" tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] } anyhow.workspace = true bytes.workspace = true +arc-swap.workspace = true ureq = { version = "2.1.0", features = ["json"] } collab = { workspace = true } diff --git a/services/appflowy-collaborate/src/actix_ws/client/rt_client.rs b/services/appflowy-collaborate/src/actix_ws/client/rt_client.rs index eed5d61cd..6d3b1e848 100644 --- a/services/appflowy-collaborate/src/actix_ws/client/rt_client.rs +++ b/services/appflowy-collaborate/src/actix_ws/client/rt_client.rs @@ -107,7 +107,7 @@ where user: self.user.clone(), message, }) - .map_err(|err| RealtimeError::Internal(err.into())) + .map_err(|err| RealtimeError::SendWSMessageFailed(err.to_string())) } } diff --git a/services/appflowy-collaborate/src/api.rs b/services/appflowy-collaborate/src/api.rs index 35db9ab50..80f4ae02c 100644 --- a/services/appflowy-collaborate/src/api.rs +++ b/services/appflowy-collaborate/src/api.rs @@ -113,7 +113,6 @@ async fn post_realtime_message_stream_handler( bytes.extend_from_slice(&item?); } - event!(tracing::Level::INFO, "message len: {}", bytes.len()); let device_id = device_id.to_string(); let message = parser_realtime_msg(bytes.freeze(), req.clone()).await?; diff --git a/services/appflowy-collaborate/src/application.rs b/services/appflowy-collaborate/src/application.rs index 4875155e1..8643600b0 100644 --- a/services/appflowy-collaborate/src/application.rs +++ b/services/appflowy-collaborate/src/application.rs @@ -23,6 +23,10 @@ use tracing::info; use crate::actix_ws::server::RealtimeServerActor; use crate::api::{collab_scope, ws_scope}; use crate::collab::access_control::CollabStorageAccessControlImpl; +use access_control::casbin::access::AccessControl; +use collab_stream::stream_router::{StreamRouter, StreamRouterOptions}; +use database::file::s3_client_impl::AwsS3BucketClientImpl; + use crate::collab::cache::CollabCache; use crate::collab::storage::CollabStorageImpl; use crate::command::{CLCommandReceiver, CLCommandSender}; @@ -31,8 +35,6 @@ use crate::pg_listener::PgListeners; use crate::snapshot::SnapshotControl; use crate::state::{AppMetrics, AppState, UserCache}; use crate::CollaborationServer; -use access_control::casbin::access::AccessControl; -use database::file::s3_client_impl::AwsS3BucketClientImpl; use indexer::collab_indexer::IndexerProvider; use indexer::scheduler::{IndexerConfiguration, IndexerScheduler}; @@ -78,9 +80,10 @@ pub async fn run_actix_server( )), state.metrics.realtime_metrics.clone(), rt_cmd_recv, + state.redis_stream_router.clone(), + state.redis_connection_manager.clone(), Duration::from_secs(config.collab.group_persistence_interval_secs), - config.collab.edit_state_max_count, - config.collab.edit_state_max_secs, + Duration::from_secs(config.collab.group_prune_grace_period_secs), state.indexer_scheduler.clone(), ) .await @@ -107,7 +110,8 @@ pub async fn init_state(config: &Config, rt_cmd_tx: CLCommandSender) -> Result Result Result Result { +async fn get_redis_client( + redis_uri: &str, + worker_count: usize, +) -> Result<(redis::aio::ConnectionManager, Arc), Error> { info!("Connecting to redis with uri: {}", redis_uri); - let manager = redis::Client::open(redis_uri) - .context("failed to connect to redis")? + let client = redis::Client::open(redis_uri).context("failed to connect to redis")?; + + let router = StreamRouter::with_options( + &client, + StreamRouterOptions { + worker_count, + xread_streams: 100, + xread_block_millis: Some(5000), + xread_count: None, + }, + )?; + + let manager = client .get_connection_manager() .await .context("failed to get the connection manager")?; - Ok(manager) + Ok((manager, router.into())) } async fn get_connection_pool(setting: &DatabaseSetting) -> Result { diff --git a/services/appflowy-collaborate/src/collab/cache/collab_cache.rs b/services/appflowy-collaborate/src/collab/cache/collab_cache.rs index b90f9b7cf..92e699364 100644 --- a/services/appflowy-collaborate/src/collab/cache/collab_cache.rs +++ b/services/appflowy-collaborate/src/collab/cache/collab_cache.rs @@ -1,4 +1,6 @@ +use bytes::Bytes; use collab::entity::EncodedCollab; +use collab_entity::CollabType; use futures_util::{stream, StreamExt}; use itertools::{Either, Itertools}; use sqlx::{PgPool, Transaction}; @@ -186,6 +188,11 @@ impl CollabCache { // when the data is written to the disk cache but fails to be written to the memory cache // we log the error and continue. + self.cache_collab(object_id, collab_type, encode_collab_data); + Ok(()) + } + + fn cache_collab(&self, object_id: String, collab_type: CollabType, encode_collab_data: Bytes) { let mem_cache = self.mem_cache.clone(); tokio::spawn(async move { if let Err(err) = mem_cache @@ -203,20 +210,6 @@ impl CollabCache { ); } }); - - Ok(()) - } - - pub async fn get_encode_collab_from_disk( - &self, - workspace_id: &str, - query: QueryCollab, - ) -> Result { - let encode_collab = self - .disk_cache - .get_collab_encoded_from_disk(workspace_id, query) - .await?; - Ok(encode_collab) } pub async fn insert_encode_collab_to_disk( @@ -225,10 +218,12 @@ impl CollabCache { uid: &i64, params: CollabParams, ) -> Result<(), AppError> { + let p = params.clone(); self .disk_cache .upsert_collab(workspace_id, uid, params) .await?; + self.cache_collab(p.object_id, p.collab_type, p.encoded_collab_v1); Ok(()) } diff --git a/services/appflowy-collaborate/src/collab/cache/disk_cache.rs b/services/appflowy-collaborate/src/collab/cache/disk_cache.rs index ca342e436..aa0959703 100644 --- a/services/appflowy-collaborate/src/collab/cache/disk_cache.rs +++ b/services/appflowy-collaborate/src/collab/cache/disk_cache.rs @@ -357,7 +357,7 @@ impl CollabDiskCache { while let Err(err) = s3.put_blob(&key, doc_state.clone().into(), None).await { match err { AppError::ServiceTemporaryUnavailable(err) if retries > 0 => { - tracing::info!( + tracing::debug!( "S3 service is temporarily unavailable: {}. Remaining retries: {}", err, retries @@ -371,6 +371,7 @@ impl CollabDiskCache { }, } } + tracing::trace!("saved collab to S3: {}", key); Ok(()) } diff --git a/services/appflowy-collaborate/src/collab/cache/mem_cache.rs b/services/appflowy-collaborate/src/collab/cache/mem_cache.rs index c79f87277..b74612abc 100644 --- a/services/appflowy-collaborate/src/collab/cache/mem_cache.rs +++ b/services/appflowy-collaborate/src/collab/cache/mem_cache.rs @@ -154,6 +154,7 @@ impl CollabMemCache { timestamp: i64, expiration_seconds: Option, ) -> redis::RedisResult<()> { + tracing::trace!("insert collab {} to memory cache", object_id); self .insert_data_with_timestamp(object_id, data, timestamp, expiration_seconds) .await diff --git a/services/appflowy-collaborate/src/collab/storage.rs b/services/appflowy-collaborate/src/collab/storage.rs index 30aae4a07..bdc0b06b0 100644 --- a/services/appflowy-collaborate/src/collab/storage.rs +++ b/services/appflowy-collaborate/src/collab/storage.rs @@ -545,6 +545,18 @@ where .await } + async fn get_latest_snapshot( + &self, + workspace_id: &str, + object_id: &str, + collab_type: CollabType, + ) -> AppResult> { + self + .snapshot_control + .get_latest_snapshot(workspace_id, object_id, collab_type) + .await + } + async fn get_collab_snapshot_list( &self, workspace_id: &str, diff --git a/services/appflowy-collaborate/src/config.rs b/services/appflowy-collaborate/src/config.rs index 19c4ecad7..cd02fbe54 100644 --- a/services/appflowy-collaborate/src/config.rs +++ b/services/appflowy-collaborate/src/config.rs @@ -16,6 +16,7 @@ pub struct Config { pub gotrue: GoTrueSetting, pub collab: CollabSetting, pub redis_uri: Secret, + pub redis_worker_count: usize, pub ai: AISettings, pub s3: S3Setting, } @@ -127,6 +128,7 @@ pub struct GoTrueSetting { #[derive(Clone, Debug)] pub struct CollabSetting { pub group_persistence_interval_secs: u64, + pub group_prune_grace_period_secs: u64, pub edit_state_max_count: u32, pub edit_state_max_secs: i64, pub s3_collab_threshold: u64, @@ -198,11 +200,14 @@ pub fn get_configuration() -> Result { "60", ) .parse()?, + group_prune_grace_period_secs: get_env_var("APPFLOWY_COLLAB_GROUP_GRACE_PERIOD_SECS", "60") + .parse()?, edit_state_max_count: get_env_var("APPFLOWY_COLLAB_EDIT_STATE_MAX_COUNT", "100").parse()?, edit_state_max_secs: get_env_var("APPFLOWY_COLLAB_EDIT_STATE_MAX_SECS", "60").parse()?, s3_collab_threshold: get_env_var("APPFLOWY_COLLAB_S3_THRESHOLD", "8000").parse()?, }, redis_uri: get_env_var("APPFLOWY_REDIS_URI", "redis://localhost:6379").into(), + redis_worker_count: get_env_var("APPFLOWY_REDIS_WORKERS", "60").parse()?, ai: AISettings { port: get_env_var("APPFLOWY_AI_SERVER_PORT", "5001").parse()?, host: get_env_var("APPFLOWY_AI_SERVER_HOST", "localhost"), diff --git a/services/appflowy-collaborate/src/error.rs b/services/appflowy-collaborate/src/error.rs index feb33e42d..ec8bca424 100644 --- a/services/appflowy-collaborate/src/error.rs +++ b/services/appflowy-collaborate/src/error.rs @@ -60,6 +60,27 @@ pub enum RealtimeError { #[error("Collab redis stream error: {0}")] StreamError(#[from] StreamError), + + #[error("Cannot create group: {0}")] + CannotCreateGroup(String), + + #[error("BinCodeCollab error: {0}")] + BincodeEncode(String), + + #[error("Failed to create snapshot: {0}")] + CreateSnapshotFailed(String), + + #[error("Failed to get latest snapshot: {0}")] + GetLatestSnapshotFailed(String), + + #[error("Collab Schema Error: {0}")] + CollabSchemaError(String), + + #[error("failed to obtain lease: {0}")] + Lease(Box), + + #[error("failed to send ws message: {0}")] + SendWSMessageFailed(String), } #[derive(Debug)] diff --git a/services/appflowy-collaborate/src/group/cmd.rs b/services/appflowy-collaborate/src/group/cmd.rs index 7f2bce2d9..907c26600 100644 --- a/services/appflowy-collaborate/src/group/cmd.rs +++ b/services/appflowy-collaborate/src/group/cmd.rs @@ -6,6 +6,11 @@ use async_stream::stream; use bytes::Bytes; use collab::core::origin::{CollabClient, CollabOrigin}; use collab::entity::EncodedCollab; +use dashmap::DashMap; +use futures_util::StreamExt; +use std::collections::HashMap; +use std::sync::Arc; + use collab_entity::CollabType; use collab_rt_entity::user::RealtimeUser; use collab_rt_entity::CollabAck; @@ -13,10 +18,7 @@ use collab_rt_entity::{ AckCode, ClientCollabMessage, MessageByObjectId, ServerCollabMessage, SinkMessage, UpdateSync, }; use collab_rt_protocol::{Message, SyncMessage}; -use dashmap::DashMap; use database::collab::CollabStorage; -use futures_util::StreamExt; -use std::sync::Arc; use tracing::{error, instrument, trace, warn}; use yrs::updates::encoder::Encode; use yrs::StateVector; @@ -147,7 +149,10 @@ where }, GroupCommand::GenerateCollabEmbedding { object_id } => { if let Some(group) = self.group_manager.get_group(&object_id).await { - group.generate_embeddings().await; + match group.generate_embeddings().await { + Ok(_) => trace!("successfully created embeddings for {}", object_id), + Err(err) => trace!("failed to create embeddings for {}: {}", object_id, err), + } } }, GroupCommand::CalculateMissingUpdate { @@ -334,16 +339,15 @@ where }; if let Some(group) = self.group_manager.get_group(&object_id).await { - let (collab_message_sender, _collab_message_receiver) = futures::channel::mpsc::channel(1); let (mut message_by_oid_sender, message_by_oid_receiver) = futures::channel::mpsc::channel(1); group.subscribe( &server_rt_user, CollabOrigin::Server, - collab_message_sender, + NullSender::default(), message_by_oid_receiver, ); - let message = MessageByObjectId::new_with_message(object_id.clone(), messages); - if let Err(err) = message_by_oid_sender.try_send(message) { + let message = HashMap::from([(object_id.clone(), messages)]); + if let Err(err) = message_by_oid_sender.try_send(MessageByObjectId(message)) { error!( "failed to send message to group: {}, object_id: {}", err, object_id diff --git a/services/appflowy-collaborate/src/group/group_init.rs b/services/appflowy-collaborate/src/group/group_init.rs index 4a6d3c0ce..2977a7b6e 100644 --- a/services/appflowy-collaborate/src/group/group_init.rs +++ b/services/appflowy-collaborate/src/group/group_init.rs @@ -1,46 +1,74 @@ -use std::fmt::Display; -use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU32, Ordering}; -use std::sync::Arc; -use std::time::Duration; - +use crate::error::RealtimeError; +use anyhow::anyhow; +use app_error::AppError; +use arc_swap::ArcSwap; +use collab::core::collab::DataSource; use collab::core::origin::CollabOrigin; use collab::entity::EncodedCollab; use collab::lock::RwLock; use collab::preclude::Collab; use collab_entity::CollabType; use collab_rt_entity::user::RealtimeUser; -use collab_rt_entity::CollabMessage; -use collab_rt_entity::MessageByObjectId; +use collab_rt_entity::{ + AckCode, AwarenessSync, BroadcastSync, CollabAck, MessageByObjectId, MsgId, +}; +use collab_rt_entity::{ClientCollabMessage, CollabMessage}; +use collab_rt_protocol::{Message, MessageReader, RTProtocolError, SyncMessage}; +use collab_stream::client::CollabRedisStream; +use collab_stream::collab_update_sink::{AwarenessUpdateSink, CollabUpdateSink}; + +use crate::metrics::CollabRealtimeMetrics; +use bytes::Bytes; +use collab_document::document::DocumentBody; +use collab_stream::error::StreamError; +use collab_stream::model::{AwarenessStreamUpdate, CollabStreamUpdate, MessageId, UpdateFlags}; use dashmap::DashMap; +use database::collab::{CollabStorage, GetCollabOrigin}; +use database_entity::dto::{CollabParams, QueryCollabParams}; +use futures::{pin_mut, Sink, Stream}; use futures_util::{SinkExt, StreamExt}; +use indexer::scheduler::{IndexerScheduler, UnindexedCollabTask, UnindexedData}; +use std::sync::atomic::{AtomicU32, Ordering}; +use std::sync::Arc; +use std::time::{Duration, Instant, SystemTime}; +use tokio::time::MissedTickBehavior; use tokio_util::sync::CancellationToken; -use tracing::{event, info, trace}; -use yrs::{ReadTxn, StateVector}; - -use collab_stream::error::StreamError; - -use crate::error::RealtimeError; -use crate::group::broadcast::{CollabBroadcast, Subscription}; -use crate::group::persistence::GroupPersistence; -use crate::metrics::CollabRealtimeMetrics; -use database::collab::CollabStorage; -use indexer::scheduler::IndexerScheduler; +use tracing::{error, info, trace}; +use uuid::Uuid; +use yrs::updates::decoder::{Decode, DecoderV1}; +use yrs::updates::encoder::{Encode, Encoder, EncoderV1}; +use yrs::{ReadTxn, StateVector, Update}; /// A group used to manage a single [Collab] object pub struct CollabGroup { - pub workspace_id: String, - pub object_id: String, - collab: Arc>, + state: Arc, +} + +/// Inner state of [CollabGroup] that's private and hidden behind Arc, so that it can be moved into +/// tasks. +struct CollabGroupState { + workspace_id: String, + object_id: String, collab_type: CollabType, - /// A broadcast used to propagate updates produced by yrs [yrs::Doc] and [Awareness] - /// to subscribes. - broadcast: CollabBroadcast, /// A list of subscribers to this group. Each subscriber will receive updates from the /// broadcast. subscribers: DashMap, - metrics_calculate: Arc, - indexer_scheduler: Arc, - cancel: CancellationToken, + persister: CollabPersister, + metrics: Arc, + /// Cancellation token triggered when current collab group is about to be stopped. + /// This will also shut down all subsequent [Subscription]s. + shutdown: CancellationToken, + last_activity: ArcSwap, + seq_no: AtomicU32, + /// The most recent state vector from a redis update. + state_vector: RwLock, +} + +impl Drop for CollabGroup { + fn drop(&mut self) { + // we're going to use state shutdown to cancel subsequent tasks + self.state.shutdown.cancel(); + } } impl CollabGroup { @@ -50,126 +78,359 @@ impl CollabGroup { workspace_id: String, object_id: String, collab_type: CollabType, - collab: Collab, - metrics_calculate: Arc, + metrics: Arc, storage: Arc, - is_new_collab: bool, + collab_redis_stream: Arc, persistence_interval: Duration, - edit_state_max_count: u32, - edit_state_max_secs: i64, + prune_grace_period: Duration, + state_vector: StateVector, indexer_scheduler: Arc, ) -> Result where S: CollabStorage, { - let edit_state = Arc::new(EditState::new( - edit_state_max_count, - edit_state_max_secs, - is_new_collab, - )); - let broadcast = CollabBroadcast::new(&object_id, 1000, edit_state.clone(), &collab); - let cancel = CancellationToken::new(); - let collab = Arc::new(RwLock::new(collab)); - tokio::spawn( - GroupPersistence::new( - workspace_id.clone(), - object_id.clone(), - uid, - storage, - edit_state.clone(), - collab.clone(), - collab_type.clone(), - persistence_interval, - indexer_scheduler.clone(), - cancel.clone(), - ) - .run(), + let is_new_collab = state_vector.is_empty(); + let persister = CollabPersister::new( + uid, + workspace_id.clone(), + object_id.clone(), + collab_type.clone(), + storage, + collab_redis_stream, + indexer_scheduler, + metrics.clone(), + prune_grace_period, ); - Ok(Self { + let state = Arc::new(CollabGroupState { workspace_id, object_id, collab_type, - collab, - broadcast, - subscribers: Default::default(), - metrics_calculate, - cancel, - indexer_scheduler, - }) + subscribers: DashMap::new(), + metrics, + shutdown: CancellationToken::new(), + persister, + last_activity: ArcSwap::new(Instant::now().into()), + seq_no: AtomicU32::new(0), + state_vector: state_vector.into(), + }); + + /* + NOTE: we don't want to pass `Weak` to tasks and terminate them when they + cannot be upgraded since we want to be sure that ie. when collab group is to be removed, + that we're going to call for a final save of the document state. + + For that we use `CancellationToken` instead, which is racing against internal loops of child + tasks and triggered when this `CollabGroup` is dropped. + */ + + // setup task used to receive collab updates from Redis + { + let state = state.clone(); + tokio::spawn(async move { + if let Err(err) = Self::inbound_task(state).await { + tracing::warn!("failed to receive collab update: {}", err); + } + }); + } + + // setup task used to receive awareness updates from Redis + { + let state = state.clone(); + tokio::spawn(async move { + if let Err(err) = Self::inbound_awareness_task(state).await { + tracing::warn!("failed to receive awareness update: {}", err); + } + }); + } + + // setup periodic snapshot + { + tokio::spawn(Self::snapshot_task( + state.clone(), + persistence_interval, + is_new_collab, + )); + } + + Ok(Self { state }) } - /// Generate embedding for the current Collab immediately - /// - pub async fn generate_embeddings(&self) { - let result = self - .indexer_scheduler - .index_collab_immediately( - &self.workspace_id, - &self.object_id, - &self.collab, - &self.collab_type, - ) - .await; - match result { - Ok(_) => { - trace!( - "successfully indexed embeddings for {} {}/{}", - self.collab_type, - self.workspace_id, - self.object_id - ); - }, + #[inline] + pub fn workspace_id(&self) -> &str { + &self.state.workspace_id + } + + #[inline] + #[allow(dead_code)] + pub fn object_id(&self) -> &str { + &self.state.object_id + } + + pub fn is_cancelled(&self) -> bool { + self.state.shutdown.is_cancelled() + } + + /// Task used to receive collab updates from Redis. + async fn inbound_task(state: Arc) -> Result<(), RealtimeError> { + let updates = state.persister.collab_redis_stream.live_collab_updates( + &state.workspace_id, + &state.object_id, + None, + ); + pin_mut!(updates); + loop { + tokio::select! { + _ = state.shutdown.cancelled() => { + break; + } + res = updates.next() => { + match res { + Some(Ok((_message_id, update))) => { + Self::handle_inbound_update(&state, update).await; + }, + Some(Err(err)) => { + tracing::warn!("failed to handle incoming update for collab `{}`: {}", state.object_id, err); + break; + }, + None => { + break; + } + } + } + } + } + Ok(()) + } + + async fn handle_inbound_update(state: &CollabGroupState, update: CollabStreamUpdate) { + // update state vector based on incoming message + match Update::decode_v1(&update.data) { + Ok(update) => state + .state_vector + .write() + .await + .merge(update.state_vector()), Err(err) => { - trace!( - "failed to index embeddings for collab {} {}/{}: {}", - self.collab_type, - self.workspace_id, - self.object_id, + tracing::error!( + "received malformed update for collab `{}`: {}", + state.object_id, err ); + return; }, } + + let seq_num = state.seq_no.fetch_add(1, Ordering::SeqCst) + 1; + tracing::trace!( + "broadcasting collab update from {} ({} bytes) - seq_num: {}", + update.sender, + update.data.len(), + seq_num + ); + let payload = Message::Sync(SyncMessage::Update(update.data)).encode_v1(); + let message = BroadcastSync::new(update.sender, state.object_id.clone(), payload, seq_num); + for mut e in state.subscribers.iter_mut() { + let subscription = e.value_mut(); + if message.origin == subscription.collab_origin { + continue; // don't send update to its sender + } + + if let Err(err) = subscription.sink.send(message.clone().into()).await { + tracing::debug!( + "failed to send collab `{}` update to `{}`: {}", + state.object_id, + subscription.collab_origin, + err + ); + } + + state.last_activity.store(Arc::new(Instant::now())); + } + } + + /// Task used to receive awareness updates from Redis. + async fn inbound_awareness_task(state: Arc) -> Result<(), RealtimeError> { + let updates = state.persister.collab_redis_stream.awareness_updates( + &state.workspace_id, + &state.object_id, + None, + ); + pin_mut!(updates); + loop { + tokio::select! { + _ = state.shutdown.cancelled() => { + break; + } + res = updates.next() => { + match res { + Some(Ok(awareness_update)) => { + Self::handle_inbound_awareness(&state, awareness_update).await; + }, + Some(Err(err)) => { + tracing::warn!("failed to handle incoming update for collab `{}`: {}", state.object_id, err); + break; + }, + None => { + break; + } + } + } + } + } + Ok(()) + } + + async fn handle_inbound_awareness(state: &CollabGroupState, update: AwarenessStreamUpdate) { + tracing::trace!( + "broadcasting awareness update from {} ({} bytes)", + update.sender, + update.data.len() + ); + let sender = update.sender; + let message = AwarenessSync::new( + state.object_id.clone(), + Message::Awareness(update.data).encode_v1(), + CollabOrigin::Empty, + ); + for mut e in state.subscribers.iter_mut() { + let subscription = e.value_mut(); + if sender == subscription.collab_origin { + continue; // don't send update to its sender + } + + if let Err(err) = subscription.sink.send(message.clone().into()).await { + tracing::debug!( + "failed to send awareness `{}` update to `{}`: {}", + state.object_id, + subscription.collab_origin, + err + ); + } + + state.last_activity.store(Arc::new(Instant::now())); + } + } + + async fn snapshot_task(state: Arc, interval: Duration, is_new_collab: bool) { + if is_new_collab { + tracing::trace!("persisting new collab for {}", state.object_id); + if let Err(err) = state.persister.save().await { + tracing::warn!( + "failed to persist new document `{}`: {}", + state.object_id, + err + ); + } + } + + let mut snapshot_tick = tokio::time::interval(interval); + // if saving took longer than snapshot_tick, just skip it over and try in the next round + snapshot_tick.set_missed_tick_behavior(MissedTickBehavior::Skip); + + loop { + tokio::select! { + _ = snapshot_tick.tick() => { + if let Err(err) = state.persister.save().await { + tracing::warn!("failed to persist collab `{}/{}`: {}", state.workspace_id, state.object_id, err); + } + }, + _ = state.shutdown.cancelled() => { + if let Err(err) = state.persister.save().await { + tracing::warn!("failed to persist collab on shutdown `{}/{}`: {}", state.workspace_id, state.object_id, err); + } + break; + } + } + } + } + + /// Generate embedding for the current Collab immediately + /// + pub async fn generate_embeddings(&self) -> Result<(), AppError> { + let collab = self + .encode_collab() + .await + .map_err(|e| AppError::Internal(e.into()))?; + let collab = Collab::new_with_source( + CollabOrigin::Server, + self.object_id(), + DataSource::DocStateV1(collab.doc_state.into()), + vec![], + false, + ) + .map_err(|e| AppError::Internal(e.into()))?; + let workspace_id = &self.state.workspace_id; + let object_id = &self.state.object_id; + let collab_type = &self.state.collab_type; + self + .state + .persister + .indexer_scheduler + .index_collab_immediately(workspace_id, object_id, &collab, collab_type) + .await } pub async fn calculate_missing_update( &self, state_vector: StateVector, ) -> Result, RealtimeError> { - let update = { - let guard = self.collab.read().await; - let txn = guard.transact(); - txn.encode_state_as_update_v1(&state_vector) - }; + { + // first check if we need to send any updates + let collab_sv = self.state.state_vector.read().await; + if *collab_sv <= state_vector { + return Ok(vec![]); + } + } + + let encoded_collab = self.encode_collab().await?; + let collab = Collab::new_with_source( + CollabOrigin::Server, + self.object_id(), + DataSource::DocStateV1(encoded_collab.doc_state.into()), + vec![], + false, + )?; + let update = collab.transact().encode_state_as_update_v1(&state_vector); Ok(update) } pub async fn encode_collab(&self) -> Result { - let lock = self.collab.read().await; - let encode_collab = lock.encode_collab_v1(|collab| { + let snapshot = self.state.persister.load_compact().await?; + let encode_collab = snapshot.collab.encode_collab_v1(|collab| { self + .state .collab_type .validate_require_data(collab) - .map_err(|err| RealtimeError::Internal(err.into())) + .map_err(|err| RealtimeError::CollabSchemaError(err.to_string())) })?; Ok(encode_collab) } pub fn contains_user(&self, user: &RealtimeUser) -> bool { - self.subscribers.contains_key(user) + self.state.subscribers.contains_key(user) } pub fn remove_user(&self, user: &RealtimeUser) { - if self.subscribers.remove(user).is_some() { - trace!("{} remove subscriber from group: {}", self.object_id, user); + if self.state.subscribers.remove(user).is_some() { + trace!( + "{} remove subscriber from group: {}", + self.state.object_id, + user + ); } } pub fn user_count(&self) -> usize { - self.subscribers.len() + self.state.subscribers.len() + } + + pub fn modified_at(&self) -> Instant { + *self.state.last_activity.load_full() } /// Subscribes a new connection to the broadcast group for collaborative activities. + /// pub fn subscribe( &self, user: &RealtimeUser, @@ -177,57 +438,383 @@ impl CollabGroup { sink: Sink, stream: Stream, ) where - Sink: SinkExt + Clone + Send + Sync + Unpin + 'static, - Stream: StreamExt + Send + Sync + Unpin + 'static, - >::Error: std::error::Error + Send + Sync, + Sink: SubscriptionSink + Clone + 'static, + Stream: SubscriptionStream + 'static, { // create new subscription for new subscriber - let sub = self.broadcast.subscribe( - user, - subscriber_origin, - sink, + let subscriber_shutdown = self.state.shutdown.child_token(); + + tokio::spawn(Self::receive_from_client_task( + self.state.clone(), + sink.clone(), stream, - Arc::downgrade(&self.collab), - self.metrics_calculate.clone(), - self.cancel.child_token(), - ); + subscriber_origin.clone(), + )); - if let Some(old) = self.subscribers.insert((*user).clone(), sub) { - tracing::warn!("{}: remove old subscriber: {}", &self.object_id, user); - drop(old); + let sub = Subscription::new(sink, subscriber_origin, subscriber_shutdown); + if self + .state + .subscribers + .insert((*user).clone(), sub) + .is_some() + { + tracing::warn!("{}: remove old subscriber: {}", &self.state.object_id, user); } if cfg!(debug_assertions) { - event!( - tracing::Level::TRACE, + trace!( "{}: add new subscriber, current group member: {}", - &self.object_id, + &self.state.object_id, self.user_count(), ); } trace!( "[realtime]:{} new subscriber:{}, connect at:{}, connected members: {}", - self.object_id, + self.state.object_id, user.user_device(), user.connect_at, - self.subscribers.len(), + self.state.subscribers.len(), + ); + } + + async fn receive_from_client_task( + state: Arc, + mut sink: Sink, + mut stream: Stream, + origin: CollabOrigin, + ) where + Sink: SubscriptionSink + 'static, + Stream: SubscriptionStream + 'static, + { + loop { + tokio::select! { + _ = state.shutdown.cancelled() => { + break; + } + msg = stream.next() => { + match msg { + None => break, + Some(msg) => if let Err(err) = Self::handle_messages(&state, &mut sink, msg).await { + tracing::warn!( + "collab `{}` failed to handle message from `{}`: {}", + state.object_id, + origin, + err + ); + + } + } + } + } + } + } + + async fn handle_messages( + state: &CollabGroupState, + sink: &mut Sink, + msg: MessageByObjectId, + ) -> Result<(), RealtimeError> + where + Sink: SubscriptionSink + 'static, + { + for (message_object_id, messages) in msg.0 { + if state.object_id != message_object_id { + error!( + "Expect object id:{} but got:{}", + state.object_id, message_object_id + ); + continue; + } + for message in messages { + match Self::handle_client_message(state, message).await { + Ok(response) => { + trace!("[realtime]: sending response: {}", response); + match sink.send(response.into()).await { + Ok(()) => {}, + Err(err) => { + trace!("[realtime]: send failed: {}", err); + break; + }, + } + }, + Err(err) => { + error!( + "Error handling collab message for object_id: {}: {}", + message_object_id, err + ); + break; + }, + } + } + } + Ok(()) + } + + /// Handle the message sent from the client + async fn handle_client_message( + state: &CollabGroupState, + collab_msg: ClientCollabMessage, + ) -> Result { + let msg_id = collab_msg.msg_id(); + let message_origin = collab_msg.origin().clone(); + + // If the payload is empty, we don't need to apply any updates . + // Currently, only the ping message should has an empty payload. + if collab_msg.payload().is_empty() { + if !matches!(collab_msg, ClientCollabMessage::ClientCollabStateCheck(_)) { + error!("receive unexpected empty payload message:{}", collab_msg); + } + return Ok(CollabAck::new( + message_origin, + state.object_id.to_string(), + msg_id, + state.seq_no.load(Ordering::SeqCst), + )); + } + + trace!( + "Applying client updates: {}, origin:{}", + collab_msg, + message_origin ); + + let payload = collab_msg.payload(); + + // Spawn a blocking task to handle the message + let result = Self::handle_message(state, payload, &message_origin, msg_id).await; + + match result { + Ok(inner_result) => match inner_result { + Some(response) => Ok(response), + None => Err(RealtimeError::UnexpectedData("No ack response")), + }, + Err(err) => Err(RealtimeError::Internal(anyhow!( + "fail to handle message:{}", + err + ))), + } + } + + async fn handle_message( + state: &CollabGroupState, + payload: &[u8], + message_origin: &CollabOrigin, + msg_id: MsgId, + ) -> Result, RealtimeError> { + let mut decoder = DecoderV1::from(payload); + let reader = MessageReader::new(&mut decoder); + let mut ack_response = None; + for msg in reader { + match msg { + Ok(msg) => { + match Self::handle_protocol_message(state, message_origin, msg).await { + Ok(payload) => { + // One ClientCollabMessage can have multiple Yrs [Message] in it, but we only need to + // send one ack back to the client. + if ack_response.is_none() { + ack_response = Some( + CollabAck::new( + CollabOrigin::Server, + state.object_id.to_string(), + msg_id, + state.seq_no.load(Ordering::SeqCst), + ) + .with_payload(payload.unwrap_or_default()), + ); + } + }, + Err(err) => { + tracing::warn!("[realtime]: failed to handled message: {}", msg_id); + state.metrics.apply_update_failed_count.inc(); + + let code = Self::ack_code_from_error(&err); + let payload = match err { + RTProtocolError::MissUpdates { + state_vector_v1, + reason: _, + } => state_vector_v1.unwrap_or_default(), + _ => vec![], + }; + + ack_response = Some( + CollabAck::new( + CollabOrigin::Server, + state.object_id.to_string(), + msg_id, + state.seq_no.load(Ordering::SeqCst), + ) + .with_code(code) + .with_payload(payload), + ); + + break; + }, + } + }, + Err(e) => { + error!("{} => parse sync message failed: {:?}", state.object_id, e); + break; + }, + } + } + Ok(ack_response) + } + + async fn handle_protocol_message( + state: &CollabGroupState, + origin: &CollabOrigin, + msg: Message, + ) -> Result>, RTProtocolError> { + match msg { + Message::Sync(msg) => match msg { + SyncMessage::SyncStep1(sv) => Self::handle_sync_step1(state, &sv).await, + SyncMessage::SyncStep2(update) => Self::handle_sync_step2(state, origin, update).await, + SyncMessage::Update(update) => Self::handle_update(state, origin, update).await, + }, + //FIXME: where is the QueryAwareness protocol? + Message::Awareness(update) => Self::handle_awareness_update(state, origin, update).await, + Message::Auth(_reason) => Ok(None), + Message::Custom(_msg) => Ok(None), + } + } + + async fn handle_sync_step1( + state: &CollabGroupState, + remote_sv: &StateVector, + ) -> Result>, RTProtocolError> { + if let Ok(sv) = state.state_vector.try_read() { + // we optimistically try to obtain state vector lock for a fast track: + // if we remote sv is up-to-date with current one, we don't need to do anything + match sv.partial_cmp(remote_sv) { + Some(std::cmp::Ordering::Equal) => return Ok(None), // client and server are in sync + Some(std::cmp::Ordering::Less) => { + // server is behind client + let msg = Message::Sync(SyncMessage::SyncStep1(sv.clone())); + return Ok(Some(msg.encode_v1())); + }, + Some(std::cmp::Ordering::Greater) | None => { /* server has some new updates */ }, + } + } + + // we need to reconstruct document state on the server side + tracing::debug!("loading collab {}", state.object_id); + let snapshot = state + .persister + .load_compact() + .await + .map_err(|err| RTProtocolError::Internal(err.into()))?; + + // prepare document state update and state vector + let tx = snapshot.collab.transact(); + let doc_state = tx.encode_state_as_update_v1(remote_sv); + let local_sv = tx.state_vector(); + drop(tx); + + // Retrieve the latest document state from the client after they return online from offline editing. + tracing::trace!("sending missing data to client ({} bytes)", doc_state.len()); + let mut encoder = EncoderV1::new(); + Message::Sync(SyncMessage::SyncStep2(doc_state)).encode(&mut encoder); + //FIXME: this should never happen as response to sync step 1 from the client, but rather be + // send when a connection is established + Message::Sync(SyncMessage::SyncStep1(local_sv)).encode(&mut encoder); + Ok(Some(encoder.to_vec())) + } + + async fn handle_sync_step2( + state: &CollabGroupState, + origin: &CollabOrigin, + update: Vec, + ) -> Result>, RTProtocolError> { + state.metrics.collab_size.observe(update.len() as f64); + + let start = tokio::time::Instant::now(); + // we try to decode update to make sure it's not malformed and to extract state vector + let (update, decoded_update) = if update.len() <= collab_rt_protocol::LARGE_UPDATE_THRESHOLD { + let decoded_update = Update::decode_v1(&update)?; + (update, decoded_update) + } else { + tokio::task::spawn_blocking(move || { + let decoded_update = Update::decode_v1(&update)?; + Ok::<(Vec, yrs::Update), yrs::encoding::read::Error>((update, decoded_update)) + }) + .await + .map_err(|err| RTProtocolError::Internal(err.into()))?? + }; + let missing_updates = { + let state_vector = state.state_vector.read().await; + match state_vector.partial_cmp(&decoded_update.state_vector_lower()) { + None | Some(std::cmp::Ordering::Less) => Some(state_vector.clone()), + _ => None, + } + }; + + if let Some(missing_updates) = missing_updates { + let msg = Message::Sync(SyncMessage::SyncStep1(missing_updates)); + tracing::debug!("subscriber {} send update with missing data", origin); + Ok(Some(msg.encode_v1())) + } else { + state + .persister + .send_update(origin.clone(), update) + .await + .map_err(|err| RTProtocolError::Internal(err.into()))?; + let elapsed = start.elapsed(); + + state + .metrics + .load_collab_time + .observe(elapsed.as_millis() as f64); + + Ok(None) + } + } + + async fn handle_update( + state: &CollabGroupState, + origin: &CollabOrigin, + update: Vec, + ) -> Result>, RTProtocolError> { + Self::handle_sync_step2(state, origin, update).await + } + + async fn handle_awareness_update( + state: &CollabGroupState, + origin: &CollabOrigin, + update: Vec, + ) -> Result>, RTProtocolError> { + state + .persister + .send_awareness(origin, update) + .await + .map_err(|err| RTProtocolError::Internal(err.into()))?; + Ok(None) + } + + #[inline] + fn ack_code_from_error(error: &RTProtocolError) -> AckCode { + match error { + RTProtocolError::YrsTransaction(_) => AckCode::Retry, + RTProtocolError::YrsApplyUpdate(_) => AckCode::CannotApplyUpdate, + RTProtocolError::YrsEncodeState(_) => AckCode::EncodeStateAsUpdateFail, + RTProtocolError::MissUpdates { .. } => AckCode::MissUpdate, + _ => AckCode::Internal, + } } /// Check if the group is active. A group is considered active if it has at least one /// subscriber pub fn is_inactive(&self) -> bool { - let modified_at = *self.broadcast.modified_at.lock(); + let modified_at = self.modified_at(); // In debug mode, we set the timeout to 60 seconds if cfg!(debug_assertions) { trace!( "Group:{}:{} is inactive for {} seconds, subscribers: {}", - self.object_id, - self.collab_type, + self.state.object_id, + self.state.collab_type, modified_at.elapsed().as_secs(), - self.subscribers.len() + self.state.subscribers.len() ); modified_at.elapsed().as_secs() > 60 * 3 } else { @@ -241,14 +828,14 @@ impl CollabGroup { if elapsed_secs > MAXIMUM_SECS { info!( "Group:{}:{} is inactive for {} seconds, subscribers: {}", - self.object_id, - self.collab_type, + self.state.object_id, + self.state.collab_type, modified_at.elapsed().as_secs(), - self.subscribers.len() + self.state.subscribers.len() ); true } else { - self.subscribers.is_empty() + self.state.subscribers.is_empty() } } else { false @@ -267,7 +854,7 @@ impl CollabGroup { /// A `u64` representing the timeout duration in seconds for the collaboration type in question. #[inline] fn timeout_secs(&self) -> u64 { - match self.collab_type { + match self.state.collab_type { CollabType::Document => 30 * 60, // 30 minutes CollabType::Database | CollabType::DatabaseRow => 30 * 60, // 30 minutes CollabType::WorkspaceDatabase | CollabType::Folder | CollabType::UserAwareness => 6 * 60 * 60, // 6 hours, @@ -278,117 +865,392 @@ impl CollabGroup { } } -impl Drop for CollabGroup { - fn drop(&mut self) { - self.cancel.cancel(); - } +pub trait SubscriptionSink: + Sink + Send + Sync + Unpin +{ +} +impl SubscriptionSink for T where + T: Sink + Send + Sync + Unpin +{ } -pub(crate) struct EditState { - /// Clients rely on `edit_count` to verify message ordering. A non-continuous sequence suggests - /// missing updates, prompting the client to request an initial synchronization. - /// Continuous sequence numbers ensure the client receives and displays updates in the correct order. - /// - edit_counter: AtomicU32, - prev_edit_count: AtomicU32, - prev_flush_timestamp: AtomicI64, - - max_edit_count: u32, - max_secs: i64, - /// Indicate the collab object is just created in the client and not exist in server database. - is_new_create: AtomicBool, +pub trait SubscriptionStream: Stream + Send + Sync + Unpin {} +impl SubscriptionStream for T where T: Stream + Send + Sync + Unpin {} + +struct Subscription { + collab_origin: CollabOrigin, + sink: Box, + shutdown: CancellationToken, } -impl Display for EditState { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "EditState {{ edit_counter: {}, prev_edit_count: {}, max_edit_count: {}, max_secs: {}, is_new: {}", - self.edit_counter.load(Ordering::SeqCst), - self.prev_edit_count.load(Ordering::SeqCst), - self.max_edit_count, - self.max_secs, - self.is_new_create.load(Ordering::SeqCst), - ) +impl Subscription { + fn new(sink: S, collab_origin: CollabOrigin, shutdown: CancellationToken) -> Self + where + S: SubscriptionSink + 'static, + { + Subscription { + sink: Box::new(sink), + collab_origin, + shutdown, + } } } -impl EditState { - fn new(max_edit_count: u32, max_secs: i64, is_new_create: bool) -> Self { +impl Drop for Subscription { + fn drop(&mut self) { + tracing::trace!("closing subscription: {}", self.collab_origin); + self.shutdown.cancel(); + } +} + +struct CollabPersister { + uid: i64, + workspace_id: String, + object_id: String, + collab_type: CollabType, + storage: Arc, + collab_redis_stream: Arc, + indexer_scheduler: Arc, + metrics: Arc, + update_sink: CollabUpdateSink, + awareness_sink: AwarenessUpdateSink, + /// A grace period for prunning Redis collab updates. Instead of deleting all messages we + /// read right away, we give 1min for other potential client to catch up. + prune_grace_period: Duration, +} + +impl CollabPersister { + #[allow(clippy::too_many_arguments)] + pub fn new( + uid: i64, + workspace_id: String, + object_id: String, + collab_type: CollabType, + storage: Arc, + collab_redis_stream: Arc, + indexer_scheduler: Arc, + metrics: Arc, + prune_grace_period: Duration, + ) -> Self { + let update_sink = collab_redis_stream.collab_update_sink(&workspace_id, &object_id); + let awareness_sink = collab_redis_stream.awareness_update_sink(&workspace_id, &object_id); Self { - edit_counter: AtomicU32::new(0), - prev_edit_count: Default::default(), - prev_flush_timestamp: AtomicI64::new(chrono::Utc::now().timestamp()), - max_edit_count, - max_secs, - is_new_create: AtomicBool::new(is_new_create), + uid, + workspace_id, + object_id, + collab_type, + storage, + collab_redis_stream, + indexer_scheduler, + metrics, + update_sink, + awareness_sink, + prune_grace_period, } } - pub(crate) fn edit_count(&self) -> u32 { - self.edit_counter.load(Ordering::SeqCst) + async fn send_update( + &self, + sender: CollabOrigin, + update: Vec, + ) -> Result { + let len = update.len(); + // send updates to redis queue + let update = CollabStreamUpdate::new(update, sender, UpdateFlags::default()); + let msg_id = self.update_sink.send(&update).await?; + tracing::trace!( + "persisted update from {} ({} bytes) - msg id: {}", + update.sender, + len, + msg_id + ); + Ok(msg_id) } - /// Increments the edit count and returns the old value - pub(crate) fn increment_edit_count(&self) -> u32 { - self - .edit_counter - .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |current| { - Some(current + 1) - }) - // safety: unwrap when returning the new value - .unwrap() + async fn send_awareness( + &self, + sender_session: &CollabOrigin, + awareness_update: Vec, + ) -> Result { + // send awareness updates to redis queue: + // QUESTION: is it needed? Maybe we could reuse update_sink? + let len = awareness_update.len(); + let update = AwarenessStreamUpdate { + data: awareness_update, + sender: sender_session.clone(), + }; + let msg_id = self.awareness_sink.send(&update).await?; + tracing::trace!( + "persisted awareness from {} ({} bytes) - msg id: {}", + update.sender, + len, + msg_id + ); + Ok(msg_id) } - pub(crate) fn tick(&self) { - self - .prev_edit_count - .store(self.edit_counter.load(Ordering::SeqCst), Ordering::SeqCst); + /// Loads collab without its history. Used for handling y-sync protocol messages. + async fn load_compact(&self) -> Result { + tracing::trace!("requested to load compact collab {}", self.object_id); + // 1. Try to load the latest snapshot from storage + let start = Instant::now(); + let mut collab = match self.load_collab_full().await? { + Some(collab) => collab, + None => Collab::new_with_origin(CollabOrigin::Server, self.object_id.clone(), vec![], false), + }; + self.metrics.load_collab_count.inc(); + + // 2. consume all Redis updates on top of it (keep redis msg id) + let mut last_message_id = None; + let mut tx = collab.transact_mut(); + let updates = self + .collab_redis_stream + .current_collab_updates( + &self.workspace_id, + &self.object_id, + None, //TODO: store Redis last msg id somewhere in doc state snapshot and replay from there + ) + .await?; + let mut i = 0; + for (message_id, update) in updates { + i += 1; + let update: Update = update.into_update()?; + tx.apply_update(update) + .map_err(|err| RTProtocolError::YrsApplyUpdate(err.to_string()))?; + last_message_id = Some(message_id); //TODO: shouldn't this happen before decoding? + self.metrics.apply_update_count.inc(); + } + drop(tx); + tracing::trace!( + "loaded collab compact state: {} replaying {} updates", + self.object_id, + i + ); self - .prev_flush_timestamp - .store(chrono::Utc::now().timestamp(), Ordering::SeqCst); - } + .metrics + .load_collab_time + .observe(start.elapsed().as_millis() as f64); - pub(crate) fn is_edit(&self) -> bool { - self.edit_counter.load(Ordering::SeqCst) != self.prev_edit_count.load(Ordering::SeqCst) + // now we have the most recent version of the document + let snapshot = CollabSnapshot { + collab, + last_message_id, + }; + Ok(snapshot) } - pub(crate) fn is_new_create(&self) -> bool { - self.is_new_create.load(Ordering::SeqCst) - } + /// Returns a collab state (with GC turned off), but only if there were any pending updates + /// waiting to be merged into main document state. + async fn load_if_changed(&self) -> Result, RealtimeError> { + // 1. load pending Redis updates + let updates = self + .collab_redis_stream + .current_collab_updates(&self.workspace_id, &self.object_id, None) + .await?; + + let start = Instant::now(); + let mut i = 0; + let mut collab = None; + let mut last_message_id = None; + for (message_id, update) in updates { + i += 1; + let update: Update = update.into_update()?; + if collab.is_none() { + collab = Some(match self.load_collab_full().await? { + Some(collab) => collab, + None => { + Collab::new_with_origin(CollabOrigin::Server, self.object_id.clone(), vec![], false) + }, + }) + }; + let collab = collab.as_mut().unwrap(); + collab + .transact_mut() + .apply_update(update) + .map_err(|err| RTProtocolError::YrsApplyUpdate(err.to_string()))?; + last_message_id = Some(message_id); //TODO: shouldn't this happen before decoding? + self.metrics.apply_update_count.inc(); + } - pub(crate) fn set_is_new_create(&self, is_new: bool) { - self.is_new_create.store(is_new, Ordering::SeqCst); + // if there were no Redis updates, collab is still not initialized + match collab { + Some(collab) => { + self.metrics.load_full_collab_count.inc(); + let elapsed = start.elapsed(); + self + .metrics + .load_collab_time + .observe(elapsed.as_millis() as f64); + tracing::trace!( + "loaded collab full state: {} replaying {} updates in {:?}", + self.object_id, + i, + elapsed + ); + { + let tx = collab.transact(); + if tx.store().pending_update().is_some() || tx.store().pending_ds().is_some() { + tracing::trace!( + "loaded collab {} is incomplete: has pending data", + self.object_id + ); + } + } + Ok(Some(CollabSnapshot { + collab, + last_message_id, + })) + }, + None => Ok(None), + } } - pub(crate) fn should_save_to_disk(&self) -> bool { - if self.is_new_create.load(Ordering::Relaxed) { - return true; + async fn save(&self) -> Result<(), RealtimeError> { + // load collab but only if there were pending updates in Redis + if let Some(mut snapshot) = self.load_if_changed().await? { + tracing::debug!("requesting save for collab {}", self.object_id); + if let Some(message_id) = snapshot.last_message_id { + // non-nil message_id means that we had to update the most recent collab state snapshot + // with new updates from Redis. This means that our snapshot state is newer than the last + // persisted one in the database + self.save_attempt(&mut snapshot.collab, message_id).await?; + } + } else { + tracing::trace!("collab {} state has not changed", self.object_id); } + Ok(()) + } - let current_edit_count = self.edit_counter.load(Ordering::SeqCst); - let prev_edit_count = self.prev_edit_count.load(Ordering::SeqCst); + /// Tries to save provided `snapshot`. This snapshot is expected to have **GC turned off**, as + /// first it will try to save it as a historical snapshot (will all updates available), then it + /// will generate another (compact) snapshot variant that will be used as main one for loading + /// for the sake of y-sync protocol. + async fn save_attempt( + &self, + collab: &mut Collab, + message_id: MessageId, + ) -> Result<(), RealtimeError> { + // try to acquire snapshot lease - it's possible that multiple web services will try to + // perform snapshot at the same time, so we'll use lease to let only one of them atm. + if let Some(mut lease) = self + .collab_redis_stream + .lease(&self.workspace_id, &self.object_id) + .await? + { + let doc_state_light = collab + .transact() + .encode_state_as_update_v1(&StateVector::default()); + let light_len = doc_state_light.len(); + self.write_collab(doc_state_light).await?; - // If the collab is new, save it to disk and reset the flag - if self.is_new_create.load(Ordering::SeqCst) { - return true; - } + match self.collab_type { + CollabType::Document => { + let txn = collab.transact(); + if let Some(text) = DocumentBody::from_collab(collab) + .and_then(|body| body.to_plain_text(txn, false, true).ok()) + { + self.index_collab_content(text); + } + }, + _ => { + // TODO(nathan): support other collab type + }, + } - if current_edit_count == prev_edit_count { - return false; + tracing::debug!( + "persisted collab {} snapshot at {}: {} bytes", + self.object_id, + message_id, + light_len + ); + + // 3. finally we can drop Redis messages + let now = SystemTime::UNIX_EPOCH.elapsed().unwrap().as_millis(); + let msg_id = MessageId { + timestamp_ms: (now - self.prune_grace_period.as_millis()) as u64, + sequence_number: 0, + }; + let stream_key = CollabStreamUpdate::stream_key(&self.workspace_id, &self.object_id); + self + .collab_redis_stream + .prune_stream(&stream_key, msg_id) + .await?; + + let _ = lease.release().await; } - // Check if the edit count exceeds the maximum allowed since the last save - let edit_count_exceeded = (current_edit_count > prev_edit_count) - && ((current_edit_count - prev_edit_count) >= self.max_edit_count); + Ok(()) + } - // Calculate the time since the last flush and check if it exceeds the maximum allowed - let now = chrono::Utc::now().timestamp(); - let prev_flush_timestamp = self.prev_flush_timestamp.load(Ordering::SeqCst); - let time_exceeded = - (now > prev_flush_timestamp) && (now - prev_flush_timestamp >= self.max_secs); + async fn write_collab(&self, doc_state_v1: Vec) -> Result<(), RealtimeError> { + let encoded_collab = EncodedCollab::new_v1(Default::default(), doc_state_v1) + .encode_to_bytes() + .map(Bytes::from) + .map_err(|err| RealtimeError::BincodeEncode(err.to_string()))?; + self + .metrics + .collab_size + .observe(encoded_collab.len() as f64); + let params = CollabParams::new(&self.object_id, self.collab_type.clone(), encoded_collab); + self + .storage + .queue_insert_or_update_collab(&self.workspace_id, &self.uid, params, true) + .await + .map_err(|err| RealtimeError::Internal(err.into()))?; + Ok(()) + } - // Determine if we should save based on either condition being met - edit_count_exceeded || (current_edit_count != prev_edit_count && time_exceeded) + fn index_collab_content(&self, text: String) { + if let Ok(workspace_id) = Uuid::parse_str(&self.workspace_id) { + let indexed_collab = UnindexedCollabTask::new( + workspace_id, + self.object_id.clone(), + self.collab_type.clone(), + UnindexedData::Text(text), + ); + if let Err(err) = self + .indexer_scheduler + .index_pending_collab_one(indexed_collab, false) + { + tracing::warn!( + "failed to index collab `{}/{}`: {}", + self.workspace_id, + self.object_id, + err + ); + } + } } + + async fn load_collab_full(&self) -> Result, RealtimeError> { + // we didn't find a snapshot, or we want a lightweight collab version + let params = QueryCollabParams::new( + self.object_id.clone(), + self.collab_type.clone(), + self.workspace_id.clone(), + ); + let result = self + .storage + .get_encode_collab(GetCollabOrigin::Server, params, false) + .await; + let doc_state = match result { + Ok(encoded_collab) => encoded_collab.doc_state, + Err(AppError::RecordNotFound(_)) => return Ok(None), + Err(err) => return Err(RealtimeError::Internal(err.into())), + }; + + let collab: Collab = Collab::new_with_source( + CollabOrigin::Server, + &self.object_id, + DataSource::DocStateV1(doc_state.into()), + vec![], + false, + )?; + Ok(Some(collab)) + } +} + +pub struct CollabSnapshot { + pub collab: Collab, + pub last_message_id: Option, } diff --git a/services/appflowy-collaborate/src/group/manager.rs b/services/appflowy-collaborate/src/group/manager.rs index 7b1e13dac..485caa0b0 100644 --- a/services/appflowy-collaborate/src/group/manager.rs +++ b/services/appflowy-collaborate/src/group/manager.rs @@ -1,25 +1,26 @@ use std::sync::Arc; use std::time::Duration; +use access_control::collab::RealtimeAccessControl; +use app_error::AppError; use collab::core::collab::DataSource; use collab::core::origin::CollabOrigin; use collab::entity::EncodedCollab; use collab::preclude::Collab; use collab_entity::CollabType; -use tracing::{instrument, trace}; - -use access_control::collab::RealtimeAccessControl; -use app_error::AppError; use collab_rt_entity::user::RealtimeUser; use collab_rt_entity::CollabMessage; +use collab_stream::client::CollabRedisStream; +use database::collab::{CollabStorage, GetCollabOrigin}; +use database_entity::dto::QueryCollabParams; +use tracing::{instrument, trace}; +use yrs::{ReadTxn, StateVector}; use crate::client::client_msg_router::ClientMessageRouter; -use crate::error::{CreateGroupFailedReason, RealtimeError}; +use crate::error::RealtimeError; use crate::group::group_init::CollabGroup; use crate::group::state::GroupManagementState; use crate::metrics::CollabRealtimeMetrics; -use database::collab::{CollabStorage, GetCollabOrigin}; -use database_entity::dto::QueryCollabParams; use indexer::scheduler::IndexerScheduler; pub struct GroupManager { @@ -27,9 +28,9 @@ pub struct GroupManager { storage: Arc, access_control: Arc, metrics_calculate: Arc, + collab_redis_stream: Arc, persistence_interval: Duration, - edit_state_max_count: u32, - edit_state_max_secs: i64, + prune_grace_period: Duration, indexer_scheduler: Arc, } @@ -42,19 +43,20 @@ where storage: Arc, access_control: Arc, metrics_calculate: Arc, + collab_stream: CollabRedisStream, persistence_interval: Duration, - edit_state_max_count: u32, - edit_state_max_secs: i64, + prune_grace_period: Duration, indexer_scheduler: Arc, ) -> Result { + let collab_stream = Arc::new(collab_stream); Ok(Self { state: GroupManagementState::new(metrics_calculate.clone()), storage, access_control, metrics_calculate, + collab_redis_stream: collab_stream, persistence_interval, - edit_state_max_count, - edit_state_max_secs, + prune_grace_period, indexer_scheduler, }) } @@ -87,17 +89,18 @@ where client_msg_router: &mut ClientMessageRouter, ) -> Result<(), RealtimeError> { // Lock the group and subscribe the user to the group. - if let Some(group) = self.state.get_mut_group(object_id).await { + if let Some(mut e) = self.state.get_mut_group(object_id).await { + let group = e.value_mut(); trace!("[realtime]: {} subscribe group:{}", user, object_id,); let (sink, stream) = client_msg_router.init_client_communication::( - &group.workspace_id, + group.workspace_id(), user, object_id, self.access_control.clone(), ); group.subscribe(user, message_origin.clone(), sink, stream); // explicitly drop the group to release the lock. - drop(group); + drop(e); self.state.insert_user(user, object_id)?; } else { @@ -115,29 +118,23 @@ where object_id: &str, collab_type: CollabType, ) -> Result<(), RealtimeError> { - let mut is_new_collab = false; let params = QueryCollabParams::new(object_id, collab_type.clone(), workspace_id); - - let result = load_collab(user.uid, object_id, params, self.storage.clone()).await; - let (collab, _encode_collab) = { - let (mut collab, encode_collab) = match result { - Ok(value) => value, - Err(err) => { - if err.is_record_not_found() { - is_new_collab = true; - let collab = Collab::new_with_origin(CollabOrigin::Server, object_id, vec![], false); - let encode_collab = collab.encode_collab_v1(|_| Ok::<_, RealtimeError>(()))?; - (collab, encode_collab) - } else { - return Err(RealtimeError::CreateGroupFailed( - CreateGroupFailedReason::CannotGetCollabData, - )); - } - }, - }; - - collab.initialize(); - (collab, encode_collab) + let res = self + .storage + .get_encode_collab(GetCollabOrigin::Server, params, false) + .await; + let state_vector = match res { + Ok(collab) => Collab::new_with_source( + CollabOrigin::Server, + object_id, + DataSource::DocStateV1(collab.doc_state.into()), + vec![], + false, + )? + .transact() + .state_vector(), + Err(err) if err.is_record_not_found() => StateVector::default(), + Err(err) => return Err(RealtimeError::CannotCreateGroup(err.to_string())), }; trace!( @@ -148,25 +145,25 @@ where collab_type ); - let group = Arc::new(CollabGroup::new( + let group = CollabGroup::new( user.uid, workspace_id.to_string(), object_id.to_string(), collab_type, - collab, self.metrics_calculate.clone(), self.storage.clone(), - is_new_collab, + self.collab_redis_stream.clone(), self.persistence_interval, - self.edit_state_max_count, - self.edit_state_max_secs, + self.prune_grace_period, + state_vector, self.indexer_scheduler.clone(), - )?); + )?; self.state.insert_group(object_id, group); Ok(()) } } +#[allow(dead_code)] #[instrument(level = "trace", skip_all)] async fn load_collab( uid: i64, diff --git a/services/appflowy-collaborate/src/group/mod.rs b/services/appflowy-collaborate/src/group/mod.rs index 61190f640..f48e22ab7 100644 --- a/services/appflowy-collaborate/src/group/mod.rs +++ b/services/appflowy-collaborate/src/group/mod.rs @@ -1,9 +1,6 @@ -pub(crate) mod broadcast; pub(crate) mod cmd; pub(crate) mod group_init; pub(crate) mod manager; mod null_sender; -mod persistence; mod plugin; -pub(crate) mod protocol; mod state; diff --git a/services/appflowy-collaborate/src/group/plugin/history_plugin.rs b/services/appflowy-collaborate/src/group/plugin/history_plugin.rs index 7d108461a..eb1ff2718 100644 --- a/services/appflowy-collaborate/src/group/plugin/history_plugin.rs +++ b/services/appflowy-collaborate/src/group/plugin/history_plugin.rs @@ -64,7 +64,7 @@ where let data = encode_collab.doc_state; let params = InsertSnapshotParams { object_id, - data, + doc_state: data, workspace_id, collab_type, }; diff --git a/services/appflowy-collaborate/src/group/protocol.rs b/services/appflowy-collaborate/src/group/protocol.rs deleted file mode 100644 index ec33792fa..000000000 --- a/services/appflowy-collaborate/src/group/protocol.rs +++ /dev/null @@ -1,127 +0,0 @@ -use std::sync::Arc; - -use async_trait::async_trait; -use collab::core::collab::{TransactionExt, TransactionMutExt}; -use collab::core::origin::CollabOrigin; -use tokio::time::Instant; -use yrs::updates::encoder::{Encode, Encoder, EncoderV1}; -use yrs::{ReadTxn, StateVector, Transact}; - -use collab_rt_protocol::CollabSyncProtocol; -use collab_rt_protocol::{ - decode_update, CollabRef, CustomMessage, Message, RTProtocolError, SyncMessage, -}; - -use crate::CollabRealtimeMetrics; - -#[derive(Clone)] -pub struct ServerSyncProtocol { - metrics: Arc, -} - -impl ServerSyncProtocol { - pub fn new(metrics: Arc) -> Self { - Self { metrics } - } -} - -#[async_trait] -impl CollabSyncProtocol for ServerSyncProtocol { - async fn handle_sync_step1( - &self, - collab: &CollabRef, - sv: StateVector, - ) -> Result>, RTProtocolError> { - let (doc_state, state_vector) = { - let lock = collab.read().await; - let collab = (*lock).borrow(); - - let txn = collab.get_awareness().doc().try_transact().map_err(|err| { - RTProtocolError::YrsTransaction(format!("fail to handle sync step1. error: {}", err)) - })?; - - let doc_state = txn.try_encode_state_as_update_v1(&sv).map_err(|err| { - RTProtocolError::YrsEncodeState(format!( - "fail to encode state as update. error: {}\ninit state vector: {:?}\ndocument state: {:#?}", - err, - sv, - txn.store() - )) - })?; - (doc_state, txn.state_vector()) - }; - - // Retrieve the latest document state from the client after they return online from offline editing. - let mut encoder = EncoderV1::new(); - Message::Sync(SyncMessage::SyncStep2(doc_state)).encode(&mut encoder); - - //FIXME: this should never happen as response to sync step 1 from the client, but rather be - // send when a connection is established - Message::Sync(SyncMessage::SyncStep1(state_vector)).encode(&mut encoder); - Ok(Some(encoder.to_vec())) - } - - async fn handle_sync_step2( - &self, - origin: &CollabOrigin, - collab: &CollabRef, - update: Vec, - ) -> Result>, RTProtocolError> { - self.metrics.apply_update_size.observe(update.len() as f64); - let start = Instant::now(); - let result = { - let update = decode_update(update).await?; - let mut lock = collab.write().await; - let collab = (*lock).borrow_mut(); - - let mut txn = collab - .get_awareness() - .doc() - .try_transact_mut_with(origin.clone()) - .map_err(|err| { - RTProtocolError::YrsTransaction(format!("sync step2 transaction acquire: {}", err)) - })?; - txn.try_apply_update(update).map_err(|err| { - RTProtocolError::YrsApplyUpdate(format!( - "sync step2 apply update: {}\ndocument state: {:#?}", - err, - txn.store() - )) - })?; - - // If server can't apply updates sent by client, which means the server is missing some updates - // from the client or the client is missing some updates from the server. - // If the client can't apply broadcast from server, which means the client is missing some - // updates. - match txn.store().pending_update() { - Some(_update) => { - // let state_vector_v1 = update.missing.encode_v1(); - // for the moment, we don't need to send missing updates to the client. passing None - // instead, which will trigger a sync step 0 on client - let state_vector_v1 = txn.state_vector().encode_v1(); - Err(RTProtocolError::MissUpdates { - state_vector_v1: Some(state_vector_v1), - reason: "server miss updates".to_string(), - }) - }, - None => Ok(None), - } - }; - - let elapsed = start.elapsed(); - self - .metrics - .apply_update_time - .observe(elapsed.as_millis() as f64); - - result - } - - async fn handle_custom_message( - &self, - _collab: &CollabRef, - _msg: CustomMessage, - ) -> Result>, RTProtocolError> { - Ok(None) - } -} diff --git a/services/appflowy-collaborate/src/group/state.rs b/services/appflowy-collaborate/src/group/state.rs index 76b18d1f9..5ca8f6513 100644 --- a/services/appflowy-collaborate/src/group/state.rs +++ b/services/appflowy-collaborate/src/group/state.rs @@ -64,7 +64,9 @@ impl GroupManagementState { loop { match self.group_by_object_id.try_get(object_id) { - TryResult::Present(group) => return Some(group.clone()), + TryResult::Present(group) => { + return Some(group.clone()); + }, TryResult::Absent => return None, TryResult::Locked => { attempts += 1; @@ -106,22 +108,29 @@ impl GroupManagementState { } } - pub(crate) fn insert_group(&self, object_id: &str, group: Arc) { - self.group_by_object_id.insert(object_id.to_string(), group); + pub(crate) fn insert_group(&self, object_id: &str, group: CollabGroup) { + self + .group_by_object_id + .insert(object_id.to_string(), group.into()); self.metrics_calculate.opening_collab_count.inc(); } pub(crate) fn contains_group(&self, object_id: &str) -> bool { - self.group_by_object_id.contains_key(object_id) + if let Some(group) = self.group_by_object_id.get(object_id) { + let cancelled = group.is_cancelled(); + !cancelled + } else { + false + } } pub(crate) fn remove_group(&self, object_id: &str) { - let entry = self.group_by_object_id.remove(object_id); - - if entry.is_none() { + let group_not_found = self.group_by_object_id.remove(object_id).is_none(); + if group_not_found { // Log error if the group doesn't exist error!("Group for object_id:{} not found", object_id); } + self .metrics_calculate .opening_collab_count diff --git a/services/appflowy-collaborate/src/indexer/indexer_scheduler.rs b/services/appflowy-collaborate/src/indexer/indexer_scheduler.rs index 7ca4f2053..f6b03a21e 100644 --- a/services/appflowy-collaborate/src/indexer/indexer_scheduler.rs +++ b/services/appflowy-collaborate/src/indexer/indexer_scheduler.rs @@ -11,7 +11,6 @@ use bytes::Bytes; use collab::core::collab::DataSource; use collab::core::origin::CollabOrigin; use collab::entity::EncodedCollab; -use collab::lock::RwLock; use collab::preclude::Collab; use collab_entity::CollabType; use dashmap::DashMap; @@ -281,7 +280,7 @@ impl IndexerScheduler { &self, workspace_id: &str, object_id: &str, - collab: &Arc>, + collab: &Collab, collab_type: &CollabType, ) -> Result<(), AppError> { if !self.index_enabled() { @@ -304,9 +303,7 @@ impl IndexerScheduler { let workspace_id = Uuid::parse_str(workspace_id)?; let embedder = self.create_embedder()?; - let lock = collab.read().await; - let chunks = indexer.create_embedded_chunks(&lock, embedder.model())?; - drop(lock); // release the read lock ASAP + let chunks = indexer.create_embedded_chunks(collab, embedder.model())?; let threads = self.threads.clone(); let tx = self.schedule_tx.clone(); diff --git a/services/appflowy-collaborate/src/lib.rs b/services/appflowy-collaborate/src/lib.rs index c72e36ad2..972fc075c 100644 --- a/services/appflowy-collaborate/src/lib.rs +++ b/services/appflowy-collaborate/src/lib.rs @@ -8,7 +8,7 @@ pub mod compression; pub mod config; pub mod connect_state; pub mod error; -mod group; +pub mod group; pub mod metrics; mod permission; mod pg_listener; diff --git a/services/appflowy-collaborate/src/metrics.rs b/services/appflowy-collaborate/src/metrics.rs index 3c2d15925..dea643097 100644 --- a/services/appflowy-collaborate/src/metrics.rs +++ b/services/appflowy-collaborate/src/metrics.rs @@ -8,16 +8,20 @@ pub struct CollabRealtimeMetrics { pub(crate) connected_users: Gauge, pub(crate) opening_collab_count: Gauge, pub(crate) num_of_editing_users: Gauge, + /// Number of times a compact state collab load has been done. + pub(crate) load_collab_count: Gauge, + /// Number of times a full state collab (with history) load has been done. + pub(crate) load_full_collab_count: Gauge, /// The number of apply update pub(crate) apply_update_count: Gauge, /// The number of apply update failed pub(crate) apply_update_failed_count: Gauge, - pub(crate) acquire_collab_lock_count: Gauge, - pub(crate) acquire_collab_lock_fail_count: Gauge, - /// How long it takes to apply update in milliseconds. - pub(crate) apply_update_time: Histogram, - /// How big the update is in bytes. - pub(crate) apply_update_size: Histogram, + /// How long it takes to load a collab (from snapshot and updates combined). + pub(crate) load_collab_time: Histogram, + /// How big is the collab (no history, after applying all updates). + pub(crate) collab_size: Histogram, + /// How big is the collab (with history, after applying all updates). + pub(crate) full_collab_size: Histogram, } impl CollabRealtimeMetrics { @@ -28,23 +32,30 @@ impl CollabRealtimeMetrics { num_of_editing_users: Gauge::default(), apply_update_count: Default::default(), apply_update_failed_count: Default::default(), - acquire_collab_lock_count: Default::default(), - acquire_collab_lock_fail_count: Default::default(), // when it comes to histograms we organize them by buckets or specific sizes - since our // prometheus client doesn't support Summary type, we use Histogram type instead - // time spent on apply_update in milliseconds: 1ms, 5ms, 15ms, 30ms, 100ms, 200ms, 500ms, 1s - apply_update_time: Histogram::new( + // time spent on loading collab in milliseconds: 1ms, 5ms, 15ms, 30ms, 100ms, 200ms, 500ms, 1s + load_collab_time: Histogram::new( [1.0, 5.0, 15.0, 30.0, 100.0, 200.0, 500.0, 1000.0].into_iter(), ), - // update size in bytes: 128B, 512B, 1KB, 64KB, 512KB, 1MB, 5MB, 10MB - apply_update_size: Histogram::new( + // collab size in bytes: 128B, 512B, 1KB, 64KB, 512KB, 1MB, 5MB, 10MB + collab_size: Histogram::new( [ 128.0, 512.0, 1024.0, 65536.0, 524288.0, 1048576.0, 5242880.0, 10485760.0, ] .into_iter(), ), + // collab size in bytes: 128B, 512B, 1KB, 64KB, 512KB, 1MB, 5MB, 10MB + full_collab_size: Histogram::new( + [ + 128.0, 512.0, 1024.0, 65536.0, 524288.0, 1048576.0, 5242880.0, 10485760.0, + ] + .into_iter(), + ), + load_collab_count: Default::default(), + load_full_collab_count: Default::default(), } } @@ -76,28 +87,31 @@ impl CollabRealtimeMetrics { "number of apply update failed", metrics.apply_update_failed_count.clone(), ); - realtime_registry.register( - "acquire_collab_lock_count", - "number of acquire collab lock", - metrics.acquire_collab_lock_count.clone(), + "load_collab_time", + "time spent on loading collab in milliseconds", + metrics.load_collab_time.clone(), ); realtime_registry.register( - "acquire_collab_lock_fail_count", - "number of acquire collab lock failed", - metrics.acquire_collab_lock_fail_count.clone(), + "collab_size", + "size of compact collab in bytes", + metrics.collab_size.clone(), ); realtime_registry.register( - "apply_update_time", - "time spent on applying collab updates in milliseconds", - metrics.apply_update_time.clone(), + "full_collab_size", + "size of full collab in bytes", + metrics.full_collab_size.clone(), ); realtime_registry.register( - "apply_update_size", - "size of updates applied to collab in bytes", - metrics.apply_update_size.clone(), + "load_collab_count", + "number of collab loads (no history)", + metrics.load_collab_count.clone(), + ); + realtime_registry.register( + "load_full_collab_count", + "number of collab loads (with history)", + metrics.load_full_collab_count.clone(), ); - metrics } } diff --git a/services/appflowy-collaborate/src/rt_server.rs b/services/appflowy-collaborate/src/rt_server.rs index 98e7b50ae..59f5917ad 100644 --- a/services/appflowy-collaborate/src/rt_server.rs +++ b/services/appflowy-collaborate/src/rt_server.rs @@ -6,8 +6,11 @@ use anyhow::{anyhow, Result}; use app_error::AppError; use collab_rt_entity::user::{RealtimeUser, UserDevice}; use collab_rt_entity::MessageByObjectId; +use collab_stream::client::CollabRedisStream; +use collab_stream::stream_router::StreamRouter; use dashmap::mapref::entry::Entry; use dashmap::DashMap; +use redis::aio::ConnectionManager; use tokio::sync::mpsc::Sender; use tokio::task::yield_now; use tokio::time::interval; @@ -50,9 +53,10 @@ where access_control: Arc, metrics: Arc, command_recv: CLCommandReceiver, + redis_stream_router: Arc, + redis_connection_manager: ConnectionManager, group_persistence_interval: Duration, - edit_state_max_count: u32, - edit_state_max_secs: i64, + prune_grace_period: Duration, indexer_scheduler: Arc, ) -> Result { let enable_custom_runtime = get_env_var("APPFLOWY_COLLABORATE_MULTI_THREAD", "false") @@ -66,14 +70,16 @@ where } let connect_state = ConnectState::new(); + let collab_stream = + CollabRedisStream::new_with_connection_manager(redis_connection_manager, redis_stream_router); let group_manager = Arc::new( GroupManager::new( storage.clone(), access_control.clone(), metrics.clone(), + collab_stream, group_persistence_interval, - edit_state_max_count, - edit_state_max_secs, + prune_grace_period, indexer_scheduler.clone(), ) .await?, diff --git a/services/appflowy-collaborate/src/snapshot/snapshot_control.rs b/services/appflowy-collaborate/src/snapshot/snapshot_control.rs index 6b5e68a96..ed7cb09e5 100644 --- a/services/appflowy-collaborate/src/snapshot/snapshot_control.rs +++ b/services/appflowy-collaborate/src/snapshot/snapshot_control.rs @@ -3,6 +3,7 @@ use std::time::Duration; use chrono::{DateTime, Utc}; use collab::entity::{EncodedCollab, EncoderVersion}; +use collab_entity::CollabType; use sqlx::PgPool; use tracing::{debug, error, trace, warn}; use validator::Validate; @@ -14,6 +15,7 @@ use database::collab::{ }; use database::file::s3_client_impl::AwsS3BucketClientImpl; use database::file::{BucketClient, ResponseBlob}; +use database::history::ops::get_latest_snapshot; use database_entity::dto::{ AFSnapshotMeta, AFSnapshotMetas, InsertSnapshotParams, SnapshotData, ZSTD_COMPRESSION_LEVEL, }; @@ -60,7 +62,6 @@ fn get_meta(objct_key: String) -> Option { } #[derive(Clone)] -// #[deprecated(note = "snapshot is implemented in the appflowy-history")] pub struct SnapshotControl { pg_pool: PgPool, s3: AwsS3BucketClientImpl, @@ -119,7 +120,7 @@ impl SnapshotControl { let timestamp = Utc::now(); let snapshot_id = timestamp.timestamp_millis(); let key = collab_snapshot_key(¶ms.workspace_id, ¶ms.object_id, snapshot_id); - let compressed = zstd::encode_all(params.data.as_ref(), ZSTD_COMPRESSION_LEVEL)?; + let compressed = zstd::encode_all(params.doc_state.as_ref(), ZSTD_COMPRESSION_LEVEL)?; if let Err(err) = self.s3.put_blob(&key, compressed.into(), None).await { self.collab_metrics.write_snapshot_failures.inc(); return Err(err); @@ -241,6 +242,41 @@ impl SnapshotControl { .await } + pub async fn get_latest_snapshot( + &self, + workspace_id: &str, + oid: &str, + collab_type: CollabType, + ) -> Result, AppError> { + let snapshot_prefix = collab_snapshot_prefix(workspace_id, oid); + let mut resp = self.s3.list_dir(&snapshot_prefix, 1).await?; + if let Some(key) = resp.pop() { + let resp = self.s3.get_blob(&key).await?; + let decompressed = zstd::decode_all(&*resp.to_blob())?; + let encoded_collab = EncodedCollab { + state_vector: Default::default(), + doc_state: decompressed.into(), + version: EncoderVersion::V1, + }; + Ok(Some(SnapshotData { + object_id: oid.to_string(), + encoded_collab_v1: encoded_collab.encode_to_bytes()?, + workspace_id: workspace_id.to_string(), + })) + } else { + let snapshot = get_latest_snapshot(oid, &collab_type, &self.pg_pool).await?; + Ok( + snapshot + .and_then(|row| row.snapshot_meta) + .map(|meta| SnapshotData { + object_id: oid.to_string(), + encoded_collab_v1: meta.snapshot, + workspace_id: workspace_id.to_string(), + }), + ) + } + } + async fn latest_snapshot_time( &self, workspace_id: &str, diff --git a/services/appflowy-collaborate/src/state.rs b/services/appflowy-collaborate/src/state.rs index 33f01ac4b..58eca5d22 100644 --- a/services/appflowy-collaborate/src/state.rs +++ b/services/appflowy-collaborate/src/state.rs @@ -13,6 +13,7 @@ use crate::pg_listener::PgListeners; use crate::CollabRealtimeMetrics; use access_control::metrics::AccessControlMetrics; use app_error::AppError; +use collab_stream::stream_router::StreamRouter; use database::user::{select_all_uid_uuid, select_uid_from_uuid}; use indexer::metrics::EmbeddingMetrics; use indexer::scheduler::IndexerScheduler; @@ -24,6 +25,7 @@ pub struct AppState { pub config: Arc, pub pg_listeners: Arc, pub user_cache: UserCache, + pub redis_stream_router: Arc, pub redis_connection_manager: RedisConnectionManager, pub access_control: AccessControl, pub collab_access_control_storage: Arc, diff --git a/services/appflowy-collaborate/tests/main.rs b/services/appflowy-collaborate/tests/main.rs deleted file mode 100644 index 8b1378917..000000000 --- a/services/appflowy-collaborate/tests/main.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/services/appflowy-history_deprecated/src/core/manager.rs b/services/appflowy-history_deprecated/src/core/manager.rs index dc44bdc6c..4ea7b31d6 100644 --- a/services/appflowy-history_deprecated/src/core/manager.rs +++ b/services/appflowy-history_deprecated/src/core/manager.rs @@ -239,7 +239,7 @@ async fn init_collab_handle( ) -> Result { let group_name = format!("history_{}:{}", workspace_id, object_id); let update_stream = redis_stream - .collab_update_stream(workspace_id, object_id, &group_name) + .collab_update_stream_group(workspace_id, object_id, &group_name) .await .unwrap(); diff --git a/services/appflowy-history_deprecated/tests/edit_test/recv_update_test.rs b/services/appflowy-history_deprecated/tests/edit_test/recv_update_test.rs index df89f156b..cda1396bf 100644 --- a/services/appflowy-history_deprecated/tests/edit_test/recv_update_test.rs +++ b/services/appflowy-history_deprecated/tests/edit_test/recv_update_test.rs @@ -27,7 +27,7 @@ async fn apply_update_stream_updates_test() { .unwrap(); let mut update_group = redis_stream - .collab_update_stream(&workspace_id, &object_id, "appflowy_cloud") + .collab_update_stream_group(&workspace_id, &object_id, "appflowy_cloud") .await .unwrap(); @@ -81,7 +81,7 @@ async fn apply_update_stream_updates_test() { // .unwrap(); // // let mut update_group = redis_stream -// .collab_update_stream(&workspace_id, &object_id, "appflowy_cloud") +// .collab_update_stream_group(&workspace_id, &object_id, "appflowy_cloud") // .await // .unwrap(); // diff --git a/services/appflowy-history_deprecated/tests/stream_test/update_stream_test.rs b/services/appflowy-history_deprecated/tests/stream_test/update_stream_test.rs index 4b2b6e16e..ab1ac86a4 100644 --- a/services/appflowy-history_deprecated/tests/stream_test/update_stream_test.rs +++ b/services/appflowy-history_deprecated/tests/stream_test/update_stream_test.rs @@ -10,7 +10,7 @@ async fn single_reader_single_sender_update_stream_test() { let object_id = uuid::Uuid::new_v4().to_string(); let mut send_group = redis_stream - .collab_update_stream(&workspace, &object_id, "write") + .collab_update_stream_group(&workspace, &object_id, "write") .await .unwrap(); for i in 0..5 { @@ -18,7 +18,7 @@ async fn single_reader_single_sender_update_stream_test() { } let mut recv_group = redis_stream - .collab_update_stream(&workspace, &object_id, "read1") + .collab_update_stream_group(&workspace, &object_id, "read1") .await .unwrap(); @@ -55,19 +55,19 @@ async fn multiple_reader_single_sender_update_stream_test() { let object_id = uuid::Uuid::new_v4().to_string(); let mut send_group = redis_stream - .collab_update_stream(&workspace, &object_id, "write") + .collab_update_stream_group(&workspace, &object_id, "write") .await .unwrap(); send_group.insert_message(vec![1, 2, 3]).await.unwrap(); send_group.insert_message(vec![4, 5, 6]).await.unwrap(); let recv_group_1 = redis_stream - .collab_update_stream(&workspace, &object_id, "read1") + .collab_update_stream_group(&workspace, &object_id, "read1") .await .unwrap(); let recv_group_2 = redis_stream - .collab_update_stream(&workspace, &object_id, "read2") + .collab_update_stream_group(&workspace, &object_id, "read2") .await .unwrap(); // Both groups should have the same messages diff --git a/services/appflowy-worker/src/indexer_worker/worker.rs b/services/appflowy-worker/src/indexer_worker/worker.rs index 39e805e93..fd0d20253 100644 --- a/services/appflowy-worker/src/indexer_worker/worker.rs +++ b/services/appflowy-worker/src/indexer_worker/worker.rs @@ -228,7 +228,7 @@ fn handle_task( task.collab_type ); let chunks = match task.data { - UnindexedData::UnindexedText(text) => indexer + UnindexedData::Text(text) => indexer .create_embedded_chunks_from_text(task.object_id.clone(), text, embedder.model()) .ok()?, }; diff --git a/src/api/workspace.rs b/src/api/workspace.rs index 882299fe5..bfb908ed0 100644 --- a/src/api/workspace.rs +++ b/src/api/workspace.rs @@ -743,7 +743,7 @@ async fn create_collab_handler( workspace_id_uuid, params.object_id.clone(), params.collab_type.clone(), - UnindexedData::UnindexedText(text), + UnindexedData::Text(text), ); state .indexer_scheduler @@ -875,8 +875,7 @@ async fn batch_create_collab_handler( let total_size = collab_params_list .iter() .fold(0, |acc, x| acc + x.1.encoded_collab_v1.len()); - event!( - tracing::Level::INFO, + tracing::info!( "decompressed {} collab objects in {:?}", collab_params_list.len(), start.elapsed() @@ -903,7 +902,7 @@ async fn batch_create_collab_handler( workspace_id_uuid, value.1.object_id.clone(), value.1.collab_type.clone(), - UnindexedData::UnindexedText(text), + UnindexedData::Text(text), ) }) .ok(), @@ -922,8 +921,7 @@ async fn batch_create_collab_handler( .batch_insert_new_collab(&workspace_id, &uid, collab_params_list) .await?; - event!( - tracing::Level::INFO, + tracing::info!( "inserted collab objects to disk in {:?}, total size:{}", start.elapsed(), total_size @@ -1373,7 +1371,7 @@ async fn create_collab_snapshot_handler( .create_snapshot(InsertSnapshotParams { object_id, workspace_id, - data, + doc_state: data, collab_type, }) .await?; @@ -1462,7 +1460,7 @@ async fn update_collab_handler( workspace_id_uuid, params.object_id.clone(), params.collab_type.clone(), - UnindexedData::UnindexedText(text), + UnindexedData::Text(text), ); state .indexer_scheduler @@ -1984,7 +1982,6 @@ async fn post_realtime_message_stream_handler( bytes.extend_from_slice(&item?); } - event!(tracing::Level::INFO, "message len: {}", bytes.len()); let device_id = device_id.to_string(); let message = parser_realtime_msg(bytes.freeze(), req.clone()).await?; diff --git a/src/application.rs b/src/application.rs index 2f938b1eb..8f76e2ec5 100644 --- a/src/application.rs +++ b/src/application.rs @@ -42,6 +42,7 @@ use appflowy_collaborate::collab::storage::CollabStorageImpl; use appflowy_collaborate::command::{CLCommandReceiver, CLCommandSender}; use appflowy_collaborate::snapshot::SnapshotControl; use appflowy_collaborate::CollaborationServer; +use collab_stream::stream_router::{StreamRouter, StreamRouterOptions}; use database::file::s3_client_impl::{AwsS3BucketClientImpl, S3BucketStorage}; use indexer::collab_indexer::IndexerProvider; use indexer::scheduler::{IndexerConfiguration, IndexerScheduler}; @@ -134,9 +135,10 @@ pub async fn run_actix_server( state.realtime_access_control.clone(), state.metrics.realtime_metrics.clone(), rt_cmd_recv, + state.redis_stream_router.clone(), + state.redis_connection_manager.clone(), Duration::from_secs(config.collab.group_persistence_interval_secs), - config.collab.edit_state_max_count, - config.collab.edit_state_max_secs, + Duration::from_secs(config.collab.group_prune_grace_period_secs), state.indexer_scheduler.clone(), ) .await @@ -246,7 +248,8 @@ pub async fn init_state(config: &Config, rt_cmd_tx: CLCommandSender) -> Result Result Result { +async fn get_redis_client( + redis_uri: &str, + worker_count: usize, +) -> Result<(redis::aio::ConnectionManager, Arc), Error> { info!("Connecting to redis with uri: {}", redis_uri); - let manager = redis::Client::open(redis_uri) - .context("failed to connect to redis")? + let client = redis::Client::open(redis_uri).context("failed to connect to redis")?; + + let router = StreamRouter::with_options( + &client, + StreamRouterOptions { + worker_count, + xread_streams: 100, + xread_block_millis: Some(5000), + xread_count: None, + }, + )?; + + let manager = client .get_connection_manager() .await .context("failed to get the connection manager")?; - Ok(manager) + Ok((manager, router.into())) } pub async fn get_aws_s3_client(s3_setting: &S3Setting) -> Result { diff --git a/src/biz/workspace/ops.rs b/src/biz/workspace/ops.rs index d93f51989..e3ad76f2f 100644 --- a/src/biz/workspace/ops.rs +++ b/src/biz/workspace/ops.rs @@ -741,7 +741,7 @@ pub async fn broadcast_update( oid: &str, encoded_update: Vec, ) -> Result<(), AppError> { - tracing::info!("broadcasting update to group: {}", oid); + tracing::trace!("broadcasting update to group: {}", oid); let payload = Message::Sync(SyncMessage::Update(encoded_update)).encode_v1(); let msg = ClientCollabMessage::ClientUpdateSync { data: UpdateSync { diff --git a/src/config/config.rs b/src/config/config.rs index 107fba8d5..13f74c56c 100644 --- a/src/config/config.rs +++ b/src/config/config.rs @@ -19,6 +19,7 @@ pub struct Config { pub application: ApplicationSetting, pub websocket: WebsocketSetting, pub redis_uri: Secret, + pub redis_worker_count: usize, pub s3: S3Setting, pub appflowy_ai: AppFlowyAISetting, pub grpc_history: GrpcHistorySetting, @@ -143,6 +144,7 @@ pub struct GrpcHistorySetting { #[derive(Clone, Debug)] pub struct CollabSetting { pub group_persistence_interval_secs: u64, + pub group_prune_grace_period_secs: u64, pub edit_state_max_count: u32, pub edit_state_max_secs: i64, pub s3_collab_threshold: u64, @@ -224,6 +226,7 @@ pub fn get_configuration() -> Result { min_client_version: get_env_var("APPFLOWY_WEBSOCKET_CLIENT_MIN_VERSION", "0.5.0").parse()?, }, redis_uri: get_env_var("APPFLOWY_REDIS_URI", "redis://localhost:6379").into(), + redis_worker_count: get_env_var("APPFLOWY_REDIS_WORKERS", "60").parse()?, s3: S3Setting { create_bucket: get_env_var("APPFLOWY_S3_CREATE_BUCKET", "true") .parse() @@ -250,6 +253,8 @@ pub fn get_configuration() -> Result { "60", ) .parse()?, + group_prune_grace_period_secs: get_env_var("APPFLOWY_COLLAB_GROUP_GRACE_PERIOD_SECS", "60") + .parse()?, edit_state_max_count: get_env_var("APPFLOWY_COLLAB_EDIT_STATE_MAX_COUNT", "100").parse()?, edit_state_max_secs: get_env_var("APPFLOWY_COLLAB_EDIT_STATE_MAX_SECS", "60").parse()?, s3_collab_threshold: get_env_var("APPFLOWY_COLLAB_S3_THRESHOLD", "8000").parse()?, diff --git a/src/state.rs b/src/state.rs index 9195781d4..1659d43de 100644 --- a/src/state.rs +++ b/src/state.rs @@ -17,6 +17,7 @@ use appflowy_collaborate::collab::cache::CollabCache; use appflowy_collaborate::collab::storage::CollabAccessControlStorage; use appflowy_collaborate::metrics::CollabMetrics; use appflowy_collaborate::CollabRealtimeMetrics; +use collab_stream::stream_router::StreamRouter; use database::file::s3_client_impl::{AwsS3BucketClientImpl, S3BucketStorage}; use database::user::{select_all_uid_uuid, select_uid_from_uuid}; use gotrue::grant::{Grant, PasswordGrant}; @@ -39,6 +40,7 @@ pub struct AppState { pub user_cache: UserCache, pub id_gen: Arc>, pub gotrue_client: gotrue::api::Client, + pub redis_stream_router: Arc, pub redis_connection_manager: RedisConnectionManager, pub collab_cache: CollabCache, pub collab_access_control_storage: Arc, diff --git a/tests/collab/asset/automerge-paper.json.gz b/tests/collab/asset/automerge-paper.json.gz new file mode 100644 index 000000000..aa7f67ec0 Binary files /dev/null and b/tests/collab/asset/automerge-paper.json.gz differ diff --git a/tests/collab/awareness_test.rs b/tests/collab/awareness_test.rs index d6f45f318..f7de1345b 100644 --- a/tests/collab/awareness_test.rs +++ b/tests/collab/awareness_test.rs @@ -24,13 +24,14 @@ async fn viewing_document_editing_users_test() { let owner_uid = owner.uid().await; let clients = owner.get_connect_users(&object_id).await; - assert_eq!(clients.len(), 1); + assert_eq!(clients.len(), 1, "guest shouldn't be connected yet"); assert_eq!(clients[0], owner_uid); guest .open_collab(&workspace_id, &object_id, collab_type) .await; guest.wait_object_sync_complete(&object_id).await.unwrap(); + sleep(Duration::from_secs(1)).await; // after guest open the collab, it will emit an awareness that contains the user id of guest. // This awareness will be sent to the server. Server will broadcast the awareness to all the clients @@ -42,7 +43,7 @@ async fn viewing_document_editing_users_test() { let mut expected_clients = [owner_uid, guest_uid]; expected_clients.sort(); - assert_eq!(clients.len(), 2); + assert_eq!(clients.len(), 2, "expected owner and member connected"); assert_eq!(clients, expected_clients); // simulate the guest close the collab guest.clean_awareness_state(&object_id).await; @@ -50,7 +51,7 @@ async fn viewing_document_editing_users_test() { sleep(Duration::from_secs(5)).await; guest.wait_object_sync_complete(&object_id).await.unwrap(); let clients = owner.get_connect_users(&object_id).await; - assert_eq!(clients.len(), 1); + assert_eq!(clients.len(), 1, "expected only owner connected"); assert_eq!(clients[0], owner_uid); // simulate the guest open the collab again diff --git a/tests/collab/mod.rs b/tests/collab/mod.rs index a5a65c407..2223bf81b 100644 --- a/tests/collab/mod.rs +++ b/tests/collab/mod.rs @@ -7,6 +7,8 @@ mod missing_update_test; mod multi_devices_edit; mod permission_test; mod single_device_edit; +mod snapshot_test; mod storage_test; +mod stress_test; pub mod util; mod web_edit; diff --git a/tests/collab/snapshot_test.rs b/tests/collab/snapshot_test.rs new file mode 100644 index 000000000..2c71f23b4 --- /dev/null +++ b/tests/collab/snapshot_test.rs @@ -0,0 +1,85 @@ +use client_api_test::{assert_server_collab, TestClient}; +use collab::core::collab::DataSource; +use collab::core::origin::CollabOrigin; +use collab::entity::EncodedCollab; +use collab::preclude::{Collab, JsonValue}; +use collab_entity::CollabType; +use serde_json::json; + +#[tokio::test] +async fn read_write_snapshot() { + let mut c = TestClient::new_user().await; + + // prepare initial document + let wid = c.workspace_id().await; + let oid = c.create_and_edit_collab(&wid, CollabType::Unknown).await; + c.open_collab(&wid, &oid, CollabType::Unknown).await; + c.insert_into(&oid, "title", "t1").await; + c.wait_object_sync_complete(&oid).await.unwrap(); + assert_server_collab( + &wid, + &mut c.api_client, + &oid, + &CollabType::Unknown, + 10, + json!({"title": "t1"}), + ) + .await + .unwrap(); + // create the 1st snapshot + let m1 = c + .create_snapshot(&wid, &oid, CollabType::Unknown) + .await + .unwrap(); + + c.insert_into(&oid, "title", "t2").await; + c.wait_object_sync_complete(&oid).await.unwrap(); + assert_server_collab( + &wid, + &mut c.api_client, + &oid, + &CollabType::Unknown, + 10, + json!({"title": "t2"}), + ) + .await + .unwrap(); + // create the 2nd snapshot + let m2 = c + .create_snapshot(&wid, &oid, CollabType::Unknown) + .await + .unwrap(); + + let snapshots = c.get_snapshot_list(&wid, &oid).await.unwrap(); + assert_eq!(snapshots.0.len(), 2, "expecting 2 snapshots"); + + // retrieve state + verify_snapshot_state(&c, &wid, &oid, &m1.snapshot_id, json!({"title": "t1"})).await; + verify_snapshot_state(&c, &wid, &oid, &m2.snapshot_id, json!({"title": "t2"})).await; +} + +async fn verify_snapshot_state( + c: &TestClient, + workspace_id: &str, + oid: &str, + snapshot_id: &i64, + expected: JsonValue, +) { + let snapshot = c + .get_snapshot(workspace_id, oid, snapshot_id) + .await + .unwrap(); + + // retrieve state + let encoded_collab = EncodedCollab::decode_from_bytes(&snapshot.encoded_collab_v1).unwrap(); + let collab = Collab::new_with_source( + CollabOrigin::Empty, + oid, + DataSource::DocStateV1(encoded_collab.doc_state.into()), + vec![], + true, + ) + .unwrap(); + let actual = collab.to_json_value(); + assert_eq!(actual, expected); +} diff --git a/tests/collab/stress_test.rs b/tests/collab/stress_test.rs new file mode 100644 index 000000000..7dacb1d9d --- /dev/null +++ b/tests/collab/stress_test.rs @@ -0,0 +1,81 @@ +use std::sync::Arc; +use std::time::Duration; + +use collab_entity::CollabType; +use serde_json::json; +use tokio::time::sleep; +use uuid::Uuid; + +use super::util::TestScenario; +use client_api_test::{assert_server_collab, TestClient}; +use database_entity::dto::AFRole; + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_test_run_multiple_text_edits() { + const READER_COUNT: usize = 1; + let test_scenario = Arc::new(TestScenario::open( + "./tests/collab/asset/automerge-paper.json.gz", + )); + // create writer + let mut writer = TestClient::new_user().await; + sleep(Duration::from_secs(2)).await; // sleep 2 secs to make sure it do not trigger register user too fast in gotrue + + let object_id = Uuid::new_v4().to_string(); + let workspace_id = writer.workspace_id().await; + + writer + .open_collab(&workspace_id, &object_id, CollabType::Unknown) + .await; + + // create readers and invite them into the same workspace + let mut readers = Vec::with_capacity(READER_COUNT); + for _ in 0..READER_COUNT { + let mut reader = TestClient::new_user().await; + sleep(Duration::from_secs(2)).await; // sleep 2 secs to make sure it do not trigger register user too fast in gotrue + writer + .invite_and_accepted_workspace_member(&workspace_id, &reader, AFRole::Member) + .await + .unwrap(); + + reader + .open_collab(&workspace_id, &object_id, CollabType::Unknown) + .await; + + readers.push(reader); + } + + // run test scenario + let collab = writer.collabs.get(&object_id).unwrap().collab.clone(); + let expected = test_scenario.execute(collab, 20_000).await; + + // wait for the writer to complete sync + writer.wait_object_sync_complete(&object_id).await.unwrap(); + + // wait for the readers to complete sync + let mut tasks = Vec::with_capacity(READER_COUNT); + for reader in readers.iter() { + let fut = reader.wait_object_sync_complete(&object_id); + tasks.push(fut); + } + let results = futures::future::join_all(tasks).await; + + // make sure that the readers are in correct state + for res in results { + res.unwrap(); + } + + for mut reader in readers.drain(..) { + assert_server_collab( + &workspace_id, + &mut reader.api_client, + &object_id, + &CollabType::Unknown, + 10, + json!({ + "text-id": &expected, + }), + ) + .await + .unwrap(); + } +} diff --git a/tests/collab/util.rs b/tests/collab/util.rs index e6528e0c2..ae04f9cfa 100644 --- a/tests/collab/util.rs +++ b/tests/collab/util.rs @@ -204,3 +204,97 @@ pub async fn redis_connection_manager() -> ConnectionManager { } } } + +use std::io::{BufReader, Read}; + +use collab::preclude::MapExt; +use flate2::bufread::GzDecoder; +use serde::Deserialize; +use yrs::{GetString, Text, TextRef}; + +use client_api_test::CollabRef; + +/// (position, delete length, insert content). +#[derive(Debug, Clone, Deserialize, Eq, PartialEq)] +pub struct TestPatch(pub usize, pub usize, pub String); + +#[derive(Debug, Clone, Deserialize, Eq, PartialEq)] +pub struct TestTxn { + // time: String, // ISO String. Unused. + pub patches: Vec, +} + +#[derive(Debug, Clone, Deserialize, Eq, PartialEq)] +pub struct TestScenario { + #[serde(default)] + pub using_byte_positions: bool, + + #[serde(rename = "startContent")] + pub start_content: String, + #[serde(rename = "endContent")] + pub end_content: String, + + pub txns: Vec, +} + +impl TestScenario { + /// Load the testing data at the specified file. If the filename ends in .gz, it will be + /// transparently uncompressed. + /// + /// This method panics if the file does not exist, or is corrupt. It'd be better to have a try_ + /// variant of this method, but given this is mostly for benchmarking and testing, I haven't felt + /// the need to write that code. + pub fn open(fpath: &str) -> TestScenario { + // let start = SystemTime::now(); + // let mut file = File::open("benchmark_data/automerge-paper.json.gz").unwrap(); + let file = std::fs::File::open(fpath).unwrap(); + + let mut reader = BufReader::new(file); + // We could pass the GzDecoder straight to serde, but it makes it way slower to parse for + // some reason. + let mut raw_json = vec![]; + + if fpath.ends_with(".gz") { + let mut reader = GzDecoder::new(reader); + reader.read_to_end(&mut raw_json).unwrap(); + } else { + reader.read_to_end(&mut raw_json).unwrap(); + } + + let data: TestScenario = serde_json::from_reader(raw_json.as_slice()).unwrap(); + data + } + + pub async fn execute(&self, collab: CollabRef, step_count: usize) -> String { + let mut i = 0; + for t in self.txns.iter().take(step_count) { + i += 1; + if i % 10_000 == 0 { + tracing::trace!("Executed {}/{} steps", i, step_count); + } + let mut lock = collab.write().await; + let collab = lock.borrow_mut(); + let mut txn = collab.context.transact_mut(); + let txt = collab.data.get_or_init_text(&mut txn, "text-id"); + for patch in t.patches.iter() { + let at = patch.0; + let delete = patch.1; + let content = patch.2.as_str(); + + if delete != 0 { + txt.remove_range(&mut txn, at as u32, delete as u32); + } + if !content.is_empty() { + txt.insert(&mut txn, at as u32, content); + } + } + } + + // validate after applying all patches + let lock = collab.read().await; + let collab = lock.borrow(); + let txn = collab.context.transact(); + let txt: TextRef = collab.data.get_with_txn(&txn, "text-id").unwrap(); + txt.get_string(&txn) + } +} diff --git a/tests/workspace/quick_note.rs b/tests/workspace/quick_note.rs index e60d11c63..d2bdc1321 100644 --- a/tests/workspace/quick_note.rs +++ b/tests/workspace/quick_note.rs @@ -21,8 +21,8 @@ async fn quick_note_crud_test() { // To ensure that the creation time is different time::sleep(Duration::from_millis(1)).await; } - let quick_note_id_1 = quick_note_ids[0]; - let quick_note_id_2 = quick_note_ids[1]; + let _quick_note_id_1 = quick_note_ids[0]; + let _quick_note_id_2 = quick_note_ids[1]; let quick_notes = client .api_client .list_quick_notes(workspace_uuid, None, None, None) @@ -30,9 +30,11 @@ async fn quick_note_crud_test() { .expect("list quick notes"); assert_eq!(quick_notes.quick_notes.len(), 2); assert!(!quick_notes.has_more); - assert_eq!(quick_notes.quick_notes[0].id, quick_note_id_2); - assert_eq!(quick_notes.quick_notes[1].id, quick_note_id_1); + let mut notes_sorted_by_created_at_asc = quick_notes.quick_notes.clone(); + notes_sorted_by_created_at_asc.sort_by(|a, b| a.created_at.cmp(&b.created_at)); + let quick_note_id_1 = notes_sorted_by_created_at_asc[0].id; + let quick_note_id_2 = notes_sorted_by_created_at_asc[1].id; let data_1 = json!([ { "type": "paragraph", diff --git a/xtask/Cargo.toml b/xtask/Cargo.toml index 284e1869c..728e6a28b 100644 --- a/xtask/Cargo.toml +++ b/xtask/Cargo.toml @@ -7,4 +7,5 @@ edition = "2021" [dependencies] anyhow = "1.0" -tokio = { version = "1", features = ["full"] } \ No newline at end of file +tokio = { version = "1", features = ["full"] } +futures = "0.3.31" \ No newline at end of file diff --git a/xtask/src/main.rs b/xtask/src/main.rs index ea4de3513..8b9d86815 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -1,61 +1,112 @@ use anyhow::{anyhow, Context, Result}; +use std::process::Stdio; use tokio::process::Command; use tokio::select; +use tokio::time::{sleep, Duration}; -/// Using 'cargo run --package xtask' to run servers in parallel. -/// 1. AppFlowy Cloud -/// 2. AppFlowy History -/// 3. AppFlowy Indexer +/// Run servers: +/// cargo run --package xtask /// -/// Before running this command, make sure the other dependencies servers are running. For example, -/// Redis, Postgres, etc. +/// Run servers and stress tests: +/// cargo run --package xtask -- --stress-test +/// +/// Note: test start with 'stress_test' will be run as stress tests #[tokio::main] async fn main() -> Result<()> { - let appflowy = "appflowy_cloud"; - let worker = "appflowy_worker"; + let is_stress_test = std::env::args().any(|arg| arg == "--stress-test"); let target_dir = "./target"; std::env::set_var("CARGO_TARGET_DIR", target_dir); - kill_existing_process(appflowy).await?; - kill_existing_process(worker).await?; + let appflowy_cloud_bin_name = "appflowy_cloud"; + let worker_bin_name = "appflowy_worker"; - let enable_runtime_profile = false; - let mut appflowy_cloud_cmd = Command::new("cargo"); + // Step 1: Kill existing processes + kill_existing_process(appflowy_cloud_bin_name).await?; + kill_existing_process(worker_bin_name).await?; - appflowy_cloud_cmd - .env("RUSTFLAGS", "--cfg tokio_unstable") - .args(["run", "--features"]); - if enable_runtime_profile { - appflowy_cloud_cmd.args(["history,tokio-runtime-profile"]); - } else { - appflowy_cloud_cmd.args(["history"]); - } - - let mut appflowy_cloud_handle = appflowy_cloud_cmd - .spawn() - .context("Failed to start AppFlowy-Cloud process")?; + // Step 2: Start servers sequentially + println!("Starting {} server...", appflowy_cloud_bin_name); + let mut appflowy_cloud_cmd = spawn_server( + "cargo", + &["run", "--features", "history"], + appflowy_cloud_bin_name, + is_stress_test, + )?; + wait_for_readiness(appflowy_cloud_bin_name).await?; - let mut appflowy_worker_handle = Command::new("cargo") - .args([ + println!("Starting {} server...", worker_bin_name); + let mut appflowy_worker_cmd = spawn_server( + "cargo", + &[ "run", "--manifest-path", "./services/appflowy-worker/Cargo.toml", - ]) - .spawn() - .context("Failed to start AppFlowy-Worker process")?; + ], + worker_bin_name, + is_stress_test, + )?; + wait_for_readiness(worker_bin_name).await?; + + println!("All servers are up and running."); + // Step 3: Run stress tests if flag is set + let stress_test_cmd = if is_stress_test { + println!("Running stress tests (tests starting with 'stress_test')..."); + Some( + Command::new("cargo") + .args(["test", "stress_test", "--", "--nocapture"]) + .stdout(Stdio::inherit()) + .stderr(Stdio::inherit()) + .spawn() + .context("Failed to start stress test process")?, + ) + } else { + None + }; + + // Step 4: Monitor all processes select! { - status = appflowy_cloud_handle.wait() => { - handle_process_exit(status?, appflowy)?; + status = appflowy_cloud_cmd.wait() => { + handle_process_exit(status?, worker_bin_name)?; + }, + status = appflowy_worker_cmd.wait() => { + handle_process_exit(status?, worker_bin_name)?; + }, + status = async { + if let Some(mut stress_cmd) = stress_test_cmd { + stress_cmd.wait().await + } else { + futures::future::pending().await + } + } => { + if is_stress_test { + handle_process_exit(status?, "cargo test stress_test")?; + } }, - status = appflowy_worker_handle.wait() => { - handle_process_exit(status?, worker)?; - } } Ok(()) } +fn spawn_server( + command: &str, + args: &[&str], + name: &str, + suppress_output: bool, +) -> Result { + println!("Spawning {} process...", name); + let mut cmd = Command::new(command); + cmd.args(args); + + if suppress_output { + cmd.stdout(Stdio::null()).stderr(Stdio::null()); + } + + cmd + .spawn() + .context(format!("Failed to start {} process", name)) +} + async fn kill_existing_process(process_identifier: &str) -> Result<()> { let _ = Command::new("pkill") .arg("-f") @@ -79,3 +130,10 @@ fn handle_process_exit(status: std::process::ExitStatus, process_name: &str) -> )) } } + +async fn wait_for_readiness(process_name: &str) -> Result<()> { + println!("Waiting for {} to be ready...", process_name); + sleep(Duration::from_secs(3)).await; + println!("{} is ready.", process_name); + Ok(()) +}