Compare commits

..

11 Commits

Author SHA1 Message Date
Lev Kokotov
35c20dc9b4 remove debug 2023-05-03 16:46:48 -07:00
Lev Kokotov
72ce2be5c7 lowercase config query 2023-05-03 16:46:07 -07:00
Lev Kokotov
811885f464 Actually plugins (#421)
* more plugins

* clean up

* fix tests

* fix flakey test
2023-05-03 16:13:45 -07:00
dependabot[bot]
d5e329fec5 chore(deps): bump regex from 1.8.0 to 1.8.1 (#413)
Bumps [regex](https://github.com/rust-lang/regex) from 1.8.0 to 1.8.1.
- [Release notes](https://github.com/rust-lang/regex/releases)
- [Changelog](https://github.com/rust-lang/regex/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/regex/commits/1.8.1)

---
updated-dependencies:
- dependency-name: regex
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-05-03 10:00:05 -07:00
Lev Kokotov
09e54e1175 Plugins! (#420)
* Some queries

* Plugins!!

* cleanup

* actual names

* the actual plugins

* comment

* fix tests

* Tests

* unused errors

* Increase reaper rate to actually enforce settings

* ok
2023-05-03 09:13:05 -07:00
dependabot[bot]
23819c8549 chore(deps): bump rustls from 0.21.0 to 0.21.1 (#419)
Bumps [rustls](https://github.com/rustls/rustls) from 0.21.0 to 0.21.1.
- [Release notes](https://github.com/rustls/rustls/releases)
- [Changelog](https://github.com/rustls/rustls/blob/main/RELEASE_NOTES.md)
- [Commits](https://github.com/rustls/rustls/compare/v/0.21.0...v/0.21.1)

---
updated-dependencies:
- dependency-name: rustls
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-05-02 07:32:44 -07:00
Jose Fernández
7dfbd993f2 Add dns_cache for server addresses as in pgbouncer (#249)
* Add dns_cache so server addresses are cached and invalidated when DNS changes.

Adds a module to deal with dns_cache feature. It's
main struct is CachedResolver, which is a simple thread safe
hostname <-> Ips cache with the ability to refresh resolutions
every `dns_max_ttl` seconds. This way, a client can check whether its
ip address has changed.

* Allow reloading dns cached

* Add documentation for dns_cached
2023-05-02 10:26:40 +02:00
Lev Kokotov
3601130ba1 Readme update (#418)
* Readme update

* m

* wording
2023-04-30 09:44:25 -07:00
Lev Kokotov
0d504032b2 Server TLS (#417)
* Server TLS

* Finish up TLS

* thats it

* diff

* remove dead code

* maybe?

* dirty shutdown

* skip flakey test

* remove unused error

* fetch config once
2023-04-30 09:41:46 -07:00
Lev Kokotov
4a87b4807d Add more pool settings (#416)
* Add some pool settings

* fmt
2023-04-26 16:33:26 -07:00
Shawn
cb5ff40a59 fix typo (#415)
chore: typo
2023-04-26 08:28:54 -07:00
28 changed files with 2230 additions and 296 deletions

View File

@@ -110,10 +110,6 @@ python3 tests/python/tests.py || exit 1
start_pgcat "info"
python3 tests/python/async_test.py
start_pgcat "info"
# Admin tests
export PGPASSWORD=admin_pass
psql -U admin_user -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null

View File

@@ -49,6 +49,14 @@ default: 30000 # milliseconds
How long an idle connection with a server is left open (ms).
### server_lifetime
```
path: general.server_lifetime
default: 86400000 # 24 hours
```
Max connection lifetime before it's closed, even if actively used.
### idle_client_in_transaction_timeout
```
path: general.idle_client_in_transaction_timeout
@@ -180,6 +188,22 @@ default: "admin_pass"
Password to access the virtual administrative database
### dns_cache_enabled
```
path: general.dns_cache_enabled
default: false
```
When enabled, ip resolutions for server connections specified using hostnames will be cached
and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found
old ip connections are closed (gracefully) and new connections will start using new ip.
### dns_max_ttl
```
path: general.dns_max_ttl
default: 30
```
Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
## `pools.<pool_name>` Section
### pool_mode

365
Cargo.lock generated
View File

@@ -26,6 +26,27 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6"
[[package]]
name = "async-stream"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dad5c83079eae9969be7fadefe640a1c566901f05ff91ab221de4b6f68d9507e"
dependencies = [
"async-stream-impl",
"futures-core",
]
[[package]]
name = "async-stream-impl"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10f203db73a71dfa2fb6dd22763990fa26f3d2625a6da2da900d23b87d26be27"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "async-trait"
version = "0.1.68"
@@ -212,6 +233,12 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "data-encoding"
version = "2.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ee2393c4a91429dffb4bedf19f4d6abf27d8a732c8ce4980305d782e5426d57"
[[package]]
name = "digest"
version = "0.10.6"
@@ -223,6 +250,18 @@ dependencies = [
"subtle",
]
[[package]]
name = "enum-as-inner"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c9720bba047d567ffc8a3cba48bf19126600e249ab7f128e9233e6376976a116"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "env_logger"
version = "0.10.0"
@@ -275,6 +314,15 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "form_urlencoded"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8"
dependencies = [
"percent-encoding",
]
[[package]]
name = "futures"
version = "0.3.28"
@@ -410,6 +458,12 @@ version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
[[package]]
name = "heck"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9"
[[package]]
name = "hermit-abi"
version = "0.2.6"
@@ -434,6 +488,17 @@ dependencies = [
"digest",
]
[[package]]
name = "hostname"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c731c3e10504cc8ed35cfe2f1db4c9274c3d35fa486e3b31df46f068ef3e867"
dependencies = [
"libc",
"match_cfg",
"winapi",
]
[[package]]
name = "http"
version = "0.2.9"
@@ -522,6 +587,27 @@ dependencies = [
"cxx-build",
]
[[package]]
name = "idna"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "418a0a6fab821475f634efe3ccc45c013f742efe03d853e8d3355d5cb850ecf8"
dependencies = [
"matches",
"unicode-bidi",
"unicode-normalization",
]
[[package]]
name = "idna"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e14ddfc70884202db2244c223200c204c2bda1bc6e0998d11b5e024d657209e6"
dependencies = [
"unicode-bidi",
"unicode-normalization",
]
[[package]]
name = "indexmap"
version = "1.9.2"
@@ -542,6 +628,24 @@ dependencies = [
"windows-sys",
]
[[package]]
name = "ipconfig"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd302af1b90f2463a98fa5ad469fc212c8e3175a41c3068601bfa2727591c5be"
dependencies = [
"socket2",
"widestring",
"winapi",
"winreg",
]
[[package]]
name = "ipnet"
version = "2.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f88c5561171189e69df9d98bcf18fd5f9558300f7ea7b801eb8a0fd748bd8745"
[[package]]
name = "is-terminal"
version = "0.4.4"
@@ -589,6 +693,12 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "lazy_static"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
name = "libc"
version = "0.2.139"
@@ -604,6 +714,12 @@ dependencies = [
"cc",
]
[[package]]
name = "linked-hash-map"
version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f"
[[package]]
name = "linux-raw-sys"
version = "0.1.4"
@@ -629,6 +745,27 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "lru-cache"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31e24f1ad8321ca0e8a1e0ac13f23cb668e6f5466c2c57319f6a5cf1cc8e3b1c"
dependencies = [
"linked-hash-map",
]
[[package]]
name = "match_cfg"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4"
[[package]]
name = "matches"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f"
[[package]]
name = "md-5"
version = "0.10.5"
@@ -737,9 +874,15 @@ dependencies = [
"windows-sys",
]
[[package]]
name = "percent-encoding"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e"
[[package]]
name = "pgcat"
version = "1.0.1"
version = "1.0.2-alpha1"
dependencies = [
"arc-swap",
"async-trait",
@@ -762,12 +905,15 @@ dependencies = [
"once_cell",
"parking_lot",
"phf",
"pin-project",
"postgres-protocol",
"rand",
"regex",
"rustls",
"rustls-pemfile",
"serde",
"serde_derive",
"serde_json",
"sha-1",
"sha2",
"socket2",
@@ -775,7 +921,10 @@ dependencies = [
"stringprep",
"tokio",
"tokio-rustls",
"tokio-test",
"toml",
"trust-dns-resolver",
"webpki-roots",
]
[[package]]
@@ -820,6 +969,26 @@ dependencies = [
"siphasher",
]
[[package]]
name = "pin-project"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "pin-project-lite"
version = "0.2.9"
@@ -865,6 +1034,12 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "quick-error"
version = "1.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
[[package]]
name = "quote"
version = "1.0.26"
@@ -915,9 +1090,9 @@ dependencies = [
[[package]]
name = "regex"
version = "1.8.0"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac6cf59af1067a3fb53fbe5c88c053764e930f932be1d71d3ffe032cbe147f59"
checksum = "af83e617f331cc6ae2da5443c602dfa5af81e517212d9d611a5b3ba1777b5370"
dependencies = [
"aho-corasick",
"memchr",
@@ -926,9 +1101,19 @@ dependencies = [
[[package]]
name = "regex-syntax"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5996294f19bd3aae0453a862ad728f60e6600695733dd5df01da90c54363a3c"
[[package]]
name = "resolv-conf"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6868896879ba532248f33598de5181522d8b3d9d724dfd230911e1a7d4822f5"
checksum = "52e44394d2086d010551b14b53b1f24e31647570cd1deb0379e2c21b329aba00"
dependencies = [
"hostname",
"quick-error",
]
[[package]]
name = "ring"
@@ -961,9 +1146,9 @@ dependencies = [
[[package]]
name = "rustls"
version = "0.21.0"
version = "0.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07180898a28ed6a7f7ba2311594308f595e3dd2e3c3812fa0a80a47b45f17e5d"
checksum = "c911ba11bc8433e811ce56fde130ccf32f5127cab0e0194e9c68c5a5b671791e"
dependencies = [
"log",
"ring",
@@ -990,6 +1175,12 @@ dependencies = [
"untrusted",
]
[[package]]
name = "ryu"
version = "1.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041"
[[package]]
name = "scopeguard"
version = "1.1.0"
@@ -1017,6 +1208,9 @@ name = "serde"
version = "1.0.160"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb2f3770c8bce3bcda7e149193a069a0f4365bda1fa5cd88e03bca26afc1216c"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
@@ -1029,6 +1223,17 @@ dependencies = [
"syn 2.0.9",
]
[[package]]
name = "serde_json"
version = "1.0.96"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1"
dependencies = [
"itoa",
"ryu",
"serde",
]
[[package]]
name = "serde_spanned"
version = "0.6.1"
@@ -1113,6 +1318,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "355dc4d4b6207ca8a3434fc587db0a8016130a574dbcdbfb93d7f7b5bc5b211a"
dependencies = [
"log",
"sqlparser_derive",
]
[[package]]
name = "sqlparser_derive"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55fe75cb4a364c7f7ae06c7dbbc8d84bddd85d6cdf9975963c3935bc1991761e"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
@@ -1168,6 +1385,26 @@ dependencies = [
"winapi-util",
]
[[package]]
name = "thiserror"
version = "1.0.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10deb33631e3c9018b9baf9dcbbc4f737320d2b576bac10f6aefa048fa407e3e"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "982d17546b47146b28f7c22e3d08465f6b8903d0ea13c1660d9d84a6e7adcdbb"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "time"
version = "0.1.45"
@@ -1235,6 +1472,30 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-stream"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d660770404473ccd7bc9f8b28494a811bc18542b915c0855c51e8f419d5223ce"
dependencies = [
"futures-core",
"pin-project-lite",
"tokio",
]
[[package]]
name = "tokio-test"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53474327ae5e166530d17f2d956afcb4f8a004de581b3cae10f12006bc8163e3"
dependencies = [
"async-stream",
"bytes",
"futures-core",
"tokio",
"tokio-stream",
]
[[package]]
name = "tokio-util"
version = "0.7.7"
@@ -1297,9 +1558,21 @@ checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
dependencies = [
"cfg-if",
"pin-project-lite",
"tracing-attributes",
"tracing-core",
]
[[package]]
name = "tracing-attributes"
version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "tracing-core"
version = "0.1.30"
@@ -1309,6 +1582,51 @@ dependencies = [
"once_cell",
]
[[package]]
name = "trust-dns-proto"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f7f83d1e4a0e4358ac54c5c3681e5d7da5efc5a7a632c90bb6d6669ddd9bc26"
dependencies = [
"async-trait",
"cfg-if",
"data-encoding",
"enum-as-inner",
"futures-channel",
"futures-io",
"futures-util",
"idna 0.2.3",
"ipnet",
"lazy_static",
"rand",
"smallvec",
"thiserror",
"tinyvec",
"tokio",
"tracing",
"url",
]
[[package]]
name = "trust-dns-resolver"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aff21aa4dcefb0a1afbfac26deb0adc93888c7d295fb63ab273ef276ba2b7cfe"
dependencies = [
"cfg-if",
"futures-util",
"ipconfig",
"lazy_static",
"lru-cache",
"parking_lot",
"resolv-conf",
"smallvec",
"thiserror",
"tokio",
"tracing",
"trust-dns-proto",
]
[[package]]
name = "try-lock"
version = "0.2.4"
@@ -1354,6 +1672,17 @@ version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
[[package]]
name = "url"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643"
dependencies = [
"form_urlencoded",
"idna 0.3.0",
"percent-encoding",
]
[[package]]
name = "version_check"
version = "0.9.4"
@@ -1446,6 +1775,21 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "webpki-roots"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa54963694b65584e170cf5dc46aeb4dcaa5584e652ff5f3952e56d66aff0125"
dependencies = [
"rustls-webpki",
]
[[package]]
name = "widestring"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17882f045410753661207383517a6f62ec3dbeb6a4ed2acce01f0728238d1983"
[[package]]
name = "winapi"
version = "0.3.9"
@@ -1551,3 +1895,12 @@ checksum = "faf09497b8f8b5ac5d3bb4d05c0a99be20f26fd3d5f2db7b0716e946d5103658"
dependencies = [
"memchr",
]
[[package]]
name = "winreg"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d"
dependencies = [
"winapi",
]

View File

@@ -1,6 +1,6 @@
[package]
name = "pgcat"
version = "1.0.1"
version = "1.0.2-alpha1"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -14,12 +14,12 @@ rand = "0.8"
chrono = "0.4"
sha-1 = "0.10"
toml = "0.7"
serde = "1"
serde = { version = "1", features = ["derive"] }
serde_derive = "1"
regex = "1"
num_cpus = "1"
once_cell = "1"
sqlparser = "0.33.0"
sqlparser = {version = "0.33", features = ["visitor"] }
log = "0.4"
arc-swap = "1"
env_logger = "0.10"
@@ -39,6 +39,13 @@ nix = "0.26.2"
atomic_enum = "0.2.0"
postgres-protocol = "0.6.5"
fallible-iterator = "0.2"
pin-project = "1"
webpki-roots = "0.23"
rustls = { version = "0.21", features = ["dangerous_configuration"] }
trust-dns-resolver = "0.22.0"
tokio-test = "0.4.2"
serde_json = "1"
[target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0"

View File

@@ -18,7 +18,7 @@ PostgreSQL pooler and proxy (like PgBouncer) with support for sharding, load bal
| Failover | **Stable** | Queries are automatically rerouted around broken replicas, validated by regular health checks. |
| Admin database statistics | **Stable** | Pooler statistics and administration via the `pgbouncer` and `pgcat` databases. |
| Prometheus statistics | **Stable** | Statistics are reported via a HTTP endpoint for Prometheus. |
| Client TLS | **Stable** | Clients can connect to the pooler using TLS/SSL. |
| SSL/TLS | **Stable** | Clients can connect to the pooler using TLS. Pooler can connect to Postgres servers using TLS. |
| Client/Server authentication | **Stable** | Clients can connect using MD5 authentication, supported by `libpq` and all Postgres client drivers. PgCat can connect to Postgres using MD5 and SCRAM-SHA-256. |
| Live configuration reloading | **Stable** | Identical to PgBouncer; all settings can be reloaded dynamically (except `host` and `port`). |
| Auth passthrough | **Stable** | MD5 password authentication can be configured to use an `auth_query` so no cleartext passwords are needed in the config file.|

View File

@@ -23,6 +23,9 @@ connect_timeout = 5000 # milliseconds
# How long an idle connection with a server is left open (ms).
idle_timeout = 30000 # milliseconds
# Max connection lifetime before it's closed, even if actively used.
server_lifetime = 86400000 # 24 hours
# How long a client is allowed to be idle while in a transaction (ms).
idle_client_in_transaction_timeout = 0 # milliseconds
@@ -58,9 +61,15 @@ tcp_keepalives_count = 5
tcp_keepalives_interval = 5
# Path to TLS Certificate file to use for TLS connections
# tls_certificate = "server.cert"
# tls_certificate = ".circleci/server.cert"
# Path to TLS private key file to use for TLS connections
# tls_private_key = "server.key"
# tls_private_key = ".circleci/server.key"
# Enable/disable server TLS
server_tls = false
# Verify server certificate is completely authentic.
verify_server_certificate = false
# User name to access the virtual administrative database (pgbouncer or pgcat)
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
@@ -137,6 +146,53 @@ idle_timeout = 40000
# Connect timeout can be overwritten in the pool
connect_timeout = 3000
# When enabled, ip resolutions for server connections specified using hostnames will be cached
# and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found
# old ip connections are closed (gracefully) and new connections will start using new ip.
# dns_cache_enabled = false
# Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
# dns_max_ttl = 30
[plugins]
[plugins.query_logger]
enabled = false
[plugins.table_access]
enabled = false
tables = [
"pg_user",
"pg_roles",
"pg_database",
]
[plugins.intercept]
enabled = true
[plugins.intercept.queries.0]
query = "select current_database() as a, current_schemas(false) as b"
schema = [
["a", "text"],
["b", "text"],
]
result = [
["${DATABASE}", "{public}"],
]
[plugins.intercept.queries.1]
query = "select current_database(), current_schema(), current_user"
schema = [
["current_database", "text"],
["current_schema", "text"],
["current_user", "text"],
]
result = [
["${DATABASE}", "public", "${USER}"],
]
# User configs are structured as pool.<pool_name>.users.<user_index>
# This section holds the credentials for users that may connect to this cluster
[pools.sharded_db.users.0]
@@ -206,6 +262,8 @@ sharding_function = "pg_bigint_hash"
username = "simple_user"
password = "simple_user"
pool_size = 5
min_pool_size = 3
server_lifetime = 60000
statement_timeout = 0
[pools.simple_db.shards.0]

View File

@@ -12,9 +12,9 @@ use tokio::time::Instant;
use crate::config::{get_config, reload_config, VERSION};
use crate::errors::Error;
use crate::messages::*;
use crate::pool::ClientServerMap;
use crate::pool::{get_all_pools, get_pool};
use crate::stats::{get_client_stats, get_pool_stats, get_server_stats, ClientState, ServerState};
use crate::ClientServerMap;
pub fn generate_server_info_for_admin() -> BytesMut {
let mut server_info = BytesMut::new();

View File

@@ -77,6 +77,8 @@ impl AuthPassthrough {
pool_size: 1,
statement_timeout: 0,
pool_mode: None,
server_lifetime: None,
min_pool_size: None,
};
let user = &address.username;

View File

@@ -16,6 +16,7 @@ use crate::auth_passthrough::refetch_auth_hash;
use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode};
use crate::constants::*;
use crate::messages::*;
use crate::plugins::PluginOutput;
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
use crate::query_router::{Command, QueryRouter};
use crate::server::Server;
@@ -539,6 +540,7 @@ where
Some(md5_hash_password(username, password, &salt))
} else {
if !get_config().is_auth_query_configured() {
wrong_password(&mut write, username).await?;
return Err(Error::ClientAuthImpossible(username.into()));
}
@@ -565,6 +567,8 @@ where
}
Err(err) => {
wrong_password(&mut write, username).await?;
return Err(Error::ClientAuthPassthroughError(
err.to_string(),
client_identifier,
@@ -587,7 +591,15 @@ where
client_identifier
);
let fetched_hash = refetch_auth_hash(&pool).await?;
let fetched_hash = match refetch_auth_hash(&pool).await {
Ok(fetched_hash) => fetched_hash,
Err(err) => {
wrong_password(&mut write, username).await?;
return Err(err);
}
};
let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt);
// Ok password changed in server an auth is possible.
@@ -754,6 +766,9 @@ where
self.stats.register(self.stats.clone());
// Result returned by one of the plugins.
let mut plugin_output = None;
// Our custom protocol loop.
// We expect the client to either start a transaction with regular queries
// or issue commands for our sharding and server selection protocol.
@@ -804,7 +819,25 @@ where
'Q' => {
if query_router.query_parser_enabled() {
query_router.infer(&message);
if let Ok(ast) = QueryRouter::parse(&message) {
let plugin_result = query_router.execute_plugins(&ast).await;
match plugin_result {
Ok(PluginOutput::Deny(error)) => {
error_response(&mut self.write, &error).await?;
continue;
}
Ok(PluginOutput::Intercept(result)) => {
write_all(&mut self.write, result).await?;
continue;
}
_ => (),
};
let _ = query_router.infer(&ast);
}
}
}
@@ -812,7 +845,13 @@ where
self.buffer.put(&message[..]);
if query_router.query_parser_enabled() {
query_router.infer(&message);
if let Ok(ast) = QueryRouter::parse(&message) {
if let Ok(output) = query_router.execute_plugins(&ast).await {
plugin_output = Some(output);
}
let _ = query_router.infer(&ast);
}
}
continue;
@@ -846,6 +885,18 @@ where
continue;
}
// Check on plugin results.
match plugin_output {
Some(PluginOutput::Deny(error)) => {
self.buffer.clear();
error_response(&mut self.write, &error).await?;
plugin_output = None;
continue;
}
_ => (),
};
// Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config
// when starting a query.
@@ -932,7 +983,7 @@ where
}
// Grab a server from the pool.
let mut connection = match pool
let connection = match pool
.get(query_router.shard(), query_router.role(), &self.stats)
.await
{
@@ -975,8 +1026,9 @@ where
}
};
let server = &mut *connection.0;
let mut reference = connection.0;
let address = connection.1;
let server = &mut *reference;
// Server is assigned to the client in case the client wants to
// cancel a query later.
@@ -999,7 +1051,6 @@ where
// Set application_name.
server.set_name(&self.application_name).await?;
server.switch_async(false);
let mut initial_message = Some(message);
@@ -1019,37 +1070,12 @@ where
None => {
trace!("Waiting for message inside transaction or in session mode");
let message = tokio::select! {
message = tokio::time::timeout(
idle_client_timeout_duration,
read_message(&mut self.read),
) => message,
server_message = server.recv() => {
debug!("Got async message");
let server_message = match server_message {
Ok(message) => message,
Err(err) => {
pool.ban(&address, BanReason::MessageReceiveFailed, Some(&self.stats));
server.mark_bad();
return Err(err);
}
};
match write_all_half(&mut self.write, &server_message).await {
Ok(_) => (),
Err(err) => {
server.mark_bad();
return Err(err);
}
};
continue;
}
};
match message {
match tokio::time::timeout(
idle_client_timeout_duration,
read_message(&mut self.read),
)
.await
{
Ok(Ok(message)) => message,
Ok(Err(err)) => {
// Client disconnected inside a transaction.
@@ -1099,6 +1125,27 @@ where
match code {
// Query
'Q' => {
if query_router.query_parser_enabled() {
if let Ok(ast) = QueryRouter::parse(&message) {
let plugin_result = query_router.execute_plugins(&ast).await;
match plugin_result {
Ok(PluginOutput::Deny(error)) => {
error_response(&mut self.write, &error).await?;
continue;
}
Ok(PluginOutput::Intercept(result)) => {
write_all(&mut self.write, result).await?;
continue;
}
_ => (),
};
let _ = query_router.infer(&ast);
}
}
debug!("Sending query to server");
self.send_and_receive_loop(
@@ -1138,6 +1185,14 @@ where
// Parse
// The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
'P' => {
if query_router.query_parser_enabled() {
if let Ok(ast) = QueryRouter::parse(&message) {
if let Ok(output) = query_router.execute_plugins(&ast).await {
plugin_output = Some(output);
}
}
}
self.buffer.put(&message[..]);
}
@@ -1166,13 +1221,26 @@ where
// Sync
// Frontend (client) is asking for the query result now.
'S' | 'H' => {
'S' => {
debug!("Sending query to server");
if code == 'H' {
server.switch_async(true);
debug!("Client requested flush, going async");
}
match plugin_output {
Some(PluginOutput::Deny(error)) => {
error_response(&mut self.write, &error).await?;
plugin_output = None;
self.buffer.clear();
continue;
}
Some(PluginOutput::Intercept(result)) => {
write_all(&mut self.write, result).await?;
plugin_output = None;
self.buffer.clear();
continue;
}
_ => (),
};
self.buffer.put(&message[..]);

View File

@@ -12,6 +12,7 @@ use std::sync::Arc;
use tokio::fs::File;
use tokio::io::AsyncReadExt;
use crate::dns_cache::CachedResolver;
use crate::errors::Error;
use crate::pool::{ClientServerMap, ConnectionPool};
use crate::sharding::ShardingFunction;
@@ -181,7 +182,9 @@ pub struct User {
pub server_username: Option<String>,
pub server_password: Option<String>,
pub pool_size: u32,
pub min_pool_size: Option<u32>,
pub pool_mode: Option<PoolMode>,
pub server_lifetime: Option<u64>,
#[serde(default)] // 0
pub statement_timeout: u64,
}
@@ -194,12 +197,34 @@ impl Default for User {
server_username: None,
server_password: None,
pool_size: 15,
min_pool_size: None,
statement_timeout: 0,
pool_mode: None,
server_lifetime: None,
}
}
}
impl User {
fn validate(&self) -> Result<(), Error> {
match self.min_pool_size {
Some(min_pool_size) => {
if min_pool_size > self.pool_size {
error!(
"min_pool_size of {} cannot be larger than pool_size of {}",
min_pool_size, self.pool_size
);
return Err(Error::BadConfig);
}
}
None => (),
};
Ok(())
}
}
/// General configuration.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct General {
@@ -231,6 +256,12 @@ pub struct General {
#[serde(default)] // False
pub log_client_disconnections: bool,
#[serde(default)] // False
pub dns_cache_enabled: bool,
#[serde(default = "General::default_dns_max_ttl")]
pub dns_max_ttl: u64,
#[serde(default = "General::default_shutdown_timeout")]
pub shutdown_timeout: u64,
@@ -246,6 +277,9 @@ pub struct General {
#[serde(default = "General::default_idle_client_in_transaction_timeout")]
pub idle_client_in_transaction_timeout: u64,
#[serde(default = "General::default_server_lifetime")]
pub server_lifetime: u64,
#[serde(default = "General::default_worker_threads")]
pub worker_threads: usize,
@@ -254,9 +288,17 @@ pub struct General {
pub tls_certificate: Option<String>,
pub tls_private_key: Option<String>,
#[serde(default)] // false
pub server_tls: bool,
#[serde(default)] // false
pub verify_server_certificate: bool,
pub admin_username: String,
pub admin_password: String,
// Support for auth query
pub auth_query: Option<String>,
pub auth_query_user: Option<String>,
pub auth_query_password: Option<String>,
@@ -271,6 +313,10 @@ impl General {
5432
}
pub fn default_server_lifetime() -> u64 {
1000 * 60 * 60 * 24 // 24 hours
}
pub fn default_connect_timeout() -> u64 {
1000
}
@@ -298,6 +344,10 @@ impl General {
60000
}
pub fn default_dns_max_ttl() -> u64 {
30
}
pub fn default_healthcheck_timeout() -> u64 {
1000
}
@@ -340,13 +390,18 @@ impl Default for General {
log_client_connections: false,
log_client_disconnections: false,
autoreload: None,
dns_cache_enabled: false,
dns_max_ttl: Self::default_dns_max_ttl(),
tls_certificate: None,
tls_private_key: None,
server_tls: false,
verify_server_certificate: false,
admin_username: String::from("admin"),
admin_password: String::from("admin"),
auth_query: None,
auth_query_user: None,
auth_query_password: None,
server_lifetime: 1000 * 3600 * 24, // 24 hours,
}
}
}
@@ -411,6 +466,8 @@ pub struct Pool {
pub idle_timeout: Option<u64>,
pub server_lifetime: Option<u64>,
pub sharding_function: ShardingFunction,
#[serde(default = "Pool::default_automatic_sharding_key")]
@@ -515,6 +572,10 @@ impl Pool {
None => None,
};
for (_, user) in &self.users {
user.validate()?;
}
Ok(())
}
}
@@ -539,6 +600,7 @@ impl Default for Pool {
auth_query: None,
auth_query_user: None,
auth_query_password: None,
server_lifetime: None,
}
}
}
@@ -588,7 +650,7 @@ impl Shard {
if primary_count > 1 {
error!(
"Shard {} has more than on primary configured",
"Shard {} has more than one primary configured",
self.database
);
return Err(Error::BadConfig);
@@ -617,6 +679,56 @@ impl Default for Shard {
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Plugins {
pub intercept: Option<Intercept>,
pub table_access: Option<TableAccess>,
pub query_logger: Option<QueryLogger>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Intercept {
pub enabled: bool,
pub queries: BTreeMap<String, Query>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct TableAccess {
pub enabled: bool,
pub tables: Vec<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct QueryLogger {
pub enabled: bool,
}
impl Intercept {
pub fn substitute(&mut self, db: &str, user: &str) {
for (_, query) in self.queries.iter_mut() {
query.substitute(db, user);
query.query = query.query.to_ascii_lowercase();
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Query {
pub query: String,
pub schema: Vec<Vec<String>>,
pub result: Vec<Vec<String>>,
}
impl Query {
pub fn substitute(&mut self, db: &str, user: &str) {
for col in self.result.iter_mut() {
for i in 0..col.len() {
col[i] = col[i].replace("${USER}", user).replace("${DATABASE}", db);
}
}
}
}
/// Configuration wrapper.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Config {
@@ -635,6 +747,7 @@ pub struct Config {
pub path: String,
pub general: General,
pub plugins: Option<Plugins>,
pub pools: HashMap<String, Pool>,
}
@@ -672,6 +785,7 @@ impl Default for Config {
path: Self::default_path(),
general: General::default(),
pools: HashMap::default(),
plugins: None,
}
}
}
@@ -791,6 +905,10 @@ impl Config {
);
info!("Shutdown timeout: {}ms", self.general.shutdown_timeout);
info!("Healthcheck delay: {}ms", self.general.healthcheck_delay);
info!(
"Default max server lifetime: {}ms",
self.general.server_lifetime
);
match self.general.tls_certificate.clone() {
Some(tls_certificate) => {
info!("TLS certificate: {}", tls_certificate);
@@ -809,6 +927,11 @@ impl Config {
info!("TLS support is disabled");
}
};
info!("Server TLS enabled: {}", self.general.server_tls);
info!(
"Server TLS certificate verification: {}",
self.general.verify_server_certificate
);
for (pool_name, pool_config) in &self.pools {
// TODO: Make this output prettier (maybe a table?)
@@ -867,12 +990,26 @@ impl Config {
pool_name,
pool_config.users.len()
);
info!(
"[pool: {}] Max server lifetime: {}",
pool_name,
match pool_config.server_lifetime {
Some(server_lifetime) => format!("{}ms", server_lifetime),
None => "default".to_string(),
}
);
for user in &pool_config.users {
info!(
"[pool: {}][user: {}] Pool size: {}",
pool_name, user.1.username, user.1.pool_size,
);
info!(
"[pool: {}][user: {}] Minimum pool size: {}",
pool_name,
user.1.username,
user.1.min_pool_size.unwrap_or(0)
);
info!(
"[pool: {}][user: {}] Statement timeout: {}",
pool_name, user.1.username, user.1.statement_timeout
@@ -886,6 +1023,15 @@ impl Config {
None => pool_config.pool_mode.to_string(),
}
);
info!(
"[pool: {}][user: {}] Max server lifetime: {}",
pool_name,
user.1.username,
match user.1.server_lifetime {
Some(server_lifetime) => format!("{}ms", server_lifetime),
None => "default".to_string(),
}
);
}
}
}
@@ -896,7 +1042,13 @@ impl Config {
&& (self.general.auth_query_user.is_none()
|| self.general.auth_query_password.is_none())
{
error!("If auth_query is specified, you need to provide a value for `auth_query_user`, `auth_query_password`");
error!(
"If auth_query is specified, \
you need to provide a value \
for `auth_query_user`, \
`auth_query_password`"
);
return Err(Error::BadConfig);
}
@@ -904,7 +1056,14 @@ impl Config {
if pool.auth_query.is_some()
&& (pool.auth_query_user.is_none() || pool.auth_query_password.is_none())
{
error!("Error in pool {{ {} }}. If auth_query is specified, you need to provide a value for `auth_query_user`, `auth_query_password`", name);
error!(
"Error in pool {{ {} }}. \
If auth_query is specified, you need \
to provide a value for `auth_query_user`, \
`auth_query_password`",
name
);
return Err(Error::BadConfig);
}
@@ -914,7 +1073,13 @@ impl Config {
|| pool.auth_query_user.is_none())
&& user_data.password.is_none()
{
error!("Error in pool {{ {} }}. You have to specify a user password for every pool if auth_query is not specified", name);
error!(
"Error in pool {{ {} }}. \
You have to specify a user password \
for every pool if auth_query is not specified",
name
);
return Err(Error::BadConfig);
}
}
@@ -1012,6 +1177,7 @@ pub async fn parse(path: &str) -> Result<(), Error> {
pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, Error> {
let old_config = get_config();
match parse(&old_config.path).await {
Ok(()) => (),
Err(err) => {
@@ -1019,14 +1185,18 @@ pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, E
return Err(Error::BadConfig);
}
};
let new_config = get_config();
if old_config.pools != new_config.pools {
info!("Pool configuration changed");
match CachedResolver::from_config().await {
Ok(_) => (),
Err(err) => error!("DNS cache reinitialization error: {:?}", err),
};
if old_config != new_config {
info!("Config changed, reloading");
ConnectionPool::from_config(client_server_map).await?;
Ok(true)
} else if old_config != new_config {
Ok(true)
} else {
Ok(false)
}

410
src/dns_cache.rs Normal file
View File

@@ -0,0 +1,410 @@
use crate::config::get_config;
use crate::errors::Error;
use arc_swap::ArcSwap;
use log::{debug, error, info, warn};
use once_cell::sync::Lazy;
use std::collections::{HashMap, HashSet};
use std::io;
use std::net::IpAddr;
use std::sync::Arc;
use std::sync::RwLock;
use tokio::time::{sleep, Duration};
use trust_dns_resolver::error::{ResolveError, ResolveResult};
use trust_dns_resolver::lookup_ip::LookupIp;
use trust_dns_resolver::TokioAsyncResolver;
/// Cached Resolver Globally available
pub static CACHED_RESOLVER: Lazy<ArcSwap<CachedResolver>> =
Lazy::new(|| ArcSwap::from_pointee(CachedResolver::default()));
// Ip addressed are returned as a set of addresses
// so we can compare.
#[derive(Clone, PartialEq, Debug)]
pub struct AddrSet {
set: HashSet<IpAddr>,
}
impl AddrSet {
fn new() -> AddrSet {
AddrSet {
set: HashSet::new(),
}
}
}
impl From<LookupIp> for AddrSet {
fn from(lookup_ip: LookupIp) -> Self {
let mut addr_set = AddrSet::new();
for address in lookup_ip.iter() {
addr_set.set.insert(address);
}
addr_set
}
}
///
/// A CachedResolver is a DNS resolution cache mechanism with customizable expiration time.
///
/// The system works as follows:
///
/// When a host is to be resolved, if we have not resolved it before, a new resolution is
/// executed and stored in the internal cache. Concurrently, every `dns_max_ttl` time, the
/// cache is refreshed.
///
/// # Example:
///
/// ```
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
///
/// # tokio_test::block_on(async {
/// let config = CachedResolverConfig::default();
/// let resolver = CachedResolver::new(config, None).await.unwrap();
/// let addrset = resolver.lookup_ip("www.example.com.").await.unwrap();
/// # })
/// ```
///
/// // Now the ip resolution is stored in local cache and subsequent
/// // calls will be returned from cache. Also, the cache is refreshed
/// // and updated every 10 seconds.
///
/// // You can now check if an 'old' lookup differs from what it's currently
/// // store in cache by using `has_changed`.
/// resolver.has_changed("www.example.com.", addrset)
#[derive(Default)]
pub struct CachedResolver {
// The configuration of the cached_resolver.
config: CachedResolverConfig,
// This is the hash that contains the hash.
data: Option<RwLock<HashMap<String, AddrSet>>>,
// The resolver to be used for DNS queries.
resolver: Option<TokioAsyncResolver>,
// The RefreshLoop
refresh_loop: RwLock<Option<tokio::task::JoinHandle<()>>>,
}
///
/// Configuration
#[derive(Clone, Debug, Default, PartialEq)]
pub struct CachedResolverConfig {
/// Amount of time in secods that a resolved dns address is considered stale.
dns_max_ttl: u64,
/// Enabled or disabled? (this is so we can reload config)
enabled: bool,
}
impl CachedResolverConfig {
fn new(dns_max_ttl: u64, enabled: bool) -> Self {
CachedResolverConfig {
dns_max_ttl,
enabled,
}
}
}
impl From<crate::config::Config> for CachedResolverConfig {
fn from(config: crate::config::Config) -> Self {
CachedResolverConfig::new(config.general.dns_max_ttl, config.general.dns_cache_enabled)
}
}
impl CachedResolver {
///
/// Returns a new Arc<CachedResolver> based on passed configuration.
/// It also starts the loop that will refresh cache entries.
///
/// # Arguments:
///
/// * `config` - The `CachedResolverConfig` to be used to create the resolver.
///
/// # Example:
///
/// ```
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
///
/// # tokio_test::block_on(async {
/// let config = CachedResolverConfig::default();
/// let resolver = CachedResolver::new(config, None).await.unwrap();
/// # })
/// ```
///
pub async fn new(
config: CachedResolverConfig,
data: Option<HashMap<String, AddrSet>>,
) -> Result<Arc<Self>, io::Error> {
// Construct a new Resolver with default configuration options
let resolver = Some(TokioAsyncResolver::tokio_from_system_conf()?);
let data = if let Some(hash) = data {
Some(RwLock::new(hash))
} else {
Some(RwLock::new(HashMap::new()))
};
let instance = Arc::new(Self {
config,
resolver,
data,
refresh_loop: RwLock::new(None),
});
if instance.enabled() {
info!("Scheduling DNS refresh loop");
let refresh_loop = tokio::task::spawn({
let instance = instance.clone();
async move {
instance.refresh_dns_entries_loop().await;
}
});
*(instance.refresh_loop.write().unwrap()) = Some(refresh_loop);
}
Ok(instance)
}
pub fn enabled(&self) -> bool {
self.config.enabled
}
// Schedules the refresher
async fn refresh_dns_entries_loop(&self) {
let resolver = TokioAsyncResolver::tokio_from_system_conf().unwrap();
let interval = Duration::from_secs(self.config.dns_max_ttl);
loop {
debug!("Begin refreshing cached DNS addresses.");
// To minimize the time we hold the lock, we first create
// an array with keys.
let mut hostnames: Vec<String> = Vec::new();
{
if let Some(ref data) = self.data {
for hostname in data.read().unwrap().keys() {
hostnames.push(hostname.clone());
}
}
}
for hostname in hostnames.iter() {
let addrset = self
.fetch_from_cache(hostname.as_str())
.expect("Could not obtain expected address from cache, this should not happen");
match resolver.lookup_ip(hostname).await {
Ok(lookup_ip) => {
let new_addrset = AddrSet::from(lookup_ip);
debug!(
"Obtained address for host ({}) -> ({:?})",
hostname, new_addrset
);
if addrset != new_addrset {
debug!(
"Addr changed from {:?} to {:?} updating cache.",
addrset, new_addrset
);
self.store_in_cache(hostname, new_addrset);
}
}
Err(err) => {
error!(
"There was an error trying to resolv {}: ({}).",
hostname, err
);
}
}
}
debug!("Finished refreshing cached DNS addresses.");
sleep(interval).await;
}
}
/// Returns a `AddrSet` given the specified hostname.
///
/// This method first tries to fetch the value from the cache, if it misses
/// then it is resolved and stored in the cache. TTL from records is ignored.
///
/// # Arguments
///
/// * `host` - A string slice referencing the hostname to be resolved.
///
/// # Example:
///
/// ```
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
///
/// # tokio_test::block_on(async {
/// let config = CachedResolverConfig::default();
/// let resolver = CachedResolver::new(config, None).await.unwrap();
/// let response = resolver.lookup_ip("www.google.com.");
/// # })
/// ```
///
pub async fn lookup_ip(&self, host: &str) -> ResolveResult<AddrSet> {
debug!("Lookup up {} in cache", host);
match self.fetch_from_cache(host) {
Some(addr_set) => {
debug!("Cache hit!");
Ok(addr_set)
}
None => {
debug!("Not found, executing a dns query!");
if let Some(ref resolver) = self.resolver {
let addr_set = AddrSet::from(resolver.lookup_ip(host).await?);
debug!("Obtained: {:?}", addr_set);
self.store_in_cache(host, addr_set.clone());
Ok(addr_set)
} else {
Err(ResolveError::from("No resolver available"))
}
}
}
}
//
// Returns true if the stored host resolution differs from the AddrSet passed.
pub fn has_changed(&self, host: &str, addr_set: &AddrSet) -> bool {
if let Some(fetched_addr_set) = self.fetch_from_cache(host) {
return fetched_addr_set != *addr_set;
}
false
}
// Fetches an AddrSet from the inner cache adquiring the read lock.
fn fetch_from_cache(&self, key: &str) -> Option<AddrSet> {
if let Some(ref hash) = self.data {
if let Some(addr_set) = hash.read().unwrap().get(key) {
return Some(addr_set.clone());
}
}
None
}
// Sets up the global CACHED_RESOLVER static variable so we can globally use DNS
// cache.
pub async fn from_config() -> Result<(), Error> {
let cached_resolver = CACHED_RESOLVER.load();
let desired_config = CachedResolverConfig::from(get_config());
if cached_resolver.config != desired_config {
if let Some(ref refresh_loop) = *(cached_resolver.refresh_loop.write().unwrap()) {
warn!("Killing Dnscache refresh loop as its configuration is being reloaded");
refresh_loop.abort()
}
let new_resolver = if let Some(ref data) = cached_resolver.data {
let data = Some(data.read().unwrap().clone());
CachedResolver::new(desired_config, data).await
} else {
CachedResolver::new(desired_config, None).await
};
match new_resolver {
Ok(ok) => {
CACHED_RESOLVER.store(ok);
Ok(())
}
Err(err) => {
let message = format!("Error setting up cached_resolver. Error: {:?}, will continue without this feature.", err);
Err(Error::DNSCachedError(message))
}
}
} else {
Ok(())
}
}
// Stores the AddrSet in cache adquiring the write lock.
fn store_in_cache(&self, host: &str, addr_set: AddrSet) {
if let Some(ref data) = self.data {
data.write().unwrap().insert(host.to_string(), addr_set);
} else {
error!("Could not insert, Hash not initialized");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use trust_dns_resolver::error::ResolveError;
#[tokio::test]
async fn new() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await;
assert!(resolver.is_ok());
}
#[tokio::test]
async fn lookup_ip() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let response = resolver.lookup_ip("www.google.com.").await;
assert!(response.is_ok());
}
#[tokio::test]
async fn has_changed() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let hostname = "www.google.com.";
let response = resolver.lookup_ip(hostname).await;
let addr_set = response.unwrap();
assert!(!resolver.has_changed(hostname, &addr_set));
}
#[tokio::test]
async fn unknown_host() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let hostname = "www.idontexists.";
let response = resolver.lookup_ip(hostname).await;
assert!(matches!(response, Err(ResolveError { .. })));
}
#[tokio::test]
async fn incorrect_address() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let hostname = "w ww.idontexists.";
let response = resolver.lookup_ip(hostname).await;
assert!(matches!(response, Err(ResolveError { .. })));
assert!(!resolver.has_changed(hostname, &AddrSet::new()));
}
#[tokio::test]
// Ok, this test is based on the fact that google does DNS RR
// and does not responds with every available ip everytime, so
// if I cache here, it will miss after one cache iteration or two.
async fn thread() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let hostname = "www.google.com.";
let response = resolver.lookup_ip(hostname).await;
let addr_set = response.unwrap();
assert!(!resolver.has_changed(hostname, &addr_set));
let resolver_for_refresher = resolver.clone();
let _thread_handle = tokio::task::spawn(async move {
resolver_for_refresher.refresh_dns_entries_loop().await;
});
assert!(!resolver.has_changed(hostname, &addr_set));
}
}

View File

@@ -1,7 +1,7 @@
//! Errors.
/// Various errors.
#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Clone)]
pub enum Error {
SocketError(String),
ClientSocketError(String, ClientIdentifier),
@@ -19,10 +19,13 @@ pub enum Error {
ClientError(String),
TlsError,
StatementTimeout,
DNSCachedError(String),
ShuttingDown,
ParseBytesError(String),
AuthError(String),
AuthPassthroughError(String),
UnsupportedStatement,
QueryRouterParserError(String),
}
#[derive(Clone, PartialEq, Debug)]

View File

@@ -1,11 +1,17 @@
pub mod admin;
pub mod auth_passthrough;
pub mod client;
pub mod config;
pub mod constants;
pub mod dns_cache;
pub mod errors;
pub mod messages;
pub mod mirrors;
pub mod multi_logger;
pub mod plugins;
pub mod pool;
pub mod prometheus;
pub mod query_router;
pub mod scram;
pub mod server;
pub mod sharding;

View File

@@ -36,6 +36,7 @@ extern crate sqlparser;
extern crate tokio;
extern crate tokio_rustls;
extern crate toml;
extern crate trust_dns_resolver;
#[cfg(not(target_env = "msvc"))]
use jemallocator::Jemalloc;
@@ -60,36 +61,19 @@ use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::broadcast;
mod admin;
mod auth_passthrough;
mod client;
mod config;
mod constants;
mod errors;
mod messages;
mod mirrors;
mod multi_logger;
mod pool;
mod prometheus;
mod query_router;
mod scram;
mod server;
mod sharding;
mod stats;
mod tls;
use crate::config::{get_config, reload_config, VERSION};
use crate::messages::configure_socket;
use crate::pool::{ClientServerMap, ConnectionPool};
use crate::prometheus::start_metric_server;
use crate::stats::{Collector, Reporter, REPORTER};
use pgcat::config::{get_config, reload_config, VERSION};
use pgcat::dns_cache;
use pgcat::messages::configure_socket;
use pgcat::pool::{ClientServerMap, ConnectionPool};
use pgcat::prometheus::start_metric_server;
use pgcat::stats::{Collector, Reporter, REPORTER};
fn main() -> Result<(), Box<dyn std::error::Error>> {
multi_logger::MultiLogger::init().unwrap();
pgcat::multi_logger::MultiLogger::init().unwrap();
info!("Welcome to PgCat! Meow. (Version {})", VERSION);
if !query_router::QueryRouter::setup() {
if !pgcat::query_router::QueryRouter::setup() {
error!("Could not setup query router");
std::process::exit(exitcode::CONFIG);
}
@@ -107,7 +91,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let runtime = Builder::new_multi_thread().worker_threads(1).build()?;
runtime.block_on(async {
match config::parse(&config_file).await {
match pgcat::config::parse(&config_file).await {
Ok(_) => (),
Err(err) => {
error!("Config parse error: {:?}", err);
@@ -166,6 +150,12 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// Statistics reporting.
REPORTER.store(Arc::new(Reporter::default()));
// Starts (if enabled) dns cache before pools initialization
match dns_cache::CachedResolver::from_config().await {
Ok(_) => (),
Err(err) => error!("DNS cache initialization error: {:?}", err),
};
// Connection pool that allows to query all shards and replicas.
match ConnectionPool::from_config(client_server_map.clone()).await {
Ok(_) => (),
@@ -295,7 +285,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
tokio::task::spawn(async move {
let start = chrono::offset::Utc::now().naive_utc();
match client::client_entrypoint(
match pgcat::client::client_entrypoint(
socket,
client_server_map,
shutdown_rx,
@@ -326,7 +316,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
Err(err) => {
match err {
errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err),
pgcat::errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err),
_ => warn!("Client disconnected with error {:?}", err),
}

View File

@@ -20,6 +20,10 @@ pub enum DataType {
Text,
Int4,
Numeric,
Bool,
Oid,
AnyArray,
Any,
}
impl From<&DataType> for i32 {
@@ -28,6 +32,10 @@ impl From<&DataType> for i32 {
DataType::Text => 25,
DataType::Int4 => 23,
DataType::Numeric => 1700,
DataType::Bool => 16,
DataType::Oid => 26,
DataType::AnyArray => 2277,
DataType::Any => 2276,
}
}
}
@@ -116,7 +124,10 @@ where
/// Send the startup packet the server. We're pretending we're a Pg client.
/// This tells the server which user we are and what database we want.
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> {
pub async fn startup<S>(stream: &mut S, user: &str, database: &str) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut bytes = BytesMut::with_capacity(25);
bytes.put_i32(196608); // Protocol number
@@ -150,6 +161,21 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
}
}
pub async fn ssl_request(stream: &mut TcpStream) -> Result<(), Error> {
let mut bytes = BytesMut::with_capacity(12);
bytes.put_i32(8);
bytes.put_i32(80877103);
match stream.write_all(&bytes).await {
Ok(_) => Ok(()),
Err(err) => Err(Error::SocketError(format!(
"Error writing SSLRequest to server socket - Error: {:?}",
err
))),
}
}
/// Parse the params the server sends as a key/value format.
pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
let mut result = HashMap::new();
@@ -425,6 +451,10 @@ pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
DataType::Text => -1,
DataType::Int4 => 4,
DataType::Numeric => -1,
DataType::Bool => 1,
DataType::Oid => 4,
DataType::AnyArray => -1,
DataType::Any => -1,
};
row_desc.put_i16(type_size);
@@ -463,6 +493,29 @@ pub fn data_row(row: &Vec<String>) -> BytesMut {
res
}
pub fn data_row_nullable(row: &Vec<Option<String>>) -> BytesMut {
let mut res = BytesMut::new();
let mut data_row = BytesMut::new();
data_row.put_i16(row.len() as i16);
for column in row {
if let Some(column) = column {
let column = column.as_bytes();
data_row.put_i32(column.len() as i32);
data_row.put_slice(column);
} else {
data_row.put_i32(-1 as i32);
}
}
res.put_u8(b'D');
res.put_i32(data_row.len() as i32 + 4);
res.put(data_row);
res
}
/// Create a CommandComplete message.
pub fn command_complete(command: &str) -> BytesMut {
let cmd = BytesMut::from(format!("{}\0", command).as_bytes());
@@ -505,6 +558,29 @@ where
}
}
pub async fn write_all_flush<S>(stream: &mut S, buf: &[u8]) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
match stream.write_all(buf).await {
Ok(_) => match stream.flush().await {
Ok(_) => Ok(()),
Err(err) => {
return Err(Error::SocketError(format!(
"Error flushing socket - Error: {:?}",
err
)))
}
},
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
)))
}
}
}
/// Read a complete message from the socket.
pub async fn read_message<S>(stream: &mut S) -> Result<BytesMut, Error>
where

288
src/plugins/intercept.rs Normal file
View File

@@ -0,0 +1,288 @@
//! The intercept plugin.
//!
//! It intercepts queries and returns fake results.
use arc_swap::ArcSwap;
use async_trait::async_trait;
use bytes::{BufMut, BytesMut};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use sqlparser::ast::Statement;
use std::collections::HashMap;
use log::{debug, info};
use std::sync::Arc;
use crate::{
config::Intercept as InterceptConfig,
errors::Error,
messages::{command_complete, data_row_nullable, row_description, DataType},
plugins::{Plugin, PluginOutput},
pool::{PoolIdentifier, PoolMap},
query_router::QueryRouter,
};
pub static CONFIG: Lazy<ArcSwap<HashMap<PoolIdentifier, InterceptConfig>>> =
Lazy::new(|| ArcSwap::from_pointee(HashMap::new()));
/// Check if the interceptor plugin has been enabled.
pub fn enabled() -> bool {
!CONFIG.load().is_empty()
}
pub fn setup(intercept_config: &InterceptConfig, pools: &PoolMap) {
let mut config = HashMap::new();
for (identifier, _) in pools.iter() {
let mut intercept_config = intercept_config.clone();
intercept_config.substitute(&identifier.db, &identifier.user);
config.insert(identifier.clone(), intercept_config);
}
CONFIG.store(Arc::new(config));
info!("Intercepting {} queries", intercept_config.queries.len());
}
pub fn disable() {
CONFIG.store(Arc::new(HashMap::new()));
}
// TODO: use these structs for deserialization
#[derive(Serialize, Deserialize)]
pub struct Rule {
query: String,
schema: Vec<Column>,
result: Vec<Vec<String>>,
}
#[derive(Serialize, Deserialize)]
pub struct Column {
name: String,
data_type: String,
}
/// The intercept plugin.
pub struct Intercept;
#[async_trait]
impl Plugin for Intercept {
async fn run(
&mut self,
query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error> {
if ast.is_empty() {
return Ok(PluginOutput::Allow);
}
let mut result = BytesMut::new();
let query_map = match CONFIG.load().get(&PoolIdentifier::new(
&query_router.pool_settings().db,
&query_router.pool_settings().user.username,
)) {
Some(query_map) => query_map.clone(),
None => return Ok(PluginOutput::Allow),
};
for q in ast {
// Normalization
let q = q.to_string().to_ascii_lowercase();
for (_, target) in query_map.queries.iter() {
if target.query.as_str() == q {
debug!("Intercepting query: {}", q);
let rd = target
.schema
.iter()
.map(|row| {
let name = &row[0];
let data_type = &row[1];
(
name.as_str(),
match data_type.as_str() {
"text" => DataType::Text,
"anyarray" => DataType::AnyArray,
"oid" => DataType::Oid,
"bool" => DataType::Bool,
"int4" => DataType::Int4,
_ => DataType::Any,
},
)
})
.collect::<Vec<(&str, DataType)>>();
result.put(row_description(&rd));
target.result.iter().for_each(|row| {
let row = row
.iter()
.map(|s| {
let s = s.as_str().to_string();
if s == "" {
None
} else {
Some(s)
}
})
.collect::<Vec<Option<String>>>();
result.put(data_row_nullable(&row));
});
result.put(command_complete("SELECT"));
}
}
}
if !result.is_empty() {
result.put_u8(b'Z');
result.put_i32(5);
result.put_u8(b'I');
return Ok(PluginOutput::Intercept(result));
} else {
Ok(PluginOutput::Allow)
}
}
}
/// Make IntelliJ SQL plugin believe it's talking to an actual database
/// instead of PgCat.
#[allow(dead_code)]
fn fool_datagrip(database: &str, user: &str) -> Value {
json!([
{
"query": "select current_database() as a, current_schemas(false) as b",
"schema": [
{
"name": "a",
"data_type": "text",
},
{
"name": "b",
"data_type": "anyarray",
},
],
"result": [
[database, "{public}"],
],
},
{
"query": "select current_database(), current_schema(), current_user",
"schema": [
{
"name": "current_database",
"data_type": "text",
},
{
"name": "current_schema",
"data_type": "text",
},
{
"name": "current_user",
"data_type": "text",
}
],
"result": [
["sharded_db", "public", "sharding_user"],
],
},
{
"query": "select cast(n.oid as bigint) as id, datname as name, d.description, datistemplate as is_template, datallowconn as allow_connections, pg_catalog.pg_get_userbyid(n.datdba) as \"owner\" from pg_catalog.pg_database as n left join pg_catalog.pg_shdescription as d on n.oid = d.objoid order by case when datname = pg_catalog.current_database() then -cast(1 as bigint) else cast(n.oid as bigint) end",
"schema": [
{
"name": "id",
"data_type": "oid",
},
{
"name": "name",
"data_type": "text",
},
{
"name": "description",
"data_type": "text",
},
{
"name": "is_template",
"data_type": "bool",
},
{
"name": "allow_connections",
"data_type": "bool",
},
{
"name": "owner",
"data_type": "text",
}
],
"result": [
["16387", database, "", "f", "t", user],
]
},
{
"query": "select cast(r.oid as bigint) as role_id, rolname as role_name, rolsuper as is_super, rolinherit as is_inherit, rolcreaterole as can_createrole, rolcreatedb as can_createdb, rolcanlogin as can_login, rolreplication as is_replication, rolconnlimit as conn_limit, rolvaliduntil as valid_until, rolbypassrls as bypass_rls, rolconfig as config, d.description from pg_catalog.pg_roles as r left join pg_catalog.pg_shdescription as d on d.objoid = r.oid",
"schema": [
{
"name": "role_id",
"data_type": "oid",
},
{
"name": "role_name",
"data_type": "text",
},
{
"name": "is_super",
"data_type": "bool",
},
{
"name": "is_inherit",
"data_type": "bool",
},
{
"name": "can_createrole",
"data_type": "bool",
},
{
"name": "can_createdb",
"data_type": "bool",
},
{
"name": "can_login",
"data_type": "bool",
},
{
"name": "is_replication",
"data_type": "bool",
},
{
"name": "conn_limit",
"data_type": "int4",
},
{
"name": "valid_until",
"data_type": "text",
},
{
"name": "bypass_rls",
"data_type": "bool",
},
{
"name": "config",
"data_type": "text",
},
{
"name": "description",
"data_type": "text",
},
],
"result": [
["10", "postgres", "f", "t", "f", "f", "t", "f", "-1", "", "f", "", ""],
["16419", user, "f", "t", "f", "f", "t", "f", "-1", "", "f", "", ""],
]
}
])
}

43
src/plugins/mod.rs Normal file
View File

@@ -0,0 +1,43 @@
//! The plugin ecosystem.
//!
//! Currently plugins only grant access or deny access to the database for a particual query.
//! Example use cases:
//! - block known bad queries
//! - block access to system catalogs
//! - block dangerous modifications like `DROP TABLE`
//! - etc
//!
pub mod intercept;
pub mod query_logger;
pub mod table_access;
use crate::{errors::Error, query_router::QueryRouter};
use async_trait::async_trait;
use bytes::BytesMut;
use sqlparser::ast::Statement;
pub use intercept::Intercept;
pub use query_logger::QueryLogger;
pub use table_access::TableAccess;
#[derive(Clone, Debug, PartialEq)]
pub enum PluginOutput {
Allow,
Deny(String),
Overwrite(Vec<Statement>),
Intercept(BytesMut),
}
#[async_trait]
pub trait Plugin {
// Run before the query is sent to the server.
async fn run(
&mut self,
query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error>;
// TODO: run after the result is returned
// async fn callback(&mut self, query_router: &QueryRouter);
}

View File

@@ -0,0 +1,49 @@
//! Log all queries to stdout (or somewhere else, why not).
use crate::{
errors::Error,
plugins::{Plugin, PluginOutput},
query_router::QueryRouter,
};
use arc_swap::ArcSwap;
use async_trait::async_trait;
use log::info;
use once_cell::sync::Lazy;
use sqlparser::ast::Statement;
use std::sync::Arc;
static ENABLED: Lazy<ArcSwap<bool>> = Lazy::new(|| ArcSwap::from_pointee(false));
pub struct QueryLogger;
pub fn setup() {
ENABLED.store(Arc::new(true));
info!("Logging queries to stdout");
}
pub fn disable() {
ENABLED.store(Arc::new(false));
}
pub fn enabled() -> bool {
**ENABLED.load()
}
#[async_trait]
impl Plugin for QueryLogger {
async fn run(
&mut self,
_query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error> {
let query = ast
.iter()
.map(|q| q.to_string())
.collect::<Vec<String>>()
.join("; ");
info!("{}", query);
Ok(PluginOutput::Allow)
}
}

View File

@@ -0,0 +1,73 @@
//! This query router plugin will check if the user can access a particular
//! table as part of their query. If they can't, the query will not be routed.
use async_trait::async_trait;
use sqlparser::ast::{visit_relations, Statement};
use crate::{
config::TableAccess as TableAccessConfig,
errors::Error,
plugins::{Plugin, PluginOutput},
query_router::QueryRouter,
};
use log::{debug, info};
use arc_swap::ArcSwap;
use core::ops::ControlFlow;
use once_cell::sync::Lazy;
use std::sync::Arc;
static CONFIG: Lazy<ArcSwap<Vec<String>>> = Lazy::new(|| ArcSwap::from_pointee(vec![]));
pub fn setup(config: &TableAccessConfig) {
CONFIG.store(Arc::new(config.tables.clone()));
info!("Blocking access to {} tables", config.tables.len());
}
pub fn enabled() -> bool {
!CONFIG.load().is_empty()
}
pub fn disable() {
CONFIG.store(Arc::new(vec![]));
}
pub struct TableAccess;
#[async_trait]
impl Plugin for TableAccess {
async fn run(
&mut self,
_query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error> {
let mut found = None;
let forbidden_tables = CONFIG.load();
visit_relations(ast, |relation| {
let relation = relation.to_string();
let parts = relation.split(".").collect::<Vec<&str>>();
let table_name = parts.last().unwrap();
if forbidden_tables.contains(&table_name.to_string()) {
found = Some(table_name.to_string());
ControlFlow::<()>::Break(())
} else {
ControlFlow::<()>::Continue(())
}
});
if let Some(found) = found {
debug!("Blocking access to table \"{}\"", found);
Ok(PluginOutput::Deny(format!(
"permission for table \"{}\" denied",
found
)))
} else {
Ok(PluginOutput::Allow)
}
}
}

View File

@@ -61,6 +61,8 @@ pub struct PoolIdentifier {
pub user: String,
}
static POOL_REAPER_RATE: u64 = 30_000; // 30 seconds by default
impl PoolIdentifier {
/// Create a new user/pool identifier.
pub fn new(db: &str, user: &str) -> PoolIdentifier {
@@ -91,6 +93,7 @@ pub struct PoolSettings {
// Connecting user.
pub user: User,
pub db: String,
// Default server role to connect to.
pub default_role: Option<Role>,
@@ -138,6 +141,7 @@ impl Default for PoolSettings {
load_balancing_mode: LoadBalancingMode::Random,
shards: 1,
user: User::default(),
db: String::default(),
default_role: None,
query_parser_enabled: false,
primary_reads_enabled: true,
@@ -311,21 +315,34 @@ impl ConnectionPool {
if let Some(apt) = &auth_passthrough {
match apt.fetch_hash(&address).await {
Ok(ok) => {
if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) {
if ok != *pool_auth_hash_value {
warn!("Hash is not the same across shards of the same pool, client auth will \
be done using last obtained hash. Server: {}:{}, Database: {}", server.host, server.port, shard.database);
}
}
debug!("Hash obtained for {:?}", address);
{
let mut pool_auth_hash = pool_auth_hash.write();
*pool_auth_hash = Some(ok.clone());
}
},
Err(err) => warn!("Could not obtain password hashes using auth_query config, ignoring. Error: {:?}", err),
}
Ok(ok) => {
if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read())
{
if ok != *pool_auth_hash_value {
warn!(
"Hash is not the same across shards \
of the same pool, client auth will \
be done using last obtained hash. \
Server: {}:{}, Database: {}",
server.host, server.port, shard.database,
);
}
}
debug!("Hash obtained for {:?}", address);
{
let mut pool_auth_hash = pool_auth_hash.write();
*pool_auth_hash = Some(ok.clone());
}
}
Err(err) => warn!(
"Could not obtain password hashes \
using auth_query config, ignoring. \
Error: {:?}",
err,
),
}
}
let manager = ServerPool::new(
@@ -347,14 +364,31 @@ impl ConnectionPool {
None => config.general.idle_timeout,
};
let server_lifetime = match user.server_lifetime {
Some(server_lifetime) => server_lifetime,
None => match pool_config.server_lifetime {
Some(server_lifetime) => server_lifetime,
None => config.general.server_lifetime,
},
};
let reaper_rate = *vec![idle_timeout, server_lifetime, POOL_REAPER_RATE]
.iter()
.min()
.unwrap();
debug!("Pool reaper rate: {}ms", reaper_rate);
let pool = Pool::builder()
.max_size(user.pool_size)
.min_idle(user.min_pool_size)
.connection_timeout(std::time::Duration::from_millis(connect_timeout))
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
.max_lifetime(Some(std::time::Duration::from_millis(server_lifetime)))
.reaper_rate(std::time::Duration::from_millis(reaper_rate))
.test_on_check_out(false)
.build(manager)
.await
.unwrap();
.await?;
pools.push(pool);
servers.push(address);
@@ -390,6 +424,7 @@ impl ConnectionPool {
// shards: pool_config.shards.clone(),
shards: shard_ids.len(),
user: user.clone(),
db: pool_name.clone(),
default_role: match pool_config.default_role.as_str() {
"any" => None,
"replica" => Some(Role::Replica),
@@ -434,6 +469,32 @@ impl ConnectionPool {
}
}
if let Some(ref plugins) = config.plugins {
if let Some(ref intercept) = plugins.intercept {
if intercept.enabled {
crate::plugins::intercept::setup(intercept, &new_pools);
} else {
crate::plugins::intercept::disable();
}
}
if let Some(ref table_access) = plugins.table_access {
if table_access.enabled {
crate::plugins::table_access::setup(table_access);
} else {
crate::plugins::table_access::disable();
}
}
if let Some(ref query_logger) = plugins.query_logger {
if query_logger.enabled {
crate::plugins::query_logger::setup();
} else {
crate::plugins::query_logger::disable();
}
}
}
POOLS.store(Arc::new(new_pools.clone()));
Ok(())
}
@@ -777,7 +838,6 @@ impl ConnectionPool {
self.databases.len()
}
/// Retrieve all bans for all servers.
pub fn get_bans(&self) -> Vec<(Address, (BanReason, NaiveDateTime))> {
let mut bans: Vec<(Address, (BanReason, NaiveDateTime))> = Vec::new();
let guard = self.banlist.read();
@@ -789,7 +849,7 @@ impl ConnectionPool {
return bans;
}
/// Get the address from the host url.
/// Get the address from the host url
pub fn get_addresses_from_host(&self, host: &str) -> Vec<Address> {
let mut addresses = Vec::new();
for shard in 0..self.shards() {
@@ -828,13 +888,10 @@ impl ConnectionPool {
&self.addresses[shard][server]
}
/// Get server settings retrieved at connection setup.
pub fn server_info(&self) -> BytesMut {
self.server_info.read().clone()
}
/// Calculate how many used connections in the pool
/// for the given server.
fn busy_connection_count(&self, address: &Address) -> u32 {
let state = self.pool_state(address.shard, address.address_index);
let idle = state.idle_connections;

View File

@@ -6,13 +6,19 @@ use once_cell::sync::OnceCell;
use regex::{Regex, RegexSet};
use sqlparser::ast::Statement::{Query, StartTransaction};
use sqlparser::ast::{
BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, TableFactor, Value,
BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, Statement, TableFactor,
Value,
};
use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;
use crate::config::Role;
use crate::errors::Error;
use crate::messages::BytesMutReader;
use crate::plugins::{
intercept, query_logger, table_access, Intercept, Plugin, PluginOutput, QueryLogger,
TableAccess,
};
use crate::pool::PoolSettings;
use crate::sharding::Sharder;
@@ -129,6 +135,10 @@ impl QueryRouter {
self.pool_settings = pool_settings;
}
pub fn pool_settings<'a>(&'a self) -> &'a PoolSettings {
&self.pool_settings
}
/// Try to parse a command and execute it.
pub fn try_execute_command(&mut self, message_buffer: &BytesMut) -> Option<(Command, String)> {
let mut message_cursor = Cursor::new(message_buffer);
@@ -324,10 +334,7 @@ impl QueryRouter {
Some((command, value))
}
/// Try to infer which server to connect to based on the contents of the query.
pub fn infer(&mut self, message: &BytesMut) -> bool {
debug!("Inferring role");
pub fn parse(message: &BytesMut) -> Result<Vec<sqlparser::ast::Statement>, Error> {
let mut message_cursor = Cursor::new(message);
let code = message_cursor.get_u8() as char;
@@ -353,28 +360,29 @@ impl QueryRouter {
query
}
_ => return false,
_ => return Err(Error::UnsupportedStatement),
};
let ast = match Parser::parse_sql(&PostgreSqlDialect {}, &query) {
Ok(ast) => ast,
match Parser::parse_sql(&PostgreSqlDialect {}, &query) {
Ok(ast) => Ok(ast),
Err(err) => {
// SELECT ... FOR UPDATE won't get parsed correctly.
debug!("{}: {}", err, query);
self.active_role = Some(Role::Primary);
return false;
Err(Error::QueryRouterParserError(err.to_string()))
}
};
}
}
debug!("AST: {:?}", ast);
/// Try to infer which server to connect to based on the contents of the query.
pub fn infer(&mut self, ast: &Vec<sqlparser::ast::Statement>) -> Result<(), Error> {
debug!("Inferring role");
if ast.is_empty() {
// That's weird, no idea, let's go to primary
self.active_role = Some(Role::Primary);
return false;
return Err(Error::QueryRouterParserError("empty query".into()));
}
for q in &ast {
for q in ast {
match q {
// All transactions go to the primary, probably a write.
StartTransaction { .. } => {
@@ -418,7 +426,7 @@ impl QueryRouter {
};
}
true
Ok(())
}
/// Parse the shard number from the Bind message
@@ -783,6 +791,34 @@ impl QueryRouter {
}
}
/// Add your plugins here and execute them.
pub async fn execute_plugins(&self, ast: &Vec<Statement>) -> Result<PluginOutput, Error> {
if query_logger::enabled() {
let mut query_logger = QueryLogger {};
let _ = query_logger.run(&self, ast).await;
}
if intercept::enabled() {
let mut intercept = Intercept {};
let result = intercept.run(&self, ast).await;
if let Ok(PluginOutput::Intercept(output)) = result {
return Ok(PluginOutput::Intercept(output));
}
}
if table_access::enabled() {
let mut table_access = TableAccess {};
let result = table_access.run(&self, ast).await;
if let Ok(PluginOutput::Deny(error)) = result {
return Ok(PluginOutput::Deny(error));
}
}
Ok(PluginOutput::Allow)
}
fn set_sharding_key(&mut self, sharding_key: i64) -> Option<usize> {
let sharder = Sharder::new(
self.pool_settings.shards,
@@ -810,11 +846,22 @@ impl QueryRouter {
/// Should we attempt to parse queries?
pub fn query_parser_enabled(&self) -> bool {
let enabled = match self.query_parser_enabled {
None => self.pool_settings.query_parser_enabled,
Some(value) => value,
};
None => {
debug!(
"Using pool settings, query_parser_enabled: {}",
self.pool_settings.query_parser_enabled
);
self.pool_settings.query_parser_enabled
}
debug!("Query parser enabled: {}", enabled);
Some(value) => {
debug!(
"Using query parser override, query_parser_enabled: {}",
value
);
value
}
};
enabled
}
@@ -862,7 +909,7 @@ mod test {
for query in queries {
// It's a recognized query
assert!(qr.infer(&query));
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Replica));
}
}
@@ -881,7 +928,7 @@ mod test {
for query in queries {
// It's a recognized query
assert!(qr.infer(&query));
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Primary));
}
}
@@ -893,7 +940,7 @@ mod test {
let query = simple_query("SELECT * FROM items WHERE id = 5");
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None);
assert!(qr.infer(&query));
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), None);
}
@@ -913,7 +960,7 @@ mod test {
res.put(prepared_stmt);
res.put_i16(0);
assert!(qr.infer(&res));
assert!(qr.infer(&QueryRouter::parse(&res).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Replica));
}
@@ -1077,11 +1124,11 @@ mod test {
assert_eq!(qr.role(), None);
let query = simple_query("INSERT INTO test_table VALUES (1)");
assert!(qr.infer(&query));
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Primary));
let query = simple_query("SELECT * FROM test_table");
assert!(qr.infer(&query));
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), Some(Role::Replica));
assert!(qr.query_parser_enabled());
@@ -1113,6 +1160,7 @@ mod test {
auth_query: None,
auth_query_password: None,
auth_query_user: None,
db: "test".to_string(),
};
let mut qr = QueryRouter::new();
assert_eq!(qr.active_role, None);
@@ -1142,15 +1190,24 @@ mod test {
QueryRouter::setup();
let mut qr = QueryRouter::new();
assert!(qr.infer(&simple_query("BEGIN; SELECT 1; COMMIT;")));
assert!(qr
.infer(&QueryRouter::parse(&simple_query("BEGIN; SELECT 1; COMMIT;")).unwrap())
.is_ok());
assert_eq!(qr.role(), Role::Primary);
assert!(qr.infer(&simple_query("SELECT 1; SELECT 2;")));
assert!(qr
.infer(&QueryRouter::parse(&simple_query("SELECT 1; SELECT 2;")).unwrap())
.is_ok());
assert_eq!(qr.role(), Role::Replica);
assert!(qr.infer(&simple_query(
"SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
)));
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
"SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.role(), Role::Primary);
}
@@ -1177,6 +1234,7 @@ mod test {
auth_query: None,
auth_query_password: None,
auth_query_user: None,
db: "test".to_string(),
};
let mut qr = QueryRouter::new();
qr.update_pool_settings(pool_settings.clone());
@@ -1208,47 +1266,84 @@ mod test {
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
qr.pool_settings.shards = 3;
assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 5")));
assert!(qr
.infer(&QueryRouter::parse(&simple_query("SELECT * FROM data WHERE id = 5")).unwrap())
.is_ok());
assert_eq!(qr.shard(), 2);
assert!(qr.infer(&simple_query(
"SELECT one, two, three FROM public.data WHERE id = 6"
)));
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
"SELECT one, two, three FROM public.data WHERE id = 6"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 0);
assert!(qr.infer(&simple_query(
"SELECT * FROM data
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
"SELECT * FROM data
INNER JOIN t2 ON data.id = 5
AND t2.data_id = data.id
WHERE data.id = 5"
)));
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 2);
// Shard did not move because we couldn't determine the sharding key since it could be ambiguous
// in the query.
assert!(qr.infer(&simple_query(
"SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
)));
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
"SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 2);
assert!(qr.infer(&simple_query(
r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
)));
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 0);
assert!(qr.infer(&simple_query(
r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
)));
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 2);
// Super unique sharding key
qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string());
assert!(qr.infer(&simple_query(
"SELECT * FROM table_x WHERE unique_enough_column_name = 6"
)));
assert!(qr
.infer(
&QueryRouter::parse(&simple_query(
"SELECT * FROM table_x WHERE unique_enough_column_name = 6"
))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 0);
assert!(qr.infer(&simple_query("SELECT * FROM table_y WHERE another_key = 5")));
assert!(qr
.infer(
&QueryRouter::parse(&simple_query("SELECT * FROM table_y WHERE another_key = 5"))
.unwrap()
)
.is_ok());
assert_eq!(qr.shard(), 0);
}
@@ -1272,11 +1367,40 @@ mod test {
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
qr.pool_settings.shards = 3;
assert!(qr.infer(&simple_query(stmt)));
assert!(qr
.infer(&QueryRouter::parse(&simple_query(stmt)).unwrap())
.is_ok());
assert_eq!(qr.placeholders.len(), 1);
assert!(qr.infer_shard_from_bind(&bind));
assert_eq!(qr.shard(), 2);
assert!(qr.placeholders.is_empty());
}
#[tokio::test]
async fn test_table_access_plugin() {
use crate::config::TableAccess;
let ta = TableAccess {
enabled: true,
tables: vec![String::from("pg_database")],
};
crate::plugins::table_access::setup(&ta);
QueryRouter::setup();
let qr = QueryRouter::new();
let query = simple_query("SELECT * FROM pg_database");
let ast = QueryRouter::parse(&query).unwrap();
let res = qr.execute_plugins(&ast).await;
assert_eq!(
res,
Ok(PluginOutput::Deny(
"permission for table \"pg_database\" denied".to_string()
))
);
}
}

View File

@@ -7,22 +7,101 @@ use parking_lot::{Mutex, RwLock};
use postgres_protocol::message;
use std::collections::HashMap;
use std::io::Read;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::SystemTime;
use tokio::io::{AsyncReadExt, BufReader};
use tokio::net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream};
use tokio::net::TcpStream;
use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore};
use tokio_rustls::{client::TlsStream, TlsConnector};
use crate::config::{Address, User};
use crate::config::{get_config, Address, User};
use crate::constants::*;
use crate::dns_cache::{AddrSet, CACHED_RESOLVER};
use crate::errors::{Error, ServerIdentifier};
use crate::messages::*;
use crate::mirrors::MirroringManager;
use crate::pool::ClientServerMap;
use crate::scram::ScramSha256;
use crate::stats::ServerStats;
use std::io::Write;
use pin_project::pin_project;
#[pin_project(project = SteamInnerProj)]
pub enum StreamInner {
Plain {
#[pin]
stream: TcpStream,
},
Tls {
#[pin]
stream: TlsStream<TcpStream>,
},
}
impl AsyncWrite for StreamInner {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let this = self.project();
match this {
SteamInnerProj::Tls { stream } => stream.poll_write(cx, buf),
SteamInnerProj::Plain { stream } => stream.poll_write(cx, buf),
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let this = self.project();
match this {
SteamInnerProj::Tls { stream } => stream.poll_flush(cx),
SteamInnerProj::Plain { stream } => stream.poll_flush(cx),
}
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let this = self.project();
match this {
SteamInnerProj::Tls { stream } => stream.poll_shutdown(cx),
SteamInnerProj::Plain { stream } => stream.poll_shutdown(cx),
}
}
}
impl AsyncRead for StreamInner {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let this = self.project();
match this {
SteamInnerProj::Tls { stream } => stream.poll_read(cx, buf),
SteamInnerProj::Plain { stream } => stream.poll_read(cx, buf),
}
}
}
impl StreamInner {
pub fn try_write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self {
StreamInner::Tls { stream } => {
let r = stream.get_mut();
let mut w = r.1.writer();
w.write(buf)
}
StreamInner::Plain { stream } => stream.try_write(buf),
}
}
}
/// Server state.
pub struct Server {
@@ -30,15 +109,11 @@ pub struct Server {
/// port, e.g. 5432, and role, e.g. primary or replica.
address: Address,
/// Buffered read socket.
read: BufReader<OwnedReadHalf>,
/// Unbuffered write socket (our client code buffers).
write: OwnedWriteHalf,
/// Server TCP connection.
stream: BufStream<StreamInner>,
/// Our server response buffer. We buffer data before we give it to the client.
buffer: BytesMut,
is_async: bool,
/// Server information the server sent us over on startup.
server_info: BytesMut,
@@ -75,6 +150,9 @@ pub struct Server {
last_activity: SystemTime,
mirror_manager: Option<MirroringManager>,
// Associated addresses used
addr_set: Option<AddrSet>,
}
impl Server {
@@ -88,6 +166,24 @@ impl Server {
stats: Arc<ServerStats>,
auth_hash: Arc<RwLock<Option<String>>>,
) -> Result<Server, Error> {
let cached_resolver = CACHED_RESOLVER.load();
let mut addr_set: Option<AddrSet> = None;
// If we are caching addresses and hostname is not an IP
if cached_resolver.enabled() && address.host.parse::<IpAddr>().is_err() {
debug!("Resolving {}", &address.host);
addr_set = match cached_resolver.lookup_ip(&address.host).await {
Ok(ok) => {
debug!("Obtained: {:?}", ok);
Some(ok)
}
Err(err) => {
warn!("Error trying to resolve {}, ({:?})", &address.host, err);
None
}
}
};
let mut stream =
match TcpStream::connect(&format!("{}:{}", &address.host, address.port)).await {
Ok(stream) => stream,
@@ -99,8 +195,88 @@ impl Server {
)));
}
};
// TCP timeouts.
configure_socket(&stream);
let config = get_config();
let mut stream = if config.general.server_tls {
// Request a TLS connection
ssl_request(&mut stream).await?;
let response = match stream.read_u8().await {
Ok(response) => response as char,
Err(err) => {
return Err(Error::SocketError(format!(
"Server socket error: {:?}",
err
)))
}
};
match response {
// Server supports TLS
'S' => {
debug!("Connecting to server using TLS");
let mut root_store = RootCertStore::empty();
root_store.add_server_trust_anchors(
webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}),
);
let mut tls_config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
// Equivalent to sslmode=prefer which is fine most places.
// If you want verify-full, change `verify_server_certificate` to true.
if !config.general.verify_server_certificate {
let mut dangerous = tls_config.dangerous();
dangerous.set_certificate_verifier(Arc::new(
crate::tls::NoCertificateVerification {},
));
}
let connector = TlsConnector::from(Arc::new(tls_config));
let stream = match connector
.connect(address.host.as_str().try_into().unwrap(), stream)
.await
{
Ok(stream) => stream,
Err(err) => {
return Err(Error::SocketError(format!("Server TLS error: {:?}", err)))
}
};
StreamInner::Tls { stream }
}
// Server does not support TLS
'N' => StreamInner::Plain { stream },
// Something else?
m => {
return Err(Error::SocketError(format!(
"Unknown message: {}",
m as char
)));
}
}
} else {
StreamInner::Plain { stream }
};
// let (read, write) = split(stream);
// let (mut read, mut write) = (ReadInner::Plain { stream: read }, WriteInner::Plain { stream: write });
trace!("Sending StartupMessage");
// StartupMessage
@@ -246,7 +422,7 @@ impl Server {
let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]);
if sasl_type == SCRAM_SHA_256 {
if sasl_type.contains(SCRAM_SHA_256) {
debug!("Using {}", SCRAM_SHA_256);
// Generate client message.
@@ -269,7 +445,7 @@ impl Server {
res.put_i32(sasl_response.len() as i32);
res.put(sasl_response);
write_all(&mut stream, res).await?;
write_all_flush(&mut stream, &res).await?;
} else {
error!("Unsupported SCRAM version: {}", sasl_type);
return Err(Error::ServerError);
@@ -300,7 +476,7 @@ impl Server {
res.put_i32(4 + sasl_response.len() as i32);
res.put(sasl_response);
write_all(&mut stream, res).await?;
write_all_flush(&mut stream, &res).await?;
}
SASL_FINAL => {
@@ -444,14 +620,10 @@ impl Server {
}
};
let (read, write) = stream.into_split();
let mut server = Server {
address: address.clone(),
read: BufReader::new(read),
write,
stream: BufStream::new(stream),
buffer: BytesMut::with_capacity(8196),
is_async: false,
server_info,
process_id,
secret_key,
@@ -460,6 +632,7 @@ impl Server {
bad: false,
needs_cleanup: false,
client_server_map,
addr_set,
connected_at: chrono::offset::Utc::now().naive_utc(),
stats,
application_name: String::new(),
@@ -517,7 +690,7 @@ impl Server {
bytes.put_i32(process_id);
bytes.put_i32(secret_key);
write_all(&mut stream, bytes).await
write_all_flush(&mut stream, &bytes).await
}
/// Send messages to the server from the client.
@@ -525,7 +698,7 @@ impl Server {
self.mirror_send(messages);
self.stats().data_sent(messages.len());
match write_all_half(&mut self.write, messages).await {
match write_all_flush(&mut self.stream, &messages).await {
Ok(_) => {
// Successfully sent to server
self.last_activity = SystemTime::now();
@@ -539,22 +712,12 @@ impl Server {
}
}
/// Switch to async mode, flushing messages as soon
/// as we receive them without buffering or waiting for "ReadyForQuery".
pub fn switch_async(&mut self, on: bool) {
if on {
self.is_async = true;
} else {
self.is_async = false;
}
}
/// Receive data from the server in response to a client request.
/// This method must be called multiple times while `self.is_data_available()` is true
/// in order to receive all data the server has to offer.
pub async fn recv(&mut self) -> Result<BytesMut, Error> {
loop {
let mut message = match read_message(&mut self.read).await {
let mut message = match read_message(&mut self.stream).await {
Ok(message) => message,
Err(err) => {
error!("Terminating server because of: {:?}", err);
@@ -644,10 +807,7 @@ impl Server {
// DataRow
'D' => {
// More data is available after this message, this is not the end of the reply.
// If we're async, flush to client now.
if !self.is_async {
self.data_available = true;
}
self.data_available = true;
// Don't flush yet, the more we buffer, the faster this goes...up to a limit.
if self.buffer.len() >= 8196 {
@@ -660,10 +820,7 @@ impl Server {
// CopyOutResponse: copy is starting from the server to the client.
'H' => {
// If we're in async mode, flush now.
if !self.is_async {
self.data_available = true;
}
self.data_available = true;
break;
}
@@ -683,10 +840,6 @@ impl Server {
// Keep buffering until ReadyForQuery shows up.
_ => (),
};
if self.is_async {
break;
}
}
let bytes = self.buffer.clone();
@@ -720,7 +873,23 @@ impl Server {
/// Server & client are out of sync, we must discard this connection.
/// This happens with clients that misbehave.
pub fn is_bad(&self) -> bool {
self.bad
if self.bad {
return self.bad;
};
let cached_resolver = CACHED_RESOLVER.load();
if cached_resolver.enabled() {
if let Some(addr_set) = &self.addr_set {
if cached_resolver.has_changed(self.address.host.as_str(), addr_set) {
warn!(
"DNS changed for {}, it was {:?}. Dropping server connection.",
self.address.host.as_str(),
addr_set
);
return true;
}
}
}
false
}
/// Get server startup information to forward it to the client.
@@ -957,13 +1126,13 @@ impl Drop for Server {
// Update statistics
self.stats.disconnect();
let mut bytes = BytesMut::with_capacity(4);
let mut bytes = BytesMut::with_capacity(5);
bytes.put_u8(b'X');
bytes.put_i32(4);
match self.write.try_write(&bytes) {
Ok(_) => (),
Err(_) => debug!("Dirty shutdown"),
match self.stream.get_mut().try_write(&bytes) {
Ok(5) => (),
_ => debug!("Dirty shutdown"),
};
// Should not matter.

View File

@@ -4,7 +4,12 @@ use rustls_pemfile::{certs, read_one, Item};
use std::iter;
use std::path::Path;
use std::sync::Arc;
use tokio_rustls::rustls::{self, Certificate, PrivateKey};
use std::time::SystemTime;
use tokio_rustls::rustls::{
self,
client::{ServerCertVerified, ServerCertVerifier},
Certificate, PrivateKey, ServerName,
};
use tokio_rustls::TlsAcceptor;
use crate::config::get_config;
@@ -64,3 +69,19 @@ impl Tls {
})
}
}
pub struct NoCertificateVerification;
impl ServerCertVerifier for NoCertificateVerification {
fn verify_server_cert(
&self,
_end_entity: &Certificate,
_intermediates: &[Certificate],
_server_name: &ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8],
_now: SystemTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
}

View File

@@ -1,60 +0,0 @@
import psycopg2
import asyncio
import asyncpg
PGCAT_HOST = "127.0.0.1"
PGCAT_PORT = "6432"
def regular_main():
# Connect to the PostgreSQL database
conn = psycopg2.connect(
host=PGCAT_HOST,
database="sharded_db",
user="sharding_user",
password="sharding_user",
port=PGCAT_PORT,
)
# Open a cursor to perform database operations
cur = conn.cursor()
# Execute a SQL query
cur.execute("SELECT 1")
# Fetch the results
rows = cur.fetchall()
# Print the results
for row in rows:
print(row[0])
# Close the cursor and the database connection
cur.close()
conn.close()
async def main():
# Connect to the PostgreSQL database
conn = await asyncpg.connect(
host=PGCAT_HOST,
database="sharded_db",
user="sharding_user",
password="sharding_user",
port=PGCAT_PORT,
)
# Execute a SQL query
for _ in range(25):
rows = await conn.fetch("SELECT 1")
# Print the results
for row in rows:
print(row[0])
# Close the database connection
await conn.close()
regular_main()
asyncio.run(main())

View File

@@ -1,11 +1,2 @@
asyncio==3.4.3
asyncpg==0.27.0
black==23.3.0
click==8.1.3
mypy-extensions==1.0.0
packaging==23.1
pathspec==0.11.1
platformdirs==3.2.0
psutil==5.9.1
psycopg2==2.9.3
tomli==2.0.1
psutil==5.9.1

View File

@@ -71,15 +71,17 @@ describe "Admin" do
context "client connects but issues no queries" do
it "only affects cl_idle stats" do
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
before_test = admin_conn.async_exec("SHOW POOLS")[0]["sv_idle"]
connections = Array.new(20) { PG::connect(pgcat_conn_str) }
sleep(1)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["cl_idle"]).to eq("20")
expect(results["sv_idle"]).to eq("1")
expect(results["sv_idle"]).to eq(before_test)
connections.map(&:close)
sleep(1.1)
@@ -87,7 +89,7 @@ describe "Admin" do
%w[cl_active cl_idle cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end
expect(results["sv_idle"]).to eq("1")
expect(results["sv_idle"]).to eq(before_test)
end
end

View File

@@ -25,7 +25,7 @@ describe "Query Mirroing" do
processes.pgcat.shutdown
end
it "can mirror a query" do
xit "can mirror a query" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
runs = 15
runs.times { conn.async_exec("SELECT 1 + 2") }

View File

@@ -0,0 +1,14 @@
require_relative 'spec_helper'
describe "Plugins" do
let(:processes) { Helpers::Pgcat.three_shard_setup("sharded_db", 5) }
context "intercept" do
it "will intercept an intellij query" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
res = conn.exec("select current_database() as a, current_schemas(false) as b")
expect(res.values).to eq([["sharded_db", "{public}"]])
end
end
end