Make infer role configurable and fix double parse bug (#533)

* Make infer role configurable and fix double parse bug

* Fix tests

* Enable infer_role_from query in toml for tests

* Fix test

* Add max length config, add logging for which application is failing to parse, and change config name

* fmt

* Update src/config.rs

---------

Co-authored-by: Lev Kokotov <levkk@users.noreply.github.com>
This commit is contained in:
Zain Kabani
2023-08-08 16:10:03 -04:00
committed by GitHub
parent 7c3c90c38e
commit e14b283f0c
8 changed files with 197 additions and 60 deletions

View File

@@ -74,6 +74,10 @@ default_role = "any"
# we'll direct it to the primary.
query_parser_enabled = true
# If the query parser is enabled and this setting is enabled, we'll attempt to
# infer the role from the query itself.
query_parser_read_write_splitting = true
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
# load balancing of read queries. Otherwise, the primary will only be used for write
# queries. The primary can always be explicitely selected with our custom protocol.
@@ -134,6 +138,7 @@ database = "shard2"
pool_mode = "session"
default_role = "primary"
query_parser_enabled = true
query_parser_read_write_splitting = true
primary_reads_enabled = true
sharding_function = "pg_bigint_hash"

View File

@@ -71,6 +71,10 @@ default_role = "any"
# we'll direct it to the primary.
query_parser_enabled = true
# If the query parser is enabled and this setting is enabled, we'll attempt to
# infer the role from the query itself.
query_parser_read_write_splitting = true
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
# load balancing of read queries. Otherwise, the primary will only be used for write
# queries. The primary can always be explicitly selected with our custom protocol.

View File

@@ -162,6 +162,10 @@ default_role = "any"
# we'll direct it to the primary.
query_parser_enabled = true
# If the query parser is enabled and this setting is enabled, we'll attempt to
# infer the role from the query itself.
query_parser_read_write_splitting = true
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
# load balancing of read queries. Otherwise, the primary will only be used for write
# queries. The primary can always be explicitly selected with our custom protocol.

View File

@@ -774,6 +774,9 @@ where
let mut prepared_statement = None;
let mut will_prepare = false;
let client_identifier =
ClientIdentifier::new(&self.application_name, &self.username, &self.pool_name);
// Our custom protocol loop.
// We expect the client to either start a transaction with regular queries
// or issue commands for our sharding and server selection protocol.
@@ -812,6 +815,21 @@ where
message_result = read_message(&mut self.read) => message_result?
};
// Handle admin database queries.
if self.admin {
debug!("Handling admin command");
handle_admin(&mut self.write, message, self.client_server_map.clone()).await?;
continue;
}
// Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config
// when starting a query.
let mut pool = self.get_pool().await?;
query_router.update_pool_settings(pool.settings.clone());
let mut initial_parsed_ast = None;
match message[0] as char {
// Buffer extended protocol messages even if we do not have
// a server connection yet. Hopefully, when we get the S message
@@ -841,7 +859,8 @@ where
'Q' => {
if query_router.query_parser_enabled() {
if let Ok(ast) = QueryRouter::parse(&message) {
match query_router.parse(&message) {
Ok(ast) => {
let plugin_result = query_router.execute_plugins(&ast).await;
match plugin_result {
@@ -859,6 +878,15 @@ where
};
let _ = query_router.infer(&ast);
initial_parsed_ast = Some(ast);
}
Err(error) => {
warn!(
"Query parsing error: {} (client: {})",
error, client_identifier
);
}
}
}
}
@@ -872,13 +900,21 @@ where
self.buffer.put(&message[..]);
if query_router.query_parser_enabled() {
if let Ok(ast) = QueryRouter::parse(&message) {
match query_router.parse(&message) {
Ok(ast) => {
if let Ok(output) = query_router.execute_plugins(&ast).await {
plugin_output = Some(output);
}
let _ = query_router.infer(&ast);
}
Err(error) => {
warn!(
"Query parsing error: {} (client: {})",
error, client_identifier
);
}
};
}
continue;
@@ -922,13 +958,6 @@ where
_ => (),
}
// Handle admin database queries.
if self.admin {
debug!("Handling admin command");
handle_admin(&mut self.write, message, self.client_server_map.clone()).await?;
continue;
}
// Check on plugin results.
match plugin_output {
Some(PluginOutput::Deny(error)) => {
@@ -941,11 +970,6 @@ where
_ => (),
};
// Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config
// when starting a query.
let mut pool = self.get_pool().await?;
// Check if the pool is paused and wait until it's resumed.
if pool.wait_paused().await {
// Refresh pool information, something might have changed.
@@ -1165,6 +1189,9 @@ where
None => {
trace!("Waiting for message inside transaction or in session mode");
// This is not an initial message so discard the initial_parsed_ast
initial_parsed_ast.take();
match tokio::time::timeout(
idle_client_timeout_duration,
read_message(&mut self.read),
@@ -1221,7 +1248,22 @@ where
// Query
'Q' => {
if query_router.query_parser_enabled() {
if let Ok(ast) = QueryRouter::parse(&message) {
// We don't want to parse again if we already parsed it as the initial message
let ast = match initial_parsed_ast {
Some(_) => Some(initial_parsed_ast.take().unwrap()),
None => match query_router.parse(&message) {
Ok(ast) => Some(ast),
Err(error) => {
warn!(
"Query parsing error: {} (client: {})",
error, client_identifier
);
None
}
},
};
if let Some(ast) = ast {
let plugin_result = query_router.execute_plugins(&ast).await;
match plugin_result {
@@ -1237,8 +1279,6 @@ where
_ => (),
};
let _ = query_router.infer(&ast);
}
}
debug!("Sending query to server");
@@ -1290,7 +1330,7 @@ where
}
if query_router.query_parser_enabled() {
if let Ok(ast) = QueryRouter::parse(&message) {
if let Ok(ast) = query_router.parse(&message) {
if let Ok(output) = query_router.execute_plugins(&ast).await {
plugin_output = Some(output);
}

View File

@@ -511,6 +511,11 @@ pub struct Pool {
#[serde(default)] // False
pub query_parser_enabled: bool,
pub query_parser_max_length: Option<usize>,
#[serde(default)] // False
pub query_parser_read_write_splitting: bool,
#[serde(default)] // False
pub primary_reads_enabled: bool,
@@ -627,6 +632,18 @@ impl Pool {
}
}
if self.query_parser_read_write_splitting && !self.query_parser_enabled {
error!(
"query_parser_read_write_splitting is only valid when query_parser_enabled is true"
);
return Err(Error::BadConfig);
}
if self.plugins.is_some() && !self.query_parser_enabled {
error!("plugins are only valid when query_parser_enabled is true");
return Err(Error::BadConfig);
}
self.automatic_sharding_key = match &self.automatic_sharding_key {
Some(key) => {
// No quotes in the key so we don't have to compare quoted
@@ -663,6 +680,8 @@ impl Default for Pool {
users: BTreeMap::default(),
default_role: String::from("any"),
query_parser_enabled: false,
query_parser_max_length: None,
query_parser_read_write_splitting: false,
primary_reads_enabled: false,
sharding_function: ShardingFunction::PgBigintHash,
automatic_sharding_key: None,
@@ -914,6 +933,17 @@ impl From<&Config> for std::collections::HashMap<String, String> {
format!("pools.{}.query_parser_enabled", pool_name),
pool.query_parser_enabled.to_string(),
),
(
format!("pools.{}.query_parser_max_length", pool_name),
match pool.query_parser_max_length {
Some(max_length) => max_length.to_string(),
None => String::from("unlimited"),
},
),
(
format!("pools.{}.query_parser_read_write_splitting", pool_name),
pool.query_parser_read_write_splitting.to_string(),
),
(
format!("pools.{}.default_role", pool_name),
pool.default_role.clone(),
@@ -1096,6 +1126,15 @@ impl Config {
"[pool: {}] Query router: {}",
pool_name, pool_config.query_parser_enabled
);
info!(
"[pool: {}] Query parser max length: {:?}",
pool_name, pool_config.query_parser_max_length
);
info!(
"[pool: {}] Infer role from query: {}",
pool_name, pool_config.query_parser_read_write_splitting
);
info!(
"[pool: {}] Number of shards: {}",
pool_name,

View File

@@ -111,6 +111,12 @@ pub struct PoolSettings {
// Enable/disable query parser.
pub query_parser_enabled: bool,
// Max length of query the parser will parse.
pub query_parser_max_length: Option<usize>,
// Infer role
pub query_parser_read_write_splitting: bool,
// Read from the primary as well or not.
pub primary_reads_enabled: bool,
@@ -157,6 +163,8 @@ impl Default for PoolSettings {
db: String::default(),
default_role: None,
query_parser_enabled: false,
query_parser_max_length: None,
query_parser_read_write_splitting: false,
primary_reads_enabled: true,
sharding_function: ShardingFunction::PgBigintHash,
automatic_sharding_key: None,
@@ -456,6 +464,9 @@ impl ConnectionPool {
_ => unreachable!(),
},
query_parser_enabled: pool_config.query_parser_enabled,
query_parser_max_length: pool_config.query_parser_max_length,
query_parser_read_write_splitting: pool_config
.query_parser_read_write_splitting,
primary_reads_enabled: pool_config.primary_reads_enabled,
sharding_function: pool_config.sharding_function,
automatic_sharding_key: pool_config.automatic_sharding_key.clone(),

View File

@@ -331,11 +331,23 @@ impl QueryRouter {
Some((command, value))
}
pub fn parse(message: &BytesMut) -> Result<Vec<Statement>, Error> {
pub fn parse(&self, message: &BytesMut) -> Result<Vec<Statement>, Error> {
let mut message_cursor = Cursor::new(message);
let code = message_cursor.get_u8() as char;
let _len = message_cursor.get_i32() as usize;
let len = message_cursor.get_i32() as usize;
match self.pool_settings.query_parser_max_length {
Some(max_length) => {
if len > max_length {
return Err(Error::QueryRouterParserError(format!(
"Query too long for parser: {} > {}",
len, max_length
)));
}
}
None => (),
};
let query = match code {
// Query
@@ -372,6 +384,10 @@ impl QueryRouter {
/// Try to infer which server to connect to based on the contents of the query.
pub fn infer(&mut self, ast: &Vec<sqlparser::ast::Statement>) -> Result<(), Error> {
if !self.pool_settings.query_parser_read_write_splitting {
return Ok(()); // Nothing to do
}
debug!("Inferring role");
if ast.is_empty() {
@@ -433,6 +449,10 @@ impl QueryRouter {
/// N.B.: Only supports anonymous prepared statements since we don't
/// keep a cache of them in PgCat.
pub fn infer_shard_from_bind(&mut self, message: &BytesMut) -> bool {
if !self.pool_settings.query_parser_read_write_splitting {
return false; // Nothing to do
}
debug!("Parsing bind message");
let mut message_cursor = Cursor::new(message);
@@ -910,6 +930,7 @@ mod test {
fn test_infer_replica() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
qr.pool_settings.query_parser_read_write_splitting = true;
assert!(qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")) != None);
assert!(qr.query_parser_enabled());
@@ -925,7 +946,7 @@ mod test {
for query in queries {
// It's a recognized query
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Replica));
}
}
@@ -934,6 +955,7 @@ mod test {
fn test_infer_primary() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
qr.pool_settings.query_parser_read_write_splitting = true;
let queries = vec![
simple_query("UPDATE items SET name = 'pumpkin' WHERE id = 5"),
@@ -944,7 +966,7 @@ mod test {
for query in queries {
// It's a recognized query
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Primary));
}
}
@@ -956,7 +978,7 @@ mod test {
let query = simple_query("SELECT * FROM items WHERE id = 5");
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None);
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), None);
}
@@ -964,6 +986,8 @@ mod test {
fn test_infer_parse_prepared() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
qr.pool_settings.query_parser_read_write_splitting = true;
qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"));
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
@@ -976,7 +1000,7 @@ mod test {
res.put(prepared_stmt);
res.put_i16(0);
assert!(qr.infer(&QueryRouter::parse(&res).unwrap()).is_ok());
assert!(qr.infer(&qr.parse(&res).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Replica));
}
@@ -1132,6 +1156,8 @@ mod test {
fn test_enable_query_parser() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
qr.pool_settings.query_parser_read_write_splitting = true;
let query = simple_query("SET SERVER ROLE TO 'auto'");
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
@@ -1140,11 +1166,11 @@ mod test {
assert_eq!(qr.role(), None);
let query = simple_query("INSERT INTO test_table VALUES (1)");
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Primary));
let query = simple_query("SELECT * FROM test_table");
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Replica));
assert!(qr.query_parser_enabled());
@@ -1164,6 +1190,8 @@ mod test {
user: crate::config::User::default(),
default_role: Some(Role::Replica),
query_parser_enabled: true,
query_parser_max_length: None,
query_parser_read_write_splitting: true,
primary_reads_enabled: false,
sharding_function: ShardingFunction::PgBigintHash,
automatic_sharding_key: Some(String::from("test.id")),
@@ -1208,18 +1236,18 @@ mod test {
let mut qr = QueryRouter::new();
assert!(qr
.infer(&QueryRouter::parse(&simple_query("BEGIN; SELECT 1; COMMIT;")).unwrap())
.infer(&qr.parse(&simple_query("BEGIN; SELECT 1; COMMIT;")).unwrap())
.is_ok());
assert_eq!(qr.role(), Role::Primary);
assert!(qr
.infer(&QueryRouter::parse(&simple_query("SELECT 1; SELECT 2;")).unwrap())
.infer(&qr.parse(&simple_query("SELECT 1; SELECT 2;")).unwrap())
.is_ok());
assert_eq!(qr.role(), Role::Replica);
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
&qr.parse(&simple_query(
"SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
))
.unwrap()
@@ -1239,6 +1267,8 @@ mod test {
user: crate::config::User::default(),
default_role: Some(Role::Replica),
query_parser_enabled: true,
query_parser_max_length: None,
query_parser_read_write_splitting: true,
primary_reads_enabled: false,
sharding_function: ShardingFunction::PgBigintHash,
automatic_sharding_key: None,
@@ -1284,15 +1314,19 @@ mod test {
let mut qr = QueryRouter::new();
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
qr.pool_settings.shards = 3;
qr.pool_settings.query_parser_read_write_splitting = true;
assert!(qr
.infer(&QueryRouter::parse(&simple_query("SELECT * FROM data WHERE id = 5")).unwrap())
.infer(
&qr.parse(&simple_query("SELECT * FROM data WHERE id = 5"))
.unwrap(),
)
.is_ok());
assert_eq!(qr.shard(), 2);
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
&qr.parse(&simple_query(
"SELECT one, two, three FROM public.data WHERE id = 6"
))
.unwrap()
@@ -1302,7 +1336,7 @@ mod test {
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
&qr.parse(&simple_query(
"SELECT * FROM data
INNER JOIN t2 ON data.id = 5
AND t2.data_id = data.id
@@ -1317,7 +1351,7 @@ mod test {
// in the query.
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
&qr.parse(&simple_query(
"SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
))
.unwrap()
@@ -1327,7 +1361,7 @@ mod test {
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
&qr.parse(&simple_query(
r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
))
.unwrap()
@@ -1337,7 +1371,7 @@ mod test {
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
&qr.parse(&simple_query(
r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
))
.unwrap()
@@ -1349,7 +1383,7 @@ mod test {
qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string());
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
&qr.parse(&simple_query(
"SELECT * FROM table_x WHERE unique_enough_column_name = 6"
))
.unwrap()
@@ -1359,7 +1393,7 @@ mod test {
assert!(qr
.infer(
&QueryRouter::parse(&simple_query("SELECT * FROM table_y WHERE another_key = 5"))
&qr.parse(&simple_query("SELECT * FROM table_y WHERE another_key = 5"))
.unwrap()
)
.is_ok());
@@ -1385,10 +1419,9 @@ mod test {
let mut qr = QueryRouter::new();
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
qr.pool_settings.shards = 3;
qr.pool_settings.query_parser_read_write_splitting = true;
assert!(qr
.infer(&QueryRouter::parse(&simple_query(stmt)).unwrap())
.is_ok());
assert!(qr.infer(&qr.parse(&simple_query(stmt)).unwrap()).is_ok());
assert_eq!(qr.placeholders.len(), 1);
assert!(qr.infer_shard_from_bind(&bind));
@@ -1419,7 +1452,7 @@ mod test {
qr.update_pool_settings(pool_settings);
let query = simple_query("SELECT * FROM pg_database");
let ast = QueryRouter::parse(&query).unwrap();
let ast = qr.parse(&query).unwrap();
let res = qr.execute_plugins(&ast).await;
@@ -1437,7 +1470,7 @@ mod test {
let qr = QueryRouter::new();
let query = simple_query("SELECT * FROM pg_database");
let ast = QueryRouter::parse(&query).unwrap();
let ast = qr.parse(&query).unwrap();
let res = qr.execute_plugins(&ast).await;

View File

@@ -34,6 +34,7 @@ module Helpers
"load_balancing_mode" => lb_mode,
"primary_reads_enabled" => true,
"query_parser_enabled" => true,
"query_parser_read_write_splitting" => true,
"automatic_sharding_key" => "data.id",
"sharding_function" => "pg_bigint_hash",
"shards" => {