diff --git a/src/admin.rs b/src/admin.rs index 7f7ada5..c90f28e 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -1,9 +1,12 @@ +use crate::config::Role; +use crate::pool::BanReason; /// Admin database. use bytes::{Buf, BufMut, BytesMut}; use log::{error, info, trace}; use nix::sys::signal::{self, Signal}; use nix::unistd::Pid; use std::collections::HashMap; +use std::time::{SystemTime, UNIX_EPOCH}; use tokio::time::Instant; use crate::config::{get_config, reload_config, VERSION}; @@ -53,6 +56,14 @@ where let query_parts: Vec<&str> = query.trim_end_matches(';').split_whitespace().collect(); match query_parts[0].to_ascii_uppercase().as_str() { + "BAN" => { + trace!("BAN"); + ban(stream, query_parts).await + } + "UNBAN" => { + trace!("UNBAN"); + unban(stream, query_parts).await + } "RELOAD" => { trace!("RELOAD"); reload(stream, client_server_map).await @@ -74,6 +85,10 @@ where shutdown(stream).await } "SHOW" => match query_parts[1].to_ascii_uppercase().as_str() { + "BANS" => { + trace!("SHOW BANS"); + show_bans(stream).await + } "CONFIG" => { trace!("SHOW CONFIG"); show_config(stream).await @@ -350,6 +365,163 @@ where custom_protocol_response_ok(stream, "SET").await } +/// Bans a host from being used +async fn ban(stream: &mut T, tokens: Vec<&str>) -> Result<(), Error> +where + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ + let host = match tokens.get(1) { + Some(host) => host, + None => return error_response(stream, "usage: BAN hostname duration_seconds").await, + }; + + let duration_seconds = match tokens.get(2) { + Some(duration_seconds) => match duration_seconds.parse::() { + Ok(duration_seconds) => duration_seconds, + Err(_) => { + return error_response(stream, "duration_seconds must be an integer").await; + } + }, + None => return error_response(stream, "usage: BAN hostname duration_seconds").await, + }; + + if duration_seconds <= 0 { + return error_response(stream, "duration_seconds must be >= 0").await; + } + + let columns = vec![ + ("db", DataType::Text), + ("user", DataType::Text), + ("role", DataType::Text), + ("host", DataType::Text), + ]; + let mut res = BytesMut::new(); + res.put(row_description(&columns)); + + for (id, pool) in get_all_pools().iter() { + for address in pool.get_addresses_from_host(host) { + if !pool.is_banned(&address) { + pool.ban(&address, BanReason::AdminBan(duration_seconds), -1); + res.put(data_row(&vec![ + id.db.clone(), + id.user.clone(), + address.role.to_string(), + address.host, + ])); + } + } + } + + res.put(command_complete("BAN")); + + // ReadyForQuery + res.put_u8(b'Z'); + res.put_i32(5); + res.put_u8(b'I'); + + write_all_half(stream, &res).await +} + +/// Clear a host for use +async fn unban(stream: &mut T, tokens: Vec<&str>) -> Result<(), Error> +where + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ + let host = match tokens.get(1) { + Some(host) => host, + None => return error_response(stream, "UNBAN command requires a hostname to unban").await, + }; + + let columns = vec![ + ("db", DataType::Text), + ("user", DataType::Text), + ("role", DataType::Text), + ("host", DataType::Text), + ]; + let mut res = BytesMut::new(); + res.put(row_description(&columns)); + + for (id, pool) in get_all_pools().iter() { + for address in pool.get_addresses_from_host(host) { + if pool.is_banned(&address) { + pool.unban(&address); + res.put(data_row(&vec![ + id.db.clone(), + id.user.clone(), + address.role.to_string(), + address.host, + ])); + } + } + } + + res.put(command_complete("UNBAN")); + + // ReadyForQuery + res.put_u8(b'Z'); + res.put_i32(5); + res.put_u8(b'I'); + + write_all_half(stream, &res).await +} + +/// Shows all the bans +async fn show_bans(stream: &mut T) -> Result<(), Error> +where + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ + let columns = vec![ + ("db", DataType::Text), + ("user", DataType::Text), + ("role", DataType::Text), + ("host", DataType::Text), + ("reason", DataType::Text), + ("ban_time", DataType::Text), + ("ban_duration_seconds", DataType::Text), + ("ban_remaining_seconds", DataType::Text), + ]; + let mut res = BytesMut::new(); + res.put(row_description(&columns)); + + // The block should be pretty quick so we cache the time outside + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_secs() as i64; + + for (id, pool) in get_all_pools().iter() { + for (address, (ban_reason, ban_time)) in pool.get_bans().iter() { + let ban_duration = match ban_reason { + BanReason::AdminBan(duration) => *duration, + _ => pool.settings.ban_time, + }; + let remaining = ban_duration - (now - ban_time.timestamp()); + if remaining <= 0 { + continue; + } + res.put(data_row(&vec![ + id.db.clone(), + id.user.clone(), + address.role.to_string(), + address.host.clone(), + format!("{:?}", ban_reason), + ban_time.to_string(), + ban_duration.to_string(), + remaining.to_string(), + ])); + } + } + + res.put(command_complete("SHOW BANS")); + + // ReadyForQuery + res.put_u8(b'Z'); + res.put_i32(5); + res.put_u8(b'I'); + + write_all_half(stream, &res).await +} + /// Reload the configuration file without restarting the process. async fn reload(stream: &mut T, client_server_map: ClientServerMap) -> Result<(), Error> where diff --git a/src/client.rs b/src/client.rs index 11f4cc6..2e04c2a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,6 +1,9 @@ +use crate::errors::Error; +use crate::pool::BanReason; /// Handle clients by pretending to be a PostgreSQL server. use bytes::{Buf, BufMut, BytesMut}; use log::{debug, error, info, trace, warn}; + use std::collections::HashMap; use std::time::Instant; use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf}; @@ -11,7 +14,7 @@ use tokio::sync::mpsc::Sender; use crate::admin::{generate_server_info_for_admin, handle_admin}; use crate::config::{get_config, Address, PoolMode}; use crate::constants::*; -use crate::errors::Error; + use crate::messages::*; use crate::pool::{get_pool, ClientServerMap, ConnectionPool}; use crate::query_router::{Command, QueryRouter}; @@ -1111,7 +1114,7 @@ where match server.send(message).await { Ok(_) => Ok(()), Err(err) => { - pool.ban(address, self.process_id); + pool.ban(address, BanReason::MessageSendFailed, self.process_id); Err(err) } } @@ -1133,7 +1136,7 @@ where Ok(result) => match result { Ok(message) => Ok(message), Err(err) => { - pool.ban(address, self.process_id); + pool.ban(address, BanReason::MessageReceiveFailed, self.process_id); error_response_terminal( &mut self.write, &format!("error receiving data from server: {:?}", err), @@ -1148,7 +1151,7 @@ where address, pool.settings.user.username ); server.mark_bad(); - pool.ban(address, self.process_id); + pool.ban(address, BanReason::StatementTimeout, self.process_id); error_response_terminal(&mut self.write, "pool statement timeout").await?; Err(Error::StatementTimeout) } @@ -1157,7 +1160,7 @@ where match server.recv().await { Ok(message) => Ok(message), Err(err) => { - pool.ban(address, self.process_id); + pool.ban(address, BanReason::MessageReceiveFailed, self.process_id); error_response_terminal( &mut self.write, &format!("error receiving data from server: {:?}", err), diff --git a/src/errors.rs b/src/errors.rs index 4ac23a8..310243c 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -6,6 +6,7 @@ pub enum Error { SocketError(String), ClientBadStartup, ProtocolSyncError(String), + BadQuery(String), ServerError, BadConfig, AllServersDown, diff --git a/src/pool.rs b/src/pool.rs index cca9577..0a0a53f 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -29,7 +29,7 @@ pub type SecretKey = i32; pub type ServerHost = String; pub type ServerPort = u16; -pub type BanList = Arc>>>; +pub type BanList = Arc>>>; pub type ClientServerMap = Arc>>; pub type PoolMap = HashMap; @@ -38,6 +38,17 @@ pub type PoolMap = HashMap; /// The pool is recreated dynamically when the config is reloaded. pub static POOLS: Lazy> = Lazy::new(|| ArcSwap::from_pointee(HashMap::default())); +// Reasons for banning a server. +#[derive(Debug, PartialEq, Clone)] +pub enum BanReason { + FailedHealthCheck, + MessageSendFailed, + MessageReceiveFailed, + FailedCheckout, + StatementTimeout, + AdminBan(i64), +} + /// An identifier for a PgCat pool, /// a database visible to clients. #[derive(Hash, Debug, Clone, PartialEq, Eq)] @@ -489,7 +500,7 @@ impl ConnectionPool { Ok(conn) => conn, Err(err) => { error!("Banning instance {:?}, error: {:?}", address, err); - self.ban(address, client_process_id); + self.ban(address, BanReason::FailedCheckout, client_process_id); self.stats .client_checkout_error(client_process_id, address.id); continue; @@ -582,14 +593,14 @@ impl ConnectionPool { // Don't leave a bad connection in the pool. server.mark_bad(); - self.ban(&address, client_process_id); + self.ban(&address, BanReason::FailedHealthCheck, client_process_id); return false; } /// Ban an address (i.e. replica). It no longer will serve /// traffic for any new transactions. Existing transactions on that replica /// will finish successfully or error out to the clients. - pub fn ban(&self, address: &Address, client_id: i32) { + pub fn ban(&self, address: &Address, reason: BanReason, client_id: i32) { // Primary can never be banned if address.role == Role::Primary { return; @@ -599,12 +610,12 @@ impl ConnectionPool { let mut guard = self.banlist.write(); error!("Banning {:?}", address); self.stats.client_ban_error(client_id, address.id); - guard[address.shard].insert(address.clone(), now); + guard[address.shard].insert(address.clone(), (reason, now)); } /// Clear the replica to receive traffic again. Takes effect immediately /// for all new transactions. - pub fn _unban(&self, address: &Address) { + pub fn unban(&self, address: &Address) { let mut guard = self.banlist.write(); guard[address.shard].remove(address); } @@ -653,9 +664,14 @@ impl ConnectionPool { // Check if ban time is expired let read_guard = self.banlist.read(); let exceeded_ban_time = match read_guard[address.shard].get(address) { - Some(timestamp) => { + Some((ban_reason, timestamp)) => { let now = chrono::offset::Utc::now().naive_utc(); - now.timestamp() - timestamp.timestamp() > self.settings.ban_time + match ban_reason { + BanReason::AdminBan(duration) => { + now.timestamp() - timestamp.timestamp() > *duration + } + _ => now.timestamp() - timestamp.timestamp() > self.settings.ban_time, + } } None => return true, }; @@ -679,6 +695,31 @@ impl ConnectionPool { self.databases.len() } + pub fn get_bans(&self) -> Vec<(Address, (BanReason, NaiveDateTime))> { + let mut bans: Vec<(Address, (BanReason, NaiveDateTime))> = Vec::new(); + let guard = self.banlist.read(); + for banlist in guard.iter() { + for (address, (reason, timestamp)) in banlist.iter() { + bans.push((address.clone(), (reason.clone(), timestamp.clone()))); + } + } + return bans; + } + + /// Get the address from the host url + pub fn get_addresses_from_host(&self, host: &str) -> Vec
{ + let mut addresses = Vec::new(); + for shard in 0..self.shards() { + for server in 0..self.servers(shard) { + let address = self.address(shard, server); + if address.host == host { + addresses.push(address.clone()); + } + } + } + addresses + } + /// Get the number of servers (primary and replicas) /// configured for a shard. pub fn servers(&self, shard: usize) -> usize { diff --git a/tests/ruby/admin_spec.rb b/tests/ruby/admin_spec.rb index 7836415..f69c3df 100644 --- a/tests/ruby/admin_spec.rb +++ b/tests/ruby/admin_spec.rb @@ -287,6 +287,76 @@ describe "Admin" do end end + describe "Manual Banning" do + let(:processes) { Helpers::Pgcat.single_shard_setup("sharded_db", 10) } + before do + new_configs = processes.pgcat.current_config + # Prevent immediate unbanning when we ban localhost + new_configs["pools"]["sharded_db"]["shards"]["0"]["servers"][0][0] = "127.0.0.1" + new_configs["pools"]["sharded_db"]["shards"]["0"]["servers"][1][0] = "127.0.0.1" + processes.pgcat.update_config(new_configs) + processes.pgcat.reload_config + end + + describe "BAN/UNBAN and SHOW BANS" do + it "bans/unbans hosts" do + admin_conn = PG::connect(processes.pgcat.admin_connection_string) + + # Returns a list of the banned addresses + results = admin_conn.async_exec("BAN localhost 10").to_a + expect(results.count).to eq(2) + expect(results.map{ |r| r["host"] }.uniq).to eq(["localhost"]) + + # Subsequent calls should yield no results + results = admin_conn.async_exec("BAN localhost 10").to_a + expect(results.count).to eq(0) + + results = admin_conn.async_exec("SHOW BANS").to_a + expect(results.count).to eq(2) + expect(results.map{ |r| r["host"] }.uniq).to eq(["localhost"]) + + # Returns a list of the unbanned addresses + results = admin_conn.async_exec("UNBAN localhost").to_a + expect(results.count).to eq(2) + expect(results.map{ |r| r["host"] }.uniq).to eq(["localhost"]) + + # Subsequent calls should yield no results + results = admin_conn.async_exec("UNBAN localhost").to_a + expect(results.count).to eq(0) + + results = admin_conn.async_exec("SHOW BANS").to_a + expect(results.count).to eq(0) + end + + it "honors ban duration" do + admin_conn = PG::connect(processes.pgcat.admin_connection_string) + + # Returns a list of the banned addresses + results = admin_conn.async_exec("BAN localhost 1").to_a + expect(results.count).to eq(2) + expect(results.map{ |r| r["host"] }.uniq).to eq(["localhost"]) + + sleep(2) + + # After 2 seconds the ban should be lifted + results = admin_conn.async_exec("SHOW BANS").to_a + expect(results.count).to eq(0) + end + + it "can handle bad input" do + admin_conn = PG::connect(processes.pgcat.admin_connection_string) + + expect { admin_conn.async_exec("BAN").to_a }.to raise_error(PG::SystemError) + expect { admin_conn.async_exec("BAN a").to_a }.to raise_error(PG::SystemError) + expect { admin_conn.async_exec("BAN a a").to_a }.to raise_error(PG::SystemError) + expect { admin_conn.async_exec("BAN a -5").to_a }.to raise_error(PG::SystemError) + expect { admin_conn.async_exec("BAN a 0").to_a }.to raise_error(PG::SystemError) + expect { admin_conn.async_exec("BAN a a a").to_a }.to raise_error(PG::SystemError) + expect { admin_conn.async_exec("UNBAN").to_a }.to raise_error(PG::SystemError) + end + end + end + describe "SHOW users" do it "returns the right users" do admin_conn = PG::connect(processes.pgcat.admin_connection_string)