From a784883611c77df3b32db3958661464176d9ed2d Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Thu, 24 Feb 2022 12:16:24 -0800 Subject: [PATCH] Allow to set shard and set sharding key without quotes (#43) * Allow to set shard and set sharding key without quotes * cover it * dont look for these in the middle of another query * friendly regex * its own response to set shard key --- src/client.rs | 8 ++++- src/query_router.rs | 82 +++++++++++++++++++++++++++++++++++++++------ 2 files changed, 78 insertions(+), 12 deletions(-) diff --git a/src/client.rs b/src/client.rs index 3c1eeea..fb1997d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -229,11 +229,17 @@ impl Client { } } - Some((Command::SetShard, _)) | Some((Command::SetShardingKey, _)) => { + Some((Command::SetShard, _)) => { custom_protocol_response_ok(&mut self.write, &format!("SET SHARD")).await?; continue; } + Some((Command::SetShardingKey, _)) => { + custom_protocol_response_ok(&mut self.write, &format!("SET SHARDING KEY")) + .await?; + continue; + } + Some((Command::SetServerRole, _)) => { custom_protocol_response_ok(&mut self.write, "SET SERVER ROLE").await?; continue; diff --git a/src/query_router.rs b/src/query_router.rs index 3f309ea..737c4d1 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -5,17 +5,17 @@ use crate::sharding::{Sharder, ShardingFunction}; use bytes::{Buf, BytesMut}; use log::{debug, error}; use once_cell::sync::OnceCell; -use regex::RegexSet; +use regex::{Regex, RegexSet}; use sqlparser::ast::Statement::{Query, StartTransaction}; use sqlparser::dialect::PostgreSqlDialect; use sqlparser::parser::Parser; const CUSTOM_SQL_REGEXES: [&str; 5] = [ - r"(?i)SET SHARDING KEY TO '[0-9]+'", - r"(?i)SET SHARD TO '[0-9]+'", - r"(?i)SHOW SHARD", - r"(?i)SET SERVER ROLE TO '(PRIMARY|REPLICA|ANY|AUTO|DEFAULT)'", - r"(?i)SHOW SERVER ROLE", + r"(?i)^ *SET SHARDING KEY TO '?([0-9]+)'? *;? *$", + r"(?i)^ *SET SHARD TO '?([0-9]+)'? *;? *$", + r"(?i)^ *SHOW SHARD *;? *$", + r"(?i)^ *SET SERVER ROLE TO '(PRIMARY|REPLICA|ANY|AUTO|DEFAULT)' *;? *$", + r"(?i)^ *SHOW SERVER ROLE *;? *$", ]; #[derive(PartialEq, Debug)] @@ -27,8 +27,12 @@ pub enum Command { ShowServerRole, } +// Quick test static CUSTOM_SQL_REGEX_SET: OnceCell = OnceCell::new(); +// Capture value +static CUSTOM_SQL_REGEX_LIST: OnceCell> = OnceCell::new(); + pub struct QueryRouter { // By default, queries go here, unless we have better information // about what the client wants. @@ -63,6 +67,21 @@ impl QueryRouter { } }; + let list: Vec<_> = CUSTOM_SQL_REGEXES + .iter() + .map(|rgx| Regex::new(rgx).unwrap()) + .collect(); + + // Impossible + if list.len() != set.len() { + return false; + } + + match CUSTOM_SQL_REGEX_LIST.set(list) { + Ok(_) => true, + Err(_) => return false, + }; + match CUSTOM_SQL_REGEX_SET.set(set) { Ok(_) => true, Err(_) => false, @@ -113,6 +132,11 @@ impl QueryRouter { None => return None, }; + let regex_list = match CUSTOM_SQL_REGEX_LIST.get() { + Some(regex_list) => regex_list, + None => return None, + }; + let matches: Vec<_> = regex_set.matches(&query).into_iter().collect(); if matches.len() != 1 { @@ -130,7 +154,19 @@ impl QueryRouter { let mut value = match command { Command::SetShardingKey | Command::SetShard | Command::SetServerRole => { - query.split("'").collect::>()[1].to_string() + // Capture value. I know this re-runs the regex engine, but I haven't + // figured out a better way just yet. I think I can write a single Regex + // that matches all 5 custom SQL patterns, but maybe that's not very legible? + // + // I think this is faster than running the Regex engine 5 times, so + // this is a strong maybe for me so far. + match regex_list[matches[0]].captures(&query) { + Some(captures) => match captures.get(1) { + Some(value) => value.as_str().to_string(), + None => return None, + }, + None => return None, + } } Command::ShowShard => self.shard().to_string(), @@ -411,14 +447,38 @@ mod test { "set server role to 'any'", "set server role to 'auto'", "show server role", + // No quotes + "SET SHARDING KEY TO 11235", + "SET SHARD TO 15", + // Spaces and semicolon + " SET SHARDING KEY TO 11235 ; ", + " SET SHARD TO 15; ", + " SET SHARDING KEY TO 11235 ;", + " SET SERVER ROLE TO 'primary'; ", + " SET SERVER ROLE TO 'primary' ; ", + " SET SERVER ROLE TO 'primary' ;", ]; + // Which regexes it'll match to in the list + let matches = [ + 0, 1, 2, 3, 3, 3, 3, 4, 0, 1, 2, 3, 3, 3, 3, 4, 0, 1, 0, 1, 0, 3, 3, 3, + ]; + + let list = CUSTOM_SQL_REGEX_LIST.get().unwrap(); let set = CUSTOM_SQL_REGEX_SET.get().unwrap(); - for test in &tests { - let matches: Vec<_> = set.matches(test).into_iter().collect(); + for (i, test) in tests.iter().enumerate() { + assert!(list[matches[i]].is_match(test)); + assert_eq!(set.matches(test).into_iter().collect::>().len(), 1); + } - assert_eq!(matches.len(), 1); + let bad = [ + "SELECT * FROM table", + "SELECT * FROM table WHERE value = 'set sharding key to 5'", // Don't capture things in the middle of the query + ]; + + for query in &bad { + assert_eq!(set.matches(query).into_iter().collect::>().len(), 0); } } @@ -428,7 +488,7 @@ mod test { let mut qr = QueryRouter::new(); // SetShardingKey - let query = simple_query("SET SHARDING KEY TO '13'"); + let query = simple_query("SET SHARDING KEY TO 13"); assert_eq!( qr.try_execute_command(query), Some((Command::SetShardingKey, String::from("1")))