Compare commits

...

10 Commits

Author SHA1 Message Date
Lev Kokotov
9b41cc2639 bump version 2023-10-26 10:50:06 -07:00
Zain Kabani
7d3003a16a Reimplement prepared statements with LRU cache and statement deduplication (#618)
* Initial commit

* Cleanup and add stats

* Use an arc instead of full clones to store the parse packets

* Use mutex instead

* fmt

* clippy

* fmt

* fix?

* fix?

* fmt

* typo

* Update docs

* Refactor custom protocol

* fmt

* move custom protocol handling to before parsing

* Support describe

* Add LRU for server side statement cache

* rename variable

* Refactoring

* Move docs

* Fix test

* fix

* Update tests

* trigger build

* Add more tests

* Reorder handling sync

* Support when a named describe is sent along with Parse (go pgx) and expecting results

* don't talk to client if not needed when client sends Parse

* fmt :(

* refactor tests

* nit

* Reduce hashing

* Reducing work done to decode describe and parse messages

* minor refactor

* Merge branch 'main' into zain/reimplment-prepared-statements-with-global-lru-cache

* Rewrite extended and prepared protocol message handling to better support mocking response packets and close

* An attempt to better handle if there are DDL changes that might break cached plans with ideas about how to further improve it

* fix

* Minor stats fixed and cleanup

* Cosmetic fixes (#64)

* Cosmetic fixes

* fix test

* Change server drop for statement cache error to a `deallocate all`

* Updated comments and added new idea for handling DDL changes impacting cached plans

* fix test?

* Revert test change

* trigger build, flakey test

* Avoid potential race conditions by changing get_or_insert to promote for pool LRU

* remove ps enabled variable on the server in favor of using an option

* Add close to the Extended Protocol buffer

---------

Co-authored-by: Lev Kokotov <levkk@users.noreply.github.com>
2023-10-25 15:11:57 -07:00
Zain Kabani
d37df43a90 Reduces the amount of time the get_pool operation takes (#625)
* Reduces the amount of time the get_pool operation takes

* trigger build

* Fix admin
2023-10-19 23:49:05 -07:00
Mohammad Dashti
2c7bf52c17 Removed unnecessary clippy overrides. (#614)
Removed unnecessary clippy overrides.
2023-10-11 10:13:23 -07:00
Mohammad Dashti
de8df29ca4 Added clippy to CI and fixed all clippy warnings (#613)
* Fixed all clippy warnings.

* Added `clippy` to CI.

* Reverted an unwanted change + Applied `cargo fmt`.

* Fixed the idiom version.

* Revert "Fixed the idiom version."

This reverts commit 6f78be0d42.

* Fixed clippy issues on CI.

* Revert "Fixed clippy issues on CI."

This reverts commit a9fa6ba189.

* Revert "Reverted an unwanted change + Applied `cargo fmt`."

This reverts commit 6bd37b6479.

* Revert "Fixed all clippy warnings."

This reverts commit d1f3b847e3.

* Removed Clippy

* Removed Lint

* `admin.rs` clippy fixes.

* Applied more clippy changes.

* Even more clippy changes.

* `client.rs` clippy fixes.

* `server.rs` clippy fixes.

* Revert "Removed Lint"

This reverts commit cb5042b144.

* Revert "Removed Clippy"

This reverts commit 6dec8bffb1.

* Applied lint.

* Revert "Revert "Fixed clippy issues on CI.""

This reverts commit 49164a733c.
2023-10-10 09:18:21 -07:00
Mohammad Dashti
c4fb72b9fc Added yj to dev Dockerfile (#612) 2023-10-05 18:13:22 -07:00
Mohammad Dashti
3371c01e0e Added a Plugin trait (#536)
* Improved logging

* Improved logging for more `Address` usages

* Fixed lint issues.

* Reverted the `Address` logging changes.

* Applied the PR comment by @levkk.

* Applied the PR comment by @levkk.

* Applied the PR comment by @levkk.

* Applied the PR comment by @levkk.
2023-10-03 13:13:21 -07:00
Mohammad Dashti
c2a483f36a Automatic sharding for INSERT, UPDATE, and DELETE statements. (#610)
Added support for INSERT, UPDATE, and DELETE for auto-sharding.
2023-10-03 09:36:13 -07:00
dependabot[bot]
51cd13b8b5 chore(deps): bump webpki from 0.22.0 to 0.22.2 in /tests/rust (#609)
Bumps [webpki](https://github.com/briansmith/webpki) from 0.22.0 to 0.22.2.
- [Commits](https://github.com/briansmith/webpki/commits)

---
updated-dependencies:
- dependency-name: webpki
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-10-02 15:30:42 -07:00
Nicolas Vanelslande
a054b454d2 Add psql to the container image. (#607)
It could be used to implement container health checks.
Example:
  PGPASSWORD="<some-password>" psql -U pgcat -p 6432 -h 127.0.0.1 -tA -c "show version;" -d pgcat >/dev/null
2023-09-27 09:03:39 -07:00
30 changed files with 1931 additions and 909 deletions

View File

@@ -63,6 +63,9 @@ jobs:
- run: - run:
name: "Lint" name: "Lint"
command: "cargo fmt --check" command: "cargo fmt --check"
- run:
name: "Clippy"
command: "cargo clippy --all --all-targets -- -Dwarnings"
- run: - run:
name: "Tests" name: "Tests"
command: "cargo clean && cargo build && cargo test && bash .circleci/run_tests.sh && .circleci/generate_coverage.sh" command: "cargo clean && cargo build && cargo test && bash .circleci/run_tests.sh && .circleci/generate_coverage.sh"

View File

@@ -4,7 +4,7 @@ on:
workflow_dispatch: workflow_dispatch:
inputs: inputs:
packageVersion: packageVersion:
default: "1.1.2-dev" default: "1.1.2-dev1"
jobs: jobs:
build: build:
strategy: strategy:

1
.gitignore vendored
View File

@@ -10,3 +10,4 @@ lcov.info
dev/.bash_history dev/.bash_history
dev/cache dev/cache
!dev/cache/.keepme !dev/cache/.keepme
.venv

View File

@@ -259,22 +259,6 @@ Password to be used for connecting to servers to obtain the hash used for md5 au
specified in `auth_query_user`. The connection will be established using the database configured in the pool. specified in `auth_query_user`. The connection will be established using the database configured in the pool.
This parameter is inherited by every pool and can be redefined in pool configuration. This parameter is inherited by every pool and can be redefined in pool configuration.
### prepared_statements
```
path: general.prepared_statements
default: false
```
Whether to use prepared statements or not.
### prepared_statements_cache_size
```
path: general.prepared_statements_cache_size
default: 500
```
Size of the prepared statements cache.
### dns_cache_enabled ### dns_cache_enabled
``` ```
path: general.dns_cache_enabled path: general.dns_cache_enabled
@@ -324,6 +308,15 @@ If the client doesn't specify, PgCat routes traffic to this role by default.
`replica` round-robin between replicas only without touching the primary, `replica` round-robin between replicas only without touching the primary,
`primary` all queries go to the primary unless otherwise specified. `primary` all queries go to the primary unless otherwise specified.
### prepared_statements_cache_size
```
path: general.prepared_statements_cache_size
default: 0
```
Size of the prepared statements cache. 0 means disabled.
TODO: update documentation
### query_parser_enabled ### query_parser_enabled
``` ```
path: pools.<pool_name>.query_parser_enabled path: pools.<pool_name>.query_parser_enabled

View File

@@ -2,7 +2,7 @@
Thank you for contributing! Just a few tips here: Thank you for contributing! Just a few tips here:
1. `cargo fmt` your code before opening up a PR 1. `cargo fmt` and `cargo clippy` your code before opening up a PR
2. Run the test suite (e.g. `pgbench`) to make sure everything still works. The tests are in `.circleci/run_tests.sh`. 2. Run the test suite (e.g. `pgbench`) to make sure everything still works. The tests are in `.circleci/run_tests.sh`.
3. Performance is important, make sure there are no regressions in your branch vs. `main`. 3. Performance is important, make sure there are no regressions in your branch vs. `main`.

33
Cargo.lock generated
View File

@@ -17,6 +17,17 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "ahash"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f"
dependencies = [
"cfg-if",
"once_cell",
"version_check",
]
[[package]] [[package]]
name = "aho-corasick" name = "aho-corasick"
version = "1.0.2" version = "1.0.2"
@@ -26,6 +37,12 @@ dependencies = [
"memchr", "memchr",
] ]
[[package]]
name = "allocator-api2"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5"
[[package]] [[package]]
name = "android-tzdata" name = "android-tzdata"
version = "0.1.1" version = "0.1.1"
@@ -553,6 +570,10 @@ name = "hashbrown"
version = "0.14.0" version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a"
dependencies = [
"ahash",
"allocator-api2",
]
[[package]] [[package]]
name = "heck" name = "heck"
@@ -821,6 +842,15 @@ version = "0.4.19"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4"
[[package]]
name = "lru"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1efa59af2ddfad1854ae27d75009d538d0998b4b2fd47083e743ac1a10e46c60"
dependencies = [
"hashbrown 0.14.0",
]
[[package]] [[package]]
name = "lru-cache" name = "lru-cache"
version = "0.1.2" version = "0.1.2"
@@ -990,7 +1020,7 @@ checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94"
[[package]] [[package]]
name = "pgcat" name = "pgcat"
version = "1.1.2-dev" version = "1.1.2-dev1"
dependencies = [ dependencies = [
"arc-swap", "arc-swap",
"async-trait", "async-trait",
@@ -1008,6 +1038,7 @@ dependencies = [
"itertools", "itertools",
"jemallocator", "jemallocator",
"log", "log",
"lru",
"md-5", "md-5",
"nix", "nix",
"num_cpus", "num_cpus",

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "pgcat" name = "pgcat"
version = "1.1.2-dev" version = "1.1.2-dev1"
edition = "2021" edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -48,6 +48,7 @@ itertools = "0.10"
clap = { version = "4.3.1", features = ["derive", "env"] } clap = { version = "4.3.1", features = ["derive", "env"] }
tracing = "0.1.37" tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter", "std"]} tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter", "std"]}
lru = "0.12.0"
[target.'cfg(not(target_env = "msvc"))'.dependencies] [target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0" jemallocator = "0.5.0"

View File

@@ -8,6 +8,12 @@ WORKDIR /app
RUN cargo build --release RUN cargo build --release
FROM debian:bookworm-slim FROM debian:bookworm-slim
RUN apt-get update && apt-get install -o Dpkg::Options::=--force-confdef -yq --no-install-recommends \
postgresql-client \
# Clean up layer
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* \
&& truncate -s 0 /var/log/*log
COPY --from=builder /app/target/release/pgcat /usr/bin/pgcat COPY --from=builder /app/target/release/pgcat /usr/bin/pgcat
COPY --from=builder /app/pgcat.toml /etc/pgcat/pgcat.toml COPY --from=builder /app/pgcat.toml /etc/pgcat/pgcat.toml
WORKDIR /etc/pgcat WORKDIR /etc/pgcat

View File

@@ -1,6 +1,8 @@
FROM rust:1.70-bullseye FROM rust:1.70-bullseye
# Dependencies # Dependencies
COPY --from=sclevine/yj /bin/yj /bin/yj
RUN /bin/yj -h
RUN apt-get update -y \ RUN apt-get update -y \
&& apt-get install -y \ && apt-get install -y \
llvm-11 psmisc postgresql-contrib postgresql-client \ llvm-11 psmisc postgresql-contrib postgresql-client \

View File

@@ -60,12 +60,6 @@ tcp_keepalives_count = 5
# Number of seconds between keepalive packets. # Number of seconds between keepalive packets.
tcp_keepalives_interval = 5 tcp_keepalives_interval = 5
# Handle prepared statements.
prepared_statements = true
# Prepared statements server cache size.
prepared_statements_cache_size = 500
# Path to TLS Certificate file to use for TLS connections # Path to TLS Certificate file to use for TLS connections
# tls_certificate = ".circleci/server.cert" # tls_certificate = ".circleci/server.cert"
# Path to TLS private key file to use for TLS connections # Path to TLS private key file to use for TLS connections
@@ -156,6 +150,10 @@ load_balancing_mode = "random"
# `primary` all queries go to the primary unless otherwise specified. # `primary` all queries go to the primary unless otherwise specified.
default_role = "any" default_role = "any"
# Prepared statements cache size.
# TODO: update documentation
prepared_statements_cache_size = 500
# If Query Parser is enabled, we'll attempt to parse # If Query Parser is enabled, we'll attempt to parse
# every incoming query to determine if it's a read or a write. # every incoming query to determine if it's a read or a write.
# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write, # If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,

View File

@@ -283,7 +283,7 @@ where
{ {
let mut res = BytesMut::new(); let mut res = BytesMut::new();
let detail_msg = vec![ let detail_msg = [
"", "",
"SHOW HELP|CONFIG|DATABASES|POOLS|CLIENTS|SERVERS|USERS|VERSION", "SHOW HELP|CONFIG|DATABASES|POOLS|CLIENTS|SERVERS|USERS|VERSION",
// "SHOW PEERS|PEER_POOLS", // missing PEERS|PEER_POOLS // "SHOW PEERS|PEER_POOLS", // missing PEERS|PEER_POOLS
@@ -301,7 +301,6 @@ where
// "KILL <db>", // "KILL <db>",
// "SUSPEND", // "SUSPEND",
"SHUTDOWN", "SHUTDOWN",
// "WAIT_CLOSE [<db>]", // missing
]; ];
res.put(notify("Console usage", detail_msg.join("\n\t"))); res.put(notify("Console usage", detail_msg.join("\n\t")));
@@ -745,6 +744,7 @@ where
("age_seconds", DataType::Numeric), ("age_seconds", DataType::Numeric),
("prepare_cache_hit", DataType::Numeric), ("prepare_cache_hit", DataType::Numeric),
("prepare_cache_miss", DataType::Numeric), ("prepare_cache_miss", DataType::Numeric),
("prepare_cache_eviction", DataType::Numeric),
("prepare_cache_size", DataType::Numeric), ("prepare_cache_size", DataType::Numeric),
]; ];
@@ -777,6 +777,10 @@ where
.prepared_miss_count .prepared_miss_count
.load(Ordering::Relaxed) .load(Ordering::Relaxed)
.to_string(), .to_string(),
server
.prepared_eviction_count
.load(Ordering::Relaxed)
.to_string(),
server server
.prepared_cache_size .prepared_cache_size
.load(Ordering::Relaxed) .load(Ordering::Relaxed)
@@ -802,7 +806,7 @@ where
T: tokio::io::AsyncWrite + std::marker::Unpin, T: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
let parts: Vec<&str> = match tokens.len() == 2 { let parts: Vec<&str> = match tokens.len() == 2 {
true => tokens[1].split(",").map(|part| part.trim()).collect(), true => tokens[1].split(',').map(|part| part.trim()).collect(),
false => Vec::new(), false => Vec::new(),
}; };
@@ -865,7 +869,7 @@ where
T: tokio::io::AsyncWrite + std::marker::Unpin, T: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
let parts: Vec<&str> = match tokens.len() == 2 { let parts: Vec<&str> = match tokens.len() == 2 {
true => tokens[1].split(",").map(|part| part.trim()).collect(), true => tokens[1].split(',').map(|part| part.trim()).collect(),
false => Vec::new(), false => Vec::new(),
}; };

File diff suppressed because it is too large Load Diff

View File

@@ -25,7 +25,7 @@ pub struct Args {
} }
pub fn parse() -> Args { pub fn parse() -> Args {
return Args::parse(); Args::parse()
} }
#[derive(ValueEnum, Clone, Debug)] #[derive(ValueEnum, Clone, Debug)]

View File

@@ -116,10 +116,10 @@ impl Default for Address {
host: String::from("127.0.0.1"), host: String::from("127.0.0.1"),
port: 5432, port: 5432,
shard: 0, shard: 0,
address_index: 0,
replica_number: 0,
database: String::from("database"), database: String::from("database"),
role: Role::Replica, role: Role::Replica,
replica_number: 0,
address_index: 0,
username: String::from("username"), username: String::from("username"),
pool_name: String::from("pool_name"), pool_name: String::from("pool_name"),
mirrors: Vec::new(), mirrors: Vec::new(),
@@ -236,18 +236,14 @@ impl Default for User {
impl User { impl User {
fn validate(&self) -> Result<(), Error> { fn validate(&self) -> Result<(), Error> {
match self.min_pool_size { if let Some(min_pool_size) = self.min_pool_size {
Some(min_pool_size) => { if min_pool_size > self.pool_size {
if min_pool_size > self.pool_size { error!(
error!( "min_pool_size of {} cannot be larger than pool_size of {}",
"min_pool_size of {} cannot be larger than pool_size of {}", min_pool_size, self.pool_size
min_pool_size, self.pool_size );
); return Err(Error::BadConfig);
return Err(Error::BadConfig);
}
} }
None => (),
}; };
Ok(()) Ok(())
@@ -341,12 +337,6 @@ pub struct General {
pub auth_query: Option<String>, pub auth_query: Option<String>,
pub auth_query_user: Option<String>, pub auth_query_user: Option<String>,
pub auth_query_password: Option<String>, pub auth_query_password: Option<String>,
#[serde(default)]
pub prepared_statements: bool,
#[serde(default = "General::default_prepared_statements_cache_size")]
pub prepared_statements_cache_size: usize,
} }
impl General { impl General {
@@ -428,10 +418,6 @@ impl General {
pub fn default_server_round_robin() -> bool { pub fn default_server_round_robin() -> bool {
true true
} }
pub fn default_prepared_statements_cache_size() -> usize {
500
}
} }
impl Default for General { impl Default for General {
@@ -443,35 +429,33 @@ impl Default for General {
prometheus_exporter_port: 9930, prometheus_exporter_port: 9930,
connect_timeout: General::default_connect_timeout(), connect_timeout: General::default_connect_timeout(),
idle_timeout: General::default_idle_timeout(), idle_timeout: General::default_idle_timeout(),
shutdown_timeout: Self::default_shutdown_timeout(),
healthcheck_timeout: Self::default_healthcheck_timeout(),
healthcheck_delay: Self::default_healthcheck_delay(),
ban_time: Self::default_ban_time(),
worker_threads: Self::default_worker_threads(),
idle_client_in_transaction_timeout: Self::default_idle_client_in_transaction_timeout(),
tcp_keepalives_idle: Self::default_tcp_keepalives_idle(), tcp_keepalives_idle: Self::default_tcp_keepalives_idle(),
tcp_keepalives_count: Self::default_tcp_keepalives_count(), tcp_keepalives_count: Self::default_tcp_keepalives_count(),
tcp_keepalives_interval: Self::default_tcp_keepalives_interval(), tcp_keepalives_interval: Self::default_tcp_keepalives_interval(),
tcp_user_timeout: Self::default_tcp_user_timeout(), tcp_user_timeout: Self::default_tcp_user_timeout(),
log_client_connections: false, log_client_connections: false,
log_client_disconnections: false, log_client_disconnections: false,
autoreload: None,
dns_cache_enabled: false, dns_cache_enabled: false,
dns_max_ttl: Self::default_dns_max_ttl(), dns_max_ttl: Self::default_dns_max_ttl(),
shutdown_timeout: Self::default_shutdown_timeout(),
healthcheck_timeout: Self::default_healthcheck_timeout(),
healthcheck_delay: Self::default_healthcheck_delay(),
ban_time: Self::default_ban_time(),
idle_client_in_transaction_timeout: Self::default_idle_client_in_transaction_timeout(),
server_lifetime: Self::default_server_lifetime(),
server_round_robin: Self::default_server_round_robin(),
worker_threads: Self::default_worker_threads(),
autoreload: None,
tls_certificate: None, tls_certificate: None,
tls_private_key: None, tls_private_key: None,
server_tls: false, server_tls: false,
verify_server_certificate: false, verify_server_certificate: false,
admin_username: String::from("admin"), admin_username: String::from("admin"),
admin_password: String::from("admin"), admin_password: String::from("admin"),
validate_config: true,
auth_query: None, auth_query: None,
auth_query_user: None, auth_query_user: None,
auth_query_password: None, auth_query_password: None,
server_lifetime: Self::default_server_lifetime(),
server_round_robin: Self::default_server_round_robin(),
validate_config: true,
prepared_statements: false,
prepared_statements_cache_size: 500,
} }
} }
} }
@@ -572,6 +556,9 @@ pub struct Pool {
#[serde(default)] // False #[serde(default)] // False
pub log_client_parameter_status_changes: bool, pub log_client_parameter_status_changes: bool,
#[serde(default = "Pool::default_prepared_statements_cache_size")]
pub prepared_statements_cache_size: usize,
pub plugins: Option<Plugins>, pub plugins: Option<Plugins>,
pub shards: BTreeMap<String, Shard>, pub shards: BTreeMap<String, Shard>,
pub users: BTreeMap<String, User>, pub users: BTreeMap<String, User>,
@@ -621,6 +608,10 @@ impl Pool {
true true
} }
pub fn default_prepared_statements_cache_size() -> usize {
0
}
pub fn validate(&mut self) -> Result<(), Error> { pub fn validate(&mut self) -> Result<(), Error> {
match self.default_role.as_ref() { match self.default_role.as_ref() {
"any" => (), "any" => (),
@@ -677,9 +668,9 @@ impl Pool {
Some(key) => { Some(key) => {
// No quotes in the key so we don't have to compare quoted // No quotes in the key so we don't have to compare quoted
// to unquoted idents. // to unquoted idents.
let key = key.replace("\"", ""); let key = key.replace('\"', "");
if key.split(".").count() != 2 { if key.split('.').count() != 2 {
error!( error!(
"automatic_sharding_key '{}' must be fully qualified, e.g. t.{}`", "automatic_sharding_key '{}' must be fully qualified, e.g. t.{}`",
key, key key, key
@@ -692,17 +683,14 @@ impl Pool {
None => None, None => None,
}; };
match self.default_shard { if let DefaultShard::Shard(shard_number) = self.default_shard {
DefaultShard::Shard(shard_number) => { if shard_number >= self.shards.len() {
if shard_number >= self.shards.len() { error!("Invalid shard {:?}", shard_number);
error!("Invalid shard {:?}", shard_number); return Err(Error::BadConfig);
return Err(Error::BadConfig);
}
} }
_ => (),
} }
for (_, user) in &self.users { for user in self.users.values() {
user.validate()?; user.validate()?;
} }
@@ -715,17 +703,16 @@ impl Default for Pool {
Pool { Pool {
pool_mode: Self::default_pool_mode(), pool_mode: Self::default_pool_mode(),
load_balancing_mode: Self::default_load_balancing_mode(), load_balancing_mode: Self::default_load_balancing_mode(),
shards: BTreeMap::from([(String::from("1"), Shard::default())]),
users: BTreeMap::default(),
default_role: String::from("any"), default_role: String::from("any"),
query_parser_enabled: false, query_parser_enabled: false,
query_parser_max_length: None, query_parser_max_length: None,
query_parser_read_write_splitting: false, query_parser_read_write_splitting: false,
primary_reads_enabled: false, primary_reads_enabled: false,
sharding_function: ShardingFunction::PgBigintHash,
automatic_sharding_key: None,
connect_timeout: None, connect_timeout: None,
idle_timeout: None, idle_timeout: None,
server_lifetime: None,
sharding_function: ShardingFunction::PgBigintHash,
automatic_sharding_key: None,
sharding_key_regex: None, sharding_key_regex: None,
shard_id_regex: None, shard_id_regex: None,
regex_search_limit: Some(1000), regex_search_limit: Some(1000),
@@ -733,10 +720,12 @@ impl Default for Pool {
auth_query: None, auth_query: None,
auth_query_user: None, auth_query_user: None,
auth_query_password: None, auth_query_password: None,
server_lifetime: None,
plugins: None,
cleanup_server_connections: true, cleanup_server_connections: true,
log_client_parameter_status_changes: false, log_client_parameter_status_changes: false,
prepared_statements_cache_size: Self::default_prepared_statements_cache_size(),
plugins: None,
shards: BTreeMap::from([(String::from("1"), Shard::default())]),
users: BTreeMap::default(),
} }
} }
} }
@@ -777,8 +766,8 @@ impl<'de> serde::Deserialize<'de> for DefaultShard {
D: Deserializer<'de>, D: Deserializer<'de>,
{ {
let s = String::deserialize(deserializer)?; let s = String::deserialize(deserializer)?;
if s.starts_with("shard_") { if let Some(s) = s.strip_prefix("shard_") {
let shard = s[6..].parse::<usize>().map_err(serde::de::Error::custom)?; let shard = s.parse::<usize>().map_err(serde::de::Error::custom)?;
return Ok(DefaultShard::Shard(shard)); return Ok(DefaultShard::Shard(shard));
} }
@@ -848,13 +837,13 @@ impl Shard {
impl Default for Shard { impl Default for Shard {
fn default() -> Shard { fn default() -> Shard {
Shard { Shard {
database: String::from("postgres"),
mirrors: None,
servers: vec![ServerConfig { servers: vec![ServerConfig {
host: String::from("localhost"), host: String::from("localhost"),
port: 5432, port: 5432,
role: Role::Primary, role: Role::Primary,
}], }],
mirrors: None,
database: String::from("postgres"),
} }
} }
} }
@@ -867,15 +856,26 @@ pub struct Plugins {
pub prewarmer: Option<Prewarmer>, pub prewarmer: Option<Prewarmer>,
} }
pub trait Plugin {
fn is_enabled(&self) -> bool;
}
impl std::fmt::Display for Plugins { impl std::fmt::Display for Plugins {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
fn is_enabled<T: Plugin>(arg: Option<&T>) -> bool {
if let Some(arg) = arg {
arg.is_enabled()
} else {
false
}
}
write!( write!(
f, f,
"interceptor: {}, table_access: {}, query_logger: {}, prewarmer: {}", "interceptor: {}, table_access: {}, query_logger: {}, prewarmer: {}",
self.intercept.is_some(), is_enabled(self.intercept.as_ref()),
self.table_access.is_some(), is_enabled(self.table_access.as_ref()),
self.query_logger.is_some(), is_enabled(self.query_logger.as_ref()),
self.prewarmer.is_some(), is_enabled(self.prewarmer.as_ref()),
) )
} }
} }
@@ -886,23 +886,47 @@ pub struct Intercept {
pub queries: BTreeMap<String, Query>, pub queries: BTreeMap<String, Query>,
} }
impl Plugin for Intercept {
fn is_enabled(&self) -> bool {
self.enabled
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
pub struct TableAccess { pub struct TableAccess {
pub enabled: bool, pub enabled: bool,
pub tables: Vec<String>, pub tables: Vec<String>,
} }
impl Plugin for TableAccess {
fn is_enabled(&self) -> bool {
self.enabled
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
pub struct QueryLogger { pub struct QueryLogger {
pub enabled: bool, pub enabled: bool,
} }
impl Plugin for QueryLogger {
fn is_enabled(&self) -> bool {
self.enabled
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
pub struct Prewarmer { pub struct Prewarmer {
pub enabled: bool, pub enabled: bool,
pub queries: Vec<String>, pub queries: Vec<String>,
} }
impl Plugin for Prewarmer {
fn is_enabled(&self) -> bool {
self.enabled
}
}
impl Intercept { impl Intercept {
pub fn substitute(&mut self, db: &str, user: &str) { pub fn substitute(&mut self, db: &str, user: &str) {
for (_, query) in self.queries.iter_mut() { for (_, query) in self.queries.iter_mut() {
@@ -920,6 +944,7 @@ pub struct Query {
} }
impl Query { impl Query {
#[allow(clippy::needless_range_loop)]
pub fn substitute(&mut self, db: &str, user: &str) { pub fn substitute(&mut self, db: &str, user: &str) {
for col in self.result.iter_mut() { for col in self.result.iter_mut() {
for i in 0..col.len() { for i in 0..col.len() {
@@ -989,8 +1014,8 @@ impl Default for Config {
Config { Config {
path: Self::default_path(), path: Self::default_path(),
general: General::default(), general: General::default(),
pools: HashMap::default(),
plugins: None, plugins: None,
pools: HashMap::default(),
} }
} }
} }
@@ -1044,8 +1069,8 @@ impl From<&Config> for std::collections::HashMap<String, String> {
( (
format!("pools.{:?}.users", pool_name), format!("pools.{:?}.users", pool_name),
pool.users pool.users
.iter() .values()
.map(|(_username, user)| &user.username) .map(|user| &user.username)
.cloned() .cloned()
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(", "), .join(", "),
@@ -1099,6 +1124,7 @@ impl From<&Config> for std::collections::HashMap<String, String> {
impl Config { impl Config {
/// Print current configuration. /// Print current configuration.
pub fn show(&self) { pub fn show(&self) {
info!("Config path: {}", self.path);
info!("Ban time: {}s", self.general.ban_time); info!("Ban time: {}s", self.general.ban_time);
info!( info!(
"Idle client in transaction timeout: {}ms", "Idle client in transaction timeout: {}ms",
@@ -1130,13 +1156,9 @@ impl Config {
Some(tls_certificate) => { Some(tls_certificate) => {
info!("TLS certificate: {}", tls_certificate); info!("TLS certificate: {}", tls_certificate);
match self.general.tls_private_key.clone() { if let Some(tls_private_key) = self.general.tls_private_key.clone() {
Some(tls_private_key) => { info!("TLS private key: {}", tls_private_key);
info!("TLS private key: {}", tls_private_key); info!("TLS support is enabled");
info!("TLS support is enabled");
}
None => (),
} }
} }
@@ -1149,13 +1171,6 @@ impl Config {
"Server TLS certificate verification: {}", "Server TLS certificate verification: {}",
self.general.verify_server_certificate self.general.verify_server_certificate
); );
info!("Prepared statements: {}", self.general.prepared_statements);
if self.general.prepared_statements {
info!(
"Prepared statements server cache size: {}",
self.general.prepared_statements_cache_size
);
}
info!( info!(
"Plugins: {}", "Plugins: {}",
match self.plugins { match self.plugins {
@@ -1171,8 +1186,8 @@ impl Config {
pool_name, pool_name,
pool_config pool_config
.users .users
.iter() .values()
.map(|(_, user_cfg)| user_cfg.pool_size) .map(|user_cfg| user_cfg.pool_size)
.sum::<u32>() .sum::<u32>()
.to_string() .to_string()
); );
@@ -1246,6 +1261,10 @@ impl Config {
"[pool: {}] Log client parameter status changes: {}", "[pool: {}] Log client parameter status changes: {}",
pool_name, pool_config.log_client_parameter_status_changes pool_name, pool_config.log_client_parameter_status_changes
); );
info!(
"[pool: {}] Prepared statements server cache size: {}",
pool_name, pool_config.prepared_statements_cache_size
);
info!( info!(
"[pool: {}] Plugins: {}", "[pool: {}] Plugins: {}",
pool_name, pool_name,
@@ -1342,34 +1361,31 @@ impl Config {
} }
// Validate TLS! // Validate TLS!
match self.general.tls_certificate.clone() { if let Some(tls_certificate) = self.general.tls_certificate.clone() {
Some(tls_certificate) => { match load_certs(Path::new(&tls_certificate)) {
match load_certs(Path::new(&tls_certificate)) { Ok(_) => {
Ok(_) => { // Cert is okay, but what about the private key?
// Cert is okay, but what about the private key? match self.general.tls_private_key.clone() {
match self.general.tls_private_key.clone() { Some(tls_private_key) => match load_keys(Path::new(&tls_private_key)) {
Some(tls_private_key) => match load_keys(Path::new(&tls_private_key)) { Ok(_) => (),
Ok(_) => (), Err(err) => {
Err(err) => { error!("tls_private_key is incorrectly configured: {:?}", err);
error!("tls_private_key is incorrectly configured: {:?}", err);
return Err(Error::BadConfig);
}
},
None => {
error!("tls_certificate is set, but the tls_private_key is not");
return Err(Error::BadConfig); return Err(Error::BadConfig);
} }
}; },
}
Err(err) => { None => {
error!("tls_certificate is incorrectly configured: {:?}", err); error!("tls_certificate is set, but the tls_private_key is not");
return Err(Error::BadConfig); return Err(Error::BadConfig);
} }
};
}
Err(err) => {
error!("tls_certificate is incorrectly configured: {:?}", err);
return Err(Error::BadConfig);
} }
} }
None => (),
}; };
for pool in self.pools.values_mut() { for pool in self.pools.values_mut() {
@@ -1391,14 +1407,6 @@ pub fn get_idle_client_in_transaction_timeout() -> u64 {
CONFIG.load().general.idle_client_in_transaction_timeout CONFIG.load().general.idle_client_in_transaction_timeout
} }
pub fn get_prepared_statements() -> bool {
CONFIG.load().general.prepared_statements
}
pub fn get_prepared_statements_cache_size() -> usize {
CONFIG.load().general.prepared_statements_cache_size
}
/// Parse the configuration file located at the path. /// Parse the configuration file located at the path.
pub async fn parse(path: &str) -> Result<(), Error> { pub async fn parse(path: &str) -> Result<(), Error> {
let mut contents = String::new(); let mut contents = String::new();

View File

@@ -12,13 +12,16 @@ use crate::config::get_config;
use crate::errors::Error; use crate::errors::Error;
use crate::constants::MESSAGE_TERMINATOR; use crate::constants::MESSAGE_TERMINATOR;
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap; use std::collections::HashMap;
use std::ffi::CString; use std::ffi::CString;
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
use std::hash::{Hash, Hasher};
use std::io::{BufRead, Cursor}; use std::io::{BufRead, Cursor};
use std::mem; use std::mem;
use std::str::FromStr; use std::str::FromStr;
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
/// Postgres data type mappings /// Postgres data type mappings
@@ -114,19 +117,11 @@ pub fn simple_query(query: &str) -> BytesMut {
} }
/// Tell the client we're ready for another query. /// Tell the client we're ready for another query.
pub async fn ready_for_query<S>(stream: &mut S) -> Result<(), Error> pub async fn send_ready_for_query<S>(stream: &mut S) -> Result<(), Error>
where where
S: tokio::io::AsyncWrite + std::marker::Unpin, S: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
let mut bytes = BytesMut::with_capacity( write_all(stream, ready_for_query(false)).await
mem::size_of::<u8>() + mem::size_of::<i32>() + mem::size_of::<u8>(),
);
bytes.put_u8(b'Z');
bytes.put_i32(5);
bytes.put_u8(b'I'); // Idle
write_all(stream, bytes).await
} }
/// Send the startup packet the server. We're pretending we're a Pg client. /// Send the startup packet the server. We're pretending we're a Pg client.
@@ -163,12 +158,10 @@ where
match stream.write_all(&startup).await { match stream.write_all(&startup).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => { Err(err) => Err(Error::SocketError(format!(
return Err(Error::SocketError(format!( "Error writing startup to server socket - Error: {:?}",
"Error writing startup to server socket - Error: {:?}", err
err ))),
)))
}
} }
} }
@@ -244,8 +237,8 @@ pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec<u8> {
let mut md5 = Md5::new(); let mut md5 = Md5::new();
// First pass // First pass
md5.update(&password.as_bytes()); md5.update(password.as_bytes());
md5.update(&user.as_bytes()); md5.update(user.as_bytes());
let output = md5.finalize_reset(); let output = md5.finalize_reset();
@@ -281,7 +274,7 @@ where
{ {
let password = md5_hash_password(user, password, salt); let password = md5_hash_password(user, password, salt);
let mut message = BytesMut::with_capacity(password.len() as usize + 5); let mut message = BytesMut::with_capacity(password.len() + 5);
message.put_u8(b'p'); message.put_u8(b'p');
message.put_i32(password.len() as i32 + 4); message.put_i32(password.len() as i32 + 4);
@@ -295,7 +288,7 @@ where
S: tokio::io::AsyncWrite + std::marker::Unpin, S: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
let password = md5_hash_second_pass(hash, salt); let password = md5_hash_second_pass(hash, salt);
let mut message = BytesMut::with_capacity(password.len() as usize + 5); let mut message = BytesMut::with_capacity(password.len() + 5);
message.put_u8(b'p'); message.put_u8(b'p');
message.put_i32(password.len() as i32 + 4); message.put_i32(password.len() as i32 + 4);
@@ -322,7 +315,7 @@ where
res.put_slice(&set_complete[..]); res.put_slice(&set_complete[..]);
write_all_half(stream, &res).await?; write_all_half(stream, &res).await?;
ready_for_query(stream).await send_ready_for_query(stream).await
} }
/// Send a custom error message to the client. /// Send a custom error message to the client.
@@ -333,7 +326,7 @@ where
S: tokio::io::AsyncWrite + std::marker::Unpin, S: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
error_response_terminal(stream, message).await?; error_response_terminal(stream, message).await?;
ready_for_query(stream).await send_ready_for_query(stream).await
} }
/// Send a custom error message to the client. /// Send a custom error message to the client.
@@ -434,7 +427,7 @@ where
res.put(command_complete("SELECT 1")); res.put(command_complete("SELECT 1"));
write_all_half(stream, &res).await?; write_all_half(stream, &res).await?;
ready_for_query(stream).await send_ready_for_query(stream).await
} }
pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut { pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
@@ -516,7 +509,7 @@ pub fn data_row_nullable(row: &Vec<Option<String>>) -> BytesMut {
data_row.put_i32(column.len() as i32); data_row.put_i32(column.len() as i32);
data_row.put_slice(column); data_row.put_slice(column);
} else { } else {
data_row.put_i32(-1 as i32); data_row.put_i32(-1_i32);
} }
} }
@@ -564,6 +557,37 @@ pub fn flush() -> BytesMut {
bytes bytes
} }
pub fn sync() -> BytesMut {
let mut bytes = BytesMut::with_capacity(mem::size_of::<u8>() + mem::size_of::<i32>());
bytes.put_u8(b'S');
bytes.put_i32(4);
bytes
}
pub fn parse_complete() -> BytesMut {
let mut bytes = BytesMut::with_capacity(mem::size_of::<u8>() + mem::size_of::<i32>());
bytes.put_u8(b'1');
bytes.put_i32(4);
bytes
}
pub fn ready_for_query(in_transaction: bool) -> BytesMut {
let mut bytes = BytesMut::with_capacity(
mem::size_of::<u8>() + mem::size_of::<i32>() + mem::size_of::<u8>(),
);
bytes.put_u8(b'Z');
bytes.put_i32(5);
if in_transaction {
bytes.put_u8(b'T');
} else {
bytes.put_u8(b'I');
}
bytes
}
/// Write all data in the buffer to the TcpStream. /// Write all data in the buffer to the TcpStream.
pub async fn write_all<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error> pub async fn write_all<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error>
where where
@@ -571,12 +595,10 @@ where
{ {
match stream.write_all(&buf).await { match stream.write_all(&buf).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => { Err(err) => Err(Error::SocketError(format!(
return Err(Error::SocketError(format!( "Error writing to socket - Error: {:?}",
"Error writing to socket - Error: {:?}", err
err ))),
)))
}
} }
} }
@@ -587,12 +609,10 @@ where
{ {
match stream.write_all(buf).await { match stream.write_all(buf).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => { Err(err) => Err(Error::SocketError(format!(
return Err(Error::SocketError(format!( "Error writing to socket - Error: {:?}",
"Error writing to socket - Error: {:?}", err
err ))),
)))
}
} }
} }
@@ -603,19 +623,15 @@ where
match stream.write_all(buf).await { match stream.write_all(buf).await {
Ok(_) => match stream.flush().await { Ok(_) => match stream.flush().await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => { Err(err) => Err(Error::SocketError(format!(
return Err(Error::SocketError(format!( "Error flushing socket - Error: {:?}",
"Error flushing socket - Error: {:?}",
err
)))
}
},
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err err
))) ))),
} },
Err(err) => Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
))),
} }
} }
@@ -730,7 +746,7 @@ impl BytesMutReader for Cursor<&BytesMut> {
let mut buf = vec![]; let mut buf = vec![];
match self.read_until(b'\0', &mut buf) { match self.read_until(b'\0', &mut buf) {
Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()), Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()),
Err(err) => return Err(Error::ParseBytesError(err.to_string())), Err(err) => Err(Error::ParseBytesError(err.to_string())),
} }
} }
} }
@@ -746,10 +762,55 @@ impl BytesMutReader for BytesMut {
let string_bytes = self.split_to(index + 1); let string_bytes = self.split_to(index + 1);
Ok(String::from_utf8_lossy(&string_bytes[..string_bytes.len() - 1]).to_string()) Ok(String::from_utf8_lossy(&string_bytes[..string_bytes.len() - 1]).to_string())
} }
None => return Err(Error::ParseBytesError("Could not read string".to_string())), None => Err(Error::ParseBytesError("Could not read string".to_string())),
} }
} }
} }
pub enum ExtendedProtocolData {
Parse {
data: BytesMut,
metadata: Option<(Arc<Parse>, u64)>,
},
Bind {
data: BytesMut,
metadata: Option<String>,
},
Describe {
data: BytesMut,
metadata: Option<String>,
},
Execute {
data: BytesMut,
},
Close {
data: BytesMut,
close: Close,
},
}
impl ExtendedProtocolData {
pub fn create_new_parse(data: BytesMut, metadata: Option<(Arc<Parse>, u64)>) -> Self {
Self::Parse { data, metadata }
}
pub fn create_new_bind(data: BytesMut, metadata: Option<String>) -> Self {
Self::Bind { data, metadata }
}
pub fn create_new_describe(data: BytesMut, metadata: Option<String>) -> Self {
Self::Describe { data, metadata }
}
pub fn create_new_execute(data: BytesMut) -> Self {
Self::Execute { data }
}
pub fn create_new_close(data: BytesMut, close: Close) -> Self {
Self::Close { data, close }
}
}
/// Parse (F) message. /// Parse (F) message.
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html> /// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@@ -758,7 +819,6 @@ pub struct Parse {
#[allow(dead_code)] #[allow(dead_code)]
len: i32, len: i32,
pub name: String, pub name: String,
pub generated_name: String,
query: String, query: String,
num_params: i16, num_params: i16,
param_types: Vec<i32>, param_types: Vec<i32>,
@@ -784,7 +844,6 @@ impl TryFrom<&BytesMut> for Parse {
code, code,
len, len,
name, name,
generated_name: prepared_statement_name(),
query, query,
num_params, num_params,
param_types, param_types,
@@ -833,11 +892,44 @@ impl TryFrom<&Parse> for BytesMut {
} }
impl Parse { impl Parse {
pub fn rename(mut self) -> Self { /// Renames the prepared statement to a new name based on the global counter
self.name = self.generated_name.to_string(); pub fn rewrite(mut self) -> Self {
self.name = format!(
"PGCAT_{}",
PREPARED_STATEMENT_COUNTER.fetch_add(1, Ordering::SeqCst)
);
self self
} }
/// Gets the name of the prepared statement from the buffer
pub fn get_name(buf: &BytesMut) -> Result<String, Error> {
let mut cursor = Cursor::new(buf);
// Skip the code and length
cursor.advance(mem::size_of::<u8>() + mem::size_of::<i32>());
cursor.read_string()
}
/// Hashes the parse statement to be used as a key in the global cache
pub fn get_hash(&self) -> u64 {
// TODO_ZAIN: Take a look at which hashing function is being used
let mut hasher = DefaultHasher::new();
let concatenated = format!(
"{}{}{}",
self.query,
self.num_params,
self.param_types
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(",")
);
concatenated.hash(&mut hasher);
hasher.finish()
}
pub fn anonymous(&self) -> bool { pub fn anonymous(&self) -> bool {
self.name.is_empty() self.name.is_empty()
} }
@@ -968,9 +1060,42 @@ impl TryFrom<Bind> for BytesMut {
} }
impl Bind { impl Bind {
pub fn reassign(mut self, parse: &Parse) -> Self { /// Gets the name of the prepared statement from the buffer
self.prepared_statement = parse.name.clone(); pub fn get_name(buf: &BytesMut) -> Result<String, Error> {
self let mut cursor = Cursor::new(buf);
// Skip the code and length
cursor.advance(mem::size_of::<u8>() + mem::size_of::<i32>());
cursor.read_string()?;
cursor.read_string()
}
/// Renames the prepared statement to a new name
pub fn rename(buf: BytesMut, new_name: &str) -> Result<BytesMut, Error> {
let mut cursor = Cursor::new(&buf);
// Read basic data from the cursor
let code = cursor.get_u8();
let current_len = cursor.get_i32();
let portal = cursor.read_string()?;
let prepared_statement = cursor.read_string()?;
// Calculate new length
let new_len = current_len + new_name.len() as i32 - prepared_statement.len() as i32;
// Begin building the response buffer
let mut response_buf = BytesMut::with_capacity(new_len as usize + 1);
response_buf.put_u8(code);
response_buf.put_i32(new_len);
// Put the portal and new name into the buffer
// Note: panic if the provided string contains null byte
response_buf.put_slice(CString::new(portal)?.as_bytes_with_nul());
response_buf.put_slice(CString::new(new_name)?.as_bytes_with_nul());
// Add the remainder of the original buffer into the response
response_buf.put_slice(&buf[cursor.position() as usize..]);
// Return the buffer
Ok(response_buf)
} }
pub fn anonymous(&self) -> bool { pub fn anonymous(&self) -> bool {
@@ -1026,6 +1151,15 @@ impl TryFrom<Describe> for BytesMut {
} }
impl Describe { impl Describe {
pub fn empty_new() -> Describe {
Describe {
code: 'D',
len: 4 + 1 + 1,
target: 'S',
statement_name: "".to_string(),
}
}
pub fn rename(mut self, name: &str) -> Self { pub fn rename(mut self, name: &str) -> Self {
self.statement_name = name.to_string(); self.statement_name = name.to_string();
self self
@@ -1114,13 +1248,6 @@ pub fn close_complete() -> BytesMut {
bytes bytes
} }
pub fn prepared_statement_name() -> String {
format!(
"P_{}",
PREPARED_STATEMENT_COUNTER.fetch_add(1, Ordering::SeqCst)
)
}
// from https://www.postgresql.org/docs/12/protocol-error-fields.html // from https://www.postgresql.org/docs/12/protocol-error-fields.html
#[derive(Debug, Default, PartialEq)] #[derive(Debug, Default, PartialEq)]
pub struct PgErrorMsg { pub struct PgErrorMsg {
@@ -1203,7 +1330,7 @@ impl Display for PgErrorMsg {
} }
impl PgErrorMsg { impl PgErrorMsg {
pub fn parse(error_msg: Vec<u8>) -> Result<PgErrorMsg, Error> { pub fn parse(error_msg: &[u8]) -> Result<PgErrorMsg, Error> {
let mut out = PgErrorMsg { let mut out = PgErrorMsg {
severity_localized: "".to_string(), severity_localized: "".to_string(),
severity: "".to_string(), severity: "".to_string(),
@@ -1311,38 +1438,38 @@ mod tests {
fn parse_fields() { fn parse_fields() {
let mut complete_msg = vec![]; let mut complete_msg = vec![];
let severity = "FATAL"; let severity = "FATAL";
complete_msg.extend(field('S', &severity)); complete_msg.extend(field('S', severity));
complete_msg.extend(field('V', &severity)); complete_msg.extend(field('V', severity));
let error_code = "29P02"; let error_code = "29P02";
complete_msg.extend(field('C', &error_code)); complete_msg.extend(field('C', error_code));
let message = "password authentication failed for user \"wrong_user\""; let message = "password authentication failed for user \"wrong_user\"";
complete_msg.extend(field('M', &message)); complete_msg.extend(field('M', message));
let detail_msg = "super detailed message"; let detail_msg = "super detailed message";
complete_msg.extend(field('D', &detail_msg)); complete_msg.extend(field('D', detail_msg));
let hint_msg = "hint detail here"; let hint_msg = "hint detail here";
complete_msg.extend(field('H', &hint_msg)); complete_msg.extend(field('H', hint_msg));
complete_msg.extend(field('P', "123")); complete_msg.extend(field('P', "123"));
complete_msg.extend(field('p', "234")); complete_msg.extend(field('p', "234"));
let internal_query = "SELECT * from foo;"; let internal_query = "SELECT * from foo;";
complete_msg.extend(field('q', &internal_query)); complete_msg.extend(field('q', internal_query));
let where_msg = "where goes here"; let where_msg = "where goes here";
complete_msg.extend(field('W', &where_msg)); complete_msg.extend(field('W', where_msg));
let schema_msg = "schema_name"; let schema_msg = "schema_name";
complete_msg.extend(field('s', &schema_msg)); complete_msg.extend(field('s', schema_msg));
let table_msg = "table_name"; let table_msg = "table_name";
complete_msg.extend(field('t', &table_msg)); complete_msg.extend(field('t', table_msg));
let column_msg = "column_name"; let column_msg = "column_name";
complete_msg.extend(field('c', &column_msg)); complete_msg.extend(field('c', column_msg));
let data_type_msg = "type_name"; let data_type_msg = "type_name";
complete_msg.extend(field('d', &data_type_msg)); complete_msg.extend(field('d', data_type_msg));
let constraint_msg = "constraint_name"; let constraint_msg = "constraint_name";
complete_msg.extend(field('n', &constraint_msg)); complete_msg.extend(field('n', constraint_msg));
let file_msg = "pgcat.c"; let file_msg = "pgcat.c";
complete_msg.extend(field('F', &file_msg)); complete_msg.extend(field('F', file_msg));
complete_msg.extend(field('L', "335")); complete_msg.extend(field('L', "335"));
let routine_msg = "my_failing_routine"; let routine_msg = "my_failing_routine";
complete_msg.extend(field('R', &routine_msg)); complete_msg.extend(field('R', routine_msg));
tracing_subscriber::fmt() tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO) .with_max_level(tracing::Level::INFO)
@@ -1351,7 +1478,7 @@ mod tests {
info!( info!(
"full message: {}", "full message: {}",
PgErrorMsg::parse(complete_msg.clone()).unwrap() PgErrorMsg::parse(&complete_msg).unwrap()
); );
assert_eq!( assert_eq!(
PgErrorMsg { PgErrorMsg {
@@ -1374,17 +1501,17 @@ mod tests {
line: Some(335), line: Some(335),
routine: Some(routine_msg.to_string()), routine: Some(routine_msg.to_string()),
}, },
PgErrorMsg::parse(complete_msg).unwrap() PgErrorMsg::parse(&complete_msg).unwrap()
); );
let mut only_mandatory_msg = vec![]; let mut only_mandatory_msg = vec![];
only_mandatory_msg.extend(field('S', &severity)); only_mandatory_msg.extend(field('S', severity));
only_mandatory_msg.extend(field('V', &severity)); only_mandatory_msg.extend(field('V', severity));
only_mandatory_msg.extend(field('C', &error_code)); only_mandatory_msg.extend(field('C', error_code));
only_mandatory_msg.extend(field('M', &message)); only_mandatory_msg.extend(field('M', message));
only_mandatory_msg.extend(field('D', &detail_msg)); only_mandatory_msg.extend(field('D', detail_msg));
let err_fields = PgErrorMsg::parse(only_mandatory_msg.clone()).unwrap(); let err_fields = PgErrorMsg::parse(&only_mandatory_msg).unwrap();
info!("only mandatory fields: {}", &err_fields); info!("only mandatory fields: {}", &err_fields);
error!( error!(
"server error: {}: {}", "server error: {}: {}",
@@ -1411,7 +1538,7 @@ mod tests {
line: None, line: None,
routine: None, routine: None,
}, },
PgErrorMsg::parse(only_mandatory_msg).unwrap() PgErrorMsg::parse(&only_mandatory_msg).unwrap()
); );
} }
} }

View File

@@ -23,14 +23,15 @@ impl MirroredClient {
async fn create_pool(&self) -> Pool<ServerPool> { async fn create_pool(&self) -> Pool<ServerPool> {
let config = get_config(); let config = get_config();
let default = std::time::Duration::from_millis(10_000).as_millis() as u64; let default = std::time::Duration::from_millis(10_000).as_millis() as u64;
let (connection_timeout, idle_timeout, _cfg) = let (connection_timeout, idle_timeout, _cfg, prepared_statement_cache_size) =
match config.pools.get(&self.address.pool_name) { match config.pools.get(&self.address.pool_name) {
Some(cfg) => ( Some(cfg) => (
cfg.connect_timeout.unwrap_or(default), cfg.connect_timeout.unwrap_or(default),
cfg.idle_timeout.unwrap_or(default), cfg.idle_timeout.unwrap_or(default),
cfg.clone(), cfg.clone(),
cfg.prepared_statements_cache_size,
), ),
None => (default, default, crate::config::Pool::default()), None => (default, default, crate::config::Pool::default(), 0),
}; };
let manager = ServerPool::new( let manager = ServerPool::new(
@@ -42,6 +43,7 @@ impl MirroredClient {
None, None,
true, true,
false, false,
prepared_statement_cache_size,
); );
Pool::builder() Pool::builder()
@@ -137,18 +139,18 @@ impl MirroringManager {
bytes_rx, bytes_rx,
disconnect_rx: exit_rx, disconnect_rx: exit_rx,
}; };
exit_senders.push(exit_tx.clone()); exit_senders.push(exit_tx);
byte_senders.push(bytes_tx.clone()); byte_senders.push(bytes_tx);
client.start(); client.start();
}); });
Self { Self {
byte_senders: byte_senders, byte_senders,
disconnect_senders: exit_senders, disconnect_senders: exit_senders,
} }
} }
pub fn send(self: &mut Self, bytes: &BytesMut) { pub fn send(&mut self, bytes: &BytesMut) {
// We want to avoid performing an allocation if we won't be able to send the message // We want to avoid performing an allocation if we won't be able to send the message
// There is a possibility of a race here where we check the capacity and then the channel is // There is a possibility of a race here where we check the capacity and then the channel is
// closed or the capacity is reduced to 0, but mirroring is best effort anyway // closed or the capacity is reduced to 0, but mirroring is best effort anyway
@@ -170,7 +172,7 @@ impl MirroringManager {
}); });
} }
pub fn disconnect(self: &mut Self) { pub fn disconnect(&mut self) {
self.disconnect_senders self.disconnect_senders
.iter_mut() .iter_mut()
.for_each(|sender| match sender.try_send(()) { .for_each(|sender| match sender.try_send(()) {

View File

@@ -92,7 +92,7 @@ impl<'a> Plugin for Intercept<'a> {
.map(|s| { .map(|s| {
let s = s.as_str().to_string(); let s = s.as_str().to_string();
if s == "" { if s.is_empty() {
None None
} else { } else {
Some(s) Some(s)

View File

@@ -33,6 +33,7 @@ pub enum PluginOutput {
#[async_trait] #[async_trait]
pub trait Plugin { pub trait Plugin {
// Run before the query is sent to the server. // Run before the query is sent to the server.
#[allow(clippy::ptr_arg)]
async fn run( async fn run(
&mut self, &mut self,
query_router: &QueryRouter, query_router: &QueryRouter,

View File

@@ -20,7 +20,7 @@ impl<'a> Prewarmer<'a> {
self.server.address(), self.server.address(),
query query
); );
self.server.query(&query).await?; self.server.query(query).await?;
} }
Ok(()) Ok(())

View File

@@ -34,7 +34,7 @@ impl<'a> Plugin for TableAccess<'a> {
visit_relations(ast, |relation| { visit_relations(ast, |relation| {
let relation = relation.to_string(); let relation = relation.to_string();
let parts = relation.split(".").collect::<Vec<&str>>(); let parts = relation.split('.').collect::<Vec<&str>>();
let table_name = parts.last().unwrap(); let table_name = parts.last().unwrap();
if self.tables.contains(&table_name.to_string()) { if self.tables.contains(&table_name.to_string()) {

View File

@@ -3,6 +3,7 @@ use async_trait::async_trait;
use bb8::{ManageConnection, Pool, PooledConnection, QueueStrategy}; use bb8::{ManageConnection, Pool, PooledConnection, QueueStrategy};
use chrono::naive::NaiveDateTime; use chrono::naive::NaiveDateTime;
use log::{debug, error, info, warn}; use log::{debug, error, info, warn};
use lru::LruCache;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
@@ -10,6 +11,7 @@ use rand::thread_rng;
use regex::Regex; use regex::Regex;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
use std::num::NonZeroUsize;
use std::sync::atomic::AtomicU64; use std::sync::atomic::AtomicU64;
use std::sync::{ use std::sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, Ordering},
@@ -24,6 +26,7 @@ use crate::config::{
use crate::errors::Error; use crate::errors::Error;
use crate::auth_passthrough::AuthPassthrough; use crate::auth_passthrough::AuthPassthrough;
use crate::messages::Parse;
use crate::plugins::prewarmer; use crate::plugins::prewarmer;
use crate::server::{Server, ServerParameters}; use crate::server::{Server, ServerParameters};
use crate::sharding::ShardingFunction; use crate::sharding::ShardingFunction;
@@ -54,6 +57,57 @@ pub enum BanReason {
AdminBan(i64), AdminBan(i64),
} }
pub type PreparedStatementCacheType = Arc<Mutex<PreparedStatementCache>>;
// TODO: Add stats the this cache
// TODO: Add application name to the cache value to help identify which application is using the cache
// TODO: Create admin command to show which statements are in the cache
#[derive(Debug)]
pub struct PreparedStatementCache {
cache: LruCache<u64, Arc<Parse>>,
}
impl PreparedStatementCache {
pub fn new(mut size: usize) -> Self {
// Cannot be zeros
if size == 0 {
size = 1;
}
PreparedStatementCache {
cache: LruCache::new(NonZeroUsize::new(size).unwrap()),
}
}
/// Adds the prepared statement to the cache if it doesn't exist with a new name
/// if it already exists will give you the existing parse
///
/// Pass the hash to this so that we can do the compute before acquiring the lock
pub fn get_or_insert(&mut self, parse: &Parse, hash: u64) -> Arc<Parse> {
match self.cache.get(&hash) {
Some(rewritten_parse) => rewritten_parse.clone(),
None => {
let new_parse = Arc::new(parse.clone().rewrite());
let evicted = self.cache.push(hash, new_parse.clone());
if let Some((_, evicted_parse)) = evicted {
debug!(
"Evicted prepared statement {} from cache",
evicted_parse.name
);
}
new_parse
}
}
}
/// Marks the hash as most recently used if it exists
pub fn promote(&mut self, hash: &u64) {
self.cache.promote(hash);
}
}
/// An identifier for a PgCat pool, /// An identifier for a PgCat pool,
/// a database visible to clients. /// a database visible to clients.
#[derive(Hash, Debug, Clone, PartialEq, Eq, Default)] #[derive(Hash, Debug, Clone, PartialEq, Eq, Default)]
@@ -190,11 +244,11 @@ impl Default for PoolSettings {
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
pub struct ConnectionPool { pub struct ConnectionPool {
/// The pools handled internally by bb8. /// The pools handled internally by bb8.
databases: Vec<Vec<Pool<ServerPool>>>, databases: Arc<Vec<Vec<Pool<ServerPool>>>>,
/// The addresses (host, port, role) to handle /// The addresses (host, port, role) to handle
/// failover and load balancing deterministically. /// failover and load balancing deterministically.
addresses: Vec<Vec<Address>>, addresses: Arc<Vec<Vec<Address>>>,
/// List of banned addresses (see above) /// List of banned addresses (see above)
/// that should not be queried. /// that should not be queried.
@@ -206,7 +260,7 @@ pub struct ConnectionPool {
original_server_parameters: Arc<RwLock<ServerParameters>>, original_server_parameters: Arc<RwLock<ServerParameters>>,
/// Pool configuration. /// Pool configuration.
pub settings: PoolSettings, pub settings: Arc<PoolSettings>,
/// If not validated, we need to double check the pool is available before allowing a client /// If not validated, we need to double check the pool is available before allowing a client
/// to use it. /// to use it.
@@ -223,6 +277,9 @@ pub struct ConnectionPool {
/// AuthInfo /// AuthInfo
pub auth_hash: Arc<RwLock<Option<String>>>, pub auth_hash: Arc<RwLock<Option<String>>>,
/// Cache
pub prepared_statement_cache: Option<PreparedStatementCacheType>,
} }
impl ConnectionPool { impl ConnectionPool {
@@ -241,20 +298,17 @@ impl ConnectionPool {
let old_pool_ref = get_pool(pool_name, &user.username); let old_pool_ref = get_pool(pool_name, &user.username);
let identifier = PoolIdentifier::new(pool_name, &user.username); let identifier = PoolIdentifier::new(pool_name, &user.username);
match old_pool_ref { if let Some(pool) = old_pool_ref {
Some(pool) => { // If the pool hasn't changed, get existing reference and insert it into the new_pools.
// If the pool hasn't changed, get existing reference and insert it into the new_pools. // We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8). if pool.config_hash == new_pool_hash_value {
if pool.config_hash == new_pool_hash_value { info!(
info!( "[pool: {}][user: {}] has not changed",
"[pool: {}][user: {}] has not changed", pool_name, user.username
pool_name, user.username );
); new_pools.insert(identifier.clone(), pool.clone());
new_pools.insert(identifier.clone(), pool.clone()); continue;
continue;
}
} }
None => (),
} }
info!( info!(
@@ -379,6 +433,7 @@ impl ConnectionPool {
}, },
pool_config.cleanup_server_connections, pool_config.cleanup_server_connections,
pool_config.log_client_parameter_status_changes, pool_config.log_client_parameter_status_changes,
pool_config.prepared_statements_cache_size,
); );
let connect_timeout = match pool_config.connect_timeout { let connect_timeout = match pool_config.connect_timeout {
@@ -399,7 +454,7 @@ impl ConnectionPool {
}, },
}; };
let reaper_rate = *vec![idle_timeout, server_lifetime, POOL_REAPER_RATE] let reaper_rate = *[idle_timeout, server_lifetime, POOL_REAPER_RATE]
.iter() .iter()
.min() .min()
.unwrap(); .unwrap();
@@ -448,13 +503,13 @@ impl ConnectionPool {
} }
let pool = ConnectionPool { let pool = ConnectionPool {
databases: shards, databases: Arc::new(shards),
addresses, addresses: Arc::new(addresses),
banlist: Arc::new(RwLock::new(banlist)), banlist: Arc::new(RwLock::new(banlist)),
config_hash: new_pool_hash_value, config_hash: new_pool_hash_value,
original_server_parameters: Arc::new(RwLock::new(ServerParameters::new())), original_server_parameters: Arc::new(RwLock::new(ServerParameters::new())),
auth_hash: pool_auth_hash, auth_hash: pool_auth_hash,
settings: PoolSettings { settings: Arc::new(PoolSettings {
pool_mode: match user.pool_mode { pool_mode: match user.pool_mode {
Some(pool_mode) => pool_mode, Some(pool_mode) => pool_mode,
None => pool_config.pool_mode, None => pool_config.pool_mode,
@@ -489,7 +544,7 @@ impl ConnectionPool {
.clone() .clone()
.map(|regex| Regex::new(regex.as_str()).unwrap()), .map(|regex| Regex::new(regex.as_str()).unwrap()),
regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000), regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
default_shard: pool_config.default_shard.clone(), default_shard: pool_config.default_shard,
auth_query: pool_config.auth_query.clone(), auth_query: pool_config.auth_query.clone(),
auth_query_user: pool_config.auth_query_user.clone(), auth_query_user: pool_config.auth_query_user.clone(),
auth_query_password: pool_config.auth_query_password.clone(), auth_query_password: pool_config.auth_query_password.clone(),
@@ -497,17 +552,23 @@ impl ConnectionPool {
Some(ref plugins) => Some(plugins.clone()), Some(ref plugins) => Some(plugins.clone()),
None => config.plugins.clone(), None => config.plugins.clone(),
}, },
}, }),
validated: Arc::new(AtomicBool::new(false)), validated: Arc::new(AtomicBool::new(false)),
paused: Arc::new(AtomicBool::new(false)), paused: Arc::new(AtomicBool::new(false)),
paused_waiter: Arc::new(Notify::new()), paused_waiter: Arc::new(Notify::new()),
prepared_statement_cache: match pool_config.prepared_statements_cache_size {
0 => None,
_ => Some(Arc::new(Mutex::new(PreparedStatementCache::new(
pool_config.prepared_statements_cache_size,
)))),
},
}; };
// Connect to the servers to make sure pool configuration is valid // Connect to the servers to make sure pool configuration is valid
// before setting it globally. // before setting it globally.
// Do this async and somewhere else, we don't have to wait here. // Do this async and somewhere else, we don't have to wait here.
if config.general.validate_config { if config.general.validate_config {
let mut validate_pool = pool.clone(); let validate_pool = pool.clone();
tokio::task::spawn(async move { tokio::task::spawn(async move {
let _ = validate_pool.validate().await; let _ = validate_pool.validate().await;
}); });
@@ -528,7 +589,7 @@ impl ConnectionPool {
/// when they connect. /// when they connect.
/// This also warms up the pool for clients that connect when /// This also warms up the pool for clients that connect when
/// the pooler starts up. /// the pooler starts up.
pub async fn validate(&mut self) -> Result<(), Error> { pub async fn validate(&self) -> Result<(), Error> {
let mut futures = Vec::new(); let mut futures = Vec::new();
let validated = Arc::clone(&self.validated); let validated = Arc::clone(&self.validated);
@@ -678,7 +739,7 @@ impl ConnectionPool {
let mut force_healthcheck = false; let mut force_healthcheck = false;
if self.is_banned(address) { if self.is_banned(address) {
if self.try_unban(&address).await { if self.try_unban(address).await {
force_healthcheck = true; force_healthcheck = true;
} else { } else {
debug!("Address {:?} is banned", address); debug!("Address {:?} is banned", address);
@@ -806,8 +867,8 @@ impl ConnectionPool {
// Don't leave a bad connection in the pool. // Don't leave a bad connection in the pool.
server.mark_bad(); server.mark_bad();
self.ban(&address, BanReason::FailedHealthCheck, Some(client_info)); self.ban(address, BanReason::FailedHealthCheck, Some(client_info));
return false; false
} }
/// Ban an address (i.e. replica). It no longer will serve /// Ban an address (i.e. replica). It no longer will serve
@@ -931,10 +992,10 @@ impl ConnectionPool {
let guard = self.banlist.read(); let guard = self.banlist.read();
for banlist in guard.iter() { for banlist in guard.iter() {
for (address, (reason, timestamp)) in banlist.iter() { for (address, (reason, timestamp)) in banlist.iter() {
bans.push((address.clone(), (reason.clone(), timestamp.clone()))); bans.push((address.clone(), (reason.clone(), *timestamp)));
} }
} }
return bans; bans
} }
/// Get the address from the host url /// Get the address from the host url
@@ -992,7 +1053,7 @@ impl ConnectionPool {
} }
let busy = provisioned - idle; let busy = provisioned - idle;
debug!("{:?} has {:?} busy connections", address, busy); debug!("{:?} has {:?} busy connections", address, busy);
return busy; busy
} }
fn valid_shard_id(&self, shard: Option<usize>) -> bool { fn valid_shard_id(&self, shard: Option<usize>) -> bool {
@@ -1001,6 +1062,29 @@ impl ConnectionPool {
Some(shard) => shard < self.shards(), Some(shard) => shard < self.shards(),
} }
} }
/// Register a parse statement to the pool's cache and return the rewritten parse
///
/// Do not pass an anonymous parse statement to this function
pub fn register_parse_to_cache(&self, hash: u64, parse: &Parse) -> Option<Arc<Parse>> {
// We should only be calling this function if the cache is enabled
match self.prepared_statement_cache {
Some(ref prepared_statement_cache) => {
let mut cache = prepared_statement_cache.lock();
Some(cache.get_or_insert(parse, hash))
}
None => None,
}
}
/// Promote a prepared statement hash in the LRU
pub fn promote_prepared_statement_hash(&self, hash: &u64) {
// We should only be calling this function if the cache is enabled
if let Some(ref prepared_statement_cache) = self.prepared_statement_cache {
let mut cache = prepared_statement_cache.lock();
cache.promote(hash);
}
}
} }
/// Wrapper for the bb8 connection pool. /// Wrapper for the bb8 connection pool.
@@ -1028,9 +1112,13 @@ pub struct ServerPool {
/// Log client parameter status changes /// Log client parameter status changes
log_client_parameter_status_changes: bool, log_client_parameter_status_changes: bool,
/// Prepared statement cache size
prepared_statement_cache_size: usize,
} }
impl ServerPool { impl ServerPool {
#[allow(clippy::too_many_arguments)]
pub fn new( pub fn new(
address: Address, address: Address,
user: User, user: User,
@@ -1040,16 +1128,18 @@ impl ServerPool {
plugins: Option<Plugins>, plugins: Option<Plugins>,
cleanup_connections: bool, cleanup_connections: bool,
log_client_parameter_status_changes: bool, log_client_parameter_status_changes: bool,
prepared_statement_cache_size: usize,
) -> ServerPool { ) -> ServerPool {
ServerPool { ServerPool {
address, address,
user: user.clone(), user,
database: database.to_string(), database: database.to_string(),
client_server_map, client_server_map,
auth_hash, auth_hash,
plugins, plugins,
cleanup_connections, cleanup_connections,
log_client_parameter_status_changes, log_client_parameter_status_changes,
prepared_statement_cache_size,
} }
} }
} }
@@ -1080,6 +1170,7 @@ impl ManageConnection for ServerPool {
self.auth_hash.clone(), self.auth_hash.clone(),
self.cleanup_connections, self.cleanup_connections,
self.log_client_parameter_status_changes, self.log_client_parameter_status_changes,
self.prepared_statement_cache_size,
) )
.await .await
{ {

View File

@@ -4,10 +4,10 @@ use bytes::{Buf, BytesMut};
use log::{debug, error}; use log::{debug, error};
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use regex::{Regex, RegexSet}; use regex::{Regex, RegexSet};
use sqlparser::ast::Statement::{Query, StartTransaction}; use sqlparser::ast::Statement::{Delete, Insert, Query, StartTransaction, Update};
use sqlparser::ast::{ use sqlparser::ast::{
BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, Statement, TableFactor, Assignment, BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, Statement,
Value, TableFactor, TableWithJoins, Value,
}; };
use sqlparser::dialect::PostgreSqlDialect; use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser; use sqlparser::parser::Parser;
@@ -91,7 +91,7 @@ impl QueryRouter {
/// One-time initialization of regexes /// One-time initialization of regexes
/// that parse our custom SQL protocol. /// that parse our custom SQL protocol.
pub fn setup() -> bool { pub fn setup() -> bool {
let set = match RegexSet::new(&CUSTOM_SQL_REGEXES) { let set = match RegexSet::new(CUSTOM_SQL_REGEXES) {
Ok(rgx) => rgx, Ok(rgx) => rgx,
Err(err) => { Err(err) => {
error!("QueryRouter::setup Could not compile regex set: {:?}", err); error!("QueryRouter::setup Could not compile regex set: {:?}", err);
@@ -128,11 +128,11 @@ impl QueryRouter {
} }
/// Pool settings can change because of a config reload. /// Pool settings can change because of a config reload.
pub fn update_pool_settings(&mut self, pool_settings: PoolSettings) { pub fn update_pool_settings(&mut self, pool_settings: &PoolSettings) {
self.pool_settings = pool_settings; self.pool_settings = pool_settings.clone();
} }
pub fn pool_settings<'a>(&'a self) -> &'a PoolSettings { pub fn pool_settings(&self) -> &PoolSettings {
&self.pool_settings &self.pool_settings
} }
@@ -148,7 +148,7 @@ impl QueryRouter {
// Check for any sharding regex matches in any queries // Check for any sharding regex matches in any queries
if comment_shard_routing_enabled { if comment_shard_routing_enabled {
match code as char { match code {
// For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement // For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement
'P' | 'Q' => { 'P' | 'Q' => {
// Check only the first block of bytes configured by the pool settings // Check only the first block of bytes configured by the pool settings
@@ -344,16 +344,13 @@ impl QueryRouter {
let code = message_cursor.get_u8() as char; let code = message_cursor.get_u8() as char;
let len = message_cursor.get_i32() as usize; let len = message_cursor.get_i32() as usize;
match self.pool_settings.query_parser_max_length { if let Some(max_length) = self.pool_settings.query_parser_max_length {
Some(max_length) => { if len > max_length {
if len > max_length { return Err(Error::QueryRouterParserError(format!(
return Err(Error::QueryRouterParserError(format!( "Query too long for parser: {} > {}",
"Query too long for parser: {} > {}", len, max_length
len, max_length )));
)));
}
} }
None => (),
}; };
let query = match code { let query = match code {
@@ -403,6 +400,9 @@ impl QueryRouter {
return Err(Error::QueryRouterParserError("empty query".into())); return Err(Error::QueryRouterParserError("empty query".into()));
} }
let mut visited_write_statement = false;
let mut prev_inferred_shard = None;
for q in ast { for q in ast {
match q { match q {
// All transactions go to the primary, probably a write. // All transactions go to the primary, probably a write.
@@ -420,29 +420,38 @@ impl QueryRouter {
// or discard shard selection. If they point to the same shard though, // or discard shard selection. If they point to the same shard though,
// we can let them through as-is. // we can let them through as-is.
// This is basically building a database now :) // This is basically building a database now :)
match self.infer_shard(query) { let inferred_shard = self.infer_shard(query);
Some(shard) => { self.handle_inferred_shard(inferred_shard, &mut prev_inferred_shard)?;
self.active_shard = Some(shard);
debug!("Automatically using shard: {:?}", self.active_shard);
}
None => (),
};
} }
None => (), None => (),
}; };
self.active_role = match self.primary_reads_enabled() { // If we already visited a write statement, we should be going to the primary.
false => Some(Role::Replica), // If primary should not be receiving reads, use a replica. if !visited_write_statement {
true => None, // Any server role is fine in this case. self.active_role = match self.primary_reads_enabled() {
false => Some(Role::Replica), // If primary should not be receiving reads, use a replica.
true => None, // Any server role is fine in this case.
}
} }
} }
// Likely a write // Likely a write
_ => { _ => {
match &self.pool_settings.automatic_sharding_key {
Some(_) => {
// TODO: similar to the above, if we have multiple queries in the
// same message, we can either split them and execute them individually
// or discard shard selection. If they point to the same shard though,
// we can let them through as-is.
let inferred_shard = self.infer_shard_on_write(q)?;
self.handle_inferred_shard(inferred_shard, &mut prev_inferred_shard)?;
}
None => (),
};
visited_write_statement = true;
self.active_role = Some(Role::Primary); self.active_role = Some(Role::Primary);
break;
} }
}; };
} }
@@ -450,6 +459,188 @@ impl QueryRouter {
Ok(()) Ok(())
} }
fn handle_inferred_shard(
&mut self,
inferred_shard: Option<usize>,
prev_inferred_shard: &mut Option<usize>,
) -> Result<(), Error> {
if let Some(shard) = inferred_shard {
if let Some(prev_shard) = *prev_inferred_shard {
if prev_shard != shard {
debug!("Found more than one shard in the query, not supported yet");
return Err(Error::QueryRouterParserError(
"multiple shards in query".into(),
));
}
}
*prev_inferred_shard = Some(shard);
self.active_shard = Some(shard);
debug!("Automatically using shard: {:?}", self.active_shard);
};
Ok(())
}
fn infer_shard_on_write(&mut self, q: &Statement) -> Result<Option<usize>, Error> {
let mut exprs = Vec::new();
// Collect all table names from the query.
let mut table_names = Vec::new();
match q {
Insert {
or,
into: _,
table_name,
columns,
overwrite: _,
source,
partitioned,
after_columns,
table: _,
on: _,
returning: _,
} => {
// Not supported in postgres.
assert!(or.is_none());
assert!(partitioned.is_none());
assert!(after_columns.is_empty());
Self::process_table(table_name, &mut table_names);
Self::process_query(source, &mut exprs, &mut table_names, &Some(columns));
}
Delete {
tables,
from,
using,
selection,
returning: _,
} => {
if let Some(expr) = selection {
exprs.push(expr.clone());
}
// Multi tables delete are not supported in postgres.
assert!(tables.is_empty());
Self::process_tables_with_join(from, &mut exprs, &mut table_names);
if let Some(using_tbl_with_join) = using {
Self::process_tables_with_join(
using_tbl_with_join,
&mut exprs,
&mut table_names,
);
}
Self::process_selection(selection, &mut exprs);
}
Update {
table,
assignments,
from,
selection,
returning: _,
} => {
Self::process_table_with_join(table, &mut exprs, &mut table_names);
if let Some(from_tbl) = from {
Self::process_table_with_join(from_tbl, &mut exprs, &mut table_names);
}
Self::process_selection(selection, &mut exprs);
self.assignment_parser(assignments)?;
}
_ => {
return Ok(None);
}
};
Ok(self.infer_shard_from_exprs(exprs, table_names))
}
fn process_query(
query: &sqlparser::ast::Query,
exprs: &mut Vec<Expr>,
table_names: &mut Vec<Vec<Ident>>,
columns: &Option<&Vec<Ident>>,
) {
match &*query.body {
SetExpr::Query(query) => {
Self::process_query(query, exprs, table_names, columns);
}
// SELECT * FROM ...
// We understand that pretty well.
SetExpr::Select(select) => {
Self::process_tables_with_join(&select.from, exprs, table_names);
// Parse the actual "FROM ..."
Self::process_selection(&select.selection, exprs);
}
SetExpr::Values(values) => {
if let Some(cols) = columns {
for row in values.rows.iter() {
for (i, expr) in row.iter().enumerate() {
if cols.len() > i {
exprs.push(Expr::BinaryOp {
left: Box::new(Expr::Identifier(cols[i].clone())),
op: BinaryOperator::Eq,
right: Box::new(expr.clone()),
});
}
}
}
}
}
_ => (),
};
}
fn process_selection(selection: &Option<Expr>, exprs: &mut Vec<Expr>) {
match selection {
Some(selection) => {
exprs.push(selection.clone());
}
None => (),
};
}
fn process_tables_with_join(
tables: &[TableWithJoins],
exprs: &mut Vec<Expr>,
table_names: &mut Vec<Vec<Ident>>,
) {
for table in tables.iter() {
Self::process_table_with_join(table, exprs, table_names);
}
}
fn process_table_with_join(
table: &TableWithJoins,
exprs: &mut Vec<Expr>,
table_names: &mut Vec<Vec<Ident>>,
) {
if let TableFactor::Table { name, .. } = &table.relation {
Self::process_table(name, table_names);
};
// Get table names from all the joins.
for join in table.joins.iter() {
if let TableFactor::Table { name, .. } = &join.relation {
Self::process_table(name, table_names);
};
// We can filter results based on join conditions, e.g.
// SELECT * FROM t INNER JOIN B ON B.sharding_key = 5;
if let JoinOperator::Inner(JoinConstraint::On(expr)) = &join.join_operator {
// Parse the selection criteria later.
exprs.push(expr.clone());
};
}
}
fn process_table(name: &sqlparser::ast::ObjectName, table_names: &mut Vec<Vec<Ident>>) {
table_names.push(name.0.clone())
}
/// Parse the shard number from the Bind message /// Parse the shard number from the Bind message
/// which contains the arguments for a prepared statement. /// which contains the arguments for a prepared statement.
/// ///
@@ -592,6 +783,33 @@ impl QueryRouter {
} }
} }
/// An `assignments` exists in the `UPDATE` statements. This parses the assignments and makes
/// sure that we are not updating the sharding key. It's not supported yet.
fn assignment_parser(&self, assignments: &Vec<Assignment>) -> Result<(), Error> {
let sharding_key = self
.pool_settings
.automatic_sharding_key
.as_ref()
.unwrap()
.split('.')
.map(|ident| Ident::new(ident.to_lowercase()))
.collect::<Vec<Ident>>();
// Sharding key must be always fully qualified
assert_eq!(sharding_key.len(), 2);
for a in assignments {
if sharding_key[0].value == "*"
&& sharding_key[1].value == a.id.last().unwrap().value.to_lowercase()
{
return Err(Error::QueryRouterParserError(
"Sharding key cannot be updated.".into(),
));
}
}
Ok(())
}
/// A `selection` is the `WHERE` clause. This parses /// A `selection` is the `WHERE` clause. This parses
/// the clause and extracts the sharding key, if present. /// the clause and extracts the sharding key, if present.
fn selection_parser(&self, expr: &Expr, table_names: &Vec<Vec<Ident>>) -> Vec<ShardingKey> { fn selection_parser(&self, expr: &Expr, table_names: &Vec<Vec<Ident>>) -> Vec<ShardingKey> {
@@ -603,8 +821,8 @@ impl QueryRouter {
.automatic_sharding_key .automatic_sharding_key
.as_ref() .as_ref()
.unwrap() .unwrap()
.split(".") .split('.')
.map(|ident| Ident::new(ident)) .map(|ident| Ident::new(ident.to_lowercase()))
.collect::<Vec<Ident>>(); .collect::<Vec<Ident>>();
// Sharding key must be always fully qualified // Sharding key must be always fully qualified
@@ -620,7 +838,7 @@ impl QueryRouter {
Expr::Identifier(ident) => { Expr::Identifier(ident) => {
// Only if we're dealing with only one table // Only if we're dealing with only one table
// and there is no ambiguity // and there is no ambiguity
if &ident.value == &sharding_key[1].value { if ident.value.to_lowercase() == sharding_key[1].value {
// Sharding key is unique enough, don't worry about // Sharding key is unique enough, don't worry about
// table names. // table names.
if &sharding_key[0].value == "*" { if &sharding_key[0].value == "*" {
@@ -633,13 +851,13 @@ impl QueryRouter {
// SELECT * FROM t WHERE sharding_key = 5 // SELECT * FROM t WHERE sharding_key = 5
// Make sure the table name from the sharding key matches // Make sure the table name from the sharding key matches
// the table name from the query. // the table name from the query.
found = &sharding_key[0].value == &table[0].value; found = sharding_key[0].value == table[0].value.to_lowercase();
} else if table.len() == 2 { } else if table.len() == 2 {
// Table name is fully qualified with the schema: e.g. // Table name is fully qualified with the schema: e.g.
// SELECT * FROM public.t WHERE sharding_key = 5 // SELECT * FROM public.t WHERE sharding_key = 5
// Ignore the schema (TODO: at some point, we want schema support) // Ignore the schema (TODO: at some point, we want schema support)
// and use the table name only. // and use the table name only.
found = &sharding_key[0].value == &table[1].value; found = sharding_key[0].value == table[1].value.to_lowercase();
} else { } else {
debug!("Got table name with more than two idents, which is not possible"); debug!("Got table name with more than two idents, which is not possible");
} }
@@ -651,8 +869,9 @@ impl QueryRouter {
// The key is fully qualified in the query, // The key is fully qualified in the query,
// it will exist or Postgres will throw an error. // it will exist or Postgres will throw an error.
if idents.len() == 2 { if idents.len() == 2 {
found = &sharding_key[0].value == &idents[0].value found = (&sharding_key[0].value == "*"
&& &sharding_key[1].value == &idents[1].value; || sharding_key[0].value == idents[0].value.to_lowercase())
&& sharding_key[1].value == idents[1].value.to_lowercase();
} }
// TODO: key can have schema as well, e.g. public.data.id (len == 3) // TODO: key can have schema as well, e.g. public.data.id (len == 3)
} }
@@ -684,7 +903,7 @@ impl QueryRouter {
} }
Expr::Value(Value::Placeholder(placeholder)) => { Expr::Value(Value::Placeholder(placeholder)) => {
match placeholder.replace("$", "").parse::<i16>() { match placeholder.replace('$', "").parse::<i16>() {
Ok(placeholder) => result.push(ShardingKey::Placeholder(placeholder)), Ok(placeholder) => result.push(ShardingKey::Placeholder(placeholder)),
Err(_) => { Err(_) => {
debug!( debug!(
@@ -705,100 +924,48 @@ impl QueryRouter {
/// Try to figure out which shard the query should go to. /// Try to figure out which shard the query should go to.
fn infer_shard(&mut self, query: &sqlparser::ast::Query) -> Option<usize> { fn infer_shard(&mut self, query: &sqlparser::ast::Query) -> Option<usize> {
let mut shards = BTreeSet::new();
let mut exprs = Vec::new(); let mut exprs = Vec::new();
match &*query.body { // Collect all table names from the query.
SetExpr::Query(query) => { let mut table_names = Vec::new();
match self.infer_shard(&*query) {
Some(shard) => { Self::process_query(query, &mut exprs, &mut table_names, &None);
self.infer_shard_from_exprs(exprs, table_names)
}
fn infer_shard_from_exprs(
&mut self,
exprs: Vec<Expr>,
table_names: Vec<Vec<Ident>>,
) -> Option<usize> {
let mut shards = BTreeSet::new();
let sharder = Sharder::new(
self.pool_settings.shards,
self.pool_settings.sharding_function,
);
// Look for sharding keys in either the join condition
// or the selection.
for expr in exprs.iter() {
let sharding_keys = self.selection_parser(expr, &table_names);
// TODO: Add support for prepared statements here.
// This should just give us the position of the value in the `B` message.
for value in sharding_keys {
match value {
ShardingKey::Value(value) => {
let shard = sharder.shard(value);
shards.insert(shard); shards.insert(shard);
} }
None => (),
ShardingKey::Placeholder(position) => {
self.placeholders.push(position);
}
}; };
} }
}
// SELECT * FROM ...
// We understand that pretty well.
SetExpr::Select(select) => {
// Collect all table names from the query.
let mut table_names = Vec::new();
for table in select.from.iter() {
match &table.relation {
TableFactor::Table { name, .. } => {
table_names.push(name.0.clone());
}
_ => (),
};
// Get table names from all the joins.
for join in table.joins.iter() {
match &join.relation {
TableFactor::Table { name, .. } => {
table_names.push(name.0.clone());
}
_ => (),
};
// We can filter results based on join conditions, e.g.
// SELECT * FROM t INNER JOIN B ON B.sharding_key = 5;
match &join.join_operator {
JoinOperator::Inner(inner_join) => match &inner_join {
JoinConstraint::On(expr) => {
// Parse the selection criteria later.
exprs.push(expr.clone());
}
_ => (),
},
_ => (),
};
}
}
// Parse the actual "FROM ..."
match &select.selection {
Some(selection) => {
exprs.push(selection.clone());
}
None => (),
};
let sharder = Sharder::new(
self.pool_settings.shards,
self.pool_settings.sharding_function,
);
// Look for sharding keys in either the join condition
// or the selection.
for expr in exprs.iter() {
let sharding_keys = self.selection_parser(expr, &table_names);
// TODO: Add support for prepared statements here.
// This should just give us the position of the value in the `B` message.
for value in sharding_keys {
match value {
ShardingKey::Value(value) => {
let shard = sharder.shard(value);
shards.insert(shard);
}
ShardingKey::Placeholder(position) => {
self.placeholders.push(position);
}
};
}
}
}
_ => (),
};
match shards.len() { match shards.len() {
// Didn't find a sharding key, you're on your own. // Didn't find a sharding key, you're on your own.
0 => { 0 => {
@@ -830,16 +997,16 @@ impl QueryRouter {
db: &self.pool_settings.db, db: &self.pool_settings.db,
}; };
let _ = query_logger.run(&self, ast).await; let _ = query_logger.run(self, ast).await;
} }
if let Some(ref intercept) = plugins.intercept { if let Some(ref intercept) = plugins.intercept {
let mut intercept = Intercept { let mut intercept = Intercept {
enabled: intercept.enabled, enabled: intercept.enabled,
config: &intercept, config: intercept,
}; };
let result = intercept.run(&self, ast).await; let result = intercept.run(self, ast).await;
if let Ok(PluginOutput::Intercept(output)) = result { if let Ok(PluginOutput::Intercept(output)) = result {
return Ok(PluginOutput::Intercept(output)); return Ok(PluginOutput::Intercept(output));
@@ -852,7 +1019,7 @@ impl QueryRouter {
tables: &table_access.tables, tables: &table_access.tables,
}; };
let result = table_access.run(&self, ast).await; let result = table_access.run(self, ast).await;
if let Ok(PluginOutput::Deny(error)) = result { if let Ok(PluginOutput::Deny(error)) = result {
return Ok(PluginOutput::Deny(error)); return Ok(PluginOutput::Deny(error));
@@ -888,7 +1055,7 @@ impl QueryRouter {
/// Should we attempt to parse queries? /// Should we attempt to parse queries?
pub fn query_parser_enabled(&self) -> bool { pub fn query_parser_enabled(&self) -> bool {
let enabled = match self.query_parser_enabled { match self.query_parser_enabled {
None => { None => {
debug!( debug!(
"Using pool settings, query_parser_enabled: {}", "Using pool settings, query_parser_enabled: {}",
@@ -904,9 +1071,7 @@ impl QueryRouter {
); );
value value
} }
}; }
enabled
} }
pub fn primary_reads_enabled(&self) -> bool { pub fn primary_reads_enabled(&self) -> bool {
@@ -917,6 +1082,12 @@ impl QueryRouter {
} }
} }
impl Default for QueryRouter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::*; use super::*;
@@ -938,10 +1109,14 @@ mod test {
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
qr.pool_settings.query_parser_read_write_splitting = true; qr.pool_settings.query_parser_read_write_splitting = true;
assert!(qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")) != None); assert!(qr
.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"))
.is_some());
assert!(qr.query_parser_enabled()); assert!(qr.query_parser_enabled());
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
.is_some());
let queries = vec![ let queries = vec![
simple_query("SELECT * FROM items WHERE id = 5"), simple_query("SELECT * FROM items WHERE id = 5"),
@@ -983,7 +1158,9 @@ mod test {
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
let query = simple_query("SELECT * FROM items WHERE id = 5"); let query = simple_query("SELECT * FROM items WHERE id = 5");
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None); assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO on"))
.is_some());
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok()); assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), None); assert_eq!(qr.role(), None);
@@ -996,7 +1173,9 @@ mod test {
qr.pool_settings.query_parser_read_write_splitting = true; qr.pool_settings.query_parser_read_write_splitting = true;
qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")); qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"));
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
.is_some());
let prepared_stmt = BytesMut::from( let prepared_stmt = BytesMut::from(
&b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..], &b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..],
@@ -1166,9 +1345,11 @@ mod test {
qr.pool_settings.query_parser_read_write_splitting = true; qr.pool_settings.query_parser_read_write_splitting = true;
let query = simple_query("SET SERVER ROLE TO 'auto'"); let query = simple_query("SET SERVER ROLE TO 'auto'");
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
.is_some());
assert!(qr.try_execute_command(&query) != None); assert!(qr.try_execute_command(&query).is_some());
assert!(qr.query_parser_enabled()); assert!(qr.query_parser_enabled());
assert_eq!(qr.role(), None); assert_eq!(qr.role(), None);
@@ -1182,7 +1363,7 @@ mod test {
assert!(qr.query_parser_enabled()); assert!(qr.query_parser_enabled());
let query = simple_query("SET SERVER ROLE TO 'default'"); let query = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(&query) != None); assert!(qr.try_execute_command(&query).is_some());
assert!(!qr.query_parser_enabled()); assert!(!qr.query_parser_enabled());
} }
@@ -1222,7 +1403,7 @@ mod test {
assert_eq!(qr.primary_reads_enabled, None); assert_eq!(qr.primary_reads_enabled, None);
// Internal state must not be changed due to this, only defaults // Internal state must not be changed due to this, only defaults
qr.update_pool_settings(pool_settings.clone()); qr.update_pool_settings(&pool_settings);
assert_eq!(qr.active_role, None); assert_eq!(qr.active_role, None);
assert_eq!(qr.active_shard, None); assert_eq!(qr.active_shard, None);
@@ -1230,11 +1411,11 @@ mod test {
assert!(!qr.primary_reads_enabled()); assert!(!qr.primary_reads_enabled());
let q1 = simple_query("SET SERVER ROLE TO 'primary'"); let q1 = simple_query("SET SERVER ROLE TO 'primary'");
assert!(qr.try_execute_command(&q1) != None); assert!(qr.try_execute_command(&q1).is_some());
assert_eq!(qr.active_role.unwrap(), Role::Primary); assert_eq!(qr.active_role.unwrap(), Role::Primary);
let q2 = simple_query("SET SERVER ROLE TO 'default'"); let q2 = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(&q2) != None); assert!(qr.try_execute_command(&q2).is_some());
assert_eq!(qr.active_role.unwrap(), pool_settings.default_role); assert_eq!(qr.active_role.unwrap(), pool_settings.default_role);
} }
@@ -1295,29 +1476,29 @@ mod test {
}; };
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
qr.update_pool_settings(pool_settings.clone()); qr.update_pool_settings(&pool_settings);
// Shard should start out unset // Shard should start out unset
assert_eq!(qr.active_shard, None); assert_eq!(qr.active_shard, None);
// Don't panic when short query eg. ; is sent // Don't panic when short query eg. ; is sent
let q0 = simple_query(";"); let q0 = simple_query(";");
assert!(qr.try_execute_command(&q0) == None); assert!(qr.try_execute_command(&q0).is_none());
assert_eq!(qr.active_shard, None); assert_eq!(qr.active_shard, None);
// Make sure setting it works // Make sure setting it works
let q1 = simple_query("/* shard_id: 1 */ select 1 from foo;"); let q1 = simple_query("/* shard_id: 1 */ select 1 from foo;");
assert!(qr.try_execute_command(&q1) == None); assert!(qr.try_execute_command(&q1).is_none());
assert_eq!(qr.active_shard, Some(1)); assert_eq!(qr.active_shard, Some(1));
// And make sure changing it works // And make sure changing it works
let q2 = simple_query("/* shard_id: 0 */ select 1 from foo;"); let q2 = simple_query("/* shard_id: 0 */ select 1 from foo;");
assert!(qr.try_execute_command(&q2) == None); assert!(qr.try_execute_command(&q2).is_none());
assert_eq!(qr.active_shard, Some(0)); assert_eq!(qr.active_shard, Some(0));
// Validate setting by shard with expected shard copied from sharding.rs tests // Validate setting by shard with expected shard copied from sharding.rs tests
let q2 = simple_query("/* sharding_key: 6 */ select 1 from foo;"); let q2 = simple_query("/* sharding_key: 6 */ select 1 from foo;");
assert!(qr.try_execute_command(&q2) == None); assert!(qr.try_execute_command(&q2).is_none());
assert_eq!(qr.active_shard, Some(2)); assert_eq!(qr.active_shard, Some(2));
} }
@@ -1414,6 +1595,221 @@ mod test {
assert_eq!(qr.shard().unwrap(), 0); assert_eq!(qr.shard().unwrap(), 0);
} }
fn auto_shard_wrapper(qry: &str, should_succeed: bool) -> Option<usize> {
let mut qr = QueryRouter::new();
qr.pool_settings.automatic_sharding_key = Some("*.w_id".to_string());
qr.pool_settings.shards = 3;
qr.pool_settings.query_parser_read_write_splitting = true;
assert_eq!(qr.shard(), None);
let infer_res = qr.infer(&qr.parse(&simple_query(qry)).unwrap());
assert_eq!(infer_res.is_ok(), should_succeed);
qr.shard()
}
fn auto_shard(qry: &str) -> Option<usize> {
auto_shard_wrapper(qry, true)
}
fn auto_shard_fails(qry: &str) -> Option<usize> {
auto_shard_wrapper(qry, false)
}
#[test]
fn test_automatic_sharding_insert_update_delete() {
QueryRouter::setup();
assert_eq!(
auto_shard_fails(
"UPDATE ORDERS SET w_id = 3 WHERE O_ID = 3 AND O_D_ID = 3 AND W_ID = 5"
),
None
);
assert_eq!(
auto_shard_fails(
"UPDATE ORDERS o SET o.W_ID = 3 WHERE o.O_ID = 3 AND o.O_D_ID = 3 AND o.W_ID = 5"
),
None
);
assert_eq!(
auto_shard(
"UPDATE ORDERS o SET o.O_CARRIER_ID = 3 WHERE o.O_ID = 3 AND o.O_D_ID = 3 AND o.W_ID = 5"
),
Some(2)
);
}
#[test]
fn test_automatic_sharding_key_tpcc() {
QueryRouter::setup();
assert_eq!(auto_shard("SELECT * FROM my_tbl WHERE w_id = 5"), Some(2));
assert_eq!(
auto_shard("BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ"),
None
);
assert_eq!(auto_shard("COMMIT"), None);
assert_eq!(auto_shard("ROLLBACK"), None);
assert_eq!(auto_shard("SELECT NO_O_ID FROM NEW_ORDER WHERE NO_D_ID = 7 AND W_ID = 5 AND NO_O_ID > 3 LIMIT 3"), Some(2));
assert_eq!(auto_shard("SELECT NO_O_ID FROM NEW_ORDER no WHERE no.NO_D_ID = 7 AND no.W_ID = 5 AND no.NO_O_ID > 3 LIMIT 3"), Some(2));
assert_eq!(
auto_shard("DELETE FROM NEW_ORDER WHERE NO_D_ID = 7 AND W_ID = 5 AND NO_O_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT O_C_ID FROM ORDERS WHERE O_ID = 3 AND O_D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard(
"UPDATE ORDERS SET O_CARRIER_ID = 3 WHERE O_ID = 3 AND O_D_ID = 3 AND W_ID = 5"
),
Some(2)
);
assert_eq!(
auto_shard("UPDATE ORDER_LINE SET OL_DELIVERY_D = 3 WHERE OL_O_ID = 3 AND OL_D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT SUM(OL_AMOUNT) FROM ORDER_LINE WHERE OL_O_ID = 3 AND OL_D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE CUSTOMER SET C_BALANCE = C_BALANCE + 3 WHERE C_ID = 3 AND C_D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT W_TAX FROM WAREHOUSE WHERE W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT D_TAX, D_NEXT_O_ID FROM DISTRICT WHERE D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE DISTRICT SET D_NEXT_O_ID = 3 WHERE D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_DISCOUNT, C_LAST, C_CREDIT FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("INSERT INTO ORDERS (O_ID, O_D_ID, W_ID, O_C_ID, O_ENTRY_D, O_CARRIER_ID, O_OL_CNT, O_ALL_LOCAL) VALUES (3, 3, 5, 3, 3, 3, 3, 3)"),
Some(2)
);
assert_eq!(
auto_shard("INSERT INTO NEW_ORDER (NO_O_ID, NO_D_ID, W_ID) VALUES (3, 3, 5)"),
Some(2)
);
assert_eq!(
auto_shard("SELECT I_PRICE, I_NAME, I_DATA FROM ITEM WHERE I_ID = 3"),
None
);
assert_eq!(
auto_shard("SELECT S_QUANTITY, S_DATA, S_YTD, S_ORDER_CNT, S_REMOTE_CNT, S_DIST_03 FROM STOCK WHERE S_I_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE STOCK SET S_QUANTITY = 3, S_YTD = 3, S_ORDER_CNT = 3, S_REMOTE_CNT = 3 WHERE S_I_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("INSERT INTO ORDER_LINE (OL_O_ID, OL_D_ID, W_ID, OL_NUMBER, OL_I_ID, OL_SUPPLY_W_ID, OL_DELIVERY_D, OL_QUANTITY, OL_AMOUNT, OL_DIST_INFO) VALUES (3, 3, 5, 3, 3, 3, 3, 3, 3, 3)"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_BALANCE FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_BALANCE FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_LAST = '3' ORDER BY C_FIRST"),
Some(2)
);
assert_eq!(
auto_shard("SELECT O_ID, O_CARRIER_ID, O_ENTRY_D FROM ORDERS WHERE W_ID = 5 AND O_D_ID = 3 AND O_C_ID = 3 ORDER BY O_ID DESC LIMIT 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT OL_SUPPLY_W_ID, OL_I_ID, OL_QUANTITY, OL_AMOUNT, OL_DELIVERY_D FROM ORDER_LINE WHERE W_ID = 5 AND OL_D_ID = 3 AND OL_O_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT W_NAME, W_STREET_1, W_STREET_2, W_CITY, W_STATE, W_ZIP FROM WAREHOUSE WHERE W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE WAREHOUSE SET W_YTD = W_YTD + 3 WHERE W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT D_NAME, D_STREET_1, D_STREET_2, D_CITY, D_STATE, D_ZIP FROM DISTRICT WHERE W_ID = 5 AND D_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE DISTRICT SET D_YTD = D_YTD + 3 WHERE W_ID = 5 AND D_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_STREET_1, C_STREET_2, C_CITY, C_STATE, C_ZIP, C_PHONE, C_SINCE, C_CREDIT, C_CREDIT_LIM, C_DISCOUNT, C_BALANCE, C_YTD_PAYMENT, C_PAYMENT_CNT, C_DATA FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_STREET_1, C_STREET_2, C_CITY, C_STATE, C_ZIP, C_PHONE, C_SINCE, C_CREDIT, C_CREDIT_LIM, C_DISCOUNT, C_BALANCE, C_YTD_PAYMENT, C_PAYMENT_CNT, C_DATA FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_LAST = '3' ORDER BY C_FIRST"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE CUSTOMER SET C_BALANCE = 3, C_YTD_PAYMENT = 3, C_PAYMENT_CNT = 3, C_DATA = 3 WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE CUSTOMER SET C_BALANCE = 3, C_YTD_PAYMENT = 3, C_PAYMENT_CNT = 3 WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(auto_shard("INSERT INTO HISTORY (H_C_ID, H_C_D_ID, H_C_W_ID, H_D_ID, W_ID, H_DATE, H_AMOUNT, H_DATA) VALUES (3, 3, 5, 3, 5, 3, 3, 3)"), Some(2));
assert_eq!(
auto_shard("SELECT D_NEXT_O_ID FROM DISTRICT WHERE W_ID = 5 AND D_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard(
"SELECT COUNT(DISTINCT(OL_I_ID)) FROM ORDER_LINE, STOCK
WHERE ORDER_LINE.W_ID = 5
AND OL_D_ID = 3
AND OL_O_ID < 3
AND OL_O_ID >= 3
AND STOCK.W_ID = 5
AND S_I_ID = OL_I_ID
AND S_QUANTITY < 3"
),
Some(2)
);
// This is a distributed query and contains two shards
assert_eq!(
auto_shard(
"SELECT COUNT(DISTINCT(OL_I_ID)) FROM ORDER_LINE, STOCK
WHERE ORDER_LINE.W_ID = 5
AND OL_D_ID = 3
AND OL_O_ID < 3
AND OL_O_ID >= 3
AND STOCK.W_ID = 7
AND S_I_ID = OL_I_ID
AND S_QUANTITY < 3"
),
None
);
}
#[test] #[test]
fn test_prepared_statements() { fn test_prepared_statements() {
let stmt = "SELECT * FROM data WHERE id = $1"; let stmt = "SELECT * FROM data WHERE id = $1";
@@ -1458,12 +1854,13 @@ mod test {
}; };
QueryRouter::setup(); QueryRouter::setup();
let mut pool_settings = PoolSettings::default(); let pool_settings = PoolSettings {
pool_settings.query_parser_enabled = true; query_parser_enabled: true,
pool_settings.plugins = Some(plugins); plugins: Some(plugins),
..Default::default()
};
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
qr.update_pool_settings(pool_settings); qr.update_pool_settings(&pool_settings);
let query = simple_query("SELECT * FROM pg_database"); let query = simple_query("SELECT * FROM pg_database");
let ast = qr.parse(&query).unwrap(); let ast = qr.parse(&query).unwrap();

View File

@@ -79,12 +79,12 @@ impl ScramSha256 {
let server_message = Message::parse(message)?; let server_message = Message::parse(message)?;
if !server_message.nonce.starts_with(&self.nonce) { if !server_message.nonce.starts_with(&self.nonce) {
return Err(Error::ProtocolSyncError(format!("SCRAM"))); return Err(Error::ProtocolSyncError("SCRAM".to_string()));
} }
let salt = match general_purpose::STANDARD.decode(&server_message.salt) { let salt = match general_purpose::STANDARD.decode(&server_message.salt) {
Ok(salt) => salt, Ok(salt) => salt,
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))), Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
}; };
let salted_password = Self::hi( let salted_password = Self::hi(
@@ -166,9 +166,9 @@ impl ScramSha256 {
pub fn finish(&mut self, message: &BytesMut) -> Result<(), Error> { pub fn finish(&mut self, message: &BytesMut) -> Result<(), Error> {
let final_message = FinalMessage::parse(message)?; let final_message = FinalMessage::parse(message)?;
let verifier = match general_purpose::STANDARD.decode(&final_message.value) { let verifier = match general_purpose::STANDARD.decode(final_message.value) {
Ok(verifier) => verifier, Ok(verifier) => verifier,
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))), Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
}; };
let mut hmac = match Hmac::<Sha256>::new_from_slice(&self.salted_password) { let mut hmac = match Hmac::<Sha256>::new_from_slice(&self.salted_password) {
@@ -230,14 +230,14 @@ impl Message {
.collect::<Vec<String>>(); .collect::<Vec<String>>();
if parts.len() != 3 { if parts.len() != 3 {
return Err(Error::ProtocolSyncError(format!("SCRAM"))); return Err(Error::ProtocolSyncError("SCRAM".to_string()));
} }
let nonce = str::replace(&parts[0], "r=", ""); let nonce = str::replace(&parts[0], "r=", "");
let salt = str::replace(&parts[1], "s=", ""); let salt = str::replace(&parts[1], "s=", "");
let iterations = match str::replace(&parts[2], "i=", "").parse::<u32>() { let iterations = match str::replace(&parts[2], "i=", "").parse::<u32>() {
Ok(iterations) => iterations, Ok(iterations) => iterations,
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))), Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
}; };
Ok(Message { Ok(Message {
@@ -257,7 +257,7 @@ impl FinalMessage {
/// Parse the server final validation message. /// Parse the server final validation message.
pub fn parse(message: &BytesMut) -> Result<FinalMessage, Error> { pub fn parse(message: &BytesMut) -> Result<FinalMessage, Error> {
if !message.starts_with(b"v=") || message.len() < 4 { if !message.starts_with(b"v=") || message.len() < 4 {
return Err(Error::ProtocolSyncError(format!("SCRAM"))); return Err(Error::ProtocolSyncError("SCRAM".to_string()));
} }
Ok(FinalMessage { Ok(FinalMessage {

View File

@@ -3,12 +3,14 @@
use bytes::{Buf, BufMut, BytesMut}; use bytes::{Buf, BufMut, BytesMut};
use fallible_iterator::FallibleIterator; use fallible_iterator::FallibleIterator;
use log::{debug, error, info, trace, warn}; use log::{debug, error, info, trace, warn};
use lru::LruCache;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use postgres_protocol::message; use postgres_protocol::message;
use std::collections::{BTreeSet, HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::mem; use std::mem;
use std::net::IpAddr; use std::net::IpAddr;
use std::num::NonZeroUsize;
use std::sync::Arc; use std::sync::Arc;
use std::time::SystemTime; use std::time::SystemTime;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream};
@@ -16,7 +18,7 @@ use tokio::net::TcpStream;
use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore}; use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore};
use tokio_rustls::{client::TlsStream, TlsConnector}; use tokio_rustls::{client::TlsStream, TlsConnector};
use crate::config::{get_config, get_prepared_statements_cache_size, Address, User}; use crate::config::{get_config, Address, User};
use crate::constants::*; use crate::constants::*;
use crate::dns_cache::{AddrSet, CACHED_RESOLVER}; use crate::dns_cache::{AddrSet, CACHED_RESOLVER};
use crate::errors::{Error, ServerIdentifier}; use crate::errors::{Error, ServerIdentifier};
@@ -197,12 +199,8 @@ impl ServerParameters {
key = "DateStyle".to_string(); key = "DateStyle".to_string();
}; };
if TRACKED_PARAMETERS.contains(&key) { if TRACKED_PARAMETERS.contains(&key) || startup {
self.parameters.insert(key, value); self.parameters.insert(key, value);
} else {
if startup {
self.parameters.insert(key, value);
}
} }
} }
@@ -326,12 +324,13 @@ pub struct Server {
log_client_parameter_status_changes: bool, log_client_parameter_status_changes: bool,
/// Prepared statements /// Prepared statements
prepared_statements: BTreeSet<String>, prepared_statement_cache: Option<LruCache<String, ()>>,
} }
impl Server { impl Server {
/// Pretend to be the Postgres client and connect to the server given host, port and credentials. /// Pretend to be the Postgres client and connect to the server given host, port and credentials.
/// Perform the authentication and return the server in a ready for query state. /// Perform the authentication and return the server in a ready for query state.
#[allow(clippy::too_many_arguments)]
pub async fn startup( pub async fn startup(
address: &Address, address: &Address,
user: &User, user: &User,
@@ -341,6 +340,7 @@ impl Server {
auth_hash: Arc<RwLock<Option<String>>>, auth_hash: Arc<RwLock<Option<String>>>,
cleanup_connections: bool, cleanup_connections: bool,
log_client_parameter_status_changes: bool, log_client_parameter_status_changes: bool,
prepared_statement_cache_size: usize,
) -> Result<Server, Error> { ) -> Result<Server, Error> {
let cached_resolver = CACHED_RESOLVER.load(); let cached_resolver = CACHED_RESOLVER.load();
let mut addr_set: Option<AddrSet> = None; let mut addr_set: Option<AddrSet> = None;
@@ -440,10 +440,7 @@ impl Server {
// Something else? // Something else?
m => { m => {
return Err(Error::SocketError(format!( return Err(Error::SocketError(format!("Unknown message: {}", { m })));
"Unknown message: {}",
m as char
)));
} }
} }
} else { } else {
@@ -461,26 +458,20 @@ impl Server {
None => &user.username, None => &user.username,
}; };
let password = match user.server_password { let password = match user.server_password.as_ref() {
Some(ref server_password) => Some(server_password), Some(server_password) => Some(server_password),
None => match user.password { None => user.password.as_ref(),
Some(ref password) => Some(password),
None => None,
},
}; };
startup(&mut stream, username, database).await?; startup(&mut stream, username, database).await?;
let mut process_id: i32 = 0; let mut process_id: i32 = 0;
let mut secret_key: i32 = 0; let mut secret_key: i32 = 0;
let server_identifier = ServerIdentifier::new(username, &database); let server_identifier = ServerIdentifier::new(username, database);
// We'll be handling multiple packets, but they will all be structured the same. // We'll be handling multiple packets, but they will all be structured the same.
// We'll loop here until this exchange is complete. // We'll loop here until this exchange is complete.
let mut scram: Option<ScramSha256> = match password { let mut scram: Option<ScramSha256> = password.map(|password| ScramSha256::new(password));
Some(password) => Some(ScramSha256::new(password)),
None => None,
};
let mut server_parameters = ServerParameters::new(); let mut server_parameters = ServerParameters::new();
@@ -725,7 +716,7 @@ impl Server {
} }
}; };
let fields = match PgErrorMsg::parse(error) { let fields = match PgErrorMsg::parse(&error) {
Ok(f) => f, Ok(f) => f,
Err(err) => { Err(err) => {
return Err(err); return Err(err);
@@ -830,7 +821,12 @@ impl Server {
}, },
cleanup_connections, cleanup_connections,
log_client_parameter_status_changes, log_client_parameter_status_changes,
prepared_statements: BTreeSet::new(), prepared_statement_cache: match prepared_statement_cache_size {
0 => None,
_ => Some(LruCache::new(
NonZeroUsize::new(prepared_statement_cache_size).unwrap(),
)),
},
}; };
return Ok(server); return Ok(server);
@@ -882,7 +878,7 @@ impl Server {
self.mirror_send(messages); self.mirror_send(messages);
self.stats().data_sent(messages.len()); self.stats().data_sent(messages.len());
match write_all_flush(&mut self.stream, &messages).await { match write_all_flush(&mut self.stream, messages).await {
Ok(_) => { Ok(_) => {
// Successfully sent to server // Successfully sent to server
self.last_activity = SystemTime::now(); self.last_activity = SystemTime::now();
@@ -969,6 +965,20 @@ impl Server {
if self.in_copy_mode { if self.in_copy_mode {
self.in_copy_mode = false; self.in_copy_mode = false;
} }
if self.prepared_statement_cache.is_some() {
let error_message = PgErrorMsg::parse(&message)?;
if error_message.message == "cached plan must not change result type" {
warn!("Server {:?} changed schema, dropping connection to clean up prepared statements", self.address);
// This will still result in an error to the client, but this server connection will drop all cached prepared statements
// so that any new queries will be re-prepared
// TODO: Other ideas to solve errors when there are DDL changes after a statement has been prepared
// - Recreate entire connection pool to force recreation of all server connections
// - Clear the ConnectionPool's statement cache so that new statement names are generated
// - Implement a retry (re-prepare) so the client doesn't see an error
self.cleanup_state.needs_cleanup_prepare = true;
}
}
} }
// CommandComplete // CommandComplete
@@ -1079,115 +1089,92 @@ impl Server {
Ok(bytes) Ok(bytes)
} }
/// Add the prepared statement to being tracked by this server. // Determines if the server already has a prepared statement with the given name
/// The client is processing data that will create a prepared statement on this server. // Increments the prepared statement cache hit counter
pub fn will_prepare(&mut self, name: &str) { pub fn has_prepared_statement(&mut self, name: &str) -> bool {
debug!("Will prepare `{}`", name); let cache = match &mut self.prepared_statement_cache {
Some(cache) => cache,
None => return false,
};
self.prepared_statements.insert(name.to_string()); let has_it = cache.get(name).is_some();
self.stats.prepared_cache_add(); if has_it {
}
/// Check if we should prepare a statement on the server.
pub fn should_prepare(&self, name: &str) -> bool {
let should_prepare = !self.prepared_statements.contains(name);
debug!("Should prepare `{}`: {}", name, should_prepare);
if should_prepare {
self.stats.prepared_cache_miss();
} else {
self.stats.prepared_cache_hit(); self.stats.prepared_cache_hit();
} else {
self.stats.prepared_cache_miss();
} }
should_prepare has_it
} }
/// Create a prepared statement on the server. pub fn add_prepared_statement_to_cache(&mut self, name: &str) -> Option<String> {
pub async fn prepare(&mut self, parse: &Parse) -> Result<(), Error> { let cache = match &mut self.prepared_statement_cache {
debug!("Preparing `{}`", parse.name); Some(cache) => cache,
None => return None,
};
let bytes: BytesMut = parse.try_into()?;
self.send(&bytes).await?;
self.send(&flush()).await?;
// Read and discard ParseComplete (B)
match read_message(&mut self.stream).await {
Ok(_) => (),
Err(err) => {
self.bad = true;
return Err(err);
}
}
self.prepared_statements.insert(parse.name.to_string());
self.stats.prepared_cache_add(); self.stats.prepared_cache_add();
debug!("Prepared `{}`", parse.name); // If we evict something, we need to close it on the server
if let Some((evicted_name, _)) = cache.push(name.to_string(), ()) {
Ok(()) if evicted_name != name {
} debug!(
"Evicted prepared statement {} from cache, replaced with {}",
/// Maintain adequate cache size on the server. evicted_name, name
pub async fn maintain_cache(&mut self) -> Result<(), Error> { );
debug!("Cache maintenance run"); return Some(evicted_name);
let max_cache_size = get_prepared_statements_cache_size();
let mut names = Vec::new();
while self.prepared_statements.len() >= max_cache_size {
// The prepared statmeents are alphanumerically sorted by the BTree.
// FIFO.
if let Some(name) = self.prepared_statements.pop_last() {
names.push(name);
} }
} };
if !names.is_empty() { None
self.deallocate(names).await?;
}
Ok(())
} }
/// Remove the prepared statement from being tracked by this server. pub fn remove_prepared_statement_from_cache(&mut self, name: &str) {
/// The client is processing data that will cause the server to close the prepared statement. let cache = match &mut self.prepared_statement_cache {
pub fn will_close(&mut self, name: &str) { Some(cache) => cache,
debug!("Will close `{}`", name); None => return,
};
self.prepared_statements.remove(name); self.stats.prepared_cache_remove();
cache.pop(name);
} }
/// Close a prepared statement on the server. pub async fn register_prepared_statement(
pub async fn deallocate(&mut self, names: Vec<String>) -> Result<(), Error> { &mut self,
for name in &names { parse: &Parse,
debug!("Deallocating prepared statement `{}`", name); should_send_parse_to_server: bool,
) -> Result<(), Error> {
if !self.has_prepared_statement(&parse.name) {
let mut bytes = BytesMut::new();
let close = Close::new(name); if should_send_parse_to_server {
let bytes: BytesMut = close.try_into()?; let parse_bytes: BytesMut = parse.try_into()?;
bytes.extend_from_slice(&parse_bytes);
}
self.send(&bytes).await?; // If we evict something, we need to close it on the server
} // We do this by adding it to the messages we're sending to the server before the sync
if let Some(evicted_name) = self.add_prepared_statement_to_cache(&parse.name) {
if !names.is_empty() { self.remove_prepared_statement_from_cache(&evicted_name);
self.send(&flush()).await?; let close_bytes: BytesMut = Close::new(&evicted_name).try_into()?;
} bytes.extend_from_slice(&close_bytes);
// Read and discard CloseComplete (3)
for name in &names {
match read_message(&mut self.stream).await {
Ok(_) => {
self.prepared_statements.remove(name);
self.stats.prepared_cache_remove();
debug!("Closed `{}`", name);
}
Err(err) => {
self.bad = true;
return Err(err);
}
}; };
}
// If we have a parse or close we need to send to the server, send them and sync
if !bytes.is_empty() {
bytes.extend_from_slice(&sync());
self.send(&bytes).await?;
loop {
self.recv(None).await?;
if !self.is_data_available() {
break;
}
}
}
};
Ok(()) Ok(())
} }
@@ -1324,6 +1311,10 @@ impl Server {
if self.cleanup_state.needs_cleanup_prepare { if self.cleanup_state.needs_cleanup_prepare {
reset_string.push_str("DEALLOCATE ALL;"); reset_string.push_str("DEALLOCATE ALL;");
// Since we deallocated all prepared statements, we need to clear the cache
if let Some(cache) = &mut self.prepared_statement_cache {
cache.clear();
}
}; };
self.query(&reset_string).await?; self.query(&reset_string).await?;
@@ -1359,16 +1350,14 @@ impl Server {
} }
pub fn mirror_send(&mut self, bytes: &BytesMut) { pub fn mirror_send(&mut self, bytes: &BytesMut) {
match self.mirror_manager.as_mut() { if let Some(manager) = self.mirror_manager.as_mut() {
Some(manager) => manager.send(bytes), manager.send(bytes)
None => (),
} }
} }
pub fn mirror_disconnect(&mut self) { pub fn mirror_disconnect(&mut self) {
match self.mirror_manager.as_mut() { if let Some(manager) = self.mirror_manager.as_mut() {
Some(manager) => manager.disconnect(), manager.disconnect()
None => (),
} }
} }
@@ -1391,13 +1380,14 @@ impl Server {
Arc::new(RwLock::new(None)), Arc::new(RwLock::new(None)),
true, true,
false, false,
0,
) )
.await?; .await?;
debug!("Connected!, sending query."); debug!("Connected!, sending query.");
server.send(&simple_query(query)).await?; server.send(&simple_query(query)).await?;
let mut message = server.recv(None).await?; let mut message = server.recv(None).await?;
Ok(parse_query_message(&mut message).await?) parse_query_message(&mut message).await
} }
} }

View File

@@ -64,7 +64,7 @@ impl Sharder {
fn sha1(&self, key: i64) -> usize { fn sha1(&self, key: i64) -> usize {
let mut hasher = Sha1::new(); let mut hasher = Sha1::new();
hasher.update(&key.to_string().as_bytes()); hasher.update(key.to_string().as_bytes());
let result = hasher.finalize(); let result = hasher.finalize();
@@ -202,10 +202,10 @@ mod test {
#[test] #[test]
fn test_sha1_hash() { fn test_sha1_hash() {
let sharder = Sharder::new(12, ShardingFunction::Sha1); let sharder = Sharder::new(12, ShardingFunction::Sha1);
let ids = vec![ let ids = [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
]; ];
let shards = vec![ let shards = [
4, 7, 8, 3, 6, 0, 0, 10, 3, 11, 1, 7, 4, 4, 11, 2, 5, 0, 8, 3, 4, 7, 8, 3, 6, 0, 0, 10, 3, 11, 1, 7, 4, 4, 11, 2, 5, 0, 8, 3,
]; ];

View File

@@ -86,11 +86,11 @@ impl PoolStats {
} }
} }
return map; map
} }
pub fn generate_header() -> Vec<(&'static str, DataType)> { pub fn generate_header() -> Vec<(&'static str, DataType)> {
return vec![ vec![
("database", DataType::Text), ("database", DataType::Text),
("user", DataType::Text), ("user", DataType::Text),
("pool_mode", DataType::Text), ("pool_mode", DataType::Text),
@@ -105,11 +105,11 @@ impl PoolStats {
("sv_login", DataType::Numeric), ("sv_login", DataType::Numeric),
("maxwait", DataType::Numeric), ("maxwait", DataType::Numeric),
("maxwait_us", DataType::Numeric), ("maxwait_us", DataType::Numeric),
]; ]
} }
pub fn generate_row(&self) -> Vec<String> { pub fn generate_row(&self) -> Vec<String> {
return vec![ vec![
self.identifier.db.clone(), self.identifier.db.clone(),
self.identifier.user.clone(), self.identifier.user.clone(),
self.mode.to_string(), self.mode.to_string(),
@@ -124,7 +124,7 @@ impl PoolStats {
self.sv_login.to_string(), self.sv_login.to_string(),
(self.maxwait / 1_000_000).to_string(), (self.maxwait / 1_000_000).to_string(),
(self.maxwait % 1_000_000).to_string(), (self.maxwait % 1_000_000).to_string(),
]; ]
} }
} }

View File

@@ -49,6 +49,7 @@ pub struct ServerStats {
pub error_count: Arc<AtomicU64>, pub error_count: Arc<AtomicU64>,
pub prepared_hit_count: Arc<AtomicU64>, pub prepared_hit_count: Arc<AtomicU64>,
pub prepared_miss_count: Arc<AtomicU64>, pub prepared_miss_count: Arc<AtomicU64>,
pub prepared_eviction_count: Arc<AtomicU64>,
pub prepared_cache_size: Arc<AtomicU64>, pub prepared_cache_size: Arc<AtomicU64>,
} }
@@ -68,6 +69,7 @@ impl Default for ServerStats {
reporter: get_reporter(), reporter: get_reporter(),
prepared_hit_count: Arc::new(AtomicU64::new(0)), prepared_hit_count: Arc::new(AtomicU64::new(0)),
prepared_miss_count: Arc::new(AtomicU64::new(0)), prepared_miss_count: Arc::new(AtomicU64::new(0)),
prepared_eviction_count: Arc::new(AtomicU64::new(0)),
prepared_cache_size: Arc::new(AtomicU64::new(0)), prepared_cache_size: Arc::new(AtomicU64::new(0)),
} }
} }
@@ -221,6 +223,7 @@ impl ServerStats {
} }
pub fn prepared_cache_remove(&self) { pub fn prepared_cache_remove(&self) {
self.prepared_eviction_count.fetch_add(1, Ordering::Relaxed);
self.prepared_cache_size.fetch_sub(1, Ordering::Relaxed); self.prepared_cache_size.fetch_sub(1, Ordering::Relaxed);
} }
} }

View File

@@ -36,4 +36,4 @@ SELECT abalance FROM pgbench_accounts WHERE aid = :aid;
SET SERVER ROLE TO 'replica'; SET SERVER ROLE TO 'replica';
-- Read load balancing -- Read load balancing
SELECT abalance FROM pgbench_accounts WHERE aid = :aid; SELECT abalance FROM pgbench_accounts WHERE aid = :aid;

View File

@@ -1,29 +1,214 @@
require_relative 'spec_helper' require_relative 'spec_helper'
describe 'Prepared statements' do describe 'Prepared statements' do
let(:processes) { Helpers::Pgcat.three_shard_setup('sharded_db', 5) } let(:pool_size) { 5 }
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", pool_size) }
let(:prepared_statements_cache_size) { 100 }
let(:server_round_robin) { false }
context 'enabled' do before do
it 'will work over the same connection' do new_configs = processes.pgcat.current_config
new_configs["general"]["server_round_robin"] = server_round_robin
new_configs["pools"]["sharded_db"]["prepared_statements_cache_size"] = prepared_statements_cache_size
new_configs["pools"]["sharded_db"]["users"]["0"]["pool_size"] = pool_size
processes.pgcat.update_config(new_configs)
processes.pgcat.reload_config
end
context 'when trying prepared statements' do
it 'it allows unparameterized statements to succeed' do
conn1 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
conn2 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
prepared_query = "SELECT 1"
# prepare query on server 1 and client 1
conn1.prepare('statement1', prepared_query)
conn1.exec_prepared('statement1')
conn2.transaction do
# Claim server 1 with client 2
conn2.exec("SELECT 2")
# Client 1 now runs the prepared query, and it's automatically
# prepared on server 2
conn1.prepare('statement2', prepared_query)
conn1.exec_prepared('statement2')
# Client 2 now prepares the same query that was already
# prepared on server 1. And PgBouncer reuses that already
# prepared query for this different client.
conn2.prepare('statement3', prepared_query)
conn2.exec_prepared('statement3')
end
ensure
conn1.close if conn1
conn2.close if conn2
end
it 'it allows parameterized statements to succeed' do
conn1 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
conn2 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
prepared_query = "SELECT $1"
# prepare query on server 1 and client 1
conn1.prepare('statement1', prepared_query)
conn1.exec_prepared('statement1', [1])
conn2.transaction do
# Claim server 1 with client 2
conn2.exec("SELECT 2")
# Client 1 now runs the prepared query, and it's automatically
# prepared on server 2
conn1.prepare('statement2', prepared_query)
conn1.exec_prepared('statement2', [1])
# Client 2 now prepares the same query that was already
# prepared on server 1. And PgBouncer reuses that already
# prepared query for this different client.
conn2.prepare('statement3', prepared_query)
conn2.exec_prepared('statement3', [1])
end
ensure
conn1.close if conn1
conn2.close if conn2
end
end
context 'when trying large packets' do
it "works with large parse" do
conn1 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
long_string = "1" * 4096 * 10
prepared_query = "SELECT '#{long_string}'"
# prepare query on server 1 and client 1
conn1.prepare('statement1', prepared_query)
result = conn1.exec_prepared('statement1')
# assert result matches long_string
expect(result.getvalue(0, 0)).to eq(long_string)
ensure
conn1.close if conn1
end
it "works with large bind" do
conn1 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
long_string = "1" * 4096 * 10
prepared_query = "SELECT $1::text"
# prepare query on server 1 and client 1
conn1.prepare('statement1', prepared_query)
result = conn1.exec_prepared('statement1', [long_string])
# assert result matches long_string
expect(result.getvalue(0, 0)).to eq(long_string)
ensure
conn1.close if conn1
end
end
context 'when statement cache is smaller than set of unqiue statements' do
let(:prepared_statements_cache_size) { 1 }
let(:pool_size) { 1 }
it "evicts all but 1 statement from the server cache" do
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
5.times do |i|
prepared_query = "SELECT '#{i}'"
conn.prepare("statement#{i}", prepared_query)
result = conn.exec_prepared("statement#{i}")
expect(result.getvalue(0, 0)).to eq(i.to_s)
end
# Check number of prepared statements (expected: 1)
n_statements = conn.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
expect(n_statements).to eq(1)
end
end
context 'when statement cache is larger than set of unqiue statements' do
let(:pool_size) { 1 }
it "does not evict any of the statements from the cache" do
# cache size 5
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
5.times do |i|
prepared_query = "SELECT '#{i}'"
conn.prepare("statement#{i}", prepared_query)
result = conn.exec_prepared("statement#{i}")
expect(result.getvalue(0, 0)).to eq(i.to_s)
end
# Check number of prepared statements (expected: 1)
n_statements = conn.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
expect(n_statements).to eq(5)
end
end
context 'when preparing the same query' do
let(:prepared_statements_cache_size) { 5 }
let(:pool_size) { 5 }
it "reuses statement cache when there are different statement names on the same connection" do
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user')) conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
10.times do |i| 10.times do |i|
statement_name = "statement_#{i}" statement_name = "statement_#{i}"
conn.prepare(statement_name, 'SELECT $1::int') conn.prepare(statement_name, 'SELECT $1::int')
conn.exec_prepared(statement_name, [1]) conn.exec_prepared(statement_name, [1])
conn.describe_prepared(statement_name)
end end
# Check number of prepared statements (expected: 1)
n_statements = conn.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
expect(n_statements).to eq(1)
end end
it 'will work with new connections' do it "reuses statement cache when there are different statement names on different connections" do
10.times do 10.times do |i|
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user')) conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
statement_name = "statement_#{i}"
statement_name = 'statement1' conn.prepare(statement_name, 'SELECT $1::int')
conn.prepare('statement1', 'SELECT $1::int') conn.exec_prepared(statement_name, [1])
conn.exec_prepared('statement1', [1])
conn.describe_prepared('statement1')
end end
# Check number of prepared statements (expected: 1)
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
n_statements = conn.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
expect(n_statements).to eq(1)
end
end
context 'when reloading config' do
let(:pool_size) { 1 }
it "test_reload_config" do
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
# prepare query
conn.prepare('statement1', 'SELECT 1')
conn.exec_prepared('statement1')
# Reload config which triggers pool recreation
new_configs = processes.pgcat.current_config
new_configs["pools"]["sharded_db"]["prepared_statements_cache_size"] = prepared_statements_cache_size + 1
processes.pgcat.update_config(new_configs)
processes.pgcat.reload_config
# check that we're starting with no prepared statements on the server
conn_check = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
n_statements = conn_check.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
expect(n_statements).to eq(0)
# still able to run prepared query
conn.exec_prepared('statement1')
end end
end end
end end

4
tests/rust/Cargo.lock generated
View File

@@ -1206,9 +1206,9 @@ dependencies = [
[[package]] [[package]]
name = "webpki" name = "webpki"
version = "0.22.0" version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd" checksum = "07ecc0cd7cac091bf682ec5efa18b1cff79d617b84181f38b3951dbe135f607f"
dependencies = [ dependencies = [
"ring", "ring",
"untrusted", "untrusted",