Files
pgcat/src/pool.rs

178 lines
5.0 KiB
Rust
Raw Normal View History

2022-02-03 16:25:05 -08:00
use async_trait::async_trait;
2022-02-05 13:15:53 -08:00
use bb8::{ManageConnection, PooledConnection};
2022-02-05 10:02:13 -08:00
use chrono::naive::NaiveDateTime;
2022-02-03 16:25:05 -08:00
2022-02-05 10:02:13 -08:00
use crate::config::{Address, User};
2022-02-03 16:25:05 -08:00
use crate::errors::Error;
use crate::server::Server;
2022-02-05 10:02:13 -08:00
use std::collections::HashMap;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
};
// Banlist: bad servers go in here.
pub type BanList = Arc<Mutex<HashMap<Address, NaiveDateTime>>>;
pub type Counter = Arc<AtomicUsize>;
pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, String)>>>;
2022-02-03 16:25:05 -08:00
pub struct ServerPool {
2022-02-05 13:15:53 -08:00
replica_pool: ReplicaPool,
user: User,
2022-02-03 16:25:05 -08:00
database: String,
2022-02-04 16:01:35 -08:00
client_server_map: ClientServerMap,
2022-02-03 16:25:05 -08:00
}
impl ServerPool {
2022-02-04 16:01:35 -08:00
pub fn new(
2022-02-05 13:15:53 -08:00
replica_pool: ReplicaPool,
user: User,
2022-02-04 16:01:35 -08:00
database: &str,
client_server_map: ClientServerMap,
) -> ServerPool {
2022-02-03 16:25:05 -08:00
ServerPool {
2022-02-05 13:15:53 -08:00
replica_pool: replica_pool,
user: user,
2022-02-03 16:25:05 -08:00
database: database.to_string(),
2022-02-04 16:01:35 -08:00
client_server_map: client_server_map,
2022-02-03 16:25:05 -08:00
}
}
}
#[async_trait]
impl ManageConnection for ServerPool {
type Connection = Server;
type Error = Error;
/// Attempts to create a new connection.
2022-02-03 16:25:05 -08:00
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
2022-02-05 13:19:50 -08:00
println!(">> Getting new connection from the pool");
2022-02-05 13:15:53 -08:00
let address = self.replica_pool.get();
match Server::startup(
&address.host,
&address.port,
&self.user.name,
&self.user.password,
&self.database,
2022-02-04 16:01:35 -08:00
self.client_server_map.clone(),
)
2022-02-05 13:15:53 -08:00
.await
{
Ok(server) => {
self.replica_pool.unban(&address);
Ok(server)
}
Err(err) => {
self.replica_pool.ban(&address);
Err(err)
}
}
2022-02-03 16:25:05 -08:00
}
/// Determines if the connection is still connected to the database.
2022-02-03 17:32:04 -08:00
async fn is_valid(&self, conn: &mut PooledConnection<'_, Self>) -> Result<(), Self::Error> {
let server = &mut *conn;
2022-02-03 18:02:50 -08:00
// Client disconnected before cleaning up
if server.in_transaction() {
return Err(Error::DirtyServer);
}
2022-02-03 17:32:04 -08:00
// If this fails, the connection will be closed and another will be grabbed from the pool quietly :-).
// Failover, step 1, complete.
match tokio::time::timeout(
tokio::time::Duration::from_millis(1000),
server.query("SELECT 1"),
)
.await
{
Ok(_) => Ok(()),
2022-02-05 13:15:53 -08:00
Err(_err) => {
println!(">> Unhealthy!");
self.replica_pool.ban(&server.address());
Err(Error::ServerTimeout)
}
2022-02-03 17:32:04 -08:00
}
2022-02-03 16:25:05 -08:00
}
/// Synchronously determine if the connection is no longer usable, if possible.
fn has_broken(&self, conn: &mut Self::Connection) -> bool {
conn.is_bad()
2022-02-03 16:25:05 -08:00
}
}
2022-02-05 10:02:13 -08:00
/// A collection of servers, which could either be a single primary,
/// many sharded primaries or replicas.
#[derive(Clone)]
pub struct ReplicaPool {
addresses: Vec<Address>,
round_robin: Counter,
banlist: BanList,
}
impl ReplicaPool {
2022-02-05 13:15:53 -08:00
pub async fn new(addresses: Vec<Address>) -> ReplicaPool {
2022-02-05 10:02:13 -08:00
ReplicaPool {
addresses: addresses,
round_robin: Arc::new(AtomicUsize::new(0)),
banlist: Arc::new(Mutex::new(HashMap::new())),
}
}
2022-02-05 13:15:53 -08:00
pub fn ban(&self, address: &Address) {
println!(">> Banning {:?}", address);
2022-02-05 10:02:13 -08:00
let now = chrono::offset::Utc::now().naive_utc();
let mut guard = self.banlist.lock().unwrap();
guard.insert(address.clone(), now);
}
2022-02-05 13:15:53 -08:00
pub fn unban(&self, address: &Address) {
2022-02-05 10:02:13 -08:00
let mut guard = self.banlist.lock().unwrap();
guard.remove(address);
}
pub fn is_banned(&self, address: &Address) -> bool {
let mut guard = self.banlist.lock().unwrap();
// Everything is banned, nothig is banned
if guard.len() == self.addresses.len() {
guard.clear();
2022-02-05 13:15:53 -08:00
drop(guard);
println!(">> Unbanning all");
2022-02-05 10:02:13 -08:00
return false;
}
// I expect this to miss 99.9999% of the time.
match guard.get(address) {
Some(timestamp) => {
let now = chrono::offset::Utc::now().naive_utc();
if now.timestamp() - timestamp.timestamp() > 60 {
// 1 minute
guard.remove(address);
false
} else {
true
}
}
None => false,
}
}
2022-02-05 13:15:53 -08:00
pub fn get(&self) -> Address {
2022-02-05 10:02:13 -08:00
loop {
// We'll never hit a 64-bit overflow right....right? :-)
let index = self.round_robin.fetch_add(1, Ordering::SeqCst) % self.addresses.len();
let address = &self.addresses[index];
if !self.is_banned(address) {
2022-02-05 13:15:53 -08:00
return address.clone();
2022-02-05 10:02:13 -08:00
} else {
continue;
}
}
}
}