Compare commits

..

25 Commits

Author SHA1 Message Date
Lev
7a419f40ea Revert "Require a reason when marking a server bad (#654)"
This reverts commit 4dbef49ec9.
2023-12-04 19:59:53 -08:00
Lev
54c4ad140d Revert "Not sure how this sneaked past CI"
This reverts commit 4c5498b915.
2023-12-04 19:59:42 -08:00
Lev
190e32ae85 Revert "Reset wait times when checked out successfully (#656)"
This reverts commit ec3920d60f.
2023-12-04 19:59:34 -08:00
Lev Kokotov
ec3920d60f Reset wait times when checked out successfully (#656) 2023-12-04 18:33:08 -08:00
Lev
4c5498b915 Not sure how this sneaked past CI 2023-12-04 18:30:03 -08:00
Daniel Babiak
0e8064b049 only report wait times from clients currently waiting to match behavior of pgbouncer (#655)
* Change maxwait to only report wait times from clients currently waiting to match behavior of pgbouncer

* Fix tests
2023-12-04 18:19:51 -08:00
Alec
4dbef49ec9 Require a reason when marking a server bad (#654)
When calling mark_bad require a reason so it can be logged rather than
the generic message
2023-12-04 16:09:41 -08:00
Lev Kokotov
bc07dc9c81 Broken blog link 2 2023-12-03 21:01:23 -08:00
Lev Kokotov
9b8166b313 Broken blog link (#652)
Update README.md
2023-12-03 20:58:39 -08:00
Lev Kokotov
e58d69f3de Fix deb build overwriting config (#651) 2023-12-03 20:27:44 -08:00
Lev Kokotov
e76d720ffb Dont cache prepared statement with errors (#647)
* Fix prepared statement not found when prepared stmt has error

* cleanup debug

* remove more debug msgs

* sure debugged this..

* version bump

* add rust tests
2023-11-28 21:13:30 -08:00
Calvin Hughes
998cc16a3c Expose clients maxwait time in SHOW CLIENTS response via admin (#639)
* Expose clients maxwait time in SHOW CLIENTS response via PgCat admin
Displays the maxwait via maxwait_seconds and maxwait_us columns for each client that can be used to track down the wait time per client in a case where the overall pool stats shows waiting time. The maxwait_us, similar to the pool stats setup, is configured to display as a remainder alongside the maxwait_seconds.

* Use maxwait instead of maxwait_seconds to match pools column name

---------

Co-authored-by: Calvin Hughes <9379992+calvinhughes@users.noreply.github.com>
2023-11-13 11:24:39 -08:00
Jakob Schultz-Falk
7c37da2fad Support unnamed prepared statements (#635)
* Add golang test suite to reproduce issue with unnamed parameterized prepared statements

* Allow caching of unnamed prepared statements

* Passthrough describe on portals

* Remove unneeded kill

* Update Dockerfile.ci with golang

* Move out update of Dockerfiles to separate PR
2023-11-08 16:36:45 -08:00
Jakob Schultz-Falk
b45c6b1d23 Update Dockerfile.ci with golang (#637) 2023-11-08 08:25:49 -08:00
Lev Kokotov
dae240d30c Add connet_timeout and idle_timeout to the user (#634)
* Add connect_timeout to the user

* Allow user to override connect timeout

* version

* lock

* Add both timeouts to the user
2023-11-06 12:18:52 -08:00
Lev Kokotov
b52ea8e7f1 bump version (#629) 2023-10-26 10:50:45 -07:00
Zain Kabani
7d3003a16a Reimplement prepared statements with LRU cache and statement deduplication (#618)
* Initial commit

* Cleanup and add stats

* Use an arc instead of full clones to store the parse packets

* Use mutex instead

* fmt

* clippy

* fmt

* fix?

* fix?

* fmt

* typo

* Update docs

* Refactor custom protocol

* fmt

* move custom protocol handling to before parsing

* Support describe

* Add LRU for server side statement cache

* rename variable

* Refactoring

* Move docs

* Fix test

* fix

* Update tests

* trigger build

* Add more tests

* Reorder handling sync

* Support when a named describe is sent along with Parse (go pgx) and expecting results

* don't talk to client if not needed when client sends Parse

* fmt :(

* refactor tests

* nit

* Reduce hashing

* Reducing work done to decode describe and parse messages

* minor refactor

* Merge branch 'main' into zain/reimplment-prepared-statements-with-global-lru-cache

* Rewrite extended and prepared protocol message handling to better support mocking response packets and close

* An attempt to better handle if there are DDL changes that might break cached plans with ideas about how to further improve it

* fix

* Minor stats fixed and cleanup

* Cosmetic fixes (#64)

* Cosmetic fixes

* fix test

* Change server drop for statement cache error to a `deallocate all`

* Updated comments and added new idea for handling DDL changes impacting cached plans

* fix test?

* Revert test change

* trigger build, flakey test

* Avoid potential race conditions by changing get_or_insert to promote for pool LRU

* remove ps enabled variable on the server in favor of using an option

* Add close to the Extended Protocol buffer

---------

Co-authored-by: Lev Kokotov <levkk@users.noreply.github.com>
2023-10-25 15:11:57 -07:00
Zain Kabani
d37df43a90 Reduces the amount of time the get_pool operation takes (#625)
* Reduces the amount of time the get_pool operation takes

* trigger build

* Fix admin
2023-10-19 23:49:05 -07:00
Mohammad Dashti
2c7bf52c17 Removed unnecessary clippy overrides. (#614)
Removed unnecessary clippy overrides.
2023-10-11 10:13:23 -07:00
Mohammad Dashti
de8df29ca4 Added clippy to CI and fixed all clippy warnings (#613)
* Fixed all clippy warnings.

* Added `clippy` to CI.

* Reverted an unwanted change + Applied `cargo fmt`.

* Fixed the idiom version.

* Revert "Fixed the idiom version."

This reverts commit 6f78be0d42.

* Fixed clippy issues on CI.

* Revert "Fixed clippy issues on CI."

This reverts commit a9fa6ba189.

* Revert "Reverted an unwanted change + Applied `cargo fmt`."

This reverts commit 6bd37b6479.

* Revert "Fixed all clippy warnings."

This reverts commit d1f3b847e3.

* Removed Clippy

* Removed Lint

* `admin.rs` clippy fixes.

* Applied more clippy changes.

* Even more clippy changes.

* `client.rs` clippy fixes.

* `server.rs` clippy fixes.

* Revert "Removed Lint"

This reverts commit cb5042b144.

* Revert "Removed Clippy"

This reverts commit 6dec8bffb1.

* Applied lint.

* Revert "Revert "Fixed clippy issues on CI.""

This reverts commit 49164a733c.
2023-10-10 09:18:21 -07:00
Mohammad Dashti
c4fb72b9fc Added yj to dev Dockerfile (#612) 2023-10-05 18:13:22 -07:00
Mohammad Dashti
3371c01e0e Added a Plugin trait (#536)
* Improved logging

* Improved logging for more `Address` usages

* Fixed lint issues.

* Reverted the `Address` logging changes.

* Applied the PR comment by @levkk.

* Applied the PR comment by @levkk.

* Applied the PR comment by @levkk.

* Applied the PR comment by @levkk.
2023-10-03 13:13:21 -07:00
Mohammad Dashti
c2a483f36a Automatic sharding for INSERT, UPDATE, and DELETE statements. (#610)
Added support for INSERT, UPDATE, and DELETE for auto-sharding.
2023-10-03 09:36:13 -07:00
dependabot[bot]
51cd13b8b5 chore(deps): bump webpki from 0.22.0 to 0.22.2 in /tests/rust (#609)
Bumps [webpki](https://github.com/briansmith/webpki) from 0.22.0 to 0.22.2.
- [Commits](https://github.com/briansmith/webpki/commits)

---
updated-dependencies:
- dependency-name: webpki
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-10-02 15:30:42 -07:00
Nicolas Vanelslande
a054b454d2 Add psql to the container image. (#607)
It could be used to implement container health checks.
Example:
  PGPASSWORD="<some-password>" psql -U pgcat -p 6432 -h 127.0.0.1 -tA -c "show version;" -d pgcat >/dev/null
2023-09-27 09:03:39 -07:00
45 changed files with 2439 additions and 944 deletions

View File

@@ -63,6 +63,9 @@ jobs:
- run:
name: "Lint"
command: "cargo fmt --check"
- run:
name: "Clippy"
command: "cargo clippy --all --all-targets -- -Dwarnings"
- run:
name: "Tests"
command: "cargo clean && cargo build && cargo test && bash .circleci/run_tests.sh && .circleci/generate_coverage.sh"

View File

@@ -108,8 +108,24 @@ cd ../..
pip3 install -r tests/python/requirements.txt
python3 tests/python/tests.py || exit 1
#
# Go tests
# Starts its own pgcat server
#
pushd tests/go
/usr/local/go/bin/go test || exit 1
popd
start_pgcat "info"
#
# Rust tests
#
cd tests/rust
cargo run
cd ../../
# Admin tests
export PGPASSWORD=admin_pass
psql -U admin_user -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null

View File

@@ -4,7 +4,7 @@ on:
workflow_dispatch:
inputs:
packageVersion:
default: "1.1.2-dev"
default: "1.1.2-dev1"
jobs:
build:
strategy:

1
.gitignore vendored
View File

@@ -10,3 +10,4 @@ lcov.info
dev/.bash_history
dev/cache
!dev/cache/.keepme
.venv

View File

@@ -259,22 +259,6 @@ Password to be used for connecting to servers to obtain the hash used for md5 au
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
This parameter is inherited by every pool and can be redefined in pool configuration.
### prepared_statements
```
path: general.prepared_statements
default: false
```
Whether to use prepared statements or not.
### prepared_statements_cache_size
```
path: general.prepared_statements_cache_size
default: 500
```
Size of the prepared statements cache.
### dns_cache_enabled
```
path: general.dns_cache_enabled
@@ -324,6 +308,15 @@ If the client doesn't specify, PgCat routes traffic to this role by default.
`replica` round-robin between replicas only without touching the primary,
`primary` all queries go to the primary unless otherwise specified.
### prepared_statements_cache_size
```
path: general.prepared_statements_cache_size
default: 0
```
Size of the prepared statements cache. 0 means disabled.
TODO: update documentation
### query_parser_enabled
```
path: pools.<pool_name>.query_parser_enabled

View File

@@ -2,7 +2,7 @@
Thank you for contributing! Just a few tips here:
1. `cargo fmt` your code before opening up a PR
1. `cargo fmt` and `cargo clippy` your code before opening up a PR
2. Run the test suite (e.g. `pgbench`) to make sure everything still works. The tests are in `.circleci/run_tests.sh`.
3. Performance is important, make sure there are no regressions in your branch vs. `main`.

33
Cargo.lock generated
View File

@@ -17,6 +17,17 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "ahash"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f"
dependencies = [
"cfg-if",
"once_cell",
"version_check",
]
[[package]]
name = "aho-corasick"
version = "1.0.2"
@@ -26,6 +37,12 @@ dependencies = [
"memchr",
]
[[package]]
name = "allocator-api2"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5"
[[package]]
name = "android-tzdata"
version = "0.1.1"
@@ -553,6 +570,10 @@ name = "hashbrown"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a"
dependencies = [
"ahash",
"allocator-api2",
]
[[package]]
name = "heck"
@@ -821,6 +842,15 @@ version = "0.4.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4"
[[package]]
name = "lru"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1efa59af2ddfad1854ae27d75009d538d0998b4b2fd47083e743ac1a10e46c60"
dependencies = [
"hashbrown 0.14.0",
]
[[package]]
name = "lru-cache"
version = "0.1.2"
@@ -990,7 +1020,7 @@ checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94"
[[package]]
name = "pgcat"
version = "1.1.2-dev"
version = "1.1.2-dev4"
dependencies = [
"arc-swap",
"async-trait",
@@ -1008,6 +1038,7 @@ dependencies = [
"itertools",
"jemallocator",
"log",
"lru",
"md-5",
"nix",
"num_cpus",

View File

@@ -1,6 +1,6 @@
[package]
name = "pgcat"
version = "1.1.2-dev"
version = "1.1.2-dev4"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -48,6 +48,7 @@ itertools = "0.10"
clap = { version = "4.3.1", features = ["derive", "env"] }
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter", "std"]}
lru = "0.12.0"
[target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0"

View File

@@ -8,6 +8,12 @@ WORKDIR /app
RUN cargo build --release
FROM debian:bookworm-slim
RUN apt-get update && apt-get install -o Dpkg::Options::=--force-confdef -yq --no-install-recommends \
postgresql-client \
# Clean up layer
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* \
&& truncate -s 0 /var/log/*log
COPY --from=builder /app/target/release/pgcat /usr/bin/pgcat
COPY --from=builder /app/pgcat.toml /etc/pgcat/pgcat.toml
WORKDIR /etc/pgcat

View File

@@ -9,6 +9,9 @@ RUN sudo apt-get update && \
sudo apt-get upgrade curl && \
cargo install cargo-binutils rustfilt && \
rustup component add llvm-tools-preview && \
pip3 install psycopg2 && sudo gem install bundler && \
pip3 install psycopg2 && sudo gem install bundler && \
wget -O /tmp/toxiproxy-2.4.0.deb https://github.com/Shopify/toxiproxy/releases/download/v2.4.0/toxiproxy_2.4.0_linux_$(dpkg --print-architecture).deb && \
sudo dpkg -i /tmp/toxiproxy-2.4.0.deb
RUN wget -O /tmp/go1.21.3.linux-$(dpkg --print-architecture).tar.gz https://go.dev/dl/go1.21.3.linux-$(dpkg --print-architecture).tar.gz && \
sudo tar -C /usr/local -xzf /tmp/go1.21.3.linux-$(dpkg --print-architecture).tar.gz && \
rm /tmp/go1.21.3.linux-$(dpkg --print-architecture).tar.gz

View File

@@ -40,7 +40,7 @@ PgCat is stable and used in production to serve hundreds of thousands of queries
</a>
</td>
<td>
<a href="https://postgresml.org/blog/scaling-postgresml-to-one-million-requests-per-second">
<a href="https://postgresml.org/blog/scaling-postgresml-to-1-million-requests-per-second">
<img src="./images/postgresml.webp" height="70" width="auto">
</a>
</td>
@@ -57,7 +57,7 @@ PgCat is stable and used in production to serve hundreds of thousands of queries
</a>
</td>
<td>
<a href="https://postgresml.org/blog/scaling-postgresml-to-one-million-requests-per-second">
<a href="https://postgresml.org/blog/scaling-postgresml-to-1-million-requests-per-second">
PostgresML
</a>
</td>

View File

@@ -1,6 +1,8 @@
FROM rust:1.70-bullseye
# Dependencies
COPY --from=sclevine/yj /bin/yj /bin/yj
RUN /bin/yj -h
RUN apt-get update -y \
&& apt-get install -y \
llvm-11 psmisc postgresql-contrib postgresql-client \

View File

@@ -60,12 +60,6 @@ tcp_keepalives_count = 5
# Number of seconds between keepalive packets.
tcp_keepalives_interval = 5
# Handle prepared statements.
prepared_statements = true
# Prepared statements server cache size.
prepared_statements_cache_size = 500
# Path to TLS Certificate file to use for TLS connections
# tls_certificate = ".circleci/server.cert"
# Path to TLS private key file to use for TLS connections
@@ -156,6 +150,10 @@ load_balancing_mode = "random"
# `primary` all queries go to the primary unless otherwise specified.
default_role = "any"
# Prepared statements cache size.
# TODO: update documentation
prepared_statements_cache_size = 500
# If Query Parser is enabled, we'll attempt to parse
# every incoming query to determine if it's a read or a write.
# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,
@@ -303,6 +301,8 @@ username = "other_user"
password = "other_user"
pool_size = 21
statement_timeout = 15000
connect_timeout = 1000
idle_timeout = 1000
# Shard configs are structured as pool.<pool_name>.shards.<shard_id>
# Each shard config contains a list of servers that make up the shard

View File

@@ -283,7 +283,7 @@ where
{
let mut res = BytesMut::new();
let detail_msg = vec![
let detail_msg = [
"",
"SHOW HELP|CONFIG|DATABASES|POOLS|CLIENTS|SERVERS|USERS|VERSION",
// "SHOW PEERS|PEER_POOLS", // missing PEERS|PEER_POOLS
@@ -301,7 +301,6 @@ where
// "KILL <db>",
// "SUSPEND",
"SHUTDOWN",
// "WAIT_CLOSE [<db>]", // missing
];
res.put(notify("Console usage", detail_msg.join("\n\t")));
@@ -691,6 +690,8 @@ where
("query_count", DataType::Numeric),
("error_count", DataType::Numeric),
("age_seconds", DataType::Numeric),
("maxwait", DataType::Numeric),
("maxwait_us", DataType::Numeric),
];
let new_map = get_client_stats();
@@ -698,6 +699,7 @@ where
res.put(row_description(&columns));
for (_, client) in new_map {
let max_wait = client.max_wait_time.load(Ordering::Relaxed);
let row = vec![
format!("{:#010X}", client.client_id()),
client.pool_name(),
@@ -711,6 +713,8 @@ where
.duration_since(client.connect_time())
.as_secs()
.to_string(),
(max_wait / 1_000_000).to_string(),
(max_wait % 1_000_000).to_string(),
];
res.put(data_row(&row));
@@ -745,6 +749,7 @@ where
("age_seconds", DataType::Numeric),
("prepare_cache_hit", DataType::Numeric),
("prepare_cache_miss", DataType::Numeric),
("prepare_cache_eviction", DataType::Numeric),
("prepare_cache_size", DataType::Numeric),
];
@@ -777,6 +782,10 @@ where
.prepared_miss_count
.load(Ordering::Relaxed)
.to_string(),
server
.prepared_eviction_count
.load(Ordering::Relaxed)
.to_string(),
server
.prepared_cache_size
.load(Ordering::Relaxed)
@@ -802,7 +811,7 @@ where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let parts: Vec<&str> = match tokens.len() == 2 {
true => tokens[1].split(",").map(|part| part.trim()).collect(),
true => tokens[1].split(',').map(|part| part.trim()).collect(),
false => Vec::new(),
};
@@ -865,7 +874,7 @@ where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let parts: Vec<&str> = match tokens.len() == 2 {
true => tokens[1].split(",").map(|part| part.trim()).collect(),
true => tokens[1].split(',').map(|part| part.trim()).collect(),
false => Vec::new(),
};

View File

@@ -79,6 +79,8 @@ impl AuthPassthrough {
pool_mode: None,
server_lifetime: None,
min_pool_size: None,
connect_timeout: None,
idle_timeout: None,
};
let user = &address.username;

File diff suppressed because it is too large Load Diff

View File

@@ -25,7 +25,7 @@ pub struct Args {
}
pub fn parse() -> Args {
return Args::parse();
Args::parse()
}
#[derive(ValueEnum, Clone, Debug)]

View File

@@ -1,6 +1,6 @@
/// Parse the configuration file.
use arc_swap::ArcSwap;
use log::{error, info, warn};
use log::{error, info};
use once_cell::sync::Lazy;
use regex::Regex;
use serde::{Deserializer, Serializer};
@@ -116,10 +116,10 @@ impl Default for Address {
host: String::from("127.0.0.1"),
port: 5432,
shard: 0,
address_index: 0,
replica_number: 0,
database: String::from("database"),
role: Role::Replica,
replica_number: 0,
address_index: 0,
username: String::from("username"),
pool_name: String::from("pool_name"),
mirrors: Vec::new(),
@@ -216,6 +216,8 @@ pub struct User {
pub server_lifetime: Option<u64>,
#[serde(default)] // 0
pub statement_timeout: u64,
pub connect_timeout: Option<u64>,
pub idle_timeout: Option<u64>,
}
impl Default for User {
@@ -230,24 +232,22 @@ impl Default for User {
statement_timeout: 0,
pool_mode: None,
server_lifetime: None,
connect_timeout: None,
idle_timeout: None,
}
}
}
impl User {
fn validate(&self) -> Result<(), Error> {
match self.min_pool_size {
Some(min_pool_size) => {
if min_pool_size > self.pool_size {
error!(
"min_pool_size of {} cannot be larger than pool_size of {}",
min_pool_size, self.pool_size
);
return Err(Error::BadConfig);
}
if let Some(min_pool_size) = self.min_pool_size {
if min_pool_size > self.pool_size {
error!(
"min_pool_size of {} cannot be larger than pool_size of {}",
min_pool_size, self.pool_size
);
return Err(Error::BadConfig);
}
None => (),
};
Ok(())
@@ -341,12 +341,6 @@ pub struct General {
pub auth_query: Option<String>,
pub auth_query_user: Option<String>,
pub auth_query_password: Option<String>,
#[serde(default)]
pub prepared_statements: bool,
#[serde(default = "General::default_prepared_statements_cache_size")]
pub prepared_statements_cache_size: usize,
}
impl General {
@@ -428,10 +422,6 @@ impl General {
pub fn default_server_round_robin() -> bool {
true
}
pub fn default_prepared_statements_cache_size() -> usize {
500
}
}
impl Default for General {
@@ -443,35 +433,33 @@ impl Default for General {
prometheus_exporter_port: 9930,
connect_timeout: General::default_connect_timeout(),
idle_timeout: General::default_idle_timeout(),
shutdown_timeout: Self::default_shutdown_timeout(),
healthcheck_timeout: Self::default_healthcheck_timeout(),
healthcheck_delay: Self::default_healthcheck_delay(),
ban_time: Self::default_ban_time(),
worker_threads: Self::default_worker_threads(),
idle_client_in_transaction_timeout: Self::default_idle_client_in_transaction_timeout(),
tcp_keepalives_idle: Self::default_tcp_keepalives_idle(),
tcp_keepalives_count: Self::default_tcp_keepalives_count(),
tcp_keepalives_interval: Self::default_tcp_keepalives_interval(),
tcp_user_timeout: Self::default_tcp_user_timeout(),
log_client_connections: false,
log_client_disconnections: false,
autoreload: None,
dns_cache_enabled: false,
dns_max_ttl: Self::default_dns_max_ttl(),
shutdown_timeout: Self::default_shutdown_timeout(),
healthcheck_timeout: Self::default_healthcheck_timeout(),
healthcheck_delay: Self::default_healthcheck_delay(),
ban_time: Self::default_ban_time(),
idle_client_in_transaction_timeout: Self::default_idle_client_in_transaction_timeout(),
server_lifetime: Self::default_server_lifetime(),
server_round_robin: Self::default_server_round_robin(),
worker_threads: Self::default_worker_threads(),
autoreload: None,
tls_certificate: None,
tls_private_key: None,
server_tls: false,
verify_server_certificate: false,
admin_username: String::from("admin"),
admin_password: String::from("admin"),
validate_config: true,
auth_query: None,
auth_query_user: None,
auth_query_password: None,
server_lifetime: Self::default_server_lifetime(),
server_round_robin: Self::default_server_round_robin(),
validate_config: true,
prepared_statements: false,
prepared_statements_cache_size: 500,
}
}
}
@@ -572,6 +560,9 @@ pub struct Pool {
#[serde(default)] // False
pub log_client_parameter_status_changes: bool,
#[serde(default = "Pool::default_prepared_statements_cache_size")]
pub prepared_statements_cache_size: usize,
pub plugins: Option<Plugins>,
pub shards: BTreeMap<String, Shard>,
pub users: BTreeMap<String, User>,
@@ -621,6 +612,10 @@ impl Pool {
true
}
pub fn default_prepared_statements_cache_size() -> usize {
0
}
pub fn validate(&mut self) -> Result<(), Error> {
match self.default_role.as_ref() {
"any" => (),
@@ -677,9 +672,9 @@ impl Pool {
Some(key) => {
// No quotes in the key so we don't have to compare quoted
// to unquoted idents.
let key = key.replace("\"", "");
let key = key.replace('\"', "");
if key.split(".").count() != 2 {
if key.split('.').count() != 2 {
error!(
"automatic_sharding_key '{}' must be fully qualified, e.g. t.{}`",
key, key
@@ -692,17 +687,14 @@ impl Pool {
None => None,
};
match self.default_shard {
DefaultShard::Shard(shard_number) => {
if shard_number >= self.shards.len() {
error!("Invalid shard {:?}", shard_number);
return Err(Error::BadConfig);
}
if let DefaultShard::Shard(shard_number) = self.default_shard {
if shard_number >= self.shards.len() {
error!("Invalid shard {:?}", shard_number);
return Err(Error::BadConfig);
}
_ => (),
}
for (_, user) in &self.users {
for user in self.users.values() {
user.validate()?;
}
@@ -715,17 +707,16 @@ impl Default for Pool {
Pool {
pool_mode: Self::default_pool_mode(),
load_balancing_mode: Self::default_load_balancing_mode(),
shards: BTreeMap::from([(String::from("1"), Shard::default())]),
users: BTreeMap::default(),
default_role: String::from("any"),
query_parser_enabled: false,
query_parser_max_length: None,
query_parser_read_write_splitting: false,
primary_reads_enabled: false,
sharding_function: ShardingFunction::PgBigintHash,
automatic_sharding_key: None,
connect_timeout: None,
idle_timeout: None,
server_lifetime: None,
sharding_function: ShardingFunction::PgBigintHash,
automatic_sharding_key: None,
sharding_key_regex: None,
shard_id_regex: None,
regex_search_limit: Some(1000),
@@ -733,10 +724,12 @@ impl Default for Pool {
auth_query: None,
auth_query_user: None,
auth_query_password: None,
server_lifetime: None,
plugins: None,
cleanup_server_connections: true,
log_client_parameter_status_changes: false,
prepared_statements_cache_size: Self::default_prepared_statements_cache_size(),
plugins: None,
shards: BTreeMap::from([(String::from("1"), Shard::default())]),
users: BTreeMap::default(),
}
}
}
@@ -777,8 +770,8 @@ impl<'de> serde::Deserialize<'de> for DefaultShard {
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
if s.starts_with("shard_") {
let shard = s[6..].parse::<usize>().map_err(serde::de::Error::custom)?;
if let Some(s) = s.strip_prefix("shard_") {
let shard = s.parse::<usize>().map_err(serde::de::Error::custom)?;
return Ok(DefaultShard::Shard(shard));
}
@@ -848,13 +841,13 @@ impl Shard {
impl Default for Shard {
fn default() -> Shard {
Shard {
database: String::from("postgres"),
mirrors: None,
servers: vec![ServerConfig {
host: String::from("localhost"),
port: 5432,
role: Role::Primary,
}],
mirrors: None,
database: String::from("postgres"),
}
}
}
@@ -867,15 +860,26 @@ pub struct Plugins {
pub prewarmer: Option<Prewarmer>,
}
pub trait Plugin {
fn is_enabled(&self) -> bool;
}
impl std::fmt::Display for Plugins {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
fn is_enabled<T: Plugin>(arg: Option<&T>) -> bool {
if let Some(arg) = arg {
arg.is_enabled()
} else {
false
}
}
write!(
f,
"interceptor: {}, table_access: {}, query_logger: {}, prewarmer: {}",
self.intercept.is_some(),
self.table_access.is_some(),
self.query_logger.is_some(),
self.prewarmer.is_some(),
is_enabled(self.intercept.as_ref()),
is_enabled(self.table_access.as_ref()),
is_enabled(self.query_logger.as_ref()),
is_enabled(self.prewarmer.as_ref()),
)
}
}
@@ -886,23 +890,47 @@ pub struct Intercept {
pub queries: BTreeMap<String, Query>,
}
impl Plugin for Intercept {
fn is_enabled(&self) -> bool {
self.enabled
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
pub struct TableAccess {
pub enabled: bool,
pub tables: Vec<String>,
}
impl Plugin for TableAccess {
fn is_enabled(&self) -> bool {
self.enabled
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
pub struct QueryLogger {
pub enabled: bool,
}
impl Plugin for QueryLogger {
fn is_enabled(&self) -> bool {
self.enabled
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
pub struct Prewarmer {
pub enabled: bool,
pub queries: Vec<String>,
}
impl Plugin for Prewarmer {
fn is_enabled(&self) -> bool {
self.enabled
}
}
impl Intercept {
pub fn substitute(&mut self, db: &str, user: &str) {
for (_, query) in self.queries.iter_mut() {
@@ -920,6 +948,7 @@ pub struct Query {
}
impl Query {
#[allow(clippy::needless_range_loop)]
pub fn substitute(&mut self, db: &str, user: &str) {
for col in self.result.iter_mut() {
for i in 0..col.len() {
@@ -989,8 +1018,8 @@ impl Default for Config {
Config {
path: Self::default_path(),
general: General::default(),
pools: HashMap::default(),
plugins: None,
pools: HashMap::default(),
}
}
}
@@ -1044,8 +1073,8 @@ impl From<&Config> for std::collections::HashMap<String, String> {
(
format!("pools.{:?}.users", pool_name),
pool.users
.iter()
.map(|(_username, user)| &user.username)
.values()
.map(|user| &user.username)
.cloned()
.collect::<Vec<String>>()
.join(", "),
@@ -1099,6 +1128,7 @@ impl From<&Config> for std::collections::HashMap<String, String> {
impl Config {
/// Print current configuration.
pub fn show(&self) {
info!("Config path: {}", self.path);
info!("Ban time: {}s", self.general.ban_time);
info!(
"Idle client in transaction timeout: {}ms",
@@ -1130,13 +1160,9 @@ impl Config {
Some(tls_certificate) => {
info!("TLS certificate: {}", tls_certificate);
match self.general.tls_private_key.clone() {
Some(tls_private_key) => {
info!("TLS private key: {}", tls_private_key);
info!("TLS support is enabled");
}
None => (),
if let Some(tls_private_key) = self.general.tls_private_key.clone() {
info!("TLS private key: {}", tls_private_key);
info!("TLS support is enabled");
}
}
@@ -1149,13 +1175,6 @@ impl Config {
"Server TLS certificate verification: {}",
self.general.verify_server_certificate
);
info!("Prepared statements: {}", self.general.prepared_statements);
if self.general.prepared_statements {
info!(
"Prepared statements server cache size: {}",
self.general.prepared_statements_cache_size
);
}
info!(
"Plugins: {}",
match self.plugins {
@@ -1171,8 +1190,8 @@ impl Config {
pool_name,
pool_config
.users
.iter()
.map(|(_, user_cfg)| user_cfg.pool_size)
.values()
.map(|user_cfg| user_cfg.pool_size)
.sum::<u32>()
.to_string()
);
@@ -1246,6 +1265,10 @@ impl Config {
"[pool: {}] Log client parameter status changes: {}",
pool_name, pool_config.log_client_parameter_status_changes
);
info!(
"[pool: {}] Prepared statements server cache size: {}",
pool_name, pool_config.prepared_statements_cache_size
);
info!(
"[pool: {}] Plugins: {}",
pool_name,
@@ -1288,6 +1311,24 @@ impl Config {
None => "default".to_string(),
}
);
info!(
"[pool: {}][user: {}] Connection timeout: {}",
pool_name,
user.1.username,
match user.1.connect_timeout {
Some(connect_timeout) => format!("{}ms", connect_timeout),
None => "not set".to_string(),
}
);
info!(
"[pool: {}][user: {}] Idle timeout: {}",
pool_name,
user.1.username,
match user.1.idle_timeout {
Some(idle_timeout) => format!("{}ms", idle_timeout),
None => "not set".to_string(),
}
);
}
}
}
@@ -1342,42 +1383,31 @@ impl Config {
}
// Validate TLS!
match self.general.tls_certificate {
Some(ref mut tls_certificate) => {
match load_certs(Path::new(&tls_certificate)) {
Ok(_) => {
// Cert is okay, but what about the private key?
match self.general.tls_private_key {
Some(ref tls_private_key) => {
match load_keys(Path::new(&tls_private_key)) {
Ok(_) => (),
Err(err) => {
warn!(
"tls_private_key is incorrectly configured: {:?}",
err
);
self.general.tls_private_key = None;
self.general.tls_certificate = None;
}
}
if let Some(tls_certificate) = self.general.tls_certificate.clone() {
match load_certs(Path::new(&tls_certificate)) {
Ok(_) => {
// Cert is okay, but what about the private key?
match self.general.tls_private_key.clone() {
Some(tls_private_key) => match load_keys(Path::new(&tls_private_key)) {
Ok(_) => (),
Err(err) => {
error!("tls_private_key is incorrectly configured: {:?}", err);
return Err(Error::BadConfig);
}
},
None => {
warn!("tls_certificate is set, but the tls_private_key is not");
self.general.tls_private_key = None;
self.general.tls_certificate = None;
}
};
}
None => {
error!("tls_certificate is set, but the tls_private_key is not");
return Err(Error::BadConfig);
}
};
}
Err(err) => {
warn!("tls_certificate is incorrectly configured: {:?}", err);
self.general.tls_private_key = None;
self.general.tls_certificate = None;
}
Err(err) => {
error!("tls_certificate is incorrectly configured: {:?}", err);
return Err(Error::BadConfig);
}
}
None => (),
};
for pool in self.pools.values_mut() {
@@ -1399,14 +1429,6 @@ pub fn get_idle_client_in_transaction_timeout() -> u64 {
CONFIG.load().general.idle_client_in_transaction_timeout
}
pub fn get_prepared_statements() -> bool {
CONFIG.load().general.prepared_statements
}
pub fn get_prepared_statements_cache_size() -> usize {
CONFIG.load().general.prepared_statements_cache_size
}
/// Parse the configuration file located at the path.
pub async fn parse(path: &str) -> Result<(), Error> {
let mut contents = String::new();

View File

@@ -29,6 +29,7 @@ pub enum Error {
QueryRouterParserError(String),
QueryRouterError(String),
InvalidShardId(usize),
PreparedStatementError,
}
#[derive(Clone, PartialEq, Debug)]

View File

@@ -12,13 +12,16 @@ use crate::config::get_config;
use crate::errors::Error;
use crate::constants::MESSAGE_TERMINATOR;
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::ffi::CString;
use std::fmt::{Display, Formatter};
use std::hash::{Hash, Hasher};
use std::io::{BufRead, Cursor};
use std::mem;
use std::str::FromStr;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::Duration;
/// Postgres data type mappings
@@ -114,19 +117,11 @@ pub fn simple_query(query: &str) -> BytesMut {
}
/// Tell the client we're ready for another query.
pub async fn ready_for_query<S>(stream: &mut S) -> Result<(), Error>
pub async fn send_ready_for_query<S>(stream: &mut S) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut bytes = BytesMut::with_capacity(
mem::size_of::<u8>() + mem::size_of::<i32>() + mem::size_of::<u8>(),
);
bytes.put_u8(b'Z');
bytes.put_i32(5);
bytes.put_u8(b'I'); // Idle
write_all(stream, bytes).await
write_all(stream, ready_for_query(false)).await
}
/// Send the startup packet the server. We're pretending we're a Pg client.
@@ -163,12 +158,10 @@ where
match stream.write_all(&startup).await {
Ok(_) => Ok(()),
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing startup to server socket - Error: {:?}",
err
)))
}
Err(err) => Err(Error::SocketError(format!(
"Error writing startup to server socket - Error: {:?}",
err
))),
}
}
@@ -244,8 +237,8 @@ pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec<u8> {
let mut md5 = Md5::new();
// First pass
md5.update(&password.as_bytes());
md5.update(&user.as_bytes());
md5.update(password.as_bytes());
md5.update(user.as_bytes());
let output = md5.finalize_reset();
@@ -281,7 +274,7 @@ where
{
let password = md5_hash_password(user, password, salt);
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
let mut message = BytesMut::with_capacity(password.len() + 5);
message.put_u8(b'p');
message.put_i32(password.len() as i32 + 4);
@@ -295,7 +288,7 @@ where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let password = md5_hash_second_pass(hash, salt);
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
let mut message = BytesMut::with_capacity(password.len() + 5);
message.put_u8(b'p');
message.put_i32(password.len() as i32 + 4);
@@ -322,7 +315,7 @@ where
res.put_slice(&set_complete[..]);
write_all_half(stream, &res).await?;
ready_for_query(stream).await
send_ready_for_query(stream).await
}
/// Send a custom error message to the client.
@@ -333,7 +326,7 @@ where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
error_response_terminal(stream, message).await?;
ready_for_query(stream).await
send_ready_for_query(stream).await
}
/// Send a custom error message to the client.
@@ -434,7 +427,7 @@ where
res.put(command_complete("SELECT 1"));
write_all_half(stream, &res).await?;
ready_for_query(stream).await
send_ready_for_query(stream).await
}
pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
@@ -516,7 +509,7 @@ pub fn data_row_nullable(row: &Vec<Option<String>>) -> BytesMut {
data_row.put_i32(column.len() as i32);
data_row.put_slice(column);
} else {
data_row.put_i32(-1 as i32);
data_row.put_i32(-1_i32);
}
}
@@ -564,6 +557,37 @@ pub fn flush() -> BytesMut {
bytes
}
pub fn sync() -> BytesMut {
let mut bytes = BytesMut::with_capacity(mem::size_of::<u8>() + mem::size_of::<i32>());
bytes.put_u8(b'S');
bytes.put_i32(4);
bytes
}
pub fn parse_complete() -> BytesMut {
let mut bytes = BytesMut::with_capacity(mem::size_of::<u8>() + mem::size_of::<i32>());
bytes.put_u8(b'1');
bytes.put_i32(4);
bytes
}
pub fn ready_for_query(in_transaction: bool) -> BytesMut {
let mut bytes = BytesMut::with_capacity(
mem::size_of::<u8>() + mem::size_of::<i32>() + mem::size_of::<u8>(),
);
bytes.put_u8(b'Z');
bytes.put_i32(5);
if in_transaction {
bytes.put_u8(b'T');
} else {
bytes.put_u8(b'I');
}
bytes
}
/// Write all data in the buffer to the TcpStream.
pub async fn write_all<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error>
where
@@ -571,12 +595,10 @@ where
{
match stream.write_all(&buf).await {
Ok(_) => Ok(()),
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
)))
}
Err(err) => Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
))),
}
}
@@ -587,12 +609,10 @@ where
{
match stream.write_all(buf).await {
Ok(_) => Ok(()),
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
)))
}
Err(err) => Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
))),
}
}
@@ -603,19 +623,15 @@ where
match stream.write_all(buf).await {
Ok(_) => match stream.flush().await {
Ok(_) => Ok(()),
Err(err) => {
return Err(Error::SocketError(format!(
"Error flushing socket - Error: {:?}",
err
)))
}
},
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
Err(err) => Err(Error::SocketError(format!(
"Error flushing socket - Error: {:?}",
err
)))
}
))),
},
Err(err) => Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
))),
}
}
@@ -730,7 +746,7 @@ impl BytesMutReader for Cursor<&BytesMut> {
let mut buf = vec![];
match self.read_until(b'\0', &mut buf) {
Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()),
Err(err) => return Err(Error::ParseBytesError(err.to_string())),
Err(err) => Err(Error::ParseBytesError(err.to_string())),
}
}
}
@@ -746,10 +762,55 @@ impl BytesMutReader for BytesMut {
let string_bytes = self.split_to(index + 1);
Ok(String::from_utf8_lossy(&string_bytes[..string_bytes.len() - 1]).to_string())
}
None => return Err(Error::ParseBytesError("Could not read string".to_string())),
None => Err(Error::ParseBytesError("Could not read string".to_string())),
}
}
}
pub enum ExtendedProtocolData {
Parse {
data: BytesMut,
metadata: Option<(Arc<Parse>, u64)>,
},
Bind {
data: BytesMut,
metadata: Option<String>,
},
Describe {
data: BytesMut,
metadata: Option<String>,
},
Execute {
data: BytesMut,
},
Close {
data: BytesMut,
close: Close,
},
}
impl ExtendedProtocolData {
pub fn create_new_parse(data: BytesMut, metadata: Option<(Arc<Parse>, u64)>) -> Self {
Self::Parse { data, metadata }
}
pub fn create_new_bind(data: BytesMut, metadata: Option<String>) -> Self {
Self::Bind { data, metadata }
}
pub fn create_new_describe(data: BytesMut, metadata: Option<String>) -> Self {
Self::Describe { data, metadata }
}
pub fn create_new_execute(data: BytesMut) -> Self {
Self::Execute { data }
}
pub fn create_new_close(data: BytesMut, close: Close) -> Self {
Self::Close { data, close }
}
}
/// Parse (F) message.
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
#[derive(Clone, Debug)]
@@ -758,7 +819,6 @@ pub struct Parse {
#[allow(dead_code)]
len: i32,
pub name: String,
pub generated_name: String,
query: String,
num_params: i16,
param_types: Vec<i32>,
@@ -784,7 +844,6 @@ impl TryFrom<&BytesMut> for Parse {
code,
len,
name,
generated_name: prepared_statement_name(),
query,
num_params,
param_types,
@@ -833,11 +892,44 @@ impl TryFrom<&Parse> for BytesMut {
}
impl Parse {
pub fn rename(mut self) -> Self {
self.name = self.generated_name.to_string();
/// Renames the prepared statement to a new name based on the global counter
pub fn rewrite(mut self) -> Self {
self.name = format!(
"PGCAT_{}",
PREPARED_STATEMENT_COUNTER.fetch_add(1, Ordering::SeqCst)
);
self
}
/// Gets the name of the prepared statement from the buffer
pub fn get_name(buf: &BytesMut) -> Result<String, Error> {
let mut cursor = Cursor::new(buf);
// Skip the code and length
cursor.advance(mem::size_of::<u8>() + mem::size_of::<i32>());
cursor.read_string()
}
/// Hashes the parse statement to be used as a key in the global cache
pub fn get_hash(&self) -> u64 {
// TODO_ZAIN: Take a look at which hashing function is being used
let mut hasher = DefaultHasher::new();
let concatenated = format!(
"{}{}{}",
self.query,
self.num_params,
self.param_types
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(",")
);
concatenated.hash(&mut hasher);
hasher.finish()
}
pub fn anonymous(&self) -> bool {
self.name.is_empty()
}
@@ -968,9 +1060,42 @@ impl TryFrom<Bind> for BytesMut {
}
impl Bind {
pub fn reassign(mut self, parse: &Parse) -> Self {
self.prepared_statement = parse.name.clone();
self
/// Gets the name of the prepared statement from the buffer
pub fn get_name(buf: &BytesMut) -> Result<String, Error> {
let mut cursor = Cursor::new(buf);
// Skip the code and length
cursor.advance(mem::size_of::<u8>() + mem::size_of::<i32>());
cursor.read_string()?;
cursor.read_string()
}
/// Renames the prepared statement to a new name
pub fn rename(buf: BytesMut, new_name: &str) -> Result<BytesMut, Error> {
let mut cursor = Cursor::new(&buf);
// Read basic data from the cursor
let code = cursor.get_u8();
let current_len = cursor.get_i32();
let portal = cursor.read_string()?;
let prepared_statement = cursor.read_string()?;
// Calculate new length
let new_len = current_len + new_name.len() as i32 - prepared_statement.len() as i32;
// Begin building the response buffer
let mut response_buf = BytesMut::with_capacity(new_len as usize + 1);
response_buf.put_u8(code);
response_buf.put_i32(new_len);
// Put the portal and new name into the buffer
// Note: panic if the provided string contains null byte
response_buf.put_slice(CString::new(portal)?.as_bytes_with_nul());
response_buf.put_slice(CString::new(new_name)?.as_bytes_with_nul());
// Add the remainder of the original buffer into the response
response_buf.put_slice(&buf[cursor.position() as usize..]);
// Return the buffer
Ok(response_buf)
}
pub fn anonymous(&self) -> bool {
@@ -984,7 +1109,7 @@ pub struct Describe {
#[allow(dead_code)]
len: i32,
target: char,
pub target: char,
pub statement_name: String,
}
@@ -1026,6 +1151,15 @@ impl TryFrom<Describe> for BytesMut {
}
impl Describe {
pub fn empty_new() -> Describe {
Describe {
code: 'D',
len: 4 + 1 + 1,
target: 'S',
statement_name: "".to_string(),
}
}
pub fn rename(mut self, name: &str) -> Self {
self.statement_name = name.to_string();
self
@@ -1114,13 +1248,6 @@ pub fn close_complete() -> BytesMut {
bytes
}
pub fn prepared_statement_name() -> String {
format!(
"P_{}",
PREPARED_STATEMENT_COUNTER.fetch_add(1, Ordering::SeqCst)
)
}
// from https://www.postgresql.org/docs/12/protocol-error-fields.html
#[derive(Debug, Default, PartialEq)]
pub struct PgErrorMsg {
@@ -1203,7 +1330,7 @@ impl Display for PgErrorMsg {
}
impl PgErrorMsg {
pub fn parse(error_msg: Vec<u8>) -> Result<PgErrorMsg, Error> {
pub fn parse(error_msg: &[u8]) -> Result<PgErrorMsg, Error> {
let mut out = PgErrorMsg {
severity_localized: "".to_string(),
severity: "".to_string(),
@@ -1311,38 +1438,38 @@ mod tests {
fn parse_fields() {
let mut complete_msg = vec![];
let severity = "FATAL";
complete_msg.extend(field('S', &severity));
complete_msg.extend(field('V', &severity));
complete_msg.extend(field('S', severity));
complete_msg.extend(field('V', severity));
let error_code = "29P02";
complete_msg.extend(field('C', &error_code));
complete_msg.extend(field('C', error_code));
let message = "password authentication failed for user \"wrong_user\"";
complete_msg.extend(field('M', &message));
complete_msg.extend(field('M', message));
let detail_msg = "super detailed message";
complete_msg.extend(field('D', &detail_msg));
complete_msg.extend(field('D', detail_msg));
let hint_msg = "hint detail here";
complete_msg.extend(field('H', &hint_msg));
complete_msg.extend(field('H', hint_msg));
complete_msg.extend(field('P', "123"));
complete_msg.extend(field('p', "234"));
let internal_query = "SELECT * from foo;";
complete_msg.extend(field('q', &internal_query));
complete_msg.extend(field('q', internal_query));
let where_msg = "where goes here";
complete_msg.extend(field('W', &where_msg));
complete_msg.extend(field('W', where_msg));
let schema_msg = "schema_name";
complete_msg.extend(field('s', &schema_msg));
complete_msg.extend(field('s', schema_msg));
let table_msg = "table_name";
complete_msg.extend(field('t', &table_msg));
complete_msg.extend(field('t', table_msg));
let column_msg = "column_name";
complete_msg.extend(field('c', &column_msg));
complete_msg.extend(field('c', column_msg));
let data_type_msg = "type_name";
complete_msg.extend(field('d', &data_type_msg));
complete_msg.extend(field('d', data_type_msg));
let constraint_msg = "constraint_name";
complete_msg.extend(field('n', &constraint_msg));
complete_msg.extend(field('n', constraint_msg));
let file_msg = "pgcat.c";
complete_msg.extend(field('F', &file_msg));
complete_msg.extend(field('F', file_msg));
complete_msg.extend(field('L', "335"));
let routine_msg = "my_failing_routine";
complete_msg.extend(field('R', &routine_msg));
complete_msg.extend(field('R', routine_msg));
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
@@ -1351,7 +1478,7 @@ mod tests {
info!(
"full message: {}",
PgErrorMsg::parse(complete_msg.clone()).unwrap()
PgErrorMsg::parse(&complete_msg).unwrap()
);
assert_eq!(
PgErrorMsg {
@@ -1374,17 +1501,17 @@ mod tests {
line: Some(335),
routine: Some(routine_msg.to_string()),
},
PgErrorMsg::parse(complete_msg).unwrap()
PgErrorMsg::parse(&complete_msg).unwrap()
);
let mut only_mandatory_msg = vec![];
only_mandatory_msg.extend(field('S', &severity));
only_mandatory_msg.extend(field('V', &severity));
only_mandatory_msg.extend(field('C', &error_code));
only_mandatory_msg.extend(field('M', &message));
only_mandatory_msg.extend(field('D', &detail_msg));
only_mandatory_msg.extend(field('S', severity));
only_mandatory_msg.extend(field('V', severity));
only_mandatory_msg.extend(field('C', error_code));
only_mandatory_msg.extend(field('M', message));
only_mandatory_msg.extend(field('D', detail_msg));
let err_fields = PgErrorMsg::parse(only_mandatory_msg.clone()).unwrap();
let err_fields = PgErrorMsg::parse(&only_mandatory_msg).unwrap();
info!("only mandatory fields: {}", &err_fields);
error!(
"server error: {}: {}",
@@ -1411,7 +1538,7 @@ mod tests {
line: None,
routine: None,
},
PgErrorMsg::parse(only_mandatory_msg).unwrap()
PgErrorMsg::parse(&only_mandatory_msg).unwrap()
);
}
}

View File

@@ -23,14 +23,15 @@ impl MirroredClient {
async fn create_pool(&self) -> Pool<ServerPool> {
let config = get_config();
let default = std::time::Duration::from_millis(10_000).as_millis() as u64;
let (connection_timeout, idle_timeout, _cfg) =
let (connection_timeout, idle_timeout, _cfg, prepared_statement_cache_size) =
match config.pools.get(&self.address.pool_name) {
Some(cfg) => (
cfg.connect_timeout.unwrap_or(default),
cfg.idle_timeout.unwrap_or(default),
cfg.clone(),
cfg.prepared_statements_cache_size,
),
None => (default, default, crate::config::Pool::default()),
None => (default, default, crate::config::Pool::default(), 0),
};
let manager = ServerPool::new(
@@ -42,6 +43,7 @@ impl MirroredClient {
None,
true,
false,
prepared_statement_cache_size,
);
Pool::builder()
@@ -137,18 +139,18 @@ impl MirroringManager {
bytes_rx,
disconnect_rx: exit_rx,
};
exit_senders.push(exit_tx.clone());
byte_senders.push(bytes_tx.clone());
exit_senders.push(exit_tx);
byte_senders.push(bytes_tx);
client.start();
});
Self {
byte_senders: byte_senders,
byte_senders,
disconnect_senders: exit_senders,
}
}
pub fn send(self: &mut Self, bytes: &BytesMut) {
pub fn send(&mut self, bytes: &BytesMut) {
// We want to avoid performing an allocation if we won't be able to send the message
// There is a possibility of a race here where we check the capacity and then the channel is
// closed or the capacity is reduced to 0, but mirroring is best effort anyway
@@ -170,7 +172,7 @@ impl MirroringManager {
});
}
pub fn disconnect(self: &mut Self) {
pub fn disconnect(&mut self) {
self.disconnect_senders
.iter_mut()
.for_each(|sender| match sender.try_send(()) {

View File

@@ -92,7 +92,7 @@ impl<'a> Plugin for Intercept<'a> {
.map(|s| {
let s = s.as_str().to_string();
if s == "" {
if s.is_empty() {
None
} else {
Some(s)

View File

@@ -33,6 +33,7 @@ pub enum PluginOutput {
#[async_trait]
pub trait Plugin {
// Run before the query is sent to the server.
#[allow(clippy::ptr_arg)]
async fn run(
&mut self,
query_router: &QueryRouter,

View File

@@ -20,7 +20,7 @@ impl<'a> Prewarmer<'a> {
self.server.address(),
query
);
self.server.query(&query).await?;
self.server.query(query).await?;
}
Ok(())

View File

@@ -34,7 +34,7 @@ impl<'a> Plugin for TableAccess<'a> {
visit_relations(ast, |relation| {
let relation = relation.to_string();
let parts = relation.split(".").collect::<Vec<&str>>();
let parts = relation.split('.').collect::<Vec<&str>>();
let table_name = parts.last().unwrap();
if self.tables.contains(&table_name.to_string()) {

View File

@@ -3,6 +3,7 @@ use async_trait::async_trait;
use bb8::{ManageConnection, Pool, PooledConnection, QueueStrategy};
use chrono::naive::NaiveDateTime;
use log::{debug, error, info, warn};
use lru::LruCache;
use once_cell::sync::Lazy;
use parking_lot::{Mutex, RwLock};
use rand::seq::SliceRandom;
@@ -10,6 +11,7 @@ use rand::thread_rng;
use regex::Regex;
use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use std::num::NonZeroUsize;
use std::sync::atomic::AtomicU64;
use std::sync::{
atomic::{AtomicBool, Ordering},
@@ -24,6 +26,7 @@ use crate::config::{
use crate::errors::Error;
use crate::auth_passthrough::AuthPassthrough;
use crate::messages::Parse;
use crate::plugins::prewarmer;
use crate::server::{Server, ServerParameters};
use crate::sharding::ShardingFunction;
@@ -54,6 +57,57 @@ pub enum BanReason {
AdminBan(i64),
}
pub type PreparedStatementCacheType = Arc<Mutex<PreparedStatementCache>>;
// TODO: Add stats the this cache
// TODO: Add application name to the cache value to help identify which application is using the cache
// TODO: Create admin command to show which statements are in the cache
#[derive(Debug)]
pub struct PreparedStatementCache {
cache: LruCache<u64, Arc<Parse>>,
}
impl PreparedStatementCache {
pub fn new(mut size: usize) -> Self {
// Cannot be zeros
if size == 0 {
size = 1;
}
PreparedStatementCache {
cache: LruCache::new(NonZeroUsize::new(size).unwrap()),
}
}
/// Adds the prepared statement to the cache if it doesn't exist with a new name
/// if it already exists will give you the existing parse
///
/// Pass the hash to this so that we can do the compute before acquiring the lock
pub fn get_or_insert(&mut self, parse: &Parse, hash: u64) -> Arc<Parse> {
match self.cache.get(&hash) {
Some(rewritten_parse) => rewritten_parse.clone(),
None => {
let new_parse = Arc::new(parse.clone().rewrite());
let evicted = self.cache.push(hash, new_parse.clone());
if let Some((_, evicted_parse)) = evicted {
debug!(
"Evicted prepared statement {} from cache",
evicted_parse.name
);
}
new_parse
}
}
}
/// Marks the hash as most recently used if it exists
pub fn promote(&mut self, hash: &u64) {
self.cache.promote(hash);
}
}
/// An identifier for a PgCat pool,
/// a database visible to clients.
#[derive(Hash, Debug, Clone, PartialEq, Eq, Default)]
@@ -190,11 +244,11 @@ impl Default for PoolSettings {
#[derive(Clone, Debug, Default)]
pub struct ConnectionPool {
/// The pools handled internally by bb8.
databases: Vec<Vec<Pool<ServerPool>>>,
databases: Arc<Vec<Vec<Pool<ServerPool>>>>,
/// The addresses (host, port, role) to handle
/// failover and load balancing deterministically.
addresses: Vec<Vec<Address>>,
addresses: Arc<Vec<Vec<Address>>>,
/// List of banned addresses (see above)
/// that should not be queried.
@@ -206,7 +260,7 @@ pub struct ConnectionPool {
original_server_parameters: Arc<RwLock<ServerParameters>>,
/// Pool configuration.
pub settings: PoolSettings,
pub settings: Arc<PoolSettings>,
/// If not validated, we need to double check the pool is available before allowing a client
/// to use it.
@@ -223,6 +277,9 @@ pub struct ConnectionPool {
/// AuthInfo
pub auth_hash: Arc<RwLock<Option<String>>>,
/// Cache
pub prepared_statement_cache: Option<PreparedStatementCacheType>,
}
impl ConnectionPool {
@@ -241,20 +298,17 @@ impl ConnectionPool {
let old_pool_ref = get_pool(pool_name, &user.username);
let identifier = PoolIdentifier::new(pool_name, &user.username);
match old_pool_ref {
Some(pool) => {
// If the pool hasn't changed, get existing reference and insert it into the new_pools.
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
if pool.config_hash == new_pool_hash_value {
info!(
"[pool: {}][user: {}] has not changed",
pool_name, user.username
);
new_pools.insert(identifier.clone(), pool.clone());
continue;
}
if let Some(pool) = old_pool_ref {
// If the pool hasn't changed, get existing reference and insert it into the new_pools.
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
if pool.config_hash == new_pool_hash_value {
info!(
"[pool: {}][user: {}] has not changed",
pool_name, user.username
);
new_pools.insert(identifier.clone(), pool.clone());
continue;
}
None => (),
}
info!(
@@ -379,16 +433,23 @@ impl ConnectionPool {
},
pool_config.cleanup_server_connections,
pool_config.log_client_parameter_status_changes,
pool_config.prepared_statements_cache_size,
);
let connect_timeout = match pool_config.connect_timeout {
let connect_timeout = match user.connect_timeout {
Some(connect_timeout) => connect_timeout,
None => config.general.connect_timeout,
None => match pool_config.connect_timeout {
Some(connect_timeout) => connect_timeout,
None => config.general.connect_timeout,
},
};
let idle_timeout = match pool_config.idle_timeout {
let idle_timeout = match user.idle_timeout {
Some(idle_timeout) => idle_timeout,
None => config.general.idle_timeout,
None => match pool_config.idle_timeout {
Some(idle_timeout) => idle_timeout,
None => config.general.idle_timeout,
},
};
let server_lifetime = match user.server_lifetime {
@@ -399,7 +460,7 @@ impl ConnectionPool {
},
};
let reaper_rate = *vec![idle_timeout, server_lifetime, POOL_REAPER_RATE]
let reaper_rate = *[idle_timeout, server_lifetime, POOL_REAPER_RATE]
.iter()
.min()
.unwrap();
@@ -448,13 +509,13 @@ impl ConnectionPool {
}
let pool = ConnectionPool {
databases: shards,
addresses,
databases: Arc::new(shards),
addresses: Arc::new(addresses),
banlist: Arc::new(RwLock::new(banlist)),
config_hash: new_pool_hash_value,
original_server_parameters: Arc::new(RwLock::new(ServerParameters::new())),
auth_hash: pool_auth_hash,
settings: PoolSettings {
settings: Arc::new(PoolSettings {
pool_mode: match user.pool_mode {
Some(pool_mode) => pool_mode,
None => pool_config.pool_mode,
@@ -489,7 +550,7 @@ impl ConnectionPool {
.clone()
.map(|regex| Regex::new(regex.as_str()).unwrap()),
regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
default_shard: pool_config.default_shard.clone(),
default_shard: pool_config.default_shard,
auth_query: pool_config.auth_query.clone(),
auth_query_user: pool_config.auth_query_user.clone(),
auth_query_password: pool_config.auth_query_password.clone(),
@@ -497,17 +558,23 @@ impl ConnectionPool {
Some(ref plugins) => Some(plugins.clone()),
None => config.plugins.clone(),
},
},
}),
validated: Arc::new(AtomicBool::new(false)),
paused: Arc::new(AtomicBool::new(false)),
paused_waiter: Arc::new(Notify::new()),
prepared_statement_cache: match pool_config.prepared_statements_cache_size {
0 => None,
_ => Some(Arc::new(Mutex::new(PreparedStatementCache::new(
pool_config.prepared_statements_cache_size,
)))),
},
};
// Connect to the servers to make sure pool configuration is valid
// before setting it globally.
// Do this async and somewhere else, we don't have to wait here.
if config.general.validate_config {
let mut validate_pool = pool.clone();
let validate_pool = pool.clone();
tokio::task::spawn(async move {
let _ = validate_pool.validate().await;
});
@@ -528,7 +595,7 @@ impl ConnectionPool {
/// when they connect.
/// This also warms up the pool for clients that connect when
/// the pooler starts up.
pub async fn validate(&mut self) -> Result<(), Error> {
pub async fn validate(&self) -> Result<(), Error> {
let mut futures = Vec::new();
let validated = Arc::clone(&self.validated);
@@ -678,7 +745,7 @@ impl ConnectionPool {
let mut force_healthcheck = false;
if self.is_banned(address) {
if self.try_unban(&address).await {
if self.try_unban(address).await {
force_healthcheck = true;
} else {
debug!("Address {:?} is banned", address);
@@ -806,8 +873,8 @@ impl ConnectionPool {
// Don't leave a bad connection in the pool.
server.mark_bad();
self.ban(&address, BanReason::FailedHealthCheck, Some(client_info));
return false;
self.ban(address, BanReason::FailedHealthCheck, Some(client_info));
false
}
/// Ban an address (i.e. replica). It no longer will serve
@@ -931,10 +998,10 @@ impl ConnectionPool {
let guard = self.banlist.read();
for banlist in guard.iter() {
for (address, (reason, timestamp)) in banlist.iter() {
bans.push((address.clone(), (reason.clone(), timestamp.clone())));
bans.push((address.clone(), (reason.clone(), *timestamp)));
}
}
return bans;
bans
}
/// Get the address from the host url
@@ -992,7 +1059,7 @@ impl ConnectionPool {
}
let busy = provisioned - idle;
debug!("{:?} has {:?} busy connections", address, busy);
return busy;
busy
}
fn valid_shard_id(&self, shard: Option<usize>) -> bool {
@@ -1001,6 +1068,29 @@ impl ConnectionPool {
Some(shard) => shard < self.shards(),
}
}
/// Register a parse statement to the pool's cache and return the rewritten parse
///
/// Do not pass an anonymous parse statement to this function
pub fn register_parse_to_cache(&self, hash: u64, parse: &Parse) -> Option<Arc<Parse>> {
// We should only be calling this function if the cache is enabled
match self.prepared_statement_cache {
Some(ref prepared_statement_cache) => {
let mut cache = prepared_statement_cache.lock();
Some(cache.get_or_insert(parse, hash))
}
None => None,
}
}
/// Promote a prepared statement hash in the LRU
pub fn promote_prepared_statement_hash(&self, hash: &u64) {
// We should only be calling this function if the cache is enabled
if let Some(ref prepared_statement_cache) = self.prepared_statement_cache {
let mut cache = prepared_statement_cache.lock();
cache.promote(hash);
}
}
}
/// Wrapper for the bb8 connection pool.
@@ -1028,9 +1118,13 @@ pub struct ServerPool {
/// Log client parameter status changes
log_client_parameter_status_changes: bool,
/// Prepared statement cache size
prepared_statement_cache_size: usize,
}
impl ServerPool {
#[allow(clippy::too_many_arguments)]
pub fn new(
address: Address,
user: User,
@@ -1040,16 +1134,18 @@ impl ServerPool {
plugins: Option<Plugins>,
cleanup_connections: bool,
log_client_parameter_status_changes: bool,
prepared_statement_cache_size: usize,
) -> ServerPool {
ServerPool {
address,
user: user.clone(),
user,
database: database.to_string(),
client_server_map,
auth_hash,
plugins,
cleanup_connections,
log_client_parameter_status_changes,
prepared_statement_cache_size,
}
}
}
@@ -1080,6 +1176,7 @@ impl ManageConnection for ServerPool {
self.auth_hash.clone(),
self.cleanup_connections,
self.log_client_parameter_status_changes,
self.prepared_statement_cache_size,
)
.await
{

View File

@@ -4,10 +4,10 @@ use bytes::{Buf, BytesMut};
use log::{debug, error};
use once_cell::sync::OnceCell;
use regex::{Regex, RegexSet};
use sqlparser::ast::Statement::{Query, StartTransaction};
use sqlparser::ast::Statement::{Delete, Insert, Query, StartTransaction, Update};
use sqlparser::ast::{
BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, Statement, TableFactor,
Value,
Assignment, BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, Statement,
TableFactor, TableWithJoins, Value,
};
use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;
@@ -91,7 +91,7 @@ impl QueryRouter {
/// One-time initialization of regexes
/// that parse our custom SQL protocol.
pub fn setup() -> bool {
let set = match RegexSet::new(&CUSTOM_SQL_REGEXES) {
let set = match RegexSet::new(CUSTOM_SQL_REGEXES) {
Ok(rgx) => rgx,
Err(err) => {
error!("QueryRouter::setup Could not compile regex set: {:?}", err);
@@ -128,11 +128,11 @@ impl QueryRouter {
}
/// Pool settings can change because of a config reload.
pub fn update_pool_settings(&mut self, pool_settings: PoolSettings) {
self.pool_settings = pool_settings;
pub fn update_pool_settings(&mut self, pool_settings: &PoolSettings) {
self.pool_settings = pool_settings.clone();
}
pub fn pool_settings<'a>(&'a self) -> &'a PoolSettings {
pub fn pool_settings(&self) -> &PoolSettings {
&self.pool_settings
}
@@ -148,7 +148,7 @@ impl QueryRouter {
// Check for any sharding regex matches in any queries
if comment_shard_routing_enabled {
match code as char {
match code {
// For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement
'P' | 'Q' => {
// Check only the first block of bytes configured by the pool settings
@@ -344,16 +344,13 @@ impl QueryRouter {
let code = message_cursor.get_u8() as char;
let len = message_cursor.get_i32() as usize;
match self.pool_settings.query_parser_max_length {
Some(max_length) => {
if len > max_length {
return Err(Error::QueryRouterParserError(format!(
"Query too long for parser: {} > {}",
len, max_length
)));
}
if let Some(max_length) = self.pool_settings.query_parser_max_length {
if len > max_length {
return Err(Error::QueryRouterParserError(format!(
"Query too long for parser: {} > {}",
len, max_length
)));
}
None => (),
};
let query = match code {
@@ -403,6 +400,9 @@ impl QueryRouter {
return Err(Error::QueryRouterParserError("empty query".into()));
}
let mut visited_write_statement = false;
let mut prev_inferred_shard = None;
for q in ast {
match q {
// All transactions go to the primary, probably a write.
@@ -420,29 +420,38 @@ impl QueryRouter {
// or discard shard selection. If they point to the same shard though,
// we can let them through as-is.
// This is basically building a database now :)
match self.infer_shard(query) {
Some(shard) => {
self.active_shard = Some(shard);
debug!("Automatically using shard: {:?}", self.active_shard);
}
None => (),
};
let inferred_shard = self.infer_shard(query);
self.handle_inferred_shard(inferred_shard, &mut prev_inferred_shard)?;
}
None => (),
};
self.active_role = match self.primary_reads_enabled() {
false => Some(Role::Replica), // If primary should not be receiving reads, use a replica.
true => None, // Any server role is fine in this case.
// If we already visited a write statement, we should be going to the primary.
if !visited_write_statement {
self.active_role = match self.primary_reads_enabled() {
false => Some(Role::Replica), // If primary should not be receiving reads, use a replica.
true => None, // Any server role is fine in this case.
}
}
}
// Likely a write
_ => {
match &self.pool_settings.automatic_sharding_key {
Some(_) => {
// TODO: similar to the above, if we have multiple queries in the
// same message, we can either split them and execute them individually
// or discard shard selection. If they point to the same shard though,
// we can let them through as-is.
let inferred_shard = self.infer_shard_on_write(q)?;
self.handle_inferred_shard(inferred_shard, &mut prev_inferred_shard)?;
}
None => (),
};
visited_write_statement = true;
self.active_role = Some(Role::Primary);
break;
}
};
}
@@ -450,6 +459,188 @@ impl QueryRouter {
Ok(())
}
fn handle_inferred_shard(
&mut self,
inferred_shard: Option<usize>,
prev_inferred_shard: &mut Option<usize>,
) -> Result<(), Error> {
if let Some(shard) = inferred_shard {
if let Some(prev_shard) = *prev_inferred_shard {
if prev_shard != shard {
debug!("Found more than one shard in the query, not supported yet");
return Err(Error::QueryRouterParserError(
"multiple shards in query".into(),
));
}
}
*prev_inferred_shard = Some(shard);
self.active_shard = Some(shard);
debug!("Automatically using shard: {:?}", self.active_shard);
};
Ok(())
}
fn infer_shard_on_write(&mut self, q: &Statement) -> Result<Option<usize>, Error> {
let mut exprs = Vec::new();
// Collect all table names from the query.
let mut table_names = Vec::new();
match q {
Insert {
or,
into: _,
table_name,
columns,
overwrite: _,
source,
partitioned,
after_columns,
table: _,
on: _,
returning: _,
} => {
// Not supported in postgres.
assert!(or.is_none());
assert!(partitioned.is_none());
assert!(after_columns.is_empty());
Self::process_table(table_name, &mut table_names);
Self::process_query(source, &mut exprs, &mut table_names, &Some(columns));
}
Delete {
tables,
from,
using,
selection,
returning: _,
} => {
if let Some(expr) = selection {
exprs.push(expr.clone());
}
// Multi tables delete are not supported in postgres.
assert!(tables.is_empty());
Self::process_tables_with_join(from, &mut exprs, &mut table_names);
if let Some(using_tbl_with_join) = using {
Self::process_tables_with_join(
using_tbl_with_join,
&mut exprs,
&mut table_names,
);
}
Self::process_selection(selection, &mut exprs);
}
Update {
table,
assignments,
from,
selection,
returning: _,
} => {
Self::process_table_with_join(table, &mut exprs, &mut table_names);
if let Some(from_tbl) = from {
Self::process_table_with_join(from_tbl, &mut exprs, &mut table_names);
}
Self::process_selection(selection, &mut exprs);
self.assignment_parser(assignments)?;
}
_ => {
return Ok(None);
}
};
Ok(self.infer_shard_from_exprs(exprs, table_names))
}
fn process_query(
query: &sqlparser::ast::Query,
exprs: &mut Vec<Expr>,
table_names: &mut Vec<Vec<Ident>>,
columns: &Option<&Vec<Ident>>,
) {
match &*query.body {
SetExpr::Query(query) => {
Self::process_query(query, exprs, table_names, columns);
}
// SELECT * FROM ...
// We understand that pretty well.
SetExpr::Select(select) => {
Self::process_tables_with_join(&select.from, exprs, table_names);
// Parse the actual "FROM ..."
Self::process_selection(&select.selection, exprs);
}
SetExpr::Values(values) => {
if let Some(cols) = columns {
for row in values.rows.iter() {
for (i, expr) in row.iter().enumerate() {
if cols.len() > i {
exprs.push(Expr::BinaryOp {
left: Box::new(Expr::Identifier(cols[i].clone())),
op: BinaryOperator::Eq,
right: Box::new(expr.clone()),
});
}
}
}
}
}
_ => (),
};
}
fn process_selection(selection: &Option<Expr>, exprs: &mut Vec<Expr>) {
match selection {
Some(selection) => {
exprs.push(selection.clone());
}
None => (),
};
}
fn process_tables_with_join(
tables: &[TableWithJoins],
exprs: &mut Vec<Expr>,
table_names: &mut Vec<Vec<Ident>>,
) {
for table in tables.iter() {
Self::process_table_with_join(table, exprs, table_names);
}
}
fn process_table_with_join(
table: &TableWithJoins,
exprs: &mut Vec<Expr>,
table_names: &mut Vec<Vec<Ident>>,
) {
if let TableFactor::Table { name, .. } = &table.relation {
Self::process_table(name, table_names);
};
// Get table names from all the joins.
for join in table.joins.iter() {
if let TableFactor::Table { name, .. } = &join.relation {
Self::process_table(name, table_names);
};
// We can filter results based on join conditions, e.g.
// SELECT * FROM t INNER JOIN B ON B.sharding_key = 5;
if let JoinOperator::Inner(JoinConstraint::On(expr)) = &join.join_operator {
// Parse the selection criteria later.
exprs.push(expr.clone());
};
}
}
fn process_table(name: &sqlparser::ast::ObjectName, table_names: &mut Vec<Vec<Ident>>) {
table_names.push(name.0.clone())
}
/// Parse the shard number from the Bind message
/// which contains the arguments for a prepared statement.
///
@@ -592,6 +783,33 @@ impl QueryRouter {
}
}
/// An `assignments` exists in the `UPDATE` statements. This parses the assignments and makes
/// sure that we are not updating the sharding key. It's not supported yet.
fn assignment_parser(&self, assignments: &Vec<Assignment>) -> Result<(), Error> {
let sharding_key = self
.pool_settings
.automatic_sharding_key
.as_ref()
.unwrap()
.split('.')
.map(|ident| Ident::new(ident.to_lowercase()))
.collect::<Vec<Ident>>();
// Sharding key must be always fully qualified
assert_eq!(sharding_key.len(), 2);
for a in assignments {
if sharding_key[0].value == "*"
&& sharding_key[1].value == a.id.last().unwrap().value.to_lowercase()
{
return Err(Error::QueryRouterParserError(
"Sharding key cannot be updated.".into(),
));
}
}
Ok(())
}
/// A `selection` is the `WHERE` clause. This parses
/// the clause and extracts the sharding key, if present.
fn selection_parser(&self, expr: &Expr, table_names: &Vec<Vec<Ident>>) -> Vec<ShardingKey> {
@@ -603,8 +821,8 @@ impl QueryRouter {
.automatic_sharding_key
.as_ref()
.unwrap()
.split(".")
.map(|ident| Ident::new(ident))
.split('.')
.map(|ident| Ident::new(ident.to_lowercase()))
.collect::<Vec<Ident>>();
// Sharding key must be always fully qualified
@@ -620,7 +838,7 @@ impl QueryRouter {
Expr::Identifier(ident) => {
// Only if we're dealing with only one table
// and there is no ambiguity
if &ident.value == &sharding_key[1].value {
if ident.value.to_lowercase() == sharding_key[1].value {
// Sharding key is unique enough, don't worry about
// table names.
if &sharding_key[0].value == "*" {
@@ -633,13 +851,13 @@ impl QueryRouter {
// SELECT * FROM t WHERE sharding_key = 5
// Make sure the table name from the sharding key matches
// the table name from the query.
found = &sharding_key[0].value == &table[0].value;
found = sharding_key[0].value == table[0].value.to_lowercase();
} else if table.len() == 2 {
// Table name is fully qualified with the schema: e.g.
// SELECT * FROM public.t WHERE sharding_key = 5
// Ignore the schema (TODO: at some point, we want schema support)
// and use the table name only.
found = &sharding_key[0].value == &table[1].value;
found = sharding_key[0].value == table[1].value.to_lowercase();
} else {
debug!("Got table name with more than two idents, which is not possible");
}
@@ -651,8 +869,9 @@ impl QueryRouter {
// The key is fully qualified in the query,
// it will exist or Postgres will throw an error.
if idents.len() == 2 {
found = &sharding_key[0].value == &idents[0].value
&& &sharding_key[1].value == &idents[1].value;
found = (&sharding_key[0].value == "*"
|| sharding_key[0].value == idents[0].value.to_lowercase())
&& sharding_key[1].value == idents[1].value.to_lowercase();
}
// TODO: key can have schema as well, e.g. public.data.id (len == 3)
}
@@ -684,7 +903,7 @@ impl QueryRouter {
}
Expr::Value(Value::Placeholder(placeholder)) => {
match placeholder.replace("$", "").parse::<i16>() {
match placeholder.replace('$', "").parse::<i16>() {
Ok(placeholder) => result.push(ShardingKey::Placeholder(placeholder)),
Err(_) => {
debug!(
@@ -705,100 +924,48 @@ impl QueryRouter {
/// Try to figure out which shard the query should go to.
fn infer_shard(&mut self, query: &sqlparser::ast::Query) -> Option<usize> {
let mut shards = BTreeSet::new();
let mut exprs = Vec::new();
match &*query.body {
SetExpr::Query(query) => {
match self.infer_shard(&*query) {
Some(shard) => {
// Collect all table names from the query.
let mut table_names = Vec::new();
Self::process_query(query, &mut exprs, &mut table_names, &None);
self.infer_shard_from_exprs(exprs, table_names)
}
fn infer_shard_from_exprs(
&mut self,
exprs: Vec<Expr>,
table_names: Vec<Vec<Ident>>,
) -> Option<usize> {
let mut shards = BTreeSet::new();
let sharder = Sharder::new(
self.pool_settings.shards,
self.pool_settings.sharding_function,
);
// Look for sharding keys in either the join condition
// or the selection.
for expr in exprs.iter() {
let sharding_keys = self.selection_parser(expr, &table_names);
// TODO: Add support for prepared statements here.
// This should just give us the position of the value in the `B` message.
for value in sharding_keys {
match value {
ShardingKey::Value(value) => {
let shard = sharder.shard(value);
shards.insert(shard);
}
None => (),
ShardingKey::Placeholder(position) => {
self.placeholders.push(position);
}
};
}
// SELECT * FROM ...
// We understand that pretty well.
SetExpr::Select(select) => {
// Collect all table names from the query.
let mut table_names = Vec::new();
for table in select.from.iter() {
match &table.relation {
TableFactor::Table { name, .. } => {
table_names.push(name.0.clone());
}
_ => (),
};
// Get table names from all the joins.
for join in table.joins.iter() {
match &join.relation {
TableFactor::Table { name, .. } => {
table_names.push(name.0.clone());
}
_ => (),
};
// We can filter results based on join conditions, e.g.
// SELECT * FROM t INNER JOIN B ON B.sharding_key = 5;
match &join.join_operator {
JoinOperator::Inner(inner_join) => match &inner_join {
JoinConstraint::On(expr) => {
// Parse the selection criteria later.
exprs.push(expr.clone());
}
_ => (),
},
_ => (),
};
}
}
// Parse the actual "FROM ..."
match &select.selection {
Some(selection) => {
exprs.push(selection.clone());
}
None => (),
};
let sharder = Sharder::new(
self.pool_settings.shards,
self.pool_settings.sharding_function,
);
// Look for sharding keys in either the join condition
// or the selection.
for expr in exprs.iter() {
let sharding_keys = self.selection_parser(expr, &table_names);
// TODO: Add support for prepared statements here.
// This should just give us the position of the value in the `B` message.
for value in sharding_keys {
match value {
ShardingKey::Value(value) => {
let shard = sharder.shard(value);
shards.insert(shard);
}
ShardingKey::Placeholder(position) => {
self.placeholders.push(position);
}
};
}
}
}
_ => (),
};
}
match shards.len() {
// Didn't find a sharding key, you're on your own.
0 => {
@@ -830,16 +997,16 @@ impl QueryRouter {
db: &self.pool_settings.db,
};
let _ = query_logger.run(&self, ast).await;
let _ = query_logger.run(self, ast).await;
}
if let Some(ref intercept) = plugins.intercept {
let mut intercept = Intercept {
enabled: intercept.enabled,
config: &intercept,
config: intercept,
};
let result = intercept.run(&self, ast).await;
let result = intercept.run(self, ast).await;
if let Ok(PluginOutput::Intercept(output)) = result {
return Ok(PluginOutput::Intercept(output));
@@ -852,7 +1019,7 @@ impl QueryRouter {
tables: &table_access.tables,
};
let result = table_access.run(&self, ast).await;
let result = table_access.run(self, ast).await;
if let Ok(PluginOutput::Deny(error)) = result {
return Ok(PluginOutput::Deny(error));
@@ -888,7 +1055,7 @@ impl QueryRouter {
/// Should we attempt to parse queries?
pub fn query_parser_enabled(&self) -> bool {
let enabled = match self.query_parser_enabled {
match self.query_parser_enabled {
None => {
debug!(
"Using pool settings, query_parser_enabled: {}",
@@ -904,9 +1071,7 @@ impl QueryRouter {
);
value
}
};
enabled
}
}
pub fn primary_reads_enabled(&self) -> bool {
@@ -917,6 +1082,12 @@ impl QueryRouter {
}
}
impl Default for QueryRouter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod test {
use super::*;
@@ -938,10 +1109,14 @@ mod test {
QueryRouter::setup();
let mut qr = QueryRouter::new();
qr.pool_settings.query_parser_read_write_splitting = true;
assert!(qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")) != None);
assert!(qr
.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"))
.is_some());
assert!(qr.query_parser_enabled());
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
.is_some());
let queries = vec![
simple_query("SELECT * FROM items WHERE id = 5"),
@@ -983,7 +1158,9 @@ mod test {
QueryRouter::setup();
let mut qr = QueryRouter::new();
let query = simple_query("SELECT * FROM items WHERE id = 5");
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None);
assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO on"))
.is_some());
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), None);
@@ -996,7 +1173,9 @@ mod test {
qr.pool_settings.query_parser_read_write_splitting = true;
qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"));
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
.is_some());
let prepared_stmt = BytesMut::from(
&b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..],
@@ -1166,9 +1345,11 @@ mod test {
qr.pool_settings.query_parser_read_write_splitting = true;
let query = simple_query("SET SERVER ROLE TO 'auto'");
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
.is_some());
assert!(qr.try_execute_command(&query) != None);
assert!(qr.try_execute_command(&query).is_some());
assert!(qr.query_parser_enabled());
assert_eq!(qr.role(), None);
@@ -1182,7 +1363,7 @@ mod test {
assert!(qr.query_parser_enabled());
let query = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(&query) != None);
assert!(qr.try_execute_command(&query).is_some());
assert!(!qr.query_parser_enabled());
}
@@ -1222,7 +1403,7 @@ mod test {
assert_eq!(qr.primary_reads_enabled, None);
// Internal state must not be changed due to this, only defaults
qr.update_pool_settings(pool_settings.clone());
qr.update_pool_settings(&pool_settings);
assert_eq!(qr.active_role, None);
assert_eq!(qr.active_shard, None);
@@ -1230,11 +1411,11 @@ mod test {
assert!(!qr.primary_reads_enabled());
let q1 = simple_query("SET SERVER ROLE TO 'primary'");
assert!(qr.try_execute_command(&q1) != None);
assert!(qr.try_execute_command(&q1).is_some());
assert_eq!(qr.active_role.unwrap(), Role::Primary);
let q2 = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(&q2) != None);
assert!(qr.try_execute_command(&q2).is_some());
assert_eq!(qr.active_role.unwrap(), pool_settings.default_role);
}
@@ -1295,29 +1476,29 @@ mod test {
};
let mut qr = QueryRouter::new();
qr.update_pool_settings(pool_settings.clone());
qr.update_pool_settings(&pool_settings);
// Shard should start out unset
assert_eq!(qr.active_shard, None);
// Don't panic when short query eg. ; is sent
let q0 = simple_query(";");
assert!(qr.try_execute_command(&q0) == None);
assert!(qr.try_execute_command(&q0).is_none());
assert_eq!(qr.active_shard, None);
// Make sure setting it works
let q1 = simple_query("/* shard_id: 1 */ select 1 from foo;");
assert!(qr.try_execute_command(&q1) == None);
assert!(qr.try_execute_command(&q1).is_none());
assert_eq!(qr.active_shard, Some(1));
// And make sure changing it works
let q2 = simple_query("/* shard_id: 0 */ select 1 from foo;");
assert!(qr.try_execute_command(&q2) == None);
assert!(qr.try_execute_command(&q2).is_none());
assert_eq!(qr.active_shard, Some(0));
// Validate setting by shard with expected shard copied from sharding.rs tests
let q2 = simple_query("/* sharding_key: 6 */ select 1 from foo;");
assert!(qr.try_execute_command(&q2) == None);
assert!(qr.try_execute_command(&q2).is_none());
assert_eq!(qr.active_shard, Some(2));
}
@@ -1414,6 +1595,221 @@ mod test {
assert_eq!(qr.shard().unwrap(), 0);
}
fn auto_shard_wrapper(qry: &str, should_succeed: bool) -> Option<usize> {
let mut qr = QueryRouter::new();
qr.pool_settings.automatic_sharding_key = Some("*.w_id".to_string());
qr.pool_settings.shards = 3;
qr.pool_settings.query_parser_read_write_splitting = true;
assert_eq!(qr.shard(), None);
let infer_res = qr.infer(&qr.parse(&simple_query(qry)).unwrap());
assert_eq!(infer_res.is_ok(), should_succeed);
qr.shard()
}
fn auto_shard(qry: &str) -> Option<usize> {
auto_shard_wrapper(qry, true)
}
fn auto_shard_fails(qry: &str) -> Option<usize> {
auto_shard_wrapper(qry, false)
}
#[test]
fn test_automatic_sharding_insert_update_delete() {
QueryRouter::setup();
assert_eq!(
auto_shard_fails(
"UPDATE ORDERS SET w_id = 3 WHERE O_ID = 3 AND O_D_ID = 3 AND W_ID = 5"
),
None
);
assert_eq!(
auto_shard_fails(
"UPDATE ORDERS o SET o.W_ID = 3 WHERE o.O_ID = 3 AND o.O_D_ID = 3 AND o.W_ID = 5"
),
None
);
assert_eq!(
auto_shard(
"UPDATE ORDERS o SET o.O_CARRIER_ID = 3 WHERE o.O_ID = 3 AND o.O_D_ID = 3 AND o.W_ID = 5"
),
Some(2)
);
}
#[test]
fn test_automatic_sharding_key_tpcc() {
QueryRouter::setup();
assert_eq!(auto_shard("SELECT * FROM my_tbl WHERE w_id = 5"), Some(2));
assert_eq!(
auto_shard("BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ"),
None
);
assert_eq!(auto_shard("COMMIT"), None);
assert_eq!(auto_shard("ROLLBACK"), None);
assert_eq!(auto_shard("SELECT NO_O_ID FROM NEW_ORDER WHERE NO_D_ID = 7 AND W_ID = 5 AND NO_O_ID > 3 LIMIT 3"), Some(2));
assert_eq!(auto_shard("SELECT NO_O_ID FROM NEW_ORDER no WHERE no.NO_D_ID = 7 AND no.W_ID = 5 AND no.NO_O_ID > 3 LIMIT 3"), Some(2));
assert_eq!(
auto_shard("DELETE FROM NEW_ORDER WHERE NO_D_ID = 7 AND W_ID = 5 AND NO_O_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT O_C_ID FROM ORDERS WHERE O_ID = 3 AND O_D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard(
"UPDATE ORDERS SET O_CARRIER_ID = 3 WHERE O_ID = 3 AND O_D_ID = 3 AND W_ID = 5"
),
Some(2)
);
assert_eq!(
auto_shard("UPDATE ORDER_LINE SET OL_DELIVERY_D = 3 WHERE OL_O_ID = 3 AND OL_D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT SUM(OL_AMOUNT) FROM ORDER_LINE WHERE OL_O_ID = 3 AND OL_D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE CUSTOMER SET C_BALANCE = C_BALANCE + 3 WHERE C_ID = 3 AND C_D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT W_TAX FROM WAREHOUSE WHERE W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT D_TAX, D_NEXT_O_ID FROM DISTRICT WHERE D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE DISTRICT SET D_NEXT_O_ID = 3 WHERE D_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_DISCOUNT, C_LAST, C_CREDIT FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("INSERT INTO ORDERS (O_ID, O_D_ID, W_ID, O_C_ID, O_ENTRY_D, O_CARRIER_ID, O_OL_CNT, O_ALL_LOCAL) VALUES (3, 3, 5, 3, 3, 3, 3, 3)"),
Some(2)
);
assert_eq!(
auto_shard("INSERT INTO NEW_ORDER (NO_O_ID, NO_D_ID, W_ID) VALUES (3, 3, 5)"),
Some(2)
);
assert_eq!(
auto_shard("SELECT I_PRICE, I_NAME, I_DATA FROM ITEM WHERE I_ID = 3"),
None
);
assert_eq!(
auto_shard("SELECT S_QUANTITY, S_DATA, S_YTD, S_ORDER_CNT, S_REMOTE_CNT, S_DIST_03 FROM STOCK WHERE S_I_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE STOCK SET S_QUANTITY = 3, S_YTD = 3, S_ORDER_CNT = 3, S_REMOTE_CNT = 3 WHERE S_I_ID = 3 AND W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("INSERT INTO ORDER_LINE (OL_O_ID, OL_D_ID, W_ID, OL_NUMBER, OL_I_ID, OL_SUPPLY_W_ID, OL_DELIVERY_D, OL_QUANTITY, OL_AMOUNT, OL_DIST_INFO) VALUES (3, 3, 5, 3, 3, 3, 3, 3, 3, 3)"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_BALANCE FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_BALANCE FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_LAST = '3' ORDER BY C_FIRST"),
Some(2)
);
assert_eq!(
auto_shard("SELECT O_ID, O_CARRIER_ID, O_ENTRY_D FROM ORDERS WHERE W_ID = 5 AND O_D_ID = 3 AND O_C_ID = 3 ORDER BY O_ID DESC LIMIT 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT OL_SUPPLY_W_ID, OL_I_ID, OL_QUANTITY, OL_AMOUNT, OL_DELIVERY_D FROM ORDER_LINE WHERE W_ID = 5 AND OL_D_ID = 3 AND OL_O_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT W_NAME, W_STREET_1, W_STREET_2, W_CITY, W_STATE, W_ZIP FROM WAREHOUSE WHERE W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE WAREHOUSE SET W_YTD = W_YTD + 3 WHERE W_ID = 5"),
Some(2)
);
assert_eq!(
auto_shard("SELECT D_NAME, D_STREET_1, D_STREET_2, D_CITY, D_STATE, D_ZIP FROM DISTRICT WHERE W_ID = 5 AND D_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE DISTRICT SET D_YTD = D_YTD + 3 WHERE W_ID = 5 AND D_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_STREET_1, C_STREET_2, C_CITY, C_STATE, C_ZIP, C_PHONE, C_SINCE, C_CREDIT, C_CREDIT_LIM, C_DISCOUNT, C_BALANCE, C_YTD_PAYMENT, C_PAYMENT_CNT, C_DATA FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_STREET_1, C_STREET_2, C_CITY, C_STATE, C_ZIP, C_PHONE, C_SINCE, C_CREDIT, C_CREDIT_LIM, C_DISCOUNT, C_BALANCE, C_YTD_PAYMENT, C_PAYMENT_CNT, C_DATA FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_LAST = '3' ORDER BY C_FIRST"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE CUSTOMER SET C_BALANCE = 3, C_YTD_PAYMENT = 3, C_PAYMENT_CNT = 3, C_DATA = 3 WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard("UPDATE CUSTOMER SET C_BALANCE = 3, C_YTD_PAYMENT = 3, C_PAYMENT_CNT = 3 WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"),
Some(2)
);
assert_eq!(auto_shard("INSERT INTO HISTORY (H_C_ID, H_C_D_ID, H_C_W_ID, H_D_ID, W_ID, H_DATE, H_AMOUNT, H_DATA) VALUES (3, 3, 5, 3, 5, 3, 3, 3)"), Some(2));
assert_eq!(
auto_shard("SELECT D_NEXT_O_ID FROM DISTRICT WHERE W_ID = 5 AND D_ID = 3"),
Some(2)
);
assert_eq!(
auto_shard(
"SELECT COUNT(DISTINCT(OL_I_ID)) FROM ORDER_LINE, STOCK
WHERE ORDER_LINE.W_ID = 5
AND OL_D_ID = 3
AND OL_O_ID < 3
AND OL_O_ID >= 3
AND STOCK.W_ID = 5
AND S_I_ID = OL_I_ID
AND S_QUANTITY < 3"
),
Some(2)
);
// This is a distributed query and contains two shards
assert_eq!(
auto_shard(
"SELECT COUNT(DISTINCT(OL_I_ID)) FROM ORDER_LINE, STOCK
WHERE ORDER_LINE.W_ID = 5
AND OL_D_ID = 3
AND OL_O_ID < 3
AND OL_O_ID >= 3
AND STOCK.W_ID = 7
AND S_I_ID = OL_I_ID
AND S_QUANTITY < 3"
),
None
);
}
#[test]
fn test_prepared_statements() {
let stmt = "SELECT * FROM data WHERE id = $1";
@@ -1458,12 +1854,13 @@ mod test {
};
QueryRouter::setup();
let mut pool_settings = PoolSettings::default();
pool_settings.query_parser_enabled = true;
pool_settings.plugins = Some(plugins);
let pool_settings = PoolSettings {
query_parser_enabled: true,
plugins: Some(plugins),
..Default::default()
};
let mut qr = QueryRouter::new();
qr.update_pool_settings(pool_settings);
qr.update_pool_settings(&pool_settings);
let query = simple_query("SELECT * FROM pg_database");
let ast = qr.parse(&query).unwrap();

View File

@@ -79,12 +79,12 @@ impl ScramSha256 {
let server_message = Message::parse(message)?;
if !server_message.nonce.starts_with(&self.nonce) {
return Err(Error::ProtocolSyncError(format!("SCRAM")));
return Err(Error::ProtocolSyncError("SCRAM".to_string()));
}
let salt = match general_purpose::STANDARD.decode(&server_message.salt) {
Ok(salt) => salt,
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
};
let salted_password = Self::hi(
@@ -166,9 +166,9 @@ impl ScramSha256 {
pub fn finish(&mut self, message: &BytesMut) -> Result<(), Error> {
let final_message = FinalMessage::parse(message)?;
let verifier = match general_purpose::STANDARD.decode(&final_message.value) {
let verifier = match general_purpose::STANDARD.decode(final_message.value) {
Ok(verifier) => verifier,
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
};
let mut hmac = match Hmac::<Sha256>::new_from_slice(&self.salted_password) {
@@ -230,14 +230,14 @@ impl Message {
.collect::<Vec<String>>();
if parts.len() != 3 {
return Err(Error::ProtocolSyncError(format!("SCRAM")));
return Err(Error::ProtocolSyncError("SCRAM".to_string()));
}
let nonce = str::replace(&parts[0], "r=", "");
let salt = str::replace(&parts[1], "s=", "");
let iterations = match str::replace(&parts[2], "i=", "").parse::<u32>() {
Ok(iterations) => iterations,
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
};
Ok(Message {
@@ -257,7 +257,7 @@ impl FinalMessage {
/// Parse the server final validation message.
pub fn parse(message: &BytesMut) -> Result<FinalMessage, Error> {
if !message.starts_with(b"v=") || message.len() < 4 {
return Err(Error::ProtocolSyncError(format!("SCRAM")));
return Err(Error::ProtocolSyncError("SCRAM".to_string()));
}
Ok(FinalMessage {

View File

@@ -3,12 +3,14 @@
use bytes::{Buf, BufMut, BytesMut};
use fallible_iterator::FallibleIterator;
use log::{debug, error, info, trace, warn};
use lru::LruCache;
use once_cell::sync::Lazy;
use parking_lot::{Mutex, RwLock};
use postgres_protocol::message;
use std::collections::{BTreeSet, HashMap, HashSet};
use std::collections::{HashMap, HashSet, VecDeque};
use std::mem;
use std::net::IpAddr;
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::time::SystemTime;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream};
@@ -16,7 +18,7 @@ use tokio::net::TcpStream;
use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore};
use tokio_rustls::{client::TlsStream, TlsConnector};
use crate::config::{get_config, get_prepared_statements_cache_size, Address, User};
use crate::config::{get_config, Address, User};
use crate::constants::*;
use crate::dns_cache::{AddrSet, CACHED_RESOLVER};
use crate::errors::{Error, ServerIdentifier};
@@ -197,12 +199,8 @@ impl ServerParameters {
key = "DateStyle".to_string();
};
if TRACKED_PARAMETERS.contains(&key) {
if TRACKED_PARAMETERS.contains(&key) || startup {
self.parameters.insert(key, value);
} else {
if startup {
self.parameters.insert(key, value);
}
}
}
@@ -326,12 +324,16 @@ pub struct Server {
log_client_parameter_status_changes: bool,
/// Prepared statements
prepared_statements: BTreeSet<String>,
prepared_statement_cache: Option<LruCache<String, ()>>,
/// Prepared statement being currently registered on the server.
registering_prepared_statement: VecDeque<String>,
}
impl Server {
/// Pretend to be the Postgres client and connect to the server given host, port and credentials.
/// Perform the authentication and return the server in a ready for query state.
#[allow(clippy::too_many_arguments)]
pub async fn startup(
address: &Address,
user: &User,
@@ -341,6 +343,7 @@ impl Server {
auth_hash: Arc<RwLock<Option<String>>>,
cleanup_connections: bool,
log_client_parameter_status_changes: bool,
prepared_statement_cache_size: usize,
) -> Result<Server, Error> {
let cached_resolver = CACHED_RESOLVER.load();
let mut addr_set: Option<AddrSet> = None;
@@ -440,10 +443,7 @@ impl Server {
// Something else?
m => {
return Err(Error::SocketError(format!(
"Unknown message: {}",
m as char
)));
return Err(Error::SocketError(format!("Unknown message: {}", { m })));
}
}
} else {
@@ -461,26 +461,20 @@ impl Server {
None => &user.username,
};
let password = match user.server_password {
Some(ref server_password) => Some(server_password),
None => match user.password {
Some(ref password) => Some(password),
None => None,
},
let password = match user.server_password.as_ref() {
Some(server_password) => Some(server_password),
None => user.password.as_ref(),
};
startup(&mut stream, username, database).await?;
let mut process_id: i32 = 0;
let mut secret_key: i32 = 0;
let server_identifier = ServerIdentifier::new(username, &database);
let server_identifier = ServerIdentifier::new(username, database);
// We'll be handling multiple packets, but they will all be structured the same.
// We'll loop here until this exchange is complete.
let mut scram: Option<ScramSha256> = match password {
Some(password) => Some(ScramSha256::new(password)),
None => None,
};
let mut scram: Option<ScramSha256> = password.map(|password| ScramSha256::new(password));
let mut server_parameters = ServerParameters::new();
@@ -725,7 +719,7 @@ impl Server {
}
};
let fields = match PgErrorMsg::parse(error) {
let fields = match PgErrorMsg::parse(&error) {
Ok(f) => f,
Err(err) => {
return Err(err);
@@ -830,7 +824,13 @@ impl Server {
},
cleanup_connections,
log_client_parameter_status_changes,
prepared_statements: BTreeSet::new(),
prepared_statement_cache: match prepared_statement_cache_size {
0 => None,
_ => Some(LruCache::new(
NonZeroUsize::new(prepared_statement_cache_size).unwrap(),
)),
},
registering_prepared_statement: VecDeque::new(),
};
return Ok(server);
@@ -882,7 +882,7 @@ impl Server {
self.mirror_send(messages);
self.stats().data_sent(messages.len());
match write_all_flush(&mut self.stream, &messages).await {
match write_all_flush(&mut self.stream, messages).await {
Ok(_) => {
// Successfully sent to server
self.last_activity = SystemTime::now();
@@ -960,7 +960,6 @@ impl Server {
// There is no more data available from the server.
self.data_available = false;
break;
}
@@ -969,6 +968,37 @@ impl Server {
if self.in_copy_mode {
self.in_copy_mode = false;
}
// Remove the prepared statement from the cache, it has a syntax error or something else bad happened.
if let Some(prepared_stmt_name) =
self.registering_prepared_statement.pop_front()
{
if let Some(ref mut cache) = self.prepared_statement_cache {
if let Some(_removed) = cache.pop(&prepared_stmt_name) {
debug!(
"Removed {} from prepared statement cache",
prepared_stmt_name
);
} else {
// Shouldn't happen.
debug!("Prepared statement {} was not cached", prepared_stmt_name);
}
}
}
if self.prepared_statement_cache.is_some() {
let error_message = PgErrorMsg::parse(&message)?;
if error_message.message == "cached plan must not change result type" {
warn!("Server {:?} changed schema, dropping connection to clean up prepared statements", self.address);
// This will still result in an error to the client, but this server connection will drop all cached prepared statements
// so that any new queries will be re-prepared
// TODO: Other ideas to solve errors when there are DDL changes after a statement has been prepared
// - Recreate entire connection pool to force recreation of all server connections
// - Clear the ConnectionPool's statement cache so that new statement names are generated
// - Implement a retry (re-prepare) so the client doesn't see an error
self.cleanup_state.needs_cleanup_prepare = true;
}
}
}
// CommandComplete
@@ -1058,6 +1088,11 @@ impl Server {
// Buffer until ReadyForQuery shows up, so don't exit the loop yet.
'c' => (),
// Parse complete successfully
'1' => {
self.registering_prepared_statement.pop_front();
}
// Anything else, e.g. errors, notices, etc.
// Keep buffering until ReadyForQuery shows up.
_ => (),
@@ -1079,117 +1114,103 @@ impl Server {
Ok(bytes)
}
/// Add the prepared statement to being tracked by this server.
/// The client is processing data that will create a prepared statement on this server.
pub fn will_prepare(&mut self, name: &str) {
debug!("Will prepare `{}`", name);
// Determines if the server already has a prepared statement with the given name
// Increments the prepared statement cache hit counter
pub fn has_prepared_statement(&mut self, name: &str) -> bool {
let cache = match &mut self.prepared_statement_cache {
Some(cache) => cache,
None => return false,
};
self.prepared_statements.insert(name.to_string());
self.stats.prepared_cache_add();
}
/// Check if we should prepare a statement on the server.
pub fn should_prepare(&self, name: &str) -> bool {
let should_prepare = !self.prepared_statements.contains(name);
debug!("Should prepare `{}`: {}", name, should_prepare);
if should_prepare {
self.stats.prepared_cache_miss();
} else {
let has_it = cache.get(name).is_some();
if has_it {
self.stats.prepared_cache_hit();
} else {
self.stats.prepared_cache_miss();
}
should_prepare
has_it
}
/// Create a prepared statement on the server.
pub async fn prepare(&mut self, parse: &Parse) -> Result<(), Error> {
debug!("Preparing `{}`", parse.name);
fn add_prepared_statement_to_cache(&mut self, name: &str) -> Option<String> {
let cache = match &mut self.prepared_statement_cache {
Some(cache) => cache,
None => return None,
};
let bytes: BytesMut = parse.try_into()?;
self.send(&bytes).await?;
self.send(&flush()).await?;
// Read and discard ParseComplete (B)
match read_message(&mut self.stream).await {
Ok(_) => (),
Err(err) => {
self.bad = true;
return Err(err);
}
}
self.prepared_statements.insert(parse.name.to_string());
self.stats.prepared_cache_add();
debug!("Prepared `{}`", parse.name);
Ok(())
}
/// Maintain adequate cache size on the server.
pub async fn maintain_cache(&mut self) -> Result<(), Error> {
debug!("Cache maintenance run");
let max_cache_size = get_prepared_statements_cache_size();
let mut names = Vec::new();
while self.prepared_statements.len() >= max_cache_size {
// The prepared statmeents are alphanumerically sorted by the BTree.
// FIFO.
if let Some(name) = self.prepared_statements.pop_last() {
names.push(name);
// If we evict something, we need to close it on the server
if let Some((evicted_name, _)) = cache.push(name.to_string(), ()) {
if evicted_name != name {
debug!(
"Evicted prepared statement {} from cache, replaced with {}",
evicted_name, name
);
return Some(evicted_name);
}
}
};
if !names.is_empty() {
self.deallocate(names).await?;
}
Ok(())
None
}
/// Remove the prepared statement from being tracked by this server.
/// The client is processing data that will cause the server to close the prepared statement.
pub fn will_close(&mut self, name: &str) {
debug!("Will close `{}`", name);
fn remove_prepared_statement_from_cache(&mut self, name: &str) {
let cache = match &mut self.prepared_statement_cache {
Some(cache) => cache,
None => return,
};
self.prepared_statements.remove(name);
self.stats.prepared_cache_remove();
cache.pop(name);
}
/// Close a prepared statement on the server.
pub async fn deallocate(&mut self, names: Vec<String>) -> Result<(), Error> {
for name in &names {
debug!("Deallocating prepared statement `{}`", name);
pub async fn register_prepared_statement(
&mut self,
parse: &Parse,
should_send_parse_to_server: bool,
) -> Result<(), Error> {
if !self.has_prepared_statement(&parse.name) {
self.registering_prepared_statement
.push_back(parse.name.clone());
let close = Close::new(name);
let bytes: BytesMut = close.try_into()?;
let mut bytes = BytesMut::new();
self.send(&bytes).await?;
}
if should_send_parse_to_server {
let parse_bytes: BytesMut = parse.try_into()?;
bytes.extend_from_slice(&parse_bytes);
}
if !names.is_empty() {
self.send(&flush()).await?;
}
// Read and discard CloseComplete (3)
for name in &names {
match read_message(&mut self.stream).await {
Ok(_) => {
self.prepared_statements.remove(name);
self.stats.prepared_cache_remove();
debug!("Closed `{}`", name);
}
Err(err) => {
self.bad = true;
return Err(err);
}
// If we evict something, we need to close it on the server
// We do this by adding it to the messages we're sending to the server before the sync
if let Some(evicted_name) = self.add_prepared_statement_to_cache(&parse.name) {
self.remove_prepared_statement_from_cache(&evicted_name);
let close_bytes: BytesMut = Close::new(&evicted_name).try_into()?;
bytes.extend_from_slice(&close_bytes);
};
}
Ok(())
// If we have a parse or close we need to send to the server, send them and sync
if !bytes.is_empty() {
bytes.extend_from_slice(&sync());
self.send(&bytes).await?;
loop {
self.recv(None).await?;
if !self.is_data_available() {
break;
}
}
}
};
// If it's not there, something went bad, I'm guessing bad syntax or permissions error
// on the server.
if !self.has_prepared_statement(&parse.name) {
Err(Error::PreparedStatementError)
} else {
Ok(())
}
}
/// If the server is still inside a transaction.
@@ -1199,6 +1220,7 @@ impl Server {
self.in_transaction
}
/// Currently copying data from client to server or vice-versa.
pub fn in_copy_mode(&self) -> bool {
self.in_copy_mode
}
@@ -1324,6 +1346,10 @@ impl Server {
if self.cleanup_state.needs_cleanup_prepare {
reset_string.push_str("DEALLOCATE ALL;");
// Since we deallocated all prepared statements, we need to clear the cache
if let Some(cache) = &mut self.prepared_statement_cache {
cache.clear();
}
};
self.query(&reset_string).await?;
@@ -1359,16 +1385,14 @@ impl Server {
}
pub fn mirror_send(&mut self, bytes: &BytesMut) {
match self.mirror_manager.as_mut() {
Some(manager) => manager.send(bytes),
None => (),
if let Some(manager) = self.mirror_manager.as_mut() {
manager.send(bytes)
}
}
pub fn mirror_disconnect(&mut self) {
match self.mirror_manager.as_mut() {
Some(manager) => manager.disconnect(),
None => (),
if let Some(manager) = self.mirror_manager.as_mut() {
manager.disconnect()
}
}
@@ -1391,13 +1415,14 @@ impl Server {
Arc::new(RwLock::new(None)),
true,
false,
0,
)
.await?;
debug!("Connected!, sending query.");
server.send(&simple_query(query)).await?;
let mut message = server.recv(None).await?;
Ok(parse_query_message(&mut message).await?)
parse_query_message(&mut message).await
}
}

View File

@@ -64,7 +64,7 @@ impl Sharder {
fn sha1(&self, key: i64) -> usize {
let mut hasher = Sha1::new();
hasher.update(&key.to_string().as_bytes());
hasher.update(key.to_string().as_bytes());
let result = hasher.finalize();
@@ -202,10 +202,10 @@ mod test {
#[test]
fn test_sha1_hash() {
let sharder = Sharder::new(12, ShardingFunction::Sha1);
let ids = vec![
let ids = [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
];
let shards = vec![
let shards = [
4, 7, 8, 3, 6, 0, 0, 10, 3, 11, 1, 7, 4, 4, 11, 2, 5, 0, 8, 3,
];

View File

@@ -38,8 +38,10 @@ pub struct ClientStats {
/// Total time spent waiting for a connection from pool, measures in microseconds
pub total_wait_time: Arc<AtomicU64>,
/// Maximum time spent waiting for a connection from pool, measures in microseconds
pub max_wait_time: Arc<AtomicU64>,
/// When this client started waiting.
/// Stored as microseconds since connect_time so it can fit in an AtomicU64 instead
/// of us using an "AtomicInstant"
pub wait_start: Arc<AtomicU64>,
/// Current state of the client
pub state: Arc<AtomicClientState>,
@@ -63,7 +65,7 @@ impl Default for ClientStats {
username: String::new(),
pool_name: String::new(),
total_wait_time: Arc::new(AtomicU64::new(0)),
max_wait_time: Arc::new(AtomicU64::new(0)),
wait_start: Arc::new(AtomicU64::new(0)),
state: Arc::new(AtomicClientState::new(ClientState::Idle)),
transaction_count: Arc::new(AtomicU64::new(0)),
query_count: Arc::new(AtomicU64::new(0)),
@@ -111,6 +113,11 @@ impl ClientStats {
/// Reports a client is waiting for a connection
pub fn waiting(&self) {
// safe to truncate, we only lose info if duration is greater than ~585,000 years
self.wait_start.store(
Instant::now().duration_since(self.connect_time).as_micros() as u64,
Ordering::Relaxed,
);
self.state.store(ClientState::Waiting, Ordering::Relaxed);
}
@@ -134,8 +141,6 @@ impl ClientStats {
pub fn checkout_time(&self, microseconds: u64) {
self.total_wait_time
.fetch_add(microseconds, Ordering::Relaxed);
self.max_wait_time
.fetch_max(microseconds, Ordering::Relaxed);
}
/// Report a query executed by a client against a server

View File

@@ -4,6 +4,7 @@ use super::{ClientState, ServerState};
use crate::{config::PoolMode, messages::DataType, pool::PoolIdentifier};
use std::collections::HashMap;
use std::sync::atomic::*;
use tokio::time::Instant;
use crate::pool::get_all_pools;
@@ -53,6 +54,7 @@ impl PoolStats {
);
}
let now = Instant::now();
for client in client_map.values() {
match map.get_mut(&PoolIdentifier {
db: client.pool_name(),
@@ -62,10 +64,16 @@ impl PoolStats {
match client.state.load(Ordering::Relaxed) {
ClientState::Active => pool_stats.cl_active += 1,
ClientState::Idle => pool_stats.cl_idle += 1,
ClientState::Waiting => pool_stats.cl_waiting += 1,
ClientState::Waiting => {
pool_stats.cl_waiting += 1;
// wait_start is measured as microseconds since connect_time
// so compute wait_time as (now() - connect_time) - (wait_start - connect_time)
let duration_since_connect = now.duration_since(client.connect_time());
let wait_time = (duration_since_connect.as_micros() as u64)
- client.wait_start.load(Ordering::Relaxed);
pool_stats.maxwait = std::cmp::max(pool_stats.maxwait, wait_time);
}
}
let max_wait = client.max_wait_time.load(Ordering::Relaxed);
pool_stats.maxwait = std::cmp::max(pool_stats.maxwait, max_wait);
}
None => debug!("Client from an obselete pool"),
}
@@ -86,11 +94,11 @@ impl PoolStats {
}
}
return map;
map
}
pub fn generate_header() -> Vec<(&'static str, DataType)> {
return vec![
vec![
("database", DataType::Text),
("user", DataType::Text),
("pool_mode", DataType::Text),
@@ -105,11 +113,11 @@ impl PoolStats {
("sv_login", DataType::Numeric),
("maxwait", DataType::Numeric),
("maxwait_us", DataType::Numeric),
];
]
}
pub fn generate_row(&self) -> Vec<String> {
return vec![
vec![
self.identifier.db.clone(),
self.identifier.user.clone(),
self.mode.to_string(),
@@ -124,7 +132,7 @@ impl PoolStats {
self.sv_login.to_string(),
(self.maxwait / 1_000_000).to_string(),
(self.maxwait % 1_000_000).to_string(),
];
]
}
}

View File

@@ -49,6 +49,7 @@ pub struct ServerStats {
pub error_count: Arc<AtomicU64>,
pub prepared_hit_count: Arc<AtomicU64>,
pub prepared_miss_count: Arc<AtomicU64>,
pub prepared_eviction_count: Arc<AtomicU64>,
pub prepared_cache_size: Arc<AtomicU64>,
}
@@ -68,6 +69,7 @@ impl Default for ServerStats {
reporter: get_reporter(),
prepared_hit_count: Arc::new(AtomicU64::new(0)),
prepared_miss_count: Arc::new(AtomicU64::new(0)),
prepared_eviction_count: Arc::new(AtomicU64::new(0)),
prepared_cache_size: Arc::new(AtomicU64::new(0)),
}
}
@@ -221,6 +223,7 @@ impl ServerStats {
}
pub fn prepared_cache_remove(&self) {
self.prepared_eviction_count.fetch_add(1, Ordering::Relaxed);
self.prepared_cache_size.fetch_sub(1, Ordering::Relaxed);
}
}

View File

@@ -8,3 +8,6 @@ RUN rustup component add llvm-tools-preview
RUN sudo gem install bundler
RUN wget -O toxiproxy-2.4.0.deb https://github.com/Shopify/toxiproxy/releases/download/v2.4.0/toxiproxy_2.4.0_linux_$(dpkg --print-architecture).deb && \
sudo dpkg -i toxiproxy-2.4.0.deb
RUN wget -O go1.21.3.linux-$(dpkg --print-architecture).tar.gz https://go.dev/dl/go1.21.3.linux-$(dpkg --print-architecture).tar.gz && \
sudo tar -C /usr/local -xzf go1.21.3.linux-$(dpkg --print-architecture).tar.gz && \
rm go1.21.3.linux-$(dpkg --print-architecture).tar.gz

5
tests/go/go.mod Normal file
View File

@@ -0,0 +1,5 @@
module pgcat
go 1.21
require github.com/lib/pq v1.10.9

2
tests/go/go.sum Normal file
View File

@@ -0,0 +1,2 @@
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=

162
tests/go/pgcat.toml Normal file
View File

@@ -0,0 +1,162 @@
#
# PgCat config example.
#
#
# General pooler settings
[general]
# What IP to run on, 0.0.0.0 means accessible from everywhere.
host = "0.0.0.0"
# Port to run on, same as PgBouncer used in this example.
port = "${PORT}"
# Whether to enable prometheus exporter or not.
enable_prometheus_exporter = true
# Port at which prometheus exporter listens on.
prometheus_exporter_port = 9930
# How long to wait before aborting a server connection (ms).
connect_timeout = 1000
# How much time to give the health check query to return with a result (ms).
healthcheck_timeout = 1000
# How long to keep connection available for immediate re-use, without running a healthcheck query on it
healthcheck_delay = 30000
# How much time to give clients during shutdown before forcibly killing client connections (ms).
shutdown_timeout = 5000
# For how long to ban a server if it fails a health check (seconds).
ban_time = 60 # Seconds
# If we should log client connections
log_client_connections = false
# If we should log client disconnections
log_client_disconnections = false
# Reload config automatically if it changes.
autoreload = 15000
server_round_robin = false
# TLS
tls_certificate = "../../.circleci/server.cert"
tls_private_key = "../../.circleci/server.key"
# Credentials to access the virtual administrative database (pgbouncer or pgcat)
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
admin_username = "admin_user"
admin_password = "admin_pass"
# pool
# configs are structured as pool.<pool_name>
# the pool_name is what clients use as database name when connecting
# For the example below a client can connect using "postgres://sharding_user:sharding_user@pgcat_host:pgcat_port/sharded_db"
[pools.sharded_db]
# Pool mode (see PgBouncer docs for more).
# session: one server connection per connected client
# transaction: one server connection per client transaction
pool_mode = "transaction"
# If the client doesn't specify, route traffic to
# this role by default.
#
# any: round-robin between primary and replicas,
# replica: round-robin between replicas only without touching the primary,
# primary: all queries go to the primary unless otherwise specified.
default_role = "any"
# Query parser. If enabled, we'll attempt to parse
# every incoming query to determine if it's a read or a write.
# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,
# we'll direct it to the primary.
query_parser_enabled = true
# If the query parser is enabled and this setting is enabled, we'll attempt to
# infer the role from the query itself.
query_parser_read_write_splitting = true
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
# load balancing of read queries. Otherwise, the primary will only be used for write
# queries. The primary can always be explicitely selected with our custom protocol.
primary_reads_enabled = true
# So what if you wanted to implement a different hashing function,
# or you've already built one and you want this pooler to use it?
#
# Current options:
#
# pg_bigint_hash: PARTITION BY HASH (Postgres hashing function)
# sha1: A hashing function based on SHA1
#
sharding_function = "pg_bigint_hash"
# Prepared statements cache size.
prepared_statements_cache_size = 500
# Credentials for users that may connect to this cluster
[pools.sharded_db.users.0]
username = "sharding_user"
password = "sharding_user"
# Maximum number of server connections that can be established for this user
# The maximum number of connection from a single Pgcat process to any database in the cluster
# is the sum of pool_size across all users.
pool_size = 5
statement_timeout = 0
[pools.sharded_db.users.1]
username = "other_user"
password = "other_user"
pool_size = 21
statement_timeout = 30000
# Shard 0
[pools.sharded_db.shards.0]
# [ host, port, role ]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ]
]
# Database name (e.g. "postgres")
database = "shard0"
[pools.sharded_db.shards.1]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
]
database = "shard1"
[pools.sharded_db.shards.2]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
]
database = "shard2"
[pools.simple_db]
pool_mode = "session"
default_role = "primary"
query_parser_enabled = true
query_parser_read_write_splitting = true
primary_reads_enabled = true
sharding_function = "pg_bigint_hash"
[pools.simple_db.users.0]
username = "simple_user"
password = "simple_user"
pool_size = 5
statement_timeout = 30000
[pools.simple_db.shards.0]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ]
]
database = "some_db"

52
tests/go/prepared_test.go Normal file
View File

@@ -0,0 +1,52 @@
package pgcat
import (
"context"
"database/sql"
"fmt"
_ "github.com/lib/pq"
"testing"
)
func Test(t *testing.T) {
t.Cleanup(setup(t))
t.Run("Named parameterized prepared statement works", namedParameterizedPreparedStatement)
t.Run("Unnamed parameterized prepared statement works", unnamedParameterizedPreparedStatement)
}
func namedParameterizedPreparedStatement(t *testing.T) {
db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d database=sharded_db user=sharding_user password=sharding_user sslmode=disable", port))
if err != nil {
t.Fatalf("could not open connection: %+v", err)
}
stmt, err := db.Prepare("SELECT $1")
if err != nil {
t.Fatalf("could not prepare: %+v", err)
}
for i := 0; i < 100; i++ {
rows, err := stmt.Query(1)
if err != nil {
t.Fatalf("could not query: %+v", err)
}
_ = rows.Close()
}
}
func unnamedParameterizedPreparedStatement(t *testing.T) {
db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d database=sharded_db user=sharding_user password=sharding_user sslmode=disable", port))
if err != nil {
t.Fatalf("could not open connection: %+v", err)
}
for i := 0; i < 100; i++ {
// Under the hood QueryContext generates an unnamed parameterized prepared statement
rows, err := db.QueryContext(context.Background(), "SELECT $1", 1)
if err != nil {
t.Fatalf("could not query: %+v", err)
}
_ = rows.Close()
}
}

81
tests/go/setup.go Normal file
View File

@@ -0,0 +1,81 @@
package pgcat
import (
"context"
"database/sql"
_ "embed"
"fmt"
"math/rand"
"os"
"os/exec"
"strings"
"testing"
"time"
)
//go:embed pgcat.toml
var pgcatCfg string
var port = rand.Intn(32760-20000) + 20000
func setup(t *testing.T) func() {
cfg, err := os.CreateTemp("/tmp", "pgcat_cfg_*.toml")
if err != nil {
t.Fatalf("could not create temp file: %+v", err)
}
pgcatCfg = strings.Replace(pgcatCfg, "\"${PORT}\"", fmt.Sprintf("%d", port), 1)
_, err = cfg.Write([]byte(pgcatCfg))
if err != nil {
t.Fatalf("could not write temp file: %+v", err)
}
commandPath := "../../target/debug/pgcat"
if os.Getenv("CARGO_TARGET_DIR") != "" {
commandPath = os.Getenv("CARGO_TARGET_DIR") + "/debug/pgcat"
}
cmd := exec.Command(commandPath, cfg.Name())
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
go func() {
err = cmd.Run()
if err != nil {
t.Errorf("could not run pgcat: %+v", err)
}
}()
deadline, cancelFunc := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
defer cancelFunc()
for {
select {
case <-deadline.Done():
break
case <-time.After(50 * time.Millisecond):
db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d database=pgcat user=admin_user password=admin_pass sslmode=disable", port))
if err != nil {
continue
}
rows, err := db.QueryContext(deadline, "SHOW STATS")
if err != nil {
continue
}
_ = rows.Close()
_ = db.Close()
break
}
break
}
return func() {
err := cmd.Process.Signal(os.Interrupt)
if err != nil {
t.Fatalf("could not interrupt pgcat: %+v", err)
}
err = os.Remove(cfg.Name())
if err != nil {
t.Fatalf("could not remove temp file: %+v", err)
}
}
}

View File

@@ -36,4 +36,4 @@ SELECT abalance FROM pgbench_accounts WHERE aid = :aid;
SET SERVER ROLE TO 'replica';
-- Read load balancing
SELECT abalance FROM pgbench_accounts WHERE aid = :aid;
SELECT abalance FROM pgbench_accounts WHERE aid = :aid;

View File

@@ -1,29 +1,214 @@
require_relative 'spec_helper'
describe 'Prepared statements' do
let(:processes) { Helpers::Pgcat.three_shard_setup('sharded_db', 5) }
let(:pool_size) { 5 }
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", pool_size) }
let(:prepared_statements_cache_size) { 100 }
let(:server_round_robin) { false }
context 'enabled' do
it 'will work over the same connection' do
before do
new_configs = processes.pgcat.current_config
new_configs["general"]["server_round_robin"] = server_round_robin
new_configs["pools"]["sharded_db"]["prepared_statements_cache_size"] = prepared_statements_cache_size
new_configs["pools"]["sharded_db"]["users"]["0"]["pool_size"] = pool_size
processes.pgcat.update_config(new_configs)
processes.pgcat.reload_config
end
context 'when trying prepared statements' do
it 'it allows unparameterized statements to succeed' do
conn1 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
conn2 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
prepared_query = "SELECT 1"
# prepare query on server 1 and client 1
conn1.prepare('statement1', prepared_query)
conn1.exec_prepared('statement1')
conn2.transaction do
# Claim server 1 with client 2
conn2.exec("SELECT 2")
# Client 1 now runs the prepared query, and it's automatically
# prepared on server 2
conn1.prepare('statement2', prepared_query)
conn1.exec_prepared('statement2')
# Client 2 now prepares the same query that was already
# prepared on server 1. And PgBouncer reuses that already
# prepared query for this different client.
conn2.prepare('statement3', prepared_query)
conn2.exec_prepared('statement3')
end
ensure
conn1.close if conn1
conn2.close if conn2
end
it 'it allows parameterized statements to succeed' do
conn1 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
conn2 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
prepared_query = "SELECT $1"
# prepare query on server 1 and client 1
conn1.prepare('statement1', prepared_query)
conn1.exec_prepared('statement1', [1])
conn2.transaction do
# Claim server 1 with client 2
conn2.exec("SELECT 2")
# Client 1 now runs the prepared query, and it's automatically
# prepared on server 2
conn1.prepare('statement2', prepared_query)
conn1.exec_prepared('statement2', [1])
# Client 2 now prepares the same query that was already
# prepared on server 1. And PgBouncer reuses that already
# prepared query for this different client.
conn2.prepare('statement3', prepared_query)
conn2.exec_prepared('statement3', [1])
end
ensure
conn1.close if conn1
conn2.close if conn2
end
end
context 'when trying large packets' do
it "works with large parse" do
conn1 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
long_string = "1" * 4096 * 10
prepared_query = "SELECT '#{long_string}'"
# prepare query on server 1 and client 1
conn1.prepare('statement1', prepared_query)
result = conn1.exec_prepared('statement1')
# assert result matches long_string
expect(result.getvalue(0, 0)).to eq(long_string)
ensure
conn1.close if conn1
end
it "works with large bind" do
conn1 = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
long_string = "1" * 4096 * 10
prepared_query = "SELECT $1::text"
# prepare query on server 1 and client 1
conn1.prepare('statement1', prepared_query)
result = conn1.exec_prepared('statement1', [long_string])
# assert result matches long_string
expect(result.getvalue(0, 0)).to eq(long_string)
ensure
conn1.close if conn1
end
end
context 'when statement cache is smaller than set of unqiue statements' do
let(:prepared_statements_cache_size) { 1 }
let(:pool_size) { 1 }
it "evicts all but 1 statement from the server cache" do
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
5.times do |i|
prepared_query = "SELECT '#{i}'"
conn.prepare("statement#{i}", prepared_query)
result = conn.exec_prepared("statement#{i}")
expect(result.getvalue(0, 0)).to eq(i.to_s)
end
# Check number of prepared statements (expected: 1)
n_statements = conn.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
expect(n_statements).to eq(1)
end
end
context 'when statement cache is larger than set of unqiue statements' do
let(:pool_size) { 1 }
it "does not evict any of the statements from the cache" do
# cache size 5
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
5.times do |i|
prepared_query = "SELECT '#{i}'"
conn.prepare("statement#{i}", prepared_query)
result = conn.exec_prepared("statement#{i}")
expect(result.getvalue(0, 0)).to eq(i.to_s)
end
# Check number of prepared statements (expected: 1)
n_statements = conn.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
expect(n_statements).to eq(5)
end
end
context 'when preparing the same query' do
let(:prepared_statements_cache_size) { 5 }
let(:pool_size) { 5 }
it "reuses statement cache when there are different statement names on the same connection" do
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
10.times do |i|
statement_name = "statement_#{i}"
conn.prepare(statement_name, 'SELECT $1::int')
conn.exec_prepared(statement_name, [1])
conn.describe_prepared(statement_name)
end
# Check number of prepared statements (expected: 1)
n_statements = conn.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
expect(n_statements).to eq(1)
end
it 'will work with new connections' do
10.times do
it "reuses statement cache when there are different statement names on different connections" do
10.times do |i|
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
statement_name = 'statement1'
conn.prepare('statement1', 'SELECT $1::int')
conn.exec_prepared('statement1', [1])
conn.describe_prepared('statement1')
statement_name = "statement_#{i}"
conn.prepare(statement_name, 'SELECT $1::int')
conn.exec_prepared(statement_name, [1])
end
# Check number of prepared statements (expected: 1)
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
n_statements = conn.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
expect(n_statements).to eq(1)
end
end
context 'when reloading config' do
let(:pool_size) { 1 }
it "test_reload_config" do
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
# prepare query
conn.prepare('statement1', 'SELECT 1')
conn.exec_prepared('statement1')
# Reload config which triggers pool recreation
new_configs = processes.pgcat.current_config
new_configs["pools"]["sharded_db"]["prepared_statements_cache_size"] = prepared_statements_cache_size + 1
processes.pgcat.update_config(new_configs)
processes.pgcat.reload_config
# check that we're starting with no prepared statements on the server
conn_check = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
n_statements = conn_check.exec("SELECT count(*) FROM pg_prepared_statements").getvalue(0, 0).to_i
expect(n_statements).to eq(0)
# still able to run prepared query
conn.exec_prepared('statement1')
end
end
end

View File

@@ -233,7 +233,7 @@ describe "Stats" do
sleep(1.1) # Allow time for stats to update
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_cancel_req sv_idle sv_used sv_tested sv_login maxwait].each do |s|
%w[cl_idle cl_cancel_req sv_idle sv_used sv_tested sv_login].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
@@ -260,12 +260,20 @@ describe "Stats" do
threads << Thread.new { c.async_exec("SELECT pg_sleep(1.5)") rescue nil }
end
sleep(2.5) # Allow time for stats to update
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
# two connections waiting => they report wait time
sleep(1.1) # Allow time for stats to update
results = admin_conn.async_exec("SHOW POOLS")[0]
expect(results["maxwait"]).to eq("1")
expect(results["maxwait_us"].to_i).to be_within(200_000).of(100_000)
sleep(2.5) # Allow time for stats to update
results = admin_conn.async_exec("SHOW POOLS")[0]
expect(results["maxwait"]).to eq("1")
expect(results["maxwait_us"].to_i).to be_within(200_000).of(500_000)
# no connections waiting => no reported wait time
expect(results["maxwait"]).to eq("0")
expect(results["maxwait_us"]).to eq("0")
connections.map(&:close)
sleep(4.5) # Allow time for stats to update
@@ -329,6 +337,40 @@ describe "Stats" do
admin_conn.close
connections.map(&:close)
end
context "when client has waited for a server" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 2) }
it "shows correct maxwait" do
threads = []
connections = Array.new(3) { |i| PG::connect("#{pgcat_conn_str}?application_name=app#{i}") }
connections.each do |c|
threads << Thread.new { c.async_exec("SELECT pg_sleep(1.5)") rescue nil }
end
sleep(2.5) # Allow time for stats to update
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW CLIENTS")
normal_client_results = results.reject { |r| r["database"] == "pgcat" }
non_waiting_clients = normal_client_results.select { |c| c["maxwait"] == "0" }
waiting_clients = normal_client_results.select { |c| c["maxwait"].to_i > 0 }
expect(non_waiting_clients.count).to eq(2)
non_waiting_clients.each do |client|
expect(client["maxwait_us"].to_i).to be_between(0, 50_000)
end
expect(waiting_clients.count).to eq(1)
waiting_clients.each do |client|
expect(client["maxwait_us"].to_i).to be_within(200_000).of(500_000)
end
admin_conn.close
connections.map(&:close)
end
end
end

4
tests/rust/Cargo.lock generated
View File

@@ -1206,9 +1206,9 @@ dependencies = [
[[package]]
name = "webpki"
version = "0.22.0"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd"
checksum = "07ecc0cd7cac091bf682ec5efa18b1cff79d617b84181f38b3951dbe135f607f"
dependencies = [
"ring",
"untrusted",

View File

@@ -16,7 +16,14 @@ async fn test_prepared_statements() {
let pool = pool.clone();
let handle = tokio::task::spawn(async move {
for _ in 0..1000 {
sqlx::query("SELECT 1").fetch_all(&pool).await.unwrap();
match sqlx::query("SELECT one").fetch_all(&pool).await {
Ok(_) => (),
Err(err) => {
if err.to_string().contains("prepared statement") {
panic!("prepared statement error: {}", err);
}
}
}
}
});

View File

@@ -22,7 +22,7 @@ mkdir -p "$deb_dir/etc/systemd/system"
cp target/release/pgcat "$deb_dir/usr/bin/pgcat"
chmod +x "$deb_dir/usr/bin/pgcat"
cp pgcat.toml "$deb_dir/etc/pgcat.toml"
cp pgcat.toml "$deb_dir/etc/pgcat.example.toml"
cp pgcat.service "$deb_dir/etc/systemd/system/pgcat.service"
(cat control | envsubst) > "$deb_dir/DEBIAN/control"