Add Manual host banning to PgCat (#340)

Sometimes we want an admin to be able to ban a host for some time to route traffic away from that host for reasons like partial outages, replication lag, and scheduled maintenance.

We can achieve this today using a configuration update but a quicker approach is to send a control command to PgCat that bans the replica for some specified duration.

This command does not change the current banning rules like

Primaries cannot be banned
When all replicas are banned, all replicas are unbanned
This commit is contained in:
Mostafa Abdelraouf
2023-03-06 06:10:59 -06:00
committed by GitHub
parent 8a0da10a87
commit 2cc6a09fba
5 changed files with 300 additions and 13 deletions

View File

@@ -1,9 +1,12 @@
use crate::config::Role;
use crate::pool::BanReason;
/// Admin database.
use bytes::{Buf, BufMut, BytesMut};
use log::{error, info, trace};
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::time::Instant;
use crate::config::{get_config, reload_config, VERSION};
@@ -53,6 +56,14 @@ where
let query_parts: Vec<&str> = query.trim_end_matches(';').split_whitespace().collect();
match query_parts[0].to_ascii_uppercase().as_str() {
"BAN" => {
trace!("BAN");
ban(stream, query_parts).await
}
"UNBAN" => {
trace!("UNBAN");
unban(stream, query_parts).await
}
"RELOAD" => {
trace!("RELOAD");
reload(stream, client_server_map).await
@@ -74,6 +85,10 @@ where
shutdown(stream).await
}
"SHOW" => match query_parts[1].to_ascii_uppercase().as_str() {
"BANS" => {
trace!("SHOW BANS");
show_bans(stream).await
}
"CONFIG" => {
trace!("SHOW CONFIG");
show_config(stream).await
@@ -350,6 +365,163 @@ where
custom_protocol_response_ok(stream, "SET").await
}
/// Bans a host from being used
async fn ban<T>(stream: &mut T, tokens: Vec<&str>) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let host = match tokens.get(1) {
Some(host) => host,
None => return error_response(stream, "usage: BAN hostname duration_seconds").await,
};
let duration_seconds = match tokens.get(2) {
Some(duration_seconds) => match duration_seconds.parse::<i64>() {
Ok(duration_seconds) => duration_seconds,
Err(_) => {
return error_response(stream, "duration_seconds must be an integer").await;
}
},
None => return error_response(stream, "usage: BAN hostname duration_seconds").await,
};
if duration_seconds <= 0 {
return error_response(stream, "duration_seconds must be >= 0").await;
}
let columns = vec![
("db", DataType::Text),
("user", DataType::Text),
("role", DataType::Text),
("host", DataType::Text),
];
let mut res = BytesMut::new();
res.put(row_description(&columns));
for (id, pool) in get_all_pools().iter() {
for address in pool.get_addresses_from_host(host) {
if !pool.is_banned(&address) {
pool.ban(&address, BanReason::AdminBan(duration_seconds), -1);
res.put(data_row(&vec![
id.db.clone(),
id.user.clone(),
address.role.to_string(),
address.host,
]));
}
}
}
res.put(command_complete("BAN"));
// ReadyForQuery
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, &res).await
}
/// Clear a host for use
async fn unban<T>(stream: &mut T, tokens: Vec<&str>) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let host = match tokens.get(1) {
Some(host) => host,
None => return error_response(stream, "UNBAN command requires a hostname to unban").await,
};
let columns = vec![
("db", DataType::Text),
("user", DataType::Text),
("role", DataType::Text),
("host", DataType::Text),
];
let mut res = BytesMut::new();
res.put(row_description(&columns));
for (id, pool) in get_all_pools().iter() {
for address in pool.get_addresses_from_host(host) {
if pool.is_banned(&address) {
pool.unban(&address);
res.put(data_row(&vec![
id.db.clone(),
id.user.clone(),
address.role.to_string(),
address.host,
]));
}
}
}
res.put(command_complete("UNBAN"));
// ReadyForQuery
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, &res).await
}
/// Shows all the bans
async fn show_bans<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let columns = vec![
("db", DataType::Text),
("user", DataType::Text),
("role", DataType::Text),
("host", DataType::Text),
("reason", DataType::Text),
("ban_time", DataType::Text),
("ban_duration_seconds", DataType::Text),
("ban_remaining_seconds", DataType::Text),
];
let mut res = BytesMut::new();
res.put(row_description(&columns));
// The block should be pretty quick so we cache the time outside
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs() as i64;
for (id, pool) in get_all_pools().iter() {
for (address, (ban_reason, ban_time)) in pool.get_bans().iter() {
let ban_duration = match ban_reason {
BanReason::AdminBan(duration) => *duration,
_ => pool.settings.ban_time,
};
let remaining = ban_duration - (now - ban_time.timestamp());
if remaining <= 0 {
continue;
}
res.put(data_row(&vec![
id.db.clone(),
id.user.clone(),
address.role.to_string(),
address.host.clone(),
format!("{:?}", ban_reason),
ban_time.to_string(),
ban_duration.to_string(),
remaining.to_string(),
]));
}
}
res.put(command_complete("SHOW BANS"));
// ReadyForQuery
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, &res).await
}
/// Reload the configuration file without restarting the process.
async fn reload<T>(stream: &mut T, client_server_map: ClientServerMap) -> Result<(), Error>
where

