Compare commits

..

1 Commits

Author SHA1 Message Date
Lev
fa17bb5cc6 TLS misconfiguration demoted to warning 2023-09-26 10:14:42 -07:00
30 changed files with 915 additions and 1929 deletions

View File

@@ -63,9 +63,6 @@ jobs:
- run:
name: "Lint"
command: "cargo fmt --check"
- run:
name: "Clippy"
command: "cargo clippy --all --all-targets -- -Dwarnings"
- run:
name: "Tests"
command: "cargo clean && cargo build && cargo test && bash .circleci/run_tests.sh && .circleci/generate_coverage.sh"

View File

@@ -4,7 +4,7 @@ on:
workflow_dispatch:
inputs:
packageVersion:
default: "1.1.2-dev1"
default: "1.1.2-dev"
jobs:
build:
strategy:

1
.gitignore vendored
View File

@@ -10,4 +10,3 @@ lcov.info
dev/.bash_history
dev/cache
!dev/cache/.keepme
.venv

View File

@@ -259,6 +259,22 @@ Password to be used for connecting to servers to obtain the hash used for md5 au
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
This parameter is inherited by every pool and can be redefined in pool configuration.
### prepared_statements
```
path: general.prepared_statements
default: false
```
Whether to use prepared statements or not.
### prepared_statements_cache_size
```
path: general.prepared_statements_cache_size
default: 500
```
Size of the prepared statements cache.
### dns_cache_enabled
```
path: general.dns_cache_enabled
@@ -308,15 +324,6 @@ If the client doesn't specify, PgCat routes traffic to this role by default.
`replica` round-robin between replicas only without touching the primary,
`primary` all queries go to the primary unless otherwise specified.
### prepared_statements_cache_size
```
path: general.prepared_statements_cache_size
default: 0
```
Size of the prepared statements cache. 0 means disabled.
TODO: update documentation
### query_parser_enabled
```
path: pools.<pool_name>.query_parser_enabled

View File

@@ -2,7 +2,7 @@
Thank you for contributing! Just a few tips here:
1. `cargo fmt` and `cargo clippy` your code before opening up a PR
1. `cargo fmt` your code before opening up a PR
2. Run the test suite (e.g. `pgbench`) to make sure everything still works. The tests are in `.circleci/run_tests.sh`.
3. Performance is important, make sure there are no regressions in your branch vs. `main`.

33
Cargo.lock generated
View File

