mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-25 02:06:28 +00:00
Extended query protocol sharding (#339)
* Prepared stmt sharding s tests * len check * remove python test * latest rust * move that to debug for sure * Add the actual tests * latest image * Update tests/ruby/sharding_spec.rb
This commit is contained in:
@@ -675,14 +675,42 @@ where
|
||||
// allocate a connection, we wouldn't be able to send back an error message
|
||||
// to the client so we buffer them and defer the decision to error out or not
|
||||
// to when we get the S message
|
||||
'P' | 'B' | 'D' | 'E' => {
|
||||
'D' | 'E' => {
|
||||
self.buffer.put(&message[..]);
|
||||
continue;
|
||||
}
|
||||
|
||||
'Q' => {
|
||||
if query_router.query_parser_enabled() {
|
||||
query_router.infer(&message);
|
||||
}
|
||||
}
|
||||
|
||||
'P' => {
|
||||
self.buffer.put(&message[..]);
|
||||
|
||||
if query_router.query_parser_enabled() {
|
||||
query_router.infer(&message);
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
'B' => {
|
||||
self.buffer.put(&message[..]);
|
||||
|
||||
if query_router.query_parser_enabled() {
|
||||
query_router.infer_shard_from_bind(&message);
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
'X' => {
|
||||
debug!("Client disconnecting");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
_ => (),
|
||||
}
|
||||
|
||||
@@ -711,11 +739,7 @@ where
|
||||
// Handle all custom protocol commands, if any.
|
||||
match query_router.try_execute_command(&message) {
|
||||
// Normal query, not a custom command.
|
||||
None => {
|
||||
if query_router.query_parser_enabled() {
|
||||
query_router.infer(&message);
|
||||
}
|
||||
}
|
||||
None => (),
|
||||
|
||||
// SET SHARD TO
|
||||
Some((Command::SetShard, _)) => {
|
||||
@@ -727,7 +751,7 @@ where
|
||||
error_response(
|
||||
&mut self.write,
|
||||
&format!(
|
||||
"shard {} is more than configured {}, staying on shard {}",
|
||||
"shard {} is more than configured {}, staying on shard {} (shard numbers start at 0)",
|
||||
query_router.shard(),
|
||||
pool.shards(),
|
||||
current_shard,
|
||||
|
||||
@@ -43,6 +43,20 @@ pub enum Command {
|
||||
ShowPrimaryReads,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Debug)]
|
||||
pub enum ShardingKey {
|
||||
Value(i64),
|
||||
Placeholder(i16),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
enum ParameterFormat {
|
||||
Text,
|
||||
Binary,
|
||||
Uniform(Box<ParameterFormat>),
|
||||
Specified(Vec<ParameterFormat>),
|
||||
}
|
||||
|
||||
/// Quickly test for match when a query is received.
|
||||
static CUSTOM_SQL_REGEX_SET: OnceCell<RegexSet> = OnceCell::new();
|
||||
|
||||
@@ -65,6 +79,9 @@ pub struct QueryRouter {
|
||||
|
||||
/// Pool configuration.
|
||||
pool_settings: PoolSettings,
|
||||
|
||||
// Placeholders from prepared statement.
|
||||
placeholders: Vec<i16>,
|
||||
}
|
||||
|
||||
impl QueryRouter {
|
||||
@@ -103,6 +120,7 @@ impl QueryRouter {
|
||||
query_parser_enabled: None,
|
||||
primary_reads_enabled: None,
|
||||
pool_settings: PoolSettings::default(),
|
||||
placeholders: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -307,10 +325,10 @@ impl QueryRouter {
|
||||
}
|
||||
|
||||
/// Try to infer which server to connect to based on the contents of the query.
|
||||
pub fn infer(&mut self, message_buffer: &BytesMut) -> bool {
|
||||
pub fn infer(&mut self, message: &BytesMut) -> bool {
|
||||
debug!("Inferring role");
|
||||
|
||||
let mut message_cursor = Cursor::new(message_buffer);
|
||||
let mut message_cursor = Cursor::new(message);
|
||||
|
||||
let code = message_cursor.get_u8() as char;
|
||||
let _len = message_cursor.get_i32() as usize;
|
||||
@@ -332,8 +350,7 @@ impl QueryRouter {
|
||||
let query = message_cursor.read_string().unwrap();
|
||||
|
||||
debug!("Prepared statement: '{}'", query);
|
||||
|
||||
query.replace('$', "") // Remove placeholders turning them into "values"
|
||||
query
|
||||
}
|
||||
|
||||
_ => return false,
|
||||
@@ -343,7 +360,7 @@ impl QueryRouter {
|
||||
Ok(ast) => ast,
|
||||
Err(err) => {
|
||||
// SELECT ... FOR UPDATE won't get parsed correctly.
|
||||
error!("{}: {}", err, query);
|
||||
debug!("{}: {}", err, query);
|
||||
self.active_role = Some(Role::Primary);
|
||||
return false;
|
||||
}
|
||||
@@ -404,9 +421,147 @@ impl QueryRouter {
|
||||
true
|
||||
}
|
||||
|
||||
/// Parse the shard number from the Bind message
|
||||
/// which contains the arguments for a prepared statement.
|
||||
///
|
||||
/// N.B.: Only supports anonymous prepared statements since we don't
|
||||
/// keep a cache of them in PgCat.
|
||||
pub fn infer_shard_from_bind(&mut self, message: &BytesMut) -> bool {
|
||||
debug!("Parsing bind message");
|
||||
|
||||
let mut message_cursor = Cursor::new(message);
|
||||
|
||||
let code = message_cursor.get_u8() as char;
|
||||
let len = message_cursor.get_i32();
|
||||
|
||||
if code != 'B' {
|
||||
debug!("Not a bind packet");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check message length
|
||||
if message.len() != len as usize + 1 {
|
||||
debug!(
|
||||
"Message has wrong length, expected {}, but have {}",
|
||||
len,
|
||||
message.len()
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
// There are no shard keys in the prepared statement.
|
||||
if self.placeholders.is_empty() {
|
||||
debug!("There are no placeholders in the prepared statement that matched the automatic sharding key");
|
||||
return false;
|
||||
}
|
||||
|
||||
let sharder = Sharder::new(
|
||||
self.pool_settings.shards,
|
||||
self.pool_settings.sharding_function,
|
||||
);
|
||||
|
||||
let mut shards = BTreeSet::new();
|
||||
|
||||
let _portal = message_cursor.read_string();
|
||||
let _name = message_cursor.read_string();
|
||||
|
||||
let num_params = message_cursor.get_i16();
|
||||
let parameter_format = match num_params {
|
||||
0 => ParameterFormat::Text, // Text
|
||||
1 => {
|
||||
let param_format = message_cursor.get_i16();
|
||||
ParameterFormat::Uniform(match param_format {
|
||||
0 => Box::new(ParameterFormat::Text),
|
||||
1 => Box::new(ParameterFormat::Binary),
|
||||
_ => unreachable!(),
|
||||
})
|
||||
}
|
||||
n => {
|
||||
let mut v = Vec::with_capacity(n as usize);
|
||||
for _ in 0..n {
|
||||
let param_format = message_cursor.get_i16();
|
||||
v.push(match param_format {
|
||||
0 => ParameterFormat::Text,
|
||||
1 => ParameterFormat::Binary,
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
ParameterFormat::Specified(v)
|
||||
}
|
||||
};
|
||||
|
||||
let num_parameters = message_cursor.get_i16();
|
||||
|
||||
for i in 0..num_parameters {
|
||||
let mut len = message_cursor.get_i32() as usize;
|
||||
let format = match ¶meter_format {
|
||||
ParameterFormat::Text => ParameterFormat::Text,
|
||||
ParameterFormat::Uniform(format) => *format.clone(),
|
||||
ParameterFormat::Specified(formats) => formats[i as usize].clone(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
debug!("Parameter {} (len: {}): {:?}", i, len, format);
|
||||
|
||||
// Postgres counts placeholders starting at 1
|
||||
let placeholder = i + 1;
|
||||
|
||||
if self.placeholders.contains(&placeholder) {
|
||||
let value = match format {
|
||||
ParameterFormat::Text => {
|
||||
let mut value = String::new();
|
||||
while len > 0 {
|
||||
value.push(message_cursor.get_u8() as char);
|
||||
len -= 1;
|
||||
}
|
||||
|
||||
match value.parse::<i64>() {
|
||||
Ok(value) => value,
|
||||
Err(_) => {
|
||||
debug!("Error parsing bind value: {}", value);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ParameterFormat::Binary => match len {
|
||||
2 => message_cursor.get_i16() as i64,
|
||||
4 => message_cursor.get_i32() as i64,
|
||||
8 => message_cursor.get_i64(),
|
||||
_ => {
|
||||
error!(
|
||||
"Got wrong length for integer type parameter in bind: {}",
|
||||
len
|
||||
);
|
||||
continue;
|
||||
}
|
||||
},
|
||||
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
shards.insert(sharder.shard(value));
|
||||
}
|
||||
}
|
||||
|
||||
self.placeholders.clear();
|
||||
self.placeholders.shrink_to_fit();
|
||||
|
||||
// We only support querying one shard at a time.
|
||||
// TODO: Support multi-shard queries some day.
|
||||
if shards.len() == 1 {
|
||||
debug!("Found one sharding key");
|
||||
self.set_shard(*shards.first().unwrap());
|
||||
true
|
||||
} else {
|
||||
debug!("Found no sharding keys");
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// A `selection` is the `WHERE` clause. This parses
|
||||
/// the clause and extracts the sharding key, if present.
|
||||
fn selection_parser(&self, expr: &Expr, table_names: &Vec<Vec<Ident>>) -> Vec<i64> {
|
||||
fn selection_parser(&self, expr: &Expr, table_names: &Vec<Vec<Ident>>) -> Vec<ShardingKey> {
|
||||
let mut result = Vec::new();
|
||||
let mut found = false;
|
||||
|
||||
@@ -487,13 +642,25 @@ impl QueryRouter {
|
||||
Expr::Value(Value::Number(value, ..)) => {
|
||||
if found {
|
||||
match value.parse::<i64>() {
|
||||
Ok(value) => result.push(value),
|
||||
Ok(value) => result.push(ShardingKey::Value(value)),
|
||||
Err(_) => {
|
||||
debug!("Sharding key was not an integer: {}", value);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
Expr::Value(Value::Placeholder(placeholder)) => {
|
||||
match placeholder.replace("$", "").parse::<i16>() {
|
||||
Ok(placeholder) => result.push(ShardingKey::Placeholder(placeholder)),
|
||||
Err(_) => {
|
||||
debug!(
|
||||
"Prepared statement didn't have integer placeholders: {}",
|
||||
placeholder
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => (),
|
||||
};
|
||||
}
|
||||
@@ -504,7 +671,7 @@ impl QueryRouter {
|
||||
}
|
||||
|
||||
/// Try to figure out which shard the query should go to.
|
||||
fn infer_shard(&self, query: &sqlparser::ast::Query) -> Option<usize> {
|
||||
fn infer_shard(&mut self, query: &sqlparser::ast::Query) -> Option<usize> {
|
||||
let mut shards = BTreeSet::new();
|
||||
let mut exprs = Vec::new();
|
||||
|
||||
@@ -569,6 +736,11 @@ impl QueryRouter {
|
||||
None => (),
|
||||
};
|
||||
|
||||
let sharder = Sharder::new(
|
||||
self.pool_settings.shards,
|
||||
self.pool_settings.sharding_function,
|
||||
);
|
||||
|
||||
// Look for sharding keys in either the join condition
|
||||
// or the selection.
|
||||
for expr in exprs.iter() {
|
||||
@@ -577,14 +749,17 @@ impl QueryRouter {
|
||||
// TODO: Add support for prepared statements here.
|
||||
// This should just give us the position of the value in the `B` message.
|
||||
|
||||
let sharder = Sharder::new(
|
||||
self.pool_settings.shards,
|
||||
self.pool_settings.sharding_function,
|
||||
);
|
||||
|
||||
for value in sharding_keys {
|
||||
let shard = sharder.shard(value);
|
||||
shards.insert(shard);
|
||||
match value {
|
||||
ShardingKey::Value(value) => {
|
||||
let shard = sharder.shard(value);
|
||||
shards.insert(shard);
|
||||
}
|
||||
|
||||
ShardingKey::Placeholder(position) => {
|
||||
self.placeholders.push(position);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -634,10 +809,14 @@ impl QueryRouter {
|
||||
|
||||
/// Should we attempt to parse queries?
|
||||
pub fn query_parser_enabled(&self) -> bool {
|
||||
match self.query_parser_enabled {
|
||||
let enabled = match self.query_parser_enabled {
|
||||
None => self.pool_settings.query_parser_enabled,
|
||||
Some(value) => value,
|
||||
}
|
||||
};
|
||||
|
||||
debug!("Query parser enabled: {}", enabled);
|
||||
|
||||
enabled
|
||||
}
|
||||
|
||||
pub fn primary_reads_enabled(&self) -> bool {
|
||||
@@ -1066,4 +1245,32 @@ mod test {
|
||||
assert!(qr.infer(&simple_query("SELECT * FROM table_y WHERE another_key = 5")));
|
||||
assert_eq!(qr.shard(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prepared_statements() {
|
||||
let stmt = "SELECT * FROM data WHERE id = $1";
|
||||
|
||||
let mut bind = BytesMut::from(&b"B"[..]);
|
||||
|
||||
let mut payload = BytesMut::from(&b"\0\0"[..]);
|
||||
payload.put_i16(0);
|
||||
payload.put_i16(1);
|
||||
payload.put_i32(1);
|
||||
payload.put(&b"5"[..]);
|
||||
payload.put_i16(0);
|
||||
|
||||
bind.put_i32(payload.len() as i32 + 4);
|
||||
bind.put(payload);
|
||||
|
||||
let mut qr = QueryRouter::new();
|
||||
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
|
||||
qr.pool_settings.shards = 3;
|
||||
|
||||
assert!(qr.infer(&simple_query(stmt)));
|
||||
assert_eq!(qr.placeholders.len(), 1);
|
||||
|
||||
assert!(qr.infer_shard_from_bind(&bind));
|
||||
assert_eq!(qr.shard(), 2);
|
||||
assert!(qr.placeholders.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user