Compare commits

...

29 Commits

Author SHA1 Message Date
Lev Kokotov
e7265cbf91 fix flakey test 2023-05-03 16:01:48 -07:00
Lev Kokotov
d738ba28b6 fix tests 2023-05-03 15:42:16 -07:00
Lev Kokotov
ff80bb75cc clean up 2023-05-03 15:38:03 -07:00
Lev Kokotov
374a6b138b more plugins 2023-05-03 15:29:16 -07:00
dependabot[bot]
d5e329fec5 chore(deps): bump regex from 1.8.0 to 1.8.1 (#413)
Bumps [regex](https://github.com/rust-lang/regex) from 1.8.0 to 1.8.1.
- [Release notes](https://github.com/rust-lang/regex/releases)
- [Changelog](https://github.com/rust-lang/regex/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/regex/commits/1.8.1)

---
updated-dependencies:
- dependency-name: regex
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-05-03 10:00:05 -07:00
Lev Kokotov
09e54e1175 Plugins! (#420)
* Some queries

* Plugins!!

* cleanup

* actual names

* the actual plugins

* comment

* fix tests

* Tests

* unused errors

* Increase reaper rate to actually enforce settings

* ok
2023-05-03 09:13:05 -07:00
dependabot[bot]
23819c8549 chore(deps): bump rustls from 0.21.0 to 0.21.1 (#419)
Bumps [rustls](https://github.com/rustls/rustls) from 0.21.0 to 0.21.1.
- [Release notes](https://github.com/rustls/rustls/releases)
- [Changelog](https://github.com/rustls/rustls/blob/main/RELEASE_NOTES.md)
- [Commits](https://github.com/rustls/rustls/compare/v/0.21.0...v/0.21.1)

---
updated-dependencies:
- dependency-name: rustls
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-05-02 07:32:44 -07:00
Jose Fernández
7dfbd993f2 Add dns_cache for server addresses as in pgbouncer (#249)
* Add dns_cache so server addresses are cached and invalidated when DNS changes.

Adds a module to deal with dns_cache feature. It's
main struct is CachedResolver, which is a simple thread safe
hostname <-> Ips cache with the ability to refresh resolutions
every `dns_max_ttl` seconds. This way, a client can check whether its
ip address has changed.

* Allow reloading dns cached

* Add documentation for dns_cached
2023-05-02 10:26:40 +02:00
Lev Kokotov
3601130ba1 Readme update (#418)
* Readme update

* m

* wording
2023-04-30 09:44:25 -07:00
Lev Kokotov
0d504032b2 Server TLS (#417)
* Server TLS

* Finish up TLS

* thats it

* diff

* remove dead code

* maybe?

* dirty shutdown

* skip flakey test

* remove unused error

* fetch config once
2023-04-30 09:41:46 -07:00
Lev Kokotov
4a87b4807d Add more pool settings (#416)
* Add some pool settings

* fmt
2023-04-26 16:33:26 -07:00
Shawn
cb5ff40a59 fix typo (#415)
chore: typo
2023-04-26 08:28:54 -07:00
dependabot[bot]
62b2d994c1 chore(deps): bump regex from 1.7.3 to 1.8.0 (#411)
Bumps [regex](https://github.com/rust-lang/regex) from 1.7.3 to 1.8.0.
- [Release notes](https://github.com/rust-lang/regex/releases)
- [Changelog](https://github.com/rust-lang/regex/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/regex/commits)

---
updated-dependencies:
- dependency-name: regex
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-21 06:33:52 -07:00
Lev Kokotov
66805d7e77 README updates (#409)
* Better table

* add image

* promote auth passthrough to stable

* fmt
2023-04-20 07:53:55 -07:00
Lev Kokotov
4ccc1e7fa3 Fix CONFIG (#408)
Fix readme
2023-04-19 07:45:26 -07:00
Lev Kokotov
3dae3d0777 Separate server and client passwords optionally (#407)
* Separate server and user passwords

* config
2023-04-18 09:57:17 -07:00
dependabot[bot]
a18eb42df5 chore(deps): bump serde from 1.0.159 to 1.0.160 (#404)
Bumps [serde](https://github.com/serde-rs/serde) from 1.0.159 to 1.0.160.
- [Release notes](https://github.com/serde-rs/serde/releases)
- [Commits](https://github.com/serde-rs/serde/compare/v1.0.159...v1.0.160)

---
updated-dependencies:
- dependency-name: serde
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-14 10:25:00 -07:00
dependabot[bot]
6aacf1fa19 chore(deps): bump serde_derive from 1.0.159 to 1.0.160 (#403)
Bumps [serde_derive](https://github.com/serde-rs/serde) from 1.0.159 to 1.0.160.
- [Release notes](https://github.com/serde-rs/serde/releases)
- [Commits](https://github.com/serde-rs/serde/compare/v1.0.159...v1.0.160)

---
updated-dependencies:
- dependency-name: serde_derive
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-14 10:24:52 -07:00
dependabot[bot]
8e99e65215 chore(deps): bump sqlparser from 0.32.0 to 0.33.0 (#399)
Bumps [sqlparser](https://github.com/sqlparser-rs/sqlparser-rs) from 0.32.0 to 0.33.0.
- [Release notes](https://github.com/sqlparser-rs/sqlparser-rs/releases)
- [Changelog](https://github.com/sqlparser-rs/sqlparser-rs/blob/main/CHANGELOG.md)
- [Commits](https://github.com/sqlparser-rs/sqlparser-rs/compare/v0.32.0...v0.33.0)

---
updated-dependencies:
- dependency-name: sqlparser
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-14 10:24:44 -07:00
dependabot[bot]
5dfbc102a9 chore(deps): bump hyper from 0.14.25 to 0.14.26 (#406)
Bumps [hyper](https://github.com/hyperium/hyper) from 0.14.25 to 0.14.26.
- [Release notes](https://github.com/hyperium/hyper/releases)
- [Changelog](https://github.com/hyperium/hyper/blob/v0.14.26/CHANGELOG.md)
- [Commits](https://github.com/hyperium/hyper/compare/v0.14.25...v0.14.26)

---
updated-dependencies:
- dependency-name: hyper
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-14 10:23:52 -07:00
Cluas
bae12fca99 feat: set keepalive for pgcat server itself (#402)
* feat: set keepalive for pgcat server self

* docs: note also set for client
2023-04-12 09:29:43 -07:00
Lev Kokotov
421c5d4b64 Load config on client connect (#401) 2023-04-11 10:32:48 -07:00
Kian-Meng Ang
d568739db9 Fix typos (#398)
Found via `typos --format brief`
2023-04-10 18:37:16 -07:00
Lev Kokotov
692353c839 A couple things (#397)
* Format cleanup

* fmt

* finally
2023-04-10 14:51:01 -07:00
Lev Kokotov
a62f6b0eea Fix port; add user pool mode (#395)
* Fix port; add user pool mode

* will probably break our session/transaction mode tests
2023-04-05 15:06:19 -07:00
dependabot[bot]
89e15f09b5 chore(deps): bump tokio-rustls from 0.23.4 to 0.24.0 (#394)
Bumps [tokio-rustls](https://github.com/tokio-rs/tls) from 0.23.4 to 0.24.0.
- [Release notes](https://github.com/tokio-rs/tls/releases)
- [Commits](https://github.com/tokio-rs/tls/commits)

---
updated-dependencies:
- dependency-name: tokio-rustls
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-04-02 23:00:09 -07:00
Mostafa Abdelraouf
7ddd23b514 Protocol-level test helpers (#393)
I needed to have granular control over protocol message testing. For example, being able to send protocol messages one-by-one and then be able to inspect the results.

In order to do that, I created this low-level ruby client that can be used to send protocol messages in any order without blocking and also allows inspection of response messages.
2023-04-01 15:27:57 -05:00
dependabot[bot]
faa9c1f64a chore(deps): bump futures from 0.3.27 to 0.3.28 (#392)
Bumps [futures](https://github.com/rust-lang/futures-rs) from 0.3.27 to 0.3.28.
- [Release notes](https://github.com/rust-lang/futures-rs/releases)
- [Changelog](https://github.com/rust-lang/futures-rs/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/futures-rs/compare/0.3.27...0.3.28)

---
updated-dependencies:
- dependency-name: futures
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-31 09:35:01 -07:00
dependabot[bot]
9094988491 chore(deps): bump postgres-protocol from 0.6.4 to 0.6.5 (#391)
Bumps [postgres-protocol](https://github.com/sfackler/rust-postgres) from 0.6.4 to 0.6.5.
- [Release notes](https://github.com/sfackler/rust-postgres/releases)
- [Commits](https://github.com/sfackler/rust-postgres/compare/postgres-protocol-v0.6.4...postgres-protocol-v0.6.5)

---
updated-dependencies:
- dependency-name: postgres-protocol
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-03-31 09:34:51 -07:00
39 changed files with 3246 additions and 407 deletions

View File

@@ -39,7 +39,7 @@ log_client_connections = false
log_client_disconnections = false log_client_disconnections = false
# Reload config automatically if it changes. # Reload config automatically if it changes.
autoreload = true autoreload = 15000
# TLS # TLS
tls_certificate = ".circleci/server.cert" tls_certificate = ".circleci/server.cert"

14
.editorconfig Normal file
View File

@@ -0,0 +1,14 @@
root = true
[*]
trim_trailing_whitespace = true
insert_final_newline = true
[*.rs]
indent_style = space
indent_size = 4
max_line_length = 120
[*.toml]
indent_style = space
indent_size = 2

137
CONFIG.md
View File

@@ -49,6 +49,14 @@ default: 30000 # milliseconds
How long an idle connection with a server is left open (ms). How long an idle connection with a server is left open (ms).
### server_lifetime
```
path: general.server_lifetime
default: 86400000 # 24 hours
```
Max connection lifetime before it's closed, even if actively used.
### idle_client_in_transaction_timeout ### idle_client_in_transaction_timeout
``` ```
path: general.idle_client_in_transaction_timeout path: general.idle_client_in_transaction_timeout
@@ -108,7 +116,7 @@ If we should log client disconnections
### autoreload ### autoreload
``` ```
path: general.autoreload path: general.autoreload
default: false default: 15000
``` ```
When set to true, PgCat reloads configs if it detects a change in the config file. When set to true, PgCat reloads configs if it detects a change in the config file.
@@ -152,7 +160,7 @@ default: <UNSET>
example: "server.cert" example: "server.cert"
``` ```
Path to TLS Certficate file to use for TLS connections Path to TLS Certificate file to use for TLS connections
### tls_private_key ### tls_private_key
``` ```
@@ -175,40 +183,26 @@ Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DAT
### admin_password ### admin_password
``` ```
path: general.admin_password path: general.admin_password
default: <UNSET> default: "admin_pass"
``` ```
Password to access the virtual administrative database Password to access the virtual administrative database
### auth_query (experimental) ### dns_cache_enabled
``` ```
path: general.auth_query path: general.dns_cache_enabled
default: <UNSET> default: false
``` ```
When enabled, ip resolutions for server connections specified using hostnames will be cached
and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found
old ip connections are closed (gracefully) and new connections will start using new ip.
Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be ### dns_max_ttl
established using the database configured in the pool. This parameter is inherited by every pool
and can be redefined in pool configuration.
### auth_query_user (experimental)
``` ```
path: general.auth_query_user path: general.dns_max_ttl
default: <UNSET> default: 30
``` ```
Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
This parameter is inherited by every pool and can be redefined in pool configuration.
### auth_query_password (experimental)
```
path: general.auth_query_password
default: <UNSET>
```
Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
This parameter is inherited by every pool and can be redefined in pool configuration.
## `pools.<pool_name>` Section ## `pools.<pool_name>` Section
@@ -243,7 +237,7 @@ If the client doesn't specify, PgCat routes traffic to this role by default.
`replica` round-robin between replicas only without touching the primary, `replica` round-robin between replicas only without touching the primary,
`primary` all queries go to the primary unless otherwise specified. `primary` all queries go to the primary unless otherwise specified.
### query_parser_enabled (experimental) ### query_parser_enabled
``` ```
path: pools.<pool_name>.query_parser_enabled path: pools.<pool_name>.query_parser_enabled
default: true default: true
@@ -264,7 +258,7 @@ If the query parser is enabled and this setting is enabled, the primary will be
load balancing of read queries. Otherwise, the primary will only be used for write load balancing of read queries. Otherwise, the primary will only be used for write
queries. The primary can always be explicitly selected with our custom protocol. queries. The primary can always be explicitly selected with our custom protocol.
### sharding_key_regex (experimental) ### sharding_key_regex
``` ```
path: pools.<pool_name>.sharding_key_regex path: pools.<pool_name>.sharding_key_regex
default: <UNSET> default: <UNSET>
@@ -286,7 +280,40 @@ Current options:
`pg_bigint_hash`: PARTITION BY HASH (Postgres hashing function) `pg_bigint_hash`: PARTITION BY HASH (Postgres hashing function)
`sha1`: A hashing function based on SHA1 `sha1`: A hashing function based on SHA1
### automatic_sharding_key (experimental) ### auth_query
```
path: pools.<pool_name>.auth_query
default: <UNSET>
example: "SELECT $1"
```
Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
established using the database configured in the pool. This parameter is inherited by every pool
and can be redefined in pool configuration.
### auth_query_user
```
path: pools.<pool_name>.auth_query_user
default: <UNSET>
example: "sharding_user"
```
User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
This parameter is inherited by every pool and can be redefined in pool configuration.
### auth_query_password
```
path: pools.<pool_name>.auth_query_password
default: <UNSET>
example: "sharding_user"
```
Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
This parameter is inherited by every pool and can be redefined in pool configuration.
### automatic_sharding_key
``` ```
path: pools.<pool_name>.automatic_sharding_key path: pools.<pool_name>.automatic_sharding_key
default: <UNSET> default: <UNSET>
@@ -311,30 +338,6 @@ default: 3000
Connect timeout can be overwritten in the pool Connect timeout can be overwritten in the pool
### auth_query (experimental)
```
path: general.auth_query
default: <UNSET>
```
Auth query can be overwritten in the pool
### auth_query_user (experimental)
```
path: general.auth_query_user
default: <UNSET>
```
Auth query user can be overwritten in the pool
### auth_query_password (experimental)
```
path: general.auth_query_password
default: <UNSET>
```
Auth query password can be overwritten in the pool
## `pools.<pool_name>.users.<user_index>` Section ## `pools.<pool_name>.users.<user_index>` Section
### username ### username
@@ -343,7 +346,8 @@ path: pools.<pool_name>.users.<user_index>.username
default: "sharding_user" default: "sharding_user"
``` ```
Postgresql username PostgreSQL username used to authenticate the user and connect to the server
if `server_username` is not set.
### password ### password
``` ```
@@ -351,7 +355,26 @@ path: pools.<pool_name>.users.<user_index>.password
default: "sharding_user" default: "sharding_user"
``` ```
Postgresql password PostgreSQL password used to authenticate the user and connect to the server
if `server_password` is not set.
### server_username
```
path: pools.<pool_name>.users.<user_index>.server_username
default: <UNSET>
example: "another_user"
```
PostgreSQL username used to connect to the server.
### server_password
```
path: pools.<pool_name>.users.<user_index>.server_password
default: <UNSET>
example: "another_password"
```
PostgreSQL password used to connect to the server.
### pool_size ### pool_size
``` ```
@@ -382,7 +405,7 @@ default: [["127.0.0.1", 5432, "primary"], ["localhost", 5432, "replica"]]
Array of servers in the shard, each server entry is an array of `[host, port, role]` Array of servers in the shard, each server entry is an array of `[host, port, role]`
### mirrors (experimental) ### mirrors
``` ```
path: pools.<pool_name>.shards.<shard_index>.mirrors path: pools.<pool_name>.shards.<shard_index>.mirrors
default: <UNSET> default: <UNSET>

464
Cargo.lock generated
View File

@@ -4,9 +4,9 @@ version = 3
[[package]] [[package]]
name = "aho-corasick" name = "aho-corasick"
version = "0.7.20" version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac" checksum = "67fc08ce920c31afb70f013dcce1bfc3a3195de6a228474e45e1f145b36f8d04"
dependencies = [ dependencies = [
"memchr", "memchr",
] ]
@@ -26,6 +26,27 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6" checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6"
[[package]]
name = "async-stream"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dad5c83079eae9969be7fadefe640a1c566901f05ff91ab221de4b6f68d9507e"
dependencies = [
"async-stream-impl",
"futures-core",
]
[[package]]
name = "async-stream-impl"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10f203db73a71dfa2fb6dd22763990fa26f3d2625a6da2da900d23b87d26be27"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]] [[package]]
name = "async-trait" name = "async-trait"
version = "0.1.68" version = "0.1.68"
@@ -54,12 +75,6 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "base64"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
[[package]] [[package]]
name = "base64" name = "base64"
version = "0.21.0" version = "0.21.0"
@@ -218,6 +233,12 @@ dependencies = [
"syn 1.0.109", "syn 1.0.109",
] ]
[[package]]
name = "data-encoding"
version = "2.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ee2393c4a91429dffb4bedf19f4d6abf27d8a732c8ce4980305d782e5426d57"
[[package]] [[package]]
name = "digest" name = "digest"
version = "0.10.6" version = "0.10.6"
@@ -229,6 +250,18 @@ dependencies = [
"subtle", "subtle",
] ]
[[package]]
name = "enum-as-inner"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c9720bba047d567ffc8a3cba48bf19126600e249ab7f128e9233e6376976a116"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]] [[package]]
name = "env_logger" name = "env_logger"
version = "0.10.0" version = "0.10.0"
@@ -282,10 +315,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]] [[package]]
name = "futures" name = "form_urlencoded"
version = "0.3.27" version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "531ac96c6ff5fd7c62263c5e3c67a603af4fcaee2e1a0ae5565ba3a11e69e549" checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8"
dependencies = [
"percent-encoding",
]
[[package]]
name = "futures"
version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40"
dependencies = [ dependencies = [
"futures-channel", "futures-channel",
"futures-core", "futures-core",
@@ -298,9 +340,9 @@ dependencies = [
[[package]] [[package]]
name = "futures-channel" name = "futures-channel"
version = "0.3.27" version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "164713a5a0dcc3e7b4b1ed7d3b433cabc18025386f9339346e8daf15963cf7ac" checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"futures-sink", "futures-sink",
@@ -308,15 +350,15 @@ dependencies = [
[[package]] [[package]]
name = "futures-core" name = "futures-core"
version = "0.3.27" version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "86d7a0c1aa76363dac491de0ee99faf6941128376f1cf96f07db7603b7de69dd" checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c"
[[package]] [[package]]
name = "futures-executor" name = "futures-executor"
version = "0.3.27" version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1997dd9df74cdac935c76252744c1ed5794fac083242ea4fe77ef3ed60ba0f83" checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"futures-task", "futures-task",
@@ -325,38 +367,38 @@ dependencies = [
[[package]] [[package]]
name = "futures-io" name = "futures-io"
version = "0.3.27" version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89d422fa3cbe3b40dca574ab087abb5bc98258ea57eea3fd6f1fa7162c778b91" checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964"
[[package]] [[package]]
name = "futures-macro" name = "futures-macro"
version = "0.3.27" version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3eb14ed937631bd8b8b8977f2c198443447a8355b6e3ca599f38c975e5a963b6" checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 1.0.109", "syn 2.0.9",
] ]
[[package]] [[package]]
name = "futures-sink" name = "futures-sink"
version = "0.3.27" version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec93083a4aecafb2a80a885c9de1f0ccae9dbd32c2bb54b0c3a65690e0b8d2f2" checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e"
[[package]] [[package]]
name = "futures-task" name = "futures-task"
version = "0.3.27" version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd65540d33b37b16542a0438c12e6aeead10d4ac5d05bd3f805b8f35ab592879" checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65"
[[package]] [[package]]
name = "futures-util" name = "futures-util"
version = "0.3.27" version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ef6b17e481503ec85211fed8f39d1970f128935ca1f814cd32ac4a6842e84ab" checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533"
dependencies = [ dependencies = [
"futures-channel", "futures-channel",
"futures-core", "futures-core",
@@ -393,9 +435,9 @@ dependencies = [
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.3.15" version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f9f29bc9dda355256b2916cf526ab02ce0aeaaaf2bad60d65ef3f12f11dd0f4" checksum = "66b91535aa35fea1523ad1b86cb6b53c28e0ae566ba4a460f4457e936cad7c6f"
dependencies = [ dependencies = [
"bytes", "bytes",
"fnv", "fnv",
@@ -416,6 +458,12 @@ version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
[[package]]
name = "heck"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9"
[[package]] [[package]]
name = "hermit-abi" name = "hermit-abi"
version = "0.2.6" version = "0.2.6"
@@ -440,6 +488,17 @@ dependencies = [
"digest", "digest",
] ]
[[package]]
name = "hostname"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c731c3e10504cc8ed35cfe2f1db4c9274c3d35fa486e3b31df46f068ef3e867"
dependencies = [
"libc",
"match_cfg",
"winapi",
]
[[package]] [[package]]
name = "http" name = "http"
version = "0.2.9" version = "0.2.9"
@@ -482,9 +541,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]] [[package]]
name = "hyper" name = "hyper"
version = "0.14.25" version = "0.14.26"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc5e554ff619822309ffd57d8734d77cd5ce6238bc956f037ea06c58238c9899" checksum = "ab302d72a6f11a3b910431ff93aae7e773078c769f0a3ef15fb9ec692ed147d4"
dependencies = [ dependencies = [
"bytes", "bytes",
"futures-channel", "futures-channel",
@@ -528,6 +587,27 @@ dependencies = [
"cxx-build", "cxx-build",
] ]
[[package]]
name = "idna"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "418a0a6fab821475f634efe3ccc45c013f742efe03d853e8d3355d5cb850ecf8"
dependencies = [
"matches",
"unicode-bidi",
"unicode-normalization",
]
[[package]]
name = "idna"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e14ddfc70884202db2244c223200c204c2bda1bc6e0998d11b5e024d657209e6"
dependencies = [
"unicode-bidi",
"unicode-normalization",
]
[[package]] [[package]]
name = "indexmap" name = "indexmap"
version = "1.9.2" version = "1.9.2"
@@ -548,6 +628,24 @@ dependencies = [
"windows-sys", "windows-sys",
] ]
[[package]]
name = "ipconfig"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd302af1b90f2463a98fa5ad469fc212c8e3175a41c3068601bfa2727591c5be"
dependencies = [
"socket2",
"widestring",
"winapi",
"winreg",
]
[[package]]
name = "ipnet"
version = "2.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f88c5561171189e69df9d98bcf18fd5f9558300f7ea7b801eb8a0fd748bd8745"
[[package]] [[package]]
name = "is-terminal" name = "is-terminal"
version = "0.4.4" version = "0.4.4"
@@ -595,6 +693,12 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "lazy_static"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.139" version = "0.2.139"
@@ -610,6 +714,12 @@ dependencies = [
"cc", "cc",
] ]
[[package]]
name = "linked-hash-map"
version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f"
[[package]] [[package]]
name = "linux-raw-sys" name = "linux-raw-sys"
version = "0.1.4" version = "0.1.4"
@@ -635,6 +745,27 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "lru-cache"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31e24f1ad8321ca0e8a1e0ac13f23cb668e6f5466c2c57319f6a5cf1cc8e3b1c"
dependencies = [
"linked-hash-map",
]
[[package]]
name = "match_cfg"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4"
[[package]]
name = "matches"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f"
[[package]] [[package]]
name = "md-5" name = "md-5"
version = "0.10.5" version = "0.10.5"
@@ -743,14 +874,20 @@ dependencies = [
"windows-sys", "windows-sys",
] ]
[[package]]
name = "percent-encoding"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e"
[[package]] [[package]]
name = "pgcat" name = "pgcat"
version = "1.0.0" version = "1.0.2-alpha1"
dependencies = [ dependencies = [
"arc-swap", "arc-swap",
"async-trait", "async-trait",
"atomic_enum", "atomic_enum",
"base64 0.21.0", "base64",
"bb8", "bb8",
"bytes", "bytes",
"chrono", "chrono",
@@ -768,12 +905,15 @@ dependencies = [
"once_cell", "once_cell",
"parking_lot", "parking_lot",
"phf", "phf",
"pin-project",
"postgres-protocol", "postgres-protocol",
"rand", "rand",
"regex", "regex",
"rustls",
"rustls-pemfile", "rustls-pemfile",
"serde", "serde",
"serde_derive", "serde_derive",
"serde_json",
"sha-1", "sha-1",
"sha2", "sha2",
"socket2", "socket2",
@@ -781,7 +921,10 @@ dependencies = [
"stringprep", "stringprep",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",
"tokio-test",
"toml", "toml",
"trust-dns-resolver",
"webpki-roots",
] ]
[[package]] [[package]]
@@ -826,6 +969,26 @@ dependencies = [
"siphasher", "siphasher",
] ]
[[package]]
name = "pin-project"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]] [[package]]
name = "pin-project-lite" name = "pin-project-lite"
version = "0.2.9" version = "0.2.9"
@@ -840,11 +1003,11 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]] [[package]]
name = "postgres-protocol" name = "postgres-protocol"
version = "0.6.4" version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "878c6cbf956e03af9aa8204b407b9cbf47c072164800aa918c516cd4b056c50c" checksum = "78b7fa9f396f51dffd61546fd8573ee20592287996568e6175ceb0f8699ad75d"
dependencies = [ dependencies = [
"base64 0.13.1", "base64",
"byteorder", "byteorder",
"bytes", "bytes",
"fallible-iterator", "fallible-iterator",
@@ -871,6 +1034,12 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "quick-error"
version = "1.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.26" version = "1.0.26"
@@ -921,9 +1090,9 @@ dependencies = [
[[package]] [[package]]
name = "regex" name = "regex"
version = "1.7.3" version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b1f693b24f6ac912f4893ef08244d70b6067480d2f1a46e950c9691e6749d1d" checksum = "af83e617f331cc6ae2da5443c602dfa5af81e517212d9d611a5b3ba1777b5370"
dependencies = [ dependencies = [
"aho-corasick", "aho-corasick",
"memchr", "memchr",
@@ -932,9 +1101,19 @@ dependencies = [
[[package]] [[package]]
name = "regex-syntax" name = "regex-syntax"
version = "0.6.29" version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" checksum = "a5996294f19bd3aae0453a862ad728f60e6600695733dd5df01da90c54363a3c"
[[package]]
name = "resolv-conf"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52e44394d2086d010551b14b53b1f24e31647570cd1deb0379e2c21b329aba00"
dependencies = [
"hostname",
"quick-error",
]
[[package]] [[package]]
name = "ring" name = "ring"
@@ -967,14 +1146,14 @@ dependencies = [
[[package]] [[package]]
name = "rustls" name = "rustls"
version = "0.20.8" version = "0.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fff78fc74d175294f4e83b28343315ffcfb114b156f0185e9741cb5570f50e2f" checksum = "c911ba11bc8433e811ce56fde130ccf32f5127cab0e0194e9c68c5a5b671791e"
dependencies = [ dependencies = [
"log", "log",
"ring", "ring",
"rustls-webpki",
"sct", "sct",
"webpki",
] ]
[[package]] [[package]]
@@ -983,9 +1162,25 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b"
dependencies = [ dependencies = [
"base64 0.21.0", "base64",
] ]
[[package]]
name = "rustls-webpki"
version = "0.100.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6207cd5ed3d8dca7816f8f3725513a34609c0c765bf652b8c3cb4cfd87db46b"
dependencies = [
"ring",
"untrusted",
]
[[package]]
name = "ryu"
version = "1.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041"
[[package]] [[package]]
name = "scopeguard" name = "scopeguard"
version = "1.1.0" version = "1.1.0"
@@ -1010,21 +1205,35 @@ dependencies = [
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.159" version = "1.0.160"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c04e8343c3daeec41f58990b9d77068df31209f2af111e059e9fe9646693065" checksum = "bb2f3770c8bce3bcda7e149193a069a0f4365bda1fa5cd88e03bca26afc1216c"
dependencies = [
"serde_derive",
]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.159" version = "1.0.160"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c614d17805b093df4b147b51339e7e44bf05ef59fba1e45d83500bcfb4d8585" checksum = "291a097c63d8497e00160b166a967a4a79c64f3facdd01cbd7502231688d77df"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.9", "syn 2.0.9",
] ]
[[package]]
name = "serde_json"
version = "1.0.96"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1"
dependencies = [
"itoa",
"ryu",
"serde",
]
[[package]] [[package]]
name = "serde_spanned" name = "serde_spanned"
version = "0.6.1" version = "0.6.1"
@@ -1104,11 +1313,23 @@ checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
[[package]] [[package]]
name = "sqlparser" name = "sqlparser"
version = "0.32.0" version = "0.33.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0366f270dbabb5cc2e4c88427dc4c08bba144f81e32fbd459a013f26a4d16aa0" checksum = "355dc4d4b6207ca8a3434fc587db0a8016130a574dbcdbfb93d7f7b5bc5b211a"
dependencies = [ dependencies = [
"log", "log",
"sqlparser_derive",
]
[[package]]
name = "sqlparser_derive"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55fe75cb4a364c7f7ae06c7dbbc8d84bddd85d6cdf9975963c3935bc1991761e"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
] ]
[[package]] [[package]]
@@ -1164,6 +1385,26 @@ dependencies = [
"winapi-util", "winapi-util",
] ]
[[package]]
name = "thiserror"
version = "1.0.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10deb33631e3c9018b9baf9dcbbc4f737320d2b576bac10f6aefa048fa407e3e"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "982d17546b47146b28f7c22e3d08465f6b8903d0ea13c1660d9d84a6e7adcdbb"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]] [[package]]
name = "time" name = "time"
version = "0.1.45" version = "0.1.45"
@@ -1223,13 +1464,36 @@ dependencies = [
[[package]] [[package]]
name = "tokio-rustls" name = "tokio-rustls"
version = "0.23.4" version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59" checksum = "e0d409377ff5b1e3ca6437aa86c1eb7d40c134bfec254e44c830defa92669db5"
dependencies = [ dependencies = [
"rustls", "rustls",
"tokio", "tokio",
"webpki", ]
[[package]]
name = "tokio-stream"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d660770404473ccd7bc9f8b28494a811bc18542b915c0855c51e8f419d5223ce"
dependencies = [
"futures-core",
"pin-project-lite",
"tokio",
]
[[package]]
name = "tokio-test"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53474327ae5e166530d17f2d956afcb4f8a004de581b3cae10f12006bc8163e3"
dependencies = [
"async-stream",
"bytes",
"futures-core",
"tokio",
"tokio-stream",
] ]
[[package]] [[package]]
@@ -1294,9 +1558,21 @@ checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"pin-project-lite", "pin-project-lite",
"tracing-attributes",
"tracing-core", "tracing-core",
] ]
[[package]]
name = "tracing-attributes"
version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]] [[package]]
name = "tracing-core" name = "tracing-core"
version = "0.1.30" version = "0.1.30"
@@ -1306,6 +1582,51 @@ dependencies = [
"once_cell", "once_cell",
] ]
[[package]]
name = "trust-dns-proto"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f7f83d1e4a0e4358ac54c5c3681e5d7da5efc5a7a632c90bb6d6669ddd9bc26"
dependencies = [
"async-trait",
"cfg-if",
"data-encoding",
"enum-as-inner",
"futures-channel",
"futures-io",
"futures-util",
"idna 0.2.3",
"ipnet",
"lazy_static",
"rand",
"smallvec",
"thiserror",
"tinyvec",
"tokio",
"tracing",
"url",
]
[[package]]
name = "trust-dns-resolver"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aff21aa4dcefb0a1afbfac26deb0adc93888c7d295fb63ab273ef276ba2b7cfe"
dependencies = [
"cfg-if",
"futures-util",
"ipconfig",
"lazy_static",
"lru-cache",
"parking_lot",
"resolv-conf",
"smallvec",
"thiserror",
"tokio",
"tracing",
"trust-dns-proto",
]
[[package]] [[package]]
name = "try-lock" name = "try-lock"
version = "0.2.4" version = "0.2.4"
@@ -1351,6 +1672,17 @@ version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
[[package]]
name = "url"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643"
dependencies = [
"form_urlencoded",
"idna 0.3.0",
"percent-encoding",
]
[[package]] [[package]]
name = "version_check" name = "version_check"
version = "0.9.4" version = "0.9.4"
@@ -1444,15 +1776,20 @@ dependencies = [
] ]
[[package]] [[package]]
name = "webpki" name = "webpki-roots"
version = "0.22.0" version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd" checksum = "aa54963694b65584e170cf5dc46aeb4dcaa5584e652ff5f3952e56d66aff0125"
dependencies = [ dependencies = [
"ring", "rustls-webpki",
"untrusted",
] ]
[[package]]
name = "widestring"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17882f045410753661207383517a6f62ec3dbeb6a4ed2acce01f0728238d1983"
[[package]] [[package]]
name = "winapi" name = "winapi"
version = "0.3.9" version = "0.3.9"
@@ -1558,3 +1895,12 @@ checksum = "faf09497b8f8b5ac5d3bb4d05c0a99be20f26fd3d5f2db7b0716e946d5103658"
dependencies = [ dependencies = [
"memchr", "memchr",
] ]
[[package]]
name = "winreg"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d"
dependencies = [
"winapi",
]

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "pgcat" name = "pgcat"
version = "1.0.0" version = "1.0.2-alpha1"
edition = "2021" edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -14,12 +14,12 @@ rand = "0.8"
chrono = "0.4" chrono = "0.4"
sha-1 = "0.10" sha-1 = "0.10"
toml = "0.7" toml = "0.7"
serde = "1" serde = { version = "1", features = ["derive"] }
serde_derive = "1" serde_derive = "1"
regex = "1" regex = "1"
num_cpus = "1" num_cpus = "1"
once_cell = "1" once_cell = "1"
sqlparser = "0.32.0" sqlparser = {version = "0.33", features = ["visitor"] }
log = "0.4" log = "0.4"
arc-swap = "1" arc-swap = "1"
env_logger = "0.10" env_logger = "0.10"
@@ -28,7 +28,7 @@ hmac = "0.12"
sha2 = "0.10" sha2 = "0.10"
base64 = "0.21" base64 = "0.21"
stringprep = "0.1" stringprep = "0.1"
tokio-rustls = "0.23" tokio-rustls = "0.24"
rustls-pemfile = "1" rustls-pemfile = "1"
hyper = { version = "0.14", features = ["full"] } hyper = { version = "0.14", features = ["full"] }
phf = { version = "0.11.1", features = ["macros"] } phf = { version = "0.11.1", features = ["macros"] }
@@ -37,8 +37,15 @@ futures = "0.3"
socket2 = { version = "0.4.7", features = ["all"] } socket2 = { version = "0.4.7", features = ["all"] }
nix = "0.26.2" nix = "0.26.2"
atomic_enum = "0.2.0" atomic_enum = "0.2.0"
postgres-protocol = "0.6.4" postgres-protocol = "0.6.5"
fallible-iterator = "0.2" fallible-iterator = "0.2"
pin-project = "1"
webpki-roots = "0.23"
rustls = { version = "0.21", features = ["dangerous_configuration"] }
trust-dns-resolver = "0.22.0"
tokio-test = "0.4.2"
serde_json = "1"
[target.'cfg(not(target_env = "msvc"))'.dependencies] [target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0" jemallocator = "0.5.0"

View File

@@ -18,24 +18,56 @@ PostgreSQL pooler and proxy (like PgBouncer) with support for sharding, load bal
| Failover | **Stable** | Queries are automatically rerouted around broken replicas, validated by regular health checks. | | Failover | **Stable** | Queries are automatically rerouted around broken replicas, validated by regular health checks. |
| Admin database statistics | **Stable** | Pooler statistics and administration via the `pgbouncer` and `pgcat` databases. | | Admin database statistics | **Stable** | Pooler statistics and administration via the `pgbouncer` and `pgcat` databases. |
| Prometheus statistics | **Stable** | Statistics are reported via a HTTP endpoint for Prometheus. | | Prometheus statistics | **Stable** | Statistics are reported via a HTTP endpoint for Prometheus. |
| Client TLS | **Stable** | Clients can connect to the pooler using TLS/SSL. | | SSL/TLS | **Stable** | Clients can connect to the pooler using TLS. Pooler can connect to Postgres servers using TLS. |
| Client/Server authentication | **Stable** | Clients can connect using MD5 authentication, supported by `libpq` and all Postgres client drivers. PgCat can connect to Postgres using MD5 and SCRAM-SHA-256. | | Client/Server authentication | **Stable** | Clients can connect using MD5 authentication, supported by `libpq` and all Postgres client drivers. PgCat can connect to Postgres using MD5 and SCRAM-SHA-256. |
| Live configuration reloading | **Stable** | Identical to PgBouncer; all settings can be reloaded dynamically (except `host` and `port`). | | Live configuration reloading | **Stable** | Identical to PgBouncer; all settings can be reloaded dynamically (except `host` and `port`). |
| Auth passthrough | **Stable** | MD5 password authentication can be configured to use an `auth_query` so no cleartext passwords are needed in the config file.|
| Sharding using extended SQL syntax | **Experimental** | Clients can dynamically configure the pooler to route queries to specific shards. | | Sharding using extended SQL syntax | **Experimental** | Clients can dynamically configure the pooler to route queries to specific shards. |
| Sharding using comments parsing/Regex | **Experimental** | Clients can include shard information (sharding key, shard ID) in the query comments. | | Sharding using comments parsing/Regex | **Experimental** | Clients can include shard information (sharding key, shard ID) in the query comments. |
| Automatic sharding | **Experimental** | PgCat can parse queries, detect sharding keys automatically, and route queries to the correct shard. | | Automatic sharding | **Experimental** | PgCat can parse queries, detect sharding keys automatically, and route queries to the correct shard. |
| Mirroring | **Experimental** | Mirror queries between multiple databases in order to test servers with realistic production traffic. | | Mirroring | **Experimental** | Mirror queries between multiple databases in order to test servers with realistic production traffic. |
| Auth passthrough | **Experimental** | MD5 password authentication can be configured to use an `auth_query` so no cleartext passwords are needed in the config file. |
## Status ## Status
PgCat is stable and used in production to serve hundreds of thousands of queries per second. Some features remain experimental and are being actively developed. They are optional and can be enabled through configuration. PgCat is stable and used in production to serve hundreds of thousands of queries per second.
| | | <table>
|-|-| <tr>
|<a href="https://tech.instacart.com/adopting-pgcat-a-nextgen-postgres-proxy-3cf284e68c2f"><img src="./images/instacart.webp" height="70" width="auto"></a>|<a href="https://postgresml.org/blog/scaling-postgresml-to-one-million-requests-per-second"><img src="./images/postgresml.webp" height="70" width="auto"></a>| <td>
| [Instacart](https://tech.instacart.com/adopting-pgcat-a-nextgen-postgres-proxy-3cf284e68c2f) | [PostgresML](https://postgresml.org/blog/scaling-postgresml-to-one-million-requests-per-second) | <a href="https://tech.instacart.com/adopting-pgcat-a-nextgen-postgres-proxy-3cf284e68c2f">
<img src="./images/instacart.webp" height="70" width="auto">
</a>
</td>
<td>
<a href="https://postgresml.org/blog/scaling-postgresml-to-one-million-requests-per-second">
<img src="./images/postgresml.webp" height="70" width="auto">
</a>
</td>
<td>
<a href="https://onesignal.com">
<img src="./images/one_signal.webp" height="70" width="auto">
</a>
</td>
</tr>
<tr>
<td>
<a href="https://tech.instacart.com/adopting-pgcat-a-nextgen-postgres-proxy-3cf284e68c2f">
Instacart
</a>
</td>
<td>
<a href="https://postgresml.org/blog/scaling-postgresml-to-one-million-requests-per-second">
PostgresML
</a>
</td>
<td>
OneSignal
</td>
</tr>
</table>
Some features remain experimental and are being actively developed. They are optional and can be enabled through configuration.
## Deployment ## Deployment
@@ -99,7 +131,7 @@ You can open a Docker development environment where you can debug tests easier.
./dev/script/console ./dev/script/console
``` ```
This will open a terminal in an environment similar to that used in tests. In there, you can compile the pooler, run tests, do some debugging with the test environment, etc. Objects compiled inside the contaner (and bundled gems) will be placed in `dev/cache` so they don't interfere with what you have on your machine. This will open a terminal in an environment similar to that used in tests. In there, you can compile the pooler, run tests, do some debugging with the test environment, etc. Objects compiled inside the container (and bundled gems) will be placed in `dev/cache` so they don't interfere with what you have on your machine.
## Usage ## Usage

View File

@@ -38,9 +38,6 @@ log_client_connections = false
# If we should log client disconnections # If we should log client disconnections
log_client_disconnections = false log_client_disconnections = false
# Reload config automatically if it changes.
autoreload = false
# TLS # TLS
# tls_certificate = "server.cert" # tls_certificate = "server.cert"
# tls_private_key = "server.key" # tls_private_key = "server.key"
@@ -76,7 +73,7 @@ query_parser_enabled = true
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for # If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
# load balancing of read queries. Otherwise, the primary will only be used for write # load balancing of read queries. Otherwise, the primary will only be used for write
# queries. The primary can always be explicitely selected with our custom protocol. # queries. The primary can always be explicitly selected with our custom protocol.
primary_reads_enabled = true primary_reads_enabled = true
# So what if you wanted to implement a different hashing function, # So what if you wanted to implement a different hashing function,

BIN
images/one_signal.webp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

View File

@@ -23,6 +23,9 @@ connect_timeout = 5000 # milliseconds
# How long an idle connection with a server is left open (ms). # How long an idle connection with a server is left open (ms).
idle_timeout = 30000 # milliseconds idle_timeout = 30000 # milliseconds
# Max connection lifetime before it's closed, even if actively used.
server_lifetime = 86400000 # 24 hours
# How long a client is allowed to be idle while in a transaction (ms). # How long a client is allowed to be idle while in a transaction (ms).
idle_client_in_transaction_timeout = 0 # milliseconds idle_client_in_transaction_timeout = 0 # milliseconds
@@ -45,7 +48,7 @@ log_client_connections = false
log_client_disconnections = false log_client_disconnections = false
# When set to true, PgCat reloads configs if it detects a change in the config file. # When set to true, PgCat reloads configs if it detects a change in the config file.
autoreload = false autoreload = 15000
# Number of worker threads the Runtime will use (4 by default). # Number of worker threads the Runtime will use (4 by default).
worker_threads = 5 worker_threads = 5
@@ -57,10 +60,16 @@ tcp_keepalives_count = 5
# Number of seconds between keepalive packets. # Number of seconds between keepalive packets.
tcp_keepalives_interval = 5 tcp_keepalives_interval = 5
# Path to TLS Certficate file to use for TLS connections # Path to TLS Certificate file to use for TLS connections
# tls_certificate = "server.cert" # tls_certificate = ".circleci/server.cert"
# Path to TLS private key file to use for TLS connections # Path to TLS private key file to use for TLS connections
# tls_private_key = "server.key" # tls_private_key = ".circleci/server.key"
# Enable/disable server TLS
server_tls = false
# Verify server certificate is completely authentic.
verify_server_certificate = false
# User name to access the virtual administrative database (pgbouncer or pgcat) # User name to access the virtual administrative database (pgbouncer or pgcat)
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc.. # Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
@@ -113,6 +122,21 @@ primary_reads_enabled = true
# `sha1`: A hashing function based on SHA1 # `sha1`: A hashing function based on SHA1
sharding_function = "pg_bigint_hash" sharding_function = "pg_bigint_hash"
# Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
# established using the database configured in the pool. This parameter is inherited by every pool
# and can be redefined in pool configuration.
# auth_query = "SELECT $1"
# User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
# specified in `auth_query_user`. The connection will be established using the database configured in the pool.
# This parameter is inherited by every pool and can be redefined in pool configuration.
# auth_query_user = "sharding_user"
# Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
# specified in `auth_query_user`. The connection will be established using the database configured in the pool.
# This parameter is inherited by every pool and can be redefined in pool configuration.
# auth_query_password = "sharding_user"
# Automatically parse this from queries and route queries to the right shard! # Automatically parse this from queries and route queries to the right shard!
# automatic_sharding_key = "data.id" # automatic_sharding_key = "data.id"
@@ -122,18 +146,78 @@ idle_timeout = 40000
# Connect timeout can be overwritten in the pool # Connect timeout can be overwritten in the pool
connect_timeout = 3000 connect_timeout = 3000
# When enabled, ip resolutions for server connections specified using hostnames will be cached
# and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found
# old ip connections are closed (gracefully) and new connections will start using new ip.
# dns_cache_enabled = false
# Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
# dns_max_ttl = 30
[plugins]
[plugins.query_logger]
enabled = false
[plugins.table_access]
enabled = false
tables = [
"pg_user",
"pg_roles",
"pg_database",
]
[plugins.intercept]
enabled = true
[plugins.intercept.queries.0]
query = "select current_database() as a, current_schemas(false) as b"
schema = [
["a", "text"],
["b", "text"],
]
result = [
["${DATABASE}", "{public}"],
]
[plugins.intercept.queries.1]
query = "select current_database(), current_schema(), current_user"
schema = [
["current_database", "text"],
["current_schema", "text"],
["current_user", "text"],
]
result = [
["${DATABASE}", "public", "${USER}"],
]
# User configs are structured as pool.<pool_name>.users.<user_index> # User configs are structured as pool.<pool_name>.users.<user_index>
# This secion holds the credentials for users that may connect to this cluster # This section holds the credentials for users that may connect to this cluster
[pools.sharded_db.users.0] [pools.sharded_db.users.0]
# Postgresql username # PostgreSQL username used to authenticate the user and connect to the server
# if `server_username` is not set.
username = "sharding_user" username = "sharding_user"
# Postgresql password
# PostgreSQL password used to authenticate the user and connect to the server
# if `server_password` is not set.
password = "sharding_user" password = "sharding_user"
pool_mode = "session"
# PostgreSQL username used to connect to the server.
# server_username = "another_user"
# PostgreSQL password used to connect to the server.
# server_password = "another_password"
# Maximum number of server connections that can be established for this user # Maximum number of server connections that can be established for this user
# The maximum number of connection from a single Pgcat process to any database in the cluster # The maximum number of connection from a single Pgcat process to any database in the cluster
# is the sum of pool_size across all users. # is the sum of pool_size across all users.
pool_size = 9 pool_size = 9
# Maximum query duration. Dangerous, but protects against DBs that died in a non-obvious way. # Maximum query duration. Dangerous, but protects against DBs that died in a non-obvious way.
# 0 means it is disabled. # 0 means it is disabled.
statement_timeout = 0 statement_timeout = 0
@@ -178,6 +262,8 @@ sharding_function = "pg_bigint_hash"
username = "simple_user" username = "simple_user"
password = "simple_user" password = "simple_user"
pool_size = 5 pool_size = 5
min_pool_size = 3
server_lifetime = 60000
statement_timeout = 0 statement_timeout = 0
[pools.simple_db.shards.0] [pools.simple_db.shards.0]

View File

@@ -12,9 +12,9 @@ use tokio::time::Instant;
use crate::config::{get_config, reload_config, VERSION}; use crate::config::{get_config, reload_config, VERSION};
use crate::errors::Error; use crate::errors::Error;
use crate::messages::*; use crate::messages::*;
use crate::pool::ClientServerMap;
use crate::pool::{get_all_pools, get_pool}; use crate::pool::{get_all_pools, get_pool};
use crate::stats::{get_client_stats, get_pool_stats, get_server_stats, ClientState, ServerState}; use crate::stats::{get_client_stats, get_pool_stats, get_server_stats, ClientState, ServerState};
use crate::ClientServerMap;
pub fn generate_server_info_for_admin() -> BytesMut { pub fn generate_server_info_for_admin() -> BytesMut {
let mut server_info = BytesMut::new(); let mut server_info = BytesMut::new();

View File

@@ -1,4 +1,5 @@
use crate::errors::Error; use crate::errors::Error;
use crate::pool::ConnectionPool;
use crate::server::Server; use crate::server::Server;
use log::debug; use log::debug;
@@ -71,25 +72,36 @@ impl AuthPassthrough {
let auth_user = crate::config::User { let auth_user = crate::config::User {
username: self.user.clone(), username: self.user.clone(),
password: Some(self.password.clone()), password: Some(self.password.clone()),
server_username: None,
server_password: None,
pool_size: 1, pool_size: 1,
statement_timeout: 0, statement_timeout: 0,
pool_mode: None,
server_lifetime: None,
min_pool_size: None,
}; };
let user = &address.username; let user = &address.username;
debug!("Connecting to server to obtain auth hashes."); debug!("Connecting to server to obtain auth hashes");
let auth_query = self.query.replace("$1", user); let auth_query = self.query.replace("$1", user);
match Server::exec_simple_query(address, &auth_user, &auth_query).await { match Server::exec_simple_query(address, &auth_user, &auth_query).await {
Ok(password_data) => { Ok(password_data) => {
if password_data.len() == 2 && password_data.first().unwrap() == user { if password_data.len() == 2 && password_data.first().unwrap() == user {
if let Some(stripped_hash) = password_data.last().unwrap().to_string().strip_prefix("md5") { if let Some(stripped_hash) = password_data
Ok(stripped_hash.to_string()) .last()
} .unwrap()
else { .to_string()
Err(Error::AuthPassthroughError( .strip_prefix("md5") {
"Obtained hash from auth_query does not seem to be in md5 format.".to_string(), Ok(stripped_hash.to_string())
)) }
} else {
Err(Error::AuthPassthroughError(
"Obtained hash from auth_query does not seem to be in md5 format.".to_string(),
))
}
} else { } else {
Err(Error::AuthPassthroughError( Err(Error::AuthPassthroughError(
"Data obtained from query does not follow the scheme 'user','hash'." "Data obtained from query does not follow the scheme 'user','hash'."
@@ -98,10 +110,25 @@ impl AuthPassthrough {
} }
} }
Err(err) => { Err(err) => {
Err(Error::AuthPassthroughError( Err(Error::AuthPassthroughError(
format!("Error trying to obtain password from auth_query, ignoring hash for user '{}'. Error: {:?}", format!("Error trying to obtain password from auth_query, ignoring hash for user '{}'. Error: {:?}",
user, err))) user, err))
)
} }
} }
} }
} }
pub async fn refetch_auth_hash(pool: &ConnectionPool) -> Result<String, Error> {
let address = pool.address(0, 0);
if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) {
let hash = apt.fetch_hash(address).await?;
return Ok(hash);
}
Err(Error::ClientError(format!(
"Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.",
address.username, address.database
)))
}

View File

@@ -1,4 +1,4 @@
use crate::errors::Error; use crate::errors::{ClientIdentifier, Error};
use crate::pool::BanReason; use crate::pool::BanReason;
/// Handle clients by pretending to be a PostgreSQL server. /// Handle clients by pretending to be a PostgreSQL server.
use bytes::{Buf, BufMut, BytesMut}; use bytes::{Buf, BufMut, BytesMut};
@@ -12,10 +12,11 @@ use tokio::sync::broadcast::Receiver;
use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::Sender;
use crate::admin::{generate_server_info_for_admin, handle_admin}; use crate::admin::{generate_server_info_for_admin, handle_admin};
use crate::auth_passthrough::AuthPassthrough; use crate::auth_passthrough::refetch_auth_hash;
use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode}; use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode};
use crate::constants::*; use crate::constants::*;
use crate::messages::*; use crate::messages::*;
use crate::plugins::PluginOutput;
use crate::pool::{get_pool, ClientServerMap, ConnectionPool}; use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
use crate::query_router::{Command, QueryRouter}; use crate::query_router::{Command, QueryRouter};
use crate::server::Server; use crate::server::Server;
@@ -202,7 +203,7 @@ pub async fn client_entrypoint(
// Client probably disconnected rejecting our plain text connection. // Client probably disconnected rejecting our plain text connection.
Ok((ClientConnectionType::Tls, _)) Ok((ClientConnectionType::Tls, _))
| Ok((ClientConnectionType::CancelQuery, _)) => Err(Error::ProtocolSyncError( | Ok((ClientConnectionType::CancelQuery, _)) => Err(Error::ProtocolSyncError(
format!("Bad postgres client (plain)"), "Bad postgres client (plain)".into(),
)), )),
Err(err) => Err(err), Err(err) => Err(err),
@@ -369,28 +370,14 @@ pub async fn startup_tls(
} }
// Bad Postgres client. // Bad Postgres client.
Ok((ClientConnectionType::Tls, _)) | Ok((ClientConnectionType::CancelQuery, _)) => Err( Ok((ClientConnectionType::Tls, _)) | Ok((ClientConnectionType::CancelQuery, _)) => {
Error::ProtocolSyncError(format!("Bad postgres client (tls)")), Err(Error::ProtocolSyncError("Bad postgres client (tls)".into()))
), }
Err(err) => Err(err), Err(err) => Err(err),
} }
} }
async fn refetch_auth_hash(pool: &ConnectionPool) -> Result<String, Error> {
let address = pool.address(0, 0);
if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) {
let hash = apt.fetch_hash(address).await?;
return Ok(hash);
}
Err(Error::ClientError(format!(
"Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.",
address.username, address.database
)))
}
impl<S, T> Client<S, T> impl<S, T> Client<S, T>
where where
S: tokio::io::AsyncRead + std::marker::Unpin, S: tokio::io::AsyncRead + std::marker::Unpin,
@@ -418,7 +405,7 @@ where
Some(user) => user, Some(user) => user,
None => { None => {
return Err(Error::ClientError( return Err(Error::ClientError(
"Missing user parameter on client startup".to_string(), "Missing user parameter on client startup".into(),
)) ))
} }
}; };
@@ -433,6 +420,8 @@ where
None => "pgcat", None => "pgcat",
}; };
let client_identifier = ClientIdentifier::new(&application_name, &username, &pool_name);
let admin = ["pgcat", "pgbouncer"] let admin = ["pgcat", "pgbouncer"]
.iter() .iter()
.filter(|db| *db == pool_name) .filter(|db| *db == pool_name)
@@ -463,7 +452,12 @@ where
let code = match read.read_u8().await { let code = match read.read_u8().await {
Ok(p) => p, Ok(p) => p,
Err(_) => return Err(Error::SocketError(format!("Error reading password code from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))), Err(_) => {
return Err(Error::ClientSocketError(
"password code".into(),
client_identifier,
))
}
}; };
// PasswordMessage // PasswordMessage
@@ -476,19 +470,30 @@ where
let len = match read.read_i32().await { let len = match read.read_i32().await {
Ok(len) => len, Ok(len) => len,
Err(_) => return Err(Error::SocketError(format!("Error reading password message length from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))), Err(_) => {
return Err(Error::ClientSocketError(
"password message length".into(),
client_identifier,
))
}
}; };
let mut password_response = vec![0u8; (len - 4) as usize]; let mut password_response = vec![0u8; (len - 4) as usize];
match read.read_exact(&mut password_response).await { match read.read_exact(&mut password_response).await {
Ok(_) => (), Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading password message from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))), Err(_) => {
return Err(Error::ClientSocketError(
"password message".into(),
client_identifier,
))
}
}; };
// Authenticate admin user. // Authenticate admin user.
let (transaction_mode, server_info) = if admin { let (transaction_mode, server_info) = if admin {
let config = get_config(); let config = get_config();
// Compare server and client hashes. // Compare server and client hashes.
let password_hash = md5_hash_password( let password_hash = md5_hash_password(
&config.general.admin_username, &config.general.admin_username,
@@ -497,10 +502,12 @@ where
); );
if password_hash != password_response { if password_hash != password_response {
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name); let error = Error::ClientGeneralError("Invalid password".into(), client_identifier);
warn!("{}", error);
wrong_password(&mut write, username).await?; wrong_password(&mut write, username).await?;
return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))); return Err(error);
} }
(false, generate_server_info_for_admin()) (false, generate_server_info_for_admin())
@@ -519,7 +526,10 @@ where
) )
.await?; .await?;
return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))); return Err(Error::ClientGeneralError(
"Invalid pool name".into(),
client_identifier,
));
} }
}; };
@@ -530,16 +540,24 @@ where
Some(md5_hash_password(username, password, &salt)) Some(md5_hash_password(username, password, &salt))
} else { } else {
if !get_config().is_auth_query_configured() { if !get_config().is_auth_query_configured() {
return Err(Error::ClientError(format!("Client auth not possible, no cleartext password set for username: {:?} in config and auth passthrough (query_auth) is not set up.", username))); wrong_password(&mut write, username).await?;
return Err(Error::ClientAuthImpossible(username.into()));
} }
let mut hash = (*pool.auth_hash.read()).clone(); let mut hash = (*pool.auth_hash.read()).clone();
if hash.is_none() { if hash.is_none() {
warn!("Query auth configured but no hash password found for pool {}. Will try to refetch it.", pool_name); warn!(
"Query auth configured \
but no hash password found \
for pool {}. Will try to refetch it.",
pool_name
);
match refetch_auth_hash(&pool).await { match refetch_auth_hash(&pool).await {
Ok(fetched_hash) => { Ok(fetched_hash) => {
warn!("Password for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, obtained. Updating.", username, pool_name, application_name); warn!("Password for {}, obtained. Updating.", client_identifier);
{ {
let mut pool_auth_hash = pool.auth_hash.write(); let mut pool_auth_hash = pool.auth_hash.write();
*pool_auth_hash = Some(fetched_hash.clone()); *pool_auth_hash = Some(fetched_hash.clone());
@@ -547,16 +565,14 @@ where
hash = Some(fetched_hash); hash = Some(fetched_hash);
} }
Err(err) => { Err(err) => {
return Err( wrong_password(&mut write, username).await?;
Error::ClientError(
format!("No cleartext password set, and no auth passthrough could not obtain the hash from server for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, the error was: {:?}", return Err(Error::ClientAuthPassthroughError(
username, err.to_string(),
pool_name, client_identifier,
application_name, ));
err)
)
);
} }
} }
}; };
@@ -570,20 +586,39 @@ where
// //
// @TODO: we could end up fetching again the same password twice (see above). // @TODO: we could end up fetching again the same password twice (see above).
if password_hash.unwrap() != password_response { if password_hash.unwrap() != password_response {
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, will try to refetch it.", username, pool_name, application_name); warn!(
let fetched_hash = refetch_auth_hash(&pool).await?; "Invalid password {}, will try to refetch it.",
client_identifier
);
let fetched_hash = match refetch_auth_hash(&pool).await {
Ok(fetched_hash) => fetched_hash,
Err(err) => {
wrong_password(&mut write, username).await?;
return Err(err);
}
};
let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt); let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt);
// Ok password changed in server an auth is possible. // Ok password changed in server an auth is possible.
if new_password_hash == password_response { if new_password_hash == password_response {
warn!("Password for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, changed in server. Updating.", username, pool_name, application_name); warn!(
"Password for {}, changed in server. Updating.",
client_identifier
);
{ {
let mut pool_auth_hash = pool.auth_hash.write(); let mut pool_auth_hash = pool.auth_hash.write();
*pool_auth_hash = Some(fetched_hash); *pool_auth_hash = Some(fetched_hash);
} }
} else { } else {
wrong_password(&mut write, username).await?; wrong_password(&mut write, username).await?;
return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))); return Err(Error::ClientGeneralError(
"Invalid password".into(),
client_identifier,
));
} }
} }
@@ -731,6 +766,9 @@ where
self.stats.register(self.stats.clone()); self.stats.register(self.stats.clone());
// Result returned by one of the plugins.
let mut plugin_output = None;
// Our custom protocol loop. // Our custom protocol loop.
// We expect the client to either start a transaction with regular queries // We expect the client to either start a transaction with regular queries
// or issue commands for our sharding and server selection protocol. // or issue commands for our sharding and server selection protocol.
@@ -753,9 +791,9 @@ where
&mut self.write, &mut self.write,
"terminating connection due to administrator command" "terminating connection due to administrator command"
).await?; ).await?;
self.stats.disconnect();
return Ok(()) self.stats.disconnect();
return Ok(());
} }
// Admin clients ignore shutdown. // Admin clients ignore shutdown.
@@ -781,7 +819,25 @@ where
'Q' => { 'Q' => {
if query_router.query_parser_enabled() { if query_router.query_parser_enabled() {
query_router.infer(&message); if let Ok(ast) = QueryRouter::parse(&message) {
let plugin_result = query_router.execute_plugins(&ast).await;
match plugin_result {
Ok(PluginOutput::Deny(error)) => {
error_response(&mut self.write, &error).await?;
continue;
}
Ok(PluginOutput::Intercept(result)) => {
write_all(&mut self.write, result).await?;
continue;
}
_ => (),
};
let _ = query_router.infer(&ast);
}
} }
} }
@@ -789,7 +845,13 @@ where
self.buffer.put(&message[..]); self.buffer.put(&message[..]);
if query_router.query_parser_enabled() { if query_router.query_parser_enabled() {
query_router.infer(&message); if let Ok(ast) = QueryRouter::parse(&message) {
if let Ok(output) = query_router.execute_plugins(&ast).await {
plugin_output = Some(output);
}
let _ = query_router.infer(&ast);
}
} }
continue; continue;
@@ -823,6 +885,18 @@ where
continue; continue;
} }
// Check on plugin results.
match plugin_output {
Some(PluginOutput::Deny(error)) => {
self.buffer.clear();
error_response(&mut self.write, &error).await?;
plugin_output = None;
continue;
}
_ => (),
};
// Get a pool instance referenced by the most up-to-date // Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config // pointer. This ensures we always read the latest config
// when starting a query. // when starting a query.
@@ -928,11 +1002,26 @@ where
error!("Got Sync message but failed to get a connection from the pool"); error!("Got Sync message but failed to get a connection from the pool");
self.buffer.clear(); self.buffer.clear();
} }
error_response(&mut self.write, "could not get connection from the pool") error_response(&mut self.write, "could not get connection from the pool")
.await?; .await?;
error!("Could not get connection from pool: {{ pool_name: {:?}, username: {:?}, shard: {:?}, role: \"{:?}\", error: \"{:?}\" }}", error!(
self.pool_name.clone(), self.username.clone(), query_router.shard(), query_router.role(), err); "Could not get connection from pool: \
{{ \
pool_name: {:?}, \
username: {:?}, \
shard: {:?}, \
role: \"{:?}\", \
error: \"{:?}\" \
}}",
self.pool_name,
self.username,
query_router.shard(),
query_router.role(),
err
);
continue; continue;
} }
}; };
@@ -999,11 +1088,25 @@ where
Err(_) => { Err(_) => {
// Client idle in transaction timeout // Client idle in transaction timeout
error_response(&mut self.write, "idle transaction timeout").await?; error_response(&mut self.write, "idle transaction timeout").await?;
error!("Client idle in transaction timeout: {{ pool_name: {:?}, username: {:?}, shard: {:?}, role: \"{:?}\"}}", self.pool_name.clone(), self.username.clone(), query_router.shard(), query_router.role()); error!(
"Client idle in transaction timeout: \
{{ \
pool_name: {}, \
username: {}, \
shard: {}, \
role: \"{:?}\" \
}}",
self.pool_name,
self.username,
query_router.shard(),
query_router.role()
);
break; break;
} }
} }
} }
Some(message) => { Some(message) => {
initial_message = None; initial_message = None;
message message
@@ -1022,6 +1125,27 @@ where
match code { match code {
// Query // Query
'Q' => { 'Q' => {
if query_router.query_parser_enabled() {
if let Ok(ast) = QueryRouter::parse(&message) {
let plugin_result = query_router.execute_plugins(&ast).await;
match plugin_result {
Ok(PluginOutput::Deny(error)) => {
error_response(&mut self.write, &error).await?;
continue;
}
Ok(PluginOutput::Intercept(result)) => {
write_all(&mut self.write, result).await?;
continue;
}
_ => (),
};
let _ = query_router.infer(&ast);
}
}
debug!("Sending query to server"); debug!("Sending query to server");
self.send_and_receive_loop( self.send_and_receive_loop(
@@ -1061,6 +1185,14 @@ where
// Parse // Parse
// The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`. // The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
'P' => { 'P' => {
if query_router.query_parser_enabled() {
if let Ok(ast) = QueryRouter::parse(&message) {
if let Ok(output) = query_router.execute_plugins(&ast).await {
plugin_output = Some(output);
}
}
}
self.buffer.put(&message[..]); self.buffer.put(&message[..]);
} }
@@ -1076,6 +1208,11 @@ where
self.buffer.put(&message[..]); self.buffer.put(&message[..]);
} }
// Close the prepared statement.
'C' => {
self.buffer.put(&message[..]);
}
// Execute // Execute
// Execute a prepared statement prepared in `P` and bound in `B`. // Execute a prepared statement prepared in `P` and bound in `B`.
'E' => { 'E' => {
@@ -1087,6 +1224,24 @@ where
'S' => { 'S' => {
debug!("Sending query to server"); debug!("Sending query to server");
match plugin_output {
Some(PluginOutput::Deny(error)) => {
error_response(&mut self.write, &error).await?;
plugin_output = None;
self.buffer.clear();
continue;
}
Some(PluginOutput::Intercept(result)) => {
write_all(&mut self.write, result).await?;
plugin_output = None;
self.buffer.clear();
continue;
}
_ => (),
};
self.buffer.put(&message[..]); self.buffer.put(&message[..]);
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char; let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;

View File

@@ -12,6 +12,7 @@ use std::sync::Arc;
use tokio::fs::File; use tokio::fs::File;
use tokio::io::AsyncReadExt; use tokio::io::AsyncReadExt;
use crate::dns_cache::CachedResolver;
use crate::errors::Error; use crate::errors::Error;
use crate::pool::{ClientServerMap, ConnectionPool}; use crate::pool::{ClientServerMap, ConnectionPool};
use crate::sharding::ShardingFunction; use crate::sharding::ShardingFunction;
@@ -178,7 +179,12 @@ impl Address {
pub struct User { pub struct User {
pub username: String, pub username: String,
pub password: Option<String>, pub password: Option<String>,
pub server_username: Option<String>,
pub server_password: Option<String>,
pub pool_size: u32, pub pool_size: u32,
pub min_pool_size: Option<u32>,
pub pool_mode: Option<PoolMode>,
pub server_lifetime: Option<u64>,
#[serde(default)] // 0 #[serde(default)] // 0
pub statement_timeout: u64, pub statement_timeout: u64,
} }
@@ -188,12 +194,37 @@ impl Default for User {
User { User {
username: String::from("postgres"), username: String::from("postgres"),
password: None, password: None,
server_username: None,
server_password: None,
pool_size: 15, pool_size: 15,
min_pool_size: None,
statement_timeout: 0, statement_timeout: 0,
pool_mode: None,
server_lifetime: None,
} }
} }
} }
impl User {
fn validate(&self) -> Result<(), Error> {
match self.min_pool_size {
Some(min_pool_size) => {
if min_pool_size > self.pool_size {
error!(
"min_pool_size of {} cannot be larger than pool_size of {}",
min_pool_size, self.pool_size
);
return Err(Error::BadConfig);
}
}
None => (),
};
Ok(())
}
}
/// General configuration. /// General configuration.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct General { pub struct General {
@@ -201,7 +232,7 @@ pub struct General {
pub host: String, pub host: String,
#[serde(default = "General::default_port")] #[serde(default = "General::default_port")]
pub port: i16, pub port: u16,
pub enable_prometheus_exporter: Option<bool>, pub enable_prometheus_exporter: Option<bool>,
pub prometheus_exporter_port: i16, pub prometheus_exporter_port: i16,
@@ -225,6 +256,12 @@ pub struct General {
#[serde(default)] // False #[serde(default)] // False
pub log_client_disconnections: bool, pub log_client_disconnections: bool,
#[serde(default)] // False
pub dns_cache_enabled: bool,
#[serde(default = "General::default_dns_max_ttl")]
pub dns_max_ttl: u64,
#[serde(default = "General::default_shutdown_timeout")] #[serde(default = "General::default_shutdown_timeout")]
pub shutdown_timeout: u64, pub shutdown_timeout: u64,
@@ -240,17 +277,28 @@ pub struct General {
#[serde(default = "General::default_idle_client_in_transaction_timeout")] #[serde(default = "General::default_idle_client_in_transaction_timeout")]
pub idle_client_in_transaction_timeout: u64, pub idle_client_in_transaction_timeout: u64,
#[serde(default = "General::default_server_lifetime")]
pub server_lifetime: u64,
#[serde(default = "General::default_worker_threads")] #[serde(default = "General::default_worker_threads")]
pub worker_threads: usize, pub worker_threads: usize,
#[serde(default)] // False #[serde(default)] // None
pub autoreload: bool, pub autoreload: Option<u64>,
pub tls_certificate: Option<String>, pub tls_certificate: Option<String>,
pub tls_private_key: Option<String>, pub tls_private_key: Option<String>,
#[serde(default)] // false
pub server_tls: bool,
#[serde(default)] // false
pub verify_server_certificate: bool,
pub admin_username: String, pub admin_username: String,
pub admin_password: String, pub admin_password: String,
// Support for auth query
pub auth_query: Option<String>, pub auth_query: Option<String>,
pub auth_query_user: Option<String>, pub auth_query_user: Option<String>,
pub auth_query_password: Option<String>, pub auth_query_password: Option<String>,
@@ -261,17 +309,21 @@ impl General {
"0.0.0.0".into() "0.0.0.0".into()
} }
pub fn default_port() -> i16 { pub fn default_port() -> u16 {
5432 5432
} }
pub fn default_server_lifetime() -> u64 {
1000 * 60 * 60 * 24 // 24 hours
}
pub fn default_connect_timeout() -> u64 { pub fn default_connect_timeout() -> u64 {
1000 1000
} }
// These keepalive defaults should detect a dead connection within 30 seconds. // These keepalive defaults should detect a dead connection within 30 seconds.
// Tokio defaults to disabling keepalives which keeps dead connections around indefinitely. // Tokio defaults to disabling keepalives which keeps dead connections around indefinitely.
// This can lead to permenant server pool exhaustion // This can lead to permanent server pool exhaustion
pub fn default_tcp_keepalives_idle() -> u64 { pub fn default_tcp_keepalives_idle() -> u64 {
5 // 5 seconds 5 // 5 seconds
} }
@@ -292,6 +344,10 @@ impl General {
60000 60000
} }
pub fn default_dns_max_ttl() -> u64 {
30
}
pub fn default_healthcheck_timeout() -> u64 { pub fn default_healthcheck_timeout() -> u64 {
1000 1000
} }
@@ -333,14 +389,19 @@ impl Default for General {
tcp_keepalives_interval: Self::default_tcp_keepalives_interval(), tcp_keepalives_interval: Self::default_tcp_keepalives_interval(),
log_client_connections: false, log_client_connections: false,
log_client_disconnections: false, log_client_disconnections: false,
autoreload: false, autoreload: None,
dns_cache_enabled: false,
dns_max_ttl: Self::default_dns_max_ttl(),
tls_certificate: None, tls_certificate: None,
tls_private_key: None, tls_private_key: None,
server_tls: false,
verify_server_certificate: false,
admin_username: String::from("admin"), admin_username: String::from("admin"),
admin_password: String::from("admin"), admin_password: String::from("admin"),
auth_query: None, auth_query: None,
auth_query_user: None, auth_query_user: None,
auth_query_password: None, auth_query_password: None,
server_lifetime: 1000 * 3600 * 24, // 24 hours,
} }
} }
} }
@@ -356,6 +417,7 @@ pub enum PoolMode {
#[serde(alias = "session", alias = "Session")] #[serde(alias = "session", alias = "Session")]
Session, Session,
} }
impl ToString for PoolMode { impl ToString for PoolMode {
fn to_string(&self) -> String { fn to_string(&self) -> String {
match *self { match *self {
@@ -404,6 +466,8 @@ pub struct Pool {
pub idle_timeout: Option<u64>, pub idle_timeout: Option<u64>,
pub server_lifetime: Option<u64>,
pub sharding_function: ShardingFunction, pub sharding_function: ShardingFunction,
#[serde(default = "Pool::default_automatic_sharding_key")] #[serde(default = "Pool::default_automatic_sharding_key")]
@@ -419,7 +483,7 @@ pub struct Pool {
pub shards: BTreeMap<String, Shard>, pub shards: BTreeMap<String, Shard>,
pub users: BTreeMap<String, User>, pub users: BTreeMap<String, User>,
// Note, don't put simple fields below these configs. There's a compatability issue with TOML that makes it // Note, don't put simple fields below these configs. There's a compatibility issue with TOML that makes it
// incompatible to have simple fields in TOML after complex objects. See // incompatible to have simple fields in TOML after complex objects. See
// https://users.rust-lang.org/t/why-toml-to-string-get-error-valueaftertable/85903 // https://users.rust-lang.org/t/why-toml-to-string-get-error-valueaftertable/85903
} }
@@ -508,6 +572,10 @@ impl Pool {
None => None, None => None,
}; };
for (_, user) in &self.users {
user.validate()?;
}
Ok(()) Ok(())
} }
} }
@@ -532,6 +600,7 @@ impl Default for Pool {
auth_query: None, auth_query: None,
auth_query_user: None, auth_query_user: None,
auth_query_password: None, auth_query_password: None,
server_lifetime: None,
} }
} }
} }
@@ -581,7 +650,7 @@ impl Shard {
if primary_count > 1 { if primary_count > 1 {
error!( error!(
"Shard {} has more than on primary configured", "Shard {} has more than one primary configured",
self.database self.database
); );
return Err(Error::BadConfig); return Err(Error::BadConfig);
@@ -610,6 +679,55 @@ impl Default for Shard {
} }
} }
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Plugins {
pub intercept: Option<Intercept>,
pub table_access: Option<TableAccess>,
pub query_logger: Option<QueryLogger>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Intercept {
pub enabled: bool,
pub queries: BTreeMap<String, Query>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct TableAccess {
pub enabled: bool,
pub tables: Vec<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct QueryLogger {
pub enabled: bool,
}
impl Intercept {
pub fn substitute(&mut self, db: &str, user: &str) {
for (_, query) in self.queries.iter_mut() {
query.substitute(db, user);
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Query {
pub query: String,
pub schema: Vec<Vec<String>>,
pub result: Vec<Vec<String>>,
}
impl Query {
pub fn substitute(&mut self, db: &str, user: &str) {
for col in self.result.iter_mut() {
for i in 0..col.len() {
col[i] = col[i].replace("${USER}", user).replace("${DATABASE}", db);
}
}
}
}
/// Configuration wrapper. /// Configuration wrapper.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Config { pub struct Config {
@@ -628,6 +746,7 @@ pub struct Config {
pub path: String, pub path: String,
pub general: General, pub general: General,
pub plugins: Option<Plugins>,
pub pools: HashMap<String, Pool>, pub pools: HashMap<String, Pool>,
} }
@@ -665,6 +784,7 @@ impl Default for Config {
path: Self::default_path(), path: Self::default_path(),
general: General::default(), general: General::default(),
pools: HashMap::default(), pools: HashMap::default(),
plugins: None,
} }
} }
} }
@@ -784,6 +904,10 @@ impl Config {
); );
info!("Shutdown timeout: {}ms", self.general.shutdown_timeout); info!("Shutdown timeout: {}ms", self.general.shutdown_timeout);
info!("Healthcheck delay: {}ms", self.general.healthcheck_delay); info!("Healthcheck delay: {}ms", self.general.healthcheck_delay);
info!(
"Default max server lifetime: {}ms",
self.general.server_lifetime
);
match self.general.tls_certificate.clone() { match self.general.tls_certificate.clone() {
Some(tls_certificate) => { Some(tls_certificate) => {
info!("TLS certificate: {}", tls_certificate); info!("TLS certificate: {}", tls_certificate);
@@ -802,6 +926,11 @@ impl Config {
info!("TLS support is disabled"); info!("TLS support is disabled");
} }
}; };
info!("Server TLS enabled: {}", self.general.server_tls);
info!(
"Server TLS certificate verification: {}",
self.general.verify_server_certificate
);
for (pool_name, pool_config) in &self.pools { for (pool_name, pool_config) in &self.pools {
// TODO: Make this output prettier (maybe a table?) // TODO: Make this output prettier (maybe a table?)
@@ -816,8 +945,9 @@ impl Config {
.to_string() .to_string()
); );
info!( info!(
"[pool: {}] Pool mode: {:?}", "[pool: {}] Default pool mode: {}",
pool_name, pool_config.pool_mode pool_name,
pool_config.pool_mode.to_string()
); );
info!( info!(
"[pool: {}] Load Balancing mode: {:?}", "[pool: {}] Load Balancing mode: {:?}",
@@ -859,16 +989,48 @@ impl Config {
pool_name, pool_name,
pool_config.users.len() pool_config.users.len()
); );
info!(
"[pool: {}] Max server lifetime: {}",
pool_name,
match pool_config.server_lifetime {
Some(server_lifetime) => format!("{}ms", server_lifetime),
None => "default".to_string(),
}
);
for user in &pool_config.users { for user in &pool_config.users {
info!( info!(
"[pool: {}][user: {}] Pool size: {}", "[pool: {}][user: {}] Pool size: {}",
pool_name, user.1.username, user.1.pool_size, pool_name, user.1.username, user.1.pool_size,
); );
info!(
"[pool: {}][user: {}] Minimum pool size: {}",
pool_name,
user.1.username,
user.1.min_pool_size.unwrap_or(0)
);
info!( info!(
"[pool: {}][user: {}] Statement timeout: {}", "[pool: {}][user: {}] Statement timeout: {}",
pool_name, user.1.username, user.1.statement_timeout pool_name, user.1.username, user.1.statement_timeout
) );
info!(
"[pool: {}][user: {}] Pool mode: {}",
pool_name,
user.1.username,
match user.1.pool_mode {
Some(pool_mode) => pool_mode.to_string(),
None => pool_config.pool_mode.to_string(),
}
);
info!(
"[pool: {}][user: {}] Max server lifetime: {}",
pool_name,
user.1.username,
match user.1.server_lifetime {
Some(server_lifetime) => format!("{}ms", server_lifetime),
None => "default".to_string(),
}
);
} }
} }
} }
@@ -879,7 +1041,13 @@ impl Config {
&& (self.general.auth_query_user.is_none() && (self.general.auth_query_user.is_none()
|| self.general.auth_query_password.is_none()) || self.general.auth_query_password.is_none())
{ {
error!("If auth_query is specified, you need to provide a value for `auth_query_user`, `auth_query_password`"); error!(
"If auth_query is specified, \
you need to provide a value \
for `auth_query_user`, \
`auth_query_password`"
);
return Err(Error::BadConfig); return Err(Error::BadConfig);
} }
@@ -887,7 +1055,14 @@ impl Config {
if pool.auth_query.is_some() if pool.auth_query.is_some()
&& (pool.auth_query_user.is_none() || pool.auth_query_password.is_none()) && (pool.auth_query_user.is_none() || pool.auth_query_password.is_none())
{ {
error!("Error in pool {{ {} }}. If auth_query is specified, you need to provide a value for `auth_query_user`, `auth_query_password`", name); error!(
"Error in pool {{ {} }}. \
If auth_query is specified, you need \
to provide a value for `auth_query_user`, \
`auth_query_password`",
name
);
return Err(Error::BadConfig); return Err(Error::BadConfig);
} }
@@ -897,7 +1072,13 @@ impl Config {
|| pool.auth_query_user.is_none()) || pool.auth_query_user.is_none())
&& user_data.password.is_none() && user_data.password.is_none()
{ {
error!("Error in pool {{ {} }}. You have to specify a user password for every pool if auth_query is not specified", name); error!(
"Error in pool {{ {} }}. \
You have to specify a user password \
for every pool if auth_query is not specified",
name
);
return Err(Error::BadConfig); return Err(Error::BadConfig);
} }
} }
@@ -995,6 +1176,7 @@ pub async fn parse(path: &str) -> Result<(), Error> {
pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, Error> { pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, Error> {
let old_config = get_config(); let old_config = get_config();
match parse(&old_config.path).await { match parse(&old_config.path).await {
Ok(()) => (), Ok(()) => (),
Err(err) => { Err(err) => {
@@ -1002,14 +1184,18 @@ pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, E
return Err(Error::BadConfig); return Err(Error::BadConfig);
} }
}; };
let new_config = get_config(); let new_config = get_config();
if old_config.pools != new_config.pools { match CachedResolver::from_config().await {
info!("Pool configuration changed"); Ok(_) => (),
Err(err) => error!("DNS cache reinitialization error: {:?}", err),
};
if old_config != new_config {
info!("Config changed, reloading");
ConnectionPool::from_config(client_server_map).await?; ConnectionPool::from_config(client_server_map).await?;
Ok(true) Ok(true)
} else if old_config != new_config {
Ok(true)
} else { } else {
Ok(false) Ok(false)
} }

410
src/dns_cache.rs Normal file
View File

@@ -0,0 +1,410 @@
use crate::config::get_config;
use crate::errors::Error;
use arc_swap::ArcSwap;
use log::{debug, error, info, warn};
use once_cell::sync::Lazy;
use std::collections::{HashMap, HashSet};
use std::io;
use std::net::IpAddr;
use std::sync::Arc;
use std::sync::RwLock;
use tokio::time::{sleep, Duration};
use trust_dns_resolver::error::{ResolveError, ResolveResult};
use trust_dns_resolver::lookup_ip::LookupIp;
use trust_dns_resolver::TokioAsyncResolver;
/// Cached Resolver Globally available
pub static CACHED_RESOLVER: Lazy<ArcSwap<CachedResolver>> =
Lazy::new(|| ArcSwap::from_pointee(CachedResolver::default()));
// Ip addressed are returned as a set of addresses
// so we can compare.
#[derive(Clone, PartialEq, Debug)]
pub struct AddrSet {
set: HashSet<IpAddr>,
}
impl AddrSet {
fn new() -> AddrSet {
AddrSet {
set: HashSet::new(),
}
}
}
impl From<LookupIp> for AddrSet {
fn from(lookup_ip: LookupIp) -> Self {
let mut addr_set = AddrSet::new();
for address in lookup_ip.iter() {
addr_set.set.insert(address);
}
addr_set
}
}
///
/// A CachedResolver is a DNS resolution cache mechanism with customizable expiration time.
///
/// The system works as follows:
///
/// When a host is to be resolved, if we have not resolved it before, a new resolution is
/// executed and stored in the internal cache. Concurrently, every `dns_max_ttl` time, the
/// cache is refreshed.
///
/// # Example:
///
/// ```
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
///
/// # tokio_test::block_on(async {
/// let config = CachedResolverConfig::default();
/// let resolver = CachedResolver::new(config, None).await.unwrap();
/// let addrset = resolver.lookup_ip("www.example.com.").await.unwrap();
/// # })
/// ```
///
/// // Now the ip resolution is stored in local cache and subsequent
/// // calls will be returned from cache. Also, the cache is refreshed
/// // and updated every 10 seconds.
///
/// // You can now check if an 'old' lookup differs from what it's currently
/// // store in cache by using `has_changed`.
/// resolver.has_changed("www.example.com.", addrset)
#[derive(Default)]
pub struct CachedResolver {
// The configuration of the cached_resolver.
config: CachedResolverConfig,
// This is the hash that contains the hash.
data: Option<RwLock<HashMap<String, AddrSet>>>,
// The resolver to be used for DNS queries.
resolver: Option<TokioAsyncResolver>,
// The RefreshLoop
refresh_loop: RwLock<Option<tokio::task::JoinHandle<()>>>,
}
///
/// Configuration
#[derive(Clone, Debug, Default, PartialEq)]
pub struct CachedResolverConfig {
/// Amount of time in secods that a resolved dns address is considered stale.
dns_max_ttl: u64,
/// Enabled or disabled? (this is so we can reload config)
enabled: bool,
}
impl CachedResolverConfig {
fn new(dns_max_ttl: u64, enabled: bool) -> Self {
CachedResolverConfig {
dns_max_ttl,
enabled,
}
}
}
impl From<crate::config::Config> for CachedResolverConfig {
fn from(config: crate::config::Config) -> Self {
CachedResolverConfig::new(config.general.dns_max_ttl, config.general.dns_cache_enabled)
}
}
impl CachedResolver {
///
/// Returns a new Arc<CachedResolver> based on passed configuration.
/// It also starts the loop that will refresh cache entries.
///
/// # Arguments:
///
/// * `config` - The `CachedResolverConfig` to be used to create the resolver.
///
/// # Example:
///
/// ```
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
///
/// # tokio_test::block_on(async {
/// let config = CachedResolverConfig::default();
/// let resolver = CachedResolver::new(config, None).await.unwrap();
/// # })
/// ```
///
pub async fn new(
config: CachedResolverConfig,
data: Option<HashMap<String, AddrSet>>,
) -> Result<Arc<Self>, io::Error> {
// Construct a new Resolver with default configuration options
let resolver = Some(TokioAsyncResolver::tokio_from_system_conf()?);
let data = if let Some(hash) = data {
Some(RwLock::new(hash))
} else {
Some(RwLock::new(HashMap::new()))
};
let instance = Arc::new(Self {
config,
resolver,
data,
refresh_loop: RwLock::new(None),
});
if instance.enabled() {
info!("Scheduling DNS refresh loop");
let refresh_loop = tokio::task::spawn({
let instance = instance.clone();
async move {
instance.refresh_dns_entries_loop().await;
}
});
*(instance.refresh_loop.write().unwrap()) = Some(refresh_loop);
}
Ok(instance)
}
pub fn enabled(&self) -> bool {
self.config.enabled
}
// Schedules the refresher
async fn refresh_dns_entries_loop(&self) {
let resolver = TokioAsyncResolver::tokio_from_system_conf().unwrap();
let interval = Duration::from_secs(self.config.dns_max_ttl);
loop {
debug!("Begin refreshing cached DNS addresses.");
// To minimize the time we hold the lock, we first create
// an array with keys.
let mut hostnames: Vec<String> = Vec::new();
{
if let Some(ref data) = self.data {
for hostname in data.read().unwrap().keys() {
hostnames.push(hostname.clone());
}
}
}
for hostname in hostnames.iter() {
let addrset = self
.fetch_from_cache(hostname.as_str())
.expect("Could not obtain expected address from cache, this should not happen");
match resolver.lookup_ip(hostname).await {
Ok(lookup_ip) => {
let new_addrset = AddrSet::from(lookup_ip);
debug!(
"Obtained address for host ({}) -> ({:?})",
hostname, new_addrset
);
if addrset != new_addrset {
debug!(
"Addr changed from {:?} to {:?} updating cache.",
addrset, new_addrset
);
self.store_in_cache(hostname, new_addrset);
}
}
Err(err) => {
error!(
"There was an error trying to resolv {}: ({}).",
hostname, err
);
}
}
}
debug!("Finished refreshing cached DNS addresses.");
sleep(interval).await;
}
}
/// Returns a `AddrSet` given the specified hostname.
///
/// This method first tries to fetch the value from the cache, if it misses
/// then it is resolved and stored in the cache. TTL from records is ignored.
///
/// # Arguments
///
/// * `host` - A string slice referencing the hostname to be resolved.
///
/// # Example:
///
/// ```
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
///
/// # tokio_test::block_on(async {
/// let config = CachedResolverConfig::default();
/// let resolver = CachedResolver::new(config, None).await.unwrap();
/// let response = resolver.lookup_ip("www.google.com.");
/// # })
/// ```
///
pub async fn lookup_ip(&self, host: &str) -> ResolveResult<AddrSet> {
debug!("Lookup up {} in cache", host);
match self.fetch_from_cache(host) {
Some(addr_set) => {
debug!("Cache hit!");
Ok(addr_set)
}
None => {
debug!("Not found, executing a dns query!");
if let Some(ref resolver) = self.resolver {
let addr_set = AddrSet::from(resolver.lookup_ip(host).await?);
debug!("Obtained: {:?}", addr_set);
self.store_in_cache(host, addr_set.clone());
Ok(addr_set)
} else {
Err(ResolveError::from("No resolver available"))
}
}
}
}
//
// Returns true if the stored host resolution differs from the AddrSet passed.
pub fn has_changed(&self, host: &str, addr_set: &AddrSet) -> bool {
if let Some(fetched_addr_set) = self.fetch_from_cache(host) {
return fetched_addr_set != *addr_set;
}
false
}
// Fetches an AddrSet from the inner cache adquiring the read lock.
fn fetch_from_cache(&self, key: &str) -> Option<AddrSet> {
if let Some(ref hash) = self.data {
if let Some(addr_set) = hash.read().unwrap().get(key) {
return Some(addr_set.clone());
}
}
None
}
// Sets up the global CACHED_RESOLVER static variable so we can globally use DNS
// cache.
pub async fn from_config() -> Result<(), Error> {
let cached_resolver = CACHED_RESOLVER.load();
let desired_config = CachedResolverConfig::from(get_config());
if cached_resolver.config != desired_config {
if let Some(ref refresh_loop) = *(cached_resolver.refresh_loop.write().unwrap()) {
warn!("Killing Dnscache refresh loop as its configuration is being reloaded");
refresh_loop.abort()
}
let new_resolver = if let Some(ref data) = cached_resolver.data {
let data = Some(data.read().unwrap().clone());
CachedResolver::new(desired_config, data).await
} else {
CachedResolver::new(desired_config, None).await
};
match new_resolver {
Ok(ok) => {
CACHED_RESOLVER.store(ok);
Ok(())
}
Err(err) => {
let message = format!("Error setting up cached_resolver. Error: {:?}, will continue without this feature.", err);
Err(Error::DNSCachedError(message))
}
}
} else {
Ok(())
}
}
// Stores the AddrSet in cache adquiring the write lock.
fn store_in_cache(&self, host: &str, addr_set: AddrSet) {
if let Some(ref data) = self.data {
data.write().unwrap().insert(host.to_string(), addr_set);
} else {
error!("Could not insert, Hash not initialized");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use trust_dns_resolver::error::ResolveError;
#[tokio::test]
async fn new() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await;
assert!(resolver.is_ok());
}
#[tokio::test]
async fn lookup_ip() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let response = resolver.lookup_ip("www.google.com.").await;
assert!(response.is_ok());
}
#[tokio::test]
async fn has_changed() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let hostname = "www.google.com.";
let response = resolver.lookup_ip(hostname).await;
let addr_set = response.unwrap();
assert!(!resolver.has_changed(hostname, &addr_set));
}
#[tokio::test]
async fn unknown_host() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let hostname = "www.idontexists.";
let response = resolver.lookup_ip(hostname).await;
assert!(matches!(response, Err(ResolveError { .. })));
}
#[tokio::test]
async fn incorrect_address() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let hostname = "w ww.idontexists.";
let response = resolver.lookup_ip(hostname).await;
assert!(matches!(response, Err(ResolveError { .. })));
assert!(!resolver.has_changed(hostname, &AddrSet::new()));
}
#[tokio::test]
// Ok, this test is based on the fact that google does DNS RR
// and does not responds with every available ip everytime, so
// if I cache here, it will miss after one cache iteration or two.
async fn thread() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let hostname = "www.google.com.";
let response = resolver.lookup_ip(hostname).await;
let addr_set = response.unwrap();
assert!(!resolver.has_changed(hostname, &addr_set));
let resolver_for_refresher = resolver.clone();
let _thread_handle = tokio::task::spawn(async move {
resolver_for_refresher.refresh_dns_entries_loop().await;
});
assert!(!resolver.has_changed(hostname, &addr_set));
}
}

View File

@@ -1,20 +1,123 @@
/// Errors. //! Errors.
/// Various errors. /// Various errors.
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq, Clone)]
pub enum Error { pub enum Error {
SocketError(String), SocketError(String),
ClientSocketError(String, ClientIdentifier),
ClientGeneralError(String, ClientIdentifier),
ClientAuthImpossible(String),
ClientAuthPassthroughError(String, ClientIdentifier),
ClientBadStartup, ClientBadStartup,
ProtocolSyncError(String), ProtocolSyncError(String),
BadQuery(String), BadQuery(String),
ServerError, ServerError,
ServerStartupError(String, ServerIdentifier),
ServerAuthError(String, ServerIdentifier),
BadConfig, BadConfig,
AllServersDown, AllServersDown,
ClientError(String), ClientError(String),
TlsError, TlsError,
StatementTimeout, StatementTimeout,
DNSCachedError(String),
ShuttingDown, ShuttingDown,
ParseBytesError(String), ParseBytesError(String),
AuthError(String), AuthError(String),
AuthPassthroughError(String), AuthPassthroughError(String),
UnsupportedStatement,
QueryRouterParserError(String),
}
#[derive(Clone, PartialEq, Debug)]
pub struct ClientIdentifier {
pub application_name: String,
pub username: String,
pub pool_name: String,
}
impl ClientIdentifier {
pub fn new(application_name: &str, username: &str, pool_name: &str) -> ClientIdentifier {
ClientIdentifier {
application_name: application_name.into(),
username: username.into(),
pool_name: pool_name.into(),
}
}
}
impl std::fmt::Display for ClientIdentifier {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"{{ application_name: {}, username: {}, pool_name: {} }}",
self.application_name, self.username, self.pool_name
)
}
}
#[derive(Clone, PartialEq, Debug)]
pub struct ServerIdentifier {
pub username: String,
pub database: String,
}
impl ServerIdentifier {
pub fn new(username: &str, database: &str) -> ServerIdentifier {
ServerIdentifier {
username: username.into(),
database: database.into(),
}
}
}
impl std::fmt::Display for ServerIdentifier {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"{{ username: {}, database: {} }}",
self.username, self.database
)
}
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match &self {
&Error::ClientSocketError(error, client_identifier) => write!(
f,
"Error reading {} from client {}",
error, client_identifier
),
&Error::ClientGeneralError(error, client_identifier) => {
write!(f, "{} {}", error, client_identifier)
}
&Error::ClientAuthImpossible(username) => write!(
f,
"Client auth not possible, \
no cleartext password set for username: {} \
in config and auth passthrough (query_auth) \
is not set up.",
username
),
&Error::ClientAuthPassthroughError(error, client_identifier) => write!(
f,
"No cleartext password set, \
and no auth passthrough could not \
obtain the hash from server for {}, \
the error was: {}",
client_identifier, error
),
&Error::ServerStartupError(error, server_identifier) => write!(
f,
"Error reading {} on server startup {}",
error, server_identifier,
),
&Error::ServerAuthError(error, server_identifier) => {
write!(f, "{} for {}", error, server_identifier,)
}
// The rest can use Debug.
err => write!(f, "{:?}", err),
}
}
} }

View File

@@ -1,11 +1,17 @@
pub mod admin;
pub mod auth_passthrough; pub mod auth_passthrough;
pub mod client;
pub mod config; pub mod config;
pub mod constants; pub mod constants;
pub mod dns_cache;
pub mod errors; pub mod errors;
pub mod messages; pub mod messages;
pub mod mirrors; pub mod mirrors;
pub mod multi_logger; pub mod multi_logger;
pub mod plugins;
pub mod pool; pub mod pool;
pub mod prometheus;
pub mod query_router;
pub mod scram; pub mod scram;
pub mod server; pub mod server;
pub mod sharding; pub mod sharding;

View File

@@ -36,6 +36,7 @@ extern crate sqlparser;
extern crate tokio; extern crate tokio;
extern crate tokio_rustls; extern crate tokio_rustls;
extern crate toml; extern crate toml;
extern crate trust_dns_resolver;
#[cfg(not(target_env = "msvc"))] #[cfg(not(target_env = "msvc"))]
use jemallocator::Jemalloc; use jemallocator::Jemalloc;
@@ -60,35 +61,19 @@ use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::broadcast; use tokio::sync::broadcast;
mod admin; use pgcat::config::{get_config, reload_config, VERSION};
mod auth_passthrough; use pgcat::dns_cache;
mod client; use pgcat::messages::configure_socket;
mod config; use pgcat::pool::{ClientServerMap, ConnectionPool};
mod constants; use pgcat::prometheus::start_metric_server;
mod errors; use pgcat::stats::{Collector, Reporter, REPORTER};
mod messages;
mod mirrors;
mod multi_logger;
mod pool;
mod prometheus;
mod query_router;
mod scram;
mod server;
mod sharding;
mod stats;
mod tls;
use crate::config::{get_config, reload_config, VERSION};
use crate::pool::{ClientServerMap, ConnectionPool};
use crate::prometheus::start_metric_server;
use crate::stats::{Collector, Reporter, REPORTER};
fn main() -> Result<(), Box<dyn std::error::Error>> { fn main() -> Result<(), Box<dyn std::error::Error>> {
multi_logger::MultiLogger::init().unwrap(); pgcat::multi_logger::MultiLogger::init().unwrap();
info!("Welcome to PgCat! Meow. (Version {})", VERSION); info!("Welcome to PgCat! Meow. (Version {})", VERSION);
if !query_router::QueryRouter::setup() { if !pgcat::query_router::QueryRouter::setup() {
error!("Could not setup query router"); error!("Could not setup query router");
std::process::exit(exitcode::CONFIG); std::process::exit(exitcode::CONFIG);
} }
@@ -106,7 +91,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let runtime = Builder::new_multi_thread().worker_threads(1).build()?; let runtime = Builder::new_multi_thread().worker_threads(1).build()?;
runtime.block_on(async { runtime.block_on(async {
match config::parse(&config_file).await { match pgcat::config::parse(&config_file).await {
Ok(_) => (), Ok(_) => (),
Err(err) => { Err(err) => {
error!("Config parse error: {:?}", err); error!("Config parse error: {:?}", err);
@@ -165,6 +150,12 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// Statistics reporting. // Statistics reporting.
REPORTER.store(Arc::new(Reporter::default())); REPORTER.store(Arc::new(Reporter::default()));
// Starts (if enabled) dns cache before pools initialization
match dns_cache::CachedResolver::from_config().await {
Ok(_) => (),
Err(err) => error!("DNS cache initialization error: {:?}", err),
};
// Connection pool that allows to query all shards and replicas. // Connection pool that allows to query all shards and replicas.
match ConnectionPool::from_config(client_server_map.clone()).await { match ConnectionPool::from_config(client_server_map.clone()).await {
Ok(_) => (), Ok(_) => (),
@@ -179,16 +170,19 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
stats_collector.collect().await; stats_collector.collect().await;
}); });
info!("Config autoreloader: {}", config.general.autoreload); info!("Config autoreloader: {}", match config.general.autoreload {
Some(interval) => format!("{} ms", interval),
None => "disabled".into(),
});
let mut autoreload_interval = tokio::time::interval(tokio::time::Duration::from_millis(15_000)); if let Some(interval) = config.general.autoreload {
let autoreload_client_server_map = client_server_map.clone(); let mut autoreload_interval = tokio::time::interval(tokio::time::Duration::from_millis(interval));
let autoreload_client_server_map = client_server_map.clone();
tokio::task::spawn(async move { tokio::task::spawn(async move {
loop { loop {
autoreload_interval.tick().await; autoreload_interval.tick().await;
if config.general.autoreload { debug!("Automatically reloading config");
info!("Automatically reloading config");
if let Ok(changed) = reload_config(autoreload_client_server_map.clone()).await { if let Ok(changed) = reload_config(autoreload_client_server_map.clone()).await {
if changed { if changed {
@@ -196,8 +190,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
} }
}; };
} }
} });
}); };
#[cfg(windows)] #[cfg(windows)]
let mut term_signal = win_signal::ctrl_close().unwrap(); let mut term_signal = win_signal::ctrl_close().unwrap();
@@ -282,18 +278,20 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let drain_tx = drain_tx.clone(); let drain_tx = drain_tx.clone();
let client_server_map = client_server_map.clone(); let client_server_map = client_server_map.clone();
let tls_certificate = config.general.tls_certificate.clone(); let tls_certificate = get_config().general.tls_certificate.clone();
configure_socket(&socket);
tokio::task::spawn(async move { tokio::task::spawn(async move {
let start = chrono::offset::Utc::now().naive_utc(); let start = chrono::offset::Utc::now().naive_utc();
match client::client_entrypoint( match pgcat::client::client_entrypoint(
socket, socket,
client_server_map, client_server_map,
shutdown_rx, shutdown_rx,
drain_tx, drain_tx,
admin_only, admin_only,
tls_certificate.clone(), tls_certificate,
config.general.log_client_connections, config.general.log_client_connections,
) )
.await .await
@@ -301,7 +299,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(()) => { Ok(()) => {
let duration = chrono::offset::Utc::now().naive_utc() - start; let duration = chrono::offset::Utc::now().naive_utc() - start;
if config.general.log_client_disconnections { if get_config().general.log_client_disconnections {
info!( info!(
"Client {:?} disconnected, session duration: {}", "Client {:?} disconnected, session duration: {}",
addr, addr,
@@ -318,7 +316,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
Err(err) => { Err(err) => {
match err { match err {
errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err), pgcat::errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err),
_ => warn!("Client disconnected with error {:?}", err), _ => warn!("Client disconnected with error {:?}", err),
} }

View File

@@ -20,6 +20,10 @@ pub enum DataType {
Text, Text,
Int4, Int4,
Numeric, Numeric,
Bool,
Oid,
AnyArray,
Any,
} }
impl From<&DataType> for i32 { impl From<&DataType> for i32 {
@@ -28,6 +32,10 @@ impl From<&DataType> for i32 {
DataType::Text => 25, DataType::Text => 25,
DataType::Int4 => 23, DataType::Int4 => 23,
DataType::Numeric => 1700, DataType::Numeric => 1700,
DataType::Bool => 16,
DataType::Oid => 26,
DataType::AnyArray => 2277,
DataType::Any => 2276,
} }
} }
} }
@@ -116,7 +124,10 @@ where
/// Send the startup packet the server. We're pretending we're a Pg client. /// Send the startup packet the server. We're pretending we're a Pg client.
/// This tells the server which user we are and what database we want. /// This tells the server which user we are and what database we want.
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> { pub async fn startup<S>(stream: &mut S, user: &str, database: &str) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut bytes = BytesMut::with_capacity(25); let mut bytes = BytesMut::with_capacity(25);
bytes.put_i32(196608); // Protocol number bytes.put_i32(196608); // Protocol number
@@ -150,6 +161,21 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
} }
} }
pub async fn ssl_request(stream: &mut TcpStream) -> Result<(), Error> {
let mut bytes = BytesMut::with_capacity(12);
bytes.put_i32(8);
bytes.put_i32(80877103);
match stream.write_all(&bytes).await {
Ok(_) => Ok(()),
Err(err) => Err(Error::SocketError(format!(
"Error writing SSLRequest to server socket - Error: {:?}",
err
))),
}
}
/// Parse the params the server sends as a key/value format. /// Parse the params the server sends as a key/value format.
pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> { pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
let mut result = HashMap::new(); let mut result = HashMap::new();
@@ -404,7 +430,7 @@ pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
let mut res = BytesMut::new(); let mut res = BytesMut::new();
let mut row_desc = BytesMut::new(); let mut row_desc = BytesMut::new();
// how many colums we are storing // how many columns we are storing
row_desc.put_i16(columns.len() as i16); row_desc.put_i16(columns.len() as i16);
for (name, data_type) in columns { for (name, data_type) in columns {
@@ -425,6 +451,10 @@ pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
DataType::Text => -1, DataType::Text => -1,
DataType::Int4 => 4, DataType::Int4 => 4,
DataType::Numeric => -1, DataType::Numeric => -1,
DataType::Bool => 1,
DataType::Oid => 4,
DataType::AnyArray => -1,
DataType::Any => -1,
}; };
row_desc.put_i16(type_size); row_desc.put_i16(type_size);
@@ -463,6 +493,29 @@ pub fn data_row(row: &Vec<String>) -> BytesMut {
res res
} }
pub fn data_row_nullable(row: &Vec<Option<String>>) -> BytesMut {
let mut res = BytesMut::new();
let mut data_row = BytesMut::new();
data_row.put_i16(row.len() as i16);
for column in row {
if let Some(column) = column {
let column = column.as_bytes();
data_row.put_i32(column.len() as i32);
data_row.put_slice(column);
} else {
data_row.put_i32(-1 as i32);
}
}
res.put_u8(b'D');
res.put_i32(data_row.len() as i32 + 4);
res.put(data_row);
res
}
/// Create a CommandComplete message. /// Create a CommandComplete message.
pub fn command_complete(command: &str) -> BytesMut { pub fn command_complete(command: &str) -> BytesMut {
let cmd = BytesMut::from(format!("{}\0", command).as_bytes()); let cmd = BytesMut::from(format!("{}\0", command).as_bytes());
@@ -505,6 +558,29 @@ where
} }
} }
pub async fn write_all_flush<S>(stream: &mut S, buf: &[u8]) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
match stream.write_all(buf).await {
Ok(_) => match stream.flush().await {
Ok(_) => Ok(()),
Err(err) => {
return Err(Error::SocketError(format!(
"Error flushing socket - Error: {:?}",
err
)))
}
},
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
)))
}
}
}
/// Read a complete message from the socket. /// Read a complete message from the socket.
pub async fn read_message<S>(stream: &mut S) -> Result<BytesMut, Error> pub async fn read_message<S>(stream: &mut S) -> Result<BytesMut, Error>
where where

View File

@@ -17,7 +17,7 @@ use log::{Level, Log, Metadata, Record, SetLoggerError};
// //
// So to summarize, if no `STDOUT_LOG` env var is present, the logger is the default logger. If `STDOUT_LOG` is set, everything // So to summarize, if no `STDOUT_LOG` env var is present, the logger is the default logger. If `STDOUT_LOG` is set, everything
// but errors, that matches the log level set in the `STDOUT_LOG` env var is sent to stdout. You can have also some esoteric configuration // but errors, that matches the log level set in the `STDOUT_LOG` env var is sent to stdout. You can have also some esoteric configuration
// where you set `RUST_LOG=debug` and `STDOUT_LOG=info`, in here, erros will go to stderr, warns and infos to stdout and debugs to stderr. // where you set `RUST_LOG=debug` and `STDOUT_LOG=info`, in here, errors will go to stderr, warns and infos to stdout and debugs to stderr.
// //
pub struct MultiLogger { pub struct MultiLogger {
stderr_logger: env_logger::Logger, stderr_logger: env_logger::Logger,

288
src/plugins/intercept.rs Normal file
View File

@@ -0,0 +1,288 @@
//! The intercept plugin.
//!
//! It intercepts queries and returns fake results.
use arc_swap::ArcSwap;
use async_trait::async_trait;
use bytes::{BufMut, BytesMut};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use sqlparser::ast::Statement;
use std::collections::HashMap;
use log::{debug, info};
use std::sync::Arc;
use crate::{
config::Intercept as InterceptConfig,
errors::Error,
messages::{command_complete, data_row_nullable, row_description, DataType},
plugins::{Plugin, PluginOutput},
pool::{PoolIdentifier, PoolMap},
query_router::QueryRouter,
};
pub static CONFIG: Lazy<ArcSwap<HashMap<PoolIdentifier, InterceptConfig>>> =
Lazy::new(|| ArcSwap::from_pointee(HashMap::new()));
/// Check if the interceptor plugin has been enabled.
pub fn enabled() -> bool {
!CONFIG.load().is_empty()
}
pub fn setup(intercept_config: &InterceptConfig, pools: &PoolMap) {
let mut config = HashMap::new();
for (identifier, _) in pools.iter() {
let mut intercept_config = intercept_config.clone();
intercept_config.substitute(&identifier.db, &identifier.user);
config.insert(identifier.clone(), intercept_config);
}
CONFIG.store(Arc::new(config));
info!("Intercepting {} queries", intercept_config.queries.len());
}
pub fn disable() {
CONFIG.store(Arc::new(HashMap::new()));
}
// TODO: use these structs for deserialization
#[derive(Serialize, Deserialize)]
pub struct Rule {
query: String,
schema: Vec<Column>,
result: Vec<Vec<String>>,
}
#[derive(Serialize, Deserialize)]
pub struct Column {
name: String,
data_type: String,
}
/// The intercept plugin.
pub struct Intercept;
#[async_trait]
impl Plugin for Intercept {
async fn run(
&mut self,
query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error> {
if ast.is_empty() {
return Ok(PluginOutput::Allow);
}
let mut result = BytesMut::new();
let query_map = match CONFIG.load().get(&PoolIdentifier::new(
&query_router.pool_settings().db,
&query_router.pool_settings().user.username,
)) {
Some(query_map) => query_map.clone(),
None => return Ok(PluginOutput::Allow),
};
for q in ast {
// Normalization
let q = q.to_string().to_ascii_lowercase();
for (_, target) in query_map.queries.iter() {
if target.query.as_str() == q {
debug!("Intercepting query: {}", q);
let rd = target
.schema
.iter()
.map(|row| {
let name = &row[0];
let data_type = &row[1];
(
name.as_str(),
match data_type.as_str() {
"text" => DataType::Text,
"anyarray" => DataType::AnyArray,
"oid" => DataType::Oid,
"bool" => DataType::Bool,
"int4" => DataType::Int4,
_ => DataType::Any,
},
)
})
.collect::<Vec<(&str, DataType)>>();
result.put(row_description(&rd));
target.result.iter().for_each(|row| {
let row = row
.iter()
.map(|s| {
let s = s.as_str().to_string();
if s == "" {
None
} else {
Some(s)
}
})
.collect::<Vec<Option<String>>>();
result.put(data_row_nullable(&row));
});
result.put(command_complete("SELECT"));
}
}
}
if !result.is_empty() {
result.put_u8(b'Z');
result.put_i32(5);
result.put_u8(b'I');
return Ok(PluginOutput::Intercept(result));
} else {
Ok(PluginOutput::Allow)
}
}
}
/// Make IntelliJ SQL plugin believe it's talking to an actual database
/// instead of PgCat.
#[allow(dead_code)]
fn fool_datagrip(database: &str, user: &str) -> Value {
json!([
{
"query": "select current_database() as a, current_schemas(false) as b",
"schema": [
{
"name": "a",
"data_type": "text",
},
{
"name": "b",
"data_type": "anyarray",
},
],
"result": [
[database, "{public}"],
],
},
{
"query": "select current_database(), current_schema(), current_user",
"schema": [
{
"name": "current_database",
"data_type": "text",
},
{
"name": "current_schema",
"data_type": "text",
},
{
"name": "current_user",
"data_type": "text",
}
],
"result": [
["sharded_db", "public", "sharding_user"],
],
},
{
"query": "select cast(n.oid as bigint) as id, datname as name, d.description, datistemplate as is_template, datallowconn as allow_connections, pg_catalog.pg_get_userbyid(n.datdba) as \"owner\" from pg_catalog.pg_database as n left join pg_catalog.pg_shdescription as d on n.oid = d.objoid order by case when datname = pg_catalog.current_database() then -cast(1 as bigint) else cast(n.oid as bigint) end",
"schema": [
{
"name": "id",
"data_type": "oid",
},
{
"name": "name",
"data_type": "text",
},
{
"name": "description",
"data_type": "text",
},
{
"name": "is_template",
"data_type": "bool",
},
{
"name": "allow_connections",
"data_type": "bool",
},
{
"name": "owner",
"data_type": "text",
}
],
"result": [
["16387", database, "", "f", "t", user],
]
},
{
"query": "select cast(r.oid as bigint) as role_id, rolname as role_name, rolsuper as is_super, rolinherit as is_inherit, rolcreaterole as can_createrole, rolcreatedb as can_createdb, rolcanlogin as can_login, rolreplication as is_replication, rolconnlimit as conn_limit, rolvaliduntil as valid_until, rolbypassrls as bypass_rls, rolconfig as config, d.description from pg_catalog.pg_roles as r left join pg_catalog.pg_shdescription as d on d.objoid = r.oid",
"schema": [
{
"name": "role_id",
"data_type": "oid",
},
{
"name": "role_name",
"data_type": "text",
},
{
"name": "is_super",
"data_type": "bool",
},
{
"name": "is_inherit",
"data_type": "bool",
},
{
"name": "can_createrole",
"data_type": "bool",
},
{
"name": "can_createdb",
"data_type": "bool",
},
{
"name": "can_login",
"data_type": "bool",
},
{
"name": "is_replication",
"data_type": "bool",
},
{
"name": "conn_limit",
"data_type": "int4",
},
{
"name": "valid_until",
"data_type": "text",
},
{
"name": "bypass_rls",
"data_type": "bool",
},
{
"name": "config",
"data_type": "text",
},
{
"name": "description",
"data_type": "text",
},
],
"result": [
["10", "postgres", "f", "t", "f", "f", "t", "f", "-1", "", "f", "", ""],
["16419", user, "f", "t", "f", "f", "t", "f", "-1", "", "f", "", ""],
]
}
])
}

43
src/plugins/mod.rs Normal file
View File

@@ -0,0 +1,43 @@
//! The plugin ecosystem.
//!
//! Currently plugins only grant access or deny access to the database for a particual query.
//! Example use cases:
//! - block known bad queries
//! - block access to system catalogs
//! - block dangerous modifications like `DROP TABLE`
//! - etc
//!
pub mod intercept;
pub mod query_logger;
pub mod table_access;
use crate::{errors::Error, query_router::QueryRouter};
use async_trait::async_trait;
use bytes::BytesMut;
use sqlparser::ast::Statement;
pub use intercept::Intercept;
pub use query_logger::QueryLogger;
pub use table_access::TableAccess;
#[derive(Clone, Debug, PartialEq)]
pub enum PluginOutput {
Allow,
Deny(String),
Overwrite(Vec<Statement>),
Intercept(BytesMut),
}
#[async_trait]
pub trait Plugin {
// Run before the query is sent to the server.
async fn run(
&mut self,
query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error>;
// TODO: run after the result is returned
// async fn callback(&mut self, query_router: &QueryRouter);
}

View File

@@ -0,0 +1,49 @@
//! Log all queries to stdout (or somewhere else, why not).
use crate::{
errors::Error,
plugins::{Plugin, PluginOutput},
query_router::QueryRouter,
};
use arc_swap::ArcSwap;
use async_trait::async_trait;
use log::info;
use once_cell::sync::Lazy;
use sqlparser::ast::Statement;
use std::sync::Arc;
static ENABLED: Lazy<ArcSwap<bool>> = Lazy::new(|| ArcSwap::from_pointee(false));
pub struct QueryLogger;
pub fn setup() {
ENABLED.store(Arc::new(true));
info!("Logging queries to stdout");
}
pub fn disable() {
ENABLED.store(Arc::new(false));
}
pub fn enabled() -> bool {
**ENABLED.load()
}
#[async_trait]
impl Plugin for QueryLogger {
async fn run(
&mut self,
_query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error> {
let query = ast
.iter()
.map(|q| q.to_string())
.collect::<Vec<String>>()
.join("; ");
info!("{}", query);
Ok(PluginOutput::Allow)
}
}

View File

@@ -0,0 +1,73 @@
//! This query router plugin will check if the user can access a particular
//! table as part of their query. If they can't, the query will not be routed.
use async_trait::async_trait;
use sqlparser::ast::{visit_relations, Statement};
use crate::{
config::TableAccess as TableAccessConfig,
errors::Error,
plugins::{Plugin, PluginOutput},
query_router::QueryRouter,
};
use log::{debug, info};
use arc_swap::ArcSwap;
use core::ops::ControlFlow;
use once_cell::sync::Lazy;
use std::sync::Arc;
static CONFIG: Lazy<ArcSwap<Vec<String>>> = Lazy::new(|| ArcSwap::from_pointee(vec![]));
pub fn setup(config: &TableAccessConfig) {
CONFIG.store(Arc::new(config.tables.clone()));
info!("Blocking access to {} tables", config.tables.len());
}
pub fn enabled() -> bool {
!CONFIG.load().is_empty()
}
pub fn disable() {
CONFIG.store(Arc::new(vec![]));
}
pub struct TableAccess;
#[async_trait]
impl Plugin for TableAccess {
async fn run(
&mut self,
_query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error> {
let mut found = None;
let forbidden_tables = CONFIG.load();
visit_relations(ast, |relation| {
let relation = relation.to_string();
let parts = relation.split(".").collect::<Vec<&str>>();
let table_name = parts.last().unwrap();
if forbidden_tables.contains(&table_name.to_string()) {
found = Some(table_name.to_string());
ControlFlow::<()>::Break(())
} else {
ControlFlow::<()>::Continue(())
}
});
if let Some(found) = found {
debug!("Blocking access to table \"{}\"", found);
Ok(PluginOutput::Deny(format!(
"permission for table \"{}\" denied",
found
)))
} else {
Ok(PluginOutput::Allow)
}
}
}

View File

@@ -61,6 +61,8 @@ pub struct PoolIdentifier {
pub user: String, pub user: String,
} }
static POOL_REAPER_RATE: u64 = 30_000; // 30 seconds by default
impl PoolIdentifier { impl PoolIdentifier {
/// Create a new user/pool identifier. /// Create a new user/pool identifier.
pub fn new(db: &str, user: &str) -> PoolIdentifier { pub fn new(db: &str, user: &str) -> PoolIdentifier {
@@ -91,6 +93,7 @@ pub struct PoolSettings {
// Connecting user. // Connecting user.
pub user: User, pub user: User,
pub db: String,
// Default server role to connect to. // Default server role to connect to.
pub default_role: Option<Role>, pub default_role: Option<Role>,
@@ -138,6 +141,7 @@ impl Default for PoolSettings {
load_balancing_mode: LoadBalancingMode::Random, load_balancing_mode: LoadBalancingMode::Random,
shards: 1, shards: 1,
user: User::default(), user: User::default(),
db: String::default(),
default_role: None, default_role: None,
query_parser_enabled: false, query_parser_enabled: false,
primary_reads_enabled: true, primary_reads_enabled: true,
@@ -311,21 +315,34 @@ impl ConnectionPool {
if let Some(apt) = &auth_passthrough { if let Some(apt) = &auth_passthrough {
match apt.fetch_hash(&address).await { match apt.fetch_hash(&address).await {
Ok(ok) => { Ok(ok) => {
if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) { if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read())
if ok != *pool_auth_hash_value { {
warn!("Hash is not the same across shards of the same pool, client auth will \ if ok != *pool_auth_hash_value {
be done using last obtained hash. Server: {}:{}, Database: {}", server.host, server.port, shard.database); warn!(
} "Hash is not the same across shards \
} of the same pool, client auth will \
debug!("Hash obtained for {:?}", address); be done using last obtained hash. \
{ Server: {}:{}, Database: {}",
let mut pool_auth_hash = pool_auth_hash.write(); server.host, server.port, shard.database,
*pool_auth_hash = Some(ok.clone()); );
} }
}, }
Err(err) => warn!("Could not obtain password hashes using auth_query config, ignoring. Error: {:?}", err),
} debug!("Hash obtained for {:?}", address);
{
let mut pool_auth_hash = pool_auth_hash.write();
*pool_auth_hash = Some(ok.clone());
}
}
Err(err) => warn!(
"Could not obtain password hashes \
using auth_query config, ignoring. \
Error: {:?}",
err,
),
}
} }
let manager = ServerPool::new( let manager = ServerPool::new(
@@ -347,14 +364,31 @@ impl ConnectionPool {
None => config.general.idle_timeout, None => config.general.idle_timeout,
}; };
let server_lifetime = match user.server_lifetime {
Some(server_lifetime) => server_lifetime,
None => match pool_config.server_lifetime {
Some(server_lifetime) => server_lifetime,
None => config.general.server_lifetime,
},
};
let reaper_rate = *vec![idle_timeout, server_lifetime, POOL_REAPER_RATE]
.iter()
.min()
.unwrap();
debug!("Pool reaper rate: {}ms", reaper_rate);
let pool = Pool::builder() let pool = Pool::builder()
.max_size(user.pool_size) .max_size(user.pool_size)
.min_idle(user.min_pool_size)
.connection_timeout(std::time::Duration::from_millis(connect_timeout)) .connection_timeout(std::time::Duration::from_millis(connect_timeout))
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout))) .idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
.max_lifetime(Some(std::time::Duration::from_millis(server_lifetime)))
.reaper_rate(std::time::Duration::from_millis(reaper_rate))
.test_on_check_out(false) .test_on_check_out(false)
.build(manager) .build(manager)
.await .await?;
.unwrap();
pools.push(pool); pools.push(pool);
servers.push(address); servers.push(address);
@@ -382,11 +416,15 @@ impl ConnectionPool {
server_info: Arc::new(RwLock::new(BytesMut::new())), server_info: Arc::new(RwLock::new(BytesMut::new())),
auth_hash: pool_auth_hash, auth_hash: pool_auth_hash,
settings: PoolSettings { settings: PoolSettings {
pool_mode: pool_config.pool_mode, pool_mode: match user.pool_mode {
Some(pool_mode) => pool_mode,
None => pool_config.pool_mode,
},
load_balancing_mode: pool_config.load_balancing_mode, load_balancing_mode: pool_config.load_balancing_mode,
// shards: pool_config.shards.clone(), // shards: pool_config.shards.clone(),
shards: shard_ids.len(), shards: shard_ids.len(),
user: user.clone(), user: user.clone(),
db: pool_name.clone(),
default_role: match pool_config.default_role.as_str() { default_role: match pool_config.default_role.as_str() {
"any" => None, "any" => None,
"replica" => Some(Role::Replica), "replica" => Some(Role::Replica),
@@ -431,6 +469,32 @@ impl ConnectionPool {
} }
} }
if let Some(ref plugins) = config.plugins {
if let Some(ref intercept) = plugins.intercept {
if intercept.enabled {
crate::plugins::intercept::setup(intercept, &new_pools);
} else {
crate::plugins::intercept::disable();
}
}
if let Some(ref table_access) = plugins.table_access {
if table_access.enabled {
crate::plugins::table_access::setup(table_access);
} else {
crate::plugins::table_access::disable();
}
}
if let Some(ref query_logger) = plugins.query_logger {
if query_logger.enabled {
crate::plugins::query_logger::setup();
} else {
crate::plugins::query_logger::disable();
}
}
}
POOLS.store(Arc::new(new_pools.clone())); POOLS.store(Arc::new(new_pools.clone()));
Ok(()) Ok(())
} }

View File

@@ -1,4 +1,4 @@
/// Route queries automatically based on explicitely requested /// Route queries automatically based on explicitly requested
/// or implied query characteristics. /// or implied query characteristics.
use bytes::{Buf, BytesMut}; use bytes::{Buf, BytesMut};
use log::{debug, error}; use log::{debug, error};
@@ -6,13 +6,19 @@ use once_cell::sync::OnceCell;
use regex::{Regex, RegexSet}; use regex::{Regex, RegexSet};
use sqlparser::ast::Statement::{Query, StartTransaction}; use sqlparser::ast::Statement::{Query, StartTransaction};
use sqlparser::ast::{ use sqlparser::ast::{
BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, TableFactor, Value, BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, Statement, TableFactor,
Value,
}; };
use sqlparser::dialect::PostgreSqlDialect; use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser; use sqlparser::parser::Parser;
use crate::config::Role; use crate::config::Role;
use crate::errors::Error;
use crate::messages::BytesMutReader; use crate::messages::BytesMutReader;
use crate::plugins::{
intercept, query_logger, table_access, Intercept, Plugin, PluginOutput, QueryLogger,
TableAccess,
};
use crate::pool::PoolSettings; use crate::pool::PoolSettings;
use crate::sharding::Sharder; use crate::sharding::Sharder;
@@ -129,6 +135,10 @@ impl QueryRouter {
self.pool_settings = pool_settings; self.pool_settings = pool_settings;
} }
pub fn pool_settings<'a>(&'a self) -> &'a PoolSettings {
&self.pool_settings
}
/// Try to parse a command and execute it. /// Try to parse a command and execute it.
pub fn try_execute_command(&mut self, message_buffer: &BytesMut) -> Option<(Command, String)> { pub fn try_execute_command(&mut self, message_buffer: &BytesMut) -> Option<(Command, String)> {
let mut message_cursor = Cursor::new(message_buffer); let mut message_cursor = Cursor::new(message_buffer);
@@ -324,10 +334,7 @@ impl QueryRouter {
Some((command, value)) Some((command, value))
} }
/// Try to infer which server to connect to based on the contents of the query. pub fn parse(message: &BytesMut) -> Result<Vec<sqlparser::ast::Statement>, Error> {
pub fn infer(&mut self, message: &BytesMut) -> bool {
debug!("Inferring role");
let mut message_cursor = Cursor::new(message); let mut message_cursor = Cursor::new(message);
let code = message_cursor.get_u8() as char; let code = message_cursor.get_u8() as char;
@@ -353,28 +360,29 @@ impl QueryRouter {
query query
} }
_ => return false, _ => return Err(Error::UnsupportedStatement),
}; };
let ast = match Parser::parse_sql(&PostgreSqlDialect {}, &query) { match Parser::parse_sql(&PostgreSqlDialect {}, &query) {
Ok(ast) => ast, Ok(ast) => Ok(ast),
Err(err) => { Err(err) => {
// SELECT ... FOR UPDATE won't get parsed correctly.
debug!("{}: {}", err, query); debug!("{}: {}", err, query);
self.active_role = Some(Role::Primary); Err(Error::QueryRouterParserError(err.to_string()))
return false;
} }
}; }
}
debug!("AST: {:?}", ast); /// Try to infer which server to connect to based on the contents of the query.
pub fn infer(&mut self, ast: &Vec<sqlparser::ast::Statement>) -> Result<(), Error> {
debug!("Inferring role");
if ast.is_empty() { if ast.is_empty() {
// That's weird, no idea, let's go to primary // That's weird, no idea, let's go to primary
self.active_role = Some(Role::Primary); self.active_role = Some(Role::Primary);
return false; return Err(Error::QueryRouterParserError("empty query".into()));
} }
for q in &ast { for q in ast {
match q { match q {
// All transactions go to the primary, probably a write. // All transactions go to the primary, probably a write.
StartTransaction { .. } => { StartTransaction { .. } => {
@@ -418,7 +426,7 @@ impl QueryRouter {
}; };
} }
true Ok(())
} }
/// Parse the shard number from the Bind message /// Parse the shard number from the Bind message
@@ -783,6 +791,34 @@ impl QueryRouter {
} }
} }
/// Add your plugins here and execute them.
pub async fn execute_plugins(&self, ast: &Vec<Statement>) -> Result<PluginOutput, Error> {
if query_logger::enabled() {
let mut query_logger = QueryLogger {};
let _ = query_logger.run(&self, ast).await;
}
if intercept::enabled() {
let mut intercept = Intercept {};
let result = intercept.run(&self, ast).await;
if let Ok(PluginOutput::Intercept(output)) = result {
return Ok(PluginOutput::Intercept(output));
}
}
if table_access::enabled() {
let mut table_access = TableAccess {};
let result = table_access.run(&self, ast).await;
if let Ok(PluginOutput::Deny(error)) = result {
return Ok(PluginOutput::Deny(error));
}
}
Ok(PluginOutput::Allow)
}
fn set_sharding_key(&mut self, sharding_key: i64) -> Option<usize> { fn set_sharding_key(&mut self, sharding_key: i64) -> Option<usize> {
let sharder = Sharder::new( let sharder = Sharder::new(
self.pool_settings.shards, self.pool_settings.shards,
@@ -810,11 +846,22 @@ impl QueryRouter {
/// Should we attempt to parse queries? /// Should we attempt to parse queries?
pub fn query_parser_enabled(&self) -> bool { pub fn query_parser_enabled(&self) -> bool {
let enabled = match self.query_parser_enabled { let enabled = match self.query_parser_enabled {
None => self.pool_settings.query_parser_enabled, None => {
Some(value) => value, debug!(
}; "Using pool settings, query_parser_enabled: {}",
self.pool_settings.query_parser_enabled
);
self.pool_settings.query_parser_enabled
}
debug!("Query parser enabled: {}", enabled); Some(value) => {
debug!(
"Using query parser override, query_parser_enabled: {}",
value
);
value
}
};
enabled enabled
} }
@@ -862,7 +909,7 @@ mod test {
for query in queries { for query in queries {
// It's a recognized query // It's a recognized query
assert!(qr.infer(&query)); assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Replica)); assert_eq!(qr.role(), Some(Role::Replica));
} }
} }
@@ -881,7 +928,7 @@ mod test {
for query in queries { for query in queries {
// It's a recognized query // It's a recognized query
assert!(qr.infer(&query)); assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Primary)); assert_eq!(qr.role(), Some(Role::Primary));
} }
} }
@@ -893,7 +940,7 @@ mod test {
let query = simple_query("SELECT * FROM items WHERE id = 5"); let query = simple_query("SELECT * FROM items WHERE id = 5");
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None); assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None);
assert!(qr.infer(&query)); assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), None); assert_eq!(qr.role(), None);
} }
@@ -913,7 +960,7 @@ mod test {
res.put(prepared_stmt); res.put(prepared_stmt);
res.put_i16(0); res.put_i16(0);
assert!(qr.infer(&res)); assert!(qr.infer(&QueryRouter::parse(&res).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Replica)); assert_eq!(qr.role(), Some(Role::Replica));
} }
@@ -1077,11 +1124,11 @@ mod test {
assert_eq!(qr.role(), None); assert_eq!(qr.role(), None);
let query = simple_query("INSERT INTO test_table VALUES (1)"); let query = simple_query("INSERT INTO test_table VALUES (1)");
assert!(qr.infer(&query)); assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Primary)); assert_eq!(qr.role(), Some(Role::Primary));
let query = simple_query("SELECT * FROM test_table"); let query = simple_query("SELECT * FROM test_table");
assert!(qr.infer(&query)); assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Replica)); assert_eq!(qr.role(), Some(Role::Replica));
assert!(qr.query_parser_enabled()); assert!(qr.query_parser_enabled());
@@ -1113,6 +1160,7 @@ mod test {
auth_query: None, auth_query: None,
auth_query_password: None, auth_query_password: None,
auth_query_user: None, auth_query_user: None,
db: "test".to_string(),
}; };
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
assert_eq!(qr.active_role, None); assert_eq!(qr.active_role, None);
@@ -1142,15 +1190,24 @@ mod test {
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
assert!(qr.infer(&simple_query("BEGIN; SELECT 1; COMMIT;"))); assert!(qr
.infer(&QueryRouter::parse(&simple_query("BEGIN; SELECT 1; COMMIT;")).unwrap())
.is_ok());
assert_eq!(qr.role(), Role::Primary); assert_eq!(qr.role(), Role::Primary);
assert!(qr.infer(&simple_query("SELECT 1; SELECT 2;"))); assert!(qr
.infer(&QueryRouter::parse(&simple_query("SELECT 1; SELECT 2;")).unwrap())
.is_ok());
assert_eq!(qr.role(), Role::Replica); assert_eq!(qr.role(), Role::Replica);
assert!(qr.infer(&simple_query( assert!(qr
"SELECT 123; INSERT INTO t VALUES (5); SELECT 1;" .infer(
))); &QueryRouter::parse(&simple_query(
"SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.role(), Role::Primary); assert_eq!(qr.role(), Role::Primary);
} }
@@ -1177,6 +1234,7 @@ mod test {
auth_query: None, auth_query: None,
auth_query_password: None, auth_query_password: None,
auth_query_user: None, auth_query_user: None,
db: "test".to_string(),
}; };
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
qr.update_pool_settings(pool_settings.clone()); qr.update_pool_settings(pool_settings.clone());
@@ -1208,47 +1266,84 @@ mod test {
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string()); qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
qr.pool_settings.shards = 3; qr.pool_settings.shards = 3;
assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 5"))); assert!(qr
.infer(&QueryRouter::parse(&simple_query("SELECT * FROM data WHERE id = 5")).unwrap())
.is_ok());
assert_eq!(qr.shard(), 2); assert_eq!(qr.shard(), 2);
assert!(qr.infer(&simple_query( assert!(qr
"SELECT one, two, three FROM public.data WHERE id = 6" .infer(
))); &QueryRouter::parse(&simple_query(
"SELECT one, two, three FROM public.data WHERE id = 6"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 0); assert_eq!(qr.shard(), 0);
assert!(qr.infer(&simple_query( assert!(qr
"SELECT * FROM data .infer(
&QueryRouter::parse(&simple_query(
"SELECT * FROM data
INNER JOIN t2 ON data.id = 5 INNER JOIN t2 ON data.id = 5
AND t2.data_id = data.id AND t2.data_id = data.id
WHERE data.id = 5" WHERE data.id = 5"
))); ))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 2); assert_eq!(qr.shard(), 2);
// Shard did not move because we couldn't determine the sharding key since it could be ambiguous // Shard did not move because we couldn't determine the sharding key since it could be ambiguous
// in the query. // in the query.
assert!(qr.infer(&simple_query( assert!(qr
"SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id" .infer(
))); &QueryRouter::parse(&simple_query(
"SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 2); assert_eq!(qr.shard(), 2);
assert!(qr.infer(&simple_query( assert!(qr
r#"SELECT * FROM "public"."data" WHERE "id" = 6"# .infer(
))); &QueryRouter::parse(&simple_query(
r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 0); assert_eq!(qr.shard(), 0);
assert!(qr.infer(&simple_query( assert!(qr
r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"# .infer(
))); &QueryRouter::parse(&simple_query(
r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 2); assert_eq!(qr.shard(), 2);
// Super unique sharding key // Super unique sharding key
qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string()); qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string());
assert!(qr.infer(&simple_query( assert!(qr
"SELECT * FROM table_x WHERE unique_enough_column_name = 6" .infer(
))); &QueryRouter::parse(&simple_query(
"SELECT * FROM table_x WHERE unique_enough_column_name = 6"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 0); assert_eq!(qr.shard(), 0);
assert!(qr.infer(&simple_query("SELECT * FROM table_y WHERE another_key = 5"))); assert!(qr
.infer(
&QueryRouter::parse(&simple_query("SELECT * FROM table_y WHERE another_key = 5"))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 0); assert_eq!(qr.shard(), 0);
} }
@@ -1272,11 +1367,40 @@ mod test {
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string()); qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
qr.pool_settings.shards = 3; qr.pool_settings.shards = 3;
assert!(qr.infer(&simple_query(stmt))); assert!(qr
.infer(&QueryRouter::parse(&simple_query(stmt)).unwrap())
.is_ok());
assert_eq!(qr.placeholders.len(), 1); assert_eq!(qr.placeholders.len(), 1);
assert!(qr.infer_shard_from_bind(&bind)); assert!(qr.infer_shard_from_bind(&bind));
assert_eq!(qr.shard(), 2); assert_eq!(qr.shard(), 2);
assert!(qr.placeholders.is_empty()); assert!(qr.placeholders.is_empty());
} }
#[tokio::test]
async fn test_table_access_plugin() {
use crate::config::TableAccess;
let ta = TableAccess {
enabled: true,
tables: vec![String::from("pg_database")],
};
crate::plugins::table_access::setup(&ta);
QueryRouter::setup();
let qr = QueryRouter::new();
let query = simple_query("SELECT * FROM pg_database");
let ast = QueryRouter::parse(&query).unwrap();
let res = qr.execute_plugins(&ast).await;
assert_eq!(
res,
Ok(PluginOutput::Deny(
"permission for table \"pg_database\" denied".to_string()
))
);
}
} }

View File

@@ -7,22 +7,101 @@ use parking_lot::{Mutex, RwLock};
use postgres_protocol::message; use postgres_protocol::message;
use std::collections::HashMap; use std::collections::HashMap;
use std::io::Read; use std::io::Read;
use std::net::IpAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::SystemTime; use std::time::SystemTime;
use tokio::io::{AsyncReadExt, BufReader}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream};
use tokio::net::{ use tokio::net::TcpStream;
tcp::{OwnedReadHalf, OwnedWriteHalf}, use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore};
TcpStream, use tokio_rustls::{client::TlsStream, TlsConnector};
};
use crate::config::{Address, User}; use crate::config::{get_config, Address, User};
use crate::constants::*; use crate::constants::*;
use crate::errors::Error; use crate::dns_cache::{AddrSet, CACHED_RESOLVER};
use crate::errors::{Error, ServerIdentifier};
use crate::messages::*; use crate::messages::*;
use crate::mirrors::MirroringManager; use crate::mirrors::MirroringManager;
use crate::pool::ClientServerMap; use crate::pool::ClientServerMap;
use crate::scram::ScramSha256; use crate::scram::ScramSha256;
use crate::stats::ServerStats; use crate::stats::ServerStats;
use std::io::Write;
use pin_project::pin_project;
#[pin_project(project = SteamInnerProj)]
pub enum StreamInner {
Plain {
#[pin]
stream: TcpStream,
},
Tls {
#[pin]
stream: TlsStream<TcpStream>,
},
}
impl AsyncWrite for StreamInner {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let this = self.project();
match this {
SteamInnerProj::Tls { stream } => stream.poll_write(cx, buf),
SteamInnerProj::Plain { stream } => stream.poll_write(cx, buf),
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let this = self.project();
match this {
SteamInnerProj::Tls { stream } => stream.poll_flush(cx),
SteamInnerProj::Plain { stream } => stream.poll_flush(cx),
}
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let this = self.project();
match this {
SteamInnerProj::Tls { stream } => stream.poll_shutdown(cx),
SteamInnerProj::Plain { stream } => stream.poll_shutdown(cx),
}
}
}
impl AsyncRead for StreamInner {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let this = self.project();
match this {
SteamInnerProj::Tls { stream } => stream.poll_read(cx, buf),
SteamInnerProj::Plain { stream } => stream.poll_read(cx, buf),
}
}
}
impl StreamInner {
pub fn try_write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self {
StreamInner::Tls { stream } => {
let r = stream.get_mut();
let mut w = r.1.writer();
w.write(buf)
}
StreamInner::Plain { stream } => stream.try_write(buf),
}
}
}
/// Server state. /// Server state.
pub struct Server { pub struct Server {
@@ -30,11 +109,8 @@ pub struct Server {
/// port, e.g. 5432, and role, e.g. primary or replica. /// port, e.g. 5432, and role, e.g. primary or replica.
address: Address, address: Address,
/// Buffered read socket. /// Server TCP connection.
read: BufReader<OwnedReadHalf>, stream: BufStream<StreamInner>,
/// Unbuffered write socket (our client code buffers).
write: OwnedWriteHalf,
/// Our server response buffer. We buffer data before we give it to the client. /// Our server response buffer. We buffer data before we give it to the client.
buffer: BytesMut, buffer: BytesMut,
@@ -74,6 +150,9 @@ pub struct Server {
last_activity: SystemTime, last_activity: SystemTime,
mirror_manager: Option<MirroringManager>, mirror_manager: Option<MirroringManager>,
// Associated addresses used
addr_set: Option<AddrSet>,
} }
impl Server { impl Server {
@@ -87,6 +166,24 @@ impl Server {
stats: Arc<ServerStats>, stats: Arc<ServerStats>,
auth_hash: Arc<RwLock<Option<String>>>, auth_hash: Arc<RwLock<Option<String>>>,
) -> Result<Server, Error> { ) -> Result<Server, Error> {
let cached_resolver = CACHED_RESOLVER.load();
let mut addr_set: Option<AddrSet> = None;
// If we are caching addresses and hostname is not an IP
if cached_resolver.enabled() && address.host.parse::<IpAddr>().is_err() {
debug!("Resolving {}", &address.host);
addr_set = match cached_resolver.lookup_ip(&address.host).await {
Ok(ok) => {
debug!("Obtained: {:?}", ok);
Some(ok)
}
Err(err) => {
warn!("Error trying to resolve {}, ({:?})", &address.host, err);
None
}
}
};
let mut stream = let mut stream =
match TcpStream::connect(&format!("{}:{}", &address.host, address.port)).await { match TcpStream::connect(&format!("{}:{}", &address.host, address.port)).await {
Ok(stream) => stream, Ok(stream) => stream,
@@ -98,33 +195,137 @@ impl Server {
))); )));
} }
}; };
// TCP timeouts.
configure_socket(&stream); configure_socket(&stream);
let config = get_config();
let mut stream = if config.general.server_tls {
// Request a TLS connection
ssl_request(&mut stream).await?;
let response = match stream.read_u8().await {
Ok(response) => response as char,
Err(err) => {
return Err(Error::SocketError(format!(
"Server socket error: {:?}",
err
)))
}
};
match response {
// Server supports TLS
'S' => {
debug!("Connecting to server using TLS");
let mut root_store = RootCertStore::empty();
root_store.add_server_trust_anchors(
webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}),
);
let mut tls_config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
// Equivalent to sslmode=prefer which is fine most places.
// If you want verify-full, change `verify_server_certificate` to true.
if !config.general.verify_server_certificate {
let mut dangerous = tls_config.dangerous();
dangerous.set_certificate_verifier(Arc::new(
crate::tls::NoCertificateVerification {},
));
}
let connector = TlsConnector::from(Arc::new(tls_config));
let stream = match connector
.connect(address.host.as_str().try_into().unwrap(), stream)
.await
{
Ok(stream) => stream,
Err(err) => {
return Err(Error::SocketError(format!("Server TLS error: {:?}", err)))
}
};
StreamInner::Tls { stream }
}
// Server does not support TLS
'N' => StreamInner::Plain { stream },
// Something else?
m => {
return Err(Error::SocketError(format!(
"Unknown message: {}",
m as char
)));
}
}
} else {
StreamInner::Plain { stream }
};
// let (read, write) = split(stream);
// let (mut read, mut write) = (ReadInner::Plain { stream: read }, WriteInner::Plain { stream: write });
trace!("Sending StartupMessage"); trace!("Sending StartupMessage");
// StartupMessage // StartupMessage
startup(&mut stream, &user.username, database).await?; let username = match user.server_username {
Some(ref server_username) => server_username,
None => &user.username,
};
let password = match user.server_password {
Some(ref server_password) => Some(server_password),
None => match user.password {
Some(ref password) => Some(password),
None => None,
},
};
startup(&mut stream, username, database).await?;
let mut server_info = BytesMut::new(); let mut server_info = BytesMut::new();
let mut process_id: i32 = 0; let mut process_id: i32 = 0;
let mut secret_key: i32 = 0; let mut secret_key: i32 = 0;
let server_identifier = ServerIdentifier::new(username, &database);
// We'll be handling multiple packets, but they will all be structured the same. // We'll be handling multiple packets, but they will all be structured the same.
// We'll loop here until this exchange is complete. // We'll loop here until this exchange is complete.
let mut scram: Option<ScramSha256> = None; let mut scram: Option<ScramSha256> = match password {
if let Some(password) = &user.password.clone() { Some(password) => Some(ScramSha256::new(password)),
scram = Some(ScramSha256::new(password)); None => None,
} };
loop { loop {
let code = match stream.read_u8().await { let code = match stream.read_u8().await {
Ok(code) => code as char, Ok(code) => code as char,
Err(_) => return Err(Error::SocketError(format!("Error reading message code on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), Err(_) => {
return Err(Error::ServerStartupError(
"message code".into(),
server_identifier,
))
}
}; };
let len = match stream.read_i32().await { let len = match stream.read_i32().await {
Ok(len) => len, Ok(len) => len,
Err(_) => return Err(Error::SocketError(format!("Error reading message len on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), Err(_) => {
return Err(Error::ServerStartupError(
"message len".into(),
server_identifier,
))
}
}; };
trace!("Message: {}", code); trace!("Message: {}", code);
@@ -135,7 +336,12 @@ impl Server {
// Determine which kind of authentication is required, if any. // Determine which kind of authentication is required, if any.
let auth_code = match stream.read_i32().await { let auth_code = match stream.read_i32().await {
Ok(auth_code) => auth_code, Ok(auth_code) => auth_code,
Err(_) => return Err(Error::SocketError(format!("Error reading auth code on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), Err(_) => {
return Err(Error::ServerStartupError(
"auth code".into(),
server_identifier,
))
}
}; };
trace!("Auth: {}", auth_code); trace!("Auth: {}", auth_code);
@@ -148,14 +354,18 @@ impl Server {
match stream.read_exact(&mut salt).await { match stream.read_exact(&mut salt).await {
Ok(_) => (), Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading salt on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), Err(_) => {
return Err(Error::ServerStartupError(
"salt".into(),
server_identifier,
))
}
}; };
match &user.password { match password {
// Using plaintext password // Using plaintext password
Some(password) => { Some(password) => {
md5_password(&mut stream, &user.username, password, &salt[..]) md5_password(&mut stream, username, password, &salt[..]).await?
.await?
} }
// Using auth passthrough, in this case we should already have a // Using auth passthrough, in this case we should already have a
@@ -171,8 +381,12 @@ impl Server {
&salt[..], &salt[..],
) )
.await?, .await?,
None => None => return Err(
return Err(Error::AuthError(format!("Auth passthrough (auth_query) failed and no user password is set in cleartext for {{ username: {:?}, database: {:?} }}", user.username, database))) Error::ServerAuthError(
"Auth passthrough (auth_query) failed and no user password is set in cleartext".into(),
server_identifier
)
),
} }
} }
} }
@@ -182,21 +396,33 @@ impl Server {
SASL => { SASL => {
if scram.is_none() { if scram.is_none() {
return Err(Error::AuthError(format!("SASL auth required and not password specified, auth passthrough (auth_query) method is currently unsupported for SASL auth {{ username: {:?}, database: {:?} }}", user.username, database))); return Err(Error::ServerAuthError(
"SASL auth required and no password specified. \
Auth passthrough (auth_query) method is currently \
unsupported for SASL auth"
.into(),
server_identifier,
));
} }
debug!("Starting SASL authentication"); debug!("Starting SASL authentication");
let sasl_len = (len - 8) as usize; let sasl_len = (len - 8) as usize;
let mut sasl_auth = vec![0u8; sasl_len]; let mut sasl_auth = vec![0u8; sasl_len];
match stream.read_exact(&mut sasl_auth).await { match stream.read_exact(&mut sasl_auth).await {
Ok(_) => (), Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading sasl message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), Err(_) => {
return Err(Error::ServerStartupError(
"sasl message".into(),
server_identifier,
))
}
}; };
let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]); let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]);
if sasl_type == SCRAM_SHA_256 { if sasl_type.contains(SCRAM_SHA_256) {
debug!("Using {}", SCRAM_SHA_256); debug!("Using {}", SCRAM_SHA_256);
// Generate client message. // Generate client message.
@@ -219,7 +445,7 @@ impl Server {
res.put_i32(sasl_response.len() as i32); res.put_i32(sasl_response.len() as i32);
res.put(sasl_response); res.put(sasl_response);
write_all(&mut stream, res).await?; write_all_flush(&mut stream, &res).await?;
} else { } else {
error!("Unsupported SCRAM version: {}", sasl_type); error!("Unsupported SCRAM version: {}", sasl_type);
return Err(Error::ServerError); return Err(Error::ServerError);
@@ -233,7 +459,12 @@ impl Server {
match stream.read_exact(&mut sasl_data).await { match stream.read_exact(&mut sasl_data).await {
Ok(_) => (), Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading sasl cont message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), Err(_) => {
return Err(Error::ServerStartupError(
"sasl cont message".into(),
server_identifier,
))
}
}; };
let msg = BytesMut::from(&sasl_data[..]); let msg = BytesMut::from(&sasl_data[..]);
@@ -245,7 +476,7 @@ impl Server {
res.put_i32(4 + sasl_response.len() as i32); res.put_i32(4 + sasl_response.len() as i32);
res.put(sasl_response); res.put(sasl_response);
write_all(&mut stream, res).await?; write_all_flush(&mut stream, &res).await?;
} }
SASL_FINAL => { SASL_FINAL => {
@@ -254,7 +485,12 @@ impl Server {
let mut sasl_final = vec![0u8; len as usize - 8]; let mut sasl_final = vec![0u8; len as usize - 8];
match stream.read_exact(&mut sasl_final).await { match stream.read_exact(&mut sasl_final).await {
Ok(_) => (), Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading sasl final message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), Err(_) => {
return Err(Error::ServerStartupError(
"sasl final message".into(),
server_identifier,
))
}
}; };
match scram match scram
@@ -284,7 +520,12 @@ impl Server {
'E' => { 'E' => {
let error_code = match stream.read_u8().await { let error_code = match stream.read_u8().await {
Ok(error_code) => error_code, Ok(error_code) => error_code,
Err(_) => return Err(Error::SocketError(format!("Error reading error code message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), Err(_) => {
return Err(Error::ServerStartupError(
"error code message".into(),
server_identifier,
))
}
}; };
trace!("Error: {}", error_code); trace!("Error: {}", error_code);
@@ -300,7 +541,12 @@ impl Server {
match stream.read_exact(&mut error).await { match stream.read_exact(&mut error).await {
Ok(_) => (), Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading error message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), Err(_) => {
return Err(Error::ServerStartupError(
"error message".into(),
server_identifier,
))
}
}; };
// TODO: the error message contains multiple fields; we can decode them and // TODO: the error message contains multiple fields; we can decode them and
@@ -319,7 +565,12 @@ impl Server {
match stream.read_exact(&mut param).await { match stream.read_exact(&mut param).await {
Ok(_) => (), Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading parameter status message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), Err(_) => {
return Err(Error::ServerStartupError(
"parameter status message".into(),
server_identifier,
))
}
}; };
// Save the parameter so we can pass it to the client later. // Save the parameter so we can pass it to the client later.
@@ -336,12 +587,22 @@ impl Server {
// See: <https://www.postgresql.org/docs/12/protocol-message-formats.html>. // See: <https://www.postgresql.org/docs/12/protocol-message-formats.html>.
process_id = match stream.read_i32().await { process_id = match stream.read_i32().await {
Ok(id) => id, Ok(id) => id,
Err(_) => return Err(Error::SocketError(format!("Error reading process id message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), Err(_) => {
return Err(Error::ServerStartupError(
"process id message".into(),
server_identifier,
))
}
}; };
secret_key = match stream.read_i32().await { secret_key = match stream.read_i32().await {
Ok(id) => id, Ok(id) => id,
Err(_) => return Err(Error::SocketError(format!("Error reading secret key message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), Err(_) => {
return Err(Error::ServerStartupError(
"secret key message".into(),
server_identifier,
))
}
}; };
} }
@@ -351,15 +612,17 @@ impl Server {
match stream.read_exact(&mut idle).await { match stream.read_exact(&mut idle).await {
Ok(_) => (), Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading transaction status message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), Err(_) => {
return Err(Error::ServerStartupError(
"transaction status message".into(),
server_identifier,
))
}
}; };
let (read, write) = stream.into_split();
let mut server = Server { let mut server = Server {
address: address.clone(), address: address.clone(),
read: BufReader::new(read), stream: BufStream::new(stream),
write,
buffer: BytesMut::with_capacity(8196), buffer: BytesMut::with_capacity(8196),
server_info, server_info,
process_id, process_id,
@@ -369,6 +632,7 @@ impl Server {
bad: false, bad: false,
needs_cleanup: false, needs_cleanup: false,
client_server_map, client_server_map,
addr_set,
connected_at: chrono::offset::Utc::now().naive_utc(), connected_at: chrono::offset::Utc::now().naive_utc(),
stats, stats,
application_name: String::new(), application_name: String::new(),
@@ -413,7 +677,7 @@ impl Server {
Ok(stream) => stream, Ok(stream) => stream,
Err(err) => { Err(err) => {
error!("Could not connect to server: {}", err); error!("Could not connect to server: {}", err);
return Err(Error::SocketError(format!("Error reading cancel message"))); return Err(Error::SocketError("Error reading cancel message".into()));
} }
}; };
configure_socket(&stream); configure_socket(&stream);
@@ -426,7 +690,7 @@ impl Server {
bytes.put_i32(process_id); bytes.put_i32(process_id);
bytes.put_i32(secret_key); bytes.put_i32(secret_key);
write_all(&mut stream, bytes).await write_all_flush(&mut stream, &bytes).await
} }
/// Send messages to the server from the client. /// Send messages to the server from the client.
@@ -434,7 +698,7 @@ impl Server {
self.mirror_send(messages); self.mirror_send(messages);
self.stats().data_sent(messages.len()); self.stats().data_sent(messages.len());
match write_all_half(&mut self.write, messages).await { match write_all_flush(&mut self.stream, &messages).await {
Ok(_) => { Ok(_) => {
// Successfully sent to server // Successfully sent to server
self.last_activity = SystemTime::now(); self.last_activity = SystemTime::now();
@@ -453,7 +717,7 @@ impl Server {
/// in order to receive all data the server has to offer. /// in order to receive all data the server has to offer.
pub async fn recv(&mut self) -> Result<BytesMut, Error> { pub async fn recv(&mut self) -> Result<BytesMut, Error> {
loop { loop {
let mut message = match read_message(&mut self.read).await { let mut message = match read_message(&mut self.stream).await {
Ok(message) => message, Ok(message) => message,
Err(err) => { Err(err) => {
error!("Terminating server because of: {:?}", err); error!("Terminating server because of: {:?}", err);
@@ -609,7 +873,23 @@ impl Server {
/// Server & client are out of sync, we must discard this connection. /// Server & client are out of sync, we must discard this connection.
/// This happens with clients that misbehave. /// This happens with clients that misbehave.
pub fn is_bad(&self) -> bool { pub fn is_bad(&self) -> bool {
self.bad if self.bad {
return self.bad;
};
let cached_resolver = CACHED_RESOLVER.load();
if cached_resolver.enabled() {
if let Some(addr_set) = &self.addr_set {
if cached_resolver.has_changed(self.address.host.as_str(), addr_set) {
warn!(
"DNS changed for {}, it was {:?}. Dropping server connection.",
self.address.host.as_str(),
addr_set
);
return true;
}
}
}
false
} }
/// Get server startup information to forward it to the client. /// Get server startup information to forward it to the client.
@@ -846,13 +1126,13 @@ impl Drop for Server {
// Update statistics // Update statistics
self.stats.disconnect(); self.stats.disconnect();
let mut bytes = BytesMut::with_capacity(4); let mut bytes = BytesMut::with_capacity(5);
bytes.put_u8(b'X'); bytes.put_u8(b'X');
bytes.put_i32(4); bytes.put_i32(4);
match self.write.try_write(&bytes) { match self.stream.get_mut().try_write(&bytes) {
Ok(_) => (), Ok(5) => (),
Err(_) => debug!("Dirty shutdown"), _ => debug!("Dirty shutdown"),
}; };
// Should not matter. // Should not matter.

View File

@@ -66,7 +66,7 @@ impl Reporter {
CLIENT_STATS.write().insert(client_id, stats); CLIENT_STATS.write().insert(client_id, stats);
} }
/// Reports a client is disconecting from the pooler. /// Reports a client is disconnecting from the pooler.
fn client_disconnecting(&self, client_id: i32) { fn client_disconnecting(&self, client_id: i32) {
CLIENT_STATS.write().remove(&client_id); CLIENT_STATS.write().remove(&client_id);
} }
@@ -76,7 +76,7 @@ impl Reporter {
fn server_register(&self, server_id: i32, stats: Arc<ServerStats>) { fn server_register(&self, server_id: i32, stats: Arc<ServerStats>) {
SERVER_STATS.write().insert(server_id, stats); SERVER_STATS.write().insert(server_id, stats);
} }
/// Reports a server connection is disconecting from the pooler. /// Reports a server connection is disconnecting from the pooler.
fn server_disconnecting(&self, server_id: i32) { fn server_disconnecting(&self, server_id: i32) {
SERVER_STATS.write().remove(&server_id); SERVER_STATS.write().remove(&server_id);
} }

View File

@@ -92,7 +92,7 @@ impl ClientStats {
} }
} }
/// Reports a client is disconecting from the pooler and /// Reports a client is disconnecting from the pooler and
/// update metrics on the corresponding pool. /// update metrics on the corresponding pool.
pub fn disconnect(&self) { pub fn disconnect(&self) {
self.reporter.client_disconnecting(self.client_id); self.reporter.client_disconnecting(self.client_id);
@@ -140,7 +140,7 @@ impl ClientStats {
self.error_count.fetch_add(1, Ordering::Relaxed); self.error_count.fetch_add(1, Ordering::Relaxed);
} }
/// Reportes the time spent by a client waiting to get a healthy connection from the pool /// Reporters the time spent by a client waiting to get a healthy connection from the pool
pub fn checkout_time(&self, microseconds: u64) { pub fn checkout_time(&self, microseconds: u64) {
self.total_wait_time self.total_wait_time
.fetch_add(microseconds, Ordering::Relaxed); .fetch_add(microseconds, Ordering::Relaxed);

View File

@@ -100,10 +100,9 @@ impl ServerStats {
.server_idle(self.state.load(Ordering::Relaxed)); .server_idle(self.state.load(Ordering::Relaxed));
self.state.store(ServerState::Idle, Ordering::Relaxed); self.state.store(ServerState::Idle, Ordering::Relaxed);
self.set_undefined_application();
} }
/// Reports a server connection is disconecting from the pooler. /// Reports a server connection is disconnecting from the pooler.
/// Also updates metrics on the pool regarding server usage. /// Also updates metrics on the pool regarding server usage.
pub fn disconnect(&self) { pub fn disconnect(&self) {
self.reporter.server_disconnecting(self.server_id); self.reporter.server_disconnecting(self.server_id);

View File

@@ -4,7 +4,12 @@ use rustls_pemfile::{certs, read_one, Item};
use std::iter; use std::iter;
use std::path::Path; use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use tokio_rustls::rustls::{self, Certificate, PrivateKey}; use std::time::SystemTime;
use tokio_rustls::rustls::{
self,
client::{ServerCertVerified, ServerCertVerifier},
Certificate, PrivateKey, ServerName,
};
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
use crate::config::get_config; use crate::config::get_config;
@@ -64,3 +69,19 @@ impl Tls {
}) })
} }
} }
pub struct NoCertificateVerification;
impl ServerCertVerifier for NoCertificateVerification {
fn verify_server_cert(
&self,
_end_entity: &Certificate,
_intermediates: &[Certificate],
_server_name: &ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8],
_now: SystemTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
}

View File

@@ -37,9 +37,9 @@ describe "Admin" do
describe "SHOW POOLS" do describe "SHOW POOLS" do
context "bad credentials" do context "bad credentials" do
it "does not change any stats" do it "does not change any stats" do
bad_passsword_url = URI(pgcat_conn_str) bad_password_url = URI(pgcat_conn_str)
bad_passsword_url.password = "wrong" bad_password_url.password = "wrong"
expect { PG::connect("#{bad_passsword_url.to_s}?application_name=bad_password") }.to raise_error(PG::ConnectionBad) expect { PG::connect("#{bad_password_url.to_s}?application_name=bad_password") }.to raise_error(PG::ConnectionBad)
sleep(1) sleep(1)
admin_conn = PG::connect(processes.pgcat.admin_connection_string) admin_conn = PG::connect(processes.pgcat.admin_connection_string)
@@ -71,15 +71,17 @@ describe "Admin" do
context "client connects but issues no queries" do context "client connects but issues no queries" do
it "only affects cl_idle stats" do it "only affects cl_idle stats" do
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
before_test = admin_conn.async_exec("SHOW POOLS")[0]["sv_idle"]
connections = Array.new(20) { PG::connect(pgcat_conn_str) } connections = Array.new(20) { PG::connect(pgcat_conn_str) }
sleep(1) sleep(1)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0] results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s| %w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0" raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end end
expect(results["cl_idle"]).to eq("20") expect(results["cl_idle"]).to eq("20")
expect(results["sv_idle"]).to eq("1") expect(results["sv_idle"]).to eq(before_test)
connections.map(&:close) connections.map(&:close)
sleep(1.1) sleep(1.1)
@@ -87,7 +89,7 @@ describe "Admin" do
%w[cl_active cl_idle cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s| %w[cl_active cl_idle cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0" raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end end
expect(results["sv_idle"]).to eq("1") expect(results["sv_idle"]).to eq(before_test)
end end
end end

View File

@@ -0,0 +1,259 @@
require 'socket'
require 'digest/md5'
BACKEND_MESSAGE_CODES = {
'Z' => "ReadyForQuery",
'C' => "CommandComplete",
'T' => "RowDescription",
'D' => "DataRow",
'1' => "ParseComplete",
'2' => "BindComplete",
'E' => "ErrorResponse",
's' => "PortalSuspended",
}
class PostgresSocket
def initialize(host, port)
@port = port
@host = host
@socket = TCPSocket.new @host, @port
@parameters = {}
@verbose = true
end
def send_md5_password_message(username, password, salt)
m = Digest::MD5.hexdigest(password + username)
m = Digest::MD5.hexdigest(m + salt.map(&:chr).join(""))
m = 'md5' + m
bytes = (m.split("").map(&:ord) + [0]).flatten
message_size = bytes.count + 4
message = []
message << 'p'.ord
message << [message_size].pack('l>').unpack('CCCC') # 4
message << bytes
message.flatten!
@socket.write(message.pack('C*'))
end
def send_startup_message(username, database, password)
message = []
message << [196608].pack('l>').unpack('CCCC') # 4
message << "user".split('').map(&:ord) # 4, 8
message << 0 # 1, 9
message << username.split('').map(&:ord) # 2, 11
message << 0 # 1, 12
message << "database".split('').map(&:ord) # 8, 20
message << 0 # 1, 21
message << database.split('').map(&:ord) # 2, 23
message << 0 # 1, 24
message << 0 # 1, 25
message.flatten!
total_message_size = message.size + 4
message_len = [total_message_size].pack('l>').unpack('CCCC')
@socket.write([message_len + message].flatten.pack('C*'))
sleep 0.1
read_startup_response(username, password)
end
def read_startup_response(username, password)
message_code, message_len = @socket.recv(5).unpack("al>")
while message_code == 'R'
auth_code = @socket.recv(4).unpack('l>').pop
case auth_code
when 5 # md5
salt = @socket.recv(4).unpack('CCCC')
send_md5_password_message(username, password, salt)
message_code, message_len = @socket.recv(5).unpack("al>")
when 0 # trust
break
end
end
loop do
message_code, message_len = @socket.recv(5).unpack("al>")
if message_code == 'Z'
@socket.recv(1).unpack("a") # most likely I
break # We are good to go
end
if message_code == 'S'
actual_message = @socket.recv(message_len - 4).unpack("C*")
k,v = actual_message.pack('U*').split(/\x00/)
@parameters[k] = v
end
if message_code == 'K'
process_id, secret_key = @socket.recv(message_len - 4).unpack("l>l>")
@parameters["process_id"] = process_id
@parameters["secret_key"] = secret_key
end
end
return @parameters
end
def cancel_query
socket = TCPSocket.new @host, @port
process_key = @parameters["process_id"]
secret_key = @parameters["secret_key"]
message = []
message << [16].pack('l>').unpack('CCCC') # 4
message << [80877102].pack('l>').unpack('CCCC') # 4
message << [process_key.to_i].pack('l>').unpack('CCCC') # 4
message << [secret_key.to_i].pack('l>').unpack('CCCC') # 4
message.flatten!
socket.write(message.flatten.pack('C*'))
socket.close
log "[F] Sent CancelRequest message"
end
def send_query_message(query)
query_size = query.length
message_size = 1 + 4 + query_size
message = []
message << "Q".ord
message << [message_size].pack('l>').unpack('CCCC') # 4
message << query.split('').map(&:ord) # 2, 11
message << 0 # 1, 12
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent Q message (#{query})"
end
def send_parse_message(query)
query_size = query.length
message_size = 2 + 2 + 4 + query_size
message = []
message << "P".ord
message << [message_size].pack('l>').unpack('CCCC') # 4
message << 0 # unnamed statement
message << query.split('').map(&:ord) # 2, 11
message << 0 # 1, 12
message << [0, 0]
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent P message (#{query})"
end
def send_bind_message
message = []
message << "B".ord
message << [12].pack('l>').unpack('CCCC') # 4
message << 0 # unnamed statement
message << 0 # unnamed statement
message << [0, 0] # 2
message << [0, 0] # 2
message << [0, 0] # 2
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent B message"
end
def send_describe_message(mode)
message = []
message << "D".ord
message << [6].pack('l>').unpack('CCCC') # 4
message << mode.ord
message << 0 # unnamed statement
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent D message"
end
def send_execute_message(limit=0)
message = []
message << "E".ord
message << [9].pack('l>').unpack('CCCC') # 4
message << 0 # unnamed statement
message << [limit].pack('l>').unpack('CCCC') # 4
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent E message"
end
def send_sync_message
message = []
message << "S".ord
message << [4].pack('l>').unpack('CCCC') # 4
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent S message"
end
def send_copydone_message
message = []
message << "c".ord
message << [4].pack('l>').unpack('CCCC') # 4
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent c message"
end
def send_copyfail_message
message = []
message << "f".ord
message << [5].pack('l>').unpack('CCCC') # 4
message << 0
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent f message"
end
def send_flush_message
message = []
message << "H".ord
message << [4].pack('l>').unpack('CCCC') # 4
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent H message"
end
def read_from_server()
output_messages = []
retry_count = 0
message_code = nil
message_len = 0
loop do
begin
message_code, message_len = @socket.recv_nonblock(5).unpack("al>")
rescue IO::WaitReadable
return output_messages if retry_count > 50
retry_count += 1
sleep(0.01)
next
end
message = {
code: message_code,
len: message_len,
bytes: []
}
log "[B] #{BACKEND_MESSAGE_CODES[message_code] || ('UnknownMessage(' + message_code + ')')}"
actual_message_length = message_len - 4
if actual_message_length > 0
message[:bytes] = @socket.recv(message_len - 4).unpack("C*")
log "\t#{message[:bytes].join(",")}"
log "\t#{message[:bytes].map(&:chr).join(" ")}"
end
output_messages << message
return output_messages if message_code == 'Z'
end
end
def log(msg)
return unless @verbose
puts msg
end
def close
@socket.close
end
end

View File

@@ -2,6 +2,7 @@ require 'json'
require 'ostruct' require 'ostruct'
require_relative 'pgcat_process' require_relative 'pgcat_process'
require_relative 'pg_instance' require_relative 'pg_instance'
require_relative 'pg_socket'
class ::Hash class ::Hash
def deep_merge(second) def deep_merge(second)

View File

@@ -65,7 +65,7 @@ describe "Least Outstanding Queries Load Balancing" do
processes.pgcat.shutdown processes.pgcat.shutdown
end end
context "under homogenous load" do context "under homogeneous load" do
it "balances query volume between all instances" do it "balances query volume between all instances" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))

View File

@@ -25,7 +25,7 @@ describe "Query Mirroing" do
processes.pgcat.shutdown processes.pgcat.shutdown
end end
it "can mirror a query" do xit "can mirror a query" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
runs = 15 runs = 15
runs.times { conn.async_exec("SELECT 1 + 2") } runs.times { conn.async_exec("SELECT 1 + 2") }

View File

@@ -0,0 +1,14 @@
require_relative 'spec_helper'
describe "Plugins" do
let(:processes) { Helpers::Pgcat.three_shard_setup("sharded_db", 5) }
context "intercept" do
it "will intercept an intellij query" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
res = conn.exec("select current_database() as a, current_schemas(false) as b")
expect(res.values).to eq([["sharded_db", "{public}"]])
end
end
end

155
tests/ruby/protocol_spec.rb Normal file
View File

@@ -0,0 +1,155 @@
# frozen_string_literal: true
require_relative 'spec_helper'
describe "Portocol handling" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 1, "session") }
let(:sequence) { [] }
let(:pgcat_socket) { PostgresSocket.new('localhost', processes.pgcat.port) }
let(:pgdb_socket) { PostgresSocket.new('localhost', processes.all_databases.first.port) }
after do
pgdb_socket.close
pgcat_socket.close
processes.all_databases.map(&:reset)
processes.pgcat.shutdown
end
def run_comparison(sequence, socket_a, socket_b)
sequence.each do |msg, *args|
socket_a.send(msg, *args)
socket_b.send(msg, *args)
compare_messages(
socket_a.read_from_server,
socket_b.read_from_server
)
end
end
def compare_messages(msg_arr0, msg_arr1)
if msg_arr0.count != msg_arr1.count
error_output = []
error_output << "#{msg_arr0.count} : #{msg_arr1.count}"
error_output << "PgCat Messages"
error_output += msg_arr0.map { |message| "\t#{message[:code]} - #{message[:bytes].map(&:chr).join(" ")}" }
error_output << "PgServer Messages"
error_output += msg_arr1.map { |message| "\t#{message[:code]} - #{message[:bytes].map(&:chr).join(" ")}" }
error_desc = error_output.join("\n")
raise StandardError, "Message count mismatch #{error_desc}"
end
(0..msg_arr0.count - 1).all? do |i|
msg0 = msg_arr0[i]
msg1 = msg_arr1[i]
result = [
msg0[:code] == msg1[:code],
msg0[:len] == msg1[:len],
msg0[:bytes] == msg1[:bytes],
].all?
next result if result
if result == false
error_string = []
if msg0[:code] != msg1[:code]
error_string << "code #{msg0[:code]} != #{msg1[:code]}"
end
if msg0[:len] != msg1[:len]
error_string << "len #{msg0[:len]} != #{msg1[:len]}"
end
if msg0[:bytes] != msg1[:bytes]
error_string << "bytes #{msg0[:bytes]} != #{msg1[:bytes]}"
end
err = error_string.join("\n")
raise StandardError, "Message mismatch #{err}"
end
end
end
RSpec.shared_examples "at parity with database" do
before do
pgcat_socket.send_startup_message("sharding_user", "sharded_db", "sharding_user")
pgdb_socket.send_startup_message("sharding_user", "shard0", "sharding_user")
end
it "works" do
run_comparison(sequence, pgcat_socket, pgdb_socket)
end
end
context "Cancel Query" do
let(:sequence) {
[
[:send_query_message, "SELECT pg_sleep(5)"],
[:cancel_query]
]
}
it_behaves_like "at parity with database"
end
xcontext "Simple query after parse" do
let(:sequence) {
[
[:send_parse_message, "SELECT 5"],
[:send_query_message, "SELECT 1"],
[:send_bind_message],
[:send_describe_message, "P"],
[:send_execute_message],
[:send_sync_message],
]
}
# Known to fail due to PgCat not supporting flush
it_behaves_like "at parity with database"
end
xcontext "Flush message" do
let(:sequence) {
[
[:send_parse_message, "SELECT 1"],
[:send_flush_message]
]
}
# Known to fail due to PgCat not supporting flush
it_behaves_like "at parity with database"
end
xcontext "Bind without parse" do
let(:sequence) {
[
[:send_bind_message]
]
}
# This is known to fail.
# Server responds immediately, Proxy buffers the message
it_behaves_like "at parity with database"
end
context "Simple message" do
let(:sequence) {
[[:send_query_message, "SELECT 1"]]
}
it_behaves_like "at parity with database"
end
context "Extended protocol" do
let(:sequence) {
[
[:send_parse_message, "SELECT 1"],
[:send_bind_message],
[:send_describe_message, "P"],
[:send_execute_message],
[:send_sync_message],
]
}
it_behaves_like "at parity with database"
end
end

View File

@@ -27,7 +27,7 @@ describe "Sharding" do
processes.pgcat.shutdown processes.pgcat.shutdown
end end
describe "automatic routing of extended procotol" do describe "automatic routing of extended protocol" do
it "can do it" do it "can do it" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.exec("SET SERVER ROLE TO 'auto'") conn.exec("SET SERVER ROLE TO 'auto'")

View File

@@ -0,0 +1 @@
tomli