From af6f770271d808e9abf42885e2e874e727b245aa Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Tue, 8 Feb 2022 13:11:50 -0800 Subject: [PATCH] sharded query routing --- Cargo.lock | 27 ++++++++++++++++++++++ Cargo.toml | 3 ++- src/client.rs | 60 ++++++++++++++++++++++++++++++++++++++++++++++--- src/messages.rs | 17 ++++++++++++++ 4 files changed, 103 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f4909f0..0e6f79e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "aho-corasick" +version = "0.7.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e37cfd5e7657ada45f742d6e99ca5788580b5c529dc78faf11ece6dc702656f" +dependencies = [ + "memchr", +] + [[package]] name = "async-trait" version = "0.1.52" @@ -315,6 +324,7 @@ dependencies = [ "chrono", "md-5", "rand", + "regex", "serde", "serde_derive", "sha-1", @@ -407,6 +417,23 @@ dependencies = [ "bitflags", ] +[[package]] +name = "regex" +version = "1.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d07a8629359eb56f1e2fb1652bb04212c072a87ba68546a04065d525673ac461" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.6.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" + [[package]] name = "scopeguard" version = "1.1.0" diff --git a/Cargo.toml b/Cargo.toml index 3cb0a6e..019c153 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,4 +16,5 @@ chrono = "0.4" sha-1 = "*" toml = "*" serde = "*" -serde_derive = "*" \ No newline at end of file +serde_derive = "*" +regex = "1" \ No newline at end of file diff --git a/src/client.rs b/src/client.rs index b7e4b15..0ba97d0 100644 --- a/src/client.rs +++ b/src/client.rs @@ -3,6 +3,7 @@ /// and this module implements that. use bytes::{Buf, BufMut, BytesMut}; use rand::{distributions::Alphanumeric, Rng}; +use regex::Regex; use tokio::io::{AsyncReadExt, BufReader}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::TcpStream; @@ -11,6 +12,9 @@ use crate::errors::Error; use crate::messages::*; use crate::pool::{ClientServerMap, ConnectionPool}; use crate::server::Server; +use crate::sharding::Sharder; + +const SHARDING_REGEX: &str = r"SET SHARDING KEY TO '[0-9]+';"; /// The client state. One of these is created per client. pub struct Client { @@ -39,6 +43,9 @@ pub struct Client { // Clients are mapped to servers while they use them. This allows a client // to connect and cancel a query. client_server_map: ClientServerMap, + + // sharding regex + sharding_regex: Regex, } impl Client { @@ -50,6 +57,8 @@ impl Client { client_server_map: ClientServerMap, transaction_mode: bool, ) -> Result { + let sharding_regex = Regex::new(SHARDING_REGEX).unwrap(); + loop { // Could be StartupMessage or SSLRequest // which makes this variable length. @@ -105,6 +114,7 @@ impl Client { process_id: process_id, secret_key: secret_key, client_server_map: client_server_map, + sharding_regex: sharding_regex, }); } @@ -124,6 +134,7 @@ impl Client { process_id: process_id, secret_key: secret_key, client_server_map: client_server_map, + sharding_regex: sharding_regex, }); } @@ -156,6 +167,12 @@ impl Client { return Ok(Server::cancel(&address, &port, process_id, secret_key).await?); } + // Active shard we're talking to. + // The lifetime of this depends on the pool mode: + // - if in session mode, this lives until client disconnects or changes it, + // - if in transaction mode, this lives for the duration of one transaction. + let mut shard: Option = None; + loop { // Read a complete message from the client, which normally would be // either a `Q` (query) or `P` (prepare, extended protocol). @@ -164,15 +181,23 @@ impl Client { // SET sharding_context.key = '1234'; let mut message = read_message(&mut self.read).await?; - // TODO: parse the message here. If it's part of our protocol, - // don't grab a server yet and continue loop. + // Parse for special select shard command. + // SET SHARDING KEY TO 'bigint'; + match self.select_shard(message.clone(), pool.shards()).await { + Some(s) => { + set_sharding_key(&mut self.write).await?; + shard = Some(s); + continue; + } + None => (), + }; // The message is part of the regular protocol. // self.buffer.put(message); // Grab a server from the pool. // None = any shard - let connection = pool.get(None).await.unwrap(); + let connection = pool.get(shard).await.unwrap(); let mut proxy = connection.0; let _address = connection.1; let server = &mut *proxy; @@ -230,6 +255,7 @@ impl Client { // Release server if !server.in_transaction() && self.transaction_mode { + shard = None; break; } } @@ -288,6 +314,7 @@ impl Client { // Release server if !server.in_transaction() && self.transaction_mode { + shard = None; break; } } @@ -314,6 +341,7 @@ impl Client { // Release the server if !server.in_transaction() && self.transaction_mode { println!("Releasing after copy done"); + shard = None; break; } } @@ -333,4 +361,30 @@ impl Client { let mut guard = self.client_server_map.lock().unwrap(); guard.remove(&(self.process_id, self.secret_key)); } + + async fn select_shard(&mut self, mut buf: BytesMut, shards: usize) -> Option { + let code = buf.get_u8() as char; + + match code { + 'Q' => (), + // 'P' => (), + _ => return None, + }; + + let len = buf.get_i32(); + let query = String::from_utf8_lossy(&buf[..len as usize - 4 - 1]).to_ascii_uppercase(); // Don't read the ternminating null + + if self.sharding_regex.is_match(&query) { + let shard = query.split("'").collect::>()[1]; + match shard.parse::() { + Ok(shard) => { + let sharder = Sharder::new(shards); + Some(sharder.pg_bigint_hash(shard)) + } + Err(_) => None, + } + } else { + None + } + } } diff --git a/src/messages.rs b/src/messages.rs index 58459f9..eb65c90 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -138,6 +138,23 @@ pub async fn md5_password( Ok(write_all(stream, message).await?) } +pub async fn set_sharding_key(stream: &mut OwnedWriteHalf) -> Result<(), Error> { + let mut res = BytesMut::with_capacity(25); + + let set_complete = BytesMut::from(&"SET SHARDING KEY\0"[..]); + let len = (set_complete.len() + 4) as i32; + + res.put_u8(b'C'); + res.put_i32(len); + res.put_slice(&set_complete[..]); + + res.put_u8(b'Z'); + res.put_i32(5); + res.put_u8(b'I'); + + write_all_half(stream, res).await +} + /// Write all data in the buffer to the TcpStream. pub async fn write_all(stream: &mut TcpStream, buf: BytesMut) -> Result<(), Error> { match stream.write_all(&buf).await {