This commit is contained in:
Lev Kokotov
2022-02-05 19:43:48 -08:00
parent e0ca175129
commit b943ff3fa6
5 changed files with 109 additions and 56 deletions

21
Cargo.lock generated
View File

@@ -72,6 +72,15 @@ dependencies = [
"winapi",
]
[[package]]
name = "cpufeatures"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95059428f66df56b63431fdb4e1947ed2190586af5c5a8a8b71122bdf5a7f469"
dependencies = [
"libc",
]
[[package]]
name = "crypto-common"
version = "0.1.1"
@@ -342,6 +351,7 @@ dependencies = [
"chrono",
"md-5",
"rand",
"sha-1",
"tokio",
]
@@ -400,6 +410,17 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
[[package]]
name = "sha-1"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "028f48d513f9678cda28f6e4064755b3fbb2af6acd672f2c209b62323f7aea0f"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]]
name = "signal-hook-registry"
version = "1.4.0"

View File

@@ -13,3 +13,4 @@ bb8 = "*"
async-trait = "*"
rand = "*"
chrono = "0.4"
sha-1 = "*"

View File

@@ -30,6 +30,7 @@ mod errors;
mod messages;
mod pool;
mod server;
mod sharding;
// Support for query cancellation: this maps our process_ids and
// secret keys to the backend's.
@@ -41,7 +42,7 @@ use pool::{ClientServerMap, ConnectionPool};
async fn main() {
println!("> Welcome to PgCat! Meow.");
let addr = "0.0.0.0:5433";
let addr = "0.0.0.0:6432";
let listener = match TcpListener::bind(addr).await {
Ok(sock) => sock,
Err(err) => {

View File

@@ -28,8 +28,8 @@ const POOL_SIZE: u32 = 15;
#[derive(Clone)]
pub struct ConnectionPool {
databases: Vec<Pool<ServerPool>>,
addresses: Vec<Address>,
databases: Vec<Vec<Pool<ServerPool>>>,
addresses: Vec<Vec<Address>>,
round_robin: Counter,
banlist: BanList,
}
@@ -62,8 +62,8 @@ impl ConnectionPool {
}
ConnectionPool {
databases: databases,
addresses: addresses,
databases: vec![databases],
addresses: vec![addresses],
round_robin: Arc::new(AtomicUsize::new(0)),
banlist: Arc::new(Mutex::new(HashMap::new())),
}
@@ -72,63 +72,56 @@ impl ConnectionPool {
/// Get a connection from the pool. Either round-robin or pick a specific one in case they are sharded.
pub async fn get(
&self,
index: Option<usize>,
shard: Option<usize>,
) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
match index {
// Asking for a specific database, must be sharded.
// No failover here.
Some(index) => {
assert!(index < self.databases.len());
match self.databases[index].get().await {
Ok(conn) => Ok((conn, self.addresses[index].clone())),
Err(err) => {
println!(">> Shard {} down: {:?}", index, err);
Err(Error::ServerTimeout)
}
}
// Set this to false to gain ~3-4% speed.
let with_health_check = true;
let shard = match shard {
Some(shard) => shard,
None => 0, // TODO: pick a shard at random
};
loop {
let index =
self.round_robin.fetch_add(1, Ordering::SeqCst) % self.databases[shard].len();
let address = self.addresses[shard][index].clone();
if self.is_banned(&address) {
continue;
}
// Any database is fine, we're using round-robin here.
// Failover included if the server doesn't answer a health check.
None => {
loop {
let index =
self.round_robin.fetch_add(1, Ordering::SeqCst) % self.databases.len();
let address = self.addresses[index].clone();
// Check if we can connect
let mut conn = match self.databases[shard][index].get().await {
Ok(conn) => conn,
Err(err) => {
println!(">> Banning replica {}, error: {:?}", index, err);
self.ban(&address);
continue;
}
};
if self.is_banned(&address) {
continue;
}
if !with_health_check {
return Ok((conn, address));
}
// Check if we can connect
let mut conn = match self.databases[index].get().await {
Ok(conn) => conn,
Err(err) => {
println!(">> Banning replica {}, error: {:?}", index, err);
self.ban(&address);
continue;
}
};
// // Check if this server is alive with a health check
let server = &mut *conn;
// Check if this server is alive with a health check
let server = &mut *conn;
match tokio::time::timeout(
tokio::time::Duration::from_millis(1000),
server.query("SELECT 1"),
)
.await
{
Ok(_) => return Ok((conn, address)),
Err(_) => {
println!(
">> Banning replica {} because of failed health check",
index
);
self.ban(&address);
continue;
}
}
match tokio::time::timeout(
tokio::time::Duration::from_millis(1000),
server.query("SELECT 1"),
)
.await
{
Ok(_) => return Ok((conn, address)),
Err(_) => {
println!(
">> Banning replica {} because of failed health check",
index
);
self.ban(&address);
continue;
}
}
}
@@ -180,6 +173,10 @@ impl ConnectionPool {
None => false,
}
}
pub fn shards(&self) -> usize {
self.databases.len()
}
}
pub struct ServerPool {

33
src/sharding.rs Normal file
View File

@@ -0,0 +1,33 @@
use sha1::{Digest, Sha1};
pub struct Sharder {
shards: usize,
}
impl Sharder {
pub fn new(shards: usize) -> Sharder {
Sharder { shards: shards }
}
pub fn sha1(&self, key: &[u8]) -> usize {
let mut hasher = Sha1::new();
hasher.update(key);
let result = hasher.finalize_reset();
let i = u32::from_le_bytes(result[result.len() - 4..result.len()].try_into().unwrap());
i as usize % self.shards
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_sha1() {
let sharder = Sharder::new(12);
let key = b"1234";
let shard = sharder.sha1(key);
assert_eq!(shard, 1);
}
}