Compare commits

...

18 Commits

Author SHA1 Message Date
Lev Kokotov
5872354c3e remove debug 2022-08-28 17:29:13 -07:00
Lev Kokotov
48bb6ebeef Support settings custom search path 2022-08-28 17:23:28 -07:00
Mostafa Abdelraouf
3bc4f9351c Exit with failure codes if configs are bad (#146)
* Exit with failure codes if configs are bad

* fmt
2022-08-25 18:56:18 -07:00
Lev Kokotov
9d84d6f131 Graceful shutdown and refactor (#144)
* Graceful shutdown and refactor

* ok

* _Graceful_ shutdown

* Remove hardcoded setting

* clean up

* end

* timeout

* hmm

* hmm!

* bash

* bash

* hmm

* maybe maybe

* Adds tests and move non-admin connection rejection to startup (#145)

* Move error response

* Adds tests and removes unused variable

* Adds debug log

Co-authored-by: zainkabani <77307340+zainkabani@users.noreply.github.com>
2022-08-25 06:40:56 -07:00
Mostafa Abdelraouf
c054ff068d Avoid sending Z packet in the middle of extended protocol packet sequence if we fail to get connection from pool (#137)
* Failing test

* maybe

* try fail

* try

* add message

* pool size

* correct user

* more

* debug

* try fix

* see stdout

* stick?

* fix configs

* modify

* types

* m

* maybe

* make tests idempotent

* hopefully fails

* Add client fix

* revert pgcat.toml change

* Fix tests
2022-08-23 11:02:23 -07:00
Lev Kokotov
5a0cea6a24 Really fix idle servers (#141) 2022-08-22 11:56:40 -07:00
Lev Kokotov
d0e8171b1b Fix too many idle servers (#140)
* Fix too many idle servers

* oops
2022-08-22 11:52:34 -07:00
Lev Kokotov
069d76029f Fix incorrect routing for replicas (#139)
* Fix incorrect routing for replicas

* name
2022-08-21 22:40:49 -07:00
Lev Kokotov
902fafd8d7 Random lb (#138) 2022-08-21 22:20:31 -07:00
Mostafa Abdelraouf
5f5b5e2543 Random instance selection (#136)
* wip

* revert some'

* revert more

* poor-man's integration test

* remove test

* fmt

* --workspace

* fix build

* fix integration test

* another stab

* log

* run after integration

* cargo test after integration

* revert

* revert more

* Refactor + clean up

* more clean up
2022-08-21 22:15:20 -07:00
zainkabani
5948fef6cf Minor Refactoring of re-used code and server stat reporting (#129)
* Minor changes to stats reporting and recduce re-used code

* fmt
2022-08-18 05:12:38 -07:00
Mostafa Abdelraouf
790898c20e Add pool name and username to address object (#128)
* Add pool name and username to address object

* Fix address name

* fmt
2022-08-17 08:40:47 -07:00
Pradeep Chhetri
d64f6793c1 Minor cleanup in admin command (#126)
* Minor cleanup in admin command

* Typo correction

* fix when the admin query is ending with semicolon
2022-08-16 10:01:46 -07:00
Lev Kokotov
cea35db35c Fix lost statistics (#125)
* Lost events

* more logging
2022-08-15 23:54:49 -07:00
Mostafa Abdelraouf
a3aefabb47 Add cl_idle to SHOW POOLS (#124) 2022-08-15 20:51:37 -07:00
Lev Kokotov
3285006440 Statement timeout + replica imbalance fix (#122)
* Statement timeout

* send error message too

* Correct error messages

* Fix replica inbalance

* disable stmt timeout by default

* Redundant mark_bad

* revert healthcheck delay

* tests

* set it to 0

* reload config again
2022-08-13 13:45:58 -07:00
Pradeep Chhetri
52303cc808 Make prometheus port configurable (#121)
* Make prometheus port configurable

* Update circleci config
2022-08-13 10:25:14 -07:00
Lev Kokotov
be254cedd9 Fix debug log (#120) 2022-08-11 22:47:47 -07:00
21 changed files with 1069 additions and 637 deletions

View File

@@ -11,9 +11,12 @@ host = "0.0.0.0"
# Port to run on, same as PgBouncer used in this example. # Port to run on, same as PgBouncer used in this example.
port = 6432 port = 6432
# enable prometheus exporter on port 9930 # Whether to enable prometheus exporter or not.
enable_prometheus_exporter = true enable_prometheus_exporter = true
# Port at which prometheus exporter listens on.
prometheus_exporter_port = 9930
# How long to wait before aborting a server connection (ms). # How long to wait before aborting a server connection (ms).
connect_timeout = 100 connect_timeout = 100
@@ -88,11 +91,13 @@ password = "sharding_user"
# The maximum number of connection from a single Pgcat process to any database in the cluster # The maximum number of connection from a single Pgcat process to any database in the cluster
# is the sum of pool_size across all users. # is the sum of pool_size across all users.
pool_size = 9 pool_size = 9
statement_timeout = 0
[pools.sharded_db.users.1] [pools.sharded_db.users.1]
username = "other_user" username = "other_user"
password = "other_user" password = "other_user"
pool_size = 21 pool_size = 21
statement_timeout = 30000
# Shard 0 # Shard 0
[pools.sharded_db.shards.0] [pools.sharded_db.shards.0]
@@ -103,6 +108,7 @@ servers = [
] ]
# Database name (e.g. "postgres") # Database name (e.g. "postgres")
database = "shard0" database = "shard0"
search_path = "\"$user\",public"
[pools.sharded_db.shards.1] [pools.sharded_db.shards.1]
servers = [ servers = [
@@ -130,6 +136,7 @@ sharding_function = "pg_bigint_hash"
username = "simple_user" username = "simple_user"
password = "simple_user" password = "simple_user"
pool_size = 5 pool_size = 5
statement_timeout = 30000
[pools.simple_db.shards.0] [pools.simple_db.shards.0]
servers = [ servers = [

View File

@@ -3,6 +3,9 @@
set -e set -e
set -o xtrace set -o xtrace
# non-zero exit code if we provide bad configs
(! ./target/debug/pgcat "fake_configs" 2>/dev/null)
# Start PgCat with a particular log level # Start PgCat with a particular log level
# for inspection. # for inspection.
function start_pgcat() { function start_pgcat() {
@@ -19,8 +22,8 @@ PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard1 -i
PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard2 -i PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard2 -i
# Install Toxiproxy to simulate a downed/slow database # Install Toxiproxy to simulate a downed/slow database
wget -O toxiproxy-2.1.4.deb https://github.com/Shopify/toxiproxy/releases/download/v2.1.4/toxiproxy_2.1.4_amd64.deb wget -O toxiproxy-2.4.0.deb https://github.com/Shopify/toxiproxy/releases/download/v2.4.0/toxiproxy_2.4.0_linux_$(dpkg --print-architecture).deb
sudo dpkg -i toxiproxy-2.1.4.deb sudo dpkg -i toxiproxy-2.4.0.deb
# Start Toxiproxy # Start Toxiproxy
toxiproxy-server & toxiproxy-server &
@@ -66,6 +69,18 @@ psql -U sharding_user -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_te
# Replica/primary selection & more sharding tests # Replica/primary selection & more sharding tests
psql -U sharding_user -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_primary_replica.sql > /dev/null psql -U sharding_user -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_primary_replica.sql > /dev/null
# Statement timeout tests
sed -i 's/statement_timeout = 0/statement_timeout = 100/' .circleci/pgcat.toml
kill -SIGHUP $(pgrep pgcat) # Reload config
sleep 0.2
# This should timeout
(! psql -U sharding_user -e -h 127.0.0.1 -p 6432 -c 'select pg_sleep(0.5)')
# Disable statement timeout
sed -i 's/statement_timeout = 100/statement_timeout = 0/' .circleci/pgcat.toml
kill -SIGHUP $(pgrep pgcat) # Reload config again
# #
# ActiveRecord tests # ActiveRecord tests
# #
@@ -117,11 +132,14 @@ toxiproxy-cli toxic remove --toxicName latency_downstream postgres_replica
start_pgcat "info" start_pgcat "info"
# Test session mode (and config reload) # Test session mode (and config reload)
sed -i 's/pool_mode = "transaction"/pool_mode = "session"/' .circleci/pgcat.toml sed -i '0,/simple_db/s/pool_mode = "transaction"/pool_mode = "session"/' .circleci/pgcat.toml
# Reload config test # Reload config test
kill -SIGHUP $(pgrep pgcat) kill -SIGHUP $(pgrep pgcat)
# Revert settings after reload. Makes test runs idempotent
sed -i '0,/simple_db/s/pool_mode = "session"/pool_mode = "transaction"/' .circleci/pgcat.toml
sleep 1 sleep 1
# Prepared statements that will only work in session mode # Prepared statements that will only work in session mode

7
Cargo.lock generated
View File

@@ -159,6 +159,12 @@ dependencies = [
"termcolor", "termcolor",
] ]
[[package]]
name = "exitcode"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de853764b47027c2e862a995c34978ffa63c1501f2e15f987ba11bd4f9bba193"
[[package]] [[package]]
name = "fnv" name = "fnv"
version = "1.0.7" version = "1.0.7"
@@ -515,6 +521,7 @@ dependencies = [
"bytes", "bytes",
"chrono", "chrono",
"env_logger", "env_logger",
"exitcode",
"hmac", "hmac",
"hyper", "hyper",
"log", "log",

View File

@@ -33,3 +33,4 @@ tokio-rustls = "0.23"
rustls-pemfile = "1" rustls-pemfile = "1"
hyper = { version = "0.14", features = ["full"] } hyper = { version = "0.14", features = ["full"] }
phf = { version = "0.10", features = ["macros"] } phf = { version = "0.10", features = ["macros"] }
exitcode = "1.1.2"

View File

@@ -15,7 +15,7 @@ PostgreSQL pooler (like PgBouncer) with sharding, load balancing and failover su
| Session pooling | :white_check_mark: | Identical to PgBouncer. | | Session pooling | :white_check_mark: | Identical to PgBouncer. |
| `COPY` support | :white_check_mark: | Both `COPY TO` and `COPY FROM` are supported. | | `COPY` support | :white_check_mark: | Both `COPY TO` and `COPY FROM` are supported. |
| Query cancellation | :white_check_mark: | Supported both in transaction and session pooling modes. | | Query cancellation | :white_check_mark: | Supported both in transaction and session pooling modes. |
| Load balancing of read queries | :white_check_mark: | Using round-robin between replicas. Primary is included when `primary_reads_enabled` is enabled (default). | | Load balancing of read queries | :white_check_mark: | Using random between replicas. Primary is included when `primary_reads_enabled` is enabled (default). |
| Sharding | :white_check_mark: | Transactions are sharded using `SET SHARD TO` and `SET SHARDING KEY TO` syntax extensions; see examples below. | | Sharding | :white_check_mark: | Transactions are sharded using `SET SHARD TO` and `SET SHARDING KEY TO` syntax extensions; see examples below. |
| Failover | :white_check_mark: | Replicas are tested with a health check. If a health check fails, remaining replicas are attempted; see below for algorithm description and examples. | | Failover | :white_check_mark: | Replicas are tested with a health check. If a health check fails, remaining replicas are attempted; see below for algorithm description and examples. |
| Statistics | :white_check_mark: | Statistics available in the admin database (`pgcat` and `pgbouncer`) with `SHOW STATS`, `SHOW POOLS` and others. | | Statistics | :white_check_mark: | Statistics available in the admin database (`pgcat` and `pgbouncer`) with `SHOW STATS`, `SHOW POOLS` and others. |
@@ -38,30 +38,34 @@ psql -h 127.0.0.1 -p 6432 -c 'SELECT 1'
### Config ### Config
| **Name** | **Description** | **Examples** | | **Name** | **Description** | **Examples** |
|-------------------------|--------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------| |------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------|
| **`general`** | | | | **`general`** | | |
| `host` | The pooler will run on this host, 0.0.0.0 means accessible from everywhere. | `0.0.0.0` | | `host` | The pooler will run on this host, 0.0.0.0 means accessible from everywhere. | `0.0.0.0` |
| `port` | The pooler will run on this port. | `6432` | | `port` | The pooler will run on this port. | `6432` |
| `pool_size` | Maximum allowed server connections per pool. Pools are separated for each user/shard/server role. The connections are allocated as needed. | `15` | | `enable_prometheus_exporter` | Enable prometheus exporter which will export metrics in prometheus exposition format. | `true` |
| `pool_mode` | The pool mode to use, i.e. `session` or `transaction`. | `transaction` | | `prometheus_exporter_port` | Port at which prometheus exporter listens on. | `9930` |
| `connect_timeout` | Maximum time to establish a connection to a server (milliseconds). If reached, the server is banned and the next target is attempted. | `5000` | | `pool_size` | Maximum allowed server connections per pool. Pools are separated for each user/shard/server role. The connections are allocated as needed. | `15` |
| `healthcheck_timeout` | Maximum time to pass a health check (`SELECT 1`, milliseconds). If reached, the server is banned and the next target is attempted. | `1000` | | `pool_mode` | The pool mode to use, i.e. `session` or `transaction`. | `transaction` |
| `shutdown_timeout` | Maximum time to give clients during shutdown before forcibly killing client connections (ms). | `60000` | | `connect_timeout` | Maximum time to establish a connection to a server (milliseconds). If reached, the server is banned and the next target is attempted. | `5000` |
| `healthcheck_delay` | How long to keep connection available for immediate re-use, without running a healthcheck query on it | `30000` | | `healthcheck_timeout` | Maximum time to pass a health check (`SELECT 1`, milliseconds). If reached, the server is banned and the next target is attempted. | `1000` |
| `ban_time` | Ban time for a server (seconds). It won't be allowed to serve transactions until the ban expires; failover targets will be used instead. | `60` | | `shutdown_timeout` | Maximum time to give clients during shutdown before forcibly killing client connections (ms). | `60000` |
| | | | | `healthcheck_delay` | How long to keep connection available for immediate re-use, without running a healthcheck query on it | `30000` |
| **`user`** | | | | `ban_time` | Ban time for a server (seconds). It won't be allowed to serve transactions until the ban expires; failover targets will be used instead. | `60` |
| `name` | The user name. | `sharding_user` | | `autoreload` | Enable auto-reload of config after fixed time-interval. | `false` |
| `password` | The user password in plaintext. | `hunter2` | | | | |
| | | | | **`user`** | | |
| **`shards`** | Shards are numerically numbered starting from 0; the order in the config is preserved by the pooler to route queries accordingly. | `[shards.0]` | | `name` | The user name. | `sharding_user` |
| `servers` | List of servers to connect to and their roles. A server is: `[host, port, role]`, where `role` is either `primary` or `replica`. | `["127.0.0.1", 5432, "primary"]` | | `password` | The user password in plaintext. | `hunter2` |
| `database` | The name of the database to connect to. This is the same on all servers that are part of one shard. | | | | | |
| **`query_router`** | | | | **`shards`** | Shards are numerically numbered starting from 0; the order in the config is preserved by the pooler to route queries accordingly. | `[shards.0]` |
| `default_role` | Traffic is routed to this role by default (round-robin), unless the client specifies otherwise. Default is `any`, for any role available. | `any`, `primary`, `replica` | | `servers` | List of servers to connect to and their roles. A server is: `[host, port, role]`, where `role` is either `primary` or `replica`. | `["127.0.0.1", 5432, "primary"]` |
| `query_parser_enabled` | Enable the query parser which will inspect incoming queries and route them to a primary or replicas. | `false` | | `database` | The name of the database to connect to. This is the same on all servers that are part of one shard. | |
| `primary_reads_enabled` | Enable this to allow read queries on the primary; otherwise read queries are routed to the replicas. | `true` | | | | |
| **`query_router`** | | |
| `default_role` | Traffic is routed to this role by default (random), unless the client specifies otherwise. Default is `any`, for any role available. | `any`, `primary`, `replica` |
| `query_parser_enabled` | Enable the query parser which will inspect incoming queries and route them to a primary or replicas. | `false` |
| `primary_reads_enabled` | Enable this to allow read queries on the primary; otherwise read queries are routed to the replicas. | `true` |
## Local development ## Local development
@@ -108,7 +112,7 @@ In transaction mode, a client talks to one server for the duration of a single t
This mode is enabled by default. This mode is enabled by default.
### Load balancing of read queries ### Load balancing of read queries
All queries are load balanced against the configured servers using the round-robin algorithm. The most straight forward configuration example would be to put this pooler in front of several replicas and let it load balance all queries. All queries are load balanced against the configured servers using the random algorithm. The most straight forward configuration example would be to put this pooler in front of several replicas and let it load balance all queries.
If the configuration includes a primary and replicas, the queries can be separated with the built-in query parser. The query parser will interpret the query and route all `SELECT` queries to a replica, while all other queries including explicit transactions will be routed to the primary. If the configuration includes a primary and replicas, the queries can be separated with the built-in query parser. The query parser will interpret the query and route all `SELECT` queries to a replica, while all other queries including explicit transactions will be routed to the primary.
@@ -147,18 +151,18 @@ Failover behavior can get pretty interesting (read complex) when multiple config
| **Query** | **`SET SERVER ROLE TO`** | **`query_parser_enabled`** | **`primary_reads_enabled`** | **Target state** | **Outcome** | | **Query** | **`SET SERVER ROLE TO`** | **`query_parser_enabled`** | **`primary_reads_enabled`** | **Target state** | **Outcome** |
|---------------------------|--------------------------|----------------------------|-----------------------------|------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------| |---------------------------|--------------------------|----------------------------|-----------------------------|------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Read query, i.e. `SELECT` | unset (any) | false | false | up | Query is routed to the first instance in the round-robin loop. | | Read query, i.e. `SELECT` | unset (any) | false | false | up | Query is routed to the first instance in the random loop. |
| Read query | unset (any) | true | false | up | Query is routed to the first replica instance in the round-robin loop. | | Read query | unset (any) | true | false | up | Query is routed to the first replica instance in the random loop. |
| Read query | unset (any) | true | true | up | Query is routed to the first instance in the round-robin loop. | | Read query | unset (any) | true | true | up | Query is routed to the first instance in the random loop. |
| Read query | replica | false | false | up | Query is routed to the first replica instance in the round-robin loop. | | Read query | replica | false | false | up | Query is routed to the first replica instance in the random loop. |
| Read query | primary | false | false | up | Query is routed to the primary. | | Read query | primary | false | false | up | Query is routed to the primary. |
| Read query | unset (any) | false | false | down | First instance is banned for reads. Next target in the round-robin loop is attempted. | | Read query | unset (any) | false | false | down | First instance is banned for reads. Next target in the random loop is attempted. |
| Read query | unset (any) | true | false | down | First replica instance is banned. Next replica instance is attempted in the round-robin loop. | | Read query | unset (any) | true | false | down | First replica instance is banned. Next replica instance is attempted in the random loop. |
| Read query | unset (any) | true | true | down | First instance (even if primary) is banned for reads. Next instance is attempted in the round-robin loop. | | Read query | unset (any) | true | true | down | First instance (even if primary) is banned for reads. Next instance is attempted in the random loop. |
| Read query | replica | false | false | down | First replica instance is banned. Next replica instance is attempted in the round-robin loop. | | Read query | replica | false | false | down | First replica instance is banned. Next replica instance is attempted in the random loop. |
| Read query | primary | false | false | down | The query is attempted against the primary and fails. The client receives an error. | | Read query | primary | false | false | down | The query is attempted against the primary and fails. The client receives an error. |
| | | | | | | | | | | | | |
| Write query e.g. `INSERT` | unset (any) | false | false | up | The query is attempted against the first available instance in the round-robin loop. If the instance is a replica, the query fails and the client receives an error. | | Write query e.g. `INSERT` | unset (any) | false | false | up | The query is attempted against the first available instance in the random loop. If the instance is a replica, the query fails and the client receives an error. |
| Write query | unset (any) | true | false | up | The query is routed to the primary. | | Write query | unset (any) | true | false | up | The query is routed to the primary. |
| Write query | unset (any) | true | true | up | The query is routed to the primary. | | Write query | unset (any) | true | true | up | The query is routed to the primary. |
| Write query | primary | false | false | up | The query is routed to the primary. | | Write query | primary | false | false | up | The query is routed to the primary. |

View File

@@ -1,7 +1,7 @@
version: "3" version: "3"
services: services:
postgres: postgres:
image: postgres:13 image: postgres:14
environment: environment:
POSTGRES_PASSWORD: postgres POSTGRES_PASSWORD: postgres
POSTGRES_HOST_AUTH_METHOD: md5 POSTGRES_HOST_AUTH_METHOD: md5

View File

@@ -11,9 +11,12 @@ host = "0.0.0.0"
# Port to run on, same as PgBouncer used in this example. # Port to run on, same as PgBouncer used in this example.
port = 6432 port = 6432
# enable prometheus exporter on port 9930 # Whether to enable prometheus exporter or not.
enable_prometheus_exporter = true enable_prometheus_exporter = true
# Port at which prometheus exporter listens on.
prometheus_exporter_port = 9930
# How long to wait before aborting a server connection (ms). # How long to wait before aborting a server connection (ms).
connect_timeout = 5000 connect_timeout = 5000
@@ -89,10 +92,14 @@ password = "postgres"
# is the sum of pool_size across all users. # is the sum of pool_size across all users.
pool_size = 9 pool_size = 9
# Maximum query duration. Dangerous, but protects against DBs that died in a non-obvious way.
statement_timeout = 0
[pools.sharded.users.1] [pools.sharded.users.1]
username = "postgres" username = "postgres"
password = "postgres" password = "postgres"
pool_size = 21 pool_size = 21
statement_timeout = 15000
# Shard 0 # Shard 0
[pools.sharded.shards.0] [pools.sharded.shards.0]
@@ -130,6 +137,7 @@ sharding_function = "pg_bigint_hash"
username = "postgres" username = "postgres"
password = "postgres" password = "postgres"
pool_size = 5 pool_size = 5
statement_timeout = 0
[pools.simple_db.shards.0] [pools.simple_db.shards.0]
servers = [ servers = [

View File

@@ -11,9 +11,12 @@ host = "0.0.0.0"
# Port to run on, same as PgBouncer used in this example. # Port to run on, same as PgBouncer used in this example.
port = 6432 port = 6432
# enable prometheus exporter on port 9930 # Whether to enable prometheus exporter or not.
enable_prometheus_exporter = true enable_prometheus_exporter = true
# Port at which prometheus exporter listens on.
prometheus_exporter_port = 9930
# How long to wait before aborting a server connection (ms). # How long to wait before aborting a server connection (ms).
connect_timeout = 5000 connect_timeout = 5000
@@ -89,10 +92,14 @@ password = "sharding_user"
# is the sum of pool_size across all users. # is the sum of pool_size across all users.
pool_size = 9 pool_size = 9
# Maximum query duration. Dangerous, but protects against DBs that died in a non-obvious way.
statement_timeout = 0
[pools.sharded_db.users.1] [pools.sharded_db.users.1]
username = "other_user" username = "other_user"
password = "other_user" password = "other_user"
pool_size = 21 pool_size = 21
statement_timeout = 15000
# Shard 0 # Shard 0
[pools.sharded_db.shards.0] [pools.sharded_db.shards.0]
@@ -130,6 +137,7 @@ sharding_function = "pg_bigint_hash"
username = "simple_user" username = "simple_user"
password = "simple_user" password = "simple_user"
pool_size = 5 pool_size = 5
statement_timeout = 0
[pools.simple_db.shards.0] [pools.simple_db.shards.0]
servers = [ servers = [

View File

@@ -44,32 +44,45 @@ where
trace!("Admin query: {}", query); trace!("Admin query: {}", query);
if query.starts_with("SHOW STATS") { let query_parts: Vec<&str> = query.trim_end_matches(';').split_whitespace().collect();
trace!("SHOW STATS");
show_stats(stream).await match query_parts[0] {
} else if query.starts_with("RELOAD") { "RELOAD" => {
trace!("RELOAD"); trace!("RELOAD");
reload(stream, client_server_map).await reload(stream, client_server_map).await
} else if query.starts_with("SHOW CONFIG") { }
trace!("SHOW CONFIG"); "SET" => {
show_config(stream).await trace!("SET");
} else if query.starts_with("SHOW DATABASES") { ignore_set(stream).await
trace!("SHOW DATABASES"); }
show_databases(stream).await "SHOW" => match query_parts[1] {
} else if query.starts_with("SHOW POOLS") { "CONFIG" => {
trace!("SHOW POOLS"); trace!("SHOW CONFIG");
show_pools(stream).await show_config(stream).await
} else if query.starts_with("SHOW LISTS") { }
trace!("SHOW LISTS"); "DATABASES" => {
show_lists(stream).await trace!("SHOW DATABASES");
} else if query.starts_with("SHOW VERSION") { show_databases(stream).await
trace!("SHOW VERSION"); }
show_version(stream).await "LISTS" => {
} else if query.starts_with("SET ") { trace!("SHOW LISTS");
trace!("SET"); show_lists(stream).await
ignore_set(stream).await }
} else { "POOLS" => {
error_response(stream, "Unsupported query against the admin database").await trace!("SHOW POOLS");
show_pools(stream).await
}
"STATS" => {
trace!("SHOW STATS");
show_stats(stream).await
}
"VERSION" => {
trace!("SHOW VERSION");
show_version(stream).await
}
_ => error_response(stream, "Unsupported SHOW query against the admin database").await,
},
_ => error_response(stream, "Unsupported query against the admin database").await,
} }
} }
@@ -174,6 +187,7 @@ where
let columns = vec![ let columns = vec![
("database", DataType::Text), ("database", DataType::Text),
("user", DataType::Text), ("user", DataType::Text),
("cl_idle", DataType::Numeric),
("cl_active", DataType::Numeric), ("cl_active", DataType::Numeric),
("cl_waiting", DataType::Numeric), ("cl_waiting", DataType::Numeric),
("cl_cancel_req", DataType::Numeric), ("cl_cancel_req", DataType::Numeric),
@@ -251,11 +265,11 @@ where
for (_, pool) in get_all_pools() { for (_, pool) in get_all_pools() {
let pool_config = pool.settings.clone(); let pool_config = pool.settings.clone();
for shard in 0..pool.shards() { for shard in 0..pool.shards() {
let database_name = &pool_config.shards[&shard.to_string()].database; let database_name = &pool.address(shard, 0).database;
for server in 0..pool.servers(shard) { for server in 0..pool.servers(shard) {
let address = pool.address(shard, server); let address = pool.address(shard, server);
let pool_state = pool.pool_state(shard, server); let pool_state = pool.pool_state(shard, server);
let banned = pool.is_banned(address, shard, Some(address.role)); let banned = pool.is_banned(address, Some(address.role));
res.put(data_row(&vec![ res.put(data_row(&vec![
address.name(), // name address.name(), // name

View File

@@ -5,13 +5,14 @@ use std::collections::HashMap;
use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf}; use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::broadcast::Receiver; use tokio::sync::broadcast::Receiver;
use tokio::sync::mpsc::Sender;
use crate::admin::{generate_server_info_for_admin, handle_admin}; use crate::admin::{generate_server_info_for_admin, handle_admin};
use crate::config::{get_config, Address}; use crate::config::{get_config, Address};
use crate::constants::*; use crate::constants::*;
use crate::errors::Error; use crate::errors::Error;
use crate::messages::*; use crate::messages::*;
use crate::pool::{get_pool, ClientServerMap, ConnectionPool}; use crate::pool::{get_pool, ClientServerMap, ConnectionPool, PoolMode};
use crate::query_router::{Command, QueryRouter}; use crate::query_router::{Command, QueryRouter};
use crate::server::Server; use crate::server::Server;
use crate::stats::{get_reporter, Reporter}; use crate::stats::{get_reporter, Reporter};
@@ -58,7 +59,6 @@ pub struct Client<S, T> {
client_server_map: ClientServerMap, client_server_map: ClientServerMap,
/// Client parameters, e.g. user, client_encoding, etc. /// Client parameters, e.g. user, client_encoding, etc.
#[allow(dead_code)]
parameters: HashMap<String, String>, parameters: HashMap<String, String>,
/// Statistics /// Statistics
@@ -73,21 +73,26 @@ pub struct Client<S, T> {
/// Last server process id we talked to. /// Last server process id we talked to.
last_server_id: Option<i32>, last_server_id: Option<i32>,
/// Connected to server
connected_to_server: bool,
/// Name of the server pool for this client (This comes from the database name in the connection string) /// Name of the server pool for this client (This comes from the database name in the connection string)
target_pool_name: String, pool_name: String,
/// Postgres user for this client (This comes from the user in the connection string) /// Postgres user for this client (This comes from the user in the connection string)
target_user_name: String, username: String,
/// Used to notify clients about an impending shutdown /// Used to notify clients about an impending shutdown
shutdown_event_receiver: Receiver<()>, shutdown: Receiver<()>,
} }
/// Client entrypoint. /// Client entrypoint.
pub async fn client_entrypoint( pub async fn client_entrypoint(
mut stream: TcpStream, mut stream: TcpStream,
client_server_map: ClientServerMap, client_server_map: ClientServerMap,
shutdown_event_receiver: Receiver<()>, shutdown: Receiver<()>,
drain: Sender<i8>,
admin_only: bool,
) -> Result<(), Error> { ) -> Result<(), Error> {
// Figure out if the client wants TLS or not. // Figure out if the client wants TLS or not.
let addr = stream.peer_addr().unwrap(); let addr = stream.peer_addr().unwrap();
@@ -106,11 +111,21 @@ pub async fn client_entrypoint(
write_all(&mut stream, yes).await?; write_all(&mut stream, yes).await?;
// Negotiate TLS. // Negotiate TLS.
match startup_tls(stream, client_server_map, shutdown_event_receiver).await { match startup_tls(stream, client_server_map, shutdown, admin_only).await {
Ok(mut client) => { Ok(mut client) => {
info!("Client {:?} connected (TLS)", addr); info!("Client {:?} connected (TLS)", addr);
client.handle().await if !client.is_admin() {
let _ = drain.send(1).await;
}
let result = client.handle().await;
if !client.is_admin() {
let _ = drain.send(-1).await;
}
result
} }
Err(err) => Err(err), Err(err) => Err(err),
} }
@@ -136,14 +151,25 @@ pub async fn client_entrypoint(
addr, addr,
bytes, bytes,
client_server_map, client_server_map,
shutdown_event_receiver, shutdown,
admin_only,
) )
.await .await
{ {
Ok(mut client) => { Ok(mut client) => {
info!("Client {:?} connected (plain)", addr); info!("Client {:?} connected (plain)", addr);
client.handle().await if !client.is_admin() {
let _ = drain.send(1).await;
}
let result = client.handle().await;
if !client.is_admin() {
let _ = drain.send(-1).await;
}
result
} }
Err(err) => Err(err), Err(err) => Err(err),
} }
@@ -166,14 +192,25 @@ pub async fn client_entrypoint(
addr, addr,
bytes, bytes,
client_server_map, client_server_map,
shutdown_event_receiver, shutdown,
admin_only,
) )
.await .await
{ {
Ok(mut client) => { Ok(mut client) => {
info!("Client {:?} connected (plain)", addr); info!("Client {:?} connected (plain)", addr);
client.handle().await if client.is_admin() {
let _ = drain.send(1).await;
}
let result = client.handle().await;
if !client.is_admin() {
let _ = drain.send(-1).await;
}
result
} }
Err(err) => Err(err), Err(err) => Err(err),
} }
@@ -184,20 +221,21 @@ pub async fn client_entrypoint(
let (read, write) = split(stream); let (read, write) = split(stream);
// Continue with cancel query request. // Continue with cancel query request.
match Client::cancel( match Client::cancel(read, write, addr, bytes, client_server_map, shutdown).await {
read,
write,
addr,
bytes,
client_server_map,
shutdown_event_receiver,
)
.await
{
Ok(mut client) => { Ok(mut client) => {
info!("Client {:?} issued a cancel query request", addr); info!("Client {:?} issued a cancel query request", addr);
client.handle().await if client.is_admin() {
let _ = drain.send(1).await;
}
let result = client.handle().await;
if !client.is_admin() {
let _ = drain.send(-1).await;
}
result
} }
Err(err) => Err(err), Err(err) => Err(err),
@@ -250,7 +288,8 @@ where
pub async fn startup_tls( pub async fn startup_tls(
stream: TcpStream, stream: TcpStream,
client_server_map: ClientServerMap, client_server_map: ClientServerMap,
shutdown_event_receiver: Receiver<()>, shutdown: Receiver<()>,
admin_only: bool,
) -> Result<Client<ReadHalf<TlsStream<TcpStream>>, WriteHalf<TlsStream<TcpStream>>>, Error> { ) -> Result<Client<ReadHalf<TlsStream<TcpStream>>, WriteHalf<TlsStream<TcpStream>>>, Error> {
// Negotiate TLS. // Negotiate TLS.
let tls = Tls::new()?; let tls = Tls::new()?;
@@ -280,7 +319,8 @@ pub async fn startup_tls(
addr, addr,
bytes, bytes,
client_server_map, client_server_map,
shutdown_event_receiver, shutdown,
admin_only,
) )
.await .await
} }
@@ -295,6 +335,10 @@ where
S: tokio::io::AsyncRead + std::marker::Unpin, S: tokio::io::AsyncRead + std::marker::Unpin,
T: tokio::io::AsyncWrite + std::marker::Unpin, T: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
pub fn is_admin(&self) -> bool {
self.admin
}
/// Handle Postgres client startup after TLS negotiation is complete /// Handle Postgres client startup after TLS negotiation is complete
/// or over plain text. /// or over plain text.
pub async fn startup( pub async fn startup(
@@ -303,29 +347,44 @@ where
addr: std::net::SocketAddr, addr: std::net::SocketAddr,
bytes: BytesMut, // The rest of the startup message. bytes: BytesMut, // The rest of the startup message.
client_server_map: ClientServerMap, client_server_map: ClientServerMap,
shutdown_event_receiver: Receiver<()>, shutdown: Receiver<()>,
admin_only: bool,
) -> Result<Client<S, T>, Error> { ) -> Result<Client<S, T>, Error> {
let config = get_config(); let config = get_config();
let stats = get_reporter(); let stats = get_reporter();
trace!("Got StartupMessage");
let parameters = parse_startup(bytes.clone())?; let parameters = parse_startup(bytes.clone())?;
let target_pool_name = match parameters.get("database") {
// These two parameters are mandatory by the protocol.
let pool_name = match parameters.get("database") {
Some(db) => db, Some(db) => db,
None => return Err(Error::ClientError), None => return Err(Error::ClientError),
}; };
let target_user_name = match parameters.get("user") { let username = match parameters.get("user") {
Some(user) => user, Some(user) => user,
None => return Err(Error::ClientError), None => return Err(Error::ClientError),
}; };
let admin = ["pgcat", "pgbouncer"] let admin = ["pgcat", "pgbouncer"]
.iter() .iter()
.filter(|db| *db == &target_pool_name) .filter(|db| *db == &pool_name)
.count() .count()
== 1; == 1;
// Kick any client that's not admin while we're in admin-only mode.
if !admin && admin_only {
debug!(
"Rejecting non-admin connection to {} when in admin only mode",
pool_name
);
error_response_terminal(
&mut write,
&format!("terminating connection due to administrator command"),
)
.await?;
return Err(Error::ShuttingDown);
}
// Generate random backend ID and secret key // Generate random backend ID and secret key
let process_id: i32 = rand::random(); let process_id: i32 = rand::random();
let secret_key: i32 = rand::random(); let secret_key: i32 = rand::random();
@@ -357,46 +416,55 @@ where
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError),
}; };
// Authenticate admin user.
let (transaction_mode, server_info) = if admin { let (transaction_mode, server_info) = if admin {
let correct_user = config.general.admin_username.as_str();
let correct_password = config.general.admin_password.as_str();
// Compare server and client hashes. // Compare server and client hashes.
let password_hash = md5_hash_password(correct_user, correct_password, &salt); let password_hash = md5_hash_password(
&config.general.admin_username,
&config.general.admin_password,
&salt,
);
if password_hash != password_response { if password_hash != password_response {
debug!("Password authentication failed"); debug!("Password authentication failed");
wrong_password(&mut write, target_user_name).await?; wrong_password(&mut write, username).await?;
return Err(Error::ClientError); return Err(Error::ClientError);
} }
(false, generate_server_info_for_admin()) (false, generate_server_info_for_admin())
} else { }
let target_pool = match get_pool(target_pool_name.clone(), target_user_name.clone()) { // Authenticate normal user.
else {
let pool = match get_pool(pool_name.clone(), username.clone()) {
Some(pool) => pool, Some(pool) => pool,
None => { None => {
error_response( error_response(
&mut write, &mut write,
&format!( &format!(
"No pool configured for database: {:?}, user: {:?}", "No pool configured for database: {:?}, user: {:?}",
target_pool_name, target_user_name pool_name, username
), ),
) )
.await?; .await?;
return Err(Error::ClientError); return Err(Error::ClientError);
} }
}; };
let transaction_mode = target_pool.settings.pool_mode == "transaction";
let server_info = target_pool.server_info();
// Compare server and client hashes. // Compare server and client hashes.
let correct_password = target_pool.settings.user.password.as_str(); let password_hash = md5_hash_password(&username, &pool.settings.user.password, &salt);
let password_hash = md5_hash_password(&target_user_name, correct_password, &salt);
if password_hash != password_response { if password_hash != password_response {
debug!("Password authentication failed"); debug!("Password authentication failed");
wrong_password(&mut write, &target_user_name).await?; wrong_password(&mut write, username).await?;
return Err(Error::ClientError); return Err(Error::ClientError);
} }
(transaction_mode, server_info)
let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction;
(transaction_mode, pool.server_info())
}; };
debug!("Password authentication successful"); debug!("Password authentication successful");
@@ -408,27 +476,25 @@ where
trace!("Startup OK"); trace!("Startup OK");
// Split the read and write streams
// so we can control buffering.
return Ok(Client { return Ok(Client {
read: BufReader::new(read), read: BufReader::new(read),
write: write, write: write,
addr, addr,
buffer: BytesMut::with_capacity(8196), buffer: BytesMut::with_capacity(8196),
cancel_mode: false, cancel_mode: false,
transaction_mode: transaction_mode, transaction_mode,
process_id: process_id, process_id,
secret_key: secret_key, secret_key,
client_server_map: client_server_map, client_server_map,
parameters: parameters.clone(), parameters: parameters.clone(),
stats: stats, stats: stats,
admin: admin, admin: admin,
last_address_id: None, last_address_id: None,
last_server_id: None, last_server_id: None,
target_pool_name: target_pool_name.clone(), pool_name: pool_name.clone(),
target_user_name: target_user_name.clone(), username: username.clone(),
shutdown_event_receiver: shutdown_event_receiver, shutdown,
connected_to_server: false,
}); });
} }
@@ -439,7 +505,7 @@ where
addr: std::net::SocketAddr, addr: std::net::SocketAddr,
mut bytes: BytesMut, // The rest of the startup message. mut bytes: BytesMut, // The rest of the startup message.
client_server_map: ClientServerMap, client_server_map: ClientServerMap,
shutdown_event_receiver: Receiver<()>, shutdown: Receiver<()>,
) -> Result<Client<S, T>, Error> { ) -> Result<Client<S, T>, Error> {
let process_id = bytes.get_i32(); let process_id = bytes.get_i32();
let secret_key = bytes.get_i32(); let secret_key = bytes.get_i32();
@@ -450,17 +516,18 @@ where
buffer: BytesMut::with_capacity(8196), buffer: BytesMut::with_capacity(8196),
cancel_mode: true, cancel_mode: true,
transaction_mode: false, transaction_mode: false,
process_id: process_id, process_id,
secret_key: secret_key, secret_key,
client_server_map: client_server_map, client_server_map,
parameters: HashMap::new(), parameters: HashMap::new(),
stats: get_reporter(), stats: get_reporter(),
admin: false, admin: false,
last_address_id: None, last_address_id: None,
last_server_id: None, last_server_id: None,
target_pool_name: String::from("undefined"), pool_name: String::from("undefined"),
target_user_name: String::from("undefined"), username: String::from("undefined"),
shutdown_event_receiver: shutdown_event_receiver, shutdown,
connected_to_server: false,
}); });
} }
@@ -481,7 +548,7 @@ where
process_id.clone(), process_id.clone(),
secret_key.clone(), secret_key.clone(),
address.clone(), address.clone(),
port.clone(), *port,
), ),
// The client doesn't know / got the wrong server, // The client doesn't know / got the wrong server,
@@ -493,13 +560,12 @@ where
// Opens a new separate connection to the server, sends the backend_id // Opens a new separate connection to the server, sends the backend_id
// and secret_key and then closes it for security reasons. No other interactions // and secret_key and then closes it for security reasons. No other interactions
// take place. // take place.
return Ok(Server::cancel(&address, &port, process_id, secret_key).await?); return Ok(Server::cancel(&address, port, process_id, secret_key).await?);
} }
// The query router determines where the query is going to go, // The query router determines where the query is going to go,
// e.g. primary, replica, which shard. // e.g. primary, replica, which shard.
let mut query_router = QueryRouter::new(); let mut query_router = QueryRouter::new();
let mut round_robin = 0;
// Our custom protocol loop. // Our custom protocol loop.
// We expect the client to either start a transaction with regular queries // We expect the client to either start a transaction with regular queries
@@ -517,9 +583,19 @@ where
// SET SHARDING KEY TO 'bigint'; // SET SHARDING KEY TO 'bigint';
let mut message = tokio::select! { let mut message = tokio::select! {
_ = self.shutdown_event_receiver.recv() => { _ = self.shutdown.recv() => {
error_response_terminal(&mut self.write, &format!("terminating connection due to administrator command")).await?; if !self.admin {
return Ok(()) error_response_terminal(
&mut self.write,
&format!("terminating connection due to administrator command")
).await?;
return Ok(())
}
// Admin clients ignore shutdown.
else {
read_message(&mut self.read).await?
}
}, },
message_result = read_message(&mut self.read) => message_result? message_result = read_message(&mut self.read) => message_result?
}; };
@@ -540,15 +616,14 @@ where
// Get a pool instance referenced by the most up-to-date // Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config // pointer. This ensures we always read the latest config
// when starting a query. // when starting a query.
let pool = match get_pool(self.target_pool_name.clone(), self.target_user_name.clone()) let pool = match get_pool(self.pool_name.clone(), self.username.clone()) {
{
Some(pool) => pool, Some(pool) => pool,
None => { None => {
error_response( error_response(
&mut self.write, &mut self.write,
&format!( &format!(
"No pool configured for database: {:?}, user: {:?}", "No pool configured for database: {:?}, user: {:?}",
self.target_pool_name, self.target_user_name self.pool_name, self.username
), ),
) )
.await?; .await?;
@@ -569,8 +644,8 @@ where
// SET SHARD TO // SET SHARD TO
Some((Command::SetShard, _)) => { Some((Command::SetShard, _)) => {
// Selected shard is not configured. let shard = query_router.shard();
if query_router.shard() >= pool.shards() { if shard >= pool.shards() {
// Set the shard back to what it was. // Set the shard back to what it was.
query_router.set_shard(current_shard); query_router.set_shard(current_shard);
@@ -578,7 +653,7 @@ where
&mut self.write, &mut self.write,
&format!( &format!(
"shard {} is more than configured {}, staying on shard {}", "shard {} is more than configured {}, staying on shard {}",
query_router.shard(), shard,
pool.shards(), pool.shards(),
current_shard, current_shard,
), ),
@@ -631,12 +706,7 @@ where
// Grab a server from the pool. // Grab a server from the pool.
let connection = match pool let connection = match pool
.get( .get(query_router.shard(), query_router.role(), self.process_id)
query_router.shard(),
query_router.role(),
self.process_id,
round_robin,
)
.await .await
{ {
Ok(conn) => { Ok(conn) => {
@@ -644,9 +714,22 @@ where
conn conn
} }
Err(err) => { Err(err) => {
// Clients do not expect to get SystemError followed by ReadyForQuery in the middle
// of extended protocol submission. So we will hold off on sending the actual error
// message to the client until we get 'S' message
match message[0] as char {
'P' | 'B' | 'E' | 'D' => (),
_ => {
error_response(
&mut self.write,
"could not get connection from the pool",
)
.await?;
}
};
error!("Could not get connection from pool: {:?}", err); error!("Could not get connection from pool: {:?}", err);
error_response(&mut self.write, "could not get connection from the pool")
.await?;
continue; continue;
} }
}; };
@@ -655,11 +738,10 @@ where
let address = connection.1; let address = connection.1;
let server = &mut *reference; let server = &mut *reference;
round_robin += 1;
// Server is assigned to the client in case the client wants to // Server is assigned to the client in case the client wants to
// cancel a query later. // cancel a query later.
server.claim(self.process_id, self.secret_key); server.claim(self.process_id, self.secret_key);
self.connected_to_server = true;
// Update statistics. // Update statistics.
if let Some(last_address_id) = self.last_address_id { if let Some(last_address_id) = self.last_address_id {
@@ -667,7 +749,6 @@ where
.client_disconnecting(self.process_id, last_address_id); .client_disconnecting(self.process_id, last_address_id);
} }
self.stats.client_active(self.process_id, address.id); self.stats.client_active(self.process_id, address.id);
self.stats.server_active(server.process_id(), address.id);
self.last_address_id = Some(address.id); self.last_address_id = Some(address.id);
self.last_server_id = Some(server.process_id()); self.last_server_id = Some(server.process_id());
@@ -731,43 +812,8 @@ where
'Q' => { 'Q' => {
debug!("Sending query to server"); debug!("Sending query to server");
self.send_server_message( self.send_and_receive_loop(code, original, server, &address, &pool)
server, .await?;
original,
&address,
query_router.shard(),
&pool,
)
.await?;
// Read all data the server has to offer, which can be multiple messages
// buffered in 8196 bytes chunks.
loop {
let response = self
.receive_server_message(
server,
&address,
query_router.shard(),
&pool,
)
.await?;
// Send server reply to the client.
match write_all_half(&mut self.write, response).await {
Ok(_) => (),
Err(err) => {
server.mark_bad();
return Err(err);
}
};
if !server.is_data_available() {
break;
}
}
// Report query executed statistics.
self.stats.query(self.process_id, address.id);
if !server.in_transaction() { if !server.in_transaction() {
// Report transaction executed statistics. // Report transaction executed statistics.
@@ -776,7 +822,6 @@ where
// Release server back to the pool if we are in transaction mode. // Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects. // If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode { if self.transaction_mode {
self.stats.server_idle(server.process_id(), address.id);
break; break;
} }
} }
@@ -830,52 +875,23 @@ where
self.buffer.put(&original[..]); self.buffer.put(&original[..]);
self.send_server_message( self.send_and_receive_loop(
server, code,
self.buffer.clone(), self.buffer.clone(),
server,
&address, &address,
query_router.shard(),
&pool, &pool,
) )
.await?; .await?;
self.buffer.clear(); self.buffer.clear();
// Read all data the server has to offer, which can be multiple messages
// buffered in 8196 bytes chunks.
loop {
let response = self
.receive_server_message(
server,
&address,
query_router.shard(),
&pool,
)
.await?;
match write_all_half(&mut self.write, response).await {
Ok(_) => (),
Err(err) => {
server.mark_bad();
return Err(err);
}
};
if !server.is_data_available() {
break;
}
}
// Report query executed statistics.
self.stats.query(self.process_id, address.id);
if !server.in_transaction() { if !server.in_transaction() {
self.stats.transaction(self.process_id, address.id); self.stats.transaction(self.process_id, address.id);
// Release server back to the pool if we are in transaction mode. // Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects. // If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode { if self.transaction_mode {
self.stats.server_idle(server.process_id(), address.id);
break; break;
} }
} }
@@ -885,32 +901,18 @@ where
'd' => { 'd' => {
// Forward the data to the server, // Forward the data to the server,
// don't buffer it since it can be rather large. // don't buffer it since it can be rather large.
self.send_server_message( self.send_server_message(server, original, &address, &pool)
server, .await?;
original,
&address,
query_router.shard(),
&pool,
)
.await?;
} }
// CopyDone or CopyFail // CopyDone or CopyFail
// Copy is done, successfully or not. // Copy is done, successfully or not.
'c' | 'f' => { 'c' | 'f' => {
self.send_server_message( self.send_server_message(server, original, &address, &pool)
server,
original,
&address,
query_router.shard(),
&pool,
)
.await?;
let response = self
.receive_server_message(server, &address, query_router.shard(), &pool)
.await?; .await?;
let response = self.receive_server_message(server, &address, &pool).await?;
match write_all_half(&mut self.write, response).await { match write_all_half(&mut self.write, response).await {
Ok(_) => (), Ok(_) => (),
Err(err) => { Err(err) => {
@@ -925,7 +927,6 @@ where
// Release server back to the pool if we are in transaction mode. // Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects. // If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode { if self.transaction_mode {
self.stats.server_idle(server.process_id(), address.id);
break; break;
} }
} }
@@ -941,6 +942,8 @@ where
// The server is no longer bound to us, we can't cancel it's queries anymore. // The server is no longer bound to us, we can't cancel it's queries anymore.
debug!("Releasing server back into the pool"); debug!("Releasing server back into the pool");
self.stats.server_idle(server.process_id(), address.id);
self.connected_to_server = false;
self.release(); self.release();
self.stats.client_idle(self.process_id, address.id); self.stats.client_idle(self.process_id, address.id);
} }
@@ -952,35 +955,107 @@ where
guard.remove(&(self.process_id, self.secret_key)); guard.remove(&(self.process_id, self.secret_key));
} }
async fn send_and_receive_loop(
&mut self,
code: char,
message: BytesMut,
server: &mut Server,
address: &Address,
pool: &ConnectionPool,
) -> Result<(), Error> {
debug!("Sending {} to server", code);
self.send_server_message(server, message, &address, &pool)
.await?;
// Read all data the server has to offer, which can be multiple messages
// buffered in 8196 bytes chunks.
loop {
let response = self.receive_server_message(server, &address, &pool).await?;
match write_all_half(&mut self.write, response).await {
Ok(_) => (),
Err(err) => {
server.mark_bad();
return Err(err);
}
};
if !server.is_data_available() {
break;
}
}
// Report query executed statistics.
self.stats.query(self.process_id, address.id);
Ok(())
}
async fn send_server_message( async fn send_server_message(
&self, &self,
server: &mut Server, server: &mut Server,
message: BytesMut, message: BytesMut,
address: &Address, address: &Address,
shard: usize,
pool: &ConnectionPool, pool: &ConnectionPool,
) -> Result<(), Error> { ) -> Result<(), Error> {
match server.send(message).await { match server.send(message).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => { Err(err) => {
pool.ban(address, shard, self.process_id); pool.ban(address, self.process_id);
Err(err) Err(err)
} }
} }
} }
async fn receive_server_message( async fn receive_server_message(
&self, &mut self,
server: &mut Server, server: &mut Server,
address: &Address, address: &Address,
shard: usize,
pool: &ConnectionPool, pool: &ConnectionPool,
) -> Result<BytesMut, Error> { ) -> Result<BytesMut, Error> {
match server.recv().await { if pool.settings.user.statement_timeout > 0 {
Ok(message) => Ok(message), match tokio::time::timeout(
Err(err) => { tokio::time::Duration::from_millis(pool.settings.user.statement_timeout),
pool.ban(address, shard, self.process_id); server.recv(),
Err(err) )
.await
{
Ok(result) => match result {
Ok(message) => Ok(message),
Err(err) => {
pool.ban(address, self.process_id);
error_response_terminal(
&mut self.write,
&format!("error receiving data from server: {:?}", err),
)
.await?;
Err(err)
}
},
Err(_) => {
error!(
"Statement timeout while talking to {:?} with user {}",
address, pool.settings.user.username
);
server.mark_bad();
pool.ban(address, self.process_id);
error_response_terminal(&mut self.write, "pool statement timeout").await?;
Err(Error::StatementTimeout)
}
}
} else {
match server.recv().await {
Ok(message) => Ok(message),
Err(err) => {
pool.ban(address, self.process_id);
error_response_terminal(
&mut self.write,
&format!("error receiving data from server: {:?}", err),
)
.await?;
Err(err)
}
} }
} }
} }
@@ -991,15 +1066,16 @@ impl<S, T> Drop for Client<S, T> {
let mut guard = self.client_server_map.lock(); let mut guard = self.client_server_map.lock();
guard.remove(&(self.process_id, self.secret_key)); guard.remove(&(self.process_id, self.secret_key));
// Update statistics. // Dirty shutdown
// TODO: refactor, this is not the best way to handle state management.
if let Some(address_id) = self.last_address_id { if let Some(address_id) = self.last_address_id {
self.stats.client_disconnecting(self.process_id, address_id); self.stats.client_disconnecting(self.process_id, address_id);
if let Some(process_id) = self.last_server_id { if self.connected_to_server {
self.stats.server_idle(process_id, address_id); if let Some(process_id) = self.last_server_id {
self.stats.server_idle(process_id, address_id);
}
} }
} }
// self.release();
} }
} }

View File

@@ -57,13 +57,38 @@ impl PartialEq<Role> for Option<Role> {
/// Address identifying a PostgreSQL server uniquely. /// Address identifying a PostgreSQL server uniquely.
#[derive(Clone, PartialEq, Hash, std::cmp::Eq, Debug)] #[derive(Clone, PartialEq, Hash, std::cmp::Eq, Debug)]
pub struct Address { pub struct Address {
/// Unique ID per addressable Postgres server.
pub id: usize, pub id: usize,
/// Server host.
pub host: String, pub host: String,
pub port: String,
/// Server port.
pub port: u16,
/// Shard number of this Postgres server.
pub shard: usize, pub shard: usize,
/// The name of the Postgres database.
pub database: String, pub database: String,
/// Default search_path.
pub search_path: Option<String>,
/// Server role: replica, primary.
pub role: Role, pub role: Role,
/// If it's a replica, number it for reference and failover.
pub replica_number: usize, pub replica_number: usize,
/// Position of the server in the pool for failover.
pub address_index: usize,
/// The name of the user configured to use this pool.
pub username: String,
/// The name of this pool (i.e. database name visible to the client).
pub pool_name: String,
} }
impl Default for Address { impl Default for Address {
@@ -71,11 +96,15 @@ impl Default for Address {
Address { Address {
id: 0, id: 0,
host: String::from("127.0.0.1"), host: String::from("127.0.0.1"),
port: String::from("5432"), port: 5432,
shard: 0, shard: 0,
address_index: 0,
replica_number: 0, replica_number: 0,
database: String::from("database"), database: String::from("database"),
search_path: None,
role: Role::Replica, role: Role::Replica,
username: String::from("username"),
pool_name: String::from("pool_name"),
} }
} }
} }
@@ -84,11 +113,11 @@ impl Address {
/// Address name (aka database) used in `SHOW STATS`, `SHOW DATABASES`, and `SHOW POOLS`. /// Address name (aka database) used in `SHOW STATS`, `SHOW DATABASES`, and `SHOW POOLS`.
pub fn name(&self) -> String { pub fn name(&self) -> String {
match self.role { match self.role {
Role::Primary => format!("{}_shard_{}_primary", self.database, self.shard), Role::Primary => format!("{}_shard_{}_primary", self.pool_name, self.shard),
Role::Replica => format!( Role::Replica => format!(
"{}_shard_{}_replica_{}", "{}_shard_{}_replica_{}",
self.database, self.shard, self.replica_number self.pool_name, self.shard, self.replica_number
), ),
} }
} }
@@ -100,6 +129,7 @@ pub struct User {
pub username: String, pub username: String,
pub password: String, pub password: String,
pub pool_size: u32, pub pool_size: u32,
pub statement_timeout: u64,
} }
impl Default for User { impl Default for User {
@@ -108,6 +138,7 @@ impl Default for User {
username: String::from("postgres"), username: String::from("postgres"),
password: String::new(), password: String::new(),
pool_size: 15, pool_size: 15,
statement_timeout: 0,
} }
} }
} }
@@ -118,6 +149,7 @@ pub struct General {
pub host: String, pub host: String,
pub port: i16, pub port: i16,
pub enable_prometheus_exporter: Option<bool>, pub enable_prometheus_exporter: Option<bool>,
pub prometheus_exporter_port: i16,
pub connect_timeout: u64, pub connect_timeout: u64,
pub healthcheck_timeout: u64, pub healthcheck_timeout: u64,
pub shutdown_timeout: u64, pub shutdown_timeout: u64,
@@ -136,6 +168,7 @@ impl Default for General {
host: String::from("localhost"), host: String::from("localhost"),
port: 5432, port: 5432,
enable_prometheus_exporter: Some(false), enable_prometheus_exporter: Some(false),
prometheus_exporter_port: 9930,
connect_timeout: 5000, connect_timeout: 5000,
healthcheck_timeout: 1000, healthcheck_timeout: 1000,
shutdown_timeout: 60000, shutdown_timeout: 60000,
@@ -177,6 +210,7 @@ impl Default for Pool {
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Shard { pub struct Shard {
pub database: String, pub database: String,
pub search_path: Option<String>,
pub servers: Vec<(String, u16, String)>, pub servers: Vec<(String, u16, String)>,
} }
@@ -184,6 +218,7 @@ impl Default for Shard {
fn default() -> Shard { fn default() -> Shard {
Shard { Shard {
servers: vec![(String::from("localhost"), 5432, String::from("primary"))], servers: vec![(String::from("localhost"), 5432, String::from("primary"))],
search_path: None,
database: String::from("postgres"), database: String::from("postgres"),
} }
} }
@@ -271,6 +306,10 @@ impl From<&Config> for std::collections::HashMap<String, String> {
let mut static_settings = vec![ let mut static_settings = vec![
("host".to_string(), config.general.host.to_string()), ("host".to_string(), config.general.host.to_string()),
("port".to_string(), config.general.port.to_string()), ("port".to_string(), config.general.port.to_string()),
(
"prometheus_exporter_port".to_string(),
config.general.prometheus_exporter_port.to_string(),
),
( (
"connect_timeout".to_string(), "connect_timeout".to_string(),
config.general.connect_timeout.to_string(), config.general.connect_timeout.to_string(),
@@ -326,9 +365,10 @@ impl Config {
}; };
for (pool_name, pool_config) in &self.pools { for (pool_name, pool_config) in &self.pools {
info!("--- Settings for pool {} ---", pool_name); // TODO: Make this output prettier (maybe a table?)
info!( info!(
"Pool size from all users: {}", "[pool: {}] Maximum user connections: {}",
pool_name,
pool_config pool_config
.users .users
.iter() .iter()
@@ -336,12 +376,40 @@ impl Config {
.sum::<u32>() .sum::<u32>()
.to_string() .to_string()
); );
info!("Pool mode: {}", pool_config.pool_mode); info!("[pool: {}] Pool mode: {}", pool_name, pool_config.pool_mode);
info!("Sharding function: {}", pool_config.sharding_function); info!(
info!("Primary reads: {}", pool_config.primary_reads_enabled); "[pool: {}] Sharding function: {}",
info!("Query router: {}", pool_config.query_parser_enabled); pool_name, pool_config.sharding_function
info!("Number of shards: {}", pool_config.shards.len()); );
info!("Number of users: {}", pool_config.users.len()); info!(
"[pool: {}] Primary reads: {}",
pool_name, pool_config.primary_reads_enabled
);
info!(
"[pool: {}] Query router: {}",
pool_name, pool_config.query_parser_enabled
);
info!(
"[pool: {}] Number of shards: {}",
pool_name,
pool_config.shards.len()
);
info!(
"[pool: {}] Number of users: {}",
pool_name,
pool_config.users.len()
);
for user in &pool_config.users {
info!(
"[pool: {}][user: {}] Pool size: {}",
pool_name, user.1.username, user.1.pool_size,
);
info!(
"[pool: {}][user: {}] Statement timeout: {}",
pool_name, user.1.username, user.1.statement_timeout
)
}
} }
} }
} }
@@ -438,6 +506,18 @@ pub async fn parse(path: &str) -> Result<(), Error> {
} }
}; };
match pool.pool_mode.as_ref() {
"transaction" => (),
"session" => (),
other => {
error!(
"pool_mode can be 'session' or 'transaction', got: '{}'",
other
);
return Err(Error::BadConfig);
}
};
for shard in &pool.shards { for shard in &pool.shards {
// We use addresses as unique identifiers, // We use addresses as unique identifiers,
// let's make sure they are unique in the config as well. // let's make sure they are unique in the config as well.

View File

@@ -11,4 +11,6 @@ pub enum Error {
AllServersDown, AllServersDown,
ClientError, ClientError,
TlsError, TlsError,
StatementTimeout,
ShuttingDown,
} }

View File

@@ -24,6 +24,7 @@ extern crate async_trait;
extern crate bb8; extern crate bb8;
extern crate bytes; extern crate bytes;
extern crate env_logger; extern crate env_logger;
extern crate exitcode;
extern crate log; extern crate log;
extern crate md5; extern crate md5;
extern crate num_cpus; extern crate num_cpus;
@@ -66,6 +67,7 @@ mod stats;
mod tls; mod tls;
use crate::config::{get_config, reload_config, VERSION}; use crate::config::{get_config, reload_config, VERSION};
use crate::errors::Error;
use crate::pool::{ClientServerMap, ConnectionPool}; use crate::pool::{ClientServerMap, ConnectionPool};
use crate::prometheus::start_metric_server; use crate::prometheus::start_metric_server;
use crate::stats::{Collector, Reporter, REPORTER}; use crate::stats::{Collector, Reporter, REPORTER};
@@ -77,7 +79,7 @@ async fn main() {
if !query_router::QueryRouter::setup() { if !query_router::QueryRouter::setup() {
error!("Could not setup query router"); error!("Could not setup query router");
return; std::process::exit(exitcode::CONFIG);
} }
let args = std::env::args().collect::<Vec<String>>(); let args = std::env::args().collect::<Vec<String>>();
@@ -92,19 +94,22 @@ async fn main() {
Ok(_) => (), Ok(_) => (),
Err(err) => { Err(err) => {
error!("Config parse error: {:?}", err); error!("Config parse error: {:?}", err);
return; std::process::exit(exitcode::CONFIG);
} }
}; };
let config = get_config(); let config = get_config();
if let Some(true) = config.general.enable_prometheus_exporter { if let Some(true) = config.general.enable_prometheus_exporter {
let http_addr_str = format!("{}:{}", config.general.host, crate::prometheus::HTTP_PORT); let http_addr_str = format!(
"{}:{}",
config.general.host, config.general.prometheus_exporter_port
);
let http_addr = match SocketAddr::from_str(&http_addr_str) { let http_addr = match SocketAddr::from_str(&http_addr_str) {
Ok(addr) => addr, Ok(addr) => addr,
Err(err) => { Err(err) => {
error!("Invalid http address: {}", err); error!("Invalid http address: {}", err);
return; std::process::exit(exitcode::CONFIG);
} }
}; };
tokio::task::spawn(async move { tokio::task::spawn(async move {
@@ -118,7 +123,7 @@ async fn main() {
Ok(sock) => sock, Ok(sock) => sock,
Err(err) => { Err(err) => {
error!("Listener socket error: {:?}", err); error!("Listener socket error: {:?}", err);
return; std::process::exit(exitcode::CONFIG);
} }
}; };
@@ -130,171 +135,160 @@ async fn main() {
let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new())); let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new()));
// Statistics reporting. // Statistics reporting.
let (tx, rx) = mpsc::channel(100); let (stats_tx, stats_rx) = mpsc::channel(100_000);
REPORTER.store(Arc::new(Reporter::new(tx.clone()))); REPORTER.store(Arc::new(Reporter::new(stats_tx.clone())));
// Connection pool that allows to query all shards and replicas. // Connection pool that allows to query all shards and replicas.
match ConnectionPool::from_config(client_server_map.clone()).await { match ConnectionPool::from_config(client_server_map.clone()).await {
Ok(_) => (), Ok(_) => (),
Err(err) => { Err(err) => {
error!("Pool error: {:?}", err); error!("Pool error: {:?}", err);
return; std::process::exit(exitcode::CONFIG);
} }
}; };
// Statistics collector task.
let collector_tx = tx.clone();
// Save these for reloading
let reload_client_server_map = client_server_map.clone();
let autoreload_client_server_map = client_server_map.clone();
tokio::task::spawn(async move { tokio::task::spawn(async move {
let mut stats_collector = Collector::new(rx, collector_tx); let mut stats_collector = Collector::new(stats_rx, stats_tx.clone());
stats_collector.collect().await; stats_collector.collect().await;
}); });
info!("Waiting for clients"); info!("Config autoreloader: {}", config.general.autoreload);
let (shutdown_event_tx, mut shutdown_event_rx) = broadcast::channel::<()>(1);
let shutdown_event_tx_clone = shutdown_event_tx.clone();
// Client connection loop.
tokio::task::spawn(async move {
// Creates event subscriber for shutdown event, this is dropped when shutdown event is broadcast
let mut listener_shutdown_event_rx = shutdown_event_tx_clone.subscribe();
loop {
let client_server_map = client_server_map.clone();
// Listen for shutdown event and client connection at the same time
let (socket, addr) = tokio::select! {
_ = listener_shutdown_event_rx.recv() => {
// Exits client connection loop which drops listener, listener_shutdown_event_rx and shutdown_event_tx_clone
break;
}
listener_response = listener.accept() => {
match listener_response {
Ok((socket, addr)) => (socket, addr),
Err(err) => {
error!("{:?}", err);
continue;
}
}
}
};
// Used to signal shutdown
let client_shutdown_handler_rx = shutdown_event_tx_clone.subscribe();
// Used to signal that the task has completed
let dummy_tx = shutdown_event_tx_clone.clone();
// Handle client.
tokio::task::spawn(async move {
let start = chrono::offset::Utc::now().naive_utc();
match client::client_entrypoint(
socket,
client_server_map,
client_shutdown_handler_rx,
)
.await
{
Ok(_) => {
let duration = chrono::offset::Utc::now().naive_utc() - start;
info!(
"Client {:?} disconnected, session duration: {}",
addr,
format_duration(&duration)
);
}
Err(err) => {
debug!("Client disconnected with error {:?}", err);
}
};
// Drop this transmitter so receiver knows that the task is completed
drop(dummy_tx);
});
}
});
// Reload config:
// kill -SIGHUP $(pgrep pgcat)
tokio::task::spawn(async move {
let mut stream = unix_signal(SignalKind::hangup()).unwrap();
loop {
stream.recv().await;
info!("Reloading config");
match reload_config(reload_client_server_map.clone()).await {
Ok(_) => (),
Err(_) => continue,
};
get_config().show();
}
});
if config.general.autoreload {
let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(15_000));
tokio::task::spawn(async move {
info!("Config autoreloader started");
loop {
interval.tick().await;
match reload_config(autoreload_client_server_map.clone()).await {
Ok(changed) => {
if changed {
get_config().show()
}
}
Err(_) => (),
};
}
});
}
let mut term_signal = unix_signal(SignalKind::terminate()).unwrap(); let mut term_signal = unix_signal(SignalKind::terminate()).unwrap();
let mut interrupt_signal = unix_signal(SignalKind::interrupt()).unwrap(); let mut interrupt_signal = unix_signal(SignalKind::interrupt()).unwrap();
let mut sighup_signal = unix_signal(SignalKind::hangup()).unwrap();
let mut autoreload_interval = tokio::time::interval(tokio::time::Duration::from_millis(15_000));
let (shutdown_tx, _) = broadcast::channel::<()>(1);
let (drain_tx, mut drain_rx) = mpsc::channel::<i8>(2048);
let (exit_tx, mut exit_rx) = mpsc::channel::<()>(1);
tokio::select! { info!("Waiting for clients");
// Initiate graceful shutdown sequence on sig int
_ = interrupt_signal.recv() => {
info!("Got SIGINT, waiting for client connection drain now");
// Broadcast that client tasks need to finish let mut admin_only = false;
shutdown_event_tx.send(()).unwrap(); let mut total_clients = 0;
// Closes transmitter
drop(shutdown_event_tx);
// This is in a loop because the first event that the receiver receives will be the shutdown event loop {
// This is not what we are waiting for instead, we want the receiver to send an error once all senders are closed which is reached after the shutdown event is received tokio::select! {
loop { // Reload config:
match tokio::time::timeout( // kill -SIGHUP $(pgrep pgcat)
tokio::time::Duration::from_millis(config.general.shutdown_timeout), _ = sighup_signal.recv() => {
shutdown_event_rx.recv(), info!("Reloading config");
)
.await match reload_config(client_server_map.clone()).await {
{ Ok(_) => (),
Ok(res) => match res { Err(_) => (),
Ok(_) => {} };
Err(_) => break,
}, get_config().show();
Err(_) => { },
info!("Timed out while waiting for clients to shutdown");
break; _ = autoreload_interval.tick() => {
if config.general.autoreload {
info!("Automatically reloading config");
match reload_config(client_server_map.clone()).await {
Ok(changed) => {
if changed {
get_config().show()
}
}
Err(_) => (),
};
}
},
// Initiate graceful shutdown sequence on sig int
_ = interrupt_signal.recv() => {
info!("Got SIGINT, waiting for client connection drain now");
admin_only = true;
// Broadcast that client tasks need to finish
let _ = shutdown_tx.send(());
let exit_tx = exit_tx.clone();
let _ = drain_tx.send(0).await;
tokio::task::spawn(async move {
let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(config.general.shutdown_timeout));
// First tick fires immediately.
interval.tick().await;
// Second one in the interval time.
interval.tick().await;
// We're done waiting.
error!("Timed out waiting for clients");
let _ = exit_tx.send(()).await;
});
},
_ = term_signal.recv() => break,
new_client = listener.accept() => {
let (socket, addr) = match new_client {
Ok((socket, addr)) => (socket, addr),
Err(err) => {
error!("{:?}", err);
continue;
} }
};
let shutdown_rx = shutdown_tx.subscribe();
let drain_tx = drain_tx.clone();
let client_server_map = client_server_map.clone();
tokio::task::spawn(async move {
let start = chrono::offset::Utc::now().naive_utc();
match client::client_entrypoint(
socket,
client_server_map,
shutdown_rx,
drain_tx,
admin_only,
)
.await
{
Ok(()) => {
let duration = chrono::offset::Utc::now().naive_utc() - start;
info!(
"Client {:?} disconnected, session duration: {}",
addr,
format_duration(&duration)
);
}
Err(err) => {
match err {
// Don't count the clients we rejected.
Error::ShuttingDown => (),
_ => {
// drain_tx.send(-1).await.unwrap();
}
}
debug!("Client disconnected with error {:?}", err);
}
};
});
}
_ = exit_rx.recv() => {
break;
}
client_ping = drain_rx.recv() => {
let client_ping = client_ping.unwrap();
total_clients += client_ping;
if total_clients == 0 && admin_only {
let _ = exit_tx.send(()).await;
} }
} }
}, }
_ = term_signal.recv() => (),
} }
info!("Shutting down..."); info!("Shutting down...");

View File

@@ -111,7 +111,12 @@ where
/// Send the startup packet the server. We're pretending we're a Pg client. /// Send the startup packet the server. We're pretending we're a Pg client.
/// This tells the server which user we are and what database we want. /// This tells the server which user we are and what database we want.
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> { pub async fn startup(
stream: &mut TcpStream,
user: &str,
database: &str,
search_path: Option<&String>,
) -> Result<(), Error> {
let mut bytes = BytesMut::with_capacity(25); let mut bytes = BytesMut::with_capacity(25);
bytes.put_i32(196608); // Protocol number bytes.put_i32(196608); // Protocol number
@@ -125,6 +130,17 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
bytes.put(&b"database\0"[..]); bytes.put(&b"database\0"[..]);
bytes.put_slice(&database.as_bytes()); bytes.put_slice(&database.as_bytes());
bytes.put_u8(0); bytes.put_u8(0);
// search_path
match search_path {
Some(search_path) => {
bytes.put(&b"options\0"[..]);
bytes.put_slice(&format!("-c search_path={}", search_path).as_bytes());
bytes.put_u8(0);
}
None => (),
};
bytes.put_u8(0); // Null terminator bytes.put_u8(0); // Null terminator
let len = bytes.len() as i32 + 4i32; let len = bytes.len() as i32 + 4i32;

View File

@@ -6,44 +6,80 @@ use chrono::naive::NaiveDateTime;
use log::{debug, error, info, warn}; use log::{debug, error, info, warn};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use rand::seq::SliceRandom;
use rand::thread_rng;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use crate::config::{get_config, Address, Role, Shard, User}; use crate::config::{get_config, Address, Role, User};
use crate::errors::Error; use crate::errors::Error;
use crate::server::Server; use crate::server::Server;
use crate::sharding::ShardingFunction;
use crate::stats::{get_reporter, Reporter}; use crate::stats::{get_reporter, Reporter};
pub type BanList = Arc<RwLock<Vec<HashMap<Address, NaiveDateTime>>>>; pub type BanList = Arc<RwLock<Vec<HashMap<Address, NaiveDateTime>>>>;
pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, String)>>>; pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, u16)>>>;
pub type PoolMap = HashMap<(String, String), ConnectionPool>; pub type PoolMap = HashMap<(String, String), ConnectionPool>;
/// The connection pool, globally available. /// The connection pool, globally available.
/// This is atomic and safe and read-optimized. /// This is atomic and safe and read-optimized.
/// The pool is recreated dynamically when the config is reloaded. /// The pool is recreated dynamically when the config is reloaded.
pub static POOLS: Lazy<ArcSwap<PoolMap>> = Lazy::new(|| ArcSwap::from_pointee(HashMap::default())); pub static POOLS: Lazy<ArcSwap<PoolMap>> = Lazy::new(|| ArcSwap::from_pointee(HashMap::default()));
/// Pool mode:
/// - transaction: server serves one transaction,
/// - session: server is attached to the client.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum PoolMode {
Session,
Transaction,
}
impl std::fmt::Display for PoolMode {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match *self {
PoolMode::Session => write!(f, "session"),
PoolMode::Transaction => write!(f, "transaction"),
}
}
}
/// Pool settings.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct PoolSettings { pub struct PoolSettings {
pub pool_mode: String, /// Transaction or Session.
pub shards: HashMap<String, Shard>, pub pool_mode: PoolMode,
// Number of shards.
pub shards: usize,
// Connecting user.
pub user: User, pub user: User,
pub default_role: String,
// Default server role to connect to.
pub default_role: Option<Role>,
// Enable/disable query parser.
pub query_parser_enabled: bool, pub query_parser_enabled: bool,
// Read from the primary as well or not.
pub primary_reads_enabled: bool, pub primary_reads_enabled: bool,
pub sharding_function: String,
// Sharding function.
pub sharding_function: ShardingFunction,
} }
impl Default for PoolSettings { impl Default for PoolSettings {
fn default() -> PoolSettings { fn default() -> PoolSettings {
PoolSettings { PoolSettings {
pool_mode: String::from("transaction"), pool_mode: PoolMode::Transaction,
shards: HashMap::from([(String::from("1"), Shard::default())]), shards: 1,
user: User::default(), user: User::default(),
default_role: String::from("any"), default_role: None,
query_parser_enabled: false, query_parser_enabled: false,
primary_reads_enabled: true, primary_reads_enabled: true,
sharding_function: "pg_bigint_hash".to_string(), sharding_function: ShardingFunction::PgBigintHash,
} }
} }
} }
@@ -71,6 +107,7 @@ pub struct ConnectionPool {
/// on pool creation and save the K messages here. /// on pool creation and save the K messages here.
server_info: BytesMut, server_info: BytesMut,
/// Pool configuration.
pub settings: PoolSettings, pub settings: PoolSettings,
} }
@@ -78,11 +115,13 @@ impl ConnectionPool {
/// Construct the connection pool from the configuration. /// Construct the connection pool from the configuration.
pub async fn from_config(client_server_map: ClientServerMap) -> Result<(), Error> { pub async fn from_config(client_server_map: ClientServerMap) -> Result<(), Error> {
let config = get_config(); let config = get_config();
let mut new_pools = PoolMap::default();
let mut new_pools = HashMap::new();
let mut address_id = 0; let mut address_id = 0;
for (pool_name, pool_config) in &config.pools { for (pool_name, pool_config) in &config.pools {
for (_user_index, user_info) in &pool_config.users { // There is one pool per database/user pair.
for (_, user) in &pool_config.users {
let mut shards = Vec::new(); let mut shards = Vec::new();
let mut addresses = Vec::new(); let mut addresses = Vec::new();
let mut banlist = Vec::new(); let mut banlist = Vec::new();
@@ -96,10 +135,11 @@ impl ConnectionPool {
// Sort by shard number to ensure consistency. // Sort by shard number to ensure consistency.
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap()); shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
for shard_idx in shard_ids { for shard_idx in &shard_ids {
let shard = &pool_config.shards[&shard_idx]; let shard = &pool_config.shards[shard_idx];
let mut pools = Vec::new(); let mut pools = Vec::new();
let mut servers = Vec::new(); let mut servers = Vec::new();
let mut address_index = 0;
let mut replica_number = 0; let mut replica_number = 0;
for server in shard.servers.iter() { for server in shard.servers.iter() {
@@ -114,15 +154,20 @@ impl ConnectionPool {
let address = Address { let address = Address {
id: address_id, id: address_id,
database: pool_name.clone(), database: shard.database.clone(),
search_path: shard.search_path.clone(),
host: server.0.clone(), host: server.0.clone(),
port: server.1.to_string(), port: server.1 as u16,
role: role, role: role,
address_index,
replica_number, replica_number,
shard: shard_idx.parse::<usize>().unwrap(), shard: shard_idx.parse::<usize>().unwrap(),
username: user.username.clone(),
pool_name: pool_name.clone(),
}; };
address_id += 1; address_id += 1;
address_index += 1;
if role == Role::Replica { if role == Role::Replica {
replica_number += 1; replica_number += 1;
@@ -130,14 +175,14 @@ impl ConnectionPool {
let manager = ServerPool::new( let manager = ServerPool::new(
address.clone(), address.clone(),
user_info.clone(), user.clone(),
&shard.database, &shard.database,
client_server_map.clone(), client_server_map.clone(),
get_reporter(), get_reporter(),
); );
let pool = Pool::builder() let pool = Pool::builder()
.max_size(user_info.pool_size) .max_size(user.pool_size)
.connection_timeout(std::time::Duration::from_millis( .connection_timeout(std::time::Duration::from_millis(
config.general.connect_timeout, config.general.connect_timeout,
)) ))
@@ -164,13 +209,27 @@ impl ConnectionPool {
stats: get_reporter(), stats: get_reporter(),
server_info: BytesMut::new(), server_info: BytesMut::new(),
settings: PoolSettings { settings: PoolSettings {
pool_mode: pool_config.pool_mode.clone(), pool_mode: match pool_config.pool_mode.as_str() {
shards: pool_config.shards.clone(), "transaction" => PoolMode::Transaction,
user: user_info.clone(), "session" => PoolMode::Session,
default_role: pool_config.default_role.clone(), _ => unreachable!(),
},
// shards: pool_config.shards.clone(),
shards: shard_ids.len(),
user: user.clone(),
default_role: match pool_config.default_role.as_str() {
"any" => None,
"replica" => Some(Role::Replica),
"primary" => Some(Role::Primary),
_ => unreachable!(),
},
query_parser_enabled: pool_config.query_parser_enabled.clone(), query_parser_enabled: pool_config.query_parser_enabled.clone(),
primary_reads_enabled: pool_config.primary_reads_enabled, primary_reads_enabled: pool_config.primary_reads_enabled,
sharding_function: pool_config.sharding_function.clone(), sharding_function: match pool_config.sharding_function.as_str() {
"pg_bigint_hash" => ShardingFunction::PgBigintHash,
"sha1" => ShardingFunction::Sha1,
_ => unreachable!(),
},
}, },
}; };
@@ -183,7 +242,9 @@ impl ConnectionPool {
return Err(err); return Err(err);
} }
}; };
new_pools.insert((pool_name.clone(), user_info.username.clone()), pool);
// There is one pool per database/user pair.
new_pools.insert((pool_name.clone(), user.username.clone()), pool);
} }
} }
@@ -199,16 +260,9 @@ impl ConnectionPool {
/// the pooler starts up. /// the pooler starts up.
async fn validate(&mut self) -> Result<(), Error> { async fn validate(&mut self) -> Result<(), Error> {
let mut server_infos = Vec::new(); let mut server_infos = Vec::new();
let stats = self.stats.clone();
for shard in 0..self.shards() { for shard in 0..self.shards() {
let mut round_robin = 0; for server in 0..self.servers(shard) {
let connection = match self.databases[shard][server].get().await {
for _ in 0..self.servers(shard) {
// To keep stats consistent.
let fake_process_id = 0;
let connection = match self.get(shard, None, fake_process_id, round_robin).await {
Ok(conn) => conn, Ok(conn) => conn,
Err(err) => { Err(err) => {
error!("Shard {} down or misconfigured: {:?}", shard, err); error!("Shard {} down or misconfigured: {:?}", shard, err);
@@ -216,25 +270,21 @@ impl ConnectionPool {
} }
}; };
let proxy = connection.0; let proxy = connection;
let address = connection.1;
let server = &*proxy; let server = &*proxy;
let server_info = server.server_info(); let server_info = server.server_info();
stats.client_disconnecting(fake_process_id, address.id);
if server_infos.len() > 0 { if server_infos.len() > 0 {
// Compare against the last server checked. // Compare against the last server checked.
if server_info != server_infos[server_infos.len() - 1] { if server_info != server_infos[server_infos.len() - 1] {
warn!( warn!(
"{:?} has different server configuration than the last server", "{:?} has different server configuration than the last server",
address proxy.address()
); );
} }
} }
server_infos.push(server_info); server_infos.push(server_info);
round_robin += 1;
} }
} }
@@ -244,6 +294,8 @@ impl ConnectionPool {
return Err(Error::AllServersDown); return Err(Error::AllServersDown);
} }
// We're assuming all servers are identical.
// TODO: not true.
self.server_info = server_infos[0].clone(); self.server_info = server_infos[0].clone();
Ok(()) Ok(())
@@ -252,58 +304,31 @@ impl ConnectionPool {
/// Get a connection from the pool. /// Get a connection from the pool.
pub async fn get( pub async fn get(
&self, &self,
shard: usize, // shard number shard: usize, // shard number
role: Option<Role>, // primary or replica role: Option<Role>, // primary or replica
process_id: i32, // client id process_id: i32, // client id
mut round_robin: usize, // round robin offset
) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> { ) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
let now = Instant::now(); let now = Instant::now();
let addresses = &self.addresses[shard]; let mut candidates: Vec<&Address> = self.addresses[shard]
.iter()
.filter(|address| address.role == role)
.collect();
let mut allowed_attempts = match role { // Random load balancing
// Primary-specific queries get one attempt, if the primary is down, candidates.shuffle(&mut thread_rng());
// nothing we should do about it I think. It's dangerous to retry
// write queries.
Some(Role::Primary) => 1,
// Replicas get to try as many times as there are replicas
// and connections in the pool.
_ => addresses.len(),
};
debug!("Allowed attempts for {:?}: {}", role, allowed_attempts);
let exists = match role {
Some(role) => addresses.iter().filter(|addr| addr.role == role).count() > 0,
None => true,
};
if !exists {
error!("Requested role {:?}, but none are configured", role);
return Err(Error::BadConfig);
}
let healthcheck_timeout = get_config().general.healthcheck_timeout; let healthcheck_timeout = get_config().general.healthcheck_timeout;
let healthcheck_delay = get_config().general.healthcheck_delay as u128; let healthcheck_delay = get_config().general.healthcheck_delay as u128;
while allowed_attempts > 0 { while !candidates.is_empty() {
// Round-robin replicas. // Get the next candidate
round_robin += 1; let address = match candidates.pop() {
Some(address) => address,
None => break,
};
let index = round_robin % addresses.len(); if self.is_banned(&address, role) {
let address = &addresses[index]; debug!("Address {:?} is banned", address);
// Make sure you're getting a primary or a replica
// as per request. If no specific role is requested, the first
// available will be chosen.
if address.role != role {
continue;
}
allowed_attempts -= 1;
// Don't attempt to connect to banned servers.
if self.is_banned(address, shard, role) {
continue; continue;
} }
@@ -311,12 +336,14 @@ impl ConnectionPool {
self.stats.client_waiting(process_id, address.id); self.stats.client_waiting(process_id, address.id);
// Check if we can connect // Check if we can connect
let mut conn = match self.databases[shard][index].get().await { let mut conn = match self.databases[address.shard][address.address_index]
.get()
.await
{
Ok(conn) => conn, Ok(conn) => conn,
Err(err) => { Err(err) => {
error!("Banning replica {}, error: {:?}", index, err); error!("Banning instance {:?}, error: {:?}", address, err);
self.ban(address, shard, process_id); self.ban(&address, process_id);
self.stats.client_disconnecting(process_id, address.id);
self.stats self.stats
.checkout_time(now.elapsed().as_micros(), process_id, address.id); .checkout_time(now.elapsed().as_micros(), process_id, address.id);
continue; continue;
@@ -330,19 +357,23 @@ impl ConnectionPool {
let require_healthcheck = let require_healthcheck =
server.last_activity().elapsed().unwrap().as_millis() > healthcheck_delay; server.last_activity().elapsed().unwrap().as_millis() > healthcheck_delay;
// Do not issue a health check unless it's been a little while
// since we last checked the server is ok.
// Health checks are pretty expensive.
if !require_healthcheck { if !require_healthcheck {
self.stats self.stats
.checkout_time(now.elapsed().as_micros(), process_id, address.id); .checkout_time(now.elapsed().as_micros(), process_id, address.id);
self.stats.server_idle(conn.process_id(), address.id); self.stats.server_active(conn.process_id(), address.id);
return Ok((conn, address.clone())); return Ok((conn, address.clone()));
} }
debug!("Running health check for replica {}, {:?}", index, address); debug!("Running health check on server {:?}", address);
self.stats.server_tested(server.process_id(), address.id); self.stats.server_tested(server.process_id(), address.id);
match tokio::time::timeout( match tokio::time::timeout(
tokio::time::Duration::from_millis(healthcheck_timeout), tokio::time::Duration::from_millis(healthcheck_timeout),
server.query(";"), server.query(";"), // Cheap query (query parser not used in PG)
) )
.await .await
{ {
@@ -351,67 +382,72 @@ impl ConnectionPool {
Ok(_) => { Ok(_) => {
self.stats self.stats
.checkout_time(now.elapsed().as_micros(), process_id, address.id); .checkout_time(now.elapsed().as_micros(), process_id, address.id);
self.stats.server_idle(conn.process_id(), address.id); self.stats.server_active(conn.process_id(), address.id);
return Ok((conn, address.clone())); return Ok((conn, address.clone()));
} }
// Health check failed. // Health check failed.
Err(_) => { Err(err) => {
error!("Banning replica {} because of failed health check", index); error!(
"Banning instance {:?} because of failed health check, {:?}",
address, err
);
// Don't leave a bad connection in the pool. // Don't leave a bad connection in the pool.
server.mark_bad(); server.mark_bad();
self.ban(address, shard, process_id); self.ban(&address, process_id);
continue; continue;
} }
}, },
// Health check timed out. // Health check timed out.
Err(_) => { Err(err) => {
error!("Banning replica {} because of health check timeout", index); error!(
"Banning instance {:?} because of health check timeout, {:?}",
address, err
);
// Don't leave a bad connection in the pool. // Don't leave a bad connection in the pool.
server.mark_bad(); server.mark_bad();
self.ban(address, shard, process_id); self.ban(&address, process_id);
continue; continue;
} }
} }
} }
return Err(Error::AllServersDown); Err(Error::AllServersDown)
} }
/// Ban an address (i.e. replica). It no longer will serve /// Ban an address (i.e. replica). It no longer will serve
/// traffic for any new transactions. Existing transactions on that replica /// traffic for any new transactions. Existing transactions on that replica
/// will finish successfully or error out to the clients. /// will finish successfully or error out to the clients.
pub fn ban(&self, address: &Address, shard: usize, process_id: i32) { pub fn ban(&self, address: &Address, process_id: i32) {
self.stats.client_disconnecting(process_id, address.id); self.stats.client_disconnecting(process_id, address.id);
self.stats
.checkout_time(Instant::now().elapsed().as_micros(), process_id, address.id);
error!("Banning {:?}", address); error!("Banning {:?}", address);
let now = chrono::offset::Utc::now().naive_utc(); let now = chrono::offset::Utc::now().naive_utc();
let mut guard = self.banlist.write(); let mut guard = self.banlist.write();
guard[shard].insert(address.clone(), now); guard[address.shard].insert(address.clone(), now);
} }
/// Clear the replica to receive traffic again. Takes effect immediately /// Clear the replica to receive traffic again. Takes effect immediately
/// for all new transactions. /// for all new transactions.
pub fn _unban(&self, address: &Address, shard: usize) { pub fn _unban(&self, address: &Address) {
let mut guard = self.banlist.write(); let mut guard = self.banlist.write();
guard[shard].remove(address); guard[address.shard].remove(address);
} }
/// Check if a replica can serve traffic. If all replicas are banned, /// Check if a replica can serve traffic. If all replicas are banned,
/// we unban all of them. Better to try then not to. /// we unban all of them. Better to try then not to.
pub fn is_banned(&self, address: &Address, shard: usize, role: Option<Role>) -> bool { pub fn is_banned(&self, address: &Address, role: Option<Role>) -> bool {
let replicas_available = match role { let replicas_available = match role {
Some(Role::Replica) => self.addresses[shard] Some(Role::Replica) => self.addresses[address.shard]
.iter() .iter()
.filter(|addr| addr.role == Role::Replica) .filter(|addr| addr.role == Role::Replica)
.count(), .count(),
None => self.addresses[shard].len(), None => self.addresses[address.shard].len(),
Some(Role::Primary) => return false, // Primary cannot be banned. Some(Role::Primary) => return false, // Primary cannot be banned.
}; };
@@ -420,17 +456,17 @@ impl ConnectionPool {
let guard = self.banlist.read(); let guard = self.banlist.read();
// Everything is banned = nothing is banned. // Everything is banned = nothing is banned.
if guard[shard].len() == replicas_available { if guard[address.shard].len() == replicas_available {
drop(guard); drop(guard);
let mut guard = self.banlist.write(); let mut guard = self.banlist.write();
guard[shard].clear(); guard[address.shard].clear();
drop(guard); drop(guard);
warn!("Unbanning all replicas."); warn!("Unbanning all replicas.");
return false; return false;
} }
// I expect this to miss 99.9999% of the time. // I expect this to miss 99.9999% of the time.
match guard[shard].get(address) { match guard[address.shard].get(address) {
Some(timestamp) => { Some(timestamp) => {
let now = chrono::offset::Utc::now().naive_utc(); let now = chrono::offset::Utc::now().naive_utc();
let config = get_config(); let config = get_config();
@@ -440,7 +476,7 @@ impl ConnectionPool {
drop(guard); drop(guard);
warn!("Unbanning {:?}", address); warn!("Unbanning {:?}", address);
let mut guard = self.banlist.write(); let mut guard = self.banlist.write();
guard[shard].remove(address); guard[address.shard].remove(address);
false false
} else { } else {
debug!("{:?} is banned", address); debug!("{:?} is banned", address);
@@ -577,6 +613,7 @@ pub fn get_pool(db: String, user: String) -> Option<ConnectionPool> {
} }
} }
/// How many total servers we have in the config.
pub fn get_number_of_addresses() -> usize { pub fn get_number_of_addresses() -> usize {
get_all_pools() get_all_pools()
.iter() .iter()
@@ -584,6 +621,7 @@ pub fn get_number_of_addresses() -> usize {
.sum() .sum()
} }
/// Get a pointer to all configured pools.
pub fn get_all_pools() -> HashMap<(String, String), ConnectionPool> { pub fn get_all_pools() -> HashMap<(String, String), ConnectionPool> {
return (*(*POOLS.load())).clone(); return (*(*POOLS.load())).clone();
} }

View File

@@ -10,8 +10,6 @@ use crate::config::Address;
use crate::pool::get_all_pools; use crate::pool::get_all_pools;
use crate::stats::get_stats; use crate::stats::get_stats;
pub const HTTP_PORT: usize = 9930;
struct MetricHelpType { struct MetricHelpType {
help: &'static str, help: &'static str,
ty: &'static str, ty: &'static str,

View File

@@ -10,7 +10,7 @@ use sqlparser::parser::Parser;
use crate::config::Role; use crate::config::Role;
use crate::pool::PoolSettings; use crate::pool::PoolSettings;
use crate::sharding::{Sharder, ShardingFunction}; use crate::sharding::Sharder;
/// Regexes used to parse custom commands. /// Regexes used to parse custom commands.
const CUSTOM_SQL_REGEXES: [&str; 7] = [ const CUSTOM_SQL_REGEXES: [&str; 7] = [
@@ -55,11 +55,13 @@ pub struct QueryRouter {
/// Include the primary into the replica pool for reads. /// Include the primary into the replica pool for reads.
primary_reads_enabled: bool, primary_reads_enabled: bool,
/// Pool configuration.
pool_settings: PoolSettings, pool_settings: PoolSettings,
} }
impl QueryRouter { impl QueryRouter {
/// One-time initialization of regexes. /// One-time initialization of regexes
/// that parse our custom SQL protocol.
pub fn setup() -> bool { pub fn setup() -> bool {
let set = match RegexSet::new(&CUSTOM_SQL_REGEXES) { let set = match RegexSet::new(&CUSTOM_SQL_REGEXES) {
Ok(rgx) => rgx, Ok(rgx) => rgx,
@@ -74,10 +76,7 @@ impl QueryRouter {
.map(|rgx| Regex::new(rgx).unwrap()) .map(|rgx| Regex::new(rgx).unwrap())
.collect(); .collect();
// Impossible assert_eq!(list.len(), set.len());
if list.len() != set.len() {
return false;
}
match CUSTOM_SQL_REGEX_LIST.set(list) { match CUSTOM_SQL_REGEX_LIST.set(list) {
Ok(_) => true, Ok(_) => true,
@@ -90,7 +89,8 @@ impl QueryRouter {
} }
} }
/// Create a new instance of the query router. Each client gets its own. /// Create a new instance of the query router.
/// Each client gets its own.
pub fn new() -> QueryRouter { pub fn new() -> QueryRouter {
QueryRouter { QueryRouter {
active_shard: None, active_shard: None,
@@ -101,6 +101,7 @@ impl QueryRouter {
} }
} }
/// Pool settings can change because of a config reload.
pub fn update_pool_settings(&mut self, pool_settings: PoolSettings) { pub fn update_pool_settings(&mut self, pool_settings: PoolSettings) {
self.pool_settings = pool_settings; self.pool_settings = pool_settings;
} }
@@ -136,19 +137,6 @@ impl QueryRouter {
return None; return None;
} }
let sharding_function = match self.pool_settings.sharding_function.as_ref() {
"pg_bigint_hash" => ShardingFunction::PgBigintHash,
"sha1" => ShardingFunction::Sha1,
_ => unreachable!(),
};
let default_server_role = match self.pool_settings.default_role.as_ref() {
"any" => None,
"primary" => Some(Role::Primary),
"replica" => Some(Role::Replica),
_ => unreachable!(),
};
let command = match matches[0] { let command = match matches[0] {
0 => Command::SetShardingKey, 0 => Command::SetShardingKey,
1 => Command::SetShard, 1 => Command::SetShard,
@@ -200,7 +188,10 @@ impl QueryRouter {
match command { match command {
Command::SetShardingKey => { Command::SetShardingKey => {
let sharder = Sharder::new(self.pool_settings.shards.len(), sharding_function); let sharder = Sharder::new(
self.pool_settings.shards,
self.pool_settings.sharding_function,
);
let shard = sharder.shard(value.parse::<i64>().unwrap()); let shard = sharder.shard(value.parse::<i64>().unwrap());
self.active_shard = Some(shard); self.active_shard = Some(shard);
value = shard.to_string(); value = shard.to_string();
@@ -208,7 +199,7 @@ impl QueryRouter {
Command::SetShard => { Command::SetShard => {
self.active_shard = match value.to_ascii_uppercase().as_ref() { self.active_shard = match value.to_ascii_uppercase().as_ref() {
"ANY" => Some(rand::random::<usize>() % self.pool_settings.shards.len()), "ANY" => Some(rand::random::<usize>() % self.pool_settings.shards),
_ => Some(value.parse::<usize>().unwrap()), _ => Some(value.parse::<usize>().unwrap()),
}; };
} }
@@ -236,7 +227,7 @@ impl QueryRouter {
} }
"default" => { "default" => {
self.active_role = default_server_role; self.active_role = self.pool_settings.default_role;
self.query_parser_enabled = self.query_parser_enabled; self.query_parser_enabled = self.query_parser_enabled;
self.active_role self.active_role
} }
@@ -367,10 +358,10 @@ impl QueryRouter {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use std::collections::HashMap;
use super::*; use super::*;
use crate::messages::simple_query; use crate::messages::simple_query;
use crate::pool::PoolMode;
use crate::sharding::ShardingFunction;
use bytes::BufMut; use bytes::BufMut;
#[test] #[test]
@@ -633,13 +624,13 @@ mod test {
QueryRouter::setup(); QueryRouter::setup();
let pool_settings = PoolSettings { let pool_settings = PoolSettings {
pool_mode: "transaction".to_string(), pool_mode: PoolMode::Transaction,
shards: HashMap::default(), shards: 0,
user: crate::config::User::default(), user: crate::config::User::default(),
default_role: Role::Replica.to_string(), default_role: Some(Role::Replica),
query_parser_enabled: true, query_parser_enabled: true,
primary_reads_enabled: false, primary_reads_enabled: false,
sharding_function: "pg_bigint_hash".to_string(), sharding_function: ShardingFunction::PgBigintHash,
}; };
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
assert_eq!(qr.active_role, None); assert_eq!(qr.active_role, None);
@@ -661,9 +652,6 @@ mod test {
let q2 = simple_query("SET SERVER ROLE TO 'default'"); let q2 = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(q2) != None); assert!(qr.try_execute_command(q2) != None);
assert_eq!( assert_eq!(qr.active_role.unwrap(), pool_settings.clone().default_role);
qr.active_role.unwrap().to_string(),
pool_settings.clone().default_role
);
} }
} }

View File

@@ -75,7 +75,7 @@ impl Server {
stats: Reporter, stats: Reporter,
) -> Result<Server, Error> { ) -> Result<Server, Error> {
let mut stream = let mut stream =
match TcpStream::connect(&format!("{}:{}", &address.host, &address.port)).await { match TcpStream::connect(&format!("{}:{}", &address.host, address.port)).await {
Ok(stream) => stream, Ok(stream) => stream,
Err(err) => { Err(err) => {
error!("Could not connect to server: {}", err); error!("Could not connect to server: {}", err);
@@ -86,7 +86,13 @@ impl Server {
trace!("Sending StartupMessage"); trace!("Sending StartupMessage");
// StartupMessage // StartupMessage
startup(&mut stream, &user.username, database).await?; startup(
&mut stream,
&user.username,
database,
address.search_path.as_ref(),
)
.await?;
let mut server_info = BytesMut::new(); let mut server_info = BytesMut::new();
let mut process_id: i32 = 0; let mut process_id: i32 = 0;
@@ -342,7 +348,7 @@ impl Server {
/// Uses a separate connection that's not part of the connection pool. /// Uses a separate connection that's not part of the connection pool.
pub async fn cancel( pub async fn cancel(
host: &str, host: &str,
port: &str, port: u16,
process_id: i32, process_id: i32,
secret_key: i32, secret_key: i32,
) -> Result<(), Error> { ) -> Result<(), Error> {
@@ -529,7 +535,7 @@ impl Server {
self.process_id, self.process_id,
self.secret_key, self.secret_key,
self.address.host.clone(), self.address.host.clone(),
self.address.port.clone(), self.address.port,
), ),
); );
} }

View File

@@ -1,9 +1,10 @@
use arc_swap::ArcSwap; use arc_swap::ArcSwap;
/// Statistics and reporting. /// Statistics and reporting.
use log::info; use log::{error, info, trace};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use parking_lot::Mutex; use parking_lot::Mutex;
use std::collections::HashMap; use std::collections::HashMap;
use tokio::sync::mpsc::error::TrySendError;
use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::mpsc::{channel, Receiver, Sender};
use crate::pool::get_number_of_addresses; use crate::pool::get_number_of_addresses;
@@ -43,7 +44,7 @@ enum EventName {
/// Event data sent to the collector /// Event data sent to the collector
/// from clients and servers. /// from clients and servers.
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct Event { pub struct Event {
/// The name of the event being reported. /// The name of the event being reported.
name: EventName, name: EventName,
@@ -79,6 +80,25 @@ impl Reporter {
Reporter { tx: tx } Reporter { tx: tx }
} }
/// Send statistics to the task keeping track of stats.
fn send(&self, event: Event) {
let name = event.name;
let result = self.tx.try_send(event);
match result {
Ok(_) => trace!(
"{:?} event reported successfully, capacity: {}",
name,
self.tx.capacity()
),
Err(err) => match err {
TrySendError::Full { .. } => error!("{:?} event dropped, buffer full", name),
TrySendError::Closed { .. } => error!("{:?} event dropped, channel closed", name),
},
};
}
/// Report a query executed by a client against /// Report a query executed by a client against
/// a server identified by the `address_id`. /// a server identified by the `address_id`.
pub fn query(&self, process_id: i32, address_id: usize) { pub fn query(&self, process_id: i32, address_id: usize) {
@@ -89,7 +109,7 @@ impl Reporter {
address_id: address_id, address_id: address_id,
}; };
let _ = self.tx.try_send(event); self.send(event);
} }
/// Report a transaction executed by a client against /// Report a transaction executed by a client against
@@ -102,7 +122,7 @@ impl Reporter {
address_id: address_id, address_id: address_id,
}; };
let _ = self.tx.try_send(event); self.send(event)
} }
/// Report data sent to a server identified by `address_id`. /// Report data sent to a server identified by `address_id`.
@@ -115,7 +135,7 @@ impl Reporter {
address_id: address_id, address_id: address_id,
}; };
let _ = self.tx.try_send(event); self.send(event)
} }
/// Report data received from a server identified by `address_id`. /// Report data received from a server identified by `address_id`.
@@ -128,7 +148,7 @@ impl Reporter {
address_id: address_id, address_id: address_id,
}; };
let _ = self.tx.try_send(event); self.send(event)
} }
/// Time spent waiting to get a healthy connection from the pool /// Time spent waiting to get a healthy connection from the pool
@@ -142,7 +162,7 @@ impl Reporter {
address_id: address_id, address_id: address_id,
}; };
let _ = self.tx.try_send(event); self.send(event)
} }
/// Reports a client identified by `process_id` waiting for a connection /// Reports a client identified by `process_id` waiting for a connection
@@ -155,7 +175,7 @@ impl Reporter {
address_id: address_id, address_id: address_id,
}; };
let _ = self.tx.try_send(event); self.send(event)
} }
/// Reports a client identified by `process_id` is done waiting for a connection /// Reports a client identified by `process_id` is done waiting for a connection
@@ -168,7 +188,7 @@ impl Reporter {
address_id: address_id, address_id: address_id,
}; };
let _ = self.tx.try_send(event); self.send(event)
} }
/// Reports a client identified by `process_id` is done querying the server /// Reports a client identified by `process_id` is done querying the server
@@ -181,7 +201,7 @@ impl Reporter {
address_id: address_id, address_id: address_id,
}; };
let _ = self.tx.try_send(event); self.send(event)
} }
/// Reports a client identified by `process_id` is disconecting from the pooler. /// Reports a client identified by `process_id` is disconecting from the pooler.
@@ -194,7 +214,7 @@ impl Reporter {
address_id: address_id, address_id: address_id,
}; };
let _ = self.tx.try_send(event); self.send(event)
} }
/// Reports a server connection identified by `process_id` for /// Reports a server connection identified by `process_id` for
@@ -208,7 +228,7 @@ impl Reporter {
address_id: address_id, address_id: address_id,
}; };
let _ = self.tx.try_send(event); self.send(event)
} }
/// Reports a server connection identified by `process_id` for /// Reports a server connection identified by `process_id` for
@@ -222,7 +242,7 @@ impl Reporter {
address_id: address_id, address_id: address_id,
}; };
let _ = self.tx.try_send(event); self.send(event)
} }
/// Reports a server connection identified by `process_id` for /// Reports a server connection identified by `process_id` for
@@ -236,7 +256,7 @@ impl Reporter {
address_id: address_id, address_id: address_id,
}; };
let _ = self.tx.try_send(event); self.send(event)
} }
/// Reports a server connection identified by `process_id` for /// Reports a server connection identified by `process_id` for
@@ -250,7 +270,7 @@ impl Reporter {
address_id: address_id, address_id: address_id,
}; };
let _ = self.tx.try_send(event); self.send(event)
} }
/// Reports a server connection identified by `process_id` is disconecting from the pooler. /// Reports a server connection identified by `process_id` is disconecting from the pooler.
@@ -263,7 +283,7 @@ impl Reporter {
address_id: address_id, address_id: address_id,
}; };
let _ = self.tx.try_send(event); self.send(event)
} }
} }

View File

@@ -14,6 +14,7 @@ PGCAT_PORT = "6432"
def pgcat_start(): def pgcat_start():
pg_cat_send_signal(signal.SIGTERM) pg_cat_send_signal(signal.SIGTERM)
os.system("./target/debug/pgcat .circleci/pgcat.toml &") os.system("./target/debug/pgcat .circleci/pgcat.toml &")
time.sleep(2)
def pg_cat_send_signal(signal: signal.Signals): def pg_cat_send_signal(signal: signal.Signals):
@@ -27,11 +28,23 @@ def pg_cat_send_signal(signal: signal.Signals):
raise Exception("pgcat not closed after SIGTERM") raise Exception("pgcat not closed after SIGTERM")
def connect_normal_db( def connect_db(
autocommit: bool = False, autocommit: bool = True,
admin: bool = False,
) -> Tuple[psycopg2.extensions.connection, psycopg2.extensions.cursor]: ) -> Tuple[psycopg2.extensions.connection, psycopg2.extensions.cursor]:
if admin:
user = "admin_user"
password = "admin_pass"
db = "pgcat"
else:
user = "sharding_user"
password = "sharding_user"
db = "sharded_db"
conn = psycopg2.connect( conn = psycopg2.connect(
f"postgres://sharding_user:sharding_user@{PGCAT_HOST}:{PGCAT_PORT}/sharded_db?application_name=testing_pgcat" f"postgres://{user}:{password}@{PGCAT_HOST}:{PGCAT_PORT}/{db}?application_name=testing_pgcat",
connect_timeout=2,
) )
conn.autocommit = autocommit conn.autocommit = autocommit
cur = conn.cursor() cur = conn.cursor()
@@ -45,7 +58,7 @@ def cleanup_conn(conn: psycopg2.extensions.connection, cur: psycopg2.extensions.
def test_normal_db_access(): def test_normal_db_access():
conn, cur = connect_normal_db() conn, cur = connect_db(autocommit=False)
cur.execute("SELECT 1") cur.execute("SELECT 1")
res = cur.fetchall() res = cur.fetchall()
print(res) print(res)
@@ -53,11 +66,7 @@ def test_normal_db_access():
def test_admin_db_access(): def test_admin_db_access():
conn = psycopg2.connect( conn, cur = connect_db(admin=True)
f"postgres://admin_user:admin_pass@{PGCAT_HOST}:{PGCAT_PORT}/pgcat"
)
conn.autocommit = True # BEGIN/COMMIT is not supported by admin db
cur = conn.cursor()
cur.execute("SHOW POOLS") cur.execute("SHOW POOLS")
res = cur.fetchall() res = cur.fetchall()
@@ -67,15 +76,14 @@ def test_admin_db_access():
def test_shutdown_logic(): def test_shutdown_logic():
##### NO ACTIVE QUERIES SIGINT HANDLING ##### # - - - - - - - - - - - - - - - - - -
# NO ACTIVE QUERIES SIGINT HANDLING
# Start pgcat # Start pgcat
pgcat_start() pgcat_start()
# Wait for server to fully start up
time.sleep(2)
# Create client connection and send query (not in transaction) # Create client connection and send query (not in transaction)
conn, cur = connect_normal_db(True) conn, cur = connect_db()
cur.execute("BEGIN;") cur.execute("BEGIN;")
cur.execute("SELECT 1;") cur.execute("SELECT 1;")
@@ -97,17 +105,14 @@ def test_shutdown_logic():
cleanup_conn(conn, cur) cleanup_conn(conn, cur)
pg_cat_send_signal(signal.SIGTERM) pg_cat_send_signal(signal.SIGTERM)
##### END ##### # - - - - - - - - - - - - - - - - - -
# HANDLE TRANSACTION WITH SIGINT
##### HANDLE TRANSACTION WITH SIGINT #####
# Start pgcat # Start pgcat
pgcat_start() pgcat_start()
# Wait for server to fully start up
time.sleep(2)
# Create client connection and begin transaction # Create client connection and begin transaction
conn, cur = connect_normal_db(True) conn, cur = connect_db()
cur.execute("BEGIN;") cur.execute("BEGIN;")
cur.execute("SELECT 1;") cur.execute("SELECT 1;")
@@ -126,17 +131,97 @@ def test_shutdown_logic():
cleanup_conn(conn, cur) cleanup_conn(conn, cur)
pg_cat_send_signal(signal.SIGTERM) pg_cat_send_signal(signal.SIGTERM)
##### END ##### # - - - - - - - - - - - - - - - - - -
# NO NEW NON-ADMIN CONNECTIONS DURING SHUTDOWN
##### HANDLE SHUTDOWN TIMEOUT WITH SIGINT #####
# Start pgcat # Start pgcat
pgcat_start() pgcat_start()
# Wait for server to fully start up # Create client connection and begin transaction
time.sleep(3) transaction_conn, transaction_cur = connect_db()
transaction_cur.execute("BEGIN;")
transaction_cur.execute("SELECT 1;")
# Send sigint to pgcat while still in transaction
pg_cat_send_signal(signal.SIGINT)
time.sleep(1)
start = time.perf_counter()
try:
conn, cur = connect_db()
cur.execute("SELECT 1;")
cleanup_conn(conn, cur)
except psycopg2.OperationalError as e:
time_taken = time.perf_counter() - start
if time_taken > 0.1:
raise Exception(
"Failed to reject connection within 0.1 seconds, got", time_taken, "seconds")
pass
else:
raise Exception("Able connect to database during shutdown")
cleanup_conn(transaction_conn, transaction_cur)
pg_cat_send_signal(signal.SIGTERM)
# - - - - - - - - - - - - - - - - - -
# ALLOW NEW ADMIN CONNECTIONS DURING SHUTDOWN
# Start pgcat
pgcat_start()
# Create client connection and begin transaction
transaction_conn, transaction_cur = connect_db()
transaction_cur.execute("BEGIN;")
transaction_cur.execute("SELECT 1;")
# Send sigint to pgcat while still in transaction
pg_cat_send_signal(signal.SIGINT)
time.sleep(1)
try:
conn, cur = connect_db(admin=True)
cur.execute("SHOW DATABASES;")
cleanup_conn(conn, cur)
except psycopg2.OperationalError as e:
raise Exception(e)
cleanup_conn(transaction_conn, transaction_cur)
pg_cat_send_signal(signal.SIGTERM)
# - - - - - - - - - - - - - - - - - -
# ADMIN CONNECTIONS CONTINUING TO WORK AFTER SHUTDOWN
# Start pgcat
pgcat_start()
# Create client connection and begin transaction
transaction_conn, transaction_cur = connect_db()
transaction_cur.execute("BEGIN;")
transaction_cur.execute("SELECT 1;")
admin_conn, admin_cur = connect_db(admin=True)
admin_cur.execute("SHOW DATABASES;")
# Send sigint to pgcat while still in transaction
pg_cat_send_signal(signal.SIGINT)
time.sleep(1)
try:
admin_cur.execute("SHOW DATABASES;")
except psycopg2.OperationalError as e:
raise Exception("Could not execute admin command:", e)
cleanup_conn(transaction_conn, transaction_cur)
cleanup_conn(admin_conn, admin_cur)
pg_cat_send_signal(signal.SIGTERM)
# - - - - - - - - - - - - - - - - - -
# HANDLE SHUTDOWN TIMEOUT WITH SIGINT
# Start pgcat
pgcat_start()
# Create client connection and begin transaction, which should prevent server shutdown unless shutdown timeout is reached # Create client connection and begin transaction, which should prevent server shutdown unless shutdown timeout is reached
conn, cur = connect_normal_db(True) conn, cur = connect_db()
cur.execute("BEGIN;") cur.execute("BEGIN;")
cur.execute("SELECT 1;") cur.execute("SELECT 1;")
@@ -159,7 +244,7 @@ def test_shutdown_logic():
cleanup_conn(conn, cur) cleanup_conn(conn, cur)
pg_cat_send_signal(signal.SIGTERM) pg_cat_send_signal(signal.SIGTERM)
##### END ##### # - - - - - - - - - - - - - - - - - -
test_normal_db_access() test_normal_db_access()

View File

@@ -5,6 +5,89 @@ require 'pg'
require 'toml' require 'toml'
$stdout.sync = true $stdout.sync = true
$stderr.sync = true
class ConfigEditor
def initialize
@original_config_text = File.read('../../.circleci/pgcat.toml')
text_to_load = @original_config_text.gsub("5432", "\"5432\"")
@original_configs = TOML.load(text_to_load)
end
def original_configs
TOML.load(TOML::Generator.new(@original_configs).body)
end
def with_modified_configs(new_configs)
text_to_write = TOML::Generator.new(new_configs).body
text_to_write = text_to_write.gsub("\"5432\"", "5432")
File.write('../../.circleci/pgcat.toml', text_to_write)
yield
ensure
File.write('../../.circleci/pgcat.toml', @original_config_text)
end
end
def with_captured_stdout_stderr
sout = STDOUT.clone
serr = STDERR.clone
STDOUT.reopen("/tmp/out.txt", "w+")
STDERR.reopen("/tmp/err.txt", "w+")
STDOUT.sync = true
STDERR.sync = true
yield
return File.read('/tmp/out.txt'), File.read('/tmp/err.txt')
ensure
STDOUT.reopen(sout)
STDERR.reopen(serr)
end
def test_extended_protocol_pooler_errors
admin_conn = PG::connect("postgres://admin_user:admin_pass@127.0.0.1:6432/pgcat")
conf_editor = ConfigEditor.new
new_configs = conf_editor.original_configs
# shorter timeouts
new_configs["general"]["connect_timeout"] = 500
new_configs["general"]["ban_time"] = 1
new_configs["general"]["shutdown_timeout"] = 1
new_configs["pools"]["sharded_db"]["users"]["0"]["pool_size"] = 1
new_configs["pools"]["sharded_db"]["users"]["1"]["pool_size"] = 1
conf_editor.with_modified_configs(new_configs) { admin_conn.async_exec("RELOAD") }
conn_str = "postgres://sharding_user:sharding_user@127.0.0.1:6432/sharded_db"
10.times do
Thread.new do
conn = PG::connect(conn_str)
conn.async_exec("SELECT pg_sleep(5)") rescue PG::SystemError
ensure
conn&.close
end
end
sleep(0.5)
conn_under_test = PG::connect(conn_str)
stdout, stderr = with_captured_stdout_stderr do
5.times do |i|
conn_under_test.async_exec("SELECT 1") rescue PG::SystemError
conn_under_test.exec_params("SELECT #{i} + $1", [i]) rescue PG::SystemError
sleep 1
end
end
raise StandardError, "Libpq got unexpected messages while idle" if stderr.include?("arrived from server while idle")
puts "Pool checkout errors not breaking clients passed"
ensure
sleep 1
admin_conn.async_exec("RELOAD") # Reset state
conn_under_test&.close
end
test_extended_protocol_pooler_errors
# Uncomment these two to see all queries. # Uncomment these two to see all queries.
# ActiveRecord.verbose_query_logs = true # ActiveRecord.verbose_query_logs = true
@@ -144,30 +227,6 @@ def test_server_parameters
end end
class ConfigEditor
def initialize
@original_config_text = File.read('../../.circleci/pgcat.toml')
text_to_load = @original_config_text.gsub("5432", "\"5432\"")
@original_configs = TOML.load(text_to_load)
end
def original_configs
TOML.load(TOML::Generator.new(@original_configs).body)
end
def with_modified_configs(new_configs)
text_to_write = TOML::Generator.new(new_configs).body
text_to_write = text_to_write.gsub("\"5432\"", "5432")
File.write('../../.circleci/pgcat.toml', text_to_write)
yield
ensure
File.write('../../.circleci/pgcat.toml', @original_config_text)
end
end
def test_reload_pool_recycling def test_reload_pool_recycling
admin_conn = PG::connect("postgres://admin_user:admin_pass@127.0.0.1:6432/pgcat") admin_conn = PG::connect("postgres://admin_user:admin_pass@127.0.0.1:6432/pgcat")
server_conn = PG::connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/sharded_db?application_name=testing_pgcat") server_conn = PG::connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/sharded_db?application_name=testing_pgcat")
@@ -201,3 +260,6 @@ ensure
end end
test_reload_pool_recycling test_reload_pool_recycling