@@ -17,17 +17,6 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "ahash"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f"
dependencies = [
"cfg-if",
"once_cell",
"version_check",
]
[[package]]
name = "aho-corasick"
version = "1.0.2"
@@ -37,12 +26,6 @@ dependencies = [
"memchr",
]
[[package]]
name = "allocator-api2"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5"
[[package]]
name = "android-tzdata"
version = "0.1.1"
@@ -570,10 +553,6 @@ name = "hashbrown"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a"
dependencies = [
"ahash",
"allocator-api2",
]
[[package]]
name = "heck"
@@ -842,15 +821,6 @@ version = "0.4.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4"
[[package]]
name = "lru"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1efa59af2ddfad1854ae27d75009d538d0998b4b2fd47083e743ac1a10e46c60"
dependencies = [
"hashbrown 0.14.0",
]
[[package]]
name = "lru-cache"
version = "0.1.2"
@@ -1020,7 +990,7 @@ checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94"
[[package]]
name = "pgcat"
version = "1.1.2-dev1"
version = "1.1.2-dev"
dependencies = [
"arc-swap",
"async-trait",
@@ -1038,7 +1008,6 @@ dependencies = [
"itertools",
"jemallocator",
"log",
"lru",
"md-5",
"nix",
"num_cpus",

View File

@@ -1,6 +1,6 @@
[package]
name = "pgcat"
version = "1.1.2-dev1"
version = "1.1.2-dev"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -48,7 +48,6 @@ itertools = "0.10"
clap = { version = "4.3.1", features = ["derive", "env"] }
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter", "std"]}
lru = "0.12.0"
[target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0"

View File

@@ -8,12 +8,6 @@ WORKDIR /app
RUN cargo build --release
FROM debian:bookworm-slim
RUN apt-get update && apt-get install -o Dpkg::Options::=--force-confdef -yq --no-install-recommends \
postgresql-client \
# Clean up layer
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* \
&& truncate -s 0 /var/log/*log
COPY --from=builder /app/target/release/pgcat /usr/bin/pgcat
COPY --from=builder /app/pgcat.toml /etc/pgcat/pgcat.toml
WORKDIR /etc/pgcat

View File

@@ -1,8 +1,6 @@
FROM rust:1.70-bullseye
# Dependencies
COPY --from=sclevine/yj /bin/yj /bin/yj
RUN /bin/yj -h
RUN apt-get update -y \
&& apt-get install -y \
llvm-11 psmisc postgresql-contrib postgresql-client \

View File

@@ -60,6 +60,12 @@ tcp_keepalives_count = 5
# Number of seconds between keepalive packets.
tcp_keepalives_interval = 5
# Handle prepared statements.
prepared_statements = true
# Prepared statements server cache size.
prepared_statements_cache_size = 500
# Path to TLS Certificate file to use for TLS connections
# tls_certificate = ".circleci/server.cert"
# Path to TLS private key file to use for TLS connections
@@ -150,10 +156,6 @@ load_balancing_mode = "random"
# `primary` all queries go to the primary unless otherwise specified.
default_role = "any"
# Prepared statements cache size.
# TODO: update documentation
prepared_statements_cache_size = 500
# If Query Parser is enabled, we'll attempt to parse
# every incoming query to determine if it's a read or a write.
# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,

View File

@@ -283,7 +283,7 @@ where
{
let mut res = BytesMut::new();
let detail_msg = [
let detail_msg = vec![
"",
"SHOW HELP|CONFIG|DATABASES|POOLS|CLIENTS|SERVERS|USERS|VERSION",
// "SHOW PEERS|PEER_POOLS", // missing PEERS|PEER_POOLS
@@ -301,6 +301,7 @@ where
// "KILL <db>",
// "SUSPEND",
"SHUTDOWN",
// "WAIT_CLOSE [<db>]", // missing
];
res.put(notify("Console usage", detail_msg.join("\n\t")));
@@ -744,7 +745,6 @@ where
("age_seconds", DataType::Numeric),
("prepare_cache_hit", DataType::Numeric),
("prepare_cache_miss", DataType::Numeric),
("prepare_cache_eviction", DataType::Numeric),
("prepare_cache_size", DataType::Numeric),
];
@@ -777,10 +777,6 @@ where
.prepared_miss_count
.load(Ordering::Relaxed)
.to_string(),
server
.prepared_eviction_count
.load(Ordering::Relaxed)
.to_string(),
server
.prepared_cache_size
.load(Ordering::Relaxed)
@@ -806,7 +802,7 @@ where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let parts: Vec<&str> = match tokens.len() == 2 {
true => tokens[1].split(',').map(|part| part.trim()).collect(),
true => tokens[1].split(",").map(|part| part.trim()).collect(),
false => Vec::new(),
};
@@ -869,7 +865,7 @@ where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let parts: Vec<&str> = match tokens.len() == 2 {
true => tokens[1].split(',').map(|part| part.trim()).collect(),
true => tokens[1].split(",").map(|part| part.trim()).collect(),
false => Vec::new(),
};

File diff suppressed because it is too large Load Diff

View File

@@ -25,7 +25,7 @@ pub struct Args {
}
pub fn parse() -> Args {
Args::parse()
return Args::parse();
}
#[derive(ValueEnum, Clone, Debug)]

View File

@@ -1,6 +1,6 @@
/// Parse the configuration file.
use arc_swap::ArcSwap;
use log::{error, info};
use log::{error, info, warn};
use once_cell::sync::Lazy;
use regex::Regex;
use serde::{Deserializer, Serializer};
@@ -116,10 +116,10 @@ impl Default for Address {
host: String::from("127.0.0.1"),
port: 5432,
shard: 0,
address_index: 0,
replica_number: 0,
database: String::from("database"),
role: Role::Replica,
replica_number: 0,
address_index: 0,
username: String::from("username"),
pool_name: String::from("pool_name"),
mirrors: Vec::new(),
@@ -236,14 +236,18 @@ impl Default for User {
impl User {
fn validate(&self) -> Result<(), Error> {
if let Some(min_pool_size) = self.min_pool_size {
if min_pool_size > self.pool_size {
error!(
"min_pool_size of {} cannot be larger than pool_size of {}",
min_pool_size, self.pool_size
);
return Err(Error::BadConfig);
match self.min_pool_size {
Some(min_pool_size) => {
if min_pool_size > self.pool_size {
error!(
"min_pool_size of {} cannot be larger than pool_size of {}",
min_pool_size, self.pool_size
);
return Err(Error::BadConfig);
}
}
None => (),
};
Ok(())
@@ -337,6 +341,12 @@ pub struct General {
pub auth_query: Option<String>,
pub auth_query_user: Option<String>,
pub auth_query_password: Option<String>,
#[serde(default)]
pub prepared_statements: bool,
#[serde(default = "General::default_prepared_statements_cache_size")]
pub prepared_statements_cache_size: usize,
}
impl General {
@@ -418,6 +428,10 @@ impl General {
pub fn default_server_round_robin() -> bool {
true
}
pub fn default_prepared_statements_cache_size() -> usize {
500
}
}
impl Default for General {
@@ -429,33 +443,35 @@ impl Default for General {
prometheus_exporter_port: 9930,
connect_timeout: General::default_connect_timeout(),
idle_timeout: General::default_idle_timeout(),
shutdown_timeout: Self::default_shutdown_timeout(),
healthcheck_timeout: Self::default_healthcheck_timeout(),
healthcheck_delay: Self::default_healthcheck_delay(),
ban_time: Self::default_ban_time(),
worker_threads: Self::default_worker_threads(),
idle_client_in_transaction_timeout: Self::default_idle_client_in_transaction_timeout(),
tcp_keepalives_idle: Self::default_tcp_keepalives_idle(),
tcp_keepalives_count: Self::default_tcp_keepalives_count(),
tcp_keepalives_interval: Self::default_tcp_keepalives_interval(),
tcp_user_timeout: Self::default_tcp_user_timeout(),
log_client_connections: false,
log_client_disconnections: false,
autoreload: None,
dns_cache_enabled: false,
dns_max_ttl: Self::default_dns_max_ttl(),
shutdown_timeout: Self::default_shutdown_timeout(),
healthcheck_timeout: Self::default_healthcheck_timeout(),
healthcheck_delay: Self::default_healthcheck_delay(),
ban_time: Self::default_ban_time(),
idle_client_in_transaction_timeout: Self::default_idle_client_in_transaction_timeout(),
server_lifetime: Self::default_server_lifetime(),
server_round_robin: Self::default_server_round_robin(),
worker_threads: Self::default_worker_threads(),
autoreload: None,
tls_certificate: None,
tls_private_key: None,
server_tls: false,
verify_server_certificate: false,
admin_username: String::from("admin"),
admin_password: String::from("admin"),
validate_config: true,
auth_query: None,
auth_query_user: None,
auth_query_password: None,
server_lifetime: Self::default_server_lifetime(),
server_round_robin: Self::default_server_round_robin(),
validate_config: true,
prepared_statements: false,
prepared_statements_cache_size: 500,
}
}
}
@@ -556,9 +572,6 @@ pub struct Pool {
#[serde(default)] // False
pub log_client_parameter_status_changes: bool,
#[serde(default = "Pool::default_prepared_statements_cache_size")]
pub prepared_statements_cache_size: usize,
pub plugins: Option<Plugins>,
pub shards: BTreeMap<String, Shard>,
pub users: BTreeMap<String, User>,
@@ -608,10 +621,6 @@ impl Pool {
true
}
pub fn default_prepared_statements_cache_size() -> usize {
0
}
pub fn validate(&mut self) -> Result<(), Error> {
match self.default_role.as_ref() {
"any" => (),
@@ -668,9 +677,9 @@ impl Pool {
Some(key) => {
// No quotes in the key so we don't have to compare quoted
// to unquoted idents.
let key = key.replace('\"', "");
let key = key.replace("\"", "");
if key.split('.').count() != 2 {
if key.split(".").count() != 2 {
error!(
"automatic_sharding_key '{}' must be fully qualified, e.g. t.{}`",
key, key
@@ -683,14 +692,17 @@ impl Pool {
None => None,
};
if let DefaultShard::Shard(shard_number) = self.default_shard {
if shard_number >= self.shards.len() {
error!("Invalid shard {:?}", shard_number);
return Err(Error::BadConfig);
match self.default_shard {
DefaultShard::Shard(shard_number) => {
if shard_number >= self.shards.len() {
error!("Invalid shard {:?}", shard_number);
return Err(Error::BadConfig);
}
}
_ => (),
}
for user in self.users.values() {
for (_, user) in &self.users {
user.validate()?;
}
@@ -703,16 +715,17 @@ impl Default for Pool {
Pool {
pool_mode: Self::default_pool_mode(),
load_balancing_mode: Self::default_load_balancing_mode(),
shards: BTreeMap::from([(String::from("1"), Shard::default())]),
users: BTreeMap::default(),
default_role: String::from("any"),
query_parser_enabled: false,
query_parser_max_length: None,
query_parser_read_write_splitting: false,
primary_reads_enabled: false,
connect_timeout: None,
idle_timeout: None,
server_lifetime: None,
sharding_function: ShardingFunction::PgBigintHash,
automatic_sharding_key: None,
connect_timeout: None,
idle_timeout: None,
sharding_key_regex: None,
shard_id_regex: None,
regex_search_limit: Some(1000),
@@ -720,12 +733,10 @@ impl Default for Pool {
auth_query: None,
auth_query_user: None,
auth_query_password: None,
server_lifetime: None,
plugins: None,
cleanup_server_connections: true,
log_client_parameter_status_changes: false,
prepared_statements_cache_size: Self::default_prepared_statements_cache_size(),
plugins: None,
shards: BTreeMap::from([(String::from("1"), Shard::default())]),
users: BTreeMap::default(),
}
}
}
@@ -766,8 +777,8 @@ impl<'de> serde::Deserialize<'de> for DefaultShard {
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
if let Some(s) = s.strip_prefix("shard_") {
let shard = s.parse::<usize>().map_err(serde::de::Error::custom)?;
if s.starts_with("shard_") {
let shard = s[6..].parse::<usize>().map_err(serde::de::Error::custom)?;
return Ok(DefaultShard::Shard(shard));
}
@@ -837,13 +848,13 @@ impl Shard {
impl Default for Shard {
fn default() -> Shard {
Shard {
database: String::from("postgres"),
mirrors: None,
servers: vec![ServerConfig {
host: String::from("localhost"),
port: 5432,
role: Role::Primary,
}],
mirrors: None,
database: String::from("postgres"),
}
}
}
@@ -856,26 +867,15 @@ pub struct Plugins {
pub prewarmer: Option<Prewarmer>,
}
pub trait Plugin {
fn is_enabled(&self) -> bool;
}
impl std::fmt::Display for Plugins {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
fn is_enabled<T: Plugin>(arg: Option<&T>) -> bool {
if let Some(arg) = arg {
arg.is_enabled()
} else {
false
}
}
write!(
f,
"interceptor: {}, table_access: {}, query_logger: {}, prewarmer: {}",
is_enabled(self.intercept.as_ref()),
is_enabled(self.table_access.as_ref()),
is_enabled(self.query_logger.as_ref()),
is_enabled(self.prewarmer.as_ref()),
self.intercept.is_some(),
self.table_access.is_some(),
self.query_logger.is_some(),
self.prewarmer.is_some(),
)
}
}
@@ -886,47 +886,23 @@ pub struct Intercept {
pub queries: BTreeMap<String, Query>,
}
impl Plugin for Intercept {
fn is_enabled(&self) -> bool {
self.enabled
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
pub struct TableAccess {
pub enabled: bool,
pub tables: Vec<String>,
}
impl Plugin for TableAccess {
fn is_enabled(&self) -> bool {
self.enabled
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
pub struct QueryLogger {
pub enabled: bool,
}
impl Plugin for QueryLogger {
fn is_enabled(&self) -> bool {
self.enabled
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
pub struct Prewarmer {
pub enabled: bool,
pub queries: Vec<String>,
}
impl Plugin for Prewarmer {
fn is_enabled(&self) -> bool {
self.enabled
}
}
impl Intercept {
pub fn substitute(&mut self, db: &str, user: &str) {
for (_, query) in self.queries.iter_mut() {
@@ -944,7 +920,6 @@ pub struct Query {
}
impl Query {
#[allow(clippy::needless_range_loop)]
pub fn substitute(&mut self, db: &str, user: &str) {
for col in self.result.iter_mut() {
for i in 0..col.len() {
@@ -1014,8 +989,8 @@ impl Default for Config {
Config {
path: Self::default_path(),
general: General::default(),
plugins: None,
pools: HashMap::default(),
plugins: None,
}
}
}
@@ -1069,8 +1044,8 @@ impl From<&Config> for std::collections::HashMap<String, String> {
(
format!("pools.{:?}.users", pool_name),
pool.users
.values()
.map(|user| &user.username)
.iter()
.map(|(_username, user)| &user.username)
.cloned()
.collect::<Vec<String>>()
.join(", "),
@@ -1124,7 +1099,6 @@ impl From<&Config> for std::collections::HashMap<String, String> {
impl Config {
/// Print current configuration.
pub fn show(&self) {
info!("Config path: {}", self.path);
info!("Ban time: {}s", self.general.ban_time);
info!(
"Idle client in transaction timeout: {}ms",
@@ -1156,9 +1130,13 @@ impl Config {
Some(tls_certificate) => {
info!("TLS certificate: {}", tls_certificate);
if let Some(tls_private_key) = self.general.tls_private_key.clone() {
info!("TLS private key: {}", tls_private_key);
info!("TLS support is enabled");
match self.general.tls_private_key.clone() {
Some(tls_private_key) => {
info!("TLS private key: {}", tls_private_key);
info!("TLS support is enabled");
}
None => (),
}
}
@@ -1171,6 +1149,13 @@ impl Config {
"Server TLS certificate verification: {}",
self.general.verify_server_certificate
);
info!("Prepared statements: {}", self.general.prepared_statements);
if self.general.prepared_statements {
info!(
"Prepared statements server cache size: {}",
self.general.prepared_statements_cache_size
);
}
info!(
"Plugins: {}",
match self.plugins {
@@ -1186,8 +1171,8 @@ impl Config {
pool_name,
pool_config
.users
.values()
.map(|user_cfg| user_cfg.pool_size)
.iter()
.map(|(_, user_cfg)| user_cfg.pool_size)
.sum::<u32>()
.to_string()
);
@@ -1261,10 +1246,6 @@ impl Config {
"[pool: {}] Log client parameter status changes: {}",
pool_name, pool_config.log_client_parameter_status_changes
);
info!(
"[pool: {}] Prepared statements server cache size: {}",
pool_name, pool_config.prepared_statements_cache_size
);
info!(
"[pool: {}] Plugins: {}",
pool_name,
@@ -1361,31 +1342,42 @@ impl Config {
}
// Validate TLS!
if let Some(tls_certificate) = self.general.tls_certificate.clone() {
match load_certs(Path::new(&tls_certificate)) {
Ok(_) => {
// Cert is okay, but what about the private key?
match self.general.tls_private_key.clone() {
Some(tls_private_key) => match load_keys(Path::new(&tls_private_key)) {
Ok(_) => (),
Err(err) => {
error!("tls_private_key is incorrectly configured: {:?}", err);
return Err(Error::BadConfig);
match self.general.tls_certificate {
Some(ref mut tls_certificate) => {
match load_certs(Path::new(&tls_certificate)) {
Ok(_) => {
// Cert is okay, but what about the private key?
match self.general.tls_private_key {
Some(ref tls_private_key) => {
match load_keys(Path::new(&tls_private_key)) {
Ok(_) => (),
Err(err) => {
warn!(
"tls_private_key is incorrectly configured: {:?}",
err
);
self.general.tls_private_key = None;
self.general.tls_certificate = None;
}
}
}
},
None => {
error!("tls_certificate is set, but the tls_private_key is not");
return Err(Error::BadConfig);
}
};
}
None => {
warn!("tls_certificate is set, but the tls_private_key is not");
self.general.tls_private_key = None;
self.general.tls_certificate = None;
}
};
}
Err(err) => {
error!("tls_certificate is incorrectly configured: {:?}", err);
return Err(Error::BadConfig);
Err(err) => {
warn!("tls_certificate is incorrectly configured: {:?}", err);
self.general.tls_private_key = None;
self.general.tls_certificate = None;
}
}
}
None => (),
};
for pool in self.pools.values_mut() {
@@ -1407,6 +1399,14 @@ pub fn get_idle_client_in_transaction_timeout() -> u64 {
CONFIG.load().general.idle_client_in_transaction_timeout
}
pub fn get_prepared_statements() -> bool {
CONFIG.load().general.prepared_statements
}
pub fn get_prepared_statements_cache_size() -> usize {
CONFIG.load().general.prepared_statements_cache_size
}
/// Parse the configuration file located at the path.
pub async fn parse(path: &str) -> Result<(), Error> {
let mut contents = String::new();

View File

@@ -12,16 +12,13 @@ use crate::config::get_config;
use crate::errors::Error;
use crate::constants::MESSAGE_TERMINATOR;
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::ffi::CString;
use std::fmt::{Display, Formatter};
use std::hash::{Hash, Hasher};
use std::io::{BufRead, Cursor};
use std::mem;
use std::str::FromStr;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::Duration;
/// Postgres data type mappings
@@ -117,11 +114,19 @@ pub fn simple_query(query: &str) -> BytesMut {
}
/// Tell the client we're ready for another query.
pub async fn send_ready_for_query<S>(stream: &mut S) -> Result<(), Error>
pub async fn ready_for_query<S>(stream: &mut S) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
write_all(stream, ready_for_query(false)).await
let mut bytes = BytesMut::with_capacity(
mem::size_of::<u8>() + mem::size_of::<i32>() + mem::size_of::<u8>(),
);
bytes.put_u8(b'Z');
bytes.put_i32(5);
bytes.put_u8(b'I'); // Idle
write_all(stream, bytes).await
}
/// Send the startup packet the server. We're pretending we're a Pg client.
@@ -158,10 +163,12 @@ where
match stream.write_all(&startup).await {
Ok(_) => Ok(()),
Err(err) => Err(Error::SocketError(format!(
"Error writing startup to server socket - Error: {:?}",
err
))),
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing startup to server socket - Error: {:?}",
err
)))
}
}
}
@@ -237,8 +244,8 @@ pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec<u8> {
let mut md5 = Md5::new();
// First pass
md5.update(password.as_bytes());
md5.update(user.as_bytes());
md5.update(&password.as_bytes());
md5.update(&user.as_bytes());
let output = md5.finalize_reset();
@@ -274,7 +281,7 @@ where
{
let password = md5_hash_password(user, password, salt);
let mut message = BytesMut::with_capacity(password.len() + 5);
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
message.put_u8(b'p');
message.put_i32(password.len() as i32 + 4);
@@ -288,7 +295,7 @@ where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let password = md5_hash_second_pass(hash, salt);
let mut message = BytesMut::with_capacity(password.len() + 5);
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
message.put_u8(b'p');
message.put_i32(password.len() as i32 + 4);
@@ -315,7 +322,7 @@ where
res.put_slice(&set_complete[..]);
write_all_half(stream, &res).await?;
send_ready_for_query(stream).await
ready_for_query(stream).await
}
/// Send a custom error message to the client.
@@ -326,7 +333,7 @@ where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
error_response_terminal(stream, message).await?;
send_ready_for_query(stream).await
ready_for_query(stream).await
}
/// Send a custom error message to the client.
@@ -427,7 +434,7 @@ where
res.put(command_complete("SELECT 1"));
write_all_half(stream, &res).await?;
send_ready_for_query(stream).await
ready_for_query(stream).await
}
pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
@@ -509,7 +516,7 @@ pub fn data_row_nullable(row: &Vec<Option<String>>) -> BytesMut {
data_row.put_i32(column.len() as i32);
data_row.put_slice(column);
} else {
data_row.put_i32(-1_i32);
data_row.put_i32(-1 as i32);
}
}
@@ -557,37 +564,6 @@ pub fn flush() -> BytesMut {
bytes
}
pub fn sync() -> BytesMut {
let mut bytes = BytesMut::with_capacity(mem::size_of::<u8>() + mem::size_of::<i32>());
bytes.put_u8(b'S');
bytes.put_i32(4);
bytes
}
pub fn parse_complete() -> BytesMut {
let mut bytes = BytesMut::with_capacity(mem::size_of::<u8>() + mem::size_of::<i32>());
bytes.put_u8(b'1');
bytes.put_i32(4);
bytes
}
pub fn ready_for_query(in_transaction: bool) -> BytesMut {
let mut bytes = BytesMut::with_capacity(
mem::size_of::<u8>() + mem::size_of::<i32>() + mem::size_of::<u8>(),
);
bytes.put_u8(b'Z');
bytes.put_i32(5);
if in_transaction {
bytes.put_u8(b'T');
} else {
bytes.put_u8(b'I');
}
bytes
}
/// Write all data in the buffer to the TcpStream.
pub async fn write_all<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error>
where
@@ -595,10 +571,12 @@ where
{
match stream.write_all(&buf).await {
Ok(_) => Ok(()),
Err(err) => Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
))),
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
)))
}
}
}
@@ -609,10 +587,12 @@ where
{
match stream.write_all(buf).await {
Ok(_) => Ok(()),
Err(err) => Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
))),
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
)))
}
}
}
@@ -623,15 +603,19 @@ where
match stream.write_all(buf).await {
Ok(_) => match stream.flush().await {
Ok(_) => Ok(()),
Err(err) => Err(Error::SocketError(format!(
"Error flushing socket - Error: {:?}",
err
))),
Err(err) => {
return Err(Error::SocketError(format!(
"Error flushing socket - Error: {:?}",
err
)))
}
},
Err(err) => Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
))),
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
)))
}
}
}
@@ -746,7 +730,7 @@ impl BytesMutReader for Cursor<&BytesMut> {
let mut buf = vec![];
match self.read_until(b'\0', &mut buf) {
Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()),
Err(err) => Err(Error::ParseBytesError(err.to_string())),
Err(err) => return Err(Error::ParseBytesError(err.to_string())),
}
}
}
@@ -762,55 +746,10 @@ impl BytesMutReader for BytesMut {
let string_bytes = self.split_to(index + 1);
Ok(String::from_utf8_lossy(&string_bytes[..string_bytes.len() - 1]).to_string())
}
None => Err(Error::ParseBytesError("Could not read string".to_string())),
None => return Err(Error::ParseBytesError("Could not read string".to_string())),
}
}
}
pub enum ExtendedProtocolData {
Parse {
data: BytesMut,
metadata: Option<(Arc<Parse>, u64)>,
},
Bind {
data: BytesMut,
metadata: Option<String>,
},
Describe {
data: BytesMut,
metadata: Option<String>,
},
Execute {
data: BytesMut,
},
Close {
data: BytesMut,
close: Close,
},
}
impl ExtendedProtocolData {
pub fn create_new_parse(data: BytesMut, metadata: Option<(Arc<Parse>, u64)>) -> Self {
Self::Parse { data, metadata }
}
pub fn create_new_bind(data: BytesMut, metadata: Option<String>) -> Self {
Self::Bind { data, metadata }
}
pub fn create_new_describe(data: BytesMut, metadata: Option<String>) -> Self {
Self::Describe { data, metadata }
}
pub fn create_new_execute(data: BytesMut) -> Self {
Self::Execute { data }
}
pub fn create_new_close(data: BytesMut, close: Close) -> Self {
Self::Close { data, close }
}
}
/// Parse (F) message.
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
#[derive(Clone, Debug)]
@@ -819,6 +758,7 @@ pub struct Parse {
#[allow(dead_code)]
len: i32,
pub name: String,
pub generated_name: String,
query: String,
num_params: i16,
param_types: Vec<i32>,
@@ -844,6 +784,7 @@ impl TryFrom<&BytesMut> for Parse {
code,
len,
name,
generated_name: prepared_statement_name(),
query,
num_params,
param_types,
@@ -892,44 +833,11 @@ impl TryFrom<&Parse> for BytesMut {
}
impl Parse {
/// Renames the prepared statement to a new name based on the global counter
pub fn rewrite(mut self) -> Self {
self.name = format!(
"PGCAT_{}",
PREPARED_STATEMENT_COUNTER.fetch_add(1, Ordering::SeqCst)
);
pub fn rename(mut self) -> Self {
self.name = self.generated_name.to_string();
self
}
/// Gets the name of the prepared statement from the buffer
pub fn get_name(buf: &BytesMut) -> Result<String, Error> {
let mut cursor = Cursor::new(buf);
// Skip the code and length
cursor.advance(mem::size_of::<u8>() + mem::size_of::<i32>());
cursor.read_string()
}
/// Hashes the parse statement to be used as a key in the global cache
pub fn get_hash(&self) -> u64 {
// TODO_ZAIN: Take a look at which hashing function is being used
let mut hasher = DefaultHasher::new();
let concatenated = format!(
"{}{}{}",
self.query,
self.num_params,
self.param_types
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(",")
);
concatenated.hash(&mut hasher);
hasher.finish()
}
pub fn anonymous(&self) -> bool {
self.name.is_empty()
}
@@ -1060,42 +968,9 @@ impl TryFrom<Bind> for BytesMut {
}
impl Bind {
/// Gets the name of the prepared statement from the buffer
pub fn get_name(buf: &BytesMut) -> Result<String, Error> {
let mut cursor = Cursor::new(buf);
// Skip the code and length
cursor.advance(mem::size_of::<u8>() + mem::size_of::<i32>());
cursor.read_string()?;
cursor.read_string()
}
/// Renames the prepared statement to a new name
pub fn rename(buf: BytesMut, new_name: &str) -> Result<BytesMut, Error> {
let mut cursor = Cursor::new(&buf);
// Read basic data from the cursor
let code = cursor.get_u8();
let current_len = cursor.get_i32();
let portal = cursor.read_string()?;
let prepared_statement = cursor.read_string()?;
// Calculate new length
let new_len = current_len + new_name.len() as i32 - prepared_statement.len() as i32;
// Begin building the response buffer
let mut response_buf = BytesMut::with_capacity(new_len as usize + 1);
response_buf.put_u8(code);
response_buf.put_i32(new_len);
// Put the portal and new name into the buffer
// Note: panic if the provided string contains null byte
response_buf.put_slice(CString::new(portal)?.as_bytes_with_nul());
response_buf.put_slice(CString::new(new_name)?.as_bytes_with_nul());
// Add the remainder of the original buffer into the response
response_buf.put_slice(&buf[cursor.position() as usize..]);
// Return the buffer
Ok(response_buf)
pub fn reassign(mut self, parse: &Parse) -> Self {
self.prepared_statement = parse.name.clone();
self
}
pub fn anonymous(&self) -> bool {
@@ -1151,15 +1026,6 @@ impl TryFrom<Describe> for BytesMut {
}
impl Describe {
pub fn empty_new() -> Describe {
Describe {
code: 'D',
len: 4 + 1 + 1,
target: 'S',
statement_name: "".to_string(),
}
}
pub fn rename(mut self, name: &str) -> Self {
self.statement_name = name.to_string();
self
@@ -1248,6 +1114,13 @@ pub fn close_complete() -> BytesMut {
bytes
}
pub fn prepared_statement_name() -> String {
format!(
"P_{}",
PREPARED_STATEMENT_COUNTER.fetch_add(1, Ordering::SeqCst)
)
}
// from https://www.postgresql.org/docs/12/protocol-error-fields.html
#[derive(Debug, Default, PartialEq)]
pub struct PgErrorMsg {
@@ -1330,7 +1203,7 @@ impl Display for PgErrorMsg {
}
impl PgErrorMsg {
pub fn parse(error_msg: &[u8]) -> Result<PgErrorMsg, Error> {
pub fn parse(error_msg: Vec<u8>) -> Result<PgErrorMsg, Error> {
let mut out = PgErrorMsg {
severity_localized: "".to_string(),
severity: "".to_string(),
@@ -1438,38 +1311,38 @@ mod tests {
fn parse_fields() {
let mut complete_msg = vec![];
let severity = "FATAL";
complete_msg.extend(field('S', severity));
complete_msg.extend(field('V', severity));
complete_msg.extend(field('S', &severity));
complete_msg.extend(field('V', &severity));
let error_code = "29P02";
complete_msg.extend(field('C', error_code));
complete_msg.extend(field('C', &error_code));
let message = "password authentication failed for user \"wrong_user\"";
complete_msg.extend(field('M', message));
complete_msg.extend(field('M', &message));
let detail_msg = "super detailed message";
complete_msg.extend(field('D', detail_msg));
complete_msg.extend(field('D', &detail_msg));
let hint_msg = "hint detail here";
complete_msg.extend(field('H', hint_msg));
complete_msg.extend(field('H', &hint_msg));
complete_msg.extend(field('P', "123"));
complete_msg.extend(field('p', "234"));
let internal_query = "SELECT * from foo;";
complete_msg.extend(field('q', internal_query));
complete_msg.extend(field('q', &internal_query));
let where_msg = "where goes here";
complete_msg.extend(field('W', where_msg));
complete_msg.extend(field('W', &where_msg));
let schema_msg = "schema_name";
complete_msg.extend(field('s', schema_msg));
complete_msg.extend(field('s', &schema_msg));
let table_msg = "table_name";
complete_msg.extend(field('t', table_msg));
complete_msg.extend(field('t', &table_msg));
let column_msg = "column_name";
complete_msg.extend(field('c', column_msg));
complete_msg.extend(field('c', &column_msg));
let data_type_msg = "type_name";
complete_msg.extend(field('d', data_type_msg));
complete_msg.extend(field('d', &data_type_msg));
let constraint_msg = "constraint_name";
complete_msg.extend(field('n', constraint_msg));
complete_msg.extend(field('n', &constraint_msg));
let file_msg = "pgcat.c";
complete_msg.extend(field('F', file_msg));
complete_msg.extend(field('F', &file_msg));
complete_msg.extend(field('L', "335"));
let routine_msg = "my_failing_routine";
complete_msg.extend(field('R', routine_msg));
complete_msg.extend(field('R', &routine_msg));
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
@@ -1478,7 +1351,7 @@ mod tests {
info!(
"full message: {}",
PgErrorMsg::parse(&complete_msg).unwrap()
PgErrorMsg::parse(complete_msg.clone()).unwrap()
);
assert_eq!(
PgErrorMsg {
@@ -1501,17 +1374,17 @@ mod tests {
line: Some(335),
routine: Some(routine_msg.to_string()),
},
PgErrorMsg::parse(&complete_msg).unwrap()
PgErrorMsg::parse(complete_msg).unwrap()
);
let mut only_mandatory_msg = vec![];
only_mandatory_msg.extend(field('S', severity));
only_mandatory_msg.extend(field('V', severity));
only_mandatory_msg.extend(field('C', error_code));
only_mandatory_msg.extend(field('M', message));
only_mandatory_msg.extend(field('D', detail_msg));
only_mandatory_msg.extend(field('S', &severity));
only_mandatory_msg.extend(field('V', &severity));
only_mandatory_msg.extend(field('C', &error_code));
only_mandatory_msg.extend(field('M', &message));
only_mandatory_msg.extend(field('D', &detail_msg));
let err_fields = PgErrorMsg::parse(&only_mandatory_msg).unwrap();
let err_fields = PgErrorMsg::parse(only_mandatory_msg.clone()).unwrap();
info!("only mandatory fields: {}", &err_fields);
error!(
"server error: {}: {}",
@@ -1538,7 +1411,7 @@ mod tests {
line: None,
routine: None,
},
PgErrorMsg::parse(&only_mandatory_msg).unwrap()
PgErrorMsg::parse(only_mandatory_msg).unwrap()
);
}
}

View File

@@ -23,15 +23,14 @@ impl MirroredClient {
async fn create_pool(&self) -> Pool<ServerPool> {
let config = get_config();
let default = std::time::Duration::from_millis(10_000).as_millis() as u64;
let (connection_timeout, idle_timeout, _cfg, prepared_statement_cache_size) =
let (connection_timeout, idle_timeout, _cfg) =
match config.pools.get(&self.address.pool_name) {
Some(cfg) => (
cfg.connect_timeout.unwrap_or(default),
cfg.idle_timeout.unwrap_or(default),
cfg.clone(),
cfg.prepared_statements_cache_size,
),
None => (default, default, crate::config::Pool::default(), 0),
None => (default, default, crate::config::Pool::default()),
};
let manager = ServerPool::new(
@@ -43,7 +42,6 @@ impl MirroredClient {
None,
true,
false,
prepared_statement_cache_size,
);
Pool::builder()
@@ -139,18 +137,18 @@ impl MirroringManager {
bytes_rx,
disconnect_rx: exit_rx,
};
exit_senders.push(exit_tx);
byte_senders.push(bytes_tx);
exit_senders.push(exit_tx.clone());
byte_senders.push(bytes_tx.clone());
client.start();
});
Self {
byte_senders,
byte_senders: byte_senders,
disconnect_senders: exit_senders,
}
}
pub fn send(&mut self, bytes: &BytesMut) {
pub fn send(self: &mut Self, bytes: &BytesMut) {
// We want to avoid performing an allocation if we won't be able to send the message
// There is a possibility of a race here where we check the capacity and then the channel is
// closed or the capacity is reduced to 0, but mirroring is best effort anyway
@@ -172,7 +170,7 @@ impl MirroringManager {
});
}
pub fn disconnect(&mut self) {
pub fn disconnect(self: &mut Self) {
self.disconnect_senders
.iter_mut()
.for_each(|sender| match sender.try_send(()) {

View File

@@ -92,7 +92,7 @@ impl<'a> Plugin for Intercept<'a> {
.map(|s| {
let s = s.as_str().to_string();
if s.is_empty() {
if s == "" {
None
} else {
Some(s)

View File

@@ -33,7 +33,6 @@ pub enum PluginOutput {
#[async_trait]
pub trait Plugin {
// Run before the query is sent to the server.
#[allow(clippy::ptr_arg)]
async fn run(
&mut self,
query_router: &QueryRouter,

View File

@@ -20,7 +20,7 @@ impl<'a> Prewarmer<'a> {
self.server.address(),
query
);
self.server.query(query).await?;
self.server.query(&query).await?;
}
Ok(())

View File

@@ -34,7 +34,7 @@ impl<'a> Plugin for TableAccess<'a> {
visit_relations(ast, |relation| {
let relation = relation.to_string();
let parts = relation.split('.').collect::<Vec<&str>>();
let parts = relation.split(".").collect::<Vec<&str>>();
let table_name = parts.last().unwrap();
if self.tables.contains(&table_name.to_string()) {

View File

@@ -3,7 +3,6 @@ use async_trait::async_trait;
use bb8::{ManageConnection, Pool, PooledConnection, QueueStrategy};
use chrono::naive::NaiveDateTime;
use log::{debug, error, info, warn};
use lru::LruCache;
use once_cell::sync::Lazy;
use parking_lot::{Mutex, RwLock};
use rand::seq::SliceRandom;
@@ -11,7 +10,6 @@ use rand::thread_rng;
use regex::Regex;
use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use std::num::NonZeroUsize;
use std::sync::atomic::AtomicU64;
use std::sync::{
atomic::{AtomicBool, Ordering},
@@ -26,7 +24,6 @@ use crate::config::{
use crate::errors::Error;
use crate::auth_passthrough::AuthPassthrough;
use crate::messages::Parse;
use crate::plugins::prewarmer;
use crate::server::{Server, ServerParameters};
use crate::sharding::ShardingFunction;
@@ -57,57 +54,6 @@ pub enum BanReason {
AdminBan(i64),
}
pub type PreparedStatementCacheType = Arc<Mutex<PreparedStatementCache>>;
// TODO: Add stats the this cache
// TODO: Add application name to the cache value to help identify which application is using the cache
// TODO: Create admin command to show which statements are in the cache
#[derive(Debug)]
pub struct PreparedStatementCache {
cache: LruCache<u64, Arc<Parse>>,
}
impl PreparedStatementCache {
pub fn new(mut size: usize) -> Self {
// Cannot be zeros
if size == 0 {
size = 1;
}
PreparedStatementCache {
cache: LruCache::new(NonZeroUsize::new(size).unwrap()),
}
}
/// Adds the prepared statement to the cache if it doesn't exist with a new name
/// if it already exists will give you the existing parse
///
/// Pass the hash to this so that we can do the compute before acquiring the lock
pub fn get_or_insert(&mut self, parse: &Parse, hash: u64) -> Arc<Parse> {
match self.cache.get(&hash) {
Some(rewritten_parse) => rewritten_parse.clone(),
None => {
let new_parse = Arc::new(parse.clone().rewrite());
let evicted = self.cache.push(hash, new_parse.clone());
if let Some((_, evicted_parse)) = evicted {
debug!(
"Evicted prepared statement {} from cache",
evicted_parse.name
);
}
new_parse
}
}
}
/// Marks the hash as most recently used if it exists
pub fn promote(&mut self, hash: &u64) {
self.cache.promote(hash);
}
}
/// An identifier for a PgCat pool,
/// a database visible to clients.
#[derive(Hash, Debug, Clone, PartialEq, Eq, Default)]
@@ -244,11 +190,11 @@ impl Default for PoolSettings {
#[derive(Clone, Debug, Default)]
pub struct ConnectionPool {
/// The pools handled internally by bb8.
databases: Arc<Vec<Vec<Pool<ServerPool>>>>,
databases: Vec<Vec<Pool<ServerPool>>>,
/// The addresses (host, port, role) to handle
/// failover and load balancing deterministically.
addresses: Arc<Vec<Vec<Address>>>,
addresses: Vec<Vec<Address>>,
/// List of banned addresses (see above)
/// that should not be queried.
@@ -260,7 +206,7 @@ pub struct ConnectionPool {
original_server_parameters: Arc<RwLock<ServerParameters>>,
/// Pool configuration.
pub settings: Arc<PoolSettings>,
pub settings: PoolSettings,
/// If not validated, we need to double check the pool is available before allowing a client
/// to use it.
@@ -277,9 +223,6 @@ pub struct ConnectionPool {
/// AuthInfo
pub auth_hash: Arc<RwLock<Option<String>>>,
/// Cache
pub prepared_statement_cache: Option<PreparedStatementCacheType>,
}
impl ConnectionPool {
@@ -298,17 +241,20 @@ impl ConnectionPool {
let old_pool_ref = get_pool(pool_name, &user.username);
let identifier = PoolIdentifier::new(pool_name, &user.username);
if let Some(pool) = old_pool_ref {
// If the pool hasn't changed, get existing reference and insert it into the new_pools.
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
if pool.config_hash == new_pool_hash_value {
info!(
"[pool: {}][user: {}] has not changed",
pool_name, user.username
);
new_pools.insert(identifier.clone(), pool.clone());
continue;
match old_pool_ref {
Some(pool) => {
// If the pool hasn't changed, get existing reference and insert it into the new_pools.
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
if pool.config_hash == new_pool_hash_value {
info!(
"[pool: {}][user: {}] has not changed",
pool_name, user.username
);
new_pools.insert(identifier.clone(), pool.clone());
continue;
}
}
None => (),
}
info!(
@@ -433,7 +379,6 @@ impl ConnectionPool {
},
pool_config.cleanup_server_connections,
pool_config.log_client_parameter_status_changes,
pool_config.prepared_statements_cache_size,
);
let connect_timeout = match pool_config.connect_timeout {
@@ -454,7 +399,7 @@ impl ConnectionPool {
},
};
let reaper_rate = *[idle_timeout, server_lifetime, POOL_REAPER_RATE]
let reaper_rate = *vec![idle_timeout, server_lifetime, POOL_REAPER_RATE]
.iter()
.min()
.unwrap();
@@ -503,13 +448,13 @@ impl ConnectionPool {
}
let pool = ConnectionPool {
databases: Arc::new(shards),
addresses: Arc::new(addresses),
databases: shards,
addresses,
banlist: Arc::new(RwLock::new(banlist)),
config_hash: new_pool_hash_value,
original_server_parameters: Arc::new(RwLock::new(ServerParameters::new())),
auth_hash: pool_auth_hash,
settings: Arc::new(PoolSettings {
settings: PoolSettings {
pool_mode: match user.pool_mode {
Some(pool_mode) => pool_mode,
None => pool_config.pool_mode,
@@ -544,7 +489,7 @@ impl ConnectionPool {
.clone()
.map(|regex| Regex::new(regex.as_str()).unwrap()),
regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
default_shard: pool_config.default_shard,
default_shard: pool_config.default_shard.clone(),
auth_query: pool_config.auth_query.clone(),
auth_query_user: pool_config.auth_query_user.clone(),
auth_query_password: pool_config.auth_query_password.clone(),
@@ -552,23 +497,17 @@ impl ConnectionPool {
Some(ref plugins) => Some(plugins.clone()),
None => config.plugins.clone(),
},
}),
},
validated: Arc::new(AtomicBool::new(false)),
paused: Arc::new(AtomicBool::new(false)),
paused_waiter: Arc::new(Notify::new()),
prepared_statement_cache: match pool_config.prepared_statements_cache_size {
0 => None,
_ => Some(Arc::new(Mutex::new(PreparedStatementCache::new(
pool_config.prepared_statements_cache_size,
)))),
},
};
// Connect to the servers to make sure pool configuration is valid
// before setting it globally.
// Do this async and somewhere else, we don't have to wait here.
if config.general.validate_config {
let validate_pool = pool.clone();
let mut validate_pool = pool.clone();
tokio::task::spawn(async move {
let _ = validate_pool.validate().await;
});
@@ -589,7 +528,7 @@ impl ConnectionPool {
/// when they connect.
/// This also warms up the pool for clients that connect when
/// the pooler starts up.
pub async fn validate(&self) -> Result<(), Error> {
pub async fn validate(&mut self) -> Result<(), Error> {
let mut futures = Vec::new();
let validated = Arc::clone(&self.validated);
@@ -739,7 +678,7 @@ impl ConnectionPool {
let mut force_healthcheck = false;
if self.is_banned(address) {
if self.try_unban(address).await {
if self.try_unban(&address).await {
force_healthcheck = true;
} else {
debug!("Address {:?} is banned", address);
@@ -867,8 +806,8 @@ impl ConnectionPool {
// Don't leave a bad connection in the pool.
server.mark_bad();
self.ban(address, BanReason::FailedHealthCheck, Some(client_info));
false
self.ban(&address, BanReason::FailedHealthCheck, Some(client_info));
return false;
}
/// Ban an address (i.e. replica). It no longer will serve
@@ -992,10 +931,10 @@ impl ConnectionPool {
let guard = self.banlist.read();
for banlist in guard.iter() {
for (address, (reason, timestamp)) in banlist.iter() {
bans.push((address.clone(), (reason.clone(), *timestamp)));
bans.push((address.clone(), (reason.clone(), timestamp.clone())));
}
}
bans
return bans;
}
/// Get the address from the host url
@@ -1053,7 +992,7 @@ impl ConnectionPool {
}
let busy = provisioned - idle;
debug!("{:?} has {:?} busy connections", address, busy);
busy
return busy;
}
fn valid_shard_id(&self, shard: Option<usize>) -> bool {
@@ -1062,29 +1001,6 @@ impl ConnectionPool {
Some(shard) => shard < self.shards(),
}
}
/// Register a parse statement to the pool's cache and return the rewritten parse
///
/// Do not pass an anonymous parse statement to this function
pub fn register_parse_to_cache(&self, hash: u64, parse: &Parse) -> Option<Arc<Parse>> {
// We should only be calling this function if the cache is enabled
match self.prepared_statement_cache {
Some(ref prepared_statement_cache) => {
let mut cache = prepared_statement_cache.lock();
Some(cache.get_or_insert(parse, hash))
}
None => None,
}
}
/// Promote a prepared statement hash in the LRU
pub fn promote_prepared_statement_hash(&self, hash: &u64) {
// We should only be calling this function if the cache is enabled
if let Some(ref prepared_statement_cache) = self.prepared_statement_cache {
let mut cache = prepared_statement_cache.lock();
cache.promote(hash);
}
}
}
/// Wrapper for the bb8 connection pool.
@@ -1112,13 +1028,9 @@ pub struct ServerPool {
/// Log client parameter status changes
log_client_parameter_status_changes: bool,
/// Prepared statement cache size
prepared_statement_cache_size: usize,
}
impl ServerPool {
#[allow(clippy::too_many_arguments)]
pub fn new(
address: Address,
user: User,
@@ -1128,18 +1040,16 @@ impl ServerPool {
plugins: Option<Plugins>,
cleanup_connections: bool,
log_client_parameter_status_changes: bool,
prepared_statement_cache_size: usize,
) -> ServerPool {
ServerPool {
address,
user,
user: user.clone(),
database: database.to_string(),
client_server_map,
auth_hash,
plugins,
cleanup_connections,
log_client_parameter_status_changes,
prepared_statement_cache_size,
}
}
}
@@ -1170,7 +1080,6 @@ impl ManageConnection for ServerPool {
self.auth_hash.clone(),
self.cleanup_connections,
self.log_client_parameter_status_changes,
self.prepared_statement_cache_size,
)
.await
{

View File

@@ -4,10 +4,10 @@ use bytes::{Buf, BytesMut};
use log::{debug, error};
use once_cell::sync::OnceCell;
use regex::{Regex, RegexSet};
use sqlparser::ast::Statement::{Delete, Insert, Query, StartTransaction, Update};
use sqlparser::ast::Statement::{Query, StartTransaction};
use sqlparser::ast::{
Assignment, BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, Statement,
TableFactor, TableWithJoins, Value,
BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, Statement, TableFactor,
Value,
};
use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;
@@ -91,7 +91,7 @@ impl QueryRouter {
/// One-time initialization of regexes
/// that parse our custom SQL protocol.
pub fn setup() -> bool {
let set = match RegexSet::new(CUSTOM_SQL_REGEXES) {
let set = match RegexSet::new(&CUSTOM_SQL_REGEXES) {
Ok(rgx) => rgx,
Err(err) => {
error!("QueryRouter::setup Could not compile regex set: {:?}", err);
@@ -128,11 +128,11 @@ impl QueryRouter {
}
/// Pool settings can change because of a config reload.
pub fn update_pool_settings(&mut self, pool_settings: &PoolSettings) {
self.pool_settings = pool_settings.clone();
pub fn update_pool_settings(&mut self, pool_settings: PoolSettings) {
self.pool_settings = pool_settings;
}
pub fn pool_settings(&self) -> &PoolSettings {
pub fn pool_settings<'a>(&'a self) -> &'a PoolSettings {
&self.pool_settings
}
@@ -148,7 +148,7 @@ impl QueryRouter {
// Check for any sharding regex matches in any queries
if comment_shard_routing_enabled {
match code {
match code as char {
// For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement
'P' | 'Q' => {
// Check only the first block of bytes configured by the pool settings
@@ -344,13 +344,16 @@ impl QueryRouter {
let code = message_cursor.get_u8() as char;
let len = message_cursor.get_i32() as usize;
if let Some(max_length) = self.pool_settings.query_parser_max_length {
if len > max_length {
return Err(Error::QueryRouterParserError(format!(
"Query too long for parser: {} > {}",
len, max_length
)));
match self.pool_settings.query_parser_max_length {
Some(max_length) => {
if len > max_length {
return Err(Error::QueryRouterParserError(format!(
"Query too long for parser: {} > {}",
len, max_length
)));
}
}
None => (),
};
let query = match code {
@@ -400,9 +403,6 @@ impl QueryRouter {
return Err(Error::QueryRouterParserError("empty query".into()));
}
let mut visited_write_statement = false;
let mut prev_inferred_shard = None;
for q in ast {
match q {
// All transactions go to the primary, probably a write.
@@ -420,38 +420,29 @@ impl QueryRouter {
// or discard shard selection. If they point to the same shard though,
// we can let them through as-is.
// This is basically building a database now :)
let inferred_shard = self.infer_shard(query);
self.handle_inferred_shard(inferred_shard, &mut prev_inferred_shard)?;
match self.infer_shard(query) {
Some(shard) => {
self.active_shard = Some(shard);
debug!("Automatically using shard: {:?}", self.active_shard);
}
None => (),
};
}
None => (),
};
// If we already visited a write statement, we should be going to the primary.
if !visited_write_statement {
self.active_role = match self.primary_reads_enabled() {
false => Some(Role::Replica), // If primary should not be receiving reads, use a replica.
true => None, // Any server role is fine in this case.
}
self.active_role = match self.primary_reads_enabled() {
false => Some(Role::Replica), // If primary should not be receiving reads, use a replica.
true => None, // Any server role is fine in this case.
}
}
// Likely a write
_ => {
match &self.pool_settings.automatic_sharding_key {
Some(_) => {
// TODO: similar to the above, if we have multiple queries in the
// same message, we can either split them and execute them individually
// or discard shard selection. If they point to the same shard though,
// we can let them through as-is.
let inferred_shard = self.infer_shard_on_write(q)?;
self.handle_inferred_shard(inferred_shard, &mut prev_inferred_shard)?;
}
None => (),
};
visited_write_statement = true;
self.active_role = Some(Role::Primary);
break;
}
};
}
@@ -459,188 +450,6 @@ impl QueryRouter {
Ok(())
}
fn handle_inferred_shard(
&mut self,
inferred_shard: Option<usize>,
prev_inferred_shard: &mut Option<usize>,
) -> Result<(), Error> {
if let Some(shard) = inferred_shard {
if let Some(prev_shard) = *prev_inferred_shard {
if prev_shard != shard {
debug!("Found more than one shard in the query, not supported yet");
return Err(Error::QueryRouterParserError(
"multiple shards in query".into(),
));
}
}
*prev_inferred_shard = Some(shard);
self.active_shard = Some(shard);
debug!("Automatically using shard: {:?}", self.active_shard);
};
Ok(())
}
fn infer_shard_on_write(&mut self, q: &Statement) -> Result<Option<usize>, Error> {
let mut exprs = Vec::new();
// Collect all table names from the query.
let mut table_names = Vec::new();
match q {
Insert {
or,
into: _,
table_name,
columns,
overwrite: _,
source,
partitioned,
after_columns,
table: _,
on: _,
returning: _,
} => {
// Not supported in postgres.
assert!(or.is_none());
assert!(partitioned.is_none());
assert!(after_columns.is_empty());
Self::process_table(table_name, &mut table_names);
Self::process_query(source, &mut exprs, &mut table_names, &Some(columns));
}
Delete {
tables,
from,
using,
selection,
returning: _,
} => {
if let Some(expr) = selection {
exprs.push(expr.clone());
}
// Multi tables delete are not supported in postgres.
assert!(tables.is_empty());
Self::process_tables_with_join(from, &mut exprs, &mut table_names);
if let Some(using_tbl_with_join) = using {
Self::process_tables_with_join(
using_tbl_with_join,
&mut exprs,
&mut table_names,
);
}
Self::process_selection(selection, &mut exprs);
}
Update {
table,
assignments,
from,
selection,
returning: _,
} => {
Self::process_table_with_join(table, &mut exprs, &mut table_names);
if let Some(from_tbl) = from {
Self::process_table_with_join(from_tbl, &mut exprs, &mut table_names);
}
Self::process_selection(selection, &mut exprs);
self.assignment_parser(assignments)?;
}
_ => {
return Ok(None);
}
};
Ok(self.infer_shard_from_exprs(exprs, table_names))
}
fn process_query(
query: &sqlparser::ast::Query,
exprs: &mut Vec<Expr>,
table_names: &mut Vec<Vec<Ident>>,
columns: &Option<&Vec<Ident>>,
) {
match &*query.body {
SetExpr::Query(query) => {
Self::process_query(query, exprs, table_names, columns);
}
// SELECT * FROM ...
// We understand that pretty well.
SetExpr::Select(select) => {
Self::process_tables_with_join(&select.from, exprs, table_names);
// Parse the actual "FROM ..."
Self::process_selection(&select.selection, exprs);
}
SetExpr::Values(values) => {
if let Some(cols) = columns {
for row in values.rows.iter() {
for (i, expr) in row.iter().enumerate() {
if cols.len() > i {
exprs.push(Expr::BinaryOp {
left: Box::new(Expr::Identifier(cols[i].clone())),
op: BinaryOperator::Eq,
right: Box::new(expr.clone()),
});
}
}
}
}
}
_ => (),
};
}
fn process_selection(selection: &Option<Expr>, exprs: &mut Vec<Expr>) {
match selection {
Some(selection) => {
exprs.push(selection.clone());
}
None => (),
};
}
fn process_tables_with_join(
tables: &[TableWithJoins],
exprs: &mut Vec<Expr>,
table_names: &mut Vec<Vec<Ident>>,
) {
for table in tables.iter() {
Self::process_table_with_join(table, exprs, table_names);
}
}
fn process_table_with_join(
table: &TableWithJoins,
exprs: &mut Vec<Expr>,
table_names: &mut Vec<Vec<Ident>>,
) {
if let TableFactor::Table { name, .. } = &table.relation {
Self::process_table(name, table_names);
};
// Get table names from all the joins.
for join in table.joins.iter() {
if let TableFactor::Table { name, .. } = &join.relation {
Self::process_table(name, table_names);
};
// We can filter results based on join conditions, e.g.
// SELECT * FROM t INNER JOIN B ON B.sharding_key = 5;
if let JoinOperator::Inner(JoinConstraint::On(expr)) = &join.join_operator {
// Parse the selection criteria later.
exprs.push(expr.clone());
};
}
}
fn process_table(name: &sqlparser::ast::ObjectName, table_names: &mut Vec<Vec<Ident>>) {
table_names.push(name.0.clone())
}
/// Parse the shard number from the Bind message
/// which contains the arguments for a prepared statement.
///
@@ -783,33 +592,6 @@ impl QueryRouter {
}
}
/// An `assignments` exists in the `UPDATE` statements. This parses the assignments and makes
/// sure that we are not updating the sharding key. It's not supported yet.
fn assignment_parser(&self, assignments: &Vec<Assignment>) -> Result<(), Error> {
let sharding_key = self
.pool_settings
.automatic_sharding_key
.as_ref()
.unwrap()
.split('.')
.map(|ident| Ident::new(ident.to_lowercase()))
.collect::<Vec<Ident>>();
// Sharding key must be always fully qualified
assert_eq!(sharding_key.len(), 2);
for a in assignments {
if sharding_key[0].value == "*"
&& sharding_key[1].value == a.id.last().unwrap().value.to_lowercase()
{
return Err(Error::QueryRouterParserError(
"Sharding key cannot be updated.".into(),
));
}
}
Ok(())
}
/// A `selection` is the `WHERE` clause. This parses
/// the clause and extracts the sharding key, if present.
fn selection_parser(&self, expr: &Expr, table_names: &Vec<Vec<Ident>>) -> Vec<ShardingKey> {
@@ -821,8 +603,8 @@ impl QueryRouter {
.automatic_sharding_key
.as_ref()
.unwrap()
.split('.')
.map(|ident| Ident::new(ident.to_lowercase()))
.split(".")
.map(|ident| Ident::new(ident))
.collect::<Vec<Ident>>();
// Sharding key must be always fully qualified
@@ -838,7 +620,7 @@ impl QueryRouter {
Expr::Identifier(ident) => {
// Only if we're dealing with only one table
// and there is no ambiguity
if ident.value.to_lowercase() == sharding_key[1].value {
if &ident.value == &sharding_key[1].value {
// Sharding key is unique enough, don't worry about
// table names.
if &sharding_key[0].value == "*" {
@@ -851,13 +633,13 @@ impl QueryRouter {
// SELECT * FROM t WHERE sharding_key = 5
// Make sure the table name from the sharding key matches
// the table name from the query.
found = sharding_key[0].value == table[0].value.to_lowercase();
found = &sharding_key[0].value == &table[0].value;
} else if table.len() == 2 {
// Table name is fully qualified with the schema: e.g.
// SELECT * FROM public.t WHERE sharding_key = 5
// Ignore the schema (TODO: at some point, we want schema support)
// and use the table name only.
found = sharding_key[0].value == table[1].value.to_lowercase();
found = &sharding_key[0].value == &table[1].value;
} else {
debug!("Got table name with more than two idents, which is not possible");
}
@@ -869,9 +651,8 @@ impl QueryRouter {
// The key is fully qualified in the query,
// it will exist or Postgres will throw an error.
if idents.len() == 2 {
found = (&sharding_key[0].value == "*"
|| sharding_key[0].value == idents[0].value.to_lowercase())
&& sharding_key[1].value == idents[1].value.to_lowercase();
found = &sharding_key[0].value == &idents[0].value
&& &sharding_key[1].value == &idents[1].value;
}
// TODO: key can have schema as well, e.g. public.data.id (len == 3)
}
@@ -903,7 +684,7 @@ impl QueryRouter {
}
Expr::Value(Value::Placeholder(placeholder)) => {
match placeholder.replace('$', "").parse::<i16>() {
match placeholder.replace("$", "").parse::<i16>() {
Ok(placeholder) => result.push(ShardingKey::Placeholder(placeholder)),
Err(_) => {
debug!(
@@ -924,48 +705,100 @@ impl QueryRouter {
/// Try to figure out which shard the query should go to.
fn infer_shard(&mut self, query: &sqlparser::ast::Query) -> Option<usize> {
let mut shards = BTreeSet::new();
let mut exprs = Vec::new();
// Collect all table names from the query.
let mut table_names = Vec::new();
Self::process_query(query, &mut exprs, &mut table_names, &None);
self.infer_shard_from_exprs(exprs, table_names)
}
fn infer_shard_from_exprs(
&mut self,
exprs: Vec<Expr>,
table_names: Vec<Vec<Ident>>,
) -> Option<usize> {
let mut shards = BTreeSet::new();
let sharder = Sharder::new(
self.pool_settings.shards,
self.pool_settings.sharding_function,
);
// Look for sharding keys in either the join condition
// or the selection.
for expr in exprs.iter() {
let sharding_keys = self.selection_parser(expr, &table_names);
// TODO: Add support for prepared statements here.
// This should just give us the position of the value in the `B` message.
for value in sharding_keys {
match value {
ShardingKey::Value(value) => {
let shard = sharder.shard(value);
match &*query.body {
SetExpr::Query(query) => {
match self.infer_shard(&*query) {
Some(shard) => {
shards.insert(shard);
}
ShardingKey::Placeholder(position) => {
self.placeholders.push(position);
}
None => (),
};
}
}
// SELECT * FROM ...
// We understand that pretty well.
SetExpr::Select(select) => {
// Collect all table names from the query.
let mut table_names = Vec::new();
for table in select.from.iter() {
match &table.relation {
TableFactor::Table { name, .. } => {
table_names.push(name.0.clone());
}
_ => (),
};
// Get table names from all the joins.
for join in table.joins.iter() {
match &join.relation {
TableFactor::Table { name, .. } => {
table_names.push(name.0.clone());
}
_ => (),
};
// We can filter results based on join conditions, e.g.
// SELECT * FROM t INNER JOIN B ON B.sharding_key = 5;
match &join.join_operator {
JoinOperator::Inner(inner_join) => match &inner_join {
JoinConstraint::On(expr) => {
// Parse the selection criteria later.
exprs.push(expr.clone());
}
_ => (),
},
_ => (),
};
}
}
// Parse the actual "FROM ..."
match &select.selection {
Some(selection) => {
exprs.push(selection.clone());
}
None => (),
};
let sharder = Sharder::new(
self.pool_settings.shards,
self.pool_settings.sharding_function,
);
// Look for sharding keys in either the join condition
// or the selection.
for expr in exprs.iter() {
let sharding_keys = self.selection_parser(expr, &table_names);
// TODO: Add support for prepared statements here.
// This should just give us the position of the value in the `B` message.
for value in sharding_keys {
match value {
ShardingKey::Value(value) => {
let shard = sharder.shard(value);
shards.insert(shard);
}
ShardingKey::Placeholder(position) => {
self.placeholders.push(position);
}
};
}
}
}
_ => (),
};
match shards.len() {
// Didn't find a sharding key, you're on your own.
0 => {
@@ -997,16 +830,16 @@ impl QueryRouter {
db: &self.pool_settings.db,
};
let _ = query_logger.run(self, ast).await;
let _ = query_logger.run(&self, ast).await;
}
if let Some(ref intercept) = plugins.intercept {
let mut intercept = Intercept {
enabled: intercept.enabled,
config: intercept,
config: &intercept,
};
let result = intercept.run(self, ast).await;
let result = intercept.run(&self, ast).await;
if let Ok(PluginOutput::Intercept(output)) = result {
return Ok(PluginOutput::Intercept(output));
@@ -1019,7 +852,7 @@ impl QueryRouter {
tables: &table_access.tables,
};
let result = table_access.run(self, ast).await;
let result = table_access.run(&self, ast).await;
if let Ok(PluginOutput::Deny(error)) = result {
return Ok(PluginOutput::Deny(error));
@@ -1055,7 +888,7 @@ impl QueryRouter {
/// Should we attempt to parse queries?
pub fn query_parser_enabled(&self) -> bool {
match self.query_parser_enabled {
let enabled = match self.query_parser_enabled {
None => {
debug!(
"Using pool settings, query_parser_enabled: {}",
@@ -1071,7 +904,9 @@ impl QueryRouter {
);
value
}
}
};
enabled
}
pub fn primary_reads_enabled(&self) -> bool {
@@ -1082,12 +917,6 @@ impl QueryRouter {
}
}
impl Default for QueryRouter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod test {
use super::*;
@@ -1109,14 +938,10 @@ mod test {
QueryRouter::setup();
let mut qr = QueryRouter::new();
qr.pool_settings.query_parser_read_write_splitting = true;
assert!(qr
.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"))
.is_some());
assert!(qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")) != None);
assert!(qr.query_parser_enabled());
assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
.is_some());
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
let queries = vec![
simple_query("SELECT * FROM items WHERE id = 5"),
@@ -1158,9 +983,7 @@ mod test {
QueryRouter::setup();
let mut qr = QueryRouter::new();
let query = simple_query("SELECT * FROM items WHERE id = 5");
assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO on"))
.is_some());
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None);
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), None);
@@ -1173,9 +996,7 @@ mod test {
qr.pool_settings.query_parser_read_write_splitting = true;
qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"));
assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
.is_some());
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
let prepared_stmt = BytesMut::from(
&b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..],
@@ -1345,11 +1166,9 @@ mod test {
qr.pool_settings.query_parser_read_write_splitting = true;
let query = simple_query("SET SERVER ROLE TO 'auto'");
assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
.is_some());
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
assert!(qr.try_execute_command(&query).is_some());
assert!(qr.try_execute_command(&query) != None);
assert!(qr.query_parser_enabled());
assert_eq!(qr.role(), None);
@@ -1363,7 +1182,7 @@ mod test {
assert!(qr.query_parser_enabled());
let query = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(&query).is_some());
assert!(qr.try_execute_command(&query) != None);
assert!(!qr.query_parser_enabled());
}
@@ -1403,7 +1222,7 @@ mod test {
assert_eq!(qr.primary_reads_enabled, None);
// Internal state must not be changed due to this, only defaults
qr.update_pool_settings(&pool_settings);
qr.update_pool_settings(pool_settings.clone());
assert_eq!(qr.active_role, None);
assert_eq!(qr.active_shard, None);
@@ -1411,11 +1230,11 @@ mod test {
assert!(!qr.primary_reads_enabled());
let q1 = simple_query("SET SERVER ROLE TO 'primary'");
assert!(qr.try_execute_command(&q1).is_some());
assert!(qr.try_execute_command(&q1) != None);
assert_eq!(qr.active_role.unwrap(), Role::Primary);
let q2 = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(&q2).is_some());
assert!(qr.try_execute_command(&q2) != None);
assert_eq!(qr.active_role.unwrap(), pool_settings.default_role);
}
@@ -1476,29 +1295,29 @@ mod test {
};
let mut qr = QueryRouter::new();
qr.update_pool_settings(&pool_settings);
qr.update_pool_settings(pool_settings.clone());
// Shard should start out unset
assert_eq!(qr.active_shard, None);
// Don't panic when short query eg. ; is sent
let q0 = simple_query(";");
assert!(qr.try_execute_command(&q0).is_none());
assert!(qr.try_execute_command(&q0) == None);
assert_eq!(qr.active_shard, None);
// Make sure setting it works
let q1 = simple_query("/* shard_id: 1 */ select 1 from foo;");
assert!(qr.try_execute_command(&q1).is_none());
assert!(qr.try_execute_command(&q1) == None);
assert_eq!(qr.active_shard, Some(1));
// And make sure changing it works
let q2 = simple_query("/* shard_id: 0 */ select 1 from foo;");
assert!(qr.try_execute_command(&q2).is_none());
assert!(qr.try_execute_command(&q2) == None);
assert_eq!(qr.active_shard, Some(0));
// Validate setting by shard with expected shard copied from sharding.rs tests
let q2 = simple_query("/* sharding_key: 6 */ select 1 from foo;");
assert!(qr.try_execute_command(&q2).is_none());
assert!(qr.try_execute_command(&q2) == None);
assert_eq!(qr.active_shard, Some(2));
}
@@ -1595,221 +1414,6 @@ mod test {
assert_eq!(qr.shard().unwrap(), 0);
}
fn auto_shard_wrapper(qry: &str, should_succeed: bool) -> Option<usize> {
let mut qr = QueryRouter::new();
qr.pool_settings.automatic_sharding_key = Some("*.w_id".to_string());
qr.pool_settings.shards = 3;
qr.pool_settings.query_parser_read_write_splitting = true;
assert_eq!(qr.shard(), None);
let infer_res = qr.infer(&qr.parse(&simple_query(qry)).unwrap());
assert_eq!(infer_res.is_ok(), should_succeed);
qr.shard()
}
fn auto_shard(qry: &str) -> Option<usize> {
auto_shard_wrapper(qry, true)
}
fn auto_shard_fails(qry: &str) -> Option<usize> {
auto_shard_wrapper(qry, false)
}
#[test]
fn test_automatic_sharding_insert_update_delete() {
QueryRouter::setup();
assert_eq!(
auto_shard_fails(
"UPDATE ORDERS SET w_id = 3 WHERE O_ID = 3 AND O_D_ID = 3 AND W_ID = 5"
),
None
);
assert_eq!(
auto_shard_fails(
"UPDATE ORDERS o SET o.W_ID = 3 WHERE o.O_ID = 3 AND o.O_D_ID = 3 AND o.W_ID = 5"
),
None
);
assert_eq!(
auto_shard(
"UPDATE ORDERS o SET o.O_CARRIER_ID = 3 WHERE o.O_ID = 3 AND o.O_D_ID = 3 AND o.W_ID = 5"
),
Some(2)
);
}
#[test]
fn test_automatic_sharding_key_tpcc() {
QueryRouter::setup();
assert_eq!(auto_shard("SELECT * FROM my_tbl WHERE w_id = 5"), Some(2));
assert_eq!(
auto_shard("BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ"),
None
);
assert_eq!(auto_shard("COMMIT"), None);
assert_eq!(auto_shard("ROLLBACK"), None);
assert_eq!(auto_shard("SELECT NO_O_ID FROM NEW_ORDER WHERE NO_D_ID = 7 AND W_ID = 5 AND NO_O_ID > 3 LIMIT 3"), Some(2));
assert_eq!(auto_shard("SELECT NO_O_ID FROM NEW_ORDER no WHERE no.NO_D_ID = 7 AND no.W_ID = 5 AND no.NO_O_ID > 3 LIMIT 3"), Some(2));
assert_eq!(
auto_shard("DELETE FROM NEW_ORDER WHERE NO_D_ID = 7 AND W_ID = 5 AND NO_O_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT O_C_ID FROM ORDERS WHERE O_ID = 3 AND O_D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard(
"UPDATE ORDERS SET O_CARRIER_ID = 3 WHERE O_ID = 3 AND O_D_ID = 3 AND W_ID = 5"
),
Some(2)
);
assert_eq!(
auto_shard("UPDATE ORDER_LINE SET OL_DELIVERY_D = 3 WHERE OL_O_ID = 3 AND OL_D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT SUM(OL_AMOUNT) FROM ORDER_LINE WHERE OL_O_ID = 3 AND OL_D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE CUSTOMER SET C_BALANCE = C_BALANCE + 3 WHERE C_ID = 3 AND C_D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT W_TAX FROM WAREHOUSE WHERE W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT D_TAX, D_NEXT_O_ID FROM DISTRICT WHERE D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE DISTRICT SET D_NEXT_O_ID = 3 WHERE D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_DISCOUNT, C_LAST, C_CREDIT FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("INSERT INTO ORDERS (O_ID, O_D_ID, W_ID, O_C_ID, O_ENTRY_D, O_CARRIER_ID, O_OL_CNT, O_ALL_LOCAL) VALUES (3, 3, 5, 3, 3, 3, 3, 3)"),
Some(2)
);
assert_eq!(
auto_shard("INSERT INTO NEW_ORDER (NO_O_ID, NO_D_ID, W_ID) VALUES (3, 3, 5)"),
Some(2)
);
assert_eq!(
auto_shard("SELECT I_PRICE, I_NAME, I_DATA FROM ITEM WHERE I_ID = 3"),
None
);
assert_eq!(
auto_shard("SELECT S_QUANTITY, S_DATA, S_YTD, S_ORDER_CNT, S_REMOTE_CNT, S_DIST_03 FROM STOCK WHERE S_I_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE STOCK SET S_QUANTITY = 3, S_YTD = 3, S_ORDER_CNT = 3, S_REMOTE_CNT = 3 WHERE S_I_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("INSERT INTO ORDER_LINE (OL_O_ID, OL_D_ID, W_ID, OL_NUMBER, OL_I_ID, OL_SUPPLY_W_ID, OL_DELIVERY_D, OL_QUANTITY, OL_AMOUNT, OL_DIST_INFO) VALUES (3, 3, 5, 3, 3, 3, 3, 3, 3, 3)"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_BALANCE FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_BALANCE FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_LAST = '3' ORDER BY C_FIRST"),
Some(2)
);
assert_eq!(
auto_shard("SELECT O_ID, O_CARRIER_ID, O_ENTRY_D FROM ORDERS WHERE W_ID = 5 AND O_D_ID = 3 AND O_C_ID = 3 ORDER BY O_ID DESC LIMIT 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT OL_SUPPLY_W_ID, OL_I_ID, OL_QUANTITY, OL_AMOUNT, OL_DELIVERY_D FROM ORDER_LINE WHERE W_ID = 5 AND OL_D_ID = 3 AND OL_O_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT W_NAME, W_STREET_1, W_STREET_2, W_CITY, W_STATE, W_ZIP FROM WAREHOUSE WHERE W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE WAREHOUSE SET W_YTD = W_YTD + 3 WHERE W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT D_NAME, D_STREET_1, D_STREET_2, D_CITY, D_STATE, D_ZIP FROM DISTRICT WHERE W_ID = 5 AND D_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE DISTRICT SET D_YTD = D_YTD + 3 WHERE W_ID = 5 AND D_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_STREET_1, C_STREET_2, C_CITY, C_STATE, C_ZIP, C_PHONE, C_SINCE, C_CREDIT, C_CREDIT_LIM, C_DISCOUNT, C_BALANCE, C_YTD_PAYMENT, C_PAYMENT_CNT, C_DATA FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_STREET_1, C_STREET_2, C_CITY, C_STATE, C_ZIP, C_PHONE, C_SINCE, C_CREDIT, C_CREDIT_LIM, C_DISCOUNT, C_BALANCE, C_YTD_PAYMENT, C_PAYMENT_CNT, C_DATA FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_LAST = '3' ORDER BY C_FIRST"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE CUSTOMER SET C_BALANCE = 3, C_YTD_PAYMENT = 3, C_PAYMENT_CNT = 3, C_DATA = 3 WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE CUSTOMER SET C_BALANCE = 3, C_YTD_PAYMENT = 3, C_PAYMENT_CNT = 3 WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(auto_shard("INSERT INTO HISTORY (H_C_ID, H_C_D_ID, H_C_W_ID, H_D_ID, W_ID, H_DATE, H_AMOUNT, H_DATA) VALUES (3, 3, 5, 3, 5, 3, 3, 3)"), Some(2));
assert_eq!(
auto_shard("SELECT D_NEXT_O_ID FROM DISTRICT WHERE W_ID = 5 AND D_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard(
"SELECT COUNT(DISTINCT(OL_I_ID)) FROM ORDER_LINE, STOCK
WHERE ORDER_LINE.W_ID = 5
AND OL_D_ID = 3
AND OL_O_ID < 3
AND OL_O_ID >= 3
AND STOCK.W_ID = 5
AND S_I_ID = OL_I_ID
AND S_QUANTITY < 3"
),
Some(2)
);
// This is a distributed query and contains two shards
assert_eq!(
auto_shard(
"SELECT COUNT(DISTINCT(OL_I_ID)) FROM ORDER_LINE, STOCK
WHERE ORDER_LINE.W_ID = 5
AND OL_D_ID = 3
AND OL_O_ID < 3
AND OL_O_ID >= 3
AND STOCK.W_ID = 7
AND S_I_ID = OL_I_ID
AND S_QUANTITY < 3"
),
None
);
}
#[test]
fn test_prepared_statements() {
let stmt = "SELECT * FROM data WHERE id = $1";
@@ -1854,13 +1458,12 @@ mod test {
};
QueryRouter::setup();
let pool_settings = PoolSettings {
query_parser_enabled: true,
plugins: Some(plugins),
..Default::default()
};
let mut pool_settings = PoolSettings::default();
pool_settings.query_parser_enabled = true;
pool_settings.plugins = Some(plugins);
let mut qr = QueryRouter::new();
qr.update_pool_settings(&pool_settings);
qr.update_pool_settings(pool_settings);
let query = simple_query("SELECT * FROM pg_database");
let ast = qr.parse(&query).unwrap();

View File

@@ -79,12 +79,12 @@ impl ScramSha256 {
let server_message = Message::parse(message)?;
if !server_message.nonce.starts_with(&self.nonce) {
return Err(Error::ProtocolSyncError("SCRAM".to_string()));
return Err(Error::ProtocolSyncError(format!("SCRAM")));
}
let salt = match general_purpose::STANDARD.decode(&server_message.salt) {
Ok(salt) => salt,
Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
};
let salted_password = Self::hi(
@@ -166,9 +166,9 @@ impl ScramSha256 {
pub fn finish(&mut self, message: &BytesMut) -> Result<(), Error> {
let final_message = FinalMessage::parse(message)?;
let verifier = match general_purpose::STANDARD.decode(final_message.value) {
let verifier = match general_purpose::STANDARD.decode(&final_message.value) {
Ok(verifier) => verifier,
Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
};
let mut hmac = match Hmac::<Sha256>::new_from_slice(&self.salted_password) {
@@ -230,14 +230,14 @@ impl Message {
.collect::<Vec<String>>();
if parts.len() != 3 {
return Err(Error::ProtocolSyncError("SCRAM".to_string()));
return Err(Error::ProtocolSyncError(format!("SCRAM")));
}
let nonce = str::replace(&parts[0], "r=", "");
let salt = str::replace(&parts[1], "s=", "");
let iterations = match str::replace(&parts[2], "i=", "").parse::<u32>() {
Ok(iterations) => iterations,
Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
};
Ok(Message {
@@ -257,7 +257,7 @@ impl FinalMessage {
/// Parse the server final validation message.
pub fn parse(message: &BytesMut) -> Result<FinalMessage, Error> {
if !message.starts_with(b"v=") || message.len() < 4 {
return Err(Error::ProtocolSyncError("SCRAM".to_string()));
return Err(Error::ProtocolSyncError(format!("SCRAM")));
}
Ok(FinalMessage {

View File

@@ -3,14 +3,12 @@
use bytes::{Buf, BufMut, BytesMut};
use fallible_iterator::FallibleIterator;
use log::{debug, error, info, trace, warn};
use lru::LruCache;
use once_cell::sync::Lazy;
use parking_lot::{Mutex, RwLock};
use postgres_protocol::message;
use std::collections::{HashMap, HashSet};
use std::collections::{BTreeSet, HashMap, HashSet};
use std::mem;
use std::net::IpAddr;
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::time::SystemTime;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream};
@@ -18,7 +16,7 @@ use tokio::net::TcpStream;
use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore};
use tokio_rustls::{client::TlsStream, TlsConnector};
use crate::config::{get_config, Address, User};
use crate::config::{get_config, get_prepared_statements_cache_size, Address, User};
use crate::constants::*;
use crate::dns_cache::{AddrSet, CACHED_RESOLVER};
use crate::errors::{Error, ServerIdentifier};
@@ -199,8 +197,12 @@ impl ServerParameters {
key = "DateStyle".to_string();
};
if TRACKED_PARAMETERS.contains(&key) || startup {
if TRACKED_PARAMETERS.contains(&key) {
self.parameters.insert(key, value);
} else {
if startup {
self.parameters.insert(key, value);
}
}
}
@@ -324,13 +326,12 @@ pub struct Server {
log_client_parameter_status_changes: bool,
/// Prepared statements
prepared_statement_cache: Option<LruCache<String, ()>>,
prepared_statements: BTreeSet<String>,
}
impl Server {
/// Pretend to be the Postgres client and connect to the server given host, port and credentials.
/// Perform the authentication and return the server in a ready for query state.
#[allow(clippy::too_many_arguments)]
pub async fn startup(
address: &Address,
user: &User,
@@ -340,7 +341,6 @@ impl Server {
auth_hash: Arc<RwLock<Option<String>>>,
cleanup_connections: bool,
log_client_parameter_status_changes: bool,
prepared_statement_cache_size: usize,
) -> Result<Server, Error> {
let cached_resolver = CACHED_RESOLVER.load();
let mut addr_set: Option<AddrSet> = None;
@@ -440,7 +440,10 @@ impl Server {
// Something else?
m => {
return Err(Error::SocketError(format!("Unknown message: {}", { m })));
return Err(Error::SocketError(format!(
"Unknown message: {}",
m as char
)));
}
}
} else {
@@ -458,20 +461,26 @@ impl Server {
None => &user.username,
};
let password = match user.server_password.as_ref() {
Some(server_password) => Some(server_password),
None => user.password.as_ref(),
let password = match user.server_password {
Some(ref server_password) => Some(server_password),
None => match user.password {
Some(ref password) => Some(password),
None => None,
},
};
startup(&mut stream, username, database).await?;
let mut process_id: i32 = 0;
let mut secret_key: i32 = 0;
let server_identifier = ServerIdentifier::new(username, database);
let server_identifier = ServerIdentifier::new(username, &database);
// We'll be handling multiple packets, but they will all be structured the same.
// We'll loop here until this exchange is complete.
let mut scram: Option<ScramSha256> = password.map(|password| ScramSha256::new(password));
let mut scram: Option<ScramSha256> = match password {
Some(password) => Some(ScramSha256::new(password)),
None => None,
};
let mut server_parameters = ServerParameters::new();
@@ -716,7 +725,7 @@ impl Server {
}
};
let fields = match PgErrorMsg::parse(&error) {
let fields = match PgErrorMsg::parse(error) {
Ok(f) => f,
Err(err) => {
return Err(err);
@@ -821,12 +830,7 @@ impl Server {
},
cleanup_connections,
log_client_parameter_status_changes,
prepared_statement_cache: match prepared_statement_cache_size {
0 => None,
_ => Some(LruCache::new(
NonZeroUsize::new(prepared_statement_cache_size).unwrap(),
)),
},
prepared_statements: BTreeSet::new(),
};
return Ok(server);
@@ -878,7 +882,7 @@ impl Server {
self.mirror_send(messages);
self.stats().data_sent(messages.len());
match write_all_flush(&mut self.stream, messages).await {
match write_all_flush(&mut self.stream, &messages).await {
Ok(_) => {
// Successfully sent to server
self.last_activity = SystemTime::now();
@@ -965,20 +969,6 @@ impl Server {
if self.in_copy_mode {
self.in_copy_mode = false;
}
if self.prepared_statement_cache.is_some() {
let error_message = PgErrorMsg::parse(&message)?;
if error_message.message == "cached plan must not change result type" {
warn!("Server {:?} changed schema, dropping connection to clean up prepared statements", self.address);
// This will still result in an error to the client, but this server connection will drop all cached prepared statements
// so that any new queries will be re-prepared
// TODO: Other ideas to solve errors when there are DDL changes after a statement has been prepared
// - Recreate entire connection pool to force recreation of all server connections
// - Clear the ConnectionPool's statement cache so that new statement names are generated
// - Implement a retry (re-prepare) so the client doesn't see an error
self.cleanup_state.needs_cleanup_prepare = true;
}
}
}
// CommandComplete
@@ -1089,92 +1079,115 @@ impl Server {
Ok(bytes)
}
// Determines if the server already has a prepared statement with the given name
// Increments the prepared statement cache hit counter
pub fn has_prepared_statement(&mut self, name: &str) -> bool {
let cache = match &mut self.prepared_statement_cache {
Some(cache) => cache,
None => return false,
};
/// Add the prepared statement to being tracked by this server.
/// The client is processing data that will create a prepared statement on this server.
pub fn will_prepare(&mut self, name: &str) {
debug!("Will prepare `{}`", name);
let has_it = cache.get(name).is_some();
if has_it {
self.stats.prepared_cache_hit();
} else {
self.prepared_statements.insert(name.to_string());
self.stats.prepared_cache_add();
}
/// Check if we should prepare a statement on the server.
pub fn should_prepare(&self, name: &str) -> bool {
let should_prepare = !self.prepared_statements.contains(name);
debug!("Should prepare `{}`: {}", name, should_prepare);
if should_prepare {
self.stats.prepared_cache_miss();
} else {
self.stats.prepared_cache_hit();
}
has_it
should_prepare
}
pub fn add_prepared_statement_to_cache(&mut self, name: &str) -> Option<String> {
let cache = match &mut self.prepared_statement_cache {
Some(cache) => cache,
None => return None,
};
/// Create a prepared statement on the server.
pub async fn prepare(&mut self, parse: &Parse) -> Result<(), Error> {
debug!("Preparing `{}`", parse.name);
let bytes: BytesMut = parse.try_into()?;
self.send(&bytes).await?;
self.send(&flush()).await?;
// Read and discard ParseComplete (B)
match read_message(&mut self.stream).await {
Ok(_) => (),
Err(err) => {
self.bad = true;
return Err(err);
}
}
self.prepared_statements.insert(parse.name.to_string());
self.stats.prepared_cache_add();
// If we evict something, we need to close it on the server
if let Some((evicted_name, _)) = cache.push(name.to_string(), ()) {
if evicted_name != name {
debug!(
"Evicted prepared statement {} from cache, replaced with {}",
evicted_name, name
);
return Some(evicted_name);
}
};
debug!("Prepared `{}`", parse.name);
None
Ok(())
}
pub fn remove_prepared_statement_from_cache(&mut self, name: &str) {
let cache = match &mut self.prepared_statement_cache {
Some(cache) => cache,
None => return,
};
/// Maintain adequate cache size on the server.
pub async fn maintain_cache(&mut self) -> Result<(), Error> {
debug!("Cache maintenance run");
self.stats.prepared_cache_remove();
cache.pop(name);
let max_cache_size = get_prepared_statements_cache_size();
let mut names = Vec::new();
while self.prepared_statements.len() >= max_cache_size {
// The prepared statmeents are alphanumerically sorted by the BTree.
// FIFO.
if let Some(name) = self.prepared_statements.pop_last() {
names.push(name);
}
}
if !names.is_empty() {
self.deallocate(names).await?;
}
Ok(())
}
pub async fn register_prepared_statement(
&mut self,
parse: &Parse,
should_send_parse_to_server: bool,
) -> Result<(), Error> {
if !self.has_prepared_statement(&parse.name) {
let mut bytes = BytesMut::new();
/// Remove the prepared statement from being tracked by this server.
/// The client is processing data that will cause the server to close the prepared statement.
pub fn will_close(&mut self, name: &str) {
debug!("Will close `{}`", name);
if should_send_parse_to_server {
let parse_bytes: BytesMut = parse.try_into()?;
bytes.extend_from_slice(&parse_bytes);
}
self.prepared_statements.remove(name);
}
// If we evict something, we need to close it on the server
// We do this by adding it to the messages we're sending to the server before the sync
if let Some(evicted_name) = self.add_prepared_statement_to_cache(&parse.name) {
self.remove_prepared_statement_from_cache(&evicted_name);
let close_bytes: BytesMut = Close::new(&evicted_name).try_into()?;
bytes.extend_from_slice(&close_bytes);
};
/// Close a prepared statement on the server.
pub async fn deallocate(&mut self, names: Vec<String>) -> Result<(), Error> {
for name in &names {
debug!("Deallocating prepared statement `{}`", name);
// If we have a parse or close we need to send to the server, send them and sync
if !bytes.is_empty() {
bytes.extend_from_slice(&sync());
let close = Close::new(name);
let bytes: BytesMut = close.try_into()?;
self.send(&bytes).await?;
self.send(&bytes).await?;
}
loop {
self.recv(None).await?;
if !names.is_empty() {
self.send(&flush()).await?;
}
if !self.is_data_available() {
break;
}
// Read and discard CloseComplete (3)
for name in &names {
match read_message(&mut self.stream).await {
Ok(_) => {
self.prepared_statements.remove(name);
self.stats.prepared_cache_remove();
debug!("Closed `{}`", name);
}
}
};
Err(err) => {
self.bad = true;
return Err(err);
}
};
}
Ok(())
}
@@ -1311,10 +1324,6 @@ impl Server {
if self.cleanup_state.needs_cleanup_prepare {
reset_string.push_str("DEALLOCATE ALL;");
// Since we deallocated all prepared statements, we need to clear the cache
if let Some(cache) = &mut self.prepared_statement_cache {
cache.clear();
}
};
self.query(&reset_string).await?;
@@ -1350,14 +1359,16 @@ impl Server {
}
pub fn mirror_send(&mut self, bytes: &BytesMut) {
if let Some(manager) = self.mirror_manager.as_mut() {
manager.send(bytes)
match self.mirror_manager.as_mut() {
Some(manager) => manager.send(bytes),
None => (),
}
}
pub fn mirror_disconnect(&mut self) {
if let Some(manager) = self.mirror_manager.as_mut() {
manager.disconnect()
match self.mirror_manager.as_mut() {
Some(manager) => manager.disconnect(),
None => (),
}
}
@@ -1380,14 +1391,13 @@ impl Server {
Arc::new(RwLock::new(None)),
true,
false,
0,
)
.await?;
debug!("Connected!, sending query.");
server.send(&simple_query(query)).await?;
let mut message = server.recv(None).await?;
parse_query_message(&mut message).await
Ok(parse_query_message(&mut message).await?)
}
}

View File

@@ -64,7 +64,7 @@ impl Sharder {
fn sha1(&self, key: i64) -> usize {
let mut hasher = Sha1::new();
hasher.update(key.to_string().as_bytes());
hasher.update(&key.to_string().as_bytes());
let result = hasher.finalize();
@@ -202,10 +202,10 @@ mod test {
#[test]
fn test_sha1_hash() {
let sharder = Sharder::new(12, ShardingFunction::Sha1);
let ids = [
let ids = vec![
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
];
let shards = [
let shards = vec![
4, 7, 8, 3, 6, 0, 0, 10, 3, 11, 1, 7, 4, 4, 11, 2, 5, 0, 8, 3,
];

View File

@@ -86,11 +86,11 @@ impl PoolStats {
}
}
map
return map;
}
pub fn generate_header() -> Vec<(&'static str, DataType)> {
vec![
return vec![
("database", DataType::Text),
("user", DataType::Text),
("pool_mode", DataType::Text),
@@ -105,11 +105,11 @@ impl PoolStats {
("sv_login", DataType::Numeric),
("maxwait", DataType::Numeric),
("maxwait_us", DataType::Numeric),
]
];
}
pub fn generate_row(&self) -> Vec<String> {
vec![
return vec![
self.identifier.db.clone(),
self.identifier.user.clone(),
self.mode.to_string(),
@@ -124,7 +124,7 @@ impl PoolStats {
self.sv_login.to_string(),
(self.maxwait / 1_000_000).to_string(),
(self.maxwait % 1_000_000).to_string(),
]
];
}
}

View File

@@ -49,7 +49,6 @@ pub struct ServerStats {
pub error_count: Arc<AtomicU64>,
pub prepared_hit_count: Arc<AtomicU64>,
pub prepared_miss_count: Arc<AtomicU64>,
pub prepared_eviction_count: Arc<AtomicU64>,
pub prepared_cache_size: Arc<AtomicU64>,
}
@@ -69,7 +68,6 @@ impl Default for ServerStats {
reporter: get_reporter(),
prepared_hit_count: Arc::new(AtomicU64::new(0)),
prepared_miss_count: Arc::new(AtomicU64::new(0)),
prepared_eviction_count: Arc::new(AtomicU64::new(0)),
prepared_cache_size: Arc::new(AtomicU64::new(0)),
}
}
@@ -223,7 +221,6 @@ impl ServerStats {
}
pub fn prepared_cache_remove(&self) {
self.prepared_eviction_count.fetch_add(1, Ordering::Relaxed);
self.prepared_cache_size.fetch_sub(1, Ordering::Relaxed);
}
}

View File

@@ -36,4 +36,4 @@ SELECT abalance FROM pgbench_accounts WHERE aid = :aid;
SET SERVER ROLE TO 'replica';
-- Read load balancing
SELECT abalance FROM pgbench_accounts WHERE aid = :aid;
SELECT abalance FROM pgbench_accounts WHERE aid = :aid;

View File

@@ -1,214 +1,29 @@
require_relative 'spec_helper'
describe 'Prepared statements' do
let(:pool_size) { 5 }
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", pool_size) }
let(:prepared_statements_cache_size) { 100 }
let(:server_round_robin) { false }
let(:processes) { Helpers::Pgcat.three_shard_setup('sharded_db', 5) }
before do
new_configs = processes.pgcat.current_config
new_configs["general"]["server_round_robin"] = server_round_robin
new_configs["pools"]["sharded_db"]["prepared_statements_cache_size"] = prepared_statements_cache_size
new_configs["pools"]["sharded_db"]["users"]["0"]["pool_size"] = pool_size
processes.pgcat.update_config(new_configs)
processes.pgcat.reload_config
end
context 'when trying prepared statements' do
it 'it allows unparameterized statements to succeed' do
conn1 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
conn2 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
prepared_query = "SELECT 1"
# prepare query on server 1 and client 1
conn1.prepare('statement1', prepared_query)
conn1.exec_prepared('statement1')
conn2.transaction do
# Claim server 1 with client 2
conn2.exec("SELECT 2")
# Client 1 now runs the prepared query, and it's automatically
# prepared on server 2
conn1.prepare('statement2', prepared_query)
conn1.exec_prepared('statement2')
# Client 2 now prepares the same query that was already
# prepared on server 1. And PgBouncer reuses that already
# prepared query for this different client.
conn2.prepare('statement3', prepared_query)
conn2.exec_prepared('statement3')
end
ensure
conn1.close if conn1
conn2.close if conn2
end
it 'it allows parameterized statements to succeed' do
conn1 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
conn2 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
prepared_query = "SELECT $1"
# prepare query on server 1 and client 1
conn1.prepare('statement1', prepared_query)
conn1.exec_prepared('statement1', [1])
conn2.transaction do
# Claim server 1 with client 2
conn2.exec("SELECT 2")
# Client 1 now runs the prepared query, and it's automatically
# prepared on server 2
conn1.prepare('statement2', prepared_query)
conn1.exec_prepared('statement2', [1])
# Client 2 now prepares the same query that was already
# prepared on server 1. And PgBouncer reuses that already
# prepared query for this different client.
conn2.prepare('statement3', prepared_query)
conn2.exec_prepared('statement3', [1])
end
ensure
conn1.close if conn1
conn2.close if conn2
end
end
context 'when trying large packets' do
it "works with large parse" do
conn1 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
long_string = "1" * 4096 * 10
prepared_query = "SELECT '#{long_string}'"
# prepare query on server 1 and client 1
conn1.prepare('statement1', prepared_query)
result = conn1.exec_prepared('statement1')
# assert result matches long_string
expect(result.getvalue(0, 0)).to eq(long_string)
ensure
conn1.close if conn1
end
it "works with large bind" do
conn1 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
long_string = "1" * 4096 * 10
prepared_query = "SELECT $1::text"
# prepare query on server 1 and client 1
conn1.prepare('statement1', prepared_query)
result = conn1.exec_prepared('statement1', [long_string])
# assert result matches long_string
expect(result.getvalue(0, 0)).to eq(long_string)
ensure
conn1.close if conn1
end
end
context 'when statement cache is smaller than set of unqiue statements' do
let(:prepared_statements_cache_size) { 1 }
let(:pool_size) { 1 }
it "evicts all but 1 statement from the server cache" do
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
5.times do |i|
prepared_query = "SELECT '#{i}'"
conn.prepare("statement#{i}", prepared_query)
result = conn.exec_prepared("statement#{i}")
expect(result.getvalue(0, 0)).to eq(i.to_s)
end
# Check number of prepared statements (expected: 1)
n_statements = conn.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
expect(n_statements).to eq(1)
end
end
context 'when statement cache is larger than set of unqiue statements' do
let(:pool_size) { 1 }
it "does not evict any of the statements from the cache" do
# cache size 5
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
5.times do |i|
prepared_query = "SELECT '#{i}'"
conn.prepare("statement#{i}", prepared_query)
result = conn.exec_prepared("statement#{i}")
expect(result.getvalue(0, 0)).to eq(i.to_s)
end
# Check number of prepared statements (expected: 1)
n_statements = conn.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
expect(n_statements).to eq(5)
end
end
context 'when preparing the same query' do
let(:prepared_statements_cache_size) { 5 }
let(:pool_size) { 5 }
it "reuses statement cache when there are different statement names on the same connection" do
context 'enabled' do
it 'will work over the same connection' do
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
10.times do |i|
statement_name = "statement_#{i}"
conn.prepare(statement_name, 'SELECT $1::int')
conn.exec_prepared(statement_name, [1])
conn.describe_prepared(statement_name)
end
# Check number of prepared statements (expected: 1)
n_statements = conn.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
expect(n_statements).to eq(1)
end
it "reuses statement cache when there are different statement names on different connections" do
10.times do |i|
it 'will work with new connections' do
10.times do
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
statement_name = "statement_#{i}"
conn.prepare(statement_name, 'SELECT $1::int')
conn.exec_prepared(statement_name, [1])
statement_name = 'statement1'
conn.prepare('statement1', 'SELECT $1::int')
conn.exec_prepared('statement1', [1])
conn.describe_prepared('statement1')
end
# Check number of prepared statements (expected: 1)
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
n_statements = conn.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
expect(n_statements).to eq(1)
end
end
context 'when reloading config' do
let(:pool_size) { 1 }
it "test_reload_config" do
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
# prepare query
conn.prepare('statement1', 'SELECT 1')
conn.exec_prepared('statement1')
# Reload config which triggers pool recreation
new_configs = processes.pgcat.current_config
new_configs["pools"]["sharded_db"]["prepared_statements_cache_size"] = prepared_statements_cache_size + 1
processes.pgcat.update_config(new_configs)
processes.pgcat.reload_config
# check that we're starting with no prepared statements on the server
conn_check = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
n_statements = conn_check.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
expect(n_statements).to eq(0)
# still able to run prepared query
conn.exec_prepared('statement1')
end
end
end

4
tests/rust/Cargo.lock generated
View File

@@ -1206,9 +1206,9 @@ dependencies = [
[[package]]
name = "webpki"
version = "0.22.2"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07ecc0cd7cac091bf682ec5efa18b1cff79d617b84181f38b3951dbe135f607f"
checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd"
dependencies = [
"ring",
"untrusted",