config support; started more sharding

This commit is contained in:
Lev Kokotov
2022-02-08 09:25:59 -08:00
parent ef2aab3c61
commit c27a7d30dc
10 changed files with 372 additions and 40 deletions

29
Cargo.lock generated
View File

@@ -315,8 +315,11 @@ dependencies = [
"chrono",
"md-5",
"rand",
"serde",
"serde_derive",
"sha-1",
"tokio",
"toml",
]
[[package]]
@@ -410,6 +413,23 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
[[package]]
name = "serde"
version = "1.0.136"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce31e24b01e1e524df96f1c2fdd054405f8d7376249a5110886fb4b658484789"
[[package]]
name = "serde_derive"
version = "1.0.136"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08597e7152fcd306f41838ed3e37be9eaeed2b61c42e2117266a554fab4662f9"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "sha-1"
version = "0.10.0"
@@ -494,6 +514,15 @@ dependencies = [
"syn",
]
[[package]]
name = "toml"
version = "0.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a31142970826733df8241ef35dc040ef98c679ab14d7c3e54d827099b3acecaa"
dependencies = [
"serde",
]
[[package]]
name = "typenum"
version = "1.15.0"

View File

@@ -14,3 +14,6 @@ async-trait = "*"
rand = "*"
chrono = "0.4"
sha-1 = "*"
toml = "*"
serde = "*"
serde_derive = "*"

60
pgcat.toml Normal file
View File

@@ -0,0 +1,60 @@
#
# PgCat config example.
#
#
# General pooler settings
[general]
# What IP to run on, 0.0.0.0 means accessible from everywhere.
host = "0.0.0.0"
# Port to run on, same as PgBouncer used in this example.
port = 6432
# How many connections to allocate per server.
pool_size = 15
# Pool mode (see PgBouncer docs for more).
# session: one server connection per connected client
# transaction: one server connection per client transaction
pool_mode = "transaction"
# How long to wait before aborting a server connection (ms).
connect_timeout = 5000
# How much time to give `SELECT 1` health check query to return with a result (ms).
healthcheck_timeout = 1000
# For how long to ban a server if it fails a health check (seconds).
ban_time = 60 # Seconds
#
# User to use for authentication against the server.
[user]
name = "lev"
password = "lev"
#
# Shards in the cluster
[shards]
# Shard 0
[shards.0]
# [ host, port ]
servers = [
[ "127.0.0.1", 5432 ],
[ "localhost", 5432 ],
]
# Database name (e.g. "postgres")
database = "lev"
[shards.1]
# [ host, port ]
servers = [
[ "127.0.0.1", 5432 ],
[ "localhost", 5432 ],
]
database = "lev"

View File

