Add Manual host banning to PgCat (#340)

Sometimes we want an admin to be able to ban a host for some time to route traffic away from that host for reasons like partial outages, replication lag, and scheduled maintenance.

We can achieve this today using a configuration update but a quicker approach is to send a control command to PgCat that bans the replica for some specified duration.

This command does not change the current banning rules like

Primaries cannot be banned
When all replicas are banned, all replicas are unbanned
This commit is contained in:
Mostafa Abdelraouf
2023-03-06 06:10:59 -06:00
committed by GitHub
parent 8a0da10a87
commit 2cc6a09fba
5 changed files with 300 additions and 13 deletions

View File

@@ -29,7 +29,7 @@ pub type SecretKey = i32;
pub type ServerHost = String;
pub type ServerPort = u16;
pub type BanList = Arc<RwLock<Vec<HashMap<Address, NaiveDateTime>>>>;
pub type BanList = Arc<RwLock<Vec<HashMap<Address, (BanReason, NaiveDateTime)>>>>;
pub type ClientServerMap =
Arc<Mutex<HashMap<(ProcessId, SecretKey), (ProcessId, SecretKey, ServerHost, ServerPort)>>>;
pub type PoolMap = HashMap<PoolIdentifier, ConnectionPool>;
@@ -38,6 +38,17 @@ pub type PoolMap = HashMap<PoolIdentifier, ConnectionPool>;
/// The pool is recreated dynamically when the config is reloaded.
pub static POOLS: Lazy<ArcSwap<PoolMap>> = Lazy::new(|| ArcSwap::from_pointee(HashMap::default()));
// Reasons for banning a server.
#[derive(Debug, PartialEq, Clone)]
pub enum BanReason {
FailedHealthCheck,
MessageSendFailed,
MessageReceiveFailed,
FailedCheckout,
StatementTimeout,
AdminBan(i64),
}
/// An identifier for a PgCat pool,
/// a database visible to clients.
#[derive(Hash, Debug, Clone, PartialEq, Eq)]
@@ -489,7 +500,7 @@ impl ConnectionPool {
Ok(conn) => conn,
Err(err) => {
error!("Banning instance {:?}, error: {:?}", address, err);
self.ban(address, client_process_id);
self.ban(address, BanReason::FailedCheckout, client_process_id);
self.stats
.client_checkout_error(client_process_id, address.id);
continue;
@@ -582,14 +593,14 @@ impl ConnectionPool {
// Don't leave a bad connection in the pool.
server.mark_bad();
self.ban(&address, client_process_id);
self.ban(&address, BanReason::FailedHealthCheck, client_process_id);
return false;
}
/// Ban an address (i.e. replica). It no longer will serve
/// traffic for any new transactions. Existing transactions on that replica
/// will finish successfully or error out to the clients.
pub fn ban(&self, address: &Address, client_id: i32) {
pub fn ban(&self, address: &Address, reason: BanReason, client_id: i32) {
// Primary can never be banned
if address.role == Role::Primary {
return;
@@ -599,12 +610,12 @@ impl ConnectionPool {
let mut guard = self.banlist.write();
error!("Banning {:?}", address);
self.stats.client_ban_error(client_id, address.id);
guard[address.shard].insert(address.clone(), now);
guard[address.shard].insert(address.clone(), (reason, now));
}
/// Clear the replica to receive traffic again. Takes effect immediately
/// for all new transactions.
pub fn _unban(&self, address: &Address) {
pub fn unban(&self, address: &Address) {
let mut guard = self.banlist.write();
guard[address.shard].remove(address);
}
@@ -653,9 +664,14 @@ impl ConnectionPool {
// Check if ban time is expired
let read_guard = self.banlist.read();
let exceeded_ban_time = match read_guard[address.shard].get(address) {
Some(timestamp) => {
Some((ban_reason, timestamp)) => {
let now = chrono::offset::Utc::now().naive_utc();
now.timestamp() - timestamp.timestamp() > self.settings.ban_time
match ban_reason {
BanReason::AdminBan(duration) => {
now.timestamp() - timestamp.timestamp() > *duration
}
_ => now.timestamp() - timestamp.timestamp() > self.settings.ban_time,
}
}
None => return true,
};
@@ -679,6 +695,31 @@ impl ConnectionPool {
self.databases.len()
}
pub fn get_bans(&self) -> Vec<(Address, (BanReason, NaiveDateTime))> {
let mut bans: Vec<(Address, (BanReason, NaiveDateTime))> = Vec::new();
let guard = self.banlist.read();
for banlist in guard.iter() {
for (address, (reason, timestamp)) in banlist.iter() {
bans.push((address.clone(), (reason.clone(), timestamp.clone())));
}
}
return bans;
}
/// Get the address from the host url
pub fn get_addresses_from_host(&self, host: &str) -> Vec<Address> {
let mut addresses = Vec::new();
for shard in 0..self.shards() {
for server in 0..self.servers(shard) {
let address = self.address(shard, server);
if address.host == host {
addresses.push(address.clone());
}
}
}
addresses
}
/// Get the number of servers (primary and replicas)
/// configured for a shard.
pub fn servers(&self, shard: usize) -> usize {