Graceful shutdown and refactor (#144)

* Graceful shutdown and refactor

* ok

* _Graceful_ shutdown

* Remove hardcoded setting

* clean up

* end

* timeout

* hmm

* hmm!

* bash

* bash

* hmm

* maybe maybe

* Adds tests and move non-admin connection rejection to startup (#145)

* Move error response

* Adds tests and removes unused variable

* Adds debug log

Co-authored-by: zainkabani <77307340+zainkabani@users.noreply.github.com>
This commit is contained in:
Lev Kokotov
2022-08-25 06:40:56 -07:00
committed by GitHub
parent c054ff068d
commit 9d84d6f131
9 changed files with 602 additions and 382 deletions

View File

@@ -12,40 +12,74 @@ use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use crate::config::{get_config, Address, Role, Shard, User};
use crate::config::{get_config, Address, Role, User};
use crate::errors::Error;
use crate::server::Server;
use crate::sharding::ShardingFunction;
use crate::stats::{get_reporter, Reporter};
pub type BanList = Arc<RwLock<Vec<HashMap<Address, NaiveDateTime>>>>;
pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, String)>>>;
pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, u16)>>>;
pub type PoolMap = HashMap<(String, String), ConnectionPool>;
/// The connection pool, globally available.
/// This is atomic and safe and read-optimized.
/// The pool is recreated dynamically when the config is reloaded.
pub static POOLS: Lazy<ArcSwap<PoolMap>> = Lazy::new(|| ArcSwap::from_pointee(HashMap::default()));
/// Pool mode:
/// - transaction: server serves one transaction,
/// - session: server is attached to the client.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum PoolMode {
Session,
Transaction,
}
impl std::fmt::Display for PoolMode {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match *self {
PoolMode::Session => write!(f, "session"),
PoolMode::Transaction => write!(f, "transaction"),
}
}
}
/// Pool settings.
#[derive(Clone, Debug)]
pub struct PoolSettings {
pub pool_mode: String,
pub shards: HashMap<String, Shard>,
/// Transaction or Session.
pub pool_mode: PoolMode,
// Number of shards.
pub shards: usize,
// Connecting user.
pub user: User,
pub default_role: String,
// Default server role to connect to.
pub default_role: Option<Role>,
// Enable/disable query parser.
pub query_parser_enabled: bool,
// Read from the primary as well or not.
pub primary_reads_enabled: bool,
pub sharding_function: String,
// Sharding function.
pub sharding_function: ShardingFunction,
}
impl Default for PoolSettings {
fn default() -> PoolSettings {
PoolSettings {
pool_mode: String::from("transaction"),
shards: HashMap::from([(String::from("1"), Shard::default())]),
pool_mode: PoolMode::Transaction,
shards: 1,
user: User::default(),
default_role: String::from("any"),
default_role: None,
query_parser_enabled: false,
primary_reads_enabled: true,
sharding_function: "pg_bigint_hash".to_string(),
sharding_function: ShardingFunction::PgBigintHash,
}
}
}
@@ -73,6 +107,7 @@ pub struct ConnectionPool {
/// on pool creation and save the K messages here.
server_info: BytesMut,
/// Pool configuration.
pub settings: PoolSettings,
}
@@ -80,11 +115,13 @@ impl ConnectionPool {
/// Construct the connection pool from the configuration.
pub async fn from_config(client_server_map: ClientServerMap) -> Result<(), Error> {
let config = get_config();
let mut new_pools = PoolMap::default();
let mut new_pools = HashMap::new();
let mut address_id = 0;
for (pool_name, pool_config) in &config.pools {
for (_user_index, user_info) in &pool_config.users {
// There is one pool per database/user pair.
for (_, user) in &pool_config.users {
let mut shards = Vec::new();
let mut addresses = Vec::new();
let mut banlist = Vec::new();
@@ -98,8 +135,8 @@ impl ConnectionPool {
// Sort by shard number to ensure consistency.
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
for shard_idx in shard_ids {
let shard = &pool_config.shards[&shard_idx];
for shard_idx in &shard_ids {
let shard = &pool_config.shards[shard_idx];
let mut pools = Vec::new();
let mut servers = Vec::new();
let mut address_index = 0;
@@ -119,12 +156,12 @@ impl ConnectionPool {
id: address_id,
database: shard.database.clone(),
host: server.0.clone(),
port: server.1.to_string(),
port: server.1 as u16,
role: role,
address_index,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
username: user_info.username.clone(),
username: user.username.clone(),
pool_name: pool_name.clone(),
};
@@ -137,14 +174,14 @@ impl ConnectionPool {
let manager = ServerPool::new(
address.clone(),
user_info.clone(),
user.clone(),
&shard.database,
client_server_map.clone(),
get_reporter(),
);
let pool = Pool::builder()
.max_size(user_info.pool_size)
.max_size(user.pool_size)
.connection_timeout(std::time::Duration::from_millis(
config.general.connect_timeout,
))
@@ -171,13 +208,27 @@ impl ConnectionPool {
stats: get_reporter(),
server_info: BytesMut::new(),
settings: PoolSettings {
pool_mode: pool_config.pool_mode.clone(),
shards: pool_config.shards.clone(),
user: user_info.clone(),
default_role: pool_config.default_role.clone(),
pool_mode: match pool_config.pool_mode.as_str() {
"transaction" => PoolMode::Transaction,
"session" => PoolMode::Session,
_ => unreachable!(),
},
// shards: pool_config.shards.clone(),
shards: shard_ids.len(),
user: user.clone(),
default_role: match pool_config.default_role.as_str() {
"any" => None,
"replica" => Some(Role::Replica),
"primary" => Some(Role::Primary),
_ => unreachable!(),
},
query_parser_enabled: pool_config.query_parser_enabled.clone(),
primary_reads_enabled: pool_config.primary_reads_enabled,
sharding_function: pool_config.sharding_function.clone(),
sharding_function: match pool_config.sharding_function.as_str() {
"pg_bigint_hash" => ShardingFunction::PgBigintHash,
"sha1" => ShardingFunction::Sha1,
_ => unreachable!(),
},
},
};
@@ -190,7 +241,9 @@ impl ConnectionPool {
return Err(err);
}
};
new_pools.insert((pool_name.clone(), user_info.username.clone()), pool);
// There is one pool per database/user pair.
new_pools.insert((pool_name.clone(), user.username.clone()), pool);
}
}
@@ -207,8 +260,8 @@ impl ConnectionPool {
async fn validate(&mut self) -> Result<(), Error> {
let mut server_infos = Vec::new();
for shard in 0..self.shards() {
for index in 0..self.servers(shard) {
let connection = match self.databases[shard][index].get().await {
for server in 0..self.servers(shard) {
let connection = match self.databases[shard][server].get().await {
Ok(conn) => conn,
Err(err) => {
error!("Shard {} down or misconfigured: {:?}", shard, err);
@@ -229,6 +282,7 @@ impl ConnectionPool {
);
}
}
server_infos.push(server_info);
}
}
@@ -239,6 +293,8 @@ impl ConnectionPool {
return Err(Error::AllServersDown);
}
// We're assuming all servers are identical.
// TODO: not true.
self.server_info = server_infos[0].clone();
Ok(())
@@ -252,9 +308,8 @@ impl ConnectionPool {
process_id: i32, // client id
) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
let now = Instant::now();
let mut candidates: Vec<Address> = self.addresses[shard]
.clone()
.into_iter()
let mut candidates: Vec<&Address> = self.addresses[shard]
.iter()
.filter(|address| address.role == role)
.collect();
@@ -271,7 +326,8 @@ impl ConnectionPool {
None => break,
};
if self.is_banned(&address, address.shard, role) {
if self.is_banned(&address, role) {
debug!("Address {:?} is banned", address);
continue;
}
@@ -286,8 +342,7 @@ impl ConnectionPool {
Ok(conn) => conn,
Err(err) => {
error!("Banning instance {:?}, error: {:?}", address, err);
self.ban(&address, address.shard, process_id);
self.stats.client_disconnecting(process_id, address.id);
self.ban(&address, process_id);
self.stats
.checkout_time(now.elapsed().as_micros(), process_id, address.id);
continue;
@@ -301,6 +356,9 @@ impl ConnectionPool {
let require_healthcheck =
server.last_activity().elapsed().unwrap().as_millis() > healthcheck_delay;
// Do not issue a health check unless it's been a little while
// since we last checked the server is ok.
// Health checks are pretty expensive.
if !require_healthcheck {
self.stats
.checkout_time(now.elapsed().as_micros(), process_id, address.id);
@@ -314,7 +372,7 @@ impl ConnectionPool {
match tokio::time::timeout(
tokio::time::Duration::from_millis(healthcheck_timeout),
server.query(";"),
server.query(";"), // Cheap query (query parser not used in PG)
)
.await
{
@@ -337,7 +395,7 @@ impl ConnectionPool {
// Don't leave a bad connection in the pool.
server.mark_bad();
self.ban(&address, address.shard, process_id);
self.ban(&address, process_id);
continue;
}
},
@@ -351,44 +409,44 @@ impl ConnectionPool {
// Don't leave a bad connection in the pool.
server.mark_bad();
self.ban(&address, address.shard, process_id);
self.ban(&address, process_id);
continue;
}
}
}
return Err(Error::AllServersDown);
Err(Error::AllServersDown)
}
/// Ban an address (i.e. replica). It no longer will serve
/// traffic for any new transactions. Existing transactions on that replica
/// will finish successfully or error out to the clients.
pub fn ban(&self, address: &Address, shard: usize, process_id: i32) {
pub fn ban(&self, address: &Address, process_id: i32) {
self.stats.client_disconnecting(process_id, address.id);
self.stats
.checkout_time(Instant::now().elapsed().as_micros(), process_id, address.id);
error!("Banning {:?}", address);
let now = chrono::offset::Utc::now().naive_utc();
let mut guard = self.banlist.write();
guard[shard].insert(address.clone(), now);
guard[address.shard].insert(address.clone(), now);
}
/// Clear the replica to receive traffic again. Takes effect immediately
/// for all new transactions.
pub fn _unban(&self, address: &Address, shard: usize) {
pub fn _unban(&self, address: &Address) {
let mut guard = self.banlist.write();
guard[shard].remove(address);
guard[address.shard].remove(address);
}
/// Check if a replica can serve traffic. If all replicas are banned,
/// we unban all of them. Better to try then not to.
pub fn is_banned(&self, address: &Address, shard: usize, role: Option<Role>) -> bool {
pub fn is_banned(&self, address: &Address, role: Option<Role>) -> bool {
let replicas_available = match role {
Some(Role::Replica) => self.addresses[shard]
Some(Role::Replica) => self.addresses[address.shard]
.iter()
.filter(|addr| addr.role == Role::Replica)
.count(),
None => self.addresses[shard].len(),
None => self.addresses[address.shard].len(),
Some(Role::Primary) => return false, // Primary cannot be banned.
};
@@ -397,17 +455,17 @@ impl ConnectionPool {
let guard = self.banlist.read();
// Everything is banned = nothing is banned.
if guard[shard].len() == replicas_available {
if guard[address.shard].len() == replicas_available {
drop(guard);
let mut guard = self.banlist.write();
guard[shard].clear();
guard[address.shard].clear();
drop(guard);
warn!("Unbanning all replicas.");
return false;
}
// I expect this to miss 99.9999% of the time.
match guard[shard].get(address) {
match guard[address.shard].get(address) {
Some(timestamp) => {
let now = chrono::offset::Utc::now().naive_utc();
let config = get_config();
@@ -417,7 +475,7 @@ impl ConnectionPool {
drop(guard);
warn!("Unbanning {:?}", address);
let mut guard = self.banlist.write();
guard[shard].remove(address);
guard[address.shard].remove(address);
false
} else {
debug!("{:?} is banned", address);
@@ -554,6 +612,7 @@ pub fn get_pool(db: String, user: String) -> Option<ConnectionPool> {
}
}
/// How many total servers we have in the config.
pub fn get_number_of_addresses() -> usize {
get_all_pools()
.iter()
@@ -561,6 +620,7 @@ pub fn get_number_of_addresses() -> usize {
.sum()
}
/// Get a pointer to all configured pools.
pub fn get_all_pools() -> HashMap<(String, String), ConnectionPool> {
return (*(*POOLS.load())).clone();
}