From 67a89c3557b074ed8bc6e2d45160503fb33bc896 Mon Sep 17 00:00:00 2001 From: Mostafa Abdelraouf Date: Mon, 26 Sep 2022 12:28:34 -0500 Subject: [PATCH 1/7] wip --- src/admin.rs | 2 +- src/client.rs | 18 +++-- src/main.rs | 3 + src/pool.rs | 165 ++++++++++++++++++++++++++------------------ src/query_router.rs | 1 + 5 files changed, 115 insertions(+), 74 deletions(-) diff --git a/src/admin.rs b/src/admin.rs index ed2d3de3..d6891afc 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -277,7 +277,7 @@ where for server in 0..pool.servers(shard) { let address = pool.address(shard, server); let pool_state = pool.pool_state(shard, server); - let banned = pool.is_banned(address, Some(address.role)); + let banned = pool.is_banned(address); res.put(data_row(&vec![ address.name(), // name diff --git a/src/client.rs b/src/client.rs index 3b0b0ea7..ce7a63a3 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,6 +1,6 @@ /// Handle clients by pretending to be a PostgreSQL server. use bytes::{Buf, BufMut, BytesMut}; -use log::{debug, error, info, trace}; +use log::{debug, error, info, trace, warn}; use std::collections::HashMap; use std::time::Instant; use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf}; @@ -9,6 +9,7 @@ use tokio::sync::broadcast::Receiver; use tokio::sync::mpsc::Sender; use crate::admin::{generate_server_info_for_admin, handle_admin}; +use crate::bans::BanReason; use crate::config::{get_config, Address}; use crate::constants::*; use crate::errors::Error; @@ -436,7 +437,7 @@ where ); if password_hash != password_response { - debug!("Password authentication failed"); + warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name); wrong_password(&mut write, username).await?; return Err(Error::ClientError); @@ -458,6 +459,7 @@ where ) .await?; + warn!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name); return Err(Error::ClientError); } }; @@ -466,7 +468,7 @@ where let password_hash = md5_hash_password(&username, &pool.settings.user.password, &salt); if password_hash != password_response { - debug!("Password authentication failed"); + warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name); wrong_password(&mut write, username).await?; return Err(Error::ClientError); @@ -658,6 +660,8 @@ where ), ) .await?; + + warn!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", self.pool_name, self.username, self.application_name.clone()); return Err(Error::ClientError); } }; @@ -1034,7 +1038,7 @@ where match server.send(message).await { Ok(_) => Ok(()), Err(err) => { - pool.ban(address, self.process_id); + pool.ban(address, self.process_id, BanReason::MessageSendFailed); Err(err) } } @@ -1056,7 +1060,7 @@ where Ok(result) => match result { Ok(message) => Ok(message), Err(err) => { - pool.ban(address, self.process_id); + pool.ban(address, self.process_id, BanReason::MessageReceiveFailed); error_response_terminal( &mut self.write, &format!("error receiving data from server: {:?}", err), @@ -1071,7 +1075,7 @@ where address, pool.settings.user.username ); server.mark_bad(); - pool.ban(address, self.process_id); + pool.ban(address, self.process_id, BanReason::StatementTimeout); error_response_terminal(&mut self.write, "pool statement timeout").await?; Err(Error::StatementTimeout) } @@ -1080,7 +1084,7 @@ where match server.recv().await { Ok(message) => Ok(message), Err(err) => { - pool.ban(address, self.process_id); + pool.ban(address, self.process_id, BanReason::MessageReceiveFailed); error_response_terminal( &mut self.write, &format!("error receiving data from server: {:?}", err), diff --git a/src/main.rs b/src/main.rs index 0d4bd37a..c2a6fb6d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -52,6 +52,7 @@ use std::sync::Arc; use tokio::sync::broadcast; mod admin; +mod bans; mod client; mod config; mod constants; @@ -139,6 +140,8 @@ async fn main() { let (stats_tx, stats_rx) = mpsc::channel(100_000); REPORTER.store(Arc::new(Reporter::new(stats_tx.clone()))); + bans::start_ban_manager(); + // Connection pool that allows to query all shards and replicas. match ConnectionPool::from_config(client_server_map.clone()).await { Ok(_) => (), diff --git a/src/pool.rs b/src/pool.rs index 34af354a..0774aa1d 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -2,16 +2,16 @@ use arc_swap::ArcSwap; use async_trait::async_trait; use bb8::{ManageConnection, Pool, PooledConnection}; use bytes::BytesMut; -use chrono::naive::NaiveDateTime; use log::{debug, error, info, warn}; use once_cell::sync::Lazy; -use parking_lot::{Mutex, RwLock}; +use parking_lot::Mutex; use rand::seq::SliceRandom; use rand::thread_rng; use std::collections::HashMap; use std::sync::Arc; use std::time::Instant; +use crate::bans::{self, BanReason, BanReporter}; use crate::config::{get_config, Address, Role, User}; use crate::errors::Error; @@ -19,14 +19,12 @@ use crate::server::Server; use crate::sharding::ShardingFunction; use crate::stats::{get_reporter, Reporter}; -pub type BanList = Arc>>>; pub type ClientServerMap = Arc>>; pub type PoolMap = HashMap<(String, String), ConnectionPool>; /// The connection pool, globally available. /// This is atomic and safe and read-optimized. /// The pool is recreated dynamically when the config is reloaded. pub static POOLS: Lazy> = Lazy::new(|| ArcSwap::from_pointee(HashMap::default())); - /// Pool mode: /// - transaction: server serves one transaction, /// - session: server is attached to the client. @@ -54,6 +52,8 @@ pub struct PoolSettings { // Number of shards. pub shards: usize, + pub name: String, + // Connecting user. pub user: User, @@ -76,6 +76,7 @@ impl Default for PoolSettings { pool_mode: PoolMode::Transaction, shards: 1, user: User::default(), + name: String::default(), default_role: None, query_parser_enabled: false, primary_reads_enabled: true, @@ -96,7 +97,7 @@ pub struct ConnectionPool { /// List of banned addresses (see above) /// that should not be queried. - banlist: BanList, + ban_reporter: BanReporter, /// The statistics aggregator runs in a separate task /// and receives stats from clients, servers, and the pool. @@ -124,7 +125,6 @@ impl ConnectionPool { for (_, user) in &pool_config.users { let mut shards = Vec::new(); let mut addresses = Vec::new(); - let mut banlist = Vec::new(); let mut shard_ids = pool_config .shards .clone() @@ -196,7 +196,6 @@ impl ConnectionPool { shards.push(pools); addresses.push(servers); - banlist.push(HashMap::new()); } assert_eq!(shards.len(), addresses.len()); @@ -204,10 +203,11 @@ impl ConnectionPool { let mut pool = ConnectionPool { databases: shards, addresses: addresses, - banlist: Arc::new(RwLock::new(banlist)), + ban_reporter: bans::get_ban_handler(), stats: get_reporter(), server_info: BytesMut::new(), settings: PoolSettings { + name: pool_name.clone(), pool_mode: match pool_config.pool_mode.as_str() { "transaction" => PoolMode::Transaction, "session" => PoolMode::Session, @@ -326,7 +326,7 @@ impl ConnectionPool { None => break, }; - if self.is_banned(&address, role) { + if self.is_banned(&address) { debug!("Address {:?} is banned", address); continue; } @@ -342,7 +342,7 @@ impl ConnectionPool { Ok(conn) => conn, Err(err) => { error!("Banning instance {:?}, error: {:?}", address, err); - self.ban(&address, process_id); + self.ban(&address, process_id, BanReason::FailedCheckout); self.stats.client_checkout_error(process_id, address.id); continue; } @@ -397,7 +397,7 @@ impl ConnectionPool { // Don't leave a bad connection in the pool. server.mark_bad(); - self.ban(&address, process_id); + self.ban(&address, process_id, BanReason::FailedHealthCheck); continue; } }, @@ -411,7 +411,7 @@ impl ConnectionPool { // Don't leave a bad connection in the pool. server.mark_bad(); - self.ban(&address, process_id); + self.ban(&address, process_id, BanReason::FailedHealthCheck); continue; } } @@ -421,74 +421,107 @@ impl ConnectionPool { } /// Ban an address (i.e. replica). It no longer will serve - /// traffic for any new transactions. Existing transactions on that replica - /// will finish successfully or error out to the clients. - pub fn ban(&self, address: &Address, client_id: i32) { + /// traffic for any new transactions. If this call bans the last + /// replica in a shard, we unban all replicas + pub fn ban(&self, address: &Address, client_id: i32, reason: BanReason) { error!("Banning {:?}", address); self.stats.client_ban_error(client_id, address.id); - let now = chrono::offset::Utc::now().naive_utc(); - let mut guard = self.banlist.write(); - guard[address.shard].insert(address.clone(), now); - } + // Primaries cannot be banned + if address.role == Role::Primary { + return; + } - /// Clear the replica to receive traffic again. Takes effect immediately - /// for all new transactions. - pub fn _unban(&self, address: &Address) { - let mut guard = self.banlist.write(); - guard[address.shard].remove(address); - } + // We check if banning this address will result in all replica being banned + // If so, we unban all replicas instead + let pool_banned_addresses = self.ban_reporter.banlist( + self.settings.name.clone(), + self.settings.user.username.clone(), + ); - /// Check if a replica can serve traffic. If all replicas are banned, - /// we unban all of them. Better to try then not to. - pub fn is_banned(&self, address: &Address, role: Option) -> bool { - let replicas_available = match role { - Some(Role::Replica) => self.addresses[address.shard] + let unbanned_count = self.addresses[address.shard] + .iter() + .filter(|addr| addr.role == Role::Replica) + .filter(|addr| + // Return true if address is not banned + match pool_banned_addresses.get(addr) { + Some(ban_entry) => ban_entry.has_expired(), + // We assume the address that is to be banned is already banned + None => address != *addr, + }) + .count(); + if unbanned_count == 0 { + // All replicas are banned + // Unban everything + warn!("Unbanning all replicas."); + self.addresses[address.shard] .iter() .filter(|addr| addr.role == Role::Replica) - .count(), - None => self.addresses[address.shard].len(), - Some(Role::Primary) => return false, // Primary cannot be banned. - }; + .for_each(|address| self.unban(address)); + return; + } - debug!("Available targets for {:?}: {}", role, replicas_available); + match reason { + BanReason::FailedHealthCheck => self.ban_reporter.report_failed_healthcheck( + self.settings.name.clone(), + self.settings.user.username.clone(), + address.clone(), + ), + BanReason::MessageSendFailed => self.ban_reporter.report_server_send_failed( + self.settings.name.clone(), + self.settings.user.username.clone(), + address.clone(), + ), + BanReason::MessageReceiveFailed => self.ban_reporter.report_server_receive_failed( + self.settings.name.clone(), + self.settings.user.username.clone(), + address.clone(), + ), + BanReason::StatementTimeout => self.ban_reporter.report_statement_timeout( + self.settings.name.clone(), + self.settings.user.username.clone(), + address.clone(), + ), + BanReason::FailedCheckout => self.ban_reporter.report_failed_checkout( + self.settings.name.clone(), + self.settings.user.username.clone(), + address.clone(), + ), + BanReason::ManualBan => unreachable!(), + } + } - let guard = self.banlist.read(); + /// Clear the replica to receive traffic again. ban/unban operations + /// are not synchronous but are typically very fast + pub fn unban(&self, address: &Address) { + self.ban_reporter.unban( + self.settings.name.clone(), + self.settings.user.username.clone(), + address.clone(), + ); + } - // Everything is banned = nothing is banned. - if guard[address.shard].len() == replicas_available { - drop(guard); - let mut guard = self.banlist.write(); - guard[address.shard].clear(); - drop(guard); - warn!("Unbanning all replicas."); + /// Check if a replica can serve traffic. + /// This is a hot codepath, called for each query during + /// the routing phase, we should keep it as fast as possible + pub fn is_banned(&self, address: &Address) -> bool { + if address.role == Role::Primary { return false; } - // I expect this to miss 99.9999% of the time. - match guard[address.shard].get(address) { - Some(timestamp) => { - let now = chrono::offset::Utc::now().naive_utc(); - let config = get_config(); - - // Ban expired. - if now.timestamp() - timestamp.timestamp() > config.general.ban_time { - drop(guard); - warn!("Unbanning {:?}", address); - let mut guard = self.banlist.write(); - guard[address.shard].remove(address); - false - } else { - debug!("{:?} is banned", address); - true - } - } - - None => { - debug!("{:?} is ok", address); - false - } + let pool_banned_addresses = self.ban_reporter.banlist( + self.settings.name.clone(), + self.settings.user.username.clone(), + ); + if pool_banned_addresses.len() == 0 { + // We should hit this branch most of the time + return false; } + + return match pool_banned_addresses.get(address) { + Some(ban_entry) => ban_entry.has_expired(), + None => false, + }; } /// Get the number of configured shards. diff --git a/src/query_router.rs b/src/query_router.rs index f9d5f0b3..a52b311b 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -627,6 +627,7 @@ mod test { pool_mode: PoolMode::Transaction, shards: 0, user: crate::config::User::default(), + name: String::from("some_pool"), default_role: Some(Role::Replica), query_parser_enabled: true, primary_reads_enabled: false, From 6d587bf9027ded611eea84a53b82cb7436f42b2a Mon Sep 17 00:00:00 2001 From: Mostafa Abdelraouf Date: Mon, 26 Sep 2022 19:12:13 -0500 Subject: [PATCH 2/7] add ban.rs --- src/bans.rs | 341 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 341 insertions(+) create mode 100644 src/bans.rs diff --git a/src/bans.rs b/src/bans.rs new file mode 100644 index 00000000..1b239f88 --- /dev/null +++ b/src/bans.rs @@ -0,0 +1,341 @@ +use arc_swap::ArcSwap; +use log::error; +use once_cell::sync::Lazy; +use std::collections::HashMap; +use std::sync::Arc; + +use crate::config::get_config; +use crate::config::Address; +use tokio::sync::mpsc; +use tokio::sync::mpsc::error::TrySendError; +use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::time::Duration; +use tokio::time::Instant; + +pub type GBanList = HashMap<(String, String), HashMap>; + +#[derive(Debug, Clone, Copy)] +pub enum BanReason { + FailedHealthCheck, + MessageSendFailed, + MessageReceiveFailed, + FailedCheckout, + StatementTimeout, + #[allow(dead_code)] + ManualBan, +} +#[derive(Debug, Clone)] +pub struct BanEntry { + reason: BanReason, + time: Instant, + duration: Duration, +} + +#[derive(Debug, Clone)] +pub enum BanManagerEvent { + Ban { + address: Address, + pool_name: String, + username: String, + reason: BanReason, + }, + Unban { + address: Address, + pool_name: String, + username: String, + }, + CleanUpBanList, +} + +impl BanEntry { + pub fn has_expired(&self) -> bool { + return Instant::now().duration_since(self.time) > self.duration; + } + + pub fn is_active(&self) -> bool { + !self.has_expired() + } +} +static BANLIST: Lazy> = Lazy::new(|| ArcSwap::from_pointee(GBanList::default())); + +static BAN_REPORTER: Lazy> = + Lazy::new(|| ArcSwap::from_pointee(BanReporter::default())); + +#[derive(Clone, Debug)] +pub struct BanReporter { + channel_to_worker: Sender, +} + +impl Default for BanReporter { + fn default() -> BanReporter { + let (channel_to_worker, _rx) = channel(1000); + BanReporter { channel_to_worker } + } +} + +impl BanReporter { + /// Create a new Reporter instance. + pub fn new(channel_to_worker: Sender) -> BanReporter { + BanReporter { channel_to_worker } + } + + /// Send statistics to the task keeping track of stats. + fn send(&self, event: BanManagerEvent) { + let result = self.channel_to_worker.try_send(event.clone()); + + match result { + Ok(_) => (()), + Err(err) => match err { + TrySendError::Full { .. } => error!("event dropped, buffer full"), + TrySendError::Closed { .. } => error!("event dropped, channel closed"), + }, + }; + } + + pub fn report_failed_checkout(&self, pool_name: String, username: String, address: Address) { + let event = BanManagerEvent::Ban { + address: address, + pool_name: pool_name, + username: username, + reason: BanReason::FailedCheckout, + }; + self.send(event); + } + + + pub fn report_failed_healthcheck(&self, pool_name: String, username: String, address: Address) { + let event = BanManagerEvent::Ban { + address: address, + pool_name: pool_name, + username: username, + reason: BanReason::FailedHealthCheck, + }; + self.send(event); + } + + pub fn report_server_send_failed(&self, pool_name: String, username: String, address: Address) { + let event = BanManagerEvent::Ban { + address: address, + pool_name: pool_name, + username: username, + reason: BanReason::MessageSendFailed, + }; + self.send(event); + } + + pub fn report_server_receive_failed(&self, pool_name: String, username: String, address: Address) { + let event = BanManagerEvent::Ban { + address: address, + pool_name: pool_name, + username: username, + reason: BanReason::MessageReceiveFailed, + }; + self.send(event); + } + + pub fn report_statement_timeout(&self, pool_name: String, username: String, address: Address) { + let event = BanManagerEvent::Ban { + address: address, + pool_name: pool_name, + username: username, + reason: BanReason::StatementTimeout, + }; + self.send(event); + } + + #[allow(dead_code)] + pub fn report_manual_ban(&self, pool_name: String, username: String, address: Address) { + let event = BanManagerEvent::Ban { + address: address, + pool_name: pool_name, + username: username, + reason: BanReason::ManualBan, + }; + self.send(event); + } + + pub fn unban(&self, pool_name: String, username: String, address: Address) { + let event = BanManagerEvent::Unban { + address: address, + pool_name: pool_name, + username: username, + }; + self.send(event); + } + + pub fn banlist(&self, pool_name: String, username: String) -> HashMap { + let k = (pool_name, username); + match (*(*BANLIST.load())).get(&k) { + Some(banlist) => banlist.clone(), + None => HashMap::default(), + } + } + + #[allow(dead_code)] + pub fn is_banned(&self, pool_name: String, username: String, address: Address) -> bool { + let k = (pool_name, username); + match (*(*BANLIST.load())).get(&k) { + Some(pool_banlist) => match pool_banlist.get(&address) { + Some(ban_entry) => ban_entry.is_active(), + None => false, + }, + None => false, + } + } +} + +pub struct BanWorker { + work_queue_tx: Sender, + work_queue_rx: Receiver, +} + +impl BanWorker { + pub fn new() -> BanWorker { + let (work_queue_tx, work_queue_rx) = mpsc::channel(100_000); + BanWorker { + work_queue_tx, + work_queue_rx, + } + } + + pub fn get_reporter(&self) -> BanReporter { + BanReporter::new(self.work_queue_tx.clone()) + } + + pub async fn start(&mut self) { + let mut internal_ban_list: GBanList = GBanList::default(); + let tx = self.work_queue_tx.clone(); + + tokio::task::spawn(async move { + let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(60)); + loop { + interval.tick().await; + match tx.try_send(BanManagerEvent::CleanUpBanList) { + Ok(_) => (), + Err(err) => match err { + TrySendError::Full(_) => (), + TrySendError::Closed(_) => (), + }, + } + } + }); + + loop { + let event = match self.work_queue_rx.recv().await { + Some(stat) => stat, + None => { + return; + } + }; + + match event { + BanManagerEvent::Ban { + address, + pool_name, + username, + reason, + } => { + if self.ban(&mut internal_ban_list, address, pool_name, username, reason) { + // Ban list was changed, let's publish a new one + self.publish_banlist(&internal_ban_list); + } + } + BanManagerEvent::Unban { + address, + pool_name, + username, + } => { + if self.unban(&mut internal_ban_list, address, pool_name, username) { + // Ban list was changed, let's publish a new one + self.publish_banlist(&internal_ban_list); + } + } + BanManagerEvent::CleanUpBanList => { + self.cleanup_ban_list(&mut internal_ban_list); + } + }; + } + } + + fn publish_banlist(&self, internal_ban_list: &GBanList) { + BANLIST.store(Arc::new(internal_ban_list.clone())); + } + + fn cleanup_ban_list(&self, internal_ban_list: &mut GBanList) { + for (_, v) in internal_ban_list { + v.retain(|_k, v| v.is_active()); + } + } + + fn unban( + &self, + internal_ban_list: &mut GBanList, + address: Address, + pool_name: String, + username: String, + ) -> bool { + let k = (pool_name, username); + match internal_ban_list.get_mut(&k) { + Some(banlist) => { + if banlist.remove(&address).is_none() { + // Was Already not banned? Let's avoid publishing a new list + return false; + } + } + None => return false, // Was Already not banned? Let's avoid publishing a new list + } + return true; + } + + fn ban( + &self, + internal_ban_list: &mut GBanList, + address: Address, + pool_name: String, + username: String, + reason: BanReason, + ) -> bool { + let k = (pool_name.clone(), username.clone()); + let ban_time = Instant::now(); // Technically, ban time is when client made the call but this should be close enough + let config = get_config(); + let ban_duration = match reason { + BanReason::FailedHealthCheck + | BanReason::MessageReceiveFailed + | BanReason::MessageSendFailed + | BanReason::FailedCheckout + | BanReason::StatementTimeout => { + Duration::from_secs(config.general.ban_time.try_into().unwrap()) + } + BanReason::ManualBan => Duration::from_secs(86400), + }; + + let pool_banlist = internal_ban_list.entry(k).or_insert(HashMap::default()); + + let ban_entry = pool_banlist.entry(address.clone()).or_insert(BanEntry { + reason: reason, + time: ban_time, + duration: ban_duration, + }); + + let old_banned_until = ban_entry.time + ban_entry.duration; + let new_banned_until = ban_time + ban_duration; + if new_banned_until >= old_banned_until { + ban_entry.duration = ban_duration; + ban_entry.time = ban_time; + ban_entry.reason = reason; + } + + return true; + } +} + +pub fn start_ban_manager() { + let mut worker = BanWorker::new(); + BAN_REPORTER.store(Arc::new(worker.get_reporter())); + + tokio::task::spawn(async move { worker.start().await }); +} + +pub fn get_ban_handler() -> BanReporter { + return (*(*BAN_REPORTER.load())).clone(); +} From 7ca8f5329b5a03e72b8a93a68e04758b5ab8f6c0 Mon Sep 17 00:00:00 2001 From: Mostafa Abdelraouf Date: Sat, 8 Oct 2022 16:40:42 -0500 Subject: [PATCH 3/7] Optimize ban checks --- src/bans.rs | 10 ++++++---- src/config.rs | 5 +++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/bans.rs b/src/bans.rs index 1b239f88..549c1423 100644 --- a/src/bans.rs +++ b/src/bans.rs @@ -4,7 +4,7 @@ use once_cell::sync::Lazy; use std::collections::HashMap; use std::sync::Arc; -use crate::config::get_config; +use crate::config::get_ban_time; use crate::config::Address; use tokio::sync::mpsc; use tokio::sync::mpsc::error::TrySendError; @@ -295,20 +295,22 @@ impl BanWorker { username: String, reason: BanReason, ) -> bool { + let ban_duration_from_conf = get_ban_time(); let k = (pool_name.clone(), username.clone()); - let ban_time = Instant::now(); // Technically, ban time is when client made the call but this should be close enough - let config = get_config(); let ban_duration = match reason { BanReason::FailedHealthCheck | BanReason::MessageReceiveFailed | BanReason::MessageSendFailed | BanReason::FailedCheckout | BanReason::StatementTimeout => { - Duration::from_secs(config.general.ban_time.try_into().unwrap()) + Duration::from_secs(ban_duration_from_conf.try_into().unwrap()) } BanReason::ManualBan => Duration::from_secs(86400), }; + // Technically, ban time is when client made the call but this should be close enough + let ban_time = Instant::now(); + let pool_banlist = internal_ban_list.entry(k).or_insert(HashMap::default()); let ban_entry = pool_banlist.entry(address.clone()).or_insert(BanEntry { diff --git a/src/config.rs b/src/config.rs index 1cb37595..b4303df2 100644 --- a/src/config.rs +++ b/src/config.rs @@ -644,6 +644,11 @@ pub fn get_config() -> Config { (*(*CONFIG.load())).clone() } +pub fn get_ban_time() -> i64 { + (*(*CONFIG.load())).general.ban_time +} + + /// Parse the configuration file located at the path. pub async fn parse(path: &str) -> Result<(), Error> { let mut contents = String::new(); From 94a00de660bee6b11598bce729c75bfd866162a2 Mon Sep 17 00:00:00 2001 From: Mostafa Abdelraouf Date: Sat, 8 Oct 2022 19:29:46 -0500 Subject: [PATCH 4/7] Address comments --- .circleci/run_tests.sh | 2 +- src/bans.rs | 199 ++++++++++++++------------------ src/client.rs | 12 +- src/config.rs | 1 - src/main.rs | 2 +- src/pool.rs | 122 +++++++++----------- src/query_router.rs | 3 +- tests/docker/docker-compose.yml | 1 + 8 files changed, 156 insertions(+), 186 deletions(-) diff --git a/.circleci/run_tests.sh b/.circleci/run_tests.sh index 6ffef8ba..bd68b88e 100644 --- a/.circleci/run_tests.sh +++ b/.circleci/run_tests.sh @@ -91,7 +91,7 @@ cd tests/ruby sudo gem install bundler bundle install bundle exec ruby tests.rb || exit 1 -bundle exec rspec *_spec.rb || exit 1 +bundle exec rspec *_spec.rb --format documentation || exit 1 cd ../.. # diff --git a/src/bans.rs b/src/bans.rs index 549c1423..cf8c7af9 100644 --- a/src/bans.rs +++ b/src/bans.rs @@ -6,13 +6,14 @@ use std::sync::Arc; use crate::config::get_ban_time; use crate::config::Address; +use crate::pool::PoolIdentifier; use tokio::sync::mpsc; use tokio::sync::mpsc::error::TrySendError; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::time::Duration; use tokio::time::Instant; -pub type GBanList = HashMap<(String, String), HashMap>; +pub type BanList = HashMap>; #[derive(Debug, Clone, Copy)] pub enum BanReason { @@ -32,17 +33,15 @@ pub struct BanEntry { } #[derive(Debug, Clone)] -pub enum BanManagerEvent { +pub enum BanEvent { Ban { + pool_id: PoolIdentifier, address: Address, - pool_name: String, - username: String, reason: BanReason, }, Unban { + pool_id: PoolIdentifier, address: Address, - pool_name: String, - username: String, }, CleanUpBanList, } @@ -56,126 +55,113 @@ impl BanEntry { !self.has_expired() } } -static BANLIST: Lazy> = Lazy::new(|| ArcSwap::from_pointee(GBanList::default())); +static BANLIST: Lazy> = Lazy::new(|| ArcSwap::from_pointee(BanList::default())); -static BAN_REPORTER: Lazy> = - Lazy::new(|| ArcSwap::from_pointee(BanReporter::default())); +static BAN_MANAGER: Lazy> = + Lazy::new(|| ArcSwap::from_pointee(BanManager::default())); #[derive(Clone, Debug)] -pub struct BanReporter { - channel_to_worker: Sender, +pub struct BanManager { + channel_to_worker: Sender, } -impl Default for BanReporter { - fn default() -> BanReporter { +impl Default for BanManager { + fn default() -> BanManager { let (channel_to_worker, _rx) = channel(1000); - BanReporter { channel_to_worker } + BanManager { channel_to_worker } } } -impl BanReporter { +impl BanManager { /// Create a new Reporter instance. - pub fn new(channel_to_worker: Sender) -> BanReporter { - BanReporter { channel_to_worker } + pub fn new(channel_to_worker: Sender) -> BanManager { + BanManager { channel_to_worker } } /// Send statistics to the task keeping track of stats. - fn send(&self, event: BanManagerEvent) { - let result = self.channel_to_worker.try_send(event.clone()); + async fn send(&self, event: BanEvent) { + let result = self.channel_to_worker.send(event.clone()).await; match result { Ok(_) => (()), - Err(err) => match err { - TrySendError::Full { .. } => error!("event dropped, buffer full"), - TrySendError::Closed { .. } => error!("event dropped, channel closed"), - }, + Err(err) => error!("Failed to send ban event {:?}", err), }; } - pub fn report_failed_checkout(&self, pool_name: String, username: String, address: Address) { - let event = BanManagerEvent::Ban { - address: address, - pool_name: pool_name, - username: username, + pub async fn report_failed_checkout(&self, pool_id: &PoolIdentifier, address: &Address) { + let event = BanEvent::Ban { + pool_id: pool_id.clone(), + address: address.clone(), reason: BanReason::FailedCheckout, }; - self.send(event); + self.send(event).await } - - pub fn report_failed_healthcheck(&self, pool_name: String, username: String, address: Address) { - let event = BanManagerEvent::Ban { - address: address, - pool_name: pool_name, - username: username, + pub async fn report_failed_healthcheck(&self, pool_id: &PoolIdentifier, address: &Address) { + let event = BanEvent::Ban { + pool_id: pool_id.clone(), + address: address.clone(), reason: BanReason::FailedHealthCheck, }; - self.send(event); + self.send(event).await } - pub fn report_server_send_failed(&self, pool_name: String, username: String, address: Address) { - let event = BanManagerEvent::Ban { - address: address, - pool_name: pool_name, - username: username, + pub async fn report_server_send_failed(&self, pool_id: &PoolIdentifier, address: &Address) { + let event = BanEvent::Ban { + pool_id: pool_id.clone(), + address: address.clone(), reason: BanReason::MessageSendFailed, }; - self.send(event); + self.send(event).await } - pub fn report_server_receive_failed(&self, pool_name: String, username: String, address: Address) { - let event = BanManagerEvent::Ban { - address: address, - pool_name: pool_name, - username: username, + pub async fn report_server_receive_failed(&self, pool_id: &PoolIdentifier, address: &Address) { + let event = BanEvent::Ban { + pool_id: pool_id.clone(), + address: address.clone(), reason: BanReason::MessageReceiveFailed, }; - self.send(event); + self.send(event).await } - pub fn report_statement_timeout(&self, pool_name: String, username: String, address: Address) { - let event = BanManagerEvent::Ban { - address: address, - pool_name: pool_name, - username: username, + pub async fn report_statement_timeout(&self, pool_id: &PoolIdentifier, address: &Address) { + let event = BanEvent::Ban { + pool_id: pool_id.clone(), + address: address.clone(), reason: BanReason::StatementTimeout, }; - self.send(event); + self.send(event).await } #[allow(dead_code)] - pub fn report_manual_ban(&self, pool_name: String, username: String, address: Address) { - let event = BanManagerEvent::Ban { - address: address, - pool_name: pool_name, - username: username, + pub async fn report_manual_ban(&self, pool_id: &PoolIdentifier, address: &Address) { + let event = BanEvent::Ban { + pool_id: pool_id.clone(), + address: address.clone(), reason: BanReason::ManualBan, }; - self.send(event); + self.send(event).await } - pub fn unban(&self, pool_name: String, username: String, address: Address) { - let event = BanManagerEvent::Unban { - address: address, - pool_name: pool_name, - username: username, + pub async fn unban(&self, pool_id: &PoolIdentifier, address: &Address) { + let event = BanEvent::Unban { + pool_id: pool_id.clone(), + address: address.clone(), }; - self.send(event); + self.send(event).await; } - pub fn banlist(&self, pool_name: String, username: String) -> HashMap { - let k = (pool_name, username); - match (*(*BANLIST.load())).get(&k) { + pub fn banlist(&self, pool_id: &PoolIdentifier) -> HashMap { + match (*(*BANLIST.load())).get(pool_id) { Some(banlist) => banlist.clone(), None => HashMap::default(), } } #[allow(dead_code)] - pub fn is_banned(&self, pool_name: String, username: String, address: Address) -> bool { - let k = (pool_name, username); - match (*(*BANLIST.load())).get(&k) { - Some(pool_banlist) => match pool_banlist.get(&address) { + pub fn is_banned(&self, pool_id: &PoolIdentifier, address: &Address) -> bool { + match (*(*BANLIST.load())).get(pool_id) { + Some(pool_banlist) => match pool_banlist.get(address) { Some(ban_entry) => ban_entry.is_active(), None => false, }, @@ -185,8 +171,8 @@ impl BanReporter { } pub struct BanWorker { - work_queue_tx: Sender, - work_queue_rx: Receiver, + work_queue_tx: Sender, + work_queue_rx: Receiver, } impl BanWorker { @@ -198,19 +184,19 @@ impl BanWorker { } } - pub fn get_reporter(&self) -> BanReporter { - BanReporter::new(self.work_queue_tx.clone()) + pub fn get_reporter(&self) -> BanManager { + BanManager::new(self.work_queue_tx.clone()) } pub async fn start(&mut self) { - let mut internal_ban_list: GBanList = GBanList::default(); + let mut internal_ban_list: BanList = BanList::default(); let tx = self.work_queue_tx.clone(); tokio::task::spawn(async move { let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(60)); loop { interval.tick().await; - match tx.try_send(BanManagerEvent::CleanUpBanList) { + match tx.try_send(BanEvent::CleanUpBanList) { Ok(_) => (), Err(err) => match err { TrySendError::Full(_) => (), @@ -229,39 +215,34 @@ impl BanWorker { }; match event { - BanManagerEvent::Ban { + BanEvent::Ban { + pool_id, address, - pool_name, - username, reason, } => { - if self.ban(&mut internal_ban_list, address, pool_name, username, reason) { + if self.ban(&mut internal_ban_list, &pool_id, &address, reason) { // Ban list was changed, let's publish a new one self.publish_banlist(&internal_ban_list); } } - BanManagerEvent::Unban { - address, - pool_name, - username, - } => { - if self.unban(&mut internal_ban_list, address, pool_name, username) { + BanEvent::Unban { pool_id, address } => { + if self.unban(&mut internal_ban_list, &pool_id, &address) { // Ban list was changed, let's publish a new one self.publish_banlist(&internal_ban_list); } } - BanManagerEvent::CleanUpBanList => { + BanEvent::CleanUpBanList => { self.cleanup_ban_list(&mut internal_ban_list); } }; } } - fn publish_banlist(&self, internal_ban_list: &GBanList) { + fn publish_banlist(&self, internal_ban_list: &BanList) { BANLIST.store(Arc::new(internal_ban_list.clone())); } - fn cleanup_ban_list(&self, internal_ban_list: &mut GBanList) { + fn cleanup_ban_list(&self, internal_ban_list: &mut BanList) { for (_, v) in internal_ban_list { v.retain(|_k, v| v.is_active()); } @@ -269,34 +250,30 @@ impl BanWorker { fn unban( &self, - internal_ban_list: &mut GBanList, - address: Address, - pool_name: String, - username: String, + internal_ban_list: &mut BanList, + pool_id: &PoolIdentifier, + address: &Address, ) -> bool { - let k = (pool_name, username); - match internal_ban_list.get_mut(&k) { + match internal_ban_list.get_mut(pool_id) { Some(banlist) => { if banlist.remove(&address).is_none() { - // Was Already not banned? Let's avoid publishing a new list + // Was already not banned? Let's avoid publishing a new list return false; } } - None => return false, // Was Already not banned? Let's avoid publishing a new list + None => return false, // Was already not banned? Let's avoid publishing a new list } return true; } fn ban( &self, - internal_ban_list: &mut GBanList, - address: Address, - pool_name: String, - username: String, + internal_ban_list: &mut BanList, + pool_id: &PoolIdentifier, + address: &Address, reason: BanReason, ) -> bool { let ban_duration_from_conf = get_ban_time(); - let k = (pool_name.clone(), username.clone()); let ban_duration = match reason { BanReason::FailedHealthCheck | BanReason::MessageReceiveFailed @@ -311,7 +288,9 @@ impl BanWorker { // Technically, ban time is when client made the call but this should be close enough let ban_time = Instant::now(); - let pool_banlist = internal_ban_list.entry(k).or_insert(HashMap::default()); + let pool_banlist = internal_ban_list + .entry(pool_id.clone()) + .or_insert(HashMap::default()); let ban_entry = pool_banlist.entry(address.clone()).or_insert(BanEntry { reason: reason, @@ -331,13 +310,13 @@ impl BanWorker { } } -pub fn start_ban_manager() { +pub fn start_ban_worker() { let mut worker = BanWorker::new(); - BAN_REPORTER.store(Arc::new(worker.get_reporter())); + BAN_MANAGER.store(Arc::new(worker.get_reporter())); tokio::task::spawn(async move { worker.start().await }); } -pub fn get_ban_handler() -> BanReporter { - return (*(*BAN_REPORTER.load())).clone(); +pub fn get_ban_manager() -> BanManager { + return (*(*BAN_MANAGER.load())).clone(); } diff --git a/src/client.rs b/src/client.rs index 0f7400b6..1c4d9daa 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1038,7 +1038,8 @@ where match server.send(message).await { Ok(_) => Ok(()), Err(err) => { - pool.ban(address, self.process_id, BanReason::MessageSendFailed); + pool.ban(address, self.process_id, BanReason::MessageSendFailed) + .await; Err(err) } } @@ -1060,7 +1061,8 @@ where Ok(result) => match result { Ok(message) => Ok(message), Err(err) => { - pool.ban(address, self.process_id, BanReason::MessageReceiveFailed); + pool.ban(address, self.process_id, BanReason::MessageReceiveFailed) + .await; error_response_terminal( &mut self.write, &format!("error receiving data from server: {:?}", err), @@ -1075,7 +1077,8 @@ where address, pool.settings.user.username ); server.mark_bad(); - pool.ban(address, self.process_id, BanReason::StatementTimeout); + pool.ban(address, self.process_id, BanReason::StatementTimeout) + .await; error_response_terminal(&mut self.write, "pool statement timeout").await?; Err(Error::StatementTimeout) } @@ -1084,7 +1087,8 @@ where match server.recv().await { Ok(message) => Ok(message), Err(err) => { - pool.ban(address, self.process_id, BanReason::MessageReceiveFailed); + pool.ban(address, self.process_id, BanReason::MessageReceiveFailed) + .await; error_response_terminal( &mut self.write, &format!("error receiving data from server: {:?}", err), diff --git a/src/config.rs b/src/config.rs index b4303df2..d0564f7e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -648,7 +648,6 @@ pub fn get_ban_time() -> i64 { (*(*CONFIG.load())).general.ban_time } - /// Parse the configuration file located at the path. pub async fn parse(path: &str) -> Result<(), Error> { let mut contents = String::new(); diff --git a/src/main.rs b/src/main.rs index 705af379..62d8751c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -141,7 +141,7 @@ async fn main() { let (stats_tx, stats_rx) = mpsc::channel(100_000); REPORTER.store(Arc::new(Reporter::new(stats_tx.clone()))); - bans::start_ban_manager(); + bans::start_ban_worker(); // Connection pool that allows to query all shards and replicas. match ConnectionPool::from_config(client_server_map.clone()).await { diff --git a/src/pool.rs b/src/pool.rs index f6e15a67..b72bd358 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -11,7 +11,7 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; use std::time::Instant; -use crate::bans::{self, BanReporter, BanReason}; +use crate::bans::{self, BanManager, BanReason}; use crate::config::{get_config, Address, PoolMode, Role, User}; use crate::errors::Error; @@ -64,7 +64,7 @@ pub struct PoolSettings { // Number of shards. pub shards: usize, - pub name: String, + pub pool_id: PoolIdentifier, // Connecting user. pub user: User, @@ -87,8 +87,8 @@ impl Default for PoolSettings { PoolSettings { pool_mode: PoolMode::Transaction, shards: 1, + pool_id: PoolIdentifier::new("", ""), user: User::default(), - name: String::default(), default_role: None, query_parser_enabled: false, primary_reads_enabled: true, @@ -107,9 +107,8 @@ pub struct ConnectionPool { /// failover and load balancing deterministically. addresses: Vec>, - /// List of banned addresses (see above) - /// that should not be queried. - ban_reporter: BanReporter, + /// Reference to the global ban manager + ban_manager: BanManager, /// The statistics aggregator runs in a separate task /// and receives stats from clients, servers, and the pool. @@ -237,11 +236,11 @@ impl ConnectionPool { let mut pool = ConnectionPool { databases: shards, addresses: addresses, - ban_reporter: bans::get_ban_handler(), + ban_manager: bans::get_ban_manager(), stats: get_reporter(), server_info: BytesMut::new(), settings: PoolSettings { - name: pool_name.clone(), + pool_id: PoolIdentifier::new(pool_name, &user.username), pool_mode: pool_config.pool_mode, // shards: pool_config.shards.clone(), shards: shard_ids.len(), @@ -369,7 +368,8 @@ impl ConnectionPool { Ok(conn) => conn, Err(err) => { error!("Banning instance {:?}, error: {:?}", address, err); - self.ban(&address, process_id, BanReason::FailedCheckout); + self.ban(&address, process_id, BanReason::FailedCheckout) + .await; self.stats.client_checkout_error(process_id, address.id); continue; } @@ -424,7 +424,8 @@ impl ConnectionPool { // Don't leave a bad connection in the pool. server.mark_bad(); - self.ban(&address, process_id, BanReason::FailedHealthCheck); + self.ban(&address, process_id, BanReason::FailedHealthCheck) + .await; continue; } }, @@ -438,7 +439,8 @@ impl ConnectionPool { // Don't leave a bad connection in the pool. server.mark_bad(); - self.ban(&address, process_id, BanReason::FailedHealthCheck); + self.ban(&address, process_id, BanReason::FailedHealthCheck) + .await; continue; } } @@ -450,7 +452,7 @@ impl ConnectionPool { /// Ban an address (i.e. replica). It no longer will serve /// traffic for any new transactions. If this call bans the last /// replica in a shard, we unban all replicas - pub fn ban(&self, address: &Address, client_id: i32, reason: BanReason) { + pub async fn ban(&self, address: &Address, client_id: i32, reason: BanReason) { error!("Banning {:?}", address); self.stats.client_ban_error(client_id, address.id); @@ -461,71 +463,67 @@ impl ConnectionPool { // We check if banning this address will result in all replica being banned // If so, we unban all replicas instead - let pool_banned_addresses = self.ban_reporter.banlist( - self.settings.name.clone(), - self.settings.user.username.clone(), - ); + let pool_banned_addresses = self.ban_manager.banlist(&self.settings.pool_id); let unbanned_count = self.addresses[address.shard] - .iter() - .filter(|addr| addr.role == Role::Replica) - .filter(|addr| + .iter() + .filter(|addr| addr.role == Role::Replica) + .filter(|addr| // Return true if address is not banned match pool_banned_addresses.get(addr) { Some(ban_entry) => ban_entry.has_expired(), // We assume the address that is to be banned is already banned None => address != *addr, }) - .count(); + .count(); if unbanned_count == 0 { // All replicas are banned // Unban everything warn!("Unbanning all replicas."); - self.addresses[address.shard] - .iter() - .filter(|addr| addr.role == Role::Replica) - .for_each(|address| self.unban(address)); + for address in &self.addresses[address.shard] { + if address.role == Role::Replica { + self.unban(address).await + } + } return; } match reason { - BanReason::FailedHealthCheck => self.ban_reporter.report_failed_healthcheck( - self.settings.name.clone(), - self.settings.user.username.clone(), - address.clone(), - ), - BanReason::MessageSendFailed => self.ban_reporter.report_server_send_failed( - self.settings.name.clone(), - self.settings.user.username.clone(), - address.clone(), - ), - BanReason::MessageReceiveFailed => self.ban_reporter.report_server_receive_failed( - self.settings.name.clone(), - self.settings.user.username.clone(), - address.clone(), - ), - BanReason::StatementTimeout => self.ban_reporter.report_statement_timeout( - self.settings.name.clone(), - self.settings.user.username.clone(), - address.clone(), - ), - BanReason::FailedCheckout => self.ban_reporter.report_failed_checkout( - self.settings.name.clone(), - self.settings.user.username.clone(), - address.clone(), - ), + BanReason::FailedHealthCheck => { + self.ban_manager + .report_failed_healthcheck(&self.settings.pool_id, &address) + .await + } + BanReason::MessageSendFailed => { + self.ban_manager + .report_server_send_failed(&self.settings.pool_id, &address) + .await + } + BanReason::MessageReceiveFailed => { + self.ban_manager + .report_server_receive_failed(&self.settings.pool_id, &address) + .await + } + BanReason::StatementTimeout => { + self.ban_manager + .report_statement_timeout(&self.settings.pool_id, &address) + .await + } + BanReason::FailedCheckout => { + self.ban_manager + .report_failed_checkout(&self.settings.pool_id, &address) + .await + } BanReason::ManualBan => unreachable!(), } } /// Clear the replica to receive traffic again. ban/unban operations /// are not synchronous but are typically very fast - pub fn unban(&self, address: &Address) { - self.ban_reporter.unban( - self.settings.name.clone(), - self.settings.user.username.clone(), - address.clone(), - ); + pub async fn unban(&self, address: &Address) { + self.ban_manager + .unban(&self.settings.pool_id, &address) + .await; } /// Check if a replica can serve traffic. @@ -536,19 +534,7 @@ impl ConnectionPool { return false; } - let pool_banned_addresses = self.ban_reporter.banlist( - self.settings.name.clone(), - self.settings.user.username.clone(), - ); - if pool_banned_addresses.len() == 0 { - // We should hit this branch most of the time - return false; - } - - return match pool_banned_addresses.get(address) { - Some(ban_entry) => ban_entry.has_expired(), - None => false, - }; + return self.ban_manager.is_banned(&self.settings.pool_id, address); } /// Get the number of configured shards. diff --git a/src/query_router.rs b/src/query_router.rs index 09e8e9f3..3a5ee432 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -361,6 +361,7 @@ mod test { use super::*; use crate::config::PoolMode; use crate::messages::simple_query; + use crate::pool::PoolIdentifier; use crate::sharding::ShardingFunction; use bytes::BufMut; @@ -626,8 +627,8 @@ mod test { let pool_settings = PoolSettings { pool_mode: PoolMode::Transaction, shards: 0, + pool_id: PoolIdentifier::new("db", "user"), user: crate::config::User::default(), - name: String::from("some_pool"), default_role: Some(Role::Replica), query_parser_enabled: true, primary_reads_enabled: false, diff --git a/tests/docker/docker-compose.yml b/tests/docker/docker-compose.yml index d86e2399..0f7eb011 100644 --- a/tests/docker/docker-compose.yml +++ b/tests/docker/docker-compose.yml @@ -41,6 +41,7 @@ services: command: ["bash", "/app/tests/docker/run.sh"] environment: RUSTFLAGS: "-C instrument-coverage" + RUST_BACKTRACE: "1" LLVM_PROFILE_FILE: "pgcat-%m.profraw" volumes: - ../../:/app/ From 27e246eaba8287ae7ed92252413ea17e87bfa8ce Mon Sep 17 00:00:00 2001 From: Mostafa Abdelraouf Date: Sun, 9 Oct 2022 09:39:12 -0500 Subject: [PATCH 5/7] Add test for the unban all replicas case --- tests/ruby/load_balancing_spec.rb | 34 +++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/ruby/load_balancing_spec.rb b/tests/ruby/load_balancing_spec.rb index bd98a831..d1f38b3e 100644 --- a/tests/ruby/load_balancing_spec.rb +++ b/tests/ruby/load_balancing_spec.rb @@ -29,6 +29,40 @@ end end + context "when all replicas are down" do + it "unbans all replicas" do + conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + conn.async_exec("SET SERVER ROLE to 'replica'") + + 20.times { conn.async_exec("SELECT 9") } + + expected_share = QUERY_COUNT / (processes.all_databases.count - 2) + admin_conn = PG::connect(processes.pgcat.admin_connection_string) + processes[:replicas][0].take_down do + processes[:replicas][1].take_down do + processes[:replicas][2].take_down do + 3.times do + conn.async_exec("SELECT 9") + rescue + conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + conn.async_exec("SET SERVER ROLE to 'replica'") + end + end + end + end + + 50.times { conn.async_exec("SELECT 1 + 2") } + + # If all replicas were unbanned, we expect each replica to get at least + # on query after the unbanning event + processes.replicas.each do |instance| + queries_routed = instance.count_select_1_plus_2 + expect(queries_routed).to be > 1 + end + end + end + + context "when some replicas are down" do it "balances query volume between working instances" do conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) From b56bad764b1f6fb4e5e27bb5362350a34af579d6 Mon Sep 17 00:00:00 2001 From: Mostafa Abdelraouf Date: Mon, 10 Oct 2022 13:12:12 -0500 Subject: [PATCH 6/7] Remove channel --- src/bans.rs | 364 ++++++++++++++------------------------------------ src/client.rs | 12 +- src/main.rs | 2 - src/pool.rs | 49 +++---- 4 files changed, 119 insertions(+), 308 deletions(-) diff --git a/src/bans.rs b/src/bans.rs index cf8c7af9..a17a14ae 100644 --- a/src/bans.rs +++ b/src/bans.rs @@ -1,20 +1,15 @@ use arc_swap::ArcSwap; -use log::error; + use once_cell::sync::Lazy; +use parking_lot::Mutex; use std::collections::HashMap; use std::sync::Arc; use crate::config::get_ban_time; use crate::config::Address; use crate::pool::PoolIdentifier; -use tokio::sync::mpsc; -use tokio::sync::mpsc::error::TrySendError; -use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::time::Duration; use tokio::time::Instant; - -pub type BanList = HashMap>; - #[derive(Debug, Clone, Copy)] pub enum BanReason { FailedHealthCheck, @@ -31,21 +26,6 @@ pub struct BanEntry { time: Instant, duration: Duration, } - -#[derive(Debug, Clone)] -pub enum BanEvent { - Ban { - pool_id: PoolIdentifier, - address: Address, - reason: BanReason, - }, - Unban { - pool_id: PoolIdentifier, - address: Address, - }, - CleanUpBanList, -} - impl BanEntry { pub fn has_expired(&self) -> bool { return Instant::now().duration_since(self.time) > self.duration; @@ -55,268 +35,124 @@ impl BanEntry { !self.has_expired() } } +type BanList = HashMap>; static BANLIST: Lazy> = Lazy::new(|| ArcSwap::from_pointee(BanList::default())); - -static BAN_MANAGER: Lazy> = - Lazy::new(|| ArcSwap::from_pointee(BanManager::default())); - -#[derive(Clone, Debug)] -pub struct BanManager { - channel_to_worker: Sender, -} - -impl Default for BanManager { - fn default() -> BanManager { - let (channel_to_worker, _rx) = channel(1000); - BanManager { channel_to_worker } - } -} - -impl BanManager { - /// Create a new Reporter instance. - pub fn new(channel_to_worker: Sender) -> BanManager { - BanManager { channel_to_worker } - } - - /// Send statistics to the task keeping track of stats. - async fn send(&self, event: BanEvent) { - let result = self.channel_to_worker.send(event.clone()).await; - - match result { - Ok(_) => (()), - Err(err) => error!("Failed to send ban event {:?}", err), - }; - } - - pub async fn report_failed_checkout(&self, pool_id: &PoolIdentifier, address: &Address) { - let event = BanEvent::Ban { - pool_id: pool_id.clone(), - address: address.clone(), - reason: BanReason::FailedCheckout, - }; - self.send(event).await - } - - pub async fn report_failed_healthcheck(&self, pool_id: &PoolIdentifier, address: &Address) { - let event = BanEvent::Ban { - pool_id: pool_id.clone(), - address: address.clone(), - reason: BanReason::FailedHealthCheck, - }; - self.send(event).await - } - - pub async fn report_server_send_failed(&self, pool_id: &PoolIdentifier, address: &Address) { - let event = BanEvent::Ban { - pool_id: pool_id.clone(), - address: address.clone(), - reason: BanReason::MessageSendFailed, - }; - self.send(event).await - } - - pub async fn report_server_receive_failed(&self, pool_id: &PoolIdentifier, address: &Address) { - let event = BanEvent::Ban { - pool_id: pool_id.clone(), - address: address.clone(), - reason: BanReason::MessageReceiveFailed, - }; - self.send(event).await - } - - pub async fn report_statement_timeout(&self, pool_id: &PoolIdentifier, address: &Address) { - let event = BanEvent::Ban { - pool_id: pool_id.clone(), - address: address.clone(), - reason: BanReason::StatementTimeout, - }; - self.send(event).await - } - - #[allow(dead_code)] - pub async fn report_manual_ban(&self, pool_id: &PoolIdentifier, address: &Address) { - let event = BanEvent::Ban { - pool_id: pool_id.clone(), - address: address.clone(), - reason: BanReason::ManualBan, - }; - self.send(event).await - } - - pub async fn unban(&self, pool_id: &PoolIdentifier, address: &Address) { - let event = BanEvent::Unban { - pool_id: pool_id.clone(), - address: address.clone(), - }; - self.send(event).await; - } - - pub fn banlist(&self, pool_id: &PoolIdentifier) -> HashMap { - match (*(*BANLIST.load())).get(pool_id) { - Some(banlist) => banlist.clone(), - None => HashMap::default(), - } - } - - #[allow(dead_code)] - pub fn is_banned(&self, pool_id: &PoolIdentifier, address: &Address) -> bool { - match (*(*BANLIST.load())).get(pool_id) { - Some(pool_banlist) => match pool_banlist.get(address) { - Some(ban_entry) => ban_entry.is_active(), - None => false, - }, - None => false, +static BANLIST_MUTEX: Lazy> = Lazy::new(|| Mutex::new(0)); + +pub fn unban(pool_id: &PoolIdentifier, address: &Address) { + if !is_banned(pool_id, address) { + // Already not banned? No need to do any work + return; + } + let _guard = BANLIST_MUTEX.lock(); + if !is_banned(pool_id, address) { + // Maybe it was unbanned between our initial check and locking the mutex + // In that case, we don't need to do any work + return; + } + + let mut global_banlist = (**BANLIST.load()).clone(); + + match global_banlist.get_mut(pool_id) { + Some(pool_banlist) => { + if pool_banlist.remove(&address).is_none() { + // Was already not banned? Let's avoid publishing a new list + return; + } else { + // Banlist was updated, let's publish a new version for readers + BANLIST.store(Arc::new(global_banlist)); + } } + None => return, // Was already not banned? Let's avoid publishing a new list } } -pub struct BanWorker { - work_queue_tx: Sender, - work_queue_rx: Receiver, -} - -impl BanWorker { - pub fn new() -> BanWorker { - let (work_queue_tx, work_queue_rx) = mpsc::channel(100_000); - BanWorker { - work_queue_tx, - work_queue_rx, +fn ban(pool_id: &PoolIdentifier, address: &Address, reason: BanReason) { + if is_banned(pool_id, address) { + // Already banned? No need to do any work + return; + } + let _guard = BANLIST_MUTEX.lock(); + if is_banned(pool_id, address) { + // Maybe it was banned between our initial check and locking the mutex + // In that case, we don't need to do any work + return; + } + + let ban_duration_from_conf = get_ban_time(); + let ban_duration = match reason { + BanReason::FailedHealthCheck + | BanReason::MessageReceiveFailed + | BanReason::MessageSendFailed + | BanReason::FailedCheckout + | BanReason::StatementTimeout => { + Duration::from_secs(ban_duration_from_conf.try_into().unwrap()) } - } + BanReason::ManualBan => Duration::from_secs(86400), + }; - pub fn get_reporter(&self) -> BanManager { - BanManager::new(self.work_queue_tx.clone()) - } - - pub async fn start(&mut self) { - let mut internal_ban_list: BanList = BanList::default(); - let tx = self.work_queue_tx.clone(); - - tokio::task::spawn(async move { - let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(60)); - loop { - interval.tick().await; - match tx.try_send(BanEvent::CleanUpBanList) { - Ok(_) => (), - Err(err) => match err { - TrySendError::Full(_) => (), - TrySendError::Closed(_) => (), - }, - } - } - }); - - loop { - let event = match self.work_queue_rx.recv().await { - Some(stat) => stat, - None => { - return; - } - }; - - match event { - BanEvent::Ban { - pool_id, - address, - reason, - } => { - if self.ban(&mut internal_ban_list, &pool_id, &address, reason) { - // Ban list was changed, let's publish a new one - self.publish_banlist(&internal_ban_list); - } - } - BanEvent::Unban { pool_id, address } => { - if self.unban(&mut internal_ban_list, &pool_id, &address) { - // Ban list was changed, let's publish a new one - self.publish_banlist(&internal_ban_list); - } - } - BanEvent::CleanUpBanList => { - self.cleanup_ban_list(&mut internal_ban_list); - } - }; - } - } + let ban_time = Instant::now(); + let mut global_banlist = (**BANLIST.load()).clone(); + let pool_banlist = global_banlist.entry(pool_id.clone()).or_insert(HashMap::default()); - fn publish_banlist(&self, internal_ban_list: &BanList) { - BANLIST.store(Arc::new(internal_ban_list.clone())); - } + let ban_entry = pool_banlist.entry(address.clone()).or_insert(BanEntry { + reason: reason, + time: ban_time, + duration: ban_duration, + }); - fn cleanup_ban_list(&self, internal_ban_list: &mut BanList) { - for (_, v) in internal_ban_list { - v.retain(|_k, v| v.is_active()); - } + let old_banned_until = ban_entry.time + ban_entry.duration; + let new_banned_until = ban_time + ban_duration; + if new_banned_until >= old_banned_until { + ban_entry.duration = ban_duration; + ban_entry.time = ban_time; + ban_entry.reason = reason; } - fn unban( - &self, - internal_ban_list: &mut BanList, - pool_id: &PoolIdentifier, - address: &Address, - ) -> bool { - match internal_ban_list.get_mut(pool_id) { - Some(banlist) => { - if banlist.remove(&address).is_none() { - // Was already not banned? Let's avoid publishing a new list - return false; - } - } - None => return false, // Was already not banned? Let's avoid publishing a new list - } - return true; - } + // Clean up + pool_banlist.retain(|_k, v| v.is_active()); - fn ban( - &self, - internal_ban_list: &mut BanList, - pool_id: &PoolIdentifier, - address: &Address, - reason: BanReason, - ) -> bool { - let ban_duration_from_conf = get_ban_time(); - let ban_duration = match reason { - BanReason::FailedHealthCheck - | BanReason::MessageReceiveFailed - | BanReason::MessageSendFailed - | BanReason::FailedCheckout - | BanReason::StatementTimeout => { - Duration::from_secs(ban_duration_from_conf.try_into().unwrap()) - } - BanReason::ManualBan => Duration::from_secs(86400), - }; + BANLIST.store(Arc::new(global_banlist)); +} - // Technically, ban time is when client made the call but this should be close enough - let ban_time = Instant::now(); +pub fn report_failed_checkout(pool_id: &PoolIdentifier, address: &Address) { + ban(pool_id, address, BanReason::FailedCheckout); +} - let pool_banlist = internal_ban_list - .entry(pool_id.clone()) - .or_insert(HashMap::default()); +pub fn report_failed_healthcheck(pool_id: &PoolIdentifier, address: &Address) { + ban(pool_id, address, BanReason::FailedHealthCheck); +} - let ban_entry = pool_banlist.entry(address.clone()).or_insert(BanEntry { - reason: reason, - time: ban_time, - duration: ban_duration, - }); +pub fn report_server_send_failed(pool_id: &PoolIdentifier, address: &Address) { + ban(pool_id, address, BanReason::MessageSendFailed); +} - let old_banned_until = ban_entry.time + ban_entry.duration; - let new_banned_until = ban_time + ban_duration; - if new_banned_until >= old_banned_until { - ban_entry.duration = ban_duration; - ban_entry.time = ban_time; - ban_entry.reason = reason; - } +pub fn report_server_receive_failed(pool_id: &PoolIdentifier, address: &Address) { + ban(pool_id, address, BanReason::MessageReceiveFailed); +} - return true; - } +pub fn report_statement_timeout(pool_id: &PoolIdentifier, address: &Address) { + ban(pool_id, address, BanReason::StatementTimeout); } -pub fn start_ban_worker() { - let mut worker = BanWorker::new(); - BAN_MANAGER.store(Arc::new(worker.get_reporter())); +#[allow(dead_code)] +pub fn report_manual_ban(pool_id: &PoolIdentifier, address: &Address) { + ban(pool_id, address, BanReason::ManualBan); +} - tokio::task::spawn(async move { worker.start().await }); +pub fn banlist(pool_id: &PoolIdentifier) -> HashMap { + match (**BANLIST.load()).get(pool_id) { + Some(banlist) => banlist.clone(), + None => HashMap::default(), + } } -pub fn get_ban_manager() -> BanManager { - return (*(*BAN_MANAGER.load())).clone(); +pub fn is_banned(pool_id: &PoolIdentifier, address: &Address) -> bool { + match (**BANLIST.load()).get(pool_id) { + Some(pool_banlist) => match pool_banlist.get(address) { + Some(ban_entry) => ban_entry.is_active(), + None => false, + }, + None => false, + } } diff --git a/src/client.rs b/src/client.rs index 1c4d9daa..0f7400b6 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1038,8 +1038,7 @@ where match server.send(message).await { Ok(_) => Ok(()), Err(err) => { - pool.ban(address, self.process_id, BanReason::MessageSendFailed) - .await; + pool.ban(address, self.process_id, BanReason::MessageSendFailed); Err(err) } } @@ -1061,8 +1060,7 @@ where Ok(result) => match result { Ok(message) => Ok(message), Err(err) => { - pool.ban(address, self.process_id, BanReason::MessageReceiveFailed) - .await; + pool.ban(address, self.process_id, BanReason::MessageReceiveFailed); error_response_terminal( &mut self.write, &format!("error receiving data from server: {:?}", err), @@ -1077,8 +1075,7 @@ where address, pool.settings.user.username ); server.mark_bad(); - pool.ban(address, self.process_id, BanReason::StatementTimeout) - .await; + pool.ban(address, self.process_id, BanReason::StatementTimeout); error_response_terminal(&mut self.write, "pool statement timeout").await?; Err(Error::StatementTimeout) } @@ -1087,8 +1084,7 @@ where match server.recv().await { Ok(message) => Ok(message), Err(err) => { - pool.ban(address, self.process_id, BanReason::MessageReceiveFailed) - .await; + pool.ban(address, self.process_id, BanReason::MessageReceiveFailed); error_response_terminal( &mut self.write, &format!("error receiving data from server: {:?}", err), diff --git a/src/main.rs b/src/main.rs index 62d8751c..eb82df9f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -141,8 +141,6 @@ async fn main() { let (stats_tx, stats_rx) = mpsc::channel(100_000); REPORTER.store(Arc::new(Reporter::new(stats_tx.clone()))); - bans::start_ban_worker(); - // Connection pool that allows to query all shards and replicas. match ConnectionPool::from_config(client_server_map.clone()).await { Ok(_) => (), diff --git a/src/pool.rs b/src/pool.rs index b72bd358..d9c6ec64 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -11,7 +11,7 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; use std::time::Instant; -use crate::bans::{self, BanManager, BanReason}; +use crate::bans::{self, BanReason}; use crate::config::{get_config, Address, PoolMode, Role, User}; use crate::errors::Error; @@ -107,9 +107,6 @@ pub struct ConnectionPool { /// failover and load balancing deterministically. addresses: Vec>, - /// Reference to the global ban manager - ban_manager: BanManager, - /// The statistics aggregator runs in a separate task /// and receives stats from clients, servers, and the pool. stats: Reporter, @@ -236,7 +233,6 @@ impl ConnectionPool { let mut pool = ConnectionPool { databases: shards, addresses: addresses, - ban_manager: bans::get_ban_manager(), stats: get_reporter(), server_info: BytesMut::new(), settings: PoolSettings { @@ -368,8 +364,7 @@ impl ConnectionPool { Ok(conn) => conn, Err(err) => { error!("Banning instance {:?}, error: {:?}", address, err); - self.ban(&address, process_id, BanReason::FailedCheckout) - .await; + self.ban(&address, process_id, BanReason::FailedCheckout); self.stats.client_checkout_error(process_id, address.id); continue; } @@ -424,8 +419,7 @@ impl ConnectionPool { // Don't leave a bad connection in the pool. server.mark_bad(); - self.ban(&address, process_id, BanReason::FailedHealthCheck) - .await; + self.ban(&address, process_id, BanReason::FailedHealthCheck); continue; } }, @@ -439,8 +433,7 @@ impl ConnectionPool { // Don't leave a bad connection in the pool. server.mark_bad(); - self.ban(&address, process_id, BanReason::FailedHealthCheck) - .await; + self.ban(&address, process_id, BanReason::FailedHealthCheck); continue; } } @@ -452,7 +445,7 @@ impl ConnectionPool { /// Ban an address (i.e. replica). It no longer will serve /// traffic for any new transactions. If this call bans the last /// replica in a shard, we unban all replicas - pub async fn ban(&self, address: &Address, client_id: i32, reason: BanReason) { + pub fn ban(&self, address: &Address, client_id: i32, reason: BanReason) { error!("Banning {:?}", address); self.stats.client_ban_error(client_id, address.id); @@ -463,7 +456,7 @@ impl ConnectionPool { // We check if banning this address will result in all replica being banned // If so, we unban all replicas instead - let pool_banned_addresses = self.ban_manager.banlist(&self.settings.pool_id); + let pool_banned_addresses = bans::banlist(&self.settings.pool_id); let unbanned_count = self.addresses[address.shard] .iter() @@ -482,7 +475,7 @@ impl ConnectionPool { warn!("Unbanning all replicas."); for address in &self.addresses[address.shard] { if address.role == Role::Replica { - self.unban(address).await + self.unban(address); } } return; @@ -490,29 +483,19 @@ impl ConnectionPool { match reason { BanReason::FailedHealthCheck => { - self.ban_manager - .report_failed_healthcheck(&self.settings.pool_id, &address) - .await + bans::report_failed_healthcheck(&self.settings.pool_id, &address) } BanReason::MessageSendFailed => { - self.ban_manager - .report_server_send_failed(&self.settings.pool_id, &address) - .await + bans::report_server_send_failed(&self.settings.pool_id, &address) } BanReason::MessageReceiveFailed => { - self.ban_manager - .report_server_receive_failed(&self.settings.pool_id, &address) - .await + bans::report_server_receive_failed(&self.settings.pool_id, &address) } BanReason::StatementTimeout => { - self.ban_manager - .report_statement_timeout(&self.settings.pool_id, &address) - .await + bans::report_statement_timeout(&self.settings.pool_id, &address) } BanReason::FailedCheckout => { - self.ban_manager - .report_failed_checkout(&self.settings.pool_id, &address) - .await + bans::report_failed_checkout(&self.settings.pool_id, &address) } BanReason::ManualBan => unreachable!(), } @@ -520,10 +503,8 @@ impl ConnectionPool { /// Clear the replica to receive traffic again. ban/unban operations /// are not synchronous but are typically very fast - pub async fn unban(&self, address: &Address) { - self.ban_manager - .unban(&self.settings.pool_id, &address) - .await; + pub fn unban(&self, address: &Address) { + bans::unban(&self.settings.pool_id, &address); } /// Check if a replica can serve traffic. @@ -534,7 +515,7 @@ impl ConnectionPool { return false; } - return self.ban_manager.is_banned(&self.settings.pool_id, address); + return bans::is_banned(&self.settings.pool_id, address); } /// Get the number of configured shards. From 04c704bf774f56e902ccba8609d2186a312f4dbd Mon Sep 17 00:00:00 2001 From: Mostafa Abdelraouf Date: Tue, 11 Oct 2022 08:16:21 -0500 Subject: [PATCH 7/7] fmt --- src/bans.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/bans.rs b/src/bans.rs index a17a14ae..17e194c3 100644 --- a/src/bans.rs +++ b/src/bans.rs @@ -93,7 +93,9 @@ fn ban(pool_id: &PoolIdentifier, address: &Address, reason: BanReason) { let ban_time = Instant::now(); let mut global_banlist = (**BANLIST.load()).clone(); - let pool_banlist = global_banlist.entry(pool_id.clone()).or_insert(HashMap::default()); + let pool_banlist = global_banlist + .entry(pool_id.clone()) + .or_insert(HashMap::default()); let ban_entry = pool_banlist.entry(address.clone()).or_insert(BanEntry { reason: reason,