mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 01:16:30 +00:00
Graceful shutdown and refactor (#144)
* Graceful shutdown and refactor * ok * _Graceful_ shutdown * Remove hardcoded setting * clean up * end * timeout * hmm * hmm! * bash * bash * hmm * maybe maybe * Adds tests and move non-admin connection rejection to startup (#145) * Move error response * Adds tests and removes unused variable * Adds debug log Co-authored-by: zainkabani <77307340+zainkabani@users.noreply.github.com>
This commit is contained in:
158
src/pool.rs
158
src/pool.rs
@@ -12,40 +12,74 @@ use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::config::{get_config, Address, Role, Shard, User};
|
||||
use crate::config::{get_config, Address, Role, User};
|
||||
use crate::errors::Error;
|
||||
|
||||
use crate::server::Server;
|
||||
use crate::sharding::ShardingFunction;
|
||||
use crate::stats::{get_reporter, Reporter};
|
||||
|
||||
pub type BanList = Arc<RwLock<Vec<HashMap<Address, NaiveDateTime>>>>;
|
||||
pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, String)>>>;
|
||||
pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, u16)>>>;
|
||||
pub type PoolMap = HashMap<(String, String), ConnectionPool>;
|
||||
/// The connection pool, globally available.
|
||||
/// This is atomic and safe and read-optimized.
|
||||
/// The pool is recreated dynamically when the config is reloaded.
|
||||
pub static POOLS: Lazy<ArcSwap<PoolMap>> = Lazy::new(|| ArcSwap::from_pointee(HashMap::default()));
|
||||
|
||||
/// Pool mode:
|
||||
/// - transaction: server serves one transaction,
|
||||
/// - session: server is attached to the client.
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum PoolMode {
|
||||
Session,
|
||||
Transaction,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PoolMode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match *self {
|
||||
PoolMode::Session => write!(f, "session"),
|
||||
PoolMode::Transaction => write!(f, "transaction"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Pool settings.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PoolSettings {
|
||||
pub pool_mode: String,
|
||||
pub shards: HashMap<String, Shard>,
|
||||
/// Transaction or Session.
|
||||
pub pool_mode: PoolMode,
|
||||
|
||||
// Number of shards.
|
||||
pub shards: usize,
|
||||
|
||||
// Connecting user.
|
||||
pub user: User,
|
||||
pub default_role: String,
|
||||
|
||||
// Default server role to connect to.
|
||||
pub default_role: Option<Role>,
|
||||
|
||||
// Enable/disable query parser.
|
||||
pub query_parser_enabled: bool,
|
||||
|
||||
// Read from the primary as well or not.
|
||||
pub primary_reads_enabled: bool,
|
||||
pub sharding_function: String,
|
||||
|
||||
// Sharding function.
|
||||
pub sharding_function: ShardingFunction,
|
||||
}
|
||||
|
||||
impl Default for PoolSettings {
|
||||
fn default() -> PoolSettings {
|
||||
PoolSettings {
|
||||
pool_mode: String::from("transaction"),
|
||||
shards: HashMap::from([(String::from("1"), Shard::default())]),
|
||||
pool_mode: PoolMode::Transaction,
|
||||
shards: 1,
|
||||
user: User::default(),
|
||||
default_role: String::from("any"),
|
||||
default_role: None,
|
||||
query_parser_enabled: false,
|
||||
primary_reads_enabled: true,
|
||||
sharding_function: "pg_bigint_hash".to_string(),
|
||||
sharding_function: ShardingFunction::PgBigintHash,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -73,6 +107,7 @@ pub struct ConnectionPool {
|
||||
/// on pool creation and save the K messages here.
|
||||
server_info: BytesMut,
|
||||
|
||||
/// Pool configuration.
|
||||
pub settings: PoolSettings,
|
||||
}
|
||||
|
||||
@@ -80,11 +115,13 @@ impl ConnectionPool {
|
||||
/// Construct the connection pool from the configuration.
|
||||
pub async fn from_config(client_server_map: ClientServerMap) -> Result<(), Error> {
|
||||
let config = get_config();
|
||||
let mut new_pools = PoolMap::default();
|
||||
|
||||
let mut new_pools = HashMap::new();
|
||||
let mut address_id = 0;
|
||||
|
||||
for (pool_name, pool_config) in &config.pools {
|
||||
for (_user_index, user_info) in &pool_config.users {
|
||||
// There is one pool per database/user pair.
|
||||
for (_, user) in &pool_config.users {
|
||||
let mut shards = Vec::new();
|
||||
let mut addresses = Vec::new();
|
||||
let mut banlist = Vec::new();
|
||||
@@ -98,8 +135,8 @@ impl ConnectionPool {
|
||||
// Sort by shard number to ensure consistency.
|
||||
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
|
||||
|
||||
for shard_idx in shard_ids {
|
||||
let shard = &pool_config.shards[&shard_idx];
|
||||
for shard_idx in &shard_ids {
|
||||
let shard = &pool_config.shards[shard_idx];
|
||||
let mut pools = Vec::new();
|
||||
let mut servers = Vec::new();
|
||||
let mut address_index = 0;
|
||||
@@ -119,12 +156,12 @@ impl ConnectionPool {
|
||||
id: address_id,
|
||||
database: shard.database.clone(),
|
||||
host: server.0.clone(),
|
||||
port: server.1.to_string(),
|
||||
port: server.1 as u16,
|
||||
role: role,
|
||||
address_index,
|
||||
replica_number,
|
||||
shard: shard_idx.parse::<usize>().unwrap(),
|
||||
username: user_info.username.clone(),
|
||||
username: user.username.clone(),
|
||||
pool_name: pool_name.clone(),
|
||||
};
|
||||
|
||||
@@ -137,14 +174,14 @@ impl ConnectionPool {
|
||||
|
||||
let manager = ServerPool::new(
|
||||
address.clone(),
|
||||
user_info.clone(),
|
||||
user.clone(),
|
||||
&shard.database,
|
||||
client_server_map.clone(),
|
||||
get_reporter(),
|
||||
);
|
||||
|
||||
let pool = Pool::builder()
|
||||
.max_size(user_info.pool_size)
|
||||
.max_size(user.pool_size)
|
||||
.connection_timeout(std::time::Duration::from_millis(
|
||||
config.general.connect_timeout,
|
||||
))
|
||||
@@ -171,13 +208,27 @@ impl ConnectionPool {
|
||||
stats: get_reporter(),
|
||||
server_info: BytesMut::new(),
|
||||
settings: PoolSettings {
|
||||
pool_mode: pool_config.pool_mode.clone(),
|
||||
shards: pool_config.shards.clone(),
|
||||
user: user_info.clone(),
|
||||
default_role: pool_config.default_role.clone(),
|
||||
pool_mode: match pool_config.pool_mode.as_str() {
|
||||
"transaction" => PoolMode::Transaction,
|
||||
"session" => PoolMode::Session,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
// shards: pool_config.shards.clone(),
|
||||
shards: shard_ids.len(),
|
||||
user: user.clone(),
|
||||
default_role: match pool_config.default_role.as_str() {
|
||||
"any" => None,
|
||||
"replica" => Some(Role::Replica),
|
||||
"primary" => Some(Role::Primary),
|
||||
_ => unreachable!(),
|
||||
},
|
||||
query_parser_enabled: pool_config.query_parser_enabled.clone(),
|
||||
primary_reads_enabled: pool_config.primary_reads_enabled,
|
||||
sharding_function: pool_config.sharding_function.clone(),
|
||||
sharding_function: match pool_config.sharding_function.as_str() {
|
||||
"pg_bigint_hash" => ShardingFunction::PgBigintHash,
|
||||
"sha1" => ShardingFunction::Sha1,
|
||||
_ => unreachable!(),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -190,7 +241,9 @@ impl ConnectionPool {
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
new_pools.insert((pool_name.clone(), user_info.username.clone()), pool);
|
||||
|
||||
// There is one pool per database/user pair.
|
||||
new_pools.insert((pool_name.clone(), user.username.clone()), pool);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -207,8 +260,8 @@ impl ConnectionPool {
|
||||
async fn validate(&mut self) -> Result<(), Error> {
|
||||
let mut server_infos = Vec::new();
|
||||
for shard in 0..self.shards() {
|
||||
for index in 0..self.servers(shard) {
|
||||
let connection = match self.databases[shard][index].get().await {
|
||||
for server in 0..self.servers(shard) {
|
||||
let connection = match self.databases[shard][server].get().await {
|
||||
Ok(conn) => conn,
|
||||
Err(err) => {
|
||||
error!("Shard {} down or misconfigured: {:?}", shard, err);
|
||||
@@ -229,6 +282,7 @@ impl ConnectionPool {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
server_infos.push(server_info);
|
||||
}
|
||||
}
|
||||
@@ -239,6 +293,8 @@ impl ConnectionPool {
|
||||
return Err(Error::AllServersDown);
|
||||
}
|
||||
|
||||
// We're assuming all servers are identical.
|
||||
// TODO: not true.
|
||||
self.server_info = server_infos[0].clone();
|
||||
|
||||
Ok(())
|
||||
@@ -252,9 +308,8 @@ impl ConnectionPool {
|
||||
process_id: i32, // client id
|
||||
) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
|
||||
let now = Instant::now();
|
||||
let mut candidates: Vec<Address> = self.addresses[shard]
|
||||
.clone()
|
||||
.into_iter()
|
||||
let mut candidates: Vec<&Address> = self.addresses[shard]
|
||||
.iter()
|
||||
.filter(|address| address.role == role)
|
||||
.collect();
|
||||
|
||||
@@ -271,7 +326,8 @@ impl ConnectionPool {
|
||||
None => break,
|
||||
};
|
||||
|
||||
if self.is_banned(&address, address.shard, role) {
|
||||
if self.is_banned(&address, role) {
|
||||
debug!("Address {:?} is banned", address);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -286,8 +342,7 @@ impl ConnectionPool {
|
||||
Ok(conn) => conn,
|
||||
Err(err) => {
|
||||
error!("Banning instance {:?}, error: {:?}", address, err);
|
||||
self.ban(&address, address.shard, process_id);
|
||||
self.stats.client_disconnecting(process_id, address.id);
|
||||
self.ban(&address, process_id);
|
||||
self.stats
|
||||
.checkout_time(now.elapsed().as_micros(), process_id, address.id);
|
||||
continue;
|
||||
@@ -301,6 +356,9 @@ impl ConnectionPool {
|
||||
let require_healthcheck =
|
||||
server.last_activity().elapsed().unwrap().as_millis() > healthcheck_delay;
|
||||
|
||||
// Do not issue a health check unless it's been a little while
|
||||
// since we last checked the server is ok.
|
||||
// Health checks are pretty expensive.
|
||||
if !require_healthcheck {
|
||||
self.stats
|
||||
.checkout_time(now.elapsed().as_micros(), process_id, address.id);
|
||||
@@ -314,7 +372,7 @@ impl ConnectionPool {
|
||||
|
||||
match tokio::time::timeout(
|
||||
tokio::time::Duration::from_millis(healthcheck_timeout),
|
||||
server.query(";"),
|
||||
server.query(";"), // Cheap query (query parser not used in PG)
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -337,7 +395,7 @@ impl ConnectionPool {
|
||||
// Don't leave a bad connection in the pool.
|
||||
server.mark_bad();
|
||||
|
||||
self.ban(&address, address.shard, process_id);
|
||||
self.ban(&address, process_id);
|
||||
continue;
|
||||
}
|
||||
},
|
||||
@@ -351,44 +409,44 @@ impl ConnectionPool {
|
||||
// Don't leave a bad connection in the pool.
|
||||
server.mark_bad();
|
||||
|
||||
self.ban(&address, address.shard, process_id);
|
||||
self.ban(&address, process_id);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Err(Error::AllServersDown);
|
||||
|
||||
Err(Error::AllServersDown)
|
||||
}
|
||||
|
||||
/// Ban an address (i.e. replica). It no longer will serve
|
||||
/// traffic for any new transactions. Existing transactions on that replica
|
||||
/// will finish successfully or error out to the clients.
|
||||
pub fn ban(&self, address: &Address, shard: usize, process_id: i32) {
|
||||
pub fn ban(&self, address: &Address, process_id: i32) {
|
||||
self.stats.client_disconnecting(process_id, address.id);
|
||||
self.stats
|
||||
.checkout_time(Instant::now().elapsed().as_micros(), process_id, address.id);
|
||||
|
||||
error!("Banning {:?}", address);
|
||||
|
||||
let now = chrono::offset::Utc::now().naive_utc();
|
||||
let mut guard = self.banlist.write();
|
||||
guard[shard].insert(address.clone(), now);
|
||||
guard[address.shard].insert(address.clone(), now);
|
||||
}
|
||||
|
||||
/// Clear the replica to receive traffic again. Takes effect immediately
|
||||
/// for all new transactions.
|
||||
pub fn _unban(&self, address: &Address, shard: usize) {
|
||||
pub fn _unban(&self, address: &Address) {
|
||||
let mut guard = self.banlist.write();
|
||||
guard[shard].remove(address);
|
||||
guard[address.shard].remove(address);
|
||||
}
|
||||
|
||||
/// Check if a replica can serve traffic. If all replicas are banned,
|
||||
/// we unban all of them. Better to try then not to.
|
||||
pub fn is_banned(&self, address: &Address, shard: usize, role: Option<Role>) -> bool {
|
||||
pub fn is_banned(&self, address: &Address, role: Option<Role>) -> bool {
|
||||
let replicas_available = match role {
|
||||
Some(Role::Replica) => self.addresses[shard]
|
||||
Some(Role::Replica) => self.addresses[address.shard]
|
||||
.iter()
|
||||
.filter(|addr| addr.role == Role::Replica)
|
||||
.count(),
|
||||
None => self.addresses[shard].len(),
|
||||
None => self.addresses[address.shard].len(),
|
||||
Some(Role::Primary) => return false, // Primary cannot be banned.
|
||||
};
|
||||
|
||||
@@ -397,17 +455,17 @@ impl ConnectionPool {
|
||||
let guard = self.banlist.read();
|
||||
|
||||
// Everything is banned = nothing is banned.
|
||||
if guard[shard].len() == replicas_available {
|
||||
if guard[address.shard].len() == replicas_available {
|
||||
drop(guard);
|
||||
let mut guard = self.banlist.write();
|
||||
guard[shard].clear();
|
||||
guard[address.shard].clear();
|
||||
drop(guard);
|
||||
warn!("Unbanning all replicas.");
|
||||
return false;
|
||||
}
|
||||
|
||||
// I expect this to miss 99.9999% of the time.
|
||||
match guard[shard].get(address) {
|
||||
match guard[address.shard].get(address) {
|
||||
Some(timestamp) => {
|
||||
let now = chrono::offset::Utc::now().naive_utc();
|
||||
let config = get_config();
|
||||
@@ -417,7 +475,7 @@ impl ConnectionPool {
|
||||
drop(guard);
|
||||
warn!("Unbanning {:?}", address);
|
||||
let mut guard = self.banlist.write();
|
||||
guard[shard].remove(address);
|
||||
guard[address.shard].remove(address);
|
||||
false
|
||||
} else {
|
||||
debug!("{:?} is banned", address);
|
||||
@@ -554,6 +612,7 @@ pub fn get_pool(db: String, user: String) -> Option<ConnectionPool> {
|
||||
}
|
||||
}
|
||||
|
||||
/// How many total servers we have in the config.
|
||||
pub fn get_number_of_addresses() -> usize {
|
||||
get_all_pools()
|
||||
.iter()
|
||||
@@ -561,6 +620,7 @@ pub fn get_number_of_addresses() -> usize {
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Get a pointer to all configured pools.
|
||||
pub fn get_all_pools() -> HashMap<(String, String), ConnectionPool> {
|
||||
return (*(*POOLS.load())).clone();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user