mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-28 03:06:29 +00:00
replica pool & banlist
This commit is contained in:
114
src/pool.rs
114
src/pool.rs
@@ -1,9 +1,21 @@
|
||||
use async_trait::async_trait;
|
||||
use bb8::{ManageConnection, PooledConnection};
|
||||
use bb8::{ManageConnection, Pool, PooledConnection};
|
||||
use chrono::naive::NaiveDateTime;
|
||||
|
||||
use crate::config::{Address, User};
|
||||
use crate::errors::Error;
|
||||
use crate::server::Server;
|
||||
use crate::ClientServerMap;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc, Mutex,
|
||||
};
|
||||
|
||||
// Banlist: bad servers go in here.
|
||||
pub type BanList = Arc<Mutex<HashMap<Address, NaiveDateTime>>>;
|
||||
pub type Counter = Arc<AtomicUsize>;
|
||||
pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, String)>>>;
|
||||
|
||||
pub struct ServerPool {
|
||||
host: String,
|
||||
@@ -84,3 +96,101 @@ impl ManageConnection for ServerPool {
|
||||
conn.is_bad()
|
||||
}
|
||||
}
|
||||
|
||||
/// A collection of servers, which could either be a single primary,
|
||||
/// many sharded primaries or replicas.
|
||||
#[derive(Clone)]
|
||||
pub struct ReplicaPool {
|
||||
replicas: Vec<Pool<ServerPool>>,
|
||||
addresses: Vec<Address>,
|
||||
// user: User,
|
||||
round_robin: Counter,
|
||||
banlist: BanList,
|
||||
}
|
||||
|
||||
impl ReplicaPool {
|
||||
pub async fn new(
|
||||
addresses: Vec<Address>,
|
||||
user: User,
|
||||
database: &str,
|
||||
client_server_map: ClientServerMap,
|
||||
) -> ReplicaPool {
|
||||
let mut replicas = Vec::new();
|
||||
|
||||
for address in &addresses {
|
||||
let client_server_map = client_server_map.clone();
|
||||
|
||||
let manager = ServerPool::new(
|
||||
&address.host,
|
||||
&address.port,
|
||||
&user.name,
|
||||
&user.password,
|
||||
database,
|
||||
client_server_map,
|
||||
);
|
||||
|
||||
let pool = Pool::builder().max_size(15).build(manager).await.unwrap();
|
||||
|
||||
replicas.push(pool);
|
||||
}
|
||||
|
||||
ReplicaPool {
|
||||
addresses: addresses,
|
||||
replicas: replicas,
|
||||
// user: user,
|
||||
round_robin: Arc::new(AtomicUsize::new(0)),
|
||||
banlist: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ban(&mut self, address: &Address) {
|
||||
let now = chrono::offset::Utc::now().naive_utc();
|
||||
let mut guard = self.banlist.lock().unwrap();
|
||||
guard.insert(address.clone(), now);
|
||||
}
|
||||
|
||||
pub fn unban(&mut self, address: &Address) {
|
||||
let mut guard = self.banlist.lock().unwrap();
|
||||
guard.remove(address);
|
||||
}
|
||||
|
||||
pub fn is_banned(&self, address: &Address) -> bool {
|
||||
let mut guard = self.banlist.lock().unwrap();
|
||||
|
||||
// Everything is banned, nothig is banned
|
||||
if guard.len() == self.addresses.len() {
|
||||
guard.clear();
|
||||
return false;
|
||||
}
|
||||
|
||||
// I expect this to miss 99.9999% of the time.
|
||||
match guard.get(address) {
|
||||
Some(timestamp) => {
|
||||
let now = chrono::offset::Utc::now().naive_utc();
|
||||
if now.timestamp() - timestamp.timestamp() > 60 {
|
||||
// 1 minute
|
||||
guard.remove(address);
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get(&mut self) -> (Address, Pool<ServerPool>) {
|
||||
loop {
|
||||
// We'll never hit a 64-bit overflow right....right? :-)
|
||||
let index = self.round_robin.fetch_add(1, Ordering::SeqCst) % self.addresses.len();
|
||||
|
||||
let address = &self.addresses[index];
|
||||
if !self.is_banned(address) {
|
||||
return (address.clone(), self.replicas[index].clone());
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user