mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-25 10:06:28 +00:00
Fixes
This commit is contained in:
@@ -1729,14 +1729,13 @@ where
|
|||||||
/// and also the pool's statement cache. Add it to extended protocol data.
|
/// and also the pool's statement cache. Add it to extended protocol data.
|
||||||
fn buffer_parse(&mut self, message: BytesMut, pool: &ConnectionPool) -> Result<(), Error> {
|
fn buffer_parse(&mut self, message: BytesMut, pool: &ConnectionPool) -> Result<(), Error> {
|
||||||
// Avoid parsing if prepared statements not enabled
|
// Avoid parsing if prepared statements not enabled
|
||||||
if !self.prepared_statements_enabled {
|
let client_given_name = Parse::get_name(&message)?;
|
||||||
|
if !self.prepared_statements_enabled || client_given_name.len() == 0 {
|
||||||
debug!("Anonymous parse message");
|
debug!("Anonymous parse message");
|
||||||
self.extended_protocol_data_buffer
|
self.extended_protocol_data_buffer
|
||||||
.push_back(ExtendedProtocolData::create_new_parse(message, None));
|
.push_back(ExtendedProtocolData::create_new_parse(message, None));
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let client_given_name = Parse::get_name(&message)?;
|
|
||||||
let parse: Parse = (&message).try_into()?;
|
let parse: Parse = (&message).try_into()?;
|
||||||
|
|
||||||
// Compute the hash of the parse statement
|
// Compute the hash of the parse statement
|
||||||
@@ -1774,15 +1773,14 @@ where
|
|||||||
/// saved in the client cache.
|
/// saved in the client cache.
|
||||||
async fn buffer_bind(&mut self, message: BytesMut) -> Result<(), Error> {
|
async fn buffer_bind(&mut self, message: BytesMut) -> Result<(), Error> {
|
||||||
// Avoid parsing if prepared statements not enabled
|
// Avoid parsing if prepared statements not enabled
|
||||||
if !self.prepared_statements_enabled {
|
let client_given_name = Bind::get_name(&message)?;
|
||||||
|
if !self.prepared_statements_enabled || client_given_name.len() == 0 {
|
||||||
debug!("Anonymous bind message");
|
debug!("Anonymous bind message");
|
||||||
self.extended_protocol_data_buffer
|
self.extended_protocol_data_buffer
|
||||||
.push_back(ExtendedProtocolData::create_new_bind(message, None));
|
.push_back(ExtendedProtocolData::create_new_bind(message, None));
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let client_given_name = Bind::get_name(&message)?;
|
|
||||||
|
|
||||||
match self.prepared_statements.get(&client_given_name) {
|
match self.prepared_statements.get(&client_given_name) {
|
||||||
Some((rewritten_parse, _)) => {
|
Some((rewritten_parse, _)) => {
|
||||||
let message = Bind::rename(message, &rewritten_parse.name)?;
|
let message = Bind::rename(message, &rewritten_parse.name)?;
|
||||||
@@ -1834,7 +1832,8 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
let describe: Describe = (&message).try_into()?;
|
let describe: Describe = (&message).try_into()?;
|
||||||
if describe.target == 'P' {
|
let client_given_name = describe.statement_name.clone();
|
||||||
|
if describe.target == 'P' || client_given_name.len() == 0 {
|
||||||
debug!("Portal describe message");
|
debug!("Portal describe message");
|
||||||
self.extended_protocol_data_buffer
|
self.extended_protocol_data_buffer
|
||||||
.push_back(ExtendedProtocolData::create_new_describe(message, None));
|
.push_back(ExtendedProtocolData::create_new_describe(message, None));
|
||||||
@@ -1842,8 +1841,6 @@ where
|
|||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let client_given_name = describe.statement_name.clone();
|
|
||||||
|
|
||||||
match self.prepared_statements.get(&client_given_name) {
|
match self.prepared_statements.get(&client_given_name) {
|
||||||
Some((rewritten_parse, _)) => {
|
Some((rewritten_parse, _)) => {
|
||||||
let describe = describe.rename(&rewritten_parse.name);
|
let describe = describe.rename(&rewritten_parse.name);
|
||||||
|
|||||||
@@ -821,10 +821,10 @@ impl ExtendedProtocolData {
|
|||||||
pub struct Parse {
|
pub struct Parse {
|
||||||
code: char,
|
code: char,
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
len: i32,
|
len: u32,
|
||||||
pub name: String,
|
pub name: String,
|
||||||
query: String,
|
query: String,
|
||||||
num_params: i16,
|
num_params: u16,
|
||||||
param_types: Vec<i32>,
|
param_types: Vec<i32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -834,12 +834,11 @@ impl TryFrom<&BytesMut> for Parse {
|
|||||||
fn try_from(buf: &BytesMut) -> Result<Parse, Error> {
|
fn try_from(buf: &BytesMut) -> Result<Parse, Error> {
|
||||||
let mut cursor = Cursor::new(buf);
|
let mut cursor = Cursor::new(buf);
|
||||||
let code = cursor.get_u8() as char;
|
let code = cursor.get_u8() as char;
|
||||||
let len = cursor.get_i32();
|
let len = cursor.get_u32();
|
||||||
let name = cursor.read_string()?;
|
let name = cursor.read_string()?;
|
||||||
let query = cursor.read_string()?;
|
let query = cursor.read_string()?;
|
||||||
let num_params = cursor.get_i16();
|
let num_params = cursor.get_u16();
|
||||||
let mut param_types = Vec::new();
|
let mut param_types = Vec::new();
|
||||||
|
|
||||||
for _ in 0..num_params {
|
for _ in 0..num_params {
|
||||||
param_types.push(cursor.get_i32());
|
param_types.push(cursor.get_i32());
|
||||||
}
|
}
|
||||||
@@ -875,10 +874,10 @@ impl TryFrom<Parse> for BytesMut {
|
|||||||
+ 4 * parse.num_params as usize;
|
+ 4 * parse.num_params as usize;
|
||||||
|
|
||||||
bytes.put_u8(parse.code as u8);
|
bytes.put_u8(parse.code as u8);
|
||||||
bytes.put_i32(len as i32);
|
bytes.put_u32(len as u32);
|
||||||
bytes.put_slice(name);
|
bytes.put_slice(name);
|
||||||
bytes.put_slice(query);
|
bytes.put_slice(query);
|
||||||
bytes.put_i16(parse.num_params);
|
bytes.put_u16(parse.num_params);
|
||||||
for param in parse.param_types {
|
for param in parse.param_types {
|
||||||
bytes.put_i32(param);
|
bytes.put_i32(param);
|
||||||
}
|
}
|
||||||
@@ -945,15 +944,15 @@ impl Parse {
|
|||||||
pub struct Bind {
|
pub struct Bind {
|
||||||
code: char,
|
code: char,
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
len: i64,
|
len: u64,
|
||||||
portal: String,
|
portal: String,
|
||||||
pub prepared_statement: String,
|
pub prepared_statement: String,
|
||||||
num_param_format_codes: i16,
|
num_param_format_codes: u16,
|
||||||
param_format_codes: Vec<i16>,
|
param_format_codes: Vec<i16>,
|
||||||
num_param_values: i16,
|
num_param_values: u16,
|
||||||
param_values: Vec<(i32, BytesMut)>,
|
param_values: Vec<(i32, BytesMut)>,
|
||||||
num_result_column_format_codes: i16,
|
num_result_column_format_codes: u16,
|
||||||
result_columns_format_codes: Vec<i16>,
|
result_columns_format_codes: Vec<u16>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<&BytesMut> for Bind {
|
impl TryFrom<&BytesMut> for Bind {
|
||||||
@@ -962,21 +961,21 @@ impl TryFrom<&BytesMut> for Bind {
|
|||||||
fn try_from(buf: &BytesMut) -> Result<Bind, Error> {
|
fn try_from(buf: &BytesMut) -> Result<Bind, Error> {
|
||||||
let mut cursor = Cursor::new(buf);
|
let mut cursor = Cursor::new(buf);
|
||||||
let code = cursor.get_u8() as char;
|
let code = cursor.get_u8() as char;
|
||||||
let len = cursor.get_i32();
|
let len = cursor.get_u32();
|
||||||
let portal = cursor.read_string()?;
|
let portal = cursor.read_string()?;
|
||||||
let prepared_statement = cursor.read_string()?;
|
let prepared_statement = cursor.read_string()?;
|
||||||
let num_param_format_codes = cursor.get_i16();
|
let num_param_format_codes = cursor.get_u16();
|
||||||
let mut param_format_codes = Vec::new();
|
let mut param_format_codes = Vec::new();
|
||||||
|
|
||||||
for _ in 0..num_param_format_codes {
|
for _ in 0..num_param_format_codes {
|
||||||
param_format_codes.push(cursor.get_i16());
|
param_format_codes.push(cursor.get_u16());
|
||||||
}
|
}
|
||||||
|
|
||||||
let num_param_values = cursor.get_i16();
|
let num_param_values = cursor.get_u16();
|
||||||
let mut param_values = Vec::new();
|
let mut param_values = Vec::new();
|
||||||
|
|
||||||
for _ in 0..num_param_values {
|
for _ in 0..num_param_values {
|
||||||
let param_len = cursor.get_i32();
|
let param_len = cursor.get_u32();
|
||||||
// There is special occasion when the parameter is NULL
|
// There is special occasion when the parameter is NULL
|
||||||
// In that case, param length is defined as -1
|
// In that case, param length is defined as -1
|
||||||
// So if the passed parameter len is over 0
|
// So if the passed parameter len is over 0
|
||||||
@@ -994,16 +993,16 @@ impl TryFrom<&BytesMut> for Bind {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let num_result_column_format_codes = cursor.get_i16();
|
let num_result_column_format_codes = cursor.get_u16();
|
||||||
let mut result_columns_format_codes = Vec::new();
|
let mut result_columns_format_codes = Vec::new();
|
||||||
|
|
||||||
for _ in 0..num_result_column_format_codes {
|
for _ in 0..num_result_column_format_codes {
|
||||||
result_columns_format_codes.push(cursor.get_i16());
|
result_columns_format_codes.push(cursor.get_u16());
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Bind {
|
Ok(Bind {
|
||||||
code,
|
code,
|
||||||
len: len as i64,
|
len: len as u64,
|
||||||
portal,
|
portal,
|
||||||
prepared_statement,
|
prepared_statement,
|
||||||
num_param_format_codes,
|
num_param_format_codes,
|
||||||
@@ -1042,21 +1041,21 @@ impl TryFrom<Bind> for BytesMut {
|
|||||||
len += 2 * bind.num_result_column_format_codes as usize;
|
len += 2 * bind.num_result_column_format_codes as usize;
|
||||||
|
|
||||||
bytes.put_u8(bind.code as u8);
|
bytes.put_u8(bind.code as u8);
|
||||||
bytes.put_i32(len as i32);
|
bytes.put_u32(len as u32);
|
||||||
bytes.put_slice(portal);
|
bytes.put_slice(portal);
|
||||||
bytes.put_slice(prepared_statement);
|
bytes.put_slice(prepared_statement);
|
||||||
bytes.put_i16(bind.num_param_format_codes);
|
bytes.put_u16(bind.num_param_format_codes);
|
||||||
for param_format_code in bind.param_format_codes {
|
for param_format_code in bind.param_format_codes {
|
||||||
bytes.put_i16(param_format_code);
|
bytes.put_u16(param_format_code);
|
||||||
}
|
}
|
||||||
bytes.put_i16(bind.num_param_values);
|
bytes.put_u16(bind.num_param_values);
|
||||||
for (param_len, param) in bind.param_values {
|
for (param_len, param) in bind.param_values {
|
||||||
bytes.put_i32(param_len);
|
bytes.put_u32(param_len);
|
||||||
bytes.put_slice(¶m);
|
bytes.put_slice(¶m);
|
||||||
}
|
}
|
||||||
bytes.put_i16(bind.num_result_column_format_codes);
|
bytes.put_u16(bind.num_result_column_format_codes);
|
||||||
for result_column_format_code in bind.result_columns_format_codes {
|
for result_column_format_code in bind.result_columns_format_codes {
|
||||||
bytes.put_i16(result_column_format_code);
|
bytes.put_u16(result_column_format_code);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(bytes)
|
Ok(bytes)
|
||||||
@@ -1068,7 +1067,7 @@ impl Bind {
|
|||||||
pub fn get_name(buf: &BytesMut) -> Result<String, Error> {
|
pub fn get_name(buf: &BytesMut) -> Result<String, Error> {
|
||||||
let mut cursor = Cursor::new(buf);
|
let mut cursor = Cursor::new(buf);
|
||||||
// Skip the code and length
|
// Skip the code and length
|
||||||
cursor.advance(mem::size_of::<u8>() + mem::size_of::<i32>());
|
cursor.advance(mem::size_of::<u8>() + mem::size_of::<u32>());
|
||||||
cursor.read_string()?;
|
cursor.read_string()?;
|
||||||
cursor.read_string()
|
cursor.read_string()
|
||||||
}
|
}
|
||||||
@@ -1078,17 +1077,17 @@ impl Bind {
|
|||||||
let mut cursor = Cursor::new(&buf);
|
let mut cursor = Cursor::new(&buf);
|
||||||
// Read basic data from the cursor
|
// Read basic data from the cursor
|
||||||
let code = cursor.get_u8();
|
let code = cursor.get_u8();
|
||||||
let current_len = cursor.get_i32();
|
let current_len = cursor.get_u32();
|
||||||
let portal = cursor.read_string()?;
|
let portal = cursor.read_string()?;
|
||||||
let prepared_statement = cursor.read_string()?;
|
let prepared_statement = cursor.read_string()?;
|
||||||
|
|
||||||
// Calculate new length
|
// Calculate new length
|
||||||
let new_len = current_len + new_name.len() as i32 - prepared_statement.len() as i32;
|
let new_len = current_len + new_name.len() as u32 - prepared_statement.len() as u32;
|
||||||
|
|
||||||
// Begin building the response buffer
|
// Begin building the response buffer
|
||||||
let mut response_buf = BytesMut::with_capacity(new_len as usize + 1);
|
let mut response_buf = BytesMut::with_capacity(new_len as usize + 1);
|
||||||
response_buf.put_u8(code);
|
response_buf.put_u8(code);
|
||||||
response_buf.put_i32(new_len);
|
response_buf.put_u32(new_len);
|
||||||
|
|
||||||
// Put the portal and new name into the buffer
|
// Put the portal and new name into the buffer
|
||||||
// Note: panic if the provided string contains null byte
|
// Note: panic if the provided string contains null byte
|
||||||
@@ -1112,7 +1111,7 @@ pub struct Describe {
|
|||||||
code: char,
|
code: char,
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
len: i32,
|
len: u32,
|
||||||
pub target: char,
|
pub target: char,
|
||||||
pub statement_name: String,
|
pub statement_name: String,
|
||||||
}
|
}
|
||||||
@@ -1123,7 +1122,7 @@ impl TryFrom<&BytesMut> for Describe {
|
|||||||
fn try_from(bytes: &BytesMut) -> Result<Describe, Error> {
|
fn try_from(bytes: &BytesMut) -> Result<Describe, Error> {
|
||||||
let mut cursor = Cursor::new(bytes);
|
let mut cursor = Cursor::new(bytes);
|
||||||
let code = cursor.get_u8() as char;
|
let code = cursor.get_u8() as char;
|
||||||
let len = cursor.get_i32();
|
let len = cursor.get_u32();
|
||||||
let target = cursor.get_u8() as char;
|
let target = cursor.get_u8() as char;
|
||||||
let statement_name = cursor.read_string()?;
|
let statement_name = cursor.read_string()?;
|
||||||
|
|
||||||
@@ -1146,7 +1145,7 @@ impl TryFrom<Describe> for BytesMut {
|
|||||||
let len = 4 + 1 + statement_name.len();
|
let len = 4 + 1 + statement_name.len();
|
||||||
|
|
||||||
bytes.put_u8(describe.code as u8);
|
bytes.put_u8(describe.code as u8);
|
||||||
bytes.put_i32(len as i32);
|
bytes.put_u32(len as u32);
|
||||||
bytes.put_u8(describe.target as u8);
|
bytes.put_u8(describe.target as u8);
|
||||||
bytes.put_slice(statement_name);
|
bytes.put_slice(statement_name);
|
||||||
|
|
||||||
|
|||||||
@@ -698,7 +698,6 @@ impl Server {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
trace!("Error: {}", error_code);
|
trace!("Error: {}", error_code);
|
||||||
|
|
||||||
match error_code {
|
match error_code {
|
||||||
@@ -1013,6 +1012,12 @@ impl Server {
|
|||||||
// which can leak between clients. This is a best effort to block bad clients
|
// which can leak between clients. This is a best effort to block bad clients
|
||||||
// from poisoning a transaction-mode pool by setting inappropriate session variables
|
// from poisoning a transaction-mode pool by setting inappropriate session variables
|
||||||
match command.as_str() {
|
match command.as_str() {
|
||||||
|
"DISCARD ALL" => {
|
||||||
|
self.clear_prepared_statement_cache();
|
||||||
|
}
|
||||||
|
"DEALLOCATE ALL" => {
|
||||||
|
self.clear_prepared_statement_cache();
|
||||||
|
}
|
||||||
"SET" => {
|
"SET" => {
|
||||||
// We don't detect set statements in transactions
|
// We don't detect set statements in transactions
|
||||||
// No great way to differentiate between set and set local
|
// No great way to differentiate between set and set local
|
||||||
@@ -1132,6 +1137,12 @@ impl Server {
|
|||||||
has_it
|
has_it
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn clear_prepared_statement_cache(&mut self) {
|
||||||
|
if let Some(cache) = &mut self.prepared_statement_cache {
|
||||||
|
cache.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn add_prepared_statement_to_cache(&mut self, name: &str) -> Option<String> {
|
fn add_prepared_statement_to_cache(&mut self, name: &str) -> Option<String> {
|
||||||
let cache = match &mut self.prepared_statement_cache {
|
let cache = match &mut self.prepared_statement_cache {
|
||||||
Some(cache) => cache,
|
Some(cache) => cache,
|
||||||
|
|||||||
Reference in New Issue
Block a user