mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 01:16:30 +00:00
Compare commits
29 Commits
levkk-auth
...
levkk-star
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee23b374ae | ||
|
|
9dffebccbf | ||
|
|
4c8358b8b3 | ||
|
|
f0d1916a98 | ||
|
|
bba5f10be1 | ||
|
|
a514dbc187 | ||
|
|
d660e3e565 | ||
|
|
0d882cc204 | ||
|
|
b36746a47b | ||
|
|
9e51b8110f | ||
|
|
4a87b4807d | ||
|
|
cb5ff40a59 | ||
|
|
62b2d994c1 | ||
|
|
66805d7e77 | ||
|
|
4ccc1e7fa3 | ||
|
|
3dae3d0777 | ||
|
|
a18eb42df5 | ||
|
|
6aacf1fa19 | ||
|
|
8e99e65215 | ||
|
|
5dfbc102a9 | ||
|
|
bae12fca99 | ||
|
|
421c5d4b64 | ||
|
|
d568739db9 | ||
|
|
692353c839 | ||
|
|
a62f6b0eea | ||
|
|
89e15f09b5 | ||
|
|
7ddd23b514 | ||
|
|
faa9c1f64a | ||
|
|
9094988491 |
@@ -39,7 +39,7 @@ log_client_connections = false
|
||||
log_client_disconnections = false
|
||||
|
||||
# Reload config automatically if it changes.
|
||||
autoreload = true
|
||||
autoreload = 15000
|
||||
|
||||
# TLS
|
||||
tls_certificate = ".circleci/server.cert"
|
||||
|
||||
14
.editorconfig
Normal file
14
.editorconfig
Normal file
@@ -0,0 +1,14 @@
|
||||
root = true
|
||||
|
||||
[*]
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
|
||||
[*.rs]
|
||||
indent_style = space
|
||||
indent_size = 4
|
||||
max_line_length = 120
|
||||
|
||||
[*.toml]
|
||||
indent_style = space
|
||||
indent_size = 2
|
||||
133
CONFIG.md
133
CONFIG.md
@@ -49,6 +49,14 @@ default: 30000 # milliseconds
|
||||
|
||||
How long an idle connection with a server is left open (ms).
|
||||
|
||||
### server_lifetime
|
||||
```
|
||||
path: general.server_lifetime
|
||||
default: 86400000 # 24 hours
|
||||
```
|
||||
|
||||
Max connection lifetime before it's closed, even if actively used.
|
||||
|
||||
### idle_client_in_transaction_timeout
|
||||
```
|
||||
path: general.idle_client_in_transaction_timeout
|
||||
@@ -108,7 +116,7 @@ If we should log client disconnections
|
||||
### autoreload
|
||||
```
|
||||
path: general.autoreload
|
||||
default: false
|
||||
default: 15000
|
||||
```
|
||||
|
||||
When set to true, PgCat reloads configs if it detects a change in the config file.
|
||||
@@ -152,7 +160,7 @@ default: <UNSET>
|
||||
example: "server.cert"
|
||||
```
|
||||
|
||||
Path to TLS Certficate file to use for TLS connections
|
||||
Path to TLS Certificate file to use for TLS connections
|
||||
|
||||
### tls_private_key
|
||||
```
|
||||
@@ -175,41 +183,11 @@ Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DAT
|
||||
### admin_password
|
||||
```
|
||||
path: general.admin_password
|
||||
default: <UNSET>
|
||||
default: "admin_pass"
|
||||
```
|
||||
|
||||
Password to access the virtual administrative database
|
||||
|
||||
### auth_query (experimental)
|
||||
```
|
||||
path: general.auth_query
|
||||
default: <UNSET>
|
||||
```
|
||||
|
||||
Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
|
||||
established using the database configured in the pool. This parameter is inherited by every pool
|
||||
and can be redefined in pool configuration.
|
||||
|
||||
### auth_query_user (experimental)
|
||||
```
|
||||
path: general.auth_query_user
|
||||
default: <UNSET>
|
||||
```
|
||||
|
||||
User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
|
||||
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
|
||||
This parameter is inherited by every pool and can be redefined in pool configuration.
|
||||
|
||||
### auth_query_password (experimental)
|
||||
```
|
||||
path: general.auth_query_password
|
||||
default: <UNSET>
|
||||
```
|
||||
|
||||
Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
|
||||
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
|
||||
This parameter is inherited by every pool and can be redefined in pool configuration.
|
||||
|
||||
## `pools.<pool_name>` Section
|
||||
|
||||
### pool_mode
|
||||
@@ -243,7 +221,7 @@ If the client doesn't specify, PgCat routes traffic to this role by default.
|
||||
`replica` round-robin between replicas only without touching the primary,
|
||||
`primary` all queries go to the primary unless otherwise specified.
|
||||
|
||||
### query_parser_enabled (experimental)
|
||||
### query_parser_enabled
|
||||
```
|
||||
path: pools.<pool_name>.query_parser_enabled
|
||||
default: true
|
||||
@@ -264,7 +242,7 @@ If the query parser is enabled and this setting is enabled, the primary will be
|
||||
load balancing of read queries. Otherwise, the primary will only be used for write
|
||||
queries. The primary can always be explicitly selected with our custom protocol.
|
||||
|
||||
### sharding_key_regex (experimental)
|
||||
### sharding_key_regex
|
||||
```
|
||||
path: pools.<pool_name>.sharding_key_regex
|
||||
default: <UNSET>
|
||||
@@ -286,7 +264,40 @@ Current options:
|
||||
`pg_bigint_hash`: PARTITION BY HASH (Postgres hashing function)
|
||||
`sha1`: A hashing function based on SHA1
|
||||
|
||||
### automatic_sharding_key (experimental)
|
||||
### auth_query
|
||||
```
|
||||
path: pools.<pool_name>.auth_query
|
||||
default: <UNSET>
|
||||
example: "SELECT $1"
|
||||
```
|
||||
|
||||
Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
|
||||
established using the database configured in the pool. This parameter is inherited by every pool
|
||||
and can be redefined in pool configuration.
|
||||
|
||||
### auth_query_user
|
||||
```
|
||||
path: pools.<pool_name>.auth_query_user
|
||||
default: <UNSET>
|
||||
example: "sharding_user"
|
||||
```
|
||||
|
||||
User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
|
||||
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
|
||||
This parameter is inherited by every pool and can be redefined in pool configuration.
|
||||
|
||||
### auth_query_password
|
||||
```
|
||||
path: pools.<pool_name>.auth_query_password
|
||||
default: <UNSET>
|
||||
example: "sharding_user"
|
||||
```
|
||||
|
||||
Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
|
||||
specified in `auth_query_user`. The connection will be established using the database configured in the pool.
|
||||
This parameter is inherited by every pool and can be redefined in pool configuration.
|
||||
|
||||
### automatic_sharding_key
|
||||
```
|
||||
path: pools.<pool_name>.automatic_sharding_key
|
||||
default: <UNSET>
|
||||
@@ -311,30 +322,6 @@ default: 3000
|
||||
|
||||
Connect timeout can be overwritten in the pool
|
||||
|
||||
### auth_query (experimental)
|
||||
```
|
||||
path: general.auth_query
|
||||
default: <UNSET>
|
||||
```
|
||||
|
||||
Auth query can be overwritten in the pool
|
||||
|
||||
### auth_query_user (experimental)
|
||||
```
|
||||
path: general.auth_query_user
|
||||
default: <UNSET>
|
||||
```
|
||||
|
||||
Auth query user can be overwritten in the pool
|
||||
|
||||
### auth_query_password (experimental)
|
||||
```
|
||||
path: general.auth_query_password
|
||||
default: <UNSET>
|
||||
```
|
||||
|
||||
Auth query password can be overwritten in the pool
|
||||
|
||||
## `pools.<pool_name>.users.<user_index>` Section
|
||||
|
||||
### username
|
||||
@@ -343,7 +330,8 @@ path: pools.<pool_name>.users.<user_index>.username
|
||||
default: "sharding_user"
|
||||
```
|
||||
|
||||
Postgresql username
|
||||
PostgreSQL username used to authenticate the user and connect to the server
|
||||
if `server_username` is not set.
|
||||
|
||||
### password
|
||||
```
|
||||
@@ -351,7 +339,26 @@ path: pools.<pool_name>.users.<user_index>.password
|
||||
default: "sharding_user"
|
||||
```
|
||||
|
||||
Postgresql password
|
||||
PostgreSQL password used to authenticate the user and connect to the server
|
||||
if `server_password` is not set.
|
||||
|
||||
### server_username
|
||||
```
|
||||
path: pools.<pool_name>.users.<user_index>.server_username
|
||||
default: <UNSET>
|
||||
example: "another_user"
|
||||
```
|
||||
|
||||
PostgreSQL username used to connect to the server.
|
||||
|
||||
### server_password
|
||||
```
|
||||
path: pools.<pool_name>.users.<user_index>.server_password
|
||||
default: <UNSET>
|
||||
example: "another_password"
|
||||
```
|
||||
|
||||
PostgreSQL password used to connect to the server.
|
||||
|
||||
### pool_size
|
||||
```
|
||||
@@ -382,7 +389,7 @@ default: [["127.0.0.1", 5432, "primary"], ["localhost", 5432, "replica"]]
|
||||
|
||||
Array of servers in the shard, each server entry is an array of `[host, port, role]`
|
||||
|
||||
### mirrors (experimental)
|
||||
### mirrors
|
||||
```
|
||||
path: pools.<pool_name>.shards.<shard_index>.mirrors
|
||||
default: <UNSET>
|
||||
|
||||
141
Cargo.lock
generated
141
Cargo.lock
generated
@@ -4,9 +4,9 @@ version = 3
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "0.7.20"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac"
|
||||
checksum = "67fc08ce920c31afb70f013dcce1bfc3a3195de6a228474e45e1f145b36f8d04"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
@@ -54,12 +54,6 @@ version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.21.0"
|
||||
@@ -283,9 +277,9 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
|
||||
|
||||
[[package]]
|
||||
name = "futures"
|
||||
version = "0.3.27"
|
||||
version = "0.3.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "531ac96c6ff5fd7c62263c5e3c67a603af4fcaee2e1a0ae5565ba3a11e69e549"
|
||||
checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
@@ -298,9 +292,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "futures-channel"
|
||||
version = "0.3.27"
|
||||
version = "0.3.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "164713a5a0dcc3e7b4b1ed7d3b433cabc18025386f9339346e8daf15963cf7ac"
|
||||
checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
@@ -308,15 +302,15 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "futures-core"
|
||||
version = "0.3.27"
|
||||
version = "0.3.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "86d7a0c1aa76363dac491de0ee99faf6941128376f1cf96f07db7603b7de69dd"
|
||||
checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c"
|
||||
|
||||
[[package]]
|
||||
name = "futures-executor"
|
||||
version = "0.3.27"
|
||||
version = "0.3.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1997dd9df74cdac935c76252744c1ed5794fac083242ea4fe77ef3ed60ba0f83"
|
||||
checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-task",
|
||||
@@ -325,38 +319,38 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "futures-io"
|
||||
version = "0.3.27"
|
||||
version = "0.3.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "89d422fa3cbe3b40dca574ab087abb5bc98258ea57eea3fd6f1fa7162c778b91"
|
||||
checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964"
|
||||
|
||||
[[package]]
|
||||
name = "futures-macro"
|
||||
version = "0.3.27"
|
||||
version = "0.3.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3eb14ed937631bd8b8b8977f2c198443447a8355b6e3ca599f38c975e5a963b6"
|
||||
checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
"syn 2.0.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-sink"
|
||||
version = "0.3.27"
|
||||
version = "0.3.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec93083a4aecafb2a80a885c9de1f0ccae9dbd32c2bb54b0c3a65690e0b8d2f2"
|
||||
checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e"
|
||||
|
||||
[[package]]
|
||||
name = "futures-task"
|
||||
version = "0.3.27"
|
||||
version = "0.3.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fd65540d33b37b16542a0438c12e6aeead10d4ac5d05bd3f805b8f35ab592879"
|
||||
checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65"
|
||||
|
||||
[[package]]
|
||||
name = "futures-util"
|
||||
version = "0.3.27"
|
||||
version = "0.3.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3ef6b17e481503ec85211fed8f39d1970f128935ca1f814cd32ac4a6842e84ab"
|
||||
checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
@@ -393,9 +387,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.3.15"
|
||||
version = "0.3.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5f9f29bc9dda355256b2916cf526ab02ce0aeaaaf2bad60d65ef3f12f11dd0f4"
|
||||
checksum = "66b91535aa35fea1523ad1b86cb6b53c28e0ae566ba4a460f4457e936cad7c6f"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fnv",
|
||||
@@ -482,9 +476,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
|
||||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "0.14.25"
|
||||
version = "0.14.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cc5e554ff619822309ffd57d8734d77cd5ce6238bc956f037ea06c58238c9899"
|
||||
checksum = "ab302d72a6f11a3b910431ff93aae7e773078c769f0a3ef15fb9ec692ed147d4"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
@@ -745,12 +739,12 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "pgcat"
|
||||
version = "1.0.0"
|
||||
version = "1.0.1"
|
||||
dependencies = [
|
||||
"arc-swap",
|
||||
"async-trait",
|
||||
"atomic_enum",
|
||||
"base64 0.21.0",
|
||||
"base64",
|
||||
"bb8",
|
||||
"bytes",
|
||||
"chrono",
|
||||
@@ -768,9 +762,11 @@ dependencies = [
|
||||
"once_cell",
|
||||
"parking_lot",
|
||||
"phf",
|
||||
"pin-project",
|
||||
"postgres-protocol",
|
||||
"rand",
|
||||
"regex",
|
||||
"rustls",
|
||||
"rustls-pemfile",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
@@ -782,6 +778,7 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"toml",
|
||||
"webpki-roots",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -826,6 +823,26 @@ dependencies = [
|
||||
"siphasher",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project"
|
||||
version = "1.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc"
|
||||
dependencies = [
|
||||
"pin-project-internal",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-internal"
|
||||
version = "1.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-lite"
|
||||
version = "0.2.9"
|
||||
@@ -840,11 +857,11 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
|
||||
|
||||
[[package]]
|
||||
name = "postgres-protocol"
|
||||
version = "0.6.4"
|
||||
version = "0.6.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "878c6cbf956e03af9aa8204b407b9cbf47c072164800aa918c516cd4b056c50c"
|
||||
checksum = "78b7fa9f396f51dffd61546fd8573ee20592287996568e6175ceb0f8699ad75d"
|
||||
dependencies = [
|
||||
"base64 0.13.1",
|
||||
"base64",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
@@ -921,9 +938,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.7.3"
|
||||
version = "1.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b1f693b24f6ac912f4893ef08244d70b6067480d2f1a46e950c9691e6749d1d"
|
||||
checksum = "ac6cf59af1067a3fb53fbe5c88c053764e930f932be1d71d3ffe032cbe147f59"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
@@ -932,9 +949,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.6.29"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
|
||||
checksum = "b6868896879ba532248f33598de5181522d8b3d9d724dfd230911e1a7d4822f5"
|
||||
|
||||
[[package]]
|
||||
name = "ring"
|
||||
@@ -967,14 +984,14 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.20.8"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fff78fc74d175294f4e83b28343315ffcfb114b156f0185e9741cb5570f50e2f"
|
||||
checksum = "07180898a28ed6a7f7ba2311594308f595e3dd2e3c3812fa0a80a47b45f17e5d"
|
||||
dependencies = [
|
||||
"log",
|
||||
"ring",
|
||||
"rustls-webpki",
|
||||
"sct",
|
||||
"webpki",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -983,7 +1000,17 @@ version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b"
|
||||
dependencies = [
|
||||
"base64 0.21.0",
|
||||
"base64",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.100.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d6207cd5ed3d8dca7816f8f3725513a34609c0c765bf652b8c3cb4cfd87db46b"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1010,15 +1037,15 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.159"
|
||||
version = "1.0.160"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3c04e8343c3daeec41f58990b9d77068df31209f2af111e059e9fe9646693065"
|
||||
checksum = "bb2f3770c8bce3bcda7e149193a069a0f4365bda1fa5cd88e03bca26afc1216c"
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.159"
|
||||
version = "1.0.160"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4c614d17805b093df4b147b51339e7e44bf05ef59fba1e45d83500bcfb4d8585"
|
||||
checksum = "291a097c63d8497e00160b166a967a4a79c64f3facdd01cbd7502231688d77df"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -1104,9 +1131,9 @@ checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
|
||||
|
||||
[[package]]
|
||||
name = "sqlparser"
|
||||
version = "0.32.0"
|
||||
version = "0.33.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0366f270dbabb5cc2e4c88427dc4c08bba144f81e32fbd459a013f26a4d16aa0"
|
||||
checksum = "355dc4d4b6207ca8a3434fc587db0a8016130a574dbcdbfb93d7f7b5bc5b211a"
|
||||
dependencies = [
|
||||
"log",
|
||||
]
|
||||
@@ -1223,13 +1250,12 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.23.4"
|
||||
version = "0.24.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59"
|
||||
checksum = "e0d409377ff5b1e3ca6437aa86c1eb7d40c134bfec254e44c830defa92669db5"
|
||||
dependencies = [
|
||||
"rustls",
|
||||
"tokio",
|
||||
"webpki",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1444,13 +1470,12 @@ dependencies = [
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki"
|
||||
version = "0.22.0"
|
||||
name = "webpki-roots"
|
||||
version = "0.23.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd"
|
||||
checksum = "aa54963694b65584e170cf5dc46aeb4dcaa5584e652ff5f3952e56d66aff0125"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"untrusted",
|
||||
"rustls-webpki",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
11
Cargo.toml
11
Cargo.toml
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "pgcat"
|
||||
version = "1.0.0"
|
||||
version = "1.0.1"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
@@ -19,7 +19,7 @@ serde_derive = "1"
|
||||
regex = "1"
|
||||
num_cpus = "1"
|
||||
once_cell = "1"
|
||||
sqlparser = "0.32.0"
|
||||
sqlparser = "0.33.0"
|
||||
log = "0.4"
|
||||
arc-swap = "1"
|
||||
env_logger = "0.10"
|
||||
@@ -28,7 +28,7 @@ hmac = "0.12"
|
||||
sha2 = "0.10"
|
||||
base64 = "0.21"
|
||||
stringprep = "0.1"
|
||||
tokio-rustls = "0.23"
|
||||
tokio-rustls = "0.24"
|
||||
rustls-pemfile = "1"
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
phf = { version = "0.11.1", features = ["macros"] }
|
||||
@@ -37,8 +37,11 @@ futures = "0.3"
|
||||
socket2 = { version = "0.4.7", features = ["all"] }
|
||||
nix = "0.26.2"
|
||||
atomic_enum = "0.2.0"
|
||||
postgres-protocol = "0.6.4"
|
||||
postgres-protocol = "0.6.5"
|
||||
fallible-iterator = "0.2"
|
||||
pin-project = "1"
|
||||
webpki-roots = "0.23"
|
||||
rustls = { version = "0.21", features = ["dangerous_configuration"] }
|
||||
|
||||
[target.'cfg(not(target_env = "msvc"))'.dependencies]
|
||||
jemallocator = "0.5.0"
|
||||
|
||||
46
README.md
46
README.md
@@ -21,21 +21,53 @@ PostgreSQL pooler and proxy (like PgBouncer) with support for sharding, load bal
|
||||
| Client TLS | **Stable** | Clients can connect to the pooler using TLS/SSL. |
|
||||
| Client/Server authentication | **Stable** | Clients can connect using MD5 authentication, supported by `libpq` and all Postgres client drivers. PgCat can connect to Postgres using MD5 and SCRAM-SHA-256. |
|
||||
| Live configuration reloading | **Stable** | Identical to PgBouncer; all settings can be reloaded dynamically (except `host` and `port`). |
|
||||
| Auth passthrough | **Stable** | MD5 password authentication can be configured to use an `auth_query` so no cleartext passwords are needed in the config file.|
|
||||
| Sharding using extended SQL syntax | **Experimental** | Clients can dynamically configure the pooler to route queries to specific shards. |
|
||||
| Sharding using comments parsing/Regex | **Experimental** | Clients can include shard information (sharding key, shard ID) in the query comments. |
|
||||
| Automatic sharding | **Experimental** | PgCat can parse queries, detect sharding keys automatically, and route queries to the correct shard. |
|
||||
| Mirroring | **Experimental** | Mirror queries between multiple databases in order to test servers with realistic production traffic. |
|
||||
| Auth passthrough | **Experimental** | MD5 password authentication can be configured to use an `auth_query` so no cleartext passwords are needed in the config file. |
|
||||
|
||||
|
||||
## Status
|
||||
|
||||
PgCat is stable and used in production to serve hundreds of thousands of queries per second. Some features remain experimental and are being actively developed. They are optional and can be enabled through configuration.
|
||||
PgCat is stable and used in production to serve hundreds of thousands of queries per second.
|
||||
|
||||
| | |
|
||||
|-|-|
|
||||
|<a href="https://tech.instacart.com/adopting-pgcat-a-nextgen-postgres-proxy-3cf284e68c2f"><img src="./images/instacart.webp" height="70" width="auto"></a>|<a href="https://postgresml.org/blog/scaling-postgresml-to-one-million-requests-per-second"><img src="./images/postgresml.webp" height="70" width="auto"></a>|
|
||||
| [Instacart](https://tech.instacart.com/adopting-pgcat-a-nextgen-postgres-proxy-3cf284e68c2f) | [PostgresML](https://postgresml.org/blog/scaling-postgresml-to-one-million-requests-per-second) |
|
||||
<table>
|
||||
<tr>
|
||||
<td>
|
||||
<a href="https://tech.instacart.com/adopting-pgcat-a-nextgen-postgres-proxy-3cf284e68c2f">
|
||||
<img src="./images/instacart.webp" height="70" width="auto">
|
||||
</a>
|
||||
</td>
|
||||
<td>
|
||||
<a href="https://postgresml.org/blog/scaling-postgresml-to-one-million-requests-per-second">
|
||||
<img src="./images/postgresml.webp" height="70" width="auto">
|
||||
</a>
|
||||
</td>
|
||||
<td>
|
||||
<a href="https://onesignal.com">
|
||||
<img src="./images/one_signal.webp" height="70" width="auto">
|
||||
</a>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<a href="https://tech.instacart.com/adopting-pgcat-a-nextgen-postgres-proxy-3cf284e68c2f">
|
||||
Instacart
|
||||
</a>
|
||||
</td>
|
||||
<td>
|
||||
<a href="https://postgresml.org/blog/scaling-postgresml-to-one-million-requests-per-second">
|
||||
PostgresML
|
||||
</a>
|
||||
</td>
|
||||
<td>
|
||||
OneSignal
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
Some features remain experimental and are being actively developed. They are optional and can be enabled through configuration.
|
||||
|
||||
## Deployment
|
||||
|
||||
@@ -99,7 +131,7 @@ You can open a Docker development environment where you can debug tests easier.
|
||||
./dev/script/console
|
||||
```
|
||||
|
||||
This will open a terminal in an environment similar to that used in tests. In there, you can compile the pooler, run tests, do some debugging with the test environment, etc. Objects compiled inside the contaner (and bundled gems) will be placed in `dev/cache` so they don't interfere with what you have on your machine.
|
||||
This will open a terminal in an environment similar to that used in tests. In there, you can compile the pooler, run tests, do some debugging with the test environment, etc. Objects compiled inside the container (and bundled gems) will be placed in `dev/cache` so they don't interfere with what you have on your machine.
|
||||
|
||||
## Usage
|
||||
|
||||
|
||||
@@ -38,9 +38,6 @@ log_client_connections = false
|
||||
# If we should log client disconnections
|
||||
log_client_disconnections = false
|
||||
|
||||
# Reload config automatically if it changes.
|
||||
autoreload = false
|
||||
|
||||
# TLS
|
||||
# tls_certificate = "server.cert"
|
||||
# tls_private_key = "server.key"
|
||||
@@ -76,7 +73,7 @@ query_parser_enabled = true
|
||||
|
||||
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
|
||||
# load balancing of read queries. Otherwise, the primary will only be used for write
|
||||
# queries. The primary can always be explicitely selected with our custom protocol.
|
||||
# queries. The primary can always be explicitly selected with our custom protocol.
|
||||
primary_reads_enabled = true
|
||||
|
||||
# So what if you wanted to implement a different hashing function,
|
||||
|
||||
BIN
images/one_signal.webp
Normal file
BIN
images/one_signal.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 16 KiB |
53
pgcat.toml
53
pgcat.toml
@@ -23,6 +23,9 @@ connect_timeout = 5000 # milliseconds
|
||||
# How long an idle connection with a server is left open (ms).
|
||||
idle_timeout = 30000 # milliseconds
|
||||
|
||||
# Max connection lifetime before it's closed, even if actively used.
|
||||
server_lifetime = 86400000 # 24 hours
|
||||
|
||||
# How long a client is allowed to be idle while in a transaction (ms).
|
||||
idle_client_in_transaction_timeout = 0 # milliseconds
|
||||
|
||||
@@ -45,7 +48,7 @@ log_client_connections = false
|
||||
log_client_disconnections = false
|
||||
|
||||
# When set to true, PgCat reloads configs if it detects a change in the config file.
|
||||
autoreload = false
|
||||
autoreload = 15000
|
||||
|
||||
# Number of worker threads the Runtime will use (4 by default).
|
||||
worker_threads = 5
|
||||
@@ -57,10 +60,16 @@ tcp_keepalives_count = 5
|
||||
# Number of seconds between keepalive packets.
|
||||
tcp_keepalives_interval = 5
|
||||
|
||||
# Path to TLS Certficate file to use for TLS connections
|
||||
# tls_certificate = "server.cert"
|
||||
# Path to TLS Certificate file to use for TLS connections
|
||||
# tls_certificate = ".circleci/server.cert"
|
||||
# Path to TLS private key file to use for TLS connections
|
||||
# tls_private_key = "server.key"
|
||||
# tls_private_key = ".circleci/server.key"
|
||||
|
||||
# Enable/disable server TLS
|
||||
server_tls = false
|
||||
|
||||
# Verify server certificate is completely authentic.
|
||||
verify_server_certificate = false
|
||||
|
||||
# User name to access the virtual administrative database (pgbouncer or pgcat)
|
||||
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
|
||||
@@ -113,6 +122,21 @@ primary_reads_enabled = true
|
||||
# `sha1`: A hashing function based on SHA1
|
||||
sharding_function = "pg_bigint_hash"
|
||||
|
||||
# Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
|
||||
# established using the database configured in the pool. This parameter is inherited by every pool
|
||||
# and can be redefined in pool configuration.
|
||||
# auth_query = "SELECT $1"
|
||||
|
||||
# User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
|
||||
# specified in `auth_query_user`. The connection will be established using the database configured in the pool.
|
||||
# This parameter is inherited by every pool and can be redefined in pool configuration.
|
||||
# auth_query_user = "sharding_user"
|
||||
|
||||
# Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
|
||||
# specified in `auth_query_user`. The connection will be established using the database configured in the pool.
|
||||
# This parameter is inherited by every pool and can be redefined in pool configuration.
|
||||
# auth_query_password = "sharding_user"
|
||||
|
||||
# Automatically parse this from queries and route queries to the right shard!
|
||||
# automatic_sharding_key = "data.id"
|
||||
|
||||
@@ -123,17 +147,30 @@ idle_timeout = 40000
|
||||
connect_timeout = 3000
|
||||
|
||||
# User configs are structured as pool.<pool_name>.users.<user_index>
|
||||
# This secion holds the credentials for users that may connect to this cluster
|
||||
# This section holds the credentials for users that may connect to this cluster
|
||||
[pools.sharded_db.users.0]
|
||||
# Postgresql username
|
||||
# PostgreSQL username used to authenticate the user and connect to the server
|
||||
# if `server_username` is not set.
|
||||
username = "sharding_user"
|
||||
# Postgresql password
|
||||
|
||||
# PostgreSQL password used to authenticate the user and connect to the server
|
||||
# if `server_password` is not set.
|
||||
password = "sharding_user"
|
||||
|
||||
pool_mode = "session"
|
||||
|
||||
# PostgreSQL username used to connect to the server.
|
||||
# server_username = "another_user"
|
||||
|
||||
# PostgreSQL password used to connect to the server.
|
||||
# server_password = "another_password"
|
||||
|
||||
# Maximum number of server connections that can be established for this user
|
||||
# The maximum number of connection from a single Pgcat process to any database in the cluster
|
||||
# is the sum of pool_size across all users.
|
||||
pool_size = 9
|
||||
|
||||
|
||||
# Maximum query duration. Dangerous, but protects against DBs that died in a non-obvious way.
|
||||
# 0 means it is disabled.
|
||||
statement_timeout = 0
|
||||
@@ -178,6 +215,8 @@ sharding_function = "pg_bigint_hash"
|
||||
username = "simple_user"
|
||||
password = "simple_user"
|
||||
pool_size = 5
|
||||
min_pool_size = 3
|
||||
server_lifetime = 60000
|
||||
statement_timeout = 0
|
||||
|
||||
[pools.simple_db.shards.0]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::errors::Error;
|
||||
use crate::pool::ConnectionPool;
|
||||
use crate::server::Server;
|
||||
use log::debug;
|
||||
|
||||
@@ -71,25 +72,36 @@ impl AuthPassthrough {
|
||||
let auth_user = crate::config::User {
|
||||
username: self.user.clone(),
|
||||
password: Some(self.password.clone()),
|
||||
server_username: None,
|
||||
server_password: None,
|
||||
pool_size: 1,
|
||||
statement_timeout: 0,
|
||||
pool_mode: None,
|
||||
server_lifetime: None,
|
||||
min_pool_size: None,
|
||||
};
|
||||
|
||||
let user = &address.username;
|
||||
|
||||
debug!("Connecting to server to obtain auth hashes.");
|
||||
debug!("Connecting to server to obtain auth hashes");
|
||||
|
||||
let auth_query = self.query.replace("$1", user);
|
||||
|
||||
match Server::exec_simple_query(address, &auth_user, &auth_query).await {
|
||||
Ok(password_data) => {
|
||||
if password_data.len() == 2 && password_data.first().unwrap() == user {
|
||||
if let Some(stripped_hash) = password_data.last().unwrap().to_string().strip_prefix("md5") {
|
||||
Ok(stripped_hash.to_string())
|
||||
}
|
||||
else {
|
||||
Err(Error::AuthPassthroughError(
|
||||
"Obtained hash from auth_query does not seem to be in md5 format.".to_string(),
|
||||
))
|
||||
}
|
||||
if let Some(stripped_hash) = password_data
|
||||
.last()
|
||||
.unwrap()
|
||||
.to_string()
|
||||
.strip_prefix("md5") {
|
||||
Ok(stripped_hash.to_string())
|
||||
}
|
||||
else {
|
||||
Err(Error::AuthPassthroughError(
|
||||
"Obtained hash from auth_query does not seem to be in md5 format.".to_string(),
|
||||
))
|
||||
}
|
||||
} else {
|
||||
Err(Error::AuthPassthroughError(
|
||||
"Data obtained from query does not follow the scheme 'user','hash'."
|
||||
@@ -98,10 +110,25 @@ impl AuthPassthrough {
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
Err(Error::AuthPassthroughError(
|
||||
format!("Error trying to obtain password from auth_query, ignoring hash for user '{}'. Error: {:?}",
|
||||
user, err)))
|
||||
Err(Error::AuthPassthroughError(
|
||||
format!("Error trying to obtain password from auth_query, ignoring hash for user '{}'. Error: {:?}",
|
||||
user, err))
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn refetch_auth_hash(pool: &ConnectionPool) -> Result<String, Error> {
|
||||
let address = pool.address(0, 0);
|
||||
if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) {
|
||||
let hash = apt.fetch_hash(address).await?;
|
||||
|
||||
return Ok(hash);
|
||||
}
|
||||
|
||||
Err(Error::ClientError(format!(
|
||||
"Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.",
|
||||
address.username, address.database
|
||||
)))
|
||||
}
|
||||
|
||||
164
src/client.rs
164
src/client.rs
@@ -1,4 +1,4 @@
|
||||
use crate::errors::Error;
|
||||
use crate::errors::{ClientIdentifier, Error};
|
||||
use crate::pool::BanReason;
|
||||
/// Handle clients by pretending to be a PostgreSQL server.
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
@@ -12,7 +12,7 @@ use tokio::sync::broadcast::Receiver;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
|
||||
use crate::admin::{generate_server_info_for_admin, handle_admin};
|
||||
use crate::auth_passthrough::AuthPassthrough;
|
||||
use crate::auth_passthrough::refetch_auth_hash;
|
||||
use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode};
|
||||
use crate::constants::*;
|
||||
use crate::messages::*;
|
||||
@@ -202,7 +202,7 @@ pub async fn client_entrypoint(
|
||||
// Client probably disconnected rejecting our plain text connection.
|
||||
Ok((ClientConnectionType::Tls, _))
|
||||
| Ok((ClientConnectionType::CancelQuery, _)) => Err(Error::ProtocolSyncError(
|
||||
format!("Bad postgres client (plain)"),
|
||||
"Bad postgres client (plain)".into(),
|
||||
)),
|
||||
|
||||
Err(err) => Err(err),
|
||||
@@ -369,28 +369,14 @@ pub async fn startup_tls(
|
||||
}
|
||||
|
||||
// Bad Postgres client.
|
||||
Ok((ClientConnectionType::Tls, _)) | Ok((ClientConnectionType::CancelQuery, _)) => Err(
|
||||
Error::ProtocolSyncError(format!("Bad postgres client (tls)")),
|
||||
),
|
||||
Ok((ClientConnectionType::Tls, _)) | Ok((ClientConnectionType::CancelQuery, _)) => {
|
||||
Err(Error::ProtocolSyncError("Bad postgres client (tls)".into()))
|
||||
}
|
||||
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
async fn refetch_auth_hash(pool: &ConnectionPool) -> Result<String, Error> {
|
||||
let address = pool.address(0, 0);
|
||||
if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) {
|
||||
let hash = apt.fetch_hash(address).await?;
|
||||
|
||||
return Ok(hash);
|
||||
}
|
||||
|
||||
Err(Error::ClientError(format!(
|
||||
"Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.",
|
||||
address.username, address.database
|
||||
)))
|
||||
}
|
||||
|
||||
impl<S, T> Client<S, T>
|
||||
where
|
||||
S: tokio::io::AsyncRead + std::marker::Unpin,
|
||||
@@ -418,7 +404,7 @@ where
|
||||
Some(user) => user,
|
||||
None => {
|
||||
return Err(Error::ClientError(
|
||||
"Missing user parameter on client startup".to_string(),
|
||||
"Missing user parameter on client startup".into(),
|
||||
))
|
||||
}
|
||||
};
|
||||
@@ -433,6 +419,8 @@ where
|
||||
None => "pgcat",
|
||||
};
|
||||
|
||||
let client_identifier = ClientIdentifier::new(&application_name, &username, &pool_name);
|
||||
|
||||
let admin = ["pgcat", "pgbouncer"]
|
||||
.iter()
|
||||
.filter(|db| *db == pool_name)
|
||||
@@ -463,7 +451,12 @@ where
|
||||
|
||||
let code = match read.read_u8().await {
|
||||
Ok(p) => p,
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading password code from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))),
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password code".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// PasswordMessage
|
||||
@@ -476,19 +469,30 @@ where
|
||||
|
||||
let len = match read.read_i32().await {
|
||||
Ok(len) => len,
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading password message length from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))),
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password message length".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let mut password_response = vec![0u8; (len - 4) as usize];
|
||||
|
||||
match read.read_exact(&mut password_response).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading password message from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))),
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password message".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Authenticate admin user.
|
||||
let (transaction_mode, server_info) = if admin {
|
||||
let config = get_config();
|
||||
|
||||
// Compare server and client hashes.
|
||||
let password_hash = md5_hash_password(
|
||||
&config.general.admin_username,
|
||||
@@ -497,10 +501,12 @@ where
|
||||
);
|
||||
|
||||
if password_hash != password_response {
|
||||
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name);
|
||||
let error = Error::ClientGeneralError("Invalid password".into(), client_identifier);
|
||||
|
||||
warn!("{}", error);
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name)));
|
||||
return Err(error);
|
||||
}
|
||||
|
||||
(false, generate_server_info_for_admin())
|
||||
@@ -519,7 +525,10 @@ where
|
||||
)
|
||||
.await?;
|
||||
|
||||
return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name)));
|
||||
return Err(Error::ClientGeneralError(
|
||||
"Invalid pool name".into(),
|
||||
client_identifier,
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -530,16 +539,24 @@ where
|
||||
Some(md5_hash_password(username, password, &salt))
|
||||
} else {
|
||||
if !get_config().is_auth_query_configured() {
|
||||
return Err(Error::ClientError(format!("Client auth not possible, no cleartext password set for username: {:?} in config and auth passthrough (query_auth) is not set up.", username)));
|
||||
wrong_password(&mut write, username).await?;
|
||||
return Err(Error::ClientAuthImpossible(username.into()));
|
||||
}
|
||||
|
||||
let mut hash = (*pool.auth_hash.read()).clone();
|
||||
|
||||
if hash.is_none() {
|
||||
warn!("Query auth configured but no hash password found for pool {}. Will try to refetch it.", pool_name);
|
||||
warn!(
|
||||
"Query auth configured \
|
||||
but no hash password found \
|
||||
for pool {}. Will try to refetch it.",
|
||||
pool_name
|
||||
);
|
||||
|
||||
match refetch_auth_hash(&pool).await {
|
||||
Ok(fetched_hash) => {
|
||||
warn!("Password for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, obtained. Updating.", username, pool_name, application_name);
|
||||
warn!("Password for {}, obtained. Updating.", client_identifier);
|
||||
|
||||
{
|
||||
let mut pool_auth_hash = pool.auth_hash.write();
|
||||
*pool_auth_hash = Some(fetched_hash.clone());
|
||||
@@ -547,16 +564,14 @@ where
|
||||
|
||||
hash = Some(fetched_hash);
|
||||
}
|
||||
|
||||
Err(err) => {
|
||||
return Err(
|
||||
Error::ClientError(
|
||||
format!("No cleartext password set, and no auth passthrough could not obtain the hash from server for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, the error was: {:?}",
|
||||
username,
|
||||
pool_name,
|
||||
application_name,
|
||||
err)
|
||||
)
|
||||
);
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(Error::ClientAuthPassthroughError(
|
||||
err.to_string(),
|
||||
client_identifier,
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -570,20 +585,39 @@ where
|
||||
//
|
||||
// @TODO: we could end up fetching again the same password twice (see above).
|
||||
if password_hash.unwrap() != password_response {
|
||||
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, will try to refetch it.", username, pool_name, application_name);
|
||||
let fetched_hash = refetch_auth_hash(&pool).await?;
|
||||
warn!(
|
||||
"Invalid password {}, will try to refetch it.",
|
||||
client_identifier
|
||||
);
|
||||
|
||||
let fetched_hash = match refetch_auth_hash(&pool).await {
|
||||
Ok(fetched_hash) => fetched_hash,
|
||||
Err(err) => {
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt);
|
||||
|
||||
// Ok password changed in server an auth is possible.
|
||||
if new_password_hash == password_response {
|
||||
warn!("Password for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, changed in server. Updating.", username, pool_name, application_name);
|
||||
warn!(
|
||||
"Password for {}, changed in server. Updating.",
|
||||
client_identifier
|
||||
);
|
||||
|
||||
{
|
||||
let mut pool_auth_hash = pool.auth_hash.write();
|
||||
*pool_auth_hash = Some(fetched_hash);
|
||||
}
|
||||
} else {
|
||||
wrong_password(&mut write, username).await?;
|
||||
return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name)));
|
||||
return Err(Error::ClientGeneralError(
|
||||
"Invalid password".into(),
|
||||
client_identifier,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -753,9 +787,9 @@ where
|
||||
&mut self.write,
|
||||
"terminating connection due to administrator command"
|
||||
).await?;
|
||||
self.stats.disconnect();
|
||||
|
||||
return Ok(())
|
||||
self.stats.disconnect();
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Admin clients ignore shutdown.
|
||||
@@ -928,11 +962,26 @@ where
|
||||
error!("Got Sync message but failed to get a connection from the pool");
|
||||
self.buffer.clear();
|
||||
}
|
||||
|
||||
error_response(&mut self.write, "could not get connection from the pool")
|
||||
.await?;
|
||||
|
||||
error!("Could not get connection from pool: {{ pool_name: {:?}, username: {:?}, shard: {:?}, role: \"{:?}\", error: \"{:?}\" }}",
|
||||
self.pool_name.clone(), self.username.clone(), query_router.shard(), query_router.role(), err);
|
||||
error!(
|
||||
"Could not get connection from pool: \
|
||||
{{ \
|
||||
pool_name: {:?}, \
|
||||
username: {:?}, \
|
||||
shard: {:?}, \
|
||||
role: \"{:?}\", \
|
||||
error: \"{:?}\" \
|
||||
}}",
|
||||
self.pool_name,
|
||||
self.username,
|
||||
query_router.shard(),
|
||||
query_router.role(),
|
||||
err
|
||||
);
|
||||
|
||||
continue;
|
||||
}
|
||||
};
|
||||
@@ -999,11 +1048,25 @@ where
|
||||
Err(_) => {
|
||||
// Client idle in transaction timeout
|
||||
error_response(&mut self.write, "idle transaction timeout").await?;
|
||||
error!("Client idle in transaction timeout: {{ pool_name: {:?}, username: {:?}, shard: {:?}, role: \"{:?}\"}}", self.pool_name.clone(), self.username.clone(), query_router.shard(), query_router.role());
|
||||
error!(
|
||||
"Client idle in transaction timeout: \
|
||||
{{ \
|
||||
pool_name: {}, \
|
||||
username: {}, \
|
||||
shard: {}, \
|
||||
role: \"{:?}\" \
|
||||
}}",
|
||||
self.pool_name,
|
||||
self.username,
|
||||
query_router.shard(),
|
||||
query_router.role()
|
||||
);
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(message) => {
|
||||
initial_message = None;
|
||||
message
|
||||
@@ -1076,6 +1139,11 @@ where
|
||||
self.buffer.put(&message[..]);
|
||||
}
|
||||
|
||||
// Close the prepared statement.
|
||||
'C' => {
|
||||
self.buffer.put(&message[..]);
|
||||
}
|
||||
|
||||
// Execute
|
||||
// Execute a prepared statement prepared in `P` and bound in `B`.
|
||||
'E' => {
|
||||
|
||||
144
src/config.rs
144
src/config.rs
@@ -178,7 +178,12 @@ impl Address {
|
||||
pub struct User {
|
||||
pub username: String,
|
||||
pub password: Option<String>,
|
||||
pub server_username: Option<String>,
|
||||
pub server_password: Option<String>,
|
||||
pub pool_size: u32,
|
||||
pub min_pool_size: Option<u32>,
|
||||
pub pool_mode: Option<PoolMode>,
|
||||
pub server_lifetime: Option<u64>,
|
||||
#[serde(default)] // 0
|
||||
pub statement_timeout: u64,
|
||||
}
|
||||
@@ -188,12 +193,37 @@ impl Default for User {
|
||||
User {
|
||||
username: String::from("postgres"),
|
||||
password: None,
|
||||
server_username: None,
|
||||
server_password: None,
|
||||
pool_size: 15,
|
||||
min_pool_size: None,
|
||||
statement_timeout: 0,
|
||||
pool_mode: None,
|
||||
server_lifetime: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl User {
|
||||
fn validate(&self) -> Result<(), Error> {
|
||||
match self.min_pool_size {
|
||||
Some(min_pool_size) => {
|
||||
if min_pool_size > self.pool_size {
|
||||
error!(
|
||||
"min_pool_size of {} cannot be larger than pool_size of {}",
|
||||
min_pool_size, self.pool_size
|
||||
);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
}
|
||||
|
||||
None => (),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// General configuration.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
pub struct General {
|
||||
@@ -201,7 +231,7 @@ pub struct General {
|
||||
pub host: String,
|
||||
|
||||
#[serde(default = "General::default_port")]
|
||||
pub port: i16,
|
||||
pub port: u16,
|
||||
|
||||
pub enable_prometheus_exporter: Option<bool>,
|
||||
pub prometheus_exporter_port: i16,
|
||||
@@ -240,14 +270,24 @@ pub struct General {
|
||||
#[serde(default = "General::default_idle_client_in_transaction_timeout")]
|
||||
pub idle_client_in_transaction_timeout: u64,
|
||||
|
||||
#[serde(default = "General::default_server_lifetime")]
|
||||
pub server_lifetime: u64,
|
||||
|
||||
#[serde(default = "General::default_worker_threads")]
|
||||
pub worker_threads: usize,
|
||||
|
||||
#[serde(default)] // False
|
||||
pub autoreload: bool,
|
||||
#[serde(default)] // None
|
||||
pub autoreload: Option<u64>,
|
||||
|
||||
pub tls_certificate: Option<String>,
|
||||
pub tls_private_key: Option<String>,
|
||||
|
||||
#[serde(default)] // false
|
||||
pub server_tls: bool,
|
||||
|
||||
#[serde(default)] // false
|
||||
pub verify_server_certificate: bool,
|
||||
|
||||
pub admin_username: String,
|
||||
pub admin_password: String,
|
||||
|
||||
@@ -261,17 +301,21 @@ impl General {
|
||||
"0.0.0.0".into()
|
||||
}
|
||||
|
||||
pub fn default_port() -> i16 {
|
||||
pub fn default_port() -> u16 {
|
||||
5432
|
||||
}
|
||||
|
||||
pub fn default_server_lifetime() -> u64 {
|
||||
1000 * 60 * 60 * 24 // 24 hours
|
||||
}
|
||||
|
||||
pub fn default_connect_timeout() -> u64 {
|
||||
1000
|
||||
}
|
||||
|
||||
// These keepalive defaults should detect a dead connection within 30 seconds.
|
||||
// Tokio defaults to disabling keepalives which keeps dead connections around indefinitely.
|
||||
// This can lead to permenant server pool exhaustion
|
||||
// This can lead to permanent server pool exhaustion
|
||||
pub fn default_tcp_keepalives_idle() -> u64 {
|
||||
5 // 5 seconds
|
||||
}
|
||||
@@ -333,14 +377,17 @@ impl Default for General {
|
||||
tcp_keepalives_interval: Self::default_tcp_keepalives_interval(),
|
||||
log_client_connections: false,
|
||||
log_client_disconnections: false,
|
||||
autoreload: false,
|
||||
autoreload: None,
|
||||
tls_certificate: None,
|
||||
tls_private_key: None,
|
||||
server_tls: false,
|
||||
verify_server_certificate: false,
|
||||
admin_username: String::from("admin"),
|
||||
admin_password: String::from("admin"),
|
||||
auth_query: None,
|
||||
auth_query_user: None,
|
||||
auth_query_password: None,
|
||||
server_lifetime: 1000 * 3600 * 24, // 24 hours,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -356,6 +403,7 @@ pub enum PoolMode {
|
||||
#[serde(alias = "session", alias = "Session")]
|
||||
Session,
|
||||
}
|
||||
|
||||
impl ToString for PoolMode {
|
||||
fn to_string(&self) -> String {
|
||||
match *self {
|
||||
@@ -404,6 +452,8 @@ pub struct Pool {
|
||||
|
||||
pub idle_timeout: Option<u64>,
|
||||
|
||||
pub server_lifetime: Option<u64>,
|
||||
|
||||
pub sharding_function: ShardingFunction,
|
||||
|
||||
#[serde(default = "Pool::default_automatic_sharding_key")]
|
||||
@@ -419,7 +469,7 @@ pub struct Pool {
|
||||
|
||||
pub shards: BTreeMap<String, Shard>,
|
||||
pub users: BTreeMap<String, User>,
|
||||
// Note, don't put simple fields below these configs. There's a compatability issue with TOML that makes it
|
||||
// Note, don't put simple fields below these configs. There's a compatibility issue with TOML that makes it
|
||||
// incompatible to have simple fields in TOML after complex objects. See
|
||||
// https://users.rust-lang.org/t/why-toml-to-string-get-error-valueaftertable/85903
|
||||
}
|
||||
@@ -508,6 +558,10 @@ impl Pool {
|
||||
None => None,
|
||||
};
|
||||
|
||||
for (_, user) in &self.users {
|
||||
user.validate()?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -532,6 +586,7 @@ impl Default for Pool {
|
||||
auth_query: None,
|
||||
auth_query_user: None,
|
||||
auth_query_password: None,
|
||||
server_lifetime: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -581,7 +636,7 @@ impl Shard {
|
||||
|
||||
if primary_count > 1 {
|
||||
error!(
|
||||
"Shard {} has more than on primary configured",
|
||||
"Shard {} has more than one primary configured",
|
||||
self.database
|
||||
);
|
||||
return Err(Error::BadConfig);
|
||||
@@ -784,6 +839,10 @@ impl Config {
|
||||
);
|
||||
info!("Shutdown timeout: {}ms", self.general.shutdown_timeout);
|
||||
info!("Healthcheck delay: {}ms", self.general.healthcheck_delay);
|
||||
info!(
|
||||
"Default max server lifetime: {}ms",
|
||||
self.general.server_lifetime
|
||||
);
|
||||
match self.general.tls_certificate.clone() {
|
||||
Some(tls_certificate) => {
|
||||
info!("TLS certificate: {}", tls_certificate);
|
||||
@@ -802,6 +861,11 @@ impl Config {
|
||||
info!("TLS support is disabled");
|
||||
}
|
||||
};
|
||||
info!("Server TLS enabled: {}", self.general.server_tls);
|
||||
info!(
|
||||
"Server TLS certificate verification: {}",
|
||||
self.general.verify_server_certificate
|
||||
);
|
||||
|
||||
for (pool_name, pool_config) in &self.pools {
|
||||
// TODO: Make this output prettier (maybe a table?)
|
||||
@@ -816,8 +880,9 @@ impl Config {
|
||||
.to_string()
|
||||
);
|
||||
info!(
|
||||
"[pool: {}] Pool mode: {:?}",
|
||||
pool_name, pool_config.pool_mode
|
||||
"[pool: {}] Default pool mode: {}",
|
||||
pool_name,
|
||||
pool_config.pool_mode.to_string()
|
||||
);
|
||||
info!(
|
||||
"[pool: {}] Load Balancing mode: {:?}",
|
||||
@@ -859,16 +924,48 @@ impl Config {
|
||||
pool_name,
|
||||
pool_config.users.len()
|
||||
);
|
||||
info!(
|
||||
"[pool: {}] Max server lifetime: {}",
|
||||
pool_name,
|
||||
match pool_config.server_lifetime {
|
||||
Some(server_lifetime) => format!("{}ms", server_lifetime),
|
||||
None => "default".to_string(),
|
||||
}
|
||||
);
|
||||
|
||||
for user in &pool_config.users {
|
||||
info!(
|
||||
"[pool: {}][user: {}] Pool size: {}",
|
||||
pool_name, user.1.username, user.1.pool_size,
|
||||
);
|
||||
info!(
|
||||
"[pool: {}][user: {}] Minimum pool size: {}",
|
||||
pool_name,
|
||||
user.1.username,
|
||||
user.1.min_pool_size.unwrap_or(0)
|
||||
);
|
||||
info!(
|
||||
"[pool: {}][user: {}] Statement timeout: {}",
|
||||
pool_name, user.1.username, user.1.statement_timeout
|
||||
)
|
||||
);
|
||||
info!(
|
||||
"[pool: {}][user: {}] Pool mode: {}",
|
||||
pool_name,
|
||||
user.1.username,
|
||||
match user.1.pool_mode {
|
||||
Some(pool_mode) => pool_mode.to_string(),
|
||||
None => pool_config.pool_mode.to_string(),
|
||||
}
|
||||
);
|
||||
info!(
|
||||
"[pool: {}][user: {}] Max server lifetime: {}",
|
||||
pool_name,
|
||||
user.1.username,
|
||||
match user.1.server_lifetime {
|
||||
Some(server_lifetime) => format!("{}ms", server_lifetime),
|
||||
None => "default".to_string(),
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -879,7 +976,13 @@ impl Config {
|
||||
&& (self.general.auth_query_user.is_none()
|
||||
|| self.general.auth_query_password.is_none())
|
||||
{
|
||||
error!("If auth_query is specified, you need to provide a value for `auth_query_user`, `auth_query_password`");
|
||||
error!(
|
||||
"If auth_query is specified, \
|
||||
you need to provide a value \
|
||||
for `auth_query_user`, \
|
||||
`auth_query_password`"
|
||||
);
|
||||
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
|
||||
@@ -887,7 +990,14 @@ impl Config {
|
||||
if pool.auth_query.is_some()
|
||||
&& (pool.auth_query_user.is_none() || pool.auth_query_password.is_none())
|
||||
{
|
||||
error!("Error in pool {{ {} }}. If auth_query is specified, you need to provide a value for `auth_query_user`, `auth_query_password`", name);
|
||||
error!(
|
||||
"Error in pool {{ {} }}. \
|
||||
If auth_query is specified, you need \
|
||||
to provide a value for `auth_query_user`, \
|
||||
`auth_query_password`",
|
||||
name
|
||||
);
|
||||
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
|
||||
@@ -897,7 +1007,13 @@ impl Config {
|
||||
|| pool.auth_query_user.is_none())
|
||||
&& user_data.password.is_none()
|
||||
{
|
||||
error!("Error in pool {{ {} }}. You have to specify a user password for every pool if auth_query is not specified", name);
|
||||
error!(
|
||||
"Error in pool {{ {} }}. \
|
||||
You have to specify a user password \
|
||||
for every pool if auth_query is not specified",
|
||||
name
|
||||
);
|
||||
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
}
|
||||
|
||||
102
src/errors.rs
102
src/errors.rs
@@ -1,13 +1,19 @@
|
||||
/// Errors.
|
||||
//! Errors.
|
||||
|
||||
/// Various errors.
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum Error {
|
||||
SocketError(String),
|
||||
ClientSocketError(String, ClientIdentifier),
|
||||
ClientGeneralError(String, ClientIdentifier),
|
||||
ClientAuthImpossible(String),
|
||||
ClientAuthPassthroughError(String, ClientIdentifier),
|
||||
ClientBadStartup,
|
||||
ProtocolSyncError(String),
|
||||
BadQuery(String),
|
||||
ServerError,
|
||||
ServerStartupError(String, ServerIdentifier),
|
||||
ServerAuthError(String, ServerIdentifier),
|
||||
BadConfig,
|
||||
AllServersDown,
|
||||
ClientError(String),
|
||||
@@ -18,3 +24,97 @@ pub enum Error {
|
||||
AuthError(String),
|
||||
AuthPassthroughError(String),
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub struct ClientIdentifier {
|
||||
pub application_name: String,
|
||||
pub username: String,
|
||||
pub pool_name: String,
|
||||
}
|
||||
|
||||
impl ClientIdentifier {
|
||||
pub fn new(application_name: &str, username: &str, pool_name: &str) -> ClientIdentifier {
|
||||
ClientIdentifier {
|
||||
application_name: application_name.into(),
|
||||
username: username.into(),
|
||||
pool_name: pool_name.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ClientIdentifier {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{{ application_name: {}, username: {}, pool_name: {} }}",
|
||||
self.application_name, self.username, self.pool_name
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub struct ServerIdentifier {
|
||||
pub username: String,
|
||||
pub database: String,
|
||||
}
|
||||
|
||||
impl ServerIdentifier {
|
||||
pub fn new(username: &str, database: &str) -> ServerIdentifier {
|
||||
ServerIdentifier {
|
||||
username: username.into(),
|
||||
database: database.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ServerIdentifier {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{{ username: {}, database: {} }}",
|
||||
self.username, self.database
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match &self {
|
||||
&Error::ClientSocketError(error, client_identifier) => write!(
|
||||
f,
|
||||
"Error reading {} from client {}",
|
||||
error, client_identifier
|
||||
),
|
||||
&Error::ClientGeneralError(error, client_identifier) => {
|
||||
write!(f, "{} {}", error, client_identifier)
|
||||
}
|
||||
&Error::ClientAuthImpossible(username) => write!(
|
||||
f,
|
||||
"Client auth not possible, \
|
||||
no cleartext password set for username: {} \
|
||||
in config and auth passthrough (query_auth) \
|
||||
is not set up.",
|
||||
username
|
||||
),
|
||||
&Error::ClientAuthPassthroughError(error, client_identifier) => write!(
|
||||
f,
|
||||
"No cleartext password set, \
|
||||
and no auth passthrough could not \
|
||||
obtain the hash from server for {}, \
|
||||
the error was: {}",
|
||||
client_identifier, error
|
||||
),
|
||||
&Error::ServerStartupError(error, server_identifier) => write!(
|
||||
f,
|
||||
"Error reading {} on server startup {}",
|
||||
error, server_identifier,
|
||||
),
|
||||
&Error::ServerAuthError(error, server_identifier) => {
|
||||
write!(f, "{} for {}", error, server_identifier,)
|
||||
}
|
||||
|
||||
// The rest can use Debug.
|
||||
err => write!(f, "{:?}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
34
src/main.rs
34
src/main.rs
@@ -79,6 +79,7 @@ mod stats;
|
||||
mod tls;
|
||||
|
||||
use crate::config::{get_config, reload_config, VERSION};
|
||||
use crate::messages::configure_socket;
|
||||
use crate::pool::{ClientServerMap, ConnectionPool};
|
||||
use crate::prometheus::start_metric_server;
|
||||
use crate::stats::{Collector, Reporter, REPORTER};
|
||||
@@ -179,16 +180,19 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
stats_collector.collect().await;
|
||||
});
|
||||
|
||||
info!("Config autoreloader: {}", config.general.autoreload);
|
||||
info!("Config autoreloader: {}", match config.general.autoreload {
|
||||
Some(interval) => format!("{} ms", interval),
|
||||
None => "disabled".into(),
|
||||
});
|
||||
|
||||
let mut autoreload_interval = tokio::time::interval(tokio::time::Duration::from_millis(15_000));
|
||||
let autoreload_client_server_map = client_server_map.clone();
|
||||
if let Some(interval) = config.general.autoreload {
|
||||
let mut autoreload_interval = tokio::time::interval(tokio::time::Duration::from_millis(interval));
|
||||
let autoreload_client_server_map = client_server_map.clone();
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
loop {
|
||||
autoreload_interval.tick().await;
|
||||
if config.general.autoreload {
|
||||
info!("Automatically reloading config");
|
||||
tokio::task::spawn(async move {
|
||||
loop {
|
||||
autoreload_interval.tick().await;
|
||||
debug!("Automatically reloading config");
|
||||
|
||||
if let Ok(changed) = reload_config(autoreload_client_server_map.clone()).await {
|
||||
if changed {
|
||||
@@ -196,8 +200,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
|
||||
#[cfg(windows)]
|
||||
let mut term_signal = win_signal::ctrl_close().unwrap();
|
||||
@@ -282,7 +288,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let drain_tx = drain_tx.clone();
|
||||
let client_server_map = client_server_map.clone();
|
||||
|
||||
let tls_certificate = config.general.tls_certificate.clone();
|
||||
let tls_certificate = get_config().general.tls_certificate.clone();
|
||||
|
||||
configure_socket(&socket);
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
let start = chrono::offset::Utc::now().naive_utc();
|
||||
@@ -293,7 +301,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
shutdown_rx,
|
||||
drain_tx,
|
||||
admin_only,
|
||||
tls_certificate.clone(),
|
||||
tls_certificate,
|
||||
config.general.log_client_connections,
|
||||
)
|
||||
.await
|
||||
@@ -301,7 +309,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
Ok(()) => {
|
||||
let duration = chrono::offset::Utc::now().naive_utc() - start;
|
||||
|
||||
if config.general.log_client_disconnections {
|
||||
if get_config().general.log_client_disconnections {
|
||||
info!(
|
||||
"Client {:?} disconnected, session duration: {}",
|
||||
addr,
|
||||
|
||||
@@ -116,7 +116,10 @@ where
|
||||
|
||||
/// Send the startup packet the server. We're pretending we're a Pg client.
|
||||
/// This tells the server which user we are and what database we want.
|
||||
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> {
|
||||
pub async fn startup<S>(stream: &mut S, user: &str, database: &str) -> Result<(), Error>
|
||||
where
|
||||
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
let mut bytes = BytesMut::with_capacity(25);
|
||||
|
||||
bytes.put_i32(196608); // Protocol number
|
||||
@@ -150,6 +153,21 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn ssl_request(stream: &mut TcpStream) -> Result<(), Error> {
|
||||
let mut bytes = BytesMut::with_capacity(12);
|
||||
|
||||
bytes.put_i32(8);
|
||||
bytes.put_i32(80877103);
|
||||
|
||||
match stream.write_all(&bytes).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(Error::SocketError(format!(
|
||||
"Error writing SSLRequest to server socket - Error: {:?}",
|
||||
err
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse the params the server sends as a key/value format.
|
||||
pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
|
||||
let mut result = HashMap::new();
|
||||
@@ -404,7 +422,7 @@ pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
|
||||
let mut res = BytesMut::new();
|
||||
let mut row_desc = BytesMut::new();
|
||||
|
||||
// how many colums we are storing
|
||||
// how many columns we are storing
|
||||
row_desc.put_i16(columns.len() as i16);
|
||||
|
||||
for (name, data_type) in columns {
|
||||
@@ -505,6 +523,29 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn write_all_flush<S>(stream: &mut S, buf: &[u8]) -> Result<(), Error>
|
||||
where
|
||||
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
match stream.write_all(buf).await {
|
||||
Ok(_) => match stream.flush().await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error flushing socket - Error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error writing to socket - Error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read a complete message from the socket.
|
||||
pub async fn read_message<S>(stream: &mut S) -> Result<BytesMut, Error>
|
||||
where
|
||||
|
||||
@@ -17,7 +17,7 @@ use log::{Level, Log, Metadata, Record, SetLoggerError};
|
||||
//
|
||||
// So to summarize, if no `STDOUT_LOG` env var is present, the logger is the default logger. If `STDOUT_LOG` is set, everything
|
||||
// but errors, that matches the log level set in the `STDOUT_LOG` env var is sent to stdout. You can have also some esoteric configuration
|
||||
// where you set `RUST_LOG=debug` and `STDOUT_LOG=info`, in here, erros will go to stderr, warns and infos to stdout and debugs to stderr.
|
||||
// where you set `RUST_LOG=debug` and `STDOUT_LOG=info`, in here, errors will go to stderr, warns and infos to stdout and debugs to stderr.
|
||||
//
|
||||
pub struct MultiLogger {
|
||||
stderr_logger: env_logger::Logger,
|
||||
|
||||
61
src/pool.rs
61
src/pool.rs
@@ -311,21 +311,34 @@ impl ConnectionPool {
|
||||
|
||||
if let Some(apt) = &auth_passthrough {
|
||||
match apt.fetch_hash(&address).await {
|
||||
Ok(ok) => {
|
||||
if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) {
|
||||
if ok != *pool_auth_hash_value {
|
||||
warn!("Hash is not the same across shards of the same pool, client auth will \
|
||||
be done using last obtained hash. Server: {}:{}, Database: {}", server.host, server.port, shard.database);
|
||||
}
|
||||
}
|
||||
debug!("Hash obtained for {:?}", address);
|
||||
{
|
||||
let mut pool_auth_hash = pool_auth_hash.write();
|
||||
*pool_auth_hash = Some(ok.clone());
|
||||
}
|
||||
},
|
||||
Err(err) => warn!("Could not obtain password hashes using auth_query config, ignoring. Error: {:?}", err),
|
||||
}
|
||||
Ok(ok) => {
|
||||
if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read())
|
||||
{
|
||||
if ok != *pool_auth_hash_value {
|
||||
warn!(
|
||||
"Hash is not the same across shards \
|
||||
of the same pool, client auth will \
|
||||
be done using last obtained hash. \
|
||||
Server: {}:{}, Database: {}",
|
||||
server.host, server.port, shard.database,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
debug!("Hash obtained for {:?}", address);
|
||||
|
||||
{
|
||||
let mut pool_auth_hash = pool_auth_hash.write();
|
||||
*pool_auth_hash = Some(ok.clone());
|
||||
}
|
||||
}
|
||||
Err(err) => warn!(
|
||||
"Could not obtain password hashes \
|
||||
using auth_query config, ignoring. \
|
||||
Error: {:?}",
|
||||
err,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
let manager = ServerPool::new(
|
||||
@@ -347,14 +360,23 @@ impl ConnectionPool {
|
||||
None => config.general.idle_timeout,
|
||||
};
|
||||
|
||||
let server_lifetime = match user.server_lifetime {
|
||||
Some(server_lifetime) => server_lifetime,
|
||||
None => match pool_config.server_lifetime {
|
||||
Some(server_lifetime) => server_lifetime,
|
||||
None => config.general.server_lifetime,
|
||||
},
|
||||
};
|
||||
|
||||
let pool = Pool::builder()
|
||||
.max_size(user.pool_size)
|
||||
.min_idle(user.min_pool_size)
|
||||
.connection_timeout(std::time::Duration::from_millis(connect_timeout))
|
||||
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
|
||||
.max_lifetime(Some(std::time::Duration::from_millis(server_lifetime)))
|
||||
.test_on_check_out(false)
|
||||
.build(manager)
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
|
||||
pools.push(pool);
|
||||
servers.push(address);
|
||||
@@ -382,7 +404,10 @@ impl ConnectionPool {
|
||||
server_info: Arc::new(RwLock::new(BytesMut::new())),
|
||||
auth_hash: pool_auth_hash,
|
||||
settings: PoolSettings {
|
||||
pool_mode: pool_config.pool_mode,
|
||||
pool_mode: match user.pool_mode {
|
||||
Some(pool_mode) => pool_mode,
|
||||
None => pool_config.pool_mode,
|
||||
},
|
||||
load_balancing_mode: pool_config.load_balancing_mode,
|
||||
// shards: pool_config.shards.clone(),
|
||||
shards: shard_ids.len(),
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
/// Route queries automatically based on explicitely requested
|
||||
/// Route queries automatically based on explicitly requested
|
||||
/// or implied query characteristics.
|
||||
use bytes::{Buf, BytesMut};
|
||||
use log::{debug, error};
|
||||
|
||||
342
src/server.rs
342
src/server.rs
@@ -9,20 +9,97 @@ use std::collections::HashMap;
|
||||
use std::io::Read;
|
||||
use std::sync::Arc;
|
||||
use std::time::SystemTime;
|
||||
use tokio::io::{AsyncReadExt, BufReader};
|
||||
use tokio::net::{
|
||||
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
||||
TcpStream,
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore};
|
||||
use tokio_rustls::{client::TlsStream, TlsConnector};
|
||||
|
||||
use crate::config::{Address, User};
|
||||
use crate::config::{get_config, Address, User};
|
||||
use crate::constants::*;
|
||||
use crate::errors::Error;
|
||||
use crate::errors::{Error, ServerIdentifier};
|
||||
use crate::messages::*;
|
||||
use crate::mirrors::MirroringManager;
|
||||
use crate::pool::ClientServerMap;
|
||||
use crate::scram::ScramSha256;
|
||||
use crate::stats::ServerStats;
|
||||
use std::io::Write;
|
||||
|
||||
use pin_project::pin_project;
|
||||
|
||||
#[pin_project(project = SteamInnerProj)]
|
||||
pub enum StreamInner {
|
||||
Plain {
|
||||
#[pin]
|
||||
stream: TcpStream,
|
||||
},
|
||||
Tls {
|
||||
#[pin]
|
||||
stream: TlsStream<TcpStream>,
|
||||
},
|
||||
}
|
||||
|
||||
impl AsyncWrite for StreamInner {
|
||||
fn poll_write(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> std::task::Poll<Result<usize, std::io::Error>> {
|
||||
let this = self.project();
|
||||
match this {
|
||||
SteamInnerProj::Tls { stream } => stream.poll_write(cx, buf),
|
||||
SteamInnerProj::Plain { stream } => stream.poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), std::io::Error>> {
|
||||
let this = self.project();
|
||||
match this {
|
||||
SteamInnerProj::Tls { stream } => stream.poll_flush(cx),
|
||||
SteamInnerProj::Plain { stream } => stream.poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), std::io::Error>> {
|
||||
let this = self.project();
|
||||
match this {
|
||||
SteamInnerProj::Tls { stream } => stream.poll_shutdown(cx),
|
||||
SteamInnerProj::Plain { stream } => stream.poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for StreamInner {
|
||||
fn poll_read(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
let this = self.project();
|
||||
match this {
|
||||
SteamInnerProj::Tls { stream } => stream.poll_read(cx, buf),
|
||||
SteamInnerProj::Plain { stream } => stream.poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamInner {
|
||||
pub fn try_write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
match self {
|
||||
StreamInner::Tls { stream } => {
|
||||
let r = stream.get_mut();
|
||||
let mut w = r.1.writer();
|
||||
w.write(buf)
|
||||
}
|
||||
StreamInner::Plain { stream } => stream.try_write(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Server state.
|
||||
pub struct Server {
|
||||
@@ -30,11 +107,8 @@ pub struct Server {
|
||||
/// port, e.g. 5432, and role, e.g. primary or replica.
|
||||
address: Address,
|
||||
|
||||
/// Buffered read socket.
|
||||
read: BufReader<OwnedReadHalf>,
|
||||
|
||||
/// Unbuffered write socket (our client code buffers).
|
||||
write: OwnedWriteHalf,
|
||||
/// Server TCP connection.
|
||||
stream: BufStream<StreamInner>,
|
||||
|
||||
/// Our server response buffer. We buffer data before we give it to the client.
|
||||
buffer: BytesMut,
|
||||
@@ -98,33 +172,137 @@ impl Server {
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
// TCP timeouts.
|
||||
configure_socket(&stream);
|
||||
|
||||
let config = get_config();
|
||||
|
||||
let mut stream = if config.general.server_tls {
|
||||
// Request a TLS connection
|
||||
ssl_request(&mut stream).await?;
|
||||
|
||||
let response = match stream.read_u8().await {
|
||||
Ok(response) => response as char,
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Server socket error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
match response {
|
||||
// Server supports TLS
|
||||
'S' => {
|
||||
debug!("Connecting to server using TLS");
|
||||
|
||||
let mut root_store = RootCertStore::empty();
|
||||
root_store.add_server_trust_anchors(
|
||||
webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
|
||||
OwnedTrustAnchor::from_subject_spki_name_constraints(
|
||||
ta.subject,
|
||||
ta.spki,
|
||||
ta.name_constraints,
|
||||
)
|
||||
}),
|
||||
);
|
||||
|
||||
let mut tls_config = rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
|
||||
// Equivalent to sslmode=prefer which is fine most places.
|
||||
// If you want verify-full, change `verify_server_certificate` to true.
|
||||
if !config.general.verify_server_certificate {
|
||||
let mut dangerous = tls_config.dangerous();
|
||||
dangerous.set_certificate_verifier(Arc::new(
|
||||
crate::tls::NoCertificateVerification {},
|
||||
));
|
||||
}
|
||||
|
||||
let connector = TlsConnector::from(Arc::new(tls_config));
|
||||
let stream = match connector
|
||||
.connect(address.host.as_str().try_into().unwrap(), stream)
|
||||
.await
|
||||
{
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!("Server TLS error: {:?}", err)))
|
||||
}
|
||||
};
|
||||
|
||||
StreamInner::Tls { stream }
|
||||
}
|
||||
|
||||
// Server does not support TLS
|
||||
'N' => StreamInner::Plain { stream },
|
||||
|
||||
// Something else?
|
||||
m => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Unknown message: {}",
|
||||
m as char
|
||||
)));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
StreamInner::Plain { stream }
|
||||
};
|
||||
|
||||
// let (read, write) = split(stream);
|
||||
// let (mut read, mut write) = (ReadInner::Plain { stream: read }, WriteInner::Plain { stream: write });
|
||||
|
||||
trace!("Sending StartupMessage");
|
||||
|
||||
// StartupMessage
|
||||
startup(&mut stream, &user.username, database).await?;
|
||||
let username = match user.server_username {
|
||||
Some(ref server_username) => server_username,
|
||||
None => &user.username,
|
||||
};
|
||||
|
||||
let password = match user.server_password {
|
||||
Some(ref server_password) => Some(server_password),
|
||||
None => match user.password {
|
||||
Some(ref password) => Some(password),
|
||||
None => None,
|
||||
},
|
||||
};
|
||||
|
||||
startup(&mut stream, username, database).await?;
|
||||
|
||||
let mut server_info = BytesMut::new();
|
||||
let mut process_id: i32 = 0;
|
||||
let mut secret_key: i32 = 0;
|
||||
let server_identifier = ServerIdentifier::new(username, &database);
|
||||
|
||||
// We'll be handling multiple packets, but they will all be structured the same.
|
||||
// We'll loop here until this exchange is complete.
|
||||
let mut scram: Option<ScramSha256> = None;
|
||||
if let Some(password) = &user.password.clone() {
|
||||
scram = Some(ScramSha256::new(password));
|
||||
}
|
||||
let mut scram: Option<ScramSha256> = match password {
|
||||
Some(password) => Some(ScramSha256::new(password)),
|
||||
None => None,
|
||||
};
|
||||
|
||||
loop {
|
||||
let code = match stream.read_u8().await {
|
||||
Ok(code) => code as char,
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading message code on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"message code".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let len = match stream.read_i32().await {
|
||||
Ok(len) => len,
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading message len on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"message len".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
trace!("Message: {}", code);
|
||||
@@ -135,7 +313,12 @@ impl Server {
|
||||
// Determine which kind of authentication is required, if any.
|
||||
let auth_code = match stream.read_i32().await {
|
||||
Ok(auth_code) => auth_code,
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading auth code on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"auth code".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
trace!("Auth: {}", auth_code);
|
||||
@@ -148,14 +331,18 @@ impl Server {
|
||||
|
||||
match stream.read_exact(&mut salt).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading salt on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"salt".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
match &user.password {
|
||||
match password {
|
||||
// Using plaintext password
|
||||
Some(password) => {
|
||||
md5_password(&mut stream, &user.username, password, &salt[..])
|
||||
.await?
|
||||
md5_password(&mut stream, username, password, &salt[..]).await?
|
||||
}
|
||||
|
||||
// Using auth passthrough, in this case we should already have a
|
||||
@@ -171,8 +358,12 @@ impl Server {
|
||||
&salt[..],
|
||||
)
|
||||
.await?,
|
||||
None =>
|
||||
return Err(Error::AuthError(format!("Auth passthrough (auth_query) failed and no user password is set in cleartext for {{ username: {:?}, database: {:?} }}", user.username, database)))
|
||||
None => return Err(
|
||||
Error::ServerAuthError(
|
||||
"Auth passthrough (auth_query) failed and no user password is set in cleartext".into(),
|
||||
server_identifier
|
||||
)
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -182,21 +373,33 @@ impl Server {
|
||||
|
||||
SASL => {
|
||||
if scram.is_none() {
|
||||
return Err(Error::AuthError(format!("SASL auth required and not password specified, auth passthrough (auth_query) method is currently unsupported for SASL auth {{ username: {:?}, database: {:?} }}", user.username, database)));
|
||||
return Err(Error::ServerAuthError(
|
||||
"SASL auth required and no password specified. \
|
||||
Auth passthrough (auth_query) method is currently \
|
||||
unsupported for SASL auth"
|
||||
.into(),
|
||||
server_identifier,
|
||||
));
|
||||
}
|
||||
|
||||
debug!("Starting SASL authentication");
|
||||
|
||||
let sasl_len = (len - 8) as usize;
|
||||
let mut sasl_auth = vec![0u8; sasl_len];
|
||||
|
||||
match stream.read_exact(&mut sasl_auth).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading sasl message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"sasl message".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]);
|
||||
|
||||
if sasl_type == SCRAM_SHA_256 {
|
||||
if sasl_type.contains(SCRAM_SHA_256) {
|
||||
debug!("Using {}", SCRAM_SHA_256);
|
||||
|
||||
// Generate client message.
|
||||
@@ -219,7 +422,7 @@ impl Server {
|
||||
res.put_i32(sasl_response.len() as i32);
|
||||
res.put(sasl_response);
|
||||
|
||||
write_all(&mut stream, res).await?;
|
||||
write_all_flush(&mut stream, &res).await?;
|
||||
} else {
|
||||
error!("Unsupported SCRAM version: {}", sasl_type);
|
||||
return Err(Error::ServerError);
|
||||
@@ -233,7 +436,12 @@ impl Server {
|
||||
|
||||
match stream.read_exact(&mut sasl_data).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading sasl cont message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"sasl cont message".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let msg = BytesMut::from(&sasl_data[..]);
|
||||
@@ -245,7 +453,7 @@ impl Server {
|
||||
res.put_i32(4 + sasl_response.len() as i32);
|
||||
res.put(sasl_response);
|
||||
|
||||
write_all(&mut stream, res).await?;
|
||||
write_all_flush(&mut stream, &res).await?;
|
||||
}
|
||||
|
||||
SASL_FINAL => {
|
||||
@@ -254,7 +462,12 @@ impl Server {
|
||||
let mut sasl_final = vec![0u8; len as usize - 8];
|
||||
match stream.read_exact(&mut sasl_final).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading sasl final message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"sasl final message".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
match scram
|
||||
@@ -284,7 +497,12 @@ impl Server {
|
||||
'E' => {
|
||||
let error_code = match stream.read_u8().await {
|
||||
Ok(error_code) => error_code,
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading error code message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"error code message".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
trace!("Error: {}", error_code);
|
||||
@@ -300,7 +518,12 @@ impl Server {
|
||||
|
||||
match stream.read_exact(&mut error).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading error message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"error message".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: the error message contains multiple fields; we can decode them and
|
||||
@@ -319,7 +542,12 @@ impl Server {
|
||||
|
||||
match stream.read_exact(&mut param).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading parameter status message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"parameter status message".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Save the parameter so we can pass it to the client later.
|
||||
@@ -336,12 +564,22 @@ impl Server {
|
||||
// See: <https://www.postgresql.org/docs/12/protocol-message-formats.html>.
|
||||
process_id = match stream.read_i32().await {
|
||||
Ok(id) => id,
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading process id message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"process id message".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
secret_key = match stream.read_i32().await {
|
||||
Ok(id) => id,
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading secret key message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"secret key message".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -351,15 +589,17 @@ impl Server {
|
||||
|
||||
match stream.read_exact(&mut idle).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading transaction status message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
"transaction status message".into(),
|
||||
server_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let (read, write) = stream.into_split();
|
||||
|
||||
let mut server = Server {
|
||||
address: address.clone(),
|
||||
read: BufReader::new(read),
|
||||
write,
|
||||
stream: BufStream::new(stream),
|
||||
buffer: BytesMut::with_capacity(8196),
|
||||
server_info,
|
||||
process_id,
|
||||
@@ -413,7 +653,7 @@ impl Server {
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
error!("Could not connect to server: {}", err);
|
||||
return Err(Error::SocketError(format!("Error reading cancel message")));
|
||||
return Err(Error::SocketError("Error reading cancel message".into()));
|
||||
}
|
||||
};
|
||||
configure_socket(&stream);
|
||||
@@ -426,7 +666,7 @@ impl Server {
|
||||
bytes.put_i32(process_id);
|
||||
bytes.put_i32(secret_key);
|
||||
|
||||
write_all(&mut stream, bytes).await
|
||||
write_all_flush(&mut stream, &bytes).await
|
||||
}
|
||||
|
||||
/// Send messages to the server from the client.
|
||||
@@ -434,7 +674,7 @@ impl Server {
|
||||
self.mirror_send(messages);
|
||||
self.stats().data_sent(messages.len());
|
||||
|
||||
match write_all_half(&mut self.write, messages).await {
|
||||
match write_all_flush(&mut self.stream, &messages).await {
|
||||
Ok(_) => {
|
||||
// Successfully sent to server
|
||||
self.last_activity = SystemTime::now();
|
||||
@@ -453,7 +693,7 @@ impl Server {
|
||||
/// in order to receive all data the server has to offer.
|
||||
pub async fn recv(&mut self) -> Result<BytesMut, Error> {
|
||||
loop {
|
||||
let mut message = match read_message(&mut self.read).await {
|
||||
let mut message = match read_message(&mut self.stream).await {
|
||||
Ok(message) => message,
|
||||
Err(err) => {
|
||||
error!("Terminating server because of: {:?}", err);
|
||||
@@ -846,13 +1086,13 @@ impl Drop for Server {
|
||||
// Update statistics
|
||||
self.stats.disconnect();
|
||||
|
||||
let mut bytes = BytesMut::with_capacity(4);
|
||||
let mut bytes = BytesMut::with_capacity(5);
|
||||
bytes.put_u8(b'X');
|
||||
bytes.put_i32(4);
|
||||
|
||||
match self.write.try_write(&bytes) {
|
||||
Ok(_) => (),
|
||||
Err(_) => debug!("Dirty shutdown"),
|
||||
match self.stream.get_mut().try_write(&bytes) {
|
||||
Ok(5) => (),
|
||||
_ => debug!("Dirty shutdown"),
|
||||
};
|
||||
|
||||
// Should not matter.
|
||||
|
||||
@@ -66,7 +66,7 @@ impl Reporter {
|
||||
CLIENT_STATS.write().insert(client_id, stats);
|
||||
}
|
||||
|
||||
/// Reports a client is disconecting from the pooler.
|
||||
/// Reports a client is disconnecting from the pooler.
|
||||
fn client_disconnecting(&self, client_id: i32) {
|
||||
CLIENT_STATS.write().remove(&client_id);
|
||||
}
|
||||
@@ -76,7 +76,7 @@ impl Reporter {
|
||||
fn server_register(&self, server_id: i32, stats: Arc<ServerStats>) {
|
||||
SERVER_STATS.write().insert(server_id, stats);
|
||||
}
|
||||
/// Reports a server connection is disconecting from the pooler.
|
||||
/// Reports a server connection is disconnecting from the pooler.
|
||||
fn server_disconnecting(&self, server_id: i32) {
|
||||
SERVER_STATS.write().remove(&server_id);
|
||||
}
|
||||
|
||||
@@ -92,7 +92,7 @@ impl ClientStats {
|
||||
}
|
||||
}
|
||||
|
||||
/// Reports a client is disconecting from the pooler and
|
||||
/// Reports a client is disconnecting from the pooler and
|
||||
/// update metrics on the corresponding pool.
|
||||
pub fn disconnect(&self) {
|
||||
self.reporter.client_disconnecting(self.client_id);
|
||||
@@ -140,7 +140,7 @@ impl ClientStats {
|
||||
self.error_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Reportes the time spent by a client waiting to get a healthy connection from the pool
|
||||
/// Reporters the time spent by a client waiting to get a healthy connection from the pool
|
||||
pub fn checkout_time(&self, microseconds: u64) {
|
||||
self.total_wait_time
|
||||
.fetch_add(microseconds, Ordering::Relaxed);
|
||||
|
||||
@@ -100,10 +100,9 @@ impl ServerStats {
|
||||
.server_idle(self.state.load(Ordering::Relaxed));
|
||||
|
||||
self.state.store(ServerState::Idle, Ordering::Relaxed);
|
||||
self.set_undefined_application();
|
||||
}
|
||||
|
||||
/// Reports a server connection is disconecting from the pooler.
|
||||
/// Reports a server connection is disconnecting from the pooler.
|
||||
/// Also updates metrics on the pool regarding server usage.
|
||||
pub fn disconnect(&self) {
|
||||
self.reporter.server_disconnecting(self.server_id);
|
||||
|
||||
23
src/tls.rs
23
src/tls.rs
@@ -4,7 +4,12 @@ use rustls_pemfile::{certs, read_one, Item};
|
||||
use std::iter;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tokio_rustls::rustls::{self, Certificate, PrivateKey};
|
||||
use std::time::SystemTime;
|
||||
use tokio_rustls::rustls::{
|
||||
self,
|
||||
client::{ServerCertVerified, ServerCertVerifier},
|
||||
Certificate, PrivateKey, ServerName,
|
||||
};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
|
||||
use crate::config::get_config;
|
||||
@@ -64,3 +69,19 @@ impl Tls {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NoCertificateVerification;
|
||||
|
||||
impl ServerCertVerifier for NoCertificateVerification {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &Certificate,
|
||||
_intermediates: &[Certificate],
|
||||
_server_name: &ServerName,
|
||||
_scts: &mut dyn Iterator<Item = &[u8]>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: SystemTime,
|
||||
) -> Result<ServerCertVerified, rustls::Error> {
|
||||
Ok(ServerCertVerified::assertion())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,9 +37,9 @@ describe "Admin" do
|
||||
describe "SHOW POOLS" do
|
||||
context "bad credentials" do
|
||||
it "does not change any stats" do
|
||||
bad_passsword_url = URI(pgcat_conn_str)
|
||||
bad_passsword_url.password = "wrong"
|
||||
expect { PG::connect("#{bad_passsword_url.to_s}?application_name=bad_password") }.to raise_error(PG::ConnectionBad)
|
||||
bad_password_url = URI(pgcat_conn_str)
|
||||
bad_password_url.password = "wrong"
|
||||
expect { PG::connect("#{bad_password_url.to_s}?application_name=bad_password") }.to raise_error(PG::ConnectionBad)
|
||||
|
||||
sleep(1)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
|
||||
259
tests/ruby/helpers/pg_socket.rb
Normal file
259
tests/ruby/helpers/pg_socket.rb
Normal file
@@ -0,0 +1,259 @@
|
||||
require 'socket'
|
||||
require 'digest/md5'
|
||||
|
||||
BACKEND_MESSAGE_CODES = {
|
||||
'Z' => "ReadyForQuery",
|
||||
'C' => "CommandComplete",
|
||||
'T' => "RowDescription",
|
||||
'D' => "DataRow",
|
||||
'1' => "ParseComplete",
|
||||
'2' => "BindComplete",
|
||||
'E' => "ErrorResponse",
|
||||
's' => "PortalSuspended",
|
||||
}
|
||||
|
||||
class PostgresSocket
|
||||
def initialize(host, port)
|
||||
@port = port
|
||||
@host = host
|
||||
@socket = TCPSocket.new @host, @port
|
||||
@parameters = {}
|
||||
@verbose = true
|
||||
end
|
||||
|
||||
def send_md5_password_message(username, password, salt)
|
||||
m = Digest::MD5.hexdigest(password + username)
|
||||
m = Digest::MD5.hexdigest(m + salt.map(&:chr).join(""))
|
||||
m = 'md5' + m
|
||||
bytes = (m.split("").map(&:ord) + [0]).flatten
|
||||
message_size = bytes.count + 4
|
||||
|
||||
message = []
|
||||
|
||||
message << 'p'.ord
|
||||
message << [message_size].pack('l>').unpack('CCCC') # 4
|
||||
message << bytes
|
||||
message.flatten!
|
||||
|
||||
|
||||
@socket.write(message.pack('C*'))
|
||||
end
|
||||
|
||||
def send_startup_message(username, database, password)
|
||||
message = []
|
||||
|
||||
message << [196608].pack('l>').unpack('CCCC') # 4
|
||||
message << "user".split('').map(&:ord) # 4, 8
|
||||
message << 0 # 1, 9
|
||||
message << username.split('').map(&:ord) # 2, 11
|
||||
message << 0 # 1, 12
|
||||
message << "database".split('').map(&:ord) # 8, 20
|
||||
message << 0 # 1, 21
|
||||
message << database.split('').map(&:ord) # 2, 23
|
||||
message << 0 # 1, 24
|
||||
message << 0 # 1, 25
|
||||
message.flatten!
|
||||
|
||||
total_message_size = message.size + 4
|
||||
|
||||
message_len = [total_message_size].pack('l>').unpack('CCCC')
|
||||
|
||||
@socket.write([message_len + message].flatten.pack('C*'))
|
||||
|
||||
sleep 0.1
|
||||
|
||||
read_startup_response(username, password)
|
||||
end
|
||||
|
||||
def read_startup_response(username, password)
|
||||
message_code, message_len = @socket.recv(5).unpack("al>")
|
||||
while message_code == 'R'
|
||||
auth_code = @socket.recv(4).unpack('l>').pop
|
||||
case auth_code
|
||||
when 5 # md5
|
||||
salt = @socket.recv(4).unpack('CCCC')
|
||||
send_md5_password_message(username, password, salt)
|
||||
message_code, message_len = @socket.recv(5).unpack("al>")
|
||||
when 0 # trust
|
||||
break
|
||||
end
|
||||
end
|
||||
loop do
|
||||
message_code, message_len = @socket.recv(5).unpack("al>")
|
||||
if message_code == 'Z'
|
||||
@socket.recv(1).unpack("a") # most likely I
|
||||
break # We are good to go
|
||||
end
|
||||
if message_code == 'S'
|
||||
actual_message = @socket.recv(message_len - 4).unpack("C*")
|
||||
k,v = actual_message.pack('U*').split(/\x00/)
|
||||
@parameters[k] = v
|
||||
end
|
||||
if message_code == 'K'
|
||||
process_id, secret_key = @socket.recv(message_len - 4).unpack("l>l>")
|
||||
@parameters["process_id"] = process_id
|
||||
@parameters["secret_key"] = secret_key
|
||||
end
|
||||
end
|
||||
return @parameters
|
||||
end
|
||||
|
||||
def cancel_query
|
||||
socket = TCPSocket.new @host, @port
|
||||
process_key = @parameters["process_id"]
|
||||
secret_key = @parameters["secret_key"]
|
||||
message = []
|
||||
message << [16].pack('l>').unpack('CCCC') # 4
|
||||
message << [80877102].pack('l>').unpack('CCCC') # 4
|
||||
message << [process_key.to_i].pack('l>').unpack('CCCC') # 4
|
||||
message << [secret_key.to_i].pack('l>').unpack('CCCC') # 4
|
||||
message.flatten!
|
||||
socket.write(message.flatten.pack('C*'))
|
||||
socket.close
|
||||
log "[F] Sent CancelRequest message"
|
||||
end
|
||||
|
||||
def send_query_message(query)
|
||||
query_size = query.length
|
||||
message_size = 1 + 4 + query_size
|
||||
message = []
|
||||
message << "Q".ord
|
||||
message << [message_size].pack('l>').unpack('CCCC') # 4
|
||||
message << query.split('').map(&:ord) # 2, 11
|
||||
message << 0 # 1, 12
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent Q message (#{query})"
|
||||
end
|
||||
|
||||
def send_parse_message(query)
|
||||
query_size = query.length
|
||||
message_size = 2 + 2 + 4 + query_size
|
||||
message = []
|
||||
message << "P".ord
|
||||
message << [message_size].pack('l>').unpack('CCCC') # 4
|
||||
message << 0 # unnamed statement
|
||||
message << query.split('').map(&:ord) # 2, 11
|
||||
message << 0 # 1, 12
|
||||
message << [0, 0]
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent P message (#{query})"
|
||||
end
|
||||
|
||||
def send_bind_message
|
||||
message = []
|
||||
message << "B".ord
|
||||
message << [12].pack('l>').unpack('CCCC') # 4
|
||||
message << 0 # unnamed statement
|
||||
message << 0 # unnamed statement
|
||||
message << [0, 0] # 2
|
||||
message << [0, 0] # 2
|
||||
message << [0, 0] # 2
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent B message"
|
||||
end
|
||||
|
||||
def send_describe_message(mode)
|
||||
message = []
|
||||
message << "D".ord
|
||||
message << [6].pack('l>').unpack('CCCC') # 4
|
||||
message << mode.ord
|
||||
message << 0 # unnamed statement
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent D message"
|
||||
end
|
||||
|
||||
def send_execute_message(limit=0)
|
||||
message = []
|
||||
message << "E".ord
|
||||
message << [9].pack('l>').unpack('CCCC') # 4
|
||||
message << 0 # unnamed statement
|
||||
message << [limit].pack('l>').unpack('CCCC') # 4
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent E message"
|
||||
end
|
||||
|
||||
def send_sync_message
|
||||
message = []
|
||||
message << "S".ord
|
||||
message << [4].pack('l>').unpack('CCCC') # 4
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent S message"
|
||||
end
|
||||
|
||||
def send_copydone_message
|
||||
message = []
|
||||
message << "c".ord
|
||||
message << [4].pack('l>').unpack('CCCC') # 4
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent c message"
|
||||
end
|
||||
|
||||
def send_copyfail_message
|
||||
message = []
|
||||
message << "f".ord
|
||||
message << [5].pack('l>').unpack('CCCC') # 4
|
||||
message << 0
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent f message"
|
||||
end
|
||||
|
||||
def send_flush_message
|
||||
message = []
|
||||
message << "H".ord
|
||||
message << [4].pack('l>').unpack('CCCC') # 4
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent H message"
|
||||
end
|
||||
|
||||
def read_from_server()
|
||||
output_messages = []
|
||||
retry_count = 0
|
||||
message_code = nil
|
||||
message_len = 0
|
||||
loop do
|
||||
begin
|
||||
message_code, message_len = @socket.recv_nonblock(5).unpack("al>")
|
||||
rescue IO::WaitReadable
|
||||
return output_messages if retry_count > 50
|
||||
|
||||
retry_count += 1
|
||||
sleep(0.01)
|
||||
next
|
||||
end
|
||||
message = {
|
||||
code: message_code,
|
||||
len: message_len,
|
||||
bytes: []
|
||||
}
|
||||
log "[B] #{BACKEND_MESSAGE_CODES[message_code] || ('UnknownMessage(' + message_code + ')')}"
|
||||
|
||||
actual_message_length = message_len - 4
|
||||
if actual_message_length > 0
|
||||
message[:bytes] = @socket.recv(message_len - 4).unpack("C*")
|
||||
log "\t#{message[:bytes].join(",")}"
|
||||
log "\t#{message[:bytes].map(&:chr).join(" ")}"
|
||||
end
|
||||
output_messages << message
|
||||
return output_messages if message_code == 'Z'
|
||||
end
|
||||
end
|
||||
|
||||
def log(msg)
|
||||
return unless @verbose
|
||||
|
||||
puts msg
|
||||
end
|
||||
|
||||
def close
|
||||
@socket.close
|
||||
end
|
||||
end
|
||||
@@ -2,6 +2,7 @@ require 'json'
|
||||
require 'ostruct'
|
||||
require_relative 'pgcat_process'
|
||||
require_relative 'pg_instance'
|
||||
require_relative 'pg_socket'
|
||||
|
||||
class ::Hash
|
||||
def deep_merge(second)
|
||||
|
||||
@@ -65,7 +65,7 @@ describe "Least Outstanding Queries Load Balancing" do
|
||||
processes.pgcat.shutdown
|
||||
end
|
||||
|
||||
context "under homogenous load" do
|
||||
context "under homogeneous load" do
|
||||
it "balances query volume between all instances" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ describe "Query Mirroing" do
|
||||
processes.pgcat.shutdown
|
||||
end
|
||||
|
||||
it "can mirror a query" do
|
||||
xit "can mirror a query" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
runs = 15
|
||||
runs.times { conn.async_exec("SELECT 1 + 2") }
|
||||
|
||||
155
tests/ruby/protocol_spec.rb
Normal file
155
tests/ruby/protocol_spec.rb
Normal file
@@ -0,0 +1,155 @@
|
||||
# frozen_string_literal: true
|
||||
require_relative 'spec_helper'
|
||||
|
||||
|
||||
describe "Portocol handling" do
|
||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 1, "session") }
|
||||
let(:sequence) { [] }
|
||||
let(:pgcat_socket) { PostgresSocket.new('localhost', processes.pgcat.port) }
|
||||
let(:pgdb_socket) { PostgresSocket.new('localhost', processes.all_databases.first.port) }
|
||||
|
||||
after do
|
||||
pgdb_socket.close
|
||||
pgcat_socket.close
|
||||
processes.all_databases.map(&:reset)
|
||||
processes.pgcat.shutdown
|
||||
end
|
||||
|
||||
def run_comparison(sequence, socket_a, socket_b)
|
||||
sequence.each do |msg, *args|
|
||||
socket_a.send(msg, *args)
|
||||
socket_b.send(msg, *args)
|
||||
|
||||
compare_messages(
|
||||
socket_a.read_from_server,
|
||||
socket_b.read_from_server
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
def compare_messages(msg_arr0, msg_arr1)
|
||||
if msg_arr0.count != msg_arr1.count
|
||||
error_output = []
|
||||
|
||||
error_output << "#{msg_arr0.count} : #{msg_arr1.count}"
|
||||
error_output << "PgCat Messages"
|
||||
error_output += msg_arr0.map { |message| "\t#{message[:code]} - #{message[:bytes].map(&:chr).join(" ")}" }
|
||||
error_output << "PgServer Messages"
|
||||
error_output += msg_arr1.map { |message| "\t#{message[:code]} - #{message[:bytes].map(&:chr).join(" ")}" }
|
||||
error_desc = error_output.join("\n")
|
||||
raise StandardError, "Message count mismatch #{error_desc}"
|
||||
end
|
||||
|
||||
(0..msg_arr0.count - 1).all? do |i|
|
||||
msg0 = msg_arr0[i]
|
||||
msg1 = msg_arr1[i]
|
||||
|
||||
result = [
|
||||
msg0[:code] == msg1[:code],
|
||||
msg0[:len] == msg1[:len],
|
||||
msg0[:bytes] == msg1[:bytes],
|
||||
].all?
|
||||
|
||||
next result if result
|
||||
|
||||
if result == false
|
||||
error_string = []
|
||||
if msg0[:code] != msg1[:code]
|
||||
error_string << "code #{msg0[:code]} != #{msg1[:code]}"
|
||||
end
|
||||
if msg0[:len] != msg1[:len]
|
||||
error_string << "len #{msg0[:len]} != #{msg1[:len]}"
|
||||
end
|
||||
if msg0[:bytes] != msg1[:bytes]
|
||||
error_string << "bytes #{msg0[:bytes]} != #{msg1[:bytes]}"
|
||||
end
|
||||
err = error_string.join("\n")
|
||||
|
||||
raise StandardError, "Message mismatch #{err}"
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
RSpec.shared_examples "at parity with database" do
|
||||
before do
|
||||
pgcat_socket.send_startup_message("sharding_user", "sharded_db", "sharding_user")
|
||||
pgdb_socket.send_startup_message("sharding_user", "shard0", "sharding_user")
|
||||
end
|
||||
|
||||
it "works" do
|
||||
run_comparison(sequence, pgcat_socket, pgdb_socket)
|
||||
end
|
||||
end
|
||||
|
||||
context "Cancel Query" do
|
||||
let(:sequence) {
|
||||
[
|
||||
[:send_query_message, "SELECT pg_sleep(5)"],
|
||||
[:cancel_query]
|
||||
]
|
||||
}
|
||||
|
||||
it_behaves_like "at parity with database"
|
||||
end
|
||||
|
||||
xcontext "Simple query after parse" do
|
||||
let(:sequence) {
|
||||
[
|
||||
[:send_parse_message, "SELECT 5"],
|
||||
[:send_query_message, "SELECT 1"],
|
||||
[:send_bind_message],
|
||||
[:send_describe_message, "P"],
|
||||
[:send_execute_message],
|
||||
[:send_sync_message],
|
||||
]
|
||||
}
|
||||
|
||||
# Known to fail due to PgCat not supporting flush
|
||||
it_behaves_like "at parity with database"
|
||||
end
|
||||
|
||||
xcontext "Flush message" do
|
||||
let(:sequence) {
|
||||
[
|
||||
[:send_parse_message, "SELECT 1"],
|
||||
[:send_flush_message]
|
||||
]
|
||||
}
|
||||
|
||||
# Known to fail due to PgCat not supporting flush
|
||||
it_behaves_like "at parity with database"
|
||||
end
|
||||
|
||||
xcontext "Bind without parse" do
|
||||
let(:sequence) {
|
||||
[
|
||||
[:send_bind_message]
|
||||
]
|
||||
}
|
||||
# This is known to fail.
|
||||
# Server responds immediately, Proxy buffers the message
|
||||
it_behaves_like "at parity with database"
|
||||
end
|
||||
|
||||
context "Simple message" do
|
||||
let(:sequence) {
|
||||
[[:send_query_message, "SELECT 1"]]
|
||||
}
|
||||
|
||||
it_behaves_like "at parity with database"
|
||||
end
|
||||
|
||||
context "Extended protocol" do
|
||||
let(:sequence) {
|
||||
[
|
||||
[:send_parse_message, "SELECT 1"],
|
||||
[:send_bind_message],
|
||||
[:send_describe_message, "P"],
|
||||
[:send_execute_message],
|
||||
[:send_sync_message],
|
||||
]
|
||||
}
|
||||
|
||||
it_behaves_like "at parity with database"
|
||||
end
|
||||
end
|
||||
@@ -27,7 +27,7 @@ describe "Sharding" do
|
||||
processes.pgcat.shutdown
|
||||
end
|
||||
|
||||
describe "automatic routing of extended procotol" do
|
||||
describe "automatic routing of extended protocol" do
|
||||
it "can do it" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
conn.exec("SET SERVER ROLE TO 'auto'")
|
||||
|
||||
1
utilities/requirements.txt
Normal file
1
utilities/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
tomli
|
||||
Reference in New Issue
Block a user