Compare commits

...

111 Commits

Author SHA1 Message Date
Lev
fa17bb5cc6 TLS misconfiguration demoted to warning 2023-09-26 10:14:42 -07:00
Kevin Elliott
04e9814770 Fix incorrect data output for plugin query_logger (#601)
Update query_logger.rs

Pool and user were incorrectly swapped and needed to be fixed.
2023-09-25 18:45:51 -07:00
Lev Kokotov
037d232fcd Mark admin clients as disconnected on error (#597) 2023-09-21 15:55:22 -07:00
Lev Kokotov
b2933762e7 Report maxwait for clients that end up not getting a connection (#596) 2023-09-21 14:50:18 -07:00
Mohammad Dashti
df8aa888f9 Add a cache layer to Docker for development (#594)
* Add a cache layer to Docker.

* Created a separate `dev` Docker file.

* Fixed `Docker.dev` to build in non-release mode.
2023-09-20 10:29:30 -07:00
Mohammad Dashti
7f5639c94a Include thread_id in the logs (#592)
Include `thread_id` in the logs.
2023-09-20 09:11:16 -07:00
Lev Kokotov
c0112f6f12 Revert "User-friendly error messages" (#587)
Revert "User-friendly error messages (#586)"

This reverts commit b7ceee2ddf.
2023-09-11 16:39:31 -07:00
Lev Kokotov
b7ceee2ddf User-friendly error messages (#586) 2023-09-11 16:39:11 -07:00
Mostafa Abdelraouf
0b01d70b55 Allow configuring routing decision when no shard is selected (#578)
The TL;DR for the change is that we allow QueryRouter to set the active shard to None. This signals to the Pool::get method that we have no shard selected. The get method follows a no_shard_specified_behavior config to know how to route the query.

Original PR description
Ruby-pg library makes a startup query to SET client_encoding to ... if Encoding.default_internal value is set (Code). This query is troublesome because we cannot possibly attach a routing comment to it. PgCat, by default, will route that query to the default shard.

Everything is fine until shard 0 has issues, Clients will all be attempting to send this query to shard0 which increases the connection latency significantly for all clients, even those not interested in shard0

This PR introduces no_shard_specified_behavior that defines the behavior in case we have routing-by-comment enabled but we get a query without a comment. The allowed behaviors are

random: Picks a shard at random
random_healthy: Picks a shard at random favoring shards with the least number of recent connection/checkout errors
shard_<number>: e.g. shard_0, shard_4, etc. picks a specific shard, everytime
In order to achieve this, this PR introduces an error_count on the Address Object that tracks the number of errors since the last checkout and uses that metric to sort shards by error count before making a routing decision.
I didn't want to use address stats to avoid introducing a routing dependency on internal stats (We might do that in the future but I prefer to avoid this for the time being.

I also made changes to the test environment to replace Ruby's TOML reader library, It appears to be abandoned and does not support mixed arrays (which we use in the config toml), and it also does not play nicely with single-quoted regular expressions. I opted for using yj which is a CLI tool that can convert from toml to JSON and back. So I refactor the tests to use that library.
2023-09-11 13:47:28 -05:00
hellower
33db0dffa8 stream.peer_addr() & auth_query (#575)
* Don't unwrap stream.peer_addr()

https://github.com/postgresml/pgcat/pull/562 (same code)
(another lines changed)

* auth_query (real sample)

# single quote need
auth_query="SELECT usename, passwd FROM pg_shadow WHERE usename='$1'"
2023-08-31 14:11:38 -07:00
hi019
7994a661d9 Fix Docker image runs erroring due to glibc incompatability (#572)
Fix Docker image builds breaking due to glibc incompatability
2023-08-30 16:51:31 -07:00
Tommy Li
9937193332 Allow pause/resuming all pools (#566)
support pausing all pools
2023-08-29 10:07:36 -07:00
Mostafa Abdelraouf
baa00ff546 Add yj to CI image (#568) 2023-08-28 21:20:53 -05:00
Zain Kabani
ffe820497f Don't unwrap stream.peer_addr() (#562) 2023-08-25 10:33:39 -07:00
Zain Kabani
be549f3faa Fixes try_execute_command message parsing bug (#560)
* Fixes try_execute_command message parsing bug

* Fix initial segment logic

* Add test
2023-08-24 11:25:43 -07:00
dependabot[bot]
4301ab0606 chore(deps): bump rustls-webpki from 0.100.1 to 0.100.2 (#555)
Bumps [rustls-webpki](https://github.com/rustls/webpki) from 0.100.1 to 0.100.2.
- [Release notes](https://github.com/rustls/webpki/releases)
- [Commits](https://github.com/rustls/webpki/compare/v/0.100.1...v/0.100.2)

---
updated-dependencies:
- dependency-name: rustls-webpki
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-08-22 11:41:09 -07:00
Cluas
5143500c9a docs: complete the missing general items (#553)
docs: complete the missing general items.
2023-08-20 19:14:19 -07:00
Zain Kabani
3255323bff Adds option to log which parameter status is changed by the client (#550) 2023-08-16 11:01:21 -07:00
Zain Kabani
bb27586758 Reset instead of discard all (#549)
* Use reset all instead of discard all

* Move 'X' handling to before admin handle

* fix tests
2023-08-16 10:08:48 -07:00
Lev Kokotov
4f0f45b576 Add pgcat user (#546)
* Add pgcat user

* warn

* dev
2023-08-10 12:25:43 -07:00
Zain Kabani
f94ce97ebc Handle and track startup parameters (#478)
* User server parameters struct instead of server info bytesmut

* Refactor to use hashmap for all params and add server parameters to client

* Sync parameters on client server checkout

* minor refactor

* update client side parameters when changed

* Move the SET statement logic from the C packet to the S packet.

* trigger build

* revert validation changes

* remove comment

* Try fix

* Reset cleanup state after sync

* fix server version test

* Track application name through client life for stats

* Add tests

* minor refactoring

* fmt

* fix

* fmt
2023-08-10 08:18:46 -07:00
Sebastian Webber
9ab128579d parse server error messages (#543)
This commit adds a parser to the Postgres error message, providing better
error messages.

Implemented based in:
  https://www.postgresql.org/docs/12/protocol-error-fields.html

Signed-off-by: Sebastian Webber <sebastian@swebber.me>
2023-08-09 09:14:05 -07:00
Lev Kokotov
1cde74f05e Revert "Preserve existing behavior" (#542)
Revert "Preserve existing behavior (#541)"

This reverts commit a4de6c1eb6.
2023-08-08 17:45:48 -07:00
Lev Kokotov
a4de6c1eb6 Preserve existing behavior (#541) 2023-08-08 13:48:52 -07:00
Zain Kabani
e14b283f0c Make infer role configurable and fix double parse bug (#533)
* Make infer role configurable and fix double parse bug

* Fix tests

* Enable infer_role_from query in toml for tests

* Fix test

* Add max length config, add logging for which application is failing to parse, and change config name

* fmt

* Update src/config.rs

---------

Co-authored-by: Lev Kokotov <levkk@users.noreply.github.com>
2023-08-08 13:10:03 -07:00
Lev Kokotov
7c3c90c38e Add systemd service (#540) 2023-08-08 11:51:38 -07:00
Lev Kokotov
2ca21b2bec pgcat deb package (#539) 2023-08-08 11:08:46 -07:00
Matthias Pfeil
3986eaa4b2 Add github tag as tag to image (#537) 2023-08-04 10:20:56 -07:00
Lev Kokotov
1f2c6507f7 debug -> release 2023-08-01 17:47:34 -07:00
Lev Kokotov
aefcf4281c Fix for #534 and #535 2023-08-01 17:46:34 -07:00
Bertrand Paquet
9d1c46a3e9 Fix typo in the config documentation (#532) 2023-07-28 00:31:53 -07:00
Spindel Ljungmark
328108aeb5 Restore the ability to filter spammy log messages (#530)
* Move connection checkin log messages to their own target

Under heavy load they can happen thousands of times per second, and
should generally be considered a nuisance at best. This marks the state
discard as an info rather than a warning, and moves all the messages
into their own log-target, so they can be filtered separately from the
more relevant warnings.

Signed-off-by: D.S. Ljungmark <spider@skuggor.se>

* Remove left-over env_logger dependencies

When moving to tracing-subscriber for logging, the env_logger
dependencies were left around, this cuts them out as dead code.

Signed-off-by: D.S. Ljungmark <spider@skuggor.se>

* Restore ability to filter log messages at runtime

This restores the RUST_LOG filters from env_logger but now with the
tracing subscriber setup. The filters are chained so commandline options
mark the default in case either option is set, which should be the path
of least confusion for users.  ( RUST_LOG setting level to debug, and
commandline to warning is an odd user case, and I don't know what a user
who does that is expecting. )

It also bumps the version number as a fix to see which versions have
which behaviour.

Signed-off-by: D.S. Ljungmark <spider@skuggor.se>

---------

Signed-off-by: D.S. Ljungmark <spider@skuggor.se>
2023-07-27 08:51:23 -07:00
Lev Kokotov
4cf54a6122 Release 1.1 (#526) 2023-07-25 10:27:04 -07:00
Mostafa Abdelraouf
2a8f3653a6 Fix COPY FROM and add tests (#522)
* Fix COPY FROM and add tests

* E

* fmt
2023-07-20 23:06:01 -07:00
Sebastian Webber
19cb8a3022 add --no-color option to disable colors in the terminal (#518)
add --no-color option to disable colors

this commit adds a new option to disable colors in the terminal and also
moves the logger configuration to a different crate.

Signed-off-by: Sebastian Webber <sebastian@swebber.me>
2023-07-19 21:15:55 -07:00
Sebastian Webber
f85e5bd9e8 add support for multiple log formats (#517)
this commit adds the tracing-subscriber crate and use its formatters to
support multiple log formats.

More details in
https://github.com/postgresml/pgcat/issues/464#issuecomment-1641430299

Signed-off-by: Sebastian Webber <sebastian@swebber.me>
2023-07-18 23:07:13 -07:00
Sebastian Webber
7bdb4e5cd9 Add cmd line parser (#512)
This commit adds the clap library and configures the necessary args to
parse from the command line,  expanding the current option of a single
file and adding support for environment variables.

Signed-off-by: Sebastian Webber <sebastian@swebber.me>
2023-07-18 13:52:40 -07:00
Sebastian Webber
5d87e3781e push and build only in main and tags (#508)
this commit changes the CI behavior to only build and push when something is committed to main or is a new tag.
2023-07-14 10:30:49 -07:00
dependabot[bot]
3e08c6bd8d chore(deps): bump num_cpus from 1.15.0 to 1.16.0 (#507)
Bumps [num_cpus](https://github.com/seanmonstar/num_cpus) from 1.15.0 to 1.16.0.
- [Release notes](https://github.com/seanmonstar/num_cpus/releases)
- [Changelog](https://github.com/seanmonstar/num_cpus/blob/master/CHANGELOG.md)
- [Commits](https://github.com/seanmonstar/num_cpus/compare/v1.15.0...v1.16.0)

---
updated-dependencies:
- dependency-name: num_cpus
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-07-14 07:58:11 -07:00
Sebastian Webber
15b6db8e4e add "show help" command (#505)
This commit adds a new function to handle notify and use it
in the SHOW HELP command, which displays the available options
in the admin console.

Also, adding Fabrízio as a co-author for all the help with the
protocol and the help to structure this PR.

Signed-off-by: Sebastian Webber <sebastian@swebber.me>
Co-authored-by: Fabrízio de Royes Mello <fabriziomello@gmail.com>
2023-07-13 22:40:04 -07:00
dependabot[bot]
b2e6dfd9bb chore(deps): bump rustls-pemfile from 1.0.2 to 1.0.3 (#504)
Bumps [rustls-pemfile](https://github.com/rustls/pemfile) from 1.0.2 to 1.0.3.
- [Commits](https://github.com/rustls/pemfile/commits)

---
updated-dependencies:
- dependency-name: rustls-pemfile
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-07-12 21:41:48 -07:00
Mostafa Abdelraouf
3c9565d351 Add support for tcp_user_timeout (#503)
* Add support for tcp_user_timeout

* option

* duration

* Some()

* docs

* fmt, compile
2023-07-12 11:24:30 -07:00
dependabot[bot]
67579c9af4 chore(deps): bump rustls from 0.21.1 to 0.21.5 (#501)
Bumps [rustls](https://github.com/rustls/rustls) from 0.21.1 to 0.21.5.
- [Release notes](https://github.com/rustls/rustls/releases)
- [Commits](https://github.com/rustls/rustls/compare/v/0.21.1...v/0.21.5)

---
updated-dependencies:
- dependency-name: rustls
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-07-12 05:46:31 -07:00
Cluas
cf7f6f35ab docs: fix general.autoreload description (#491)
* docs: fix autoreload description

Signed-off-by: Cluas <Cluas@live.cn>

* docs: add blank line

Signed-off-by: Cluas <Cluas@live.cn>

---------

Signed-off-by: Cluas <Cluas@live.cn>
2023-07-12 05:42:44 -07:00
Voldemarich
7205537b49 [BUG] Fix binding of NULL value parameters in prepared statements (#496)
Fix binding of NULL value parameters in prepared statements

Co-authored-by: anon <anon@non.existent>
2023-07-10 10:35:43 +02:00
Zain Kabani
1ed6e925ed Fixes the default for round robing in General (#488) 2023-06-23 09:15:44 -07:00
Lev Kokotov
4b78af9676 Implement Close for prepared statements (#482)
* Partial support for Close

* Close

* respect config value

* prepared spec

* Hmm

* Print cache size
2023-06-18 23:02:34 -07:00
Lev Kokotov
73500c0c96 Fix build (#481) 2023-06-17 09:09:54 -07:00
Lev Kokotov
b167de5aa3 fmt (#480) 2023-06-17 08:57:33 -07:00
Juraj Bubniak
473bb3d17d Log not implemented messages as debug in prometheus metrics. (#477) 2023-06-16 18:48:38 -07:00
Lev Kokotov
c7d6273037 Support for prepared statements (#474)
* Start prepared statements

* parse

* Ok

* optional

* dont rewrite anonymous prepared stmts

* Dont rewrite anonymous prep statements

* hm?

* prep statements

* I see!

* comment

* Print config value

* Rewrite bind and add sqlx test

* fmt

* ok

* Fix

* Fix stats

* its late

* clean up PREPARE
2023-06-16 12:57:44 -07:00
Jeff Chen
94c781881f Report min_pool_size correctly (#471) 2023-06-12 09:23:56 -07:00
dependabot[bot]
a8c81e5df6 chore(deps): bump pin-project from 1.0.12 to 1.1.0 (#440)
Bumps [pin-project](https://github.com/taiki-e/pin-project) from 1.0.12 to 1.1.0.
- [Release notes](https://github.com/taiki-e/pin-project/releases)
- [Changelog](https://github.com/taiki-e/pin-project/blob/main/CHANGELOG.md)
- [Commits](https://github.com/taiki-e/pin-project/compare/v1.0.12...v1.1.0)

---
updated-dependencies:
- dependency-name: pin-project
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-06-12 00:51:24 -07:00
dependabot[bot]
1d3746ec9e chore(deps): bump sqlparser from 0.33.0 to 0.34.0 (#448)
Bumps [sqlparser](https://github.com/sqlparser-rs/sqlparser-rs) from 0.33.0 to 0.34.0.
- [Changelog](https://github.com/sqlparser-rs/sqlparser-rs/blob/main/CHANGELOG.md)
- [Commits](https://github.com/sqlparser-rs/sqlparser-rs/compare/v0.33.0...v0.34.0)

---
updated-dependencies:
- dependency-name: sqlparser
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-06-12 00:51:05 -07:00
dependabot[bot]
b5489dc1e6 chore(deps): bump regex from 1.8.1 to 1.8.4 (#466)
Bumps [regex](https://github.com/rust-lang/regex) from 1.8.1 to 1.8.4.
- [Release notes](https://github.com/rust-lang/regex/releases)
- [Changelog](https://github.com/rust-lang/regex/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/regex/compare/1.8.1...1.8.4)

---
updated-dependencies:
- dependency-name: regex
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-06-12 00:50:46 -07:00
dependabot[bot]
557b425fb1 chore(deps): bump log from 0.4.17 to 0.4.19 (#470)
Bumps [log](https://github.com/rust-lang/log) from 0.4.17 to 0.4.19.
- [Release notes](https://github.com/rust-lang/log/releases)
- [Changelog](https://github.com/rust-lang/log/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/log/compare/0.4.17...0.4.19)

---
updated-dependencies:
- dependency-name: log
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-06-12 00:50:26 -07:00
Zain Kabani
aca9738821 Make queue strategy configurable and default to Fifo (#463)
* Change idle timeout default to 10 minutes

* Revert lifo for now while we investigate connection thrashing issues

* Make queue strategy configurable

* test revert idle time out

* Add pgcat start to python test
2023-06-09 11:35:20 -07:00
Zain Kabani
0bc453a771 Change default server lifetime and bump bb8 version to use LIFO correctly (#462)
Change default server lifetime and idle timeouts and bump bb8 version to use LIFO correctly
2023-05-31 08:25:42 -07:00
Zain Kabani
b67c33b6d0 Use latest bb8 and use Lifo as the queue strategy in the pool (#455)
* Use git bb8

* Use latest bb8 and change pool is use stack
2023-05-28 19:46:13 -07:00
Mostafa Abdelraouf
a8a30ad43b Refactor Pool Stats to be based off of Server/Client stats (#445)
What is wrong
Stats reported by SHOW POOLS seem to be leaking. We see lingering cl_idle , cl_waiting, and similarly for sv_idle , sv_active. We confirmed that these are reporting issues not actual lingering clients.

This behavior is readily reproducible by running

while true; do
    psql "postgres://sharding_user:sharding_user@localhost:6432/sharded_db" -c "SELECT 1" > /dev/null 2>&1  &
done

Why it happens
I wasn't able to get to figure our the reason for the bug but my best guess is that we have race conditions when updating pool-level stats. So even though individual update operations are atomic, we perform a check then update sequence which is not protected by a guard.
https://github.com/postgresml/pgcat/blob/main/src/stats/pool.rs#L174-L179

I am also suspecting that using Relaxed ordering might allow this behavior (I changed all operations to use Ordering::SeqCst but still got lingering clients)

How to fix
Since SHOW POOLS/SHOW SERVER/SHOW CLIENTS all show the current state of the proxy (as opposed to SHOW STATS which show aggregate values), this PR refactors SHOW POOLS to have it construct the results directly from SHOW SERVER and SHOW CLIENT datasets. This reduces the complexity of stat updates and eliminates the need for having locks when updating pool stats as we only care about updating individual client/server states.

This will change the semantics of maxwait, so instead of it holding the maxwait time ever encountered by a client (connected or disconnected), it will only consider connected clients which should be okay given PgCat tends to hold on to client connections more than Pgbouncer.
2023-05-23 08:44:49 -05:00
dependabot[bot]
d63be9b93a chore(deps): bump toml from 0.7.3 to 0.7.4 (#447)
Bumps [toml](https://github.com/toml-rs/toml) from 0.7.3 to 0.7.4.
- [Commits](https://github.com/toml-rs/toml/compare/toml-v0.7.3...toml-v0.7.4)

---
updated-dependencies:
- dependency-name: toml
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-05-19 08:58:13 -07:00
Lev Kokotov
100778670c Ensure data makes it to the client (#446)
* Ensure data makes it to the client

* flush all buffers
2023-05-18 16:41:22 -07:00
Lev Kokotov
37e3349c24 Optionally clean up server connections (#444)
* Optionally clean up server connections

* move setting to pool

* fix test

* Print setting to screen

* fmt

* Fix pool_settings override in tests
2023-05-18 10:46:55 -07:00
Zain Kabani
7f57a89d75 Fix time based average stats (#442)
* keep track of current stats and zero them after updating averages

* Try tests

* typo

* remove commented test stuff

* Avoid dividing by zero

* Fix test

* refactor, get rid of iterator. do it manually

* trigger build

* Fix
2023-05-17 21:38:10 -07:00
Lev Kokotov
0898461c01 Allow to deploy pools without checking (#438) 2023-05-12 12:48:37 -07:00
Lev Kokotov
52b1b43850 Prewarmer (#435)
* Prewarmer

* hmm

* Tests

* default

* fix test

* Correct configuration

* Added minimal config example

* remove connect_timeout
2023-05-12 09:50:52 -07:00
Zain Kabani
0907f1b77f Improve logging for connection cleanup (#428)
* initial commit

* fix

* fmt
2023-05-11 17:40:10 -07:00
Zain Kabani
73260690b0 Fixes average stats bug (#436)
* Add test

* Fix test

* Add fix
2023-05-11 17:37:58 -07:00
Mostafa Abdelraouf
5056cbe8ed Fix docker-compose dev stack for Apple silicon (#432)
The docker-compose dev setup is broken under Apple silicon, starting the stack fails with the following error. Switching to a different docker image fixes the issue.
2023-05-10 10:24:35 -05:00
Lev Kokotov
571b02e178 Calculate averages correctly and preserve totals like before (#429)
* Reset totals after avg calculation

* like it used to be
2023-05-08 10:06:16 -07:00
Andrew Tanner
159eb89bf0 First try with role reset (#427)
* First try with role rest

* update

* extra line

* Update src/server.rs

Co-authored-by: Lev Kokotov <levkk@users.noreply.github.com>

* Update tests/ruby/misc_spec.rb

Co-authored-by: Lev Kokotov <levkk@users.noreply.github.com>

---------

Co-authored-by: Lev Kokotov <levkk@users.noreply.github.com>
2023-05-05 15:31:27 -07:00
Lev Kokotov
389993bf3e Accurate log messages (#425) 2023-05-05 08:27:19 -07:00
Lev Kokotov
ba5243b6dd Optionally validate config on boot (#423) 2023-05-03 17:07:23 -07:00
Lev Kokotov
128ef72911 lowercase config query (#422)
* lowercase config query

* remove debug
2023-05-03 16:47:20 -07:00
Lev Kokotov
811885f464 Actually plugins (#421)
* more plugins

* clean up

* fix tests

* fix flakey test
2023-05-03 16:13:45 -07:00
dependabot[bot]
d5e329fec5 chore(deps): bump regex from 1.8.0 to 1.8.1 (#413)
Bumps [regex](https://github.com/rust-lang/regex) from 1.8.0 to 1.8.1.
- [Release notes](https://github.com/rust-lang/regex/releases)
- [Changelog](https://github.com/rust-lang/regex/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/regex/commits/1.8.1)

---
updated-dependencies:
- dependency-name: regex
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-05-03 10:00:05 -07:00
Lev Kokotov
09e54e1175 Plugins! (#420)
* Some queries

* Plugins!!

* cleanup

* actual names

* the actual plugins

* comment

* fix tests

* Tests

* unused errors

* Increase reaper rate to actually enforce settings

* ok
2023-05-03 09:13:05 -07:00
dependabot[bot]
23819c8549 chore(deps): bump rustls from 0.21.0 to 0.21.1 (#419)
Bumps [rustls](https://github.com/rustls/rustls) from 0.21.0 to 0.21.1.
- [Release notes](https://github.com/rustls/rustls/releases)
- [Changelog](https://github.com/rustls/rustls/blob/main/RELEASE_NOTES.md)
- [Commits](https://github.com/rustls/rustls/compare/v/0.21.0...v/0.21.1)

---
updated-dependencies:
- dependency-name: rustls
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-05-02 07:32:44 -07:00
Jose Fernández
7dfbd993f2 Add dns_cache for server addresses as in pgbouncer (#249)
* Add dns_cache so server addresses are cached and invalidated when DNS changes.

Adds a module to deal with dns_cache feature. It's
main struct is CachedResolver, which is a simple thread safe
hostname <-> Ips cache with the ability to refresh resolutions
every `dns_max_ttl` seconds. This way, a client can check whether its
ip address has changed.

* Allow reloading dns cached

* Add documentation for dns_cached
2023-05-02 10:26:40 +02:00
Lev Kokotov
3601130ba1 Readme update (#418)
* Readme update

* m

* wording
2023-04-30 09:44:25 -07:00
Lev Kokotov
0d504032b2 Server TLS (#417)
* Server TLS

* Finish up TLS

* thats it

* diff

* remove dead code

* maybe?

* dirty shutdown

* skip flakey test

* remove unused error

* fetch config once
2023-04-30 09:41:46 -07:00
Lev Kokotov
4a87b4807d Add more pool settings (#416)
* Add some pool settings

* fmt
2023-04-26 16:33:26 -07:00
Shawn
cb5ff40a59 fix typo (#415)
chore: typo
2023-04-26 08:28:54 -07:00
dependabot[bot]
62b2d994c1 chore(deps): bump regex from 1.7.3 to 1.8.0 (#411)
Bumps [regex](https://github.com/rust-lang/regex) from 1.7.3 to 1.8.0.
- [Release notes](https://github.com/rust-lang/regex/releases)
- [Changelog](https://github.com/rust-lang/regex/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/regex/commits)

---
updated-dependencies:
- dependency-name: regex
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-21 06:33:52 -07:00
Lev Kokotov
66805d7e77 README updates (#409)
* Better table

* add image

* promote auth passthrough to stable

* fmt
2023-04-20 07:53:55 -07:00
Lev Kokotov
4ccc1e7fa3 Fix CONFIG (#408)
Fix readme
2023-04-19 07:45:26 -07:00
Lev Kokotov
3dae3d0777 Separate server and client passwords optionally (#407)
* Separate server and user passwords

* config
2023-04-18 09:57:17 -07:00
dependabot[bot]
a18eb42df5 chore(deps): bump serde from 1.0.159 to 1.0.160 (#404)
Bumps [serde](https://github.com/serde-rs/serde) from 1.0.159 to 1.0.160.
- [Release notes](https://github.com/serde-rs/serde/releases)
- [Commits](https://github.com/serde-rs/serde/compare/v1.0.159...v1.0.160)

---
updated-dependencies:
- dependency-name: serde
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-14 10:25:00 -07:00
dependabot[bot]
6aacf1fa19 chore(deps): bump serde_derive from 1.0.159 to 1.0.160 (#403)
Bumps [serde_derive](https://github.com/serde-rs/serde) from 1.0.159 to 1.0.160.
- [Release notes](https://github.com/serde-rs/serde/releases)
- [Commits](https://github.com/serde-rs/serde/compare/v1.0.159...v1.0.160)

---
updated-dependencies:
- dependency-name: serde_derive
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-14 10:24:52 -07:00
dependabot[bot]
8e99e65215 chore(deps): bump sqlparser from 0.32.0 to 0.33.0 (#399)
Bumps [sqlparser](https://github.com/sqlparser-rs/sqlparser-rs) from 0.32.0 to 0.33.0.
- [Release notes](https://github.com/sqlparser-rs/sqlparser-rs/releases)
- [Changelog](https://github.com/sqlparser-rs/sqlparser-rs/blob/main/CHANGELOG.md)
- [Commits](https://github.com/sqlparser-rs/sqlparser-rs/compare/v0.32.0...v0.33.0)

---
updated-dependencies:
- dependency-name: sqlparser
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-14 10:24:44 -07:00
dependabot[bot]
5dfbc102a9 chore(deps): bump hyper from 0.14.25 to 0.14.26 (#406)
Bumps [hyper](https://github.com/hyperium/hyper) from 0.14.25 to 0.14.26.
- [Release notes](https://github.com/hyperium/hyper/releases)
- [Changelog](https://github.com/hyperium/hyper/blob/v0.14.26/CHANGELOG.md)
- [Commits](https://github.com/hyperium/hyper/compare/v0.14.25...v0.14.26)

---
updated-dependencies:
- dependency-name: hyper
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-14 10:23:52 -07:00
Cluas
bae12fca99 feat: set keepalive for pgcat server itself (#402)
* feat: set keepalive for pgcat server self

* docs: note also set for client
2023-04-12 09:29:43 -07:00
Lev Kokotov
421c5d4b64 Load config on client connect (#401) 2023-04-11 10:32:48 -07:00
Kian-Meng Ang
d568739db9 Fix typos (#398)
Found via `typos --format brief`
2023-04-10 18:37:16 -07:00
Lev Kokotov
692353c839 A couple things (#397)
* Format cleanup

* fmt

* finally
2023-04-10 14:51:01 -07:00
Lev Kokotov
a62f6b0eea Fix port; add user pool mode (#395)
* Fix port; add user pool mode

* will probably break our session/transaction mode tests
2023-04-05 15:06:19 -07:00
dependabot[bot]
89e15f09b5 chore(deps): bump tokio-rustls from 0.23.4 to 0.24.0 (#394)
Bumps [tokio-rustls](https://github.com/tokio-rs/tls) from 0.23.4 to 0.24.0.
- [Release notes](https://github.com/tokio-rs/tls/releases)
- [Commits](https://github.com/tokio-rs/tls/commits)

---
updated-dependencies:
- dependency-name: tokio-rustls
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-02 23:00:09 -07:00
Mostafa Abdelraouf
7ddd23b514 Protocol-level test helpers (#393)
I needed to have granular control over protocol message testing. For example, being able to send protocol messages one-by-one and then be able to inspect the results.

In order to do that, I created this low-level ruby client that can be used to send protocol messages in any order without blocking and also allows inspection of response messages.
2023-04-01 15:27:57 -05:00
dependabot[bot]
faa9c1f64a chore(deps): bump futures from 0.3.27 to 0.3.28 (#392)
Bumps [futures](https://github.com/rust-lang/futures-rs) from 0.3.27 to 0.3.28.
- [Release notes](https://github.com/rust-lang/futures-rs/releases)
- [Changelog](https://github.com/rust-lang/futures-rs/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/futures-rs/compare/0.3.27...0.3.28)

---
updated-dependencies:
- dependency-name: futures
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-31 09:35:01 -07:00
dependabot[bot]
9094988491 chore(deps): bump postgres-protocol from 0.6.4 to 0.6.5 (#391)
Bumps [postgres-protocol](https://github.com/sfackler/rust-postgres) from 0.6.4 to 0.6.5.
- [Release notes](https://github.com/sfackler/rust-postgres/releases)
- [Commits](https://github.com/sfackler/rust-postgres/compare/postgres-protocol-v0.6.4...postgres-protocol-v0.6.5)

---
updated-dependencies:
- dependency-name: postgres-protocol
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-31 09:34:51 -07:00
Jose Fernández
6f768a84ce Auth passthrough (auth_query) (#266)
* Add a new exec_simple_query method

This adds a new `exec_simple_query` method so we can make 'out of band'
queries to servers that don't interfere with pools at all.
In order to reuse startup code for making these simple queries,
we need to set the stats (`Reporter`) optional, so using these
simple queries wont interfere with stats.

* Add auth passthough (auth_query)

Adds a feature that allows setting auth passthrough for md5 auth.

It adds 3 new (general and pool) config parameters:

- `auth_query`: An string containing a query that will be executed on boot
to obtain the hash of a given user. This query have to use a placeholder `$1`,
so pgcat can replace it with the user its trying to fetch the hash from.
- `auth_query_user`: The user to use for connecting to the server and executing the
auth_query.
- `auth_query_password`: The password to use for connecting to the server and executing the
auth_query.

The configuration can be done either on the general config (so pools share them) or in a per-pool basis.

The behavior is, at boot time, when validating server connections, a hash is fetched per server
and stored in the pool. When new server connections are created, and no cleartext password is specified,
the obtained hash is used for creating them, if the hash could not be obtained for whatever reason, it retries
it.

When client authentication is tried, it uses cleartext passwords if specified, it not, it checks whether
we have query_auth set up, if so, it tries to use the obtained hash for making client auth. If there is no
hash (we could not obtain one when validating the connection), a new fetch is tried.

Once we have a hash, we authenticate using it against whathever the client has sent us, if there is a failure
we refetch the hash and retry auth (so password changes can be done).

The idea with this 'retrial' mechanism is to make it fault tolerant, so if for whatever reason hash could not be
obtained during connection validation, or the password has change, we can still connect later.

* Add documentation for Auth passthrough
2023-03-30 13:29:23 -07:00
dependabot[bot]
0757d7f3a0 chore(deps): bump serde from 1.0.158 to 1.0.159 (#386)
Bumps [serde](https://github.com/serde-rs/serde) from 1.0.158 to 1.0.159.
- [Release notes](https://github.com/serde-rs/serde/releases)
- [Commits](https://github.com/serde-rs/serde/compare/v1.0.158...v1.0.159)

---
updated-dependencies:
- dependency-name: serde
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-28 09:54:39 -07:00
dependabot[bot]
568f04feee chore(deps): bump serde_derive from 1.0.154 to 1.0.159 (#387)
Bumps [serde_derive](https://github.com/serde-rs/serde) from 1.0.154 to 1.0.159.
- [Release notes](https://github.com/serde-rs/serde/releases)
- [Commits](https://github.com/serde-rs/serde/compare/v1.0.154...v1.0.159)

---
updated-dependencies:
- dependency-name: serde_derive
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-28 09:54:31 -07:00
Jose Fernández
58ce76d9b9 Refactor stats to use atomics (#375)
* Refactor stats to use atomics

When we are dealing with a high number of connections, generated
stats cannot be consumed fast enough by the stats collector loop.
This makes the stats subsystem inconsistent and a log of
warning messages are thrown due to unregistered server/clients.

This change refactors the stats subsystem so it uses atomics:

- Now counters are handled using U64 atomics
- Event system is dropped and averages are calculated using a loop
  every 15 seconds.
- Now, instead of snapshots being generated ever second we keep track of servers/clients
  that have registered. Each pool/server/client has its own instance of the counter and
  makes changes directly, instead of adding an event that gets processed later.

* Manually mplement Hash/Eq in `config::Address` ignoring stats

* Add tests for client connection counters

* Allow connecting to dockerized dev pgcat from the host

* stats: Decrease cl_idle when idle socket disconnects
2023-03-28 17:19:37 +02:00
dependabot[bot]
9a2076a9eb chore(deps): bump futures from 0.3.26 to 0.3.27 (#356)
Bumps [futures](https://github.com/rust-lang/futures-rs) from 0.3.26 to 0.3.27.
- [Release notes](https://github.com/rust-lang/futures-rs/releases)
- [Changelog](https://github.com/rust-lang/futures-rs/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/futures-rs/compare/0.3.26...0.3.27)

---
updated-dependencies:
- dependency-name: futures
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-27 09:13:45 -07:00
dependabot[bot]
e7e7118725 chore(deps): bump hyper from 0.14.24 to 0.14.25 (#358)
Bumps [hyper](https://github.com/hyperium/hyper) from 0.14.24 to 0.14.25.
- [Release notes](https://github.com/hyperium/hyper/releases)
- [Changelog](https://github.com/hyperium/hyper/blob/v0.14.25/CHANGELOG.md)
- [Commits](https://github.com/hyperium/hyper/compare/v0.14.24...v0.14.25)

---
updated-dependencies:
- dependency-name: hyper
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-27 09:13:36 -07:00
dependabot[bot]
99f790cacf chore(deps): bump toml from 0.7.2 to 0.7.3 (#360)
Bumps [toml](https://github.com/toml-rs/toml) from 0.7.2 to 0.7.3.
- [Release notes](https://github.com/toml-rs/toml/releases)
- [Commits](https://github.com/toml-rs/toml/compare/toml-v0.7.2...toml-v0.7.3)

---
updated-dependencies:
- dependency-name: toml
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-27 09:13:25 -07:00
dependabot[bot]
434b0bb69e chore(deps): bump serde from 1.0.154 to 1.0.158 (#376)
Bumps [serde](https://github.com/serde-rs/serde) from 1.0.154 to 1.0.158.
- [Release notes](https://github.com/serde-rs/serde/releases)
- [Commits](https://github.com/serde-rs/serde/compare/v1.0.154...v1.0.158)

---
updated-dependencies:
- dependency-name: serde
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-27 09:13:13 -07:00
dependabot[bot]
714e043ef0 chore(deps): bump async-trait from 0.1.66 to 0.1.68 (#382)
Bumps [async-trait](https://github.com/dtolnay/async-trait) from 0.1.66 to 0.1.68.
- [Release notes](https://github.com/dtolnay/async-trait/releases)
- [Commits](https://github.com/dtolnay/async-trait/compare/0.1.66...0.1.68)

---
updated-dependencies:
- dependency-name: async-trait
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-27 09:12:54 -07:00
dependabot[bot]
863104aadd chore(deps): bump regex from 1.7.1 to 1.7.3 (#385)
Bumps [regex](https://github.com/rust-lang/regex) from 1.7.1 to 1.7.3.
- [Release notes](https://github.com/rust-lang/regex/releases)
- [Changelog](https://github.com/rust-lang/regex/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/regex/compare/1.7.1...1.7.3)

---
updated-dependencies:
- dependency-name: regex
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-27 09:12:43 -07:00
Lev Kokotov
7dd96141e3 Update README.md 2023-03-26 00:33:05 -07:00
80 changed files with 10687 additions and 2609 deletions

View File

@@ -9,7 +9,7 @@ jobs:
# Specify the execution environment. You can specify an image from Dockerhub or use one of our Convenience Images from CircleCI's Developer Hub.
# See: https://circleci.com/docs/2.0/configuration-reference/#docker-machine-macos-windows-executor
docker:
- image: ghcr.io/levkk/pgcat-ci:1.67
- image: ghcr.io/postgresml/pgcat-ci:latest
environment:
RUST_LOG: info
LLVM_PROFILE_FILE: /tmp/pgcat-%m-%p.profraw
@@ -46,6 +46,14 @@ jobs:
POSTGRES_PASSWORD: postgres
POSTGRES_INITDB_ARGS: --auth-local=scram-sha-256 --auth-host=scram-sha-256 --auth=scram-sha-256
- image: postgres:14
command: ["postgres", "-p", "10432", "-c", "shared_preload_libraries=pg_stat_statements"]
environment:
POSTGRES_USER: postgres
POSTGRES_DB: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_INITDB_ARGS: --auth-local=md5 --auth-host=md5 --auth=md5
# Add steps to the job
# See: https://circleci.com/docs/2.0/configuration-reference/#steps
steps:

View File

@@ -39,7 +39,7 @@ log_client_connections = false
log_client_disconnections = false
# Reload config automatically if it changes.
autoreload = true
autoreload = 15000
# TLS
tls_certificate = ".circleci/server.cert"
@@ -74,6 +74,10 @@ default_role = "any"
# we'll direct it to the primary.
query_parser_enabled = true
# If the query parser is enabled and this setting is enabled, we'll attempt to
# infer the role from the query itself.
query_parser_read_write_splitting = true
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
# load balancing of read queries. Otherwise, the primary will only be used for write
# queries. The primary can always be explicitely selected with our custom protocol.
@@ -134,6 +138,7 @@ database = "shard2"
pool_mode = "session"
default_role = "primary"
query_parser_enabled = true
query_parser_read_write_splitting = true
primary_reads_enabled = true
sharding_function = "pg_bigint_hash"

View File

@@ -19,6 +19,7 @@ PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 5432 -U postgres -f tests/sharding/q
PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 7432 -U postgres -f tests/sharding/query_routing_setup.sql
PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 8432 -U postgres -f tests/sharding/query_routing_setup.sql
PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 9432 -U postgres -f tests/sharding/query_routing_setup.sql
PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 10432 -U postgres -f tests/sharding/query_routing_setup.sql
PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard0 -i
PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard1 -i

14
.editorconfig Normal file
View File

@@ -0,0 +1,14 @@
root = true
[*]
trim_trailing_whitespace = true
insert_final_newline = true
[*.rs]
indent_style = space
indent_size = 4
max_line_length = 120
[*.toml]
indent_style = space
indent_size = 2

View File

@@ -1,6 +1,11 @@
name: Build and Push
on: push
on:
push:
branches:
- main
tags:
- v*
env:
registry: ghcr.io
@@ -29,6 +34,7 @@ jobs:
tags: |
type=sha,prefix=,format=long
type=schedule
type=ref,event=tag
type=ref,event=branch
type=ref,event=pr
type=raw,value=latest,enable={{ is_default_branch }}

View File

@@ -0,0 +1,48 @@
name: pgcat package (deb)
on:
workflow_dispatch:
inputs:
packageVersion:
default: "1.1.2-dev"
jobs:
build:
strategy:
max-parallel: 1
fail-fast: false # Let the other job finish, or they can lock each other out
matrix:
os: ["buildjet-4vcpu-ubuntu-2204", "buildjet-4vcpu-ubuntu-2204-arm"]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
with:
toolchain: stable
- name: Install dependencies
env:
DEBIAN_FRONTEND: noninteractive
TZ: Etc/UTC
run: |
curl -sLO https://github.com/deb-s3/deb-s3/releases/download/0.11.4/deb-s3-0.11.4.gem
sudo gem install deb-s3-0.11.4.gem
dpkg-deb --version
- name: Build and release package
env:
AWS_ACCESS_KEY_ID: ${{ vars.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_DEFAULT_REGION: ${{ vars.AWS_DEFAULT_REGION }}
run: |
if [[ $(arch) == "x86_64" ]]; then
export ARCH=amd64
else
export ARCH=arm64
fi
bash utilities/deb.sh ${{ inputs.packageVersion }}
deb-s3 upload \
--lock \
--bucket apt.postgresml.org \
pgcat-${{ inputs.packageVersion }}-ubuntu22.04-${ARCH}.deb \
--codename $(lsb_release -cs)

2
.rustfmt.toml Normal file
View File

@@ -0,0 +1,2 @@
edition = "2021"
hard_tabs = false

188
CONFIG.md
View File

@@ -1,4 +1,4 @@
# PgCat Configurations
# PgCat Configurations
## `general` Section
### host
@@ -49,6 +49,46 @@ default: 30000 # milliseconds
How long an idle connection with a server is left open (ms).
### server_lifetime
```
path: general.server_lifetime
default: 86400000 # 24 hours
```
Max connection lifetime before it's closed, even if actively used.
### server_round_robin
```
path: general.server_round_robin
default: false
```
Whether to use round robin for server selection or not.
### server_tls
```
path: general.server_tls
default: false
```
Whether to use TLS for server connections or not.
### verify_server_certificate
```
path: general.verify_server_certificate
default: false
```
Whether to verify server certificate or not.
### verify_config
```
path: general.verify_config
default: true
```
Whether to verify config or not.
### idle_client_in_transaction_timeout
```
path: general.idle_client_in_transaction_timeout
@@ -108,10 +148,10 @@ If we should log client disconnections
### autoreload
```
path: general.autoreload
default: false
default: 15000 # milliseconds
```
When set to true, PgCat reloads configs if it detects a change in the config file.
When set, PgCat automatically reloads its configurations at the specified interval (in milliseconds) if it detects changes in the configuration file. The default interval is 15000 milliseconds or 15 seconds.
### worker_threads
```
@@ -143,7 +183,13 @@ path: general.tcp_keepalives_interval
default: 5
```
Number of seconds between keepalive packets.
### tcp_user_timeout
```
path: general.tcp_user_timeout
default: 10000
```
A linux-only parameters that defines the amount of time in milliseconds that transmitted data may remain unacknowledged or buffered data may remain untransmitted (due to zero window size) before TCP will forcibly disconnect
### tls_certificate
```
@@ -152,7 +198,7 @@ default: <UNSET>
example: "server.cert"
```
Path to TLS Certficate file to use for TLS connections
Path to TLS Certificate file to use for TLS connections
### tls_private_key
```
@@ -180,6 +226,71 @@ default: "admin_pass"
Password to access the virtual administrative database
### auth_query
```
path: general.auth_query
default: <UNSET>
example: "SELECT $1"
```
Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
established using the database configured in the pool. This parameter is inherited by every pool
and can be redefined in pool configuration.
### auth_query_user
```
path: general.auth_query_user
default: <UNSET>
example: "sharding_user"
```
User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
This parameter is inherited by every pool and can be redefined in pool configuration.
### auth_query_password
```
path: general.auth_query_password
default: <UNSET>
example: "sharding_user"
```
Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
This parameter is inherited by every pool and can be redefined in pool configuration.
### prepared_statements
```
path: general.prepared_statements
default: false
```
Whether to use prepared statements or not.
### prepared_statements_cache_size
```
path: general.prepared_statements_cache_size
default: 500
```
Size of the prepared statements cache.
### dns_cache_enabled
```
path: general.dns_cache_enabled
default: false
```
When enabled, ip resolutions for server connections specified using hostnames will be cached
and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found
old ip connections are closed (gracefully) and new connections will start using new ip.
### dns_max_ttl
```
path: general.dns_max_ttl
default: 30
```
Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
## `pools.<pool_name>` Section
### pool_mode
@@ -200,7 +311,7 @@ default: "random"
Load balancing mode
`random` selects the server at random
`loc` selects the server with the least outstanding busy conncetions
`loc` selects the server with the least outstanding busy connections
### default_role
```
@@ -213,7 +324,7 @@ If the client doesn't specify, PgCat routes traffic to this role by default.
`replica` round-robin between replicas only without touching the primary,
`primary` all queries go to the primary unless otherwise specified.
### query_parser_enabled (experimental)
### query_parser_enabled
```
path: pools.<pool_name>.query_parser_enabled
default: true
@@ -234,7 +345,7 @@ If the query parser is enabled and this setting is enabled, the primary will be
load balancing of read queries. Otherwise, the primary will only be used for write
queries. The primary can always be explicitly selected with our custom protocol.
### sharding_key_regex (experimental)
### sharding_key_regex
```
path: pools.<pool_name>.sharding_key_regex
default: <UNSET>
@@ -256,7 +367,40 @@ Current options:
`pg_bigint_hash`: PARTITION BY HASH (Postgres hashing function)
`sha1`: A hashing function based on SHA1
### automatic_sharding_key (experimental)
### auth_query
```
path: pools.<pool_name>.auth_query
default: <UNSET>
example: "SELECT $1"
```
Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
established using the database configured in the pool. This parameter is inherited by every pool
and can be redefined in pool configuration.
### auth_query_user
```
path: pools.<pool_name>.auth_query_user
default: <UNSET>
example: "sharding_user"
```
User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
This parameter is inherited by every pool and can be redefined in pool configuration.
### auth_query_password
```
path: pools.<pool_name>.auth_query_password
default: <UNSET>
example: "sharding_user"
```
Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
This parameter is inherited by every pool and can be redefined in pool configuration.
### automatic_sharding_key
```
path: pools.<pool_name>.automatic_sharding_key
default: <UNSET>
@@ -289,7 +433,8 @@ path: pools.<pool_name>.users.<user_index>.username
default: "sharding_user"
```
Postgresql username
PostgreSQL username used to authenticate the user and connect to the server
if `server_username` is not set.
### password
```
@@ -297,7 +442,26 @@ path: pools.<pool_name>.users.<user_index>.password
default: "sharding_user"
```
Postgresql password
PostgreSQL password used to authenticate the user and connect to the server
if `server_password` is not set.
### server_username
```
path: pools.<pool_name>.users.<user_index>.server_username
default: <UNSET>
example: "another_user"
```
PostgreSQL username used to connect to the server.
### server_password
```
path: pools.<pool_name>.users.<user_index>.server_password
default: <UNSET>
example: "another_password"
```
PostgreSQL password used to connect to the server.
### pool_size
```
@@ -328,7 +492,7 @@ default: [["127.0.0.1", 5432, "primary"], ["localhost", 5432, "replica"]]
Array of servers in the shard, each server entry is an array of `[host, port, role]`
### mirrors (experimental)
### mirrors
```
path: pools.<pool_name>.shards.<shard_index>.mirrors
default: <UNSET>

1311
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,35 +1,33 @@
[package]
name = "pgcat"
version = "1.0.0"
version = "1.1.2-dev"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
tokio = { version = "1", features = ["full"] }
bytes = "1"
md-5 = "0.10"
bb8 = "0.8.0"
bb8 = "0.8.1"
async-trait = "0.1"
rand = "0.8"
chrono = "0.4"
sha-1 = "0.10"
toml = "0.7"
serde = "1"
serde = { version = "1", features = ["derive"] }
serde_derive = "1"
regex = "1"
num_cpus = "1"
once_cell = "1"
sqlparser = "0.32.0"
sqlparser = {version = "0.34", features = ["visitor"] }
log = "0.4"
arc-swap = "1"
env_logger = "0.10"
parking_lot = "0.12.1"
hmac = "0.12"
sha2 = "0.10"
base64 = "0.21"
stringprep = "0.1"
tokio-rustls = "0.23"
tokio-rustls = "0.24"
rustls-pemfile = "1"
hyper = { version = "0.14", features = ["full"] }
phf = { version = "0.11.1", features = ["macros"] }
@@ -37,6 +35,20 @@ exitcode = "1.1.2"
futures = "0.3"
socket2 = { version = "0.4.7", features = ["all"] }
nix = "0.26.2"
atomic_enum = "0.2.0"
postgres-protocol = "0.6.5"
fallible-iterator = "0.2"
pin-project = "1"
webpki-roots = "0.23"
rustls = { version = "0.21", features = ["dangerous_configuration"] }
trust-dns-resolver = "0.22.0"
tokio-test = "0.4.2"
serde_json = "1"
itertools = "0.10"
clap = { version = "4.3.1", features = ["derive", "env"] }
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter", "std"]}
[target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0"

View File

@@ -1,9 +1,13 @@
FROM rust:1 AS builder
FROM rust:1-slim-bookworm AS builder
RUN apt-get update && \
apt-get install -y build-essential
COPY . /app
WORKDIR /app
RUN cargo build --release
FROM debian:bullseye-slim
FROM debian:bookworm-slim
COPY --from=builder /app/target/release/pgcat /usr/bin/pgcat
COPY --from=builder /app/pgcat.toml /etc/pgcat/pgcat.toml
WORKDIR /etc/pgcat

View File

@@ -1,4 +1,6 @@
FROM cimg/rust:1.67.1
COPY --from=sclevine/yj /bin/yj /bin/yj
RUN /bin/yj -h
RUN sudo apt-get update && \
sudo apt-get install -y \
psmisc postgresql-contrib-14 postgresql-client-14 libpq-dev \

25
Dockerfile.dev Normal file
View File

@@ -0,0 +1,25 @@
FROM lukemathwalker/cargo-chef:latest-rust-1 AS chef
RUN apt-get update && \
apt-get install -y build-essential
WORKDIR /app
FROM chef AS planner
COPY . .
RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder
COPY --from=planner /app/recipe.json recipe.json
# Build dependencies - this is the caching Docker layer!
RUN cargo chef cook --release --recipe-path recipe.json
# Build application
COPY . .
RUN cargo build
FROM debian:bookworm-slim
COPY --from=builder /app/target/release/pgcat /usr/bin/pgcat
COPY --from=builder /app/pgcat.toml /etc/pgcat/pgcat.toml
WORKDIR /etc/pgcat
ENV RUST_LOG=info
CMD ["pgcat"]

View File

@@ -18,23 +18,56 @@ PostgreSQL pooler and proxy (like PgBouncer) with support for sharding, load bal
| Failover | **Stable** | Queries are automatically rerouted around broken replicas, validated by regular health checks. |
| Admin database statistics | **Stable** | Pooler statistics and administration via the `pgbouncer` and `pgcat` databases. |
| Prometheus statistics | **Stable** | Statistics are reported via a HTTP endpoint for Prometheus. |
| Client TLS | **Stable** | Clients can connect to the pooler using TLS/SSL. |
| SSL/TLS | **Stable** | Clients can connect to the pooler using TLS. Pooler can connect to Postgres servers using TLS. |
| Client/Server authentication | **Stable** | Clients can connect using MD5 authentication, supported by `libpq` and all Postgres client drivers. PgCat can connect to Postgres using MD5 and SCRAM-SHA-256. |
| Live configuration reloading | **Stable** | Identical to PgBouncer; all settings can be reloaded dynamically (except `host` and `port`). |
| Auth passthrough | **Stable** | MD5 password authentication can be configured to use an `auth_query` so no cleartext passwords are needed in the config file.|
| Sharding using extended SQL syntax | **Experimental** | Clients can dynamically configure the pooler to route queries to specific shards. |
| Sharding using comments parsing/Regex | **Experimental** | Clients can include shard information (sharding key, shard ID) in the query comments. |
| Automatic sharding | **Experimental** | PgCat can parse queries detect sharding keys automatically, and route queries to the route shard. |
| Automatic sharding | **Experimental** | PgCat can parse queries, detect sharding keys automatically, and route queries to the correct shard. |
| Mirroring | **Experimental** | Mirror queries between multiple databases in order to test servers with realistic production traffic. |
## Status
PgCat is stable and used in production to serve hundreds of thousands of queries per second. Some features remain experimental and are being actively developed. They are optional and can be enabled through configuration.
PgCat is stable and used in production to serve hundreds of thousands of queries per second.
| | |
|-|-|
|<a href="https://tech.instacart.com/adopting-pgcat-a-nextgen-postgres-proxy-3cf284e68c2f"><img src="./images/instacart.webp" height="70" width="auto"></a>|<a href="https://postgresml.org/blog/scaling-postgresml-to-one-million-requests-per-second"><img src="./images/postgresml.webp" height="70" width="auto"></a>|
| [Instacart](https://tech.instacart.com/adopting-pgcat-a-nextgen-postgres-proxy-3cf284e68c2f) | [PostgresML](https://postgresml.org/blog/scaling-postgresml-to-one-million-requests-per-second) |
<table>
<tr>
<td>
<a href="https://tech.instacart.com/adopting-pgcat-a-nextgen-postgres-proxy-3cf284e68c2f">
<img src="./images/instacart.webp" height="70" width="auto">
</a>
</td>
<td>
<a href="https://postgresml.org/blog/scaling-postgresml-to-one-million-requests-per-second">
<img src="./images/postgresml.webp" height="70" width="auto">
</a>
</td>
<td>
<a href="https://onesignal.com">
<img src="./images/one_signal.webp" height="70" width="auto">
</a>
</td>
</tr>
<tr>
<td>
<a href="https://tech.instacart.com/adopting-pgcat-a-nextgen-postgres-proxy-3cf284e68c2f">
Instacart
</a>
</td>
<td>
<a href="https://postgresml.org/blog/scaling-postgresml-to-one-million-requests-per-second">
PostgresML
</a>
</td>
<td>
OneSignal
</td>
</tr>
</table>
Some features remain experimental and are being actively developed. They are optional and can be enabled through configuration.
## Deployment
@@ -79,7 +112,7 @@ pgbench -t 1000 -p 6432 -h 127.0.0.1 --protocol extended
See [sharding README](./tests/sharding/README.md) for sharding logic testing.
Additionally, all features are tested with Ruby, Python, and Rust tests unit and integration tests.
Additionally, all features are tested with Ruby, Python, and Rust unit and integration tests.
Run `cargo test` to run Rust unit tests.
@@ -98,7 +131,7 @@ You can open a Docker development environment where you can debug tests easier.
./dev/script/console
```
This will open a terminal in an environment similar to that used in tests. In there, you can compile the pooler, run tests, do some debugging with the test environment, etc. Objects compiled inside the contaner (and bundled gems) will be placed in `dev/cache` so they don't interfere with what you have on your machine.
This will open a terminal in an environment similar to that used in tests. In there, you can compile the pooler, run tests, do some debugging with the test environment, etc. Objects compiled inside the container (and bundled gems) will be placed in `dev/cache` so they don't interfere with what you have on your machine.
## Usage

9
control Normal file
View File

@@ -0,0 +1,9 @@
Package: pgcat
Version: ${PACKAGE_VERSION}
Section: database
Priority: optional
Architecture: ${ARCH}
Maintainer: PostgresML <team@postgresml.org>
Homepage: https://postgresml.org
Description: PgCat - NextGen PostgreSQL Pooler
PostgreSQL pooler and proxy (like PgBouncer) with support for sharding, load balancing, failover and mirroring.

View File

@@ -1,4 +1,4 @@
FROM rust:bullseye
FROM rust:1.70-bullseye
# Dependencies
RUN apt-get update -y \

View File

@@ -25,7 +25,9 @@ x-common-env-pg:
services:
main:
image: kubernetes/pause
image: gcr.io/google_containers/pause:3.2
ports:
- 6432
pg1:
<<: *common-definition-pg
@@ -56,6 +58,13 @@ services:
POSTGRES_INITDB_ARGS: --auth-local=scram-sha-256 --auth-host=scram-sha-256 --auth=scram-sha-256
PGPORT: 9432
command: ["postgres", "-p", "9432", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-c", "pg_stat_statements.max=100000"]
pg5:
<<: *common-definition-pg
environment:
<<: *common-env-pg
POSTGRES_INITDB_ARGS: --auth-local=md5 --auth-host=md5 --auth=md5
PGPORT: 10432
command: ["postgres", "-p", "10432", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-c", "pg_stat_statements.max=100000"]
toxiproxy:
build: .
@@ -69,6 +78,7 @@ services:
- pg2
- pg3
- pg4
- pg5
pgcat-shell:
stdin_open: true

View File

@@ -38,9 +38,6 @@ log_client_connections = false
# If we should log client disconnections
log_client_disconnections = false
# Reload config automatically if it changes.
autoreload = false
# TLS
# tls_certificate = "server.cert"
# tls_private_key = "server.key"
@@ -74,9 +71,13 @@ default_role = "any"
# we'll direct it to the primary.
query_parser_enabled = true
# If the query parser is enabled and this setting is enabled, we'll attempt to
# infer the role from the query itself.
query_parser_read_write_splitting = true
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
# load balancing of read queries. Otherwise, the primary will only be used for write
# queries. The primary can always be explicitely selected with our custom protocol.
# queries. The primary can always be explicitly selected with our custom protocol.
primary_reads_enabled = true
# So what if you wanted to implement a different hashing function,

BIN
images/one_signal.webp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

22
pgcat.minimal.toml Normal file
View File

@@ -0,0 +1,22 @@
# This is an example of the most basic config
# that will mimic what PgBouncer does in transaction mode with one server.
[general]
host = "0.0.0.0"
port = 6433
admin_username = "pgcat"
admin_password = "pgcat"
[pools.pgml.users.0]
username = "postgres"
password = "postgres"
pool_size = 10
min_pool_size = 1
pool_mode = "transaction"
[pools.pgml.shards.0]
servers = [
["127.0.0.1", 28815, "primary"]
]
database = "postgres"

16
pgcat.service Normal file
View File

@@ -0,0 +1,16 @@
[Unit]
Description=PgCat pooler
After=network.target
StartLimitIntervalSec=0
[Service]
User=pgcat
Type=simple
Restart=always
RestartSec=1
Environment=RUST_LOG=info
LimitNOFILE=65536
ExecStart=/usr/bin/pgcat /etc/pgcat.toml
[Install]
WantedBy=multi-user.target

View File

@@ -23,6 +23,9 @@ connect_timeout = 5000 # milliseconds
# How long an idle connection with a server is left open (ms).
idle_timeout = 30000 # milliseconds
# Max connection lifetime before it's closed, even if actively used.
server_lifetime = 86400000 # 24 hours
# How long a client is allowed to be idle while in a transaction (ms).
idle_client_in_transaction_timeout = 0 # milliseconds
@@ -45,7 +48,7 @@ log_client_connections = false
log_client_disconnections = false
# When set to true, PgCat reloads configs if it detects a change in the config file.
autoreload = false
autoreload = 15000
# Number of worker threads the Runtime will use (4 by default).
worker_threads = 5
@@ -57,10 +60,22 @@ tcp_keepalives_count = 5
# Number of seconds between keepalive packets.
tcp_keepalives_interval = 5
# Path to TLS Certficate file to use for TLS connections
# tls_certificate = "server.cert"
# Handle prepared statements.
prepared_statements = true
# Prepared statements server cache size.
prepared_statements_cache_size = 500
# Path to TLS Certificate file to use for TLS connections
# tls_certificate = ".circleci/server.cert"
# Path to TLS private key file to use for TLS connections
# tls_private_key = "server.key"
# tls_private_key = ".circleci/server.key"
# Enable/disable server TLS
server_tls = false
# Verify server certificate is completely authentic.
verify_server_certificate = false
# User name to access the virtual administrative database (pgbouncer or pgcat)
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
@@ -68,6 +83,58 @@ admin_username = "admin_user"
# Password to access the virtual administrative database
admin_password = "admin_pass"
# Default plugins that are configured on all pools.
[plugins]
# Prewarmer plugin that runs queries on server startup, before giving the connection
# to the client.
[plugins.prewarmer]
enabled = false
queries = [
"SELECT pg_prewarm('pgbench_accounts')",
]
# Log all queries to stdout.
[plugins.query_logger]
enabled = false
# Block access to tables that Postgres does not allow us to control.
[plugins.table_access]
enabled = false
tables = [
"pg_user",
"pg_roles",
"pg_database",
]
# Intercept user queries and give a fake reply.
[plugins.intercept]
enabled = true
[plugins.intercept.queries.0]
query = "select current_database() as a, current_schemas(false) as b"
schema = [
["a", "text"],
["b", "text"],
]
result = [
["${DATABASE}", "{public}"],
]
[plugins.intercept.queries.1]
query = "select current_database(), current_schema(), current_user"
schema = [
["current_database", "text"],
["current_schema", "text"],
["current_user", "text"],
]
result = [
["${DATABASE}", "public", "${USER}"],
]
# pool configs are structured as pool.<pool_name>
# the pool_name is what clients use as database name when connecting.
# For a pool named `sharded_db`, clients access that pool using connection string like
@@ -95,6 +162,10 @@ default_role = "any"
# we'll direct it to the primary.
query_parser_enabled = true
# If the query parser is enabled and this setting is enabled, we'll attempt to
# infer the role from the query itself.
query_parser_read_write_splitting = true
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
# load balancing of read queries. Otherwise, the primary will only be used for write
# queries. The primary can always be explicitly selected with our custom protocol.
@@ -106,6 +177,12 @@ primary_reads_enabled = true
# shard_id_regex = '/\* shard_id: (\d+) \*/'
# regex_search_limit = 1000 # only look at the first 1000 characters of SQL statements
# Defines the behavior when no shard is selected in a sharded system.
# `random`: picks a shard at random
# `random_healthy`: picks a shard at random favoring shards with the least number of recent errors
# `shard_<number>`: e.g. shard_0, shard_4, etc. picks a specific shard, everytime
# no_shard_specified_behavior = "shard_0"
# So what if you wanted to implement a different hashing function,
# or you've already built one and you want this pooler to use it?
# Current options:
@@ -113,6 +190,21 @@ primary_reads_enabled = true
# `sha1`: A hashing function based on SHA1
sharding_function = "pg_bigint_hash"
# Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
# established using the database configured in the pool. This parameter is inherited by every pool
# and can be redefined in pool configuration.
# auth_query="SELECT usename, passwd FROM pg_shadow WHERE usename='$1'"
# User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
# specified in `auth_query_user`. The connection will be established using the database configured in the pool.
# This parameter is inherited by every pool and can be redefined in pool configuration.
# auth_query_user = "sharding_user"
# Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
# specified in `auth_query_user`. The connection will be established using the database configured in the pool.
# This parameter is inherited by every pool and can be redefined in pool configuration.
# auth_query_password = "sharding_user"
# Automatically parse this from queries and route queries to the right shard!
# automatic_sharding_key = "data.id"
@@ -122,18 +214,86 @@ idle_timeout = 40000
# Connect timeout can be overwritten in the pool
connect_timeout = 3000
# When enabled, ip resolutions for server connections specified using hostnames will be cached
# and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found
# old ip connections are closed (gracefully) and new connections will start using new ip.
# dns_cache_enabled = false
# Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
# dns_max_ttl = 30
# Plugins can be configured on a pool-per-pool basis. This overrides the global plugins setting,
# so all plugins have to be configured here again.
[pool.sharded_db.plugins]
[pools.sharded_db.plugins.prewarmer]
enabled = true
queries = [
"SELECT pg_prewarm('pgbench_accounts')",
]
[pools.sharded_db.plugins.query_logger]
enabled = false
[pools.sharded_db.plugins.table_access]
enabled = false
tables = [
"pg_user",
"pg_roles",
"pg_database",
]
[pools.sharded_db.plugins.intercept]
enabled = true
[pools.sharded_db.plugins.intercept.queries.0]
query = "select current_database() as a, current_schemas(false) as b"
schema = [
["a", "text"],
["b", "text"],
]
result = [
["${DATABASE}", "{public}"],
]
[pools.sharded_db.plugins.intercept.queries.1]
query = "select current_database(), current_schema(), current_user"
schema = [
["current_database", "text"],
["current_schema", "text"],
["current_user", "text"],
]
result = [
["${DATABASE}", "public", "${USER}"],
]
# User configs are structured as pool.<pool_name>.users.<user_index>
# This secion holds the credentials for users that may connect to this cluster
# This section holds the credentials for users that may connect to this cluster
[pools.sharded_db.users.0]
# Postgresql username
# PostgreSQL username used to authenticate the user and connect to the server
# if `server_username` is not set.
username = "sharding_user"
# Postgresql password
# PostgreSQL password used to authenticate the user and connect to the server
# if `server_password` is not set.
password = "sharding_user"
pool_mode = "transaction"
# PostgreSQL username used to connect to the server.
# server_username = "another_user"
# PostgreSQL password used to connect to the server.
# server_password = "another_password"
# Maximum number of server connections that can be established for this user
# The maximum number of connection from a single Pgcat process to any database in the cluster
# is the sum of pool_size across all users.
pool_size = 9
# Maximum query duration. Dangerous, but protects against DBs that died in a non-obvious way.
# 0 means it is disabled.
statement_timeout = 0
@@ -178,6 +338,8 @@ sharding_function = "pg_bigint_hash"
username = "simple_user"
password = "simple_user"
pool_size = 5
min_pool_size = 3
server_lifetime = 60000
statement_timeout = 0
[pools.simple_db.shards.0]

9
postinst Normal file
View File

@@ -0,0 +1,9 @@
#!/bin/bash
set -e
systemctl daemon-reload
systemctl enable pgcat
if ! id pgcat 2> /dev/null; then
useradd -s /usr/bin/false pgcat
fi

4
postrm Normal file
View File

@@ -0,0 +1,4 @@
#!/bin/bash
set -e
systemctl daemon-reload

5
prerm Normal file
View File

@@ -0,0 +1,5 @@
#!/bin/bash
set -e
systemctl stop pgcat
systemctl disable pgcat

View File

@@ -1,32 +1,33 @@
use crate::pool::BanReason;
/// Admin database.
use crate::server::ServerParameters;
use crate::stats::pool::PoolStats;
use bytes::{Buf, BufMut, BytesMut};
use log::{error, info, trace};
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
use std::collections::HashMap;
/// Admin database.
use std::sync::atomic::Ordering;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::time::Instant;
use crate::config::{get_config, reload_config, VERSION};
use crate::errors::Error;
use crate::messages::*;
use crate::pool::ClientServerMap;
use crate::pool::{get_all_pools, get_pool};
use crate::stats::{
get_address_stats, get_client_stats, get_pool_stats, get_server_stats, ClientState, ServerState,
};
use crate::ClientServerMap;
use crate::stats::{get_client_stats, get_server_stats, ClientState, ServerState};
pub fn generate_server_info_for_admin() -> BytesMut {
let mut server_info = BytesMut::new();
pub fn generate_server_parameters_for_admin() -> ServerParameters {
let mut server_parameters = ServerParameters::new();
server_info.put(server_parameter_message("application_name", ""));
server_info.put(server_parameter_message("client_encoding", "UTF8"));
server_info.put(server_parameter_message("server_encoding", "UTF8"));
server_info.put(server_parameter_message("server_version", VERSION));
server_info.put(server_parameter_message("DateStyle", "ISO, MDY"));
server_parameters.set_param("application_name".to_string(), "".to_string(), true);
server_parameters.set_param("client_encoding".to_string(), "UTF8".to_string(), true);
server_parameters.set_param("server_encoding".to_string(), "UTF8".to_string(), true);
server_parameters.set_param("server_version".to_string(), VERSION.to_string(), true);
server_parameters.set_param("DateStyle".to_string(), "ISO, MDY".to_string(), true);
server_info
server_parameters
}
/// Handle admin client.
@@ -73,17 +74,21 @@ where
}
"PAUSE" => {
trace!("PAUSE");
pause(stream, query_parts[1]).await
pause(stream, query_parts).await
}
"RESUME" => {
trace!("RESUME");
resume(stream, query_parts[1]).await
resume(stream, query_parts).await
}
"SHUTDOWN" => {
trace!("SHUTDOWN");
shutdown(stream).await
}
"SHOW" => match query_parts[1].to_ascii_uppercase().as_str() {
"HELP" => {
trace!("SHOW HELP");
show_help(stream).await
}
"BANS" => {
trace!("SHOW BANS");
show_bans(stream).await
@@ -158,7 +163,14 @@ where
"free_clients".to_string(),
client_stats
.keys()
.filter(|client_id| client_stats.get(client_id).unwrap().state == ClientState::Idle)
.filter(|client_id| {
client_stats
.get(client_id)
.unwrap()
.state
.load(Ordering::Relaxed)
== ClientState::Idle
})
.count()
.to_string(),
]));
@@ -166,7 +178,14 @@ where
"used_clients".to_string(),
client_stats
.keys()
.filter(|client_id| client_stats.get(client_id).unwrap().state == ClientState::Active)
.filter(|client_id| {
client_stats
.get(client_id)
.unwrap()
.state
.load(Ordering::Relaxed)
== ClientState::Active
})
.count()
.to_string(),
]));
@@ -178,7 +197,14 @@ where
"free_servers".to_string(),
server_stats
.keys()
.filter(|server_id| server_stats.get(server_id).unwrap().state == ServerState::Idle)
.filter(|server_id| {
server_stats
.get(server_id)
.unwrap()
.state
.load(Ordering::Relaxed)
== ServerState::Idle
})
.count()
.to_string(),
]));
@@ -186,7 +212,14 @@ where
"used_servers".to_string(),
server_stats
.keys()
.filter(|server_id| server_stats.get(server_id).unwrap().state == ServerState::Active)
.filter(|server_id| {
server_stats
.get(server_id)
.unwrap()
.state
.load(Ordering::Relaxed)
== ServerState::Active
})
.count()
.to_string(),
]));
@@ -227,52 +260,51 @@ async fn show_pools<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let all_pool_stats = get_pool_stats();
let pool_lookup = PoolStats::construct_pool_lookup();
let mut res = BytesMut::new();
res.put(row_description(&PoolStats::generate_header()));
pool_lookup.iter().for_each(|(_identifier, pool_stats)| {
res.put(data_row(&pool_stats.generate_row()));
});
res.put(command_complete("SHOW"));
let columns = vec![
("database", DataType::Text),
("user", DataType::Text),
("pool_mode", DataType::Text),
("cl_idle", DataType::Numeric),
("cl_active", DataType::Numeric),
("cl_waiting", DataType::Numeric),
("cl_cancel_req", DataType::Numeric),
("sv_active", DataType::Numeric),
("sv_idle", DataType::Numeric),
("sv_used", DataType::Numeric),
("sv_tested", DataType::Numeric),
("sv_login", DataType::Numeric),
("maxwait", DataType::Numeric),
("maxwait_us", DataType::Numeric),
// ReadyForQuery
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, &res).await
}
/// Show all available options.
async fn show_help<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut res = BytesMut::new();
let detail_msg = vec![
"",
"SHOW HELP|CONFIG|DATABASES|POOLS|CLIENTS|SERVERS|USERS|VERSION",
// "SHOW PEERS|PEER_POOLS", // missing PEERS|PEER_POOLS
// "SHOW FDS|SOCKETS|ACTIVE_SOCKETS|LISTS|MEM|STATE", // missing FDS|SOCKETS|ACTIVE_SOCKETS|MEM|STATE
"SHOW LISTS",
// "SHOW DNS_HOSTS|DNS_ZONES", // missing DNS_HOSTS|DNS_ZONES
"SHOW STATS", // missing STATS_TOTALS|STATS_AVERAGES|TOTALS
"SET key = arg",
"RELOAD",
"PAUSE [<db>, <user>]",
"RESUME [<db>, <user>]",
// "DISABLE <db>", // missing
// "ENABLE <db>", // missing
// "RECONNECT [<db>]", missing
// "KILL <db>",
// "SUSPEND",
"SHUTDOWN",
// "WAIT_CLOSE [<db>]", // missing
];
let mut res = BytesMut::new();
res.put(row_description(&columns));
for (user_pool, pool) in get_all_pools() {
let def = HashMap::default();
let pool_stats = all_pool_stats
.get(&(user_pool.db.clone(), user_pool.user.clone()))
.unwrap_or(&def);
let pool_config = &pool.settings;
let mut row = vec![
user_pool.db.clone(),
user_pool.user.clone(),
pool_config.pool_mode.to_string(),
];
for column in &columns[3..columns.len()] {
let value = match column.0 {
"maxwait" => (pool_stats.get("maxwait_us").unwrap_or(&0) / 1_000_000).to_string(),
"maxwait_us" => {
(pool_stats.get("maxwait_us").unwrap_or(&0) % 1_000_000).to_string()
}
_other_values => pool_stats.get(column.0).unwrap_or(&0).to_string(),
};
row.push(value);
}
res.put(data_row(&row));
}
res.put(notify("Console usage", detail_msg.join("\n\t")));
res.put(command_complete("SHOW"));
// ReadyForQuery
@@ -320,17 +352,17 @@ where
let paused = pool.paused();
res.put(data_row(&vec![
address.name(), // name
address.host.to_string(), // host
address.port.to_string(), // port
database_name.to_string(), // database
pool_config.user.username.to_string(), // force_user
pool_config.user.pool_size.to_string(), // pool_size
"0".to_string(), // min_pool_size
"0".to_string(), // reserve_pool
pool_config.pool_mode.to_string(), // pool_mode
pool_config.user.pool_size.to_string(), // max_connections
pool_state.connections.to_string(), // current_connections
address.name(), // name
address.host.to_string(), // host
address.port.to_string(), // port
database_name.to_string(), // database
pool_config.user.username.to_string(), // force_user
pool_config.user.pool_size.to_string(), // pool_size
pool_config.user.min_pool_size.unwrap_or(0).to_string(), // min_pool_size
"0".to_string(), // reserve_pool
pool_config.pool_mode.to_string(), // pool_mode
pool_config.user.pool_size.to_string(), // max_connections
pool_state.connections.to_string(), // current_connections
match paused {
// paused
true => "1".to_string(),
@@ -400,7 +432,7 @@ where
for (id, pool) in get_all_pools().iter() {
for address in pool.get_addresses_from_host(host) {
if !pool.is_banned(&address) {
pool.ban(&address, BanReason::AdminBan(duration_seconds), -1);
pool.ban(&address, BanReason::AdminBan(duration_seconds), None);
res.put(data_row(&vec![
id.db.clone(),
id.user.clone(),
@@ -617,7 +649,6 @@ where
("avg_wait_time", DataType::Numeric),
];
let all_stats = get_address_stats();
let mut res = BytesMut::new();
res.put(row_description(&columns));
@@ -625,15 +656,10 @@ where
for shard in 0..pool.shards() {
for server in 0..pool.servers(shard) {
let address = pool.address(shard, server);
let stats = match all_stats.get(&address.id) {
Some(stats) => stats.clone(),
None => HashMap::new(),
};
let mut row = vec![address.name(), user_pool.db.clone(), user_pool.user.clone()];
for column in &columns[3..] {
row.push(stats.get(column.0).unwrap_or(&0).to_string());
}
let stats = address.stats.clone();
stats.populate_row(&mut row);
res.put(data_row(&row));
}
@@ -673,16 +699,16 @@ where
for (_, client) in new_map {
let row = vec![
format!("{:#010X}", client.client_id),
client.pool_name,
client.username,
client.application_name.clone(),
client.state.to_string(),
client.transaction_count.to_string(),
client.query_count.to_string(),
client.error_count.to_string(),
format!("{:#010X}", client.client_id()),
client.pool_name(),
client.username(),
client.application_name(),
client.state.load(Ordering::Relaxed).to_string(),
client.transaction_count.load(Ordering::Relaxed).to_string(),
client.query_count.load(Ordering::Relaxed).to_string(),
client.error_count.load(Ordering::Relaxed).to_string(),
Instant::now()
.duration_since(client.connect_time)
.duration_since(client.connect_time())
.as_secs()
.to_string(),
];
@@ -717,6 +743,9 @@ where
("bytes_sent", DataType::Numeric),
("bytes_received", DataType::Numeric),
("age_seconds", DataType::Numeric),
("prepare_cache_hit", DataType::Numeric),
("prepare_cache_miss", DataType::Numeric),
("prepare_cache_size", DataType::Numeric),
];
let new_map = get_server_stats();
@@ -724,21 +753,34 @@ where
res.put(row_description(&columns));
for (_, server) in new_map {
let application_name = server.application_name.read();
let row = vec![
format!("{:#010X}", server.server_id),
server.pool_name,
server.username,
server.address_name,
server.application_name,
server.state.to_string(),
server.transaction_count.to_string(),
server.query_count.to_string(),
server.bytes_sent.to_string(),
server.bytes_received.to_string(),
format!("{:#010X}", server.server_id()),
server.pool_name(),
server.username(),
server.address_name(),
application_name.clone(),
server.state.load(Ordering::Relaxed).to_string(),
server.transaction_count.load(Ordering::Relaxed).to_string(),
server.query_count.load(Ordering::Relaxed).to_string(),
server.bytes_sent.load(Ordering::Relaxed).to_string(),
server.bytes_received.load(Ordering::Relaxed).to_string(),
Instant::now()
.duration_since(server.connect_time)
.duration_since(server.connect_time())
.as_secs()
.to_string(),
server
.prepared_hit_count
.load(Ordering::Relaxed)
.to_string(),
server
.prepared_miss_count
.load(Ordering::Relaxed)
.to_string(),
server
.prepared_cache_size
.load(Ordering::Relaxed)
.to_string(),
];
res.put(data_row(&row));
@@ -755,96 +797,128 @@ where
}
/// Pause a pool. It won't pass any more queries to the backends.
async fn pause<T>(stream: &mut T, query: &str) -> Result<(), Error>
async fn pause<T>(stream: &mut T, tokens: Vec<&str>) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let parts: Vec<&str> = query.split(",").map(|part| part.trim()).collect();
let parts: Vec<&str> = match tokens.len() == 2 {
true => tokens[1].split(",").map(|part| part.trim()).collect(),
false => Vec::new(),
};
if parts.len() != 2 {
error_response(
stream,
"PAUSE requires a database and a user, e.g. PAUSE my_db,my_user",
)
.await
} else {
let database = parts[0];
let user = parts[1];
match get_pool(database, user) {
Some(pool) => {
match parts.len() {
0 => {
for (_, pool) in get_all_pools() {
pool.pause();
let mut res = BytesMut::new();
res.put(command_complete(&format!("PAUSE {},{}", database, user)));
// ReadyForQuery
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, &res).await
}
None => {
error_response(
stream,
&format!(
"No pool configured for database: {}, user: {}",
database, user
),
)
.await
let mut res = BytesMut::new();
res.put(command_complete("PAUSE"));
// ReadyForQuery
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, &res).await
}
2 => {
let database = parts[0];
let user = parts[1];
match get_pool(database, user) {
Some(pool) => {
pool.pause();
let mut res = BytesMut::new();
res.put(command_complete(&format!("PAUSE {},{}", database, user)));
// ReadyForQuery
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, &res).await
}
None => {
error_response(
stream,
&format!(
"No pool configured for database: {}, user: {}",
database, user
),
)
.await
}
}
}
_ => error_response(stream, "usage: PAUSE [db, user]").await,
}
}
/// Resume a pool. Queries are allowed again.
async fn resume<T>(stream: &mut T, query: &str) -> Result<(), Error>
async fn resume<T>(stream: &mut T, tokens: Vec<&str>) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let parts: Vec<&str> = query.split(",").map(|part| part.trim()).collect();
let parts: Vec<&str> = match tokens.len() == 2 {
true => tokens[1].split(",").map(|part| part.trim()).collect(),
false => Vec::new(),
};
if parts.len() != 2 {
error_response(
stream,
"RESUME requires a database and a user, e.g. RESUME my_db,my_user",
)
.await
} else {
let database = parts[0];
let user = parts[1];
match get_pool(database, user) {
Some(pool) => {
match parts.len() {
0 => {
for (_, pool) in get_all_pools() {
pool.resume();
let mut res = BytesMut::new();
res.put(command_complete(&format!("RESUME {},{}", database, user)));
// ReadyForQuery
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, &res).await
}
None => {
error_response(
stream,
&format!(
"No pool configured for database: {}, user: {}",
database, user
),
)
.await
let mut res = BytesMut::new();
res.put(command_complete("RESUME"));
// ReadyForQuery
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, &res).await
}
2 => {
let database = parts[0];
let user = parts[1];
match get_pool(database, user) {
Some(pool) => {
pool.resume();
let mut res = BytesMut::new();
res.put(command_complete(&format!("RESUME {},{}", database, user)));
// ReadyForQuery
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, &res).await
}
None => {
error_response(
stream,
&format!(
"No pool configured for database: {}, user: {}",
database, user
),
)
.await
}
}
}
_ => error_response(stream, "usage: RESUME [db, user]").await,
}
}

134
src/auth_passthrough.rs Normal file
View File

@@ -0,0 +1,134 @@
use crate::errors::Error;
use crate::pool::ConnectionPool;
use crate::server::Server;
use log::debug;
#[derive(Clone, Debug)]
pub struct AuthPassthrough {
password: String,
query: String,
user: String,
}
impl AuthPassthrough {
/// Initializes an AuthPassthrough.
pub fn new(query: &str, user: &str, password: &str) -> Self {
AuthPassthrough {
password: password.to_string(),
query: query.to_string(),
user: user.to_string(),
}
}
/// Returns an AuthPassthrough given the pool configuration.
/// If any of required values is not set, None is returned.
pub fn from_pool_config(pool_config: &crate::config::Pool) -> Option<Self> {
if pool_config.is_auth_query_configured() {
return Some(AuthPassthrough::new(
pool_config.auth_query.as_ref().unwrap(),
pool_config.auth_query_user.as_ref().unwrap(),
pool_config.auth_query_password.as_ref().unwrap(),
));
}
None
}
/// Returns an AuthPassthrough given the pool settings.
/// If any of required values is not set, None is returned.
pub fn from_pool_settings(pool_settings: &crate::pool::PoolSettings) -> Option<Self> {
let pool_config = crate::config::Pool {
auth_query: pool_settings.auth_query.clone(),
auth_query_password: pool_settings.auth_query_password.clone(),
auth_query_user: pool_settings.auth_query_user.clone(),
..Default::default()
};
AuthPassthrough::from_pool_config(&pool_config)
}
/// Connects to server and executes auth_query for the specified address.
/// If the response is a row with two columns containing the username set in the address.
/// and its MD5 hash, the MD5 hash returned.
///
/// Note that the query is executed, changing $1 with the name of the user
/// this is so we only hold in memory (and transfer) the least amount of 'sensitive' data.
/// Also, it is compatible with pgbouncer.
///
/// # Arguments
///
/// * `address` - An Address of the server we want to connect to. The username for the hash will be obtained from this value.
///
/// # Examples
///
/// ```
/// use pgcat::auth_passthrough::AuthPassthrough;
/// use pgcat::config::Address;
/// let auth_passthrough = AuthPassthrough::new("SELECT * FROM public.user_lookup('$1');", "postgres", "postgres");
/// auth_passthrough.fetch_hash(&Address::default());
/// ```
///
pub async fn fetch_hash(&self, address: &crate::config::Address) -> Result<String, Error> {
let auth_user = crate::config::User {
username: self.user.clone(),
password: Some(self.password.clone()),
server_username: None,
server_password: None,
pool_size: 1,
statement_timeout: 0,
pool_mode: None,
server_lifetime: None,
min_pool_size: None,
};
let user = &address.username;
debug!("Connecting to server to obtain auth hashes");
let auth_query = self.query.replace("$1", user);
match Server::exec_simple_query(address, &auth_user, &auth_query).await {
Ok(password_data) => {
if password_data.len() == 2 && password_data.first().unwrap() == user {
if let Some(stripped_hash) = password_data
.last()
.unwrap()
.to_string()
.strip_prefix("md5") {
Ok(stripped_hash.to_string())
}
else {
Err(Error::AuthPassthroughError(
"Obtained hash from auth_query does not seem to be in md5 format.".to_string(),
))
}
} else {
Err(Error::AuthPassthroughError(
"Data obtained from query does not follow the scheme 'user','hash'."
.to_string(),
))
}
}
Err(err) => {
Err(Error::AuthPassthroughError(
format!("Error trying to obtain password from auth_query, ignoring hash for user '{}'. Error: {:?}",
user, err))
)
}
}
}
}
pub async fn refetch_auth_hash(pool: &ConnectionPool) -> Result<String, Error> {
let address = pool.address(0, 0);
if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) {
let hash = apt.fetch_hash(address).await?;
return Ok(hash);
}
Err(Error::ClientError(format!(
"Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.",
address.username, address.database
)))
}

File diff suppressed because it is too large Load Diff

36
src/cmd_args.rs Normal file
View File

@@ -0,0 +1,36 @@
use clap::{Parser, ValueEnum};
use tracing::Level;
/// PgCat: Nextgen PostgreSQL Pooler
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
#[arg(default_value_t = String::from("pgcat.toml"), env)]
pub config_file: String,
#[arg(short, long, default_value_t = tracing::Level::INFO, env)]
pub log_level: Level,
#[clap(short='F', long, value_enum, default_value_t=LogFormat::Text, env)]
pub log_format: LogFormat,
#[arg(
short,
long,
default_value_t = false,
env,
help = "disable colors in the log output"
)]
pub no_color: bool,
}
pub fn parse() -> Args {
return Args::parse();
}
#[derive(ValueEnum, Clone, Debug)]
pub enum LogFormat {
Text,
Structured,
Debug,
}

File diff suppressed because it is too large Load Diff

410
src/dns_cache.rs Normal file
View File

@@ -0,0 +1,410 @@
use crate::config::get_config;
use crate::errors::Error;
use arc_swap::ArcSwap;
use log::{debug, error, info, warn};
use once_cell::sync::Lazy;
use std::collections::{HashMap, HashSet};
use std::io;
use std::net::IpAddr;
use std::sync::Arc;
use std::sync::RwLock;
use tokio::time::{sleep, Duration};
use trust_dns_resolver::error::{ResolveError, ResolveResult};
use trust_dns_resolver::lookup_ip::LookupIp;
use trust_dns_resolver::TokioAsyncResolver;
/// Cached Resolver Globally available
pub static CACHED_RESOLVER: Lazy<ArcSwap<CachedResolver>> =
Lazy::new(|| ArcSwap::from_pointee(CachedResolver::default()));
// Ip addressed are returned as a set of addresses
// so we can compare.
#[derive(Clone, PartialEq, Debug)]
pub struct AddrSet {
set: HashSet<IpAddr>,
}
impl AddrSet {
fn new() -> AddrSet {
AddrSet {
set: HashSet::new(),
}
}
}
impl From<LookupIp> for AddrSet {
fn from(lookup_ip: LookupIp) -> Self {
let mut addr_set = AddrSet::new();
for address in lookup_ip.iter() {
addr_set.set.insert(address);
}
addr_set
}
}
///
/// A CachedResolver is a DNS resolution cache mechanism with customizable expiration time.
///
/// The system works as follows:
///
/// When a host is to be resolved, if we have not resolved it before, a new resolution is
/// executed and stored in the internal cache. Concurrently, every `dns_max_ttl` time, the
/// cache is refreshed.
///
/// # Example:
///
/// ```
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
///
/// # tokio_test::block_on(async {
/// let config = CachedResolverConfig::default();
/// let resolver = CachedResolver::new(config, None).await.unwrap();
/// let addrset = resolver.lookup_ip("www.example.com.").await.unwrap();
/// # })
/// ```
///
/// // Now the ip resolution is stored in local cache and subsequent
/// // calls will be returned from cache. Also, the cache is refreshed
/// // and updated every 10 seconds.
///
/// // You can now check if an 'old' lookup differs from what it's currently
/// // store in cache by using `has_changed`.
/// resolver.has_changed("www.example.com.", addrset)
#[derive(Default)]
pub struct CachedResolver {
// The configuration of the cached_resolver.
config: CachedResolverConfig,
// This is the hash that contains the hash.
data: Option<RwLock<HashMap<String, AddrSet>>>,
// The resolver to be used for DNS queries.
resolver: Option<TokioAsyncResolver>,
// The RefreshLoop
refresh_loop: RwLock<Option<tokio::task::JoinHandle<()>>>,
}
///
/// Configuration
#[derive(Clone, Debug, Default, PartialEq)]
pub struct CachedResolverConfig {
/// Amount of time in secods that a resolved dns address is considered stale.
dns_max_ttl: u64,
/// Enabled or disabled? (this is so we can reload config)
enabled: bool,
}
impl CachedResolverConfig {
fn new(dns_max_ttl: u64, enabled: bool) -> Self {
CachedResolverConfig {
dns_max_ttl,
enabled,
}
}
}
impl From<crate::config::Config> for CachedResolverConfig {
fn from(config: crate::config::Config) -> Self {
CachedResolverConfig::new(config.general.dns_max_ttl, config.general.dns_cache_enabled)
}
}
impl CachedResolver {
///
/// Returns a new Arc<CachedResolver> based on passed configuration.
/// It also starts the loop that will refresh cache entries.
///
/// # Arguments:
///
/// * `config` - The `CachedResolverConfig` to be used to create the resolver.
///
/// # Example:
///
/// ```
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
///
/// # tokio_test::block_on(async {
/// let config = CachedResolverConfig::default();
/// let resolver = CachedResolver::new(config, None).await.unwrap();
/// # })
/// ```
///
pub async fn new(
config: CachedResolverConfig,
data: Option<HashMap<String, AddrSet>>,
) -> Result<Arc<Self>, io::Error> {
// Construct a new Resolver with default configuration options
let resolver = Some(TokioAsyncResolver::tokio_from_system_conf()?);
let data = if let Some(hash) = data {
Some(RwLock::new(hash))
} else {
Some(RwLock::new(HashMap::new()))
};
let instance = Arc::new(Self {
config,
resolver,
data,
refresh_loop: RwLock::new(None),
});
if instance.enabled() {
info!("Scheduling DNS refresh loop");
let refresh_loop = tokio::task::spawn({
let instance = instance.clone();
async move {
instance.refresh_dns_entries_loop().await;
}
});
*(instance.refresh_loop.write().unwrap()) = Some(refresh_loop);
}
Ok(instance)
}
pub fn enabled(&self) -> bool {
self.config.enabled
}
// Schedules the refresher
async fn refresh_dns_entries_loop(&self) {
let resolver = TokioAsyncResolver::tokio_from_system_conf().unwrap();
let interval = Duration::from_secs(self.config.dns_max_ttl);
loop {
debug!("Begin refreshing cached DNS addresses.");
// To minimize the time we hold the lock, we first create
// an array with keys.
let mut hostnames: Vec<String> = Vec::new();
{
if let Some(ref data) = self.data {
for hostname in data.read().unwrap().keys() {
hostnames.push(hostname.clone());
}
}
}
for hostname in hostnames.iter() {
let addrset = self
.fetch_from_cache(hostname.as_str())
.expect("Could not obtain expected address from cache, this should not happen");
match resolver.lookup_ip(hostname).await {
Ok(lookup_ip) => {
let new_addrset = AddrSet::from(lookup_ip);
debug!(
"Obtained address for host ({}) -> ({:?})",
hostname, new_addrset
);
if addrset != new_addrset {
debug!(
"Addr changed from {:?} to {:?} updating cache.",
addrset, new_addrset
);
self.store_in_cache(hostname, new_addrset);
}
}
Err(err) => {
error!(
"There was an error trying to resolv {}: ({}).",
hostname, err
);
}
}
}
debug!("Finished refreshing cached DNS addresses.");
sleep(interval).await;
}
}
/// Returns a `AddrSet` given the specified hostname.
///
/// This method first tries to fetch the value from the cache, if it misses
/// then it is resolved and stored in the cache. TTL from records is ignored.
///
/// # Arguments
///
/// * `host` - A string slice referencing the hostname to be resolved.
///
/// # Example:
///
/// ```
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
///
/// # tokio_test::block_on(async {
/// let config = CachedResolverConfig::default();
/// let resolver = CachedResolver::new(config, None).await.unwrap();
/// let response = resolver.lookup_ip("www.google.com.");
/// # })
/// ```
///
pub async fn lookup_ip(&self, host: &str) -> ResolveResult<AddrSet> {
debug!("Lookup up {} in cache", host);
match self.fetch_from_cache(host) {
Some(addr_set) => {
debug!("Cache hit!");
Ok(addr_set)
}
None => {
debug!("Not found, executing a dns query!");
if let Some(ref resolver) = self.resolver {
let addr_set = AddrSet::from(resolver.lookup_ip(host).await?);
debug!("Obtained: {:?}", addr_set);
self.store_in_cache(host, addr_set.clone());
Ok(addr_set)
} else {
Err(ResolveError::from("No resolver available"))
}
}
}
}
//
// Returns true if the stored host resolution differs from the AddrSet passed.
pub fn has_changed(&self, host: &str, addr_set: &AddrSet) -> bool {
if let Some(fetched_addr_set) = self.fetch_from_cache(host) {
return fetched_addr_set != *addr_set;
}
false
}
// Fetches an AddrSet from the inner cache adquiring the read lock.
fn fetch_from_cache(&self, key: &str) -> Option<AddrSet> {
if let Some(ref hash) = self.data {
if let Some(addr_set) = hash.read().unwrap().get(key) {
return Some(addr_set.clone());
}
}
None
}
// Sets up the global CACHED_RESOLVER static variable so we can globally use DNS
// cache.
pub async fn from_config() -> Result<(), Error> {
let cached_resolver = CACHED_RESOLVER.load();
let desired_config = CachedResolverConfig::from(get_config());
if cached_resolver.config != desired_config {
if let Some(ref refresh_loop) = *(cached_resolver.refresh_loop.write().unwrap()) {
warn!("Killing Dnscache refresh loop as its configuration is being reloaded");
refresh_loop.abort()
}
let new_resolver = if let Some(ref data) = cached_resolver.data {
let data = Some(data.read().unwrap().clone());
CachedResolver::new(desired_config, data).await
} else {
CachedResolver::new(desired_config, None).await
};
match new_resolver {
Ok(ok) => {
CACHED_RESOLVER.store(ok);
Ok(())
}
Err(err) => {
let message = format!("Error setting up cached_resolver. Error: {:?}, will continue without this feature.", err);
Err(Error::DNSCachedError(message))
}
}
} else {
Ok(())
}
}
// Stores the AddrSet in cache adquiring the write lock.
fn store_in_cache(&self, host: &str, addr_set: AddrSet) {
if let Some(ref data) = self.data {
data.write().unwrap().insert(host.to_string(), addr_set);
} else {
error!("Could not insert, Hash not initialized");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use trust_dns_resolver::error::ResolveError;
#[tokio::test]
async fn new() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await;
assert!(resolver.is_ok());
}
#[tokio::test]
async fn lookup_ip() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let response = resolver.lookup_ip("www.google.com.").await;
assert!(response.is_ok());
}
#[tokio::test]
async fn has_changed() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let hostname = "www.google.com.";
let response = resolver.lookup_ip(hostname).await;
let addr_set = response.unwrap();
assert!(!resolver.has_changed(hostname, &addr_set));
}
#[tokio::test]
async fn unknown_host() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let hostname = "www.idontexists.";
let response = resolver.lookup_ip(hostname).await;
assert!(matches!(response, Err(ResolveError { .. })));
}
#[tokio::test]
async fn incorrect_address() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let hostname = "w ww.idontexists.";
let response = resolver.lookup_ip(hostname).await;
assert!(matches!(response, Err(ResolveError { .. })));
assert!(!resolver.has_changed(hostname, &AddrSet::new()));
}
#[tokio::test]
// Ok, this test is based on the fact that google does DNS RR
// and does not responds with every available ip everytime, so
// if I cache here, it will miss after one cache iteration or two.
async fn thread() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let hostname = "www.google.com.";
let response = resolver.lookup_ip(hostname).await;
let addr_set = response.unwrap();
assert!(!resolver.has_changed(hostname, &addr_set));
let resolver_for_refresher = resolver.clone();
let _thread_handle = tokio::task::spawn(async move {
resolver_for_refresher.refresh_dns_entries_loop().await;
});
assert!(!resolver.has_changed(hostname, &addr_set));
}
}

View File

@@ -1,18 +1,132 @@
/// Errors.
//! Errors.
/// Various errors.
#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Clone)]
pub enum Error {
SocketError(String),
ClientSocketError(String, ClientIdentifier),
ClientGeneralError(String, ClientIdentifier),
ClientAuthImpossible(String),
ClientAuthPassthroughError(String, ClientIdentifier),
ClientBadStartup,
ProtocolSyncError(String),
BadQuery(String),
ServerError,
ServerMessageParserError(String),
ServerStartupError(String, ServerIdentifier),
ServerAuthError(String, ServerIdentifier),
BadConfig,
AllServersDown,
ClientError(String),
TlsError,
StatementTimeout,
DNSCachedError(String),
ShuttingDown,
ParseBytesError(String),
AuthError(String),
AuthPassthroughError(String),
UnsupportedStatement,
QueryRouterParserError(String),
QueryRouterError(String),
InvalidShardId(usize),
}
#[derive(Clone, PartialEq, Debug)]
pub struct ClientIdentifier {
pub application_name: String,
pub username: String,
pub pool_name: String,
}
impl ClientIdentifier {
pub fn new(application_name: &str, username: &str, pool_name: &str) -> ClientIdentifier {
ClientIdentifier {
application_name: application_name.into(),
username: username.into(),
pool_name: pool_name.into(),
}
}
}
impl std::fmt::Display for ClientIdentifier {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"{{ application_name: {}, username: {}, pool_name: {} }}",
self.application_name, self.username, self.pool_name
)
}
}
#[derive(Clone, PartialEq, Debug)]
pub struct ServerIdentifier {
pub username: String,
pub database: String,
}
impl ServerIdentifier {
pub fn new(username: &str, database: &str) -> ServerIdentifier {
ServerIdentifier {
username: username.into(),
database: database.into(),
}
}
}
impl std::fmt::Display for ServerIdentifier {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"{{ username: {}, database: {} }}",
self.username, self.database
)
}
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match &self {
&Error::ClientSocketError(error, client_identifier) => write!(
f,
"Error reading {} from client {}",
error, client_identifier
),
&Error::ClientGeneralError(error, client_identifier) => {
write!(f, "{} {}", error, client_identifier)
}
&Error::ClientAuthImpossible(username) => write!(
f,
"Client auth not possible, \
no cleartext password set for username: {} \
in config and auth passthrough (query_auth) \
is not set up.",
username
),
&Error::ClientAuthPassthroughError(error, client_identifier) => write!(
f,
"No cleartext password set, \
and no auth passthrough could not \
obtain the hash from server for {}, \
the error was: {}",
client_identifier, error
),
&Error::ServerStartupError(error, server_identifier) => write!(
f,
"Error reading {} on server startup {}",
error, server_identifier,
),
&Error::ServerAuthError(error, server_identifier) => {
write!(f, "{} for {}", error, server_identifier,)
}
// The rest can use Debug.
err => write!(f, "{:?}", err),
}
}
}
impl From<std::ffi::NulError> for Error {
fn from(err: std::ffi::NulError) -> Self {
Error::QueryRouterError(err.to_string())
}
}

View File

@@ -1,10 +1,18 @@
pub mod admin;
pub mod auth_passthrough;
pub mod client;
pub mod cmd_args;
pub mod config;
pub mod constants;
pub mod dns_cache;
pub mod errors;
pub mod logger;
pub mod messages;
pub mod mirrors;
pub mod multi_logger;
pub mod plugins;
pub mod pool;
pub mod prometheus;
pub mod query_router;
pub mod scram;
pub mod server;
pub mod sharding;

20
src/logger.rs Normal file
View File

@@ -0,0 +1,20 @@
use crate::cmd_args::{Args, LogFormat};
use tracing_subscriber;
use tracing_subscriber::EnvFilter;
pub fn init(args: &Args) {
// Iniitalize a default filter, and then override the builtin default "warning" with our
// commandline, (default: "info")
let filter = EnvFilter::from_default_env().add_directive(args.log_level.into());
let trace_sub = tracing_subscriber::fmt()
.with_thread_ids(true)
.with_env_filter(filter)
.with_ansi(!args.no_color);
match args.log_format {
LogFormat::Structured => trace_sub.json().init(),
LogFormat::Debug => trace_sub.pretty().init(),
_ => trace_sub.init(),
};
}

View File

@@ -23,7 +23,6 @@ extern crate arc_swap;
extern crate async_trait;
extern crate bb8;
extern crate bytes;
extern crate env_logger;
extern crate exitcode;
extern crate log;
extern crate md5;
@@ -36,6 +35,7 @@ extern crate sqlparser;
extern crate tokio;
extern crate tokio_rustls;
extern crate toml;
extern crate trust_dns_resolver;
#[cfg(not(target_env = "msvc"))]
use jemallocator::Jemalloc;
@@ -60,52 +60,32 @@ use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::broadcast;
mod admin;
mod client;
mod config;
mod constants;
mod errors;
mod messages;
mod mirrors;
mod multi_logger;
mod pool;
mod prometheus;
mod query_router;
mod scram;
mod server;
mod sharding;
mod stats;
mod tls;
use crate::config::{get_config, reload_config, VERSION};
use crate::pool::{ClientServerMap, ConnectionPool};
use crate::prometheus::start_metric_server;
use crate::stats::{Collector, Reporter, REPORTER};
use pgcat::cmd_args;
use pgcat::config::{get_config, reload_config, VERSION};
use pgcat::dns_cache;
use pgcat::logger;
use pgcat::messages::configure_socket;
use pgcat::pool::{ClientServerMap, ConnectionPool};
use pgcat::prometheus::start_metric_server;
use pgcat::stats::{Collector, Reporter, REPORTER};
fn main() -> Result<(), Box<dyn std::error::Error>> {
multi_logger::MultiLogger::init().unwrap();
let args = cmd_args::parse();
logger::init(&args);
info!("Welcome to PgCat! Meow. (Version {})", VERSION);
if !query_router::QueryRouter::setup() {
if !pgcat::query_router::QueryRouter::setup() {
error!("Could not setup query router");
std::process::exit(exitcode::CONFIG);
}
let args = std::env::args().collect::<Vec<String>>();
let config_file = if args.len() == 2 {
args[1].to_string()
} else {
String::from("pgcat.toml")
};
// Create a transient runtime for loading the config for the first time.
{
let runtime = Builder::new_multi_thread().worker_threads(1).build()?;
runtime.block_on(async {
match config::parse(&config_file).await {
match pgcat::config::parse(args.config_file.as_str()).await {
Ok(_) => (),
Err(err) => {
error!("Config parse error: {:?}", err);
@@ -162,8 +142,13 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new()));
// Statistics reporting.
let (stats_tx, stats_rx) = mpsc::channel(500_000);
REPORTER.store(Arc::new(Reporter::new(stats_tx.clone())));
REPORTER.store(Arc::new(Reporter::default()));
// Starts (if enabled) dns cache before pools initialization
match dns_cache::CachedResolver::from_config().await {
Ok(_) => (),
Err(err) => error!("DNS cache initialization error: {:?}", err),
};
// Connection pool that allows to query all shards and replicas.
match ConnectionPool::from_config(client_server_map.clone()).await {
@@ -175,20 +160,23 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
};
tokio::task::spawn(async move {
let mut stats_collector = Collector::new(stats_rx, stats_tx.clone());
let mut stats_collector = Collector::default();
stats_collector.collect().await;
});
info!("Config autoreloader: {}", config.general.autoreload);
info!("Config autoreloader: {}", match config.general.autoreload {
Some(interval) => format!("{} ms", interval),
None => "disabled".into(),
});
let mut autoreload_interval = tokio::time::interval(tokio::time::Duration::from_millis(15_000));
let autoreload_client_server_map = client_server_map.clone();
if let Some(interval) = config.general.autoreload {
let mut autoreload_interval = tokio::time::interval(tokio::time::Duration::from_millis(interval));
let autoreload_client_server_map = client_server_map.clone();
tokio::task::spawn(async move {
loop {
autoreload_interval.tick().await;
if config.general.autoreload {
info!("Automatically reloading config");
tokio::task::spawn(async move {
loop {
autoreload_interval.tick().await;
debug!("Automatically reloading config");
if let Ok(changed) = reload_config(autoreload_client_server_map.clone()).await {
if changed {
@@ -196,8 +184,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
}
};
}
}
});
});
};
#[cfg(windows)]
let mut term_signal = win_signal::ctrl_close().unwrap();
@@ -282,18 +272,20 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let drain_tx = drain_tx.clone();
let client_server_map = client_server_map.clone();
let tls_certificate = config.general.tls_certificate.clone();
let tls_certificate = get_config().general.tls_certificate.clone();
configure_socket(&socket);
tokio::task::spawn(async move {
let start = chrono::offset::Utc::now().naive_utc();
match client::client_entrypoint(
match pgcat::client::client_entrypoint(
socket,
client_server_map,
shutdown_rx,
drain_tx,
admin_only,
tls_certificate.clone(),
tls_certificate,
config.general.log_client_connections,
)
.await
@@ -301,7 +293,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) => {
let duration = chrono::offset::Utc::now().naive_utc() - start;
if config.general.log_client_disconnections {
if get_config().general.log_client_disconnections {
info!(
"Client {:?} disconnected, session duration: {}",
addr,
@@ -318,7 +310,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
Err(err) => {
match err {
errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err),
pgcat::errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err),
_ => warn!("Client disconnected with error {:?}", err),
}

View File

@@ -1,17 +1,24 @@
/// Helper functions to send one-off protocol messages
/// and handle TcpStream (TCP socket).
use bytes::{Buf, BufMut, BytesMut};
use log::error;
use log::{debug, error};
use md5::{Digest, Md5};
use socket2::{SockRef, TcpKeepalive};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use crate::client::PREPARED_STATEMENT_COUNTER;
use crate::config::get_config;
use crate::errors::Error;
use crate::constants::MESSAGE_TERMINATOR;
use std::collections::HashMap;
use std::ffi::CString;
use std::fmt::{Display, Formatter};
use std::io::{BufRead, Cursor};
use std::mem;
use std::str::FromStr;
use std::sync::atomic::Ordering;
use std::time::Duration;
/// Postgres data type mappings
@@ -20,6 +27,10 @@ pub enum DataType {
Text,
Int4,
Numeric,
Bool,
Oid,
AnyArray,
Any,
}
impl From<&DataType> for i32 {
@@ -28,6 +39,10 @@ impl From<&DataType> for i32 {
DataType::Text => 25,
DataType::Int4 => 23,
DataType::Numeric => 1700,
DataType::Bool => 16,
DataType::Oid => 26,
DataType::AnyArray => 2277,
DataType::Any => 2276,
}
}
}
@@ -116,7 +131,10 @@ where
/// Send the startup packet the server. We're pretending we're a Pg client.
/// This tells the server which user we are and what database we want.
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> {
pub async fn startup<S>(stream: &mut S, user: &str, database: &str) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut bytes = BytesMut::with_capacity(25);
bytes.put_i32(196608); // Protocol number
@@ -126,6 +144,10 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
bytes.put_slice(user.as_bytes());
bytes.put_u8(0);
// Application name
bytes.put(&b"application_name\0"[..]);
bytes.put_slice(&b"pgcat\0"[..]);
// Database
bytes.put(&b"database\0"[..]);
bytes.put_slice(database.as_bytes());
@@ -150,6 +172,21 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
}
}
pub async fn ssl_request(stream: &mut TcpStream) -> Result<(), Error> {
let mut bytes = BytesMut::with_capacity(12);
bytes.put_i32(8);
bytes.put_i32(80877103);
match stream.write_all(&bytes).await {
Ok(_) => Ok(()),
Err(err) => Err(Error::SocketError(format!(
"Error writing SSLRequest to server socket - Error: {:?}",
err
))),
}
}
/// Parse the params the server sends as a key/value format.
pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
let mut result = HashMap::new();
@@ -213,7 +250,13 @@ pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec<u8> {
let output = md5.finalize_reset();
// Second pass
md5.update(format!("{:x}", output));
md5_hash_second_pass(&(format!("{:x}", output)), salt)
}
pub fn md5_hash_second_pass(hash: &str, salt: &[u8]) -> Vec<u8> {
let mut md5 = Md5::new();
// Second pass
md5.update(hash);
md5.update(salt);
let mut password = format!("md5{:x}", md5.finalize())
@@ -247,6 +290,20 @@ where
write_all(stream, message).await
}
pub async fn md5_password_with_hash<S>(stream: &mut S, hash: &str, salt: &[u8]) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let password = md5_hash_second_pass(hash, salt);
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
message.put_u8(b'p');
message.put_i32(password.len() as i32 + 4);
message.put_slice(&password[..]);
write_all(stream, message).await
}
/// Implements a response to our custom `SET SHARDING KEY`
/// and `SET SERVER ROLE` commands.
/// This tells the client we're ready for the next query.
@@ -384,7 +441,7 @@ pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
let mut res = BytesMut::new();
let mut row_desc = BytesMut::new();
// how many colums we are storing
// how many columns we are storing
row_desc.put_i16(columns.len() as i16);
for (name, data_type) in columns {
@@ -405,6 +462,10 @@ pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
DataType::Text => -1,
DataType::Int4 => 4,
DataType::Numeric => -1,
DataType::Bool => 1,
DataType::Oid => 4,
DataType::AnyArray => -1,
DataType::Any => -1,
};
row_desc.put_i16(type_size);
@@ -443,6 +504,29 @@ pub fn data_row(row: &Vec<String>) -> BytesMut {
res
}
pub fn data_row_nullable(row: &Vec<Option<String>>) -> BytesMut {
let mut res = BytesMut::new();
let mut data_row = BytesMut::new();
data_row.put_i16(row.len() as i16);
for column in row {
if let Some(column) = column {
let column = column.as_bytes();
data_row.put_i32(column.len() as i32);
data_row.put_slice(column);
} else {
data_row.put_i32(-1 as i32);
}
}
res.put_u8(b'D');
res.put_i32(data_row.len() as i32 + 4);
res.put(data_row);
res
}
/// Create a CommandComplete message.
pub fn command_complete(command: &str) -> BytesMut {
let cmd = BytesMut::from(format!("{}\0", command).as_bytes());
@@ -453,6 +537,33 @@ pub fn command_complete(command: &str) -> BytesMut {
res
}
/// Create a notify message.
pub fn notify(message: &str, details: String) -> BytesMut {
let mut notify_cmd = BytesMut::new();
notify_cmd.put_slice("SNOTICE\0".as_bytes());
notify_cmd.put_slice("C00000\0".as_bytes());
notify_cmd.put_slice(format!("M{}\0", message).as_bytes());
notify_cmd.put_slice(format!("D{}\0", details).as_bytes());
// this extra byte says that is the end of the package
notify_cmd.put_u8(0);
let mut res = BytesMut::new();
res.put_u8(b'N');
res.put_i32(notify_cmd.len() as i32 + 4);
res.put(notify_cmd);
res
}
pub fn flush() -> BytesMut {
let mut bytes = BytesMut::new();
bytes.put_u8(b'H');
bytes.put_i32(4);
bytes
}
/// Write all data in the buffer to the TcpStream.
pub async fn write_all<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error>
where
@@ -485,6 +596,29 @@ where
}
}
pub async fn write_all_flush<S>(stream: &mut S, buf: &[u8]) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
match stream.write_all(buf).await {
Ok(_) => match stream.flush().await {
Ok(_) => Ok(()),
Err(err) => {
return Err(Error::SocketError(format!(
"Error flushing socket - Error: {:?}",
err
)))
}
},
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
)))
}
}
}
/// Read a complete message from the socket.
pub async fn read_message<S>(stream: &mut S) -> Result<BytesMut, Error>
where
@@ -562,6 +696,13 @@ pub fn configure_socket(stream: &TcpStream) {
let sock_ref = SockRef::from(stream);
let conf = get_config();
#[cfg(target_os = "linux")]
match sock_ref.set_tcp_user_timeout(Some(Duration::from_millis(conf.general.tcp_user_timeout)))
{
Ok(_) => (),
Err(err) => error!("Could not configure tcp_user_timeout for socket: {}", err),
}
match sock_ref.set_keepalive(true) {
Ok(_) => {
match sock_ref.set_tcp_keepalive(
@@ -571,7 +712,7 @@ pub fn configure_socket(stream: &TcpStream) {
.with_time(Duration::from_secs(conf.general.tcp_keepalives_idle)),
) {
Ok(_) => (),
Err(err) => error!("Could not configure socket: {}", err),
Err(err) => error!("Could not configure tcp_keepalive for socket: {}", err),
}
}
Err(err) => error!("Could not configure socket: {}", err),
@@ -593,3 +734,684 @@ impl BytesMutReader for Cursor<&BytesMut> {
}
}
}
impl BytesMutReader for BytesMut {
/// Should only be used when reading strings from the message protocol.
/// Can be used to read multiple strings from the same message which are separated by the null byte
fn read_string(&mut self) -> Result<String, Error> {
let null_index = self.iter().position(|&byte| byte == b'\0');
match null_index {
Some(index) => {
let string_bytes = self.split_to(index + 1);
Ok(String::from_utf8_lossy(&string_bytes[..string_bytes.len() - 1]).to_string())
}
None => return Err(Error::ParseBytesError("Could not read string".to_string())),
}
}
}
/// Parse (F) message.
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
#[derive(Clone, Debug)]
pub struct Parse {
code: char,
#[allow(dead_code)]
len: i32,
pub name: String,
pub generated_name: String,
query: String,
num_params: i16,
param_types: Vec<i32>,
}
impl TryFrom<&BytesMut> for Parse {
type Error = Error;
fn try_from(buf: &BytesMut) -> Result<Parse, Error> {
let mut cursor = Cursor::new(buf);
let code = cursor.get_u8() as char;
let len = cursor.get_i32();
let name = cursor.read_string()?;
let query = cursor.read_string()?;
let num_params = cursor.get_i16();
let mut param_types = Vec::new();
for _ in 0..num_params {
param_types.push(cursor.get_i32());
}
Ok(Parse {
code,
len,
name,
generated_name: prepared_statement_name(),
query,
num_params,
param_types,
})
}
}
impl TryFrom<Parse> for BytesMut {
type Error = Error;
fn try_from(parse: Parse) -> Result<BytesMut, Error> {
let mut bytes = BytesMut::new();
let name_binding = CString::new(parse.name)?;
let name = name_binding.as_bytes_with_nul();
let query_binding = CString::new(parse.query)?;
let query = query_binding.as_bytes_with_nul();
// Recompute length of the message.
let len = 4 // self
+ name.len()
+ query.len()
+ 2
+ 4 * parse.num_params as usize;
bytes.put_u8(parse.code as u8);
bytes.put_i32(len as i32);
bytes.put_slice(name);
bytes.put_slice(query);
bytes.put_i16(parse.num_params);
for param in parse.param_types {
bytes.put_i32(param);
}
Ok(bytes)
}
}
impl TryFrom<&Parse> for BytesMut {
type Error = Error;
fn try_from(parse: &Parse) -> Result<BytesMut, Error> {
parse.clone().try_into()
}
}
impl Parse {
pub fn rename(mut self) -> Self {
self.name = self.generated_name.to_string();
self
}
pub fn anonymous(&self) -> bool {
self.name.is_empty()
}
}
/// Bind (B) message.
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
#[derive(Clone, Debug)]
pub struct Bind {
code: char,
#[allow(dead_code)]
len: i64,
portal: String,
pub prepared_statement: String,
num_param_format_codes: i16,
param_format_codes: Vec<i16>,
num_param_values: i16,
param_values: Vec<(i32, BytesMut)>,
num_result_column_format_codes: i16,
result_columns_format_codes: Vec<i16>,
}
impl TryFrom<&BytesMut> for Bind {
type Error = Error;
fn try_from(buf: &BytesMut) -> Result<Bind, Error> {
let mut cursor = Cursor::new(buf);
let code = cursor.get_u8() as char;
let len = cursor.get_i32();
let portal = cursor.read_string()?;
let prepared_statement = cursor.read_string()?;
let num_param_format_codes = cursor.get_i16();
let mut param_format_codes = Vec::new();
for _ in 0..num_param_format_codes {
param_format_codes.push(cursor.get_i16());
}
let num_param_values = cursor.get_i16();
let mut param_values = Vec::new();
for _ in 0..num_param_values {
let param_len = cursor.get_i32();
// There is special occasion when the parameter is NULL
// In that case, param length is defined as -1
// So if the passed parameter len is over 0
if param_len > 0 {
let mut param = BytesMut::with_capacity(param_len as usize);
param.resize(param_len as usize, b'0');
cursor.copy_to_slice(&mut param);
// we push and the length and the parameter into vector
param_values.push((param_len, param));
} else {
// otherwise we push a tuple with -1 and 0-len BytesMut
// which means that after encountering -1 postgres proceeds
// to processing another parameter
param_values.push((param_len, BytesMut::new()));
}
}
let num_result_column_format_codes = cursor.get_i16();
let mut result_columns_format_codes = Vec::new();
for _ in 0..num_result_column_format_codes {
result_columns_format_codes.push(cursor.get_i16());
}
Ok(Bind {
code,
len: len as i64,
portal,
prepared_statement,
num_param_format_codes,
param_format_codes,
num_param_values,
param_values,
num_result_column_format_codes,
result_columns_format_codes,
})
}
}
impl TryFrom<Bind> for BytesMut {
type Error = Error;
fn try_from(bind: Bind) -> Result<BytesMut, Error> {
let mut bytes = BytesMut::new();
let portal_binding = CString::new(bind.portal)?;
let portal = portal_binding.as_bytes_with_nul();
let prepared_statement_binding = CString::new(bind.prepared_statement)?;
let prepared_statement = prepared_statement_binding.as_bytes_with_nul();
let mut len = 4 // self
+ portal.len()
+ prepared_statement.len()
+ 2 // num_param_format_codes
+ 2 * bind.num_param_format_codes as usize // num_param_format_codes
+ 2; // num_param_values
for (param_len, _) in &bind.param_values {
len += 4 + *param_len as usize;
}
len += 2; // num_result_column_format_codes
len += 2 * bind.num_result_column_format_codes as usize;
bytes.put_u8(bind.code as u8);
bytes.put_i32(len as i32);
bytes.put_slice(portal);
bytes.put_slice(prepared_statement);
bytes.put_i16(bind.num_param_format_codes);
for param_format_code in bind.param_format_codes {
bytes.put_i16(param_format_code);
}
bytes.put_i16(bind.num_param_values);
for (param_len, param) in bind.param_values {
bytes.put_i32(param_len);
bytes.put_slice(&param);
}
bytes.put_i16(bind.num_result_column_format_codes);
for result_column_format_code in bind.result_columns_format_codes {
bytes.put_i16(result_column_format_code);
}
Ok(bytes)
}
}
impl Bind {
pub fn reassign(mut self, parse: &Parse) -> Self {
self.prepared_statement = parse.name.clone();
self
}
pub fn anonymous(&self) -> bool {
self.prepared_statement.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct Describe {
code: char,
#[allow(dead_code)]
len: i32,
target: char,
pub statement_name: String,
}
impl TryFrom<&BytesMut> for Describe {
type Error = Error;
fn try_from(bytes: &BytesMut) -> Result<Describe, Error> {
let mut cursor = Cursor::new(bytes);
let code = cursor.get_u8() as char;
let len = cursor.get_i32();
let target = cursor.get_u8() as char;
let statement_name = cursor.read_string()?;
Ok(Describe {
code,
len,
target,
statement_name,
})
}
}
impl TryFrom<Describe> for BytesMut {
type Error = Error;
fn try_from(describe: Describe) -> Result<BytesMut, Error> {
let mut bytes = BytesMut::new();
let statement_name_binding = CString::new(describe.statement_name)?;
let statement_name = statement_name_binding.as_bytes_with_nul();
let len = 4 + 1 + statement_name.len();
bytes.put_u8(describe.code as u8);
bytes.put_i32(len as i32);
bytes.put_u8(describe.target as u8);
bytes.put_slice(statement_name);
Ok(bytes)
}
}
impl Describe {
pub fn rename(mut self, name: &str) -> Self {
self.statement_name = name.to_string();
self
}
pub fn anonymous(&self) -> bool {
self.statement_name.is_empty()
}
}
/// Close (F) message.
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
#[derive(Clone, Debug)]
pub struct Close {
code: char,
#[allow(dead_code)]
len: i32,
close_type: char,
pub name: String,
}
impl TryFrom<&BytesMut> for Close {
type Error = Error;
fn try_from(bytes: &BytesMut) -> Result<Close, Error> {
let mut cursor = Cursor::new(bytes);
let code = cursor.get_u8() as char;
let len = cursor.get_i32();
let close_type = cursor.get_u8() as char;
let name = cursor.read_string()?;
Ok(Close {
code,
len,
close_type,
name,
})
}
}
impl TryFrom<Close> for BytesMut {
type Error = Error;
fn try_from(close: Close) -> Result<BytesMut, Error> {
debug!("Close: {:?}", close);
let mut bytes = BytesMut::new();
let name_binding = CString::new(close.name)?;
let name = name_binding.as_bytes_with_nul();
let len = 4 + 1 + name.len();
bytes.put_u8(close.code as u8);
bytes.put_i32(len as i32);
bytes.put_u8(close.close_type as u8);
bytes.put_slice(name);
Ok(bytes)
}
}
impl Close {
pub fn new(name: &str) -> Close {
let name = name.to_string();
Close {
code: 'C',
len: 4 + 1 + name.len() as i32 + 1, // will be recalculated
close_type: 'S',
name,
}
}
pub fn is_prepared_statement(&self) -> bool {
self.close_type == 'S'
}
pub fn anonymous(&self) -> bool {
self.name.is_empty()
}
}
pub fn close_complete() -> BytesMut {
let mut bytes = BytesMut::new();
bytes.put_u8(b'3');
bytes.put_i32(4);
bytes
}
pub fn prepared_statement_name() -> String {
format!(
"P_{}",
PREPARED_STATEMENT_COUNTER.fetch_add(1, Ordering::SeqCst)
)
}
// from https://www.postgresql.org/docs/12/protocol-error-fields.html
#[derive(Debug, Default, PartialEq)]
pub struct PgErrorMsg {
pub severity_localized: String, // S
pub severity: String, // V
pub code: String, // C
pub message: String, // M
pub detail: Option<String>, // D
pub hint: Option<String>, // H
pub position: Option<u32>, // P
pub internal_position: Option<u32>, // p
pub internal_query: Option<String>, // q
pub where_context: Option<String>, // W
pub schema_name: Option<String>, // s
pub table_name: Option<String>, // t
pub column_name: Option<String>, // c
pub data_type_name: Option<String>, // d
pub constraint_name: Option<String>, // n
pub file_name: Option<String>, // F
pub line: Option<u32>, // L
pub routine: Option<String>, // R
}
// TODO: implement with https://docs.rs/derive_more/latest/derive_more/
impl Display for PgErrorMsg {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "[severity: {}]", self.severity)?;
write!(f, "[code: {}]", self.code)?;
write!(f, "[message: {}]", self.message)?;
if let Some(val) = &self.detail {
write!(f, "[detail: {val}]")?;
}
if let Some(val) = &self.hint {
write!(f, "[hint: {val}]")?;
}
if let Some(val) = &self.position {
write!(f, "[position: {val}]")?;
}
if let Some(val) = &self.internal_position {
write!(f, "[internal_position: {val}]")?;
}
if let Some(val) = &self.internal_query {
write!(f, "[internal_query: {val}]")?;
}
if let Some(val) = &self.internal_query {
write!(f, "[internal_query: {val}]")?;
}
if let Some(val) = &self.where_context {
write!(f, "[where: {val}]")?;
}
if let Some(val) = &self.schema_name {
write!(f, "[schema_name: {val}]")?;
}
if let Some(val) = &self.table_name {
write!(f, "[table_name: {val}]")?;
}
if let Some(val) = &self.column_name {
write!(f, "[column_name: {val}]")?;
}
if let Some(val) = &self.data_type_name {
write!(f, "[data_type_name: {val}]")?;
}
if let Some(val) = &self.constraint_name {
write!(f, "[constraint_name: {val}]")?;
}
if let Some(val) = &self.file_name {
write!(f, "[file_name: {val}]")?;
}
if let Some(val) = &self.line {
write!(f, "[line: {val}]")?;
}
if let Some(val) = &self.routine {
write!(f, "[routine: {val}]")?;
}
write!(f, " ")?;
Ok(())
}
}
impl PgErrorMsg {
pub fn parse(error_msg: Vec<u8>) -> Result<PgErrorMsg, Error> {
let mut out = PgErrorMsg {
severity_localized: "".to_string(),
severity: "".to_string(),
code: "".to_string(),
message: "".to_string(),
detail: None,
hint: None,
position: None,
internal_position: None,
internal_query: None,
where_context: None,
schema_name: None,
table_name: None,
column_name: None,
data_type_name: None,
constraint_name: None,
file_name: None,
line: None,
routine: None,
};
for msg_part in error_msg.split(|v| *v == MESSAGE_TERMINATOR) {
if msg_part.is_empty() {
continue;
}
let msg_content = match String::from_utf8_lossy(&msg_part[1..]).parse() {
Ok(c) => c,
Err(err) => {
return Err(Error::ServerMessageParserError(format!(
"could not parse server message field. err {:?}",
err
)))
}
};
match &msg_part[0] {
b'S' => {
out.severity_localized = msg_content;
}
b'V' => {
out.severity = msg_content;
}
b'C' => {
out.code = msg_content;
}
b'M' => {
out.message = msg_content;
}
b'D' => {
out.detail = Some(msg_content);
}
b'H' => {
out.hint = Some(msg_content);
}
b'P' => out.position = Some(u32::from_str(msg_content.as_str()).unwrap_or(0)),
b'p' => {
out.internal_position = Some(u32::from_str(msg_content.as_str()).unwrap_or(0))
}
b'q' => {
out.internal_query = Some(msg_content);
}
b'W' => {
out.where_context = Some(msg_content);
}
b's' => {
out.schema_name = Some(msg_content);
}
b't' => {
out.table_name = Some(msg_content);
}
b'c' => {
out.column_name = Some(msg_content);
}
b'd' => {
out.data_type_name = Some(msg_content);
}
b'n' => {
out.constraint_name = Some(msg_content);
}
b'F' => {
out.file_name = Some(msg_content);
}
b'L' => out.line = Some(u32::from_str(msg_content.as_str()).unwrap_or(0)),
b'R' => {
out.routine = Some(msg_content);
}
_ => {}
}
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use crate::messages::PgErrorMsg;
use log::{error, info};
fn field(kind: char, content: &str) -> Vec<u8> {
format!("{kind}{content}\0").as_bytes().to_vec()
}
#[test]
fn parse_fields() {
let mut complete_msg = vec![];
let severity = "FATAL";
complete_msg.extend(field('S', &severity));
complete_msg.extend(field('V', &severity));
let error_code = "29P02";
complete_msg.extend(field('C', &error_code));
let message = "password authentication failed for user \"wrong_user\"";
complete_msg.extend(field('M', &message));
let detail_msg = "super detailed message";
complete_msg.extend(field('D', &detail_msg));
let hint_msg = "hint detail here";
complete_msg.extend(field('H', &hint_msg));
complete_msg.extend(field('P', "123"));
complete_msg.extend(field('p', "234"));
let internal_query = "SELECT * from foo;";
complete_msg.extend(field('q', &internal_query));
let where_msg = "where goes here";
complete_msg.extend(field('W', &where_msg));
let schema_msg = "schema_name";
complete_msg.extend(field('s', &schema_msg));
let table_msg = "table_name";
complete_msg.extend(field('t', &table_msg));
let column_msg = "column_name";
complete_msg.extend(field('c', &column_msg));
let data_type_msg = "type_name";
complete_msg.extend(field('d', &data_type_msg));
let constraint_msg = "constraint_name";
complete_msg.extend(field('n', &constraint_msg));
let file_msg = "pgcat.c";
complete_msg.extend(field('F', &file_msg));
complete_msg.extend(field('L', "335"));
let routine_msg = "my_failing_routine";
complete_msg.extend(field('R', &routine_msg));
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.with_ansi(true)
.init();
info!(
"full message: {}",
PgErrorMsg::parse(complete_msg.clone()).unwrap()
);
assert_eq!(
PgErrorMsg {
severity_localized: severity.to_string(),
severity: severity.to_string(),
code: error_code.to_string(),
message: message.to_string(),
detail: Some(detail_msg.to_string()),
hint: Some(hint_msg.to_string()),
position: Some(123),
internal_position: Some(234),
internal_query: Some(internal_query.to_string()),
where_context: Some(where_msg.to_string()),
schema_name: Some(schema_msg.to_string()),
table_name: Some(table_msg.to_string()),
column_name: Some(column_msg.to_string()),
data_type_name: Some(data_type_msg.to_string()),
constraint_name: Some(constraint_msg.to_string()),
file_name: Some(file_msg.to_string()),
line: Some(335),
routine: Some(routine_msg.to_string()),
},
PgErrorMsg::parse(complete_msg).unwrap()
);
let mut only_mandatory_msg = vec![];
only_mandatory_msg.extend(field('S', &severity));
only_mandatory_msg.extend(field('V', &severity));
only_mandatory_msg.extend(field('C', &error_code));
only_mandatory_msg.extend(field('M', &message));
only_mandatory_msg.extend(field('D', &detail_msg));
let err_fields = PgErrorMsg::parse(only_mandatory_msg.clone()).unwrap();
info!("only mandatory fields: {}", &err_fields);
error!(
"server error: {}: {}",
err_fields.severity, err_fields.message
);
assert_eq!(
PgErrorMsg {
severity_localized: severity.to_string(),
severity: severity.to_string(),
code: error_code.to_string(),
message: message.to_string(),
detail: Some(detail_msg.to_string()),
hint: None,
position: None,
internal_position: None,
internal_query: None,
where_context: None,
schema_name: None,
table_name: None,
column_name: None,
data_type_name: None,
constraint_name: None,
file_name: None,
line: None,
routine: None,
},
PgErrorMsg::parse(only_mandatory_msg).unwrap()
);
}
}

View File

@@ -1,11 +1,13 @@
use std::sync::Arc;
/// A mirrored PostgreSQL client.
/// Packets arrive to us through a channel from the main client and we send them to the server.
use bb8::Pool;
use bytes::{Bytes, BytesMut};
use parking_lot::RwLock;
use crate::config::{get_config, Address, Role, User};
use crate::pool::{ClientServerMap, ServerPool};
use crate::stats::get_reporter;
use log::{error, info, trace, warn};
use tokio::sync::mpsc::{channel, Receiver, Sender};
@@ -21,20 +23,25 @@ impl MirroredClient {
async fn create_pool(&self) -> Pool<ServerPool> {
let config = get_config();
let default = std::time::Duration::from_millis(10_000).as_millis() as u64;
let (connection_timeout, idle_timeout) = match config.pools.get(&self.address.pool_name) {
Some(cfg) => (
cfg.connect_timeout.unwrap_or(default),
cfg.idle_timeout.unwrap_or(default),
),
None => (default, default),
};
let (connection_timeout, idle_timeout, _cfg) =
match config.pools.get(&self.address.pool_name) {
Some(cfg) => (
cfg.connect_timeout.unwrap_or(default),
cfg.idle_timeout.unwrap_or(default),
cfg.clone(),
),
None => (default, default, crate::config::Pool::default()),
};
let manager = ServerPool::new(
self.address.clone(),
self.user.clone(),
self.database.as_str(),
ClientServerMap::default(),
get_reporter(),
Arc::new(RwLock::new(None)),
None,
true,
false,
);
Pool::builder()
@@ -72,7 +79,7 @@ impl MirroredClient {
}
// Incoming data from server (we read to clear the socket buffer and discard the data)
recv_result = server.recv() => {
recv_result = server.recv(None) => {
match recv_result {
Ok(message) => trace!("Received from mirror: {} {:?}", String::from_utf8_lossy(&message[..]), address.clone()),
Err(err) => {

View File

@@ -1,80 +0,0 @@
use log::{Level, Log, Metadata, Record, SetLoggerError};
// This is a special kind of logger that allows sending logs to different
// targets depending on the log level.
//
// By default, if nothing is set, it acts as a regular env_log logger,
// it sends everything to standard error.
//
// If the Env variable `STDOUT_LOG` is defined, it will be used for
// configuring the standard out logger.
//
// The behavior is:
// - If it is an error, the message is written to standard error.
// - If it is not, and it matches the log level of the standard output logger (`STDOUT_LOG` env var), it will be send to standard output.
// - If the above is not true, it is sent to the stderr logger that will log it or not depending on the value
// of the RUST_LOG env var.
//
// So to summarize, if no `STDOUT_LOG` env var is present, the logger is the default logger. If `STDOUT_LOG` is set, everything
// but errors, that matches the log level set in the `STDOUT_LOG` env var is sent to stdout. You can have also some esoteric configuration
// where you set `RUST_LOG=debug` and `STDOUT_LOG=info`, in here, erros will go to stderr, warns and infos to stdout and debugs to stderr.
//
pub struct MultiLogger {
stderr_logger: env_logger::Logger,
stdout_logger: env_logger::Logger,
}
impl MultiLogger {
fn new() -> Self {
let stderr_logger = env_logger::builder().format_timestamp_micros().build();
let stdout_logger = env_logger::Builder::from_env("STDOUT_LOG")
.format_timestamp_micros()
.target(env_logger::Target::Stdout)
.build();
Self {
stderr_logger,
stdout_logger,
}
}
pub fn init() -> Result<(), SetLoggerError> {
let logger = Self::new();
log::set_max_level(logger.stderr_logger.filter());
log::set_boxed_logger(Box::new(logger))
}
}
impl Log for MultiLogger {
fn enabled(&self, metadata: &Metadata) -> bool {
self.stderr_logger.enabled(metadata) && self.stdout_logger.enabled(metadata)
}
fn log(&self, record: &Record) {
if record.level() == Level::Error {
self.stderr_logger.log(record);
} else {
if self.stdout_logger.matches(record) {
self.stdout_logger.log(record);
} else {
self.stderr_logger.log(record);
}
}
}
fn flush(&self) {
self.stderr_logger.flush();
self.stdout_logger.flush();
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_init() {
MultiLogger::init().unwrap();
}
}

120
src/plugins/intercept.rs Normal file
View File

@@ -0,0 +1,120 @@
//! The intercept plugin.
//!
//! It intercepts queries and returns fake results.
use async_trait::async_trait;
use bytes::{BufMut, BytesMut};
use serde::{Deserialize, Serialize};
use sqlparser::ast::Statement;
use log::debug;
use crate::{
config::Intercept as InterceptConfig,
errors::Error,
messages::{command_complete, data_row_nullable, row_description, DataType},
plugins::{Plugin, PluginOutput},
query_router::QueryRouter,
};
// TODO: use these structs for deserialization
#[derive(Serialize, Deserialize)]
pub struct Rule {
query: String,
schema: Vec<Column>,
result: Vec<Vec<String>>,
}
#[derive(Serialize, Deserialize)]
pub struct Column {
name: String,
data_type: String,
}
/// The intercept plugin.
pub struct Intercept<'a> {
pub enabled: bool,
pub config: &'a InterceptConfig,
}
#[async_trait]
impl<'a> Plugin for Intercept<'a> {
async fn run(
&mut self,
query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error> {
if !self.enabled || ast.is_empty() {
return Ok(PluginOutput::Allow);
}
let mut config = self.config.clone();
config.substitute(
&query_router.pool_settings().db,
&query_router.pool_settings().user.username,
);
let mut result = BytesMut::new();
for q in ast {
// Normalization
let q = q.to_string().to_ascii_lowercase();
for (_, target) in config.queries.iter() {
if target.query.as_str() == q {
debug!("Intercepting query: {}", q);
let rd = target
.schema
.iter()
.map(|row| {
let name = &row[0];
let data_type = &row[1];
(
name.as_str(),
match data_type.as_str() {
"text" => DataType::Text,
"anyarray" => DataType::AnyArray,
"oid" => DataType::Oid,
"bool" => DataType::Bool,
"int4" => DataType::Int4,
_ => DataType::Any,
},
)
})
.collect::<Vec<(&str, DataType)>>();
result.put(row_description(&rd));
target.result.iter().for_each(|row| {
let row = row
.iter()
.map(|s| {
let s = s.as_str().to_string();
if s == "" {
None
} else {
Some(s)
}
})
.collect::<Vec<Option<String>>>();
result.put(data_row_nullable(&row));
});
result.put(command_complete("SELECT"));
}
}
}
if !result.is_empty() {
result.put_u8(b'Z');
result.put_i32(5);
result.put_u8(b'I');
return Ok(PluginOutput::Intercept(result));
} else {
Ok(PluginOutput::Allow)
}
}
}

44
src/plugins/mod.rs Normal file
View File

@@ -0,0 +1,44 @@
//! The plugin ecosystem.
//!
//! Currently plugins only grant access or deny access to the database for a particual query.
//! Example use cases:
//! - block known bad queries
//! - block access to system catalogs
//! - block dangerous modifications like `DROP TABLE`
//! - etc
//!
pub mod intercept;
pub mod prewarmer;
pub mod query_logger;
pub mod table_access;
use crate::{errors::Error, query_router::QueryRouter};
use async_trait::async_trait;
use bytes::BytesMut;
use sqlparser::ast::Statement;
pub use intercept::Intercept;
pub use query_logger::QueryLogger;
pub use table_access::TableAccess;
#[derive(Clone, Debug, PartialEq)]
pub enum PluginOutput {
Allow,
Deny(String),
Overwrite(Vec<Statement>),
Intercept(BytesMut),
}
#[async_trait]
pub trait Plugin {
// Run before the query is sent to the server.
async fn run(
&mut self,
query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error>;
// TODO: run after the result is returned
// async fn callback(&mut self, query_router: &QueryRouter);
}

28
src/plugins/prewarmer.rs Normal file
View File

@@ -0,0 +1,28 @@
//! Prewarm new connections before giving them to the client.
use crate::{errors::Error, server::Server};
use log::info;
pub struct Prewarmer<'a> {
pub enabled: bool,
pub server: &'a mut Server,
pub queries: &'a Vec<String>,
}
impl<'a> Prewarmer<'a> {
pub async fn run(&mut self) -> Result<(), Error> {
if !self.enabled {
return Ok(());
}
for query in self.queries {
info!(
"{} Prewarning with query: `{}`",
self.server.address(),
query
);
self.server.query(&query).await?;
}
Ok(())
}
}

View File

@@ -0,0 +1,38 @@
//! Log all queries to stdout (or somewhere else, why not).
use crate::{
errors::Error,
plugins::{Plugin, PluginOutput},
query_router::QueryRouter,
};
use async_trait::async_trait;
use log::info;
use sqlparser::ast::Statement;
pub struct QueryLogger<'a> {
pub enabled: bool,
pub user: &'a str,
pub db: &'a str,
}
#[async_trait]
impl<'a> Plugin for QueryLogger<'a> {
async fn run(
&mut self,
_query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error> {
if !self.enabled {
return Ok(PluginOutput::Allow);
}
let query = ast
.iter()
.map(|q| q.to_string())
.collect::<Vec<String>>()
.join("; ");
info!("[pool: {}][user: {}] {}", self.db, self.user, query);
Ok(PluginOutput::Allow)
}
}

View File

@@ -0,0 +1,59 @@
//! This query router plugin will check if the user can access a particular
//! table as part of their query. If they can't, the query will not be routed.
use async_trait::async_trait;
use sqlparser::ast::{visit_relations, Statement};
use crate::{
errors::Error,
plugins::{Plugin, PluginOutput},
query_router::QueryRouter,
};
use log::debug;
use core::ops::ControlFlow;
pub struct TableAccess<'a> {
pub enabled: bool,
pub tables: &'a Vec<String>,
}
#[async_trait]
impl<'a> Plugin for TableAccess<'a> {
async fn run(
&mut self,
_query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error> {
if !self.enabled {
return Ok(PluginOutput::Allow);
}
let mut found = None;
visit_relations(ast, |relation| {
let relation = relation.to_string();
let parts = relation.split(".").collect::<Vec<&str>>();
let table_name = parts.last().unwrap();
if self.tables.contains(&table_name.to_string()) {
found = Some(table_name.to_string());
ControlFlow::<()>::Break(())
} else {
ControlFlow::<()>::Continue(())
}
});
if let Some(found) = found {
debug!("Blocking access to table \"{}\"", found);
Ok(PluginOutput::Deny(format!(
"permission for table \"{}\" denied",
found
)))
} else {
Ok(PluginOutput::Allow)
}
}
}

View File

@@ -1,7 +1,6 @@
use arc_swap::ArcSwap;
use async_trait::async_trait;
use bb8::{ManageConnection, Pool, PooledConnection};
use bytes::{BufMut, BytesMut};
use bb8::{ManageConnection, Pool, PooledConnection, QueueStrategy};
use chrono::naive::NaiveDateTime;
use log::{debug, error, info, warn};
use once_cell::sync::Lazy;
@@ -10,6 +9,8 @@ use rand::seq::SliceRandom;
use rand::thread_rng;
use regex::Regex;
use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use std::sync::atomic::AtomicU64;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
@@ -17,12 +18,16 @@ use std::sync::{
use std::time::Instant;
use tokio::sync::Notify;
use crate::config::{get_config, Address, General, LoadBalancingMode, PoolMode, Role, User};
use crate::config::{
get_config, Address, DefaultShard, General, LoadBalancingMode, Plugins, PoolMode, Role, User,
};
use crate::errors::Error;
use crate::server::Server;
use crate::auth_passthrough::AuthPassthrough;
use crate::plugins::prewarmer;
use crate::server::{Server, ServerParameters};
use crate::sharding::ShardingFunction;
use crate::stats::{get_reporter, Reporter};
use crate::stats::{AddressStats, ClientStats, ServerStats};
pub type ProcessId = i32;
pub type SecretKey = i32;
@@ -51,7 +56,7 @@ pub enum BanReason {
/// An identifier for a PgCat pool,
/// a database visible to clients.
#[derive(Hash, Debug, Clone, PartialEq, Eq)]
#[derive(Hash, Debug, Clone, PartialEq, Eq, Default)]
pub struct PoolIdentifier {
// The name of the database clients want to connect to.
pub db: String,
@@ -60,6 +65,8 @@ pub struct PoolIdentifier {
pub user: String,
}
static POOL_REAPER_RATE: u64 = 30_000; // 30 seconds by default
impl PoolIdentifier {
/// Create a new user/pool identifier.
pub fn new(db: &str, user: &str) -> PoolIdentifier {
@@ -70,6 +77,12 @@ impl PoolIdentifier {
}
}
impl Display for PoolIdentifier {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}@{}", self.user, self.db)
}
}
impl From<&Address> for PoolIdentifier {
fn from(address: &Address) -> PoolIdentifier {
PoolIdentifier::new(&address.database, &address.username)
@@ -90,6 +103,7 @@ pub struct PoolSettings {
// Connecting user.
pub user: User,
pub db: String,
// Default server role to connect to.
pub default_role: Option<Role>,
@@ -97,6 +111,12 @@ pub struct PoolSettings {
// Enable/disable query parser.
pub query_parser_enabled: bool,
// Max length of query the parser will parse.
pub query_parser_max_length: Option<usize>,
// Infer role
pub query_parser_read_write_splitting: bool,
// Read from the primary as well or not.
pub primary_reads_enabled: bool,
@@ -121,8 +141,19 @@ pub struct PoolSettings {
// Regex for searching for the shard id in SQL statements
pub shard_id_regex: Option<Regex>,
// What to do when no shard is selected in a sharded system
pub default_shard: DefaultShard,
// Limit how much of each query is searched for a potential shard regex match
pub regex_search_limit: usize,
// Auth query parameters
pub auth_query: Option<String>,
pub auth_query_user: Option<String>,
pub auth_query_password: Option<String>,
/// Plugins
pub plugins: Option<Plugins>,
}
impl Default for PoolSettings {
@@ -132,8 +163,11 @@ impl Default for PoolSettings {
load_balancing_mode: LoadBalancingMode::Random,
shards: 1,
user: User::default(),
db: String::default(),
default_role: None,
query_parser_enabled: false,
query_parser_max_length: None,
query_parser_read_write_splitting: false,
primary_reads_enabled: true,
sharding_function: ShardingFunction::PgBigintHash,
automatic_sharding_key: None,
@@ -143,6 +177,11 @@ impl Default for PoolSettings {
sharding_key_regex: None,
shard_id_regex: None,
regex_search_limit: 1000,
default_shard: DefaultShard::Shard(0),
auth_query: None,
auth_query_user: None,
auth_query_password: None,
plugins: None,
}
}
}
@@ -161,14 +200,10 @@ pub struct ConnectionPool {
/// that should not be queried.
banlist: BanList,
/// The statistics aggregator runs in a separate task
/// and receives stats from clients, servers, and the pool.
stats: Reporter,
/// The server information (K messages) have to be passed to the
/// The server information has to be passed to the
/// clients on startup. We pre-connect to all shards and replicas
/// on pool creation and save the K messages here.
server_info: Arc<RwLock<BytesMut>>,
/// on pool creation and save the startup parameters here.
original_server_parameters: Arc<RwLock<ServerParameters>>,
/// Pool configuration.
pub settings: PoolSettings,
@@ -185,6 +220,9 @@ pub struct ConnectionPool {
/// If the pool has been paused or not.
paused: Arc<AtomicBool>,
paused_waiter: Arc<Notify>,
/// AuthInfo
pub auth_hash: Arc<RwLock<Option<String>>>,
}
impl ConnectionPool {
@@ -201,6 +239,7 @@ impl ConnectionPool {
// There is one pool per database/user pair.
for user in pool_config.users.values() {
let old_pool_ref = get_pool(pool_name, &user.username);
let identifier = PoolIdentifier::new(pool_name, &user.username);
match old_pool_ref {
Some(pool) => {
@@ -211,10 +250,7 @@ impl ConnectionPool {
"[pool: {}][user: {}] has not changed",
pool_name, user.username
);
new_pools.insert(
PoolIdentifier::new(pool_name, &user.username),
pool.clone(),
);
new_pools.insert(identifier.clone(), pool.clone());
continue;
}
}
@@ -237,6 +273,7 @@ impl ConnectionPool {
// Sort by shard number to ensure consistency.
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
let pool_auth_hash: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
for shard_idx in &shard_ids {
let shard = &pool_config.shards[shard_idx];
@@ -266,6 +303,8 @@ impl ConnectionPool {
username: user.username.clone(),
pool_name: pool_name.clone(),
mirrors: vec![],
stats: Arc::new(AddressStats::default()),
error_count: Arc::new(AtomicU64::new(0)),
});
address_id += 1;
}
@@ -283,6 +322,8 @@ impl ConnectionPool {
username: user.username.clone(),
pool_name: pool_name.clone(),
mirrors: mirror_addresses,
stats: Arc::new(AddressStats::default()),
error_count: Arc::new(AtomicU64::new(0)),
};
address_id += 1;
@@ -291,12 +332,53 @@ impl ConnectionPool {
replica_number += 1;
}
// We assume every server in the pool share user/passwords
let auth_passthrough = AuthPassthrough::from_pool_config(pool_config);
if let Some(apt) = &auth_passthrough {
match apt.fetch_hash(&address).await {
Ok(ok) => {
if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read())
{
if ok != *pool_auth_hash_value {
warn!(
"Hash is not the same across shards \
of the same pool, client auth will \
be done using last obtained hash. \
Server: {}:{}, Database: {}",
server.host, server.port, shard.database,
);
}
}
debug!("Hash obtained for {:?}", address);
{
let mut pool_auth_hash = pool_auth_hash.write();
*pool_auth_hash = Some(ok.clone());
}
}
Err(err) => warn!(
"Could not obtain password hashes \
using auth_query config, ignoring. \
Error: {:?}",
err,
),
}
}
let manager = ServerPool::new(
address.clone(),
user.clone(),
&shard.database,
client_server_map.clone(),
get_reporter(),
pool_auth_hash.clone(),
match pool_config.plugins {
Some(ref plugins) => Some(plugins.clone()),
None => config.plugins.clone(),
},
pool_config.cleanup_server_connections,
pool_config.log_client_parameter_status_changes,
);
let connect_timeout = match pool_config.connect_timeout {
@@ -309,14 +391,44 @@ impl ConnectionPool {
None => config.general.idle_timeout,
};
let server_lifetime = match user.server_lifetime {
Some(server_lifetime) => server_lifetime,
None => match pool_config.server_lifetime {
Some(server_lifetime) => server_lifetime,
None => config.general.server_lifetime,
},
};
let reaper_rate = *vec![idle_timeout, server_lifetime, POOL_REAPER_RATE]
.iter()
.min()
.unwrap();
let queue_strategy = match config.general.server_round_robin {
true => QueueStrategy::Fifo,
false => QueueStrategy::Lifo,
};
debug!(
"[pool: {}][user: {}] Pool reaper rate: {}ms",
pool_name, user.username, reaper_rate
);
let pool = Pool::builder()
.max_size(user.pool_size)
.min_idle(user.min_pool_size)
.connection_timeout(std::time::Duration::from_millis(connect_timeout))
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
.test_on_check_out(false)
.build(manager)
.await
.unwrap();
.max_lifetime(Some(std::time::Duration::from_millis(server_lifetime)))
.reaper_rate(std::time::Duration::from_millis(reaper_rate))
.queue_strategy(queue_strategy)
.test_on_check_out(false);
let pool = if config.general.validate_config {
pool.build(manager).await?
} else {
pool.build_unchecked(manager)
};
pools.push(pool);
servers.push(address);
@@ -328,20 +440,30 @@ impl ConnectionPool {
}
assert_eq!(shards.len(), addresses.len());
if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) {
info!(
"Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}",
pool_name, user.username
);
}
let pool = ConnectionPool {
databases: shards,
addresses,
banlist: Arc::new(RwLock::new(banlist)),
stats: get_reporter(),
config_hash: new_pool_hash_value,
server_info: Arc::new(RwLock::new(BytesMut::new())),
original_server_parameters: Arc::new(RwLock::new(ServerParameters::new())),
auth_hash: pool_auth_hash,
settings: PoolSettings {
pool_mode: pool_config.pool_mode,
pool_mode: match user.pool_mode {
Some(pool_mode) => pool_mode,
None => pool_config.pool_mode,
},
load_balancing_mode: pool_config.load_balancing_mode,
// shards: pool_config.shards.clone(),
shards: shard_ids.len(),
user: user.clone(),
db: pool_name.clone(),
default_role: match pool_config.default_role.as_str() {
"any" => None,
"replica" => Some(Role::Replica),
@@ -349,6 +471,9 @@ impl ConnectionPool {
_ => unreachable!(),
},
query_parser_enabled: pool_config.query_parser_enabled,
query_parser_max_length: pool_config.query_parser_max_length,
query_parser_read_write_splitting: pool_config
.query_parser_read_write_splitting,
primary_reads_enabled: pool_config.primary_reads_enabled,
sharding_function: pool_config.sharding_function,
automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
@@ -364,6 +489,14 @@ impl ConnectionPool {
.clone()
.map(|regex| Regex::new(regex.as_str()).unwrap()),
regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
default_shard: pool_config.default_shard.clone(),
auth_query: pool_config.auth_query.clone(),
auth_query_user: pool_config.auth_query_user.clone(),
auth_query_password: pool_config.auth_query_password.clone(),
plugins: match pool_config.plugins {
Some(ref plugins) => Some(plugins.clone()),
None => config.plugins.clone(),
},
},
validated: Arc::new(AtomicBool::new(false)),
paused: Arc::new(AtomicBool::new(false)),
@@ -373,10 +506,12 @@ impl ConnectionPool {
// Connect to the servers to make sure pool configuration is valid
// before setting it globally.
// Do this async and somewhere else, we don't have to wait here.
let mut validate_pool = pool.clone();
tokio::task::spawn(async move {
let _ = validate_pool.validate().await;
});
if config.general.validate_config {
let mut validate_pool = pool.clone();
tokio::task::spawn(async move {
let _ = validate_pool.validate().await;
});
}
// There is one pool per database/user pair.
new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool);
@@ -387,7 +522,8 @@ impl ConnectionPool {
Ok(())
}
/// Connect to all shards and grab server information.
/// Connect to all shards, grab server information, and possibly
/// passwords to use in client auth.
/// Return server information we will pass to the clients
/// when they connect.
/// This also warms up the pool for clients that connect when
@@ -400,7 +536,7 @@ impl ConnectionPool {
for server in 0..self.servers(shard) {
let databases = self.databases.clone();
let validated = Arc::clone(&validated);
let pool_server_info = Arc::clone(&self.server_info);
let pool_server_parameters = Arc::clone(&self.original_server_parameters);
let task = tokio::task::spawn(async move {
let connection = match databases[shard][server].get().await {
@@ -413,11 +549,10 @@ impl ConnectionPool {
let proxy = connection;
let server = &*proxy;
let server_info = server.server_info();
let server_parameters: ServerParameters = server.server_parameters();
let mut guard = pool_server_info.write();
guard.clear();
guard.put(server_info.clone());
let mut guard = pool_server_parameters.write();
*guard = server_parameters;
validated.store(true, Ordering::Relaxed);
});
@@ -429,7 +564,7 @@ impl ConnectionPool {
// TODO: compare server information to make sure
// all shards are running identical configurations.
if self.server_info.read().is_empty() {
if !self.validated() {
error!("Could not validate connection pool");
return Err(Error::AllServersDown);
}
@@ -476,19 +611,51 @@ impl ConnectionPool {
/// Get a connection from the pool.
pub async fn get(
&self,
shard: usize, // shard number
role: Option<Role>, // primary or replica
client_process_id: i32, // client id
shard: Option<usize>, // shard number
role: Option<Role>, // primary or replica
client_stats: &ClientStats, // client id
) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
let mut candidates: Vec<&Address> = self.addresses[shard]
.iter()
.filter(|address| address.role == role)
.collect();
let effective_shard_id = if self.shards() == 1 {
// The base, unsharded case
Some(0)
} else {
if !self.valid_shard_id(shard) {
// None is valid shard ID so it is safe to unwrap here
return Err(Error::InvalidShardId(shard.unwrap()));
}
shard
};
// We shuffle even if least_outstanding_queries is used to avoid imbalance
// in cases where all candidates have more or less the same number of outstanding
// queries
let mut candidates = self
.addresses
.iter()
.flatten()
.filter(|address| address.role == role)
.collect::<Vec<&Address>>();
// We start with a shuffled list of addresses even if we end up resorting
// this is meant to avoid hitting instance 0 everytime if the sorting metric
// ends up being the same for all instances
candidates.shuffle(&mut thread_rng());
match effective_shard_id {
Some(shard_id) => candidates.retain(|address| address.shard == shard_id),
None => match self.settings.default_shard {
DefaultShard::Shard(shard_id) => {
candidates.retain(|address| address.shard == shard_id)
}
DefaultShard::Random => (),
DefaultShard::RandomHealthy => {
candidates.sort_by(|a, b| {
b.error_count
.load(Ordering::Relaxed)
.partial_cmp(&a.error_count.load(Ordering::Relaxed))
.unwrap()
});
}
},
};
if self.settings.load_balancing_mode == LoadBalancingMode::LeastOutstandingConnections {
candidates.sort_by(|a, b| {
self.busy_connection_count(b)
@@ -497,6 +664,10 @@ impl ConnectionPool {
});
}
// Indicate we're waiting on a server connection from a pool.
let now = Instant::now();
client_stats.waiting();
while !candidates.is_empty() {
// Get the next candidate
let address = match candidates.pop() {
@@ -515,21 +686,24 @@ impl ConnectionPool {
}
}
// Indicate we're waiting on a server connection from a pool.
let now = Instant::now();
self.stats.client_waiting(client_process_id);
// Check if we can connect
let mut conn = match self.databases[address.shard][address.address_index]
.get()
.await
{
Ok(conn) => conn,
Ok(conn) => {
address.reset_error_count();
conn
}
Err(err) => {
error!("Banning instance {:?}, error: {:?}", address, err);
self.ban(address, BanReason::FailedCheckout, client_process_id);
self.stats
.client_checkout_error(client_process_id, address.id);
error!(
"Connection checkout error for instance {:?}, error: {:?}",
address, err
);
self.ban(address, BanReason::FailedCheckout, Some(client_stats));
address.stats.error();
client_stats.idle();
client_stats.checkout_error();
continue;
}
};
@@ -546,26 +720,38 @@ impl ConnectionPool {
// since we last checked the server is ok.
// Health checks are pretty expensive.
if !require_healthcheck {
self.stats.checkout_time(
now.elapsed().as_micros(),
client_process_id,
server.server_id(),
);
self.stats
.server_active(client_process_id, server.server_id());
let checkout_time = now.elapsed().as_micros() as u64;
client_stats.checkout_time(checkout_time);
server
.stats()
.checkout_time(checkout_time, client_stats.application_name());
server.stats().active(client_stats.application_name());
client_stats.active();
return Ok((conn, address.clone()));
}
if self
.run_health_check(address, server, now, client_process_id)
.run_health_check(address, server, now, client_stats)
.await
{
let checkout_time = now.elapsed().as_micros() as u64;
client_stats.checkout_time(checkout_time);
server
.stats()
.checkout_time(checkout_time, client_stats.application_name());
server.stats().active(client_stats.application_name());
client_stats.active();
return Ok((conn, address.clone()));
} else {
continue;
}
}
client_stats.idle();
let checkout_time = now.elapsed().as_micros() as u64;
client_stats.checkout_time(checkout_time);
Err(Error::AllServersDown)
}
@@ -574,11 +760,11 @@ impl ConnectionPool {
address: &Address,
server: &mut Server,
start: Instant,
client_process_id: i32,
client_info: &ClientStats,
) -> bool {
debug!("Running health check on server {:?}", address);
self.stats.server_tested(server.server_id());
server.stats().tested();
match tokio::time::timeout(
tokio::time::Duration::from_millis(self.settings.healthcheck_timeout),
@@ -589,20 +775,20 @@ impl ConnectionPool {
// Check if health check succeeded.
Ok(res) => match res {
Ok(_) => {
self.stats.checkout_time(
start.elapsed().as_micros(),
client_process_id,
server.server_id(),
);
self.stats
.server_active(client_process_id, server.server_id());
let checkout_time: u64 = start.elapsed().as_micros() as u64;
client_info.checkout_time(checkout_time);
server
.stats()
.checkout_time(checkout_time, client_info.application_name());
server.stats().active(client_info.application_name());
return true;
}
// Health check failed.
Err(err) => {
error!(
"Banning instance {:?} because of failed health check, {:?}",
"Failed health check on instance {:?}, error: {:?}",
address, err
);
}
@@ -611,7 +797,7 @@ impl ConnectionPool {
// Health check timed out.
Err(err) => {
error!(
"Banning instance {:?} because of health check timeout, {:?}",
"Health check timeout on instance {:?}, error: {:?}",
address, err
);
}
@@ -620,23 +806,41 @@ impl ConnectionPool {
// Don't leave a bad connection in the pool.
server.mark_bad();
self.ban(&address, BanReason::FailedHealthCheck, client_process_id);
self.ban(&address, BanReason::FailedHealthCheck, Some(client_info));
return false;
}
/// Ban an address (i.e. replica). It no longer will serve
/// traffic for any new transactions. Existing transactions on that replica
/// will finish successfully or error out to the clients.
pub fn ban(&self, address: &Address, reason: BanReason, client_id: i32) {
pub fn ban(&self, address: &Address, reason: BanReason, client_info: Option<&ClientStats>) {
// Count the number of errors since the last successful checkout
// This is used to determine if the shard is down
match reason {
BanReason::FailedHealthCheck
| BanReason::FailedCheckout
| BanReason::MessageSendFailed
| BanReason::MessageReceiveFailed => {
address.increment_error_count();
}
_ => (),
};
// Primary can never be banned
if address.role == Role::Primary {
return;
}
error!("Banning instance {:?}, reason: {:?}", address, reason);
let now = chrono::offset::Utc::now().naive_utc();
let mut guard = self.banlist.write();
error!("Banning {:?}", address);
self.stats.client_ban_error(client_id, address.id);
if let Some(client_info) = client_info {
client_info.ban_error();
address.stats.error();
}
guard[address.shard].insert(address.clone(), (reason, now));
}
@@ -772,10 +976,11 @@ impl ConnectionPool {
&self.addresses[shard][server]
}
pub fn server_info(&self) -> BytesMut {
self.server_info.read().clone()
pub fn server_parameters(&self) -> ServerParameters {
self.original_server_parameters.read().clone()
}
/// Get the number of checked out connection for an address
fn busy_connection_count(&self, address: &Address) -> u32 {
let state = self.pool_state(address.shard, address.address_index);
let idle = state.idle_connections;
@@ -789,15 +994,40 @@ impl ConnectionPool {
debug!("{:?} has {:?} busy connections", address, busy);
return busy;
}
fn valid_shard_id(&self, shard: Option<usize>) -> bool {
match shard {
None => true,
Some(shard) => shard < self.shards(),
}
}
}
/// Wrapper for the bb8 connection pool.
pub struct ServerPool {
/// Server address.
address: Address,
/// Server Postgres user.
user: User,
/// Server database.
database: String,
/// Client/server mapping.
client_server_map: ClientServerMap,
stats: Reporter,
/// Server auth hash (for auth passthrough).
auth_hash: Arc<RwLock<Option<String>>>,
/// Server plugins.
plugins: Option<Plugins>,
/// Should we clean up dirty connections before putting them into the pool?
cleanup_connections: bool,
/// Log client parameter status changes
log_client_parameter_status_changes: bool,
}
impl ServerPool {
@@ -806,14 +1036,20 @@ impl ServerPool {
user: User,
database: &str,
client_server_map: ClientServerMap,
stats: Reporter,
auth_hash: Arc<RwLock<Option<String>>>,
plugins: Option<Plugins>,
cleanup_connections: bool,
log_client_parameter_status_changes: bool,
) -> ServerPool {
ServerPool {
address,
user,
user: user.clone(),
database: database.to_string(),
client_server_map,
stats,
auth_hash,
plugins,
cleanup_connections,
log_client_parameter_status_changes,
}
}
}
@@ -826,34 +1062,45 @@ impl ManageConnection for ServerPool {
/// Attempts to create a new connection.
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
info!("Creating a new server connection {:?}", self.address);
let server_id = rand::random::<i32>();
self.stats.server_register(
server_id,
self.address.id,
self.address.name(),
self.address.pool_name.clone(),
self.address.username.clone(),
);
self.stats.server_login(server_id);
let stats = Arc::new(ServerStats::new(
self.address.clone(),
tokio::time::Instant::now(),
));
stats.register(stats.clone());
// Connect to the PostgreSQL server.
match Server::startup(
server_id,
&self.address,
&self.user,
&self.database,
self.client_server_map.clone(),
self.stats.clone(),
stats.clone(),
self.auth_hash.clone(),
self.cleanup_connections,
self.log_client_parameter_status_changes,
)
.await
{
Ok(conn) => {
self.stats.server_idle(server_id);
Ok(mut conn) => {
if let Some(ref plugins) = self.plugins {
if let Some(ref prewarmer) = plugins.prewarmer {
let mut prewarmer = prewarmer::Prewarmer {
enabled: prewarmer.enabled,
server: &mut conn,
queries: &prewarmer.queries,
};
prewarmer.run().await?;
}
}
stats.idle();
Ok(conn)
}
Err(err) => {
self.stats.server_disconnecting(server_id);
stats.disconnect();
Err(err)
}
}
@@ -881,11 +1128,3 @@ pub fn get_pool(db: &str, user: &str) -> Option<ConnectionPool> {
pub fn get_all_pools() -> HashMap<PoolIdentifier, ConnectionPool> {
(*(*POOLS.load())).clone()
}
/// How many total servers we have in the config.
pub fn get_number_of_addresses() -> usize {
get_all_pools()
.iter()
.map(|(_, pool)| pool.databases())
.sum()
}

View File

@@ -1,14 +1,17 @@
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, Server, StatusCode};
use log::{error, info, warn};
use log::{debug, error, info};
use phf::phf_map;
use std::collections::HashMap;
use std::fmt;
use std::net::SocketAddr;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use crate::config::Address;
use crate::pool::get_all_pools;
use crate::stats::{get_address_stats, get_pool_stats, get_server_stats, ServerInformation};
use crate::pool::{get_all_pools, PoolIdentifier};
use crate::stats::pool::PoolStats;
use crate::stats::{get_server_stats, ServerStats};
struct MetricHelpType {
help: &'static str,
@@ -220,7 +223,7 @@ impl<Value: fmt::Display> PrometheusMetric<Value> {
Self::from_name(&format!("servers_{}", name), value, labels)
}
fn from_address(address: &Address, name: &str, value: i64) -> Option<PrometheusMetric<i64>> {
fn from_address(address: &Address, name: &str, value: u64) -> Option<PrometheusMetric<u64>> {
let mut labels = HashMap::new();
labels.insert("host", address.host.clone());
labels.insert("shard", address.shard.to_string());
@@ -231,10 +234,10 @@ impl<Value: fmt::Display> PrometheusMetric<Value> {
Self::from_name(&format!("stats_{}", name), value, labels)
}
fn from_pool(pool: &(String, String), name: &str, value: i64) -> Option<PrometheusMetric<i64>> {
fn from_pool(pool_id: PoolIdentifier, name: &str, value: u64) -> Option<PrometheusMetric<u64>> {
let mut labels = HashMap::new();
labels.insert("pool", pool.0.clone());
labels.insert("user", pool.1.clone());
labels.insert("pool", pool_id.db);
labels.insert("user", pool_id.user);
Self::from_name(&format!("pools_{}", name), value, labels)
}
@@ -261,20 +264,18 @@ async fn prometheus_stats(request: Request<Body>) -> Result<Response<Body>, hype
// Adds metrics shown in a SHOW STATS admin command.
fn push_address_stats(lines: &mut Vec<String>) {
let address_stats: HashMap<usize, HashMap<String, i64>> = get_address_stats();
for (_, pool) in get_all_pools() {
for shard in 0..pool.shards() {
for server in 0..pool.servers(shard) {
let address = pool.address(shard, server);
if let Some(address_stats) = address_stats.get(&address.id) {
for (key, value) in address_stats.iter() {
if let Some(prometheus_metric) =
PrometheusMetric::<i64>::from_address(address, key, *value)
{
lines.push(prometheus_metric.to_string());
} else {
warn!("Metric {} not implemented for {}", key, address.name());
}
let stats = &*address.stats;
for (key, value) in stats.clone() {
if let Some(prometheus_metric) =
PrometheusMetric::<u64>::from_address(address, &key, value)
{
lines.push(prometheus_metric.to_string());
} else {
debug!("Metric {} not implemented for {}", key, address.name());
}
}
}
@@ -284,17 +285,15 @@ fn push_address_stats(lines: &mut Vec<String>) {
// Adds relevant metrics shown in a SHOW POOLS admin command.
fn push_pool_stats(lines: &mut Vec<String>) {
let pool_stats = get_pool_stats();
for (pool, stats) in pool_stats.iter() {
for (name, value) in stats.iter() {
if let Some(prometheus_metric) = PrometheusMetric::<i64>::from_pool(pool, name, *value)
let pool_stats = PoolStats::construct_pool_lookup();
for (pool_id, stats) in pool_stats.iter() {
for (name, value) in stats.clone() {
if let Some(prometheus_metric) =
PrometheusMetric::<u64>::from_pool(pool_id.clone(), &name, value)
{
lines.push(prometheus_metric.to_string());
} else {
warn!(
"Metric {} not implemented for ({},{})",
name, pool.0, pool.1
);
debug!("Metric {} not implemented for ({})", name, *pool_id);
}
}
}
@@ -319,7 +318,7 @@ fn push_database_stats(lines: &mut Vec<String>) {
{
lines.push(prometheus_metric.to_string());
} else {
warn!("Metric {} not implemented for {}", key, address.name());
debug!("Metric {} not implemented for {}", key, address.name());
}
}
}
@@ -330,9 +329,9 @@ fn push_database_stats(lines: &mut Vec<String>) {
// Adds relevant metrics shown in a SHOW SERVERS admin command.
fn push_server_stats(lines: &mut Vec<String>) {
let server_stats = get_server_stats();
let mut server_stats_by_addresses = HashMap::<String, ServerInformation>::new();
for (_, info) in server_stats {
server_stats_by_addresses.insert(info.address_name.clone(), info);
let mut server_stats_by_addresses = HashMap::<String, Arc<ServerStats>>::new();
for (_, stats) in server_stats {
server_stats_by_addresses.insert(stats.address_name(), stats);
}
for (_, pool) in get_all_pools() {
@@ -341,11 +340,23 @@ fn push_server_stats(lines: &mut Vec<String>) {
let address = pool.address(shard, server);
if let Some(server_info) = server_stats_by_addresses.get(&address.name()) {
let metrics = [
("bytes_received", server_info.bytes_received),
("bytes_sent", server_info.bytes_sent),
("transaction_count", server_info.transaction_count),
("query_count", server_info.query_count),
("error_count", server_info.error_count),
(
"bytes_received",
server_info.bytes_received.load(Ordering::Relaxed),
),
("bytes_sent", server_info.bytes_sent.load(Ordering::Relaxed)),
(
"transaction_count",
server_info.transaction_count.load(Ordering::Relaxed),
),
(
"query_count",
server_info.query_count.load(Ordering::Relaxed),
),
(
"error_count",
server_info.error_count.load(Ordering::Relaxed),
),
];
for (key, value) in metrics {
if let Some(prometheus_metric) =
@@ -353,7 +364,7 @@ fn push_server_stats(lines: &mut Vec<String>) {
{
lines.push(prometheus_metric.to_string());
} else {
warn!("Metric {} not implemented for {}", key, address.name());
debug!("Metric {} not implemented for {}", key, address.name());
}
}
}

View File

@@ -1,4 +1,4 @@
/// Route queries automatically based on explicitely requested
/// Route queries automatically based on explicitly requested
/// or implied query characteristics.
use bytes::{Buf, BytesMut};
use log::{debug, error};
@@ -6,19 +6,22 @@ use once_cell::sync::OnceCell;
use regex::{Regex, RegexSet};
use sqlparser::ast::Statement::{Query, StartTransaction};
use sqlparser::ast::{
BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, TableFactor, Value,
BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, Statement, TableFactor,
Value,
};
use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;
use crate::config::Role;
use crate::errors::Error;
use crate::messages::BytesMutReader;
use crate::plugins::{Intercept, Plugin, PluginOutput, QueryLogger, TableAccess};
use crate::pool::PoolSettings;
use crate::sharding::Sharder;
use std::cmp;
use std::collections::BTreeSet;
use std::io::Cursor;
use std::{cmp, mem};
/// Regexes used to parse custom commands.
const CUSTOM_SQL_REGEXES: [&str; 7] = [
@@ -129,23 +132,33 @@ impl QueryRouter {
self.pool_settings = pool_settings;
}
pub fn pool_settings<'a>(&'a self) -> &'a PoolSettings {
&self.pool_settings
}
/// Try to parse a command and execute it.
pub fn try_execute_command(&mut self, message_buffer: &BytesMut) -> Option<(Command, String)> {
let mut message_cursor = Cursor::new(message_buffer);
let code = message_cursor.get_u8() as char;
let len = message_cursor.get_i32() as usize;
let comment_shard_routing_enabled = self.pool_settings.shard_id_regex.is_some()
|| self.pool_settings.sharding_key_regex.is_some();
// Check for any sharding regex matches in any queries
match code as char {
// For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement
'P' | 'Q' => {
if self.pool_settings.shard_id_regex.is_some()
|| self.pool_settings.sharding_key_regex.is_some()
{
if comment_shard_routing_enabled {
match code as char {
// For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement
'P' | 'Q' => {
// Check only the first block of bytes configured by the pool settings
let len = message_cursor.get_i32() as usize;
let seg = cmp::min(len - 5, self.pool_settings.regex_search_limit);
let initial_segment = String::from_utf8_lossy(&message_buffer[0..seg]);
let query_start_index = mem::size_of::<u8>() + mem::size_of::<i32>();
let initial_segment = String::from_utf8_lossy(
&message_buffer[query_start_index..query_start_index + seg],
);
// Check for a shard_id included in the query
if let Some(shard_id_regex) = &self.pool_settings.shard_id_regex {
@@ -154,7 +167,7 @@ impl QueryRouter {
});
if let Some(shard_id) = shard_id {
debug!("Setting shard to {:?}", shard_id);
self.set_shard(shard_id);
self.set_shard(Some(shard_id));
// Skip other command processing since a sharding command was found
return None;
}
@@ -176,8 +189,8 @@ impl QueryRouter {
}
}
}
_ => {}
}
_ => {}
}
// Only simple protocol supported for commands processed below
@@ -185,7 +198,6 @@ impl QueryRouter {
return None;
}
let _len = message_cursor.get_i32() as usize;
let query = message_cursor.read_string().unwrap();
let regex_set = match CUSTOM_SQL_REGEX_SET.get() {
@@ -237,7 +249,9 @@ impl QueryRouter {
}
}
Command::ShowShard => self.shard().to_string(),
Command::ShowShard => self
.shard()
.map_or_else(|| "unset".to_string(), |x| x.to_string()),
Command::ShowServerRole => match self.active_role {
Some(Role::Primary) => Role::Primary.to_string(),
Some(Role::Replica) => Role::Replica.to_string(),
@@ -324,14 +338,23 @@ impl QueryRouter {
Some((command, value))
}
/// Try to infer which server to connect to based on the contents of the query.
pub fn infer(&mut self, message: &BytesMut) -> bool {
debug!("Inferring role");
pub fn parse(&self, message: &BytesMut) -> Result<Vec<Statement>, Error> {
let mut message_cursor = Cursor::new(message);
let code = message_cursor.get_u8() as char;
let _len = message_cursor.get_i32() as usize;
let len = message_cursor.get_i32() as usize;
match self.pool_settings.query_parser_max_length {
Some(max_length) => {
if len > max_length {
return Err(Error::QueryRouterParserError(format!(
"Query too long for parser: {} > {}",
len, max_length
)));
}
}
None => (),
};
let query = match code {
// Query
@@ -344,37 +367,43 @@ impl QueryRouter {
// Parse (prepared statement)
'P' => {
// Reads statement name
message_cursor.read_string().unwrap();
let _name = message_cursor.read_string().unwrap();
// Reads query string
let query = message_cursor.read_string().unwrap();
debug!("Prepared statement: '{}'", query);
query
}
_ => return false,
_ => return Err(Error::UnsupportedStatement),
};
let ast = match Parser::parse_sql(&PostgreSqlDialect {}, &query) {
Ok(ast) => ast,
match Parser::parse_sql(&PostgreSqlDialect {}, &query) {
Ok(ast) => Ok(ast),
Err(err) => {
// SELECT ... FOR UPDATE won't get parsed correctly.
debug!("{}: {}", err, query);
self.active_role = Some(Role::Primary);
return false;
Err(Error::QueryRouterParserError(err.to_string()))
}
};
}
}
debug!("AST: {:?}", ast);
/// Try to infer which server to connect to based on the contents of the query.
pub fn infer(&mut self, ast: &Vec<sqlparser::ast::Statement>) -> Result<(), Error> {
if !self.pool_settings.query_parser_read_write_splitting {
return Ok(()); // Nothing to do
}
debug!("Inferring role");
if ast.is_empty() {
// That's weird, no idea, let's go to primary
self.active_role = Some(Role::Primary);
return false;
return Err(Error::QueryRouterParserError("empty query".into()));
}
for q in &ast {
for q in ast {
match q {
// All transactions go to the primary, probably a write.
StartTransaction { .. } => {
@@ -418,7 +447,7 @@ impl QueryRouter {
};
}
true
Ok(())
}
/// Parse the shard number from the Bind message
@@ -427,6 +456,10 @@ impl QueryRouter {
/// N.B.: Only supports anonymous prepared statements since we don't
/// keep a cache of them in PgCat.
pub fn infer_shard_from_bind(&mut self, message: &BytesMut) -> bool {
if !self.pool_settings.query_parser_read_write_splitting {
return false; // Nothing to do
}
debug!("Parsing bind message");
let mut message_cursor = Cursor::new(message);
@@ -551,7 +584,7 @@ impl QueryRouter {
// TODO: Support multi-shard queries some day.
if shards.len() == 1 {
debug!("Found one sharding key");
self.set_shard(*shards.first().unwrap());
self.set_shard(Some(*shards.first().unwrap()));
true
} else {
debug!("Found no sharding keys");
@@ -783,13 +816,59 @@ impl QueryRouter {
}
}
/// Add your plugins here and execute them.
pub async fn execute_plugins(&self, ast: &Vec<Statement>) -> Result<PluginOutput, Error> {
let plugins = match self.pool_settings.plugins {
Some(ref plugins) => plugins,
None => return Ok(PluginOutput::Allow),
};
if let Some(ref query_logger) = plugins.query_logger {
let mut query_logger = QueryLogger {
enabled: query_logger.enabled,
user: &self.pool_settings.user.username,
db: &self.pool_settings.db,
};
let _ = query_logger.run(&self, ast).await;
}
if let Some(ref intercept) = plugins.intercept {
let mut intercept = Intercept {
enabled: intercept.enabled,
config: &intercept,
};
let result = intercept.run(&self, ast).await;
if let Ok(PluginOutput::Intercept(output)) = result {
return Ok(PluginOutput::Intercept(output));
}
}
if let Some(ref table_access) = plugins.table_access {
let mut table_access = TableAccess {
enabled: table_access.enabled,
tables: &table_access.tables,
};
let result = table_access.run(&self, ast).await;
if let Ok(PluginOutput::Deny(error)) = result {
return Ok(PluginOutput::Deny(error));
}
}
Ok(PluginOutput::Allow)
}
fn set_sharding_key(&mut self, sharding_key: i64) -> Option<usize> {
let sharder = Sharder::new(
self.pool_settings.shards,
self.pool_settings.sharding_function,
);
let shard = sharder.shard(sharding_key);
self.set_shard(shard);
self.set_shard(Some(shard));
self.active_shard
}
@@ -799,22 +878,33 @@ impl QueryRouter {
}
/// Get desired shard we should be talking to.
pub fn shard(&self) -> usize {
self.active_shard.unwrap_or(0)
pub fn shard(&self) -> Option<usize> {
self.active_shard
}
pub fn set_shard(&mut self, shard: usize) {
self.active_shard = Some(shard);
pub fn set_shard(&mut self, shard: Option<usize>) {
self.active_shard = shard;
}
/// Should we attempt to parse queries?
pub fn query_parser_enabled(&self) -> bool {
let enabled = match self.query_parser_enabled {
None => self.pool_settings.query_parser_enabled,
Some(value) => value,
};
None => {
debug!(
"Using pool settings, query_parser_enabled: {}",
self.pool_settings.query_parser_enabled
);
self.pool_settings.query_parser_enabled
}
debug!("Query parser enabled: {}", enabled);
Some(value) => {
debug!(
"Using query parser override, query_parser_enabled: {}",
value
);
value
}
};
enabled
}
@@ -847,6 +937,7 @@ mod test {
fn test_infer_replica() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
qr.pool_settings.query_parser_read_write_splitting = true;
assert!(qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")) != None);
assert!(qr.query_parser_enabled());
@@ -862,7 +953,7 @@ mod test {
for query in queries {
// It's a recognized query
assert!(qr.infer(&query));
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Replica));
}
}
@@ -871,6 +962,7 @@ mod test {
fn test_infer_primary() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
qr.pool_settings.query_parser_read_write_splitting = true;
let queries = vec![
simple_query("UPDATE items SET name = 'pumpkin' WHERE id = 5"),
@@ -881,7 +973,7 @@ mod test {
for query in queries {
// It's a recognized query
assert!(qr.infer(&query));
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Primary));
}
}
@@ -893,7 +985,7 @@ mod test {
let query = simple_query("SELECT * FROM items WHERE id = 5");
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None);
assert!(qr.infer(&query));
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), None);
}
@@ -901,6 +993,8 @@ mod test {
fn test_infer_parse_prepared() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
qr.pool_settings.query_parser_read_write_splitting = true;
qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"));
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
@@ -913,7 +1007,7 @@ mod test {
res.put(prepared_stmt);
res.put_i16(0);
assert!(qr.infer(&res));
assert!(qr.infer(&qr.parse(&res).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Replica));
}
@@ -999,7 +1093,7 @@ mod test {
qr.try_execute_command(&query),
Some((Command::SetShardingKey, String::from("0")))
);
assert_eq!(qr.shard(), 0);
assert_eq!(qr.shard().unwrap(), 0);
// SetShard
let query = simple_query("SET SHARD TO '1'");
@@ -1007,7 +1101,7 @@ mod test {
qr.try_execute_command(&query),
Some((Command::SetShard, String::from("1")))
);
assert_eq!(qr.shard(), 1);
assert_eq!(qr.shard().unwrap(), 1);
// ShowShard
let query = simple_query("SHOW SHARD");
@@ -1069,6 +1163,8 @@ mod test {
fn test_enable_query_parser() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
qr.pool_settings.query_parser_read_write_splitting = true;
let query = simple_query("SET SERVER ROLE TO 'auto'");
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
@@ -1077,11 +1173,11 @@ mod test {
assert_eq!(qr.role(), None);
let query = simple_query("INSERT INTO test_table VALUES (1)");
assert!(qr.infer(&query));
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Primary));
let query = simple_query("SELECT * FROM test_table");
assert!(qr.infer(&query));
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Replica));
assert!(qr.query_parser_enabled());
@@ -1101,6 +1197,8 @@ mod test {
user: crate::config::User::default(),
default_role: Some(Role::Replica),
query_parser_enabled: true,
query_parser_max_length: None,
query_parser_read_write_splitting: true,
primary_reads_enabled: false,
sharding_function: ShardingFunction::PgBigintHash,
automatic_sharding_key: Some(String::from("test.id")),
@@ -1109,7 +1207,13 @@ mod test {
ban_time: PoolSettings::default().ban_time,
sharding_key_regex: None,
shard_id_regex: None,
default_shard: crate::config::DefaultShard::Shard(0),
regex_search_limit: 1000,
auth_query: None,
auth_query_password: None,
auth_query_user: None,
db: "test".to_string(),
plugins: None,
};
let mut qr = QueryRouter::new();
assert_eq!(qr.active_role, None);
@@ -1139,15 +1243,24 @@ mod test {
QueryRouter::setup();
let mut qr = QueryRouter::new();
assert!(qr.infer(&simple_query("BEGIN; SELECT 1; COMMIT;")));
assert!(qr
.infer(&qr.parse(&simple_query("BEGIN; SELECT 1; COMMIT;")).unwrap())
.is_ok());
assert_eq!(qr.role(), Role::Primary);
assert!(qr.infer(&simple_query("SELECT 1; SELECT 2;")));
assert!(qr
.infer(&qr.parse(&simple_query("SELECT 1; SELECT 2;")).unwrap())
.is_ok());
assert_eq!(qr.role(), Role::Replica);
assert!(qr.infer(&simple_query(
"SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
)));
assert!(qr
.infer(
&qr.parse(&simple_query(
"SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.role(), Role::Primary);
}
@@ -1162,6 +1275,8 @@ mod test {
user: crate::config::User::default(),
default_role: Some(Role::Replica),
query_parser_enabled: true,
query_parser_max_length: None,
query_parser_read_write_splitting: true,
primary_reads_enabled: false,
sharding_function: ShardingFunction::PgBigintHash,
automatic_sharding_key: None,
@@ -1170,14 +1285,26 @@ mod test {
ban_time: PoolSettings::default().ban_time,
sharding_key_regex: Some(Regex::new(r"/\* sharding_key: (\d+) \*/").unwrap()),
shard_id_regex: Some(Regex::new(r"/\* shard_id: (\d+) \*/").unwrap()),
default_shard: crate::config::DefaultShard::Shard(0),
regex_search_limit: 1000,
auth_query: None,
auth_query_password: None,
auth_query_user: None,
db: "test".to_string(),
plugins: None,
};
let mut qr = QueryRouter::new();
qr.update_pool_settings(pool_settings.clone());
// Shard should start out unset
assert_eq!(qr.active_shard, None);
// Don't panic when short query eg. ; is sent
let q0 = simple_query(";");
assert!(qr.try_execute_command(&q0) == None);
assert_eq!(qr.active_shard, None);
// Make sure setting it works
let q1 = simple_query("/* shard_id: 1 */ select 1 from foo;");
assert!(qr.try_execute_command(&q1) == None);
@@ -1201,49 +1328,90 @@ mod test {
let mut qr = QueryRouter::new();
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
qr.pool_settings.shards = 3;
qr.pool_settings.query_parser_read_write_splitting = true;
assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 5")));
assert_eq!(qr.shard(), 2);
assert!(qr
.infer(
&qr.parse(&simple_query("SELECT * FROM data WHERE id = 5"))
.unwrap(),
)
.is_ok());
assert_eq!(qr.shard().unwrap(), 2);
assert!(qr.infer(&simple_query(
"SELECT one, two, three FROM public.data WHERE id = 6"
)));
assert_eq!(qr.shard(), 0);
assert!(qr
.infer(
&qr.parse(&simple_query(
"SELECT one, two, three FROM public.data WHERE id = 6"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard().unwrap(), 0);
assert!(qr.infer(&simple_query(
"SELECT * FROM data
assert!(qr
.infer(
&qr.parse(&simple_query(
"SELECT * FROM data
INNER JOIN t2 ON data.id = 5
AND t2.data_id = data.id
WHERE data.id = 5"
)));
assert_eq!(qr.shard(), 2);
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard().unwrap(), 2);
// Shard did not move because we couldn't determine the sharding key since it could be ambiguous
// in the query.
assert!(qr.infer(&simple_query(
"SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
)));
assert_eq!(qr.shard(), 2);
assert!(qr
.infer(
&qr.parse(&simple_query(
"SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard().unwrap(), 2);
assert!(qr.infer(&simple_query(
r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
)));
assert_eq!(qr.shard(), 0);
assert!(qr
.infer(
&qr.parse(&simple_query(
r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard().unwrap(), 0);
assert!(qr.infer(&simple_query(
r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
)));
assert_eq!(qr.shard(), 2);
assert!(qr
.infer(
&qr.parse(&simple_query(
r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard().unwrap(), 2);
// Super unique sharding key
qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string());
assert!(qr.infer(&simple_query(
"SELECT * FROM table_x WHERE unique_enough_column_name = 6"
)));
assert_eq!(qr.shard(), 0);
assert!(qr
.infer(
&qr.parse(&simple_query(
"SELECT * FROM table_x WHERE unique_enough_column_name = 6"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard().unwrap(), 0);
assert!(qr.infer(&simple_query("SELECT * FROM table_y WHERE another_key = 5")));
assert_eq!(qr.shard(), 0);
assert!(qr
.infer(
&qr.parse(&simple_query("SELECT * FROM table_y WHERE another_key = 5"))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard().unwrap(), 0);
}
#[test]
@@ -1265,12 +1433,61 @@ mod test {
let mut qr = QueryRouter::new();
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
qr.pool_settings.shards = 3;
qr.pool_settings.query_parser_read_write_splitting = true;
assert!(qr.infer(&simple_query(stmt)));
assert!(qr.infer(&qr.parse(&simple_query(stmt)).unwrap()).is_ok());
assert_eq!(qr.placeholders.len(), 1);
assert!(qr.infer_shard_from_bind(&bind));
assert_eq!(qr.shard(), 2);
assert_eq!(qr.shard().unwrap(), 2);
assert!(qr.placeholders.is_empty());
}
#[tokio::test]
async fn test_table_access_plugin() {
use crate::config::{Plugins, TableAccess};
let table_access = TableAccess {
enabled: true,
tables: vec![String::from("pg_database")],
};
let plugins = Plugins {
table_access: Some(table_access),
intercept: None,
query_logger: None,
prewarmer: None,
};
QueryRouter::setup();
let mut pool_settings = PoolSettings::default();
pool_settings.query_parser_enabled = true;
pool_settings.plugins = Some(plugins);
let mut qr = QueryRouter::new();
qr.update_pool_settings(pool_settings);
let query = simple_query("SELECT * FROM pg_database");
let ast = qr.parse(&query).unwrap();
let res = qr.execute_plugins(&ast).await;
assert_eq!(
res,
Ok(PluginOutput::Deny(
"permission for table \"pg_database\" denied".to_string()
))
);
}
#[tokio::test]
async fn test_plugins_disabled_by_defaault() {
QueryRouter::setup();
let qr = QueryRouter::new();
let query = simple_query("SELECT * FROM pg_database");
let ast = qr.parse(&query).unwrap();
let res = qr.execute_plugins(&ast).await;
assert_eq!(res, Ok(PluginOutput::Allow));
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

226
src/stats/address.rs Normal file
View File

@@ -0,0 +1,226 @@
use std::sync::atomic::*;
use std::sync::Arc;
#[derive(Debug, Clone, Default)]
struct AddressStatFields {
xact_count: Arc<AtomicU64>,
query_count: Arc<AtomicU64>,
bytes_received: Arc<AtomicU64>,
bytes_sent: Arc<AtomicU64>,
xact_time: Arc<AtomicU64>,
query_time: Arc<AtomicU64>,
wait_time: Arc<AtomicU64>,
errors: Arc<AtomicU64>,
}
/// Internal address stats
#[derive(Debug, Clone, Default)]
pub struct AddressStats {
total: AddressStatFields,
current: AddressStatFields,
averages: AddressStatFields,
// Determines if the averages have been updated since the last time they were reported
pub averages_updated: Arc<AtomicBool>,
}
impl IntoIterator for AddressStats {
type Item = (String, u64);
type IntoIter = std::vec::IntoIter<Self::Item>;
fn into_iter(self) -> Self::IntoIter {
vec![
(
"total_xact_count".to_string(),
self.total.xact_count.load(Ordering::Relaxed),
),
(
"total_query_count".to_string(),
self.total.query_count.load(Ordering::Relaxed),
),
(
"total_received".to_string(),
self.total.bytes_received.load(Ordering::Relaxed),
),
(
"total_sent".to_string(),
self.total.bytes_sent.load(Ordering::Relaxed),
),
(
"total_xact_time".to_string(),
self.total.xact_time.load(Ordering::Relaxed),
),
(
"total_query_time".to_string(),
self.total.query_time.load(Ordering::Relaxed),
),
(
"total_wait_time".to_string(),
self.total.wait_time.load(Ordering::Relaxed),
),
(
"total_errors".to_string(),
self.total.errors.load(Ordering::Relaxed),
),
(
"avg_xact_count".to_string(),
self.averages.xact_count.load(Ordering::Relaxed),
),
(
"avg_query_count".to_string(),
self.averages.query_count.load(Ordering::Relaxed),
),
(
"avg_recv".to_string(),
self.averages.bytes_received.load(Ordering::Relaxed),
),
(
"avg_sent".to_string(),
self.averages.bytes_sent.load(Ordering::Relaxed),
),
(
"avg_errors".to_string(),
self.averages.errors.load(Ordering::Relaxed),
),
(
"avg_xact_time".to_string(),
self.averages.xact_time.load(Ordering::Relaxed),
),
(
"avg_query_time".to_string(),
self.averages.query_time.load(Ordering::Relaxed),
),
(
"avg_wait_time".to_string(),
self.averages.wait_time.load(Ordering::Relaxed),
),
]
.into_iter()
}
}
impl AddressStats {
pub fn xact_count_add(&self) {
self.total.xact_count.fetch_add(1, Ordering::Relaxed);
self.current.xact_count.fetch_add(1, Ordering::Relaxed);
}
pub fn query_count_add(&self) {
self.total.query_count.fetch_add(1, Ordering::Relaxed);
self.current.query_count.fetch_add(1, Ordering::Relaxed);
}
pub fn bytes_received_add(&self, bytes: u64) {
self.total
.bytes_received
.fetch_add(bytes, Ordering::Relaxed);
self.current
.bytes_received
.fetch_add(bytes, Ordering::Relaxed);
}
pub fn bytes_sent_add(&self, bytes: u64) {
self.total.bytes_sent.fetch_add(bytes, Ordering::Relaxed);
self.current.bytes_sent.fetch_add(bytes, Ordering::Relaxed);
}
pub fn xact_time_add(&self, time: u64) {
self.total.xact_time.fetch_add(time, Ordering::Relaxed);
self.current.xact_time.fetch_add(time, Ordering::Relaxed);
}
pub fn query_time_add(&self, time: u64) {
self.total.query_time.fetch_add(time, Ordering::Relaxed);
self.current.query_time.fetch_add(time, Ordering::Relaxed);
}
pub fn wait_time_add(&self, time: u64) {
self.total.wait_time.fetch_add(time, Ordering::Relaxed);
self.current.wait_time.fetch_add(time, Ordering::Relaxed);
}
pub fn error(&self) {
self.total.errors.fetch_add(1, Ordering::Relaxed);
self.current.errors.fetch_add(1, Ordering::Relaxed);
}
pub fn update_averages(&self) {
let stat_period_per_second = crate::stats::STAT_PERIOD / 1_000;
// xact_count
let current_xact_count = self.current.xact_count.load(Ordering::Relaxed);
let current_xact_time = self.current.xact_time.load(Ordering::Relaxed);
self.averages.xact_count.store(
current_xact_count / stat_period_per_second,
Ordering::Relaxed,
);
if current_xact_count == 0 {
self.averages.xact_time.store(0, Ordering::Relaxed);
} else {
self.averages
.xact_time
.store(current_xact_time / current_xact_count, Ordering::Relaxed);
}
// query_count
let current_query_count = self.current.query_count.load(Ordering::Relaxed);
let current_query_time = self.current.query_time.load(Ordering::Relaxed);
self.averages.query_count.store(
current_query_count / stat_period_per_second,
Ordering::Relaxed,
);
if current_query_count == 0 {
self.averages.query_time.store(0, Ordering::Relaxed);
} else {
self.averages
.query_time
.store(current_query_time / current_query_count, Ordering::Relaxed);
}
// bytes_received
let current_bytes_received = self.current.bytes_received.load(Ordering::Relaxed);
self.averages.bytes_received.store(
current_bytes_received / stat_period_per_second,
Ordering::Relaxed,
);
// bytes_sent
let current_bytes_sent = self.current.bytes_sent.load(Ordering::Relaxed);
self.averages.bytes_sent.store(
current_bytes_sent / stat_period_per_second,
Ordering::Relaxed,
);
// wait_time
let current_wait_time = self.current.wait_time.load(Ordering::Relaxed);
self.averages.wait_time.store(
current_wait_time / stat_period_per_second,
Ordering::Relaxed,
);
// errors
let current_errors = self.current.errors.load(Ordering::Relaxed);
self.averages
.errors
.store(current_errors / stat_period_per_second, Ordering::Relaxed);
}
pub fn reset_current_counts(&self) {
self.current.xact_count.store(0, Ordering::Relaxed);
self.current.xact_time.store(0, Ordering::Relaxed);
self.current.query_count.store(0, Ordering::Relaxed);
self.current.query_time.store(0, Ordering::Relaxed);
self.current.bytes_received.store(0, Ordering::Relaxed);
self.current.bytes_sent.store(0, Ordering::Relaxed);
self.current.wait_time.store(0, Ordering::Relaxed);
self.current.errors.store(0, Ordering::Relaxed);
}
pub fn populate_row(&self, row: &mut Vec<String>) {
for (_key, value) in self.clone() {
row.push(value.to_string());
}
}
}

174
src/stats/client.rs Normal file
View File

@@ -0,0 +1,174 @@
use super::{get_reporter, Reporter};
use atomic_enum::atomic_enum;
use std::sync::atomic::*;
use std::sync::Arc;
use tokio::time::Instant;
/// The various states that a client can be in
#[atomic_enum]
#[derive(PartialEq)]
pub enum ClientState {
Idle = 0,
Waiting,
Active,
}
impl std::fmt::Display for ClientState {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match *self {
ClientState::Idle => write!(f, "idle"),
ClientState::Waiting => write!(f, "waiting"),
ClientState::Active => write!(f, "active"),
}
}
}
#[derive(Debug, Clone)]
/// Information we keep track of which can be queried by SHOW CLIENTS
pub struct ClientStats {
/// A random integer assigned to the client and used by stats to track the client
client_id: i32,
/// Data associated with the client, not writable, only set when we construct the ClientStat
application_name: String,
username: String,
pool_name: String,
connect_time: Instant,
reporter: Reporter,
/// Total time spent waiting for a connection from pool, measures in microseconds
pub total_wait_time: Arc<AtomicU64>,
/// Maximum time spent waiting for a connection from pool, measures in microseconds
pub max_wait_time: Arc<AtomicU64>,
/// Current state of the client
pub state: Arc<AtomicClientState>,
/// Number of transactions executed by this client
pub transaction_count: Arc<AtomicU64>,
/// Number of queries executed by this client
pub query_count: Arc<AtomicU64>,
/// Number of errors made by this client
pub error_count: Arc<AtomicU64>,
}
impl Default for ClientStats {
fn default() -> Self {
ClientStats {
client_id: 0,
connect_time: Instant::now(),
application_name: String::new(),
username: String::new(),
pool_name: String::new(),
total_wait_time: Arc::new(AtomicU64::new(0)),
max_wait_time: Arc::new(AtomicU64::new(0)),
state: Arc::new(AtomicClientState::new(ClientState::Idle)),
transaction_count: Arc::new(AtomicU64::new(0)),
query_count: Arc::new(AtomicU64::new(0)),
error_count: Arc::new(AtomicU64::new(0)),
reporter: get_reporter(),
}
}
}
impl ClientStats {
pub fn new(
client_id: i32,
application_name: &str,
username: &str,
pool_name: &str,
connect_time: Instant,
) -> Self {
Self {
client_id,
connect_time,
application_name: application_name.to_string(),
username: username.to_string(),
pool_name: pool_name.to_string(),
..Default::default()
}
}
/// Reports a client is disconnecting from the pooler and
/// update metrics on the corresponding pool.
pub fn disconnect(&self) {
self.reporter.client_disconnecting(self.client_id);
}
/// Register a client with the stats system. The stats system uses client_id
/// to track and aggregate statistics from all source that relate to that client
pub fn register(&self, stats: Arc<ClientStats>) {
self.reporter.client_register(self.client_id, stats);
self.state.store(ClientState::Idle, Ordering::Relaxed);
}
/// Reports a client is done querying the server and is no longer assigned a server connection
pub fn idle(&self) {
self.state.store(ClientState::Idle, Ordering::Relaxed);
}
/// Reports a client is waiting for a connection
pub fn waiting(&self) {
self.state.store(ClientState::Waiting, Ordering::Relaxed);
}
/// Reports a client is done waiting for a connection and is about to query the server.
pub fn active(&self) {
self.state.store(ClientState::Active, Ordering::Relaxed);
}
/// Reports a client has failed to obtain a connection from a connection pool
pub fn checkout_error(&self) {
self.state.store(ClientState::Idle, Ordering::Relaxed);
}
/// Reports a client has had the server assigned to it be banned
pub fn ban_error(&self) {
self.state.store(ClientState::Idle, Ordering::Relaxed);
self.error_count.fetch_add(1, Ordering::Relaxed);
}
/// Reporters the time spent by a client waiting to get a healthy connection from the pool
pub fn checkout_time(&self, microseconds: u64) {
self.total_wait_time
.fetch_add(microseconds, Ordering::Relaxed);
self.max_wait_time
.fetch_max(microseconds, Ordering::Relaxed);
}
/// Report a query executed by a client against a server
pub fn query(&self) {
self.query_count.fetch_add(1, Ordering::Relaxed);
}
/// Report a transaction executed by a client a server
/// we report each individual queries outside a transaction as a transaction
/// We only count the initial BEGIN as a transaction, all queries within do not
/// count as transactions
pub fn transaction(&self) {
self.transaction_count.fetch_add(1, Ordering::Relaxed);
}
// Helper methods for show clients
pub fn connect_time(&self) -> Instant {
self.connect_time
}
pub fn client_id(&self) -> i32 {
self.client_id
}
pub fn application_name(&self) -> String {
self.application_name.clone()
}
pub fn username(&self) -> String {
self.username.clone()
}
pub fn pool_name(&self) -> String {
self.pool_name.clone()
}
}

151
src/stats/pool.rs Normal file
View File

@@ -0,0 +1,151 @@
use log::debug;
use super::{ClientState, ServerState};
use crate::{config::PoolMode, messages::DataType, pool::PoolIdentifier};
use std::collections::HashMap;
use std::sync::atomic::*;
use crate::pool::get_all_pools;
#[derive(Debug, Clone)]
/// A struct that holds information about a Pool .
pub struct PoolStats {
pub identifier: PoolIdentifier,
pub mode: PoolMode,
pub cl_idle: u64,
pub cl_active: u64,
pub cl_waiting: u64,
pub cl_cancel_req: u64,
pub sv_active: u64,
pub sv_idle: u64,
pub sv_used: u64,
pub sv_tested: u64,
pub sv_login: u64,
pub maxwait: u64,
}
impl PoolStats {
pub fn new(identifier: PoolIdentifier, mode: PoolMode) -> Self {
PoolStats {
identifier,
mode,
cl_idle: 0,
cl_active: 0,
cl_waiting: 0,
cl_cancel_req: 0,
sv_active: 0,
sv_idle: 0,
sv_used: 0,
sv_tested: 0,
sv_login: 0,
maxwait: 0,
}
}
pub fn construct_pool_lookup() -> HashMap<PoolIdentifier, PoolStats> {
let mut map: HashMap<PoolIdentifier, PoolStats> = HashMap::new();
let client_map = super::get_client_stats();
let server_map = super::get_server_stats();
for (identifier, pool) in get_all_pools() {
map.insert(
identifier.clone(),
PoolStats::new(identifier, pool.settings.pool_mode),
);
}
for client in client_map.values() {
match map.get_mut(&PoolIdentifier {
db: client.pool_name(),
user: client.username(),
}) {
Some(pool_stats) => {
match client.state.load(Ordering::Relaxed) {
ClientState::Active => pool_stats.cl_active += 1,
ClientState::Idle => pool_stats.cl_idle += 1,
ClientState::Waiting => pool_stats.cl_waiting += 1,
}
let max_wait = client.max_wait_time.load(Ordering::Relaxed);
pool_stats.maxwait = std::cmp::max(pool_stats.maxwait, max_wait);
}
None => debug!("Client from an obselete pool"),
}
}
for server in server_map.values() {
match map.get_mut(&PoolIdentifier {
db: server.pool_name(),
user: server.username(),
}) {
Some(pool_stats) => match server.state.load(Ordering::Relaxed) {
ServerState::Active => pool_stats.sv_active += 1,
ServerState::Idle => pool_stats.sv_idle += 1,
ServerState::Login => pool_stats.sv_login += 1,
ServerState::Tested => pool_stats.sv_tested += 1,
},
None => debug!("Server from an obselete pool"),
}
}
return map;
}
pub fn generate_header() -> Vec<(&'static str, DataType)> {
return vec![
("database", DataType::Text),
("user", DataType::Text),
("pool_mode", DataType::Text),
("cl_idle", DataType::Numeric),
("cl_active", DataType::Numeric),
("cl_waiting", DataType::Numeric),
("cl_cancel_req", DataType::Numeric),
("sv_active", DataType::Numeric),
("sv_idle", DataType::Numeric),
("sv_used", DataType::Numeric),
("sv_tested", DataType::Numeric),
("sv_login", DataType::Numeric),
("maxwait", DataType::Numeric),
("maxwait_us", DataType::Numeric),
];
}
pub fn generate_row(&self) -> Vec<String> {
return vec![
self.identifier.db.clone(),
self.identifier.user.clone(),
self.mode.to_string(),
self.cl_idle.to_string(),
self.cl_active.to_string(),
self.cl_waiting.to_string(),
self.cl_cancel_req.to_string(),
self.sv_active.to_string(),
self.sv_idle.to_string(),
self.sv_used.to_string(),
self.sv_tested.to_string(),
self.sv_login.to_string(),
(self.maxwait / 1_000_000).to_string(),
(self.maxwait % 1_000_000).to_string(),
];
}
}
impl IntoIterator for PoolStats {
type Item = (String, u64);
type IntoIter = std::vec::IntoIter<Self::Item>;
fn into_iter(self) -> Self::IntoIter {
vec![
("cl_idle".to_string(), self.cl_idle),
("cl_active".to_string(), self.cl_active),
("cl_waiting".to_string(), self.cl_waiting),
("cl_cancel_req".to_string(), self.cl_cancel_req),
("sv_active".to_string(), self.sv_active),
("sv_idle".to_string(), self.sv_idle),
("sv_used".to_string(), self.sv_used),
("sv_tested".to_string(), self.sv_tested),
("sv_login".to_string(), self.sv_login),
("maxwait".to_string(), self.maxwait / 1_000_000),
("maxwait_us".to_string(), self.maxwait % 1_000_000),
]
.into_iter()
}
}

226
src/stats/server.rs Normal file
View File

@@ -0,0 +1,226 @@
use super::AddressStats;
use super::{get_reporter, Reporter};
use crate::config::Address;
use atomic_enum::atomic_enum;
use parking_lot::RwLock;
use std::sync::atomic::*;
use std::sync::Arc;
use tokio::time::Instant;
/// The various states that a server can be in
#[atomic_enum]
#[derive(PartialEq)]
pub enum ServerState {
Login = 0,
Active,
Tested,
Idle,
}
impl std::fmt::Display for ServerState {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match *self {
ServerState::Login => write!(f, "login"),
ServerState::Active => write!(f, "active"),
ServerState::Tested => write!(f, "tested"),
ServerState::Idle => write!(f, "idle"),
}
}
}
/// Information we keep track of which can be queried by SHOW SERVERS
#[derive(Debug, Clone)]
pub struct ServerStats {
/// A random integer assigned to the server and used by stats to track the server
server_id: i32,
/// Context information, only to be read
address: Address,
connect_time: Instant,
reporter: Reporter,
/// Data
pub application_name: Arc<RwLock<String>>,
pub state: Arc<AtomicServerState>,
pub bytes_sent: Arc<AtomicU64>,
pub bytes_received: Arc<AtomicU64>,
pub transaction_count: Arc<AtomicU64>,
pub query_count: Arc<AtomicU64>,
pub error_count: Arc<AtomicU64>,
pub prepared_hit_count: Arc<AtomicU64>,
pub prepared_miss_count: Arc<AtomicU64>,
pub prepared_cache_size: Arc<AtomicU64>,
}
impl Default for ServerStats {
fn default() -> Self {
ServerStats {
server_id: 0,
application_name: Arc::new(RwLock::new(String::new())),
address: Address::default(),
connect_time: Instant::now(),
state: Arc::new(AtomicServerState::new(ServerState::Login)),
bytes_sent: Arc::new(AtomicU64::new(0)),
bytes_received: Arc::new(AtomicU64::new(0)),
transaction_count: Arc::new(AtomicU64::new(0)),
query_count: Arc::new(AtomicU64::new(0)),
error_count: Arc::new(AtomicU64::new(0)),
reporter: get_reporter(),
prepared_hit_count: Arc::new(AtomicU64::new(0)),
prepared_miss_count: Arc::new(AtomicU64::new(0)),
prepared_cache_size: Arc::new(AtomicU64::new(0)),
}
}
}
impl ServerStats {
pub fn new(address: Address, connect_time: Instant) -> Self {
Self {
address,
connect_time,
server_id: rand::random::<i32>(),
..Default::default()
}
}
pub fn server_id(&self) -> i32 {
self.server_id
}
/// Register a server connection with the stats system. The stats system uses server_id
/// to track and aggregate statistics from all source that relate to that server
// Delegates to reporter
pub fn register(&self, stats: Arc<ServerStats>) {
self.reporter.server_register(self.server_id, stats);
self.login();
}
/// Reports a server connection is no longer assigned to a client
/// and is available for the next client to pick it up
pub fn idle(&self) {
self.state.store(ServerState::Idle, Ordering::Relaxed);
}
/// Reports a server connection is disconnecting from the pooler.
/// Also updates metrics on the pool regarding server usage.
pub fn disconnect(&self) {
self.reporter.server_disconnecting(self.server_id);
}
/// Reports a server connection is being tested before being given to a client.
pub fn tested(&self) {
self.set_undefined_application();
self.state.store(ServerState::Tested, Ordering::Relaxed);
}
/// Reports a server connection is attempting to login.
pub fn login(&self) {
self.state.store(ServerState::Login, Ordering::Relaxed);
self.set_undefined_application();
}
/// Reports a server connection has been assigned to a client that
/// is about to query the server
pub fn active(&self, application_name: String) {
self.state.store(ServerState::Active, Ordering::Relaxed);
self.set_application(application_name);
}
pub fn address_stats(&self) -> Arc<AddressStats> {
self.address.stats.clone()
}
pub fn check_address_stat_average_is_updated_status(&self) -> bool {
self.address.stats.averages_updated.load(Ordering::Relaxed)
}
pub fn set_address_stat_average_is_updated_status(&self, is_checked: bool) {
self.address
.stats
.averages_updated
.store(is_checked, Ordering::Relaxed);
}
// Helper methods for show_servers
pub fn pool_name(&self) -> String {
self.address.pool_name.clone()
}
pub fn username(&self) -> String {
self.address.username.clone()
}
pub fn address_name(&self) -> String {
self.address.name()
}
pub fn connect_time(&self) -> Instant {
self.connect_time
}
fn set_application(&self, name: String) {
let mut application_name = self.application_name.write();
*application_name = name;
}
fn set_undefined_application(&self) {
self.set_application(String::from("Undefined"))
}
pub fn checkout_time(&self, microseconds: u64, application_name: String) {
// Update server stats and address aggregation stats
self.set_application(application_name);
self.address.stats.wait_time_add(microseconds);
}
/// Report a query executed by a client against a server
pub fn query(&self, milliseconds: u64, application_name: &str) {
self.set_application(application_name.to_string());
self.address.stats.query_count_add();
self.address.stats.query_time_add(milliseconds);
self.query_count.fetch_add(1, Ordering::Relaxed);
}
/// Report a transaction executed by a client a server
/// we report each individual queries outside a transaction as a transaction
/// We only count the initial BEGIN as a transaction, all queries within do not
/// count as transactions
pub fn transaction(&self, application_name: &str) {
self.set_application(application_name.to_string());
self.transaction_count.fetch_add(1, Ordering::Relaxed);
self.address.stats.xact_count_add();
}
/// Report data sent to a server
pub fn data_sent(&self, amount_bytes: usize) {
self.bytes_sent
.fetch_add(amount_bytes as u64, Ordering::Relaxed);
self.address.stats.bytes_sent_add(amount_bytes as u64);
}
/// Report data received from a server
pub fn data_received(&self, amount_bytes: usize) {
self.bytes_received
.fetch_add(amount_bytes as u64, Ordering::Relaxed);
self.address.stats.bytes_received_add(amount_bytes as u64);
}
/// Report a prepared statement that already exists on the server.
pub fn prepared_cache_hit(&self) {
self.prepared_hit_count.fetch_add(1, Ordering::Relaxed);
}
/// Report a prepared statement that does not exist on the server yet.
pub fn prepared_cache_miss(&self) {
self.prepared_miss_count.fetch_add(1, Ordering::Relaxed);
}
pub fn prepared_cache_add(&self) {
self.prepared_cache_size.fetch_add(1, Ordering::Relaxed);
}
pub fn prepared_cache_remove(&self) {
self.prepared_cache_size.fetch_sub(1, Ordering::Relaxed);
}
}

View File

@@ -4,7 +4,12 @@ use rustls_pemfile::{certs, read_one, Item};
use std::iter;
use std::path::Path;
use std::sync::Arc;
use tokio_rustls::rustls::{self, Certificate, PrivateKey};
use std::time::SystemTime;
use tokio_rustls::rustls::{
self,
client::{ServerCertVerified, ServerCertVerifier},
Certificate, PrivateKey, ServerName,
};
use tokio_rustls::TlsAcceptor;
use crate::config::get_config;
@@ -64,3 +69,19 @@ impl Tls {
})
}
}
pub struct NoCertificateVerification;
impl ServerCertVerifier for NoCertificateVerification {
fn verify_server_cert(
&self,
_end_entity: &Certificate,
_intermediates: &[Certificate],
_server_name: &ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8],
_now: SystemTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
}

View File

@@ -1,5 +1,7 @@
FROM rust:bullseye
COPY --from=sclevine/yj /bin/yj /bin/yj
RUN /bin/yj -h
RUN apt-get update && apt-get install llvm-11 psmisc postgresql-contrib postgresql-client ruby ruby-dev libpq-dev python3 python3-pip lcov curl sudo iproute2 -y
RUN cargo install cargo-binutils rustfilt
RUN rustup component add llvm-tools-preview

View File

@@ -36,6 +36,15 @@ services:
POSTGRES_PASSWORD: postgres
POSTGRES_INITDB_ARGS: --auth-local=scram-sha-256 --auth-host=scram-sha-256 --auth=scram-sha-256
command: ["postgres", "-p", "9432", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-c", "pg_stat_statements.max=100000"]
pg5:
image: postgres:14
network_mode: "service:main"
environment:
POSTGRES_USER: postgres
POSTGRES_DB: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_INITDB_ARGS: --auth-local=md5 --auth-host=md5 --auth=md5
command: ["postgres", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-p", "10432"]
main:
build: .
command: ["bash", "/app/tests/docker/run.sh"]

View File

@@ -63,6 +63,7 @@ def cleanup_conn(conn: psycopg2.extensions.connection, cur: psycopg2.extensions.
def test_normal_db_access():
pgcat_start()
conn, cur = connect_db(autocommit=False)
cur.execute("SELECT 1")
res = cur.fetchall()

View File

@@ -11,282 +11,6 @@ describe "Admin" do
processes.pgcat.shutdown
end
describe "SHOW STATS" do
context "clients connect and make one query" do
it "updates *_query_time and *_wait_time" do
connection = PG::connect("#{pgcat_conn_str}?application_name=one_query")
connection.async_exec("SELECT pg_sleep(0.25)")
connection.async_exec("SELECT pg_sleep(0.25)")
connection.async_exec("SELECT pg_sleep(0.25)")
connection.close
# wait for averages to be calculated, we shouldn't do this too often
sleep(15.5)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW STATS")[0]
admin_conn.close
expect(results["total_query_time"].to_i).to be_within(200).of(750)
expect(results["avg_query_time"].to_i).to_not eq(0)
expect(results["total_wait_time"].to_i).to_not eq(0)
expect(results["avg_wait_time"].to_i).to_not eq(0)
end
end
end
describe "SHOW POOLS" do
context "bad credentials" do
it "does not change any stats" do
bad_passsword_url = URI(pgcat_conn_str)
bad_passsword_url.password = "wrong"
expect { PG::connect("#{bad_passsword_url.to_s}?application_name=bad_password") }.to raise_error(PG::ConnectionBad)
sleep(1)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["sv_idle"]).to eq("1")
end
end
context "bad database name" do
it "does not change any stats" do
bad_db_url = URI(pgcat_conn_str)
bad_db_url.path = "/wrong_db"
expect { PG::connect("#{bad_db_url.to_s}?application_name=bad_db") }.to raise_error(PG::ConnectionBad)
sleep(1)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["sv_idle"]).to eq("1")
end
end
context "client connects but issues no queries" do
it "only affects cl_idle stats" do
connections = Array.new(20) { PG::connect(pgcat_conn_str) }
sleep(1)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_idle"]).to eq("20")
expect(results["sv_idle"]).to eq("1")
connections.map(&:close)
sleep(1.1)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_idle cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["sv_idle"]).to eq("1")
end
end
context "clients connect and make one query" do
it "only affects cl_idle, sv_idle stats" do
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
Thread.new { c.async_exec("SELECT pg_sleep(2.5)") }
end
sleep(1.1)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_waiting cl_cancel_req sv_idle sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_active"]).to eq("5")
expect(results["sv_active"]).to eq("5")
sleep(3)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_idle"]).to eq("5")
expect(results["sv_idle"]).to eq("5")
connections.map(&:close)
sleep(1)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["sv_idle"]).to eq("5")
end
end
context "client connects and opens a transaction and closes connection uncleanly" do
it "produces correct statistics" do
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
Thread.new do
c.async_exec("BEGIN")
c.async_exec("SELECT pg_sleep(0.01)")
c.close
end
end
sleep(1.1)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["sv_idle"]).to eq("5")
end
end
context "client fail to checkout connection from the pool" do
it "counts clients as idle" do
new_configs = processes.pgcat.current_config
new_configs["general"]["connect_timeout"] = 500
new_configs["general"]["ban_time"] = 1
new_configs["general"]["shutdown_timeout"] = 1
new_configs["pools"]["sharded_db"]["users"]["0"]["pool_size"] = 1
processes.pgcat.update_config(new_configs)
processes.pgcat.reload_config
threads = []
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
threads << Thread.new { c.async_exec("SELECT pg_sleep(1)") rescue PG::SystemError }
end
sleep(2)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_idle"]).to eq("5")
expect(results["sv_idle"]).to eq("1")
threads.map(&:join)
connections.map(&:close)
end
end
context "clients overwhelm server pools" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 2) }
it "cl_waiting is updated to show it" do
threads = []
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
threads << Thread.new { c.async_exec("SELECT pg_sleep(1.5)") }
end
sleep(1.1) # Allow time for stats to update
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_cancel_req sv_idle sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_waiting"]).to eq("2")
expect(results["cl_active"]).to eq("2")
expect(results["sv_active"]).to eq("2")
sleep(2.5) # Allow time for stats to update
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_idle"]).to eq("4")
expect(results["sv_idle"]).to eq("2")
threads.map(&:join)
connections.map(&:close)
end
it "show correct max_wait" do
threads = []
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
threads << Thread.new { c.async_exec("SELECT pg_sleep(1.5)") }
end
sleep(2.5) # Allow time for stats to update
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
expect(results["maxwait"]).to eq("1")
expect(results["maxwait_us"].to_i).to be_within(200_000).of(500_000)
sleep(4.5) # Allow time for stats to update
results = admin_conn.async_exec("SHOW POOLS")[0]
expect(results["maxwait"]).to eq("0")
threads.map(&:join)
connections.map(&:close)
end
end
end
describe "SHOW CLIENTS" do
it "reports correct number and application names" do
conn_str = processes.pgcat.connection_string("sharded_db", "sharding_user")
connections = Array.new(20) { |i| PG::connect("#{conn_str}?application_name=app#{i % 5}") }
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
sleep(1) # Wait for stats to be updated
results = admin_conn.async_exec("SHOW CLIENTS")
expect(results.count).to eq(21) # count admin clients
expect(results.select { |c| c["application_name"] == "app3" || c["application_name"] == "app4" }.count).to eq(8)
expect(results.select { |c| c["database"] == "pgcat" }.count).to eq(1)
connections[0..5].map(&:close)
sleep(1) # Wait for stats to be updated
results = admin_conn.async_exec("SHOW CLIENTS")
expect(results.count).to eq(15)
connections[6..].map(&:close)
sleep(1) # Wait for stats to be updated
expect(admin_conn.async_exec("SHOW CLIENTS").count).to eq(1)
admin_conn.close
end
it "reports correct number of queries and transactions" do
conn_str = processes.pgcat.connection_string("sharded_db", "sharding_user")
connections = Array.new(2) { |i| PG::connect("#{conn_str}?application_name=app#{i}") }
connections.each do |c|
c.async_exec("SELECT 1")
c.async_exec("SELECT 2")
c.async_exec("SELECT 3")
c.async_exec("BEGIN")
c.async_exec("SELECT 4")
c.async_exec("SELECT 5")
c.async_exec("COMMIT")
end
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
sleep(1) # Wait for stats to be updated
results = admin_conn.async_exec("SHOW CLIENTS")
expect(results.count).to eq(3)
normal_client_results = results.reject { |r| r["database"] == "pgcat" }
expect(normal_client_results[0]["transaction_count"]).to eq("4")
expect(normal_client_results[1]["transaction_count"]).to eq("4")
expect(normal_client_results[0]["query_count"]).to eq("7")
expect(normal_client_results[1]["query_count"]).to eq("7")
admin_conn.close
connections.map(&:close)
end
end
describe "Manual Banning" do
let(:processes) { Helpers::Pgcat.single_shard_setup("sharded_db", 10) }
before do
@@ -357,7 +81,7 @@ describe "Admin" do
end
end
describe "SHOW users" do
describe "SHOW USERS" do
it "returns the right users" do
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW USERS")[0]
@@ -366,4 +90,28 @@ describe "Admin" do
expect(results["pool_mode"]).to eq("transaction")
end
end
describe "PAUSE" do
it "pauses all pools" do
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW DATABASES").to_a
expect(results.map{ |r| r["paused"] }.uniq).to eq(["0"])
admin_conn.async_exec("PAUSE")
results = admin_conn.async_exec("SHOW DATABASES").to_a
expect(results.map{ |r| r["paused"] }.uniq).to eq(["1"])
admin_conn.async_exec("RESUME")
results = admin_conn.async_exec("SHOW DATABASES").to_a
expect(results.map{ |r| r["paused"] }.uniq).to eq(["0"])
end
it "handles errors" do
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
expect { admin_conn.async_exec("PAUSE foo").to_a }.to raise_error(PG::SystemError)
expect { admin_conn.async_exec("PAUSE foo,bar").to_a }.to raise_error(PG::SystemError)
end
end
end

View File

@@ -0,0 +1,215 @@
# frozen_string_literal: true
require_relative 'spec_helper'
require_relative 'helpers/auth_query_helper'
describe "Auth Query" do
let(:configured_instances) {[5432, 10432]}
let(:config_user) { { 'username' => 'sharding_user', 'password' => 'sharding_user' } }
let(:pg_user) { { 'username' => 'sharding_user', 'password' => 'sharding_user' } }
let(:processes) { Helpers::AuthQuery.single_shard_auth_query(pool_name: "sharded_db", pg_user: pg_user, config_user: config_user, extra_conf: config, wait_until_ready: wait_until_ready ) }
let(:config) { {} }
let(:wait_until_ready) { true }
after do
unless @failing_process
processes.all_databases.map(&:reset)
processes.pgcat.shutdown
end
@failing_process = false
end
context "when auth_query is not configured" do
context 'and cleartext passwords are set' do
it "uses local passwords" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", config_user['username'], config_user['password']))
expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil
end
end
context 'and cleartext passwords are not set' do
let(:config_user) { { 'username' => 'sharding_user' } }
it "does not start because it is not possible to authenticate" do
@failing_process = true
expect { processes.pgcat }.to raise_error(StandardError, /You have to specify a user password for every pool if auth_query is not specified/)
end
end
end
context 'when auth_query is configured' do
context 'with global configuration' do
around(:example) do |example|
# Set up auth query
Helpers::AuthQuery.set_up_auth_query_for_user(
user: 'md5_auth_user',
password: 'secret'
);
example.run
# Drop auth query support
Helpers::AuthQuery.tear_down_auth_query_for_user(
user: 'md5_auth_user',
password: 'secret'
);
end
context 'with correct global parameters' do
let(:config) { { 'general' => { 'auth_query' => "SELECT * FROM public.user_lookup('$1');", 'auth_query_user' => 'md5_auth_user', 'auth_query_password' => 'secret' } } }
context 'and with cleartext passwords set' do
it 'it uses local passwords' do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password']))
expect(conn.exec("SELECT 1 + 2")).not_to be_nil
end
end
context 'and with cleartext passwords not set' do
let(:config_user) { { 'username' => 'sharding_user', 'password' => 'sharding_user' } }
it 'it uses obtained passwords' do
connection_string = processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password'])
conn = PG.connect(connection_string)
expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil
end
it 'allows passwords to be changed without closing existing connections' do
pgconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username']))
expect(pgconn.exec("SELECT 1 + 2")).not_to be_nil
Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD 'secret2';")
expect(pgconn.exec("SELECT 1 + 4")).not_to be_nil
Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD '#{pg_user['password']}';")
end
it 'allows passwords to be changed and that new password is needed when reconnecting' do
pgconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username']))
expect(pgconn.exec("SELECT 1 + 2")).not_to be_nil
Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD 'secret2';")
newconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], 'secret2'))
expect(newconn.exec("SELECT 1 + 2")).not_to be_nil
Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD '#{pg_user['password']}';")
end
end
end
context 'with wrong parameters' do
let(:config) { { 'general' => { 'auth_query' => 'SELECT 1', 'auth_query_user' => 'wrong_user', 'auth_query_password' => 'wrong' } } }
context 'and with clear text passwords set' do
it "it uses local passwords" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password']))
expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil
end
end
context 'and with cleartext passwords not set' do
let(:config_user) { { 'username' => 'sharding_user' } }
it "it fails to start as it cannot authenticate against servers" do
@failing_process = true
expect { PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password'])) }.to raise_error(StandardError, /Error trying to obtain password from auth_query/ )
end
context 'and we fix the issue and reload' do
let(:wait_until_ready) { false }
it 'fails in the beginning but starts working after reloading config' do
connection_string = processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password'])
while !(processes.pgcat.logs =~ /Waiting for clients/) do
sleep 0.5
end
expect { PG.connect(connection_string)}.to raise_error(PG::ConnectionBad)
expect(processes.pgcat.logs).to match(/Error trying to obtain password from auth_query/)
current_config = processes.pgcat.current_config
config = { 'general' => { 'auth_query' => "SELECT * FROM public.user_lookup('$1');", 'auth_query_user' => 'md5_auth_user', 'auth_query_password' => 'secret' } }
processes.pgcat.update_config(current_config.deep_merge(config))
processes.pgcat.reload_config
conn = nil
expect { conn = PG.connect(connection_string)}.not_to raise_error
expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil
end
end
end
end
end
context 'with per pool configuration' do
around(:example) do |example|
# Set up auth query
Helpers::AuthQuery.set_up_auth_query_for_user(
user: 'md5_auth_user',
password: 'secret'
);
Helpers::AuthQuery.set_up_auth_query_for_user(
user: 'md5_auth_user1',
password: 'secret',
database: 'shard1'
);
example.run
# Tear down auth query
Helpers::AuthQuery.tear_down_auth_query_for_user(
user: 'md5_auth_user',
password: 'secret'
);
Helpers::AuthQuery.tear_down_auth_query_for_user(
user: 'md5_auth_user1',
password: 'secret',
database: 'shard1'
);
end
context 'with correct parameters' do
let(:processes) { Helpers::AuthQuery.two_pools_auth_query(pool_names: ["sharded_db0", "sharded_db1"], pg_user: pg_user, config_user: config_user, extra_conf: config ) }
let(:config) {
{ 'pools' =>
{
'sharded_db0' => {
'auth_query' => "SELECT * FROM public.user_lookup('$1');",
'auth_query_user' => 'md5_auth_user',
'auth_query_password' => 'secret'
},
'sharded_db1' => {
'auth_query' => "SELECT * FROM public.user_lookup('$1');",
'auth_query_user' => 'md5_auth_user1',
'auth_query_password' => 'secret'
},
}
}
}
context 'and with cleartext passwords set' do
it 'it uses local passwords' do
conn = PG.connect(processes.pgcat.connection_string("sharded_db0", pg_user['username'], pg_user['password']))
expect(conn.exec("SELECT 1 + 2")).not_to be_nil
conn = PG.connect(processes.pgcat.connection_string("sharded_db1", pg_user['username'], pg_user['password']))
expect(conn.exec("SELECT 1 + 2")).not_to be_nil
end
end
context 'and with cleartext passwords not set' do
let(:config_user) { { 'username' => 'sharding_user' } }
it 'it uses obtained passwords' do
connection_string = processes.pgcat.connection_string("sharded_db0", pg_user['username'], pg_user['password'])
conn = PG.connect(connection_string)
expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil
connection_string = processes.pgcat.connection_string("sharded_db1", pg_user['username'], pg_user['password'])
conn = PG.connect(connection_string)
expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil
end
end
end
end
end
end

BIN
tests/ruby/capture Normal file

Binary file not shown.

102
tests/ruby/copy_spec.rb Normal file
View File

@@ -0,0 +1,102 @@
# frozen_string_literal: true
require_relative 'spec_helper'
describe "COPY Handling" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 5) }
before do
new_configs = processes.pgcat.current_config
# Allow connections in the pool to expire faster
new_configs["general"]["idle_timeout"] = 5
processes.pgcat.update_config(new_configs)
# We need to kill the old process that was using the default configs
processes.pgcat.stop
processes.pgcat.start
processes.pgcat.wait_until_ready
end
before do
processes.all_databases.first.with_connection do |conn|
conn.async_exec "CREATE TABLE copy_test_table (a TEXT,b TEXT,c TEXT,d TEXT)"
end
end
after do
processes.all_databases.first.with_connection do |conn|
conn.async_exec "DROP TABLE copy_test_table;"
end
end
after do
processes.all_databases.map(&:reset)
processes.pgcat.shutdown
end
describe "COPY FROM" do
context "within transaction" do
it "finishes within alloted time" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
Timeout.timeout(3) do
conn.async_exec("BEGIN")
conn.copy_data "COPY copy_test_table FROM STDIN CSV" do
sleep 0.5
conn.put_copy_data "some,data,to,copy\n"
conn.put_copy_data "more,data,to,copy\n"
end
conn.async_exec("COMMIT")
end
res = conn.async_exec("SELECT * FROM copy_test_table").to_a
expect(res).to eq([
{"a"=>"some", "b"=>"data", "c"=>"to", "d"=>"copy"},
{"a"=>"more", "b"=>"data", "c"=>"to", "d"=>"copy"}
])
end
end
context "outside transaction" do
it "finishes within alloted time" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
Timeout.timeout(3) do
conn.copy_data "COPY copy_test_table FROM STDIN CSV" do
sleep 0.5
conn.put_copy_data "some,data,to,copy\n"
conn.put_copy_data "more,data,to,copy\n"
end
end
res = conn.async_exec("SELECT * FROM copy_test_table").to_a
expect(res).to eq([
{"a"=>"some", "b"=>"data", "c"=>"to", "d"=>"copy"},
{"a"=>"more", "b"=>"data", "c"=>"to", "d"=>"copy"}
])
end
end
end
describe "COPY TO" do
before do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.async_exec("BEGIN")
conn.copy_data "COPY copy_test_table FROM STDIN CSV" do
conn.put_copy_data "some,data,to,copy\n"
conn.put_copy_data "more,data,to,copy\n"
end
conn.async_exec("COMMIT")
conn.close
end
it "works" do
res = []
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.copy_data "COPY copy_test_table TO STDOUT CSV" do
while row=conn.get_copy_data
res << row
end
end
expect(res).to eq(["some,data,to,copy\n", "more,data,to,copy\n"])
end
end
end

View File

@@ -0,0 +1,173 @@
module Helpers
module AuthQuery
def self.single_shard_auth_query(
pg_user:,
config_user:,
pool_name:,
extra_conf: {},
log_level: 'debug',
wait_until_ready: true
)
user = {
"pool_size" => 10,
"statement_timeout" => 0,
}
pgcat = PgcatProcess.new(log_level)
pgcat_cfg = pgcat.current_config.deep_merge(extra_conf)
primary = PgInstance.new(5432, pg_user["username"], pg_user["password"], "shard0")
replica = PgInstance.new(10432, pg_user["username"], pg_user["password"], "shard0")
# Main proxy configs
pgcat_cfg["pools"] = {
"#{pool_name}" => {
"default_role" => "any",
"pool_mode" => "transaction",
"load_balancing_mode" => "random",
"primary_reads_enabled" => false,
"query_parser_enabled" => false,
"sharding_function" => "pg_bigint_hash",
"shards" => {
"0" => {
"database" => "shard0",
"servers" => [
["localhost", primary.port.to_i, "primary"],
["localhost", replica.port.to_i, "replica"],
]
},
},
"users" => { "0" => user.merge(config_user) }
}
}
pgcat_cfg["general"]["port"] = pgcat.port.to_i
pgcat.update_config(pgcat_cfg)
pgcat.start
pgcat.wait_until_ready(
pgcat.connection_string(
"sharded_db",
pg_user['username'],
pg_user['password']
)
) if wait_until_ready
OpenStruct.new.tap do |struct|
struct.pgcat = pgcat
struct.primary = primary
struct.replicas = [replica]
struct.all_databases = [primary]
end
end
def self.two_pools_auth_query(
pg_user:,
config_user:,
pool_names:,
extra_conf: {},
log_level: 'debug'
)
user = {
"pool_size" => 10,
"statement_timeout" => 0,
}
pgcat = PgcatProcess.new(log_level)
pgcat_cfg = pgcat.current_config
primary = PgInstance.new(5432, pg_user["username"], pg_user["password"], "shard0")
replica = PgInstance.new(10432, pg_user["username"], pg_user["password"], "shard0")
pool_template = Proc.new do |database|
{
"default_role" => "any",
"pool_mode" => "transaction",
"load_balancing_mode" => "random",
"primary_reads_enabled" => false,
"query_parser_enabled" => false,
"sharding_function" => "pg_bigint_hash",
"shards" => {
"0" => {
"database" => database,
"servers" => [
["localhost", primary.port.to_i, "primary"],
["localhost", replica.port.to_i, "replica"],
]
},
},
"users" => { "0" => user.merge(config_user) }
}
end
# Main proxy configs
pgcat_cfg["pools"] = {
"#{pool_names[0]}" => pool_template.call("shard0"),
"#{pool_names[1]}" => pool_template.call("shard1")
}
pgcat_cfg["general"]["port"] = pgcat.port
pgcat.update_config(pgcat_cfg.deep_merge(extra_conf))
pgcat.start
pgcat.wait_until_ready(pgcat.connection_string("sharded_db0", pg_user['username'], pg_user['password']))
OpenStruct.new.tap do |struct|
struct.pgcat = pgcat
struct.primary = primary
struct.replicas = [replica]
struct.all_databases = [primary]
end
end
def self.create_query_auth_function(user)
return <<-SQL
CREATE OR REPLACE FUNCTION public.user_lookup(in i_username text, out uname text, out phash text)
RETURNS record AS $$
BEGIN
SELECT usename, passwd FROM pg_catalog.pg_shadow
WHERE usename = i_username INTO uname, phash;
RETURN;
END;
$$ LANGUAGE plpgsql SECURITY DEFINER;
GRANT EXECUTE ON FUNCTION public.user_lookup(text) TO #{user};
SQL
end
def self.exec_in_instances(query:, instance_ports: [ 5432, 10432 ], database: 'postgres', user: 'postgres', password: 'postgres')
instance_ports.each do |port|
c = PG.connect("postgres://#{user}:#{password}@localhost:#{port}/#{database}")
c.exec(query)
c.close
end
end
def self.set_up_auth_query_for_user(user:, password:, instance_ports: [ 5432, 10432 ], database: 'shard0' )
instance_ports.each do |port|
connection = PG.connect("postgres://postgres:postgres@localhost:#{port}/#{database}")
connection.exec(self.drop_query_auth_function(user)) rescue PG::UndefinedFunction
connection.exec("DROP ROLE #{user}") rescue PG::UndefinedObject
connection.exec("CREATE ROLE #{user} ENCRYPTED PASSWORD '#{password}' LOGIN;")
connection.exec(self.create_query_auth_function(user))
connection.close
end
end
def self.tear_down_auth_query_for_user(user:, password:, instance_ports: [ 5432, 10432 ], database: 'shard0' )
instance_ports.each do |port|
connection = PG.connect("postgres://postgres:postgres@localhost:#{port}/#{database}")
connection.exec(self.drop_query_auth_function(user)) rescue PG::UndefinedFunction
connection.exec("DROP ROLE #{user}")
connection.close
end
end
def self.drop_query_auth_function(user)
return <<-SQL
REVOKE ALL ON FUNCTION public.user_lookup(text) FROM public, #{user};
DROP FUNCTION public.user_lookup(in i_username text, out uname text, out phash text);
SQL
end
end
end

View File

@@ -7,10 +7,24 @@ class PgInstance
attr_reader :password
attr_reader :database_name
def self.mass_takedown(databases)
raise StandardError "block missing" unless block_given?
databases.each do |database|
database.toxiproxy.toxic(:limit_data, bytes: 1).toxics.each(&:save)
end
sleep 0.1
yield
ensure
databases.each do |database|
database.toxiproxy.toxics.each(&:destroy)
end
end
def initialize(port, username, password, database_name)
@original_port = port
@original_port = port.to_i
@toxiproxy_port = 10000 + port.to_i
@port = @toxiproxy_port
@port = @toxiproxy_port.to_i
@username = username
@password = password
@@ -48,9 +62,9 @@ class PgInstance
def take_down
if block_given?
Toxiproxy[@toxiproxy_name].toxic(:limit_data, bytes: 5).apply { yield }
Toxiproxy[@toxiproxy_name].toxic(:limit_data, bytes: 1).apply { yield }
else
Toxiproxy[@toxiproxy_name].toxic(:limit_data, bytes: 5).toxics.each(&:save)
Toxiproxy[@toxiproxy_name].toxic(:limit_data, bytes: 1).toxics.each(&:save)
end
end
@@ -89,6 +103,6 @@ class PgInstance
end
def count_select_1_plus_2
with_connection { |c| c.async_exec("SELECT SUM(calls) FROM pg_stat_statements WHERE query = 'SELECT $1 + $2'")[0]["sum"].to_i }
with_connection { |c| c.async_exec("SELECT SUM(calls) FROM pg_stat_statements WHERE query LIKE '%SELECT $1 + $2%'")[0]["sum"].to_i }
end
end

View File

@@ -0,0 +1,259 @@
require 'socket'
require 'digest/md5'
BACKEND_MESSAGE_CODES = {
'Z' => "ReadyForQuery",
'C' => "CommandComplete",
'T' => "RowDescription",
'D' => "DataRow",
'1' => "ParseComplete",
'2' => "BindComplete",
'E' => "ErrorResponse",
's' => "PortalSuspended",
}
class PostgresSocket
def initialize(host, port)
@port = port
@host = host
@socket = TCPSocket.new @host, @port
@parameters = {}
@verbose = true
end
def send_md5_password_message(username, password, salt)
m = Digest::MD5.hexdigest(password + username)
m = Digest::MD5.hexdigest(m + salt.map(&:chr).join(""))
m = 'md5' + m
bytes = (m.split("").map(&:ord) + [0]).flatten
message_size = bytes.count + 4
message = []
message << 'p'.ord
message << [message_size].pack('l>').unpack('CCCC') # 4
message << bytes
message.flatten!
@socket.write(message.pack('C*'))
end
def send_startup_message(username, database, password)
message = []
message << [196608].pack('l>').unpack('CCCC') # 4
message << "user".split('').map(&:ord) # 4, 8
message << 0 # 1, 9
message << username.split('').map(&:ord) # 2, 11
message << 0 # 1, 12
message << "database".split('').map(&:ord) # 8, 20
message << 0 # 1, 21
message << database.split('').map(&:ord) # 2, 23
message << 0 # 1, 24
message << 0 # 1, 25
message.flatten!
total_message_size = message.size + 4
message_len = [total_message_size].pack('l>').unpack('CCCC')
@socket.write([message_len + message].flatten.pack('C*'))
sleep 0.1
read_startup_response(username, password)
end
def read_startup_response(username, password)
message_code, message_len = @socket.recv(5).unpack("al>")
while message_code == 'R'
auth_code = @socket.recv(4).unpack('l>').pop
case auth_code
when 5 # md5
salt = @socket.recv(4).unpack('CCCC')
send_md5_password_message(username, password, salt)
message_code, message_len = @socket.recv(5).unpack("al>")
when 0 # trust
break
end
end
loop do
message_code, message_len = @socket.recv(5).unpack("al>")
if message_code == 'Z'
@socket.recv(1).unpack("a") # most likely I
break # We are good to go
end
if message_code == 'S'
actual_message = @socket.recv(message_len - 4).unpack("C*")
k,v = actual_message.pack('U*').split(/\x00/)
@parameters[k] = v
end
if message_code == 'K'
process_id, secret_key = @socket.recv(message_len - 4).unpack("l>l>")
@parameters["process_id"] = process_id
@parameters["secret_key"] = secret_key
end
end
return @parameters
end
def cancel_query
socket = TCPSocket.new @host, @port
process_key = @parameters["process_id"]
secret_key = @parameters["secret_key"]
message = []
message << [16].pack('l>').unpack('CCCC') # 4
message << [80877102].pack('l>').unpack('CCCC') # 4
message << [process_key.to_i].pack('l>').unpack('CCCC') # 4
message << [secret_key.to_i].pack('l>').unpack('CCCC') # 4
message.flatten!
socket.write(message.flatten.pack('C*'))
socket.close
log "[F] Sent CancelRequest message"
end
def send_query_message(query)
query_size = query.length
message_size = 1 + 4 + query_size
message = []
message << "Q".ord
message << [message_size].pack('l>').unpack('CCCC') # 4
message << query.split('').map(&:ord) # 2, 11
message << 0 # 1, 12
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent Q message (#{query})"
end
def send_parse_message(query)
query_size = query.length
message_size = 2 + 2 + 4 + query_size
message = []
message << "P".ord
message << [message_size].pack('l>').unpack('CCCC') # 4
message << 0 # unnamed statement
message << query.split('').map(&:ord) # 2, 11
message << 0 # 1, 12
message << [0, 0]
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent P message (#{query})"
end
def send_bind_message
message = []
message << "B".ord
message << [12].pack('l>').unpack('CCCC') # 4
message << 0 # unnamed statement
message << 0 # unnamed statement
message << [0, 0] # 2
message << [0, 0] # 2
message << [0, 0] # 2
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent B message"
end
def send_describe_message(mode)
message = []
message << "D".ord
message << [6].pack('l>').unpack('CCCC') # 4
message << mode.ord
message << 0 # unnamed statement
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent D message"
end
def send_execute_message(limit=0)
message = []
message << "E".ord
message << [9].pack('l>').unpack('CCCC') # 4
message << 0 # unnamed statement
message << [limit].pack('l>').unpack('CCCC') # 4
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent E message"
end
def send_sync_message
message = []
message << "S".ord
message << [4].pack('l>').unpack('CCCC') # 4
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent S message"
end
def send_copydone_message
message = []
message << "c".ord
message << [4].pack('l>').unpack('CCCC') # 4
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent c message"
end
def send_copyfail_message
message = []
message << "f".ord
message << [5].pack('l>').unpack('CCCC') # 4
message << 0
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent f message"
end
def send_flush_message
message = []
message << "H".ord
message << [4].pack('l>').unpack('CCCC') # 4
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent H message"
end
def read_from_server()
output_messages = []
retry_count = 0
message_code = nil
message_len = 0
loop do
begin
message_code, message_len = @socket.recv_nonblock(5).unpack("al>")
rescue IO::WaitReadable
return output_messages if retry_count > 50
retry_count += 1
sleep(0.01)
next
end
message = {
code: message_code,
len: message_len,
bytes: []
}
log "[B] #{BACKEND_MESSAGE_CODES[message_code] || ('UnknownMessage(' + message_code + ')')}"
actual_message_length = message_len - 4
if actual_message_length > 0
message[:bytes] = @socket.recv(message_len - 4).unpack("C*")
log "\t#{message[:bytes].join(",")}"
log "\t#{message[:bytes].map(&:chr).join(" ")}"
end
output_messages << message
return output_messages if message_code == 'Z'
end
end
def log(msg)
return unless @verbose
puts msg
end
def close
@socket.close
end
end

View File

@@ -2,6 +2,14 @@ require 'json'
require 'ostruct'
require_relative 'pgcat_process'
require_relative 'pg_instance'
require_relative 'pg_socket'
class ::Hash
def deep_merge(second)
merger = proc { |key, v1, v2| Hash === v1 && Hash === v2 ? v1.merge(v2, &merger) : v2 }
self.merge(second, &merger)
end
end
module Helpers
module Pgcat
@@ -26,14 +34,32 @@ module Helpers
"load_balancing_mode" => lb_mode,
"primary_reads_enabled" => true,
"query_parser_enabled" => true,
"query_parser_read_write_splitting" => true,
"automatic_sharding_key" => "data.id",
"sharding_function" => "pg_bigint_hash",
"shards" => {
"0" => { "database" => "shard0", "servers" => [["localhost", primary0.port.to_s, "primary"]] },
"1" => { "database" => "shard1", "servers" => [["localhost", primary1.port.to_s, "primary"]] },
"2" => { "database" => "shard2", "servers" => [["localhost", primary2.port.to_s, "primary"]] },
"0" => { "database" => "shard0", "servers" => [["localhost", primary0.port.to_i, "primary"]] },
"1" => { "database" => "shard1", "servers" => [["localhost", primary1.port.to_i, "primary"]] },
"2" => { "database" => "shard2", "servers" => [["localhost", primary2.port.to_i, "primary"]] },
},
"users" => { "0" => user }
"users" => { "0" => user },
"plugins" => {
"intercept" => {
"enabled" => true,
"queries" => {
"0" => {
"query" => "select current_database() as a, current_schemas(false) as b",
"schema" => [
["a", "text"],
["b", "text"],
],
"result" => [
["${DATABASE}", "{public}"],
]
}
}
}
}
}
}
pgcat.update_config(pgcat_cfg)
@@ -74,7 +100,7 @@ module Helpers
"0" => {
"database" => "shard0",
"servers" => [
["localhost", primary.port.to_s, "primary"]
["localhost", primary.port.to_i, "primary"]
]
},
},
@@ -93,7 +119,7 @@ module Helpers
end
end
def self.single_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mode="random", log_level="info")
def self.single_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mode="random", log_level="info", pool_settings={})
user = {
"password" => "sharding_user",
"pool_size" => pool_size,
@@ -109,28 +135,32 @@ module Helpers
replica1 = PgInstance.new(8432, user["username"], user["password"], "shard0")
replica2 = PgInstance.new(9432, user["username"], user["password"], "shard0")
pool_config = {
"default_role" => "any",
"pool_mode" => pool_mode,
"load_balancing_mode" => lb_mode,
"primary_reads_enabled" => false,
"query_parser_enabled" => false,
"sharding_function" => "pg_bigint_hash",
"shards" => {
"0" => {
"database" => "shard0",
"servers" => [
["localhost", primary.port.to_i, "primary"],
["localhost", replica0.port.to_i, "replica"],
["localhost", replica1.port.to_i, "replica"],
["localhost", replica2.port.to_i, "replica"]
]
},
},
"users" => { "0" => user }
}
pool_config = pool_config.merge(pool_settings)
# Main proxy configs
pgcat_cfg["pools"] = {
"#{pool_name}" => {
"default_role" => "any",
"pool_mode" => pool_mode,
"load_balancing_mode" => lb_mode,
"primary_reads_enabled" => false,
"query_parser_enabled" => false,
"sharding_function" => "pg_bigint_hash",
"shards" => {
"0" => {
"database" => "shard0",
"servers" => [
["localhost", primary.port.to_s, "primary"],
["localhost", replica0.port.to_s, "replica"],
["localhost", replica1.port.to_s, "replica"],
["localhost", replica2.port.to_s, "replica"]
]
},
},
"users" => { "0" => user }
}
"#{pool_name}" => pool_config,
}
pgcat_cfg["general"]["port"] = pgcat.port
pgcat.update_config(pgcat_cfg)

View File

@@ -1,8 +1,10 @@
require 'pg'
require 'toml'
require 'json'
require 'tempfile'
require 'fileutils'
require 'securerandom'
class ConfigReloadFailed < StandardError; end
class PgcatProcess
attr_reader :port
attr_reader :pid
@@ -18,7 +20,7 @@ class PgcatProcess
end
def initialize(log_level)
@env = {"RUST_LOG" => log_level}
@env = {}
@port = rand(20000..32760)
@log_level = log_level
@log_filename = "/tmp/pgcat_log_#{SecureRandom.urlsafe_base64}.log"
@@ -30,7 +32,7 @@ class PgcatProcess
'../../target/debug/pgcat'
end
@command = "#{command_path} #{@config_filename}"
@command = "#{command_path} #{@config_filename} --log-level #{@log_level}"
FileUtils.cp("../../pgcat.toml", @config_filename)
cfg = current_config
@@ -46,38 +48,54 @@ class PgcatProcess
def update_config(config_hash)
@original_config = current_config
output_to_write = TOML::Generator.new(config_hash).body
output_to_write = output_to_write.gsub(/,\s*["|'](\d+)["|']\s*,/, ',\1,')
output_to_write = output_to_write.gsub(/,\s*["|'](\d+)["|']\s*\]/, ',\1]')
File.write(@config_filename, output_to_write)
Tempfile.create('json_out', '/tmp') do |f|
f.write(config_hash.to_json)
f.flush
`cat #{f.path} | yj -jt > #{@config_filename}`
end
end
def current_config
loadable_string = File.read(@config_filename)
loadable_string = loadable_string.gsub(/,\s*(\d+)\s*,/, ', "\1",')
loadable_string = loadable_string.gsub(/,\s*(\d+)\s*\]/, ', "\1"]')
TOML.load(loadable_string)
JSON.parse(`cat #{@config_filename} | yj -tj`)
end
def raw_config_file
File.read(@config_filename)
end
def reload_config
`kill -s HUP #{@pid}`
sleep 0.5
conn = PG.connect(admin_connection_string)
conn.async_exec("RELOAD")
rescue PG::ConnectionBad => e
errors = logs.split("Reloading config").last
errors = errors.gsub(/\e\[([;\d]+)?m/, '') # Remove color codes
errors = errors.
split("\n").select{|line| line.include?("ERROR") }.
map { |line| line.split("pgcat::config: ").last }
raise ConfigReloadFailed, errors.join("\n")
ensure
conn&.close
end
def start
raise StandardError, "Process is already started" unless @pid.nil?
@pid = Process.spawn(@env, @command, err: @log_filename, out: @log_filename)
Process.detach(@pid)
ObjectSpace.define_finalizer(@log_filename, proc { PgcatProcess.finalize(@pid, @log_filename, @config_filename) })
return self
end
def wait_until_ready
def wait_until_ready(connection_string = nil)
exc = nil
10.times do
PG::connect(example_connection_string).close
Process.kill 0, @pid
PG::connect(connection_string || example_connection_string).close
return self
rescue Errno::ESRCH
raise StandardError, "Process #{@pid} died. #{logs}"
rescue => e
exc = e
sleep(0.5)
@@ -108,13 +126,16 @@ class PgcatProcess
"postgresql://#{username}:#{password}@0.0.0.0:#{@port}/pgcat"
end
def connection_string(pool_name, username)
def connection_string(pool_name, username, password = nil, parameters: {})
cfg = current_config
user_idx, user_obj = cfg["pools"][pool_name]["users"].detect { |k, user| user["username"] == username }
password = user_obj["password"]
connection_string = "postgresql://#{username}:#{password || user_obj["password"]}@0.0.0.0:#{@port}/#{pool_name}"
"postgresql://#{username}:#{password}@0.0.0.0:#{@port}/#{pool_name}"
# Add the additional parameters to the connection string
parameter_string = parameters.map { |key, value| "#{key}=#{value}" }.join("&")
connection_string += "?#{parameter_string}" unless parameter_string.empty?
connection_string
end
def example_connection_string

View File

@@ -65,7 +65,7 @@ describe "Least Outstanding Queries Load Balancing" do
processes.pgcat.shutdown
end
context "under homogenous load" do
context "under homogeneous load" do
it "balances query volume between all instances" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))

View File

@@ -11,9 +11,9 @@ describe "Query Mirroing" do
before do
new_configs = processes.pgcat.current_config
new_configs["pools"]["sharded_db"]["shards"]["0"]["mirrors"] = [
[mirror_host, mirror_pg.port.to_s, "0"],
[mirror_host, mirror_pg.port.to_s, "0"],
[mirror_host, mirror_pg.port.to_s, "0"],
[mirror_host, mirror_pg.port.to_i, 0],
[mirror_host, mirror_pg.port.to_i, 0],
[mirror_host, mirror_pg.port.to_i, 0],
]
processes.pgcat.update_config(new_configs)
processes.pgcat.reload_config
@@ -25,13 +25,14 @@ describe "Query Mirroing" do
processes.pgcat.shutdown
end
it "can mirror a query" do
xit "can mirror a query" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
runs = 15
runs.times { conn.async_exec("SELECT 1 + 2") }
sleep 0.5
expect(processes.all_databases.first.count_select_1_plus_2).to eq(runs)
expect(mirror_pg.count_select_1_plus_2).to eq(runs * 3)
# Allow some slack in mirroring successes
expect(mirror_pg.count_select_1_plus_2).to be > ((runs - 5) * 3)
end
context "when main server connection is closed" do
@@ -42,9 +43,9 @@ describe "Query Mirroing" do
new_configs = processes.pgcat.current_config
new_configs["pools"]["sharded_db"]["idle_timeout"] = 5000 + i
new_configs["pools"]["sharded_db"]["shards"]["0"]["mirrors"] = [
[mirror_host, mirror_pg.port.to_s, "0"],
[mirror_host, mirror_pg.port.to_s, "0"],
[mirror_host, mirror_pg.port.to_s, "0"],
[mirror_host, mirror_pg.port.to_i, 0],
[mirror_host, mirror_pg.port.to_i, 0],
[mirror_host, mirror_pg.port.to_i, 0],
]
processes.pgcat.update_config(new_configs)
processes.pgcat.reload_config

View File

@@ -221,7 +221,7 @@ describe "Miscellaneous" do
conn.close
end
it "Does not send DISCARD ALL unless necessary" do
it "Does not send RESET ALL unless necessary" do
10.times do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.async_exec("SET SERVER ROLE to 'primary'")
@@ -229,7 +229,7 @@ describe "Miscellaneous" do
conn.close
end
expect(processes.primary.count_query("DISCARD ALL")).to eq(0)
expect(processes.primary.count_query("RESET ALL")).to eq(0)
10.times do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
@@ -239,7 +239,19 @@ describe "Miscellaneous" do
conn.close
end
expect(processes.primary.count_query("DISCARD ALL")).to eq(10)
expect(processes.primary.count_query("RESET ALL")).to eq(10)
end
it "Resets server roles correctly" do
10.times do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.async_exec("SET SERVER ROLE to 'primary'")
conn.async_exec("SELECT 1")
conn.async_exec("SET statement_timeout to 5000")
conn.close
end
expect(processes.primary.count_query("RESET ROLE")).to eq(10)
end
end
@@ -261,7 +273,7 @@ describe "Miscellaneous" do
end
end
it "Does not send DISCARD ALL unless necessary" do
it "Does not send RESET ALL unless necessary" do
10.times do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.async_exec("SET SERVER ROLE to 'primary'")
@@ -270,7 +282,7 @@ describe "Miscellaneous" do
conn.close
end
expect(processes.primary.count_query("DISCARD ALL")).to eq(0)
expect(processes.primary.count_query("RESET ALL")).to eq(0)
10.times do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
@@ -280,8 +292,32 @@ describe "Miscellaneous" do
conn.close
end
expect(processes.primary.count_query("DISCARD ALL")).to eq(10)
expect(processes.primary.count_query("RESET ALL")).to eq(10)
end
it "Respects tracked parameters on startup" do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user", parameters: { "application_name" => "my_pgcat_test" }))
expect(conn.async_exec("SHOW application_name")[0]["application_name"]).to eq("my_pgcat_test")
conn.close
end
it "Respect tracked parameter on set statemet" do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.async_exec("SET application_name to 'my_pgcat_test'")
expect(conn.async_exec("SHOW application_name")[0]["application_name"]).to eq("my_pgcat_test")
end
it "Ignore untracked parameter on set statemet" do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
orignal_statement_timeout = conn.async_exec("SHOW statement_timeout")[0]["statement_timeout"]
conn.async_exec("SET statement_timeout to 1500")
expect(conn.async_exec("SHOW statement_timeout")[0]["statement_timeout"]).to eq(orignal_statement_timeout)
end
end
context "transaction mode with transactions" do
@@ -295,7 +331,7 @@ describe "Miscellaneous" do
conn.async_exec("COMMIT")
conn.close
end
expect(processes.primary.count_query("DISCARD ALL")).to eq(0)
expect(processes.primary.count_query("RESET ALL")).to eq(0)
10.times do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
@@ -305,7 +341,30 @@ describe "Miscellaneous" do
conn.async_exec("COMMIT")
conn.close
end
expect(processes.primary.count_query("DISCARD ALL")).to eq(0)
expect(processes.primary.count_query("RESET ALL")).to eq(0)
end
end
context "server cleanup disabled" do
let(:processes) { Helpers::Pgcat.single_shard_setup("sharded_db", 1, "transaction", "random", "info", { "cleanup_server_connections" => false }) }
it "will not clean up connection state" do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
processes.primary.reset_stats
conn.async_exec("SET statement_timeout TO 1000")
conn.close
expect(processes.primary.count_query("RESET ALL")).to eq(0)
end
it "will not clean up prepared statements" do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
processes.primary.reset_stats
conn.async_exec("PREPARE prepared_q (int) AS SELECT $1")
conn.close
expect(processes.primary.count_query("RESET ALL")).to eq(0)
end
end
end
@@ -315,10 +374,9 @@ describe "Miscellaneous" do
before do
current_configs = processes.pgcat.current_config
correct_idle_client_transaction_timeout = current_configs["general"]["idle_client_in_transaction_timeout"]
puts(current_configs["general"]["idle_client_in_transaction_timeout"])
current_configs["general"]["idle_client_in_transaction_timeout"] = 0
processes.pgcat.update_config(current_configs) # with timeout 0
processes.pgcat.reload_config
end
@@ -336,9 +394,9 @@ describe "Miscellaneous" do
context "idle transaction timeout set to 500ms" do
before do
current_configs = processes.pgcat.current_config
correct_idle_client_transaction_timeout = current_configs["general"]["idle_client_in_transaction_timeout"]
correct_idle_client_transaction_timeout = current_configs["general"]["idle_client_in_transaction_timeout"]
current_configs["general"]["idle_client_in_transaction_timeout"] = 500
processes.pgcat.update_config(current_configs) # with timeout 500
processes.pgcat.reload_config
end
@@ -357,7 +415,7 @@ describe "Miscellaneous" do
conn.async_exec("BEGIN")
conn.async_exec("SELECT 1")
sleep(1) # above 500ms
expect{ conn.async_exec("COMMIT") }.to raise_error(PG::SystemError, /idle transaction timeout/)
expect{ conn.async_exec("COMMIT") }.to raise_error(PG::SystemError, /idle transaction timeout/)
conn.async_exec("SELECT 1") # should be able to send another query
conn.close
end

View File

@@ -0,0 +1,14 @@
require_relative 'spec_helper'
describe "Plugins" do
let(:processes) { Helpers::Pgcat.three_shard_setup("sharded_db", 5) }
context "intercept" do
it "will intercept an intellij query" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
res = conn.exec("select current_database() as a, current_schemas(false) as b")
expect(res.values).to eq([["sharded_db", "{public}"]])
end
end
end

View File

@@ -0,0 +1,29 @@
require_relative 'spec_helper'
describe 'Prepared statements' do
let(:processes) { Helpers::Pgcat.three_shard_setup('sharded_db', 5) }
context 'enabled' do
it 'will work over the same connection' do
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
10.times do |i|
statement_name = "statement_#{i}"
conn.prepare(statement_name, 'SELECT $1::int')
conn.exec_prepared(statement_name, [1])
conn.describe_prepared(statement_name)
end
end
it 'will work with new connections' do
10.times do
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
statement_name = 'statement1'
conn.prepare('statement1', 'SELECT $1::int')
conn.exec_prepared('statement1', [1])
conn.describe_prepared('statement1')
end
end
end
end

155
tests/ruby/protocol_spec.rb Normal file
View File

@@ -0,0 +1,155 @@
# frozen_string_literal: true
require_relative 'spec_helper'
describe "Portocol handling" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 1, "session") }
let(:sequence) { [] }
let(:pgcat_socket) { PostgresSocket.new('localhost', processes.pgcat.port) }
let(:pgdb_socket) { PostgresSocket.new('localhost', processes.all_databases.first.port) }
after do
pgdb_socket.close
pgcat_socket.close
processes.all_databases.map(&:reset)
processes.pgcat.shutdown
end
def run_comparison(sequence, socket_a, socket_b)
sequence.each do |msg, *args|
socket_a.send(msg, *args)
socket_b.send(msg, *args)
compare_messages(
socket_a.read_from_server,
socket_b.read_from_server
)
end
end
def compare_messages(msg_arr0, msg_arr1)
if msg_arr0.count != msg_arr1.count
error_output = []
error_output << "#{msg_arr0.count} : #{msg_arr1.count}"
error_output << "PgCat Messages"
error_output += msg_arr0.map { |message| "\t#{message[:code]} - #{message[:bytes].map(&:chr).join(" ")}" }
error_output << "PgServer Messages"
error_output += msg_arr1.map { |message| "\t#{message[:code]} - #{message[:bytes].map(&:chr).join(" ")}" }
error_desc = error_output.join("\n")
raise StandardError, "Message count mismatch #{error_desc}"
end
(0..msg_arr0.count - 1).all? do |i|
msg0 = msg_arr0[i]
msg1 = msg_arr1[i]
result = [
msg0[:code] == msg1[:code],
msg0[:len] == msg1[:len],
msg0[:bytes] == msg1[:bytes],
].all?
next result if result
if result == false
error_string = []
if msg0[:code] != msg1[:code]
error_string << "code #{msg0[:code]} != #{msg1[:code]}"
end
if msg0[:len] != msg1[:len]
error_string << "len #{msg0[:len]} != #{msg1[:len]}"
end
if msg0[:bytes] != msg1[:bytes]
error_string << "bytes #{msg0[:bytes]} != #{msg1[:bytes]}"
end
err = error_string.join("\n")
raise StandardError, "Message mismatch #{err}"
end
end
end
RSpec.shared_examples "at parity with database" do
before do
pgcat_socket.send_startup_message("sharding_user", "sharded_db", "sharding_user")
pgdb_socket.send_startup_message("sharding_user", "shard0", "sharding_user")
end
it "works" do
run_comparison(sequence, pgcat_socket, pgdb_socket)
end
end
context "Cancel Query" do
let(:sequence) {
[
[:send_query_message, "SELECT pg_sleep(5)"],
[:cancel_query]
]
}
it_behaves_like "at parity with database"
end
xcontext "Simple query after parse" do
let(:sequence) {
[
[:send_parse_message, "SELECT 5"],
[:send_query_message, "SELECT 1"],
[:send_bind_message],
[:send_describe_message, "P"],
[:send_execute_message],
[:send_sync_message],
]
}
# Known to fail due to PgCat not supporting flush
it_behaves_like "at parity with database"
end
xcontext "Flush message" do
let(:sequence) {
[
[:send_parse_message, "SELECT 1"],
[:send_flush_message]
]
}
# Known to fail due to PgCat not supporting flush
it_behaves_like "at parity with database"
end
xcontext "Bind without parse" do
let(:sequence) {
[
[:send_bind_message]
]
}
# This is known to fail.
# Server responds immediately, Proxy buffers the message
it_behaves_like "at parity with database"
end
context "Simple message" do
let(:sequence) {
[[:send_query_message, "SELECT 1"]]
}
it_behaves_like "at parity with database"
end
context "Extended protocol" do
let(:sequence) {
[
[:send_parse_message, "SELECT 1"],
[:send_bind_message],
[:send_describe_message, "P"],
[:send_execute_message],
[:send_sync_message],
]
}
it_behaves_like "at parity with database"
end
end

View File

@@ -7,11 +7,11 @@ describe "Sharding" do
before do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
# Setup the sharding data
3.times do |i|
conn.exec("SET SHARD TO '#{i}'")
conn.exec("DELETE FROM data WHERE id > 0")
conn.exec("DELETE FROM data WHERE id > 0") rescue nil
end
18.times do |i|
@@ -19,15 +19,16 @@ describe "Sharding" do
conn.exec("SET SHARDING KEY TO '#{i}'")
conn.exec("INSERT INTO data (id, value) VALUES (#{i}, 'value_#{i}')")
end
conn.close
end
after do
processes.all_databases.map(&:reset)
processes.pgcat.shutdown
end
describe "automatic routing of extended procotol" do
describe "automatic routing of extended protocol" do
it "can do it" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.exec("SET SERVER ROLE TO 'auto'")
@@ -48,4 +49,148 @@ describe "Sharding" do
end
end
end
describe "no_shard_specified_behavior config" do
context "when default shard number is invalid" do
it "prevents config reload" do
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
current_configs = processes.pgcat.current_config
current_configs["pools"]["sharded_db"]["default_shard"] = "shard_99"
processes.pgcat.update_config(current_configs)
expect { processes.pgcat.reload_config }.to raise_error(ConfigReloadFailed, /Invalid shard 99/)
end
end
end
describe "comment-based routing" do
context "when no configs are set" do
it "routes queries with a shard_id comment to the default shard" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
10.times { conn.async_exec("/* shard_id: 2 */ SELECT 1 + 2") }
expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([10, 0, 0])
end
it "does not honor no_shard_specified_behavior directives" do
end
end
[
["shard_id_regex", "/\\* the_shard_id: (\\d+) \\*/", "/* the_shard_id: 1 */"],
["sharding_key_regex", "/\\* the_sharding_key: (\\d+) \\*/", "/* the_sharding_key: 3 */"],
].each do |config_name, config_value, comment_to_use|
context "when #{config_name} config is set" do
let(:no_shard_specified_behavior) { nil }
before do
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
current_configs = processes.pgcat.current_config
current_configs["pools"]["sharded_db"][config_name] = config_value
if no_shard_specified_behavior
current_configs["pools"]["sharded_db"]["default_shard"] = no_shard_specified_behavior
else
current_configs["pools"]["sharded_db"].delete("default_shard")
end
processes.pgcat.update_config(current_configs)
processes.pgcat.reload_config
end
it "routes queries with a shard_id comment to the correct shard" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") }
expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 25, 0])
end
context "when no_shard_specified_behavior config is set to random" do
let(:no_shard_specified_behavior) { "random" }
context "with no shard comment" do
it "sends queries to random shard" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
25.times { conn.async_exec("SELECT 1 + 2") }
expect(processes.all_databases.map(&:count_select_1_plus_2).all?(&:positive?)).to be true
end
end
context "with a shard comment" do
it "honors the comment" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") }
expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 25, 0])
end
end
end
context "when no_shard_specified_behavior config is set to random_healthy" do
let(:no_shard_specified_behavior) { "random_healthy" }
context "with no shard comment" do
it "sends queries to random healthy shard" do
good_databases = [processes.all_databases[0], processes.all_databases[2]]
bad_database = processes.all_databases[1]
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
250.times { conn.async_exec("SELECT 99") }
bad_database.take_down do
250.times do
conn.async_exec("SELECT 99")
rescue PG::ConnectionBad => e
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
end
end
# Routes traffic away from bad shard
25.times { conn.async_exec("SELECT 1 + 2") }
expect(good_databases.map(&:count_select_1_plus_2).all?(&:positive?)).to be true
expect(bad_database.count_select_1_plus_2).to eq(0)
# Routes traffic to the bad shard if the shard_id is specified
25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") }
bad_database = processes.all_databases[1]
expect(bad_database.count_select_1_plus_2).to eq(25)
end
end
context "with a shard comment" do
it "honors the comment" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") }
expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 25, 0])
end
end
end
context "when no_shard_specified_behavior config is set to shard_x" do
let(:no_shard_specified_behavior) { "shard_2" }
context "with no shard comment" do
it "sends queries to the specified shard" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
25.times { conn.async_exec("SELECT 1 + 2") }
expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 0, 25])
end
end
context "with a shard comment" do
it "honors the comment" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") }
expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 25, 0])
end
end
end
end
end
end
end

View File

@@ -19,3 +19,10 @@ ensure
STDOUT.reopen(sout)
STDERR.reopen(serr)
end
def clients_connected_to_pool(pool_index: 0, processes:)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[pool_index]
admin_conn.close
results['cl_idle'].to_i + results['cl_active'].to_i + results['cl_waiting'].to_i
end

369
tests/ruby/stats_spec.rb Normal file
View File

@@ -0,0 +1,369 @@
# frozen_string_literal: true
require 'open3'
require_relative 'spec_helper'
describe "Stats" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 10) }
let(:pgcat_conn_str) { processes.pgcat.connection_string("sharded_db", "sharding_user") }
after do
processes.all_databases.map(&:reset)
processes.pgcat.shutdown
end
describe "SHOW STATS" do
context "clients connect and make one query" do
it "updates *_query_time and *_wait_time" do
connections = Array.new(3) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
Thread.new { c.async_exec("SELECT pg_sleep(0.25)") }
end
sleep(1)
connections.map(&:close)
# wait for averages to be calculated, we shouldn't do this too often
sleep(15.5)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW STATS")[0]
admin_conn.close
expect(results["total_query_time"].to_i).to be_within(200).of(750)
expect(results["avg_query_time"].to_i).to be_within(50).of(250)
expect(results["total_wait_time"].to_i).to_not eq(0)
expect(results["avg_wait_time"].to_i).to_not eq(0)
end
end
end
describe "SHOW POOLS" do
context "bad credentials" do
it "does not change any stats" do
bad_password_url = URI(pgcat_conn_str)
bad_password_url.password = "wrong"
expect { PG::connect("#{bad_password_url.to_s}?application_name=bad_password") }.to raise_error(PG::ConnectionBad)
sleep(1)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["sv_idle"]).to eq("1")
end
end
context "bad database name" do
it "does not change any stats" do
bad_db_url = URI(pgcat_conn_str)
bad_db_url.path = "/wrong_db"
expect { PG::connect("#{bad_db_url.to_s}?application_name=bad_db") }.to raise_error(PG::ConnectionBad)
sleep(1)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["sv_idle"]).to eq("1")
end
end
context "client connects but issues no queries" do
it "only affects cl_idle stats" do
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
before_test = admin_conn.async_exec("SHOW POOLS")[0]["sv_idle"]
connections = Array.new(20) { PG::connect(pgcat_conn_str) }
sleep(1)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_idle"]).to eq("20")
expect(results["sv_idle"]).to eq(before_test)
connections.map(&:close)
sleep(1.1)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_idle cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["sv_idle"]).to eq(before_test)
end
end
context "clients connect and make one query" do
it "only affects cl_idle, sv_idle stats" do
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
Thread.new { c.async_exec("SELECT pg_sleep(2.5)") }
end
sleep(1.1)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_waiting cl_cancel_req sv_idle sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_active"]).to eq("5")
expect(results["sv_active"]).to eq("5")
sleep(3)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_idle"]).to eq("5")
expect(results["sv_idle"]).to eq("5")
connections.map(&:close)
sleep(1)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["sv_idle"]).to eq("5")
end
end
context "client connects and opens a transaction and closes connection uncleanly" do
it "produces correct statistics" do
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
Thread.new do
c.async_exec("BEGIN")
c.async_exec("SELECT pg_sleep(0.01)")
c.close
end
end
sleep(1.1)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["sv_idle"]).to eq("5")
end
end
context "client fail to checkout connection from the pool" do
it "counts clients as idle" do
new_configs = processes.pgcat.current_config
new_configs["general"]["connect_timeout"] = 500
new_configs["general"]["ban_time"] = 1
new_configs["general"]["shutdown_timeout"] = 1
new_configs["pools"]["sharded_db"]["users"]["0"]["pool_size"] = 1
processes.pgcat.update_config(new_configs)
processes.pgcat.reload_config
threads = []
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
threads << Thread.new { c.async_exec("SELECT pg_sleep(1)") rescue PG::SystemError }
end
sleep(2)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_idle"]).to eq("5")
expect(results["sv_idle"]).to eq("1")
threads.map(&:join)
connections.map(&:close)
end
end
context "clients connects and disconnect normally" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 2) }
it 'shows the same number of clients before and after' do
clients_before = clients_connected_to_pool(processes: processes)
threads = []
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
threads << Thread.new { c.async_exec("SELECT 1") rescue nil }
end
clients_between = clients_connected_to_pool(processes: processes)
expect(clients_before).not_to eq(clients_between)
connections.each(&:close)
clients_after = clients_connected_to_pool(processes: processes)
expect(clients_before).to eq(clients_after)
end
end
context "clients connects and disconnect abruptly" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 10) }
it 'shows the same number of clients before and after' do
threads = []
connections = Array.new(2) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
threads << Thread.new { c.async_exec("SELECT 1") }
end
clients_before = clients_connected_to_pool(processes: processes)
random_string = (0...8).map { (65 + rand(26)).chr }.join
connection_string = "#{pgcat_conn_str}?application_name=#{random_string}"
faulty_client = Process.spawn("psql -Atx #{connection_string} >/dev/null")
sleep(1)
# psql starts two processes, we only know the pid of the parent, this
# ensure both are killed
`pkill -9 -f '#{random_string}'`
Process.wait(faulty_client)
clients_after = clients_connected_to_pool(processes: processes)
expect(clients_before).to eq(clients_after)
end
end
context "clients overwhelm server pools" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 2) }
it "cl_waiting is updated to show it" do
threads = []
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
threads << Thread.new { c.async_exec("SELECT pg_sleep(1.5)") }
end
sleep(1.1) # Allow time for stats to update
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_cancel_req sv_idle sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_waiting"]).to eq("2")
expect(results["cl_active"]).to eq("2")
expect(results["sv_active"]).to eq("2")
sleep(2.5) # Allow time for stats to update
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_idle"]).to eq("4")
expect(results["sv_idle"]).to eq("2")
threads.map(&:join)
connections.map(&:close)
end
it "show correct max_wait" do
threads = []
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
threads << Thread.new { c.async_exec("SELECT pg_sleep(1.5)") rescue nil }
end
sleep(2.5) # Allow time for stats to update
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
expect(results["maxwait"]).to eq("1")
expect(results["maxwait_us"].to_i).to be_within(200_000).of(500_000)
connections.map(&:close)
sleep(4.5) # Allow time for stats to update
results = admin_conn.async_exec("SHOW POOLS")[0]
expect(results["maxwait"]).to eq("0")
threads.map(&:join)
end
end
end
describe "SHOW CLIENTS" do
it "reports correct number and application names" do
conn_str = processes.pgcat.connection_string("sharded_db", "sharding_user")
connections = Array.new(20) { |i| PG::connect("#{conn_str}?application_name=app#{i % 5}") }
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
sleep(1) # Wait for stats to be updated
results = admin_conn.async_exec("SHOW CLIENTS")
expect(results.count).to eq(21) # count admin clients
expect(results.select { |c| c["application_name"] == "app3" || c["application_name"] == "app4" }.count).to eq(8)
expect(results.select { |c| c["database"] == "pgcat" }.count).to eq(1)
connections[0..5].map(&:close)
sleep(1) # Wait for stats to be updated
results = admin_conn.async_exec("SHOW CLIENTS")
expect(results.count).to eq(15)
connections[6..].map(&:close)
sleep(1) # Wait for stats to be updated
expect(admin_conn.async_exec("SHOW CLIENTS").count).to eq(1)
admin_conn.close
end
it "reports correct number of queries and transactions" do
conn_str = processes.pgcat.connection_string("sharded_db", "sharding_user")
connections = Array.new(2) { |i| PG::connect("#{conn_str}?application_name=app#{i}") }
connections.each do |c|
c.async_exec("SELECT 1")
c.async_exec("SELECT 2")
c.async_exec("SELECT 3")
c.async_exec("BEGIN")
c.async_exec("SELECT 4")
c.async_exec("SELECT 5")
c.async_exec("COMMIT")
end
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
sleep(1) # Wait for stats to be updated
results = admin_conn.async_exec("SHOW CLIENTS")
expect(results.count).to eq(3)
normal_client_results = results.reject { |r| r["database"] == "pgcat" }
expect(normal_client_results[0]["transaction_count"]).to eq("4")
expect(normal_client_results[1]["transaction_count"]).to eq("4")
expect(normal_client_results[0]["query_count"]).to eq("7")
expect(normal_client_results[1]["query_count"]).to eq("7")
admin_conn.close
connections.map(&:close)
end
end
describe "Query Storm" do
context "when the proxy receives overwhelmingly large number of short quick queries" do
it "should not have lingering clients or active servers" do
new_configs = processes.pgcat.current_config
new_configs["general"]["connect_timeout"] = 500
new_configs["general"]["ban_time"] = 1
new_configs["general"]["shutdown_timeout"] = 1
new_configs["pools"]["sharded_db"]["users"]["0"]["pool_size"] = 1
processes.pgcat.update_config(new_configs)
processes.pgcat.reload_config
Array.new(40) do
Thread.new do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.async_exec("SELECT pg_sleep(0.1)")
rescue PG::SystemError
ensure
conn.close
end
end.each(&:join)
sleep 1
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_waiting cl_cancel_req sv_used sv_tested sv_login].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
admin_conn.close
end
end
end
end

1
tests/rust/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
target/

1322
tests/rust/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

10
tests/rust/Cargo.toml Normal file
View File

@@ -0,0 +1,10 @@
[package]
name = "rust"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
sqlx = { version = "0.6.2", features = [ "runtime-tokio-rustls", "postgres", "json", "tls", "migrate", "time", "uuid", "ipnetwork"] }
tokio = { version = "1", features = ["full"] }

29
tests/rust/src/main.rs Normal file
View File

@@ -0,0 +1,29 @@
#[tokio::main]
async fn main() {
test_prepared_statements().await;
}
async fn test_prepared_statements() {
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(5)
.connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/sharded_db")
.await
.unwrap();
let mut handles = Vec::new();
for _ in 0..5 {
let pool = pool.clone();
let handle = tokio::task::spawn(async move {
for _ in 0..1000 {
sqlx::query("SELECT 1").fetch_all(&pool).await.unwrap();
}
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
}

40
utilities/deb.sh Normal file
View File

@@ -0,0 +1,40 @@
#!/bin/bash
#
# Build an Ubuntu deb.
#
script_dir=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
deb_dir="/tmp/pgcat-build"
export PACKAGE_VERSION=${1:-"1.1.1"}
if [[ $(arch) == "x86_64" ]]; then
export ARCH=amd64
else
export ARCH=arm64
fi
cd "$script_dir/.."
cargo build --release
rm -rf "$deb_dir"
mkdir -p "$deb_dir/DEBIAN"
mkdir -p "$deb_dir/usr/bin"
mkdir -p "$deb_dir/etc/systemd/system"
cp target/release/pgcat "$deb_dir/usr/bin/pgcat"
chmod +x "$deb_dir/usr/bin/pgcat"
cp pgcat.toml "$deb_dir/etc/pgcat.toml"
cp pgcat.service "$deb_dir/etc/systemd/system/pgcat.service"
(cat control | envsubst) > "$deb_dir/DEBIAN/control"
cp postinst "$deb_dir/DEBIAN/postinst"
cp postrm "$deb_dir/DEBIAN/postrm"
cp prerm "$deb_dir/DEBIAN/prerm"
chmod +x ${deb_dir}/DEBIAN/post*
chmod +x ${deb_dir}/DEBIAN/pre*
dpkg-deb \
--root-owner-group \
-z1 \
--build "$deb_dir" \
pgcat-${PACKAGE_VERSION}-ubuntu22.04-${ARCH}.deb

View File

@@ -0,0 +1 @@
tomli