diff --git a/src/client.rs b/src/client.rs index 2d7c974..0fe1859 100644 --- a/src/client.rs +++ b/src/client.rs @@ -9,8 +9,9 @@ use tokio::net::TcpStream; use crate::errors::Error; use crate::messages::*; -use crate::pool::{ClientServerMap, ReplicaPool}; +use crate::pool::{ClientServerMap, ServerPool}; use crate::server::Server; +use bb8::Pool; /// The client state. pub struct Client { @@ -121,7 +122,7 @@ impl Client { } /// Client loop. We handle all messages between the client and the database here. - pub async fn handle(&mut self, mut pool: ReplicaPool) -> Result<(), Error> { + pub async fn handle(&mut self, pool: Pool) -> Result<(), Error> { // Special: cancelling existing running query if self.cancel_mode { let (process_id, secret_key, address, port) = { @@ -150,7 +151,6 @@ impl Client { Ok(_) => (), Err(_) => return Err(Error::ClientDisconnected), }; - let pool = pool.get().1; let mut proxy = pool.get().await.unwrap(); let server = &mut *proxy; diff --git a/src/config.rs b/src/config.rs index c0983a1..736d6a5 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,4 +1,4 @@ -#[derive(Clone, PartialEq, Hash, std::cmp::Eq)] +#[derive(Clone, PartialEq, Hash, std::cmp::Eq, Debug)] pub struct Address { pub host: String, pub port: String, diff --git a/src/main.rs b/src/main.rs index ad62664..cfe935f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,7 @@ extern crate tokio; use tokio::net::TcpListener; +use bb8::Pool; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -19,8 +20,14 @@ mod server; // Support for query cancellation: this maps our process_ids and // secret keys to the backend's. use config::{Address, User}; -use pool::{ClientServerMap, ReplicaPool}; +use pool::{ClientServerMap, ReplicaPool, ServerPool}; +// +// Poor man's config +// +const POOL_SIZE: u32 = 15; + +/// Main! #[tokio::main] async fn main() { println!("> Welcome to PgCat! Meow."); @@ -38,7 +45,7 @@ async fn main() { let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new())); - // Note in the logs that it will fetch two connections! + // Replica pool. let addresses = vec![ Address { host: "127.0.0.1".to_string(), @@ -49,17 +56,35 @@ async fn main() { port: "5432".to_string(), }, ]; + let num_addresses = addresses.len() as u32; let user = User { name: "lev".to_string(), password: "lev".to_string(), }; - let replica_pool = ReplicaPool::new(addresses, user, "lev", client_server_map.clone()).await; + let database = "lev"; + + let replica_pool = ReplicaPool::new(addresses).await; + let manager = ServerPool::new(replica_pool, user, database, client_server_map.clone()); + + // We are round-robining, so ideally the replicas will be equally loaded. + // Therefore, we are allocating number of replicas * pool size of connections. + // However, if a replica dies, the remaining replicas will share the burden, + // also equally. + // + // Note that failover in this case could bring down the remaining replicas, so + // in certain situations, e.g. when replicas are running hot already, failover + // is not at all desirable!! + let pool = Pool::builder() + .max_size(POOL_SIZE * num_addresses) + .build(manager) + .await + .unwrap(); loop { + let pool = pool.clone(); let client_server_map = client_server_map.clone(); - let replica_pool = replica_pool.clone(); let (socket, addr) = match listener.accept().await { Ok((socket, addr)) => (socket, addr), @@ -77,7 +102,7 @@ async fn main() { Ok(mut client) => { println!(">> Client {:?} authenticated successfully!", addr); - match client.handle(replica_pool).await { + match client.handle(pool).await { Ok(()) => { println!(">> Client {:?} disconnected.", addr); } diff --git a/src/pool.rs b/src/pool.rs index cac9fb5..7f31f77 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -1,5 +1,5 @@ use async_trait::async_trait; -use bb8::{ManageConnection, Pool, PooledConnection}; +use bb8::{ManageConnection, PooledConnection}; use chrono::naive::NaiveDateTime; use crate::config::{Address, User}; @@ -18,28 +18,22 @@ pub type Counter = Arc; pub type ClientServerMap = Arc>>; pub struct ServerPool { - host: String, - port: String, - user: String, - password: String, + replica_pool: ReplicaPool, + user: User, database: String, client_server_map: ClientServerMap, } impl ServerPool { pub fn new( - host: &str, - port: &str, - user: &str, - password: &str, + replica_pool: ReplicaPool, + user: User, database: &str, client_server_map: ClientServerMap, ) -> ServerPool { ServerPool { - host: host.to_string(), - port: port.to_string(), - user: user.to_string(), - password: password.to_string(), + replica_pool: replica_pool, + user: user, database: database.to_string(), client_server_map: client_server_map, } @@ -53,20 +47,28 @@ impl ManageConnection for ServerPool { /// Attempts to create a new connection. async fn connect(&self) -> Result { - println!(">> Getting connetion from pool"); + let address = self.replica_pool.get(); + // println!(">> Getting connetion from pool"); - // - // TODO: Pick a random connection from a replica pool here. - // - Ok(Server::startup( - &self.host, - &self.port, - &self.user, - &self.password, + match Server::startup( + &address.host, + &address.port, + &self.user.name, + &self.user.password, &self.database, self.client_server_map.clone(), ) - .await?) + .await + { + Ok(server) => { + self.replica_pool.unban(&address); + Ok(server) + } + Err(err) => { + self.replica_pool.ban(&address); + Err(err) + } + } } /// Determines if the connection is still connected to the database. @@ -87,7 +89,11 @@ impl ManageConnection for ServerPool { .await { Ok(_) => Ok(()), - Err(_err) => Err(Error::ServerTimeout), + Err(_err) => { + println!(">> Unhealthy!"); + self.replica_pool.ban(&server.address()); + Err(Error::ServerTimeout) + } } } @@ -101,7 +107,7 @@ impl ManageConnection for ServerPool { /// many sharded primaries or replicas. #[derive(Clone)] pub struct ReplicaPool { - replicas: Vec>, + // replicas: Vec>, addresses: Vec
, // user: User, round_robin: Counter, @@ -109,47 +115,22 @@ pub struct ReplicaPool { } impl ReplicaPool { - pub async fn new( - addresses: Vec
, - user: User, - database: &str, - client_server_map: ClientServerMap, - ) -> ReplicaPool { - let mut replicas = Vec::new(); - - for address in &addresses { - let client_server_map = client_server_map.clone(); - - let manager = ServerPool::new( - &address.host, - &address.port, - &user.name, - &user.password, - database, - client_server_map, - ); - - let pool = Pool::builder().max_size(15).build(manager).await.unwrap(); - - replicas.push(pool); - } - + pub async fn new(addresses: Vec
) -> ReplicaPool { ReplicaPool { addresses: addresses, - replicas: replicas, - // user: user, round_robin: Arc::new(AtomicUsize::new(0)), banlist: Arc::new(Mutex::new(HashMap::new())), } } - pub fn ban(&mut self, address: &Address) { + pub fn ban(&self, address: &Address) { + println!(">> Banning {:?}", address); let now = chrono::offset::Utc::now().naive_utc(); let mut guard = self.banlist.lock().unwrap(); guard.insert(address.clone(), now); } - pub fn unban(&mut self, address: &Address) { + pub fn unban(&self, address: &Address) { let mut guard = self.banlist.lock().unwrap(); guard.remove(address); } @@ -160,6 +141,8 @@ impl ReplicaPool { // Everything is banned, nothig is banned if guard.len() == self.addresses.len() { guard.clear(); + drop(guard); + println!(">> Unbanning all"); return false; } @@ -180,14 +163,14 @@ impl ReplicaPool { } } - pub fn get(&mut self) -> (Address, Pool) { + pub fn get(&self) -> Address { loop { // We'll never hit a 64-bit overflow right....right? :-) let index = self.round_robin.fetch_add(1, Ordering::SeqCst) % self.addresses.len(); let address = &self.addresses[index]; if !self.is_banned(address) { - return (address.clone(), self.replicas[index].clone()); + return address.clone(); } else { continue; } diff --git a/src/server.rs b/src/server.rs index fb7b85d..ea6c200 100644 --- a/src/server.rs +++ b/src/server.rs @@ -8,6 +8,7 @@ use tokio::io::{AsyncReadExt, BufReader}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::TcpStream; +use crate::config::Address; use crate::errors::Error; use crate::messages::*; use crate::ClientServerMap; @@ -350,4 +351,11 @@ impl Server { .query(&format!("SET application_name = '{}'", name)) .await?) } + + pub fn address(&self) -> Address { + Address { + host: self.host.to_string(), + port: self.port.to_string(), + } + } }