Automatic sharding for INSERT, UPDATE, and DELETE statements. (#610)

Added support for INSERT, UPDATE, and DELETE for auto-sharding.
This commit is contained in:
Mohammad Dashti
2023-10-03 09:36:13 -07:00
committed by GitHub
parent 51cd13b8b5
commit c2a483f36a

View File

@@ -4,10 +4,10 @@ use bytes::{Buf, BytesMut};
use log::{debug, error};
use once_cell::sync::OnceCell;
use regex::{Regex, RegexSet};
use sqlparser::ast::Statement::{Query, StartTransaction};
use sqlparser::ast::Statement::{Delete, Insert, Query, StartTransaction, Update};
use sqlparser::ast::{
BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, Statement, TableFactor,
Value,
Assignment, BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, Statement,
TableFactor, TableWithJoins, Value,
};
use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;
@@ -403,6 +403,9 @@ impl QueryRouter {
return Err(Error::QueryRouterParserError("empty query".into()));
}
let mut visited_write_statement = false;
let mut prev_inferred_shard = None;
for q in ast {
match q {
// All transactions go to the primary, probably a write.
@@ -420,29 +423,38 @@ impl QueryRouter {
// or discard shard selection. If they point to the same shard though,
// we can let them through as-is.
// This is basically building a database now :)
match self.infer_shard(query) {
Some(shard) => {
self.active_shard = Some(shard);
debug!("Automatically using shard: {:?}", self.active_shard);
}
None => (),
};
let inferred_shard = self.infer_shard(query);
self.handle_inferred_shard(inferred_shard, &mut prev_inferred_shard)?;
}
None => (),
};
self.active_role = match self.primary_reads_enabled() {
false => Some(Role::Replica), // If primary should not be receiving reads, use a replica.
true => None, // Any server role is fine in this case.
// If we already visited a write statement, we should be going to the primary.
if !visited_write_statement {
self.active_role = match self.primary_reads_enabled() {
false => Some(Role::Replica), // If primary should not be receiving reads, use a replica.
true => None, // Any server role is fine in this case.
}
}
}
// Likely a write
_ => {
match &self.pool_settings.automatic_sharding_key {
Some(_) => {
// TODO: similar to the above, if we have multiple queries in the
// same message, we can either split them and execute them individually
// or discard shard selection. If they point to the same shard though,
// we can let them through as-is.
let inferred_shard = self.infer_shard_on_write(q)?;
self.handle_inferred_shard(inferred_shard, &mut prev_inferred_shard)?;
}
None => (),
};
visited_write_statement = true;
self.active_role = Some(Role::Primary);
break;
}
};
}
@@ -450,6 +462,208 @@ impl QueryRouter {
Ok(())
}
fn handle_inferred_shard(
&mut self,
inferred_shard: Option<usize>,
prev_inferred_shard: &mut Option<usize>,
) -> Result<(), Error> {
match inferred_shard {
Some(shard) => {
if let Some(prev_shard) = *prev_inferred_shard {
if prev_shard != shard {
debug!("Found more than one shard in the query, not supported yet");
return Err(Error::QueryRouterParserError(
"multiple shards in query".into(),
));
}
}
*prev_inferred_shard = Some(shard);
self.active_shard = Some(shard);
debug!("Automatically using shard: {:?}", self.active_shard);
}
None => (),
};
Ok(())
}
fn infer_shard_on_write(&mut self, q: &Statement) -> Result<Option<usize>, Error> {
let mut exprs = Vec::new();
// Collect all table names from the query.
let mut table_names = Vec::new();
match q {
Insert {
or,
into: _,
table_name,
columns,
overwrite: _,
source,
partitioned,
after_columns,
table: _,
on: _,
returning: _,
} => {
// Not supported in postgres.
assert!(or.is_none());
assert!(partitioned.is_none());
assert!(after_columns.is_empty());
Self::process_table(table_name, &mut table_names);
Self::process_query(&*source, &mut exprs, &mut table_names, &Some(columns));
}
Delete {
tables,
from,
using,
selection,
returning: _,
} => {
if let Some(expr) = selection {
exprs.push(expr.clone());
}
// Multi tables delete are not supported in postgres.
assert!(tables.is_empty());
Self::process_tables_with_join(&from, &mut exprs, &mut table_names);
if let Some(using_tbl_with_join) = using {
Self::process_tables_with_join(
using_tbl_with_join,
&mut exprs,
&mut table_names,
);
}
Self::process_selection(selection, &mut exprs);
}
Update {
table,
assignments,
from,
selection,
returning: _,
} => {
Self::process_table_with_join(table, &mut exprs, &mut table_names);
if let Some(from_tbl) = from {
Self::process_table_with_join(from_tbl, &mut exprs, &mut table_names);
}
Self::process_selection(selection, &mut exprs);
self.assignment_parser(assignments)?;
}
_ => {
return Ok(None);
}
};
Ok(self.infer_shard_from_exprs(exprs, table_names))
}
fn process_query(
query: &sqlparser::ast::Query,
exprs: &mut Vec<Expr>,
table_names: &mut Vec<Vec<Ident>>,
columns: &Option<&Vec<Ident>>,
) {
match &*query.body {
SetExpr::Query(query) => {
Self::process_query(&*query, exprs, table_names, columns);
}
// SELECT * FROM ...
// We understand that pretty well.
SetExpr::Select(select) => {
Self::process_tables_with_join(&select.from, exprs, table_names);
// Parse the actual "FROM ..."
Self::process_selection(&select.selection, exprs);
}
SetExpr::Values(values) => {
if let Some(cols) = columns {
for row in values.rows.iter() {
for (i, expr) in row.iter().enumerate() {
if cols.len() > i {
exprs.push(Expr::BinaryOp {
left: Box::new(Expr::Identifier(cols[i].clone())),
op: BinaryOperator::Eq,
right: Box::new(expr.clone()),
});
}
}
}
}
}
_ => (),
};
}
fn process_selection(selection: &Option<Expr>, exprs: &mut Vec<Expr>) {
match selection {
Some(selection) => {
exprs.push(selection.clone());
}
None => (),
};
}
fn process_tables_with_join(
tables: &Vec<TableWithJoins>,
exprs: &mut Vec<Expr>,
table_names: &mut Vec<Vec<Ident>>,
) {
for table in tables.iter() {
Self::process_table_with_join(table, exprs, table_names);
}
}
fn process_table_with_join(
table: &TableWithJoins,
exprs: &mut Vec<Expr>,
table_names: &mut Vec<Vec<Ident>>,
) {
match &table.relation {
TableFactor::Table { name, .. } => {
Self::process_table(name, table_names);
}
_ => (),
};
// Get table names from all the joins.
for join in table.joins.iter() {
match &join.relation {
TableFactor::Table { name, .. } => {
Self::process_table(name, table_names);
}
_ => (),
};
// We can filter results based on join conditions, e.g.
// SELECT * FROM t INNER JOIN B ON B.sharding_key = 5;
match &join.join_operator {
JoinOperator::Inner(inner_join) => match &inner_join {
JoinConstraint::On(expr) => {
// Parse the selection criteria later.
exprs.push(expr.clone());
}
_ => (),
},
_ => (),
};
}
}
fn process_table(name: &sqlparser::ast::ObjectName, table_names: &mut Vec<Vec<Ident>>) {
table_names.push(name.0.clone())
}
/// Parse the shard number from the Bind message
/// which contains the arguments for a prepared statement.
///
@@ -592,6 +806,33 @@ impl QueryRouter {
}
}
/// An `assignments` exists in the `UPDATE` statements. This parses the assignments and makes
/// sure that we are not updating the sharding key. It's not supported yet.
fn assignment_parser(&self, assignments: &Vec<Assignment>) -> Result<(), Error> {
let sharding_key = self
.pool_settings
.automatic_sharding_key
.as_ref()
.unwrap()
.split(".")
.map(|ident| Ident::new(ident.to_lowercase()))
.collect::<Vec<Ident>>();
// Sharding key must be always fully qualified
assert_eq!(sharding_key.len(), 2);
for a in assignments {
if sharding_key[0].value == "*" {
if sharding_key[1].value == a.id.last().unwrap().value.to_lowercase() {
return Err(Error::QueryRouterParserError(
"Sharding key cannot be updated.".into(),
));
}
}
}
Ok(())
}
/// A `selection` is the `WHERE` clause. This parses
/// the clause and extracts the sharding key, if present.
fn selection_parser(&self, expr: &Expr, table_names: &Vec<Vec<Ident>>) -> Vec<ShardingKey> {
@@ -604,7 +845,7 @@ impl QueryRouter {
.as_ref()
.unwrap()
.split(".")
.map(|ident| Ident::new(ident))
.map(|ident| Ident::new(ident.to_lowercase()))
.collect::<Vec<Ident>>();
// Sharding key must be always fully qualified
@@ -620,7 +861,7 @@ impl QueryRouter {
Expr::Identifier(ident) => {
// Only if we're dealing with only one table
// and there is no ambiguity
if &ident.value == &sharding_key[1].value {
if &ident.value.to_lowercase() == &sharding_key[1].value {
// Sharding key is unique enough, don't worry about
// table names.
if &sharding_key[0].value == "*" {
@@ -633,13 +874,13 @@ impl QueryRouter {
// SELECT * FROM t WHERE sharding_key = 5
// Make sure the table name from the sharding key matches
// the table name from the query.
found = &sharding_key[0].value == &table[0].value;
found = &sharding_key[0].value == &table[0].value.to_lowercase();
} else if table.len() == 2 {
// Table name is fully qualified with the schema: e.g.
// SELECT * FROM public.t WHERE sharding_key = 5
// Ignore the schema (TODO: at some point, we want schema support)
// and use the table name only.
found = &sharding_key[0].value == &table[1].value;
found = &sharding_key[0].value == &table[1].value.to_lowercase();
} else {
debug!("Got table name with more than two idents, which is not possible");
}
@@ -651,8 +892,9 @@ impl QueryRouter {
// The key is fully qualified in the query,
// it will exist or Postgres will throw an error.
if idents.len() == 2 {
found = &sharding_key[0].value == &idents[0].value
&& &sharding_key[1].value == &idents[1].value;
found = (&sharding_key[0].value == "*"
|| &sharding_key[0].value == &idents[0].value.to_lowercase())
&& &sharding_key[1].value == &idents[1].value.to_lowercase();
}
// TODO: key can have schema as well, e.g. public.data.id (len == 3)
}
@@ -705,100 +947,48 @@ impl QueryRouter {
/// Try to figure out which shard the query should go to.
fn infer_shard(&mut self, query: &sqlparser::ast::Query) -> Option<usize> {
let mut shards = BTreeSet::new();
let mut exprs = Vec::new();
match &*query.body {
SetExpr::Query(query) => {
match self.infer_shard(&*query) {
Some(shard) => {
// Collect all table names from the query.
let mut table_names = Vec::new();
Self::process_query(query, &mut exprs, &mut table_names, &None);
self.infer_shard_from_exprs(exprs, table_names)
}
fn infer_shard_from_exprs(
&mut self,
exprs: Vec<Expr>,
table_names: Vec<Vec<Ident>>,
) -> Option<usize> {
let mut shards = BTreeSet::new();
let sharder = Sharder::new(
self.pool_settings.shards,
self.pool_settings.sharding_function,
);
// Look for sharding keys in either the join condition
// or the selection.
for expr in exprs.iter() {
let sharding_keys = self.selection_parser(expr, &table_names);
// TODO: Add support for prepared statements here.
// This should just give us the position of the value in the `B` message.
for value in sharding_keys {
match value {
ShardingKey::Value(value) => {
let shard = sharder.shard(value);
shards.insert(shard);
}
None => (),
ShardingKey::Placeholder(position) => {
self.placeholders.push(position);
}
};
}
// SELECT * FROM ...
// We understand that pretty well.
SetExpr::Select(select) => {
// Collect all table names from the query.
let mut table_names = Vec::new();
for table in select.from.iter() {
match &table.relation {
TableFactor::Table { name, .. } => {
table_names.push(name.0.clone());
}
_ => (),
};
// Get table names from all the joins.
for join in table.joins.iter() {
match &join.relation {
TableFactor::Table { name, .. } => {
table_names.push(name.0.clone());
}
_ => (),
};
// We can filter results based on join conditions, e.g.
// SELECT * FROM t INNER JOIN B ON B.sharding_key = 5;
match &join.join_operator {
JoinOperator::Inner(inner_join) => match &inner_join {
JoinConstraint::On(expr) => {
// Parse the selection criteria later.
exprs.push(expr.clone());
}
_ => (),
},
_ => (),
};
}
}
// Parse the actual "FROM ..."
match &select.selection {
Some(selection) => {
exprs.push(selection.clone());
}
None => (),
};
let sharder = Sharder::new(
self.pool_settings.shards,
self.pool_settings.sharding_function,
);
// Look for sharding keys in either the join condition
// or the selection.
for expr in exprs.iter() {
let sharding_keys = self.selection_parser(expr, &table_names);
// TODO: Add support for prepared statements here.
// This should just give us the position of the value in the `B` message.
for value in sharding_keys {
match value {
ShardingKey::Value(value) => {
let shard = sharder.shard(value);
shards.insert(shard);
}
ShardingKey::Placeholder(position) => {
self.placeholders.push(position);
}
};
}
}
}
_ => (),
};
}
match shards.len() {
// Didn't find a sharding key, you're on your own.
0 => {
@@ -1414,6 +1604,221 @@ mod test {
assert_eq!(qr.shard().unwrap(), 0);
}
fn auto_shard_wrapper(qry: &str, should_succeed: bool) -> Option<usize> {
let mut qr = QueryRouter::new();
qr.pool_settings.automatic_sharding_key = Some("*.w_id".to_string());
qr.pool_settings.shards = 3;
qr.pool_settings.query_parser_read_write_splitting = true;
assert_eq!(qr.shard(), None);
let infer_res = qr.infer(&qr.parse(&simple_query(qry)).unwrap());
assert_eq!(infer_res.is_ok(), should_succeed);
qr.shard()
}
fn auto_shard(qry: &str) -> Option<usize> {
auto_shard_wrapper(qry, true)
}
fn auto_shard_fails(qry: &str) -> Option<usize> {
auto_shard_wrapper(qry, false)
}
#[test]
fn test_automatic_sharding_insert_update_delete() {
QueryRouter::setup();
assert_eq!(
auto_shard_fails(
"UPDATE ORDERS SET w_id = 3 WHERE O_ID = 3 AND O_D_ID = 3 AND W_ID = 5"
),
None
);
assert_eq!(
auto_shard_fails(
"UPDATE ORDERS o SET o.W_ID = 3 WHERE o.O_ID = 3 AND o.O_D_ID = 3 AND o.W_ID = 5"
),
None
);
assert_eq!(
auto_shard(
"UPDATE ORDERS o SET o.O_CARRIER_ID = 3 WHERE o.O_ID = 3 AND o.O_D_ID = 3 AND o.W_ID = 5"
),
Some(2)
);
}
#[test]
fn test_automatic_sharding_key_tpcc() {
QueryRouter::setup();
assert_eq!(auto_shard("SELECT * FROM my_tbl WHERE w_id = 5"), Some(2));
assert_eq!(
auto_shard("BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ"),
None
);
assert_eq!(auto_shard("COMMIT"), None);
assert_eq!(auto_shard("ROLLBACK"), None);
assert_eq!(auto_shard("SELECT NO_O_ID FROM NEW_ORDER WHERE NO_D_ID = 7 AND W_ID = 5 AND NO_O_ID > 3 LIMIT 3"), Some(2));
assert_eq!(auto_shard("SELECT NO_O_ID FROM NEW_ORDER no WHERE no.NO_D_ID = 7 AND no.W_ID = 5 AND no.NO_O_ID > 3 LIMIT 3"), Some(2));
assert_eq!(
auto_shard("DELETE FROM NEW_ORDER WHERE NO_D_ID = 7 AND W_ID = 5 AND NO_O_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT O_C_ID FROM ORDERS WHERE O_ID = 3 AND O_D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard(
"UPDATE ORDERS SET O_CARRIER_ID = 3 WHERE O_ID = 3 AND O_D_ID = 3 AND W_ID = 5"
),
Some(2)
);
assert_eq!(
auto_shard("UPDATE ORDER_LINE SET OL_DELIVERY_D = 3 WHERE OL_O_ID = 3 AND OL_D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT SUM(OL_AMOUNT) FROM ORDER_LINE WHERE OL_O_ID = 3 AND OL_D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE CUSTOMER SET C_BALANCE = C_BALANCE + 3 WHERE C_ID = 3 AND C_D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT W_TAX FROM WAREHOUSE WHERE W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT D_TAX, D_NEXT_O_ID FROM DISTRICT WHERE D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE DISTRICT SET D_NEXT_O_ID = 3 WHERE D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_DISCOUNT, C_LAST, C_CREDIT FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("INSERT INTO ORDERS (O_ID, O_D_ID, W_ID, O_C_ID, O_ENTRY_D, O_CARRIER_ID, O_OL_CNT, O_ALL_LOCAL) VALUES (3, 3, 5, 3, 3, 3, 3, 3)"),
Some(2)
);
assert_eq!(
auto_shard("INSERT INTO NEW_ORDER (NO_O_ID, NO_D_ID, W_ID) VALUES (3, 3, 5)"),
Some(2)
);
assert_eq!(
auto_shard("SELECT I_PRICE, I_NAME, I_DATA FROM ITEM WHERE I_ID = 3"),
None
);
assert_eq!(
auto_shard("SELECT S_QUANTITY, S_DATA, S_YTD, S_ORDER_CNT, S_REMOTE_CNT, S_DIST_03 FROM STOCK WHERE S_I_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE STOCK SET S_QUANTITY = 3, S_YTD = 3, S_ORDER_CNT = 3, S_REMOTE_CNT = 3 WHERE S_I_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("INSERT INTO ORDER_LINE (OL_O_ID, OL_D_ID, W_ID, OL_NUMBER, OL_I_ID, OL_SUPPLY_W_ID, OL_DELIVERY_D, OL_QUANTITY, OL_AMOUNT, OL_DIST_INFO) VALUES (3, 3, 5, 3, 3, 3, 3, 3, 3, 3)"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_BALANCE FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_BALANCE FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_LAST = '3' ORDER BY C_FIRST"),
Some(2)
);
assert_eq!(
auto_shard("SELECT O_ID, O_CARRIER_ID, O_ENTRY_D FROM ORDERS WHERE W_ID = 5 AND O_D_ID = 3 AND O_C_ID = 3 ORDER BY O_ID DESC LIMIT 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT OL_SUPPLY_W_ID, OL_I_ID, OL_QUANTITY, OL_AMOUNT, OL_DELIVERY_D FROM ORDER_LINE WHERE W_ID = 5 AND OL_D_ID = 3 AND OL_O_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT W_NAME, W_STREET_1, W_STREET_2, W_CITY, W_STATE, W_ZIP FROM WAREHOUSE WHERE W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE WAREHOUSE SET W_YTD = W_YTD + 3 WHERE W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT D_NAME, D_STREET_1, D_STREET_2, D_CITY, D_STATE, D_ZIP FROM DISTRICT WHERE W_ID = 5 AND D_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE DISTRICT SET D_YTD = D_YTD + 3 WHERE W_ID = 5 AND D_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_STREET_1, C_STREET_2, C_CITY, C_STATE, C_ZIP, C_PHONE, C_SINCE, C_CREDIT, C_CREDIT_LIM, C_DISCOUNT, C_BALANCE, C_YTD_PAYMENT, C_PAYMENT_CNT, C_DATA FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_STREET_1, C_STREET_2, C_CITY, C_STATE, C_ZIP, C_PHONE, C_SINCE, C_CREDIT, C_CREDIT_LIM, C_DISCOUNT, C_BALANCE, C_YTD_PAYMENT, C_PAYMENT_CNT, C_DATA FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_LAST = '3' ORDER BY C_FIRST"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE CUSTOMER SET C_BALANCE = 3, C_YTD_PAYMENT = 3, C_PAYMENT_CNT = 3, C_DATA = 3 WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE CUSTOMER SET C_BALANCE = 3, C_YTD_PAYMENT = 3, C_PAYMENT_CNT = 3 WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(auto_shard("INSERT INTO HISTORY (H_C_ID, H_C_D_ID, H_C_W_ID, H_D_ID, W_ID, H_DATE, H_AMOUNT, H_DATA) VALUES (3, 3, 5, 3, 5, 3, 3, 3)"), Some(2));
assert_eq!(
auto_shard("SELECT D_NEXT_O_ID FROM DISTRICT WHERE W_ID = 5 AND D_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard(
"SELECT COUNT(DISTINCT(OL_I_ID)) FROM ORDER_LINE, STOCK
WHERE ORDER_LINE.W_ID = 5
AND OL_D_ID = 3
AND OL_O_ID < 3
AND OL_O_ID >= 3
AND STOCK.W_ID = 5
AND S_I_ID = OL_I_ID
AND S_QUANTITY < 3"
),
Some(2)
);
// This is a distributed query and contains two shards
assert_eq!(
auto_shard(
"SELECT COUNT(DISTINCT(OL_I_ID)) FROM ORDER_LINE, STOCK
WHERE ORDER_LINE.W_ID = 5
AND OL_D_ID = 3
AND OL_O_ID < 3
AND OL_O_ID >= 3
AND STOCK.W_ID = 7
AND S_I_ID = OL_I_ID
AND S_QUANTITY < 3"
),
None
);
}
#[test]
fn test_prepared_statements() {
let stmt = "SELECT * FROM data WHERE id = $1";