This commit is contained in:
Lev Kokotov
2022-02-05 19:43:48 -08:00
parent e0ca175129
commit b943ff3fa6
5 changed files with 109 additions and 56 deletions

21
Cargo.lock generated
View File

@@ -72,6 +72,15 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "cpufeatures"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95059428f66df56b63431fdb4e1947ed2190586af5c5a8a8b71122bdf5a7f469"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "crypto-common" name = "crypto-common"
version = "0.1.1" version = "0.1.1"
@@ -342,6 +351,7 @@ dependencies = [
"chrono", "chrono",
"md-5", "md-5",
"rand", "rand",
"sha-1",
"tokio", "tokio",
] ]
@@ -400,6 +410,17 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
[[package]]
name = "sha-1"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "028f48d513f9678cda28f6e4064755b3fbb2af6acd672f2c209b62323f7aea0f"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]] [[package]]
name = "signal-hook-registry" name = "signal-hook-registry"
version = "1.4.0" version = "1.4.0"

View File

@@ -13,3 +13,4 @@ bb8 = "*"
async-trait = "*" async-trait = "*"
rand = "*" rand = "*"
chrono = "0.4" chrono = "0.4"
sha-1 = "*"

View File

@@ -30,6 +30,7 @@ mod errors;
mod messages; mod messages;
mod pool; mod pool;
mod server; mod server;
mod sharding;
// Support for query cancellation: this maps our process_ids and // Support for query cancellation: this maps our process_ids and
// secret keys to the backend's. // secret keys to the backend's.
@@ -41,7 +42,7 @@ use pool::{ClientServerMap, ConnectionPool};
async fn main() { async fn main() {
println!("> Welcome to PgCat! Meow."); println!("> Welcome to PgCat! Meow.");
let addr = "0.0.0.0:5433"; let addr = "0.0.0.0:6432";
let listener = match TcpListener::bind(addr).await { let listener = match TcpListener::bind(addr).await {
Ok(sock) => sock, Ok(sock) => sock,
Err(err) => { Err(err) => {

View File

@@ -28,8 +28,8 @@ const POOL_SIZE: u32 = 15;
#[derive(Clone)] #[derive(Clone)]
pub struct ConnectionPool { pub struct ConnectionPool {
databases: Vec<Pool<ServerPool>>, databases: Vec<Vec<Pool<ServerPool>>>,
addresses: Vec<Address>, addresses: Vec<Vec<Address>>,
round_robin: Counter, round_robin: Counter,
banlist: BanList, banlist: BanList,
} }
@@ -62,8 +62,8 @@ impl ConnectionPool {
} }
ConnectionPool { ConnectionPool {
databases: databases, databases: vec![databases],
addresses: addresses, addresses: vec![addresses],
round_robin: Arc::new(AtomicUsize::new(0)), round_robin: Arc::new(AtomicUsize::new(0)),
banlist: Arc::new(Mutex::new(HashMap::new())), banlist: Arc::new(Mutex::new(HashMap::new())),
} }
@@ -72,36 +72,27 @@ impl ConnectionPool {
/// Get a connection from the pool. Either round-robin or pick a specific one in case they are sharded. /// Get a connection from the pool. Either round-robin or pick a specific one in case they are sharded.
pub async fn get( pub async fn get(
&self, &self,
index: Option<usize>, shard: Option<usize>,
) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> { ) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
match index { // Set this to false to gain ~3-4% speed.
// Asking for a specific database, must be sharded. let with_health_check = true;
// No failover here.
Some(index) => { let shard = match shard {
assert!(index < self.databases.len()); Some(shard) => shard,
match self.databases[index].get().await { None => 0, // TODO: pick a shard at random
Ok(conn) => Ok((conn, self.addresses[index].clone())), };
Err(err) => {
println!(">> Shard {} down: {:?}", index, err);
Err(Error::ServerTimeout)
}
}
}
// Any database is fine, we're using round-robin here.
// Failover included if the server doesn't answer a health check.
None => {
loop { loop {
let index = let index =
self.round_robin.fetch_add(1, Ordering::SeqCst) % self.databases.len(); self.round_robin.fetch_add(1, Ordering::SeqCst) % self.databases[shard].len();
let address = self.addresses[index].clone(); let address = self.addresses[shard][index].clone();
if self.is_banned(&address) { if self.is_banned(&address) {
continue; continue;
} }
// Check if we can connect // Check if we can connect
let mut conn = match self.databases[index].get().await { let mut conn = match self.databases[shard][index].get().await {
Ok(conn) => conn, Ok(conn) => conn,
Err(err) => { Err(err) => {
println!(">> Banning replica {}, error: {:?}", index, err); println!(">> Banning replica {}, error: {:?}", index, err);
@@ -110,7 +101,11 @@ impl ConnectionPool {
} }
}; };
// Check if this server is alive with a health check if !with_health_check {
return Ok((conn, address));
}
// // Check if this server is alive with a health check
let server = &mut *conn; let server = &mut *conn;
match tokio::time::timeout( match tokio::time::timeout(
@@ -131,8 +126,6 @@ impl ConnectionPool {
} }
} }
} }
}
}
/// Ban an address (i.e. replica). It no longer will serve /// Ban an address (i.e. replica). It no longer will serve
/// traffic for any new transactions. Existing transactions on that replica /// traffic for any new transactions. Existing transactions on that replica
@@ -180,6 +173,10 @@ impl ConnectionPool {
None => false, None => false,
} }
} }
pub fn shards(&self) -> usize {
self.databases.len()
}
} }
pub struct ServerPool { pub struct ServerPool {

33
src/sharding.rs Normal file
View File

@@ -0,0 +1,33 @@
use sha1::{Digest, Sha1};
pub struct Sharder {
shards: usize,
}
impl Sharder {
pub fn new(shards: usize) -> Sharder {
Sharder { shards: shards }
}
pub fn sha1(&self, key: &[u8]) -> usize {
let mut hasher = Sha1::new();
hasher.update(key);
let result = hasher.finalize_reset();
let i = u32::from_le_bytes(result[result.len() - 4..result.len()].try_into().unwrap());
i as usize % self.shards
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_sha1() {
let sharder = Sharder::new(12);
let key = b"1234";
let shard = sharder.sha1(key);
assert_eq!(shard, 1);
}
}