mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-26 10:26:30 +00:00
Reimplement prepared statements with LRU cache and statement deduplication (#618)
* Initial commit * Cleanup and add stats * Use an arc instead of full clones to store the parse packets * Use mutex instead * fmt * clippy * fmt * fix? * fix? * fmt * typo * Update docs * Refactor custom protocol * fmt * move custom protocol handling to before parsing * Support describe * Add LRU for server side statement cache * rename variable * Refactoring * Move docs * Fix test * fix * Update tests * trigger build * Add more tests * Reorder handling sync * Support when a named describe is sent along with Parse (go pgx) and expecting results * don't talk to client if not needed when client sends Parse * fmt :( * refactor tests * nit * Reduce hashing * Reducing work done to decode describe and parse messages * minor refactor * Merge branch 'main' into zain/reimplment-prepared-statements-with-global-lru-cache * Rewrite extended and prepared protocol message handling to better support mocking response packets and close * An attempt to better handle if there are DDL changes that might break cached plans with ideas about how to further improve it * fix * Minor stats fixed and cleanup * Cosmetic fixes (#64) * Cosmetic fixes * fix test * Change server drop for statement cache error to a `deallocate all` * Updated comments and added new idea for handling DDL changes impacting cached plans * fix test? * Revert test change * trigger build, flakey test * Avoid potential race conditions by changing get_or_insert to promote for pool LRU * remove ps enabled variable on the server in favor of using an option * Add close to the Extended Protocol buffer --------- Co-authored-by: Lev Kokotov <levkk@users.noreply.github.com>
This commit is contained in:
201
src/messages.rs
201
src/messages.rs
@@ -12,13 +12,16 @@ use crate::config::get_config;
|
||||
use crate::errors::Error;
|
||||
|
||||
use crate::constants::MESSAGE_TERMINATOR;
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::CString;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::io::{BufRead, Cursor};
|
||||
use std::mem;
|
||||
use std::str::FromStr;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Postgres data type mappings
|
||||
@@ -114,19 +117,11 @@ pub fn simple_query(query: &str) -> BytesMut {
|
||||
}
|
||||
|
||||
/// Tell the client we're ready for another query.
|
||||
pub async fn ready_for_query<S>(stream: &mut S) -> Result<(), Error>
|
||||
pub async fn send_ready_for_query<S>(stream: &mut S) -> Result<(), Error>
|
||||
where
|
||||
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
let mut bytes = BytesMut::with_capacity(
|
||||
mem::size_of::<u8>() + mem::size_of::<i32>() + mem::size_of::<u8>(),
|
||||
);
|
||||
|
||||
bytes.put_u8(b'Z');
|
||||
bytes.put_i32(5);
|
||||
bytes.put_u8(b'I'); // Idle
|
||||
|
||||
write_all(stream, bytes).await
|
||||
write_all(stream, ready_for_query(false)).await
|
||||
}
|
||||
|
||||
/// Send the startup packet the server. We're pretending we're a Pg client.
|
||||
@@ -320,7 +315,7 @@ where
|
||||
res.put_slice(&set_complete[..]);
|
||||
|
||||
write_all_half(stream, &res).await?;
|
||||
ready_for_query(stream).await
|
||||
send_ready_for_query(stream).await
|
||||
}
|
||||
|
||||
/// Send a custom error message to the client.
|
||||
@@ -331,7 +326,7 @@ where
|
||||
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
error_response_terminal(stream, message).await?;
|
||||
ready_for_query(stream).await
|
||||
send_ready_for_query(stream).await
|
||||
}
|
||||
|
||||
/// Send a custom error message to the client.
|
||||
@@ -432,7 +427,7 @@ where
|
||||
res.put(command_complete("SELECT 1"));
|
||||
|
||||
write_all_half(stream, &res).await?;
|
||||
ready_for_query(stream).await
|
||||
send_ready_for_query(stream).await
|
||||
}
|
||||
|
||||
pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
|
||||
@@ -562,6 +557,37 @@ pub fn flush() -> BytesMut {
|
||||
bytes
|
||||
}
|
||||
|
||||
pub fn sync() -> BytesMut {
|
||||
let mut bytes = BytesMut::with_capacity(mem::size_of::<u8>() + mem::size_of::<i32>());
|
||||
bytes.put_u8(b'S');
|
||||
bytes.put_i32(4);
|
||||
bytes
|
||||
}
|
||||
|
||||
pub fn parse_complete() -> BytesMut {
|
||||
let mut bytes = BytesMut::with_capacity(mem::size_of::<u8>() + mem::size_of::<i32>());
|
||||
|
||||
bytes.put_u8(b'1');
|
||||
bytes.put_i32(4);
|
||||
bytes
|
||||
}
|
||||
|
||||
pub fn ready_for_query(in_transaction: bool) -> BytesMut {
|
||||
let mut bytes = BytesMut::with_capacity(
|
||||
mem::size_of::<u8>() + mem::size_of::<i32>() + mem::size_of::<u8>(),
|
||||
);
|
||||
|
||||
bytes.put_u8(b'Z');
|
||||
bytes.put_i32(5);
|
||||
if in_transaction {
|
||||
bytes.put_u8(b'T');
|
||||
} else {
|
||||
bytes.put_u8(b'I');
|
||||
}
|
||||
|
||||
bytes
|
||||
}
|
||||
|
||||
/// Write all data in the buffer to the TcpStream.
|
||||
pub async fn write_all<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error>
|
||||
where
|
||||
@@ -740,6 +766,51 @@ impl BytesMutReader for BytesMut {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ExtendedProtocolData {
|
||||
Parse {
|
||||
data: BytesMut,
|
||||
metadata: Option<(Arc<Parse>, u64)>,
|
||||
},
|
||||
Bind {
|
||||
data: BytesMut,
|
||||
metadata: Option<String>,
|
||||
},
|
||||
Describe {
|
||||
data: BytesMut,
|
||||
metadata: Option<String>,
|
||||
},
|
||||
Execute {
|
||||
data: BytesMut,
|
||||
},
|
||||
Close {
|
||||
data: BytesMut,
|
||||
close: Close,
|
||||
},
|
||||
}
|
||||
|
||||
impl ExtendedProtocolData {
|
||||
pub fn create_new_parse(data: BytesMut, metadata: Option<(Arc<Parse>, u64)>) -> Self {
|
||||
Self::Parse { data, metadata }
|
||||
}
|
||||
|
||||
pub fn create_new_bind(data: BytesMut, metadata: Option<String>) -> Self {
|
||||
Self::Bind { data, metadata }
|
||||
}
|
||||
|
||||
pub fn create_new_describe(data: BytesMut, metadata: Option<String>) -> Self {
|
||||
Self::Describe { data, metadata }
|
||||
}
|
||||
|
||||
pub fn create_new_execute(data: BytesMut) -> Self {
|
||||
Self::Execute { data }
|
||||
}
|
||||
|
||||
pub fn create_new_close(data: BytesMut, close: Close) -> Self {
|
||||
Self::Close { data, close }
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse (F) message.
|
||||
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -748,7 +819,6 @@ pub struct Parse {
|
||||
#[allow(dead_code)]
|
||||
len: i32,
|
||||
pub name: String,
|
||||
pub generated_name: String,
|
||||
query: String,
|
||||
num_params: i16,
|
||||
param_types: Vec<i32>,
|
||||
@@ -774,7 +844,6 @@ impl TryFrom<&BytesMut> for Parse {
|
||||
code,
|
||||
len,
|
||||
name,
|
||||
generated_name: prepared_statement_name(),
|
||||
query,
|
||||
num_params,
|
||||
param_types,
|
||||
@@ -823,11 +892,44 @@ impl TryFrom<&Parse> for BytesMut {
|
||||
}
|
||||
|
||||
impl Parse {
|
||||
pub fn rename(mut self) -> Self {
|
||||
self.name = self.generated_name.to_string();
|
||||
/// Renames the prepared statement to a new name based on the global counter
|
||||
pub fn rewrite(mut self) -> Self {
|
||||
self.name = format!(
|
||||
"PGCAT_{}",
|
||||
PREPARED_STATEMENT_COUNTER.fetch_add(1, Ordering::SeqCst)
|
||||
);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the name of the prepared statement from the buffer
|
||||
pub fn get_name(buf: &BytesMut) -> Result<String, Error> {
|
||||
let mut cursor = Cursor::new(buf);
|
||||
// Skip the code and length
|
||||
cursor.advance(mem::size_of::<u8>() + mem::size_of::<i32>());
|
||||
cursor.read_string()
|
||||
}
|
||||
|
||||
/// Hashes the parse statement to be used as a key in the global cache
|
||||
pub fn get_hash(&self) -> u64 {
|
||||
// TODO_ZAIN: Take a look at which hashing function is being used
|
||||
let mut hasher = DefaultHasher::new();
|
||||
|
||||
let concatenated = format!(
|
||||
"{}{}{}",
|
||||
self.query,
|
||||
self.num_params,
|
||||
self.param_types
|
||||
.iter()
|
||||
.map(ToString::to_string)
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
|
||||
concatenated.hash(&mut hasher);
|
||||
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
pub fn anonymous(&self) -> bool {
|
||||
self.name.is_empty()
|
||||
}
|
||||
@@ -958,9 +1060,42 @@ impl TryFrom<Bind> for BytesMut {
|
||||
}
|
||||
|
||||
impl Bind {
|
||||
pub fn reassign(mut self, parse: &Parse) -> Self {
|
||||
self.prepared_statement = parse.name.clone();
|
||||
self
|
||||
/// Gets the name of the prepared statement from the buffer
|
||||
pub fn get_name(buf: &BytesMut) -> Result<String, Error> {
|
||||
let mut cursor = Cursor::new(buf);
|
||||
// Skip the code and length
|
||||
cursor.advance(mem::size_of::<u8>() + mem::size_of::<i32>());
|
||||
cursor.read_string()?;
|
||||
cursor.read_string()
|
||||
}
|
||||
|
||||
/// Renames the prepared statement to a new name
|
||||
pub fn rename(buf: BytesMut, new_name: &str) -> Result<BytesMut, Error> {
|
||||
let mut cursor = Cursor::new(&buf);
|
||||
// Read basic data from the cursor
|
||||
let code = cursor.get_u8();
|
||||
let current_len = cursor.get_i32();
|
||||
let portal = cursor.read_string()?;
|
||||
let prepared_statement = cursor.read_string()?;
|
||||
|
||||
// Calculate new length
|
||||
let new_len = current_len + new_name.len() as i32 - prepared_statement.len() as i32;
|
||||
|
||||
// Begin building the response buffer
|
||||
let mut response_buf = BytesMut::with_capacity(new_len as usize + 1);
|
||||
response_buf.put_u8(code);
|
||||
response_buf.put_i32(new_len);
|
||||
|
||||
// Put the portal and new name into the buffer
|
||||
// Note: panic if the provided string contains null byte
|
||||
response_buf.put_slice(CString::new(portal)?.as_bytes_with_nul());
|
||||
response_buf.put_slice(CString::new(new_name)?.as_bytes_with_nul());
|
||||
|
||||
// Add the remainder of the original buffer into the response
|
||||
response_buf.put_slice(&buf[cursor.position() as usize..]);
|
||||
|
||||
// Return the buffer
|
||||
Ok(response_buf)
|
||||
}
|
||||
|
||||
pub fn anonymous(&self) -> bool {
|
||||
@@ -1016,6 +1151,15 @@ impl TryFrom<Describe> for BytesMut {
|
||||
}
|
||||
|
||||
impl Describe {
|
||||
pub fn empty_new() -> Describe {
|
||||
Describe {
|
||||
code: 'D',
|
||||
len: 4 + 1 + 1,
|
||||
target: 'S',
|
||||
statement_name: "".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rename(mut self, name: &str) -> Self {
|
||||
self.statement_name = name.to_string();
|
||||
self
|
||||
@@ -1104,13 +1248,6 @@ pub fn close_complete() -> BytesMut {
|
||||
bytes
|
||||
}
|
||||
|
||||
pub fn prepared_statement_name() -> String {
|
||||
format!(
|
||||
"P_{}",
|
||||
PREPARED_STATEMENT_COUNTER.fetch_add(1, Ordering::SeqCst)
|
||||
)
|
||||
}
|
||||
|
||||
// from https://www.postgresql.org/docs/12/protocol-error-fields.html
|
||||
#[derive(Debug, Default, PartialEq)]
|
||||
pub struct PgErrorMsg {
|
||||
@@ -1193,7 +1330,7 @@ impl Display for PgErrorMsg {
|
||||
}
|
||||
|
||||
impl PgErrorMsg {
|
||||
pub fn parse(error_msg: Vec<u8>) -> Result<PgErrorMsg, Error> {
|
||||
pub fn parse(error_msg: &[u8]) -> Result<PgErrorMsg, Error> {
|
||||
let mut out = PgErrorMsg {
|
||||
severity_localized: "".to_string(),
|
||||
severity: "".to_string(),
|
||||
@@ -1341,7 +1478,7 @@ mod tests {
|
||||
|
||||
info!(
|
||||
"full message: {}",
|
||||
PgErrorMsg::parse(complete_msg.clone()).unwrap()
|
||||
PgErrorMsg::parse(&complete_msg).unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
PgErrorMsg {
|
||||
@@ -1364,7 +1501,7 @@ mod tests {
|
||||
line: Some(335),
|
||||
routine: Some(routine_msg.to_string()),
|
||||
},
|
||||
PgErrorMsg::parse(complete_msg).unwrap()
|
||||
PgErrorMsg::parse(&complete_msg).unwrap()
|
||||
);
|
||||
|
||||
let mut only_mandatory_msg = vec![];
|
||||
@@ -1374,7 +1511,7 @@ mod tests {
|
||||
only_mandatory_msg.extend(field('M', message));
|
||||
only_mandatory_msg.extend(field('D', detail_msg));
|
||||
|
||||
let err_fields = PgErrorMsg::parse(only_mandatory_msg.clone()).unwrap();
|
||||
let err_fields = PgErrorMsg::parse(&only_mandatory_msg).unwrap();
|
||||
info!("only mandatory fields: {}", &err_fields);
|
||||
error!(
|
||||
"server error: {}: {}",
|
||||
@@ -1401,7 +1538,7 @@ mod tests {
|
||||
line: None,
|
||||
routine: None,
|
||||
},
|
||||
PgErrorMsg::parse(only_mandatory_msg).unwrap()
|
||||
PgErrorMsg::parse(&only_mandatory_msg).unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user