mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 09:26:30 +00:00
Compare commits
43 Commits
mostafa_se
...
levkk-lowe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
35c20dc9b4 | ||
|
|
72ce2be5c7 | ||
|
|
811885f464 | ||
|
|
d5e329fec5 | ||
|
|
09e54e1175 | ||
|
|
23819c8549 | ||
|
|
7dfbd993f2 | ||
|
|
3601130ba1 | ||
|
|
0d504032b2 | ||
|
|
4a87b4807d | ||
|
|
cb5ff40a59 | ||
|
|
62b2d994c1 | ||
|
|
66805d7e77 | ||
|
|
4ccc1e7fa3 | ||
|
|
3dae3d0777 | ||
|
|
a18eb42df5 | ||
|
|
6aacf1fa19 | ||
|
|
8e99e65215 | ||
|
|
5dfbc102a9 | ||
|
|
bae12fca99 | ||
|
|
421c5d4b64 | ||
|
|
d568739db9 | ||
|
|
692353c839 | ||
|
|
a62f6b0eea | ||
|
|
89e15f09b5 | ||
|
|
7ddd23b514 | ||
|
|
faa9c1f64a | ||
|
|
9094988491 | ||
|
|
6f768a84ce | ||
|
|
0757d7f3a0 | ||
|
|
568f04feee | ||
|
|
58ce76d9b9 | ||
|
|
9a2076a9eb | ||
|
|
e7e7118725 | ||
|
|
99f790cacf | ||
|
|
434b0bb69e | ||
|
|
714e043ef0 | ||
|
|
863104aadd | ||
|
|
7dd96141e3 | ||
|
|
0d5feac4b2 | ||
|
|
90aba9c011 | ||
|
|
0f34b49503 | ||
|
|
ca4431b67e |
@@ -46,6 +46,14 @@ jobs:
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_INITDB_ARGS: --auth-local=scram-sha-256 --auth-host=scram-sha-256 --auth=scram-sha-256
|
||||
|
||||
- image: postgres:14
|
||||
command: ["postgres", "-p", "10432", "-c", "shared_preload_libraries=pg_stat_statements"]
|
||||
environment:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_DB: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_INITDB_ARGS: --auth-local=md5 --auth-host=md5 --auth=md5
|
||||
|
||||
# Add steps to the job
|
||||
# See: https://circleci.com/docs/2.0/configuration-reference/#steps
|
||||
steps:
|
||||
|
||||
@@ -39,7 +39,7 @@ log_client_connections = false
|
||||
log_client_disconnections = false
|
||||
|
||||
# Reload config automatically if it changes.
|
||||
autoreload = true
|
||||
autoreload = 15000
|
||||
|
||||
# TLS
|
||||
tls_certificate = ".circleci/server.cert"
|
||||
|
||||
@@ -19,6 +19,7 @@ PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 5432 -U postgres -f tests/sharding/q
|
||||
PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 7432 -U postgres -f tests/sharding/query_routing_setup.sql
|
||||
PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 8432 -U postgres -f tests/sharding/query_routing_setup.sql
|
||||
PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 9432 -U postgres -f tests/sharding/query_routing_setup.sql
|
||||
PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 10432 -U postgres -f tests/sharding/query_routing_setup.sql
|
||||
|
||||
PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard0 -i
|
||||
PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard1 -i
|
||||
|
||||
14
.editorconfig
Normal file
14
.editorconfig
Normal file
@@ -0,0 +1,14 @@
|
||||
root = true
|
||||
|
||||
[*]
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
|
||||
[*.rs]
|
||||
indent_style = space
|
||||
indent_size = 4
|
||||
max_line_length = 120
|
||||
|
||||
[*.toml]
|
||||
indent_style = space
|
||||
indent_size = 2
|
||||
@@ -15,6 +15,6 @@ jobs:
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Build CI Docker image
|
||||
run: |
|
||||
docker build . -f Dockerfile.ci --tag ghcr.io/levkk/pgcat-ci:latest
|
||||
docker run ghcr.io/levkk/pgcat-ci:latest
|
||||
docker push ghcr.io/levkk/pgcat-ci:latest
|
||||
docker build . -f Dockerfile.ci --tag ghcr.io/postgresml/pgcat-ci:latest
|
||||
docker run ghcr.io/postgresml/pgcat-ci:latest
|
||||
docker push ghcr.io/postgresml/pgcat-ci:latest
|
||||
|
||||
2
.rustfmt.toml
Normal file
2
.rustfmt.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
edition = "2021"
|
||||
hard_tabs = false
|
||||
101
CONFIG.md
101
CONFIG.md
@@ -49,6 +49,22 @@ default: 30000 # milliseconds
|
||||
|
||||
How long an idle connection with a server is left open (ms).
|
||||
|
||||
### server_lifetime
|
||||
```
|
||||
path: general.server_lifetime
|
||||
default: 86400000 # 24 hours
|
||||
```
|
||||
|
||||
Max connection lifetime before it's closed, even if actively used.
|
||||
|
||||
### idle_client_in_transaction_timeout
|
||||
```
|
||||
path: general.idle_client_in_transaction_timeout
|
||||
default: 0 # milliseconds
|
||||
```
|
||||
|
||||
How long a client is allowed to be idle while in a transaction (ms).
|
||||
|
||||
### healthcheck_timeout
|
||||
```
|
||||
path: general.healthcheck_timeout
|
||||
@@ -100,7 +116,7 @@ If we should log client disconnections
|
||||
### autoreload
|
||||
```
|
||||
path: general.autoreload
|
||||
default: false
|
||||
default: 15000
|
||||
```
|
||||
|
||||
When set to true, PgCat reloads configs if it detects a change in the config file.
|
||||
@@ -144,7 +160,7 @@ default: <UNSET>
|
||||
example: "server.cert"
|
||||
```
|
||||
|
||||
Path to TLS Certficate file to use for TLS connections
|
||||
Path to TLS Certificate file to use for TLS connections
|
||||
|
||||
### tls_private_key
|
||||
```
|
||||
@@ -172,6 +188,22 @@ default: "admin_pass"
|
||||
|
||||
Password to access the virtual administrative database
|
||||
|
||||
### dns_cache_enabled
|
||||
```
|
||||
path: general.dns_cache_enabled
|
||||
default: false
|
||||
```
|
||||
When enabled, ip resolutions for server connections specified using hostnames will be cached
|
||||
and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found
|
||||
old ip connections are closed (gracefully) and new connections will start using new ip.
|
||||
|
||||
### dns_max_ttl
|
||||
```
|
||||
path: general.dns_max_ttl
|
||||
default: 30
|
||||
```
|
||||
Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
|
||||
|
||||
## `pools.<pool_name>` Section
|
||||
|
||||
### pool_mode
|
||||
@@ -205,7 +237,7 @@ If the client doesn't specify, PgCat routes traffic to this role by default.
|
||||
`replica` round-robin between replicas only without touching the primary,
|
||||
`primary` all queries go to the primary unless otherwise specified.
|
||||
|
||||
### query_parser_enabled (experimental)
|
||||
### query_parser_enabled
|
||||
```
|
||||
path: pools.<pool_name>.query_parser_enabled
|
||||
default: true
|
||||
@@ -226,7 +258,7 @@ If the query parser is enabled and this setting is enabled, the primary will be
|
||||
load balancing of read queries. Otherwise, the primary will only be used for write
|
||||
queries. The primary can always be explicitly selected with our custom protocol.
|
||||
|
||||
### sharding_key_regex (experimental)
|
||||
### sharding_key_regex
|
||||
```
|
||||
path: pools.<pool_name>.sharding_key_regex
|
||||
default: <UNSET>
|
||||
@@ -248,7 +280,40 @@ Current options:
|
||||
`pg_bigint_hash`: PARTITION BY HASH (Postgres hashing function)
|
||||
`sha1`: A hashing function based on SHA1
|
||||
|
||||
### automatic_sharding_key (experimental)
|
||||
### auth_query
|
||||
```
|
||||
path: pools.<pool_name>.auth_query
|
||||
default: <UNSET>
|
||||
example: "SELECT $1"
|
||||
```
|
||||
|
||||
Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
|
||||
established using the database configured in the pool. This parameter is inherited by every pool
|
||||
and can be redefined in pool configuration.
|
||||
|
||||
### auth_query_user
|
||||
```
|
||||
path: pools.<pool_name>.auth_query_user
|
||||
default: <UNSET>
|
||||
example: "sharding_user"
|
||||
```
|
||||
|
||||
User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
|
||||
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
|
||||
This parameter is inherited by every pool and can be redefined in pool configuration.
|
||||
|
||||
### auth_query_password
|
||||
```
|
||||
path: pools.<pool_name>.auth_query_password
|
||||
default: <UNSET>
|
||||
example: "sharding_user"
|
||||
```
|
||||
|
||||
Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
|
||||
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
|
||||
This parameter is inherited by every pool and can be redefined in pool configuration.
|
||||
|
||||
### automatic_sharding_key
|
||||
```
|
||||
path: pools.<pool_name>.automatic_sharding_key
|
||||
default: <UNSET>
|
||||
@@ -281,7 +346,8 @@ path: pools.<pool_name>.users.<user_index>.username
|
||||
default: "sharding_user"
|
||||
```
|
||||
|
||||
Postgresql username
|
||||
PostgreSQL username used to authenticate the user and connect to the server
|
||||
if `server_username` is not set.
|
||||
|
||||
### password
|
||||
```
|
||||
@@ -289,7 +355,26 @@ path: pools.<pool_name>.users.<user_index>.password
|
||||
default: "sharding_user"
|
||||
```
|
||||
|
||||
Postgresql password
|
||||
PostgreSQL password used to authenticate the user and connect to the server
|
||||
if `server_password` is not set.
|
||||
|
||||
### server_username
|
||||
```
|
||||
path: pools.<pool_name>.users.<user_index>.server_username
|
||||
default: <UNSET>
|
||||
example: "another_user"
|
||||
```
|
||||
|
||||
PostgreSQL username used to connect to the server.
|
||||
|
||||
### server_password
|
||||
```
|
||||
path: pools.<pool_name>.users.<user_index>.server_password
|
||||
default: <UNSET>
|
||||
example: "another_password"
|
||||
```
|
||||
|
||||
PostgreSQL password used to connect to the server.
|
||||
|
||||
### pool_size
|
||||
```
|
||||
@@ -320,7 +405,7 @@ default: [["127.0.0.1", 5432, "primary"], ["localhost", 5432, "replica"]]
|
||||
|
||||
Array of servers in the shard, each server entry is an array of `[host, port, role]`
|
||||
|
||||
### mirrors (experimental)
|
||||
### mirrors
|
||||
```
|
||||
path: pools.<pool_name>.shards.<shard_index>.mirrors
|
||||
default: <UNSET>
|
||||
|
||||
543
Cargo.lock
generated
543
Cargo.lock
generated
@@ -4,9 +4,9 @@ version = 3
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "0.7.20"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac"
|
||||
checksum = "67fc08ce920c31afb70f013dcce1bfc3a3195de6a228474e45e1f145b36f8d04"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
@@ -27,14 +27,46 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6"
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.66"
|
||||
name = "async-stream"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b84f9ebcc6c1f5b8cb160f6990096a5c127f423fcb6e1ccc46c370cbdfb75dfc"
|
||||
checksum = "dad5c83079eae9969be7fadefe640a1c566901f05ff91ab221de4b6f68d9507e"
|
||||
dependencies = [
|
||||
"async-stream-impl",
|
||||
"futures-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-stream-impl"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "10f203db73a71dfa2fb6dd22763990fa26f3d2625a6da2da900d23b87d26be27"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.68"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atomic_enum"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6227a8d6fdb862bcb100c4314d0d9579e5cd73fa6df31a2e6f6e1acd3c5f1207"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -83,6 +115,12 @@ version = "3.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535"
|
||||
|
||||
[[package]]
|
||||
name = "byteorder"
|
||||
version = "1.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
|
||||
|
||||
[[package]]
|
||||
name = "bytes"
|
||||
version = "1.4.0"
|
||||
@@ -175,7 +213,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"scratch",
|
||||
"syn",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -192,9 +230,15 @@ checksum = "086c685979a698443656e5cf7856c95c642295a38599f12fb1ff76fb28d19892"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "data-encoding"
|
||||
version = "2.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3ee2393c4a91429dffb4bedf19f4d6abf27d8a732c8ce4980305d782e5426d57"
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.6"
|
||||
@@ -206,6 +250,18 @@ dependencies = [
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "enum-as-inner"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c9720bba047d567ffc8a3cba48bf19126600e249ab7f128e9233e6376976a116"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.10.0"
|
||||
@@ -246,6 +302,12 @@ version = "1.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "de853764b47027c2e862a995c34978ffa63c1501f2e15f987ba11bd4f9bba193"
|
||||
|
||||
[[package]]
|
||||
name = "fallible-iterator"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.7"
|
||||
@@ -253,10 +315,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
|
||||
|
||||
[[package]]
|
||||
name = "futures"
|
||||
version = "0.3.26"
|
||||
name = "form_urlencoded"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "13e2792b0ff0340399d58445b88fd9770e3489eff258a4cbc1523418f12abf84"
|
||||
checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8"
|
||||
dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures"
|
||||
version = "0.3.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
@@ -269,9 +340,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "futures-channel"
|
||||
version = "0.3.26"
|
||||
version = "0.3.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2e5317663a9089767a1ec00a487df42e0ca174b61b4483213ac24448e4664df5"
|
||||
checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
@@ -279,15 +350,15 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "futures-core"
|
||||
version = "0.3.26"
|
||||
version = "0.3.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec90ff4d0fe1f57d600049061dc6bb68ed03c7d2fbd697274c41805dcb3f8608"
|
||||
checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c"
|
||||
|
||||
[[package]]
|
||||
name = "futures-executor"
|
||||
version = "0.3.26"
|
||||
version = "0.3.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e8de0a35a6ab97ec8869e32a2473f4b1324459e14c29275d14b10cb1fd19b50e"
|
||||
checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-task",
|
||||
@@ -296,38 +367,38 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "futures-io"
|
||||
version = "0.3.26"
|
||||
version = "0.3.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bfb8371b6fb2aeb2d280374607aeabfc99d95c72edfe51692e42d3d7f0d08531"
|
||||
checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964"
|
||||
|
||||
[[package]]
|
||||
name = "futures-macro"
|
||||
version = "0.3.26"
|
||||
version = "0.3.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "95a73af87da33b5acf53acfebdc339fe592ecf5357ac7c0a7734ab9d8c876a70"
|
||||
checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-sink"
|
||||
version = "0.3.26"
|
||||
version = "0.3.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f310820bb3e8cfd46c80db4d7fb8353e15dfff853a127158425f31e0be6c8364"
|
||||
checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e"
|
||||
|
||||
[[package]]
|
||||
name = "futures-task"
|
||||
version = "0.3.26"
|
||||
version = "0.3.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dcf79a1bf610b10f42aea489289c5a2c478a786509693b80cd39c44ccd936366"
|
||||
checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65"
|
||||
|
||||
[[package]]
|
||||
name = "futures-util"
|
||||
version = "0.3.26"
|
||||
version = "0.3.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c1d6de3acfef38d2be4b1f543f553131788603495be83da675e180c8d6b7bd1"
|
||||
checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
@@ -364,9 +435,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.3.15"
|
||||
version = "0.3.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5f9f29bc9dda355256b2916cf526ab02ce0aeaaaf2bad60d65ef3f12f11dd0f4"
|
||||
checksum = "66b91535aa35fea1523ad1b86cb6b53c28e0ae566ba4a460f4457e936cad7c6f"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fnv",
|
||||
@@ -387,6 +458,12 @@ version = "0.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9"
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.2.6"
|
||||
@@ -411,6 +488,17 @@ dependencies = [
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hostname"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3c731c3e10504cc8ed35cfe2f1db4c9274c3d35fa486e3b31df46f068ef3e867"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"match_cfg",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "0.2.9"
|
||||
@@ -453,9 +541,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
|
||||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "0.14.24"
|
||||
version = "0.14.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5e011372fa0b68db8350aa7a248930ecc7839bf46d8485577d69f117a75f164c"
|
||||
checksum = "ab302d72a6f11a3b910431ff93aae7e773078c769f0a3ef15fb9ec692ed147d4"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
@@ -499,6 +587,27 @@ dependencies = [
|
||||
"cxx-build",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "418a0a6fab821475f634efe3ccc45c013f742efe03d853e8d3355d5cb850ecf8"
|
||||
dependencies = [
|
||||
"matches",
|
||||
"unicode-bidi",
|
||||
"unicode-normalization",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e14ddfc70884202db2244c223200c204c2bda1bc6e0998d11b5e024d657209e6"
|
||||
dependencies = [
|
||||
"unicode-bidi",
|
||||
"unicode-normalization",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "1.9.2"
|
||||
@@ -519,6 +628,24 @@ dependencies = [
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ipconfig"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bd302af1b90f2463a98fa5ad469fc212c8e3175a41c3068601bfa2727591c5be"
|
||||
dependencies = [
|
||||
"socket2",
|
||||
"widestring",
|
||||
"winapi",
|
||||
"winreg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ipnet"
|
||||
version = "2.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f88c5561171189e69df9d98bcf18fd5f9558300f7ea7b801eb8a0fd748bd8745"
|
||||
|
||||
[[package]]
|
||||
name = "is-terminal"
|
||||
version = "0.4.4"
|
||||
@@ -566,6 +693,12 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.139"
|
||||
@@ -581,6 +714,12 @@ dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linked-hash-map"
|
||||
version = "0.5.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f"
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.1.4"
|
||||
@@ -606,6 +745,27 @@ dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lru-cache"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "31e24f1ad8321ca0e8a1e0ac13f23cb668e6f5466c2c57319f6a5cf1cc8e3b1c"
|
||||
dependencies = [
|
||||
"linked-hash-map",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "match_cfg"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4"
|
||||
|
||||
[[package]]
|
||||
name = "matches"
|
||||
version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f"
|
||||
|
||||
[[package]]
|
||||
name = "md-5"
|
||||
version = "0.10.5"
|
||||
@@ -714,18 +874,26 @@ dependencies = [
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "percent-encoding"
|
||||
version = "2.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e"
|
||||
|
||||
[[package]]
|
||||
name = "pgcat"
|
||||
version = "0.6.0-alpha1"
|
||||
version = "1.0.2-alpha1"
|
||||
dependencies = [
|
||||
"arc-swap",
|
||||
"async-trait",
|
||||
"atomic_enum",
|
||||
"base64",
|
||||
"bb8",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"env_logger",
|
||||
"exitcode",
|
||||
"fallible-iterator",
|
||||
"futures",
|
||||
"hmac",
|
||||
"hyper",
|
||||
@@ -737,11 +905,15 @@ dependencies = [
|
||||
"once_cell",
|
||||
"parking_lot",
|
||||
"phf",
|
||||
"pin-project",
|
||||
"postgres-protocol",
|
||||
"rand",
|
||||
"regex",
|
||||
"rustls",
|
||||
"rustls-pemfile",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"sha-1",
|
||||
"sha2",
|
||||
"socket2",
|
||||
@@ -749,7 +921,10 @@ dependencies = [
|
||||
"stringprep",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tokio-test",
|
||||
"toml",
|
||||
"trust-dns-resolver",
|
||||
"webpki-roots",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -782,7 +957,7 @@ dependencies = [
|
||||
"phf_shared",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -794,6 +969,26 @@ dependencies = [
|
||||
"siphasher",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project"
|
||||
version = "1.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc"
|
||||
dependencies = [
|
||||
"pin-project-internal",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-internal"
|
||||
version = "1.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-lite"
|
||||
version = "0.2.9"
|
||||
@@ -806,6 +1001,24 @@ version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
|
||||
|
||||
[[package]]
|
||||
name = "postgres-protocol"
|
||||
version = "0.6.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "78b7fa9f396f51dffd61546fd8573ee20592287996568e6175ceb0f8699ad75d"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"hmac",
|
||||
"md-5",
|
||||
"memchr",
|
||||
"rand",
|
||||
"sha2",
|
||||
"stringprep",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.17"
|
||||
@@ -814,18 +1027,24 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.51"
|
||||
version = "1.0.53"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6"
|
||||
checksum = "ba466839c78239c09faf015484e5cc04860f88242cff4d03eb038f04b4699b73"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.23"
|
||||
name = "quick-error"
|
||||
version = "1.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b"
|
||||
checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
@@ -871,9 +1090,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.7.1"
|
||||
version = "1.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733"
|
||||
checksum = "af83e617f331cc6ae2da5443c602dfa5af81e517212d9d611a5b3ba1777b5370"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
@@ -882,9 +1101,19 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.6.28"
|
||||
version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848"
|
||||
checksum = "a5996294f19bd3aae0453a862ad728f60e6600695733dd5df01da90c54363a3c"
|
||||
|
||||
[[package]]
|
||||
name = "resolv-conf"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "52e44394d2086d010551b14b53b1f24e31647570cd1deb0379e2c21b329aba00"
|
||||
dependencies = [
|
||||
"hostname",
|
||||
"quick-error",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ring"
|
||||
@@ -917,14 +1146,14 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.20.8"
|
||||
version = "0.21.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fff78fc74d175294f4e83b28343315ffcfb114b156f0185e9741cb5570f50e2f"
|
||||
checksum = "c911ba11bc8433e811ce56fde130ccf32f5127cab0e0194e9c68c5a5b671791e"
|
||||
dependencies = [
|
||||
"log",
|
||||
"ring",
|
||||
"rustls-webpki",
|
||||
"sct",
|
||||
"webpki",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -936,6 +1165,22 @@ dependencies = [
|
||||
"base64",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.100.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d6207cd5ed3d8dca7816f8f3725513a34609c0c765bf652b8c3cb4cfd87db46b"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041"
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.1.0"
|
||||
@@ -960,19 +1205,33 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.154"
|
||||
version = "1.0.160"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8cdd151213925e7f1ab45a9bbfb129316bd00799784b174b7cc7bcd16961c49e"
|
||||
checksum = "bb2f3770c8bce3bcda7e149193a069a0f4365bda1fa5cd88e03bca26afc1216c"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.154"
|
||||
version = "1.0.160"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4fc80d722935453bcafdc2c9a73cd6fac4dc1938f0346035d84bf99fa9e33217"
|
||||
checksum = "291a097c63d8497e00160b166a967a4a79c64f3facdd01cbd7502231688d77df"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.96"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1054,11 +1313,23 @@ checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
|
||||
|
||||
[[package]]
|
||||
name = "sqlparser"
|
||||
version = "0.32.0"
|
||||
version = "0.33.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0366f270dbabb5cc2e4c88427dc4c08bba144f81e32fbd459a013f26a4d16aa0"
|
||||
checksum = "355dc4d4b6207ca8a3434fc587db0a8016130a574dbcdbfb93d7f7b5bc5b211a"
|
||||
dependencies = [
|
||||
"log",
|
||||
"sqlparser_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlparser_derive"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "55fe75cb4a364c7f7ae06c7dbbc8d84bddd85d6cdf9975963c3935bc1991761e"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1094,6 +1365,17 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0da4a3c17e109f700685ec577c0f85efd9b19bcf15c913985f14dc1ac01775aa"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
version = "1.2.0"
|
||||
@@ -1103,6 +1385,26 @@ dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.37"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "10deb33631e3c9018b9baf9dcbbc4f737320d2b576bac10f6aefa048fa407e3e"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.37"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "982d17546b47146b28f7c22e3d08465f6b8903d0ea13c1660d9d84a6e7adcdbb"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "time"
|
||||
version = "0.1.45"
|
||||
@@ -1157,18 +1459,41 @@ checksum = "d266c00fde287f55d3f1c3e96c500c362a2b8c695076ec180f27918820bc6df8"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.23.4"
|
||||
version = "0.24.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59"
|
||||
checksum = "e0d409377ff5b1e3ca6437aa86c1eb7d40c134bfec254e44c830defa92669db5"
|
||||
dependencies = [
|
||||
"rustls",
|
||||
"tokio",
|
||||
"webpki",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-stream"
|
||||
version = "0.1.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d660770404473ccd7bc9f8b28494a811bc18542b915c0855c51e8f419d5223ce"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-test"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "53474327ae5e166530d17f2d956afcb4f8a004de581b3cae10f12006bc8163e3"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1187,9 +1512,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "toml"
|
||||
version = "0.7.2"
|
||||
version = "0.7.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f7afcae9e3f0fe2c370fd4657108972cbb2fa9db1b9f84849cefd80741b01cb6"
|
||||
checksum = "b403acf6f2bb0859c93c7f0d967cb4a75a7ac552100f9322faf64dc047669b21"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_spanned",
|
||||
@@ -1208,9 +1533,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "toml_edit"
|
||||
version = "0.19.4"
|
||||
version = "0.19.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9a1eb0622d28f4b9c90adc4ea4b2b46b47663fde9ac5fafcb14a1369d5508825"
|
||||
checksum = "08de71aa0d6e348f070457f85af8bd566e2bc452156a423ddf22861b3a953fae"
|
||||
dependencies = [
|
||||
"indexmap",
|
||||
"serde",
|
||||
@@ -1233,9 +1558,21 @@ checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"pin-project-lite",
|
||||
"tracing-attributes",
|
||||
"tracing-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-attributes"
|
||||
version = "0.1.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-core"
|
||||
version = "0.1.30"
|
||||
@@ -1245,6 +1582,51 @@ dependencies = [
|
||||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "trust-dns-proto"
|
||||
version = "0.22.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4f7f83d1e4a0e4358ac54c5c3681e5d7da5efc5a7a632c90bb6d6669ddd9bc26"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"cfg-if",
|
||||
"data-encoding",
|
||||
"enum-as-inner",
|
||||
"futures-channel",
|
||||
"futures-io",
|
||||
"futures-util",
|
||||
"idna 0.2.3",
|
||||
"ipnet",
|
||||
"lazy_static",
|
||||
"rand",
|
||||
"smallvec",
|
||||
"thiserror",
|
||||
"tinyvec",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "trust-dns-resolver"
|
||||
version = "0.22.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "aff21aa4dcefb0a1afbfac26deb0adc93888c7d295fb63ab273ef276ba2b7cfe"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"futures-util",
|
||||
"ipconfig",
|
||||
"lazy_static",
|
||||
"lru-cache",
|
||||
"parking_lot",
|
||||
"resolv-conf",
|
||||
"smallvec",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"trust-dns-proto",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "try-lock"
|
||||
version = "0.2.4"
|
||||
@@ -1290,6 +1672,17 @@ version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
|
||||
|
||||
[[package]]
|
||||
name = "url"
|
||||
version = "2.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643"
|
||||
dependencies = [
|
||||
"form_urlencoded",
|
||||
"idna 0.3.0",
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.4"
|
||||
@@ -1339,7 +1732,7 @@ dependencies = [
|
||||
"once_cell",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 1.0.109",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
@@ -1361,7 +1754,7 @@ checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 1.0.109",
|
||||
"wasm-bindgen-backend",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
@@ -1383,15 +1776,20 @@ dependencies = [
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki"
|
||||
version = "0.22.0"
|
||||
name = "webpki-roots"
|
||||
version = "0.23.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd"
|
||||
checksum = "aa54963694b65584e170cf5dc46aeb4dcaa5584e652ff5f3952e56d66aff0125"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"untrusted",
|
||||
"rustls-webpki",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "widestring"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "17882f045410753661207383517a6f62ec3dbeb6a4ed2acce01f0728238d1983"
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
@@ -1497,3 +1895,12 @@ checksum = "faf09497b8f8b5ac5d3bb4d05c0a99be20f26fd3d5f2db7b0716e946d5103658"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winreg"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
19
Cargo.toml
19
Cargo.toml
@@ -1,10 +1,9 @@
|
||||
[package]
|
||||
name = "pgcat"
|
||||
version = "0.6.0-alpha1"
|
||||
version = "1.0.2-alpha1"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
bytes = "1"
|
||||
@@ -15,12 +14,12 @@ rand = "0.8"
|
||||
chrono = "0.4"
|
||||
sha-1 = "0.10"
|
||||
toml = "0.7"
|
||||
serde = "1"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_derive = "1"
|
||||
regex = "1"
|
||||
num_cpus = "1"
|
||||
once_cell = "1"
|
||||
sqlparser = "0.32.0"
|
||||
sqlparser = {version = "0.33", features = ["visitor"] }
|
||||
log = "0.4"
|
||||
arc-swap = "1"
|
||||
env_logger = "0.10"
|
||||
@@ -29,7 +28,7 @@ hmac = "0.12"
|
||||
sha2 = "0.10"
|
||||
base64 = "0.21"
|
||||
stringprep = "0.1"
|
||||
tokio-rustls = "0.23"
|
||||
tokio-rustls = "0.24"
|
||||
rustls-pemfile = "1"
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
phf = { version = "0.11.1", features = ["macros"] }
|
||||
@@ -37,6 +36,16 @@ exitcode = "1.1.2"
|
||||
futures = "0.3"
|
||||
socket2 = { version = "0.4.7", features = ["all"] }
|
||||
nix = "0.26.2"
|
||||
atomic_enum = "0.2.0"
|
||||
postgres-protocol = "0.6.5"
|
||||
fallible-iterator = "0.2"
|
||||
pin-project = "1"
|
||||
webpki-roots = "0.23"
|
||||
rustls = { version = "0.21", features = ["dangerous_configuration"] }
|
||||
trust-dns-resolver = "0.22.0"
|
||||
tokio-test = "0.4.2"
|
||||
serde_json = "1"
|
||||
|
||||
[target.'cfg(not(target_env = "msvc"))'.dependencies]
|
||||
jemallocator = "0.5.0"
|
||||
|
||||
|
||||
2
LICENSE
2
LICENSE
@@ -1,4 +1,4 @@
|
||||
Copyright (c) 2022 Lev Kokotov <lev@levthe.dev>
|
||||
Copyright (c) 2023 PgCat Contributors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining
|
||||
a copy of this software and associated documentation files (the
|
||||
|
||||
431
README.md
431
README.md
@@ -1,33 +1,80 @@
|
||||
##### PgCat: PostgreSQL at petabyte scale
|
||||
## PgCat: Nextgen PostgreSQL Pooler
|
||||
|
||||
[](https://circleci.com/gh/levkk/pgcat/tree/main)
|
||||
[](https://circleci.com/gh/postgresml/pgcat/tree/main)
|
||||
<a href="https://discord.gg/DmyJP3qJ7U" target="_blank">
|
||||
<img src="https://img.shields.io/discord/1013868243036930099" alt="Join our Discord!" />
|
||||
</a>
|
||||
|
||||
PostgreSQL pooler (like PgBouncer) with sharding, load balancing and failover support.
|
||||
|
||||
**Beta**: looking for beta testers, see [#35](https://github.com/levkk/pgcat/issues/35).
|
||||
PostgreSQL pooler and proxy (like PgBouncer) with support for sharding, load balancing, failover and mirroring.
|
||||
|
||||
## Features
|
||||
| **Feature** | **Status** | **Comments** |
|
||||
|--------------------------------|-----------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| Transaction pooling | :white_check_mark: | Identical to PgBouncer. |
|
||||
| Session pooling | :white_check_mark: | Identical to PgBouncer. |
|
||||
| `COPY` support | :white_check_mark: | Both `COPY TO` and `COPY FROM` are supported. |
|
||||
| Query cancellation | :white_check_mark: | Supported both in transaction and session pooling modes. |
|
||||
| Load balancing of read queries | :white_check_mark: | Using random between replicas. Primary is included when `primary_reads_enabled` is enabled (default). |
|
||||
| Sharding | :white_check_mark: | Transactions are sharded using `SET SHARD TO` and `SET SHARDING KEY TO` syntax extensions; see examples below. |
|
||||
| Failover | :white_check_mark: | Replicas are tested with a health check. If a health check fails, remaining replicas are attempted; see below for algorithm description and examples. |
|
||||
| Statistics | :white_check_mark: | Statistics available in the admin database (`pgcat` and `pgbouncer`) with `SHOW STATS`, `SHOW POOLS` and others. |
|
||||
| Live configuration reloading | :white_check_mark: | Reload supported settings with a `SIGHUP` to the process, e.g. `kill -s SIGHUP $(pgrep pgcat)` or `RELOAD` query issued to the admin database. |
|
||||
| Client authentication | :white_check_mark: :wrench: | MD5 password authentication is supported, SCRAM is on the roadmap; one user is used to connect to Postgres with both SCRAM and MD5 supported. |
|
||||
| Admin database | :white_check_mark: | The admin database, similar to PgBouncer's, allows to query for statistics and reload the configuration. |
|
||||
|
||||
| **Feature** | **Status** | **Comments** |
|
||||
|-------------|------------|--------------|
|
||||
| Transaction pooling | **Stable** | Identical to PgBouncer with notable improvements for handling bad clients and abandoned transactions. |
|
||||
| Session pooling | **Stable** | Identical to PgBouncer. |
|
||||
| Multi-threaded runtime | **Stable** | Using Tokio asynchronous runtime, the pooler takes advantage of multicore machines. |
|
||||
| Load balancing of read queries | **Stable** | Queries are automatically load balanced between replicas and the primary. |
|
||||
| Failover | **Stable** | Queries are automatically rerouted around broken replicas, validated by regular health checks. |
|
||||
| Admin database statistics | **Stable** | Pooler statistics and administration via the `pgbouncer` and `pgcat` databases. |
|
||||
| Prometheus statistics | **Stable** | Statistics are reported via a HTTP endpoint for Prometheus. |
|
||||
| SSL/TLS | **Stable** | Clients can connect to the pooler using TLS. Pooler can connect to Postgres servers using TLS. |
|
||||
| Client/Server authentication | **Stable** | Clients can connect using MD5 authentication, supported by `libpq` and all Postgres client drivers. PgCat can connect to Postgres using MD5 and SCRAM-SHA-256. |
|
||||
| Live configuration reloading | **Stable** | Identical to PgBouncer; all settings can be reloaded dynamically (except `host` and `port`). |
|
||||
| Auth passthrough | **Stable** | MD5 password authentication can be configured to use an `auth_query` so no cleartext passwords are needed in the config file.|
|
||||
| Sharding using extended SQL syntax | **Experimental** | Clients can dynamically configure the pooler to route queries to specific shards. |
|
||||
| Sharding using comments parsing/Regex | **Experimental** | Clients can include shard information (sharding key, shard ID) in the query comments. |
|
||||
| Automatic sharding | **Experimental** | PgCat can parse queries, detect sharding keys automatically, and route queries to the correct shard. |
|
||||
| Mirroring | **Experimental** | Mirror queries between multiple databases in order to test servers with realistic production traffic. |
|
||||
|
||||
|
||||
## Status
|
||||
|
||||
PgCat is stable and used in production to serve hundreds of thousands of queries per second.
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td>
|
||||
<a href="https://tech.instacart.com/adopting-pgcat-a-nextgen-postgres-proxy-3cf284e68c2f">
|
||||
<img src="./images/instacart.webp" height="70" width="auto">
|
||||
</a>
|
||||
</td>
|
||||
<td>
|
||||
<a href="https://postgresml.org/blog/scaling-postgresml-to-one-million-requests-per-second">
|
||||
<img src="./images/postgresml.webp" height="70" width="auto">
|
||||
</a>
|
||||
</td>
|
||||
<td>
|
||||
<a href="https://onesignal.com">
|
||||
<img src="./images/one_signal.webp" height="70" width="auto">
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<a href="https://tech.instacart.com/adopting-pgcat-a-nextgen-postgres-proxy-3cf284e68c2f">
|
||||
Instacart
|
||||
</a>
|
||||
</td>
|
||||
<td>
|
||||
<a href="https://postgresml.org/blog/scaling-postgresml-to-one-million-requests-per-second">
|
||||
PostgresML
|
||||
</a>
|
||||
</td>
|
||||
<td>
|
||||
OneSignal
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
Some features remain experimental and are being actively developed. They are optional and can be enabled through configuration.
|
||||
|
||||
## Deployment
|
||||
|
||||
See `Dockerfile` for example deployment using Docker. The pooler is configured to spawn 4 workers so 4 CPUs are recommended for optimal performance. That setting can be adjusted to spawn as many (or as little) workers as needed.
|
||||
|
||||
A Docker image is available from `docker pull ghcr.io/postgresml/pgcat:latest`. See our [Github packages repository](https://github.com/postgresml/pgcat/pkgs/container/pgcat).
|
||||
|
||||
For quick local example, use the Docker Compose environment provided:
|
||||
|
||||
```bash
|
||||
@@ -39,9 +86,13 @@ PGPASSWORD=postgres psql -h 127.0.0.1 -p 6432 -U postgres -c 'SELECT 1'
|
||||
|
||||
### Config
|
||||
|
||||
See [Configurations page](https://github.com/levkk/pgcat/blob/main/CONFIG.md)
|
||||
See **[Configuration](https://github.com/levkk/pgcat/blob/main/CONFIG.md)**.
|
||||
|
||||
## Local development
|
||||
## Contributing
|
||||
|
||||
The project is being actively developed and looking for additional contributors and production deployments.
|
||||
|
||||
### Local development
|
||||
|
||||
1. Install Rust (latest stable will work great).
|
||||
2. `cargo build --release` (to get better benchmarks).
|
||||
@@ -51,7 +102,7 @@ See [Configurations page](https://github.com/levkk/pgcat/blob/main/CONFIG.md)
|
||||
|
||||
### Tests
|
||||
|
||||
Quickest way to test your changes is to use pgbench:
|
||||
When making substantial modifications to the protocol implementation, make sure to test them with pgbench:
|
||||
|
||||
```
|
||||
pgbench -i -h 127.0.0.1 -p 6432 && \
|
||||
@@ -61,36 +112,26 @@ pgbench -t 1000 -p 6432 -h 127.0.0.1 --protocol extended
|
||||
|
||||
See [sharding README](./tests/sharding/README.md) for sharding logic testing.
|
||||
|
||||
Run `cargo test` to run Rust tests.
|
||||
Additionally, all features are tested with Ruby, Python, and Rust unit and integration tests.
|
||||
|
||||
Run `cargo test` to run Rust unit tests.
|
||||
|
||||
Run the following commands to run Ruby and Python integration tests:
|
||||
|
||||
Run the following commands to run Integration tests locally.
|
||||
```
|
||||
cd tests/docker/
|
||||
docker compose up --exit-code-from main # This will also produce coverage report under ./cov/
|
||||
```
|
||||
|
||||
| **Feature** | **Tested in CI** | **Tested manually** | **Comments** |
|
||||
|-----------------------|--------------------|---------------------|--------------------------------------------------------------------------------------------------------------------------|
|
||||
| Transaction pooling | :white_check_mark: | :white_check_mark: | Used by default for all tests. |
|
||||
| Session pooling | :white_check_mark: | :white_check_mark: | Tested by running pgbench with `--protocol prepared` which only works in session mode. |
|
||||
| `COPY` | :white_check_mark: | :white_check_mark: | `pgbench -i` uses `COPY`. `COPY FROM` is tested as well. |
|
||||
| Query cancellation | :white_check_mark: | :white_check_mark: | `psql -c 'SELECT pg_sleep(1000);'` and press `Ctrl-C`. |
|
||||
| Load balancing | :white_check_mark: | :white_check_mark: | We could test this by emitting statistics for each replica and compare them. |
|
||||
| Failover | :white_check_mark: | :white_check_mark: | Misconfigure a replica in `pgcat.toml` and watch it forward queries to spares. CI testing is using Toxiproxy. |
|
||||
| Sharding | :white_check_mark: | :white_check_mark: | See `tests/sharding` and `tests/ruby` for an Rails/ActiveRecord example. |
|
||||
| Statistics | :white_check_mark: | :white_check_mark: | Query the admin database with `psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS'`. |
|
||||
| Live config reloading | :white_check_mark: | :white_check_mark: | Run `kill -s SIGHUP $(pgrep pgcat)` and watch the config reload. |
|
||||
### Docker-based local development
|
||||
|
||||
### Dev
|
||||
|
||||
Also, you can open a 'dev' environment where you can debug tests easier by running the following command:
|
||||
You can open a Docker development environment where you can debug tests easier. Run the following command to spin it up:
|
||||
|
||||
```
|
||||
./dev/script/console
|
||||
```
|
||||
|
||||
This will open a terminal in an environment similar to that used in tests. In there you can compile, run tests, do some debugging with the test environment, etc. Objects
|
||||
compiled inside the contaner (and bundled gems) will be placed in `dev/cache` so they don't interfere with what you have in your host.
|
||||
This will open a terminal in an environment similar to that used in tests. In there, you can compile the pooler, run tests, do some debugging with the test environment, etc. Objects compiled inside the container (and bundled gems) will be placed in `dev/cache` so they don't interfere with what you have on your machine.
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -105,11 +146,9 @@ In transaction mode, a client talks to one server for the duration of a single t
|
||||
This mode is enabled by default.
|
||||
|
||||
### Load balancing of read queries
|
||||
All queries are load balanced against the configured servers using the random algorithm. The most straight forward configuration example would be to put this pooler in front of several replicas and let it load balance all queries.
|
||||
All queries are load balanced against the configured servers using either the random or least open connections algorithms. The most straightforward configuration example would be to put this pooler in front of several replicas and let it load balance all queries.
|
||||
|
||||
If the configuration includes a primary and replicas, the queries can be separated with the built-in query parser. The query parser will interpret the query and route all `SELECT` queries to a replica, while all other queries including explicit transactions will be routed to the primary.
|
||||
|
||||
The query parser is disabled by default.
|
||||
If the configuration includes a primary and replicas, the queries can be separated with the built-in query parser. The query parser, implemented with the `sqlparser` crate, will interpret the query and route all `SELECT` queries to a replica, while all other queries including explicit transactions will be routed to the primary.
|
||||
|
||||
#### Query parser
|
||||
The query parser will do its best to determine where the query should go, but sometimes that's not possible. In that case, the client can select which server it wants using this custom SQL syntax:
|
||||
@@ -136,38 +175,14 @@ The setting will persist until it's changed again or the client disconnects.
|
||||
By default, all queries are routed to the first available server; `default_role` setting controls this behavior.
|
||||
|
||||
### Failover
|
||||
All servers are checked with a `SELECT 1` query before being given to a client. If the server is not reachable, it will be banned and cannot serve any more transactions for the duration of the ban. The queries are routed to the remaining servers. If all servers become banned, the ban list is cleared: this is a safety precaution against false positives. The primary can never be banned.
|
||||
All servers are checked with a `;` (very fast) query before being given to a client. Additionally, the server health is monitored with every client query that it processes. If the server is not reachable, it will be banned and cannot serve any more transactions for the duration of the ban. The queries are routed to the remaining servers. If all servers become banned, the ban list is cleared: this is a safety precaution against false positives. The primary can never be banned.
|
||||
|
||||
The ban time can be changed with `ban_time`. The default is 60 seconds.
|
||||
|
||||
Failover behavior can get pretty interesting (read complex) when multiple configurations and factors are involved. The table below will try to explain what PgCat does in each scenario:
|
||||
|
||||
| **Query** | **`SET SERVER ROLE TO`** | **`query_parser_enabled`** | **`primary_reads_enabled`** | **Target state** | **Outcome** |
|
||||
|---------------------------|--------------------------|----------------------------|-----------------------------|------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| Read query, i.e. `SELECT` | unset (any) | false | false | up | Query is routed to the first instance in the random loop. |
|
||||
| Read query | unset (any) | true | false | up | Query is routed to the first replica instance in the random loop. |
|
||||
| Read query | unset (any) | true | true | up | Query is routed to the first instance in the random loop. |
|
||||
| Read query | replica | false | false | up | Query is routed to the first replica instance in the random loop. |
|
||||
| Read query | primary | false | false | up | Query is routed to the primary. |
|
||||
| Read query | unset (any) | false | false | down | First instance is banned for reads. Next target in the random loop is attempted. |
|
||||
| Read query | unset (any) | true | false | down | First replica instance is banned. Next replica instance is attempted in the random loop. |
|
||||
| Read query | unset (any) | true | true | down | First instance (even if primary) is banned for reads. Next instance is attempted in the random loop. |
|
||||
| Read query | replica | false | false | down | First replica instance is banned. Next replica instance is attempted in the random loop. |
|
||||
| Read query | primary | false | false | down | The query is attempted against the primary and fails. The client receives an error. |
|
||||
| | | | | | |
|
||||
| Write query e.g. `INSERT` | unset (any) | false | false | up | The query is attempted against the first available instance in the random loop. If the instance is a replica, the query fails and the client receives an error. |
|
||||
| Write query | unset (any) | true | false | up | The query is routed to the primary. |
|
||||
| Write query | unset (any) | true | true | up | The query is routed to the primary. |
|
||||
| Write query | primary | false | false | up | The query is routed to the primary. |
|
||||
| Write query | replica | false | false | up | The query is routed to the replica and fails. The client receives an error. |
|
||||
| Write query | unset (any) | true | false | down | The query is routed to the primary and fails. The client receives an error. |
|
||||
| Write query | unset (any) | true | true | down | The query is routed to the primary and fails. The client receives an error. |
|
||||
| Write query | primary | false | false | down | The query is routed to the primary and fails. The client receives an error. |
|
||||
| | | | | | |
|
||||
|
||||
### Sharding
|
||||
We use the `PARTITION BY HASH` hashing function, the same as used by Postgres for declarative partitioning. This allows to shard the database using Postgres partitions and place the partitions on different servers (shards). Both read and write queries can be routed to the shards using this pooler.
|
||||
|
||||
#### Extended syntax
|
||||
To route queries to a particular shard, we use this custom SQL syntax:
|
||||
|
||||
```sql
|
||||
@@ -182,7 +197,8 @@ The active shard will last until it's changed again or the client disconnects. B
|
||||
|
||||
For hash function implementation, see `src/sharding.rs` and `tests/sharding/partition_hash_test_setup.sql`.
|
||||
|
||||
#### ActiveRecord/Rails
|
||||
|
||||
##### ActiveRecord/Rails
|
||||
|
||||
```ruby
|
||||
class User < ActiveRecord::Base
|
||||
@@ -210,7 +226,7 @@ User.connection.execute "SET SERVER ROLE TO 'auto'"
|
||||
User.find_by_email("test@example.com")
|
||||
```
|
||||
|
||||
#### Raw SQL
|
||||
##### Raw SQL
|
||||
|
||||
```sql
|
||||
-- Grab a bunch of users from shard 1
|
||||
@@ -230,268 +246,45 @@ SET SERVER ROLE TO 'auto'; -- let the query router figure out where the query sh
|
||||
SELECT * FROM users WHERE email = 'test@example.com'; -- shard setting lasts until set again; we are reading from the primary
|
||||
```
|
||||
|
||||
#### With comments
|
||||
Issuing queries to the pooler can cause additional latency. To reduce its impact, it's possible to include sharding information inside SQL comments sent via the query. This is reasonably easy to implement with ORMs like [ActiveRecord](https://api.rubyonrails.org/classes/ActiveRecord/QueryMethods.html#method-i-annotate) and [SQLAlchemy](https://docs.sqlalchemy.org/en/20/core/events.html#sql-execution-and-connection-events).
|
||||
|
||||
```
|
||||
/* shard_id: 5 */ SELECT * FROM foo WHERE id = 1234;
|
||||
|
||||
/* sharding_key: 1234 */ SELECT * FROM foo WHERE id = 1234;
|
||||
```
|
||||
|
||||
#### Automatic query parsing
|
||||
PgCat can use the `sqlparser` crate to parse SQL queries and extract the sharding key. This is configurable with the `automatic_sharding_key` setting. This feature is still experimental, but it's the ideal implementation for sharding, requiring no client modifications.
|
||||
|
||||
### Statistics reporting
|
||||
|
||||
The stats are very similar to what Pgbouncer reports and the names are kept to be comparable. They are accessible by querying the admin database `pgcat`, and `pgbouncer` for compatibility.
|
||||
The stats are very similar to what PgBouncer reports and the names are kept to be comparable. They are accessible by querying the admin database `pgcat`, and `pgbouncer` for compatibility.
|
||||
|
||||
```
|
||||
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW DATABASES'
|
||||
```
|
||||
|
||||
Additionally, Prometheus statistics are available at `/metrics` via HTTP.
|
||||
|
||||
### Live configuration reloading
|
||||
|
||||
The config can be reloaded by sending a `kill -s SIGHUP` to the process or by querying `RELOAD` to the admin database. Not all settings are currently supported by live reload:
|
||||
The config can be reloaded by sending a `kill -s SIGHUP` to the process or by querying `RELOAD` to the admin database. All settings except the `host` and `port` can be reloaded without restarting the pooler, including sharding and replicas configurations.
|
||||
|
||||
| **Config** | **Requires restart** |
|
||||
|-------------------------|----------------------|
|
||||
| `host` | yes |
|
||||
| `port` | yes |
|
||||
| `pool_mode` | no |
|
||||
| `connect_timeout` | yes |
|
||||
| `healthcheck_timeout` | no |
|
||||
| `shutdown_timeout` | no |
|
||||
| `healthcheck_delay` | no |
|
||||
| `ban_time` | no |
|
||||
| `user` | yes |
|
||||
| `shards` | yes |
|
||||
| `default_role` | no |
|
||||
| `primary_reads_enabled` | no |
|
||||
| `query_parser_enabled` | no |
|
||||
### Mirroring
|
||||
|
||||
Mirroring allows to route queries to multiple databases at the same time. This is useful for prewarning replicas before placing them into the active configuration, or for testing different versions of Postgres with live traffic.
|
||||
|
||||
## Benchmarks
|
||||
## License
|
||||
|
||||
You can setup PgBench locally through PgCat:
|
||||
PgCat is free and open source, released under the MIT license.
|
||||
|
||||
```
|
||||
pgbench -h 127.0.0.1 -p 6432 -i
|
||||
```
|
||||
## Contributors
|
||||
|
||||
Coincidenly, this uses `COPY` so you can test if that works. Additionally, we'll be running the following PgBench configurations:
|
||||
Many thanks to our amazing contributors!
|
||||
|
||||
1. 16 clients, 2 threads
|
||||
2. 32 clients, 2 threads
|
||||
3. 64 clients, 2 threads
|
||||
4. 128 clients, 2 threads
|
||||
<a href = "https://github.com/postgresml/pgcat/graphs/contributors">
|
||||
<img src = "https://contrib.rocks/image?repo=postgresml/pgcat"/>
|
||||
</a>
|
||||
|
||||
All queries will be `SELECT` only (`-S`) just so disks don't get in the way, since the dataset will be effectively all in RAM.
|
||||
|
||||
My setup:
|
||||
|
||||
- 8 cores, 16 hyperthreaded (AMD Ryzen 5800X)
|
||||
- 32GB RAM (doesn't matter for this benchmark, except to prove that Postgres will fit the whole dataset into RAM)
|
||||
|
||||
### PgBouncer
|
||||
|
||||
#### Config
|
||||
|
||||
```ini
|
||||
[databases]
|
||||
shard0 = host=localhost port=5432 user=sharding_user password=sharding_user
|
||||
|
||||
[pgbouncer]
|
||||
pool_mode = transaction
|
||||
max_client_conn = 1000
|
||||
```
|
||||
|
||||
Everything else stays default.
|
||||
|
||||
#### Runs
|
||||
|
||||
|
||||
```
|
||||
$ pgbench -t 1000 -c 16 -j 2 -p 6432 -h 127.0.0.1 -S --protocol extended shard0
|
||||
|
||||
starting vacuum...end.
|
||||
transaction type: <builtin: select only>
|
||||
scaling factor: 1
|
||||
query mode: extended
|
||||
number of clients: 16
|
||||
number of threads: 2
|
||||
number of transactions per client: 1000
|
||||
number of transactions actually processed: 16000/16000
|
||||
latency average = 0.155 ms
|
||||
tps = 103417.377469 (including connections establishing)
|
||||
tps = 103510.639935 (excluding connections establishing)
|
||||
|
||||
|
||||
$ pgbench -t 1000 -c 32 -j 2 -p 6432 -h 127.0.0.1 -S --protocol extended shard0
|
||||
|
||||
starting vacuum...end.
|
||||
transaction type: <builtin: select only>
|
||||
scaling factor: 1
|
||||
query mode: extended
|
||||
number of clients: 32
|
||||
number of threads: 2
|
||||
number of transactions per client: 1000
|
||||
number of transactions actually processed: 32000/32000
|
||||
latency average = 0.290 ms
|
||||
tps = 110325.939785 (including connections establishing)
|
||||
tps = 110386.513435 (excluding connections establishing)
|
||||
|
||||
|
||||
$ pgbench -t 1000 -c 64 -j 2 -p 6432 -h 127.0.0.1 -S --protocol extended shard0
|
||||
|
||||
starting vacuum...end.
|
||||
transaction type: <builtin: select only>
|
||||
scaling factor: 1
|
||||
query mode: extended
|
||||
number of clients: 64
|
||||
number of threads: 2
|
||||
number of transactions per client: 1000
|
||||
number of transactions actually processed: 64000/64000
|
||||
latency average = 0.692 ms
|
||||
tps = 92470.427412 (including connections establishing)
|
||||
tps = 92618.389350 (excluding connections establishing)
|
||||
|
||||
$ pgbench -t 1000 -c 128 -j 2 -p 6432 -h 127.0.0.1 -S --protocol extended shard0
|
||||
|
||||
starting vacuum...end.
|
||||
transaction type: <builtin: select only>
|
||||
scaling factor: 1
|
||||
query mode: extended
|
||||
number of clients: 128
|
||||
number of threads: 2
|
||||
number of transactions per client: 1000
|
||||
number of transactions actually processed: 128000/128000
|
||||
latency average = 1.406 ms
|
||||
tps = 91013.429985 (including connections establishing)
|
||||
tps = 91067.583928 (excluding connections establishing)
|
||||
```
|
||||
|
||||
### PgCat
|
||||
|
||||
#### Config
|
||||
|
||||
The only thing that matters here is the number of workers in the Tokio pool. Make sure to set it to < than the number of your CPU cores.
|
||||
Also account for hyper-threading, so if you have that, take the number you got above and divide it by two, that way only "real" cores serving
|
||||
requests.
|
||||
|
||||
My setup is 16 threads, 8 cores (`htop` shows as 16 CPUs), so I set the `max_workers` in Tokio to 4. Too many, and it starts conflicting with PgBench
|
||||
which is also running on the same system.
|
||||
|
||||
#### Runs
|
||||
|
||||
|
||||
```
|
||||
$ pgbench -t 1000 -c 16 -j 2 -p 6432 -h 127.0.0.1 -S --protocol extended
|
||||
starting vacuum...end.
|
||||
transaction type: <builtin: select only>
|
||||
scaling factor: 1
|
||||
query mode: extended
|
||||
number of clients: 16
|
||||
number of threads: 2
|
||||
number of transactions per client: 1000
|
||||
number of transactions actually processed: 16000/16000
|
||||
latency average = 0.164 ms
|
||||
tps = 97705.088232 (including connections establishing)
|
||||
tps = 97872.216045 (excluding connections establishing)
|
||||
|
||||
|
||||
$ pgbench -t 1000 -c 32 -j 2 -p 6432 -h 127.0.0.1 -S --protocol extended
|
||||
|
||||
starting vacuum...end.
|
||||
transaction type: <builtin: select only>
|
||||
scaling factor: 1
|
||||
query mode: extended
|
||||
number of clients: 32
|
||||
number of threads: 2
|
||||
number of transactions per client: 1000
|
||||
number of transactions actually processed: 32000/32000
|
||||
latency average = 0.288 ms
|
||||
tps = 111300.488119 (including connections establishing)
|
||||
tps = 111413.107800 (excluding connections establishing)
|
||||
|
||||
|
||||
$ pgbench -t 1000 -c 64 -j 2 -p 6432 -h 127.0.0.1 -S --protocol extended
|
||||
|
||||
starting vacuum...end.
|
||||
transaction type: <builtin: select only>
|
||||
scaling factor: 1
|
||||
query mode: extended
|
||||
number of clients: 64
|
||||
number of threads: 2
|
||||
number of transactions per client: 1000
|
||||
number of transactions actually processed: 64000/64000
|
||||
latency average = 0.556 ms
|
||||
tps = 115190.496139 (including connections establishing)
|
||||
tps = 115247.521295 (excluding connections establishing)
|
||||
|
||||
$ pgbench -t 1000 -c 128 -j 2 -p 6432 -h 127.0.0.1 -S --protocol extended
|
||||
|
||||
starting vacuum...end.
|
||||
transaction type: <builtin: select only>
|
||||
scaling factor: 1
|
||||
query mode: extended
|
||||
number of clients: 128
|
||||
number of threads: 2
|
||||
number of transactions per client: 1000
|
||||
number of transactions actually processed: 128000/128000
|
||||
latency average = 1.135 ms
|
||||
tps = 112770.562239 (including connections establishing)
|
||||
tps = 112796.502381 (excluding connections establishing)
|
||||
```
|
||||
|
||||
### Direct Postgres
|
||||
|
||||
Always good to have a base line.
|
||||
|
||||
#### Runs
|
||||
|
||||
```
|
||||
$ pgbench -t 1000 -c 16 -j 2 -p 5432 -h 127.0.0.1 -S --protocol extended shard0
|
||||
Password:
|
||||
starting vacuum...end.
|
||||
transaction type: <builtin: select only>
|
||||
scaling factor: 1
|
||||
query mode: extended
|
||||
number of clients: 16
|
||||
number of threads: 2
|
||||
number of transactions per client: 1000
|
||||
number of transactions actually processed: 16000/16000
|
||||
latency average = 0.115 ms
|
||||
tps = 139443.955722 (including connections establishing)
|
||||
tps = 142314.859075 (excluding connections establishing)
|
||||
|
||||
$ pgbench -t 1000 -c 32 -j 2 -p 5432 -h 127.0.0.1 -S --protocol extended shard0
|
||||
Password:
|
||||
starting vacuum...end.
|
||||
transaction type: <builtin: select only>
|
||||
scaling factor: 1
|
||||
query mode: extended
|
||||
number of clients: 32
|
||||
number of threads: 2
|
||||
number of transactions per client: 1000
|
||||
number of transactions actually processed: 32000/32000
|
||||
latency average = 0.212 ms
|
||||
tps = 150644.840891 (including connections establishing)
|
||||
tps = 152218.499430 (excluding connections establishing)
|
||||
|
||||
$ pgbench -t 1000 -c 64 -j 2 -p 5432 -h 127.0.0.1 -S --protocol extended shard0
|
||||
Password:
|
||||
starting vacuum...end.
|
||||
transaction type: <builtin: select only>
|
||||
scaling factor: 1
|
||||
query mode: extended
|
||||
number of clients: 64
|
||||
number of threads: 2
|
||||
number of transactions per client: 1000
|
||||
number of transactions actually processed: 64000/64000
|
||||
latency average = 0.420 ms
|
||||
tps = 152517.663404 (including connections establishing)
|
||||
tps = 153319.188482 (excluding connections establishing)
|
||||
|
||||
$ pgbench -t 1000 -c 128 -j 2 -p 5432 -h 127.0.0.1 -S --protocol extended shard0
|
||||
Password:
|
||||
starting vacuum...end.
|
||||
transaction type: <builtin: select only>
|
||||
scaling factor: 1
|
||||
query mode: extended
|
||||
number of clients: 128
|
||||
number of threads: 2
|
||||
number of transactions per client: 1000
|
||||
number of transactions actually processed: 128000/128000
|
||||
latency average = 0.854 ms
|
||||
tps = 149818.594087 (including connections establishing)
|
||||
tps = 150200.603049 (excluding connections establishing)
|
||||
```
|
||||
|
||||
@@ -26,6 +26,8 @@ x-common-env-pg:
|
||||
services:
|
||||
main:
|
||||
image: kubernetes/pause
|
||||
ports:
|
||||
- 6432
|
||||
|
||||
pg1:
|
||||
<<: *common-definition-pg
|
||||
@@ -56,6 +58,13 @@ services:
|
||||
POSTGRES_INITDB_ARGS: --auth-local=scram-sha-256 --auth-host=scram-sha-256 --auth=scram-sha-256
|
||||
PGPORT: 9432
|
||||
command: ["postgres", "-p", "9432", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-c", "pg_stat_statements.max=100000"]
|
||||
pg5:
|
||||
<<: *common-definition-pg
|
||||
environment:
|
||||
<<: *common-env-pg
|
||||
POSTGRES_INITDB_ARGS: --auth-local=md5 --auth-host=md5 --auth=md5
|
||||
PGPORT: 10432
|
||||
command: ["postgres", "-p", "5432", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-c", "pg_stat_statements.max=100000"]
|
||||
|
||||
toxiproxy:
|
||||
build: .
|
||||
@@ -69,6 +78,7 @@ services:
|
||||
- pg2
|
||||
- pg3
|
||||
- pg4
|
||||
- pg5
|
||||
|
||||
pgcat-shell:
|
||||
stdin_open: true
|
||||
|
||||
@@ -38,9 +38,6 @@ log_client_connections = false
|
||||
# If we should log client disconnections
|
||||
log_client_disconnections = false
|
||||
|
||||
# Reload config automatically if it changes.
|
||||
autoreload = false
|
||||
|
||||
# TLS
|
||||
# tls_certificate = "server.cert"
|
||||
# tls_private_key = "server.key"
|
||||
@@ -76,7 +73,7 @@ query_parser_enabled = true
|
||||
|
||||
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
|
||||
# load balancing of read queries. Otherwise, the primary will only be used for write
|
||||
# queries. The primary can always be explicitely selected with our custom protocol.
|
||||
# queries. The primary can always be explicitly selected with our custom protocol.
|
||||
primary_reads_enabled = true
|
||||
|
||||
# So what if you wanted to implement a different hashing function,
|
||||
|
||||
BIN
images/instacart.webp
Normal file
BIN
images/instacart.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 3.4 KiB |
BIN
images/one_signal.webp
Normal file
BIN
images/one_signal.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 16 KiB |
BIN
images/postgresml.webp
Normal file
BIN
images/postgresml.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.7 KiB |
103
pgcat.toml
103
pgcat.toml
@@ -23,6 +23,12 @@ connect_timeout = 5000 # milliseconds
|
||||
# How long an idle connection with a server is left open (ms).
|
||||
idle_timeout = 30000 # milliseconds
|
||||
|
||||
# Max connection lifetime before it's closed, even if actively used.
|
||||
server_lifetime = 86400000 # 24 hours
|
||||
|
||||
# How long a client is allowed to be idle while in a transaction (ms).
|
||||
idle_client_in_transaction_timeout = 0 # milliseconds
|
||||
|
||||
# How much time to give the health check query to return with a result (ms).
|
||||
healthcheck_timeout = 1000 # milliseconds
|
||||
|
||||
@@ -42,7 +48,7 @@ log_client_connections = false
|
||||
log_client_disconnections = false
|
||||
|
||||
# When set to true, PgCat reloads configs if it detects a change in the config file.
|
||||
autoreload = false
|
||||
autoreload = 15000
|
||||
|
||||
# Number of worker threads the Runtime will use (4 by default).
|
||||
worker_threads = 5
|
||||
@@ -54,10 +60,16 @@ tcp_keepalives_count = 5
|
||||
# Number of seconds between keepalive packets.
|
||||
tcp_keepalives_interval = 5
|
||||
|
||||
# Path to TLS Certficate file to use for TLS connections
|
||||
# tls_certificate = "server.cert"
|
||||
# Path to TLS Certificate file to use for TLS connections
|
||||
# tls_certificate = ".circleci/server.cert"
|
||||
# Path to TLS private key file to use for TLS connections
|
||||
# tls_private_key = "server.key"
|
||||
# tls_private_key = ".circleci/server.key"
|
||||
|
||||
# Enable/disable server TLS
|
||||
server_tls = false
|
||||
|
||||
# Verify server certificate is completely authentic.
|
||||
verify_server_certificate = false
|
||||
|
||||
# User name to access the virtual administrative database (pgbouncer or pgcat)
|
||||
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
|
||||
@@ -110,6 +122,21 @@ primary_reads_enabled = true
|
||||
# `sha1`: A hashing function based on SHA1
|
||||
sharding_function = "pg_bigint_hash"
|
||||
|
||||
# Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
|
||||
# established using the database configured in the pool. This parameter is inherited by every pool
|
||||
# and can be redefined in pool configuration.
|
||||
# auth_query = "SELECT $1"
|
||||
|
||||
# User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
|
||||
# specified in `auth_query_user`. The connection will be established using the database configured in the pool.
|
||||
# This parameter is inherited by every pool and can be redefined in pool configuration.
|
||||
# auth_query_user = "sharding_user"
|
||||
|
||||
# Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
|
||||
# specified in `auth_query_user`. The connection will be established using the database configured in the pool.
|
||||
# This parameter is inherited by every pool and can be redefined in pool configuration.
|
||||
# auth_query_password = "sharding_user"
|
||||
|
||||
# Automatically parse this from queries and route queries to the right shard!
|
||||
# automatic_sharding_key = "data.id"
|
||||
|
||||
@@ -119,18 +146,78 @@ idle_timeout = 40000
|
||||
# Connect timeout can be overwritten in the pool
|
||||
connect_timeout = 3000
|
||||
|
||||
# When enabled, ip resolutions for server connections specified using hostnames will be cached
|
||||
# and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found
|
||||
# old ip connections are closed (gracefully) and new connections will start using new ip.
|
||||
# dns_cache_enabled = false
|
||||
|
||||
# Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
|
||||
# dns_max_ttl = 30
|
||||
|
||||
[plugins]
|
||||
|
||||
[plugins.query_logger]
|
||||
enabled = false
|
||||
|
||||
[plugins.table_access]
|
||||
enabled = false
|
||||
tables = [
|
||||
"pg_user",
|
||||
"pg_roles",
|
||||
"pg_database",
|
||||
]
|
||||
|
||||
[plugins.intercept]
|
||||
enabled = true
|
||||
|
||||
[plugins.intercept.queries.0]
|
||||
|
||||
query = "select current_database() as a, current_schemas(false) as b"
|
||||
schema = [
|
||||
["a", "text"],
|
||||
["b", "text"],
|
||||
]
|
||||
result = [
|
||||
["${DATABASE}", "{public}"],
|
||||
]
|
||||
|
||||
[plugins.intercept.queries.1]
|
||||
|
||||
query = "select current_database(), current_schema(), current_user"
|
||||
schema = [
|
||||
["current_database", "text"],
|
||||
["current_schema", "text"],
|
||||
["current_user", "text"],
|
||||
]
|
||||
result = [
|
||||
["${DATABASE}", "public", "${USER}"],
|
||||
]
|
||||
|
||||
# User configs are structured as pool.<pool_name>.users.<user_index>
|
||||
# This secion holds the credentials for users that may connect to this cluster
|
||||
# This section holds the credentials for users that may connect to this cluster
|
||||
[pools.sharded_db.users.0]
|
||||
# Postgresql username
|
||||
# PostgreSQL username used to authenticate the user and connect to the server
|
||||
# if `server_username` is not set.
|
||||
username = "sharding_user"
|
||||
# Postgresql password
|
||||
|
||||
# PostgreSQL password used to authenticate the user and connect to the server
|
||||
# if `server_password` is not set.
|
||||
password = "sharding_user"
|
||||
|
||||
pool_mode = "session"
|
||||
|
||||
# PostgreSQL username used to connect to the server.
|
||||
# server_username = "another_user"
|
||||
|
||||
# PostgreSQL password used to connect to the server.
|
||||
# server_password = "another_password"
|
||||
|
||||
# Maximum number of server connections that can be established for this user
|
||||
# The maximum number of connection from a single Pgcat process to any database in the cluster
|
||||
# is the sum of pool_size across all users.
|
||||
pool_size = 9
|
||||
|
||||
|
||||
# Maximum query duration. Dangerous, but protects against DBs that died in a non-obvious way.
|
||||
# 0 means it is disabled.
|
||||
statement_timeout = 0
|
||||
@@ -175,6 +262,8 @@ sharding_function = "pg_bigint_hash"
|
||||
username = "simple_user"
|
||||
password = "simple_user"
|
||||
pool_size = 5
|
||||
min_pool_size = 3
|
||||
server_lifetime = 60000
|
||||
statement_timeout = 0
|
||||
|
||||
[pools.simple_db.shards.0]
|
||||
|
||||
123
src/admin.rs
123
src/admin.rs
@@ -1,21 +1,20 @@
|
||||
use crate::pool::BanReason;
|
||||
/// Admin database.
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use log::{error, info, trace};
|
||||
use nix::sys::signal::{self, Signal};
|
||||
use nix::unistd::Pid;
|
||||
use std::collections::HashMap;
|
||||
/// Admin database.
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use tokio::time::Instant;
|
||||
|
||||
use crate::config::{get_config, reload_config, VERSION};
|
||||
use crate::errors::Error;
|
||||
use crate::messages::*;
|
||||
use crate::pool::ClientServerMap;
|
||||
use crate::pool::{get_all_pools, get_pool};
|
||||
use crate::stats::{
|
||||
get_address_stats, get_client_stats, get_pool_stats, get_server_stats, ClientState, ServerState,
|
||||
};
|
||||
use crate::ClientServerMap;
|
||||
use crate::stats::{get_client_stats, get_pool_stats, get_server_stats, ClientState, ServerState};
|
||||
|
||||
pub fn generate_server_info_for_admin() -> BytesMut {
|
||||
let mut server_info = BytesMut::new();
|
||||
@@ -158,7 +157,14 @@ where
|
||||
"free_clients".to_string(),
|
||||
client_stats
|
||||
.keys()
|
||||
.filter(|client_id| client_stats.get(client_id).unwrap().state == ClientState::Idle)
|
||||
.filter(|client_id| {
|
||||
client_stats
|
||||
.get(client_id)
|
||||
.unwrap()
|
||||
.state
|
||||
.load(Ordering::Relaxed)
|
||||
== ClientState::Idle
|
||||
})
|
||||
.count()
|
||||
.to_string(),
|
||||
]));
|
||||
@@ -166,7 +172,14 @@ where
|
||||
"used_clients".to_string(),
|
||||
client_stats
|
||||
.keys()
|
||||
.filter(|client_id| client_stats.get(client_id).unwrap().state == ClientState::Active)
|
||||
.filter(|client_id| {
|
||||
client_stats
|
||||
.get(client_id)
|
||||
.unwrap()
|
||||
.state
|
||||
.load(Ordering::Relaxed)
|
||||
== ClientState::Active
|
||||
})
|
||||
.count()
|
||||
.to_string(),
|
||||
]));
|
||||
@@ -178,7 +191,14 @@ where
|
||||
"free_servers".to_string(),
|
||||
server_stats
|
||||
.keys()
|
||||
.filter(|server_id| server_stats.get(server_id).unwrap().state == ServerState::Idle)
|
||||
.filter(|server_id| {
|
||||
server_stats
|
||||
.get(server_id)
|
||||
.unwrap()
|
||||
.state
|
||||
.load(Ordering::Relaxed)
|
||||
== ServerState::Idle
|
||||
})
|
||||
.count()
|
||||
.to_string(),
|
||||
]));
|
||||
@@ -186,7 +206,14 @@ where
|
||||
"used_servers".to_string(),
|
||||
server_stats
|
||||
.keys()
|
||||
.filter(|server_id| server_stats.get(server_id).unwrap().state == ServerState::Active)
|
||||
.filter(|server_id| {
|
||||
server_stats
|
||||
.get(server_id)
|
||||
.unwrap()
|
||||
.state
|
||||
.load(Ordering::Relaxed)
|
||||
== ServerState::Active
|
||||
})
|
||||
.count()
|
||||
.to_string(),
|
||||
]));
|
||||
@@ -248,28 +275,15 @@ where
|
||||
|
||||
let mut res = BytesMut::new();
|
||||
res.put(row_description(&columns));
|
||||
for (user_pool, pool) in get_all_pools() {
|
||||
let def = HashMap::default();
|
||||
let pool_stats = all_pool_stats
|
||||
.get(&(user_pool.db.clone(), user_pool.user.clone()))
|
||||
.unwrap_or(&def);
|
||||
|
||||
let pool_config = &pool.settings;
|
||||
for ((_user_pool, _pool), pool_stats) in all_pool_stats {
|
||||
let mut row = vec![
|
||||
user_pool.db.clone(),
|
||||
user_pool.user.clone(),
|
||||
pool_config.pool_mode.to_string(),
|
||||
pool_stats.database(),
|
||||
pool_stats.user(),
|
||||
pool_stats.pool_mode().to_string(),
|
||||
];
|
||||
for column in &columns[3..columns.len()] {
|
||||
let value = match column.0 {
|
||||
"maxwait" => (pool_stats.get("maxwait_us").unwrap_or(&0) / 1_000_000).to_string(),
|
||||
"maxwait_us" => {
|
||||
(pool_stats.get("maxwait_us").unwrap_or(&0) % 1_000_000).to_string()
|
||||
}
|
||||
_other_values => pool_stats.get(column.0).unwrap_or(&0).to_string(),
|
||||
};
|
||||
row.push(value);
|
||||
}
|
||||
pool_stats.populate_row(&mut row);
|
||||
pool_stats.clear_maxwait();
|
||||
res.put(data_row(&row));
|
||||
}
|
||||
|
||||
@@ -400,7 +414,7 @@ where
|
||||
for (id, pool) in get_all_pools().iter() {
|
||||
for address in pool.get_addresses_from_host(host) {
|
||||
if !pool.is_banned(&address) {
|
||||
pool.ban(&address, BanReason::AdminBan(duration_seconds), -1);
|
||||
pool.ban(&address, BanReason::AdminBan(duration_seconds), None);
|
||||
res.put(data_row(&vec![
|
||||
id.db.clone(),
|
||||
id.user.clone(),
|
||||
@@ -617,7 +631,6 @@ where
|
||||
("avg_wait_time", DataType::Numeric),
|
||||
];
|
||||
|
||||
let all_stats = get_address_stats();
|
||||
let mut res = BytesMut::new();
|
||||
res.put(row_description(&columns));
|
||||
|
||||
@@ -625,15 +638,10 @@ where
|
||||
for shard in 0..pool.shards() {
|
||||
for server in 0..pool.servers(shard) {
|
||||
let address = pool.address(shard, server);
|
||||
let stats = match all_stats.get(&address.id) {
|
||||
Some(stats) => stats.clone(),
|
||||
None => HashMap::new(),
|
||||
};
|
||||
|
||||
let mut row = vec![address.name(), user_pool.db.clone(), user_pool.user.clone()];
|
||||
for column in &columns[3..] {
|
||||
row.push(stats.get(column.0).unwrap_or(&0).to_string());
|
||||
}
|
||||
let stats = address.stats.clone();
|
||||
stats.populate_row(&mut row);
|
||||
|
||||
res.put(data_row(&row));
|
||||
}
|
||||
@@ -673,16 +681,16 @@ where
|
||||
|
||||
for (_, client) in new_map {
|
||||
let row = vec![
|
||||
format!("{:#010X}", client.client_id),
|
||||
client.pool_name,
|
||||
client.username,
|
||||
client.application_name.clone(),
|
||||
client.state.to_string(),
|
||||
client.transaction_count.to_string(),
|
||||
client.query_count.to_string(),
|
||||
client.error_count.to_string(),
|
||||
format!("{:#010X}", client.client_id()),
|
||||
client.pool_name(),
|
||||
client.username(),
|
||||
client.application_name(),
|
||||
client.state.load(Ordering::Relaxed).to_string(),
|
||||
client.transaction_count.load(Ordering::Relaxed).to_string(),
|
||||
client.query_count.load(Ordering::Relaxed).to_string(),
|
||||
client.error_count.load(Ordering::Relaxed).to_string(),
|
||||
Instant::now()
|
||||
.duration_since(client.connect_time)
|
||||
.duration_since(client.connect_time())
|
||||
.as_secs()
|
||||
.to_string(),
|
||||
];
|
||||
@@ -724,19 +732,20 @@ where
|
||||
res.put(row_description(&columns));
|
||||
|
||||
for (_, server) in new_map {
|
||||
let application_name = server.application_name.read();
|
||||
let row = vec![
|
||||
format!("{:#010X}", server.server_id),
|
||||
server.pool_name,
|
||||
server.username,
|
||||
server.address_name,
|
||||
server.application_name,
|
||||
server.state.to_string(),
|
||||
server.transaction_count.to_string(),
|
||||
server.query_count.to_string(),
|
||||
server.bytes_sent.to_string(),
|
||||
server.bytes_received.to_string(),
|
||||
format!("{:#010X}", server.server_id()),
|
||||
server.pool_name(),
|
||||
server.username(),
|
||||
server.address_name(),
|
||||
application_name.clone(),
|
||||
server.state.load(Ordering::Relaxed).to_string(),
|
||||
server.transaction_count.load(Ordering::Relaxed).to_string(),
|
||||
server.query_count.load(Ordering::Relaxed).to_string(),
|
||||
server.bytes_sent.load(Ordering::Relaxed).to_string(),
|
||||
server.bytes_received.load(Ordering::Relaxed).to_string(),
|
||||
Instant::now()
|
||||
.duration_since(server.connect_time)
|
||||
.duration_since(server.connect_time())
|
||||
.as_secs()
|
||||
.to_string(),
|
||||
];
|
||||
|
||||
134
src/auth_passthrough.rs
Normal file
134
src/auth_passthrough.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
use crate::errors::Error;
|
||||
use crate::pool::ConnectionPool;
|
||||
use crate::server::Server;
|
||||
use log::debug;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AuthPassthrough {
|
||||
password: String,
|
||||
query: String,
|
||||
user: String,
|
||||
}
|
||||
|
||||
impl AuthPassthrough {
|
||||
/// Initializes an AuthPassthrough.
|
||||
pub fn new(query: &str, user: &str, password: &str) -> Self {
|
||||
AuthPassthrough {
|
||||
password: password.to_string(),
|
||||
query: query.to_string(),
|
||||
user: user.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an AuthPassthrough given the pool configuration.
|
||||
/// If any of required values is not set, None is returned.
|
||||
pub fn from_pool_config(pool_config: &crate::config::Pool) -> Option<Self> {
|
||||
if pool_config.is_auth_query_configured() {
|
||||
return Some(AuthPassthrough::new(
|
||||
pool_config.auth_query.as_ref().unwrap(),
|
||||
pool_config.auth_query_user.as_ref().unwrap(),
|
||||
pool_config.auth_query_password.as_ref().unwrap(),
|
||||
));
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Returns an AuthPassthrough given the pool settings.
|
||||
/// If any of required values is not set, None is returned.
|
||||
pub fn from_pool_settings(pool_settings: &crate::pool::PoolSettings) -> Option<Self> {
|
||||
let pool_config = crate::config::Pool {
|
||||
auth_query: pool_settings.auth_query.clone(),
|
||||
auth_query_password: pool_settings.auth_query_password.clone(),
|
||||
auth_query_user: pool_settings.auth_query_user.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
AuthPassthrough::from_pool_config(&pool_config)
|
||||
}
|
||||
|
||||
/// Connects to server and executes auth_query for the specified address.
|
||||
/// If the response is a row with two columns containing the username set in the address.
|
||||
/// and its MD5 hash, the MD5 hash returned.
|
||||
///
|
||||
/// Note that the query is executed, changing $1 with the name of the user
|
||||
/// this is so we only hold in memory (and transfer) the least amount of 'sensitive' data.
|
||||
/// Also, it is compatible with pgbouncer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `address` - An Address of the server we want to connect to. The username for the hash will be obtained from this value.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use pgcat::auth_passthrough::AuthPassthrough;
|
||||
/// use pgcat::config::Address;
|
||||
/// let auth_passthrough = AuthPassthrough::new("SELECT * FROM public.user_lookup('$1');", "postgres", "postgres");
|
||||
/// auth_passthrough.fetch_hash(&Address::default());
|
||||
/// ```
|
||||
///
|
||||
pub async fn fetch_hash(&self, address: &crate::config::Address) -> Result<String, Error> {
|
||||
let auth_user = crate::config::User {
|
||||
username: self.user.clone(),
|
||||
password: Some(self.password.clone()),
|
||||
server_username: None,
|
||||
server_password: None,
|
||||
pool_size: 1,
|
||||
statement_timeout: 0,
|
||||
pool_mode: None,
|
||||
server_lifetime: None,
|
||||
min_pool_size: None,
|
||||
};
|
||||
|
||||
let user = &address.username;
|
||||
|
||||
debug!("Connecting to server to obtain auth hashes");
|
||||
|
||||
let auth_query = self.query.replace("$1", user);
|
||||
|
||||
match Server::exec_simple_query(address, &auth_user, &auth_query).await {
|
||||
Ok(password_data) => {
|
||||
if password_data.len() == 2 && password_data.first().unwrap() == user {
|
||||
if let Some(stripped_hash) = password_data
|
||||
.last()
|
||||
.unwrap()
|
||||
.to_string()
|
||||
.strip_prefix("md5") {
|
||||
Ok(stripped_hash.to_string())
|
||||
}
|
||||
else {
|
||||
Err(Error::AuthPassthroughError(
|
||||
"Obtained hash from auth_query does not seem to be in md5 format.".to_string(),
|
||||
))
|
||||
}
|
||||
} else {
|
||||
Err(Error::AuthPassthroughError(
|
||||
"Data obtained from query does not follow the scheme 'user','hash'."
|
||||
.to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
Err(Error::AuthPassthroughError(
|
||||
format!("Error trying to obtain password from auth_query, ignoring hash for user '{}'. Error: {:?}",
|
||||
user, err))
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn refetch_auth_hash(pool: &ConnectionPool) -> Result<String, Error> {
|
||||
let address = pool.address(0, 0);
|
||||
if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) {
|
||||
let hash = apt.fetch_hash(address).await?;
|
||||
|
||||
return Ok(hash);
|
||||
}
|
||||
|
||||
Err(Error::ClientError(format!(
|
||||
"Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.",
|
||||
address.username, address.database
|
||||
)))
|
||||
}
|
||||
448
src/client.rs
448
src/client.rs
@@ -1,10 +1,10 @@
|
||||
use crate::errors::Error;
|
||||
use crate::errors::{ClientIdentifier, Error};
|
||||
use crate::pool::BanReason;
|
||||
/// Handle clients by pretending to be a PostgreSQL server.
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use log::{debug, error, info, trace, warn};
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf};
|
||||
use tokio::net::TcpStream;
|
||||
@@ -12,14 +12,15 @@ use tokio::sync::broadcast::Receiver;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
|
||||
use crate::admin::{generate_server_info_for_admin, handle_admin};
|
||||
use crate::config::{get_config, Address, PoolMode};
|
||||
use crate::auth_passthrough::refetch_auth_hash;
|
||||
use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode};
|
||||
use crate::constants::*;
|
||||
|
||||
use crate::messages::*;
|
||||
use crate::plugins::PluginOutput;
|
||||
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
|
||||
use crate::query_router::{Command, QueryRouter};
|
||||
use crate::server::Server;
|
||||
use crate::stats::{get_reporter, Reporter};
|
||||
use crate::stats::{ClientStats, PoolStats, ServerStats};
|
||||
use crate::tls::Tls;
|
||||
|
||||
use tokio_rustls::server::TlsStream;
|
||||
@@ -66,8 +67,8 @@ pub struct Client<S, T> {
|
||||
#[allow(dead_code)]
|
||||
parameters: HashMap<String, String>,
|
||||
|
||||
/// Statistics
|
||||
stats: Reporter,
|
||||
/// Statistics related to this client
|
||||
stats: Arc<ClientStats>,
|
||||
|
||||
/// Clients want to talk to admin database.
|
||||
admin: bool,
|
||||
@@ -75,8 +76,8 @@ pub struct Client<S, T> {
|
||||
/// Last address the client talked to.
|
||||
last_address_id: Option<usize>,
|
||||
|
||||
/// Last server process id we talked to.
|
||||
last_server_id: Option<i32>,
|
||||
/// Last server process stats we talked to.
|
||||
last_server_stats: Option<Arc<ServerStats>>,
|
||||
|
||||
/// Connected to server
|
||||
connected_to_server: bool,
|
||||
@@ -135,6 +136,10 @@ pub async fn client_entrypoint(
|
||||
|
||||
if !client.is_admin() {
|
||||
let _ = drain.send(-1).await;
|
||||
|
||||
if result.is_err() {
|
||||
client.stats.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
@@ -183,6 +188,10 @@ pub async fn client_entrypoint(
|
||||
|
||||
if !client.is_admin() {
|
||||
let _ = drain.send(-1).await;
|
||||
|
||||
if result.is_err() {
|
||||
client.stats.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
@@ -194,7 +203,7 @@ pub async fn client_entrypoint(
|
||||
// Client probably disconnected rejecting our plain text connection.
|
||||
Ok((ClientConnectionType::Tls, _))
|
||||
| Ok((ClientConnectionType::CancelQuery, _)) => Err(Error::ProtocolSyncError(
|
||||
format!("Bad postgres client (plain)"),
|
||||
"Bad postgres client (plain)".into(),
|
||||
)),
|
||||
|
||||
Err(err) => Err(err),
|
||||
@@ -233,6 +242,10 @@ pub async fn client_entrypoint(
|
||||
|
||||
if !client.is_admin() {
|
||||
let _ = drain.send(-1).await;
|
||||
|
||||
if result.is_err() {
|
||||
client.stats.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
@@ -258,8 +271,11 @@ pub async fn client_entrypoint(
|
||||
|
||||
if !client.is_admin() {
|
||||
let _ = drain.send(-1).await;
|
||||
}
|
||||
|
||||
if result.is_err() {
|
||||
client.stats.disconnect();
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
@@ -354,9 +370,9 @@ pub async fn startup_tls(
|
||||
}
|
||||
|
||||
// Bad Postgres client.
|
||||
Ok((ClientConnectionType::Tls, _)) | Ok((ClientConnectionType::CancelQuery, _)) => Err(
|
||||
Error::ProtocolSyncError(format!("Bad postgres client (tls)")),
|
||||
),
|
||||
Ok((ClientConnectionType::Tls, _)) | Ok((ClientConnectionType::CancelQuery, _)) => {
|
||||
Err(Error::ProtocolSyncError("Bad postgres client (tls)".into()))
|
||||
}
|
||||
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
@@ -382,7 +398,6 @@ where
|
||||
shutdown: Receiver<()>,
|
||||
admin_only: bool,
|
||||
) -> Result<Client<S, T>, Error> {
|
||||
let stats = get_reporter();
|
||||
let parameters = parse_startup(bytes.clone())?;
|
||||
|
||||
// This parameter is mandatory by the protocol.
|
||||
@@ -390,7 +405,7 @@ where
|
||||
Some(user) => user,
|
||||
None => {
|
||||
return Err(Error::ClientError(
|
||||
"Missing user parameter on client startup".to_string(),
|
||||
"Missing user parameter on client startup".into(),
|
||||
))
|
||||
}
|
||||
};
|
||||
@@ -405,6 +420,8 @@ where
|
||||
None => "pgcat",
|
||||
};
|
||||
|
||||
let client_identifier = ClientIdentifier::new(&application_name, &username, &pool_name);
|
||||
|
||||
let admin = ["pgcat", "pgbouncer"]
|
||||
.iter()
|
||||
.filter(|db| *db == pool_name)
|
||||
@@ -435,7 +452,12 @@ where
|
||||
|
||||
let code = match read.read_u8().await {
|
||||
Ok(p) => p,
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading password code from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))),
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password code".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// PasswordMessage
|
||||
@@ -448,19 +470,30 @@ where
|
||||
|
||||
let len = match read.read_i32().await {
|
||||
Ok(len) => len,
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading password message length from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))),
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password message length".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let mut password_response = vec![0u8; (len - 4) as usize];
|
||||
|
||||
match read.read_exact(&mut password_response).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading password message from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))),
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password message".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Authenticate admin user.
|
||||
let (transaction_mode, server_info) = if admin {
|
||||
let config = get_config();
|
||||
|
||||
// Compare server and client hashes.
|
||||
let password_hash = md5_hash_password(
|
||||
&config.general.admin_username,
|
||||
@@ -469,10 +502,12 @@ where
|
||||
);
|
||||
|
||||
if password_hash != password_response {
|
||||
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name);
|
||||
let error = Error::ClientGeneralError("Invalid password".into(), client_identifier);
|
||||
|
||||
warn!("{}", error);
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name)));
|
||||
return Err(error);
|
||||
}
|
||||
|
||||
(false, generate_server_info_for_admin())
|
||||
@@ -491,18 +526,100 @@ where
|
||||
)
|
||||
.await?;
|
||||
|
||||
return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name)));
|
||||
return Err(Error::ClientGeneralError(
|
||||
"Invalid pool name".into(),
|
||||
client_identifier,
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// Compare server and client hashes.
|
||||
let password_hash = md5_hash_password(username, &pool.settings.user.password, &salt);
|
||||
// Obtain the hash to compare, we give preference to that written in cleartext in config
|
||||
// if there is nothing set in cleartext and auth passthrough (auth_query) is configured, we use the hash obtained
|
||||
// when the pool was created. If there is no hash there, we try to fetch it one more time.
|
||||
let password_hash = if let Some(password) = &pool.settings.user.password {
|
||||
Some(md5_hash_password(username, password, &salt))
|
||||
} else {
|
||||
if !get_config().is_auth_query_configured() {
|
||||
wrong_password(&mut write, username).await?;
|
||||
return Err(Error::ClientAuthImpossible(username.into()));
|
||||
}
|
||||
|
||||
if password_hash != password_response {
|
||||
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name);
|
||||
wrong_password(&mut write, username).await?;
|
||||
let mut hash = (*pool.auth_hash.read()).clone();
|
||||
|
||||
return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name)));
|
||||
if hash.is_none() {
|
||||
warn!(
|
||||
"Query auth configured \
|
||||
but no hash password found \
|
||||
for pool {}. Will try to refetch it.",
|
||||
pool_name
|
||||
);
|
||||
|
||||
match refetch_auth_hash(&pool).await {
|
||||
Ok(fetched_hash) => {
|
||||
warn!("Password for {}, obtained. Updating.", client_identifier);
|
||||
|
||||
{
|
||||
let mut pool_auth_hash = pool.auth_hash.write();
|
||||
*pool_auth_hash = Some(fetched_hash.clone());
|
||||
}
|
||||
|
||||
hash = Some(fetched_hash);
|
||||
}
|
||||
|
||||
Err(err) => {
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(Error::ClientAuthPassthroughError(
|
||||
err.to_string(),
|
||||
client_identifier,
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Some(md5_hash_second_pass(&hash.unwrap(), &salt))
|
||||
};
|
||||
|
||||
// Once we have the resulting hash, we compare with what the client gave us.
|
||||
// If they do not match and auth query is set up, we try to refetch the hash one more time
|
||||
// to see if the password has changed since the pool was created.
|
||||
//
|
||||
// @TODO: we could end up fetching again the same password twice (see above).
|
||||
if password_hash.unwrap() != password_response {
|
||||
warn!(
|
||||
"Invalid password {}, will try to refetch it.",
|
||||
client_identifier
|
||||
);
|
||||
|
||||
let fetched_hash = match refetch_auth_hash(&pool).await {
|
||||
Ok(fetched_hash) => fetched_hash,
|
||||
Err(err) => {
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt);
|
||||
|
||||
// Ok password changed in server an auth is possible.
|
||||
if new_password_hash == password_response {
|
||||
warn!(
|
||||
"Password for {}, changed in server. Updating.",
|
||||
client_identifier
|
||||
);
|
||||
|
||||
{
|
||||
let mut pool_auth_hash = pool.auth_hash.write();
|
||||
*pool_auth_hash = Some(fetched_hash);
|
||||
}
|
||||
} else {
|
||||
wrong_password(&mut write, username).await?;
|
||||
return Err(Error::ClientGeneralError(
|
||||
"Invalid password".into(),
|
||||
client_identifier,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction;
|
||||
@@ -537,6 +654,25 @@ where
|
||||
ready_for_query(&mut write).await?;
|
||||
|
||||
trace!("Startup OK");
|
||||
let pool_stats = match get_pool(pool_name, username) {
|
||||
Some(pool) => {
|
||||
if !admin {
|
||||
pool.stats
|
||||
} else {
|
||||
Arc::new(PoolStats::default())
|
||||
}
|
||||
}
|
||||
None => Arc::new(PoolStats::default()),
|
||||
};
|
||||
|
||||
let stats = Arc::new(ClientStats::new(
|
||||
process_id,
|
||||
application_name,
|
||||
username,
|
||||
pool_name,
|
||||
tokio::time::Instant::now(),
|
||||
pool_stats,
|
||||
));
|
||||
|
||||
Ok(Client {
|
||||
read: BufReader::new(read),
|
||||
@@ -552,7 +688,7 @@ where
|
||||
stats,
|
||||
admin,
|
||||
last_address_id: None,
|
||||
last_server_id: None,
|
||||
last_server_stats: None,
|
||||
pool_name: pool_name.clone(),
|
||||
username: username.clone(),
|
||||
application_name: application_name.to_string(),
|
||||
@@ -583,10 +719,10 @@ where
|
||||
secret_key,
|
||||
client_server_map,
|
||||
parameters: HashMap::new(),
|
||||
stats: get_reporter(),
|
||||
stats: Arc::new(ClientStats::default()),
|
||||
admin: false,
|
||||
last_address_id: None,
|
||||
last_server_id: None,
|
||||
last_server_stats: None,
|
||||
pool_name: String::from("undefined"),
|
||||
username: String::from("undefined"),
|
||||
application_name: String::from("undefined"),
|
||||
@@ -627,12 +763,11 @@ where
|
||||
// The query router determines where the query is going to go,
|
||||
// e.g. primary, replica, which shard.
|
||||
let mut query_router = QueryRouter::new();
|
||||
self.stats.client_register(
|
||||
self.process_id,
|
||||
self.pool_name.clone(),
|
||||
self.username.clone(),
|
||||
self.application_name.clone(),
|
||||
);
|
||||
|
||||
self.stats.register(self.stats.clone());
|
||||
|
||||
// Result returned by one of the plugins.
|
||||
let mut plugin_output = None;
|
||||
|
||||
// Our custom protocol loop.
|
||||
// We expect the client to either start a transaction with regular queries
|
||||
@@ -656,7 +791,9 @@ where
|
||||
&mut self.write,
|
||||
"terminating connection due to administrator command"
|
||||
).await?;
|
||||
return Ok(())
|
||||
|
||||
self.stats.disconnect();
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Admin clients ignore shutdown.
|
||||
@@ -682,7 +819,25 @@ where
|
||||
|
||||
'Q' => {
|
||||
if query_router.query_parser_enabled() {
|
||||
query_router.infer(&message);
|
||||
if let Ok(ast) = QueryRouter::parse(&message) {
|
||||
let plugin_result = query_router.execute_plugins(&ast).await;
|
||||
|
||||
match plugin_result {
|
||||
Ok(PluginOutput::Deny(error)) => {
|
||||
error_response(&mut self.write, &error).await?;
|
||||
continue;
|
||||
}
|
||||
|
||||
Ok(PluginOutput::Intercept(result)) => {
|
||||
write_all(&mut self.write, result).await?;
|
||||
continue;
|
||||
}
|
||||
|
||||
_ => (),
|
||||
};
|
||||
|
||||
let _ = query_router.infer(&ast);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -690,7 +845,13 @@ where
|
||||
self.buffer.put(&message[..]);
|
||||
|
||||
if query_router.query_parser_enabled() {
|
||||
query_router.infer(&message);
|
||||
if let Ok(ast) = QueryRouter::parse(&message) {
|
||||
if let Ok(output) = query_router.execute_plugins(&ast).await {
|
||||
plugin_output = Some(output);
|
||||
}
|
||||
|
||||
let _ = query_router.infer(&ast);
|
||||
}
|
||||
}
|
||||
|
||||
continue;
|
||||
@@ -708,6 +869,9 @@ where
|
||||
|
||||
'X' => {
|
||||
debug!("Client disconnecting");
|
||||
|
||||
self.stats.disconnect();
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
@@ -721,6 +885,18 @@ where
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check on plugin results.
|
||||
match plugin_output {
|
||||
Some(PluginOutput::Deny(error)) => {
|
||||
self.buffer.clear();
|
||||
error_response(&mut self.write, &error).await?;
|
||||
plugin_output = None;
|
||||
continue;
|
||||
}
|
||||
|
||||
_ => (),
|
||||
};
|
||||
|
||||
// Get a pool instance referenced by the most up-to-date
|
||||
// pointer. This ensures we always read the latest config
|
||||
// when starting a query.
|
||||
@@ -757,7 +933,7 @@ where
|
||||
current_shard,
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
.await?;
|
||||
} else {
|
||||
custom_protocol_response_ok(&mut self.write, "SET SHARD").await?;
|
||||
}
|
||||
@@ -802,10 +978,13 @@ where
|
||||
};
|
||||
|
||||
debug!("Waiting for connection from pool");
|
||||
if !self.admin {
|
||||
self.stats.waiting();
|
||||
}
|
||||
|
||||
// Grab a server from the pool.
|
||||
let connection = match pool
|
||||
.get(query_router.shard(), query_router.role(), self.process_id)
|
||||
.get(query_router.shard(), query_router.role(), &self.stats)
|
||||
.await
|
||||
{
|
||||
Ok(conn) => {
|
||||
@@ -817,15 +996,32 @@ where
|
||||
// but we were unable to grab a connection from the pool
|
||||
// We'll send back an error message and clean the extended
|
||||
// protocol buffer
|
||||
self.stats.idle();
|
||||
|
||||
if message[0] as char == 'S' {
|
||||
error!("Got Sync message but failed to get a connection from the pool");
|
||||
self.buffer.clear();
|
||||
}
|
||||
|
||||
error_response(&mut self.write, "could not get connection from the pool")
|
||||
.await?;
|
||||
|
||||
error!("Could not get connection from pool: {{ pool_name: {:?}, username: {:?}, shard: {:?}, role: \"{:?}\", error: \"{:?}\" }}",
|
||||
self.pool_name.clone(), self.username.clone(), query_router.shard(), query_router.role(), err);
|
||||
error!(
|
||||
"Could not get connection from pool: \
|
||||
{{ \
|
||||
pool_name: {:?}, \
|
||||
username: {:?}, \
|
||||
shard: {:?}, \
|
||||
role: \"{:?}\", \
|
||||
error: \"{:?}\" \
|
||||
}}",
|
||||
self.pool_name,
|
||||
self.username,
|
||||
query_router.shard(),
|
||||
query_router.role(),
|
||||
err
|
||||
);
|
||||
|
||||
continue;
|
||||
}
|
||||
};
|
||||
@@ -840,11 +1036,10 @@ where
|
||||
self.connected_to_server = true;
|
||||
|
||||
// Update statistics
|
||||
self.stats
|
||||
.client_active(self.process_id, server.server_id());
|
||||
self.stats.active();
|
||||
|
||||
self.last_address_id = Some(address.id);
|
||||
self.last_server_id = Some(server.server_id());
|
||||
self.last_server_stats = Some(server.stats());
|
||||
|
||||
debug!(
|
||||
"Client {:?} talking to server {:?}",
|
||||
@@ -859,6 +1054,11 @@ where
|
||||
|
||||
let mut initial_message = Some(message);
|
||||
|
||||
let idle_client_timeout_duration = match get_idle_client_in_transaction_timeout() {
|
||||
0 => tokio::time::Duration::MAX,
|
||||
timeout => tokio::time::Duration::from_millis(timeout),
|
||||
};
|
||||
|
||||
// Transaction loop. Multiple queries can be issued by the client here.
|
||||
// The connection belongs to the client until the transaction is over,
|
||||
// or until the client disconnects if we are in session mode.
|
||||
@@ -870,17 +1070,43 @@ where
|
||||
None => {
|
||||
trace!("Waiting for message inside transaction or in session mode");
|
||||
|
||||
match read_message(&mut self.read).await {
|
||||
Ok(message) => message,
|
||||
Err(err) => {
|
||||
match tokio::time::timeout(
|
||||
idle_client_timeout_duration,
|
||||
read_message(&mut self.read),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(message)) => message,
|
||||
Ok(Err(err)) => {
|
||||
// Client disconnected inside a transaction.
|
||||
// Clean up the server and re-use it.
|
||||
self.stats.disconnect();
|
||||
server.checkin_cleanup().await?;
|
||||
|
||||
return Err(err);
|
||||
}
|
||||
Err(_) => {
|
||||
// Client idle in transaction timeout
|
||||
error_response(&mut self.write, "idle transaction timeout").await?;
|
||||
error!(
|
||||
"Client idle in transaction timeout: \
|
||||
{{ \
|
||||
pool_name: {}, \
|
||||
username: {}, \
|
||||
shard: {}, \
|
||||
role: \"{:?}\" \
|
||||
}}",
|
||||
self.pool_name,
|
||||
self.username,
|
||||
query_router.shard(),
|
||||
query_router.role()
|
||||
);
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(message) => {
|
||||
initial_message = None;
|
||||
message
|
||||
@@ -899,18 +1125,49 @@ where
|
||||
match code {
|
||||
// Query
|
||||
'Q' => {
|
||||
if query_router.query_parser_enabled() {
|
||||
if let Ok(ast) = QueryRouter::parse(&message) {
|
||||
let plugin_result = query_router.execute_plugins(&ast).await;
|
||||
|
||||
match plugin_result {
|
||||
Ok(PluginOutput::Deny(error)) => {
|
||||
error_response(&mut self.write, &error).await?;
|
||||
continue;
|
||||
}
|
||||
|
||||
Ok(PluginOutput::Intercept(result)) => {
|
||||
write_all(&mut self.write, result).await?;
|
||||
continue;
|
||||
}
|
||||
|
||||
_ => (),
|
||||
};
|
||||
|
||||
let _ = query_router.infer(&ast);
|
||||
}
|
||||
}
|
||||
debug!("Sending query to server");
|
||||
|
||||
self.send_and_receive_loop(code, Some(&message), server, &address, &pool)
|
||||
.await?;
|
||||
self.send_and_receive_loop(
|
||||
code,
|
||||
Some(&message),
|
||||
server,
|
||||
&address,
|
||||
&pool,
|
||||
&self.stats.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
if !server.in_transaction() {
|
||||
// Report transaction executed statistics.
|
||||
self.stats.transaction(self.process_id, server.server_id());
|
||||
self.stats.transaction();
|
||||
server.stats().transaction(&self.application_name);
|
||||
|
||||
// Release server back to the pool if we are in transaction mode.
|
||||
// If we are in session mode, we keep the server until the client disconnects.
|
||||
if self.transaction_mode {
|
||||
self.stats.idle();
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -919,6 +1176,7 @@ where
|
||||
// Terminate
|
||||
'X' => {
|
||||
server.checkin_cleanup().await?;
|
||||
self.stats.disconnect();
|
||||
self.release();
|
||||
|
||||
return Ok(());
|
||||
@@ -927,6 +1185,14 @@ where
|
||||
// Parse
|
||||
// The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
|
||||
'P' => {
|
||||
if query_router.query_parser_enabled() {
|
||||
if let Ok(ast) = QueryRouter::parse(&message) {
|
||||
if let Ok(output) = query_router.execute_plugins(&ast).await {
|
||||
plugin_output = Some(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.buffer.put(&message[..]);
|
||||
}
|
||||
|
||||
@@ -942,6 +1208,11 @@ where
|
||||
self.buffer.put(&message[..]);
|
||||
}
|
||||
|
||||
// Close the prepared statement.
|
||||
'C' => {
|
||||
self.buffer.put(&message[..]);
|
||||
}
|
||||
|
||||
// Execute
|
||||
// Execute a prepared statement prepared in `P` and bound in `B`.
|
||||
'E' => {
|
||||
@@ -953,6 +1224,24 @@ where
|
||||
'S' => {
|
||||
debug!("Sending query to server");
|
||||
|
||||
match plugin_output {
|
||||
Some(PluginOutput::Deny(error)) => {
|
||||
error_response(&mut self.write, &error).await?;
|
||||
plugin_output = None;
|
||||
self.buffer.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
Some(PluginOutput::Intercept(result)) => {
|
||||
write_all(&mut self.write, result).await?;
|
||||
plugin_output = None;
|
||||
self.buffer.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
_ => (),
|
||||
};
|
||||
|
||||
self.buffer.put(&message[..]);
|
||||
|
||||
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;
|
||||
@@ -971,13 +1260,21 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
self.send_and_receive_loop(code, None, server, &address, &pool)
|
||||
.await?;
|
||||
self.send_and_receive_loop(
|
||||
code,
|
||||
None,
|
||||
server,
|
||||
&address,
|
||||
&pool,
|
||||
&self.stats.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
self.buffer.clear();
|
||||
|
||||
if !server.in_transaction() {
|
||||
self.stats.transaction(self.process_id, server.server_id());
|
||||
self.stats.transaction();
|
||||
server.stats().transaction(&self.application_name);
|
||||
|
||||
// Release server back to the pool if we are in transaction mode.
|
||||
// If we are in session mode, we keep the server until the client disconnects.
|
||||
@@ -1012,7 +1309,9 @@ where
|
||||
// Clear the buffer
|
||||
self.buffer.clear();
|
||||
|
||||
let response = self.receive_server_message(server, &address, &pool).await?;
|
||||
let response = self
|
||||
.receive_server_message(server, &address, &pool, &self.stats.clone())
|
||||
.await?;
|
||||
|
||||
match write_all_half(&mut self.write, &response).await {
|
||||
Ok(_) => (),
|
||||
@@ -1023,7 +1322,8 @@ where
|
||||
};
|
||||
|
||||
if !server.in_transaction() {
|
||||
self.stats.transaction(self.process_id, server.server_id());
|
||||
self.stats.transaction();
|
||||
server.stats().transaction(&self.application_name);
|
||||
|
||||
// Release server back to the pool if we are in transaction mode.
|
||||
// If we are in session mode, we keep the server until the client disconnects.
|
||||
@@ -1044,11 +1344,11 @@ where
|
||||
// The server is no longer bound to us, we can't cancel it's queries anymore.
|
||||
debug!("Releasing server back into the pool");
|
||||
server.checkin_cleanup().await?;
|
||||
self.stats.server_idle(server.server_id());
|
||||
server.stats().idle();
|
||||
self.connected_to_server = false;
|
||||
|
||||
self.release();
|
||||
self.stats.client_idle(self.process_id);
|
||||
self.stats.idle();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1088,6 +1388,7 @@ where
|
||||
server: &mut Server,
|
||||
address: &Address,
|
||||
pool: &ConnectionPool,
|
||||
client_stats: &ClientStats,
|
||||
) -> Result<(), Error> {
|
||||
debug!("Sending {} to server", code);
|
||||
|
||||
@@ -1103,7 +1404,9 @@ where
|
||||
// Read all data the server has to offer, which can be multiple messages
|
||||
// buffered in 8196 bytes chunks.
|
||||
loop {
|
||||
let response = self.receive_server_message(server, address, pool).await?;
|
||||
let response = self
|
||||
.receive_server_message(server, address, pool, client_stats)
|
||||
.await?;
|
||||
|
||||
match write_all_half(&mut self.write, &response).await {
|
||||
Ok(_) => (),
|
||||
@@ -1119,10 +1422,10 @@ where
|
||||
}
|
||||
|
||||
// Report query executed statistics.
|
||||
self.stats.query(
|
||||
self.process_id,
|
||||
server.server_id(),
|
||||
Instant::now().duration_since(query_start).as_millis(),
|
||||
client_stats.query();
|
||||
server.stats().query(
|
||||
Instant::now().duration_since(query_start).as_millis() as u64,
|
||||
&self.application_name,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
@@ -1138,7 +1441,7 @@ where
|
||||
match server.send(message).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => {
|
||||
pool.ban(address, BanReason::MessageSendFailed, self.process_id);
|
||||
pool.ban(address, BanReason::MessageSendFailed, Some(&self.stats));
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
@@ -1149,6 +1452,7 @@ where
|
||||
server: &mut Server,
|
||||
address: &Address,
|
||||
pool: &ConnectionPool,
|
||||
client_stats: &ClientStats,
|
||||
) -> Result<BytesMut, Error> {
|
||||
if pool.settings.user.statement_timeout > 0 {
|
||||
match tokio::time::timeout(
|
||||
@@ -1160,7 +1464,7 @@ where
|
||||
Ok(result) => match result {
|
||||
Ok(message) => Ok(message),
|
||||
Err(err) => {
|
||||
pool.ban(address, BanReason::MessageReceiveFailed, self.process_id);
|
||||
pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats));
|
||||
error_response_terminal(
|
||||
&mut self.write,
|
||||
&format!("error receiving data from server: {:?}", err),
|
||||
@@ -1175,7 +1479,7 @@ where
|
||||
address, pool.settings.user.username
|
||||
);
|
||||
server.mark_bad();
|
||||
pool.ban(address, BanReason::StatementTimeout, self.process_id);
|
||||
pool.ban(address, BanReason::StatementTimeout, Some(client_stats));
|
||||
error_response_terminal(&mut self.write, "pool statement timeout").await?;
|
||||
Err(Error::StatementTimeout)
|
||||
}
|
||||
@@ -1184,7 +1488,7 @@ where
|
||||
match server.recv().await {
|
||||
Ok(message) => Ok(message),
|
||||
Err(err) => {
|
||||
pool.ban(address, BanReason::MessageReceiveFailed, self.process_id);
|
||||
pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats));
|
||||
error_response_terminal(
|
||||
&mut self.write,
|
||||
&format!("error receiving data from server: {:?}", err),
|
||||
@@ -1204,9 +1508,9 @@ impl<S, T> Drop for Client<S, T> {
|
||||
|
||||
// Dirty shutdown
|
||||
// TODO: refactor, this is not the best way to handle state management.
|
||||
self.stats.client_disconnecting(self.process_id);
|
||||
if self.connected_to_server && self.last_server_id.is_some() {
|
||||
self.stats.server_idle(self.last_server_id.unwrap());
|
||||
|
||||
if self.connected_to_server && self.last_server_stats.is_some() {
|
||||
self.last_server_stats.as_ref().unwrap().idle();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
374
src/config.rs
374
src/config.rs
@@ -12,9 +12,11 @@ use std::sync::Arc;
|
||||
use tokio::fs::File;
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
use crate::dns_cache::CachedResolver;
|
||||
use crate::errors::Error;
|
||||
use crate::pool::{ClientServerMap, ConnectionPool};
|
||||
use crate::sharding::ShardingFunction;
|
||||
use crate::stats::AddressStats;
|
||||
use crate::tls::{load_certs, load_keys};
|
||||
|
||||
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
@@ -62,7 +64,7 @@ impl PartialEq<Role> for Option<Role> {
|
||||
}
|
||||
|
||||
/// Address identifying a PostgreSQL server uniquely.
|
||||
#[derive(Clone, PartialEq, Hash, std::cmp::Eq, Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Address {
|
||||
/// Unique ID per addressable Postgres server.
|
||||
pub id: usize,
|
||||
@@ -96,6 +98,9 @@ pub struct Address {
|
||||
|
||||
/// List of addresses to receive mirrored traffic.
|
||||
pub mirrors: Vec<Address>,
|
||||
|
||||
/// Address stats
|
||||
pub stats: Arc<AddressStats>,
|
||||
}
|
||||
|
||||
impl Default for Address {
|
||||
@@ -112,10 +117,46 @@ impl Default for Address {
|
||||
username: String::from("username"),
|
||||
pool_name: String::from("pool_name"),
|
||||
mirrors: Vec::new(),
|
||||
stats: Arc::new(AddressStats::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We need to implement PartialEq by ourselves so we skip stats in the comparison
|
||||
impl PartialEq for Address {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.id == other.id
|
||||
&& self.host == other.host
|
||||
&& self.port == other.port
|
||||
&& self.shard == other.shard
|
||||
&& self.address_index == other.address_index
|
||||
&& self.replica_number == other.replica_number
|
||||
&& self.database == other.database
|
||||
&& self.role == other.role
|
||||
&& self.username == other.username
|
||||
&& self.pool_name == other.pool_name
|
||||
&& self.mirrors == other.mirrors
|
||||
}
|
||||
}
|
||||
impl Eq for Address {}
|
||||
|
||||
// We need to implement Hash by ourselves so we skip stats in the comparison
|
||||
impl Hash for Address {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
self.id.hash(state);
|
||||
self.host.hash(state);
|
||||
self.port.hash(state);
|
||||
self.shard.hash(state);
|
||||
self.address_index.hash(state);
|
||||
self.replica_number.hash(state);
|
||||
self.database.hash(state);
|
||||
self.role.hash(state);
|
||||
self.username.hash(state);
|
||||
self.pool_name.hash(state);
|
||||
self.mirrors.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl Address {
|
||||
/// Address name (aka database) used in `SHOW STATS`, `SHOW DATABASES`, and `SHOW POOLS`.
|
||||
pub fn name(&self) -> String {
|
||||
@@ -137,8 +178,13 @@ impl Address {
|
||||
#[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize, Debug)]
|
||||
pub struct User {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
pub password: Option<String>,
|
||||
pub server_username: Option<String>,
|
||||
pub server_password: Option<String>,
|
||||
pub pool_size: u32,
|
||||
pub min_pool_size: Option<u32>,
|
||||
pub pool_mode: Option<PoolMode>,
|
||||
pub server_lifetime: Option<u64>,
|
||||
#[serde(default)] // 0
|
||||
pub statement_timeout: u64,
|
||||
}
|
||||
@@ -147,13 +193,38 @@ impl Default for User {
|
||||
fn default() -> User {
|
||||
User {
|
||||
username: String::from("postgres"),
|
||||
password: String::new(),
|
||||
password: None,
|
||||
server_username: None,
|
||||
server_password: None,
|
||||
pool_size: 15,
|
||||
min_pool_size: None,
|
||||
statement_timeout: 0,
|
||||
pool_mode: None,
|
||||
server_lifetime: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl User {
|
||||
fn validate(&self) -> Result<(), Error> {
|
||||
match self.min_pool_size {
|
||||
Some(min_pool_size) => {
|
||||
if min_pool_size > self.pool_size {
|
||||
error!(
|
||||
"min_pool_size of {} cannot be larger than pool_size of {}",
|
||||
min_pool_size, self.pool_size
|
||||
);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
}
|
||||
|
||||
None => (),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// General configuration.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
pub struct General {
|
||||
@@ -161,7 +232,7 @@ pub struct General {
|
||||
pub host: String,
|
||||
|
||||
#[serde(default = "General::default_port")]
|
||||
pub port: i16,
|
||||
pub port: u16,
|
||||
|
||||
pub enable_prometheus_exporter: Option<bool>,
|
||||
pub prometheus_exporter_port: i16,
|
||||
@@ -185,6 +256,12 @@ pub struct General {
|
||||
#[serde(default)] // False
|
||||
pub log_client_disconnections: bool,
|
||||
|
||||
#[serde(default)] // False
|
||||
pub dns_cache_enabled: bool,
|
||||
|
||||
#[serde(default = "General::default_dns_max_ttl")]
|
||||
pub dns_max_ttl: u64,
|
||||
|
||||
#[serde(default = "General::default_shutdown_timeout")]
|
||||
pub shutdown_timeout: u64,
|
||||
|
||||
@@ -197,16 +274,34 @@ pub struct General {
|
||||
#[serde(default = "General::default_ban_time")]
|
||||
pub ban_time: i64,
|
||||
|
||||
#[serde(default = "General::default_idle_client_in_transaction_timeout")]
|
||||
pub idle_client_in_transaction_timeout: u64,
|
||||
|
||||
#[serde(default = "General::default_server_lifetime")]
|
||||
pub server_lifetime: u64,
|
||||
|
||||
#[serde(default = "General::default_worker_threads")]
|
||||
pub worker_threads: usize,
|
||||
|
||||
#[serde(default)] // False
|
||||
pub autoreload: bool,
|
||||
#[serde(default)] // None
|
||||
pub autoreload: Option<u64>,
|
||||
|
||||
pub tls_certificate: Option<String>,
|
||||
pub tls_private_key: Option<String>,
|
||||
|
||||
#[serde(default)] // false
|
||||
pub server_tls: bool,
|
||||
|
||||
#[serde(default)] // false
|
||||
pub verify_server_certificate: bool,
|
||||
|
||||
pub admin_username: String,
|
||||
pub admin_password: String,
|
||||
|
||||
// Support for auth query
|
||||
pub auth_query: Option<String>,
|
||||
pub auth_query_user: Option<String>,
|
||||
pub auth_query_password: Option<String>,
|
||||
}
|
||||
|
||||
impl General {
|
||||
@@ -214,17 +309,21 @@ impl General {
|
||||
"0.0.0.0".into()
|
||||
}
|
||||
|
||||
pub fn default_port() -> i16 {
|
||||
pub fn default_port() -> u16 {
|
||||
5432
|
||||
}
|
||||
|
||||
pub fn default_server_lifetime() -> u64 {
|
||||
1000 * 60 * 60 * 24 // 24 hours
|
||||
}
|
||||
|
||||
pub fn default_connect_timeout() -> u64 {
|
||||
1000
|
||||
}
|
||||
|
||||
// These keepalive defaults should detect a dead connection within 30 seconds.
|
||||
// Tokio defaults to disabling keepalives which keeps dead connections around indefinitely.
|
||||
// This can lead to permenant server pool exhaustion
|
||||
// This can lead to permanent server pool exhaustion
|
||||
pub fn default_tcp_keepalives_idle() -> u64 {
|
||||
5 // 5 seconds
|
||||
}
|
||||
@@ -245,6 +344,10 @@ impl General {
|
||||
60000
|
||||
}
|
||||
|
||||
pub fn default_dns_max_ttl() -> u64 {
|
||||
30
|
||||
}
|
||||
|
||||
pub fn default_healthcheck_timeout() -> u64 {
|
||||
1000
|
||||
}
|
||||
@@ -260,6 +363,10 @@ impl General {
|
||||
pub fn default_worker_threads() -> usize {
|
||||
4
|
||||
}
|
||||
|
||||
pub fn default_idle_client_in_transaction_timeout() -> u64 {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for General {
|
||||
@@ -276,16 +383,25 @@ impl Default for General {
|
||||
healthcheck_delay: Self::default_healthcheck_delay(),
|
||||
ban_time: Self::default_ban_time(),
|
||||
worker_threads: Self::default_worker_threads(),
|
||||
idle_client_in_transaction_timeout: Self::default_idle_client_in_transaction_timeout(),
|
||||
tcp_keepalives_idle: Self::default_tcp_keepalives_idle(),
|
||||
tcp_keepalives_count: Self::default_tcp_keepalives_count(),
|
||||
tcp_keepalives_interval: Self::default_tcp_keepalives_interval(),
|
||||
log_client_connections: false,
|
||||
log_client_disconnections: false,
|
||||
autoreload: false,
|
||||
autoreload: None,
|
||||
dns_cache_enabled: false,
|
||||
dns_max_ttl: Self::default_dns_max_ttl(),
|
||||
tls_certificate: None,
|
||||
tls_private_key: None,
|
||||
server_tls: false,
|
||||
verify_server_certificate: false,
|
||||
admin_username: String::from("admin"),
|
||||
admin_password: String::from("admin"),
|
||||
auth_query: None,
|
||||
auth_query_user: None,
|
||||
auth_query_password: None,
|
||||
server_lifetime: 1000 * 3600 * 24, // 24 hours,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -301,6 +417,7 @@ pub enum PoolMode {
|
||||
#[serde(alias = "session", alias = "Session")]
|
||||
Session,
|
||||
}
|
||||
|
||||
impl ToString for PoolMode {
|
||||
fn to_string(&self) -> String {
|
||||
match *self {
|
||||
@@ -349,6 +466,8 @@ pub struct Pool {
|
||||
|
||||
pub idle_timeout: Option<u64>,
|
||||
|
||||
pub server_lifetime: Option<u64>,
|
||||
|
||||
pub sharding_function: ShardingFunction,
|
||||
|
||||
#[serde(default = "Pool::default_automatic_sharding_key")]
|
||||
@@ -358,9 +477,13 @@ pub struct Pool {
|
||||
pub shard_id_regex: Option<String>,
|
||||
pub regex_search_limit: Option<usize>,
|
||||
|
||||
pub auth_query: Option<String>,
|
||||
pub auth_query_user: Option<String>,
|
||||
pub auth_query_password: Option<String>,
|
||||
|
||||
pub shards: BTreeMap<String, Shard>,
|
||||
pub users: BTreeMap<String, User>,
|
||||
// Note, don't put simple fields below these configs. There's a compatability issue with TOML that makes it
|
||||
// Note, don't put simple fields below these configs. There's a compatibility issue with TOML that makes it
|
||||
// incompatible to have simple fields in TOML after complex objects. See
|
||||
// https://users.rust-lang.org/t/why-toml-to-string-get-error-valueaftertable/85903
|
||||
}
|
||||
@@ -372,6 +495,12 @@ impl Pool {
|
||||
s.finish()
|
||||
}
|
||||
|
||||
pub fn is_auth_query_configured(&self) -> bool {
|
||||
self.auth_query_password.is_some()
|
||||
&& self.auth_query_user.is_some()
|
||||
&& self.auth_query_password.is_some()
|
||||
}
|
||||
|
||||
pub fn default_pool_mode() -> PoolMode {
|
||||
PoolMode::Transaction
|
||||
}
|
||||
@@ -443,6 +572,10 @@ impl Pool {
|
||||
None => None,
|
||||
};
|
||||
|
||||
for (_, user) in &self.users {
|
||||
user.validate()?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -464,6 +597,10 @@ impl Default for Pool {
|
||||
sharding_key_regex: None,
|
||||
shard_id_regex: None,
|
||||
regex_search_limit: Some(1000),
|
||||
auth_query: None,
|
||||
auth_query_user: None,
|
||||
auth_query_password: None,
|
||||
server_lifetime: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -513,7 +650,7 @@ impl Shard {
|
||||
|
||||
if primary_count > 1 {
|
||||
error!(
|
||||
"Shard {} has more than on primary configured",
|
||||
"Shard {} has more than one primary configured",
|
||||
self.database
|
||||
);
|
||||
return Err(Error::BadConfig);
|
||||
@@ -542,6 +679,56 @@ impl Default for Shard {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct Plugins {
|
||||
pub intercept: Option<Intercept>,
|
||||
pub table_access: Option<TableAccess>,
|
||||
pub query_logger: Option<QueryLogger>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct Intercept {
|
||||
pub enabled: bool,
|
||||
pub queries: BTreeMap<String, Query>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct TableAccess {
|
||||
pub enabled: bool,
|
||||
pub tables: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct QueryLogger {
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
impl Intercept {
|
||||
pub fn substitute(&mut self, db: &str, user: &str) {
|
||||
for (_, query) in self.queries.iter_mut() {
|
||||
query.substitute(db, user);
|
||||
query.query = query.query.to_ascii_lowercase();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct Query {
|
||||
pub query: String,
|
||||
pub schema: Vec<Vec<String>>,
|
||||
pub result: Vec<Vec<String>>,
|
||||
}
|
||||
|
||||
impl Query {
|
||||
pub fn substitute(&mut self, db: &str, user: &str) {
|
||||
for col in self.result.iter_mut() {
|
||||
for i in 0..col.len() {
|
||||
col[i] = col[i].replace("${USER}", user).replace("${DATABASE}", db);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration wrapper.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
pub struct Config {
|
||||
@@ -560,13 +747,36 @@ pub struct Config {
|
||||
pub path: String,
|
||||
|
||||
pub general: General,
|
||||
pub plugins: Option<Plugins>,
|
||||
pub pools: HashMap<String, Pool>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn is_auth_query_configured(&self) -> bool {
|
||||
self.pools
|
||||
.iter()
|
||||
.any(|(_name, pool)| pool.is_auth_query_configured())
|
||||
}
|
||||
|
||||
pub fn default_path() -> String {
|
||||
String::from("pgcat.toml")
|
||||
}
|
||||
|
||||
pub fn fill_up_auth_query_config(&mut self) {
|
||||
for (_name, pool) in self.pools.iter_mut() {
|
||||
if pool.auth_query.is_none() {
|
||||
pool.auth_query = self.general.auth_query.clone();
|
||||
}
|
||||
|
||||
if pool.auth_query_user.is_none() {
|
||||
pool.auth_query_user = self.general.auth_query_user.clone();
|
||||
}
|
||||
|
||||
if pool.auth_query_password.is_none() {
|
||||
pool.auth_query_password = self.general.auth_query_password.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
@@ -575,6 +785,7 @@ impl Default for Config {
|
||||
path: Self::default_path(),
|
||||
general: General::default(),
|
||||
pools: HashMap::default(),
|
||||
plugins: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -655,6 +866,13 @@ impl From<&Config> for std::collections::HashMap<String, String> {
|
||||
config.general.healthcheck_delay.to_string(),
|
||||
),
|
||||
("ban_time".to_string(), config.general.ban_time.to_string()),
|
||||
(
|
||||
"idle_client_in_transaction_timeout".to_string(),
|
||||
config
|
||||
.general
|
||||
.idle_client_in_transaction_timeout
|
||||
.to_string(),
|
||||
),
|
||||
];
|
||||
|
||||
r.append(&mut static_settings);
|
||||
@@ -666,6 +884,10 @@ impl Config {
|
||||
/// Print current configuration.
|
||||
pub fn show(&self) {
|
||||
info!("Ban time: {}s", self.general.ban_time);
|
||||
info!(
|
||||
"Idle client in transaction timeout: {}ms",
|
||||
self.general.idle_client_in_transaction_timeout
|
||||
);
|
||||
info!("Worker threads: {}", self.general.worker_threads);
|
||||
info!(
|
||||
"Healthcheck timeout: {}ms",
|
||||
@@ -683,6 +905,10 @@ impl Config {
|
||||
);
|
||||
info!("Shutdown timeout: {}ms", self.general.shutdown_timeout);
|
||||
info!("Healthcheck delay: {}ms", self.general.healthcheck_delay);
|
||||
info!(
|
||||
"Default max server lifetime: {}ms",
|
||||
self.general.server_lifetime
|
||||
);
|
||||
match self.general.tls_certificate.clone() {
|
||||
Some(tls_certificate) => {
|
||||
info!("TLS certificate: {}", tls_certificate);
|
||||
@@ -701,6 +927,11 @@ impl Config {
|
||||
info!("TLS support is disabled");
|
||||
}
|
||||
};
|
||||
info!("Server TLS enabled: {}", self.general.server_tls);
|
||||
info!(
|
||||
"Server TLS certificate verification: {}",
|
||||
self.general.verify_server_certificate
|
||||
);
|
||||
|
||||
for (pool_name, pool_config) in &self.pools {
|
||||
// TODO: Make this output prettier (maybe a table?)
|
||||
@@ -715,8 +946,9 @@ impl Config {
|
||||
.to_string()
|
||||
);
|
||||
info!(
|
||||
"[pool: {}] Pool mode: {:?}",
|
||||
pool_name, pool_config.pool_mode
|
||||
"[pool: {}] Default pool mode: {}",
|
||||
pool_name,
|
||||
pool_config.pool_mode.to_string()
|
||||
);
|
||||
info!(
|
||||
"[pool: {}] Load Balancing mode: {:?}",
|
||||
@@ -758,21 +990,101 @@ impl Config {
|
||||
pool_name,
|
||||
pool_config.users.len()
|
||||
);
|
||||
info!(
|
||||
"[pool: {}] Max server lifetime: {}",
|
||||
pool_name,
|
||||
match pool_config.server_lifetime {
|
||||
Some(server_lifetime) => format!("{}ms", server_lifetime),
|
||||
None => "default".to_string(),
|
||||
}
|
||||
);
|
||||
|
||||
for user in &pool_config.users {
|
||||
info!(
|
||||
"[pool: {}][user: {}] Pool size: {}",
|
||||
pool_name, user.1.username, user.1.pool_size,
|
||||
);
|
||||
info!(
|
||||
"[pool: {}][user: {}] Minimum pool size: {}",
|
||||
pool_name,
|
||||
user.1.username,
|
||||
user.1.min_pool_size.unwrap_or(0)
|
||||
);
|
||||
info!(
|
||||
"[pool: {}][user: {}] Statement timeout: {}",
|
||||
pool_name, user.1.username, user.1.statement_timeout
|
||||
)
|
||||
);
|
||||
info!(
|
||||
"[pool: {}][user: {}] Pool mode: {}",
|
||||
pool_name,
|
||||
user.1.username,
|
||||
match user.1.pool_mode {
|
||||
Some(pool_mode) => pool_mode.to_string(),
|
||||
None => pool_config.pool_mode.to_string(),
|
||||
}
|
||||
);
|
||||
info!(
|
||||
"[pool: {}][user: {}] Max server lifetime: {}",
|
||||
pool_name,
|
||||
user.1.username,
|
||||
match user.1.server_lifetime {
|
||||
Some(server_lifetime) => format!("{}ms", server_lifetime),
|
||||
None => "default".to_string(),
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn validate(&mut self) -> Result<(), Error> {
|
||||
// Validation for auth_query feature
|
||||
if self.general.auth_query.is_some()
|
||||
&& (self.general.auth_query_user.is_none()
|
||||
|| self.general.auth_query_password.is_none())
|
||||
{
|
||||
error!(
|
||||
"If auth_query is specified, \
|
||||
you need to provide a value \
|
||||
for `auth_query_user`, \
|
||||
`auth_query_password`"
|
||||
);
|
||||
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
|
||||
for (name, pool) in self.pools.iter() {
|
||||
if pool.auth_query.is_some()
|
||||
&& (pool.auth_query_user.is_none() || pool.auth_query_password.is_none())
|
||||
{
|
||||
error!(
|
||||
"Error in pool {{ {} }}. \
|
||||
If auth_query is specified, you need \
|
||||
to provide a value for `auth_query_user`, \
|
||||
`auth_query_password`",
|
||||
name
|
||||
);
|
||||
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
|
||||
for (_name, user_data) in pool.users.iter() {
|
||||
if (pool.auth_query.is_none()
|
||||
|| pool.auth_query_password.is_none()
|
||||
|| pool.auth_query_user.is_none())
|
||||
&& user_data.password.is_none()
|
||||
{
|
||||
error!(
|
||||
"Error in pool {{ {} }}. \
|
||||
You have to specify a user password \
|
||||
for every pool if auth_query is not specified",
|
||||
name
|
||||
);
|
||||
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate TLS!
|
||||
match self.general.tls_certificate.clone() {
|
||||
Some(tls_certificate) => {
|
||||
@@ -819,6 +1131,12 @@ pub fn get_config() -> Config {
|
||||
(*(*CONFIG.load())).clone()
|
||||
}
|
||||
|
||||
pub fn get_idle_client_in_transaction_timeout() -> u64 {
|
||||
(*(*CONFIG.load()))
|
||||
.general
|
||||
.idle_client_in_transaction_timeout
|
||||
}
|
||||
|
||||
/// Parse the configuration file located at the path.
|
||||
pub async fn parse(path: &str) -> Result<(), Error> {
|
||||
let mut contents = String::new();
|
||||
@@ -846,6 +1164,7 @@ pub async fn parse(path: &str) -> Result<(), Error> {
|
||||
}
|
||||
};
|
||||
|
||||
config.fill_up_auth_query_config();
|
||||
config.validate()?;
|
||||
|
||||
config.path = path.to_string();
|
||||
@@ -858,6 +1177,7 @@ pub async fn parse(path: &str) -> Result<(), Error> {
|
||||
|
||||
pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, Error> {
|
||||
let old_config = get_config();
|
||||
|
||||
match parse(&old_config.path).await {
|
||||
Ok(()) => (),
|
||||
Err(err) => {
|
||||
@@ -865,14 +1185,18 @@ pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, E
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
};
|
||||
|
||||
let new_config = get_config();
|
||||
|
||||
if old_config.pools != new_config.pools {
|
||||
info!("Pool configuration changed");
|
||||
match CachedResolver::from_config().await {
|
||||
Ok(_) => (),
|
||||
Err(err) => error!("DNS cache reinitialization error: {:?}", err),
|
||||
};
|
||||
|
||||
if old_config != new_config {
|
||||
info!("Config changed, reloading");
|
||||
ConnectionPool::from_config(client_server_map).await?;
|
||||
Ok(true)
|
||||
} else if old_config != new_config {
|
||||
Ok(true)
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
@@ -889,6 +1213,7 @@ mod test {
|
||||
assert_eq!(get_config().path, "pgcat.toml".to_string());
|
||||
|
||||
assert_eq!(get_config().general.ban_time, 60);
|
||||
assert_eq!(get_config().general.idle_client_in_transaction_timeout, 0);
|
||||
assert_eq!(get_config().general.idle_timeout, 30000);
|
||||
assert_eq!(get_config().pools.len(), 2);
|
||||
assert_eq!(get_config().pools["sharded_db"].shards.len(), 3);
|
||||
@@ -914,7 +1239,10 @@ mod test {
|
||||
"sharding_user"
|
||||
);
|
||||
assert_eq!(
|
||||
get_config().pools["sharded_db"].users["1"].password,
|
||||
get_config().pools["sharded_db"].users["1"]
|
||||
.password
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
"other_user"
|
||||
);
|
||||
assert_eq!(get_config().pools["sharded_db"].users["1"].pool_size, 21);
|
||||
@@ -939,10 +1267,16 @@ mod test {
|
||||
"simple_user"
|
||||
);
|
||||
assert_eq!(
|
||||
get_config().pools["simple_db"].users["0"].password,
|
||||
get_config().pools["simple_db"].users["0"]
|
||||
.password
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
"simple_user"
|
||||
);
|
||||
assert_eq!(get_config().pools["simple_db"].users["0"].pool_size, 5);
|
||||
assert_eq!(get_config().general.auth_query, None);
|
||||
assert_eq!(get_config().general.auth_query_user, None);
|
||||
assert_eq!(get_config().general.auth_query_password, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
410
src/dns_cache.rs
Normal file
410
src/dns_cache.rs
Normal file
@@ -0,0 +1,410 @@
|
||||
use crate::config::get_config;
|
||||
use crate::errors::Error;
|
||||
use arc_swap::ArcSwap;
|
||||
use log::{debug, error, info, warn};
|
||||
use once_cell::sync::Lazy;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::io;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
use tokio::time::{sleep, Duration};
|
||||
use trust_dns_resolver::error::{ResolveError, ResolveResult};
|
||||
use trust_dns_resolver::lookup_ip::LookupIp;
|
||||
use trust_dns_resolver::TokioAsyncResolver;
|
||||
|
||||
/// Cached Resolver Globally available
|
||||
pub static CACHED_RESOLVER: Lazy<ArcSwap<CachedResolver>> =
|
||||
Lazy::new(|| ArcSwap::from_pointee(CachedResolver::default()));
|
||||
|
||||
// Ip addressed are returned as a set of addresses
|
||||
// so we can compare.
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub struct AddrSet {
|
||||
set: HashSet<IpAddr>,
|
||||
}
|
||||
|
||||
impl AddrSet {
|
||||
fn new() -> AddrSet {
|
||||
AddrSet {
|
||||
set: HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LookupIp> for AddrSet {
|
||||
fn from(lookup_ip: LookupIp) -> Self {
|
||||
let mut addr_set = AddrSet::new();
|
||||
for address in lookup_ip.iter() {
|
||||
addr_set.set.insert(address);
|
||||
}
|
||||
addr_set
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// A CachedResolver is a DNS resolution cache mechanism with customizable expiration time.
|
||||
///
|
||||
/// The system works as follows:
|
||||
///
|
||||
/// When a host is to be resolved, if we have not resolved it before, a new resolution is
|
||||
/// executed and stored in the internal cache. Concurrently, every `dns_max_ttl` time, the
|
||||
/// cache is refreshed.
|
||||
///
|
||||
/// # Example:
|
||||
///
|
||||
/// ```
|
||||
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
|
||||
///
|
||||
/// # tokio_test::block_on(async {
|
||||
/// let config = CachedResolverConfig::default();
|
||||
/// let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
/// let addrset = resolver.lookup_ip("www.example.com.").await.unwrap();
|
||||
/// # })
|
||||
/// ```
|
||||
///
|
||||
/// // Now the ip resolution is stored in local cache and subsequent
|
||||
/// // calls will be returned from cache. Also, the cache is refreshed
|
||||
/// // and updated every 10 seconds.
|
||||
///
|
||||
/// // You can now check if an 'old' lookup differs from what it's currently
|
||||
/// // store in cache by using `has_changed`.
|
||||
/// resolver.has_changed("www.example.com.", addrset)
|
||||
#[derive(Default)]
|
||||
pub struct CachedResolver {
|
||||
// The configuration of the cached_resolver.
|
||||
config: CachedResolverConfig,
|
||||
|
||||
// This is the hash that contains the hash.
|
||||
data: Option<RwLock<HashMap<String, AddrSet>>>,
|
||||
|
||||
// The resolver to be used for DNS queries.
|
||||
resolver: Option<TokioAsyncResolver>,
|
||||
|
||||
// The RefreshLoop
|
||||
refresh_loop: RwLock<Option<tokio::task::JoinHandle<()>>>,
|
||||
}
|
||||
|
||||
///
|
||||
/// Configuration
|
||||
#[derive(Clone, Debug, Default, PartialEq)]
|
||||
pub struct CachedResolverConfig {
|
||||
/// Amount of time in secods that a resolved dns address is considered stale.
|
||||
dns_max_ttl: u64,
|
||||
|
||||
/// Enabled or disabled? (this is so we can reload config)
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
impl CachedResolverConfig {
|
||||
fn new(dns_max_ttl: u64, enabled: bool) -> Self {
|
||||
CachedResolverConfig {
|
||||
dns_max_ttl,
|
||||
enabled,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::config::Config> for CachedResolverConfig {
|
||||
fn from(config: crate::config::Config) -> Self {
|
||||
CachedResolverConfig::new(config.general.dns_max_ttl, config.general.dns_cache_enabled)
|
||||
}
|
||||
}
|
||||
|
||||
impl CachedResolver {
|
||||
///
|
||||
/// Returns a new Arc<CachedResolver> based on passed configuration.
|
||||
/// It also starts the loop that will refresh cache entries.
|
||||
///
|
||||
/// # Arguments:
|
||||
///
|
||||
/// * `config` - The `CachedResolverConfig` to be used to create the resolver.
|
||||
///
|
||||
/// # Example:
|
||||
///
|
||||
/// ```
|
||||
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
|
||||
///
|
||||
/// # tokio_test::block_on(async {
|
||||
/// let config = CachedResolverConfig::default();
|
||||
/// let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
/// # })
|
||||
/// ```
|
||||
///
|
||||
pub async fn new(
|
||||
config: CachedResolverConfig,
|
||||
data: Option<HashMap<String, AddrSet>>,
|
||||
) -> Result<Arc<Self>, io::Error> {
|
||||
// Construct a new Resolver with default configuration options
|
||||
let resolver = Some(TokioAsyncResolver::tokio_from_system_conf()?);
|
||||
|
||||
let data = if let Some(hash) = data {
|
||||
Some(RwLock::new(hash))
|
||||
} else {
|
||||
Some(RwLock::new(HashMap::new()))
|
||||
};
|
||||
|
||||
let instance = Arc::new(Self {
|
||||
config,
|
||||
resolver,
|
||||
data,
|
||||
refresh_loop: RwLock::new(None),
|
||||
});
|
||||
|
||||
if instance.enabled() {
|
||||
info!("Scheduling DNS refresh loop");
|
||||
let refresh_loop = tokio::task::spawn({
|
||||
let instance = instance.clone();
|
||||
async move {
|
||||
instance.refresh_dns_entries_loop().await;
|
||||
}
|
||||
});
|
||||
*(instance.refresh_loop.write().unwrap()) = Some(refresh_loop);
|
||||
}
|
||||
|
||||
Ok(instance)
|
||||
}
|
||||
|
||||
pub fn enabled(&self) -> bool {
|
||||
self.config.enabled
|
||||
}
|
||||
|
||||
// Schedules the refresher
|
||||
async fn refresh_dns_entries_loop(&self) {
|
||||
let resolver = TokioAsyncResolver::tokio_from_system_conf().unwrap();
|
||||
let interval = Duration::from_secs(self.config.dns_max_ttl);
|
||||
loop {
|
||||
debug!("Begin refreshing cached DNS addresses.");
|
||||
// To minimize the time we hold the lock, we first create
|
||||
// an array with keys.
|
||||
let mut hostnames: Vec<String> = Vec::new();
|
||||
{
|
||||
if let Some(ref data) = self.data {
|
||||
for hostname in data.read().unwrap().keys() {
|
||||
hostnames.push(hostname.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for hostname in hostnames.iter() {
|
||||
let addrset = self
|
||||
.fetch_from_cache(hostname.as_str())
|
||||
.expect("Could not obtain expected address from cache, this should not happen");
|
||||
|
||||
match resolver.lookup_ip(hostname).await {
|
||||
Ok(lookup_ip) => {
|
||||
let new_addrset = AddrSet::from(lookup_ip);
|
||||
debug!(
|
||||
"Obtained address for host ({}) -> ({:?})",
|
||||
hostname, new_addrset
|
||||
);
|
||||
|
||||
if addrset != new_addrset {
|
||||
debug!(
|
||||
"Addr changed from {:?} to {:?} updating cache.",
|
||||
addrset, new_addrset
|
||||
);
|
||||
self.store_in_cache(hostname, new_addrset);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
error!(
|
||||
"There was an error trying to resolv {}: ({}).",
|
||||
hostname, err
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
debug!("Finished refreshing cached DNS addresses.");
|
||||
sleep(interval).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a `AddrSet` given the specified hostname.
|
||||
///
|
||||
/// This method first tries to fetch the value from the cache, if it misses
|
||||
/// then it is resolved and stored in the cache. TTL from records is ignored.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `host` - A string slice referencing the hostname to be resolved.
|
||||
///
|
||||
/// # Example:
|
||||
///
|
||||
/// ```
|
||||
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
|
||||
///
|
||||
/// # tokio_test::block_on(async {
|
||||
/// let config = CachedResolverConfig::default();
|
||||
/// let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
/// let response = resolver.lookup_ip("www.google.com.");
|
||||
/// # })
|
||||
/// ```
|
||||
///
|
||||
pub async fn lookup_ip(&self, host: &str) -> ResolveResult<AddrSet> {
|
||||
debug!("Lookup up {} in cache", host);
|
||||
match self.fetch_from_cache(host) {
|
||||
Some(addr_set) => {
|
||||
debug!("Cache hit!");
|
||||
Ok(addr_set)
|
||||
}
|
||||
None => {
|
||||
debug!("Not found, executing a dns query!");
|
||||
if let Some(ref resolver) = self.resolver {
|
||||
let addr_set = AddrSet::from(resolver.lookup_ip(host).await?);
|
||||
debug!("Obtained: {:?}", addr_set);
|
||||
self.store_in_cache(host, addr_set.clone());
|
||||
Ok(addr_set)
|
||||
} else {
|
||||
Err(ResolveError::from("No resolver available"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Returns true if the stored host resolution differs from the AddrSet passed.
|
||||
pub fn has_changed(&self, host: &str, addr_set: &AddrSet) -> bool {
|
||||
if let Some(fetched_addr_set) = self.fetch_from_cache(host) {
|
||||
return fetched_addr_set != *addr_set;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// Fetches an AddrSet from the inner cache adquiring the read lock.
|
||||
fn fetch_from_cache(&self, key: &str) -> Option<AddrSet> {
|
||||
if let Some(ref hash) = self.data {
|
||||
if let Some(addr_set) = hash.read().unwrap().get(key) {
|
||||
return Some(addr_set.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// Sets up the global CACHED_RESOLVER static variable so we can globally use DNS
|
||||
// cache.
|
||||
pub async fn from_config() -> Result<(), Error> {
|
||||
let cached_resolver = CACHED_RESOLVER.load();
|
||||
let desired_config = CachedResolverConfig::from(get_config());
|
||||
|
||||
if cached_resolver.config != desired_config {
|
||||
if let Some(ref refresh_loop) = *(cached_resolver.refresh_loop.write().unwrap()) {
|
||||
warn!("Killing Dnscache refresh loop as its configuration is being reloaded");
|
||||
refresh_loop.abort()
|
||||
}
|
||||
let new_resolver = if let Some(ref data) = cached_resolver.data {
|
||||
let data = Some(data.read().unwrap().clone());
|
||||
CachedResolver::new(desired_config, data).await
|
||||
} else {
|
||||
CachedResolver::new(desired_config, None).await
|
||||
};
|
||||
|
||||
match new_resolver {
|
||||
Ok(ok) => {
|
||||
CACHED_RESOLVER.store(ok);
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
let message = format!("Error setting up cached_resolver. Error: {:?}, will continue without this feature.", err);
|
||||
Err(Error::DNSCachedError(message))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// Stores the AddrSet in cache adquiring the write lock.
|
||||
fn store_in_cache(&self, host: &str, addr_set: AddrSet) {
|
||||
if let Some(ref data) = self.data {
|
||||
data.write().unwrap().insert(host.to_string(), addr_set);
|
||||
} else {
|
||||
error!("Could not insert, Hash not initialized");
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use trust_dns_resolver::error::ResolveError;
|
||||
|
||||
#[tokio::test]
|
||||
async fn new() {
|
||||
let config = CachedResolverConfig {
|
||||
dns_max_ttl: 10,
|
||||
enabled: true,
|
||||
};
|
||||
let resolver = CachedResolver::new(config, None).await;
|
||||
assert!(resolver.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn lookup_ip() {
|
||||
let config = CachedResolverConfig {
|
||||
dns_max_ttl: 10,
|
||||
enabled: true,
|
||||
};
|
||||
let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
let response = resolver.lookup_ip("www.google.com.").await;
|
||||
assert!(response.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn has_changed() {
|
||||
let config = CachedResolverConfig {
|
||||
dns_max_ttl: 10,
|
||||
enabled: true,
|
||||
};
|
||||
let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
let hostname = "www.google.com.";
|
||||
let response = resolver.lookup_ip(hostname).await;
|
||||
let addr_set = response.unwrap();
|
||||
assert!(!resolver.has_changed(hostname, &addr_set));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unknown_host() {
|
||||
let config = CachedResolverConfig {
|
||||
dns_max_ttl: 10,
|
||||
enabled: true,
|
||||
};
|
||||
let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
let hostname = "www.idontexists.";
|
||||
let response = resolver.lookup_ip(hostname).await;
|
||||
assert!(matches!(response, Err(ResolveError { .. })));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn incorrect_address() {
|
||||
let config = CachedResolverConfig {
|
||||
dns_max_ttl: 10,
|
||||
enabled: true,
|
||||
};
|
||||
let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
let hostname = "w ww.idontexists.";
|
||||
let response = resolver.lookup_ip(hostname).await;
|
||||
assert!(matches!(response, Err(ResolveError { .. })));
|
||||
assert!(!resolver.has_changed(hostname, &AddrSet::new()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
// Ok, this test is based on the fact that google does DNS RR
|
||||
// and does not responds with every available ip everytime, so
|
||||
// if I cache here, it will miss after one cache iteration or two.
|
||||
async fn thread() {
|
||||
let config = CachedResolverConfig {
|
||||
dns_max_ttl: 10,
|
||||
enabled: true,
|
||||
};
|
||||
let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
let hostname = "www.google.com.";
|
||||
let response = resolver.lookup_ip(hostname).await;
|
||||
let addr_set = response.unwrap();
|
||||
assert!(!resolver.has_changed(hostname, &addr_set));
|
||||
let resolver_for_refresher = resolver.clone();
|
||||
let _thread_handle = tokio::task::spawn(async move {
|
||||
resolver_for_refresher.refresh_dns_entries_loop().await;
|
||||
});
|
||||
assert!(!resolver.has_changed(hostname, &addr_set));
|
||||
}
|
||||
}
|
||||
109
src/errors.rs
109
src/errors.rs
@@ -1,18 +1,123 @@
|
||||
/// Errors.
|
||||
//! Errors.
|
||||
|
||||
/// Various errors.
|
||||
#[derive(Debug, PartialEq)]
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub enum Error {
|
||||
SocketError(String),
|
||||
ClientSocketError(String, ClientIdentifier),
|
||||
ClientGeneralError(String, ClientIdentifier),
|
||||
ClientAuthImpossible(String),
|
||||
ClientAuthPassthroughError(String, ClientIdentifier),
|
||||
ClientBadStartup,
|
||||
ProtocolSyncError(String),
|
||||
BadQuery(String),
|
||||
ServerError,
|
||||
ServerStartupError(String, ServerIdentifier),
|
||||
ServerAuthError(String, ServerIdentifier),
|
||||
BadConfig,
|
||||
AllServersDown,
|
||||
ClientError(String),
|
||||
TlsError,
|
||||
StatementTimeout,
|
||||
DNSCachedError(String),
|
||||
ShuttingDown,
|
||||
ParseBytesError(String),
|
||||
AuthError(String),
|
||||
AuthPassthroughError(String),
|
||||
UnsupportedStatement,
|
||||
QueryRouterParserError(String),
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub struct ClientIdentifier {
|
||||
pub application_name: String,
|
||||
pub username: String,
|
||||
pub pool_name: String,
|
||||
}
|
||||
|
||||
impl ClientIdentifier {
|
||||
pub fn new(application_name: &str, username: &str, pool_name: &str) -> ClientIdentifier {
|
||||
ClientIdentifier {
|
||||
application_name: application_name.into(),
|
||||
username: username.into(),
|
||||
pool_name: pool_name.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ClientIdentifier {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{{ application_name: {}, username: {}, pool_name: {} }}",
|
||||
self.application_name, self.username, self.pool_name
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub struct ServerIdentifier {
|
||||
pub username: String,
|
||||
pub database: String,
|
||||
}
|
||||
|
||||
impl ServerIdentifier {
|
||||
pub fn new(username: &str, database: &str) -> ServerIdentifier {
|
||||
ServerIdentifier {
|
||||
username: username.into(),
|
||||
database: database.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ServerIdentifier {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{{ username: {}, database: {} }}",
|
||||
self.username, self.database
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match &self {
|
||||
&Error::ClientSocketError(error, client_identifier) => write!(
|
||||
f,
|
||||
"Error reading {} from client {}",
|
||||
error, client_identifier
|
||||
),
|
||||
&Error::ClientGeneralError(error, client_identifier) => {
|
||||
write!(f, "{} {}", error, client_identifier)
|
||||
}
|
||||
&Error::ClientAuthImpossible(username) => write!(
|
||||
f,
|
||||
"Client auth not possible, \
|
||||
no cleartext password set for username: {} \
|
||||
in config and auth passthrough (query_auth) \
|
||||
is not set up.",
|
||||
username
|
||||
),
|
||||
&Error::ClientAuthPassthroughError(error, client_identifier) => write!(
|
||||
f,
|
||||
"No cleartext password set, \
|
||||
and no auth passthrough could not \
|
||||
obtain the hash from server for {}, \
|
||||
the error was: {}",
|
||||
client_identifier, error
|
||||
),
|
||||
&Error::ServerStartupError(error, server_identifier) => write!(
|
||||
f,
|
||||
"Error reading {} on server startup {}",
|
||||
error, server_identifier,
|
||||
),
|
||||
&Error::ServerAuthError(error, server_identifier) => {
|
||||
write!(f, "{} for {}", error, server_identifier,)
|
||||
}
|
||||
|
||||
// The rest can use Debug.
|
||||
err => write!(f, "{:?}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
pub mod admin;
|
||||
pub mod auth_passthrough;
|
||||
pub mod client;
|
||||
pub mod config;
|
||||
pub mod constants;
|
||||
pub mod dns_cache;
|
||||
pub mod errors;
|
||||
pub mod messages;
|
||||
pub mod mirrors;
|
||||
pub mod multi_logger;
|
||||
pub mod plugins;
|
||||
pub mod pool;
|
||||
pub mod prometheus;
|
||||
pub mod query_router;
|
||||
pub mod scram;
|
||||
pub mod server;
|
||||
pub mod sharding;
|
||||
|
||||
82
src/main.rs
82
src/main.rs
@@ -36,6 +36,7 @@ extern crate sqlparser;
|
||||
extern crate tokio;
|
||||
extern crate tokio_rustls;
|
||||
extern crate toml;
|
||||
extern crate trust_dns_resolver;
|
||||
|
||||
#[cfg(not(target_env = "msvc"))]
|
||||
use jemallocator::Jemalloc;
|
||||
@@ -60,34 +61,19 @@ use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
mod admin;
|
||||
mod client;
|
||||
mod config;
|
||||
mod constants;
|
||||
mod errors;
|
||||
mod messages;
|
||||
mod mirrors;
|
||||
mod multi_logger;
|
||||
mod pool;
|
||||
mod prometheus;
|
||||
mod query_router;
|
||||
mod scram;
|
||||
mod server;
|
||||
mod sharding;
|
||||
mod stats;
|
||||
mod tls;
|
||||
|
||||
use crate::config::{get_config, reload_config, VERSION};
|
||||
use crate::pool::{ClientServerMap, ConnectionPool};
|
||||
use crate::prometheus::start_metric_server;
|
||||
use crate::stats::{Collector, Reporter, REPORTER};
|
||||
use pgcat::config::{get_config, reload_config, VERSION};
|
||||
use pgcat::dns_cache;
|
||||
use pgcat::messages::configure_socket;
|
||||
use pgcat::pool::{ClientServerMap, ConnectionPool};
|
||||
use pgcat::prometheus::start_metric_server;
|
||||
use pgcat::stats::{Collector, Reporter, REPORTER};
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
multi_logger::MultiLogger::init().unwrap();
|
||||
pgcat::multi_logger::MultiLogger::init().unwrap();
|
||||
|
||||
info!("Welcome to PgCat! Meow. (Version {})", VERSION);
|
||||
|
||||
if !query_router::QueryRouter::setup() {
|
||||
if !pgcat::query_router::QueryRouter::setup() {
|
||||
error!("Could not setup query router");
|
||||
std::process::exit(exitcode::CONFIG);
|
||||
}
|
||||
@@ -105,7 +91,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let runtime = Builder::new_multi_thread().worker_threads(1).build()?;
|
||||
|
||||
runtime.block_on(async {
|
||||
match config::parse(&config_file).await {
|
||||
match pgcat::config::parse(&config_file).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
error!("Config parse error: {:?}", err);
|
||||
@@ -162,8 +148,13 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new()));
|
||||
|
||||
// Statistics reporting.
|
||||
let (stats_tx, stats_rx) = mpsc::channel(500_000);
|
||||
REPORTER.store(Arc::new(Reporter::new(stats_tx.clone())));
|
||||
REPORTER.store(Arc::new(Reporter::default()));
|
||||
|
||||
// Starts (if enabled) dns cache before pools initialization
|
||||
match dns_cache::CachedResolver::from_config().await {
|
||||
Ok(_) => (),
|
||||
Err(err) => error!("DNS cache initialization error: {:?}", err),
|
||||
};
|
||||
|
||||
// Connection pool that allows to query all shards and replicas.
|
||||
match ConnectionPool::from_config(client_server_map.clone()).await {
|
||||
@@ -175,20 +166,23 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
};
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
let mut stats_collector = Collector::new(stats_rx, stats_tx.clone());
|
||||
let mut stats_collector = Collector::default();
|
||||
stats_collector.collect().await;
|
||||
});
|
||||
|
||||
info!("Config autoreloader: {}", config.general.autoreload);
|
||||
info!("Config autoreloader: {}", match config.general.autoreload {
|
||||
Some(interval) => format!("{} ms", interval),
|
||||
None => "disabled".into(),
|
||||
});
|
||||
|
||||
let mut autoreload_interval = tokio::time::interval(tokio::time::Duration::from_millis(15_000));
|
||||
let autoreload_client_server_map = client_server_map.clone();
|
||||
if let Some(interval) = config.general.autoreload {
|
||||
let mut autoreload_interval = tokio::time::interval(tokio::time::Duration::from_millis(interval));
|
||||
let autoreload_client_server_map = client_server_map.clone();
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
loop {
|
||||
autoreload_interval.tick().await;
|
||||
if config.general.autoreload {
|
||||
info!("Automatically reloading config");
|
||||
tokio::task::spawn(async move {
|
||||
loop {
|
||||
autoreload_interval.tick().await;
|
||||
debug!("Automatically reloading config");
|
||||
|
||||
if let Ok(changed) = reload_config(autoreload_client_server_map.clone()).await {
|
||||
if changed {
|
||||
@@ -196,8 +190,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
|
||||
#[cfg(windows)]
|
||||
let mut term_signal = win_signal::ctrl_close().unwrap();
|
||||
@@ -282,18 +278,20 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let drain_tx = drain_tx.clone();
|
||||
let client_server_map = client_server_map.clone();
|
||||
|
||||
let tls_certificate = config.general.tls_certificate.clone();
|
||||
let tls_certificate = get_config().general.tls_certificate.clone();
|
||||
|
||||
configure_socket(&socket);
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
let start = chrono::offset::Utc::now().naive_utc();
|
||||
|
||||
match client::client_entrypoint(
|
||||
match pgcat::client::client_entrypoint(
|
||||
socket,
|
||||
client_server_map,
|
||||
shutdown_rx,
|
||||
drain_tx,
|
||||
admin_only,
|
||||
tls_certificate.clone(),
|
||||
tls_certificate,
|
||||
config.general.log_client_connections,
|
||||
)
|
||||
.await
|
||||
@@ -301,7 +299,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
Ok(()) => {
|
||||
let duration = chrono::offset::Utc::now().naive_utc() - start;
|
||||
|
||||
if config.general.log_client_disconnections {
|
||||
if get_config().general.log_client_disconnections {
|
||||
info!(
|
||||
"Client {:?} disconnected, session duration: {}",
|
||||
addr,
|
||||
@@ -318,7 +316,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
|
||||
Err(err) => {
|
||||
match err {
|
||||
errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err),
|
||||
pgcat::errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err),
|
||||
_ => warn!("Client disconnected with error {:?}", err),
|
||||
}
|
||||
|
||||
|
||||
102
src/messages.rs
102
src/messages.rs
@@ -20,6 +20,10 @@ pub enum DataType {
|
||||
Text,
|
||||
Int4,
|
||||
Numeric,
|
||||
Bool,
|
||||
Oid,
|
||||
AnyArray,
|
||||
Any,
|
||||
}
|
||||
|
||||
impl From<&DataType> for i32 {
|
||||
@@ -28,6 +32,10 @@ impl From<&DataType> for i32 {
|
||||
DataType::Text => 25,
|
||||
DataType::Int4 => 23,
|
||||
DataType::Numeric => 1700,
|
||||
DataType::Bool => 16,
|
||||
DataType::Oid => 26,
|
||||
DataType::AnyArray => 2277,
|
||||
DataType::Any => 2276,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -116,7 +124,10 @@ where
|
||||
|
||||
/// Send the startup packet the server. We're pretending we're a Pg client.
|
||||
/// This tells the server which user we are and what database we want.
|
||||
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> {
|
||||
pub async fn startup<S>(stream: &mut S, user: &str, database: &str) -> Result<(), Error>
|
||||
where
|
||||
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
let mut bytes = BytesMut::with_capacity(25);
|
||||
|
||||
bytes.put_i32(196608); // Protocol number
|
||||
@@ -150,6 +161,21 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn ssl_request(stream: &mut TcpStream) -> Result<(), Error> {
|
||||
let mut bytes = BytesMut::with_capacity(12);
|
||||
|
||||
bytes.put_i32(8);
|
||||
bytes.put_i32(80877103);
|
||||
|
||||
match stream.write_all(&bytes).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(Error::SocketError(format!(
|
||||
"Error writing SSLRequest to server socket - Error: {:?}",
|
||||
err
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse the params the server sends as a key/value format.
|
||||
pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
|
||||
let mut result = HashMap::new();
|
||||
@@ -213,7 +239,13 @@ pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec<u8> {
|
||||
let output = md5.finalize_reset();
|
||||
|
||||
// Second pass
|
||||
md5.update(format!("{:x}", output));
|
||||
md5_hash_second_pass(&(format!("{:x}", output)), salt)
|
||||
}
|
||||
|
||||
pub fn md5_hash_second_pass(hash: &str, salt: &[u8]) -> Vec<u8> {
|
||||
let mut md5 = Md5::new();
|
||||
// Second pass
|
||||
md5.update(hash);
|
||||
md5.update(salt);
|
||||
|
||||
let mut password = format!("md5{:x}", md5.finalize())
|
||||
@@ -247,6 +279,20 @@ where
|
||||
write_all(stream, message).await
|
||||
}
|
||||
|
||||
pub async fn md5_password_with_hash<S>(stream: &mut S, hash: &str, salt: &[u8]) -> Result<(), Error>
|
||||
where
|
||||
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
let password = md5_hash_second_pass(hash, salt);
|
||||
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
|
||||
|
||||
message.put_u8(b'p');
|
||||
message.put_i32(password.len() as i32 + 4);
|
||||
message.put_slice(&password[..]);
|
||||
|
||||
write_all(stream, message).await
|
||||
}
|
||||
|
||||
/// Implements a response to our custom `SET SHARDING KEY`
|
||||
/// and `SET SERVER ROLE` commands.
|
||||
/// This tells the client we're ready for the next query.
|
||||
@@ -384,7 +430,7 @@ pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
|
||||
let mut res = BytesMut::new();
|
||||
let mut row_desc = BytesMut::new();
|
||||
|
||||
// how many colums we are storing
|
||||
// how many columns we are storing
|
||||
row_desc.put_i16(columns.len() as i16);
|
||||
|
||||
for (name, data_type) in columns {
|
||||
@@ -405,6 +451,10 @@ pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
|
||||
DataType::Text => -1,
|
||||
DataType::Int4 => 4,
|
||||
DataType::Numeric => -1,
|
||||
DataType::Bool => 1,
|
||||
DataType::Oid => 4,
|
||||
DataType::AnyArray => -1,
|
||||
DataType::Any => -1,
|
||||
};
|
||||
|
||||
row_desc.put_i16(type_size);
|
||||
@@ -443,6 +493,29 @@ pub fn data_row(row: &Vec<String>) -> BytesMut {
|
||||
res
|
||||
}
|
||||
|
||||
pub fn data_row_nullable(row: &Vec<Option<String>>) -> BytesMut {
|
||||
let mut res = BytesMut::new();
|
||||
let mut data_row = BytesMut::new();
|
||||
|
||||
data_row.put_i16(row.len() as i16);
|
||||
|
||||
for column in row {
|
||||
if let Some(column) = column {
|
||||
let column = column.as_bytes();
|
||||
data_row.put_i32(column.len() as i32);
|
||||
data_row.put_slice(column);
|
||||
} else {
|
||||
data_row.put_i32(-1 as i32);
|
||||
}
|
||||
}
|
||||
|
||||
res.put_u8(b'D');
|
||||
res.put_i32(data_row.len() as i32 + 4);
|
||||
res.put(data_row);
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
/// Create a CommandComplete message.
|
||||
pub fn command_complete(command: &str) -> BytesMut {
|
||||
let cmd = BytesMut::from(format!("{}\0", command).as_bytes());
|
||||
@@ -485,6 +558,29 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn write_all_flush<S>(stream: &mut S, buf: &[u8]) -> Result<(), Error>
|
||||
where
|
||||
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
match stream.write_all(buf).await {
|
||||
Ok(_) => match stream.flush().await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error flushing socket - Error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error writing to socket - Error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read a complete message from the socket.
|
||||
pub async fn read_message<S>(stream: &mut S) -> Result<BytesMut, Error>
|
||||
where
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
/// A mirrored PostgreSQL client.
|
||||
/// Packets arrive to us through a channel from the main client and we send them to the server.
|
||||
use bb8::Pool;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use parking_lot::RwLock;
|
||||
|
||||
use crate::config::{get_config, Address, Role, User};
|
||||
use crate::pool::{ClientServerMap, ServerPool};
|
||||
use crate::stats::get_reporter;
|
||||
use crate::pool::{ClientServerMap, PoolIdentifier, ServerPool};
|
||||
use crate::stats::PoolStats;
|
||||
use log::{error, info, trace, warn};
|
||||
use tokio::sync::mpsc::{channel, Receiver, Sender};
|
||||
|
||||
@@ -21,20 +24,25 @@ impl MirroredClient {
|
||||
async fn create_pool(&self) -> Pool<ServerPool> {
|
||||
let config = get_config();
|
||||
let default = std::time::Duration::from_millis(10_000).as_millis() as u64;
|
||||
let (connection_timeout, idle_timeout) = match config.pools.get(&self.address.pool_name) {
|
||||
Some(cfg) => (
|
||||
cfg.connect_timeout.unwrap_or(default),
|
||||
cfg.idle_timeout.unwrap_or(default),
|
||||
),
|
||||
None => (default, default),
|
||||
};
|
||||
let (connection_timeout, idle_timeout, cfg) =
|
||||
match config.pools.get(&self.address.pool_name) {
|
||||
Some(cfg) => (
|
||||
cfg.connect_timeout.unwrap_or(default),
|
||||
cfg.idle_timeout.unwrap_or(default),
|
||||
cfg.clone(),
|
||||
),
|
||||
None => (default, default, crate::config::Pool::default()),
|
||||
};
|
||||
|
||||
let identifier = PoolIdentifier::new(&self.database, &self.user.username);
|
||||
|
||||
let manager = ServerPool::new(
|
||||
self.address.clone(),
|
||||
self.user.clone(),
|
||||
self.database.as_str(),
|
||||
ClientServerMap::default(),
|
||||
get_reporter(),
|
||||
Arc::new(PoolStats::new(identifier, cfg.clone())),
|
||||
Arc::new(RwLock::new(None)),
|
||||
);
|
||||
|
||||
Pool::builder()
|
||||
|
||||
@@ -17,7 +17,7 @@ use log::{Level, Log, Metadata, Record, SetLoggerError};
|
||||
//
|
||||
// So to summarize, if no `STDOUT_LOG` env var is present, the logger is the default logger. If `STDOUT_LOG` is set, everything
|
||||
// but errors, that matches the log level set in the `STDOUT_LOG` env var is sent to stdout. You can have also some esoteric configuration
|
||||
// where you set `RUST_LOG=debug` and `STDOUT_LOG=info`, in here, erros will go to stderr, warns and infos to stdout and debugs to stderr.
|
||||
// where you set `RUST_LOG=debug` and `STDOUT_LOG=info`, in here, errors will go to stderr, warns and infos to stdout and debugs to stderr.
|
||||
//
|
||||
pub struct MultiLogger {
|
||||
stderr_logger: env_logger::Logger,
|
||||
|
||||
288
src/plugins/intercept.rs
Normal file
288
src/plugins/intercept.rs
Normal file
@@ -0,0 +1,288 @@
|
||||
//! The intercept plugin.
|
||||
//!
|
||||
//! It intercepts queries and returns fake results.
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use async_trait::async_trait;
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use sqlparser::ast::Statement;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use log::{debug, info};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
config::Intercept as InterceptConfig,
|
||||
errors::Error,
|
||||
messages::{command_complete, data_row_nullable, row_description, DataType},
|
||||
plugins::{Plugin, PluginOutput},
|
||||
pool::{PoolIdentifier, PoolMap},
|
||||
query_router::QueryRouter,
|
||||
};
|
||||
|
||||
pub static CONFIG: Lazy<ArcSwap<HashMap<PoolIdentifier, InterceptConfig>>> =
|
||||
Lazy::new(|| ArcSwap::from_pointee(HashMap::new()));
|
||||
|
||||
/// Check if the interceptor plugin has been enabled.
|
||||
pub fn enabled() -> bool {
|
||||
!CONFIG.load().is_empty()
|
||||
}
|
||||
|
||||
pub fn setup(intercept_config: &InterceptConfig, pools: &PoolMap) {
|
||||
let mut config = HashMap::new();
|
||||
for (identifier, _) in pools.iter() {
|
||||
let mut intercept_config = intercept_config.clone();
|
||||
intercept_config.substitute(&identifier.db, &identifier.user);
|
||||
config.insert(identifier.clone(), intercept_config);
|
||||
}
|
||||
|
||||
CONFIG.store(Arc::new(config));
|
||||
|
||||
info!("Intercepting {} queries", intercept_config.queries.len());
|
||||
}
|
||||
|
||||
pub fn disable() {
|
||||
CONFIG.store(Arc::new(HashMap::new()));
|
||||
}
|
||||
|
||||
// TODO: use these structs for deserialization
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Rule {
|
||||
query: String,
|
||||
schema: Vec<Column>,
|
||||
result: Vec<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Column {
|
||||
name: String,
|
||||
data_type: String,
|
||||
}
|
||||
|
||||
/// The intercept plugin.
|
||||
pub struct Intercept;
|
||||
|
||||
#[async_trait]
|
||||
impl Plugin for Intercept {
|
||||
async fn run(
|
||||
&mut self,
|
||||
query_router: &QueryRouter,
|
||||
ast: &Vec<Statement>,
|
||||
) -> Result<PluginOutput, Error> {
|
||||
if ast.is_empty() {
|
||||
return Ok(PluginOutput::Allow);
|
||||
}
|
||||
|
||||
let mut result = BytesMut::new();
|
||||
let query_map = match CONFIG.load().get(&PoolIdentifier::new(
|
||||
&query_router.pool_settings().db,
|
||||
&query_router.pool_settings().user.username,
|
||||
)) {
|
||||
Some(query_map) => query_map.clone(),
|
||||
None => return Ok(PluginOutput::Allow),
|
||||
};
|
||||
|
||||
for q in ast {
|
||||
// Normalization
|
||||
let q = q.to_string().to_ascii_lowercase();
|
||||
|
||||
for (_, target) in query_map.queries.iter() {
|
||||
if target.query.as_str() == q {
|
||||
debug!("Intercepting query: {}", q);
|
||||
|
||||
let rd = target
|
||||
.schema
|
||||
.iter()
|
||||
.map(|row| {
|
||||
let name = &row[0];
|
||||
let data_type = &row[1];
|
||||
(
|
||||
name.as_str(),
|
||||
match data_type.as_str() {
|
||||
"text" => DataType::Text,
|
||||
"anyarray" => DataType::AnyArray,
|
||||
"oid" => DataType::Oid,
|
||||
"bool" => DataType::Bool,
|
||||
"int4" => DataType::Int4,
|
||||
_ => DataType::Any,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect::<Vec<(&str, DataType)>>();
|
||||
|
||||
result.put(row_description(&rd));
|
||||
|
||||
target.result.iter().for_each(|row| {
|
||||
let row = row
|
||||
.iter()
|
||||
.map(|s| {
|
||||
let s = s.as_str().to_string();
|
||||
|
||||
if s == "" {
|
||||
None
|
||||
} else {
|
||||
Some(s)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<Option<String>>>();
|
||||
result.put(data_row_nullable(&row));
|
||||
});
|
||||
|
||||
result.put(command_complete("SELECT"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !result.is_empty() {
|
||||
result.put_u8(b'Z');
|
||||
result.put_i32(5);
|
||||
result.put_u8(b'I');
|
||||
|
||||
return Ok(PluginOutput::Intercept(result));
|
||||
} else {
|
||||
Ok(PluginOutput::Allow)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Make IntelliJ SQL plugin believe it's talking to an actual database
|
||||
/// instead of PgCat.
|
||||
#[allow(dead_code)]
|
||||
fn fool_datagrip(database: &str, user: &str) -> Value {
|
||||
json!([
|
||||
{
|
||||
"query": "select current_database() as a, current_schemas(false) as b",
|
||||
"schema": [
|
||||
{
|
||||
"name": "a",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "b",
|
||||
"data_type": "anyarray",
|
||||
},
|
||||
],
|
||||
|
||||
"result": [
|
||||
[database, "{public}"],
|
||||
],
|
||||
},
|
||||
{
|
||||
"query": "select current_database(), current_schema(), current_user",
|
||||
"schema": [
|
||||
{
|
||||
"name": "current_database",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "current_schema",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "current_user",
|
||||
"data_type": "text",
|
||||
}
|
||||
],
|
||||
|
||||
"result": [
|
||||
["sharded_db", "public", "sharding_user"],
|
||||
],
|
||||
},
|
||||
{
|
||||
"query": "select cast(n.oid as bigint) as id, datname as name, d.description, datistemplate as is_template, datallowconn as allow_connections, pg_catalog.pg_get_userbyid(n.datdba) as \"owner\" from pg_catalog.pg_database as n left join pg_catalog.pg_shdescription as d on n.oid = d.objoid order by case when datname = pg_catalog.current_database() then -cast(1 as bigint) else cast(n.oid as bigint) end",
|
||||
"schema": [
|
||||
{
|
||||
"name": "id",
|
||||
"data_type": "oid",
|
||||
},
|
||||
{
|
||||
"name": "name",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "description",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "is_template",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "allow_connections",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "owner",
|
||||
"data_type": "text",
|
||||
}
|
||||
],
|
||||
"result": [
|
||||
["16387", database, "", "f", "t", user],
|
||||
]
|
||||
},
|
||||
{
|
||||
"query": "select cast(r.oid as bigint) as role_id, rolname as role_name, rolsuper as is_super, rolinherit as is_inherit, rolcreaterole as can_createrole, rolcreatedb as can_createdb, rolcanlogin as can_login, rolreplication as is_replication, rolconnlimit as conn_limit, rolvaliduntil as valid_until, rolbypassrls as bypass_rls, rolconfig as config, d.description from pg_catalog.pg_roles as r left join pg_catalog.pg_shdescription as d on d.objoid = r.oid",
|
||||
"schema": [
|
||||
{
|
||||
"name": "role_id",
|
||||
"data_type": "oid",
|
||||
},
|
||||
{
|
||||
"name": "role_name",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "is_super",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "is_inherit",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "can_createrole",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "can_createdb",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "can_login",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "is_replication",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "conn_limit",
|
||||
"data_type": "int4",
|
||||
},
|
||||
{
|
||||
"name": "valid_until",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "bypass_rls",
|
||||
"data_type": "bool",
|
||||
},
|
||||
{
|
||||
"name": "config",
|
||||
"data_type": "text",
|
||||
},
|
||||
{
|
||||
"name": "description",
|
||||
"data_type": "text",
|
||||
},
|
||||
],
|
||||
"result": [
|
||||
["10", "postgres", "f", "t", "f", "f", "t", "f", "-1", "", "f", "", ""],
|
||||
["16419", user, "f", "t", "f", "f", "t", "f", "-1", "", "f", "", ""],
|
||||
]
|
||||
}
|
||||
])
|
||||
}
|
||||
43
src/plugins/mod.rs
Normal file
43
src/plugins/mod.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
//! The plugin ecosystem.
|
||||
//!
|
||||
//! Currently plugins only grant access or deny access to the database for a particual query.
|
||||
//! Example use cases:
|
||||
//! - block known bad queries
|
||||
//! - block access to system catalogs
|
||||
//! - block dangerous modifications like `DROP TABLE`
|
||||
//! - etc
|
||||
//!
|
||||
|
||||
pub mod intercept;
|
||||
pub mod query_logger;
|
||||
pub mod table_access;
|
||||
|
||||
use crate::{errors::Error, query_router::QueryRouter};
|
||||
use async_trait::async_trait;
|
||||
use bytes::BytesMut;
|
||||
use sqlparser::ast::Statement;
|
||||
|
||||
pub use intercept::Intercept;
|
||||
pub use query_logger::QueryLogger;
|
||||
pub use table_access::TableAccess;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum PluginOutput {
|
||||
Allow,
|
||||
Deny(String),
|
||||
Overwrite(Vec<Statement>),
|
||||
Intercept(BytesMut),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Plugin {
|
||||
// Run before the query is sent to the server.
|
||||
async fn run(
|
||||
&mut self,
|
||||
query_router: &QueryRouter,
|
||||
ast: &Vec<Statement>,
|
||||
) -> Result<PluginOutput, Error>;
|
||||
|
||||
// TODO: run after the result is returned
|
||||
// async fn callback(&mut self, query_router: &QueryRouter);
|
||||
}
|
||||
49
src/plugins/query_logger.rs
Normal file
49
src/plugins/query_logger.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
//! Log all queries to stdout (or somewhere else, why not).
|
||||
|
||||
use crate::{
|
||||
errors::Error,
|
||||
plugins::{Plugin, PluginOutput},
|
||||
query_router::QueryRouter,
|
||||
};
|
||||
use arc_swap::ArcSwap;
|
||||
use async_trait::async_trait;
|
||||
use log::info;
|
||||
use once_cell::sync::Lazy;
|
||||
use sqlparser::ast::Statement;
|
||||
use std::sync::Arc;
|
||||
|
||||
static ENABLED: Lazy<ArcSwap<bool>> = Lazy::new(|| ArcSwap::from_pointee(false));
|
||||
|
||||
pub struct QueryLogger;
|
||||
|
||||
pub fn setup() {
|
||||
ENABLED.store(Arc::new(true));
|
||||
|
||||
info!("Logging queries to stdout");
|
||||
}
|
||||
|
||||
pub fn disable() {
|
||||
ENABLED.store(Arc::new(false));
|
||||
}
|
||||
|
||||
pub fn enabled() -> bool {
|
||||
**ENABLED.load()
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Plugin for QueryLogger {
|
||||
async fn run(
|
||||
&mut self,
|
||||
_query_router: &QueryRouter,
|
||||
ast: &Vec<Statement>,
|
||||
) -> Result<PluginOutput, Error> {
|
||||
let query = ast
|
||||
.iter()
|
||||
.map(|q| q.to_string())
|
||||
.collect::<Vec<String>>()
|
||||
.join("; ");
|
||||
info!("{}", query);
|
||||
|
||||
Ok(PluginOutput::Allow)
|
||||
}
|
||||
}
|
||||
73
src/plugins/table_access.rs
Normal file
73
src/plugins/table_access.rs
Normal file
@@ -0,0 +1,73 @@
|
||||
//! This query router plugin will check if the user can access a particular
|
||||
//! table as part of their query. If they can't, the query will not be routed.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use sqlparser::ast::{visit_relations, Statement};
|
||||
|
||||
use crate::{
|
||||
config::TableAccess as TableAccessConfig,
|
||||
errors::Error,
|
||||
plugins::{Plugin, PluginOutput},
|
||||
query_router::QueryRouter,
|
||||
};
|
||||
|
||||
use log::{debug, info};
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use core::ops::ControlFlow;
|
||||
use once_cell::sync::Lazy;
|
||||
use std::sync::Arc;
|
||||
|
||||
static CONFIG: Lazy<ArcSwap<Vec<String>>> = Lazy::new(|| ArcSwap::from_pointee(vec![]));
|
||||
|
||||
pub fn setup(config: &TableAccessConfig) {
|
||||
CONFIG.store(Arc::new(config.tables.clone()));
|
||||
|
||||
info!("Blocking access to {} tables", config.tables.len());
|
||||
}
|
||||
|
||||
pub fn enabled() -> bool {
|
||||
!CONFIG.load().is_empty()
|
||||
}
|
||||
|
||||
pub fn disable() {
|
||||
CONFIG.store(Arc::new(vec![]));
|
||||
}
|
||||
|
||||
pub struct TableAccess;
|
||||
|
||||
#[async_trait]
|
||||
impl Plugin for TableAccess {
|
||||
async fn run(
|
||||
&mut self,
|
||||
_query_router: &QueryRouter,
|
||||
ast: &Vec<Statement>,
|
||||
) -> Result<PluginOutput, Error> {
|
||||
let mut found = None;
|
||||
let forbidden_tables = CONFIG.load();
|
||||
|
||||
visit_relations(ast, |relation| {
|
||||
let relation = relation.to_string();
|
||||
let parts = relation.split(".").collect::<Vec<&str>>();
|
||||
let table_name = parts.last().unwrap();
|
||||
|
||||
if forbidden_tables.contains(&table_name.to_string()) {
|
||||
found = Some(table_name.to_string());
|
||||
ControlFlow::<()>::Break(())
|
||||
} else {
|
||||
ControlFlow::<()>::Continue(())
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(found) = found {
|
||||
debug!("Blocking access to table \"{}\"", found);
|
||||
|
||||
Ok(PluginOutput::Deny(format!(
|
||||
"permission for table \"{}\" denied",
|
||||
found
|
||||
)))
|
||||
} else {
|
||||
Ok(PluginOutput::Allow)
|
||||
}
|
||||
}
|
||||
}
|
||||
245
src/pool.rs
245
src/pool.rs
@@ -20,9 +20,10 @@ use tokio::sync::Notify;
|
||||
use crate::config::{get_config, Address, General, LoadBalancingMode, PoolMode, Role, User};
|
||||
use crate::errors::Error;
|
||||
|
||||
use crate::auth_passthrough::AuthPassthrough;
|
||||
use crate::server::Server;
|
||||
use crate::sharding::ShardingFunction;
|
||||
use crate::stats::{get_reporter, Reporter};
|
||||
use crate::stats::{AddressStats, ClientStats, PoolStats, ServerStats};
|
||||
|
||||
pub type ProcessId = i32;
|
||||
pub type SecretKey = i32;
|
||||
@@ -51,7 +52,7 @@ pub enum BanReason {
|
||||
|
||||
/// An identifier for a PgCat pool,
|
||||
/// a database visible to clients.
|
||||
#[derive(Hash, Debug, Clone, PartialEq, Eq)]
|
||||
#[derive(Hash, Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub struct PoolIdentifier {
|
||||
// The name of the database clients want to connect to.
|
||||
pub db: String,
|
||||
@@ -60,6 +61,8 @@ pub struct PoolIdentifier {
|
||||
pub user: String,
|
||||
}
|
||||
|
||||
static POOL_REAPER_RATE: u64 = 30_000; // 30 seconds by default
|
||||
|
||||
impl PoolIdentifier {
|
||||
/// Create a new user/pool identifier.
|
||||
pub fn new(db: &str, user: &str) -> PoolIdentifier {
|
||||
@@ -90,6 +93,7 @@ pub struct PoolSettings {
|
||||
|
||||
// Connecting user.
|
||||
pub user: User,
|
||||
pub db: String,
|
||||
|
||||
// Default server role to connect to.
|
||||
pub default_role: Option<Role>,
|
||||
@@ -123,6 +127,11 @@ pub struct PoolSettings {
|
||||
|
||||
// Limit how much of each query is searched for a potential shard regex match
|
||||
pub regex_search_limit: usize,
|
||||
|
||||
// Auth query parameters
|
||||
pub auth_query: Option<String>,
|
||||
pub auth_query_user: Option<String>,
|
||||
pub auth_query_password: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for PoolSettings {
|
||||
@@ -132,6 +141,7 @@ impl Default for PoolSettings {
|
||||
load_balancing_mode: LoadBalancingMode::Random,
|
||||
shards: 1,
|
||||
user: User::default(),
|
||||
db: String::default(),
|
||||
default_role: None,
|
||||
query_parser_enabled: false,
|
||||
primary_reads_enabled: true,
|
||||
@@ -143,6 +153,9 @@ impl Default for PoolSettings {
|
||||
sharding_key_regex: None,
|
||||
shard_id_regex: None,
|
||||
regex_search_limit: 1000,
|
||||
auth_query: None,
|
||||
auth_query_user: None,
|
||||
auth_query_password: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -161,10 +174,6 @@ pub struct ConnectionPool {
|
||||
/// that should not be queried.
|
||||
banlist: BanList,
|
||||
|
||||
/// The statistics aggregator runs in a separate task
|
||||
/// and receives stats from clients, servers, and the pool.
|
||||
stats: Reporter,
|
||||
|
||||
/// The server information (K messages) have to be passed to the
|
||||
/// clients on startup. We pre-connect to all shards and replicas
|
||||
/// on pool creation and save the K messages here.
|
||||
@@ -185,6 +194,11 @@ pub struct ConnectionPool {
|
||||
/// If the pool has been paused or not.
|
||||
paused: Arc<AtomicBool>,
|
||||
paused_waiter: Arc<Notify>,
|
||||
|
||||
pub stats: Arc<PoolStats>,
|
||||
|
||||
/// AuthInfo
|
||||
pub auth_hash: Arc<RwLock<Option<String>>>,
|
||||
}
|
||||
|
||||
impl ConnectionPool {
|
||||
@@ -201,6 +215,7 @@ impl ConnectionPool {
|
||||
// There is one pool per database/user pair.
|
||||
for user in pool_config.users.values() {
|
||||
let old_pool_ref = get_pool(pool_name, &user.username);
|
||||
let identifier = PoolIdentifier::new(pool_name, &user.username);
|
||||
|
||||
match old_pool_ref {
|
||||
Some(pool) => {
|
||||
@@ -211,10 +226,7 @@ impl ConnectionPool {
|
||||
"[pool: {}][user: {}] has not changed",
|
||||
pool_name, user.username
|
||||
);
|
||||
new_pools.insert(
|
||||
PoolIdentifier::new(pool_name, &user.username),
|
||||
pool.clone(),
|
||||
);
|
||||
new_pools.insert(identifier.clone(), pool.clone());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -234,9 +246,14 @@ impl ConnectionPool {
|
||||
.clone()
|
||||
.into_keys()
|
||||
.collect::<Vec<String>>();
|
||||
let pool_stats = Arc::new(PoolStats::new(identifier, pool_config.clone()));
|
||||
|
||||
// Allow the pool to be seen in statistics
|
||||
pool_stats.register(pool_stats.clone());
|
||||
|
||||
// Sort by shard number to ensure consistency.
|
||||
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
|
||||
let pool_auth_hash: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
|
||||
|
||||
for shard_idx in &shard_ids {
|
||||
let shard = &pool_config.shards[shard_idx];
|
||||
@@ -266,6 +283,7 @@ impl ConnectionPool {
|
||||
username: user.username.clone(),
|
||||
pool_name: pool_name.clone(),
|
||||
mirrors: vec![],
|
||||
stats: Arc::new(AddressStats::default()),
|
||||
});
|
||||
address_id += 1;
|
||||
}
|
||||
@@ -283,6 +301,7 @@ impl ConnectionPool {
|
||||
username: user.username.clone(),
|
||||
pool_name: pool_name.clone(),
|
||||
mirrors: mirror_addresses,
|
||||
stats: Arc::new(AddressStats::default()),
|
||||
};
|
||||
|
||||
address_id += 1;
|
||||
@@ -291,12 +310,48 @@ impl ConnectionPool {
|
||||
replica_number += 1;
|
||||
}
|
||||
|
||||
// We assume every server in the pool share user/passwords
|
||||
let auth_passthrough = AuthPassthrough::from_pool_config(pool_config);
|
||||
|
||||
if let Some(apt) = &auth_passthrough {
|
||||
match apt.fetch_hash(&address).await {
|
||||
Ok(ok) => {
|
||||
if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read())
|
||||
{
|
||||
if ok != *pool_auth_hash_value {
|
||||
warn!(
|
||||
"Hash is not the same across shards \
|
||||
of the same pool, client auth will \
|
||||
be done using last obtained hash. \
|
||||
Server: {}:{}, Database: {}",
|
||||
server.host, server.port, shard.database,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
debug!("Hash obtained for {:?}", address);
|
||||
|
||||
{
|
||||
let mut pool_auth_hash = pool_auth_hash.write();
|
||||
*pool_auth_hash = Some(ok.clone());
|
||||
}
|
||||
}
|
||||
Err(err) => warn!(
|
||||
"Could not obtain password hashes \
|
||||
using auth_query config, ignoring. \
|
||||
Error: {:?}",
|
||||
err,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
let manager = ServerPool::new(
|
||||
address.clone(),
|
||||
user.clone(),
|
||||
&shard.database,
|
||||
client_server_map.clone(),
|
||||
get_reporter(),
|
||||
pool_stats.clone(),
|
||||
pool_auth_hash.clone(),
|
||||
);
|
||||
|
||||
let connect_timeout = match pool_config.connect_timeout {
|
||||
@@ -309,14 +364,31 @@ impl ConnectionPool {
|
||||
None => config.general.idle_timeout,
|
||||
};
|
||||
|
||||
let server_lifetime = match user.server_lifetime {
|
||||
Some(server_lifetime) => server_lifetime,
|
||||
None => match pool_config.server_lifetime {
|
||||
Some(server_lifetime) => server_lifetime,
|
||||
None => config.general.server_lifetime,
|
||||
},
|
||||
};
|
||||
|
||||
let reaper_rate = *vec![idle_timeout, server_lifetime, POOL_REAPER_RATE]
|
||||
.iter()
|
||||
.min()
|
||||
.unwrap();
|
||||
|
||||
debug!("Pool reaper rate: {}ms", reaper_rate);
|
||||
|
||||
let pool = Pool::builder()
|
||||
.max_size(user.pool_size)
|
||||
.min_idle(user.min_pool_size)
|
||||
.connection_timeout(std::time::Duration::from_millis(connect_timeout))
|
||||
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
|
||||
.max_lifetime(Some(std::time::Duration::from_millis(server_lifetime)))
|
||||
.reaper_rate(std::time::Duration::from_millis(reaper_rate))
|
||||
.test_on_check_out(false)
|
||||
.build(manager)
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
|
||||
pools.push(pool);
|
||||
servers.push(address);
|
||||
@@ -328,20 +400,31 @@ impl ConnectionPool {
|
||||
}
|
||||
|
||||
assert_eq!(shards.len(), addresses.len());
|
||||
if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) {
|
||||
info!(
|
||||
"Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}",
|
||||
pool_name, user.username
|
||||
);
|
||||
}
|
||||
|
||||
let pool = ConnectionPool {
|
||||
databases: shards,
|
||||
stats: pool_stats,
|
||||
addresses,
|
||||
banlist: Arc::new(RwLock::new(banlist)),
|
||||
stats: get_reporter(),
|
||||
config_hash: new_pool_hash_value,
|
||||
server_info: Arc::new(RwLock::new(BytesMut::new())),
|
||||
auth_hash: pool_auth_hash,
|
||||
settings: PoolSettings {
|
||||
pool_mode: pool_config.pool_mode,
|
||||
pool_mode: match user.pool_mode {
|
||||
Some(pool_mode) => pool_mode,
|
||||
None => pool_config.pool_mode,
|
||||
},
|
||||
load_balancing_mode: pool_config.load_balancing_mode,
|
||||
// shards: pool_config.shards.clone(),
|
||||
shards: shard_ids.len(),
|
||||
user: user.clone(),
|
||||
db: pool_name.clone(),
|
||||
default_role: match pool_config.default_role.as_str() {
|
||||
"any" => None,
|
||||
"replica" => Some(Role::Replica),
|
||||
@@ -364,6 +447,9 @@ impl ConnectionPool {
|
||||
.clone()
|
||||
.map(|regex| Regex::new(regex.as_str()).unwrap()),
|
||||
regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
|
||||
auth_query: pool_config.auth_query.clone(),
|
||||
auth_query_user: pool_config.auth_query_user.clone(),
|
||||
auth_query_password: pool_config.auth_query_password.clone(),
|
||||
},
|
||||
validated: Arc::new(AtomicBool::new(false)),
|
||||
paused: Arc::new(AtomicBool::new(false)),
|
||||
@@ -383,11 +469,38 @@ impl ConnectionPool {
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref plugins) = config.plugins {
|
||||
if let Some(ref intercept) = plugins.intercept {
|
||||
if intercept.enabled {
|
||||
crate::plugins::intercept::setup(intercept, &new_pools);
|
||||
} else {
|
||||
crate::plugins::intercept::disable();
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref table_access) = plugins.table_access {
|
||||
if table_access.enabled {
|
||||
crate::plugins::table_access::setup(table_access);
|
||||
} else {
|
||||
crate::plugins::table_access::disable();
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref query_logger) = plugins.query_logger {
|
||||
if query_logger.enabled {
|
||||
crate::plugins::query_logger::setup();
|
||||
} else {
|
||||
crate::plugins::query_logger::disable();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
POOLS.store(Arc::new(new_pools.clone()));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Connect to all shards and grab server information.
|
||||
/// Connect to all shards, grab server information, and possibly
|
||||
/// passwords to use in client auth.
|
||||
/// Return server information we will pass to the clients
|
||||
/// when they connect.
|
||||
/// This also warms up the pool for clients that connect when
|
||||
@@ -476,9 +589,9 @@ impl ConnectionPool {
|
||||
/// Get a connection from the pool.
|
||||
pub async fn get(
|
||||
&self,
|
||||
shard: usize, // shard number
|
||||
role: Option<Role>, // primary or replica
|
||||
client_process_id: i32, // client id
|
||||
shard: usize, // shard number
|
||||
role: Option<Role>, // primary or replica
|
||||
client_stats: &ClientStats, // client id
|
||||
) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
|
||||
let mut candidates: Vec<&Address> = self.addresses[shard]
|
||||
.iter()
|
||||
@@ -517,7 +630,7 @@ impl ConnectionPool {
|
||||
|
||||
// Indicate we're waiting on a server connection from a pool.
|
||||
let now = Instant::now();
|
||||
self.stats.client_waiting(client_process_id);
|
||||
client_stats.waiting();
|
||||
|
||||
// Check if we can connect
|
||||
let mut conn = match self.databases[address.shard][address.address_index]
|
||||
@@ -527,9 +640,10 @@ impl ConnectionPool {
|
||||
Ok(conn) => conn,
|
||||
Err(err) => {
|
||||
error!("Banning instance {:?}, error: {:?}", address, err);
|
||||
self.ban(address, BanReason::FailedCheckout, client_process_id);
|
||||
self.stats
|
||||
.client_checkout_error(client_process_id, address.id);
|
||||
self.ban(address, BanReason::FailedCheckout, Some(client_stats));
|
||||
address.stats.error();
|
||||
client_stats.idle();
|
||||
client_stats.checkout_error();
|
||||
continue;
|
||||
}
|
||||
};
|
||||
@@ -546,18 +660,18 @@ impl ConnectionPool {
|
||||
// since we last checked the server is ok.
|
||||
// Health checks are pretty expensive.
|
||||
if !require_healthcheck {
|
||||
self.stats.checkout_time(
|
||||
now.elapsed().as_micros(),
|
||||
client_process_id,
|
||||
server.server_id(),
|
||||
);
|
||||
self.stats
|
||||
.server_active(client_process_id, server.server_id());
|
||||
let checkout_time: u64 = now.elapsed().as_micros() as u64;
|
||||
client_stats.checkout_time(checkout_time);
|
||||
server
|
||||
.stats()
|
||||
.checkout_time(checkout_time, client_stats.application_name());
|
||||
server.stats().active(client_stats.application_name());
|
||||
|
||||
return Ok((conn, address.clone()));
|
||||
}
|
||||
|
||||
if self
|
||||
.run_health_check(address, server, now, client_process_id)
|
||||
.run_health_check(address, server, now, client_stats)
|
||||
.await
|
||||
{
|
||||
return Ok((conn, address.clone()));
|
||||
@@ -565,7 +679,6 @@ impl ConnectionPool {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
Err(Error::AllServersDown)
|
||||
}
|
||||
|
||||
@@ -574,11 +687,11 @@ impl ConnectionPool {
|
||||
address: &Address,
|
||||
server: &mut Server,
|
||||
start: Instant,
|
||||
client_process_id: i32,
|
||||
client_info: &ClientStats,
|
||||
) -> bool {
|
||||
debug!("Running health check on server {:?}", address);
|
||||
|
||||
self.stats.server_tested(server.server_id());
|
||||
server.stats().tested();
|
||||
|
||||
match tokio::time::timeout(
|
||||
tokio::time::Duration::from_millis(self.settings.healthcheck_timeout),
|
||||
@@ -589,13 +702,13 @@ impl ConnectionPool {
|
||||
// Check if health check succeeded.
|
||||
Ok(res) => match res {
|
||||
Ok(_) => {
|
||||
self.stats.checkout_time(
|
||||
start.elapsed().as_micros(),
|
||||
client_process_id,
|
||||
server.server_id(),
|
||||
);
|
||||
self.stats
|
||||
.server_active(client_process_id, server.server_id());
|
||||
let checkout_time: u64 = start.elapsed().as_micros() as u64;
|
||||
client_info.checkout_time(checkout_time);
|
||||
server
|
||||
.stats()
|
||||
.checkout_time(checkout_time, client_info.application_name());
|
||||
server.stats().active(client_info.application_name());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -620,14 +733,14 @@ impl ConnectionPool {
|
||||
// Don't leave a bad connection in the pool.
|
||||
server.mark_bad();
|
||||
|
||||
self.ban(&address, BanReason::FailedHealthCheck, client_process_id);
|
||||
self.ban(&address, BanReason::FailedHealthCheck, Some(client_info));
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Ban an address (i.e. replica). It no longer will serve
|
||||
/// traffic for any new transactions. Existing transactions on that replica
|
||||
/// will finish successfully or error out to the clients.
|
||||
pub fn ban(&self, address: &Address, reason: BanReason, client_id: i32) {
|
||||
pub fn ban(&self, address: &Address, reason: BanReason, client_info: Option<&ClientStats>) {
|
||||
// Primary can never be banned
|
||||
if address.role == Role::Primary {
|
||||
return;
|
||||
@@ -636,7 +749,10 @@ impl ConnectionPool {
|
||||
let now = chrono::offset::Utc::now().naive_utc();
|
||||
let mut guard = self.banlist.write();
|
||||
error!("Banning {:?}", address);
|
||||
self.stats.client_ban_error(client_id, address.id);
|
||||
if let Some(client_info) = client_info {
|
||||
client_info.ban_error();
|
||||
address.stats.error();
|
||||
}
|
||||
guard[address.shard].insert(address.clone(), (reason, now));
|
||||
}
|
||||
|
||||
@@ -797,7 +913,8 @@ pub struct ServerPool {
|
||||
user: User,
|
||||
database: String,
|
||||
client_server_map: ClientServerMap,
|
||||
stats: Reporter,
|
||||
stats: Arc<PoolStats>,
|
||||
auth_hash: Arc<RwLock<Option<String>>>,
|
||||
}
|
||||
|
||||
impl ServerPool {
|
||||
@@ -806,14 +923,16 @@ impl ServerPool {
|
||||
user: User,
|
||||
database: &str,
|
||||
client_server_map: ClientServerMap,
|
||||
stats: Reporter,
|
||||
stats: Arc<PoolStats>,
|
||||
auth_hash: Arc<RwLock<Option<String>>>,
|
||||
) -> ServerPool {
|
||||
ServerPool {
|
||||
address,
|
||||
user,
|
||||
user: user.clone(),
|
||||
database: database.to_string(),
|
||||
client_server_map,
|
||||
stats,
|
||||
auth_hash,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -826,34 +945,32 @@ impl ManageConnection for ServerPool {
|
||||
/// Attempts to create a new connection.
|
||||
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
|
||||
info!("Creating a new server connection {:?}", self.address);
|
||||
let server_id = rand::random::<i32>();
|
||||
|
||||
self.stats.server_register(
|
||||
server_id,
|
||||
self.address.id,
|
||||
self.address.name(),
|
||||
self.address.pool_name.clone(),
|
||||
self.address.username.clone(),
|
||||
);
|
||||
self.stats.server_login(server_id);
|
||||
let stats = Arc::new(ServerStats::new(
|
||||
self.address.clone(),
|
||||
self.stats.clone(),
|
||||
tokio::time::Instant::now(),
|
||||
));
|
||||
|
||||
stats.register(stats.clone());
|
||||
|
||||
// Connect to the PostgreSQL server.
|
||||
match Server::startup(
|
||||
server_id,
|
||||
&self.address,
|
||||
&self.user,
|
||||
&self.database,
|
||||
self.client_server_map.clone(),
|
||||
self.stats.clone(),
|
||||
stats.clone(),
|
||||
self.auth_hash.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(conn) => {
|
||||
self.stats.server_idle(server_id);
|
||||
stats.idle();
|
||||
Ok(conn)
|
||||
}
|
||||
Err(err) => {
|
||||
self.stats.server_disconnecting(server_id);
|
||||
stats.disconnect();
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
@@ -881,11 +998,3 @@ pub fn get_pool(db: &str, user: &str) -> Option<ConnectionPool> {
|
||||
pub fn get_all_pools() -> HashMap<PoolIdentifier, ConnectionPool> {
|
||||
(*(*POOLS.load())).clone()
|
||||
}
|
||||
|
||||
/// How many total servers we have in the config.
|
||||
pub fn get_number_of_addresses() -> usize {
|
||||
get_all_pools()
|
||||
.iter()
|
||||
.map(|(_, pool)| pool.databases())
|
||||
.sum()
|
||||
}
|
||||
|
||||
@@ -5,10 +5,12 @@ use phf::phf_map;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::Address;
|
||||
use crate::pool::get_all_pools;
|
||||
use crate::stats::{get_address_stats, get_pool_stats, get_server_stats, ServerInformation};
|
||||
use crate::stats::{get_pool_stats, get_server_stats, ServerStats};
|
||||
|
||||
struct MetricHelpType {
|
||||
help: &'static str,
|
||||
@@ -220,7 +222,7 @@ impl<Value: fmt::Display> PrometheusMetric<Value> {
|
||||
Self::from_name(&format!("servers_{}", name), value, labels)
|
||||
}
|
||||
|
||||
fn from_address(address: &Address, name: &str, value: i64) -> Option<PrometheusMetric<i64>> {
|
||||
fn from_address(address: &Address, name: &str, value: u64) -> Option<PrometheusMetric<u64>> {
|
||||
let mut labels = HashMap::new();
|
||||
labels.insert("host", address.host.clone());
|
||||
labels.insert("shard", address.shard.to_string());
|
||||
@@ -231,7 +233,7 @@ impl<Value: fmt::Display> PrometheusMetric<Value> {
|
||||
Self::from_name(&format!("stats_{}", name), value, labels)
|
||||
}
|
||||
|
||||
fn from_pool(pool: &(String, String), name: &str, value: i64) -> Option<PrometheusMetric<i64>> {
|
||||
fn from_pool(pool: &(String, String), name: &str, value: u64) -> Option<PrometheusMetric<u64>> {
|
||||
let mut labels = HashMap::new();
|
||||
labels.insert("pool", pool.0.clone());
|
||||
labels.insert("user", pool.1.clone());
|
||||
@@ -261,20 +263,18 @@ async fn prometheus_stats(request: Request<Body>) -> Result<Response<Body>, hype
|
||||
|
||||
// Adds metrics shown in a SHOW STATS admin command.
|
||||
fn push_address_stats(lines: &mut Vec<String>) {
|
||||
let address_stats: HashMap<usize, HashMap<String, i64>> = get_address_stats();
|
||||
for (_, pool) in get_all_pools() {
|
||||
for shard in 0..pool.shards() {
|
||||
for server in 0..pool.servers(shard) {
|
||||
let address = pool.address(shard, server);
|
||||
if let Some(address_stats) = address_stats.get(&address.id) {
|
||||
for (key, value) in address_stats.iter() {
|
||||
if let Some(prometheus_metric) =
|
||||
PrometheusMetric::<i64>::from_address(address, key, *value)
|
||||
{
|
||||
lines.push(prometheus_metric.to_string());
|
||||
} else {
|
||||
warn!("Metric {} not implemented for {}", key, address.name());
|
||||
}
|
||||
let stats = &*address.stats;
|
||||
for (key, value) in stats.clone() {
|
||||
if let Some(prometheus_metric) =
|
||||
PrometheusMetric::<u64>::from_address(address, &key, value)
|
||||
{
|
||||
lines.push(prometheus_metric.to_string());
|
||||
} else {
|
||||
warn!("Metric {} not implemented for {}", key, address.name());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -286,8 +286,9 @@ fn push_address_stats(lines: &mut Vec<String>) {
|
||||
fn push_pool_stats(lines: &mut Vec<String>) {
|
||||
let pool_stats = get_pool_stats();
|
||||
for (pool, stats) in pool_stats.iter() {
|
||||
for (name, value) in stats.iter() {
|
||||
if let Some(prometheus_metric) = PrometheusMetric::<i64>::from_pool(pool, name, *value)
|
||||
let stats = &**stats;
|
||||
for (name, value) in stats.clone() {
|
||||
if let Some(prometheus_metric) = PrometheusMetric::<u64>::from_pool(pool, &name, value)
|
||||
{
|
||||
lines.push(prometheus_metric.to_string());
|
||||
} else {
|
||||
@@ -330,9 +331,9 @@ fn push_database_stats(lines: &mut Vec<String>) {
|
||||
// Adds relevant metrics shown in a SHOW SERVERS admin command.
|
||||
fn push_server_stats(lines: &mut Vec<String>) {
|
||||
let server_stats = get_server_stats();
|
||||
let mut server_stats_by_addresses = HashMap::<String, ServerInformation>::new();
|
||||
for (_, info) in server_stats {
|
||||
server_stats_by_addresses.insert(info.address_name.clone(), info);
|
||||
let mut server_stats_by_addresses = HashMap::<String, Arc<ServerStats>>::new();
|
||||
for (_, stats) in server_stats {
|
||||
server_stats_by_addresses.insert(stats.address_name(), stats);
|
||||
}
|
||||
|
||||
for (_, pool) in get_all_pools() {
|
||||
@@ -341,11 +342,23 @@ fn push_server_stats(lines: &mut Vec<String>) {
|
||||
let address = pool.address(shard, server);
|
||||
if let Some(server_info) = server_stats_by_addresses.get(&address.name()) {
|
||||
let metrics = [
|
||||
("bytes_received", server_info.bytes_received),
|
||||
("bytes_sent", server_info.bytes_sent),
|
||||
("transaction_count", server_info.transaction_count),
|
||||
("query_count", server_info.query_count),
|
||||
("error_count", server_info.error_count),
|
||||
(
|
||||
"bytes_received",
|
||||
server_info.bytes_received.load(Ordering::Relaxed),
|
||||
),
|
||||
("bytes_sent", server_info.bytes_sent.load(Ordering::Relaxed)),
|
||||
(
|
||||
"transaction_count",
|
||||
server_info.transaction_count.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"query_count",
|
||||
server_info.query_count.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"error_count",
|
||||
server_info.error_count.load(Ordering::Relaxed),
|
||||
),
|
||||
];
|
||||
for (key, value) in metrics {
|
||||
if let Some(prometheus_metric) =
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
/// Route queries automatically based on explicitely requested
|
||||
/// Route queries automatically based on explicitly requested
|
||||
/// or implied query characteristics.
|
||||
use bytes::{Buf, BytesMut};
|
||||
use log::{debug, error};
|
||||
@@ -6,13 +6,19 @@ use once_cell::sync::OnceCell;
|
||||
use regex::{Regex, RegexSet};
|
||||
use sqlparser::ast::Statement::{Query, StartTransaction};
|
||||
use sqlparser::ast::{
|
||||
BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, TableFactor, Value,
|
||||
BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, Statement, TableFactor,
|
||||
Value,
|
||||
};
|
||||
use sqlparser::dialect::PostgreSqlDialect;
|
||||
use sqlparser::parser::Parser;
|
||||
|
||||
use crate::config::Role;
|
||||
use crate::errors::Error;
|
||||
use crate::messages::BytesMutReader;
|
||||
use crate::plugins::{
|
||||
intercept, query_logger, table_access, Intercept, Plugin, PluginOutput, QueryLogger,
|
||||
TableAccess,
|
||||
};
|
||||
use crate::pool::PoolSettings;
|
||||
use crate::sharding::Sharder;
|
||||
|
||||
@@ -129,6 +135,10 @@ impl QueryRouter {
|
||||
self.pool_settings = pool_settings;
|
||||
}
|
||||
|
||||
pub fn pool_settings<'a>(&'a self) -> &'a PoolSettings {
|
||||
&self.pool_settings
|
||||
}
|
||||
|
||||
/// Try to parse a command and execute it.
|
||||
pub fn try_execute_command(&mut self, message_buffer: &BytesMut) -> Option<(Command, String)> {
|
||||
let mut message_cursor = Cursor::new(message_buffer);
|
||||
@@ -324,10 +334,7 @@ impl QueryRouter {
|
||||
Some((command, value))
|
||||
}
|
||||
|
||||
/// Try to infer which server to connect to based on the contents of the query.
|
||||
pub fn infer(&mut self, message: &BytesMut) -> bool {
|
||||
debug!("Inferring role");
|
||||
|
||||
pub fn parse(message: &BytesMut) -> Result<Vec<sqlparser::ast::Statement>, Error> {
|
||||
let mut message_cursor = Cursor::new(message);
|
||||
|
||||
let code = message_cursor.get_u8() as char;
|
||||
@@ -353,28 +360,29 @@ impl QueryRouter {
|
||||
query
|
||||
}
|
||||
|
||||
_ => return false,
|
||||
_ => return Err(Error::UnsupportedStatement),
|
||||
};
|
||||
|
||||
let ast = match Parser::parse_sql(&PostgreSqlDialect {}, &query) {
|
||||
Ok(ast) => ast,
|
||||
match Parser::parse_sql(&PostgreSqlDialect {}, &query) {
|
||||
Ok(ast) => Ok(ast),
|
||||
Err(err) => {
|
||||
// SELECT ... FOR UPDATE won't get parsed correctly.
|
||||
debug!("{}: {}", err, query);
|
||||
self.active_role = Some(Role::Primary);
|
||||
return false;
|
||||
Err(Error::QueryRouterParserError(err.to_string()))
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
debug!("AST: {:?}", ast);
|
||||
/// Try to infer which server to connect to based on the contents of the query.
|
||||
pub fn infer(&mut self, ast: &Vec<sqlparser::ast::Statement>) -> Result<(), Error> {
|
||||
debug!("Inferring role");
|
||||
|
||||
if ast.is_empty() {
|
||||
// That's weird, no idea, let's go to primary
|
||||
self.active_role = Some(Role::Primary);
|
||||
return false;
|
||||
return Err(Error::QueryRouterParserError("empty query".into()));
|
||||
}
|
||||
|
||||
for q in &ast {
|
||||
for q in ast {
|
||||
match q {
|
||||
// All transactions go to the primary, probably a write.
|
||||
StartTransaction { .. } => {
|
||||
@@ -418,7 +426,7 @@ impl QueryRouter {
|
||||
};
|
||||
}
|
||||
|
||||
true
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Parse the shard number from the Bind message
|
||||
@@ -783,6 +791,34 @@ impl QueryRouter {
|
||||
}
|
||||
}
|
||||
|
||||
/// Add your plugins here and execute them.
|
||||
pub async fn execute_plugins(&self, ast: &Vec<Statement>) -> Result<PluginOutput, Error> {
|
||||
if query_logger::enabled() {
|
||||
let mut query_logger = QueryLogger {};
|
||||
let _ = query_logger.run(&self, ast).await;
|
||||
}
|
||||
|
||||
if intercept::enabled() {
|
||||
let mut intercept = Intercept {};
|
||||
let result = intercept.run(&self, ast).await;
|
||||
|
||||
if let Ok(PluginOutput::Intercept(output)) = result {
|
||||
return Ok(PluginOutput::Intercept(output));
|
||||
}
|
||||
}
|
||||
|
||||
if table_access::enabled() {
|
||||
let mut table_access = TableAccess {};
|
||||
let result = table_access.run(&self, ast).await;
|
||||
|
||||
if let Ok(PluginOutput::Deny(error)) = result {
|
||||
return Ok(PluginOutput::Deny(error));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(PluginOutput::Allow)
|
||||
}
|
||||
|
||||
fn set_sharding_key(&mut self, sharding_key: i64) -> Option<usize> {
|
||||
let sharder = Sharder::new(
|
||||
self.pool_settings.shards,
|
||||
@@ -810,11 +846,22 @@ impl QueryRouter {
|
||||
/// Should we attempt to parse queries?
|
||||
pub fn query_parser_enabled(&self) -> bool {
|
||||
let enabled = match self.query_parser_enabled {
|
||||
None => self.pool_settings.query_parser_enabled,
|
||||
Some(value) => value,
|
||||
};
|
||||
None => {
|
||||
debug!(
|
||||
"Using pool settings, query_parser_enabled: {}",
|
||||
self.pool_settings.query_parser_enabled
|
||||
);
|
||||
self.pool_settings.query_parser_enabled
|
||||
}
|
||||
|
||||
debug!("Query parser enabled: {}", enabled);
|
||||
Some(value) => {
|
||||
debug!(
|
||||
"Using query parser override, query_parser_enabled: {}",
|
||||
value
|
||||
);
|
||||
value
|
||||
}
|
||||
};
|
||||
|
||||
enabled
|
||||
}
|
||||
@@ -862,7 +909,7 @@ mod test {
|
||||
|
||||
for query in queries {
|
||||
// It's a recognized query
|
||||
assert!(qr.infer(&query));
|
||||
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), Some(Role::Replica));
|
||||
}
|
||||
}
|
||||
@@ -881,7 +928,7 @@ mod test {
|
||||
|
||||
for query in queries {
|
||||
// It's a recognized query
|
||||
assert!(qr.infer(&query));
|
||||
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), Some(Role::Primary));
|
||||
}
|
||||
}
|
||||
@@ -893,7 +940,7 @@ mod test {
|
||||
let query = simple_query("SELECT * FROM items WHERE id = 5");
|
||||
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None);
|
||||
|
||||
assert!(qr.infer(&query));
|
||||
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), None);
|
||||
}
|
||||
|
||||
@@ -913,7 +960,7 @@ mod test {
|
||||
res.put(prepared_stmt);
|
||||
res.put_i16(0);
|
||||
|
||||
assert!(qr.infer(&res));
|
||||
assert!(qr.infer(&QueryRouter::parse(&res).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), Some(Role::Replica));
|
||||
}
|
||||
|
||||
@@ -1077,11 +1124,11 @@ mod test {
|
||||
assert_eq!(qr.role(), None);
|
||||
|
||||
let query = simple_query("INSERT INTO test_table VALUES (1)");
|
||||
assert!(qr.infer(&query));
|
||||
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), Some(Role::Primary));
|
||||
|
||||
let query = simple_query("SELECT * FROM test_table");
|
||||
assert!(qr.infer(&query));
|
||||
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), Some(Role::Replica));
|
||||
|
||||
assert!(qr.query_parser_enabled());
|
||||
@@ -1110,6 +1157,10 @@ mod test {
|
||||
sharding_key_regex: None,
|
||||
shard_id_regex: None,
|
||||
regex_search_limit: 1000,
|
||||
auth_query: None,
|
||||
auth_query_password: None,
|
||||
auth_query_user: None,
|
||||
db: "test".to_string(),
|
||||
};
|
||||
let mut qr = QueryRouter::new();
|
||||
assert_eq!(qr.active_role, None);
|
||||
@@ -1139,15 +1190,24 @@ mod test {
|
||||
QueryRouter::setup();
|
||||
|
||||
let mut qr = QueryRouter::new();
|
||||
assert!(qr.infer(&simple_query("BEGIN; SELECT 1; COMMIT;")));
|
||||
assert!(qr
|
||||
.infer(&QueryRouter::parse(&simple_query("BEGIN; SELECT 1; COMMIT;")).unwrap())
|
||||
.is_ok());
|
||||
assert_eq!(qr.role(), Role::Primary);
|
||||
|
||||
assert!(qr.infer(&simple_query("SELECT 1; SELECT 2;")));
|
||||
assert!(qr
|
||||
.infer(&QueryRouter::parse(&simple_query("SELECT 1; SELECT 2;")).unwrap())
|
||||
.is_ok());
|
||||
assert_eq!(qr.role(), Role::Replica);
|
||||
|
||||
assert!(qr.infer(&simple_query(
|
||||
"SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
|
||||
)));
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
"SELECT 123; INSERT INTO t VALUES (5); SELECT 1;"
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.role(), Role::Primary);
|
||||
}
|
||||
|
||||
@@ -1171,6 +1231,10 @@ mod test {
|
||||
sharding_key_regex: Some(Regex::new(r"/\* sharding_key: (\d+) \*/").unwrap()),
|
||||
shard_id_regex: Some(Regex::new(r"/\* shard_id: (\d+) \*/").unwrap()),
|
||||
regex_search_limit: 1000,
|
||||
auth_query: None,
|
||||
auth_query_password: None,
|
||||
auth_query_user: None,
|
||||
db: "test".to_string(),
|
||||
};
|
||||
let mut qr = QueryRouter::new();
|
||||
qr.update_pool_settings(pool_settings.clone());
|
||||
@@ -1202,47 +1266,84 @@ mod test {
|
||||
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
|
||||
qr.pool_settings.shards = 3;
|
||||
|
||||
assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 5")));
|
||||
assert!(qr
|
||||
.infer(&QueryRouter::parse(&simple_query("SELECT * FROM data WHERE id = 5")).unwrap())
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 2);
|
||||
|
||||
assert!(qr.infer(&simple_query(
|
||||
"SELECT one, two, three FROM public.data WHERE id = 6"
|
||||
)));
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
"SELECT one, two, three FROM public.data WHERE id = 6"
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 0);
|
||||
|
||||
assert!(qr.infer(&simple_query(
|
||||
"SELECT * FROM data
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
"SELECT * FROM data
|
||||
INNER JOIN t2 ON data.id = 5
|
||||
AND t2.data_id = data.id
|
||||
WHERE data.id = 5"
|
||||
)));
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 2);
|
||||
|
||||
// Shard did not move because we couldn't determine the sharding key since it could be ambiguous
|
||||
// in the query.
|
||||
assert!(qr.infer(&simple_query(
|
||||
"SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
|
||||
)));
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
"SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 2);
|
||||
|
||||
assert!(qr.infer(&simple_query(
|
||||
r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
|
||||
)));
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 0);
|
||||
|
||||
assert!(qr.infer(&simple_query(
|
||||
r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
|
||||
)));
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 2);
|
||||
|
||||
// Super unique sharding key
|
||||
qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string());
|
||||
assert!(qr.infer(&simple_query(
|
||||
"SELECT * FROM table_x WHERE unique_enough_column_name = 6"
|
||||
)));
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query(
|
||||
"SELECT * FROM table_x WHERE unique_enough_column_name = 6"
|
||||
))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 0);
|
||||
|
||||
assert!(qr.infer(&simple_query("SELECT * FROM table_y WHERE another_key = 5")));
|
||||
assert!(qr
|
||||
.infer(
|
||||
&QueryRouter::parse(&simple_query("SELECT * FROM table_y WHERE another_key = 5"))
|
||||
.unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(qr.shard(), 0);
|
||||
}
|
||||
|
||||
@@ -1266,11 +1367,40 @@ mod test {
|
||||
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
|
||||
qr.pool_settings.shards = 3;
|
||||
|
||||
assert!(qr.infer(&simple_query(stmt)));
|
||||
assert!(qr
|
||||
.infer(&QueryRouter::parse(&simple_query(stmt)).unwrap())
|
||||
.is_ok());
|
||||
assert_eq!(qr.placeholders.len(), 1);
|
||||
|
||||
assert!(qr.infer_shard_from_bind(&bind));
|
||||
assert_eq!(qr.shard(), 2);
|
||||
assert!(qr.placeholders.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_table_access_plugin() {
|
||||
use crate::config::TableAccess;
|
||||
let ta = TableAccess {
|
||||
enabled: true,
|
||||
tables: vec![String::from("pg_database")],
|
||||
};
|
||||
|
||||
crate::plugins::table_access::setup(&ta);
|
||||
|
||||
QueryRouter::setup();
|
||||
|
||||
let qr = QueryRouter::new();
|
||||
|
||||
let query = simple_query("SELECT * FROM pg_database");
|
||||
let ast = QueryRouter::parse(&query).unwrap();
|
||||
|
||||
let res = qr.execute_plugins(&ast).await;
|
||||
|
||||
assert_eq!(
|
||||
res,
|
||||
Ok(PluginOutput::Deny(
|
||||
"permission for table \"pg_database\" denied".to_string()
|
||||
))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
544
src/server.rs
544
src/server.rs
@@ -1,37 +1,116 @@
|
||||
/// Implementation of the PostgreSQL server (database) protocol.
|
||||
/// Here we are pretending to the a Postgres client.
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use log::{debug, error, info, trace, warn};
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use postgres_protocol::message;
|
||||
use std::collections::HashMap;
|
||||
use std::io::Read;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::SystemTime;
|
||||
use tokio::io::{AsyncReadExt, BufReader};
|
||||
use tokio::net::{
|
||||
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
||||
TcpStream,
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore};
|
||||
use tokio_rustls::{client::TlsStream, TlsConnector};
|
||||
|
||||
use crate::config::{Address, User};
|
||||
use crate::config::{get_config, Address, User};
|
||||
use crate::constants::*;
|
||||
use crate::errors::Error;
|
||||
use crate::dns_cache::{AddrSet, CACHED_RESOLVER};
|
||||
use crate::errors::{Error, ServerIdentifier};
|
||||
use crate::messages::*;
|
||||
use crate::mirrors::MirroringManager;
|
||||
use crate::pool::ClientServerMap;
|
||||
use crate::scram::ScramSha256;
|
||||
use crate::stats::Reporter;
|
||||
use crate::stats::ServerStats;
|
||||
use std::io::Write;
|
||||
|
||||
use pin_project::pin_project;
|
||||
|
||||
#[pin_project(project = SteamInnerProj)]
|
||||
pub enum StreamInner {
|
||||
Plain {
|
||||
#[pin]
|
||||
stream: TcpStream,
|
||||
},
|
||||
Tls {
|
||||
#[pin]
|
||||
stream: TlsStream<TcpStream>,
|
||||
},
|
||||
}
|
||||
|
||||
impl AsyncWrite for StreamInner {
|
||||
fn poll_write(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> std::task::Poll<Result<usize, std::io::Error>> {
|
||||
let this = self.project();
|
||||
match this {
|
||||
SteamInnerProj::Tls { stream } => stream.poll_write(cx, buf),
|
||||
SteamInnerProj::Plain { stream } => stream.poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), std::io::Error>> {
|
||||
let this = self.project();
|
||||
match this {
|
||||
SteamInnerProj::Tls { stream } => stream.poll_flush(cx),
|
||||
SteamInnerProj::Plain { stream } => stream.poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), std::io::Error>> {
|
||||
let this = self.project();
|
||||
match this {
|
||||
SteamInnerProj::Tls { stream } => stream.poll_shutdown(cx),
|
||||
SteamInnerProj::Plain { stream } => stream.poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for StreamInner {
|
||||
fn poll_read(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
let this = self.project();
|
||||
match this {
|
||||
SteamInnerProj::Tls { stream } => stream.poll_read(cx, buf),
|
||||
SteamInnerProj::Plain { stream } => stream.poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamInner {
|
||||
pub fn try_write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
match self {
|
||||
StreamInner::Tls { stream } => {
|
||||
let r = stream.get_mut();
|
||||
let mut w = r.1.writer();
|
||||
w.write(buf)
|
||||
}
|
||||
StreamInner::Plain { stream } => stream.try_write(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Server state.
|
||||
pub struct Server {
|
||||
server_id: i32,
|
||||
|
||||
/// Server host, e.g. localhost,
|
||||
/// port, e.g. 5432, and role, e.g. primary or replica.
|
||||
address: Address,
|
||||
|
||||
/// Buffered read socket.
|
||||
read: BufReader<OwnedReadHalf>,
|
||||
|
||||
/// Unbuffered write socket (our client code buffers).
|
||||
write: OwnedWriteHalf,
|
||||
/// Server TCP connection.
|
||||
stream: BufStream<StreamInner>,
|
||||
|
||||
/// Our server response buffer. We buffer data before we give it to the client.
|
||||
buffer: BytesMut,
|
||||
@@ -62,7 +141,7 @@ pub struct Server {
|
||||
connected_at: chrono::naive::NaiveDateTime,
|
||||
|
||||
/// Reports various metrics, e.g. data sent & received.
|
||||
stats: Reporter,
|
||||
stats: Arc<ServerStats>,
|
||||
|
||||
/// Application name using the server at the moment.
|
||||
application_name: String,
|
||||
@@ -71,19 +150,40 @@ pub struct Server {
|
||||
last_activity: SystemTime,
|
||||
|
||||
mirror_manager: Option<MirroringManager>,
|
||||
|
||||
// Associated addresses used
|
||||
addr_set: Option<AddrSet>,
|
||||
}
|
||||
|
||||
impl Server {
|
||||
/// Pretend to be the Postgres client and connect to the server given host, port and credentials.
|
||||
/// Perform the authentication and return the server in a ready for query state.
|
||||
pub async fn startup(
|
||||
server_id: i32,
|
||||
address: &Address,
|
||||
user: &User,
|
||||
database: &str,
|
||||
client_server_map: ClientServerMap,
|
||||
stats: Reporter,
|
||||
stats: Arc<ServerStats>,
|
||||
auth_hash: Arc<RwLock<Option<String>>>,
|
||||
) -> Result<Server, Error> {
|
||||
let cached_resolver = CACHED_RESOLVER.load();
|
||||
let mut addr_set: Option<AddrSet> = None;
|
||||
|
||||
// If we are caching addresses and hostname is not an IP
|
||||
if cached_resolver.enabled() && address.host.parse::<IpAddr>().is_err() {
|
||||
debug!("Resolving {}", &address.host);
|
||||
addr_set = match cached_resolver.lookup_ip(&address.host).await {
|
||||
Ok(ok) => {
|
||||
debug!("Obtained: {:?}", ok);
|
||||
Some(ok)
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Error trying to resolve {}, ({:?})", &address.host, err);
|
||||
None
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let mut stream =
|
||||
match TcpStream::connect(&format!("{}:{}", &address.host, address.port)).await {
|
||||
Ok(stream) => stream,
|
||||
@@ -95,30 +195,137 @@ impl Server {
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
// TCP timeouts.
|
||||
configure_socket(&stream);
|
||||
|
||||
let config = get_config();
|
||||
|
||||
let mut stream = if config.general.server_tls {
|
||||
// Request a TLS connection
|
||||
ssl_request(&mut stream).await?;
|
||||
|
||||
let response = match stream.read_u8().await {
|
||||
Ok(response) => response as char,
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Server socket error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
match response {
|
||||
// Server supports TLS
|
||||
'S' => {
|
||||
debug!("Connecting to server using TLS");
|
||||
|
||||
let mut root_store = RootCertStore::empty();
|
||||
root_store.add_server_trust_anchors(
|
||||
webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
|
||||
OwnedTrustAnchor::from_subject_spki_name_constraints(
|
||||
ta.subject,
|
||||
ta.spki,
|
||||
ta.name_constraints,
|
||||
)
|
||||
}),
|
||||
);
|
||||
|
||||
let mut tls_config = rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
|
||||
// Equivalent to sslmode=prefer which is fine most places.
|
||||
// If you want verify-full, change `verify_server_certificate` to true.
|
||||
if !config.general.verify_server_certificate {
|
||||
let mut dangerous = tls_config.dangerous();
|
||||
dangerous.set_certificate_verifier(Arc::new(
|
||||
crate::tls::NoCertificateVerification {},
|
||||
));
|
||||
}
|
||||
|
||||
let connector = TlsConnector::from(Arc::new(tls_config));
|
||||
let stream = match connector
|
||||
.connect(address.host.as_str().try_into().unwrap(), stream)
|
||||
.await
|
||||
{
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!("Server TLS error: {:?}", err)))
|
||||
}
|
||||
};
|
||||
|
||||
StreamInner::Tls { stream }
|
||||
}
|
||||
|
||||
// Server does not support TLS
|
||||
'N' => StreamInner::Plain { stream },
|
||||
|
||||
// Something else?
|
||||
m => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Unknown message: {}",
|
||||
m as char
|
||||
)));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
StreamInner::Plain { stream }
|
||||
};
|
||||
|
||||
// let (read, write) = split(stream);
|
||||
// let (mut read, mut write) = (ReadInner::Plain { stream: read }, WriteInner::Plain { stream: write });
|
||||
|
||||
trace!("Sending StartupMessage");
|
||||
|
||||
// StartupMessage
|
||||
startup(&mut stream, &user.username, database).await?;
|
||||
let username = match user.server_username {
|
||||
Some(ref server_username) => server_username,
|
||||
None => &user.username,
|
||||
};
|
||||
|
||||
let password = match user.server_password {
|
||||
Some(ref server_password) => Some(server_password),
|
||||
None => match user.password {
|
||||
Some(ref password) => Some(password),
|
||||
None => None,
|
||||
},
|
||||
};
|
||||
|
||||
startup(&mut stream, username, database).await?;
|
||||
|
||||
let mut server_info = BytesMut::new();
|
||||
let mut process_id: i32 = 0;
|
||||
let mut secret_key: i32 = 0;
|
||||
let server_identifier = ServerIdentifier::new(username, &database);
|
||||
|
||||
// We'll be handling multiple packets, but they will all be structured the same.
|
||||
// We'll loop here until this exchange is complete.
|
||||
let mut scram = ScramSha256::new(&user.password);
|
||||
let mut scram: Option<ScramSha256> = match password {
|
||||
Some(password) => Some(ScramSha256::new(password)),
|
||||
None => None,
|
||||
};
|
||||
|
||||
loop {
|
||||
let code = match stream.read_u8().await {
|
||||
Ok(code) => code as char,
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading message code on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"message code".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let len = match stream.read_i32().await {
|
||||
Ok(len) => len,
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading message len on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"message len".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
trace!("Message: {}", code);
|
||||
@@ -129,7 +336,12 @@ impl Server {
|
||||
// Determine which kind of authentication is required, if any.
|
||||
let auth_code = match stream.read_i32().await {
|
||||
Ok(auth_code) => auth_code,
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading auth code on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"auth code".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
trace!("Auth: {}", auth_code);
|
||||
@@ -142,32 +354,79 @@ impl Server {
|
||||
|
||||
match stream.read_exact(&mut salt).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading salt on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"salt".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
md5_password(&mut stream, &user.username, &user.password, &salt[..])
|
||||
.await?;
|
||||
match password {
|
||||
// Using plaintext password
|
||||
Some(password) => {
|
||||
md5_password(&mut stream, username, password, &salt[..]).await?
|
||||
}
|
||||
|
||||
// Using auth passthrough, in this case we should already have a
|
||||
// hash obtained when the pool was validated. If we reach this point
|
||||
// and don't have a hash, we return an error.
|
||||
None => {
|
||||
let option_hash = (*auth_hash.read()).clone();
|
||||
match option_hash {
|
||||
Some(hash) =>
|
||||
md5_password_with_hash(
|
||||
&mut stream,
|
||||
&hash,
|
||||
&salt[..],
|
||||
)
|
||||
.await?,
|
||||
None => return Err(
|
||||
Error::ServerAuthError(
|
||||
"Auth passthrough (auth_query) failed and no user password is set in cleartext".into(),
|
||||
server_identifier
|
||||
)
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AUTHENTICATION_SUCCESSFUL => (),
|
||||
|
||||
SASL => {
|
||||
if scram.is_none() {
|
||||
return Err(Error::ServerAuthError(
|
||||
"SASL auth required and no password specified. \
|
||||
Auth passthrough (auth_query) method is currently \
|
||||
unsupported for SASL auth"
|
||||
.into(),
|
||||
server_identifier,
|
||||
));
|
||||
}
|
||||
|
||||
debug!("Starting SASL authentication");
|
||||
|
||||
let sasl_len = (len - 8) as usize;
|
||||
let mut sasl_auth = vec![0u8; sasl_len];
|
||||
|
||||
match stream.read_exact(&mut sasl_auth).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading sasl message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"sasl message".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]);
|
||||
|
||||
if sasl_type == SCRAM_SHA_256 {
|
||||
if sasl_type.contains(SCRAM_SHA_256) {
|
||||
debug!("Using {}", SCRAM_SHA_256);
|
||||
|
||||
// Generate client message.
|
||||
let sasl_response = scram.message();
|
||||
let sasl_response = scram.as_mut().unwrap().message();
|
||||
|
||||
// SASLInitialResponse (F)
|
||||
let mut res = BytesMut::new();
|
||||
@@ -186,7 +445,7 @@ impl Server {
|
||||
res.put_i32(sasl_response.len() as i32);
|
||||
res.put(sasl_response);
|
||||
|
||||
write_all(&mut stream, res).await?;
|
||||
write_all_flush(&mut stream, &res).await?;
|
||||
} else {
|
||||
error!("Unsupported SCRAM version: {}", sasl_type);
|
||||
return Err(Error::ServerError);
|
||||
@@ -200,11 +459,16 @@ impl Server {
|
||||
|
||||
match stream.read_exact(&mut sasl_data).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading sasl cont message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"sasl cont message".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let msg = BytesMut::from(&sasl_data[..]);
|
||||
let sasl_response = scram.update(&msg)?;
|
||||
let sasl_response = scram.as_mut().unwrap().update(&msg)?;
|
||||
|
||||
// SASLResponse
|
||||
let mut res = BytesMut::new();
|
||||
@@ -212,7 +476,7 @@ impl Server {
|
||||
res.put_i32(4 + sasl_response.len() as i32);
|
||||
res.put(sasl_response);
|
||||
|
||||
write_all(&mut stream, res).await?;
|
||||
write_all_flush(&mut stream, &res).await?;
|
||||
}
|
||||
|
||||
SASL_FINAL => {
|
||||
@@ -221,10 +485,19 @@ impl Server {
|
||||
let mut sasl_final = vec![0u8; len as usize - 8];
|
||||
match stream.read_exact(&mut sasl_final).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading sasl final message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"sasl final message".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
match scram.finish(&BytesMut::from(&sasl_final[..])) {
|
||||
match scram
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.finish(&BytesMut::from(&sasl_final[..]))
|
||||
{
|
||||
Ok(_) => {
|
||||
debug!("SASL authentication successful");
|
||||
}
|
||||
@@ -247,7 +520,12 @@ impl Server {
|
||||
'E' => {
|
||||
let error_code = match stream.read_u8().await {
|
||||
Ok(error_code) => error_code,
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading error code message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"error code message".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
trace!("Error: {}", error_code);
|
||||
@@ -263,7 +541,12 @@ impl Server {
|
||||
|
||||
match stream.read_exact(&mut error).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading error message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"error message".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: the error message contains multiple fields; we can decode them and
|
||||
@@ -282,7 +565,12 @@ impl Server {
|
||||
|
||||
match stream.read_exact(&mut param).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading parameter status message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"parameter status message".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Save the parameter so we can pass it to the client later.
|
||||
@@ -299,12 +587,22 @@ impl Server {
|
||||
// See: <https://www.postgresql.org/docs/12/protocol-message-formats.html>.
|
||||
process_id = match stream.read_i32().await {
|
||||
Ok(id) => id,
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading process id message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"process id message".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
secret_key = match stream.read_i32().await {
|
||||
Ok(id) => id,
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading secret key message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"secret key message".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -314,18 +612,19 @@ impl Server {
|
||||
|
||||
match stream.read_exact(&mut idle).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading transaction status message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"transaction status message".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let (read, write) = stream.into_split();
|
||||
|
||||
let mut server = Server {
|
||||
address: address.clone(),
|
||||
read: BufReader::new(read),
|
||||
write,
|
||||
stream: BufStream::new(stream),
|
||||
buffer: BytesMut::with_capacity(8196),
|
||||
server_info,
|
||||
server_id,
|
||||
process_id,
|
||||
secret_key,
|
||||
in_transaction: false,
|
||||
@@ -333,6 +632,7 @@ impl Server {
|
||||
bad: false,
|
||||
needs_cleanup: false,
|
||||
client_server_map,
|
||||
addr_set,
|
||||
connected_at: chrono::offset::Utc::now().naive_utc(),
|
||||
stats,
|
||||
application_name: String::new(),
|
||||
@@ -377,7 +677,7 @@ impl Server {
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
error!("Could not connect to server: {}", err);
|
||||
return Err(Error::SocketError(format!("Error reading cancel message")));
|
||||
return Err(Error::SocketError("Error reading cancel message".into()));
|
||||
}
|
||||
};
|
||||
configure_socket(&stream);
|
||||
@@ -390,15 +690,15 @@ impl Server {
|
||||
bytes.put_i32(process_id);
|
||||
bytes.put_i32(secret_key);
|
||||
|
||||
write_all(&mut stream, bytes).await
|
||||
write_all_flush(&mut stream, &bytes).await
|
||||
}
|
||||
|
||||
/// Send messages to the server from the client.
|
||||
pub async fn send(&mut self, messages: &BytesMut) -> Result<(), Error> {
|
||||
self.mirror_send(messages);
|
||||
self.stats.data_sent(messages.len(), self.server_id);
|
||||
self.stats().data_sent(messages.len());
|
||||
|
||||
match write_all_half(&mut self.write, messages).await {
|
||||
match write_all_flush(&mut self.stream, &messages).await {
|
||||
Ok(_) => {
|
||||
// Successfully sent to server
|
||||
self.last_activity = SystemTime::now();
|
||||
@@ -417,7 +717,7 @@ impl Server {
|
||||
/// in order to receive all data the server has to offer.
|
||||
pub async fn recv(&mut self) -> Result<BytesMut, Error> {
|
||||
loop {
|
||||
let mut message = match read_message(&mut self.read).await {
|
||||
let mut message = match read_message(&mut self.stream).await {
|
||||
Ok(message) => message,
|
||||
Err(err) => {
|
||||
error!("Terminating server because of: {:?}", err);
|
||||
@@ -545,7 +845,7 @@ impl Server {
|
||||
let bytes = self.buffer.clone();
|
||||
|
||||
// Keep track of how much data we got from the server for stats.
|
||||
self.stats.data_received(bytes.len(), self.server_id);
|
||||
self.stats().data_received(bytes.len());
|
||||
|
||||
// Clear the buffer for next query.
|
||||
self.buffer.clear();
|
||||
@@ -573,7 +873,23 @@ impl Server {
|
||||
/// Server & client are out of sync, we must discard this connection.
|
||||
/// This happens with clients that misbehave.
|
||||
pub fn is_bad(&self) -> bool {
|
||||
self.bad
|
||||
if self.bad {
|
||||
return self.bad;
|
||||
};
|
||||
let cached_resolver = CACHED_RESOLVER.load();
|
||||
if cached_resolver.enabled() {
|
||||
if let Some(addr_set) = &self.addr_set {
|
||||
if cached_resolver.has_changed(self.address.host.as_str(), addr_set) {
|
||||
warn!(
|
||||
"DNS changed for {}, it was {:?}. Dropping server connection.",
|
||||
self.address.host.as_str(),
|
||||
addr_set
|
||||
);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Get server startup information to forward it to the client.
|
||||
@@ -665,18 +981,17 @@ impl Server {
|
||||
}
|
||||
}
|
||||
|
||||
/// get Server stats
|
||||
pub fn stats(&self) -> Arc<ServerStats> {
|
||||
self.stats.clone()
|
||||
}
|
||||
|
||||
/// Get the servers address.
|
||||
#[allow(dead_code)]
|
||||
pub fn address(&self) -> Address {
|
||||
self.address.clone()
|
||||
}
|
||||
|
||||
/// Get the server connection identifier
|
||||
/// Used to uniquely identify connection in statistics
|
||||
pub fn server_id(&self) -> i32 {
|
||||
self.server_id
|
||||
}
|
||||
|
||||
// Get server's latest response timestamp
|
||||
pub fn last_activity(&self) -> SystemTime {
|
||||
self.last_activity
|
||||
@@ -700,6 +1015,105 @@ impl Server {
|
||||
None => (),
|
||||
}
|
||||
}
|
||||
|
||||
// This is so we can execute out of band queries to the server.
|
||||
// The connection will be opened, the query executed and closed.
|
||||
pub async fn exec_simple_query(
|
||||
address: &Address,
|
||||
user: &User,
|
||||
query: &str,
|
||||
) -> Result<Vec<String>, Error> {
|
||||
let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new()));
|
||||
|
||||
debug!("Connecting to server to obtain auth hashes.");
|
||||
let mut server = Server::startup(
|
||||
address,
|
||||
user,
|
||||
&address.database,
|
||||
client_server_map,
|
||||
Arc::new(ServerStats::default()),
|
||||
Arc::new(RwLock::new(None)),
|
||||
)
|
||||
.await?;
|
||||
debug!("Connected!, sending query.");
|
||||
server.send(&simple_query(query)).await?;
|
||||
let mut message = server.recv().await?;
|
||||
|
||||
Ok(parse_query_message(&mut message).await?)
|
||||
}
|
||||
}
|
||||
|
||||
async fn parse_query_message(message: &mut BytesMut) -> Result<Vec<String>, Error> {
|
||||
let mut pair = Vec::<String>::new();
|
||||
match message::backend::Message::parse(message) {
|
||||
Ok(Some(message::backend::Message::RowDescription(_description))) => {}
|
||||
Ok(Some(message::backend::Message::ErrorResponse(err))) => {
|
||||
return Err(Error::ProtocolSyncError(format!(
|
||||
"Protocol error parsing response. Err: {:?}",
|
||||
err.fields()
|
||||
.iterator()
|
||||
.fold(String::default(), |acc, element| acc
|
||||
+ element.unwrap().value())
|
||||
)))
|
||||
}
|
||||
Ok(_) => {
|
||||
return Err(Error::ProtocolSyncError(
|
||||
"Protocol error, expected Row Description.".to_string(),
|
||||
))
|
||||
}
|
||||
Err(err) => {
|
||||
return Err(Error::ProtocolSyncError(format!(
|
||||
"Protocol error parsing response. Err: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
while !message.is_empty() {
|
||||
match message::backend::Message::parse(message) {
|
||||
Ok(postgres_message) => {
|
||||
match postgres_message {
|
||||
Some(message::backend::Message::DataRow(data)) => {
|
||||
let buf = data.buffer();
|
||||
trace!("Data: {:?}", buf);
|
||||
|
||||
for item in data.ranges().iterator() {
|
||||
match item.as_ref() {
|
||||
Ok(range) => match range {
|
||||
Some(range) => {
|
||||
pair.push(String::from_utf8_lossy(&buf[range.clone()]).to_string());
|
||||
}
|
||||
None => return Err(Error::ProtocolSyncError(String::from(
|
||||
"Data expected while receiving query auth data, found nothing.",
|
||||
))),
|
||||
},
|
||||
Err(err) => {
|
||||
return Err(Error::ProtocolSyncError(format!(
|
||||
"Data error, err: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(message::backend::Message::CommandComplete(_)) => {}
|
||||
Some(message::backend::Message::ReadyForQuery(_)) => {}
|
||||
_ => {
|
||||
return Err(Error::ProtocolSyncError(
|
||||
"Unexpected message while receiving auth query data.".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
return Err(Error::ProtocolSyncError(format!(
|
||||
"Parse error, err: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
};
|
||||
}
|
||||
Ok(pair)
|
||||
}
|
||||
|
||||
impl Drop for Server {
|
||||
@@ -708,15 +1122,17 @@ impl Drop for Server {
|
||||
/// for a write.
|
||||
fn drop(&mut self) {
|
||||
self.mirror_disconnect();
|
||||
self.stats.server_disconnecting(self.server_id);
|
||||
|
||||
let mut bytes = BytesMut::with_capacity(4);
|
||||
// Update statistics
|
||||
self.stats.disconnect();
|
||||
|
||||
let mut bytes = BytesMut::with_capacity(5);
|
||||
bytes.put_u8(b'X');
|
||||
bytes.put_i32(4);
|
||||
|
||||
match self.write.try_write(&bytes) {
|
||||
Ok(_) => (),
|
||||
Err(_) => debug!("Dirty shutdown"),
|
||||
match self.stream.get_mut().try_write(&bytes) {
|
||||
Ok(5) => (),
|
||||
_ => debug!("Dirty shutdown"),
|
||||
};
|
||||
|
||||
// Should not matter.
|
||||
|
||||
1033
src/stats.rs
1033
src/stats.rs
File diff suppressed because it is too large
Load Diff
149
src/stats/address.rs
Normal file
149
src/stats/address.rs
Normal file
@@ -0,0 +1,149 @@
|
||||
use log::warn;
|
||||
use std::sync::atomic::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Internal address stats
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct AddressStats {
|
||||
pub total_xact_count: Arc<AtomicU64>,
|
||||
pub total_query_count: Arc<AtomicU64>,
|
||||
pub total_received: Arc<AtomicU64>,
|
||||
pub total_sent: Arc<AtomicU64>,
|
||||
pub total_xact_time: Arc<AtomicU64>,
|
||||
pub total_query_time: Arc<AtomicU64>,
|
||||
pub total_wait_time: Arc<AtomicU64>,
|
||||
pub total_errors: Arc<AtomicU64>,
|
||||
pub avg_query_count: Arc<AtomicU64>,
|
||||
pub avg_query_time: Arc<AtomicU64>,
|
||||
pub avg_recv: Arc<AtomicU64>,
|
||||
pub avg_sent: Arc<AtomicU64>,
|
||||
pub avg_errors: Arc<AtomicU64>,
|
||||
pub avg_xact_time: Arc<AtomicU64>,
|
||||
pub avg_xact_count: Arc<AtomicU64>,
|
||||
pub avg_wait_time: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
impl IntoIterator for AddressStats {
|
||||
type Item = (String, u64);
|
||||
type IntoIter = std::vec::IntoIter<Self::Item>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
vec![
|
||||
(
|
||||
"total_xact_count".to_string(),
|
||||
self.total_xact_count.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"total_query_count".to_string(),
|
||||
self.total_query_count.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"total_received".to_string(),
|
||||
self.total_received.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"total_sent".to_string(),
|
||||
self.total_sent.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"total_xact_time".to_string(),
|
||||
self.total_xact_time.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"total_query_time".to_string(),
|
||||
self.total_query_time.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"total_wait_time".to_string(),
|
||||
self.total_wait_time.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"total_errors".to_string(),
|
||||
self.total_errors.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"avg_xact_count".to_string(),
|
||||
self.avg_xact_count.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"avg_query_count".to_string(),
|
||||
self.avg_query_count.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"avg_recv".to_string(),
|
||||
self.avg_recv.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"avg_sent".to_string(),
|
||||
self.avg_sent.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"avg_errors".to_string(),
|
||||
self.avg_errors.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"avg_xact_time".to_string(),
|
||||
self.avg_xact_time.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"avg_query_time".to_string(),
|
||||
self.avg_query_time.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"avg_wait_time".to_string(),
|
||||
self.avg_wait_time.load(Ordering::Relaxed),
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl AddressStats {
|
||||
pub fn error(&self) {
|
||||
self.total_errors.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn update_averages(&self) {
|
||||
let (totals, averages) = self.fields_iterators();
|
||||
for data in totals.iter().zip(averages.iter()) {
|
||||
let (total, average) = data;
|
||||
if let Err(err) = average.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |avg| {
|
||||
let total = total.load(Ordering::Relaxed);
|
||||
let avg = (total - avg) / (crate::stats::STAT_PERIOD / 1_000); // Avg / second
|
||||
Some(avg)
|
||||
}) {
|
||||
warn!("Could not update averages for addresses stats, {:?}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn populate_row(&self, row: &mut Vec<String>) {
|
||||
for (_key, value) in self.clone() {
|
||||
row.push(value.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
fn fields_iterators(&self) -> (Vec<Arc<AtomicU64>>, Vec<Arc<AtomicU64>>) {
|
||||
let mut totals: Vec<Arc<AtomicU64>> = Vec::new();
|
||||
let mut averages: Vec<Arc<AtomicU64>> = Vec::new();
|
||||
|
||||
totals.push(self.total_xact_count.clone());
|
||||
averages.push(self.avg_xact_count.clone());
|
||||
totals.push(self.total_query_count.clone());
|
||||
averages.push(self.avg_query_count.clone());
|
||||
totals.push(self.total_received.clone());
|
||||
averages.push(self.avg_recv.clone());
|
||||
totals.push(self.total_sent.clone());
|
||||
averages.push(self.avg_sent.clone());
|
||||
totals.push(self.total_xact_time.clone());
|
||||
averages.push(self.avg_xact_time.clone());
|
||||
totals.push(self.total_query_time.clone());
|
||||
averages.push(self.avg_query_time.clone());
|
||||
totals.push(self.total_wait_time.clone());
|
||||
averages.push(self.avg_wait_time.clone());
|
||||
totals.push(self.total_errors.clone());
|
||||
averages.push(self.avg_errors.clone());
|
||||
|
||||
(totals, averages)
|
||||
}
|
||||
}
|
||||
182
src/stats/client.rs
Normal file
182
src/stats/client.rs
Normal file
@@ -0,0 +1,182 @@
|
||||
use super::PoolStats;
|
||||
use super::{get_reporter, Reporter};
|
||||
use atomic_enum::atomic_enum;
|
||||
use std::sync::atomic::*;
|
||||
use std::sync::Arc;
|
||||
use tokio::time::Instant;
|
||||
/// The various states that a client can be in
|
||||
#[atomic_enum]
|
||||
#[derive(PartialEq)]
|
||||
pub enum ClientState {
|
||||
Idle = 0,
|
||||
Waiting,
|
||||
Active,
|
||||
}
|
||||
impl std::fmt::Display for ClientState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match *self {
|
||||
ClientState::Idle => write!(f, "idle"),
|
||||
ClientState::Waiting => write!(f, "waiting"),
|
||||
ClientState::Active => write!(f, "active"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Information we keep track of which can be queried by SHOW CLIENTS
|
||||
pub struct ClientStats {
|
||||
/// A random integer assigned to the client and used by stats to track the client
|
||||
client_id: i32,
|
||||
|
||||
/// Data associated with the client, not writable, only set when we construct the ClientStat
|
||||
application_name: String,
|
||||
username: String,
|
||||
pool_name: String,
|
||||
connect_time: Instant,
|
||||
|
||||
pool_stats: Arc<PoolStats>,
|
||||
reporter: Reporter,
|
||||
|
||||
/// Total time spent waiting for a connection from pool, measures in microseconds
|
||||
pub total_wait_time: Arc<AtomicU64>,
|
||||
|
||||
/// Current state of the client
|
||||
pub state: Arc<AtomicClientState>,
|
||||
|
||||
/// Number of transactions executed by this client
|
||||
pub transaction_count: Arc<AtomicU64>,
|
||||
|
||||
/// Number of queries executed by this client
|
||||
pub query_count: Arc<AtomicU64>,
|
||||
|
||||
/// Number of errors made by this client
|
||||
pub error_count: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
impl Default for ClientStats {
|
||||
fn default() -> Self {
|
||||
ClientStats {
|
||||
client_id: 0,
|
||||
connect_time: Instant::now(),
|
||||
application_name: String::new(),
|
||||
username: String::new(),
|
||||
pool_name: String::new(),
|
||||
pool_stats: Arc::new(PoolStats::default()),
|
||||
total_wait_time: Arc::new(AtomicU64::new(0)),
|
||||
state: Arc::new(AtomicClientState::new(ClientState::Idle)),
|
||||
transaction_count: Arc::new(AtomicU64::new(0)),
|
||||
query_count: Arc::new(AtomicU64::new(0)),
|
||||
error_count: Arc::new(AtomicU64::new(0)),
|
||||
reporter: get_reporter(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientStats {
|
||||
pub fn new(
|
||||
client_id: i32,
|
||||
application_name: &str,
|
||||
username: &str,
|
||||
pool_name: &str,
|
||||
connect_time: Instant,
|
||||
pool_stats: Arc<PoolStats>,
|
||||
) -> Self {
|
||||
Self {
|
||||
client_id,
|
||||
pool_stats,
|
||||
connect_time,
|
||||
application_name: application_name.to_string(),
|
||||
username: username.to_string(),
|
||||
pool_name: pool_name.to_string(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Reports a client is disconnecting from the pooler and
|
||||
/// update metrics on the corresponding pool.
|
||||
pub fn disconnect(&self) {
|
||||
self.reporter.client_disconnecting(self.client_id);
|
||||
self.pool_stats
|
||||
.client_disconnect(self.state.load(Ordering::Relaxed))
|
||||
}
|
||||
|
||||
/// Register a client with the stats system. The stats system uses client_id
|
||||
/// to track and aggregate statistics from all source that relate to that client
|
||||
pub fn register(&self, stats: Arc<ClientStats>) {
|
||||
self.reporter.client_register(self.client_id, stats);
|
||||
self.state.store(ClientState::Idle, Ordering::Relaxed);
|
||||
self.pool_stats.cl_idle.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Reports a client is done querying the server and is no longer assigned a server connection
|
||||
pub fn idle(&self) {
|
||||
self.pool_stats
|
||||
.client_idle(self.state.load(Ordering::Relaxed));
|
||||
self.state.store(ClientState::Idle, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Reports a client is waiting for a connection
|
||||
pub fn waiting(&self) {
|
||||
self.pool_stats
|
||||
.client_waiting(self.state.load(Ordering::Relaxed));
|
||||
self.state.store(ClientState::Waiting, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Reports a client is done waiting for a connection and is about to query the server.
|
||||
pub fn active(&self) {
|
||||
self.pool_stats
|
||||
.client_active(self.state.load(Ordering::Relaxed));
|
||||
self.state.store(ClientState::Active, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Reports a client has failed to obtain a connection from a connection pool
|
||||
pub fn checkout_error(&self) {
|
||||
self.state.store(ClientState::Idle, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Reports a client has had the server assigned to it be banned
|
||||
pub fn ban_error(&self) {
|
||||
self.state.store(ClientState::Idle, Ordering::Relaxed);
|
||||
self.error_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Reporters the time spent by a client waiting to get a healthy connection from the pool
|
||||
pub fn checkout_time(&self, microseconds: u64) {
|
||||
self.total_wait_time
|
||||
.fetch_add(microseconds, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Report a query executed by a client against a server
|
||||
pub fn query(&self) {
|
||||
self.query_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Report a transaction executed by a client a server
|
||||
/// we report each individual queries outside a transaction as a transaction
|
||||
/// We only count the initial BEGIN as a transaction, all queries within do not
|
||||
/// count as transactions
|
||||
pub fn transaction(&self) {
|
||||
self.transaction_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
// Helper methods for show clients
|
||||
pub fn connect_time(&self) -> Instant {
|
||||
self.connect_time
|
||||
}
|
||||
|
||||
pub fn client_id(&self) -> i32 {
|
||||
self.client_id
|
||||
}
|
||||
|
||||
pub fn application_name(&self) -> String {
|
||||
self.application_name.clone()
|
||||
}
|
||||
|
||||
pub fn username(&self) -> String {
|
||||
self.username.clone()
|
||||
}
|
||||
|
||||
pub fn pool_name(&self) -> String {
|
||||
self.pool_name.clone()
|
||||
}
|
||||
}
|
||||
274
src/stats/pool.rs
Normal file
274
src/stats/pool.rs
Normal file
@@ -0,0 +1,274 @@
|
||||
use crate::config::Pool;
|
||||
use crate::config::PoolMode;
|
||||
use crate::pool::PoolIdentifier;
|
||||
use std::sync::atomic::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::get_reporter;
|
||||
use super::Reporter;
|
||||
use super::{ClientState, ServerState};
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
/// A struct that holds information about a Pool .
|
||||
pub struct PoolStats {
|
||||
// Pool identifier, cannot be changed after creating the instance
|
||||
identifier: PoolIdentifier,
|
||||
|
||||
// Pool Config, cannot be changed after creating the instance
|
||||
config: Pool,
|
||||
|
||||
// A reference to the global reporter.
|
||||
reporter: Reporter,
|
||||
|
||||
/// Counters (atomics)
|
||||
pub cl_idle: Arc<AtomicU64>,
|
||||
pub cl_active: Arc<AtomicU64>,
|
||||
pub cl_waiting: Arc<AtomicU64>,
|
||||
pub cl_cancel_req: Arc<AtomicU64>,
|
||||
pub sv_active: Arc<AtomicU64>,
|
||||
pub sv_idle: Arc<AtomicU64>,
|
||||
pub sv_used: Arc<AtomicU64>,
|
||||
pub sv_tested: Arc<AtomicU64>,
|
||||
pub sv_login: Arc<AtomicU64>,
|
||||
pub maxwait: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
impl IntoIterator for PoolStats {
|
||||
type Item = (String, u64);
|
||||
type IntoIter = std::vec::IntoIter<Self::Item>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
vec![
|
||||
("cl_idle".to_string(), self.cl_idle.load(Ordering::Relaxed)),
|
||||
(
|
||||
"cl_active".to_string(),
|
||||
self.cl_active.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"cl_waiting".to_string(),
|
||||
self.cl_waiting.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"cl_cancel_req".to_string(),
|
||||
self.cl_cancel_req.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"sv_active".to_string(),
|
||||
self.sv_active.load(Ordering::Relaxed),
|
||||
),
|
||||
("sv_idle".to_string(), self.sv_idle.load(Ordering::Relaxed)),
|
||||
("sv_used".to_string(), self.sv_used.load(Ordering::Relaxed)),
|
||||
(
|
||||
"sv_tested".to_string(),
|
||||
self.sv_tested.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"sv_login".to_string(),
|
||||
self.sv_login.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"maxwait".to_string(),
|
||||
self.maxwait.load(Ordering::Relaxed) / 1_000_000,
|
||||
),
|
||||
(
|
||||
"maxwait_us".to_string(),
|
||||
self.maxwait.load(Ordering::Relaxed) % 1_000_000,
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl PoolStats {
|
||||
pub fn new(identifier: PoolIdentifier, config: Pool) -> Self {
|
||||
Self {
|
||||
identifier,
|
||||
config,
|
||||
reporter: get_reporter(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
// Getters
|
||||
pub fn register(&self, stats: Arc<PoolStats>) {
|
||||
self.reporter.pool_register(self.identifier.clone(), stats);
|
||||
}
|
||||
|
||||
pub fn database(&self) -> String {
|
||||
self.identifier.db.clone()
|
||||
}
|
||||
|
||||
pub fn user(&self) -> String {
|
||||
self.identifier.user.clone()
|
||||
}
|
||||
|
||||
pub fn pool_mode(&self) -> PoolMode {
|
||||
self.config.pool_mode
|
||||
}
|
||||
|
||||
/// Populates an array of strings with counters (used by admin in show pools)
|
||||
pub fn populate_row(&self, row: &mut Vec<String>) {
|
||||
for (_key, value) in self.clone() {
|
||||
row.push(value.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
/// Deletes the maxwait counter, this is done everytime we obtain metrics
|
||||
pub fn clear_maxwait(&self) {
|
||||
self.maxwait.store(0, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Notified when a server of the pool enters login state.
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the server that notifies.
|
||||
pub fn server_login(&self, from: ServerState) {
|
||||
self.sv_login.fetch_add(1, Ordering::Relaxed);
|
||||
if from != ServerState::Login {
|
||||
self.decrease_from_server_state(from);
|
||||
}
|
||||
}
|
||||
|
||||
/// Notified when a server of the pool become 'active'
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the server that notifies.
|
||||
pub fn server_active(&self, from: ServerState) {
|
||||
self.sv_active.fetch_add(1, Ordering::Relaxed);
|
||||
if from != ServerState::Active {
|
||||
self.decrease_from_server_state(from);
|
||||
}
|
||||
}
|
||||
|
||||
/// Notified when a server of the pool become 'tested'
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the server that notifies.
|
||||
pub fn server_tested(&self, from: ServerState) {
|
||||
self.sv_tested.fetch_add(1, Ordering::Relaxed);
|
||||
if from != ServerState::Tested {
|
||||
self.decrease_from_server_state(from);
|
||||
}
|
||||
}
|
||||
|
||||
/// Notified when a server of the pool become 'idle'
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the server that notifies.
|
||||
pub fn server_idle(&self, from: ServerState) {
|
||||
self.sv_idle.fetch_add(1, Ordering::Relaxed);
|
||||
if from != ServerState::Idle {
|
||||
self.decrease_from_server_state(from);
|
||||
}
|
||||
}
|
||||
|
||||
/// Notified when a client of the pool become 'waiting'
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the client that notifies.
|
||||
pub fn client_waiting(&self, from: ClientState) {
|
||||
if from != ClientState::Waiting {
|
||||
self.cl_waiting.fetch_add(1, Ordering::Relaxed);
|
||||
self.decrease_from_client_state(from);
|
||||
}
|
||||
}
|
||||
|
||||
/// Notified when a client of the pool become 'active'
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the client that notifies.
|
||||
pub fn client_active(&self, from: ClientState) {
|
||||
if from != ClientState::Active {
|
||||
self.cl_active.fetch_add(1, Ordering::Relaxed);
|
||||
self.decrease_from_client_state(from);
|
||||
}
|
||||
}
|
||||
|
||||
/// Notified when a client of the pool become 'idle'
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the client that notifies.
|
||||
pub fn client_idle(&self, from: ClientState) {
|
||||
if from != ClientState::Idle {
|
||||
self.cl_idle.fetch_add(1, Ordering::Relaxed);
|
||||
self.decrease_from_client_state(from);
|
||||
}
|
||||
}
|
||||
|
||||
/// Notified when a client disconnects.
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the client that notifies.
|
||||
pub fn client_disconnect(&self, from: ClientState) {
|
||||
let counter = match from {
|
||||
ClientState::Idle => &self.cl_idle,
|
||||
ClientState::Waiting => &self.cl_waiting,
|
||||
ClientState::Active => &self.cl_active,
|
||||
};
|
||||
|
||||
Self::decrease_counter(counter.clone());
|
||||
}
|
||||
|
||||
/// Notified when a server disconnects.
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the client that notifies.
|
||||
pub fn server_disconnect(&self, from: ServerState) {
|
||||
let counter = match from {
|
||||
ServerState::Active => &self.sv_active,
|
||||
ServerState::Idle => &self.sv_idle,
|
||||
ServerState::Login => &self.sv_login,
|
||||
ServerState::Tested => &self.sv_tested,
|
||||
};
|
||||
Self::decrease_counter(counter.clone());
|
||||
}
|
||||
|
||||
// helpers for counter decrease
|
||||
fn decrease_from_server_state(&self, from: ServerState) {
|
||||
let counter = match from {
|
||||
ServerState::Tested => &self.sv_tested,
|
||||
ServerState::Active => &self.sv_active,
|
||||
ServerState::Idle => &self.sv_idle,
|
||||
ServerState::Login => &self.sv_login,
|
||||
};
|
||||
Self::decrease_counter(counter.clone());
|
||||
}
|
||||
|
||||
fn decrease_from_client_state(&self, from: ClientState) {
|
||||
let counter = match from {
|
||||
ClientState::Active => &self.cl_active,
|
||||
ClientState::Idle => &self.cl_idle,
|
||||
ClientState::Waiting => &self.cl_waiting,
|
||||
};
|
||||
Self::decrease_counter(counter.clone());
|
||||
}
|
||||
|
||||
fn decrease_counter(value: Arc<AtomicU64>) {
|
||||
if value.load(Ordering::Relaxed) > 0 {
|
||||
value.fetch_sub(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_decrease() {
|
||||
let stat: PoolStats = PoolStats::default();
|
||||
stat.server_login(ServerState::Login);
|
||||
stat.server_idle(ServerState::Login);
|
||||
assert_eq!(stat.sv_login.load(Ordering::Relaxed), 0);
|
||||
assert_eq!(stat.sv_idle.load(Ordering::Relaxed), 1);
|
||||
}
|
||||
}
|
||||
225
src/stats/server.rs
Normal file
225
src/stats/server.rs
Normal file
@@ -0,0 +1,225 @@
|
||||
use super::AddressStats;
|
||||
use super::PoolStats;
|
||||
use super::{get_reporter, Reporter};
|
||||
use crate::config::Address;
|
||||
use atomic_enum::atomic_enum;
|
||||
use parking_lot::RwLock;
|
||||
use std::sync::atomic::*;
|
||||
use std::sync::Arc;
|
||||
use tokio::time::Instant;
|
||||
|
||||
/// The various states that a server can be in
|
||||
#[atomic_enum]
|
||||
#[derive(PartialEq)]
|
||||
pub enum ServerState {
|
||||
Login = 0,
|
||||
Active,
|
||||
Tested,
|
||||
Idle,
|
||||
}
|
||||
impl std::fmt::Display for ServerState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match *self {
|
||||
ServerState::Login => write!(f, "login"),
|
||||
ServerState::Active => write!(f, "active"),
|
||||
ServerState::Tested => write!(f, "tested"),
|
||||
ServerState::Idle => write!(f, "idle"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Information we keep track of which can be queried by SHOW SERVERS
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ServerStats {
|
||||
/// A random integer assigned to the server and used by stats to track the server
|
||||
server_id: i32,
|
||||
|
||||
/// Context information, only to be read
|
||||
address: Address,
|
||||
connect_time: Instant,
|
||||
|
||||
pool_stats: Arc<PoolStats>,
|
||||
reporter: Reporter,
|
||||
|
||||
/// Data
|
||||
pub application_name: Arc<RwLock<String>>,
|
||||
pub state: Arc<AtomicServerState>,
|
||||
pub bytes_sent: Arc<AtomicU64>,
|
||||
pub bytes_received: Arc<AtomicU64>,
|
||||
pub transaction_count: Arc<AtomicU64>,
|
||||
pub query_count: Arc<AtomicU64>,
|
||||
pub error_count: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
impl Default for ServerStats {
|
||||
fn default() -> Self {
|
||||
ServerStats {
|
||||
server_id: 0,
|
||||
application_name: Arc::new(RwLock::new(String::new())),
|
||||
address: Address::default(),
|
||||
pool_stats: Arc::new(PoolStats::default()),
|
||||
connect_time: Instant::now(),
|
||||
state: Arc::new(AtomicServerState::new(ServerState::Login)),
|
||||
bytes_sent: Arc::new(AtomicU64::new(0)),
|
||||
bytes_received: Arc::new(AtomicU64::new(0)),
|
||||
transaction_count: Arc::new(AtomicU64::new(0)),
|
||||
query_count: Arc::new(AtomicU64::new(0)),
|
||||
error_count: Arc::new(AtomicU64::new(0)),
|
||||
reporter: get_reporter(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerStats {
|
||||
pub fn new(address: Address, pool_stats: Arc<PoolStats>, connect_time: Instant) -> Self {
|
||||
Self {
|
||||
address,
|
||||
pool_stats,
|
||||
connect_time,
|
||||
server_id: rand::random::<i32>(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn server_id(&self) -> i32 {
|
||||
self.server_id
|
||||
}
|
||||
|
||||
/// Register a server connection with the stats system. The stats system uses server_id
|
||||
/// to track and aggregate statistics from all source that relate to that server
|
||||
// Delegates to reporter
|
||||
pub fn register(&self, stats: Arc<ServerStats>) {
|
||||
self.reporter.server_register(self.server_id, stats);
|
||||
self.login();
|
||||
}
|
||||
|
||||
/// Reports a server connection is no longer assigned to a client
|
||||
/// and is available for the next client to pick it up
|
||||
pub fn idle(&self) {
|
||||
self.pool_stats
|
||||
.server_idle(self.state.load(Ordering::Relaxed));
|
||||
|
||||
self.state.store(ServerState::Idle, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Reports a server connection is disconnecting from the pooler.
|
||||
/// Also updates metrics on the pool regarding server usage.
|
||||
pub fn disconnect(&self) {
|
||||
self.reporter.server_disconnecting(self.server_id);
|
||||
self.pool_stats
|
||||
.server_disconnect(self.state.load(Ordering::Relaxed))
|
||||
}
|
||||
|
||||
/// Reports a server connection is being tested before being given to a client.
|
||||
pub fn tested(&self) {
|
||||
self.set_undefined_application();
|
||||
self.pool_stats
|
||||
.server_tested(self.state.load(Ordering::Relaxed));
|
||||
self.state.store(ServerState::Tested, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Reports a server connection is attempting to login.
|
||||
pub fn login(&self) {
|
||||
self.pool_stats
|
||||
.server_login(self.state.load(Ordering::Relaxed));
|
||||
self.state.store(ServerState::Login, Ordering::Relaxed);
|
||||
self.set_undefined_application();
|
||||
}
|
||||
|
||||
/// Reports a server connection has been assigned to a client that
|
||||
/// is about to query the server
|
||||
pub fn active(&self, application_name: String) {
|
||||
self.pool_stats
|
||||
.server_active(self.state.load(Ordering::Relaxed));
|
||||
self.state.store(ServerState::Active, Ordering::Relaxed);
|
||||
self.set_application(application_name);
|
||||
}
|
||||
|
||||
pub fn address_stats(&self) -> Arc<AddressStats> {
|
||||
self.address.stats.clone()
|
||||
}
|
||||
|
||||
// Helper methods for show_servers
|
||||
pub fn pool_name(&self) -> String {
|
||||
self.pool_stats.database()
|
||||
}
|
||||
|
||||
pub fn username(&self) -> String {
|
||||
self.pool_stats.user()
|
||||
}
|
||||
|
||||
pub fn address_name(&self) -> String {
|
||||
self.address.name()
|
||||
}
|
||||
|
||||
pub fn connect_time(&self) -> Instant {
|
||||
self.connect_time
|
||||
}
|
||||
|
||||
fn set_application(&self, name: String) {
|
||||
let mut application_name = self.application_name.write();
|
||||
*application_name = name;
|
||||
}
|
||||
|
||||
fn set_undefined_application(&self) {
|
||||
self.set_application(String::from("Undefined"))
|
||||
}
|
||||
|
||||
pub fn checkout_time(&self, microseconds: u64, application_name: String) {
|
||||
// Update server stats and address aggergation stats
|
||||
self.set_application(application_name);
|
||||
self.address
|
||||
.stats
|
||||
.total_wait_time
|
||||
.fetch_add(microseconds, Ordering::Relaxed);
|
||||
self.pool_stats
|
||||
.maxwait
|
||||
.fetch_max(microseconds, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Report a query executed by a client against a server
|
||||
pub fn query(&self, milliseconds: u64, application_name: &str) {
|
||||
self.set_application(application_name.to_string());
|
||||
let address_stats = self.address_stats();
|
||||
address_stats
|
||||
.total_query_count
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
address_stats
|
||||
.total_query_time
|
||||
.fetch_add(milliseconds, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Report a transaction executed by a client a server
|
||||
/// we report each individual queries outside a transaction as a transaction
|
||||
/// We only count the initial BEGIN as a transaction, all queries within do not
|
||||
/// count as transactions
|
||||
pub fn transaction(&self, application_name: &str) {
|
||||
self.set_application(application_name.to_string());
|
||||
|
||||
self.transaction_count.fetch_add(1, Ordering::Relaxed);
|
||||
self.address
|
||||
.stats
|
||||
.total_xact_count
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Report data sent to a server
|
||||
pub fn data_sent(&self, amount_bytes: usize) {
|
||||
self.bytes_sent
|
||||
.fetch_add(amount_bytes as u64, Ordering::Relaxed);
|
||||
self.address
|
||||
.stats
|
||||
.total_sent
|
||||
.fetch_add(amount_bytes as u64, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Report data received from a server
|
||||
pub fn data_received(&self, amount_bytes: usize) {
|
||||
self.bytes_received
|
||||
.fetch_add(amount_bytes as u64, Ordering::Relaxed);
|
||||
self.address
|
||||
.stats
|
||||
.total_received
|
||||
.fetch_add(amount_bytes as u64, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
23
src/tls.rs
23
src/tls.rs
@@ -4,7 +4,12 @@ use rustls_pemfile::{certs, read_one, Item};
|
||||
use std::iter;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tokio_rustls::rustls::{self, Certificate, PrivateKey};
|
||||
use std::time::SystemTime;
|
||||
use tokio_rustls::rustls::{
|
||||
self,
|
||||
client::{ServerCertVerified, ServerCertVerifier},
|
||||
Certificate, PrivateKey, ServerName,
|
||||
};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
|
||||
use crate::config::get_config;
|
||||
@@ -64,3 +69,19 @@ impl Tls {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NoCertificateVerification;
|
||||
|
||||
impl ServerCertVerifier for NoCertificateVerification {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &Certificate,
|
||||
_intermediates: &[Certificate],
|
||||
_server_name: &ServerName,
|
||||
_scts: &mut dyn Iterator<Item = &[u8]>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: SystemTime,
|
||||
) -> Result<ServerCertVerified, rustls::Error> {
|
||||
Ok(ServerCertVerified::assertion())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,6 +36,15 @@ services:
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_INITDB_ARGS: --auth-local=scram-sha-256 --auth-host=scram-sha-256 --auth=scram-sha-256
|
||||
command: ["postgres", "-p", "9432", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-c", "pg_stat_statements.max=100000"]
|
||||
pg5:
|
||||
image: postgres:14
|
||||
network_mode: "service:main"
|
||||
environment:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_DB: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_INITDB_ARGS: --auth-local=md5 --auth-host=md5 --auth=md5
|
||||
command: ["postgres", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-p", "10432"]
|
||||
main:
|
||||
build: .
|
||||
command: ["bash", "/app/tests/docker/run.sh"]
|
||||
|
||||
@@ -37,9 +37,9 @@ describe "Admin" do
|
||||
describe "SHOW POOLS" do
|
||||
context "bad credentials" do
|
||||
it "does not change any stats" do
|
||||
bad_passsword_url = URI(pgcat_conn_str)
|
||||
bad_passsword_url.password = "wrong"
|
||||
expect { PG::connect("#{bad_passsword_url.to_s}?application_name=bad_password") }.to raise_error(PG::ConnectionBad)
|
||||
bad_password_url = URI(pgcat_conn_str)
|
||||
bad_password_url.password = "wrong"
|
||||
expect { PG::connect("#{bad_password_url.to_s}?application_name=bad_password") }.to raise_error(PG::ConnectionBad)
|
||||
|
||||
sleep(1)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
@@ -71,15 +71,17 @@ describe "Admin" do
|
||||
|
||||
context "client connects but issues no queries" do
|
||||
it "only affects cl_idle stats" do
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
|
||||
before_test = admin_conn.async_exec("SHOW POOLS")[0]["sv_idle"]
|
||||
connections = Array.new(20) { PG::connect(pgcat_conn_str) }
|
||||
sleep(1)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["cl_idle"]).to eq("20")
|
||||
expect(results["sv_idle"]).to eq("1")
|
||||
expect(results["sv_idle"]).to eq(before_test)
|
||||
|
||||
connections.map(&:close)
|
||||
sleep(1.1)
|
||||
@@ -87,7 +89,7 @@ describe "Admin" do
|
||||
%w[cl_active cl_idle cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["sv_idle"]).to eq("1")
|
||||
expect(results["sv_idle"]).to eq(before_test)
|
||||
end
|
||||
end
|
||||
|
||||
@@ -176,6 +178,47 @@ describe "Admin" do
|
||||
end
|
||||
end
|
||||
|
||||
context "clients connects and disconnect normally" do
|
||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 2) }
|
||||
|
||||
it 'shows the same number of clients before and after' do
|
||||
clients_before = clients_connected_to_pool(processes: processes)
|
||||
threads = []
|
||||
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
threads << Thread.new { c.async_exec("SELECT 1") }
|
||||
end
|
||||
clients_between = clients_connected_to_pool(processes: processes)
|
||||
expect(clients_before).not_to eq(clients_between)
|
||||
connections.each(&:close)
|
||||
clients_after = clients_connected_to_pool(processes: processes)
|
||||
expect(clients_before).to eq(clients_after)
|
||||
end
|
||||
end
|
||||
|
||||
context "clients connects and disconnect abruptly" do
|
||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 10) }
|
||||
|
||||
it 'shows the same number of clients before and after' do
|
||||
threads = []
|
||||
connections = Array.new(2) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
threads << Thread.new { c.async_exec("SELECT 1") }
|
||||
end
|
||||
clients_before = clients_connected_to_pool(processes: processes)
|
||||
random_string = (0...8).map { (65 + rand(26)).chr }.join
|
||||
connection_string = "#{pgcat_conn_str}?application_name=#{random_string}"
|
||||
faulty_client = Process.spawn("psql -Atx #{connection_string} >/dev/null")
|
||||
sleep(1)
|
||||
# psql starts two processes, we only know the pid of the parent, this
|
||||
# ensure both are killed
|
||||
`pkill -9 -f '#{random_string}'`
|
||||
Process.wait(faulty_client)
|
||||
clients_after = clients_connected_to_pool(processes: processes)
|
||||
expect(clients_before).to eq(clients_after)
|
||||
end
|
||||
end
|
||||
|
||||
context "clients overwhelm server pools" do
|
||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 2) }
|
||||
|
||||
@@ -199,7 +242,7 @@ describe "Admin" do
|
||||
|
||||
sleep(2.5) # Allow time for stats to update
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["cl_idle"]).to eq("4")
|
||||
|
||||
215
tests/ruby/auth_query_spec.rb
Normal file
215
tests/ruby/auth_query_spec.rb
Normal file
@@ -0,0 +1,215 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
require_relative 'spec_helper'
|
||||
require_relative 'helpers/auth_query_helper'
|
||||
|
||||
describe "Auth Query" do
|
||||
let(:configured_instances) {[5432, 10432]}
|
||||
let(:config_user) { { 'username' => 'sharding_user', 'password' => 'sharding_user' } }
|
||||
let(:pg_user) { { 'username' => 'sharding_user', 'password' => 'sharding_user' } }
|
||||
let(:processes) { Helpers::AuthQuery.single_shard_auth_query(pool_name: "sharded_db", pg_user: pg_user, config_user: config_user, extra_conf: config, wait_until_ready: wait_until_ready ) }
|
||||
let(:config) { {} }
|
||||
let(:wait_until_ready) { true }
|
||||
|
||||
after do
|
||||
unless @failing_process
|
||||
processes.all_databases.map(&:reset)
|
||||
processes.pgcat.shutdown
|
||||
end
|
||||
@failing_process = false
|
||||
end
|
||||
|
||||
context "when auth_query is not configured" do
|
||||
context 'and cleartext passwords are set' do
|
||||
it "uses local passwords" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", config_user['username'], config_user['password']))
|
||||
|
||||
expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil
|
||||
end
|
||||
end
|
||||
|
||||
context 'and cleartext passwords are not set' do
|
||||
let(:config_user) { { 'username' => 'sharding_user' } }
|
||||
|
||||
it "does not start because it is not possible to authenticate" do
|
||||
@failing_process = true
|
||||
expect { processes.pgcat }.to raise_error(StandardError, /You have to specify a user password for every pool if auth_query is not specified/)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
context 'when auth_query is configured' do
|
||||
context 'with global configuration' do
|
||||
around(:example) do |example|
|
||||
|
||||
# Set up auth query
|
||||
Helpers::AuthQuery.set_up_auth_query_for_user(
|
||||
user: 'md5_auth_user',
|
||||
password: 'secret'
|
||||
);
|
||||
|
||||
example.run
|
||||
|
||||
# Drop auth query support
|
||||
Helpers::AuthQuery.tear_down_auth_query_for_user(
|
||||
user: 'md5_auth_user',
|
||||
password: 'secret'
|
||||
);
|
||||
end
|
||||
|
||||
context 'with correct global parameters' do
|
||||
let(:config) { { 'general' => { 'auth_query' => "SELECT * FROM public.user_lookup('$1');", 'auth_query_user' => 'md5_auth_user', 'auth_query_password' => 'secret' } } }
|
||||
context 'and with cleartext passwords set' do
|
||||
it 'it uses local passwords' do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password']))
|
||||
expect(conn.exec("SELECT 1 + 2")).not_to be_nil
|
||||
end
|
||||
end
|
||||
|
||||
context 'and with cleartext passwords not set' do
|
||||
let(:config_user) { { 'username' => 'sharding_user', 'password' => 'sharding_user' } }
|
||||
|
||||
it 'it uses obtained passwords' do
|
||||
connection_string = processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password'])
|
||||
conn = PG.connect(connection_string)
|
||||
expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil
|
||||
end
|
||||
|
||||
it 'allows passwords to be changed without closing existing connections' do
|
||||
pgconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username']))
|
||||
expect(pgconn.exec("SELECT 1 + 2")).not_to be_nil
|
||||
Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD 'secret2';")
|
||||
expect(pgconn.exec("SELECT 1 + 4")).not_to be_nil
|
||||
Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD '#{pg_user['password']}';")
|
||||
end
|
||||
|
||||
it 'allows passwords to be changed and that new password is needed when reconnecting' do
|
||||
pgconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username']))
|
||||
expect(pgconn.exec("SELECT 1 + 2")).not_to be_nil
|
||||
Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD 'secret2';")
|
||||
newconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], 'secret2'))
|
||||
expect(newconn.exec("SELECT 1 + 2")).not_to be_nil
|
||||
Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD '#{pg_user['password']}';")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
context 'with wrong parameters' do
|
||||
let(:config) { { 'general' => { 'auth_query' => 'SELECT 1', 'auth_query_user' => 'wrong_user', 'auth_query_password' => 'wrong' } } }
|
||||
|
||||
context 'and with clear text passwords set' do
|
||||
it "it uses local passwords" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password']))
|
||||
|
||||
expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil
|
||||
end
|
||||
end
|
||||
|
||||
context 'and with cleartext passwords not set' do
|
||||
let(:config_user) { { 'username' => 'sharding_user' } }
|
||||
it "it fails to start as it cannot authenticate against servers" do
|
||||
@failing_process = true
|
||||
expect { PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password'])) }.to raise_error(StandardError, /Error trying to obtain password from auth_query/ )
|
||||
end
|
||||
|
||||
context 'and we fix the issue and reload' do
|
||||
let(:wait_until_ready) { false }
|
||||
|
||||
it 'fails in the beginning but starts working after reloading config' do
|
||||
connection_string = processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password'])
|
||||
while !(processes.pgcat.logs =~ /Waiting for clients/) do
|
||||
sleep 0.5
|
||||
end
|
||||
|
||||
expect { PG.connect(connection_string)}.to raise_error(PG::ConnectionBad)
|
||||
expect(processes.pgcat.logs).to match(/Error trying to obtain password from auth_query/)
|
||||
|
||||
current_config = processes.pgcat.current_config
|
||||
config = { 'general' => { 'auth_query' => "SELECT * FROM public.user_lookup('$1');", 'auth_query_user' => 'md5_auth_user', 'auth_query_password' => 'secret' } }
|
||||
processes.pgcat.update_config(current_config.deep_merge(config))
|
||||
processes.pgcat.reload_config
|
||||
|
||||
conn = nil
|
||||
expect { conn = PG.connect(connection_string)}.not_to raise_error
|
||||
expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
context 'with per pool configuration' do
|
||||
around(:example) do |example|
|
||||
|
||||
# Set up auth query
|
||||
Helpers::AuthQuery.set_up_auth_query_for_user(
|
||||
user: 'md5_auth_user',
|
||||
password: 'secret'
|
||||
);
|
||||
|
||||
Helpers::AuthQuery.set_up_auth_query_for_user(
|
||||
user: 'md5_auth_user1',
|
||||
password: 'secret',
|
||||
database: 'shard1'
|
||||
);
|
||||
|
||||
example.run
|
||||
|
||||
# Tear down auth query
|
||||
Helpers::AuthQuery.tear_down_auth_query_for_user(
|
||||
user: 'md5_auth_user',
|
||||
password: 'secret'
|
||||
);
|
||||
|
||||
Helpers::AuthQuery.tear_down_auth_query_for_user(
|
||||
user: 'md5_auth_user1',
|
||||
password: 'secret',
|
||||
database: 'shard1'
|
||||
);
|
||||
end
|
||||
|
||||
context 'with correct parameters' do
|
||||
let(:processes) { Helpers::AuthQuery.two_pools_auth_query(pool_names: ["sharded_db0", "sharded_db1"], pg_user: pg_user, config_user: config_user, extra_conf: config ) }
|
||||
let(:config) {
|
||||
{ 'pools' =>
|
||||
{
|
||||
'sharded_db0' => {
|
||||
'auth_query' => "SELECT * FROM public.user_lookup('$1');",
|
||||
'auth_query_user' => 'md5_auth_user',
|
||||
'auth_query_password' => 'secret'
|
||||
},
|
||||
'sharded_db1' => {
|
||||
'auth_query' => "SELECT * FROM public.user_lookup('$1');",
|
||||
'auth_query_user' => 'md5_auth_user1',
|
||||
'auth_query_password' => 'secret'
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
context 'and with cleartext passwords set' do
|
||||
it 'it uses local passwords' do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db0", pg_user['username'], pg_user['password']))
|
||||
expect(conn.exec("SELECT 1 + 2")).not_to be_nil
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db1", pg_user['username'], pg_user['password']))
|
||||
expect(conn.exec("SELECT 1 + 2")).not_to be_nil
|
||||
end
|
||||
end
|
||||
|
||||
context 'and with cleartext passwords not set' do
|
||||
let(:config_user) { { 'username' => 'sharding_user' } }
|
||||
|
||||
it 'it uses obtained passwords' do
|
||||
connection_string = processes.pgcat.connection_string("sharded_db0", pg_user['username'], pg_user['password'])
|
||||
conn = PG.connect(connection_string)
|
||||
expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil
|
||||
connection_string = processes.pgcat.connection_string("sharded_db1", pg_user['username'], pg_user['password'])
|
||||
conn = PG.connect(connection_string)
|
||||
expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil
|
||||
end
|
||||
end
|
||||
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
BIN
tests/ruby/capture
Normal file
BIN
tests/ruby/capture
Normal file
Binary file not shown.
173
tests/ruby/helpers/auth_query_helper.rb
Normal file
173
tests/ruby/helpers/auth_query_helper.rb
Normal file
@@ -0,0 +1,173 @@
|
||||
module Helpers
|
||||
module AuthQuery
|
||||
def self.single_shard_auth_query(
|
||||
pg_user:,
|
||||
config_user:,
|
||||
pool_name:,
|
||||
extra_conf: {},
|
||||
log_level: 'debug',
|
||||
wait_until_ready: true
|
||||
)
|
||||
|
||||
user = {
|
||||
"pool_size" => 10,
|
||||
"statement_timeout" => 0,
|
||||
}
|
||||
|
||||
pgcat = PgcatProcess.new(log_level)
|
||||
pgcat_cfg = pgcat.current_config.deep_merge(extra_conf)
|
||||
|
||||
primary = PgInstance.new(5432, pg_user["username"], pg_user["password"], "shard0")
|
||||
replica = PgInstance.new(10432, pg_user["username"], pg_user["password"], "shard0")
|
||||
|
||||
# Main proxy configs
|
||||
pgcat_cfg["pools"] = {
|
||||
"#{pool_name}" => {
|
||||
"default_role" => "any",
|
||||
"pool_mode" => "transaction",
|
||||
"load_balancing_mode" => "random",
|
||||
"primary_reads_enabled" => false,
|
||||
"query_parser_enabled" => false,
|
||||
"sharding_function" => "pg_bigint_hash",
|
||||
"shards" => {
|
||||
"0" => {
|
||||
"database" => "shard0",
|
||||
"servers" => [
|
||||
["localhost", primary.port.to_s, "primary"],
|
||||
["localhost", replica.port.to_s, "replica"],
|
||||
]
|
||||
},
|
||||
},
|
||||
"users" => { "0" => user.merge(config_user) }
|
||||
}
|
||||
}
|
||||
pgcat_cfg["general"]["port"] = pgcat.port
|
||||
pgcat.update_config(pgcat_cfg)
|
||||
pgcat.start
|
||||
|
||||
pgcat.wait_until_ready(
|
||||
pgcat.connection_string(
|
||||
"sharded_db",
|
||||
pg_user['username'],
|
||||
pg_user['password']
|
||||
)
|
||||
) if wait_until_ready
|
||||
|
||||
OpenStruct.new.tap do |struct|
|
||||
struct.pgcat = pgcat
|
||||
struct.primary = primary
|
||||
struct.replicas = [replica]
|
||||
struct.all_databases = [primary]
|
||||
end
|
||||
end
|
||||
|
||||
def self.two_pools_auth_query(
|
||||
pg_user:,
|
||||
config_user:,
|
||||
pool_names:,
|
||||
extra_conf: {},
|
||||
log_level: 'debug'
|
||||
)
|
||||
|
||||
user = {
|
||||
"pool_size" => 10,
|
||||
"statement_timeout" => 0,
|
||||
}
|
||||
|
||||
pgcat = PgcatProcess.new(log_level)
|
||||
pgcat_cfg = pgcat.current_config
|
||||
|
||||
primary = PgInstance.new(5432, pg_user["username"], pg_user["password"], "shard0")
|
||||
replica = PgInstance.new(10432, pg_user["username"], pg_user["password"], "shard0")
|
||||
|
||||
pool_template = Proc.new do |database|
|
||||
{
|
||||
"default_role" => "any",
|
||||
"pool_mode" => "transaction",
|
||||
"load_balancing_mode" => "random",
|
||||
"primary_reads_enabled" => false,
|
||||
"query_parser_enabled" => false,
|
||||
"sharding_function" => "pg_bigint_hash",
|
||||
"shards" => {
|
||||
"0" => {
|
||||
"database" => database,
|
||||
"servers" => [
|
||||
["localhost", primary.port.to_s, "primary"],
|
||||
["localhost", replica.port.to_s, "replica"],
|
||||
]
|
||||
},
|
||||
},
|
||||
"users" => { "0" => user.merge(config_user) }
|
||||
}
|
||||
end
|
||||
# Main proxy configs
|
||||
pgcat_cfg["pools"] = {
|
||||
"#{pool_names[0]}" => pool_template.call("shard0"),
|
||||
"#{pool_names[1]}" => pool_template.call("shard1")
|
||||
}
|
||||
|
||||
pgcat_cfg["general"]["port"] = pgcat.port
|
||||
pgcat.update_config(pgcat_cfg.deep_merge(extra_conf))
|
||||
pgcat.start
|
||||
|
||||
pgcat.wait_until_ready(pgcat.connection_string("sharded_db0", pg_user['username'], pg_user['password']))
|
||||
|
||||
OpenStruct.new.tap do |struct|
|
||||
struct.pgcat = pgcat
|
||||
struct.primary = primary
|
||||
struct.replicas = [replica]
|
||||
struct.all_databases = [primary]
|
||||
end
|
||||
end
|
||||
|
||||
def self.create_query_auth_function(user)
|
||||
return <<-SQL
|
||||
CREATE OR REPLACE FUNCTION public.user_lookup(in i_username text, out uname text, out phash text)
|
||||
RETURNS record AS $$
|
||||
BEGIN
|
||||
SELECT usename, passwd FROM pg_catalog.pg_shadow
|
||||
WHERE usename = i_username INTO uname, phash;
|
||||
RETURN;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql SECURITY DEFINER;
|
||||
|
||||
GRANT EXECUTE ON FUNCTION public.user_lookup(text) TO #{user};
|
||||
SQL
|
||||
end
|
||||
|
||||
def self.exec_in_instances(query:, instance_ports: [ 5432, 10432 ], database: 'postgres', user: 'postgres', password: 'postgres')
|
||||
instance_ports.each do |port|
|
||||
c = PG.connect("postgres://#{user}:#{password}@localhost:#{port}/#{database}")
|
||||
c.exec(query)
|
||||
c.close
|
||||
end
|
||||
end
|
||||
|
||||
def self.set_up_auth_query_for_user(user:, password:, instance_ports: [ 5432, 10432 ], database: 'shard0' )
|
||||
instance_ports.each do |port|
|
||||
connection = PG.connect("postgres://postgres:postgres@localhost:#{port}/#{database}")
|
||||
connection.exec(self.drop_query_auth_function(user)) rescue PG::UndefinedFunction
|
||||
connection.exec("DROP ROLE #{user}") rescue PG::UndefinedObject
|
||||
connection.exec("CREATE ROLE #{user} ENCRYPTED PASSWORD '#{password}' LOGIN;")
|
||||
connection.exec(self.create_query_auth_function(user))
|
||||
connection.close
|
||||
end
|
||||
end
|
||||
|
||||
def self.tear_down_auth_query_for_user(user:, password:, instance_ports: [ 5432, 10432 ], database: 'shard0' )
|
||||
instance_ports.each do |port|
|
||||
connection = PG.connect("postgres://postgres:postgres@localhost:#{port}/#{database}")
|
||||
connection.exec(self.drop_query_auth_function(user)) rescue PG::UndefinedFunction
|
||||
connection.exec("DROP ROLE #{user}")
|
||||
connection.close
|
||||
end
|
||||
end
|
||||
|
||||
def self.drop_query_auth_function(user)
|
||||
return <<-SQL
|
||||
REVOKE ALL ON FUNCTION public.user_lookup(text) FROM public, #{user};
|
||||
DROP FUNCTION public.user_lookup(in i_username text, out uname text, out phash text);
|
||||
SQL
|
||||
end
|
||||
end
|
||||
end
|
||||
259
tests/ruby/helpers/pg_socket.rb
Normal file
259
tests/ruby/helpers/pg_socket.rb
Normal file
@@ -0,0 +1,259 @@
|
||||
require 'socket'
|
||||
require 'digest/md5'
|
||||
|
||||
BACKEND_MESSAGE_CODES = {
|
||||
'Z' => "ReadyForQuery",
|
||||
'C' => "CommandComplete",
|
||||
'T' => "RowDescription",
|
||||
'D' => "DataRow",
|
||||
'1' => "ParseComplete",
|
||||
'2' => "BindComplete",
|
||||
'E' => "ErrorResponse",
|
||||
's' => "PortalSuspended",
|
||||
}
|
||||
|
||||
class PostgresSocket
|
||||
def initialize(host, port)
|
||||
@port = port
|
||||
@host = host
|
||||
@socket = TCPSocket.new @host, @port
|
||||
@parameters = {}
|
||||
@verbose = true
|
||||
end
|
||||
|
||||
def send_md5_password_message(username, password, salt)
|
||||
m = Digest::MD5.hexdigest(password + username)
|
||||
m = Digest::MD5.hexdigest(m + salt.map(&:chr).join(""))
|
||||
m = 'md5' + m
|
||||
bytes = (m.split("").map(&:ord) + [0]).flatten
|
||||
message_size = bytes.count + 4
|
||||
|
||||
message = []
|
||||
|
||||
message << 'p'.ord
|
||||
message << [message_size].pack('l>').unpack('CCCC') # 4
|
||||
message << bytes
|
||||
message.flatten!
|
||||
|
||||
|
||||
@socket.write(message.pack('C*'))
|
||||
end
|
||||
|
||||
def send_startup_message(username, database, password)
|
||||
message = []
|
||||
|
||||
message << [196608].pack('l>').unpack('CCCC') # 4
|
||||
message << "user".split('').map(&:ord) # 4, 8
|
||||
message << 0 # 1, 9
|
||||
message << username.split('').map(&:ord) # 2, 11
|
||||
message << 0 # 1, 12
|
||||
message << "database".split('').map(&:ord) # 8, 20
|
||||
message << 0 # 1, 21
|
||||
message << database.split('').map(&:ord) # 2, 23
|
||||
message << 0 # 1, 24
|
||||
message << 0 # 1, 25
|
||||
message.flatten!
|
||||
|
||||
total_message_size = message.size + 4
|
||||
|
||||
message_len = [total_message_size].pack('l>').unpack('CCCC')
|
||||
|
||||
@socket.write([message_len + message].flatten.pack('C*'))
|
||||
|
||||
sleep 0.1
|
||||
|
||||
read_startup_response(username, password)
|
||||
end
|
||||
|
||||
def read_startup_response(username, password)
|
||||
message_code, message_len = @socket.recv(5).unpack("al>")
|
||||
while message_code == 'R'
|
||||
auth_code = @socket.recv(4).unpack('l>').pop
|
||||
case auth_code
|
||||
when 5 # md5
|
||||
salt = @socket.recv(4).unpack('CCCC')
|
||||
send_md5_password_message(username, password, salt)
|
||||
message_code, message_len = @socket.recv(5).unpack("al>")
|
||||
when 0 # trust
|
||||
break
|
||||
end
|
||||
end
|
||||
loop do
|
||||
message_code, message_len = @socket.recv(5).unpack("al>")
|
||||
if message_code == 'Z'
|
||||
@socket.recv(1).unpack("a") # most likely I
|
||||
break # We are good to go
|
||||
end
|
||||
if message_code == 'S'
|
||||
actual_message = @socket.recv(message_len - 4).unpack("C*")
|
||||
k,v = actual_message.pack('U*').split(/\x00/)
|
||||
@parameters[k] = v
|
||||
end
|
||||
if message_code == 'K'
|
||||
process_id, secret_key = @socket.recv(message_len - 4).unpack("l>l>")
|
||||
@parameters["process_id"] = process_id
|
||||
@parameters["secret_key"] = secret_key
|
||||
end
|
||||
end
|
||||
return @parameters
|
||||
end
|
||||
|
||||
def cancel_query
|
||||
socket = TCPSocket.new @host, @port
|
||||
process_key = @parameters["process_id"]
|
||||
secret_key = @parameters["secret_key"]
|
||||
message = []
|
||||
message << [16].pack('l>').unpack('CCCC') # 4
|
||||
message << [80877102].pack('l>').unpack('CCCC') # 4
|
||||
message << [process_key.to_i].pack('l>').unpack('CCCC') # 4
|
||||
message << [secret_key.to_i].pack('l>').unpack('CCCC') # 4
|
||||
message.flatten!
|
||||
socket.write(message.flatten.pack('C*'))
|
||||
socket.close
|
||||
log "[F] Sent CancelRequest message"
|
||||
end
|
||||
|
||||
def send_query_message(query)
|
||||
query_size = query.length
|
||||
message_size = 1 + 4 + query_size
|
||||
message = []
|
||||
message << "Q".ord
|
||||
message << [message_size].pack('l>').unpack('CCCC') # 4
|
||||
message << query.split('').map(&:ord) # 2, 11
|
||||
message << 0 # 1, 12
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent Q message (#{query})"
|
||||
end
|
||||
|
||||
def send_parse_message(query)
|
||||
query_size = query.length
|
||||
message_size = 2 + 2 + 4 + query_size
|
||||
message = []
|
||||
message << "P".ord
|
||||
message << [message_size].pack('l>').unpack('CCCC') # 4
|
||||
message << 0 # unnamed statement
|
||||
message << query.split('').map(&:ord) # 2, 11
|
||||
message << 0 # 1, 12
|
||||
message << [0, 0]
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent P message (#{query})"
|
||||
end
|
||||
|
||||
def send_bind_message
|
||||
message = []
|
||||
message << "B".ord
|
||||
message << [12].pack('l>').unpack('CCCC') # 4
|
||||
message << 0 # unnamed statement
|
||||
message << 0 # unnamed statement
|
||||
message << [0, 0] # 2
|
||||
message << [0, 0] # 2
|
||||
message << [0, 0] # 2
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent B message"
|
||||
end
|
||||
|
||||
def send_describe_message(mode)
|
||||
message = []
|
||||
message << "D".ord
|
||||
message << [6].pack('l>').unpack('CCCC') # 4
|
||||
message << mode.ord
|
||||
message << 0 # unnamed statement
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent D message"
|
||||
end
|
||||
|
||||
def send_execute_message(limit=0)
|
||||
message = []
|
||||
message << "E".ord
|
||||
message << [9].pack('l>').unpack('CCCC') # 4
|
||||
message << 0 # unnamed statement
|
||||
message << [limit].pack('l>').unpack('CCCC') # 4
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent E message"
|
||||
end
|
||||
|
||||
def send_sync_message
|
||||
message = []
|
||||
message << "S".ord
|
||||
message << [4].pack('l>').unpack('CCCC') # 4
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent S message"
|
||||
end
|
||||
|
||||
def send_copydone_message
|
||||
message = []
|
||||
message << "c".ord
|
||||
message << [4].pack('l>').unpack('CCCC') # 4
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent c message"
|
||||
end
|
||||
|
||||
def send_copyfail_message
|
||||
message = []
|
||||
message << "f".ord
|
||||
message << [5].pack('l>').unpack('CCCC') # 4
|
||||
message << 0
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent f message"
|
||||
end
|
||||
|
||||
def send_flush_message
|
||||
message = []
|
||||
message << "H".ord
|
||||
message << [4].pack('l>').unpack('CCCC') # 4
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent H message"
|
||||
end
|
||||
|
||||
def read_from_server()
|
||||
output_messages = []
|
||||
retry_count = 0
|
||||
message_code = nil
|
||||
message_len = 0
|
||||
loop do
|
||||
begin
|
||||
message_code, message_len = @socket.recv_nonblock(5).unpack("al>")
|
||||
rescue IO::WaitReadable
|
||||
return output_messages if retry_count > 50
|
||||
|
||||
retry_count += 1
|
||||
sleep(0.01)
|
||||
next
|
||||
end
|
||||
message = {
|
||||
code: message_code,
|
||||
len: message_len,
|
||||
bytes: []
|
||||
}
|
||||
log "[B] #{BACKEND_MESSAGE_CODES[message_code] || ('UnknownMessage(' + message_code + ')')}"
|
||||
|
||||
actual_message_length = message_len - 4
|
||||
if actual_message_length > 0
|
||||
message[:bytes] = @socket.recv(message_len - 4).unpack("C*")
|
||||
log "\t#{message[:bytes].join(",")}"
|
||||
log "\t#{message[:bytes].map(&:chr).join(" ")}"
|
||||
end
|
||||
output_messages << message
|
||||
return output_messages if message_code == 'Z'
|
||||
end
|
||||
end
|
||||
|
||||
def log(msg)
|
||||
return unless @verbose
|
||||
|
||||
puts msg
|
||||
end
|
||||
|
||||
def close
|
||||
@socket.close
|
||||
end
|
||||
end
|
||||
@@ -2,6 +2,14 @@ require 'json'
|
||||
require 'ostruct'
|
||||
require_relative 'pgcat_process'
|
||||
require_relative 'pg_instance'
|
||||
require_relative 'pg_socket'
|
||||
|
||||
class ::Hash
|
||||
def deep_merge(second)
|
||||
merger = proc { |key, v1, v2| Hash === v1 && Hash === v2 ? v1.merge(v2, &merger) : v2 }
|
||||
self.merge(second, &merger)
|
||||
end
|
||||
end
|
||||
|
||||
module Helpers
|
||||
module Pgcat
|
||||
|
||||
@@ -67,17 +67,21 @@ class PgcatProcess
|
||||
def start
|
||||
raise StandardError, "Process is already started" unless @pid.nil?
|
||||
@pid = Process.spawn(@env, @command, err: @log_filename, out: @log_filename)
|
||||
Process.detach(@pid)
|
||||
ObjectSpace.define_finalizer(@log_filename, proc { PgcatProcess.finalize(@pid, @log_filename, @config_filename) })
|
||||
|
||||
return self
|
||||
end
|
||||
|
||||
def wait_until_ready
|
||||
def wait_until_ready(connection_string = nil)
|
||||
exc = nil
|
||||
10.times do
|
||||
PG::connect(example_connection_string).close
|
||||
Process.kill 0, @pid
|
||||
PG::connect(connection_string || example_connection_string).close
|
||||
|
||||
return self
|
||||
rescue Errno::ESRCH
|
||||
raise StandardError, "Process #{@pid} died. #{logs}"
|
||||
rescue => e
|
||||
exc = e
|
||||
sleep(0.5)
|
||||
@@ -108,13 +112,10 @@ class PgcatProcess
|
||||
"postgresql://#{username}:#{password}@0.0.0.0:#{@port}/pgcat"
|
||||
end
|
||||
|
||||
def connection_string(pool_name, username)
|
||||
def connection_string(pool_name, username, password = nil)
|
||||
cfg = current_config
|
||||
|
||||
user_idx, user_obj = cfg["pools"][pool_name]["users"].detect { |k, user| user["username"] == username }
|
||||
password = user_obj["password"]
|
||||
|
||||
"postgresql://#{username}:#{password}@0.0.0.0:#{@port}/#{pool_name}"
|
||||
"postgresql://#{username}:#{password || user_obj["password"]}@0.0.0.0:#{@port}/#{pool_name}"
|
||||
end
|
||||
|
||||
def example_connection_string
|
||||
|
||||
@@ -65,7 +65,7 @@ describe "Least Outstanding Queries Load Balancing" do
|
||||
processes.pgcat.shutdown
|
||||
end
|
||||
|
||||
context "under homogenous load" do
|
||||
context "under homogeneous load" do
|
||||
it "balances query volume between all instances" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ describe "Query Mirroing" do
|
||||
processes.pgcat.shutdown
|
||||
end
|
||||
|
||||
it "can mirror a query" do
|
||||
xit "can mirror a query" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
runs = 15
|
||||
runs.times { conn.async_exec("SELECT 1 + 2") }
|
||||
|
||||
@@ -309,4 +309,58 @@ describe "Miscellaneous" do
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "Idle client timeout" do
|
||||
context "idle transaction timeout set to 0" do
|
||||
before do
|
||||
current_configs = processes.pgcat.current_config
|
||||
correct_idle_client_transaction_timeout = current_configs["general"]["idle_client_in_transaction_timeout"]
|
||||
puts(current_configs["general"]["idle_client_in_transaction_timeout"])
|
||||
|
||||
current_configs["general"]["idle_client_in_transaction_timeout"] = 0
|
||||
|
||||
processes.pgcat.update_config(current_configs) # with timeout 0
|
||||
processes.pgcat.reload_config
|
||||
end
|
||||
|
||||
it "Allow client to be idle in transaction" do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
conn.async_exec("BEGIN")
|
||||
conn.async_exec("SELECT 1")
|
||||
sleep(2)
|
||||
conn.async_exec("COMMIT")
|
||||
conn.close
|
||||
end
|
||||
end
|
||||
|
||||
context "idle transaction timeout set to 500ms" do
|
||||
before do
|
||||
current_configs = processes.pgcat.current_config
|
||||
correct_idle_client_transaction_timeout = current_configs["general"]["idle_client_in_transaction_timeout"]
|
||||
current_configs["general"]["idle_client_in_transaction_timeout"] = 500
|
||||
|
||||
processes.pgcat.update_config(current_configs) # with timeout 500
|
||||
processes.pgcat.reload_config
|
||||
end
|
||||
|
||||
it "Allow client to be idle in transaction below timeout" do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
conn.async_exec("BEGIN")
|
||||
conn.async_exec("SELECT 1")
|
||||
sleep(0.4) # below 500ms
|
||||
conn.async_exec("COMMIT")
|
||||
conn.close
|
||||
end
|
||||
|
||||
it "Error when client idle in transaction time exceeds timeout" do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
conn.async_exec("BEGIN")
|
||||
conn.async_exec("SELECT 1")
|
||||
sleep(1) # above 500ms
|
||||
expect{ conn.async_exec("COMMIT") }.to raise_error(PG::SystemError, /idle transaction timeout/)
|
||||
conn.async_exec("SELECT 1") # should be able to send another query
|
||||
conn.close
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
14
tests/ruby/plugins_spec.rb
Normal file
14
tests/ruby/plugins_spec.rb
Normal file
@@ -0,0 +1,14 @@
|
||||
require_relative 'spec_helper'
|
||||
|
||||
|
||||
describe "Plugins" do
|
||||
let(:processes) { Helpers::Pgcat.three_shard_setup("sharded_db", 5) }
|
||||
|
||||
context "intercept" do
|
||||
it "will intercept an intellij query" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
res = conn.exec("select current_database() as a, current_schemas(false) as b")
|
||||
expect(res.values).to eq([["sharded_db", "{public}"]])
|
||||
end
|
||||
end
|
||||
end
|
||||
155
tests/ruby/protocol_spec.rb
Normal file
155
tests/ruby/protocol_spec.rb
Normal file
@@ -0,0 +1,155 @@
|
||||
# frozen_string_literal: true
|
||||
require_relative 'spec_helper'
|
||||
|
||||
|
||||
describe "Portocol handling" do
|
||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 1, "session") }
|
||||
let(:sequence) { [] }
|
||||
let(:pgcat_socket) { PostgresSocket.new('localhost', processes.pgcat.port) }
|
||||
let(:pgdb_socket) { PostgresSocket.new('localhost', processes.all_databases.first.port) }
|
||||
|
||||
after do
|
||||
pgdb_socket.close
|
||||
pgcat_socket.close
|
||||
processes.all_databases.map(&:reset)
|
||||
processes.pgcat.shutdown
|
||||
end
|
||||
|
||||
def run_comparison(sequence, socket_a, socket_b)
|
||||
sequence.each do |msg, *args|
|
||||
socket_a.send(msg, *args)
|
||||
socket_b.send(msg, *args)
|
||||
|
||||
compare_messages(
|
||||
socket_a.read_from_server,
|
||||
socket_b.read_from_server
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
def compare_messages(msg_arr0, msg_arr1)
|
||||
if msg_arr0.count != msg_arr1.count
|
||||
error_output = []
|
||||
|
||||
error_output << "#{msg_arr0.count} : #{msg_arr1.count}"
|
||||
error_output << "PgCat Messages"
|
||||
error_output += msg_arr0.map { |message| "\t#{message[:code]} - #{message[:bytes].map(&:chr).join(" ")}" }
|
||||
error_output << "PgServer Messages"
|
||||
error_output += msg_arr1.map { |message| "\t#{message[:code]} - #{message[:bytes].map(&:chr).join(" ")}" }
|
||||
error_desc = error_output.join("\n")
|
||||
raise StandardError, "Message count mismatch #{error_desc}"
|
||||
end
|
||||
|
||||
(0..msg_arr0.count - 1).all? do |i|
|
||||
msg0 = msg_arr0[i]
|
||||
msg1 = msg_arr1[i]
|
||||
|
||||
result = [
|
||||
msg0[:code] == msg1[:code],
|
||||
msg0[:len] == msg1[:len],
|
||||
msg0[:bytes] == msg1[:bytes],
|
||||
].all?
|
||||
|
||||
next result if result
|
||||
|
||||
if result == false
|
||||
error_string = []
|
||||
if msg0[:code] != msg1[:code]
|
||||
error_string << "code #{msg0[:code]} != #{msg1[:code]}"
|
||||
end
|
||||
if msg0[:len] != msg1[:len]
|
||||
error_string << "len #{msg0[:len]} != #{msg1[:len]}"
|
||||
end
|
||||
if msg0[:bytes] != msg1[:bytes]
|
||||
error_string << "bytes #{msg0[:bytes]} != #{msg1[:bytes]}"
|
||||
end
|
||||
err = error_string.join("\n")
|
||||
|
||||
raise StandardError, "Message mismatch #{err}"
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
RSpec.shared_examples "at parity with database" do
|
||||
before do
|
||||
pgcat_socket.send_startup_message("sharding_user", "sharded_db", "sharding_user")
|
||||
pgdb_socket.send_startup_message("sharding_user", "shard0", "sharding_user")
|
||||
end
|
||||
|
||||
it "works" do
|
||||
run_comparison(sequence, pgcat_socket, pgdb_socket)
|
||||
end
|
||||
end
|
||||
|
||||
context "Cancel Query" do
|
||||
let(:sequence) {
|
||||
[
|
||||
[:send_query_message, "SELECT pg_sleep(5)"],
|
||||
[:cancel_query]
|
||||
]
|
||||
}
|
||||
|
||||
it_behaves_like "at parity with database"
|
||||
end
|
||||
|
||||
xcontext "Simple query after parse" do
|
||||
let(:sequence) {
|
||||
[
|
||||
[:send_parse_message, "SELECT 5"],
|
||||
[:send_query_message, "SELECT 1"],
|
||||
[:send_bind_message],
|
||||
[:send_describe_message, "P"],
|
||||
[:send_execute_message],
|
||||
[:send_sync_message],
|
||||
]
|
||||
}
|
||||
|
||||
# Known to fail due to PgCat not supporting flush
|
||||
it_behaves_like "at parity with database"
|
||||
end
|
||||
|
||||
xcontext "Flush message" do
|
||||
let(:sequence) {
|
||||
[
|
||||
[:send_parse_message, "SELECT 1"],
|
||||
[:send_flush_message]
|
||||
]
|
||||
}
|
||||
|
||||
# Known to fail due to PgCat not supporting flush
|
||||
it_behaves_like "at parity with database"
|
||||
end
|
||||
|
||||
xcontext "Bind without parse" do
|
||||
let(:sequence) {
|
||||
[
|
||||
[:send_bind_message]
|
||||
]
|
||||
}
|
||||
# This is known to fail.
|
||||
# Server responds immediately, Proxy buffers the message
|
||||
it_behaves_like "at parity with database"
|
||||
end
|
||||
|
||||
context "Simple message" do
|
||||
let(:sequence) {
|
||||
[[:send_query_message, "SELECT 1"]]
|
||||
}
|
||||
|
||||
it_behaves_like "at parity with database"
|
||||
end
|
||||
|
||||
context "Extended protocol" do
|
||||
let(:sequence) {
|
||||
[
|
||||
[:send_parse_message, "SELECT 1"],
|
||||
[:send_bind_message],
|
||||
[:send_describe_message, "P"],
|
||||
[:send_execute_message],
|
||||
[:send_sync_message],
|
||||
]
|
||||
}
|
||||
|
||||
it_behaves_like "at parity with database"
|
||||
end
|
||||
end
|
||||
@@ -27,7 +27,7 @@ describe "Sharding" do
|
||||
processes.pgcat.shutdown
|
||||
end
|
||||
|
||||
describe "automatic routing of extended procotol" do
|
||||
describe "automatic routing of extended protocol" do
|
||||
it "can do it" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
conn.exec("SET SERVER ROLE TO 'auto'")
|
||||
|
||||
@@ -19,3 +19,10 @@ ensure
|
||||
STDOUT.reopen(sout)
|
||||
STDERR.reopen(serr)
|
||||
end
|
||||
|
||||
def clients_connected_to_pool(pool_index: 0, processes:)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[pool_index]
|
||||
admin_conn.close
|
||||
results['cl_idle'].to_i + results['cl_active'].to_i + results['cl_waiting'].to_i
|
||||
end
|
||||
|
||||
1
utilities/requirements.txt
Normal file
1
utilities/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
tomli
|
||||
Reference in New Issue
Block a user