From 686b7ca7c52f8d4e2b8cfed1e6adc8a1ea87c153 Mon Sep 17 00:00:00 2001 From: Mostafa Date: Sun, 1 Sep 2024 15:31:27 -0500 Subject: [PATCH] Fixes --- src/client.rs | 15 +++++------ src/messages.rs | 67 ++++++++++++++++++++++++------------------------- src/server.rs | 13 +++++++++- 3 files changed, 51 insertions(+), 44 deletions(-) diff --git a/src/client.rs b/src/client.rs index 23392b7..a5d543a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1729,14 +1729,13 @@ where /// and also the pool's statement cache. Add it to extended protocol data. fn buffer_parse(&mut self, message: BytesMut, pool: &ConnectionPool) -> Result<(), Error> { // Avoid parsing if prepared statements not enabled - if !self.prepared_statements_enabled { + let client_given_name = Parse::get_name(&message)?; + if !self.prepared_statements_enabled || client_given_name.len() == 0 { debug!("Anonymous parse message"); self.extended_protocol_data_buffer .push_back(ExtendedProtocolData::create_new_parse(message, None)); return Ok(()); } - - let client_given_name = Parse::get_name(&message)?; let parse: Parse = (&message).try_into()?; // Compute the hash of the parse statement @@ -1774,15 +1773,14 @@ where /// saved in the client cache. async fn buffer_bind(&mut self, message: BytesMut) -> Result<(), Error> { // Avoid parsing if prepared statements not enabled - if !self.prepared_statements_enabled { + let client_given_name = Bind::get_name(&message)?; + if !self.prepared_statements_enabled || client_given_name.len() == 0 { debug!("Anonymous bind message"); self.extended_protocol_data_buffer .push_back(ExtendedProtocolData::create_new_bind(message, None)); return Ok(()); } - let client_given_name = Bind::get_name(&message)?; - match self.prepared_statements.get(&client_given_name) { Some((rewritten_parse, _)) => { let message = Bind::rename(message, &rewritten_parse.name)?; @@ -1834,7 +1832,8 @@ where } let describe: Describe = (&message).try_into()?; - if describe.target == 'P' { + let client_given_name = describe.statement_name.clone(); + if describe.target == 'P' || client_given_name.len() == 0 { debug!("Portal describe message"); self.extended_protocol_data_buffer .push_back(ExtendedProtocolData::create_new_describe(message, None)); @@ -1842,8 +1841,6 @@ where return Ok(()); } - let client_given_name = describe.statement_name.clone(); - match self.prepared_statements.get(&client_given_name) { Some((rewritten_parse, _)) => { let describe = describe.rename(&rewritten_parse.name); diff --git a/src/messages.rs b/src/messages.rs index 6a114e1..042df0e 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -821,10 +821,10 @@ impl ExtendedProtocolData { pub struct Parse { code: char, #[allow(dead_code)] - len: i32, + len: u32, pub name: String, query: String, - num_params: i16, + num_params: u16, param_types: Vec, } @@ -834,12 +834,11 @@ impl TryFrom<&BytesMut> for Parse { fn try_from(buf: &BytesMut) -> Result { let mut cursor = Cursor::new(buf); let code = cursor.get_u8() as char; - let len = cursor.get_i32(); + let len = cursor.get_u32(); let name = cursor.read_string()?; let query = cursor.read_string()?; - let num_params = cursor.get_i16(); + let num_params = cursor.get_u16(); let mut param_types = Vec::new(); - for _ in 0..num_params { param_types.push(cursor.get_i32()); } @@ -875,10 +874,10 @@ impl TryFrom for BytesMut { + 4 * parse.num_params as usize; bytes.put_u8(parse.code as u8); - bytes.put_i32(len as i32); + bytes.put_u32(len as u32); bytes.put_slice(name); bytes.put_slice(query); - bytes.put_i16(parse.num_params); + bytes.put_u16(parse.num_params); for param in parse.param_types { bytes.put_i32(param); } @@ -945,15 +944,15 @@ impl Parse { pub struct Bind { code: char, #[allow(dead_code)] - len: i64, + len: u64, portal: String, pub prepared_statement: String, - num_param_format_codes: i16, + num_param_format_codes: u16, param_format_codes: Vec, - num_param_values: i16, + num_param_values: u16, param_values: Vec<(i32, BytesMut)>, - num_result_column_format_codes: i16, - result_columns_format_codes: Vec, + num_result_column_format_codes: u16, + result_columns_format_codes: Vec, } impl TryFrom<&BytesMut> for Bind { @@ -962,21 +961,21 @@ impl TryFrom<&BytesMut> for Bind { fn try_from(buf: &BytesMut) -> Result { let mut cursor = Cursor::new(buf); let code = cursor.get_u8() as char; - let len = cursor.get_i32(); + let len = cursor.get_u32(); let portal = cursor.read_string()?; let prepared_statement = cursor.read_string()?; - let num_param_format_codes = cursor.get_i16(); + let num_param_format_codes = cursor.get_u16(); let mut param_format_codes = Vec::new(); for _ in 0..num_param_format_codes { - param_format_codes.push(cursor.get_i16()); + param_format_codes.push(cursor.get_u16()); } - let num_param_values = cursor.get_i16(); + let num_param_values = cursor.get_u16(); let mut param_values = Vec::new(); for _ in 0..num_param_values { - let param_len = cursor.get_i32(); + let param_len = cursor.get_u32(); // There is special occasion when the parameter is NULL // In that case, param length is defined as -1 // So if the passed parameter len is over 0 @@ -994,16 +993,16 @@ impl TryFrom<&BytesMut> for Bind { } } - let num_result_column_format_codes = cursor.get_i16(); + let num_result_column_format_codes = cursor.get_u16(); let mut result_columns_format_codes = Vec::new(); for _ in 0..num_result_column_format_codes { - result_columns_format_codes.push(cursor.get_i16()); + result_columns_format_codes.push(cursor.get_u16()); } Ok(Bind { code, - len: len as i64, + len: len as u64, portal, prepared_statement, num_param_format_codes, @@ -1042,21 +1041,21 @@ impl TryFrom for BytesMut { len += 2 * bind.num_result_column_format_codes as usize; bytes.put_u8(bind.code as u8); - bytes.put_i32(len as i32); + bytes.put_u32(len as u32); bytes.put_slice(portal); bytes.put_slice(prepared_statement); - bytes.put_i16(bind.num_param_format_codes); + bytes.put_u16(bind.num_param_format_codes); for param_format_code in bind.param_format_codes { - bytes.put_i16(param_format_code); + bytes.put_u16(param_format_code); } - bytes.put_i16(bind.num_param_values); + bytes.put_u16(bind.num_param_values); for (param_len, param) in bind.param_values { - bytes.put_i32(param_len); + bytes.put_u32(param_len); bytes.put_slice(¶m); } - bytes.put_i16(bind.num_result_column_format_codes); + bytes.put_u16(bind.num_result_column_format_codes); for result_column_format_code in bind.result_columns_format_codes { - bytes.put_i16(result_column_format_code); + bytes.put_u16(result_column_format_code); } Ok(bytes) @@ -1068,7 +1067,7 @@ impl Bind { pub fn get_name(buf: &BytesMut) -> Result { let mut cursor = Cursor::new(buf); // Skip the code and length - cursor.advance(mem::size_of::() + mem::size_of::()); + cursor.advance(mem::size_of::() + mem::size_of::()); cursor.read_string()?; cursor.read_string() } @@ -1078,17 +1077,17 @@ impl Bind { let mut cursor = Cursor::new(&buf); // Read basic data from the cursor let code = cursor.get_u8(); - let current_len = cursor.get_i32(); + let current_len = cursor.get_u32(); let portal = cursor.read_string()?; let prepared_statement = cursor.read_string()?; // Calculate new length - let new_len = current_len + new_name.len() as i32 - prepared_statement.len() as i32; + let new_len = current_len + new_name.len() as u32 - prepared_statement.len() as u32; // Begin building the response buffer let mut response_buf = BytesMut::with_capacity(new_len as usize + 1); response_buf.put_u8(code); - response_buf.put_i32(new_len); + response_buf.put_u32(new_len); // Put the portal and new name into the buffer // Note: panic if the provided string contains null byte @@ -1112,7 +1111,7 @@ pub struct Describe { code: char, #[allow(dead_code)] - len: i32, + len: u32, pub target: char, pub statement_name: String, } @@ -1123,7 +1122,7 @@ impl TryFrom<&BytesMut> for Describe { fn try_from(bytes: &BytesMut) -> Result { let mut cursor = Cursor::new(bytes); let code = cursor.get_u8() as char; - let len = cursor.get_i32(); + let len = cursor.get_u32(); let target = cursor.get_u8() as char; let statement_name = cursor.read_string()?; @@ -1146,7 +1145,7 @@ impl TryFrom for BytesMut { let len = 4 + 1 + statement_name.len(); bytes.put_u8(describe.code as u8); - bytes.put_i32(len as i32); + bytes.put_u32(len as u32); bytes.put_u8(describe.target as u8); bytes.put_slice(statement_name); diff --git a/src/server.rs b/src/server.rs index 882450e..1d0510a 100644 --- a/src/server.rs +++ b/src/server.rs @@ -698,7 +698,6 @@ impl Server { )) } }; - trace!("Error: {}", error_code); match error_code { @@ -1013,6 +1012,12 @@ impl Server { // which can leak between clients. This is a best effort to block bad clients // from poisoning a transaction-mode pool by setting inappropriate session variables match command.as_str() { + "DISCARD ALL" => { + self.clear_prepared_statement_cache(); + } + "DEALLOCATE ALL" => { + self.clear_prepared_statement_cache(); + } "SET" => { // We don't detect set statements in transactions // No great way to differentiate between set and set local @@ -1132,6 +1137,12 @@ impl Server { has_it } + fn clear_prepared_statement_cache(&mut self) { + if let Some(cache) = &mut self.prepared_statement_cache { + cache.clear(); + } + } + fn add_prepared_statement_to_cache(&mut self, name: &str) -> Option { let cache = match &mut self.prepared_statement_cache { Some(cache) => cache,