From e0ca175129e767beef21ee9ec0bf7adec129efdb Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Sat, 5 Feb 2022 18:20:53 -0800 Subject: [PATCH] correct load balancing --- src/client.rs | 53 ++++++---- src/main.rs | 18 +--- src/pool.rs | 261 +++++++++++++++++++++++++++++--------------------- 3 files changed, 187 insertions(+), 145 deletions(-) diff --git a/src/client.rs b/src/client.rs index 4accc8e..5b0e218 100644 --- a/src/client.rs +++ b/src/client.rs @@ -9,9 +9,8 @@ use tokio::net::TcpStream; use crate::errors::Error; use crate::messages::*; -use crate::pool::{ClientServerMap, ServerPool}; +use crate::pool::{ClientServerMap, ConnectionPool}; use crate::server::Server; -use bb8::Pool; /// The client state. pub struct Client { @@ -125,7 +124,7 @@ impl Client { } /// Client loop. We handle all messages between the client and the database here. - pub async fn handle(&mut self, pool: Pool) -> Result<(), Error> { + pub async fn handle(&mut self, pool: ConnectionPool) -> Result<(), Error> { // Special: cancelling existing running query if self.cancel_mode { let (process_id, secret_key, address, port) = { @@ -148,13 +147,17 @@ impl Client { loop { // Only grab a connection once we have some traffic on the socket // TODO: this is not the most optimal way to share servers. - let mut peek_buf = vec![0u8; 2]; + // let mut peek_buf = vec![0u8; 2]; - match self.read.get_mut().peek(&mut peek_buf).await { - Ok(_) => (), - Err(_) => return Err(Error::ClientDisconnected), - }; - let mut proxy = pool.get().await.unwrap(); + // match self.read.get_mut().peek(&mut peek_buf).await { + // Ok(_) => (), + // Err(_) => return Err(Error::ClientDisconnected), + // }; + let message = read_message(&mut self.read).await?; + + self.buffer.put(message); + + let mut proxy = pool.get(None).await.unwrap().0; let server = &mut *proxy; // TODO: maybe don't do this, I don't think it's useful. @@ -164,18 +167,28 @@ impl Client { server.claim(self.process_id, self.secret_key); loop { - let mut message = match read_message(&mut self.read).await { - Ok(message) => message, - Err(err) => { - if server.in_transaction() { - // TODO: this is what PgBouncer does - // which leads to connection thrashing. - // - // I think we could issue a ROLLBACK here instead. - server.mark_bad(); - } + let mut message = match self.buffer.len() { + 0 => { + match read_message(&mut self.read).await { + Ok(message) => message, + Err(err) => { + if server.in_transaction() { + // TODO: this is what PgBouncer does + // which leads to connection thrashing. + // + // I think we could issue a ROLLBACK here instead. + server.mark_bad(); + } - return Err(err); + return Err(err); + } + } + } + + _ => { + let message = self.buffer.clone(); + self.buffer.clear(); + message } }; diff --git a/src/main.rs b/src/main.rs index 37d5933..ed9e738 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,7 +21,6 @@ extern crate tokio; use tokio::net::TcpListener; -use bb8::Pool; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -35,12 +34,7 @@ mod server; // Support for query cancellation: this maps our process_ids and // secret keys to the backend's. use config::{Address, User}; -use pool::{ClientServerMap, ReplicaPool, ServerPool}; - -// -// Poor man's config -// -const POOL_SIZE: u32 = 15; +use pool::{ClientServerMap, ConnectionPool}; /// Main! #[tokio::main] @@ -71,7 +65,6 @@ async fn main() { port: "5432".to_string(), }, ]; - let num_addresses = addresses.len() as u32; let user = User { name: "lev".to_string(), @@ -80,9 +73,7 @@ async fn main() { let database = "lev"; - let replica_pool = ReplicaPool::new(addresses).await; - let manager = ServerPool::new(replica_pool, user, database, client_server_map.clone()); - + let pool = ConnectionPool::new(addresses, user, database, client_server_map.clone()).await; // We are round-robining, so ideally the replicas will be equally loaded. // Therefore, we are allocating number of replicas * pool size of connections. // However, if a replica dies, the remaining replicas will share the burden, @@ -91,11 +82,6 @@ async fn main() { // Note that failover in this case could bring down the remaining replicas, so // in certain situations, e.g. when replicas are running hot already, failover // is not at all desirable!! - let pool = Pool::builder() - .max_size(POOL_SIZE * num_addresses) - .build(manager) - .await - .unwrap(); loop { let pool = pool.clone(); diff --git a/src/pool.rs b/src/pool.rs index e68d6c8..049fb2e 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -1,6 +1,6 @@ /// Pooling and failover and banlist. use async_trait::async_trait; -use bb8::{ManageConnection, PooledConnection}; +use bb8::{ManageConnection, Pool, PooledConnection}; use chrono::naive::NaiveDateTime; use crate::config::{Address, User}; @@ -21,112 +21,119 @@ pub type ClientServerMap = Arc ServerPool { - ServerPool { - replica_pool: replica_pool, - user: user, - database: database.to_string(), - client_server_map: client_server_map, - } - } -} - -#[async_trait] -impl ManageConnection for ServerPool { - type Connection = Server; - type Error = Error; - - /// Attempts to create a new connection. - async fn connect(&self) -> Result { - println!(">> Getting new connection from the pool"); - let address = self.replica_pool.get(); - - match Server::startup( - &address.host, - &address.port, - &self.user.name, - &self.user.password, - &self.database, - self.client_server_map.clone(), - ) - .await - { - Ok(server) => { - self.replica_pool.unban(&address); - Ok(server) - } - Err(err) => { - self.replica_pool.ban(&address); - Err(err) - } - } - } - - /// Determines if the connection is still connected to the database. - async fn is_valid(&self, conn: &mut PooledConnection<'_, Self>) -> Result<(), Self::Error> { - let server = &mut *conn; - - // Client disconnected before cleaning up - if server.in_transaction() { - return Err(Error::DirtyServer); - } - - // If this fails, the connection will be closed and another will be grabbed from the pool quietly :-). - // Failover, step 1, complete. - match tokio::time::timeout( - tokio::time::Duration::from_millis(1000), - server.query("SELECT 1"), - ) - .await - { - Ok(_) => Ok(()), - Err(_err) => { - println!(">> Unhealthy!"); - self.replica_pool.ban(&server.address()); - Err(Error::ServerTimeout) - } - } - } - - /// Synchronously determine if the connection is no longer usable, if possible. - fn has_broken(&self, conn: &mut Self::Connection) -> bool { - conn.is_bad() - } -} - -/// A collection of addresses, which could either be a single primary, -/// many sharded primaries or replicas. #[derive(Clone)] -pub struct ReplicaPool { +pub struct ConnectionPool { + databases: Vec>, addresses: Vec
, round_robin: Counter, banlist: BanList, } -impl ReplicaPool { - /// Create a new replica pool. Addresses must be known in advance. - pub async fn new(addresses: Vec
) -> ReplicaPool { - ReplicaPool { +impl ConnectionPool { + pub async fn new( + addresses: Vec
, + user: User, + database: &str, + client_server_map: ClientServerMap, + ) -> ConnectionPool { + let mut databases = Vec::new(); + + for address in &addresses { + let manager = ServerPool::new( + address.clone(), + user.clone(), + database, + client_server_map.clone(), + ); + let pool = Pool::builder() + .max_size(POOL_SIZE) + .connection_timeout(std::time::Duration::from_millis(5000)) + .test_on_check_out(false) + .build(manager) + .await + .unwrap(); + + databases.push(pool); + } + + ConnectionPool { + databases: databases, addresses: addresses, round_robin: Arc::new(AtomicUsize::new(0)), banlist: Arc::new(Mutex::new(HashMap::new())), } } + /// Get a connection from the pool. Either round-robin or pick a specific one in case they are sharded. + pub async fn get( + &self, + index: Option, + ) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> { + match index { + // Asking for a specific database, must be sharded. + // No failover here. + Some(index) => { + assert!(index < self.databases.len()); + match self.databases[index].get().await { + Ok(conn) => Ok((conn, self.addresses[index].clone())), + Err(err) => { + println!(">> Shard {} down: {:?}", index, err); + Err(Error::ServerTimeout) + } + } + } + + // Any database is fine, we're using round-robin here. + // Failover included if the server doesn't answer a health check. + None => { + loop { + let index = + self.round_robin.fetch_add(1, Ordering::SeqCst) % self.databases.len(); + let address = self.addresses[index].clone(); + + if self.is_banned(&address) { + continue; + } + + // Check if we can connect + let mut conn = match self.databases[index].get().await { + Ok(conn) => conn, + Err(err) => { + println!(">> Banning replica {}, error: {:?}", index, err); + self.ban(&address); + continue; + } + }; + + // Check if this server is alive with a health check + let server = &mut *conn; + + match tokio::time::timeout( + tokio::time::Duration::from_millis(1000), + server.query("SELECT 1"), + ) + .await + { + Ok(_) => return Ok((conn, address)), + Err(_) => { + println!( + ">> Banning replica {} because of failed health check", + index + ); + self.ban(&address); + continue; + } + } + } + } + } + } + /// Ban an address (i.e. replica). It no longer will serve /// traffic for any new transactions. Existing transactions on that replica /// will finish successfully or error out to the clients. @@ -150,7 +157,7 @@ impl ReplicaPool { let mut guard = self.banlist.lock().unwrap(); // Everything is banned, nothig is banned - if guard.len() == self.addresses.len() { + if guard.len() == self.databases.len() { guard.clear(); drop(guard); println!(">> Unbanning all replicas."); @@ -173,22 +180,58 @@ impl ReplicaPool { None => false, } } +} - /// Get a replica to route the query to. - /// Will attempt to fetch a healthy replica. It will also - /// round-robin them for reasonably equal load. Round-robin is done - /// per transaction. - pub fn get(&self) -> Address { - loop { - // We'll never hit a 64-bit overflow right....right? :-) - let index = self.round_robin.fetch_add(1, Ordering::SeqCst) % self.addresses.len(); +pub struct ServerPool { + address: Address, + user: User, + database: String, + client_server_map: ClientServerMap, +} - let address = &self.addresses[index]; - if !self.is_banned(address) { - return address.clone(); - } else { - continue; - } +impl ServerPool { + pub fn new( + address: Address, + user: User, + database: &str, + client_server_map: ClientServerMap, + ) -> ServerPool { + ServerPool { + address: address, + user: user, + database: database.to_string(), + client_server_map: client_server_map, } } } + +#[async_trait] +impl ManageConnection for ServerPool { + type Connection = Server; + type Error = Error; + + /// Attempts to create a new connection. + async fn connect(&self) -> Result { + println!(">> Getting new connection from the pool"); + + Server::startup( + &self.address.host, + &self.address.port, + &self.user.name, + &self.user.password, + &self.database, + self.client_server_map.clone(), + ) + .await + } + + /// Determines if the connection is still connected to the database. + async fn is_valid(&self, _conn: &mut PooledConnection<'_, Self>) -> Result<(), Self::Error> { + Ok(()) + } + + /// Synchronously determine if the connection is no longer usable, if possible. + fn has_broken(&self, conn: &mut Self::Connection) -> bool { + conn.is_bad() + } +}