@@ -48,6 +48,7 @@ impl Client {
pub async fn startup(
mut stream: TcpStream,
client_server_map: ClientServerMap,
transaction_mode: bool,
) -> Result<Client, Error> {
loop {
// Could be StartupMessage or SSLRequest
@@ -100,7 +101,7 @@ impl Client {
write: write,
buffer: BytesMut::with_capacity(8196),
cancel_mode: false,
transaction_mode: true,
transaction_mode: transaction_mode,
process_id: process_id,
secret_key: secret_key,
client_server_map: client_server_map,
@@ -119,7 +120,7 @@ impl Client {
write: write,
buffer: BytesMut::with_capacity(8196),
cancel_mode: true,
transaction_mode: true,
transaction_mode: transaction_mode,
process_id: process_id,
secret_key: secret_key,
client_server_map: client_server_map,

View File

@@ -1,24 +1,84 @@
use serde_derive::Deserialize;
use std::collections::HashMap;
use std::path::Path;
use tokio::fs::File;
use tokio::io::AsyncReadExt;
use toml;
use crate::errors::Error;
#[derive(Clone, PartialEq, Hash, std::cmp::Eq, Debug)]
pub struct Address {
pub host: String,
pub port: String,
}
#[derive(Clone, PartialEq, Hash, std::cmp::Eq)]
#[derive(Clone, PartialEq, Hash, std::cmp::Eq, Deserialize, Debug)]
pub struct User {
pub name: String,
pub password: String,
}
// #[derive(Clone)]
// pub struct Config {
// pools: HashMap<String, Pool<ServerPool>>,
// }
#[derive(Deserialize, Debug, Clone)]
pub struct General {
pub host: String,
pub port: i16,
pub pool_size: u32,
pub pool_mode: String,
pub connect_timeout: u64,
pub healthcheck_timeout: u64,
pub ban_time: i64,
}
// impl Config {
// pub fn new() -> Config {
// Config {
// pools: HashMap::new(),
// }
// }
#[derive(Deserialize, Debug, Clone)]
pub struct Shard {
pub servers: Vec<(String, u16)>,
pub database: String,
}
#[derive(Deserialize, Debug, Clone)]
pub struct Config {
pub general: General,
pub user: User,
pub shards: HashMap<String, Shard>,
}
pub async fn parse(path: &str) -> Result<Config, Error> {
// let path = Path::new(path);
let mut contents = String::new();
let mut file = match File::open(path).await {
Ok(file) => file,
Err(err) => {
println!("> Config error: {:?}", err);
return Err(Error::BadConfig);
}
};
match file.read_to_string(&mut contents).await {
Ok(_) => (),
Err(err) => {
println!("> Config error: {:?}", err);
return Err(Error::BadConfig);
}
};
// let config: toml::Value = match toml::from_str(&contents) {
// Ok(config) => config,
// Err(err) => {
// println!("> Config error: {:?}", err);
// return Err(Error::BadConfig);
// }
// };
// println!("Config: {:?}", config);
let config: Config = match toml::from_str(&contents) {
Ok(config) => config,
Err(err) => {
println!("> Config error: {:?}", err);
return Err(Error::BadConfig);
}
};
Ok(config)
}

View File

@@ -7,4 +7,5 @@ pub enum Error {
ServerError,
ServerTimeout,
DirtyServer,
BadConfig,
}

View File

@@ -17,7 +17,10 @@ extern crate async_trait;
extern crate bb8;
extern crate bytes;
extern crate md5;
extern crate serde;
extern crate serde_derive;
extern crate tokio;
extern crate toml;
use tokio::net::TcpListener;
@@ -42,8 +45,15 @@ use pool::{ClientServerMap, ConnectionPool};
async fn main() {
println!("> Welcome to PgCat! Meow.");
let addr = "0.0.0.0:6432";
let listener = match TcpListener::bind(addr).await {
let config = match config::parse("pgcat.toml").await {
Ok(config) => config,
Err(err) => {
return;
}
};
let addr = format!("{}:{}", config.general.host, config.general.port);
let listener = match TcpListener::bind(&addr).await {
Ok(sock) => sock,
Err(err) => {
println!("> Error: {:?}", err);
@@ -53,28 +63,21 @@ async fn main() {
println!("> Running on {}", addr);
// Tracks which client is connected to which server for query cancellation.
let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new()));
// Replica pool.
let addresses = vec![
Address {
host: "127.0.0.1".to_string(),
port: "5432".to_string(),
},
Address {
host: "localhost".to_string(),
port: "5433".to_string(),
},
];
println!("> Pool size: {}", config.general.pool_size);
println!("> Pool mode: {}", config.general.pool_mode);
println!("> Ban time: {}s", config.general.ban_time);
println!(
"> Healthcheck timeout: {}ms",
config.general.healthcheck_timeout
);
let user = User {
name: "lev".to_string(),
password: "lev".to_string(),
};
let pool = ConnectionPool::from_config(config.clone(), client_server_map.clone()).await;
let transaction_mode = config.general.pool_mode == "transaction";
let database = "lev";
let pool = ConnectionPool::new(addresses, user, database, client_server_map.clone()).await;
println!("> Waiting for clients...");
loop {
let pool = pool.clone();
@@ -90,9 +93,12 @@ async fn main() {
// Client goes to another thread, bye.
tokio::task::spawn(async move {
println!(">> Client {:?} connected.", addr);
println!(
">> Client {:?} connected, transaction pooling: {}",
addr, transaction_mode
);
match client::Client::startup(socket, client_server_map).await {
match client::Client::startup(socket, client_server_map, transaction_mode).await {
Ok(mut client) => {
println!(">> Client {:?} authenticated successfully!", addr);

View File

@@ -3,7 +3,7 @@ use async_trait::async_trait;
use bb8::{ManageConnection, Pool, PooledConnection};
use chrono::naive::NaiveDateTime;
use crate::config::{Address, User};
use crate::config::{Address, Config, User};
use crate::errors::Error;
use crate::server::Server;
@@ -31,15 +31,17 @@ const CONNECT_TIMEOUT: u64 = 5000;
// How much time to give the server to answer a SELECT 1 query.
const HEALTHCHECK_TIMEOUT: u64 = 1000;
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct ConnectionPool {
databases: Vec<Vec<Pool<ServerPool>>>,
addresses: Vec<Vec<Address>>,
round_robin: Counter,
banlist: BanList,
healthcheck_timeout: u64,
}
impl ConnectionPool {
// Construct the connection pool for a single-shard cluster.
pub async fn new(
addresses: Vec<Address>,
user: User,
@@ -71,10 +73,70 @@ impl ConnectionPool {
addresses: vec![addresses],
round_robin: Arc::new(AtomicUsize::new(0)),
banlist: Arc::new(Mutex::new(vec![HashMap::new()])),
healthcheck_timeout: HEALTHCHECK_TIMEOUT,
}
}
/// Get a connection from the pool. Either round-robin or pick a specific one in case they are sharded.
/// Construct the connection pool from a config file.
pub async fn from_config(config: Config, client_server_map: ClientServerMap) -> ConnectionPool {
let mut shards = Vec::new();
let mut addresses = Vec::new();
let mut banlist = Vec::new();
let mut shard_ids = config
.shards
.clone()
.into_keys()
.map(|x| x.to_string())
.collect::<Vec<String>>();
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
for shard in shard_ids {
let shard = &config.shards[&shard];
let mut pools = Vec::new();
let mut replica_addresses = Vec::new();
for server in &shard.servers {
let address = Address {
host: server.0.clone(),
port: server.1.to_string(),
};
let manager = ServerPool::new(
address.clone(),
config.user.clone(),
&shard.database,
client_server_map.clone(),
);
let pool = Pool::builder()
.max_size(config.general.pool_size)
.connection_timeout(std::time::Duration::from_millis(
config.general.connect_timeout,
))
.test_on_check_out(false)
.build(manager)
.await
.unwrap();
pools.push(pool);
replica_addresses.push(address);
}
shards.push(pools);
addresses.push(replica_addresses);
banlist.push(HashMap::new());
}
ConnectionPool {
databases: shards,
addresses: addresses,
round_robin: Arc::new(AtomicUsize::new(0)),
banlist: Arc::new(Mutex::new(banlist)),
healthcheck_timeout: config.general.healthcheck_timeout,
}
}
/// Get a connection from the pool.
pub async fn get(
&self,
shard: Option<usize>,

View File

@@ -15,17 +15,38 @@ use crate::ClientServerMap;
/// Server state.
pub struct Server {
// Server host, e.g. localhost
host: String,
// Server port: e.g. 5432
port: String,
// Buffered read socket
read: BufReader<OwnedReadHalf>,
// Unbuffered write socket (our client code buffers)
write: OwnedWriteHalf,
// Our server response buffer
buffer: BytesMut,
// Server information the server sent us over on startup
server_info: BytesMut,
// Backend id and secret key used for query cancellation.
backend_id: i32,
secret_key: i32,
// Is the server inside a transaction at the moment.
in_transaction: bool,
// Is there more data for the client to read.
data_available: bool,
// Is the server broken? We'll remote it from the pool if so.
bad: bool,
// Mapping of clients and servers used for query cancellation.
client_server_map: ClientServerMap,
}
@@ -48,6 +69,7 @@ impl Server {
}
};
// Send the startup packet.
startup(&mut stream, user, database).await?;
let mut server_info = BytesMut::with_capacity(25);

View File

@@ -1,5 +1,8 @@
use sha1::{Digest, Sha1};
// https://github.com/postgres/postgres/blob/27b77ecf9f4d5be211900eda54d8155ada50d696/src/include/catalog/partition.h#L20
const PARTITION_HASH_SEED: u64 = 0x7A5B22367996DCFD;
pub struct Sharder {
shards: usize,
}
@@ -9,6 +12,8 @@ impl Sharder {
Sharder { shards: shards }
}
/// Use SHA1 to pick a shard for the key. The key can be anything,
/// including an int or a string.
pub fn sha1(&self, key: &[u8]) -> usize {
let mut hasher = Sha1::new();
hasher.update(key);
@@ -17,6 +22,81 @@ impl Sharder {
let i = u32::from_le_bytes(result[result.len() - 4..result.len()].try_into().unwrap());
i as usize % self.shards
}
/// Hash function used by Postgres to determine which partition
/// to put the row in when using HASH(column) partitioning.
/// Source: https://github.com/postgres/postgres/blob/27b77ecf9f4d5be211900eda54d8155ada50d696/src/common/hashfn.c#L631
pub fn pg_bigint_hash(&self, key: i64) -> usize {
let mut lohalf = key as u32;
let hihalf = (key >> 32) as u32;
println!("{}, {}", lohalf, hihalf);
lohalf ^= if key >= 0 { hihalf } else { !hihalf };
println!("Low half: {}", lohalf);
Self::pg_u32_hash(lohalf) as usize % self.shards
}
fn rot(x: u32, k: u32) -> u32 {
((x) << (k)) | ((x) >> (32 - (k)))
}
#[inline]
fn mix(mut a: u32, mut b: u32, mut c: u32) -> (u32, u32, u32) {
a = a.wrapping_sub(c);
a ^= Self::rot(c, 4);
c = c.wrapping_add(b);
b = b.wrapping_add(a);
b ^= Self::rot(a, 6);
a = a.wrapping_add(c);
c = c.wrapping_add(b);
c ^= Self::rot(b, 8);
b = b.wrapping_add(a);
a = a.wrapping_add(c);
a ^= Self::rot(c, 16);
c = c.wrapping_add(b);
b = b.wrapping_add(a);
b ^= Self::rot(a, 19);
a = a.wrapping_add(c);
c = c.wrapping_add(b);
c ^= Self::rot(b, 4);
b = b.wrapping_add(a);
(a, b, c)
}
#[inline]
fn _final(mut a: u32, mut b: u32, mut c: u32) -> (u32, u32, u32) {
c ^= b;
c = c.wrapping_add(Self::rot(b, 14));
a ^= c;
a = a.wrapping_add(Self::rot(c, 11));
b ^= a;
b = b.wrapping_add(Self::rot(a, 25));
c ^= b;
c = c.wrapping_add(Self::rot(b, 16));
a ^= c;
a = a.wrapping_add(Self::rot(c, 4));
b ^= a;
b = b.wrapping_add(Self::rot(a, 14));
c ^= b;
c = c.wrapping_add(Self::rot(b, 24));
(a, b, c)
}
fn pg_u32_hash(val: u32) -> u64 {
let mut a: u32 = 0x9e3779b9 + 4 + 3923095;
let mut b = a;
let c = a;
let seed = PARTITION_HASH_SEED;
a = a.wrapping_add((seed >> 32) as u32);
b = b.wrapping_add(seed as u32);
let (mut a, b, c) = Self::mix(a, b, c);
a = a.wrapping_add(val);
let (a, b, c) = Self::_final(a, b, c);
(b as u64) << 32 | c as u64
}
}
#[cfg(test)]
@@ -30,4 +110,12 @@ mod test {
let shard = sharder.sha1(key);
assert_eq!(shard, 1);
}
#[test]
fn test_pg_bigint_hash() {
let sharder = Sharder::new(4);
let key = 23423423 as i64;
let shard = sharder.pg_bigint_hash(key);
assert_eq!(shard, 0);
}
}