proper failover

This commit is contained in:
Lev Kokotov
2022-02-05 13:15:53 -08:00
parent a6574acbc3
commit 8479c74354
5 changed files with 81 additions and 65 deletions

View File

@@ -9,8 +9,9 @@ use tokio::net::TcpStream;
use crate::errors::Error; use crate::errors::Error;
use crate::messages::*; use crate::messages::*;
use crate::pool::{ClientServerMap, ReplicaPool}; use crate::pool::{ClientServerMap, ServerPool};
use crate::server::Server; use crate::server::Server;
use bb8::Pool;
/// The client state. /// The client state.
pub struct Client { pub struct Client {
@@ -121,7 +122,7 @@ impl Client {
} }
/// Client loop. We handle all messages between the client and the database here. /// Client loop. We handle all messages between the client and the database here.
pub async fn handle(&mut self, mut pool: ReplicaPool) -> Result<(), Error> { pub async fn handle(&mut self, pool: Pool<ServerPool>) -> Result<(), Error> {
// Special: cancelling existing running query // Special: cancelling existing running query
if self.cancel_mode { if self.cancel_mode {
let (process_id, secret_key, address, port) = { let (process_id, secret_key, address, port) = {
@@ -150,7 +151,6 @@ impl Client {
Ok(_) => (), Ok(_) => (),
Err(_) => return Err(Error::ClientDisconnected), Err(_) => return Err(Error::ClientDisconnected),
}; };
let pool = pool.get().1;
let mut proxy = pool.get().await.unwrap(); let mut proxy = pool.get().await.unwrap();
let server = &mut *proxy; let server = &mut *proxy;

View File

@@ -1,4 +1,4 @@
#[derive(Clone, PartialEq, Hash, std::cmp::Eq)] #[derive(Clone, PartialEq, Hash, std::cmp::Eq, Debug)]
pub struct Address { pub struct Address {
pub host: String, pub host: String,
pub port: String, pub port: String,

View File

@@ -6,6 +6,7 @@ extern crate tokio;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use bb8::Pool;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
@@ -19,8 +20,14 @@ mod server;
// Support for query cancellation: this maps our process_ids and // Support for query cancellation: this maps our process_ids and
// secret keys to the backend's. // secret keys to the backend's.
use config::{Address, User}; use config::{Address, User};
use pool::{ClientServerMap, ReplicaPool}; use pool::{ClientServerMap, ReplicaPool, ServerPool};
//
// Poor man's config
//
const POOL_SIZE: u32 = 15;
/// Main!
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
println!("> Welcome to PgCat! Meow."); println!("> Welcome to PgCat! Meow.");
@@ -38,7 +45,7 @@ async fn main() {
let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new())); let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new()));
// Note in the logs that it will fetch two connections! // Replica pool.
let addresses = vec![ let addresses = vec![
Address { Address {
host: "127.0.0.1".to_string(), host: "127.0.0.1".to_string(),
@@ -49,17 +56,35 @@ async fn main() {
port: "5432".to_string(), port: "5432".to_string(),
}, },
]; ];
let num_addresses = addresses.len() as u32;
let user = User { let user = User {
name: "lev".to_string(), name: "lev".to_string(),
password: "lev".to_string(), password: "lev".to_string(),
}; };
let replica_pool = ReplicaPool::new(addresses, user, "lev", client_server_map.clone()).await; let database = "lev";
let replica_pool = ReplicaPool::new(addresses).await;
let manager = ServerPool::new(replica_pool, user, database, client_server_map.clone());
// We are round-robining, so ideally the replicas will be equally loaded.
// Therefore, we are allocating number of replicas * pool size of connections.
// However, if a replica dies, the remaining replicas will share the burden,
// also equally.
//
// Note that failover in this case could bring down the remaining replicas, so
// in certain situations, e.g. when replicas are running hot already, failover
// is not at all desirable!!
let pool = Pool::builder()
.max_size(POOL_SIZE * num_addresses)
.build(manager)
.await
.unwrap();
loop { loop {
let pool = pool.clone();
let client_server_map = client_server_map.clone(); let client_server_map = client_server_map.clone();
let replica_pool = replica_pool.clone();
let (socket, addr) = match listener.accept().await { let (socket, addr) = match listener.accept().await {
Ok((socket, addr)) => (socket, addr), Ok((socket, addr)) => (socket, addr),
@@ -77,7 +102,7 @@ async fn main() {
Ok(mut client) => { Ok(mut client) => {
println!(">> Client {:?} authenticated successfully!", addr); println!(">> Client {:?} authenticated successfully!", addr);
match client.handle(replica_pool).await { match client.handle(pool).await {
Ok(()) => { Ok(()) => {
println!(">> Client {:?} disconnected.", addr); println!(">> Client {:?} disconnected.", addr);
} }

View File

@@ -1,5 +1,5 @@
use async_trait::async_trait; use async_trait::async_trait;
use bb8::{ManageConnection, Pool, PooledConnection}; use bb8::{ManageConnection, PooledConnection};
use chrono::naive::NaiveDateTime; use chrono::naive::NaiveDateTime;
use crate::config::{Address, User}; use crate::config::{Address, User};
@@ -18,28 +18,22 @@ pub type Counter = Arc<AtomicUsize>;
pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, String)>>>; pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, String)>>>;
pub struct ServerPool { pub struct ServerPool {
host: String, replica_pool: ReplicaPool,
port: String, user: User,
user: String,
password: String,
database: String, database: String,
client_server_map: ClientServerMap, client_server_map: ClientServerMap,
} }
impl ServerPool { impl ServerPool {
pub fn new( pub fn new(
host: &str, replica_pool: ReplicaPool,
port: &str, user: User,
user: &str,
password: &str,
database: &str, database: &str,
client_server_map: ClientServerMap, client_server_map: ClientServerMap,
) -> ServerPool { ) -> ServerPool {
ServerPool { ServerPool {
host: host.to_string(), replica_pool: replica_pool,
port: port.to_string(), user: user,
user: user.to_string(),
password: password.to_string(),
database: database.to_string(), database: database.to_string(),
client_server_map: client_server_map, client_server_map: client_server_map,
} }
@@ -53,20 +47,28 @@ impl ManageConnection for ServerPool {
/// Attempts to create a new connection. /// Attempts to create a new connection.
async fn connect(&self) -> Result<Self::Connection, Self::Error> { async fn connect(&self) -> Result<Self::Connection, Self::Error> {
println!(">> Getting connetion from pool"); let address = self.replica_pool.get();
// println!(">> Getting connetion from pool");
// match Server::startup(
// TODO: Pick a random connection from a replica pool here. &address.host,
// &address.port,
Ok(Server::startup( &self.user.name,
&self.host, &self.user.password,
&self.port,
&self.user,
&self.password,
&self.database, &self.database,
self.client_server_map.clone(), self.client_server_map.clone(),
) )
.await?) .await
{
Ok(server) => {
self.replica_pool.unban(&address);
Ok(server)
}
Err(err) => {
self.replica_pool.ban(&address);
Err(err)
}
}
} }
/// Determines if the connection is still connected to the database. /// Determines if the connection is still connected to the database.
@@ -87,7 +89,11 @@ impl ManageConnection for ServerPool {
.await .await
{ {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(_err) => Err(Error::ServerTimeout), Err(_err) => {
println!(">> Unhealthy!");
self.replica_pool.ban(&server.address());
Err(Error::ServerTimeout)
}
} }
} }
@@ -101,7 +107,7 @@ impl ManageConnection for ServerPool {
/// many sharded primaries or replicas. /// many sharded primaries or replicas.
#[derive(Clone)] #[derive(Clone)]
pub struct ReplicaPool { pub struct ReplicaPool {
replicas: Vec<Pool<ServerPool>>, // replicas: Vec<Pool<ServerPool>>,
addresses: Vec<Address>, addresses: Vec<Address>,
// user: User, // user: User,
round_robin: Counter, round_robin: Counter,
@@ -109,47 +115,22 @@ pub struct ReplicaPool {
} }
impl ReplicaPool { impl ReplicaPool {
pub async fn new( pub async fn new(addresses: Vec<Address>) -> ReplicaPool {
addresses: Vec<Address>,
user: User,
database: &str,
client_server_map: ClientServerMap,
) -> ReplicaPool {
let mut replicas = Vec::new();
for address in &addresses {
let client_server_map = client_server_map.clone();
let manager = ServerPool::new(
&address.host,
&address.port,
&user.name,
&user.password,
database,
client_server_map,
);
let pool = Pool::builder().max_size(15).build(manager).await.unwrap();
replicas.push(pool);
}
ReplicaPool { ReplicaPool {
addresses: addresses, addresses: addresses,
replicas: replicas,
// user: user,
round_robin: Arc::new(AtomicUsize::new(0)), round_robin: Arc::new(AtomicUsize::new(0)),
banlist: Arc::new(Mutex::new(HashMap::new())), banlist: Arc::new(Mutex::new(HashMap::new())),
} }
} }
pub fn ban(&mut self, address: &Address) { pub fn ban(&self, address: &Address) {
println!(">> Banning {:?}", address);
let now = chrono::offset::Utc::now().naive_utc(); let now = chrono::offset::Utc::now().naive_utc();
let mut guard = self.banlist.lock().unwrap(); let mut guard = self.banlist.lock().unwrap();
guard.insert(address.clone(), now); guard.insert(address.clone(), now);
} }
pub fn unban(&mut self, address: &Address) { pub fn unban(&self, address: &Address) {
let mut guard = self.banlist.lock().unwrap(); let mut guard = self.banlist.lock().unwrap();
guard.remove(address); guard.remove(address);
} }
@@ -160,6 +141,8 @@ impl ReplicaPool {
// Everything is banned, nothig is banned // Everything is banned, nothig is banned
if guard.len() == self.addresses.len() { if guard.len() == self.addresses.len() {
guard.clear(); guard.clear();
drop(guard);
println!(">> Unbanning all");
return false; return false;
} }
@@ -180,14 +163,14 @@ impl ReplicaPool {
} }
} }
pub fn get(&mut self) -> (Address, Pool<ServerPool>) { pub fn get(&self) -> Address {
loop { loop {
// We'll never hit a 64-bit overflow right....right? :-) // We'll never hit a 64-bit overflow right....right? :-)
let index = self.round_robin.fetch_add(1, Ordering::SeqCst) % self.addresses.len(); let index = self.round_robin.fetch_add(1, Ordering::SeqCst) % self.addresses.len();
let address = &self.addresses[index]; let address = &self.addresses[index];
if !self.is_banned(address) { if !self.is_banned(address) {
return (address.clone(), self.replicas[index].clone()); return address.clone();
} else { } else {
continue; continue;
} }

View File

@@ -8,6 +8,7 @@ use tokio::io::{AsyncReadExt, BufReader};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use crate::config::Address;
use crate::errors::Error; use crate::errors::Error;
use crate::messages::*; use crate::messages::*;
use crate::ClientServerMap; use crate::ClientServerMap;
@@ -350,4 +351,11 @@ impl Server {
.query(&format!("SET application_name = '{}'", name)) .query(&format!("SET application_name = '{}'", name))
.await?) .await?)
} }
pub fn address(&self) -> Address {
Address {
host: self.host.to_string(),
port: self.port.to_string(),
}
}
} }