Allow to set shard and set sharding key without quotes (#43)

* Allow to set shard and set sharding key without quotes

* cover it

* dont look for these in the middle of another query

* friendly regex

* its own response to set shard key
This commit is contained in:
Lev Kokotov
2022-02-24 12:16:24 -08:00
committed by GitHub
parent 5972b6fa52
commit a784883611
2 changed files with 78 additions and 12 deletions

View File

@@ -229,11 +229,17 @@ impl Client {
}
}
Some((Command::SetShard, _)) | Some((Command::SetShardingKey, _)) => {
Some((Command::SetShard, _)) => {
custom_protocol_response_ok(&mut self.write, &format!("SET SHARD")).await?;
continue;
}
Some((Command::SetShardingKey, _)) => {
custom_protocol_response_ok(&mut self.write, &format!("SET SHARDING KEY"))
.await?;
continue;
}
Some((Command::SetServerRole, _)) => {
custom_protocol_response_ok(&mut self.write, "SET SERVER ROLE").await?;
continue;

View File

@@ -5,17 +5,17 @@ use crate::sharding::{Sharder, ShardingFunction};
use bytes::{Buf, BytesMut};
use log::{debug, error};
use once_cell::sync::OnceCell;
use regex::RegexSet;
use regex::{Regex, RegexSet};
use sqlparser::ast::Statement::{Query, StartTransaction};
use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;
const CUSTOM_SQL_REGEXES: [&str; 5] = [
r"(?i)SET SHARDING KEY TO '[0-9]+'",
r"(?i)SET SHARD TO '[0-9]+'",
r"(?i)SHOW SHARD",
r"(?i)SET SERVER ROLE TO '(PRIMARY|REPLICA|ANY|AUTO|DEFAULT)'",
r"(?i)SHOW SERVER ROLE",
r"(?i)^ *SET SHARDING KEY TO '?([0-9]+)'? *;? *$",
r"(?i)^ *SET SHARD TO '?([0-9]+)'? *;? *$",
r"(?i)^ *SHOW SHARD *;? *$",
r"(?i)^ *SET SERVER ROLE TO '(PRIMARY|REPLICA|ANY|AUTO|DEFAULT)' *;? *$",
r"(?i)^ *SHOW SERVER ROLE *;? *$",
];
#[derive(PartialEq, Debug)]
@@ -27,8 +27,12 @@ pub enum Command {
ShowServerRole,
}
// Quick test
static CUSTOM_SQL_REGEX_SET: OnceCell<RegexSet> = OnceCell::new();
// Capture value
static CUSTOM_SQL_REGEX_LIST: OnceCell<Vec<Regex>> = OnceCell::new();
pub struct QueryRouter {
// By default, queries go here, unless we have better information
// about what the client wants.
@@ -63,6 +67,21 @@ impl QueryRouter {
}
};
let list: Vec<_> = CUSTOM_SQL_REGEXES
.iter()
.map(|rgx| Regex::new(rgx).unwrap())
.collect();
// Impossible
if list.len() != set.len() {
return false;
}
match CUSTOM_SQL_REGEX_LIST.set(list) {
Ok(_) => true,
Err(_) => return false,
};
match CUSTOM_SQL_REGEX_SET.set(set) {
Ok(_) => true,
Err(_) => false,
@@ -113,6 +132,11 @@ impl QueryRouter {
None => return None,
};
let regex_list = match CUSTOM_SQL_REGEX_LIST.get() {
Some(regex_list) => regex_list,
None => return None,
};
let matches: Vec<_> = regex_set.matches(&query).into_iter().collect();
if matches.len() != 1 {
@@ -130,7 +154,19 @@ impl QueryRouter {
let mut value = match command {
Command::SetShardingKey | Command::SetShard | Command::SetServerRole => {
query.split("'").collect::<Vec<&str>>()[1].to_string()
// Capture value. I know this re-runs the regex engine, but I haven't
// figured out a better way just yet. I think I can write a single Regex
// that matches all 5 custom SQL patterns, but maybe that's not very legible?
//
// I think this is faster than running the Regex engine 5 times, so
// this is a strong maybe for me so far.
match regex_list[matches[0]].captures(&query) {
Some(captures) => match captures.get(1) {
Some(value) => value.as_str().to_string(),
None => return None,
},
None => return None,
}
}
Command::ShowShard => self.shard().to_string(),
@@ -411,14 +447,38 @@ mod test {
"set server role to 'any'",
"set server role to 'auto'",
"show server role",
// No quotes
"SET SHARDING KEY TO 11235",
"SET SHARD TO 15",
// Spaces and semicolon
" SET SHARDING KEY TO 11235 ; ",
" SET SHARD TO 15; ",
" SET SHARDING KEY TO 11235 ;",
" SET SERVER ROLE TO 'primary'; ",
" SET SERVER ROLE TO 'primary' ; ",
" SET SERVER ROLE TO 'primary' ;",
];
// Which regexes it'll match to in the list
let matches = [
0, 1, 2, 3, 3, 3, 3, 4, 0, 1, 2, 3, 3, 3, 3, 4, 0, 1, 0, 1, 0, 3, 3, 3,
];
let list = CUSTOM_SQL_REGEX_LIST.get().unwrap();
let set = CUSTOM_SQL_REGEX_SET.get().unwrap();
for test in &tests {
let matches: Vec<_> = set.matches(test).into_iter().collect();
for (i, test) in tests.iter().enumerate() {
assert!(list[matches[i]].is_match(test));
assert_eq!(set.matches(test).into_iter().collect::<Vec<_>>().len(), 1);
}
assert_eq!(matches.len(), 1);
let bad = [
"SELECT * FROM table",
"SELECT * FROM table WHERE value = 'set sharding key to 5'", // Don't capture things in the middle of the query
];
for query in &bad {
assert_eq!(set.matches(query).into_iter().collect::<Vec<_>>().len(), 0);
}
}
@@ -428,7 +488,7 @@ mod test {
let mut qr = QueryRouter::new();
// SetShardingKey
let query = simple_query("SET SHARDING KEY TO '13'");
let query = simple_query("SET SHARDING KEY TO 13");
assert_eq!(
qr.try_execute_command(query),
Some((Command::SetShardingKey, String::from("1")))