View File

@@ -1,6 +1,9 @@
use crate::errors::Error;
use crate::pool::BanReason;
/// Handle clients by pretending to be a PostgreSQL server.
use bytes::{Buf, BufMut, BytesMut};
use log::{debug, error, info, trace, warn};
use std::collections::HashMap;
use std::time::Instant;
use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf};
@@ -11,7 +14,7 @@ use tokio::sync::mpsc::Sender;
use crate::admin::{generate_server_info_for_admin, handle_admin};
use crate::config::{get_config, Address, PoolMode};
use crate::constants::*;
use crate::errors::Error;
use crate::messages::*;
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
use crate::query_router::{Command, QueryRouter};
@@ -1111,7 +1114,7 @@ where
match server.send(message).await {
Ok(_) => Ok(()),
Err(err) => {
pool.ban(address, self.process_id);
pool.ban(address, BanReason::MessageSendFailed, self.process_id);
Err(err)
}
}
@@ -1133,7 +1136,7 @@ where
Ok(result) => match result {
Ok(message) => Ok(message),
Err(err) => {
pool.ban(address, self.process_id);
pool.ban(address, BanReason::MessageReceiveFailed, self.process_id);
error_response_terminal(
&mut self.write,
&format!("error receiving data from server: {:?}", err),
@@ -1148,7 +1151,7 @@ where
address, pool.settings.user.username
);
server.mark_bad();
pool.ban(address, self.process_id);
pool.ban(address, BanReason::StatementTimeout, self.process_id);
error_response_terminal(&mut self.write, "pool statement timeout").await?;
Err(Error::StatementTimeout)
}
@@ -1157,7 +1160,7 @@ where
match server.recv().await {
Ok(message) => Ok(message),
Err(err) => {
pool.ban(address, self.process_id);
pool.ban(address, BanReason::MessageReceiveFailed, self.process_id);
error_response_terminal(
&mut self.write,
&format!("error receiving data from server: {:?}", err),

View File

@@ -6,6 +6,7 @@ pub enum Error {
SocketError(String),
ClientBadStartup,
ProtocolSyncError(String),
BadQuery(String),
ServerError,
BadConfig,
AllServersDown,

View File

@@ -29,7 +29,7 @@ pub type SecretKey = i32;
pub type ServerHost = String;
pub type ServerPort = u16;
pub type BanList = Arc<RwLock<Vec<HashMap<Address, NaiveDateTime>>>>;
pub type BanList = Arc<RwLock<Vec<HashMap<Address, (BanReason, NaiveDateTime)>>>>;
pub type ClientServerMap =
Arc<Mutex<HashMap<(ProcessId, SecretKey), (ProcessId, SecretKey, ServerHost, ServerPort)>>>;
pub type PoolMap = HashMap<PoolIdentifier, ConnectionPool>;
@@ -38,6 +38,17 @@ pub type PoolMap = HashMap<PoolIdentifier, ConnectionPool>;
/// The pool is recreated dynamically when the config is reloaded.
pub static POOLS: Lazy<ArcSwap<PoolMap>> = Lazy::new(|| ArcSwap::from_pointee(HashMap::default()));
// Reasons for banning a server.
#[derive(Debug, PartialEq, Clone)]
pub enum BanReason {
FailedHealthCheck,
MessageSendFailed,
MessageReceiveFailed,
FailedCheckout,
StatementTimeout,
AdminBan(i64),
}
/// An identifier for a PgCat pool,
/// a database visible to clients.
#[derive(Hash, Debug, Clone, PartialEq, Eq)]
@@ -489,7 +500,7 @@ impl ConnectionPool {
Ok(conn) => conn,
Err(err) => {
error!("Banning instance {:?}, error: {:?}", address, err);
self.ban(address, client_process_id);
self.ban(address, BanReason::FailedCheckout, client_process_id);
self.stats
.client_checkout_error(client_process_id, address.id);
continue;
@@ -582,14 +593,14 @@ impl ConnectionPool {
// Don't leave a bad connection in the pool.
server.mark_bad();
self.ban(&address, client_process_id);
self.ban(&address, BanReason::FailedHealthCheck, client_process_id);
return false;
}
/// Ban an address (i.e. replica). It no longer will serve
/// traffic for any new transactions. Existing transactions on that replica
/// will finish successfully or error out to the clients.
pub fn ban(&self, address: &Address, client_id: i32) {
pub fn ban(&self, address: &Address, reason: BanReason, client_id: i32) {
// Primary can never be banned
if address.role == Role::Primary {
return;
@@ -599,12 +610,12 @@ impl ConnectionPool {
let mut guard = self.banlist.write();
error!("Banning {:?}", address);
self.stats.client_ban_error(client_id, address.id);
guard[address.shard].insert(address.clone(), now);
guard[address.shard].insert(address.clone(), (reason, now));
}
/// Clear the replica to receive traffic again. Takes effect immediately
/// for all new transactions.
pub fn _unban(&self, address: &Address) {
pub fn unban(&self, address: &Address) {
let mut guard = self.banlist.write();
guard[address.shard].remove(address);
}
@@ -653,9 +664,14 @@ impl ConnectionPool {
// Check if ban time is expired
let read_guard = self.banlist.read();
let exceeded_ban_time = match read_guard[address.shard].get(address) {
Some(timestamp) => {
Some((ban_reason, timestamp)) => {
let now = chrono::offset::Utc::now().naive_utc();
now.timestamp() - timestamp.timestamp() > self.settings.ban_time
match ban_reason {
BanReason::AdminBan(duration) => {
now.timestamp() - timestamp.timestamp() > *duration
}
_ => now.timestamp() - timestamp.timestamp() > self.settings.ban_time,
}
}
None => return true,
};
@@ -679,6 +695,31 @@ impl ConnectionPool {
self.databases.len()
}
pub fn get_bans(&self) -> Vec<(Address, (BanReason, NaiveDateTime))> {
let mut bans: Vec<(Address, (BanReason, NaiveDateTime))> = Vec::new();
let guard = self.banlist.read();
for banlist in guard.iter() {
for (address, (reason, timestamp)) in banlist.iter() {
bans.push((address.clone(), (reason.clone(), timestamp.clone())));
}
}
return bans;
}
/// Get the address from the host url
pub fn get_addresses_from_host(&self, host: &str) -> Vec<Address> {
let mut addresses = Vec::new();
for shard in 0..self.shards() {
for server in 0..self.servers(shard) {
let address = self.address(shard, server);
if address.host == host {
addresses.push(address.clone());
}
}
}
addresses
}
/// Get the number of servers (primary and replicas)
/// configured for a shard.
pub fn servers(&self, shard: usize) -> usize {

View File

@@ -287,6 +287,76 @@ describe "Admin" do
end
end
describe "Manual Banning" do
let(:processes) { Helpers::Pgcat.single_shard_setup("sharded_db", 10) }
before do
new_configs = processes.pgcat.current_config
# Prevent immediate unbanning when we ban localhost
new_configs["pools"]["sharded_db"]["shards"]["0"]["servers"][0][0] = "127.0.0.1"
new_configs["pools"]["sharded_db"]["shards"]["0"]["servers"][1][0] = "127.0.0.1"
processes.pgcat.update_config(new_configs)
processes.pgcat.reload_config
end
describe "BAN/UNBAN and SHOW BANS" do
it "bans/unbans hosts" do
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
# Returns a list of the banned addresses
results = admin_conn.async_exec("BAN localhost 10").to_a
expect(results.count).to eq(2)
expect(results.map{ |r| r["host"] }.uniq).to eq(["localhost"])
# Subsequent calls should yield no results
results = admin_conn.async_exec("BAN localhost 10").to_a
expect(results.count).to eq(0)
results = admin_conn.async_exec("SHOW BANS").to_a
expect(results.count).to eq(2)
expect(results.map{ |r| r["host"] }.uniq).to eq(["localhost"])
# Returns a list of the unbanned addresses
results = admin_conn.async_exec("UNBAN localhost").to_a
expect(results.count).to eq(2)
expect(results.map{ |r| r["host"] }.uniq).to eq(["localhost"])
# Subsequent calls should yield no results
results = admin_conn.async_exec("UNBAN localhost").to_a
expect(results.count).to eq(0)
results = admin_conn.async_exec("SHOW BANS").to_a
expect(results.count).to eq(0)
end
it "honors ban duration" do
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
# Returns a list of the banned addresses
results = admin_conn.async_exec("BAN localhost 1").to_a
expect(results.count).to eq(2)
expect(results.map{ |r| r["host"] }.uniq).to eq(["localhost"])
sleep(2)
# After 2 seconds the ban should be lifted
results = admin_conn.async_exec("SHOW BANS").to_a
expect(results.count).to eq(0)
end
it "can handle bad input" do
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
expect { admin_conn.async_exec("BAN").to_a }.to raise_error(PG::SystemError)
expect { admin_conn.async_exec("BAN a").to_a }.to raise_error(PG::SystemError)
expect { admin_conn.async_exec("BAN a a").to_a }.to raise_error(PG::SystemError)
expect { admin_conn.async_exec("BAN a -5").to_a }.to raise_error(PG::SystemError)
expect { admin_conn.async_exec("BAN a 0").to_a }.to raise_error(PG::SystemError)
expect { admin_conn.async_exec("BAN a a a").to_a }.to raise_error(PG::SystemError)
expect { admin_conn.async_exec("UNBAN").to_a }.to raise_error(PG::SystemError)
end
end
end
describe "SHOW users" do
it "returns the right users" do
admin_conn = PG::connect(processes.pgcat.admin_connection_string)