diff --git a/pgcat.toml b/pgcat.toml index 0187c16..26f9d7d 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -85,6 +85,12 @@ query_parser_enabled = true # queries. The primary can always be explicitly selected with our custom protocol. primary_reads_enabled = true +# Allow sharding commands to be passed as statement comments instead of +# separate commands. If these are unset this functionality is disabled. +# sharding_key_regex = '/\* sharding_key: (\d+) \*/' +# shard_id_regex = '/\* shard_id: (\d+) \*/' +# regex_search_limit = 1000 # only look at the first 1000 characters of SQL statements + # So what if you wanted to implement a different hashing function, # or you've already built one and you want this pooler to use it? # diff --git a/src/config.rs b/src/config.rs index 6943423..392acfb 100644 --- a/src/config.rs +++ b/src/config.rs @@ -2,6 +2,7 @@ use arc_swap::ArcSwap; use log::{error, info}; use once_cell::sync::Lazy; +use regex::Regex; use serde_derive::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashMap, HashSet}; use std::hash::Hash; @@ -342,8 +343,15 @@ pub struct Pool { #[serde(default = "Pool::default_automatic_sharding_key")] pub automatic_sharding_key: Option, + pub sharding_key_regex: Option, + pub shard_id_regex: Option, + pub regex_search_limit: Option, + pub shards: BTreeMap, pub users: BTreeMap, + // Note, don't put simple fields below these configs. There's a compatability issue with TOML that makes it + // incompatible to have simple fields in TOML after complex objects. See + // https://users.rust-lang.org/t/why-toml-to-string-get-error-valueaftertable/85903 } impl Pool { @@ -387,6 +395,18 @@ impl Pool { shard.validate()?; } + for (option, name) in [ + (&self.shard_id_regex, "shard_id_regex"), + (&self.sharding_key_regex, "sharding_key_regex"), + ] { + if let Some(regex) = option { + if let Err(parse_err) = Regex::new(regex.as_str()) { + error!("{} is not a valid Regex: {}", name, parse_err); + return Err(Error::BadConfig); + } + } + } + Ok(()) } } @@ -405,6 +425,9 @@ impl Default for Pool { automatic_sharding_key: None, connect_timeout: None, idle_timeout: None, + sharding_key_regex: None, + shard_id_regex: None, + regex_search_limit: Some(1000), } } } diff --git a/src/pool.rs b/src/pool.rs index 702b617..cbe5b50 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -8,6 +8,7 @@ use once_cell::sync::Lazy; use parking_lot::{Mutex, RwLock}; use rand::seq::SliceRandom; use rand::thread_rng; +use regex::Regex; use std::collections::{HashMap, HashSet}; use std::sync::{ atomic::{AtomicBool, Ordering}, @@ -104,6 +105,15 @@ pub struct PoolSettings { // Ban time pub ban_time: i64, + + // Regex for searching for the sharding key in SQL statements + pub sharding_key_regex: Option, + + // Regex for searching for the shard id in SQL statements + pub shard_id_regex: Option, + + // Limit how much of each query is searched for a potential shard regex match + pub regex_search_limit: usize, } impl Default for PoolSettings { @@ -121,6 +131,9 @@ impl Default for PoolSettings { healthcheck_delay: General::default_healthcheck_delay(), healthcheck_timeout: General::default_healthcheck_timeout(), ban_time: General::default_ban_time(), + sharding_key_regex: None, + shard_id_regex: None, + regex_search_limit: 1000, } } } @@ -300,6 +313,15 @@ impl ConnectionPool { healthcheck_delay: config.general.healthcheck_delay, healthcheck_timeout: config.general.healthcheck_timeout, ban_time: config.general.ban_time, + sharding_key_regex: pool_config + .sharding_key_regex + .clone() + .map(|regex| Regex::new(regex.as_str()).unwrap()), + shard_id_regex: pool_config + .shard_id_regex + .clone() + .map(|regex| Regex::new(regex.as_str()).unwrap()), + regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000), }, validated: Arc::new(AtomicBool::new(false)), paused: Arc::new(AtomicBool::new(false)), diff --git a/src/query_router.rs b/src/query_router.rs index 28d899d..bf07db7 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -14,6 +14,7 @@ use crate::messages::BytesMutReader; use crate::pool::PoolSettings; use crate::sharding::Sharder; +use std::cmp; use std::collections::BTreeSet; use std::io::Cursor; @@ -114,7 +115,52 @@ impl QueryRouter { let code = message_cursor.get_u8() as char; - // Only simple protocol supported for commands. + // Check for any sharding regex matches in any queries + match code as char { + // For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement + 'P' | 'Q' => { + if self.pool_settings.shard_id_regex.is_some() + || self.pool_settings.sharding_key_regex.is_some() + { + // Check only the first block of bytes configured by the pool settings + let len = message_cursor.get_i32() as usize; + let seg = cmp::min(len - 5, self.pool_settings.regex_search_limit); + let initial_segment = String::from_utf8_lossy(&message_buffer[0..seg]); + + // Check for a shard_id included in the query + if let Some(shard_id_regex) = &self.pool_settings.shard_id_regex { + let shard_id = shard_id_regex.captures(&initial_segment).and_then(|cap| { + cap.get(1).and_then(|id| id.as_str().parse::().ok()) + }); + if let Some(shard_id) = shard_id { + debug!("Setting shard to {:?}", shard_id); + self.set_shard(shard_id); + // Skip other command processing since a sharding command was found + return None; + } + } + + // Check for a sharding_key included in the query + if let Some(sharding_key_regex) = &self.pool_settings.sharding_key_regex { + let sharding_key = + sharding_key_regex + .captures(&initial_segment) + .and_then(|cap| { + cap.get(1).and_then(|id| id.as_str().parse::().ok()) + }); + if let Some(sharding_key) = sharding_key { + debug!("Setting sharding_key to {:?}", sharding_key); + self.set_sharding_key(sharding_key); + // Skip other command processing since a sharding command was found + return None; + } + } + } + } + _ => {} + } + + // Only simple protocol supported for commands processed below if code != 'Q' { return None; } @@ -192,13 +238,11 @@ impl QueryRouter { match command { Command::SetShardingKey => { - let sharder = Sharder::new( - self.pool_settings.shards, - self.pool_settings.sharding_function, - ); - let shard = sharder.shard(value.parse::().unwrap()); - self.active_shard = Some(shard); - value = shard.to_string(); + // TODO: some error handling here + value = self + .set_sharding_key(value.parse::().unwrap()) + .unwrap() + .to_string(); } Command::SetShard => { @@ -465,6 +509,16 @@ impl QueryRouter { } } + fn set_sharding_key(&mut self, sharding_key: i64) -> Option { + let sharder = Sharder::new( + self.pool_settings.shards, + self.pool_settings.sharding_function, + ); + let shard = sharder.shard(sharding_key); + self.set_shard(shard); + self.active_shard + } + /// Get the current desired server role we should be talking to. pub fn role(&self) -> Option { self.active_role @@ -775,6 +829,9 @@ mod test { healthcheck_delay: PoolSettings::default().healthcheck_delay, healthcheck_timeout: PoolSettings::default().healthcheck_timeout, ban_time: PoolSettings::default().ban_time, + sharding_key_regex: None, + shard_id_regex: None, + regex_search_limit: 1000, }; let mut qr = QueryRouter::new(); assert_eq!(qr.active_role, None); @@ -820,4 +877,47 @@ mod test { ))); assert_eq!(qr.role(), Role::Primary); } + + #[test] + fn test_regex_shard_parsing() { + QueryRouter::setup(); + + let pool_settings = PoolSettings { + pool_mode: PoolMode::Transaction, + load_balancing_mode: crate::config::LoadBalancingMode::Random, + shards: 5, + user: crate::config::User::default(), + default_role: Some(Role::Replica), + query_parser_enabled: true, + primary_reads_enabled: false, + sharding_function: ShardingFunction::PgBigintHash, + automatic_sharding_key: Some(String::from("id")), + healthcheck_delay: PoolSettings::default().healthcheck_delay, + healthcheck_timeout: PoolSettings::default().healthcheck_timeout, + ban_time: PoolSettings::default().ban_time, + sharding_key_regex: Some(Regex::new(r"/\* sharding_key: (\d+) \*/").unwrap()), + shard_id_regex: Some(Regex::new(r"/\* shard_id: (\d+) \*/").unwrap()), + regex_search_limit: 1000, + }; + let mut qr = QueryRouter::new(); + qr.update_pool_settings(pool_settings.clone()); + + // Shard should start out unset + assert_eq!(qr.active_shard, None); + + // Make sure setting it works + let q1 = simple_query("/* shard_id: 1 */ select 1 from foo;"); + assert!(qr.try_execute_command(&q1) == None); + assert_eq!(qr.active_shard, Some(1)); + + // And make sure changing it works + let q2 = simple_query("/* shard_id: 0 */ select 1 from foo;"); + assert!(qr.try_execute_command(&q2) == None); + assert_eq!(qr.active_shard, Some(0)); + + // Validate setting by shard with expected shard copied from sharding.rs tests + let q2 = simple_query("/* sharding_key: 6 */ select 1 from foo;"); + assert!(qr.try_execute_command(&q2) == None); + assert_eq!(qr.active_shard, Some(2)); + } }