From 4c16ba3848aff12eaf8295bdcc1330a49ef21d6d Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 9 Feb 2022 06:51:31 -0800 Subject: [PATCH] some comments --- src/client.rs | 16 ++++++++++------ src/config.rs | 2 +- src/messages.rs | 7 +++++++ src/pool.rs | 4 ++-- 4 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/client.rs b/src/client.rs index 31b61ff..85cfbe9 100644 --- a/src/client.rs +++ b/src/client.rs @@ -168,7 +168,7 @@ impl Client { // Active shard we're talking to. // The lifetime of this depends on the pool mode: - // - if in session mode, this lives until client disconnects or changes it, + // - if in session mode, this lives until the client disconnects, // - if in transaction mode, this lives for the duration of one transaction. let mut shard: Option = None; @@ -177,7 +177,7 @@ impl Client { // either a `Q` (query) or `P` (prepare, extended protocol). // We can parse it here before grabbing a server from the pool, // in case the client is sending some control messages, e.g. - // SET sharding_context.key = '1234'; + // SET SHARDING KEY TO 'bigint'; let mut message = read_message(&mut self.read).await?; // Parse for special select shard command. @@ -191,9 +191,6 @@ impl Client { None => (), }; - // The message is part of the regular protocol. - // self.buffer.put(message); - // Grab a server from the pool. // None = any shard let connection = pool.get(shard).await.unwrap(); @@ -361,12 +358,19 @@ impl Client { guard.remove(&(self.process_id, self.secret_key)); } + /// Determine if the query is part of our special syntax, extract + /// the shard key, and return the shard to query based on Postgres' + /// PARTITION BY HASH function. async fn select_shard(&mut self, mut buf: BytesMut, shards: usize) -> Option { let code = buf.get_u8() as char; + // Only supporting simpe protocol here, so + // one would have to execute something like this: + // psql -c "SET SHARDING KEY TO '1234'" + // after sanitizing the value manually, which can be just done with an + // int parser, e.g. `let key = "1234".parse::().unwrap()`. match code { 'Q' => (), - // 'P' => (), _ => return None, }; diff --git a/src/config.rs b/src/config.rs index ed897b2..39218f7 100644 --- a/src/config.rs +++ b/src/config.rs @@ -43,8 +43,8 @@ pub struct Config { pub shards: HashMap, } +/// Parse the config. pub async fn parse(path: &str) -> Result { - // let path = Path::new(path); let mut contents = String::new(); let mut file = match File::open(path).await { Ok(file) => file, diff --git a/src/messages.rs b/src/messages.rs index eb65c90..90a6700 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -26,6 +26,9 @@ pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> { /// Send server parameters to the client. This will tell the client /// what server version and what's the encoding we're using. +// +// TODO: Forward these from the server instead of hardcoding. +// pub async fn server_parameters(stream: &mut TcpStream) -> Result<(), Error> { let client_encoding = BytesMut::from(&b"client_encoding\0UTF8\0"[..]); let server_version = @@ -138,16 +141,20 @@ pub async fn md5_password( Ok(write_all(stream, message).await?) } +/// Implements a response to our custom `SET SHARDING KEY` command. +/// This tells the client we're ready for the next query. pub async fn set_sharding_key(stream: &mut OwnedWriteHalf) -> Result<(), Error> { let mut res = BytesMut::with_capacity(25); let set_complete = BytesMut::from(&"SET SHARDING KEY\0"[..]); let len = (set_complete.len() + 4) as i32; + // CommandComplete res.put_u8(b'C'); res.put_i32(len); res.put_slice(&set_complete[..]); + // ReadyForQuery (idle) res.put_u8(b'Z'); res.put_i32(5); res.put_u8(b'I'); diff --git a/src/pool.rs b/src/pool.rs index eb06816..624818a 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -182,7 +182,7 @@ impl ConnectionPool { pub fn is_banned(&self, address: &Address, shard: usize) -> bool { let mut guard = self.banlist.lock().unwrap(); - // Everything is banned, nothig is banned + // Everything is banned = nothing is banned. if guard[shard].len() == self.databases[shard].len() { guard[shard].clear(); drop(guard); @@ -194,8 +194,8 @@ impl ConnectionPool { match guard[shard].get(address) { Some(timestamp) => { let now = chrono::offset::Utc::now().naive_utc(); + // Ban expired. if now.timestamp() - timestamp.timestamp() > self.ban_time { - // 1 minute guard[shard].remove(address); false } else {