From 2cc6a09fba1a8d74ecce54fa6947f1eca1a28344 Mon Sep 17 00:00:00 2001 From: Mostafa Abdelraouf Date: Mon, 6 Mar 2023 06:10:59 -0600 Subject: [PATCH] Add Manual host banning to PgCat (#340) Sometimes we want an admin to be able to ban a host for some time to route traffic away from that host for reasons like partial outages, replication lag, and scheduled maintenance. We can achieve this today using a configuration update but a quicker approach is to send a control command to PgCat that bans the replica for some specified duration. This command does not change the current banning rules like Primaries cannot be banned When all replicas are banned, all replicas are unbanned --- src/admin.rs | 172 +++++++++++++++++++++++++++++++++++++++ src/client.rs | 13 +-- src/errors.rs | 1 + src/pool.rs | 57 +++++++++++-- tests/ruby/admin_spec.rb | 70 ++++++++++++++++ 5 files changed, 300 insertions(+), 13 deletions(-) diff --git a/src/admin.rs b/src/admin.rs index 7f7ada5..c90f28e 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -1,9 +1,12 @@ +use crate::config::Role; +use crate::pool::BanReason; /// Admin database. use bytes::{Buf, BufMut, BytesMut}; use log::{error, info, trace}; use nix::sys::signal::{self, Signal}; use nix::unistd::Pid; use std::collections::HashMap; +use std::time::{SystemTime, UNIX_EPOCH}; use tokio::time::Instant; use crate::config::{get_config, reload_config, VERSION}; @@ -53,6 +56,14 @@ where let query_parts: Vec<&str> = query.trim_end_matches(';').split_whitespace().collect(); match query_parts[0].to_ascii_uppercase().as_str() { + "BAN" => { + trace!("BAN"); + ban(stream, query_parts).await + } + "UNBAN" => { + trace!("UNBAN"); + unban(stream, query_parts).await + } "RELOAD" => { trace!("RELOAD"); reload(stream, client_server_map).await @@ -74,6 +85,10 @@ where shutdown(stream).await } "SHOW" => match query_parts[1].to_ascii_uppercase().as_str() { + "BANS" => { + trace!("SHOW BANS"); + show_bans(stream).await + } "CONFIG" => { trace!("SHOW CONFIG"); show_config(stream).await @@ -350,6 +365,163 @@ where custom_protocol_response_ok(stream, "SET").await } +/// Bans a host from being used +async fn ban(stream: &mut T, tokens: Vec<&str>) -> Result<(), Error> +where + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ + let host = match tokens.get(1) { + Some(host) => host, + None => return error_response(stream, "usage: BAN hostname duration_seconds").await, + }; + + let duration_seconds = match tokens.get(2) { + Some(duration_seconds) => match duration_seconds.parse::() { + Ok(duration_seconds) => duration_seconds, + Err(_) => { + return error_response(stream, "duration_seconds must be an integer").await; + } + }, + None => return error_response(stream, "usage: BAN hostname duration_seconds").await, + }; + + if duration_seconds <= 0 { + return error_response(stream, "duration_seconds must be >= 0").await; + } + + let columns = vec![ + ("db", DataType::Text), + ("user", DataType::Text), + ("role", DataType::Text), + ("host", DataType::Text), + ]; + let mut res = BytesMut::new(); + res.put(row_description(&columns)); + + for (id, pool) in get_all_pools().iter() { + for address in pool.get_addresses_from_host(host) { + if !pool.is_banned(&address) { + pool.ban(&address, BanReason::AdminBan(duration_seconds), -1); + res.put(data_row(&vec![ + id.db.clone(), + id.user.clone(), + address.role.to_string(), + address.host, + ])); + } + } + } + + res.put(command_complete("BAN")); + + // ReadyForQuery + res.put_u8(b'Z'); + res.put_i32(5); + res.put_u8(b'I'); + + write_all_half(stream, &res).await +} + +/// Clear a host for use +async fn unban(stream: &mut T, tokens: Vec<&str>) -> Result<(), Error> +where + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ + let host = match tokens.get(1) { + Some(host) => host, + None => return error_response(stream, "UNBAN command requires a hostname to unban").await, + }; + + let columns = vec![ + ("db", DataType::Text), + ("user", DataType::Text), + ("role", DataType::Text), + ("host", DataType::Text), + ]; + let mut res = BytesMut::new(); + res.put(row_description(&columns)); + + for (id, pool) in get_all_pools().iter() { + for address in pool.get_addresses_from_host(host) { + if pool.is_banned(&address) { + pool.unban(&address); + res.put(data_row(&vec![ + id.db.clone(), + id.user.clone(), + address.role.to_string(), + address.host, + ])); + } + } + } + + res.put(command_complete("UNBAN")); + + // ReadyForQuery + res.put_u8(b'Z'); + res.put_i32(5); + res.put_u8(b'I'); + + write_all_half(stream, &res).await +} + +/// Shows all the bans +async fn show_bans(stream: &mut T) -> Result<(), Error> +where + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ + let columns = vec![ + ("db", DataType::Text), + ("user", DataType::Text), + ("role", DataType::Text), + ("host", DataType::Text), + ("reason", DataType::Text), + ("ban_time", DataType::Text), + ("ban_duration_seconds", DataType::Text), + ("ban_remaining_seconds", DataType::Text), + ]; + let mut res = BytesMut::new(); + res.put(row_description(&columns)); + + // The block should be pretty quick so we cache the time outside + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_secs() as i64; + + for (id, pool) in get_all_pools().iter() { + for (address, (ban_reason, ban_time)) in pool.get_bans().iter() { + let ban_duration = match ban_reason { + BanReason::AdminBan(duration) => *duration, + _ => pool.settings.ban_time, + }; + let remaining = ban_duration - (now - ban_time.timestamp()); + if remaining <= 0 { + continue; + } + res.put(data_row(&vec![ + id.db.clone(), + id.user.clone(), + address.role.to_string(), + address.host.clone(), + format!("{:?}", ban_reason), + ban_time.to_string(), + ban_duration.to_string(), + remaining.to_string(), + ])); + } + } + + res.put(command_complete("SHOW BANS")); + + // ReadyForQuery + res.put_u8(b'Z'); + res.put_i32(5); + res.put_u8(b'I'); + + write_all_half(stream, &res).await +} + /// Reload the configuration file without restarting the process. async fn reload(stream: &mut T, client_server_map: ClientServerMap) -> Result<(), Error> where diff --git a/src/client.rs b/src/client.rs index 11f4cc6..2e04c2a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,6 +1,9 @@ +use crate::errors::Error; +use crate::pool::BanReason; /// Handle clients by pretending to be a PostgreSQL server. use bytes::{Buf, BufMut, BytesMut}; use log::{debug, error, info, trace, warn}; + use std::collections::HashMap; use std::time::Instant; use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf}; @@ -11,7 +14,7 @@ use tokio::sync::mpsc::Sender; use crate::admin::{generate_server_info_for_admin, handle_admin}; use crate::config::{get_config, Address, PoolMode}; use crate::constants::*; -use crate::errors::Error; + use crate::messages::*; use crate::pool::{get_pool, ClientServerMap, ConnectionPool}; use crate::query_router::{Command, QueryRouter}; @@ -1111,7 +1114,7 @@ where match server.send(message).await { Ok(_) => Ok(()), Err(err) => { - pool.ban(address, self.process_id); + pool.ban(address, BanReason::MessageSendFailed, self.process_id); Err(err) } } @@ -1133,7 +1136,7 @@ where Ok(result) => match result { Ok(message) => Ok(message), Err(err) => { - pool.ban(address, self.process_id); + pool.ban(address, BanReason::MessageReceiveFailed, self.process_id); error_response_terminal( &mut self.write, &format!("error receiving data from server: {:?}", err), @@ -1148,7 +1151,7 @@ where address, pool.settings.user.username ); server.mark_bad(); - pool.ban(address, self.process_id); + pool.ban(address, BanReason::StatementTimeout, self.process_id); error_response_terminal(&mut self.write, "pool statement timeout").await?; Err(Error::StatementTimeout) } @@ -1157,7 +1160,7 @@ where match server.recv().await { Ok(message) => Ok(message), Err(err) => { - pool.ban(address, self.process_id); + pool.ban(address, BanReason::MessageReceiveFailed, self.process_id); error_response_terminal( &mut self.write, &format!("error receiving data from server: {:?}", err), diff --git a/src/errors.rs b/src/errors.rs index 4ac23a8..310243c 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -6,6 +6,7 @@ pub enum Error { SocketError(String), ClientBadStartup, ProtocolSyncError(String), + BadQuery(String), ServerError, BadConfig, AllServersDown, diff --git a/src/pool.rs b/src/pool.rs index cca9577..0a0a53f 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -29,7 +29,7 @@ pub type SecretKey = i32; pub type ServerHost = String; pub type ServerPort = u16; -pub type BanList = Arc>>>; +pub type BanList = Arc>>>; pub type ClientServerMap = Arc>>; pub type PoolMap = HashMap; @@ -38,6 +38,17 @@ pub type PoolMap = HashMap; /// The pool is recreated dynamically when the config is reloaded. pub static POOLS: Lazy> = Lazy::new(|| ArcSwap::from_pointee(HashMap::default())); +// Reasons for banning a server. +#[derive(Debug, PartialEq, Clone)] +pub enum BanReason { + FailedHealthCheck, + MessageSendFailed, + MessageReceiveFailed, + FailedCheckout, + StatementTimeout, + AdminBan(i64), +} + /// An identifier for a PgCat pool, /// a database visible to clients. #[derive(Hash, Debug, Clone, PartialEq, Eq)] @@ -489,7 +500,7 @@ impl ConnectionPool { Ok(conn) => conn, Err(err) => { error!("Banning instance {:?}, error: {:?}", address, err); - self.ban(address, client_process_id); + self.ban(address, BanReason::FailedCheckout, client_process_id); self.stats .client_checkout_error(client_process_id, address.id); continue; @@ -582,14 +593,14 @@ impl ConnectionPool { // Don't leave a bad connection in the pool. server.mark_bad(); - self.ban(&address, client_process_id); + self.ban(&address, BanReason::FailedHealthCheck, client_process_id); return false; } /// Ban an address (i.e. replica). It no longer will serve /// traffic for any new transactions. Existing transactions on that replica /// will finish successfully or error out to the clients. - pub fn ban(&self, address: &Address, client_id: i32) { + pub fn ban(&self, address: &Address, reason: BanReason, client_id: i32) { // Primary can never be banned if address.role == Role::Primary { return; @@ -599,12 +610,12 @@ impl ConnectionPool { let mut guard = self.banlist.write(); error!("Banning {:?}", address); self.stats.client_ban_error(client_id, address.id); - guard[address.shard].insert(address.clone(), now); + guard[address.shard].insert(address.clone(), (reason, now)); } /// Clear the replica to receive traffic again. Takes effect immediately /// for all new transactions. - pub fn _unban(&self, address: &Address) { + pub fn unban(&self, address: &Address) { let mut guard = self.banlist.write(); guard[address.shard].remove(address); } @@ -653,9 +664,14 @@ impl ConnectionPool { // Check if ban time is expired let read_guard = self.banlist.read(); let exceeded_ban_time = match read_guard[address.shard].get(address) { - Some(timestamp) => { + Some((ban_reason, timestamp)) => { let now = chrono::offset::Utc::now().naive_utc(); - now.timestamp() - timestamp.timestamp() > self.settings.ban_time + match ban_reason { + BanReason::AdminBan(duration) => { + now.timestamp() - timestamp.timestamp() > *duration + } + _ => now.timestamp() - timestamp.timestamp() > self.settings.ban_time, + } } None => return true, }; @@ -679,6 +695,31 @@ impl ConnectionPool { self.databases.len() } + pub fn get_bans(&self) -> Vec<(Address, (BanReason, NaiveDateTime))> { + let mut bans: Vec<(Address, (BanReason, NaiveDateTime))> = Vec::new(); + let guard = self.banlist.read(); + for banlist in guard.iter() { + for (address, (reason, timestamp)) in banlist.iter() { + bans.push((address.clone(), (reason.clone(), timestamp.clone()))); + } + } + return bans; + } + + /// Get the address from the host url + pub fn get_addresses_from_host(&self, host: &str) -> Vec
{ + let mut addresses = Vec::new(); + for shard in 0..self.shards() { + for server in 0..self.servers(shard) { + let address = self.address(shard, server); + if address.host == host { + addresses.push(address.clone()); + } + } + } + addresses + } + /// Get the number of servers (primary and replicas) /// configured for a shard. pub fn servers(&self, shard: usize) -> usize { diff --git a/tests/ruby/admin_spec.rb b/tests/ruby/admin_spec.rb index 7836415..f69c3df 100644 --- a/tests/ruby/admin_spec.rb +++ b/tests/ruby/admin_spec.rb @@ -287,6 +287,76 @@ describe "Admin" do end end + describe "Manual Banning" do + let(:processes) { Helpers::Pgcat.single_shard_setup("sharded_db", 10) } + before do + new_configs = processes.pgcat.current_config + # Prevent immediate unbanning when we ban localhost + new_configs["pools"]["sharded_db"]["shards"]["0"]["servers"][0][0] = "127.0.0.1" + new_configs["pools"]["sharded_db"]["shards"]["0"]["servers"][1][0] = "127.0.0.1" + processes.pgcat.update_config(new_configs) + processes.pgcat.reload_config + end + + describe "BAN/UNBAN and SHOW BANS" do + it "bans/unbans hosts" do + admin_conn = PG::connect(processes.pgcat.admin_connection_string) + + # Returns a list of the banned addresses + results = admin_conn.async_exec("BAN localhost 10").to_a + expect(results.count).to eq(2) + expect(results.map{ |r| r["host"] }.uniq).to eq(["localhost"]) + + # Subsequent calls should yield no results + results = admin_conn.async_exec("BAN localhost 10").to_a + expect(results.count).to eq(0) + + results = admin_conn.async_exec("SHOW BANS").to_a + expect(results.count).to eq(2) + expect(results.map{ |r| r["host"] }.uniq).to eq(["localhost"]) + + # Returns a list of the unbanned addresses + results = admin_conn.async_exec("UNBAN localhost").to_a + expect(results.count).to eq(2) + expect(results.map{ |r| r["host"] }.uniq).to eq(["localhost"]) + + # Subsequent calls should yield no results + results = admin_conn.async_exec("UNBAN localhost").to_a + expect(results.count).to eq(0) + + results = admin_conn.async_exec("SHOW BANS").to_a + expect(results.count).to eq(0) + end + + it "honors ban duration" do + admin_conn = PG::connect(processes.pgcat.admin_connection_string) + + # Returns a list of the banned addresses + results = admin_conn.async_exec("BAN localhost 1").to_a + expect(results.count).to eq(2) + expect(results.map{ |r| r["host"] }.uniq).to eq(["localhost"]) + + sleep(2) + + # After 2 seconds the ban should be lifted + results = admin_conn.async_exec("SHOW BANS").to_a + expect(results.count).to eq(0) + end + + it "can handle bad input" do + admin_conn = PG::connect(processes.pgcat.admin_connection_string) + + expect { admin_conn.async_exec("BAN").to_a }.to raise_error(PG::SystemError) + expect { admin_conn.async_exec("BAN a").to_a }.to raise_error(PG::SystemError) + expect { admin_conn.async_exec("BAN a a").to_a }.to raise_error(PG::SystemError) + expect { admin_conn.async_exec("BAN a -5").to_a }.to raise_error(PG::SystemError) + expect { admin_conn.async_exec("BAN a 0").to_a }.to raise_error(PG::SystemError) + expect { admin_conn.async_exec("BAN a a a").to_a }.to raise_error(PG::SystemError) + expect { admin_conn.async_exec("UNBAN").to_a }.to raise_error(PG::SystemError) + end + end + end + describe "SHOW users" do it "returns the right users" do admin_conn = PG::connect(processes.pgcat.admin_connection_string)