Compare commits

..

27 Commits

Author SHA1 Message Date
Lev Kokotov
3ca28a62c4 Dont accept empty passwords 2023-03-30 18:09:01 -07:00
Lev Kokotov
b65c1ddd56 readme 2023-03-30 17:36:49 -07:00
Lev Kokotov
ed31053cdb Fix spec 2023-03-30 17:35:32 -07:00
Lev Kokotov
4969abf355 Hmm 2023-03-30 15:29:10 -07:00
Lev Kokotov
112c0bdae8 Rebased 2023-03-30 15:19:52 -07:00
Lev Kokotov
fef737ea43 fmt 2023-03-30 14:16:50 -07:00
Lev Kokotov
345ee88342 Warn when secrets are too short 2023-03-30 14:16:38 -07:00
Lev Kokotov
db3d6c3baa Some tests 2023-03-30 14:16:36 -07:00
Lev Kokotov
197c32b4e8 Readme 2023-03-30 14:15:30 -07:00
Lev Kokotov
6345c39bd5 fix ci config 2023-03-30 14:15:07 -07:00
Lev Kokotov
32b913af94 update admin 2023-03-30 14:15:07 -07:00
Lev Kokotov
5c673b4333 Zero-downtime password rotation 2023-03-30 14:15:05 -07:00
Jose Fernández
6f768a84ce Auth passthrough (auth_query) (#266)
* Add a new exec_simple_query method

This adds a new `exec_simple_query` method so we can make 'out of band'
queries to servers that don't interfere with pools at all.
In order to reuse startup code for making these simple queries,
we need to set the stats (`Reporter`) optional, so using these
simple queries wont interfere with stats.

* Add auth passthough (auth_query)

Adds a feature that allows setting auth passthrough for md5 auth.

It adds 3 new (general and pool) config parameters:

- `auth_query`: An string containing a query that will be executed on boot
to obtain the hash of a given user. This query have to use a placeholder `$1`,
so pgcat can replace it with the user its trying to fetch the hash from.
- `auth_query_user`: The user to use for connecting to the server and executing the
auth_query.
- `auth_query_password`: The password to use for connecting to the server and executing the
auth_query.

The configuration can be done either on the general config (so pools share them) or in a per-pool basis.

The behavior is, at boot time, when validating server connections, a hash is fetched per server
and stored in the pool. When new server connections are created, and no cleartext password is specified,
the obtained hash is used for creating them, if the hash could not be obtained for whatever reason, it retries
it.

When client authentication is tried, it uses cleartext passwords if specified, it not, it checks whether
we have query_auth set up, if so, it tries to use the obtained hash for making client auth. If there is no
hash (we could not obtain one when validating the connection), a new fetch is tried.

Once we have a hash, we authenticate using it against whathever the client has sent us, if there is a failure
we refetch the hash and retry auth (so password changes can be done).

The idea with this 'retrial' mechanism is to make it fault tolerant, so if for whatever reason hash could not be
obtained during connection validation, or the password has change, we can still connect later.

* Add documentation for Auth passthrough
2023-03-30 13:29:23 -07:00
dependabot[bot]
0757d7f3a0 chore(deps): bump serde from 1.0.158 to 1.0.159 (#386)
Bumps [serde](https://github.com/serde-rs/serde) from 1.0.158 to 1.0.159.
- [Release notes](https://github.com/serde-rs/serde/releases)
- [Commits](https://github.com/serde-rs/serde/compare/v1.0.158...v1.0.159)

---
updated-dependencies:
- dependency-name: serde
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-28 09:54:39 -07:00
dependabot[bot]
568f04feee chore(deps): bump serde_derive from 1.0.154 to 1.0.159 (#387)
Bumps [serde_derive](https://github.com/serde-rs/serde) from 1.0.154 to 1.0.159.
- [Release notes](https://github.com/serde-rs/serde/releases)
- [Commits](https://github.com/serde-rs/serde/compare/v1.0.154...v1.0.159)

---
updated-dependencies:
- dependency-name: serde_derive
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-28 09:54:31 -07:00
Jose Fernández
58ce76d9b9 Refactor stats to use atomics (#375)
* Refactor stats to use atomics

When we are dealing with a high number of connections, generated
stats cannot be consumed fast enough by the stats collector loop.
This makes the stats subsystem inconsistent and a log of
warning messages are thrown due to unregistered server/clients.

This change refactors the stats subsystem so it uses atomics:

- Now counters are handled using U64 atomics
- Event system is dropped and averages are calculated using a loop
  every 15 seconds.
- Now, instead of snapshots being generated ever second we keep track of servers/clients
  that have registered. Each pool/server/client has its own instance of the counter and
  makes changes directly, instead of adding an event that gets processed later.

* Manually mplement Hash/Eq in `config::Address` ignoring stats

* Add tests for client connection counters

* Allow connecting to dockerized dev pgcat from the host

* stats: Decrease cl_idle when idle socket disconnects
2023-03-28 17:19:37 +02:00
dependabot[bot]
9a2076a9eb chore(deps): bump futures from 0.3.26 to 0.3.27 (#356)
Bumps [futures](https://github.com/rust-lang/futures-rs) from 0.3.26 to 0.3.27.
- [Release notes](https://github.com/rust-lang/futures-rs/releases)
- [Changelog](https://github.com/rust-lang/futures-rs/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/futures-rs/compare/0.3.26...0.3.27)

---
updated-dependencies:
- dependency-name: futures
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-27 09:13:45 -07:00
dependabot[bot]
e7e7118725 chore(deps): bump hyper from 0.14.24 to 0.14.25 (#358)
Bumps [hyper](https://github.com/hyperium/hyper) from 0.14.24 to 0.14.25.
- [Release notes](https://github.com/hyperium/hyper/releases)
- [Changelog](https://github.com/hyperium/hyper/blob/v0.14.25/CHANGELOG.md)
- [Commits](https://github.com/hyperium/hyper/compare/v0.14.24...v0.14.25)

---
updated-dependencies:
- dependency-name: hyper
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-27 09:13:36 -07:00
dependabot[bot]
99f790cacf chore(deps): bump toml from 0.7.2 to 0.7.3 (#360)
Bumps [toml](https://github.com/toml-rs/toml) from 0.7.2 to 0.7.3.
- [Release notes](https://github.com/toml-rs/toml/releases)
- [Commits](https://github.com/toml-rs/toml/compare/toml-v0.7.2...toml-v0.7.3)

---
updated-dependencies:
- dependency-name: toml
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-27 09:13:25 -07:00
dependabot[bot]
434b0bb69e chore(deps): bump serde from 1.0.154 to 1.0.158 (#376)
Bumps [serde](https://github.com/serde-rs/serde) from 1.0.154 to 1.0.158.
- [Release notes](https://github.com/serde-rs/serde/releases)
- [Commits](https://github.com/serde-rs/serde/compare/v1.0.154...v1.0.158)

---
updated-dependencies:
- dependency-name: serde
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-27 09:13:13 -07:00
dependabot[bot]
714e043ef0 chore(deps): bump async-trait from 0.1.66 to 0.1.68 (#382)
Bumps [async-trait](https://github.com/dtolnay/async-trait) from 0.1.66 to 0.1.68.
- [Release notes](https://github.com/dtolnay/async-trait/releases)
- [Commits](https://github.com/dtolnay/async-trait/compare/0.1.66...0.1.68)

---
updated-dependencies:
- dependency-name: async-trait
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-27 09:12:54 -07:00
dependabot[bot]
863104aadd chore(deps): bump regex from 1.7.1 to 1.7.3 (#385)
Bumps [regex](https://github.com/rust-lang/regex) from 1.7.1 to 1.7.3.
- [Release notes](https://github.com/rust-lang/regex/releases)
- [Changelog](https://github.com/rust-lang/regex/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/regex/compare/1.7.1...1.7.3)

---
updated-dependencies:
- dependency-name: regex
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-27 09:12:43 -07:00
Lev Kokotov
7dd96141e3 Update README.md 2023-03-26 00:33:05 -07:00
Lev Kokotov
0d5feac4b2 Contributors (#384) 2023-03-24 17:12:12 -07:00
Lev Kokotov
90aba9c011 V1 (#383) 2023-03-24 17:10:12 -07:00
Montana Low
0f34b49503 point CI at updated repo 2023-03-24 12:59:03 -07:00
Zain Kabani
ca4431b67e Add idle client in transaction configuration (#380)
* Add idle client in transaction configuration

* fmt

* Update docs

* trigger build

* Add tests

* Make the config dynamic from reloads

* fmt

* comments

* trigger build

* fix config.md

* remove error
2023-03-24 08:20:30 -07:00
42 changed files with 3643 additions and 2299 deletions

View File

@@ -46,6 +46,14 @@ jobs:
POSTGRES_PASSWORD: postgres
POSTGRES_INITDB_ARGS: --auth-local=scram-sha-256 --auth-host=scram-sha-256 --auth=scram-sha-256
- image: postgres:14
command: ["postgres", "-p", "10432", "-c", "shared_preload_libraries=pg_stat_statements"]
environment:
POSTGRES_USER: postgres
POSTGRES_DB: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_INITDB_ARGS: --auth-local=md5 --auth-host=md5 --auth=md5
# Add steps to the job
# See: https://circleci.com/docs/2.0/configuration-reference/#steps
steps:

View File

@@ -19,6 +19,7 @@ PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 5432 -U postgres -f tests/sharding/q
PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 7432 -U postgres -f tests/sharding/query_routing_setup.sql
PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 8432 -U postgres -f tests/sharding/query_routing_setup.sql
PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 9432 -U postgres -f tests/sharding/query_routing_setup.sql
PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 10432 -U postgres -f tests/sharding/query_routing_setup.sql
PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard0 -i
PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard1 -i

View File

@@ -15,6 +15,6 @@ jobs:
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build CI Docker image
run: |
docker build . -f Dockerfile.ci --tag ghcr.io/levkk/pgcat-ci:latest
docker run ghcr.io/levkk/pgcat-ci:latest
docker push ghcr.io/levkk/pgcat-ci:latest
docker build . -f Dockerfile.ci --tag ghcr.io/postgresml/pgcat-ci:latest
docker run ghcr.io/postgresml/pgcat-ci:latest
docker push ghcr.io/postgresml/pgcat-ci:latest

2
.rustfmt.toml Normal file
View File

@@ -0,0 +1,2 @@
edition = "2021"
hard_tabs = false

View File

@@ -49,6 +49,14 @@ default: 30000 # milliseconds
How long an idle connection with a server is left open (ms).
### idle_client_in_transaction_timeout
```
path: general.idle_client_in_transaction_timeout
default: 0 # milliseconds
```
How long a client is allowed to be idle while in a transaction (ms).
### healthcheck_timeout
```
path: general.healthcheck_timeout
@@ -167,11 +175,41 @@ Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DAT
### admin_password
```
path: general.admin_password
default: "admin_pass"
default: <UNSET>
```
Password to access the virtual administrative database
### auth_query (experimental)
```
path: general.auth_query
default: <UNSET>
```
Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
established using the database configured in the pool. This parameter is inherited by every pool
and can be redefined in pool configuration.
### auth_query_user (experimental)
```
path: general.auth_query_user
default: <UNSET>
```
User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
This parameter is inherited by every pool and can be redefined in pool configuration.
### auth_query_password (experimental)
```
path: general.auth_query_password
default: <UNSET>
```
Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
This parameter is inherited by every pool and can be redefined in pool configuration.
## `pools.<pool_name>` Section
### pool_mode
@@ -273,6 +311,30 @@ default: 3000
Connect timeout can be overwritten in the pool
### auth_query (experimental)
```
path: general.auth_query
default: <UNSET>
```
Auth query can be overwritten in the pool
### auth_query_user (experimental)
```
path: general.auth_query_user
default: <UNSET>
```
Auth query user can be overwritten in the pool
### auth_query_password (experimental)
```
path: general.auth_query_password
default: <UNSET>
```
Auth query password can be overwritten in the pool
## `pools.<pool_name>.users.<user_index>` Section
### username

161
Cargo.lock generated
View File

@@ -28,13 +28,24 @@ checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6"
[[package]]
name = "async-trait"
version = "0.1.66"
version = "0.1.68"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b84f9ebcc6c1f5b8cb160f6990096a5c127f423fcb6e1ccc46c370cbdfb75dfc"
checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.9",
]
[[package]]
name = "atomic_enum"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6227a8d6fdb862bcb100c4314d0d9579e5cd73fa6df31a2e6f6e1acd3c5f1207"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
@@ -43,6 +54,12 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "base64"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
[[package]]
name = "base64"
version = "0.21.0"
@@ -83,6 +100,12 @@ version = "3.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535"
[[package]]
name = "byteorder"
version = "1.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
[[package]]
name = "bytes"
version = "1.4.0"
@@ -175,7 +198,7 @@ dependencies = [
"proc-macro2",
"quote",
"scratch",
"syn",
"syn 1.0.109",
]
[[package]]
@@ -192,7 +215,7 @@ checksum = "086c685979a698443656e5cf7856c95c642295a38599f12fb1ff76fb28d19892"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 1.0.109",
]
[[package]]
@@ -246,6 +269,12 @@ version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de853764b47027c2e862a995c34978ffa63c1501f2e15f987ba11bd4f9bba193"
[[package]]
name = "fallible-iterator"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
[[package]]
name = "fnv"
version = "1.0.7"
@@ -254,9 +283,9 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "futures"
version = "0.3.26"
version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13e2792b0ff0340399d58445b88fd9770e3489eff258a4cbc1523418f12abf84"
checksum = "531ac96c6ff5fd7c62263c5e3c67a603af4fcaee2e1a0ae5565ba3a11e69e549"
dependencies = [
"futures-channel",
"futures-core",
@@ -269,9 +298,9 @@ dependencies = [
[[package]]
name = "futures-channel"
version = "0.3.26"
version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e5317663a9089767a1ec00a487df42e0ca174b61b4483213ac24448e4664df5"
checksum = "164713a5a0dcc3e7b4b1ed7d3b433cabc18025386f9339346e8daf15963cf7ac"
dependencies = [
"futures-core",
"futures-sink",
@@ -279,15 +308,15 @@ dependencies = [
[[package]]
name = "futures-core"
version = "0.3.26"
version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec90ff4d0fe1f57d600049061dc6bb68ed03c7d2fbd697274c41805dcb3f8608"
checksum = "86d7a0c1aa76363dac491de0ee99faf6941128376f1cf96f07db7603b7de69dd"
[[package]]
name = "futures-executor"
version = "0.3.26"
version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8de0a35a6ab97ec8869e32a2473f4b1324459e14c29275d14b10cb1fd19b50e"
checksum = "1997dd9df74cdac935c76252744c1ed5794fac083242ea4fe77ef3ed60ba0f83"
dependencies = [
"futures-core",
"futures-task",
@@ -296,38 +325,38 @@ dependencies = [
[[package]]
name = "futures-io"
version = "0.3.26"
version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfb8371b6fb2aeb2d280374607aeabfc99d95c72edfe51692e42d3d7f0d08531"
checksum = "89d422fa3cbe3b40dca574ab087abb5bc98258ea57eea3fd6f1fa7162c778b91"
[[package]]
name = "futures-macro"
version = "0.3.26"
version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95a73af87da33b5acf53acfebdc339fe592ecf5357ac7c0a7734ab9d8c876a70"
checksum = "3eb14ed937631bd8b8b8977f2c198443447a8355b6e3ca599f38c975e5a963b6"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 1.0.109",
]
[[package]]
name = "futures-sink"
version = "0.3.26"
version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f310820bb3e8cfd46c80db4d7fb8353e15dfff853a127158425f31e0be6c8364"
checksum = "ec93083a4aecafb2a80a885c9de1f0ccae9dbd32c2bb54b0c3a65690e0b8d2f2"
[[package]]
name = "futures-task"
version = "0.3.26"
version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcf79a1bf610b10f42aea489289c5a2c478a786509693b80cd39c44ccd936366"
checksum = "fd65540d33b37b16542a0438c12e6aeead10d4ac5d05bd3f805b8f35ab592879"
[[package]]
name = "futures-util"
version = "0.3.26"
version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c1d6de3acfef38d2be4b1f543f553131788603495be83da675e180c8d6b7bd1"
checksum = "3ef6b17e481503ec85211fed8f39d1970f128935ca1f814cd32ac4a6842e84ab"
dependencies = [
"futures-channel",
"futures-core",
@@ -453,9 +482,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]]
name = "hyper"
version = "0.14.24"
version = "0.14.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e011372fa0b68db8350aa7a248930ecc7839bf46d8485577d69f117a75f164c"
checksum = "cc5e554ff619822309ffd57d8734d77cd5ce6238bc956f037ea06c58238c9899"
dependencies = [
"bytes",
"futures-channel",
@@ -716,16 +745,18 @@ dependencies = [
[[package]]
name = "pgcat"
version = "0.6.0-alpha1"
version = "1.0.0"
dependencies = [
"arc-swap",
"async-trait",
"base64",
"atomic_enum",
"base64 0.21.0",
"bb8",
"bytes",
"chrono",
"env_logger",
"exitcode",
"fallible-iterator",
"futures",
"hmac",
"hyper",
@@ -737,6 +768,7 @@ dependencies = [
"once_cell",
"parking_lot",
"phf",
"postgres-protocol",
"rand",
"regex",
"rustls-pemfile",
@@ -782,7 +814,7 @@ dependencies = [
"phf_shared",
"proc-macro2",
"quote",
"syn",
"syn 1.0.109",
]
[[package]]
@@ -806,6 +838,24 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "postgres-protocol"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "878c6cbf956e03af9aa8204b407b9cbf47c072164800aa918c516cd4b056c50c"
dependencies = [
"base64 0.13.1",
"byteorder",
"bytes",
"fallible-iterator",
"hmac",
"md-5",
"memchr",
"rand",
"sha2",
"stringprep",
]
[[package]]
name = "ppv-lite86"
version = "0.2.17"
@@ -814,18 +864,18 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]]
name = "proc-macro2"
version = "1.0.51"
version = "1.0.53"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6"
checksum = "ba466839c78239c09faf015484e5cc04860f88242cff4d03eb038f04b4699b73"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.23"
version = "1.0.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b"
checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc"
dependencies = [
"proc-macro2",
]
@@ -871,9 +921,9 @@ dependencies = [
[[package]]
name = "regex"
version = "1.7.1"
version = "1.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733"
checksum = "8b1f693b24f6ac912f4893ef08244d70b6067480d2f1a46e950c9691e6749d1d"
dependencies = [
"aho-corasick",
"memchr",
@@ -882,9 +932,9 @@ dependencies = [
[[package]]
name = "regex-syntax"
version = "0.6.28"
version = "0.6.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848"
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
[[package]]
name = "ring"
@@ -933,7 +983,7 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b"
dependencies = [
"base64",
"base64 0.21.0",
]
[[package]]
@@ -960,19 +1010,19 @@ dependencies = [
[[package]]
name = "serde"
version = "1.0.154"
version = "1.0.159"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8cdd151213925e7f1ab45a9bbfb129316bd00799784b174b7cc7bcd16961c49e"
checksum = "3c04e8343c3daeec41f58990b9d77068df31209f2af111e059e9fe9646693065"
[[package]]
name = "serde_derive"
version = "1.0.154"
version = "1.0.159"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4fc80d722935453bcafdc2c9a73cd6fac4dc1938f0346035d84bf99fa9e33217"
checksum = "4c614d17805b093df4b147b51339e7e44bf05ef59fba1e45d83500bcfb4d8585"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.9",
]
[[package]]
@@ -1094,6 +1144,17 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "syn"
version = "2.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0da4a3c17e109f700685ec577c0f85efd9b19bcf15c913985f14dc1ac01775aa"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "termcolor"
version = "1.2.0"
@@ -1157,7 +1218,7 @@ checksum = "d266c00fde287f55d3f1c3e96c500c362a2b8c695076ec180f27918820bc6df8"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 1.0.109",
]
[[package]]
@@ -1187,9 +1248,9 @@ dependencies = [
[[package]]
name = "toml"
version = "0.7.2"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7afcae9e3f0fe2c370fd4657108972cbb2fa9db1b9f84849cefd80741b01cb6"
checksum = "b403acf6f2bb0859c93c7f0d967cb4a75a7ac552100f9322faf64dc047669b21"
dependencies = [
"serde",
"serde_spanned",
@@ -1208,9 +1269,9 @@ dependencies = [
[[package]]
name = "toml_edit"
version = "0.19.4"
version = "0.19.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a1eb0622d28f4b9c90adc4ea4b2b46b47663fde9ac5fafcb14a1369d5508825"
checksum = "08de71aa0d6e348f070457f85af8bd566e2bc452156a423ddf22861b3a953fae"
dependencies = [
"indexmap",
"serde",
@@ -1339,7 +1400,7 @@ dependencies = [
"once_cell",
"proc-macro2",
"quote",
"syn",
"syn 1.0.109",
"wasm-bindgen-shared",
]
@@ -1361,7 +1422,7 @@ checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 1.0.109",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]

View File

@@ -1,10 +1,9 @@
[package]
name = "pgcat"
version = "0.6.0-alpha1"
version = "1.0.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
tokio = { version = "1", features = ["full"] }
bytes = "1"
@@ -37,6 +36,9 @@ exitcode = "1.1.2"
futures = "0.3"
socket2 = { version = "0.4.7", features = ["all"] }
nix = "0.26.2"
atomic_enum = "0.2.0"
postgres-protocol = "0.6.4"
fallible-iterator = "0.2"
[target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0"

View File

@@ -1,4 +1,4 @@
Copyright (c) 2022 Lev Kokotov <lev@levthe.dev>
Copyright (c) 2023 PgCat Contributors
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the

400
README.md
View File

@@ -1,33 +1,49 @@
##### PgCat: PostgreSQL at petabyte scale
## PgCat: Nextgen PostgreSQL Pooler
[![CircleCI](https://circleci.com/gh/levkk/pgcat/tree/main.svg?style=svg)](https://circleci.com/gh/levkk/pgcat/tree/main)
[![CircleCI](https://circleci.com/gh/postgresml/pgcat/tree/main.svg?style=svg)](https://circleci.com/gh/postgresml/pgcat/tree/main)
<a href="https://discord.gg/DmyJP3qJ7U" target="_blank">
<img src="https://img.shields.io/discord/1013868243036930099" alt="Join our Discord!" />
</a>
PostgreSQL pooler (like PgBouncer) with sharding, load balancing and failover support.
**Beta**: looking for beta testers, see [#35](https://github.com/levkk/pgcat/issues/35).
PostgreSQL pooler and proxy (like PgBouncer) with support for sharding, load balancing, failover and mirroring.
## Features
| **Feature** | **Status** | **Comments** |
|--------------------------------|-----------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------|
| Transaction pooling | :white_check_mark: | Identical to PgBouncer. |
| Session pooling | :white_check_mark: | Identical to PgBouncer. |
| `COPY` support | :white_check_mark: | Both `COPY TO` and `COPY FROM` are supported. |
| Query cancellation | :white_check_mark: | Supported both in transaction and session pooling modes. |
| Load balancing of read queries | :white_check_mark: | Using random between replicas. Primary is included when `primary_reads_enabled` is enabled (default). |
| Sharding | :white_check_mark: | Transactions are sharded using `SET SHARD TO` and `SET SHARDING KEY TO` syntax extensions; see examples below. |
| Failover | :white_check_mark: | Replicas are tested with a health check. If a health check fails, remaining replicas are attempted; see below for algorithm description and examples. |
| Statistics | :white_check_mark: | Statistics available in the admin database (`pgcat` and `pgbouncer`) with `SHOW STATS`, `SHOW POOLS` and others. |
| Live configuration reloading | :white_check_mark: | Reload supported settings with a `SIGHUP` to the process, e.g. `kill -s SIGHUP $(pgrep pgcat)` or `RELOAD` query issued to the admin database. |
| Client authentication | :white_check_mark: :wrench: | MD5 password authentication is supported, SCRAM is on the roadmap; one user is used to connect to Postgres with both SCRAM and MD5 supported. |
| Admin database | :white_check_mark: | The admin database, similar to PgBouncer's, allows to query for statistics and reload the configuration. |
| **Feature** | **Status** | **Comments** |
|-------------|------------|--------------|
| Transaction pooling | **Stable** | Identical to PgBouncer with notable improvements for handling bad clients and abandoned transactions. |
| Session pooling | **Stable** | Identical to PgBouncer. |
| Multi-threaded runtime | **Stable** | Using Tokio asynchronous runtime, the pooler takes advantage of multicore machines. |
| Load balancing of read queries | **Stable** | Queries are automatically load balanced between replicas and the primary. |
| Failover | **Stable** | Queries are automatically rerouted around broken replicas, validated by regular health checks. |
| Admin database statistics | **Stable** | Pooler statistics and administration via the `pgbouncer` and `pgcat` databases. |
| Prometheus statistics | **Stable** | Statistics are reported via a HTTP endpoint for Prometheus. |
| Client TLS | **Stable** | Clients can connect to the pooler using TLS/SSL. |
| Client/Server authentication | **Stable** | Clients can connect using MD5 authentication, supported by `libpq` and all Postgres client drivers. PgCat can connect to Postgres using MD5 and SCRAM-SHA-256. |
| Live configuration reloading | **Stable** | Identical to PgBouncer; all settings can be reloaded dynamically (except `host` and `port`). |
| Sharding using extended SQL syntax | **Experimental** | Clients can dynamically configure the pooler to route queries to specific shards. |
| Sharding using comments parsing/Regex | **Experimental** | Clients can include shard information (sharding key, shard ID) in the query comments. |
| Automatic sharding | **Experimental** | PgCat can parse queries, detect sharding keys automatically, and route queries to the correct shard. |
| Mirroring | **Experimental** | Mirror queries between multiple databases in order to test servers with realistic production traffic. |
| Auth passthrough | **Experimental** | MD5 password authentication can be configured to use an `auth_query` so no cleartext passwords are needed in the config file. |
| Password rotation | **Experimental** | Allows to rotate passwords without downtime or using third-party tools to manage Postgres authentication. |
## Status
PgCat is stable and used in production to serve hundreds of thousands of queries per second. Some features remain experimental and are being actively developed. They are optional and can be enabled through configuration.
| | |
|-|-|
|<a href="https://tech.instacart.com/adopting-pgcat-a-nextgen-postgres-proxy-3cf284e68c2f"><img src="./images/instacart.webp" height="70" width="auto"></a>|<a href="https://postgresml.org/blog/scaling-postgresml-to-one-million-requests-per-second"><img src="./images/postgresml.webp" height="70" width="auto"></a>|
| [Instacart](https://tech.instacart.com/adopting-pgcat-a-nextgen-postgres-proxy-3cf284e68c2f) | [PostgresML](https://postgresml.org/blog/scaling-postgresml-to-one-million-requests-per-second) |
## Deployment
See `Dockerfile` for example deployment using Docker. The pooler is configured to spawn 4 workers so 4 CPUs are recommended for optimal performance. That setting can be adjusted to spawn as many (or as little) workers as needed.
A Docker image is available from `docker pull ghcr.io/postgresml/pgcat:latest`. See our [Github packages repository](https://github.com/postgresml/pgcat/pkgs/container/pgcat).
For quick local example, use the Docker Compose environment provided:
```bash
@@ -39,9 +55,13 @@ PGPASSWORD=postgres psql -h 127.0.0.1 -p 6432 -U postgres -c 'SELECT 1'
### Config
See [Configurations page](https://github.com/levkk/pgcat/blob/main/CONFIG.md)
See **[Configuration](https://github.com/levkk/pgcat/blob/main/CONFIG.md)**.
## Local development
## Contributing
The project is being actively developed and looking for additional contributors and production deployments.
### Local development
1. Install Rust (latest stable will work great).
2. `cargo build --release` (to get better benchmarks).
@@ -51,7 +71,7 @@ See [Configurations page](https://github.com/levkk/pgcat/blob/main/CONFIG.md)
### Tests
Quickest way to test your changes is to use pgbench:
When making substantial modifications to the protocol implementation, make sure to test them with pgbench:
```
pgbench -i -h 127.0.0.1 -p 6432 && \
@@ -61,36 +81,26 @@ pgbench -t 1000 -p 6432 -h 127.0.0.1 --protocol extended
See [sharding README](./tests/sharding/README.md) for sharding logic testing.
Run `cargo test` to run Rust tests.
Additionally, all features are tested with Ruby, Python, and Rust unit and integration tests.
Run `cargo test` to run Rust unit tests.
Run the following commands to run Ruby and Python integration tests:
Run the following commands to run Integration tests locally.
```
cd tests/docker/
docker compose up --exit-code-from main # This will also produce coverage report under ./cov/
```
| **Feature** | **Tested in CI** | **Tested manually** | **Comments** |
|-----------------------|--------------------|---------------------|--------------------------------------------------------------------------------------------------------------------------|
| Transaction pooling | :white_check_mark: | :white_check_mark: | Used by default for all tests. |
| Session pooling | :white_check_mark: | :white_check_mark: | Tested by running pgbench with `--protocol prepared` which only works in session mode. |
| `COPY` | :white_check_mark: | :white_check_mark: | `pgbench -i` uses `COPY`. `COPY FROM` is tested as well. |
| Query cancellation | :white_check_mark: | :white_check_mark: | `psql -c 'SELECT pg_sleep(1000);'` and press `Ctrl-C`. |
| Load balancing | :white_check_mark: | :white_check_mark: | We could test this by emitting statistics for each replica and compare them. |
| Failover | :white_check_mark: | :white_check_mark: | Misconfigure a replica in `pgcat.toml` and watch it forward queries to spares. CI testing is using Toxiproxy. |
| Sharding | :white_check_mark: | :white_check_mark: | See `tests/sharding` and `tests/ruby` for an Rails/ActiveRecord example. |
| Statistics | :white_check_mark: | :white_check_mark: | Query the admin database with `psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS'`. |
| Live config reloading | :white_check_mark: | :white_check_mark: | Run `kill -s SIGHUP $(pgrep pgcat)` and watch the config reload. |
### Docker-based local development
### Dev
Also, you can open a 'dev' environment where you can debug tests easier by running the following command:
You can open a Docker development environment where you can debug tests easier. Run the following command to spin it up:
```
./dev/script/console
```
This will open a terminal in an environment similar to that used in tests. In there you can compile, run tests, do some debugging with the test environment, etc. Objects
compiled inside the contaner (and bundled gems) will be placed in `dev/cache` so they don't interfere with what you have in your host.
This will open a terminal in an environment similar to that used in tests. In there, you can compile the pooler, run tests, do some debugging with the test environment, etc. Objects compiled inside the contaner (and bundled gems) will be placed in `dev/cache` so they don't interfere with what you have on your machine.
## Usage
@@ -105,11 +115,9 @@ In transaction mode, a client talks to one server for the duration of a single t
This mode is enabled by default.
### Load balancing of read queries
All queries are load balanced against the configured servers using the random algorithm. The most straight forward configuration example would be to put this pooler in front of several replicas and let it load balance all queries.
All queries are load balanced against the configured servers using either the random or least open connections algorithms. The most straightforward configuration example would be to put this pooler in front of several replicas and let it load balance all queries.
If the configuration includes a primary and replicas, the queries can be separated with the built-in query parser. The query parser will interpret the query and route all `SELECT` queries to a replica, while all other queries including explicit transactions will be routed to the primary.
The query parser is disabled by default.
If the configuration includes a primary and replicas, the queries can be separated with the built-in query parser. The query parser, implemented with the `sqlparser` crate, will interpret the query and route all `SELECT` queries to a replica, while all other queries including explicit transactions will be routed to the primary.
#### Query parser
The query parser will do its best to determine where the query should go, but sometimes that's not possible. In that case, the client can select which server it wants using this custom SQL syntax:
@@ -136,38 +144,14 @@ The setting will persist until it's changed again or the client disconnects.
By default, all queries are routed to the first available server; `default_role` setting controls this behavior.
### Failover
All servers are checked with a `SELECT 1` query before being given to a client. If the server is not reachable, it will be banned and cannot serve any more transactions for the duration of the ban. The queries are routed to the remaining servers. If all servers become banned, the ban list is cleared: this is a safety precaution against false positives. The primary can never be banned.
All servers are checked with a `;` (very fast) query before being given to a client. Additionally, the server health is monitored with every client query that it processes. If the server is not reachable, it will be banned and cannot serve any more transactions for the duration of the ban. The queries are routed to the remaining servers. If all servers become banned, the ban list is cleared: this is a safety precaution against false positives. The primary can never be banned.
The ban time can be changed with `ban_time`. The default is 60 seconds.
Failover behavior can get pretty interesting (read complex) when multiple configurations and factors are involved. The table below will try to explain what PgCat does in each scenario:
| **Query** | **`SET SERVER ROLE TO`** | **`query_parser_enabled`** | **`primary_reads_enabled`** | **Target state** | **Outcome** |
|---------------------------|--------------------------|----------------------------|-----------------------------|------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Read query, i.e. `SELECT` | unset (any) | false | false | up | Query is routed to the first instance in the random loop. |
| Read query | unset (any) | true | false | up | Query is routed to the first replica instance in the random loop. |
| Read query | unset (any) | true | true | up | Query is routed to the first instance in the random loop. |
| Read query | replica | false | false | up | Query is routed to the first replica instance in the random loop. |
| Read query | primary | false | false | up | Query is routed to the primary. |
| Read query | unset (any) | false | false | down | First instance is banned for reads. Next target in the random loop is attempted. |
| Read query | unset (any) | true | false | down | First replica instance is banned. Next replica instance is attempted in the random loop. |
| Read query | unset (any) | true | true | down | First instance (even if primary) is banned for reads. Next instance is attempted in the random loop. |
| Read query | replica | false | false | down | First replica instance is banned. Next replica instance is attempted in the random loop. |
| Read query | primary | false | false | down | The query is attempted against the primary and fails. The client receives an error. |
| | | | | | |
| Write query e.g. `INSERT` | unset (any) | false | false | up | The query is attempted against the first available instance in the random loop. If the instance is a replica, the query fails and the client receives an error. |
| Write query | unset (any) | true | false | up | The query is routed to the primary. |
| Write query | unset (any) | true | true | up | The query is routed to the primary. |
| Write query | primary | false | false | up | The query is routed to the primary. |
| Write query | replica | false | false | up | The query is routed to the replica and fails. The client receives an error. |
| Write query | unset (any) | true | false | down | The query is routed to the primary and fails. The client receives an error. |
| Write query | unset (any) | true | true | down | The query is routed to the primary and fails. The client receives an error. |
| Write query | primary | false | false | down | The query is routed to the primary and fails. The client receives an error. |
| | | | | | |
### Sharding
We use the `PARTITION BY HASH` hashing function, the same as used by Postgres for declarative partitioning. This allows to shard the database using Postgres partitions and place the partitions on different servers (shards). Both read and write queries can be routed to the shards using this pooler.
#### Extended syntax
To route queries to a particular shard, we use this custom SQL syntax:
```sql
@@ -182,7 +166,8 @@ The active shard will last until it's changed again or the client disconnects. B
For hash function implementation, see `src/sharding.rs` and `tests/sharding/partition_hash_test_setup.sql`.
#### ActiveRecord/Rails
##### ActiveRecord/Rails
```ruby
class User < ActiveRecord::Base
@@ -210,7 +195,7 @@ User.connection.execute "SET SERVER ROLE TO 'auto'"
User.find_by_email("test@example.com")
```
#### Raw SQL
##### Raw SQL
```sql
-- Grab a bunch of users from shard 1
@@ -230,268 +215,51 @@ SET SERVER ROLE TO 'auto'; -- let the query router figure out where the query sh
SELECT * FROM users WHERE email = 'test@example.com'; -- shard setting lasts until set again; we are reading from the primary
```
#### With comments
Issuing queries to the pooler can cause additional latency. To reduce its impact, it's possible to include sharding information inside SQL comments sent via the query. This is reasonably easy to implement with ORMs like [ActiveRecord](https://api.rubyonrails.org/classes/ActiveRecord/QueryMethods.html#method-i-annotate) and [SQLAlchemy](https://docs.sqlalchemy.org/en/20/core/events.html#sql-execution-and-connection-events).
```
/* shard_id: 5 */ SELECT * FROM foo WHERE id = 1234;
/* sharding_key: 1234 */ SELECT * FROM foo WHERE id = 1234;
```
#### Automatic query parsing
PgCat can use the `sqlparser` crate to parse SQL queries and extract the sharding key. This is configurable with the `automatic_sharding_key` setting. This feature is still experimental, but it's the ideal implementation for sharding, requiring no client modifications.
### Statistics reporting
The stats are very similar to what Pgbouncer reports and the names are kept to be comparable. They are accessible by querying the admin database `pgcat`, and `pgbouncer` for compatibility.
The stats are very similar to what PgBouncer reports and the names are kept to be comparable. They are accessible by querying the admin database `pgcat`, and `pgbouncer` for compatibility.
```
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW DATABASES'
```
Additionally, Prometheus statistics are available at `/metrics` via HTTP.
### Live configuration reloading
The config can be reloaded by sending a `kill -s SIGHUP` to the process or by querying `RELOAD` to the admin database. Not all settings are currently supported by live reload:
The config can be reloaded by sending a `kill -s SIGHUP` to the process or by querying `RELOAD` to the admin database. All settings except the `host` and `port` can be reloaded without restarting the pooler, including sharding and replicas configurations.
| **Config** | **Requires restart** |
|-------------------------|----------------------|
| `host` | yes |
| `port` | yes |
| `pool_mode` | no |
| `connect_timeout` | yes |
| `healthcheck_timeout` | no |
| `shutdown_timeout` | no |
| `healthcheck_delay` | no |
| `ban_time` | no |
| `user` | yes |
| `shards` | yes |
| `default_role` | no |
| `primary_reads_enabled` | no |
| `query_parser_enabled` | no |
### Mirroring
Mirroring allows to route queries to multiple databases at the same time. This is useful for prewarning replicas before placing them into the active configuration, or for testing different versions of Postgres with live traffic.
## Benchmarks
### Password rotation
You can setup PgBench locally through PgCat:
Password rotation allows to specify multiple passwords for a user, so they can connect to PgCat with multiple credentials. This allows distributed applications to change their configuration (connection strings) gradually and for PgCat to monitor their progression in admin statistics. Once the new secret is deployed everywhere, the old one can be removed from PgCat.
```
pgbench -h 127.0.0.1 -p 6432 -i
```
This also decouples server passwords from client passwords, allowing to change one without necessarily changing the other.
Coincidenly, this uses `COPY` so you can test if that works. Additionally, we'll be running the following PgBench configurations:
## License
1. 16 clients, 2 threads
2. 32 clients, 2 threads
3. 64 clients, 2 threads
4. 128 clients, 2 threads
PgCat is free and open source, released under the MIT license.
All queries will be `SELECT` only (`-S`) just so disks don't get in the way, since the dataset will be effectively all in RAM.
## Contributors
My setup:
Many thanks to our amazing contributors!
- 8 cores, 16 hyperthreaded (AMD Ryzen 5800X)
- 32GB RAM (doesn't matter for this benchmark, except to prove that Postgres will fit the whole dataset into RAM)
<a href = "https://github.com/postgresml/pgcat/graphs/contributors">
<img src = "https://contrib.rocks/image?repo=postgresml/pgcat"/>
</a>
### PgBouncer
#### Config
```ini
[databases]
shard0 = host=localhost port=5432 user=sharding_user password=sharding_user
[pgbouncer]
pool_mode = transaction
max_client_conn = 1000
```
Everything else stays default.
#### Runs
```
$ pgbench -t 1000 -c 16 -j 2 -p 6432 -h 127.0.0.1 -S --protocol extended shard0
starting vacuum...end.
transaction type: <builtin: select only>
scaling factor: 1
query mode: extended
number of clients: 16
number of threads: 2
number of transactions per client: 1000
number of transactions actually processed: 16000/16000
latency average = 0.155 ms
tps = 103417.377469 (including connections establishing)
tps = 103510.639935 (excluding connections establishing)
$ pgbench -t 1000 -c 32 -j 2 -p 6432 -h 127.0.0.1 -S --protocol extended shard0
starting vacuum...end.
transaction type: <builtin: select only>
scaling factor: 1
query mode: extended
number of clients: 32
number of threads: 2
number of transactions per client: 1000
number of transactions actually processed: 32000/32000
latency average = 0.290 ms
tps = 110325.939785 (including connections establishing)
tps = 110386.513435 (excluding connections establishing)
$ pgbench -t 1000 -c 64 -j 2 -p 6432 -h 127.0.0.1 -S --protocol extended shard0
starting vacuum...end.
transaction type: <builtin: select only>
scaling factor: 1
query mode: extended
number of clients: 64
number of threads: 2
number of transactions per client: 1000
number of transactions actually processed: 64000/64000
latency average = 0.692 ms
tps = 92470.427412 (including connections establishing)
tps = 92618.389350 (excluding connections establishing)
$ pgbench -t 1000 -c 128 -j 2 -p 6432 -h 127.0.0.1 -S --protocol extended shard0
starting vacuum...end.
transaction type: <builtin: select only>
scaling factor: 1
query mode: extended
number of clients: 128
number of threads: 2
number of transactions per client: 1000
number of transactions actually processed: 128000/128000
latency average = 1.406 ms
tps = 91013.429985 (including connections establishing)
tps = 91067.583928 (excluding connections establishing)
```
### PgCat
#### Config
The only thing that matters here is the number of workers in the Tokio pool. Make sure to set it to < than the number of your CPU cores.
Also account for hyper-threading, so if you have that, take the number you got above and divide it by two, that way only "real" cores serving
requests.
My setup is 16 threads, 8 cores (`htop` shows as 16 CPUs), so I set the `max_workers` in Tokio to 4. Too many, and it starts conflicting with PgBench
which is also running on the same system.
#### Runs
```
$ pgbench -t 1000 -c 16 -j 2 -p 6432 -h 127.0.0.1 -S --protocol extended
starting vacuum...end.
transaction type: <builtin: select only>
scaling factor: 1
query mode: extended
number of clients: 16
number of threads: 2
number of transactions per client: 1000
number of transactions actually processed: 16000/16000
latency average = 0.164 ms
tps = 97705.088232 (including connections establishing)
tps = 97872.216045 (excluding connections establishing)
$ pgbench -t 1000 -c 32 -j 2 -p 6432 -h 127.0.0.1 -S --protocol extended
starting vacuum...end.
transaction type: <builtin: select only>
scaling factor: 1
query mode: extended
number of clients: 32
number of threads: 2
number of transactions per client: 1000
number of transactions actually processed: 32000/32000
latency average = 0.288 ms
tps = 111300.488119 (including connections establishing)
tps = 111413.107800 (excluding connections establishing)
$ pgbench -t 1000 -c 64 -j 2 -p 6432 -h 127.0.0.1 -S --protocol extended
starting vacuum...end.
transaction type: <builtin: select only>
scaling factor: 1
query mode: extended
number of clients: 64
number of threads: 2
number of transactions per client: 1000
number of transactions actually processed: 64000/64000
latency average = 0.556 ms
tps = 115190.496139 (including connections establishing)
tps = 115247.521295 (excluding connections establishing)
$ pgbench -t 1000 -c 128 -j 2 -p 6432 -h 127.0.0.1 -S --protocol extended
starting vacuum...end.
transaction type: <builtin: select only>
scaling factor: 1
query mode: extended
number of clients: 128
number of threads: 2
number of transactions per client: 1000
number of transactions actually processed: 128000/128000
latency average = 1.135 ms
tps = 112770.562239 (including connections establishing)
tps = 112796.502381 (excluding connections establishing)
```
### Direct Postgres
Always good to have a base line.
#### Runs
```
$ pgbench -t 1000 -c 16 -j 2 -p 5432 -h 127.0.0.1 -S --protocol extended shard0
Password:
starting vacuum...end.
transaction type: <builtin: select only>
scaling factor: 1
query mode: extended
number of clients: 16
number of threads: 2
number of transactions per client: 1000
number of transactions actually processed: 16000/16000
latency average = 0.115 ms
tps = 139443.955722 (including connections establishing)
tps = 142314.859075 (excluding connections establishing)
$ pgbench -t 1000 -c 32 -j 2 -p 5432 -h 127.0.0.1 -S --protocol extended shard0
Password:
starting vacuum...end.
transaction type: <builtin: select only>
scaling factor: 1
query mode: extended
number of clients: 32
number of threads: 2
number of transactions per client: 1000
number of transactions actually processed: 32000/32000
latency average = 0.212 ms
tps = 150644.840891 (including connections establishing)
tps = 152218.499430 (excluding connections establishing)
$ pgbench -t 1000 -c 64 -j 2 -p 5432 -h 127.0.0.1 -S --protocol extended shard0
Password:
starting vacuum...end.
transaction type: <builtin: select only>
scaling factor: 1
query mode: extended
number of clients: 64
number of threads: 2
number of transactions per client: 1000
number of transactions actually processed: 64000/64000
latency average = 0.420 ms
tps = 152517.663404 (including connections establishing)
tps = 153319.188482 (excluding connections establishing)
$ pgbench -t 1000 -c 128 -j 2 -p 5432 -h 127.0.0.1 -S --protocol extended shard0
Password:
starting vacuum...end.
transaction type: <builtin: select only>
scaling factor: 1
query mode: extended
number of clients: 128
number of threads: 2
number of transactions per client: 1000
number of transactions actually processed: 128000/128000
latency average = 0.854 ms
tps = 149818.594087 (including connections establishing)
tps = 150200.603049 (excluding connections establishing)
```

View File

@@ -26,6 +26,8 @@ x-common-env-pg:
services:
main:
image: kubernetes/pause
ports:
- 6432
pg1:
<<: *common-definition-pg
@@ -56,6 +58,13 @@ services:
POSTGRES_INITDB_ARGS: --auth-local=scram-sha-256 --auth-host=scram-sha-256 --auth=scram-sha-256
PGPORT: 9432
command: ["postgres", "-p", "9432", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-c", "pg_stat_statements.max=100000"]
pg5:
<<: *common-definition-pg
environment:
<<: *common-env-pg
POSTGRES_INITDB_ARGS: --auth-local=md5 --auth-host=md5 --auth=md5
PGPORT: 10432
command: ["postgres", "-p", "10432", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-c", "pg_stat_statements.max=100000"]
toxiproxy:
build: .
@@ -69,6 +78,7 @@ services:
- pg2
- pg3
- pg4
- pg5
pgcat-shell:
stdin_open: true

BIN
images/instacart.webp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.4 KiB

BIN
images/postgresml.webp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.7 KiB

View File

@@ -23,6 +23,9 @@ connect_timeout = 5000 # milliseconds
# How long an idle connection with a server is left open (ms).
idle_timeout = 30000 # milliseconds
# How long a client is allowed to be idle while in a transaction (ms).
idle_client_in_transaction_timeout = 0 # milliseconds
# How much time to give the health check query to return with a result (ms).
healthcheck_timeout = 1000 # milliseconds
@@ -55,9 +58,9 @@ tcp_keepalives_count = 5
tcp_keepalives_interval = 5
# Path to TLS Certficate file to use for TLS connections
# tls_certificate = "server.cert"
# tls_certificate = ".circleci/server.cert"
# Path to TLS private key file to use for TLS connections
# tls_private_key = "server.key"
# tls_private_key = ".circleci/server.key"
# User name to access the virtual administrative database (pgbouncer or pgcat)
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
@@ -119,6 +122,10 @@ idle_timeout = 40000
# Connect timeout can be overwritten in the pool
connect_timeout = 3000
# auth_query = "SELECT * FROM public.user_lookup('$1')"
# auth_query_user = "postgres"
# auth_query_password = "postgres"
# User configs are structured as pool.<pool_name>.users.<user_index>
# This secion holds the credentials for users that may connect to this cluster
[pools.sharded_db.users.0]
@@ -126,6 +133,10 @@ connect_timeout = 3000
username = "sharding_user"
# Postgresql password
password = "sharding_user"
# # Passwords the client can use to connect. Useful for password rotations.
# secrets = [ "secret_one", "secret_two" ]
# Maximum number of server connections that can be established for this user
# The maximum number of connection from a single Pgcat process to any database in the cluster
# is the sum of pool_size across all users.

View File

@@ -1,10 +1,11 @@
use crate::pool::BanReason;
/// Admin database.
use bytes::{Buf, BufMut, BytesMut};
use log::{error, info, trace};
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
use std::collections::HashMap;
/// Admin database.
use std::sync::atomic::Ordering;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::time::Instant;
@@ -12,9 +13,7 @@ use crate::config::{get_config, reload_config, VERSION};
use crate::errors::Error;
use crate::messages::*;
use crate::pool::{get_all_pools, get_pool};
use crate::stats::{
get_address_stats, get_client_stats, get_pool_stats, get_server_stats, ClientState, ServerState,
};
use crate::stats::{get_client_stats, get_pool_stats, get_server_stats, ClientState, ServerState};
use crate::ClientServerMap;
pub fn generate_server_info_for_admin() -> BytesMut {
@@ -158,7 +157,14 @@ where
"free_clients".to_string(),
client_stats
.keys()
.filter(|client_id| client_stats.get(client_id).unwrap().state == ClientState::Idle)
.filter(|client_id| {
client_stats
.get(client_id)
.unwrap()
.state
.load(Ordering::Relaxed)
== ClientState::Idle
})
.count()
.to_string(),
]));
@@ -166,7 +172,14 @@ where
"used_clients".to_string(),
client_stats
.keys()
.filter(|client_id| client_stats.get(client_id).unwrap().state == ClientState::Active)
.filter(|client_id| {
client_stats
.get(client_id)
.unwrap()
.state
.load(Ordering::Relaxed)
== ClientState::Active
})
.count()
.to_string(),
]));
@@ -178,7 +191,14 @@ where
"free_servers".to_string(),
server_stats
.keys()
.filter(|server_id| server_stats.get(server_id).unwrap().state == ServerState::Idle)
.filter(|server_id| {
server_stats
.get(server_id)
.unwrap()
.state
.load(Ordering::Relaxed)
== ServerState::Idle
})
.count()
.to_string(),
]));
@@ -186,7 +206,14 @@ where
"used_servers".to_string(),
server_stats
.keys()
.filter(|server_id| server_stats.get(server_id).unwrap().state == ServerState::Active)
.filter(|server_id| {
server_stats
.get(server_id)
.unwrap()
.state
.load(Ordering::Relaxed)
== ServerState::Active
})
.count()
.to_string(),
]));
@@ -232,6 +259,7 @@ where
let columns = vec![
("database", DataType::Text),
("user", DataType::Text),
("secret", DataType::Text),
("pool_mode", DataType::Text),
("cl_idle", DataType::Numeric),
("cl_active", DataType::Numeric),
@@ -248,28 +276,16 @@ where
let mut res = BytesMut::new();
res.put(row_description(&columns));
for (user_pool, pool) in get_all_pools() {
let def = HashMap::default();
let pool_stats = all_pool_stats
.get(&(user_pool.db.clone(), user_pool.user.clone()))
.unwrap_or(&def);
let pool_config = &pool.settings;
for (_, pool_stats) in all_pool_stats {
let mut row = vec![
user_pool.db.clone(),
user_pool.user.clone(),
pool_config.pool_mode.to_string(),
pool_stats.database(),
pool_stats.user(),
pool_stats.redacted_secret(),
pool_stats.pool_mode().to_string(),
];
for column in &columns[3..columns.len()] {
let value = match column.0 {
"maxwait" => (pool_stats.get("maxwait_us").unwrap_or(&0) / 1_000_000).to_string(),
"maxwait_us" => {
(pool_stats.get("maxwait_us").unwrap_or(&0) % 1_000_000).to_string()
}
_other_values => pool_stats.get(column.0).unwrap_or(&0).to_string(),
};
row.push(value);
}
pool_stats.populate_row(&mut row);
pool_stats.clear_maxwait();
res.put(data_row(&row));
}
@@ -400,7 +416,7 @@ where
for (id, pool) in get_all_pools().iter() {
for address in pool.get_addresses_from_host(host) {
if !pool.is_banned(&address) {
pool.ban(&address, BanReason::AdminBan(duration_seconds), -1);
pool.ban(&address, BanReason::AdminBan(duration_seconds), None);
res.put(data_row(&vec![
id.db.clone(),
id.user.clone(),
@@ -617,7 +633,6 @@ where
("avg_wait_time", DataType::Numeric),
];
let all_stats = get_address_stats();
let mut res = BytesMut::new();
res.put(row_description(&columns));
@@ -625,15 +640,10 @@ where
for shard in 0..pool.shards() {
for server in 0..pool.servers(shard) {
let address = pool.address(shard, server);
let stats = match all_stats.get(&address.id) {
Some(stats) => stats.clone(),
None => HashMap::new(),
};
let mut row = vec![address.name(), user_pool.db.clone(), user_pool.user.clone()];
for column in &columns[3..] {
row.push(stats.get(column.0).unwrap_or(&0).to_string());
}
let stats = address.stats.clone();
stats.populate_row(&mut row);
res.put(data_row(&row));
}
@@ -673,16 +683,16 @@ where
for (_, client) in new_map {
let row = vec![
format!("{:#010X}", client.client_id),
client.pool_name,
client.username,
client.application_name.clone(),
client.state.to_string(),
client.transaction_count.to_string(),
client.query_count.to_string(),
client.error_count.to_string(),
format!("{:#010X}", client.client_id()),
client.pool_name(),
client.username(),
client.application_name(),
client.state.load(Ordering::Relaxed).to_string(),
client.transaction_count.load(Ordering::Relaxed).to_string(),
client.query_count.load(Ordering::Relaxed).to_string(),
client.error_count.load(Ordering::Relaxed).to_string(),
Instant::now()
.duration_since(client.connect_time)
.duration_since(client.connect_time())
.as_secs()
.to_string(),
];
@@ -724,19 +734,20 @@ where
res.put(row_description(&columns));
for (_, server) in new_map {
let application_name = server.application_name.read();
let row = vec![
format!("{:#010X}", server.server_id),
server.pool_name,
server.username,
server.address_name,
server.application_name,
server.state.to_string(),
server.transaction_count.to_string(),
server.query_count.to_string(),
server.bytes_sent.to_string(),
server.bytes_received.to_string(),
format!("{:#010X}", server.server_id()),
server.pool_name(),
server.username(),
server.address_name(),
application_name.clone(),
server.state.load(Ordering::Relaxed).to_string(),
server.transaction_count.load(Ordering::Relaxed).to_string(),
server.query_count.load(Ordering::Relaxed).to_string(),
server.bytes_sent.load(Ordering::Relaxed).to_string(),
server.bytes_received.load(Ordering::Relaxed).to_string(),
Instant::now()
.duration_since(server.connect_time)
.duration_since(server.connect_time())
.as_secs()
.to_string(),
];
@@ -771,7 +782,7 @@ where
let database = parts[0];
let user = parts[1];
match get_pool(database, user) {
match get_pool(database, user, None) {
Some(pool) => {
pool.pause();
@@ -818,7 +829,7 @@ where
let database = parts[0];
let user = parts[1];
match get_pool(database, user) {
match get_pool(database, user, None) {
Some(pool) => {
pool.resume();
@@ -886,13 +897,20 @@ where
res.put(row_description(&vec![
("name", DataType::Text),
("pool_mode", DataType::Text),
("secret", DataType::Text),
]));
for (user_pool, pool) in get_all_pools() {
let pool_config = &pool.settings;
let redacted_secret = match user_pool.secret {
Some(secret) => format!("****{}", &secret[secret.len() - 4..]),
None => "<no secret>".to_string(),
};
res.put(data_row(&vec![
user_pool.user.clone(),
pool_config.pool_mode.to_string(),
redacted_secret,
]));
}

452
src/auth.rs Normal file
View File

@@ -0,0 +1,452 @@
//! Module implementing various client authentication mechanisms.
//!
//! Currently supported: plain (via TLS), md5 (via TLS and plain text connection).
use crate::errors::Error;
use crate::tokio::io::AsyncReadExt;
use crate::{
auth_passthrough::AuthPassthrough,
config::get_config,
messages::{
error_response, md5_hash_password, md5_hash_second_pass, write_all, wrong_password,
},
pool::{get_pool, ConnectionPool},
};
use bytes::{BufMut, BytesMut};
use log::debug;
async fn refetch_auth_hash<S>(
pool: &ConnectionPool,
stream: &mut S,
username: &str,
pool_name: &str,
) -> Result<String, Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
let config = get_config();
debug!("Fetching auth hash");
if config.is_auth_query_configured() {
let address = pool.address(0, 0);
if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) {
let hash = apt.fetch_hash(address).await?;
debug!("Auth query succeeded");
return Ok(hash);
}
} else {
debug!("Auth query not configured on pool");
}
error_response(
stream,
&format!(
"No password set and auth passthrough failed for database: {}, user: {}",
pool_name, username
),
)
.await?;
Err(Error::ClientError(format!(
"Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.",
pool_name, username
)))
}
/// Read 'p' message from client.
async fn response<R>(stream: &mut R) -> Result<Vec<u8>, Error>
where
R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send,
{
let code = match stream.read_u8().await {
Ok(code) => code,
Err(_) => {
return Err(Error::SocketError(
"Error reading password code from client".to_string(),
))
}
};
if code as char != 'p' {
return Err(Error::SocketError(format!("Expected p, got {}", code)));
}
let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => {
return Err(Error::SocketError(
"Error reading password length from client".to_string(),
))
}
};
let mut response = vec![0; (len - 4) as usize];
// Too short to be a password (null-terminated)
if response.len() < 2 {
return Err(Error::ClientError(format!("Password response too short")));
}
match stream.read_exact(&mut response).await {
Ok(_) => (),
Err(_) => {
return Err(Error::SocketError(
"Error reading password from client".to_string(),
))
}
};
Ok(response.to_vec())
}
/// Make sure the pool we authenticated to has at least one server connection
/// that can serve our request.
async fn validate_pool<W>(
stream: &mut W,
mut pool: ConnectionPool,
username: &str,
pool_name: &str,
) -> Result<(), Error>
where
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
if !pool.validated() {
match pool.validate().await {
Ok(_) => Ok(()),
Err(err) => {
error_response(
stream,
&format!("Pool down for database: {}, user: {}", pool_name, username,),
)
.await?;
Err(Error::ClientError(format!("Pool down: {:?}", err)))
}
}
} else {
Ok(())
}
}
/// Clear text authentication.
///
/// The client will send the password in plain text over the wire.
/// To protect against obvious security issues, this is only used over TLS.
///
/// Clear text authentication is used to support zero-downtime password rotation.
/// It allows the client to use multiple passwords when talking to the PgCat
/// while the password is being rotated across multiple app instances.
pub struct ClearText {
username: String,
pool_name: String,
application_name: String,
}
impl ClearText {
/// Create a new ClearText authentication mechanism.
pub fn new(username: &str, pool_name: &str, application_name: &str) -> ClearText {
ClearText {
username: username.to_string(),
pool_name: pool_name.to_string(),
application_name: application_name.to_string(),
}
}
/// Issue 'R' clear text challenge to client.
pub async fn challenge<W>(&self, stream: &mut W) -> Result<(), Error>
where
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
debug!("Sending plain challenge");
let mut msg = BytesMut::new();
msg.put_u8(b'R');
msg.put_i32(8);
msg.put_i32(3); // Clear text
write_all(stream, msg).await
}
/// Authenticate client with server password or secret.
pub async fn authenticate<R, W>(
&self,
read: &mut R,
write: &mut W,
) -> Result<Option<String>, Error>
where
R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send,
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
let response = response(read).await?;
let secret = String::from_utf8_lossy(&response[0..response.len() - 1]).to_string();
match get_pool(&self.pool_name, &self.username, Some(secret.clone())) {
None => match get_pool(&self.pool_name, &self.username, None) {
Some(pool) => {
match pool.settings.user.password {
Some(ref password) => {
if password != &secret {
wrong_password(write, &self.username).await?;
Err(Error::ClientError(format!(
"Invalid password {{ username: {}, pool_name: {}, application_name: {} }}",
self.username, self.pool_name, self.application_name
)))
} else {
validate_pool(write, pool, &self.username, &self.pool_name).await?;
Ok(None)
}
}
None => {
// Server is storing hashes, we can't query it for the plain text password.
error_response(
write,
&format!(
"No server password configured for database: {}, user: {}",
self.pool_name, self.username
),
)
.await?;
Err(Error::ClientError(format!(
"No server password configured for {{ username: {}, pool_name: {}, application_name: {} }}",
self.username, self.pool_name, self.application_name
)))
}
}
}
None => {
error_response(
write,
&format!(
"No pool configured for database: {}, user: {}",
self.pool_name, self.username
),
)
.await?;
Err(Error::ClientError(format!(
"Invalid pool name {{ username: {}, pool_name: {}, application_name: {} }}",
self.username, self.pool_name, self.application_name
)))
}
},
Some(pool) => {
validate_pool(write, pool, &self.username, &self.pool_name).await?;
Ok(Some(secret))
}
}
}
}
/// MD5 hash authentication.
///
/// Deprecated, but widely used everywhere, and currently required for poolers
/// to authencticate clients without involving Postgres.
///
/// Admin clients are required to use MD5.
pub struct Md5 {
username: String,
pool_name: String,
application_name: String,
salt: [u8; 4],
admin: bool,
}
impl Md5 {
pub fn new(username: &str, pool_name: &str, application_name: &str, admin: bool) -> Md5 {
let salt: [u8; 4] = [
rand::random(),
rand::random(),
rand::random(),
rand::random(),
];
Md5 {
username: username.to_string(),
pool_name: pool_name.to_string(),
application_name: application_name.to_string(),
salt,
admin,
}
}
/// Issue a 'R' MD5 challenge to the client.
pub async fn challenge<W>(&self, stream: &mut W) -> Result<(), Error>
where
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
let mut res = BytesMut::new();
res.put_u8(b'R');
res.put_i32(12);
res.put_i32(5); // MD5
res.put_slice(&self.salt[..]);
write_all(stream, res).await
}
/// Authenticate client with MD5. This is used for both admin and normal users.
pub async fn authenticate<R, W>(&self, read: &mut R, write: &mut W) -> Result<(), Error>
where
R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send,
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
let password_hash = response(read).await?;
if self.admin {
let config = get_config();
// Compare server and client hashes.
let our_hash = md5_hash_password(
&config.general.admin_username,
&config.general.admin_password,
&self.salt,
);
if our_hash != password_hash {
wrong_password(write, &self.username).await?;
Err(Error::ClientError(format!(
"Invalid password {{ username: {}, pool_name: {}, application_name: {} }}",
self.username, self.pool_name, self.application_name
)))
} else {
Ok(())
}
} else {
match get_pool(&self.pool_name, &self.username, None) {
Some(pool) => {
match &pool.settings.user.password {
Some(ref password) => {
let our_hash = md5_hash_password(&self.username, password, &self.salt);
if our_hash != password_hash {
wrong_password(write, &self.username).await?;
Err(Error::ClientError(format!(
"Invalid password {{ username: {}, pool_name: {}, application_name: {} }}",
self.username, self.pool_name, self.application_name
)))
} else {
validate_pool(write, pool, &self.username, &self.pool_name).await?;
Ok(())
}
}
None => {
if !get_config().is_auth_query_configured() {
error_response(
write,
&format!(
"No password configured and auth_query is not set: {}, user: {}",
self.pool_name, self.username
),
)
.await?;
return Err(Error::ClientError(format!(
"No password configured and auth_query is not set"
)));
}
debug!("Using auth_query");
// Fetch hash from server
let hash = (*pool.auth_hash.read()).clone();
let hash = match hash {
Some(hash) => {
debug!("Using existing hash: {}", hash);
hash.clone()
}
None => {
debug!("Pool has no hash set, fetching new one");
let hash = refetch_auth_hash(
&pool,
write,
&self.username,
&self.pool_name,
)
.await?;
(*pool.auth_hash.write()) = Some(hash.clone());
hash
}
};
let our_hash = md5_hash_second_pass(&hash, &self.salt);
// Compare hashes
if our_hash != password_hash {
debug!("Pool auth query hash did not match, refetching");
// Server hash maybe changed
let hash = refetch_auth_hash(
&pool,
write,
&self.username,
&self.pool_name,
)
.await?;
let our_hash = md5_hash_second_pass(&hash, &self.salt);
if our_hash != password_hash {
debug!("Auth query failed, passwords don't match");
wrong_password(write, &self.username).await?;
Err(Error::ClientError(format!(
"Invalid password {{ username: {}, pool_name: {}, application_name: {} }}",
self.username, self.pool_name, self.application_name
)))
} else {
(*pool.auth_hash.write()) = Some(hash);
validate_pool(
write,
pool.clone(),
&self.username,
&self.pool_name,
)
.await?;
Ok(())
}
} else {
validate_pool(write, pool.clone(), &self.username, &self.pool_name)
.await?;
Ok(())
}
}
}
}
None => {
error_response(
write,
&format!(
"No pool configured for database: {}, user: {}",
self.pool_name, self.username
),
)
.await?;
return Err(Error::ClientError(format!(
"Invalid pool name {{ username: {}, pool_name: {}, application_name: {} }}",
self.username, self.pool_name, self.application_name
)));
}
}
}
}
}

110
src/auth_passthrough.rs Normal file
View File

@@ -0,0 +1,110 @@
use crate::errors::Error;
use crate::server::Server;
use log::debug;
#[derive(Clone, Debug)]
pub struct AuthPassthrough {
password: String,
query: String,
user: String,
}
impl AuthPassthrough {
/// Initializes an AuthPassthrough.
pub fn new(query: &str, user: &str, password: &str) -> Self {
AuthPassthrough {
password: password.to_string(),
query: query.to_string(),
user: user.to_string(),
}
}
/// Returns an AuthPassthrough given the pool configuration.
/// If any of required values is not set, None is returned.
pub fn from_pool_config(pool_config: &crate::config::Pool) -> Option<Self> {
if pool_config.is_auth_query_configured() {
return Some(AuthPassthrough::new(
pool_config.auth_query.as_ref().unwrap(),
pool_config.auth_query_user.as_ref().unwrap(),
pool_config.auth_query_password.as_ref().unwrap(),
));
}
None
}
/// Returns an AuthPassthrough given the pool settings.
/// If any of required values is not set, None is returned.
pub fn from_pool_settings(pool_settings: &crate::pool::PoolSettings) -> Option<Self> {
let pool_config = crate::config::Pool {
auth_query: pool_settings.auth_query.clone(),
auth_query_password: pool_settings.auth_query_password.clone(),
auth_query_user: pool_settings.auth_query_user.clone(),
..Default::default()
};
AuthPassthrough::from_pool_config(&pool_config)
}
/// Connects to server and executes auth_query for the specified address.
/// If the response is a row with two columns containing the username set in the address.
/// and its MD5 hash, the MD5 hash returned.
///
/// Note that the query is executed, changing $1 with the name of the user
/// this is so we only hold in memory (and transfer) the least amount of 'sensitive' data.
/// Also, it is compatible with pgbouncer.
///
/// # Arguments
///
/// * `address` - An Address of the server we want to connect to. The username for the hash will be obtained from this value.
///
/// # Examples
///
/// ```
/// use pgcat::auth_passthrough::AuthPassthrough;
/// use pgcat::config::Address;
/// let auth_passthrough = AuthPassthrough::new("SELECT * FROM public.user_lookup('$1');", "postgres", "postgres");
/// auth_passthrough.fetch_hash(&Address::default());
/// ```
///
pub async fn fetch_hash(&self, address: &crate::config::Address) -> Result<String, Error> {
let auth_user = crate::config::User {
username: self.user.clone(),
password: Some(self.password.clone()),
pool_size: 1,
statement_timeout: 0,
secrets: None,
};
let user = &address.username;
debug!("Connecting to server to obtain auth hashes.");
let auth_query = self.query.replace("$1", user);
match Server::exec_simple_query(address, &auth_user, &auth_query).await {
Ok(password_data) => {
if password_data.len() == 2 && password_data.first().unwrap() == user {
if let Some(stripped_hash) = password_data.last().unwrap().to_string().strip_prefix("md5") {
Ok(stripped_hash.to_string())
} else {
Err(Error::AuthPassthroughError(
"Obtained hash from auth_query does not seem to be in md5 format.".to_string(),
))
}
} else {
Err(Error::AuthPassthroughError(
"Data obtained from query does not follow the scheme 'user','hash'."
.to_string(),
))
}
}
Err(err) => {
Err(Error::AuthPassthroughError(
format!("Error trying to obtain password from auth_query, ignoring hash for user '{}'. Error: {:?}",
user, err)))
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
/// Parse the configuration file.
use arc_swap::ArcSwap;
use log::{error, info};
use log::{error, info, warn};
use once_cell::sync::Lazy;
use regex::Regex;
use serde_derive::{Deserialize, Serialize};
@@ -15,6 +15,7 @@ use tokio::io::AsyncReadExt;
use crate::errors::Error;
use crate::pool::{ClientServerMap, ConnectionPool};
use crate::sharding::ShardingFunction;
use crate::stats::AddressStats;
use crate::tls::{load_certs, load_keys};
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
@@ -62,7 +63,7 @@ impl PartialEq<Role> for Option<Role> {
}
/// Address identifying a PostgreSQL server uniquely.
#[derive(Clone, PartialEq, Hash, std::cmp::Eq, Debug)]
#[derive(Clone, Debug)]
pub struct Address {
/// Unique ID per addressable Postgres server.
pub id: usize,
@@ -96,6 +97,9 @@ pub struct Address {
/// List of addresses to receive mirrored traffic.
pub mirrors: Vec<Address>,
/// Address stats
pub stats: Arc<AddressStats>,
}
impl Default for Address {
@@ -112,10 +116,46 @@ impl Default for Address {
username: String::from("username"),
pool_name: String::from("pool_name"),
mirrors: Vec::new(),
stats: Arc::new(AddressStats::default()),
}
}
}
// We need to implement PartialEq by ourselves so we skip stats in the comparison
impl PartialEq for Address {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
&& self.host == other.host
&& self.port == other.port
&& self.shard == other.shard
&& self.address_index == other.address_index
&& self.replica_number == other.replica_number
&& self.database == other.database
&& self.role == other.role
&& self.username == other.username
&& self.pool_name == other.pool_name
&& self.mirrors == other.mirrors
}
}
impl Eq for Address {}
// We need to implement Hash by ourselves so we skip stats in the comparison
impl Hash for Address {
fn hash<H: Hasher>(&self, state: &mut H) {
self.id.hash(state);
self.host.hash(state);
self.port.hash(state);
self.shard.hash(state);
self.address_index.hash(state);
self.replica_number.hash(state);
self.database.hash(state);
self.role.hash(state);
self.username.hash(state);
self.pool_name.hash(state);
self.mirrors.hash(state);
}
}
impl Address {
/// Address name (aka database) used in `SHOW STATS`, `SHOW DATABASES`, and `SHOW POOLS`.
pub fn name(&self) -> String {
@@ -137,19 +177,40 @@ impl Address {
#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize, Debug)]
pub struct User {
pub username: String,
pub password: String,
pub password: Option<String>,
pub pool_size: u32,
#[serde(default)] // 0
pub statement_timeout: u64,
pub secrets: Option<Vec<String>>,
}
impl User {
fn validate(&self) -> Result<(), Error> {
match self.secrets {
Some(ref secrets) => {
for secret in secrets.iter() {
if secret.len() < 16 {
warn!(
"[user: {}] Secret is too short (less than 16 characters)",
self.username
);
}
}
}
None => (),
}
Ok(())
}
}
impl Default for User {
fn default() -> User {
User {
username: String::from("postgres"),
password: String::new(),
password: None,
pool_size: 15,
statement_timeout: 0,
secrets: None,
}
}
}
@@ -197,6 +258,9 @@ pub struct General {
#[serde(default = "General::default_ban_time")]
pub ban_time: i64,
#[serde(default = "General::default_idle_client_in_transaction_timeout")]
pub idle_client_in_transaction_timeout: u64,
#[serde(default = "General::default_worker_threads")]
pub worker_threads: usize,
@@ -207,6 +271,10 @@ pub struct General {
pub tls_private_key: Option<String>,
pub admin_username: String,
pub admin_password: String,
pub auth_query: Option<String>,
pub auth_query_user: Option<String>,
pub auth_query_password: Option<String>,
}
impl General {
@@ -260,6 +328,10 @@ impl General {
pub fn default_worker_threads() -> usize {
4
}
pub fn default_idle_client_in_transaction_timeout() -> u64 {
0
}
}
impl Default for General {
@@ -276,6 +348,7 @@ impl Default for General {
healthcheck_delay: Self::default_healthcheck_delay(),
ban_time: Self::default_ban_time(),
worker_threads: Self::default_worker_threads(),
idle_client_in_transaction_timeout: Self::default_idle_client_in_transaction_timeout(),
tcp_keepalives_idle: Self::default_tcp_keepalives_idle(),
tcp_keepalives_count: Self::default_tcp_keepalives_count(),
tcp_keepalives_interval: Self::default_tcp_keepalives_interval(),
@@ -286,6 +359,9 @@ impl Default for General {
tls_private_key: None,
admin_username: String::from("admin"),
admin_password: String::from("admin"),
auth_query: None,
auth_query_user: None,
auth_query_password: None,
}
}
}
@@ -358,6 +434,10 @@ pub struct Pool {
pub shard_id_regex: Option<String>,
pub regex_search_limit: Option<usize>,
pub auth_query: Option<String>,
pub auth_query_user: Option<String>,
pub auth_query_password: Option<String>,
pub shards: BTreeMap<String, Shard>,
pub users: BTreeMap<String, User>,
// Note, don't put simple fields below these configs. There's a compatability issue with TOML that makes it
@@ -372,6 +452,12 @@ impl Pool {
s.finish()
}
pub fn is_auth_query_configured(&self) -> bool {
self.auth_query_password.is_some()
&& self.auth_query_user.is_some()
&& self.auth_query_password.is_some()
}
pub fn default_pool_mode() -> PoolMode {
PoolMode::Transaction
}
@@ -443,6 +529,10 @@ impl Pool {
None => None,
};
for user in self.users.iter() {
user.1.validate()?;
}
Ok(())
}
}
@@ -464,6 +554,9 @@ impl Default for Pool {
sharding_key_regex: None,
shard_id_regex: None,
regex_search_limit: Some(1000),
auth_query: None,
auth_query_user: None,
auth_query_password: None,
}
}
}
@@ -564,9 +657,36 @@ pub struct Config {
}
impl Config {
pub fn is_auth_query_configured(&self) -> bool {
self.pools
.iter()
.any(|(_name, pool)| pool.is_auth_query_configured())
}
pub fn default_path() -> String {
String::from("pgcat.toml")
}
pub fn fill_up_auth_query_config(&mut self) {
for (_name, pool) in self.pools.iter_mut() {
if pool.auth_query.is_none() {
pool.auth_query = self.general.auth_query.clone();
}
if pool.auth_query_user.is_none() {
pool.auth_query_user = self.general.auth_query_user.clone();
}
if pool.auth_query_password.is_none() {
pool.auth_query_password = self.general.auth_query_password.clone();
}
}
}
/// Checks that we configured TLS.
pub fn tls_enabled(&self) -> bool {
self.general.tls_certificate.is_some() && self.general.tls_private_key.is_some()
}
}
impl Default for Config {
@@ -655,6 +775,13 @@ impl From<&Config> for std::collections::HashMap<String, String> {
config.general.healthcheck_delay.to_string(),
),
("ban_time".to_string(), config.general.ban_time.to_string()),
(
"idle_client_in_transaction_timeout".to_string(),
config
.general
.idle_client_in_transaction_timeout
.to_string(),
),
];
r.append(&mut static_settings);
@@ -666,6 +793,10 @@ impl Config {
/// Print current configuration.
pub fn show(&self) {
info!("Ban time: {}s", self.general.ban_time);
info!(
"Idle client in transaction timeout: {}ms",
self.general.idle_client_in_transaction_timeout
);
info!("Worker threads: {}", self.general.worker_threads);
info!(
"Healthcheck timeout: {}ms",
@@ -773,6 +904,35 @@ impl Config {
}
pub fn validate(&mut self) -> Result<(), Error> {
// Validation for auth_query feature
if self.general.auth_query.is_some()
&& (self.general.auth_query_user.is_none()
|| self.general.auth_query_password.is_none())
{
error!("If auth_query is specified, you need to provide a value for `auth_query_user`, `auth_query_password`");
return Err(Error::BadConfig);
}
for (name, pool) in self.pools.iter() {
if pool.auth_query.is_some()
&& (pool.auth_query_user.is_none() || pool.auth_query_password.is_none())
{
error!("Error in pool {{ {} }}. If auth_query is specified, you need to provide a value for `auth_query_user`, `auth_query_password`", name);
return Err(Error::BadConfig);
}
for (_name, user_data) in pool.users.iter() {
if (pool.auth_query.is_none()
|| pool.auth_query_password.is_none()
|| pool.auth_query_user.is_none())
&& user_data.password.is_none()
{
error!("Error in pool {{ {} }}. You have to specify a user password for every pool if auth_query is not specified", name);
return Err(Error::BadConfig);
}
}
}
// Validate TLS!
match self.general.tls_certificate.clone() {
Some(tls_certificate) => {
@@ -819,6 +979,12 @@ pub fn get_config() -> Config {
(*(*CONFIG.load())).clone()
}
pub fn get_idle_client_in_transaction_timeout() -> u64 {
(*(*CONFIG.load()))
.general
.idle_client_in_transaction_timeout
}
/// Parse the configuration file located at the path.
pub async fn parse(path: &str) -> Result<(), Error> {
let mut contents = String::new();
@@ -846,6 +1012,7 @@ pub async fn parse(path: &str) -> Result<(), Error> {
}
};
config.fill_up_auth_query_config();
config.validate()?;
config.path = path.to_string();
@@ -889,6 +1056,7 @@ mod test {
assert_eq!(get_config().path, "pgcat.toml".to_string());
assert_eq!(get_config().general.ban_time, 60);
assert_eq!(get_config().general.idle_client_in_transaction_timeout, 0);
assert_eq!(get_config().general.idle_timeout, 30000);
assert_eq!(get_config().pools.len(), 2);
assert_eq!(get_config().pools["sharded_db"].shards.len(), 3);
@@ -914,7 +1082,10 @@ mod test {
"sharding_user"
);
assert_eq!(
get_config().pools["sharded_db"].users["1"].password,
get_config().pools["sharded_db"].users["1"]
.password
.as_ref()
.unwrap(),
"other_user"
);
assert_eq!(get_config().pools["sharded_db"].users["1"].pool_size, 21);
@@ -939,10 +1110,16 @@ mod test {
"simple_user"
);
assert_eq!(
get_config().pools["simple_db"].users["0"].password,
get_config().pools["simple_db"].users["0"]
.password
.as_ref()
.unwrap(),
"simple_user"
);
assert_eq!(get_config().pools["simple_db"].users["0"].pool_size, 5);
assert_eq!(get_config().general.auth_query, None);
assert_eq!(get_config().general.auth_query_user, None);
assert_eq!(get_config().general.auth_query_password, None);
}
#[tokio::test]

View File

@@ -15,4 +15,6 @@ pub enum Error {
StatementTimeout,
ShuttingDown,
ParseBytesError(String),
AuthError(String),
AuthPassthroughError(String),
}

View File

@@ -1,3 +1,4 @@
pub mod auth_passthrough;
pub mod config;
pub mod constants;
pub mod errors;

View File

@@ -61,6 +61,8 @@ use std::sync::Arc;
use tokio::sync::broadcast;
mod admin;
mod auth;
mod auth_passthrough;
mod client;
mod config;
mod constants;
@@ -162,8 +164,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new()));
// Statistics reporting.
let (stats_tx, stats_rx) = mpsc::channel(500_000);
REPORTER.store(Arc::new(Reporter::new(stats_tx.clone())));
REPORTER.store(Arc::new(Reporter::default()));
// Connection pool that allows to query all shards and replicas.
match ConnectionPool::from_config(client_server_map.clone()).await {
@@ -175,7 +176,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
};
tokio::task::spawn(async move {
let mut stats_collector = Collector::new(stats_rx, stats_tx.clone());
let mut stats_collector = Collector::default();
stats_collector.collect().await;
});

View File

@@ -1,7 +1,7 @@
/// Helper functions to send one-off protocol messages
/// and handle TcpStream (TCP socket).
use bytes::{Buf, BufMut, BytesMut};
use log::error;
use log::{debug, error};
use md5::{Digest, Md5};
use socket2::{SockRef, TcpKeepalive};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
@@ -46,29 +46,6 @@ where
write_all(stream, auth_ok).await
}
/// Generate md5 password challenge.
pub async fn md5_challenge<S>(stream: &mut S) -> Result<[u8; 4], Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
// let mut rng = rand::thread_rng();
let salt: [u8; 4] = [
rand::random(),
rand::random(),
rand::random(),
rand::random(),
];
let mut res = BytesMut::new();
res.put_u8(b'R');
res.put_i32(12);
res.put_i32(5); // MD5
res.put_slice(&salt[..]);
write_all(stream, res).await?;
Ok(salt)
}
/// Give the client the process_id and secret we generated
/// used in query cancellation.
pub async fn backend_key_data<S>(
@@ -213,7 +190,13 @@ pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec<u8> {
let output = md5.finalize_reset();
// Second pass
md5.update(format!("{:x}", output));
md5_hash_second_pass(&(format!("{:x}", output)), salt)
}
pub fn md5_hash_second_pass(hash: &str, salt: &[u8]) -> Vec<u8> {
let mut md5 = Md5::new();
// Second pass
md5.update(hash);
md5.update(salt);
let mut password = format!("md5{:x}", md5.finalize())
@@ -247,6 +230,22 @@ where
write_all(stream, message).await
}
pub async fn md5_password_with_hash<S>(stream: &mut S, hash: &str, salt: &[u8]) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
debug!("Sending hash {} to server", hash);
let password = md5_hash_second_pass(hash, salt);
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
message.put_u8(b'p');
message.put_i32(password.len() as i32 + 4);
message.put_slice(&password[..]);
write_all(stream, message).await
}
/// Implements a response to our custom `SET SHARDING KEY`
/// and `SET SERVER ROLE` commands.
/// This tells the client we're ready for the next query.

View File

@@ -1,11 +1,14 @@
use std::sync::Arc;
/// A mirrored PostgreSQL client.
/// Packets arrive to us through a channel from the main client and we send them to the server.
use bb8::Pool;
use bytes::{Bytes, BytesMut};
use parking_lot::RwLock;
use crate::config::{get_config, Address, Role, User};
use crate::pool::{ClientServerMap, ServerPool};
use crate::stats::get_reporter;
use crate::pool::{ClientServerMap, PoolIdentifier, ServerPool};
use crate::stats::PoolStats;
use log::{error, info, trace, warn};
use tokio::sync::mpsc::{channel, Receiver, Sender};
@@ -21,20 +24,25 @@ impl MirroredClient {
async fn create_pool(&self) -> Pool<ServerPool> {
let config = get_config();
let default = std::time::Duration::from_millis(10_000).as_millis() as u64;
let (connection_timeout, idle_timeout) = match config.pools.get(&self.address.pool_name) {
Some(cfg) => (
cfg.connect_timeout.unwrap_or(default),
cfg.idle_timeout.unwrap_or(default),
),
None => (default, default),
};
let (connection_timeout, idle_timeout, cfg) =
match config.pools.get(&self.address.pool_name) {
Some(cfg) => (
cfg.connect_timeout.unwrap_or(default),
cfg.idle_timeout.unwrap_or(default),
cfg.clone(),
),
None => (default, default, crate::config::Pool::default()),
};
let identifier = PoolIdentifier::new(&self.database, &self.user.username, None);
let manager = ServerPool::new(
self.address.clone(),
self.user.clone(),
self.database.as_str(),
ClientServerMap::default(),
get_reporter(),
Arc::new(PoolStats::new(identifier, cfg.clone())),
Arc::new(RwLock::new(None)),
);
Pool::builder()

View File

@@ -20,9 +20,10 @@ use tokio::sync::Notify;
use crate::config::{get_config, Address, General, LoadBalancingMode, PoolMode, Role, User};
use crate::errors::Error;
use crate::auth_passthrough::AuthPassthrough;
use crate::server::Server;
use crate::sharding::ShardingFunction;
use crate::stats::{get_reporter, Reporter};
use crate::stats::{AddressStats, ClientStats, PoolStats, ServerStats};
pub type ProcessId = i32;
pub type SecretKey = i32;
@@ -51,31 +52,29 @@ pub enum BanReason {
/// An identifier for a PgCat pool,
/// a database visible to clients.
#[derive(Hash, Debug, Clone, PartialEq, Eq)]
#[derive(Hash, Debug, Clone, PartialEq, Eq, Default)]
pub struct PoolIdentifier {
// The name of the database clients want to connect to.
pub db: String,
/// The username the client connects with. Each user gets its own pool.
pub user: String,
/// The client secret (password).
pub secret: Option<String>,
}
impl PoolIdentifier {
/// Create a new user/pool identifier.
pub fn new(db: &str, user: &str) -> PoolIdentifier {
pub fn new(db: &str, user: &str, secret: Option<String>) -> PoolIdentifier {
PoolIdentifier {
db: db.to_string(),
user: user.to_string(),
secret,
}
}
}
impl From<&Address> for PoolIdentifier {
fn from(address: &Address) -> PoolIdentifier {
PoolIdentifier::new(&address.database, &address.username)
}
}
/// Pool settings.
#[derive(Clone, Debug)]
pub struct PoolSettings {
@@ -123,6 +122,11 @@ pub struct PoolSettings {
// Limit how much of each query is searched for a potential shard regex match
pub regex_search_limit: usize,
// Auth query parameters
pub auth_query: Option<String>,
pub auth_query_user: Option<String>,
pub auth_query_password: Option<String>,
}
impl Default for PoolSettings {
@@ -143,6 +147,9 @@ impl Default for PoolSettings {
sharding_key_regex: None,
shard_id_regex: None,
regex_search_limit: 1000,
auth_query: None,
auth_query_user: None,
auth_query_password: None,
}
}
}
@@ -161,10 +168,6 @@ pub struct ConnectionPool {
/// that should not be queried.
banlist: BanList,
/// The statistics aggregator runs in a separate task
/// and receives stats from clients, servers, and the pool.
stats: Reporter,
/// The server information (K messages) have to be passed to the
/// clients on startup. We pre-connect to all shards and replicas
/// on pool creation and save the K messages here.
@@ -185,6 +188,11 @@ pub struct ConnectionPool {
/// If the pool has been paused or not.
paused: Arc<AtomicBool>,
paused_waiter: Arc<Notify>,
pub stats: Arc<PoolStats>,
/// AuthInfo
pub auth_hash: Arc<RwLock<Option<String>>>,
}
impl ConnectionPool {
@@ -200,186 +208,241 @@ impl ConnectionPool {
// There is one pool per database/user pair.
for user in pool_config.users.values() {
let old_pool_ref = get_pool(pool_name, &user.username);
match old_pool_ref {
Some(pool) => {
// If the pool hasn't changed, get existing reference and insert it into the new_pools.
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
if pool.config_hash == new_pool_hash_value {
info!(
"[pool: {}][user: {}] has not changed",
pool_name, user.username
);
new_pools.insert(
PoolIdentifier::new(pool_name, &user.username),
pool.clone(),
);
continue;
}
}
None => (),
}
info!(
"[pool: {}][user: {}] creating new pool",
pool_name, user.username
);
let mut shards = Vec::new();
let mut addresses = Vec::new();
let mut banlist = Vec::new();
let mut shard_ids = pool_config
.shards
.clone()
.into_keys()
.collect::<Vec<String>>();
// Sort by shard number to ensure consistency.
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
for shard_idx in &shard_ids {
let shard = &pool_config.shards[shard_idx];
let mut pools = Vec::new();
let mut servers = Vec::new();
let mut replica_number = 0;
// Load Mirror settings
for (address_index, server) in shard.servers.iter().enumerate() {
let mut mirror_addresses = vec![];
if let Some(mirror_settings_vec) = &shard.mirrors {
for (mirror_idx, mirror_settings) in
mirror_settings_vec.iter().enumerate()
{
if mirror_settings.mirroring_target_index != address_index {
continue;
}
mirror_addresses.push(Address {
id: address_id,
database: shard.database.clone(),
host: mirror_settings.host.clone(),
port: mirror_settings.port,
role: server.role,
address_index: mirror_idx,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
username: user.username.clone(),
pool_name: pool_name.clone(),
mirrors: vec![],
});
address_id += 1;
}
}
let address = Address {
id: address_id,
database: shard.database.clone(),
host: server.host.clone(),
port: server.port,
role: server.role,
address_index,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
username: user.username.clone(),
pool_name: pool_name.clone(),
mirrors: mirror_addresses,
};
address_id += 1;
if server.role == Role::Replica {
replica_number += 1;
}
let manager = ServerPool::new(
address.clone(),
user.clone(),
&shard.database,
client_server_map.clone(),
get_reporter(),
);
let connect_timeout = match pool_config.connect_timeout {
Some(connect_timeout) => connect_timeout,
None => config.general.connect_timeout,
};
let idle_timeout = match pool_config.idle_timeout {
Some(idle_timeout) => idle_timeout,
None => config.general.idle_timeout,
};
let pool = Pool::builder()
.max_size(user.pool_size)
.connection_timeout(std::time::Duration::from_millis(connect_timeout))
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
.test_on_check_out(false)
.build(manager)
.await
.unwrap();
pools.push(pool);
servers.push(address);
}
shards.push(pools);
addresses.push(servers);
banlist.push(HashMap::new());
}
assert_eq!(shards.len(), addresses.len());
let pool = ConnectionPool {
databases: shards,
addresses,
banlist: Arc::new(RwLock::new(banlist)),
stats: get_reporter(),
config_hash: new_pool_hash_value,
server_info: Arc::new(RwLock::new(BytesMut::new())),
settings: PoolSettings {
pool_mode: pool_config.pool_mode,
load_balancing_mode: pool_config.load_balancing_mode,
// shards: pool_config.shards.clone(),
shards: shard_ids.len(),
user: user.clone(),
default_role: match pool_config.default_role.as_str() {
"any" => None,
"replica" => Some(Role::Replica),
"primary" => Some(Role::Primary),
_ => unreachable!(),
},
query_parser_enabled: pool_config.query_parser_enabled,
primary_reads_enabled: pool_config.primary_reads_enabled,
sharding_function: pool_config.sharding_function,
automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
healthcheck_delay: config.general.healthcheck_delay,
healthcheck_timeout: config.general.healthcheck_timeout,
ban_time: config.general.ban_time,
sharding_key_regex: pool_config
.sharding_key_regex
.clone()
.map(|regex| Regex::new(regex.as_str()).unwrap()),
shard_id_regex: pool_config
.shard_id_regex
.clone()
.map(|regex| Regex::new(regex.as_str()).unwrap()),
regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
},
validated: Arc::new(AtomicBool::new(false)),
paused: Arc::new(AtomicBool::new(false)),
paused_waiter: Arc::new(Notify::new()),
let mut secrets = match &user.secrets {
Some(_) => user
.secrets
.as_ref()
.unwrap()
.iter()
.map(|secret| Some(secret.to_string()))
.collect::<Vec<Option<String>>>(),
None => vec![],
};
// Connect to the servers to make sure pool configuration is valid
// before setting it globally.
// Do this async and somewhere else, we don't have to wait here.
let mut validate_pool = pool.clone();
tokio::task::spawn(async move {
let _ = validate_pool.validate().await;
});
secrets.push(None);
// There is one pool per database/user pair.
new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool);
for secret in secrets {
let old_pool_ref = get_pool(pool_name, &user.username, secret.clone());
let identifier = PoolIdentifier::new(pool_name, &user.username, secret.clone());
match old_pool_ref {
Some(pool) => {
// If the pool hasn't changed, get existing reference and insert it into the new_pools.
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
if pool.config_hash == new_pool_hash_value {
info!(
"[pool: {}][user: {}] has not changed",
pool_name, user.username
);
new_pools.insert(identifier.clone(), pool.clone());
continue;
}
}
None => (),
}
info!(
"[pool: {}][user: {}] creating new pool",
pool_name, user.username
);
let mut shards = Vec::new();
let mut addresses = Vec::new();
let mut banlist = Vec::new();
let mut shard_ids = pool_config
.shards
.clone()
.into_keys()
.collect::<Vec<String>>();
let pool_stats = Arc::new(PoolStats::new(identifier, pool_config.clone()));
// Allow the pool to be seen in statistics
pool_stats.register(pool_stats.clone());
// Sort by shard number to ensure consistency.
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
let pool_auth_hash: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
for shard_idx in &shard_ids {
let shard = &pool_config.shards[shard_idx];
let mut pools = Vec::new();
let mut servers = Vec::new();
let mut replica_number = 0;
// Load Mirror settings
for (address_index, server) in shard.servers.iter().enumerate() {
let mut mirror_addresses = vec![];
if let Some(mirror_settings_vec) = &shard.mirrors {
for (mirror_idx, mirror_settings) in
mirror_settings_vec.iter().enumerate()
{
if mirror_settings.mirroring_target_index != address_index {
continue;
}
mirror_addresses.push(Address {
id: address_id,
database: shard.database.clone(),
host: mirror_settings.host.clone(),
port: mirror_settings.port,
role: server.role,
address_index: mirror_idx,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
username: user.username.clone(),
pool_name: pool_name.clone(),
mirrors: vec![],
stats: Arc::new(AddressStats::default()),
});
address_id += 1;
}
}
let address = Address {
id: address_id,
database: shard.database.clone(),
host: server.host.clone(),
port: server.port,
role: server.role,
address_index,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
username: user.username.clone(),
pool_name: pool_name.clone(),
mirrors: mirror_addresses,
stats: Arc::new(AddressStats::default()),
};
address_id += 1;
if server.role == Role::Replica {
replica_number += 1;
}
// We assume every server in the pool share user/passwords
let auth_passthrough = AuthPassthrough::from_pool_config(pool_config);
if let Some(apt) = &auth_passthrough {
match apt.fetch_hash(&address).await {
Ok(ok) => {
if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) {
if ok != *pool_auth_hash_value {
warn!("Hash is not the same across shards of the same pool, client auth will \
be done using last obtained hash. Server: {}:{}, Database: {}", server.host, server.port, shard.database);
}
}
debug!("Hash obtained for {:?}", address);
{
let mut pool_auth_hash = pool_auth_hash.write();
*pool_auth_hash = Some(ok.clone());
}
},
Err(err) => warn!("Could not obtain password hashes using auth_query config, ignoring. Error: {:?}", err),
}
}
let manager = ServerPool::new(
address.clone(),
user.clone(),
&shard.database,
client_server_map.clone(),
pool_stats.clone(),
pool_auth_hash.clone(),
);
let connect_timeout = match pool_config.connect_timeout {
Some(connect_timeout) => connect_timeout,
None => config.general.connect_timeout,
};
let idle_timeout = match pool_config.idle_timeout {
Some(idle_timeout) => idle_timeout,
None => config.general.idle_timeout,
};
let pool = Pool::builder()
.max_size(user.pool_size)
.connection_timeout(std::time::Duration::from_millis(
connect_timeout,
))
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
.test_on_check_out(false)
.build(manager)
.await
.unwrap();
pools.push(pool);
servers.push(address);
}
shards.push(pools);
addresses.push(servers);
banlist.push(HashMap::new());
}
assert_eq!(shards.len(), addresses.len());
if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) {
info!(
"Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}",
pool_name, user.username
);
}
let pool = ConnectionPool {
databases: shards,
stats: pool_stats,
addresses,
banlist: Arc::new(RwLock::new(banlist)),
config_hash: new_pool_hash_value,
server_info: Arc::new(RwLock::new(BytesMut::new())),
auth_hash: pool_auth_hash,
settings: PoolSettings {
pool_mode: pool_config.pool_mode,
load_balancing_mode: pool_config.load_balancing_mode,
// shards: pool_config.shards.clone(),
shards: shard_ids.len(),
user: user.clone(),
default_role: match pool_config.default_role.as_str() {
"any" => None,
"replica" => Some(Role::Replica),
"primary" => Some(Role::Primary),
_ => unreachable!(),
},
query_parser_enabled: pool_config.query_parser_enabled,
primary_reads_enabled: pool_config.primary_reads_enabled,
sharding_function: pool_config.sharding_function,
automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
healthcheck_delay: config.general.healthcheck_delay,
healthcheck_timeout: config.general.healthcheck_timeout,
ban_time: config.general.ban_time,
sharding_key_regex: pool_config
.sharding_key_regex
.clone()
.map(|regex| Regex::new(regex.as_str()).unwrap()),
shard_id_regex: pool_config
.shard_id_regex
.clone()
.map(|regex| Regex::new(regex.as_str()).unwrap()),
regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
auth_query: pool_config.auth_query.clone(),
auth_query_user: pool_config.auth_query_user.clone(),
auth_query_password: pool_config.auth_query_password.clone(),
},
validated: Arc::new(AtomicBool::new(false)),
paused: Arc::new(AtomicBool::new(false)),
paused_waiter: Arc::new(Notify::new()),
};
// Connect to the servers to make sure pool configuration is valid
// before setting it globally.
// Do this async and somewhere else, we don't have to wait here.
let mut validate_pool = pool.clone();
tokio::task::spawn(async move {
let _ = validate_pool.validate().await;
});
// There is one pool per database/user pair.
new_pools.insert(PoolIdentifier::new(pool_name, &user.username, secret), pool);
}
}
}
@@ -387,7 +450,8 @@ impl ConnectionPool {
Ok(())
}
/// Connect to all shards and grab server information.
/// Connect to all shards, grab server information, and possibly
/// passwords to use in client auth.
/// Return server information we will pass to the clients
/// when they connect.
/// This also warms up the pool for clients that connect when
@@ -476,9 +540,9 @@ impl ConnectionPool {
/// Get a connection from the pool.
pub async fn get(
&self,
shard: usize, // shard number
role: Option<Role>, // primary or replica
client_process_id: i32, // client id
shard: usize, // shard number
role: Option<Role>, // primary or replica
client_stats: &ClientStats, // client id
) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
let mut candidates: Vec<&Address> = self.addresses[shard]
.iter()
@@ -517,7 +581,7 @@ impl ConnectionPool {
// Indicate we're waiting on a server connection from a pool.
let now = Instant::now();
self.stats.client_waiting(client_process_id);
client_stats.waiting();
// Check if we can connect
let mut conn = match self.databases[address.shard][address.address_index]
@@ -527,9 +591,10 @@ impl ConnectionPool {
Ok(conn) => conn,
Err(err) => {
error!("Banning instance {:?}, error: {:?}", address, err);
self.ban(address, BanReason::FailedCheckout, client_process_id);
self.stats
.client_checkout_error(client_process_id, address.id);
self.ban(address, BanReason::FailedCheckout, Some(client_stats));
address.stats.error();
client_stats.idle();
client_stats.checkout_error();
continue;
}
};
@@ -546,18 +611,18 @@ impl ConnectionPool {
// since we last checked the server is ok.
// Health checks are pretty expensive.
if !require_healthcheck {
self.stats.checkout_time(
now.elapsed().as_micros(),
client_process_id,
server.server_id(),
);
self.stats
.server_active(client_process_id, server.server_id());
let checkout_time: u64 = now.elapsed().as_micros() as u64;
client_stats.checkout_time(checkout_time);
server
.stats()
.checkout_time(checkout_time, client_stats.application_name());
server.stats().active(client_stats.application_name());
return Ok((conn, address.clone()));
}
if self
.run_health_check(address, server, now, client_process_id)
.run_health_check(address, server, now, client_stats)
.await
{
return Ok((conn, address.clone()));
@@ -565,7 +630,6 @@ impl ConnectionPool {
continue;
}
}
Err(Error::AllServersDown)
}
@@ -574,11 +638,11 @@ impl ConnectionPool {
address: &Address,
server: &mut Server,
start: Instant,
client_process_id: i32,
client_info: &ClientStats,
) -> bool {
debug!("Running health check on server {:?}", address);
self.stats.server_tested(server.server_id());
server.stats().tested();
match tokio::time::timeout(
tokio::time::Duration::from_millis(self.settings.healthcheck_timeout),
@@ -589,13 +653,13 @@ impl ConnectionPool {
// Check if health check succeeded.
Ok(res) => match res {
Ok(_) => {
self.stats.checkout_time(
start.elapsed().as_micros(),
client_process_id,
server.server_id(),
);
self.stats
.server_active(client_process_id, server.server_id());
let checkout_time: u64 = start.elapsed().as_micros() as u64;
client_info.checkout_time(checkout_time);
server
.stats()
.checkout_time(checkout_time, client_info.application_name());
server.stats().active(client_info.application_name());
return true;
}
@@ -620,14 +684,14 @@ impl ConnectionPool {
// Don't leave a bad connection in the pool.
server.mark_bad();
self.ban(&address, BanReason::FailedHealthCheck, client_process_id);
self.ban(&address, BanReason::FailedHealthCheck, Some(client_info));
return false;
}
/// Ban an address (i.e. replica). It no longer will serve
/// traffic for any new transactions. Existing transactions on that replica
/// will finish successfully or error out to the clients.
pub fn ban(&self, address: &Address, reason: BanReason, client_id: i32) {
pub fn ban(&self, address: &Address, reason: BanReason, client_info: Option<&ClientStats>) {
// Primary can never be banned
if address.role == Role::Primary {
return;
@@ -636,7 +700,10 @@ impl ConnectionPool {
let now = chrono::offset::Utc::now().naive_utc();
let mut guard = self.banlist.write();
error!("Banning {:?}", address);
self.stats.client_ban_error(client_id, address.id);
if let Some(client_info) = client_info {
client_info.ban_error();
address.stats.error();
}
guard[address.shard].insert(address.clone(), (reason, now));
}
@@ -797,7 +864,8 @@ pub struct ServerPool {
user: User,
database: String,
client_server_map: ClientServerMap,
stats: Reporter,
stats: Arc<PoolStats>,
auth_hash: Arc<RwLock<Option<String>>>,
}
impl ServerPool {
@@ -806,14 +874,16 @@ impl ServerPool {
user: User,
database: &str,
client_server_map: ClientServerMap,
stats: Reporter,
stats: Arc<PoolStats>,
auth_hash: Arc<RwLock<Option<String>>>,
) -> ServerPool {
ServerPool {
address,
user,
user: user.clone(),
database: database.to_string(),
client_server_map,
stats,
auth_hash,
}
}
}
@@ -826,34 +896,32 @@ impl ManageConnection for ServerPool {
/// Attempts to create a new connection.
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
info!("Creating a new server connection {:?}", self.address);
let server_id = rand::random::<i32>();
self.stats.server_register(
server_id,
self.address.id,
self.address.name(),
self.address.pool_name.clone(),
self.address.username.clone(),
);
self.stats.server_login(server_id);
let stats = Arc::new(ServerStats::new(
self.address.clone(),
self.stats.clone(),
tokio::time::Instant::now(),
));
stats.register(stats.clone());
// Connect to the PostgreSQL server.
match Server::startup(
server_id,
&self.address,
&self.user,
&self.database,
self.client_server_map.clone(),
self.stats.clone(),
stats.clone(),
self.auth_hash.clone(),
)
.await
{
Ok(conn) => {
self.stats.server_idle(server_id);
stats.idle();
Ok(conn)
}
Err(err) => {
self.stats.server_disconnecting(server_id);
stats.disconnect();
Err(err)
}
}
@@ -871,21 +939,13 @@ impl ManageConnection for ServerPool {
}
/// Get the connection pool
pub fn get_pool(db: &str, user: &str) -> Option<ConnectionPool> {
(*(*POOLS.load()))
.get(&PoolIdentifier::new(db, user))
.cloned()
pub fn get_pool(db: &str, user: &str, secret: Option<String>) -> Option<ConnectionPool> {
let identifier = PoolIdentifier::new(db, user, secret);
(*(*POOLS.load())).get(&identifier).cloned()
}
/// Get a pointer to all configured pools.
pub fn get_all_pools() -> HashMap<PoolIdentifier, ConnectionPool> {
(*(*POOLS.load())).clone()
}
/// How many total servers we have in the config.
pub fn get_number_of_addresses() -> usize {
get_all_pools()
.iter()
.map(|(_, pool)| pool.databases())
.sum()
}

View File

@@ -5,10 +5,12 @@ use phf::phf_map;
use std::collections::HashMap;
use std::fmt;
use std::net::SocketAddr;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use crate::config::Address;
use crate::pool::get_all_pools;
use crate::stats::{get_address_stats, get_pool_stats, get_server_stats, ServerInformation};
use crate::pool::{get_all_pools, PoolIdentifier};
use crate::stats::{get_pool_stats, get_server_stats, ServerStats};
struct MetricHelpType {
help: &'static str,
@@ -220,7 +222,7 @@ impl<Value: fmt::Display> PrometheusMetric<Value> {
Self::from_name(&format!("servers_{}", name), value, labels)
}
fn from_address(address: &Address, name: &str, value: i64) -> Option<PrometheusMetric<i64>> {
fn from_address(address: &Address, name: &str, value: u64) -> Option<PrometheusMetric<u64>> {
let mut labels = HashMap::new();
labels.insert("host", address.host.clone());
labels.insert("shard", address.shard.to_string());
@@ -231,10 +233,10 @@ impl<Value: fmt::Display> PrometheusMetric<Value> {
Self::from_name(&format!("stats_{}", name), value, labels)
}
fn from_pool(pool: &(String, String), name: &str, value: i64) -> Option<PrometheusMetric<i64>> {
fn from_pool(pool: &PoolIdentifier, name: &str, value: u64) -> Option<PrometheusMetric<u64>> {
let mut labels = HashMap::new();
labels.insert("pool", pool.0.clone());
labels.insert("user", pool.1.clone());
labels.insert("pool", pool.db.clone());
labels.insert("user", pool.user.clone());
Self::from_name(&format!("pools_{}", name), value, labels)
}
@@ -261,20 +263,18 @@ async fn prometheus_stats(request: Request<Body>) -> Result<Response<Body>, hype
// Adds metrics shown in a SHOW STATS admin command.
fn push_address_stats(lines: &mut Vec<String>) {
let address_stats: HashMap<usize, HashMap<String, i64>> = get_address_stats();
for (_, pool) in get_all_pools() {
for shard in 0..pool.shards() {
for server in 0..pool.servers(shard) {
let address = pool.address(shard, server);
if let Some(address_stats) = address_stats.get(&address.id) {
for (key, value) in address_stats.iter() {
if let Some(prometheus_metric) =
PrometheusMetric::<i64>::from_address(address, key, *value)
{
lines.push(prometheus_metric.to_string());
} else {
warn!("Metric {} not implemented for {}", key, address.name());
}
let stats = &*address.stats;
for (key, value) in stats.clone() {
if let Some(prometheus_metric) =
PrometheusMetric::<u64>::from_address(address, &key, value)
{
lines.push(prometheus_metric.to_string());
} else {
warn!("Metric {} not implemented for {}", key, address.name());
}
}
}
@@ -286,14 +286,15 @@ fn push_address_stats(lines: &mut Vec<String>) {
fn push_pool_stats(lines: &mut Vec<String>) {
let pool_stats = get_pool_stats();
for (pool, stats) in pool_stats.iter() {
for (name, value) in stats.iter() {
if let Some(prometheus_metric) = PrometheusMetric::<i64>::from_pool(pool, name, *value)
let stats = &**stats;
for (name, value) in stats.clone() {
if let Some(prometheus_metric) = PrometheusMetric::<u64>::from_pool(pool, &name, value)
{
lines.push(prometheus_metric.to_string());
} else {
warn!(
"Metric {} not implemented for ({},{})",
name, pool.0, pool.1
name, pool.db, pool.user
);
}
}
@@ -330,9 +331,9 @@ fn push_database_stats(lines: &mut Vec<String>) {
// Adds relevant metrics shown in a SHOW SERVERS admin command.
fn push_server_stats(lines: &mut Vec<String>) {
let server_stats = get_server_stats();
let mut server_stats_by_addresses = HashMap::<String, ServerInformation>::new();
for (_, info) in server_stats {
server_stats_by_addresses.insert(info.address_name.clone(), info);
let mut server_stats_by_addresses = HashMap::<String, Arc<ServerStats>>::new();
for (_, stats) in server_stats {
server_stats_by_addresses.insert(stats.address_name(), stats);
}
for (_, pool) in get_all_pools() {
@@ -341,11 +342,23 @@ fn push_server_stats(lines: &mut Vec<String>) {
let address = pool.address(shard, server);
if let Some(server_info) = server_stats_by_addresses.get(&address.name()) {
let metrics = [
("bytes_received", server_info.bytes_received),
("bytes_sent", server_info.bytes_sent),
("transaction_count", server_info.transaction_count),
("query_count", server_info.query_count),
("error_count", server_info.error_count),
(
"bytes_received",
server_info.bytes_received.load(Ordering::Relaxed),
),
("bytes_sent", server_info.bytes_sent.load(Ordering::Relaxed)),
(
"transaction_count",
server_info.transaction_count.load(Ordering::Relaxed),
),
(
"query_count",
server_info.query_count.load(Ordering::Relaxed),
),
(
"error_count",
server_info.error_count.load(Ordering::Relaxed),
),
];
for (key, value) in metrics {
if let Some(prometheus_metric) =

View File

@@ -1110,6 +1110,9 @@ mod test {
sharding_key_regex: None,
shard_id_regex: None,
regex_search_limit: 1000,
auth_query: None,
auth_query_password: None,
auth_query_user: None,
};
let mut qr = QueryRouter::new();
assert_eq!(qr.active_role, None);
@@ -1171,6 +1174,9 @@ mod test {
sharding_key_regex: Some(Regex::new(r"/\* sharding_key: (\d+) \*/").unwrap()),
shard_id_regex: Some(Regex::new(r"/\* shard_id: (\d+) \*/").unwrap()),
regex_search_limit: 1000,
auth_query: None,
auth_query_password: None,
auth_query_user: None,
};
let mut qr = QueryRouter::new();
qr.update_pool_settings(pool_settings.clone());

View File

@@ -1,8 +1,13 @@
/// Implementation of the PostgreSQL server (database) protocol.
/// Here we are pretending to the a Postgres client.
use bytes::{Buf, BufMut, BytesMut};
use fallible_iterator::FallibleIterator;
use log::{debug, error, info, trace, warn};
use parking_lot::{Mutex, RwLock};
use postgres_protocol::message;
use std::collections::HashMap;
use std::io::Read;
use std::sync::Arc;
use std::time::SystemTime;
use tokio::io::{AsyncReadExt, BufReader};
use tokio::net::{
@@ -17,12 +22,10 @@ use crate::messages::*;
use crate::mirrors::MirroringManager;
use crate::pool::ClientServerMap;
use crate::scram::ScramSha256;
use crate::stats::Reporter;
use crate::stats::ServerStats;
/// Server state.
pub struct Server {
server_id: i32,
/// Server host, e.g. localhost,
/// port, e.g. 5432, and role, e.g. primary or replica.
address: Address,
@@ -62,7 +65,7 @@ pub struct Server {
connected_at: chrono::naive::NaiveDateTime,
/// Reports various metrics, e.g. data sent & received.
stats: Reporter,
stats: Arc<ServerStats>,
/// Application name using the server at the moment.
application_name: String,
@@ -77,12 +80,12 @@ impl Server {
/// Pretend to be the Postgres client and connect to the server given host, port and credentials.
/// Perform the authentication and return the server in a ready for query state.
pub async fn startup(
server_id: i32,
address: &Address,
user: &User,
database: &str,
client_server_map: ClientServerMap,
stats: Reporter,
stats: Arc<ServerStats>,
auth_hash: Arc<RwLock<Option<String>>>,
) -> Result<Server, Error> {
let mut stream =
match TcpStream::connect(&format!("{}:{}", &address.host, address.port)).await {
@@ -108,7 +111,10 @@ impl Server {
// We'll be handling multiple packets, but they will all be structured the same.
// We'll loop here until this exchange is complete.
let mut scram = ScramSha256::new(&user.password);
let mut scram: Option<ScramSha256> = None;
if let Some(password) = &user.password.clone() {
scram = Some(ScramSha256::new(password));
}
loop {
let code = match stream.read_u8().await {
@@ -145,13 +151,40 @@ impl Server {
Err(_) => return Err(Error::SocketError(format!("Error reading salt on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
md5_password(&mut stream, &user.username, &user.password, &salt[..])
.await?;
match &user.password {
// Using plaintext password
Some(password) => {
md5_password(&mut stream, &user.username, password, &salt[..])
.await?
}
// Using auth passthrough, in this case we should already have a
// hash obtained when the pool was validated. If we reach this point
// and don't have a hash, we return an error.
None => {
let option_hash = (*auth_hash.read()).clone();
match option_hash {
Some(hash) =>
md5_password_with_hash(
&mut stream,
&hash,
&salt[..],
)
.await?,
None =>
return Err(Error::AuthError(format!("Auth passthrough (auth_query) failed and no user password is set in cleartext for {{ username: {:?}, database: {:?} }}", user.username, database)))
}
}
}
}
AUTHENTICATION_SUCCESSFUL => (),
SASL => {
if scram.is_none() {
return Err(Error::AuthError(format!("SASL auth required and not password specified, auth passthrough (auth_query) method is currently unsupported for SASL auth {{ username: {:?}, database: {:?} }}", user.username, database)));
}
debug!("Starting SASL authentication");
let sasl_len = (len - 8) as usize;
let mut sasl_auth = vec![0u8; sasl_len];
@@ -167,7 +200,7 @@ impl Server {
debug!("Using {}", SCRAM_SHA_256);
// Generate client message.
let sasl_response = scram.message();
let sasl_response = scram.as_mut().unwrap().message();
// SASLInitialResponse (F)
let mut res = BytesMut::new();
@@ -204,7 +237,7 @@ impl Server {
};
let msg = BytesMut::from(&sasl_data[..]);
let sasl_response = scram.update(&msg)?;
let sasl_response = scram.as_mut().unwrap().update(&msg)?;
// SASLResponse
let mut res = BytesMut::new();
@@ -224,7 +257,11 @@ impl Server {
Err(_) => return Err(Error::SocketError(format!("Error reading sasl final message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
match scram.finish(&BytesMut::from(&sasl_final[..])) {
match scram
.as_mut()
.unwrap()
.finish(&BytesMut::from(&sasl_final[..]))
{
Ok(_) => {
debug!("SASL authentication successful");
}
@@ -325,7 +362,6 @@ impl Server {
write,
buffer: BytesMut::with_capacity(8196),
server_info,
server_id,
process_id,
secret_key,
in_transaction: false,
@@ -396,7 +432,7 @@ impl Server {
/// Send messages to the server from the client.
pub async fn send(&mut self, messages: &BytesMut) -> Result<(), Error> {
self.mirror_send(messages);
self.stats.data_sent(messages.len(), self.server_id);
self.stats().data_sent(messages.len());
match write_all_half(&mut self.write, messages).await {
Ok(_) => {
@@ -545,7 +581,7 @@ impl Server {
let bytes = self.buffer.clone();
// Keep track of how much data we got from the server for stats.
self.stats.data_received(bytes.len(), self.server_id);
self.stats().data_received(bytes.len());
// Clear the buffer for next query.
self.buffer.clear();
@@ -665,18 +701,17 @@ impl Server {
}
}
/// get Server stats
pub fn stats(&self) -> Arc<ServerStats> {
self.stats.clone()
}
/// Get the servers address.
#[allow(dead_code)]
pub fn address(&self) -> Address {
self.address.clone()
}
/// Get the server connection identifier
/// Used to uniquely identify connection in statistics
pub fn server_id(&self) -> i32 {
self.server_id
}
// Get server's latest response timestamp
pub fn last_activity(&self) -> SystemTime {
self.last_activity
@@ -700,6 +735,112 @@ impl Server {
None => (),
}
}
// This is so we can execute out of band queries to the server.
// The connection will be opened, the query executed and closed.
pub async fn exec_simple_query(
address: &Address,
user: &User,
query: &str,
) -> Result<Vec<String>, Error> {
let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new()));
debug!("Connecting to server to obtain auth hashes.");
let mut server = Server::startup(
address,
user,
&address.database,
client_server_map,
Arc::new(ServerStats::default()),
Arc::new(RwLock::new(None)),
)
.await?;
debug!("Connected!, sending query: {}", query);
server.send(&simple_query(query)).await?;
let mut message = server.recv().await?;
Ok(parse_query_message(&mut message).await?)
}
}
async fn parse_query_message(message: &mut BytesMut) -> Result<Vec<String>, Error> {
debug!("Parsing query message");
let mut pair = Vec::<String>::new();
match message::backend::Message::parse(message) {
Ok(Some(message::backend::Message::RowDescription(_description))) => {}
Ok(Some(message::backend::Message::ErrorResponse(err))) => {
return Err(Error::ProtocolSyncError(format!(
"Protocol error parsing response. Err: {:?}",
err.fields()
.iterator()
.fold(String::default(), |acc, element| acc
+ element.unwrap().value())
)))
}
Ok(_) => {
return Err(Error::ProtocolSyncError(
"Protocol error, expected Row Description.".to_string(),
))
}
Err(err) => {
return Err(Error::ProtocolSyncError(format!(
"Protocol error parsing response. Err: {:?}",
err
)))
}
}
while !message.is_empty() {
match message::backend::Message::parse(message) {
Ok(postgres_message) => {
match postgres_message {
Some(message::backend::Message::DataRow(data)) => {
let buf = data.buffer();
trace!("Data: {:?}", buf);
for item in data.ranges().iterator() {
match item.as_ref() {
Ok(range) => match range {
Some(range) => {
pair.push(String::from_utf8_lossy(&buf[range.clone()]).to_string());
}
None => return Err(Error::ProtocolSyncError(String::from(
"Data expected while receiving query auth data, found nothing.",
))),
},
Err(err) => {
return Err(Error::ProtocolSyncError(format!(
"Data error, err: {:?}",
err
)))
}
}
}
}
Some(message::backend::Message::CommandComplete(_)) => {}
Some(message::backend::Message::ReadyForQuery(_)) => {}
_ => {
return Err(Error::ProtocolSyncError(
"Unexpected message while receiving auth query data.".to_string(),
))
}
}
}
Err(err) => {
return Err(Error::ProtocolSyncError(format!(
"Parse error, err: {:?}",
err
)))
}
};
}
debug!("Got auth hash successfully");
Ok(pair)
}
impl Drop for Server {
@@ -708,7 +849,9 @@ impl Drop for Server {
/// for a write.
fn drop(&mut self) {
self.mirror_disconnect();
self.stats.server_disconnecting(self.server_id);
// Update statistics
self.stats.disconnect();
let mut bytes = BytesMut::with_capacity(4);
bytes.put_u8(b'X');

File diff suppressed because it is too large Load Diff

149
src/stats/address.rs Normal file
View File

@@ -0,0 +1,149 @@
use log::warn;
use std::sync::atomic::*;
use std::sync::Arc;
/// Internal address stats
#[derive(Debug, Clone, Default)]
pub struct AddressStats {
pub total_xact_count: Arc<AtomicU64>,
pub total_query_count: Arc<AtomicU64>,
pub total_received: Arc<AtomicU64>,
pub total_sent: Arc<AtomicU64>,
pub total_xact_time: Arc<AtomicU64>,
pub total_query_time: Arc<AtomicU64>,
pub total_wait_time: Arc<AtomicU64>,
pub total_errors: Arc<AtomicU64>,
pub avg_query_count: Arc<AtomicU64>,
pub avg_query_time: Arc<AtomicU64>,
pub avg_recv: Arc<AtomicU64>,
pub avg_sent: Arc<AtomicU64>,
pub avg_errors: Arc<AtomicU64>,
pub avg_xact_time: Arc<AtomicU64>,
pub avg_xact_count: Arc<AtomicU64>,
pub avg_wait_time: Arc<AtomicU64>,
}
impl IntoIterator for AddressStats {
type Item = (String, u64);
type IntoIter = std::vec::IntoIter<Self::Item>;
fn into_iter(self) -> Self::IntoIter {
vec![
(
"total_xact_count".to_string(),
self.total_xact_count.load(Ordering::Relaxed),
),
(
"total_query_count".to_string(),
self.total_query_count.load(Ordering::Relaxed),
),
(
"total_received".to_string(),
self.total_received.load(Ordering::Relaxed),
),
(
"total_sent".to_string(),
self.total_sent.load(Ordering::Relaxed),
),
(
"total_xact_time".to_string(),
self.total_xact_time.load(Ordering::Relaxed),
),
(
"total_query_time".to_string(),
self.total_query_time.load(Ordering::Relaxed),
),
(
"total_wait_time".to_string(),
self.total_wait_time.load(Ordering::Relaxed),
),
(
"total_errors".to_string(),
self.total_errors.load(Ordering::Relaxed),
),
(
"avg_xact_count".to_string(),
self.avg_xact_count.load(Ordering::Relaxed),
),
(
"avg_query_count".to_string(),
self.avg_query_count.load(Ordering::Relaxed),
),
(
"avg_recv".to_string(),
self.avg_recv.load(Ordering::Relaxed),
),
(
"avg_sent".to_string(),
self.avg_sent.load(Ordering::Relaxed),
),
(
"avg_errors".to_string(),
self.avg_errors.load(Ordering::Relaxed),
),
(
"avg_xact_time".to_string(),
self.avg_xact_time.load(Ordering::Relaxed),
),
(
"avg_query_time".to_string(),
self.avg_query_time.load(Ordering::Relaxed),
),
(
"avg_wait_time".to_string(),
self.avg_wait_time.load(Ordering::Relaxed),
),
]
.into_iter()
}
}
impl AddressStats {
pub fn error(&self) {
self.total_errors.fetch_add(1, Ordering::Relaxed);
}
pub fn update_averages(&self) {
let (totals, averages) = self.fields_iterators();
for data in totals.iter().zip(averages.iter()) {
let (total, average) = data;
if let Err(err) = average.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |avg| {
let total = total.load(Ordering::Relaxed);
let avg = (total - avg) / (crate::stats::STAT_PERIOD / 1_000); // Avg / second
Some(avg)
}) {
warn!("Could not update averages for addresses stats, {:?}", err);
}
}
}
pub fn populate_row(&self, row: &mut Vec<String>) {
for (_key, value) in self.clone() {
row.push(value.to_string());
}
}
fn fields_iterators(&self) -> (Vec<Arc<AtomicU64>>, Vec<Arc<AtomicU64>>) {
let mut totals: Vec<Arc<AtomicU64>> = Vec::new();
let mut averages: Vec<Arc<AtomicU64>> = Vec::new();
totals.push(self.total_xact_count.clone());
averages.push(self.avg_xact_count.clone());
totals.push(self.total_query_count.clone());
averages.push(self.avg_query_count.clone());
totals.push(self.total_received.clone());
averages.push(self.avg_recv.clone());
totals.push(self.total_sent.clone());
averages.push(self.avg_sent.clone());
totals.push(self.total_xact_time.clone());
averages.push(self.avg_xact_time.clone());
totals.push(self.total_query_time.clone());
averages.push(self.avg_query_time.clone());
totals.push(self.total_wait_time.clone());
averages.push(self.avg_wait_time.clone());
totals.push(self.total_errors.clone());
averages.push(self.avg_errors.clone());
(totals, averages)
}
}

182
src/stats/client.rs Normal file
View File

@@ -0,0 +1,182 @@
use super::PoolStats;
use super::{get_reporter, Reporter};
use atomic_enum::atomic_enum;
use std::sync::atomic::*;
use std::sync::Arc;
use tokio::time::Instant;
/// The various states that a client can be in
#[atomic_enum]
#[derive(PartialEq)]
pub enum ClientState {
Idle = 0,
Waiting,
Active,
}
impl std::fmt::Display for ClientState {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match *self {
ClientState::Idle => write!(f, "idle"),
ClientState::Waiting => write!(f, "waiting"),
ClientState::Active => write!(f, "active"),
}
}
}
#[derive(Debug, Clone)]
/// Information we keep track of which can be queried by SHOW CLIENTS
pub struct ClientStats {
/// A random integer assigned to the client and used by stats to track the client
client_id: i32,
/// Data associated with the client, not writable, only set when we construct the ClientStat
application_name: String,
username: String,
pool_name: String,
connect_time: Instant,
pool_stats: Arc<PoolStats>,
reporter: Reporter,
/// Total time spent waiting for a connection from pool, measures in microseconds
pub total_wait_time: Arc<AtomicU64>,
/// Current state of the client
pub state: Arc<AtomicClientState>,
/// Number of transactions executed by this client
pub transaction_count: Arc<AtomicU64>,
/// Number of queries executed by this client
pub query_count: Arc<AtomicU64>,
/// Number of errors made by this client
pub error_count: Arc<AtomicU64>,
}
impl Default for ClientStats {
fn default() -> Self {
ClientStats {
client_id: 0,
connect_time: Instant::now(),
application_name: String::new(),
username: String::new(),
pool_name: String::new(),
pool_stats: Arc::new(PoolStats::default()),
total_wait_time: Arc::new(AtomicU64::new(0)),
state: Arc::new(AtomicClientState::new(ClientState::Idle)),
transaction_count: Arc::new(AtomicU64::new(0)),
query_count: Arc::new(AtomicU64::new(0)),
error_count: Arc::new(AtomicU64::new(0)),
reporter: get_reporter(),
}
}
}
impl ClientStats {
pub fn new(
client_id: i32,
application_name: &str,
username: &str,
pool_name: &str,
connect_time: Instant,
pool_stats: Arc<PoolStats>,
) -> Self {
Self {
client_id,
pool_stats,
connect_time,
application_name: application_name.to_string(),
username: username.to_string(),
pool_name: pool_name.to_string(),
..Default::default()
}
}
/// Reports a client is disconecting from the pooler and
/// update metrics on the corresponding pool.
pub fn disconnect(&self) {
self.reporter.client_disconnecting(self.client_id);
self.pool_stats
.client_disconnect(self.state.load(Ordering::Relaxed))
}
/// Register a client with the stats system. The stats system uses client_id
/// to track and aggregate statistics from all source that relate to that client
pub fn register(&self, stats: Arc<ClientStats>) {
self.reporter.client_register(self.client_id, stats);
self.state.store(ClientState::Idle, Ordering::Relaxed);
self.pool_stats.cl_idle.fetch_add(1, Ordering::Relaxed);
}
/// Reports a client is done querying the server and is no longer assigned a server connection
pub fn idle(&self) {
self.pool_stats
.client_idle(self.state.load(Ordering::Relaxed));
self.state.store(ClientState::Idle, Ordering::Relaxed);
}
/// Reports a client is waiting for a connection
pub fn waiting(&self) {
self.pool_stats
.client_waiting(self.state.load(Ordering::Relaxed));
self.state.store(ClientState::Waiting, Ordering::Relaxed);
}
/// Reports a client is done waiting for a connection and is about to query the server.
pub fn active(&self) {
self.pool_stats
.client_active(self.state.load(Ordering::Relaxed));
self.state.store(ClientState::Active, Ordering::Relaxed);
}
/// Reports a client has failed to obtain a connection from a connection pool
pub fn checkout_error(&self) {
self.state.store(ClientState::Idle, Ordering::Relaxed);
}
/// Reports a client has had the server assigned to it be banned
pub fn ban_error(&self) {
self.state.store(ClientState::Idle, Ordering::Relaxed);
self.error_count.fetch_add(1, Ordering::Relaxed);
}
/// Reportes the time spent by a client waiting to get a healthy connection from the pool
pub fn checkout_time(&self, microseconds: u64) {
self.total_wait_time
.fetch_add(microseconds, Ordering::Relaxed);
}
/// Report a query executed by a client against a server
pub fn query(&self) {
self.query_count.fetch_add(1, Ordering::Relaxed);
}
/// Report a transaction executed by a client a server
/// we report each individual queries outside a transaction as a transaction
/// We only count the initial BEGIN as a transaction, all queries within do not
/// count as transactions
pub fn transaction(&self) {
self.transaction_count.fetch_add(1, Ordering::Relaxed);
}
// Helper methods for show clients
pub fn connect_time(&self) -> Instant {
self.connect_time
}
pub fn client_id(&self) -> i32 {
self.client_id
}
pub fn application_name(&self) -> String {
self.application_name.clone()
}
pub fn username(&self) -> String {
self.username.clone()
}
pub fn pool_name(&self) -> String {
self.pool_name.clone()
}
}

281
src/stats/pool.rs Normal file
View File

@@ -0,0 +1,281 @@
use crate::config::Pool;
use crate::config::PoolMode;
use crate::pool::PoolIdentifier;
use std::sync::atomic::*;
use std::sync::Arc;
use super::get_reporter;
use super::Reporter;
use super::{ClientState, ServerState};
#[derive(Debug, Clone, Default)]
/// A struct that holds information about a Pool .
pub struct PoolStats {
// Pool identifier, cannot be changed after creating the instance
identifier: PoolIdentifier,
// Pool Config, cannot be changed after creating the instance
config: Pool,
// A reference to the global reporter.
reporter: Reporter,
/// Counters (atomics)
pub cl_idle: Arc<AtomicU64>,
pub cl_active: Arc<AtomicU64>,
pub cl_waiting: Arc<AtomicU64>,
pub cl_cancel_req: Arc<AtomicU64>,
pub sv_active: Arc<AtomicU64>,
pub sv_idle: Arc<AtomicU64>,
pub sv_used: Arc<AtomicU64>,
pub sv_tested: Arc<AtomicU64>,
pub sv_login: Arc<AtomicU64>,
pub maxwait: Arc<AtomicU64>,
}
impl IntoIterator for PoolStats {
type Item = (String, u64);
type IntoIter = std::vec::IntoIter<Self::Item>;
fn into_iter(self) -> Self::IntoIter {
vec![
("cl_idle".to_string(), self.cl_idle.load(Ordering::Relaxed)),
(
"cl_active".to_string(),
self.cl_active.load(Ordering::Relaxed),
),
(
"cl_waiting".to_string(),
self.cl_waiting.load(Ordering::Relaxed),
),
(
"cl_cancel_req".to_string(),
self.cl_cancel_req.load(Ordering::Relaxed),
),
(
"sv_active".to_string(),
self.sv_active.load(Ordering::Relaxed),
),
("sv_idle".to_string(), self.sv_idle.load(Ordering::Relaxed)),
("sv_used".to_string(), self.sv_used.load(Ordering::Relaxed)),
(
"sv_tested".to_string(),
self.sv_tested.load(Ordering::Relaxed),
),
(
"sv_login".to_string(),
self.sv_login.load(Ordering::Relaxed),
),
(
"maxwait".to_string(),
self.maxwait.load(Ordering::Relaxed) / 1_000_000,
),
(
"maxwait_us".to_string(),
self.maxwait.load(Ordering::Relaxed) % 1_000_000,
),
]
.into_iter()
}
}
impl PoolStats {
pub fn new(identifier: PoolIdentifier, config: Pool) -> Self {
Self {
identifier,
config,
reporter: get_reporter(),
..Default::default()
}
}
// Getters
pub fn register(&self, stats: Arc<PoolStats>) {
self.reporter.pool_register(self.identifier.clone(), stats);
}
pub fn database(&self) -> String {
self.identifier.db.clone()
}
pub fn user(&self) -> String {
self.identifier.user.clone()
}
pub fn redacted_secret(&self) -> String {
match self.identifier.secret {
Some(ref s) => format!("****{}", &s[s.len() - 4..]),
None => "<no secret>".to_string(),
}
}
pub fn pool_mode(&self) -> PoolMode {
self.config.pool_mode
}
/// Populates an array of strings with counters (used by admin in show pools)
pub fn populate_row(&self, row: &mut Vec<String>) {
for (_key, value) in self.clone() {
row.push(value.to_string());
}
}
/// Deletes the maxwait counter, this is done everytime we obtain metrics
pub fn clear_maxwait(&self) {
self.maxwait.store(0, Ordering::Relaxed);
}
/// Notified when a server of the pool enters login state.
///
/// Arguments:
///
/// `from`: The state of the server that notifies.
pub fn server_login(&self, from: ServerState) {
self.sv_login.fetch_add(1, Ordering::Relaxed);
if from != ServerState::Login {
self.decrease_from_server_state(from);
}
}
/// Notified when a server of the pool become 'active'
///
/// Arguments:
///
/// `from`: The state of the server that notifies.
pub fn server_active(&self, from: ServerState) {
self.sv_active.fetch_add(1, Ordering::Relaxed);
if from != ServerState::Active {
self.decrease_from_server_state(from);
}
}
/// Notified when a server of the pool become 'tested'
///
/// Arguments:
///
/// `from`: The state of the server that notifies.
pub fn server_tested(&self, from: ServerState) {
self.sv_tested.fetch_add(1, Ordering::Relaxed);
if from != ServerState::Tested {
self.decrease_from_server_state(from);
}
}
/// Notified when a server of the pool become 'idle'
///
/// Arguments:
///
/// `from`: The state of the server that notifies.
pub fn server_idle(&self, from: ServerState) {
self.sv_idle.fetch_add(1, Ordering::Relaxed);
if from != ServerState::Idle {
self.decrease_from_server_state(from);
}
}
/// Notified when a client of the pool become 'waiting'
///
/// Arguments:
///
/// `from`: The state of the client that notifies.
pub fn client_waiting(&self, from: ClientState) {
if from != ClientState::Waiting {
self.cl_waiting.fetch_add(1, Ordering::Relaxed);
self.decrease_from_client_state(from);
}
}
/// Notified when a client of the pool become 'active'
///
/// Arguments:
///
/// `from`: The state of the client that notifies.
pub fn client_active(&self, from: ClientState) {
if from != ClientState::Active {
self.cl_active.fetch_add(1, Ordering::Relaxed);
self.decrease_from_client_state(from);
}
}
/// Notified when a client of the pool become 'idle'
///
/// Arguments:
///
/// `from`: The state of the client that notifies.
pub fn client_idle(&self, from: ClientState) {
if from != ClientState::Idle {
self.cl_idle.fetch_add(1, Ordering::Relaxed);
self.decrease_from_client_state(from);
}
}
/// Notified when a client disconnects.
///
/// Arguments:
///
/// `from`: The state of the client that notifies.
pub fn client_disconnect(&self, from: ClientState) {
let counter = match from {
ClientState::Idle => &self.cl_idle,
ClientState::Waiting => &self.cl_waiting,
ClientState::Active => &self.cl_active,
};
Self::decrease_counter(counter.clone());
}
/// Notified when a server disconnects.
///
/// Arguments:
///
/// `from`: The state of the client that notifies.
pub fn server_disconnect(&self, from: ServerState) {
let counter = match from {
ServerState::Active => &self.sv_active,
ServerState::Idle => &self.sv_idle,
ServerState::Login => &self.sv_login,
ServerState::Tested => &self.sv_tested,
};
Self::decrease_counter(counter.clone());
}
// helpers for counter decrease
fn decrease_from_server_state(&self, from: ServerState) {
let counter = match from {
ServerState::Tested => &self.sv_tested,
ServerState::Active => &self.sv_active,
ServerState::Idle => &self.sv_idle,
ServerState::Login => &self.sv_login,
};
Self::decrease_counter(counter.clone());
}
fn decrease_from_client_state(&self, from: ClientState) {
let counter = match from {
ClientState::Active => &self.cl_active,
ClientState::Idle => &self.cl_idle,
ClientState::Waiting => &self.cl_waiting,
};
Self::decrease_counter(counter.clone());
}
fn decrease_counter(value: Arc<AtomicU64>) {
if value.load(Ordering::Relaxed) > 0 {
value.fetch_sub(1, Ordering::Relaxed);
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_decrease() {
let stat: PoolStats = PoolStats::default();
stat.server_login(ServerState::Login);
stat.server_idle(ServerState::Login);
assert_eq!(stat.sv_login.load(Ordering::Relaxed), 0);
assert_eq!(stat.sv_idle.load(Ordering::Relaxed), 1);
}
}

226
src/stats/server.rs Normal file
View File

@@ -0,0 +1,226 @@
use super::AddressStats;
use super::PoolStats;
use super::{get_reporter, Reporter};
use crate::config::Address;
use atomic_enum::atomic_enum;
use parking_lot::RwLock;
use std::sync::atomic::*;
use std::sync::Arc;
use tokio::time::Instant;
/// The various states that a server can be in
#[atomic_enum]
#[derive(PartialEq)]
pub enum ServerState {
Login = 0,
Active,
Tested,
Idle,
}
impl std::fmt::Display for ServerState {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match *self {
ServerState::Login => write!(f, "login"),
ServerState::Active => write!(f, "active"),
ServerState::Tested => write!(f, "tested"),
ServerState::Idle => write!(f, "idle"),
}
}
}
/// Information we keep track of which can be queried by SHOW SERVERS
#[derive(Debug, Clone)]
pub struct ServerStats {
/// A random integer assigned to the server and used by stats to track the server
server_id: i32,
/// Context information, only to be read
address: Address,
connect_time: Instant,
pool_stats: Arc<PoolStats>,
reporter: Reporter,
/// Data
pub application_name: Arc<RwLock<String>>,
pub state: Arc<AtomicServerState>,
pub bytes_sent: Arc<AtomicU64>,
pub bytes_received: Arc<AtomicU64>,
pub transaction_count: Arc<AtomicU64>,
pub query_count: Arc<AtomicU64>,
pub error_count: Arc<AtomicU64>,
}
impl Default for ServerStats {
fn default() -> Self {
ServerStats {
server_id: 0,
application_name: Arc::new(RwLock::new(String::new())),
address: Address::default(),
pool_stats: Arc::new(PoolStats::default()),
connect_time: Instant::now(),
state: Arc::new(AtomicServerState::new(ServerState::Login)),
bytes_sent: Arc::new(AtomicU64::new(0)),
bytes_received: Arc::new(AtomicU64::new(0)),
transaction_count: Arc::new(AtomicU64::new(0)),
query_count: Arc::new(AtomicU64::new(0)),
error_count: Arc::new(AtomicU64::new(0)),
reporter: get_reporter(),
}
}
}
impl ServerStats {
pub fn new(address: Address, pool_stats: Arc<PoolStats>, connect_time: Instant) -> Self {
Self {
address,
pool_stats,
connect_time,
server_id: rand::random::<i32>(),
..Default::default()
}
}
pub fn server_id(&self) -> i32 {
self.server_id
}
/// Register a server connection with the stats system. The stats system uses server_id
/// to track and aggregate statistics from all source that relate to that server
// Delegates to reporter
pub fn register(&self, stats: Arc<ServerStats>) {
self.reporter.server_register(self.server_id, stats);
self.login();
}
/// Reports a server connection is no longer assigned to a client
/// and is available for the next client to pick it up
pub fn idle(&self) {
self.pool_stats
.server_idle(self.state.load(Ordering::Relaxed));
self.state.store(ServerState::Idle, Ordering::Relaxed);
self.set_undefined_application();
}
/// Reports a server connection is disconecting from the pooler.
/// Also updates metrics on the pool regarding server usage.
pub fn disconnect(&self) {
self.reporter.server_disconnecting(self.server_id);
self.pool_stats
.server_disconnect(self.state.load(Ordering::Relaxed))
}
/// Reports a server connection is being tested before being given to a client.
pub fn tested(&self) {
self.set_undefined_application();
self.pool_stats
.server_tested(self.state.load(Ordering::Relaxed));
self.state.store(ServerState::Tested, Ordering::Relaxed);
}
/// Reports a server connection is attempting to login.
pub fn login(&self) {
self.pool_stats
.server_login(self.state.load(Ordering::Relaxed));
self.state.store(ServerState::Login, Ordering::Relaxed);
self.set_undefined_application();
}
/// Reports a server connection has been assigned to a client that
/// is about to query the server
pub fn active(&self, application_name: String) {
self.pool_stats
.server_active(self.state.load(Ordering::Relaxed));
self.state.store(ServerState::Active, Ordering::Relaxed);
self.set_application(application_name);
}
pub fn address_stats(&self) -> Arc<AddressStats> {
self.address.stats.clone()
}
// Helper methods for show_servers
pub fn pool_name(&self) -> String {
self.pool_stats.database()
}
pub fn username(&self) -> String {
self.pool_stats.user()
}
pub fn address_name(&self) -> String {
self.address.name()
}
pub fn connect_time(&self) -> Instant {
self.connect_time
}
fn set_application(&self, name: String) {
let mut application_name = self.application_name.write();
*application_name = name;
}
fn set_undefined_application(&self) {
self.set_application(String::from("Undefined"))
}
pub fn checkout_time(&self, microseconds: u64, application_name: String) {
// Update server stats and address aggergation stats
self.set_application(application_name);
self.address
.stats
.total_wait_time
.fetch_add(microseconds, Ordering::Relaxed);
self.pool_stats
.maxwait
.fetch_max(microseconds, Ordering::Relaxed);
}
/// Report a query executed by a client against a server
pub fn query(&self, milliseconds: u64, application_name: &str) {
self.set_application(application_name.to_string());
let address_stats = self.address_stats();
address_stats
.total_query_count
.fetch_add(1, Ordering::Relaxed);
address_stats
.total_query_time
.fetch_add(milliseconds, Ordering::Relaxed);
}
/// Report a transaction executed by a client a server
/// we report each individual queries outside a transaction as a transaction
/// We only count the initial BEGIN as a transaction, all queries within do not
/// count as transactions
pub fn transaction(&self, application_name: &str) {
self.set_application(application_name.to_string());
self.transaction_count.fetch_add(1, Ordering::Relaxed);
self.address
.stats
.total_xact_count
.fetch_add(1, Ordering::Relaxed);
}
/// Report data sent to a server
pub fn data_sent(&self, amount_bytes: usize) {
self.bytes_sent
.fetch_add(amount_bytes as u64, Ordering::Relaxed);
self.address
.stats
.total_sent
.fetch_add(amount_bytes as u64, Ordering::Relaxed);
}
/// Report data received from a server
pub fn data_received(&self, amount_bytes: usize) {
self.bytes_received
.fetch_add(amount_bytes as u64, Ordering::Relaxed);
self.address
.stats
.total_received
.fetch_add(amount_bytes as u64, Ordering::Relaxed);
}
}

View File

@@ -36,6 +36,15 @@ services:
POSTGRES_PASSWORD: postgres
POSTGRES_INITDB_ARGS: --auth-local=scram-sha-256 --auth-host=scram-sha-256 --auth=scram-sha-256
command: ["postgres", "-p", "9432", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-c", "pg_stat_statements.max=100000"]
pg5:
image: postgres:14
network_mode: "service:main"
environment:
POSTGRES_USER: postgres
POSTGRES_DB: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_INITDB_ARGS: --auth-local=md5 --auth-host=md5 --auth=md5
command: ["postgres", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-p", "10432"]
main:
build: .
command: ["bash", "/app/tests/docker/run.sh"]

View File

@@ -176,6 +176,47 @@ describe "Admin" do
end
end
context "clients connects and disconnect normally" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 2) }
it 'shows the same number of clients before and after' do
clients_before = clients_connected_to_pool(processes: processes)
threads = []
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
threads << Thread.new { c.async_exec("SELECT 1") }
end
clients_between = clients_connected_to_pool(processes: processes)
expect(clients_before).not_to eq(clients_between)
connections.each(&:close)
clients_after = clients_connected_to_pool(processes: processes)
expect(clients_before).to eq(clients_after)
end
end
context "clients connects and disconnect abruptly" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 10) }
it 'shows the same number of clients before and after' do
threads = []
connections = Array.new(2) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
threads << Thread.new { c.async_exec("SELECT 1") }
end
clients_before = clients_connected_to_pool(processes: processes)
random_string = (0...8).map { (65 + rand(26)).chr }.join
connection_string = "#{pgcat_conn_str}?application_name=#{random_string}"
faulty_client = Process.spawn("psql -Atx #{connection_string} >/dev/null")
sleep(1)
# psql starts two processes, we only know the pid of the parent, this
# ensure both are killed
`pkill -9 -f '#{random_string}'`
Process.wait(faulty_client)
clients_after = clients_connected_to_pool(processes: processes)
expect(clients_before).to eq(clients_after)
end
end
context "clients overwhelm server pools" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 2) }
@@ -199,7 +240,7 @@ describe "Admin" do
sleep(2.5) # Allow time for stats to update
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_idle"]).to eq("4")

View File

@@ -0,0 +1,215 @@
# frozen_string_literal: true
require_relative 'spec_helper'
require_relative 'helpers/auth_query_helper'
describe "Auth Query" do
let(:configured_instances) {[5432, 10432]}
let(:config_user) { { 'username' => 'sharding_user', 'password' => 'sharding_user' } }
let(:pg_user) { { 'username' => 'sharding_user', 'password' => 'sharding_user' } }
let(:processes) { Helpers::AuthQuery.single_shard_auth_query(pool_name: "sharded_db", pg_user: pg_user, config_user: config_user, extra_conf: config, wait_until_ready: wait_until_ready ) }
let(:config) { {} }
let(:wait_until_ready) { true }
after do
unless @failing_process
processes.all_databases.map(&:reset)
processes.pgcat.shutdown
end
@failing_process = false
end
context "when auth_query is not configured" do
context 'and cleartext passwords are set' do
it "uses local passwords" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", config_user['username'], config_user['password']))
expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil
end
end
context 'and cleartext passwords are not set' do
let(:config_user) { { 'username' => 'sharding_user' } }
it "does not start because it is not possible to authenticate" do
@failing_process = true
expect { processes.pgcat }.to raise_error(StandardError, /You have to specify a user password for every pool if auth_query is not specified/)
end
end
end
context 'when auth_query is configured' do
context 'with global configuration' do
around(:example) do |example|
# Set up auth query
Helpers::AuthQuery.set_up_auth_query_for_user(
user: 'md5_auth_user',
password: 'secret'
);
example.run
# Drop auth query support
Helpers::AuthQuery.tear_down_auth_query_for_user(
user: 'md5_auth_user',
password: 'secret'
);
end
context 'with correct global parameters' do
let(:config) { { 'general' => { 'auth_query' => "SELECT * FROM public.user_lookup('$1');", 'auth_query_user' => 'md5_auth_user', 'auth_query_password' => 'secret' } } }
context 'and with cleartext passwords set' do
it 'it uses local passwords' do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password']))
expect(conn.exec("SELECT 1 + 2")).not_to be_nil
end
end
context 'and with cleartext passwords not set' do
let(:config_user) { { 'username' => 'sharding_user' } }
it 'it uses obtained passwords' do
connection_string = processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password'])
conn = PG.connect(connection_string)
expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil
end
it 'allows passwords to be changed without closing existing connections' do
pgconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password']))
expect(pgconn.exec("SELECT 1 + 2")).not_to be_nil
Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD 'secret2';")
expect(pgconn.exec("SELECT 1 + 4")).not_to be_nil
Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD '#{pg_user['password']}';")
end
it 'allows passwords to be changed and that new password is needed when reconnecting' do
pgconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password']))
expect(pgconn.exec("SELECT 1 + 2")).not_to be_nil
Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD 'secret2';")
newconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], 'secret2'))
expect(newconn.exec("SELECT 1 + 2")).not_to be_nil
Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD '#{pg_user['password']}';")
end
end
end
context 'with wrong parameters' do
let(:config) { { 'general' => { 'auth_query' => 'SELECT 1', 'auth_query_user' => 'wrong_user', 'auth_query_password' => 'wrong' } } }
context 'and with clear text passwords set' do
it "it uses local passwords" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password']))
expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil
end
end
context 'and with cleartext passwords not set' do
let(:config_user) { { 'username' => 'sharding_user' } }
it "it fails to start as it cannot authenticate against servers" do
@failing_process = true
expect { PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password'])) }.to raise_error(StandardError, /Error trying to obtain password from auth_query/ )
end
context 'and we fix the issue and reload' do
let(:wait_until_ready) { false }
it 'fails in the beginning but starts working after reloading config' do
connection_string = processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password'])
while !(processes.pgcat.logs =~ /Waiting for clients/) do
sleep 0.5
end
expect { PG.connect(connection_string)}.to raise_error(PG::ConnectionBad)
expect(processes.pgcat.logs).to match(/Error trying to obtain password from auth_query/)
current_config = processes.pgcat.current_config
config = { 'general' => { 'auth_query' => "SELECT * FROM public.user_lookup('$1');", 'auth_query_user' => 'md5_auth_user', 'auth_query_password' => 'secret' } }
processes.pgcat.update_config(current_config.deep_merge(config))
processes.pgcat.reload_config
conn = nil
expect { conn = PG.connect(connection_string)}.not_to raise_error
expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil
end
end
end
end
end
context 'with per pool configuration' do
around(:example) do |example|
# Set up auth query
Helpers::AuthQuery.set_up_auth_query_for_user(
user: 'md5_auth_user',
password: 'secret'
);
Helpers::AuthQuery.set_up_auth_query_for_user(
user: 'md5_auth_user1',
password: 'secret',
database: 'shard1'
);
example.run
# Tear down auth query
Helpers::AuthQuery.tear_down_auth_query_for_user(
user: 'md5_auth_user',
password: 'secret'
);
Helpers::AuthQuery.tear_down_auth_query_for_user(
user: 'md5_auth_user1',
password: 'secret',
database: 'shard1'
);
end
context 'with correct parameters' do
let(:processes) { Helpers::AuthQuery.two_pools_auth_query(pool_names: ["sharded_db0", "sharded_db1"], pg_user: pg_user, config_user: config_user, extra_conf: config ) }
let(:config) {
{ 'pools' =>
{
'sharded_db0' => {
'auth_query' => "SELECT * FROM public.user_lookup('$1');",
'auth_query_user' => 'md5_auth_user',
'auth_query_password' => 'secret'
},
'sharded_db1' => {
'auth_query' => "SELECT * FROM public.user_lookup('$1');",
'auth_query_user' => 'md5_auth_user1',
'auth_query_password' => 'secret'
},
}
}
}
context 'and with cleartext passwords set' do
it 'it uses local passwords' do
conn = PG.connect(processes.pgcat.connection_string("sharded_db0", pg_user['username'], pg_user['password']))
expect(conn.exec("SELECT 1 + 2")).not_to be_nil
conn = PG.connect(processes.pgcat.connection_string("sharded_db1", pg_user['username'], pg_user['password']))
expect(conn.exec("SELECT 1 + 2")).not_to be_nil
end
end
context 'and with cleartext passwords not set' do
let(:config_user) { { 'username' => 'sharding_user' } }
it 'it uses obtained passwords' do
connection_string = processes.pgcat.connection_string("sharded_db0", pg_user['username'], pg_user['password'])
conn = PG.connect(connection_string)
expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil
connection_string = processes.pgcat.connection_string("sharded_db1", pg_user['username'], pg_user['password'])
conn = PG.connect(connection_string)
expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil
end
end
end
end
end
end

39
tests/ruby/auth_spec.rb Normal file
View File

@@ -0,0 +1,39 @@
# frozen_string_literal: true
require_relative 'spec_helper'
describe "Authentication" do
describe "multiple secrets configured" do
let(:secrets) { ["one_secret", "two_secret"] }
let(:processes) { Helpers::Pgcat.three_shard_setup("sharded_db", 5, pool_mode="transaction", lb_mode="random", log_level="info", secrets=["one_secret", "two_secret"]) }
after do
processes.all_databases.map(&:reset)
processes.pgcat.shutdown
end
it "can connect using all secrets and postgres password" do
secrets.push("sharding_user").each do |secret|
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user", password=secret))
conn.exec("SELECT current_user")
end
end
end
describe "no secrets configured" do
let(:secrets) { [] }
let(:processes) { Helpers::Pgcat.three_shard_setup("sharded_db", 5, pool_mode="transaction", lb_mode="random", log_level="info") }
after do
processes.all_databases.map(&:reset)
processes.pgcat.shutdown
end
it "can connect using only the password" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.exec("SELECT current_user")
expect { PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user", password="secret_one")) }.to raise_error PG::ConnectionBad
end
end
end

BIN
tests/ruby/capture Normal file

Binary file not shown.

View File

@@ -0,0 +1,173 @@
module Helpers
module AuthQuery
def self.single_shard_auth_query(
pg_user:,
config_user:,
pool_name:,
extra_conf: {},
log_level: 'debug',
wait_until_ready: true
)
user = {
"pool_size" => 10,
"statement_timeout" => 0,
}
pgcat = PgcatProcess.new(log_level)
pgcat_cfg = pgcat.current_config.deep_merge(extra_conf)
primary = PgInstance.new(5432, pg_user["username"], pg_user["password"], "shard0")
replica = PgInstance.new(10432, pg_user["username"], pg_user["password"], "shard0")
# Main proxy configs
pgcat_cfg["pools"] = {
"#{pool_name}" => {
"default_role" => "any",
"pool_mode" => "transaction",
"load_balancing_mode" => "random",
"primary_reads_enabled" => false,
"query_parser_enabled" => false,
"sharding_function" => "pg_bigint_hash",
"shards" => {
"0" => {
"database" => "shard0",
"servers" => [
["localhost", primary.port.to_s, "primary"],
["localhost", replica.port.to_s, "replica"],
]
},
},
"users" => { "0" => user.merge(config_user) }
}
}
pgcat_cfg["general"]["port"] = pgcat.port
pgcat.update_config(pgcat_cfg)
pgcat.start
pgcat.wait_until_ready(
pgcat.connection_string(
"sharded_db",
pg_user['username'],
pg_user['password']
)
) if wait_until_ready
OpenStruct.new.tap do |struct|
struct.pgcat = pgcat
struct.primary = primary
struct.replicas = [replica]
struct.all_databases = [primary]
end
end
def self.two_pools_auth_query(
pg_user:,
config_user:,
pool_names:,
extra_conf: {},
log_level: 'debug'
)
user = {
"pool_size" => 10,
"statement_timeout" => 0,
}
pgcat = PgcatProcess.new(log_level)
pgcat_cfg = pgcat.current_config
primary = PgInstance.new(5432, pg_user["username"], pg_user["password"], "shard0")
replica = PgInstance.new(10432, pg_user["username"], pg_user["password"], "shard0")
pool_template = Proc.new do |database|
{
"default_role" => "any",
"pool_mode" => "transaction",
"load_balancing_mode" => "random",
"primary_reads_enabled" => false,
"query_parser_enabled" => false,
"sharding_function" => "pg_bigint_hash",
"shards" => {
"0" => {
"database" => database,
"servers" => [
["localhost", primary.port.to_s, "primary"],
["localhost", replica.port.to_s, "replica"],
]
},
},
"users" => { "0" => user.merge(config_user) }
}
end
# Main proxy configs
pgcat_cfg["pools"] = {
"#{pool_names[0]}" => pool_template.call("shard0"),
"#{pool_names[1]}" => pool_template.call("shard1")
}
pgcat_cfg["general"]["port"] = pgcat.port
pgcat.update_config(pgcat_cfg.deep_merge(extra_conf))
pgcat.start
pgcat.wait_until_ready(pgcat.connection_string("sharded_db0", pg_user['username'], pg_user['password']))
OpenStruct.new.tap do |struct|
struct.pgcat = pgcat
struct.primary = primary
struct.replicas = [replica]
struct.all_databases = [primary]
end
end
def self.create_query_auth_function(user)
return <<-SQL
CREATE OR REPLACE FUNCTION public.user_lookup(in i_username text, out uname text, out phash text)
RETURNS record AS $$
BEGIN
SELECT usename, passwd FROM pg_catalog.pg_shadow
WHERE usename = i_username INTO uname, phash;
RETURN;
END;
$$ LANGUAGE plpgsql SECURITY DEFINER;
GRANT EXECUTE ON FUNCTION public.user_lookup(text) TO #{user};
SQL
end
def self.exec_in_instances(query:, instance_ports: [ 5432, 10432 ], database: 'postgres', user: 'postgres', password: 'postgres')
instance_ports.each do |port|
c = PG.connect("postgres://#{user}:#{password}@localhost:#{port}/#{database}")
c.exec(query)
c.close
end
end
def self.set_up_auth_query_for_user(user:, password:, instance_ports: [ 5432, 10432 ], database: 'shard0' )
instance_ports.each do |port|
connection = PG.connect("postgres://postgres:postgres@localhost:#{port}/#{database}")
connection.exec(self.drop_query_auth_function(user)) rescue PG::UndefinedFunction
connection.exec("DROP ROLE #{user}") rescue PG::UndefinedObject
connection.exec("CREATE ROLE #{user} ENCRYPTED PASSWORD '#{password}' LOGIN;")
connection.exec(self.create_query_auth_function(user))
connection.close
end
end
def self.tear_down_auth_query_for_user(user:, password:, instance_ports: [ 5432, 10432 ], database: 'shard0' )
instance_ports.each do |port|
connection = PG.connect("postgres://postgres:postgres@localhost:#{port}/#{database}")
connection.exec(self.drop_query_auth_function(user)) rescue PG::UndefinedFunction
connection.exec("DROP ROLE #{user}")
connection.close
end
end
def self.drop_query_auth_function(user)
return <<-SQL
REVOKE ALL ON FUNCTION public.user_lookup(text) FROM public, #{user};
DROP FUNCTION public.user_lookup(in i_username text, out uname text, out phash text);
SQL
end
end
end

View File

@@ -3,16 +3,27 @@ require 'ostruct'
require_relative 'pgcat_process'
require_relative 'pg_instance'
class ::Hash
def deep_merge(second)
merger = proc { |key, v1, v2| Hash === v1 && Hash === v2 ? v1.merge(v2, &merger) : v2 }
self.merge(second, &merger)
end
end
module Helpers
module Pgcat
def self.three_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mode="random", log_level="info")
def self.three_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mode="random", log_level="info", secrets=nil)
user = {
"password" => "sharding_user",
"pool_size" => pool_size,
"statement_timeout" => 0,
"username" => "sharding_user"
"username" => "sharding_user",
}
if !secrets.nil?
user["secrets"] = secrets
end
pgcat = PgcatProcess.new(log_level)
primary0 = PgInstance.new(5432, user["username"], user["password"], "shard0")
primary1 = PgInstance.new(7432, user["username"], user["password"], "shard1")
@@ -20,7 +31,7 @@ module Helpers
pgcat_cfg = pgcat.current_config
pgcat_cfg["pools"] = {
"#{pool_name}" => {
"#{pool_name}" => {
"default_role" => "any",
"pool_mode" => pool_mode,
"load_balancing_mode" => lb_mode,
@@ -34,8 +45,14 @@ module Helpers
"2" => { "database" => "shard2", "servers" => [["localhost", primary2.port.to_s, "primary"]] },
},
"users" => { "0" => user }
}
},
}
if !secrets.nil?
pgcat_cfg["general"]["tls_certificate"] = "../../.circleci/server.cert"
pgcat_cfg["general"]["tls_private_key"] = "../../.circleci/server.key"
end
pgcat.update_config(pgcat_cfg)
pgcat.start

View File

@@ -67,17 +67,20 @@ class PgcatProcess
def start
raise StandardError, "Process is already started" unless @pid.nil?
@pid = Process.spawn(@env, @command, err: @log_filename, out: @log_filename)
Process.detach(@pid)
ObjectSpace.define_finalizer(@log_filename, proc { PgcatProcess.finalize(@pid, @log_filename, @config_filename) })
return self
end
def wait_until_ready
def wait_until_ready(connection_string = nil)
exc = nil
10.times do
PG::connect(example_connection_string).close
Process.kill 0, @pid
PG::connect(connection_string || example_connection_string).close
return self
rescue Errno::ESRCH
raise StandardError, "Process #{@pid} died. #{logs}"
rescue => e
exc = e
sleep(0.5)
@@ -108,11 +111,11 @@ class PgcatProcess
"postgresql://#{username}:#{password}@0.0.0.0:#{@port}/pgcat"
end
def connection_string(pool_name, username)
def connection_string(pool_name, username, password=nil)
cfg = current_config
user_idx, user_obj = cfg["pools"][pool_name]["users"].detect { |k, user| user["username"] == username }
password = user_obj["password"]
password = if password.nil? then user_obj["password"] else password end
"postgresql://#{username}:#{password}@0.0.0.0:#{@port}/#{pool_name}"
end

View File

@@ -309,4 +309,58 @@ describe "Miscellaneous" do
end
end
end
describe "Idle client timeout" do
context "idle transaction timeout set to 0" do
before do
current_configs = processes.pgcat.current_config
correct_idle_client_transaction_timeout = current_configs["general"]["idle_client_in_transaction_timeout"]
puts(current_configs["general"]["idle_client_in_transaction_timeout"])
current_configs["general"]["idle_client_in_transaction_timeout"] = 0
processes.pgcat.update_config(current_configs) # with timeout 0
processes.pgcat.reload_config
end
it "Allow client to be idle in transaction" do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.async_exec("BEGIN")
conn.async_exec("SELECT 1")
sleep(2)
conn.async_exec("COMMIT")
conn.close
end
end
context "idle transaction timeout set to 500ms" do
before do
current_configs = processes.pgcat.current_config
correct_idle_client_transaction_timeout = current_configs["general"]["idle_client_in_transaction_timeout"]
current_configs["general"]["idle_client_in_transaction_timeout"] = 500
processes.pgcat.update_config(current_configs) # with timeout 500
processes.pgcat.reload_config
end
it "Allow client to be idle in transaction below timeout" do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.async_exec("BEGIN")
conn.async_exec("SELECT 1")
sleep(0.4) # below 500ms
conn.async_exec("COMMIT")
conn.close
end
it "Error when client idle in transaction time exceeds timeout" do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.async_exec("BEGIN")
conn.async_exec("SELECT 1")
sleep(1) # above 500ms
expect{ conn.async_exec("COMMIT") }.to raise_error(PG::SystemError, /idle transaction timeout/)
conn.async_exec("SELECT 1") # should be able to send another query
conn.close
end
end
end
end

View File

@@ -19,3 +19,10 @@ ensure
STDOUT.reopen(sout)
STDERR.reopen(serr)
end
def clients_connected_to_pool(pool_index: 0, processes:)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[pool_index]
admin_conn.close
results['cl_idle'].to_i + results['cl_active'].to_i + results['cl_waiting'].to_i
end