Skip to content

Move replica banning to its own task #181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .circleci/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ cd tests/ruby
sudo gem install bundler
bundle install
bundle exec ruby tests.rb || exit 1
bundle exec rspec *_spec.rb || exit 1
bundle exec rspec *_spec.rb --format documentation || exit 1
cd ../..

#
Expand Down
2 changes: 1 addition & 1 deletion src/admin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ where
for server in 0..pool.servers(shard) {
let address = pool.address(shard, server);
let pool_state = pool.pool_state(shard, server);
let banned = pool.is_banned(address, Some(address.role));
let banned = pool.is_banned(address);

res.put(data_row(&vec![
address.name(), // name
Expand Down
322 changes: 322 additions & 0 deletions src/bans.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
use arc_swap::ArcSwap;
use log::error;
use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::sync::Arc;

use crate::config::get_ban_time;
use crate::config::Address;
use crate::pool::PoolIdentifier;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TrySendError;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::time::Duration;
use tokio::time::Instant;

pub type BanList = HashMap<PoolIdentifier, HashMap<Address, BanEntry>>;

#[derive(Debug, Clone, Copy)]
pub enum BanReason {
FailedHealthCheck,
MessageSendFailed,
MessageReceiveFailed,
FailedCheckout,
StatementTimeout,
#[allow(dead_code)]
ManualBan,
}
#[derive(Debug, Clone)]
pub struct BanEntry {
reason: BanReason,
time: Instant,
duration: Duration,
}

#[derive(Debug, Clone)]
pub enum BanEvent {
Ban {
pool_id: PoolIdentifier,
address: Address,
reason: BanReason,
},
Unban {
pool_id: PoolIdentifier,
address: Address,
},
CleanUpBanList,
}

impl BanEntry {
pub fn has_expired(&self) -> bool {
return Instant::now().duration_since(self.time) > self.duration;
}

pub fn is_active(&self) -> bool {
!self.has_expired()
}
}
static BANLIST: Lazy<ArcSwap<BanList>> = Lazy::new(|| ArcSwap::from_pointee(BanList::default()));

static BAN_MANAGER: Lazy<ArcSwap<BanManager>> =
Lazy::new(|| ArcSwap::from_pointee(BanManager::default()));

#[derive(Clone, Debug)]
pub struct BanManager {
channel_to_worker: Sender<BanEvent>,
}

impl Default for BanManager {
fn default() -> BanManager {
let (channel_to_worker, _rx) = channel(1000);
Copy link
Contributor

@levkk levkk Oct 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking out loud:

This channel can get saturated during incidents. Imagine we have 30k clients that together produce 50k QPS. If a replica goes down, all of a sudden we'll get 50k banning events which will quickly saturate this channel. The async worker task may not even get scheduled to ban the replica.

With an RW lock, one of the clients will ban, and all the others will read the banlist and see that it's banned immediately. RW locks are probably more expensive than ArcSwap (something we should actually validate), but they are more effective than channels I think (something we should investigate).

I think the middle ground can be a Mutex for banning, and an ArcSwap for reading the list, so only one task gets to ban at a time instead of a thundering herd (pseudo-code):

fn ban() {
  if (is_banned()) return;

  let guard = ban_mutex.lock();

  let new_ban_list = ...;

  arc_swap.set(new_ban_list);
}

fn is_banned() {
   let ban_list = (*arc_swap);
   ban_list.contains_key(address);
}

The thesis that this architecture allows for other agents than clients to ban seems shaky. Another agent, e.g. an admin command, can easily use the pool.ban and pool.unban methods, for example, and it would require very little changes to the code and no changes to the arch.

BanManager { channel_to_worker }
}
}

impl BanManager {
/// Create a new Reporter instance.
pub fn new(channel_to_worker: Sender<BanEvent>) -> BanManager {
BanManager { channel_to_worker }
}

/// Send statistics to the task keeping track of stats.
async fn send(&self, event: BanEvent) {
let result = self.channel_to_worker.send(event.clone()).await;

match result {
Ok(_) => (()),
Err(err) => error!("Failed to send ban event {:?}", err),
};
}

pub async fn report_failed_checkout(&self, pool_id: &PoolIdentifier, address: &Address) {
let event = BanEvent::Ban {
pool_id: pool_id.clone(),
address: address.clone(),
reason: BanReason::FailedCheckout,
};
self.send(event).await
}

pub async fn report_failed_healthcheck(&self, pool_id: &PoolIdentifier, address: &Address) {
let event = BanEvent::Ban {
pool_id: pool_id.clone(),
address: address.clone(),
reason: BanReason::FailedHealthCheck,
};
self.send(event).await
}

pub async fn report_server_send_failed(&self, pool_id: &PoolIdentifier, address: &Address) {
let event = BanEvent::Ban {
pool_id: pool_id.clone(),
address: address.clone(),
reason: BanReason::MessageSendFailed,
};
self.send(event).await
}

pub async fn report_server_receive_failed(&self, pool_id: &PoolIdentifier, address: &Address) {
let event = BanEvent::Ban {
pool_id: pool_id.clone(),
address: address.clone(),
reason: BanReason::MessageReceiveFailed,
};
self.send(event).await
}

pub async fn report_statement_timeout(&self, pool_id: &PoolIdentifier, address: &Address) {
let event = BanEvent::Ban {
pool_id: pool_id.clone(),
address: address.clone(),
reason: BanReason::StatementTimeout,
};
self.send(event).await
}

#[allow(dead_code)]
pub async fn report_manual_ban(&self, pool_id: &PoolIdentifier, address: &Address) {
let event = BanEvent::Ban {
pool_id: pool_id.clone(),
address: address.clone(),
reason: BanReason::ManualBan,
};
self.send(event).await
}

pub async fn unban(&self, pool_id: &PoolIdentifier, address: &Address) {
let event = BanEvent::Unban {
pool_id: pool_id.clone(),
address: address.clone(),
};
self.send(event).await;
}

pub fn banlist(&self, pool_id: &PoolIdentifier) -> HashMap<Address, BanEntry> {
match (*(*BANLIST.load())).get(pool_id) {
Some(banlist) => banlist.clone(),
None => HashMap::default(),
}
}

#[allow(dead_code)]
pub fn is_banned(&self, pool_id: &PoolIdentifier, address: &Address) -> bool {
match (*(*BANLIST.load())).get(pool_id) {
Some(pool_banlist) => match pool_banlist.get(address) {
Some(ban_entry) => ban_entry.is_active(),
None => false,
},
None => false,
}
}
}

pub struct BanWorker {
work_queue_tx: Sender<BanEvent>,
work_queue_rx: Receiver<BanEvent>,
}

impl BanWorker {
pub fn new() -> BanWorker {
let (work_queue_tx, work_queue_rx) = mpsc::channel(100_000);
BanWorker {
work_queue_tx,
work_queue_rx,
}
}

pub fn get_reporter(&self) -> BanManager {
BanManager::new(self.work_queue_tx.clone())
}

pub async fn start(&mut self) {
let mut internal_ban_list: BanList = BanList::default();
let tx = self.work_queue_tx.clone();

tokio::task::spawn(async move {
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(60));
loop {
interval.tick().await;
match tx.try_send(BanEvent::CleanUpBanList) {
Ok(_) => (),
Err(err) => match err {
TrySendError::Full(_) => (),
TrySendError::Closed(_) => (),
},
}
}
});

loop {
let event = match self.work_queue_rx.recv().await {
Some(stat) => stat,
None => {
return;
}
};

match event {
BanEvent::Ban {
pool_id,
address,
reason,
} => {
if self.ban(&mut internal_ban_list, &pool_id, &address, reason) {
// Ban list was changed, let's publish a new one
self.publish_banlist(&internal_ban_list);
}
}
BanEvent::Unban { pool_id, address } => {
if self.unban(&mut internal_ban_list, &pool_id, &address) {
// Ban list was changed, let's publish a new one
self.publish_banlist(&internal_ban_list);
}
}
BanEvent::CleanUpBanList => {
self.cleanup_ban_list(&mut internal_ban_list);
}
};
}
}

fn publish_banlist(&self, internal_ban_list: &BanList) {
BANLIST.store(Arc::new(internal_ban_list.clone()));
}

fn cleanup_ban_list(&self, internal_ban_list: &mut BanList) {
for (_, v) in internal_ban_list {
v.retain(|_k, v| v.is_active());
}
}

fn unban(
&self,
internal_ban_list: &mut BanList,
pool_id: &PoolIdentifier,
address: &Address,
) -> bool {
match internal_ban_list.get_mut(pool_id) {
Some(banlist) => {
if banlist.remove(&address).is_none() {
// Was already not banned? Let's avoid publishing a new list
return false;
}
}
None => return false, // Was already not banned? Let's avoid publishing a new list
}
return true;
}

fn ban(
&self,
internal_ban_list: &mut BanList,
pool_id: &PoolIdentifier,
address: &Address,
reason: BanReason,
) -> bool {
let ban_duration_from_conf = get_ban_time();
let ban_duration = match reason {
BanReason::FailedHealthCheck
| BanReason::MessageReceiveFailed
| BanReason::MessageSendFailed
| BanReason::FailedCheckout
| BanReason::StatementTimeout => {
Duration::from_secs(ban_duration_from_conf.try_into().unwrap())
}
BanReason::ManualBan => Duration::from_secs(86400),
};

// Technically, ban time is when client made the call but this should be close enough
let ban_time = Instant::now();

let pool_banlist = internal_ban_list
.entry(pool_id.clone())
.or_insert(HashMap::default());

let ban_entry = pool_banlist.entry(address.clone()).or_insert(BanEntry {
reason: reason,
time: ban_time,
duration: ban_duration,
});

let old_banned_until = ban_entry.time + ban_entry.duration;
let new_banned_until = ban_time + ban_duration;
if new_banned_until >= old_banned_until {
ban_entry.duration = ban_duration;
ban_entry.time = ban_time;
ban_entry.reason = reason;
}

return true;
}
}

pub fn start_ban_worker() {
let mut worker = BanWorker::new();
BAN_MANAGER.store(Arc::new(worker.get_reporter()));

tokio::task::spawn(async move { worker.start().await });
}

pub fn get_ban_manager() -> BanManager {
return (*(*BAN_MANAGER.load())).clone();
}
13 changes: 9 additions & 4 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use tokio::sync::broadcast::Receiver;
use tokio::sync::mpsc::Sender;

use crate::admin::{generate_server_info_for_admin, handle_admin};
use crate::bans::BanReason;
use crate::config::{get_config, Address, PoolMode};
use crate::constants::*;
use crate::errors::Error;
Expand Down Expand Up @@ -1037,7 +1038,8 @@ where
match server.send(message).await {
Ok(_) => Ok(()),
Err(err) => {
pool.ban(address, self.process_id);
pool.ban(address, self.process_id, BanReason::MessageSendFailed)
.await;
Err(err)
}
}
Expand All @@ -1059,7 +1061,8 @@ where
Ok(result) => match result {
Ok(message) => Ok(message),
Err(err) => {
pool.ban(address, self.process_id);
pool.ban(address, self.process_id, BanReason::MessageReceiveFailed)
.await;
error_response_terminal(
&mut self.write,
&format!("error receiving data from server: {:?}", err),
Expand All @@ -1074,7 +1077,8 @@ where
address, pool.settings.user.username
);
server.mark_bad();
pool.ban(address, self.process_id);
pool.ban(address, self.process_id, BanReason::StatementTimeout)
.await;
error_response_terminal(&mut self.write, "pool statement timeout").await?;
Err(Error::StatementTimeout)
}
Expand All @@ -1083,7 +1087,8 @@ where
match server.recv().await {
Ok(message) => Ok(message),
Err(err) => {
pool.ban(address, self.process_id);
pool.ban(address, self.process_id, BanReason::MessageReceiveFailed)
.await;
error_response_terminal(
&mut self.write,
&format!("error receiving data from server: {:?}", err),
Expand Down
4 changes: 4 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,10 @@ pub fn get_config() -> Config {
(*(*CONFIG.load())).clone()
}

pub fn get_ban_time() -> i64 {
(*(*CONFIG.load())).general.ban_time
}

/// Parse the configuration file located at the path.
pub async fn parse(path: &str) -> Result<(), Error> {
let mut contents = String::new();
Expand Down
Loading