mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 09:26:30 +00:00
Compare commits
12 Commits
levkk-asyn
...
levkk-more
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e7265cbf91 | ||
|
|
d738ba28b6 | ||
|
|
ff80bb75cc | ||
|
|
374a6b138b | ||
|
|
d5e329fec5 | ||
|
|
09e54e1175 | ||
|
|
23819c8549 | ||
|
|
7dfbd993f2 | ||
|
|
3601130ba1 | ||
|
|
0d504032b2 | ||
|
|
4a87b4807d | ||
|
|
cb5ff40a59 |
@@ -110,10 +110,6 @@ python3 tests/python/tests.py || exit 1
|
||||
|
||||
start_pgcat "info"
|
||||
|
||||
python3 tests/python/async_test.py
|
||||
|
||||
start_pgcat "info"
|
||||
|
||||
# Admin tests
|
||||
export PGPASSWORD=admin_pass
|
||||
psql -U admin_user -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null
|
||||
|
||||
24
CONFIG.md
24
CONFIG.md
@@ -49,6 +49,14 @@ default: 30000 # milliseconds
|
||||
|
||||
How long an idle connection with a server is left open (ms).
|
||||
|
||||
### server_lifetime
|
||||
```
|
||||
path: general.server_lifetime
|
||||
default: 86400000 # 24 hours
|
||||
```
|
||||
|
||||
Max connection lifetime before it's closed, even if actively used.
|
||||
|
||||
### idle_client_in_transaction_timeout
|
||||
```
|
||||
path: general.idle_client_in_transaction_timeout
|
||||
@@ -180,6 +188,22 @@ default: "admin_pass"
|
||||
|
||||
Password to access the virtual administrative database
|
||||
|
||||
### dns_cache_enabled
|
||||
```
|
||||
path: general.dns_cache_enabled
|
||||
default: false
|
||||
```
|
||||
When enabled, ip resolutions for server connections specified using hostnames will be cached
|
||||
and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found
|
||||
old ip connections are closed (gracefully) and new connections will start using new ip.
|
||||
|
||||
### dns_max_ttl
|
||||
```
|
||||
path: general.dns_max_ttl
|
||||
default: 30
|
||||
```
|
||||
Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
|
||||
|
||||
## `pools.<pool_name>` Section
|
||||
|
||||
### pool_mode
|
||||
|
||||
365
Cargo.lock
generated
365
Cargo.lock
generated
@@ -26,6 +26,27 @@ version = "1.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6"
|
||||
|
||||
[[package]]
|
||||
name = "async-stream"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dad5c83079eae9969be7fadefe640a1c566901f05ff91ab221de4b6f68d9507e"
|
||||
dependencies = [
|
||||
"async-stream-impl",
|
||||
"futures-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-stream-impl"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "10f203db73a71dfa2fb6dd22763990fa26f3d2625a6da2da900d23b87d26be27"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.68"
|
||||
@@ -212,6 +233,12 @@ dependencies = [
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "data-encoding"
|
||||
version = "2.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3ee2393c4a91429dffb4bedf19f4d6abf27d8a732c8ce4980305d782e5426d57"
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.6"
|
||||
@@ -223,6 +250,18 @@ dependencies = [
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "enum-as-inner"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c9720bba047d567ffc8a3cba48bf19126600e249ab7f128e9233e6376976a116"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.10.0"
|
||||
@@ -275,6 +314,15 @@ version = "1.0.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
|
||||
|
||||
[[package]]
|
||||
name = "form_urlencoded"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8"
|
||||
dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures"
|
||||
version = "0.3.28"
|
||||
@@ -410,6 +458,12 @@ version = "0.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9"
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.2.6"
|
||||
@@ -434,6 +488,17 @@ dependencies = [
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hostname"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3c731c3e10504cc8ed35cfe2f1db4c9274c3d35fa486e3b31df46f068ef3e867"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"match_cfg",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "0.2.9"
|
||||
@@ -522,6 +587,27 @@ dependencies = [
|
||||
"cxx-build",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "418a0a6fab821475f634efe3ccc45c013f742efe03d853e8d3355d5cb850ecf8"
|
||||
dependencies = [
|
||||
"matches",
|
||||
"unicode-bidi",
|
||||
"unicode-normalization",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e14ddfc70884202db2244c223200c204c2bda1bc6e0998d11b5e024d657209e6"
|
||||
dependencies = [
|
||||
"unicode-bidi",
|
||||
"unicode-normalization",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "1.9.2"
|
||||
@@ -542,6 +628,24 @@ dependencies = [
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ipconfig"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bd302af1b90f2463a98fa5ad469fc212c8e3175a41c3068601bfa2727591c5be"
|
||||
dependencies = [
|
||||
"socket2",
|
||||
"widestring",
|
||||
"winapi",
|
||||
"winreg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ipnet"
|
||||
version = "2.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f88c5561171189e69df9d98bcf18fd5f9558300f7ea7b801eb8a0fd748bd8745"
|
||||
|
||||
[[package]]
|
||||
name = "is-terminal"
|
||||
version = "0.4.4"
|
||||
@@ -589,6 +693,12 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.139"
|
||||
@@ -604,6 +714,12 @@ dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linked-hash-map"
|
||||
version = "0.5.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f"
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.1.4"
|
||||
@@ -629,6 +745,27 @@ dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lru-cache"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "31e24f1ad8321ca0e8a1e0ac13f23cb668e6f5466c2c57319f6a5cf1cc8e3b1c"
|
||||
dependencies = [
|
||||
"linked-hash-map",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "match_cfg"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4"
|
||||
|
||||
[[package]]
|
||||
name = "matches"
|
||||
version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f"
|
||||
|
||||
[[package]]
|
||||
name = "md-5"
|
||||
version = "0.10.5"
|
||||
@@ -737,9 +874,15 @@ dependencies = [
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "percent-encoding"
|
||||
version = "2.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e"
|
||||
|
||||
[[package]]
|
||||
name = "pgcat"
|
||||
version = "1.0.1"
|
||||
version = "1.0.2-alpha1"
|
||||
dependencies = [
|
||||
"arc-swap",
|
||||
"async-trait",
|
||||
@@ -762,12 +905,15 @@ dependencies = [
|
||||
"once_cell",
|
||||
"parking_lot",
|
||||
"phf",
|
||||
"pin-project",
|
||||
"postgres-protocol",
|
||||
"rand",
|
||||
"regex",
|
||||
"rustls",
|
||||
"rustls-pemfile",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"sha-1",
|
||||
"sha2",
|
||||
"socket2",
|
||||
@@ -775,7 +921,10 @@ dependencies = [
|
||||
"stringprep",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tokio-test",
|
||||
"toml",
|
||||
"trust-dns-resolver",
|
||||
"webpki-roots",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -820,6 +969,26 @@ dependencies = [
|
||||
"siphasher",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project"
|
||||
version = "1.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc"
|
||||
dependencies = [
|
||||
"pin-project-internal",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-internal"
|
||||
version = "1.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-lite"
|
||||
version = "0.2.9"
|
||||
@@ -865,6 +1034,12 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quick-error"
|
||||
version = "1.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.26"
|
||||
@@ -915,9 +1090,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.8.0"
|
||||
version = "1.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac6cf59af1067a3fb53fbe5c88c053764e930f932be1d71d3ffe032cbe147f59"
|
||||
checksum = "af83e617f331cc6ae2da5443c602dfa5af81e517212d9d611a5b3ba1777b5370"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
@@ -926,9 +1101,19 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a5996294f19bd3aae0453a862ad728f60e6600695733dd5df01da90c54363a3c"
|
||||
|
||||
[[package]]
|
||||
name = "resolv-conf"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b6868896879ba532248f33598de5181522d8b3d9d724dfd230911e1a7d4822f5"
|
||||
checksum = "52e44394d2086d010551b14b53b1f24e31647570cd1deb0379e2c21b329aba00"
|
||||
dependencies = [
|
||||
"hostname",
|
||||
"quick-error",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ring"
|
||||
@@ -961,9 +1146,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.21.0"
|
||||
version = "0.21.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "07180898a28ed6a7f7ba2311594308f595e3dd2e3c3812fa0a80a47b45f17e5d"
|
||||
checksum = "c911ba11bc8433e811ce56fde130ccf32f5127cab0e0194e9c68c5a5b671791e"
|
||||
dependencies = [
|
||||
"log",
|
||||
"ring",
|
||||
@@ -990,6 +1175,12 @@ dependencies = [
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041"
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.1.0"
|
||||
@@ -1017,6 +1208,9 @@ name = "serde"
|
||||
version = "1.0.160"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bb2f3770c8bce3bcda7e149193a069a0f4365bda1fa5cd88e03bca26afc1216c"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
@@ -1029,6 +1223,17 @@ dependencies = [
|
||||
"syn 2.0.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.96"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_spanned"
|
||||
version = "0.6.1"
|
||||
@@ -1113,6 +1318,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "355dc4d4b6207ca8a3434fc587db0a8016130a574dbcdbfb93d7f7b5bc5b211a"
|
||||
dependencies = [
|
||||
"log",
|
||||
"sqlparser_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlparser_derive"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "55fe75cb4a364c7f7ae06c7dbbc8d84bddd85d6cdf9975963c3935bc1991761e"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1168,6 +1385,26 @@ dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.37"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "10deb33631e3c9018b9baf9dcbbc4f737320d2b576bac10f6aefa048fa407e3e"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.37"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "982d17546b47146b28f7c22e3d08465f6b8903d0ea13c1660d9d84a6e7adcdbb"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "time"
|
||||
version = "0.1.45"
|
||||
@@ -1235,6 +1472,30 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-stream"
|
||||
version = "0.1.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d660770404473ccd7bc9f8b28494a811bc18542b915c0855c51e8f419d5223ce"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-test"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "53474327ae5e166530d17f2d956afcb4f8a004de581b3cae10f12006bc8163e3"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.7.7"
|
||||
@@ -1297,9 +1558,21 @@ checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"pin-project-lite",
|
||||
"tracing-attributes",
|
||||
"tracing-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-attributes"
|
||||
version = "0.1.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-core"
|
||||
version = "0.1.30"
|
||||
@@ -1309,6 +1582,51 @@ dependencies = [
|
||||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "trust-dns-proto"
|
||||
version = "0.22.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4f7f83d1e4a0e4358ac54c5c3681e5d7da5efc5a7a632c90bb6d6669ddd9bc26"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"cfg-if",
|
||||
"data-encoding",
|
||||
"enum-as-inner",
|
||||
"futures-channel",
|
||||
"futures-io",
|
||||
"futures-util",
|
||||
"idna 0.2.3",
|
||||
"ipnet",
|
||||
"lazy_static",
|
||||
"rand",
|
||||
"smallvec",
|
||||
"thiserror",
|
||||
"tinyvec",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "trust-dns-resolver"
|
||||
version = "0.22.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "aff21aa4dcefb0a1afbfac26deb0adc93888c7d295fb63ab273ef276ba2b7cfe"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"futures-util",
|
||||
"ipconfig",
|
||||
"lazy_static",
|
||||
"lru-cache",
|
||||
"parking_lot",
|
||||
"resolv-conf",
|
||||
"smallvec",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"trust-dns-proto",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "try-lock"
|
||||
version = "0.2.4"
|
||||
@@ -1354,6 +1672,17 @@ version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
|
||||
|
||||
[[package]]
|
||||
name = "url"
|
||||
version = "2.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643"
|
||||
dependencies = [
|
||||
"form_urlencoded",
|
||||
"idna 0.3.0",
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.4"
|
||||
@@ -1446,6 +1775,21 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "0.23.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "aa54963694b65584e170cf5dc46aeb4dcaa5584e652ff5f3952e56d66aff0125"
|
||||
dependencies = [
|
||||
"rustls-webpki",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "widestring"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "17882f045410753661207383517a6f62ec3dbeb6a4ed2acce01f0728238d1983"
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
@@ -1551,3 +1895,12 @@ checksum = "faf09497b8f8b5ac5d3bb4d05c0a99be20f26fd3d5f2db7b0716e946d5103658"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winreg"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
13
Cargo.toml
13
Cargo.toml
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "pgcat"
|
||||
version = "1.0.1"
|
||||
version = "1.0.2-alpha1"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
@@ -14,12 +14,12 @@ rand = "0.8"
|
||||
chrono = "0.4"
|
||||
sha-1 = "0.10"
|
||||
toml = "0.7"
|
||||
serde = "1"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_derive = "1"
|
||||
regex = "1"
|
||||
num_cpus = "1"
|
||||
once_cell = "1"
|
||||
sqlparser = "0.33.0"
|
||||
sqlparser = {version = "0.33", features = ["visitor"] }
|
||||
log = "0.4"
|
||||
arc-swap = "1"
|
||||
env_logger = "0.10"
|
||||
@@ -39,6 +39,13 @@ nix = "0.26.2"
|
||||
atomic_enum = "0.2.0"
|
||||
postgres-protocol = "0.6.5"
|
||||
fallible-iterator = "0.2"
|
||||
pin-project = "1"
|
||||
webpki-roots = "0.23"
|
||||
rustls = { version = "0.21", features = ["dangerous_configuration"] }
|
||||
trust-dns-resolver = "0.22.0"
|
||||
tokio-test = "0.4.2"
|
||||
serde_json = "1"
|
||||
|
||||
[target.'cfg(not(target_env = "msvc"))'.dependencies]
|
||||
jemallocator = "0.5.0"
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ PostgreSQL pooler and proxy (like PgBouncer) with support for sharding, load bal
|
||||
| Failover | **Stable** | Queries are automatically rerouted around broken replicas, validated by regular health checks. |
|
||||
| Admin database statistics | **Stable** | Pooler statistics and administration via the `pgbouncer` and `pgcat` databases. |
|
||||
| Prometheus statistics | **Stable** | Statistics are reported via a HTTP endpoint for Prometheus. |
|
||||
| Client TLS | **Stable** | Clients can connect to the pooler using TLS/SSL. |
|
||||
| SSL/TLS | **Stable** | Clients can connect to the pooler using TLS. Pooler can connect to Postgres servers using TLS. |
|
||||
| Client/Server authentication | **Stable** | Clients can connect using MD5 authentication, supported by `libpq` and all Postgres client drivers. PgCat can connect to Postgres using MD5 and SCRAM-SHA-256. |
|
||||
| Live configuration reloading | **Stable** | Identical to PgBouncer; all settings can be reloaded dynamically (except `host` and `port`). |
|
||||
| Auth passthrough | **Stable** | MD5 password authentication can be configured to use an `auth_query` so no cleartext passwords are needed in the config file.|
|
||||
|
||||
62
pgcat.toml
62
pgcat.toml
@@ -23,6 +23,9 @@ connect_timeout = 5000 # milliseconds
|
||||
# How long an idle connection with a server is left open (ms).
|
||||
idle_timeout = 30000 # milliseconds
|
||||
|
||||
# Max connection lifetime before it's closed, even if actively used.
|
||||
server_lifetime = 86400000 # 24 hours
|
||||
|
||||
# How long a client is allowed to be idle while in a transaction (ms).
|
||||
idle_client_in_transaction_timeout = 0 # milliseconds
|
||||
|
||||
@@ -58,9 +61,15 @@ tcp_keepalives_count = 5
|
||||
tcp_keepalives_interval = 5
|
||||
|
||||
# Path to TLS Certificate file to use for TLS connections
|
||||
# tls_certificate = "server.cert"
|
||||
# tls_certificate = ".circleci/server.cert"
|
||||
# Path to TLS private key file to use for TLS connections
|
||||
# tls_private_key = "server.key"
|
||||
# tls_private_key = ".circleci/server.key"
|
||||
|
||||
# Enable/disable server TLS
|
||||
server_tls = false
|
||||
|
||||
# Verify server certificate is completely authentic.
|
||||
verify_server_certificate = false
|
||||
|
||||
# User name to access the virtual administrative database (pgbouncer or pgcat)
|
||||
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
|
||||
@@ -137,6 +146,53 @@ idle_timeout = 40000
|
||||
# Connect timeout can be overwritten in the pool
|
||||
connect_timeout = 3000
|
||||
|
||||
# When enabled, ip resolutions for server connections specified using hostnames will be cached
|
||||
# and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found
|
||||
# old ip connections are closed (gracefully) and new connections will start using new ip.
|
||||
# dns_cache_enabled = false
|
||||
|
||||
# Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
|
||||
# dns_max_ttl = 30
|
||||
|
||||
[plugins]
|
||||
|
||||
[plugins.query_logger]
|
||||
enabled = false
|
||||
|
||||
[plugins.table_access]
|
||||
enabled = false
|
||||
tables = [
|
||||
"pg_user",
|
||||
"pg_roles",
|
||||
"pg_database",
|
||||
]
|
||||
|
||||
[plugins.intercept]
|
||||
enabled = true
|
||||
|
||||
[plugins.intercept.queries.0]
|
||||
|
||||
query = "select current_database() as a, current_schemas(false) as b"
|
||||
schema = [
|
||||
["a", "text"],
|
||||
["b", "text"],
|
||||
]
|
||||
result = [
|
||||
["${DATABASE}", "{public}"],
|
||||
]
|
||||
|
||||
[plugins.intercept.queries.1]
|
||||
|
||||
query = "select current_database(), current_schema(), current_user"
|
||||
schema = [
|
||||
["current_database", "text"],
|
||||
["current_schema", "text"],
|
||||
["current_user", "text"],
|
||||
]
|
||||
result = [
|
||||
["${DATABASE}", "public", "${USER}"],
|
||||
]
|
||||
|
||||
# User configs are structured as pool.<pool_name>.users.<user_index>
|
||||
# This section holds the credentials for users that may connect to this cluster
|
||||
[pools.sharded_db.users.0]
|
||||
@@ -206,6 +262,8 @@ sharding_function = "pg_bigint_hash"
|
||||
username = "simple_user"
|
||||
password = "simple_user"
|
||||
pool_size = 5
|
||||
min_pool_size = 3
|
||||
server_lifetime = 60000
|
||||
statement_timeout = 0
|
||||
|
||||
[pools.simple_db.shards.0]
|
||||
|
||||
@@ -12,9 +12,9 @@ use tokio::time::Instant;
|
||||
use crate::config::{get_config, reload_config, VERSION};
|
||||
use crate::errors::Error;
|
||||
use crate::messages::*;
|
||||
use crate::pool::ClientServerMap;
|
||||
use crate::pool::{get_all_pools, get_pool};
|
||||
use crate::stats::{get_client_stats, get_pool_stats, get_server_stats, ClientState, ServerState};
|
||||
use crate::ClientServerMap;
|
||||
|
||||
pub fn generate_server_info_for_admin() -> BytesMut {
|
||||
let mut server_info = BytesMut::new();
|
||||
|
||||
@@ -77,6 +77,8 @@ impl AuthPassthrough {
|
||||
pool_size: 1,
|
||||
statement_timeout: 0,
|
||||
pool_mode: None,
|
||||
server_lifetime: None,
|
||||
min_pool_size: None,
|
||||
};
|
||||
|
||||
let user = &address.username;
|
||||
|
||||
152
src/client.rs
152
src/client.rs
@@ -16,6 +16,7 @@ use crate::auth_passthrough::refetch_auth_hash;
|
||||
use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode};
|
||||
use crate::constants::*;
|
||||
use crate::messages::*;
|
||||
use crate::plugins::PluginOutput;
|
||||
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
|
||||
use crate::query_router::{Command, QueryRouter};
|
||||
use crate::server::Server;
|
||||
@@ -539,6 +540,7 @@ where
|
||||
Some(md5_hash_password(username, password, &salt))
|
||||
} else {
|
||||
if !get_config().is_auth_query_configured() {
|
||||
wrong_password(&mut write, username).await?;
|
||||
return Err(Error::ClientAuthImpossible(username.into()));
|
||||
}
|
||||
|
||||
@@ -565,6 +567,8 @@ where
|
||||
}
|
||||
|
||||
Err(err) => {
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(Error::ClientAuthPassthroughError(
|
||||
err.to_string(),
|
||||
client_identifier,
|
||||
@@ -587,7 +591,15 @@ where
|
||||
client_identifier
|
||||
);
|
||||
|
||||
let fetched_hash = refetch_auth_hash(&pool).await?;
|
||||
let fetched_hash = match refetch_auth_hash(&pool).await {
|
||||
Ok(fetched_hash) => fetched_hash,
|
||||
Err(err) => {
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt);
|
||||
|
||||
// Ok password changed in server an auth is possible.
|
||||
@@ -754,6 +766,9 @@ where
|
||||
|
||||
self.stats.register(self.stats.clone());
|
||||
|
||||
// Result returned by one of the plugins.
|
||||
let mut plugin_output = None;
|
||||
|
||||
// Our custom protocol loop.
|
||||
// We expect the client to either start a transaction with regular queries
|
||||
// or issue commands for our sharding and server selection protocol.
|
||||
@@ -804,7 +819,25 @@ where
|
||||
|
||||
'Q' => {
|
||||
if query_router.query_parser_enabled() {
|
||||
query_router.infer(&message);
|
||||
if let Ok(ast) = QueryRouter::parse(&message) {
|
||||
let plugin_result = query_router.execute_plugins(&ast).await;
|
||||
|
||||
match plugin_result {
|
||||
Ok(PluginOutput::Deny(error)) => {
|
||||
error_response(&mut self.write, &error).await?;
|
||||
continue;
|
||||
}
|
||||
|
||||
Ok(PluginOutput::Intercept(result)) => {
|
||||
write_all(&mut self.write, result).await?;
|
||||
continue;
|
||||
}
|
||||
|
||||
_ => (),
|
||||
};
|
||||
|
||||
let _ = query_router.infer(&ast);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -812,7 +845,13 @@ where
|
||||
self.buffer.put(&message[..]);
|
||||
|
||||
if query_router.query_parser_enabled() {
|
||||
query_router.infer(&message);
|
||||
if let Ok(ast) = QueryRouter::parse(&message) {
|
||||
if let Ok(output) = query_router.execute_plugins(&ast).await {
|
||||
plugin_output = Some(output);
|
||||
}
|
||||
|
||||
let _ = query_router.infer(&ast);
|
||||
}
|
||||
}
|
||||
|
||||
continue;
|
||||
@@ -846,6 +885,18 @@ where
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check on plugin results.
|
||||
match plugin_output {
|
||||
Some(PluginOutput::Deny(error)) => {
|
||||
self.buffer.clear();
|
||||
error_response(&mut self.write, &error).await?;
|
||||
plugin_output = None;
|
||||
continue;
|
||||
}
|
||||
|
||||
_ => (),
|
||||
};
|
||||
|
||||
// Get a pool instance referenced by the most up-to-date
|
||||
// pointer. This ensures we always read the latest config
|
||||
// when starting a query.
|
||||
@@ -932,7 +983,7 @@ where
|
||||
}
|
||||
|
||||
// Grab a server from the pool.
|
||||
let mut connection = match pool
|
||||
let connection = match pool
|
||||
.get(query_router.shard(), query_router.role(), &self.stats)
|
||||
.await
|
||||
{
|
||||
@@ -975,8 +1026,9 @@ where
|
||||
}
|
||||
};
|
||||
|
||||
let server = &mut *connection.0;
|
||||
let mut reference = connection.0;
|
||||
let address = connection.1;
|
||||
let server = &mut *reference;
|
||||
|
||||
// Server is assigned to the client in case the client wants to
|
||||
// cancel a query later.
|
||||
@@ -999,7 +1051,6 @@ where
|
||||
|
||||
// Set application_name.
|
||||
server.set_name(&self.application_name).await?;
|
||||
server.switch_async(false);
|
||||
|
||||
let mut initial_message = Some(message);
|
||||
|
||||
@@ -1019,37 +1070,12 @@ where
|
||||
None => {
|
||||
trace!("Waiting for message inside transaction or in session mode");
|
||||
|
||||
let message = tokio::select! {
|
||||
message = tokio::time::timeout(
|
||||
idle_client_timeout_duration,
|
||||
read_message(&mut self.read),
|
||||
) => message,
|
||||
|
||||
server_message = server.recv() => {
|
||||
debug!("Got async message");
|
||||
|
||||
let server_message = match server_message {
|
||||
Ok(message) => message,
|
||||
Err(err) => {
|
||||
pool.ban(&address, BanReason::MessageReceiveFailed, Some(&self.stats));
|
||||
server.mark_bad();
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
match write_all_half(&mut self.write, &server_message).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
server.mark_bad();
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match message {
|
||||
match tokio::time::timeout(
|
||||
idle_client_timeout_duration,
|
||||
read_message(&mut self.read),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(message)) => message,
|
||||
Ok(Err(err)) => {
|
||||
// Client disconnected inside a transaction.
|
||||
@@ -1099,6 +1125,27 @@ where
|
||||
match code {
|
||||
// Query
|
||||
'Q' => {
|
||||
if query_router.query_parser_enabled() {
|
||||
if let Ok(ast) = QueryRouter::parse(&message) {
|
||||
let plugin_result = query_router.execute_plugins(&ast).await;
|
||||
|
||||
match plugin_result {
|
||||
Ok(PluginOutput::Deny(error)) => {
|
||||
error_response(&mut self.write, &error).await?;
|
||||
continue;
|
||||
}
|
||||
|
||||
Ok(PluginOutput::Intercept(result)) => {
|
||||
write_all(&mut self.write, result).await?;
|
||||
continue;
|
||||
}
|
||||
|
||||
_ => (),
|
||||
};
|
||||
|
||||
let _ = query_router.infer(&ast);
|
||||
}
|
||||
}
|
||||
debug!("Sending query to server");
|
||||
|
||||
self.send_and_receive_loop(
|
||||
@@ -1138,6 +1185,14 @@ where
|
||||
// Parse
|
||||
// The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
|
||||
'P' => {
|
||||
if query_router.query_parser_enabled() {
|
||||
if let Ok(ast) = QueryRouter::parse(&message) {
|
||||
if let Ok(output) = query_router.execute_plugins(&ast).await {
|
||||
plugin_output = Some(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.buffer.put(&message[..]);
|
||||
}
|
||||
|
||||
@@ -1166,13 +1221,26 @@ where
|
||||
|
||||
// Sync
|
||||
// Frontend (client) is asking for the query result now.
|
||||
'S' | 'H' => {
|
||||
'S' => {
|
||||
debug!("Sending query to server");
|
||||
|
||||
if code == 'H' {
|
||||
server.switch_async(true);
|
||||
debug!("Client requested flush, going async");
|
||||
}
|
||||
match plugin_output {
|
||||
Some(PluginOutput::Deny(error)) => {
|
||||
error_response(&mut self.write, &error).await?;
|
||||
plugin_output = None;
|
||||
self.buffer.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
Some(PluginOutput::Intercept(result)) => {
|
||||
write_all(&mut self.write, result).await?;
|
||||
plugin_output = None;
|
||||
self.buffer.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
_ => (),
|
||||
};
|
||||
|
||||
self.buffer.put(&message[..]);
|
||||
|
||||
|
||||
185
src/config.rs
185
src/config.rs
@@ -12,6 +12,7 @@ use std::sync::Arc;
|
||||
use tokio::fs::File;
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
use crate::dns_cache::CachedResolver;
|
||||
use crate::errors::Error;
|
||||
use crate::pool::{ClientServerMap, ConnectionPool};
|
||||
use crate::sharding::ShardingFunction;
|
||||
@@ -181,7 +182,9 @@ pub struct User {
|
||||
pub server_username: Option<String>,
|
||||
pub server_password: Option<String>,
|
||||
pub pool_size: u32,
|
||||
pub min_pool_size: Option<u32>,
|
||||
pub pool_mode: Option<PoolMode>,
|
||||
pub server_lifetime: Option<u64>,
|
||||
#[serde(default)] // 0
|
||||
pub statement_timeout: u64,
|
||||
}
|
||||
@@ -194,12 +197,34 @@ impl Default for User {
|
||||
server_username: None,
|
||||
server_password: None,
|
||||
pool_size: 15,
|
||||
min_pool_size: None,
|
||||
statement_timeout: 0,
|
||||
pool_mode: None,
|
||||
server_lifetime: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl User {
|
||||
fn validate(&self) -> Result<(), Error> {
|
||||
match self.min_pool_size {
|
||||
Some(min_pool_size) => {
|
||||
if min_pool_size > self.pool_size {
|
||||
error!(
|
||||
"min_pool_size of {} cannot be larger than pool_size of {}",
|
||||
min_pool_size, self.pool_size
|
||||
);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
}
|
||||
|
||||
None => (),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// General configuration.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
pub struct General {
|
||||
@@ -231,6 +256,12 @@ pub struct General {
|
||||
#[serde(default)] // False
|
||||
pub log_client_disconnections: bool,
|
||||
|
||||
#[serde(default)] // False
|
||||
pub dns_cache_enabled: bool,
|
||||
|
||||
#[serde(default = "General::default_dns_max_ttl")]
|
||||
pub dns_max_ttl: u64,
|
||||
|
||||
#[serde(default = "General::default_shutdown_timeout")]
|
||||
pub shutdown_timeout: u64,
|
||||
|
||||
@@ -246,6 +277,9 @@ pub struct General {
|
||||
#[serde(default = "General::default_idle_client_in_transaction_timeout")]
|
||||
pub idle_client_in_transaction_timeout: u64,
|
||||
|
||||
#[serde(default = "General::default_server_lifetime")]
|
||||
pub server_lifetime: u64,
|
||||
|
||||
#[serde(default = "General::default_worker_threads")]
|
||||
pub worker_threads: usize,
|
||||
|
||||
@@ -254,9 +288,17 @@ pub struct General {
|
||||
|
||||
pub tls_certificate: Option<String>,
|
||||
pub tls_private_key: Option<String>,
|
||||
|
||||
#[serde(default)] // false
|
||||
pub server_tls: bool,
|
||||
|
||||
#[serde(default)] // false
|
||||
pub verify_server_certificate: bool,
|
||||
|
||||
pub admin_username: String,
|
||||
pub admin_password: String,
|
||||
|
||||
// Support for auth query
|
||||
pub auth_query: Option<String>,
|
||||
pub auth_query_user: Option<String>,
|
||||
pub auth_query_password: Option<String>,
|
||||
@@ -271,6 +313,10 @@ impl General {
|
||||
5432
|
||||
}
|
||||
|
||||
pub fn default_server_lifetime() -> u64 {
|
||||
1000 * 60 * 60 * 24 // 24 hours
|
||||
}
|
||||
|
||||
pub fn default_connect_timeout() -> u64 {
|
||||
1000
|
||||
}
|
||||
@@ -298,6 +344,10 @@ impl General {
|
||||
60000
|
||||
}
|
||||
|
||||
pub fn default_dns_max_ttl() -> u64 {
|
||||
30
|
||||
}
|
||||
|
||||
pub fn default_healthcheck_timeout() -> u64 {
|
||||
1000
|
||||
}
|
||||
@@ -340,13 +390,18 @@ impl Default for General {
|
||||
log_client_connections: false,
|
||||
log_client_disconnections: false,
|
||||
autoreload: None,
|
||||
dns_cache_enabled: false,
|
||||
dns_max_ttl: Self::default_dns_max_ttl(),
|
||||
tls_certificate: None,
|
||||
tls_private_key: None,
|
||||
server_tls: false,
|
||||
verify_server_certificate: false,
|
||||
admin_username: String::from("admin"),
|
||||
admin_password: String::from("admin"),
|
||||
auth_query: None,
|
||||
auth_query_user: None,
|
||||
auth_query_password: None,
|
||||
server_lifetime: 1000 * 3600 * 24, // 24 hours,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -411,6 +466,8 @@ pub struct Pool {
|
||||
|
||||
pub idle_timeout: Option<u64>,
|
||||
|
||||
pub server_lifetime: Option<u64>,
|
||||
|
||||
pub sharding_function: ShardingFunction,
|
||||
|
||||
#[serde(default = "Pool::default_automatic_sharding_key")]
|
||||
@@ -515,6 +572,10 @@ impl Pool {
|
||||
None => None,
|
||||
};
|
||||
|
||||
for (_, user) in &self.users {
|
||||
user.validate()?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -539,6 +600,7 @@ impl Default for Pool {
|
||||
auth_query: None,
|
||||
auth_query_user: None,
|
||||
auth_query_password: None,
|
||||
server_lifetime: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -588,7 +650,7 @@ impl Shard {
|
||||
|
||||
if primary_count > 1 {
|
||||
error!(
|
||||
"Shard {} has more than on primary configured",
|
||||
"Shard {} has more than one primary configured",
|
||||
self.database
|
||||
);
|
||||
return Err(Error::BadConfig);
|
||||
@@ -617,6 +679,55 @@ impl Default for Shard {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct Plugins {
|
||||
pub intercept: Option<Intercept>,
|
||||
pub table_access: Option<TableAccess>,
|
||||
pub query_logger: Option<QueryLogger>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct Intercept {
|
||||
pub enabled: bool,
|
||||
pub queries: BTreeMap<String, Query>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct TableAccess {
|
||||
pub enabled: bool,
|
||||
pub tables: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct QueryLogger {
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
impl Intercept {
|
||||
pub fn substitute(&mut self, db: &str, user: &str) {
|
||||
for (_, query) in self.queries.iter_mut() {
|
||||
query.substitute(db, user);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct Query {
|
||||
pub query: String,
|
||||
pub schema: Vec<Vec<String>>,
|
||||
pub result: Vec<Vec<String>>,
|
||||
}
|
||||
|
||||
impl Query {
|
||||
pub fn substitute(&mut self, db: &str, user: &str) {
|
||||
for col in self.result.iter_mut() {
|
||||
for i in 0..col.len() {
|
||||
col[i] = col[i].replace("${USER}", user).replace("${DATABASE}", db);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration wrapper.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
pub struct Config {
|
||||
@@ -635,6 +746,7 @@ pub struct Config {
|
||||
pub path: String,
|
||||
|
||||
pub general: General,
|
||||
pub plugins: Option<Plugins>,
|
||||
pub pools: HashMap<String, Pool>,
|
||||
}
|
||||
|
||||
@@ -672,6 +784,7 @@ impl Default for Config {
|
||||
path: Self::default_path(),
|
||||
general: General::default(),
|
||||
pools: HashMap::default(),
|
||||
plugins: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -791,6 +904,10 @@ impl Config {
|
||||
);
|
||||
info!("Shutdown timeout: {}ms", self.general.shutdown_timeout);
|
||||
info!("Healthcheck delay: {}ms", self.general.healthcheck_delay);
|
||||
info!(
|
||||
"Default max server lifetime: {}ms",
|
||||
self.general.server_lifetime
|
||||
);
|
||||
match self.general.tls_certificate.clone() {
|
||||
Some(tls_certificate) => {
|
||||
info!("TLS certificate: {}", tls_certificate);
|
||||
@@ -809,6 +926,11 @@ impl Config {
|
||||
info!("TLS support is disabled");
|
||||
}
|
||||
};
|
||||
info!("Server TLS enabled: {}", self.general.server_tls);
|
||||
info!(
|
||||
"Server TLS certificate verification: {}",
|
||||
self.general.verify_server_certificate
|
||||
);
|
||||
|
||||
for (pool_name, pool_config) in &self.pools {
|
||||
// TODO: Make this output prettier (maybe a table?)
|
||||
@@ -867,12 +989,26 @@ impl Config {
|
||||
pool_name,
|
||||
pool_config.users.len()
|
||||
);
|
||||
info!(
|
||||
"[pool: {}] Max server lifetime: {}",
|
||||
pool_name,
|
||||
match pool_config.server_lifetime {
|
||||
Some(server_lifetime) => format!("{}ms", server_lifetime),
|
||||
None => "default".to_string(),
|
||||
}
|
||||
);
|
||||
|
||||
for user in &pool_config.users {
|
||||
info!(
|
||||
"[pool: {}][user: {}] Pool size: {}",
|
||||
pool_name, user.1.username, user.1.pool_size,
|
||||
);
|
||||
info!(
|
||||
"[pool: {}][user: {}] Minimum pool size: {}",
|
||||
pool_name,
|
||||
user.1.username,
|
||||
user.1.min_pool_size.unwrap_or(0)
|
||||
);
|
||||
info!(
|
||||
"[pool: {}][user: {}] Statement timeout: {}",
|
||||
pool_name, user.1.username, user.1.statement_timeout
|
||||
@@ -886,6 +1022,15 @@ impl Config {
|
||||
None => pool_config.pool_mode.to_string(),
|
||||
}
|
||||
);
|
||||
info!(
|
||||
"[pool: {}][user: {}] Max server lifetime: {}",
|
||||
pool_name,
|
||||
user.1.username,
|
||||
match user.1.server_lifetime {
|
||||
Some(server_lifetime) => format!("{}ms", server_lifetime),
|
||||
None => "default".to_string(),
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -896,7 +1041,13 @@ impl Config {
|
||||
&& (self.general.auth_query_user.is_none()
|
||||
|| self.general.auth_query_password.is_none())
|
||||
{
|
||||
error!("If auth_query is specified, you need to provide a value for `auth_query_user`, `auth_query_password`");
|
||||
error!(
|
||||
"If auth_query is specified, \
|
||||
you need to provide a value \
|
||||
for `auth_query_user`, \
|
||||
`auth_query_password`"
|
||||
);
|
||||
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
|
||||
@@ -904,7 +1055,14 @@ impl Config {
|
||||
if pool.auth_query.is_some()
|
||||
&& (pool.auth_query_user.is_none() || pool.auth_query_password.is_none())
|
||||
{
|
||||
error!("Error in pool {{ {} }}. If auth_query is specified, you need to provide a value for `auth_query_user`, `auth_query_password`", name);
|
||||
error!(
|
||||
"Error in pool {{ {} }}. \
|
||||
If auth_query is specified, you need \
|
||||
to provide a value for `auth_query_user`, \
|
||||
`auth_query_password`",
|
||||
name
|
||||
);
|
||||
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
|
||||
@@ -914,7 +1072,13 @@ impl Config {
|
||||
|| pool.auth_query_user.is_none())
|
||||
&& user_data.password.is_none()
|
||||
{
|
||||
error!("Error in pool {{ {} }}. You have to specify a user password for every pool if auth_query is not specified", name);
|
||||
error!(
|
||||
"Error in pool {{ {} }}. \
|
||||
You have to specify a user password \
|
||||
for every pool if auth_query is not specified",
|
||||
name
|
||||
);
|
||||
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
}
|
||||
@@ -1012,6 +1176,7 @@ pub async fn parse(path: &str) -> Result<(), Error> {
|
||||
|
||||
pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, Error> {
|
||||
let old_config = get_config();
|
||||
|
||||
match parse(&old_config.path).await {
|
||||
Ok(()) => (),
|
||||
Err(err) => {
|
||||
@@ -1019,14 +1184,18 @@ pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, E
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
};
|
||||
|
||||
let new_config = get_config();
|
||||
|
||||
if old_config.pools != new_config.pools {
|
||||
info!("Pool configuration changed");
|
||||
match CachedResolver::from_config().await {
|
||||
Ok(_) => (),
|
||||
Err(err) => error!("DNS cache reinitialization error: {:?}", err),
|
||||
};
|
||||
|
||||
if old_config != new_config {
|
||||
info!("Config changed, reloading");
|
||||
ConnectionPool::from_config(client_server_map).await?;
|
||||
Ok(true)
|
||||
} else if old_config != new_config {
|
||||
Ok(true)
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
410
src/dns_cache.rs
Normal file
410
src/dns_cache.rs
Normal file
@@ -0,0 +1,410 @@
|
||||
use crate::config::get_config;
|
||||
use crate::errors::Error;
|
||||
use arc_swap::ArcSwap;
|
||||
use log::{debug, error, info, warn};
|
||||
use once_cell::sync::Lazy;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::io;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
use tokio::time::{sleep, Duration};
|
||||
use trust_dns_resolver::error::{ResolveError, ResolveResult};
|
||||
use trust_dns_resolver::lookup_ip::LookupIp;
|
||||
use trust_dns_resolver::TokioAsyncResolver;
|
||||
|
||||
/// Cached Resolver Globally available
|
||||
pub static CACHED_RESOLVER: Lazy<ArcSwap<CachedResolver>> =
|
||||
Lazy::new(|| ArcSwap::from_pointee(CachedResolver::default()));
|
||||
|
||||
// Ip addressed are returned as a set of addresses
|
||||
// so we can compare.
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub struct AddrSet {
|
||||
set: HashSet<IpAddr>,
|
||||
}
|
||||
|
||||
impl AddrSet {
|
||||
fn new() -> AddrSet {
|
||||
AddrSet {
|
||||
set: HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LookupIp> for AddrSet {
|
||||
fn from(lookup_ip: LookupIp) -> Self {
|
||||
let mut addr_set = AddrSet::new();
|
||||
for address in lookup_ip.iter() {
|
||||
addr_set.set.insert(address);
|
||||
}
|
||||
addr_set
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// A CachedResolver is a DNS resolution cache mechanism with customizable expiration time.
|
||||
///
|
||||
/// The system works as follows:
|
||||
///
|
||||
/// When a host is to be resolved, if we have not resolved it before, a new resolution is
|
||||
/// executed and stored in the internal cache. Concurrently, every `dns_max_ttl` time, the
|
||||
/// cache is refreshed.
|
||||
///
|
||||
/// # Example:
|
||||
///
|
||||
/// ```
|
||||
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
|
||||
///
|
||||
/// # tokio_test::block_on(async {
|
||||
/// let config = CachedResolverConfig::default();
|
||||
/// let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
/// let addrset = resolver.lookup_ip("www.example.com.").await.unwrap();
|
||||
/// # })
|
||||
/// ```
|
||||
///
|
||||
/// // Now the ip resolution is stored in local cache and subsequent
|
||||
/// // calls will be returned from cache. Also, the cache is refreshed
|
||||
/// // and updated every 10 seconds.
|
||||
///
|
||||
/// // You can now check if an 'old' lookup differs from what it's currently
|
||||
/// // store in cache by using `has_changed`.
|
||||
/// resolver.has_changed("www.example.com.", addrset)
|
||||
#[derive(Default)]
|
||||
pub struct CachedResolver {
|
||||
// The configuration of the cached_resolver.
|
||||
config: CachedResolverConfig,
|
||||
|
||||
// This is the hash that contains the hash.
|
||||
data: Option<RwLock<HashMap<String, AddrSet>>>,
|
||||
|
||||
// The resolver to be used for DNS queries.
|
||||
resolver: Option<TokioAsyncResolver>,
|
||||
|
||||
// The RefreshLoop
|
||||
refresh_loop: RwLock<Option<tokio::task::JoinHandle<()>>>,
|
||||
}
|
||||
|
||||
///
|
||||
/// Configuration
|
||||
#[derive(Clone, Debug, Default, PartialEq)]
|
||||
pub struct CachedResolverConfig {
|
||||
/// Amount of time in secods that a resolved dns address is considered stale.
|
||||
dns_max_ttl: u64,
|
||||
|
||||
/// Enabled or disabled? (this is so we can reload config)
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
impl CachedResolverConfig {
|
||||
fn new(dns_max_ttl: u64, enabled: bool) -> Self {
|
||||
CachedResolverConfig {
|
||||
dns_max_ttl,
|
||||
enabled,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::config::Config> for CachedResolverConfig {
|
||||
fn from(config: crate::config::Config) -> Self {
|
||||
CachedResolverConfig::new(config.general.dns_max_ttl, config.general.dns_cache_enabled)
|
||||
}
|
||||
}
|
||||
|
||||
impl CachedResolver {
|
||||
///
|
||||
/// Returns a new Arc<CachedResolver> based on passed configuration.
|
||||
/// It also starts the loop that will refresh cache entries.
|
||||
///
|
||||
/// # Arguments:
|
||||
///
|
||||
/// * `config` - The `CachedResolverConfig` to be used to create the resolver.
|
||||
///
|
||||
/// # Example:
|
||||
///
|
||||
/// ```
|
||||
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
|
||||
///
|
||||
/// # tokio_test::block_on(async {
|
||||
/// let config = CachedResolverConfig::default();
|
||||
/// let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
/// # })
|
||||
/// ```
|
||||
///
|
||||
pub async fn new(
|
||||
config: CachedResolverConfig,
|
||||
data: Option<HashMap<String, AddrSet>>,
|
||||
) -> Result<Arc<Self>, io::Error> {
|
||||
// Construct a new Resolver with default configuration options
|
||||
let resolver = Some(TokioAsyncResolver::tokio_from_system_conf()?);
|
||||
|
||||
let data = if let Some(hash) = data {
|
||||
Some(RwLock::new(hash))
|
||||
} else {
|
||||
Some(RwLock::new(HashMap::new()))
|
||||
};
|
||||
|
||||
let instance = Arc::new(Self {
|
||||
config,
|
||||
resolver,
|
||||
data,
|
||||
refresh_loop: RwLock::new(None),
|
||||
});
|
||||
|
||||
if instance.enabled() {
|
||||
info!("Scheduling DNS refresh loop");
|
||||
let refresh_loop = tokio::task::spawn({
|
||||
let instance = instance.clone();
|
||||
async move {
|
||||
instance.refresh_dns_entries_loop().await;
|
||||
}
|
||||
});
|
||||
*(instance.refresh_loop.write().unwrap()) = Some(refresh_loop);
|
||||
}
|
||||
|
||||
Ok(instance)
|
||||
}
|
||||
|
||||
pub fn enabled(&self) -> bool {
|
||||
self.config.enabled
|
||||
}
|
||||
|
||||
// Schedules the refresher
|
||||
async fn refresh_dns_entries_loop(&self) {
|
||||
let resolver = TokioAsyncResolver::tokio_from_system_conf().unwrap();
|
||||
let interval = Duration::from_secs(self.config.dns_max_ttl);
|
||||
loop {
|
||||
debug!("Begin refreshing cached DNS addresses.");
|
||||
// To minimize the time we hold the lock, we first create
|
||||
// an array with keys.
|
||||
let mut hostnames: Vec<String> = Vec::new();
|
||||
{
|
||||
if let Some(ref data) = self.data {
|
||||
for hostname in data.read().unwrap().keys() {
|
||||
hostnames.push(hostname.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for hostname in hostnames.iter() {
|
||||
let addrset = self
|
||||
.fetch_from_cache(hostname.as_str())
|
||||
.expect("Could not obtain expected address from cache, this should not happen");
|
||||
|
||||
match resolver.lookup_ip(hostname).await {
|
||||
Ok(lookup_ip) => {
|
||||
let new_addrset = AddrSet::from(lookup_ip);
|
||||
debug!(
|
||||
"Obtained address for host ({}) -> ({:?})",
|
||||
hostname, new_addrset
|
||||
);
|
||||
|
||||
if addrset != new_addrset {
|
||||
debug!(
|
||||
"Addr changed from {:?} to {:?} updating cache.",
|
||||
addrset, new_addrset
|
||||
);
|
||||
self.store_in_cache(hostname, new_addrset);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
error!(
|
||||
"There was an error trying to resolv {}: ({}).",
|
||||
hostname, err
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
debug!("Finished refreshing cached DNS addresses.");
|
||||
sleep(interval).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a `AddrSet` given the specified hostname.
|
||||
///
|
||||
/// This method first tries to fetch the value from the cache, if it misses
|
||||
/// then it is resolved and stored in the cache. TTL from records is ignored.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `host` - A string slice referencing the hostname to be resolved.
|
||||
///
|
||||
/// # Example:
|
||||
///
|
||||
/// ```
|
||||
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
|
||||
///
|
||||
/// # tokio_test::block_on(async {
|
||||
/// let config = CachedResolverConfig::default();
|
||||
/// let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
/// let response = resolver.lookup_ip("www.google.com.");
|
||||
/// # })
|
||||
/// ```
|
||||
///
|
||||
pub async fn lookup_ip(&self, host: &str) -> ResolveResult<AddrSet> {
|
||||
debug!("Lookup up {} in cache", host);
|
||||
match self.fetch_from_cache(host) {
|
||||
Some(addr_set) => {
|
||||
debug!("Cache hit!");
|
||||
Ok(addr_set)
|
||||
}
|
||||
None => {
|
||||
debug!("Not found, executing a dns query!");
|
||||
if let Some(ref resolver) = self.resolver {
|
||||
let addr_set = AddrSet::from(resolver.lookup_ip(host).await?);
|
||||
debug!("Obtained: {:?}", addr_set);
|
||||
self.store_in_cache(host, addr_set.clone());
|
||||
Ok(addr_set)
|
||||
} else {
|
||||
Err(ResolveError::from("No resolver available"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Returns true if the stored host resolution differs from the AddrSet passed.
|
||||
pub fn has_changed(&self, host: &str, addr_set: &AddrSet) -> bool {
|
||||
if let Some(fetched_addr_set) = self.fetch_from_cache(host) {
|
||||
return fetched_addr_set != *addr_set;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// Fetches an AddrSet from the inner cache adquiring the read lock.
|
||||
fn fetch_from_cache(&self, key: &str) -> Option<AddrSet> {
|
||||
if let Some(ref hash) = self.data {
|
||||
if let Some(addr_set) = hash.read().unwrap().get(key) {
|
||||
return Some(addr_set.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// Sets up the global CACHED_RESOLVER static variable so we can globally use DNS
|
||||
// cache.
|
||||
pub async fn from_config() -> Result<(), Error> {
|
||||
let cached_resolver = CACHED_RESOLVER.load();
|
||||
let desired_config = CachedResolverConfig::from(get_config());
|
||||
|
||||
if cached_resolver.config != desired_config {
|
||||
if let Some(ref refresh_loop) = *(cached_resolver.refresh_loop.write().unwrap()) {
|
||||
warn!("Killing Dnscache refresh loop as its configuration is being reloaded");
|
||||
refresh_loop.abort()
|
||||
}
|
||||
let new_resolver = if let Some(ref data) = cached_resolver.data {
|
||||
let data = Some(data.read().unwrap().clone());
|
||||
CachedResolver::new(desired_config, data).await
|
||||
} else {
|
||||
CachedResolver::new(desired_config, None).await
|
||||
};
|
||||
|
||||
match new_resolver {
|
||||
Ok(ok) => {
|
||||
CACHED_RESOLVER.store(ok);
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
let message = format!("Error setting up cached_resolver. Error: {:?}, will continue without this feature.", err);
|
||||
Err(Error::DNSCachedError(message))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// Stores the AddrSet in cache adquiring the write lock.
|
||||
fn store_in_cache(&self, host: &str, addr_set: AddrSet) {
|
||||
if let Some(ref data) = self.data {
|
||||
data.write().unwrap().insert(host.to_string(), addr_set);
|
||||
} else {
|
||||
error!("Could not insert, Hash not initialized");
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use trust_dns_resolver::error::ResolveError;
|
||||
|
||||
#[tokio::test]
|
||||
async fn new() {
|
||||
let config = CachedResolverConfig {
|
||||
dns_max_ttl: 10,
|
||||
enabled: true,
|
||||
};
|
||||
let resolver = CachedResolver::new(config, None).await;
|
||||
assert!(resolver.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn lookup_ip() {
|
||||
let config = CachedResolverConfig {
|
||||
dns_max_ttl: 10,
|
||||
enabled: true,
|
||||
};
|
||||
let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
let response = resolver.lookup_ip("www.google.com.").await;
|
||||
assert!(response.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn has_changed() {
|
||||
let config = CachedResolverConfig {
|
||||
dns_max_ttl: 10,
|
||||
enabled: true,
|
||||
};
|
||||
let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
let hostname = "www.google.com.";
|
||||
let response = resolver.lookup_ip(hostname).await;
|
||||
let addr_set = response.unwrap();
|
||||
assert!(!resolver.has_changed(hostname, &addr_set));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unknown_host() {
|
||||
let config = CachedResolverConfig {
|
||||
dns_max_ttl: 10,
|
||||
enabled: true,
|
||||
};
|
||||
let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
let hostname = "www.idontexists.";
|
||||
let response = resolver.lookup_ip(hostname).await;
|
||||
assert!(matches!(response, Err(ResolveError { .. })));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn incorrect_address() {
|
||||
let config = CachedResolverConfig {
|
||||
dns_max_ttl: 10,
|
||||
enabled: true,
|
||||
};
|
||||
let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
let hostname = "w ww.idontexists.";
|
||||
let response = resolver.lookup_ip(hostname).await;
|
||||
assert!(matches!(response, Err(ResolveError { .. })));
|
||||
assert!(!resolver.has_changed(hostname, &AddrSet::new()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
// Ok, this test is based on the fact that google does DNS RR
|
||||
// and does not responds with every available ip everytime, so
|
||||
// if I cache here, it will miss after one cache iteration or two.
|
||||
async fn thread() {
|
||||
let config = CachedResolverConfig {
|
||||
dns_max_ttl: 10,
|
||||
enabled: true,
|
||||
};
|
||||
let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
let hostname = "www.google.com.";
|
||||
let response = resolver.lookup_ip(hostname).await;
|
||||
let addr_set = response.unwrap();
|
||||
assert!(!resolver.has_changed(hostname, &addr_set));
|
||||
let resolver_for_refresher = resolver.clone();
|
||||
let _thread_handle = tokio::task::spawn(async move {
|
||||
resolver_for_refresher.refresh_dns_entries_loop().await;
|
||||
});
|
||||
assert!(!resolver.has_changed(hostname, &addr_set));
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
//! Errors.
|
||||
|
||||
/// Various errors.
|
||||
#[derive(Debug, PartialEq)]
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub enum Error {
|
||||
SocketError(String),
|
||||
ClientSocketError(String, ClientIdentifier),
|
||||
@@ -19,10 +19,13 @@ pub enum Error {
|
||||
ClientError(String),
|
||||
TlsError,
|
||||
StatementTimeout,
|
||||
DNSCachedError(String),
|
||||
ShuttingDown,
|
||||
ParseBytesError(String),
|
||||
AuthError(String),
|
||||
AuthPassthroughError(String),
|
||||
UnsupportedStatement,
|
||||
QueryRouterParserError(String),
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
pub mod admin;
|
||||
pub mod auth_passthrough;
|
||||
pub mod client;
|
||||
pub mod config;
|
||||
pub mod constants;
|
||||
pub mod dns_cache;
|
||||
pub mod errors;
|
||||
pub mod messages;
|
||||
pub mod mirrors;
|
||||
pub mod multi_logger;
|
||||
pub mod plugins;
|
||||
pub mod pool;
|
||||
pub mod prometheus;
|
||||
pub mod query_router;
|
||||
pub mod scram;
|
||||
pub mod server;
|
||||
pub mod sharding;
|
||||
|
||||
46
src/main.rs
46
src/main.rs
@@ -36,6 +36,7 @@ extern crate sqlparser;
|
||||
extern crate tokio;
|
||||
extern crate tokio_rustls;
|
||||
extern crate toml;
|
||||
extern crate trust_dns_resolver;
|
||||
|
||||
#[cfg(not(target_env = "msvc"))]
|
||||
use jemallocator::Jemalloc;
|
||||
@@ -60,36 +61,19 @@ use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
mod admin;
|
||||
mod auth_passthrough;
|
||||
mod client;
|
||||
mod config;
|
||||
mod constants;
|
||||
mod errors;
|
||||
mod messages;
|
||||
mod mirrors;
|
||||
mod multi_logger;
|
||||
mod pool;
|
||||
mod prometheus;
|
||||
mod query_router;
|
||||
mod scram;
|
||||
mod server;
|
||||
mod sharding;
|
||||
mod stats;
|
||||
mod tls;
|
||||
|
||||
use crate::config::{get_config, reload_config, VERSION};
|
||||
use crate::messages::configure_socket;
|
||||
use crate::pool::{ClientServerMap, ConnectionPool};
|
||||
use crate::prometheus::start_metric_server;
|
||||
use crate::stats::{Collector, Reporter, REPORTER};
|
||||
use pgcat::config::{get_config, reload_config, VERSION};
|
||||
use pgcat::dns_cache;
|
||||
use pgcat::messages::configure_socket;
|
||||
use pgcat::pool::{ClientServerMap, ConnectionPool};
|
||||
use pgcat::prometheus::start_metric_server;
|
||||
use pgcat::stats::{Collector, Reporter, REPORTER};
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
multi_logger::MultiLogger::init().unwrap();
|
||||
pgcat::multi_logger::MultiLogger::init().unwrap();
|
||||
|
||||
info!("Welcome to PgCat! Meow. (Version {})", VERSION);
|
||||
|
||||
if !query_router::QueryRouter::setup() {
|
||||
if !pgcat::query_router::QueryRouter::setup() {
|
||||
error!("Could not setup query router");
|
||||
std::process::exit(exitcode::CONFIG);
|
||||
}
|
||||
@@ -107,7 +91,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let runtime = Builder::new_multi_thread().worker_threads(1).build()?;
|
||||
|
||||
runtime.block_on(async {
|
||||
match config::parse(&config_file).await {
|
||||
match pgcat::config::parse(&config_file).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
error!("Config parse error: {:?}", err);
|
||||
@@ -166,6 +150,12 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Statistics reporting.
|
||||
REPORTER.store(Arc::new(Reporter::default()));
|
||||
|
||||
// Starts (if enabled) dns cache before pools initialization
|
||||
match dns_cache::CachedResolver::from_config().await {
|
||||
Ok(_) => (),
|
||||
Err(err) => error!("DNS cache initialization error: {:?}", err),
|
||||
};
|
||||
|
||||
// Connection pool that allows to query all shards and replicas.
|
||||
match ConnectionPool::from_config(client_server_map.clone()).await {
|
||||
Ok(_) => (),
|
||||
@@ -295,7 +285,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
tokio::task::spawn(async move {
|
||||
let start = chrono::offset::Utc::now().naive_utc();
|
||||
|
||||
match client::client_entrypoint(
|
||||
match pgcat::client::client_entrypoint(
|
||||
socket,
|
||||
client_server_map,
|
||||
shutdown_rx,
|
||||
@@ -326,7 +316,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
|
||||
Err(err) => {
|
||||
match err {
|
||||
errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err),
|
||||
pgcat::errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err),
|
||||
_ => warn!("Client disconnected with error {:?}", err),
|
||||
}
|
||||
|
||||
|
||||
@@ -20,6 +20,10 @@ pub enum DataType {
|
||||
Text,
|
||||
Int4,
|
||||
Numeric,
|
||||
Bool,
|
||||
Oid,
|
||||
AnyArray,
|
||||
Any,
|
||||
}
|
||||
|
||||
impl From<&DataType> for i32 {
|
||||
@@ -28,6 +32,10 @@ impl From<&DataType> for i32 {
|
||||
DataType::Text => 25,
|
||||
DataType::Int4 => 23,
|
||||
DataType::Numeric => 1700,
|
||||
DataType::Bool => 16,
|
||||
DataType::Oid => 26,
|
||||
DataType::AnyArray => 2277,
|
||||
DataType::Any => 2276,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -116,7 +124,10 @@ where
|
||||
|
||||
/// Send the startup packet the server. We're pretending we're a Pg client.
|
||||
/// This tells the server which user we are and what database we want.
|
||||
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> {
|
||||
pub async fn startup<S>(stream: &mut S, user: &str, database: &str) -> Result<(), Error>
|
||||
where
|
||||
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
let mut bytes = BytesMut::with_capacity(25);
|
||||
|
||||
bytes.put_i32(196608); // Protocol number
|
||||
@@ -150,6 +161,21 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn ssl_request(stream: &mut TcpStream) -> Result<(), Error> {
|
||||
let mut bytes = BytesMut::with_capacity(12);
|
||||
|
||||
bytes.put_i32(8);
|
||||
bytes.put_i32(80877103);
|
||||
|
||||
match stream.write_all(&bytes).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(Error::SocketError(format!(
|
||||
"Error writing SSLRequest to server socket - Error: {:?}",
|
||||
err
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse the params the server sends as a key/value format.
|
||||
pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
|
||||
let mut result = HashMap::new();
|
||||
@@ -425,6 +451,10 @@ pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
|
||||
DataType::Text => -1,
|
||||
DataType::Int4 => 4,
|
||||
DataType::Numeric => -1,
|
||||
DataType::Bool => 1,
|
||||
DataType::Oid => 4,
|
||||
DataType::AnyArray => -1,
|
||||
DataType::Any => -1,
|
||||
};
|
||||
|
||||
row_desc.put_i16(type_size);
|
||||
@@ -463,6 +493,29 @@ pub fn data_row(row: &Vec<String>) -> BytesMut {
|
||||
res
|
||||
}
|
||||
|
||||
pub fn data_row_nullable(row: &Vec<Option<String>>) -> BytesMut {
|
||||
let mut res = BytesMut::new();
|
||||
let mut data_row = BytesMut::new();
|
||||
|
||||
data_row.put_i16(row.len() as i16);
|
||||
|
||||
for column in row {
|
||||
if let Some(column) = column {
|
||||
let column = column.as_bytes();
|
||||
data_row.put_i32(column.len() as i32);
|
||||
data_row.put_slice(column);
|
||||
} else {
|
||||
data_row.put_i32(-1 as i32);
|
||||
}
|
||||
}
|
||||
|
||||
res.put_u8(b'D');
|
||||
res.put_i32(data_row.len() as i32 + 4);
|
||||
res.put(data_row);
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
/// Create a CommandComplete message.
|
||||
pub fn command_complete(command: &str) -> BytesMut {
|
||||
let cmd = BytesMut::from(format!("{}\0", command).as_bytes());
|
||||
@@ -505,6 +558,29 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn write_all_flush<S>(stream: &mut S, buf: &[u8]) -> Result<(), Error>
|
||||
where
|
||||
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
match stream.write_all(buf).await {
|
||||
Ok(_) => match stream.flush().await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error flushing socket - Error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error writing to socket - Error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read a complete message from the socket.
|
||||
pub async fn read_message<S>(stream: &mut S) -> Result<BytesMut, Error>
|
||||
where
|
||||
|
||||
288
src/plugins/intercept.rs
Normal file
288
src/plugins/intercept.rs
Normal file
@@ -0,0 +1,288 @@
|
||||
//! The intercept plugin.
|
||||
//!
|
||||
//! It intercepts queries and returns fake results.
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use async_trait::async_trait;
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use sqlparser::ast::Statement;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use log::{debug, info};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
config::Intercept as InterceptConfig,
|
||||
errors::Error,
|
||||
messages::{command_complete, data_row_nullable, row_description, DataType},
|
||||
plugins::{Plugin, PluginOutput},
|
||||
pool::{PoolIdentifier, PoolMap},
|
||||
query_router::QueryRouter,
|
||||
};
|
||||
|
||||
pub static CONFIG: Lazy<ArcSwap<HashMap<PoolIdentifier, InterceptConfig>>> =
|
||||
Lazy::new(|| ArcSwap::from_pointee(HashMap::new()));
|
||||
|
||||
/// Check if the interceptor plugin has been enabled.
|
||||
pub fn enabled() -> bool {
|
||||
!CONFIG.load().is_empty()
|
||||
}
|
||||
|
||||
pub fn setup(intercept_config: &InterceptConfig, pools: &PoolMap) {
|
||||
let mut config = HashMap::new();
|
||||
for (identifier, _) in pools.iter() {
|
||||
let mut intercept_config = intercept_config.clone();
|
||||
intercept_config.substitute(&identifier.db, &identifier.user);
|
||||
config.insert(identifier.clone(), intercept_config);
|
||||
}
|
||||
|
||||
CONFIG.store(Arc::new(config));
|
||||
|
||||
info!("Intercepting {} queries", intercept_config.queries.len());
|
||||
}
|
||||
|
||||
pub fn disable() {
|
||||
CONFIG.store(Arc::new(HashMap::new()));
|
||||
}
|
||||
|
||||
// TODO: use these structs for deserialization
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Rule {
|
||||
query: String,
|
||||
schema: Vec<Column>,
|
||||
result: Vec<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Column {
|
||||
name: String,
|
||||
data_type: String,
|
||||
}
|
||||
|
||||
/// The intercept plugin.
|
||||
pub struct Intercept;
|
||||
|
||||
#[async_trait]
|
||||
impl Plugin for Intercept {
|
||||
async fn run(
|
||||
&mut self,
|
||||
query_router: &QueryRouter,
|
||||
ast: &Vec<Statement>,
|
||||
) -> Result<PluginOutput, Error> {
|
||||
if ast.is_empty() {
|
||||
return Ok(PluginOutput::Allow);
|
||||
}
|
||||
|
||||
let mut result = BytesMut::new();
|
||||
let query_map = match CONFIG.load().get(&PoolIdentifier::new(
|
||||
&query_router.pool_settings().db,
|
||||
&query_router.pool_settings().user.username,
|
||||
)) {
|
||||
Some(query_map) => query_map.clone(),
|
||||
None => return Ok(PluginOutput::Allow),
|
||||
};
|
||||
|
||||
for q in ast {
|
||||
// Normalization
|
||||
let q = q.to_string().to_ascii_lowercase();
|
||||
|
||||
for (_, target) in query_map.queries.iter() {
|
||||
if target.query.as_str() == q {
|
||||
debug!("Intercepting query: {}", q);
|
||||
|
||||
let rd = target
|
||||
.schema
|
||||
.iter()
|
||||
.map(|row| {
|
||||
let name = &row[0];
|
||||
let data_type = &row[1];
|
||||
(
|
||||
name.as_str(),
|
||||
match data_type.as_str() {
|
||||
"text" => DataType::Text,
|
||||
"anyarray" => DataType::AnyArray,
|
||||
"oid" => DataType::Oid,
|
||||
"bool" => DataType::Bool,
|
||||
"int4" => DataType::Int4,
|
||||
_ => DataType::Any,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect::<Vec<(&str, DataType)>>();
|
||||
|
||||
result.put(row_description(&rd));
|
||||
|
||||
target.result.iter().for_each(|row| {
|
||||
let row = row
|
||||
.iter()
|
||||
.map(|s| {
|
||||
let s = s.as_str().to_string();
|
||||
|
||||
if s == "" {
|
||||
None
|
||||
} else {
|
||||
Some(s)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<Option<String>>>();
|
||||
result.put(data_row_nullable(&row));
|
||||
});
|
||||
|
||||
result.put(command_complete("SELECT"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !result.is_empty() {
|
||||
result.put_u8(b'Z');
|
||||
result.put_i32(5);
|
||||
result.put_u8(b'I');
|
||||
|
||||
return Ok(PluginOutput::Intercept(result));
|
||||
} else {
|
||||
Ok(PluginOutput::Allow)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Make IntelliJ SQL plugin believe it's talking to an actual database
|
||||
/// instead of PgCat.
|
||||
#[allow(dead_code)]
|
||||
fn fool_datagrip(database: &str, user: &str) -> Value {
|
||||
json!([
|
||||
{
|
||||
"query": "select current_database() as a, current_schemas(false) as b",
|
||||
"schema": [
|
||||
{
|
||||
"name": "a",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "b",
|
||||
"data_type": "anyarray",
|
||||
},
|
||||
],
|
||||
|
||||
"result": [
|
||||
[database, "{public}"],
|
||||
],
|
||||
},
|
||||
{
|
||||
"query": "select current_database(), current_schema(), current_user",
|
||||
"schema": [
|
||||
{
|
||||
"name": "current_database",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "current_schema",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "current_user",
|
||||
"data_type": "text",
|
||||
}
|
||||
],
|
||||
|
||||
"result": [
|
||||
["sharded_db", "public", "sharding_user"],
|
||||
],
|
||||
},
|
||||
{
|
||||
"query": "select cast(n.oid as bigint) as id, datname as name, d.description, datistemplate as is_template, datallowconn as allow_connections, pg_catalog.pg_get_userbyid(n.datdba) as \"owner\" from pg_catalog.pg_database as n left join pg_catalog.pg_shdescription as d on n.oid = d.objoid order by case when datname = pg_catalog.current_database() then -cast(1 as bigint) else cast(n.oid as bigint) end",
|
||||
"schema": [
|
||||
{
|
||||
"name": "id",
|
||||
"data_type": "oid",
|
||||
},
|
||||
{
|
||||
"name": "name",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "description",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "is_template",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "allow_connections",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "owner",
|
||||
"data_type": "text",
|
||||
}
|
||||
],
|
||||
"result": [
|
||||
["16387", database, "", "f", "t", user],
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "select cast(r.oid as bigint) as role_id, rolname as role_name, rolsuper as is_super, rolinherit as is_inherit, rolcreaterole as can_createrole, rolcreatedb as can_createdb, rolcanlogin as can_login, rolreplication as is_replication, rolconnlimit as conn_limit, rolvaliduntil as valid_until, rolbypassrls as bypass_rls, rolconfig as config, d.description from pg_catalog.pg_roles as r left join pg_catalog.pg_shdescription as d on d.objoid = r.oid",
|
||||
"schema": [
|
||||
{
|
||||
"name": "role_id",
|
||||
"data_type": "oid",
|
||||
},
|
||||
{
|
||||
"name": "role_name",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "is_super",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "is_inherit",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "can_createrole",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "can_createdb",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "can_login",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "is_replication",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "conn_limit",
|
||||
"data_type": "int4",
|
||||
},
|
||||
{
|
||||
"name": "valid_until",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "bypass_rls",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "config",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "description",
|
||||
"data_type": "text",
|
||||
},
|
||||
],
|
||||
"result": [
|
||||
["10", "postgres", "f", "t", "f", "f", "t", "f", "-1", "", "f", "", ""],
|
||||
["16419", user, "f", "t", "f", "f", "t", "f", "-1", "", "f", "", ""],
|
||||
]
|
||||
}
|
||||
])
|
||||
}
|
||||
43
src/plugins/mod.rs
Normal file
43
src/plugins/mod.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
//! The plugin ecosystem.
|
||||
//!
|
||||
//! Currently plugins only grant access or deny access to the database for a particual query.
|
||||
//! Example use cases:
|
||||
//! - block known bad queries
|
||||
//! - block access to system catalogs
|
||||
//! - block dangerous modifications like `DROP TABLE`
|
||||
//! - etc
|
||||
//!
|
||||
|
||||
pub mod intercept;
|
||||
pub mod query_logger;
|
||||
pub mod table_access;
|
||||
|
||||
use crate::{errors::Error, query_router::QueryRouter};
|
||||
use async_trait::async_trait;
|
||||
use bytes::BytesMut;
|
||||
use sqlparser::ast::Statement;
|
||||
|
||||
pub use intercept::Intercept;
|
||||
pub use query_logger::QueryLogger;
|
||||
pub use table_access::TableAccess;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum PluginOutput {
|
||||
Allow,
|
||||
Deny(String),
|
||||
Overwrite(Vec<Statement>),
|
||||
Intercept(BytesMut),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Plugin {
|
||||
// Run before the query is sent to the server.
|
||||
async fn run(
|
||||
&mut self,
|
||||
query_router: &QueryRouter,
|
||||
ast: &Vec<Statement>,
|
||||
) -> Result<PluginOutput, Error>;
|
||||
|
||||
// TODO: run after the result is returned
|
||||
// async fn callback(&mut self, query_router: &QueryRouter);
|
||||
}
|
||||
49
src/plugins/query_logger.rs
Normal file
49
src/plugins/query_logger.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
//! Log all queries to stdout (or somewhere else, why not).
|
||||
|
||||
use crate::{
|
||||
errors::Error,
|
||||
plugins::{Plugin, PluginOutput},
|
||||
query_router::QueryRouter,
|
||||
};
|
||||
use arc_swap::ArcSwap;
|
||||
use async_trait::async_trait;
|
||||
use log::info;
|
||||
use once_cell::sync::Lazy;
|
||||
use sqlparser::ast::Statement;
|
||||
use std::sync::Arc;
|
||||
|
||||
static ENABLED: Lazy<ArcSwap<bool>> = Lazy::new(|| ArcSwap::from_pointee(false));
|
||||
|
||||
pub struct QueryLogger;
|
||||
|
||||
pub fn setup() {
|
||||
ENABLED.store(Arc::new(true));
|
||||
|
||||
info!("Logging queries to stdout");
|
||||
}
|
||||
|
||||
pub fn disable() {
|
||||
ENABLED.store(Arc::new(false));
|
||||
}
|
||||
|
||||
pub fn enabled() -> bool {
|
||||
**ENABLED.load()
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Plugin for QueryLogger {
|
||||
async fn run(
|
||||
&mut self,
|
||||
_query_router: &QueryRouter,
|
||||
ast: &Vec<Statement>,
|
||||
) -> Result<PluginOutput, Error> {
|
||||
let query = ast
|
||||
.iter()
|
||||
.map(|q| q.to_string())
|
||||
.collect::<Vec<String>>()
|
||||
.join("; ");
|
||||
info!("{}", query);
|
||||
|
||||
Ok(PluginOutput::Allow)
|
||||
}
|
||||
}
|
||||
73
src/plugins/table_access.rs
Normal file
73
src/plugins/table_access.rs
Normal file
@@ -0,0 +1,73 @@
|
||||
//! This query router plugin will check if the user can access a particular
|
||||
//! table as part of their query. If they can't, the query will not be routed.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use sqlparser::ast::{visit_relations, Statement};
|
||||
|
||||
use crate::{
|
||||
config::TableAccess as TableAccessConfig,
|
||||
errors::Error,
|
||||
plugins::{Plugin, PluginOutput},
|
||||
query_router::QueryRouter,
|
||||
};
|
||||
|
||||
use log::{debug, info};
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use core::ops::ControlFlow;
|
||||
use once_cell::sync::Lazy;
|
||||
use std::sync::Arc;
|
||||
|
||||
static CONFIG: Lazy<ArcSwap<Vec<String>>> = Lazy::new(|| ArcSwap::from_pointee(vec![]));
|
||||
|
||||
pub fn setup(config: &TableAccessConfig) {
|
||||
CONFIG.store(Arc::new(config.tables.clone()));
|
||||
|
||||
info!("Blocking access to {} tables", config.tables.len());
|
||||
}
|
||||
|
||||
pub fn enabled() -> bool {
|
||||
!CONFIG.load().is_empty()
|
||||
}
|
||||
|
||||
pub fn disable() {
|
||||
CONFIG.store(Arc::new(vec![]));
|
||||
}
|
||||
|
||||
pub struct TableAccess;
|
||||
|
||||
#[async_trait]
|
||||
impl Plugin for TableAccess {
|
||||
async fn run(
|
||||
&mut self,
|
||||
_query_router: &QueryRouter,
|
||||
ast: &Vec<Statement>,
|
||||
) -> Result<PluginOutput, Error> {
|
||||
let mut found = None;
|
||||
let forbidden_tables = CONFIG.load();
|
||||
|
||||
visit_relations(ast, |relation| {
|
||||
let relation = relation.to_string();
|
||||
let parts = relation.split(".").collect::<Vec<&str>>();
|
||||
let table_name = parts.last().unwrap();
|
||||
|
||||
if forbidden_tables.contains(&table_name.to_string()) {
|
||||
found = Some(table_name.to_string());
|
||||
ControlFlow::<()>::Break(())
|
||||
} else {
|
||||
ControlFlow::<()>::Continue(())
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(found) = found {
|
||||
debug!("Blocking access to table \"{}\"", found);
|
||||
|
||||
Ok(PluginOutput::Deny(format!(
|
||||
"permission for table \"{}\" denied",
|
||||
found
|
||||
)))
|
||||
} else {
|
||||
Ok(PluginOutput::Allow)
|
||||
}
|
||||
}
|
||||
}
|
||||
101
src/pool.rs
101
src/pool.rs
@@ -61,6 +61,8 @@ pub struct PoolIdentifier {
|
||||
pub user: String,
|
||||
}
|
||||
|
||||
static POOL_REAPER_RATE: u64 = 30_000; // 30 seconds by default
|
||||
|
||||
impl PoolIdentifier {
|
||||
/// Create a new user/pool identifier.
|
||||
pub fn new(db: &str, user: &str) -> PoolIdentifier {
|
||||
@@ -91,6 +93,7 @@ pub struct PoolSettings {
|
||||
|
||||
// Connecting user.
|
||||
pub user: User,
|
||||
pub db: String,
|
||||
|
||||
// Default server role to connect to.
|
||||
pub default_role: Option<Role>,
|
||||
@@ -138,6 +141,7 @@ impl Default for PoolSettings {
|
||||
load_balancing_mode: LoadBalancingMode::Random,
|
||||
shards: 1,
|
||||
user: User::default(),
|
||||
db: String::default(),
|
||||
default_role: None,
|
||||
query_parser_enabled: false,
|
||||
primary_reads_enabled: true,
|
||||
@@ -311,21 +315,34 @@ impl ConnectionPool {
|
||||
|
||||
if let Some(apt) = &auth_passthrough {
|
||||
match apt.fetch_hash(&address).await {
|
||||
Ok(ok) => {
|
||||
if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) {
|
||||
if ok != *pool_auth_hash_value {
|
||||
warn!("Hash is not the same across shards of the same pool, client auth will \
|
||||
be done using last obtained hash. Server: {}:{}, Database: {}", server.host, server.port, shard.database);
|
||||
}
|
||||
}
|
||||
debug!("Hash obtained for {:?}", address);
|
||||
{
|
||||
let mut pool_auth_hash = pool_auth_hash.write();
|
||||
*pool_auth_hash = Some(ok.clone());
|
||||
}
|
||||
},
|
||||
Err(err) => warn!("Could not obtain password hashes using auth_query config, ignoring. Error: {:?}", err),
|
||||
}
|
||||
Ok(ok) => {
|
||||
if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read())
|
||||
{
|
||||
if ok != *pool_auth_hash_value {
|
||||
warn!(
|
||||
"Hash is not the same across shards \
|
||||
of the same pool, client auth will \
|
||||
be done using last obtained hash. \
|
||||
Server: {}:{}, Database: {}",
|
||||
server.host, server.port, shard.database,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
debug!("Hash obtained for {:?}", address);
|
||||
|
||||
{
|
||||
let mut pool_auth_hash = pool_auth_hash.write();
|
||||
*pool_auth_hash = Some(ok.clone());
|
||||
}
|
||||
}
|
||||
Err(err) => warn!(
|
||||
"Could not obtain password hashes \
|
||||
using auth_query config, ignoring. \
|
||||
Error: {:?}",
|
||||
err,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
let manager = ServerPool::new(
|
||||
@@ -347,14 +364,31 @@ impl ConnectionPool {
|
||||
None => config.general.idle_timeout,
|
||||
};
|
||||
|
||||
let server_lifetime = match user.server_lifetime {
|
||||
Some(server_lifetime) => server_lifetime,
|
||||
None => match pool_config.server_lifetime {
|
||||
Some(server_lifetime) => server_lifetime,
|
||||
None => config.general.server_lifetime,
|
||||
},
|
||||
};
|
||||
|
||||
let reaper_rate = *vec![idle_timeout, server_lifetime, POOL_REAPER_RATE]
|
||||
.iter()
|
||||
.min()
|
||||
.unwrap();
|
||||
|
||||
debug!("Pool reaper rate: {}ms", reaper_rate);
|
||||
|
||||
let pool = Pool::builder()
|
||||
.max_size(user.pool_size)
|
||||
.min_idle(user.min_pool_size)
|
||||
.connection_timeout(std::time::Duration::from_millis(connect_timeout))
|
||||
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
|
||||
.max_lifetime(Some(std::time::Duration::from_millis(server_lifetime)))
|
||||
.reaper_rate(std::time::Duration::from_millis(reaper_rate))
|
||||
.test_on_check_out(false)
|
||||
.build(manager)
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
|
||||
pools.push(pool);
|
||||
servers.push(address);
|
||||
@@ -390,6 +424,7 @@ impl ConnectionPool {
|
||||
// shards: pool_config.shards.clone(),
|
||||
shards: shard_ids.len(),
|
||||
user: user.clone(),
|
||||
db: pool_name.clone(),
|
||||
default_role: match pool_config.default_role.as_str() {
|
||||
"any" => None,
|
||||
"replica" => Some(Role::Replica),
|
||||
@@ -434,6 +469,32 @@ impl ConnectionPool {
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref plugins) = config.plugins {
|
||||
if let Some(ref intercept) = plugins.intercept {
|
||||
if intercept.enabled {
|
||||
crate::plugins::intercept::setup(intercept, &new_pools);
|
||||
} else {
|
||||
crate::plugins::intercept::disable();
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref table_access) = plugins.table_access {
|
||||
if table_access.enabled {
|
||||
crate::plugins::table_access::setup(table_access);
|
||||
} else {
|
||||
crate::plugins::table_access::disable();
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref query_logger) = plugins.query_logger {
|
||||
if query_logger.enabled {
|
||||
crate::plugins::query_logger::setup();
|
||||
} else {
|
||||
crate::plugins::query_logger::disable();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
POOLS.store(Arc::new(new_pools.clone()));
|
||||
Ok(())
|
||||
}
|
||||
@@ -777,7 +838,6 @@ impl ConnectionPool {
|
||||
self.databases.len()
|
||||
}
|
||||
|
||||
/// Retrieve all bans for all servers.
|
||||
pub fn get_bans(&self) -> Vec<(Address, (BanReason, NaiveDateTime))> {
|
||||
let mut bans: Vec<(Address, (BanReason, NaiveDateTime))> = Vec::new();
|
||||
let guard = self.banlist.read();
|
||||
@@ -789,7 +849,7 @@ impl ConnectionPool {
|
||||
return bans;
|
||||
}
|
||||
|
||||
/// Get the address from the host url.
|
||||
/// Get the address from the host url
|
||||
pub fn get_addresses_from_host(&self, host: &str) -> Vec<Address> {
|
||||
let mut addresses = Vec::new();
|
||||
for shard in 0..self.shards() {
|
||||
@@ -828,13 +888,10 @@ impl ConnectionPool {
|
||||
&self.addresses[shard][server]
|
||||
}
|
||||
|
||||
/// Get server settings retrieved at connection setup.
|
||||
pub fn server_info(&self) -> BytesMut {
|
||||
self.server_info.read().clone()
|
||||
}
|
||||
|
||||
/// Calculate how many used connections in the pool
|
||||
/// for the given server.
|
||||
fn busy_connection_count(&self, address: &Address) -> u32 {
|
||||
let state = self.pool_state(address.shard, address.address_index);
|
||||
let idle = state.idle_connections;
|
||||
|
||||
@@ -6,13 +6,19 @@ use once_cell::sync::OnceCell;
|
||||
use regex::{Regex, RegexSet};
|
||||
use sqlparser::ast::Statement::{Query, StartTransaction};
|
||||
use sqlparser::ast::{
|
||||
BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, TableFactor, Value,
|
||||
BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, Statement, TableFactor,
|
||||
Value,
|
||||
};
|
||||
use sqlparser::dialect::PostgreSqlDialect;
|
||||
use sqlparser::parser::Parser;
|
||||
|
||||
use crate::config::Role;
|
||||
use crate::errors::Error;
|
||||
use crate::messages::BytesMutReader;
|
||||
use crate::plugins::{
|
||||
intercept, query_logger, table_access, Intercept, Plugin, PluginOutput, QueryLogger,
|
||||
TableAccess,
|
||||
};
|
||||
use crate::pool::PoolSettings;
|
||||
use crate::sharding::Sharder;
|
||||
|
||||
@@ -129,6 +135,10 @@ impl QueryRouter {
|
||||
self.pool_settings = pool_settings;
|
||||
}
|
||||
|
||||
pub fn pool_settings<'a>(&'a self) -> &'a PoolSettings {
|
||||
&self.pool_settings
|
||||
}
|
||||
|
||||
/// Try to parse a command and execute it.
|
||||
pub fn try_execute_command(&mut self, message_buffer: &BytesMut) -> Option<(Command, String)> {
|
||||
let mut message_cursor = Cursor::new(message_buffer);
|
||||
@@ -324,10 +334,7 @@ impl QueryRouter {
|
||||
Some((command, value))
|
||||
}
|
||||
|
||||
/// Try to infer which server to connect to based on the contents of the query.
|
||||
pub fn infer(&mut self, message: &BytesMut) -> bool {
|
||||
debug!("Inferring role");
|
||||
|
||||
pub fn parse(message: &BytesMut) -> Result<Vec<sqlparser::ast::Statement>, Error> {
|
||||
let mut message_cursor = Cursor::new(message);
|
||||
|
||||
let code = message_cursor.get_u8() as char;
|
||||
@@ -353,28 +360,29 @@ impl QueryRouter {
|
||||
query
|
||||
}
|
||||
|
||||
_ => return false,
|
||||
_ => return Err(Error::UnsupportedStatement),
|
||||
};
|
||||
|
||||
let ast = match Parser::parse_sql(&PostgreSqlDialect {}, &query) {
|
||||
Ok(ast) => ast,
|
||||
match Parser::parse_sql(&PostgreSqlDialect {}, &query) {
|
||||
Ok(ast) => Ok(ast),
|
||||
Err(err) => {
|
||||
// SELECT ... FOR UPDATE won't get parsed correctly.
|
||||
debug!("{}: {}", err, query);
|
||||
self.active_role = Some(Role::Primary);
|
||||
return false;
|
||||
Err(Error::QueryRouterParserError(err.to_string()))
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
debug!("AST: {:?}", ast);
|
||||
/// Try to infer which server to connect to based on the contents of the query.
|
||||
pub fn infer(&mut self, ast: &Vec<sqlparser::ast::Statement>) -> Result<(), Error> {
|
||||
debug!("Inferring role");
|
||||
|
||||
if ast.is_empty() {
|
||||
// That's weird, no idea, let's go to primary
|
||||
self.active_role = Some(Role::Primary);
|
||||
return false;
|
||||
return Err(Error::QueryRouterParserError("empty query".into()));
|
||||
}
|
||||
|
||||
for q in &ast {
|
||||
for q in ast {
|
||||
match q {
|
||||
// All transactions go to the primary, probably a write.
|
||||
StartTransaction { .. } => {
|
||||
@@ -418,7 +426,7 @@ impl QueryRouter {
|
||||
};
|
||||
}
|
||||
|
||||
true
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Parse the shard number from the Bind message
|
||||
@@ -783,6 +791,34 @@ impl QueryRouter {
|
||||
}
|
||||
}
|
||||
|
||||
/// Add your plugins here and execute them.
|
||||
pub async fn execute_plugins(&self, ast: &Vec<Statement>) -> Result<PluginOutput, Error> {
|
||||
if query_logger::enabled() {
|
||||
let mut query_logger = QueryLogger {};
|
||||
let _ = query_logger.run(&self, ast).await;
|
||||
}
|
||||
|
||||
if intercept::enabled() {
|
||||
let mut intercept = Intercept {};
|
||||
let result = intercept.run(&self, ast).await;
|
||||
|
||||
if let Ok(PluginOutput::Intercept(output)) = result {
|
||||
return Ok(PluginOutput::Intercept(output));
|
||||
}
|
||||
}
|
||||
|
||||
if table_access::enabled() {
|
||||
let mut table_access = TableAccess {};
|
||||
let result = table_access.run(&self, ast).await;
|
||||
|
||||
if let Ok(PluginOutput::Deny(error)) = result {
|
||||
return Ok(PluginOutput::Deny(error));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(PluginOutput::Allow)
|
||||
}
|
||||
|
||||
fn set_sharding_key(&mut self, sharding_key: i64) -> Option<usize> {
|
||||
let sharder = Sharder::new(
|
||||
self.pool_settings.shards,
|
||||
@@ -810,11 +846,22 @@ impl QueryRouter {
|
||||
/// Should we attempt to parse queries?
|
||||
pub fn query_parser_enabled(&self) -> bool {
|
||||
let enabled = match self.query_parser_enabled {
|
||||
None => self.pool_settings.query_parser_enabled,
|
||||
Some(value) => value,
|
||||
};
|
||||
None => {
|
||||
debug!(
|
||||
"Using pool settings, query_parser_enabled: {}",
|
||||
self.pool_settings.query_parser_enabled
|
||||
);
|
||||
self.pool_settings.query_parser_enabled
|
||||
}
|
||||
|
||||
debug!("Query parser enabled: {}", enabled);
|
||||
Some(value) => {
|
||||
debug!(
|
||||
"Using query parser override, query_parser_enabled: {}",
|
||||
value
|
||||
);
|
||||
value
|
||||
}
|
||||
};
|
||||
|
||||
enabled
|
||||
}
|
||||
@@ -862,7 +909,7 @@ mod test {
|
||||
|
||||
for query in queries {
|
||||
// It's a recognized query
|
||||
assert!(qr.infer(&query));
|
||||
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), Some(Role::Replica));
|
||||
}
|
||||
}
|
||||
@@ -881,7 +928,7 @@ mod test {
|
||||
|
||||
for query in queries {
|
||||
// It's a recognized query
|
||||
assert!(qr.infer(&query));
|
||||
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), Some(Role::Primary));
|
||||
}
|
||||
}
|
||||
@@ -893,7 +940,7 @@ mod test {
|
||||
let query = simple_query("SELECT * FROM items WHERE id = 5");
|
||||
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None);
|
||||
|
||||
assert!(qr.infer(&query));
|
||||
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), None);
|
||||
}
|
||||
|
||||
@@ -913,7 +960,7 @@ mod test {
|
||||
res.put(prepared_stmt);
|
||||
res.put_i16(0);
|
||||
|
||||
assert!(qr.infer(&res));
|
||||
assert!(qr.infer(&QueryRouter::parse(&res).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), Some(Role::Replica));
|
||||
}
|
||||
|
||||
@@ -1077,11 +1124,11 @@ mod test {
|
||||
assert_eq!(qr.role(), None);
|
||||
|
||||
let query = simple_query("INSERT INTO test_table VALUES (1)");
|
||||
assert!(qr.infer(&query));
|
||||
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), Some(Role::Primary));
|
||||
|
||||
let query = simple_query("SELECT * FROM test_table");
|
||||
assert!(qr.infer(&query));
|
||||
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), Some(Role::Replica));
|
||||
|
||||
assert!(qr.query_parser_enabled());
|
||||
@@ -1113,6 +1160,7 @@ mod test {
|
||||
auth_query: None,
|
||||
auth_query_password: None,
|
||||
auth_query_user: None,
|
||||
db: "test".to_string(),
|
||||
};
|
||||
let mut qr = QueryRouter::new();
|
||||
assert_eq!(qr.active_role, None);
|
||||
@@ -1142,15 +1190,24 @@ mod test {
|
||||
QueryRouter::setup();
|
||||
|
||||
let mut qr = QueryRouter::new();
|
||||
assert!(qr.infer(&simple_query("BEGIN; SELECT 1; COMMIT;")));
|
||||
assert!(qr
|
||||
.infer(&QueryRouter::parse(&simple_query("BEGIN; SELECT 1; COMMIT;")).unwrap())
|
||||
.is_ok());
|
||||
assert_eq!(qr.role(), Role::Primary);
|
||||
|
||||
assert!(qr.infer(&simple_query("SELECT 1; SELECT 2;")));
|
||||
assert!(qr
|
||||
.infer(&QueryRouter::parse(&simple_query("SELECT 1; SELECT 2;")).unwrap())
|
||||
.is_ok());
|
||||
assert_eq!(qr.role(), Role::Replica);
|
||||
|
||||
assert!(qr.infer(&simple_query(
|
||||
"SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
|
||||
)));
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
"SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.role(), Role::Primary);
|
||||
}
|
||||
|
||||
@@ -1177,6 +1234,7 @@ mod test {
|
||||
auth_query: None,
|
||||
auth_query_password: None,
|
||||
auth_query_user: None,
|
||||
db: "test".to_string(),
|
||||
};
|
||||
let mut qr = QueryRouter::new();
|
||||
qr.update_pool_settings(pool_settings.clone());
|
||||
@@ -1208,47 +1266,84 @@ mod test {
|
||||
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
|
||||
qr.pool_settings.shards = 3;
|
||||
|
||||
assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 5")));
|
||||
assert!(qr
|
||||
.infer(&QueryRouter::parse(&simple_query("SELECT * FROM data WHERE id = 5")).unwrap())
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 2);
|
||||
|
||||
assert!(qr.infer(&simple_query(
|
||||
"SELECT one, two, three FROM public.data WHERE id = 6"
|
||||
)));
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
"SELECT one, two, three FROM public.data WHERE id = 6"
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 0);
|
||||
|
||||
assert!(qr.infer(&simple_query(
|
||||
"SELECT * FROM data
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
"SELECT * FROM data
|
||||
INNER JOIN t2 ON data.id = 5
|
||||
AND t2.data_id = data.id
|
||||
WHERE data.id = 5"
|
||||
)));
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 2);
|
||||
|
||||
// Shard did not move because we couldn't determine the sharding key since it could be ambiguous
|
||||
// in the query.
|
||||
assert!(qr.infer(&simple_query(
|
||||
"SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
|
||||
)));
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
"SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 2);
|
||||
|
||||
assert!(qr.infer(&simple_query(
|
||||
r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
|
||||
)));
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 0);
|
||||
|
||||
assert!(qr.infer(&simple_query(
|
||||
r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
|
||||
)));
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 2);
|
||||
|
||||
// Super unique sharding key
|
||||
qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string());
|
||||
assert!(qr.infer(&simple_query(
|
||||
"SELECT * FROM table_x WHERE unique_enough_column_name = 6"
|
||||
)));
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
"SELECT * FROM table_x WHERE unique_enough_column_name = 6"
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 0);
|
||||
|
||||
assert!(qr.infer(&simple_query("SELECT * FROM table_y WHERE another_key = 5")));
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query("SELECT * FROM table_y WHERE another_key = 5"))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 0);
|
||||
}
|
||||
|
||||
@@ -1272,11 +1367,40 @@ mod test {
|
||||
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
|
||||
qr.pool_settings.shards = 3;
|
||||
|
||||
assert!(qr.infer(&simple_query(stmt)));
|
||||
assert!(qr
|
||||
.infer(&QueryRouter::parse(&simple_query(stmt)).unwrap())
|
||||
.is_ok());
|
||||
assert_eq!(qr.placeholders.len(), 1);
|
||||
|
||||
assert!(qr.infer_shard_from_bind(&bind));
|
||||
assert_eq!(qr.shard(), 2);
|
||||
assert!(qr.placeholders.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_table_access_plugin() {
|
||||
use crate::config::TableAccess;
|
||||
let ta = TableAccess {
|
||||
enabled: true,
|
||||
tables: vec![String::from("pg_database")],
|
||||
};
|
||||
|
||||
crate::plugins::table_access::setup(&ta);
|
||||
|
||||
QueryRouter::setup();
|
||||
|
||||
let qr = QueryRouter::new();
|
||||
|
||||
let query = simple_query("SELECT * FROM pg_database");
|
||||
let ast = QueryRouter::parse(&query).unwrap();
|
||||
|
||||
let res = qr.execute_plugins(&ast).await;
|
||||
|
||||
assert_eq!(
|
||||
res,
|
||||
Ok(PluginOutput::Deny(
|
||||
"permission for table \"pg_database\" denied".to_string()
|
||||
))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
269
src/server.rs
269
src/server.rs
@@ -7,22 +7,101 @@ use parking_lot::{Mutex, RwLock};
|
||||
use postgres_protocol::message;
|
||||
use std::collections::HashMap;
|
||||
use std::io::Read;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::SystemTime;
|
||||
use tokio::io::{AsyncReadExt, BufReader};
|
||||
use tokio::net::{
|
||||
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
||||
TcpStream,
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore};
|
||||
use tokio_rustls::{client::TlsStream, TlsConnector};
|
||||
|
||||
use crate::config::{Address, User};
|
||||
use crate::config::{get_config, Address, User};
|
||||
use crate::constants::*;
|
||||
use crate::dns_cache::{AddrSet, CACHED_RESOLVER};
|
||||
use crate::errors::{Error, ServerIdentifier};
|
||||
use crate::messages::*;
|
||||
use crate::mirrors::MirroringManager;
|
||||
use crate::pool::ClientServerMap;
|
||||
use crate::scram::ScramSha256;
|
||||
use crate::stats::ServerStats;
|
||||
use std::io::Write;
|
||||
|
||||
use pin_project::pin_project;
|
||||
|
||||
#[pin_project(project = SteamInnerProj)]
|
||||
pub enum StreamInner {
|
||||
Plain {
|
||||
#[pin]
|
||||
stream: TcpStream,
|
||||
},
|
||||
Tls {
|
||||
#[pin]
|
||||
stream: TlsStream<TcpStream>,
|
||||
},
|
||||
}
|
||||
|
||||
impl AsyncWrite for StreamInner {
|
||||
fn poll_write(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> std::task::Poll<Result<usize, std::io::Error>> {
|
||||
let this = self.project();
|
||||
match this {
|
||||
SteamInnerProj::Tls { stream } => stream.poll_write(cx, buf),
|
||||
SteamInnerProj::Plain { stream } => stream.poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), std::io::Error>> {
|
||||
let this = self.project();
|
||||
match this {
|
||||
SteamInnerProj::Tls { stream } => stream.poll_flush(cx),
|
||||
SteamInnerProj::Plain { stream } => stream.poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), std::io::Error>> {
|
||||
let this = self.project();
|
||||
match this {
|
||||
SteamInnerProj::Tls { stream } => stream.poll_shutdown(cx),
|
||||
SteamInnerProj::Plain { stream } => stream.poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for StreamInner {
|
||||
fn poll_read(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
let this = self.project();
|
||||
match this {
|
||||
SteamInnerProj::Tls { stream } => stream.poll_read(cx, buf),
|
||||
SteamInnerProj::Plain { stream } => stream.poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamInner {
|
||||
pub fn try_write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
match self {
|
||||
StreamInner::Tls { stream } => {
|
||||
let r = stream.get_mut();
|
||||
let mut w = r.1.writer();
|
||||
w.write(buf)
|
||||
}
|
||||
StreamInner::Plain { stream } => stream.try_write(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Server state.
|
||||
pub struct Server {
|
||||
@@ -30,15 +109,11 @@ pub struct Server {
|
||||
/// port, e.g. 5432, and role, e.g. primary or replica.
|
||||
address: Address,
|
||||
|
||||
/// Buffered read socket.
|
||||
read: BufReader<OwnedReadHalf>,
|
||||
|
||||
/// Unbuffered write socket (our client code buffers).
|
||||
write: OwnedWriteHalf,
|
||||
/// Server TCP connection.
|
||||
stream: BufStream<StreamInner>,
|
||||
|
||||
/// Our server response buffer. We buffer data before we give it to the client.
|
||||
buffer: BytesMut,
|
||||
is_async: bool,
|
||||
|
||||
/// Server information the server sent us over on startup.
|
||||
server_info: BytesMut,
|
||||
@@ -75,6 +150,9 @@ pub struct Server {
|
||||
last_activity: SystemTime,
|
||||
|
||||
mirror_manager: Option<MirroringManager>,
|
||||
|
||||
// Associated addresses used
|
||||
addr_set: Option<AddrSet>,
|
||||
}
|
||||
|
||||
impl Server {
|
||||
@@ -88,6 +166,24 @@ impl Server {
|
||||
stats: Arc<ServerStats>,
|
||||
auth_hash: Arc<RwLock<Option<String>>>,
|
||||
) -> Result<Server, Error> {
|
||||
let cached_resolver = CACHED_RESOLVER.load();
|
||||
let mut addr_set: Option<AddrSet> = None;
|
||||
|
||||
// If we are caching addresses and hostname is not an IP
|
||||
if cached_resolver.enabled() && address.host.parse::<IpAddr>().is_err() {
|
||||
debug!("Resolving {}", &address.host);
|
||||
addr_set = match cached_resolver.lookup_ip(&address.host).await {
|
||||
Ok(ok) => {
|
||||
debug!("Obtained: {:?}", ok);
|
||||
Some(ok)
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Error trying to resolve {}, ({:?})", &address.host, err);
|
||||
None
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let mut stream =
|
||||
match TcpStream::connect(&format!("{}:{}", &address.host, address.port)).await {
|
||||
Ok(stream) => stream,
|
||||
@@ -99,8 +195,88 @@ impl Server {
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
// TCP timeouts.
|
||||
configure_socket(&stream);
|
||||
|
||||
let config = get_config();
|
||||
|
||||
let mut stream = if config.general.server_tls {
|
||||
// Request a TLS connection
|
||||
ssl_request(&mut stream).await?;
|
||||
|
||||
let response = match stream.read_u8().await {
|
||||
Ok(response) => response as char,
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Server socket error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
match response {
|
||||
// Server supports TLS
|
||||
'S' => {
|
||||
debug!("Connecting to server using TLS");
|
||||
|
||||
let mut root_store = RootCertStore::empty();
|
||||
root_store.add_server_trust_anchors(
|
||||
webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
|
||||
OwnedTrustAnchor::from_subject_spki_name_constraints(
|
||||
ta.subject,
|
||||
ta.spki,
|
||||
ta.name_constraints,
|
||||
)
|
||||
}),
|
||||
);
|
||||
|
||||
let mut tls_config = rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
|
||||
// Equivalent to sslmode=prefer which is fine most places.
|
||||
// If you want verify-full, change `verify_server_certificate` to true.
|
||||
if !config.general.verify_server_certificate {
|
||||
let mut dangerous = tls_config.dangerous();
|
||||
dangerous.set_certificate_verifier(Arc::new(
|
||||
crate::tls::NoCertificateVerification {},
|
||||
));
|
||||
}
|
||||
|
||||
let connector = TlsConnector::from(Arc::new(tls_config));
|
||||
let stream = match connector
|
||||
.connect(address.host.as_str().try_into().unwrap(), stream)
|
||||
.await
|
||||
{
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!("Server TLS error: {:?}", err)))
|
||||
}
|
||||
};
|
||||
|
||||
StreamInner::Tls { stream }
|
||||
}
|
||||
|
||||
// Server does not support TLS
|
||||
'N' => StreamInner::Plain { stream },
|
||||
|
||||
// Something else?
|
||||
m => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Unknown message: {}",
|
||||
m as char
|
||||
)));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
StreamInner::Plain { stream }
|
||||
};
|
||||
|
||||
// let (read, write) = split(stream);
|
||||
// let (mut read, mut write) = (ReadInner::Plain { stream: read }, WriteInner::Plain { stream: write });
|
||||
|
||||
trace!("Sending StartupMessage");
|
||||
|
||||
// StartupMessage
|
||||
@@ -246,7 +422,7 @@ impl Server {
|
||||
|
||||
let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]);
|
||||
|
||||
if sasl_type == SCRAM_SHA_256 {
|
||||
if sasl_type.contains(SCRAM_SHA_256) {
|
||||
debug!("Using {}", SCRAM_SHA_256);
|
||||
|
||||
// Generate client message.
|
||||
@@ -269,7 +445,7 @@ impl Server {
|
||||
res.put_i32(sasl_response.len() as i32);
|
||||
res.put(sasl_response);
|
||||
|
||||
write_all(&mut stream, res).await?;
|
||||
write_all_flush(&mut stream, &res).await?;
|
||||
} else {
|
||||
error!("Unsupported SCRAM version: {}", sasl_type);
|
||||
return Err(Error::ServerError);
|
||||
@@ -300,7 +476,7 @@ impl Server {
|
||||
res.put_i32(4 + sasl_response.len() as i32);
|
||||
res.put(sasl_response);
|
||||
|
||||
write_all(&mut stream, res).await?;
|
||||
write_all_flush(&mut stream, &res).await?;
|
||||
}
|
||||
|
||||
SASL_FINAL => {
|
||||
@@ -444,14 +620,10 @@ impl Server {
|
||||
}
|
||||
};
|
||||
|
||||
let (read, write) = stream.into_split();
|
||||
|
||||
let mut server = Server {
|
||||
address: address.clone(),
|
||||
read: BufReader::new(read),
|
||||
write,
|
||||
stream: BufStream::new(stream),
|
||||
buffer: BytesMut::with_capacity(8196),
|
||||
is_async: false,
|
||||
server_info,
|
||||
process_id,
|
||||
secret_key,
|
||||
@@ -460,6 +632,7 @@ impl Server {
|
||||
bad: false,
|
||||
needs_cleanup: false,
|
||||
client_server_map,
|
||||
addr_set,
|
||||
connected_at: chrono::offset::Utc::now().naive_utc(),
|
||||
stats,
|
||||
application_name: String::new(),
|
||||
@@ -517,7 +690,7 @@ impl Server {
|
||||
bytes.put_i32(process_id);
|
||||
bytes.put_i32(secret_key);
|
||||
|
||||
write_all(&mut stream, bytes).await
|
||||
write_all_flush(&mut stream, &bytes).await
|
||||
}
|
||||
|
||||
/// Send messages to the server from the client.
|
||||
@@ -525,7 +698,7 @@ impl Server {
|
||||
self.mirror_send(messages);
|
||||
self.stats().data_sent(messages.len());
|
||||
|
||||
match write_all_half(&mut self.write, messages).await {
|
||||
match write_all_flush(&mut self.stream, &messages).await {
|
||||
Ok(_) => {
|
||||
// Successfully sent to server
|
||||
self.last_activity = SystemTime::now();
|
||||
@@ -539,22 +712,12 @@ impl Server {
|
||||
}
|
||||
}
|
||||
|
||||
/// Switch to async mode, flushing messages as soon
|
||||
/// as we receive them without buffering or waiting for "ReadyForQuery".
|
||||
pub fn switch_async(&mut self, on: bool) {
|
||||
if on {
|
||||
self.is_async = true;
|
||||
} else {
|
||||
self.is_async = false;
|
||||
}
|
||||
}
|
||||
|
||||
/// Receive data from the server in response to a client request.
|
||||
/// This method must be called multiple times while `self.is_data_available()` is true
|
||||
/// in order to receive all data the server has to offer.
|
||||
pub async fn recv(&mut self) -> Result<BytesMut, Error> {
|
||||
loop {
|
||||
let mut message = match read_message(&mut self.read).await {
|
||||
let mut message = match read_message(&mut self.stream).await {
|
||||
Ok(message) => message,
|
||||
Err(err) => {
|
||||
error!("Terminating server because of: {:?}", err);
|
||||
@@ -644,10 +807,7 @@ impl Server {
|
||||
// DataRow
|
||||
'D' => {
|
||||
// More data is available after this message, this is not the end of the reply.
|
||||
// If we're async, flush to client now.
|
||||
if !self.is_async {
|
||||
self.data_available = true;
|
||||
}
|
||||
self.data_available = true;
|
||||
|
||||
// Don't flush yet, the more we buffer, the faster this goes...up to a limit.
|
||||
if self.buffer.len() >= 8196 {
|
||||
@@ -660,10 +820,7 @@ impl Server {
|
||||
|
||||
// CopyOutResponse: copy is starting from the server to the client.
|
||||
'H' => {
|
||||
// If we're in async mode, flush now.
|
||||
if !self.is_async {
|
||||
self.data_available = true;
|
||||
}
|
||||
self.data_available = true;
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -683,10 +840,6 @@ impl Server {
|
||||
// Keep buffering until ReadyForQuery shows up.
|
||||
_ => (),
|
||||
};
|
||||
|
||||
if self.is_async {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let bytes = self.buffer.clone();
|
||||
@@ -720,7 +873,23 @@ impl Server {
|
||||
/// Server & client are out of sync, we must discard this connection.
|
||||
/// This happens with clients that misbehave.
|
||||
pub fn is_bad(&self) -> bool {
|
||||
self.bad
|
||||
if self.bad {
|
||||
return self.bad;
|
||||
};
|
||||
let cached_resolver = CACHED_RESOLVER.load();
|
||||
if cached_resolver.enabled() {
|
||||
if let Some(addr_set) = &self.addr_set {
|
||||
if cached_resolver.has_changed(self.address.host.as_str(), addr_set) {
|
||||
warn!(
|
||||
"DNS changed for {}, it was {:?}. Dropping server connection.",
|
||||
self.address.host.as_str(),
|
||||
addr_set
|
||||
);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Get server startup information to forward it to the client.
|
||||
@@ -957,13 +1126,13 @@ impl Drop for Server {
|
||||
// Update statistics
|
||||
self.stats.disconnect();
|
||||
|
||||
let mut bytes = BytesMut::with_capacity(4);
|
||||
let mut bytes = BytesMut::with_capacity(5);
|
||||
bytes.put_u8(b'X');
|
||||
bytes.put_i32(4);
|
||||
|
||||
match self.write.try_write(&bytes) {
|
||||
Ok(_) => (),
|
||||
Err(_) => debug!("Dirty shutdown"),
|
||||
match self.stream.get_mut().try_write(&bytes) {
|
||||
Ok(5) => (),
|
||||
_ => debug!("Dirty shutdown"),
|
||||
};
|
||||
|
||||
// Should not matter.
|
||||
|
||||
23
src/tls.rs
23
src/tls.rs
@@ -4,7 +4,12 @@ use rustls_pemfile::{certs, read_one, Item};
|
||||
use std::iter;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tokio_rustls::rustls::{self, Certificate, PrivateKey};
|
||||
use std::time::SystemTime;
|
||||
use tokio_rustls::rustls::{
|
||||
self,
|
||||
client::{ServerCertVerified, ServerCertVerifier},
|
||||
Certificate, PrivateKey, ServerName,
|
||||
};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
|
||||
use crate::config::get_config;
|
||||
@@ -64,3 +69,19 @@ impl Tls {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NoCertificateVerification;
|
||||
|
||||
impl ServerCertVerifier for NoCertificateVerification {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &Certificate,
|
||||
_intermediates: &[Certificate],
|
||||
_server_name: &ServerName,
|
||||
_scts: &mut dyn Iterator<Item = &[u8]>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: SystemTime,
|
||||
) -> Result<ServerCertVerified, rustls::Error> {
|
||||
Ok(ServerCertVerified::assertion())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
import psycopg2
|
||||
import asyncio
|
||||
import asyncpg
|
||||
|
||||
PGCAT_HOST = "127.0.0.1"
|
||||
PGCAT_PORT = "6432"
|
||||
|
||||
|
||||
def regular_main():
|
||||
# Connect to the PostgreSQL database
|
||||
conn = psycopg2.connect(
|
||||
host=PGCAT_HOST,
|
||||
database="sharded_db",
|
||||
user="sharding_user",
|
||||
password="sharding_user",
|
||||
port=PGCAT_PORT,
|
||||
)
|
||||
|
||||
# Open a cursor to perform database operations
|
||||
cur = conn.cursor()
|
||||
|
||||
# Execute a SQL query
|
||||
cur.execute("SELECT 1")
|
||||
|
||||
# Fetch the results
|
||||
rows = cur.fetchall()
|
||||
|
||||
# Print the results
|
||||
for row in rows:
|
||||
print(row[0])
|
||||
|
||||
# Close the cursor and the database connection
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
|
||||
async def main():
|
||||
# Connect to the PostgreSQL database
|
||||
conn = await asyncpg.connect(
|
||||
host=PGCAT_HOST,
|
||||
database="sharded_db",
|
||||
user="sharding_user",
|
||||
password="sharding_user",
|
||||
port=PGCAT_PORT,
|
||||
)
|
||||
|
||||
# Execute a SQL query
|
||||
for _ in range(25):
|
||||
rows = await conn.fetch("SELECT 1")
|
||||
|
||||
# Print the results
|
||||
for row in rows:
|
||||
print(row[0])
|
||||
|
||||
# Close the database connection
|
||||
await conn.close()
|
||||
|
||||
|
||||
regular_main()
|
||||
asyncio.run(main())
|
||||
@@ -1,11 +1,2 @@
|
||||
asyncio==3.4.3
|
||||
asyncpg==0.27.0
|
||||
black==23.3.0
|
||||
click==8.1.3
|
||||
mypy-extensions==1.0.0
|
||||
packaging==23.1
|
||||
pathspec==0.11.1
|
||||
platformdirs==3.2.0
|
||||
psutil==5.9.1
|
||||
psycopg2==2.9.3
|
||||
tomli==2.0.1
|
||||
psutil==5.9.1
|
||||
@@ -71,15 +71,17 @@ describe "Admin" do
|
||||
|
||||
context "client connects but issues no queries" do
|
||||
it "only affects cl_idle stats" do
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
|
||||
before_test = admin_conn.async_exec("SHOW POOLS")[0]["sv_idle"]
|
||||
connections = Array.new(20) { PG::connect(pgcat_conn_str) }
|
||||
sleep(1)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["cl_idle"]).to eq("20")
|
||||
expect(results["sv_idle"]).to eq("1")
|
||||
expect(results["sv_idle"]).to eq(before_test)
|
||||
|
||||
connections.map(&:close)
|
||||
sleep(1.1)
|
||||
@@ -87,7 +89,7 @@ describe "Admin" do
|
||||
%w[cl_active cl_idle cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["sv_idle"]).to eq("1")
|
||||
expect(results["sv_idle"]).to eq(before_test)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ describe "Query Mirroing" do
|
||||
processes.pgcat.shutdown
|
||||
end
|
||||
|
||||
it "can mirror a query" do
|
||||
xit "can mirror a query" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
runs = 15
|
||||
runs.times { conn.async_exec("SELECT 1 + 2") }
|
||||
|
||||
14
tests/ruby/plugins_spec.rb
Normal file
14
tests/ruby/plugins_spec.rb
Normal file
@@ -0,0 +1,14 @@
|
||||
require_relative 'spec_helper'
|
||||
|
||||
|
||||
describe "Plugins" do
|
||||
let(:processes) { Helpers::Pgcat.three_shard_setup("sharded_db", 5) }
|
||||
|
||||
context "intercept" do
|
||||
it "will intercept an intellij query" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
res = conn.exec("select current_database() as a, current_schemas(false) as b")
|
||||
expect(res.values).to eq([["sharded_db", "{public}"]])
|
||||
end
|
||||
end
|
||||
end
|
||||
Reference in New Issue
Block a user