mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 01:16:30 +00:00
Compare commits
10 Commits
levkk-tls-
...
levkk-bump
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b41cc2639 | ||
|
|
7d3003a16a | ||
|
|
d37df43a90 | ||
|
|
2c7bf52c17 | ||
|
|
de8df29ca4 | ||
|
|
c4fb72b9fc | ||
|
|
3371c01e0e | ||
|
|
c2a483f36a | ||
|
|
51cd13b8b5 | ||
|
|
a054b454d2 |
@@ -63,6 +63,9 @@ jobs:
|
||||
- run:
|
||||
name: "Lint"
|
||||
command: "cargo fmt --check"
|
||||
- run:
|
||||
name: "Clippy"
|
||||
command: "cargo clippy --all --all-targets -- -Dwarnings"
|
||||
- run:
|
||||
name: "Tests"
|
||||
command: "cargo clean && cargo build && cargo test && bash .circleci/run_tests.sh && .circleci/generate_coverage.sh"
|
||||
|
||||
2
.github/workflows/publish-deb-package.yml
vendored
2
.github/workflows/publish-deb-package.yml
vendored
@@ -4,7 +4,7 @@ on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
packageVersion:
|
||||
default: "1.1.2-dev"
|
||||
default: "1.1.2-dev1"
|
||||
jobs:
|
||||
build:
|
||||
strategy:
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -10,3 +10,4 @@ lcov.info
|
||||
dev/.bash_history
|
||||
dev/cache
|
||||
!dev/cache/.keepme
|
||||
.venv
|
||||
25
CONFIG.md
25
CONFIG.md
@@ -259,22 +259,6 @@ Password to be used for connecting to servers to obtain the hash used for md5 au
|
||||
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
|
||||
This parameter is inherited by every pool and can be redefined in pool configuration.
|
||||
|
||||
### prepared_statements
|
||||
```
|
||||
path: general.prepared_statements
|
||||
default: false
|
||||
```
|
||||
|
||||
Whether to use prepared statements or not.
|
||||
|
||||
### prepared_statements_cache_size
|
||||
```
|
||||
path: general.prepared_statements_cache_size
|
||||
default: 500
|
||||
```
|
||||
|
||||
Size of the prepared statements cache.
|
||||
|
||||
### dns_cache_enabled
|
||||
```
|
||||
path: general.dns_cache_enabled
|
||||
@@ -324,6 +308,15 @@ If the client doesn't specify, PgCat routes traffic to this role by default.
|
||||
`replica` round-robin between replicas only without touching the primary,
|
||||
`primary` all queries go to the primary unless otherwise specified.
|
||||
|
||||
### prepared_statements_cache_size
|
||||
```
|
||||
path: general.prepared_statements_cache_size
|
||||
default: 0
|
||||
```
|
||||
|
||||
Size of the prepared statements cache. 0 means disabled.
|
||||
TODO: update documentation
|
||||
|
||||
### query_parser_enabled
|
||||
```
|
||||
path: pools.<pool_name>.query_parser_enabled
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
Thank you for contributing! Just a few tips here:
|
||||
|
||||
1. `cargo fmt` your code before opening up a PR
|
||||
1. `cargo fmt` and `cargo clippy` your code before opening up a PR
|
||||
2. Run the test suite (e.g. `pgbench`) to make sure everything still works. The tests are in `.circleci/run_tests.sh`.
|
||||
3. Performance is important, make sure there are no regressions in your branch vs. `main`.
|
||||
|
||||
|
||||
33
Cargo.lock
generated
33
Cargo.lock
generated
@@ -17,6 +17,17 @@ version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"once_cell",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "1.0.2"
|
||||
@@ -26,6 +37,12 @@ dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "allocator-api2"
|
||||
version = "0.2.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5"
|
||||
|
||||
[[package]]
|
||||
name = "android-tzdata"
|
||||
version = "0.1.1"
|
||||
@@ -553,6 +570,10 @@ name = "hashbrown"
|
||||
version = "0.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"allocator-api2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
@@ -821,6 +842,15 @@ version = "0.4.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4"
|
||||
|
||||
[[package]]
|
||||
name = "lru"
|
||||
version = "0.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1efa59af2ddfad1854ae27d75009d538d0998b4b2fd47083e743ac1a10e46c60"
|
||||
dependencies = [
|
||||
"hashbrown 0.14.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lru-cache"
|
||||
version = "0.1.2"
|
||||
@@ -990,7 +1020,7 @@ checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94"
|
||||
|
||||
[[package]]
|
||||
name = "pgcat"
|
||||
version = "1.1.2-dev"
|
||||
version = "1.1.2-dev1"
|
||||
dependencies = [
|
||||
"arc-swap",
|
||||
"async-trait",
|
||||
@@ -1008,6 +1038,7 @@ dependencies = [
|
||||
"itertools",
|
||||
"jemallocator",
|
||||
"log",
|
||||
"lru",
|
||||
"md-5",
|
||||
"nix",
|
||||
"num_cpus",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "pgcat"
|
||||
version = "1.1.2-dev"
|
||||
version = "1.1.2-dev1"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
@@ -48,6 +48,7 @@ itertools = "0.10"
|
||||
clap = { version = "4.3.1", features = ["derive", "env"] }
|
||||
tracing = "0.1.37"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter", "std"]}
|
||||
lru = "0.12.0"
|
||||
|
||||
[target.'cfg(not(target_env = "msvc"))'.dependencies]
|
||||
jemallocator = "0.5.0"
|
||||
|
||||
@@ -8,6 +8,12 @@ WORKDIR /app
|
||||
RUN cargo build --release
|
||||
|
||||
FROM debian:bookworm-slim
|
||||
RUN apt-get update && apt-get install -o Dpkg::Options::=--force-confdef -yq --no-install-recommends \
|
||||
postgresql-client \
|
||||
# Clean up layer
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* \
|
||||
&& truncate -s 0 /var/log/*log
|
||||
COPY --from=builder /app/target/release/pgcat /usr/bin/pgcat
|
||||
COPY --from=builder /app/pgcat.toml /etc/pgcat/pgcat.toml
|
||||
WORKDIR /etc/pgcat
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
FROM rust:1.70-bullseye
|
||||
|
||||
# Dependencies
|
||||
COPY --from=sclevine/yj /bin/yj /bin/yj
|
||||
RUN /bin/yj -h
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y \
|
||||
llvm-11 psmisc postgresql-contrib postgresql-client \
|
||||
|
||||
10
pgcat.toml
10
pgcat.toml
@@ -60,12 +60,6 @@ tcp_keepalives_count = 5
|
||||
# Number of seconds between keepalive packets.
|
||||
tcp_keepalives_interval = 5
|
||||
|
||||
# Handle prepared statements.
|
||||
prepared_statements = true
|
||||
|
||||
# Prepared statements server cache size.
|
||||
prepared_statements_cache_size = 500
|
||||
|
||||
# Path to TLS Certificate file to use for TLS connections
|
||||
# tls_certificate = ".circleci/server.cert"
|
||||
# Path to TLS private key file to use for TLS connections
|
||||
@@ -156,6 +150,10 @@ load_balancing_mode = "random"
|
||||
# `primary` all queries go to the primary unless otherwise specified.
|
||||
default_role = "any"
|
||||
|
||||
# Prepared statements cache size.
|
||||
# TODO: update documentation
|
||||
prepared_statements_cache_size = 500
|
||||
|
||||
# If Query Parser is enabled, we'll attempt to parse
|
||||
# every incoming query to determine if it's a read or a write.
|
||||
# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,
|
||||
|
||||
12
src/admin.rs
12
src/admin.rs
@@ -283,7 +283,7 @@ where
|
||||
{
|
||||
let mut res = BytesMut::new();
|
||||
|
||||
let detail_msg = vec![
|
||||
let detail_msg = [
|
||||
"",
|
||||
"SHOW HELP|CONFIG|DATABASES|POOLS|CLIENTS|SERVERS|USERS|VERSION",
|
||||
// "SHOW PEERS|PEER_POOLS", // missing PEERS|PEER_POOLS
|
||||
@@ -301,7 +301,6 @@ where
|
||||
// "KILL <db>",
|
||||
// "SUSPEND",
|
||||
"SHUTDOWN",
|
||||
// "WAIT_CLOSE [<db>]", // missing
|
||||
];
|
||||
|
||||
res.put(notify("Console usage", detail_msg.join("\n\t")));
|
||||
@@ -745,6 +744,7 @@ where
|
||||
("age_seconds", DataType::Numeric),
|
||||
("prepare_cache_hit", DataType::Numeric),
|
||||
("prepare_cache_miss", DataType::Numeric),
|
||||
("prepare_cache_eviction", DataType::Numeric),
|
||||
("prepare_cache_size", DataType::Numeric),
|
||||
];
|
||||
|
||||
@@ -777,6 +777,10 @@ where
|
||||
.prepared_miss_count
|
||||
.load(Ordering::Relaxed)
|
||||
.to_string(),
|
||||
server
|
||||
.prepared_eviction_count
|
||||
.load(Ordering::Relaxed)
|
||||
.to_string(),
|
||||
server
|
||||
.prepared_cache_size
|
||||
.load(Ordering::Relaxed)
|
||||
@@ -802,7 +806,7 @@ where
|
||||
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
let parts: Vec<&str> = match tokens.len() == 2 {
|
||||
true => tokens[1].split(",").map(|part| part.trim()).collect(),
|
||||
true => tokens[1].split(',').map(|part| part.trim()).collect(),
|
||||
false => Vec::new(),
|
||||
};
|
||||
|
||||
@@ -865,7 +869,7 @@ where
|
||||
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
let parts: Vec<&str> = match tokens.len() == 2 {
|
||||
true => tokens[1].split(",").map(|part| part.trim()).collect(),
|
||||
true => tokens[1].split(',').map(|part| part.trim()).collect(),
|
||||
false => Vec::new(),
|
||||
};
|
||||
|
||||
|
||||
847
src/client.rs
847
src/client.rs
File diff suppressed because it is too large
Load Diff
@@ -25,7 +25,7 @@ pub struct Args {
|
||||
}
|
||||
|
||||
pub fn parse() -> Args {
|
||||
return Args::parse();
|
||||
Args::parse()
|
||||
}
|
||||
|
||||
#[derive(ValueEnum, Clone, Debug)]
|
||||
|
||||
236
src/config.rs
236
src/config.rs
@@ -1,6 +1,6 @@
|
||||
/// Parse the configuration file.
|
||||
use arc_swap::ArcSwap;
|
||||
use log::{error, info, warn};
|
||||
use log::{error, info};
|
||||
use once_cell::sync::Lazy;
|
||||
use regex::Regex;
|
||||
use serde::{Deserializer, Serializer};
|
||||
@@ -116,10 +116,10 @@ impl Default for Address {
|
||||
host: String::from("127.0.0.1"),
|
||||
port: 5432,
|
||||
shard: 0,
|
||||
address_index: 0,
|
||||
replica_number: 0,
|
||||
database: String::from("database"),
|
||||
role: Role::Replica,
|
||||
replica_number: 0,
|
||||
address_index: 0,
|
||||
username: String::from("username"),
|
||||
pool_name: String::from("pool_name"),
|
||||
mirrors: Vec::new(),
|
||||
@@ -236,18 +236,14 @@ impl Default for User {
|
||||
|
||||
impl User {
|
||||
fn validate(&self) -> Result<(), Error> {
|
||||
match self.min_pool_size {
|
||||
Some(min_pool_size) => {
|
||||
if min_pool_size > self.pool_size {
|
||||
error!(
|
||||
"min_pool_size of {} cannot be larger than pool_size of {}",
|
||||
min_pool_size, self.pool_size
|
||||
);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
if let Some(min_pool_size) = self.min_pool_size {
|
||||
if min_pool_size > self.pool_size {
|
||||
error!(
|
||||
"min_pool_size of {} cannot be larger than pool_size of {}",
|
||||
min_pool_size, self.pool_size
|
||||
);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
|
||||
None => (),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
@@ -341,12 +337,6 @@ pub struct General {
|
||||
pub auth_query: Option<String>,
|
||||
pub auth_query_user: Option<String>,
|
||||
pub auth_query_password: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub prepared_statements: bool,
|
||||
|
||||
#[serde(default = "General::default_prepared_statements_cache_size")]
|
||||
pub prepared_statements_cache_size: usize,
|
||||
}
|
||||
|
||||
impl General {
|
||||
@@ -428,10 +418,6 @@ impl General {
|
||||
pub fn default_server_round_robin() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
pub fn default_prepared_statements_cache_size() -> usize {
|
||||
500
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for General {
|
||||
@@ -443,35 +429,33 @@ impl Default for General {
|
||||
prometheus_exporter_port: 9930,
|
||||
connect_timeout: General::default_connect_timeout(),
|
||||
idle_timeout: General::default_idle_timeout(),
|
||||
shutdown_timeout: Self::default_shutdown_timeout(),
|
||||
healthcheck_timeout: Self::default_healthcheck_timeout(),
|
||||
healthcheck_delay: Self::default_healthcheck_delay(),
|
||||
ban_time: Self::default_ban_time(),
|
||||
worker_threads: Self::default_worker_threads(),
|
||||
idle_client_in_transaction_timeout: Self::default_idle_client_in_transaction_timeout(),
|
||||
tcp_keepalives_idle: Self::default_tcp_keepalives_idle(),
|
||||
tcp_keepalives_count: Self::default_tcp_keepalives_count(),
|
||||
tcp_keepalives_interval: Self::default_tcp_keepalives_interval(),
|
||||
tcp_user_timeout: Self::default_tcp_user_timeout(),
|
||||
log_client_connections: false,
|
||||
log_client_disconnections: false,
|
||||
autoreload: None,
|
||||
dns_cache_enabled: false,
|
||||
dns_max_ttl: Self::default_dns_max_ttl(),
|
||||
shutdown_timeout: Self::default_shutdown_timeout(),
|
||||
healthcheck_timeout: Self::default_healthcheck_timeout(),
|
||||
healthcheck_delay: Self::default_healthcheck_delay(),
|
||||
ban_time: Self::default_ban_time(),
|
||||
idle_client_in_transaction_timeout: Self::default_idle_client_in_transaction_timeout(),
|
||||
server_lifetime: Self::default_server_lifetime(),
|
||||
server_round_robin: Self::default_server_round_robin(),
|
||||
worker_threads: Self::default_worker_threads(),
|
||||
autoreload: None,
|
||||
tls_certificate: None,
|
||||
tls_private_key: None,
|
||||
server_tls: false,
|
||||
verify_server_certificate: false,
|
||||
admin_username: String::from("admin"),
|
||||
admin_password: String::from("admin"),
|
||||
validate_config: true,
|
||||
auth_query: None,
|
||||
auth_query_user: None,
|
||||
auth_query_password: None,
|
||||
server_lifetime: Self::default_server_lifetime(),
|
||||
server_round_robin: Self::default_server_round_robin(),
|
||||
validate_config: true,
|
||||
prepared_statements: false,
|
||||
prepared_statements_cache_size: 500,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -572,6 +556,9 @@ pub struct Pool {
|
||||
#[serde(default)] // False
|
||||
pub log_client_parameter_status_changes: bool,
|
||||
|
||||
#[serde(default = "Pool::default_prepared_statements_cache_size")]
|
||||
pub prepared_statements_cache_size: usize,
|
||||
|
||||
pub plugins: Option<Plugins>,
|
||||
pub shards: BTreeMap<String, Shard>,
|
||||
pub users: BTreeMap<String, User>,
|
||||
@@ -621,6 +608,10 @@ impl Pool {
|
||||
true
|
||||
}
|
||||
|
||||
pub fn default_prepared_statements_cache_size() -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
pub fn validate(&mut self) -> Result<(), Error> {
|
||||
match self.default_role.as_ref() {
|
||||
"any" => (),
|
||||
@@ -677,9 +668,9 @@ impl Pool {
|
||||
Some(key) => {
|
||||
// No quotes in the key so we don't have to compare quoted
|
||||
// to unquoted idents.
|
||||
let key = key.replace("\"", "");
|
||||
let key = key.replace('\"', "");
|
||||
|
||||
if key.split(".").count() != 2 {
|
||||
if key.split('.').count() != 2 {
|
||||
error!(
|
||||
"automatic_sharding_key '{}' must be fully qualified, e.g. t.{}`",
|
||||
key, key
|
||||
@@ -692,17 +683,14 @@ impl Pool {
|
||||
None => None,
|
||||
};
|
||||
|
||||
match self.default_shard {
|
||||
DefaultShard::Shard(shard_number) => {
|
||||
if shard_number >= self.shards.len() {
|
||||
error!("Invalid shard {:?}", shard_number);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
if let DefaultShard::Shard(shard_number) = self.default_shard {
|
||||
if shard_number >= self.shards.len() {
|
||||
error!("Invalid shard {:?}", shard_number);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
|
||||
for (_, user) in &self.users {
|
||||
for user in self.users.values() {
|
||||
user.validate()?;
|
||||
}
|
||||
|
||||
@@ -715,17 +703,16 @@ impl Default for Pool {
|
||||
Pool {
|
||||
pool_mode: Self::default_pool_mode(),
|
||||
load_balancing_mode: Self::default_load_balancing_mode(),
|
||||
shards: BTreeMap::from([(String::from("1"), Shard::default())]),
|
||||
users: BTreeMap::default(),
|
||||
default_role: String::from("any"),
|
||||
query_parser_enabled: false,
|
||||
query_parser_max_length: None,
|
||||
query_parser_read_write_splitting: false,
|
||||
primary_reads_enabled: false,
|
||||
sharding_function: ShardingFunction::PgBigintHash,
|
||||
automatic_sharding_key: None,
|
||||
connect_timeout: None,
|
||||
idle_timeout: None,
|
||||
server_lifetime: None,
|
||||
sharding_function: ShardingFunction::PgBigintHash,
|
||||
automatic_sharding_key: None,
|
||||
sharding_key_regex: None,
|
||||
shard_id_regex: None,
|
||||
regex_search_limit: Some(1000),
|
||||
@@ -733,10 +720,12 @@ impl Default for Pool {
|
||||
auth_query: None,
|
||||
auth_query_user: None,
|
||||
auth_query_password: None,
|
||||
server_lifetime: None,
|
||||
plugins: None,
|
||||
cleanup_server_connections: true,
|
||||
log_client_parameter_status_changes: false,
|
||||
prepared_statements_cache_size: Self::default_prepared_statements_cache_size(),
|
||||
plugins: None,
|
||||
shards: BTreeMap::from([(String::from("1"), Shard::default())]),
|
||||
users: BTreeMap::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -777,8 +766,8 @@ impl<'de> serde::Deserialize<'de> for DefaultShard {
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
if s.starts_with("shard_") {
|
||||
let shard = s[6..].parse::<usize>().map_err(serde::de::Error::custom)?;
|
||||
if let Some(s) = s.strip_prefix("shard_") {
|
||||
let shard = s.parse::<usize>().map_err(serde::de::Error::custom)?;
|
||||
return Ok(DefaultShard::Shard(shard));
|
||||
}
|
||||
|
||||
@@ -848,13 +837,13 @@ impl Shard {
|
||||
impl Default for Shard {
|
||||
fn default() -> Shard {
|
||||
Shard {
|
||||
database: String::from("postgres"),
|
||||
mirrors: None,
|
||||
servers: vec![ServerConfig {
|
||||
host: String::from("localhost"),
|
||||
port: 5432,
|
||||
role: Role::Primary,
|
||||
}],
|
||||
mirrors: None,
|
||||
database: String::from("postgres"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -867,15 +856,26 @@ pub struct Plugins {
|
||||
pub prewarmer: Option<Prewarmer>,
|
||||
}
|
||||
|
||||
pub trait Plugin {
|
||||
fn is_enabled(&self) -> bool;
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Plugins {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
fn is_enabled<T: Plugin>(arg: Option<&T>) -> bool {
|
||||
if let Some(arg) = arg {
|
||||
arg.is_enabled()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
write!(
|
||||
f,
|
||||
"interceptor: {}, table_access: {}, query_logger: {}, prewarmer: {}",
|
||||
self.intercept.is_some(),
|
||||
self.table_access.is_some(),
|
||||
self.query_logger.is_some(),
|
||||
self.prewarmer.is_some(),
|
||||
is_enabled(self.intercept.as_ref()),
|
||||
is_enabled(self.table_access.as_ref()),
|
||||
is_enabled(self.query_logger.as_ref()),
|
||||
is_enabled(self.prewarmer.as_ref()),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -886,23 +886,47 @@ pub struct Intercept {
|
||||
pub queries: BTreeMap<String, Query>,
|
||||
}
|
||||
|
||||
impl Plugin for Intercept {
|
||||
fn is_enabled(&self) -> bool {
|
||||
self.enabled
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
|
||||
pub struct TableAccess {
|
||||
pub enabled: bool,
|
||||
pub tables: Vec<String>,
|
||||
}
|
||||
|
||||
impl Plugin for TableAccess {
|
||||
fn is_enabled(&self) -> bool {
|
||||
self.enabled
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
|
||||
pub struct QueryLogger {
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
impl Plugin for QueryLogger {
|
||||
fn is_enabled(&self) -> bool {
|
||||
self.enabled
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
|
||||
pub struct Prewarmer {
|
||||
pub enabled: bool,
|
||||
pub queries: Vec<String>,
|
||||
}
|
||||
|
||||
impl Plugin for Prewarmer {
|
||||
fn is_enabled(&self) -> bool {
|
||||
self.enabled
|
||||
}
|
||||
}
|
||||
|
||||
impl Intercept {
|
||||
pub fn substitute(&mut self, db: &str, user: &str) {
|
||||
for (_, query) in self.queries.iter_mut() {
|
||||
@@ -920,6 +944,7 @@ pub struct Query {
|
||||
}
|
||||
|
||||
impl Query {
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
pub fn substitute(&mut self, db: &str, user: &str) {
|
||||
for col in self.result.iter_mut() {
|
||||
for i in 0..col.len() {
|
||||
@@ -989,8 +1014,8 @@ impl Default for Config {
|
||||
Config {
|
||||
path: Self::default_path(),
|
||||
general: General::default(),
|
||||
pools: HashMap::default(),
|
||||
plugins: None,
|
||||
pools: HashMap::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1044,8 +1069,8 @@ impl From<&Config> for std::collections::HashMap<String, String> {
|
||||
(
|
||||
format!("pools.{:?}.users", pool_name),
|
||||
pool.users
|
||||
.iter()
|
||||
.map(|(_username, user)| &user.username)
|
||||
.values()
|
||||
.map(|user| &user.username)
|
||||
.cloned()
|
||||
.collect::<Vec<String>>()
|
||||
.join(", "),
|
||||
@@ -1099,6 +1124,7 @@ impl From<&Config> for std::collections::HashMap<String, String> {
|
||||
impl Config {
|
||||
/// Print current configuration.
|
||||
pub fn show(&self) {
|
||||
info!("Config path: {}", self.path);
|
||||
info!("Ban time: {}s", self.general.ban_time);
|
||||
info!(
|
||||
"Idle client in transaction timeout: {}ms",
|
||||
@@ -1130,13 +1156,9 @@ impl Config {
|
||||
Some(tls_certificate) => {
|
||||
info!("TLS certificate: {}", tls_certificate);
|
||||
|
||||
match self.general.tls_private_key.clone() {
|
||||
Some(tls_private_key) => {
|
||||
info!("TLS private key: {}", tls_private_key);
|
||||
info!("TLS support is enabled");
|
||||
}
|
||||
|
||||
None => (),
|
||||
if let Some(tls_private_key) = self.general.tls_private_key.clone() {
|
||||
info!("TLS private key: {}", tls_private_key);
|
||||
info!("TLS support is enabled");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1149,13 +1171,6 @@ impl Config {
|
||||
"Server TLS certificate verification: {}",
|
||||
self.general.verify_server_certificate
|
||||
);
|
||||
info!("Prepared statements: {}", self.general.prepared_statements);
|
||||
if self.general.prepared_statements {
|
||||
info!(
|
||||
"Prepared statements server cache size: {}",
|
||||
self.general.prepared_statements_cache_size
|
||||
);
|
||||
}
|
||||
info!(
|
||||
"Plugins: {}",
|
||||
match self.plugins {
|
||||
@@ -1171,8 +1186,8 @@ impl Config {
|
||||
pool_name,
|
||||
pool_config
|
||||
.users
|
||||
.iter()
|
||||
.map(|(_, user_cfg)| user_cfg.pool_size)
|
||||
.values()
|
||||
.map(|user_cfg| user_cfg.pool_size)
|
||||
.sum::<u32>()
|
||||
.to_string()
|
||||
);
|
||||
@@ -1246,6 +1261,10 @@ impl Config {
|
||||
"[pool: {}] Log client parameter status changes: {}",
|
||||
pool_name, pool_config.log_client_parameter_status_changes
|
||||
);
|
||||
info!(
|
||||
"[pool: {}] Prepared statements server cache size: {}",
|
||||
pool_name, pool_config.prepared_statements_cache_size
|
||||
);
|
||||
info!(
|
||||
"[pool: {}] Plugins: {}",
|
||||
pool_name,
|
||||
@@ -1342,42 +1361,31 @@ impl Config {
|
||||
}
|
||||
|
||||
// Validate TLS!
|
||||
match self.general.tls_certificate {
|
||||
Some(ref mut tls_certificate) => {
|
||||
match load_certs(Path::new(&tls_certificate)) {
|
||||
Ok(_) => {
|
||||
// Cert is okay, but what about the private key?
|
||||
match self.general.tls_private_key {
|
||||
Some(ref tls_private_key) => {
|
||||
match load_keys(Path::new(&tls_private_key)) {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"tls_private_key is incorrectly configured: {:?}",
|
||||
err
|
||||
);
|
||||
self.general.tls_private_key = None;
|
||||
self.general.tls_certificate = None;
|
||||
}
|
||||
}
|
||||
if let Some(tls_certificate) = self.general.tls_certificate.clone() {
|
||||
match load_certs(Path::new(&tls_certificate)) {
|
||||
Ok(_) => {
|
||||
// Cert is okay, but what about the private key?
|
||||
match self.general.tls_private_key.clone() {
|
||||
Some(tls_private_key) => match load_keys(Path::new(&tls_private_key)) {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
error!("tls_private_key is incorrectly configured: {:?}", err);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
},
|
||||
|
||||
None => {
|
||||
warn!("tls_certificate is set, but the tls_private_key is not");
|
||||
self.general.tls_private_key = None;
|
||||
self.general.tls_certificate = None;
|
||||
}
|
||||
};
|
||||
}
|
||||
None => {
|
||||
error!("tls_certificate is set, but the tls_private_key is not");
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
Err(err) => {
|
||||
warn!("tls_certificate is incorrectly configured: {:?}", err);
|
||||
self.general.tls_private_key = None;
|
||||
self.general.tls_certificate = None;
|
||||
}
|
||||
Err(err) => {
|
||||
error!("tls_certificate is incorrectly configured: {:?}", err);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
}
|
||||
None => (),
|
||||
};
|
||||
|
||||
for pool in self.pools.values_mut() {
|
||||
@@ -1399,14 +1407,6 @@ pub fn get_idle_client_in_transaction_timeout() -> u64 {
|
||||
CONFIG.load().general.idle_client_in_transaction_timeout
|
||||
}
|
||||
|
||||
pub fn get_prepared_statements() -> bool {
|
||||
CONFIG.load().general.prepared_statements
|
||||
}
|
||||
|
||||
pub fn get_prepared_statements_cache_size() -> usize {
|
||||
CONFIG.load().general.prepared_statements_cache_size
|
||||
}
|
||||
|
||||
/// Parse the configuration file located at the path.
|
||||
pub async fn parse(path: &str) -> Result<(), Error> {
|
||||
let mut contents = String::new();
|
||||
|
||||
305
src/messages.rs
305
src/messages.rs
@@ -12,13 +12,16 @@ use crate::config::get_config;
|
||||
use crate::errors::Error;
|
||||
|
||||
use crate::constants::MESSAGE_TERMINATOR;
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::CString;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::io::{BufRead, Cursor};
|
||||
use std::mem;
|
||||
use std::str::FromStr;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Postgres data type mappings
|
||||
@@ -114,19 +117,11 @@ pub fn simple_query(query: &str) -> BytesMut {
|
||||
}
|
||||
|
||||
/// Tell the client we're ready for another query.
|
||||
pub async fn ready_for_query<S>(stream: &mut S) -> Result<(), Error>
|
||||
pub async fn send_ready_for_query<S>(stream: &mut S) -> Result<(), Error>
|
||||
where
|
||||
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
let mut bytes = BytesMut::with_capacity(
|
||||
mem::size_of::<u8>() + mem::size_of::<i32>() + mem::size_of::<u8>(),
|
||||
);
|
||||
|
||||
bytes.put_u8(b'Z');
|
||||
bytes.put_i32(5);
|
||||
bytes.put_u8(b'I'); // Idle
|
||||
|
||||
write_all(stream, bytes).await
|
||||
write_all(stream, ready_for_query(false)).await
|
||||
}
|
||||
|
||||
/// Send the startup packet the server. We're pretending we're a Pg client.
|
||||
@@ -163,12 +158,10 @@ where
|
||||
|
||||
match stream.write_all(&startup).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error writing startup to server socket - Error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
Err(err) => Err(Error::SocketError(format!(
|
||||
"Error writing startup to server socket - Error: {:?}",
|
||||
err
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -244,8 +237,8 @@ pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec<u8> {
|
||||
let mut md5 = Md5::new();
|
||||
|
||||
// First pass
|
||||
md5.update(&password.as_bytes());
|
||||
md5.update(&user.as_bytes());
|
||||
md5.update(password.as_bytes());
|
||||
md5.update(user.as_bytes());
|
||||
|
||||
let output = md5.finalize_reset();
|
||||
|
||||
@@ -281,7 +274,7 @@ where
|
||||
{
|
||||
let password = md5_hash_password(user, password, salt);
|
||||
|
||||
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
|
||||
let mut message = BytesMut::with_capacity(password.len() + 5);
|
||||
|
||||
message.put_u8(b'p');
|
||||
message.put_i32(password.len() as i32 + 4);
|
||||
@@ -295,7 +288,7 @@ where
|
||||
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
let password = md5_hash_second_pass(hash, salt);
|
||||
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
|
||||
let mut message = BytesMut::with_capacity(password.len() + 5);
|
||||
|
||||
message.put_u8(b'p');
|
||||
message.put_i32(password.len() as i32 + 4);
|
||||
@@ -322,7 +315,7 @@ where
|
||||
res.put_slice(&set_complete[..]);
|
||||
|
||||
write_all_half(stream, &res).await?;
|
||||
ready_for_query(stream).await
|
||||
send_ready_for_query(stream).await
|
||||
}
|
||||
|
||||
/// Send a custom error message to the client.
|
||||
@@ -333,7 +326,7 @@ where
|
||||
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
error_response_terminal(stream, message).await?;
|
||||
ready_for_query(stream).await
|
||||
send_ready_for_query(stream).await
|
||||
}
|
||||
|
||||
/// Send a custom error message to the client.
|
||||
@@ -434,7 +427,7 @@ where
|
||||
res.put(command_complete("SELECT 1"));
|
||||
|
||||
write_all_half(stream, &res).await?;
|
||||
ready_for_query(stream).await
|
||||
send_ready_for_query(stream).await
|
||||
}
|
||||
|
||||
pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
|
||||
@@ -516,7 +509,7 @@ pub fn data_row_nullable(row: &Vec<Option<String>>) -> BytesMut {
|
||||
data_row.put_i32(column.len() as i32);
|
||||
data_row.put_slice(column);
|
||||
} else {
|
||||
data_row.put_i32(-1 as i32);
|
||||
data_row.put_i32(-1_i32);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -564,6 +557,37 @@ pub fn flush() -> BytesMut {
|
||||
bytes
|
||||
}
|
||||
|
||||
pub fn sync() -> BytesMut {
|
||||
let mut bytes = BytesMut::with_capacity(mem::size_of::<u8>() + mem::size_of::<i32>());
|
||||
bytes.put_u8(b'S');
|
||||
bytes.put_i32(4);
|
||||
bytes
|
||||
}
|
||||
|
||||
pub fn parse_complete() -> BytesMut {
|
||||
let mut bytes = BytesMut::with_capacity(mem::size_of::<u8>() + mem::size_of::<i32>());
|
||||
|
||||
bytes.put_u8(b'1');
|
||||
bytes.put_i32(4);
|
||||
bytes
|
||||
}
|
||||
|
||||
pub fn ready_for_query(in_transaction: bool) -> BytesMut {
|
||||
let mut bytes = BytesMut::with_capacity(
|
||||
mem::size_of::<u8>() + mem::size_of::<i32>() + mem::size_of::<u8>(),
|
||||
);
|
||||
|
||||
bytes.put_u8(b'Z');
|
||||
bytes.put_i32(5);
|
||||
if in_transaction {
|
||||
bytes.put_u8(b'T');
|
||||
} else {
|
||||
bytes.put_u8(b'I');
|
||||
}
|
||||
|
||||
bytes
|
||||
}
|
||||
|
||||
/// Write all data in the buffer to the TcpStream.
|
||||
pub async fn write_all<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error>
|
||||
where
|
||||
@@ -571,12 +595,10 @@ where
|
||||
{
|
||||
match stream.write_all(&buf).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error writing to socket - Error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
Err(err) => Err(Error::SocketError(format!(
|
||||
"Error writing to socket - Error: {:?}",
|
||||
err
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -587,12 +609,10 @@ where
|
||||
{
|
||||
match stream.write_all(buf).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error writing to socket - Error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
Err(err) => Err(Error::SocketError(format!(
|
||||
"Error writing to socket - Error: {:?}",
|
||||
err
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -603,19 +623,15 @@ where
|
||||
match stream.write_all(buf).await {
|
||||
Ok(_) => match stream.flush().await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error flushing socket - Error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error writing to socket - Error: {:?}",
|
||||
Err(err) => Err(Error::SocketError(format!(
|
||||
"Error flushing socket - Error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
))),
|
||||
},
|
||||
Err(err) => Err(Error::SocketError(format!(
|
||||
"Error writing to socket - Error: {:?}",
|
||||
err
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -730,7 +746,7 @@ impl BytesMutReader for Cursor<&BytesMut> {
|
||||
let mut buf = vec![];
|
||||
match self.read_until(b'\0', &mut buf) {
|
||||
Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()),
|
||||
Err(err) => return Err(Error::ParseBytesError(err.to_string())),
|
||||
Err(err) => Err(Error::ParseBytesError(err.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -746,10 +762,55 @@ impl BytesMutReader for BytesMut {
|
||||
let string_bytes = self.split_to(index + 1);
|
||||
Ok(String::from_utf8_lossy(&string_bytes[..string_bytes.len() - 1]).to_string())
|
||||
}
|
||||
None => return Err(Error::ParseBytesError("Could not read string".to_string())),
|
||||
None => Err(Error::ParseBytesError("Could not read string".to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ExtendedProtocolData {
|
||||
Parse {
|
||||
data: BytesMut,
|
||||
metadata: Option<(Arc<Parse>, u64)>,
|
||||
},
|
||||
Bind {
|
||||
data: BytesMut,
|
||||
metadata: Option<String>,
|
||||
},
|
||||
Describe {
|
||||
data: BytesMut,
|
||||
metadata: Option<String>,
|
||||
},
|
||||
Execute {
|
||||
data: BytesMut,
|
||||
},
|
||||
Close {
|
||||
data: BytesMut,
|
||||
close: Close,
|
||||
},
|
||||
}
|
||||
|
||||
impl ExtendedProtocolData {
|
||||
pub fn create_new_parse(data: BytesMut, metadata: Option<(Arc<Parse>, u64)>) -> Self {
|
||||
Self::Parse { data, metadata }
|
||||
}
|
||||
|
||||
pub fn create_new_bind(data: BytesMut, metadata: Option<String>) -> Self {
|
||||
Self::Bind { data, metadata }
|
||||
}
|
||||
|
||||
pub fn create_new_describe(data: BytesMut, metadata: Option<String>) -> Self {
|
||||
Self::Describe { data, metadata }
|
||||
}
|
||||
|
||||
pub fn create_new_execute(data: BytesMut) -> Self {
|
||||
Self::Execute { data }
|
||||
}
|
||||
|
||||
pub fn create_new_close(data: BytesMut, close: Close) -> Self {
|
||||
Self::Close { data, close }
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse (F) message.
|
||||
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -758,7 +819,6 @@ pub struct Parse {
|
||||
#[allow(dead_code)]
|
||||
len: i32,
|
||||
pub name: String,
|
||||
pub generated_name: String,
|
||||
query: String,
|
||||
num_params: i16,
|
||||
param_types: Vec<i32>,
|
||||
@@ -784,7 +844,6 @@ impl TryFrom<&BytesMut> for Parse {
|
||||
code,
|
||||
len,
|
||||
name,
|
||||
generated_name: prepared_statement_name(),
|
||||
query,
|
||||
num_params,
|
||||
param_types,
|
||||
@@ -833,11 +892,44 @@ impl TryFrom<&Parse> for BytesMut {
|
||||
}
|
||||
|
||||
impl Parse {
|
||||
pub fn rename(mut self) -> Self {
|
||||
self.name = self.generated_name.to_string();
|
||||
/// Renames the prepared statement to a new name based on the global counter
|
||||
pub fn rewrite(mut self) -> Self {
|
||||
self.name = format!(
|
||||
"PGCAT_{}",
|
||||
PREPARED_STATEMENT_COUNTER.fetch_add(1, Ordering::SeqCst)
|
||||
);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the name of the prepared statement from the buffer
|
||||
pub fn get_name(buf: &BytesMut) -> Result<String, Error> {
|
||||
let mut cursor = Cursor::new(buf);
|
||||
// Skip the code and length
|
||||
cursor.advance(mem::size_of::<u8>() + mem::size_of::<i32>());
|
||||
cursor.read_string()
|
||||
}
|
||||
|
||||
/// Hashes the parse statement to be used as a key in the global cache
|
||||
pub fn get_hash(&self) -> u64 {
|
||||
// TODO_ZAIN: Take a look at which hashing function is being used
|
||||
let mut hasher = DefaultHasher::new();
|
||||
|
||||
let concatenated = format!(
|
||||
"{}{}{}",
|
||||
self.query,
|
||||
self.num_params,
|
||||
self.param_types
|
||||
.iter()
|
||||
.map(ToString::to_string)
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
|
||||
concatenated.hash(&mut hasher);
|
||||
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
pub fn anonymous(&self) -> bool {
|
||||
self.name.is_empty()
|
||||
}
|
||||
@@ -968,9 +1060,42 @@ impl TryFrom<Bind> for BytesMut {
|
||||
}
|
||||
|
||||
impl Bind {
|
||||
pub fn reassign(mut self, parse: &Parse) -> Self {
|
||||
self.prepared_statement = parse.name.clone();
|
||||
self
|
||||
/// Gets the name of the prepared statement from the buffer
|
||||
pub fn get_name(buf: &BytesMut) -> Result<String, Error> {
|
||||
let mut cursor = Cursor::new(buf);
|
||||
// Skip the code and length
|
||||
cursor.advance(mem::size_of::<u8>() + mem::size_of::<i32>());
|
||||
cursor.read_string()?;
|
||||
cursor.read_string()
|
||||
}
|
||||
|
||||
/// Renames the prepared statement to a new name
|
||||
pub fn rename(buf: BytesMut, new_name: &str) -> Result<BytesMut, Error> {
|
||||
let mut cursor = Cursor::new(&buf);
|
||||
// Read basic data from the cursor
|
||||
let code = cursor.get_u8();
|
||||
let current_len = cursor.get_i32();
|
||||
let portal = cursor.read_string()?;
|
||||
let prepared_statement = cursor.read_string()?;
|
||||
|
||||
// Calculate new length
|
||||
let new_len = current_len + new_name.len() as i32 - prepared_statement.len() as i32;
|
||||
|
||||
// Begin building the response buffer
|
||||
let mut response_buf = BytesMut::with_capacity(new_len as usize + 1);
|
||||
response_buf.put_u8(code);
|
||||
response_buf.put_i32(new_len);
|
||||
|
||||
// Put the portal and new name into the buffer
|
||||
// Note: panic if the provided string contains null byte
|
||||
response_buf.put_slice(CString::new(portal)?.as_bytes_with_nul());
|
||||
response_buf.put_slice(CString::new(new_name)?.as_bytes_with_nul());
|
||||
|
||||
// Add the remainder of the original buffer into the response
|
||||
response_buf.put_slice(&buf[cursor.position() as usize..]);
|
||||
|
||||
// Return the buffer
|
||||
Ok(response_buf)
|
||||
}
|
||||
|
||||
pub fn anonymous(&self) -> bool {
|
||||
@@ -1026,6 +1151,15 @@ impl TryFrom<Describe> for BytesMut {
|
||||
}
|
||||
|
||||
impl Describe {
|
||||
pub fn empty_new() -> Describe {
|
||||
Describe {
|
||||
code: 'D',
|
||||
len: 4 + 1 + 1,
|
||||
target: 'S',
|
||||
statement_name: "".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rename(mut self, name: &str) -> Self {
|
||||
self.statement_name = name.to_string();
|
||||
self
|
||||
@@ -1114,13 +1248,6 @@ pub fn close_complete() -> BytesMut {
|
||||
bytes
|
||||
}
|
||||
|
||||
pub fn prepared_statement_name() -> String {
|
||||
format!(
|
||||
"P_{}",
|
||||
PREPARED_STATEMENT_COUNTER.fetch_add(1, Ordering::SeqCst)
|
||||
)
|
||||
}
|
||||
|
||||
// from https://www.postgresql.org/docs/12/protocol-error-fields.html
|
||||
#[derive(Debug, Default, PartialEq)]
|
||||
pub struct PgErrorMsg {
|
||||
@@ -1203,7 +1330,7 @@ impl Display for PgErrorMsg {
|
||||
}
|
||||
|
||||
impl PgErrorMsg {
|
||||
pub fn parse(error_msg: Vec<u8>) -> Result<PgErrorMsg, Error> {
|
||||
pub fn parse(error_msg: &[u8]) -> Result<PgErrorMsg, Error> {
|
||||
let mut out = PgErrorMsg {
|
||||
severity_localized: "".to_string(),
|
||||
severity: "".to_string(),
|
||||
@@ -1311,38 +1438,38 @@ mod tests {
|
||||
fn parse_fields() {
|
||||
let mut complete_msg = vec![];
|
||||
let severity = "FATAL";
|
||||
complete_msg.extend(field('S', &severity));
|
||||
complete_msg.extend(field('V', &severity));
|
||||
complete_msg.extend(field('S', severity));
|
||||
complete_msg.extend(field('V', severity));
|
||||
|
||||
let error_code = "29P02";
|
||||
complete_msg.extend(field('C', &error_code));
|
||||
complete_msg.extend(field('C', error_code));
|
||||
let message = "password authentication failed for user \"wrong_user\"";
|
||||
complete_msg.extend(field('M', &message));
|
||||
complete_msg.extend(field('M', message));
|
||||
let detail_msg = "super detailed message";
|
||||
complete_msg.extend(field('D', &detail_msg));
|
||||
complete_msg.extend(field('D', detail_msg));
|
||||
let hint_msg = "hint detail here";
|
||||
complete_msg.extend(field('H', &hint_msg));
|
||||
complete_msg.extend(field('H', hint_msg));
|
||||
complete_msg.extend(field('P', "123"));
|
||||
complete_msg.extend(field('p', "234"));
|
||||
let internal_query = "SELECT * from foo;";
|
||||
complete_msg.extend(field('q', &internal_query));
|
||||
complete_msg.extend(field('q', internal_query));
|
||||
let where_msg = "where goes here";
|
||||
complete_msg.extend(field('W', &where_msg));
|
||||
complete_msg.extend(field('W', where_msg));
|
||||
let schema_msg = "schema_name";
|
||||
complete_msg.extend(field('s', &schema_msg));
|
||||
complete_msg.extend(field('s', schema_msg));
|
||||
let table_msg = "table_name";
|
||||
complete_msg.extend(field('t', &table_msg));
|
||||
complete_msg.extend(field('t', table_msg));
|
||||
let column_msg = "column_name";
|
||||
complete_msg.extend(field('c', &column_msg));
|
||||
complete_msg.extend(field('c', column_msg));
|
||||
let data_type_msg = "type_name";
|
||||
complete_msg.extend(field('d', &data_type_msg));
|
||||
complete_msg.extend(field('d', data_type_msg));
|
||||
let constraint_msg = "constraint_name";
|
||||
complete_msg.extend(field('n', &constraint_msg));
|
||||
complete_msg.extend(field('n', constraint_msg));
|
||||
let file_msg = "pgcat.c";
|
||||
complete_msg.extend(field('F', &file_msg));
|
||||
complete_msg.extend(field('F', file_msg));
|
||||
complete_msg.extend(field('L', "335"));
|
||||
let routine_msg = "my_failing_routine";
|
||||
complete_msg.extend(field('R', &routine_msg));
|
||||
complete_msg.extend(field('R', routine_msg));
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::INFO)
|
||||
@@ -1351,7 +1478,7 @@ mod tests {
|
||||
|
||||
info!(
|
||||
"full message: {}",
|
||||
PgErrorMsg::parse(complete_msg.clone()).unwrap()
|
||||
PgErrorMsg::parse(&complete_msg).unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
PgErrorMsg {
|
||||
@@ -1374,17 +1501,17 @@ mod tests {
|
||||
line: Some(335),
|
||||
routine: Some(routine_msg.to_string()),
|
||||
},
|
||||
PgErrorMsg::parse(complete_msg).unwrap()
|
||||
PgErrorMsg::parse(&complete_msg).unwrap()
|
||||
);
|
||||
|
||||
let mut only_mandatory_msg = vec![];
|
||||
only_mandatory_msg.extend(field('S', &severity));
|
||||
only_mandatory_msg.extend(field('V', &severity));
|
||||
only_mandatory_msg.extend(field('C', &error_code));
|
||||
only_mandatory_msg.extend(field('M', &message));
|
||||
only_mandatory_msg.extend(field('D', &detail_msg));
|
||||
only_mandatory_msg.extend(field('S', severity));
|
||||
only_mandatory_msg.extend(field('V', severity));
|
||||
only_mandatory_msg.extend(field('C', error_code));
|
||||
only_mandatory_msg.extend(field('M', message));
|
||||
only_mandatory_msg.extend(field('D', detail_msg));
|
||||
|
||||
let err_fields = PgErrorMsg::parse(only_mandatory_msg.clone()).unwrap();
|
||||
let err_fields = PgErrorMsg::parse(&only_mandatory_msg).unwrap();
|
||||
info!("only mandatory fields: {}", &err_fields);
|
||||
error!(
|
||||
"server error: {}: {}",
|
||||
@@ -1411,7 +1538,7 @@ mod tests {
|
||||
line: None,
|
||||
routine: None,
|
||||
},
|
||||
PgErrorMsg::parse(only_mandatory_msg).unwrap()
|
||||
PgErrorMsg::parse(&only_mandatory_msg).unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,14 +23,15 @@ impl MirroredClient {
|
||||
async fn create_pool(&self) -> Pool<ServerPool> {
|
||||
let config = get_config();
|
||||
let default = std::time::Duration::from_millis(10_000).as_millis() as u64;
|
||||
let (connection_timeout, idle_timeout, _cfg) =
|
||||
let (connection_timeout, idle_timeout, _cfg, prepared_statement_cache_size) =
|
||||
match config.pools.get(&self.address.pool_name) {
|
||||
Some(cfg) => (
|
||||
cfg.connect_timeout.unwrap_or(default),
|
||||
cfg.idle_timeout.unwrap_or(default),
|
||||
cfg.clone(),
|
||||
cfg.prepared_statements_cache_size,
|
||||
),
|
||||
None => (default, default, crate::config::Pool::default()),
|
||||
None => (default, default, crate::config::Pool::default(), 0),
|
||||
};
|
||||
|
||||
let manager = ServerPool::new(
|
||||
@@ -42,6 +43,7 @@ impl MirroredClient {
|
||||
None,
|
||||
true,
|
||||
false,
|
||||
prepared_statement_cache_size,
|
||||
);
|
||||
|
||||
Pool::builder()
|
||||
@@ -137,18 +139,18 @@ impl MirroringManager {
|
||||
bytes_rx,
|
||||
disconnect_rx: exit_rx,
|
||||
};
|
||||
exit_senders.push(exit_tx.clone());
|
||||
byte_senders.push(bytes_tx.clone());
|
||||
exit_senders.push(exit_tx);
|
||||
byte_senders.push(bytes_tx);
|
||||
client.start();
|
||||
});
|
||||
|
||||
Self {
|
||||
byte_senders: byte_senders,
|
||||
byte_senders,
|
||||
disconnect_senders: exit_senders,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send(self: &mut Self, bytes: &BytesMut) {
|
||||
pub fn send(&mut self, bytes: &BytesMut) {
|
||||
// We want to avoid performing an allocation if we won't be able to send the message
|
||||
// There is a possibility of a race here where we check the capacity and then the channel is
|
||||
// closed or the capacity is reduced to 0, but mirroring is best effort anyway
|
||||
@@ -170,7 +172,7 @@ impl MirroringManager {
|
||||
});
|
||||
}
|
||||
|
||||
pub fn disconnect(self: &mut Self) {
|
||||
pub fn disconnect(&mut self) {
|
||||
self.disconnect_senders
|
||||
.iter_mut()
|
||||
.for_each(|sender| match sender.try_send(()) {
|
||||
|
||||
@@ -92,7 +92,7 @@ impl<'a> Plugin for Intercept<'a> {
|
||||
.map(|s| {
|
||||
let s = s.as_str().to_string();
|
||||
|
||||
if s == "" {
|
||||
if s.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(s)
|
||||
|
||||
@@ -33,6 +33,7 @@ pub enum PluginOutput {
|
||||
#[async_trait]
|
||||
pub trait Plugin {
|
||||
// Run before the query is sent to the server.
|
||||
#[allow(clippy::ptr_arg)]
|
||||
async fn run(
|
||||
&mut self,
|
||||
query_router: &QueryRouter,
|
||||
|
||||
@@ -20,7 +20,7 @@ impl<'a> Prewarmer<'a> {
|
||||
self.server.address(),
|
||||
query
|
||||
);
|
||||
self.server.query(&query).await?;
|
||||
self.server.query(query).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -34,7 +34,7 @@ impl<'a> Plugin for TableAccess<'a> {
|
||||
|
||||
visit_relations(ast, |relation| {
|
||||
let relation = relation.to_string();
|
||||
let parts = relation.split(".").collect::<Vec<&str>>();
|
||||
let parts = relation.split('.').collect::<Vec<&str>>();
|
||||
let table_name = parts.last().unwrap();
|
||||
|
||||
if self.tables.contains(&table_name.to_string()) {
|
||||
|
||||
153
src/pool.rs
153
src/pool.rs
@@ -3,6 +3,7 @@ use async_trait::async_trait;
|
||||
use bb8::{ManageConnection, Pool, PooledConnection, QueueStrategy};
|
||||
use chrono::naive::NaiveDateTime;
|
||||
use log::{debug, error, info, warn};
|
||||
use lru::LruCache;
|
||||
use once_cell::sync::Lazy;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use rand::seq::SliceRandom;
|
||||
@@ -10,6 +11,7 @@ use rand::thread_rng;
|
||||
use regex::Regex;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::num::NonZeroUsize;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
@@ -24,6 +26,7 @@ use crate::config::{
|
||||
use crate::errors::Error;
|
||||
|
||||
use crate::auth_passthrough::AuthPassthrough;
|
||||
use crate::messages::Parse;
|
||||
use crate::plugins::prewarmer;
|
||||
use crate::server::{Server, ServerParameters};
|
||||
use crate::sharding::ShardingFunction;
|
||||
@@ -54,6 +57,57 @@ pub enum BanReason {
|
||||
AdminBan(i64),
|
||||
}
|
||||
|
||||
pub type PreparedStatementCacheType = Arc<Mutex<PreparedStatementCache>>;
|
||||
|
||||
// TODO: Add stats the this cache
|
||||
// TODO: Add application name to the cache value to help identify which application is using the cache
|
||||
// TODO: Create admin command to show which statements are in the cache
|
||||
#[derive(Debug)]
|
||||
pub struct PreparedStatementCache {
|
||||
cache: LruCache<u64, Arc<Parse>>,
|
||||
}
|
||||
|
||||
impl PreparedStatementCache {
|
||||
pub fn new(mut size: usize) -> Self {
|
||||
// Cannot be zeros
|
||||
if size == 0 {
|
||||
size = 1;
|
||||
}
|
||||
|
||||
PreparedStatementCache {
|
||||
cache: LruCache::new(NonZeroUsize::new(size).unwrap()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds the prepared statement to the cache if it doesn't exist with a new name
|
||||
/// if it already exists will give you the existing parse
|
||||
///
|
||||
/// Pass the hash to this so that we can do the compute before acquiring the lock
|
||||
pub fn get_or_insert(&mut self, parse: &Parse, hash: u64) -> Arc<Parse> {
|
||||
match self.cache.get(&hash) {
|
||||
Some(rewritten_parse) => rewritten_parse.clone(),
|
||||
None => {
|
||||
let new_parse = Arc::new(parse.clone().rewrite());
|
||||
let evicted = self.cache.push(hash, new_parse.clone());
|
||||
|
||||
if let Some((_, evicted_parse)) = evicted {
|
||||
debug!(
|
||||
"Evicted prepared statement {} from cache",
|
||||
evicted_parse.name
|
||||
);
|
||||
}
|
||||
|
||||
new_parse
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Marks the hash as most recently used if it exists
|
||||
pub fn promote(&mut self, hash: &u64) {
|
||||
self.cache.promote(hash);
|
||||
}
|
||||
}
|
||||
|
||||
/// An identifier for a PgCat pool,
|
||||
/// a database visible to clients.
|
||||
#[derive(Hash, Debug, Clone, PartialEq, Eq, Default)]
|
||||
@@ -190,11 +244,11 @@ impl Default for PoolSettings {
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ConnectionPool {
|
||||
/// The pools handled internally by bb8.
|
||||
databases: Vec<Vec<Pool<ServerPool>>>,
|
||||
databases: Arc<Vec<Vec<Pool<ServerPool>>>>,
|
||||
|
||||
/// The addresses (host, port, role) to handle
|
||||
/// failover and load balancing deterministically.
|
||||
addresses: Vec<Vec<Address>>,
|
||||
addresses: Arc<Vec<Vec<Address>>>,
|
||||
|
||||
/// List of banned addresses (see above)
|
||||
/// that should not be queried.
|
||||
@@ -206,7 +260,7 @@ pub struct ConnectionPool {
|
||||
original_server_parameters: Arc<RwLock<ServerParameters>>,
|
||||
|
||||
/// Pool configuration.
|
||||
pub settings: PoolSettings,
|
||||
pub settings: Arc<PoolSettings>,
|
||||
|
||||
/// If not validated, we need to double check the pool is available before allowing a client
|
||||
/// to use it.
|
||||
@@ -223,6 +277,9 @@ pub struct ConnectionPool {
|
||||
|
||||
/// AuthInfo
|
||||
pub auth_hash: Arc<RwLock<Option<String>>>,
|
||||
|
||||
/// Cache
|
||||
pub prepared_statement_cache: Option<PreparedStatementCacheType>,
|
||||
}
|
||||
|
||||
impl ConnectionPool {
|
||||
@@ -241,20 +298,17 @@ impl ConnectionPool {
|
||||
let old_pool_ref = get_pool(pool_name, &user.username);
|
||||
let identifier = PoolIdentifier::new(pool_name, &user.username);
|
||||
|
||||
match old_pool_ref {
|
||||
Some(pool) => {
|
||||
// If the pool hasn't changed, get existing reference and insert it into the new_pools.
|
||||
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
|
||||
if pool.config_hash == new_pool_hash_value {
|
||||
info!(
|
||||
"[pool: {}][user: {}] has not changed",
|
||||
pool_name, user.username
|
||||
);
|
||||
new_pools.insert(identifier.clone(), pool.clone());
|
||||
continue;
|
||||
}
|
||||
if let Some(pool) = old_pool_ref {
|
||||
// If the pool hasn't changed, get existing reference and insert it into the new_pools.
|
||||
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
|
||||
if pool.config_hash == new_pool_hash_value {
|
||||
info!(
|
||||
"[pool: {}][user: {}] has not changed",
|
||||
pool_name, user.username
|
||||
);
|
||||
new_pools.insert(identifier.clone(), pool.clone());
|
||||
continue;
|
||||
}
|
||||
None => (),
|
||||
}
|
||||
|
||||
info!(
|
||||
@@ -379,6 +433,7 @@ impl ConnectionPool {
|
||||
},
|
||||
pool_config.cleanup_server_connections,
|
||||
pool_config.log_client_parameter_status_changes,
|
||||
pool_config.prepared_statements_cache_size,
|
||||
);
|
||||
|
||||
let connect_timeout = match pool_config.connect_timeout {
|
||||
@@ -399,7 +454,7 @@ impl ConnectionPool {
|
||||
},
|
||||
};
|
||||
|
||||
let reaper_rate = *vec![idle_timeout, server_lifetime, POOL_REAPER_RATE]
|
||||
let reaper_rate = *[idle_timeout, server_lifetime, POOL_REAPER_RATE]
|
||||
.iter()
|
||||
.min()
|
||||
.unwrap();
|
||||
@@ -448,13 +503,13 @@ impl ConnectionPool {
|
||||
}
|
||||
|
||||
let pool = ConnectionPool {
|
||||
databases: shards,
|
||||
addresses,
|
||||
databases: Arc::new(shards),
|
||||
addresses: Arc::new(addresses),
|
||||
banlist: Arc::new(RwLock::new(banlist)),
|
||||
config_hash: new_pool_hash_value,
|
||||
original_server_parameters: Arc::new(RwLock::new(ServerParameters::new())),
|
||||
auth_hash: pool_auth_hash,
|
||||
settings: PoolSettings {
|
||||
settings: Arc::new(PoolSettings {
|
||||
pool_mode: match user.pool_mode {
|
||||
Some(pool_mode) => pool_mode,
|
||||
None => pool_config.pool_mode,
|
||||
@@ -489,7 +544,7 @@ impl ConnectionPool {
|
||||
.clone()
|
||||
.map(|regex| Regex::new(regex.as_str()).unwrap()),
|
||||
regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
|
||||
default_shard: pool_config.default_shard.clone(),
|
||||
default_shard: pool_config.default_shard,
|
||||
auth_query: pool_config.auth_query.clone(),
|
||||
auth_query_user: pool_config.auth_query_user.clone(),
|
||||
auth_query_password: pool_config.auth_query_password.clone(),
|
||||
@@ -497,17 +552,23 @@ impl ConnectionPool {
|
||||
Some(ref plugins) => Some(plugins.clone()),
|
||||
None => config.plugins.clone(),
|
||||
},
|
||||
},
|
||||
}),
|
||||
validated: Arc::new(AtomicBool::new(false)),
|
||||
paused: Arc::new(AtomicBool::new(false)),
|
||||
paused_waiter: Arc::new(Notify::new()),
|
||||
prepared_statement_cache: match pool_config.prepared_statements_cache_size {
|
||||
0 => None,
|
||||
_ => Some(Arc::new(Mutex::new(PreparedStatementCache::new(
|
||||
pool_config.prepared_statements_cache_size,
|
||||
)))),
|
||||
},
|
||||
};
|
||||
|
||||
// Connect to the servers to make sure pool configuration is valid
|
||||
// before setting it globally.
|
||||
// Do this async and somewhere else, we don't have to wait here.
|
||||
if config.general.validate_config {
|
||||
let mut validate_pool = pool.clone();
|
||||
let validate_pool = pool.clone();
|
||||
tokio::task::spawn(async move {
|
||||
let _ = validate_pool.validate().await;
|
||||
});
|
||||
@@ -528,7 +589,7 @@ impl ConnectionPool {
|
||||
/// when they connect.
|
||||
/// This also warms up the pool for clients that connect when
|
||||
/// the pooler starts up.
|
||||
pub async fn validate(&mut self) -> Result<(), Error> {
|
||||
pub async fn validate(&self) -> Result<(), Error> {
|
||||
let mut futures = Vec::new();
|
||||
let validated = Arc::clone(&self.validated);
|
||||
|
||||
@@ -678,7 +739,7 @@ impl ConnectionPool {
|
||||
let mut force_healthcheck = false;
|
||||
|
||||
if self.is_banned(address) {
|
||||
if self.try_unban(&address).await {
|
||||
if self.try_unban(address).await {
|
||||
force_healthcheck = true;
|
||||
} else {
|
||||
debug!("Address {:?} is banned", address);
|
||||
@@ -806,8 +867,8 @@ impl ConnectionPool {
|
||||
// Don't leave a bad connection in the pool.
|
||||
server.mark_bad();
|
||||
|
||||
self.ban(&address, BanReason::FailedHealthCheck, Some(client_info));
|
||||
return false;
|
||||
self.ban(address, BanReason::FailedHealthCheck, Some(client_info));
|
||||
false
|
||||
}
|
||||
|
||||
/// Ban an address (i.e. replica). It no longer will serve
|
||||
@@ -931,10 +992,10 @@ impl ConnectionPool {
|
||||
let guard = self.banlist.read();
|
||||
for banlist in guard.iter() {
|
||||
for (address, (reason, timestamp)) in banlist.iter() {
|
||||
bans.push((address.clone(), (reason.clone(), timestamp.clone())));
|
||||
bans.push((address.clone(), (reason.clone(), *timestamp)));
|
||||
}
|
||||
}
|
||||
return bans;
|
||||
bans
|
||||
}
|
||||
|
||||
/// Get the address from the host url
|
||||
@@ -992,7 +1053,7 @@ impl ConnectionPool {
|
||||
}
|
||||
let busy = provisioned - idle;
|
||||
debug!("{:?} has {:?} busy connections", address, busy);
|
||||
return busy;
|
||||
busy
|
||||
}
|
||||
|
||||
fn valid_shard_id(&self, shard: Option<usize>) -> bool {
|
||||
@@ -1001,6 +1062,29 @@ impl ConnectionPool {
|
||||
Some(shard) => shard < self.shards(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a parse statement to the pool's cache and return the rewritten parse
|
||||
///
|
||||
/// Do not pass an anonymous parse statement to this function
|
||||
pub fn register_parse_to_cache(&self, hash: u64, parse: &Parse) -> Option<Arc<Parse>> {
|
||||
// We should only be calling this function if the cache is enabled
|
||||
match self.prepared_statement_cache {
|
||||
Some(ref prepared_statement_cache) => {
|
||||
let mut cache = prepared_statement_cache.lock();
|
||||
Some(cache.get_or_insert(parse, hash))
|
||||
}
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Promote a prepared statement hash in the LRU
|
||||
pub fn promote_prepared_statement_hash(&self, hash: &u64) {
|
||||
// We should only be calling this function if the cache is enabled
|
||||
if let Some(ref prepared_statement_cache) = self.prepared_statement_cache {
|
||||
let mut cache = prepared_statement_cache.lock();
|
||||
cache.promote(hash);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper for the bb8 connection pool.
|
||||
@@ -1028,9 +1112,13 @@ pub struct ServerPool {
|
||||
|
||||
/// Log client parameter status changes
|
||||
log_client_parameter_status_changes: bool,
|
||||
|
||||
/// Prepared statement cache size
|
||||
prepared_statement_cache_size: usize,
|
||||
}
|
||||
|
||||
impl ServerPool {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
address: Address,
|
||||
user: User,
|
||||
@@ -1040,16 +1128,18 @@ impl ServerPool {
|
||||
plugins: Option<Plugins>,
|
||||
cleanup_connections: bool,
|
||||
log_client_parameter_status_changes: bool,
|
||||
prepared_statement_cache_size: usize,
|
||||
) -> ServerPool {
|
||||
ServerPool {
|
||||
address,
|
||||
user: user.clone(),
|
||||
user,
|
||||
database: database.to_string(),
|
||||
client_server_map,
|
||||
auth_hash,
|
||||
plugins,
|
||||
cleanup_connections,
|
||||
log_client_parameter_status_changes,
|
||||
prepared_statement_cache_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1080,6 +1170,7 @@ impl ManageConnection for ServerPool {
|
||||
self.auth_hash.clone(),
|
||||
self.cleanup_connections,
|
||||
self.log_client_parameter_status_changes,
|
||||
self.prepared_statement_cache_size,
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
||||
@@ -4,10 +4,10 @@ use bytes::{Buf, BytesMut};
|
||||
use log::{debug, error};
|
||||
use once_cell::sync::OnceCell;
|
||||
use regex::{Regex, RegexSet};
|
||||
use sqlparser::ast::Statement::{Query, StartTransaction};
|
||||
use sqlparser::ast::Statement::{Delete, Insert, Query, StartTransaction, Update};
|
||||
use sqlparser::ast::{
|
||||
BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, Statement, TableFactor,
|
||||
Value,
|
||||
Assignment, BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, Statement,
|
||||
TableFactor, TableWithJoins, Value,
|
||||
};
|
||||
use sqlparser::dialect::PostgreSqlDialect;
|
||||
use sqlparser::parser::Parser;
|
||||
@@ -91,7 +91,7 @@ impl QueryRouter {
|
||||
/// One-time initialization of regexes
|
||||
/// that parse our custom SQL protocol.
|
||||
pub fn setup() -> bool {
|
||||
let set = match RegexSet::new(&CUSTOM_SQL_REGEXES) {
|
||||
let set = match RegexSet::new(CUSTOM_SQL_REGEXES) {
|
||||
Ok(rgx) => rgx,
|
||||
Err(err) => {
|
||||
error!("QueryRouter::setup Could not compile regex set: {:?}", err);
|
||||
@@ -128,11 +128,11 @@ impl QueryRouter {
|
||||
}
|
||||
|
||||
/// Pool settings can change because of a config reload.
|
||||
pub fn update_pool_settings(&mut self, pool_settings: PoolSettings) {
|
||||
self.pool_settings = pool_settings;
|
||||
pub fn update_pool_settings(&mut self, pool_settings: &PoolSettings) {
|
||||
self.pool_settings = pool_settings.clone();
|
||||
}
|
||||
|
||||
pub fn pool_settings<'a>(&'a self) -> &'a PoolSettings {
|
||||
pub fn pool_settings(&self) -> &PoolSettings {
|
||||
&self.pool_settings
|
||||
}
|
||||
|
||||
@@ -148,7 +148,7 @@ impl QueryRouter {
|
||||
|
||||
// Check for any sharding regex matches in any queries
|
||||
if comment_shard_routing_enabled {
|
||||
match code as char {
|
||||
match code {
|
||||
// For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement
|
||||
'P' | 'Q' => {
|
||||
// Check only the first block of bytes configured by the pool settings
|
||||
@@ -344,16 +344,13 @@ impl QueryRouter {
|
||||
let code = message_cursor.get_u8() as char;
|
||||
let len = message_cursor.get_i32() as usize;
|
||||
|
||||
match self.pool_settings.query_parser_max_length {
|
||||
Some(max_length) => {
|
||||
if len > max_length {
|
||||
return Err(Error::QueryRouterParserError(format!(
|
||||
"Query too long for parser: {} > {}",
|
||||
len, max_length
|
||||
)));
|
||||
}
|
||||
if let Some(max_length) = self.pool_settings.query_parser_max_length {
|
||||
if len > max_length {
|
||||
return Err(Error::QueryRouterParserError(format!(
|
||||
"Query too long for parser: {} > {}",
|
||||
len, max_length
|
||||
)));
|
||||
}
|
||||
None => (),
|
||||
};
|
||||
|
||||
let query = match code {
|
||||
@@ -403,6 +400,9 @@ impl QueryRouter {
|
||||
return Err(Error::QueryRouterParserError("empty query".into()));
|
||||
}
|
||||
|
||||
let mut visited_write_statement = false;
|
||||
let mut prev_inferred_shard = None;
|
||||
|
||||
for q in ast {
|
||||
match q {
|
||||
// All transactions go to the primary, probably a write.
|
||||
@@ -420,29 +420,38 @@ impl QueryRouter {
|
||||
// or discard shard selection. If they point to the same shard though,
|
||||
// we can let them through as-is.
|
||||
// This is basically building a database now :)
|
||||
match self.infer_shard(query) {
|
||||
Some(shard) => {
|
||||
self.active_shard = Some(shard);
|
||||
debug!("Automatically using shard: {:?}", self.active_shard);
|
||||
}
|
||||
|
||||
None => (),
|
||||
};
|
||||
let inferred_shard = self.infer_shard(query);
|
||||
self.handle_inferred_shard(inferred_shard, &mut prev_inferred_shard)?;
|
||||
}
|
||||
|
||||
None => (),
|
||||
};
|
||||
|
||||
self.active_role = match self.primary_reads_enabled() {
|
||||
false => Some(Role::Replica), // If primary should not be receiving reads, use a replica.
|
||||
true => None, // Any server role is fine in this case.
|
||||
// If we already visited a write statement, we should be going to the primary.
|
||||
if !visited_write_statement {
|
||||
self.active_role = match self.primary_reads_enabled() {
|
||||
false => Some(Role::Replica), // If primary should not be receiving reads, use a replica.
|
||||
true => None, // Any server role is fine in this case.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Likely a write
|
||||
_ => {
|
||||
match &self.pool_settings.automatic_sharding_key {
|
||||
Some(_) => {
|
||||
// TODO: similar to the above, if we have multiple queries in the
|
||||
// same message, we can either split them and execute them individually
|
||||
// or discard shard selection. If they point to the same shard though,
|
||||
// we can let them through as-is.
|
||||
let inferred_shard = self.infer_shard_on_write(q)?;
|
||||
self.handle_inferred_shard(inferred_shard, &mut prev_inferred_shard)?;
|
||||
}
|
||||
|
||||
None => (),
|
||||
};
|
||||
visited_write_statement = true;
|
||||
self.active_role = Some(Role::Primary);
|
||||
break;
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -450,6 +459,188 @@ impl QueryRouter {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_inferred_shard(
|
||||
&mut self,
|
||||
inferred_shard: Option<usize>,
|
||||
prev_inferred_shard: &mut Option<usize>,
|
||||
) -> Result<(), Error> {
|
||||
if let Some(shard) = inferred_shard {
|
||||
if let Some(prev_shard) = *prev_inferred_shard {
|
||||
if prev_shard != shard {
|
||||
debug!("Found more than one shard in the query, not supported yet");
|
||||
return Err(Error::QueryRouterParserError(
|
||||
"multiple shards in query".into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
*prev_inferred_shard = Some(shard);
|
||||
self.active_shard = Some(shard);
|
||||
debug!("Automatically using shard: {:?}", self.active_shard);
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn infer_shard_on_write(&mut self, q: &Statement) -> Result<Option<usize>, Error> {
|
||||
let mut exprs = Vec::new();
|
||||
|
||||
// Collect all table names from the query.
|
||||
let mut table_names = Vec::new();
|
||||
|
||||
match q {
|
||||
Insert {
|
||||
or,
|
||||
into: _,
|
||||
table_name,
|
||||
columns,
|
||||
overwrite: _,
|
||||
source,
|
||||
partitioned,
|
||||
after_columns,
|
||||
table: _,
|
||||
on: _,
|
||||
returning: _,
|
||||
} => {
|
||||
// Not supported in postgres.
|
||||
assert!(or.is_none());
|
||||
assert!(partitioned.is_none());
|
||||
assert!(after_columns.is_empty());
|
||||
|
||||
Self::process_table(table_name, &mut table_names);
|
||||
Self::process_query(source, &mut exprs, &mut table_names, &Some(columns));
|
||||
}
|
||||
Delete {
|
||||
tables,
|
||||
from,
|
||||
using,
|
||||
selection,
|
||||
returning: _,
|
||||
} => {
|
||||
if let Some(expr) = selection {
|
||||
exprs.push(expr.clone());
|
||||
}
|
||||
|
||||
// Multi tables delete are not supported in postgres.
|
||||
assert!(tables.is_empty());
|
||||
|
||||
Self::process_tables_with_join(from, &mut exprs, &mut table_names);
|
||||
if let Some(using_tbl_with_join) = using {
|
||||
Self::process_tables_with_join(
|
||||
using_tbl_with_join,
|
||||
&mut exprs,
|
||||
&mut table_names,
|
||||
);
|
||||
}
|
||||
Self::process_selection(selection, &mut exprs);
|
||||
}
|
||||
Update {
|
||||
table,
|
||||
assignments,
|
||||
from,
|
||||
selection,
|
||||
returning: _,
|
||||
} => {
|
||||
Self::process_table_with_join(table, &mut exprs, &mut table_names);
|
||||
if let Some(from_tbl) = from {
|
||||
Self::process_table_with_join(from_tbl, &mut exprs, &mut table_names);
|
||||
}
|
||||
Self::process_selection(selection, &mut exprs);
|
||||
self.assignment_parser(assignments)?;
|
||||
}
|
||||
_ => {
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
|
||||
Ok(self.infer_shard_from_exprs(exprs, table_names))
|
||||
}
|
||||
|
||||
fn process_query(
|
||||
query: &sqlparser::ast::Query,
|
||||
exprs: &mut Vec<Expr>,
|
||||
table_names: &mut Vec<Vec<Ident>>,
|
||||
columns: &Option<&Vec<Ident>>,
|
||||
) {
|
||||
match &*query.body {
|
||||
SetExpr::Query(query) => {
|
||||
Self::process_query(query, exprs, table_names, columns);
|
||||
}
|
||||
|
||||
// SELECT * FROM ...
|
||||
// We understand that pretty well.
|
||||
SetExpr::Select(select) => {
|
||||
Self::process_tables_with_join(&select.from, exprs, table_names);
|
||||
|
||||
// Parse the actual "FROM ..."
|
||||
Self::process_selection(&select.selection, exprs);
|
||||
}
|
||||
|
||||
SetExpr::Values(values) => {
|
||||
if let Some(cols) = columns {
|
||||
for row in values.rows.iter() {
|
||||
for (i, expr) in row.iter().enumerate() {
|
||||
if cols.len() > i {
|
||||
exprs.push(Expr::BinaryOp {
|
||||
left: Box::new(Expr::Identifier(cols[i].clone())),
|
||||
op: BinaryOperator::Eq,
|
||||
right: Box::new(expr.clone()),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => (),
|
||||
};
|
||||
}
|
||||
|
||||
fn process_selection(selection: &Option<Expr>, exprs: &mut Vec<Expr>) {
|
||||
match selection {
|
||||
Some(selection) => {
|
||||
exprs.push(selection.clone());
|
||||
}
|
||||
|
||||
None => (),
|
||||
};
|
||||
}
|
||||
|
||||
fn process_tables_with_join(
|
||||
tables: &[TableWithJoins],
|
||||
exprs: &mut Vec<Expr>,
|
||||
table_names: &mut Vec<Vec<Ident>>,
|
||||
) {
|
||||
for table in tables.iter() {
|
||||
Self::process_table_with_join(table, exprs, table_names);
|
||||
}
|
||||
}
|
||||
|
||||
fn process_table_with_join(
|
||||
table: &TableWithJoins,
|
||||
exprs: &mut Vec<Expr>,
|
||||
table_names: &mut Vec<Vec<Ident>>,
|
||||
) {
|
||||
if let TableFactor::Table { name, .. } = &table.relation {
|
||||
Self::process_table(name, table_names);
|
||||
};
|
||||
|
||||
// Get table names from all the joins.
|
||||
for join in table.joins.iter() {
|
||||
if let TableFactor::Table { name, .. } = &join.relation {
|
||||
Self::process_table(name, table_names);
|
||||
};
|
||||
|
||||
// We can filter results based on join conditions, e.g.
|
||||
// SELECT * FROM t INNER JOIN B ON B.sharding_key = 5;
|
||||
if let JoinOperator::Inner(JoinConstraint::On(expr)) = &join.join_operator {
|
||||
// Parse the selection criteria later.
|
||||
exprs.push(expr.clone());
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn process_table(name: &sqlparser::ast::ObjectName, table_names: &mut Vec<Vec<Ident>>) {
|
||||
table_names.push(name.0.clone())
|
||||
}
|
||||
|
||||
/// Parse the shard number from the Bind message
|
||||
/// which contains the arguments for a prepared statement.
|
||||
///
|
||||
@@ -592,6 +783,33 @@ impl QueryRouter {
|
||||
}
|
||||
}
|
||||
|
||||
/// An `assignments` exists in the `UPDATE` statements. This parses the assignments and makes
|
||||
/// sure that we are not updating the sharding key. It's not supported yet.
|
||||
fn assignment_parser(&self, assignments: &Vec<Assignment>) -> Result<(), Error> {
|
||||
let sharding_key = self
|
||||
.pool_settings
|
||||
.automatic_sharding_key
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.split('.')
|
||||
.map(|ident| Ident::new(ident.to_lowercase()))
|
||||
.collect::<Vec<Ident>>();
|
||||
|
||||
// Sharding key must be always fully qualified
|
||||
assert_eq!(sharding_key.len(), 2);
|
||||
|
||||
for a in assignments {
|
||||
if sharding_key[0].value == "*"
|
||||
&& sharding_key[1].value == a.id.last().unwrap().value.to_lowercase()
|
||||
{
|
||||
return Err(Error::QueryRouterParserError(
|
||||
"Sharding key cannot be updated.".into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// A `selection` is the `WHERE` clause. This parses
|
||||
/// the clause and extracts the sharding key, if present.
|
||||
fn selection_parser(&self, expr: &Expr, table_names: &Vec<Vec<Ident>>) -> Vec<ShardingKey> {
|
||||
@@ -603,8 +821,8 @@ impl QueryRouter {
|
||||
.automatic_sharding_key
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.split(".")
|
||||
.map(|ident| Ident::new(ident))
|
||||
.split('.')
|
||||
.map(|ident| Ident::new(ident.to_lowercase()))
|
||||
.collect::<Vec<Ident>>();
|
||||
|
||||
// Sharding key must be always fully qualified
|
||||
@@ -620,7 +838,7 @@ impl QueryRouter {
|
||||
Expr::Identifier(ident) => {
|
||||
// Only if we're dealing with only one table
|
||||
// and there is no ambiguity
|
||||
if &ident.value == &sharding_key[1].value {
|
||||
if ident.value.to_lowercase() == sharding_key[1].value {
|
||||
// Sharding key is unique enough, don't worry about
|
||||
// table names.
|
||||
if &sharding_key[0].value == "*" {
|
||||
@@ -633,13 +851,13 @@ impl QueryRouter {
|
||||
// SELECT * FROM t WHERE sharding_key = 5
|
||||
// Make sure the table name from the sharding key matches
|
||||
// the table name from the query.
|
||||
found = &sharding_key[0].value == &table[0].value;
|
||||
found = sharding_key[0].value == table[0].value.to_lowercase();
|
||||
} else if table.len() == 2 {
|
||||
// Table name is fully qualified with the schema: e.g.
|
||||
// SELECT * FROM public.t WHERE sharding_key = 5
|
||||
// Ignore the schema (TODO: at some point, we want schema support)
|
||||
// and use the table name only.
|
||||
found = &sharding_key[0].value == &table[1].value;
|
||||
found = sharding_key[0].value == table[1].value.to_lowercase();
|
||||
} else {
|
||||
debug!("Got table name with more than two idents, which is not possible");
|
||||
}
|
||||
@@ -651,8 +869,9 @@ impl QueryRouter {
|
||||
// The key is fully qualified in the query,
|
||||
// it will exist or Postgres will throw an error.
|
||||
if idents.len() == 2 {
|
||||
found = &sharding_key[0].value == &idents[0].value
|
||||
&& &sharding_key[1].value == &idents[1].value;
|
||||
found = (&sharding_key[0].value == "*"
|
||||
|| sharding_key[0].value == idents[0].value.to_lowercase())
|
||||
&& sharding_key[1].value == idents[1].value.to_lowercase();
|
||||
}
|
||||
// TODO: key can have schema as well, e.g. public.data.id (len == 3)
|
||||
}
|
||||
@@ -684,7 +903,7 @@ impl QueryRouter {
|
||||
}
|
||||
|
||||
Expr::Value(Value::Placeholder(placeholder)) => {
|
||||
match placeholder.replace("$", "").parse::<i16>() {
|
||||
match placeholder.replace('$', "").parse::<i16>() {
|
||||
Ok(placeholder) => result.push(ShardingKey::Placeholder(placeholder)),
|
||||
Err(_) => {
|
||||
debug!(
|
||||
@@ -705,100 +924,48 @@ impl QueryRouter {
|
||||
|
||||
/// Try to figure out which shard the query should go to.
|
||||
fn infer_shard(&mut self, query: &sqlparser::ast::Query) -> Option<usize> {
|
||||
let mut shards = BTreeSet::new();
|
||||
let mut exprs = Vec::new();
|
||||
|
||||
match &*query.body {
|
||||
SetExpr::Query(query) => {
|
||||
match self.infer_shard(&*query) {
|
||||
Some(shard) => {
|
||||
// Collect all table names from the query.
|
||||
let mut table_names = Vec::new();
|
||||
|
||||
Self::process_query(query, &mut exprs, &mut table_names, &None);
|
||||
self.infer_shard_from_exprs(exprs, table_names)
|
||||
}
|
||||
|
||||
fn infer_shard_from_exprs(
|
||||
&mut self,
|
||||
exprs: Vec<Expr>,
|
||||
table_names: Vec<Vec<Ident>>,
|
||||
) -> Option<usize> {
|
||||
let mut shards = BTreeSet::new();
|
||||
|
||||
let sharder = Sharder::new(
|
||||
self.pool_settings.shards,
|
||||
self.pool_settings.sharding_function,
|
||||
);
|
||||
|
||||
// Look for sharding keys in either the join condition
|
||||
// or the selection.
|
||||
for expr in exprs.iter() {
|
||||
let sharding_keys = self.selection_parser(expr, &table_names);
|
||||
|
||||
// TODO: Add support for prepared statements here.
|
||||
// This should just give us the position of the value in the `B` message.
|
||||
|
||||
for value in sharding_keys {
|
||||
match value {
|
||||
ShardingKey::Value(value) => {
|
||||
let shard = sharder.shard(value);
|
||||
shards.insert(shard);
|
||||
}
|
||||
None => (),
|
||||
|
||||
ShardingKey::Placeholder(position) => {
|
||||
self.placeholders.push(position);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// SELECT * FROM ...
|
||||
// We understand that pretty well.
|
||||
SetExpr::Select(select) => {
|
||||
// Collect all table names from the query.
|
||||
let mut table_names = Vec::new();
|
||||
|
||||
for table in select.from.iter() {
|
||||
match &table.relation {
|
||||
TableFactor::Table { name, .. } => {
|
||||
table_names.push(name.0.clone());
|
||||
}
|
||||
|
||||
_ => (),
|
||||
};
|
||||
|
||||
// Get table names from all the joins.
|
||||
for join in table.joins.iter() {
|
||||
match &join.relation {
|
||||
TableFactor::Table { name, .. } => {
|
||||
table_names.push(name.0.clone());
|
||||
}
|
||||
|
||||
_ => (),
|
||||
};
|
||||
|
||||
// We can filter results based on join conditions, e.g.
|
||||
// SELECT * FROM t INNER JOIN B ON B.sharding_key = 5;
|
||||
match &join.join_operator {
|
||||
JoinOperator::Inner(inner_join) => match &inner_join {
|
||||
JoinConstraint::On(expr) => {
|
||||
// Parse the selection criteria later.
|
||||
exprs.push(expr.clone());
|
||||
}
|
||||
|
||||
_ => (),
|
||||
},
|
||||
|
||||
_ => (),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the actual "FROM ..."
|
||||
match &select.selection {
|
||||
Some(selection) => {
|
||||
exprs.push(selection.clone());
|
||||
}
|
||||
|
||||
None => (),
|
||||
};
|
||||
|
||||
let sharder = Sharder::new(
|
||||
self.pool_settings.shards,
|
||||
self.pool_settings.sharding_function,
|
||||
);
|
||||
|
||||
// Look for sharding keys in either the join condition
|
||||
// or the selection.
|
||||
for expr in exprs.iter() {
|
||||
let sharding_keys = self.selection_parser(expr, &table_names);
|
||||
|
||||
// TODO: Add support for prepared statements here.
|
||||
// This should just give us the position of the value in the `B` message.
|
||||
|
||||
for value in sharding_keys {
|
||||
match value {
|
||||
ShardingKey::Value(value) => {
|
||||
let shard = sharder.shard(value);
|
||||
shards.insert(shard);
|
||||
}
|
||||
|
||||
ShardingKey::Placeholder(position) => {
|
||||
self.placeholders.push(position);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => (),
|
||||
};
|
||||
|
||||
}
|
||||
match shards.len() {
|
||||
// Didn't find a sharding key, you're on your own.
|
||||
0 => {
|
||||
@@ -830,16 +997,16 @@ impl QueryRouter {
|
||||
db: &self.pool_settings.db,
|
||||
};
|
||||
|
||||
let _ = query_logger.run(&self, ast).await;
|
||||
let _ = query_logger.run(self, ast).await;
|
||||
}
|
||||
|
||||
if let Some(ref intercept) = plugins.intercept {
|
||||
let mut intercept = Intercept {
|
||||
enabled: intercept.enabled,
|
||||
config: &intercept,
|
||||
config: intercept,
|
||||
};
|
||||
|
||||
let result = intercept.run(&self, ast).await;
|
||||
let result = intercept.run(self, ast).await;
|
||||
|
||||
if let Ok(PluginOutput::Intercept(output)) = result {
|
||||
return Ok(PluginOutput::Intercept(output));
|
||||
@@ -852,7 +1019,7 @@ impl QueryRouter {
|
||||
tables: &table_access.tables,
|
||||
};
|
||||
|
||||
let result = table_access.run(&self, ast).await;
|
||||
let result = table_access.run(self, ast).await;
|
||||
|
||||
if let Ok(PluginOutput::Deny(error)) = result {
|
||||
return Ok(PluginOutput::Deny(error));
|
||||
@@ -888,7 +1055,7 @@ impl QueryRouter {
|
||||
|
||||
/// Should we attempt to parse queries?
|
||||
pub fn query_parser_enabled(&self) -> bool {
|
||||
let enabled = match self.query_parser_enabled {
|
||||
match self.query_parser_enabled {
|
||||
None => {
|
||||
debug!(
|
||||
"Using pool settings, query_parser_enabled: {}",
|
||||
@@ -904,9 +1071,7 @@ impl QueryRouter {
|
||||
);
|
||||
value
|
||||
}
|
||||
};
|
||||
|
||||
enabled
|
||||
}
|
||||
}
|
||||
|
||||
pub fn primary_reads_enabled(&self) -> bool {
|
||||
@@ -917,6 +1082,12 @@ impl QueryRouter {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for QueryRouter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
@@ -938,10 +1109,14 @@ mod test {
|
||||
QueryRouter::setup();
|
||||
let mut qr = QueryRouter::new();
|
||||
qr.pool_settings.query_parser_read_write_splitting = true;
|
||||
assert!(qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")) != None);
|
||||
assert!(qr
|
||||
.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"))
|
||||
.is_some());
|
||||
assert!(qr.query_parser_enabled());
|
||||
|
||||
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
|
||||
assert!(qr
|
||||
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
|
||||
.is_some());
|
||||
|
||||
let queries = vec![
|
||||
simple_query("SELECT * FROM items WHERE id = 5"),
|
||||
@@ -983,7 +1158,9 @@ mod test {
|
||||
QueryRouter::setup();
|
||||
let mut qr = QueryRouter::new();
|
||||
let query = simple_query("SELECT * FROM items WHERE id = 5");
|
||||
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None);
|
||||
assert!(qr
|
||||
.try_execute_command(&simple_query("SET PRIMARY READS TO on"))
|
||||
.is_some());
|
||||
|
||||
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), None);
|
||||
@@ -996,7 +1173,9 @@ mod test {
|
||||
qr.pool_settings.query_parser_read_write_splitting = true;
|
||||
|
||||
qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"));
|
||||
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
|
||||
assert!(qr
|
||||
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
|
||||
.is_some());
|
||||
|
||||
let prepared_stmt = BytesMut::from(
|
||||
&b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..],
|
||||
@@ -1166,9 +1345,11 @@ mod test {
|
||||
qr.pool_settings.query_parser_read_write_splitting = true;
|
||||
|
||||
let query = simple_query("SET SERVER ROLE TO 'auto'");
|
||||
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
|
||||
assert!(qr
|
||||
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
|
||||
.is_some());
|
||||
|
||||
assert!(qr.try_execute_command(&query) != None);
|
||||
assert!(qr.try_execute_command(&query).is_some());
|
||||
assert!(qr.query_parser_enabled());
|
||||
assert_eq!(qr.role(), None);
|
||||
|
||||
@@ -1182,7 +1363,7 @@ mod test {
|
||||
|
||||
assert!(qr.query_parser_enabled());
|
||||
let query = simple_query("SET SERVER ROLE TO 'default'");
|
||||
assert!(qr.try_execute_command(&query) != None);
|
||||
assert!(qr.try_execute_command(&query).is_some());
|
||||
assert!(!qr.query_parser_enabled());
|
||||
}
|
||||
|
||||
@@ -1222,7 +1403,7 @@ mod test {
|
||||
assert_eq!(qr.primary_reads_enabled, None);
|
||||
|
||||
// Internal state must not be changed due to this, only defaults
|
||||
qr.update_pool_settings(pool_settings.clone());
|
||||
qr.update_pool_settings(&pool_settings);
|
||||
|
||||
assert_eq!(qr.active_role, None);
|
||||
assert_eq!(qr.active_shard, None);
|
||||
@@ -1230,11 +1411,11 @@ mod test {
|
||||
assert!(!qr.primary_reads_enabled());
|
||||
|
||||
let q1 = simple_query("SET SERVER ROLE TO 'primary'");
|
||||
assert!(qr.try_execute_command(&q1) != None);
|
||||
assert!(qr.try_execute_command(&q1).is_some());
|
||||
assert_eq!(qr.active_role.unwrap(), Role::Primary);
|
||||
|
||||
let q2 = simple_query("SET SERVER ROLE TO 'default'");
|
||||
assert!(qr.try_execute_command(&q2) != None);
|
||||
assert!(qr.try_execute_command(&q2).is_some());
|
||||
assert_eq!(qr.active_role.unwrap(), pool_settings.default_role);
|
||||
}
|
||||
|
||||
@@ -1295,29 +1476,29 @@ mod test {
|
||||
};
|
||||
|
||||
let mut qr = QueryRouter::new();
|
||||
qr.update_pool_settings(pool_settings.clone());
|
||||
qr.update_pool_settings(&pool_settings);
|
||||
|
||||
// Shard should start out unset
|
||||
assert_eq!(qr.active_shard, None);
|
||||
|
||||
// Don't panic when short query eg. ; is sent
|
||||
let q0 = simple_query(";");
|
||||
assert!(qr.try_execute_command(&q0) == None);
|
||||
assert!(qr.try_execute_command(&q0).is_none());
|
||||
assert_eq!(qr.active_shard, None);
|
||||
|
||||
// Make sure setting it works
|
||||
let q1 = simple_query("/* shard_id: 1 */ select 1 from foo;");
|
||||
assert!(qr.try_execute_command(&q1) == None);
|
||||
assert!(qr.try_execute_command(&q1).is_none());
|
||||
assert_eq!(qr.active_shard, Some(1));
|
||||
|
||||
// And make sure changing it works
|
||||
let q2 = simple_query("/* shard_id: 0 */ select 1 from foo;");
|
||||
assert!(qr.try_execute_command(&q2) == None);
|
||||
assert!(qr.try_execute_command(&q2).is_none());
|
||||
assert_eq!(qr.active_shard, Some(0));
|
||||
|
||||
// Validate setting by shard with expected shard copied from sharding.rs tests
|
||||
let q2 = simple_query("/* sharding_key: 6 */ select 1 from foo;");
|
||||
assert!(qr.try_execute_command(&q2) == None);
|
||||
assert!(qr.try_execute_command(&q2).is_none());
|
||||
assert_eq!(qr.active_shard, Some(2));
|
||||
}
|
||||
|
||||
@@ -1414,6 +1595,221 @@ mod test {
|
||||
assert_eq!(qr.shard().unwrap(), 0);
|
||||
}
|
||||
|
||||
fn auto_shard_wrapper(qry: &str, should_succeed: bool) -> Option<usize> {
|
||||
let mut qr = QueryRouter::new();
|
||||
qr.pool_settings.automatic_sharding_key = Some("*.w_id".to_string());
|
||||
qr.pool_settings.shards = 3;
|
||||
qr.pool_settings.query_parser_read_write_splitting = true;
|
||||
assert_eq!(qr.shard(), None);
|
||||
let infer_res = qr.infer(&qr.parse(&simple_query(qry)).unwrap());
|
||||
assert_eq!(infer_res.is_ok(), should_succeed);
|
||||
qr.shard()
|
||||
}
|
||||
|
||||
fn auto_shard(qry: &str) -> Option<usize> {
|
||||
auto_shard_wrapper(qry, true)
|
||||
}
|
||||
|
||||
fn auto_shard_fails(qry: &str) -> Option<usize> {
|
||||
auto_shard_wrapper(qry, false)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_automatic_sharding_insert_update_delete() {
|
||||
QueryRouter::setup();
|
||||
|
||||
assert_eq!(
|
||||
auto_shard_fails(
|
||||
"UPDATE ORDERS SET w_id = 3 WHERE O_ID = 3 AND O_D_ID = 3 AND W_ID = 5"
|
||||
),
|
||||
None
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
auto_shard_fails(
|
||||
"UPDATE ORDERS o SET o.W_ID = 3 WHERE o.O_ID = 3 AND o.O_D_ID = 3 AND o.W_ID = 5"
|
||||
),
|
||||
None
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
auto_shard(
|
||||
"UPDATE ORDERS o SET o.O_CARRIER_ID = 3 WHERE o.O_ID = 3 AND o.O_D_ID = 3 AND o.W_ID = 5"
|
||||
),
|
||||
Some(2)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_automatic_sharding_key_tpcc() {
|
||||
QueryRouter::setup();
|
||||
|
||||
assert_eq!(auto_shard("SELECT * FROM my_tbl WHERE w_id = 5"), Some(2));
|
||||
assert_eq!(
|
||||
auto_shard("BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ"),
|
||||
None
|
||||
);
|
||||
assert_eq!(auto_shard("COMMIT"), None);
|
||||
assert_eq!(auto_shard("ROLLBACK"), None);
|
||||
|
||||
assert_eq!(auto_shard("SELECT NO_O_ID FROM NEW_ORDER WHERE NO_D_ID = 7 AND W_ID = 5 AND NO_O_ID > 3 LIMIT 3"), Some(2));
|
||||
assert_eq!(auto_shard("SELECT NO_O_ID FROM NEW_ORDER no WHERE no.NO_D_ID = 7 AND no.W_ID = 5 AND no.NO_O_ID > 3 LIMIT 3"), Some(2));
|
||||
|
||||
assert_eq!(
|
||||
auto_shard("DELETE FROM NEW_ORDER WHERE NO_D_ID = 7 AND W_ID = 5 AND NO_O_ID = 3"),
|
||||
Some(2)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
auto_shard("SELECT O_C_ID FROM ORDERS WHERE O_ID = 3 AND O_D_ID = 3 AND W_ID = 5"),
|
||||
Some(2)
|
||||
);
|
||||
assert_eq!(
|
||||
auto_shard(
|
||||
"UPDATE ORDERS SET O_CARRIER_ID = 3 WHERE O_ID = 3 AND O_D_ID = 3 AND W_ID = 5"
|
||||
),
|
||||
Some(2)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
auto_shard("UPDATE ORDER_LINE SET OL_DELIVERY_D = 3 WHERE OL_O_ID = 3 AND OL_D_ID = 3 AND W_ID = 5"),
|
||||
Some(2)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
auto_shard("SELECT SUM(OL_AMOUNT) FROM ORDER_LINE WHERE OL_O_ID = 3 AND OL_D_ID = 3 AND W_ID = 5"),
|
||||
Some(2)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
auto_shard("UPDATE CUSTOMER SET C_BALANCE = C_BALANCE + 3 WHERE C_ID = 3 AND C_D_ID = 3 AND W_ID = 5"),
|
||||
Some(2)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
auto_shard("SELECT W_TAX FROM WAREHOUSE WHERE W_ID = 5"),
|
||||
Some(2)
|
||||
);
|
||||
assert_eq!(
|
||||
auto_shard("SELECT D_TAX, D_NEXT_O_ID FROM DISTRICT WHERE D_ID = 3 AND W_ID = 5"),
|
||||
Some(2)
|
||||
);
|
||||
assert_eq!(
|
||||
auto_shard("UPDATE DISTRICT SET D_NEXT_O_ID = 3 WHERE D_ID = 3 AND W_ID = 5"),
|
||||
Some(2)
|
||||
);
|
||||
assert_eq!(
|
||||
auto_shard("SELECT C_DISCOUNT, C_LAST, C_CREDIT FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
|
||||
Some(2)
|
||||
);
|
||||
assert_eq!(
|
||||
auto_shard("INSERT INTO ORDERS (O_ID, O_D_ID, W_ID, O_C_ID, O_ENTRY_D, O_CARRIER_ID, O_OL_CNT, O_ALL_LOCAL) VALUES (3, 3, 5, 3, 3, 3, 3, 3)"),
|
||||
Some(2)
|
||||
);
|
||||
assert_eq!(
|
||||
auto_shard("INSERT INTO NEW_ORDER (NO_O_ID, NO_D_ID, W_ID) VALUES (3, 3, 5)"),
|
||||
Some(2)
|
||||
);
|
||||
assert_eq!(
|
||||
auto_shard("SELECT I_PRICE, I_NAME, I_DATA FROM ITEM WHERE I_ID = 3"),
|
||||
None
|
||||
);
|
||||
assert_eq!(
|
||||
auto_shard("SELECT S_QUANTITY, S_DATA, S_YTD, S_ORDER_CNT, S_REMOTE_CNT, S_DIST_03 FROM STOCK WHERE S_I_ID = 3 AND W_ID = 5"),
|
||||
Some(2)
|
||||
);
|
||||
assert_eq!(
|
||||
auto_shard("UPDATE STOCK SET S_QUANTITY = 3, S_YTD = 3, S_ORDER_CNT = 3, S_REMOTE_CNT = 3 WHERE S_I_ID = 3 AND W_ID = 5"),
|
||||
Some(2)
|
||||
);
|
||||
assert_eq!(
|
||||
auto_shard("INSERT INTO ORDER_LINE (OL_O_ID, OL_D_ID, W_ID, OL_NUMBER, OL_I_ID, OL_SUPPLY_W_ID, OL_DELIVERY_D, OL_QUANTITY, OL_AMOUNT, OL_DIST_INFO) VALUES (3, 3, 5, 3, 3, 3, 3, 3, 3, 3)"),
|
||||
Some(2)
|
||||
);
|
||||
assert_eq!(
|
||||
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_BALANCE FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
|
||||
Some(2)
|
||||
);
|
||||
assert_eq!(
|
||||
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_BALANCE FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_LAST = '3' ORDER BY C_FIRST"),
|
||||
Some(2)
|
||||
);
|
||||
assert_eq!(
|
||||
auto_shard("SELECT O_ID, O_CARRIER_ID, O_ENTRY_D FROM ORDERS WHERE W_ID = 5 AND O_D_ID = 3 AND O_C_ID = 3 ORDER BY O_ID DESC LIMIT 3"),
|
||||
Some(2)
|
||||
);
|
||||
assert_eq!(
|
||||
auto_shard("SELECT OL_SUPPLY_W_ID, OL_I_ID, OL_QUANTITY, OL_AMOUNT, OL_DELIVERY_D FROM ORDER_LINE WHERE W_ID = 5 AND OL_D_ID = 3 AND OL_O_ID = 3"),
|
||||
Some(2)
|
||||
);
|
||||
assert_eq!(
|
||||
auto_shard("SELECT W_NAME, W_STREET_1, W_STREET_2, W_CITY, W_STATE, W_ZIP FROM WAREHOUSE WHERE W_ID = 5"),
|
||||
Some(2)
|
||||
);
|
||||
assert_eq!(
|
||||
auto_shard("UPDATE WAREHOUSE SET W_YTD = W_YTD + 3 WHERE W_ID = 5"),
|
||||
Some(2)
|
||||
);
|
||||
assert_eq!(
|
||||
auto_shard("SELECT D_NAME, D_STREET_1, D_STREET_2, D_CITY, D_STATE, D_ZIP FROM DISTRICT WHERE W_ID = 5 AND D_ID = 3"),
|
||||
Some(2)
|
||||
);
|
||||
assert_eq!(
|
||||
auto_shard("UPDATE DISTRICT SET D_YTD = D_YTD + 3 WHERE W_ID = 5 AND D_ID = 3"),
|
||||
Some(2)
|
||||
);
|
||||
assert_eq!(
|
||||
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_STREET_1, C_STREET_2, C_CITY, C_STATE, C_ZIP, C_PHONE, C_SINCE, C_CREDIT, C_CREDIT_LIM, C_DISCOUNT, C_BALANCE, C_YTD_PAYMENT, C_PAYMENT_CNT, C_DATA FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
|
||||
Some(2)
|
||||
);
|
||||
assert_eq!(
|
||||
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_STREET_1, C_STREET_2, C_CITY, C_STATE, C_ZIP, C_PHONE, C_SINCE, C_CREDIT, C_CREDIT_LIM, C_DISCOUNT, C_BALANCE, C_YTD_PAYMENT, C_PAYMENT_CNT, C_DATA FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_LAST = '3' ORDER BY C_FIRST"),
|
||||
Some(2)
|
||||
);
|
||||
assert_eq!(
|
||||
auto_shard("UPDATE CUSTOMER SET C_BALANCE = 3, C_YTD_PAYMENT = 3, C_PAYMENT_CNT = 3, C_DATA = 3 WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
|
||||
Some(2)
|
||||
);
|
||||
assert_eq!(
|
||||
auto_shard("UPDATE CUSTOMER SET C_BALANCE = 3, C_YTD_PAYMENT = 3, C_PAYMENT_CNT = 3 WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
|
||||
Some(2)
|
||||
);
|
||||
|
||||
assert_eq!(auto_shard("INSERT INTO HISTORY (H_C_ID, H_C_D_ID, H_C_W_ID, H_D_ID, W_ID, H_DATE, H_AMOUNT, H_DATA) VALUES (3, 3, 5, 3, 5, 3, 3, 3)"), Some(2));
|
||||
assert_eq!(
|
||||
auto_shard("SELECT D_NEXT_O_ID FROM DISTRICT WHERE W_ID = 5 AND D_ID = 3"),
|
||||
Some(2)
|
||||
);
|
||||
assert_eq!(
|
||||
auto_shard(
|
||||
"SELECT COUNT(DISTINCT(OL_I_ID)) FROM ORDER_LINE, STOCK
|
||||
WHERE ORDER_LINE.W_ID = 5
|
||||
AND OL_D_ID = 3
|
||||
AND OL_O_ID < 3
|
||||
AND OL_O_ID >= 3
|
||||
AND STOCK.W_ID = 5
|
||||
AND S_I_ID = OL_I_ID
|
||||
AND S_QUANTITY < 3"
|
||||
),
|
||||
Some(2)
|
||||
);
|
||||
|
||||
// This is a distributed query and contains two shards
|
||||
assert_eq!(
|
||||
auto_shard(
|
||||
"SELECT COUNT(DISTINCT(OL_I_ID)) FROM ORDER_LINE, STOCK
|
||||
WHERE ORDER_LINE.W_ID = 5
|
||||
AND OL_D_ID = 3
|
||||
AND OL_O_ID < 3
|
||||
AND OL_O_ID >= 3
|
||||
AND STOCK.W_ID = 7
|
||||
AND S_I_ID = OL_I_ID
|
||||
AND S_QUANTITY < 3"
|
||||
),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prepared_statements() {
|
||||
let stmt = "SELECT * FROM data WHERE id = $1";
|
||||
@@ -1458,12 +1854,13 @@ mod test {
|
||||
};
|
||||
|
||||
QueryRouter::setup();
|
||||
let mut pool_settings = PoolSettings::default();
|
||||
pool_settings.query_parser_enabled = true;
|
||||
pool_settings.plugins = Some(plugins);
|
||||
|
||||
let pool_settings = PoolSettings {
|
||||
query_parser_enabled: true,
|
||||
plugins: Some(plugins),
|
||||
..Default::default()
|
||||
};
|
||||
let mut qr = QueryRouter::new();
|
||||
qr.update_pool_settings(pool_settings);
|
||||
qr.update_pool_settings(&pool_settings);
|
||||
|
||||
let query = simple_query("SELECT * FROM pg_database");
|
||||
let ast = qr.parse(&query).unwrap();
|
||||
|
||||
14
src/scram.rs
14
src/scram.rs
@@ -79,12 +79,12 @@ impl ScramSha256 {
|
||||
let server_message = Message::parse(message)?;
|
||||
|
||||
if !server_message.nonce.starts_with(&self.nonce) {
|
||||
return Err(Error::ProtocolSyncError(format!("SCRAM")));
|
||||
return Err(Error::ProtocolSyncError("SCRAM".to_string()));
|
||||
}
|
||||
|
||||
let salt = match general_purpose::STANDARD.decode(&server_message.salt) {
|
||||
Ok(salt) => salt,
|
||||
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
|
||||
Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
|
||||
};
|
||||
|
||||
let salted_password = Self::hi(
|
||||
@@ -166,9 +166,9 @@ impl ScramSha256 {
|
||||
pub fn finish(&mut self, message: &BytesMut) -> Result<(), Error> {
|
||||
let final_message = FinalMessage::parse(message)?;
|
||||
|
||||
let verifier = match general_purpose::STANDARD.decode(&final_message.value) {
|
||||
let verifier = match general_purpose::STANDARD.decode(final_message.value) {
|
||||
Ok(verifier) => verifier,
|
||||
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
|
||||
Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
|
||||
};
|
||||
|
||||
let mut hmac = match Hmac::<Sha256>::new_from_slice(&self.salted_password) {
|
||||
@@ -230,14 +230,14 @@ impl Message {
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
if parts.len() != 3 {
|
||||
return Err(Error::ProtocolSyncError(format!("SCRAM")));
|
||||
return Err(Error::ProtocolSyncError("SCRAM".to_string()));
|
||||
}
|
||||
|
||||
let nonce = str::replace(&parts[0], "r=", "");
|
||||
let salt = str::replace(&parts[1], "s=", "");
|
||||
let iterations = match str::replace(&parts[2], "i=", "").parse::<u32>() {
|
||||
Ok(iterations) => iterations,
|
||||
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
|
||||
Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
|
||||
};
|
||||
|
||||
Ok(Message {
|
||||
@@ -257,7 +257,7 @@ impl FinalMessage {
|
||||
/// Parse the server final validation message.
|
||||
pub fn parse(message: &BytesMut) -> Result<FinalMessage, Error> {
|
||||
if !message.starts_with(b"v=") || message.len() < 4 {
|
||||
return Err(Error::ProtocolSyncError(format!("SCRAM")));
|
||||
return Err(Error::ProtocolSyncError("SCRAM".to_string()));
|
||||
}
|
||||
|
||||
Ok(FinalMessage {
|
||||
|
||||
236
src/server.rs
236
src/server.rs
@@ -3,12 +3,14 @@
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use log::{debug, error, info, trace, warn};
|
||||
use lru::LruCache;
|
||||
use once_cell::sync::Lazy;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use postgres_protocol::message;
|
||||
use std::collections::{BTreeSet, HashMap, HashSet};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::mem;
|
||||
use std::net::IpAddr;
|
||||
use std::num::NonZeroUsize;
|
||||
use std::sync::Arc;
|
||||
use std::time::SystemTime;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream};
|
||||
@@ -16,7 +18,7 @@ use tokio::net::TcpStream;
|
||||
use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore};
|
||||
use tokio_rustls::{client::TlsStream, TlsConnector};
|
||||
|
||||
use crate::config::{get_config, get_prepared_statements_cache_size, Address, User};
|
||||
use crate::config::{get_config, Address, User};
|
||||
use crate::constants::*;
|
||||
use crate::dns_cache::{AddrSet, CACHED_RESOLVER};
|
||||
use crate::errors::{Error, ServerIdentifier};
|
||||
@@ -197,12 +199,8 @@ impl ServerParameters {
|
||||
key = "DateStyle".to_string();
|
||||
};
|
||||
|
||||
if TRACKED_PARAMETERS.contains(&key) {
|
||||
if TRACKED_PARAMETERS.contains(&key) || startup {
|
||||
self.parameters.insert(key, value);
|
||||
} else {
|
||||
if startup {
|
||||
self.parameters.insert(key, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -326,12 +324,13 @@ pub struct Server {
|
||||
log_client_parameter_status_changes: bool,
|
||||
|
||||
/// Prepared statements
|
||||
prepared_statements: BTreeSet<String>,
|
||||
prepared_statement_cache: Option<LruCache<String, ()>>,
|
||||
}
|
||||
|
||||
impl Server {
|
||||
/// Pretend to be the Postgres client and connect to the server given host, port and credentials.
|
||||
/// Perform the authentication and return the server in a ready for query state.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn startup(
|
||||
address: &Address,
|
||||
user: &User,
|
||||
@@ -341,6 +340,7 @@ impl Server {
|
||||
auth_hash: Arc<RwLock<Option<String>>>,
|
||||
cleanup_connections: bool,
|
||||
log_client_parameter_status_changes: bool,
|
||||
prepared_statement_cache_size: usize,
|
||||
) -> Result<Server, Error> {
|
||||
let cached_resolver = CACHED_RESOLVER.load();
|
||||
let mut addr_set: Option<AddrSet> = None;
|
||||
@@ -440,10 +440,7 @@ impl Server {
|
||||
|
||||
// Something else?
|
||||
m => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Unknown message: {}",
|
||||
m as char
|
||||
)));
|
||||
return Err(Error::SocketError(format!("Unknown message: {}", { m })));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -461,26 +458,20 @@ impl Server {
|
||||
None => &user.username,
|
||||
};
|
||||
|
||||
let password = match user.server_password {
|
||||
Some(ref server_password) => Some(server_password),
|
||||
None => match user.password {
|
||||
Some(ref password) => Some(password),
|
||||
None => None,
|
||||
},
|
||||
let password = match user.server_password.as_ref() {
|
||||
Some(server_password) => Some(server_password),
|
||||
None => user.password.as_ref(),
|
||||
};
|
||||
|
||||
startup(&mut stream, username, database).await?;
|
||||
|
||||
let mut process_id: i32 = 0;
|
||||
let mut secret_key: i32 = 0;
|
||||
let server_identifier = ServerIdentifier::new(username, &database);
|
||||
let server_identifier = ServerIdentifier::new(username, database);
|
||||
|
||||
// We'll be handling multiple packets, but they will all be structured the same.
|
||||
// We'll loop here until this exchange is complete.
|
||||
let mut scram: Option<ScramSha256> = match password {
|
||||
Some(password) => Some(ScramSha256::new(password)),
|
||||
None => None,
|
||||
};
|
||||
let mut scram: Option<ScramSha256> = password.map(|password| ScramSha256::new(password));
|
||||
|
||||
let mut server_parameters = ServerParameters::new();
|
||||
|
||||
@@ -725,7 +716,7 @@ impl Server {
|
||||
}
|
||||
};
|
||||
|
||||
let fields = match PgErrorMsg::parse(error) {
|
||||
let fields = match PgErrorMsg::parse(&error) {
|
||||
Ok(f) => f,
|
||||
Err(err) => {
|
||||
return Err(err);
|
||||
@@ -830,7 +821,12 @@ impl Server {
|
||||
},
|
||||
cleanup_connections,
|
||||
log_client_parameter_status_changes,
|
||||
prepared_statements: BTreeSet::new(),
|
||||
prepared_statement_cache: match prepared_statement_cache_size {
|
||||
0 => None,
|
||||
_ => Some(LruCache::new(
|
||||
NonZeroUsize::new(prepared_statement_cache_size).unwrap(),
|
||||
)),
|
||||
},
|
||||
};
|
||||
|
||||
return Ok(server);
|
||||
@@ -882,7 +878,7 @@ impl Server {
|
||||
self.mirror_send(messages);
|
||||
self.stats().data_sent(messages.len());
|
||||
|
||||
match write_all_flush(&mut self.stream, &messages).await {
|
||||
match write_all_flush(&mut self.stream, messages).await {
|
||||
Ok(_) => {
|
||||
// Successfully sent to server
|
||||
self.last_activity = SystemTime::now();
|
||||
@@ -969,6 +965,20 @@ impl Server {
|
||||
if self.in_copy_mode {
|
||||
self.in_copy_mode = false;
|
||||
}
|
||||
|
||||
if self.prepared_statement_cache.is_some() {
|
||||
let error_message = PgErrorMsg::parse(&message)?;
|
||||
if error_message.message == "cached plan must not change result type" {
|
||||
warn!("Server {:?} changed schema, dropping connection to clean up prepared statements", self.address);
|
||||
// This will still result in an error to the client, but this server connection will drop all cached prepared statements
|
||||
// so that any new queries will be re-prepared
|
||||
// TODO: Other ideas to solve errors when there are DDL changes after a statement has been prepared
|
||||
// - Recreate entire connection pool to force recreation of all server connections
|
||||
// - Clear the ConnectionPool's statement cache so that new statement names are generated
|
||||
// - Implement a retry (re-prepare) so the client doesn't see an error
|
||||
self.cleanup_state.needs_cleanup_prepare = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CommandComplete
|
||||
@@ -1079,115 +1089,92 @@ impl Server {
|
||||
Ok(bytes)
|
||||
}
|
||||
|
||||
/// Add the prepared statement to being tracked by this server.
|
||||
/// The client is processing data that will create a prepared statement on this server.
|
||||
pub fn will_prepare(&mut self, name: &str) {
|
||||
debug!("Will prepare `{}`", name);
|
||||
// Determines if the server already has a prepared statement with the given name
|
||||
// Increments the prepared statement cache hit counter
|
||||
pub fn has_prepared_statement(&mut self, name: &str) -> bool {
|
||||
let cache = match &mut self.prepared_statement_cache {
|
||||
Some(cache) => cache,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
self.prepared_statements.insert(name.to_string());
|
||||
self.stats.prepared_cache_add();
|
||||
}
|
||||
|
||||
/// Check if we should prepare a statement on the server.
|
||||
pub fn should_prepare(&self, name: &str) -> bool {
|
||||
let should_prepare = !self.prepared_statements.contains(name);
|
||||
|
||||
debug!("Should prepare `{}`: {}", name, should_prepare);
|
||||
|
||||
if should_prepare {
|
||||
self.stats.prepared_cache_miss();
|
||||
} else {
|
||||
let has_it = cache.get(name).is_some();
|
||||
if has_it {
|
||||
self.stats.prepared_cache_hit();
|
||||
} else {
|
||||
self.stats.prepared_cache_miss();
|
||||
}
|
||||
|
||||
should_prepare
|
||||
has_it
|
||||
}
|
||||
|
||||
/// Create a prepared statement on the server.
|
||||
pub async fn prepare(&mut self, parse: &Parse) -> Result<(), Error> {
|
||||
debug!("Preparing `{}`", parse.name);
|
||||
pub fn add_prepared_statement_to_cache(&mut self, name: &str) -> Option<String> {
|
||||
let cache = match &mut self.prepared_statement_cache {
|
||||
Some(cache) => cache,
|
||||
None => return None,
|
||||
};
|
||||
|
||||
let bytes: BytesMut = parse.try_into()?;
|
||||
self.send(&bytes).await?;
|
||||
self.send(&flush()).await?;
|
||||
|
||||
// Read and discard ParseComplete (B)
|
||||
match read_message(&mut self.stream).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
self.bad = true;
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
|
||||
self.prepared_statements.insert(parse.name.to_string());
|
||||
self.stats.prepared_cache_add();
|
||||
|
||||
debug!("Prepared `{}`", parse.name);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Maintain adequate cache size on the server.
|
||||
pub async fn maintain_cache(&mut self) -> Result<(), Error> {
|
||||
debug!("Cache maintenance run");
|
||||
|
||||
let max_cache_size = get_prepared_statements_cache_size();
|
||||
let mut names = Vec::new();
|
||||
|
||||
while self.prepared_statements.len() >= max_cache_size {
|
||||
// The prepared statmeents are alphanumerically sorted by the BTree.
|
||||
// FIFO.
|
||||
if let Some(name) = self.prepared_statements.pop_last() {
|
||||
names.push(name);
|
||||
// If we evict something, we need to close it on the server
|
||||
if let Some((evicted_name, _)) = cache.push(name.to_string(), ()) {
|
||||
if evicted_name != name {
|
||||
debug!(
|
||||
"Evicted prepared statement {} from cache, replaced with {}",
|
||||
evicted_name, name
|
||||
);
|
||||
return Some(evicted_name);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if !names.is_empty() {
|
||||
self.deallocate(names).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
None
|
||||
}
|
||||
|
||||
/// Remove the prepared statement from being tracked by this server.
|
||||
/// The client is processing data that will cause the server to close the prepared statement.
|
||||
pub fn will_close(&mut self, name: &str) {
|
||||
debug!("Will close `{}`", name);
|
||||
pub fn remove_prepared_statement_from_cache(&mut self, name: &str) {
|
||||
let cache = match &mut self.prepared_statement_cache {
|
||||
Some(cache) => cache,
|
||||
None => return,
|
||||
};
|
||||
|
||||
self.prepared_statements.remove(name);
|
||||
self.stats.prepared_cache_remove();
|
||||
cache.pop(name);
|
||||
}
|
||||
|
||||
/// Close a prepared statement on the server.
|
||||
pub async fn deallocate(&mut self, names: Vec<String>) -> Result<(), Error> {
|
||||
for name in &names {
|
||||
debug!("Deallocating prepared statement `{}`", name);
|
||||
pub async fn register_prepared_statement(
|
||||
&mut self,
|
||||
parse: &Parse,
|
||||
should_send_parse_to_server: bool,
|
||||
) -> Result<(), Error> {
|
||||
if !self.has_prepared_statement(&parse.name) {
|
||||
let mut bytes = BytesMut::new();
|
||||
|
||||
let close = Close::new(name);
|
||||
let bytes: BytesMut = close.try_into()?;
|
||||
if should_send_parse_to_server {
|
||||
let parse_bytes: BytesMut = parse.try_into()?;
|
||||
bytes.extend_from_slice(&parse_bytes);
|
||||
}
|
||||
|
||||
self.send(&bytes).await?;
|
||||
}
|
||||
|
||||
if !names.is_empty() {
|
||||
self.send(&flush()).await?;
|
||||
}
|
||||
|
||||
// Read and discard CloseComplete (3)
|
||||
for name in &names {
|
||||
match read_message(&mut self.stream).await {
|
||||
Ok(_) => {
|
||||
self.prepared_statements.remove(name);
|
||||
self.stats.prepared_cache_remove();
|
||||
debug!("Closed `{}`", name);
|
||||
}
|
||||
|
||||
Err(err) => {
|
||||
self.bad = true;
|
||||
return Err(err);
|
||||
}
|
||||
// If we evict something, we need to close it on the server
|
||||
// We do this by adding it to the messages we're sending to the server before the sync
|
||||
if let Some(evicted_name) = self.add_prepared_statement_to_cache(&parse.name) {
|
||||
self.remove_prepared_statement_from_cache(&evicted_name);
|
||||
let close_bytes: BytesMut = Close::new(&evicted_name).try_into()?;
|
||||
bytes.extend_from_slice(&close_bytes);
|
||||
};
|
||||
}
|
||||
|
||||
// If we have a parse or close we need to send to the server, send them and sync
|
||||
if !bytes.is_empty() {
|
||||
bytes.extend_from_slice(&sync());
|
||||
|
||||
self.send(&bytes).await?;
|
||||
|
||||
loop {
|
||||
self.recv(None).await?;
|
||||
|
||||
if !self.is_data_available() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1324,6 +1311,10 @@ impl Server {
|
||||
|
||||
if self.cleanup_state.needs_cleanup_prepare {
|
||||
reset_string.push_str("DEALLOCATE ALL;");
|
||||
// Since we deallocated all prepared statements, we need to clear the cache
|
||||
if let Some(cache) = &mut self.prepared_statement_cache {
|
||||
cache.clear();
|
||||
}
|
||||
};
|
||||
|
||||
self.query(&reset_string).await?;
|
||||
@@ -1359,16 +1350,14 @@ impl Server {
|
||||
}
|
||||
|
||||
pub fn mirror_send(&mut self, bytes: &BytesMut) {
|
||||
match self.mirror_manager.as_mut() {
|
||||
Some(manager) => manager.send(bytes),
|
||||
None => (),
|
||||
if let Some(manager) = self.mirror_manager.as_mut() {
|
||||
manager.send(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mirror_disconnect(&mut self) {
|
||||
match self.mirror_manager.as_mut() {
|
||||
Some(manager) => manager.disconnect(),
|
||||
None => (),
|
||||
if let Some(manager) = self.mirror_manager.as_mut() {
|
||||
manager.disconnect()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1391,13 +1380,14 @@ impl Server {
|
||||
Arc::new(RwLock::new(None)),
|
||||
true,
|
||||
false,
|
||||
0,
|
||||
)
|
||||
.await?;
|
||||
debug!("Connected!, sending query.");
|
||||
server.send(&simple_query(query)).await?;
|
||||
let mut message = server.recv(None).await?;
|
||||
|
||||
Ok(parse_query_message(&mut message).await?)
|
||||
parse_query_message(&mut message).await
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ impl Sharder {
|
||||
fn sha1(&self, key: i64) -> usize {
|
||||
let mut hasher = Sha1::new();
|
||||
|
||||
hasher.update(&key.to_string().as_bytes());
|
||||
hasher.update(key.to_string().as_bytes());
|
||||
|
||||
let result = hasher.finalize();
|
||||
|
||||
@@ -202,10 +202,10 @@ mod test {
|
||||
#[test]
|
||||
fn test_sha1_hash() {
|
||||
let sharder = Sharder::new(12, ShardingFunction::Sha1);
|
||||
let ids = vec![
|
||||
let ids = [
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
|
||||
];
|
||||
let shards = vec![
|
||||
let shards = [
|
||||
4, 7, 8, 3, 6, 0, 0, 10, 3, 11, 1, 7, 4, 4, 11, 2, 5, 0, 8, 3,
|
||||
];
|
||||
|
||||
|
||||
@@ -86,11 +86,11 @@ impl PoolStats {
|
||||
}
|
||||
}
|
||||
|
||||
return map;
|
||||
map
|
||||
}
|
||||
|
||||
pub fn generate_header() -> Vec<(&'static str, DataType)> {
|
||||
return vec![
|
||||
vec![
|
||||
("database", DataType::Text),
|
||||
("user", DataType::Text),
|
||||
("pool_mode", DataType::Text),
|
||||
@@ -105,11 +105,11 @@ impl PoolStats {
|
||||
("sv_login", DataType::Numeric),
|
||||
("maxwait", DataType::Numeric),
|
||||
("maxwait_us", DataType::Numeric),
|
||||
];
|
||||
]
|
||||
}
|
||||
|
||||
pub fn generate_row(&self) -> Vec<String> {
|
||||
return vec![
|
||||
vec![
|
||||
self.identifier.db.clone(),
|
||||
self.identifier.user.clone(),
|
||||
self.mode.to_string(),
|
||||
@@ -124,7 +124,7 @@ impl PoolStats {
|
||||
self.sv_login.to_string(),
|
||||
(self.maxwait / 1_000_000).to_string(),
|
||||
(self.maxwait % 1_000_000).to_string(),
|
||||
];
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ pub struct ServerStats {
|
||||
pub error_count: Arc<AtomicU64>,
|
||||
pub prepared_hit_count: Arc<AtomicU64>,
|
||||
pub prepared_miss_count: Arc<AtomicU64>,
|
||||
pub prepared_eviction_count: Arc<AtomicU64>,
|
||||
pub prepared_cache_size: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
@@ -68,6 +69,7 @@ impl Default for ServerStats {
|
||||
reporter: get_reporter(),
|
||||
prepared_hit_count: Arc::new(AtomicU64::new(0)),
|
||||
prepared_miss_count: Arc::new(AtomicU64::new(0)),
|
||||
prepared_eviction_count: Arc::new(AtomicU64::new(0)),
|
||||
prepared_cache_size: Arc::new(AtomicU64::new(0)),
|
||||
}
|
||||
}
|
||||
@@ -221,6 +223,7 @@ impl ServerStats {
|
||||
}
|
||||
|
||||
pub fn prepared_cache_remove(&self) {
|
||||
self.prepared_eviction_count.fetch_add(1, Ordering::Relaxed);
|
||||
self.prepared_cache_size.fetch_sub(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,4 +36,4 @@ SELECT abalance FROM pgbench_accounts WHERE aid = :aid;
|
||||
SET SERVER ROLE TO 'replica';
|
||||
|
||||
-- Read load balancing
|
||||
SELECT abalance FROM pgbench_accounts WHERE aid = :aid;
|
||||
SELECT abalance FROM pgbench_accounts WHERE aid = :aid;
|
||||
@@ -1,29 +1,214 @@
|
||||
require_relative 'spec_helper'
|
||||
|
||||
describe 'Prepared statements' do
|
||||
let(:processes) { Helpers::Pgcat.three_shard_setup('sharded_db', 5) }
|
||||
let(:pool_size) { 5 }
|
||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", pool_size) }
|
||||
let(:prepared_statements_cache_size) { 100 }
|
||||
let(:server_round_robin) { false }
|
||||
|
||||
context 'enabled' do
|
||||
it 'will work over the same connection' do
|
||||
before do
|
||||
new_configs = processes.pgcat.current_config
|
||||
new_configs["general"]["server_round_robin"] = server_round_robin
|
||||
new_configs["pools"]["sharded_db"]["prepared_statements_cache_size"] = prepared_statements_cache_size
|
||||
new_configs["pools"]["sharded_db"]["users"]["0"]["pool_size"] = pool_size
|
||||
processes.pgcat.update_config(new_configs)
|
||||
processes.pgcat.reload_config
|
||||
end
|
||||
|
||||
context 'when trying prepared statements' do
|
||||
it 'it allows unparameterized statements to succeed' do
|
||||
conn1 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
|
||||
conn2 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
|
||||
|
||||
prepared_query = "SELECT 1"
|
||||
|
||||
# prepare query on server 1 and client 1
|
||||
conn1.prepare('statement1', prepared_query)
|
||||
conn1.exec_prepared('statement1')
|
||||
|
||||
conn2.transaction do
|
||||
# Claim server 1 with client 2
|
||||
conn2.exec("SELECT 2")
|
||||
|
||||
# Client 1 now runs the prepared query, and it's automatically
|
||||
# prepared on server 2
|
||||
conn1.prepare('statement2', prepared_query)
|
||||
conn1.exec_prepared('statement2')
|
||||
|
||||
# Client 2 now prepares the same query that was already
|
||||
# prepared on server 1. And PgBouncer reuses that already
|
||||
# prepared query for this different client.
|
||||
conn2.prepare('statement3', prepared_query)
|
||||
conn2.exec_prepared('statement3')
|
||||
end
|
||||
ensure
|
||||
conn1.close if conn1
|
||||
conn2.close if conn2
|
||||
end
|
||||
|
||||
it 'it allows parameterized statements to succeed' do
|
||||
conn1 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
|
||||
conn2 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
|
||||
|
||||
prepared_query = "SELECT $1"
|
||||
|
||||
# prepare query on server 1 and client 1
|
||||
conn1.prepare('statement1', prepared_query)
|
||||
conn1.exec_prepared('statement1', [1])
|
||||
|
||||
conn2.transaction do
|
||||
# Claim server 1 with client 2
|
||||
conn2.exec("SELECT 2")
|
||||
|
||||
# Client 1 now runs the prepared query, and it's automatically
|
||||
# prepared on server 2
|
||||
conn1.prepare('statement2', prepared_query)
|
||||
conn1.exec_prepared('statement2', [1])
|
||||
|
||||
# Client 2 now prepares the same query that was already
|
||||
# prepared on server 1. And PgBouncer reuses that already
|
||||
# prepared query for this different client.
|
||||
conn2.prepare('statement3', prepared_query)
|
||||
conn2.exec_prepared('statement3', [1])
|
||||
end
|
||||
ensure
|
||||
conn1.close if conn1
|
||||
conn2.close if conn2
|
||||
|
||||
end
|
||||
end
|
||||
|
||||
context 'when trying large packets' do
|
||||
it "works with large parse" do
|
||||
conn1 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
|
||||
|
||||
long_string = "1" * 4096 * 10
|
||||
prepared_query = "SELECT '#{long_string}'"
|
||||
|
||||
|
||||
# prepare query on server 1 and client 1
|
||||
conn1.prepare('statement1', prepared_query)
|
||||
result = conn1.exec_prepared('statement1')
|
||||
|
||||
# assert result matches long_string
|
||||
expect(result.getvalue(0, 0)).to eq(long_string)
|
||||
ensure
|
||||
conn1.close if conn1
|
||||
end
|
||||
|
||||
it "works with large bind" do
|
||||
conn1 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
|
||||
|
||||
long_string = "1" * 4096 * 10
|
||||
prepared_query = "SELECT $1::text"
|
||||
|
||||
# prepare query on server 1 and client 1
|
||||
conn1.prepare('statement1', prepared_query)
|
||||
result = conn1.exec_prepared('statement1', [long_string])
|
||||
|
||||
# assert result matches long_string
|
||||
expect(result.getvalue(0, 0)).to eq(long_string)
|
||||
ensure
|
||||
conn1.close if conn1
|
||||
end
|
||||
end
|
||||
|
||||
context 'when statement cache is smaller than set of unqiue statements' do
|
||||
let(:prepared_statements_cache_size) { 1 }
|
||||
let(:pool_size) { 1 }
|
||||
|
||||
it "evicts all but 1 statement from the server cache" do
|
||||
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
|
||||
|
||||
5.times do |i|
|
||||
prepared_query = "SELECT '#{i}'"
|
||||
conn.prepare("statement#{i}", prepared_query)
|
||||
result = conn.exec_prepared("statement#{i}")
|
||||
expect(result.getvalue(0, 0)).to eq(i.to_s)
|
||||
end
|
||||
|
||||
# Check number of prepared statements (expected: 1)
|
||||
n_statements = conn.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
|
||||
expect(n_statements).to eq(1)
|
||||
end
|
||||
end
|
||||
|
||||
context 'when statement cache is larger than set of unqiue statements' do
|
||||
let(:pool_size) { 1 }
|
||||
|
||||
it "does not evict any of the statements from the cache" do
|
||||
# cache size 5
|
||||
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
|
||||
|
||||
5.times do |i|
|
||||
prepared_query = "SELECT '#{i}'"
|
||||
conn.prepare("statement#{i}", prepared_query)
|
||||
result = conn.exec_prepared("statement#{i}")
|
||||
expect(result.getvalue(0, 0)).to eq(i.to_s)
|
||||
end
|
||||
|
||||
# Check number of prepared statements (expected: 1)
|
||||
n_statements = conn.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
|
||||
expect(n_statements).to eq(5)
|
||||
end
|
||||
end
|
||||
|
||||
context 'when preparing the same query' do
|
||||
let(:prepared_statements_cache_size) { 5 }
|
||||
let(:pool_size) { 5 }
|
||||
|
||||
it "reuses statement cache when there are different statement names on the same connection" do
|
||||
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
|
||||
|
||||
10.times do |i|
|
||||
statement_name = "statement_#{i}"
|
||||
conn.prepare(statement_name, 'SELECT $1::int')
|
||||
conn.exec_prepared(statement_name, [1])
|
||||
conn.describe_prepared(statement_name)
|
||||
end
|
||||
|
||||
# Check number of prepared statements (expected: 1)
|
||||
n_statements = conn.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
|
||||
expect(n_statements).to eq(1)
|
||||
end
|
||||
|
||||
it 'will work with new connections' do
|
||||
10.times do
|
||||
it "reuses statement cache when there are different statement names on different connections" do
|
||||
10.times do |i|
|
||||
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
|
||||
|
||||
statement_name = 'statement1'
|
||||
conn.prepare('statement1', 'SELECT $1::int')
|
||||
conn.exec_prepared('statement1', [1])
|
||||
conn.describe_prepared('statement1')
|
||||
statement_name = "statement_#{i}"
|
||||
conn.prepare(statement_name, 'SELECT $1::int')
|
||||
conn.exec_prepared(statement_name, [1])
|
||||
end
|
||||
|
||||
# Check number of prepared statements (expected: 1)
|
||||
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
|
||||
n_statements = conn.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
|
||||
expect(n_statements).to eq(1)
|
||||
end
|
||||
end
|
||||
|
||||
context 'when reloading config' do
|
||||
let(:pool_size) { 1 }
|
||||
|
||||
it "test_reload_config" do
|
||||
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
|
||||
|
||||
# prepare query
|
||||
conn.prepare('statement1', 'SELECT 1')
|
||||
conn.exec_prepared('statement1')
|
||||
|
||||
# Reload config which triggers pool recreation
|
||||
new_configs = processes.pgcat.current_config
|
||||
new_configs["pools"]["sharded_db"]["prepared_statements_cache_size"] = prepared_statements_cache_size + 1
|
||||
processes.pgcat.update_config(new_configs)
|
||||
processes.pgcat.reload_config
|
||||
|
||||
# check that we're starting with no prepared statements on the server
|
||||
conn_check = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
|
||||
n_statements = conn_check.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
|
||||
expect(n_statements).to eq(0)
|
||||
|
||||
# still able to run prepared query
|
||||
conn.exec_prepared('statement1')
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
4
tests/rust/Cargo.lock
generated
4
tests/rust/Cargo.lock
generated
@@ -1206,9 +1206,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "webpki"
|
||||
version = "0.22.0"
|
||||
version = "0.22.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd"
|
||||
checksum = "07ecc0cd7cac091bf682ec5efa18b1cff79d617b84181f38b3951dbe135f607f"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"untrusted",
|
||||
|
||||
Reference in New Issue
Block a user