Reimplement prepared statements with LRU cache and statement deduplication (#618)

* Initial commit

* Cleanup and add stats

* Use an arc instead of full clones to store the parse packets

* Use mutex instead

* fmt

* clippy

* fmt

* fix?

* fix?

* fmt

* typo

* Update docs

* Refactor custom protocol

* fmt

* move custom protocol handling to before parsing

* Support describe

* Add LRU for server side statement cache

* rename variable

* Refactoring

* Move docs

* Fix test

* fix

* Update tests

* trigger build

* Add more tests

* Reorder handling sync

* Support when a named describe is sent along with Parse (go pgx) and expecting results

* don't talk to client if not needed when client sends Parse

* fmt :(

* refactor tests

* nit

* Reduce hashing

* Reducing work done to decode describe and parse messages

* minor refactor

* Merge branch 'main' into zain/reimplment-prepared-statements-with-global-lru-cache

* Rewrite extended and prepared protocol message handling to better support mocking response packets and close

* An attempt to better handle if there are DDL changes that might break cached plans with ideas about how to further improve it

* fix

* Minor stats fixed and cleanup

* Cosmetic fixes (#64)

* Cosmetic fixes

* fix test

* Change server drop for statement cache error to a `deallocate all`

* Updated comments and added new idea for handling DDL changes impacting cached plans

* fix test?

* Revert test change

* trigger build, flakey test

* Avoid potential race conditions by changing get_or_insert to promote for pool LRU

* remove ps enabled variable on the server in favor of using an option

* Add close to the Extended Protocol buffer

---------

Co-authored-by: Lev Kokotov <levkk@users.noreply.github.com>
This commit is contained in:
Zain Kabani
2023-10-25 18:11:57 -04:00
committed by GitHub
parent d37df43a90
commit 7d3003a16a
14 changed files with 1135 additions and 516 deletions

View File

@@ -259,22 +259,6 @@ Password to be used for connecting to servers to obtain the hash used for md5 au
specified in `auth_query_user`. The connection will be established using the database configured in the pool. specified in `auth_query_user`. The connection will be established using the database configured in the pool.
This parameter is inherited by every pool and can be redefined in pool configuration. This parameter is inherited by every pool and can be redefined in pool configuration.
### prepared_statements
```
path: general.prepared_statements
default: false
```
Whether to use prepared statements or not.
### prepared_statements_cache_size
```
path: general.prepared_statements_cache_size
default: 500
```
Size of the prepared statements cache.
### dns_cache_enabled ### dns_cache_enabled
``` ```
path: general.dns_cache_enabled path: general.dns_cache_enabled
@@ -324,6 +308,15 @@ If the client doesn't specify, PgCat routes traffic to this role by default.
`replica` round-robin between replicas only without touching the primary, `replica` round-robin between replicas only without touching the primary,
`primary` all queries go to the primary unless otherwise specified. `primary` all queries go to the primary unless otherwise specified.
### prepared_statements_cache_size
```
path: general.prepared_statements_cache_size
default: 0
```
Size of the prepared statements cache. 0 means disabled.
TODO: update documentation
### query_parser_enabled ### query_parser_enabled
``` ```
path: pools.<pool_name>.query_parser_enabled path: pools.<pool_name>.query_parser_enabled

31
Cargo.lock generated
View File

@@ -17,6 +17,17 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "ahash"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f"
dependencies = [
"cfg-if",
"once_cell",
"version_check",
]
[[package]] [[package]]
name = "aho-corasick" name = "aho-corasick"
version = "1.0.2" version = "1.0.2"
@@ -26,6 +37,12 @@ dependencies = [
"memchr", "memchr",
] ]
[[package]]
name = "allocator-api2"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5"
[[package]] [[package]]
name = "android-tzdata" name = "android-tzdata"
version = "0.1.1" version = "0.1.1"
@@ -553,6 +570,10 @@ name = "hashbrown"
version = "0.14.0" version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a"
dependencies = [
"ahash",
"allocator-api2",
]
[[package]] [[package]]
name = "heck" name = "heck"
@@ -821,6 +842,15 @@ version = "0.4.19"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4"
[[package]]
name = "lru"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1efa59af2ddfad1854ae27d75009d538d0998b4b2fd47083e743ac1a10e46c60"
dependencies = [
"hashbrown 0.14.0",
]
[[package]] [[package]]
name = "lru-cache" name = "lru-cache"
version = "0.1.2" version = "0.1.2"
@@ -1008,6 +1038,7 @@ dependencies = [
"itertools", "itertools",
"jemallocator", "jemallocator",
"log", "log",
"lru",
"md-5", "md-5",
"nix", "nix",
"num_cpus", "num_cpus",

View File

@@ -48,6 +48,7 @@ itertools = "0.10"
clap = { version = "4.3.1", features = ["derive", "env"] } clap = { version = "4.3.1", features = ["derive", "env"] }
tracing = "0.1.37" tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter", "std"]} tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter", "std"]}
lru = "0.12.0"
[target.'cfg(not(target_env = "msvc"))'.dependencies] [target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0" jemallocator = "0.5.0"

View File

@@ -60,12 +60,6 @@ tcp_keepalives_count = 5
# Number of seconds between keepalive packets. # Number of seconds between keepalive packets.
tcp_keepalives_interval = 5 tcp_keepalives_interval = 5
# Handle prepared statements.
prepared_statements = true
# Prepared statements server cache size.
prepared_statements_cache_size = 500
# Path to TLS Certificate file to use for TLS connections # Path to TLS Certificate file to use for TLS connections
# tls_certificate = ".circleci/server.cert" # tls_certificate = ".circleci/server.cert"
# Path to TLS private key file to use for TLS connections # Path to TLS private key file to use for TLS connections
@@ -156,6 +150,10 @@ load_balancing_mode = "random"
# `primary` all queries go to the primary unless otherwise specified. # `primary` all queries go to the primary unless otherwise specified.
default_role = "any" default_role = "any"
# Prepared statements cache size.
# TODO: update documentation
prepared_statements_cache_size = 500
# If Query Parser is enabled, we'll attempt to parse # If Query Parser is enabled, we'll attempt to parse
# every incoming query to determine if it's a read or a write. # every incoming query to determine if it's a read or a write.
# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write, # If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,

View File

@@ -744,6 +744,7 @@ where
("age_seconds", DataType::Numeric), ("age_seconds", DataType::Numeric),
("prepare_cache_hit", DataType::Numeric), ("prepare_cache_hit", DataType::Numeric),
("prepare_cache_miss", DataType::Numeric), ("prepare_cache_miss", DataType::Numeric),
("prepare_cache_eviction", DataType::Numeric),
("prepare_cache_size", DataType::Numeric), ("prepare_cache_size", DataType::Numeric),
]; ];
@@ -776,6 +777,10 @@ where
.prepared_miss_count .prepared_miss_count
.load(Ordering::Relaxed) .load(Ordering::Relaxed)
.to_string(), .to_string(),
server
.prepared_eviction_count
.load(Ordering::Relaxed)
.to_string(),
server server
.prepared_cache_size .prepared_cache_size
.load(Ordering::Relaxed) .load(Ordering::Relaxed)

