correct load balancing

This commit is contained in:
Lev Kokotov
2022-02-05 18:20:53 -08:00
parent 9ac5614d50
commit e0ca175129
3 changed files with 187 additions and 145 deletions

View File

@@ -9,9 +9,8 @@ use tokio::net::TcpStream;
use crate::errors::Error; use crate::errors::Error;
use crate::messages::*; use crate::messages::*;
use crate::pool::{ClientServerMap, ServerPool}; use crate::pool::{ClientServerMap, ConnectionPool};
use crate::server::Server; use crate::server::Server;
use bb8::Pool;
/// The client state. /// The client state.
pub struct Client { pub struct Client {
@@ -125,7 +124,7 @@ impl Client {
} }
/// Client loop. We handle all messages between the client and the database here. /// Client loop. We handle all messages between the client and the database here.
pub async fn handle(&mut self, pool: Pool<ServerPool>) -> Result<(), Error> { pub async fn handle(&mut self, pool: ConnectionPool) -> Result<(), Error> {
// Special: cancelling existing running query // Special: cancelling existing running query
if self.cancel_mode { if self.cancel_mode {
let (process_id, secret_key, address, port) = { let (process_id, secret_key, address, port) = {
@@ -148,13 +147,17 @@ impl Client {
loop { loop {
// Only grab a connection once we have some traffic on the socket // Only grab a connection once we have some traffic on the socket
// TODO: this is not the most optimal way to share servers. // TODO: this is not the most optimal way to share servers.
let mut peek_buf = vec![0u8; 2]; // let mut peek_buf = vec![0u8; 2];
match self.read.get_mut().peek(&mut peek_buf).await { // match self.read.get_mut().peek(&mut peek_buf).await {
Ok(_) => (), // Ok(_) => (),
Err(_) => return Err(Error::ClientDisconnected), // Err(_) => return Err(Error::ClientDisconnected),
}; // };
let mut proxy = pool.get().await.unwrap(); let message = read_message(&mut self.read).await?;
self.buffer.put(message);
let mut proxy = pool.get(None).await.unwrap().0;
let server = &mut *proxy; let server = &mut *proxy;
// TODO: maybe don't do this, I don't think it's useful. // TODO: maybe don't do this, I don't think it's useful.
@@ -164,18 +167,28 @@ impl Client {
server.claim(self.process_id, self.secret_key); server.claim(self.process_id, self.secret_key);
loop { loop {
let mut message = match read_message(&mut self.read).await { let mut message = match self.buffer.len() {
Ok(message) => message, 0 => {
Err(err) => { match read_message(&mut self.read).await {
if server.in_transaction() { Ok(message) => message,
// TODO: this is what PgBouncer does Err(err) => {
// which leads to connection thrashing. if server.in_transaction() {
// // TODO: this is what PgBouncer does
// I think we could issue a ROLLBACK here instead. // which leads to connection thrashing.
server.mark_bad(); //
} // I think we could issue a ROLLBACK here instead.
server.mark_bad();
}
return Err(err); return Err(err);
}
}
}
_ => {
let message = self.buffer.clone();
self.buffer.clear();
message
} }
}; };

View File

@@ -21,7 +21,6 @@ extern crate tokio;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use bb8::Pool;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
@@ -35,12 +34,7 @@ mod server;
// Support for query cancellation: this maps our process_ids and // Support for query cancellation: this maps our process_ids and
// secret keys to the backend's. // secret keys to the backend's.
use config::{Address, User}; use config::{Address, User};
use pool::{ClientServerMap, ReplicaPool, ServerPool}; use pool::{ClientServerMap, ConnectionPool};
//
// Poor man's config
//
const POOL_SIZE: u32 = 15;
/// Main! /// Main!
#[tokio::main] #[tokio::main]
@@ -71,7 +65,6 @@ async fn main() {
port: "5432".to_string(), port: "5432".to_string(),
}, },
]; ];
let num_addresses = addresses.len() as u32;
let user = User { let user = User {
name: "lev".to_string(), name: "lev".to_string(),
@@ -80,9 +73,7 @@ async fn main() {
let database = "lev"; let database = "lev";
let replica_pool = ReplicaPool::new(addresses).await; let pool = ConnectionPool::new(addresses, user, database, client_server_map.clone()).await;
let manager = ServerPool::new(replica_pool, user, database, client_server_map.clone());
// We are round-robining, so ideally the replicas will be equally loaded. // We are round-robining, so ideally the replicas will be equally loaded.
// Therefore, we are allocating number of replicas * pool size of connections. // Therefore, we are allocating number of replicas * pool size of connections.
// However, if a replica dies, the remaining replicas will share the burden, // However, if a replica dies, the remaining replicas will share the burden,
@@ -91,11 +82,6 @@ async fn main() {
// Note that failover in this case could bring down the remaining replicas, so // Note that failover in this case could bring down the remaining replicas, so
// in certain situations, e.g. when replicas are running hot already, failover // in certain situations, e.g. when replicas are running hot already, failover
// is not at all desirable!! // is not at all desirable!!
let pool = Pool::builder()
.max_size(POOL_SIZE * num_addresses)
.build(manager)
.await
.unwrap();
loop { loop {
let pool = pool.clone(); let pool = pool.clone();

View File

@@ -1,6 +1,6 @@
/// Pooling and failover and banlist. /// Pooling and failover and banlist.
use async_trait::async_trait; use async_trait::async_trait;
use bb8::{ManageConnection, PooledConnection}; use bb8::{ManageConnection, Pool, PooledConnection};
use chrono::naive::NaiveDateTime; use chrono::naive::NaiveDateTime;
use crate::config::{Address, User}; use crate::config::{Address, User};
@@ -21,112 +21,119 @@ pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, Stri
// 60 seconds of ban time. // 60 seconds of ban time.
// After that, the replica will be allowed to serve traffic again. // After that, the replica will be allowed to serve traffic again.
const BAN_TIME: i64 = 60; const BAN_TIME: i64 = 60;
//
// Poor man's config
//
const POOL_SIZE: u32 = 15;
pub struct ServerPool {
replica_pool: ReplicaPool,
user: User,
database: String,
client_server_map: ClientServerMap,
}
impl ServerPool {
pub fn new(
replica_pool: ReplicaPool,
user: User,
database: &str,
client_server_map: ClientServerMap,
) -> ServerPool {
ServerPool {
replica_pool: replica_pool,
user: user,
database: database.to_string(),
client_server_map: client_server_map,
}
}
}
#[async_trait]
impl ManageConnection for ServerPool {
type Connection = Server;
type Error = Error;
/// Attempts to create a new connection.
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
println!(">> Getting new connection from the pool");
let address = self.replica_pool.get();
match Server::startup(
&address.host,
&address.port,
&self.user.name,
&self.user.password,
&self.database,
self.client_server_map.clone(),
)
.await
{
Ok(server) => {
self.replica_pool.unban(&address);
Ok(server)
}
Err(err) => {
self.replica_pool.ban(&address);
Err(err)
}
}
}
/// Determines if the connection is still connected to the database.
async fn is_valid(&self, conn: &mut PooledConnection<'_, Self>) -> Result<(), Self::Error> {
let server = &mut *conn;
// Client disconnected before cleaning up
if server.in_transaction() {
return Err(Error::DirtyServer);
}
// If this fails, the connection will be closed and another will be grabbed from the pool quietly :-).
// Failover, step 1, complete.
match tokio::time::timeout(
tokio::time::Duration::from_millis(1000),
server.query("SELECT 1"),
)
.await
{
Ok(_) => Ok(()),
Err(_err) => {
println!(">> Unhealthy!");
self.replica_pool.ban(&server.address());
Err(Error::ServerTimeout)
}
}
}
/// Synchronously determine if the connection is no longer usable, if possible.
fn has_broken(&self, conn: &mut Self::Connection) -> bool {
conn.is_bad()
}
}
/// A collection of addresses, which could either be a single primary,
/// many sharded primaries or replicas.
#[derive(Clone)] #[derive(Clone)]
pub struct ReplicaPool { pub struct ConnectionPool {
databases: Vec<Pool<ServerPool>>,
addresses: Vec<Address>, addresses: Vec<Address>,
round_robin: Counter, round_robin: Counter,
banlist: BanList, banlist: BanList,
} }
impl ReplicaPool { impl ConnectionPool {
/// Create a new replica pool. Addresses must be known in advance. pub async fn new(
pub async fn new(addresses: Vec<Address>) -> ReplicaPool { addresses: Vec<Address>,
ReplicaPool { user: User,
database: &str,
client_server_map: ClientServerMap,
) -> ConnectionPool {
let mut databases = Vec::new();
for address in &addresses {
let manager = ServerPool::new(
address.clone(),
user.clone(),
database,
client_server_map.clone(),
);
let pool = Pool::builder()
.max_size(POOL_SIZE)
.connection_timeout(std::time::Duration::from_millis(5000))
.test_on_check_out(false)
.build(manager)
.await
.unwrap();
databases.push(pool);
}
ConnectionPool {
databases: databases,
addresses: addresses, addresses: addresses,
round_robin: Arc::new(AtomicUsize::new(0)), round_robin: Arc::new(AtomicUsize::new(0)),
banlist: Arc::new(Mutex::new(HashMap::new())), banlist: Arc::new(Mutex::new(HashMap::new())),
} }
} }
/// Get a connection from the pool. Either round-robin or pick a specific one in case they are sharded.
pub async fn get(
&self,
index: Option<usize>,
) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
match index {
// Asking for a specific database, must be sharded.
// No failover here.
Some(index) => {
assert!(index < self.databases.len());
match self.databases[index].get().await {
Ok(conn) => Ok((conn, self.addresses[index].clone())),
Err(err) => {
println!(">> Shard {} down: {:?}", index, err);
Err(Error::ServerTimeout)
}
}
}
// Any database is fine, we're using round-robin here.
// Failover included if the server doesn't answer a health check.
None => {
loop {
let index =
self.round_robin.fetch_add(1, Ordering::SeqCst) % self.databases.len();
let address = self.addresses[index].clone();
if self.is_banned(&address) {
continue;
}
// Check if we can connect
let mut conn = match self.databases[index].get().await {
Ok(conn) => conn,
Err(err) => {
println!(">> Banning replica {}, error: {:?}", index, err);
self.ban(&address);
continue;
}
};
// Check if this server is alive with a health check
let server = &mut *conn;
match tokio::time::timeout(
tokio::time::Duration::from_millis(1000),
server.query("SELECT 1"),
)
.await
{
Ok(_) => return Ok((conn, address)),
Err(_) => {
println!(
">> Banning replica {} because of failed health check",
index
);
self.ban(&address);
continue;
}
}
}
}
}
}
/// Ban an address (i.e. replica). It no longer will serve /// Ban an address (i.e. replica). It no longer will serve
/// traffic for any new transactions. Existing transactions on that replica /// traffic for any new transactions. Existing transactions on that replica
/// will finish successfully or error out to the clients. /// will finish successfully or error out to the clients.
@@ -150,7 +157,7 @@ impl ReplicaPool {
let mut guard = self.banlist.lock().unwrap(); let mut guard = self.banlist.lock().unwrap();
// Everything is banned, nothig is banned // Everything is banned, nothig is banned
if guard.len() == self.addresses.len() { if guard.len() == self.databases.len() {
guard.clear(); guard.clear();
drop(guard); drop(guard);
println!(">> Unbanning all replicas."); println!(">> Unbanning all replicas.");
@@ -173,22 +180,58 @@ impl ReplicaPool {
None => false, None => false,
} }
} }
}
/// Get a replica to route the query to. pub struct ServerPool {
/// Will attempt to fetch a healthy replica. It will also address: Address,
/// round-robin them for reasonably equal load. Round-robin is done user: User,
/// per transaction. database: String,
pub fn get(&self) -> Address { client_server_map: ClientServerMap,
loop { }
// We'll never hit a 64-bit overflow right....right? :-)
let index = self.round_robin.fetch_add(1, Ordering::SeqCst) % self.addresses.len();
let address = &self.addresses[index]; impl ServerPool {
if !self.is_banned(address) { pub fn new(
return address.clone(); address: Address,
} else { user: User,
continue; database: &str,
} client_server_map: ClientServerMap,
) -> ServerPool {
ServerPool {
address: address,
user: user,
database: database.to_string(),
client_server_map: client_server_map,
} }
} }
} }
#[async_trait]
impl ManageConnection for ServerPool {
type Connection = Server;
type Error = Error;
/// Attempts to create a new connection.
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
println!(">> Getting new connection from the pool");
Server::startup(
&self.address.host,
&self.address.port,
&self.user.name,
&self.user.password,
&self.database,
self.client_server_map.clone(),
)
.await
}
/// Determines if the connection is still connected to the database.
async fn is_valid(&self, _conn: &mut PooledConnection<'_, Self>) -> Result<(), Self::Error> {
Ok(())
}
/// Synchronously determine if the connection is no longer usable, if possible.
fn has_broken(&self, conn: &mut Self::Connection) -> bool {
conn.is_bad()
}
}