diff --git a/pgcat.toml b/pgcat.toml index 9125afd..4b7d0c5 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -82,6 +82,7 @@ primary_reads_enabled = true # sha1: A hashing function based on SHA1 # sharding_function = "pg_bigint_hash" +sharding_key = "id" # Credentials for users that may connect to this cluster [pools.sharded_db.users.0] diff --git a/src/client.rs b/src/client.rs index 3aac72c..62a9f29 100644 --- a/src/client.rs +++ b/src/client.rs @@ -662,7 +662,7 @@ where // Normal query, not a custom command. None => { if query_router.query_parser_enabled() { - query_router.infer_role(message.clone()); + query_router.infer_role_and_shard(message.clone()); } } diff --git a/src/config.rs b/src/config.rs index 5c12261..004b768 100644 --- a/src/config.rs +++ b/src/config.rs @@ -185,6 +185,7 @@ pub struct Pool { pub query_parser_enabled: bool, pub primary_reads_enabled: bool, pub sharding_function: String, + pub sharding_key: Option, pub shards: HashMap, pub users: HashMap, } @@ -198,6 +199,7 @@ impl Default for Pool { query_parser_enabled: false, primary_reads_enabled: true, sharding_function: "pg_bigint_hash".to_string(), + sharding_key: None, } } } diff --git a/src/pool.rs b/src/pool.rs index dea29ad..88b0fd4 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -8,6 +8,7 @@ use once_cell::sync::Lazy; use parking_lot::{Mutex, RwLock}; use rand::seq::SliceRandom; use rand::thread_rng; +use regex::Regex; use std::collections::HashMap; use std::sync::Arc; use std::time::Instant; @@ -68,6 +69,9 @@ pub struct PoolSettings { // Sharding function. pub sharding_function: ShardingFunction, + + // Automatically detect sharding key in query. + pub sharding_key_regex: Option, } impl Default for PoolSettings { @@ -80,6 +84,7 @@ impl Default for PoolSettings { query_parser_enabled: false, primary_reads_enabled: true, sharding_function: ShardingFunction::PgBigintHash, + sharding_key_regex: None, } } } @@ -229,6 +234,20 @@ impl ConnectionPool { "sha1" => ShardingFunction::Sha1, _ => unreachable!(), }, + sharding_key_regex: match &pool_config.sharding_key { + Some(sharding_key) => match Regex::new(&format!( + r"(?i) *{} *= *'?([0-9]+)'?", + sharding_key + )) { + Ok(regex) => Some(regex), + Err(err) => { + error!("Sharding key regex error: {:?}", err); + return Err(Error::BadConfig); + } + }, + + None => None, + }, }, }; diff --git a/src/query_router.rs b/src/query_router.rs index f9d5f0b..808857c 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -104,6 +104,8 @@ impl QueryRouter { /// Pool settings can change because of a config reload. pub fn update_pool_settings(&mut self, pool_settings: PoolSettings) { self.pool_settings = pool_settings; + self.query_parser_enabled = self.pool_settings.query_parser_enabled; + self.primary_reads_enabled = self.pool_settings.primary_reads_enabled; } /// Try to parse a command and execute it. @@ -256,7 +258,7 @@ impl QueryRouter { } /// Try to infer which server to connect to based on the contents of the query. - pub fn infer_role(&mut self, mut buf: BytesMut) -> bool { + pub fn infer_role_and_shard(&mut self, mut buf: BytesMut) -> bool { debug!("Inferring role"); let code = buf.get_u8() as char; @@ -297,6 +299,31 @@ impl QueryRouter { _ => return false, }; + // First find the shard key + match &self.pool_settings.sharding_key_regex { + Some(re) => { + match re.captures(&query) { + Some(group) => match group.get(1) { + Some(value) => { + let value = value.as_str().parse::().unwrap(); + let sharder = Sharder::new( + self.pool_settings.shards, + self.pool_settings.sharding_function, + ); + let shard = sharder.shard(value); + self.active_shard = Some(shard); + + debug!("Automatically routing to shard {}", shard); + } + None => (), + }, + + None => (), + }; + } + None => (), + }; + let ast = match Parser::parse_sql(&PostgreSqlDialect {}, &query) { Ok(ast) => ast, Err(err) => { @@ -373,7 +400,7 @@ mod test { } #[test] - fn test_infer_role_replica() { + fn test__replica() { QueryRouter::setup(); let mut qr = QueryRouter::new(); assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None); @@ -391,13 +418,13 @@ mod test { for query in queries { // It's a recognized query - assert!(qr.infer_role(query)); + assert!(qr.infer_role_and_shard(query)); assert_eq!(qr.role(), Some(Role::Replica)); } } #[test] - fn test_infer_role_primary() { + fn test_infer_role_and_shard_primary() { QueryRouter::setup(); let mut qr = QueryRouter::new(); @@ -410,24 +437,24 @@ mod test { for query in queries { // It's a recognized query - assert!(qr.infer_role(query)); + assert!(qr.infer_role_and_shard(query)); assert_eq!(qr.role(), Some(Role::Primary)); } } #[test] - fn test_infer_role_primary_reads_enabled() { + fn test_infer_role_and_shard_primary_reads_enabled() { QueryRouter::setup(); let mut qr = QueryRouter::new(); let query = simple_query("SELECT * FROM items WHERE id = 5"); assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO on")) != None); - assert!(qr.infer_role(query)); + assert!(qr.infer_role_and_shard(query)); assert_eq!(qr.role(), None); } #[test] - fn test_infer_role_parse_prepared() { + fn test_infer_role_and_shard_parse_prepared() { QueryRouter::setup(); let mut qr = QueryRouter::new(); qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")); @@ -442,7 +469,7 @@ mod test { res.put(prepared_stmt); res.put_i16(0); - assert!(qr.infer_role(res)); + assert!(qr.infer_role_and_shard(res)); assert_eq!(qr.role(), Some(Role::Replica)); } @@ -606,11 +633,11 @@ mod test { assert_eq!(qr.role(), None); let query = simple_query("INSERT INTO test_table VALUES (1)"); - assert_eq!(qr.infer_role(query), true); + assert_eq!(qr.infer_role_and_shard(query), true); assert_eq!(qr.role(), Some(Role::Primary)); let query = simple_query("SELECT * FROM test_table"); - assert_eq!(qr.infer_role(query), true); + assert_eq!(qr.infer_role_and_shard(query), true); assert_eq!(qr.role(), Some(Role::Replica)); assert!(qr.query_parser_enabled());