mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-27 18:56:30 +00:00
Allow to set shard and set sharding key without quotes (#43)
* Allow to set shard and set sharding key without quotes * cover it * dont look for these in the middle of another query * friendly regex * its own response to set shard key
This commit is contained in:
@@ -229,11 +229,17 @@ impl Client {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Some((Command::SetShard, _)) | Some((Command::SetShardingKey, _)) => {
|
Some((Command::SetShard, _)) => {
|
||||||
custom_protocol_response_ok(&mut self.write, &format!("SET SHARD")).await?;
|
custom_protocol_response_ok(&mut self.write, &format!("SET SHARD")).await?;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Some((Command::SetShardingKey, _)) => {
|
||||||
|
custom_protocol_response_ok(&mut self.write, &format!("SET SHARDING KEY"))
|
||||||
|
.await?;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
Some((Command::SetServerRole, _)) => {
|
Some((Command::SetServerRole, _)) => {
|
||||||
custom_protocol_response_ok(&mut self.write, "SET SERVER ROLE").await?;
|
custom_protocol_response_ok(&mut self.write, "SET SERVER ROLE").await?;
|
||||||
continue;
|
continue;
|
||||||
|
|||||||
@@ -5,17 +5,17 @@ use crate::sharding::{Sharder, ShardingFunction};
|
|||||||
use bytes::{Buf, BytesMut};
|
use bytes::{Buf, BytesMut};
|
||||||
use log::{debug, error};
|
use log::{debug, error};
|
||||||
use once_cell::sync::OnceCell;
|
use once_cell::sync::OnceCell;
|
||||||
use regex::RegexSet;
|
use regex::{Regex, RegexSet};
|
||||||
use sqlparser::ast::Statement::{Query, StartTransaction};
|
use sqlparser::ast::Statement::{Query, StartTransaction};
|
||||||
use sqlparser::dialect::PostgreSqlDialect;
|
use sqlparser::dialect::PostgreSqlDialect;
|
||||||
use sqlparser::parser::Parser;
|
use sqlparser::parser::Parser;
|
||||||
|
|
||||||
const CUSTOM_SQL_REGEXES: [&str; 5] = [
|
const CUSTOM_SQL_REGEXES: [&str; 5] = [
|
||||||
r"(?i)SET SHARDING KEY TO '[0-9]+'",
|
r"(?i)^ *SET SHARDING KEY TO '?([0-9]+)'? *;? *$",
|
||||||
r"(?i)SET SHARD TO '[0-9]+'",
|
r"(?i)^ *SET SHARD TO '?([0-9]+)'? *;? *$",
|
||||||
r"(?i)SHOW SHARD",
|
r"(?i)^ *SHOW SHARD *;? *$",
|
||||||
r"(?i)SET SERVER ROLE TO '(PRIMARY|REPLICA|ANY|AUTO|DEFAULT)'",
|
r"(?i)^ *SET SERVER ROLE TO '(PRIMARY|REPLICA|ANY|AUTO|DEFAULT)' *;? *$",
|
||||||
r"(?i)SHOW SERVER ROLE",
|
r"(?i)^ *SHOW SERVER ROLE *;? *$",
|
||||||
];
|
];
|
||||||
|
|
||||||
#[derive(PartialEq, Debug)]
|
#[derive(PartialEq, Debug)]
|
||||||
@@ -27,8 +27,12 @@ pub enum Command {
|
|||||||
ShowServerRole,
|
ShowServerRole,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Quick test
|
||||||
static CUSTOM_SQL_REGEX_SET: OnceCell<RegexSet> = OnceCell::new();
|
static CUSTOM_SQL_REGEX_SET: OnceCell<RegexSet> = OnceCell::new();
|
||||||
|
|
||||||
|
// Capture value
|
||||||
|
static CUSTOM_SQL_REGEX_LIST: OnceCell<Vec<Regex>> = OnceCell::new();
|
||||||
|
|
||||||
pub struct QueryRouter {
|
pub struct QueryRouter {
|
||||||
// By default, queries go here, unless we have better information
|
// By default, queries go here, unless we have better information
|
||||||
// about what the client wants.
|
// about what the client wants.
|
||||||
@@ -63,6 +67,21 @@ impl QueryRouter {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let list: Vec<_> = CUSTOM_SQL_REGEXES
|
||||||
|
.iter()
|
||||||
|
.map(|rgx| Regex::new(rgx).unwrap())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Impossible
|
||||||
|
if list.len() != set.len() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
match CUSTOM_SQL_REGEX_LIST.set(list) {
|
||||||
|
Ok(_) => true,
|
||||||
|
Err(_) => return false,
|
||||||
|
};
|
||||||
|
|
||||||
match CUSTOM_SQL_REGEX_SET.set(set) {
|
match CUSTOM_SQL_REGEX_SET.set(set) {
|
||||||
Ok(_) => true,
|
Ok(_) => true,
|
||||||
Err(_) => false,
|
Err(_) => false,
|
||||||
@@ -113,6 +132,11 @@ impl QueryRouter {
|
|||||||
None => return None,
|
None => return None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let regex_list = match CUSTOM_SQL_REGEX_LIST.get() {
|
||||||
|
Some(regex_list) => regex_list,
|
||||||
|
None => return None,
|
||||||
|
};
|
||||||
|
|
||||||
let matches: Vec<_> = regex_set.matches(&query).into_iter().collect();
|
let matches: Vec<_> = regex_set.matches(&query).into_iter().collect();
|
||||||
|
|
||||||
if matches.len() != 1 {
|
if matches.len() != 1 {
|
||||||
@@ -130,7 +154,19 @@ impl QueryRouter {
|
|||||||
|
|
||||||
let mut value = match command {
|
let mut value = match command {
|
||||||
Command::SetShardingKey | Command::SetShard | Command::SetServerRole => {
|
Command::SetShardingKey | Command::SetShard | Command::SetServerRole => {
|
||||||
query.split("'").collect::<Vec<&str>>()[1].to_string()
|
// Capture value. I know this re-runs the regex engine, but I haven't
|
||||||
|
// figured out a better way just yet. I think I can write a single Regex
|
||||||
|
// that matches all 5 custom SQL patterns, but maybe that's not very legible?
|
||||||
|
//
|
||||||
|
// I think this is faster than running the Regex engine 5 times, so
|
||||||
|
// this is a strong maybe for me so far.
|
||||||
|
match regex_list[matches[0]].captures(&query) {
|
||||||
|
Some(captures) => match captures.get(1) {
|
||||||
|
Some(value) => value.as_str().to_string(),
|
||||||
|
None => return None,
|
||||||
|
},
|
||||||
|
None => return None,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Command::ShowShard => self.shard().to_string(),
|
Command::ShowShard => self.shard().to_string(),
|
||||||
@@ -411,14 +447,38 @@ mod test {
|
|||||||
"set server role to 'any'",
|
"set server role to 'any'",
|
||||||
"set server role to 'auto'",
|
"set server role to 'auto'",
|
||||||
"show server role",
|
"show server role",
|
||||||
|
// No quotes
|
||||||
|
"SET SHARDING KEY TO 11235",
|
||||||
|
"SET SHARD TO 15",
|
||||||
|
// Spaces and semicolon
|
||||||
|
" SET SHARDING KEY TO 11235 ; ",
|
||||||
|
" SET SHARD TO 15; ",
|
||||||
|
" SET SHARDING KEY TO 11235 ;",
|
||||||
|
" SET SERVER ROLE TO 'primary'; ",
|
||||||
|
" SET SERVER ROLE TO 'primary' ; ",
|
||||||
|
" SET SERVER ROLE TO 'primary' ;",
|
||||||
];
|
];
|
||||||
|
|
||||||
|
// Which regexes it'll match to in the list
|
||||||
|
let matches = [
|
||||||
|
0, 1, 2, 3, 3, 3, 3, 4, 0, 1, 2, 3, 3, 3, 3, 4, 0, 1, 0, 1, 0, 3, 3, 3,
|
||||||
|
];
|
||||||
|
|
||||||
|
let list = CUSTOM_SQL_REGEX_LIST.get().unwrap();
|
||||||
let set = CUSTOM_SQL_REGEX_SET.get().unwrap();
|
let set = CUSTOM_SQL_REGEX_SET.get().unwrap();
|
||||||
|
|
||||||
for test in &tests {
|
for (i, test) in tests.iter().enumerate() {
|
||||||
let matches: Vec<_> = set.matches(test).into_iter().collect();
|
assert!(list[matches[i]].is_match(test));
|
||||||
|
assert_eq!(set.matches(test).into_iter().collect::<Vec<_>>().len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
assert_eq!(matches.len(), 1);
|
let bad = [
|
||||||
|
"SELECT * FROM table",
|
||||||
|
"SELECT * FROM table WHERE value = 'set sharding key to 5'", // Don't capture things in the middle of the query
|
||||||
|
];
|
||||||
|
|
||||||
|
for query in &bad {
|
||||||
|
assert_eq!(set.matches(query).into_iter().collect::<Vec<_>>().len(), 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -428,7 +488,7 @@ mod test {
|
|||||||
let mut qr = QueryRouter::new();
|
let mut qr = QueryRouter::new();
|
||||||
|
|
||||||
// SetShardingKey
|
// SetShardingKey
|
||||||
let query = simple_query("SET SHARDING KEY TO '13'");
|
let query = simple_query("SET SHARDING KEY TO 13");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
qr.try_execute_command(query),
|
qr.try_execute_command(query),
|
||||||
Some((Command::SetShardingKey, String::from("1")))
|
Some((Command::SetShardingKey, String::from("1")))
|
||||||
|
|||||||
Reference in New Issue
Block a user