File diff suppressed because it is too large Load Diff

View File

@@ -116,10 +116,10 @@ impl Default for Address {
host: String::from("127.0.0.1"), host: String::from("127.0.0.1"),
port: 5432, port: 5432,
shard: 0, shard: 0,
address_index: 0,
replica_number: 0,
database: String::from("database"), database: String::from("database"),
role: Role::Replica, role: Role::Replica,
replica_number: 0,
address_index: 0,
username: String::from("username"), username: String::from("username"),
pool_name: String::from("pool_name"), pool_name: String::from("pool_name"),
mirrors: Vec::new(), mirrors: Vec::new(),
@@ -337,12 +337,6 @@ pub struct General {
pub auth_query: Option<String>, pub auth_query: Option<String>,
pub auth_query_user: Option<String>, pub auth_query_user: Option<String>,
pub auth_query_password: Option<String>, pub auth_query_password: Option<String>,
#[serde(default)]
pub prepared_statements: bool,
#[serde(default = "General::default_prepared_statements_cache_size")]
pub prepared_statements_cache_size: usize,
} }
impl General { impl General {
@@ -424,10 +418,6 @@ impl General {
pub fn default_server_round_robin() -> bool { pub fn default_server_round_robin() -> bool {
true true
} }
pub fn default_prepared_statements_cache_size() -> usize {
500
}
} }
impl Default for General { impl Default for General {
@@ -439,35 +429,33 @@ impl Default for General {
prometheus_exporter_port: 9930, prometheus_exporter_port: 9930,
connect_timeout: General::default_connect_timeout(), connect_timeout: General::default_connect_timeout(),
idle_timeout: General::default_idle_timeout(), idle_timeout: General::default_idle_timeout(),
shutdown_timeout: Self::default_shutdown_timeout(),
healthcheck_timeout: Self::default_healthcheck_timeout(),
healthcheck_delay: Self::default_healthcheck_delay(),
ban_time: Self::default_ban_time(),
worker_threads: Self::default_worker_threads(),
idle_client_in_transaction_timeout: Self::default_idle_client_in_transaction_timeout(),
tcp_keepalives_idle: Self::default_tcp_keepalives_idle(), tcp_keepalives_idle: Self::default_tcp_keepalives_idle(),
tcp_keepalives_count: Self::default_tcp_keepalives_count(), tcp_keepalives_count: Self::default_tcp_keepalives_count(),
tcp_keepalives_interval: Self::default_tcp_keepalives_interval(), tcp_keepalives_interval: Self::default_tcp_keepalives_interval(),
tcp_user_timeout: Self::default_tcp_user_timeout(), tcp_user_timeout: Self::default_tcp_user_timeout(),
log_client_connections: false, log_client_connections: false,
log_client_disconnections: false, log_client_disconnections: false,
autoreload: None,
dns_cache_enabled: false, dns_cache_enabled: false,
dns_max_ttl: Self::default_dns_max_ttl(), dns_max_ttl: Self::default_dns_max_ttl(),
shutdown_timeout: Self::default_shutdown_timeout(),
healthcheck_timeout: Self::default_healthcheck_timeout(),
healthcheck_delay: Self::default_healthcheck_delay(),
ban_time: Self::default_ban_time(),
idle_client_in_transaction_timeout: Self::default_idle_client_in_transaction_timeout(),
server_lifetime: Self::default_server_lifetime(),
server_round_robin: Self::default_server_round_robin(),
worker_threads: Self::default_worker_threads(),
autoreload: None,
tls_certificate: None, tls_certificate: None,
tls_private_key: None, tls_private_key: None,
server_tls: false, server_tls: false,
verify_server_certificate: false, verify_server_certificate: false,
admin_username: String::from("admin"), admin_username: String::from("admin"),
admin_password: String::from("admin"), admin_password: String::from("admin"),
validate_config: true,
auth_query: None, auth_query: None,
auth_query_user: None, auth_query_user: None,
auth_query_password: None, auth_query_password: None,
server_lifetime: Self::default_server_lifetime(),
server_round_robin: Self::default_server_round_robin(),
validate_config: true,
prepared_statements: false,
prepared_statements_cache_size: 500,
} }
} }
} }
@@ -568,6 +556,9 @@ pub struct Pool {
#[serde(default)] // False #[serde(default)] // False
pub log_client_parameter_status_changes: bool, pub log_client_parameter_status_changes: bool,
#[serde(default = "Pool::default_prepared_statements_cache_size")]
pub prepared_statements_cache_size: usize,
pub plugins: Option<Plugins>, pub plugins: Option<Plugins>,
pub shards: BTreeMap<String, Shard>, pub shards: BTreeMap<String, Shard>,
pub users: BTreeMap<String, User>, pub users: BTreeMap<String, User>,
@@ -617,6 +608,10 @@ impl Pool {
true true
} }
pub fn default_prepared_statements_cache_size() -> usize {
0
}
pub fn validate(&mut self) -> Result<(), Error> { pub fn validate(&mut self) -> Result<(), Error> {
match self.default_role.as_ref() { match self.default_role.as_ref() {
"any" => (), "any" => (),
@@ -708,17 +703,16 @@ impl Default for Pool {
Pool { Pool {
pool_mode: Self::default_pool_mode(), pool_mode: Self::default_pool_mode(),
load_balancing_mode: Self::default_load_balancing_mode(), load_balancing_mode: Self::default_load_balancing_mode(),
shards: BTreeMap::from([(String::from("1"), Shard::default())]),
users: BTreeMap::default(),
default_role: String::from("any"), default_role: String::from("any"),
query_parser_enabled: false, query_parser_enabled: false,
query_parser_max_length: None, query_parser_max_length: None,
query_parser_read_write_splitting: false, query_parser_read_write_splitting: false,
primary_reads_enabled: false, primary_reads_enabled: false,
sharding_function: ShardingFunction::PgBigintHash,
automatic_sharding_key: None,
connect_timeout: None, connect_timeout: None,
idle_timeout: None, idle_timeout: None,
server_lifetime: None,
sharding_function: ShardingFunction::PgBigintHash,
automatic_sharding_key: None,
sharding_key_regex: None, sharding_key_regex: None,
shard_id_regex: None, shard_id_regex: None,
regex_search_limit: Some(1000), regex_search_limit: Some(1000),
@@ -726,10 +720,12 @@ impl Default for Pool {
auth_query: None, auth_query: None,
auth_query_user: None, auth_query_user: None,
auth_query_password: None, auth_query_password: None,
server_lifetime: None,
plugins: None,
cleanup_server_connections: true, cleanup_server_connections: true,
log_client_parameter_status_changes: false, log_client_parameter_status_changes: false,
prepared_statements_cache_size: Self::default_prepared_statements_cache_size(),
plugins: None,
shards: BTreeMap::from([(String::from("1"), Shard::default())]),
users: BTreeMap::default(),
} }
} }
} }
@@ -841,13 +837,13 @@ impl Shard {
impl Default for Shard { impl Default for Shard {
fn default() -> Shard { fn default() -> Shard {
Shard { Shard {
database: String::from("postgres"),
mirrors: None,
servers: vec![ServerConfig { servers: vec![ServerConfig {
host: String::from("localhost"), host: String::from("localhost"),
port: 5432, port: 5432,
role: Role::Primary, role: Role::Primary,
}], }],
mirrors: None,
database: String::from("postgres"),
} }
} }
} }
@@ -1018,8 +1014,8 @@ impl Default for Config {
Config { Config {
path: Self::default_path(), path: Self::default_path(),
general: General::default(), general: General::default(),
pools: HashMap::default(),
plugins: None, plugins: None,
pools: HashMap::default(),
} }
} }
} }
@@ -1128,6 +1124,7 @@ impl From<&Config> for std::collections::HashMap<String, String> {
impl Config { impl Config {
/// Print current configuration. /// Print current configuration.
pub fn show(&self) { pub fn show(&self) {
info!("Config path: {}", self.path);
info!("Ban time: {}s", self.general.ban_time); info!("Ban time: {}s", self.general.ban_time);
info!( info!(
"Idle client in transaction timeout: {}ms", "Idle client in transaction timeout: {}ms",
@@ -1174,13 +1171,6 @@ impl Config {
"Server TLS certificate verification: {}", "Server TLS certificate verification: {}",
self.general.verify_server_certificate self.general.verify_server_certificate
); );
info!("Prepared statements: {}", self.general.prepared_statements);
if self.general.prepared_statements {
info!(
"Prepared statements server cache size: {}",
self.general.prepared_statements_cache_size
);
}
info!( info!(
"Plugins: {}", "Plugins: {}",
match self.plugins { match self.plugins {
@@ -1271,6 +1261,10 @@ impl Config {
"[pool: {}] Log client parameter status changes: {}", "[pool: {}] Log client parameter status changes: {}",
pool_name, pool_config.log_client_parameter_status_changes pool_name, pool_config.log_client_parameter_status_changes
); );
info!(
"[pool: {}] Prepared statements server cache size: {}",
pool_name, pool_config.prepared_statements_cache_size
);
info!( info!(
"[pool: {}] Plugins: {}", "[pool: {}] Plugins: {}",
pool_name, pool_name,
@@ -1413,14 +1407,6 @@ pub fn get_idle_client_in_transaction_timeout() -> u64 {
CONFIG.load().general.idle_client_in_transaction_timeout CONFIG.load().general.idle_client_in_transaction_timeout
} }
pub fn get_prepared_statements() -> bool {
CONFIG.load().general.prepared_statements
}
pub fn get_prepared_statements_cache_size() -> usize {
CONFIG.load().general.prepared_statements_cache_size
}
/// Parse the configuration file located at the path. /// Parse the configuration file located at the path.
pub async fn parse(path: &str) -> Result<(), Error> { pub async fn parse(path: &str) -> Result<(), Error> {
let mut contents = String::new(); let mut contents = String::new();

View File

@@ -12,13 +12,16 @@ use crate::config::get_config;
use crate::errors::Error; use crate::errors::Error;
use crate::constants::MESSAGE_TERMINATOR; use crate::constants::MESSAGE_TERMINATOR;
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap; use std::collections::HashMap;
use std::ffi::CString; use std::ffi::CString;
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
use std::hash::{Hash, Hasher};
use std::io::{BufRead, Cursor}; use std::io::{BufRead, Cursor};
use std::mem; use std::mem;
use std::str::FromStr; use std::str::FromStr;
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
/// Postgres data type mappings /// Postgres data type mappings
@@ -114,19 +117,11 @@ pub fn simple_query(query: &str) -> BytesMut {
} }
/// Tell the client we're ready for another query. /// Tell the client we're ready for another query.
pub async fn ready_for_query<S>(stream: &mut S) -> Result<(), Error> pub async fn send_ready_for_query<S>(stream: &mut S) -> Result<(), Error>
where where
S: tokio::io::AsyncWrite + std::marker::Unpin, S: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
let mut bytes = BytesMut::with_capacity( write_all(stream, ready_for_query(false)).await
mem::size_of::<u8>() + mem::size_of::<i32>() + mem::size_of::<u8>(),
);
bytes.put_u8(b'Z');
bytes.put_i32(5);
bytes.put_u8(b'I'); // Idle
write_all(stream, bytes).await
} }
/// Send the startup packet the server. We're pretending we're a Pg client. /// Send the startup packet the server. We're pretending we're a Pg client.
@@ -320,7 +315,7 @@ where
res.put_slice(&set_complete[..]); res.put_slice(&set_complete[..]);
write_all_half(stream, &res).await?; write_all_half(stream, &res).await?;
ready_for_query(stream).await send_ready_for_query(stream).await
} }
/// Send a custom error message to the client. /// Send a custom error message to the client.
@@ -331,7 +326,7 @@ where
S: tokio::io::AsyncWrite + std::marker::Unpin, S: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
error_response_terminal(stream, message).await?; error_response_terminal(stream, message).await?;
ready_for_query(stream).await send_ready_for_query(stream).await
} }
/// Send a custom error message to the client. /// Send a custom error message to the client.
@@ -432,7 +427,7 @@ where
res.put(command_complete("SELECT 1")); res.put(command_complete("SELECT 1"));
write_all_half(stream, &res).await?; write_all_half(stream, &res).await?;
ready_for_query(stream).await send_ready_for_query(stream).await
} }
pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut { pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
@@ -562,6 +557,37 @@ pub fn flush() -> BytesMut {
bytes bytes
} }
pub fn sync() -> BytesMut {
let mut bytes = BytesMut::with_capacity(mem::size_of::<u8>() + mem::size_of::<i32>());
bytes.put_u8(b'S');
bytes.put_i32(4);
bytes
}
pub fn parse_complete() -> BytesMut {
let mut bytes = BytesMut::with_capacity(mem::size_of::<u8>() + mem::size_of::<i32>());
bytes.put_u8(b'1');
bytes.put_i32(4);
bytes
}
pub fn ready_for_query(in_transaction: bool) -> BytesMut {
let mut bytes = BytesMut::with_capacity(
mem::size_of::<u8>() + mem::size_of::<i32>() + mem::size_of::<u8>(),
);
bytes.put_u8(b'Z');
bytes.put_i32(5);
if in_transaction {
bytes.put_u8(b'T');
} else {
bytes.put_u8(b'I');
}
bytes
}
/// Write all data in the buffer to the TcpStream. /// Write all data in the buffer to the TcpStream.
pub async fn write_all<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error> pub async fn write_all<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error>
where where
@@ -740,6 +766,51 @@ impl BytesMutReader for BytesMut {
} }
} }
} }
pub enum ExtendedProtocolData {
Parse {
data: BytesMut,
metadata: Option<(Arc<Parse>, u64)>,
},
Bind {
data: BytesMut,
metadata: Option<String>,
},
Describe {
data: BytesMut,
metadata: Option<String>,
},
Execute {
data: BytesMut,
},
Close {
data: BytesMut,
close: Close,
},
}
impl ExtendedProtocolData {
pub fn create_new_parse(data: BytesMut, metadata: Option<(Arc<Parse>, u64)>) -> Self {
Self::Parse { data, metadata }
}
pub fn create_new_bind(data: BytesMut, metadata: Option<String>) -> Self {
Self::Bind { data, metadata }
}
pub fn create_new_describe(data: BytesMut, metadata: Option<String>) -> Self {
Self::Describe { data, metadata }
}
pub fn create_new_execute(data: BytesMut) -> Self {
Self::Execute { data }
}
pub fn create_new_close(data: BytesMut, close: Close) -> Self {
Self::Close { data, close }
}
}
/// Parse (F) message. /// Parse (F) message.
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html> /// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@@ -748,7 +819,6 @@ pub struct Parse {
#[allow(dead_code)] #[allow(dead_code)]
len: i32, len: i32,
pub name: String, pub name: String,
pub generated_name: String,
query: String, query: String,
num_params: i16, num_params: i16,
param_types: Vec<i32>, param_types: Vec<i32>,
@@ -774,7 +844,6 @@ impl TryFrom<&BytesMut> for Parse {
code, code,
len, len,
name, name,
generated_name: prepared_statement_name(),
query, query,
num_params, num_params,
param_types, param_types,
@@ -823,11 +892,44 @@ impl TryFrom<&Parse> for BytesMut {
} }
impl Parse { impl Parse {
pub fn rename(mut self) -> Self { /// Renames the prepared statement to a new name based on the global counter
self.name = self.generated_name.to_string(); pub fn rewrite(mut self) -> Self {
self.name = format!(
"PGCAT_{}",
PREPARED_STATEMENT_COUNTER.fetch_add(1, Ordering::SeqCst)
);
self self
} }
/// Gets the name of the prepared statement from the buffer
pub fn get_name(buf: &BytesMut) -> Result<String, Error> {
let mut cursor = Cursor::new(buf);
// Skip the code and length
cursor.advance(mem::size_of::<u8>() + mem::size_of::<i32>());
cursor.read_string()
}
/// Hashes the parse statement to be used as a key in the global cache
pub fn get_hash(&self) -> u64 {
// TODO_ZAIN: Take a look at which hashing function is being used
let mut hasher = DefaultHasher::new();
let concatenated = format!(
"{}{}{}",
self.query,
self.num_params,
self.param_types
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(",")
);
concatenated.hash(&mut hasher);
hasher.finish()
}
pub fn anonymous(&self) -> bool { pub fn anonymous(&self) -> bool {
self.name.is_empty() self.name.is_empty()
} }
@@ -958,9 +1060,42 @@ impl TryFrom<Bind> for BytesMut {
} }
impl Bind { impl Bind {
pub fn reassign(mut self, parse: &Parse) -> Self { /// Gets the name of the prepared statement from the buffer
self.prepared_statement = parse.name.clone(); pub fn get_name(buf: &BytesMut) -> Result<String, Error> {
self let mut cursor = Cursor::new(buf);
// Skip the code and length
cursor.advance(mem::size_of::<u8>() + mem::size_of::<i32>());
cursor.read_string()?;
cursor.read_string()
}
/// Renames the prepared statement to a new name
pub fn rename(buf: BytesMut, new_name: &str) -> Result<BytesMut, Error> {
let mut cursor = Cursor::new(&buf);
// Read basic data from the cursor
let code = cursor.get_u8();
let current_len = cursor.get_i32();
let portal = cursor.read_string()?;
let prepared_statement = cursor.read_string()?;
// Calculate new length
let new_len = current_len + new_name.len() as i32 - prepared_statement.len() as i32;
// Begin building the response buffer
let mut response_buf = BytesMut::with_capacity(new_len as usize + 1);
response_buf.put_u8(code);
response_buf.put_i32(new_len);
// Put the portal and new name into the buffer
// Note: panic if the provided string contains null byte
response_buf.put_slice(CString::new(portal)?.as_bytes_with_nul());
response_buf.put_slice(CString::new(new_name)?.as_bytes_with_nul());
// Add the remainder of the original buffer into the response
response_buf.put_slice(&buf[cursor.position() as usize..]);
// Return the buffer
Ok(response_buf)
} }
pub fn anonymous(&self) -> bool { pub fn anonymous(&self) -> bool {
@@ -1016,6 +1151,15 @@ impl TryFrom<Describe> for BytesMut {
} }
impl Describe { impl Describe {
pub fn empty_new() -> Describe {
Describe {
code: 'D',
len: 4 + 1 + 1,
target: 'S',
statement_name: "".to_string(),
}
}
pub fn rename(mut self, name: &str) -> Self { pub fn rename(mut self, name: &str) -> Self {
self.statement_name = name.to_string(); self.statement_name = name.to_string();
self self
@@ -1104,13 +1248,6 @@ pub fn close_complete() -> BytesMut {
bytes bytes
} }
pub fn prepared_statement_name() -> String {
format!(
"P_{}",
PREPARED_STATEMENT_COUNTER.fetch_add(1, Ordering::SeqCst)
)
}
// from https://www.postgresql.org/docs/12/protocol-error-fields.html // from https://www.postgresql.org/docs/12/protocol-error-fields.html
#[derive(Debug, Default, PartialEq)] #[derive(Debug, Default, PartialEq)]
pub struct PgErrorMsg { pub struct PgErrorMsg {
@@ -1193,7 +1330,7 @@ impl Display for PgErrorMsg {
} }
impl PgErrorMsg { impl PgErrorMsg {
pub fn parse(error_msg: Vec<u8>) -> Result<PgErrorMsg, Error> { pub fn parse(error_msg: &[u8]) -> Result<PgErrorMsg, Error> {
let mut out = PgErrorMsg { let mut out = PgErrorMsg {
severity_localized: "".to_string(), severity_localized: "".to_string(),
severity: "".to_string(), severity: "".to_string(),
@@ -1341,7 +1478,7 @@ mod tests {
info!( info!(
"full message: {}", "full message: {}",
PgErrorMsg::parse(complete_msg.clone()).unwrap() PgErrorMsg::parse(&complete_msg).unwrap()
); );
assert_eq!( assert_eq!(
PgErrorMsg { PgErrorMsg {
@@ -1364,7 +1501,7 @@ mod tests {
line: Some(335), line: Some(335),
routine: Some(routine_msg.to_string()), routine: Some(routine_msg.to_string()),
}, },
PgErrorMsg::parse(complete_msg).unwrap() PgErrorMsg::parse(&complete_msg).unwrap()
); );
let mut only_mandatory_msg = vec![]; let mut only_mandatory_msg = vec![];
@@ -1374,7 +1511,7 @@ mod tests {
only_mandatory_msg.extend(field('M', message)); only_mandatory_msg.extend(field('M', message));
only_mandatory_msg.extend(field('D', detail_msg)); only_mandatory_msg.extend(field('D', detail_msg));
let err_fields = PgErrorMsg::parse(only_mandatory_msg.clone()).unwrap(); let err_fields = PgErrorMsg::parse(&only_mandatory_msg).unwrap();
info!("only mandatory fields: {}", &err_fields); info!("only mandatory fields: {}", &err_fields);
error!( error!(
"server error: {}: {}", "server error: {}: {}",
@@ -1401,7 +1538,7 @@ mod tests {
line: None, line: None,
routine: None, routine: None,
}, },
PgErrorMsg::parse(only_mandatory_msg).unwrap() PgErrorMsg::parse(&only_mandatory_msg).unwrap()
); );
} }
} }

View File

@@ -23,14 +23,15 @@ impl MirroredClient {
async fn create_pool(&self) -> Pool<ServerPool> { async fn create_pool(&self) -> Pool<ServerPool> {
let config = get_config(); let config = get_config();
let default = std::time::Duration::from_millis(10_000).as_millis() as u64; let default = std::time::Duration::from_millis(10_000).as_millis() as u64;
let (connection_timeout, idle_timeout, _cfg) = let (connection_timeout, idle_timeout, _cfg, prepared_statement_cache_size) =
match config.pools.get(&self.address.pool_name) { match config.pools.get(&self.address.pool_name) {
Some(cfg) => ( Some(cfg) => (
cfg.connect_timeout.unwrap_or(default), cfg.connect_timeout.unwrap_or(default),
cfg.idle_timeout.unwrap_or(default), cfg.idle_timeout.unwrap_or(default),
cfg.clone(), cfg.clone(),
cfg.prepared_statements_cache_size,
), ),
None => (default, default, crate::config::Pool::default()), None => (default, default, crate::config::Pool::default(), 0),
}; };
let manager = ServerPool::new( let manager = ServerPool::new(
@@ -42,6 +43,7 @@ impl MirroredClient {
None, None,
true, true,
false, false,
prepared_statement_cache_size,
); );
Pool::builder() Pool::builder()

View File

@@ -3,6 +3,7 @@ use async_trait::async_trait;
use bb8::{ManageConnection, Pool, PooledConnection, QueueStrategy}; use bb8::{ManageConnection, Pool, PooledConnection, QueueStrategy};
use chrono::naive::NaiveDateTime; use chrono::naive::NaiveDateTime;
use log::{debug, error, info, warn}; use log::{debug, error, info, warn};
use lru::LruCache;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
@@ -10,6 +11,7 @@ use rand::thread_rng;
use regex::Regex; use regex::Regex;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
use std::num::NonZeroUsize;
use std::sync::atomic::AtomicU64; use std::sync::atomic::AtomicU64;
use std::sync::{ use std::sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
@@ -24,6 +26,7 @@ use crate::config::{
use crate::errors::Error; use crate::errors::Error;
use crate::auth_passthrough::AuthPassthrough; use crate::auth_passthrough::AuthPassthrough;
use crate::messages::Parse;
use crate::plugins::prewarmer; use crate::plugins::prewarmer;
use crate::server::{Server, ServerParameters}; use crate::server::{Server, ServerParameters};
use crate::sharding::ShardingFunction; use crate::sharding::ShardingFunction;
@@ -54,6 +57,57 @@ pub enum BanReason {
AdminBan(i64), AdminBan(i64),
} }
pub type PreparedStatementCacheType = Arc<Mutex<PreparedStatementCache>>;
// TODO: Add stats the this cache
// TODO: Add application name to the cache value to help identify which application is using the cache
// TODO: Create admin command to show which statements are in the cache
#[derive(Debug)]
pub struct PreparedStatementCache {
cache: LruCache<u64, Arc<Parse>>,
}
impl PreparedStatementCache {
pub fn new(mut size: usize) -> Self {
// Cannot be zeros
if size == 0 {
size = 1;
}
PreparedStatementCache {
cache: LruCache::new(NonZeroUsize::new(size).unwrap()),
}
}
/// Adds the prepared statement to the cache if it doesn't exist with a new name
/// if it already exists will give you the existing parse
///
/// Pass the hash to this so that we can do the compute before acquiring the lock
pub fn get_or_insert(&mut self, parse: &Parse, hash: u64) -> Arc<Parse> {
match self.cache.get(&hash) {
Some(rewritten_parse) => rewritten_parse.clone(),
None => {
let new_parse = Arc::new(parse.clone().rewrite());
let evicted = self.cache.push(hash, new_parse.clone());
if let Some((_, evicted_parse)) = evicted {
debug!(
"Evicted prepared statement {} from cache",
evicted_parse.name
);
}
new_parse
}
}
}
/// Marks the hash as most recently used if it exists
pub fn promote(&mut self, hash: &u64) {
self.cache.promote(hash);
}
}
/// An identifier for a PgCat pool, /// An identifier for a PgCat pool,
/// a database visible to clients. /// a database visible to clients.
#[derive(Hash, Debug, Clone, PartialEq, Eq, Default)] #[derive(Hash, Debug, Clone, PartialEq, Eq, Default)]
@@ -223,6 +277,9 @@ pub struct ConnectionPool {
/// AuthInfo /// AuthInfo
pub auth_hash: Arc<RwLock<Option<String>>>, pub auth_hash: Arc<RwLock<Option<String>>>,
/// Cache
pub prepared_statement_cache: Option<PreparedStatementCacheType>,
} }
impl ConnectionPool { impl ConnectionPool {
@@ -376,6 +433,7 @@ impl ConnectionPool {
}, },
pool_config.cleanup_server_connections, pool_config.cleanup_server_connections,
pool_config.log_client_parameter_status_changes, pool_config.log_client_parameter_status_changes,
pool_config.prepared_statements_cache_size,
); );
let connect_timeout = match pool_config.connect_timeout { let connect_timeout = match pool_config.connect_timeout {
@@ -498,6 +556,12 @@ impl ConnectionPool {
validated: Arc::new(AtomicBool::new(false)), validated: Arc::new(AtomicBool::new(false)),
paused: Arc::new(AtomicBool::new(false)), paused: Arc::new(AtomicBool::new(false)),
paused_waiter: Arc::new(Notify::new()), paused_waiter: Arc::new(Notify::new()),
prepared_statement_cache: match pool_config.prepared_statements_cache_size {
0 => None,
_ => Some(Arc::new(Mutex::new(PreparedStatementCache::new(
pool_config.prepared_statements_cache_size,
)))),
},
}; };
// Connect to the servers to make sure pool configuration is valid // Connect to the servers to make sure pool configuration is valid
@@ -998,6 +1062,29 @@ impl ConnectionPool {
Some(shard) => shard < self.shards(), Some(shard) => shard < self.shards(),
} }
} }
/// Register a parse statement to the pool's cache and return the rewritten parse
///
/// Do not pass an anonymous parse statement to this function
pub fn register_parse_to_cache(&self, hash: u64, parse: &Parse) -> Option<Arc<Parse>> {
// We should only be calling this function if the cache is enabled
match self.prepared_statement_cache {
Some(ref prepared_statement_cache) => {
let mut cache = prepared_statement_cache.lock();
Some(cache.get_or_insert(parse, hash))
}
None => None,
}
}
/// Promote a prepared statement hash in the LRU
pub fn promote_prepared_statement_hash(&self, hash: &u64) {
// We should only be calling this function if the cache is enabled
if let Some(ref prepared_statement_cache) = self.prepared_statement_cache {
let mut cache = prepared_statement_cache.lock();
cache.promote(hash);
}
}
} }
/// Wrapper for the bb8 connection pool. /// Wrapper for the bb8 connection pool.
@@ -1025,6 +1112,9 @@ pub struct ServerPool {
/// Log client parameter status changes /// Log client parameter status changes
log_client_parameter_status_changes: bool, log_client_parameter_status_changes: bool,
/// Prepared statement cache size
prepared_statement_cache_size: usize,
} }
impl ServerPool { impl ServerPool {
@@ -1038,6 +1128,7 @@ impl ServerPool {
plugins: Option<Plugins>, plugins: Option<Plugins>,
cleanup_connections: bool, cleanup_connections: bool,
log_client_parameter_status_changes: bool, log_client_parameter_status_changes: bool,
prepared_statement_cache_size: usize,
) -> ServerPool { ) -> ServerPool {
ServerPool { ServerPool {
address, address,
@@ -1048,6 +1139,7 @@ impl ServerPool {
plugins, plugins,
cleanup_connections, cleanup_connections,
log_client_parameter_status_changes, log_client_parameter_status_changes,
prepared_statement_cache_size,
} }
} }
} }
@@ -1078,6 +1170,7 @@ impl ManageConnection for ServerPool {
self.auth_hash.clone(), self.auth_hash.clone(),
self.cleanup_connections, self.cleanup_connections,
self.log_client_parameter_status_changes, self.log_client_parameter_status_changes,
self.prepared_statement_cache_size,
) )
.await .await
{ {

View File

@@ -3,12 +3,14 @@
use bytes::{Buf, BufMut, BytesMut}; use bytes::{Buf, BufMut, BytesMut};
use fallible_iterator::FallibleIterator; use fallible_iterator::FallibleIterator;
use log::{debug, error, info, trace, warn}; use log::{debug, error, info, trace, warn};
use lru::LruCache;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use postgres_protocol::message; use postgres_protocol::message;
use std::collections::{BTreeSet, HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::mem; use std::mem;
use std::net::IpAddr; use std::net::IpAddr;
use std::num::NonZeroUsize;
use std::sync::Arc; use std::sync::Arc;
use std::time::SystemTime; use std::time::SystemTime;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream};
@@ -16,7 +18,7 @@ use tokio::net::TcpStream;
use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore}; use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore};
use tokio_rustls::{client::TlsStream, TlsConnector}; use tokio_rustls::{client::TlsStream, TlsConnector};
use crate::config::{get_config, get_prepared_statements_cache_size, Address, User}; use crate::config::{get_config, Address, User};
use crate::constants::*; use crate::constants::*;
use crate::dns_cache::{AddrSet, CACHED_RESOLVER}; use crate::dns_cache::{AddrSet, CACHED_RESOLVER};
use crate::errors::{Error, ServerIdentifier}; use crate::errors::{Error, ServerIdentifier};
@@ -322,7 +324,7 @@ pub struct Server {
log_client_parameter_status_changes: bool, log_client_parameter_status_changes: bool,
/// Prepared statements /// Prepared statements
prepared_statements: BTreeSet<String>, prepared_statement_cache: Option<LruCache<String, ()>>,
} }
impl Server { impl Server {
@@ -338,6 +340,7 @@ impl Server {
auth_hash: Arc<RwLock<Option<String>>>, auth_hash: Arc<RwLock<Option<String>>>,
cleanup_connections: bool, cleanup_connections: bool,
log_client_parameter_status_changes: bool, log_client_parameter_status_changes: bool,
prepared_statement_cache_size: usize,
) -> Result<Server, Error> { ) -> Result<Server, Error> {
let cached_resolver = CACHED_RESOLVER.load(); let cached_resolver = CACHED_RESOLVER.load();
let mut addr_set: Option<AddrSet> = None; let mut addr_set: Option<AddrSet> = None;
@@ -713,7 +716,7 @@ impl Server {
} }
}; };
let fields = match PgErrorMsg::parse(error) { let fields = match PgErrorMsg::parse(&error) {
Ok(f) => f, Ok(f) => f,
Err(err) => { Err(err) => {
return Err(err); return Err(err);
@@ -818,7 +821,12 @@ impl Server {
}, },
cleanup_connections, cleanup_connections,
log_client_parameter_status_changes, log_client_parameter_status_changes,
prepared_statements: BTreeSet::new(), prepared_statement_cache: match prepared_statement_cache_size {
0 => None,
_ => Some(LruCache::new(
NonZeroUsize::new(prepared_statement_cache_size).unwrap(),
)),
},
}; };
return Ok(server); return Ok(server);
@@ -957,6 +965,20 @@ impl Server {
if self.in_copy_mode { if self.in_copy_mode {
self.in_copy_mode = false; self.in_copy_mode = false;
} }
if self.prepared_statement_cache.is_some() {
let error_message = PgErrorMsg::parse(&message)?;
if error_message.message == "cached plan must not change result type" {
warn!("Server {:?} changed schema, dropping connection to clean up prepared statements", self.address);
// This will still result in an error to the client, but this server connection will drop all cached prepared statements
// so that any new queries will be re-prepared
// TODO: Other ideas to solve errors when there are DDL changes after a statement has been prepared
// - Recreate entire connection pool to force recreation of all server connections
// - Clear the ConnectionPool's statement cache so that new statement names are generated
// - Implement a retry (re-prepare) so the client doesn't see an error
self.cleanup_state.needs_cleanup_prepare = true;
}
}
} }
// CommandComplete // CommandComplete
@@ -1067,115 +1089,92 @@ impl Server {
Ok(bytes) Ok(bytes)
} }
/// Add the prepared statement to being tracked by this server. // Determines if the server already has a prepared statement with the given name
/// The client is processing data that will create a prepared statement on this server. // Increments the prepared statement cache hit counter
pub fn will_prepare(&mut self, name: &str) { pub fn has_prepared_statement(&mut self, name: &str) -> bool {
debug!("Will prepare `{}`", name); let cache = match &mut self.prepared_statement_cache {
Some(cache) => cache,
None => return false,
};
self.prepared_statements.insert(name.to_string()); let has_it = cache.get(name).is_some();
self.stats.prepared_cache_add(); if has_it {
}
/// Check if we should prepare a statement on the server.
pub fn should_prepare(&self, name: &str) -> bool {
let should_prepare = !self.prepared_statements.contains(name);
debug!("Should prepare `{}`: {}", name, should_prepare);
if should_prepare {
self.stats.prepared_cache_miss();
} else {
self.stats.prepared_cache_hit(); self.stats.prepared_cache_hit();
} else {
self.stats.prepared_cache_miss();
} }
should_prepare has_it
} }
/// Create a prepared statement on the server. pub fn add_prepared_statement_to_cache(&mut self, name: &str) -> Option<String> {
pub async fn prepare(&mut self, parse: &Parse) -> Result<(), Error> { let cache = match &mut self.prepared_statement_cache {
debug!("Preparing `{}`", parse.name); Some(cache) => cache,
None => return None,
};
let bytes: BytesMut = parse.try_into()?;
self.send(&bytes).await?;
self.send(&flush()).await?;
// Read and discard ParseComplete (B)
match read_message(&mut self.stream).await {
Ok(_) => (),
Err(err) => {
self.bad = true;
return Err(err);
}
}
self.prepared_statements.insert(parse.name.to_string());
self.stats.prepared_cache_add(); self.stats.prepared_cache_add();
debug!("Prepared `{}`", parse.name); // If we evict something, we need to close it on the server
if let Some((evicted_name, _)) = cache.push(name.to_string(), ()) {
Ok(()) if evicted_name != name {
} debug!(
"Evicted prepared statement {} from cache, replaced with {}",
/// Maintain adequate cache size on the server. evicted_name, name
pub async fn maintain_cache(&mut self) -> Result<(), Error> { );
debug!("Cache maintenance run"); return Some(evicted_name);
let max_cache_size = get_prepared_statements_cache_size();
let mut names = Vec::new();
while self.prepared_statements.len() >= max_cache_size {
// The prepared statmeents are alphanumerically sorted by the BTree.
// FIFO.
if let Some(name) = self.prepared_statements.pop_last() {
names.push(name);
} }
} };
if !names.is_empty() { None
self.deallocate(names).await?;
}
Ok(())
} }
/// Remove the prepared statement from being tracked by this server. pub fn remove_prepared_statement_from_cache(&mut self, name: &str) {
/// The client is processing data that will cause the server to close the prepared statement. let cache = match &mut self.prepared_statement_cache {
pub fn will_close(&mut self, name: &str) { Some(cache) => cache,
debug!("Will close `{}`", name); None => return,
};
self.prepared_statements.remove(name); self.stats.prepared_cache_remove();
cache.pop(name);
} }
/// Close a prepared statement on the server. pub async fn register_prepared_statement(
pub async fn deallocate(&mut self, names: Vec<String>) -> Result<(), Error> { &mut self,
for name in &names { parse: &Parse,
debug!("Deallocating prepared statement `{}`", name); should_send_parse_to_server: bool,
) -> Result<(), Error> {
if !self.has_prepared_statement(&parse.name) {
let mut bytes = BytesMut::new();
let close = Close::new(name); if should_send_parse_to_server {
let bytes: BytesMut = close.try_into()?; let parse_bytes: BytesMut = parse.try_into()?;
bytes.extend_from_slice(&parse_bytes);
}
self.send(&bytes).await?; // If we evict something, we need to close it on the server
} // We do this by adding it to the messages we're sending to the server before the sync
if let Some(evicted_name) = self.add_prepared_statement_to_cache(&parse.name) {
if !names.is_empty() { self.remove_prepared_statement_from_cache(&evicted_name);
self.send(&flush()).await?; let close_bytes: BytesMut = Close::new(&evicted_name).try_into()?;
} bytes.extend_from_slice(&close_bytes);
// Read and discard CloseComplete (3)
for name in &names {
match read_message(&mut self.stream).await {
Ok(_) => {
self.prepared_statements.remove(name);
self.stats.prepared_cache_remove();
debug!("Closed `{}`", name);
}
Err(err) => {
self.bad = true;
return Err(err);
}
}; };
}
// If we have a parse or close we need to send to the server, send them and sync
if !bytes.is_empty() {
bytes.extend_from_slice(&sync());
self.send(&bytes).await?;
loop {
self.recv(None).await?;
if !self.is_data_available() {
break;
}
}
}
};
Ok(()) Ok(())
} }
@@ -1312,6 +1311,10 @@ impl Server {
if self.cleanup_state.needs_cleanup_prepare { if self.cleanup_state.needs_cleanup_prepare {
reset_string.push_str("DEALLOCATE ALL;"); reset_string.push_str("DEALLOCATE ALL;");
// Since we deallocated all prepared statements, we need to clear the cache
if let Some(cache) = &mut self.prepared_statement_cache {
cache.clear();
}
}; };
self.query(&reset_string).await?; self.query(&reset_string).await?;
@@ -1377,6 +1380,7 @@ impl Server {
Arc::new(RwLock::new(None)), Arc::new(RwLock::new(None)),
true, true,
false, false,
0,
) )
.await?; .await?;
debug!("Connected!, sending query."); debug!("Connected!, sending query.");

View File

@@ -49,6 +49,7 @@ pub struct ServerStats {
pub error_count: Arc<AtomicU64>, pub error_count: Arc<AtomicU64>,
pub prepared_hit_count: Arc<AtomicU64>, pub prepared_hit_count: Arc<AtomicU64>,
pub prepared_miss_count: Arc<AtomicU64>, pub prepared_miss_count: Arc<AtomicU64>,
pub prepared_eviction_count: Arc<AtomicU64>,
pub prepared_cache_size: Arc<AtomicU64>, pub prepared_cache_size: Arc<AtomicU64>,
} }
@@ -68,6 +69,7 @@ impl Default for ServerStats {
reporter: get_reporter(), reporter: get_reporter(),
prepared_hit_count: Arc::new(AtomicU64::new(0)), prepared_hit_count: Arc::new(AtomicU64::new(0)),
prepared_miss_count: Arc::new(AtomicU64::new(0)), prepared_miss_count: Arc::new(AtomicU64::new(0)),
prepared_eviction_count: Arc::new(AtomicU64::new(0)),
prepared_cache_size: Arc::new(AtomicU64::new(0)), prepared_cache_size: Arc::new(AtomicU64::new(0)),
} }
} }
@@ -221,6 +223,7 @@ impl ServerStats {
} }
pub fn prepared_cache_remove(&self) { pub fn prepared_cache_remove(&self) {
self.prepared_eviction_count.fetch_add(1, Ordering::Relaxed);
self.prepared_cache_size.fetch_sub(1, Ordering::Relaxed); self.prepared_cache_size.fetch_sub(1, Ordering::Relaxed);
} }
} }

