mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 01:16:30 +00:00
1407 lines
49 KiB
Rust
1407 lines
49 KiB
Rust
/// Route queries automatically based on explicitly requested
|
|
/// or implied query characteristics.
|
|
use bytes::{Buf, BytesMut};
|
|
use log::{debug, error};
|
|
use once_cell::sync::OnceCell;
|
|
use regex::{Regex, RegexSet};
|
|
use sqlparser::ast::Statement::{Query, StartTransaction};
|
|
use sqlparser::ast::{
|
|
BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, Statement, TableFactor,
|
|
Value,
|
|
};
|
|
use sqlparser::dialect::PostgreSqlDialect;
|
|
use sqlparser::parser::Parser;
|
|
|
|
use crate::config::Role;
|
|
use crate::errors::Error;
|
|
use crate::messages::BytesMutReader;
|
|
use crate::plugins::{
|
|
intercept, query_logger, table_access, Intercept, Plugin, PluginOutput, QueryLogger,
|
|
TableAccess,
|
|
};
|
|
use crate::pool::PoolSettings;
|
|
use crate::sharding::Sharder;
|
|
|
|
use std::cmp;
|
|
use std::collections::BTreeSet;
|
|
use std::io::Cursor;
|
|
|
|
/// Regexes used to parse custom commands.
|
|
const CUSTOM_SQL_REGEXES: [&str; 7] = [
|
|
r"(?i)^ *SET SHARDING KEY TO '?([0-9]+)'? *;? *$",
|
|
r"(?i)^ *SET SHARD TO '?([0-9]+|ANY)'? *;? *$",
|
|
r"(?i)^ *SHOW SHARD *;? *$",
|
|
r"(?i)^ *SET SERVER ROLE TO '(PRIMARY|REPLICA|ANY|AUTO|DEFAULT)' *;? *$",
|
|
r"(?i)^ *SHOW SERVER ROLE *;? *$",
|
|
r"(?i)^ *SET PRIMARY READS TO '?(on|off|default)'? *;? *$",
|
|
r"(?i)^ *SHOW PRIMARY READS *;? *$",
|
|
];
|
|
|
|
/// Custom commands.
|
|
#[derive(PartialEq, Debug)]
|
|
pub enum Command {
|
|
SetShardingKey,
|
|
SetShard,
|
|
ShowShard,
|
|
SetServerRole,
|
|
ShowServerRole,
|
|
SetPrimaryReads,
|
|
ShowPrimaryReads,
|
|
}
|
|
|
|
#[derive(PartialEq, Debug)]
|
|
pub enum ShardingKey {
|
|
Value(i64),
|
|
Placeholder(i16),
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
enum ParameterFormat {
|
|
Text,
|
|
Binary,
|
|
Uniform(Box<ParameterFormat>),
|
|
Specified(Vec<ParameterFormat>),
|
|
}
|
|
|
|
/// Quickly test for match when a query is received.
|
|
static CUSTOM_SQL_REGEX_SET: OnceCell<RegexSet> = OnceCell::new();
|
|
|
|
// Get the value inside the custom command.
|
|
static CUSTOM_SQL_REGEX_LIST: OnceCell<Vec<Regex>> = OnceCell::new();
|
|
|
|
/// The query router.
|
|
pub struct QueryRouter {
|
|
/// Which shard we should be talking to right now.
|
|
active_shard: Option<usize>,
|
|
|
|
/// Which server should we be talking to.
|
|
active_role: Option<Role>,
|
|
|
|
/// Should we try to parse queries to route them to replicas or primary automatically
|
|
query_parser_enabled: Option<bool>,
|
|
|
|
/// Include the primary into the replica pool for reads.
|
|
primary_reads_enabled: Option<bool>,
|
|
|
|
/// Pool configuration.
|
|
pool_settings: PoolSettings,
|
|
|
|
// Placeholders from prepared statement.
|
|
placeholders: Vec<i16>,
|
|
}
|
|
|
|
impl QueryRouter {
|
|
/// One-time initialization of regexes
|
|
/// that parse our custom SQL protocol.
|
|
pub fn setup() -> bool {
|
|
let set = match RegexSet::new(&CUSTOM_SQL_REGEXES) {
|
|
Ok(rgx) => rgx,
|
|
Err(err) => {
|
|
error!("QueryRouter::setup Could not compile regex set: {:?}", err);
|
|
return false;
|
|
}
|
|
};
|
|
|
|
let list: Vec<_> = CUSTOM_SQL_REGEXES
|
|
.iter()
|
|
.map(|rgx| Regex::new(rgx).unwrap())
|
|
.collect();
|
|
|
|
assert_eq!(list.len(), set.len());
|
|
|
|
match CUSTOM_SQL_REGEX_LIST.set(list) {
|
|
Ok(_) => true,
|
|
Err(_) => return false,
|
|
};
|
|
|
|
CUSTOM_SQL_REGEX_SET.set(set).is_ok()
|
|
}
|
|
|
|
/// Create a new instance of the query router.
|
|
/// Each client gets its own.
|
|
pub fn new() -> QueryRouter {
|
|
QueryRouter {
|
|
active_shard: None,
|
|
active_role: None,
|
|
query_parser_enabled: None,
|
|
primary_reads_enabled: None,
|
|
pool_settings: PoolSettings::default(),
|
|
placeholders: Vec::new(),
|
|
}
|
|
}
|
|
|
|
/// Pool settings can change because of a config reload.
|
|
pub fn update_pool_settings(&mut self, pool_settings: PoolSettings) {
|
|
self.pool_settings = pool_settings;
|
|
}
|
|
|
|
pub fn pool_settings<'a>(&'a self) -> &'a PoolSettings {
|
|
&self.pool_settings
|
|
}
|
|
|
|
/// Try to parse a command and execute it.
|
|
pub fn try_execute_command(&mut self, message_buffer: &BytesMut) -> Option<(Command, String)> {
|
|
let mut message_cursor = Cursor::new(message_buffer);
|
|
|
|
let code = message_cursor.get_u8() as char;
|
|
|
|
// Check for any sharding regex matches in any queries
|
|
match code as char {
|
|
// For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement
|
|
'P' | 'Q' => {
|
|
if self.pool_settings.shard_id_regex.is_some()
|
|
|| self.pool_settings.sharding_key_regex.is_some()
|
|
{
|
|
// Check only the first block of bytes configured by the pool settings
|
|
let len = message_cursor.get_i32() as usize;
|
|
let seg = cmp::min(len - 5, self.pool_settings.regex_search_limit);
|
|
let initial_segment = String::from_utf8_lossy(&message_buffer[0..seg]);
|
|
|
|
// Check for a shard_id included in the query
|
|
if let Some(shard_id_regex) = &self.pool_settings.shard_id_regex {
|
|
let shard_id = shard_id_regex.captures(&initial_segment).and_then(|cap| {
|
|
cap.get(1).and_then(|id| id.as_str().parse::<usize>().ok())
|
|
});
|
|
if let Some(shard_id) = shard_id {
|
|
debug!("Setting shard to {:?}", shard_id);
|
|
self.set_shard(shard_id);
|
|
// Skip other command processing since a sharding command was found
|
|
return None;
|
|
}
|
|
}
|
|
|
|
// Check for a sharding_key included in the query
|
|
if let Some(sharding_key_regex) = &self.pool_settings.sharding_key_regex {
|
|
let sharding_key =
|
|
sharding_key_regex
|
|
.captures(&initial_segment)
|
|
.and_then(|cap| {
|
|
cap.get(1).and_then(|id| id.as_str().parse::<i64>().ok())
|
|
});
|
|
if let Some(sharding_key) = sharding_key {
|
|
debug!("Setting sharding_key to {:?}", sharding_key);
|
|
self.set_sharding_key(sharding_key);
|
|
// Skip other command processing since a sharding command was found
|
|
return None;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
|
|
// Only simple protocol supported for commands processed below
|
|
if code != 'Q' {
|
|
return None;
|
|
}
|
|
|
|
let _len = message_cursor.get_i32() as usize;
|
|
let query = message_cursor.read_string().unwrap();
|
|
|
|
let regex_set = match CUSTOM_SQL_REGEX_SET.get() {
|
|
Some(regex_set) => regex_set,
|
|
None => return None,
|
|
};
|
|
|
|
let regex_list = match CUSTOM_SQL_REGEX_LIST.get() {
|
|
Some(regex_list) => regex_list,
|
|
None => return None,
|
|
};
|
|
|
|
let matches: Vec<_> = regex_set.matches(&query).into_iter().collect();
|
|
|
|
// This is not a custom query, try to infer which
|
|
// server it'll go to if the query parser is enabled.
|
|
if matches.len() != 1 {
|
|
debug!("Regular query, not a command");
|
|
return None;
|
|
}
|
|
|
|
let command = match matches[0] {
|
|
0 => Command::SetShardingKey,
|
|
1 => Command::SetShard,
|
|
2 => Command::ShowShard,
|
|
3 => Command::SetServerRole,
|
|
4 => Command::ShowServerRole,
|
|
5 => Command::SetPrimaryReads,
|
|
6 => Command::ShowPrimaryReads,
|
|
_ => unreachable!(),
|
|
};
|
|
|
|
let mut value = match command {
|
|
Command::SetShardingKey
|
|
| Command::SetShard
|
|
| Command::SetServerRole
|
|
| Command::SetPrimaryReads => {
|
|
// Capture value. I know this re-runs the regex engine, but I haven't
|
|
// figured out a better way just yet. I think I can write a single Regex
|
|
// that matches all 5 custom SQL patterns, but maybe that's not very legible?
|
|
//
|
|
// I think this is faster than running the Regex engine 5 times.
|
|
match regex_list[matches[0]].captures(&query) {
|
|
Some(captures) => match captures.get(1) {
|
|
Some(value) => value.as_str().to_string(),
|
|
None => return None,
|
|
},
|
|
None => return None,
|
|
}
|
|
}
|
|
|
|
Command::ShowShard => self.shard().to_string(),
|
|
Command::ShowServerRole => match self.active_role {
|
|
Some(Role::Primary) => Role::Primary.to_string(),
|
|
Some(Role::Replica) => Role::Replica.to_string(),
|
|
Some(Role::Mirror) => Role::Mirror.to_string(),
|
|
None => {
|
|
if self.query_parser_enabled() {
|
|
String::from("auto")
|
|
} else {
|
|
String::from("any")
|
|
}
|
|
}
|
|
},
|
|
|
|
Command::ShowPrimaryReads => match self.primary_reads_enabled() {
|
|
true => String::from("on"),
|
|
false => String::from("off"),
|
|
},
|
|
};
|
|
|
|
match command {
|
|
Command::SetShardingKey => {
|
|
// TODO: some error handling here
|
|
value = self
|
|
.set_sharding_key(value.parse::<i64>().unwrap())
|
|
.unwrap()
|
|
.to_string();
|
|
}
|
|
|
|
Command::SetShard => {
|
|
self.active_shard = match value.to_ascii_uppercase().as_ref() {
|
|
"ANY" => Some(rand::random::<usize>() % self.pool_settings.shards),
|
|
_ => Some(value.parse::<usize>().unwrap()),
|
|
};
|
|
}
|
|
|
|
Command::SetServerRole => {
|
|
self.active_role = match value.to_ascii_lowercase().as_ref() {
|
|
"primary" => {
|
|
self.query_parser_enabled = Some(false);
|
|
Some(Role::Primary)
|
|
}
|
|
|
|
"replica" => {
|
|
self.query_parser_enabled = Some(false);
|
|
Some(Role::Replica)
|
|
}
|
|
|
|
"any" => {
|
|
self.query_parser_enabled = Some(false);
|
|
None
|
|
}
|
|
|
|
"auto" => {
|
|
self.query_parser_enabled = Some(true);
|
|
None
|
|
}
|
|
|
|
"default" => {
|
|
self.active_role = self.pool_settings.default_role;
|
|
self.query_parser_enabled = None;
|
|
self.active_role
|
|
}
|
|
|
|
_ => unreachable!(),
|
|
};
|
|
}
|
|
|
|
Command::SetPrimaryReads => {
|
|
if value == "on" {
|
|
debug!("Setting primary reads to on");
|
|
self.primary_reads_enabled = Some(true);
|
|
} else if value == "off" {
|
|
debug!("Setting primary reads to off");
|
|
self.primary_reads_enabled = Some(false);
|
|
} else if value == "default" {
|
|
debug!("Setting primary reads to default");
|
|
self.primary_reads_enabled = None;
|
|
}
|
|
}
|
|
|
|
_ => (),
|
|
}
|
|
|
|
Some((command, value))
|
|
}
|
|
|
|
pub fn parse(message: &BytesMut) -> Result<Vec<sqlparser::ast::Statement>, Error> {
|
|
let mut message_cursor = Cursor::new(message);
|
|
|
|
let code = message_cursor.get_u8() as char;
|
|
let _len = message_cursor.get_i32() as usize;
|
|
|
|
let query = match code {
|
|
// Query
|
|
'Q' => {
|
|
let query = message_cursor.read_string().unwrap();
|
|
debug!("Query: '{}'", query);
|
|
query
|
|
}
|
|
|
|
// Parse (prepared statement)
|
|
'P' => {
|
|
// Reads statement name
|
|
message_cursor.read_string().unwrap();
|
|
|
|
// Reads query string
|
|
let query = message_cursor.read_string().unwrap();
|
|
|
|
debug!("Prepared statement: '{}'", query);
|
|
query
|
|
}
|
|
|
|
_ => return Err(Error::UnsupportedStatement),
|
|
};
|
|
|
|
match Parser::parse_sql(&PostgreSqlDialect {}, &query) {
|
|
Ok(ast) => Ok(ast),
|
|
Err(err) => {
|
|
debug!("{}: {}", err, query);
|
|
Err(Error::QueryRouterParserError(err.to_string()))
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Try to infer which server to connect to based on the contents of the query.
|
|
pub fn infer(&mut self, ast: &Vec<sqlparser::ast::Statement>) -> Result<(), Error> {
|
|
debug!("Inferring role");
|
|
|
|
if ast.is_empty() {
|
|
// That's weird, no idea, let's go to primary
|
|
self.active_role = Some(Role::Primary);
|
|
return Err(Error::QueryRouterParserError("empty query".into()));
|
|
}
|
|
|
|
for q in ast {
|
|
match q {
|
|
// All transactions go to the primary, probably a write.
|
|
StartTransaction { .. } => {
|
|
self.active_role = Some(Role::Primary);
|
|
break;
|
|
}
|
|
|
|
// Likely a read-only query
|
|
Query(query) => {
|
|
match &self.pool_settings.automatic_sharding_key {
|
|
Some(_) => {
|
|
// TODO: if we have multiple queries in the same message,
|
|
// we can either split them and execute them individually
|
|
// or discard shard selection. If they point to the same shard though,
|
|
// we can let them through as-is.
|
|
// This is basically building a database now :)
|
|
match self.infer_shard(query) {
|
|
Some(shard) => {
|
|
self.active_shard = Some(shard);
|
|
debug!("Automatically using shard: {:?}", self.active_shard);
|
|
}
|
|
|
|
None => (),
|
|
};
|
|
}
|
|
|
|
None => (),
|
|
};
|
|
|
|
self.active_role = match self.primary_reads_enabled() {
|
|
false => Some(Role::Replica), // If primary should not be receiving reads, use a replica.
|
|
true => None, // Any server role is fine in this case.
|
|
}
|
|
}
|
|
|
|
// Likely a write
|
|
_ => {
|
|
self.active_role = Some(Role::Primary);
|
|
break;
|
|
}
|
|
};
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Parse the shard number from the Bind message
|
|
/// which contains the arguments for a prepared statement.
|
|
///
|
|
/// N.B.: Only supports anonymous prepared statements since we don't
|
|
/// keep a cache of them in PgCat.
|
|
pub fn infer_shard_from_bind(&mut self, message: &BytesMut) -> bool {
|
|
debug!("Parsing bind message");
|
|
|
|
let mut message_cursor = Cursor::new(message);
|
|
|
|
let code = message_cursor.get_u8() as char;
|
|
let len = message_cursor.get_i32();
|
|
|
|
if code != 'B' {
|
|
debug!("Not a bind packet");
|
|
return false;
|
|
}
|
|
|
|
// Check message length
|
|
if message.len() != len as usize + 1 {
|
|
debug!(
|
|
"Message has wrong length, expected {}, but have {}",
|
|
len,
|
|
message.len()
|
|
);
|
|
return false;
|
|
}
|
|
|
|
// There are no shard keys in the prepared statement.
|
|
if self.placeholders.is_empty() {
|
|
debug!("There are no placeholders in the prepared statement that matched the automatic sharding key");
|
|
return false;
|
|
}
|
|
|
|
let sharder = Sharder::new(
|
|
self.pool_settings.shards,
|
|
self.pool_settings.sharding_function,
|
|
);
|
|
|
|
let mut shards = BTreeSet::new();
|
|
|
|
let _portal = message_cursor.read_string();
|
|
let _name = message_cursor.read_string();
|
|
|
|
let num_params = message_cursor.get_i16();
|
|
let parameter_format = match num_params {
|
|
0 => ParameterFormat::Text, // Text
|
|
1 => {
|
|
let param_format = message_cursor.get_i16();
|
|
ParameterFormat::Uniform(match param_format {
|
|
0 => Box::new(ParameterFormat::Text),
|
|
1 => Box::new(ParameterFormat::Binary),
|
|
_ => unreachable!(),
|
|
})
|
|
}
|
|
n => {
|
|
let mut v = Vec::with_capacity(n as usize);
|
|
for _ in 0..n {
|
|
let param_format = message_cursor.get_i16();
|
|
v.push(match param_format {
|
|
0 => ParameterFormat::Text,
|
|
1 => ParameterFormat::Binary,
|
|
_ => unreachable!(),
|
|
});
|
|
}
|
|
ParameterFormat::Specified(v)
|
|
}
|
|
};
|
|
|
|
let num_parameters = message_cursor.get_i16();
|
|
|
|
for i in 0..num_parameters {
|
|
let mut len = message_cursor.get_i32() as usize;
|
|
let format = match ¶meter_format {
|
|
ParameterFormat::Text => ParameterFormat::Text,
|
|
ParameterFormat::Uniform(format) => *format.clone(),
|
|
ParameterFormat::Specified(formats) => formats[i as usize].clone(),
|
|
_ => unreachable!(),
|
|
};
|
|
|
|
debug!("Parameter {} (len: {}): {:?}", i, len, format);
|
|
|
|
// Postgres counts placeholders starting at 1
|
|
let placeholder = i + 1;
|
|
|
|
if self.placeholders.contains(&placeholder) {
|
|
let value = match format {
|
|
ParameterFormat::Text => {
|
|
let mut value = String::new();
|
|
while len > 0 {
|
|
value.push(message_cursor.get_u8() as char);
|
|
len -= 1;
|
|
}
|
|
|
|
match value.parse::<i64>() {
|
|
Ok(value) => value,
|
|
Err(_) => {
|
|
debug!("Error parsing bind value: {}", value);
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
|
|
ParameterFormat::Binary => match len {
|
|
2 => message_cursor.get_i16() as i64,
|
|
4 => message_cursor.get_i32() as i64,
|
|
8 => message_cursor.get_i64(),
|
|
_ => {
|
|
error!(
|
|
"Got wrong length for integer type parameter in bind: {}",
|
|
len
|
|
);
|
|
continue;
|
|
}
|
|
},
|
|
|
|
_ => unreachable!(),
|
|
};
|
|
|
|
shards.insert(sharder.shard(value));
|
|
}
|
|
}
|
|
|
|
self.placeholders.clear();
|
|
self.placeholders.shrink_to_fit();
|
|
|
|
// We only support querying one shard at a time.
|
|
// TODO: Support multi-shard queries some day.
|
|
if shards.len() == 1 {
|
|
debug!("Found one sharding key");
|
|
self.set_shard(*shards.first().unwrap());
|
|
true
|
|
} else {
|
|
debug!("Found no sharding keys");
|
|
false
|
|
}
|
|
}
|
|
|
|
/// A `selection` is the `WHERE` clause. This parses
|
|
/// the clause and extracts the sharding key, if present.
|
|
fn selection_parser(&self, expr: &Expr, table_names: &Vec<Vec<Ident>>) -> Vec<ShardingKey> {
|
|
let mut result = Vec::new();
|
|
let mut found = false;
|
|
|
|
let sharding_key = self
|
|
.pool_settings
|
|
.automatic_sharding_key
|
|
.as_ref()
|
|
.unwrap()
|
|
.split(".")
|
|
.map(|ident| Ident::new(ident))
|
|
.collect::<Vec<Ident>>();
|
|
|
|
// Sharding key must be always fully qualified
|
|
assert_eq!(sharding_key.len(), 2);
|
|
|
|
// This parses `sharding_key = 5`. But it's technically
|
|
// legal to write `5 = sharding_key`. I don't judge the people
|
|
// who do that, but I think ORMs will still use the first variant,
|
|
// so we can leave the second as a TODO.
|
|
if let Expr::BinaryOp { left, op, right } = expr {
|
|
match &**left {
|
|
Expr::BinaryOp { .. } => result.extend(self.selection_parser(left, table_names)),
|
|
Expr::Identifier(ident) => {
|
|
// Only if we're dealing with only one table
|
|
// and there is no ambiguity
|
|
if &ident.value == &sharding_key[1].value {
|
|
// Sharding key is unique enough, don't worry about
|
|
// table names.
|
|
if &sharding_key[0].value == "*" {
|
|
found = true;
|
|
} else if table_names.len() == 1 {
|
|
let table = &table_names[0];
|
|
|
|
if table.len() == 1 {
|
|
// Table is not fully qualified, e.g.
|
|
// SELECT * FROM t WHERE sharding_key = 5
|
|
// Make sure the table name from the sharding key matches
|
|
// the table name from the query.
|
|
found = &sharding_key[0].value == &table[0].value;
|
|
} else if table.len() == 2 {
|
|
// Table name is fully qualified with the schema: e.g.
|
|
// SELECT * FROM public.t WHERE sharding_key = 5
|
|
// Ignore the schema (TODO: at some point, we want schema support)
|
|
// and use the table name only.
|
|
found = &sharding_key[0].value == &table[1].value;
|
|
} else {
|
|
debug!("Got table name with more than two idents, which is not possible");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
Expr::CompoundIdentifier(idents) => {
|
|
// The key is fully qualified in the query,
|
|
// it will exist or Postgres will throw an error.
|
|
if idents.len() == 2 {
|
|
found = &sharding_key[0].value == &idents[0].value
|
|
&& &sharding_key[1].value == &idents[1].value;
|
|
}
|
|
// TODO: key can have schema as well, e.g. public.data.id (len == 3)
|
|
}
|
|
_ => (),
|
|
};
|
|
|
|
match op {
|
|
BinaryOperator::Eq => (),
|
|
BinaryOperator::Or => (),
|
|
BinaryOperator::And => (),
|
|
_ => {
|
|
// TODO: support other operators than equality.
|
|
debug!("Unsupported operation: {:?}", op);
|
|
return Vec::new();
|
|
}
|
|
};
|
|
|
|
match &**right {
|
|
Expr::BinaryOp { .. } => result.extend(self.selection_parser(right, table_names)),
|
|
Expr::Value(Value::Number(value, ..)) => {
|
|
if found {
|
|
match value.parse::<i64>() {
|
|
Ok(value) => result.push(ShardingKey::Value(value)),
|
|
Err(_) => {
|
|
debug!("Sharding key was not an integer: {}", value);
|
|
}
|
|
};
|
|
}
|
|
}
|
|
|
|
Expr::Value(Value::Placeholder(placeholder)) => {
|
|
match placeholder.replace("$", "").parse::<i16>() {
|
|
Ok(placeholder) => result.push(ShardingKey::Placeholder(placeholder)),
|
|
Err(_) => {
|
|
debug!(
|
|
"Prepared statement didn't have integer placeholders: {}",
|
|
placeholder
|
|
);
|
|
}
|
|
}
|
|
}
|
|
_ => (),
|
|
};
|
|
}
|
|
|
|
debug!("Sharding keys found: {:?}", result);
|
|
|
|
result
|
|
}
|
|
|
|
/// Try to figure out which shard the query should go to.
|
|
fn infer_shard(&mut self, query: &sqlparser::ast::Query) -> Option<usize> {
|
|
let mut shards = BTreeSet::new();
|
|
let mut exprs = Vec::new();
|
|
|
|
match &*query.body {
|
|
SetExpr::Query(query) => {
|
|
match self.infer_shard(&*query) {
|
|
Some(shard) => {
|
|
shards.insert(shard);
|
|
}
|
|
None => (),
|
|
};
|
|
}
|
|
|
|
// SELECT * FROM ...
|
|
// We understand that pretty well.
|
|
SetExpr::Select(select) => {
|
|
// Collect all table names from the query.
|
|
let mut table_names = Vec::new();
|
|
|
|
for table in select.from.iter() {
|
|
match &table.relation {
|
|
TableFactor::Table { name, .. } => {
|
|
table_names.push(name.0.clone());
|
|
}
|
|
|
|
_ => (),
|
|
};
|
|
|
|
// Get table names from all the joins.
|
|
for join in table.joins.iter() {
|
|
match &join.relation {
|
|
TableFactor::Table { name, .. } => {
|
|
table_names.push(name.0.clone());
|
|
}
|
|
|
|
_ => (),
|
|
};
|
|
|
|
// We can filter results based on join conditions, e.g.
|
|
// SELECT * FROM t INNER JOIN B ON B.sharding_key = 5;
|
|
match &join.join_operator {
|
|
JoinOperator::Inner(inner_join) => match &inner_join {
|
|
JoinConstraint::On(expr) => {
|
|
// Parse the selection criteria later.
|
|
exprs.push(expr.clone());
|
|
}
|
|
|
|
_ => (),
|
|
},
|
|
|
|
_ => (),
|
|
};
|
|
}
|
|
}
|
|
|
|
// Parse the actual "FROM ..."
|
|
match &select.selection {
|
|
Some(selection) => {
|
|
exprs.push(selection.clone());
|
|
}
|
|
|
|
None => (),
|
|
};
|
|
|
|
let sharder = Sharder::new(
|
|
self.pool_settings.shards,
|
|
self.pool_settings.sharding_function,
|
|
);
|
|
|
|
// Look for sharding keys in either the join condition
|
|
// or the selection.
|
|
for expr in exprs.iter() {
|
|
let sharding_keys = self.selection_parser(expr, &table_names);
|
|
|
|
// TODO: Add support for prepared statements here.
|
|
// This should just give us the position of the value in the `B` message.
|
|
|
|
for value in sharding_keys {
|
|
match value {
|
|
ShardingKey::Value(value) => {
|
|
let shard = sharder.shard(value);
|
|
shards.insert(shard);
|
|
}
|
|
|
|
ShardingKey::Placeholder(position) => {
|
|
self.placeholders.push(position);
|
|
}
|
|
};
|
|
}
|
|
}
|
|
}
|
|
_ => (),
|
|
};
|
|
|
|
match shards.len() {
|
|
// Didn't find a sharding key, you're on your own.
|
|
0 => {
|
|
debug!("No sharding keys found");
|
|
None
|
|
}
|
|
|
|
1 => Some(shards.into_iter().last().unwrap()),
|
|
|
|
// TODO: support querying multiple shards (some day...)
|
|
_ => {
|
|
debug!("More than one sharding key found");
|
|
None
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Add your plugins here and execute them.
|
|
pub async fn execute_plugins(&self, ast: &Vec<Statement>) -> Result<PluginOutput, Error> {
|
|
if query_logger::enabled() {
|
|
let mut query_logger = QueryLogger {};
|
|
let _ = query_logger.run(&self, ast).await;
|
|
}
|
|
|
|
if intercept::enabled() {
|
|
let mut intercept = Intercept {};
|
|
let result = intercept.run(&self, ast).await;
|
|
|
|
if let Ok(PluginOutput::Intercept(output)) = result {
|
|
return Ok(PluginOutput::Intercept(output));
|
|
}
|
|
}
|
|
|
|
if table_access::enabled() {
|
|
let mut table_access = TableAccess {};
|
|
let result = table_access.run(&self, ast).await;
|
|
|
|
if let Ok(PluginOutput::Deny(error)) = result {
|
|
return Ok(PluginOutput::Deny(error));
|
|
}
|
|
}
|
|
|
|
Ok(PluginOutput::Allow)
|
|
}
|
|
|
|
fn set_sharding_key(&mut self, sharding_key: i64) -> Option<usize> {
|
|
let sharder = Sharder::new(
|
|
self.pool_settings.shards,
|
|
self.pool_settings.sharding_function,
|
|
);
|
|
let shard = sharder.shard(sharding_key);
|
|
self.set_shard(shard);
|
|
self.active_shard
|
|
}
|
|
|
|
/// Get the current desired server role we should be talking to.
|
|
pub fn role(&self) -> Option<Role> {
|
|
self.active_role
|
|
}
|
|
|
|
/// Get desired shard we should be talking to.
|
|
pub fn shard(&self) -> usize {
|
|
self.active_shard.unwrap_or(0)
|
|
}
|
|
|
|
pub fn set_shard(&mut self, shard: usize) {
|
|
self.active_shard = Some(shard);
|
|
}
|
|
|
|
/// Should we attempt to parse queries?
|
|
pub fn query_parser_enabled(&self) -> bool {
|
|
let enabled = match self.query_parser_enabled {
|
|
None => {
|
|
debug!(
|
|
"Using pool settings, query_parser_enabled: {}",
|
|
self.pool_settings.query_parser_enabled
|
|
);
|
|
self.pool_settings.query_parser_enabled
|
|
}
|
|
|
|
Some(value) => {
|
|
debug!(
|
|
"Using query parser override, query_parser_enabled: {}",
|
|
value
|
|
);
|
|
value
|
|
}
|
|
};
|
|
|
|
enabled
|
|
}
|
|
|
|
pub fn primary_reads_enabled(&self) -> bool {
|
|
match self.primary_reads_enabled {
|
|
None => self.pool_settings.primary_reads_enabled,
|
|
Some(value) => value,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
use super::*;
|
|
use crate::config::PoolMode;
|
|
use crate::messages::simple_query;
|
|
use crate::sharding::ShardingFunction;
|
|
use bytes::BufMut;
|
|
|
|
#[test]
|
|
fn test_defaults() {
|
|
QueryRouter::setup();
|
|
let qr = QueryRouter::new();
|
|
|
|
assert_eq!(qr.role(), None);
|
|
}
|
|
|
|
#[test]
|
|
fn test_infer_replica() {
|
|
QueryRouter::setup();
|
|
let mut qr = QueryRouter::new();
|
|
assert!(qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")) != None);
|
|
assert!(qr.query_parser_enabled());
|
|
|
|
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
|
|
|
|
let queries = vec![
|
|
simple_query("SELECT * FROM items WHERE id = 5"),
|
|
simple_query(
|
|
"SELECT id, name, value FROM items INNER JOIN prices ON item.id = prices.item_id",
|
|
),
|
|
simple_query("WITH t AS (SELECT * FROM items) SELECT * FROM t"),
|
|
];
|
|
|
|
for query in queries {
|
|
// It's a recognized query
|
|
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
|
assert_eq!(qr.role(), Some(Role::Replica));
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_infer_primary() {
|
|
QueryRouter::setup();
|
|
let mut qr = QueryRouter::new();
|
|
|
|
let queries = vec![
|
|
simple_query("UPDATE items SET name = 'pumpkin' WHERE id = 5"),
|
|
simple_query("INSERT INTO items (id, name) VALUES (5, 'pumpkin')"),
|
|
simple_query("DELETE FROM items WHERE id = 5"),
|
|
simple_query("BEGIN"), // Transaction start
|
|
];
|
|
|
|
for query in queries {
|
|
// It's a recognized query
|
|
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
|
assert_eq!(qr.role(), Some(Role::Primary));
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_infer_primary_reads_enabled() {
|
|
QueryRouter::setup();
|
|
let mut qr = QueryRouter::new();
|
|
let query = simple_query("SELECT * FROM items WHERE id = 5");
|
|
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None);
|
|
|
|
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
|
assert_eq!(qr.role(), None);
|
|
}
|
|
|
|
#[test]
|
|
fn test_infer_parse_prepared() {
|
|
QueryRouter::setup();
|
|
let mut qr = QueryRouter::new();
|
|
qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"));
|
|
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
|
|
|
|
let prepared_stmt = BytesMut::from(
|
|
&b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..],
|
|
);
|
|
let mut res = BytesMut::from(&b"P"[..]);
|
|
res.put_i32(prepared_stmt.len() as i32 + 4 + 1 + 2);
|
|
res.put_u8(0);
|
|
res.put(prepared_stmt);
|
|
res.put_i16(0);
|
|
|
|
assert!(qr.infer(&QueryRouter::parse(&res).unwrap()).is_ok());
|
|
assert_eq!(qr.role(), Some(Role::Replica));
|
|
}
|
|
|
|
#[test]
|
|
fn test_regex_set() {
|
|
QueryRouter::setup();
|
|
|
|
let tests = [
|
|
// Upper case
|
|
"SET SHARDING KEY TO '1'",
|
|
"SET SHARD TO '1'",
|
|
"SHOW SHARD",
|
|
"SET SERVER ROLE TO 'replica'",
|
|
"SET SERVER ROLE TO 'primary'",
|
|
"SET SERVER ROLE TO 'any'",
|
|
"SET SERVER ROLE TO 'auto'",
|
|
"SHOW SERVER ROLE",
|
|
"SET PRIMARY READS TO 'on'",
|
|
"SET PRIMARY READS TO 'off'",
|
|
"SET PRIMARY READS TO 'default'",
|
|
"SHOW PRIMARY READS",
|
|
// Lower case
|
|
"set sharding key to '1'",
|
|
"set shard to '1'",
|
|
"show shard",
|
|
"set server role to 'replica'",
|
|
"set server role to 'primary'",
|
|
"set server role to 'any'",
|
|
"set server role to 'auto'",
|
|
"show server role",
|
|
"set primary reads to 'on'",
|
|
"set primary reads to 'OFF'",
|
|
"set primary reads to 'deFaUlt'",
|
|
// No quotes
|
|
"SET SHARDING KEY TO 11235",
|
|
"SET SHARD TO 15",
|
|
"SET PRIMARY READS TO off",
|
|
// Spaces and semicolon
|
|
" SET SHARDING KEY TO 11235 ; ",
|
|
" SET SHARD TO 15; ",
|
|
" SET SHARDING KEY TO 11235 ;",
|
|
" SET SERVER ROLE TO 'primary'; ",
|
|
" SET SERVER ROLE TO 'primary' ; ",
|
|
" SET SERVER ROLE TO 'primary' ;",
|
|
" SET PRIMARY READS TO 'off' ;",
|
|
];
|
|
|
|
// Which regexes it'll match to in the list
|
|
let matches = [
|
|
0, 1, 2, 3, 3, 3, 3, 4, 5, 5, 5, 6, 0, 1, 2, 3, 3, 3, 3, 4, 5, 5, 5, 0, 1, 5, 0, 1, 0,
|
|
3, 3, 3, 5,
|
|
];
|
|
|
|
let list = CUSTOM_SQL_REGEX_LIST.get().unwrap();
|
|
let set = CUSTOM_SQL_REGEX_SET.get().unwrap();
|
|
|
|
for (i, test) in tests.iter().enumerate() {
|
|
if !list[matches[i]].is_match(test) {
|
|
println!("{} does not match {}", test, list[matches[i]]);
|
|
panic!();
|
|
}
|
|
assert_eq!(set.matches(test).into_iter().count(), 1);
|
|
}
|
|
|
|
let bad = [
|
|
"SELECT * FROM table",
|
|
"SELECT * FROM table WHERE value = 'set sharding key to 5'", // Don't capture things in the middle of the query
|
|
];
|
|
|
|
for query in &bad {
|
|
assert_eq!(set.matches(query).into_iter().count(), 0);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_try_execute_command() {
|
|
QueryRouter::setup();
|
|
let mut qr = QueryRouter::new();
|
|
|
|
// SetShardingKey
|
|
let query = simple_query("SET SHARDING KEY TO 13");
|
|
assert_eq!(
|
|
qr.try_execute_command(&query),
|
|
Some((Command::SetShardingKey, String::from("0")))
|
|
);
|
|
assert_eq!(qr.shard(), 0);
|
|
|
|
// SetShard
|
|
let query = simple_query("SET SHARD TO '1'");
|
|
assert_eq!(
|
|
qr.try_execute_command(&query),
|
|
Some((Command::SetShard, String::from("1")))
|
|
);
|
|
assert_eq!(qr.shard(), 1);
|
|
|
|
// ShowShard
|
|
let query = simple_query("SHOW SHARD");
|
|
assert_eq!(
|
|
qr.try_execute_command(&query),
|
|
Some((Command::ShowShard, String::from("1")))
|
|
);
|
|
|
|
// SetServerRole
|
|
let roles = ["primary", "replica", "any", "auto", "primary"];
|
|
let verify_roles = [
|
|
Some(Role::Primary),
|
|
Some(Role::Replica),
|
|
None,
|
|
None,
|
|
Some(Role::Primary),
|
|
];
|
|
let query_parser_enabled = [false, false, false, true, false];
|
|
|
|
for (idx, role) in roles.iter().enumerate() {
|
|
let query = simple_query(&format!("SET SERVER ROLE TO '{}'", role));
|
|
assert_eq!(
|
|
qr.try_execute_command(&query),
|
|
Some((Command::SetServerRole, String::from(*role)))
|
|
);
|
|
assert_eq!(qr.role(), verify_roles[idx],);
|
|
assert_eq!(qr.query_parser_enabled(), query_parser_enabled[idx],);
|
|
|
|
// ShowServerRole
|
|
let query = simple_query("SHOW SERVER ROLE");
|
|
assert_eq!(
|
|
qr.try_execute_command(&query),
|
|
Some((Command::ShowServerRole, String::from(*role)))
|
|
);
|
|
}
|
|
|
|
let primary_reads = ["on", "off", "default"];
|
|
let primary_reads_enabled = ["on", "off", "on"];
|
|
|
|
for (idx, primary_reads) in primary_reads.iter().enumerate() {
|
|
assert_eq!(
|
|
qr.try_execute_command(&simple_query(&format!(
|
|
"SET PRIMARY READS TO {}",
|
|
primary_reads
|
|
))),
|
|
Some((Command::SetPrimaryReads, String::from(*primary_reads)))
|
|
);
|
|
assert_eq!(
|
|
qr.try_execute_command(&simple_query("SHOW PRIMARY READS")),
|
|
Some((
|
|
Command::ShowPrimaryReads,
|
|
String::from(primary_reads_enabled[idx])
|
|
))
|
|
);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_enable_query_parser() {
|
|
QueryRouter::setup();
|
|
let mut qr = QueryRouter::new();
|
|
let query = simple_query("SET SERVER ROLE TO 'auto'");
|
|
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
|
|
|
|
assert!(qr.try_execute_command(&query) != None);
|
|
assert!(qr.query_parser_enabled());
|
|
assert_eq!(qr.role(), None);
|
|
|
|
let query = simple_query("INSERT INTO test_table VALUES (1)");
|
|
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
|
assert_eq!(qr.role(), Some(Role::Primary));
|
|
|
|
let query = simple_query("SELECT * FROM test_table");
|
|
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
|
assert_eq!(qr.role(), Some(Role::Replica));
|
|
|
|
assert!(qr.query_parser_enabled());
|
|
let query = simple_query("SET SERVER ROLE TO 'default'");
|
|
assert!(qr.try_execute_command(&query) != None);
|
|
assert!(!qr.query_parser_enabled());
|
|
}
|
|
|
|
#[test]
|
|
fn test_update_from_pool_settings() {
|
|
QueryRouter::setup();
|
|
|
|
let pool_settings = PoolSettings {
|
|
pool_mode: PoolMode::Transaction,
|
|
load_balancing_mode: crate::config::LoadBalancingMode::Random,
|
|
shards: 2,
|
|
user: crate::config::User::default(),
|
|
default_role: Some(Role::Replica),
|
|
query_parser_enabled: true,
|
|
primary_reads_enabled: false,
|
|
sharding_function: ShardingFunction::PgBigintHash,
|
|
automatic_sharding_key: Some(String::from("test.id")),
|
|
healthcheck_delay: PoolSettings::default().healthcheck_delay,
|
|
healthcheck_timeout: PoolSettings::default().healthcheck_timeout,
|
|
ban_time: PoolSettings::default().ban_time,
|
|
sharding_key_regex: None,
|
|
shard_id_regex: None,
|
|
regex_search_limit: 1000,
|
|
auth_query: None,
|
|
auth_query_password: None,
|
|
auth_query_user: None,
|
|
db: "test".to_string(),
|
|
};
|
|
let mut qr = QueryRouter::new();
|
|
assert_eq!(qr.active_role, None);
|
|
assert_eq!(qr.active_shard, None);
|
|
assert_eq!(qr.query_parser_enabled, None);
|
|
assert_eq!(qr.primary_reads_enabled, None);
|
|
|
|
// Internal state must not be changed due to this, only defaults
|
|
qr.update_pool_settings(pool_settings.clone());
|
|
|
|
assert_eq!(qr.active_role, None);
|
|
assert_eq!(qr.active_shard, None);
|
|
assert!(qr.query_parser_enabled());
|
|
assert!(!qr.primary_reads_enabled());
|
|
|
|
let q1 = simple_query("SET SERVER ROLE TO 'primary'");
|
|
assert!(qr.try_execute_command(&q1) != None);
|
|
assert_eq!(qr.active_role.unwrap(), Role::Primary);
|
|
|
|
let q2 = simple_query("SET SERVER ROLE TO 'default'");
|
|
assert!(qr.try_execute_command(&q2) != None);
|
|
assert_eq!(qr.active_role.unwrap(), pool_settings.default_role);
|
|
}
|
|
|
|
#[test]
|
|
fn test_parse_multiple_queries() {
|
|
QueryRouter::setup();
|
|
|
|
let mut qr = QueryRouter::new();
|
|
assert!(qr
|
|
.infer(&QueryRouter::parse(&simple_query("BEGIN; SELECT 1; COMMIT;")).unwrap())
|
|
.is_ok());
|
|
assert_eq!(qr.role(), Role::Primary);
|
|
|
|
assert!(qr
|
|
.infer(&QueryRouter::parse(&simple_query("SELECT 1; SELECT 2;")).unwrap())
|
|
.is_ok());
|
|
assert_eq!(qr.role(), Role::Replica);
|
|
|
|
assert!(qr
|
|
.infer(
|
|
&QueryRouter::parse(&simple_query(
|
|
"SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
|
|
))
|
|
.unwrap()
|
|
)
|
|
.is_ok());
|
|
assert_eq!(qr.role(), Role::Primary);
|
|
}
|
|
|
|
#[test]
|
|
fn test_regex_shard_parsing() {
|
|
QueryRouter::setup();
|
|
|
|
let pool_settings = PoolSettings {
|
|
pool_mode: PoolMode::Transaction,
|
|
load_balancing_mode: crate::config::LoadBalancingMode::Random,
|
|
shards: 5,
|
|
user: crate::config::User::default(),
|
|
default_role: Some(Role::Replica),
|
|
query_parser_enabled: true,
|
|
primary_reads_enabled: false,
|
|
sharding_function: ShardingFunction::PgBigintHash,
|
|
automatic_sharding_key: None,
|
|
healthcheck_delay: PoolSettings::default().healthcheck_delay,
|
|
healthcheck_timeout: PoolSettings::default().healthcheck_timeout,
|
|
ban_time: PoolSettings::default().ban_time,
|
|
sharding_key_regex: Some(Regex::new(r"/\* sharding_key: (\d+) \*/").unwrap()),
|
|
shard_id_regex: Some(Regex::new(r"/\* shard_id: (\d+) \*/").unwrap()),
|
|
regex_search_limit: 1000,
|
|
auth_query: None,
|
|
auth_query_password: None,
|
|
auth_query_user: None,
|
|
db: "test".to_string(),
|
|
};
|
|
let mut qr = QueryRouter::new();
|
|
qr.update_pool_settings(pool_settings.clone());
|
|
|
|
// Shard should start out unset
|
|
assert_eq!(qr.active_shard, None);
|
|
|
|
// Make sure setting it works
|
|
let q1 = simple_query("/* shard_id: 1 */ select 1 from foo;");
|
|
assert!(qr.try_execute_command(&q1) == None);
|
|
assert_eq!(qr.active_shard, Some(1));
|
|
|
|
// And make sure changing it works
|
|
let q2 = simple_query("/* shard_id: 0 */ select 1 from foo;");
|
|
assert!(qr.try_execute_command(&q2) == None);
|
|
assert_eq!(qr.active_shard, Some(0));
|
|
|
|
// Validate setting by shard with expected shard copied from sharding.rs tests
|
|
let q2 = simple_query("/* sharding_key: 6 */ select 1 from foo;");
|
|
assert!(qr.try_execute_command(&q2) == None);
|
|
assert_eq!(qr.active_shard, Some(2));
|
|
}
|
|
|
|
#[test]
|
|
fn test_automatic_sharding_key() {
|
|
QueryRouter::setup();
|
|
|
|
let mut qr = QueryRouter::new();
|
|
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
|
|
qr.pool_settings.shards = 3;
|
|
|
|
assert!(qr
|
|
.infer(&QueryRouter::parse(&simple_query("SELECT * FROM data WHERE id = 5")).unwrap())
|
|
.is_ok());
|
|
assert_eq!(qr.shard(), 2);
|
|
|
|
assert!(qr
|
|
.infer(
|
|
&QueryRouter::parse(&simple_query(
|
|
"SELECT one, two, three FROM public.data WHERE id = 6"
|
|
))
|
|
.unwrap()
|
|
)
|
|
.is_ok());
|
|
assert_eq!(qr.shard(), 0);
|
|
|
|
assert!(qr
|
|
.infer(
|
|
&QueryRouter::parse(&simple_query(
|
|
"SELECT * FROM data
|
|
INNER JOIN t2 ON data.id = 5
|
|
AND t2.data_id = data.id
|
|
WHERE data.id = 5"
|
|
))
|
|
.unwrap()
|
|
)
|
|
.is_ok());
|
|
assert_eq!(qr.shard(), 2);
|
|
|
|
// Shard did not move because we couldn't determine the sharding key since it could be ambiguous
|
|
// in the query.
|
|
assert!(qr
|
|
.infer(
|
|
&QueryRouter::parse(&simple_query(
|
|
"SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
|
|
))
|
|
.unwrap()
|
|
)
|
|
.is_ok());
|
|
assert_eq!(qr.shard(), 2);
|
|
|
|
assert!(qr
|
|
.infer(
|
|
&QueryRouter::parse(&simple_query(
|
|
r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
|
|
))
|
|
.unwrap()
|
|
)
|
|
.is_ok());
|
|
assert_eq!(qr.shard(), 0);
|
|
|
|
assert!(qr
|
|
.infer(
|
|
&QueryRouter::parse(&simple_query(
|
|
r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
|
|
))
|
|
.unwrap()
|
|
)
|
|
.is_ok());
|
|
assert_eq!(qr.shard(), 2);
|
|
|
|
// Super unique sharding key
|
|
qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string());
|
|
assert!(qr
|
|
.infer(
|
|
&QueryRouter::parse(&simple_query(
|
|
"SELECT * FROM table_x WHERE unique_enough_column_name = 6"
|
|
))
|
|
.unwrap()
|
|
)
|
|
.is_ok());
|
|
assert_eq!(qr.shard(), 0);
|
|
|
|
assert!(qr
|
|
.infer(
|
|
&QueryRouter::parse(&simple_query("SELECT * FROM table_y WHERE another_key = 5"))
|
|
.unwrap()
|
|
)
|
|
.is_ok());
|
|
assert_eq!(qr.shard(), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_prepared_statements() {
|
|
let stmt = "SELECT * FROM data WHERE id = $1";
|
|
|
|
let mut bind = BytesMut::from(&b"B"[..]);
|
|
|
|
let mut payload = BytesMut::from(&b"\0\0"[..]);
|
|
payload.put_i16(0);
|
|
payload.put_i16(1);
|
|
payload.put_i32(1);
|
|
payload.put(&b"5"[..]);
|
|
payload.put_i16(0);
|
|
|
|
bind.put_i32(payload.len() as i32 + 4);
|
|
bind.put(payload);
|
|
|
|
let mut qr = QueryRouter::new();
|
|
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
|
|
qr.pool_settings.shards = 3;
|
|
|
|
assert!(qr
|
|
.infer(&QueryRouter::parse(&simple_query(stmt)).unwrap())
|
|
.is_ok());
|
|
assert_eq!(qr.placeholders.len(), 1);
|
|
|
|
assert!(qr.infer_shard_from_bind(&bind));
|
|
assert_eq!(qr.shard(), 2);
|
|
assert!(qr.placeholders.is_empty());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_table_access_plugin() {
|
|
use crate::config::TableAccess;
|
|
let ta = TableAccess {
|
|
enabled: true,
|
|
tables: vec![String::from("pg_database")],
|
|
};
|
|
|
|
crate::plugins::table_access::setup(&ta);
|
|
|
|
QueryRouter::setup();
|
|
|
|
let qr = QueryRouter::new();
|
|
|
|
let query = simple_query("SELECT * FROM pg_database");
|
|
let ast = QueryRouter::parse(&query).unwrap();
|
|
|
|
let res = qr.execute_plugins(&ast).await;
|
|
|
|
assert_eq!(
|
|
res,
|
|
Ok(PluginOutput::Deny(
|
|
"permission for table \"pg_database\" denied".to_string()
|
|
))
|
|
);
|
|
}
|
|
}
|