Compare commits

..

68 Commits

Author SHA1 Message Date
Lev Kokotov
622891ee5b Release 1.1 2023-07-25 10:26:38 -07:00
Mostafa Abdelraouf
2a8f3653a6 Fix COPY FROM and add tests (#522)
* Fix COPY FROM and add tests

* E

* fmt
2023-07-20 23:06:01 -07:00
Sebastian Webber
19cb8a3022 add --no-color option to disable colors in the terminal (#518)
add --no-color option to disable colors

this commit adds a new option to disable colors in the terminal and also
moves the logger configuration to a different crate.

Signed-off-by: Sebastian Webber <sebastian@swebber.me>
2023-07-19 21:15:55 -07:00
Sebastian Webber
f85e5bd9e8 add support for multiple log formats (#517)
this commit adds the tracing-subscriber crate and use its formatters to
support multiple log formats.

More details in
https://github.com/postgresml/pgcat/issues/464#issuecomment-1641430299

Signed-off-by: Sebastian Webber <sebastian@swebber.me>
2023-07-18 23:07:13 -07:00
Sebastian Webber
7bdb4e5cd9 Add cmd line parser (#512)
This commit adds the clap library and configures the necessary args to
parse from the command line,  expanding the current option of a single
file and adding support for environment variables.

Signed-off-by: Sebastian Webber <sebastian@swebber.me>
2023-07-18 13:52:40 -07:00
Sebastian Webber
5d87e3781e push and build only in main and tags (#508)
this commit changes the CI behavior to only build and push when something is committed to main or is a new tag.
2023-07-14 10:30:49 -07:00
dependabot[bot]
3e08c6bd8d chore(deps): bump num_cpus from 1.15.0 to 1.16.0 (#507)
Bumps [num_cpus](https://github.com/seanmonstar/num_cpus) from 1.15.0 to 1.16.0.
- [Release notes](https://github.com/seanmonstar/num_cpus/releases)
- [Changelog](https://github.com/seanmonstar/num_cpus/blob/master/CHANGELOG.md)
- [Commits](https://github.com/seanmonstar/num_cpus/compare/v1.15.0...v1.16.0)

---
updated-dependencies:
- dependency-name: num_cpus
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-07-14 07:58:11 -07:00
Sebastian Webber
15b6db8e4e add "show help" command (#505)
This commit adds a new function to handle notify and use it
in the SHOW HELP command, which displays the available options
in the admin console.

Also, adding Fabrízio as a co-author for all the help with the
protocol and the help to structure this PR.

Signed-off-by: Sebastian Webber <sebastian@swebber.me>
Co-authored-by: Fabrízio de Royes Mello <fabriziomello@gmail.com>
2023-07-13 22:40:04 -07:00
dependabot[bot]
b2e6dfd9bb chore(deps): bump rustls-pemfile from 1.0.2 to 1.0.3 (#504)
Bumps [rustls-pemfile](https://github.com/rustls/pemfile) from 1.0.2 to 1.0.3.
- [Commits](https://github.com/rustls/pemfile/commits)

---
updated-dependencies:
- dependency-name: rustls-pemfile
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-07-12 21:41:48 -07:00
Mostafa Abdelraouf
3c9565d351 Add support for tcp_user_timeout (#503)
* Add support for tcp_user_timeout

* option

* duration

* Some()

* docs

* fmt, compile
2023-07-12 11:24:30 -07:00
dependabot[bot]
67579c9af4 chore(deps): bump rustls from 0.21.1 to 0.21.5 (#501)
Bumps [rustls](https://github.com/rustls/rustls) from 0.21.1 to 0.21.5.
- [Release notes](https://github.com/rustls/rustls/releases)
- [Commits](https://github.com/rustls/rustls/compare/v/0.21.1...v/0.21.5)

---
updated-dependencies:
- dependency-name: rustls
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-07-12 05:46:31 -07:00
Cluas
cf7f6f35ab docs: fix general.autoreload description (#491)
* docs: fix autoreload description

Signed-off-by: Cluas <Cluas@live.cn>

* docs: add blank line

Signed-off-by: Cluas <Cluas@live.cn>

---------

Signed-off-by: Cluas <Cluas@live.cn>
2023-07-12 05:42:44 -07:00
Voldemarich
7205537b49 [BUG] Fix binding of NULL value parameters in prepared statements (#496)
Fix binding of NULL value parameters in prepared statements

Co-authored-by: anon <anon@non.existent>
2023-07-10 10:35:43 +02:00
Zain Kabani
1ed6e925ed Fixes the default for round robing in General (#488) 2023-06-23 09:15:44 -07:00
Lev Kokotov
4b78af9676 Implement Close for prepared statements (#482)
* Partial support for Close

* Close

* respect config value

* prepared spec

* Hmm

* Print cache size
2023-06-18 23:02:34 -07:00
Lev Kokotov
73500c0c96 Fix build (#481) 2023-06-17 09:09:54 -07:00
Lev Kokotov
b167de5aa3 fmt (#480) 2023-06-17 08:57:33 -07:00
Juraj Bubniak
473bb3d17d Log not implemented messages as debug in prometheus metrics. (#477) 2023-06-16 18:48:38 -07:00
Lev Kokotov
c7d6273037 Support for prepared statements (#474)
* Start prepared statements

* parse

* Ok

* optional

* dont rewrite anonymous prepared stmts

* Dont rewrite anonymous prep statements

* hm?

* prep statements

* I see!

* comment

* Print config value

* Rewrite bind and add sqlx test

* fmt

* ok

* Fix

* Fix stats

* its late

* clean up PREPARE
2023-06-16 12:57:44 -07:00
Jeff Chen
94c781881f Report min_pool_size correctly (#471) 2023-06-12 09:23:56 -07:00
dependabot[bot]
a8c81e5df6 chore(deps): bump pin-project from 1.0.12 to 1.1.0 (#440)
Bumps [pin-project](https://github.com/taiki-e/pin-project) from 1.0.12 to 1.1.0.
- [Release notes](https://github.com/taiki-e/pin-project/releases)
- [Changelog](https://github.com/taiki-e/pin-project/blob/main/CHANGELOG.md)
- [Commits](https://github.com/taiki-e/pin-project/compare/v1.0.12...v1.1.0)

---
updated-dependencies:
- dependency-name: pin-project
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-06-12 00:51:24 -07:00
dependabot[bot]
1d3746ec9e chore(deps): bump sqlparser from 0.33.0 to 0.34.0 (#448)
Bumps [sqlparser](https://github.com/sqlparser-rs/sqlparser-rs) from 0.33.0 to 0.34.0.
- [Changelog](https://github.com/sqlparser-rs/sqlparser-rs/blob/main/CHANGELOG.md)
- [Commits](https://github.com/sqlparser-rs/sqlparser-rs/compare/v0.33.0...v0.34.0)

---
updated-dependencies:
- dependency-name: sqlparser
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-06-12 00:51:05 -07:00
dependabot[bot]
b5489dc1e6 chore(deps): bump regex from 1.8.1 to 1.8.4 (#466)
Bumps [regex](https://github.com/rust-lang/regex) from 1.8.1 to 1.8.4.
- [Release notes](https://github.com/rust-lang/regex/releases)
- [Changelog](https://github.com/rust-lang/regex/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/regex/compare/1.8.1...1.8.4)

---
updated-dependencies:
- dependency-name: regex
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-06-12 00:50:46 -07:00
dependabot[bot]
557b425fb1 chore(deps): bump log from 0.4.17 to 0.4.19 (#470)
Bumps [log](https://github.com/rust-lang/log) from 0.4.17 to 0.4.19.
- [Release notes](https://github.com/rust-lang/log/releases)
- [Changelog](https://github.com/rust-lang/log/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/log/compare/0.4.17...0.4.19)

---
updated-dependencies:
- dependency-name: log
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-06-12 00:50:26 -07:00
Zain Kabani
aca9738821 Make queue strategy configurable and default to Fifo (#463)
* Change idle timeout default to 10 minutes

* Revert lifo for now while we investigate connection thrashing issues

* Make queue strategy configurable

* test revert idle time out

* Add pgcat start to python test
2023-06-09 11:35:20 -07:00
Zain Kabani
0bc453a771 Change default server lifetime and bump bb8 version to use LIFO correctly (#462)
Change default server lifetime and idle timeouts and bump bb8 version to use LIFO correctly
2023-05-31 08:25:42 -07:00
Zain Kabani
b67c33b6d0 Use latest bb8 and use Lifo as the queue strategy in the pool (#455)
* Use git bb8

* Use latest bb8 and change pool is use stack
2023-05-28 19:46:13 -07:00
Mostafa Abdelraouf
a8a30ad43b Refactor Pool Stats to be based off of Server/Client stats (#445)
What is wrong
Stats reported by SHOW POOLS seem to be leaking. We see lingering cl_idle , cl_waiting, and similarly for sv_idle , sv_active. We confirmed that these are reporting issues not actual lingering clients.

This behavior is readily reproducible by running

while true; do
    psql "postgres://sharding_user:sharding_user@localhost:6432/sharded_db" -c "SELECT 1" > /dev/null 2>&1  &
done

Why it happens
I wasn't able to get to figure our the reason for the bug but my best guess is that we have race conditions when updating pool-level stats. So even though individual update operations are atomic, we perform a check then update sequence which is not protected by a guard.
https://github.com/postgresml/pgcat/blob/main/src/stats/pool.rs#L174-L179

I am also suspecting that using Relaxed ordering might allow this behavior (I changed all operations to use Ordering::SeqCst but still got lingering clients)

How to fix
Since SHOW POOLS/SHOW SERVER/SHOW CLIENTS all show the current state of the proxy (as opposed to SHOW STATS which show aggregate values), this PR refactors SHOW POOLS to have it construct the results directly from SHOW SERVER and SHOW CLIENT datasets. This reduces the complexity of stat updates and eliminates the need for having locks when updating pool stats as we only care about updating individual client/server states.

This will change the semantics of maxwait, so instead of it holding the maxwait time ever encountered by a client (connected or disconnected), it will only consider connected clients which should be okay given PgCat tends to hold on to client connections more than Pgbouncer.
2023-05-23 08:44:49 -05:00
dependabot[bot]
d63be9b93a chore(deps): bump toml from 0.7.3 to 0.7.4 (#447)
Bumps [toml](https://github.com/toml-rs/toml) from 0.7.3 to 0.7.4.
- [Commits](https://github.com/toml-rs/toml/compare/toml-v0.7.3...toml-v0.7.4)

---
updated-dependencies:
- dependency-name: toml
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-05-19 08:58:13 -07:00
Lev Kokotov
100778670c Ensure data makes it to the client (#446)
* Ensure data makes it to the client

* flush all buffers
2023-05-18 16:41:22 -07:00
Lev Kokotov
37e3349c24 Optionally clean up server connections (#444)
* Optionally clean up server connections

* move setting to pool

* fix test

* Print setting to screen

* fmt

* Fix pool_settings override in tests
2023-05-18 10:46:55 -07:00
Zain Kabani
7f57a89d75 Fix time based average stats (#442)
* keep track of current stats and zero them after updating averages

* Try tests

* typo

* remove commented test stuff

* Avoid dividing by zero

* Fix test

* refactor, get rid of iterator. do it manually

* trigger build

* Fix
2023-05-17 21:38:10 -07:00
Lev Kokotov
0898461c01 Allow to deploy pools without checking (#438) 2023-05-12 12:48:37 -07:00
Lev Kokotov
52b1b43850 Prewarmer (#435)
* Prewarmer

* hmm

* Tests

* default

* fix test

* Correct configuration

* Added minimal config example

* remove connect_timeout
2023-05-12 09:50:52 -07:00
Zain Kabani
0907f1b77f Improve logging for connection cleanup (#428)
* initial commit

* fix

* fmt
2023-05-11 17:40:10 -07:00
Zain Kabani
73260690b0 Fixes average stats bug (#436)
* Add test

* Fix test

* Add fix
2023-05-11 17:37:58 -07:00
Mostafa Abdelraouf
5056cbe8ed Fix docker-compose dev stack for Apple silicon (#432)
The docker-compose dev setup is broken under Apple silicon, starting the stack fails with the following error. Switching to a different docker image fixes the issue.
2023-05-10 10:24:35 -05:00
Lev Kokotov
571b02e178 Calculate averages correctly and preserve totals like before (#429)
* Reset totals after avg calculation

* like it used to be
2023-05-08 10:06:16 -07:00
Andrew Tanner
159eb89bf0 First try with role reset (#427)
* First try with role rest

* update

* extra line

* Update src/server.rs

Co-authored-by: Lev Kokotov <levkk@users.noreply.github.com>

* Update tests/ruby/misc_spec.rb

Co-authored-by: Lev Kokotov <levkk@users.noreply.github.com>

---------

Co-authored-by: Lev Kokotov <levkk@users.noreply.github.com>
2023-05-05 15:31:27 -07:00
Lev Kokotov
389993bf3e Accurate log messages (#425) 2023-05-05 08:27:19 -07:00
Lev Kokotov
ba5243b6dd Optionally validate config on boot (#423) 2023-05-03 17:07:23 -07:00
Lev Kokotov
128ef72911 lowercase config query (#422)
* lowercase config query

* remove debug
2023-05-03 16:47:20 -07:00
Lev Kokotov
811885f464 Actually plugins (#421)
* more plugins

* clean up

* fix tests

* fix flakey test
2023-05-03 16:13:45 -07:00
dependabot[bot]
d5e329fec5 chore(deps): bump regex from 1.8.0 to 1.8.1 (#413)
Bumps [regex](https://github.com/rust-lang/regex) from 1.8.0 to 1.8.1.
- [Release notes](https://github.com/rust-lang/regex/releases)
- [Changelog](https://github.com/rust-lang/regex/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/regex/commits/1.8.1)

---
updated-dependencies:
- dependency-name: regex
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-05-03 10:00:05 -07:00
Lev Kokotov
09e54e1175 Plugins! (#420)
* Some queries

* Plugins!!

* cleanup

* actual names

* the actual plugins

* comment

* fix tests

* Tests

* unused errors

* Increase reaper rate to actually enforce settings

* ok
2023-05-03 09:13:05 -07:00
dependabot[bot]
23819c8549 chore(deps): bump rustls from 0.21.0 to 0.21.1 (#419)
Bumps [rustls](https://github.com/rustls/rustls) from 0.21.0 to 0.21.1.
- [Release notes](https://github.com/rustls/rustls/releases)
- [Changelog](https://github.com/rustls/rustls/blob/main/RELEASE_NOTES.md)
- [Commits](https://github.com/rustls/rustls/compare/v/0.21.0...v/0.21.1)

---
updated-dependencies:
- dependency-name: rustls
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-05-02 07:32:44 -07:00
Jose Fernández
7dfbd993f2 Add dns_cache for server addresses as in pgbouncer (#249)
* Add dns_cache so server addresses are cached and invalidated when DNS changes.

Adds a module to deal with dns_cache feature. It's
main struct is CachedResolver, which is a simple thread safe
hostname <-> Ips cache with the ability to refresh resolutions
every `dns_max_ttl` seconds. This way, a client can check whether its
ip address has changed.

* Allow reloading dns cached

* Add documentation for dns_cached
2023-05-02 10:26:40 +02:00
Lev Kokotov
3601130ba1 Readme update (#418)
* Readme update

* m

* wording
2023-04-30 09:44:25 -07:00
Lev Kokotov
0d504032b2 Server TLS (#417)
* Server TLS

* Finish up TLS

* thats it

* diff

* remove dead code

* maybe?

* dirty shutdown

* skip flakey test

* remove unused error

* fetch config once
2023-04-30 09:41:46 -07:00
Lev Kokotov
4a87b4807d Add more pool settings (#416)
* Add some pool settings

* fmt
2023-04-26 16:33:26 -07:00
Shawn
cb5ff40a59 fix typo (#415)
chore: typo
2023-04-26 08:28:54 -07:00
dependabot[bot]
62b2d994c1 chore(deps): bump regex from 1.7.3 to 1.8.0 (#411)
Bumps [regex](https://github.com/rust-lang/regex) from 1.7.3 to 1.8.0.
- [Release notes](https://github.com/rust-lang/regex/releases)
- [Changelog](https://github.com/rust-lang/regex/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/regex/commits)

---
updated-dependencies:
- dependency-name: regex
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-21 06:33:52 -07:00
Lev Kokotov
66805d7e77 README updates (#409)
* Better table

* add image

* promote auth passthrough to stable

* fmt
2023-04-20 07:53:55 -07:00
Lev Kokotov
4ccc1e7fa3 Fix CONFIG (#408)
Fix readme
2023-04-19 07:45:26 -07:00
Lev Kokotov
3dae3d0777 Separate server and client passwords optionally (#407)
* Separate server and user passwords

* config
2023-04-18 09:57:17 -07:00
dependabot[bot]
a18eb42df5 chore(deps): bump serde from 1.0.159 to 1.0.160 (#404)
Bumps [serde](https://github.com/serde-rs/serde) from 1.0.159 to 1.0.160.
- [Release notes](https://github.com/serde-rs/serde/releases)
- [Commits](https://github.com/serde-rs/serde/compare/v1.0.159...v1.0.160)

---
updated-dependencies:
- dependency-name: serde
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-14 10:25:00 -07:00
dependabot[bot]
6aacf1fa19 chore(deps): bump serde_derive from 1.0.159 to 1.0.160 (#403)
Bumps [serde_derive](https://github.com/serde-rs/serde) from 1.0.159 to 1.0.160.
- [Release notes](https://github.com/serde-rs/serde/releases)
- [Commits](https://github.com/serde-rs/serde/compare/v1.0.159...v1.0.160)

---
updated-dependencies:
- dependency-name: serde_derive
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-14 10:24:52 -07:00
dependabot[bot]
8e99e65215 chore(deps): bump sqlparser from 0.32.0 to 0.33.0 (#399)
Bumps [sqlparser](https://github.com/sqlparser-rs/sqlparser-rs) from 0.32.0 to 0.33.0.
- [Release notes](https://github.com/sqlparser-rs/sqlparser-rs/releases)
- [Changelog](https://github.com/sqlparser-rs/sqlparser-rs/blob/main/CHANGELOG.md)
- [Commits](https://github.com/sqlparser-rs/sqlparser-rs/compare/v0.32.0...v0.33.0)

---
updated-dependencies:
- dependency-name: sqlparser
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-14 10:24:44 -07:00
dependabot[bot]
5dfbc102a9 chore(deps): bump hyper from 0.14.25 to 0.14.26 (#406)
Bumps [hyper](https://github.com/hyperium/hyper) from 0.14.25 to 0.14.26.
- [Release notes](https://github.com/hyperium/hyper/releases)
- [Changelog](https://github.com/hyperium/hyper/blob/v0.14.26/CHANGELOG.md)
- [Commits](https://github.com/hyperium/hyper/compare/v0.14.25...v0.14.26)

---
updated-dependencies:
- dependency-name: hyper
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-14 10:23:52 -07:00
Cluas
bae12fca99 feat: set keepalive for pgcat server itself (#402)
* feat: set keepalive for pgcat server self

* docs: note also set for client
2023-04-12 09:29:43 -07:00
Lev Kokotov
421c5d4b64 Load config on client connect (#401) 2023-04-11 10:32:48 -07:00
Kian-Meng Ang
d568739db9 Fix typos (#398)
Found via `typos --format brief`
2023-04-10 18:37:16 -07:00
Lev Kokotov
692353c839 A couple things (#397)
* Format cleanup

* fmt

* finally
2023-04-10 14:51:01 -07:00
Lev Kokotov
a62f6b0eea Fix port; add user pool mode (#395)
* Fix port; add user pool mode

* will probably break our session/transaction mode tests
2023-04-05 15:06:19 -07:00
dependabot[bot]
89e15f09b5 chore(deps): bump tokio-rustls from 0.23.4 to 0.24.0 (#394)
Bumps [tokio-rustls](https://github.com/tokio-rs/tls) from 0.23.4 to 0.24.0.
- [Release notes](https://github.com/tokio-rs/tls/releases)
- [Commits](https://github.com/tokio-rs/tls/commits)

---
updated-dependencies:
- dependency-name: tokio-rustls
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-02 23:00:09 -07:00
Mostafa Abdelraouf
7ddd23b514 Protocol-level test helpers (#393)
I needed to have granular control over protocol message testing. For example, being able to send protocol messages one-by-one and then be able to inspect the results.

In order to do that, I created this low-level ruby client that can be used to send protocol messages in any order without blocking and also allows inspection of response messages.
2023-04-01 15:27:57 -05:00
dependabot[bot]
faa9c1f64a chore(deps): bump futures from 0.3.27 to 0.3.28 (#392)
Bumps [futures](https://github.com/rust-lang/futures-rs) from 0.3.27 to 0.3.28.
- [Release notes](https://github.com/rust-lang/futures-rs/releases)
- [Changelog](https://github.com/rust-lang/futures-rs/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/futures-rs/compare/0.3.27...0.3.28)

---
updated-dependencies:
- dependency-name: futures
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-31 09:35:01 -07:00
dependabot[bot]
9094988491 chore(deps): bump postgres-protocol from 0.6.4 to 0.6.5 (#391)
Bumps [postgres-protocol](https://github.com/sfackler/rust-postgres) from 0.6.4 to 0.6.5.
- [Release notes](https://github.com/sfackler/rust-postgres/releases)
- [Commits](https://github.com/sfackler/rust-postgres/compare/postgres-protocol-v0.6.4...postgres-protocol-v0.6.5)

---
updated-dependencies:
- dependency-name: postgres-protocol
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-31 09:34:51 -07:00
63 changed files with 7597 additions and 2411 deletions

View File

@@ -39,7 +39,7 @@ log_client_connections = false
log_client_disconnections = false
# Reload config automatically if it changes.
autoreload = true
autoreload = 15000
# TLS
tls_certificate = ".circleci/server.cert"

14
.editorconfig Normal file
View File

@@ -0,0 +1,14 @@
root = true
[*]
trim_trailing_whitespace = true
insert_final_newline = true
[*.rs]
indent_style = space
indent_size = 4
max_line_length = 120
[*.toml]
indent_style = space
indent_size = 2

View File

@@ -1,6 +1,11 @@
name: Build and Push
on: push
on:
push:
branches:
- main
tags:
- v*
env:
registry: ghcr.io

149
CONFIG.md
View File

@@ -1,4 +1,4 @@
# PgCat Configurations
# PgCat Configurations
## `general` Section
### host
@@ -49,6 +49,14 @@ default: 30000 # milliseconds
How long an idle connection with a server is left open (ms).
### server_lifetime
```
path: general.server_lifetime
default: 86400000 # 24 hours
```
Max connection lifetime before it's closed, even if actively used.
### idle_client_in_transaction_timeout
```
path: general.idle_client_in_transaction_timeout
@@ -108,10 +116,10 @@ If we should log client disconnections
### autoreload
```
path: general.autoreload
default: false
default: 15000 # milliseconds
```
When set to true, PgCat reloads configs if it detects a change in the config file.
When set, PgCat automatically reloads its configurations at the specified interval (in milliseconds) if it detects changes in the configuration file. The default interval is 15000 milliseconds or 15 seconds.
### worker_threads
```
@@ -143,7 +151,13 @@ path: general.tcp_keepalives_interval
default: 5
```
Number of seconds between keepalive packets.
### tcp_user_timeout
```
path: general.tcp_user_timeout
default: 10000
```
A linux-only parameters that defines the amount of time in milliseconds that transmitted data may remain unacknowledged or buffered data may remain untransmitted (due to zero window size) before TCP will forcibly disconnect
### tls_certificate
```
@@ -152,7 +166,7 @@ default: <UNSET>
example: "server.cert"
```
Path to TLS Certficate file to use for TLS connections
Path to TLS Certificate file to use for TLS connections
### tls_private_key
```
@@ -175,40 +189,26 @@ Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DAT
### admin_password
```
path: general.admin_password
default: <UNSET>
default: "admin_pass"
```
Password to access the virtual administrative database
### auth_query (experimental)
### dns_cache_enabled
```
path: general.auth_query
default: <UNSET>
path: general.dns_cache_enabled
default: false
```
When enabled, ip resolutions for server connections specified using hostnames will be cached
and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found
old ip connections are closed (gracefully) and new connections will start using new ip.
Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
established using the database configured in the pool. This parameter is inherited by every pool
and can be redefined in pool configuration.
### auth_query_user (experimental)
### dns_max_ttl
```
path: general.auth_query_user
default: <UNSET>
path: general.dns_max_ttl
default: 30
```
User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
This parameter is inherited by every pool and can be redefined in pool configuration.
### auth_query_password (experimental)
```
path: general.auth_query_password
default: <UNSET>
```
Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
This parameter is inherited by every pool and can be redefined in pool configuration.
Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
## `pools.<pool_name>` Section
@@ -243,7 +243,7 @@ If the client doesn't specify, PgCat routes traffic to this role by default.
`replica` round-robin between replicas only without touching the primary,
`primary` all queries go to the primary unless otherwise specified.
### query_parser_enabled (experimental)
### query_parser_enabled
```
path: pools.<pool_name>.query_parser_enabled
default: true
@@ -264,7 +264,7 @@ If the query parser is enabled and this setting is enabled, the primary will be
load balancing of read queries. Otherwise, the primary will only be used for write
queries. The primary can always be explicitly selected with our custom protocol.
### sharding_key_regex (experimental)
### sharding_key_regex
```
path: pools.<pool_name>.sharding_key_regex
default: <UNSET>
@@ -286,7 +286,40 @@ Current options:
`pg_bigint_hash`: PARTITION BY HASH (Postgres hashing function)
`sha1`: A hashing function based on SHA1
### automatic_sharding_key (experimental)
### auth_query
```
path: pools.<pool_name>.auth_query
default: <UNSET>
example: "SELECT $1"
```
Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
established using the database configured in the pool. This parameter is inherited by every pool
and can be redefined in pool configuration.
### auth_query_user
```
path: pools.<pool_name>.auth_query_user
default: <UNSET>
example: "sharding_user"
```
User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
This parameter is inherited by every pool and can be redefined in pool configuration.
### auth_query_password
```
path: pools.<pool_name>.auth_query_password
default: <UNSET>
example: "sharding_user"
```
Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
This parameter is inherited by every pool and can be redefined in pool configuration.
### automatic_sharding_key
```
path: pools.<pool_name>.automatic_sharding_key
default: <UNSET>
@@ -311,30 +344,6 @@ default: 3000
Connect timeout can be overwritten in the pool
### auth_query (experimental)
```
path: general.auth_query
default: <UNSET>
```
Auth query can be overwritten in the pool
### auth_query_user (experimental)
```
path: general.auth_query_user
default: <UNSET>
```
Auth query user can be overwritten in the pool
### auth_query_password (experimental)
```
path: general.auth_query_password
default: <UNSET>
```
Auth query password can be overwritten in the pool
## `pools.<pool_name>.users.<user_index>` Section
### username
@@ -343,7 +352,8 @@ path: pools.<pool_name>.users.<user_index>.username
default: "sharding_user"
```
Postgresql username
PostgreSQL username used to authenticate the user and connect to the server
if `server_username` is not set.
### password
```
@@ -351,7 +361,26 @@ path: pools.<pool_name>.users.<user_index>.password
default: "sharding_user"
```
Postgresql password
PostgreSQL password used to authenticate the user and connect to the server
if `server_password` is not set.
### server_username
```
path: pools.<pool_name>.users.<user_index>.server_username
default: <UNSET>
example: "another_user"
```
PostgreSQL username used to connect to the server.
### server_password
```
path: pools.<pool_name>.users.<user_index>.server_password
default: <UNSET>
example: "another_password"
```
PostgreSQL password used to connect to the server.
### pool_size
```
@@ -382,7 +411,7 @@ default: [["127.0.0.1", 5432, "primary"], ["localhost", 5432, "replica"]]
Array of servers in the shard, each server entry is an array of `[host, port, role]`
### mirrors (experimental)
### mirrors
```
path: pools.<pool_name>.shards.<shard_index>.mirrors
default: <UNSET>

1220
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
[package]
name = "pgcat"
version = "1.0.0"
version = "1.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -8,18 +8,18 @@ edition = "2021"
tokio = { version = "1", features = ["full"] }
bytes = "1"
md-5 = "0.10"
bb8 = "0.8.0"
bb8 = "0.8.1"
async-trait = "0.1"
rand = "0.8"
chrono = "0.4"
sha-1 = "0.10"
toml = "0.7"
serde = "1"
serde = { version = "1", features = ["derive"] }
serde_derive = "1"
regex = "1"
num_cpus = "1"
once_cell = "1"
sqlparser = "0.32.0"
sqlparser = {version = "0.34", features = ["visitor"] }
log = "0.4"
arc-swap = "1"
env_logger = "0.10"
@@ -28,7 +28,7 @@ hmac = "0.12"
sha2 = "0.10"
base64 = "0.21"
stringprep = "0.1"
tokio-rustls = "0.23"
tokio-rustls = "0.24"
rustls-pemfile = "1"
hyper = { version = "0.14", features = ["full"] }
phf = { version = "0.11.1", features = ["macros"] }
@@ -37,8 +37,19 @@ futures = "0.3"
socket2 = { version = "0.4.7", features = ["all"] }
nix = "0.26.2"
atomic_enum = "0.2.0"
postgres-protocol = "0.6.4"
postgres-protocol = "0.6.5"
fallible-iterator = "0.2"
pin-project = "1"
webpki-roots = "0.23"
rustls = { version = "0.21", features = ["dangerous_configuration"] }
trust-dns-resolver = "0.22.0"
tokio-test = "0.4.2"
serde_json = "1"
itertools = "0.10"
clap = { version = "4.3.1", features = ["derive", "env"] }
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["json"]}
[target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0"

View File

@@ -18,25 +18,56 @@ PostgreSQL pooler and proxy (like PgBouncer) with support for sharding, load bal
| Failover | **Stable** | Queries are automatically rerouted around broken replicas, validated by regular health checks. |
| Admin database statistics | **Stable** | Pooler statistics and administration via the `pgbouncer` and `pgcat` databases. |
| Prometheus statistics | **Stable** | Statistics are reported via a HTTP endpoint for Prometheus. |
| Client TLS | **Stable** | Clients can connect to the pooler using TLS/SSL. |
| SSL/TLS | **Stable** | Clients can connect to the pooler using TLS. Pooler can connect to Postgres servers using TLS. |
| Client/Server authentication | **Stable** | Clients can connect using MD5 authentication, supported by `libpq` and all Postgres client drivers. PgCat can connect to Postgres using MD5 and SCRAM-SHA-256. |
| Live configuration reloading | **Stable** | Identical to PgBouncer; all settings can be reloaded dynamically (except `host` and `port`). |
| Auth passthrough | **Stable** | MD5 password authentication can be configured to use an `auth_query` so no cleartext passwords are needed in the config file.|
| Sharding using extended SQL syntax | **Experimental** | Clients can dynamically configure the pooler to route queries to specific shards. |
| Sharding using comments parsing/Regex | **Experimental** | Clients can include shard information (sharding key, shard ID) in the query comments. |
| Automatic sharding | **Experimental** | PgCat can parse queries, detect sharding keys automatically, and route queries to the correct shard. |
| Mirroring | **Experimental** | Mirror queries between multiple databases in order to test servers with realistic production traffic. |
| Auth passthrough | **Experimental** | MD5 password authentication can be configured to use an `auth_query` so no cleartext passwords are needed in the config file. |
| Password rotation | **Experimental** | Allows to rotate passwords without downtime or using third-party tools to manage Postgres authentication. |
## Status
PgCat is stable and used in production to serve hundreds of thousands of queries per second. Some features remain experimental and are being actively developed. They are optional and can be enabled through configuration.
PgCat is stable and used in production to serve hundreds of thousands of queries per second.
| | |
|-|-|
|<a href="https://tech.instacart.com/adopting-pgcat-a-nextgen-postgres-proxy-3cf284e68c2f"><img src="./images/instacart.webp" height="70" width="auto"></a>|<a href="https://postgresml.org/blog/scaling-postgresml-to-one-million-requests-per-second"><img src="./images/postgresml.webp" height="70" width="auto"></a>|
| [Instacart](https://tech.instacart.com/adopting-pgcat-a-nextgen-postgres-proxy-3cf284e68c2f) | [PostgresML](https://postgresml.org/blog/scaling-postgresml-to-one-million-requests-per-second) |
<table>
<tr>
<td>
<a href="https://tech.instacart.com/adopting-pgcat-a-nextgen-postgres-proxy-3cf284e68c2f">
<img src="./images/instacart.webp" height="70" width="auto">
</a>
</td>
<td>
<a href="https://postgresml.org/blog/scaling-postgresml-to-one-million-requests-per-second">
<img src="./images/postgresml.webp" height="70" width="auto">
</a>
</td>
<td>
<a href="https://onesignal.com">
<img src="./images/one_signal.webp" height="70" width="auto">
</a>
</td>
</tr>
<tr>
<td>
<a href="https://tech.instacart.com/adopting-pgcat-a-nextgen-postgres-proxy-3cf284e68c2f">
Instacart
</a>
</td>
<td>
<a href="https://postgresml.org/blog/scaling-postgresml-to-one-million-requests-per-second">
PostgresML
</a>
</td>
<td>
OneSignal
</td>
</tr>
</table>
Some features remain experimental and are being actively developed. They are optional and can be enabled through configuration.
## Deployment
@@ -100,7 +131,7 @@ You can open a Docker development environment where you can debug tests easier.
./dev/script/console
```
This will open a terminal in an environment similar to that used in tests. In there, you can compile the pooler, run tests, do some debugging with the test environment, etc. Objects compiled inside the contaner (and bundled gems) will be placed in `dev/cache` so they don't interfere with what you have on your machine.
This will open a terminal in an environment similar to that used in tests. In there, you can compile the pooler, run tests, do some debugging with the test environment, etc. Objects compiled inside the container (and bundled gems) will be placed in `dev/cache` so they don't interfere with what you have on your machine.
## Usage
@@ -245,12 +276,6 @@ The config can be reloaded by sending a `kill -s SIGHUP` to the process or by qu
Mirroring allows to route queries to multiple databases at the same time. This is useful for prewarning replicas before placing them into the active configuration, or for testing different versions of Postgres with live traffic.
### Password rotation
Password rotation allows to specify multiple passwords for a user, so they can connect to PgCat with multiple credentials. This allows distributed applications to change their configuration (connection strings) gradually and for PgCat to monitor their progression in admin statistics. Once the new secret is deployed everywhere, the old one can be removed from PgCat.
This also decouples server passwords from client passwords, allowing to change one without necessarily changing the other.
## License
PgCat is free and open source, released under the MIT license.

View File

@@ -1,4 +1,4 @@
FROM rust:bullseye
FROM rust:1.70-bullseye
# Dependencies
RUN apt-get update -y \

View File

@@ -25,7 +25,7 @@ x-common-env-pg:
services:
main:
image: kubernetes/pause
image: gcr.io/google_containers/pause:3.2
ports:
- 6432

View File

@@ -38,9 +38,6 @@ log_client_connections = false
# If we should log client disconnections
log_client_disconnections = false
# Reload config automatically if it changes.
autoreload = false
# TLS
# tls_certificate = "server.cert"
# tls_private_key = "server.key"
@@ -76,7 +73,7 @@ query_parser_enabled = true
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
# load balancing of read queries. Otherwise, the primary will only be used for write
# queries. The primary can always be explicitely selected with our custom protocol.
# queries. The primary can always be explicitly selected with our custom protocol.
primary_reads_enabled = true
# So what if you wanted to implement a different hashing function,

BIN
images/one_signal.webp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

22
pgcat.minimal.toml Normal file
View File

@@ -0,0 +1,22 @@
# This is an example of the most basic config
# that will mimic what PgBouncer does in transaction mode with one server.
[general]
host = "0.0.0.0"
port = 6433
admin_username = "pgcat"
admin_password = "pgcat"
[pools.pgml.users.0]
username = "postgres"
password = "postgres"
pool_size = 10
min_pool_size = 1
pool_mode = "transaction"
[pools.pgml.shards.0]
servers = [
["127.0.0.1", 28815, "primary"]
]
database = "postgres"

View File

@@ -23,6 +23,9 @@ connect_timeout = 5000 # milliseconds
# How long an idle connection with a server is left open (ms).
idle_timeout = 30000 # milliseconds
# Max connection lifetime before it's closed, even if actively used.
server_lifetime = 86400000 # 24 hours
# How long a client is allowed to be idle while in a transaction (ms).
idle_client_in_transaction_timeout = 0 # milliseconds
@@ -45,7 +48,7 @@ log_client_connections = false
log_client_disconnections = false
# When set to true, PgCat reloads configs if it detects a change in the config file.
autoreload = false
autoreload = 15000
# Number of worker threads the Runtime will use (4 by default).
worker_threads = 5
@@ -57,17 +60,81 @@ tcp_keepalives_count = 5
# Number of seconds between keepalive packets.
tcp_keepalives_interval = 5
# Path to TLS Certficate file to use for TLS connections
# Handle prepared statements.
prepared_statements = true
# Prepared statements server cache size.
prepared_statements_cache_size = 500
# Path to TLS Certificate file to use for TLS connections
# tls_certificate = ".circleci/server.cert"
# Path to TLS private key file to use for TLS connections
# tls_private_key = ".circleci/server.key"
# Enable/disable server TLS
server_tls = false
# Verify server certificate is completely authentic.
verify_server_certificate = false
# User name to access the virtual administrative database (pgbouncer or pgcat)
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
admin_username = "admin_user"
# Password to access the virtual administrative database
admin_password = "admin_pass"
# Default plugins that are configured on all pools.
[plugins]
# Prewarmer plugin that runs queries on server startup, before giving the connection
# to the client.
[plugins.prewarmer]
enabled = false
queries = [
"SELECT pg_prewarm('pgbench_accounts')",
]
# Log all queries to stdout.
[plugins.query_logger]
enabled = false
# Block access to tables that Postgres does not allow us to control.
[plugins.table_access]
enabled = false
tables = [
"pg_user",
"pg_roles",
"pg_database",
]
# Intercept user queries and give a fake reply.
[plugins.intercept]
enabled = true
[plugins.intercept.queries.0]
query = "select current_database() as a, current_schemas(false) as b"
schema = [
["a", "text"],
["b", "text"],
]
result = [
["${DATABASE}", "{public}"],
]
[plugins.intercept.queries.1]
query = "select current_database(), current_schema(), current_user"
schema = [
["current_database", "text"],
["current_schema", "text"],
["current_user", "text"],
]
result = [
["${DATABASE}", "public", "${USER}"],
]
# pool configs are structured as pool.<pool_name>
# the pool_name is what clients use as database name when connecting.
# For a pool named `sharded_db`, clients access that pool using connection string like
@@ -113,6 +180,21 @@ primary_reads_enabled = true
# `sha1`: A hashing function based on SHA1
sharding_function = "pg_bigint_hash"
# Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
# established using the database configured in the pool. This parameter is inherited by every pool
# and can be redefined in pool configuration.
# auth_query = "SELECT $1"
# User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
# specified in `auth_query_user`. The connection will be established using the database configured in the pool.
# This parameter is inherited by every pool and can be redefined in pool configuration.
# auth_query_user = "sharding_user"
# Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
# specified in `auth_query_user`. The connection will be established using the database configured in the pool.
# This parameter is inherited by every pool and can be redefined in pool configuration.
# auth_query_password = "sharding_user"
# Automatically parse this from queries and route queries to the right shard!
# automatic_sharding_key = "data.id"
@@ -122,26 +204,86 @@ idle_timeout = 40000
# Connect timeout can be overwritten in the pool
connect_timeout = 3000
# auth_query = "SELECT * FROM public.user_lookup('$1')"
# auth_query_user = "postgres"
# auth_query_password = "postgres"
# When enabled, ip resolutions for server connections specified using hostnames will be cached
# and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found
# old ip connections are closed (gracefully) and new connections will start using new ip.
# dns_cache_enabled = false
# Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
# dns_max_ttl = 30
# Plugins can be configured on a pool-per-pool basis. This overrides the global plugins setting,
# so all plugins have to be configured here again.
[pool.sharded_db.plugins]
[pools.sharded_db.plugins.prewarmer]
enabled = true
queries = [
"SELECT pg_prewarm('pgbench_accounts')",
]
[pools.sharded_db.plugins.query_logger]
enabled = false
[pools.sharded_db.plugins.table_access]
enabled = false
tables = [
"pg_user",
"pg_roles",
"pg_database",
]
[pools.sharded_db.plugins.intercept]
enabled = true
[pools.sharded_db.plugins.intercept.queries.0]
query = "select current_database() as a, current_schemas(false) as b"
schema = [
["a", "text"],
["b", "text"],
]
result = [
["${DATABASE}", "{public}"],
]
[pools.sharded_db.plugins.intercept.queries.1]
query = "select current_database(), current_schema(), current_user"
schema = [
["current_database", "text"],
["current_schema", "text"],
["current_user", "text"],
]
result = [
["${DATABASE}", "public", "${USER}"],
]
# User configs are structured as pool.<pool_name>.users.<user_index>
# This secion holds the credentials for users that may connect to this cluster
# This section holds the credentials for users that may connect to this cluster
[pools.sharded_db.users.0]
# Postgresql username
# PostgreSQL username used to authenticate the user and connect to the server
# if `server_username` is not set.
username = "sharding_user"
# Postgresql password
# PostgreSQL password used to authenticate the user and connect to the server
# if `server_password` is not set.
password = "sharding_user"
# # Passwords the client can use to connect. Useful for password rotations.
# secrets = [ "secret_one", "secret_two" ]
pool_mode = "session"
# PostgreSQL username used to connect to the server.
# server_username = "another_user"
# PostgreSQL password used to connect to the server.
# server_password = "another_password"
# Maximum number of server connections that can be established for this user
# The maximum number of connection from a single Pgcat process to any database in the cluster
# is the sum of pool_size across all users.
pool_size = 9
# Maximum query duration. Dangerous, but protects against DBs that died in a non-obvious way.
# 0 means it is disabled.
statement_timeout = 0
@@ -186,6 +328,8 @@ sharding_function = "pg_bigint_hash"
username = "simple_user"
password = "simple_user"
pool_size = 5
min_pool_size = 3
server_lifetime = 60000
statement_timeout = 0
[pools.simple_db.shards.0]

View File

@@ -1,4 +1,5 @@
use crate::pool::BanReason;
use crate::stats::pool::PoolStats;
use bytes::{Buf, BufMut, BytesMut};
use log::{error, info, trace};
use nix::sys::signal::{self, Signal};
@@ -12,9 +13,9 @@ use tokio::time::Instant;
use crate::config::{get_config, reload_config, VERSION};
use crate::errors::Error;
use crate::messages::*;
use crate::pool::ClientServerMap;
use crate::pool::{get_all_pools, get_pool};
use crate::stats::{get_client_stats, get_pool_stats, get_server_stats, ClientState, ServerState};
use crate::ClientServerMap;
use crate::stats::{get_client_stats, get_server_stats, ClientState, ServerState};
pub fn generate_server_info_for_admin() -> BytesMut {
let mut server_info = BytesMut::new();
@@ -83,6 +84,10 @@ where
shutdown(stream).await
}
"SHOW" => match query_parts[1].to_ascii_uppercase().as_str() {
"HELP" => {
trace!("SHOW HELP");
show_help(stream).await
}
"BANS" => {
trace!("SHOW BANS");
show_bans(stream).await
@@ -254,41 +259,51 @@ async fn show_pools<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let all_pool_stats = get_pool_stats();
let pool_lookup = PoolStats::construct_pool_lookup();
let mut res = BytesMut::new();
res.put(row_description(&PoolStats::generate_header()));
pool_lookup.iter().for_each(|(_identifier, pool_stats)| {
res.put(data_row(&pool_stats.generate_row()));
});
res.put(command_complete("SHOW"));
let columns = vec![
("database", DataType::Text),
("user", DataType::Text),
("secret", DataType::Text),
("pool_mode", DataType::Text),
("cl_idle", DataType::Numeric),
("cl_active", DataType::Numeric),
("cl_waiting", DataType::Numeric),
("cl_cancel_req", DataType::Numeric),
("sv_active", DataType::Numeric),
("sv_idle", DataType::Numeric),
("sv_used", DataType::Numeric),
("sv_tested", DataType::Numeric),
("sv_login", DataType::Numeric),
("maxwait", DataType::Numeric),
("maxwait_us", DataType::Numeric),
// ReadyForQuery
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, &res).await
}
/// Show all available options.
async fn show_help<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut res = BytesMut::new();
let detail_msg = vec![
"",
"SHOW HELP|CONFIG|DATABASES|POOLS|CLIENTS|SERVERS|USERS|VERSION",
// "SHOW PEERS|PEER_POOLS", // missing PEERS|PEER_POOLS
// "SHOW FDS|SOCKETS|ACTIVE_SOCKETS|LISTS|MEM|STATE", // missing FDS|SOCKETS|ACTIVE_SOCKETS|MEM|STATE
"SHOW LISTS",
// "SHOW DNS_HOSTS|DNS_ZONES", // missing DNS_HOSTS|DNS_ZONES
"SHOW STATS", // missing STATS_TOTALS|STATS_AVERAGES|TOTALS
"SET key = arg",
"RELOAD",
"PAUSE [<db>, <user>]",
"RESUME [<db>, <user>]",
// "DISABLE <db>", // missing
// "ENABLE <db>", // missing
// "RECONNECT [<db>]", missing
// "KILL <db>",
// "SUSPEND",
"SHUTDOWN",
// "WAIT_CLOSE [<db>]", // missing
];
let mut res = BytesMut::new();
res.put(row_description(&columns));
for (_, pool_stats) in all_pool_stats {
let mut row = vec![
pool_stats.database(),
pool_stats.user(),
pool_stats.redacted_secret(),
pool_stats.pool_mode().to_string(),
];
pool_stats.populate_row(&mut row);
pool_stats.clear_maxwait();
res.put(data_row(&row));
}
res.put(notify("Console usage", detail_msg.join("\n\t")));
res.put(command_complete("SHOW"));
// ReadyForQuery
@@ -336,17 +351,17 @@ where
let paused = pool.paused();
res.put(data_row(&vec![
address.name(), // name
address.host.to_string(), // host
address.port.to_string(), // port
database_name.to_string(), // database
pool_config.user.username.to_string(), // force_user
pool_config.user.pool_size.to_string(), // pool_size
"0".to_string(), // min_pool_size
"0".to_string(), // reserve_pool
pool_config.pool_mode.to_string(), // pool_mode
pool_config.user.pool_size.to_string(), // max_connections
pool_state.connections.to_string(), // current_connections
address.name(), // name
address.host.to_string(), // host
address.port.to_string(), // port
database_name.to_string(), // database
pool_config.user.username.to_string(), // force_user
pool_config.user.pool_size.to_string(), // pool_size
pool_config.user.min_pool_size.unwrap_or(0).to_string(), // min_pool_size
"0".to_string(), // reserve_pool
pool_config.pool_mode.to_string(), // pool_mode
pool_config.user.pool_size.to_string(), // max_connections
pool_state.connections.to_string(), // current_connections
match paused {
// paused
true => "1".to_string(),
@@ -727,6 +742,9 @@ where
("bytes_sent", DataType::Numeric),
("bytes_received", DataType::Numeric),
("age_seconds", DataType::Numeric),
("prepare_cache_hit", DataType::Numeric),
("prepare_cache_miss", DataType::Numeric),
("prepare_cache_size", DataType::Numeric),
];
let new_map = get_server_stats();
@@ -750,6 +768,18 @@ where
.duration_since(server.connect_time())
.as_secs()
.to_string(),
server
.prepared_hit_count
.load(Ordering::Relaxed)
.to_string(),
server
.prepared_miss_count
.load(Ordering::Relaxed)
.to_string(),
server
.prepared_cache_size
.load(Ordering::Relaxed)
.to_string(),
];
res.put(data_row(&row));
@@ -782,7 +812,7 @@ where
let database = parts[0];
let user = parts[1];
match get_pool(database, user, None) {
match get_pool(database, user) {
Some(pool) => {
pool.pause();
@@ -829,7 +859,7 @@ where
let database = parts[0];
let user = parts[1];
match get_pool(database, user, None) {
match get_pool(database, user) {
Some(pool) => {
pool.resume();
@@ -897,20 +927,13 @@ where
res.put(row_description(&vec![
("name", DataType::Text),
("pool_mode", DataType::Text),
("secret", DataType::Text),
]));
for (user_pool, pool) in get_all_pools() {
let pool_config = &pool.settings;
let redacted_secret = match user_pool.secret {
Some(secret) => format!("****{}", &secret[secret.len() - 4..]),
None => "<no secret>".to_string(),
};
res.put(data_row(&vec![
user_pool.user.clone(),
pool_config.pool_mode.to_string(),
redacted_secret,
]));
}

View File

@@ -1,452 +0,0 @@
//! Module implementing various client authentication mechanisms.
//!
//! Currently supported: plain (via TLS), md5 (via TLS and plain text connection).
use crate::errors::Error;
use crate::tokio::io::AsyncReadExt;
use crate::{
auth_passthrough::AuthPassthrough,
config::get_config,
messages::{
error_response, md5_hash_password, md5_hash_second_pass, write_all, wrong_password,
},
pool::{get_pool, ConnectionPool},
};
use bytes::{BufMut, BytesMut};
use log::debug;
async fn refetch_auth_hash<S>(
pool: &ConnectionPool,
stream: &mut S,
username: &str,
pool_name: &str,
) -> Result<String, Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
let config = get_config();
debug!("Fetching auth hash");
if config.is_auth_query_configured() {
let address = pool.address(0, 0);
if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) {
let hash = apt.fetch_hash(address).await?;
debug!("Auth query succeeded");
return Ok(hash);
}
} else {
debug!("Auth query not configured on pool");
}
error_response(
stream,
&format!(
"No password set and auth passthrough failed for database: {}, user: {}",
pool_name, username
),
)
.await?;
Err(Error::ClientError(format!(
"Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.",
pool_name, username
)))
}
/// Read 'p' message from client.
async fn response<R>(stream: &mut R) -> Result<Vec<u8>, Error>
where
R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send,
{
let code = match stream.read_u8().await {
Ok(code) => code,
Err(_) => {
return Err(Error::SocketError(
"Error reading password code from client".to_string(),
))
}
};
if code as char != 'p' {
return Err(Error::SocketError(format!("Expected p, got {}", code)));
}
let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => {
return Err(Error::SocketError(
"Error reading password length from client".to_string(),
))
}
};
let mut response = vec![0; (len - 4) as usize];
// Too short to be a password (null-terminated)
if response.len() < 2 {
return Err(Error::ClientError(format!("Password response too short")));
}
match stream.read_exact(&mut response).await {
Ok(_) => (),
Err(_) => {
return Err(Error::SocketError(
"Error reading password from client".to_string(),
))
}
};
Ok(response.to_vec())
}
/// Make sure the pool we authenticated to has at least one server connection
/// that can serve our request.
async fn validate_pool<W>(
stream: &mut W,
mut pool: ConnectionPool,
username: &str,
pool_name: &str,
) -> Result<(), Error>
where
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
if !pool.validated() {
match pool.validate().await {
Ok(_) => Ok(()),
Err(err) => {
error_response(
stream,
&format!("Pool down for database: {}, user: {}", pool_name, username,),
)
.await?;
Err(Error::ClientError(format!("Pool down: {:?}", err)))
}
}
} else {
Ok(())
}
}
/// Clear text authentication.
///
/// The client will send the password in plain text over the wire.
/// To protect against obvious security issues, this is only used over TLS.
///
/// Clear text authentication is used to support zero-downtime password rotation.
/// It allows the client to use multiple passwords when talking to the PgCat
/// while the password is being rotated across multiple app instances.
pub struct ClearText {
username: String,
pool_name: String,
application_name: String,
}
impl ClearText {
/// Create a new ClearText authentication mechanism.
pub fn new(username: &str, pool_name: &str, application_name: &str) -> ClearText {
ClearText {
username: username.to_string(),
pool_name: pool_name.to_string(),
application_name: application_name.to_string(),
}
}
/// Issue 'R' clear text challenge to client.
pub async fn challenge<W>(&self, stream: &mut W) -> Result<(), Error>
where
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
debug!("Sending plain challenge");
let mut msg = BytesMut::new();
msg.put_u8(b'R');
msg.put_i32(8);
msg.put_i32(3); // Clear text
write_all(stream, msg).await
}
/// Authenticate client with server password or secret.
pub async fn authenticate<R, W>(
&self,
read: &mut R,
write: &mut W,
) -> Result<Option<String>, Error>
where
R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send,
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
let response = response(read).await?;
let secret = String::from_utf8_lossy(&response[0..response.len() - 1]).to_string();
match get_pool(&self.pool_name, &self.username, Some(secret.clone())) {
None => match get_pool(&self.pool_name, &self.username, None) {
Some(pool) => {
match pool.settings.user.password {
Some(ref password) => {
if password != &secret {
wrong_password(write, &self.username).await?;
Err(Error::ClientError(format!(
"Invalid password {{ username: {}, pool_name: {}, application_name: {} }}",
self.username, self.pool_name, self.application_name
)))
} else {
validate_pool(write, pool, &self.username, &self.pool_name).await?;
Ok(None)
}
}
None => {
// Server is storing hashes, we can't query it for the plain text password.
error_response(
write,
&format!(
"No server password configured for database: {}, user: {}",
self.pool_name, self.username
),
)
.await?;
Err(Error::ClientError(format!(
"No server password configured for {{ username: {}, pool_name: {}, application_name: {} }}",
self.username, self.pool_name, self.application_name
)))
}
}
}
None => {
error_response(
write,
&format!(
"No pool configured for database: {}, user: {}",
self.pool_name, self.username
),
)
.await?;
Err(Error::ClientError(format!(
"Invalid pool name {{ username: {}, pool_name: {}, application_name: {} }}",
self.username, self.pool_name, self.application_name
)))
}
},
Some(pool) => {
validate_pool(write, pool, &self.username, &self.pool_name).await?;
Ok(Some(secret))
}
}
}
}
/// MD5 hash authentication.
///
/// Deprecated, but widely used everywhere, and currently required for poolers
/// to authencticate clients without involving Postgres.
///
/// Admin clients are required to use MD5.
pub struct Md5 {
username: String,
pool_name: String,
application_name: String,
salt: [u8; 4],
admin: bool,
}
impl Md5 {
pub fn new(username: &str, pool_name: &str, application_name: &str, admin: bool) -> Md5 {
let salt: [u8; 4] = [
rand::random(),
rand::random(),
rand::random(),
rand::random(),
];
Md5 {
username: username.to_string(),
pool_name: pool_name.to_string(),
application_name: application_name.to_string(),
salt,
admin,
}
}
/// Issue a 'R' MD5 challenge to the client.
pub async fn challenge<W>(&self, stream: &mut W) -> Result<(), Error>
where
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
let mut res = BytesMut::new();
res.put_u8(b'R');
res.put_i32(12);
res.put_i32(5); // MD5
res.put_slice(&self.salt[..]);
write_all(stream, res).await
}
/// Authenticate client with MD5. This is used for both admin and normal users.
pub async fn authenticate<R, W>(&self, read: &mut R, write: &mut W) -> Result<(), Error>
where
R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send,
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
let password_hash = response(read).await?;
if self.admin {
let config = get_config();
// Compare server and client hashes.
let our_hash = md5_hash_password(
&config.general.admin_username,
&config.general.admin_password,
&self.salt,
);
if our_hash != password_hash {
wrong_password(write, &self.username).await?;
Err(Error::ClientError(format!(
"Invalid password {{ username: {}, pool_name: {}, application_name: {} }}",
self.username, self.pool_name, self.application_name
)))
} else {
Ok(())
}
} else {
match get_pool(&self.pool_name, &self.username, None) {
Some(pool) => {
match &pool.settings.user.password {
Some(ref password) => {
let our_hash = md5_hash_password(&self.username, password, &self.salt);
if our_hash != password_hash {
wrong_password(write, &self.username).await?;
Err(Error::ClientError(format!(
"Invalid password {{ username: {}, pool_name: {}, application_name: {} }}",
self.username, self.pool_name, self.application_name
)))
} else {
validate_pool(write, pool, &self.username, &self.pool_name).await?;
Ok(())
}
}
None => {
if !get_config().is_auth_query_configured() {
error_response(
write,
&format!(
"No password configured and auth_query is not set: {}, user: {}",
self.pool_name, self.username
),
)
.await?;
return Err(Error::ClientError(format!(
"No password configured and auth_query is not set"
)));
}
debug!("Using auth_query");
// Fetch hash from server
let hash = (*pool.auth_hash.read()).clone();
let hash = match hash {
Some(hash) => {
debug!("Using existing hash: {}", hash);
hash.clone()
}
None => {
debug!("Pool has no hash set, fetching new one");
let hash = refetch_auth_hash(
&pool,
write,
&self.username,
&self.pool_name,
)
.await?;
(*pool.auth_hash.write()) = Some(hash.clone());
hash
}
};
let our_hash = md5_hash_second_pass(&hash, &self.salt);
// Compare hashes
if our_hash != password_hash {
debug!("Pool auth query hash did not match, refetching");
// Server hash maybe changed
let hash = refetch_auth_hash(
&pool,
write,
&self.username,
&self.pool_name,
)
.await?;
let our_hash = md5_hash_second_pass(&hash, &self.salt);
if our_hash != password_hash {
debug!("Auth query failed, passwords don't match");
wrong_password(write, &self.username).await?;
Err(Error::ClientError(format!(
"Invalid password {{ username: {}, pool_name: {}, application_name: {} }}",
self.username, self.pool_name, self.application_name
)))
} else {
(*pool.auth_hash.write()) = Some(hash);
validate_pool(
write,
pool.clone(),
&self.username,
&self.pool_name,
)
.await?;
Ok(())
}
} else {
validate_pool(write, pool.clone(), &self.username, &self.pool_name)
.await?;
Ok(())
}
}
}
}
None => {
error_response(
write,
&format!(
"No pool configured for database: {}, user: {}",
self.pool_name, self.username
),
)
.await?;
return Err(Error::ClientError(format!(
"Invalid pool name {{ username: {}, pool_name: {}, application_name: {} }}",
self.username, self.pool_name, self.application_name
)));
}
}
}
}
}

View File

@@ -1,4 +1,5 @@
use crate::errors::Error;
use crate::pool::ConnectionPool;
use crate::server::Server;
use log::debug;
@@ -71,23 +72,32 @@ impl AuthPassthrough {
let auth_user = crate::config::User {
username: self.user.clone(),
password: Some(self.password.clone()),
server_username: None,
server_password: None,
pool_size: 1,
statement_timeout: 0,
secrets: None,
pool_mode: None,
server_lifetime: None,
min_pool_size: None,
};
let user = &address.username;
debug!("Connecting to server to obtain auth hashes.");
debug!("Connecting to server to obtain auth hashes");
let auth_query = self.query.replace("$1", user);
match Server::exec_simple_query(address, &auth_user, &auth_query).await {
Ok(password_data) => {
if password_data.len() == 2 && password_data.first().unwrap() == user {
if let Some(stripped_hash) = password_data.last().unwrap().to_string().strip_prefix("md5") {
Ok(stripped_hash.to_string())
} else {
if let Some(stripped_hash) = password_data
.last()
.unwrap()
.to_string()
.strip_prefix("md5") {
Ok(stripped_hash.to_string())
}
else {
Err(Error::AuthPassthroughError(
"Obtained hash from auth_query does not seem to be in md5 format.".to_string(),
))
@@ -99,12 +109,26 @@ impl AuthPassthrough {
))
}
}
Err(err) => {
Err(Error::AuthPassthroughError(
format!("Error trying to obtain password from auth_query, ignoring hash for user '{}'. Error: {:?}",
user, err)))
}
user, err))
)
}
}
}
}
pub async fn refetch_auth_hash(pool: &ConnectionPool) -> Result<String, Error> {
let address = pool.address(0, 0);
if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) {
let hash = apt.fetch_hash(address).await?;
return Ok(hash);
}
Err(Error::ClientError(format!(
"Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.",
address.username, address.database
)))
}

View File

@@ -1,10 +1,11 @@
use crate::errors::Error;
use crate::errors::{ClientIdentifier, Error};
use crate::pool::BanReason;
/// Handle clients by pretending to be a PostgreSQL server.
use bytes::{Buf, BufMut, BytesMut};
use log::{debug, error, info, trace};
use log::{debug, error, info, trace, warn};
use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::{atomic::AtomicUsize, Arc};
use std::time::Instant;
use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
@@ -12,17 +13,26 @@ use tokio::sync::broadcast::Receiver;
use tokio::sync::mpsc::Sender;
use crate::admin::{generate_server_info_for_admin, handle_admin};
use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode};
use crate::auth_passthrough::refetch_auth_hash;
use crate::config::{
get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address, PoolMode,
};
use crate::constants::*;
use crate::messages::*;
use crate::plugins::PluginOutput;
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
use crate::query_router::{Command, QueryRouter};
use crate::server::Server;
use crate::stats::{ClientStats, PoolStats, ServerStats};
use crate::stats::{ClientStats, ServerStats};
use crate::tls::Tls;
use tokio_rustls::server::TlsStream;
/// Incrementally count prepared statements
/// to avoid random conflicts in places where the random number generator is weak.
pub static PREPARED_STATEMENT_COUNTER: Lazy<Arc<AtomicUsize>> =
Lazy::new(|| Arc::new(AtomicUsize::new(0)));
/// Type of connection received from client.
enum ClientConnectionType {
Startup,
@@ -89,11 +99,11 @@ pub struct Client<S, T> {
/// Application name for this client (defaults to pgcat)
application_name: String,
/// Which secret the user is using to connect, if any.
secret: Option<String>,
/// Used to notify clients about an impending shutdown
shutdown: Receiver<()>,
/// Prepared statements
prepared_statements: HashMap<String, Parse>,
}
/// Client entrypoint.
@@ -204,7 +214,7 @@ pub async fn client_entrypoint(
// Client probably disconnected rejecting our plain text connection.
Ok((ClientConnectionType::Tls, _))
| Ok((ClientConnectionType::CancelQuery, _)) => Err(Error::ProtocolSyncError(
format!("Bad postgres client (plain)"),
"Bad postgres client (plain)".into(),
)),
Err(err) => Err(err),
@@ -292,7 +302,7 @@ pub async fn client_entrypoint(
/// Handle the first message the client sends.
async fn get_startup<S>(stream: &mut S) -> Result<(ClientConnectionType, BytesMut), Error>
where
S: tokio::io::AsyncRead + std::marker::Unpin + tokio::io::AsyncWrite + std::marker::Send,
S: tokio::io::AsyncRead + std::marker::Unpin + tokio::io::AsyncWrite,
{
// Get startup message length.
let len = match stream.read_i32().await {
@@ -371,9 +381,9 @@ pub async fn startup_tls(
}
// Bad Postgres client.
Ok((ClientConnectionType::Tls, _)) | Ok((ClientConnectionType::CancelQuery, _)) => Err(
Error::ProtocolSyncError(format!("Bad postgres client (tls)")),
),
Ok((ClientConnectionType::Tls, _)) | Ok((ClientConnectionType::CancelQuery, _)) => {
Err(Error::ProtocolSyncError("Bad postgres client (tls)".into()))
}
Err(err) => Err(err),
}
@@ -381,8 +391,8 @@ pub async fn startup_tls(
impl<S, T> Client<S, T>
where
S: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send,
T: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
S: tokio::io::AsyncRead + std::marker::Unpin,
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
pub fn is_admin(&self) -> bool {
self.admin
@@ -406,7 +416,7 @@ where
Some(user) => user,
None => {
return Err(Error::ClientError(
"Missing user parameter on client startup".to_string(),
"Missing user parameter on client startup".into(),
))
}
};
@@ -421,6 +431,8 @@ where
None => "pgcat",
};
let client_identifier = ClientIdentifier::new(&application_name, &username, &pool_name);
let admin = ["pgcat", "pgbouncer"]
.iter()
.filter(|db| *db == pool_name)
@@ -445,39 +457,207 @@ where
let process_id: i32 = rand::random();
let secret_key: i32 = rand::random();
let config = get_config();
// Perform MD5 authentication.
// TODO: Add SASL support.
let salt = md5_challenge(&mut write).await?;
let secret = if admin {
debug!("Using md5 auth for admin");
let auth = crate::auth::Md5::new(&username, &pool_name, &application_name, true);
auth.challenge(&mut write).await?;
auth.authenticate(&mut read, &mut write).await?;
None
} else if !config.tls_enabled() {
debug!("Using md5 auth");
let auth = crate::auth::Md5::new(&username, &pool_name, &application_name, false);
auth.challenge(&mut write).await?;
auth.authenticate(&mut read, &mut write).await?;
None
} else {
debug!("Using plain auth");
let auth = crate::auth::ClearText::new(&username, &pool_name, &application_name);
auth.challenge(&mut write).await?;
auth.authenticate(&mut read, &mut write).await?
let code = match read.read_u8().await {
Ok(p) => p,
Err(_) => {
return Err(Error::ClientSocketError(
"password code".into(),
client_identifier,
))
}
};
// Authenticated admin user.
// PasswordMessage
if code as char != 'p' {
return Err(Error::ProtocolSyncError(format!(
"Expected p, got {}",
code as char
)));
}
let len = match read.read_i32().await {
Ok(len) => len,
Err(_) => {
return Err(Error::ClientSocketError(
"password message length".into(),
client_identifier,
))
}
};
let mut password_response = vec![0u8; (len - 4) as usize];
match read.read_exact(&mut password_response).await {
Ok(_) => (),
Err(_) => {
return Err(Error::ClientSocketError(
"password message".into(),
client_identifier,
))
}
};
// Authenticate admin user.
let (transaction_mode, server_info) = if admin {
let config = get_config();
// Compare server and client hashes.
let password_hash = md5_hash_password(
&config.general.admin_username,
&config.general.admin_password,
&salt,
);
if password_hash != password_response {
let error = Error::ClientGeneralError("Invalid password".into(), client_identifier);
warn!("{}", error);
wrong_password(&mut write, username).await?;
return Err(error);
}
(false, generate_server_info_for_admin())
}
// Authenticated normal user.
// Authenticate normal user.
else {
let pool = get_pool(&pool_name, &username, secret.clone()).unwrap();
let mut pool = match get_pool(pool_name, username) {
Some(pool) => pool,
None => {
error_response(
&mut write,
&format!(
"No pool configured for database: {:?}, user: {:?}",
pool_name, username
),
)
.await?;
return Err(Error::ClientGeneralError(
"Invalid pool name".into(),
client_identifier,
));
}
};
// Obtain the hash to compare, we give preference to that written in cleartext in config
// if there is nothing set in cleartext and auth passthrough (auth_query) is configured, we use the hash obtained
// when the pool was created. If there is no hash there, we try to fetch it one more time.
let password_hash = if let Some(password) = &pool.settings.user.password {
Some(md5_hash_password(username, password, &salt))
} else {
if !get_config().is_auth_query_configured() {
wrong_password(&mut write, username).await?;
return Err(Error::ClientAuthImpossible(username.into()));
}
let mut hash = (*pool.auth_hash.read()).clone();
if hash.is_none() {
warn!(
"Query auth configured \
but no hash password found \
for pool {}. Will try to refetch it.",
pool_name
);
match refetch_auth_hash(&pool).await {
Ok(fetched_hash) => {
warn!("Password for {}, obtained. Updating.", client_identifier);
{
let mut pool_auth_hash = pool.auth_hash.write();
*pool_auth_hash = Some(fetched_hash.clone());
}
hash = Some(fetched_hash);
}
Err(err) => {
wrong_password(&mut write, username).await?;
return Err(Error::ClientAuthPassthroughError(
err.to_string(),
client_identifier,
));
}
}
};
Some(md5_hash_second_pass(&hash.unwrap(), &salt))
};
// Once we have the resulting hash, we compare with what the client gave us.
// If they do not match and auth query is set up, we try to refetch the hash one more time
// to see if the password has changed since the pool was created.
//
// @TODO: we could end up fetching again the same password twice (see above).
if password_hash.unwrap() != password_response {
warn!(
"Invalid password {}, will try to refetch it.",
client_identifier
);
let fetched_hash = match refetch_auth_hash(&pool).await {
Ok(fetched_hash) => fetched_hash,
Err(err) => {
wrong_password(&mut write, username).await?;
return Err(err);
}
};
let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt);
// Ok password changed in server an auth is possible.
if new_password_hash == password_response {
warn!(
"Password for {}, changed in server. Updating.",
client_identifier
);
{
let mut pool_auth_hash = pool.auth_hash.write();
*pool_auth_hash = Some(fetched_hash);
}
} else {
wrong_password(&mut write, username).await?;
return Err(Error::ClientGeneralError(
"Invalid password".into(),
client_identifier,
));
}
}
let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction;
// If the pool hasn't been validated yet,
// connect to the servers and figure out what's what.
if !pool.validated() {
match pool.validate().await {
Ok(_) => (),
Err(err) => {
error_response(
&mut write,
&format!(
"Pool down for database: {:?}, user: {:?}",
pool_name, username
),
)
.await?;
return Err(Error::ClientError(format!("Pool down: {:?}", err)));
}
}
}
(transaction_mode, pool.server_info())
};
debug!("Authentication successful");
debug!("Password authentication successful");
auth_ok(&mut write).await?;
write_all(&mut write, server_info).await?;
@@ -485,24 +665,12 @@ where
ready_for_query(&mut write).await?;
trace!("Startup OK");
let pool_stats = match get_pool(pool_name, username, secret.clone()) {
Some(pool) => {
if !admin {
pool.stats
} else {
Arc::new(PoolStats::default())
}
}
None => Arc::new(PoolStats::default()),
};
let stats = Arc::new(ClientStats::new(
process_id,
application_name,
username,
pool_name,
tokio::time::Instant::now(),
pool_stats,
));
Ok(Client {
@@ -525,7 +693,7 @@ where
application_name: application_name.to_string(),
shutdown,
connected_to_server: false,
secret,
prepared_statements: HashMap::new(),
})
}
@@ -560,7 +728,7 @@ where
application_name: String::from("undefined"),
shutdown,
connected_to_server: false,
secret: None,
prepared_statements: HashMap::new(),
})
}
@@ -599,6 +767,13 @@ where
self.stats.register(self.stats.clone());
// Result returned by one of the plugins.
let mut plugin_output = None;
// Prepared statement being executed
let mut prepared_statement = None;
let mut will_prepare = false;
// Our custom protocol loop.
// We expect the client to either start a transaction with regular queries
// or issue commands for our sharding and server selection protocol.
@@ -608,22 +783,25 @@ where
self.transaction_mode
);
// Should we rewrite prepared statements and bind messages?
let mut prepared_statements_enabled = get_prepared_statements();
// Read a complete message from the client, which normally would be
// either a `Q` (query) or `P` (prepare, extended protocol).
// We can parse it here before grabbing a server from the pool,
// in case the client is sending some custom protocol messages, e.g.
// SET SHARDING KEY TO 'bigint';
let message = tokio::select! {
let mut message = tokio::select! {
_ = self.shutdown.recv() => {
if !self.admin {
error_response_terminal(
&mut self.write,
"terminating connection due to administrator command"
).await?;
self.stats.disconnect();
return Ok(())
self.stats.disconnect();
return Ok(());
}
// Admin clients ignore shutdown.
@@ -642,28 +820,75 @@ where
// allocate a connection, we wouldn't be able to send back an error message
// to the client so we buffer them and defer the decision to error out or not
// to when we get the S message
'D' | 'E' => {
'D' => {
if prepared_statements_enabled {
let name;
(name, message) = self.rewrite_describe(message).await?;
if let Some(name) = name {
prepared_statement = Some(name);
}
}
self.buffer.put(&message[..]);
continue;
}
'E' => {
self.buffer.put(&message[..]);
continue;
}
'Q' => {
if query_router.query_parser_enabled() {
query_router.infer(&message);
if let Ok(ast) = QueryRouter::parse(&message) {
let plugin_result = query_router.execute_plugins(&ast).await;
match plugin_result {
Ok(PluginOutput::Deny(error)) => {
error_response(&mut self.write, &error).await?;
continue;
}
Ok(PluginOutput::Intercept(result)) => {
write_all(&mut self.write, result).await?;
continue;
}
_ => (),
};
let _ = query_router.infer(&ast);
}
}
}
'P' => {
if prepared_statements_enabled {
(prepared_statement, message) = self.rewrite_parse(message)?;
will_prepare = true;
}
self.buffer.put(&message[..]);
if query_router.query_parser_enabled() {
query_router.infer(&message);
if let Ok(ast) = QueryRouter::parse(&message) {
if let Ok(output) = query_router.execute_plugins(&ast).await {
plugin_output = Some(output);
}
let _ = query_router.infer(&ast);
}
}
continue;
}
'B' => {
if prepared_statements_enabled {
(prepared_statement, message) = self.rewrite_bind(message).await?;
}
self.buffer.put(&message[..]);
if query_router.query_parser_enabled() {
@@ -681,6 +906,19 @@ where
return Ok(());
}
// Close (F)
'C' => {
if prepared_statements_enabled {
let close: Close = (&message).try_into()?;
if close.is_prepared_statement() && !close.anonymous() {
self.prepared_statements.remove(&close.name);
write_all_flush(&mut self.write, &close_complete()).await?;
continue;
}
}
}
_ => (),
}
@@ -691,6 +929,18 @@ where
continue;
}
// Check on plugin results.
match plugin_output {
Some(PluginOutput::Deny(error)) => {
self.buffer.clear();
error_response(&mut self.write, &error).await?;
plugin_output = None;
continue;
}
_ => (),
};
// Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config
// when starting a query.
@@ -796,11 +1046,26 @@ where
error!("Got Sync message but failed to get a connection from the pool");
self.buffer.clear();
}
error_response(&mut self.write, "could not get connection from the pool")
.await?;
error!("Could not get connection from pool: {{ pool_name: {:?}, username: {:?}, shard: {:?}, role: \"{:?}\", error: \"{:?}\" }}",
self.pool_name.clone(), self.username.clone(), query_router.shard(), query_router.role(), err);
error!(
"Could not get connection from pool: \
{{ \
pool_name: {:?}, \
username: {:?}, \
shard: {:?}, \
role: \"{:?}\", \
error: \"{:?}\" \
}}",
self.pool_name,
self.username,
query_router.shard(),
query_router.role(),
err
);
continue;
}
};
@@ -845,7 +1110,58 @@ where
// If the client is in session mode, no more custom protocol
// commands will be accepted.
loop {
let message = match initial_message {
// Only check if we should rewrite prepared statements
// in session mode. In transaction mode, we check at the beginning of
// each transaction.
if !self.transaction_mode {
prepared_statements_enabled = get_prepared_statements();
}
debug!("Prepared statement active: {:?}", prepared_statement);
// We are processing a prepared statement.
if let Some(ref name) = prepared_statement {
debug!("Checking prepared statement is on server");
// Get the prepared statement the server expects to see.
let statement = match self.prepared_statements.get(name) {
Some(statement) => {
debug!("Prepared statement `{}` found in cache", name);
statement
}
None => {
return Err(Error::ClientError(format!(
"prepared statement `{}` not found",
name
)))
}
};
// Since it's already in the buffer, we don't need to prepare it on this server.
if will_prepare {
server.will_prepare(&statement.name);
will_prepare = false;
} else {
// The statement is not prepared on the server, so we need to prepare it.
if server.should_prepare(&statement.name) {
match server.prepare(statement).await {
Ok(_) => (),
Err(err) => {
pool.ban(
&address,
BanReason::MessageSendFailed,
Some(&self.stats),
);
return Err(err);
}
}
}
}
// Done processing the prepared statement.
prepared_statement = None;
}
let mut message = match initial_message {
None => {
trace!("Waiting for message inside transaction or in session mode");
@@ -867,11 +1183,25 @@ where
Err(_) => {
// Client idle in transaction timeout
error_response(&mut self.write, "idle transaction timeout").await?;
error!("Client idle in transaction timeout: {{ pool_name: {:?}, username: {:?}, shard: {:?}, role: \"{:?}\"}}", self.pool_name.clone(), self.username.clone(), query_router.shard(), query_router.role());
error!(
"Client idle in transaction timeout: \
{{ \
pool_name: {}, \
username: {}, \
shard: {}, \
role: \"{:?}\" \
}}",
self.pool_name,
self.username,
query_router.shard(),
query_router.role()
);
break;
}
}
}
Some(message) => {
initial_message = None;
message
@@ -890,6 +1220,27 @@ where
match code {
// Query
'Q' => {
if query_router.query_parser_enabled() {
if let Ok(ast) = QueryRouter::parse(&message) {
let plugin_result = query_router.execute_plugins(&ast).await;
match plugin_result {
Ok(PluginOutput::Deny(error)) => {
error_response(&mut self.write, &error).await?;
continue;
}
Ok(PluginOutput::Intercept(result)) => {
write_all(&mut self.write, result).await?;
continue;
}
_ => (),
};
let _ = query_router.infer(&ast);
}
}
debug!("Sending query to server");
self.send_and_receive_loop(
@@ -909,7 +1260,7 @@ where
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode {
if self.transaction_mode && !server.in_copy_mode() {
self.stats.idle();
break;
@@ -923,24 +1274,74 @@ where
self.stats.disconnect();
self.release();
if prepared_statements_enabled {
server.maintain_cache().await?;
}
return Ok(());
}
// Parse
// The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
'P' => {
if prepared_statements_enabled {
(prepared_statement, message) = self.rewrite_parse(message)?;
will_prepare = true;
}
if query_router.query_parser_enabled() {
if let Ok(ast) = QueryRouter::parse(&message) {
if let Ok(output) = query_router.execute_plugins(&ast).await {
plugin_output = Some(output);
}
}
}
self.buffer.put(&message[..]);
}
// Bind
// The placeholder's replacements are here, e.g. 'user@email.com' and 'true'
'B' => {
if prepared_statements_enabled {
(prepared_statement, message) = self.rewrite_bind(message).await?;
}
self.buffer.put(&message[..]);
}
// Describe
// Command a client can issue to describe a previously prepared named statement.
'D' => {
if prepared_statements_enabled {
let name;
(name, message) = self.rewrite_describe(message).await?;
if let Some(name) = name {
prepared_statement = Some(name);
}
}
self.buffer.put(&message[..]);
}
// Close the prepared statement.
'C' => {
if prepared_statements_enabled {
let close: Close = (&message).try_into()?;
if close.is_prepared_statement() && !close.anonymous() {
match self.prepared_statements.get(&close.name) {
Some(parse) => {
server.will_close(&parse.generated_name);
}
// A prepared statement slipped through? Not impossible, since we don't support PREPARE yet.
None => (),
};
}
}
self.buffer.put(&message[..]);
}
@@ -955,12 +1356,30 @@ where
'S' => {
debug!("Sending query to server");
match plugin_output {
Some(PluginOutput::Deny(error)) => {
error_response(&mut self.write, &error).await?;
plugin_output = None;
self.buffer.clear();
continue;
}
Some(PluginOutput::Intercept(result)) => {
write_all(&mut self.write, result).await?;
plugin_output = None;
self.buffer.clear();
continue;
}
_ => (),
};
self.buffer.put(&message[..]);
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;
// Almost certainly true
if first_message_code == 'P' {
if first_message_code == 'P' && !prepared_statements_enabled {
// Message layout
// P followed by 32 int followed by null-terminated statement name
// So message code should be in offset 0 of the buffer, first character
@@ -991,7 +1410,7 @@ where
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode {
if self.transaction_mode && !server.in_copy_mode() {
break;
}
}
@@ -1026,7 +1445,7 @@ where
.receive_server_message(server, &address, &pool, &self.stats.clone())
.await?;
match write_all_half(&mut self.write, &response).await {
match write_all_flush(&mut self.write, &response).await {
Ok(_) => (),
Err(err) => {
server.mark_bad();
@@ -1056,7 +1475,13 @@ where
// The server is no longer bound to us, we can't cancel it's queries anymore.
debug!("Releasing server back into the pool");
server.checkin_cleanup().await?;
if prepared_statements_enabled {
server.maintain_cache().await?;
}
server.stats().idle();
self.connected_to_server = false;
@@ -1068,7 +1493,7 @@ where
/// Retrieve connection pool, if it exists.
/// Return an error to the client otherwise.
async fn get_pool(&mut self) -> Result<ConnectionPool, Error> {
match get_pool(&self.pool_name, &self.username, self.secret.clone()) {
match get_pool(&self.pool_name, &self.username) {
Some(pool) => Ok(pool),
None => {
error_response(
@@ -1088,6 +1513,107 @@ where
}
}
/// Rewrite Parse (F) message to set the prepared statement name to one we control.
/// Save it into the client cache.
fn rewrite_parse(&mut self, message: BytesMut) -> Result<(Option<String>, BytesMut), Error> {
let parse: Parse = (&message).try_into()?;
let name = parse.name.clone();
// Don't rewrite anonymous prepared statements
if parse.anonymous() {
debug!("Anonymous prepared statement");
return Ok((None, message));
}
let parse = parse.rename();
debug!(
"Renamed prepared statement `{}` to `{}` and saved to cache",
name, parse.name
);
self.prepared_statements.insert(name.clone(), parse.clone());
Ok((Some(name), parse.try_into()?))
}
/// Rewrite the Bind (F) message to use the prepared statement name
/// saved in the client cache.
async fn rewrite_bind(
&mut self,
message: BytesMut,
) -> Result<(Option<String>, BytesMut), Error> {
let bind: Bind = (&message).try_into()?;
let name = bind.prepared_statement.clone();
if bind.anonymous() {
debug!("Anonymous bind message");
return Ok((None, message));
}
match self.prepared_statements.get(&name) {
Some(prepared_stmt) => {
let bind = bind.reassign(prepared_stmt);
debug!("Rewrote bind `{}` to `{}`", name, bind.prepared_statement);
Ok((Some(name), bind.try_into()?))
}
None => {
debug!("Got bind for unknown prepared statement {:?}", bind);
error_response(
&mut self.write,
&format!(
"prepared statement \"{}\" does not exist",
bind.prepared_statement
),
)
.await?;
Err(Error::ClientError(format!(
"Prepared statement `{}` doesn't exist",
name
)))
}
}
}
/// Rewrite the Describe (F) message to use the prepared statement name
/// saved in the client cache.
async fn rewrite_describe(
&mut self,
message: BytesMut,
) -> Result<(Option<String>, BytesMut), Error> {
let describe: Describe = (&message).try_into()?;
let name = describe.statement_name.clone();
if describe.anonymous() {
debug!("Anonymous describe");
return Ok((None, message));
}
match self.prepared_statements.get(&name) {
Some(prepared_stmt) => {
let describe = describe.rename(&prepared_stmt.name);
debug!(
"Rewrote describe `{}` to `{}`",
name, describe.statement_name
);
Ok((Some(name), describe.try_into()?))
}
None => {
debug!("Got describe for unknown prepared statement {:?}", describe);
Ok((None, message))
}
}
}
/// Release the server from the client: it can't cancel its queries anymore.
pub fn release(&self) {
let mut guard = self.client_server_map.lock();
@@ -1121,7 +1647,7 @@ where
.receive_server_message(server, address, pool, client_stats)
.await?;
match write_all_half(&mut self.write, &response).await {
match write_all_flush(&mut self.write, &response).await {
Ok(_) => (),
Err(err) => {
server.mark_bad();

36
src/cmd_args.rs Normal file
View File

@@ -0,0 +1,36 @@
use clap::{Parser, ValueEnum};
use tracing::Level;
/// PgCat: Nextgen PostgreSQL Pooler
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
#[arg(default_value_t = String::from("pgcat.toml"), env)]
pub config_file: String,
#[arg(short, long, default_value_t = tracing::Level::INFO, env)]
pub log_level: Level,
#[clap(short='F', long, value_enum, default_value_t=LogFormat::Text, env)]
pub log_format: LogFormat,
#[arg(
short,
long,
default_value_t = false,
env,
help = "disable colors in the log output"
)]
pub no_color: bool,
}
pub fn parse() -> Args {
return Args::parse();
}
#[derive(ValueEnum, Clone, Debug)]
pub enum LogFormat {
Text,
Structured,
Debug,
}

View File

@@ -1,6 +1,6 @@
/// Parse the configuration file.
use arc_swap::ArcSwap;
use log::{error, info, warn};
use log::{error, info};
use once_cell::sync::Lazy;
use regex::Regex;
use serde_derive::{Deserialize, Serialize};
@@ -12,6 +12,7 @@ use std::sync::Arc;
use tokio::fs::File;
use tokio::io::AsyncReadExt;
use crate::dns_cache::CachedResolver;
use crate::errors::Error;
use crate::pool::{ClientServerMap, ConnectionPool};
use crate::sharding::ShardingFunction;
@@ -121,6 +122,16 @@ impl Default for Address {
}
}
impl std::fmt::Display for Address {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"[address: {}:{}][database: {}][user: {}]",
self.host, self.port, self.database, self.username
)
}
}
// We need to implement PartialEq by ourselves so we skip stats in the comparison
impl PartialEq for Address {
fn eq(&self, other: &Self) -> bool {
@@ -178,29 +189,14 @@ impl Address {
pub struct User {
pub username: String,
pub password: Option<String>,
pub server_username: Option<String>,
pub server_password: Option<String>,
pub pool_size: u32,
pub min_pool_size: Option<u32>,
pub pool_mode: Option<PoolMode>,
pub server_lifetime: Option<u64>,
#[serde(default)] // 0
pub statement_timeout: u64,
pub secrets: Option<Vec<String>>,
}
impl User {
fn validate(&self) -> Result<(), Error> {
match self.secrets {
Some(ref secrets) => {
for secret in secrets.iter() {
if secret.len() < 16 {
warn!(
"[user: {}] Secret is too short (less than 16 characters)",
self.username
);
}
}
}
None => (),
}
Ok(())
}
}
impl Default for User {
@@ -208,13 +204,37 @@ impl Default for User {
User {
username: String::from("postgres"),
password: None,
server_username: None,
server_password: None,
pool_size: 15,
min_pool_size: None,
statement_timeout: 0,
secrets: None,
pool_mode: None,
server_lifetime: None,
}
}
}
impl User {
fn validate(&self) -> Result<(), Error> {
match self.min_pool_size {
Some(min_pool_size) => {
if min_pool_size > self.pool_size {
error!(
"min_pool_size of {} cannot be larger than pool_size of {}",
min_pool_size, self.pool_size
);
return Err(Error::BadConfig);
}
}
None => (),
};
Ok(())
}
}
/// General configuration.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct General {
@@ -222,9 +242,11 @@ pub struct General {
pub host: String,
#[serde(default = "General::default_port")]
pub port: i16,
pub port: u16,
pub enable_prometheus_exporter: Option<bool>,
#[serde(default = "General::default_prometheus_exporter_port")]
pub prometheus_exporter_port: i16,
#[serde(default = "General::default_connect_timeout")]
@@ -239,6 +261,8 @@ pub struct General {
pub tcp_keepalives_count: u32,
#[serde(default = "General::default_tcp_keepalives_interval")]
pub tcp_keepalives_interval: u64,
#[serde(default = "General::default_tcp_user_timeout")]
pub tcp_user_timeout: u64,
#[serde(default)] // False
pub log_client_connections: bool,
@@ -246,6 +270,12 @@ pub struct General {
#[serde(default)] // False
pub log_client_disconnections: bool,
#[serde(default)] // False
pub dns_cache_enabled: bool,
#[serde(default = "General::default_dns_max_ttl")]
pub dns_max_ttl: u64,
#[serde(default = "General::default_shutdown_timeout")]
pub shutdown_timeout: u64,
@@ -261,20 +291,43 @@ pub struct General {
#[serde(default = "General::default_idle_client_in_transaction_timeout")]
pub idle_client_in_transaction_timeout: u64,
#[serde(default = "General::default_server_lifetime")]
pub server_lifetime: u64,
#[serde(default = "General::default_server_round_robin")] // False
pub server_round_robin: bool,
#[serde(default = "General::default_worker_threads")]
pub worker_threads: usize,
#[serde(default)] // False
pub autoreload: bool,
#[serde(default)] // None
pub autoreload: Option<u64>,
pub tls_certificate: Option<String>,
pub tls_private_key: Option<String>,
#[serde(default)] // false
pub server_tls: bool,
#[serde(default)] // false
pub verify_server_certificate: bool,
pub admin_username: String,
pub admin_password: String,
#[serde(default = "General::default_validate_config")]
pub validate_config: bool,
// Support for auth query
pub auth_query: Option<String>,
pub auth_query_user: Option<String>,
pub auth_query_password: Option<String>,
#[serde(default)]
pub prepared_statements: bool,
#[serde(default = "General::default_prepared_statements_cache_size")]
pub prepared_statements_cache_size: usize,
}
impl General {
@@ -282,17 +335,21 @@ impl General {
"0.0.0.0".into()
}
pub fn default_port() -> i16 {
pub fn default_port() -> u16 {
5432
}
pub fn default_server_lifetime() -> u64 {
1000 * 60 * 60 // 1 hour
}
pub fn default_connect_timeout() -> u64 {
1000
}
// These keepalive defaults should detect a dead connection within 30 seconds.
// Tokio defaults to disabling keepalives which keeps dead connections around indefinitely.
// This can lead to permenant server pool exhaustion
// This can lead to permanent server pool exhaustion
pub fn default_tcp_keepalives_idle() -> u64 {
5 // 5 seconds
}
@@ -305,14 +362,22 @@ impl General {
5 // 5 seconds
}
pub fn default_tcp_user_timeout() -> u64 {
10000 // 10000 milliseconds
}
pub fn default_idle_timeout() -> u64 {
60000 // 10 minutes
600000 // 10 minutes
}
pub fn default_shutdown_timeout() -> u64 {
60000
}
pub fn default_dns_max_ttl() -> u64 {
30
}
pub fn default_healthcheck_timeout() -> u64 {
1000
}
@@ -332,6 +397,22 @@ impl General {
pub fn default_idle_client_in_transaction_timeout() -> u64 {
0
}
pub fn default_validate_config() -> bool {
true
}
pub fn default_prometheus_exporter_port() -> i16 {
9930
}
pub fn default_server_round_robin() -> bool {
true
}
pub fn default_prepared_statements_cache_size() -> usize {
500
}
}
impl Default for General {
@@ -352,16 +433,26 @@ impl Default for General {
tcp_keepalives_idle: Self::default_tcp_keepalives_idle(),
tcp_keepalives_count: Self::default_tcp_keepalives_count(),
tcp_keepalives_interval: Self::default_tcp_keepalives_interval(),
tcp_user_timeout: Self::default_tcp_user_timeout(),
log_client_connections: false,
log_client_disconnections: false,
autoreload: false,
autoreload: None,
dns_cache_enabled: false,
dns_max_ttl: Self::default_dns_max_ttl(),
tls_certificate: None,
tls_private_key: None,
server_tls: false,
verify_server_certificate: false,
admin_username: String::from("admin"),
admin_password: String::from("admin"),
auth_query: None,
auth_query_user: None,
auth_query_password: None,
server_lifetime: Self::default_server_lifetime(),
server_round_robin: Self::default_server_round_robin(),
validate_config: true,
prepared_statements: false,
prepared_statements_cache_size: 500,
}
}
}
@@ -377,6 +468,7 @@ pub enum PoolMode {
#[serde(alias = "session", alias = "Session")]
Session,
}
impl ToString for PoolMode {
fn to_string(&self) -> String {
match *self {
@@ -413,6 +505,7 @@ pub struct Pool {
#[serde(default = "Pool::default_load_balancing_mode")]
pub load_balancing_mode: LoadBalancingMode,
#[serde(default = "Pool::default_default_role")]
pub default_role: String,
#[serde(default)] // False
@@ -421,10 +514,18 @@ pub struct Pool {
#[serde(default)] // False
pub primary_reads_enabled: bool,
/// Maximum time to allow for establishing a new server connection.
pub connect_timeout: Option<u64>,
/// Close idle connections that have been opened for longer than this.
pub idle_timeout: Option<u64>,
/// Close server connections that have been opened for longer than this.
/// Only applied to idle connections. If the connection is actively used for
/// longer than this period, the pool will not interrupt it.
pub server_lifetime: Option<u64>,
#[serde(default = "Pool::default_sharding_function")]
pub sharding_function: ShardingFunction,
#[serde(default = "Pool::default_automatic_sharding_key")]
@@ -438,9 +539,13 @@ pub struct Pool {
pub auth_query_user: Option<String>,
pub auth_query_password: Option<String>,
#[serde(default = "Pool::default_cleanup_server_connections")]
pub cleanup_server_connections: bool,
pub plugins: Option<Plugins>,
pub shards: BTreeMap<String, Shard>,
pub users: BTreeMap<String, User>,
// Note, don't put simple fields below these configs. There's a compatability issue with TOML that makes it
// Note, don't put simple fields below these configs. There's a compatibility issue with TOML that makes it
// incompatible to have simple fields in TOML after complex objects. See
// https://users.rust-lang.org/t/why-toml-to-string-get-error-valueaftertable/85903
}
@@ -470,6 +575,18 @@ impl Pool {
None
}
pub fn default_default_role() -> String {
"any".into()
}
pub fn default_sharding_function() -> ShardingFunction {
ShardingFunction::PgBigintHash
}
pub fn default_cleanup_server_connections() -> bool {
true
}
pub fn validate(&mut self) -> Result<(), Error> {
match self.default_role.as_ref() {
"any" => (),
@@ -529,8 +646,8 @@ impl Pool {
None => None,
};
for user in self.users.iter() {
user.1.validate()?;
for (_, user) in &self.users {
user.validate()?;
}
Ok(())
@@ -557,6 +674,9 @@ impl Default for Pool {
auth_query: None,
auth_query_user: None,
auth_query_password: None,
server_lifetime: None,
plugins: None,
cleanup_server_connections: true,
}
}
}
@@ -606,7 +726,7 @@ impl Shard {
if primary_count > 1 {
error!(
"Shard {} has more than on primary configured",
"Shard {} has more than one primary configured",
self.database
);
return Err(Error::BadConfig);
@@ -635,6 +755,76 @@ impl Default for Shard {
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
pub struct Plugins {
pub intercept: Option<Intercept>,
pub table_access: Option<TableAccess>,
pub query_logger: Option<QueryLogger>,
pub prewarmer: Option<Prewarmer>,
}
impl std::fmt::Display for Plugins {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"interceptor: {}, table_access: {}, query_logger: {}, prewarmer: {}",
self.intercept.is_some(),
self.table_access.is_some(),
self.query_logger.is_some(),
self.prewarmer.is_some(),
)
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
pub struct Intercept {
pub enabled: bool,
pub queries: BTreeMap<String, Query>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
pub struct TableAccess {
pub enabled: bool,
pub tables: Vec<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
pub struct QueryLogger {
pub enabled: bool,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
pub struct Prewarmer {
pub enabled: bool,
pub queries: Vec<String>,
}
impl Intercept {
pub fn substitute(&mut self, db: &str, user: &str) {
for (_, query) in self.queries.iter_mut() {
query.substitute(db, user);
query.query = query.query.to_ascii_lowercase();
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
pub struct Query {
pub query: String,
pub schema: Vec<Vec<String>>,
pub result: Vec<Vec<String>>,
}
impl Query {
pub fn substitute(&mut self, db: &str, user: &str) {
for col in self.result.iter_mut() {
for i in 0..col.len() {
col[i] = col[i].replace("${USER}", user).replace("${DATABASE}", db);
}
}
}
}
/// Configuration wrapper.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Config {
@@ -652,7 +842,13 @@ pub struct Config {
#[serde(default = "Config::default_path")]
pub path: String,
// General and global settings.
pub general: General,
// Plugins that should run in all pools.
pub plugins: Option<Plugins>,
// Connection pools.
pub pools: HashMap<String, Pool>,
}
@@ -682,11 +878,6 @@ impl Config {
}
}
}
/// Checks that we configured TLS.
pub fn tls_enabled(&self) -> bool {
self.general.tls_certificate.is_some() && self.general.tls_private_key.is_some()
}
}
impl Default for Config {
@@ -695,6 +886,7 @@ impl Default for Config {
path: Self::default_path(),
general: General::default(),
pools: HashMap::default(),
plugins: None,
}
}
}
@@ -814,6 +1006,11 @@ impl Config {
);
info!("Shutdown timeout: {}ms", self.general.shutdown_timeout);
info!("Healthcheck delay: {}ms", self.general.healthcheck_delay);
info!(
"Default max server lifetime: {}ms",
self.general.server_lifetime
);
info!("Sever round robin: {}", self.general.server_round_robin);
match self.general.tls_certificate.clone() {
Some(tls_certificate) => {
info!("TLS certificate: {}", tls_certificate);
@@ -832,6 +1029,25 @@ impl Config {
info!("TLS support is disabled");
}
};
info!("Server TLS enabled: {}", self.general.server_tls);
info!(
"Server TLS certificate verification: {}",
self.general.verify_server_certificate
);
info!("Prepared statements: {}", self.general.prepared_statements);
if self.general.prepared_statements {
info!(
"Prepared statements server cache size: {}",
self.general.prepared_statements_cache_size
);
}
info!(
"Plugins: {}",
match self.plugins {
Some(ref plugins) => plugins.to_string(),
None => "not configured".into(),
}
);
for (pool_name, pool_config) in &self.pools {
// TODO: Make this output prettier (maybe a table?)
@@ -846,8 +1062,9 @@ impl Config {
.to_string()
);
info!(
"[pool: {}] Pool mode: {:?}",
pool_name, pool_config.pool_mode
"[pool: {}] Default pool mode: {}",
pool_name,
pool_config.pool_mode.to_string()
);
info!(
"[pool: {}] Load Balancing mode: {:?}",
@@ -889,16 +1106,60 @@ impl Config {
pool_name,
pool_config.users.len()
);
info!(
"[pool: {}] Max server lifetime: {}",
pool_name,
match pool_config.server_lifetime {
Some(server_lifetime) => format!("{}ms", server_lifetime),
None => "default".to_string(),
}
);
info!(
"[pool: {}] Cleanup server connections: {}",
pool_name, pool_config.cleanup_server_connections
);
info!(
"[pool: {}] Plugins: {}",
pool_name,
match pool_config.plugins {
Some(ref plugins) => plugins.to_string(),
None => "not configured".into(),
}
);
for user in &pool_config.users {
info!(
"[pool: {}][user: {}] Pool size: {}",
pool_name, user.1.username, user.1.pool_size,
);
info!(
"[pool: {}][user: {}] Minimum pool size: {}",
pool_name,
user.1.username,
user.1.min_pool_size.unwrap_or(0)
);
info!(
"[pool: {}][user: {}] Statement timeout: {}",
pool_name, user.1.username, user.1.statement_timeout
)
);
info!(
"[pool: {}][user: {}] Pool mode: {}",
pool_name,
user.1.username,
match user.1.pool_mode {
Some(pool_mode) => pool_mode.to_string(),
None => pool_config.pool_mode.to_string(),
}
);
info!(
"[pool: {}][user: {}] Max server lifetime: {}",
pool_name,
user.1.username,
match user.1.server_lifetime {
Some(server_lifetime) => format!("{}ms", server_lifetime),
None => "default".to_string(),
}
);
}
}
}
@@ -909,7 +1170,13 @@ impl Config {
&& (self.general.auth_query_user.is_none()
|| self.general.auth_query_password.is_none())
{
error!("If auth_query is specified, you need to provide a value for `auth_query_user`, `auth_query_password`");
error!(
"If auth_query is specified, \
you need to provide a value \
for `auth_query_user`, \
`auth_query_password`"
);
return Err(Error::BadConfig);
}
@@ -917,7 +1184,14 @@ impl Config {
if pool.auth_query.is_some()
&& (pool.auth_query_user.is_none() || pool.auth_query_password.is_none())
{
error!("Error in pool {{ {} }}. If auth_query is specified, you need to provide a value for `auth_query_user`, `auth_query_password`", name);
error!(
"Error in pool {{ {} }}. \
If auth_query is specified, you need \
to provide a value for `auth_query_user`, \
`auth_query_password`",
name
);
return Err(Error::BadConfig);
}
@@ -927,7 +1201,13 @@ impl Config {
|| pool.auth_query_user.is_none())
&& user_data.password.is_none()
{
error!("Error in pool {{ {} }}. You have to specify a user password for every pool if auth_query is not specified", name);
error!(
"Error in pool {{ {} }}. \
You have to specify a user password \
for every pool if auth_query is not specified",
name
);
return Err(Error::BadConfig);
}
}
@@ -980,9 +1260,15 @@ pub fn get_config() -> Config {
}
pub fn get_idle_client_in_transaction_timeout() -> u64 {
(*(*CONFIG.load()))
.general
.idle_client_in_transaction_timeout
CONFIG.load().general.idle_client_in_transaction_timeout
}
pub fn get_prepared_statements() -> bool {
CONFIG.load().general.prepared_statements
}
pub fn get_prepared_statements_cache_size() -> usize {
CONFIG.load().general.prepared_statements_cache_size
}
/// Parse the configuration file located at the path.
@@ -1025,6 +1311,7 @@ pub async fn parse(path: &str) -> Result<(), Error> {
pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, Error> {
let old_config = get_config();
match parse(&old_config.path).await {
Ok(()) => (),
Err(err) => {
@@ -1032,14 +1319,18 @@ pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, E
return Err(Error::BadConfig);
}
};
let new_config = get_config();
if old_config.pools != new_config.pools {
info!("Pool configuration changed");
match CachedResolver::from_config().await {
Ok(_) => (),
Err(err) => error!("DNS cache reinitialization error: {:?}", err),
};
if old_config != new_config {
info!("Config changed, reloading");
ConnectionPool::from_config(client_server_map).await?;
Ok(true)
} else if old_config != new_config {
Ok(true)
} else {
Ok(false)
}

410
src/dns_cache.rs Normal file
View File

@@ -0,0 +1,410 @@
use crate::config::get_config;
use crate::errors::Error;
use arc_swap::ArcSwap;
use log::{debug, error, info, warn};
use once_cell::sync::Lazy;
use std::collections::{HashMap, HashSet};
use std::io;
use std::net::IpAddr;
use std::sync::Arc;
use std::sync::RwLock;
use tokio::time::{sleep, Duration};
use trust_dns_resolver::error::{ResolveError, ResolveResult};
use trust_dns_resolver::lookup_ip::LookupIp;
use trust_dns_resolver::TokioAsyncResolver;
/// Cached Resolver Globally available
pub static CACHED_RESOLVER: Lazy<ArcSwap<CachedResolver>> =
Lazy::new(|| ArcSwap::from_pointee(CachedResolver::default()));
// Ip addressed are returned as a set of addresses
// so we can compare.
#[derive(Clone, PartialEq, Debug)]
pub struct AddrSet {
set: HashSet<IpAddr>,
}
impl AddrSet {
fn new() -> AddrSet {
AddrSet {
set: HashSet::new(),
}
}
}
impl From<LookupIp> for AddrSet {
fn from(lookup_ip: LookupIp) -> Self {
let mut addr_set = AddrSet::new();
for address in lookup_ip.iter() {
addr_set.set.insert(address);
}
addr_set
}
}
///
/// A CachedResolver is a DNS resolution cache mechanism with customizable expiration time.
///
/// The system works as follows:
///
/// When a host is to be resolved, if we have not resolved it before, a new resolution is
/// executed and stored in the internal cache. Concurrently, every `dns_max_ttl` time, the
/// cache is refreshed.
///
/// # Example:
///
/// ```
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
///
/// # tokio_test::block_on(async {
/// let config = CachedResolverConfig::default();
/// let resolver = CachedResolver::new(config, None).await.unwrap();
/// let addrset = resolver.lookup_ip("www.example.com.").await.unwrap();
/// # })
/// ```
///
/// // Now the ip resolution is stored in local cache and subsequent
/// // calls will be returned from cache. Also, the cache is refreshed
/// // and updated every 10 seconds.
///
/// // You can now check if an 'old' lookup differs from what it's currently
/// // store in cache by using `has_changed`.
/// resolver.has_changed("www.example.com.", addrset)
#[derive(Default)]
pub struct CachedResolver {
// The configuration of the cached_resolver.
config: CachedResolverConfig,
// This is the hash that contains the hash.
data: Option<RwLock<HashMap<String, AddrSet>>>,
// The resolver to be used for DNS queries.
resolver: Option<TokioAsyncResolver>,
// The RefreshLoop
refresh_loop: RwLock<Option<tokio::task::JoinHandle<()>>>,
}
///
/// Configuration
#[derive(Clone, Debug, Default, PartialEq)]
pub struct CachedResolverConfig {
/// Amount of time in secods that a resolved dns address is considered stale.
dns_max_ttl: u64,
/// Enabled or disabled? (this is so we can reload config)
enabled: bool,
}
impl CachedResolverConfig {
fn new(dns_max_ttl: u64, enabled: bool) -> Self {
CachedResolverConfig {
dns_max_ttl,
enabled,
}
}
}
impl From<crate::config::Config> for CachedResolverConfig {
fn from(config: crate::config::Config) -> Self {
CachedResolverConfig::new(config.general.dns_max_ttl, config.general.dns_cache_enabled)
}
}
impl CachedResolver {
///
/// Returns a new Arc<CachedResolver> based on passed configuration.
/// It also starts the loop that will refresh cache entries.
///
/// # Arguments:
///
/// * `config` - The `CachedResolverConfig` to be used to create the resolver.
///
/// # Example:
///
/// ```
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
///
/// # tokio_test::block_on(async {
/// let config = CachedResolverConfig::default();
/// let resolver = CachedResolver::new(config, None).await.unwrap();
/// # })
/// ```
///
pub async fn new(
config: CachedResolverConfig,
data: Option<HashMap<String, AddrSet>>,
) -> Result<Arc<Self>, io::Error> {
// Construct a new Resolver with default configuration options
let resolver = Some(TokioAsyncResolver::tokio_from_system_conf()?);
let data = if let Some(hash) = data {
Some(RwLock::new(hash))
} else {
Some(RwLock::new(HashMap::new()))
};
let instance = Arc::new(Self {
config,
resolver,
data,
refresh_loop: RwLock::new(None),
});
if instance.enabled() {
info!("Scheduling DNS refresh loop");
let refresh_loop = tokio::task::spawn({
let instance = instance.clone();
async move {
instance.refresh_dns_entries_loop().await;
}
});
*(instance.refresh_loop.write().unwrap()) = Some(refresh_loop);
}
Ok(instance)
}
pub fn enabled(&self) -> bool {
self.config.enabled
}
// Schedules the refresher
async fn refresh_dns_entries_loop(&self) {
let resolver = TokioAsyncResolver::tokio_from_system_conf().unwrap();
let interval = Duration::from_secs(self.config.dns_max_ttl);
loop {
debug!("Begin refreshing cached DNS addresses.");
// To minimize the time we hold the lock, we first create
// an array with keys.
let mut hostnames: Vec<String> = Vec::new();
{
if let Some(ref data) = self.data {
for hostname in data.read().unwrap().keys() {
hostnames.push(hostname.clone());
}
}
}
for hostname in hostnames.iter() {
let addrset = self
.fetch_from_cache(hostname.as_str())
.expect("Could not obtain expected address from cache, this should not happen");
match resolver.lookup_ip(hostname).await {
Ok(lookup_ip) => {
let new_addrset = AddrSet::from(lookup_ip);
debug!(
"Obtained address for host ({}) -> ({:?})",
hostname, new_addrset
);
if addrset != new_addrset {
debug!(
"Addr changed from {:?} to {:?} updating cache.",
addrset, new_addrset
);
self.store_in_cache(hostname, new_addrset);
}
}
Err(err) => {
error!(
"There was an error trying to resolv {}: ({}).",
hostname, err
);
}
}
}
debug!("Finished refreshing cached DNS addresses.");
sleep(interval).await;
}
}
/// Returns a `AddrSet` given the specified hostname.
///
/// This method first tries to fetch the value from the cache, if it misses
/// then it is resolved and stored in the cache. TTL from records is ignored.
///
/// # Arguments
///
/// * `host` - A string slice referencing the hostname to be resolved.
///
/// # Example:
///
/// ```
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
///
/// # tokio_test::block_on(async {
/// let config = CachedResolverConfig::default();
/// let resolver = CachedResolver::new(config, None).await.unwrap();
/// let response = resolver.lookup_ip("www.google.com.");
/// # })
/// ```
///
pub async fn lookup_ip(&self, host: &str) -> ResolveResult<AddrSet> {
debug!("Lookup up {} in cache", host);
match self.fetch_from_cache(host) {
Some(addr_set) => {
debug!("Cache hit!");
Ok(addr_set)
}
None => {
debug!("Not found, executing a dns query!");
if let Some(ref resolver) = self.resolver {
let addr_set = AddrSet::from(resolver.lookup_ip(host).await?);
debug!("Obtained: {:?}", addr_set);
self.store_in_cache(host, addr_set.clone());
Ok(addr_set)
} else {
Err(ResolveError::from("No resolver available"))
}
}
}
}
//
// Returns true if the stored host resolution differs from the AddrSet passed.
pub fn has_changed(&self, host: &str, addr_set: &AddrSet) -> bool {
if let Some(fetched_addr_set) = self.fetch_from_cache(host) {
return fetched_addr_set != *addr_set;
}
false
}
// Fetches an AddrSet from the inner cache adquiring the read lock.
fn fetch_from_cache(&self, key: &str) -> Option<AddrSet> {
if let Some(ref hash) = self.data {
if let Some(addr_set) = hash.read().unwrap().get(key) {
return Some(addr_set.clone());
}
}
None
}
// Sets up the global CACHED_RESOLVER static variable so we can globally use DNS
// cache.
pub async fn from_config() -> Result<(), Error> {
let cached_resolver = CACHED_RESOLVER.load();
let desired_config = CachedResolverConfig::from(get_config());
if cached_resolver.config != desired_config {
if let Some(ref refresh_loop) = *(cached_resolver.refresh_loop.write().unwrap()) {
warn!("Killing Dnscache refresh loop as its configuration is being reloaded");
refresh_loop.abort()
}
let new_resolver = if let Some(ref data) = cached_resolver.data {
let data = Some(data.read().unwrap().clone());
CachedResolver::new(desired_config, data).await
} else {
CachedResolver::new(desired_config, None).await
};
match new_resolver {
Ok(ok) => {
CACHED_RESOLVER.store(ok);
Ok(())
}
Err(err) => {
let message = format!("Error setting up cached_resolver. Error: {:?}, will continue without this feature.", err);
Err(Error::DNSCachedError(message))
}
}
} else {
Ok(())
}
}
// Stores the AddrSet in cache adquiring the write lock.
fn store_in_cache(&self, host: &str, addr_set: AddrSet) {
if let Some(ref data) = self.data {
data.write().unwrap().insert(host.to_string(), addr_set);
} else {
error!("Could not insert, Hash not initialized");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use trust_dns_resolver::error::ResolveError;
#[tokio::test]
async fn new() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await;
assert!(resolver.is_ok());
}
#[tokio::test]
async fn lookup_ip() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let response = resolver.lookup_ip("www.google.com.").await;
assert!(response.is_ok());
}
#[tokio::test]
async fn has_changed() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let hostname = "www.google.com.";
let response = resolver.lookup_ip(hostname).await;
let addr_set = response.unwrap();
assert!(!resolver.has_changed(hostname, &addr_set));
}
#[tokio::test]
async fn unknown_host() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let hostname = "www.idontexists.";
let response = resolver.lookup_ip(hostname).await;
assert!(matches!(response, Err(ResolveError { .. })));
}
#[tokio::test]
async fn incorrect_address() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let hostname = "w ww.idontexists.";
let response = resolver.lookup_ip(hostname).await;
assert!(matches!(response, Err(ResolveError { .. })));
assert!(!resolver.has_changed(hostname, &AddrSet::new()));
}
#[tokio::test]
// Ok, this test is based on the fact that google does DNS RR
// and does not responds with every available ip everytime, so
// if I cache here, it will miss after one cache iteration or two.
async fn thread() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let hostname = "www.google.com.";
let response = resolver.lookup_ip(hostname).await;
let addr_set = response.unwrap();
assert!(!resolver.has_changed(hostname, &addr_set));
let resolver_for_refresher = resolver.clone();
let _thread_handle = tokio::task::spawn(async move {
resolver_for_refresher.refresh_dns_entries_loop().await;
});
assert!(!resolver.has_changed(hostname, &addr_set));
}
}

View File

@@ -1,20 +1,130 @@
/// Errors.
//! Errors.
/// Various errors.
#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Clone)]
pub enum Error {
SocketError(String),
ClientSocketError(String, ClientIdentifier),
ClientGeneralError(String, ClientIdentifier),
ClientAuthImpossible(String),
ClientAuthPassthroughError(String, ClientIdentifier),
ClientBadStartup,
ProtocolSyncError(String),
BadQuery(String),
ServerError,
ServerStartupError(String, ServerIdentifier),
ServerAuthError(String, ServerIdentifier),
BadConfig,
AllServersDown,
ClientError(String),
TlsError,
StatementTimeout,
DNSCachedError(String),
ShuttingDown,
ParseBytesError(String),
AuthError(String),
AuthPassthroughError(String),
UnsupportedStatement,
QueryRouterParserError(String),
QueryRouterError(String),
}
#[derive(Clone, PartialEq, Debug)]
pub struct ClientIdentifier {
pub application_name: String,
pub username: String,
pub pool_name: String,
}
impl ClientIdentifier {
pub fn new(application_name: &str, username: &str, pool_name: &str) -> ClientIdentifier {
ClientIdentifier {
application_name: application_name.into(),
username: username.into(),
pool_name: pool_name.into(),
}
}
}
impl std::fmt::Display for ClientIdentifier {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"{{ application_name: {}, username: {}, pool_name: {} }}",
self.application_name, self.username, self.pool_name
)
}
}
#[derive(Clone, PartialEq, Debug)]
pub struct ServerIdentifier {
pub username: String,
pub database: String,
}
impl ServerIdentifier {
pub fn new(username: &str, database: &str) -> ServerIdentifier {
ServerIdentifier {
username: username.into(),
database: database.into(),
}
}
}
impl std::fmt::Display for ServerIdentifier {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"{{ username: {}, database: {} }}",
self.username, self.database
)
}
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match &self {
&Error::ClientSocketError(error, client_identifier) => write!(
f,
"Error reading {} from client {}",
error, client_identifier
),
&Error::ClientGeneralError(error, client_identifier) => {
write!(f, "{} {}", error, client_identifier)
}
&Error::ClientAuthImpossible(username) => write!(
f,
"Client auth not possible, \
no cleartext password set for username: {} \
in config and auth passthrough (query_auth) \
is not set up.",
username
),
&Error::ClientAuthPassthroughError(error, client_identifier) => write!(
f,
"No cleartext password set, \
and no auth passthrough could not \
obtain the hash from server for {}, \
the error was: {}",
client_identifier, error
),
&Error::ServerStartupError(error, server_identifier) => write!(
f,
"Error reading {} on server startup {}",
error, server_identifier,
),
&Error::ServerAuthError(error, server_identifier) => {
write!(f, "{} for {}", error, server_identifier,)
}
// The rest can use Debug.
err => write!(f, "{:?}", err),
}
}
}
impl From<std::ffi::NulError> for Error {
fn from(err: std::ffi::NulError) -> Self {
Error::QueryRouterError(err.to_string())
}
}

View File

@@ -1,11 +1,18 @@
pub mod admin;
pub mod auth_passthrough;
pub mod client;
pub mod cmd_args;
pub mod config;
pub mod constants;
pub mod dns_cache;
pub mod errors;
pub mod logger;
pub mod messages;
pub mod mirrors;
pub mod multi_logger;
pub mod plugins;
pub mod pool;
pub mod prometheus;
pub mod query_router;
pub mod scram;
pub mod server;
pub mod sharding;

14
src/logger.rs Normal file
View File

@@ -0,0 +1,14 @@
use crate::cmd_args::{Args, LogFormat};
use tracing_subscriber;
pub fn init(args: &Args) {
let trace_sub = tracing_subscriber::fmt()
.with_max_level(args.log_level)
.with_ansi(!args.no_color);
match args.log_format {
LogFormat::Structured => trace_sub.json().init(),
LogFormat::Debug => trace_sub.pretty().init(),
_ => trace_sub.init(),
};
}

View File

@@ -36,6 +36,7 @@ extern crate sqlparser;
extern crate tokio;
extern crate tokio_rustls;
extern crate toml;
extern crate trust_dns_resolver;
#[cfg(not(target_env = "msvc"))]
use jemallocator::Jemalloc;
@@ -60,54 +61,32 @@ use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::broadcast;
mod admin;
mod auth;
mod auth_passthrough;
mod client;
mod config;
mod constants;
mod errors;
mod messages;
mod mirrors;
mod multi_logger;
mod pool;
mod prometheus;
mod query_router;
mod scram;
mod server;
mod sharding;
mod stats;
mod tls;
use crate::config::{get_config, reload_config, VERSION};
use crate::pool::{ClientServerMap, ConnectionPool};
use crate::prometheus::start_metric_server;
use crate::stats::{Collector, Reporter, REPORTER};
use pgcat::cmd_args;
use pgcat::config::{get_config, reload_config, VERSION};
use pgcat::dns_cache;
use pgcat::logger;
use pgcat::messages::configure_socket;
use pgcat::pool::{ClientServerMap, ConnectionPool};
use pgcat::prometheus::start_metric_server;
use pgcat::stats::{Collector, Reporter, REPORTER};
fn main() -> Result<(), Box<dyn std::error::Error>> {
multi_logger::MultiLogger::init().unwrap();
let args = cmd_args::parse();
logger::init(&args);
info!("Welcome to PgCat! Meow. (Version {})", VERSION);
if !query_router::QueryRouter::setup() {
if !pgcat::query_router::QueryRouter::setup() {
error!("Could not setup query router");
std::process::exit(exitcode::CONFIG);
}
let args = std::env::args().collect::<Vec<String>>();
let config_file = if args.len() == 2 {
args[1].to_string()
} else {
String::from("pgcat.toml")
};
// Create a transient runtime for loading the config for the first time.
{
let runtime = Builder::new_multi_thread().worker_threads(1).build()?;
runtime.block_on(async {
match config::parse(&config_file).await {
match pgcat::config::parse(args.config_file.as_str()).await {
Ok(_) => (),
Err(err) => {
error!("Config parse error: {:?}", err);
@@ -166,6 +145,12 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// Statistics reporting.
REPORTER.store(Arc::new(Reporter::default()));
// Starts (if enabled) dns cache before pools initialization
match dns_cache::CachedResolver::from_config().await {
Ok(_) => (),
Err(err) => error!("DNS cache initialization error: {:?}", err),
};
// Connection pool that allows to query all shards and replicas.
match ConnectionPool::from_config(client_server_map.clone()).await {
Ok(_) => (),
@@ -180,16 +165,19 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
stats_collector.collect().await;
});
info!("Config autoreloader: {}", config.general.autoreload);
info!("Config autoreloader: {}", match config.general.autoreload {
Some(interval) => format!("{} ms", interval),
None => "disabled".into(),
});
let mut autoreload_interval = tokio::time::interval(tokio::time::Duration::from_millis(15_000));
let autoreload_client_server_map = client_server_map.clone();
if let Some(interval) = config.general.autoreload {
let mut autoreload_interval = tokio::time::interval(tokio::time::Duration::from_millis(interval));
let autoreload_client_server_map = client_server_map.clone();
tokio::task::spawn(async move {
loop {
autoreload_interval.tick().await;
if config.general.autoreload {
info!("Automatically reloading config");
tokio::task::spawn(async move {
loop {
autoreload_interval.tick().await;
debug!("Automatically reloading config");
if let Ok(changed) = reload_config(autoreload_client_server_map.clone()).await {
if changed {
@@ -197,8 +185,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
}
};
}
}
});
});
};
#[cfg(windows)]
let mut term_signal = win_signal::ctrl_close().unwrap();
@@ -283,18 +273,20 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let drain_tx = drain_tx.clone();
let client_server_map = client_server_map.clone();
let tls_certificate = config.general.tls_certificate.clone();
let tls_certificate = get_config().general.tls_certificate.clone();
configure_socket(&socket);
tokio::task::spawn(async move {
let start = chrono::offset::Utc::now().naive_utc();
match client::client_entrypoint(
match pgcat::client::client_entrypoint(
socket,
client_server_map,
shutdown_rx,
drain_tx,
admin_only,
tls_certificate.clone(),
tls_certificate,
config.general.log_client_connections,
)
.await
@@ -302,7 +294,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) => {
let duration = chrono::offset::Utc::now().naive_utc() - start;
if config.general.log_client_disconnections {
if get_config().general.log_client_disconnections {
info!(
"Client {:?} disconnected, session duration: {}",
addr,
@@ -319,7 +311,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
Err(err) => {
match err {
errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err),
pgcat::errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err),
_ => warn!("Client disconnected with error {:?}", err),
}

View File

@@ -7,11 +7,15 @@ use socket2::{SockRef, TcpKeepalive};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use crate::client::PREPARED_STATEMENT_COUNTER;
use crate::config::get_config;
use crate::errors::Error;
use std::collections::HashMap;
use std::ffi::CString;
use std::io::{BufRead, Cursor};
use std::mem;
use std::sync::atomic::Ordering;
use std::time::Duration;
/// Postgres data type mappings
@@ -20,6 +24,10 @@ pub enum DataType {
Text,
Int4,
Numeric,
Bool,
Oid,
AnyArray,
Any,
}
impl From<&DataType> for i32 {
@@ -28,6 +36,10 @@ impl From<&DataType> for i32 {
DataType::Text => 25,
DataType::Int4 => 23,
DataType::Numeric => 1700,
DataType::Bool => 16,
DataType::Oid => 26,
DataType::AnyArray => 2277,
DataType::Any => 2276,
}
}
}
@@ -46,6 +58,29 @@ where
write_all(stream, auth_ok).await
}
/// Generate md5 password challenge.
pub async fn md5_challenge<S>(stream: &mut S) -> Result<[u8; 4], Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
// let mut rng = rand::thread_rng();
let salt: [u8; 4] = [
rand::random(),
rand::random(),
rand::random(),
rand::random(),
];
let mut res = BytesMut::new();
res.put_u8(b'R');
res.put_i32(12);
res.put_i32(5); // MD5
res.put_slice(&salt[..]);
write_all(stream, res).await?;
Ok(salt)
}
/// Give the client the process_id and secret we generated
/// used in query cancellation.
pub async fn backend_key_data<S>(
@@ -93,7 +128,10 @@ where
/// Send the startup packet the server. We're pretending we're a Pg client.
/// This tells the server which user we are and what database we want.
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> {
pub async fn startup<S>(stream: &mut S, user: &str, database: &str) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut bytes = BytesMut::with_capacity(25);
bytes.put_i32(196608); // Protocol number
@@ -127,6 +165,21 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
}
}
pub async fn ssl_request(stream: &mut TcpStream) -> Result<(), Error> {
let mut bytes = BytesMut::with_capacity(12);
bytes.put_i32(8);
bytes.put_i32(80877103);
match stream.write_all(&bytes).await {
Ok(_) => Ok(()),
Err(err) => Err(Error::SocketError(format!(
"Error writing SSLRequest to server socket - Error: {:?}",
err
))),
}
}
/// Parse the params the server sends as a key/value format.
pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
let mut result = HashMap::new();
@@ -234,8 +287,6 @@ pub async fn md5_password_with_hash<S>(stream: &mut S, hash: &str, salt: &[u8])
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
debug!("Sending hash {} to server", hash);
let password = md5_hash_second_pass(hash, salt);
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
@@ -383,7 +434,7 @@ pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
let mut res = BytesMut::new();
let mut row_desc = BytesMut::new();
// how many colums we are storing
// how many columns we are storing
row_desc.put_i16(columns.len() as i16);
for (name, data_type) in columns {
@@ -404,6 +455,10 @@ pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
DataType::Text => -1,
DataType::Int4 => 4,
DataType::Numeric => -1,
DataType::Bool => 1,
DataType::Oid => 4,
DataType::AnyArray => -1,
DataType::Any => -1,
};
row_desc.put_i16(type_size);
@@ -442,6 +497,29 @@ pub fn data_row(row: &Vec<String>) -> BytesMut {
res
}
pub fn data_row_nullable(row: &Vec<Option<String>>) -> BytesMut {
let mut res = BytesMut::new();
let mut data_row = BytesMut::new();
data_row.put_i16(row.len() as i16);
for column in row {
if let Some(column) = column {
let column = column.as_bytes();
data_row.put_i32(column.len() as i32);
data_row.put_slice(column);
} else {
data_row.put_i32(-1 as i32);
}
}
res.put_u8(b'D');
res.put_i32(data_row.len() as i32 + 4);
res.put(data_row);
res
}
/// Create a CommandComplete message.
pub fn command_complete(command: &str) -> BytesMut {
let cmd = BytesMut::from(format!("{}\0", command).as_bytes());
@@ -452,6 +530,33 @@ pub fn command_complete(command: &str) -> BytesMut {
res
}
/// Create a notify message.
pub fn notify(message: &str, details: String) -> BytesMut {
let mut notify_cmd = BytesMut::new();
notify_cmd.put_slice("SNOTICE\0".as_bytes());
notify_cmd.put_slice("C00000\0".as_bytes());
notify_cmd.put_slice(format!("M{}\0", message).as_bytes());
notify_cmd.put_slice(format!("D{}\0", details).as_bytes());
// this extra byte says that is the end of the package
notify_cmd.put_u8(0);
let mut res = BytesMut::new();
res.put_u8(b'N');
res.put_i32(notify_cmd.len() as i32 + 4);
res.put(notify_cmd);
res
}
pub fn flush() -> BytesMut {
let mut bytes = BytesMut::new();
bytes.put_u8(b'H');
bytes.put_i32(4);
bytes
}
/// Write all data in the buffer to the TcpStream.
pub async fn write_all<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error>
where
@@ -484,6 +589,29 @@ where
}
}
pub async fn write_all_flush<S>(stream: &mut S, buf: &[u8]) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
match stream.write_all(buf).await {
Ok(_) => match stream.flush().await {
Ok(_) => Ok(()),
Err(err) => {
return Err(Error::SocketError(format!(
"Error flushing socket - Error: {:?}",
err
)))
}
},
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
)))
}
}
}
/// Read a complete message from the socket.
pub async fn read_message<S>(stream: &mut S) -> Result<BytesMut, Error>
where
@@ -561,6 +689,13 @@ pub fn configure_socket(stream: &TcpStream) {
let sock_ref = SockRef::from(stream);
let conf = get_config();
#[cfg(target_os = "linux")]
match sock_ref.set_tcp_user_timeout(Some(Duration::from_millis(conf.general.tcp_user_timeout)))
{
Ok(_) => (),
Err(err) => error!("Could not configure tcp_user_timeout for socket: {}", err),
}
match sock_ref.set_keepalive(true) {
Ok(_) => {
match sock_ref.set_tcp_keepalive(
@@ -570,7 +705,7 @@ pub fn configure_socket(stream: &TcpStream) {
.with_time(Duration::from_secs(conf.general.tcp_keepalives_idle)),
) {
Ok(_) => (),
Err(err) => error!("Could not configure socket: {}", err),
Err(err) => error!("Could not configure tcp_keepalive for socket: {}", err),
}
}
Err(err) => error!("Could not configure socket: {}", err),
@@ -592,3 +727,374 @@ impl BytesMutReader for Cursor<&BytesMut> {
}
}
}
/// Parse (F) message.
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
#[derive(Clone, Debug)]
pub struct Parse {
code: char,
#[allow(dead_code)]
len: i32,
pub name: String,
pub generated_name: String,
query: String,
num_params: i16,
param_types: Vec<i32>,
}
impl TryFrom<&BytesMut> for Parse {
type Error = Error;
fn try_from(buf: &BytesMut) -> Result<Parse, Error> {
let mut cursor = Cursor::new(buf);
let code = cursor.get_u8() as char;
let len = cursor.get_i32();
let name = cursor.read_string()?;
let query = cursor.read_string()?;
let num_params = cursor.get_i16();
let mut param_types = Vec::new();
for _ in 0..num_params {
param_types.push(cursor.get_i32());
}
Ok(Parse {
code,
len,
name,
generated_name: prepared_statement_name(),
query,
num_params,
param_types,
})
}
}
impl TryFrom<Parse> for BytesMut {
type Error = Error;
fn try_from(parse: Parse) -> Result<BytesMut, Error> {
let mut bytes = BytesMut::new();
let name_binding = CString::new(parse.name)?;
let name = name_binding.as_bytes_with_nul();
let query_binding = CString::new(parse.query)?;
let query = query_binding.as_bytes_with_nul();
// Recompute length of the message.
let len = 4 // self
+ name.len()
+ query.len()
+ 2
+ 4 * parse.num_params as usize;
bytes.put_u8(parse.code as u8);
bytes.put_i32(len as i32);
bytes.put_slice(name);
bytes.put_slice(query);
bytes.put_i16(parse.num_params);
for param in parse.param_types {
bytes.put_i32(param);
}
Ok(bytes)
}
}
impl TryFrom<&Parse> for BytesMut {
type Error = Error;
fn try_from(parse: &Parse) -> Result<BytesMut, Error> {
parse.clone().try_into()
}
}
impl Parse {
pub fn rename(mut self) -> Self {
self.name = self.generated_name.to_string();
self
}
pub fn anonymous(&self) -> bool {
self.name.is_empty()
}
}
/// Bind (B) message.
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
#[derive(Clone, Debug)]
pub struct Bind {
code: char,
#[allow(dead_code)]
len: i64,
portal: String,
pub prepared_statement: String,
num_param_format_codes: i16,
param_format_codes: Vec<i16>,
num_param_values: i16,
param_values: Vec<(i32, BytesMut)>,
num_result_column_format_codes: i16,
result_columns_format_codes: Vec<i16>,
}
impl TryFrom<&BytesMut> for Bind {
type Error = Error;
fn try_from(buf: &BytesMut) -> Result<Bind, Error> {
let mut cursor = Cursor::new(buf);
let code = cursor.get_u8() as char;
let len = cursor.get_i32();
let portal = cursor.read_string()?;
let prepared_statement = cursor.read_string()?;
let num_param_format_codes = cursor.get_i16();
let mut param_format_codes = Vec::new();
for _ in 0..num_param_format_codes {
param_format_codes.push(cursor.get_i16());
}
let num_param_values = cursor.get_i16();
let mut param_values = Vec::new();
for _ in 0..num_param_values {
let param_len = cursor.get_i32();
// There is special occasion when the parameter is NULL
// In that case, param length is defined as -1
// So if the passed parameter len is over 0
if param_len > 0 {
let mut param = BytesMut::with_capacity(param_len as usize);
param.resize(param_len as usize, b'0');
cursor.copy_to_slice(&mut param);
// we push and the length and the parameter into vector
param_values.push((param_len, param));
} else {
// otherwise we push a tuple with -1 and 0-len BytesMut
// which means that after encountering -1 postgres proceeds
// to processing another parameter
param_values.push((param_len, BytesMut::new()));
}
}
let num_result_column_format_codes = cursor.get_i16();
let mut result_columns_format_codes = Vec::new();
for _ in 0..num_result_column_format_codes {
result_columns_format_codes.push(cursor.get_i16());
}
Ok(Bind {
code,
len: len as i64,
portal,
prepared_statement,
num_param_format_codes,
param_format_codes,
num_param_values,
param_values,
num_result_column_format_codes,
result_columns_format_codes,
})
}
}
impl TryFrom<Bind> for BytesMut {
type Error = Error;
fn try_from(bind: Bind) -> Result<BytesMut, Error> {
let mut bytes = BytesMut::new();
let portal_binding = CString::new(bind.portal)?;
let portal = portal_binding.as_bytes_with_nul();
let prepared_statement_binding = CString::new(bind.prepared_statement)?;
let prepared_statement = prepared_statement_binding.as_bytes_with_nul();
let mut len = 4 // self
+ portal.len()
+ prepared_statement.len()
+ 2 // num_param_format_codes
+ 2 * bind.num_param_format_codes as usize // num_param_format_codes
+ 2; // num_param_values
for (param_len, _) in &bind.param_values {
len += 4 + *param_len as usize;
}
len += 2; // num_result_column_format_codes
len += 2 * bind.num_result_column_format_codes as usize;
bytes.put_u8(bind.code as u8);
bytes.put_i32(len as i32);
bytes.put_slice(portal);
bytes.put_slice(prepared_statement);
bytes.put_i16(bind.num_param_format_codes);
for param_format_code in bind.param_format_codes {
bytes.put_i16(param_format_code);
}
bytes.put_i16(bind.num_param_values);
for (param_len, param) in bind.param_values {
bytes.put_i32(param_len);
bytes.put_slice(&param);
}
bytes.put_i16(bind.num_result_column_format_codes);
for result_column_format_code in bind.result_columns_format_codes {
bytes.put_i16(result_column_format_code);
}
Ok(bytes)
}
}
impl Bind {
pub fn reassign(mut self, parse: &Parse) -> Self {
self.prepared_statement = parse.name.clone();
self
}
pub fn anonymous(&self) -> bool {
self.prepared_statement.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct Describe {
code: char,
#[allow(dead_code)]
len: i32,
target: char,
pub statement_name: String,
}
impl TryFrom<&BytesMut> for Describe {
type Error = Error;
fn try_from(bytes: &BytesMut) -> Result<Describe, Error> {
let mut cursor = Cursor::new(bytes);
let code = cursor.get_u8() as char;
let len = cursor.get_i32();
let target = cursor.get_u8() as char;
let statement_name = cursor.read_string()?;
Ok(Describe {
code,
len,
target,
statement_name,
})
}
}
impl TryFrom<Describe> for BytesMut {
type Error = Error;
fn try_from(describe: Describe) -> Result<BytesMut, Error> {
let mut bytes = BytesMut::new();
let statement_name_binding = CString::new(describe.statement_name)?;
let statement_name = statement_name_binding.as_bytes_with_nul();
let len = 4 + 1 + statement_name.len();
bytes.put_u8(describe.code as u8);
bytes.put_i32(len as i32);
bytes.put_u8(describe.target as u8);
bytes.put_slice(statement_name);
Ok(bytes)
}
}
impl Describe {
pub fn rename(mut self, name: &str) -> Self {
self.statement_name = name.to_string();
self
}
pub fn anonymous(&self) -> bool {
self.statement_name.is_empty()
}
}
/// Close (F) message.
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
#[derive(Clone, Debug)]
pub struct Close {
code: char,
#[allow(dead_code)]
len: i32,
close_type: char,
pub name: String,
}
impl TryFrom<&BytesMut> for Close {
type Error = Error;
fn try_from(bytes: &BytesMut) -> Result<Close, Error> {
let mut cursor = Cursor::new(bytes);
let code = cursor.get_u8() as char;
let len = cursor.get_i32();
let close_type = cursor.get_u8() as char;
let name = cursor.read_string()?;
Ok(Close {
code,
len,
close_type,
name,
})
}
}
impl TryFrom<Close> for BytesMut {
type Error = Error;
fn try_from(close: Close) -> Result<BytesMut, Error> {
debug!("Close: {:?}", close);
let mut bytes = BytesMut::new();
let name_binding = CString::new(close.name)?;
let name = name_binding.as_bytes_with_nul();
let len = 4 + 1 + name.len();
bytes.put_u8(close.code as u8);
bytes.put_i32(len as i32);
bytes.put_u8(close.close_type as u8);
bytes.put_slice(name);
Ok(bytes)
}
}
impl Close {
pub fn new(name: &str) -> Close {
let name = name.to_string();
Close {
code: 'C',
len: 4 + 1 + name.len() as i32 + 1, // will be recalculated
close_type: 'S',
name,
}
}
pub fn is_prepared_statement(&self) -> bool {
self.close_type == 'S'
}
pub fn anonymous(&self) -> bool {
self.name.is_empty()
}
}
pub fn close_complete() -> BytesMut {
let mut bytes = BytesMut::new();
bytes.put_u8(b'3');
bytes.put_i32(4);
bytes
}
pub fn prepared_statement_name() -> String {
format!(
"P_{}",
PREPARED_STATEMENT_COUNTER.fetch_add(1, Ordering::SeqCst)
)
}

View File

@@ -7,8 +7,7 @@ use bytes::{Bytes, BytesMut};
use parking_lot::RwLock;
use crate::config::{get_config, Address, Role, User};
use crate::pool::{ClientServerMap, PoolIdentifier, ServerPool};
use crate::stats::PoolStats;
use crate::pool::{ClientServerMap, ServerPool};
use log::{error, info, trace, warn};
use tokio::sync::mpsc::{channel, Receiver, Sender};
@@ -24,7 +23,7 @@ impl MirroredClient {
async fn create_pool(&self) -> Pool<ServerPool> {
let config = get_config();
let default = std::time::Duration::from_millis(10_000).as_millis() as u64;
let (connection_timeout, idle_timeout, cfg) =
let (connection_timeout, idle_timeout, _cfg) =
match config.pools.get(&self.address.pool_name) {
Some(cfg) => (
cfg.connect_timeout.unwrap_or(default),
@@ -34,15 +33,14 @@ impl MirroredClient {
None => (default, default, crate::config::Pool::default()),
};
let identifier = PoolIdentifier::new(&self.database, &self.user.username, None);
let manager = ServerPool::new(
self.address.clone(),
self.user.clone(),
self.database.as_str(),
ClientServerMap::default(),
Arc::new(PoolStats::new(identifier, cfg.clone())),
Arc::new(RwLock::new(None)),
None,
true,
);
Pool::builder()

View File

@@ -1,80 +0,0 @@
use log::{Level, Log, Metadata, Record, SetLoggerError};
// This is a special kind of logger that allows sending logs to different
// targets depending on the log level.
//
// By default, if nothing is set, it acts as a regular env_log logger,
// it sends everything to standard error.
//
// If the Env variable `STDOUT_LOG` is defined, it will be used for
// configuring the standard out logger.
//
// The behavior is:
// - If it is an error, the message is written to standard error.
// - If it is not, and it matches the log level of the standard output logger (`STDOUT_LOG` env var), it will be send to standard output.
// - If the above is not true, it is sent to the stderr logger that will log it or not depending on the value
// of the RUST_LOG env var.
//
// So to summarize, if no `STDOUT_LOG` env var is present, the logger is the default logger. If `STDOUT_LOG` is set, everything
// but errors, that matches the log level set in the `STDOUT_LOG` env var is sent to stdout. You can have also some esoteric configuration
// where you set `RUST_LOG=debug` and `STDOUT_LOG=info`, in here, erros will go to stderr, warns and infos to stdout and debugs to stderr.
//
pub struct MultiLogger {
stderr_logger: env_logger::Logger,
stdout_logger: env_logger::Logger,
}
impl MultiLogger {
fn new() -> Self {
let stderr_logger = env_logger::builder().format_timestamp_micros().build();
let stdout_logger = env_logger::Builder::from_env("STDOUT_LOG")
.format_timestamp_micros()
.target(env_logger::Target::Stdout)
.build();
Self {
stderr_logger,
stdout_logger,
}
}
pub fn init() -> Result<(), SetLoggerError> {
let logger = Self::new();
log::set_max_level(logger.stderr_logger.filter());
log::set_boxed_logger(Box::new(logger))
}
}
impl Log for MultiLogger {
fn enabled(&self, metadata: &Metadata) -> bool {
self.stderr_logger.enabled(metadata) && self.stdout_logger.enabled(metadata)
}
fn log(&self, record: &Record) {
if record.level() == Level::Error {
self.stderr_logger.log(record);
} else {
if self.stdout_logger.matches(record) {
self.stdout_logger.log(record);
} else {
self.stderr_logger.log(record);
}
}
}
fn flush(&self) {
self.stderr_logger.flush();
self.stdout_logger.flush();
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_init() {
MultiLogger::init().unwrap();
}
}

120
src/plugins/intercept.rs Normal file
View File

@@ -0,0 +1,120 @@
//! The intercept plugin.
//!
//! It intercepts queries and returns fake results.
use async_trait::async_trait;
use bytes::{BufMut, BytesMut};
use serde::{Deserialize, Serialize};
use sqlparser::ast::Statement;
use log::debug;
use crate::{
config::Intercept as InterceptConfig,
errors::Error,
messages::{command_complete, data_row_nullable, row_description, DataType},
plugins::{Plugin, PluginOutput},
query_router::QueryRouter,
};
// TODO: use these structs for deserialization
#[derive(Serialize, Deserialize)]
pub struct Rule {
query: String,
schema: Vec<Column>,
result: Vec<Vec<String>>,
}
#[derive(Serialize, Deserialize)]
pub struct Column {
name: String,
data_type: String,
}
/// The intercept plugin.
pub struct Intercept<'a> {
pub enabled: bool,
pub config: &'a InterceptConfig,
}
#[async_trait]
impl<'a> Plugin for Intercept<'a> {
async fn run(
&mut self,
query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error> {
if !self.enabled || ast.is_empty() {
return Ok(PluginOutput::Allow);
}
let mut config = self.config.clone();
config.substitute(
&query_router.pool_settings().db,
&query_router.pool_settings().user.username,
);
let mut result = BytesMut::new();
for q in ast {
// Normalization
let q = q.to_string().to_ascii_lowercase();
for (_, target) in config.queries.iter() {
if target.query.as_str() == q {
debug!("Intercepting query: {}", q);
let rd = target
.schema
.iter()
.map(|row| {
let name = &row[0];
let data_type = &row[1];
(
name.as_str(),
match data_type.as_str() {
"text" => DataType::Text,
"anyarray" => DataType::AnyArray,
"oid" => DataType::Oid,
"bool" => DataType::Bool,
"int4" => DataType::Int4,
_ => DataType::Any,
},
)
})
.collect::<Vec<(&str, DataType)>>();
result.put(row_description(&rd));
target.result.iter().for_each(|row| {
let row = row
.iter()
.map(|s| {
let s = s.as_str().to_string();
if s == "" {
None
} else {
Some(s)
}
})
.collect::<Vec<Option<String>>>();
result.put(data_row_nullable(&row));
});
result.put(command_complete("SELECT"));
}
}
}
if !result.is_empty() {
result.put_u8(b'Z');
result.put_i32(5);
result.put_u8(b'I');
return Ok(PluginOutput::Intercept(result));
} else {
Ok(PluginOutput::Allow)
}
}
}

44
src/plugins/mod.rs Normal file
View File

@@ -0,0 +1,44 @@
//! The plugin ecosystem.
//!
//! Currently plugins only grant access or deny access to the database for a particual query.
//! Example use cases:
//! - block known bad queries
//! - block access to system catalogs
//! - block dangerous modifications like `DROP TABLE`
//! - etc
//!
pub mod intercept;
pub mod prewarmer;
pub mod query_logger;
pub mod table_access;
use crate::{errors::Error, query_router::QueryRouter};
use async_trait::async_trait;
use bytes::BytesMut;
use sqlparser::ast::Statement;
pub use intercept::Intercept;
pub use query_logger::QueryLogger;
pub use table_access::TableAccess;
#[derive(Clone, Debug, PartialEq)]
pub enum PluginOutput {
Allow,
Deny(String),
Overwrite(Vec<Statement>),
Intercept(BytesMut),
}
#[async_trait]
pub trait Plugin {
// Run before the query is sent to the server.
async fn run(
&mut self,
query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error>;
// TODO: run after the result is returned
// async fn callback(&mut self, query_router: &QueryRouter);
}

28
src/plugins/prewarmer.rs Normal file
View File

@@ -0,0 +1,28 @@
//! Prewarm new connections before giving them to the client.
use crate::{errors::Error, server::Server};
use log::info;
pub struct Prewarmer<'a> {
pub enabled: bool,
pub server: &'a mut Server,
pub queries: &'a Vec<String>,
}
impl<'a> Prewarmer<'a> {
pub async fn run(&mut self) -> Result<(), Error> {
if !self.enabled {
return Ok(());
}
for query in self.queries {
info!(
"{} Prewarning with query: `{}`",
self.server.address(),
query
);
self.server.query(&query).await?;
}
Ok(())
}
}

View File

@@ -0,0 +1,38 @@
//! Log all queries to stdout (or somewhere else, why not).
use crate::{
errors::Error,
plugins::{Plugin, PluginOutput},
query_router::QueryRouter,
};
use async_trait::async_trait;
use log::info;
use sqlparser::ast::Statement;
pub struct QueryLogger<'a> {
pub enabled: bool,
pub user: &'a str,
pub db: &'a str,
}
#[async_trait]
impl<'a> Plugin for QueryLogger<'a> {
async fn run(
&mut self,
_query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error> {
if !self.enabled {
return Ok(PluginOutput::Allow);
}
let query = ast
.iter()
.map(|q| q.to_string())
.collect::<Vec<String>>()
.join("; ");
info!("[pool: {}][user: {}] {}", self.user, self.db, query);
Ok(PluginOutput::Allow)
}
}

View File

@@ -0,0 +1,59 @@
//! This query router plugin will check if the user can access a particular
//! table as part of their query. If they can't, the query will not be routed.
use async_trait::async_trait;
use sqlparser::ast::{visit_relations, Statement};
use crate::{
errors::Error,
plugins::{Plugin, PluginOutput},
query_router::QueryRouter,
};
use log::debug;
use core::ops::ControlFlow;
pub struct TableAccess<'a> {
pub enabled: bool,
pub tables: &'a Vec<String>,
}
#[async_trait]
impl<'a> Plugin for TableAccess<'a> {
async fn run(
&mut self,
_query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error> {
if !self.enabled {
return Ok(PluginOutput::Allow);
}
let mut found = None;
visit_relations(ast, |relation| {
let relation = relation.to_string();
let parts = relation.split(".").collect::<Vec<&str>>();
let table_name = parts.last().unwrap();
if self.tables.contains(&table_name.to_string()) {
found = Some(table_name.to_string());
ControlFlow::<()>::Break(())
} else {
ControlFlow::<()>::Continue(())
}
});
if let Some(found) = found {
debug!("Blocking access to table \"{}\"", found);
Ok(PluginOutput::Deny(format!(
"permission for table \"{}\" denied",
found
)))
} else {
Ok(PluginOutput::Allow)
}
}
}

View File

@@ -1,6 +1,6 @@
use arc_swap::ArcSwap;
use async_trait::async_trait;
use bb8::{ManageConnection, Pool, PooledConnection};
use bb8::{ManageConnection, Pool, PooledConnection, QueueStrategy};
use bytes::{BufMut, BytesMut};
use chrono::naive::NaiveDateTime;
use log::{debug, error, info, warn};
@@ -10,6 +10,7 @@ use rand::seq::SliceRandom;
use rand::thread_rng;
use regex::Regex;
use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
@@ -17,13 +18,16 @@ use std::sync::{
use std::time::Instant;
use tokio::sync::Notify;
use crate::config::{get_config, Address, General, LoadBalancingMode, PoolMode, Role, User};
use crate::config::{
get_config, Address, General, LoadBalancingMode, Plugins, PoolMode, Role, User,
};
use crate::errors::Error;
use crate::auth_passthrough::AuthPassthrough;
use crate::plugins::prewarmer;
use crate::server::Server;
use crate::sharding::ShardingFunction;
use crate::stats::{AddressStats, ClientStats, PoolStats, ServerStats};
use crate::stats::{AddressStats, ClientStats, ServerStats};
pub type ProcessId = i32;
pub type SecretKey = i32;
@@ -59,22 +63,32 @@ pub struct PoolIdentifier {
/// The username the client connects with. Each user gets its own pool.
pub user: String,
/// The client secret (password).
pub secret: Option<String>,
}
static POOL_REAPER_RATE: u64 = 30_000; // 30 seconds by default
impl PoolIdentifier {
/// Create a new user/pool identifier.
pub fn new(db: &str, user: &str, secret: Option<String>) -> PoolIdentifier {
pub fn new(db: &str, user: &str) -> PoolIdentifier {
PoolIdentifier {
db: db.to_string(),
user: user.to_string(),
secret,
}
}
}
impl Display for PoolIdentifier {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}@{}", self.user, self.db)
}
}
impl From<&Address> for PoolIdentifier {
fn from(address: &Address) -> PoolIdentifier {
PoolIdentifier::new(&address.database, &address.username)
}
}
/// Pool settings.
#[derive(Clone, Debug)]
pub struct PoolSettings {
@@ -89,6 +103,7 @@ pub struct PoolSettings {
// Connecting user.
pub user: User,
pub db: String,
// Default server role to connect to.
pub default_role: Option<Role>,
@@ -127,6 +142,9 @@ pub struct PoolSettings {
pub auth_query: Option<String>,
pub auth_query_user: Option<String>,
pub auth_query_password: Option<String>,
/// Plugins
pub plugins: Option<Plugins>,
}
impl Default for PoolSettings {
@@ -136,6 +154,7 @@ impl Default for PoolSettings {
load_balancing_mode: LoadBalancingMode::Random,
shards: 1,
user: User::default(),
db: String::default(),
default_role: None,
query_parser_enabled: false,
primary_reads_enabled: true,
@@ -150,6 +169,7 @@ impl Default for PoolSettings {
auth_query: None,
auth_query_user: None,
auth_query_password: None,
plugins: None,
}
}
}
@@ -189,8 +209,6 @@ pub struct ConnectionPool {
paused: Arc<AtomicBool>,
paused_waiter: Arc<Notify>,
pub stats: Arc<PoolStats>,
/// AuthInfo
pub auth_hash: Arc<RwLock<Option<String>>>,
}
@@ -208,241 +226,276 @@ impl ConnectionPool {
// There is one pool per database/user pair.
for user in pool_config.users.values() {
let mut secrets = match &user.secrets {
Some(_) => user
.secrets
.as_ref()
.unwrap()
.iter()
.map(|secret| Some(secret.to_string()))
.collect::<Vec<Option<String>>>(),
None => vec![],
};
let old_pool_ref = get_pool(pool_name, &user.username);
let identifier = PoolIdentifier::new(pool_name, &user.username);
secrets.push(None);
match old_pool_ref {
Some(pool) => {
// If the pool hasn't changed, get existing reference and insert it into the new_pools.
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
if pool.config_hash == new_pool_hash_value {
info!(
"[pool: {}][user: {}] has not changed",
pool_name, user.username
);
new_pools.insert(identifier.clone(), pool.clone());
continue;
}
}
None => (),
}
for secret in secrets {
let old_pool_ref = get_pool(pool_name, &user.username, secret.clone());
let identifier = PoolIdentifier::new(pool_name, &user.username, secret.clone());
info!(
"[pool: {}][user: {}] creating new pool",
pool_name, user.username
);
match old_pool_ref {
Some(pool) => {
// If the pool hasn't changed, get existing reference and insert it into the new_pools.
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
if pool.config_hash == new_pool_hash_value {
info!(
"[pool: {}][user: {}] has not changed",
pool_name, user.username
);
new_pools.insert(identifier.clone(), pool.clone());
continue;
let mut shards = Vec::new();
let mut addresses = Vec::new();
let mut banlist = Vec::new();
let mut shard_ids = pool_config
.shards
.clone()
.into_keys()
.collect::<Vec<String>>();
// Sort by shard number to ensure consistency.
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
let pool_auth_hash: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
for shard_idx in &shard_ids {
let shard = &pool_config.shards[shard_idx];
let mut pools = Vec::new();
let mut servers = Vec::new();
let mut replica_number = 0;
// Load Mirror settings
for (address_index, server) in shard.servers.iter().enumerate() {
let mut mirror_addresses = vec![];
if let Some(mirror_settings_vec) = &shard.mirrors {
for (mirror_idx, mirror_settings) in
mirror_settings_vec.iter().enumerate()
{
if mirror_settings.mirroring_target_index != address_index {
continue;
}
mirror_addresses.push(Address {
id: address_id,
database: shard.database.clone(),
host: mirror_settings.host.clone(),
port: mirror_settings.port,
role: server.role,
address_index: mirror_idx,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
username: user.username.clone(),
pool_name: pool_name.clone(),
mirrors: vec![],
stats: Arc::new(AddressStats::default()),
});
address_id += 1;
}
}
None => (),
let address = Address {
id: address_id,
database: shard.database.clone(),
host: server.host.clone(),
port: server.port,
role: server.role,
address_index,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
username: user.username.clone(),
pool_name: pool_name.clone(),
mirrors: mirror_addresses,
stats: Arc::new(AddressStats::default()),
};
address_id += 1;
if server.role == Role::Replica {
replica_number += 1;
}
// We assume every server in the pool share user/passwords
let auth_passthrough = AuthPassthrough::from_pool_config(pool_config);
if let Some(apt) = &auth_passthrough {
match apt.fetch_hash(&address).await {
Ok(ok) => {
if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read())
{
if ok != *pool_auth_hash_value {
warn!(
"Hash is not the same across shards \
of the same pool, client auth will \
be done using last obtained hash. \
Server: {}:{}, Database: {}",
server.host, server.port, shard.database,
);
}
}
debug!("Hash obtained for {:?}", address);
{
let mut pool_auth_hash = pool_auth_hash.write();
*pool_auth_hash = Some(ok.clone());
}
}
Err(err) => warn!(
"Could not obtain password hashes \
using auth_query config, ignoring. \
Error: {:?}",
err,
),
}
}
let manager = ServerPool::new(
address.clone(),
user.clone(),
&shard.database,
client_server_map.clone(),
pool_auth_hash.clone(),
match pool_config.plugins {
Some(ref plugins) => Some(plugins.clone()),
None => config.plugins.clone(),
},
pool_config.cleanup_server_connections,
);
let connect_timeout = match pool_config.connect_timeout {
Some(connect_timeout) => connect_timeout,
None => config.general.connect_timeout,
};
let idle_timeout = match pool_config.idle_timeout {
Some(idle_timeout) => idle_timeout,
None => config.general.idle_timeout,
};
let server_lifetime = match user.server_lifetime {
Some(server_lifetime) => server_lifetime,
None => match pool_config.server_lifetime {
Some(server_lifetime) => server_lifetime,
None => config.general.server_lifetime,
},
};
let reaper_rate = *vec![idle_timeout, server_lifetime, POOL_REAPER_RATE]
.iter()
.min()
.unwrap();
let queue_strategy = match config.general.server_round_robin {
true => QueueStrategy::Fifo,
false => QueueStrategy::Lifo,
};
debug!(
"[pool: {}][user: {}] Pool reaper rate: {}ms",
pool_name, user.username, reaper_rate
);
let pool = Pool::builder()
.max_size(user.pool_size)
.min_idle(user.min_pool_size)
.connection_timeout(std::time::Duration::from_millis(connect_timeout))
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
.max_lifetime(Some(std::time::Duration::from_millis(server_lifetime)))
.reaper_rate(std::time::Duration::from_millis(reaper_rate))
.queue_strategy(queue_strategy)
.test_on_check_out(false);
let pool = if config.general.validate_config {
pool.build(manager).await?
} else {
pool.build_unchecked(manager)
};
pools.push(pool);
servers.push(address);
}
shards.push(pools);
addresses.push(servers);
banlist.push(HashMap::new());
}
assert_eq!(shards.len(), addresses.len());
if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) {
info!(
"[pool: {}][user: {}] creating new pool",
"Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}",
pool_name, user.username
);
}
let mut shards = Vec::new();
let mut addresses = Vec::new();
let mut banlist = Vec::new();
let mut shard_ids = pool_config
.shards
.clone()
.into_keys()
.collect::<Vec<String>>();
let pool_stats = Arc::new(PoolStats::new(identifier, pool_config.clone()));
// Allow the pool to be seen in statistics
pool_stats.register(pool_stats.clone());
// Sort by shard number to ensure consistency.
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
let pool_auth_hash: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
for shard_idx in &shard_ids {
let shard = &pool_config.shards[shard_idx];
let mut pools = Vec::new();
let mut servers = Vec::new();
let mut replica_number = 0;
// Load Mirror settings
for (address_index, server) in shard.servers.iter().enumerate() {
let mut mirror_addresses = vec![];
if let Some(mirror_settings_vec) = &shard.mirrors {
for (mirror_idx, mirror_settings) in
mirror_settings_vec.iter().enumerate()
{
if mirror_settings.mirroring_target_index != address_index {
continue;
}
mirror_addresses.push(Address {
id: address_id,
database: shard.database.clone(),
host: mirror_settings.host.clone(),
port: mirror_settings.port,
role: server.role,
address_index: mirror_idx,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
username: user.username.clone(),
pool_name: pool_name.clone(),
mirrors: vec![],
stats: Arc::new(AddressStats::default()),
});
address_id += 1;
}
}
let address = Address {
id: address_id,
database: shard.database.clone(),
host: server.host.clone(),
port: server.port,
role: server.role,
address_index,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
username: user.username.clone(),
pool_name: pool_name.clone(),
mirrors: mirror_addresses,
stats: Arc::new(AddressStats::default()),
};
address_id += 1;
if server.role == Role::Replica {
replica_number += 1;
}
// We assume every server in the pool share user/passwords
let auth_passthrough = AuthPassthrough::from_pool_config(pool_config);
if let Some(apt) = &auth_passthrough {
match apt.fetch_hash(&address).await {
Ok(ok) => {
if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) {
if ok != *pool_auth_hash_value {
warn!("Hash is not the same across shards of the same pool, client auth will \
be done using last obtained hash. Server: {}:{}, Database: {}", server.host, server.port, shard.database);
}
}
debug!("Hash obtained for {:?}", address);
{
let mut pool_auth_hash = pool_auth_hash.write();
*pool_auth_hash = Some(ok.clone());
}
},
Err(err) => warn!("Could not obtain password hashes using auth_query config, ignoring. Error: {:?}", err),
}
}
let manager = ServerPool::new(
address.clone(),
user.clone(),
&shard.database,
client_server_map.clone(),
pool_stats.clone(),
pool_auth_hash.clone(),
);
let connect_timeout = match pool_config.connect_timeout {
Some(connect_timeout) => connect_timeout,
None => config.general.connect_timeout,
};
let idle_timeout = match pool_config.idle_timeout {
Some(idle_timeout) => idle_timeout,
None => config.general.idle_timeout,
};
let pool = Pool::builder()
.max_size(user.pool_size)
.connection_timeout(std::time::Duration::from_millis(
connect_timeout,
))
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
.test_on_check_out(false)
.build(manager)
.await
.unwrap();
pools.push(pool);
servers.push(address);
}
shards.push(pools);
addresses.push(servers);
banlist.push(HashMap::new());
}
assert_eq!(shards.len(), addresses.len());
if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) {
info!(
"Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}",
pool_name, user.username
);
}
let pool = ConnectionPool {
databases: shards,
stats: pool_stats,
addresses,
banlist: Arc::new(RwLock::new(banlist)),
config_hash: new_pool_hash_value,
server_info: Arc::new(RwLock::new(BytesMut::new())),
auth_hash: pool_auth_hash,
settings: PoolSettings {
pool_mode: pool_config.pool_mode,
load_balancing_mode: pool_config.load_balancing_mode,
// shards: pool_config.shards.clone(),
shards: shard_ids.len(),
user: user.clone(),
default_role: match pool_config.default_role.as_str() {
"any" => None,
"replica" => Some(Role::Replica),
"primary" => Some(Role::Primary),
_ => unreachable!(),
},
query_parser_enabled: pool_config.query_parser_enabled,
primary_reads_enabled: pool_config.primary_reads_enabled,
sharding_function: pool_config.sharding_function,
automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
healthcheck_delay: config.general.healthcheck_delay,
healthcheck_timeout: config.general.healthcheck_timeout,
ban_time: config.general.ban_time,
sharding_key_regex: pool_config
.sharding_key_regex
.clone()
.map(|regex| Regex::new(regex.as_str()).unwrap()),
shard_id_regex: pool_config
.shard_id_regex
.clone()
.map(|regex| Regex::new(regex.as_str()).unwrap()),
regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
auth_query: pool_config.auth_query.clone(),
auth_query_user: pool_config.auth_query_user.clone(),
auth_query_password: pool_config.auth_query_password.clone(),
let pool = ConnectionPool {
databases: shards,
addresses,
banlist: Arc::new(RwLock::new(banlist)),
config_hash: new_pool_hash_value,
server_info: Arc::new(RwLock::new(BytesMut::new())),
auth_hash: pool_auth_hash,
settings: PoolSettings {
pool_mode: match user.pool_mode {
Some(pool_mode) => pool_mode,
None => pool_config.pool_mode,
},
validated: Arc::new(AtomicBool::new(false)),
paused: Arc::new(AtomicBool::new(false)),
paused_waiter: Arc::new(Notify::new()),
};
load_balancing_mode: pool_config.load_balancing_mode,
// shards: pool_config.shards.clone(),
shards: shard_ids.len(),
user: user.clone(),
db: pool_name.clone(),
default_role: match pool_config.default_role.as_str() {
"any" => None,
"replica" => Some(Role::Replica),
"primary" => Some(Role::Primary),
_ => unreachable!(),
},
query_parser_enabled: pool_config.query_parser_enabled,
primary_reads_enabled: pool_config.primary_reads_enabled,
sharding_function: pool_config.sharding_function,
automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
healthcheck_delay: config.general.healthcheck_delay,
healthcheck_timeout: config.general.healthcheck_timeout,
ban_time: config.general.ban_time,
sharding_key_regex: pool_config
.sharding_key_regex
.clone()
.map(|regex| Regex::new(regex.as_str()).unwrap()),
shard_id_regex: pool_config
.shard_id_regex
.clone()
.map(|regex| Regex::new(regex.as_str()).unwrap()),
regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
auth_query: pool_config.auth_query.clone(),
auth_query_user: pool_config.auth_query_user.clone(),
auth_query_password: pool_config.auth_query_password.clone(),
plugins: match pool_config.plugins {
Some(ref plugins) => Some(plugins.clone()),
None => config.plugins.clone(),
},
},
validated: Arc::new(AtomicBool::new(false)),
paused: Arc::new(AtomicBool::new(false)),
paused_waiter: Arc::new(Notify::new()),
};
// Connect to the servers to make sure pool configuration is valid
// before setting it globally.
// Do this async and somewhere else, we don't have to wait here.
// Connect to the servers to make sure pool configuration is valid
// before setting it globally.
// Do this async and somewhere else, we don't have to wait here.
if config.general.validate_config {
let mut validate_pool = pool.clone();
tokio::task::spawn(async move {
let _ = validate_pool.validate().await;
});
// There is one pool per database/user pair.
new_pools.insert(PoolIdentifier::new(pool_name, &user.username, secret), pool);
}
// There is one pool per database/user pair.
new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool);
}
}
@@ -561,6 +614,10 @@ impl ConnectionPool {
});
}
// Indicate we're waiting on a server connection from a pool.
let now = Instant::now();
client_stats.waiting();
while !candidates.is_empty() {
// Get the next candidate
let address = match candidates.pop() {
@@ -579,10 +636,6 @@ impl ConnectionPool {
}
}
// Indicate we're waiting on a server connection from a pool.
let now = Instant::now();
client_stats.waiting();
// Check if we can connect
let mut conn = match self.databases[address.shard][address.address_index]
.get()
@@ -590,7 +643,10 @@ impl ConnectionPool {
{
Ok(conn) => conn,
Err(err) => {
error!("Banning instance {:?}, error: {:?}", address, err);
error!(
"Connection checkout error for instance {:?}, error: {:?}",
address, err
);
self.ban(address, BanReason::FailedCheckout, Some(client_stats));
address.stats.error();
client_stats.idle();
@@ -617,7 +673,7 @@ impl ConnectionPool {
.stats()
.checkout_time(checkout_time, client_stats.application_name());
server.stats().active(client_stats.application_name());
client_stats.active();
return Ok((conn, address.clone()));
}
@@ -625,11 +681,19 @@ impl ConnectionPool {
.run_health_check(address, server, now, client_stats)
.await
{
let checkout_time: u64 = now.elapsed().as_micros() as u64;
client_stats.checkout_time(checkout_time);
server
.stats()
.checkout_time(checkout_time, client_stats.application_name());
server.stats().active(client_stats.application_name());
client_stats.active();
return Ok((conn, address.clone()));
} else {
continue;
}
}
client_stats.idle();
Err(Error::AllServersDown)
}
@@ -666,7 +730,7 @@ impl ConnectionPool {
// Health check failed.
Err(err) => {
error!(
"Banning instance {:?} because of failed health check, {:?}",
"Failed health check on instance {:?}, error: {:?}",
address, err
);
}
@@ -675,7 +739,7 @@ impl ConnectionPool {
// Health check timed out.
Err(err) => {
error!(
"Banning instance {:?} because of health check timeout, {:?}",
"Health check timeout on instance {:?}, error: {:?}",
address, err
);
}
@@ -697,13 +761,16 @@ impl ConnectionPool {
return;
}
error!("Banning instance {:?}, reason: {:?}", address, reason);
let now = chrono::offset::Utc::now().naive_utc();
let mut guard = self.banlist.write();
error!("Banning {:?}", address);
if let Some(client_info) = client_info {
client_info.ban_error();
address.stats.error();
}
guard[address.shard].insert(address.clone(), (reason, now));
}
@@ -860,12 +927,26 @@ impl ConnectionPool {
/// Wrapper for the bb8 connection pool.
pub struct ServerPool {
/// Server address.
address: Address,
/// Server Postgres user.
user: User,
/// Server database.
database: String,
/// Client/server mapping.
client_server_map: ClientServerMap,
stats: Arc<PoolStats>,
/// Server auth hash (for auth passthrough).
auth_hash: Arc<RwLock<Option<String>>>,
/// Server plugins.
plugins: Option<Plugins>,
/// Should we clean up dirty connections before putting them into the pool?
cleanup_connections: bool,
}
impl ServerPool {
@@ -874,16 +955,18 @@ impl ServerPool {
user: User,
database: &str,
client_server_map: ClientServerMap,
stats: Arc<PoolStats>,
auth_hash: Arc<RwLock<Option<String>>>,
plugins: Option<Plugins>,
cleanup_connections: bool,
) -> ServerPool {
ServerPool {
address,
user: user.clone(),
database: database.to_string(),
client_server_map,
stats,
auth_hash,
plugins,
cleanup_connections,
}
}
}
@@ -899,7 +982,6 @@ impl ManageConnection for ServerPool {
let stats = Arc::new(ServerStats::new(
self.address.clone(),
self.stats.clone(),
tokio::time::Instant::now(),
));
@@ -913,10 +995,23 @@ impl ManageConnection for ServerPool {
self.client_server_map.clone(),
stats.clone(),
self.auth_hash.clone(),
self.cleanup_connections,
)
.await
{
Ok(conn) => {
Ok(mut conn) => {
if let Some(ref plugins) = self.plugins {
if let Some(ref prewarmer) = plugins.prewarmer {
let mut prewarmer = prewarmer::Prewarmer {
enabled: prewarmer.enabled,
server: &mut conn,
queries: &prewarmer.queries,
};
prewarmer.run().await?;
}
}
stats.idle();
Ok(conn)
}
@@ -939,10 +1034,10 @@ impl ManageConnection for ServerPool {
}
/// Get the connection pool
pub fn get_pool(db: &str, user: &str, secret: Option<String>) -> Option<ConnectionPool> {
let identifier = PoolIdentifier::new(db, user, secret);
(*(*POOLS.load())).get(&identifier).cloned()
pub fn get_pool(db: &str, user: &str) -> Option<ConnectionPool> {
(*(*POOLS.load()))
.get(&PoolIdentifier::new(db, user))
.cloned()
}
/// Get a pointer to all configured pools.

View File

@@ -1,6 +1,6 @@
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, Server, StatusCode};
use log::{error, info, warn};
use log::{debug, error, info};
use phf::phf_map;
use std::collections::HashMap;
use std::fmt;
@@ -10,7 +10,8 @@ use std::sync::Arc;
use crate::config::Address;
use crate::pool::{get_all_pools, PoolIdentifier};
use crate::stats::{get_pool_stats, get_server_stats, ServerStats};
use crate::stats::pool::PoolStats;
use crate::stats::{get_server_stats, ServerStats};
struct MetricHelpType {
help: &'static str,
@@ -233,10 +234,10 @@ impl<Value: fmt::Display> PrometheusMetric<Value> {
Self::from_name(&format!("stats_{}", name), value, labels)
}
fn from_pool(pool: &PoolIdentifier, name: &str, value: u64) -> Option<PrometheusMetric<u64>> {
fn from_pool(pool_id: PoolIdentifier, name: &str, value: u64) -> Option<PrometheusMetric<u64>> {
let mut labels = HashMap::new();
labels.insert("pool", pool.db.clone());
labels.insert("user", pool.user.clone());
labels.insert("pool", pool_id.db);
labels.insert("user", pool_id.user);
Self::from_name(&format!("pools_{}", name), value, labels)
}
@@ -274,7 +275,7 @@ fn push_address_stats(lines: &mut Vec<String>) {
{
lines.push(prometheus_metric.to_string());
} else {
warn!("Metric {} not implemented for {}", key, address.name());
debug!("Metric {} not implemented for {}", key, address.name());
}
}
}
@@ -284,18 +285,15 @@ fn push_address_stats(lines: &mut Vec<String>) {
// Adds relevant metrics shown in a SHOW POOLS admin command.
fn push_pool_stats(lines: &mut Vec<String>) {
let pool_stats = get_pool_stats();
for (pool, stats) in pool_stats.iter() {
let stats = &**stats;
let pool_stats = PoolStats::construct_pool_lookup();
for (pool_id, stats) in pool_stats.iter() {
for (name, value) in stats.clone() {
if let Some(prometheus_metric) = PrometheusMetric::<u64>::from_pool(pool, &name, value)
if let Some(prometheus_metric) =
PrometheusMetric::<u64>::from_pool(pool_id.clone(), &name, value)
{
lines.push(prometheus_metric.to_string());
} else {
warn!(
"Metric {} not implemented for ({},{})",
name, pool.db, pool.user
);
debug!("Metric {} not implemented for ({})", name, *pool_id);
}
}
}
@@ -320,7 +318,7 @@ fn push_database_stats(lines: &mut Vec<String>) {
{
lines.push(prometheus_metric.to_string());
} else {
warn!("Metric {} not implemented for {}", key, address.name());
debug!("Metric {} not implemented for {}", key, address.name());
}
}
}
@@ -366,7 +364,7 @@ fn push_server_stats(lines: &mut Vec<String>) {
{
lines.push(prometheus_metric.to_string());
} else {
warn!("Metric {} not implemented for {}", key, address.name());
debug!("Metric {} not implemented for {}", key, address.name());
}
}
}

View File

@@ -1,4 +1,4 @@
/// Route queries automatically based on explicitely requested
/// Route queries automatically based on explicitly requested
/// or implied query characteristics.
use bytes::{Buf, BytesMut};
use log::{debug, error};
@@ -6,13 +6,16 @@ use once_cell::sync::OnceCell;
use regex::{Regex, RegexSet};
use sqlparser::ast::Statement::{Query, StartTransaction};
use sqlparser::ast::{
BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, TableFactor, Value,
BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, Statement, TableFactor,
Value,
};
use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;
use crate::config::Role;
use crate::errors::Error;
use crate::messages::BytesMutReader;
use crate::plugins::{Intercept, Plugin, PluginOutput, QueryLogger, TableAccess};
use crate::pool::PoolSettings;
use crate::sharding::Sharder;
@@ -129,6 +132,10 @@ impl QueryRouter {
self.pool_settings = pool_settings;
}
pub fn pool_settings<'a>(&'a self) -> &'a PoolSettings {
&self.pool_settings
}
/// Try to parse a command and execute it.
pub fn try_execute_command(&mut self, message_buffer: &BytesMut) -> Option<(Command, String)> {
let mut message_cursor = Cursor::new(message_buffer);
@@ -324,10 +331,7 @@ impl QueryRouter {
Some((command, value))
}
/// Try to infer which server to connect to based on the contents of the query.
pub fn infer(&mut self, message: &BytesMut) -> bool {
debug!("Inferring role");
pub fn parse(message: &BytesMut) -> Result<Vec<Statement>, Error> {
let mut message_cursor = Cursor::new(message);
let code = message_cursor.get_u8() as char;
@@ -344,37 +348,39 @@ impl QueryRouter {
// Parse (prepared statement)
'P' => {
// Reads statement name
message_cursor.read_string().unwrap();
let _name = message_cursor.read_string().unwrap();
// Reads query string
let query = message_cursor.read_string().unwrap();
debug!("Prepared statement: '{}'", query);
query
}
_ => return false,
_ => return Err(Error::UnsupportedStatement),
};
let ast = match Parser::parse_sql(&PostgreSqlDialect {}, &query) {
Ok(ast) => ast,
match Parser::parse_sql(&PostgreSqlDialect {}, &query) {
Ok(ast) => Ok(ast),
Err(err) => {
// SELECT ... FOR UPDATE won't get parsed correctly.
debug!("{}: {}", err, query);
self.active_role = Some(Role::Primary);
return false;
Err(Error::QueryRouterParserError(err.to_string()))
}
};
}
}
debug!("AST: {:?}", ast);
/// Try to infer which server to connect to based on the contents of the query.
pub fn infer(&mut self, ast: &Vec<sqlparser::ast::Statement>) -> Result<(), Error> {
debug!("Inferring role");
if ast.is_empty() {
// That's weird, no idea, let's go to primary
self.active_role = Some(Role::Primary);
return false;
return Err(Error::QueryRouterParserError("empty query".into()));
}
for q in &ast {
for q in ast {
match q {
// All transactions go to the primary, probably a write.
StartTransaction { .. } => {
@@ -418,7 +424,7 @@ impl QueryRouter {
};
}
true
Ok(())
}
/// Parse the shard number from the Bind message
@@ -783,6 +789,52 @@ impl QueryRouter {
}
}
/// Add your plugins here and execute them.
pub async fn execute_plugins(&self, ast: &Vec<Statement>) -> Result<PluginOutput, Error> {
let plugins = match self.pool_settings.plugins {
Some(ref plugins) => plugins,
None => return Ok(PluginOutput::Allow),
};
if let Some(ref query_logger) = plugins.query_logger {
let mut query_logger = QueryLogger {
enabled: query_logger.enabled,
user: &self.pool_settings.user.username,
db: &self.pool_settings.db,
};
let _ = query_logger.run(&self, ast).await;
}
if let Some(ref intercept) = plugins.intercept {
let mut intercept = Intercept {
enabled: intercept.enabled,
config: &intercept,
};
let result = intercept.run(&self, ast).await;
if let Ok(PluginOutput::Intercept(output)) = result {
return Ok(PluginOutput::Intercept(output));
}
}
if let Some(ref table_access) = plugins.table_access {
let mut table_access = TableAccess {
enabled: table_access.enabled,
tables: &table_access.tables,
};
let result = table_access.run(&self, ast).await;
if let Ok(PluginOutput::Deny(error)) = result {
return Ok(PluginOutput::Deny(error));
}
}
Ok(PluginOutput::Allow)
}
fn set_sharding_key(&mut self, sharding_key: i64) -> Option<usize> {
let sharder = Sharder::new(
self.pool_settings.shards,
@@ -810,11 +862,22 @@ impl QueryRouter {
/// Should we attempt to parse queries?
pub fn query_parser_enabled(&self) -> bool {
let enabled = match self.query_parser_enabled {
None => self.pool_settings.query_parser_enabled,
Some(value) => value,
};
None => {
debug!(
"Using pool settings, query_parser_enabled: {}",
self.pool_settings.query_parser_enabled
);
self.pool_settings.query_parser_enabled
}
debug!("Query parser enabled: {}", enabled);
Some(value) => {
debug!(
"Using query parser override, query_parser_enabled: {}",
value
);
value
}
};
enabled
}
@@ -862,7 +925,7 @@ mod test {
for query in queries {
// It's a recognized query
assert!(qr.infer(&query));
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Replica));
}
}
@@ -881,7 +944,7 @@ mod test {
for query in queries {
// It's a recognized query
assert!(qr.infer(&query));
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Primary));
}
}
@@ -893,7 +956,7 @@ mod test {
let query = simple_query("SELECT * FROM items WHERE id = 5");
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None);
assert!(qr.infer(&query));
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), None);
}
@@ -913,7 +976,7 @@ mod test {
res.put(prepared_stmt);
res.put_i16(0);
assert!(qr.infer(&res));
assert!(qr.infer(&QueryRouter::parse(&res).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Replica));
}
@@ -1077,11 +1140,11 @@ mod test {
assert_eq!(qr.role(), None);
let query = simple_query("INSERT INTO test_table VALUES (1)");
assert!(qr.infer(&query));
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Primary));
let query = simple_query("SELECT * FROM test_table");
assert!(qr.infer(&query));
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Replica));
assert!(qr.query_parser_enabled());
@@ -1113,6 +1176,8 @@ mod test {
auth_query: None,
auth_query_password: None,
auth_query_user: None,
db: "test".to_string(),
plugins: None,
};
let mut qr = QueryRouter::new();
assert_eq!(qr.active_role, None);
@@ -1142,15 +1207,24 @@ mod test {
QueryRouter::setup();
let mut qr = QueryRouter::new();
assert!(qr.infer(&simple_query("BEGIN; SELECT 1; COMMIT;")));
assert!(qr
.infer(&QueryRouter::parse(&simple_query("BEGIN; SELECT 1; COMMIT;")).unwrap())
.is_ok());
assert_eq!(qr.role(), Role::Primary);
assert!(qr.infer(&simple_query("SELECT 1; SELECT 2;")));
assert!(qr
.infer(&QueryRouter::parse(&simple_query("SELECT 1; SELECT 2;")).unwrap())
.is_ok());
assert_eq!(qr.role(), Role::Replica);
assert!(qr.infer(&simple_query(
"SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
)));
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
"SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.role(), Role::Primary);
}
@@ -1177,7 +1251,10 @@ mod test {
auth_query: None,
auth_query_password: None,
auth_query_user: None,
db: "test".to_string(),
plugins: None,
};
let mut qr = QueryRouter::new();
qr.update_pool_settings(pool_settings.clone());
@@ -1208,47 +1285,84 @@ mod test {
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
qr.pool_settings.shards = 3;
assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 5")));
assert!(qr
.infer(&QueryRouter::parse(&simple_query("SELECT * FROM data WHERE id = 5")).unwrap())
.is_ok());
assert_eq!(qr.shard(), 2);
assert!(qr.infer(&simple_query(
"SELECT one, two, three FROM public.data WHERE id = 6"
)));
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
"SELECT one, two, three FROM public.data WHERE id = 6"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 0);
assert!(qr.infer(&simple_query(
"SELECT * FROM data
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
"SELECT * FROM data
INNER JOIN t2 ON data.id = 5
AND t2.data_id = data.id
WHERE data.id = 5"
)));
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 2);
// Shard did not move because we couldn't determine the sharding key since it could be ambiguous
// in the query.
assert!(qr.infer(&simple_query(
"SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
)));
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
"SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 2);
assert!(qr.infer(&simple_query(
r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
)));
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 0);
assert!(qr.infer(&simple_query(
r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
)));
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 2);
// Super unique sharding key
qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string());
assert!(qr.infer(&simple_query(
"SELECT * FROM table_x WHERE unique_enough_column_name = 6"
)));
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
"SELECT * FROM table_x WHERE unique_enough_column_name = 6"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 0);
assert!(qr.infer(&simple_query("SELECT * FROM table_y WHERE another_key = 5")));
assert!(qr
.infer(
&QueryRouter::parse(&simple_query("SELECT * FROM table_y WHERE another_key = 5"))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 0);
}
@@ -1272,11 +1386,61 @@ mod test {
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
qr.pool_settings.shards = 3;
assert!(qr.infer(&simple_query(stmt)));
assert!(qr
.infer(&QueryRouter::parse(&simple_query(stmt)).unwrap())
.is_ok());
assert_eq!(qr.placeholders.len(), 1);
assert!(qr.infer_shard_from_bind(&bind));
assert_eq!(qr.shard(), 2);
assert!(qr.placeholders.is_empty());
}
#[tokio::test]
async fn test_table_access_plugin() {
use crate::config::{Plugins, TableAccess};
let table_access = TableAccess {
enabled: true,
tables: vec![String::from("pg_database")],
};
let plugins = Plugins {
table_access: Some(table_access),
intercept: None,
query_logger: None,
prewarmer: None,
};
QueryRouter::setup();
let mut pool_settings = PoolSettings::default();
pool_settings.query_parser_enabled = true;
pool_settings.plugins = Some(plugins);
let mut qr = QueryRouter::new();
qr.update_pool_settings(pool_settings);
let query = simple_query("SELECT * FROM pg_database");
let ast = QueryRouter::parse(&query).unwrap();
let res = qr.execute_plugins(&ast).await;
assert_eq!(
res,
Ok(PluginOutput::Deny(
"permission for table \"pg_database\" denied".to_string()
))
);
}
#[tokio::test]
async fn test_plugins_disabled_by_defaault() {
QueryRouter::setup();
let qr = QueryRouter::new();
let query = simple_query("SELECT * FROM pg_database");
let ast = QueryRouter::parse(&query).unwrap();
let res = qr.execute_plugins(&ast).await;
assert_eq!(res, Ok(PluginOutput::Allow));
}
}

View File

@@ -5,24 +5,145 @@ use fallible_iterator::FallibleIterator;
use log::{debug, error, info, trace, warn};
use parking_lot::{Mutex, RwLock};
use postgres_protocol::message;
use std::collections::HashMap;
use std::collections::{BTreeSet, HashMap};
use std::io::Read;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::SystemTime;
use tokio::io::{AsyncReadExt, BufReader};
use tokio::net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream};
use tokio::net::TcpStream;
use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore};
use tokio_rustls::{client::TlsStream, TlsConnector};
use crate::config::{Address, User};
use crate::config::{get_config, get_prepared_statements_cache_size, Address, User};
use crate::constants::*;
use crate::errors::Error;
use crate::dns_cache::{AddrSet, CACHED_RESOLVER};
use crate::errors::{Error, ServerIdentifier};
use crate::messages::*;
use crate::mirrors::MirroringManager;
use crate::pool::ClientServerMap;
use crate::scram::ScramSha256;
use crate::stats::ServerStats;
use std::io::Write;
use pin_project::pin_project;
#[pin_project(project = SteamInnerProj)]
pub enum StreamInner {
Plain {
#[pin]
stream: TcpStream,
},
Tls {
#[pin]
stream: TlsStream<TcpStream>,
},
}
impl AsyncWrite for StreamInner {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let this = self.project();
match this {
SteamInnerProj::Tls { stream } => stream.poll_write(cx, buf),
SteamInnerProj::Plain { stream } => stream.poll_write(cx, buf),
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let this = self.project();
match this {
SteamInnerProj::Tls { stream } => stream.poll_flush(cx),
SteamInnerProj::Plain { stream } => stream.poll_flush(cx),
}
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let this = self.project();
match this {
SteamInnerProj::Tls { stream } => stream.poll_shutdown(cx),
SteamInnerProj::Plain { stream } => stream.poll_shutdown(cx),
}
}
}
impl AsyncRead for StreamInner {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let this = self.project();
match this {
SteamInnerProj::Tls { stream } => stream.poll_read(cx, buf),
SteamInnerProj::Plain { stream } => stream.poll_read(cx, buf),
}
}
}
impl StreamInner {
pub fn try_write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self {
StreamInner::Tls { stream } => {
let r = stream.get_mut();
let mut w = r.1.writer();
w.write(buf)
}
StreamInner::Plain { stream } => stream.try_write(buf),
}
}
}
#[derive(Copy, Clone)]
struct CleanupState {
/// If server connection requires DISCARD ALL before checkin because of set statement
needs_cleanup_set: bool,
/// If server connection requires DISCARD ALL before checkin because of prepare statement
needs_cleanup_prepare: bool,
}
impl CleanupState {
fn new() -> Self {
CleanupState {
needs_cleanup_set: false,
needs_cleanup_prepare: false,
}
}
fn needs_cleanup(&self) -> bool {
self.needs_cleanup_set || self.needs_cleanup_prepare
}
fn set_true(&mut self) {
self.needs_cleanup_set = true;
self.needs_cleanup_prepare = true;
}
fn reset(&mut self) {
self.needs_cleanup_set = false;
self.needs_cleanup_prepare = false;
}
}
impl std::fmt::Display for CleanupState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"SET: {}, PREPARE: {}",
self.needs_cleanup_set, self.needs_cleanup_prepare
)
}
}
/// Server state.
pub struct Server {
@@ -30,11 +151,8 @@ pub struct Server {
/// port, e.g. 5432, and role, e.g. primary or replica.
address: Address,
/// Buffered read socket.
read: BufReader<OwnedReadHalf>,
/// Unbuffered write socket (our client code buffers).
write: OwnedWriteHalf,
/// Server TCP connection.
stream: BufStream<StreamInner>,
/// Our server response buffer. We buffer data before we give it to the client.
buffer: BytesMut,
@@ -52,11 +170,14 @@ pub struct Server {
/// Is there more data for the client to read.
data_available: bool,
/// Is the server in copy-in or copy-out modes
in_copy_mode: bool,
/// Is the server broken? We'll remote it from the pool if so.
bad: bool,
/// If server connection requires a DISCARD ALL before checkin
needs_cleanup: bool,
/// If server connection requires DISCARD ALL before checkin
cleanup_state: CleanupState,
/// Mapping of clients and servers used for query cancellation.
client_server_map: ClientServerMap,
@@ -70,10 +191,19 @@ pub struct Server {
/// Application name using the server at the moment.
application_name: String,
// Last time that a successful server send or response happened
/// Last time that a successful server send or response happened
last_activity: SystemTime,
mirror_manager: Option<MirroringManager>,
/// Associated addresses used
addr_set: Option<AddrSet>,
/// Should clean up dirty connections?
cleanup_connections: bool,
/// Prepared statements
prepared_statements: BTreeSet<String>,
}
impl Server {
@@ -86,7 +216,26 @@ impl Server {
client_server_map: ClientServerMap,
stats: Arc<ServerStats>,
auth_hash: Arc<RwLock<Option<String>>>,
cleanup_connections: bool,
) -> Result<Server, Error> {
let cached_resolver = CACHED_RESOLVER.load();
let mut addr_set: Option<AddrSet> = None;
// If we are caching addresses and hostname is not an IP
if cached_resolver.enabled() && address.host.parse::<IpAddr>().is_err() {
debug!("Resolving {}", &address.host);
addr_set = match cached_resolver.lookup_ip(&address.host).await {
Ok(ok) => {
debug!("Obtained: {:?}", ok);
Some(ok)
}
Err(err) => {
warn!("Error trying to resolve {}, ({:?})", &address.host, err);
None
}
}
};
let mut stream =
match TcpStream::connect(&format!("{}:{}", &address.host, address.port)).await {
Ok(stream) => stream,
@@ -98,33 +247,137 @@ impl Server {
)));
}
};
// TCP timeouts.
configure_socket(&stream);
let config = get_config();
let mut stream = if config.general.server_tls {
// Request a TLS connection
ssl_request(&mut stream).await?;
let response = match stream.read_u8().await {
Ok(response) => response as char,
Err(err) => {
return Err(Error::SocketError(format!(
"Server socket error: {:?}",
err
)))
}
};
match response {
// Server supports TLS
'S' => {
debug!("Connecting to server using TLS");
let mut root_store = RootCertStore::empty();
root_store.add_server_trust_anchors(
webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}),
);
let mut tls_config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
// Equivalent to sslmode=prefer which is fine most places.
// If you want verify-full, change `verify_server_certificate` to true.
if !config.general.verify_server_certificate {
let mut dangerous = tls_config.dangerous();
dangerous.set_certificate_verifier(Arc::new(
crate::tls::NoCertificateVerification {},
));
}
let connector = TlsConnector::from(Arc::new(tls_config));
let stream = match connector
.connect(address.host.as_str().try_into().unwrap(), stream)
.await
{
Ok(stream) => stream,
Err(err) => {
return Err(Error::SocketError(format!("Server TLS error: {:?}", err)))
}
};
StreamInner::Tls { stream }
}
// Server does not support TLS
'N' => StreamInner::Plain { stream },
// Something else?
m => {
return Err(Error::SocketError(format!(
"Unknown message: {}",
m as char
)));
}
}
} else {
StreamInner::Plain { stream }
};
// let (read, write) = split(stream);
// let (mut read, mut write) = (ReadInner::Plain { stream: read }, WriteInner::Plain { stream: write });
trace!("Sending StartupMessage");
// StartupMessage
startup(&mut stream, &user.username, database).await?;
let username = match user.server_username {
Some(ref server_username) => server_username,
None => &user.username,
};
let password = match user.server_password {
Some(ref server_password) => Some(server_password),
None => match user.password {
Some(ref password) => Some(password),
None => None,
},
};
startup(&mut stream, username, database).await?;
let mut server_info = BytesMut::new();
let mut process_id: i32 = 0;
let mut secret_key: i32 = 0;
let server_identifier = ServerIdentifier::new(username, &database);
// We'll be handling multiple packets, but they will all be structured the same.
// We'll loop here until this exchange is complete.
let mut scram: Option<ScramSha256> = None;
if let Some(password) = &user.password.clone() {
scram = Some(ScramSha256::new(password));
}
let mut scram: Option<ScramSha256> = match password {
Some(password) => Some(ScramSha256::new(password)),
None => None,
};
loop {
let code = match stream.read_u8().await {
Ok(code) => code as char,
Err(_) => return Err(Error::SocketError(format!("Error reading message code on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"message code".into(),
server_identifier,
))
}
};
let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::SocketError(format!("Error reading message len on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"message len".into(),
server_identifier,
))
}
};
trace!("Message: {}", code);
@@ -135,7 +388,12 @@ impl Server {
// Determine which kind of authentication is required, if any.
let auth_code = match stream.read_i32().await {
Ok(auth_code) => auth_code,
Err(_) => return Err(Error::SocketError(format!("Error reading auth code on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"auth code".into(),
server_identifier,
))
}
};
trace!("Auth: {}", auth_code);
@@ -148,14 +406,18 @@ impl Server {
match stream.read_exact(&mut salt).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading salt on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"salt".into(),
server_identifier,
))
}
};
match &user.password {
match password {
// Using plaintext password
Some(password) => {
md5_password(&mut stream, &user.username, password, &salt[..])
.await?
md5_password(&mut stream, username, password, &salt[..]).await?
}
// Using auth passthrough, in this case we should already have a
@@ -171,8 +433,12 @@ impl Server {
&salt[..],
)
.await?,
None =>
return Err(Error::AuthError(format!("Auth passthrough (auth_query) failed and no user password is set in cleartext for {{ username: {:?}, database: {:?} }}", user.username, database)))
None => return Err(
Error::ServerAuthError(
"Auth passthrough (auth_query) failed and no user password is set in cleartext".into(),
server_identifier
)
),
}
}
}
@@ -182,21 +448,33 @@ impl Server {
SASL => {
if scram.is_none() {
return Err(Error::AuthError(format!("SASL auth required and not password specified, auth passthrough (auth_query) method is currently unsupported for SASL auth {{ username: {:?}, database: {:?} }}", user.username, database)));
return Err(Error::ServerAuthError(
"SASL auth required and no password specified. \
Auth passthrough (auth_query) method is currently \
unsupported for SASL auth"
.into(),
server_identifier,
));
}
debug!("Starting SASL authentication");
let sasl_len = (len - 8) as usize;
let mut sasl_auth = vec![0u8; sasl_len];
match stream.read_exact(&mut sasl_auth).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading sasl message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"sasl message".into(),
server_identifier,
))
}
};
let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]);
if sasl_type == SCRAM_SHA_256 {
if sasl_type.contains(SCRAM_SHA_256) {
debug!("Using {}", SCRAM_SHA_256);
// Generate client message.
@@ -219,7 +497,7 @@ impl Server {
res.put_i32(sasl_response.len() as i32);
res.put(sasl_response);
write_all(&mut stream, res).await?;
write_all_flush(&mut stream, &res).await?;
} else {
error!("Unsupported SCRAM version: {}", sasl_type);
return Err(Error::ServerError);
@@ -233,7 +511,12 @@ impl Server {
match stream.read_exact(&mut sasl_data).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading sasl cont message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"sasl cont message".into(),
server_identifier,
))
}
};
let msg = BytesMut::from(&sasl_data[..]);
@@ -245,7 +528,7 @@ impl Server {
res.put_i32(4 + sasl_response.len() as i32);
res.put(sasl_response);
write_all(&mut stream, res).await?;
write_all_flush(&mut stream, &res).await?;
}
SASL_FINAL => {
@@ -254,7 +537,12 @@ impl Server {
let mut sasl_final = vec![0u8; len as usize - 8];
match stream.read_exact(&mut sasl_final).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading sasl final message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"sasl final message".into(),
server_identifier,
))
}
};
match scram
@@ -284,7 +572,12 @@ impl Server {
'E' => {
let error_code = match stream.read_u8().await {
Ok(error_code) => error_code,
Err(_) => return Err(Error::SocketError(format!("Error reading error code message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"error code message".into(),
server_identifier,
))
}
};
trace!("Error: {}", error_code);
@@ -300,7 +593,12 @@ impl Server {
match stream.read_exact(&mut error).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading error message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"error message".into(),
server_identifier,
))
}
};
// TODO: the error message contains multiple fields; we can decode them and
@@ -319,7 +617,12 @@ impl Server {
match stream.read_exact(&mut param).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading parameter status message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"parameter status message".into(),
server_identifier,
))
}
};
// Save the parameter so we can pass it to the client later.
@@ -336,12 +639,22 @@ impl Server {
// See: <https://www.postgresql.org/docs/12/protocol-message-formats.html>.
process_id = match stream.read_i32().await {
Ok(id) => id,
Err(_) => return Err(Error::SocketError(format!("Error reading process id message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"process id message".into(),
server_identifier,
))
}
};
secret_key = match stream.read_i32().await {
Ok(id) => id,
Err(_) => return Err(Error::SocketError(format!("Error reading secret key message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"secret key message".into(),
server_identifier,
))
}
};
}
@@ -351,24 +664,28 @@ impl Server {
match stream.read_exact(&mut idle).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading transaction status message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"transaction status message".into(),
server_identifier,
))
}
};
let (read, write) = stream.into_split();
let mut server = Server {
address: address.clone(),
read: BufReader::new(read),
write,
stream: BufStream::new(stream),
buffer: BytesMut::with_capacity(8196),
server_info,
process_id,
secret_key,
in_transaction: false,
in_copy_mode: false,
data_available: false,
bad: false,
needs_cleanup: false,
cleanup_state: CleanupState::new(),
client_server_map,
addr_set,
connected_at: chrono::offset::Utc::now().naive_utc(),
stats,
application_name: String::new(),
@@ -381,6 +698,8 @@ impl Server {
address.mirrors.clone(),
)),
},
cleanup_connections,
prepared_statements: BTreeSet::new(),
};
server.set_name("pgcat").await?;
@@ -413,7 +732,7 @@ impl Server {
Ok(stream) => stream,
Err(err) => {
error!("Could not connect to server: {}", err);
return Err(Error::SocketError(format!("Error reading cancel message")));
return Err(Error::SocketError("Error reading cancel message".into()));
}
};
configure_socket(&stream);
@@ -426,7 +745,7 @@ impl Server {
bytes.put_i32(process_id);
bytes.put_i32(secret_key);
write_all(&mut stream, bytes).await
write_all_flush(&mut stream, &bytes).await
}
/// Send messages to the server from the client.
@@ -434,14 +753,17 @@ impl Server {
self.mirror_send(messages);
self.stats().data_sent(messages.len());
match write_all_half(&mut self.write, messages).await {
match write_all_flush(&mut self.stream, &messages).await {
Ok(_) => {
// Successfully sent to server
self.last_activity = SystemTime::now();
Ok(())
}
Err(err) => {
error!("Terminating server because of: {:?}", err);
error!(
"Terminating server {:?} because of: {:?}",
self.address, err
);
self.bad = true;
Err(err)
}
@@ -453,10 +775,13 @@ impl Server {
/// in order to receive all data the server has to offer.
pub async fn recv(&mut self) -> Result<BytesMut, Error> {
loop {
let mut message = match read_message(&mut self.read).await {
let mut message = match read_message(&mut self.stream).await {
Ok(message) => message,
Err(err) => {
error!("Terminating server because of: {:?}", err);
error!(
"Terminating server {:?} because of: {:?}",
self.address, err
);
self.bad = true;
return Err(err);
}
@@ -507,8 +832,19 @@ impl Server {
break;
}
// ErrorResponse
'E' => {
if self.in_copy_mode {
self.in_copy_mode = false;
}
}
// CommandComplete
'C' => {
if self.in_copy_mode {
self.in_copy_mode = false;
}
let mut command_tag = String::new();
match message.reader().read_to_string(&mut command_tag) {
Ok(_) => {
@@ -523,12 +859,12 @@ impl Server {
// This will reduce amount of discard statements sent
if !self.in_transaction {
debug!("Server connection marked for clean up");
self.needs_cleanup = true;
self.cleanup_state.needs_cleanup_set = true;
}
}
"PREPARE\0" => {
debug!("Server connection marked for clean up");
self.needs_cleanup = true;
self.cleanup_state.needs_cleanup_prepare = true;
}
_ => (),
}
@@ -552,10 +888,14 @@ impl Server {
}
// CopyInResponse: copy is starting from client to server.
'G' => break,
'G' => {
self.in_copy_mode = true;
break;
}
// CopyOutResponse: copy is starting from the server to the client.
'H' => {
self.in_copy_mode = true;
self.data_available = true;
break;
}
@@ -593,6 +933,115 @@ impl Server {
Ok(bytes)
}
/// Add the prepared statement to being tracked by this server.
/// The client is processing data that will create a prepared statement on this server.
pub fn will_prepare(&mut self, name: &str) {
debug!("Will prepare `{}`", name);
self.prepared_statements.insert(name.to_string());
self.stats.prepared_cache_add();
}
/// Check if we should prepare a statement on the server.
pub fn should_prepare(&self, name: &str) -> bool {
let should_prepare = !self.prepared_statements.contains(name);
debug!("Should prepare `{}`: {}", name, should_prepare);
if should_prepare {
self.stats.prepared_cache_miss();
} else {
self.stats.prepared_cache_hit();
}
should_prepare
}
/// Create a prepared statement on the server.
pub async fn prepare(&mut self, parse: &Parse) -> Result<(), Error> {
debug!("Preparing `{}`", parse.name);
let bytes: BytesMut = parse.try_into()?;
self.send(&bytes).await?;
self.send(&flush()).await?;
// Read and discard ParseComplete (B)
match read_message(&mut self.stream).await {
Ok(_) => (),
Err(err) => {
self.bad = true;
return Err(err);
}
}
self.prepared_statements.insert(parse.name.to_string());
self.stats.prepared_cache_add();
debug!("Prepared `{}`", parse.name);
Ok(())
}
/// Maintain adequate cache size on the server.
pub async fn maintain_cache(&mut self) -> Result<(), Error> {
debug!("Cache maintenance run");
let max_cache_size = get_prepared_statements_cache_size();
let mut names = Vec::new();
while self.prepared_statements.len() >= max_cache_size {
// The prepared statmeents are alphanumerically sorted by the BTree.
// FIFO.
if let Some(name) = self.prepared_statements.pop_last() {
names.push(name);
}
}
self.deallocate(names).await?;
Ok(())
}
/// Remove the prepared statement from being tracked by this server.
/// The client is processing data that will cause the server to close the prepared statement.
pub fn will_close(&mut self, name: &str) {
debug!("Will close `{}`", name);
self.prepared_statements.remove(name);
}
/// Close a prepared statement on the server.
pub async fn deallocate(&mut self, names: Vec<String>) -> Result<(), Error> {
for name in &names {
debug!("Deallocating prepared statement `{}`", name);
let close = Close::new(name);
let bytes: BytesMut = close.try_into()?;
self.send(&bytes).await?;
}
self.send(&flush()).await?;
// Read and discard CloseComplete (3)
for name in &names {
match read_message(&mut self.stream).await {
Ok(_) => {
self.prepared_statements.remove(name);
self.stats.prepared_cache_remove();
debug!("Closed `{}`", name);
}
Err(err) => {
self.bad = true;
return Err(err);
}
};
}
Ok(())
}
/// If the server is still inside a transaction.
/// If the client disconnects while the server is in a transaction, we will clean it up.
pub fn in_transaction(&self) -> bool {
@@ -600,6 +1049,10 @@ impl Server {
self.in_transaction
}
pub fn in_copy_mode(&self) -> bool {
self.in_copy_mode
}
/// We don't buffer all of server responses, e.g. COPY OUT produces too much data.
/// The client is responsible to call `self.recv()` while this method returns true.
pub fn is_data_available(&self) -> bool {
@@ -609,7 +1062,23 @@ impl Server {
/// Server & client are out of sync, we must discard this connection.
/// This happens with clients that misbehave.
pub fn is_bad(&self) -> bool {
self.bad
if self.bad {
return self.bad;
};
let cached_resolver = CACHED_RESOLVER.load();
if cached_resolver.enabled() {
if let Some(addr_set) = &self.addr_set {
if cached_resolver.has_changed(self.address.host.as_str(), addr_set) {
warn!(
"DNS changed for {}, it was {:?}. Dropping server connection.",
self.address.host.as_str(),
addr_set
);
return true;
}
}
}
false
}
/// Get server startup information to forward it to the client.
@@ -642,6 +1111,8 @@ impl Server {
/// It will use the simple query protocol.
/// Result will not be returned, so this is useful for things like `SET` or `ROLLBACK`.
pub async fn query(&mut self, query: &str) -> Result<(), Error> {
debug!("Running `{}` on server {:?}", query, self.address);
let query = simple_query(query);
self.send(&query).await?;
@@ -674,10 +1145,15 @@ impl Server {
// to avoid leaking state between clients. For performance reasons we only
// send `DISCARD ALL` if we think the session is altered instead of just sending
// it before each checkin.
if self.needs_cleanup {
warn!("Server returned with session state altered, discarding state");
if self.cleanup_state.needs_cleanup() && self.cleanup_connections {
warn!("Server returned with session state altered, discarding state ({}) for application {}", self.cleanup_state, self.application_name);
self.query("DISCARD ALL").await?;
self.needs_cleanup = false;
self.query("RESET ROLE").await?;
self.cleanup_state.reset();
}
if self.in_copy_mode() {
warn!("Server returned while still in copy-mode");
}
Ok(())
@@ -689,12 +1165,12 @@ impl Server {
self.application_name = name.to_string();
// We don't want `SET application_name` to mark the server connection
// as needing cleanup
let needs_cleanup_before = self.needs_cleanup;
let needs_cleanup_before = self.cleanup_state;
let result = Ok(self
.query(&format!("SET application_name = '{}'", name))
.await?);
self.needs_cleanup = needs_cleanup_before;
self.cleanup_state = needs_cleanup_before;
result
} else {
Ok(())
@@ -719,7 +1195,7 @@ impl Server {
// Marks a connection as needing DISCARD ALL at checkin
pub fn mark_dirty(&mut self) {
self.needs_cleanup = true;
self.cleanup_state.set_true();
}
pub fn mirror_send(&mut self, bytes: &BytesMut) {
@@ -753,11 +1229,10 @@ impl Server {
client_server_map,
Arc::new(ServerStats::default()),
Arc::new(RwLock::new(None)),
true,
)
.await?;
debug!("Connected!, sending query: {}", query);
debug!("Connected!, sending query.");
server.send(&simple_query(query)).await?;
let mut message = server.recv().await?;
@@ -766,8 +1241,6 @@ impl Server {
}
async fn parse_query_message(message: &mut BytesMut) -> Result<Vec<String>, Error> {
debug!("Parsing query message");
let mut pair = Vec::<String>::new();
match message::backend::Message::parse(message) {
Ok(Some(message::backend::Message::RowDescription(_description))) => {}
@@ -837,9 +1310,6 @@ async fn parse_query_message(message: &mut BytesMut) -> Result<Vec<String>, Erro
}
};
}
debug!("Got auth hash successfully");
Ok(pair)
}
@@ -853,23 +1323,27 @@ impl Drop for Server {
// Update statistics
self.stats.disconnect();
let mut bytes = BytesMut::with_capacity(4);
let mut bytes = BytesMut::with_capacity(5);
bytes.put_u8(b'X');
bytes.put_i32(4);
match self.write.try_write(&bytes) {
Ok(_) => (),
Err(_) => debug!("Dirty shutdown"),
match self.stream.get_mut().try_write(&bytes) {
Ok(5) => (),
_ => debug!("Dirty shutdown"),
};
// Should not matter.
self.bad = true;
let now = chrono::offset::Utc::now().naive_utc();
let duration = now - self.connected_at;
let message = if self.bad {
"Server connection terminated"
} else {
"Server connection closed"
};
info!(
"Server connection closed {:?}, session duration: {}",
"{} {:?}, session duration: {}",
message,
self.address,
crate::format_duration(&duration)
);

View File

@@ -1,4 +1,3 @@
use crate::pool::PoolIdentifier;
/// Statistics and reporting.
use arc_swap::ArcSwap;
@@ -16,13 +15,11 @@ pub mod pool;
pub mod server;
pub use address::AddressStats;
pub use client::{ClientState, ClientStats};
pub use pool::PoolStats;
pub use server::{ServerState, ServerStats};
/// Convenience types for various stats
type ClientStatesLookup = HashMap<i32, Arc<ClientStats>>;
type ServerStatesLookup = HashMap<i32, Arc<ServerStats>>;
type PoolStatsLookup = HashMap<PoolIdentifier, Arc<PoolStats>>;
/// Stats for individual client connections
/// Used in SHOW CLIENTS.
@@ -34,11 +31,6 @@ static CLIENT_STATS: Lazy<Arc<RwLock<ClientStatesLookup>>> =
static SERVER_STATS: Lazy<Arc<RwLock<ServerStatesLookup>>> =
Lazy::new(|| Arc::new(RwLock::new(ServerStatesLookup::default())));
/// Aggregate stats for each pool (a pool is identified by database name and username)
/// Used in SHOW POOLS.
static POOL_STATS: Lazy<Arc<RwLock<PoolStatsLookup>>> =
Lazy::new(|| Arc::new(RwLock::new(PoolStatsLookup::default())));
/// The statistics reporter. An instance is given to each possible source of statistics,
/// e.g. client stats, server stats, connection pool stats.
pub static REPORTER: Lazy<ArcSwap<Reporter>> =
@@ -66,7 +58,7 @@ impl Reporter {
CLIENT_STATS.write().insert(client_id, stats);
}
/// Reports a client is disconecting from the pooler.
/// Reports a client is disconnecting from the pooler.
fn client_disconnecting(&self, client_id: i32) {
CLIENT_STATS.write().remove(&client_id);
}
@@ -76,15 +68,10 @@ impl Reporter {
fn server_register(&self, server_id: i32, stats: Arc<ServerStats>) {
SERVER_STATS.write().insert(server_id, stats);
}
/// Reports a server connection is disconecting from the pooler.
/// Reports a server connection is disconnecting from the pooler.
fn server_disconnecting(&self, server_id: i32) {
SERVER_STATS.write().remove(&server_id);
}
/// Register a pool with the stats system.
fn pool_register(&self, identifier: PoolIdentifier, stats: Arc<PoolStats>) {
POOL_STATS.write().insert(identifier, stats);
}
}
/// The statistics collector which used for calculating averages
@@ -105,8 +92,20 @@ impl Collector {
loop {
interval.tick().await;
for stats in SERVER_STATS.read().values() {
stats.address_stats().update_averages();
// Hold read lock for duration of update to retain all server stats
let server_stats = SERVER_STATS.read();
for stats in server_stats.values() {
if !stats.check_address_stat_average_is_updated_status() {
stats.address_stats().update_averages();
stats.address_stats().reset_current_counts();
stats.set_address_stat_average_is_updated_status(true);
}
}
// Reset to false for next update
for stats in server_stats.values() {
stats.set_address_stat_average_is_updated_status(false);
}
}
});
@@ -125,12 +124,6 @@ pub fn get_server_stats() -> ServerStatesLookup {
SERVER_STATS.read().clone()
}
/// Get a snapshot of pool statistics.
/// by the `Collector`.
pub fn get_pool_stats() -> PoolStatsLookup {
POOL_STATS.read().clone()
}
/// Get the statistics reporter used to update stats across the pools/clients.
pub fn get_reporter() -> Reporter {
(*(*REPORTER.load())).clone()

View File

@@ -1,26 +1,29 @@
use log::warn;
use std::sync::atomic::*;
use std::sync::Arc;
#[derive(Debug, Clone, Default)]
struct AddressStatFields {
xact_count: Arc<AtomicU64>,
query_count: Arc<AtomicU64>,
bytes_received: Arc<AtomicU64>,
bytes_sent: Arc<AtomicU64>,
xact_time: Arc<AtomicU64>,
query_time: Arc<AtomicU64>,
wait_time: Arc<AtomicU64>,
errors: Arc<AtomicU64>,
}
/// Internal address stats
#[derive(Debug, Clone, Default)]
pub struct AddressStats {
pub total_xact_count: Arc<AtomicU64>,
pub total_query_count: Arc<AtomicU64>,
pub total_received: Arc<AtomicU64>,
pub total_sent: Arc<AtomicU64>,
pub total_xact_time: Arc<AtomicU64>,
pub total_query_time: Arc<AtomicU64>,
pub total_wait_time: Arc<AtomicU64>,
pub total_errors: Arc<AtomicU64>,
pub avg_query_count: Arc<AtomicU64>,
pub avg_query_time: Arc<AtomicU64>,
pub avg_recv: Arc<AtomicU64>,
pub avg_sent: Arc<AtomicU64>,
pub avg_errors: Arc<AtomicU64>,
pub avg_xact_time: Arc<AtomicU64>,
pub avg_xact_count: Arc<AtomicU64>,
pub avg_wait_time: Arc<AtomicU64>,
total: AddressStatFields,
current: AddressStatFields,
averages: AddressStatFields,
// Determines if the averages have been updated since the last time they were reported
pub averages_updated: Arc<AtomicBool>,
}
impl IntoIterator for AddressStats {
@@ -31,67 +34,67 @@ impl IntoIterator for AddressStats {
vec![
(
"total_xact_count".to_string(),
self.total_xact_count.load(Ordering::Relaxed),
self.total.xact_count.load(Ordering::Relaxed),
),
(
"total_query_count".to_string(),
self.total_query_count.load(Ordering::Relaxed),
self.total.query_count.load(Ordering::Relaxed),
),
(
"total_received".to_string(),
self.total_received.load(Ordering::Relaxed),
self.total.bytes_received.load(Ordering::Relaxed),
),
(
"total_sent".to_string(),
self.total_sent.load(Ordering::Relaxed),
self.total.bytes_sent.load(Ordering::Relaxed),
),
(
"total_xact_time".to_string(),
self.total_xact_time.load(Ordering::Relaxed),
self.total.xact_time.load(Ordering::Relaxed),
),
(
"total_query_time".to_string(),
self.total_query_time.load(Ordering::Relaxed),
self.total.query_time.load(Ordering::Relaxed),
),
(
"total_wait_time".to_string(),
self.total_wait_time.load(Ordering::Relaxed),
self.total.wait_time.load(Ordering::Relaxed),
),
(
"total_errors".to_string(),
self.total_errors.load(Ordering::Relaxed),
self.total.errors.load(Ordering::Relaxed),
),
(
"avg_xact_count".to_string(),
self.avg_xact_count.load(Ordering::Relaxed),
self.averages.xact_count.load(Ordering::Relaxed),
),
(
"avg_query_count".to_string(),
self.avg_query_count.load(Ordering::Relaxed),
self.averages.query_count.load(Ordering::Relaxed),
),
(
"avg_recv".to_string(),
self.avg_recv.load(Ordering::Relaxed),
self.averages.bytes_received.load(Ordering::Relaxed),
),
(
"avg_sent".to_string(),
self.avg_sent.load(Ordering::Relaxed),
self.averages.bytes_sent.load(Ordering::Relaxed),
),
(
"avg_errors".to_string(),
self.avg_errors.load(Ordering::Relaxed),
self.averages.errors.load(Ordering::Relaxed),
),
(
"avg_xact_time".to_string(),
self.avg_xact_time.load(Ordering::Relaxed),
self.averages.xact_time.load(Ordering::Relaxed),
),
(
"avg_query_time".to_string(),
self.avg_query_time.load(Ordering::Relaxed),
self.averages.query_time.load(Ordering::Relaxed),
),
(
"avg_wait_time".to_string(),
self.avg_wait_time.load(Ordering::Relaxed),
self.averages.wait_time.load(Ordering::Relaxed),
),
]
.into_iter()
@@ -99,22 +102,120 @@ impl IntoIterator for AddressStats {
}
impl AddressStats {
pub fn xact_count_add(&self) {
self.total.xact_count.fetch_add(1, Ordering::Relaxed);
self.current.xact_count.fetch_add(1, Ordering::Relaxed);
}
pub fn query_count_add(&self) {
self.total.query_count.fetch_add(1, Ordering::Relaxed);
self.current.query_count.fetch_add(1, Ordering::Relaxed);
}
pub fn bytes_received_add(&self, bytes: u64) {
self.total
.bytes_received
.fetch_add(bytes, Ordering::Relaxed);
self.current
.bytes_received
.fetch_add(bytes, Ordering::Relaxed);
}
pub fn bytes_sent_add(&self, bytes: u64) {
self.total.bytes_sent.fetch_add(bytes, Ordering::Relaxed);
self.current.bytes_sent.fetch_add(bytes, Ordering::Relaxed);
}
pub fn xact_time_add(&self, time: u64) {
self.total.xact_time.fetch_add(time, Ordering::Relaxed);
self.current.xact_time.fetch_add(time, Ordering::Relaxed);
}
pub fn query_time_add(&self, time: u64) {
self.total.query_time.fetch_add(time, Ordering::Relaxed);
self.current.query_time.fetch_add(time, Ordering::Relaxed);
}
pub fn wait_time_add(&self, time: u64) {
self.total.wait_time.fetch_add(time, Ordering::Relaxed);
self.current.wait_time.fetch_add(time, Ordering::Relaxed);
}
pub fn error(&self) {
self.total_errors.fetch_add(1, Ordering::Relaxed);
self.total.errors.fetch_add(1, Ordering::Relaxed);
self.current.errors.fetch_add(1, Ordering::Relaxed);
}
pub fn update_averages(&self) {
let (totals, averages) = self.fields_iterators();
for data in totals.iter().zip(averages.iter()) {
let (total, average) = data;
if let Err(err) = average.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |avg| {
let total = total.load(Ordering::Relaxed);
let avg = (total - avg) / (crate::stats::STAT_PERIOD / 1_000); // Avg / second
Some(avg)
}) {
warn!("Could not update averages for addresses stats, {:?}", err);
}
let stat_period_per_second = crate::stats::STAT_PERIOD / 1_000;
// xact_count
let current_xact_count = self.current.xact_count.load(Ordering::Relaxed);
let current_xact_time = self.current.xact_time.load(Ordering::Relaxed);
self.averages.xact_count.store(
current_xact_count / stat_period_per_second,
Ordering::Relaxed,
);
if current_xact_count == 0 {
self.averages.xact_time.store(0, Ordering::Relaxed);
} else {
self.averages
.xact_time
.store(current_xact_time / current_xact_count, Ordering::Relaxed);
}
// query_count
let current_query_count = self.current.query_count.load(Ordering::Relaxed);
let current_query_time = self.current.query_time.load(Ordering::Relaxed);
self.averages.query_count.store(
current_query_count / stat_period_per_second,
Ordering::Relaxed,
);
if current_query_count == 0 {
self.averages.query_time.store(0, Ordering::Relaxed);
} else {
self.averages
.query_time
.store(current_query_time / current_query_count, Ordering::Relaxed);
}
// bytes_received
let current_bytes_received = self.current.bytes_received.load(Ordering::Relaxed);
self.averages.bytes_received.store(
current_bytes_received / stat_period_per_second,
Ordering::Relaxed,
);
// bytes_sent
let current_bytes_sent = self.current.bytes_sent.load(Ordering::Relaxed);
self.averages.bytes_sent.store(
current_bytes_sent / stat_period_per_second,
Ordering::Relaxed,
);
// wait_time
let current_wait_time = self.current.wait_time.load(Ordering::Relaxed);
self.averages.wait_time.store(
current_wait_time / stat_period_per_second,
Ordering::Relaxed,
);
// errors
let current_errors = self.current.errors.load(Ordering::Relaxed);
self.averages
.errors
.store(current_errors / stat_period_per_second, Ordering::Relaxed);
}
pub fn reset_current_counts(&self) {
self.current.xact_count.store(0, Ordering::Relaxed);
self.current.xact_time.store(0, Ordering::Relaxed);
self.current.query_count.store(0, Ordering::Relaxed);
self.current.query_time.store(0, Ordering::Relaxed);
self.current.bytes_received.store(0, Ordering::Relaxed);
self.current.bytes_sent.store(0, Ordering::Relaxed);
self.current.wait_time.store(0, Ordering::Relaxed);
self.current.errors.store(0, Ordering::Relaxed);
}
pub fn populate_row(&self, row: &mut Vec<String>) {
@@ -122,28 +223,4 @@ impl AddressStats {
row.push(value.to_string());
}
}
fn fields_iterators(&self) -> (Vec<Arc<AtomicU64>>, Vec<Arc<AtomicU64>>) {
let mut totals: Vec<Arc<AtomicU64>> = Vec::new();
let mut averages: Vec<Arc<AtomicU64>> = Vec::new();
totals.push(self.total_xact_count.clone());
averages.push(self.avg_xact_count.clone());
totals.push(self.total_query_count.clone());
averages.push(self.avg_query_count.clone());
totals.push(self.total_received.clone());
averages.push(self.avg_recv.clone());
totals.push(self.total_sent.clone());
averages.push(self.avg_sent.clone());
totals.push(self.total_xact_time.clone());
averages.push(self.avg_xact_time.clone());
totals.push(self.total_query_time.clone());
averages.push(self.avg_query_time.clone());
totals.push(self.total_wait_time.clone());
averages.push(self.avg_wait_time.clone());
totals.push(self.total_errors.clone());
averages.push(self.avg_errors.clone());
(totals, averages)
}
}

View File

@@ -1,4 +1,3 @@
use super::PoolStats;
use super::{get_reporter, Reporter};
use atomic_enum::atomic_enum;
use std::sync::atomic::*;
@@ -34,12 +33,14 @@ pub struct ClientStats {
pool_name: String,
connect_time: Instant,
pool_stats: Arc<PoolStats>,
reporter: Reporter,
/// Total time spent waiting for a connection from pool, measures in microseconds
pub total_wait_time: Arc<AtomicU64>,
/// Maximum time spent waiting for a connection from pool, measures in microseconds
pub max_wait_time: Arc<AtomicU64>,
/// Current state of the client
pub state: Arc<AtomicClientState>,
@@ -61,8 +62,8 @@ impl Default for ClientStats {
application_name: String::new(),
username: String::new(),
pool_name: String::new(),
pool_stats: Arc::new(PoolStats::default()),
total_wait_time: Arc::new(AtomicU64::new(0)),
max_wait_time: Arc::new(AtomicU64::new(0)),
state: Arc::new(AtomicClientState::new(ClientState::Idle)),
transaction_count: Arc::new(AtomicU64::new(0)),
query_count: Arc::new(AtomicU64::new(0)),
@@ -79,11 +80,9 @@ impl ClientStats {
username: &str,
pool_name: &str,
connect_time: Instant,
pool_stats: Arc<PoolStats>,
) -> Self {
Self {
client_id,
pool_stats,
connect_time,
application_name: application_name.to_string(),
username: username.to_string(),
@@ -92,12 +91,10 @@ impl ClientStats {
}
}
/// Reports a client is disconecting from the pooler and
/// Reports a client is disconnecting from the pooler and
/// update metrics on the corresponding pool.
pub fn disconnect(&self) {
self.reporter.client_disconnecting(self.client_id);
self.pool_stats
.client_disconnect(self.state.load(Ordering::Relaxed))
}
/// Register a client with the stats system. The stats system uses client_id
@@ -105,27 +102,20 @@ impl ClientStats {
pub fn register(&self, stats: Arc<ClientStats>) {
self.reporter.client_register(self.client_id, stats);
self.state.store(ClientState::Idle, Ordering::Relaxed);
self.pool_stats.cl_idle.fetch_add(1, Ordering::Relaxed);
}
/// Reports a client is done querying the server and is no longer assigned a server connection
pub fn idle(&self) {
self.pool_stats
.client_idle(self.state.load(Ordering::Relaxed));
self.state.store(ClientState::Idle, Ordering::Relaxed);
}
/// Reports a client is waiting for a connection
pub fn waiting(&self) {
self.pool_stats
.client_waiting(self.state.load(Ordering::Relaxed));
self.state.store(ClientState::Waiting, Ordering::Relaxed);
}
/// Reports a client is done waiting for a connection and is about to query the server.
pub fn active(&self) {
self.pool_stats
.client_active(self.state.load(Ordering::Relaxed));
self.state.store(ClientState::Active, Ordering::Relaxed);
}
@@ -140,10 +130,12 @@ impl ClientStats {
self.error_count.fetch_add(1, Ordering::Relaxed);
}
/// Reportes the time spent by a client waiting to get a healthy connection from the pool
/// Reporters the time spent by a client waiting to get a healthy connection from the pool
pub fn checkout_time(&self, microseconds: u64) {
self.total_wait_time
.fetch_add(microseconds, Ordering::Relaxed);
self.max_wait_time
.fetch_max(microseconds, Ordering::Relaxed);
}
/// Report a query executed by a client against a server

View File

@@ -1,36 +1,131 @@
use crate::config::Pool;
use crate::config::PoolMode;
use crate::pool::PoolIdentifier;
use std::sync::atomic::*;
use std::sync::Arc;
use log::debug;
use super::get_reporter;
use super::Reporter;
use super::{ClientState, ServerState};
use crate::{config::PoolMode, messages::DataType, pool::PoolIdentifier};
use std::collections::HashMap;
use std::sync::atomic::*;
#[derive(Debug, Clone, Default)]
use crate::pool::get_all_pools;
#[derive(Debug, Clone)]
/// A struct that holds information about a Pool .
pub struct PoolStats {
// Pool identifier, cannot be changed after creating the instance
identifier: PoolIdentifier,
pub identifier: PoolIdentifier,
pub mode: PoolMode,
pub cl_idle: u64,
pub cl_active: u64,
pub cl_waiting: u64,
pub cl_cancel_req: u64,
pub sv_active: u64,
pub sv_idle: u64,
pub sv_used: u64,
pub sv_tested: u64,
pub sv_login: u64,
pub maxwait: u64,
}
impl PoolStats {
pub fn new(identifier: PoolIdentifier, mode: PoolMode) -> Self {
PoolStats {
identifier,
mode,
cl_idle: 0,
cl_active: 0,
cl_waiting: 0,
cl_cancel_req: 0,
sv_active: 0,
sv_idle: 0,
sv_used: 0,
sv_tested: 0,
sv_login: 0,
maxwait: 0,
}
}
// Pool Config, cannot be changed after creating the instance
config: Pool,
pub fn construct_pool_lookup() -> HashMap<PoolIdentifier, PoolStats> {
let mut map: HashMap<PoolIdentifier, PoolStats> = HashMap::new();
let client_map = super::get_client_stats();
let server_map = super::get_server_stats();
// A reference to the global reporter.
reporter: Reporter,
for (identifier, pool) in get_all_pools() {
map.insert(
identifier.clone(),
PoolStats::new(identifier, pool.settings.pool_mode),
);
}
/// Counters (atomics)
pub cl_idle: Arc<AtomicU64>,
pub cl_active: Arc<AtomicU64>,
pub cl_waiting: Arc<AtomicU64>,
pub cl_cancel_req: Arc<AtomicU64>,
pub sv_active: Arc<AtomicU64>,
pub sv_idle: Arc<AtomicU64>,
pub sv_used: Arc<AtomicU64>,
pub sv_tested: Arc<AtomicU64>,
pub sv_login: Arc<AtomicU64>,
pub maxwait: Arc<AtomicU64>,
for client in client_map.values() {
match map.get_mut(&PoolIdentifier {
db: client.pool_name(),
user: client.username(),
}) {
Some(pool_stats) => {
match client.state.load(Ordering::Relaxed) {
ClientState::Active => pool_stats.cl_active += 1,
ClientState::Idle => pool_stats.cl_idle += 1,
ClientState::Waiting => pool_stats.cl_waiting += 1,
}
let max_wait = client.max_wait_time.load(Ordering::Relaxed);
pool_stats.maxwait = std::cmp::max(pool_stats.maxwait, max_wait);
}
None => debug!("Client from an obselete pool"),
}
}
for server in server_map.values() {
match map.get_mut(&PoolIdentifier {
db: server.pool_name(),
user: server.username(),
}) {
Some(pool_stats) => match server.state.load(Ordering::Relaxed) {
ServerState::Active => pool_stats.sv_active += 1,
ServerState::Idle => pool_stats.sv_idle += 1,
ServerState::Login => pool_stats.sv_login += 1,
ServerState::Tested => pool_stats.sv_tested += 1,
},
None => debug!("Server from an obselete pool"),
}
}
return map;
}
pub fn generate_header() -> Vec<(&'static str, DataType)> {
return vec![
("database", DataType::Text),
("user", DataType::Text),
("pool_mode", DataType::Text),
("cl_idle", DataType::Numeric),
("cl_active", DataType::Numeric),
("cl_waiting", DataType::Numeric),
("cl_cancel_req", DataType::Numeric),
("sv_active", DataType::Numeric),
("sv_idle", DataType::Numeric),
("sv_used", DataType::Numeric),
("sv_tested", DataType::Numeric),
("sv_login", DataType::Numeric),
("maxwait", DataType::Numeric),
("maxwait_us", DataType::Numeric),
];
}
pub fn generate_row(&self) -> Vec<String> {
return vec![
self.identifier.db.clone(),
self.identifier.user.clone(),
self.mode.to_string(),
self.cl_idle.to_string(),
self.cl_active.to_string(),
self.cl_waiting.to_string(),
self.cl_cancel_req.to_string(),
self.sv_active.to_string(),
self.sv_idle.to_string(),
self.sv_used.to_string(),
self.sv_tested.to_string(),
self.sv_login.to_string(),
(self.maxwait / 1_000_000).to_string(),
(self.maxwait % 1_000_000).to_string(),
];
}
}
impl IntoIterator for PoolStats {
@@ -39,243 +134,18 @@ impl IntoIterator for PoolStats {
fn into_iter(self) -> Self::IntoIter {
vec![
("cl_idle".to_string(), self.cl_idle.load(Ordering::Relaxed)),
(
"cl_active".to_string(),
self.cl_active.load(Ordering::Relaxed),
),
(
"cl_waiting".to_string(),
self.cl_waiting.load(Ordering::Relaxed),
),
(
"cl_cancel_req".to_string(),
self.cl_cancel_req.load(Ordering::Relaxed),
),
(
"sv_active".to_string(),
self.sv_active.load(Ordering::Relaxed),
),
("sv_idle".to_string(), self.sv_idle.load(Ordering::Relaxed)),
("sv_used".to_string(), self.sv_used.load(Ordering::Relaxed)),
(
"sv_tested".to_string(),
self.sv_tested.load(Ordering::Relaxed),
),
(
"sv_login".to_string(),
self.sv_login.load(Ordering::Relaxed),
),
(
"maxwait".to_string(),
self.maxwait.load(Ordering::Relaxed) / 1_000_000,
),
(
"maxwait_us".to_string(),
self.maxwait.load(Ordering::Relaxed) % 1_000_000,
),
("cl_idle".to_string(), self.cl_idle),
("cl_active".to_string(), self.cl_active),
("cl_waiting".to_string(), self.cl_waiting),
("cl_cancel_req".to_string(), self.cl_cancel_req),
("sv_active".to_string(), self.sv_active),
("sv_idle".to_string(), self.sv_idle),
("sv_used".to_string(), self.sv_used),
("sv_tested".to_string(), self.sv_tested),
("sv_login".to_string(), self.sv_login),
("maxwait".to_string(), self.maxwait / 1_000_000),
("maxwait_us".to_string(), self.maxwait % 1_000_000),
]
.into_iter()
}
}
impl PoolStats {
pub fn new(identifier: PoolIdentifier, config: Pool) -> Self {
Self {
identifier,
config,
reporter: get_reporter(),
..Default::default()
}
}
// Getters
pub fn register(&self, stats: Arc<PoolStats>) {
self.reporter.pool_register(self.identifier.clone(), stats);
}
pub fn database(&self) -> String {
self.identifier.db.clone()
}
pub fn user(&self) -> String {
self.identifier.user.clone()
}
pub fn redacted_secret(&self) -> String {
match self.identifier.secret {
Some(ref s) => format!("****{}", &s[s.len() - 4..]),
None => "<no secret>".to_string(),
}
}
pub fn pool_mode(&self) -> PoolMode {
self.config.pool_mode
}
/// Populates an array of strings with counters (used by admin in show pools)
pub fn populate_row(&self, row: &mut Vec<String>) {
for (_key, value) in self.clone() {
row.push(value.to_string());
}
}
/// Deletes the maxwait counter, this is done everytime we obtain metrics
pub fn clear_maxwait(&self) {
self.maxwait.store(0, Ordering::Relaxed);
}
/// Notified when a server of the pool enters login state.
///
/// Arguments:
///
/// `from`: The state of the server that notifies.
pub fn server_login(&self, from: ServerState) {
self.sv_login.fetch_add(1, Ordering::Relaxed);
if from != ServerState::Login {
self.decrease_from_server_state(from);
}
}
/// Notified when a server of the pool become 'active'
///
/// Arguments:
///
/// `from`: The state of the server that notifies.
pub fn server_active(&self, from: ServerState) {
self.sv_active.fetch_add(1, Ordering::Relaxed);
if from != ServerState::Active {
self.decrease_from_server_state(from);
}
}
/// Notified when a server of the pool become 'tested'
///
/// Arguments:
///
/// `from`: The state of the server that notifies.
pub fn server_tested(&self, from: ServerState) {
self.sv_tested.fetch_add(1, Ordering::Relaxed);
if from != ServerState::Tested {
self.decrease_from_server_state(from);
}
}
/// Notified when a server of the pool become 'idle'
///
/// Arguments:
///
/// `from`: The state of the server that notifies.
pub fn server_idle(&self, from: ServerState) {
self.sv_idle.fetch_add(1, Ordering::Relaxed);
if from != ServerState::Idle {
self.decrease_from_server_state(from);
}
}
/// Notified when a client of the pool become 'waiting'
///
/// Arguments:
///
/// `from`: The state of the client that notifies.
pub fn client_waiting(&self, from: ClientState) {
if from != ClientState::Waiting {
self.cl_waiting.fetch_add(1, Ordering::Relaxed);
self.decrease_from_client_state(from);
}
}
/// Notified when a client of the pool become 'active'
///
/// Arguments:
///
/// `from`: The state of the client that notifies.
pub fn client_active(&self, from: ClientState) {
if from != ClientState::Active {
self.cl_active.fetch_add(1, Ordering::Relaxed);
self.decrease_from_client_state(from);
}
}
/// Notified when a client of the pool become 'idle'
///
/// Arguments:
///
/// `from`: The state of the client that notifies.
pub fn client_idle(&self, from: ClientState) {
if from != ClientState::Idle {
self.cl_idle.fetch_add(1, Ordering::Relaxed);
self.decrease_from_client_state(from);
}
}
/// Notified when a client disconnects.
///
/// Arguments:
///
/// `from`: The state of the client that notifies.
pub fn client_disconnect(&self, from: ClientState) {
let counter = match from {
ClientState::Idle => &self.cl_idle,
ClientState::Waiting => &self.cl_waiting,
ClientState::Active => &self.cl_active,
};
Self::decrease_counter(counter.clone());
}
/// Notified when a server disconnects.
///
/// Arguments:
///
/// `from`: The state of the client that notifies.
pub fn server_disconnect(&self, from: ServerState) {
let counter = match from {
ServerState::Active => &self.sv_active,
ServerState::Idle => &self.sv_idle,
ServerState::Login => &self.sv_login,
ServerState::Tested => &self.sv_tested,
};
Self::decrease_counter(counter.clone());
}
// helpers for counter decrease
fn decrease_from_server_state(&self, from: ServerState) {
let counter = match from {
ServerState::Tested => &self.sv_tested,
ServerState::Active => &self.sv_active,
ServerState::Idle => &self.sv_idle,
ServerState::Login => &self.sv_login,
};
Self::decrease_counter(counter.clone());
}
fn decrease_from_client_state(&self, from: ClientState) {
let counter = match from {
ClientState::Active => &self.cl_active,
ClientState::Idle => &self.cl_idle,
ClientState::Waiting => &self.cl_waiting,
};
Self::decrease_counter(counter.clone());
}
fn decrease_counter(value: Arc<AtomicU64>) {
if value.load(Ordering::Relaxed) > 0 {
value.fetch_sub(1, Ordering::Relaxed);
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_decrease() {
let stat: PoolStats = PoolStats::default();
stat.server_login(ServerState::Login);
stat.server_idle(ServerState::Login);
assert_eq!(stat.sv_login.load(Ordering::Relaxed), 0);
assert_eq!(stat.sv_idle.load(Ordering::Relaxed), 1);
}
}

View File

@@ -1,5 +1,4 @@
use super::AddressStats;
use super::PoolStats;
use super::{get_reporter, Reporter};
use crate::config::Address;
use atomic_enum::atomic_enum;
@@ -38,7 +37,6 @@ pub struct ServerStats {
address: Address,
connect_time: Instant,
pool_stats: Arc<PoolStats>,
reporter: Reporter,
/// Data
@@ -49,6 +47,9 @@ pub struct ServerStats {
pub transaction_count: Arc<AtomicU64>,
pub query_count: Arc<AtomicU64>,
pub error_count: Arc<AtomicU64>,
pub prepared_hit_count: Arc<AtomicU64>,
pub prepared_miss_count: Arc<AtomicU64>,
pub prepared_cache_size: Arc<AtomicU64>,
}
impl Default for ServerStats {
@@ -57,7 +58,6 @@ impl Default for ServerStats {
server_id: 0,
application_name: Arc::new(RwLock::new(String::new())),
address: Address::default(),
pool_stats: Arc::new(PoolStats::default()),
connect_time: Instant::now(),
state: Arc::new(AtomicServerState::new(ServerState::Login)),
bytes_sent: Arc::new(AtomicU64::new(0)),
@@ -66,15 +66,17 @@ impl Default for ServerStats {
query_count: Arc::new(AtomicU64::new(0)),
error_count: Arc::new(AtomicU64::new(0)),
reporter: get_reporter(),
prepared_hit_count: Arc::new(AtomicU64::new(0)),
prepared_miss_count: Arc::new(AtomicU64::new(0)),
prepared_cache_size: Arc::new(AtomicU64::new(0)),
}
}
}
impl ServerStats {
pub fn new(address: Address, pool_stats: Arc<PoolStats>, connect_time: Instant) -> Self {
pub fn new(address: Address, connect_time: Instant) -> Self {
Self {
address,
pool_stats,
connect_time,
server_id: rand::random::<i32>(),
..Default::default()
@@ -96,33 +98,23 @@ impl ServerStats {
/// Reports a server connection is no longer assigned to a client
/// and is available for the next client to pick it up
pub fn idle(&self) {
self.pool_stats
.server_idle(self.state.load(Ordering::Relaxed));
self.state.store(ServerState::Idle, Ordering::Relaxed);
self.set_undefined_application();
}
/// Reports a server connection is disconecting from the pooler.
/// Reports a server connection is disconnecting from the pooler.
/// Also updates metrics on the pool regarding server usage.
pub fn disconnect(&self) {
self.reporter.server_disconnecting(self.server_id);
self.pool_stats
.server_disconnect(self.state.load(Ordering::Relaxed))
}
/// Reports a server connection is being tested before being given to a client.
pub fn tested(&self) {
self.set_undefined_application();
self.pool_stats
.server_tested(self.state.load(Ordering::Relaxed));
self.state.store(ServerState::Tested, Ordering::Relaxed);
}
/// Reports a server connection is attempting to login.
pub fn login(&self) {
self.pool_stats
.server_login(self.state.load(Ordering::Relaxed));
self.state.store(ServerState::Login, Ordering::Relaxed);
self.set_undefined_application();
}
@@ -130,8 +122,6 @@ impl ServerStats {
/// Reports a server connection has been assigned to a client that
/// is about to query the server
pub fn active(&self, application_name: String) {
self.pool_stats
.server_active(self.state.load(Ordering::Relaxed));
self.state.store(ServerState::Active, Ordering::Relaxed);
self.set_application(application_name);
}
@@ -140,13 +130,24 @@ impl ServerStats {
self.address.stats.clone()
}
pub fn check_address_stat_average_is_updated_status(&self) -> bool {
self.address.stats.averages_updated.load(Ordering::Relaxed)
}
pub fn set_address_stat_average_is_updated_status(&self, is_checked: bool) {
self.address
.stats
.averages_updated
.store(is_checked, Ordering::Relaxed);
}
// Helper methods for show_servers
pub fn pool_name(&self) -> String {
self.pool_stats.database()
self.address.pool_name.clone()
}
pub fn username(&self) -> String {
self.pool_stats.user()
self.address.username.clone()
}
pub fn address_name(&self) -> String {
@@ -167,27 +168,17 @@ impl ServerStats {
}
pub fn checkout_time(&self, microseconds: u64, application_name: String) {
// Update server stats and address aggergation stats
// Update server stats and address aggregation stats
self.set_application(application_name);
self.address
.stats
.total_wait_time
.fetch_add(microseconds, Ordering::Relaxed);
self.pool_stats
.maxwait
.fetch_max(microseconds, Ordering::Relaxed);
self.address.stats.wait_time_add(microseconds);
}
/// Report a query executed by a client against a server
pub fn query(&self, milliseconds: u64, application_name: &str) {
self.set_application(application_name.to_string());
let address_stats = self.address_stats();
address_stats
.total_query_count
.fetch_add(1, Ordering::Relaxed);
address_stats
.total_query_time
.fetch_add(milliseconds, Ordering::Relaxed);
self.address.stats.query_count_add();
self.address.stats.query_time_add(milliseconds);
self.query_count.fetch_add(1, Ordering::Relaxed);
}
/// Report a transaction executed by a client a server
@@ -198,29 +189,38 @@ impl ServerStats {
self.set_application(application_name.to_string());
self.transaction_count.fetch_add(1, Ordering::Relaxed);
self.address
.stats
.total_xact_count
.fetch_add(1, Ordering::Relaxed);
self.address.stats.xact_count_add();
}
/// Report data sent to a server
pub fn data_sent(&self, amount_bytes: usize) {
self.bytes_sent
.fetch_add(amount_bytes as u64, Ordering::Relaxed);
self.address
.stats
.total_sent
.fetch_add(amount_bytes as u64, Ordering::Relaxed);
self.address.stats.bytes_sent_add(amount_bytes as u64);
}
/// Report data received from a server
pub fn data_received(&self, amount_bytes: usize) {
self.bytes_received
.fetch_add(amount_bytes as u64, Ordering::Relaxed);
self.address
.stats
.total_received
.fetch_add(amount_bytes as u64, Ordering::Relaxed);
self.address.stats.bytes_received_add(amount_bytes as u64);
}
/// Report a prepared statement that already exists on the server.
pub fn prepared_cache_hit(&self) {
self.prepared_hit_count.fetch_add(1, Ordering::Relaxed);
}
/// Report a prepared statement that does not exist on the server yet.
pub fn prepared_cache_miss(&self) {
self.prepared_miss_count.fetch_add(1, Ordering::Relaxed);
}
pub fn prepared_cache_add(&self) {
self.prepared_cache_size.fetch_add(1, Ordering::Relaxed);
}
pub fn prepared_cache_remove(&self) {
self.prepared_cache_size.fetch_sub(1, Ordering::Relaxed);
}
}

View File

@@ -4,7 +4,12 @@ use rustls_pemfile::{certs, read_one, Item};
use std::iter;
use std::path::Path;
use std::sync::Arc;
use tokio_rustls::rustls::{self, Certificate, PrivateKey};
use std::time::SystemTime;
use tokio_rustls::rustls::{
self,
client::{ServerCertVerified, ServerCertVerifier},
Certificate, PrivateKey, ServerName,
};
use tokio_rustls::TlsAcceptor;
use crate::config::get_config;
@@ -64,3 +69,19 @@ impl Tls {
})
}
}
pub struct NoCertificateVerification;
impl ServerCertVerifier for NoCertificateVerification {
fn verify_server_cert(
&self,
_end_entity: &Certificate,
_intermediates: &[Certificate],
_server_name: &ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8],
_now: SystemTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
}

View File

@@ -63,6 +63,7 @@ def cleanup_conn(conn: psycopg2.extensions.connection, cur: psycopg2.extensions.
def test_normal_db_access():
pgcat_start()
conn, cur = connect_db(autocommit=False)
cur.execute("SELECT 1")
res = cur.fetchall()

View File

@@ -11,323 +11,6 @@ describe "Admin" do
processes.pgcat.shutdown
end
describe "SHOW STATS" do
context "clients connect and make one query" do
it "updates *_query_time and *_wait_time" do
connection = PG::connect("#{pgcat_conn_str}?application_name=one_query")
connection.async_exec("SELECT pg_sleep(0.25)")
connection.async_exec("SELECT pg_sleep(0.25)")
connection.async_exec("SELECT pg_sleep(0.25)")
connection.close
# wait for averages to be calculated, we shouldn't do this too often
sleep(15.5)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW STATS")[0]
admin_conn.close
expect(results["total_query_time"].to_i).to be_within(200).of(750)
expect(results["avg_query_time"].to_i).to_not eq(0)
expect(results["total_wait_time"].to_i).to_not eq(0)
expect(results["avg_wait_time"].to_i).to_not eq(0)
end
end
end
describe "SHOW POOLS" do
context "bad credentials" do
it "does not change any stats" do
bad_passsword_url = URI(pgcat_conn_str)
bad_passsword_url.password = "wrong"
expect { PG::connect("#{bad_passsword_url.to_s}?application_name=bad_password") }.to raise_error(PG::ConnectionBad)
sleep(1)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["sv_idle"]).to eq("1")
end
end
context "bad database name" do
it "does not change any stats" do
bad_db_url = URI(pgcat_conn_str)
bad_db_url.path = "/wrong_db"
expect { PG::connect("#{bad_db_url.to_s}?application_name=bad_db") }.to raise_error(PG::ConnectionBad)
sleep(1)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["sv_idle"]).to eq("1")
end
end
context "client connects but issues no queries" do
it "only affects cl_idle stats" do
connections = Array.new(20) { PG::connect(pgcat_conn_str) }
sleep(1)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_idle"]).to eq("20")
expect(results["sv_idle"]).to eq("1")
connections.map(&:close)
sleep(1.1)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_idle cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["sv_idle"]).to eq("1")
end
end
context "clients connect and make one query" do
it "only affects cl_idle, sv_idle stats" do
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
Thread.new { c.async_exec("SELECT pg_sleep(2.5)") }
end
sleep(1.1)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_waiting cl_cancel_req sv_idle sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_active"]).to eq("5")
expect(results["sv_active"]).to eq("5")
sleep(3)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_idle"]).to eq("5")
expect(results["sv_idle"]).to eq("5")
connections.map(&:close)
sleep(1)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["sv_idle"]).to eq("5")
end
end
context "client connects and opens a transaction and closes connection uncleanly" do
it "produces correct statistics" do
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
Thread.new do
c.async_exec("BEGIN")
c.async_exec("SELECT pg_sleep(0.01)")
c.close
end
end
sleep(1.1)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["sv_idle"]).to eq("5")
end
end
context "client fail to checkout connection from the pool" do
it "counts clients as idle" do
new_configs = processes.pgcat.current_config
new_configs["general"]["connect_timeout"] = 500
new_configs["general"]["ban_time"] = 1
new_configs["general"]["shutdown_timeout"] = 1
new_configs["pools"]["sharded_db"]["users"]["0"]["pool_size"] = 1
processes.pgcat.update_config(new_configs)
processes.pgcat.reload_config
threads = []
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
threads << Thread.new { c.async_exec("SELECT pg_sleep(1)") rescue PG::SystemError }
end
sleep(2)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_idle"]).to eq("5")
expect(results["sv_idle"]).to eq("1")
threads.map(&:join)
connections.map(&:close)
end
end
context "clients connects and disconnect normally" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 2) }
it 'shows the same number of clients before and after' do
clients_before = clients_connected_to_pool(processes: processes)
threads = []
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
threads << Thread.new { c.async_exec("SELECT 1") }
end
clients_between = clients_connected_to_pool(processes: processes)
expect(clients_before).not_to eq(clients_between)
connections.each(&:close)
clients_after = clients_connected_to_pool(processes: processes)
expect(clients_before).to eq(clients_after)
end
end
context "clients connects and disconnect abruptly" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 10) }
it 'shows the same number of clients before and after' do
threads = []
connections = Array.new(2) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
threads << Thread.new { c.async_exec("SELECT 1") }
end
clients_before = clients_connected_to_pool(processes: processes)
random_string = (0...8).map { (65 + rand(26)).chr }.join
connection_string = "#{pgcat_conn_str}?application_name=#{random_string}"
faulty_client = Process.spawn("psql -Atx #{connection_string} >/dev/null")
sleep(1)
# psql starts two processes, we only know the pid of the parent, this
# ensure both are killed
`pkill -9 -f '#{random_string}'`
Process.wait(faulty_client)
clients_after = clients_connected_to_pool(processes: processes)
expect(clients_before).to eq(clients_after)
end
end
context "clients overwhelm server pools" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 2) }
it "cl_waiting is updated to show it" do
threads = []
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
threads << Thread.new { c.async_exec("SELECT pg_sleep(1.5)") }
end
sleep(1.1) # Allow time for stats to update
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_cancel_req sv_idle sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_waiting"]).to eq("2")
expect(results["cl_active"]).to eq("2")
expect(results["sv_active"]).to eq("2")
sleep(2.5) # Allow time for stats to update
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_idle"]).to eq("4")
expect(results["sv_idle"]).to eq("2")
threads.map(&:join)
connections.map(&:close)
end
it "show correct max_wait" do
threads = []
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
threads << Thread.new { c.async_exec("SELECT pg_sleep(1.5)") }
end
sleep(2.5) # Allow time for stats to update
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
expect(results["maxwait"]).to eq("1")
expect(results["maxwait_us"].to_i).to be_within(200_000).of(500_000)
sleep(4.5) # Allow time for stats to update
results = admin_conn.async_exec("SHOW POOLS")[0]
expect(results["maxwait"]).to eq("0")
threads.map(&:join)
connections.map(&:close)
end
end
end
describe "SHOW CLIENTS" do
it "reports correct number and application names" do
conn_str = processes.pgcat.connection_string("sharded_db", "sharding_user")
connections = Array.new(20) { |i| PG::connect("#{conn_str}?application_name=app#{i % 5}") }
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
sleep(1) # Wait for stats to be updated
results = admin_conn.async_exec("SHOW CLIENTS")
expect(results.count).to eq(21) # count admin clients
expect(results.select { |c| c["application_name"] == "app3" || c["application_name"] == "app4" }.count).to eq(8)
expect(results.select { |c| c["database"] == "pgcat" }.count).to eq(1)
connections[0..5].map(&:close)
sleep(1) # Wait for stats to be updated
results = admin_conn.async_exec("SHOW CLIENTS")
expect(results.count).to eq(15)
connections[6..].map(&:close)
sleep(1) # Wait for stats to be updated
expect(admin_conn.async_exec("SHOW CLIENTS").count).to eq(1)
admin_conn.close
end
it "reports correct number of queries and transactions" do
conn_str = processes.pgcat.connection_string("sharded_db", "sharding_user")
connections = Array.new(2) { |i| PG::connect("#{conn_str}?application_name=app#{i}") }
connections.each do |c|
c.async_exec("SELECT 1")
c.async_exec("SELECT 2")
c.async_exec("SELECT 3")
c.async_exec("BEGIN")
c.async_exec("SELECT 4")
c.async_exec("SELECT 5")
c.async_exec("COMMIT")
end
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
sleep(1) # Wait for stats to be updated
results = admin_conn.async_exec("SHOW CLIENTS")
expect(results.count).to eq(3)
normal_client_results = results.reject { |r| r["database"] == "pgcat" }
expect(normal_client_results[0]["transaction_count"]).to eq("4")
expect(normal_client_results[1]["transaction_count"]).to eq("4")
expect(normal_client_results[0]["query_count"]).to eq("7")
expect(normal_client_results[1]["query_count"]).to eq("7")
admin_conn.close
connections.map(&:close)
end
end
describe "Manual Banning" do
let(:processes) { Helpers::Pgcat.single_shard_setup("sharded_db", 10) }
before do
@@ -398,7 +81,7 @@ describe "Admin" do
end
end
describe "SHOW users" do
describe "SHOW USERS" do
it "returns the right users" do
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW USERS")[0]

View File

@@ -67,7 +67,7 @@ describe "Auth Query" do
end
context 'and with cleartext passwords not set' do
let(:config_user) { { 'username' => 'sharding_user' } }
let(:config_user) { { 'username' => 'sharding_user', 'password' => 'sharding_user' } }
it 'it uses obtained passwords' do
connection_string = processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password'])
@@ -76,7 +76,7 @@ describe "Auth Query" do
end
it 'allows passwords to be changed without closing existing connections' do
pgconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password']))
pgconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username']))
expect(pgconn.exec("SELECT 1 + 2")).not_to be_nil
Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD 'secret2';")
expect(pgconn.exec("SELECT 1 + 4")).not_to be_nil
@@ -84,7 +84,7 @@ describe "Auth Query" do
end
it 'allows passwords to be changed and that new password is needed when reconnecting' do
pgconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password']))
pgconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username']))
expect(pgconn.exec("SELECT 1 + 2")).not_to be_nil
Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD 'secret2';")
newconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], 'secret2'))

View File

@@ -1,39 +0,0 @@
# frozen_string_literal: true
require_relative 'spec_helper'
describe "Authentication" do
describe "multiple secrets configured" do
let(:secrets) { ["one_secret", "two_secret"] }
let(:processes) { Helpers::Pgcat.three_shard_setup("sharded_db", 5, pool_mode="transaction", lb_mode="random", log_level="info", secrets=["one_secret", "two_secret"]) }
after do
processes.all_databases.map(&:reset)
processes.pgcat.shutdown
end
it "can connect using all secrets and postgres password" do
secrets.push("sharding_user").each do |secret|
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user", password=secret))
conn.exec("SELECT current_user")
end
end
end
describe "no secrets configured" do
let(:secrets) { [] }
let(:processes) { Helpers::Pgcat.three_shard_setup("sharded_db", 5, pool_mode="transaction", lb_mode="random", log_level="info") }
after do
processes.all_databases.map(&:reset)
processes.pgcat.shutdown
end
it "can connect using only the password" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.exec("SELECT current_user")
expect { PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user", password="secret_one")) }.to raise_error PG::ConnectionBad
end
end
end

102
tests/ruby/copy_spec.rb Normal file
View File

@@ -0,0 +1,102 @@
# frozen_string_literal: true
require_relative 'spec_helper'
describe "COPY Handling" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 5) }
before do
new_configs = processes.pgcat.current_config
# Allow connections in the pool to expire faster
new_configs["general"]["idle_timeout"] = 5
processes.pgcat.update_config(new_configs)
# We need to kill the old process that was using the default configs
processes.pgcat.stop
processes.pgcat.start
processes.pgcat.wait_until_ready
end
before do
processes.all_databases.first.with_connection do |conn|
conn.async_exec "CREATE TABLE copy_test_table (a TEXT,b TEXT,c TEXT,d TEXT)"
end
end
after do
processes.all_databases.first.with_connection do |conn|
conn.async_exec "DROP TABLE copy_test_table;"
end
end
after do
processes.all_databases.map(&:reset)
processes.pgcat.shutdown
end
describe "COPY FROM" do
context "within transaction" do
it "finishes within alloted time" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
Timeout.timeout(3) do
conn.async_exec("BEGIN")
conn.copy_data "COPY copy_test_table FROM STDIN CSV" do
sleep 0.5
conn.put_copy_data "some,data,to,copy\n"
conn.put_copy_data "more,data,to,copy\n"
end
conn.async_exec("COMMIT")
end
res = conn.async_exec("SELECT * FROM copy_test_table").to_a
expect(res).to eq([
{"a"=>"some", "b"=>"data", "c"=>"to", "d"=>"copy"},
{"a"=>"more", "b"=>"data", "c"=>"to", "d"=>"copy"}
])
end
end
context "outside transaction" do
it "finishes within alloted time" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
Timeout.timeout(3) do
conn.copy_data "COPY copy_test_table FROM STDIN CSV" do
sleep 0.5
conn.put_copy_data "some,data,to,copy\n"
conn.put_copy_data "more,data,to,copy\n"
end
end
res = conn.async_exec("SELECT * FROM copy_test_table").to_a
expect(res).to eq([
{"a"=>"some", "b"=>"data", "c"=>"to", "d"=>"copy"},
{"a"=>"more", "b"=>"data", "c"=>"to", "d"=>"copy"}
])
end
end
end
describe "COPY TO" do
before do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.async_exec("BEGIN")
conn.copy_data "COPY copy_test_table FROM STDIN CSV" do
conn.put_copy_data "some,data,to,copy\n"
conn.put_copy_data "more,data,to,copy\n"
end
conn.async_exec("COMMIT")
conn.close
end
it "works" do
res = []
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.copy_data "COPY copy_test_table TO STDOUT CSV" do
while row=conn.get_copy_data
res << row
end
end
expect(res).to eq(["some,data,to,copy\n", "more,data,to,copy\n"])
end
end
end

View File

@@ -0,0 +1,259 @@
require 'socket'
require 'digest/md5'
BACKEND_MESSAGE_CODES = {
'Z' => "ReadyForQuery",
'C' => "CommandComplete",
'T' => "RowDescription",
'D' => "DataRow",
'1' => "ParseComplete",
'2' => "BindComplete",
'E' => "ErrorResponse",
's' => "PortalSuspended",
}
class PostgresSocket
def initialize(host, port)
@port = port
@host = host
@socket = TCPSocket.new @host, @port
@parameters = {}
@verbose = true
end
def send_md5_password_message(username, password, salt)
m = Digest::MD5.hexdigest(password + username)
m = Digest::MD5.hexdigest(m + salt.map(&:chr).join(""))
m = 'md5' + m
bytes = (m.split("").map(&:ord) + [0]).flatten
message_size = bytes.count + 4
message = []
message << 'p'.ord
message << [message_size].pack('l>').unpack('CCCC') # 4
message << bytes
message.flatten!
@socket.write(message.pack('C*'))
end
def send_startup_message(username, database, password)
message = []
message << [196608].pack('l>').unpack('CCCC') # 4
message << "user".split('').map(&:ord) # 4, 8
message << 0 # 1, 9
message << username.split('').map(&:ord) # 2, 11
message << 0 # 1, 12
message << "database".split('').map(&:ord) # 8, 20
message << 0 # 1, 21
message << database.split('').map(&:ord) # 2, 23
message << 0 # 1, 24
message << 0 # 1, 25
message.flatten!
total_message_size = message.size + 4
message_len = [total_message_size].pack('l>').unpack('CCCC')
@socket.write([message_len + message].flatten.pack('C*'))
sleep 0.1
read_startup_response(username, password)
end
def read_startup_response(username, password)
message_code, message_len = @socket.recv(5).unpack("al>")
while message_code == 'R'
auth_code = @socket.recv(4).unpack('l>').pop
case auth_code
when 5 # md5
salt = @socket.recv(4).unpack('CCCC')
send_md5_password_message(username, password, salt)
message_code, message_len = @socket.recv(5).unpack("al>")
when 0 # trust
break
end
end
loop do
message_code, message_len = @socket.recv(5).unpack("al>")
if message_code == 'Z'
@socket.recv(1).unpack("a") # most likely I
break # We are good to go
end
if message_code == 'S'
actual_message = @socket.recv(message_len - 4).unpack("C*")
k,v = actual_message.pack('U*').split(/\x00/)
@parameters[k] = v
end
if message_code == 'K'
process_id, secret_key = @socket.recv(message_len - 4).unpack("l>l>")
@parameters["process_id"] = process_id
@parameters["secret_key"] = secret_key
end
end
return @parameters
end
def cancel_query
socket = TCPSocket.new @host, @port
process_key = @parameters["process_id"]
secret_key = @parameters["secret_key"]
message = []
message << [16].pack('l>').unpack('CCCC') # 4
message << [80877102].pack('l>').unpack('CCCC') # 4
message << [process_key.to_i].pack('l>').unpack('CCCC') # 4
message << [secret_key.to_i].pack('l>').unpack('CCCC') # 4
message.flatten!
socket.write(message.flatten.pack('C*'))
socket.close
log "[F] Sent CancelRequest message"
end
def send_query_message(query)
query_size = query.length
message_size = 1 + 4 + query_size
message = []
message << "Q".ord
message << [message_size].pack('l>').unpack('CCCC') # 4
message << query.split('').map(&:ord) # 2, 11
message << 0 # 1, 12
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent Q message (#{query})"
end
def send_parse_message(query)
query_size = query.length
message_size = 2 + 2 + 4 + query_size
message = []
message << "P".ord
message << [message_size].pack('l>').unpack('CCCC') # 4
message << 0 # unnamed statement
message << query.split('').map(&:ord) # 2, 11
message << 0 # 1, 12
message << [0, 0]
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent P message (#{query})"
end
def send_bind_message
message = []
message << "B".ord
message << [12].pack('l>').unpack('CCCC') # 4
message << 0 # unnamed statement
message << 0 # unnamed statement
message << [0, 0] # 2
message << [0, 0] # 2
message << [0, 0] # 2
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent B message"
end
def send_describe_message(mode)
message = []
message << "D".ord
message << [6].pack('l>').unpack('CCCC') # 4
message << mode.ord
message << 0 # unnamed statement
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent D message"
end
def send_execute_message(limit=0)
message = []
message << "E".ord
message << [9].pack('l>').unpack('CCCC') # 4
message << 0 # unnamed statement
message << [limit].pack('l>').unpack('CCCC') # 4
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent E message"
end
def send_sync_message
message = []
message << "S".ord
message << [4].pack('l>').unpack('CCCC') # 4
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent S message"
end
def send_copydone_message
message = []
message << "c".ord
message << [4].pack('l>').unpack('CCCC') # 4
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent c message"
end
def send_copyfail_message
message = []
message << "f".ord
message << [5].pack('l>').unpack('CCCC') # 4
message << 0
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent f message"
end
def send_flush_message
message = []
message << "H".ord
message << [4].pack('l>').unpack('CCCC') # 4
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent H message"
end
def read_from_server()
output_messages = []
retry_count = 0
message_code = nil
message_len = 0
loop do
begin
message_code, message_len = @socket.recv_nonblock(5).unpack("al>")
rescue IO::WaitReadable
return output_messages if retry_count > 50
retry_count += 1
sleep(0.01)
next
end
message = {
code: message_code,
len: message_len,
bytes: []
}
log "[B] #{BACKEND_MESSAGE_CODES[message_code] || ('UnknownMessage(' + message_code + ')')}"
actual_message_length = message_len - 4
if actual_message_length > 0
message[:bytes] = @socket.recv(message_len - 4).unpack("C*")
log "\t#{message[:bytes].join(",")}"
log "\t#{message[:bytes].map(&:chr).join(" ")}"
end
output_messages << message
return output_messages if message_code == 'Z'
end
end
def log(msg)
return unless @verbose
puts msg
end
def close
@socket.close
end
end

View File

@@ -2,6 +2,7 @@ require 'json'
require 'ostruct'
require_relative 'pgcat_process'
require_relative 'pg_instance'
require_relative 'pg_socket'
class ::Hash
def deep_merge(second)
@@ -12,18 +13,14 @@ end
module Helpers
module Pgcat
def self.three_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mode="random", log_level="info", secrets=nil)
def self.three_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mode="random", log_level="info")
user = {
"password" => "sharding_user",
"pool_size" => pool_size,
"statement_timeout" => 0,
"username" => "sharding_user",
"username" => "sharding_user"
}
if !secrets.nil?
user["secrets"] = secrets
end
pgcat = PgcatProcess.new(log_level)
primary0 = PgInstance.new(5432, user["username"], user["password"], "shard0")
primary1 = PgInstance.new(7432, user["username"], user["password"], "shard1")
@@ -31,7 +28,7 @@ module Helpers
pgcat_cfg = pgcat.current_config
pgcat_cfg["pools"] = {
"#{pool_name}" => {
"#{pool_name}" => {
"default_role" => "any",
"pool_mode" => pool_mode,
"load_balancing_mode" => lb_mode,
@@ -44,15 +41,26 @@ module Helpers
"1" => { "database" => "shard1", "servers" => [["localhost", primary1.port.to_s, "primary"]] },
"2" => { "database" => "shard2", "servers" => [["localhost", primary2.port.to_s, "primary"]] },
},
"users" => { "0" => user }
},
"users" => { "0" => user },
"plugins" => {
"intercept" => {
"enabled" => true,
"queries" => {
"0" => {
"query" => "select current_database() as a, current_schemas(false) as b",
"schema" => [
["a", "text"],
["b", "text"],
],
"result" => [
["${DATABASE}", "{public}"],
]
}
}
}
}
}
}
if !secrets.nil?
pgcat_cfg["general"]["tls_certificate"] = "../../.circleci/server.cert"
pgcat_cfg["general"]["tls_private_key"] = "../../.circleci/server.key"
end
pgcat.update_config(pgcat_cfg)
pgcat.start
@@ -110,7 +118,7 @@ module Helpers
end
end
def self.single_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mode="random", log_level="info")
def self.single_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mode="random", log_level="info", pool_settings={})
user = {
"password" => "sharding_user",
"pool_size" => pool_size,
@@ -126,28 +134,32 @@ module Helpers
replica1 = PgInstance.new(8432, user["username"], user["password"], "shard0")
replica2 = PgInstance.new(9432, user["username"], user["password"], "shard0")
pool_config = {
"default_role" => "any",
"pool_mode" => pool_mode,
"load_balancing_mode" => lb_mode,
"primary_reads_enabled" => false,
"query_parser_enabled" => false,
"sharding_function" => "pg_bigint_hash",
"shards" => {
"0" => {
"database" => "shard0",
"servers" => [
["localhost", primary.port.to_s, "primary"],
["localhost", replica0.port.to_s, "replica"],
["localhost", replica1.port.to_s, "replica"],
["localhost", replica2.port.to_s, "replica"]
]
},
},
"users" => { "0" => user }
}
pool_config = pool_config.merge(pool_settings)
# Main proxy configs
pgcat_cfg["pools"] = {
"#{pool_name}" => {
"default_role" => "any",
"pool_mode" => pool_mode,
"load_balancing_mode" => lb_mode,
"primary_reads_enabled" => false,
"query_parser_enabled" => false,
"sharding_function" => "pg_bigint_hash",
"shards" => {
"0" => {
"database" => "shard0",
"servers" => [
["localhost", primary.port.to_s, "primary"],
["localhost", replica0.port.to_s, "replica"],
["localhost", replica1.port.to_s, "replica"],
["localhost", replica2.port.to_s, "replica"]
]
},
},
"users" => { "0" => user }
}
"#{pool_name}" => pool_config,
}
pgcat_cfg["general"]["port"] = pgcat.port
pgcat.update_config(pgcat_cfg)

View File

@@ -78,6 +78,7 @@ class PgcatProcess
10.times do
Process.kill 0, @pid
PG::connect(connection_string || example_connection_string).close
return self
rescue Errno::ESRCH
raise StandardError, "Process #{@pid} died. #{logs}"
@@ -111,13 +112,10 @@ class PgcatProcess
"postgresql://#{username}:#{password}@0.0.0.0:#{@port}/pgcat"
end
def connection_string(pool_name, username, password=nil)
def connection_string(pool_name, username, password = nil)
cfg = current_config
user_idx, user_obj = cfg["pools"][pool_name]["users"].detect { |k, user| user["username"] == username }
password = if password.nil? then user_obj["password"] else password end
"postgresql://#{username}:#{password}@0.0.0.0:#{@port}/#{pool_name}"
"postgresql://#{username}:#{password || user_obj["password"]}@0.0.0.0:#{@port}/#{pool_name}"
end
def example_connection_string

View File

@@ -65,7 +65,7 @@ describe "Least Outstanding Queries Load Balancing" do
processes.pgcat.shutdown
end
context "under homogenous load" do
context "under homogeneous load" do
it "balances query volume between all instances" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))

View File

@@ -25,7 +25,7 @@ describe "Query Mirroing" do
processes.pgcat.shutdown
end
it "can mirror a query" do
xit "can mirror a query" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
runs = 15
runs.times { conn.async_exec("SELECT 1 + 2") }

View File

@@ -241,6 +241,18 @@ describe "Miscellaneous" do
expect(processes.primary.count_query("DISCARD ALL")).to eq(10)
end
it "Resets server roles correctly" do
10.times do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.async_exec("SET SERVER ROLE to 'primary'")
conn.async_exec("SELECT 1")
conn.async_exec("SET statement_timeout to 5000")
conn.close
end
expect(processes.primary.count_query("RESET ROLE")).to eq(10)
end
end
context "transaction mode" do
@@ -308,6 +320,31 @@ describe "Miscellaneous" do
expect(processes.primary.count_query("DISCARD ALL")).to eq(0)
end
end
context "server cleanup disabled" do
let(:processes) { Helpers::Pgcat.single_shard_setup("sharded_db", 1, "transaction", "random", "info", { "cleanup_server_connections" => false }) }
it "will not clean up connection state" do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
processes.primary.reset_stats
conn.async_exec("SET statement_timeout TO 1000")
conn.close
puts processes.pgcat.logs
expect(processes.primary.count_query("DISCARD ALL")).to eq(0)
end
it "will not clean up prepared statements" do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
processes.primary.reset_stats
conn.async_exec("PREPARE prepared_q (int) AS SELECT $1")
conn.close
puts processes.pgcat.logs
expect(processes.primary.count_query("DISCARD ALL")).to eq(0)
end
end
end
describe "Idle client timeout" do

View File

@@ -0,0 +1,14 @@
require_relative 'spec_helper'
describe "Plugins" do
let(:processes) { Helpers::Pgcat.three_shard_setup("sharded_db", 5) }
context "intercept" do
it "will intercept an intellij query" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
res = conn.exec("select current_database() as a, current_schemas(false) as b")
expect(res.values).to eq([["sharded_db", "{public}"]])
end
end
end

View File

@@ -0,0 +1,29 @@
require_relative 'spec_helper'
describe 'Prepared statements' do
let(:processes) { Helpers::Pgcat.three_shard_setup('sharded_db', 5) }
context 'enabled' do
it 'will work over the same connection' do
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
10.times do |i|
statement_name = "statement_#{i}"
conn.prepare(statement_name, 'SELECT $1::int')
conn.exec_prepared(statement_name, [1])
conn.describe_prepared(statement_name)
end
end
it 'will work with new connections' do
10.times do
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
statement_name = 'statement1'
conn.prepare('statement1', 'SELECT $1::int')
conn.exec_prepared('statement1', [1])
conn.describe_prepared('statement1')
end
end
end
end

155
tests/ruby/protocol_spec.rb Normal file
View File

@@ -0,0 +1,155 @@
# frozen_string_literal: true
require_relative 'spec_helper'
describe "Portocol handling" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 1, "session") }
let(:sequence) { [] }
let(:pgcat_socket) { PostgresSocket.new('localhost', processes.pgcat.port) }
let(:pgdb_socket) { PostgresSocket.new('localhost', processes.all_databases.first.port) }
after do
pgdb_socket.close
pgcat_socket.close
processes.all_databases.map(&:reset)
processes.pgcat.shutdown
end
def run_comparison(sequence, socket_a, socket_b)
sequence.each do |msg, *args|
socket_a.send(msg, *args)
socket_b.send(msg, *args)
compare_messages(
socket_a.read_from_server,
socket_b.read_from_server
)
end
end
def compare_messages(msg_arr0, msg_arr1)
if msg_arr0.count != msg_arr1.count
error_output = []
error_output << "#{msg_arr0.count} : #{msg_arr1.count}"
error_output << "PgCat Messages"
error_output += msg_arr0.map { |message| "\t#{message[:code]} - #{message[:bytes].map(&:chr).join(" ")}" }
error_output << "PgServer Messages"
error_output += msg_arr1.map { |message| "\t#{message[:code]} - #{message[:bytes].map(&:chr).join(" ")}" }
error_desc = error_output.join("\n")
raise StandardError, "Message count mismatch #{error_desc}"
end
(0..msg_arr0.count - 1).all? do |i|
msg0 = msg_arr0[i]
msg1 = msg_arr1[i]
result = [
msg0[:code] == msg1[:code],
msg0[:len] == msg1[:len],
msg0[:bytes] == msg1[:bytes],
].all?
next result if result
if result == false
error_string = []
if msg0[:code] != msg1[:code]
error_string << "code #{msg0[:code]} != #{msg1[:code]}"
end
if msg0[:len] != msg1[:len]
error_string << "len #{msg0[:len]} != #{msg1[:len]}"
end
if msg0[:bytes] != msg1[:bytes]
error_string << "bytes #{msg0[:bytes]} != #{msg1[:bytes]}"
end
err = error_string.join("\n")
raise StandardError, "Message mismatch #{err}"
end
end
end
RSpec.shared_examples "at parity with database" do
before do
pgcat_socket.send_startup_message("sharding_user", "sharded_db", "sharding_user")
pgdb_socket.send_startup_message("sharding_user", "shard0", "sharding_user")
end
it "works" do
run_comparison(sequence, pgcat_socket, pgdb_socket)
end
end
context "Cancel Query" do
let(:sequence) {
[
[:send_query_message, "SELECT pg_sleep(5)"],
[:cancel_query]
]
}
it_behaves_like "at parity with database"
end
xcontext "Simple query after parse" do
let(:sequence) {
[
[:send_parse_message, "SELECT 5"],
[:send_query_message, "SELECT 1"],
[:send_bind_message],
[:send_describe_message, "P"],
[:send_execute_message],
[:send_sync_message],
]
}
# Known to fail due to PgCat not supporting flush
it_behaves_like "at parity with database"
end
xcontext "Flush message" do
let(:sequence) {
[
[:send_parse_message, "SELECT 1"],
[:send_flush_message]
]
}
# Known to fail due to PgCat not supporting flush
it_behaves_like "at parity with database"
end
xcontext "Bind without parse" do
let(:sequence) {
[
[:send_bind_message]
]
}
# This is known to fail.
# Server responds immediately, Proxy buffers the message
it_behaves_like "at parity with database"
end
context "Simple message" do
let(:sequence) {
[[:send_query_message, "SELECT 1"]]
}
it_behaves_like "at parity with database"
end
context "Extended protocol" do
let(:sequence) {
[
[:send_parse_message, "SELECT 1"],
[:send_bind_message],
[:send_describe_message, "P"],
[:send_execute_message],
[:send_sync_message],
]
}
it_behaves_like "at parity with database"
end
end

View File

@@ -27,7 +27,7 @@ describe "Sharding" do
processes.pgcat.shutdown
end
describe "automatic routing of extended procotol" do
describe "automatic routing of extended protocol" do
it "can do it" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.exec("SET SERVER ROLE TO 'auto'")

369
tests/ruby/stats_spec.rb Normal file
View File

@@ -0,0 +1,369 @@
# frozen_string_literal: true
require 'open3'
require_relative 'spec_helper'
describe "Stats" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 10) }
let(:pgcat_conn_str) { processes.pgcat.connection_string("sharded_db", "sharding_user") }
after do
processes.all_databases.map(&:reset)
processes.pgcat.shutdown
end
describe "SHOW STATS" do
context "clients connect and make one query" do
it "updates *_query_time and *_wait_time" do
connections = Array.new(3) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
Thread.new { c.async_exec("SELECT pg_sleep(0.25)") }
end
sleep(1)
connections.map(&:close)
# wait for averages to be calculated, we shouldn't do this too often
sleep(15.5)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW STATS")[0]
admin_conn.close
expect(results["total_query_time"].to_i).to be_within(200).of(750)
expect(results["avg_query_time"].to_i).to be_within(50).of(250)
expect(results["total_wait_time"].to_i).to_not eq(0)
expect(results["avg_wait_time"].to_i).to_not eq(0)
end
end
end
describe "SHOW POOLS" do
context "bad credentials" do
it "does not change any stats" do
bad_password_url = URI(pgcat_conn_str)
bad_password_url.password = "wrong"
expect { PG::connect("#{bad_password_url.to_s}?application_name=bad_password") }.to raise_error(PG::ConnectionBad)
sleep(1)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["sv_idle"]).to eq("1")
end
end
context "bad database name" do
it "does not change any stats" do
bad_db_url = URI(pgcat_conn_str)
bad_db_url.path = "/wrong_db"
expect { PG::connect("#{bad_db_url.to_s}?application_name=bad_db") }.to raise_error(PG::ConnectionBad)
sleep(1)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["sv_idle"]).to eq("1")
end
end
context "client connects but issues no queries" do
it "only affects cl_idle stats" do
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
before_test = admin_conn.async_exec("SHOW POOLS")[0]["sv_idle"]
connections = Array.new(20) { PG::connect(pgcat_conn_str) }
sleep(1)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_idle"]).to eq("20")
expect(results["sv_idle"]).to eq(before_test)
connections.map(&:close)
sleep(1.1)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_idle cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["sv_idle"]).to eq(before_test)
end
end
context "clients connect and make one query" do
it "only affects cl_idle, sv_idle stats" do
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
Thread.new { c.async_exec("SELECT pg_sleep(2.5)") }
end
sleep(1.1)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_waiting cl_cancel_req sv_idle sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_active"]).to eq("5")
expect(results["sv_active"]).to eq("5")
sleep(3)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_idle"]).to eq("5")
expect(results["sv_idle"]).to eq("5")
connections.map(&:close)
sleep(1)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["sv_idle"]).to eq("5")
end
end
context "client connects and opens a transaction and closes connection uncleanly" do
it "produces correct statistics" do
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
Thread.new do
c.async_exec("BEGIN")
c.async_exec("SELECT pg_sleep(0.01)")
c.close
end
end
sleep(1.1)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["sv_idle"]).to eq("5")
end
end
context "client fail to checkout connection from the pool" do
it "counts clients as idle" do
new_configs = processes.pgcat.current_config
new_configs["general"]["connect_timeout"] = 500
new_configs["general"]["ban_time"] = 1
new_configs["general"]["shutdown_timeout"] = 1
new_configs["pools"]["sharded_db"]["users"]["0"]["pool_size"] = 1
processes.pgcat.update_config(new_configs)
processes.pgcat.reload_config
threads = []
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
threads << Thread.new { c.async_exec("SELECT pg_sleep(1)") rescue PG::SystemError }
end
sleep(2)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_idle"]).to eq("5")
expect(results["sv_idle"]).to eq("1")
threads.map(&:join)
connections.map(&:close)
end
end
context "clients connects and disconnect normally" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 2) }
it 'shows the same number of clients before and after' do
clients_before = clients_connected_to_pool(processes: processes)
threads = []
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
threads << Thread.new { c.async_exec("SELECT 1") rescue nil }
end
clients_between = clients_connected_to_pool(processes: processes)
expect(clients_before).not_to eq(clients_between)
connections.each(&:close)
clients_after = clients_connected_to_pool(processes: processes)
expect(clients_before).to eq(clients_after)
end
end
context "clients connects and disconnect abruptly" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 10) }
it 'shows the same number of clients before and after' do
threads = []
connections = Array.new(2) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
threads << Thread.new { c.async_exec("SELECT 1") }
end
clients_before = clients_connected_to_pool(processes: processes)
random_string = (0...8).map { (65 + rand(26)).chr }.join
connection_string = "#{pgcat_conn_str}?application_name=#{random_string}"
faulty_client = Process.spawn("psql -Atx #{connection_string} >/dev/null")
sleep(1)
# psql starts two processes, we only know the pid of the parent, this
# ensure both are killed
`pkill -9 -f '#{random_string}'`
Process.wait(faulty_client)
clients_after = clients_connected_to_pool(processes: processes)
expect(clients_before).to eq(clients_after)
end
end
context "clients overwhelm server pools" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 2) }
it "cl_waiting is updated to show it" do
threads = []
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
threads << Thread.new { c.async_exec("SELECT pg_sleep(1.5)") }
end
sleep(1.1) # Allow time for stats to update
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_cancel_req sv_idle sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_waiting"]).to eq("2")
expect(results["cl_active"]).to eq("2")
expect(results["sv_active"]).to eq("2")
sleep(2.5) # Allow time for stats to update
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_idle"]).to eq("4")
expect(results["sv_idle"]).to eq("2")
threads.map(&:join)
connections.map(&:close)
end
it "show correct max_wait" do
threads = []
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
threads << Thread.new { c.async_exec("SELECT pg_sleep(1.5)") rescue nil }
end
sleep(2.5) # Allow time for stats to update
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
expect(results["maxwait"]).to eq("1")
expect(results["maxwait_us"].to_i).to be_within(200_000).of(500_000)
connections.map(&:close)
sleep(4.5) # Allow time for stats to update
results = admin_conn.async_exec("SHOW POOLS")[0]
expect(results["maxwait"]).to eq("0")
threads.map(&:join)
end
end
end
describe "SHOW CLIENTS" do
it "reports correct number and application names" do
conn_str = processes.pgcat.connection_string("sharded_db", "sharding_user")
connections = Array.new(20) { |i| PG::connect("#{conn_str}?application_name=app#{i % 5}") }
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
sleep(1) # Wait for stats to be updated
results = admin_conn.async_exec("SHOW CLIENTS")
expect(results.count).to eq(21) # count admin clients
expect(results.select { |c| c["application_name"] == "app3" || c["application_name"] == "app4" }.count).to eq(8)
expect(results.select { |c| c["database"] == "pgcat" }.count).to eq(1)
connections[0..5].map(&:close)
sleep(1) # Wait for stats to be updated
results = admin_conn.async_exec("SHOW CLIENTS")
expect(results.count).to eq(15)
connections[6..].map(&:close)
sleep(1) # Wait for stats to be updated
expect(admin_conn.async_exec("SHOW CLIENTS").count).to eq(1)
admin_conn.close
end
it "reports correct number of queries and transactions" do
conn_str = processes.pgcat.connection_string("sharded_db", "sharding_user")
connections = Array.new(2) { |i| PG::connect("#{conn_str}?application_name=app#{i}") }
connections.each do |c|
c.async_exec("SELECT 1")
c.async_exec("SELECT 2")
c.async_exec("SELECT 3")
c.async_exec("BEGIN")
c.async_exec("SELECT 4")
c.async_exec("SELECT 5")
c.async_exec("COMMIT")
end
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
sleep(1) # Wait for stats to be updated
results = admin_conn.async_exec("SHOW CLIENTS")
expect(results.count).to eq(3)
normal_client_results = results.reject { |r| r["database"] == "pgcat" }
expect(normal_client_results[0]["transaction_count"]).to eq("4")
expect(normal_client_results[1]["transaction_count"]).to eq("4")
expect(normal_client_results[0]["query_count"]).to eq("7")
expect(normal_client_results[1]["query_count"]).to eq("7")
admin_conn.close
connections.map(&:close)
end
end
describe "Query Storm" do
context "when the proxy receives overwhelmingly large number of short quick queries" do
it "should not have lingering clients or active servers" do
new_configs = processes.pgcat.current_config
new_configs["general"]["connect_timeout"] = 500
new_configs["general"]["ban_time"] = 1
new_configs["general"]["shutdown_timeout"] = 1
new_configs["pools"]["sharded_db"]["users"]["0"]["pool_size"] = 1
processes.pgcat.update_config(new_configs)
processes.pgcat.reload_config
Array.new(40) do
Thread.new do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.async_exec("SELECT pg_sleep(0.1)")
rescue PG::SystemError
ensure
conn.close
end
end.each(&:join)
sleep 1
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_idle cl_waiting cl_cancel_req sv_used sv_tested sv_login].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
admin_conn.close
end
end
end
end

1
tests/rust/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
target/

1322
tests/rust/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

10
tests/rust/Cargo.toml Normal file
View File

@@ -0,0 +1,10 @@
[package]
name = "rust"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
sqlx = { version = "0.6.2", features = [ "runtime-tokio-rustls", "postgres", "json", "tls", "migrate", "time", "uuid", "ipnetwork"] }
tokio = { version = "1", features = ["full"] }

29
tests/rust/src/main.rs Normal file
View File

@@ -0,0 +1,29 @@
#[tokio::main]
async fn main() {
test_prepared_statements().await;
}
async fn test_prepared_statements() {
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(5)
.connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/sharded_db")
.await
.unwrap();
let mut handles = Vec::new();
for _ in 0..5 {
let pool = pool.clone();
let handle = tokio::task::spawn(async move {
for _ in 0..1000 {
sqlx::query("SELECT 1").fetch_all(&pool).await.unwrap();
}
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
}

View File

@@ -0,0 +1 @@
tomli