Files
pgcat/src/query_router.rs
Mohammad Dashti 49164a733c Revert "Fixed clippy issues on CI."
This reverts commit a9fa6ba189.
2023-10-06 15:13:29 -07:00

1891 lines
66 KiB
Rust

/// Route queries automatically based on explicitly requested
/// or implied query characteristics.
use bytes::{Buf, BytesMut};
use log::{debug, error};
use once_cell::sync::OnceCell;
use regex::{Regex, RegexSet};
use sqlparser::ast::Statement::{Delete, Insert, Query, StartTransaction, Update};
use sqlparser::ast::{
Assignment, BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, Statement,
TableFactor, TableWithJoins, Value,
};
use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;
use crate::config::Role;
use crate::errors::Error;
use crate::messages::BytesMutReader;
use crate::plugins::{Intercept, Plugin, PluginOutput, QueryLogger, TableAccess};
use crate::pool::PoolSettings;
use crate::sharding::Sharder;
use std::collections::BTreeSet;
use std::io::Cursor;
use std::{cmp, mem};
/// Regexes used to parse custom commands.
const CUSTOM_SQL_REGEXES: [&str; 7] = [
r"(?i)^ *SET SHARDING KEY TO '?([0-9]+)'? *;? *$",
r"(?i)^ *SET SHARD TO '?([0-9]+|ANY)'? *;? *$",
r"(?i)^ *SHOW SHARD *;? *$",
r"(?i)^ *SET SERVER ROLE TO '(PRIMARY|REPLICA|ANY|AUTO|DEFAULT)' *;? *$",
r"(?i)^ *SHOW SERVER ROLE *;? *$",
r"(?i)^ *SET PRIMARY READS TO '?(on|off|default)'? *;? *$",
r"(?i)^ *SHOW PRIMARY READS *;? *$",
];
/// Custom commands.
#[derive(PartialEq, Debug)]
pub enum Command {
SetShardingKey,
SetShard,
ShowShard,
SetServerRole,
ShowServerRole,
SetPrimaryReads,
ShowPrimaryReads,
}
#[derive(PartialEq, Debug)]
pub enum ShardingKey {
Value(i64),
Placeholder(i16),
}
#[derive(Clone, Debug)]
enum ParameterFormat {
Text,
Binary,
Uniform(Box<ParameterFormat>),
Specified(Vec<ParameterFormat>),
}
/// Quickly test for match when a query is received.
static CUSTOM_SQL_REGEX_SET: OnceCell<RegexSet> = OnceCell::new();
// Get the value inside the custom command.
static CUSTOM_SQL_REGEX_LIST: OnceCell<Vec<Regex>> = OnceCell::new();
/// The query router.
pub struct QueryRouter {
/// Which shard we should be talking to right now.
active_shard: Option<usize>,
/// Which server should we be talking to.
active_role: Option<Role>,
/// Should we try to parse queries to route them to replicas or primary automatically
query_parser_enabled: Option<bool>,
/// Include the primary into the replica pool for reads.
primary_reads_enabled: Option<bool>,
/// Pool configuration.
pool_settings: PoolSettings,
// Placeholders from prepared statement.
placeholders: Vec<i16>,
}
impl QueryRouter {
/// One-time initialization of regexes
/// that parse our custom SQL protocol.
pub fn setup() -> bool {
let set = match RegexSet::new(CUSTOM_SQL_REGEXES) {
Ok(rgx) => rgx,
Err(err) => {
error!("QueryRouter::setup Could not compile regex set: {:?}", err);
return false;
}
};
let list: Vec<_> = CUSTOM_SQL_REGEXES
.iter()
.map(|rgx| Regex::new(rgx).unwrap())
.collect();
assert_eq!(list.len(), set.len());
match CUSTOM_SQL_REGEX_LIST.set(list) {
Ok(_) => true,
Err(_) => return false,
};
CUSTOM_SQL_REGEX_SET.set(set).is_ok()
}
/// Create a new instance of the query router.
/// Each client gets its own.
pub fn new() -> QueryRouter {
QueryRouter {
active_shard: None,
active_role: None,
query_parser_enabled: None,
primary_reads_enabled: None,
pool_settings: PoolSettings::default(),
placeholders: Vec::new(),
}
}
/// Pool settings can change because of a config reload.
pub fn update_pool_settings(&mut self, pool_settings: PoolSettings) {
self.pool_settings = pool_settings;
}
pub fn pool_settings(&self) -> &PoolSettings {
&self.pool_settings
}
/// Try to parse a command and execute it.
pub fn try_execute_command(&mut self, message_buffer: &BytesMut) -> Option<(Command, String)> {
let mut message_cursor = Cursor::new(message_buffer);
let code = message_cursor.get_u8() as char;
let len = message_cursor.get_i32() as usize;
let comment_shard_routing_enabled = self.pool_settings.shard_id_regex.is_some()
|| self.pool_settings.sharding_key_regex.is_some();
// Check for any sharding regex matches in any queries
if comment_shard_routing_enabled {
match code {
// For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement
'P' | 'Q' => {
// Check only the first block of bytes configured by the pool settings
let seg = cmp::min(len - 5, self.pool_settings.regex_search_limit);
let query_start_index = mem::size_of::<u8>() + mem::size_of::<i32>();
let initial_segment = String::from_utf8_lossy(
&message_buffer[query_start_index..query_start_index + seg],
);
// Check for a shard_id included in the query
if let Some(shard_id_regex) = &self.pool_settings.shard_id_regex {
let shard_id = shard_id_regex.captures(&initial_segment).and_then(|cap| {
cap.get(1).and_then(|id| id.as_str().parse::<usize>().ok())
});
if let Some(shard_id) = shard_id {
debug!("Setting shard to {:?}", shard_id);
self.set_shard(Some(shard_id));
// Skip other command processing since a sharding command was found
return None;
}
}
// Check for a sharding_key included in the query
if let Some(sharding_key_regex) = &self.pool_settings.sharding_key_regex {
let sharding_key =
sharding_key_regex
.captures(&initial_segment)
.and_then(|cap| {
cap.get(1).and_then(|id| id.as_str().parse::<i64>().ok())
});
if let Some(sharding_key) = sharding_key {
debug!("Setting sharding_key to {:?}", sharding_key);
self.set_sharding_key(sharding_key);
// Skip other command processing since a sharding command was found
return None;
}
}
}
_ => {}
}
}
// Only simple protocol supported for commands processed below
if code != 'Q' {
return None;
}
let query = message_cursor.read_string().unwrap();
let regex_set = match CUSTOM_SQL_REGEX_SET.get() {
Some(regex_set) => regex_set,
None => return None,
};
let regex_list = match CUSTOM_SQL_REGEX_LIST.get() {
Some(regex_list) => regex_list,
None => return None,
};
let matches: Vec<_> = regex_set.matches(&query).into_iter().collect();
// This is not a custom query, try to infer which
// server it'll go to if the query parser is enabled.
if matches.len() != 1 {
debug!("Regular query, not a command");
return None;
}
let command = match matches[0] {
0 => Command::SetShardingKey,
1 => Command::SetShard,
2 => Command::ShowShard,
3 => Command::SetServerRole,
4 => Command::ShowServerRole,
5 => Command::SetPrimaryReads,
6 => Command::ShowPrimaryReads,
_ => unreachable!(),
};
let mut value = match command {
Command::SetShardingKey
| Command::SetShard
| Command::SetServerRole
| Command::SetPrimaryReads => {
// Capture value. I know this re-runs the regex engine, but I haven't
// figured out a better way just yet. I think I can write a single Regex
// that matches all 5 custom SQL patterns, but maybe that's not very legible?
//
// I think this is faster than running the Regex engine 5 times.
match regex_list[matches[0]].captures(&query) {
Some(captures) => match captures.get(1) {
Some(value) => value.as_str().to_string(),
None => return None,
},
None => return None,
}
}
Command::ShowShard => self
.shard()
.map_or_else(|| "unset".to_string(), |x| x.to_string()),
Command::ShowServerRole => match self.active_role {
Some(Role::Primary) => Role::Primary.to_string(),
Some(Role::Replica) => Role::Replica.to_string(),
Some(Role::Mirror) => Role::Mirror.to_string(),
None => {
if self.query_parser_enabled() {
String::from("auto")
} else {
String::from("any")
}
}
},
Command::ShowPrimaryReads => match self.primary_reads_enabled() {
true => String::from("on"),
false => String::from("off"),
},
};
match command {
Command::SetShardingKey => {
// TODO: some error handling here
value = self
.set_sharding_key(value.parse::<i64>().unwrap())
.unwrap()
.to_string();
}
Command::SetShard => {
self.active_shard = match value.to_ascii_uppercase().as_ref() {
"ANY" => Some(rand::random::<usize>() % self.pool_settings.shards),
_ => Some(value.parse::<usize>().unwrap()),
};
}
Command::SetServerRole => {
self.active_role = match value.to_ascii_lowercase().as_ref() {
"primary" => {
self.query_parser_enabled = Some(false);
Some(Role::Primary)
}
"replica" => {
self.query_parser_enabled = Some(false);
Some(Role::Replica)
}
"any" => {
self.query_parser_enabled = Some(false);
None
}
"auto" => {
self.query_parser_enabled = Some(true);
None
}
"default" => {
self.active_role = self.pool_settings.default_role;
self.query_parser_enabled = None;
self.active_role
}
_ => unreachable!(),
};
}
Command::SetPrimaryReads => {
if value == "on" {
debug!("Setting primary reads to on");
self.primary_reads_enabled = Some(true);
} else if value == "off" {
debug!("Setting primary reads to off");
self.primary_reads_enabled = Some(false);
} else if value == "default" {
debug!("Setting primary reads to default");
self.primary_reads_enabled = None;
}
}
_ => (),
}
Some((command, value))
}
pub fn parse(&self, message: &BytesMut) -> Result<Vec<Statement>, Error> {
let mut message_cursor = Cursor::new(message);
let code = message_cursor.get_u8() as char;
let len = message_cursor.get_i32() as usize;
if let Some(max_length) = self.pool_settings.query_parser_max_length {
if len > max_length {
return Err(Error::QueryRouterParserError(format!(
"Query too long for parser: {} > {}",
len, max_length
)));
}
};
let query = match code {
// Query
'Q' => {
let query = message_cursor.read_string().unwrap();
debug!("Query: '{}'", query);
query
}
// Parse (prepared statement)
'P' => {
// Reads statement name
let _name = message_cursor.read_string().unwrap();
// Reads query string
let query = message_cursor.read_string().unwrap();
debug!("Prepared statement: '{}'", query);
query
}
_ => return Err(Error::UnsupportedStatement),
};
match Parser::parse_sql(&PostgreSqlDialect {}, &query) {
Ok(ast) => Ok(ast),
Err(err) => {
debug!("{}: {}", err, query);
Err(Error::QueryRouterParserError(err.to_string()))
}
}
}
/// Try to infer which server to connect to based on the contents of the query.
pub fn infer(&mut self, ast: &Vec<sqlparser::ast::Statement>) -> Result<(), Error> {
if !self.pool_settings.query_parser_read_write_splitting {
return Ok(()); // Nothing to do
}
debug!("Inferring role");
if ast.is_empty() {
// That's weird, no idea, let's go to primary
self.active_role = Some(Role::Primary);
return Err(Error::QueryRouterParserError("empty query".into()));
}
let mut visited_write_statement = false;
let mut prev_inferred_shard = None;
for q in ast {
match q {
// All transactions go to the primary, probably a write.
StartTransaction { .. } => {
self.active_role = Some(Role::Primary);
break;
}
// Likely a read-only query
Query(query) => {
match &self.pool_settings.automatic_sharding_key {
Some(_) => {
// TODO: if we have multiple queries in the same message,
// we can either split them and execute them individually
// or discard shard selection. If they point to the same shard though,
// we can let them through as-is.
// This is basically building a database now :)
let inferred_shard = self.infer_shard(query);
self.handle_inferred_shard(inferred_shard, &mut prev_inferred_shard)?;
}
None => (),
};
// If we already visited a write statement, we should be going to the primary.
if !visited_write_statement {
self.active_role = match self.primary_reads_enabled() {
false => Some(Role::Replica), // If primary should not be receiving reads, use a replica.
true => None, // Any server role is fine in this case.
}
}
}
// Likely a write
_ => {
match &self.pool_settings.automatic_sharding_key {
Some(_) => {
// TODO: similar to the above, if we have multiple queries in the
// same message, we can either split them and execute them individually
// or discard shard selection. If they point to the same shard though,
// we can let them through as-is.
let inferred_shard = self.infer_shard_on_write(q)?;
self.handle_inferred_shard(inferred_shard, &mut prev_inferred_shard)?;
}
None => (),
};
visited_write_statement = true;
self.active_role = Some(Role::Primary);
}
};
}
Ok(())
}
fn handle_inferred_shard(
&mut self,
inferred_shard: Option<usize>,
prev_inferred_shard: &mut Option<usize>,
) -> Result<(), Error> {
if let Some(shard) = inferred_shard {
if let Some(prev_shard) = *prev_inferred_shard {
if prev_shard != shard {
debug!("Found more than one shard in the query, not supported yet");
return Err(Error::QueryRouterParserError(
"multiple shards in query".into(),
));
}
}
*prev_inferred_shard = Some(shard);
self.active_shard = Some(shard);
debug!("Automatically using shard: {:?}", self.active_shard);
};
Ok(())
}
fn infer_shard_on_write(&mut self, q: &Statement) -> Result<Option<usize>, Error> {
let mut exprs = Vec::new();
// Collect all table names from the query.
let mut table_names = Vec::new();
match q {
Insert {
or,
into: _,
table_name,
columns,
overwrite: _,
source,
partitioned,
after_columns,
table: _,
on: _,
returning: _,
} => {
// Not supported in postgres.
assert!(or.is_none());
assert!(partitioned.is_none());
assert!(after_columns.is_empty());
Self::process_table(table_name, &mut table_names);
Self::process_query(source, &mut exprs, &mut table_names, &Some(columns));
}
Delete {
tables,
from,
using,
selection,
returning: _,
} => {
if let Some(expr) = selection {
exprs.push(expr.clone());
}
// Multi tables delete are not supported in postgres.
assert!(tables.is_empty());
Self::process_tables_with_join(from, &mut exprs, &mut table_names);
if let Some(using_tbl_with_join) = using {
Self::process_tables_with_join(
using_tbl_with_join,
&mut exprs,
&mut table_names,
);
}
Self::process_selection(selection, &mut exprs);
}
Update {
table,
assignments,
from,
selection,
returning: _,
} => {
Self::process_table_with_join(table, &mut exprs, &mut table_names);
if let Some(from_tbl) = from {
Self::process_table_with_join(from_tbl, &mut exprs, &mut table_names);
}
Self::process_selection(selection, &mut exprs);
self.assignment_parser(assignments)?;
}
_ => {
return Ok(None);
}
};
Ok(self.infer_shard_from_exprs(exprs, table_names))
}
fn process_query(
query: &sqlparser::ast::Query,
exprs: &mut Vec<Expr>,
table_names: &mut Vec<Vec<Ident>>,
columns: &Option<&Vec<Ident>>,
) {
match &*query.body {
SetExpr::Query(query) => {
Self::process_query(query, exprs, table_names, columns);
}
// SELECT * FROM ...
// We understand that pretty well.
SetExpr::Select(select) => {
Self::process_tables_with_join(&select.from, exprs, table_names);
// Parse the actual "FROM ..."
Self::process_selection(&select.selection, exprs);
}
SetExpr::Values(values) => {
if let Some(cols) = columns {
for row in values.rows.iter() {
for (i, expr) in row.iter().enumerate() {
if cols.len() > i {
exprs.push(Expr::BinaryOp {
left: Box::new(Expr::Identifier(cols[i].clone())),
op: BinaryOperator::Eq,
right: Box::new(expr.clone()),
});
}
}
}
}
}
_ => (),
};
}
fn process_selection(selection: &Option<Expr>, exprs: &mut Vec<Expr>) {
match selection {
Some(selection) => {
exprs.push(selection.clone());
}
None => (),
};
}
fn process_tables_with_join(
tables: &[TableWithJoins],
exprs: &mut Vec<Expr>,
table_names: &mut Vec<Vec<Ident>>,
) {
for table in tables.iter() {
Self::process_table_with_join(table, exprs, table_names);
}
}
fn process_table_with_join(
table: &TableWithJoins,
exprs: &mut Vec<Expr>,
table_names: &mut Vec<Vec<Ident>>,
) {
if let TableFactor::Table { name, .. } = &table.relation {
Self::process_table(name, table_names);
};
// Get table names from all the joins.
for join in table.joins.iter() {
if let TableFactor::Table { name, .. } = &join.relation {
Self::process_table(name, table_names);
};
// We can filter results based on join conditions, e.g.
// SELECT * FROM t INNER JOIN B ON B.sharding_key = 5;
if let JoinOperator::Inner(JoinConstraint::On(expr)) = &join.join_operator {
// Parse the selection criteria later.
exprs.push(expr.clone());
};
}
}
fn process_table(name: &sqlparser::ast::ObjectName, table_names: &mut Vec<Vec<Ident>>) {
table_names.push(name.0.clone())
}
/// Parse the shard number from the Bind message
/// which contains the arguments for a prepared statement.
///
/// N.B.: Only supports anonymous prepared statements since we don't
/// keep a cache of them in PgCat.
pub fn infer_shard_from_bind(&mut self, message: &BytesMut) -> bool {
if !self.pool_settings.query_parser_read_write_splitting {
return false; // Nothing to do
}
debug!("Parsing bind message");
let mut message_cursor = Cursor::new(message);
let code = message_cursor.get_u8() as char;
let len = message_cursor.get_i32();
if code != 'B' {
debug!("Not a bind packet");
return false;
}
// Check message length
if message.len() != len as usize + 1 {
debug!(
"Message has wrong length, expected {}, but have {}",
len,
message.len()
);
return false;
}
// There are no shard keys in the prepared statement.
if self.placeholders.is_empty() {
debug!("There are no placeholders in the prepared statement that matched the automatic sharding key");
return false;
}
let sharder = Sharder::new(
self.pool_settings.shards,
self.pool_settings.sharding_function,
);
let mut shards = BTreeSet::new();
let _portal = message_cursor.read_string();
let _name = message_cursor.read_string();
let num_params = message_cursor.get_i16();
let parameter_format = match num_params {
0 => ParameterFormat::Text, // Text
1 => {
let param_format = message_cursor.get_i16();
ParameterFormat::Uniform(match param_format {
0 => Box::new(ParameterFormat::Text),
1 => Box::new(ParameterFormat::Binary),
_ => unreachable!(),
})
}
n => {
let mut v = Vec::with_capacity(n as usize);
for _ in 0..n {
let param_format = message_cursor.get_i16();
v.push(match param_format {
0 => ParameterFormat::Text,
1 => ParameterFormat::Binary,
_ => unreachable!(),
});
}
ParameterFormat::Specified(v)
}
};
let num_parameters = message_cursor.get_i16();
for i in 0..num_parameters {
let mut len = message_cursor.get_i32() as usize;
let format = match &parameter_format {
ParameterFormat::Text => ParameterFormat::Text,
ParameterFormat::Uniform(format) => *format.clone(),
ParameterFormat::Specified(formats) => formats[i as usize].clone(),
_ => unreachable!(),
};
debug!("Parameter {} (len: {}): {:?}", i, len, format);
// Postgres counts placeholders starting at 1
let placeholder = i + 1;
if self.placeholders.contains(&placeholder) {
let value = match format {
ParameterFormat::Text => {
let mut value = String::new();
while len > 0 {
value.push(message_cursor.get_u8() as char);
len -= 1;
}
match value.parse::<i64>() {
Ok(value) => value,
Err(_) => {
debug!("Error parsing bind value: {}", value);
continue;
}
}
}
ParameterFormat::Binary => match len {
2 => message_cursor.get_i16() as i64,
4 => message_cursor.get_i32() as i64,
8 => message_cursor.get_i64(),
_ => {
error!(
"Got wrong length for integer type parameter in bind: {}",
len
);
continue;
}
},
_ => unreachable!(),
};
shards.insert(sharder.shard(value));
}
}
self.placeholders.clear();
self.placeholders.shrink_to_fit();
// We only support querying one shard at a time.
// TODO: Support multi-shard queries some day.
if shards.len() == 1 {
debug!("Found one sharding key");
self.set_shard(Some(*shards.first().unwrap()));
true
} else {
debug!("Found no sharding keys");
false
}
}
/// An `assignments` exists in the `UPDATE` statements. This parses the assignments and makes
/// sure that we are not updating the sharding key. It's not supported yet.
fn assignment_parser(&self, assignments: &Vec<Assignment>) -> Result<(), Error> {
let sharding_key = self
.pool_settings
.automatic_sharding_key
.as_ref()
.unwrap()
.split('.')
.map(|ident| Ident::new(ident.to_lowercase()))
.collect::<Vec<Ident>>();
// Sharding key must be always fully qualified
assert_eq!(sharding_key.len(), 2);
for a in assignments {
if sharding_key[0].value == "*"
&& sharding_key[1].value == a.id.last().unwrap().value.to_lowercase()
{
return Err(Error::QueryRouterParserError(
"Sharding key cannot be updated.".into(),
));
}
}
Ok(())
}
/// A `selection` is the `WHERE` clause. This parses
/// the clause and extracts the sharding key, if present.
fn selection_parser(&self, expr: &Expr, table_names: &Vec<Vec<Ident>>) -> Vec<ShardingKey> {
let mut result = Vec::new();
let mut found = false;
let sharding_key = self
.pool_settings
.automatic_sharding_key
.as_ref()
.unwrap()
.split('.')
.map(|ident| Ident::new(ident.to_lowercase()))
.collect::<Vec<Ident>>();
// Sharding key must be always fully qualified
assert_eq!(sharding_key.len(), 2);
// This parses `sharding_key = 5`. But it's technically
// legal to write `5 = sharding_key`. I don't judge the people
// who do that, but I think ORMs will still use the first variant,
// so we can leave the second as a TODO.
if let Expr::BinaryOp { left, op, right } = expr {
match &**left {
Expr::BinaryOp { .. } => result.extend(self.selection_parser(left, table_names)),
Expr::Identifier(ident) => {
// Only if we're dealing with only one table
// and there is no ambiguity
if ident.value.to_lowercase() == sharding_key[1].value {
// Sharding key is unique enough, don't worry about
// table names.
if &sharding_key[0].value == "*" {
found = true;
} else if table_names.len() == 1 {
let table = &table_names[0];
if table.len() == 1 {
// Table is not fully qualified, e.g.
// SELECT * FROM t WHERE sharding_key = 5
// Make sure the table name from the sharding key matches
// the table name from the query.
found = sharding_key[0].value == table[0].value.to_lowercase();
} else if table.len() == 2 {
// Table name is fully qualified with the schema: e.g.
// SELECT * FROM public.t WHERE sharding_key = 5
// Ignore the schema (TODO: at some point, we want schema support)
// and use the table name only.
found = sharding_key[0].value == table[1].value.to_lowercase();
} else {
debug!("Got table name with more than two idents, which is not possible");
}
}
}
}
Expr::CompoundIdentifier(idents) => {
// The key is fully qualified in the query,
// it will exist or Postgres will throw an error.
if idents.len() == 2 {
found = (&sharding_key[0].value == "*"
|| sharding_key[0].value == idents[0].value.to_lowercase())
&& sharding_key[1].value == idents[1].value.to_lowercase();
}
// TODO: key can have schema as well, e.g. public.data.id (len == 3)
}
_ => (),
};
match op {
BinaryOperator::Eq => (),
BinaryOperator::Or => (),
BinaryOperator::And => (),
_ => {
// TODO: support other operators than equality.
debug!("Unsupported operation: {:?}", op);
return Vec::new();
}
};
match &**right {
Expr::BinaryOp { .. } => result.extend(self.selection_parser(right, table_names)),
Expr::Value(Value::Number(value, ..)) => {
if found {
match value.parse::<i64>() {
Ok(value) => result.push(ShardingKey::Value(value)),
Err(_) => {
debug!("Sharding key was not an integer: {}", value);
}
};
}
}
Expr::Value(Value::Placeholder(placeholder)) => {
match placeholder.replace('$', "").parse::<i16>() {
Ok(placeholder) => result.push(ShardingKey::Placeholder(placeholder)),
Err(_) => {
debug!(
"Prepared statement didn't have integer placeholders: {}",
placeholder
);
}
}
}
_ => (),
};
}
debug!("Sharding keys found: {:?}", result);
result
}
/// Try to figure out which shard the query should go to.
fn infer_shard(&mut self, query: &sqlparser::ast::Query) -> Option<usize> {
let mut exprs = Vec::new();
// Collect all table names from the query.
let mut table_names = Vec::new();
Self::process_query(query, &mut exprs, &mut table_names, &None);
self.infer_shard_from_exprs(exprs, table_names)
}
fn infer_shard_from_exprs(
&mut self,
exprs: Vec<Expr>,
table_names: Vec<Vec<Ident>>,
) -> Option<usize> {
let mut shards = BTreeSet::new();
let sharder = Sharder::new(
self.pool_settings.shards,
self.pool_settings.sharding_function,
);
// Look for sharding keys in either the join condition
// or the selection.
for expr in exprs.iter() {
let sharding_keys = self.selection_parser(expr, &table_names);
// TODO: Add support for prepared statements here.
// This should just give us the position of the value in the `B` message.
for value in sharding_keys {
match value {
ShardingKey::Value(value) => {
let shard = sharder.shard(value);
shards.insert(shard);
}
ShardingKey::Placeholder(position) => {
self.placeholders.push(position);
}
};
}
}
match shards.len() {
// Didn't find a sharding key, you're on your own.
0 => {
debug!("No sharding keys found");
None
}
1 => Some(shards.into_iter().last().unwrap()),
// TODO: support querying multiple shards (some day...)
_ => {
debug!("More than one sharding key found");
None
}
}
}
/// Add your plugins here and execute them.
pub async fn execute_plugins(&self, ast: &Vec<Statement>) -> Result<PluginOutput, Error> {
let plugins = match self.pool_settings.plugins {
Some(ref plugins) => plugins,
None => return Ok(PluginOutput::Allow),
};
if let Some(ref query_logger) = plugins.query_logger {
let mut query_logger = QueryLogger {
enabled: query_logger.enabled,
user: &self.pool_settings.user.username,
db: &self.pool_settings.db,
};
let _ = query_logger.run(self, ast).await;
}
if let Some(ref intercept) = plugins.intercept {
let mut intercept = Intercept {
enabled: intercept.enabled,
config: intercept,
};
let result = intercept.run(self, ast).await;
if let Ok(PluginOutput::Intercept(output)) = result {
return Ok(PluginOutput::Intercept(output));
}
}
if let Some(ref table_access) = plugins.table_access {
let mut table_access = TableAccess {
enabled: table_access.enabled,
tables: &table_access.tables,
};
let result = table_access.run(self, ast).await;
if let Ok(PluginOutput::Deny(error)) = result {
return Ok(PluginOutput::Deny(error));
}
}
Ok(PluginOutput::Allow)
}
fn set_sharding_key(&mut self, sharding_key: i64) -> Option<usize> {
let sharder = Sharder::new(
self.pool_settings.shards,
self.pool_settings.sharding_function,
);
let shard = sharder.shard(sharding_key);
self.set_shard(Some(shard));
self.active_shard
}
/// Get the current desired server role we should be talking to.
pub fn role(&self) -> Option<Role> {
self.active_role
}
/// Get desired shard we should be talking to.
pub fn shard(&self) -> Option<usize> {
self.active_shard
}
pub fn set_shard(&mut self, shard: Option<usize>) {
self.active_shard = shard;
}
/// Should we attempt to parse queries?
pub fn query_parser_enabled(&self) -> bool {
match self.query_parser_enabled {
None => {
debug!(
"Using pool settings, query_parser_enabled: {}",
self.pool_settings.query_parser_enabled
);
self.pool_settings.query_parser_enabled
}
Some(value) => {
debug!(
"Using query parser override, query_parser_enabled: {}",
value
);
value
}
}
}
pub fn primary_reads_enabled(&self) -> bool {
match self.primary_reads_enabled {
None => self.pool_settings.primary_reads_enabled,
Some(value) => value,
}
}
}
impl Default for QueryRouter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::config::PoolMode;
use crate::messages::simple_query;
use crate::sharding::ShardingFunction;
use bytes::BufMut;
#[test]
fn test_defaults() {
QueryRouter::setup();
let qr = QueryRouter::new();
assert_eq!(qr.role(), None);
}
#[test]
fn test_infer_replica() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
qr.pool_settings.query_parser_read_write_splitting = true;
assert!(qr
.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"))
.is_some());
assert!(qr.query_parser_enabled());
assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
.is_some());
let queries = vec![
simple_query("SELECT * FROM items WHERE id = 5"),
simple_query(
"SELECT id, name, value FROM items INNER JOIN prices ON item.id = prices.item_id",
),
simple_query("WITH t AS (SELECT * FROM items) SELECT * FROM t"),
];
for query in queries {
// It's a recognized query
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Replica));
}
}
#[test]
fn test_infer_primary() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
qr.pool_settings.query_parser_read_write_splitting = true;
let queries = vec![
simple_query("UPDATE items SET name = 'pumpkin' WHERE id = 5"),
simple_query("INSERT INTO items (id, name) VALUES (5, 'pumpkin')"),
simple_query("DELETE FROM items WHERE id = 5"),
simple_query("BEGIN"), // Transaction start
];
for query in queries {
// It's a recognized query
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Primary));
}
}
#[test]
fn test_infer_primary_reads_enabled() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
let query = simple_query("SELECT * FROM items WHERE id = 5");
assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO on"))
.is_some());
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), None);
}
#[test]
fn test_infer_parse_prepared() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
qr.pool_settings.query_parser_read_write_splitting = true;
qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"));
assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
.is_some());
let prepared_stmt = BytesMut::from(
&b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..],
);
let mut res = BytesMut::from(&b"P"[..]);
res.put_i32(prepared_stmt.len() as i32 + 4 + 1 + 2);
res.put_u8(0);
res.put(prepared_stmt);
res.put_i16(0);
assert!(qr.infer(&qr.parse(&res).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Replica));
}
#[test]
fn test_regex_set() {
QueryRouter::setup();
let tests = [
// Upper case
"SET SHARDING KEY TO '1'",
"SET SHARD TO '1'",
"SHOW SHARD",
"SET SERVER ROLE TO 'replica'",
"SET SERVER ROLE TO 'primary'",
"SET SERVER ROLE TO 'any'",
"SET SERVER ROLE TO 'auto'",
"SHOW SERVER ROLE",
"SET PRIMARY READS TO 'on'",
"SET PRIMARY READS TO 'off'",
"SET PRIMARY READS TO 'default'",
"SHOW PRIMARY READS",
// Lower case
"set sharding key to '1'",
"set shard to '1'",
"show shard",
"set server role to 'replica'",
"set server role to 'primary'",
"set server role to 'any'",
"set server role to 'auto'",
"show server role",
"set primary reads to 'on'",
"set primary reads to 'OFF'",
"set primary reads to 'deFaUlt'",
// No quotes
"SET SHARDING KEY TO 11235",
"SET SHARD TO 15",
"SET PRIMARY READS TO off",
// Spaces and semicolon
" SET SHARDING KEY TO 11235 ; ",
" SET SHARD TO 15; ",
" SET SHARDING KEY TO 11235 ;",
" SET SERVER ROLE TO 'primary'; ",
" SET SERVER ROLE TO 'primary' ; ",
" SET SERVER ROLE TO 'primary' ;",
" SET PRIMARY READS TO 'off' ;",
];
// Which regexes it'll match to in the list
let matches = [
0, 1, 2, 3, 3, 3, 3, 4, 5, 5, 5, 6, 0, 1, 2, 3, 3, 3, 3, 4, 5, 5, 5, 0, 1, 5, 0, 1, 0,
3, 3, 3, 5,
];
let list = CUSTOM_SQL_REGEX_LIST.get().unwrap();
let set = CUSTOM_SQL_REGEX_SET.get().unwrap();
for (i, test) in tests.iter().enumerate() {
if !list[matches[i]].is_match(test) {
println!("{} does not match {}", test, list[matches[i]]);
panic!();
}
assert_eq!(set.matches(test).into_iter().count(), 1);
}
let bad = [
"SELECT * FROM table",
"SELECT * FROM table WHERE value = 'set sharding key to 5'", // Don't capture things in the middle of the query
];
for query in &bad {
assert_eq!(set.matches(query).into_iter().count(), 0);
}
}
#[test]
fn test_try_execute_command() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
// SetShardingKey
let query = simple_query("SET SHARDING KEY TO 13");
assert_eq!(
qr.try_execute_command(&query),
Some((Command::SetShardingKey, String::from("0")))
);
assert_eq!(qr.shard().unwrap(), 0);
// SetShard
let query = simple_query("SET SHARD TO '1'");
assert_eq!(
qr.try_execute_command(&query),
Some((Command::SetShard, String::from("1")))
);
assert_eq!(qr.shard().unwrap(), 1);
// ShowShard
let query = simple_query("SHOW SHARD");
assert_eq!(
qr.try_execute_command(&query),
Some((Command::ShowShard, String::from("1")))
);
// SetServerRole
let roles = ["primary", "replica", "any", "auto", "primary"];
let verify_roles = [
Some(Role::Primary),
Some(Role::Replica),
None,
None,
Some(Role::Primary),
];
let query_parser_enabled = [false, false, false, true, false];
for (idx, role) in roles.iter().enumerate() {
let query = simple_query(&format!("SET SERVER ROLE TO '{}'", role));
assert_eq!(
qr.try_execute_command(&query),
Some((Command::SetServerRole, String::from(*role)))
);
assert_eq!(qr.role(), verify_roles[idx],);
assert_eq!(qr.query_parser_enabled(), query_parser_enabled[idx],);
// ShowServerRole
let query = simple_query("SHOW SERVER ROLE");
assert_eq!(
qr.try_execute_command(&query),
Some((Command::ShowServerRole, String::from(*role)))
);
}
let primary_reads = ["on", "off", "default"];
let primary_reads_enabled = ["on", "off", "on"];
for (idx, primary_reads) in primary_reads.iter().enumerate() {
assert_eq!(
qr.try_execute_command(&simple_query(&format!(
"SET PRIMARY READS TO {}",
primary_reads
))),
Some((Command::SetPrimaryReads, String::from(*primary_reads)))
);
assert_eq!(
qr.try_execute_command(&simple_query("SHOW PRIMARY READS")),
Some((
Command::ShowPrimaryReads,
String::from(primary_reads_enabled[idx])
))
);
}
}
#[test]
fn test_enable_query_parser() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
qr.pool_settings.query_parser_read_write_splitting = true;
let query = simple_query("SET SERVER ROLE TO 'auto'");
assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
.is_some());
assert!(qr.try_execute_command(&query).is_some());
assert!(qr.query_parser_enabled());
assert_eq!(qr.role(), None);
let query = simple_query("INSERT INTO test_table VALUES (1)");
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Primary));
let query = simple_query("SELECT * FROM test_table");
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Replica));
assert!(qr.query_parser_enabled());
let query = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(&query).is_some());
assert!(!qr.query_parser_enabled());
}
#[test]
fn test_update_from_pool_settings() {
QueryRouter::setup();
let pool_settings = PoolSettings {
pool_mode: PoolMode::Transaction,
load_balancing_mode: crate::config::LoadBalancingMode::Random,
shards: 2,
user: crate::config::User::default(),
default_role: Some(Role::Replica),
query_parser_enabled: true,
query_parser_max_length: None,
query_parser_read_write_splitting: true,
primary_reads_enabled: false,
sharding_function: ShardingFunction::PgBigintHash,
automatic_sharding_key: Some(String::from("test.id")),
healthcheck_delay: PoolSettings::default().healthcheck_delay,
healthcheck_timeout: PoolSettings::default().healthcheck_timeout,
ban_time: PoolSettings::default().ban_time,
sharding_key_regex: None,
shard_id_regex: None,
default_shard: crate::config::DefaultShard::Shard(0),
regex_search_limit: 1000,
auth_query: None,
auth_query_password: None,
auth_query_user: None,
db: "test".to_string(),
plugins: None,
};
let mut qr = QueryRouter::new();
assert_eq!(qr.active_role, None);
assert_eq!(qr.active_shard, None);
assert_eq!(qr.query_parser_enabled, None);
assert_eq!(qr.primary_reads_enabled, None);
// Internal state must not be changed due to this, only defaults
qr.update_pool_settings(pool_settings.clone());
assert_eq!(qr.active_role, None);
assert_eq!(qr.active_shard, None);
assert!(qr.query_parser_enabled());
assert!(!qr.primary_reads_enabled());
let q1 = simple_query("SET SERVER ROLE TO 'primary'");
assert!(qr.try_execute_command(&q1).is_some());
assert_eq!(qr.active_role.unwrap(), Role::Primary);
let q2 = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(&q2).is_some());
assert_eq!(qr.active_role.unwrap(), pool_settings.default_role);
}
#[test]
fn test_parse_multiple_queries() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
assert!(qr
.infer(&qr.parse(&simple_query("BEGIN; SELECT 1; COMMIT;")).unwrap())
.is_ok());
assert_eq!(qr.role(), Role::Primary);
assert!(qr
.infer(&qr.parse(&simple_query("SELECT 1; SELECT 2;")).unwrap())
.is_ok());
assert_eq!(qr.role(), Role::Replica);
assert!(qr
.infer(
&qr.parse(&simple_query(
"SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.role(), Role::Primary);
}
#[test]
fn test_regex_shard_parsing() {
QueryRouter::setup();
let pool_settings = PoolSettings {
pool_mode: PoolMode::Transaction,
load_balancing_mode: crate::config::LoadBalancingMode::Random,
shards: 5,
user: crate::config::User::default(),
default_role: Some(Role::Replica),
query_parser_enabled: true,
query_parser_max_length: None,
query_parser_read_write_splitting: true,
primary_reads_enabled: false,
sharding_function: ShardingFunction::PgBigintHash,
automatic_sharding_key: None,
healthcheck_delay: PoolSettings::default().healthcheck_delay,
healthcheck_timeout: PoolSettings::default().healthcheck_timeout,
ban_time: PoolSettings::default().ban_time,
sharding_key_regex: Some(Regex::new(r"/\* sharding_key: (\d+) \*/").unwrap()),
shard_id_regex: Some(Regex::new(r"/\* shard_id: (\d+) \*/").unwrap()),
default_shard: crate::config::DefaultShard::Shard(0),
regex_search_limit: 1000,
auth_query: None,
auth_query_password: None,
auth_query_user: None,
db: "test".to_string(),
plugins: None,
};
let mut qr = QueryRouter::new();
qr.update_pool_settings(pool_settings.clone());
// Shard should start out unset
assert_eq!(qr.active_shard, None);
// Don't panic when short query eg. ; is sent
let q0 = simple_query(";");
assert!(qr.try_execute_command(&q0).is_none());
assert_eq!(qr.active_shard, None);
// Make sure setting it works
let q1 = simple_query("/* shard_id: 1 */ select 1 from foo;");
assert!(qr.try_execute_command(&q1).is_none());
assert_eq!(qr.active_shard, Some(1));
// And make sure changing it works
let q2 = simple_query("/* shard_id: 0 */ select 1 from foo;");
assert!(qr.try_execute_command(&q2).is_none());
assert_eq!(qr.active_shard, Some(0));
// Validate setting by shard with expected shard copied from sharding.rs tests
let q2 = simple_query("/* sharding_key: 6 */ select 1 from foo;");
assert!(qr.try_execute_command(&q2).is_none());
assert_eq!(qr.active_shard, Some(2));
}
#[test]
fn test_automatic_sharding_key() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
qr.pool_settings.shards = 3;
qr.pool_settings.query_parser_read_write_splitting = true;
assert!(qr
.infer(
&qr.parse(&simple_query("SELECT * FROM data WHERE id = 5"))
.unwrap(),
)
.is_ok());
assert_eq!(qr.shard().unwrap(), 2);
assert!(qr
.infer(
&qr.parse(&simple_query(
"SELECT one, two, three FROM public.data WHERE id = 6"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard().unwrap(), 0);
assert!(qr
.infer(
&qr.parse(&simple_query(
"SELECT * FROM data
INNER JOIN t2 ON data.id = 5
AND t2.data_id = data.id
WHERE data.id = 5"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard().unwrap(), 2);
// Shard did not move because we couldn't determine the sharding key since it could be ambiguous
// in the query.
assert!(qr
.infer(
&qr.parse(&simple_query(
"SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard().unwrap(), 2);
assert!(qr
.infer(
&qr.parse(&simple_query(
r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard().unwrap(), 0);
assert!(qr
.infer(
&qr.parse(&simple_query(
r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard().unwrap(), 2);
// Super unique sharding key
qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string());
assert!(qr
.infer(
&qr.parse(&simple_query(
"SELECT * FROM table_x WHERE unique_enough_column_name = 6"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard().unwrap(), 0);
assert!(qr
.infer(
&qr.parse(&simple_query("SELECT * FROM table_y WHERE another_key = 5"))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard().unwrap(), 0);
}
fn auto_shard_wrapper(qry: &str, should_succeed: bool) -> Option<usize> {
let mut qr = QueryRouter::new();
qr.pool_settings.automatic_sharding_key = Some("*.w_id".to_string());
qr.pool_settings.shards = 3;
qr.pool_settings.query_parser_read_write_splitting = true;
assert_eq!(qr.shard(), None);
let infer_res = qr.infer(&qr.parse(&simple_query(qry)).unwrap());
assert_eq!(infer_res.is_ok(), should_succeed);
qr.shard()
}
fn auto_shard(qry: &str) -> Option<usize> {
auto_shard_wrapper(qry, true)
}
fn auto_shard_fails(qry: &str) -> Option<usize> {
auto_shard_wrapper(qry, false)
}
#[test]
fn test_automatic_sharding_insert_update_delete() {
QueryRouter::setup();
assert_eq!(
auto_shard_fails(
"UPDATE ORDERS SET w_id = 3 WHERE O_ID = 3 AND O_D_ID = 3 AND W_ID = 5"
),
None
);
assert_eq!(
auto_shard_fails(
"UPDATE ORDERS o SET o.W_ID = 3 WHERE o.O_ID = 3 AND o.O_D_ID = 3 AND o.W_ID = 5"
),
None
);
assert_eq!(
auto_shard(
"UPDATE ORDERS o SET o.O_CARRIER_ID = 3 WHERE o.O_ID = 3 AND o.O_D_ID = 3 AND o.W_ID = 5"
),
Some(2)
);
}
#[test]
fn test_automatic_sharding_key_tpcc() {
QueryRouter::setup();
assert_eq!(auto_shard("SELECT * FROM my_tbl WHERE w_id = 5"), Some(2));
assert_eq!(
auto_shard("BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ"),
None
);
assert_eq!(auto_shard("COMMIT"), None);
assert_eq!(auto_shard("ROLLBACK"), None);
assert_eq!(auto_shard("SELECT NO_O_ID FROM NEW_ORDER WHERE NO_D_ID = 7 AND W_ID = 5 AND NO_O_ID > 3 LIMIT 3"), Some(2));
assert_eq!(auto_shard("SELECT NO_O_ID FROM NEW_ORDER no WHERE no.NO_D_ID = 7 AND no.W_ID = 5 AND no.NO_O_ID > 3 LIMIT 3"), Some(2));
assert_eq!(
auto_shard("DELETE FROM NEW_ORDER WHERE NO_D_ID = 7 AND W_ID = 5 AND NO_O_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT O_C_ID FROM ORDERS WHERE O_ID = 3 AND O_D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard(
"UPDATE ORDERS SET O_CARRIER_ID = 3 WHERE O_ID = 3 AND O_D_ID = 3 AND W_ID = 5"
),
Some(2)
);
assert_eq!(
auto_shard("UPDATE ORDER_LINE SET OL_DELIVERY_D = 3 WHERE OL_O_ID = 3 AND OL_D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT SUM(OL_AMOUNT) FROM ORDER_LINE WHERE OL_O_ID = 3 AND OL_D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE CUSTOMER SET C_BALANCE = C_BALANCE + 3 WHERE C_ID = 3 AND C_D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT W_TAX FROM WAREHOUSE WHERE W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT D_TAX, D_NEXT_O_ID FROM DISTRICT WHERE D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE DISTRICT SET D_NEXT_O_ID = 3 WHERE D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_DISCOUNT, C_LAST, C_CREDIT FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("INSERT INTO ORDERS (O_ID, O_D_ID, W_ID, O_C_ID, O_ENTRY_D, O_CARRIER_ID, O_OL_CNT, O_ALL_LOCAL) VALUES (3, 3, 5, 3, 3, 3, 3, 3)"),
Some(2)
);
assert_eq!(
auto_shard("INSERT INTO NEW_ORDER (NO_O_ID, NO_D_ID, W_ID) VALUES (3, 3, 5)"),
Some(2)
);
assert_eq!(
auto_shard("SELECT I_PRICE, I_NAME, I_DATA FROM ITEM WHERE I_ID = 3"),
None
);
assert_eq!(
auto_shard("SELECT S_QUANTITY, S_DATA, S_YTD, S_ORDER_CNT, S_REMOTE_CNT, S_DIST_03 FROM STOCK WHERE S_I_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE STOCK SET S_QUANTITY = 3, S_YTD = 3, S_ORDER_CNT = 3, S_REMOTE_CNT = 3 WHERE S_I_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("INSERT INTO ORDER_LINE (OL_O_ID, OL_D_ID, W_ID, OL_NUMBER, OL_I_ID, OL_SUPPLY_W_ID, OL_DELIVERY_D, OL_QUANTITY, OL_AMOUNT, OL_DIST_INFO) VALUES (3, 3, 5, 3, 3, 3, 3, 3, 3, 3)"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_BALANCE FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_BALANCE FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_LAST = '3' ORDER BY C_FIRST"),
Some(2)
);
assert_eq!(
auto_shard("SELECT O_ID, O_CARRIER_ID, O_ENTRY_D FROM ORDERS WHERE W_ID = 5 AND O_D_ID = 3 AND O_C_ID = 3 ORDER BY O_ID DESC LIMIT 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT OL_SUPPLY_W_ID, OL_I_ID, OL_QUANTITY, OL_AMOUNT, OL_DELIVERY_D FROM ORDER_LINE WHERE W_ID = 5 AND OL_D_ID = 3 AND OL_O_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT W_NAME, W_STREET_1, W_STREET_2, W_CITY, W_STATE, W_ZIP FROM WAREHOUSE WHERE W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE WAREHOUSE SET W_YTD = W_YTD + 3 WHERE W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT D_NAME, D_STREET_1, D_STREET_2, D_CITY, D_STATE, D_ZIP FROM DISTRICT WHERE W_ID = 5 AND D_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE DISTRICT SET D_YTD = D_YTD + 3 WHERE W_ID = 5 AND D_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_STREET_1, C_STREET_2, C_CITY, C_STATE, C_ZIP, C_PHONE, C_SINCE, C_CREDIT, C_CREDIT_LIM, C_DISCOUNT, C_BALANCE, C_YTD_PAYMENT, C_PAYMENT_CNT, C_DATA FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_STREET_1, C_STREET_2, C_CITY, C_STATE, C_ZIP, C_PHONE, C_SINCE, C_CREDIT, C_CREDIT_LIM, C_DISCOUNT, C_BALANCE, C_YTD_PAYMENT, C_PAYMENT_CNT, C_DATA FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_LAST = '3' ORDER BY C_FIRST"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE CUSTOMER SET C_BALANCE = 3, C_YTD_PAYMENT = 3, C_PAYMENT_CNT = 3, C_DATA = 3 WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE CUSTOMER SET C_BALANCE = 3, C_YTD_PAYMENT = 3, C_PAYMENT_CNT = 3 WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(auto_shard("INSERT INTO HISTORY (H_C_ID, H_C_D_ID, H_C_W_ID, H_D_ID, W_ID, H_DATE, H_AMOUNT, H_DATA) VALUES (3, 3, 5, 3, 5, 3, 3, 3)"), Some(2));
assert_eq!(
auto_shard("SELECT D_NEXT_O_ID FROM DISTRICT WHERE W_ID = 5 AND D_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard(
"SELECT COUNT(DISTINCT(OL_I_ID)) FROM ORDER_LINE, STOCK
WHERE ORDER_LINE.W_ID = 5
AND OL_D_ID = 3
AND OL_O_ID < 3
AND OL_O_ID >= 3
AND STOCK.W_ID = 5
AND S_I_ID = OL_I_ID
AND S_QUANTITY < 3"
),
Some(2)
);
// This is a distributed query and contains two shards
assert_eq!(
auto_shard(
"SELECT COUNT(DISTINCT(OL_I_ID)) FROM ORDER_LINE, STOCK
WHERE ORDER_LINE.W_ID = 5
AND OL_D_ID = 3
AND OL_O_ID < 3
AND OL_O_ID >= 3
AND STOCK.W_ID = 7
AND S_I_ID = OL_I_ID
AND S_QUANTITY < 3"
),
None
);
}
#[test]
fn test_prepared_statements() {
let stmt = "SELECT * FROM data WHERE id = $1";
let mut bind = BytesMut::from(&b"B"[..]);
let mut payload = BytesMut::from(&b"\0\0"[..]);
payload.put_i16(0);
payload.put_i16(1);
payload.put_i32(1);
payload.put(&b"5"[..]);
payload.put_i16(0);
bind.put_i32(payload.len() as i32 + 4);
bind.put(payload);
let mut qr = QueryRouter::new();
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
qr.pool_settings.shards = 3;
qr.pool_settings.query_parser_read_write_splitting = true;
assert!(qr.infer(&qr.parse(&simple_query(stmt)).unwrap()).is_ok());
assert_eq!(qr.placeholders.len(), 1);
assert!(qr.infer_shard_from_bind(&bind));
assert_eq!(qr.shard().unwrap(), 2);
assert!(qr.placeholders.is_empty());
}
#[tokio::test]
async fn test_table_access_plugin() {
use crate::config::{Plugins, TableAccess};
let table_access = TableAccess {
enabled: true,
tables: vec![String::from("pg_database")],
};
let plugins = Plugins {
table_access: Some(table_access),
intercept: None,
query_logger: None,
prewarmer: None,
};
QueryRouter::setup();
let pool_settings = PoolSettings {
query_parser_enabled: true,
plugins: Some(plugins),
..Default::default()
};
let mut qr = QueryRouter::new();
qr.update_pool_settings(pool_settings);
let query = simple_query("SELECT * FROM pg_database");
let ast = qr.parse(&query).unwrap();
let res = qr.execute_plugins(&ast).await;
assert_eq!(
res,
Ok(PluginOutput::Deny(
"permission for table \"pg_database\" denied".to_string()
))
);
}
#[tokio::test]
async fn test_plugins_disabled_by_defaault() {
QueryRouter::setup();
let qr = QueryRouter::new();
let query = simple_query("SELECT * FROM pg_database");
let ast = qr.parse(&query).unwrap();
let res = qr.execute_plugins(&ast).await;
assert_eq!(res, Ok(PluginOutput::Allow));
}
}