mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-24 17:56:29 +00:00
Compare commits
32 Commits
kczimm-mea
...
levkk-tls-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fa17bb5cc6 | ||
|
|
04e9814770 | ||
|
|
037d232fcd | ||
|
|
b2933762e7 | ||
|
|
df8aa888f9 | ||
|
|
7f5639c94a | ||
|
|
c0112f6f12 | ||
|
|
b7ceee2ddf | ||
|
|
0b01d70b55 | ||
|
|
33db0dffa8 | ||
|
|
7994a661d9 | ||
|
|
9937193332 | ||
|
|
baa00ff546 | ||
|
|
ffe820497f | ||
|
|
be549f3faa | ||
|
|
4301ab0606 | ||
|
|
5143500c9a | ||
|
|
3255323bff | ||
|
|
bb27586758 | ||
|
|
4f0f45b576 | ||
|
|
f94ce97ebc | ||
|
|
9ab128579d | ||
|
|
1cde74f05e | ||
|
|
a4de6c1eb6 | ||
|
|
e14b283f0c | ||
|
|
7c3c90c38e | ||
|
|
2ca21b2bec | ||
|
|
3986eaa4b2 | ||
|
|
1f2c6507f7 | ||
|
|
aefcf4281c | ||
|
|
9d1c46a3e9 | ||
|
|
328108aeb5 |
@@ -9,7 +9,7 @@ jobs:
|
||||
# Specify the execution environment. You can specify an image from Dockerhub or use one of our Convenience Images from CircleCI's Developer Hub.
|
||||
# See: https://circleci.com/docs/2.0/configuration-reference/#docker-machine-macos-windows-executor
|
||||
docker:
|
||||
- image: ghcr.io/levkk/pgcat-ci:1.67
|
||||
- image: ghcr.io/postgresml/pgcat-ci:latest
|
||||
environment:
|
||||
RUST_LOG: info
|
||||
LLVM_PROFILE_FILE: /tmp/pgcat-%m-%p.profraw
|
||||
|
||||
@@ -74,6 +74,10 @@ default_role = "any"
|
||||
# we'll direct it to the primary.
|
||||
query_parser_enabled = true
|
||||
|
||||
# If the query parser is enabled and this setting is enabled, we'll attempt to
|
||||
# infer the role from the query itself.
|
||||
query_parser_read_write_splitting = true
|
||||
|
||||
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
|
||||
# load balancing of read queries. Otherwise, the primary will only be used for write
|
||||
# queries. The primary can always be explicitely selected with our custom protocol.
|
||||
@@ -134,6 +138,7 @@ database = "shard2"
|
||||
pool_mode = "session"
|
||||
default_role = "primary"
|
||||
query_parser_enabled = true
|
||||
query_parser_read_write_splitting = true
|
||||
primary_reads_enabled = true
|
||||
sharding_function = "pg_bigint_hash"
|
||||
|
||||
|
||||
1
.github/workflows/build-and-push.yaml
vendored
1
.github/workflows/build-and-push.yaml
vendored
@@ -34,6 +34,7 @@ jobs:
|
||||
tags: |
|
||||
type=sha,prefix=,format=long
|
||||
type=schedule
|
||||
type=ref,event=tag
|
||||
type=ref,event=branch
|
||||
type=ref,event=pr
|
||||
type=raw,value=latest,enable={{ is_default_branch }}
|
||||
|
||||
48
.github/workflows/publish-deb-package.yml
vendored
Normal file
48
.github/workflows/publish-deb-package.yml
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
name: pgcat package (deb)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
packageVersion:
|
||||
default: "1.1.2-dev"
|
||||
jobs:
|
||||
build:
|
||||
strategy:
|
||||
max-parallel: 1
|
||||
fail-fast: false # Let the other job finish, or they can lock each other out
|
||||
matrix:
|
||||
os: ["buildjet-4vcpu-ubuntu-2204", "buildjet-4vcpu-ubuntu-2204-arm"]
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: stable
|
||||
- name: Install dependencies
|
||||
env:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
TZ: Etc/UTC
|
||||
run: |
|
||||
curl -sLO https://github.com/deb-s3/deb-s3/releases/download/0.11.4/deb-s3-0.11.4.gem
|
||||
sudo gem install deb-s3-0.11.4.gem
|
||||
dpkg-deb --version
|
||||
- name: Build and release package
|
||||
env:
|
||||
AWS_ACCESS_KEY_ID: ${{ vars.AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
AWS_DEFAULT_REGION: ${{ vars.AWS_DEFAULT_REGION }}
|
||||
run: |
|
||||
if [[ $(arch) == "x86_64" ]]; then
|
||||
export ARCH=amd64
|
||||
else
|
||||
export ARCH=arm64
|
||||
fi
|
||||
|
||||
bash utilities/deb.sh ${{ inputs.packageVersion }}
|
||||
|
||||
deb-s3 upload \
|
||||
--lock \
|
||||
--bucket apt.postgresml.org \
|
||||
pgcat-${{ inputs.packageVersion }}-ubuntu22.04-${ARCH}.deb \
|
||||
--codename $(lsb_release -cs)
|
||||
83
CONFIG.md
83
CONFIG.md
@@ -57,6 +57,38 @@ default: 86400000 # 24 hours
|
||||
|
||||
Max connection lifetime before it's closed, even if actively used.
|
||||
|
||||
### server_round_robin
|
||||
```
|
||||
path: general.server_round_robin
|
||||
default: false
|
||||
```
|
||||
|
||||
Whether to use round robin for server selection or not.
|
||||
|
||||
### server_tls
|
||||
```
|
||||
path: general.server_tls
|
||||
default: false
|
||||
```
|
||||
|
||||
Whether to use TLS for server connections or not.
|
||||
|
||||
### verify_server_certificate
|
||||
```
|
||||
path: general.verify_server_certificate
|
||||
default: false
|
||||
```
|
||||
|
||||
Whether to verify server certificate or not.
|
||||
|
||||
### verify_config
|
||||
```
|
||||
path: general.verify_config
|
||||
default: true
|
||||
```
|
||||
|
||||
Whether to verify config or not.
|
||||
|
||||
### idle_client_in_transaction_timeout
|
||||
```
|
||||
path: general.idle_client_in_transaction_timeout
|
||||
@@ -194,6 +226,55 @@ default: "admin_pass"
|
||||
|
||||
Password to access the virtual administrative database
|
||||
|
||||
### auth_query
|
||||
```
|
||||
path: general.auth_query
|
||||
default: <UNSET>
|
||||
example: "SELECT $1"
|
||||
```
|
||||
|
||||
Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
|
||||
established using the database configured in the pool. This parameter is inherited by every pool
|
||||
and can be redefined in pool configuration.
|
||||
|
||||
### auth_query_user
|
||||
```
|
||||
path: general.auth_query_user
|
||||
default: <UNSET>
|
||||
example: "sharding_user"
|
||||
```
|
||||
|
||||
User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
|
||||
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
|
||||
This parameter is inherited by every pool and can be redefined in pool configuration.
|
||||
|
||||
### auth_query_password
|
||||
```
|
||||
path: general.auth_query_password
|
||||
default: <UNSET>
|
||||
example: "sharding_user"
|
||||
```
|
||||
|
||||
Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
|
||||
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
|
||||
This parameter is inherited by every pool and can be redefined in pool configuration.
|
||||
|
||||
### prepared_statements
|
||||
```
|
||||
path: general.prepared_statements
|
||||
default: false
|
||||
```
|
||||
|
||||
Whether to use prepared statements or not.
|
||||
|
||||
### prepared_statements_cache_size
|
||||
```
|
||||
path: general.prepared_statements_cache_size
|
||||
default: 500
|
||||
```
|
||||
|
||||
Size of the prepared statements cache.
|
||||
|
||||
### dns_cache_enabled
|
||||
```
|
||||
path: general.dns_cache_enabled
|
||||
@@ -230,7 +311,7 @@ default: "random"
|
||||
|
||||
Load balancing mode
|
||||
`random` selects the server at random
|
||||
`loc` selects the server with the least outstanding busy conncetions
|
||||
`loc` selects the server with the least outstanding busy connections
|
||||
|
||||
### default_role
|
||||
```
|
||||
|
||||
80
Cargo.lock
generated
80
Cargo.lock
generated
@@ -353,19 +353,6 @@ dependencies = [
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "85cdab6a89accf66733ad5a1693a4dcced6aeff64602b634530dd73c1f3ee9f0"
|
||||
dependencies = [
|
||||
"humantime",
|
||||
"is-terminal",
|
||||
"log",
|
||||
"regex",
|
||||
"termcolor",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "equivalent"
|
||||
version = "1.0.1"
|
||||
@@ -633,12 +620,6 @@ version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421"
|
||||
|
||||
[[package]]
|
||||
name = "humantime"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
|
||||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "0.14.27"
|
||||
@@ -855,6 +836,15 @@ version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4"
|
||||
|
||||
[[package]]
|
||||
name = "matchers"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
|
||||
dependencies = [
|
||||
"regex-automata 0.1.10",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "matches"
|
||||
version = "0.1.10"
|
||||
@@ -1000,7 +990,7 @@ checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94"
|
||||
|
||||
[[package]]
|
||||
name = "pgcat"
|
||||
version = "1.1.0"
|
||||
version = "1.1.2-dev"
|
||||
dependencies = [
|
||||
"arc-swap",
|
||||
"async-trait",
|
||||
@@ -1010,7 +1000,6 @@ dependencies = [
|
||||
"bytes",
|
||||
"chrono",
|
||||
"clap",
|
||||
"env_logger",
|
||||
"exitcode",
|
||||
"fallible-iterator",
|
||||
"futures",
|
||||
@@ -1218,8 +1207,17 @@ checksum = "b2eae68fc220f7cf2532e4494aded17545fce192d59cd996e0fe7887f4ceb575"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-automata",
|
||||
"regex-syntax",
|
||||
"regex-automata 0.3.3",
|
||||
"regex-syntax 0.7.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-automata"
|
||||
version = "0.1.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
|
||||
dependencies = [
|
||||
"regex-syntax 0.6.29",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1230,9 +1228,15 @@ checksum = "39354c10dd07468c2e73926b23bb9c2caca74c5501e38a35da70406f1d923310"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-syntax",
|
||||
"regex-syntax 0.7.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.6.29"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.7.4"
|
||||
@@ -1306,9 +1310,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.100.1"
|
||||
version = "0.100.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d6207cd5ed3d8dca7816f8f3725513a34609c0c765bf652b8c3cb4cfd87db46b"
|
||||
checksum = "e98ff011474fa39949b7e5c0428f9b4937eda7da7848bbb947786b7be0b27dab"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"untrusted",
|
||||
@@ -1544,15 +1548,6 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "be55cf8942feac5c765c2c993422806843c9a9a45d4d5c407ad6dd2ea95eb9b6"
|
||||
dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.43"
|
||||
@@ -1788,12 +1783,16 @@ version = "0.3.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "30a651bc37f915e81f087d86e62a18eec5f79550c7faff886f7090b4ea757c77"
|
||||
dependencies = [
|
||||
"matchers",
|
||||
"nu-ansi-term",
|
||||
"once_cell",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sharded-slab",
|
||||
"smallvec",
|
||||
"thread_local",
|
||||
"tracing",
|
||||
"tracing-core",
|
||||
"tracing-log",
|
||||
"tracing-serde",
|
||||
@@ -2003,7 +2002,7 @@ version = "0.23.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b03058f88386e5ff5310d9111d53f48b17d732b401aeb83a8d5190f2ac459338"
|
||||
dependencies = [
|
||||
"rustls-webpki 0.100.1",
|
||||
"rustls-webpki 0.100.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2028,15 +2027,6 @@ version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||
|
||||
[[package]]
|
||||
name = "winapi-util"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-x86_64-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "pgcat"
|
||||
version = "1.1.0"
|
||||
version = "1.1.2-dev"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
@@ -22,7 +22,6 @@ once_cell = "1"
|
||||
sqlparser = {version = "0.34", features = ["visitor"] }
|
||||
log = "0.4"
|
||||
arc-swap = "1"
|
||||
env_logger = "0.10"
|
||||
parking_lot = "0.12.1"
|
||||
hmac = "0.12"
|
||||
sha2 = "0.10"
|
||||
@@ -48,7 +47,7 @@ serde_json = "1"
|
||||
itertools = "0.10"
|
||||
clap = { version = "4.3.1", features = ["derive", "env"] }
|
||||
tracing = "0.1.37"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["json"]}
|
||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter", "std"]}
|
||||
|
||||
[target.'cfg(not(target_env = "msvc"))'.dependencies]
|
||||
jemallocator = "0.5.0"
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
FROM rust:1 AS builder
|
||||
FROM rust:1-slim-bookworm AS builder
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential
|
||||
|
||||
COPY . /app
|
||||
WORKDIR /app
|
||||
RUN cargo build --release
|
||||
|
||||
FROM debian:bullseye-slim
|
||||
FROM debian:bookworm-slim
|
||||
COPY --from=builder /app/target/release/pgcat /usr/bin/pgcat
|
||||
COPY --from=builder /app/pgcat.toml /etc/pgcat/pgcat.toml
|
||||
WORKDIR /etc/pgcat
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
FROM cimg/rust:1.67.1
|
||||
COPY --from=sclevine/yj /bin/yj /bin/yj
|
||||
RUN /bin/yj -h
|
||||
RUN sudo apt-get update && \
|
||||
sudo apt-get install -y \
|
||||
psmisc postgresql-contrib-14 postgresql-client-14 libpq-dev \
|
||||
|
||||
25
Dockerfile.dev
Normal file
25
Dockerfile.dev
Normal file
@@ -0,0 +1,25 @@
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1 AS chef
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
FROM chef AS planner
|
||||
COPY . .
|
||||
RUN cargo chef prepare --recipe-path recipe.json
|
||||
|
||||
FROM chef AS builder
|
||||
COPY --from=planner /app/recipe.json recipe.json
|
||||
# Build dependencies - this is the caching Docker layer!
|
||||
RUN cargo chef cook --release --recipe-path recipe.json
|
||||
# Build application
|
||||
COPY . .
|
||||
RUN cargo build
|
||||
|
||||
FROM debian:bookworm-slim
|
||||
COPY --from=builder /app/target/release/pgcat /usr/bin/pgcat
|
||||
COPY --from=builder /app/pgcat.toml /etc/pgcat/pgcat.toml
|
||||
WORKDIR /etc/pgcat
|
||||
ENV RUST_LOG=info
|
||||
CMD ["pgcat"]
|
||||
9
control
Normal file
9
control
Normal file
@@ -0,0 +1,9 @@
|
||||
Package: pgcat
|
||||
Version: ${PACKAGE_VERSION}
|
||||
Section: database
|
||||
Priority: optional
|
||||
Architecture: ${ARCH}
|
||||
Maintainer: PostgresML <team@postgresml.org>
|
||||
Homepage: https://postgresml.org
|
||||
Description: PgCat - NextGen PostgreSQL Pooler
|
||||
PostgreSQL pooler and proxy (like PgBouncer) with support for sharding, load balancing, failover and mirroring.
|
||||
@@ -71,6 +71,10 @@ default_role = "any"
|
||||
# we'll direct it to the primary.
|
||||
query_parser_enabled = true
|
||||
|
||||
# If the query parser is enabled and this setting is enabled, we'll attempt to
|
||||
# infer the role from the query itself.
|
||||
query_parser_read_write_splitting = true
|
||||
|
||||
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
|
||||
# load balancing of read queries. Otherwise, the primary will only be used for write
|
||||
# queries. The primary can always be explicitly selected with our custom protocol.
|
||||
|
||||
16
pgcat.service
Normal file
16
pgcat.service
Normal file
@@ -0,0 +1,16 @@
|
||||
[Unit]
|
||||
Description=PgCat pooler
|
||||
After=network.target
|
||||
StartLimitIntervalSec=0
|
||||
|
||||
[Service]
|
||||
User=pgcat
|
||||
Type=simple
|
||||
Restart=always
|
||||
RestartSec=1
|
||||
Environment=RUST_LOG=info
|
||||
LimitNOFILE=65536
|
||||
ExecStart=/usr/bin/pgcat /etc/pgcat.toml
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
14
pgcat.toml
14
pgcat.toml
@@ -162,6 +162,10 @@ default_role = "any"
|
||||
# we'll direct it to the primary.
|
||||
query_parser_enabled = true
|
||||
|
||||
# If the query parser is enabled and this setting is enabled, we'll attempt to
|
||||
# infer the role from the query itself.
|
||||
query_parser_read_write_splitting = true
|
||||
|
||||
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
|
||||
# load balancing of read queries. Otherwise, the primary will only be used for write
|
||||
# queries. The primary can always be explicitly selected with our custom protocol.
|
||||
@@ -173,6 +177,12 @@ primary_reads_enabled = true
|
||||
# shard_id_regex = '/\* shard_id: (\d+) \*/'
|
||||
# regex_search_limit = 1000 # only look at the first 1000 characters of SQL statements
|
||||
|
||||
# Defines the behavior when no shard is selected in a sharded system.
|
||||
# `random`: picks a shard at random
|
||||
# `random_healthy`: picks a shard at random favoring shards with the least number of recent errors
|
||||
# `shard_<number>`: e.g. shard_0, shard_4, etc. picks a specific shard, everytime
|
||||
# no_shard_specified_behavior = "shard_0"
|
||||
|
||||
# So what if you wanted to implement a different hashing function,
|
||||
# or you've already built one and you want this pooler to use it?
|
||||
# Current options:
|
||||
@@ -183,7 +193,7 @@ sharding_function = "pg_bigint_hash"
|
||||
# Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
|
||||
# established using the database configured in the pool. This parameter is inherited by every pool
|
||||
# and can be redefined in pool configuration.
|
||||
# auth_query = "SELECT $1"
|
||||
# auth_query="SELECT usename, passwd FROM pg_shadow WHERE usename='$1'"
|
||||
|
||||
# User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
|
||||
# specified in `auth_query_user`. The connection will be established using the database configured in the pool.
|
||||
@@ -270,7 +280,7 @@ username = "sharding_user"
|
||||
# if `server_password` is not set.
|
||||
password = "sharding_user"
|
||||
|
||||
pool_mode = "session"
|
||||
pool_mode = "transaction"
|
||||
|
||||
# PostgreSQL username used to connect to the server.
|
||||
# server_username = "another_user"
|
||||
|
||||
9
postinst
Normal file
9
postinst
Normal file
@@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
systemctl daemon-reload
|
||||
systemctl enable pgcat
|
||||
|
||||
if ! id pgcat 2> /dev/null; then
|
||||
useradd -s /usr/bin/false pgcat
|
||||
fi
|
||||
5
prerm
Normal file
5
prerm
Normal file
@@ -0,0 +1,5 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
systemctl stop pgcat
|
||||
systemctl disable pgcat
|
||||
189
src/admin.rs
189
src/admin.rs
@@ -1,4 +1,5 @@
|
||||
use crate::pool::BanReason;
|
||||
use crate::server::ServerParameters;
|
||||
use crate::stats::pool::PoolStats;
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use log::{error, info, trace};
|
||||
@@ -17,16 +18,16 @@ use crate::pool::ClientServerMap;
|
||||
use crate::pool::{get_all_pools, get_pool};
|
||||
use crate::stats::{get_client_stats, get_server_stats, ClientState, ServerState};
|
||||
|
||||
pub fn generate_server_info_for_admin() -> BytesMut {
|
||||
let mut server_info = BytesMut::new();
|
||||
pub fn generate_server_parameters_for_admin() -> ServerParameters {
|
||||
let mut server_parameters = ServerParameters::new();
|
||||
|
||||
server_info.put(server_parameter_message("application_name", ""));
|
||||
server_info.put(server_parameter_message("client_encoding", "UTF8"));
|
||||
server_info.put(server_parameter_message("server_encoding", "UTF8"));
|
||||
server_info.put(server_parameter_message("server_version", VERSION));
|
||||
server_info.put(server_parameter_message("DateStyle", "ISO, MDY"));
|
||||
server_parameters.set_param("application_name".to_string(), "".to_string(), true);
|
||||
server_parameters.set_param("client_encoding".to_string(), "UTF8".to_string(), true);
|
||||
server_parameters.set_param("server_encoding".to_string(), "UTF8".to_string(), true);
|
||||
server_parameters.set_param("server_version".to_string(), VERSION.to_string(), true);
|
||||
server_parameters.set_param("DateStyle".to_string(), "ISO, MDY".to_string(), true);
|
||||
|
||||
server_info
|
||||
server_parameters
|
||||
}
|
||||
|
||||
/// Handle admin client.
|
||||
@@ -73,11 +74,11 @@ where
|
||||
}
|
||||
"PAUSE" => {
|
||||
trace!("PAUSE");
|
||||
pause(stream, query_parts[1]).await
|
||||
pause(stream, query_parts).await
|
||||
}
|
||||
"RESUME" => {
|
||||
trace!("RESUME");
|
||||
resume(stream, query_parts[1]).await
|
||||
resume(stream, query_parts).await
|
||||
}
|
||||
"SHUTDOWN" => {
|
||||
trace!("SHUTDOWN");
|
||||
@@ -796,96 +797,128 @@ where
|
||||
}
|
||||
|
||||
/// Pause a pool. It won't pass any more queries to the backends.
|
||||
async fn pause<T>(stream: &mut T, query: &str) -> Result<(), Error>
|
||||
async fn pause<T>(stream: &mut T, tokens: Vec<&str>) -> Result<(), Error>
|
||||
where
|
||||
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
let parts: Vec<&str> = query.split(',').map(|part| part.trim()).collect();
|
||||
let parts: Vec<&str> = match tokens.len() == 2 {
|
||||
true => tokens[1].split(",").map(|part| part.trim()).collect(),
|
||||
false => Vec::new(),
|
||||
};
|
||||
|
||||
if parts.len() != 2 {
|
||||
error_response(
|
||||
stream,
|
||||
"PAUSE requires a database and a user, e.g. PAUSE my_db,my_user",
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
let database = parts[0];
|
||||
let user = parts[1];
|
||||
|
||||
match get_pool(database, user) {
|
||||
Some(pool) => {
|
||||
match parts.len() {
|
||||
0 => {
|
||||
for (_, pool) in get_all_pools() {
|
||||
pool.pause();
|
||||
|
||||
let mut res = BytesMut::new();
|
||||
|
||||
res.put(command_complete(&format!("PAUSE {},{}", database, user)));
|
||||
|
||||
// ReadyForQuery
|
||||
res.put_u8(b'Z');
|
||||
res.put_i32(5);
|
||||
res.put_u8(b'I');
|
||||
|
||||
write_all_half(stream, &res).await
|
||||
}
|
||||
|
||||
None => {
|
||||
error_response(
|
||||
stream,
|
||||
&format!(
|
||||
"No pool configured for database: {}, user: {}",
|
||||
database, user
|
||||
),
|
||||
)
|
||||
.await
|
||||
let mut res = BytesMut::new();
|
||||
|
||||
res.put(command_complete("PAUSE"));
|
||||
|
||||
// ReadyForQuery
|
||||
res.put_u8(b'Z');
|
||||
res.put_i32(5);
|
||||
res.put_u8(b'I');
|
||||
|
||||
write_all_half(stream, &res).await
|
||||
}
|
||||
2 => {
|
||||
let database = parts[0];
|
||||
let user = parts[1];
|
||||
|
||||
match get_pool(database, user) {
|
||||
Some(pool) => {
|
||||
pool.pause();
|
||||
|
||||
let mut res = BytesMut::new();
|
||||
|
||||
res.put(command_complete(&format!("PAUSE {},{}", database, user)));
|
||||
|
||||
// ReadyForQuery
|
||||
res.put_u8(b'Z');
|
||||
res.put_i32(5);
|
||||
res.put_u8(b'I');
|
||||
|
||||
write_all_half(stream, &res).await
|
||||
}
|
||||
|
||||
None => {
|
||||
error_response(
|
||||
stream,
|
||||
&format!(
|
||||
"No pool configured for database: {}, user: {}",
|
||||
database, user
|
||||
),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => error_response(stream, "usage: PAUSE [db, user]").await,
|
||||
}
|
||||
}
|
||||
|
||||
/// Resume a pool. Queries are allowed again.
|
||||
async fn resume<T>(stream: &mut T, query: &str) -> Result<(), Error>
|
||||
async fn resume<T>(stream: &mut T, tokens: Vec<&str>) -> Result<(), Error>
|
||||
where
|
||||
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
let parts: Vec<&str> = query.split(',').map(|part| part.trim()).collect();
|
||||
let parts: Vec<&str> = match tokens.len() == 2 {
|
||||
true => tokens[1].split(",").map(|part| part.trim()).collect(),
|
||||
false => Vec::new(),
|
||||
};
|
||||
|
||||
if parts.len() != 2 {
|
||||
error_response(
|
||||
stream,
|
||||
"RESUME requires a database and a user, e.g. RESUME my_db,my_user",
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
let database = parts[0];
|
||||
let user = parts[1];
|
||||
|
||||
match get_pool(database, user) {
|
||||
Some(pool) => {
|
||||
match parts.len() {
|
||||
0 => {
|
||||
for (_, pool) in get_all_pools() {
|
||||
pool.resume();
|
||||
|
||||
let mut res = BytesMut::new();
|
||||
|
||||
res.put(command_complete(&format!("RESUME {},{}", database, user)));
|
||||
|
||||
// ReadyForQuery
|
||||
res.put_u8(b'Z');
|
||||
res.put_i32(5);
|
||||
res.put_u8(b'I');
|
||||
|
||||
write_all_half(stream, &res).await
|
||||
}
|
||||
|
||||
None => {
|
||||
error_response(
|
||||
stream,
|
||||
&format!(
|
||||
"No pool configured for database: {}, user: {}",
|
||||
database, user
|
||||
),
|
||||
)
|
||||
.await
|
||||
let mut res = BytesMut::new();
|
||||
|
||||
res.put(command_complete("RESUME"));
|
||||
|
||||
// ReadyForQuery
|
||||
res.put_u8(b'Z');
|
||||
res.put_i32(5);
|
||||
res.put_u8(b'I');
|
||||
|
||||
write_all_half(stream, &res).await
|
||||
}
|
||||
2 => {
|
||||
let database = parts[0];
|
||||
let user = parts[1];
|
||||
|
||||
match get_pool(database, user) {
|
||||
Some(pool) => {
|
||||
pool.resume();
|
||||
|
||||
let mut res = BytesMut::new();
|
||||
|
||||
res.put(command_complete(&format!("RESUME {},{}", database, user)));
|
||||
|
||||
// ReadyForQuery
|
||||
res.put_u8(b'Z');
|
||||
res.put_i32(5);
|
||||
res.put_u8(b'I');
|
||||
|
||||
write_all_half(stream, &res).await
|
||||
}
|
||||
|
||||
None => {
|
||||
error_response(
|
||||
stream,
|
||||
&format!(
|
||||
"No pool configured for database: {}, user: {}",
|
||||
database, user
|
||||
),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => error_response(stream, "usage: RESUME [db, user]").await,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ pub struct AuthPassthrough {
|
||||
|
||||
impl AuthPassthrough {
|
||||
/// Initializes an AuthPassthrough.
|
||||
pub fn new<S: ToString>(query: S, user: S, password: S) -> Self {
|
||||
pub fn new(query: &str, user: &str, password: &str) -> Self {
|
||||
AuthPassthrough {
|
||||
password: password.to_string(),
|
||||
query: query.to_string(),
|
||||
|
||||
357
src/client.rs
357
src/client.rs
@@ -12,7 +12,7 @@ use tokio::net::TcpStream;
|
||||
use tokio::sync::broadcast::Receiver;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
|
||||
use crate::admin::{generate_server_info_for_admin, handle_admin};
|
||||
use crate::admin::{generate_server_parameters_for_admin, handle_admin};
|
||||
use crate::auth_passthrough::refetch_auth_hash;
|
||||
use crate::config::{
|
||||
get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address, PoolMode,
|
||||
@@ -22,7 +22,7 @@ use crate::messages::*;
|
||||
use crate::plugins::PluginOutput;
|
||||
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
|
||||
use crate::query_router::{Command, QueryRouter};
|
||||
use crate::server::Server;
|
||||
use crate::server::{Server, ServerParameters};
|
||||
use crate::stats::{ClientStats, ServerStats};
|
||||
use crate::tls::Tls;
|
||||
|
||||
@@ -96,8 +96,8 @@ pub struct Client<S, T> {
|
||||
/// Postgres user for this client (This comes from the user in the connection string)
|
||||
username: String,
|
||||
|
||||
/// Application name for this client (defaults to pgcat)
|
||||
application_name: String,
|
||||
/// Server startup and session parameters that we're going to track
|
||||
server_parameters: ServerParameters,
|
||||
|
||||
/// Used to notify clients about an impending shutdown
|
||||
shutdown: Receiver<()>,
|
||||
@@ -117,13 +117,21 @@ pub async fn client_entrypoint(
|
||||
log_client_connections: bool,
|
||||
) -> Result<(), Error> {
|
||||
// Figure out if the client wants TLS or not.
|
||||
let addr = stream.peer_addr().unwrap();
|
||||
let addr = match stream.peer_addr() {
|
||||
Ok(addr) => addr,
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Failed to get peer address: {:?}",
|
||||
err
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
match get_startup::<TcpStream>(&mut stream).await {
|
||||
// Client requested a TLS connection.
|
||||
Ok((ClientConnectionType::Tls, _)) => {
|
||||
// TLS settings are configured, will setup TLS now.
|
||||
if tls_certificate.is_some() {
|
||||
if tls_certificate != None {
|
||||
debug!("Accepting TLS request");
|
||||
|
||||
let mut yes = BytesMut::new();
|
||||
@@ -147,10 +155,10 @@ pub async fn client_entrypoint(
|
||||
|
||||
if !client.is_admin() {
|
||||
let _ = drain.send(-1).await;
|
||||
}
|
||||
|
||||
if result.is_err() {
|
||||
client.stats.disconnect();
|
||||
}
|
||||
if result.is_err() {
|
||||
client.stats.disconnect();
|
||||
}
|
||||
|
||||
result
|
||||
@@ -199,10 +207,10 @@ pub async fn client_entrypoint(
|
||||
|
||||
if !client.is_admin() {
|
||||
let _ = drain.send(-1).await;
|
||||
}
|
||||
|
||||
if result.is_err() {
|
||||
client.stats.disconnect();
|
||||
}
|
||||
if result.is_err() {
|
||||
client.stats.disconnect();
|
||||
}
|
||||
|
||||
result
|
||||
@@ -253,10 +261,10 @@ pub async fn client_entrypoint(
|
||||
|
||||
if !client.is_admin() {
|
||||
let _ = drain.send(-1).await;
|
||||
}
|
||||
|
||||
if result.is_err() {
|
||||
client.stats.disconnect();
|
||||
}
|
||||
if result.is_err() {
|
||||
client.stats.disconnect();
|
||||
}
|
||||
|
||||
result
|
||||
@@ -282,11 +290,12 @@ pub async fn client_entrypoint(
|
||||
|
||||
if !client.is_admin() {
|
||||
let _ = drain.send(-1).await;
|
||||
|
||||
if result.is_err() {
|
||||
client.stats.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
if result.is_err() {
|
||||
client.stats.disconnect();
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
@@ -348,7 +357,15 @@ pub async fn startup_tls(
|
||||
) -> Result<Client<ReadHalf<TlsStream<TcpStream>>, WriteHalf<TlsStream<TcpStream>>>, Error> {
|
||||
// Negotiate TLS.
|
||||
let tls = Tls::new()?;
|
||||
let addr = stream.peer_addr().unwrap();
|
||||
let addr = match stream.peer_addr() {
|
||||
Ok(addr) => addr,
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Failed to get peer address: {:?}",
|
||||
err
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
let mut stream = match tls.acceptor.accept(stream).await {
|
||||
Ok(stream) => stream,
|
||||
@@ -431,7 +448,7 @@ where
|
||||
None => "pgcat",
|
||||
};
|
||||
|
||||
let client_identifier = ClientIdentifier::new(application_name, username, pool_name);
|
||||
let client_identifier = ClientIdentifier::new(&application_name, &username, &pool_name);
|
||||
|
||||
let admin = ["pgcat", "pgbouncer"]
|
||||
.iter()
|
||||
@@ -502,7 +519,7 @@ where
|
||||
};
|
||||
|
||||
// Authenticate admin user.
|
||||
let (transaction_mode, server_info) = if admin {
|
||||
let (transaction_mode, mut server_parameters) = if admin {
|
||||
let config = get_config();
|
||||
|
||||
// Compare server and client hashes.
|
||||
@@ -521,7 +538,7 @@ where
|
||||
return Err(error);
|
||||
}
|
||||
|
||||
(false, generate_server_info_for_admin())
|
||||
(false, generate_server_parameters_for_admin())
|
||||
}
|
||||
// Authenticate normal user.
|
||||
else {
|
||||
@@ -654,13 +671,16 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
(transaction_mode, pool.server_info())
|
||||
(transaction_mode, pool.server_parameters())
|
||||
};
|
||||
|
||||
// Update the parameters to merge what the application sent and what's originally on the server
|
||||
server_parameters.set_from_hashmap(¶meters, false);
|
||||
|
||||
debug!("Password authentication successful");
|
||||
|
||||
auth_ok(&mut write).await?;
|
||||
write_all(&mut write, server_info).await?;
|
||||
write_all(&mut write, (&server_parameters).into()).await?;
|
||||
backend_key_data(&mut write, process_id, secret_key).await?;
|
||||
ready_for_query(&mut write).await?;
|
||||
|
||||
@@ -690,7 +710,7 @@ where
|
||||
last_server_stats: None,
|
||||
pool_name: pool_name.clone(),
|
||||
username: username.clone(),
|
||||
application_name: application_name.to_string(),
|
||||
server_parameters,
|
||||
shutdown,
|
||||
connected_to_server: false,
|
||||
prepared_statements: HashMap::new(),
|
||||
@@ -725,7 +745,7 @@ where
|
||||
last_server_stats: None,
|
||||
pool_name: String::from("undefined"),
|
||||
username: String::from("undefined"),
|
||||
application_name: String::from("undefined"),
|
||||
server_parameters: ServerParameters::new(),
|
||||
shutdown,
|
||||
connected_to_server: false,
|
||||
prepared_statements: HashMap::new(),
|
||||
@@ -774,6 +794,12 @@ where
|
||||
let mut prepared_statement = None;
|
||||
let mut will_prepare = false;
|
||||
|
||||
let client_identifier = ClientIdentifier::new(
|
||||
&self.server_parameters.get_application_name(),
|
||||
&self.username,
|
||||
&self.pool_name,
|
||||
);
|
||||
|
||||
// Our custom protocol loop.
|
||||
// We expect the client to either start a transaction with regular queries
|
||||
// or issue commands for our sharding and server selection protocol.
|
||||
@@ -812,6 +838,29 @@ where
|
||||
message_result = read_message(&mut self.read) => message_result?
|
||||
};
|
||||
|
||||
if message[0] as char == 'X' {
|
||||
debug!("Client disconnecting");
|
||||
|
||||
self.stats.disconnect();
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Handle admin database queries.
|
||||
if self.admin {
|
||||
debug!("Handling admin command");
|
||||
handle_admin(&mut self.write, message, self.client_server_map.clone()).await?;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get a pool instance referenced by the most up-to-date
|
||||
// pointer. This ensures we always read the latest config
|
||||
// when starting a query.
|
||||
let mut pool = self.get_pool().await?;
|
||||
query_router.update_pool_settings(pool.settings.clone());
|
||||
|
||||
let mut initial_parsed_ast = None;
|
||||
|
||||
match message[0] as char {
|
||||
// Buffer extended protocol messages even if we do not have
|
||||
// a server connection yet. Hopefully, when we get the S message
|
||||
@@ -841,24 +890,34 @@ where
|
||||
|
||||
'Q' => {
|
||||
if query_router.query_parser_enabled() {
|
||||
if let Ok(ast) = QueryRouter::parse(&message) {
|
||||
let plugin_result = query_router.execute_plugins(&ast).await;
|
||||
match query_router.parse(&message) {
|
||||
Ok(ast) => {
|
||||
let plugin_result = query_router.execute_plugins(&ast).await;
|
||||
|
||||
match plugin_result {
|
||||
Ok(PluginOutput::Deny(error)) => {
|
||||
error_response(&mut self.write, &error).await?;
|
||||
continue;
|
||||
}
|
||||
match plugin_result {
|
||||
Ok(PluginOutput::Deny(error)) => {
|
||||
error_response(&mut self.write, &error).await?;
|
||||
continue;
|
||||
}
|
||||
|
||||
Ok(PluginOutput::Intercept(result)) => {
|
||||
write_all(&mut self.write, result).await?;
|
||||
continue;
|
||||
}
|
||||
Ok(PluginOutput::Intercept(result)) => {
|
||||
write_all(&mut self.write, result).await?;
|
||||
continue;
|
||||
}
|
||||
|
||||
_ => (),
|
||||
};
|
||||
_ => (),
|
||||
};
|
||||
|
||||
let _ = query_router.infer(&ast);
|
||||
let _ = query_router.infer(&ast);
|
||||
|
||||
initial_parsed_ast = Some(ast);
|
||||
}
|
||||
Err(error) => {
|
||||
warn!(
|
||||
"Query parsing error: {} (client: {})",
|
||||
error, client_identifier
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -872,13 +931,21 @@ where
|
||||
self.buffer.put(&message[..]);
|
||||
|
||||
if query_router.query_parser_enabled() {
|
||||
if let Ok(ast) = QueryRouter::parse(&message) {
|
||||
if let Ok(output) = query_router.execute_plugins(&ast).await {
|
||||
plugin_output = Some(output);
|
||||
}
|
||||
match query_router.parse(&message) {
|
||||
Ok(ast) => {
|
||||
if let Ok(output) = query_router.execute_plugins(&ast).await {
|
||||
plugin_output = Some(output);
|
||||
}
|
||||
|
||||
let _ = query_router.infer(&ast);
|
||||
}
|
||||
let _ = query_router.infer(&ast);
|
||||
}
|
||||
Err(error) => {
|
||||
warn!(
|
||||
"Query parsing error: {} (client: {})",
|
||||
error, client_identifier
|
||||
);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
continue;
|
||||
@@ -898,14 +965,6 @@ where
|
||||
continue;
|
||||
}
|
||||
|
||||
'X' => {
|
||||
debug!("Client disconnecting");
|
||||
|
||||
self.stats.disconnect();
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Close (F)
|
||||
'C' => {
|
||||
if prepared_statements_enabled {
|
||||
@@ -922,25 +981,17 @@ where
|
||||
_ => (),
|
||||
}
|
||||
|
||||
// Handle admin database queries.
|
||||
if self.admin {
|
||||
debug!("Handling admin command");
|
||||
handle_admin(&mut self.write, message, self.client_server_map.clone()).await?;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check on plugin results.
|
||||
if let Some(PluginOutput::Deny(error)) = plugin_output {
|
||||
self.buffer.clear();
|
||||
error_response(&mut self.write, &error).await?;
|
||||
plugin_output = None;
|
||||
continue;
|
||||
}
|
||||
match plugin_output {
|
||||
Some(PluginOutput::Deny(error)) => {
|
||||
self.buffer.clear();
|
||||
error_response(&mut self.write, &error).await?;
|
||||
plugin_output = None;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get a pool instance referenced by the most up-to-date
|
||||
// pointer. This ensures we always read the latest config
|
||||
// when starting a query.
|
||||
let mut pool = self.get_pool().await?;
|
||||
_ => (),
|
||||
};
|
||||
|
||||
// Check if the pool is paused and wait until it's resumed.
|
||||
if pool.wait_paused().await {
|
||||
@@ -959,23 +1010,27 @@ where
|
||||
|
||||
// SET SHARD TO
|
||||
Some((Command::SetShard, _)) => {
|
||||
// Selected shard is not configured.
|
||||
if query_router.shard() >= pool.shards() {
|
||||
// Set the shard back to what it was.
|
||||
query_router.set_shard(current_shard);
|
||||
match query_router.shard() {
|
||||
None => (),
|
||||
Some(selected_shard) => {
|
||||
if selected_shard >= pool.shards() {
|
||||
// Bad shard number, send error message to client.
|
||||
query_router.set_shard(current_shard);
|
||||
|
||||
error_response(
|
||||
&mut self.write,
|
||||
&format!(
|
||||
"shard {} is more than configured {}, staying on shard {} (shard numbers start at 0)",
|
||||
query_router.shard(),
|
||||
pool.shards(),
|
||||
current_shard,
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
custom_protocol_response_ok(&mut self.write, "SET SHARD").await?;
|
||||
error_response(
|
||||
&mut self.write,
|
||||
&format!(
|
||||
"shard {} is not configured {}, staying on shard {:?} (shard numbers start at 0)",
|
||||
selected_shard,
|
||||
pool.shards(),
|
||||
current_shard,
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
custom_protocol_response_ok(&mut self.write, "SET SHARD").await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
@@ -1043,8 +1098,11 @@ where
|
||||
self.buffer.clear();
|
||||
}
|
||||
|
||||
error_response(&mut self.write, "could not get connection from the pool")
|
||||
.await?;
|
||||
error_response(
|
||||
&mut self.write,
|
||||
format!("could not get connection from the pool - {}", err).as_str(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
error!(
|
||||
"Could not get connection from pool: \
|
||||
@@ -1087,10 +1145,7 @@ where
|
||||
server.address()
|
||||
);
|
||||
|
||||
// TODO: investigate other parameters and set them too.
|
||||
|
||||
// Set application_name.
|
||||
server.set_name(&self.application_name).await?;
|
||||
server.sync_parameters(&self.server_parameters).await?;
|
||||
|
||||
let mut initial_message = Some(message);
|
||||
|
||||
@@ -1161,6 +1216,9 @@ where
|
||||
None => {
|
||||
trace!("Waiting for message inside transaction or in session mode");
|
||||
|
||||
// This is not an initial message so discard the initial_parsed_ast
|
||||
initial_parsed_ast.take();
|
||||
|
||||
match tokio::time::timeout(
|
||||
idle_client_timeout_duration,
|
||||
read_message(&mut self.read),
|
||||
@@ -1184,7 +1242,7 @@ where
|
||||
{{ \
|
||||
pool_name: {}, \
|
||||
username: {}, \
|
||||
shard: {}, \
|
||||
shard: {:?}, \
|
||||
role: \"{:?}\" \
|
||||
}}",
|
||||
self.pool_name,
|
||||
@@ -1209,7 +1267,7 @@ where
|
||||
|
||||
// Safe to unwrap because we know this message has a certain length and has the code
|
||||
// This reads the first byte without advancing the internal pointer and mutating the bytes
|
||||
let code = *message.first().unwrap() as char;
|
||||
let code = *message.get(0).unwrap() as char;
|
||||
|
||||
trace!("Message: {}", code);
|
||||
|
||||
@@ -1217,7 +1275,22 @@ where
|
||||
// Query
|
||||
'Q' => {
|
||||
if query_router.query_parser_enabled() {
|
||||
if let Ok(ast) = QueryRouter::parse(&message) {
|
||||
// We don't want to parse again if we already parsed it as the initial message
|
||||
let ast = match initial_parsed_ast {
|
||||
Some(_) => Some(initial_parsed_ast.take().unwrap()),
|
||||
None => match query_router.parse(&message) {
|
||||
Ok(ast) => Some(ast),
|
||||
Err(error) => {
|
||||
warn!(
|
||||
"Query parsing error: {} (client: {})",
|
||||
error, client_identifier
|
||||
);
|
||||
None
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
if let Some(ast) = ast {
|
||||
let plugin_result = query_router.execute_plugins(&ast).await;
|
||||
|
||||
match plugin_result {
|
||||
@@ -1233,8 +1306,6 @@ where
|
||||
|
||||
_ => (),
|
||||
};
|
||||
|
||||
let _ = query_router.infer(&ast);
|
||||
}
|
||||
}
|
||||
debug!("Sending query to server");
|
||||
@@ -1252,7 +1323,9 @@ where
|
||||
if !server.in_transaction() {
|
||||
// Report transaction executed statistics.
|
||||
self.stats.transaction();
|
||||
server.stats().transaction(&self.application_name);
|
||||
server
|
||||
.stats()
|
||||
.transaction(&self.server_parameters.get_application_name());
|
||||
|
||||
// Release server back to the pool if we are in transaction mode.
|
||||
// If we are in session mode, we keep the server until the client disconnects.
|
||||
@@ -1286,7 +1359,7 @@ where
|
||||
}
|
||||
|
||||
if query_router.query_parser_enabled() {
|
||||
if let Ok(ast) = QueryRouter::parse(&message) {
|
||||
if let Ok(ast) = query_router.parse(&message) {
|
||||
if let Ok(output) = query_router.execute_plugins(&ast).await {
|
||||
plugin_output = Some(output);
|
||||
}
|
||||
@@ -1327,11 +1400,14 @@ where
|
||||
let close: Close = (&message).try_into()?;
|
||||
|
||||
if close.is_prepared_statement() && !close.anonymous() {
|
||||
if let Some(parse) = self.prepared_statements.get(&close.name) {
|
||||
server.will_close(&parse.generated_name);
|
||||
} else {
|
||||
match self.prepared_statements.get(&close.name) {
|
||||
Some(parse) => {
|
||||
server.will_close(&parse.generated_name);
|
||||
}
|
||||
|
||||
// A prepared statement slipped through? Not impossible, since we don't support PREPARE yet.
|
||||
}
|
||||
None => (),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1369,7 +1445,7 @@ where
|
||||
|
||||
self.buffer.put(&message[..]);
|
||||
|
||||
let first_message_code = (*self.buffer.first().unwrap_or(&0)) as char;
|
||||
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;
|
||||
|
||||
// Almost certainly true
|
||||
if first_message_code == 'P' && !prepared_statements_enabled {
|
||||
@@ -1399,7 +1475,9 @@ where
|
||||
|
||||
if !server.in_transaction() {
|
||||
self.stats.transaction();
|
||||
server.stats().transaction(&self.application_name);
|
||||
server
|
||||
.stats()
|
||||
.transaction(&self.server_parameters.get_application_name());
|
||||
|
||||
// Release server back to the pool if we are in transaction mode.
|
||||
// If we are in session mode, we keep the server until the client disconnects.
|
||||
@@ -1448,7 +1526,9 @@ where
|
||||
|
||||
if !server.in_transaction() {
|
||||
self.stats.transaction();
|
||||
server.stats().transaction(&self.application_name);
|
||||
server
|
||||
.stats()
|
||||
.transaction(self.server_parameters.get_application_name());
|
||||
|
||||
// Release server back to the pool if we are in transaction mode.
|
||||
// If we are in session mode, we keep the server until the client disconnects.
|
||||
@@ -1500,7 +1580,9 @@ where
|
||||
|
||||
Err(Error::ClientError(format!(
|
||||
"Invalid pool name {{ username: {}, pool_name: {}, application_name: {} }}",
|
||||
self.pool_name, self.username, self.application_name
|
||||
self.pool_name,
|
||||
self.username,
|
||||
self.server_parameters.get_application_name()
|
||||
)))
|
||||
}
|
||||
}
|
||||
@@ -1657,7 +1739,7 @@ where
|
||||
client_stats.query();
|
||||
server.stats().query(
|
||||
Instant::now().duration_since(query_start).as_millis() as u64,
|
||||
&self.application_name,
|
||||
&self.server_parameters.get_application_name(),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
@@ -1686,38 +1768,18 @@ where
|
||||
pool: &ConnectionPool,
|
||||
client_stats: &ClientStats,
|
||||
) -> Result<BytesMut, Error> {
|
||||
if pool.settings.user.statement_timeout > 0 {
|
||||
match tokio::time::timeout(
|
||||
tokio::time::Duration::from_millis(pool.settings.user.statement_timeout),
|
||||
server.recv(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => match result {
|
||||
Ok(message) => Ok(message),
|
||||
Err(err) => {
|
||||
pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats));
|
||||
error_response_terminal(
|
||||
&mut self.write,
|
||||
&format!("error receiving data from server: {:?}", err),
|
||||
)
|
||||
.await?;
|
||||
Err(err)
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
error!(
|
||||
"Statement timeout while talking to {:?} with user {}",
|
||||
address, pool.settings.user.username
|
||||
);
|
||||
server.mark_bad();
|
||||
pool.ban(address, BanReason::StatementTimeout, Some(client_stats));
|
||||
error_response_terminal(&mut self.write, "pool statement timeout").await?;
|
||||
Err(Error::StatementTimeout)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
match server.recv().await {
|
||||
let statement_timeout_duration = match pool.settings.user.statement_timeout {
|
||||
0 => tokio::time::Duration::MAX,
|
||||
timeout => tokio::time::Duration::from_millis(timeout),
|
||||
};
|
||||
|
||||
match tokio::time::timeout(
|
||||
statement_timeout_duration,
|
||||
server.recv(Some(&mut self.server_parameters)),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => match result {
|
||||
Ok(message) => Ok(message),
|
||||
Err(err) => {
|
||||
pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats));
|
||||
@@ -1728,6 +1790,16 @@ where
|
||||
.await?;
|
||||
Err(err)
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
error!(
|
||||
"Statement timeout while talking to {:?} with user {}",
|
||||
address, pool.settings.user.username
|
||||
);
|
||||
server.mark_bad();
|
||||
pool.ban(address, BanReason::StatementTimeout, Some(client_stats));
|
||||
error_response_terminal(&mut self.write, "pool statement timeout").await?;
|
||||
Err(Error::StatementTimeout)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1740,7 +1812,6 @@ impl<S, T> Drop for Client<S, T> {
|
||||
|
||||
// Dirty shutdown
|
||||
// TODO: refactor, this is not the best way to handle state management.
|
||||
|
||||
if self.connected_to_server && self.last_server_stats.is_some() {
|
||||
self.last_server_stats.as_ref().unwrap().idle();
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ pub struct Args {
|
||||
}
|
||||
|
||||
pub fn parse() -> Args {
|
||||
Args::parse()
|
||||
return Args::parse();
|
||||
}
|
||||
|
||||
#[derive(ValueEnum, Clone, Debug)]
|
||||
|
||||
231
src/config.rs
231
src/config.rs
@@ -1,13 +1,16 @@
|
||||
/// Parse the configuration file.
|
||||
use arc_swap::ArcSwap;
|
||||
use log::{error, info};
|
||||
use log::{error, info, warn};
|
||||
use once_cell::sync::Lazy;
|
||||
use regex::Regex;
|
||||
use serde::{Deserializer, Serializer};
|
||||
use serde_derive::{Deserialize, Serialize};
|
||||
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::collections::{BTreeMap, HashMap, HashSet};
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::path::Path;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::fs::File;
|
||||
use tokio::io::AsyncReadExt;
|
||||
@@ -101,6 +104,9 @@ pub struct Address {
|
||||
|
||||
/// Address stats
|
||||
pub stats: Arc<AddressStats>,
|
||||
|
||||
/// Number of errors encountered since last successful checkout
|
||||
pub error_count: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
impl Default for Address {
|
||||
@@ -118,6 +124,7 @@ impl Default for Address {
|
||||
pool_name: String::from("pool_name"),
|
||||
mirrors: Vec::new(),
|
||||
stats: Arc::new(AddressStats::default()),
|
||||
error_count: Arc::new(AtomicU64::new(0)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -182,6 +189,18 @@ impl Address {
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn error_count(&self) -> u64 {
|
||||
self.error_count.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn increment_error_count(&self) {
|
||||
self.error_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn reset_error_count(&self) {
|
||||
self.error_count.store(0, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
/// PostgreSQL user.
|
||||
@@ -217,15 +236,19 @@ impl Default for User {
|
||||
|
||||
impl User {
|
||||
fn validate(&self) -> Result<(), Error> {
|
||||
if let Some(min_pool_size) = self.min_pool_size {
|
||||
if min_pool_size > self.pool_size {
|
||||
error!(
|
||||
"min_pool_size of {} cannot be larger than pool_size of {}",
|
||||
min_pool_size, self.pool_size
|
||||
);
|
||||
return Err(Error::BadConfig);
|
||||
match self.min_pool_size {
|
||||
Some(min_pool_size) => {
|
||||
if min_pool_size > self.pool_size {
|
||||
error!(
|
||||
"min_pool_size of {} cannot be larger than pool_size of {}",
|
||||
min_pool_size, self.pool_size
|
||||
);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None => (),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -507,6 +530,11 @@ pub struct Pool {
|
||||
#[serde(default)] // False
|
||||
pub query_parser_enabled: bool,
|
||||
|
||||
pub query_parser_max_length: Option<usize>,
|
||||
|
||||
#[serde(default)] // False
|
||||
pub query_parser_read_write_splitting: bool,
|
||||
|
||||
#[serde(default)] // False
|
||||
pub primary_reads_enabled: bool,
|
||||
|
||||
@@ -531,6 +559,9 @@ pub struct Pool {
|
||||
pub shard_id_regex: Option<String>,
|
||||
pub regex_search_limit: Option<usize>,
|
||||
|
||||
#[serde(default = "Pool::default_default_shard")]
|
||||
pub default_shard: DefaultShard,
|
||||
|
||||
pub auth_query: Option<String>,
|
||||
pub auth_query_user: Option<String>,
|
||||
pub auth_query_password: Option<String>,
|
||||
@@ -538,6 +569,9 @@ pub struct Pool {
|
||||
#[serde(default = "Pool::default_cleanup_server_connections")]
|
||||
pub cleanup_server_connections: bool,
|
||||
|
||||
#[serde(default)] // False
|
||||
pub log_client_parameter_status_changes: bool,
|
||||
|
||||
pub plugins: Option<Plugins>,
|
||||
pub shards: BTreeMap<String, Shard>,
|
||||
pub users: BTreeMap<String, User>,
|
||||
@@ -563,6 +597,10 @@ impl Pool {
|
||||
PoolMode::Transaction
|
||||
}
|
||||
|
||||
pub fn default_default_shard() -> DefaultShard {
|
||||
DefaultShard::default()
|
||||
}
|
||||
|
||||
pub fn default_load_balancing_mode() -> LoadBalancingMode {
|
||||
LoadBalancingMode::Random
|
||||
}
|
||||
@@ -623,13 +661,25 @@ impl Pool {
|
||||
}
|
||||
}
|
||||
|
||||
if self.query_parser_read_write_splitting && !self.query_parser_enabled {
|
||||
error!(
|
||||
"query_parser_read_write_splitting is only valid when query_parser_enabled is true"
|
||||
);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
|
||||
if self.plugins.is_some() && !self.query_parser_enabled {
|
||||
error!("plugins are only valid when query_parser_enabled is true");
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
|
||||
self.automatic_sharding_key = match &self.automatic_sharding_key {
|
||||
Some(key) => {
|
||||
// No quotes in the key so we don't have to compare quoted
|
||||
// to unquoted idents.
|
||||
let key = key.replace('\"', "");
|
||||
let key = key.replace("\"", "");
|
||||
|
||||
if key.split('.').count() != 2 {
|
||||
if key.split(".").count() != 2 {
|
||||
error!(
|
||||
"automatic_sharding_key '{}' must be fully qualified, e.g. t.{}`",
|
||||
key, key
|
||||
@@ -642,7 +692,17 @@ impl Pool {
|
||||
None => None,
|
||||
};
|
||||
|
||||
for user in self.users.values() {
|
||||
match self.default_shard {
|
||||
DefaultShard::Shard(shard_number) => {
|
||||
if shard_number >= self.shards.len() {
|
||||
error!("Invalid shard {:?}", shard_number);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
|
||||
for (_, user) in &self.users {
|
||||
user.validate()?;
|
||||
}
|
||||
|
||||
@@ -659,6 +719,8 @@ impl Default for Pool {
|
||||
users: BTreeMap::default(),
|
||||
default_role: String::from("any"),
|
||||
query_parser_enabled: false,
|
||||
query_parser_max_length: None,
|
||||
query_parser_read_write_splitting: false,
|
||||
primary_reads_enabled: false,
|
||||
sharding_function: ShardingFunction::PgBigintHash,
|
||||
automatic_sharding_key: None,
|
||||
@@ -667,12 +729,14 @@ impl Default for Pool {
|
||||
sharding_key_regex: None,
|
||||
shard_id_regex: None,
|
||||
regex_search_limit: Some(1000),
|
||||
default_shard: Self::default_default_shard(),
|
||||
auth_query: None,
|
||||
auth_query_user: None,
|
||||
auth_query_password: None,
|
||||
server_lifetime: None,
|
||||
plugins: None,
|
||||
cleanup_server_connections: true,
|
||||
log_client_parameter_status_changes: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -684,6 +748,50 @@ pub struct ServerConfig {
|
||||
pub role: Role,
|
||||
}
|
||||
|
||||
// No Shard Specified handling.
|
||||
#[derive(Debug, PartialEq, Clone, Eq, Hash, Copy)]
|
||||
pub enum DefaultShard {
|
||||
Shard(usize),
|
||||
Random,
|
||||
RandomHealthy,
|
||||
}
|
||||
impl Default for DefaultShard {
|
||||
fn default() -> Self {
|
||||
DefaultShard::Shard(0)
|
||||
}
|
||||
}
|
||||
impl serde::Serialize for DefaultShard {
|
||||
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
match self {
|
||||
DefaultShard::Shard(shard) => {
|
||||
serializer.serialize_str(&format!("shard_{}", &shard.to_string()))
|
||||
}
|
||||
DefaultShard::Random => serializer.serialize_str("random"),
|
||||
DefaultShard::RandomHealthy => serializer.serialize_str("random_healthy"),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<'de> serde::Deserialize<'de> for DefaultShard {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
if s.starts_with("shard_") {
|
||||
let shard = s[6..].parse::<usize>().map_err(serde::de::Error::custom)?;
|
||||
return Ok(DefaultShard::Shard(shard));
|
||||
}
|
||||
|
||||
match s.as_str() {
|
||||
"random" => Ok(DefaultShard::Random),
|
||||
"random_healthy" => Ok(DefaultShard::RandomHealthy),
|
||||
_ => Err(serde::de::Error::custom(
|
||||
"invalid value for no_shard_specified_behavior",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Serialize, Deserialize, Debug, Hash, Eq)]
|
||||
pub struct MirrorServerConfig {
|
||||
pub host: String,
|
||||
@@ -814,8 +922,8 @@ pub struct Query {
|
||||
impl Query {
|
||||
pub fn substitute(&mut self, db: &str, user: &str) {
|
||||
for col in self.result.iter_mut() {
|
||||
for c in col {
|
||||
*c = c.replace("${USER}", user).replace("${DATABASE}", db);
|
||||
for i in 0..col.len() {
|
||||
col[i] = col[i].replace("${USER}", user).replace("${DATABASE}", db);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -910,6 +1018,17 @@ impl From<&Config> for std::collections::HashMap<String, String> {
|
||||
format!("pools.{}.query_parser_enabled", pool_name),
|
||||
pool.query_parser_enabled.to_string(),
|
||||
),
|
||||
(
|
||||
format!("pools.{}.query_parser_max_length", pool_name),
|
||||
match pool.query_parser_max_length {
|
||||
Some(max_length) => max_length.to_string(),
|
||||
None => String::from("unlimited"),
|
||||
},
|
||||
),
|
||||
(
|
||||
format!("pools.{}.query_parser_read_write_splitting", pool_name),
|
||||
pool.query_parser_read_write_splitting.to_string(),
|
||||
),
|
||||
(
|
||||
format!("pools.{}.default_role", pool_name),
|
||||
pool.default_role.clone(),
|
||||
@@ -925,8 +1044,8 @@ impl From<&Config> for std::collections::HashMap<String, String> {
|
||||
(
|
||||
format!("pools.{:?}.users", pool_name),
|
||||
pool.users
|
||||
.values()
|
||||
.map(|user| &user.username)
|
||||
.iter()
|
||||
.map(|(_username, user)| &user.username)
|
||||
.cloned()
|
||||
.collect::<Vec<String>>()
|
||||
.join(", "),
|
||||
@@ -1011,9 +1130,13 @@ impl Config {
|
||||
Some(tls_certificate) => {
|
||||
info!("TLS certificate: {}", tls_certificate);
|
||||
|
||||
if let Some(tls_private_key) = self.general.tls_private_key.clone() {
|
||||
info!("TLS private key: {}", tls_private_key);
|
||||
info!("TLS support is enabled");
|
||||
match self.general.tls_private_key.clone() {
|
||||
Some(tls_private_key) => {
|
||||
info!("TLS private key: {}", tls_private_key);
|
||||
info!("TLS support is enabled");
|
||||
}
|
||||
|
||||
None => (),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1048,8 +1171,8 @@ impl Config {
|
||||
pool_name,
|
||||
pool_config
|
||||
.users
|
||||
.values()
|
||||
.map(|user_cfg| user_cfg.pool_size)
|
||||
.iter()
|
||||
.map(|(_, user_cfg)| user_cfg.pool_size)
|
||||
.sum::<u32>()
|
||||
.to_string()
|
||||
);
|
||||
@@ -1088,6 +1211,15 @@ impl Config {
|
||||
"[pool: {}] Query router: {}",
|
||||
pool_name, pool_config.query_parser_enabled
|
||||
);
|
||||
|
||||
info!(
|
||||
"[pool: {}] Query parser max length: {:?}",
|
||||
pool_name, pool_config.query_parser_max_length
|
||||
);
|
||||
info!(
|
||||
"[pool: {}] Infer role from query: {}",
|
||||
pool_name, pool_config.query_parser_read_write_splitting
|
||||
);
|
||||
info!(
|
||||
"[pool: {}] Number of shards: {}",
|
||||
pool_name,
|
||||
@@ -1110,6 +1242,10 @@ impl Config {
|
||||
"[pool: {}] Cleanup server connections: {}",
|
||||
pool_name, pool_config.cleanup_server_connections
|
||||
);
|
||||
info!(
|
||||
"[pool: {}] Log client parameter status changes: {}",
|
||||
pool_name, pool_config.log_client_parameter_status_changes
|
||||
);
|
||||
info!(
|
||||
"[pool: {}] Plugins: {}",
|
||||
pool_name,
|
||||
@@ -1206,32 +1342,43 @@ impl Config {
|
||||
}
|
||||
|
||||
// Validate TLS!
|
||||
if let Some(tls_certificate) = self.general.tls_certificate.clone() {
|
||||
match load_certs(Path::new(&tls_certificate)) {
|
||||
Ok(_) => {
|
||||
// Cert is okay, but what about the private key?
|
||||
match self.general.tls_private_key.clone() {
|
||||
Some(tls_private_key) => match load_keys(Path::new(&tls_private_key)) {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
error!("tls_private_key is incorrectly configured: {:?}", err);
|
||||
return Err(Error::BadConfig);
|
||||
match self.general.tls_certificate {
|
||||
Some(ref mut tls_certificate) => {
|
||||
match load_certs(Path::new(&tls_certificate)) {
|
||||
Ok(_) => {
|
||||
// Cert is okay, but what about the private key?
|
||||
match self.general.tls_private_key {
|
||||
Some(ref tls_private_key) => {
|
||||
match load_keys(Path::new(&tls_private_key)) {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"tls_private_key is incorrectly configured: {:?}",
|
||||
err
|
||||
);
|
||||
self.general.tls_private_key = None;
|
||||
self.general.tls_certificate = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
None => {
|
||||
error!("tls_certificate is set, but the tls_private_key is not");
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
};
|
||||
}
|
||||
None => {
|
||||
warn!("tls_certificate is set, but the tls_private_key is not");
|
||||
self.general.tls_private_key = None;
|
||||
self.general.tls_certificate = None;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
Err(err) => {
|
||||
error!("tls_certificate is incorrectly configured: {:?}", err);
|
||||
return Err(Error::BadConfig);
|
||||
Err(err) => {
|
||||
warn!("tls_certificate is incorrectly configured: {:?}", err);
|
||||
self.general.tls_private_key = None;
|
||||
self.general.tls_certificate = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None => (),
|
||||
};
|
||||
|
||||
for pool in self.pools.values_mut() {
|
||||
pool.validate()?;
|
||||
|
||||
@@ -12,6 +12,7 @@ pub enum Error {
|
||||
ProtocolSyncError(String),
|
||||
BadQuery(String),
|
||||
ServerError,
|
||||
ServerMessageParserError(String),
|
||||
ServerStartupError(String, ServerIdentifier),
|
||||
ServerAuthError(String, ServerIdentifier),
|
||||
BadConfig,
|
||||
@@ -27,6 +28,7 @@ pub enum Error {
|
||||
UnsupportedStatement,
|
||||
QueryRouterParserError(String),
|
||||
QueryRouterError(String),
|
||||
InvalidShardId(usize),
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
@@ -37,11 +39,11 @@ pub struct ClientIdentifier {
|
||||
}
|
||||
|
||||
impl ClientIdentifier {
|
||||
pub fn new<S: ToString>(application_name: S, username: S, pool_name: S) -> ClientIdentifier {
|
||||
pub fn new(application_name: &str, username: &str, pool_name: &str) -> ClientIdentifier {
|
||||
ClientIdentifier {
|
||||
application_name: application_name.to_string(),
|
||||
username: username.to_string(),
|
||||
pool_name: pool_name.to_string(),
|
||||
application_name: application_name.into(),
|
||||
username: username.into(),
|
||||
pool_name: pool_name.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -63,10 +65,10 @@ pub struct ServerIdentifier {
|
||||
}
|
||||
|
||||
impl ServerIdentifier {
|
||||
pub fn new<S: ToString>(username: S, database: S) -> ServerIdentifier {
|
||||
pub fn new(username: &str, database: &str) -> ServerIdentifier {
|
||||
ServerIdentifier {
|
||||
username: username.to_string(),
|
||||
database: database.to_string(),
|
||||
username: username.into(),
|
||||
database: database.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -84,36 +86,41 @@ impl std::fmt::Display for ServerIdentifier {
|
||||
impl std::fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match &self {
|
||||
Error::ClientSocketError(error, client_identifier) => {
|
||||
write!(f, "Error reading {error} from client {client_identifier}",)
|
||||
&Error::ClientSocketError(error, client_identifier) => write!(
|
||||
f,
|
||||
"Error reading {} from client {}",
|
||||
error, client_identifier
|
||||
),
|
||||
&Error::ClientGeneralError(error, client_identifier) => {
|
||||
write!(f, "{} {}", error, client_identifier)
|
||||
}
|
||||
Error::ClientGeneralError(error, client_identifier) => {
|
||||
write!(f, "{error} {client_identifier}")
|
||||
}
|
||||
Error::ClientAuthImpossible(username) => write!(
|
||||
&Error::ClientAuthImpossible(username) => write!(
|
||||
f,
|
||||
"Client auth not possible, \
|
||||
no cleartext password set for username: {username} \
|
||||
no cleartext password set for username: {} \
|
||||
in config and auth passthrough (query_auth) \
|
||||
is not set up."
|
||||
is not set up.",
|
||||
username
|
||||
),
|
||||
Error::ClientAuthPassthroughError(error, client_identifier) => write!(
|
||||
&Error::ClientAuthPassthroughError(error, client_identifier) => write!(
|
||||
f,
|
||||
"No cleartext password set, \
|
||||
and no auth passthrough could not \
|
||||
obtain the hash from server for {client_identifier}, \
|
||||
the error was: {error}",
|
||||
obtain the hash from server for {}, \
|
||||
the error was: {}",
|
||||
client_identifier, error
|
||||
),
|
||||
Error::ServerStartupError(error, server_identifier) => write!(
|
||||
&Error::ServerStartupError(error, server_identifier) => write!(
|
||||
f,
|
||||
"Error reading {error} on server startup {server_identifier}",
|
||||
"Error reading {} on server startup {}",
|
||||
error, server_identifier,
|
||||
),
|
||||
Error::ServerAuthError(error, server_identifier) => {
|
||||
write!(f, "{error} for {server_identifier}")
|
||||
&Error::ServerAuthError(error, server_identifier) => {
|
||||
write!(f, "{} for {}", error, server_identifier,)
|
||||
}
|
||||
|
||||
// The rest can use Debug.
|
||||
err => write!(f, "{err:?}"),
|
||||
err => write!(f, "{:?}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
19
src/lib.rs
19
src/lib.rs
@@ -25,11 +25,18 @@ pub mod tls;
|
||||
///
|
||||
/// * `duration` - A duration of time
|
||||
pub fn format_duration(duration: &chrono::Duration) -> String {
|
||||
let milliseconds = duration.num_milliseconds() % 1000;
|
||||
let seconds = duration.num_seconds() % 60;
|
||||
let minutes = duration.num_minutes() % 60;
|
||||
let hours = duration.num_hours() % 24;
|
||||
let days = duration.num_days();
|
||||
let milliseconds = format!("{:0>3}", duration.num_milliseconds() % 1000);
|
||||
|
||||
format!("{days}d {hours:0>2}:{minutes:0>2}:{seconds:0>2}.{milliseconds:0>3}")
|
||||
let seconds = format!("{:0>2}", duration.num_seconds() % 60);
|
||||
|
||||
let minutes = format!("{:0>2}", duration.num_minutes() % 60);
|
||||
|
||||
let hours = format!("{:0>2}", duration.num_hours() % 24);
|
||||
|
||||
let days = duration.num_days().to_string();
|
||||
|
||||
format!(
|
||||
"{}d {}:{}:{}.{}",
|
||||
days, hours, minutes, seconds, milliseconds
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
use crate::cmd_args::{Args, LogFormat};
|
||||
use tracing_subscriber;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
pub fn init(args: &Args) {
|
||||
// Iniitalize a default filter, and then override the builtin default "warning" with our
|
||||
// commandline, (default: "info")
|
||||
let filter = EnvFilter::from_default_env().add_directive(args.log_level.into());
|
||||
|
||||
let trace_sub = tracing_subscriber::fmt()
|
||||
.with_max_level(args.log_level)
|
||||
.with_thread_ids(true)
|
||||
.with_env_filter(filter)
|
||||
.with_ansi(!args.no_color);
|
||||
|
||||
match args.log_format {
|
||||
|
||||
@@ -23,7 +23,6 @@ extern crate arc_swap;
|
||||
extern crate async_trait;
|
||||
extern crate bb8;
|
||||
extern crate bytes;
|
||||
extern crate env_logger;
|
||||
extern crate exitcode;
|
||||
extern crate log;
|
||||
extern crate md5;
|
||||
@@ -160,7 +159,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
}
|
||||
};
|
||||
|
||||
Collector::collect();
|
||||
tokio::task::spawn(async move {
|
||||
let mut stats_collector = Collector::default();
|
||||
stats_collector.collect().await;
|
||||
});
|
||||
|
||||
info!("Config autoreloader: {}", match config.general.autoreload {
|
||||
Some(interval) => format!("{} ms", interval),
|
||||
|
||||
379
src/messages.rs
379
src/messages.rs
@@ -11,10 +11,13 @@ use crate::client::PREPARED_STATEMENT_COUNTER;
|
||||
use crate::config::get_config;
|
||||
use crate::errors::Error;
|
||||
|
||||
use crate::constants::MESSAGE_TERMINATOR;
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::CString;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::io::{BufRead, Cursor};
|
||||
use std::mem;
|
||||
use std::str::FromStr;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -141,6 +144,10 @@ where
|
||||
bytes.put_slice(user.as_bytes());
|
||||
bytes.put_u8(0);
|
||||
|
||||
// Application name
|
||||
bytes.put(&b"application_name\0"[..]);
|
||||
bytes.put_slice(&b"pgcat\0"[..]);
|
||||
|
||||
// Database
|
||||
bytes.put(&b"database\0"[..]);
|
||||
bytes.put_slice(database.as_bytes());
|
||||
@@ -156,10 +163,12 @@ where
|
||||
|
||||
match stream.write_all(&startup).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(Error::SocketError(format!(
|
||||
"Error writing startup to server socket - Error: {:?}",
|
||||
err
|
||||
))),
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error writing startup to server socket - Error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -235,8 +244,8 @@ pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec<u8> {
|
||||
let mut md5 = Md5::new();
|
||||
|
||||
// First pass
|
||||
md5.update(password.as_bytes());
|
||||
md5.update(user.as_bytes());
|
||||
md5.update(&password.as_bytes());
|
||||
md5.update(&user.as_bytes());
|
||||
|
||||
let output = md5.finalize_reset();
|
||||
|
||||
@@ -272,7 +281,7 @@ where
|
||||
{
|
||||
let password = md5_hash_password(user, password, salt);
|
||||
|
||||
let mut message = BytesMut::with_capacity(password.len() + 5);
|
||||
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
|
||||
|
||||
message.put_u8(b'p');
|
||||
message.put_i32(password.len() as i32 + 4);
|
||||
@@ -286,7 +295,7 @@ where
|
||||
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
let password = md5_hash_second_pass(hash, salt);
|
||||
let mut message = BytesMut::with_capacity(password.len() + 5);
|
||||
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
|
||||
|
||||
message.put_u8(b'p');
|
||||
message.put_i32(password.len() as i32 + 4);
|
||||
@@ -507,7 +516,7 @@ pub fn data_row_nullable(row: &Vec<Option<String>>) -> BytesMut {
|
||||
data_row.put_i32(column.len() as i32);
|
||||
data_row.put_slice(column);
|
||||
} else {
|
||||
data_row.put_i32(-1_i32);
|
||||
data_row.put_i32(-1 as i32);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -562,10 +571,12 @@ where
|
||||
{
|
||||
match stream.write_all(&buf).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(Error::SocketError(format!(
|
||||
"Error writing to socket - Error: {:?}",
|
||||
err
|
||||
))),
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error writing to socket - Error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -576,10 +587,12 @@ where
|
||||
{
|
||||
match stream.write_all(buf).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(Error::SocketError(format!(
|
||||
"Error writing to socket - Error: {:?}",
|
||||
err
|
||||
))),
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error writing to socket - Error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -590,15 +603,19 @@ where
|
||||
match stream.write_all(buf).await {
|
||||
Ok(_) => match stream.flush().await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(Error::SocketError(format!(
|
||||
"Error flushing socket - Error: {:?}",
|
||||
err
|
||||
))),
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error flushing socket - Error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
},
|
||||
Err(err) => Err(Error::SocketError(format!(
|
||||
"Error writing to socket - Error: {:?}",
|
||||
err
|
||||
))),
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error writing to socket - Error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -713,11 +730,26 @@ impl BytesMutReader for Cursor<&BytesMut> {
|
||||
let mut buf = vec![];
|
||||
match self.read_until(b'\0', &mut buf) {
|
||||
Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()),
|
||||
Err(err) => Err(Error::ParseBytesError(err.to_string())),
|
||||
Err(err) => return Err(Error::ParseBytesError(err.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BytesMutReader for BytesMut {
|
||||
/// Should only be used when reading strings from the message protocol.
|
||||
/// Can be used to read multiple strings from the same message which are separated by the null byte
|
||||
fn read_string(&mut self) -> Result<String, Error> {
|
||||
let null_index = self.iter().position(|&byte| byte == b'\0');
|
||||
|
||||
match null_index {
|
||||
Some(index) => {
|
||||
let string_bytes = self.split_to(index + 1);
|
||||
Ok(String::from_utf8_lossy(&string_bytes[..string_bytes.len() - 1]).to_string())
|
||||
}
|
||||
None => return Err(Error::ParseBytesError("Could not read string".to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Parse (F) message.
|
||||
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -1088,3 +1120,298 @@ pub fn prepared_statement_name() -> String {
|
||||
PREPARED_STATEMENT_COUNTER.fetch_add(1, Ordering::SeqCst)
|
||||
)
|
||||
}
|
||||
|
||||
// from https://www.postgresql.org/docs/12/protocol-error-fields.html
|
||||
#[derive(Debug, Default, PartialEq)]
|
||||
pub struct PgErrorMsg {
|
||||
pub severity_localized: String, // S
|
||||
pub severity: String, // V
|
||||
pub code: String, // C
|
||||
pub message: String, // M
|
||||
pub detail: Option<String>, // D
|
||||
pub hint: Option<String>, // H
|
||||
pub position: Option<u32>, // P
|
||||
pub internal_position: Option<u32>, // p
|
||||
pub internal_query: Option<String>, // q
|
||||
pub where_context: Option<String>, // W
|
||||
pub schema_name: Option<String>, // s
|
||||
pub table_name: Option<String>, // t
|
||||
pub column_name: Option<String>, // c
|
||||
pub data_type_name: Option<String>, // d
|
||||
pub constraint_name: Option<String>, // n
|
||||
pub file_name: Option<String>, // F
|
||||
pub line: Option<u32>, // L
|
||||
pub routine: Option<String>, // R
|
||||
}
|
||||
|
||||
// TODO: implement with https://docs.rs/derive_more/latest/derive_more/
|
||||
impl Display for PgErrorMsg {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "[severity: {}]", self.severity)?;
|
||||
write!(f, "[code: {}]", self.code)?;
|
||||
write!(f, "[message: {}]", self.message)?;
|
||||
if let Some(val) = &self.detail {
|
||||
write!(f, "[detail: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.hint {
|
||||
write!(f, "[hint: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.position {
|
||||
write!(f, "[position: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.internal_position {
|
||||
write!(f, "[internal_position: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.internal_query {
|
||||
write!(f, "[internal_query: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.internal_query {
|
||||
write!(f, "[internal_query: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.where_context {
|
||||
write!(f, "[where: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.schema_name {
|
||||
write!(f, "[schema_name: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.table_name {
|
||||
write!(f, "[table_name: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.column_name {
|
||||
write!(f, "[column_name: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.data_type_name {
|
||||
write!(f, "[data_type_name: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.constraint_name {
|
||||
write!(f, "[constraint_name: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.file_name {
|
||||
write!(f, "[file_name: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.line {
|
||||
write!(f, "[line: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.routine {
|
||||
write!(f, "[routine: {val}]")?;
|
||||
}
|
||||
|
||||
write!(f, " ")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl PgErrorMsg {
|
||||
pub fn parse(error_msg: Vec<u8>) -> Result<PgErrorMsg, Error> {
|
||||
let mut out = PgErrorMsg {
|
||||
severity_localized: "".to_string(),
|
||||
severity: "".to_string(),
|
||||
code: "".to_string(),
|
||||
message: "".to_string(),
|
||||
detail: None,
|
||||
hint: None,
|
||||
position: None,
|
||||
internal_position: None,
|
||||
internal_query: None,
|
||||
where_context: None,
|
||||
schema_name: None,
|
||||
table_name: None,
|
||||
column_name: None,
|
||||
data_type_name: None,
|
||||
constraint_name: None,
|
||||
file_name: None,
|
||||
line: None,
|
||||
routine: None,
|
||||
};
|
||||
for msg_part in error_msg.split(|v| *v == MESSAGE_TERMINATOR) {
|
||||
if msg_part.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let msg_content = match String::from_utf8_lossy(&msg_part[1..]).parse() {
|
||||
Ok(c) => c,
|
||||
Err(err) => {
|
||||
return Err(Error::ServerMessageParserError(format!(
|
||||
"could not parse server message field. err {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
match &msg_part[0] {
|
||||
b'S' => {
|
||||
out.severity_localized = msg_content;
|
||||
}
|
||||
b'V' => {
|
||||
out.severity = msg_content;
|
||||
}
|
||||
b'C' => {
|
||||
out.code = msg_content;
|
||||
}
|
||||
b'M' => {
|
||||
out.message = msg_content;
|
||||
}
|
||||
b'D' => {
|
||||
out.detail = Some(msg_content);
|
||||
}
|
||||
b'H' => {
|
||||
out.hint = Some(msg_content);
|
||||
}
|
||||
b'P' => out.position = Some(u32::from_str(msg_content.as_str()).unwrap_or(0)),
|
||||
b'p' => {
|
||||
out.internal_position = Some(u32::from_str(msg_content.as_str()).unwrap_or(0))
|
||||
}
|
||||
b'q' => {
|
||||
out.internal_query = Some(msg_content);
|
||||
}
|
||||
b'W' => {
|
||||
out.where_context = Some(msg_content);
|
||||
}
|
||||
b's' => {
|
||||
out.schema_name = Some(msg_content);
|
||||
}
|
||||
b't' => {
|
||||
out.table_name = Some(msg_content);
|
||||
}
|
||||
b'c' => {
|
||||
out.column_name = Some(msg_content);
|
||||
}
|
||||
b'd' => {
|
||||
out.data_type_name = Some(msg_content);
|
||||
}
|
||||
b'n' => {
|
||||
out.constraint_name = Some(msg_content);
|
||||
}
|
||||
b'F' => {
|
||||
out.file_name = Some(msg_content);
|
||||
}
|
||||
b'L' => out.line = Some(u32::from_str(msg_content.as_str()).unwrap_or(0)),
|
||||
b'R' => {
|
||||
out.routine = Some(msg_content);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::messages::PgErrorMsg;
|
||||
use log::{error, info};
|
||||
|
||||
fn field(kind: char, content: &str) -> Vec<u8> {
|
||||
format!("{kind}{content}\0").as_bytes().to_vec()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_fields() {
|
||||
let mut complete_msg = vec![];
|
||||
let severity = "FATAL";
|
||||
complete_msg.extend(field('S', &severity));
|
||||
complete_msg.extend(field('V', &severity));
|
||||
|
||||
let error_code = "29P02";
|
||||
complete_msg.extend(field('C', &error_code));
|
||||
let message = "password authentication failed for user \"wrong_user\"";
|
||||
complete_msg.extend(field('M', &message));
|
||||
let detail_msg = "super detailed message";
|
||||
complete_msg.extend(field('D', &detail_msg));
|
||||
let hint_msg = "hint detail here";
|
||||
complete_msg.extend(field('H', &hint_msg));
|
||||
complete_msg.extend(field('P', "123"));
|
||||
complete_msg.extend(field('p', "234"));
|
||||
let internal_query = "SELECT * from foo;";
|
||||
complete_msg.extend(field('q', &internal_query));
|
||||
let where_msg = "where goes here";
|
||||
complete_msg.extend(field('W', &where_msg));
|
||||
let schema_msg = "schema_name";
|
||||
complete_msg.extend(field('s', &schema_msg));
|
||||
let table_msg = "table_name";
|
||||
complete_msg.extend(field('t', &table_msg));
|
||||
let column_msg = "column_name";
|
||||
complete_msg.extend(field('c', &column_msg));
|
||||
let data_type_msg = "type_name";
|
||||
complete_msg.extend(field('d', &data_type_msg));
|
||||
let constraint_msg = "constraint_name";
|
||||
complete_msg.extend(field('n', &constraint_msg));
|
||||
let file_msg = "pgcat.c";
|
||||
complete_msg.extend(field('F', &file_msg));
|
||||
complete_msg.extend(field('L', "335"));
|
||||
let routine_msg = "my_failing_routine";
|
||||
complete_msg.extend(field('R', &routine_msg));
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::INFO)
|
||||
.with_ansi(true)
|
||||
.init();
|
||||
|
||||
info!(
|
||||
"full message: {}",
|
||||
PgErrorMsg::parse(complete_msg.clone()).unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
PgErrorMsg {
|
||||
severity_localized: severity.to_string(),
|
||||
severity: severity.to_string(),
|
||||
code: error_code.to_string(),
|
||||
message: message.to_string(),
|
||||
detail: Some(detail_msg.to_string()),
|
||||
hint: Some(hint_msg.to_string()),
|
||||
position: Some(123),
|
||||
internal_position: Some(234),
|
||||
internal_query: Some(internal_query.to_string()),
|
||||
where_context: Some(where_msg.to_string()),
|
||||
schema_name: Some(schema_msg.to_string()),
|
||||
table_name: Some(table_msg.to_string()),
|
||||
column_name: Some(column_msg.to_string()),
|
||||
data_type_name: Some(data_type_msg.to_string()),
|
||||
constraint_name: Some(constraint_msg.to_string()),
|
||||
file_name: Some(file_msg.to_string()),
|
||||
line: Some(335),
|
||||
routine: Some(routine_msg.to_string()),
|
||||
},
|
||||
PgErrorMsg::parse(complete_msg).unwrap()
|
||||
);
|
||||
|
||||
let mut only_mandatory_msg = vec![];
|
||||
only_mandatory_msg.extend(field('S', &severity));
|
||||
only_mandatory_msg.extend(field('V', &severity));
|
||||
only_mandatory_msg.extend(field('C', &error_code));
|
||||
only_mandatory_msg.extend(field('M', &message));
|
||||
only_mandatory_msg.extend(field('D', &detail_msg));
|
||||
|
||||
let err_fields = PgErrorMsg::parse(only_mandatory_msg.clone()).unwrap();
|
||||
info!("only mandatory fields: {}", &err_fields);
|
||||
error!(
|
||||
"server error: {}: {}",
|
||||
err_fields.severity, err_fields.message
|
||||
);
|
||||
assert_eq!(
|
||||
PgErrorMsg {
|
||||
severity_localized: severity.to_string(),
|
||||
severity: severity.to_string(),
|
||||
code: error_code.to_string(),
|
||||
message: message.to_string(),
|
||||
detail: Some(detail_msg.to_string()),
|
||||
hint: None,
|
||||
position: None,
|
||||
internal_position: None,
|
||||
internal_query: None,
|
||||
where_context: None,
|
||||
schema_name: None,
|
||||
table_name: None,
|
||||
column_name: None,
|
||||
data_type_name: None,
|
||||
constraint_name: None,
|
||||
file_name: None,
|
||||
line: None,
|
||||
routine: None,
|
||||
},
|
||||
PgErrorMsg::parse(only_mandatory_msg).unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,6 +41,7 @@ impl MirroredClient {
|
||||
Arc::new(RwLock::new(None)),
|
||||
None,
|
||||
true,
|
||||
false,
|
||||
);
|
||||
|
||||
Pool::builder()
|
||||
@@ -78,7 +79,7 @@ impl MirroredClient {
|
||||
}
|
||||
|
||||
// Incoming data from server (we read to clear the socket buffer and discard the data)
|
||||
recv_result = server.recv() => {
|
||||
recv_result = server.recv(None) => {
|
||||
match recv_result {
|
||||
Ok(message) => trace!("Received from mirror: {} {:?}", String::from_utf8_lossy(&message[..]), address.clone()),
|
||||
Err(err) => {
|
||||
@@ -142,12 +143,12 @@ impl MirroringManager {
|
||||
});
|
||||
|
||||
Self {
|
||||
byte_senders,
|
||||
byte_senders: byte_senders,
|
||||
disconnect_senders: exit_senders,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send(&mut self, bytes: &BytesMut) {
|
||||
pub fn send(self: &mut Self, bytes: &BytesMut) {
|
||||
// We want to avoid performing an allocation if we won't be able to send the message
|
||||
// There is a possibility of a race here where we check the capacity and then the channel is
|
||||
// closed or the capacity is reduced to 0, but mirroring is best effort anyway
|
||||
@@ -169,7 +170,7 @@ impl MirroringManager {
|
||||
});
|
||||
}
|
||||
|
||||
pub fn disconnect(&mut self) {
|
||||
pub fn disconnect(self: &mut Self) {
|
||||
self.disconnect_senders
|
||||
.iter_mut()
|
||||
.for_each(|sender| match sender.try_send(()) {
|
||||
|
||||
@@ -92,7 +92,7 @@ impl<'a> Plugin for Intercept<'a> {
|
||||
.map(|s| {
|
||||
let s = s.as_str().to_string();
|
||||
|
||||
if s.is_empty() {
|
||||
if s == "" {
|
||||
None
|
||||
} else {
|
||||
Some(s)
|
||||
|
||||
@@ -30,7 +30,6 @@ pub enum PluginOutput {
|
||||
Intercept(BytesMut),
|
||||
}
|
||||
|
||||
#[allow(clippy::ptr_arg)]
|
||||
#[async_trait]
|
||||
pub trait Plugin {
|
||||
// Run before the query is sent to the server.
|
||||
|
||||
@@ -20,7 +20,7 @@ impl<'a> Prewarmer<'a> {
|
||||
self.server.address(),
|
||||
query
|
||||
);
|
||||
self.server.query(query).await?;
|
||||
self.server.query(&query).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -30,22 +30,27 @@ impl<'a> Plugin for TableAccess<'a> {
|
||||
return Ok(PluginOutput::Allow);
|
||||
}
|
||||
|
||||
let control_flow = visit_relations(ast, |relation| {
|
||||
let relation = relation.to_string();
|
||||
let table_name = relation.split('.').last().unwrap().to_string();
|
||||
let mut found = None;
|
||||
|
||||
if self.tables.contains(&table_name) {
|
||||
ControlFlow::Break(table_name)
|
||||
visit_relations(ast, |relation| {
|
||||
let relation = relation.to_string();
|
||||
let parts = relation.split(".").collect::<Vec<&str>>();
|
||||
let table_name = parts.last().unwrap();
|
||||
|
||||
if self.tables.contains(&table_name.to_string()) {
|
||||
found = Some(table_name.to_string());
|
||||
ControlFlow::<()>::Break(())
|
||||
} else {
|
||||
ControlFlow::Continue(())
|
||||
ControlFlow::<()>::Continue(())
|
||||
}
|
||||
});
|
||||
|
||||
if let ControlFlow::Break(found) = control_flow {
|
||||
debug!("Blocking access to table \"{found}\"");
|
||||
if let Some(found) = found {
|
||||
debug!("Blocking access to table \"{}\"", found);
|
||||
|
||||
Ok(PluginOutput::Deny(format!(
|
||||
"permission for table \"{found}\" denied",
|
||||
"permission for table \"{}\" denied",
|
||||
found
|
||||
)))
|
||||
} else {
|
||||
Ok(PluginOutput::Allow)
|
||||
|
||||
171
src/pool.rs
171
src/pool.rs
@@ -1,7 +1,6 @@
|
||||
use arc_swap::ArcSwap;
|
||||
use async_trait::async_trait;
|
||||
use bb8::{ManageConnection, Pool, PooledConnection, QueueStrategy};
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use chrono::naive::NaiveDateTime;
|
||||
use log::{debug, error, info, warn};
|
||||
use once_cell::sync::Lazy;
|
||||
@@ -11,6 +10,7 @@ use rand::thread_rng;
|
||||
use regex::Regex;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
@@ -19,13 +19,13 @@ use std::time::Instant;
|
||||
use tokio::sync::Notify;
|
||||
|
||||
use crate::config::{
|
||||
get_config, Address, General, LoadBalancingMode, Plugins, PoolMode, Role, User,
|
||||
get_config, Address, DefaultShard, General, LoadBalancingMode, Plugins, PoolMode, Role, User,
|
||||
};
|
||||
use crate::errors::Error;
|
||||
|
||||
use crate::auth_passthrough::AuthPassthrough;
|
||||
use crate::plugins::prewarmer;
|
||||
use crate::server::Server;
|
||||
use crate::server::{Server, ServerParameters};
|
||||
use crate::sharding::ShardingFunction;
|
||||
use crate::stats::{AddressStats, ClientStats, ServerStats};
|
||||
|
||||
@@ -111,6 +111,12 @@ pub struct PoolSettings {
|
||||
// Enable/disable query parser.
|
||||
pub query_parser_enabled: bool,
|
||||
|
||||
// Max length of query the parser will parse.
|
||||
pub query_parser_max_length: Option<usize>,
|
||||
|
||||
// Infer role
|
||||
pub query_parser_read_write_splitting: bool,
|
||||
|
||||
// Read from the primary as well or not.
|
||||
pub primary_reads_enabled: bool,
|
||||
|
||||
@@ -135,6 +141,9 @@ pub struct PoolSettings {
|
||||
// Regex for searching for the shard id in SQL statements
|
||||
pub shard_id_regex: Option<Regex>,
|
||||
|
||||
// What to do when no shard is selected in a sharded system
|
||||
pub default_shard: DefaultShard,
|
||||
|
||||
// Limit how much of each query is searched for a potential shard regex match
|
||||
pub regex_search_limit: usize,
|
||||
|
||||
@@ -157,6 +166,8 @@ impl Default for PoolSettings {
|
||||
db: String::default(),
|
||||
default_role: None,
|
||||
query_parser_enabled: false,
|
||||
query_parser_max_length: None,
|
||||
query_parser_read_write_splitting: false,
|
||||
primary_reads_enabled: true,
|
||||
sharding_function: ShardingFunction::PgBigintHash,
|
||||
automatic_sharding_key: None,
|
||||
@@ -166,6 +177,7 @@ impl Default for PoolSettings {
|
||||
sharding_key_regex: None,
|
||||
shard_id_regex: None,
|
||||
regex_search_limit: 1000,
|
||||
default_shard: DefaultShard::Shard(0),
|
||||
auth_query: None,
|
||||
auth_query_user: None,
|
||||
auth_query_password: None,
|
||||
@@ -188,10 +200,10 @@ pub struct ConnectionPool {
|
||||
/// that should not be queried.
|
||||
banlist: BanList,
|
||||
|
||||
/// The server information (K messages) have to be passed to the
|
||||
/// The server information has to be passed to the
|
||||
/// clients on startup. We pre-connect to all shards and replicas
|
||||
/// on pool creation and save the K messages here.
|
||||
server_info: Arc<RwLock<BytesMut>>,
|
||||
/// on pool creation and save the startup parameters here.
|
||||
original_server_parameters: Arc<RwLock<ServerParameters>>,
|
||||
|
||||
/// Pool configuration.
|
||||
pub settings: PoolSettings,
|
||||
@@ -229,17 +241,20 @@ impl ConnectionPool {
|
||||
let old_pool_ref = get_pool(pool_name, &user.username);
|
||||
let identifier = PoolIdentifier::new(pool_name, &user.username);
|
||||
|
||||
if let Some(pool) = old_pool_ref {
|
||||
// If the pool hasn't changed, get existing reference and insert it into the new_pools.
|
||||
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
|
||||
if pool.config_hash == new_pool_hash_value {
|
||||
info!(
|
||||
"[pool: {}][user: {}] has not changed",
|
||||
pool_name, user.username
|
||||
);
|
||||
new_pools.insert(identifier.clone(), pool.clone());
|
||||
continue;
|
||||
match old_pool_ref {
|
||||
Some(pool) => {
|
||||
// If the pool hasn't changed, get existing reference and insert it into the new_pools.
|
||||
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
|
||||
if pool.config_hash == new_pool_hash_value {
|
||||
info!(
|
||||
"[pool: {}][user: {}] has not changed",
|
||||
pool_name, user.username
|
||||
);
|
||||
new_pools.insert(identifier.clone(), pool.clone());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
None => (),
|
||||
}
|
||||
|
||||
info!(
|
||||
@@ -289,6 +304,7 @@ impl ConnectionPool {
|
||||
pool_name: pool_name.clone(),
|
||||
mirrors: vec![],
|
||||
stats: Arc::new(AddressStats::default()),
|
||||
error_count: Arc::new(AtomicU64::new(0)),
|
||||
});
|
||||
address_id += 1;
|
||||
}
|
||||
@@ -307,6 +323,7 @@ impl ConnectionPool {
|
||||
pool_name: pool_name.clone(),
|
||||
mirrors: mirror_addresses,
|
||||
stats: Arc::new(AddressStats::default()),
|
||||
error_count: Arc::new(AtomicU64::new(0)),
|
||||
};
|
||||
|
||||
address_id += 1;
|
||||
@@ -361,6 +378,7 @@ impl ConnectionPool {
|
||||
None => config.plugins.clone(),
|
||||
},
|
||||
pool_config.cleanup_server_connections,
|
||||
pool_config.log_client_parameter_status_changes,
|
||||
);
|
||||
|
||||
let connect_timeout = match pool_config.connect_timeout {
|
||||
@@ -434,7 +452,7 @@ impl ConnectionPool {
|
||||
addresses,
|
||||
banlist: Arc::new(RwLock::new(banlist)),
|
||||
config_hash: new_pool_hash_value,
|
||||
server_info: Arc::new(RwLock::new(BytesMut::new())),
|
||||
original_server_parameters: Arc::new(RwLock::new(ServerParameters::new())),
|
||||
auth_hash: pool_auth_hash,
|
||||
settings: PoolSettings {
|
||||
pool_mode: match user.pool_mode {
|
||||
@@ -453,6 +471,9 @@ impl ConnectionPool {
|
||||
_ => unreachable!(),
|
||||
},
|
||||
query_parser_enabled: pool_config.query_parser_enabled,
|
||||
query_parser_max_length: pool_config.query_parser_max_length,
|
||||
query_parser_read_write_splitting: pool_config
|
||||
.query_parser_read_write_splitting,
|
||||
primary_reads_enabled: pool_config.primary_reads_enabled,
|
||||
sharding_function: pool_config.sharding_function,
|
||||
automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
|
||||
@@ -468,6 +489,7 @@ impl ConnectionPool {
|
||||
.clone()
|
||||
.map(|regex| Regex::new(regex.as_str()).unwrap()),
|
||||
regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
|
||||
default_shard: pool_config.default_shard.clone(),
|
||||
auth_query: pool_config.auth_query.clone(),
|
||||
auth_query_user: pool_config.auth_query_user.clone(),
|
||||
auth_query_password: pool_config.auth_query_password.clone(),
|
||||
@@ -514,7 +536,7 @@ impl ConnectionPool {
|
||||
for server in 0..self.servers(shard) {
|
||||
let databases = self.databases.clone();
|
||||
let validated = Arc::clone(&validated);
|
||||
let pool_server_info = Arc::clone(&self.server_info);
|
||||
let pool_server_parameters = Arc::clone(&self.original_server_parameters);
|
||||
|
||||
let task = tokio::task::spawn(async move {
|
||||
let connection = match databases[shard][server].get().await {
|
||||
@@ -527,11 +549,10 @@ impl ConnectionPool {
|
||||
|
||||
let proxy = connection;
|
||||
let server = &*proxy;
|
||||
let server_info = server.server_info();
|
||||
let server_parameters: ServerParameters = server.server_parameters();
|
||||
|
||||
let mut guard = pool_server_info.write();
|
||||
guard.clear();
|
||||
guard.put(server_info.clone());
|
||||
let mut guard = pool_server_parameters.write();
|
||||
*guard = server_parameters;
|
||||
validated.store(true, Ordering::Relaxed);
|
||||
});
|
||||
|
||||
@@ -543,7 +564,7 @@ impl ConnectionPool {
|
||||
|
||||
// TODO: compare server information to make sure
|
||||
// all shards are running identical configurations.
|
||||
if self.server_info.read().is_empty() {
|
||||
if !self.validated() {
|
||||
error!("Could not validate connection pool");
|
||||
return Err(Error::AllServersDown);
|
||||
}
|
||||
@@ -590,19 +611,51 @@ impl ConnectionPool {
|
||||
/// Get a connection from the pool.
|
||||
pub async fn get(
|
||||
&self,
|
||||
shard: usize, // shard number
|
||||
shard: Option<usize>, // shard number
|
||||
role: Option<Role>, // primary or replica
|
||||
client_stats: &ClientStats, // client id
|
||||
) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
|
||||
let mut candidates: Vec<&Address> = self.addresses[shard]
|
||||
.iter()
|
||||
.filter(|address| address.role == role)
|
||||
.collect();
|
||||
let effective_shard_id = if self.shards() == 1 {
|
||||
// The base, unsharded case
|
||||
Some(0)
|
||||
} else {
|
||||
if !self.valid_shard_id(shard) {
|
||||
// None is valid shard ID so it is safe to unwrap here
|
||||
return Err(Error::InvalidShardId(shard.unwrap()));
|
||||
}
|
||||
shard
|
||||
};
|
||||
|
||||
// We shuffle even if least_outstanding_queries is used to avoid imbalance
|
||||
// in cases where all candidates have more or less the same number of outstanding
|
||||
// queries
|
||||
let mut candidates = self
|
||||
.addresses
|
||||
.iter()
|
||||
.flatten()
|
||||
.filter(|address| address.role == role)
|
||||
.collect::<Vec<&Address>>();
|
||||
|
||||
// We start with a shuffled list of addresses even if we end up resorting
|
||||
// this is meant to avoid hitting instance 0 everytime if the sorting metric
|
||||
// ends up being the same for all instances
|
||||
candidates.shuffle(&mut thread_rng());
|
||||
|
||||
match effective_shard_id {
|
||||
Some(shard_id) => candidates.retain(|address| address.shard == shard_id),
|
||||
None => match self.settings.default_shard {
|
||||
DefaultShard::Shard(shard_id) => {
|
||||
candidates.retain(|address| address.shard == shard_id)
|
||||
}
|
||||
DefaultShard::Random => (),
|
||||
DefaultShard::RandomHealthy => {
|
||||
candidates.sort_by(|a, b| {
|
||||
b.error_count
|
||||
.load(Ordering::Relaxed)
|
||||
.partial_cmp(&a.error_count.load(Ordering::Relaxed))
|
||||
.unwrap()
|
||||
});
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
if self.settings.load_balancing_mode == LoadBalancingMode::LeastOutstandingConnections {
|
||||
candidates.sort_by(|a, b| {
|
||||
self.busy_connection_count(b)
|
||||
@@ -625,7 +678,7 @@ impl ConnectionPool {
|
||||
let mut force_healthcheck = false;
|
||||
|
||||
if self.is_banned(address) {
|
||||
if self.try_unban(address).await {
|
||||
if self.try_unban(&address).await {
|
||||
force_healthcheck = true;
|
||||
} else {
|
||||
debug!("Address {:?} is banned", address);
|
||||
@@ -638,7 +691,10 @@ impl ConnectionPool {
|
||||
.get()
|
||||
.await
|
||||
{
|
||||
Ok(conn) => conn,
|
||||
Ok(conn) => {
|
||||
address.reset_error_count();
|
||||
conn
|
||||
}
|
||||
Err(err) => {
|
||||
error!(
|
||||
"Connection checkout error for instance {:?}, error: {:?}",
|
||||
@@ -664,7 +720,7 @@ impl ConnectionPool {
|
||||
// since we last checked the server is ok.
|
||||
// Health checks are pretty expensive.
|
||||
if !require_healthcheck {
|
||||
let checkout_time: u64 = now.elapsed().as_micros() as u64;
|
||||
let checkout_time = now.elapsed().as_micros() as u64;
|
||||
client_stats.checkout_time(checkout_time);
|
||||
server
|
||||
.stats()
|
||||
@@ -678,7 +734,7 @@ impl ConnectionPool {
|
||||
.run_health_check(address, server, now, client_stats)
|
||||
.await
|
||||
{
|
||||
let checkout_time: u64 = now.elapsed().as_micros() as u64;
|
||||
let checkout_time = now.elapsed().as_micros() as u64;
|
||||
client_stats.checkout_time(checkout_time);
|
||||
server
|
||||
.stats()
|
||||
@@ -690,7 +746,12 @@ impl ConnectionPool {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
client_stats.idle();
|
||||
|
||||
let checkout_time = now.elapsed().as_micros() as u64;
|
||||
client_stats.checkout_time(checkout_time);
|
||||
|
||||
Err(Error::AllServersDown)
|
||||
}
|
||||
|
||||
@@ -745,14 +806,26 @@ impl ConnectionPool {
|
||||
// Don't leave a bad connection in the pool.
|
||||
server.mark_bad();
|
||||
|
||||
self.ban(address, BanReason::FailedHealthCheck, Some(client_info));
|
||||
false
|
||||
self.ban(&address, BanReason::FailedHealthCheck, Some(client_info));
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Ban an address (i.e. replica). It no longer will serve
|
||||
/// traffic for any new transactions. Existing transactions on that replica
|
||||
/// will finish successfully or error out to the clients.
|
||||
pub fn ban(&self, address: &Address, reason: BanReason, client_info: Option<&ClientStats>) {
|
||||
// Count the number of errors since the last successful checkout
|
||||
// This is used to determine if the shard is down
|
||||
match reason {
|
||||
BanReason::FailedHealthCheck
|
||||
| BanReason::FailedCheckout
|
||||
| BanReason::MessageSendFailed
|
||||
| BanReason::MessageReceiveFailed => {
|
||||
address.increment_error_count();
|
||||
}
|
||||
_ => (),
|
||||
};
|
||||
|
||||
// Primary can never be banned
|
||||
if address.role == Role::Primary {
|
||||
return;
|
||||
@@ -858,10 +931,10 @@ impl ConnectionPool {
|
||||
let guard = self.banlist.read();
|
||||
for banlist in guard.iter() {
|
||||
for (address, (reason, timestamp)) in banlist.iter() {
|
||||
bans.push((address.clone(), (reason.clone(), *timestamp)));
|
||||
bans.push((address.clone(), (reason.clone(), timestamp.clone())));
|
||||
}
|
||||
}
|
||||
bans
|
||||
return bans;
|
||||
}
|
||||
|
||||
/// Get the address from the host url
|
||||
@@ -903,10 +976,11 @@ impl ConnectionPool {
|
||||
&self.addresses[shard][server]
|
||||
}
|
||||
|
||||
pub fn server_info(&self) -> BytesMut {
|
||||
self.server_info.read().clone()
|
||||
pub fn server_parameters(&self) -> ServerParameters {
|
||||
self.original_server_parameters.read().clone()
|
||||
}
|
||||
|
||||
/// Get the number of checked out connection for an address
|
||||
fn busy_connection_count(&self, address: &Address) -> u32 {
|
||||
let state = self.pool_state(address.shard, address.address_index);
|
||||
let idle = state.idle_connections;
|
||||
@@ -918,7 +992,14 @@ impl ConnectionPool {
|
||||
}
|
||||
let busy = provisioned - idle;
|
||||
debug!("{:?} has {:?} busy connections", address, busy);
|
||||
busy
|
||||
return busy;
|
||||
}
|
||||
|
||||
fn valid_shard_id(&self, shard: Option<usize>) -> bool {
|
||||
match shard {
|
||||
None => true,
|
||||
Some(shard) => shard < self.shards(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -944,6 +1025,9 @@ pub struct ServerPool {
|
||||
|
||||
/// Should we clean up dirty connections before putting them into the pool?
|
||||
cleanup_connections: bool,
|
||||
|
||||
/// Log client parameter status changes
|
||||
log_client_parameter_status_changes: bool,
|
||||
}
|
||||
|
||||
impl ServerPool {
|
||||
@@ -955,6 +1039,7 @@ impl ServerPool {
|
||||
auth_hash: Arc<RwLock<Option<String>>>,
|
||||
plugins: Option<Plugins>,
|
||||
cleanup_connections: bool,
|
||||
log_client_parameter_status_changes: bool,
|
||||
) -> ServerPool {
|
||||
ServerPool {
|
||||
address,
|
||||
@@ -964,6 +1049,7 @@ impl ServerPool {
|
||||
auth_hash,
|
||||
plugins,
|
||||
cleanup_connections,
|
||||
log_client_parameter_status_changes,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -993,6 +1079,7 @@ impl ManageConnection for ServerPool {
|
||||
stats.clone(),
|
||||
self.auth_hash.clone(),
|
||||
self.cleanup_connections,
|
||||
self.log_client_parameter_status_changes,
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
||||
@@ -19,9 +19,9 @@ use crate::plugins::{Intercept, Plugin, PluginOutput, QueryLogger, TableAccess};
|
||||
use crate::pool::PoolSettings;
|
||||
use crate::sharding::Sharder;
|
||||
|
||||
use std::cmp;
|
||||
use std::collections::BTreeSet;
|
||||
use std::io::Cursor;
|
||||
use std::{cmp, mem};
|
||||
|
||||
/// Regexes used to parse custom commands.
|
||||
const CUSTOM_SQL_REGEXES: [&str; 7] = [
|
||||
@@ -67,7 +67,6 @@ static CUSTOM_SQL_REGEX_SET: OnceCell<RegexSet> = OnceCell::new();
|
||||
static CUSTOM_SQL_REGEX_LIST: OnceCell<Vec<Regex>> = OnceCell::new();
|
||||
|
||||
/// The query router.
|
||||
#[derive(Default)]
|
||||
pub struct QueryRouter {
|
||||
/// Which shard we should be talking to right now.
|
||||
active_shard: Option<usize>,
|
||||
@@ -92,7 +91,7 @@ impl QueryRouter {
|
||||
/// One-time initialization of regexes
|
||||
/// that parse our custom SQL protocol.
|
||||
pub fn setup() -> bool {
|
||||
let set = match RegexSet::new(CUSTOM_SQL_REGEXES) {
|
||||
let set = match RegexSet::new(&CUSTOM_SQL_REGEXES) {
|
||||
Ok(rgx) => rgx,
|
||||
Err(err) => {
|
||||
error!("QueryRouter::setup Could not compile regex set: {:?}", err);
|
||||
@@ -117,8 +116,15 @@ impl QueryRouter {
|
||||
|
||||
/// Create a new instance of the query router.
|
||||
/// Each client gets its own.
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
pub fn new() -> QueryRouter {
|
||||
QueryRouter {
|
||||
active_shard: None,
|
||||
active_role: None,
|
||||
query_parser_enabled: None,
|
||||
primary_reads_enabled: None,
|
||||
pool_settings: PoolSettings::default(),
|
||||
placeholders: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Pool settings can change because of a config reload.
|
||||
@@ -126,7 +132,7 @@ impl QueryRouter {
|
||||
self.pool_settings = pool_settings;
|
||||
}
|
||||
|
||||
pub fn pool_settings(&self) -> &PoolSettings {
|
||||
pub fn pool_settings<'a>(&'a self) -> &'a PoolSettings {
|
||||
&self.pool_settings
|
||||
}
|
||||
|
||||
@@ -135,18 +141,24 @@ impl QueryRouter {
|
||||
let mut message_cursor = Cursor::new(message_buffer);
|
||||
|
||||
let code = message_cursor.get_u8() as char;
|
||||
let len = message_cursor.get_i32() as usize;
|
||||
|
||||
let comment_shard_routing_enabled = self.pool_settings.shard_id_regex.is_some()
|
||||
|| self.pool_settings.sharding_key_regex.is_some();
|
||||
|
||||
// Check for any sharding regex matches in any queries
|
||||
match code {
|
||||
// For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement
|
||||
'P' | 'Q' => {
|
||||
if self.pool_settings.shard_id_regex.is_some()
|
||||
|| self.pool_settings.sharding_key_regex.is_some()
|
||||
{
|
||||
if comment_shard_routing_enabled {
|
||||
match code as char {
|
||||
// For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement
|
||||
'P' | 'Q' => {
|
||||
// Check only the first block of bytes configured by the pool settings
|
||||
let len = message_cursor.get_i32() as usize;
|
||||
let seg = cmp::min(len - 5, self.pool_settings.regex_search_limit);
|
||||
let initial_segment = String::from_utf8_lossy(&message_buffer[0..seg]);
|
||||
|
||||
let query_start_index = mem::size_of::<u8>() + mem::size_of::<i32>();
|
||||
|
||||
let initial_segment = String::from_utf8_lossy(
|
||||
&message_buffer[query_start_index..query_start_index + seg],
|
||||
);
|
||||
|
||||
// Check for a shard_id included in the query
|
||||
if let Some(shard_id_regex) = &self.pool_settings.shard_id_regex {
|
||||
@@ -155,7 +167,7 @@ impl QueryRouter {
|
||||
});
|
||||
if let Some(shard_id) = shard_id {
|
||||
debug!("Setting shard to {:?}", shard_id);
|
||||
self.set_shard(shard_id);
|
||||
self.set_shard(Some(shard_id));
|
||||
// Skip other command processing since a sharding command was found
|
||||
return None;
|
||||
}
|
||||
@@ -177,8 +189,8 @@ impl QueryRouter {
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Only simple protocol supported for commands processed below
|
||||
@@ -186,7 +198,6 @@ impl QueryRouter {
|
||||
return None;
|
||||
}
|
||||
|
||||
let _len = message_cursor.get_i32() as usize;
|
||||
let query = message_cursor.read_string().unwrap();
|
||||
|
||||
let regex_set = match CUSTOM_SQL_REGEX_SET.get() {
|
||||
@@ -238,7 +249,9 @@ impl QueryRouter {
|
||||
}
|
||||
}
|
||||
|
||||
Command::ShowShard => self.shard().to_string(),
|
||||
Command::ShowShard => self
|
||||
.shard()
|
||||
.map_or_else(|| "unset".to_string(), |x| x.to_string()),
|
||||
Command::ShowServerRole => match self.active_role {
|
||||
Some(Role::Primary) => Role::Primary.to_string(),
|
||||
Some(Role::Replica) => Role::Replica.to_string(),
|
||||
@@ -325,11 +338,23 @@ impl QueryRouter {
|
||||
Some((command, value))
|
||||
}
|
||||
|
||||
pub fn parse(message: &BytesMut) -> Result<Vec<Statement>, Error> {
|
||||
pub fn parse(&self, message: &BytesMut) -> Result<Vec<Statement>, Error> {
|
||||
let mut message_cursor = Cursor::new(message);
|
||||
|
||||
let code = message_cursor.get_u8() as char;
|
||||
let _len = message_cursor.get_i32() as usize;
|
||||
let len = message_cursor.get_i32() as usize;
|
||||
|
||||
match self.pool_settings.query_parser_max_length {
|
||||
Some(max_length) => {
|
||||
if len > max_length {
|
||||
return Err(Error::QueryRouterParserError(format!(
|
||||
"Query too long for parser: {} > {}",
|
||||
len, max_length
|
||||
)));
|
||||
}
|
||||
}
|
||||
None => (),
|
||||
};
|
||||
|
||||
let query = match code {
|
||||
// Query
|
||||
@@ -366,6 +391,10 @@ impl QueryRouter {
|
||||
|
||||
/// Try to infer which server to connect to based on the contents of the query.
|
||||
pub fn infer(&mut self, ast: &Vec<sqlparser::ast::Statement>) -> Result<(), Error> {
|
||||
if !self.pool_settings.query_parser_read_write_splitting {
|
||||
return Ok(()); // Nothing to do
|
||||
}
|
||||
|
||||
debug!("Inferring role");
|
||||
|
||||
if ast.is_empty() {
|
||||
@@ -391,10 +420,14 @@ impl QueryRouter {
|
||||
// or discard shard selection. If they point to the same shard though,
|
||||
// we can let them through as-is.
|
||||
// This is basically building a database now :)
|
||||
if let Some(shard) = self.infer_shard(query) {
|
||||
self.active_shard = Some(shard);
|
||||
debug!("Automatically using shard: {:?}", self.active_shard);
|
||||
}
|
||||
match self.infer_shard(query) {
|
||||
Some(shard) => {
|
||||
self.active_shard = Some(shard);
|
||||
debug!("Automatically using shard: {:?}", self.active_shard);
|
||||
}
|
||||
|
||||
None => (),
|
||||
};
|
||||
}
|
||||
|
||||
None => (),
|
||||
@@ -423,6 +456,10 @@ impl QueryRouter {
|
||||
/// N.B.: Only supports anonymous prepared statements since we don't
|
||||
/// keep a cache of them in PgCat.
|
||||
pub fn infer_shard_from_bind(&mut self, message: &BytesMut) -> bool {
|
||||
if !self.pool_settings.query_parser_read_write_splitting {
|
||||
return false; // Nothing to do
|
||||
}
|
||||
|
||||
debug!("Parsing bind message");
|
||||
|
||||
let mut message_cursor = Cursor::new(message);
|
||||
@@ -547,7 +584,7 @@ impl QueryRouter {
|
||||
// TODO: Support multi-shard queries some day.
|
||||
if shards.len() == 1 {
|
||||
debug!("Found one sharding key");
|
||||
self.set_shard(*shards.first().unwrap());
|
||||
self.set_shard(Some(*shards.first().unwrap()));
|
||||
true
|
||||
} else {
|
||||
debug!("Found no sharding keys");
|
||||
@@ -566,8 +603,8 @@ impl QueryRouter {
|
||||
.automatic_sharding_key
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.split('.')
|
||||
.map(Ident::new)
|
||||
.split(".")
|
||||
.map(|ident| Ident::new(ident))
|
||||
.collect::<Vec<Ident>>();
|
||||
|
||||
// Sharding key must be always fully qualified
|
||||
@@ -583,7 +620,7 @@ impl QueryRouter {
|
||||
Expr::Identifier(ident) => {
|
||||
// Only if we're dealing with only one table
|
||||
// and there is no ambiguity
|
||||
if ident.value == sharding_key[1].value {
|
||||
if &ident.value == &sharding_key[1].value {
|
||||
// Sharding key is unique enough, don't worry about
|
||||
// table names.
|
||||
if &sharding_key[0].value == "*" {
|
||||
@@ -596,13 +633,13 @@ impl QueryRouter {
|
||||
// SELECT * FROM t WHERE sharding_key = 5
|
||||
// Make sure the table name from the sharding key matches
|
||||
// the table name from the query.
|
||||
found = sharding_key[0].value == table[0].value;
|
||||
found = &sharding_key[0].value == &table[0].value;
|
||||
} else if table.len() == 2 {
|
||||
// Table name is fully qualified with the schema: e.g.
|
||||
// SELECT * FROM public.t WHERE sharding_key = 5
|
||||
// Ignore the schema (TODO: at some point, we want schema support)
|
||||
// and use the table name only.
|
||||
found = sharding_key[0].value == table[1].value;
|
||||
found = &sharding_key[0].value == &table[1].value;
|
||||
} else {
|
||||
debug!("Got table name with more than two idents, which is not possible");
|
||||
}
|
||||
@@ -614,8 +651,8 @@ impl QueryRouter {
|
||||
// The key is fully qualified in the query,
|
||||
// it will exist or Postgres will throw an error.
|
||||
if idents.len() == 2 {
|
||||
found = sharding_key[0].value == idents[0].value
|
||||
&& sharding_key[1].value == idents[1].value;
|
||||
found = &sharding_key[0].value == &idents[0].value
|
||||
&& &sharding_key[1].value == &idents[1].value;
|
||||
}
|
||||
// TODO: key can have schema as well, e.g. public.data.id (len == 3)
|
||||
}
|
||||
@@ -647,7 +684,7 @@ impl QueryRouter {
|
||||
}
|
||||
|
||||
Expr::Value(Value::Placeholder(placeholder)) => {
|
||||
match placeholder.replace('$', "").parse::<i16>() {
|
||||
match placeholder.replace("$", "").parse::<i16>() {
|
||||
Ok(placeholder) => result.push(ShardingKey::Placeholder(placeholder)),
|
||||
Err(_) => {
|
||||
debug!(
|
||||
@@ -673,9 +710,12 @@ impl QueryRouter {
|
||||
|
||||
match &*query.body {
|
||||
SetExpr::Query(query) => {
|
||||
if let Some(shard) = self.infer_shard(query) {
|
||||
shards.insert(shard);
|
||||
}
|
||||
match self.infer_shard(&*query) {
|
||||
Some(shard) => {
|
||||
shards.insert(shard);
|
||||
}
|
||||
None => (),
|
||||
};
|
||||
}
|
||||
|
||||
// SELECT * FROM ...
|
||||
@@ -685,22 +725,38 @@ impl QueryRouter {
|
||||
let mut table_names = Vec::new();
|
||||
|
||||
for table in select.from.iter() {
|
||||
if let TableFactor::Table { name, .. } = &table.relation {
|
||||
table_names.push(name.0.clone());
|
||||
}
|
||||
|
||||
// Get table names from all the joins.
|
||||
for join in table.joins.iter() {
|
||||
if let TableFactor::Table { name, .. } = &join.relation {
|
||||
match &table.relation {
|
||||
TableFactor::Table { name, .. } => {
|
||||
table_names.push(name.0.clone());
|
||||
}
|
||||
|
||||
_ => (),
|
||||
};
|
||||
|
||||
// Get table names from all the joins.
|
||||
for join in table.joins.iter() {
|
||||
match &join.relation {
|
||||
TableFactor::Table { name, .. } => {
|
||||
table_names.push(name.0.clone());
|
||||
}
|
||||
|
||||
_ => (),
|
||||
};
|
||||
|
||||
// We can filter results based on join conditions, e.g.
|
||||
// SELECT * FROM t INNER JOIN B ON B.sharding_key = 5;
|
||||
if let JoinOperator::Inner(JoinConstraint::On(expr)) = &join.join_operator {
|
||||
// Parse the selection criteria later.
|
||||
exprs.push(expr.clone());
|
||||
}
|
||||
match &join.join_operator {
|
||||
JoinOperator::Inner(inner_join) => match &inner_join {
|
||||
JoinConstraint::On(expr) => {
|
||||
// Parse the selection criteria later.
|
||||
exprs.push(expr.clone());
|
||||
}
|
||||
|
||||
_ => (),
|
||||
},
|
||||
|
||||
_ => (),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -774,16 +830,16 @@ impl QueryRouter {
|
||||
db: &self.pool_settings.db,
|
||||
};
|
||||
|
||||
let _ = query_logger.run(self, ast).await;
|
||||
let _ = query_logger.run(&self, ast).await;
|
||||
}
|
||||
|
||||
if let Some(ref intercept) = plugins.intercept {
|
||||
let mut intercept = Intercept {
|
||||
enabled: intercept.enabled,
|
||||
config: intercept,
|
||||
config: &intercept,
|
||||
};
|
||||
|
||||
let result = intercept.run(self, ast).await;
|
||||
let result = intercept.run(&self, ast).await;
|
||||
|
||||
if let Ok(PluginOutput::Intercept(output)) = result {
|
||||
return Ok(PluginOutput::Intercept(output));
|
||||
@@ -796,7 +852,7 @@ impl QueryRouter {
|
||||
tables: &table_access.tables,
|
||||
};
|
||||
|
||||
let result = table_access.run(self, ast).await;
|
||||
let result = table_access.run(&self, ast).await;
|
||||
|
||||
if let Ok(PluginOutput::Deny(error)) = result {
|
||||
return Ok(PluginOutput::Deny(error));
|
||||
@@ -812,7 +868,7 @@ impl QueryRouter {
|
||||
self.pool_settings.sharding_function,
|
||||
);
|
||||
let shard = sharder.shard(sharding_key);
|
||||
self.set_shard(shard);
|
||||
self.set_shard(Some(shard));
|
||||
self.active_shard
|
||||
}
|
||||
|
||||
@@ -822,17 +878,17 @@ impl QueryRouter {
|
||||
}
|
||||
|
||||
/// Get desired shard we should be talking to.
|
||||
pub fn shard(&self) -> usize {
|
||||
self.active_shard.unwrap_or(0)
|
||||
pub fn shard(&self) -> Option<usize> {
|
||||
self.active_shard
|
||||
}
|
||||
|
||||
pub fn set_shard(&mut self, shard: usize) {
|
||||
self.active_shard = Some(shard);
|
||||
pub fn set_shard(&mut self, shard: Option<usize>) {
|
||||
self.active_shard = shard;
|
||||
}
|
||||
|
||||
/// Should we attempt to parse queries?
|
||||
pub fn query_parser_enabled(&self) -> bool {
|
||||
match self.query_parser_enabled {
|
||||
let enabled = match self.query_parser_enabled {
|
||||
None => {
|
||||
debug!(
|
||||
"Using pool settings, query_parser_enabled: {}",
|
||||
@@ -848,7 +904,9 @@ impl QueryRouter {
|
||||
);
|
||||
value
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
enabled
|
||||
}
|
||||
|
||||
pub fn primary_reads_enabled(&self) -> bool {
|
||||
@@ -879,14 +937,11 @@ mod test {
|
||||
fn test_infer_replica() {
|
||||
QueryRouter::setup();
|
||||
let mut qr = QueryRouter::new();
|
||||
assert!(qr
|
||||
.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"))
|
||||
.is_some());
|
||||
qr.pool_settings.query_parser_read_write_splitting = true;
|
||||
assert!(qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")) != None);
|
||||
assert!(qr.query_parser_enabled());
|
||||
|
||||
assert!(qr
|
||||
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
|
||||
.is_some());
|
||||
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
|
||||
|
||||
let queries = vec![
|
||||
simple_query("SELECT * FROM items WHERE id = 5"),
|
||||
@@ -898,7 +953,7 @@ mod test {
|
||||
|
||||
for query in queries {
|
||||
// It's a recognized query
|
||||
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
||||
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), Some(Role::Replica));
|
||||
}
|
||||
}
|
||||
@@ -907,6 +962,7 @@ mod test {
|
||||
fn test_infer_primary() {
|
||||
QueryRouter::setup();
|
||||
let mut qr = QueryRouter::new();
|
||||
qr.pool_settings.query_parser_read_write_splitting = true;
|
||||
|
||||
let queries = vec![
|
||||
simple_query("UPDATE items SET name = 'pumpkin' WHERE id = 5"),
|
||||
@@ -917,7 +973,7 @@ mod test {
|
||||
|
||||
for query in queries {
|
||||
// It's a recognized query
|
||||
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
||||
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), Some(Role::Primary));
|
||||
}
|
||||
}
|
||||
@@ -927,11 +983,9 @@ mod test {
|
||||
QueryRouter::setup();
|
||||
let mut qr = QueryRouter::new();
|
||||
let query = simple_query("SELECT * FROM items WHERE id = 5");
|
||||
assert!(qr
|
||||
.try_execute_command(&simple_query("SET PRIMARY READS TO on"))
|
||||
.is_some());
|
||||
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None);
|
||||
|
||||
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
||||
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), None);
|
||||
}
|
||||
|
||||
@@ -939,10 +993,10 @@ mod test {
|
||||
fn test_infer_parse_prepared() {
|
||||
QueryRouter::setup();
|
||||
let mut qr = QueryRouter::new();
|
||||
qr.pool_settings.query_parser_read_write_splitting = true;
|
||||
|
||||
qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"));
|
||||
assert!(qr
|
||||
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
|
||||
.is_some());
|
||||
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
|
||||
|
||||
let prepared_stmt = BytesMut::from(
|
||||
&b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..],
|
||||
@@ -953,7 +1007,7 @@ mod test {
|
||||
res.put(prepared_stmt);
|
||||
res.put_i16(0);
|
||||
|
||||
assert!(qr.infer(&QueryRouter::parse(&res).unwrap()).is_ok());
|
||||
assert!(qr.infer(&qr.parse(&res).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), Some(Role::Replica));
|
||||
}
|
||||
|
||||
@@ -1039,7 +1093,7 @@ mod test {
|
||||
qr.try_execute_command(&query),
|
||||
Some((Command::SetShardingKey, String::from("0")))
|
||||
);
|
||||
assert_eq!(qr.shard(), 0);
|
||||
assert_eq!(qr.shard().unwrap(), 0);
|
||||
|
||||
// SetShard
|
||||
let query = simple_query("SET SHARD TO '1'");
|
||||
@@ -1047,7 +1101,7 @@ mod test {
|
||||
qr.try_execute_command(&query),
|
||||
Some((Command::SetShard, String::from("1")))
|
||||
);
|
||||
assert_eq!(qr.shard(), 1);
|
||||
assert_eq!(qr.shard().unwrap(), 1);
|
||||
|
||||
// ShowShard
|
||||
let query = simple_query("SHOW SHARD");
|
||||
@@ -1109,26 +1163,26 @@ mod test {
|
||||
fn test_enable_query_parser() {
|
||||
QueryRouter::setup();
|
||||
let mut qr = QueryRouter::new();
|
||||
let query = simple_query("SET SERVER ROLE TO 'auto'");
|
||||
assert!(qr
|
||||
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
|
||||
.is_some());
|
||||
qr.pool_settings.query_parser_read_write_splitting = true;
|
||||
|
||||
assert!(qr.try_execute_command(&query).is_some());
|
||||
let query = simple_query("SET SERVER ROLE TO 'auto'");
|
||||
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
|
||||
|
||||
assert!(qr.try_execute_command(&query) != None);
|
||||
assert!(qr.query_parser_enabled());
|
||||
assert_eq!(qr.role(), None);
|
||||
|
||||
let query = simple_query("INSERT INTO test_table VALUES (1)");
|
||||
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
||||
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), Some(Role::Primary));
|
||||
|
||||
let query = simple_query("SELECT * FROM test_table");
|
||||
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
||||
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), Some(Role::Replica));
|
||||
|
||||
assert!(qr.query_parser_enabled());
|
||||
let query = simple_query("SET SERVER ROLE TO 'default'");
|
||||
assert!(qr.try_execute_command(&query).is_some());
|
||||
assert!(qr.try_execute_command(&query) != None);
|
||||
assert!(!qr.query_parser_enabled());
|
||||
}
|
||||
|
||||
@@ -1143,6 +1197,8 @@ mod test {
|
||||
user: crate::config::User::default(),
|
||||
default_role: Some(Role::Replica),
|
||||
query_parser_enabled: true,
|
||||
query_parser_max_length: None,
|
||||
query_parser_read_write_splitting: true,
|
||||
primary_reads_enabled: false,
|
||||
sharding_function: ShardingFunction::PgBigintHash,
|
||||
automatic_sharding_key: Some(String::from("test.id")),
|
||||
@@ -1151,6 +1207,7 @@ mod test {
|
||||
ban_time: PoolSettings::default().ban_time,
|
||||
sharding_key_regex: None,
|
||||
shard_id_regex: None,
|
||||
default_shard: crate::config::DefaultShard::Shard(0),
|
||||
regex_search_limit: 1000,
|
||||
auth_query: None,
|
||||
auth_query_password: None,
|
||||
@@ -1173,11 +1230,11 @@ mod test {
|
||||
assert!(!qr.primary_reads_enabled());
|
||||
|
||||
let q1 = simple_query("SET SERVER ROLE TO 'primary'");
|
||||
assert!(qr.try_execute_command(&q1).is_some());
|
||||
assert!(qr.try_execute_command(&q1) != None);
|
||||
assert_eq!(qr.active_role.unwrap(), Role::Primary);
|
||||
|
||||
let q2 = simple_query("SET SERVER ROLE TO 'default'");
|
||||
assert!(qr.try_execute_command(&q2).is_some());
|
||||
assert!(qr.try_execute_command(&q2) != None);
|
||||
assert_eq!(qr.active_role.unwrap(), pool_settings.default_role);
|
||||
}
|
||||
|
||||
@@ -1187,18 +1244,18 @@ mod test {
|
||||
|
||||
let mut qr = QueryRouter::new();
|
||||
assert!(qr
|
||||
.infer(&QueryRouter::parse(&simple_query("BEGIN; SELECT 1; COMMIT;")).unwrap())
|
||||
.infer(&qr.parse(&simple_query("BEGIN; SELECT 1; COMMIT;")).unwrap())
|
||||
.is_ok());
|
||||
assert_eq!(qr.role(), Role::Primary);
|
||||
|
||||
assert!(qr
|
||||
.infer(&QueryRouter::parse(&simple_query("SELECT 1; SELECT 2;")).unwrap())
|
||||
.infer(&qr.parse(&simple_query("SELECT 1; SELECT 2;")).unwrap())
|
||||
.is_ok());
|
||||
assert_eq!(qr.role(), Role::Replica);
|
||||
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
&qr.parse(&simple_query(
|
||||
"SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
|
||||
))
|
||||
.unwrap()
|
||||
@@ -1218,6 +1275,8 @@ mod test {
|
||||
user: crate::config::User::default(),
|
||||
default_role: Some(Role::Replica),
|
||||
query_parser_enabled: true,
|
||||
query_parser_max_length: None,
|
||||
query_parser_read_write_splitting: true,
|
||||
primary_reads_enabled: false,
|
||||
sharding_function: ShardingFunction::PgBigintHash,
|
||||
automatic_sharding_key: None,
|
||||
@@ -1226,6 +1285,7 @@ mod test {
|
||||
ban_time: PoolSettings::default().ban_time,
|
||||
sharding_key_regex: Some(Regex::new(r"/\* sharding_key: (\d+) \*/").unwrap()),
|
||||
shard_id_regex: Some(Regex::new(r"/\* shard_id: (\d+) \*/").unwrap()),
|
||||
default_shard: crate::config::DefaultShard::Shard(0),
|
||||
regex_search_limit: 1000,
|
||||
auth_query: None,
|
||||
auth_query_password: None,
|
||||
@@ -1240,19 +1300,24 @@ mod test {
|
||||
// Shard should start out unset
|
||||
assert_eq!(qr.active_shard, None);
|
||||
|
||||
// Don't panic when short query eg. ; is sent
|
||||
let q0 = simple_query(";");
|
||||
assert!(qr.try_execute_command(&q0) == None);
|
||||
assert_eq!(qr.active_shard, None);
|
||||
|
||||
// Make sure setting it works
|
||||
let q1 = simple_query("/* shard_id: 1 */ select 1 from foo;");
|
||||
assert!(qr.try_execute_command(&q1).is_none());
|
||||
assert!(qr.try_execute_command(&q1) == None);
|
||||
assert_eq!(qr.active_shard, Some(1));
|
||||
|
||||
// And make sure changing it works
|
||||
let q2 = simple_query("/* shard_id: 0 */ select 1 from foo;");
|
||||
assert!(qr.try_execute_command(&q2).is_none());
|
||||
assert!(qr.try_execute_command(&q2) == None);
|
||||
assert_eq!(qr.active_shard, Some(0));
|
||||
|
||||
// Validate setting by shard with expected shard copied from sharding.rs tests
|
||||
let q2 = simple_query("/* sharding_key: 6 */ select 1 from foo;");
|
||||
assert!(qr.try_execute_command(&q2).is_none());
|
||||
assert!(qr.try_execute_command(&q2) == None);
|
||||
assert_eq!(qr.active_shard, Some(2));
|
||||
}
|
||||
|
||||
@@ -1263,25 +1328,29 @@ mod test {
|
||||
let mut qr = QueryRouter::new();
|
||||
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
|
||||
qr.pool_settings.shards = 3;
|
||||
|
||||
assert!(qr
|
||||
.infer(&QueryRouter::parse(&simple_query("SELECT * FROM data WHERE id = 5")).unwrap())
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 2);
|
||||
qr.pool_settings.query_parser_read_write_splitting = true;
|
||||
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
&qr.parse(&simple_query("SELECT * FROM data WHERE id = 5"))
|
||||
.unwrap(),
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard().unwrap(), 2);
|
||||
|
||||
assert!(qr
|
||||
.infer(
|
||||
&qr.parse(&simple_query(
|
||||
"SELECT one, two, three FROM public.data WHERE id = 6"
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 0);
|
||||
assert_eq!(qr.shard().unwrap(), 0);
|
||||
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
&qr.parse(&simple_query(
|
||||
"SELECT * FROM data
|
||||
INNER JOIN t2 ON data.id = 5
|
||||
AND t2.data_id = data.id
|
||||
@@ -1290,59 +1359,59 @@ mod test {
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 2);
|
||||
assert_eq!(qr.shard().unwrap(), 2);
|
||||
|
||||
// Shard did not move because we couldn't determine the sharding key since it could be ambiguous
|
||||
// in the query.
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
&qr.parse(&simple_query(
|
||||
"SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 2);
|
||||
assert_eq!(qr.shard().unwrap(), 2);
|
||||
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
&qr.parse(&simple_query(
|
||||
r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 0);
|
||||
assert_eq!(qr.shard().unwrap(), 0);
|
||||
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
&qr.parse(&simple_query(
|
||||
r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 2);
|
||||
assert_eq!(qr.shard().unwrap(), 2);
|
||||
|
||||
// Super unique sharding key
|
||||
qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string());
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
&qr.parse(&simple_query(
|
||||
"SELECT * FROM table_x WHERE unique_enough_column_name = 6"
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 0);
|
||||
assert_eq!(qr.shard().unwrap(), 0);
|
||||
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query("SELECT * FROM table_y WHERE another_key = 5"))
|
||||
&qr.parse(&simple_query("SELECT * FROM table_y WHERE another_key = 5"))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 0);
|
||||
assert_eq!(qr.shard().unwrap(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1364,14 +1433,13 @@ mod test {
|
||||
let mut qr = QueryRouter::new();
|
||||
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
|
||||
qr.pool_settings.shards = 3;
|
||||
qr.pool_settings.query_parser_read_write_splitting = true;
|
||||
|
||||
assert!(qr
|
||||
.infer(&QueryRouter::parse(&simple_query(stmt)).unwrap())
|
||||
.is_ok());
|
||||
assert!(qr.infer(&qr.parse(&simple_query(stmt)).unwrap()).is_ok());
|
||||
assert_eq!(qr.placeholders.len(), 1);
|
||||
|
||||
assert!(qr.infer_shard_from_bind(&bind));
|
||||
assert_eq!(qr.shard(), 2);
|
||||
assert_eq!(qr.shard().unwrap(), 2);
|
||||
assert!(qr.placeholders.is_empty());
|
||||
}
|
||||
|
||||
@@ -1390,17 +1458,15 @@ mod test {
|
||||
};
|
||||
|
||||
QueryRouter::setup();
|
||||
let pool_settings = PoolSettings {
|
||||
query_parser_enabled: true,
|
||||
plugins: Some(plugins),
|
||||
..Default::default()
|
||||
};
|
||||
let mut pool_settings = PoolSettings::default();
|
||||
pool_settings.query_parser_enabled = true;
|
||||
pool_settings.plugins = Some(plugins);
|
||||
|
||||
let mut qr = QueryRouter::new();
|
||||
qr.update_pool_settings(pool_settings);
|
||||
|
||||
let query = simple_query("SELECT * FROM pg_database");
|
||||
let ast = QueryRouter::parse(&query).unwrap();
|
||||
let ast = qr.parse(&query).unwrap();
|
||||
|
||||
let res = qr.execute_plugins(&ast).await;
|
||||
|
||||
@@ -1418,7 +1484,7 @@ mod test {
|
||||
let qr = QueryRouter::new();
|
||||
|
||||
let query = simple_query("SELECT * FROM pg_database");
|
||||
let ast = QueryRouter::parse(&query).unwrap();
|
||||
let ast = qr.parse(&query).unwrap();
|
||||
|
||||
let res = qr.execute_plugins(&ast).await;
|
||||
|
||||
|
||||
14
src/scram.rs
14
src/scram.rs
@@ -79,12 +79,12 @@ impl ScramSha256 {
|
||||
let server_message = Message::parse(message)?;
|
||||
|
||||
if !server_message.nonce.starts_with(&self.nonce) {
|
||||
return Err(Error::ProtocolSyncError("SCRAM".to_string()));
|
||||
return Err(Error::ProtocolSyncError(format!("SCRAM")));
|
||||
}
|
||||
|
||||
let salt = match general_purpose::STANDARD.decode(&server_message.salt) {
|
||||
Ok(salt) => salt,
|
||||
Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
|
||||
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
|
||||
};
|
||||
|
||||
let salted_password = Self::hi(
|
||||
@@ -166,9 +166,9 @@ impl ScramSha256 {
|
||||
pub fn finish(&mut self, message: &BytesMut) -> Result<(), Error> {
|
||||
let final_message = FinalMessage::parse(message)?;
|
||||
|
||||
let verifier = match general_purpose::STANDARD.decode(final_message.value) {
|
||||
let verifier = match general_purpose::STANDARD.decode(&final_message.value) {
|
||||
Ok(verifier) => verifier,
|
||||
Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
|
||||
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
|
||||
};
|
||||
|
||||
let mut hmac = match Hmac::<Sha256>::new_from_slice(&self.salted_password) {
|
||||
@@ -230,14 +230,14 @@ impl Message {
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
if parts.len() != 3 {
|
||||
return Err(Error::ProtocolSyncError("SCRAM".to_string()));
|
||||
return Err(Error::ProtocolSyncError(format!("SCRAM")));
|
||||
}
|
||||
|
||||
let nonce = str::replace(&parts[0], "r=", "");
|
||||
let salt = str::replace(&parts[1], "s=", "");
|
||||
let iterations = match str::replace(&parts[2], "i=", "").parse::<u32>() {
|
||||
Ok(iterations) => iterations,
|
||||
Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
|
||||
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
|
||||
};
|
||||
|
||||
Ok(Message {
|
||||
@@ -257,7 +257,7 @@ impl FinalMessage {
|
||||
/// Parse the server final validation message.
|
||||
pub fn parse(message: &BytesMut) -> Result<FinalMessage, Error> {
|
||||
if !message.starts_with(b"v=") || message.len() < 4 {
|
||||
return Err(Error::ProtocolSyncError("SCRAM".to_string()));
|
||||
return Err(Error::ProtocolSyncError(format!("SCRAM")));
|
||||
}
|
||||
|
||||
Ok(FinalMessage {
|
||||
|
||||
321
src/server.rs
321
src/server.rs
@@ -3,10 +3,11 @@
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use log::{debug, error, info, trace, warn};
|
||||
use once_cell::sync::Lazy;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use postgres_protocol::message;
|
||||
use std::collections::{BTreeSet, HashMap};
|
||||
use std::io::Read;
|
||||
use std::collections::{BTreeSet, HashMap, HashSet};
|
||||
use std::mem;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::SystemTime;
|
||||
@@ -19,6 +20,7 @@ use crate::config::{get_config, get_prepared_statements_cache_size, Address, Use
|
||||
use crate::constants::*;
|
||||
use crate::dns_cache::{AddrSet, CACHED_RESOLVER};
|
||||
use crate::errors::{Error, ServerIdentifier};
|
||||
use crate::messages::BytesMutReader;
|
||||
use crate::messages::*;
|
||||
use crate::mirrors::MirroringManager;
|
||||
use crate::pool::ClientServerMap;
|
||||
@@ -105,10 +107,10 @@ impl StreamInner {
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
struct CleanupState {
|
||||
/// If server connection requires DISCARD ALL before checkin because of set statement
|
||||
/// If server connection requires RESET ALL before checkin because of set statement
|
||||
needs_cleanup_set: bool,
|
||||
|
||||
/// If server connection requires DISCARD ALL before checkin because of prepare statement
|
||||
/// If server connection requires DEALLOCATE ALL before checkin because of prepare statement
|
||||
needs_cleanup_prepare: bool,
|
||||
}
|
||||
|
||||
@@ -145,6 +147,124 @@ impl std::fmt::Display for CleanupState {
|
||||
}
|
||||
}
|
||||
|
||||
static TRACKED_PARAMETERS: Lazy<HashSet<String>> = Lazy::new(|| {
|
||||
let mut set = HashSet::new();
|
||||
set.insert("client_encoding".to_string());
|
||||
set.insert("DateStyle".to_string());
|
||||
set.insert("TimeZone".to_string());
|
||||
set.insert("standard_conforming_strings".to_string());
|
||||
set.insert("application_name".to_string());
|
||||
set
|
||||
});
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ServerParameters {
|
||||
parameters: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl Default for ServerParameters {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerParameters {
|
||||
pub fn new() -> Self {
|
||||
let mut server_parameters = ServerParameters {
|
||||
parameters: HashMap::new(),
|
||||
};
|
||||
|
||||
server_parameters.set_param("client_encoding".to_string(), "UTF8".to_string(), false);
|
||||
server_parameters.set_param("DateStyle".to_string(), "ISO, MDY".to_string(), false);
|
||||
server_parameters.set_param("TimeZone".to_string(), "Etc/UTC".to_string(), false);
|
||||
server_parameters.set_param(
|
||||
"standard_conforming_strings".to_string(),
|
||||
"on".to_string(),
|
||||
false,
|
||||
);
|
||||
server_parameters.set_param("application_name".to_string(), "pgcat".to_string(), false);
|
||||
|
||||
server_parameters
|
||||
}
|
||||
|
||||
/// returns true if a tracked parameter was set, false if it was a non-tracked parameter
|
||||
/// if startup is false, then then only tracked parameters will be set
|
||||
pub fn set_param(&mut self, mut key: String, value: String, startup: bool) {
|
||||
// The startup parameter will send uncapitalized keys but parameter status packets will send capitalized keys
|
||||
if key == "timezone" {
|
||||
key = "TimeZone".to_string();
|
||||
} else if key == "datestyle" {
|
||||
key = "DateStyle".to_string();
|
||||
};
|
||||
|
||||
if TRACKED_PARAMETERS.contains(&key) {
|
||||
self.parameters.insert(key, value);
|
||||
} else {
|
||||
if startup {
|
||||
self.parameters.insert(key, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_from_hashmap(&mut self, parameters: &HashMap<String, String>, startup: bool) {
|
||||
// iterate through each and call set_param
|
||||
for (key, value) in parameters {
|
||||
self.set_param(key.to_string(), value.to_string(), startup);
|
||||
}
|
||||
}
|
||||
|
||||
// Gets the diff of the parameters
|
||||
fn compare_params(&self, incoming_parameters: &ServerParameters) -> HashMap<String, String> {
|
||||
let mut diff = HashMap::new();
|
||||
|
||||
// iterate through tracked parameters
|
||||
for key in TRACKED_PARAMETERS.iter() {
|
||||
if let Some(incoming_value) = incoming_parameters.parameters.get(key) {
|
||||
if let Some(value) = self.parameters.get(key) {
|
||||
if value != incoming_value {
|
||||
diff.insert(key.to_string(), incoming_value.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
diff
|
||||
}
|
||||
|
||||
pub fn get_application_name(&self) -> &String {
|
||||
// Can unwrap because we set it in the constructor
|
||||
self.parameters.get("application_name").unwrap()
|
||||
}
|
||||
|
||||
fn add_parameter_message(key: &str, value: &str, buffer: &mut BytesMut) {
|
||||
buffer.put_u8(b'S');
|
||||
|
||||
// 4 is len of i32, the plus for the null terminator
|
||||
let len = 4 + key.len() + 1 + value.len() + 1;
|
||||
|
||||
buffer.put_i32(len as i32);
|
||||
|
||||
buffer.put_slice(key.as_bytes());
|
||||
buffer.put_u8(0);
|
||||
buffer.put_slice(value.as_bytes());
|
||||
buffer.put_u8(0);
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&ServerParameters> for BytesMut {
|
||||
fn from(server_parameters: &ServerParameters) -> Self {
|
||||
let mut bytes = BytesMut::new();
|
||||
|
||||
for (key, value) in &server_parameters.parameters {
|
||||
ServerParameters::add_parameter_message(key, value, &mut bytes);
|
||||
}
|
||||
|
||||
bytes
|
||||
}
|
||||
}
|
||||
|
||||
// pub fn compare
|
||||
|
||||
/// Server state.
|
||||
pub struct Server {
|
||||
/// Server host, e.g. localhost,
|
||||
@@ -158,7 +278,7 @@ pub struct Server {
|
||||
buffer: BytesMut,
|
||||
|
||||
/// Server information the server sent us over on startup.
|
||||
server_info: BytesMut,
|
||||
server_parameters: ServerParameters,
|
||||
|
||||
/// Backend id and secret key used for query cancellation.
|
||||
process_id: i32,
|
||||
@@ -176,7 +296,7 @@ pub struct Server {
|
||||
/// Is the server broken? We'll remote it from the pool if so.
|
||||
bad: bool,
|
||||
|
||||
/// If server connection requires DISCARD ALL before checkin
|
||||
/// If server connection requires reset statements before checkin
|
||||
cleanup_state: CleanupState,
|
||||
|
||||
/// Mapping of clients and servers used for query cancellation.
|
||||
@@ -202,6 +322,9 @@ pub struct Server {
|
||||
/// Should clean up dirty connections?
|
||||
cleanup_connections: bool,
|
||||
|
||||
/// Log client parameter status changes
|
||||
log_client_parameter_status_changes: bool,
|
||||
|
||||
/// Prepared statements
|
||||
prepared_statements: BTreeSet<String>,
|
||||
}
|
||||
@@ -217,6 +340,7 @@ impl Server {
|
||||
stats: Arc<ServerStats>,
|
||||
auth_hash: Arc<RwLock<Option<String>>>,
|
||||
cleanup_connections: bool,
|
||||
log_client_parameter_status_changes: bool,
|
||||
) -> Result<Server, Error> {
|
||||
let cached_resolver = CACHED_RESOLVER.load();
|
||||
let mut addr_set: Option<AddrSet> = None;
|
||||
@@ -316,7 +440,10 @@ impl Server {
|
||||
|
||||
// Something else?
|
||||
m => {
|
||||
return Err(Error::SocketError(format!("Unknown message: {}", { m })));
|
||||
return Err(Error::SocketError(format!(
|
||||
"Unknown message: {}",
|
||||
m as char
|
||||
)));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -334,18 +461,28 @@ impl Server {
|
||||
None => &user.username,
|
||||
};
|
||||
|
||||
let password = user.server_password.as_ref();
|
||||
let password = match user.server_password {
|
||||
Some(ref server_password) => Some(server_password),
|
||||
None => match user.password {
|
||||
Some(ref password) => Some(password),
|
||||
None => None,
|
||||
},
|
||||
};
|
||||
|
||||
startup(&mut stream, username, database).await?;
|
||||
|
||||
let mut server_info = BytesMut::new();
|
||||
let mut process_id: i32 = 0;
|
||||
let mut secret_key: i32 = 0;
|
||||
let server_identifier = ServerIdentifier::new(username, database);
|
||||
let server_identifier = ServerIdentifier::new(username, &database);
|
||||
|
||||
// We'll be handling multiple packets, but they will all be structured the same.
|
||||
// We'll loop here until this exchange is complete.
|
||||
let mut scram: Option<ScramSha256> = password.map(|password| ScramSha256::new(password));
|
||||
let mut scram: Option<ScramSha256> = match password {
|
||||
Some(password) => Some(ScramSha256::new(password)),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let mut server_parameters = ServerParameters::new();
|
||||
|
||||
loop {
|
||||
let code = match stream.read_u8().await {
|
||||
@@ -576,8 +713,7 @@ impl Server {
|
||||
|
||||
// An error message will be present.
|
||||
_ => {
|
||||
// Read the error message without the terminating null character.
|
||||
let mut error = vec![0u8; len as usize - 4 - 1];
|
||||
let mut error = vec![0u8; len as usize];
|
||||
|
||||
match stream.read_exact(&mut error).await {
|
||||
Ok(_) => (),
|
||||
@@ -589,10 +725,14 @@ impl Server {
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: the error message contains multiple fields; we can decode them and
|
||||
// present a prettier message to the user.
|
||||
// See: https://www.postgresql.org/docs/12/protocol-error-fields.html
|
||||
error!("Server error: {}", String::from_utf8_lossy(&error));
|
||||
let fields = match PgErrorMsg::parse(error) {
|
||||
Ok(f) => f,
|
||||
Err(err) => {
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
trace!("error fields: {}", &fields);
|
||||
error!("server error: {}: {}", fields.severity, fields.message);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -601,9 +741,10 @@ impl Server {
|
||||
|
||||
// ParameterStatus
|
||||
'S' => {
|
||||
let mut param = vec![0u8; len as usize - 4];
|
||||
let mut bytes = BytesMut::with_capacity(len as usize - 4);
|
||||
bytes.resize(len as usize - mem::size_of::<i32>(), b'0');
|
||||
|
||||
match stream.read_exact(&mut param).await {
|
||||
match stream.read_exact(&mut bytes[..]).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
@@ -613,12 +754,13 @@ impl Server {
|
||||
}
|
||||
};
|
||||
|
||||
let key = bytes.read_string().unwrap();
|
||||
let value = bytes.read_string().unwrap();
|
||||
|
||||
// Save the parameter so we can pass it to the client later.
|
||||
// These can be server_encoding, client_encoding, server timezone, Postgres version,
|
||||
// and many more interesting things we should know about the Postgres server we are talking to.
|
||||
server_info.put_u8(b'S');
|
||||
server_info.put_i32(len);
|
||||
server_info.put_slice(¶m[..]);
|
||||
server_parameters.set_param(key, value, true);
|
||||
}
|
||||
|
||||
// BackendKeyData
|
||||
@@ -660,11 +802,11 @@ impl Server {
|
||||
}
|
||||
};
|
||||
|
||||
let mut server = Server {
|
||||
let server = Server {
|
||||
address: address.clone(),
|
||||
stream: BufStream::new(stream),
|
||||
buffer: BytesMut::with_capacity(8196),
|
||||
server_info,
|
||||
server_parameters,
|
||||
process_id,
|
||||
secret_key,
|
||||
in_transaction: false,
|
||||
@@ -676,7 +818,7 @@ impl Server {
|
||||
addr_set,
|
||||
connected_at: chrono::offset::Utc::now().naive_utc(),
|
||||
stats,
|
||||
application_name: String::new(),
|
||||
application_name: "pgcat".to_string(),
|
||||
last_activity: SystemTime::now(),
|
||||
mirror_manager: match address.mirrors.len() {
|
||||
0 => None,
|
||||
@@ -687,11 +829,10 @@ impl Server {
|
||||
)),
|
||||
},
|
||||
cleanup_connections,
|
||||
log_client_parameter_status_changes,
|
||||
prepared_statements: BTreeSet::new(),
|
||||
};
|
||||
|
||||
server.set_name("pgcat").await?;
|
||||
|
||||
return Ok(server);
|
||||
}
|
||||
|
||||
@@ -741,7 +882,7 @@ impl Server {
|
||||
self.mirror_send(messages);
|
||||
self.stats().data_sent(messages.len());
|
||||
|
||||
match write_all_flush(&mut self.stream, messages).await {
|
||||
match write_all_flush(&mut self.stream, &messages).await {
|
||||
Ok(_) => {
|
||||
// Successfully sent to server
|
||||
self.last_activity = SystemTime::now();
|
||||
@@ -761,7 +902,10 @@ impl Server {
|
||||
/// Receive data from the server in response to a client request.
|
||||
/// This method must be called multiple times while `self.is_data_available()` is true
|
||||
/// in order to receive all data the server has to offer.
|
||||
pub async fn recv(&mut self) -> Result<BytesMut, Error> {
|
||||
pub async fn recv(
|
||||
&mut self,
|
||||
mut client_server_parameters: Option<&mut ServerParameters>,
|
||||
) -> Result<BytesMut, Error> {
|
||||
loop {
|
||||
let mut message = match read_message(&mut self.stream).await {
|
||||
Ok(message) => message,
|
||||
@@ -833,24 +977,24 @@ impl Server {
|
||||
self.in_copy_mode = false;
|
||||
}
|
||||
|
||||
let mut command_tag = String::new();
|
||||
match message.reader().read_to_string(&mut command_tag) {
|
||||
Ok(_) => {
|
||||
match message.read_string() {
|
||||
Ok(command) => {
|
||||
// Non-exhaustive list of commands that are likely to change session variables/resources
|
||||
// which can leak between clients. This is a best effort to block bad clients
|
||||
// from poisoning a transaction-mode pool by setting inappropriate session variables
|
||||
match command_tag.as_str() {
|
||||
"SET\0" => {
|
||||
match command.as_str() {
|
||||
"SET" => {
|
||||
// We don't detect set statements in transactions
|
||||
// No great way to differentiate between set and set local
|
||||
// As a result, we will miss cases when set statements are used in transactions
|
||||
// This will reduce amount of discard statements sent
|
||||
// This will reduce amount of reset statements sent
|
||||
if !self.in_transaction {
|
||||
debug!("Server connection marked for clean up");
|
||||
self.cleanup_state.needs_cleanup_set = true;
|
||||
}
|
||||
}
|
||||
"PREPARE\0" => {
|
||||
|
||||
"PREPARE" => {
|
||||
debug!("Server connection marked for clean up");
|
||||
self.cleanup_state.needs_cleanup_prepare = true;
|
||||
}
|
||||
@@ -864,6 +1008,20 @@ impl Server {
|
||||
}
|
||||
}
|
||||
|
||||
'S' => {
|
||||
let key = message.read_string().unwrap();
|
||||
let value = message.read_string().unwrap();
|
||||
|
||||
if let Some(client_server_parameters) = client_server_parameters.as_mut() {
|
||||
client_server_parameters.set_param(key.clone(), value.clone(), false);
|
||||
if self.log_client_parameter_status_changes {
|
||||
info!("Client parameter status change: {} = {}", key, value)
|
||||
}
|
||||
}
|
||||
|
||||
self.server_parameters.set_param(key, value, false);
|
||||
}
|
||||
|
||||
// DataRow
|
||||
'D' => {
|
||||
// More data is available after this message, this is not the end of the reply.
|
||||
@@ -985,7 +1143,9 @@ impl Server {
|
||||
}
|
||||
}
|
||||
|
||||
self.deallocate(names).await?;
|
||||
if !names.is_empty() {
|
||||
self.deallocate(names).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1009,7 +1169,9 @@ impl Server {
|
||||
self.send(&bytes).await?;
|
||||
}
|
||||
|
||||
self.send(&flush()).await?;
|
||||
if !names.is_empty() {
|
||||
self.send(&flush()).await?;
|
||||
}
|
||||
|
||||
// Read and discard CloseComplete (3)
|
||||
for name in &names {
|
||||
@@ -1070,9 +1232,28 @@ impl Server {
|
||||
}
|
||||
|
||||
/// Get server startup information to forward it to the client.
|
||||
/// Not used at the moment.
|
||||
pub fn server_info(&self) -> BytesMut {
|
||||
self.server_info.clone()
|
||||
pub fn server_parameters(&self) -> ServerParameters {
|
||||
self.server_parameters.clone()
|
||||
}
|
||||
|
||||
pub async fn sync_parameters(&mut self, parameters: &ServerParameters) -> Result<(), Error> {
|
||||
let parameter_diff = self.server_parameters.compare_params(parameters);
|
||||
|
||||
if parameter_diff.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut query = String::from("");
|
||||
|
||||
for (key, value) in parameter_diff {
|
||||
query.push_str(&format!("SET {} TO '{}';", key, value));
|
||||
}
|
||||
|
||||
let res = self.query(&query).await;
|
||||
|
||||
self.cleanup_state.reset();
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
/// Indicate that this server connection cannot be re-used and must be discarded.
|
||||
@@ -1106,7 +1287,7 @@ impl Server {
|
||||
self.send(&query).await?;
|
||||
|
||||
loop {
|
||||
let _ = self.recv().await?;
|
||||
let _ = self.recv(None).await?;
|
||||
|
||||
if !self.data_available {
|
||||
break;
|
||||
@@ -1124,47 +1305,38 @@ impl Server {
|
||||
// server connection thrashing if clients repeatedly do this.
|
||||
// Instead, we ROLLBACK that transaction before putting the connection back in the pool
|
||||
if self.in_transaction() {
|
||||
warn!("Server returned while still in transaction, rolling back transaction");
|
||||
warn!(target: "pgcat::server::cleanup", "Server returned while still in transaction, rolling back transaction");
|
||||
self.query("ROLLBACK").await?;
|
||||
}
|
||||
|
||||
// Client disconnected but it performed session-altering operations such as
|
||||
// SET statement_timeout to 1 or create a prepared statement. We clear that
|
||||
// to avoid leaking state between clients. For performance reasons we only
|
||||
// send `DISCARD ALL` if we think the session is altered instead of just sending
|
||||
// send `RESET ALL` if we think the session is altered instead of just sending
|
||||
// it before each checkin.
|
||||
if self.cleanup_state.needs_cleanup() && self.cleanup_connections {
|
||||
warn!("Server returned with session state altered, discarding state ({}) for application {}", self.cleanup_state, self.application_name);
|
||||
self.query("DISCARD ALL").await?;
|
||||
self.query("RESET ROLE").await?;
|
||||
info!(target: "pgcat::server::cleanup", "Server returned with session state altered, discarding state ({}) for application {}", self.cleanup_state, self.application_name);
|
||||
let mut reset_string = String::from("RESET ROLE;");
|
||||
|
||||
if self.cleanup_state.needs_cleanup_set {
|
||||
reset_string.push_str("RESET ALL;");
|
||||
};
|
||||
|
||||
if self.cleanup_state.needs_cleanup_prepare {
|
||||
reset_string.push_str("DEALLOCATE ALL;");
|
||||
};
|
||||
|
||||
self.query(&reset_string).await?;
|
||||
self.cleanup_state.reset();
|
||||
}
|
||||
|
||||
if self.in_copy_mode() {
|
||||
warn!("Server returned while still in copy-mode");
|
||||
warn!(target: "pgcat::server::cleanup", "Server returned while still in copy-mode");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// A shorthand for `SET application_name = $1`.
|
||||
pub async fn set_name(&mut self, name: &str) -> Result<(), Error> {
|
||||
if self.application_name != name {
|
||||
self.application_name = name.to_string();
|
||||
// We don't want `SET application_name` to mark the server connection
|
||||
// as needing cleanup
|
||||
let needs_cleanup_before = self.cleanup_state;
|
||||
|
||||
let result = Ok(self
|
||||
.query(&format!("SET application_name = '{}'", name))
|
||||
.await?);
|
||||
self.cleanup_state = needs_cleanup_before;
|
||||
result
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// get Server stats
|
||||
pub fn stats(&self) -> Arc<ServerStats> {
|
||||
self.stats.clone()
|
||||
@@ -1181,20 +1353,22 @@ impl Server {
|
||||
self.last_activity
|
||||
}
|
||||
|
||||
// Marks a connection as needing DISCARD ALL at checkin
|
||||
// Marks a connection as needing cleanup at checkin
|
||||
pub fn mark_dirty(&mut self) {
|
||||
self.cleanup_state.set_true();
|
||||
}
|
||||
|
||||
pub fn mirror_send(&mut self, bytes: &BytesMut) {
|
||||
if let Some(manager) = self.mirror_manager.as_mut() {
|
||||
manager.send(bytes);
|
||||
match self.mirror_manager.as_mut() {
|
||||
Some(manager) => manager.send(bytes),
|
||||
None => (),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mirror_disconnect(&mut self) {
|
||||
if let Some(manager) = self.mirror_manager.as_mut() {
|
||||
manager.disconnect();
|
||||
match self.mirror_manager.as_mut() {
|
||||
Some(manager) => manager.disconnect(),
|
||||
None => (),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1216,13 +1390,14 @@ impl Server {
|
||||
Arc::new(ServerStats::default()),
|
||||
Arc::new(RwLock::new(None)),
|
||||
true,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
debug!("Connected!, sending query.");
|
||||
server.send(&simple_query(query)).await?;
|
||||
let mut message = server.recv().await?;
|
||||
let mut message = server.recv(None).await?;
|
||||
|
||||
parse_query_message(&mut message).await
|
||||
Ok(parse_query_message(&mut message).await?)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ impl Sharder {
|
||||
fn sha1(&self, key: i64) -> usize {
|
||||
let mut hasher = Sha1::new();
|
||||
|
||||
hasher.update(key.to_string().as_bytes());
|
||||
hasher.update(&key.to_string().as_bytes());
|
||||
|
||||
let result = hasher.finalize();
|
||||
|
||||
|
||||
@@ -77,12 +77,13 @@ impl Reporter {
|
||||
/// The statistics collector which used for calculating averages
|
||||
/// There is only one collector (kind of like a singleton)
|
||||
/// it updates averages every 15 seconds.
|
||||
pub struct Collector;
|
||||
#[derive(Default)]
|
||||
pub struct Collector {}
|
||||
|
||||
impl Collector {
|
||||
/// The statistics collection handler. It will collect statistics
|
||||
/// for `address_id`s starting at 0 up to `addresses`.
|
||||
pub fn collect() {
|
||||
pub async fn collect(&mut self) {
|
||||
info!("Events reporter started");
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
|
||||
@@ -86,11 +86,11 @@ impl PoolStats {
|
||||
}
|
||||
}
|
||||
|
||||
map
|
||||
return map;
|
||||
}
|
||||
|
||||
pub fn generate_header() -> Vec<(&'static str, DataType)> {
|
||||
vec![
|
||||
return vec![
|
||||
("database", DataType::Text),
|
||||
("user", DataType::Text),
|
||||
("pool_mode", DataType::Text),
|
||||
@@ -105,11 +105,11 @@ impl PoolStats {
|
||||
("sv_login", DataType::Numeric),
|
||||
("maxwait", DataType::Numeric),
|
||||
("maxwait_us", DataType::Numeric),
|
||||
]
|
||||
];
|
||||
}
|
||||
|
||||
pub fn generate_row(&self) -> Vec<String> {
|
||||
vec![
|
||||
return vec![
|
||||
self.identifier.db.clone(),
|
||||
self.identifier.user.clone(),
|
||||
self.mode.to_string(),
|
||||
@@ -124,7 +124,7 @@ impl PoolStats {
|
||||
self.sv_login.to_string(),
|
||||
(self.maxwait / 1_000_000).to_string(),
|
||||
(self.maxwait % 1_000_000).to_string(),
|
||||
]
|
||||
];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
24
src/tls.rs
24
src/tls.rs
@@ -44,17 +44,25 @@ impl Tls {
|
||||
pub fn new() -> Result<Self, Error> {
|
||||
let config = get_config();
|
||||
|
||||
let certs = load_certs(Path::new(&config.general.tls_certificate.unwrap()))
|
||||
.map_err(|_| Error::TlsError)?;
|
||||
let key_der = load_keys(Path::new(&config.general.tls_private_key.unwrap()))
|
||||
.map_err(|_| Error::TlsError)?
|
||||
.remove(0);
|
||||
let certs = match load_certs(Path::new(&config.general.tls_certificate.unwrap())) {
|
||||
Ok(certs) => certs,
|
||||
Err(_) => return Err(Error::TlsError),
|
||||
};
|
||||
|
||||
let config = rustls::ServerConfig::builder()
|
||||
let mut keys = match load_keys(Path::new(&config.general.tls_private_key.unwrap())) {
|
||||
Ok(keys) => keys,
|
||||
Err(_) => return Err(Error::TlsError),
|
||||
};
|
||||
|
||||
let config = match rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key_der)
|
||||
.map_err(|_| Error::TlsError)?;
|
||||
.with_single_cert(certs, keys.remove(0))
|
||||
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))
|
||||
{
|
||||
Ok(c) => c,
|
||||
Err(_) => return Err(Error::TlsError),
|
||||
};
|
||||
|
||||
Ok(Tls {
|
||||
acceptor: TlsAcceptor::from(Arc::new(config)),
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
FROM rust:bullseye
|
||||
|
||||
COPY --from=sclevine/yj /bin/yj /bin/yj
|
||||
RUN /bin/yj -h
|
||||
RUN apt-get update && apt-get install llvm-11 psmisc postgresql-contrib postgresql-client ruby ruby-dev libpq-dev python3 python3-pip lcov curl sudo iproute2 -y
|
||||
RUN cargo install cargo-binutils rustfilt
|
||||
RUN rustup component add llvm-tools-preview
|
||||
|
||||
@@ -90,4 +90,28 @@ describe "Admin" do
|
||||
expect(results["pool_mode"]).to eq("transaction")
|
||||
end
|
||||
end
|
||||
|
||||
describe "PAUSE" do
|
||||
it "pauses all pools" do
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW DATABASES").to_a
|
||||
expect(results.map{ |r| r["paused"] }.uniq).to eq(["0"])
|
||||
|
||||
admin_conn.async_exec("PAUSE")
|
||||
|
||||
results = admin_conn.async_exec("SHOW DATABASES").to_a
|
||||
expect(results.map{ |r| r["paused"] }.uniq).to eq(["1"])
|
||||
|
||||
admin_conn.async_exec("RESUME")
|
||||
|
||||
results = admin_conn.async_exec("SHOW DATABASES").to_a
|
||||
expect(results.map{ |r| r["paused"] }.uniq).to eq(["0"])
|
||||
end
|
||||
|
||||
it "handles errors" do
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
expect { admin_conn.async_exec("PAUSE foo").to_a }.to raise_error(PG::SystemError)
|
||||
expect { admin_conn.async_exec("PAUSE foo,bar").to_a }.to raise_error(PG::SystemError)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@@ -185,7 +185,7 @@ describe "Auth Query" do
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
context 'and with cleartext passwords set' do
|
||||
it 'it uses local passwords' do
|
||||
|
||||
@@ -33,18 +33,18 @@ module Helpers
|
||||
"0" => {
|
||||
"database" => "shard0",
|
||||
"servers" => [
|
||||
["localhost", primary.port.to_s, "primary"],
|
||||
["localhost", replica.port.to_s, "replica"],
|
||||
["localhost", primary.port.to_i, "primary"],
|
||||
["localhost", replica.port.to_i, "replica"],
|
||||
]
|
||||
},
|
||||
},
|
||||
"users" => { "0" => user.merge(config_user) }
|
||||
}
|
||||
}
|
||||
pgcat_cfg["general"]["port"] = pgcat.port
|
||||
pgcat_cfg["general"]["port"] = pgcat.port.to_i
|
||||
pgcat.update_config(pgcat_cfg)
|
||||
pgcat.start
|
||||
|
||||
|
||||
pgcat.wait_until_ready(
|
||||
pgcat.connection_string(
|
||||
"sharded_db",
|
||||
@@ -92,13 +92,13 @@ module Helpers
|
||||
"0" => {
|
||||
"database" => database,
|
||||
"servers" => [
|
||||
["localhost", primary.port.to_s, "primary"],
|
||||
["localhost", replica.port.to_s, "replica"],
|
||||
["localhost", primary.port.to_i, "primary"],
|
||||
["localhost", replica.port.to_i, "replica"],
|
||||
]
|
||||
},
|
||||
},
|
||||
"users" => { "0" => user.merge(config_user) }
|
||||
}
|
||||
}
|
||||
end
|
||||
# Main proxy configs
|
||||
pgcat_cfg["pools"] = {
|
||||
@@ -109,7 +109,7 @@ module Helpers
|
||||
pgcat_cfg["general"]["port"] = pgcat.port
|
||||
pgcat.update_config(pgcat_cfg.deep_merge(extra_conf))
|
||||
pgcat.start
|
||||
|
||||
|
||||
pgcat.wait_until_ready(pgcat.connection_string("sharded_db0", pg_user['username'], pg_user['password']))
|
||||
|
||||
OpenStruct.new.tap do |struct|
|
||||
|
||||
@@ -7,10 +7,24 @@ class PgInstance
|
||||
attr_reader :password
|
||||
attr_reader :database_name
|
||||
|
||||
def self.mass_takedown(databases)
|
||||
raise StandardError "block missing" unless block_given?
|
||||
|
||||
databases.each do |database|
|
||||
database.toxiproxy.toxic(:limit_data, bytes: 1).toxics.each(&:save)
|
||||
end
|
||||
sleep 0.1
|
||||
yield
|
||||
ensure
|
||||
databases.each do |database|
|
||||
database.toxiproxy.toxics.each(&:destroy)
|
||||
end
|
||||
end
|
||||
|
||||
def initialize(port, username, password, database_name)
|
||||
@original_port = port
|
||||
@original_port = port.to_i
|
||||
@toxiproxy_port = 10000 + port.to_i
|
||||
@port = @toxiproxy_port
|
||||
@port = @toxiproxy_port.to_i
|
||||
|
||||
@username = username
|
||||
@password = password
|
||||
@@ -48,9 +62,9 @@ class PgInstance
|
||||
|
||||
def take_down
|
||||
if block_given?
|
||||
Toxiproxy[@toxiproxy_name].toxic(:limit_data, bytes: 5).apply { yield }
|
||||
Toxiproxy[@toxiproxy_name].toxic(:limit_data, bytes: 1).apply { yield }
|
||||
else
|
||||
Toxiproxy[@toxiproxy_name].toxic(:limit_data, bytes: 5).toxics.each(&:save)
|
||||
Toxiproxy[@toxiproxy_name].toxic(:limit_data, bytes: 1).toxics.each(&:save)
|
||||
end
|
||||
end
|
||||
|
||||
@@ -89,6 +103,6 @@ class PgInstance
|
||||
end
|
||||
|
||||
def count_select_1_plus_2
|
||||
with_connection { |c| c.async_exec("SELECT SUM(calls) FROM pg_stat_statements WHERE query = 'SELECT $1 + $2'")[0]["sum"].to_i }
|
||||
with_connection { |c| c.async_exec("SELECT SUM(calls) FROM pg_stat_statements WHERE query LIKE '%SELECT $1 + $2%'")[0]["sum"].to_i }
|
||||
end
|
||||
end
|
||||
|
||||
@@ -34,12 +34,13 @@ module Helpers
|
||||
"load_balancing_mode" => lb_mode,
|
||||
"primary_reads_enabled" => true,
|
||||
"query_parser_enabled" => true,
|
||||
"query_parser_read_write_splitting" => true,
|
||||
"automatic_sharding_key" => "data.id",
|
||||
"sharding_function" => "pg_bigint_hash",
|
||||
"shards" => {
|
||||
"0" => { "database" => "shard0", "servers" => [["localhost", primary0.port.to_s, "primary"]] },
|
||||
"1" => { "database" => "shard1", "servers" => [["localhost", primary1.port.to_s, "primary"]] },
|
||||
"2" => { "database" => "shard2", "servers" => [["localhost", primary2.port.to_s, "primary"]] },
|
||||
"0" => { "database" => "shard0", "servers" => [["localhost", primary0.port.to_i, "primary"]] },
|
||||
"1" => { "database" => "shard1", "servers" => [["localhost", primary1.port.to_i, "primary"]] },
|
||||
"2" => { "database" => "shard2", "servers" => [["localhost", primary2.port.to_i, "primary"]] },
|
||||
},
|
||||
"users" => { "0" => user },
|
||||
"plugins" => {
|
||||
@@ -99,7 +100,7 @@ module Helpers
|
||||
"0" => {
|
||||
"database" => "shard0",
|
||||
"servers" => [
|
||||
["localhost", primary.port.to_s, "primary"]
|
||||
["localhost", primary.port.to_i, "primary"]
|
||||
]
|
||||
},
|
||||
},
|
||||
@@ -145,10 +146,10 @@ module Helpers
|
||||
"0" => {
|
||||
"database" => "shard0",
|
||||
"servers" => [
|
||||
["localhost", primary.port.to_s, "primary"],
|
||||
["localhost", replica0.port.to_s, "replica"],
|
||||
["localhost", replica1.port.to_s, "replica"],
|
||||
["localhost", replica2.port.to_s, "replica"]
|
||||
["localhost", primary.port.to_i, "primary"],
|
||||
["localhost", replica0.port.to_i, "replica"],
|
||||
["localhost", replica1.port.to_i, "replica"],
|
||||
["localhost", replica2.port.to_i, "replica"]
|
||||
]
|
||||
},
|
||||
},
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
require 'pg'
|
||||
require 'toml'
|
||||
require 'json'
|
||||
require 'tempfile'
|
||||
require 'fileutils'
|
||||
require 'securerandom'
|
||||
|
||||
class ConfigReloadFailed < StandardError; end
|
||||
class PgcatProcess
|
||||
attr_reader :port
|
||||
attr_reader :pid
|
||||
@@ -18,7 +20,7 @@ class PgcatProcess
|
||||
end
|
||||
|
||||
def initialize(log_level)
|
||||
@env = {"RUST_LOG" => log_level}
|
||||
@env = {}
|
||||
@port = rand(20000..32760)
|
||||
@log_level = log_level
|
||||
@log_filename = "/tmp/pgcat_log_#{SecureRandom.urlsafe_base64}.log"
|
||||
@@ -30,7 +32,7 @@ class PgcatProcess
|
||||
'../../target/debug/pgcat'
|
||||
end
|
||||
|
||||
@command = "#{command_path} #{@config_filename}"
|
||||
@command = "#{command_path} #{@config_filename} --log-level #{@log_level}"
|
||||
|
||||
FileUtils.cp("../../pgcat.toml", @config_filename)
|
||||
cfg = current_config
|
||||
@@ -46,22 +48,34 @@ class PgcatProcess
|
||||
|
||||
def update_config(config_hash)
|
||||
@original_config = current_config
|
||||
output_to_write = TOML::Generator.new(config_hash).body
|
||||
output_to_write = output_to_write.gsub(/,\s*["|'](\d+)["|']\s*,/, ',\1,')
|
||||
output_to_write = output_to_write.gsub(/,\s*["|'](\d+)["|']\s*\]/, ',\1]')
|
||||
File.write(@config_filename, output_to_write)
|
||||
Tempfile.create('json_out', '/tmp') do |f|
|
||||
f.write(config_hash.to_json)
|
||||
f.flush
|
||||
`cat #{f.path} | yj -jt > #{@config_filename}`
|
||||
end
|
||||
end
|
||||
|
||||
def current_config
|
||||
loadable_string = File.read(@config_filename)
|
||||
loadable_string = loadable_string.gsub(/,\s*(\d+)\s*,/, ', "\1",')
|
||||
loadable_string = loadable_string.gsub(/,\s*(\d+)\s*\]/, ', "\1"]')
|
||||
TOML.load(loadable_string)
|
||||
JSON.parse(`cat #{@config_filename} | yj -tj`)
|
||||
end
|
||||
|
||||
def raw_config_file
|
||||
File.read(@config_filename)
|
||||
end
|
||||
|
||||
def reload_config
|
||||
`kill -s HUP #{@pid}`
|
||||
sleep 0.5
|
||||
conn = PG.connect(admin_connection_string)
|
||||
|
||||
conn.async_exec("RELOAD")
|
||||
rescue PG::ConnectionBad => e
|
||||
errors = logs.split("Reloading config").last
|
||||
errors = errors.gsub(/\e\[([;\d]+)?m/, '') # Remove color codes
|
||||
errors = errors.
|
||||
split("\n").select{|line| line.include?("ERROR") }.
|
||||
map { |line| line.split("pgcat::config: ").last }
|
||||
raise ConfigReloadFailed, errors.join("\n")
|
||||
ensure
|
||||
conn&.close
|
||||
end
|
||||
|
||||
def start
|
||||
@@ -112,10 +126,16 @@ class PgcatProcess
|
||||
"postgresql://#{username}:#{password}@0.0.0.0:#{@port}/pgcat"
|
||||
end
|
||||
|
||||
def connection_string(pool_name, username, password = nil)
|
||||
def connection_string(pool_name, username, password = nil, parameters: {})
|
||||
cfg = current_config
|
||||
user_idx, user_obj = cfg["pools"][pool_name]["users"].detect { |k, user| user["username"] == username }
|
||||
"postgresql://#{username}:#{password || user_obj["password"]}@0.0.0.0:#{@port}/#{pool_name}"
|
||||
connection_string = "postgresql://#{username}:#{password || user_obj["password"]}@0.0.0.0:#{@port}/#{pool_name}"
|
||||
|
||||
# Add the additional parameters to the connection string
|
||||
parameter_string = parameters.map { |key, value| "#{key}=#{value}" }.join("&")
|
||||
connection_string += "?#{parameter_string}" unless parameter_string.empty?
|
||||
|
||||
connection_string
|
||||
end
|
||||
|
||||
def example_connection_string
|
||||
|
||||
@@ -11,9 +11,9 @@ describe "Query Mirroing" do
|
||||
before do
|
||||
new_configs = processes.pgcat.current_config
|
||||
new_configs["pools"]["sharded_db"]["shards"]["0"]["mirrors"] = [
|
||||
[mirror_host, mirror_pg.port.to_s, "0"],
|
||||
[mirror_host, mirror_pg.port.to_s, "0"],
|
||||
[mirror_host, mirror_pg.port.to_s, "0"],
|
||||
[mirror_host, mirror_pg.port.to_i, 0],
|
||||
[mirror_host, mirror_pg.port.to_i, 0],
|
||||
[mirror_host, mirror_pg.port.to_i, 0],
|
||||
]
|
||||
processes.pgcat.update_config(new_configs)
|
||||
processes.pgcat.reload_config
|
||||
@@ -31,7 +31,8 @@ describe "Query Mirroing" do
|
||||
runs.times { conn.async_exec("SELECT 1 + 2") }
|
||||
sleep 0.5
|
||||
expect(processes.all_databases.first.count_select_1_plus_2).to eq(runs)
|
||||
expect(mirror_pg.count_select_1_plus_2).to eq(runs * 3)
|
||||
# Allow some slack in mirroring successes
|
||||
expect(mirror_pg.count_select_1_plus_2).to be > ((runs - 5) * 3)
|
||||
end
|
||||
|
||||
context "when main server connection is closed" do
|
||||
@@ -42,9 +43,9 @@ describe "Query Mirroing" do
|
||||
new_configs = processes.pgcat.current_config
|
||||
new_configs["pools"]["sharded_db"]["idle_timeout"] = 5000 + i
|
||||
new_configs["pools"]["sharded_db"]["shards"]["0"]["mirrors"] = [
|
||||
[mirror_host, mirror_pg.port.to_s, "0"],
|
||||
[mirror_host, mirror_pg.port.to_s, "0"],
|
||||
[mirror_host, mirror_pg.port.to_s, "0"],
|
||||
[mirror_host, mirror_pg.port.to_i, 0],
|
||||
[mirror_host, mirror_pg.port.to_i, 0],
|
||||
[mirror_host, mirror_pg.port.to_i, 0],
|
||||
]
|
||||
processes.pgcat.update_config(new_configs)
|
||||
processes.pgcat.reload_config
|
||||
|
||||
@@ -221,7 +221,7 @@ describe "Miscellaneous" do
|
||||
conn.close
|
||||
end
|
||||
|
||||
it "Does not send DISCARD ALL unless necessary" do
|
||||
it "Does not send RESET ALL unless necessary" do
|
||||
10.times do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
conn.async_exec("SET SERVER ROLE to 'primary'")
|
||||
@@ -229,7 +229,7 @@ describe "Miscellaneous" do
|
||||
conn.close
|
||||
end
|
||||
|
||||
expect(processes.primary.count_query("DISCARD ALL")).to eq(0)
|
||||
expect(processes.primary.count_query("RESET ALL")).to eq(0)
|
||||
|
||||
10.times do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
@@ -239,7 +239,7 @@ describe "Miscellaneous" do
|
||||
conn.close
|
||||
end
|
||||
|
||||
expect(processes.primary.count_query("DISCARD ALL")).to eq(10)
|
||||
expect(processes.primary.count_query("RESET ALL")).to eq(10)
|
||||
end
|
||||
|
||||
it "Resets server roles correctly" do
|
||||
@@ -252,7 +252,7 @@ describe "Miscellaneous" do
|
||||
end
|
||||
|
||||
expect(processes.primary.count_query("RESET ROLE")).to eq(10)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
context "transaction mode" do
|
||||
@@ -273,7 +273,7 @@ describe "Miscellaneous" do
|
||||
end
|
||||
end
|
||||
|
||||
it "Does not send DISCARD ALL unless necessary" do
|
||||
it "Does not send RESET ALL unless necessary" do
|
||||
10.times do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
conn.async_exec("SET SERVER ROLE to 'primary'")
|
||||
@@ -282,7 +282,7 @@ describe "Miscellaneous" do
|
||||
conn.close
|
||||
end
|
||||
|
||||
expect(processes.primary.count_query("DISCARD ALL")).to eq(0)
|
||||
expect(processes.primary.count_query("RESET ALL")).to eq(0)
|
||||
|
||||
10.times do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
@@ -292,8 +292,32 @@ describe "Miscellaneous" do
|
||||
conn.close
|
||||
end
|
||||
|
||||
expect(processes.primary.count_query("DISCARD ALL")).to eq(10)
|
||||
expect(processes.primary.count_query("RESET ALL")).to eq(10)
|
||||
end
|
||||
|
||||
it "Respects tracked parameters on startup" do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user", parameters: { "application_name" => "my_pgcat_test" }))
|
||||
|
||||
expect(conn.async_exec("SHOW application_name")[0]["application_name"]).to eq("my_pgcat_test")
|
||||
conn.close
|
||||
end
|
||||
|
||||
it "Respect tracked parameter on set statemet" do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
|
||||
conn.async_exec("SET application_name to 'my_pgcat_test'")
|
||||
expect(conn.async_exec("SHOW application_name")[0]["application_name"]).to eq("my_pgcat_test")
|
||||
end
|
||||
|
||||
|
||||
it "Ignore untracked parameter on set statemet" do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
orignal_statement_timeout = conn.async_exec("SHOW statement_timeout")[0]["statement_timeout"]
|
||||
|
||||
conn.async_exec("SET statement_timeout to 1500")
|
||||
expect(conn.async_exec("SHOW statement_timeout")[0]["statement_timeout"]).to eq(orignal_statement_timeout)
|
||||
end
|
||||
|
||||
end
|
||||
|
||||
context "transaction mode with transactions" do
|
||||
@@ -307,7 +331,7 @@ describe "Miscellaneous" do
|
||||
conn.async_exec("COMMIT")
|
||||
conn.close
|
||||
end
|
||||
expect(processes.primary.count_query("DISCARD ALL")).to eq(0)
|
||||
expect(processes.primary.count_query("RESET ALL")).to eq(0)
|
||||
|
||||
10.times do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
@@ -317,7 +341,7 @@ describe "Miscellaneous" do
|
||||
conn.async_exec("COMMIT")
|
||||
conn.close
|
||||
end
|
||||
expect(processes.primary.count_query("DISCARD ALL")).to eq(0)
|
||||
expect(processes.primary.count_query("RESET ALL")).to eq(0)
|
||||
end
|
||||
end
|
||||
|
||||
@@ -330,8 +354,7 @@ describe "Miscellaneous" do
|
||||
conn.async_exec("SET statement_timeout TO 1000")
|
||||
conn.close
|
||||
|
||||
puts processes.pgcat.logs
|
||||
expect(processes.primary.count_query("DISCARD ALL")).to eq(0)
|
||||
expect(processes.primary.count_query("RESET ALL")).to eq(0)
|
||||
end
|
||||
|
||||
it "will not clean up prepared statements" do
|
||||
@@ -341,8 +364,7 @@ describe "Miscellaneous" do
|
||||
|
||||
conn.close
|
||||
|
||||
puts processes.pgcat.logs
|
||||
expect(processes.primary.count_query("DISCARD ALL")).to eq(0)
|
||||
expect(processes.primary.count_query("RESET ALL")).to eq(0)
|
||||
end
|
||||
end
|
||||
end
|
||||
@@ -352,10 +374,9 @@ describe "Miscellaneous" do
|
||||
before do
|
||||
current_configs = processes.pgcat.current_config
|
||||
correct_idle_client_transaction_timeout = current_configs["general"]["idle_client_in_transaction_timeout"]
|
||||
puts(current_configs["general"]["idle_client_in_transaction_timeout"])
|
||||
|
||||
|
||||
current_configs["general"]["idle_client_in_transaction_timeout"] = 0
|
||||
|
||||
|
||||
processes.pgcat.update_config(current_configs) # with timeout 0
|
||||
processes.pgcat.reload_config
|
||||
end
|
||||
@@ -373,9 +394,9 @@ describe "Miscellaneous" do
|
||||
context "idle transaction timeout set to 500ms" do
|
||||
before do
|
||||
current_configs = processes.pgcat.current_config
|
||||
correct_idle_client_transaction_timeout = current_configs["general"]["idle_client_in_transaction_timeout"]
|
||||
correct_idle_client_transaction_timeout = current_configs["general"]["idle_client_in_transaction_timeout"]
|
||||
current_configs["general"]["idle_client_in_transaction_timeout"] = 500
|
||||
|
||||
|
||||
processes.pgcat.update_config(current_configs) # with timeout 500
|
||||
processes.pgcat.reload_config
|
||||
end
|
||||
@@ -394,7 +415,7 @@ describe "Miscellaneous" do
|
||||
conn.async_exec("BEGIN")
|
||||
conn.async_exec("SELECT 1")
|
||||
sleep(1) # above 500ms
|
||||
expect{ conn.async_exec("COMMIT") }.to raise_error(PG::SystemError, /idle transaction timeout/)
|
||||
expect{ conn.async_exec("COMMIT") }.to raise_error(PG::SystemError, /idle transaction timeout/)
|
||||
conn.async_exec("SELECT 1") # should be able to send another query
|
||||
conn.close
|
||||
end
|
||||
|
||||
@@ -7,11 +7,11 @@ describe "Sharding" do
|
||||
|
||||
before do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
|
||||
# Setup the sharding data
|
||||
3.times do |i|
|
||||
conn.exec("SET SHARD TO '#{i}'")
|
||||
conn.exec("DELETE FROM data WHERE id > 0")
|
||||
|
||||
conn.exec("DELETE FROM data WHERE id > 0") rescue nil
|
||||
end
|
||||
|
||||
18.times do |i|
|
||||
@@ -19,10 +19,11 @@ describe "Sharding" do
|
||||
conn.exec("SET SHARDING KEY TO '#{i}'")
|
||||
conn.exec("INSERT INTO data (id, value) VALUES (#{i}, 'value_#{i}')")
|
||||
end
|
||||
|
||||
conn.close
|
||||
end
|
||||
|
||||
after do
|
||||
|
||||
processes.all_databases.map(&:reset)
|
||||
processes.pgcat.shutdown
|
||||
end
|
||||
@@ -48,4 +49,148 @@ describe "Sharding" do
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "no_shard_specified_behavior config" do
|
||||
context "when default shard number is invalid" do
|
||||
it "prevents config reload" do
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
|
||||
current_configs = processes.pgcat.current_config
|
||||
current_configs["pools"]["sharded_db"]["default_shard"] = "shard_99"
|
||||
|
||||
processes.pgcat.update_config(current_configs)
|
||||
|
||||
expect { processes.pgcat.reload_config }.to raise_error(ConfigReloadFailed, /Invalid shard 99/)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "comment-based routing" do
|
||||
context "when no configs are set" do
|
||||
it "routes queries with a shard_id comment to the default shard" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
10.times { conn.async_exec("/* shard_id: 2 */ SELECT 1 + 2") }
|
||||
|
||||
expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([10, 0, 0])
|
||||
end
|
||||
|
||||
it "does not honor no_shard_specified_behavior directives" do
|
||||
end
|
||||
end
|
||||
|
||||
[
|
||||
["shard_id_regex", "/\\* the_shard_id: (\\d+) \\*/", "/* the_shard_id: 1 */"],
|
||||
["sharding_key_regex", "/\\* the_sharding_key: (\\d+) \\*/", "/* the_sharding_key: 3 */"],
|
||||
].each do |config_name, config_value, comment_to_use|
|
||||
context "when #{config_name} config is set" do
|
||||
let(:no_shard_specified_behavior) { nil }
|
||||
|
||||
before do
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
|
||||
current_configs = processes.pgcat.current_config
|
||||
current_configs["pools"]["sharded_db"][config_name] = config_value
|
||||
if no_shard_specified_behavior
|
||||
current_configs["pools"]["sharded_db"]["default_shard"] = no_shard_specified_behavior
|
||||
else
|
||||
current_configs["pools"]["sharded_db"].delete("default_shard")
|
||||
end
|
||||
|
||||
processes.pgcat.update_config(current_configs)
|
||||
processes.pgcat.reload_config
|
||||
end
|
||||
|
||||
it "routes queries with a shard_id comment to the correct shard" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") }
|
||||
|
||||
expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 25, 0])
|
||||
end
|
||||
|
||||
context "when no_shard_specified_behavior config is set to random" do
|
||||
let(:no_shard_specified_behavior) { "random" }
|
||||
|
||||
context "with no shard comment" do
|
||||
it "sends queries to random shard" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
25.times { conn.async_exec("SELECT 1 + 2") }
|
||||
|
||||
expect(processes.all_databases.map(&:count_select_1_plus_2).all?(&:positive?)).to be true
|
||||
end
|
||||
end
|
||||
|
||||
context "with a shard comment" do
|
||||
it "honors the comment" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") }
|
||||
|
||||
expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 25, 0])
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
context "when no_shard_specified_behavior config is set to random_healthy" do
|
||||
let(:no_shard_specified_behavior) { "random_healthy" }
|
||||
|
||||
context "with no shard comment" do
|
||||
it "sends queries to random healthy shard" do
|
||||
|
||||
good_databases = [processes.all_databases[0], processes.all_databases[2]]
|
||||
bad_database = processes.all_databases[1]
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
250.times { conn.async_exec("SELECT 99") }
|
||||
bad_database.take_down do
|
||||
250.times do
|
||||
conn.async_exec("SELECT 99")
|
||||
rescue PG::ConnectionBad => e
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
end
|
||||
end
|
||||
|
||||
# Routes traffic away from bad shard
|
||||
25.times { conn.async_exec("SELECT 1 + 2") }
|
||||
expect(good_databases.map(&:count_select_1_plus_2).all?(&:positive?)).to be true
|
||||
expect(bad_database.count_select_1_plus_2).to eq(0)
|
||||
|
||||
# Routes traffic to the bad shard if the shard_id is specified
|
||||
25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") }
|
||||
bad_database = processes.all_databases[1]
|
||||
expect(bad_database.count_select_1_plus_2).to eq(25)
|
||||
end
|
||||
end
|
||||
|
||||
context "with a shard comment" do
|
||||
it "honors the comment" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") }
|
||||
|
||||
expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 25, 0])
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
context "when no_shard_specified_behavior config is set to shard_x" do
|
||||
let(:no_shard_specified_behavior) { "shard_2" }
|
||||
|
||||
context "with no shard comment" do
|
||||
it "sends queries to the specified shard" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
25.times { conn.async_exec("SELECT 1 + 2") }
|
||||
|
||||
expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 0, 25])
|
||||
end
|
||||
end
|
||||
|
||||
context "with a shard comment" do
|
||||
it "honors the comment" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") }
|
||||
|
||||
expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 25, 0])
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
40
utilities/deb.sh
Normal file
40
utilities/deb.sh
Normal file
@@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Build an Ubuntu deb.
|
||||
#
|
||||
script_dir=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||
deb_dir="/tmp/pgcat-build"
|
||||
export PACKAGE_VERSION=${1:-"1.1.1"}
|
||||
if [[ $(arch) == "x86_64" ]]; then
|
||||
export ARCH=amd64
|
||||
else
|
||||
export ARCH=arm64
|
||||
fi
|
||||
|
||||
cd "$script_dir/.."
|
||||
cargo build --release
|
||||
|
||||
rm -rf "$deb_dir"
|
||||
mkdir -p "$deb_dir/DEBIAN"
|
||||
mkdir -p "$deb_dir/usr/bin"
|
||||
mkdir -p "$deb_dir/etc/systemd/system"
|
||||
|
||||
cp target/release/pgcat "$deb_dir/usr/bin/pgcat"
|
||||
chmod +x "$deb_dir/usr/bin/pgcat"
|
||||
|
||||
cp pgcat.toml "$deb_dir/etc/pgcat.toml"
|
||||
cp pgcat.service "$deb_dir/etc/systemd/system/pgcat.service"
|
||||
|
||||
(cat control | envsubst) > "$deb_dir/DEBIAN/control"
|
||||
cp postinst "$deb_dir/DEBIAN/postinst"
|
||||
cp postrm "$deb_dir/DEBIAN/postrm"
|
||||
cp prerm "$deb_dir/DEBIAN/prerm"
|
||||
|
||||
chmod +x ${deb_dir}/DEBIAN/post*
|
||||
chmod +x ${deb_dir}/DEBIAN/pre*
|
||||
|
||||
dpkg-deb \
|
||||
--root-owner-group \
|
||||
-z1 \
|
||||
--build "$deb_dir" \
|
||||
pgcat-${PACKAGE_VERSION}-ubuntu22.04-${ARCH}.deb
|
||||
Reference in New Issue
Block a user