mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 09:26:30 +00:00
sharding
This commit is contained in:
21
Cargo.lock
generated
21
Cargo.lock
generated
@@ -72,6 +72,15 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cpufeatures"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "95059428f66df56b63431fdb4e1947ed2190586af5c5a8a8b71122bdf5a7f469"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crypto-common"
|
||||
version = "0.1.1"
|
||||
@@ -342,6 +351,7 @@ dependencies = [
|
||||
"chrono",
|
||||
"md-5",
|
||||
"rand",
|
||||
"sha-1",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
@@ -400,6 +410,17 @@ version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
|
||||
|
||||
[[package]]
|
||||
name = "sha-1"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "028f48d513f9678cda28f6e4064755b3fbb2af6acd672f2c209b62323f7aea0f"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "signal-hook-registry"
|
||||
version = "1.4.0"
|
||||
|
||||
@@ -13,3 +13,4 @@ bb8 = "*"
|
||||
async-trait = "*"
|
||||
rand = "*"
|
||||
chrono = "0.4"
|
||||
sha-1 = "*"
|
||||
@@ -30,6 +30,7 @@ mod errors;
|
||||
mod messages;
|
||||
mod pool;
|
||||
mod server;
|
||||
mod sharding;
|
||||
|
||||
// Support for query cancellation: this maps our process_ids and
|
||||
// secret keys to the backend's.
|
||||
@@ -41,7 +42,7 @@ use pool::{ClientServerMap, ConnectionPool};
|
||||
async fn main() {
|
||||
println!("> Welcome to PgCat! Meow.");
|
||||
|
||||
let addr = "0.0.0.0:5433";
|
||||
let addr = "0.0.0.0:6432";
|
||||
let listener = match TcpListener::bind(addr).await {
|
||||
Ok(sock) => sock,
|
||||
Err(err) => {
|
||||
|
||||
107
src/pool.rs
107
src/pool.rs
@@ -28,8 +28,8 @@ const POOL_SIZE: u32 = 15;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ConnectionPool {
|
||||
databases: Vec<Pool<ServerPool>>,
|
||||
addresses: Vec<Address>,
|
||||
databases: Vec<Vec<Pool<ServerPool>>>,
|
||||
addresses: Vec<Vec<Address>>,
|
||||
round_robin: Counter,
|
||||
banlist: BanList,
|
||||
}
|
||||
@@ -62,8 +62,8 @@ impl ConnectionPool {
|
||||
}
|
||||
|
||||
ConnectionPool {
|
||||
databases: databases,
|
||||
addresses: addresses,
|
||||
databases: vec![databases],
|
||||
addresses: vec![addresses],
|
||||
round_robin: Arc::new(AtomicUsize::new(0)),
|
||||
banlist: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
@@ -72,63 +72,56 @@ impl ConnectionPool {
|
||||
/// Get a connection from the pool. Either round-robin or pick a specific one in case they are sharded.
|
||||
pub async fn get(
|
||||
&self,
|
||||
index: Option<usize>,
|
||||
shard: Option<usize>,
|
||||
) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
|
||||
match index {
|
||||
// Asking for a specific database, must be sharded.
|
||||
// No failover here.
|
||||
Some(index) => {
|
||||
assert!(index < self.databases.len());
|
||||
match self.databases[index].get().await {
|
||||
Ok(conn) => Ok((conn, self.addresses[index].clone())),
|
||||
Err(err) => {
|
||||
println!(">> Shard {} down: {:?}", index, err);
|
||||
Err(Error::ServerTimeout)
|
||||
}
|
||||
}
|
||||
// Set this to false to gain ~3-4% speed.
|
||||
let with_health_check = true;
|
||||
|
||||
let shard = match shard {
|
||||
Some(shard) => shard,
|
||||
None => 0, // TODO: pick a shard at random
|
||||
};
|
||||
|
||||
loop {
|
||||
let index =
|
||||
self.round_robin.fetch_add(1, Ordering::SeqCst) % self.databases[shard].len();
|
||||
let address = self.addresses[shard][index].clone();
|
||||
|
||||
if self.is_banned(&address) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Any database is fine, we're using round-robin here.
|
||||
// Failover included if the server doesn't answer a health check.
|
||||
None => {
|
||||
loop {
|
||||
let index =
|
||||
self.round_robin.fetch_add(1, Ordering::SeqCst) % self.databases.len();
|
||||
let address = self.addresses[index].clone();
|
||||
// Check if we can connect
|
||||
let mut conn = match self.databases[shard][index].get().await {
|
||||
Ok(conn) => conn,
|
||||
Err(err) => {
|
||||
println!(">> Banning replica {}, error: {:?}", index, err);
|
||||
self.ban(&address);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if self.is_banned(&address) {
|
||||
continue;
|
||||
}
|
||||
if !with_health_check {
|
||||
return Ok((conn, address));
|
||||
}
|
||||
|
||||
// Check if we can connect
|
||||
let mut conn = match self.databases[index].get().await {
|
||||
Ok(conn) => conn,
|
||||
Err(err) => {
|
||||
println!(">> Banning replica {}, error: {:?}", index, err);
|
||||
self.ban(&address);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
// // Check if this server is alive with a health check
|
||||
let server = &mut *conn;
|
||||
|
||||
// Check if this server is alive with a health check
|
||||
let server = &mut *conn;
|
||||
|
||||
match tokio::time::timeout(
|
||||
tokio::time::Duration::from_millis(1000),
|
||||
server.query("SELECT 1"),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_) => return Ok((conn, address)),
|
||||
Err(_) => {
|
||||
println!(
|
||||
">> Banning replica {} because of failed health check",
|
||||
index
|
||||
);
|
||||
self.ban(&address);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
match tokio::time::timeout(
|
||||
tokio::time::Duration::from_millis(1000),
|
||||
server.query("SELECT 1"),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(_) => return Ok((conn, address)),
|
||||
Err(_) => {
|
||||
println!(
|
||||
">> Banning replica {} because of failed health check",
|
||||
index
|
||||
);
|
||||
self.ban(&address);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -180,6 +173,10 @@ impl ConnectionPool {
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shards(&self) -> usize {
|
||||
self.databases.len()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ServerPool {
|
||||
|
||||
33
src/sharding.rs
Normal file
33
src/sharding.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
use sha1::{Digest, Sha1};
|
||||
|
||||
pub struct Sharder {
|
||||
shards: usize,
|
||||
}
|
||||
|
||||
impl Sharder {
|
||||
pub fn new(shards: usize) -> Sharder {
|
||||
Sharder { shards: shards }
|
||||
}
|
||||
|
||||
pub fn sha1(&self, key: &[u8]) -> usize {
|
||||
let mut hasher = Sha1::new();
|
||||
hasher.update(key);
|
||||
let result = hasher.finalize_reset();
|
||||
|
||||
let i = u32::from_le_bytes(result[result.len() - 4..result.len()].try_into().unwrap());
|
||||
i as usize % self.shards
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_sha1() {
|
||||
let sharder = Sharder::new(12);
|
||||
let key = b"1234";
|
||||
let shard = sharder.sha1(key);
|
||||
assert_eq!(shard, 1);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user