View File

@@ -36,4 +36,4 @@ SELECT abalance FROM pgbench_accounts WHERE aid = :aid;
SET SERVER ROLE TO 'replica'; SET SERVER ROLE TO 'replica';
-- Read load balancing -- Read load balancing
SELECT abalance FROM pgbench_accounts WHERE aid = :aid; SELECT abalance FROM pgbench_accounts WHERE aid = :aid;

View File

@@ -1,29 +1,214 @@
require_relative 'spec_helper' require_relative 'spec_helper'
describe 'Prepared statements' do describe 'Prepared statements' do
let(:processes) { Helpers::Pgcat.three_shard_setup('sharded_db', 5) } let(:pool_size) { 5 }
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", pool_size) }
let(:prepared_statements_cache_size) { 100 }
let(:server_round_robin) { false }
context 'enabled' do before do
it 'will work over the same connection' do new_configs = processes.pgcat.current_config
new_configs["general"]["server_round_robin"] = server_round_robin
new_configs["pools"]["sharded_db"]["prepared_statements_cache_size"] = prepared_statements_cache_size
new_configs["pools"]["sharded_db"]["users"]["0"]["pool_size"] = pool_size
processes.pgcat.update_config(new_configs)
processes.pgcat.reload_config
end
context 'when trying prepared statements' do
it 'it allows unparameterized statements to succeed' do
conn1 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
conn2 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
prepared_query = "SELECT 1"
# prepare query on server 1 and client 1
conn1.prepare('statement1', prepared_query)
conn1.exec_prepared('statement1')
conn2.transaction do
# Claim server 1 with client 2
conn2.exec("SELECT 2")
# Client 1 now runs the prepared query, and it's automatically
# prepared on server 2
conn1.prepare('statement2', prepared_query)
conn1.exec_prepared('statement2')
# Client 2 now prepares the same query that was already
# prepared on server 1. And PgBouncer reuses that already
# prepared query for this different client.
conn2.prepare('statement3', prepared_query)
conn2.exec_prepared('statement3')
end
ensure
conn1.close if conn1
conn2.close if conn2
end
it 'it allows parameterized statements to succeed' do
conn1 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
conn2 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
prepared_query = "SELECT $1"
# prepare query on server 1 and client 1
conn1.prepare('statement1', prepared_query)
conn1.exec_prepared('statement1', [1])
conn2.transaction do
# Claim server 1 with client 2
conn2.exec("SELECT 2")
# Client 1 now runs the prepared query, and it's automatically
# prepared on server 2
conn1.prepare('statement2', prepared_query)
conn1.exec_prepared('statement2', [1])
# Client 2 now prepares the same query that was already
# prepared on server 1. And PgBouncer reuses that already
# prepared query for this different client.
conn2.prepare('statement3', prepared_query)
conn2.exec_prepared('statement3', [1])
end
ensure
conn1.close if conn1
conn2.close if conn2
end
end
context 'when trying large packets' do
it "works with large parse" do
conn1 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
long_string = "1" * 4096 * 10
prepared_query = "SELECT '#{long_string}'"
# prepare query on server 1 and client 1
conn1.prepare('statement1', prepared_query)
result = conn1.exec_prepared('statement1')
# assert result matches long_string
expect(result.getvalue(0, 0)).to eq(long_string)
ensure
conn1.close if conn1
end
it "works with large bind" do
conn1 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
long_string = "1" * 4096 * 10
prepared_query = "SELECT $1::text"
# prepare query on server 1 and client 1
conn1.prepare('statement1', prepared_query)
result = conn1.exec_prepared('statement1', [long_string])
# assert result matches long_string
expect(result.getvalue(0, 0)).to eq(long_string)
ensure
conn1.close if conn1
end
end
context 'when statement cache is smaller than set of unqiue statements' do
let(:prepared_statements_cache_size) { 1 }
let(:pool_size) { 1 }
it "evicts all but 1 statement from the server cache" do
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
5.times do |i|
prepared_query = "SELECT '#{i}'"
conn.prepare("statement#{i}", prepared_query)
result = conn.exec_prepared("statement#{i}")
expect(result.getvalue(0, 0)).to eq(i.to_s)
end
# Check number of prepared statements (expected: 1)
n_statements = conn.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
expect(n_statements).to eq(1)
end
end
context 'when statement cache is larger than set of unqiue statements' do
let(:pool_size) { 1 }
it "does not evict any of the statements from the cache" do
# cache size 5
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
5.times do |i|
prepared_query = "SELECT '#{i}'"
conn.prepare("statement#{i}", prepared_query)
result = conn.exec_prepared("statement#{i}")
expect(result.getvalue(0, 0)).to eq(i.to_s)
end
# Check number of prepared statements (expected: 1)
n_statements = conn.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
expect(n_statements).to eq(5)
end
end
context 'when preparing the same query' do
let(:prepared_statements_cache_size) { 5 }
let(:pool_size) { 5 }
it "reuses statement cache when there are different statement names on the same connection" do
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user')) conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
10.times do |i| 10.times do |i|
statement_name = "statement_#{i}" statement_name = "statement_#{i}"
conn.prepare(statement_name, 'SELECT $1::int') conn.prepare(statement_name, 'SELECT $1::int')
conn.exec_prepared(statement_name, [1]) conn.exec_prepared(statement_name, [1])
conn.describe_prepared(statement_name)
end end
# Check number of prepared statements (expected: 1)
n_statements = conn.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
expect(n_statements).to eq(1)
end end
it 'will work with new connections' do it "reuses statement cache when there are different statement names on different connections" do
10.times do 10.times do |i|
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user')) conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
statement_name = "statement_#{i}"
statement_name = 'statement1' conn.prepare(statement_name, 'SELECT $1::int')
conn.prepare('statement1', 'SELECT $1::int') conn.exec_prepared(statement_name, [1])
conn.exec_prepared('statement1', [1])
conn.describe_prepared('statement1')
end end
# Check number of prepared statements (expected: 1)
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
n_statements = conn.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
expect(n_statements).to eq(1)
end
end
context 'when reloading config' do
let(:pool_size) { 1 }
it "test_reload_config" do
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
# prepare query
conn.prepare('statement1', 'SELECT 1')
conn.exec_prepared('statement1')
# Reload config which triggers pool recreation
new_configs = processes.pgcat.current_config
new_configs["pools"]["sharded_db"]["prepared_statements_cache_size"] = prepared_statements_cache_size + 1
processes.pgcat.update_config(new_configs)
processes.pgcat.reload_config
# check that we're starting with no prepared statements on the server
conn_check = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
n_statements = conn_check.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
expect(n_statements).to eq(0)
# still able to run prepared query
conn.exec_prepared('statement1')
end end
end end
end end