diff --git a/src/client.rs b/src/client.rs index d9bd074..b397692 100644 --- a/src/client.rs +++ b/src/client.rs @@ -153,7 +153,7 @@ impl Client { } /// Client loop. We handle all messages between the client and the database here. - pub async fn handle(&mut self, pool: ConnectionPool) -> Result<(), Error> { + pub async fn handle(&mut self, mut pool: ConnectionPool) -> Result<(), Error> { // Special: cancelling existing running query if self.cancel_mode { let (process_id, secret_key, address, port) = { diff --git a/src/pool.rs b/src/pool.rs index 8cf3ada..0b9b62b 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -9,20 +9,20 @@ use crate::server::Server; use std::collections::HashMap; use std::sync::{ - atomic::{AtomicUsize, Ordering}, + // atomic::{AtomicUsize, Ordering}, Arc, Mutex, }; // Banlist: bad servers go in here. pub type BanList = Arc>>>; -pub type Counter = Arc; +// pub type Counter = Arc; pub type ClientServerMap = Arc>>; #[derive(Clone, Debug)] pub struct ConnectionPool { databases: Vec>>, addresses: Vec>, - round_robin: Counter, + round_robin: usize, banlist: BanList, healthcheck_timeout: u64, ban_time: i64, @@ -90,10 +90,13 @@ impl ConnectionPool { banlist.push(HashMap::new()); } + assert_eq!(shards.len(), addresses.len()); + let address_len = addresses.len(); + ConnectionPool { databases: shards, addresses: addresses, - round_robin: Arc::new(AtomicUsize::new(0)), + round_robin: rand::random::() % address_len, // Start at a random replica banlist: Arc::new(Mutex::new(banlist)), healthcheck_timeout: config.general.healthcheck_timeout, ban_time: config.general.ban_time, @@ -103,7 +106,7 @@ impl ConnectionPool { /// Get a connection from the pool. pub async fn get( - &self, + &mut self, shard: Option, role: Option, ) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> { @@ -115,40 +118,48 @@ impl ConnectionPool { None => 0, // TODO: pick a shard at random }; - let mut allowed_attempts = match role { - // Primary-specific queries get one attempt, if the primary is down, - // nothing we should do about it I think. It's dangerous to retry - // write queries. - Some(Role::Primary) => { - // Make sure we have a primary in the pool configured. - let primary_present = self.addresses[shard] + let addresses = &self.addresses[shard]; + + // Make sure if a specific role is requested, it's available in the pool. + match role { + Some(role) => { + let role_count = addresses .iter() .filter(|&db| db.role == Role::Primary) .count(); - // TODO: return this error to the client, so people don't have to look in - // the logs to figure out what happened. - if primary_present == 0 { - println!(">> Error: Primary requested but none are configured."); + if role_count == 0 { + println!( + ">> Error: Role '{:?}' requested, but none are configured.", + role + ); + return Err(Error::AllServersDown); } - - // Primary gets one attempt. - 1 } + // Any role should be present. + _ => (), + }; + + let mut allowed_attempts = match role { + // Primary-specific queries get one attempt, if the primary is down, + // nothing we should do about it I think. It's dangerous to retry + // write queries. + Some(Role::Primary) => 1, + // Replicas get to try as many times as there are replicas // and connections in the pool. _ => self.databases[shard].len() * self.pool_size as usize, }; while allowed_attempts > 0 { - // TODO: think about making this local, so multiple clients - // don't compete for the same round-robin integer. - // Especially since we're going to be skipping (see role selection below). - let index = - self.round_robin.fetch_add(1, Ordering::SeqCst) % self.databases[shard].len(); - let address = self.addresses[shard][index].clone(); + // Round-robin each client's queries. + // If a client only sends one query and then disconnects, it doesn't matter + // which replica it'll go to. + self.round_robin += 1; + let index = self.round_robin % addresses.len(); + let address = &addresses[index]; // Make sure you're getting a primary or a replica // as per request. @@ -158,14 +169,14 @@ impl ConnectionPool { // we'll do our best to pick it, but if we only // have one server in the cluster, it's probably only a primary // (or only a replica), so the client will just get what we have. - if address.role != role && self.addresses[shard].len() > 1 { + if address.role != role && addresses.len() > 1 { continue; } } None => (), }; - if self.is_banned(&address, shard, role) { + if self.is_banned(address, shard, role) { continue; } @@ -177,13 +188,13 @@ impl ConnectionPool { Ok(conn) => conn, Err(err) => { println!(">> Banning replica {}, error: {:?}", index, err); - self.ban(&address, shard); + self.ban(address, shard); continue; } }; if !with_health_check { - return Ok((conn, address)); + return Ok((conn, address.clone())); } // // Check if this server is alive with a health check @@ -197,7 +208,7 @@ impl ConnectionPool { { // Check if health check succeeded Ok(res) => match res { - Ok(_) => return Ok((conn, address)), + Ok(_) => return Ok((conn, address.clone())), Err(_) => { println!( ">> Banning replica {} because of failed health check", @@ -206,7 +217,7 @@ impl ConnectionPool { // Don't leave a bad connection in the pool. server.mark_bad(); - self.ban(&address, shard); + self.ban(address, shard); continue; } }, @@ -219,7 +230,7 @@ impl ConnectionPool { // Don't leave a bad connection in the pool. server.mark_bad(); - self.ban(&address, shard); + self.ban(address, shard); continue; } }