From c27a7d30dc44427322485ac0f21c1d102d07cd87 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Tue, 8 Feb 2022 09:25:59 -0800 Subject: [PATCH] config support; started more sharding --- Cargo.lock | 29 ++++++++++++++++ Cargo.toml | 5 ++- pgcat.toml | 60 +++++++++++++++++++++++++++++++++ src/client.rs | 5 +-- src/config.rs | 84 +++++++++++++++++++++++++++++++++++++++------- src/errors.rs | 1 + src/main.rs | 50 +++++++++++++++------------- src/pool.rs | 68 ++++++++++++++++++++++++++++++++++++-- src/server.rs | 22 +++++++++++++ src/sharding.rs | 88 +++++++++++++++++++++++++++++++++++++++++++++++++ 10 files changed, 372 insertions(+), 40 deletions(-) create mode 100644 pgcat.toml diff --git a/Cargo.lock b/Cargo.lock index 00053f9..f4909f0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -315,8 +315,11 @@ dependencies = [ "chrono", "md-5", "rand", + "serde", + "serde_derive", "sha-1", "tokio", + "toml", ] [[package]] @@ -410,6 +413,23 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "serde" +version = "1.0.136" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce31e24b01e1e524df96f1c2fdd054405f8d7376249a5110886fb4b658484789" + +[[package]] +name = "serde_derive" +version = "1.0.136" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08597e7152fcd306f41838ed3e37be9eaeed2b61c42e2117266a554fab4662f9" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "sha-1" version = "0.10.0" @@ -494,6 +514,15 @@ dependencies = [ "syn", ] +[[package]] +name = "toml" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a31142970826733df8241ef35dc040ef98c679ab14d7c3e54d827099b3acecaa" +dependencies = [ + "serde", +] + [[package]] name = "typenum" version = "1.15.0" diff --git a/Cargo.toml b/Cargo.toml index d29a9c2..3cb0a6e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,4 +13,7 @@ bb8 = "*" async-trait = "*" rand = "*" chrono = "0.4" -sha-1 = "*" \ No newline at end of file +sha-1 = "*" +toml = "*" +serde = "*" +serde_derive = "*" \ No newline at end of file diff --git a/pgcat.toml b/pgcat.toml new file mode 100644 index 0000000..bb41049 --- /dev/null +++ b/pgcat.toml @@ -0,0 +1,60 @@ +# +# PgCat config example. +# + +# +# General pooler settings +[general] + +# What IP to run on, 0.0.0.0 means accessible from everywhere. +host = "0.0.0.0" + +# Port to run on, same as PgBouncer used in this example. +port = 6432 + +# How many connections to allocate per server. +pool_size = 15 + +# Pool mode (see PgBouncer docs for more). +# session: one server connection per connected client +# transaction: one server connection per client transaction +pool_mode = "transaction" + +# How long to wait before aborting a server connection (ms). +connect_timeout = 5000 + +# How much time to give `SELECT 1` health check query to return with a result (ms). +healthcheck_timeout = 1000 + +# For how long to ban a server if it fails a health check (seconds). +ban_time = 60 # Seconds + +# +# User to use for authentication against the server. +[user] +name = "lev" +password = "lev" + + +# +# Shards in the cluster +[shards] + +# Shard 0 +[shards.0] + +# [ host, port ] +servers = [ + [ "127.0.0.1", 5432 ], + [ "localhost", 5432 ], +] +# Database name (e.g. "postgres") +database = "lev" + +[shards.1] +# [ host, port ] +servers = [ + [ "127.0.0.1", 5432 ], + [ "localhost", 5432 ], +] +database = "lev" \ No newline at end of file diff --git a/src/client.rs b/src/client.rs index 5502a3e..b7e4b15 100644 --- a/src/client.rs +++ b/src/client.rs @@ -48,6 +48,7 @@ impl Client { pub async fn startup( mut stream: TcpStream, client_server_map: ClientServerMap, + transaction_mode: bool, ) -> Result { loop { // Could be StartupMessage or SSLRequest @@ -100,7 +101,7 @@ impl Client { write: write, buffer: BytesMut::with_capacity(8196), cancel_mode: false, - transaction_mode: true, + transaction_mode: transaction_mode, process_id: process_id, secret_key: secret_key, client_server_map: client_server_map, @@ -119,7 +120,7 @@ impl Client { write: write, buffer: BytesMut::with_capacity(8196), cancel_mode: true, - transaction_mode: true, + transaction_mode: transaction_mode, process_id: process_id, secret_key: secret_key, client_server_map: client_server_map, diff --git a/src/config.rs b/src/config.rs index 736d6a5..bfc5619 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,24 +1,84 @@ +use serde_derive::Deserialize; +use std::collections::HashMap; +use std::path::Path; +use tokio::fs::File; +use tokio::io::AsyncReadExt; +use toml; + +use crate::errors::Error; + #[derive(Clone, PartialEq, Hash, std::cmp::Eq, Debug)] pub struct Address { pub host: String, pub port: String, } -#[derive(Clone, PartialEq, Hash, std::cmp::Eq)] +#[derive(Clone, PartialEq, Hash, std::cmp::Eq, Deserialize, Debug)] pub struct User { pub name: String, pub password: String, } -// #[derive(Clone)] -// pub struct Config { -// pools: HashMap>, -// } +#[derive(Deserialize, Debug, Clone)] +pub struct General { + pub host: String, + pub port: i16, + pub pool_size: u32, + pub pool_mode: String, + pub connect_timeout: u64, + pub healthcheck_timeout: u64, + pub ban_time: i64, +} -// impl Config { -// pub fn new() -> Config { -// Config { -// pools: HashMap::new(), -// } -// } -// } +#[derive(Deserialize, Debug, Clone)] +pub struct Shard { + pub servers: Vec<(String, u16)>, + pub database: String, +} + +#[derive(Deserialize, Debug, Clone)] +pub struct Config { + pub general: General, + pub user: User, + pub shards: HashMap, +} + +pub async fn parse(path: &str) -> Result { + // let path = Path::new(path); + let mut contents = String::new(); + let mut file = match File::open(path).await { + Ok(file) => file, + Err(err) => { + println!("> Config error: {:?}", err); + return Err(Error::BadConfig); + } + }; + + match file.read_to_string(&mut contents).await { + Ok(_) => (), + Err(err) => { + println!("> Config error: {:?}", err); + return Err(Error::BadConfig); + } + }; + + // let config: toml::Value = match toml::from_str(&contents) { + // Ok(config) => config, + // Err(err) => { + // println!("> Config error: {:?}", err); + // return Err(Error::BadConfig); + // } + // }; + + // println!("Config: {:?}", config); + + let config: Config = match toml::from_str(&contents) { + Ok(config) => config, + Err(err) => { + println!("> Config error: {:?}", err); + return Err(Error::BadConfig); + } + }; + + Ok(config) +} diff --git a/src/errors.rs b/src/errors.rs index cb7e756..5ee43e3 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -7,4 +7,5 @@ pub enum Error { ServerError, ServerTimeout, DirtyServer, + BadConfig, } diff --git a/src/main.rs b/src/main.rs index dcab2b2..63972bf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,7 +17,10 @@ extern crate async_trait; extern crate bb8; extern crate bytes; extern crate md5; +extern crate serde; +extern crate serde_derive; extern crate tokio; +extern crate toml; use tokio::net::TcpListener; @@ -42,8 +45,15 @@ use pool::{ClientServerMap, ConnectionPool}; async fn main() { println!("> Welcome to PgCat! Meow."); - let addr = "0.0.0.0:6432"; - let listener = match TcpListener::bind(addr).await { + let config = match config::parse("pgcat.toml").await { + Ok(config) => config, + Err(err) => { + return; + } + }; + + let addr = format!("{}:{}", config.general.host, config.general.port); + let listener = match TcpListener::bind(&addr).await { Ok(sock) => sock, Err(err) => { println!("> Error: {:?}", err); @@ -53,28 +63,21 @@ async fn main() { println!("> Running on {}", addr); + // Tracks which client is connected to which server for query cancellation. let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new())); - // Replica pool. - let addresses = vec![ - Address { - host: "127.0.0.1".to_string(), - port: "5432".to_string(), - }, - Address { - host: "localhost".to_string(), - port: "5433".to_string(), - }, - ]; + println!("> Pool size: {}", config.general.pool_size); + println!("> Pool mode: {}", config.general.pool_mode); + println!("> Ban time: {}s", config.general.ban_time); + println!( + "> Healthcheck timeout: {}ms", + config.general.healthcheck_timeout + ); - let user = User { - name: "lev".to_string(), - password: "lev".to_string(), - }; + let pool = ConnectionPool::from_config(config.clone(), client_server_map.clone()).await; + let transaction_mode = config.general.pool_mode == "transaction"; - let database = "lev"; - - let pool = ConnectionPool::new(addresses, user, database, client_server_map.clone()).await; + println!("> Waiting for clients..."); loop { let pool = pool.clone(); @@ -90,9 +93,12 @@ async fn main() { // Client goes to another thread, bye. tokio::task::spawn(async move { - println!(">> Client {:?} connected.", addr); + println!( + ">> Client {:?} connected, transaction pooling: {}", + addr, transaction_mode + ); - match client::Client::startup(socket, client_server_map).await { + match client::Client::startup(socket, client_server_map, transaction_mode).await { Ok(mut client) => { println!(">> Client {:?} authenticated successfully!", addr); diff --git a/src/pool.rs b/src/pool.rs index 2a54d55..073c20a 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -3,7 +3,7 @@ use async_trait::async_trait; use bb8::{ManageConnection, Pool, PooledConnection}; use chrono::naive::NaiveDateTime; -use crate::config::{Address, User}; +use crate::config::{Address, Config, User}; use crate::errors::Error; use crate::server::Server; @@ -31,15 +31,17 @@ const CONNECT_TIMEOUT: u64 = 5000; // How much time to give the server to answer a SELECT 1 query. const HEALTHCHECK_TIMEOUT: u64 = 1000; -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct ConnectionPool { databases: Vec>>, addresses: Vec>, round_robin: Counter, banlist: BanList, + healthcheck_timeout: u64, } impl ConnectionPool { + // Construct the connection pool for a single-shard cluster. pub async fn new( addresses: Vec
, user: User, @@ -71,10 +73,70 @@ impl ConnectionPool { addresses: vec![addresses], round_robin: Arc::new(AtomicUsize::new(0)), banlist: Arc::new(Mutex::new(vec![HashMap::new()])), + healthcheck_timeout: HEALTHCHECK_TIMEOUT, } } - /// Get a connection from the pool. Either round-robin or pick a specific one in case they are sharded. + /// Construct the connection pool from a config file. + pub async fn from_config(config: Config, client_server_map: ClientServerMap) -> ConnectionPool { + let mut shards = Vec::new(); + let mut addresses = Vec::new(); + let mut banlist = Vec::new(); + let mut shard_ids = config + .shards + .clone() + .into_keys() + .map(|x| x.to_string()) + .collect::>(); + shard_ids.sort_by_key(|k| k.parse::().unwrap()); + + for shard in shard_ids { + let shard = &config.shards[&shard]; + let mut pools = Vec::new(); + let mut replica_addresses = Vec::new(); + + for server in &shard.servers { + let address = Address { + host: server.0.clone(), + port: server.1.to_string(), + }; + + let manager = ServerPool::new( + address.clone(), + config.user.clone(), + &shard.database, + client_server_map.clone(), + ); + + let pool = Pool::builder() + .max_size(config.general.pool_size) + .connection_timeout(std::time::Duration::from_millis( + config.general.connect_timeout, + )) + .test_on_check_out(false) + .build(manager) + .await + .unwrap(); + + pools.push(pool); + replica_addresses.push(address); + } + + shards.push(pools); + addresses.push(replica_addresses); + banlist.push(HashMap::new()); + } + + ConnectionPool { + databases: shards, + addresses: addresses, + round_robin: Arc::new(AtomicUsize::new(0)), + banlist: Arc::new(Mutex::new(banlist)), + healthcheck_timeout: config.general.healthcheck_timeout, + } + } + + /// Get a connection from the pool. pub async fn get( &self, shard: Option, diff --git a/src/server.rs b/src/server.rs index 96c619c..0a7e31f 100644 --- a/src/server.rs +++ b/src/server.rs @@ -15,17 +15,38 @@ use crate::ClientServerMap; /// Server state. pub struct Server { + // Server host, e.g. localhost host: String, + + // Server port: e.g. 5432 port: String, + + // Buffered read socket read: BufReader, + + // Unbuffered write socket (our client code buffers) write: OwnedWriteHalf, + + // Our server response buffer buffer: BytesMut, + + // Server information the server sent us over on startup server_info: BytesMut, + + // Backend id and secret key used for query cancellation. backend_id: i32, secret_key: i32, + + // Is the server inside a transaction at the moment. in_transaction: bool, + + // Is there more data for the client to read. data_available: bool, + + // Is the server broken? We'll remote it from the pool if so. bad: bool, + + // Mapping of clients and servers used for query cancellation. client_server_map: ClientServerMap, } @@ -48,6 +69,7 @@ impl Server { } }; + // Send the startup packet. startup(&mut stream, user, database).await?; let mut server_info = BytesMut::with_capacity(25); diff --git a/src/sharding.rs b/src/sharding.rs index bf976e5..39c955f 100644 --- a/src/sharding.rs +++ b/src/sharding.rs @@ -1,5 +1,8 @@ use sha1::{Digest, Sha1}; +// https://github.com/postgres/postgres/blob/27b77ecf9f4d5be211900eda54d8155ada50d696/src/include/catalog/partition.h#L20 +const PARTITION_HASH_SEED: u64 = 0x7A5B22367996DCFD; + pub struct Sharder { shards: usize, } @@ -9,6 +12,8 @@ impl Sharder { Sharder { shards: shards } } + /// Use SHA1 to pick a shard for the key. The key can be anything, + /// including an int or a string. pub fn sha1(&self, key: &[u8]) -> usize { let mut hasher = Sha1::new(); hasher.update(key); @@ -17,6 +22,81 @@ impl Sharder { let i = u32::from_le_bytes(result[result.len() - 4..result.len()].try_into().unwrap()); i as usize % self.shards } + + /// Hash function used by Postgres to determine which partition + /// to put the row in when using HASH(column) partitioning. + /// Source: https://github.com/postgres/postgres/blob/27b77ecf9f4d5be211900eda54d8155ada50d696/src/common/hashfn.c#L631 + pub fn pg_bigint_hash(&self, key: i64) -> usize { + let mut lohalf = key as u32; + let hihalf = (key >> 32) as u32; + println!("{}, {}", lohalf, hihalf); + lohalf ^= if key >= 0 { hihalf } else { !hihalf }; + println!("Low half: {}", lohalf); + Self::pg_u32_hash(lohalf) as usize % self.shards + } + + fn rot(x: u32, k: u32) -> u32 { + ((x) << (k)) | ((x) >> (32 - (k))) + } + + #[inline] + fn mix(mut a: u32, mut b: u32, mut c: u32) -> (u32, u32, u32) { + a = a.wrapping_sub(c); + a ^= Self::rot(c, 4); + c = c.wrapping_add(b); + b = b.wrapping_add(a); + b ^= Self::rot(a, 6); + a = a.wrapping_add(c); + c = c.wrapping_add(b); + c ^= Self::rot(b, 8); + b = b.wrapping_add(a); + a = a.wrapping_add(c); + a ^= Self::rot(c, 16); + c = c.wrapping_add(b); + b = b.wrapping_add(a); + b ^= Self::rot(a, 19); + a = a.wrapping_add(c); + c = c.wrapping_add(b); + c ^= Self::rot(b, 4); + b = b.wrapping_add(a); + (a, b, c) + } + + #[inline] + fn _final(mut a: u32, mut b: u32, mut c: u32) -> (u32, u32, u32) { + c ^= b; + c = c.wrapping_add(Self::rot(b, 14)); + a ^= c; + a = a.wrapping_add(Self::rot(c, 11)); + b ^= a; + b = b.wrapping_add(Self::rot(a, 25)); + c ^= b; + c = c.wrapping_add(Self::rot(b, 16)); + a ^= c; + a = a.wrapping_add(Self::rot(c, 4)); + b ^= a; + b = b.wrapping_add(Self::rot(a, 14)); + c ^= b; + c = c.wrapping_add(Self::rot(b, 24)); + (a, b, c) + } + + fn pg_u32_hash(val: u32) -> u64 { + let mut a: u32 = 0x9e3779b9 + 4 + 3923095; + let mut b = a; + let c = a; + let seed = PARTITION_HASH_SEED; + + a = a.wrapping_add((seed >> 32) as u32); + b = b.wrapping_add(seed as u32); + let (mut a, b, c) = Self::mix(a, b, c); + + a = a.wrapping_add(val); + + let (a, b, c) = Self::_final(a, b, c); + + (b as u64) << 32 | c as u64 + } } #[cfg(test)] @@ -30,4 +110,12 @@ mod test { let shard = sharder.sha1(key); assert_eq!(shard, 1); } + + #[test] + fn test_pg_bigint_hash() { + let sharder = Sharder::new(4); + let key = 23423423 as i64; + let shard = sharder.pg_bigint_hash(key); + assert_eq!(shard, 0); + } }