From a0e740d30f60aadfdde7e6f0ebc30804c3f639ae Mon Sep 17 00:00:00 2001 From: zainkabani <77307340+zainkabani@users.noreply.github.com> Date: Thu, 19 Jan 2023 20:36:48 -0500 Subject: [PATCH] Refactors is_banned logic and forces health check on unban (#288) * Refactors is_banned logic and forces healthcheck on unban * typo * Make is banned log debug * addressing comments * Comment --- src/admin.rs | 2 +- src/pool.rs | 257 ++++++++++++++++++++++++++------------------ src/query_router.rs | 1 + 3 files changed, 154 insertions(+), 106 deletions(-) diff --git a/src/admin.rs b/src/admin.rs index 5879114..9d4526e 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -286,7 +286,7 @@ where for server in 0..pool.servers(shard) { let address = pool.address(shard, server); let pool_state = pool.pool_state(shard, server); - let banned = pool.is_banned(address, Some(address.role)); + let banned = pool.is_banned(address); res.put(data_row(&vec![ address.name(), // name diff --git a/src/pool.rs b/src/pool.rs index 82720aa..0f92158 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -91,6 +91,9 @@ pub struct PoolSettings { // Health check delay pub healthcheck_delay: u64, + + // Ban time + pub ban_time: i64, } impl Default for PoolSettings { @@ -107,6 +110,7 @@ impl Default for PoolSettings { automatic_sharding_key: None, healthcheck_delay: General::default_healthcheck_delay(), healthcheck_timeout: General::default_healthcheck_timeout(), + ban_time: General::default_ban_time(), } } } @@ -277,6 +281,7 @@ impl ConnectionPool { automatic_sharding_key: pool_config.automatic_sharding_key.clone(), healthcheck_delay: config.general.healthcheck_delay, healthcheck_timeout: config.general.healthcheck_timeout, + ban_time: config.general.ban_time, }, }; @@ -352,9 +357,9 @@ impl ConnectionPool { /// Get a connection from the pool. pub async fn get( &self, - shard: usize, // shard number - role: Option, // primary or replica - process_id: i32, // client id + shard: usize, // shard number + role: Option, // primary or replica + client_process_id: i32, // client id ) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> { let mut candidates: Vec<&Address> = self.addresses[shard] .iter() @@ -380,14 +385,20 @@ impl ConnectionPool { None => break, }; - if self.is_banned(address, role) { - debug!("Address {:?} is banned", address); - continue; + let mut force_healthcheck = false; + + if self.is_banned(address) { + if self.try_unban(&address).await { + force_healthcheck = true; + } else { + debug!("Address {:?} is banned", address); + continue; + } } // Indicate we're waiting on a server connection from a pool. let now = Instant::now(); - self.stats.client_waiting(process_id); + self.stats.client_waiting(client_process_id); // Check if we can connect let mut conn = match self.databases[address.shard][address.address_index] @@ -397,8 +408,9 @@ impl ConnectionPool { Ok(conn) => conn, Err(err) => { error!("Banning instance {:?}, error: {:?}", address, err); - self.ban(address, process_id); - self.stats.client_checkout_error(process_id, address.id); + self.ban(address, client_process_id); + self.stats + .client_checkout_error(client_process_id, address.id); continue; } }; @@ -407,83 +419,105 @@ impl ConnectionPool { let server = &mut *conn; // Will return error if timestamp is greater than current system time, which it should never be set to - let require_healthcheck = server.last_activity().elapsed().unwrap().as_millis() - > self.settings.healthcheck_delay as u128; + let require_healthcheck = force_healthcheck + || server.last_activity().elapsed().unwrap().as_millis() + > self.settings.healthcheck_delay as u128; // Do not issue a health check unless it's been a little while // since we last checked the server is ok. // Health checks are pretty expensive. if !require_healthcheck { + self.stats.checkout_time( + now.elapsed().as_micros(), + client_process_id, + server.server_id(), + ); self.stats - .checkout_time(now.elapsed().as_micros(), process_id, server.server_id()); - self.stats.server_active(process_id, server.server_id()); + .server_active(client_process_id, server.server_id()); return Ok((conn, address.clone())); } - debug!("Running health check on server {:?}", address); - - self.stats.server_tested(server.server_id()); - - match tokio::time::timeout( - tokio::time::Duration::from_millis(self.settings.healthcheck_timeout), - server.query(";"), // Cheap query as it skips the query planner - ) - .await + if self + .run_health_check(address, server, now, client_process_id) + .await { - // Check if health check succeeded. - Ok(res) => match res { - Ok(_) => { - self.stats.checkout_time( - now.elapsed().as_micros(), - process_id, - conn.server_id(), - ); - self.stats.server_active(process_id, conn.server_id()); - return Ok((conn, address.clone())); - } - - // Health check failed. - Err(err) => { - error!( - "Banning instance {:?} because of failed health check, {:?}", - address, err - ); - - // Don't leave a bad connection in the pool. - server.mark_bad(); - - self.ban(address, process_id); - continue; - } - }, - - // Health check timed out. - Err(err) => { - error!( - "Banning instance {:?} because of health check timeout, {:?}", - address, err - ); - // Don't leave a bad connection in the pool. - server.mark_bad(); - - self.ban(address, process_id); - continue; - } + return Ok((conn, address.clone())); + } else { + continue; } } Err(Error::AllServersDown) } + async fn run_health_check( + &self, + address: &Address, + server: &mut Server, + start: Instant, + client_process_id: i32, + ) -> bool { + debug!("Running health check on server {:?}", address); + + self.stats.server_tested(server.server_id()); + + match tokio::time::timeout( + tokio::time::Duration::from_millis(self.settings.healthcheck_timeout), + server.query(";"), // Cheap query as it skips the query planner + ) + .await + { + // Check if health check succeeded. + Ok(res) => match res { + Ok(_) => { + self.stats.checkout_time( + start.elapsed().as_micros(), + client_process_id, + server.server_id(), + ); + self.stats + .server_active(client_process_id, server.server_id()); + return true; + } + + // Health check failed. + Err(err) => { + error!( + "Banning instance {:?} because of failed health check, {:?}", + address, err + ); + } + }, + + // Health check timed out. + Err(err) => { + error!( + "Banning instance {:?} because of health check timeout, {:?}", + address, err + ); + } + } + + // Don't leave a bad connection in the pool. + server.mark_bad(); + + self.ban(&address, client_process_id); + return false; + } + /// Ban an address (i.e. replica). It no longer will serve /// traffic for any new transactions. Existing transactions on that replica /// will finish successfully or error out to the clients. pub fn ban(&self, address: &Address, client_id: i32) { - error!("Banning {:?}", address); - self.stats.client_ban_error(client_id, address.id); + // Primary can never be banned + if address.role == Role::Primary { + return; + } let now = chrono::offset::Utc::now().naive_utc(); let mut guard = self.banlist.write(); + error!("Banning {:?}", address); + self.stats.client_ban_error(client_id, address.id); guard[address.shard].insert(address.clone(), now); } @@ -494,51 +528,13 @@ impl ConnectionPool { guard[address.shard].remove(address); } - /// Check if a replica can serve traffic. If all replicas are banned, - /// we unban all of them. Better to try then not to. - pub fn is_banned(&self, address: &Address, role: Option) -> bool { - let replicas_available = match role { - Some(Role::Replica) => self.addresses[address.shard] - .iter() - .filter(|addr| addr.role == Role::Replica) - .count(), - None => self.addresses[address.shard].len(), - Some(Role::Primary) => return false, // Primary cannot be banned. - }; - - debug!("Available targets for {:?}: {}", role, replicas_available); - + /// Check if address is banned + /// true if banned, false otherwise + pub fn is_banned(&self, address: &Address) -> bool { let guard = self.banlist.read(); - // Everything is banned = nothing is banned. - if guard[address.shard].len() == replicas_available { - drop(guard); - let mut guard = self.banlist.write(); - guard[address.shard].clear(); - drop(guard); - warn!("Unbanning all replicas."); - return false; - } - - // I expect this to miss 99.9999% of the time. match guard[address.shard].get(address) { - Some(timestamp) => { - let now = chrono::offset::Utc::now().naive_utc(); - let config = get_config(); - - // Ban expired. - if now.timestamp() - timestamp.timestamp() > config.general.ban_time { - drop(guard); - warn!("Unbanning {:?}", address); - let mut guard = self.banlist.write(); - guard[address.shard].remove(address); - false - } else { - debug!("{:?} is banned", address); - true - } - } - + Some(_) => true, None => { debug!("{:?} is ok", address); false @@ -546,6 +542,57 @@ impl ConnectionPool { } } + /// Determines trying to unban this server was successful + pub async fn try_unban(&self, address: &Address) -> bool { + // If somehow primary ends up being banned we should return true here + if address.role == Role::Primary { + return true; + } + + // Check if all replicas are banned, in that case unban all of them + let replicas_available = self.addresses[address.shard] + .iter() + .filter(|addr| addr.role == Role::Replica) + .count(); + + debug!("Available targets: {}", replicas_available); + + let read_guard = self.banlist.read(); + let all_replicas_banned = read_guard[address.shard].len() == replicas_available; + drop(read_guard); + + if all_replicas_banned { + let mut write_guard = self.banlist.write(); + warn!("Unbanning all replicas."); + write_guard[address.shard].clear(); + + return true; + } + + // Check if ban time is expired + let read_guard = self.banlist.read(); + let exceeded_ban_time = match read_guard[address.shard].get(address) { + Some(timestamp) => { + let now = chrono::offset::Utc::now().naive_utc(); + now.timestamp() - timestamp.timestamp() > self.settings.ban_time + } + None => return true, + }; + drop(read_guard); + + if exceeded_ban_time { + warn!("Unbanning {:?}", address); + let mut write_guard = self.banlist.write(); + write_guard[address.shard].remove(address); + drop(write_guard); + + true + } else { + debug!("{:?} is banned", address); + false + } + } + /// Get the number of configured shards. pub fn shards(&self) -> usize { self.databases.len() diff --git a/src/query_router.rs b/src/query_router.rs index 9f9dcd7..28d899d 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -774,6 +774,7 @@ mod test { automatic_sharding_key: Some(String::from("id")), healthcheck_delay: PoolSettings::default().healthcheck_delay, healthcheck_timeout: PoolSettings::default().healthcheck_timeout, + ban_time: PoolSettings::default().ban_time, }; let mut qr = QueryRouter::new(); assert_eq!(qr.active_role, None);