diff --git a/Cargo.lock b/Cargo.lock index 3bb02fa..8cf1dcb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -72,6 +72,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "cpufeatures" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95059428f66df56b63431fdb4e1947ed2190586af5c5a8a8b71122bdf5a7f469" +dependencies = [ + "libc", +] + [[package]] name = "crypto-common" version = "0.1.1" @@ -342,6 +351,7 @@ dependencies = [ "chrono", "md-5", "rand", + "sha-1", "tokio", ] @@ -400,6 +410,17 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "sha-1" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "028f48d513f9678cda28f6e4064755b3fbb2af6acd672f2c209b62323f7aea0f" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "signal-hook-registry" version = "1.4.0" diff --git a/Cargo.toml b/Cargo.toml index 3424891..bc8034a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,3 +13,4 @@ bb8 = "*" async-trait = "*" rand = "*" chrono = "0.4" +sha-1 = "*" \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index ed9e738..e2e4dd5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -30,6 +30,7 @@ mod errors; mod messages; mod pool; mod server; +mod sharding; // Support for query cancellation: this maps our process_ids and // secret keys to the backend's. @@ -41,7 +42,7 @@ use pool::{ClientServerMap, ConnectionPool}; async fn main() { println!("> Welcome to PgCat! Meow."); - let addr = "0.0.0.0:5433"; + let addr = "0.0.0.0:6432"; let listener = match TcpListener::bind(addr).await { Ok(sock) => sock, Err(err) => { diff --git a/src/pool.rs b/src/pool.rs index 049fb2e..e4097b9 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -28,8 +28,8 @@ const POOL_SIZE: u32 = 15; #[derive(Clone)] pub struct ConnectionPool { - databases: Vec>, - addresses: Vec
, + databases: Vec>>, + addresses: Vec>, round_robin: Counter, banlist: BanList, } @@ -62,8 +62,8 @@ impl ConnectionPool { } ConnectionPool { - databases: databases, - addresses: addresses, + databases: vec![databases], + addresses: vec![addresses], round_robin: Arc::new(AtomicUsize::new(0)), banlist: Arc::new(Mutex::new(HashMap::new())), } @@ -72,63 +72,56 @@ impl ConnectionPool { /// Get a connection from the pool. Either round-robin or pick a specific one in case they are sharded. pub async fn get( &self, - index: Option, + shard: Option, ) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> { - match index { - // Asking for a specific database, must be sharded. - // No failover here. - Some(index) => { - assert!(index < self.databases.len()); - match self.databases[index].get().await { - Ok(conn) => Ok((conn, self.addresses[index].clone())), - Err(err) => { - println!(">> Shard {} down: {:?}", index, err); - Err(Error::ServerTimeout) - } - } + // Set this to false to gain ~3-4% speed. + let with_health_check = true; + + let shard = match shard { + Some(shard) => shard, + None => 0, // TODO: pick a shard at random + }; + + loop { + let index = + self.round_robin.fetch_add(1, Ordering::SeqCst) % self.databases[shard].len(); + let address = self.addresses[shard][index].clone(); + + if self.is_banned(&address) { + continue; } - // Any database is fine, we're using round-robin here. - // Failover included if the server doesn't answer a health check. - None => { - loop { - let index = - self.round_robin.fetch_add(1, Ordering::SeqCst) % self.databases.len(); - let address = self.addresses[index].clone(); + // Check if we can connect + let mut conn = match self.databases[shard][index].get().await { + Ok(conn) => conn, + Err(err) => { + println!(">> Banning replica {}, error: {:?}", index, err); + self.ban(&address); + continue; + } + }; - if self.is_banned(&address) { - continue; - } + if !with_health_check { + return Ok((conn, address)); + } - // Check if we can connect - let mut conn = match self.databases[index].get().await { - Ok(conn) => conn, - Err(err) => { - println!(">> Banning replica {}, error: {:?}", index, err); - self.ban(&address); - continue; - } - }; + // // Check if this server is alive with a health check + let server = &mut *conn; - // Check if this server is alive with a health check - let server = &mut *conn; - - match tokio::time::timeout( - tokio::time::Duration::from_millis(1000), - server.query("SELECT 1"), - ) - .await - { - Ok(_) => return Ok((conn, address)), - Err(_) => { - println!( - ">> Banning replica {} because of failed health check", - index - ); - self.ban(&address); - continue; - } - } + match tokio::time::timeout( + tokio::time::Duration::from_millis(1000), + server.query("SELECT 1"), + ) + .await + { + Ok(_) => return Ok((conn, address)), + Err(_) => { + println!( + ">> Banning replica {} because of failed health check", + index + ); + self.ban(&address); + continue; } } } @@ -180,6 +173,10 @@ impl ConnectionPool { None => false, } } + + pub fn shards(&self) -> usize { + self.databases.len() + } } pub struct ServerPool { diff --git a/src/sharding.rs b/src/sharding.rs new file mode 100644 index 0000000..bf976e5 --- /dev/null +++ b/src/sharding.rs @@ -0,0 +1,33 @@ +use sha1::{Digest, Sha1}; + +pub struct Sharder { + shards: usize, +} + +impl Sharder { + pub fn new(shards: usize) -> Sharder { + Sharder { shards: shards } + } + + pub fn sha1(&self, key: &[u8]) -> usize { + let mut hasher = Sha1::new(); + hasher.update(key); + let result = hasher.finalize_reset(); + + let i = u32::from_le_bytes(result[result.len() - 4..result.len()].try_into().unwrap()); + i as usize % self.shards + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_sha1() { + let sharder = Sharder::new(12); + let key = b"1234"; + let shard = sharder.sha1(key); + assert_eq!(shard, 1); + } +}