From a6574acbc36d92282dbb63453307937c765e15f9 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Sat, 5 Feb 2022 10:02:13 -0800 Subject: [PATCH] replica pool & banlist --- Cargo.lock | 54 +++++++++++++++++++++++- Cargo.toml | 3 +- src/client.rs | 21 ++++++---- src/config.rs | 24 +++++++++++ src/main.rs | 49 ++++++++++++++-------- src/pool.rs | 114 +++++++++++++++++++++++++++++++++++++++++++++++++- src/server.rs | 14 ++++++- 7 files changed, 246 insertions(+), 33 deletions(-) create mode 100644 src/config.rs diff --git a/Cargo.lock b/Cargo.lock index 9f87edc..3bb02fa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -13,6 +13,12 @@ dependencies = [ "syn", ] +[[package]] +name = "autocfg" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" + [[package]] name = "bb8" version = "0.7.1" @@ -53,6 +59,19 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chrono" +version = "0.4.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "670ad68c9088c2a963aaa298cb369688cf3f9465ce5e2d4ca10e6e0098a1ce73" +dependencies = [ + "libc", + "num-integer", + "num-traits", + "time", + "winapi", +] + [[package]] name = "crypto-common" version = "0.1.1" @@ -217,6 +236,25 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-integer" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2cc698a63b549a70bc047073d2949cce27cd1c7b0a4a862d08a8031bc2801db" +dependencies = [ + "autocfg", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a64b1ec5cda2586e284722486d802acf1f7dbdc623e2bfc57e65ca1cd099290" +dependencies = [ + "autocfg", +] + [[package]] name = "num_cpus" version = "1.13.1" @@ -301,6 +339,7 @@ dependencies = [ "async-trait", "bb8", "bytes", + "chrono", "md-5", "rand", "tokio", @@ -393,6 +432,17 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "time" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db9e6914ab8b1ae1c260a4ae7a49b6c5611b40328a735b21862567685e73255" +dependencies = [ + "libc", + "wasi", + "winapi", +] + [[package]] name = "tokio" version = "1.16.1" @@ -443,9 +493,9 @@ checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" [[package]] name = "wasi" -version = "0.10.2+wasi-snapshot-preview1" +version = "0.10.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6" +checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f" [[package]] name = "winapi" diff --git a/Cargo.toml b/Cargo.toml index 28a079d..52115f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,4 +11,5 @@ bytes = "1" md-5 = "*" bb8 = "*" async-trait = "*" -rand = "*" \ No newline at end of file +rand = "*" +chrono = "0.4" \ No newline at end of file diff --git a/src/client.rs b/src/client.rs index 25f5ab2..2d7c974 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,18 +1,16 @@ /// Implementation of the PostgreSQL client. /// We are pretending to the server in this scenario, /// and this module implements that. -use bb8::Pool; use bytes::{Buf, BufMut, BytesMut}; use rand::{distributions::Alphanumeric, Rng}; -use tokio::io::{AsyncReadExt, BufReader, Interest}; +use tokio::io::{AsyncReadExt, BufReader}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::TcpStream; use crate::errors::Error; use crate::messages::*; -use crate::pool::ServerPool; +use crate::pool::{ClientServerMap, ReplicaPool}; use crate::server::Server; -use crate::ClientServerMap; /// The client state. pub struct Client { @@ -123,19 +121,24 @@ impl Client { } /// Client loop. We handle all messages between the client and the database here. - pub async fn handle(&mut self, pool: Pool) -> Result<(), Error> { + pub async fn handle(&mut self, mut pool: ReplicaPool) -> Result<(), Error> { // Special: cancelling existing running query if self.cancel_mode { - let (process_id, secret_key) = { + let (process_id, secret_key, address, port) = { let guard = self.client_server_map.lock().unwrap(); match guard.get(&(self.process_id, self.secret_key)) { - Some((process_id, secret_key)) => (process_id.clone(), secret_key.clone()), + Some((process_id, secret_key, address, port)) => ( + process_id.clone(), + secret_key.clone(), + address.clone(), + port.clone(), + ), None => return Ok(()), } }; // TODO: pass actual server host and port somewhere. - return Ok(Server::cancel("127.0.0.1", "5432", process_id, secret_key).await?); + return Ok(Server::cancel(&address, &port, process_id, secret_key).await?); } loop { @@ -147,7 +150,7 @@ impl Client { Ok(_) => (), Err(_) => return Err(Error::ClientDisconnected), }; - + let pool = pool.get().1; let mut proxy = pool.get().await.unwrap(); let server = &mut *proxy; diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..c0983a1 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,24 @@ +#[derive(Clone, PartialEq, Hash, std::cmp::Eq)] +pub struct Address { + pub host: String, + pub port: String, +} + +#[derive(Clone, PartialEq, Hash, std::cmp::Eq)] +pub struct User { + pub name: String, + pub password: String, +} + +// #[derive(Clone)] +// pub struct Config { +// pools: HashMap>, +// } + +// impl Config { +// pub fn new() -> Config { +// Config { +// pools: HashMap::new(), +// } +// } +// } diff --git a/src/main.rs b/src/main.rs index b53dee4..ad62664 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,25 +4,29 @@ extern crate bytes; extern crate md5; extern crate tokio; -use bb8::Pool; use tokio::net::TcpListener; use std::collections::HashMap; use std::sync::{Arc, Mutex}; mod client; +mod config; mod errors; mod messages; mod pool; mod server; -type ClientServerMap = Arc>>; +// Support for query cancellation: this maps our process_ids and +// secret keys to the backend's. +use config::{Address, User}; +use pool::{ClientServerMap, ReplicaPool}; #[tokio::main] async fn main() { - println!("> Welcome to PgRabbit"); + println!("> Welcome to PgCat! Meow."); - let listener = match TcpListener::bind("0.0.0.0:5433").await { + let addr = "0.0.0.0:5433"; + let listener = match TcpListener::bind(addr).await { Ok(sock) => sock, Err(err) => { println!("> Error: {:?}", err); @@ -30,20 +34,32 @@ async fn main() { } }; + println!("> Running on {}", addr); + let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new())); - let manager = pool::ServerPool::new( - "127.0.0.1", - "5432", - "lev", - "lev", - "lev", - client_server_map.clone(), - ); - let pool = Pool::builder().max_size(15).build(manager).await.unwrap(); + + // Note in the logs that it will fetch two connections! + let addresses = vec![ + Address { + host: "127.0.0.1".to_string(), + port: "5432".to_string(), + }, + Address { + host: "localhost".to_string(), + port: "5432".to_string(), + }, + ]; + + let user = User { + name: "lev".to_string(), + password: "lev".to_string(), + }; + + let replica_pool = ReplicaPool::new(addresses, user, "lev", client_server_map.clone()).await; loop { - let pool = pool.clone(); let client_server_map = client_server_map.clone(); + let replica_pool = replica_pool.clone(); let (socket, addr) = match listener.accept().await { Ok((socket, addr)) => (socket, addr), @@ -57,14 +73,11 @@ async fn main() { tokio::task::spawn(async move { println!(">> Client {:?} connected.", addr); - let pool = pool.clone(); - let client_server_map = client_server_map.clone(); - match client::Client::startup(socket, client_server_map).await { Ok(mut client) => { println!(">> Client {:?} authenticated successfully!", addr); - match client.handle(pool).await { + match client.handle(replica_pool).await { Ok(()) => { println!(">> Client {:?} disconnected.", addr); } diff --git a/src/pool.rs b/src/pool.rs index 65b37c6..cac9fb5 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -1,9 +1,21 @@ use async_trait::async_trait; -use bb8::{ManageConnection, PooledConnection}; +use bb8::{ManageConnection, Pool, PooledConnection}; +use chrono::naive::NaiveDateTime; +use crate::config::{Address, User}; use crate::errors::Error; use crate::server::Server; -use crate::ClientServerMap; + +use std::collections::HashMap; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex, +}; + +// Banlist: bad servers go in here. +pub type BanList = Arc>>; +pub type Counter = Arc; +pub type ClientServerMap = Arc>>; pub struct ServerPool { host: String, @@ -84,3 +96,101 @@ impl ManageConnection for ServerPool { conn.is_bad() } } + +/// A collection of servers, which could either be a single primary, +/// many sharded primaries or replicas. +#[derive(Clone)] +pub struct ReplicaPool { + replicas: Vec>, + addresses: Vec
, + // user: User, + round_robin: Counter, + banlist: BanList, +} + +impl ReplicaPool { + pub async fn new( + addresses: Vec
, + user: User, + database: &str, + client_server_map: ClientServerMap, + ) -> ReplicaPool { + let mut replicas = Vec::new(); + + for address in &addresses { + let client_server_map = client_server_map.clone(); + + let manager = ServerPool::new( + &address.host, + &address.port, + &user.name, + &user.password, + database, + client_server_map, + ); + + let pool = Pool::builder().max_size(15).build(manager).await.unwrap(); + + replicas.push(pool); + } + + ReplicaPool { + addresses: addresses, + replicas: replicas, + // user: user, + round_robin: Arc::new(AtomicUsize::new(0)), + banlist: Arc::new(Mutex::new(HashMap::new())), + } + } + + pub fn ban(&mut self, address: &Address) { + let now = chrono::offset::Utc::now().naive_utc(); + let mut guard = self.banlist.lock().unwrap(); + guard.insert(address.clone(), now); + } + + pub fn unban(&mut self, address: &Address) { + let mut guard = self.banlist.lock().unwrap(); + guard.remove(address); + } + + pub fn is_banned(&self, address: &Address) -> bool { + let mut guard = self.banlist.lock().unwrap(); + + // Everything is banned, nothig is banned + if guard.len() == self.addresses.len() { + guard.clear(); + return false; + } + + // I expect this to miss 99.9999% of the time. + match guard.get(address) { + Some(timestamp) => { + let now = chrono::offset::Utc::now().naive_utc(); + if now.timestamp() - timestamp.timestamp() > 60 { + // 1 minute + guard.remove(address); + false + } else { + true + } + } + + None => false, + } + } + + pub fn get(&mut self) -> (Address, Pool) { + loop { + // We'll never hit a 64-bit overflow right....right? :-) + let index = self.round_robin.fetch_add(1, Ordering::SeqCst) % self.addresses.len(); + + let address = &self.addresses[index]; + if !self.is_banned(address) { + return (address.clone(), self.replicas[index].clone()); + } else { + continue; + } + } + } +} diff --git a/src/server.rs b/src/server.rs index 4b280b1..fb7b85d 100644 --- a/src/server.rs +++ b/src/server.rs @@ -14,6 +14,8 @@ use crate::ClientServerMap; /// Server state. pub struct Server { + host: String, + port: String, read: BufReader, write: OwnedWriteHalf, buffer: BytesMut, @@ -138,6 +140,8 @@ impl Server { let (read, write) = stream.into_split(); return Ok(Server { + host: host.to_string(), + port: port.to_string(), read: BufReader::new(read), write: write, buffer: BytesMut::with_capacity(8196), @@ -308,7 +312,15 @@ impl Server { /// Claim this server as mine for the purposes of query cancellation. pub fn claim(&mut self, process_id: i32, secret_key: i32) { let mut guard = self.client_server_map.lock().unwrap(); - guard.insert((process_id, secret_key), (self.backend_id, self.secret_key)); + guard.insert( + (process_id, secret_key), + ( + self.backend_id, + self.secret_key, + self.host.clone(), + self.port.clone(), + ), + ); } /// Execute an arbitrary query against the server.