diff --git a/pgcat.toml b/pgcat.toml index 26f9d7d..b5328b6 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -102,7 +102,7 @@ primary_reads_enabled = true sharding_function = "pg_bigint_hash" # Automatically parse this from queries and route queries to the right shard! -automatic_sharding_key = "id" +automatic_sharding_key = "data.id" # Idle timeout can be overwritten in the pool idle_timeout = 40000 diff --git a/src/config.rs b/src/config.rs index 2ed7aeb..517cabc 100644 --- a/src/config.rs +++ b/src/config.rs @@ -374,7 +374,7 @@ impl Pool { None } - pub fn validate(&self) -> Result<(), Error> { + pub fn validate(&mut self) -> Result<(), Error> { match self.default_role.as_ref() { "any" => (), "primary" => (), @@ -414,6 +414,25 @@ impl Pool { } } + self.automatic_sharding_key = match &self.automatic_sharding_key { + Some(key) => { + // No quotes in the key so we don't have to compare quoted + // to unquoted idents. + let key = key.replace("\"", ""); + + if key.split(".").count() != 2 { + error!( + "automatic_sharding_key '{}' must be fully qualified, e.g. t.{}`", + key, key + ); + return Err(Error::BadConfig); + } + + Some(key) + } + None => None, + }; + Ok(()) } } diff --git a/src/query_router.rs b/src/query_router.rs index bf07db7..fff5bba 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -5,7 +5,9 @@ use log::{debug, error}; use once_cell::sync::OnceCell; use regex::{Regex, RegexSet}; use sqlparser::ast::Statement::{Query, StartTransaction}; -use sqlparser::ast::{BinaryOperator, Expr, SetExpr, Value}; +use sqlparser::ast::{ + BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, TableFactor, Value, +}; use sqlparser::dialect::PostgreSqlDialect; use sqlparser::parser::Parser; @@ -403,20 +405,67 @@ impl QueryRouter { /// A `selection` is the `WHERE` clause. This parses /// the clause and extracts the sharding key, if present. - fn selection_parser(&self, expr: &Expr) -> Vec { + fn selection_parser(&self, expr: &Expr, table_names: &Vec>) -> Vec { let mut result = Vec::new(); let mut found = false; + let sharding_key = self + .pool_settings + .automatic_sharding_key + .as_ref() + .unwrap() + .split(".") + .map(|ident| Ident::new(ident)) + .collect::>(); + + // Sharding key must be always fully qualified + assert_eq!(sharding_key.len(), 2); + // This parses `sharding_key = 5`. But it's technically // legal to write `5 = sharding_key`. I don't judge the people // who do that, but I think ORMs will still use the first variant, // so we can leave the second as a TODO. if let Expr::BinaryOp { left, op, right } = expr { match &**left { - Expr::BinaryOp { .. } => result.extend(self.selection_parser(left)), + Expr::BinaryOp { .. } => result.extend(self.selection_parser(left, table_names)), Expr::Identifier(ident) => { - found = - ident.value == *self.pool_settings.automatic_sharding_key.as_ref().unwrap(); + // Only if we're dealing with only one table + // and there is no ambiguity + if &ident.value == &sharding_key[1].value { + // Sharding key is unique enough, don't worry about + // table names. + if &sharding_key[0].value == "*" { + found = true; + } else if table_names.len() == 1 { + let table = &table_names[0]; + + if table.len() == 1 { + // Table is not fully qualified, e.g. + // SELECT * FROM t WHERE sharding_key = 5 + // Make sure the table name from the sharding key matches + // the table name from the query. + found = &sharding_key[0].value == &table[0].value; + } else if table.len() == 2 { + // Table name is fully qualified with the schema: e.g. + // SELECT * FROM public.t WHERE sharding_key = 5 + // Ignore the schema (TODO: at some point, we want schema support) + // and use the table name only. + found = &sharding_key[0].value == &table[1].value; + } else { + debug!("Got table name with more than two idents, which is not possible"); + } + } + } + } + + Expr::CompoundIdentifier(idents) => { + // The key is fully qualified in the query, + // it will exist or Postgres will throw an error. + if idents.len() == 2 { + found = &sharding_key[0].value == &idents[0].value + && &sharding_key[1].value == &idents[1].value; + } + // TODO: key can have schema as well, e.g. public.data.id (len == 3) } _ => (), }; @@ -433,7 +482,7 @@ impl QueryRouter { }; match &**right { - Expr::BinaryOp { .. } => result.extend(self.selection_parser(right)), + Expr::BinaryOp { .. } => result.extend(self.selection_parser(right, table_names)), Expr::Value(Value::Number(value, ..)) => { if found { match value.parse::() { @@ -456,6 +505,7 @@ impl QueryRouter { /// Try to figure out which shard the query should go to. fn infer_shard(&self, query: &sqlparser::ast::Query) -> Option { let mut shards = BTreeSet::new(); + let mut exprs = Vec::new(); match &*query.body { SetExpr::Query(query) => { @@ -467,27 +517,75 @@ impl QueryRouter { }; } + // SELECT * FROM ... + // We understand that pretty well. SetExpr::Select(select) => { + // Collect all table names from the query. + let mut table_names = Vec::new(); + + for table in select.from.iter() { + match &table.relation { + TableFactor::Table { name, .. } => { + table_names.push(name.0.clone()); + } + + _ => (), + }; + + // Get table names from all the joins. + for join in table.joins.iter() { + match &join.relation { + TableFactor::Table { name, .. } => { + table_names.push(name.0.clone()); + } + + _ => (), + }; + + // We can filter results based on join conditions, e.g. + // SELECT * FROM t INNER JOIN B ON B.sharding_key = 5; + match &join.join_operator { + JoinOperator::Inner(inner_join) => match &inner_join { + JoinConstraint::On(expr) => { + // Parse the selection criteria later. + exprs.push(expr.clone()); + } + + _ => (), + }, + + _ => (), + }; + } + } + + // Parse the actual "FROM ..." match &select.selection { Some(selection) => { - let sharding_keys = self.selection_parser(selection); - - // TODO: Add support for prepared statements here. - // This should just give us the position of the value in the `B` message. - - let sharder = Sharder::new( - self.pool_settings.shards, - self.pool_settings.sharding_function, - ); - - for value in sharding_keys { - let shard = sharder.shard(value); - shards.insert(shard); - } + exprs.push(selection.clone()); } None => (), }; + + // Look for sharding keys in either the join condition + // or the selection. + for expr in exprs.iter() { + let sharding_keys = self.selection_parser(expr, &table_names); + + // TODO: Add support for prepared statements here. + // This should just give us the position of the value in the `B` message. + + let sharder = Sharder::new( + self.pool_settings.shards, + self.pool_settings.sharding_function, + ); + + for value in sharding_keys { + let shard = sharder.shard(value); + shards.insert(shard); + } + } } _ => (), }; @@ -825,7 +923,7 @@ mod test { query_parser_enabled: true, primary_reads_enabled: false, sharding_function: ShardingFunction::PgBigintHash, - automatic_sharding_key: Some(String::from("id")), + automatic_sharding_key: Some(String::from("test.id")), healthcheck_delay: PoolSettings::default().healthcheck_delay, healthcheck_timeout: PoolSettings::default().healthcheck_timeout, ban_time: PoolSettings::default().ban_time, @@ -854,11 +952,6 @@ mod test { let q2 = simple_query("SET SERVER ROLE TO 'default'"); assert!(qr.try_execute_command(&q2) != None); assert_eq!(qr.active_role.unwrap(), pool_settings.default_role); - - // Here we go :) - let q3 = simple_query("SELECT * FROM test WHERE id = 5 AND values IN (1, 2, 3)"); - assert!(qr.infer(&q3)); - assert_eq!(qr.shard(), 1); } #[test] @@ -891,7 +984,7 @@ mod test { query_parser_enabled: true, primary_reads_enabled: false, sharding_function: ShardingFunction::PgBigintHash, - automatic_sharding_key: Some(String::from("id")), + automatic_sharding_key: None, healthcheck_delay: PoolSettings::default().healthcheck_delay, healthcheck_timeout: PoolSettings::default().healthcheck_timeout, ban_time: PoolSettings::default().ban_time, @@ -920,4 +1013,56 @@ mod test { assert!(qr.try_execute_command(&q2) == None); assert_eq!(qr.active_shard, Some(2)); } + + #[test] + fn test_automatic_sharding_key() { + QueryRouter::setup(); + + let mut qr = QueryRouter::new(); + qr.pool_settings.automatic_sharding_key = Some("data.id".to_string()); + qr.pool_settings.shards = 3; + + assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 5"))); + assert_eq!(qr.shard(), 2); + + assert!(qr.infer(&simple_query( + "SELECT one, two, three FROM public.data WHERE id = 6" + ))); + assert_eq!(qr.shard(), 0); + + assert!(qr.infer(&simple_query( + "SELECT * FROM data + INNER JOIN t2 ON data.id = 5 + AND t2.data_id = data.id + WHERE data.id = 5" + ))); + assert_eq!(qr.shard(), 2); + + // Shard did not move because we couldn't determine the sharding key since it could be ambiguous + // in the query. + assert!(qr.infer(&simple_query( + "SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id" + ))); + assert_eq!(qr.shard(), 2); + + assert!(qr.infer(&simple_query( + r#"SELECT * FROM "public"."data" WHERE "id" = 6"# + ))); + assert_eq!(qr.shard(), 0); + + assert!(qr.infer(&simple_query( + r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"# + ))); + assert_eq!(qr.shard(), 2); + + // Super unique sharding key + qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string()); + assert!(qr.infer(&simple_query( + "SELECT * FROM table_x WHERE unique_enough_column_name = 6" + ))); + assert_eq!(qr.shard(), 0); + + assert!(qr.infer(&simple_query("SELECT * FROM table_y WHERE another_key = 5"))); + assert_eq!(qr.shard(), 0); + } }