From b3271e0d653de1c585b1b5db95447045b0453b06 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 2 Feb 2025 17:27:39 +0000 Subject: [PATCH] split state_accessor Signed-off-by: Jason Volk --- src/service/rooms/state_accessor/mod.rs | 634 +----------------- .../rooms/state_accessor/room_state.rs | 90 +++ .../rooms/state_accessor/server_can.rs | 73 ++ src/service/rooms/state_accessor/state.rs | 320 +++++++++ src/service/rooms/state_accessor/user_can.rs | 187 ++++++ 5 files changed, 684 insertions(+), 620 deletions(-) create mode 100644 src/service/rooms/state_accessor/room_state.rs create mode 100644 src/service/rooms/state_accessor/server_can.rs create mode 100644 src/service/rooms/state_accessor/state.rs create mode 100644 src/service/rooms/state_accessor/user_can.rs diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index bed8d210a..b7952ce69 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -1,23 +1,19 @@ +mod room_state; +mod server_can; +mod state; +mod user_can; + use std::{ - borrow::Borrow, fmt::Write, - ops::Deref, sync::{Arc, Mutex as StdMutex, Mutex}, }; use conduwuit::{ - at, err, error, pair_of, - pdu::PduBuilder, - utils, - utils::{ - math::{usize_from_f64, Expected}, - result::FlatOk, - stream::{BroadbandExt, IterStream, ReadyExt, TryExpect}, - }, - Err, Error, PduEvent, Result, + err, utils, + utils::math::{usize_from_f64, Expected}, + Result, }; -use database::{Deserialized, Map}; -use futures::{future::try_join, FutureExt, Stream, StreamExt, TryFutureExt}; +use database::Map; use lru_cache::LruCache; use ruma::{ events::{ @@ -29,29 +25,19 @@ use ruma::{ guest_access::{GuestAccess, RoomGuestAccessEventContent}, history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, join_rules::{AllowRule, JoinRule, RoomJoinRulesEventContent, RoomMembership}, - member::{MembershipState, RoomMemberEventContent}, + member::RoomMemberEventContent, name::RoomNameEventContent, - power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent}, topic::RoomTopicEventContent, }, - StateEventType, TimelineEventType, + StateEventType, }, room::RoomType, space::SpaceRoomJoinRule, - EventEncryptionAlgorithm, EventId, JsOption, OwnedEventId, OwnedRoomAliasId, OwnedRoomId, - OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, + EventEncryptionAlgorithm, JsOption, OwnedRoomAliasId, OwnedRoomId, OwnedServerName, + OwnedUserId, RoomId, UserId, }; -use serde::Deserialize; -use crate::{ - rooms, - rooms::{ - short::{ShortEventId, ShortStateHash, ShortStateKey}, - state::RoomMutexGuard, - state_compressor::{compress_state_event, parse_compressed_state_event, CompressedState}, - }, - Dep, -}; +use crate::{rooms, rooms::short::ShortStateHash, Dep}; pub struct Service { pub server_visibility_cache: Mutex>, @@ -143,508 +129,6 @@ impl crate::Service for Service { } impl Service { - /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). - pub async fn room_state_get_content( - &self, - room_id: &RoomId, - event_type: &StateEventType, - state_key: &str, - ) -> Result - where - T: for<'de> Deserialize<'de>, - { - self.room_state_get(room_id, event_type, state_key) - .await - .and_then(|event| event.get_content()) - } - - /// Returns the full room state. - #[tracing::instrument(skip(self), level = "debug")] - pub fn room_state_full<'a>( - &'a self, - room_id: &'a RoomId, - ) -> impl Stream> + Send + 'a { - self.services - .state - .get_room_shortstatehash(room_id) - .map_ok(|shortstatehash| self.state_full(shortstatehash).map(Ok)) - .map_err(move |e| err!(Database("Missing state for {room_id:?}: {e:?}"))) - .try_flatten_stream() - } - - /// Returns the full room state pdus - #[tracing::instrument(skip(self), level = "debug")] - pub fn room_state_full_pdus<'a>( - &'a self, - room_id: &'a RoomId, - ) -> impl Stream> + Send + 'a { - self.services - .state - .get_room_shortstatehash(room_id) - .map_ok(|shortstatehash| self.state_full_pdus(shortstatehash).map(Ok)) - .map_err(move |e| err!(Database("Missing state for {room_id:?}: {e:?}"))) - .try_flatten_stream() - } - - /// Returns a single EventId from `room_id` with key (`event_type`, - /// `state_key`). - #[tracing::instrument(skip(self), level = "debug")] - pub async fn room_state_get_id( - &self, - room_id: &RoomId, - event_type: &StateEventType, - state_key: &str, - ) -> Result - where - Id: for<'de> Deserialize<'de> + Sized + ToOwned, - ::Owned: Borrow, - { - self.services - .state - .get_room_shortstatehash(room_id) - .and_then(|shortstatehash| self.state_get_id(shortstatehash, event_type, state_key)) - .await - } - - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). - #[tracing::instrument(skip(self), level = "debug")] - pub async fn room_state_get( - &self, - room_id: &RoomId, - event_type: &StateEventType, - state_key: &str, - ) -> Result { - self.services - .state - .get_room_shortstatehash(room_id) - .and_then(|shortstatehash| self.state_get(shortstatehash, event_type, state_key)) - .await - } - - /// The user was a joined member at this state (potentially in the past) - #[inline] - async fn user_was_joined(&self, shortstatehash: ShortStateHash, user_id: &UserId) -> bool { - self.user_membership(shortstatehash, user_id).await == MembershipState::Join - } - - /// The user was an invited or joined room member at this state (potentially - /// in the past) - #[inline] - async fn user_was_invited(&self, shortstatehash: ShortStateHash, user_id: &UserId) -> bool { - let s = self.user_membership(shortstatehash, user_id).await; - s == MembershipState::Join || s == MembershipState::Invite - } - - /// Get membership for given user in state - async fn user_membership( - &self, - shortstatehash: ShortStateHash, - user_id: &UserId, - ) -> MembershipState { - self.state_get_content(shortstatehash, &StateEventType::RoomMember, user_id.as_str()) - .await - .map_or(MembershipState::Leave, |c: RoomMemberEventContent| c.membership) - } - - /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). - pub async fn state_get_content( - &self, - shortstatehash: ShortStateHash, - event_type: &StateEventType, - state_key: &str, - ) -> Result - where - T: for<'de> Deserialize<'de>, - { - self.state_get(shortstatehash, event_type, state_key) - .await - .and_then(|event| event.get_content()) - } - - #[tracing::instrument(skip(self), level = "debug")] - pub async fn state_contains( - &self, - shortstatehash: ShortStateHash, - event_type: &StateEventType, - state_key: &str, - ) -> bool { - let Ok(shortstatekey) = self - .services - .short - .get_shortstatekey(event_type, state_key) - .await - else { - return false; - }; - - self.state_contains_shortstatekey(shortstatehash, shortstatekey) - .await - } - - #[tracing::instrument(skip(self), level = "debug")] - pub async fn state_contains_shortstatekey( - &self, - shortstatehash: ShortStateHash, - shortstatekey: ShortStateKey, - ) -> bool { - let start = compress_state_event(shortstatekey, 0); - let end = compress_state_event(shortstatekey, u64::MAX); - - self.load_full_state(shortstatehash) - .map_ok(|full_state| full_state.range(start..end).next().copied()) - .await - .flat_ok() - .is_some() - } - - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). - pub async fn state_get( - &self, - shortstatehash: ShortStateHash, - event_type: &StateEventType, - state_key: &str, - ) -> Result { - self.state_get_id(shortstatehash, event_type, state_key) - .and_then(|event_id: OwnedEventId| async move { - self.services.timeline.get_pdu(&event_id).await - }) - .await - } - - /// Returns a single EventId from `room_id` with key (`event_type`, - /// `state_key`). - #[tracing::instrument(skip(self), level = "debug")] - pub async fn state_get_id( - &self, - shortstatehash: ShortStateHash, - event_type: &StateEventType, - state_key: &str, - ) -> Result - where - Id: for<'de> Deserialize<'de> + Sized + ToOwned, - ::Owned: Borrow, - { - let shorteventid = self - .state_get_shortid(shortstatehash, event_type, state_key) - .await?; - - self.services - .short - .get_eventid_from_short(shorteventid) - .await - } - - /// Returns a single EventId from `room_id` with key (`event_type`, - /// `state_key`). - #[tracing::instrument(skip(self), level = "debug")] - pub async fn state_get_shortid( - &self, - shortstatehash: ShortStateHash, - event_type: &StateEventType, - state_key: &str, - ) -> Result { - let shortstatekey = self - .services - .short - .get_shortstatekey(event_type, state_key) - .await?; - - let start = compress_state_event(shortstatekey, 0); - let end = compress_state_event(shortstatekey, u64::MAX); - self.load_full_state(shortstatehash) - .map_ok(|full_state| { - full_state - .range(start..end) - .next() - .copied() - .map(parse_compressed_state_event) - .map(at!(1)) - .ok_or(err!(Request(NotFound("Not found in room state")))) - }) - .await? - } - - /// Returns the state events removed between the interval (present in .0 but - /// not in .1) - #[inline] - pub fn state_removed( - &self, - shortstatehash: pair_of!(ShortStateHash), - ) -> impl Stream + Send + '_ { - self.state_added((shortstatehash.1, shortstatehash.0)) - } - - /// Returns the state events added between the interval (present in .1 but - /// not in .0) - #[tracing::instrument(skip(self), level = "debug")] - pub fn state_added<'a>( - &'a self, - shortstatehash: pair_of!(ShortStateHash), - ) -> impl Stream + Send + 'a { - let a = self.load_full_state(shortstatehash.0); - let b = self.load_full_state(shortstatehash.1); - try_join(a, b) - .map_ok(|(a, b)| b.difference(&a).copied().collect::>()) - .map_ok(IterStream::try_stream) - .try_flatten_stream() - .expect_ok() - .map(parse_compressed_state_event) - } - - pub fn state_full( - &self, - shortstatehash: ShortStateHash, - ) -> impl Stream + Send + '_ { - self.state_full_pdus(shortstatehash) - .ready_filter_map(|pdu| { - Some(((pdu.kind.to_string().into(), pdu.state_key.clone()?), pdu)) - }) - } - - pub fn state_full_pdus( - &self, - shortstatehash: ShortStateHash, - ) -> impl Stream + Send + '_ { - let short_ids = self - .state_full_shortids(shortstatehash) - .expect_ok() - .map(at!(1)); - - self.services - .short - .multi_get_eventid_from_short(short_ids) - .ready_filter_map(Result::ok) - .broad_filter_map(move |event_id: OwnedEventId| async move { - self.services.timeline.get_pdu(&event_id).await.ok() - }) - } - - /// Builds a StateMap by iterating over all keys that start - /// with state_hash, this gives the full state for the given state_hash. - #[tracing::instrument(skip(self), level = "debug")] - pub fn state_full_ids<'a, Id>( - &'a self, - shortstatehash: ShortStateHash, - ) -> impl Stream + Send + 'a - where - Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned + 'a, - ::Owned: Borrow, - { - let shortids = self - .state_full_shortids(shortstatehash) - .expect_ok() - .unzip() - .shared(); - - let shortstatekeys = shortids - .clone() - .map(at!(0)) - .map(Vec::into_iter) - .map(IterStream::stream) - .flatten_stream(); - - let shorteventids = shortids - .map(at!(1)) - .map(Vec::into_iter) - .map(IterStream::stream) - .flatten_stream(); - - self.services - .short - .multi_get_eventid_from_short(shorteventids) - .zip(shortstatekeys) - .ready_filter_map(|(event_id, shortstatekey)| Some((shortstatekey, event_id.ok()?))) - } - - pub fn state_full_shortids( - &self, - shortstatehash: ShortStateHash, - ) -> impl Stream> + Send + '_ { - self.load_full_state(shortstatehash) - .map_ok(|full_state| { - full_state - .deref() - .iter() - .copied() - .map(parse_compressed_state_event) - .collect() - }) - .map_ok(|vec: Vec<_>| vec.into_iter().try_stream()) - .try_flatten_stream() - } - - async fn load_full_state( - &self, - shortstatehash: ShortStateHash, - ) -> Result> { - self.services - .state_compressor - .load_shortstatehash_info(shortstatehash) - .map_err(|e| err!(Database("Missing state IDs: {e}"))) - .map_ok(|vec| vec.last().expect("at least one layer").full_state.clone()) - .await - } - - /// Returns the state hash for this pdu. - pub async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result { - const BUFSIZE: usize = size_of::(); - - self.services - .short - .get_shorteventid(event_id) - .and_then(|shorteventid| { - self.db - .shorteventid_shortstatehash - .aqry::(&shorteventid) - }) - .await - .deserialized() - } - - /// Whether a server is allowed to see an event through federation, based on - /// the room's history_visibility at that event's state. - #[tracing::instrument(skip_all, level = "trace")] - pub async fn server_can_see_event( - &self, - origin: &ServerName, - room_id: &RoomId, - event_id: &EventId, - ) -> bool { - let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else { - return true; - }; - - if let Some(visibility) = self - .server_visibility_cache - .lock() - .expect("locked") - .get_mut(&(origin.to_owned(), shortstatehash)) - { - return *visibility; - } - - let history_visibility = self - .state_get_content(shortstatehash, &StateEventType::RoomHistoryVisibility, "") - .await - .map_or(HistoryVisibility::Shared, |c: RoomHistoryVisibilityEventContent| { - c.history_visibility - }); - - let current_server_members = self - .services - .state_cache - .room_members(room_id) - .ready_filter(|member| member.server_name() == origin); - - let visibility = match history_visibility { - | HistoryVisibility::WorldReadable | HistoryVisibility::Shared => true, - | HistoryVisibility::Invited => { - // Allow if any member on requesting server was AT LEAST invited, else deny - current_server_members - .any(|member| self.user_was_invited(shortstatehash, member)) - .await - }, - | HistoryVisibility::Joined => { - // Allow if any member on requested server was joined, else deny - current_server_members - .any(|member| self.user_was_joined(shortstatehash, member)) - .await - }, - | _ => { - error!("Unknown history visibility {history_visibility}"); - false - }, - }; - - self.server_visibility_cache - .lock() - .expect("locked") - .insert((origin.to_owned(), shortstatehash), visibility); - - visibility - } - - /// Whether a user is allowed to see an event, based on - /// the room's history_visibility at that event's state. - #[tracing::instrument(skip_all, level = "trace")] - pub async fn user_can_see_event( - &self, - user_id: &UserId, - room_id: &RoomId, - event_id: &EventId, - ) -> bool { - let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else { - return true; - }; - - if let Some(visibility) = self - .user_visibility_cache - .lock() - .expect("locked") - .get_mut(&(user_id.to_owned(), shortstatehash)) - { - return *visibility; - } - - let currently_member = self.services.state_cache.is_joined(user_id, room_id).await; - - let history_visibility = self - .state_get_content(shortstatehash, &StateEventType::RoomHistoryVisibility, "") - .await - .map_or(HistoryVisibility::Shared, |c: RoomHistoryVisibilityEventContent| { - c.history_visibility - }); - - let visibility = match history_visibility { - | HistoryVisibility::WorldReadable => true, - | HistoryVisibility::Shared => currently_member, - | HistoryVisibility::Invited => { - // Allow if any member on requesting server was AT LEAST invited, else deny - self.user_was_invited(shortstatehash, user_id).await - }, - | HistoryVisibility::Joined => { - // Allow if any member on requested server was joined, else deny - self.user_was_joined(shortstatehash, user_id).await - }, - | _ => { - error!("Unknown history visibility {history_visibility}"); - false - }, - }; - - self.user_visibility_cache - .lock() - .expect("locked") - .insert((user_id.to_owned(), shortstatehash), visibility); - - visibility - } - - /// Whether a user is allowed to see an event, based on - /// the room's history_visibility at that event's state. - #[tracing::instrument(skip_all, level = "trace")] - pub async fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> bool { - if self.services.state_cache.is_joined(user_id, room_id).await { - return true; - } - - let history_visibility = self - .room_state_get_content(room_id, &StateEventType::RoomHistoryVisibility, "") - .await - .map_or(HistoryVisibility::Shared, |c: RoomHistoryVisibilityEventContent| { - c.history_visibility - }); - - match history_visibility { - | HistoryVisibility::Invited => - self.services.state_cache.is_invited(user_id, room_id).await, - | HistoryVisibility::WorldReadable => true, - | _ => false, - } - } - pub async fn get_name(&self, room_id: &RoomId) -> Result { self.room_state_get_content(room_id, &StateEventType::RoomName, "") .await @@ -669,28 +153,6 @@ impl Service { .await } - pub async fn user_can_invite( - &self, - room_id: &RoomId, - sender: &UserId, - target_user: &UserId, - state_lock: &RoomMutexGuard, - ) -> bool { - self.services - .timeline - .create_hash_and_sign_event( - PduBuilder::state( - target_user.into(), - &RoomMemberEventContent::new(MembershipState::Invite), - ), - sender, - room_id, - state_lock, - ) - .await - .is_ok() - } - /// Checks if guests are able to view room content without joining pub async fn is_world_readable(&self, room_id: &RoomId) -> bool { self.room_state_get_content(room_id, &StateEventType::RoomHistoryVisibility, "") @@ -726,74 +188,6 @@ impl Service { .map(|c: RoomTopicEventContent| c.topic) } - /// Checks if a given user can redact a given event - /// - /// If federation is true, it allows redaction events from any user of the - /// same server as the original event sender - pub async fn user_can_redact( - &self, - redacts: &EventId, - sender: &UserId, - room_id: &RoomId, - federation: bool, - ) -> Result { - let redacting_event = self.services.timeline.get_pdu(redacts).await; - - if redacting_event - .as_ref() - .is_ok_and(|pdu| pdu.kind == TimelineEventType::RoomCreate) - { - return Err!(Request(Forbidden("Redacting m.room.create is not safe, forbidding."))); - } - - if redacting_event - .as_ref() - .is_ok_and(|pdu| pdu.kind == TimelineEventType::RoomServerAcl) - { - return Err!(Request(Forbidden( - "Redacting m.room.server_acl will result in the room being inaccessible for \ - everyone (empty allow key), forbidding." - ))); - } - - if let Ok(pl_event_content) = self - .room_state_get_content::( - room_id, - &StateEventType::RoomPowerLevels, - "", - ) - .await - { - let pl_event: RoomPowerLevels = pl_event_content.into(); - Ok(pl_event.user_can_redact_event_of_other(sender) - || pl_event.user_can_redact_own_event(sender) - && if let Ok(redacting_event) = redacting_event { - if federation { - redacting_event.sender.server_name() == sender.server_name() - } else { - redacting_event.sender == sender - } - } else { - false - }) - } else { - // Falling back on m.room.create to judge power level - if let Ok(room_create) = self - .room_state_get(room_id, &StateEventType::RoomCreate, "") - .await - { - Ok(room_create.sender == sender - || redacting_event - .as_ref() - .is_ok_and(|redacting_event| redacting_event.sender == sender)) - } else { - Err(Error::bad_database( - "No m.room.power_levels or m.room.create events in database for room", - )) - } - } - } - /// Returns the join rule (`SpaceRoomJoinRule`) for a given room pub async fn get_join_rule( &self, diff --git a/src/service/rooms/state_accessor/room_state.rs b/src/service/rooms/state_accessor/room_state.rs new file mode 100644 index 000000000..98a82cea7 --- /dev/null +++ b/src/service/rooms/state_accessor/room_state.rs @@ -0,0 +1,90 @@ +use std::borrow::Borrow; + +use conduwuit::{err, implement, PduEvent, Result}; +use futures::{Stream, StreamExt, TryFutureExt}; +use ruma::{events::StateEventType, EventId, RoomId}; +use serde::Deserialize; + +/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). +#[implement(super::Service)] +pub async fn room_state_get_content( + &self, + room_id: &RoomId, + event_type: &StateEventType, + state_key: &str, +) -> Result +where + T: for<'de> Deserialize<'de>, +{ + self.room_state_get(room_id, event_type, state_key) + .await + .and_then(|event| event.get_content()) +} + +/// Returns the full room state. +#[implement(super::Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn room_state_full<'a>( + &'a self, + room_id: &'a RoomId, +) -> impl Stream> + Send + 'a { + self.services + .state + .get_room_shortstatehash(room_id) + .map_ok(|shortstatehash| self.state_full(shortstatehash).map(Ok)) + .map_err(move |e| err!(Database("Missing state for {room_id:?}: {e:?}"))) + .try_flatten_stream() +} + +/// Returns the full room state pdus +#[implement(super::Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn room_state_full_pdus<'a>( + &'a self, + room_id: &'a RoomId, +) -> impl Stream> + Send + 'a { + self.services + .state + .get_room_shortstatehash(room_id) + .map_ok(|shortstatehash| self.state_full_pdus(shortstatehash).map(Ok)) + .map_err(move |e| err!(Database("Missing state for {room_id:?}: {e:?}"))) + .try_flatten_stream() +} + +/// Returns a single EventId from `room_id` with key (`event_type`, +/// `state_key`). +#[implement(super::Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub async fn room_state_get_id( + &self, + room_id: &RoomId, + event_type: &StateEventType, + state_key: &str, +) -> Result +where + Id: for<'de> Deserialize<'de> + Sized + ToOwned, + ::Owned: Borrow, +{ + self.services + .state + .get_room_shortstatehash(room_id) + .and_then(|shortstatehash| self.state_get_id(shortstatehash, event_type, state_key)) + .await +} + +/// Returns a single PDU from `room_id` with key (`event_type`, +/// `state_key`). +#[implement(super::Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub async fn room_state_get( + &self, + room_id: &RoomId, + event_type: &StateEventType, + state_key: &str, +) -> Result { + self.services + .state + .get_room_shortstatehash(room_id) + .and_then(|shortstatehash| self.state_get(shortstatehash, event_type, state_key)) + .await +} diff --git a/src/service/rooms/state_accessor/server_can.rs b/src/service/rooms/state_accessor/server_can.rs new file mode 100644 index 000000000..4d8342275 --- /dev/null +++ b/src/service/rooms/state_accessor/server_can.rs @@ -0,0 +1,73 @@ +use conduwuit::{error, implement, utils::stream::ReadyExt}; +use futures::StreamExt; +use ruma::{ + events::{ + room::history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, + StateEventType, + }, + EventId, RoomId, ServerName, +}; + +/// Whether a server is allowed to see an event through federation, based on +/// the room's history_visibility at that event's state. +#[implement(super::Service)] +#[tracing::instrument(skip_all, level = "trace")] +pub async fn server_can_see_event( + &self, + origin: &ServerName, + room_id: &RoomId, + event_id: &EventId, +) -> bool { + let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else { + return true; + }; + + if let Some(visibility) = self + .server_visibility_cache + .lock() + .expect("locked") + .get_mut(&(origin.to_owned(), shortstatehash)) + { + return *visibility; + } + + let history_visibility = self + .state_get_content(shortstatehash, &StateEventType::RoomHistoryVisibility, "") + .await + .map_or(HistoryVisibility::Shared, |c: RoomHistoryVisibilityEventContent| { + c.history_visibility + }); + + let current_server_members = self + .services + .state_cache + .room_members(room_id) + .ready_filter(|member| member.server_name() == origin); + + let visibility = match history_visibility { + | HistoryVisibility::WorldReadable | HistoryVisibility::Shared => true, + | HistoryVisibility::Invited => { + // Allow if any member on requesting server was AT LEAST invited, else deny + current_server_members + .any(|member| self.user_was_invited(shortstatehash, member)) + .await + }, + | HistoryVisibility::Joined => { + // Allow if any member on requested server was joined, else deny + current_server_members + .any(|member| self.user_was_joined(shortstatehash, member)) + .await + }, + | _ => { + error!("Unknown history visibility {history_visibility}"); + false + }, + }; + + self.server_visibility_cache + .lock() + .expect("locked") + .insert((origin.to_owned(), shortstatehash), visibility); + + visibility +} diff --git a/src/service/rooms/state_accessor/state.rs b/src/service/rooms/state_accessor/state.rs new file mode 100644 index 000000000..c47a5693d --- /dev/null +++ b/src/service/rooms/state_accessor/state.rs @@ -0,0 +1,320 @@ +use std::{borrow::Borrow, ops::Deref, sync::Arc}; + +use conduwuit::{ + at, err, implement, pair_of, + utils::{ + result::FlatOk, + stream::{BroadbandExt, IterStream, ReadyExt, TryExpect}, + }, + PduEvent, Result, +}; +use database::Deserialized; +use futures::{future::try_join, FutureExt, Stream, StreamExt, TryFutureExt}; +use ruma::{ + events::{ + room::member::{MembershipState, RoomMemberEventContent}, + StateEventType, + }, + EventId, OwnedEventId, UserId, +}; +use serde::Deserialize; + +use crate::rooms::{ + short::{ShortEventId, ShortStateHash, ShortStateKey}, + state_compressor::{compress_state_event, parse_compressed_state_event, CompressedState}, +}; + +/// The user was a joined member at this state (potentially in the past) +#[implement(super::Service)] +#[inline] +pub async fn user_was_joined(&self, shortstatehash: ShortStateHash, user_id: &UserId) -> bool { + self.user_membership(shortstatehash, user_id).await == MembershipState::Join +} + +/// The user was an invited or joined room member at this state (potentially +/// in the past) +#[implement(super::Service)] +#[inline] +pub async fn user_was_invited(&self, shortstatehash: ShortStateHash, user_id: &UserId) -> bool { + let s = self.user_membership(shortstatehash, user_id).await; + s == MembershipState::Join || s == MembershipState::Invite +} + +/// Get membership for given user in state +#[implement(super::Service)] +pub async fn user_membership( + &self, + shortstatehash: ShortStateHash, + user_id: &UserId, +) -> MembershipState { + self.state_get_content(shortstatehash, &StateEventType::RoomMember, user_id.as_str()) + .await + .map_or(MembershipState::Leave, |c: RoomMemberEventContent| c.membership) +} + +/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). +#[implement(super::Service)] +pub async fn state_get_content( + &self, + shortstatehash: ShortStateHash, + event_type: &StateEventType, + state_key: &str, +) -> Result +where + T: for<'de> Deserialize<'de>, +{ + self.state_get(shortstatehash, event_type, state_key) + .await + .and_then(|event| event.get_content()) +} + +#[implement(super::Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub async fn state_contains( + &self, + shortstatehash: ShortStateHash, + event_type: &StateEventType, + state_key: &str, +) -> bool { + let Ok(shortstatekey) = self + .services + .short + .get_shortstatekey(event_type, state_key) + .await + else { + return false; + }; + + self.state_contains_shortstatekey(shortstatehash, shortstatekey) + .await +} + +#[implement(super::Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub async fn state_contains_shortstatekey( + &self, + shortstatehash: ShortStateHash, + shortstatekey: ShortStateKey, +) -> bool { + let start = compress_state_event(shortstatekey, 0); + let end = compress_state_event(shortstatekey, u64::MAX); + + self.load_full_state(shortstatehash) + .map_ok(|full_state| full_state.range(start..=end).next().copied()) + .await + .flat_ok() + .is_some() +} + +/// Returns a single PDU from `room_id` with key (`event_type`, +/// `state_key`). +#[implement(super::Service)] +pub async fn state_get( + &self, + shortstatehash: ShortStateHash, + event_type: &StateEventType, + state_key: &str, +) -> Result { + self.state_get_id(shortstatehash, event_type, state_key) + .and_then(|event_id: OwnedEventId| async move { + self.services.timeline.get_pdu(&event_id).await + }) + .await +} + +/// Returns a single EventId from `room_id` with key (`event_type`, +/// `state_key`). +#[implement(super::Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub async fn state_get_id( + &self, + shortstatehash: ShortStateHash, + event_type: &StateEventType, + state_key: &str, +) -> Result +where + Id: for<'de> Deserialize<'de> + Sized + ToOwned, + ::Owned: Borrow, +{ + let shorteventid = self + .state_get_shortid(shortstatehash, event_type, state_key) + .await?; + + self.services + .short + .get_eventid_from_short(shorteventid) + .await +} + +/// Returns a single EventId from `room_id` with key (`event_type`, +/// `state_key`). +#[implement(super::Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub async fn state_get_shortid( + &self, + shortstatehash: ShortStateHash, + event_type: &StateEventType, + state_key: &str, +) -> Result { + let shortstatekey = self + .services + .short + .get_shortstatekey(event_type, state_key) + .await?; + + let start = compress_state_event(shortstatekey, 0); + let end = compress_state_event(shortstatekey, u64::MAX); + self.load_full_state(shortstatehash) + .map_ok(|full_state| { + full_state + .range(start..=end) + .next() + .copied() + .map(parse_compressed_state_event) + .map(at!(1)) + .ok_or(err!(Request(NotFound("Not found in room state")))) + }) + .await? +} + +/// Returns the state events removed between the interval (present in .0 but +/// not in .1) +#[implement(super::Service)] +#[inline] +pub fn state_removed( + &self, + shortstatehash: pair_of!(ShortStateHash), +) -> impl Stream + Send + '_ { + self.state_added((shortstatehash.1, shortstatehash.0)) +} + +/// Returns the state events added between the interval (present in .1 but +/// not in .0) +#[implement(super::Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn state_added<'a>( + &'a self, + shortstatehash: pair_of!(ShortStateHash), +) -> impl Stream + Send + 'a { + let a = self.load_full_state(shortstatehash.0); + let b = self.load_full_state(shortstatehash.1); + try_join(a, b) + .map_ok(|(a, b)| b.difference(&a).copied().collect::>()) + .map_ok(IterStream::try_stream) + .try_flatten_stream() + .expect_ok() + .map(parse_compressed_state_event) +} + +#[implement(super::Service)] +pub fn state_full( + &self, + shortstatehash: ShortStateHash, +) -> impl Stream + Send + '_ { + self.state_full_pdus(shortstatehash) + .ready_filter_map(|pdu| { + Some(((pdu.kind.to_string().into(), pdu.state_key.clone()?), pdu)) + }) +} + +#[implement(super::Service)] +pub fn state_full_pdus( + &self, + shortstatehash: ShortStateHash, +) -> impl Stream + Send + '_ { + let short_ids = self + .state_full_shortids(shortstatehash) + .expect_ok() + .map(at!(1)); + + self.services + .short + .multi_get_eventid_from_short(short_ids) + .ready_filter_map(Result::ok) + .broad_filter_map(move |event_id: OwnedEventId| async move { + self.services.timeline.get_pdu(&event_id).await.ok() + }) +} + +/// Builds a StateMap by iterating over all keys that start +/// with state_hash, this gives the full state for the given state_hash. +#[implement(super::Service)] +#[tracing::instrument(skip(self), level = "debug")] +pub fn state_full_ids<'a, Id>( + &'a self, + shortstatehash: ShortStateHash, +) -> impl Stream + Send + 'a +where + Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned + 'a, + ::Owned: Borrow, +{ + let shortids = self + .state_full_shortids(shortstatehash) + .expect_ok() + .unzip() + .shared(); + + let shortstatekeys = shortids + .clone() + .map(at!(0)) + .map(Vec::into_iter) + .map(IterStream::stream) + .flatten_stream(); + + let shorteventids = shortids + .map(at!(1)) + .map(Vec::into_iter) + .map(IterStream::stream) + .flatten_stream(); + + self.services + .short + .multi_get_eventid_from_short(shorteventids) + .zip(shortstatekeys) + .ready_filter_map(|(event_id, shortstatekey)| Some((shortstatekey, event_id.ok()?))) +} + +#[implement(super::Service)] +pub fn state_full_shortids( + &self, + shortstatehash: ShortStateHash, +) -> impl Stream> + Send + '_ { + self.load_full_state(shortstatehash) + .map_ok(|full_state| { + full_state + .deref() + .iter() + .copied() + .map(parse_compressed_state_event) + .collect() + }) + .map_ok(|vec: Vec<_>| vec.into_iter().try_stream()) + .try_flatten_stream() +} + +#[implement(super::Service)] +async fn load_full_state(&self, shortstatehash: ShortStateHash) -> Result> { + self.services + .state_compressor + .load_shortstatehash_info(shortstatehash) + .map_err(|e| err!(Database("Missing state IDs: {e}"))) + .map_ok(|vec| vec.last().expect("at least one layer").full_state.clone()) + .await +} + +/// Returns the state hash for this pdu. +#[implement(super::Service)] +pub async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result { + const BUFSIZE: usize = size_of::(); + + self.services + .short + .get_shorteventid(event_id) + .and_then(|shorteventid| { + self.db + .shorteventid_shortstatehash + .aqry::(&shorteventid) + }) + .await + .deserialized() +} diff --git a/src/service/rooms/state_accessor/user_can.rs b/src/service/rooms/state_accessor/user_can.rs new file mode 100644 index 000000000..725a4fba3 --- /dev/null +++ b/src/service/rooms/state_accessor/user_can.rs @@ -0,0 +1,187 @@ +use conduwuit::{error, implement, pdu::PduBuilder, Err, Error, Result}; +use ruma::{ + events::{ + room::{ + history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, + member::{MembershipState, RoomMemberEventContent}, + power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent}, + }, + StateEventType, TimelineEventType, + }, + EventId, RoomId, UserId, +}; + +use crate::rooms::state::RoomMutexGuard; + +/// Checks if a given user can redact a given event +/// +/// If federation is true, it allows redaction events from any user of the +/// same server as the original event sender +#[implement(super::Service)] +pub async fn user_can_redact( + &self, + redacts: &EventId, + sender: &UserId, + room_id: &RoomId, + federation: bool, +) -> Result { + let redacting_event = self.services.timeline.get_pdu(redacts).await; + + if redacting_event + .as_ref() + .is_ok_and(|pdu| pdu.kind == TimelineEventType::RoomCreate) + { + return Err!(Request(Forbidden("Redacting m.room.create is not safe, forbidding."))); + } + + if redacting_event + .as_ref() + .is_ok_and(|pdu| pdu.kind == TimelineEventType::RoomServerAcl) + { + return Err!(Request(Forbidden( + "Redacting m.room.server_acl will result in the room being inaccessible for \ + everyone (empty allow key), forbidding." + ))); + } + + if let Ok(pl_event_content) = self + .room_state_get_content::( + room_id, + &StateEventType::RoomPowerLevels, + "", + ) + .await + { + let pl_event: RoomPowerLevels = pl_event_content.into(); + Ok(pl_event.user_can_redact_event_of_other(sender) + || pl_event.user_can_redact_own_event(sender) + && if let Ok(redacting_event) = redacting_event { + if federation { + redacting_event.sender.server_name() == sender.server_name() + } else { + redacting_event.sender == sender + } + } else { + false + }) + } else { + // Falling back on m.room.create to judge power level + if let Ok(room_create) = self + .room_state_get(room_id, &StateEventType::RoomCreate, "") + .await + { + Ok(room_create.sender == sender + || redacting_event + .as_ref() + .is_ok_and(|redacting_event| redacting_event.sender == sender)) + } else { + Err(Error::bad_database( + "No m.room.power_levels or m.room.create events in database for room", + )) + } + } +} + +/// Whether a user is allowed to see an event, based on +/// the room's history_visibility at that event's state. +#[implement(super::Service)] +#[tracing::instrument(skip_all, level = "trace")] +pub async fn user_can_see_event( + &self, + user_id: &UserId, + room_id: &RoomId, + event_id: &EventId, +) -> bool { + let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else { + return true; + }; + + if let Some(visibility) = self + .user_visibility_cache + .lock() + .expect("locked") + .get_mut(&(user_id.to_owned(), shortstatehash)) + { + return *visibility; + } + + let currently_member = self.services.state_cache.is_joined(user_id, room_id).await; + + let history_visibility = self + .state_get_content(shortstatehash, &StateEventType::RoomHistoryVisibility, "") + .await + .map_or(HistoryVisibility::Shared, |c: RoomHistoryVisibilityEventContent| { + c.history_visibility + }); + + let visibility = match history_visibility { + | HistoryVisibility::WorldReadable => true, + | HistoryVisibility::Shared => currently_member, + | HistoryVisibility::Invited => { + // Allow if any member on requesting server was AT LEAST invited, else deny + self.user_was_invited(shortstatehash, user_id).await + }, + | HistoryVisibility::Joined => { + // Allow if any member on requested server was joined, else deny + self.user_was_joined(shortstatehash, user_id).await + }, + | _ => { + error!("Unknown history visibility {history_visibility}"); + false + }, + }; + + self.user_visibility_cache + .lock() + .expect("locked") + .insert((user_id.to_owned(), shortstatehash), visibility); + + visibility +} + +/// Whether a user is allowed to see an event, based on +/// the room's history_visibility at that event's state. +#[implement(super::Service)] +#[tracing::instrument(skip_all, level = "trace")] +pub async fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> bool { + if self.services.state_cache.is_joined(user_id, room_id).await { + return true; + } + + let history_visibility = self + .room_state_get_content(room_id, &StateEventType::RoomHistoryVisibility, "") + .await + .map_or(HistoryVisibility::Shared, |c: RoomHistoryVisibilityEventContent| { + c.history_visibility + }); + + match history_visibility { + | HistoryVisibility::Invited => + self.services.state_cache.is_invited(user_id, room_id).await, + | HistoryVisibility::WorldReadable => true, + | _ => false, + } +} + +#[implement(super::Service)] +pub async fn user_can_invite( + &self, + room_id: &RoomId, + sender: &UserId, + target_user: &UserId, + state_lock: &RoomMutexGuard, +) -> bool { + self.services + .timeline + .create_hash_and_sign_event( + PduBuilder::state( + target_user.into(), + &RoomMemberEventContent::new(MembershipState::Invite), + ), + sender, + room_id, + state_lock, + ) + .await + .is_ok() +}