Reimplement prepared statements with LRU cache and statement deduplication (#618)

* Initial commit

* Cleanup and add stats

* Use an arc instead of full clones to store the parse packets

* Use mutex instead

* fmt

* clippy

* fmt

* fix?

* fix?

* fmt

* typo

* Update docs

* Refactor custom protocol

* fmt

* move custom protocol handling to before parsing

* Support describe

* Add LRU for server side statement cache

* rename variable

* Refactoring

* Move docs

* Fix test

* fix

* Update tests

* trigger build

* Add more tests

* Reorder handling sync

* Support when a named describe is sent along with Parse (go pgx) and expecting results

* don't talk to client if not needed when client sends Parse

* fmt :(

* refactor tests

* nit

* Reduce hashing

* Reducing work done to decode describe and parse messages

* minor refactor

* Merge branch 'main' into zain/reimplment-prepared-statements-with-global-lru-cache

* Rewrite extended and prepared protocol message handling to better support mocking response packets and close

* An attempt to better handle if there are DDL changes that might break cached plans with ideas about how to further improve it

* fix

* Minor stats fixed and cleanup

* Cosmetic fixes (#64)

* Cosmetic fixes

* fix test

* Change server drop for statement cache error to a `deallocate all`

* Updated comments and added new idea for handling DDL changes impacting cached plans

* fix test?

* Revert test change

* trigger build, flakey test

* Avoid potential race conditions by changing get_or_insert to promote for pool LRU

* remove ps enabled variable on the server in favor of using an option

* Add close to the Extended Protocol buffer

---------

Co-authored-by: Lev Kokotov <levkk@users.noreply.github.com>
This commit is contained in:
Zain Kabani
2023-10-25 18:11:57 -04:00
committed by GitHub
parent d37df43a90
commit 7d3003a16a
14 changed files with 1135 additions and 516 deletions

View File

@@ -12,13 +12,16 @@ use crate::config::get_config;
use crate::errors::Error;
use crate::constants::MESSAGE_TERMINATOR;
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::ffi::CString;
use std::fmt::{Display, Formatter};
use std::hash::{Hash, Hasher};
use std::io::{BufRead, Cursor};
use std::mem;
use std::str::FromStr;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::Duration;
/// Postgres data type mappings
@@ -114,19 +117,11 @@ pub fn simple_query(query: &str) -> BytesMut {
}
/// Tell the client we're ready for another query.
pub async fn ready_for_query<S>(stream: &mut S) -> Result<(), Error>
pub async fn send_ready_for_query<S>(stream: &mut S) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut bytes = BytesMut::with_capacity(
mem::size_of::<u8>() + mem::size_of::<i32>() + mem::size_of::<u8>(),
);
bytes.put_u8(b'Z');
bytes.put_i32(5);
bytes.put_u8(b'I'); // Idle
write_all(stream, bytes).await
write_all(stream, ready_for_query(false)).await
}
/// Send the startup packet the server. We're pretending we're a Pg client.
@@ -320,7 +315,7 @@ where
res.put_slice(&set_complete[..]);
write_all_half(stream, &res).await?;
ready_for_query(stream).await
send_ready_for_query(stream).await
}
/// Send a custom error message to the client.
@@ -331,7 +326,7 @@ where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
error_response_terminal(stream, message).await?;
ready_for_query(stream).await
send_ready_for_query(stream).await
}
/// Send a custom error message to the client.
@@ -432,7 +427,7 @@ where
res.put(command_complete("SELECT 1"));
write_all_half(stream, &res).await?;
ready_for_query(stream).await
send_ready_for_query(stream).await
}
pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
@@ -562,6 +557,37 @@ pub fn flush() -> BytesMut {
bytes
}
pub fn sync() -> BytesMut {
let mut bytes = BytesMut::with_capacity(mem::size_of::<u8>() + mem::size_of::<i32>());
bytes.put_u8(b'S');
bytes.put_i32(4);
bytes
}
pub fn parse_complete() -> BytesMut {
let mut bytes = BytesMut::with_capacity(mem::size_of::<u8>() + mem::size_of::<i32>());
bytes.put_u8(b'1');
bytes.put_i32(4);
bytes
}
pub fn ready_for_query(in_transaction: bool) -> BytesMut {
let mut bytes = BytesMut::with_capacity(
mem::size_of::<u8>() + mem::size_of::<i32>() + mem::size_of::<u8>(),
);
bytes.put_u8(b'Z');
bytes.put_i32(5);
if in_transaction {
bytes.put_u8(b'T');
} else {
bytes.put_u8(b'I');
}
bytes
}
/// Write all data in the buffer to the TcpStream.
pub async fn write_all<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error>
where
@@ -740,6 +766,51 @@ impl BytesMutReader for BytesMut {
}
}
}
pub enum ExtendedProtocolData {
Parse {
data: BytesMut,
metadata: Option<(Arc<Parse>, u64)>,
},
Bind {
data: BytesMut,
metadata: Option<String>,
},
Describe {
data: BytesMut,
metadata: Option<String>,
},
Execute {
data: BytesMut,
},
Close {
data: BytesMut,
close: Close,
},
}
impl ExtendedProtocolData {
pub fn create_new_parse(data: BytesMut, metadata: Option<(Arc<Parse>, u64)>) -> Self {
Self::Parse { data, metadata }
}
pub fn create_new_bind(data: BytesMut, metadata: Option<String>) -> Self {
Self::Bind { data, metadata }
}
pub fn create_new_describe(data: BytesMut, metadata: Option<String>) -> Self {
Self::Describe { data, metadata }
}
pub fn create_new_execute(data: BytesMut) -> Self {
Self::Execute { data }
}
pub fn create_new_close(data: BytesMut, close: Close) -> Self {
Self::Close { data, close }
}
}
/// Parse (F) message.
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
#[derive(Clone, Debug)]
@@ -748,7 +819,6 @@ pub struct Parse {
#[allow(dead_code)]
len: i32,
pub name: String,
pub generated_name: String,
query: String,
num_params: i16,
param_types: Vec<i32>,
@@ -774,7 +844,6 @@ impl TryFrom<&BytesMut> for Parse {
code,
len,
name,
generated_name: prepared_statement_name(),
query,
num_params,
param_types,
@@ -823,11 +892,44 @@ impl TryFrom<&Parse> for BytesMut {
}
impl Parse {
pub fn rename(mut self) -> Self {
self.name = self.generated_name.to_string();
/// Renames the prepared statement to a new name based on the global counter
pub fn rewrite(mut self) -> Self {
self.name = format!(
"PGCAT_{}",
PREPARED_STATEMENT_COUNTER.fetch_add(1, Ordering::SeqCst)
);
self
}
/// Gets the name of the prepared statement from the buffer
pub fn get_name(buf: &BytesMut) -> Result<String, Error> {
let mut cursor = Cursor::new(buf);
// Skip the code and length
cursor.advance(mem::size_of::<u8>() + mem::size_of::<i32>());
cursor.read_string()
}
/// Hashes the parse statement to be used as a key in the global cache
pub fn get_hash(&self) -> u64 {
// TODO_ZAIN: Take a look at which hashing function is being used
let mut hasher = DefaultHasher::new();
let concatenated = format!(
"{}{}{}",
self.query,
self.num_params,
self.param_types
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(",")
);
concatenated.hash(&mut hasher);
hasher.finish()
}
pub fn anonymous(&self) -> bool {
self.name.is_empty()
}
@@ -958,9 +1060,42 @@ impl TryFrom<Bind> for BytesMut {
}
impl Bind {
pub fn reassign(mut self, parse: &Parse) -> Self {
self.prepared_statement = parse.name.clone();
self
/// Gets the name of the prepared statement from the buffer
pub fn get_name(buf: &BytesMut) -> Result<String, Error> {
let mut cursor = Cursor::new(buf);
// Skip the code and length
cursor.advance(mem::size_of::<u8>() + mem::size_of::<i32>());
cursor.read_string()?;
cursor.read_string()
}
/// Renames the prepared statement to a new name
pub fn rename(buf: BytesMut, new_name: &str) -> Result<BytesMut, Error> {
let mut cursor = Cursor::new(&buf);
// Read basic data from the cursor
let code = cursor.get_u8();
let current_len = cursor.get_i32();
let portal = cursor.read_string()?;
let prepared_statement = cursor.read_string()?;
// Calculate new length
let new_len = current_len + new_name.len() as i32 - prepared_statement.len() as i32;
// Begin building the response buffer
let mut response_buf = BytesMut::with_capacity(new_len as usize + 1);
response_buf.put_u8(code);
response_buf.put_i32(new_len);
// Put the portal and new name into the buffer
// Note: panic if the provided string contains null byte
response_buf.put_slice(CString::new(portal)?.as_bytes_with_nul());
response_buf.put_slice(CString::new(new_name)?.as_bytes_with_nul());
// Add the remainder of the original buffer into the response
response_buf.put_slice(&buf[cursor.position() as usize..]);
// Return the buffer
Ok(response_buf)
}
pub fn anonymous(&self) -> bool {
@@ -1016,6 +1151,15 @@ impl TryFrom<Describe> for BytesMut {
}
impl Describe {
pub fn empty_new() -> Describe {
Describe {
code: 'D',
len: 4 + 1 + 1,
target: 'S',
statement_name: "".to_string(),
}
}
pub fn rename(mut self, name: &str) -> Self {
self.statement_name = name.to_string();
self
@@ -1104,13 +1248,6 @@ pub fn close_complete() -> BytesMut {
bytes
}
pub fn prepared_statement_name() -> String {
format!(
"P_{}",
PREPARED_STATEMENT_COUNTER.fetch_add(1, Ordering::SeqCst)
)
}
// from https://www.postgresql.org/docs/12/protocol-error-fields.html
#[derive(Debug, Default, PartialEq)]
pub struct PgErrorMsg {
@@ -1193,7 +1330,7 @@ impl Display for PgErrorMsg {
}
impl PgErrorMsg {
pub fn parse(error_msg: Vec<u8>) -> Result<PgErrorMsg, Error> {
pub fn parse(error_msg: &[u8]) -> Result<PgErrorMsg, Error> {
let mut out = PgErrorMsg {
severity_localized: "".to_string(),
severity: "".to_string(),
@@ -1341,7 +1478,7 @@ mod tests {
info!(
"full message: {}",
PgErrorMsg::parse(complete_msg.clone()).unwrap()
PgErrorMsg::parse(&complete_msg).unwrap()
);
assert_eq!(
PgErrorMsg {
@@ -1364,7 +1501,7 @@ mod tests {
line: Some(335),
routine: Some(routine_msg.to_string()),
},
PgErrorMsg::parse(complete_msg).unwrap()
PgErrorMsg::parse(&complete_msg).unwrap()
);
let mut only_mandatory_msg = vec![];
@@ -1374,7 +1511,7 @@ mod tests {
only_mandatory_msg.extend(field('M', message));
only_mandatory_msg.extend(field('D', detail_msg));
let err_fields = PgErrorMsg::parse(only_mandatory_msg.clone()).unwrap();
let err_fields = PgErrorMsg::parse(&only_mandatory_msg).unwrap();
info!("only mandatory fields: {}", &err_fields);
error!(
"server error: {}: {}",
@@ -1401,7 +1538,7 @@ mod tests {
line: None,
routine: None,
},
PgErrorMsg::parse(only_mandatory_msg).unwrap()
PgErrorMsg::parse(&only_mandatory_msg).unwrap()
);
}
}