Fixes try_execute_command message parsing bug (#560)

* Fixes try_execute_command message parsing bug

* Fix initial segment logic

* Add test
This commit is contained in:
Zain Kabani
2023-08-24 14:25:43 -04:00
committed by GitHub
parent 4301ab0606
commit be549f3faa

View File

@@ -19,9 +19,9 @@ use crate::plugins::{Intercept, Plugin, PluginOutput, QueryLogger, TableAccess};
use crate::pool::PoolSettings;
use crate::sharding::Sharder;
use std::cmp;
use std::collections::BTreeSet;
use std::io::Cursor;
use std::{cmp, mem};
/// Regexes used to parse custom commands.
const CUSTOM_SQL_REGEXES: [&str; 7] = [
@@ -141,6 +141,7 @@ impl QueryRouter {
let mut message_cursor = Cursor::new(message_buffer);
let code = message_cursor.get_u8() as char;
let len = message_cursor.get_i32() as usize;
// Check for any sharding regex matches in any queries
match code as char {
@@ -150,9 +151,13 @@ impl QueryRouter {
|| self.pool_settings.sharding_key_regex.is_some()
{
// Check only the first block of bytes configured by the pool settings
let len = message_cursor.get_i32() as usize;
let seg = cmp::min(len - 5, self.pool_settings.regex_search_limit);
let initial_segment = String::from_utf8_lossy(&message_buffer[0..seg]);
let query_start_index = mem::size_of::<u8>() + mem::size_of::<i32>();
let initial_segment = String::from_utf8_lossy(
&message_buffer[query_start_index..query_start_index + seg],
);
// Check for a shard_id included in the query
if let Some(shard_id_regex) = &self.pool_settings.shard_id_regex {
@@ -192,7 +197,6 @@ impl QueryRouter {
return None;
}
let _len = message_cursor.get_i32() as usize;
let query = message_cursor.read_string().unwrap();
let regex_set = match CUSTOM_SQL_REGEX_SET.get() {
@@ -1291,6 +1295,11 @@ mod test {
// Shard should start out unset
assert_eq!(qr.active_shard, None);
// Don't panic when short query eg. ; is sent
let q0 = simple_query(";");
assert!(qr.try_execute_command(&q0) == None);
assert_eq!(qr.active_shard, None);
// Make sure setting it works
let q1 = simple_query("/* shard_id: 1 */ select 1 from foo;");
assert!(qr.try_execute_command(&q1) == None);