mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 09:26:30 +00:00
Compare commits
4 Commits
levkk-tls-
...
levkk-more
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e7265cbf91 | ||
|
|
d738ba28b6 | ||
|
|
ff80bb75cc | ||
|
|
374a6b138b |
@@ -9,7 +9,7 @@ jobs:
|
||||
# Specify the execution environment. You can specify an image from Dockerhub or use one of our Convenience Images from CircleCI's Developer Hub.
|
||||
# See: https://circleci.com/docs/2.0/configuration-reference/#docker-machine-macos-windows-executor
|
||||
docker:
|
||||
- image: ghcr.io/postgresml/pgcat-ci:latest
|
||||
- image: ghcr.io/levkk/pgcat-ci:1.67
|
||||
environment:
|
||||
RUST_LOG: info
|
||||
LLVM_PROFILE_FILE: /tmp/pgcat-%m-%p.profraw
|
||||
|
||||
@@ -74,10 +74,6 @@ default_role = "any"
|
||||
# we'll direct it to the primary.
|
||||
query_parser_enabled = true
|
||||
|
||||
# If the query parser is enabled and this setting is enabled, we'll attempt to
|
||||
# infer the role from the query itself.
|
||||
query_parser_read_write_splitting = true
|
||||
|
||||
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
|
||||
# load balancing of read queries. Otherwise, the primary will only be used for write
|
||||
# queries. The primary can always be explicitely selected with our custom protocol.
|
||||
@@ -138,7 +134,6 @@ database = "shard2"
|
||||
pool_mode = "session"
|
||||
default_role = "primary"
|
||||
query_parser_enabled = true
|
||||
query_parser_read_write_splitting = true
|
||||
primary_reads_enabled = true
|
||||
sharding_function = "pg_bigint_hash"
|
||||
|
||||
|
||||
8
.github/workflows/build-and-push.yaml
vendored
8
.github/workflows/build-and-push.yaml
vendored
@@ -1,11 +1,6 @@
|
||||
name: Build and Push
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
tags:
|
||||
- v*
|
||||
on: push
|
||||
|
||||
env:
|
||||
registry: ghcr.io
|
||||
@@ -34,7 +29,6 @@ jobs:
|
||||
tags: |
|
||||
type=sha,prefix=,format=long
|
||||
type=schedule
|
||||
type=ref,event=tag
|
||||
type=ref,event=branch
|
||||
type=ref,event=pr
|
||||
type=raw,value=latest,enable={{ is_default_branch }}
|
||||
|
||||
48
.github/workflows/publish-deb-package.yml
vendored
48
.github/workflows/publish-deb-package.yml
vendored
@@ -1,48 +0,0 @@
|
||||
name: pgcat package (deb)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
packageVersion:
|
||||
default: "1.1.2-dev"
|
||||
jobs:
|
||||
build:
|
||||
strategy:
|
||||
max-parallel: 1
|
||||
fail-fast: false # Let the other job finish, or they can lock each other out
|
||||
matrix:
|
||||
os: ["buildjet-4vcpu-ubuntu-2204", "buildjet-4vcpu-ubuntu-2204-arm"]
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: stable
|
||||
- name: Install dependencies
|
||||
env:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
TZ: Etc/UTC
|
||||
run: |
|
||||
curl -sLO https://github.com/deb-s3/deb-s3/releases/download/0.11.4/deb-s3-0.11.4.gem
|
||||
sudo gem install deb-s3-0.11.4.gem
|
||||
dpkg-deb --version
|
||||
- name: Build and release package
|
||||
env:
|
||||
AWS_ACCESS_KEY_ID: ${{ vars.AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
AWS_DEFAULT_REGION: ${{ vars.AWS_DEFAULT_REGION }}
|
||||
run: |
|
||||
if [[ $(arch) == "x86_64" ]]; then
|
||||
export ARCH=amd64
|
||||
else
|
||||
export ARCH=arm64
|
||||
fi
|
||||
|
||||
bash utilities/deb.sh ${{ inputs.packageVersion }}
|
||||
|
||||
deb-s3 upload \
|
||||
--lock \
|
||||
--bucket apt.postgresml.org \
|
||||
pgcat-${{ inputs.packageVersion }}-ubuntu22.04-${ARCH}.deb \
|
||||
--codename $(lsb_release -cs)
|
||||
97
CONFIG.md
97
CONFIG.md
@@ -1,4 +1,4 @@
|
||||
# PgCat Configurations
|
||||
# PgCat Configurations
|
||||
## `general` Section
|
||||
|
||||
### host
|
||||
@@ -57,38 +57,6 @@ default: 86400000 # 24 hours
|
||||
|
||||
Max connection lifetime before it's closed, even if actively used.
|
||||
|
||||
### server_round_robin
|
||||
```
|
||||
path: general.server_round_robin
|
||||
default: false
|
||||
```
|
||||
|
||||
Whether to use round robin for server selection or not.
|
||||
|
||||
### server_tls
|
||||
```
|
||||
path: general.server_tls
|
||||
default: false
|
||||
```
|
||||
|
||||
Whether to use TLS for server connections or not.
|
||||
|
||||
### verify_server_certificate
|
||||
```
|
||||
path: general.verify_server_certificate
|
||||
default: false
|
||||
```
|
||||
|
||||
Whether to verify server certificate or not.
|
||||
|
||||
### verify_config
|
||||
```
|
||||
path: general.verify_config
|
||||
default: true
|
||||
```
|
||||
|
||||
Whether to verify config or not.
|
||||
|
||||
### idle_client_in_transaction_timeout
|
||||
```
|
||||
path: general.idle_client_in_transaction_timeout
|
||||
@@ -148,10 +116,10 @@ If we should log client disconnections
|
||||
### autoreload
|
||||
```
|
||||
path: general.autoreload
|
||||
default: 15000 # milliseconds
|
||||
default: 15000
|
||||
```
|
||||
|
||||
When set, PgCat automatically reloads its configurations at the specified interval (in milliseconds) if it detects changes in the configuration file. The default interval is 15000 milliseconds or 15 seconds.
|
||||
When set to true, PgCat reloads configs if it detects a change in the config file.
|
||||
|
||||
### worker_threads
|
||||
```
|
||||
@@ -183,13 +151,7 @@ path: general.tcp_keepalives_interval
|
||||
default: 5
|
||||
```
|
||||
|
||||
### tcp_user_timeout
|
||||
```
|
||||
path: general.tcp_user_timeout
|
||||
default: 10000
|
||||
```
|
||||
A linux-only parameters that defines the amount of time in milliseconds that transmitted data may remain unacknowledged or buffered data may remain untransmitted (due to zero window size) before TCP will forcibly disconnect
|
||||
|
||||
Number of seconds between keepalive packets.
|
||||
|
||||
### tls_certificate
|
||||
```
|
||||
@@ -226,55 +188,6 @@ default: "admin_pass"
|
||||
|
||||
Password to access the virtual administrative database
|
||||
|
||||
### auth_query
|
||||
```
|
||||
path: general.auth_query
|
||||
default: <UNSET>
|
||||
example: "SELECT $1"
|
||||
```
|
||||
|
||||
Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
|
||||
established using the database configured in the pool. This parameter is inherited by every pool
|
||||
and can be redefined in pool configuration.
|
||||
|
||||
### auth_query_user
|
||||
```
|
||||
path: general.auth_query_user
|
||||
default: <UNSET>
|
||||
example: "sharding_user"
|
||||
```
|
||||
|
||||
User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
|
||||
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
|
||||
This parameter is inherited by every pool and can be redefined in pool configuration.
|
||||
|
||||
### auth_query_password
|
||||
```
|
||||
path: general.auth_query_password
|
||||
default: <UNSET>
|
||||
example: "sharding_user"
|
||||
```
|
||||
|
||||
Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
|
||||
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
|
||||
This parameter is inherited by every pool and can be redefined in pool configuration.
|
||||
|
||||
### prepared_statements
|
||||
```
|
||||
path: general.prepared_statements
|
||||
default: false
|
||||
```
|
||||
|
||||
Whether to use prepared statements or not.
|
||||
|
||||
### prepared_statements_cache_size
|
||||
```
|
||||
path: general.prepared_statements_cache_size
|
||||
default: 500
|
||||
```
|
||||
|
||||
Size of the prepared statements cache.
|
||||
|
||||
### dns_cache_enabled
|
||||
```
|
||||
path: general.dns_cache_enabled
|
||||
@@ -311,7 +224,7 @@ default: "random"
|
||||
|
||||
Load balancing mode
|
||||
`random` selects the server at random
|
||||
`loc` selects the server with the least outstanding busy connections
|
||||
`loc` selects the server with the least outstanding busy conncetions
|
||||
|
||||
### default_role
|
||||
```
|
||||
|
||||
980
Cargo.lock
generated
980
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
11
Cargo.toml
11
Cargo.toml
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "pgcat"
|
||||
version = "1.1.2-dev"
|
||||
version = "1.0.2-alpha1"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
@@ -8,7 +8,7 @@ edition = "2021"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
bytes = "1"
|
||||
md-5 = "0.10"
|
||||
bb8 = "0.8.1"
|
||||
bb8 = "0.8.0"
|
||||
async-trait = "0.1"
|
||||
rand = "0.8"
|
||||
chrono = "0.4"
|
||||
@@ -19,9 +19,10 @@ serde_derive = "1"
|
||||
regex = "1"
|
||||
num_cpus = "1"
|
||||
once_cell = "1"
|
||||
sqlparser = {version = "0.34", features = ["visitor"] }
|
||||
sqlparser = {version = "0.33", features = ["visitor"] }
|
||||
log = "0.4"
|
||||
arc-swap = "1"
|
||||
env_logger = "0.10"
|
||||
parking_lot = "0.12.1"
|
||||
hmac = "0.12"
|
||||
sha2 = "0.10"
|
||||
@@ -44,10 +45,6 @@ rustls = { version = "0.21", features = ["dangerous_configuration"] }
|
||||
trust-dns-resolver = "0.22.0"
|
||||
tokio-test = "0.4.2"
|
||||
serde_json = "1"
|
||||
itertools = "0.10"
|
||||
clap = { version = "4.3.1", features = ["derive", "env"] }
|
||||
tracing = "0.1.37"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter", "std"]}
|
||||
|
||||
[target.'cfg(not(target_env = "msvc"))'.dependencies]
|
||||
jemallocator = "0.5.0"
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
FROM rust:1-slim-bookworm AS builder
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential
|
||||
|
||||
FROM rust:1 AS builder
|
||||
COPY . /app
|
||||
WORKDIR /app
|
||||
RUN cargo build --release
|
||||
|
||||
FROM debian:bookworm-slim
|
||||
FROM debian:bullseye-slim
|
||||
COPY --from=builder /app/target/release/pgcat /usr/bin/pgcat
|
||||
COPY --from=builder /app/pgcat.toml /etc/pgcat/pgcat.toml
|
||||
WORKDIR /etc/pgcat
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
FROM cimg/rust:1.67.1
|
||||
COPY --from=sclevine/yj /bin/yj /bin/yj
|
||||
RUN /bin/yj -h
|
||||
RUN sudo apt-get update && \
|
||||
sudo apt-get install -y \
|
||||
psmisc postgresql-contrib-14 postgresql-client-14 libpq-dev \
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1 AS chef
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
FROM chef AS planner
|
||||
COPY . .
|
||||
RUN cargo chef prepare --recipe-path recipe.json
|
||||
|
||||
FROM chef AS builder
|
||||
COPY --from=planner /app/recipe.json recipe.json
|
||||
# Build dependencies - this is the caching Docker layer!
|
||||
RUN cargo chef cook --release --recipe-path recipe.json
|
||||
# Build application
|
||||
COPY . .
|
||||
RUN cargo build
|
||||
|
||||
FROM debian:bookworm-slim
|
||||
COPY --from=builder /app/target/release/pgcat /usr/bin/pgcat
|
||||
COPY --from=builder /app/pgcat.toml /etc/pgcat/pgcat.toml
|
||||
WORKDIR /etc/pgcat
|
||||
ENV RUST_LOG=info
|
||||
CMD ["pgcat"]
|
||||
9
control
9
control
@@ -1,9 +0,0 @@
|
||||
Package: pgcat
|
||||
Version: ${PACKAGE_VERSION}
|
||||
Section: database
|
||||
Priority: optional
|
||||
Architecture: ${ARCH}
|
||||
Maintainer: PostgresML <team@postgresml.org>
|
||||
Homepage: https://postgresml.org
|
||||
Description: PgCat - NextGen PostgreSQL Pooler
|
||||
PostgreSQL pooler and proxy (like PgBouncer) with support for sharding, load balancing, failover and mirroring.
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM rust:1.70-bullseye
|
||||
FROM rust:bullseye
|
||||
|
||||
# Dependencies
|
||||
RUN apt-get update -y \
|
||||
|
||||
@@ -25,7 +25,7 @@ x-common-env-pg:
|
||||
|
||||
services:
|
||||
main:
|
||||
image: gcr.io/google_containers/pause:3.2
|
||||
image: kubernetes/pause
|
||||
ports:
|
||||
- 6432
|
||||
|
||||
@@ -64,7 +64,7 @@ services:
|
||||
<<: *common-env-pg
|
||||
POSTGRES_INITDB_ARGS: --auth-local=md5 --auth-host=md5 --auth=md5
|
||||
PGPORT: 10432
|
||||
command: ["postgres", "-p", "10432", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-c", "pg_stat_statements.max=100000"]
|
||||
command: ["postgres", "-p", "5432", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-c", "pg_stat_statements.max=100000"]
|
||||
|
||||
toxiproxy:
|
||||
build: .
|
||||
|
||||
@@ -71,10 +71,6 @@ default_role = "any"
|
||||
# we'll direct it to the primary.
|
||||
query_parser_enabled = true
|
||||
|
||||
# If the query parser is enabled and this setting is enabled, we'll attempt to
|
||||
# infer the role from the query itself.
|
||||
query_parser_read_write_splitting = true
|
||||
|
||||
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
|
||||
# load balancing of read queries. Otherwise, the primary will only be used for write
|
||||
# queries. The primary can always be explicitly selected with our custom protocol.
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
# This is an example of the most basic config
|
||||
# that will mimic what PgBouncer does in transaction mode with one server.
|
||||
|
||||
[general]
|
||||
|
||||
host = "0.0.0.0"
|
||||
port = 6433
|
||||
admin_username = "pgcat"
|
||||
admin_password = "pgcat"
|
||||
|
||||
[pools.pgml.users.0]
|
||||
username = "postgres"
|
||||
password = "postgres"
|
||||
pool_size = 10
|
||||
min_pool_size = 1
|
||||
pool_mode = "transaction"
|
||||
|
||||
[pools.pgml.shards.0]
|
||||
servers = [
|
||||
["127.0.0.1", 28815, "primary"]
|
||||
]
|
||||
database = "postgres"
|
||||
@@ -1,16 +0,0 @@
|
||||
[Unit]
|
||||
Description=PgCat pooler
|
||||
After=network.target
|
||||
StartLimitIntervalSec=0
|
||||
|
||||
[Service]
|
||||
User=pgcat
|
||||
Type=simple
|
||||
Restart=always
|
||||
RestartSec=1
|
||||
Environment=RUST_LOG=info
|
||||
LimitNOFILE=65536
|
||||
ExecStart=/usr/bin/pgcat /etc/pgcat.toml
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
92
pgcat.toml
92
pgcat.toml
@@ -60,12 +60,6 @@ tcp_keepalives_count = 5
|
||||
# Number of seconds between keepalive packets.
|
||||
tcp_keepalives_interval = 5
|
||||
|
||||
# Handle prepared statements.
|
||||
prepared_statements = true
|
||||
|
||||
# Prepared statements server cache size.
|
||||
prepared_statements_cache_size = 500
|
||||
|
||||
# Path to TLS Certificate file to use for TLS connections
|
||||
# tls_certificate = ".circleci/server.cert"
|
||||
# Path to TLS private key file to use for TLS connections
|
||||
@@ -83,58 +77,6 @@ admin_username = "admin_user"
|
||||
# Password to access the virtual administrative database
|
||||
admin_password = "admin_pass"
|
||||
|
||||
# Default plugins that are configured on all pools.
|
||||
[plugins]
|
||||
|
||||
# Prewarmer plugin that runs queries on server startup, before giving the connection
|
||||
# to the client.
|
||||
[plugins.prewarmer]
|
||||
enabled = false
|
||||
queries = [
|
||||
"SELECT pg_prewarm('pgbench_accounts')",
|
||||
]
|
||||
|
||||
# Log all queries to stdout.
|
||||
[plugins.query_logger]
|
||||
enabled = false
|
||||
|
||||
# Block access to tables that Postgres does not allow us to control.
|
||||
[plugins.table_access]
|
||||
enabled = false
|
||||
tables = [
|
||||
"pg_user",
|
||||
"pg_roles",
|
||||
"pg_database",
|
||||
]
|
||||
|
||||
# Intercept user queries and give a fake reply.
|
||||
[plugins.intercept]
|
||||
enabled = true
|
||||
|
||||
[plugins.intercept.queries.0]
|
||||
|
||||
query = "select current_database() as a, current_schemas(false) as b"
|
||||
schema = [
|
||||
["a", "text"],
|
||||
["b", "text"],
|
||||
]
|
||||
result = [
|
||||
["${DATABASE}", "{public}"],
|
||||
]
|
||||
|
||||
[plugins.intercept.queries.1]
|
||||
|
||||
query = "select current_database(), current_schema(), current_user"
|
||||
schema = [
|
||||
["current_database", "text"],
|
||||
["current_schema", "text"],
|
||||
["current_user", "text"],
|
||||
]
|
||||
result = [
|
||||
["${DATABASE}", "public", "${USER}"],
|
||||
]
|
||||
|
||||
|
||||
# pool configs are structured as pool.<pool_name>
|
||||
# the pool_name is what clients use as database name when connecting.
|
||||
# For a pool named `sharded_db`, clients access that pool using connection string like
|
||||
@@ -162,10 +104,6 @@ default_role = "any"
|
||||
# we'll direct it to the primary.
|
||||
query_parser_enabled = true
|
||||
|
||||
# If the query parser is enabled and this setting is enabled, we'll attempt to
|
||||
# infer the role from the query itself.
|
||||
query_parser_read_write_splitting = true
|
||||
|
||||
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
|
||||
# load balancing of read queries. Otherwise, the primary will only be used for write
|
||||
# queries. The primary can always be explicitly selected with our custom protocol.
|
||||
@@ -177,12 +115,6 @@ primary_reads_enabled = true
|
||||
# shard_id_regex = '/\* shard_id: (\d+) \*/'
|
||||
# regex_search_limit = 1000 # only look at the first 1000 characters of SQL statements
|
||||
|
||||
# Defines the behavior when no shard is selected in a sharded system.
|
||||
# `random`: picks a shard at random
|
||||
# `random_healthy`: picks a shard at random favoring shards with the least number of recent errors
|
||||
# `shard_<number>`: e.g. shard_0, shard_4, etc. picks a specific shard, everytime
|
||||
# no_shard_specified_behavior = "shard_0"
|
||||
|
||||
# So what if you wanted to implement a different hashing function,
|
||||
# or you've already built one and you want this pooler to use it?
|
||||
# Current options:
|
||||
@@ -193,7 +125,7 @@ sharding_function = "pg_bigint_hash"
|
||||
# Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
|
||||
# established using the database configured in the pool. This parameter is inherited by every pool
|
||||
# and can be redefined in pool configuration.
|
||||
# auth_query="SELECT usename, passwd FROM pg_shadow WHERE usename='$1'"
|
||||
# auth_query = "SELECT $1"
|
||||
|
||||
# User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
|
||||
# specified in `auth_query_user`. The connection will be established using the database configured in the pool.
|
||||
@@ -222,20 +154,12 @@ connect_timeout = 3000
|
||||
# Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
|
||||
# dns_max_ttl = 30
|
||||
|
||||
# Plugins can be configured on a pool-per-pool basis. This overrides the global plugins setting,
|
||||
# so all plugins have to be configured here again.
|
||||
[pool.sharded_db.plugins]
|
||||
[plugins]
|
||||
|
||||
[pools.sharded_db.plugins.prewarmer]
|
||||
enabled = true
|
||||
queries = [
|
||||
"SELECT pg_prewarm('pgbench_accounts')",
|
||||
]
|
||||
|
||||
[pools.sharded_db.plugins.query_logger]
|
||||
[plugins.query_logger]
|
||||
enabled = false
|
||||
|
||||
[pools.sharded_db.plugins.table_access]
|
||||
[plugins.table_access]
|
||||
enabled = false
|
||||
tables = [
|
||||
"pg_user",
|
||||
@@ -243,10 +167,10 @@ tables = [
|
||||
"pg_database",
|
||||
]
|
||||
|
||||
[pools.sharded_db.plugins.intercept]
|
||||
[plugins.intercept]
|
||||
enabled = true
|
||||
|
||||
[pools.sharded_db.plugins.intercept.queries.0]
|
||||
[plugins.intercept.queries.0]
|
||||
|
||||
query = "select current_database() as a, current_schemas(false) as b"
|
||||
schema = [
|
||||
@@ -257,7 +181,7 @@ result = [
|
||||
["${DATABASE}", "{public}"],
|
||||
]
|
||||
|
||||
[pools.sharded_db.plugins.intercept.queries.1]
|
||||
[plugins.intercept.queries.1]
|
||||
|
||||
query = "select current_database(), current_schema(), current_user"
|
||||
schema = [
|
||||
@@ -280,7 +204,7 @@ username = "sharding_user"
|
||||
# if `server_password` is not set.
|
||||
password = "sharding_user"
|
||||
|
||||
pool_mode = "transaction"
|
||||
pool_mode = "session"
|
||||
|
||||
# PostgreSQL username used to connect to the server.
|
||||
# server_username = "another_user"
|
||||
|
||||
9
postinst
9
postinst
@@ -1,9 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
systemctl daemon-reload
|
||||
systemctl enable pgcat
|
||||
|
||||
if ! id pgcat 2> /dev/null; then
|
||||
useradd -s /usr/bin/false pgcat
|
||||
fi
|
||||
305
src/admin.rs
305
src/admin.rs
@@ -1,6 +1,4 @@
|
||||
use crate::pool::BanReason;
|
||||
use crate::server::ServerParameters;
|
||||
use crate::stats::pool::PoolStats;
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use log::{error, info, trace};
|
||||
use nix::sys::signal::{self, Signal};
|
||||
@@ -16,18 +14,18 @@ use crate::errors::Error;
|
||||
use crate::messages::*;
|
||||
use crate::pool::ClientServerMap;
|
||||
use crate::pool::{get_all_pools, get_pool};
|
||||
use crate::stats::{get_client_stats, get_server_stats, ClientState, ServerState};
|
||||
use crate::stats::{get_client_stats, get_pool_stats, get_server_stats, ClientState, ServerState};
|
||||
|
||||
pub fn generate_server_parameters_for_admin() -> ServerParameters {
|
||||
let mut server_parameters = ServerParameters::new();
|
||||
pub fn generate_server_info_for_admin() -> BytesMut {
|
||||
let mut server_info = BytesMut::new();
|
||||
|
||||
server_parameters.set_param("application_name".to_string(), "".to_string(), true);
|
||||
server_parameters.set_param("client_encoding".to_string(), "UTF8".to_string(), true);
|
||||
server_parameters.set_param("server_encoding".to_string(), "UTF8".to_string(), true);
|
||||
server_parameters.set_param("server_version".to_string(), VERSION.to_string(), true);
|
||||
server_parameters.set_param("DateStyle".to_string(), "ISO, MDY".to_string(), true);
|
||||
server_info.put(server_parameter_message("application_name", ""));
|
||||
server_info.put(server_parameter_message("client_encoding", "UTF8"));
|
||||
server_info.put(server_parameter_message("server_encoding", "UTF8"));
|
||||
server_info.put(server_parameter_message("server_version", VERSION));
|
||||
server_info.put(server_parameter_message("DateStyle", "ISO, MDY"));
|
||||
|
||||
server_parameters
|
||||
server_info
|
||||
}
|
||||
|
||||
/// Handle admin client.
|
||||
@@ -74,21 +72,17 @@ where
|
||||
}
|
||||
"PAUSE" => {
|
||||
trace!("PAUSE");
|
||||
pause(stream, query_parts).await
|
||||
pause(stream, query_parts[1]).await
|
||||
}
|
||||
"RESUME" => {
|
||||
trace!("RESUME");
|
||||
resume(stream, query_parts).await
|
||||
resume(stream, query_parts[1]).await
|
||||
}
|
||||
"SHUTDOWN" => {
|
||||
trace!("SHUTDOWN");
|
||||
shutdown(stream).await
|
||||
}
|
||||
"SHOW" => match query_parts[1].to_ascii_uppercase().as_str() {
|
||||
"HELP" => {
|
||||
trace!("SHOW HELP");
|
||||
show_help(stream).await
|
||||
}
|
||||
"BANS" => {
|
||||
trace!("SHOW BANS");
|
||||
show_bans(stream).await
|
||||
@@ -260,51 +254,39 @@ async fn show_pools<T>(stream: &mut T) -> Result<(), Error>
|
||||
where
|
||||
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
let pool_lookup = PoolStats::construct_pool_lookup();
|
||||
let mut res = BytesMut::new();
|
||||
res.put(row_description(&PoolStats::generate_header()));
|
||||
pool_lookup.iter().for_each(|(_identifier, pool_stats)| {
|
||||
res.put(data_row(&pool_stats.generate_row()));
|
||||
});
|
||||
res.put(command_complete("SHOW"));
|
||||
let all_pool_stats = get_pool_stats();
|
||||
|
||||
// ReadyForQuery
|
||||
res.put_u8(b'Z');
|
||||
res.put_i32(5);
|
||||
res.put_u8(b'I');
|
||||
|
||||
write_all_half(stream, &res).await
|
||||
}
|
||||
|
||||
/// Show all available options.
|
||||
async fn show_help<T>(stream: &mut T) -> Result<(), Error>
|
||||
where
|
||||
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
let mut res = BytesMut::new();
|
||||
|
||||
let detail_msg = vec![
|
||||
"",
|
||||
"SHOW HELP|CONFIG|DATABASES|POOLS|CLIENTS|SERVERS|USERS|VERSION",
|
||||
// "SHOW PEERS|PEER_POOLS", // missing PEERS|PEER_POOLS
|
||||
// "SHOW FDS|SOCKETS|ACTIVE_SOCKETS|LISTS|MEM|STATE", // missing FDS|SOCKETS|ACTIVE_SOCKETS|MEM|STATE
|
||||
"SHOW LISTS",
|
||||
// "SHOW DNS_HOSTS|DNS_ZONES", // missing DNS_HOSTS|DNS_ZONES
|
||||
"SHOW STATS", // missing STATS_TOTALS|STATS_AVERAGES|TOTALS
|
||||
"SET key = arg",
|
||||
"RELOAD",
|
||||
"PAUSE [<db>, <user>]",
|
||||
"RESUME [<db>, <user>]",
|
||||
// "DISABLE <db>", // missing
|
||||
// "ENABLE <db>", // missing
|
||||
// "RECONNECT [<db>]", missing
|
||||
// "KILL <db>",
|
||||
// "SUSPEND",
|
||||
"SHUTDOWN",
|
||||
// "WAIT_CLOSE [<db>]", // missing
|
||||
let columns = vec![
|
||||
("database", DataType::Text),
|
||||
("user", DataType::Text),
|
||||
("pool_mode", DataType::Text),
|
||||
("cl_idle", DataType::Numeric),
|
||||
("cl_active", DataType::Numeric),
|
||||
("cl_waiting", DataType::Numeric),
|
||||
("cl_cancel_req", DataType::Numeric),
|
||||
("sv_active", DataType::Numeric),
|
||||
("sv_idle", DataType::Numeric),
|
||||
("sv_used", DataType::Numeric),
|
||||
("sv_tested", DataType::Numeric),
|
||||
("sv_login", DataType::Numeric),
|
||||
("maxwait", DataType::Numeric),
|
||||
("maxwait_us", DataType::Numeric),
|
||||
];
|
||||
|
||||
res.put(notify("Console usage", detail_msg.join("\n\t")));
|
||||
let mut res = BytesMut::new();
|
||||
res.put(row_description(&columns));
|
||||
|
||||
for ((_user_pool, _pool), pool_stats) in all_pool_stats {
|
||||
let mut row = vec![
|
||||
pool_stats.database(),
|
||||
pool_stats.user(),
|
||||
pool_stats.pool_mode().to_string(),
|
||||
];
|
||||
pool_stats.populate_row(&mut row);
|
||||
pool_stats.clear_maxwait();
|
||||
res.put(data_row(&row));
|
||||
}
|
||||
|
||||
res.put(command_complete("SHOW"));
|
||||
|
||||
// ReadyForQuery
|
||||
@@ -352,17 +334,17 @@ where
|
||||
let paused = pool.paused();
|
||||
|
||||
res.put(data_row(&vec![
|
||||
address.name(), // name
|
||||
address.host.to_string(), // host
|
||||
address.port.to_string(), // port
|
||||
database_name.to_string(), // database
|
||||
pool_config.user.username.to_string(), // force_user
|
||||
pool_config.user.pool_size.to_string(), // pool_size
|
||||
pool_config.user.min_pool_size.unwrap_or(0).to_string(), // min_pool_size
|
||||
"0".to_string(), // reserve_pool
|
||||
pool_config.pool_mode.to_string(), // pool_mode
|
||||
pool_config.user.pool_size.to_string(), // max_connections
|
||||
pool_state.connections.to_string(), // current_connections
|
||||
address.name(), // name
|
||||
address.host.to_string(), // host
|
||||
address.port.to_string(), // port
|
||||
database_name.to_string(), // database
|
||||
pool_config.user.username.to_string(), // force_user
|
||||
pool_config.user.pool_size.to_string(), // pool_size
|
||||
"0".to_string(), // min_pool_size
|
||||
"0".to_string(), // reserve_pool
|
||||
pool_config.pool_mode.to_string(), // pool_mode
|
||||
pool_config.user.pool_size.to_string(), // max_connections
|
||||
pool_state.connections.to_string(), // current_connections
|
||||
match paused {
|
||||
// paused
|
||||
true => "1".to_string(),
|
||||
@@ -743,9 +725,6 @@ where
|
||||
("bytes_sent", DataType::Numeric),
|
||||
("bytes_received", DataType::Numeric),
|
||||
("age_seconds", DataType::Numeric),
|
||||
("prepare_cache_hit", DataType::Numeric),
|
||||
("prepare_cache_miss", DataType::Numeric),
|
||||
("prepare_cache_size", DataType::Numeric),
|
||||
];
|
||||
|
||||
let new_map = get_server_stats();
|
||||
@@ -769,18 +748,6 @@ where
|
||||
.duration_since(server.connect_time())
|
||||
.as_secs()
|
||||
.to_string(),
|
||||
server
|
||||
.prepared_hit_count
|
||||
.load(Ordering::Relaxed)
|
||||
.to_string(),
|
||||
server
|
||||
.prepared_miss_count
|
||||
.load(Ordering::Relaxed)
|
||||
.to_string(),
|
||||
server
|
||||
.prepared_cache_size
|
||||
.load(Ordering::Relaxed)
|
||||
.to_string(),
|
||||
];
|
||||
|
||||
res.put(data_row(&row));
|
||||
@@ -797,128 +764,96 @@ where
|
||||
}
|
||||
|
||||
/// Pause a pool. It won't pass any more queries to the backends.
|
||||
async fn pause<T>(stream: &mut T, tokens: Vec<&str>) -> Result<(), Error>
|
||||
async fn pause<T>(stream: &mut T, query: &str) -> Result<(), Error>
|
||||
where
|
||||
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
let parts: Vec<&str> = match tokens.len() == 2 {
|
||||
true => tokens[1].split(",").map(|part| part.trim()).collect(),
|
||||
false => Vec::new(),
|
||||
};
|
||||
let parts: Vec<&str> = query.split(",").map(|part| part.trim()).collect();
|
||||
|
||||
match parts.len() {
|
||||
0 => {
|
||||
for (_, pool) in get_all_pools() {
|
||||
if parts.len() != 2 {
|
||||
error_response(
|
||||
stream,
|
||||
"PAUSE requires a database and a user, e.g. PAUSE my_db,my_user",
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
let database = parts[0];
|
||||
let user = parts[1];
|
||||
|
||||
match get_pool(database, user) {
|
||||
Some(pool) => {
|
||||
pool.pause();
|
||||
|
||||
let mut res = BytesMut::new();
|
||||
|
||||
res.put(command_complete(&format!("PAUSE {},{}", database, user)));
|
||||
|
||||
// ReadyForQuery
|
||||
res.put_u8(b'Z');
|
||||
res.put_i32(5);
|
||||
res.put_u8(b'I');
|
||||
|
||||
write_all_half(stream, &res).await
|
||||
}
|
||||
|
||||
let mut res = BytesMut::new();
|
||||
|
||||
res.put(command_complete("PAUSE"));
|
||||
|
||||
// ReadyForQuery
|
||||
res.put_u8(b'Z');
|
||||
res.put_i32(5);
|
||||
res.put_u8(b'I');
|
||||
|
||||
write_all_half(stream, &res).await
|
||||
}
|
||||
2 => {
|
||||
let database = parts[0];
|
||||
let user = parts[1];
|
||||
|
||||
match get_pool(database, user) {
|
||||
Some(pool) => {
|
||||
pool.pause();
|
||||
|
||||
let mut res = BytesMut::new();
|
||||
|
||||
res.put(command_complete(&format!("PAUSE {},{}", database, user)));
|
||||
|
||||
// ReadyForQuery
|
||||
res.put_u8(b'Z');
|
||||
res.put_i32(5);
|
||||
res.put_u8(b'I');
|
||||
|
||||
write_all_half(stream, &res).await
|
||||
}
|
||||
|
||||
None => {
|
||||
error_response(
|
||||
stream,
|
||||
&format!(
|
||||
"No pool configured for database: {}, user: {}",
|
||||
database, user
|
||||
),
|
||||
)
|
||||
.await
|
||||
}
|
||||
None => {
|
||||
error_response(
|
||||
stream,
|
||||
&format!(
|
||||
"No pool configured for database: {}, user: {}",
|
||||
database, user
|
||||
),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
_ => error_response(stream, "usage: PAUSE [db, user]").await,
|
||||
}
|
||||
}
|
||||
|
||||
/// Resume a pool. Queries are allowed again.
|
||||
async fn resume<T>(stream: &mut T, tokens: Vec<&str>) -> Result<(), Error>
|
||||
async fn resume<T>(stream: &mut T, query: &str) -> Result<(), Error>
|
||||
where
|
||||
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
let parts: Vec<&str> = match tokens.len() == 2 {
|
||||
true => tokens[1].split(",").map(|part| part.trim()).collect(),
|
||||
false => Vec::new(),
|
||||
};
|
||||
let parts: Vec<&str> = query.split(",").map(|part| part.trim()).collect();
|
||||
|
||||
match parts.len() {
|
||||
0 => {
|
||||
for (_, pool) in get_all_pools() {
|
||||
if parts.len() != 2 {
|
||||
error_response(
|
||||
stream,
|
||||
"RESUME requires a database and a user, e.g. RESUME my_db,my_user",
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
let database = parts[0];
|
||||
let user = parts[1];
|
||||
|
||||
match get_pool(database, user) {
|
||||
Some(pool) => {
|
||||
pool.resume();
|
||||
|
||||
let mut res = BytesMut::new();
|
||||
|
||||
res.put(command_complete(&format!("RESUME {},{}", database, user)));
|
||||
|
||||
// ReadyForQuery
|
||||
res.put_u8(b'Z');
|
||||
res.put_i32(5);
|
||||
res.put_u8(b'I');
|
||||
|
||||
write_all_half(stream, &res).await
|
||||
}
|
||||
|
||||
let mut res = BytesMut::new();
|
||||
|
||||
res.put(command_complete("RESUME"));
|
||||
|
||||
// ReadyForQuery
|
||||
res.put_u8(b'Z');
|
||||
res.put_i32(5);
|
||||
res.put_u8(b'I');
|
||||
|
||||
write_all_half(stream, &res).await
|
||||
}
|
||||
2 => {
|
||||
let database = parts[0];
|
||||
let user = parts[1];
|
||||
|
||||
match get_pool(database, user) {
|
||||
Some(pool) => {
|
||||
pool.resume();
|
||||
|
||||
let mut res = BytesMut::new();
|
||||
|
||||
res.put(command_complete(&format!("RESUME {},{}", database, user)));
|
||||
|
||||
// ReadyForQuery
|
||||
res.put_u8(b'Z');
|
||||
res.put_i32(5);
|
||||
res.put_u8(b'I');
|
||||
|
||||
write_all_half(stream, &res).await
|
||||
}
|
||||
|
||||
None => {
|
||||
error_response(
|
||||
stream,
|
||||
&format!(
|
||||
"No pool configured for database: {}, user: {}",
|
||||
database, user
|
||||
),
|
||||
)
|
||||
.await
|
||||
}
|
||||
None => {
|
||||
error_response(
|
||||
stream,
|
||||
&format!(
|
||||
"No pool configured for database: {}, user: {}",
|
||||
database, user
|
||||
),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
_ => error_response(stream, "usage: RESUME [db, user]").await,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
601
src/client.rs
601
src/client.rs
@@ -3,36 +3,28 @@ use crate::pool::BanReason;
|
||||
/// Handle clients by pretending to be a PostgreSQL server.
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use log::{debug, error, info, trace, warn};
|
||||
use once_cell::sync::Lazy;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{atomic::AtomicUsize, Arc};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::broadcast::Receiver;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
|
||||
use crate::admin::{generate_server_parameters_for_admin, handle_admin};
|
||||
use crate::admin::{generate_server_info_for_admin, handle_admin};
|
||||
use crate::auth_passthrough::refetch_auth_hash;
|
||||
use crate::config::{
|
||||
get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address, PoolMode,
|
||||
};
|
||||
use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode};
|
||||
use crate::constants::*;
|
||||
use crate::messages::*;
|
||||
use crate::plugins::PluginOutput;
|
||||
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
|
||||
use crate::query_router::{Command, QueryRouter};
|
||||
use crate::server::{Server, ServerParameters};
|
||||
use crate::stats::{ClientStats, ServerStats};
|
||||
use crate::server::Server;
|
||||
use crate::stats::{ClientStats, PoolStats, ServerStats};
|
||||
use crate::tls::Tls;
|
||||
|
||||
use tokio_rustls::server::TlsStream;
|
||||
|
||||
/// Incrementally count prepared statements
|
||||
/// to avoid random conflicts in places where the random number generator is weak.
|
||||
pub static PREPARED_STATEMENT_COUNTER: Lazy<Arc<AtomicUsize>> =
|
||||
Lazy::new(|| Arc::new(AtomicUsize::new(0)));
|
||||
|
||||
/// Type of connection received from client.
|
||||
enum ClientConnectionType {
|
||||
Startup,
|
||||
@@ -96,14 +88,11 @@ pub struct Client<S, T> {
|
||||
/// Postgres user for this client (This comes from the user in the connection string)
|
||||
username: String,
|
||||
|
||||
/// Server startup and session parameters that we're going to track
|
||||
server_parameters: ServerParameters,
|
||||
/// Application name for this client (defaults to pgcat)
|
||||
application_name: String,
|
||||
|
||||
/// Used to notify clients about an impending shutdown
|
||||
shutdown: Receiver<()>,
|
||||
|
||||
/// Prepared statements
|
||||
prepared_statements: HashMap<String, Parse>,
|
||||
}
|
||||
|
||||
/// Client entrypoint.
|
||||
@@ -117,15 +106,7 @@ pub async fn client_entrypoint(
|
||||
log_client_connections: bool,
|
||||
) -> Result<(), Error> {
|
||||
// Figure out if the client wants TLS or not.
|
||||
let addr = match stream.peer_addr() {
|
||||
Ok(addr) => addr,
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Failed to get peer address: {:?}",
|
||||
err
|
||||
)));
|
||||
}
|
||||
};
|
||||
let addr = stream.peer_addr().unwrap();
|
||||
|
||||
match get_startup::<TcpStream>(&mut stream).await {
|
||||
// Client requested a TLS connection.
|
||||
@@ -155,10 +136,10 @@ pub async fn client_entrypoint(
|
||||
|
||||
if !client.is_admin() {
|
||||
let _ = drain.send(-1).await;
|
||||
}
|
||||
|
||||
if result.is_err() {
|
||||
client.stats.disconnect();
|
||||
if result.is_err() {
|
||||
client.stats.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
@@ -207,10 +188,10 @@ pub async fn client_entrypoint(
|
||||
|
||||
if !client.is_admin() {
|
||||
let _ = drain.send(-1).await;
|
||||
}
|
||||
|
||||
if result.is_err() {
|
||||
client.stats.disconnect();
|
||||
if result.is_err() {
|
||||
client.stats.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
@@ -261,10 +242,10 @@ pub async fn client_entrypoint(
|
||||
|
||||
if !client.is_admin() {
|
||||
let _ = drain.send(-1).await;
|
||||
}
|
||||
|
||||
if result.is_err() {
|
||||
client.stats.disconnect();
|
||||
if result.is_err() {
|
||||
client.stats.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
@@ -290,12 +271,11 @@ pub async fn client_entrypoint(
|
||||
|
||||
if !client.is_admin() {
|
||||
let _ = drain.send(-1).await;
|
||||
}
|
||||
|
||||
if result.is_err() {
|
||||
client.stats.disconnect();
|
||||
if result.is_err() {
|
||||
client.stats.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
@@ -357,15 +337,7 @@ pub async fn startup_tls(
|
||||
) -> Result<Client<ReadHalf<TlsStream<TcpStream>>, WriteHalf<TlsStream<TcpStream>>>, Error> {
|
||||
// Negotiate TLS.
|
||||
let tls = Tls::new()?;
|
||||
let addr = match stream.peer_addr() {
|
||||
Ok(addr) => addr,
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Failed to get peer address: {:?}",
|
||||
err
|
||||
)));
|
||||
}
|
||||
};
|
||||
let addr = stream.peer_addr().unwrap();
|
||||
|
||||
let mut stream = match tls.acceptor.accept(stream).await {
|
||||
Ok(stream) => stream,
|
||||
@@ -519,7 +491,7 @@ where
|
||||
};
|
||||
|
||||
// Authenticate admin user.
|
||||
let (transaction_mode, mut server_parameters) = if admin {
|
||||
let (transaction_mode, server_info) = if admin {
|
||||
let config = get_config();
|
||||
|
||||
// Compare server and client hashes.
|
||||
@@ -538,7 +510,7 @@ where
|
||||
return Err(error);
|
||||
}
|
||||
|
||||
(false, generate_server_parameters_for_admin())
|
||||
(false, generate_server_info_for_admin())
|
||||
}
|
||||
// Authenticate normal user.
|
||||
else {
|
||||
@@ -671,26 +643,35 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
(transaction_mode, pool.server_parameters())
|
||||
(transaction_mode, pool.server_info())
|
||||
};
|
||||
|
||||
// Update the parameters to merge what the application sent and what's originally on the server
|
||||
server_parameters.set_from_hashmap(¶meters, false);
|
||||
|
||||
debug!("Password authentication successful");
|
||||
|
||||
auth_ok(&mut write).await?;
|
||||
write_all(&mut write, (&server_parameters).into()).await?;
|
||||
write_all(&mut write, server_info).await?;
|
||||
backend_key_data(&mut write, process_id, secret_key).await?;
|
||||
ready_for_query(&mut write).await?;
|
||||
|
||||
trace!("Startup OK");
|
||||
let pool_stats = match get_pool(pool_name, username) {
|
||||
Some(pool) => {
|
||||
if !admin {
|
||||
pool.stats
|
||||
} else {
|
||||
Arc::new(PoolStats::default())
|
||||
}
|
||||
}
|
||||
None => Arc::new(PoolStats::default()),
|
||||
};
|
||||
|
||||
let stats = Arc::new(ClientStats::new(
|
||||
process_id,
|
||||
application_name,
|
||||
username,
|
||||
pool_name,
|
||||
tokio::time::Instant::now(),
|
||||
pool_stats,
|
||||
));
|
||||
|
||||
Ok(Client {
|
||||
@@ -710,10 +691,9 @@ where
|
||||
last_server_stats: None,
|
||||
pool_name: pool_name.clone(),
|
||||
username: username.clone(),
|
||||
server_parameters,
|
||||
application_name: application_name.to_string(),
|
||||
shutdown,
|
||||
connected_to_server: false,
|
||||
prepared_statements: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -745,10 +725,9 @@ where
|
||||
last_server_stats: None,
|
||||
pool_name: String::from("undefined"),
|
||||
username: String::from("undefined"),
|
||||
server_parameters: ServerParameters::new(),
|
||||
application_name: String::from("undefined"),
|
||||
shutdown,
|
||||
connected_to_server: false,
|
||||
prepared_statements: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -790,16 +769,6 @@ where
|
||||
// Result returned by one of the plugins.
|
||||
let mut plugin_output = None;
|
||||
|
||||
// Prepared statement being executed
|
||||
let mut prepared_statement = None;
|
||||
let mut will_prepare = false;
|
||||
|
||||
let client_identifier = ClientIdentifier::new(
|
||||
&self.server_parameters.get_application_name(),
|
||||
&self.username,
|
||||
&self.pool_name,
|
||||
);
|
||||
|
||||
// Our custom protocol loop.
|
||||
// We expect the client to either start a transaction with regular queries
|
||||
// or issue commands for our sharding and server selection protocol.
|
||||
@@ -809,16 +778,13 @@ where
|
||||
self.transaction_mode
|
||||
);
|
||||
|
||||
// Should we rewrite prepared statements and bind messages?
|
||||
let mut prepared_statements_enabled = get_prepared_statements();
|
||||
|
||||
// Read a complete message from the client, which normally would be
|
||||
// either a `Q` (query) or `P` (prepare, extended protocol).
|
||||
// We can parse it here before grabbing a server from the pool,
|
||||
// in case the client is sending some custom protocol messages, e.g.
|
||||
// SET SHARDING KEY TO 'bigint';
|
||||
|
||||
let mut message = tokio::select! {
|
||||
let message = tokio::select! {
|
||||
_ = self.shutdown.recv() => {
|
||||
if !self.admin {
|
||||
error_response_terminal(
|
||||
@@ -838,29 +804,6 @@ where
|
||||
message_result = read_message(&mut self.read) => message_result?
|
||||
};
|
||||
|
||||
if message[0] as char == 'X' {
|
||||
debug!("Client disconnecting");
|
||||
|
||||
self.stats.disconnect();
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Handle admin database queries.
|
||||
if self.admin {
|
||||
debug!("Handling admin command");
|
||||
handle_admin(&mut self.write, message, self.client_server_map.clone()).await?;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get a pool instance referenced by the most up-to-date
|
||||
// pointer. This ensures we always read the latest config
|
||||
// when starting a query.
|
||||
let mut pool = self.get_pool().await?;
|
||||
query_router.update_pool_settings(pool.settings.clone());
|
||||
|
||||
let mut initial_parsed_ast = None;
|
||||
|
||||
match message[0] as char {
|
||||
// Buffer extended protocol messages even if we do not have
|
||||
// a server connection yet. Hopefully, when we get the S message
|
||||
@@ -869,93 +812,52 @@ where
|
||||
// allocate a connection, we wouldn't be able to send back an error message
|
||||
// to the client so we buffer them and defer the decision to error out or not
|
||||
// to when we get the S message
|
||||
'D' => {
|
||||
if prepared_statements_enabled {
|
||||
let name;
|
||||
(name, message) = self.rewrite_describe(message).await?;
|
||||
|
||||
if let Some(name) = name {
|
||||
prepared_statement = Some(name);
|
||||
}
|
||||
}
|
||||
|
||||
self.buffer.put(&message[..]);
|
||||
continue;
|
||||
}
|
||||
|
||||
'E' => {
|
||||
'D' | 'E' => {
|
||||
self.buffer.put(&message[..]);
|
||||
continue;
|
||||
}
|
||||
|
||||
'Q' => {
|
||||
if query_router.query_parser_enabled() {
|
||||
match query_router.parse(&message) {
|
||||
Ok(ast) => {
|
||||
let plugin_result = query_router.execute_plugins(&ast).await;
|
||||
if let Ok(ast) = QueryRouter::parse(&message) {
|
||||
let plugin_result = query_router.execute_plugins(&ast).await;
|
||||
|
||||
match plugin_result {
|
||||
Ok(PluginOutput::Deny(error)) => {
|
||||
error_response(&mut self.write, &error).await?;
|
||||
continue;
|
||||
}
|
||||
match plugin_result {
|
||||
Ok(PluginOutput::Deny(error)) => {
|
||||
error_response(&mut self.write, &error).await?;
|
||||
continue;
|
||||
}
|
||||
|
||||
Ok(PluginOutput::Intercept(result)) => {
|
||||
write_all(&mut self.write, result).await?;
|
||||
continue;
|
||||
}
|
||||
Ok(PluginOutput::Intercept(result)) => {
|
||||
write_all(&mut self.write, result).await?;
|
||||
continue;
|
||||
}
|
||||
|
||||
_ => (),
|
||||
};
|
||||
_ => (),
|
||||
};
|
||||
|
||||
let _ = query_router.infer(&ast);
|
||||
|
||||
initial_parsed_ast = Some(ast);
|
||||
}
|
||||
Err(error) => {
|
||||
warn!(
|
||||
"Query parsing error: {} (client: {})",
|
||||
error, client_identifier
|
||||
);
|
||||
}
|
||||
let _ = query_router.infer(&ast);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
'P' => {
|
||||
if prepared_statements_enabled {
|
||||
(prepared_statement, message) = self.rewrite_parse(message)?;
|
||||
will_prepare = true;
|
||||
}
|
||||
|
||||
self.buffer.put(&message[..]);
|
||||
|
||||
if query_router.query_parser_enabled() {
|
||||
match query_router.parse(&message) {
|
||||
Ok(ast) => {
|
||||
if let Ok(output) = query_router.execute_plugins(&ast).await {
|
||||
plugin_output = Some(output);
|
||||
}
|
||||
if let Ok(ast) = QueryRouter::parse(&message) {
|
||||
if let Ok(output) = query_router.execute_plugins(&ast).await {
|
||||
plugin_output = Some(output);
|
||||
}
|
||||
|
||||
let _ = query_router.infer(&ast);
|
||||
}
|
||||
Err(error) => {
|
||||
warn!(
|
||||
"Query parsing error: {} (client: {})",
|
||||
error, client_identifier
|
||||
);
|
||||
}
|
||||
};
|
||||
let _ = query_router.infer(&ast);
|
||||
}
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
'B' => {
|
||||
if prepared_statements_enabled {
|
||||
(prepared_statement, message) = self.rewrite_bind(message).await?;
|
||||
}
|
||||
|
||||
self.buffer.put(&message[..]);
|
||||
|
||||
if query_router.query_parser_enabled() {
|
||||
@@ -965,22 +867,24 @@ where
|
||||
continue;
|
||||
}
|
||||
|
||||
// Close (F)
|
||||
'C' => {
|
||||
if prepared_statements_enabled {
|
||||
let close: Close = (&message).try_into()?;
|
||||
'X' => {
|
||||
debug!("Client disconnecting");
|
||||
|
||||
if close.is_prepared_statement() && !close.anonymous() {
|
||||
self.prepared_statements.remove(&close.name);
|
||||
write_all_flush(&mut self.write, &close_complete()).await?;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
self.stats.disconnect();
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
_ => (),
|
||||
}
|
||||
|
||||
// Handle admin database queries.
|
||||
if self.admin {
|
||||
debug!("Handling admin command");
|
||||
handle_admin(&mut self.write, message, self.client_server_map.clone()).await?;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check on plugin results.
|
||||
match plugin_output {
|
||||
Some(PluginOutput::Deny(error)) => {
|
||||
@@ -993,6 +897,11 @@ where
|
||||
_ => (),
|
||||
};
|
||||
|
||||
// Get a pool instance referenced by the most up-to-date
|
||||
// pointer. This ensures we always read the latest config
|
||||
// when starting a query.
|
||||
let mut pool = self.get_pool().await?;
|
||||
|
||||
// Check if the pool is paused and wait until it's resumed.
|
||||
if pool.wait_paused().await {
|
||||
// Refresh pool information, something might have changed.
|
||||
@@ -1010,27 +919,23 @@ where
|
||||
|
||||
// SET SHARD TO
|
||||
Some((Command::SetShard, _)) => {
|
||||
match query_router.shard() {
|
||||
None => (),
|
||||
Some(selected_shard) => {
|
||||
if selected_shard >= pool.shards() {
|
||||
// Bad shard number, send error message to client.
|
||||
query_router.set_shard(current_shard);
|
||||
// Selected shard is not configured.
|
||||
if query_router.shard() >= pool.shards() {
|
||||
// Set the shard back to what it was.
|
||||
query_router.set_shard(current_shard);
|
||||
|
||||
error_response(
|
||||
&mut self.write,
|
||||
&format!(
|
||||
"shard {} is not configured {}, staying on shard {:?} (shard numbers start at 0)",
|
||||
selected_shard,
|
||||
pool.shards(),
|
||||
current_shard,
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
custom_protocol_response_ok(&mut self.write, "SET SHARD").await?;
|
||||
}
|
||||
}
|
||||
error_response(
|
||||
&mut self.write,
|
||||
&format!(
|
||||
"shard {} is more than configured {}, staying on shard {} (shard numbers start at 0)",
|
||||
query_router.shard(),
|
||||
pool.shards(),
|
||||
current_shard,
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
custom_protocol_response_ok(&mut self.write, "SET SHARD").await?;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
@@ -1098,11 +1003,8 @@ where
|
||||
self.buffer.clear();
|
||||
}
|
||||
|
||||
error_response(
|
||||
&mut self.write,
|
||||
format!("could not get connection from the pool - {}", err).as_str(),
|
||||
)
|
||||
.await?;
|
||||
error_response(&mut self.write, "could not get connection from the pool")
|
||||
.await?;
|
||||
|
||||
error!(
|
||||
"Could not get connection from pool: \
|
||||
@@ -1145,7 +1047,10 @@ where
|
||||
server.address()
|
||||
);
|
||||
|
||||
server.sync_parameters(&self.server_parameters).await?;
|
||||
// TODO: investigate other parameters and set them too.
|
||||
|
||||
// Set application_name.
|
||||
server.set_name(&self.application_name).await?;
|
||||
|
||||
let mut initial_message = Some(message);
|
||||
|
||||
@@ -1161,64 +1066,10 @@ where
|
||||
// If the client is in session mode, no more custom protocol
|
||||
// commands will be accepted.
|
||||
loop {
|
||||
// Only check if we should rewrite prepared statements
|
||||
// in session mode. In transaction mode, we check at the beginning of
|
||||
// each transaction.
|
||||
if !self.transaction_mode {
|
||||
prepared_statements_enabled = get_prepared_statements();
|
||||
}
|
||||
|
||||
debug!("Prepared statement active: {:?}", prepared_statement);
|
||||
|
||||
// We are processing a prepared statement.
|
||||
if let Some(ref name) = prepared_statement {
|
||||
debug!("Checking prepared statement is on server");
|
||||
// Get the prepared statement the server expects to see.
|
||||
let statement = match self.prepared_statements.get(name) {
|
||||
Some(statement) => {
|
||||
debug!("Prepared statement `{}` found in cache", name);
|
||||
statement
|
||||
}
|
||||
None => {
|
||||
return Err(Error::ClientError(format!(
|
||||
"prepared statement `{}` not found",
|
||||
name
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
// Since it's already in the buffer, we don't need to prepare it on this server.
|
||||
if will_prepare {
|
||||
server.will_prepare(&statement.name);
|
||||
will_prepare = false;
|
||||
} else {
|
||||
// The statement is not prepared on the server, so we need to prepare it.
|
||||
if server.should_prepare(&statement.name) {
|
||||
match server.prepare(statement).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
pool.ban(
|
||||
&address,
|
||||
BanReason::MessageSendFailed,
|
||||
Some(&self.stats),
|
||||
);
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Done processing the prepared statement.
|
||||
prepared_statement = None;
|
||||
}
|
||||
|
||||
let mut message = match initial_message {
|
||||
let message = match initial_message {
|
||||
None => {
|
||||
trace!("Waiting for message inside transaction or in session mode");
|
||||
|
||||
// This is not an initial message so discard the initial_parsed_ast
|
||||
initial_parsed_ast.take();
|
||||
|
||||
match tokio::time::timeout(
|
||||
idle_client_timeout_duration,
|
||||
read_message(&mut self.read),
|
||||
@@ -1242,7 +1093,7 @@ where
|
||||
{{ \
|
||||
pool_name: {}, \
|
||||
username: {}, \
|
||||
shard: {:?}, \
|
||||
shard: {}, \
|
||||
role: \"{:?}\" \
|
||||
}}",
|
||||
self.pool_name,
|
||||
@@ -1275,22 +1126,7 @@ where
|
||||
// Query
|
||||
'Q' => {
|
||||
if query_router.query_parser_enabled() {
|
||||
// We don't want to parse again if we already parsed it as the initial message
|
||||
let ast = match initial_parsed_ast {
|
||||
Some(_) => Some(initial_parsed_ast.take().unwrap()),
|
||||
None => match query_router.parse(&message) {
|
||||
Ok(ast) => Some(ast),
|
||||
Err(error) => {
|
||||
warn!(
|
||||
"Query parsing error: {} (client: {})",
|
||||
error, client_identifier
|
||||
);
|
||||
None
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
if let Some(ast) = ast {
|
||||
if let Ok(ast) = QueryRouter::parse(&message) {
|
||||
let plugin_result = query_router.execute_plugins(&ast).await;
|
||||
|
||||
match plugin_result {
|
||||
@@ -1306,6 +1142,8 @@ where
|
||||
|
||||
_ => (),
|
||||
};
|
||||
|
||||
let _ = query_router.infer(&ast);
|
||||
}
|
||||
}
|
||||
debug!("Sending query to server");
|
||||
@@ -1323,13 +1161,11 @@ where
|
||||
if !server.in_transaction() {
|
||||
// Report transaction executed statistics.
|
||||
self.stats.transaction();
|
||||
server
|
||||
.stats()
|
||||
.transaction(&self.server_parameters.get_application_name());
|
||||
server.stats().transaction(&self.application_name);
|
||||
|
||||
// Release server back to the pool if we are in transaction mode.
|
||||
// If we are in session mode, we keep the server until the client disconnects.
|
||||
if self.transaction_mode && !server.in_copy_mode() {
|
||||
if self.transaction_mode {
|
||||
self.stats.idle();
|
||||
|
||||
break;
|
||||
@@ -1343,23 +1179,14 @@ where
|
||||
self.stats.disconnect();
|
||||
self.release();
|
||||
|
||||
if prepared_statements_enabled {
|
||||
server.maintain_cache().await?;
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Parse
|
||||
// The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
|
||||
'P' => {
|
||||
if prepared_statements_enabled {
|
||||
(prepared_statement, message) = self.rewrite_parse(message)?;
|
||||
will_prepare = true;
|
||||
}
|
||||
|
||||
if query_router.query_parser_enabled() {
|
||||
if let Ok(ast) = query_router.parse(&message) {
|
||||
if let Ok(ast) = QueryRouter::parse(&message) {
|
||||
if let Ok(output) = query_router.execute_plugins(&ast).await {
|
||||
plugin_output = Some(output);
|
||||
}
|
||||
@@ -1372,45 +1199,17 @@ where
|
||||
// Bind
|
||||
// The placeholder's replacements are here, e.g. 'user@email.com' and 'true'
|
||||
'B' => {
|
||||
if prepared_statements_enabled {
|
||||
(prepared_statement, message) = self.rewrite_bind(message).await?;
|
||||
}
|
||||
|
||||
self.buffer.put(&message[..]);
|
||||
}
|
||||
|
||||
// Describe
|
||||
// Command a client can issue to describe a previously prepared named statement.
|
||||
'D' => {
|
||||
if prepared_statements_enabled {
|
||||
let name;
|
||||
(name, message) = self.rewrite_describe(message).await?;
|
||||
|
||||
if let Some(name) = name {
|
||||
prepared_statement = Some(name);
|
||||
}
|
||||
}
|
||||
|
||||
self.buffer.put(&message[..]);
|
||||
}
|
||||
|
||||
// Close the prepared statement.
|
||||
'C' => {
|
||||
if prepared_statements_enabled {
|
||||
let close: Close = (&message).try_into()?;
|
||||
|
||||
if close.is_prepared_statement() && !close.anonymous() {
|
||||
match self.prepared_statements.get(&close.name) {
|
||||
Some(parse) => {
|
||||
server.will_close(&parse.generated_name);
|
||||
}
|
||||
|
||||
// A prepared statement slipped through? Not impossible, since we don't support PREPARE yet.
|
||||
None => (),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
self.buffer.put(&message[..]);
|
||||
}
|
||||
|
||||
@@ -1448,7 +1247,7 @@ where
|
||||
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;
|
||||
|
||||
// Almost certainly true
|
||||
if first_message_code == 'P' && !prepared_statements_enabled {
|
||||
if first_message_code == 'P' {
|
||||
// Message layout
|
||||
// P followed by 32 int followed by null-terminated statement name
|
||||
// So message code should be in offset 0 of the buffer, first character
|
||||
@@ -1475,13 +1274,11 @@ where
|
||||
|
||||
if !server.in_transaction() {
|
||||
self.stats.transaction();
|
||||
server
|
||||
.stats()
|
||||
.transaction(&self.server_parameters.get_application_name());
|
||||
server.stats().transaction(&self.application_name);
|
||||
|
||||
// Release server back to the pool if we are in transaction mode.
|
||||
// If we are in session mode, we keep the server until the client disconnects.
|
||||
if self.transaction_mode && !server.in_copy_mode() {
|
||||
if self.transaction_mode {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -1516,7 +1313,7 @@ where
|
||||
.receive_server_message(server, &address, &pool, &self.stats.clone())
|
||||
.await?;
|
||||
|
||||
match write_all_flush(&mut self.write, &response).await {
|
||||
match write_all_half(&mut self.write, &response).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
server.mark_bad();
|
||||
@@ -1526,9 +1323,7 @@ where
|
||||
|
||||
if !server.in_transaction() {
|
||||
self.stats.transaction();
|
||||
server
|
||||
.stats()
|
||||
.transaction(self.server_parameters.get_application_name());
|
||||
server.stats().transaction(&self.application_name);
|
||||
|
||||
// Release server back to the pool if we are in transaction mode.
|
||||
// If we are in session mode, we keep the server until the client disconnects.
|
||||
@@ -1548,13 +1343,7 @@ where
|
||||
|
||||
// The server is no longer bound to us, we can't cancel it's queries anymore.
|
||||
debug!("Releasing server back into the pool");
|
||||
|
||||
server.checkin_cleanup().await?;
|
||||
|
||||
if prepared_statements_enabled {
|
||||
server.maintain_cache().await?;
|
||||
}
|
||||
|
||||
server.stats().idle();
|
||||
self.connected_to_server = false;
|
||||
|
||||
@@ -1580,115 +1369,12 @@ where
|
||||
|
||||
Err(Error::ClientError(format!(
|
||||
"Invalid pool name {{ username: {}, pool_name: {}, application_name: {} }}",
|
||||
self.pool_name,
|
||||
self.username,
|
||||
self.server_parameters.get_application_name()
|
||||
self.pool_name, self.username, self.application_name
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Rewrite Parse (F) message to set the prepared statement name to one we control.
|
||||
/// Save it into the client cache.
|
||||
fn rewrite_parse(&mut self, message: BytesMut) -> Result<(Option<String>, BytesMut), Error> {
|
||||
let parse: Parse = (&message).try_into()?;
|
||||
|
||||
let name = parse.name.clone();
|
||||
|
||||
// Don't rewrite anonymous prepared statements
|
||||
if parse.anonymous() {
|
||||
debug!("Anonymous prepared statement");
|
||||
return Ok((None, message));
|
||||
}
|
||||
|
||||
let parse = parse.rename();
|
||||
|
||||
debug!(
|
||||
"Renamed prepared statement `{}` to `{}` and saved to cache",
|
||||
name, parse.name
|
||||
);
|
||||
|
||||
self.prepared_statements.insert(name.clone(), parse.clone());
|
||||
|
||||
Ok((Some(name), parse.try_into()?))
|
||||
}
|
||||
|
||||
/// Rewrite the Bind (F) message to use the prepared statement name
|
||||
/// saved in the client cache.
|
||||
async fn rewrite_bind(
|
||||
&mut self,
|
||||
message: BytesMut,
|
||||
) -> Result<(Option<String>, BytesMut), Error> {
|
||||
let bind: Bind = (&message).try_into()?;
|
||||
let name = bind.prepared_statement.clone();
|
||||
|
||||
if bind.anonymous() {
|
||||
debug!("Anonymous bind message");
|
||||
return Ok((None, message));
|
||||
}
|
||||
|
||||
match self.prepared_statements.get(&name) {
|
||||
Some(prepared_stmt) => {
|
||||
let bind = bind.reassign(prepared_stmt);
|
||||
|
||||
debug!("Rewrote bind `{}` to `{}`", name, bind.prepared_statement);
|
||||
|
||||
Ok((Some(name), bind.try_into()?))
|
||||
}
|
||||
None => {
|
||||
debug!("Got bind for unknown prepared statement {:?}", bind);
|
||||
|
||||
error_response(
|
||||
&mut self.write,
|
||||
&format!(
|
||||
"prepared statement \"{}\" does not exist",
|
||||
bind.prepared_statement
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Err(Error::ClientError(format!(
|
||||
"Prepared statement `{}` doesn't exist",
|
||||
name
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Rewrite the Describe (F) message to use the prepared statement name
|
||||
/// saved in the client cache.
|
||||
async fn rewrite_describe(
|
||||
&mut self,
|
||||
message: BytesMut,
|
||||
) -> Result<(Option<String>, BytesMut), Error> {
|
||||
let describe: Describe = (&message).try_into()?;
|
||||
let name = describe.statement_name.clone();
|
||||
|
||||
if describe.anonymous() {
|
||||
debug!("Anonymous describe");
|
||||
return Ok((None, message));
|
||||
}
|
||||
|
||||
match self.prepared_statements.get(&name) {
|
||||
Some(prepared_stmt) => {
|
||||
let describe = describe.rename(&prepared_stmt.name);
|
||||
|
||||
debug!(
|
||||
"Rewrote describe `{}` to `{}`",
|
||||
name, describe.statement_name
|
||||
);
|
||||
|
||||
Ok((Some(name), describe.try_into()?))
|
||||
}
|
||||
|
||||
None => {
|
||||
debug!("Got describe for unknown prepared statement {:?}", describe);
|
||||
|
||||
Ok((None, message))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Release the server from the client: it can't cancel its queries anymore.
|
||||
pub fn release(&self) {
|
||||
let mut guard = self.client_server_map.lock();
|
||||
@@ -1722,7 +1408,7 @@ where
|
||||
.receive_server_message(server, address, pool, client_stats)
|
||||
.await?;
|
||||
|
||||
match write_all_flush(&mut self.write, &response).await {
|
||||
match write_all_half(&mut self.write, &response).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
server.mark_bad();
|
||||
@@ -1739,7 +1425,7 @@ where
|
||||
client_stats.query();
|
||||
server.stats().query(
|
||||
Instant::now().duration_since(query_start).as_millis() as u64,
|
||||
&self.server_parameters.get_application_name(),
|
||||
&self.application_name,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
@@ -1768,18 +1454,38 @@ where
|
||||
pool: &ConnectionPool,
|
||||
client_stats: &ClientStats,
|
||||
) -> Result<BytesMut, Error> {
|
||||
let statement_timeout_duration = match pool.settings.user.statement_timeout {
|
||||
0 => tokio::time::Duration::MAX,
|
||||
timeout => tokio::time::Duration::from_millis(timeout),
|
||||
};
|
||||
|
||||
match tokio::time::timeout(
|
||||
statement_timeout_duration,
|
||||
server.recv(Some(&mut self.server_parameters)),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => match result {
|
||||
if pool.settings.user.statement_timeout > 0 {
|
||||
match tokio::time::timeout(
|
||||
tokio::time::Duration::from_millis(pool.settings.user.statement_timeout),
|
||||
server.recv(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => match result {
|
||||
Ok(message) => Ok(message),
|
||||
Err(err) => {
|
||||
pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats));
|
||||
error_response_terminal(
|
||||
&mut self.write,
|
||||
&format!("error receiving data from server: {:?}", err),
|
||||
)
|
||||
.await?;
|
||||
Err(err)
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
error!(
|
||||
"Statement timeout while talking to {:?} with user {}",
|
||||
address, pool.settings.user.username
|
||||
);
|
||||
server.mark_bad();
|
||||
pool.ban(address, BanReason::StatementTimeout, Some(client_stats));
|
||||
error_response_terminal(&mut self.write, "pool statement timeout").await?;
|
||||
Err(Error::StatementTimeout)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
match server.recv().await {
|
||||
Ok(message) => Ok(message),
|
||||
Err(err) => {
|
||||
pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats));
|
||||
@@ -1790,16 +1496,6 @@ where
|
||||
.await?;
|
||||
Err(err)
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
error!(
|
||||
"Statement timeout while talking to {:?} with user {}",
|
||||
address, pool.settings.user.username
|
||||
);
|
||||
server.mark_bad();
|
||||
pool.ban(address, BanReason::StatementTimeout, Some(client_stats));
|
||||
error_response_terminal(&mut self.write, "pool statement timeout").await?;
|
||||
Err(Error::StatementTimeout)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1812,6 +1508,7 @@ impl<S, T> Drop for Client<S, T> {
|
||||
|
||||
// Dirty shutdown
|
||||
// TODO: refactor, this is not the best way to handle state management.
|
||||
|
||||
if self.connected_to_server && self.last_server_stats.is_some() {
|
||||
self.last_server_stats.as_ref().unwrap().idle();
|
||||
}
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
use clap::{Parser, ValueEnum};
|
||||
use tracing::Level;
|
||||
|
||||
/// PgCat: Nextgen PostgreSQL Pooler
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
#[arg(default_value_t = String::from("pgcat.toml"), env)]
|
||||
pub config_file: String,
|
||||
|
||||
#[arg(short, long, default_value_t = tracing::Level::INFO, env)]
|
||||
pub log_level: Level,
|
||||
|
||||
#[clap(short='F', long, value_enum, default_value_t=LogFormat::Text, env)]
|
||||
pub log_format: LogFormat,
|
||||
|
||||
#[arg(
|
||||
short,
|
||||
long,
|
||||
default_value_t = false,
|
||||
env,
|
||||
help = "disable colors in the log output"
|
||||
)]
|
||||
pub no_color: bool,
|
||||
}
|
||||
|
||||
pub fn parse() -> Args {
|
||||
return Args::parse();
|
||||
}
|
||||
|
||||
#[derive(ValueEnum, Clone, Debug)]
|
||||
pub enum LogFormat {
|
||||
Text,
|
||||
Structured,
|
||||
Debug,
|
||||
}
|
||||
321
src/config.rs
321
src/config.rs
@@ -1,16 +1,13 @@
|
||||
/// Parse the configuration file.
|
||||
use arc_swap::ArcSwap;
|
||||
use log::{error, info, warn};
|
||||
use log::{error, info};
|
||||
use once_cell::sync::Lazy;
|
||||
use regex::Regex;
|
||||
use serde::{Deserializer, Serializer};
|
||||
use serde_derive::{Deserialize, Serialize};
|
||||
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::collections::{BTreeMap, HashMap, HashSet};
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::path::Path;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::fs::File;
|
||||
use tokio::io::AsyncReadExt;
|
||||
@@ -104,9 +101,6 @@ pub struct Address {
|
||||
|
||||
/// Address stats
|
||||
pub stats: Arc<AddressStats>,
|
||||
|
||||
/// Number of errors encountered since last successful checkout
|
||||
pub error_count: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
impl Default for Address {
|
||||
@@ -124,21 +118,10 @@ impl Default for Address {
|
||||
pool_name: String::from("pool_name"),
|
||||
mirrors: Vec::new(),
|
||||
stats: Arc::new(AddressStats::default()),
|
||||
error_count: Arc::new(AtomicU64::new(0)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Address {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"[address: {}:{}][database: {}][user: {}]",
|
||||
self.host, self.port, self.database, self.username
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// We need to implement PartialEq by ourselves so we skip stats in the comparison
|
||||
impl PartialEq for Address {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
@@ -189,18 +172,6 @@ impl Address {
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn error_count(&self) -> u64 {
|
||||
self.error_count.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn increment_error_count(&self) {
|
||||
self.error_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn reset_error_count(&self) {
|
||||
self.error_count.store(0, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
/// PostgreSQL user.
|
||||
@@ -264,8 +235,6 @@ pub struct General {
|
||||
pub port: u16,
|
||||
|
||||
pub enable_prometheus_exporter: Option<bool>,
|
||||
|
||||
#[serde(default = "General::default_prometheus_exporter_port")]
|
||||
pub prometheus_exporter_port: i16,
|
||||
|
||||
#[serde(default = "General::default_connect_timeout")]
|
||||
@@ -280,8 +249,6 @@ pub struct General {
|
||||
pub tcp_keepalives_count: u32,
|
||||
#[serde(default = "General::default_tcp_keepalives_interval")]
|
||||
pub tcp_keepalives_interval: u64,
|
||||
#[serde(default = "General::default_tcp_user_timeout")]
|
||||
pub tcp_user_timeout: u64,
|
||||
|
||||
#[serde(default)] // False
|
||||
pub log_client_connections: bool,
|
||||
@@ -313,9 +280,6 @@ pub struct General {
|
||||
#[serde(default = "General::default_server_lifetime")]
|
||||
pub server_lifetime: u64,
|
||||
|
||||
#[serde(default = "General::default_server_round_robin")] // False
|
||||
pub server_round_robin: bool,
|
||||
|
||||
#[serde(default = "General::default_worker_threads")]
|
||||
pub worker_threads: usize,
|
||||
|
||||
@@ -334,19 +298,10 @@ pub struct General {
|
||||
pub admin_username: String,
|
||||
pub admin_password: String,
|
||||
|
||||
#[serde(default = "General::default_validate_config")]
|
||||
pub validate_config: bool,
|
||||
|
||||
// Support for auth query
|
||||
pub auth_query: Option<String>,
|
||||
pub auth_query_user: Option<String>,
|
||||
pub auth_query_password: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub prepared_statements: bool,
|
||||
|
||||
#[serde(default = "General::default_prepared_statements_cache_size")]
|
||||
pub prepared_statements_cache_size: usize,
|
||||
}
|
||||
|
||||
impl General {
|
||||
@@ -359,7 +314,7 @@ impl General {
|
||||
}
|
||||
|
||||
pub fn default_server_lifetime() -> u64 {
|
||||
1000 * 60 * 60 // 1 hour
|
||||
1000 * 60 * 60 * 24 // 24 hours
|
||||
}
|
||||
|
||||
pub fn default_connect_timeout() -> u64 {
|
||||
@@ -381,12 +336,8 @@ impl General {
|
||||
5 // 5 seconds
|
||||
}
|
||||
|
||||
pub fn default_tcp_user_timeout() -> u64 {
|
||||
10000 // 10000 milliseconds
|
||||
}
|
||||
|
||||
pub fn default_idle_timeout() -> u64 {
|
||||
600000 // 10 minutes
|
||||
60000 // 10 minutes
|
||||
}
|
||||
|
||||
pub fn default_shutdown_timeout() -> u64 {
|
||||
@@ -416,22 +367,6 @@ impl General {
|
||||
pub fn default_idle_client_in_transaction_timeout() -> u64 {
|
||||
0
|
||||
}
|
||||
|
||||
pub fn default_validate_config() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
pub fn default_prometheus_exporter_port() -> i16 {
|
||||
9930
|
||||
}
|
||||
|
||||
pub fn default_server_round_robin() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
pub fn default_prepared_statements_cache_size() -> usize {
|
||||
500
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for General {
|
||||
@@ -452,7 +387,6 @@ impl Default for General {
|
||||
tcp_keepalives_idle: Self::default_tcp_keepalives_idle(),
|
||||
tcp_keepalives_count: Self::default_tcp_keepalives_count(),
|
||||
tcp_keepalives_interval: Self::default_tcp_keepalives_interval(),
|
||||
tcp_user_timeout: Self::default_tcp_user_timeout(),
|
||||
log_client_connections: false,
|
||||
log_client_disconnections: false,
|
||||
autoreload: None,
|
||||
@@ -467,11 +401,7 @@ impl Default for General {
|
||||
auth_query: None,
|
||||
auth_query_user: None,
|
||||
auth_query_password: None,
|
||||
server_lifetime: Self::default_server_lifetime(),
|
||||
server_round_robin: Self::default_server_round_robin(),
|
||||
validate_config: true,
|
||||
prepared_statements: false,
|
||||
prepared_statements_cache_size: 500,
|
||||
server_lifetime: 1000 * 3600 * 24, // 24 hours,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -524,32 +454,20 @@ pub struct Pool {
|
||||
#[serde(default = "Pool::default_load_balancing_mode")]
|
||||
pub load_balancing_mode: LoadBalancingMode,
|
||||
|
||||
#[serde(default = "Pool::default_default_role")]
|
||||
pub default_role: String,
|
||||
|
||||
#[serde(default)] // False
|
||||
pub query_parser_enabled: bool,
|
||||
|
||||
pub query_parser_max_length: Option<usize>,
|
||||
|
||||
#[serde(default)] // False
|
||||
pub query_parser_read_write_splitting: bool,
|
||||
|
||||
#[serde(default)] // False
|
||||
pub primary_reads_enabled: bool,
|
||||
|
||||
/// Maximum time to allow for establishing a new server connection.
|
||||
pub connect_timeout: Option<u64>,
|
||||
|
||||
/// Close idle connections that have been opened for longer than this.
|
||||
pub idle_timeout: Option<u64>,
|
||||
|
||||
/// Close server connections that have been opened for longer than this.
|
||||
/// Only applied to idle connections. If the connection is actively used for
|
||||
/// longer than this period, the pool will not interrupt it.
|
||||
pub server_lifetime: Option<u64>,
|
||||
|
||||
#[serde(default = "Pool::default_sharding_function")]
|
||||
pub sharding_function: ShardingFunction,
|
||||
|
||||
#[serde(default = "Pool::default_automatic_sharding_key")]
|
||||
@@ -559,20 +477,10 @@ pub struct Pool {
|
||||
pub shard_id_regex: Option<String>,
|
||||
pub regex_search_limit: Option<usize>,
|
||||
|
||||
#[serde(default = "Pool::default_default_shard")]
|
||||
pub default_shard: DefaultShard,
|
||||
|
||||
pub auth_query: Option<String>,
|
||||
pub auth_query_user: Option<String>,
|
||||
pub auth_query_password: Option<String>,
|
||||
|
||||
#[serde(default = "Pool::default_cleanup_server_connections")]
|
||||
pub cleanup_server_connections: bool,
|
||||
|
||||
#[serde(default)] // False
|
||||
pub log_client_parameter_status_changes: bool,
|
||||
|
||||
pub plugins: Option<Plugins>,
|
||||
pub shards: BTreeMap<String, Shard>,
|
||||
pub users: BTreeMap<String, User>,
|
||||
// Note, don't put simple fields below these configs. There's a compatibility issue with TOML that makes it
|
||||
@@ -597,10 +505,6 @@ impl Pool {
|
||||
PoolMode::Transaction
|
||||
}
|
||||
|
||||
pub fn default_default_shard() -> DefaultShard {
|
||||
DefaultShard::default()
|
||||
}
|
||||
|
||||
pub fn default_load_balancing_mode() -> LoadBalancingMode {
|
||||
LoadBalancingMode::Random
|
||||
}
|
||||
@@ -609,18 +513,6 @@ impl Pool {
|
||||
None
|
||||
}
|
||||
|
||||
pub fn default_default_role() -> String {
|
||||
"any".into()
|
||||
}
|
||||
|
||||
pub fn default_sharding_function() -> ShardingFunction {
|
||||
ShardingFunction::PgBigintHash
|
||||
}
|
||||
|
||||
pub fn default_cleanup_server_connections() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
pub fn validate(&mut self) -> Result<(), Error> {
|
||||
match self.default_role.as_ref() {
|
||||
"any" => (),
|
||||
@@ -661,18 +553,6 @@ impl Pool {
|
||||
}
|
||||
}
|
||||
|
||||
if self.query_parser_read_write_splitting && !self.query_parser_enabled {
|
||||
error!(
|
||||
"query_parser_read_write_splitting is only valid when query_parser_enabled is true"
|
||||
);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
|
||||
if self.plugins.is_some() && !self.query_parser_enabled {
|
||||
error!("plugins are only valid when query_parser_enabled is true");
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
|
||||
self.automatic_sharding_key = match &self.automatic_sharding_key {
|
||||
Some(key) => {
|
||||
// No quotes in the key so we don't have to compare quoted
|
||||
@@ -692,16 +572,6 @@ impl Pool {
|
||||
None => None,
|
||||
};
|
||||
|
||||
match self.default_shard {
|
||||
DefaultShard::Shard(shard_number) => {
|
||||
if shard_number >= self.shards.len() {
|
||||
error!("Invalid shard {:?}", shard_number);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
|
||||
for (_, user) in &self.users {
|
||||
user.validate()?;
|
||||
}
|
||||
@@ -719,8 +589,6 @@ impl Default for Pool {
|
||||
users: BTreeMap::default(),
|
||||
default_role: String::from("any"),
|
||||
query_parser_enabled: false,
|
||||
query_parser_max_length: None,
|
||||
query_parser_read_write_splitting: false,
|
||||
primary_reads_enabled: false,
|
||||
sharding_function: ShardingFunction::PgBigintHash,
|
||||
automatic_sharding_key: None,
|
||||
@@ -729,14 +597,10 @@ impl Default for Pool {
|
||||
sharding_key_regex: None,
|
||||
shard_id_regex: None,
|
||||
regex_search_limit: Some(1000),
|
||||
default_shard: Self::default_default_shard(),
|
||||
auth_query: None,
|
||||
auth_query_user: None,
|
||||
auth_query_password: None,
|
||||
server_lifetime: None,
|
||||
plugins: None,
|
||||
cleanup_server_connections: true,
|
||||
log_client_parameter_status_changes: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -748,50 +612,6 @@ pub struct ServerConfig {
|
||||
pub role: Role,
|
||||
}
|
||||
|
||||
// No Shard Specified handling.
|
||||
#[derive(Debug, PartialEq, Clone, Eq, Hash, Copy)]
|
||||
pub enum DefaultShard {
|
||||
Shard(usize),
|
||||
Random,
|
||||
RandomHealthy,
|
||||
}
|
||||
impl Default for DefaultShard {
|
||||
fn default() -> Self {
|
||||
DefaultShard::Shard(0)
|
||||
}
|
||||
}
|
||||
impl serde::Serialize for DefaultShard {
|
||||
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
match self {
|
||||
DefaultShard::Shard(shard) => {
|
||||
serializer.serialize_str(&format!("shard_{}", &shard.to_string()))
|
||||
}
|
||||
DefaultShard::Random => serializer.serialize_str("random"),
|
||||
DefaultShard::RandomHealthy => serializer.serialize_str("random_healthy"),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<'de> serde::Deserialize<'de> for DefaultShard {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
if s.starts_with("shard_") {
|
||||
let shard = s[6..].parse::<usize>().map_err(serde::de::Error::custom)?;
|
||||
return Ok(DefaultShard::Shard(shard));
|
||||
}
|
||||
|
||||
match s.as_str() {
|
||||
"random" => Ok(DefaultShard::Random),
|
||||
"random_healthy" => Ok(DefaultShard::RandomHealthy),
|
||||
_ => Err(serde::de::Error::custom(
|
||||
"invalid value for no_shard_specified_behavior",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Serialize, Deserialize, Debug, Hash, Eq)]
|
||||
pub struct MirrorServerConfig {
|
||||
pub host: String,
|
||||
@@ -859,60 +679,39 @@ impl Default for Shard {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct Plugins {
|
||||
pub intercept: Option<Intercept>,
|
||||
pub table_access: Option<TableAccess>,
|
||||
pub query_logger: Option<QueryLogger>,
|
||||
pub prewarmer: Option<Prewarmer>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Plugins {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"interceptor: {}, table_access: {}, query_logger: {}, prewarmer: {}",
|
||||
self.intercept.is_some(),
|
||||
self.table_access.is_some(),
|
||||
self.query_logger.is_some(),
|
||||
self.prewarmer.is_some(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct Intercept {
|
||||
pub enabled: bool,
|
||||
pub queries: BTreeMap<String, Query>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct TableAccess {
|
||||
pub enabled: bool,
|
||||
pub tables: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct QueryLogger {
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
|
||||
pub struct Prewarmer {
|
||||
pub enabled: bool,
|
||||
pub queries: Vec<String>,
|
||||
}
|
||||
|
||||
impl Intercept {
|
||||
pub fn substitute(&mut self, db: &str, user: &str) {
|
||||
for (_, query) in self.queries.iter_mut() {
|
||||
query.substitute(db, user);
|
||||
query.query = query.query.to_ascii_lowercase();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct Query {
|
||||
pub query: String,
|
||||
pub schema: Vec<Vec<String>>,
|
||||
@@ -946,13 +745,8 @@ pub struct Config {
|
||||
#[serde(default = "Config::default_path")]
|
||||
pub path: String,
|
||||
|
||||
// General and global settings.
|
||||
pub general: General,
|
||||
|
||||
// Plugins that should run in all pools.
|
||||
pub plugins: Option<Plugins>,
|
||||
|
||||
// Connection pools.
|
||||
pub pools: HashMap<String, Pool>,
|
||||
}
|
||||
|
||||
@@ -1018,17 +812,6 @@ impl From<&Config> for std::collections::HashMap<String, String> {
|
||||
format!("pools.{}.query_parser_enabled", pool_name),
|
||||
pool.query_parser_enabled.to_string(),
|
||||
),
|
||||
(
|
||||
format!("pools.{}.query_parser_max_length", pool_name),
|
||||
match pool.query_parser_max_length {
|
||||
Some(max_length) => max_length.to_string(),
|
||||
None => String::from("unlimited"),
|
||||
},
|
||||
),
|
||||
(
|
||||
format!("pools.{}.query_parser_read_write_splitting", pool_name),
|
||||
pool.query_parser_read_write_splitting.to_string(),
|
||||
),
|
||||
(
|
||||
format!("pools.{}.default_role", pool_name),
|
||||
pool.default_role.clone(),
|
||||
@@ -1125,7 +908,6 @@ impl Config {
|
||||
"Default max server lifetime: {}ms",
|
||||
self.general.server_lifetime
|
||||
);
|
||||
info!("Sever round robin: {}", self.general.server_round_robin);
|
||||
match self.general.tls_certificate.clone() {
|
||||
Some(tls_certificate) => {
|
||||
info!("TLS certificate: {}", tls_certificate);
|
||||
@@ -1149,20 +931,6 @@ impl Config {
|
||||
"Server TLS certificate verification: {}",
|
||||
self.general.verify_server_certificate
|
||||
);
|
||||
info!("Prepared statements: {}", self.general.prepared_statements);
|
||||
if self.general.prepared_statements {
|
||||
info!(
|
||||
"Prepared statements server cache size: {}",
|
||||
self.general.prepared_statements_cache_size
|
||||
);
|
||||
}
|
||||
info!(
|
||||
"Plugins: {}",
|
||||
match self.plugins {
|
||||
Some(ref plugins) => plugins.to_string(),
|
||||
None => "not configured".into(),
|
||||
}
|
||||
);
|
||||
|
||||
for (pool_name, pool_config) in &self.pools {
|
||||
// TODO: Make this output prettier (maybe a table?)
|
||||
@@ -1211,15 +979,6 @@ impl Config {
|
||||
"[pool: {}] Query router: {}",
|
||||
pool_name, pool_config.query_parser_enabled
|
||||
);
|
||||
|
||||
info!(
|
||||
"[pool: {}] Query parser max length: {:?}",
|
||||
pool_name, pool_config.query_parser_max_length
|
||||
);
|
||||
info!(
|
||||
"[pool: {}] Infer role from query: {}",
|
||||
pool_name, pool_config.query_parser_read_write_splitting
|
||||
);
|
||||
info!(
|
||||
"[pool: {}] Number of shards: {}",
|
||||
pool_name,
|
||||
@@ -1238,22 +997,6 @@ impl Config {
|
||||
None => "default".to_string(),
|
||||
}
|
||||
);
|
||||
info!(
|
||||
"[pool: {}] Cleanup server connections: {}",
|
||||
pool_name, pool_config.cleanup_server_connections
|
||||
);
|
||||
info!(
|
||||
"[pool: {}] Log client parameter status changes: {}",
|
||||
pool_name, pool_config.log_client_parameter_status_changes
|
||||
);
|
||||
info!(
|
||||
"[pool: {}] Plugins: {}",
|
||||
pool_name,
|
||||
match pool_config.plugins {
|
||||
Some(ref plugins) => plugins.to_string(),
|
||||
None => "not configured".into(),
|
||||
}
|
||||
);
|
||||
|
||||
for user in &pool_config.users {
|
||||
info!(
|
||||
@@ -1342,38 +1085,30 @@ impl Config {
|
||||
}
|
||||
|
||||
// Validate TLS!
|
||||
match self.general.tls_certificate {
|
||||
Some(ref mut tls_certificate) => {
|
||||
match self.general.tls_certificate.clone() {
|
||||
Some(tls_certificate) => {
|
||||
match load_certs(Path::new(&tls_certificate)) {
|
||||
Ok(_) => {
|
||||
// Cert is okay, but what about the private key?
|
||||
match self.general.tls_private_key {
|
||||
Some(ref tls_private_key) => {
|
||||
match load_keys(Path::new(&tls_private_key)) {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"tls_private_key is incorrectly configured: {:?}",
|
||||
err
|
||||
);
|
||||
self.general.tls_private_key = None;
|
||||
self.general.tls_certificate = None;
|
||||
}
|
||||
match self.general.tls_private_key.clone() {
|
||||
Some(tls_private_key) => match load_keys(Path::new(&tls_private_key)) {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
error!("tls_private_key is incorrectly configured: {:?}", err);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
None => {
|
||||
warn!("tls_certificate is set, but the tls_private_key is not");
|
||||
self.general.tls_private_key = None;
|
||||
self.general.tls_certificate = None;
|
||||
error!("tls_certificate is set, but the tls_private_key is not");
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
Err(err) => {
|
||||
warn!("tls_certificate is incorrectly configured: {:?}", err);
|
||||
self.general.tls_private_key = None;
|
||||
self.general.tls_certificate = None;
|
||||
error!("tls_certificate is incorrectly configured: {:?}", err);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1396,15 +1131,9 @@ pub fn get_config() -> Config {
|
||||
}
|
||||
|
||||
pub fn get_idle_client_in_transaction_timeout() -> u64 {
|
||||
CONFIG.load().general.idle_client_in_transaction_timeout
|
||||
}
|
||||
|
||||
pub fn get_prepared_statements() -> bool {
|
||||
CONFIG.load().general.prepared_statements
|
||||
}
|
||||
|
||||
pub fn get_prepared_statements_cache_size() -> usize {
|
||||
CONFIG.load().general.prepared_statements_cache_size
|
||||
(*(*CONFIG.load()))
|
||||
.general
|
||||
.idle_client_in_transaction_timeout
|
||||
}
|
||||
|
||||
/// Parse the configuration file located at the path.
|
||||
|
||||
@@ -12,7 +12,6 @@ pub enum Error {
|
||||
ProtocolSyncError(String),
|
||||
BadQuery(String),
|
||||
ServerError,
|
||||
ServerMessageParserError(String),
|
||||
ServerStartupError(String, ServerIdentifier),
|
||||
ServerAuthError(String, ServerIdentifier),
|
||||
BadConfig,
|
||||
@@ -27,8 +26,6 @@ pub enum Error {
|
||||
AuthPassthroughError(String),
|
||||
UnsupportedStatement,
|
||||
QueryRouterParserError(String),
|
||||
QueryRouterError(String),
|
||||
InvalidShardId(usize),
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
@@ -124,9 +121,3 @@ impl std::fmt::Display for Error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::ffi::NulError> for Error {
|
||||
fn from(err: std::ffi::NulError) -> Self {
|
||||
Error::QueryRouterError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
pub mod admin;
|
||||
pub mod auth_passthrough;
|
||||
pub mod client;
|
||||
pub mod cmd_args;
|
||||
pub mod config;
|
||||
pub mod constants;
|
||||
pub mod dns_cache;
|
||||
pub mod errors;
|
||||
pub mod logger;
|
||||
pub mod messages;
|
||||
pub mod mirrors;
|
||||
pub mod multi_logger;
|
||||
pub mod plugins;
|
||||
pub mod pool;
|
||||
pub mod prometheus;
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
use crate::cmd_args::{Args, LogFormat};
|
||||
use tracing_subscriber;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
pub fn init(args: &Args) {
|
||||
// Iniitalize a default filter, and then override the builtin default "warning" with our
|
||||
// commandline, (default: "info")
|
||||
let filter = EnvFilter::from_default_env().add_directive(args.log_level.into());
|
||||
|
||||
let trace_sub = tracing_subscriber::fmt()
|
||||
.with_thread_ids(true)
|
||||
.with_env_filter(filter)
|
||||
.with_ansi(!args.no_color);
|
||||
|
||||
match args.log_format {
|
||||
LogFormat::Structured => trace_sub.json().init(),
|
||||
LogFormat::Debug => trace_sub.pretty().init(),
|
||||
_ => trace_sub.init(),
|
||||
};
|
||||
}
|
||||
16
src/main.rs
16
src/main.rs
@@ -23,6 +23,7 @@ extern crate arc_swap;
|
||||
extern crate async_trait;
|
||||
extern crate bb8;
|
||||
extern crate bytes;
|
||||
extern crate env_logger;
|
||||
extern crate exitcode;
|
||||
extern crate log;
|
||||
extern crate md5;
|
||||
@@ -60,18 +61,15 @@ use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
use pgcat::cmd_args;
|
||||
use pgcat::config::{get_config, reload_config, VERSION};
|
||||
use pgcat::dns_cache;
|
||||
use pgcat::logger;
|
||||
use pgcat::messages::configure_socket;
|
||||
use pgcat::pool::{ClientServerMap, ConnectionPool};
|
||||
use pgcat::prometheus::start_metric_server;
|
||||
use pgcat::stats::{Collector, Reporter, REPORTER};
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let args = cmd_args::parse();
|
||||
logger::init(&args);
|
||||
pgcat::multi_logger::MultiLogger::init().unwrap();
|
||||
|
||||
info!("Welcome to PgCat! Meow. (Version {})", VERSION);
|
||||
|
||||
@@ -80,12 +78,20 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
std::process::exit(exitcode::CONFIG);
|
||||
}
|
||||
|
||||
let args = std::env::args().collect::<Vec<String>>();
|
||||
|
||||
let config_file = if args.len() == 2 {
|
||||
args[1].to_string()
|
||||
} else {
|
||||
String::from("pgcat.toml")
|
||||
};
|
||||
|
||||
// Create a transient runtime for loading the config for the first time.
|
||||
{
|
||||
let runtime = Builder::new_multi_thread().worker_threads(1).build()?;
|
||||
|
||||
runtime.block_on(async {
|
||||
match pgcat::config::parse(args.config_file.as_str()).await {
|
||||
match pgcat::config::parse(&config_file).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
error!("Config parse error: {:?}", err);
|
||||
|
||||
730
src/messages.rs
730
src/messages.rs
@@ -1,24 +1,17 @@
|
||||
/// Helper functions to send one-off protocol messages
|
||||
/// and handle TcpStream (TCP socket).
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use log::{debug, error};
|
||||
use log::error;
|
||||
use md5::{Digest, Md5};
|
||||
use socket2::{SockRef, TcpKeepalive};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use crate::client::PREPARED_STATEMENT_COUNTER;
|
||||
use crate::config::get_config;
|
||||
use crate::errors::Error;
|
||||
|
||||
use crate::constants::MESSAGE_TERMINATOR;
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::CString;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::io::{BufRead, Cursor};
|
||||
use std::mem;
|
||||
use std::str::FromStr;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Postgres data type mappings
|
||||
@@ -144,10 +137,6 @@ where
|
||||
bytes.put_slice(user.as_bytes());
|
||||
bytes.put_u8(0);
|
||||
|
||||
// Application name
|
||||
bytes.put(&b"application_name\0"[..]);
|
||||
bytes.put_slice(&b"pgcat\0"[..]);
|
||||
|
||||
// Database
|
||||
bytes.put(&b"database\0"[..]);
|
||||
bytes.put_slice(database.as_bytes());
|
||||
@@ -537,33 +526,6 @@ pub fn command_complete(command: &str) -> BytesMut {
|
||||
res
|
||||
}
|
||||
|
||||
/// Create a notify message.
|
||||
pub fn notify(message: &str, details: String) -> BytesMut {
|
||||
let mut notify_cmd = BytesMut::new();
|
||||
|
||||
notify_cmd.put_slice("SNOTICE\0".as_bytes());
|
||||
notify_cmd.put_slice("C00000\0".as_bytes());
|
||||
notify_cmd.put_slice(format!("M{}\0", message).as_bytes());
|
||||
notify_cmd.put_slice(format!("D{}\0", details).as_bytes());
|
||||
|
||||
// this extra byte says that is the end of the package
|
||||
notify_cmd.put_u8(0);
|
||||
|
||||
let mut res = BytesMut::new();
|
||||
res.put_u8(b'N');
|
||||
res.put_i32(notify_cmd.len() as i32 + 4);
|
||||
res.put(notify_cmd);
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
pub fn flush() -> BytesMut {
|
||||
let mut bytes = BytesMut::new();
|
||||
bytes.put_u8(b'H');
|
||||
bytes.put_i32(4);
|
||||
bytes
|
||||
}
|
||||
|
||||
/// Write all data in the buffer to the TcpStream.
|
||||
pub async fn write_all<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error>
|
||||
where
|
||||
@@ -696,13 +658,6 @@ pub fn configure_socket(stream: &TcpStream) {
|
||||
let sock_ref = SockRef::from(stream);
|
||||
let conf = get_config();
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
match sock_ref.set_tcp_user_timeout(Some(Duration::from_millis(conf.general.tcp_user_timeout)))
|
||||
{
|
||||
Ok(_) => (),
|
||||
Err(err) => error!("Could not configure tcp_user_timeout for socket: {}", err),
|
||||
}
|
||||
|
||||
match sock_ref.set_keepalive(true) {
|
||||
Ok(_) => {
|
||||
match sock_ref.set_tcp_keepalive(
|
||||
@@ -712,7 +667,7 @@ pub fn configure_socket(stream: &TcpStream) {
|
||||
.with_time(Duration::from_secs(conf.general.tcp_keepalives_idle)),
|
||||
) {
|
||||
Ok(_) => (),
|
||||
Err(err) => error!("Could not configure tcp_keepalive for socket: {}", err),
|
||||
Err(err) => error!("Could not configure socket: {}", err),
|
||||
}
|
||||
}
|
||||
Err(err) => error!("Could not configure socket: {}", err),
|
||||
@@ -734,684 +689,3 @@ impl BytesMutReader for Cursor<&BytesMut> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BytesMutReader for BytesMut {
|
||||
/// Should only be used when reading strings from the message protocol.
|
||||
/// Can be used to read multiple strings from the same message which are separated by the null byte
|
||||
fn read_string(&mut self) -> Result<String, Error> {
|
||||
let null_index = self.iter().position(|&byte| byte == b'\0');
|
||||
|
||||
match null_index {
|
||||
Some(index) => {
|
||||
let string_bytes = self.split_to(index + 1);
|
||||
Ok(String::from_utf8_lossy(&string_bytes[..string_bytes.len() - 1]).to_string())
|
||||
}
|
||||
None => return Err(Error::ParseBytesError("Could not read string".to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Parse (F) message.
|
||||
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Parse {
|
||||
code: char,
|
||||
#[allow(dead_code)]
|
||||
len: i32,
|
||||
pub name: String,
|
||||
pub generated_name: String,
|
||||
query: String,
|
||||
num_params: i16,
|
||||
param_types: Vec<i32>,
|
||||
}
|
||||
|
||||
impl TryFrom<&BytesMut> for Parse {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(buf: &BytesMut) -> Result<Parse, Error> {
|
||||
let mut cursor = Cursor::new(buf);
|
||||
let code = cursor.get_u8() as char;
|
||||
let len = cursor.get_i32();
|
||||
let name = cursor.read_string()?;
|
||||
let query = cursor.read_string()?;
|
||||
let num_params = cursor.get_i16();
|
||||
let mut param_types = Vec::new();
|
||||
|
||||
for _ in 0..num_params {
|
||||
param_types.push(cursor.get_i32());
|
||||
}
|
||||
|
||||
Ok(Parse {
|
||||
code,
|
||||
len,
|
||||
name,
|
||||
generated_name: prepared_statement_name(),
|
||||
query,
|
||||
num_params,
|
||||
param_types,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Parse> for BytesMut {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(parse: Parse) -> Result<BytesMut, Error> {
|
||||
let mut bytes = BytesMut::new();
|
||||
|
||||
let name_binding = CString::new(parse.name)?;
|
||||
let name = name_binding.as_bytes_with_nul();
|
||||
|
||||
let query_binding = CString::new(parse.query)?;
|
||||
let query = query_binding.as_bytes_with_nul();
|
||||
|
||||
// Recompute length of the message.
|
||||
let len = 4 // self
|
||||
+ name.len()
|
||||
+ query.len()
|
||||
+ 2
|
||||
+ 4 * parse.num_params as usize;
|
||||
|
||||
bytes.put_u8(parse.code as u8);
|
||||
bytes.put_i32(len as i32);
|
||||
bytes.put_slice(name);
|
||||
bytes.put_slice(query);
|
||||
bytes.put_i16(parse.num_params);
|
||||
for param in parse.param_types {
|
||||
bytes.put_i32(param);
|
||||
}
|
||||
|
||||
Ok(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&Parse> for BytesMut {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(parse: &Parse) -> Result<BytesMut, Error> {
|
||||
parse.clone().try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl Parse {
|
||||
pub fn rename(mut self) -> Self {
|
||||
self.name = self.generated_name.to_string();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn anonymous(&self) -> bool {
|
||||
self.name.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// Bind (B) message.
|
||||
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Bind {
|
||||
code: char,
|
||||
#[allow(dead_code)]
|
||||
len: i64,
|
||||
portal: String,
|
||||
pub prepared_statement: String,
|
||||
num_param_format_codes: i16,
|
||||
param_format_codes: Vec<i16>,
|
||||
num_param_values: i16,
|
||||
param_values: Vec<(i32, BytesMut)>,
|
||||
num_result_column_format_codes: i16,
|
||||
result_columns_format_codes: Vec<i16>,
|
||||
}
|
||||
|
||||
impl TryFrom<&BytesMut> for Bind {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(buf: &BytesMut) -> Result<Bind, Error> {
|
||||
let mut cursor = Cursor::new(buf);
|
||||
let code = cursor.get_u8() as char;
|
||||
let len = cursor.get_i32();
|
||||
let portal = cursor.read_string()?;
|
||||
let prepared_statement = cursor.read_string()?;
|
||||
let num_param_format_codes = cursor.get_i16();
|
||||
let mut param_format_codes = Vec::new();
|
||||
|
||||
for _ in 0..num_param_format_codes {
|
||||
param_format_codes.push(cursor.get_i16());
|
||||
}
|
||||
|
||||
let num_param_values = cursor.get_i16();
|
||||
let mut param_values = Vec::new();
|
||||
|
||||
for _ in 0..num_param_values {
|
||||
let param_len = cursor.get_i32();
|
||||
// There is special occasion when the parameter is NULL
|
||||
// In that case, param length is defined as -1
|
||||
// So if the passed parameter len is over 0
|
||||
if param_len > 0 {
|
||||
let mut param = BytesMut::with_capacity(param_len as usize);
|
||||
param.resize(param_len as usize, b'0');
|
||||
cursor.copy_to_slice(&mut param);
|
||||
// we push and the length and the parameter into vector
|
||||
param_values.push((param_len, param));
|
||||
} else {
|
||||
// otherwise we push a tuple with -1 and 0-len BytesMut
|
||||
// which means that after encountering -1 postgres proceeds
|
||||
// to processing another parameter
|
||||
param_values.push((param_len, BytesMut::new()));
|
||||
}
|
||||
}
|
||||
|
||||
let num_result_column_format_codes = cursor.get_i16();
|
||||
let mut result_columns_format_codes = Vec::new();
|
||||
|
||||
for _ in 0..num_result_column_format_codes {
|
||||
result_columns_format_codes.push(cursor.get_i16());
|
||||
}
|
||||
|
||||
Ok(Bind {
|
||||
code,
|
||||
len: len as i64,
|
||||
portal,
|
||||
prepared_statement,
|
||||
num_param_format_codes,
|
||||
param_format_codes,
|
||||
num_param_values,
|
||||
param_values,
|
||||
num_result_column_format_codes,
|
||||
result_columns_format_codes,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Bind> for BytesMut {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(bind: Bind) -> Result<BytesMut, Error> {
|
||||
let mut bytes = BytesMut::new();
|
||||
|
||||
let portal_binding = CString::new(bind.portal)?;
|
||||
let portal = portal_binding.as_bytes_with_nul();
|
||||
|
||||
let prepared_statement_binding = CString::new(bind.prepared_statement)?;
|
||||
let prepared_statement = prepared_statement_binding.as_bytes_with_nul();
|
||||
|
||||
let mut len = 4 // self
|
||||
+ portal.len()
|
||||
+ prepared_statement.len()
|
||||
+ 2 // num_param_format_codes
|
||||
+ 2 * bind.num_param_format_codes as usize // num_param_format_codes
|
||||
+ 2; // num_param_values
|
||||
|
||||
for (param_len, _) in &bind.param_values {
|
||||
len += 4 + *param_len as usize;
|
||||
}
|
||||
len += 2; // num_result_column_format_codes
|
||||
len += 2 * bind.num_result_column_format_codes as usize;
|
||||
|
||||
bytes.put_u8(bind.code as u8);
|
||||
bytes.put_i32(len as i32);
|
||||
bytes.put_slice(portal);
|
||||
bytes.put_slice(prepared_statement);
|
||||
bytes.put_i16(bind.num_param_format_codes);
|
||||
for param_format_code in bind.param_format_codes {
|
||||
bytes.put_i16(param_format_code);
|
||||
}
|
||||
bytes.put_i16(bind.num_param_values);
|
||||
for (param_len, param) in bind.param_values {
|
||||
bytes.put_i32(param_len);
|
||||
bytes.put_slice(¶m);
|
||||
}
|
||||
bytes.put_i16(bind.num_result_column_format_codes);
|
||||
for result_column_format_code in bind.result_columns_format_codes {
|
||||
bytes.put_i16(result_column_format_code);
|
||||
}
|
||||
|
||||
Ok(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl Bind {
|
||||
pub fn reassign(mut self, parse: &Parse) -> Self {
|
||||
self.prepared_statement = parse.name.clone();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn anonymous(&self) -> bool {
|
||||
self.prepared_statement.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Describe {
|
||||
code: char,
|
||||
|
||||
#[allow(dead_code)]
|
||||
len: i32,
|
||||
target: char,
|
||||
pub statement_name: String,
|
||||
}
|
||||
|
||||
impl TryFrom<&BytesMut> for Describe {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(bytes: &BytesMut) -> Result<Describe, Error> {
|
||||
let mut cursor = Cursor::new(bytes);
|
||||
let code = cursor.get_u8() as char;
|
||||
let len = cursor.get_i32();
|
||||
let target = cursor.get_u8() as char;
|
||||
let statement_name = cursor.read_string()?;
|
||||
|
||||
Ok(Describe {
|
||||
code,
|
||||
len,
|
||||
target,
|
||||
statement_name,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Describe> for BytesMut {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(describe: Describe) -> Result<BytesMut, Error> {
|
||||
let mut bytes = BytesMut::new();
|
||||
let statement_name_binding = CString::new(describe.statement_name)?;
|
||||
let statement_name = statement_name_binding.as_bytes_with_nul();
|
||||
let len = 4 + 1 + statement_name.len();
|
||||
|
||||
bytes.put_u8(describe.code as u8);
|
||||
bytes.put_i32(len as i32);
|
||||
bytes.put_u8(describe.target as u8);
|
||||
bytes.put_slice(statement_name);
|
||||
|
||||
Ok(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl Describe {
|
||||
pub fn rename(mut self, name: &str) -> Self {
|
||||
self.statement_name = name.to_string();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn anonymous(&self) -> bool {
|
||||
self.statement_name.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// Close (F) message.
|
||||
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Close {
|
||||
code: char,
|
||||
#[allow(dead_code)]
|
||||
len: i32,
|
||||
close_type: char,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
impl TryFrom<&BytesMut> for Close {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(bytes: &BytesMut) -> Result<Close, Error> {
|
||||
let mut cursor = Cursor::new(bytes);
|
||||
let code = cursor.get_u8() as char;
|
||||
let len = cursor.get_i32();
|
||||
let close_type = cursor.get_u8() as char;
|
||||
let name = cursor.read_string()?;
|
||||
|
||||
Ok(Close {
|
||||
code,
|
||||
len,
|
||||
close_type,
|
||||
name,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Close> for BytesMut {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(close: Close) -> Result<BytesMut, Error> {
|
||||
debug!("Close: {:?}", close);
|
||||
|
||||
let mut bytes = BytesMut::new();
|
||||
let name_binding = CString::new(close.name)?;
|
||||
let name = name_binding.as_bytes_with_nul();
|
||||
let len = 4 + 1 + name.len();
|
||||
|
||||
bytes.put_u8(close.code as u8);
|
||||
bytes.put_i32(len as i32);
|
||||
bytes.put_u8(close.close_type as u8);
|
||||
bytes.put_slice(name);
|
||||
|
||||
Ok(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl Close {
|
||||
pub fn new(name: &str) -> Close {
|
||||
let name = name.to_string();
|
||||
|
||||
Close {
|
||||
code: 'C',
|
||||
len: 4 + 1 + name.len() as i32 + 1, // will be recalculated
|
||||
close_type: 'S',
|
||||
name,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_prepared_statement(&self) -> bool {
|
||||
self.close_type == 'S'
|
||||
}
|
||||
|
||||
pub fn anonymous(&self) -> bool {
|
||||
self.name.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn close_complete() -> BytesMut {
|
||||
let mut bytes = BytesMut::new();
|
||||
bytes.put_u8(b'3');
|
||||
bytes.put_i32(4);
|
||||
bytes
|
||||
}
|
||||
|
||||
pub fn prepared_statement_name() -> String {
|
||||
format!(
|
||||
"P_{}",
|
||||
PREPARED_STATEMENT_COUNTER.fetch_add(1, Ordering::SeqCst)
|
||||
)
|
||||
}
|
||||
|
||||
// from https://www.postgresql.org/docs/12/protocol-error-fields.html
|
||||
#[derive(Debug, Default, PartialEq)]
|
||||
pub struct PgErrorMsg {
|
||||
pub severity_localized: String, // S
|
||||
pub severity: String, // V
|
||||
pub code: String, // C
|
||||
pub message: String, // M
|
||||
pub detail: Option<String>, // D
|
||||
pub hint: Option<String>, // H
|
||||
pub position: Option<u32>, // P
|
||||
pub internal_position: Option<u32>, // p
|
||||
pub internal_query: Option<String>, // q
|
||||
pub where_context: Option<String>, // W
|
||||
pub schema_name: Option<String>, // s
|
||||
pub table_name: Option<String>, // t
|
||||
pub column_name: Option<String>, // c
|
||||
pub data_type_name: Option<String>, // d
|
||||
pub constraint_name: Option<String>, // n
|
||||
pub file_name: Option<String>, // F
|
||||
pub line: Option<u32>, // L
|
||||
pub routine: Option<String>, // R
|
||||
}
|
||||
|
||||
// TODO: implement with https://docs.rs/derive_more/latest/derive_more/
|
||||
impl Display for PgErrorMsg {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "[severity: {}]", self.severity)?;
|
||||
write!(f, "[code: {}]", self.code)?;
|
||||
write!(f, "[message: {}]", self.message)?;
|
||||
if let Some(val) = &self.detail {
|
||||
write!(f, "[detail: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.hint {
|
||||
write!(f, "[hint: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.position {
|
||||
write!(f, "[position: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.internal_position {
|
||||
write!(f, "[internal_position: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.internal_query {
|
||||
write!(f, "[internal_query: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.internal_query {
|
||||
write!(f, "[internal_query: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.where_context {
|
||||
write!(f, "[where: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.schema_name {
|
||||
write!(f, "[schema_name: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.table_name {
|
||||
write!(f, "[table_name: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.column_name {
|
||||
write!(f, "[column_name: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.data_type_name {
|
||||
write!(f, "[data_type_name: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.constraint_name {
|
||||
write!(f, "[constraint_name: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.file_name {
|
||||
write!(f, "[file_name: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.line {
|
||||
write!(f, "[line: {val}]")?;
|
||||
}
|
||||
if let Some(val) = &self.routine {
|
||||
write!(f, "[routine: {val}]")?;
|
||||
}
|
||||
|
||||
write!(f, " ")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl PgErrorMsg {
|
||||
pub fn parse(error_msg: Vec<u8>) -> Result<PgErrorMsg, Error> {
|
||||
let mut out = PgErrorMsg {
|
||||
severity_localized: "".to_string(),
|
||||
severity: "".to_string(),
|
||||
code: "".to_string(),
|
||||
message: "".to_string(),
|
||||
detail: None,
|
||||
hint: None,
|
||||
position: None,
|
||||
internal_position: None,
|
||||
internal_query: None,
|
||||
where_context: None,
|
||||
schema_name: None,
|
||||
table_name: None,
|
||||
column_name: None,
|
||||
data_type_name: None,
|
||||
constraint_name: None,
|
||||
file_name: None,
|
||||
line: None,
|
||||
routine: None,
|
||||
};
|
||||
for msg_part in error_msg.split(|v| *v == MESSAGE_TERMINATOR) {
|
||||
if msg_part.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let msg_content = match String::from_utf8_lossy(&msg_part[1..]).parse() {
|
||||
Ok(c) => c,
|
||||
Err(err) => {
|
||||
return Err(Error::ServerMessageParserError(format!(
|
||||
"could not parse server message field. err {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
match &msg_part[0] {
|
||||
b'S' => {
|
||||
out.severity_localized = msg_content;
|
||||
}
|
||||
b'V' => {
|
||||
out.severity = msg_content;
|
||||
}
|
||||
b'C' => {
|
||||
out.code = msg_content;
|
||||
}
|
||||
b'M' => {
|
||||
out.message = msg_content;
|
||||
}
|
||||
b'D' => {
|
||||
out.detail = Some(msg_content);
|
||||
}
|
||||
b'H' => {
|
||||
out.hint = Some(msg_content);
|
||||
}
|
||||
b'P' => out.position = Some(u32::from_str(msg_content.as_str()).unwrap_or(0)),
|
||||
b'p' => {
|
||||
out.internal_position = Some(u32::from_str(msg_content.as_str()).unwrap_or(0))
|
||||
}
|
||||
b'q' => {
|
||||
out.internal_query = Some(msg_content);
|
||||
}
|
||||
b'W' => {
|
||||
out.where_context = Some(msg_content);
|
||||
}
|
||||
b's' => {
|
||||
out.schema_name = Some(msg_content);
|
||||
}
|
||||
b't' => {
|
||||
out.table_name = Some(msg_content);
|
||||
}
|
||||
b'c' => {
|
||||
out.column_name = Some(msg_content);
|
||||
}
|
||||
b'd' => {
|
||||
out.data_type_name = Some(msg_content);
|
||||
}
|
||||
b'n' => {
|
||||
out.constraint_name = Some(msg_content);
|
||||
}
|
||||
b'F' => {
|
||||
out.file_name = Some(msg_content);
|
||||
}
|
||||
b'L' => out.line = Some(u32::from_str(msg_content.as_str()).unwrap_or(0)),
|
||||
b'R' => {
|
||||
out.routine = Some(msg_content);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::messages::PgErrorMsg;
|
||||
use log::{error, info};
|
||||
|
||||
fn field(kind: char, content: &str) -> Vec<u8> {
|
||||
format!("{kind}{content}\0").as_bytes().to_vec()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_fields() {
|
||||
let mut complete_msg = vec![];
|
||||
let severity = "FATAL";
|
||||
complete_msg.extend(field('S', &severity));
|
||||
complete_msg.extend(field('V', &severity));
|
||||
|
||||
let error_code = "29P02";
|
||||
complete_msg.extend(field('C', &error_code));
|
||||
let message = "password authentication failed for user \"wrong_user\"";
|
||||
complete_msg.extend(field('M', &message));
|
||||
let detail_msg = "super detailed message";
|
||||
complete_msg.extend(field('D', &detail_msg));
|
||||
let hint_msg = "hint detail here";
|
||||
complete_msg.extend(field('H', &hint_msg));
|
||||
complete_msg.extend(field('P', "123"));
|
||||
complete_msg.extend(field('p', "234"));
|
||||
let internal_query = "SELECT * from foo;";
|
||||
complete_msg.extend(field('q', &internal_query));
|
||||
let where_msg = "where goes here";
|
||||
complete_msg.extend(field('W', &where_msg));
|
||||
let schema_msg = "schema_name";
|
||||
complete_msg.extend(field('s', &schema_msg));
|
||||
let table_msg = "table_name";
|
||||
complete_msg.extend(field('t', &table_msg));
|
||||
let column_msg = "column_name";
|
||||
complete_msg.extend(field('c', &column_msg));
|
||||
let data_type_msg = "type_name";
|
||||
complete_msg.extend(field('d', &data_type_msg));
|
||||
let constraint_msg = "constraint_name";
|
||||
complete_msg.extend(field('n', &constraint_msg));
|
||||
let file_msg = "pgcat.c";
|
||||
complete_msg.extend(field('F', &file_msg));
|
||||
complete_msg.extend(field('L', "335"));
|
||||
let routine_msg = "my_failing_routine";
|
||||
complete_msg.extend(field('R', &routine_msg));
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::INFO)
|
||||
.with_ansi(true)
|
||||
.init();
|
||||
|
||||
info!(
|
||||
"full message: {}",
|
||||
PgErrorMsg::parse(complete_msg.clone()).unwrap()
|
||||
);
|
||||
assert_eq!(
|
||||
PgErrorMsg {
|
||||
severity_localized: severity.to_string(),
|
||||
severity: severity.to_string(),
|
||||
code: error_code.to_string(),
|
||||
message: message.to_string(),
|
||||
detail: Some(detail_msg.to_string()),
|
||||
hint: Some(hint_msg.to_string()),
|
||||
position: Some(123),
|
||||
internal_position: Some(234),
|
||||
internal_query: Some(internal_query.to_string()),
|
||||
where_context: Some(where_msg.to_string()),
|
||||
schema_name: Some(schema_msg.to_string()),
|
||||
table_name: Some(table_msg.to_string()),
|
||||
column_name: Some(column_msg.to_string()),
|
||||
data_type_name: Some(data_type_msg.to_string()),
|
||||
constraint_name: Some(constraint_msg.to_string()),
|
||||
file_name: Some(file_msg.to_string()),
|
||||
line: Some(335),
|
||||
routine: Some(routine_msg.to_string()),
|
||||
},
|
||||
PgErrorMsg::parse(complete_msg).unwrap()
|
||||
);
|
||||
|
||||
let mut only_mandatory_msg = vec![];
|
||||
only_mandatory_msg.extend(field('S', &severity));
|
||||
only_mandatory_msg.extend(field('V', &severity));
|
||||
only_mandatory_msg.extend(field('C', &error_code));
|
||||
only_mandatory_msg.extend(field('M', &message));
|
||||
only_mandatory_msg.extend(field('D', &detail_msg));
|
||||
|
||||
let err_fields = PgErrorMsg::parse(only_mandatory_msg.clone()).unwrap();
|
||||
info!("only mandatory fields: {}", &err_fields);
|
||||
error!(
|
||||
"server error: {}: {}",
|
||||
err_fields.severity, err_fields.message
|
||||
);
|
||||
assert_eq!(
|
||||
PgErrorMsg {
|
||||
severity_localized: severity.to_string(),
|
||||
severity: severity.to_string(),
|
||||
code: error_code.to_string(),
|
||||
message: message.to_string(),
|
||||
detail: Some(detail_msg.to_string()),
|
||||
hint: None,
|
||||
position: None,
|
||||
internal_position: None,
|
||||
internal_query: None,
|
||||
where_context: None,
|
||||
schema_name: None,
|
||||
table_name: None,
|
||||
column_name: None,
|
||||
data_type_name: None,
|
||||
constraint_name: None,
|
||||
file_name: None,
|
||||
line: None,
|
||||
routine: None,
|
||||
},
|
||||
PgErrorMsg::parse(only_mandatory_msg).unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,8 @@ use bytes::{Bytes, BytesMut};
|
||||
use parking_lot::RwLock;
|
||||
|
||||
use crate::config::{get_config, Address, Role, User};
|
||||
use crate::pool::{ClientServerMap, ServerPool};
|
||||
use crate::pool::{ClientServerMap, PoolIdentifier, ServerPool};
|
||||
use crate::stats::PoolStats;
|
||||
use log::{error, info, trace, warn};
|
||||
use tokio::sync::mpsc::{channel, Receiver, Sender};
|
||||
|
||||
@@ -23,7 +24,7 @@ impl MirroredClient {
|
||||
async fn create_pool(&self) -> Pool<ServerPool> {
|
||||
let config = get_config();
|
||||
let default = std::time::Duration::from_millis(10_000).as_millis() as u64;
|
||||
let (connection_timeout, idle_timeout, _cfg) =
|
||||
let (connection_timeout, idle_timeout, cfg) =
|
||||
match config.pools.get(&self.address.pool_name) {
|
||||
Some(cfg) => (
|
||||
cfg.connect_timeout.unwrap_or(default),
|
||||
@@ -33,15 +34,15 @@ impl MirroredClient {
|
||||
None => (default, default, crate::config::Pool::default()),
|
||||
};
|
||||
|
||||
let identifier = PoolIdentifier::new(&self.database, &self.user.username);
|
||||
|
||||
let manager = ServerPool::new(
|
||||
self.address.clone(),
|
||||
self.user.clone(),
|
||||
self.database.as_str(),
|
||||
ClientServerMap::default(),
|
||||
Arc::new(PoolStats::new(identifier, cfg.clone())),
|
||||
Arc::new(RwLock::new(None)),
|
||||
None,
|
||||
true,
|
||||
false,
|
||||
);
|
||||
|
||||
Pool::builder()
|
||||
@@ -79,7 +80,7 @@ impl MirroredClient {
|
||||
}
|
||||
|
||||
// Incoming data from server (we read to clear the socket buffer and discard the data)
|
||||
recv_result = server.recv(None) => {
|
||||
recv_result = server.recv() => {
|
||||
match recv_result {
|
||||
Ok(message) => trace!("Received from mirror: {} {:?}", String::from_utf8_lossy(&message[..]), address.clone()),
|
||||
Err(err) => {
|
||||
|
||||
80
src/multi_logger.rs
Normal file
80
src/multi_logger.rs
Normal file
@@ -0,0 +1,80 @@
|
||||
use log::{Level, Log, Metadata, Record, SetLoggerError};
|
||||
|
||||
// This is a special kind of logger that allows sending logs to different
|
||||
// targets depending on the log level.
|
||||
//
|
||||
// By default, if nothing is set, it acts as a regular env_log logger,
|
||||
// it sends everything to standard error.
|
||||
//
|
||||
// If the Env variable `STDOUT_LOG` is defined, it will be used for
|
||||
// configuring the standard out logger.
|
||||
//
|
||||
// The behavior is:
|
||||
// - If it is an error, the message is written to standard error.
|
||||
// - If it is not, and it matches the log level of the standard output logger (`STDOUT_LOG` env var), it will be send to standard output.
|
||||
// - If the above is not true, it is sent to the stderr logger that will log it or not depending on the value
|
||||
// of the RUST_LOG env var.
|
||||
//
|
||||
// So to summarize, if no `STDOUT_LOG` env var is present, the logger is the default logger. If `STDOUT_LOG` is set, everything
|
||||
// but errors, that matches the log level set in the `STDOUT_LOG` env var is sent to stdout. You can have also some esoteric configuration
|
||||
// where you set `RUST_LOG=debug` and `STDOUT_LOG=info`, in here, errors will go to stderr, warns and infos to stdout and debugs to stderr.
|
||||
//
|
||||
pub struct MultiLogger {
|
||||
stderr_logger: env_logger::Logger,
|
||||
stdout_logger: env_logger::Logger,
|
||||
}
|
||||
|
||||
impl MultiLogger {
|
||||
fn new() -> Self {
|
||||
let stderr_logger = env_logger::builder().format_timestamp_micros().build();
|
||||
let stdout_logger = env_logger::Builder::from_env("STDOUT_LOG")
|
||||
.format_timestamp_micros()
|
||||
.target(env_logger::Target::Stdout)
|
||||
.build();
|
||||
|
||||
Self {
|
||||
stderr_logger,
|
||||
stdout_logger,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init() -> Result<(), SetLoggerError> {
|
||||
let logger = Self::new();
|
||||
|
||||
log::set_max_level(logger.stderr_logger.filter());
|
||||
log::set_boxed_logger(Box::new(logger))
|
||||
}
|
||||
}
|
||||
|
||||
impl Log for MultiLogger {
|
||||
fn enabled(&self, metadata: &Metadata) -> bool {
|
||||
self.stderr_logger.enabled(metadata) && self.stdout_logger.enabled(metadata)
|
||||
}
|
||||
|
||||
fn log(&self, record: &Record) {
|
||||
if record.level() == Level::Error {
|
||||
self.stderr_logger.log(record);
|
||||
} else {
|
||||
if self.stdout_logger.matches(record) {
|
||||
self.stdout_logger.log(record);
|
||||
} else {
|
||||
self.stderr_logger.log(record);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&self) {
|
||||
self.stderr_logger.flush();
|
||||
self.stdout_logger.flush();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_init() {
|
||||
MultiLogger::init().unwrap();
|
||||
}
|
||||
}
|
||||
@@ -2,21 +2,52 @@
|
||||
//!
|
||||
//! It intercepts queries and returns fake results.
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use async_trait::async_trait;
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use sqlparser::ast::Statement;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use log::debug;
|
||||
use log::{debug, info};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
config::Intercept as InterceptConfig,
|
||||
errors::Error,
|
||||
messages::{command_complete, data_row_nullable, row_description, DataType},
|
||||
plugins::{Plugin, PluginOutput},
|
||||
pool::{PoolIdentifier, PoolMap},
|
||||
query_router::QueryRouter,
|
||||
};
|
||||
|
||||
pub static CONFIG: Lazy<ArcSwap<HashMap<PoolIdentifier, InterceptConfig>>> =
|
||||
Lazy::new(|| ArcSwap::from_pointee(HashMap::new()));
|
||||
|
||||
/// Check if the interceptor plugin has been enabled.
|
||||
pub fn enabled() -> bool {
|
||||
!CONFIG.load().is_empty()
|
||||
}
|
||||
|
||||
pub fn setup(intercept_config: &InterceptConfig, pools: &PoolMap) {
|
||||
let mut config = HashMap::new();
|
||||
for (identifier, _) in pools.iter() {
|
||||
let mut intercept_config = intercept_config.clone();
|
||||
intercept_config.substitute(&identifier.db, &identifier.user);
|
||||
config.insert(identifier.clone(), intercept_config);
|
||||
}
|
||||
|
||||
CONFIG.store(Arc::new(config));
|
||||
|
||||
info!("Intercepting {} queries", intercept_config.queries.len());
|
||||
}
|
||||
|
||||
pub fn disable() {
|
||||
CONFIG.store(Arc::new(HashMap::new()));
|
||||
}
|
||||
|
||||
// TODO: use these structs for deserialization
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Rule {
|
||||
@@ -32,35 +63,33 @@ pub struct Column {
|
||||
}
|
||||
|
||||
/// The intercept plugin.
|
||||
pub struct Intercept<'a> {
|
||||
pub enabled: bool,
|
||||
pub config: &'a InterceptConfig,
|
||||
}
|
||||
pub struct Intercept;
|
||||
|
||||
#[async_trait]
|
||||
impl<'a> Plugin for Intercept<'a> {
|
||||
impl Plugin for Intercept {
|
||||
async fn run(
|
||||
&mut self,
|
||||
query_router: &QueryRouter,
|
||||
ast: &Vec<Statement>,
|
||||
) -> Result<PluginOutput, Error> {
|
||||
if !self.enabled || ast.is_empty() {
|
||||
if ast.is_empty() {
|
||||
return Ok(PluginOutput::Allow);
|
||||
}
|
||||
|
||||
let mut config = self.config.clone();
|
||||
config.substitute(
|
||||
let mut result = BytesMut::new();
|
||||
let query_map = match CONFIG.load().get(&PoolIdentifier::new(
|
||||
&query_router.pool_settings().db,
|
||||
&query_router.pool_settings().user.username,
|
||||
);
|
||||
|
||||
let mut result = BytesMut::new();
|
||||
)) {
|
||||
Some(query_map) => query_map.clone(),
|
||||
None => return Ok(PluginOutput::Allow),
|
||||
};
|
||||
|
||||
for q in ast {
|
||||
// Normalization
|
||||
let q = q.to_string().to_ascii_lowercase();
|
||||
|
||||
for (_, target) in config.queries.iter() {
|
||||
for (_, target) in query_map.queries.iter() {
|
||||
if target.query.as_str() == q {
|
||||
debug!("Intercepting query: {}", q);
|
||||
|
||||
@@ -118,3 +147,142 @@ impl<'a> Plugin for Intercept<'a> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Make IntelliJ SQL plugin believe it's talking to an actual database
|
||||
/// instead of PgCat.
|
||||
#[allow(dead_code)]
|
||||
fn fool_datagrip(database: &str, user: &str) -> Value {
|
||||
json!([
|
||||
{
|
||||
"query": "select current_database() as a, current_schemas(false) as b",
|
||||
"schema": [
|
||||
{
|
||||
"name": "a",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "b",
|
||||
"data_type": "anyarray",
|
||||
},
|
||||
],
|
||||
|
||||
"result": [
|
||||
[database, "{public}"],
|
||||
],
|
||||
},
|
||||
{
|
||||
"query": "select current_database(), current_schema(), current_user",
|
||||
"schema": [
|
||||
{
|
||||
"name": "current_database",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "current_schema",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "current_user",
|
||||
"data_type": "text",
|
||||
}
|
||||
],
|
||||
|
||||
"result": [
|
||||
["sharded_db", "public", "sharding_user"],
|
||||
],
|
||||
},
|
||||
{
|
||||
"query": "select cast(n.oid as bigint) as id, datname as name, d.description, datistemplate as is_template, datallowconn as allow_connections, pg_catalog.pg_get_userbyid(n.datdba) as \"owner\" from pg_catalog.pg_database as n left join pg_catalog.pg_shdescription as d on n.oid = d.objoid order by case when datname = pg_catalog.current_database() then -cast(1 as bigint) else cast(n.oid as bigint) end",
|
||||
"schema": [
|
||||
{
|
||||
"name": "id",
|
||||
"data_type": "oid",
|
||||
},
|
||||
{
|
||||
"name": "name",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "description",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "is_template",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "allow_connections",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "owner",
|
||||
"data_type": "text",
|
||||
}
|
||||
],
|
||||
"result": [
|
||||
["16387", database, "", "f", "t", user],
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "select cast(r.oid as bigint) as role_id, rolname as role_name, rolsuper as is_super, rolinherit as is_inherit, rolcreaterole as can_createrole, rolcreatedb as can_createdb, rolcanlogin as can_login, rolreplication as is_replication, rolconnlimit as conn_limit, rolvaliduntil as valid_until, rolbypassrls as bypass_rls, rolconfig as config, d.description from pg_catalog.pg_roles as r left join pg_catalog.pg_shdescription as d on d.objoid = r.oid",
|
||||
"schema": [
|
||||
{
|
||||
"name": "role_id",
|
||||
"data_type": "oid",
|
||||
},
|
||||
{
|
||||
"name": "role_name",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "is_super",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "is_inherit",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "can_createrole",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "can_createdb",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "can_login",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "is_replication",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "conn_limit",
|
||||
"data_type": "int4",
|
||||
},
|
||||
{
|
||||
"name": "valid_until",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "bypass_rls",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "config",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "description",
|
||||
"data_type": "text",
|
||||
},
|
||||
],
|
||||
"result": [
|
||||
["10", "postgres", "f", "t", "f", "f", "t", "f", "-1", "", "f", "", ""],
|
||||
["16419", user, "f", "t", "f", "f", "t", "f", "-1", "", "f", "", ""],
|
||||
]
|
||||
}
|
||||
])
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
//!
|
||||
|
||||
pub mod intercept;
|
||||
pub mod prewarmer;
|
||||
pub mod query_logger;
|
||||
pub mod table_access;
|
||||
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
//! Prewarm new connections before giving them to the client.
|
||||
use crate::{errors::Error, server::Server};
|
||||
use log::info;
|
||||
|
||||
pub struct Prewarmer<'a> {
|
||||
pub enabled: bool,
|
||||
pub server: &'a mut Server,
|
||||
pub queries: &'a Vec<String>,
|
||||
}
|
||||
|
||||
impl<'a> Prewarmer<'a> {
|
||||
pub async fn run(&mut self) -> Result<(), Error> {
|
||||
if !self.enabled {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for query in self.queries {
|
||||
info!(
|
||||
"{} Prewarning with query: `{}`",
|
||||
self.server.address(),
|
||||
query
|
||||
);
|
||||
self.server.query(&query).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -5,33 +5,44 @@ use crate::{
|
||||
plugins::{Plugin, PluginOutput},
|
||||
query_router::QueryRouter,
|
||||
};
|
||||
use arc_swap::ArcSwap;
|
||||
use async_trait::async_trait;
|
||||
use log::info;
|
||||
use once_cell::sync::Lazy;
|
||||
use sqlparser::ast::Statement;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct QueryLogger<'a> {
|
||||
pub enabled: bool,
|
||||
pub user: &'a str,
|
||||
pub db: &'a str,
|
||||
static ENABLED: Lazy<ArcSwap<bool>> = Lazy::new(|| ArcSwap::from_pointee(false));
|
||||
|
||||
pub struct QueryLogger;
|
||||
|
||||
pub fn setup() {
|
||||
ENABLED.store(Arc::new(true));
|
||||
|
||||
info!("Logging queries to stdout");
|
||||
}
|
||||
|
||||
pub fn disable() {
|
||||
ENABLED.store(Arc::new(false));
|
||||
}
|
||||
|
||||
pub fn enabled() -> bool {
|
||||
**ENABLED.load()
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<'a> Plugin for QueryLogger<'a> {
|
||||
impl Plugin for QueryLogger {
|
||||
async fn run(
|
||||
&mut self,
|
||||
_query_router: &QueryRouter,
|
||||
ast: &Vec<Statement>,
|
||||
) -> Result<PluginOutput, Error> {
|
||||
if !self.enabled {
|
||||
return Ok(PluginOutput::Allow);
|
||||
}
|
||||
|
||||
let query = ast
|
||||
.iter()
|
||||
.map(|q| q.to_string())
|
||||
.collect::<Vec<String>>()
|
||||
.join("; ");
|
||||
info!("[pool: {}][user: {}] {}", self.db, self.user, query);
|
||||
info!("{}", query);
|
||||
|
||||
Ok(PluginOutput::Allow)
|
||||
}
|
||||
|
||||
@@ -5,39 +5,53 @@ use async_trait::async_trait;
|
||||
use sqlparser::ast::{visit_relations, Statement};
|
||||
|
||||
use crate::{
|
||||
config::TableAccess as TableAccessConfig,
|
||||
errors::Error,
|
||||
plugins::{Plugin, PluginOutput},
|
||||
query_router::QueryRouter,
|
||||
};
|
||||
|
||||
use log::debug;
|
||||
use log::{debug, info};
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use core::ops::ControlFlow;
|
||||
use once_cell::sync::Lazy;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct TableAccess<'a> {
|
||||
pub enabled: bool,
|
||||
pub tables: &'a Vec<String>,
|
||||
static CONFIG: Lazy<ArcSwap<Vec<String>>> = Lazy::new(|| ArcSwap::from_pointee(vec![]));
|
||||
|
||||
pub fn setup(config: &TableAccessConfig) {
|
||||
CONFIG.store(Arc::new(config.tables.clone()));
|
||||
|
||||
info!("Blocking access to {} tables", config.tables.len());
|
||||
}
|
||||
|
||||
pub fn enabled() -> bool {
|
||||
!CONFIG.load().is_empty()
|
||||
}
|
||||
|
||||
pub fn disable() {
|
||||
CONFIG.store(Arc::new(vec![]));
|
||||
}
|
||||
|
||||
pub struct TableAccess;
|
||||
|
||||
#[async_trait]
|
||||
impl<'a> Plugin for TableAccess<'a> {
|
||||
impl Plugin for TableAccess {
|
||||
async fn run(
|
||||
&mut self,
|
||||
_query_router: &QueryRouter,
|
||||
ast: &Vec<Statement>,
|
||||
) -> Result<PluginOutput, Error> {
|
||||
if !self.enabled {
|
||||
return Ok(PluginOutput::Allow);
|
||||
}
|
||||
|
||||
let mut found = None;
|
||||
let forbidden_tables = CONFIG.load();
|
||||
|
||||
visit_relations(ast, |relation| {
|
||||
let relation = relation.to_string();
|
||||
let parts = relation.split(".").collect::<Vec<&str>>();
|
||||
let table_name = parts.last().unwrap();
|
||||
|
||||
if self.tables.contains(&table_name.to_string()) {
|
||||
if forbidden_tables.contains(&table_name.to_string()) {
|
||||
found = Some(table_name.to_string());
|
||||
ControlFlow::<()>::Break(())
|
||||
} else {
|
||||
|
||||
292
src/pool.rs
292
src/pool.rs
@@ -1,6 +1,7 @@
|
||||
use arc_swap::ArcSwap;
|
||||
use async_trait::async_trait;
|
||||
use bb8::{ManageConnection, Pool, PooledConnection, QueueStrategy};
|
||||
use bb8::{ManageConnection, Pool, PooledConnection};
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use chrono::naive::NaiveDateTime;
|
||||
use log::{debug, error, info, warn};
|
||||
use once_cell::sync::Lazy;
|
||||
@@ -9,8 +10,6 @@ use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
use regex::Regex;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
@@ -18,16 +17,13 @@ use std::sync::{
|
||||
use std::time::Instant;
|
||||
use tokio::sync::Notify;
|
||||
|
||||
use crate::config::{
|
||||
get_config, Address, DefaultShard, General, LoadBalancingMode, Plugins, PoolMode, Role, User,
|
||||
};
|
||||
use crate::config::{get_config, Address, General, LoadBalancingMode, PoolMode, Role, User};
|
||||
use crate::errors::Error;
|
||||
|
||||
use crate::auth_passthrough::AuthPassthrough;
|
||||
use crate::plugins::prewarmer;
|
||||
use crate::server::{Server, ServerParameters};
|
||||
use crate::server::Server;
|
||||
use crate::sharding::ShardingFunction;
|
||||
use crate::stats::{AddressStats, ClientStats, ServerStats};
|
||||
use crate::stats::{AddressStats, ClientStats, PoolStats, ServerStats};
|
||||
|
||||
pub type ProcessId = i32;
|
||||
pub type SecretKey = i32;
|
||||
@@ -77,12 +73,6 @@ impl PoolIdentifier {
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for PoolIdentifier {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}@{}", self.user, self.db)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Address> for PoolIdentifier {
|
||||
fn from(address: &Address) -> PoolIdentifier {
|
||||
PoolIdentifier::new(&address.database, &address.username)
|
||||
@@ -111,12 +101,6 @@ pub struct PoolSettings {
|
||||
// Enable/disable query parser.
|
||||
pub query_parser_enabled: bool,
|
||||
|
||||
// Max length of query the parser will parse.
|
||||
pub query_parser_max_length: Option<usize>,
|
||||
|
||||
// Infer role
|
||||
pub query_parser_read_write_splitting: bool,
|
||||
|
||||
// Read from the primary as well or not.
|
||||
pub primary_reads_enabled: bool,
|
||||
|
||||
@@ -141,9 +125,6 @@ pub struct PoolSettings {
|
||||
// Regex for searching for the shard id in SQL statements
|
||||
pub shard_id_regex: Option<Regex>,
|
||||
|
||||
// What to do when no shard is selected in a sharded system
|
||||
pub default_shard: DefaultShard,
|
||||
|
||||
// Limit how much of each query is searched for a potential shard regex match
|
||||
pub regex_search_limit: usize,
|
||||
|
||||
@@ -151,9 +132,6 @@ pub struct PoolSettings {
|
||||
pub auth_query: Option<String>,
|
||||
pub auth_query_user: Option<String>,
|
||||
pub auth_query_password: Option<String>,
|
||||
|
||||
/// Plugins
|
||||
pub plugins: Option<Plugins>,
|
||||
}
|
||||
|
||||
impl Default for PoolSettings {
|
||||
@@ -166,8 +144,6 @@ impl Default for PoolSettings {
|
||||
db: String::default(),
|
||||
default_role: None,
|
||||
query_parser_enabled: false,
|
||||
query_parser_max_length: None,
|
||||
query_parser_read_write_splitting: false,
|
||||
primary_reads_enabled: true,
|
||||
sharding_function: ShardingFunction::PgBigintHash,
|
||||
automatic_sharding_key: None,
|
||||
@@ -177,11 +153,9 @@ impl Default for PoolSettings {
|
||||
sharding_key_regex: None,
|
||||
shard_id_regex: None,
|
||||
regex_search_limit: 1000,
|
||||
default_shard: DefaultShard::Shard(0),
|
||||
auth_query: None,
|
||||
auth_query_user: None,
|
||||
auth_query_password: None,
|
||||
plugins: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -200,10 +174,10 @@ pub struct ConnectionPool {
|
||||
/// that should not be queried.
|
||||
banlist: BanList,
|
||||
|
||||
/// The server information has to be passed to the
|
||||
/// The server information (K messages) have to be passed to the
|
||||
/// clients on startup. We pre-connect to all shards and replicas
|
||||
/// on pool creation and save the startup parameters here.
|
||||
original_server_parameters: Arc<RwLock<ServerParameters>>,
|
||||
/// on pool creation and save the K messages here.
|
||||
server_info: Arc<RwLock<BytesMut>>,
|
||||
|
||||
/// Pool configuration.
|
||||
pub settings: PoolSettings,
|
||||
@@ -221,6 +195,8 @@ pub struct ConnectionPool {
|
||||
paused: Arc<AtomicBool>,
|
||||
paused_waiter: Arc<Notify>,
|
||||
|
||||
pub stats: Arc<PoolStats>,
|
||||
|
||||
/// AuthInfo
|
||||
pub auth_hash: Arc<RwLock<Option<String>>>,
|
||||
}
|
||||
@@ -270,6 +246,10 @@ impl ConnectionPool {
|
||||
.clone()
|
||||
.into_keys()
|
||||
.collect::<Vec<String>>();
|
||||
let pool_stats = Arc::new(PoolStats::new(identifier, pool_config.clone()));
|
||||
|
||||
// Allow the pool to be seen in statistics
|
||||
pool_stats.register(pool_stats.clone());
|
||||
|
||||
// Sort by shard number to ensure consistency.
|
||||
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
|
||||
@@ -304,7 +284,6 @@ impl ConnectionPool {
|
||||
pool_name: pool_name.clone(),
|
||||
mirrors: vec![],
|
||||
stats: Arc::new(AddressStats::default()),
|
||||
error_count: Arc::new(AtomicU64::new(0)),
|
||||
});
|
||||
address_id += 1;
|
||||
}
|
||||
@@ -323,7 +302,6 @@ impl ConnectionPool {
|
||||
pool_name: pool_name.clone(),
|
||||
mirrors: mirror_addresses,
|
||||
stats: Arc::new(AddressStats::default()),
|
||||
error_count: Arc::new(AtomicU64::new(0)),
|
||||
};
|
||||
|
||||
address_id += 1;
|
||||
@@ -372,13 +350,8 @@ impl ConnectionPool {
|
||||
user.clone(),
|
||||
&shard.database,
|
||||
client_server_map.clone(),
|
||||
pool_stats.clone(),
|
||||
pool_auth_hash.clone(),
|
||||
match pool_config.plugins {
|
||||
Some(ref plugins) => Some(plugins.clone()),
|
||||
None => config.plugins.clone(),
|
||||
},
|
||||
pool_config.cleanup_server_connections,
|
||||
pool_config.log_client_parameter_status_changes,
|
||||
);
|
||||
|
||||
let connect_timeout = match pool_config.connect_timeout {
|
||||
@@ -404,15 +377,7 @@ impl ConnectionPool {
|
||||
.min()
|
||||
.unwrap();
|
||||
|
||||
let queue_strategy = match config.general.server_round_robin {
|
||||
true => QueueStrategy::Fifo,
|
||||
false => QueueStrategy::Lifo,
|
||||
};
|
||||
|
||||
debug!(
|
||||
"[pool: {}][user: {}] Pool reaper rate: {}ms",
|
||||
pool_name, user.username, reaper_rate
|
||||
);
|
||||
debug!("Pool reaper rate: {}ms", reaper_rate);
|
||||
|
||||
let pool = Pool::builder()
|
||||
.max_size(user.pool_size)
|
||||
@@ -421,14 +386,9 @@ impl ConnectionPool {
|
||||
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
|
||||
.max_lifetime(Some(std::time::Duration::from_millis(server_lifetime)))
|
||||
.reaper_rate(std::time::Duration::from_millis(reaper_rate))
|
||||
.queue_strategy(queue_strategy)
|
||||
.test_on_check_out(false);
|
||||
|
||||
let pool = if config.general.validate_config {
|
||||
pool.build(manager).await?
|
||||
} else {
|
||||
pool.build_unchecked(manager)
|
||||
};
|
||||
.test_on_check_out(false)
|
||||
.build(manager)
|
||||
.await?;
|
||||
|
||||
pools.push(pool);
|
||||
servers.push(address);
|
||||
@@ -449,10 +409,11 @@ impl ConnectionPool {
|
||||
|
||||
let pool = ConnectionPool {
|
||||
databases: shards,
|
||||
stats: pool_stats,
|
||||
addresses,
|
||||
banlist: Arc::new(RwLock::new(banlist)),
|
||||
config_hash: new_pool_hash_value,
|
||||
original_server_parameters: Arc::new(RwLock::new(ServerParameters::new())),
|
||||
server_info: Arc::new(RwLock::new(BytesMut::new())),
|
||||
auth_hash: pool_auth_hash,
|
||||
settings: PoolSettings {
|
||||
pool_mode: match user.pool_mode {
|
||||
@@ -471,9 +432,6 @@ impl ConnectionPool {
|
||||
_ => unreachable!(),
|
||||
},
|
||||
query_parser_enabled: pool_config.query_parser_enabled,
|
||||
query_parser_max_length: pool_config.query_parser_max_length,
|
||||
query_parser_read_write_splitting: pool_config
|
||||
.query_parser_read_write_splitting,
|
||||
primary_reads_enabled: pool_config.primary_reads_enabled,
|
||||
sharding_function: pool_config.sharding_function,
|
||||
automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
|
||||
@@ -489,14 +447,9 @@ impl ConnectionPool {
|
||||
.clone()
|
||||
.map(|regex| Regex::new(regex.as_str()).unwrap()),
|
||||
regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
|
||||
default_shard: pool_config.default_shard.clone(),
|
||||
auth_query: pool_config.auth_query.clone(),
|
||||
auth_query_user: pool_config.auth_query_user.clone(),
|
||||
auth_query_password: pool_config.auth_query_password.clone(),
|
||||
plugins: match pool_config.plugins {
|
||||
Some(ref plugins) => Some(plugins.clone()),
|
||||
None => config.plugins.clone(),
|
||||
},
|
||||
},
|
||||
validated: Arc::new(AtomicBool::new(false)),
|
||||
paused: Arc::new(AtomicBool::new(false)),
|
||||
@@ -506,18 +459,42 @@ impl ConnectionPool {
|
||||
// Connect to the servers to make sure pool configuration is valid
|
||||
// before setting it globally.
|
||||
// Do this async and somewhere else, we don't have to wait here.
|
||||
if config.general.validate_config {
|
||||
let mut validate_pool = pool.clone();
|
||||
tokio::task::spawn(async move {
|
||||
let _ = validate_pool.validate().await;
|
||||
});
|
||||
}
|
||||
let mut validate_pool = pool.clone();
|
||||
tokio::task::spawn(async move {
|
||||
let _ = validate_pool.validate().await;
|
||||
});
|
||||
|
||||
// There is one pool per database/user pair.
|
||||
new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref plugins) = config.plugins {
|
||||
if let Some(ref intercept) = plugins.intercept {
|
||||
if intercept.enabled {
|
||||
crate::plugins::intercept::setup(intercept, &new_pools);
|
||||
} else {
|
||||
crate::plugins::intercept::disable();
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref table_access) = plugins.table_access {
|
||||
if table_access.enabled {
|
||||
crate::plugins::table_access::setup(table_access);
|
||||
} else {
|
||||
crate::plugins::table_access::disable();
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref query_logger) = plugins.query_logger {
|
||||
if query_logger.enabled {
|
||||
crate::plugins::query_logger::setup();
|
||||
} else {
|
||||
crate::plugins::query_logger::disable();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
POOLS.store(Arc::new(new_pools.clone()));
|
||||
Ok(())
|
||||
}
|
||||
@@ -536,7 +513,7 @@ impl ConnectionPool {
|
||||
for server in 0..self.servers(shard) {
|
||||
let databases = self.databases.clone();
|
||||
let validated = Arc::clone(&validated);
|
||||
let pool_server_parameters = Arc::clone(&self.original_server_parameters);
|
||||
let pool_server_info = Arc::clone(&self.server_info);
|
||||
|
||||
let task = tokio::task::spawn(async move {
|
||||
let connection = match databases[shard][server].get().await {
|
||||
@@ -549,10 +526,11 @@ impl ConnectionPool {
|
||||
|
||||
let proxy = connection;
|
||||
let server = &*proxy;
|
||||
let server_parameters: ServerParameters = server.server_parameters();
|
||||
let server_info = server.server_info();
|
||||
|
||||
let mut guard = pool_server_parameters.write();
|
||||
*guard = server_parameters;
|
||||
let mut guard = pool_server_info.write();
|
||||
guard.clear();
|
||||
guard.put(server_info.clone());
|
||||
validated.store(true, Ordering::Relaxed);
|
||||
});
|
||||
|
||||
@@ -564,7 +542,7 @@ impl ConnectionPool {
|
||||
|
||||
// TODO: compare server information to make sure
|
||||
// all shards are running identical configurations.
|
||||
if !self.validated() {
|
||||
if self.server_info.read().is_empty() {
|
||||
error!("Could not validate connection pool");
|
||||
return Err(Error::AllServersDown);
|
||||
}
|
||||
@@ -611,51 +589,19 @@ impl ConnectionPool {
|
||||
/// Get a connection from the pool.
|
||||
pub async fn get(
|
||||
&self,
|
||||
shard: Option<usize>, // shard number
|
||||
shard: usize, // shard number
|
||||
role: Option<Role>, // primary or replica
|
||||
client_stats: &ClientStats, // client id
|
||||
) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
|
||||
let effective_shard_id = if self.shards() == 1 {
|
||||
// The base, unsharded case
|
||||
Some(0)
|
||||
} else {
|
||||
if !self.valid_shard_id(shard) {
|
||||
// None is valid shard ID so it is safe to unwrap here
|
||||
return Err(Error::InvalidShardId(shard.unwrap()));
|
||||
}
|
||||
shard
|
||||
};
|
||||
|
||||
let mut candidates = self
|
||||
.addresses
|
||||
let mut candidates: Vec<&Address> = self.addresses[shard]
|
||||
.iter()
|
||||
.flatten()
|
||||
.filter(|address| address.role == role)
|
||||
.collect::<Vec<&Address>>();
|
||||
.collect();
|
||||
|
||||
// We start with a shuffled list of addresses even if we end up resorting
|
||||
// this is meant to avoid hitting instance 0 everytime if the sorting metric
|
||||
// ends up being the same for all instances
|
||||
// We shuffle even if least_outstanding_queries is used to avoid imbalance
|
||||
// in cases where all candidates have more or less the same number of outstanding
|
||||
// queries
|
||||
candidates.shuffle(&mut thread_rng());
|
||||
|
||||
match effective_shard_id {
|
||||
Some(shard_id) => candidates.retain(|address| address.shard == shard_id),
|
||||
None => match self.settings.default_shard {
|
||||
DefaultShard::Shard(shard_id) => {
|
||||
candidates.retain(|address| address.shard == shard_id)
|
||||
}
|
||||
DefaultShard::Random => (),
|
||||
DefaultShard::RandomHealthy => {
|
||||
candidates.sort_by(|a, b| {
|
||||
b.error_count
|
||||
.load(Ordering::Relaxed)
|
||||
.partial_cmp(&a.error_count.load(Ordering::Relaxed))
|
||||
.unwrap()
|
||||
});
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
if self.settings.load_balancing_mode == LoadBalancingMode::LeastOutstandingConnections {
|
||||
candidates.sort_by(|a, b| {
|
||||
self.busy_connection_count(b)
|
||||
@@ -664,10 +610,6 @@ impl ConnectionPool {
|
||||
});
|
||||
}
|
||||
|
||||
// Indicate we're waiting on a server connection from a pool.
|
||||
let now = Instant::now();
|
||||
client_stats.waiting();
|
||||
|
||||
while !candidates.is_empty() {
|
||||
// Get the next candidate
|
||||
let address = match candidates.pop() {
|
||||
@@ -686,20 +628,18 @@ impl ConnectionPool {
|
||||
}
|
||||
}
|
||||
|
||||
// Indicate we're waiting on a server connection from a pool.
|
||||
let now = Instant::now();
|
||||
client_stats.waiting();
|
||||
|
||||
// Check if we can connect
|
||||
let mut conn = match self.databases[address.shard][address.address_index]
|
||||
.get()
|
||||
.await
|
||||
{
|
||||
Ok(conn) => {
|
||||
address.reset_error_count();
|
||||
conn
|
||||
}
|
||||
Ok(conn) => conn,
|
||||
Err(err) => {
|
||||
error!(
|
||||
"Connection checkout error for instance {:?}, error: {:?}",
|
||||
address, err
|
||||
);
|
||||
error!("Banning instance {:?}, error: {:?}", address, err);
|
||||
self.ban(address, BanReason::FailedCheckout, Some(client_stats));
|
||||
address.stats.error();
|
||||
client_stats.idle();
|
||||
@@ -720,13 +660,13 @@ impl ConnectionPool {
|
||||
// since we last checked the server is ok.
|
||||
// Health checks are pretty expensive.
|
||||
if !require_healthcheck {
|
||||
let checkout_time = now.elapsed().as_micros() as u64;
|
||||
let checkout_time: u64 = now.elapsed().as_micros() as u64;
|
||||
client_stats.checkout_time(checkout_time);
|
||||
server
|
||||
.stats()
|
||||
.checkout_time(checkout_time, client_stats.application_name());
|
||||
server.stats().active(client_stats.application_name());
|
||||
client_stats.active();
|
||||
|
||||
return Ok((conn, address.clone()));
|
||||
}
|
||||
|
||||
@@ -734,24 +674,11 @@ impl ConnectionPool {
|
||||
.run_health_check(address, server, now, client_stats)
|
||||
.await
|
||||
{
|
||||
let checkout_time = now.elapsed().as_micros() as u64;
|
||||
client_stats.checkout_time(checkout_time);
|
||||
server
|
||||
.stats()
|
||||
.checkout_time(checkout_time, client_stats.application_name());
|
||||
server.stats().active(client_stats.application_name());
|
||||
client_stats.active();
|
||||
return Ok((conn, address.clone()));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
client_stats.idle();
|
||||
|
||||
let checkout_time = now.elapsed().as_micros() as u64;
|
||||
client_stats.checkout_time(checkout_time);
|
||||
|
||||
Err(Error::AllServersDown)
|
||||
}
|
||||
|
||||
@@ -788,7 +715,7 @@ impl ConnectionPool {
|
||||
// Health check failed.
|
||||
Err(err) => {
|
||||
error!(
|
||||
"Failed health check on instance {:?}, error: {:?}",
|
||||
"Banning instance {:?} because of failed health check, {:?}",
|
||||
address, err
|
||||
);
|
||||
}
|
||||
@@ -797,7 +724,7 @@ impl ConnectionPool {
|
||||
// Health check timed out.
|
||||
Err(err) => {
|
||||
error!(
|
||||
"Health check timeout on instance {:?}, error: {:?}",
|
||||
"Banning instance {:?} because of health check timeout, {:?}",
|
||||
address, err
|
||||
);
|
||||
}
|
||||
@@ -814,33 +741,18 @@ impl ConnectionPool {
|
||||
/// traffic for any new transactions. Existing transactions on that replica
|
||||
/// will finish successfully or error out to the clients.
|
||||
pub fn ban(&self, address: &Address, reason: BanReason, client_info: Option<&ClientStats>) {
|
||||
// Count the number of errors since the last successful checkout
|
||||
// This is used to determine if the shard is down
|
||||
match reason {
|
||||
BanReason::FailedHealthCheck
|
||||
| BanReason::FailedCheckout
|
||||
| BanReason::MessageSendFailed
|
||||
| BanReason::MessageReceiveFailed => {
|
||||
address.increment_error_count();
|
||||
}
|
||||
_ => (),
|
||||
};
|
||||
|
||||
// Primary can never be banned
|
||||
if address.role == Role::Primary {
|
||||
return;
|
||||
}
|
||||
|
||||
error!("Banning instance {:?}, reason: {:?}", address, reason);
|
||||
|
||||
let now = chrono::offset::Utc::now().naive_utc();
|
||||
let mut guard = self.banlist.write();
|
||||
|
||||
error!("Banning {:?}", address);
|
||||
if let Some(client_info) = client_info {
|
||||
client_info.ban_error();
|
||||
address.stats.error();
|
||||
}
|
||||
|
||||
guard[address.shard].insert(address.clone(), (reason, now));
|
||||
}
|
||||
|
||||
@@ -976,11 +888,10 @@ impl ConnectionPool {
|
||||
&self.addresses[shard][server]
|
||||
}
|
||||
|
||||
pub fn server_parameters(&self) -> ServerParameters {
|
||||
self.original_server_parameters.read().clone()
|
||||
pub fn server_info(&self) -> BytesMut {
|
||||
self.server_info.read().clone()
|
||||
}
|
||||
|
||||
/// Get the number of checked out connection for an address
|
||||
fn busy_connection_count(&self, address: &Address) -> u32 {
|
||||
let state = self.pool_state(address.shard, address.address_index);
|
||||
let idle = state.idle_connections;
|
||||
@@ -994,40 +905,16 @@ impl ConnectionPool {
|
||||
debug!("{:?} has {:?} busy connections", address, busy);
|
||||
return busy;
|
||||
}
|
||||
|
||||
fn valid_shard_id(&self, shard: Option<usize>) -> bool {
|
||||
match shard {
|
||||
None => true,
|
||||
Some(shard) => shard < self.shards(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper for the bb8 connection pool.
|
||||
pub struct ServerPool {
|
||||
/// Server address.
|
||||
address: Address,
|
||||
|
||||
/// Server Postgres user.
|
||||
user: User,
|
||||
|
||||
/// Server database.
|
||||
database: String,
|
||||
|
||||
/// Client/server mapping.
|
||||
client_server_map: ClientServerMap,
|
||||
|
||||
/// Server auth hash (for auth passthrough).
|
||||
stats: Arc<PoolStats>,
|
||||
auth_hash: Arc<RwLock<Option<String>>>,
|
||||
|
||||
/// Server plugins.
|
||||
plugins: Option<Plugins>,
|
||||
|
||||
/// Should we clean up dirty connections before putting them into the pool?
|
||||
cleanup_connections: bool,
|
||||
|
||||
/// Log client parameter status changes
|
||||
log_client_parameter_status_changes: bool,
|
||||
}
|
||||
|
||||
impl ServerPool {
|
||||
@@ -1036,20 +923,16 @@ impl ServerPool {
|
||||
user: User,
|
||||
database: &str,
|
||||
client_server_map: ClientServerMap,
|
||||
stats: Arc<PoolStats>,
|
||||
auth_hash: Arc<RwLock<Option<String>>>,
|
||||
plugins: Option<Plugins>,
|
||||
cleanup_connections: bool,
|
||||
log_client_parameter_status_changes: bool,
|
||||
) -> ServerPool {
|
||||
ServerPool {
|
||||
address,
|
||||
user: user.clone(),
|
||||
database: database.to_string(),
|
||||
client_server_map,
|
||||
stats,
|
||||
auth_hash,
|
||||
plugins,
|
||||
cleanup_connections,
|
||||
log_client_parameter_status_changes,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1065,6 +948,7 @@ impl ManageConnection for ServerPool {
|
||||
|
||||
let stats = Arc::new(ServerStats::new(
|
||||
self.address.clone(),
|
||||
self.stats.clone(),
|
||||
tokio::time::Instant::now(),
|
||||
));
|
||||
|
||||
@@ -1078,24 +962,10 @@ impl ManageConnection for ServerPool {
|
||||
self.client_server_map.clone(),
|
||||
stats.clone(),
|
||||
self.auth_hash.clone(),
|
||||
self.cleanup_connections,
|
||||
self.log_client_parameter_status_changes,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(mut conn) => {
|
||||
if let Some(ref plugins) = self.plugins {
|
||||
if let Some(ref prewarmer) = plugins.prewarmer {
|
||||
let mut prewarmer = prewarmer::Prewarmer {
|
||||
enabled: prewarmer.enabled,
|
||||
server: &mut conn,
|
||||
queries: &prewarmer.queries,
|
||||
};
|
||||
|
||||
prewarmer.run().await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(conn) => {
|
||||
stats.idle();
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use hyper::service::{make_service_fn, service_fn};
|
||||
use hyper::{Body, Method, Request, Response, Server, StatusCode};
|
||||
use log::{debug, error, info};
|
||||
use log::{error, info, warn};
|
||||
use phf::phf_map;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
@@ -9,9 +9,8 @@ use std::sync::atomic::Ordering;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::Address;
|
||||
use crate::pool::{get_all_pools, PoolIdentifier};
|
||||
use crate::stats::pool::PoolStats;
|
||||
use crate::stats::{get_server_stats, ServerStats};
|
||||
use crate::pool::get_all_pools;
|
||||
use crate::stats::{get_pool_stats, get_server_stats, ServerStats};
|
||||
|
||||
struct MetricHelpType {
|
||||
help: &'static str,
|
||||
@@ -234,10 +233,10 @@ impl<Value: fmt::Display> PrometheusMetric<Value> {
|
||||
Self::from_name(&format!("stats_{}", name), value, labels)
|
||||
}
|
||||
|
||||
fn from_pool(pool_id: PoolIdentifier, name: &str, value: u64) -> Option<PrometheusMetric<u64>> {
|
||||
fn from_pool(pool: &(String, String), name: &str, value: u64) -> Option<PrometheusMetric<u64>> {
|
||||
let mut labels = HashMap::new();
|
||||
labels.insert("pool", pool_id.db);
|
||||
labels.insert("user", pool_id.user);
|
||||
labels.insert("pool", pool.0.clone());
|
||||
labels.insert("user", pool.1.clone());
|
||||
|
||||
Self::from_name(&format!("pools_{}", name), value, labels)
|
||||
}
|
||||
@@ -275,7 +274,7 @@ fn push_address_stats(lines: &mut Vec<String>) {
|
||||
{
|
||||
lines.push(prometheus_metric.to_string());
|
||||
} else {
|
||||
debug!("Metric {} not implemented for {}", key, address.name());
|
||||
warn!("Metric {} not implemented for {}", key, address.name());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -285,15 +284,18 @@ fn push_address_stats(lines: &mut Vec<String>) {
|
||||
|
||||
// Adds relevant metrics shown in a SHOW POOLS admin command.
|
||||
fn push_pool_stats(lines: &mut Vec<String>) {
|
||||
let pool_stats = PoolStats::construct_pool_lookup();
|
||||
for (pool_id, stats) in pool_stats.iter() {
|
||||
let pool_stats = get_pool_stats();
|
||||
for (pool, stats) in pool_stats.iter() {
|
||||
let stats = &**stats;
|
||||
for (name, value) in stats.clone() {
|
||||
if let Some(prometheus_metric) =
|
||||
PrometheusMetric::<u64>::from_pool(pool_id.clone(), &name, value)
|
||||
if let Some(prometheus_metric) = PrometheusMetric::<u64>::from_pool(pool, &name, value)
|
||||
{
|
||||
lines.push(prometheus_metric.to_string());
|
||||
} else {
|
||||
debug!("Metric {} not implemented for ({})", name, *pool_id);
|
||||
warn!(
|
||||
"Metric {} not implemented for ({},{})",
|
||||
name, pool.0, pool.1
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -318,7 +320,7 @@ fn push_database_stats(lines: &mut Vec<String>) {
|
||||
{
|
||||
lines.push(prometheus_metric.to_string());
|
||||
} else {
|
||||
debug!("Metric {} not implemented for {}", key, address.name());
|
||||
warn!("Metric {} not implemented for {}", key, address.name());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -364,7 +366,7 @@ fn push_server_stats(lines: &mut Vec<String>) {
|
||||
{
|
||||
lines.push(prometheus_metric.to_string());
|
||||
} else {
|
||||
debug!("Metric {} not implemented for {}", key, address.name());
|
||||
warn!("Metric {} not implemented for {}", key, address.name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,13 +15,16 @@ use sqlparser::parser::Parser;
|
||||
use crate::config::Role;
|
||||
use crate::errors::Error;
|
||||
use crate::messages::BytesMutReader;
|
||||
use crate::plugins::{Intercept, Plugin, PluginOutput, QueryLogger, TableAccess};
|
||||
use crate::plugins::{
|
||||
intercept, query_logger, table_access, Intercept, Plugin, PluginOutput, QueryLogger,
|
||||
TableAccess,
|
||||
};
|
||||
use crate::pool::PoolSettings;
|
||||
use crate::sharding::Sharder;
|
||||
|
||||
use std::cmp;
|
||||
use std::collections::BTreeSet;
|
||||
use std::io::Cursor;
|
||||
use std::{cmp, mem};
|
||||
|
||||
/// Regexes used to parse custom commands.
|
||||
const CUSTOM_SQL_REGEXES: [&str; 7] = [
|
||||
@@ -141,24 +144,18 @@ impl QueryRouter {
|
||||
let mut message_cursor = Cursor::new(message_buffer);
|
||||
|
||||
let code = message_cursor.get_u8() as char;
|
||||
let len = message_cursor.get_i32() as usize;
|
||||
|
||||
let comment_shard_routing_enabled = self.pool_settings.shard_id_regex.is_some()
|
||||
|| self.pool_settings.sharding_key_regex.is_some();
|
||||
|
||||
// Check for any sharding regex matches in any queries
|
||||
if comment_shard_routing_enabled {
|
||||
match code as char {
|
||||
// For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement
|
||||
'P' | 'Q' => {
|
||||
match code as char {
|
||||
// For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement
|
||||
'P' | 'Q' => {
|
||||
if self.pool_settings.shard_id_regex.is_some()
|
||||
|| self.pool_settings.sharding_key_regex.is_some()
|
||||
{
|
||||
// Check only the first block of bytes configured by the pool settings
|
||||
let len = message_cursor.get_i32() as usize;
|
||||
let seg = cmp::min(len - 5, self.pool_settings.regex_search_limit);
|
||||
|
||||
let query_start_index = mem::size_of::<u8>() + mem::size_of::<i32>();
|
||||
|
||||
let initial_segment = String::from_utf8_lossy(
|
||||
&message_buffer[query_start_index..query_start_index + seg],
|
||||
);
|
||||
let initial_segment = String::from_utf8_lossy(&message_buffer[0..seg]);
|
||||
|
||||
// Check for a shard_id included in the query
|
||||
if let Some(shard_id_regex) = &self.pool_settings.shard_id_regex {
|
||||
@@ -167,7 +164,7 @@ impl QueryRouter {
|
||||
});
|
||||
if let Some(shard_id) = shard_id {
|
||||
debug!("Setting shard to {:?}", shard_id);
|
||||
self.set_shard(Some(shard_id));
|
||||
self.set_shard(shard_id);
|
||||
// Skip other command processing since a sharding command was found
|
||||
return None;
|
||||
}
|
||||
@@ -189,8 +186,8 @@ impl QueryRouter {
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Only simple protocol supported for commands processed below
|
||||
@@ -198,6 +195,7 @@ impl QueryRouter {
|
||||
return None;
|
||||
}
|
||||
|
||||
let _len = message_cursor.get_i32() as usize;
|
||||
let query = message_cursor.read_string().unwrap();
|
||||
|
||||
let regex_set = match CUSTOM_SQL_REGEX_SET.get() {
|
||||
@@ -249,9 +247,7 @@ impl QueryRouter {
|
||||
}
|
||||
}
|
||||
|
||||
Command::ShowShard => self
|
||||
.shard()
|
||||
.map_or_else(|| "unset".to_string(), |x| x.to_string()),
|
||||
Command::ShowShard => self.shard().to_string(),
|
||||
Command::ShowServerRole => match self.active_role {
|
||||
Some(Role::Primary) => Role::Primary.to_string(),
|
||||
Some(Role::Replica) => Role::Replica.to_string(),
|
||||
@@ -338,23 +334,11 @@ impl QueryRouter {
|
||||
Some((command, value))
|
||||
}
|
||||
|
||||
pub fn parse(&self, message: &BytesMut) -> Result<Vec<Statement>, Error> {
|
||||
pub fn parse(message: &BytesMut) -> Result<Vec<sqlparser::ast::Statement>, Error> {
|
||||
let mut message_cursor = Cursor::new(message);
|
||||
|
||||
let code = message_cursor.get_u8() as char;
|
||||
let len = message_cursor.get_i32() as usize;
|
||||
|
||||
match self.pool_settings.query_parser_max_length {
|
||||
Some(max_length) => {
|
||||
if len > max_length {
|
||||
return Err(Error::QueryRouterParserError(format!(
|
||||
"Query too long for parser: {} > {}",
|
||||
len, max_length
|
||||
)));
|
||||
}
|
||||
}
|
||||
None => (),
|
||||
};
|
||||
let _len = message_cursor.get_i32() as usize;
|
||||
|
||||
let query = match code {
|
||||
// Query
|
||||
@@ -367,13 +351,12 @@ impl QueryRouter {
|
||||
// Parse (prepared statement)
|
||||
'P' => {
|
||||
// Reads statement name
|
||||
let _name = message_cursor.read_string().unwrap();
|
||||
message_cursor.read_string().unwrap();
|
||||
|
||||
// Reads query string
|
||||
let query = message_cursor.read_string().unwrap();
|
||||
|
||||
debug!("Prepared statement: '{}'", query);
|
||||
|
||||
query
|
||||
}
|
||||
|
||||
@@ -391,10 +374,6 @@ impl QueryRouter {
|
||||
|
||||
/// Try to infer which server to connect to based on the contents of the query.
|
||||
pub fn infer(&mut self, ast: &Vec<sqlparser::ast::Statement>) -> Result<(), Error> {
|
||||
if !self.pool_settings.query_parser_read_write_splitting {
|
||||
return Ok(()); // Nothing to do
|
||||
}
|
||||
|
||||
debug!("Inferring role");
|
||||
|
||||
if ast.is_empty() {
|
||||
@@ -456,10 +435,6 @@ impl QueryRouter {
|
||||
/// N.B.: Only supports anonymous prepared statements since we don't
|
||||
/// keep a cache of them in PgCat.
|
||||
pub fn infer_shard_from_bind(&mut self, message: &BytesMut) -> bool {
|
||||
if !self.pool_settings.query_parser_read_write_splitting {
|
||||
return false; // Nothing to do
|
||||
}
|
||||
|
||||
debug!("Parsing bind message");
|
||||
|
||||
let mut message_cursor = Cursor::new(message);
|
||||
@@ -584,7 +559,7 @@ impl QueryRouter {
|
||||
// TODO: Support multi-shard queries some day.
|
||||
if shards.len() == 1 {
|
||||
debug!("Found one sharding key");
|
||||
self.set_shard(Some(*shards.first().unwrap()));
|
||||
self.set_shard(*shards.first().unwrap());
|
||||
true
|
||||
} else {
|
||||
debug!("Found no sharding keys");
|
||||
@@ -818,27 +793,13 @@ impl QueryRouter {
|
||||
|
||||
/// Add your plugins here and execute them.
|
||||
pub async fn execute_plugins(&self, ast: &Vec<Statement>) -> Result<PluginOutput, Error> {
|
||||
let plugins = match self.pool_settings.plugins {
|
||||
Some(ref plugins) => plugins,
|
||||
None => return Ok(PluginOutput::Allow),
|
||||
};
|
||||
|
||||
if let Some(ref query_logger) = plugins.query_logger {
|
||||
let mut query_logger = QueryLogger {
|
||||
enabled: query_logger.enabled,
|
||||
user: &self.pool_settings.user.username,
|
||||
db: &self.pool_settings.db,
|
||||
};
|
||||
|
||||
if query_logger::enabled() {
|
||||
let mut query_logger = QueryLogger {};
|
||||
let _ = query_logger.run(&self, ast).await;
|
||||
}
|
||||
|
||||
if let Some(ref intercept) = plugins.intercept {
|
||||
let mut intercept = Intercept {
|
||||
enabled: intercept.enabled,
|
||||
config: &intercept,
|
||||
};
|
||||
|
||||
if intercept::enabled() {
|
||||
let mut intercept = Intercept {};
|
||||
let result = intercept.run(&self, ast).await;
|
||||
|
||||
if let Ok(PluginOutput::Intercept(output)) = result {
|
||||
@@ -846,12 +807,8 @@ impl QueryRouter {
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref table_access) = plugins.table_access {
|
||||
let mut table_access = TableAccess {
|
||||
enabled: table_access.enabled,
|
||||
tables: &table_access.tables,
|
||||
};
|
||||
|
||||
if table_access::enabled() {
|
||||
let mut table_access = TableAccess {};
|
||||
let result = table_access.run(&self, ast).await;
|
||||
|
||||
if let Ok(PluginOutput::Deny(error)) = result {
|
||||
@@ -868,7 +825,7 @@ impl QueryRouter {
|
||||
self.pool_settings.sharding_function,
|
||||
);
|
||||
let shard = sharder.shard(sharding_key);
|
||||
self.set_shard(Some(shard));
|
||||
self.set_shard(shard);
|
||||
self.active_shard
|
||||
}
|
||||
|
||||
@@ -878,12 +835,12 @@ impl QueryRouter {
|
||||
}
|
||||
|
||||
/// Get desired shard we should be talking to.
|
||||
pub fn shard(&self) -> Option<usize> {
|
||||
self.active_shard
|
||||
pub fn shard(&self) -> usize {
|
||||
self.active_shard.unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn set_shard(&mut self, shard: Option<usize>) {
|
||||
self.active_shard = shard;
|
||||
pub fn set_shard(&mut self, shard: usize) {
|
||||
self.active_shard = Some(shard);
|
||||
}
|
||||
|
||||
/// Should we attempt to parse queries?
|
||||
@@ -937,7 +894,6 @@ mod test {
|
||||
fn test_infer_replica() {
|
||||
QueryRouter::setup();
|
||||
let mut qr = QueryRouter::new();
|
||||
qr.pool_settings.query_parser_read_write_splitting = true;
|
||||
assert!(qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")) != None);
|
||||
assert!(qr.query_parser_enabled());
|
||||
|
||||
@@ -953,7 +909,7 @@ mod test {
|
||||
|
||||
for query in queries {
|
||||
// It's a recognized query
|
||||
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
|
||||
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), Some(Role::Replica));
|
||||
}
|
||||
}
|
||||
@@ -962,7 +918,6 @@ mod test {
|
||||
fn test_infer_primary() {
|
||||
QueryRouter::setup();
|
||||
let mut qr = QueryRouter::new();
|
||||
qr.pool_settings.query_parser_read_write_splitting = true;
|
||||
|
||||
let queries = vec![
|
||||
simple_query("UPDATE items SET name = 'pumpkin' WHERE id = 5"),
|
||||
@@ -973,7 +928,7 @@ mod test {
|
||||
|
||||
for query in queries {
|
||||
// It's a recognized query
|
||||
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
|
||||
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), Some(Role::Primary));
|
||||
}
|
||||
}
|
||||
@@ -985,7 +940,7 @@ mod test {
|
||||
let query = simple_query("SELECT * FROM items WHERE id = 5");
|
||||
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None);
|
||||
|
||||
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
|
||||
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), None);
|
||||
}
|
||||
|
||||
@@ -993,8 +948,6 @@ mod test {
|
||||
fn test_infer_parse_prepared() {
|
||||
QueryRouter::setup();
|
||||
let mut qr = QueryRouter::new();
|
||||
qr.pool_settings.query_parser_read_write_splitting = true;
|
||||
|
||||
qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"));
|
||||
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
|
||||
|
||||
@@ -1007,7 +960,7 @@ mod test {
|
||||
res.put(prepared_stmt);
|
||||
res.put_i16(0);
|
||||
|
||||
assert!(qr.infer(&qr.parse(&res).unwrap()).is_ok());
|
||||
assert!(qr.infer(&QueryRouter::parse(&res).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), Some(Role::Replica));
|
||||
}
|
||||
|
||||
@@ -1093,7 +1046,7 @@ mod test {
|
||||
qr.try_execute_command(&query),
|
||||
Some((Command::SetShardingKey, String::from("0")))
|
||||
);
|
||||
assert_eq!(qr.shard().unwrap(), 0);
|
||||
assert_eq!(qr.shard(), 0);
|
||||
|
||||
// SetShard
|
||||
let query = simple_query("SET SHARD TO '1'");
|
||||
@@ -1101,7 +1054,7 @@ mod test {
|
||||
qr.try_execute_command(&query),
|
||||
Some((Command::SetShard, String::from("1")))
|
||||
);
|
||||
assert_eq!(qr.shard().unwrap(), 1);
|
||||
assert_eq!(qr.shard(), 1);
|
||||
|
||||
// ShowShard
|
||||
let query = simple_query("SHOW SHARD");
|
||||
@@ -1163,8 +1116,6 @@ mod test {
|
||||
fn test_enable_query_parser() {
|
||||
QueryRouter::setup();
|
||||
let mut qr = QueryRouter::new();
|
||||
qr.pool_settings.query_parser_read_write_splitting = true;
|
||||
|
||||
let query = simple_query("SET SERVER ROLE TO 'auto'");
|
||||
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
|
||||
|
||||
@@ -1173,11 +1124,11 @@ mod test {
|
||||
assert_eq!(qr.role(), None);
|
||||
|
||||
let query = simple_query("INSERT INTO test_table VALUES (1)");
|
||||
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
|
||||
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), Some(Role::Primary));
|
||||
|
||||
let query = simple_query("SELECT * FROM test_table");
|
||||
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
|
||||
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), Some(Role::Replica));
|
||||
|
||||
assert!(qr.query_parser_enabled());
|
||||
@@ -1197,8 +1148,6 @@ mod test {
|
||||
user: crate::config::User::default(),
|
||||
default_role: Some(Role::Replica),
|
||||
query_parser_enabled: true,
|
||||
query_parser_max_length: None,
|
||||
query_parser_read_write_splitting: true,
|
||||
primary_reads_enabled: false,
|
||||
sharding_function: ShardingFunction::PgBigintHash,
|
||||
automatic_sharding_key: Some(String::from("test.id")),
|
||||
@@ -1207,13 +1156,11 @@ mod test {
|
||||
ban_time: PoolSettings::default().ban_time,
|
||||
sharding_key_regex: None,
|
||||
shard_id_regex: None,
|
||||
default_shard: crate::config::DefaultShard::Shard(0),
|
||||
regex_search_limit: 1000,
|
||||
auth_query: None,
|
||||
auth_query_password: None,
|
||||
auth_query_user: None,
|
||||
db: "test".to_string(),
|
||||
plugins: None,
|
||||
};
|
||||
let mut qr = QueryRouter::new();
|
||||
assert_eq!(qr.active_role, None);
|
||||
@@ -1244,18 +1191,18 @@ mod test {
|
||||
|
||||
let mut qr = QueryRouter::new();
|
||||
assert!(qr
|
||||
.infer(&qr.parse(&simple_query("BEGIN; SELECT 1; COMMIT;")).unwrap())
|
||||
.infer(&QueryRouter::parse(&simple_query("BEGIN; SELECT 1; COMMIT;")).unwrap())
|
||||
.is_ok());
|
||||
assert_eq!(qr.role(), Role::Primary);
|
||||
|
||||
assert!(qr
|
||||
.infer(&qr.parse(&simple_query("SELECT 1; SELECT 2;")).unwrap())
|
||||
.infer(&QueryRouter::parse(&simple_query("SELECT 1; SELECT 2;")).unwrap())
|
||||
.is_ok());
|
||||
assert_eq!(qr.role(), Role::Replica);
|
||||
|
||||
assert!(qr
|
||||
.infer(
|
||||
&qr.parse(&simple_query(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
"SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
|
||||
))
|
||||
.unwrap()
|
||||
@@ -1275,8 +1222,6 @@ mod test {
|
||||
user: crate::config::User::default(),
|
||||
default_role: Some(Role::Replica),
|
||||
query_parser_enabled: true,
|
||||
query_parser_max_length: None,
|
||||
query_parser_read_write_splitting: true,
|
||||
primary_reads_enabled: false,
|
||||
sharding_function: ShardingFunction::PgBigintHash,
|
||||
automatic_sharding_key: None,
|
||||
@@ -1285,26 +1230,18 @@ mod test {
|
||||
ban_time: PoolSettings::default().ban_time,
|
||||
sharding_key_regex: Some(Regex::new(r"/\* sharding_key: (\d+) \*/").unwrap()),
|
||||
shard_id_regex: Some(Regex::new(r"/\* shard_id: (\d+) \*/").unwrap()),
|
||||
default_shard: crate::config::DefaultShard::Shard(0),
|
||||
regex_search_limit: 1000,
|
||||
auth_query: None,
|
||||
auth_query_password: None,
|
||||
auth_query_user: None,
|
||||
db: "test".to_string(),
|
||||
plugins: None,
|
||||
};
|
||||
|
||||
let mut qr = QueryRouter::new();
|
||||
qr.update_pool_settings(pool_settings.clone());
|
||||
|
||||
// Shard should start out unset
|
||||
assert_eq!(qr.active_shard, None);
|
||||
|
||||
// Don't panic when short query eg. ; is sent
|
||||
let q0 = simple_query(";");
|
||||
assert!(qr.try_execute_command(&q0) == None);
|
||||
assert_eq!(qr.active_shard, None);
|
||||
|
||||
// Make sure setting it works
|
||||
let q1 = simple_query("/* shard_id: 1 */ select 1 from foo;");
|
||||
assert!(qr.try_execute_command(&q1) == None);
|
||||
@@ -1328,29 +1265,25 @@ mod test {
|
||||
let mut qr = QueryRouter::new();
|
||||
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
|
||||
qr.pool_settings.shards = 3;
|
||||
qr.pool_settings.query_parser_read_write_splitting = true;
|
||||
|
||||
assert!(qr
|
||||
.infer(
|
||||
&qr.parse(&simple_query("SELECT * FROM data WHERE id = 5"))
|
||||
.unwrap(),
|
||||
)
|
||||
.infer(&QueryRouter::parse(&simple_query("SELECT * FROM data WHERE id = 5")).unwrap())
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard().unwrap(), 2);
|
||||
assert_eq!(qr.shard(), 2);
|
||||
|
||||
assert!(qr
|
||||
.infer(
|
||||
&qr.parse(&simple_query(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
"SELECT one, two, three FROM public.data WHERE id = 6"
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard().unwrap(), 0);
|
||||
assert_eq!(qr.shard(), 0);
|
||||
|
||||
assert!(qr
|
||||
.infer(
|
||||
&qr.parse(&simple_query(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
"SELECT * FROM data
|
||||
INNER JOIN t2 ON data.id = 5
|
||||
AND t2.data_id = data.id
|
||||
@@ -1359,59 +1292,59 @@ mod test {
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard().unwrap(), 2);
|
||||
assert_eq!(qr.shard(), 2);
|
||||
|
||||
// Shard did not move because we couldn't determine the sharding key since it could be ambiguous
|
||||
// in the query.
|
||||
assert!(qr
|
||||
.infer(
|
||||
&qr.parse(&simple_query(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
"SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard().unwrap(), 2);
|
||||
assert_eq!(qr.shard(), 2);
|
||||
|
||||
assert!(qr
|
||||
.infer(
|
||||
&qr.parse(&simple_query(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard().unwrap(), 0);
|
||||
assert_eq!(qr.shard(), 0);
|
||||
|
||||
assert!(qr
|
||||
.infer(
|
||||
&qr.parse(&simple_query(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard().unwrap(), 2);
|
||||
assert_eq!(qr.shard(), 2);
|
||||
|
||||
// Super unique sharding key
|
||||
qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string());
|
||||
assert!(qr
|
||||
.infer(
|
||||
&qr.parse(&simple_query(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
"SELECT * FROM table_x WHERE unique_enough_column_name = 6"
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard().unwrap(), 0);
|
||||
assert_eq!(qr.shard(), 0);
|
||||
|
||||
assert!(qr
|
||||
.infer(
|
||||
&qr.parse(&simple_query("SELECT * FROM table_y WHERE another_key = 5"))
|
||||
&QueryRouter::parse(&simple_query("SELECT * FROM table_y WHERE another_key = 5"))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard().unwrap(), 0);
|
||||
assert_eq!(qr.shard(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1433,40 +1366,33 @@ mod test {
|
||||
let mut qr = QueryRouter::new();
|
||||
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
|
||||
qr.pool_settings.shards = 3;
|
||||
qr.pool_settings.query_parser_read_write_splitting = true;
|
||||
|
||||
assert!(qr.infer(&qr.parse(&simple_query(stmt)).unwrap()).is_ok());
|
||||
assert!(qr
|
||||
.infer(&QueryRouter::parse(&simple_query(stmt)).unwrap())
|
||||
.is_ok());
|
||||
assert_eq!(qr.placeholders.len(), 1);
|
||||
|
||||
assert!(qr.infer_shard_from_bind(&bind));
|
||||
assert_eq!(qr.shard().unwrap(), 2);
|
||||
assert_eq!(qr.shard(), 2);
|
||||
assert!(qr.placeholders.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_table_access_plugin() {
|
||||
use crate::config::{Plugins, TableAccess};
|
||||
let table_access = TableAccess {
|
||||
use crate::config::TableAccess;
|
||||
let ta = TableAccess {
|
||||
enabled: true,
|
||||
tables: vec![String::from("pg_database")],
|
||||
};
|
||||
let plugins = Plugins {
|
||||
table_access: Some(table_access),
|
||||
intercept: None,
|
||||
query_logger: None,
|
||||
prewarmer: None,
|
||||
};
|
||||
|
||||
crate::plugins::table_access::setup(&ta);
|
||||
|
||||
QueryRouter::setup();
|
||||
let mut pool_settings = PoolSettings::default();
|
||||
pool_settings.query_parser_enabled = true;
|
||||
pool_settings.plugins = Some(plugins);
|
||||
|
||||
let mut qr = QueryRouter::new();
|
||||
qr.update_pool_settings(pool_settings);
|
||||
let qr = QueryRouter::new();
|
||||
|
||||
let query = simple_query("SELECT * FROM pg_database");
|
||||
let ast = qr.parse(&query).unwrap();
|
||||
let ast = QueryRouter::parse(&query).unwrap();
|
||||
|
||||
let res = qr.execute_plugins(&ast).await;
|
||||
|
||||
@@ -1477,17 +1403,4 @@ mod test {
|
||||
))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_plugins_disabled_by_defaault() {
|
||||
QueryRouter::setup();
|
||||
let qr = QueryRouter::new();
|
||||
|
||||
let query = simple_query("SELECT * FROM pg_database");
|
||||
let ast = qr.parse(&query).unwrap();
|
||||
|
||||
let res = qr.execute_plugins(&ast).await;
|
||||
|
||||
assert_eq!(res, Ok(PluginOutput::Allow));
|
||||
}
|
||||
}
|
||||
|
||||
510
src/server.rs
510
src/server.rs
@@ -3,11 +3,10 @@
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use log::{debug, error, info, trace, warn};
|
||||
use once_cell::sync::Lazy;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use postgres_protocol::message;
|
||||
use std::collections::{BTreeSet, HashMap, HashSet};
|
||||
use std::mem;
|
||||
use std::collections::HashMap;
|
||||
use std::io::Read;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::SystemTime;
|
||||
@@ -16,11 +15,10 @@ use tokio::net::TcpStream;
|
||||
use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore};
|
||||
use tokio_rustls::{client::TlsStream, TlsConnector};
|
||||
|
||||
use crate::config::{get_config, get_prepared_statements_cache_size, Address, User};
|
||||
use crate::config::{get_config, Address, User};
|
||||
use crate::constants::*;
|
||||
use crate::dns_cache::{AddrSet, CACHED_RESOLVER};
|
||||
use crate::errors::{Error, ServerIdentifier};
|
||||
use crate::messages::BytesMutReader;
|
||||
use crate::messages::*;
|
||||
use crate::mirrors::MirroringManager;
|
||||
use crate::pool::ClientServerMap;
|
||||
@@ -105,166 +103,6 @@ impl StreamInner {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
struct CleanupState {
|
||||
/// If server connection requires RESET ALL before checkin because of set statement
|
||||
needs_cleanup_set: bool,
|
||||
|
||||
/// If server connection requires DEALLOCATE ALL before checkin because of prepare statement
|
||||
needs_cleanup_prepare: bool,
|
||||
}
|
||||
|
||||
impl CleanupState {
|
||||
fn new() -> Self {
|
||||
CleanupState {
|
||||
needs_cleanup_set: false,
|
||||
needs_cleanup_prepare: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn needs_cleanup(&self) -> bool {
|
||||
self.needs_cleanup_set || self.needs_cleanup_prepare
|
||||
}
|
||||
|
||||
fn set_true(&mut self) {
|
||||
self.needs_cleanup_set = true;
|
||||
self.needs_cleanup_prepare = true;
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.needs_cleanup_set = false;
|
||||
self.needs_cleanup_prepare = false;
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CleanupState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"SET: {}, PREPARE: {}",
|
||||
self.needs_cleanup_set, self.needs_cleanup_prepare
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
static TRACKED_PARAMETERS: Lazy<HashSet<String>> = Lazy::new(|| {
|
||||
let mut set = HashSet::new();
|
||||
set.insert("client_encoding".to_string());
|
||||
set.insert("DateStyle".to_string());
|
||||
set.insert("TimeZone".to_string());
|
||||
set.insert("standard_conforming_strings".to_string());
|
||||
set.insert("application_name".to_string());
|
||||
set
|
||||
});
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ServerParameters {
|
||||
parameters: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl Default for ServerParameters {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerParameters {
|
||||
pub fn new() -> Self {
|
||||
let mut server_parameters = ServerParameters {
|
||||
parameters: HashMap::new(),
|
||||
};
|
||||
|
||||
server_parameters.set_param("client_encoding".to_string(), "UTF8".to_string(), false);
|
||||
server_parameters.set_param("DateStyle".to_string(), "ISO, MDY".to_string(), false);
|
||||
server_parameters.set_param("TimeZone".to_string(), "Etc/UTC".to_string(), false);
|
||||
server_parameters.set_param(
|
||||
"standard_conforming_strings".to_string(),
|
||||
"on".to_string(),
|
||||
false,
|
||||
);
|
||||
server_parameters.set_param("application_name".to_string(), "pgcat".to_string(), false);
|
||||
|
||||
server_parameters
|
||||
}
|
||||
|
||||
/// returns true if a tracked parameter was set, false if it was a non-tracked parameter
|
||||
/// if startup is false, then then only tracked parameters will be set
|
||||
pub fn set_param(&mut self, mut key: String, value: String, startup: bool) {
|
||||
// The startup parameter will send uncapitalized keys but parameter status packets will send capitalized keys
|
||||
if key == "timezone" {
|
||||
key = "TimeZone".to_string();
|
||||
} else if key == "datestyle" {
|
||||
key = "DateStyle".to_string();
|
||||
};
|
||||
|
||||
if TRACKED_PARAMETERS.contains(&key) {
|
||||
self.parameters.insert(key, value);
|
||||
} else {
|
||||
if startup {
|
||||
self.parameters.insert(key, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_from_hashmap(&mut self, parameters: &HashMap<String, String>, startup: bool) {
|
||||
// iterate through each and call set_param
|
||||
for (key, value) in parameters {
|
||||
self.set_param(key.to_string(), value.to_string(), startup);
|
||||
}
|
||||
}
|
||||
|
||||
// Gets the diff of the parameters
|
||||
fn compare_params(&self, incoming_parameters: &ServerParameters) -> HashMap<String, String> {
|
||||
let mut diff = HashMap::new();
|
||||
|
||||
// iterate through tracked parameters
|
||||
for key in TRACKED_PARAMETERS.iter() {
|
||||
if let Some(incoming_value) = incoming_parameters.parameters.get(key) {
|
||||
if let Some(value) = self.parameters.get(key) {
|
||||
if value != incoming_value {
|
||||
diff.insert(key.to_string(), incoming_value.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
diff
|
||||
}
|
||||
|
||||
pub fn get_application_name(&self) -> &String {
|
||||
// Can unwrap because we set it in the constructor
|
||||
self.parameters.get("application_name").unwrap()
|
||||
}
|
||||
|
||||
fn add_parameter_message(key: &str, value: &str, buffer: &mut BytesMut) {
|
||||
buffer.put_u8(b'S');
|
||||
|
||||
// 4 is len of i32, the plus for the null terminator
|
||||
let len = 4 + key.len() + 1 + value.len() + 1;
|
||||
|
||||
buffer.put_i32(len as i32);
|
||||
|
||||
buffer.put_slice(key.as_bytes());
|
||||
buffer.put_u8(0);
|
||||
buffer.put_slice(value.as_bytes());
|
||||
buffer.put_u8(0);
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&ServerParameters> for BytesMut {
|
||||
fn from(server_parameters: &ServerParameters) -> Self {
|
||||
let mut bytes = BytesMut::new();
|
||||
|
||||
for (key, value) in &server_parameters.parameters {
|
||||
ServerParameters::add_parameter_message(key, value, &mut bytes);
|
||||
}
|
||||
|
||||
bytes
|
||||
}
|
||||
}
|
||||
|
||||
// pub fn compare
|
||||
|
||||
/// Server state.
|
||||
pub struct Server {
|
||||
/// Server host, e.g. localhost,
|
||||
@@ -278,7 +116,7 @@ pub struct Server {
|
||||
buffer: BytesMut,
|
||||
|
||||
/// Server information the server sent us over on startup.
|
||||
server_parameters: ServerParameters,
|
||||
server_info: BytesMut,
|
||||
|
||||
/// Backend id and secret key used for query cancellation.
|
||||
process_id: i32,
|
||||
@@ -290,14 +128,11 @@ pub struct Server {
|
||||
/// Is there more data for the client to read.
|
||||
data_available: bool,
|
||||
|
||||
/// Is the server in copy-in or copy-out modes
|
||||
in_copy_mode: bool,
|
||||
|
||||
/// Is the server broken? We'll remote it from the pool if so.
|
||||
bad: bool,
|
||||
|
||||
/// If server connection requires reset statements before checkin
|
||||
cleanup_state: CleanupState,
|
||||
/// If server connection requires a DISCARD ALL before checkin
|
||||
needs_cleanup: bool,
|
||||
|
||||
/// Mapping of clients and servers used for query cancellation.
|
||||
client_server_map: ClientServerMap,
|
||||
@@ -311,22 +146,13 @@ pub struct Server {
|
||||
/// Application name using the server at the moment.
|
||||
application_name: String,
|
||||
|
||||
/// Last time that a successful server send or response happened
|
||||
// Last time that a successful server send or response happened
|
||||
last_activity: SystemTime,
|
||||
|
||||
mirror_manager: Option<MirroringManager>,
|
||||
|
||||
/// Associated addresses used
|
||||
// Associated addresses used
|
||||
addr_set: Option<AddrSet>,
|
||||
|
||||
/// Should clean up dirty connections?
|
||||
cleanup_connections: bool,
|
||||
|
||||
/// Log client parameter status changes
|
||||
log_client_parameter_status_changes: bool,
|
||||
|
||||
/// Prepared statements
|
||||
prepared_statements: BTreeSet<String>,
|
||||
}
|
||||
|
||||
impl Server {
|
||||
@@ -339,8 +165,6 @@ impl Server {
|
||||
client_server_map: ClientServerMap,
|
||||
stats: Arc<ServerStats>,
|
||||
auth_hash: Arc<RwLock<Option<String>>>,
|
||||
cleanup_connections: bool,
|
||||
log_client_parameter_status_changes: bool,
|
||||
) -> Result<Server, Error> {
|
||||
let cached_resolver = CACHED_RESOLVER.load();
|
||||
let mut addr_set: Option<AddrSet> = None;
|
||||
@@ -471,6 +295,7 @@ impl Server {
|
||||
|
||||
startup(&mut stream, username, database).await?;
|
||||
|
||||
let mut server_info = BytesMut::new();
|
||||
let mut process_id: i32 = 0;
|
||||
let mut secret_key: i32 = 0;
|
||||
let server_identifier = ServerIdentifier::new(username, &database);
|
||||
@@ -482,8 +307,6 @@ impl Server {
|
||||
None => None,
|
||||
};
|
||||
|
||||
let mut server_parameters = ServerParameters::new();
|
||||
|
||||
loop {
|
||||
let code = match stream.read_u8().await {
|
||||
Ok(code) => code as char,
|
||||
@@ -713,7 +536,8 @@ impl Server {
|
||||
|
||||
// An error message will be present.
|
||||
_ => {
|
||||
let mut error = vec![0u8; len as usize];
|
||||
// Read the error message without the terminating null character.
|
||||
let mut error = vec![0u8; len as usize - 4 - 1];
|
||||
|
||||
match stream.read_exact(&mut error).await {
|
||||
Ok(_) => (),
|
||||
@@ -725,14 +549,10 @@ impl Server {
|
||||
}
|
||||
};
|
||||
|
||||
let fields = match PgErrorMsg::parse(error) {
|
||||
Ok(f) => f,
|
||||
Err(err) => {
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
trace!("error fields: {}", &fields);
|
||||
error!("server error: {}: {}", fields.severity, fields.message);
|
||||
// TODO: the error message contains multiple fields; we can decode them and
|
||||
// present a prettier message to the user.
|
||||
// See: https://www.postgresql.org/docs/12/protocol-error-fields.html
|
||||
error!("Server error: {}", String::from_utf8_lossy(&error));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -741,10 +561,9 @@ impl Server {
|
||||
|
||||
// ParameterStatus
|
||||
'S' => {
|
||||
let mut bytes = BytesMut::with_capacity(len as usize - 4);
|
||||
bytes.resize(len as usize - mem::size_of::<i32>(), b'0');
|
||||
let mut param = vec![0u8; len as usize - 4];
|
||||
|
||||
match stream.read_exact(&mut bytes[..]).await {
|
||||
match stream.read_exact(&mut param).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
@@ -754,13 +573,12 @@ impl Server {
|
||||
}
|
||||
};
|
||||
|
||||
let key = bytes.read_string().unwrap();
|
||||
let value = bytes.read_string().unwrap();
|
||||
|
||||
// Save the parameter so we can pass it to the client later.
|
||||
// These can be server_encoding, client_encoding, server timezone, Postgres version,
|
||||
// and many more interesting things we should know about the Postgres server we are talking to.
|
||||
server_parameters.set_param(key, value, true);
|
||||
server_info.put_u8(b'S');
|
||||
server_info.put_i32(len);
|
||||
server_info.put_slice(¶m[..]);
|
||||
}
|
||||
|
||||
// BackendKeyData
|
||||
@@ -802,23 +620,22 @@ impl Server {
|
||||
}
|
||||
};
|
||||
|
||||
let server = Server {
|
||||
let mut server = Server {
|
||||
address: address.clone(),
|
||||
stream: BufStream::new(stream),
|
||||
buffer: BytesMut::with_capacity(8196),
|
||||
server_parameters,
|
||||
server_info,
|
||||
process_id,
|
||||
secret_key,
|
||||
in_transaction: false,
|
||||
in_copy_mode: false,
|
||||
data_available: false,
|
||||
bad: false,
|
||||
cleanup_state: CleanupState::new(),
|
||||
needs_cleanup: false,
|
||||
client_server_map,
|
||||
addr_set,
|
||||
connected_at: chrono::offset::Utc::now().naive_utc(),
|
||||
stats,
|
||||
application_name: "pgcat".to_string(),
|
||||
application_name: String::new(),
|
||||
last_activity: SystemTime::now(),
|
||||
mirror_manager: match address.mirrors.len() {
|
||||
0 => None,
|
||||
@@ -828,11 +645,10 @@ impl Server {
|
||||
address.mirrors.clone(),
|
||||
)),
|
||||
},
|
||||
cleanup_connections,
|
||||
log_client_parameter_status_changes,
|
||||
prepared_statements: BTreeSet::new(),
|
||||
};
|
||||
|
||||
server.set_name("pgcat").await?;
|
||||
|
||||
return Ok(server);
|
||||
}
|
||||
|
||||
@@ -889,10 +705,7 @@ impl Server {
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
error!(
|
||||
"Terminating server {:?} because of: {:?}",
|
||||
self.address, err
|
||||
);
|
||||
error!("Terminating server because of: {:?}", err);
|
||||
self.bad = true;
|
||||
Err(err)
|
||||
}
|
||||
@@ -902,18 +715,12 @@ impl Server {
|
||||
/// Receive data from the server in response to a client request.
|
||||
/// This method must be called multiple times while `self.is_data_available()` is true
|
||||
/// in order to receive all data the server has to offer.
|
||||
pub async fn recv(
|
||||
&mut self,
|
||||
mut client_server_parameters: Option<&mut ServerParameters>,
|
||||
) -> Result<BytesMut, Error> {
|
||||
pub async fn recv(&mut self) -> Result<BytesMut, Error> {
|
||||
loop {
|
||||
let mut message = match read_message(&mut self.stream).await {
|
||||
Ok(message) => message,
|
||||
Err(err) => {
|
||||
error!(
|
||||
"Terminating server {:?} because of: {:?}",
|
||||
self.address, err
|
||||
);
|
||||
error!("Terminating server because of: {:?}", err);
|
||||
self.bad = true;
|
||||
return Err(err);
|
||||
}
|
||||
@@ -964,39 +771,28 @@ impl Server {
|
||||
break;
|
||||
}
|
||||
|
||||
// ErrorResponse
|
||||
'E' => {
|
||||
if self.in_copy_mode {
|
||||
self.in_copy_mode = false;
|
||||
}
|
||||
}
|
||||
|
||||
// CommandComplete
|
||||
'C' => {
|
||||
if self.in_copy_mode {
|
||||
self.in_copy_mode = false;
|
||||
}
|
||||
|
||||
match message.read_string() {
|
||||
Ok(command) => {
|
||||
let mut command_tag = String::new();
|
||||
match message.reader().read_to_string(&mut command_tag) {
|
||||
Ok(_) => {
|
||||
// Non-exhaustive list of commands that are likely to change session variables/resources
|
||||
// which can leak between clients. This is a best effort to block bad clients
|
||||
// from poisoning a transaction-mode pool by setting inappropriate session variables
|
||||
match command.as_str() {
|
||||
"SET" => {
|
||||
match command_tag.as_str() {
|
||||
"SET\0" => {
|
||||
// We don't detect set statements in transactions
|
||||
// No great way to differentiate between set and set local
|
||||
// As a result, we will miss cases when set statements are used in transactions
|
||||
// This will reduce amount of reset statements sent
|
||||
// This will reduce amount of discard statements sent
|
||||
if !self.in_transaction {
|
||||
debug!("Server connection marked for clean up");
|
||||
self.cleanup_state.needs_cleanup_set = true;
|
||||
self.needs_cleanup = true;
|
||||
}
|
||||
}
|
||||
|
||||
"PREPARE" => {
|
||||
"PREPARE\0" => {
|
||||
debug!("Server connection marked for clean up");
|
||||
self.cleanup_state.needs_cleanup_prepare = true;
|
||||
self.needs_cleanup = true;
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
@@ -1008,20 +804,6 @@ impl Server {
|
||||
}
|
||||
}
|
||||
|
||||
'S' => {
|
||||
let key = message.read_string().unwrap();
|
||||
let value = message.read_string().unwrap();
|
||||
|
||||
if let Some(client_server_parameters) = client_server_parameters.as_mut() {
|
||||
client_server_parameters.set_param(key.clone(), value.clone(), false);
|
||||
if self.log_client_parameter_status_changes {
|
||||
info!("Client parameter status change: {} = {}", key, value)
|
||||
}
|
||||
}
|
||||
|
||||
self.server_parameters.set_param(key, value, false);
|
||||
}
|
||||
|
||||
// DataRow
|
||||
'D' => {
|
||||
// More data is available after this message, this is not the end of the reply.
|
||||
@@ -1034,14 +816,10 @@ impl Server {
|
||||
}
|
||||
|
||||
// CopyInResponse: copy is starting from client to server.
|
||||
'G' => {
|
||||
self.in_copy_mode = true;
|
||||
break;
|
||||
}
|
||||
'G' => break,
|
||||
|
||||
// CopyOutResponse: copy is starting from the server to the client.
|
||||
'H' => {
|
||||
self.in_copy_mode = true;
|
||||
self.data_available = true;
|
||||
break;
|
||||
}
|
||||
@@ -1079,119 +857,6 @@ impl Server {
|
||||
Ok(bytes)
|
||||
}
|
||||
|
||||
/// Add the prepared statement to being tracked by this server.
|
||||
/// The client is processing data that will create a prepared statement on this server.
|
||||
pub fn will_prepare(&mut self, name: &str) {
|
||||
debug!("Will prepare `{}`", name);
|
||||
|
||||
self.prepared_statements.insert(name.to_string());
|
||||
self.stats.prepared_cache_add();
|
||||
}
|
||||
|
||||
/// Check if we should prepare a statement on the server.
|
||||
pub fn should_prepare(&self, name: &str) -> bool {
|
||||
let should_prepare = !self.prepared_statements.contains(name);
|
||||
|
||||
debug!("Should prepare `{}`: {}", name, should_prepare);
|
||||
|
||||
if should_prepare {
|
||||
self.stats.prepared_cache_miss();
|
||||
} else {
|
||||
self.stats.prepared_cache_hit();
|
||||
}
|
||||
|
||||
should_prepare
|
||||
}
|
||||
|
||||
/// Create a prepared statement on the server.
|
||||
pub async fn prepare(&mut self, parse: &Parse) -> Result<(), Error> {
|
||||
debug!("Preparing `{}`", parse.name);
|
||||
|
||||
let bytes: BytesMut = parse.try_into()?;
|
||||
self.send(&bytes).await?;
|
||||
self.send(&flush()).await?;
|
||||
|
||||
// Read and discard ParseComplete (B)
|
||||
match read_message(&mut self.stream).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
self.bad = true;
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
|
||||
self.prepared_statements.insert(parse.name.to_string());
|
||||
self.stats.prepared_cache_add();
|
||||
|
||||
debug!("Prepared `{}`", parse.name);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Maintain adequate cache size on the server.
|
||||
pub async fn maintain_cache(&mut self) -> Result<(), Error> {
|
||||
debug!("Cache maintenance run");
|
||||
|
||||
let max_cache_size = get_prepared_statements_cache_size();
|
||||
let mut names = Vec::new();
|
||||
|
||||
while self.prepared_statements.len() >= max_cache_size {
|
||||
// The prepared statmeents are alphanumerically sorted by the BTree.
|
||||
// FIFO.
|
||||
if let Some(name) = self.prepared_statements.pop_last() {
|
||||
names.push(name);
|
||||
}
|
||||
}
|
||||
|
||||
if !names.is_empty() {
|
||||
self.deallocate(names).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove the prepared statement from being tracked by this server.
|
||||
/// The client is processing data that will cause the server to close the prepared statement.
|
||||
pub fn will_close(&mut self, name: &str) {
|
||||
debug!("Will close `{}`", name);
|
||||
|
||||
self.prepared_statements.remove(name);
|
||||
}
|
||||
|
||||
/// Close a prepared statement on the server.
|
||||
pub async fn deallocate(&mut self, names: Vec<String>) -> Result<(), Error> {
|
||||
for name in &names {
|
||||
debug!("Deallocating prepared statement `{}`", name);
|
||||
|
||||
let close = Close::new(name);
|
||||
let bytes: BytesMut = close.try_into()?;
|
||||
|
||||
self.send(&bytes).await?;
|
||||
}
|
||||
|
||||
if !names.is_empty() {
|
||||
self.send(&flush()).await?;
|
||||
}
|
||||
|
||||
// Read and discard CloseComplete (3)
|
||||
for name in &names {
|
||||
match read_message(&mut self.stream).await {
|
||||
Ok(_) => {
|
||||
self.prepared_statements.remove(name);
|
||||
self.stats.prepared_cache_remove();
|
||||
debug!("Closed `{}`", name);
|
||||
}
|
||||
|
||||
Err(err) => {
|
||||
self.bad = true;
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// If the server is still inside a transaction.
|
||||
/// If the client disconnects while the server is in a transaction, we will clean it up.
|
||||
pub fn in_transaction(&self) -> bool {
|
||||
@@ -1199,10 +864,6 @@ impl Server {
|
||||
self.in_transaction
|
||||
}
|
||||
|
||||
pub fn in_copy_mode(&self) -> bool {
|
||||
self.in_copy_mode
|
||||
}
|
||||
|
||||
/// We don't buffer all of server responses, e.g. COPY OUT produces too much data.
|
||||
/// The client is responsible to call `self.recv()` while this method returns true.
|
||||
pub fn is_data_available(&self) -> bool {
|
||||
@@ -1232,28 +893,9 @@ impl Server {
|
||||
}
|
||||
|
||||
/// Get server startup information to forward it to the client.
|
||||
pub fn server_parameters(&self) -> ServerParameters {
|
||||
self.server_parameters.clone()
|
||||
}
|
||||
|
||||
pub async fn sync_parameters(&mut self, parameters: &ServerParameters) -> Result<(), Error> {
|
||||
let parameter_diff = self.server_parameters.compare_params(parameters);
|
||||
|
||||
if parameter_diff.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut query = String::from("");
|
||||
|
||||
for (key, value) in parameter_diff {
|
||||
query.push_str(&format!("SET {} TO '{}';", key, value));
|
||||
}
|
||||
|
||||
let res = self.query(&query).await;
|
||||
|
||||
self.cleanup_state.reset();
|
||||
|
||||
res
|
||||
/// Not used at the moment.
|
||||
pub fn server_info(&self) -> BytesMut {
|
||||
self.server_info.clone()
|
||||
}
|
||||
|
||||
/// Indicate that this server connection cannot be re-used and must be discarded.
|
||||
@@ -1280,14 +922,12 @@ impl Server {
|
||||
/// It will use the simple query protocol.
|
||||
/// Result will not be returned, so this is useful for things like `SET` or `ROLLBACK`.
|
||||
pub async fn query(&mut self, query: &str) -> Result<(), Error> {
|
||||
debug!("Running `{}` on server {:?}", query, self.address);
|
||||
|
||||
let query = simple_query(query);
|
||||
|
||||
self.send(&query).await?;
|
||||
|
||||
loop {
|
||||
let _ = self.recv(None).await?;
|
||||
let _ = self.recv().await?;
|
||||
|
||||
if !self.data_available {
|
||||
break;
|
||||
@@ -1305,38 +945,42 @@ impl Server {
|
||||
// server connection thrashing if clients repeatedly do this.
|
||||
// Instead, we ROLLBACK that transaction before putting the connection back in the pool
|
||||
if self.in_transaction() {
|
||||
warn!(target: "pgcat::server::cleanup", "Server returned while still in transaction, rolling back transaction");
|
||||
warn!("Server returned while still in transaction, rolling back transaction");
|
||||
self.query("ROLLBACK").await?;
|
||||
}
|
||||
|
||||
// Client disconnected but it performed session-altering operations such as
|
||||
// SET statement_timeout to 1 or create a prepared statement. We clear that
|
||||
// to avoid leaking state between clients. For performance reasons we only
|
||||
// send `RESET ALL` if we think the session is altered instead of just sending
|
||||
// send `DISCARD ALL` if we think the session is altered instead of just sending
|
||||
// it before each checkin.
|
||||
if self.cleanup_state.needs_cleanup() && self.cleanup_connections {
|
||||
info!(target: "pgcat::server::cleanup", "Server returned with session state altered, discarding state ({}) for application {}", self.cleanup_state, self.application_name);
|
||||
let mut reset_string = String::from("RESET ROLE;");
|
||||
|
||||
if self.cleanup_state.needs_cleanup_set {
|
||||
reset_string.push_str("RESET ALL;");
|
||||
};
|
||||
|
||||
if self.cleanup_state.needs_cleanup_prepare {
|
||||
reset_string.push_str("DEALLOCATE ALL;");
|
||||
};
|
||||
|
||||
self.query(&reset_string).await?;
|
||||
self.cleanup_state.reset();
|
||||
}
|
||||
|
||||
if self.in_copy_mode() {
|
||||
warn!(target: "pgcat::server::cleanup", "Server returned while still in copy-mode");
|
||||
if self.needs_cleanup {
|
||||
warn!("Server returned with session state altered, discarding state");
|
||||
self.query("DISCARD ALL").await?;
|
||||
self.needs_cleanup = false;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// A shorthand for `SET application_name = $1`.
|
||||
pub async fn set_name(&mut self, name: &str) -> Result<(), Error> {
|
||||
if self.application_name != name {
|
||||
self.application_name = name.to_string();
|
||||
// We don't want `SET application_name` to mark the server connection
|
||||
// as needing cleanup
|
||||
let needs_cleanup_before = self.needs_cleanup;
|
||||
|
||||
let result = Ok(self
|
||||
.query(&format!("SET application_name = '{}'", name))
|
||||
.await?);
|
||||
self.needs_cleanup = needs_cleanup_before;
|
||||
result
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// get Server stats
|
||||
pub fn stats(&self) -> Arc<ServerStats> {
|
||||
self.stats.clone()
|
||||
@@ -1353,9 +997,9 @@ impl Server {
|
||||
self.last_activity
|
||||
}
|
||||
|
||||
// Marks a connection as needing cleanup at checkin
|
||||
// Marks a connection as needing DISCARD ALL at checkin
|
||||
pub fn mark_dirty(&mut self) {
|
||||
self.cleanup_state.set_true();
|
||||
self.needs_cleanup = true;
|
||||
}
|
||||
|
||||
pub fn mirror_send(&mut self, bytes: &BytesMut) {
|
||||
@@ -1389,13 +1033,11 @@ impl Server {
|
||||
client_server_map,
|
||||
Arc::new(ServerStats::default()),
|
||||
Arc::new(RwLock::new(None)),
|
||||
true,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
debug!("Connected!, sending query.");
|
||||
server.send(&simple_query(query)).await?;
|
||||
let mut message = server.recv(None).await?;
|
||||
let mut message = server.recv().await?;
|
||||
|
||||
Ok(parse_query_message(&mut message).await?)
|
||||
}
|
||||
@@ -1493,18 +1135,14 @@ impl Drop for Server {
|
||||
_ => debug!("Dirty shutdown"),
|
||||
};
|
||||
|
||||
// Should not matter.
|
||||
self.bad = true;
|
||||
|
||||
let now = chrono::offset::Utc::now().naive_utc();
|
||||
let duration = now - self.connected_at;
|
||||
|
||||
let message = if self.bad {
|
||||
"Server connection terminated"
|
||||
} else {
|
||||
"Server connection closed"
|
||||
};
|
||||
|
||||
info!(
|
||||
"{} {:?}, session duration: {}",
|
||||
message,
|
||||
"Server connection closed {:?}, session duration: {}",
|
||||
self.address,
|
||||
crate::format_duration(&duration)
|
||||
);
|
||||
|
||||
37
src/stats.rs
37
src/stats.rs
@@ -1,3 +1,4 @@
|
||||
use crate::pool::PoolIdentifier;
|
||||
/// Statistics and reporting.
|
||||
use arc_swap::ArcSwap;
|
||||
|
||||
@@ -15,11 +16,13 @@ pub mod pool;
|
||||
pub mod server;
|
||||
pub use address::AddressStats;
|
||||
pub use client::{ClientState, ClientStats};
|
||||
pub use pool::PoolStats;
|
||||
pub use server::{ServerState, ServerStats};
|
||||
|
||||
/// Convenience types for various stats
|
||||
type ClientStatesLookup = HashMap<i32, Arc<ClientStats>>;
|
||||
type ServerStatesLookup = HashMap<i32, Arc<ServerStats>>;
|
||||
type PoolStatsLookup = HashMap<(String, String), Arc<PoolStats>>;
|
||||
|
||||
/// Stats for individual client connections
|
||||
/// Used in SHOW CLIENTS.
|
||||
@@ -31,6 +34,11 @@ static CLIENT_STATS: Lazy<Arc<RwLock<ClientStatesLookup>>> =
|
||||
static SERVER_STATS: Lazy<Arc<RwLock<ServerStatesLookup>>> =
|
||||
Lazy::new(|| Arc::new(RwLock::new(ServerStatesLookup::default())));
|
||||
|
||||
/// Aggregate stats for each pool (a pool is identified by database name and username)
|
||||
/// Used in SHOW POOLS.
|
||||
static POOL_STATS: Lazy<Arc<RwLock<PoolStatsLookup>>> =
|
||||
Lazy::new(|| Arc::new(RwLock::new(PoolStatsLookup::default())));
|
||||
|
||||
/// The statistics reporter. An instance is given to each possible source of statistics,
|
||||
/// e.g. client stats, server stats, connection pool stats.
|
||||
pub static REPORTER: Lazy<ArcSwap<Reporter>> =
|
||||
@@ -72,6 +80,13 @@ impl Reporter {
|
||||
fn server_disconnecting(&self, server_id: i32) {
|
||||
SERVER_STATS.write().remove(&server_id);
|
||||
}
|
||||
|
||||
/// Register a pool with the stats system.
|
||||
fn pool_register(&self, identifier: PoolIdentifier, stats: Arc<PoolStats>) {
|
||||
POOL_STATS
|
||||
.write()
|
||||
.insert((identifier.db, identifier.user), stats);
|
||||
}
|
||||
}
|
||||
|
||||
/// The statistics collector which used for calculating averages
|
||||
@@ -92,20 +107,8 @@ impl Collector {
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
// Hold read lock for duration of update to retain all server stats
|
||||
let server_stats = SERVER_STATS.read();
|
||||
|
||||
for stats in server_stats.values() {
|
||||
if !stats.check_address_stat_average_is_updated_status() {
|
||||
stats.address_stats().update_averages();
|
||||
stats.address_stats().reset_current_counts();
|
||||
stats.set_address_stat_average_is_updated_status(true);
|
||||
}
|
||||
}
|
||||
|
||||
// Reset to false for next update
|
||||
for stats in server_stats.values() {
|
||||
stats.set_address_stat_average_is_updated_status(false);
|
||||
for stats in SERVER_STATS.read().values() {
|
||||
stats.address_stats().update_averages();
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -124,6 +127,12 @@ pub fn get_server_stats() -> ServerStatesLookup {
|
||||
SERVER_STATS.read().clone()
|
||||
}
|
||||
|
||||
/// Get a snapshot of pool statistics.
|
||||
/// by the `Collector`.
|
||||
pub fn get_pool_stats() -> PoolStatsLookup {
|
||||
POOL_STATS.read().clone()
|
||||
}
|
||||
|
||||
/// Get the statistics reporter used to update stats across the pools/clients.
|
||||
pub fn get_reporter() -> Reporter {
|
||||
(*(*REPORTER.load())).clone()
|
||||
|
||||
@@ -1,29 +1,26 @@
|
||||
use log::warn;
|
||||
use std::sync::atomic::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct AddressStatFields {
|
||||
xact_count: Arc<AtomicU64>,
|
||||
query_count: Arc<AtomicU64>,
|
||||
bytes_received: Arc<AtomicU64>,
|
||||
bytes_sent: Arc<AtomicU64>,
|
||||
xact_time: Arc<AtomicU64>,
|
||||
query_time: Arc<AtomicU64>,
|
||||
wait_time: Arc<AtomicU64>,
|
||||
errors: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
/// Internal address stats
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct AddressStats {
|
||||
total: AddressStatFields,
|
||||
|
||||
current: AddressStatFields,
|
||||
|
||||
averages: AddressStatFields,
|
||||
|
||||
// Determines if the averages have been updated since the last time they were reported
|
||||
pub averages_updated: Arc<AtomicBool>,
|
||||
pub total_xact_count: Arc<AtomicU64>,
|
||||
pub total_query_count: Arc<AtomicU64>,
|
||||
pub total_received: Arc<AtomicU64>,
|
||||
pub total_sent: Arc<AtomicU64>,
|
||||
pub total_xact_time: Arc<AtomicU64>,
|
||||
pub total_query_time: Arc<AtomicU64>,
|
||||
pub total_wait_time: Arc<AtomicU64>,
|
||||
pub total_errors: Arc<AtomicU64>,
|
||||
pub avg_query_count: Arc<AtomicU64>,
|
||||
pub avg_query_time: Arc<AtomicU64>,
|
||||
pub avg_recv: Arc<AtomicU64>,
|
||||
pub avg_sent: Arc<AtomicU64>,
|
||||
pub avg_errors: Arc<AtomicU64>,
|
||||
pub avg_xact_time: Arc<AtomicU64>,
|
||||
pub avg_xact_count: Arc<AtomicU64>,
|
||||
pub avg_wait_time: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
impl IntoIterator for AddressStats {
|
||||
@@ -34,67 +31,67 @@ impl IntoIterator for AddressStats {
|
||||
vec![
|
||||
(
|
||||
"total_xact_count".to_string(),
|
||||
self.total.xact_count.load(Ordering::Relaxed),
|
||||
self.total_xact_count.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"total_query_count".to_string(),
|
||||
self.total.query_count.load(Ordering::Relaxed),
|
||||
self.total_query_count.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"total_received".to_string(),
|
||||
self.total.bytes_received.load(Ordering::Relaxed),
|
||||
self.total_received.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"total_sent".to_string(),
|
||||
self.total.bytes_sent.load(Ordering::Relaxed),
|
||||
self.total_sent.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"total_xact_time".to_string(),
|
||||
self.total.xact_time.load(Ordering::Relaxed),
|
||||
self.total_xact_time.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"total_query_time".to_string(),
|
||||
self.total.query_time.load(Ordering::Relaxed),
|
||||
self.total_query_time.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"total_wait_time".to_string(),
|
||||
self.total.wait_time.load(Ordering::Relaxed),
|
||||
self.total_wait_time.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"total_errors".to_string(),
|
||||
self.total.errors.load(Ordering::Relaxed),
|
||||
self.total_errors.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"avg_xact_count".to_string(),
|
||||
self.averages.xact_count.load(Ordering::Relaxed),
|
||||
self.avg_xact_count.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"avg_query_count".to_string(),
|
||||
self.averages.query_count.load(Ordering::Relaxed),
|
||||
self.avg_query_count.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"avg_recv".to_string(),
|
||||
self.averages.bytes_received.load(Ordering::Relaxed),
|
||||
self.avg_recv.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"avg_sent".to_string(),
|
||||
self.averages.bytes_sent.load(Ordering::Relaxed),
|
||||
self.avg_sent.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"avg_errors".to_string(),
|
||||
self.averages.errors.load(Ordering::Relaxed),
|
||||
self.avg_errors.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"avg_xact_time".to_string(),
|
||||
self.averages.xact_time.load(Ordering::Relaxed),
|
||||
self.avg_xact_time.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"avg_query_time".to_string(),
|
||||
self.averages.query_time.load(Ordering::Relaxed),
|
||||
self.avg_query_time.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"avg_wait_time".to_string(),
|
||||
self.averages.wait_time.load(Ordering::Relaxed),
|
||||
self.avg_wait_time.load(Ordering::Relaxed),
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
@@ -102,120 +99,22 @@ impl IntoIterator for AddressStats {
|
||||
}
|
||||
|
||||
impl AddressStats {
|
||||
pub fn xact_count_add(&self) {
|
||||
self.total.xact_count.fetch_add(1, Ordering::Relaxed);
|
||||
self.current.xact_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn query_count_add(&self) {
|
||||
self.total.query_count.fetch_add(1, Ordering::Relaxed);
|
||||
self.current.query_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn bytes_received_add(&self, bytes: u64) {
|
||||
self.total
|
||||
.bytes_received
|
||||
.fetch_add(bytes, Ordering::Relaxed);
|
||||
self.current
|
||||
.bytes_received
|
||||
.fetch_add(bytes, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn bytes_sent_add(&self, bytes: u64) {
|
||||
self.total.bytes_sent.fetch_add(bytes, Ordering::Relaxed);
|
||||
self.current.bytes_sent.fetch_add(bytes, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn xact_time_add(&self, time: u64) {
|
||||
self.total.xact_time.fetch_add(time, Ordering::Relaxed);
|
||||
self.current.xact_time.fetch_add(time, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn query_time_add(&self, time: u64) {
|
||||
self.total.query_time.fetch_add(time, Ordering::Relaxed);
|
||||
self.current.query_time.fetch_add(time, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn wait_time_add(&self, time: u64) {
|
||||
self.total.wait_time.fetch_add(time, Ordering::Relaxed);
|
||||
self.current.wait_time.fetch_add(time, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn error(&self) {
|
||||
self.total.errors.fetch_add(1, Ordering::Relaxed);
|
||||
self.current.errors.fetch_add(1, Ordering::Relaxed);
|
||||
self.total_errors.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn update_averages(&self) {
|
||||
let stat_period_per_second = crate::stats::STAT_PERIOD / 1_000;
|
||||
|
||||
// xact_count
|
||||
let current_xact_count = self.current.xact_count.load(Ordering::Relaxed);
|
||||
let current_xact_time = self.current.xact_time.load(Ordering::Relaxed);
|
||||
self.averages.xact_count.store(
|
||||
current_xact_count / stat_period_per_second,
|
||||
Ordering::Relaxed,
|
||||
);
|
||||
if current_xact_count == 0 {
|
||||
self.averages.xact_time.store(0, Ordering::Relaxed);
|
||||
} else {
|
||||
self.averages
|
||||
.xact_time
|
||||
.store(current_xact_time / current_xact_count, Ordering::Relaxed);
|
||||
let (totals, averages) = self.fields_iterators();
|
||||
for data in totals.iter().zip(averages.iter()) {
|
||||
let (total, average) = data;
|
||||
if let Err(err) = average.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |avg| {
|
||||
let total = total.load(Ordering::Relaxed);
|
||||
let avg = (total - avg) / (crate::stats::STAT_PERIOD / 1_000); // Avg / second
|
||||
Some(avg)
|
||||
}) {
|
||||
warn!("Could not update averages for addresses stats, {:?}", err);
|
||||
}
|
||||
}
|
||||
|
||||
// query_count
|
||||
let current_query_count = self.current.query_count.load(Ordering::Relaxed);
|
||||
let current_query_time = self.current.query_time.load(Ordering::Relaxed);
|
||||
self.averages.query_count.store(
|
||||
current_query_count / stat_period_per_second,
|
||||
Ordering::Relaxed,
|
||||
);
|
||||
if current_query_count == 0 {
|
||||
self.averages.query_time.store(0, Ordering::Relaxed);
|
||||
} else {
|
||||
self.averages
|
||||
.query_time
|
||||
.store(current_query_time / current_query_count, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
// bytes_received
|
||||
let current_bytes_received = self.current.bytes_received.load(Ordering::Relaxed);
|
||||
self.averages.bytes_received.store(
|
||||
current_bytes_received / stat_period_per_second,
|
||||
Ordering::Relaxed,
|
||||
);
|
||||
|
||||
// bytes_sent
|
||||
let current_bytes_sent = self.current.bytes_sent.load(Ordering::Relaxed);
|
||||
self.averages.bytes_sent.store(
|
||||
current_bytes_sent / stat_period_per_second,
|
||||
Ordering::Relaxed,
|
||||
);
|
||||
|
||||
// wait_time
|
||||
let current_wait_time = self.current.wait_time.load(Ordering::Relaxed);
|
||||
self.averages.wait_time.store(
|
||||
current_wait_time / stat_period_per_second,
|
||||
Ordering::Relaxed,
|
||||
);
|
||||
|
||||
// errors
|
||||
let current_errors = self.current.errors.load(Ordering::Relaxed);
|
||||
self.averages
|
||||
.errors
|
||||
.store(current_errors / stat_period_per_second, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn reset_current_counts(&self) {
|
||||
self.current.xact_count.store(0, Ordering::Relaxed);
|
||||
self.current.xact_time.store(0, Ordering::Relaxed);
|
||||
self.current.query_count.store(0, Ordering::Relaxed);
|
||||
self.current.query_time.store(0, Ordering::Relaxed);
|
||||
self.current.bytes_received.store(0, Ordering::Relaxed);
|
||||
self.current.bytes_sent.store(0, Ordering::Relaxed);
|
||||
self.current.wait_time.store(0, Ordering::Relaxed);
|
||||
self.current.errors.store(0, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn populate_row(&self, row: &mut Vec<String>) {
|
||||
@@ -223,4 +122,28 @@ impl AddressStats {
|
||||
row.push(value.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
fn fields_iterators(&self) -> (Vec<Arc<AtomicU64>>, Vec<Arc<AtomicU64>>) {
|
||||
let mut totals: Vec<Arc<AtomicU64>> = Vec::new();
|
||||
let mut averages: Vec<Arc<AtomicU64>> = Vec::new();
|
||||
|
||||
totals.push(self.total_xact_count.clone());
|
||||
averages.push(self.avg_xact_count.clone());
|
||||
totals.push(self.total_query_count.clone());
|
||||
averages.push(self.avg_query_count.clone());
|
||||
totals.push(self.total_received.clone());
|
||||
averages.push(self.avg_recv.clone());
|
||||
totals.push(self.total_sent.clone());
|
||||
averages.push(self.avg_sent.clone());
|
||||
totals.push(self.total_xact_time.clone());
|
||||
averages.push(self.avg_xact_time.clone());
|
||||
totals.push(self.total_query_time.clone());
|
||||
averages.push(self.avg_query_time.clone());
|
||||
totals.push(self.total_wait_time.clone());
|
||||
averages.push(self.avg_wait_time.clone());
|
||||
totals.push(self.total_errors.clone());
|
||||
averages.push(self.avg_errors.clone());
|
||||
|
||||
(totals, averages)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use super::PoolStats;
|
||||
use super::{get_reporter, Reporter};
|
||||
use atomic_enum::atomic_enum;
|
||||
use std::sync::atomic::*;
|
||||
@@ -33,14 +34,12 @@ pub struct ClientStats {
|
||||
pool_name: String,
|
||||
connect_time: Instant,
|
||||
|
||||
pool_stats: Arc<PoolStats>,
|
||||
reporter: Reporter,
|
||||
|
||||
/// Total time spent waiting for a connection from pool, measures in microseconds
|
||||
pub total_wait_time: Arc<AtomicU64>,
|
||||
|
||||
/// Maximum time spent waiting for a connection from pool, measures in microseconds
|
||||
pub max_wait_time: Arc<AtomicU64>,
|
||||
|
||||
/// Current state of the client
|
||||
pub state: Arc<AtomicClientState>,
|
||||
|
||||
@@ -62,8 +61,8 @@ impl Default for ClientStats {
|
||||
application_name: String::new(),
|
||||
username: String::new(),
|
||||
pool_name: String::new(),
|
||||
pool_stats: Arc::new(PoolStats::default()),
|
||||
total_wait_time: Arc::new(AtomicU64::new(0)),
|
||||
max_wait_time: Arc::new(AtomicU64::new(0)),
|
||||
state: Arc::new(AtomicClientState::new(ClientState::Idle)),
|
||||
transaction_count: Arc::new(AtomicU64::new(0)),
|
||||
query_count: Arc::new(AtomicU64::new(0)),
|
||||
@@ -80,9 +79,11 @@ impl ClientStats {
|
||||
username: &str,
|
||||
pool_name: &str,
|
||||
connect_time: Instant,
|
||||
pool_stats: Arc<PoolStats>,
|
||||
) -> Self {
|
||||
Self {
|
||||
client_id,
|
||||
pool_stats,
|
||||
connect_time,
|
||||
application_name: application_name.to_string(),
|
||||
username: username.to_string(),
|
||||
@@ -95,6 +96,8 @@ impl ClientStats {
|
||||
/// update metrics on the corresponding pool.
|
||||
pub fn disconnect(&self) {
|
||||
self.reporter.client_disconnecting(self.client_id);
|
||||
self.pool_stats
|
||||
.client_disconnect(self.state.load(Ordering::Relaxed))
|
||||
}
|
||||
|
||||
/// Register a client with the stats system. The stats system uses client_id
|
||||
@@ -102,20 +105,27 @@ impl ClientStats {
|
||||
pub fn register(&self, stats: Arc<ClientStats>) {
|
||||
self.reporter.client_register(self.client_id, stats);
|
||||
self.state.store(ClientState::Idle, Ordering::Relaxed);
|
||||
self.pool_stats.cl_idle.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Reports a client is done querying the server and is no longer assigned a server connection
|
||||
pub fn idle(&self) {
|
||||
self.pool_stats
|
||||
.client_idle(self.state.load(Ordering::Relaxed));
|
||||
self.state.store(ClientState::Idle, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Reports a client is waiting for a connection
|
||||
pub fn waiting(&self) {
|
||||
self.pool_stats
|
||||
.client_waiting(self.state.load(Ordering::Relaxed));
|
||||
self.state.store(ClientState::Waiting, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Reports a client is done waiting for a connection and is about to query the server.
|
||||
pub fn active(&self) {
|
||||
self.pool_stats
|
||||
.client_active(self.state.load(Ordering::Relaxed));
|
||||
self.state.store(ClientState::Active, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
@@ -134,8 +144,6 @@ impl ClientStats {
|
||||
pub fn checkout_time(&self, microseconds: u64) {
|
||||
self.total_wait_time
|
||||
.fetch_add(microseconds, Ordering::Relaxed);
|
||||
self.max_wait_time
|
||||
.fetch_max(microseconds, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Report a query executed by a client against a server
|
||||
|
||||
@@ -1,131 +1,36 @@
|
||||
use log::debug;
|
||||
|
||||
use super::{ClientState, ServerState};
|
||||
use crate::{config::PoolMode, messages::DataType, pool::PoolIdentifier};
|
||||
use std::collections::HashMap;
|
||||
use crate::config::Pool;
|
||||
use crate::config::PoolMode;
|
||||
use crate::pool::PoolIdentifier;
|
||||
use std::sync::atomic::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::pool::get_all_pools;
|
||||
use super::get_reporter;
|
||||
use super::Reporter;
|
||||
use super::{ClientState, ServerState};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Default)]
|
||||
/// A struct that holds information about a Pool .
|
||||
pub struct PoolStats {
|
||||
pub identifier: PoolIdentifier,
|
||||
pub mode: PoolMode,
|
||||
pub cl_idle: u64,
|
||||
pub cl_active: u64,
|
||||
pub cl_waiting: u64,
|
||||
pub cl_cancel_req: u64,
|
||||
pub sv_active: u64,
|
||||
pub sv_idle: u64,
|
||||
pub sv_used: u64,
|
||||
pub sv_tested: u64,
|
||||
pub sv_login: u64,
|
||||
pub maxwait: u64,
|
||||
}
|
||||
impl PoolStats {
|
||||
pub fn new(identifier: PoolIdentifier, mode: PoolMode) -> Self {
|
||||
PoolStats {
|
||||
identifier,
|
||||
mode,
|
||||
cl_idle: 0,
|
||||
cl_active: 0,
|
||||
cl_waiting: 0,
|
||||
cl_cancel_req: 0,
|
||||
sv_active: 0,
|
||||
sv_idle: 0,
|
||||
sv_used: 0,
|
||||
sv_tested: 0,
|
||||
sv_login: 0,
|
||||
maxwait: 0,
|
||||
}
|
||||
}
|
||||
// Pool identifier, cannot be changed after creating the instance
|
||||
identifier: PoolIdentifier,
|
||||
|
||||
pub fn construct_pool_lookup() -> HashMap<PoolIdentifier, PoolStats> {
|
||||
let mut map: HashMap<PoolIdentifier, PoolStats> = HashMap::new();
|
||||
let client_map = super::get_client_stats();
|
||||
let server_map = super::get_server_stats();
|
||||
// Pool Config, cannot be changed after creating the instance
|
||||
config: Pool,
|
||||
|
||||
for (identifier, pool) in get_all_pools() {
|
||||
map.insert(
|
||||
identifier.clone(),
|
||||
PoolStats::new(identifier, pool.settings.pool_mode),
|
||||
);
|
||||
}
|
||||
// A reference to the global reporter.
|
||||
reporter: Reporter,
|
||||
|
||||
for client in client_map.values() {
|
||||
match map.get_mut(&PoolIdentifier {
|
||||
db: client.pool_name(),
|
||||
user: client.username(),
|
||||
}) {
|
||||
Some(pool_stats) => {
|
||||
match client.state.load(Ordering::Relaxed) {
|
||||
ClientState::Active => pool_stats.cl_active += 1,
|
||||
ClientState::Idle => pool_stats.cl_idle += 1,
|
||||
ClientState::Waiting => pool_stats.cl_waiting += 1,
|
||||
}
|
||||
let max_wait = client.max_wait_time.load(Ordering::Relaxed);
|
||||
pool_stats.maxwait = std::cmp::max(pool_stats.maxwait, max_wait);
|
||||
}
|
||||
None => debug!("Client from an obselete pool"),
|
||||
}
|
||||
}
|
||||
|
||||
for server in server_map.values() {
|
||||
match map.get_mut(&PoolIdentifier {
|
||||
db: server.pool_name(),
|
||||
user: server.username(),
|
||||
}) {
|
||||
Some(pool_stats) => match server.state.load(Ordering::Relaxed) {
|
||||
ServerState::Active => pool_stats.sv_active += 1,
|
||||
ServerState::Idle => pool_stats.sv_idle += 1,
|
||||
ServerState::Login => pool_stats.sv_login += 1,
|
||||
ServerState::Tested => pool_stats.sv_tested += 1,
|
||||
},
|
||||
None => debug!("Server from an obselete pool"),
|
||||
}
|
||||
}
|
||||
|
||||
return map;
|
||||
}
|
||||
|
||||
pub fn generate_header() -> Vec<(&'static str, DataType)> {
|
||||
return vec![
|
||||
("database", DataType::Text),
|
||||
("user", DataType::Text),
|
||||
("pool_mode", DataType::Text),
|
||||
("cl_idle", DataType::Numeric),
|
||||
("cl_active", DataType::Numeric),
|
||||
("cl_waiting", DataType::Numeric),
|
||||
("cl_cancel_req", DataType::Numeric),
|
||||
("sv_active", DataType::Numeric),
|
||||
("sv_idle", DataType::Numeric),
|
||||
("sv_used", DataType::Numeric),
|
||||
("sv_tested", DataType::Numeric),
|
||||
("sv_login", DataType::Numeric),
|
||||
("maxwait", DataType::Numeric),
|
||||
("maxwait_us", DataType::Numeric),
|
||||
];
|
||||
}
|
||||
|
||||
pub fn generate_row(&self) -> Vec<String> {
|
||||
return vec![
|
||||
self.identifier.db.clone(),
|
||||
self.identifier.user.clone(),
|
||||
self.mode.to_string(),
|
||||
self.cl_idle.to_string(),
|
||||
self.cl_active.to_string(),
|
||||
self.cl_waiting.to_string(),
|
||||
self.cl_cancel_req.to_string(),
|
||||
self.sv_active.to_string(),
|
||||
self.sv_idle.to_string(),
|
||||
self.sv_used.to_string(),
|
||||
self.sv_tested.to_string(),
|
||||
self.sv_login.to_string(),
|
||||
(self.maxwait / 1_000_000).to_string(),
|
||||
(self.maxwait % 1_000_000).to_string(),
|
||||
];
|
||||
}
|
||||
/// Counters (atomics)
|
||||
pub cl_idle: Arc<AtomicU64>,
|
||||
pub cl_active: Arc<AtomicU64>,
|
||||
pub cl_waiting: Arc<AtomicU64>,
|
||||
pub cl_cancel_req: Arc<AtomicU64>,
|
||||
pub sv_active: Arc<AtomicU64>,
|
||||
pub sv_idle: Arc<AtomicU64>,
|
||||
pub sv_used: Arc<AtomicU64>,
|
||||
pub sv_tested: Arc<AtomicU64>,
|
||||
pub sv_login: Arc<AtomicU64>,
|
||||
pub maxwait: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
impl IntoIterator for PoolStats {
|
||||
@@ -134,18 +39,236 @@ impl IntoIterator for PoolStats {
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
vec![
|
||||
("cl_idle".to_string(), self.cl_idle),
|
||||
("cl_active".to_string(), self.cl_active),
|
||||
("cl_waiting".to_string(), self.cl_waiting),
|
||||
("cl_cancel_req".to_string(), self.cl_cancel_req),
|
||||
("sv_active".to_string(), self.sv_active),
|
||||
("sv_idle".to_string(), self.sv_idle),
|
||||
("sv_used".to_string(), self.sv_used),
|
||||
("sv_tested".to_string(), self.sv_tested),
|
||||
("sv_login".to_string(), self.sv_login),
|
||||
("maxwait".to_string(), self.maxwait / 1_000_000),
|
||||
("maxwait_us".to_string(), self.maxwait % 1_000_000),
|
||||
("cl_idle".to_string(), self.cl_idle.load(Ordering::Relaxed)),
|
||||
(
|
||||
"cl_active".to_string(),
|
||||
self.cl_active.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"cl_waiting".to_string(),
|
||||
self.cl_waiting.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"cl_cancel_req".to_string(),
|
||||
self.cl_cancel_req.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"sv_active".to_string(),
|
||||
self.sv_active.load(Ordering::Relaxed),
|
||||
),
|
||||
("sv_idle".to_string(), self.sv_idle.load(Ordering::Relaxed)),
|
||||
("sv_used".to_string(), self.sv_used.load(Ordering::Relaxed)),
|
||||
(
|
||||
"sv_tested".to_string(),
|
||||
self.sv_tested.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"sv_login".to_string(),
|
||||
self.sv_login.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"maxwait".to_string(),
|
||||
self.maxwait.load(Ordering::Relaxed) / 1_000_000,
|
||||
),
|
||||
(
|
||||
"maxwait_us".to_string(),
|
||||
self.maxwait.load(Ordering::Relaxed) % 1_000_000,
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl PoolStats {
|
||||
pub fn new(identifier: PoolIdentifier, config: Pool) -> Self {
|
||||
Self {
|
||||
identifier,
|
||||
config,
|
||||
reporter: get_reporter(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
// Getters
|
||||
pub fn register(&self, stats: Arc<PoolStats>) {
|
||||
self.reporter.pool_register(self.identifier.clone(), stats);
|
||||
}
|
||||
|
||||
pub fn database(&self) -> String {
|
||||
self.identifier.db.clone()
|
||||
}
|
||||
|
||||
pub fn user(&self) -> String {
|
||||
self.identifier.user.clone()
|
||||
}
|
||||
|
||||
pub fn pool_mode(&self) -> PoolMode {
|
||||
self.config.pool_mode
|
||||
}
|
||||
|
||||
/// Populates an array of strings with counters (used by admin in show pools)
|
||||
pub fn populate_row(&self, row: &mut Vec<String>) {
|
||||
for (_key, value) in self.clone() {
|
||||
row.push(value.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
/// Deletes the maxwait counter, this is done everytime we obtain metrics
|
||||
pub fn clear_maxwait(&self) {
|
||||
self.maxwait.store(0, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Notified when a server of the pool enters login state.
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the server that notifies.
|
||||
pub fn server_login(&self, from: ServerState) {
|
||||
self.sv_login.fetch_add(1, Ordering::Relaxed);
|
||||
if from != ServerState::Login {
|
||||
self.decrease_from_server_state(from);
|
||||
}
|
||||
}
|
||||
|
||||
/// Notified when a server of the pool become 'active'
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the server that notifies.
|
||||
pub fn server_active(&self, from: ServerState) {
|
||||
self.sv_active.fetch_add(1, Ordering::Relaxed);
|
||||
if from != ServerState::Active {
|
||||
self.decrease_from_server_state(from);
|
||||
}
|
||||
}
|
||||
|
||||
/// Notified when a server of the pool become 'tested'
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the server that notifies.
|
||||
pub fn server_tested(&self, from: ServerState) {
|
||||
self.sv_tested.fetch_add(1, Ordering::Relaxed);
|
||||
if from != ServerState::Tested {
|
||||
self.decrease_from_server_state(from);
|
||||
}
|
||||
}
|
||||
|
||||
/// Notified when a server of the pool become 'idle'
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the server that notifies.
|
||||
pub fn server_idle(&self, from: ServerState) {
|
||||
self.sv_idle.fetch_add(1, Ordering::Relaxed);
|
||||
if from != ServerState::Idle {
|
||||
self.decrease_from_server_state(from);
|
||||
}
|
||||
}
|
||||
|
||||
/// Notified when a client of the pool become 'waiting'
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the client that notifies.
|
||||
pub fn client_waiting(&self, from: ClientState) {
|
||||
if from != ClientState::Waiting {
|
||||
self.cl_waiting.fetch_add(1, Ordering::Relaxed);
|
||||
self.decrease_from_client_state(from);
|
||||
}
|
||||
}
|
||||
|
||||
/// Notified when a client of the pool become 'active'
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the client that notifies.
|
||||
pub fn client_active(&self, from: ClientState) {
|
||||
if from != ClientState::Active {
|
||||
self.cl_active.fetch_add(1, Ordering::Relaxed);
|
||||
self.decrease_from_client_state(from);
|
||||
}
|
||||
}
|
||||
|
||||
/// Notified when a client of the pool become 'idle'
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the client that notifies.
|
||||
pub fn client_idle(&self, from: ClientState) {
|
||||
if from != ClientState::Idle {
|
||||
self.cl_idle.fetch_add(1, Ordering::Relaxed);
|
||||
self.decrease_from_client_state(from);
|
||||
}
|
||||
}
|
||||
|
||||
/// Notified when a client disconnects.
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the client that notifies.
|
||||
pub fn client_disconnect(&self, from: ClientState) {
|
||||
let counter = match from {
|
||||
ClientState::Idle => &self.cl_idle,
|
||||
ClientState::Waiting => &self.cl_waiting,
|
||||
ClientState::Active => &self.cl_active,
|
||||
};
|
||||
|
||||
Self::decrease_counter(counter.clone());
|
||||
}
|
||||
|
||||
/// Notified when a server disconnects.
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the client that notifies.
|
||||
pub fn server_disconnect(&self, from: ServerState) {
|
||||
let counter = match from {
|
||||
ServerState::Active => &self.sv_active,
|
||||
ServerState::Idle => &self.sv_idle,
|
||||
ServerState::Login => &self.sv_login,
|
||||
ServerState::Tested => &self.sv_tested,
|
||||
};
|
||||
Self::decrease_counter(counter.clone());
|
||||
}
|
||||
|
||||
// helpers for counter decrease
|
||||
fn decrease_from_server_state(&self, from: ServerState) {
|
||||
let counter = match from {
|
||||
ServerState::Tested => &self.sv_tested,
|
||||
ServerState::Active => &self.sv_active,
|
||||
ServerState::Idle => &self.sv_idle,
|
||||
ServerState::Login => &self.sv_login,
|
||||
};
|
||||
Self::decrease_counter(counter.clone());
|
||||
}
|
||||
|
||||
fn decrease_from_client_state(&self, from: ClientState) {
|
||||
let counter = match from {
|
||||
ClientState::Active => &self.cl_active,
|
||||
ClientState::Idle => &self.cl_idle,
|
||||
ClientState::Waiting => &self.cl_waiting,
|
||||
};
|
||||
Self::decrease_counter(counter.clone());
|
||||
}
|
||||
|
||||
fn decrease_counter(value: Arc<AtomicU64>) {
|
||||
if value.load(Ordering::Relaxed) > 0 {
|
||||
value.fetch_sub(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_decrease() {
|
||||
let stat: PoolStats = PoolStats::default();
|
||||
stat.server_login(ServerState::Login);
|
||||
stat.server_idle(ServerState::Login);
|
||||
assert_eq!(stat.sv_login.load(Ordering::Relaxed), 0);
|
||||
assert_eq!(stat.sv_idle.load(Ordering::Relaxed), 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use super::AddressStats;
|
||||
use super::PoolStats;
|
||||
use super::{get_reporter, Reporter};
|
||||
use crate::config::Address;
|
||||
use atomic_enum::atomic_enum;
|
||||
@@ -37,6 +38,7 @@ pub struct ServerStats {
|
||||
address: Address,
|
||||
connect_time: Instant,
|
||||
|
||||
pool_stats: Arc<PoolStats>,
|
||||
reporter: Reporter,
|
||||
|
||||
/// Data
|
||||
@@ -47,9 +49,6 @@ pub struct ServerStats {
|
||||
pub transaction_count: Arc<AtomicU64>,
|
||||
pub query_count: Arc<AtomicU64>,
|
||||
pub error_count: Arc<AtomicU64>,
|
||||
pub prepared_hit_count: Arc<AtomicU64>,
|
||||
pub prepared_miss_count: Arc<AtomicU64>,
|
||||
pub prepared_cache_size: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
impl Default for ServerStats {
|
||||
@@ -58,6 +57,7 @@ impl Default for ServerStats {
|
||||
server_id: 0,
|
||||
application_name: Arc::new(RwLock::new(String::new())),
|
||||
address: Address::default(),
|
||||
pool_stats: Arc::new(PoolStats::default()),
|
||||
connect_time: Instant::now(),
|
||||
state: Arc::new(AtomicServerState::new(ServerState::Login)),
|
||||
bytes_sent: Arc::new(AtomicU64::new(0)),
|
||||
@@ -66,17 +66,15 @@ impl Default for ServerStats {
|
||||
query_count: Arc::new(AtomicU64::new(0)),
|
||||
error_count: Arc::new(AtomicU64::new(0)),
|
||||
reporter: get_reporter(),
|
||||
prepared_hit_count: Arc::new(AtomicU64::new(0)),
|
||||
prepared_miss_count: Arc::new(AtomicU64::new(0)),
|
||||
prepared_cache_size: Arc::new(AtomicU64::new(0)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerStats {
|
||||
pub fn new(address: Address, connect_time: Instant) -> Self {
|
||||
pub fn new(address: Address, pool_stats: Arc<PoolStats>, connect_time: Instant) -> Self {
|
||||
Self {
|
||||
address,
|
||||
pool_stats,
|
||||
connect_time,
|
||||
server_id: rand::random::<i32>(),
|
||||
..Default::default()
|
||||
@@ -98,6 +96,9 @@ impl ServerStats {
|
||||
/// Reports a server connection is no longer assigned to a client
|
||||
/// and is available for the next client to pick it up
|
||||
pub fn idle(&self) {
|
||||
self.pool_stats
|
||||
.server_idle(self.state.load(Ordering::Relaxed));
|
||||
|
||||
self.state.store(ServerState::Idle, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
@@ -105,16 +106,22 @@ impl ServerStats {
|
||||
/// Also updates metrics on the pool regarding server usage.
|
||||
pub fn disconnect(&self) {
|
||||
self.reporter.server_disconnecting(self.server_id);
|
||||
self.pool_stats
|
||||
.server_disconnect(self.state.load(Ordering::Relaxed))
|
||||
}
|
||||
|
||||
/// Reports a server connection is being tested before being given to a client.
|
||||
pub fn tested(&self) {
|
||||
self.set_undefined_application();
|
||||
self.pool_stats
|
||||
.server_tested(self.state.load(Ordering::Relaxed));
|
||||
self.state.store(ServerState::Tested, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Reports a server connection is attempting to login.
|
||||
pub fn login(&self) {
|
||||
self.pool_stats
|
||||
.server_login(self.state.load(Ordering::Relaxed));
|
||||
self.state.store(ServerState::Login, Ordering::Relaxed);
|
||||
self.set_undefined_application();
|
||||
}
|
||||
@@ -122,6 +129,8 @@ impl ServerStats {
|
||||
/// Reports a server connection has been assigned to a client that
|
||||
/// is about to query the server
|
||||
pub fn active(&self, application_name: String) {
|
||||
self.pool_stats
|
||||
.server_active(self.state.load(Ordering::Relaxed));
|
||||
self.state.store(ServerState::Active, Ordering::Relaxed);
|
||||
self.set_application(application_name);
|
||||
}
|
||||
@@ -130,24 +139,13 @@ impl ServerStats {
|
||||
self.address.stats.clone()
|
||||
}
|
||||
|
||||
pub fn check_address_stat_average_is_updated_status(&self) -> bool {
|
||||
self.address.stats.averages_updated.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn set_address_stat_average_is_updated_status(&self, is_checked: bool) {
|
||||
self.address
|
||||
.stats
|
||||
.averages_updated
|
||||
.store(is_checked, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
// Helper methods for show_servers
|
||||
pub fn pool_name(&self) -> String {
|
||||
self.address.pool_name.clone()
|
||||
self.pool_stats.database()
|
||||
}
|
||||
|
||||
pub fn username(&self) -> String {
|
||||
self.address.username.clone()
|
||||
self.pool_stats.user()
|
||||
}
|
||||
|
||||
pub fn address_name(&self) -> String {
|
||||
@@ -168,17 +166,27 @@ impl ServerStats {
|
||||
}
|
||||
|
||||
pub fn checkout_time(&self, microseconds: u64, application_name: String) {
|
||||
// Update server stats and address aggregation stats
|
||||
// Update server stats and address aggergation stats
|
||||
self.set_application(application_name);
|
||||
self.address.stats.wait_time_add(microseconds);
|
||||
self.address
|
||||
.stats
|
||||
.total_wait_time
|
||||
.fetch_add(microseconds, Ordering::Relaxed);
|
||||
self.pool_stats
|
||||
.maxwait
|
||||
.fetch_max(microseconds, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Report a query executed by a client against a server
|
||||
pub fn query(&self, milliseconds: u64, application_name: &str) {
|
||||
self.set_application(application_name.to_string());
|
||||
self.address.stats.query_count_add();
|
||||
self.address.stats.query_time_add(milliseconds);
|
||||
self.query_count.fetch_add(1, Ordering::Relaxed);
|
||||
let address_stats = self.address_stats();
|
||||
address_stats
|
||||
.total_query_count
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
address_stats
|
||||
.total_query_time
|
||||
.fetch_add(milliseconds, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Report a transaction executed by a client a server
|
||||
@@ -189,38 +197,29 @@ impl ServerStats {
|
||||
self.set_application(application_name.to_string());
|
||||
|
||||
self.transaction_count.fetch_add(1, Ordering::Relaxed);
|
||||
self.address.stats.xact_count_add();
|
||||
self.address
|
||||
.stats
|
||||
.total_xact_count
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Report data sent to a server
|
||||
pub fn data_sent(&self, amount_bytes: usize) {
|
||||
self.bytes_sent
|
||||
.fetch_add(amount_bytes as u64, Ordering::Relaxed);
|
||||
self.address.stats.bytes_sent_add(amount_bytes as u64);
|
||||
self.address
|
||||
.stats
|
||||
.total_sent
|
||||
.fetch_add(amount_bytes as u64, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Report data received from a server
|
||||
pub fn data_received(&self, amount_bytes: usize) {
|
||||
self.bytes_received
|
||||
.fetch_add(amount_bytes as u64, Ordering::Relaxed);
|
||||
self.address.stats.bytes_received_add(amount_bytes as u64);
|
||||
}
|
||||
|
||||
/// Report a prepared statement that already exists on the server.
|
||||
pub fn prepared_cache_hit(&self) {
|
||||
self.prepared_hit_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Report a prepared statement that does not exist on the server yet.
|
||||
pub fn prepared_cache_miss(&self) {
|
||||
self.prepared_miss_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn prepared_cache_add(&self) {
|
||||
self.prepared_cache_size.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn prepared_cache_remove(&self) {
|
||||
self.prepared_cache_size.fetch_sub(1, Ordering::Relaxed);
|
||||
self.address
|
||||
.stats
|
||||
.total_received
|
||||
.fetch_add(amount_bytes as u64, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
FROM rust:bullseye
|
||||
|
||||
COPY --from=sclevine/yj /bin/yj /bin/yj
|
||||
RUN /bin/yj -h
|
||||
RUN apt-get update && apt-get install llvm-11 psmisc postgresql-contrib postgresql-client ruby ruby-dev libpq-dev python3 python3-pip lcov curl sudo iproute2 -y
|
||||
RUN cargo install cargo-binutils rustfilt
|
||||
RUN rustup component add llvm-tools-preview
|
||||
|
||||
@@ -63,7 +63,6 @@ def cleanup_conn(conn: psycopg2.extensions.connection, cur: psycopg2.extensions.
|
||||
|
||||
|
||||
def test_normal_db_access():
|
||||
pgcat_start()
|
||||
conn, cur = connect_db(autocommit=False)
|
||||
cur.execute("SELECT 1")
|
||||
res = cur.fetchall()
|
||||
|
||||
@@ -11,6 +11,325 @@ describe "Admin" do
|
||||
processes.pgcat.shutdown
|
||||
end
|
||||
|
||||
describe "SHOW STATS" do
|
||||
context "clients connect and make one query" do
|
||||
it "updates *_query_time and *_wait_time" do
|
||||
connection = PG::connect("#{pgcat_conn_str}?application_name=one_query")
|
||||
connection.async_exec("SELECT pg_sleep(0.25)")
|
||||
connection.async_exec("SELECT pg_sleep(0.25)")
|
||||
connection.async_exec("SELECT pg_sleep(0.25)")
|
||||
connection.close
|
||||
|
||||
# wait for averages to be calculated, we shouldn't do this too often
|
||||
sleep(15.5)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW STATS")[0]
|
||||
admin_conn.close
|
||||
expect(results["total_query_time"].to_i).to be_within(200).of(750)
|
||||
expect(results["avg_query_time"].to_i).to_not eq(0)
|
||||
|
||||
expect(results["total_wait_time"].to_i).to_not eq(0)
|
||||
expect(results["avg_wait_time"].to_i).to_not eq(0)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "SHOW POOLS" do
|
||||
context "bad credentials" do
|
||||
it "does not change any stats" do
|
||||
bad_password_url = URI(pgcat_conn_str)
|
||||
bad_password_url.password = "wrong"
|
||||
expect { PG::connect("#{bad_password_url.to_s}?application_name=bad_password") }.to raise_error(PG::ConnectionBad)
|
||||
|
||||
sleep(1)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
|
||||
expect(results["sv_idle"]).to eq("1")
|
||||
end
|
||||
end
|
||||
|
||||
context "bad database name" do
|
||||
it "does not change any stats" do
|
||||
bad_db_url = URI(pgcat_conn_str)
|
||||
bad_db_url.path = "/wrong_db"
|
||||
expect { PG::connect("#{bad_db_url.to_s}?application_name=bad_db") }.to raise_error(PG::ConnectionBad)
|
||||
|
||||
sleep(1)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
|
||||
expect(results["sv_idle"]).to eq("1")
|
||||
end
|
||||
end
|
||||
|
||||
context "client connects but issues no queries" do
|
||||
it "only affects cl_idle stats" do
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
|
||||
before_test = admin_conn.async_exec("SHOW POOLS")[0]["sv_idle"]
|
||||
connections = Array.new(20) { PG::connect(pgcat_conn_str) }
|
||||
sleep(1)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["cl_idle"]).to eq("20")
|
||||
expect(results["sv_idle"]).to eq(before_test)
|
||||
|
||||
connections.map(&:close)
|
||||
sleep(1.1)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_active cl_idle cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["sv_idle"]).to eq(before_test)
|
||||
end
|
||||
end
|
||||
|
||||
context "clients connect and make one query" do
|
||||
it "only affects cl_idle, sv_idle stats" do
|
||||
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
Thread.new { c.async_exec("SELECT pg_sleep(2.5)") }
|
||||
end
|
||||
|
||||
sleep(1.1)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_waiting cl_cancel_req sv_idle sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["cl_active"]).to eq("5")
|
||||
expect(results["sv_active"]).to eq("5")
|
||||
|
||||
sleep(3)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["cl_idle"]).to eq("5")
|
||||
expect(results["sv_idle"]).to eq("5")
|
||||
|
||||
connections.map(&:close)
|
||||
sleep(1)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["sv_idle"]).to eq("5")
|
||||
end
|
||||
end
|
||||
|
||||
context "client connects and opens a transaction and closes connection uncleanly" do
|
||||
it "produces correct statistics" do
|
||||
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
Thread.new do
|
||||
c.async_exec("BEGIN")
|
||||
c.async_exec("SELECT pg_sleep(0.01)")
|
||||
c.close
|
||||
end
|
||||
end
|
||||
|
||||
sleep(1.1)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["sv_idle"]).to eq("5")
|
||||
end
|
||||
end
|
||||
|
||||
context "client fail to checkout connection from the pool" do
|
||||
it "counts clients as idle" do
|
||||
new_configs = processes.pgcat.current_config
|
||||
new_configs["general"]["connect_timeout"] = 500
|
||||
new_configs["general"]["ban_time"] = 1
|
||||
new_configs["general"]["shutdown_timeout"] = 1
|
||||
new_configs["pools"]["sharded_db"]["users"]["0"]["pool_size"] = 1
|
||||
processes.pgcat.update_config(new_configs)
|
||||
processes.pgcat.reload_config
|
||||
|
||||
threads = []
|
||||
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
threads << Thread.new { c.async_exec("SELECT pg_sleep(1)") rescue PG::SystemError }
|
||||
end
|
||||
|
||||
sleep(2)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["cl_idle"]).to eq("5")
|
||||
expect(results["sv_idle"]).to eq("1")
|
||||
|
||||
threads.map(&:join)
|
||||
connections.map(&:close)
|
||||
end
|
||||
end
|
||||
|
||||
context "clients connects and disconnect normally" do
|
||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 2) }
|
||||
|
||||
it 'shows the same number of clients before and after' do
|
||||
clients_before = clients_connected_to_pool(processes: processes)
|
||||
threads = []
|
||||
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
threads << Thread.new { c.async_exec("SELECT 1") }
|
||||
end
|
||||
clients_between = clients_connected_to_pool(processes: processes)
|
||||
expect(clients_before).not_to eq(clients_between)
|
||||
connections.each(&:close)
|
||||
clients_after = clients_connected_to_pool(processes: processes)
|
||||
expect(clients_before).to eq(clients_after)
|
||||
end
|
||||
end
|
||||
|
||||
context "clients connects and disconnect abruptly" do
|
||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 10) }
|
||||
|
||||
it 'shows the same number of clients before and after' do
|
||||
threads = []
|
||||
connections = Array.new(2) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
threads << Thread.new { c.async_exec("SELECT 1") }
|
||||
end
|
||||
clients_before = clients_connected_to_pool(processes: processes)
|
||||
random_string = (0...8).map { (65 + rand(26)).chr }.join
|
||||
connection_string = "#{pgcat_conn_str}?application_name=#{random_string}"
|
||||
faulty_client = Process.spawn("psql -Atx #{connection_string} >/dev/null")
|
||||
sleep(1)
|
||||
# psql starts two processes, we only know the pid of the parent, this
|
||||
# ensure both are killed
|
||||
`pkill -9 -f '#{random_string}'`
|
||||
Process.wait(faulty_client)
|
||||
clients_after = clients_connected_to_pool(processes: processes)
|
||||
expect(clients_before).to eq(clients_after)
|
||||
end
|
||||
end
|
||||
|
||||
context "clients overwhelm server pools" do
|
||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 2) }
|
||||
|
||||
it "cl_waiting is updated to show it" do
|
||||
threads = []
|
||||
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
threads << Thread.new { c.async_exec("SELECT pg_sleep(1.5)") }
|
||||
end
|
||||
|
||||
sleep(1.1) # Allow time for stats to update
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_cancel_req sv_idle sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
|
||||
expect(results["cl_waiting"]).to eq("2")
|
||||
expect(results["cl_active"]).to eq("2")
|
||||
expect(results["sv_active"]).to eq("2")
|
||||
|
||||
sleep(2.5) # Allow time for stats to update
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["cl_idle"]).to eq("4")
|
||||
expect(results["sv_idle"]).to eq("2")
|
||||
|
||||
threads.map(&:join)
|
||||
connections.map(&:close)
|
||||
end
|
||||
|
||||
it "show correct max_wait" do
|
||||
threads = []
|
||||
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
threads << Thread.new { c.async_exec("SELECT pg_sleep(1.5)") }
|
||||
end
|
||||
|
||||
sleep(2.5) # Allow time for stats to update
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
|
||||
expect(results["maxwait"]).to eq("1")
|
||||
expect(results["maxwait_us"].to_i).to be_within(200_000).of(500_000)
|
||||
|
||||
sleep(4.5) # Allow time for stats to update
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
expect(results["maxwait"]).to eq("0")
|
||||
|
||||
threads.map(&:join)
|
||||
connections.map(&:close)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "SHOW CLIENTS" do
|
||||
it "reports correct number and application names" do
|
||||
conn_str = processes.pgcat.connection_string("sharded_db", "sharding_user")
|
||||
connections = Array.new(20) { |i| PG::connect("#{conn_str}?application_name=app#{i % 5}") }
|
||||
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
sleep(1) # Wait for stats to be updated
|
||||
|
||||
results = admin_conn.async_exec("SHOW CLIENTS")
|
||||
expect(results.count).to eq(21) # count admin clients
|
||||
expect(results.select { |c| c["application_name"] == "app3" || c["application_name"] == "app4" }.count).to eq(8)
|
||||
expect(results.select { |c| c["database"] == "pgcat" }.count).to eq(1)
|
||||
|
||||
connections[0..5].map(&:close)
|
||||
sleep(1) # Wait for stats to be updated
|
||||
results = admin_conn.async_exec("SHOW CLIENTS")
|
||||
expect(results.count).to eq(15)
|
||||
|
||||
connections[6..].map(&:close)
|
||||
sleep(1) # Wait for stats to be updated
|
||||
expect(admin_conn.async_exec("SHOW CLIENTS").count).to eq(1)
|
||||
admin_conn.close
|
||||
end
|
||||
|
||||
it "reports correct number of queries and transactions" do
|
||||
conn_str = processes.pgcat.connection_string("sharded_db", "sharding_user")
|
||||
|
||||
connections = Array.new(2) { |i| PG::connect("#{conn_str}?application_name=app#{i}") }
|
||||
connections.each do |c|
|
||||
c.async_exec("SELECT 1")
|
||||
c.async_exec("SELECT 2")
|
||||
c.async_exec("SELECT 3")
|
||||
c.async_exec("BEGIN")
|
||||
c.async_exec("SELECT 4")
|
||||
c.async_exec("SELECT 5")
|
||||
c.async_exec("COMMIT")
|
||||
end
|
||||
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
sleep(1) # Wait for stats to be updated
|
||||
|
||||
results = admin_conn.async_exec("SHOW CLIENTS")
|
||||
expect(results.count).to eq(3)
|
||||
normal_client_results = results.reject { |r| r["database"] == "pgcat" }
|
||||
expect(normal_client_results[0]["transaction_count"]).to eq("4")
|
||||
expect(normal_client_results[1]["transaction_count"]).to eq("4")
|
||||
expect(normal_client_results[0]["query_count"]).to eq("7")
|
||||
expect(normal_client_results[1]["query_count"]).to eq("7")
|
||||
|
||||
admin_conn.close
|
||||
connections.map(&:close)
|
||||
end
|
||||
end
|
||||
|
||||
describe "Manual Banning" do
|
||||
let(:processes) { Helpers::Pgcat.single_shard_setup("sharded_db", 10) }
|
||||
before do
|
||||
@@ -81,7 +400,7 @@ describe "Admin" do
|
||||
end
|
||||
end
|
||||
|
||||
describe "SHOW USERS" do
|
||||
describe "SHOW users" do
|
||||
it "returns the right users" do
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW USERS")[0]
|
||||
@@ -90,28 +409,4 @@ describe "Admin" do
|
||||
expect(results["pool_mode"]).to eq("transaction")
|
||||
end
|
||||
end
|
||||
|
||||
describe "PAUSE" do
|
||||
it "pauses all pools" do
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW DATABASES").to_a
|
||||
expect(results.map{ |r| r["paused"] }.uniq).to eq(["0"])
|
||||
|
||||
admin_conn.async_exec("PAUSE")
|
||||
|
||||
results = admin_conn.async_exec("SHOW DATABASES").to_a
|
||||
expect(results.map{ |r| r["paused"] }.uniq).to eq(["1"])
|
||||
|
||||
admin_conn.async_exec("RESUME")
|
||||
|
||||
results = admin_conn.async_exec("SHOW DATABASES").to_a
|
||||
expect(results.map{ |r| r["paused"] }.uniq).to eq(["0"])
|
||||
end
|
||||
|
||||
it "handles errors" do
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
expect { admin_conn.async_exec("PAUSE foo").to_a }.to raise_error(PG::SystemError)
|
||||
expect { admin_conn.async_exec("PAUSE foo,bar").to_a }.to raise_error(PG::SystemError)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@@ -185,7 +185,7 @@ describe "Auth Query" do
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
context 'and with cleartext passwords set' do
|
||||
it 'it uses local passwords' do
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
# frozen_string_literal: true
|
||||
require_relative 'spec_helper'
|
||||
|
||||
|
||||
describe "COPY Handling" do
|
||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 5) }
|
||||
before do
|
||||
new_configs = processes.pgcat.current_config
|
||||
|
||||
# Allow connections in the pool to expire faster
|
||||
new_configs["general"]["idle_timeout"] = 5
|
||||
processes.pgcat.update_config(new_configs)
|
||||
# We need to kill the old process that was using the default configs
|
||||
processes.pgcat.stop
|
||||
processes.pgcat.start
|
||||
processes.pgcat.wait_until_ready
|
||||
end
|
||||
|
||||
before do
|
||||
processes.all_databases.first.with_connection do |conn|
|
||||
conn.async_exec "CREATE TABLE copy_test_table (a TEXT,b TEXT,c TEXT,d TEXT)"
|
||||
end
|
||||
end
|
||||
|
||||
after do
|
||||
processes.all_databases.first.with_connection do |conn|
|
||||
conn.async_exec "DROP TABLE copy_test_table;"
|
||||
end
|
||||
end
|
||||
|
||||
after do
|
||||
processes.all_databases.map(&:reset)
|
||||
processes.pgcat.shutdown
|
||||
end
|
||||
|
||||
describe "COPY FROM" do
|
||||
context "within transaction" do
|
||||
it "finishes within alloted time" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
Timeout.timeout(3) do
|
||||
conn.async_exec("BEGIN")
|
||||
conn.copy_data "COPY copy_test_table FROM STDIN CSV" do
|
||||
sleep 0.5
|
||||
conn.put_copy_data "some,data,to,copy\n"
|
||||
conn.put_copy_data "more,data,to,copy\n"
|
||||
end
|
||||
conn.async_exec("COMMIT")
|
||||
end
|
||||
|
||||
res = conn.async_exec("SELECT * FROM copy_test_table").to_a
|
||||
expect(res).to eq([
|
||||
{"a"=>"some", "b"=>"data", "c"=>"to", "d"=>"copy"},
|
||||
{"a"=>"more", "b"=>"data", "c"=>"to", "d"=>"copy"}
|
||||
])
|
||||
end
|
||||
end
|
||||
|
||||
context "outside transaction" do
|
||||
it "finishes within alloted time" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
Timeout.timeout(3) do
|
||||
conn.copy_data "COPY copy_test_table FROM STDIN CSV" do
|
||||
sleep 0.5
|
||||
conn.put_copy_data "some,data,to,copy\n"
|
||||
conn.put_copy_data "more,data,to,copy\n"
|
||||
end
|
||||
end
|
||||
|
||||
res = conn.async_exec("SELECT * FROM copy_test_table").to_a
|
||||
expect(res).to eq([
|
||||
{"a"=>"some", "b"=>"data", "c"=>"to", "d"=>"copy"},
|
||||
{"a"=>"more", "b"=>"data", "c"=>"to", "d"=>"copy"}
|
||||
])
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "COPY TO" do
|
||||
before do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
conn.async_exec("BEGIN")
|
||||
conn.copy_data "COPY copy_test_table FROM STDIN CSV" do
|
||||
conn.put_copy_data "some,data,to,copy\n"
|
||||
conn.put_copy_data "more,data,to,copy\n"
|
||||
end
|
||||
conn.async_exec("COMMIT")
|
||||
conn.close
|
||||
end
|
||||
|
||||
it "works" do
|
||||
res = []
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
conn.copy_data "COPY copy_test_table TO STDOUT CSV" do
|
||||
while row=conn.get_copy_data
|
||||
res << row
|
||||
end
|
||||
end
|
||||
expect(res).to eq(["some,data,to,copy\n", "more,data,to,copy\n"])
|
||||
end
|
||||
end
|
||||
|
||||
end
|
||||
@@ -33,18 +33,18 @@ module Helpers
|
||||
"0" => {
|
||||
"database" => "shard0",
|
||||
"servers" => [
|
||||
["localhost", primary.port.to_i, "primary"],
|
||||
["localhost", replica.port.to_i, "replica"],
|
||||
["localhost", primary.port.to_s, "primary"],
|
||||
["localhost", replica.port.to_s, "replica"],
|
||||
]
|
||||
},
|
||||
},
|
||||
"users" => { "0" => user.merge(config_user) }
|
||||
}
|
||||
}
|
||||
pgcat_cfg["general"]["port"] = pgcat.port.to_i
|
||||
pgcat_cfg["general"]["port"] = pgcat.port
|
||||
pgcat.update_config(pgcat_cfg)
|
||||
pgcat.start
|
||||
|
||||
|
||||
pgcat.wait_until_ready(
|
||||
pgcat.connection_string(
|
||||
"sharded_db",
|
||||
@@ -92,13 +92,13 @@ module Helpers
|
||||
"0" => {
|
||||
"database" => database,
|
||||
"servers" => [
|
||||
["localhost", primary.port.to_i, "primary"],
|
||||
["localhost", replica.port.to_i, "replica"],
|
||||
["localhost", primary.port.to_s, "primary"],
|
||||
["localhost", replica.port.to_s, "replica"],
|
||||
]
|
||||
},
|
||||
},
|
||||
"users" => { "0" => user.merge(config_user) }
|
||||
}
|
||||
}
|
||||
end
|
||||
# Main proxy configs
|
||||
pgcat_cfg["pools"] = {
|
||||
@@ -109,7 +109,7 @@ module Helpers
|
||||
pgcat_cfg["general"]["port"] = pgcat.port
|
||||
pgcat.update_config(pgcat_cfg.deep_merge(extra_conf))
|
||||
pgcat.start
|
||||
|
||||
|
||||
pgcat.wait_until_ready(pgcat.connection_string("sharded_db0", pg_user['username'], pg_user['password']))
|
||||
|
||||
OpenStruct.new.tap do |struct|
|
||||
|
||||
@@ -7,24 +7,10 @@ class PgInstance
|
||||
attr_reader :password
|
||||
attr_reader :database_name
|
||||
|
||||
def self.mass_takedown(databases)
|
||||
raise StandardError "block missing" unless block_given?
|
||||
|
||||
databases.each do |database|
|
||||
database.toxiproxy.toxic(:limit_data, bytes: 1).toxics.each(&:save)
|
||||
end
|
||||
sleep 0.1
|
||||
yield
|
||||
ensure
|
||||
databases.each do |database|
|
||||
database.toxiproxy.toxics.each(&:destroy)
|
||||
end
|
||||
end
|
||||
|
||||
def initialize(port, username, password, database_name)
|
||||
@original_port = port.to_i
|
||||
@original_port = port
|
||||
@toxiproxy_port = 10000 + port.to_i
|
||||
@port = @toxiproxy_port.to_i
|
||||
@port = @toxiproxy_port
|
||||
|
||||
@username = username
|
||||
@password = password
|
||||
@@ -62,9 +48,9 @@ class PgInstance
|
||||
|
||||
def take_down
|
||||
if block_given?
|
||||
Toxiproxy[@toxiproxy_name].toxic(:limit_data, bytes: 1).apply { yield }
|
||||
Toxiproxy[@toxiproxy_name].toxic(:limit_data, bytes: 5).apply { yield }
|
||||
else
|
||||
Toxiproxy[@toxiproxy_name].toxic(:limit_data, bytes: 1).toxics.each(&:save)
|
||||
Toxiproxy[@toxiproxy_name].toxic(:limit_data, bytes: 5).toxics.each(&:save)
|
||||
end
|
||||
end
|
||||
|
||||
@@ -103,6 +89,6 @@ class PgInstance
|
||||
end
|
||||
|
||||
def count_select_1_plus_2
|
||||
with_connection { |c| c.async_exec("SELECT SUM(calls) FROM pg_stat_statements WHERE query LIKE '%SELECT $1 + $2%'")[0]["sum"].to_i }
|
||||
with_connection { |c| c.async_exec("SELECT SUM(calls) FROM pg_stat_statements WHERE query = 'SELECT $1 + $2'")[0]["sum"].to_i }
|
||||
end
|
||||
end
|
||||
|
||||
@@ -34,32 +34,14 @@ module Helpers
|
||||
"load_balancing_mode" => lb_mode,
|
||||
"primary_reads_enabled" => true,
|
||||
"query_parser_enabled" => true,
|
||||
"query_parser_read_write_splitting" => true,
|
||||
"automatic_sharding_key" => "data.id",
|
||||
"sharding_function" => "pg_bigint_hash",
|
||||
"shards" => {
|
||||
"0" => { "database" => "shard0", "servers" => [["localhost", primary0.port.to_i, "primary"]] },
|
||||
"1" => { "database" => "shard1", "servers" => [["localhost", primary1.port.to_i, "primary"]] },
|
||||
"2" => { "database" => "shard2", "servers" => [["localhost", primary2.port.to_i, "primary"]] },
|
||||
"0" => { "database" => "shard0", "servers" => [["localhost", primary0.port.to_s, "primary"]] },
|
||||
"1" => { "database" => "shard1", "servers" => [["localhost", primary1.port.to_s, "primary"]] },
|
||||
"2" => { "database" => "shard2", "servers" => [["localhost", primary2.port.to_s, "primary"]] },
|
||||
},
|
||||
"users" => { "0" => user },
|
||||
"plugins" => {
|
||||
"intercept" => {
|
||||
"enabled" => true,
|
||||
"queries" => {
|
||||
"0" => {
|
||||
"query" => "select current_database() as a, current_schemas(false) as b",
|
||||
"schema" => [
|
||||
["a", "text"],
|
||||
["b", "text"],
|
||||
],
|
||||
"result" => [
|
||||
["${DATABASE}", "{public}"],
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"users" => { "0" => user }
|
||||
}
|
||||
}
|
||||
pgcat.update_config(pgcat_cfg)
|
||||
@@ -100,7 +82,7 @@ module Helpers
|
||||
"0" => {
|
||||
"database" => "shard0",
|
||||
"servers" => [
|
||||
["localhost", primary.port.to_i, "primary"]
|
||||
["localhost", primary.port.to_s, "primary"]
|
||||
]
|
||||
},
|
||||
},
|
||||
@@ -119,7 +101,7 @@ module Helpers
|
||||
end
|
||||
end
|
||||
|
||||
def self.single_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mode="random", log_level="info", pool_settings={})
|
||||
def self.single_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mode="random", log_level="info")
|
||||
user = {
|
||||
"password" => "sharding_user",
|
||||
"pool_size" => pool_size,
|
||||
@@ -135,32 +117,28 @@ module Helpers
|
||||
replica1 = PgInstance.new(8432, user["username"], user["password"], "shard0")
|
||||
replica2 = PgInstance.new(9432, user["username"], user["password"], "shard0")
|
||||
|
||||
pool_config = {
|
||||
"default_role" => "any",
|
||||
"pool_mode" => pool_mode,
|
||||
"load_balancing_mode" => lb_mode,
|
||||
"primary_reads_enabled" => false,
|
||||
"query_parser_enabled" => false,
|
||||
"sharding_function" => "pg_bigint_hash",
|
||||
"shards" => {
|
||||
"0" => {
|
||||
"database" => "shard0",
|
||||
"servers" => [
|
||||
["localhost", primary.port.to_i, "primary"],
|
||||
["localhost", replica0.port.to_i, "replica"],
|
||||
["localhost", replica1.port.to_i, "replica"],
|
||||
["localhost", replica2.port.to_i, "replica"]
|
||||
]
|
||||
},
|
||||
},
|
||||
"users" => { "0" => user }
|
||||
}
|
||||
|
||||
pool_config = pool_config.merge(pool_settings)
|
||||
|
||||
# Main proxy configs
|
||||
pgcat_cfg["pools"] = {
|
||||
"#{pool_name}" => pool_config,
|
||||
"#{pool_name}" => {
|
||||
"default_role" => "any",
|
||||
"pool_mode" => pool_mode,
|
||||
"load_balancing_mode" => lb_mode,
|
||||
"primary_reads_enabled" => false,
|
||||
"query_parser_enabled" => false,
|
||||
"sharding_function" => "pg_bigint_hash",
|
||||
"shards" => {
|
||||
"0" => {
|
||||
"database" => "shard0",
|
||||
"servers" => [
|
||||
["localhost", primary.port.to_s, "primary"],
|
||||
["localhost", replica0.port.to_s, "replica"],
|
||||
["localhost", replica1.port.to_s, "replica"],
|
||||
["localhost", replica2.port.to_s, "replica"]
|
||||
]
|
||||
},
|
||||
},
|
||||
"users" => { "0" => user }
|
||||
}
|
||||
}
|
||||
pgcat_cfg["general"]["port"] = pgcat.port
|
||||
pgcat.update_config(pgcat_cfg)
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
require 'pg'
|
||||
require 'json'
|
||||
require 'tempfile'
|
||||
require 'toml'
|
||||
require 'fileutils'
|
||||
require 'securerandom'
|
||||
|
||||
class ConfigReloadFailed < StandardError; end
|
||||
class PgcatProcess
|
||||
attr_reader :port
|
||||
attr_reader :pid
|
||||
@@ -20,7 +18,7 @@ class PgcatProcess
|
||||
end
|
||||
|
||||
def initialize(log_level)
|
||||
@env = {}
|
||||
@env = {"RUST_LOG" => log_level}
|
||||
@port = rand(20000..32760)
|
||||
@log_level = log_level
|
||||
@log_filename = "/tmp/pgcat_log_#{SecureRandom.urlsafe_base64}.log"
|
||||
@@ -32,7 +30,7 @@ class PgcatProcess
|
||||
'../../target/debug/pgcat'
|
||||
end
|
||||
|
||||
@command = "#{command_path} #{@config_filename} --log-level #{@log_level}"
|
||||
@command = "#{command_path} #{@config_filename}"
|
||||
|
||||
FileUtils.cp("../../pgcat.toml", @config_filename)
|
||||
cfg = current_config
|
||||
@@ -48,34 +46,22 @@ class PgcatProcess
|
||||
|
||||
def update_config(config_hash)
|
||||
@original_config = current_config
|
||||
Tempfile.create('json_out', '/tmp') do |f|
|
||||
f.write(config_hash.to_json)
|
||||
f.flush
|
||||
`cat #{f.path} | yj -jt > #{@config_filename}`
|
||||
end
|
||||
output_to_write = TOML::Generator.new(config_hash).body
|
||||
output_to_write = output_to_write.gsub(/,\s*["|'](\d+)["|']\s*,/, ',\1,')
|
||||
output_to_write = output_to_write.gsub(/,\s*["|'](\d+)["|']\s*\]/, ',\1]')
|
||||
File.write(@config_filename, output_to_write)
|
||||
end
|
||||
|
||||
def current_config
|
||||
JSON.parse(`cat #{@config_filename} | yj -tj`)
|
||||
end
|
||||
|
||||
def raw_config_file
|
||||
File.read(@config_filename)
|
||||
loadable_string = File.read(@config_filename)
|
||||
loadable_string = loadable_string.gsub(/,\s*(\d+)\s*,/, ', "\1",')
|
||||
loadable_string = loadable_string.gsub(/,\s*(\d+)\s*\]/, ', "\1"]')
|
||||
TOML.load(loadable_string)
|
||||
end
|
||||
|
||||
def reload_config
|
||||
conn = PG.connect(admin_connection_string)
|
||||
|
||||
conn.async_exec("RELOAD")
|
||||
rescue PG::ConnectionBad => e
|
||||
errors = logs.split("Reloading config").last
|
||||
errors = errors.gsub(/\e\[([;\d]+)?m/, '') # Remove color codes
|
||||
errors = errors.
|
||||
split("\n").select{|line| line.include?("ERROR") }.
|
||||
map { |line| line.split("pgcat::config: ").last }
|
||||
raise ConfigReloadFailed, errors.join("\n")
|
||||
ensure
|
||||
conn&.close
|
||||
`kill -s HUP #{@pid}`
|
||||
sleep 0.5
|
||||
end
|
||||
|
||||
def start
|
||||
@@ -126,16 +112,10 @@ class PgcatProcess
|
||||
"postgresql://#{username}:#{password}@0.0.0.0:#{@port}/pgcat"
|
||||
end
|
||||
|
||||
def connection_string(pool_name, username, password = nil, parameters: {})
|
||||
def connection_string(pool_name, username, password = nil)
|
||||
cfg = current_config
|
||||
user_idx, user_obj = cfg["pools"][pool_name]["users"].detect { |k, user| user["username"] == username }
|
||||
connection_string = "postgresql://#{username}:#{password || user_obj["password"]}@0.0.0.0:#{@port}/#{pool_name}"
|
||||
|
||||
# Add the additional parameters to the connection string
|
||||
parameter_string = parameters.map { |key, value| "#{key}=#{value}" }.join("&")
|
||||
connection_string += "?#{parameter_string}" unless parameter_string.empty?
|
||||
|
||||
connection_string
|
||||
"postgresql://#{username}:#{password || user_obj["password"]}@0.0.0.0:#{@port}/#{pool_name}"
|
||||
end
|
||||
|
||||
def example_connection_string
|
||||
|
||||
@@ -11,9 +11,9 @@ describe "Query Mirroing" do
|
||||
before do
|
||||
new_configs = processes.pgcat.current_config
|
||||
new_configs["pools"]["sharded_db"]["shards"]["0"]["mirrors"] = [
|
||||
[mirror_host, mirror_pg.port.to_i, 0],
|
||||
[mirror_host, mirror_pg.port.to_i, 0],
|
||||
[mirror_host, mirror_pg.port.to_i, 0],
|
||||
[mirror_host, mirror_pg.port.to_s, "0"],
|
||||
[mirror_host, mirror_pg.port.to_s, "0"],
|
||||
[mirror_host, mirror_pg.port.to_s, "0"],
|
||||
]
|
||||
processes.pgcat.update_config(new_configs)
|
||||
processes.pgcat.reload_config
|
||||
@@ -31,8 +31,7 @@ describe "Query Mirroing" do
|
||||
runs.times { conn.async_exec("SELECT 1 + 2") }
|
||||
sleep 0.5
|
||||
expect(processes.all_databases.first.count_select_1_plus_2).to eq(runs)
|
||||
# Allow some slack in mirroring successes
|
||||
expect(mirror_pg.count_select_1_plus_2).to be > ((runs - 5) * 3)
|
||||
expect(mirror_pg.count_select_1_plus_2).to eq(runs * 3)
|
||||
end
|
||||
|
||||
context "when main server connection is closed" do
|
||||
@@ -43,9 +42,9 @@ describe "Query Mirroing" do
|
||||
new_configs = processes.pgcat.current_config
|
||||
new_configs["pools"]["sharded_db"]["idle_timeout"] = 5000 + i
|
||||
new_configs["pools"]["sharded_db"]["shards"]["0"]["mirrors"] = [
|
||||
[mirror_host, mirror_pg.port.to_i, 0],
|
||||
[mirror_host, mirror_pg.port.to_i, 0],
|
||||
[mirror_host, mirror_pg.port.to_i, 0],
|
||||
[mirror_host, mirror_pg.port.to_s, "0"],
|
||||
[mirror_host, mirror_pg.port.to_s, "0"],
|
||||
[mirror_host, mirror_pg.port.to_s, "0"],
|
||||
]
|
||||
processes.pgcat.update_config(new_configs)
|
||||
processes.pgcat.reload_config
|
||||
|
||||
@@ -221,7 +221,7 @@ describe "Miscellaneous" do
|
||||
conn.close
|
||||
end
|
||||
|
||||
it "Does not send RESET ALL unless necessary" do
|
||||
it "Does not send DISCARD ALL unless necessary" do
|
||||
10.times do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
conn.async_exec("SET SERVER ROLE to 'primary'")
|
||||
@@ -229,7 +229,7 @@ describe "Miscellaneous" do
|
||||
conn.close
|
||||
end
|
||||
|
||||
expect(processes.primary.count_query("RESET ALL")).to eq(0)
|
||||
expect(processes.primary.count_query("DISCARD ALL")).to eq(0)
|
||||
|
||||
10.times do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
@@ -239,19 +239,7 @@ describe "Miscellaneous" do
|
||||
conn.close
|
||||
end
|
||||
|
||||
expect(processes.primary.count_query("RESET ALL")).to eq(10)
|
||||
end
|
||||
|
||||
it "Resets server roles correctly" do
|
||||
10.times do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
conn.async_exec("SET SERVER ROLE to 'primary'")
|
||||
conn.async_exec("SELECT 1")
|
||||
conn.async_exec("SET statement_timeout to 5000")
|
||||
conn.close
|
||||
end
|
||||
|
||||
expect(processes.primary.count_query("RESET ROLE")).to eq(10)
|
||||
expect(processes.primary.count_query("DISCARD ALL")).to eq(10)
|
||||
end
|
||||
end
|
||||
|
||||
@@ -273,7 +261,7 @@ describe "Miscellaneous" do
|
||||
end
|
||||
end
|
||||
|
||||
it "Does not send RESET ALL unless necessary" do
|
||||
it "Does not send DISCARD ALL unless necessary" do
|
||||
10.times do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
conn.async_exec("SET SERVER ROLE to 'primary'")
|
||||
@@ -282,7 +270,7 @@ describe "Miscellaneous" do
|
||||
conn.close
|
||||
end
|
||||
|
||||
expect(processes.primary.count_query("RESET ALL")).to eq(0)
|
||||
expect(processes.primary.count_query("DISCARD ALL")).to eq(0)
|
||||
|
||||
10.times do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
@@ -292,32 +280,8 @@ describe "Miscellaneous" do
|
||||
conn.close
|
||||
end
|
||||
|
||||
expect(processes.primary.count_query("RESET ALL")).to eq(10)
|
||||
expect(processes.primary.count_query("DISCARD ALL")).to eq(10)
|
||||
end
|
||||
|
||||
it "Respects tracked parameters on startup" do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user", parameters: { "application_name" => "my_pgcat_test" }))
|
||||
|
||||
expect(conn.async_exec("SHOW application_name")[0]["application_name"]).to eq("my_pgcat_test")
|
||||
conn.close
|
||||
end
|
||||
|
||||
it "Respect tracked parameter on set statemet" do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
|
||||
conn.async_exec("SET application_name to 'my_pgcat_test'")
|
||||
expect(conn.async_exec("SHOW application_name")[0]["application_name"]).to eq("my_pgcat_test")
|
||||
end
|
||||
|
||||
|
||||
it "Ignore untracked parameter on set statemet" do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
orignal_statement_timeout = conn.async_exec("SHOW statement_timeout")[0]["statement_timeout"]
|
||||
|
||||
conn.async_exec("SET statement_timeout to 1500")
|
||||
expect(conn.async_exec("SHOW statement_timeout")[0]["statement_timeout"]).to eq(orignal_statement_timeout)
|
||||
end
|
||||
|
||||
end
|
||||
|
||||
context "transaction mode with transactions" do
|
||||
@@ -331,7 +295,7 @@ describe "Miscellaneous" do
|
||||
conn.async_exec("COMMIT")
|
||||
conn.close
|
||||
end
|
||||
expect(processes.primary.count_query("RESET ALL")).to eq(0)
|
||||
expect(processes.primary.count_query("DISCARD ALL")).to eq(0)
|
||||
|
||||
10.times do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
@@ -341,30 +305,7 @@ describe "Miscellaneous" do
|
||||
conn.async_exec("COMMIT")
|
||||
conn.close
|
||||
end
|
||||
expect(processes.primary.count_query("RESET ALL")).to eq(0)
|
||||
end
|
||||
end
|
||||
|
||||
context "server cleanup disabled" do
|
||||
let(:processes) { Helpers::Pgcat.single_shard_setup("sharded_db", 1, "transaction", "random", "info", { "cleanup_server_connections" => false }) }
|
||||
|
||||
it "will not clean up connection state" do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
processes.primary.reset_stats
|
||||
conn.async_exec("SET statement_timeout TO 1000")
|
||||
conn.close
|
||||
|
||||
expect(processes.primary.count_query("RESET ALL")).to eq(0)
|
||||
end
|
||||
|
||||
it "will not clean up prepared statements" do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
processes.primary.reset_stats
|
||||
conn.async_exec("PREPARE prepared_q (int) AS SELECT $1")
|
||||
|
||||
conn.close
|
||||
|
||||
expect(processes.primary.count_query("RESET ALL")).to eq(0)
|
||||
expect(processes.primary.count_query("DISCARD ALL")).to eq(0)
|
||||
end
|
||||
end
|
||||
end
|
||||
@@ -374,9 +315,10 @@ describe "Miscellaneous" do
|
||||
before do
|
||||
current_configs = processes.pgcat.current_config
|
||||
correct_idle_client_transaction_timeout = current_configs["general"]["idle_client_in_transaction_timeout"]
|
||||
|
||||
puts(current_configs["general"]["idle_client_in_transaction_timeout"])
|
||||
|
||||
current_configs["general"]["idle_client_in_transaction_timeout"] = 0
|
||||
|
||||
|
||||
processes.pgcat.update_config(current_configs) # with timeout 0
|
||||
processes.pgcat.reload_config
|
||||
end
|
||||
@@ -394,9 +336,9 @@ describe "Miscellaneous" do
|
||||
context "idle transaction timeout set to 500ms" do
|
||||
before do
|
||||
current_configs = processes.pgcat.current_config
|
||||
correct_idle_client_transaction_timeout = current_configs["general"]["idle_client_in_transaction_timeout"]
|
||||
correct_idle_client_transaction_timeout = current_configs["general"]["idle_client_in_transaction_timeout"]
|
||||
current_configs["general"]["idle_client_in_transaction_timeout"] = 500
|
||||
|
||||
|
||||
processes.pgcat.update_config(current_configs) # with timeout 500
|
||||
processes.pgcat.reload_config
|
||||
end
|
||||
@@ -415,7 +357,7 @@ describe "Miscellaneous" do
|
||||
conn.async_exec("BEGIN")
|
||||
conn.async_exec("SELECT 1")
|
||||
sleep(1) # above 500ms
|
||||
expect{ conn.async_exec("COMMIT") }.to raise_error(PG::SystemError, /idle transaction timeout/)
|
||||
expect{ conn.async_exec("COMMIT") }.to raise_error(PG::SystemError, /idle transaction timeout/)
|
||||
conn.async_exec("SELECT 1") # should be able to send another query
|
||||
conn.close
|
||||
end
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
require_relative 'spec_helper'
|
||||
|
||||
describe 'Prepared statements' do
|
||||
let(:processes) { Helpers::Pgcat.three_shard_setup('sharded_db', 5) }
|
||||
|
||||
context 'enabled' do
|
||||
it 'will work over the same connection' do
|
||||
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
|
||||
|
||||
10.times do |i|
|
||||
statement_name = "statement_#{i}"
|
||||
conn.prepare(statement_name, 'SELECT $1::int')
|
||||
conn.exec_prepared(statement_name, [1])
|
||||
conn.describe_prepared(statement_name)
|
||||
end
|
||||
end
|
||||
|
||||
it 'will work with new connections' do
|
||||
10.times do
|
||||
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
|
||||
|
||||
statement_name = 'statement1'
|
||||
conn.prepare('statement1', 'SELECT $1::int')
|
||||
conn.exec_prepared('statement1', [1])
|
||||
conn.describe_prepared('statement1')
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
@@ -7,11 +7,11 @@ describe "Sharding" do
|
||||
|
||||
before do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
|
||||
# Setup the sharding data
|
||||
3.times do |i|
|
||||
conn.exec("SET SHARD TO '#{i}'")
|
||||
|
||||
conn.exec("DELETE FROM data WHERE id > 0") rescue nil
|
||||
conn.exec("DELETE FROM data WHERE id > 0")
|
||||
end
|
||||
|
||||
18.times do |i|
|
||||
@@ -19,11 +19,10 @@ describe "Sharding" do
|
||||
conn.exec("SET SHARDING KEY TO '#{i}'")
|
||||
conn.exec("INSERT INTO data (id, value) VALUES (#{i}, 'value_#{i}')")
|
||||
end
|
||||
|
||||
conn.close
|
||||
end
|
||||
|
||||
after do
|
||||
|
||||
processes.all_databases.map(&:reset)
|
||||
processes.pgcat.shutdown
|
||||
end
|
||||
@@ -49,148 +48,4 @@ describe "Sharding" do
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "no_shard_specified_behavior config" do
|
||||
context "when default shard number is invalid" do
|
||||
it "prevents config reload" do
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
|
||||
current_configs = processes.pgcat.current_config
|
||||
current_configs["pools"]["sharded_db"]["default_shard"] = "shard_99"
|
||||
|
||||
processes.pgcat.update_config(current_configs)
|
||||
|
||||
expect { processes.pgcat.reload_config }.to raise_error(ConfigReloadFailed, /Invalid shard 99/)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "comment-based routing" do
|
||||
context "when no configs are set" do
|
||||
it "routes queries with a shard_id comment to the default shard" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
10.times { conn.async_exec("/* shard_id: 2 */ SELECT 1 + 2") }
|
||||
|
||||
expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([10, 0, 0])
|
||||
end
|
||||
|
||||
it "does not honor no_shard_specified_behavior directives" do
|
||||
end
|
||||
end
|
||||
|
||||
[
|
||||
["shard_id_regex", "/\\* the_shard_id: (\\d+) \\*/", "/* the_shard_id: 1 */"],
|
||||
["sharding_key_regex", "/\\* the_sharding_key: (\\d+) \\*/", "/* the_sharding_key: 3 */"],
|
||||
].each do |config_name, config_value, comment_to_use|
|
||||
context "when #{config_name} config is set" do
|
||||
let(:no_shard_specified_behavior) { nil }
|
||||
|
||||
before do
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
|
||||
current_configs = processes.pgcat.current_config
|
||||
current_configs["pools"]["sharded_db"][config_name] = config_value
|
||||
if no_shard_specified_behavior
|
||||
current_configs["pools"]["sharded_db"]["default_shard"] = no_shard_specified_behavior
|
||||
else
|
||||
current_configs["pools"]["sharded_db"].delete("default_shard")
|
||||
end
|
||||
|
||||
processes.pgcat.update_config(current_configs)
|
||||
processes.pgcat.reload_config
|
||||
end
|
||||
|
||||
it "routes queries with a shard_id comment to the correct shard" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") }
|
||||
|
||||
expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 25, 0])
|
||||
end
|
||||
|
||||
context "when no_shard_specified_behavior config is set to random" do
|
||||
let(:no_shard_specified_behavior) { "random" }
|
||||
|
||||
context "with no shard comment" do
|
||||
it "sends queries to random shard" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
25.times { conn.async_exec("SELECT 1 + 2") }
|
||||
|
||||
expect(processes.all_databases.map(&:count_select_1_plus_2).all?(&:positive?)).to be true
|
||||
end
|
||||
end
|
||||
|
||||
context "with a shard comment" do
|
||||
it "honors the comment" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") }
|
||||
|
||||
expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 25, 0])
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
context "when no_shard_specified_behavior config is set to random_healthy" do
|
||||
let(:no_shard_specified_behavior) { "random_healthy" }
|
||||
|
||||
context "with no shard comment" do
|
||||
it "sends queries to random healthy shard" do
|
||||
|
||||
good_databases = [processes.all_databases[0], processes.all_databases[2]]
|
||||
bad_database = processes.all_databases[1]
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
250.times { conn.async_exec("SELECT 99") }
|
||||
bad_database.take_down do
|
||||
250.times do
|
||||
conn.async_exec("SELECT 99")
|
||||
rescue PG::ConnectionBad => e
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
end
|
||||
end
|
||||
|
||||
# Routes traffic away from bad shard
|
||||
25.times { conn.async_exec("SELECT 1 + 2") }
|
||||
expect(good_databases.map(&:count_select_1_plus_2).all?(&:positive?)).to be true
|
||||
expect(bad_database.count_select_1_plus_2).to eq(0)
|
||||
|
||||
# Routes traffic to the bad shard if the shard_id is specified
|
||||
25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") }
|
||||
bad_database = processes.all_databases[1]
|
||||
expect(bad_database.count_select_1_plus_2).to eq(25)
|
||||
end
|
||||
end
|
||||
|
||||
context "with a shard comment" do
|
||||
it "honors the comment" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") }
|
||||
|
||||
expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 25, 0])
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
context "when no_shard_specified_behavior config is set to shard_x" do
|
||||
let(:no_shard_specified_behavior) { "shard_2" }
|
||||
|
||||
context "with no shard comment" do
|
||||
it "sends queries to the specified shard" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
25.times { conn.async_exec("SELECT 1 + 2") }
|
||||
|
||||
expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 0, 25])
|
||||
end
|
||||
end
|
||||
|
||||
context "with a shard comment" do
|
||||
it "honors the comment" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") }
|
||||
|
||||
expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 25, 0])
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@@ -1,369 +0,0 @@
|
||||
# frozen_string_literal: true
|
||||
require 'open3'
|
||||
require_relative 'spec_helper'
|
||||
|
||||
describe "Stats" do
|
||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 10) }
|
||||
let(:pgcat_conn_str) { processes.pgcat.connection_string("sharded_db", "sharding_user") }
|
||||
|
||||
after do
|
||||
processes.all_databases.map(&:reset)
|
||||
processes.pgcat.shutdown
|
||||
end
|
||||
|
||||
describe "SHOW STATS" do
|
||||
context "clients connect and make one query" do
|
||||
it "updates *_query_time and *_wait_time" do
|
||||
connections = Array.new(3) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
Thread.new { c.async_exec("SELECT pg_sleep(0.25)") }
|
||||
end
|
||||
sleep(1)
|
||||
connections.map(&:close)
|
||||
|
||||
# wait for averages to be calculated, we shouldn't do this too often
|
||||
sleep(15.5)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW STATS")[0]
|
||||
admin_conn.close
|
||||
expect(results["total_query_time"].to_i).to be_within(200).of(750)
|
||||
expect(results["avg_query_time"].to_i).to be_within(50).of(250)
|
||||
|
||||
expect(results["total_wait_time"].to_i).to_not eq(0)
|
||||
expect(results["avg_wait_time"].to_i).to_not eq(0)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "SHOW POOLS" do
|
||||
context "bad credentials" do
|
||||
it "does not change any stats" do
|
||||
bad_password_url = URI(pgcat_conn_str)
|
||||
bad_password_url.password = "wrong"
|
||||
expect { PG::connect("#{bad_password_url.to_s}?application_name=bad_password") }.to raise_error(PG::ConnectionBad)
|
||||
|
||||
sleep(1)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
|
||||
expect(results["sv_idle"]).to eq("1")
|
||||
end
|
||||
end
|
||||
|
||||
context "bad database name" do
|
||||
it "does not change any stats" do
|
||||
bad_db_url = URI(pgcat_conn_str)
|
||||
bad_db_url.path = "/wrong_db"
|
||||
expect { PG::connect("#{bad_db_url.to_s}?application_name=bad_db") }.to raise_error(PG::ConnectionBad)
|
||||
|
||||
sleep(1)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
|
||||
expect(results["sv_idle"]).to eq("1")
|
||||
end
|
||||
end
|
||||
|
||||
context "client connects but issues no queries" do
|
||||
it "only affects cl_idle stats" do
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
|
||||
before_test = admin_conn.async_exec("SHOW POOLS")[0]["sv_idle"]
|
||||
connections = Array.new(20) { PG::connect(pgcat_conn_str) }
|
||||
sleep(1)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["cl_idle"]).to eq("20")
|
||||
expect(results["sv_idle"]).to eq(before_test)
|
||||
|
||||
connections.map(&:close)
|
||||
sleep(1.1)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_active cl_idle cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["sv_idle"]).to eq(before_test)
|
||||
end
|
||||
end
|
||||
|
||||
context "clients connect and make one query" do
|
||||
it "only affects cl_idle, sv_idle stats" do
|
||||
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
Thread.new { c.async_exec("SELECT pg_sleep(2.5)") }
|
||||
end
|
||||
|
||||
sleep(1.1)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_waiting cl_cancel_req sv_idle sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["cl_active"]).to eq("5")
|
||||
expect(results["sv_active"]).to eq("5")
|
||||
|
||||
sleep(3)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["cl_idle"]).to eq("5")
|
||||
expect(results["sv_idle"]).to eq("5")
|
||||
|
||||
connections.map(&:close)
|
||||
sleep(1)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["sv_idle"]).to eq("5")
|
||||
end
|
||||
end
|
||||
|
||||
context "client connects and opens a transaction and closes connection uncleanly" do
|
||||
it "produces correct statistics" do
|
||||
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
Thread.new do
|
||||
c.async_exec("BEGIN")
|
||||
c.async_exec("SELECT pg_sleep(0.01)")
|
||||
c.close
|
||||
end
|
||||
end
|
||||
|
||||
sleep(1.1)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["sv_idle"]).to eq("5")
|
||||
end
|
||||
end
|
||||
|
||||
context "client fail to checkout connection from the pool" do
|
||||
it "counts clients as idle" do
|
||||
new_configs = processes.pgcat.current_config
|
||||
new_configs["general"]["connect_timeout"] = 500
|
||||
new_configs["general"]["ban_time"] = 1
|
||||
new_configs["general"]["shutdown_timeout"] = 1
|
||||
new_configs["pools"]["sharded_db"]["users"]["0"]["pool_size"] = 1
|
||||
processes.pgcat.update_config(new_configs)
|
||||
processes.pgcat.reload_config
|
||||
|
||||
threads = []
|
||||
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
threads << Thread.new { c.async_exec("SELECT pg_sleep(1)") rescue PG::SystemError }
|
||||
end
|
||||
|
||||
sleep(2)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["cl_idle"]).to eq("5")
|
||||
expect(results["sv_idle"]).to eq("1")
|
||||
|
||||
threads.map(&:join)
|
||||
connections.map(&:close)
|
||||
end
|
||||
end
|
||||
|
||||
context "clients connects and disconnect normally" do
|
||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 2) }
|
||||
|
||||
it 'shows the same number of clients before and after' do
|
||||
clients_before = clients_connected_to_pool(processes: processes)
|
||||
threads = []
|
||||
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
threads << Thread.new { c.async_exec("SELECT 1") rescue nil }
|
||||
end
|
||||
clients_between = clients_connected_to_pool(processes: processes)
|
||||
expect(clients_before).not_to eq(clients_between)
|
||||
connections.each(&:close)
|
||||
clients_after = clients_connected_to_pool(processes: processes)
|
||||
expect(clients_before).to eq(clients_after)
|
||||
end
|
||||
end
|
||||
|
||||
context "clients connects and disconnect abruptly" do
|
||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 10) }
|
||||
|
||||
it 'shows the same number of clients before and after' do
|
||||
threads = []
|
||||
connections = Array.new(2) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
threads << Thread.new { c.async_exec("SELECT 1") }
|
||||
end
|
||||
clients_before = clients_connected_to_pool(processes: processes)
|
||||
random_string = (0...8).map { (65 + rand(26)).chr }.join
|
||||
connection_string = "#{pgcat_conn_str}?application_name=#{random_string}"
|
||||
faulty_client = Process.spawn("psql -Atx #{connection_string} >/dev/null")
|
||||
sleep(1)
|
||||
# psql starts two processes, we only know the pid of the parent, this
|
||||
# ensure both are killed
|
||||
`pkill -9 -f '#{random_string}'`
|
||||
Process.wait(faulty_client)
|
||||
clients_after = clients_connected_to_pool(processes: processes)
|
||||
expect(clients_before).to eq(clients_after)
|
||||
end
|
||||
end
|
||||
|
||||
context "clients overwhelm server pools" do
|
||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 2) }
|
||||
|
||||
it "cl_waiting is updated to show it" do
|
||||
threads = []
|
||||
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
threads << Thread.new { c.async_exec("SELECT pg_sleep(1.5)") }
|
||||
end
|
||||
|
||||
sleep(1.1) # Allow time for stats to update
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_cancel_req sv_idle sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
|
||||
expect(results["cl_waiting"]).to eq("2")
|
||||
expect(results["cl_active"]).to eq("2")
|
||||
expect(results["sv_active"]).to eq("2")
|
||||
|
||||
sleep(2.5) # Allow time for stats to update
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["cl_idle"]).to eq("4")
|
||||
expect(results["sv_idle"]).to eq("2")
|
||||
|
||||
threads.map(&:join)
|
||||
connections.map(&:close)
|
||||
end
|
||||
|
||||
it "show correct max_wait" do
|
||||
threads = []
|
||||
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
threads << Thread.new { c.async_exec("SELECT pg_sleep(1.5)") rescue nil }
|
||||
end
|
||||
|
||||
sleep(2.5) # Allow time for stats to update
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
|
||||
expect(results["maxwait"]).to eq("1")
|
||||
expect(results["maxwait_us"].to_i).to be_within(200_000).of(500_000)
|
||||
connections.map(&:close)
|
||||
|
||||
sleep(4.5) # Allow time for stats to update
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
expect(results["maxwait"]).to eq("0")
|
||||
|
||||
threads.map(&:join)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "SHOW CLIENTS" do
|
||||
it "reports correct number and application names" do
|
||||
conn_str = processes.pgcat.connection_string("sharded_db", "sharding_user")
|
||||
connections = Array.new(20) { |i| PG::connect("#{conn_str}?application_name=app#{i % 5}") }
|
||||
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
sleep(1) # Wait for stats to be updated
|
||||
|
||||
results = admin_conn.async_exec("SHOW CLIENTS")
|
||||
expect(results.count).to eq(21) # count admin clients
|
||||
expect(results.select { |c| c["application_name"] == "app3" || c["application_name"] == "app4" }.count).to eq(8)
|
||||
expect(results.select { |c| c["database"] == "pgcat" }.count).to eq(1)
|
||||
|
||||
connections[0..5].map(&:close)
|
||||
sleep(1) # Wait for stats to be updated
|
||||
results = admin_conn.async_exec("SHOW CLIENTS")
|
||||
expect(results.count).to eq(15)
|
||||
|
||||
connections[6..].map(&:close)
|
||||
sleep(1) # Wait for stats to be updated
|
||||
expect(admin_conn.async_exec("SHOW CLIENTS").count).to eq(1)
|
||||
admin_conn.close
|
||||
end
|
||||
|
||||
it "reports correct number of queries and transactions" do
|
||||
conn_str = processes.pgcat.connection_string("sharded_db", "sharding_user")
|
||||
|
||||
connections = Array.new(2) { |i| PG::connect("#{conn_str}?application_name=app#{i}") }
|
||||
connections.each do |c|
|
||||
c.async_exec("SELECT 1")
|
||||
c.async_exec("SELECT 2")
|
||||
c.async_exec("SELECT 3")
|
||||
c.async_exec("BEGIN")
|
||||
c.async_exec("SELECT 4")
|
||||
c.async_exec("SELECT 5")
|
||||
c.async_exec("COMMIT")
|
||||
end
|
||||
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
sleep(1) # Wait for stats to be updated
|
||||
|
||||
results = admin_conn.async_exec("SHOW CLIENTS")
|
||||
expect(results.count).to eq(3)
|
||||
normal_client_results = results.reject { |r| r["database"] == "pgcat" }
|
||||
expect(normal_client_results[0]["transaction_count"]).to eq("4")
|
||||
expect(normal_client_results[1]["transaction_count"]).to eq("4")
|
||||
expect(normal_client_results[0]["query_count"]).to eq("7")
|
||||
expect(normal_client_results[1]["query_count"]).to eq("7")
|
||||
|
||||
admin_conn.close
|
||||
connections.map(&:close)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
describe "Query Storm" do
|
||||
context "when the proxy receives overwhelmingly large number of short quick queries" do
|
||||
it "should not have lingering clients or active servers" do
|
||||
new_configs = processes.pgcat.current_config
|
||||
|
||||
new_configs["general"]["connect_timeout"] = 500
|
||||
new_configs["general"]["ban_time"] = 1
|
||||
new_configs["general"]["shutdown_timeout"] = 1
|
||||
new_configs["pools"]["sharded_db"]["users"]["0"]["pool_size"] = 1
|
||||
processes.pgcat.update_config(new_configs)
|
||||
processes.pgcat.reload_config
|
||||
|
||||
Array.new(40) do
|
||||
Thread.new do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
conn.async_exec("SELECT pg_sleep(0.1)")
|
||||
rescue PG::SystemError
|
||||
ensure
|
||||
conn.close
|
||||
end
|
||||
end.each(&:join)
|
||||
|
||||
sleep 1
|
||||
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_waiting cl_cancel_req sv_used sv_tested sv_login].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
|
||||
admin_conn.close
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
1
tests/rust/.gitignore
vendored
1
tests/rust/.gitignore
vendored
@@ -1 +0,0 @@
|
||||
target/
|
||||
1322
tests/rust/Cargo.lock
generated
1322
tests/rust/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,10 +0,0 @@
|
||||
[package]
|
||||
name = "rust"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
sqlx = { version = "0.6.2", features = [ "runtime-tokio-rustls", "postgres", "json", "tls", "migrate", "time", "uuid", "ipnetwork"] }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
@@ -1,29 +0,0 @@
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
test_prepared_statements().await;
|
||||
}
|
||||
|
||||
async fn test_prepared_statements() {
|
||||
let pool = sqlx::postgres::PgPoolOptions::new()
|
||||
.max_connections(5)
|
||||
.connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/sharded_db")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut handles = Vec::new();
|
||||
|
||||
for _ in 0..5 {
|
||||
let pool = pool.clone();
|
||||
let handle = tokio::task::spawn(async move {
|
||||
for _ in 0..1000 {
|
||||
sqlx::query("SELECT 1").fetch_all(&pool).await.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.await.unwrap();
|
||||
}
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Build an Ubuntu deb.
|
||||
#
|
||||
script_dir=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||
deb_dir="/tmp/pgcat-build"
|
||||
export PACKAGE_VERSION=${1:-"1.1.1"}
|
||||
if [[ $(arch) == "x86_64" ]]; then
|
||||
export ARCH=amd64
|
||||
else
|
||||
export ARCH=arm64
|
||||
fi
|
||||
|
||||
cd "$script_dir/.."
|
||||
cargo build --release
|
||||
|
||||
rm -rf "$deb_dir"
|
||||
mkdir -p "$deb_dir/DEBIAN"
|
||||
mkdir -p "$deb_dir/usr/bin"
|
||||
mkdir -p "$deb_dir/etc/systemd/system"
|
||||
|
||||
cp target/release/pgcat "$deb_dir/usr/bin/pgcat"
|
||||
chmod +x "$deb_dir/usr/bin/pgcat"
|
||||
|
||||
cp pgcat.toml "$deb_dir/etc/pgcat.toml"
|
||||
cp pgcat.service "$deb_dir/etc/systemd/system/pgcat.service"
|
||||
|
||||
(cat control | envsubst) > "$deb_dir/DEBIAN/control"
|
||||
cp postinst "$deb_dir/DEBIAN/postinst"
|
||||
cp postrm "$deb_dir/DEBIAN/postrm"
|
||||
cp prerm "$deb_dir/DEBIAN/prerm"
|
||||
|
||||
chmod +x ${deb_dir}/DEBIAN/post*
|
||||
chmod +x ${deb_dir}/DEBIAN/pre*
|
||||
|
||||
dpkg-deb \
|
||||
--root-owner-group \
|
||||
-z1 \
|
||||
--build "$deb_dir" \
|
||||
pgcat-${PACKAGE_VERSION}-ubuntu22.04-${ARCH}.deb
|
||||
Reference in New Issue
Block a user