Compare commits

..

4 Commits

Author SHA1 Message Date
Lev Kokotov
e7265cbf91 fix flakey test 2023-05-03 16:01:48 -07:00
Lev Kokotov
d738ba28b6 fix tests 2023-05-03 15:42:16 -07:00
Lev Kokotov
ff80bb75cc clean up 2023-05-03 15:38:03 -07:00
Lev Kokotov
374a6b138b more plugins 2023-05-03 15:29:16 -07:00
21 changed files with 358 additions and 639 deletions

138
Cargo.lock generated
View File

@@ -250,12 +250,6 @@ dependencies = [
"subtle", "subtle",
] ]
[[package]]
name = "either"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91"
[[package]] [[package]]
name = "enum-as-inner" name = "enum-as-inner"
version = "0.5.1" version = "0.5.1"
@@ -562,7 +556,7 @@ dependencies = [
"httpdate", "httpdate",
"itoa", "itoa",
"pin-project-lite", "pin-project-lite",
"socket2 0.4.9", "socket2",
"tokio", "tokio",
"tower-service", "tower-service",
"tracing", "tracing",
@@ -631,7 +625,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1abeb7a0dd0f8181267ff8adc397075586500b81b28a73e8a0208b00fc170fb3" checksum = "1abeb7a0dd0f8181267ff8adc397075586500b81b28a73e8a0208b00fc170fb3"
dependencies = [ dependencies = [
"libc", "libc",
"windows-sys 0.45.0", "windows-sys",
] ]
[[package]] [[package]]
@@ -640,7 +634,7 @@ version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd302af1b90f2463a98fa5ad469fc212c8e3175a41c3068601bfa2727591c5be" checksum = "bd302af1b90f2463a98fa5ad469fc212c8e3175a41c3068601bfa2727591c5be"
dependencies = [ dependencies = [
"socket2 0.4.9", "socket2",
"widestring", "widestring",
"winapi", "winapi",
"winreg", "winreg",
@@ -661,16 +655,7 @@ dependencies = [
"hermit-abi 0.3.1", "hermit-abi 0.3.1",
"io-lifetimes", "io-lifetimes",
"rustix", "rustix",
"windows-sys 0.45.0", "windows-sys",
]
[[package]]
name = "itertools"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473"
dependencies = [
"either",
] ]
[[package]] [[package]]
@@ -716,9 +701,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.144" version = "0.2.139"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b00cc1c228a6782d0f076e7b232802e0c5689d41bb5df366f2a6b6621cfdfe1" checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79"
[[package]] [[package]]
name = "link-cplusplus" name = "link-cplusplus"
@@ -814,7 +799,7 @@ dependencies = [
"libc", "libc",
"log", "log",
"wasi 0.11.0+wasi-snapshot-preview1", "wasi 0.11.0+wasi-snapshot-preview1",
"windows-sys 0.45.0", "windows-sys",
] ]
[[package]] [[package]]
@@ -886,7 +871,7 @@ dependencies = [
"libc", "libc",
"redox_syscall", "redox_syscall",
"smallvec", "smallvec",
"windows-sys 0.45.0", "windows-sys",
] ]
[[package]] [[package]]
@@ -897,7 +882,7 @@ checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e"
[[package]] [[package]]
name = "pgcat" name = "pgcat"
version = "1.0.2-alpha3" version = "1.0.2-alpha1"
dependencies = [ dependencies = [
"arc-swap", "arc-swap",
"async-trait", "async-trait",
@@ -912,7 +897,6 @@ dependencies = [
"futures", "futures",
"hmac", "hmac",
"hyper", "hyper",
"itertools",
"jemallocator", "jemallocator",
"log", "log",
"md-5", "md-5",
@@ -932,7 +916,7 @@ dependencies = [
"serde_json", "serde_json",
"sha-1", "sha-1",
"sha2", "sha2",
"socket2 0.5.3", "socket2",
"sqlparser", "sqlparser",
"stringprep", "stringprep",
"tokio", "tokio",
@@ -1157,7 +1141,7 @@ dependencies = [
"io-lifetimes", "io-lifetimes",
"libc", "libc",
"linux-raw-sys", "linux-raw-sys",
"windows-sys 0.45.0", "windows-sys",
] ]
[[package]] [[package]]
@@ -1313,24 +1297,14 @@ checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0"
[[package]] [[package]]
name = "socket2" name = "socket2"
version = "0.4.9" version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" checksum = "02e2d2db9033d13a1567121ddd7a095ee144db4e1ca1b1bda3419bc0da294ebd"
dependencies = [ dependencies = [
"libc", "libc",
"winapi", "winapi",
] ]
[[package]]
name = "socket2"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2538b18701741680e0322a2302176d3253a35388e2e62f172f64f4f16605f877"
dependencies = [
"libc",
"windows-sys 0.48.0",
]
[[package]] [[package]]
name = "spin" name = "spin"
version = "0.5.2" version = "0.5.2"
@@ -1472,9 +1446,9 @@ dependencies = [
"parking_lot", "parking_lot",
"pin-project-lite", "pin-project-lite",
"signal-hook-registry", "signal-hook-registry",
"socket2 0.4.9", "socket2",
"tokio-macros", "tokio-macros",
"windows-sys 0.45.0", "windows-sys",
] ]
[[package]] [[package]]
@@ -1853,16 +1827,7 @@ version = "0.45.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0"
dependencies = [ dependencies = [
"windows-targets 0.42.1", "windows-targets",
]
[[package]]
name = "windows-sys"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9"
dependencies = [
"windows-targets 0.48.0",
] ]
[[package]] [[package]]
@@ -1871,28 +1836,13 @@ version = "0.42.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e2522491fbfcd58cc84d47aeb2958948c4b8982e9a2d8a2a35bbaed431390e7" checksum = "8e2522491fbfcd58cc84d47aeb2958948c4b8982e9a2d8a2a35bbaed431390e7"
dependencies = [ dependencies = [
"windows_aarch64_gnullvm 0.42.1", "windows_aarch64_gnullvm",
"windows_aarch64_msvc 0.42.1", "windows_aarch64_msvc",
"windows_i686_gnu 0.42.1", "windows_i686_gnu",
"windows_i686_msvc 0.42.1", "windows_i686_msvc",
"windows_x86_64_gnu 0.42.1", "windows_x86_64_gnu",
"windows_x86_64_gnullvm 0.42.1", "windows_x86_64_gnullvm",
"windows_x86_64_msvc 0.42.1", "windows_x86_64_msvc",
]
[[package]]
name = "windows-targets"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5"
dependencies = [
"windows_aarch64_gnullvm 0.48.0",
"windows_aarch64_msvc 0.48.0",
"windows_i686_gnu 0.48.0",
"windows_i686_msvc 0.48.0",
"windows_x86_64_gnu 0.48.0",
"windows_x86_64_gnullvm 0.48.0",
"windows_x86_64_msvc 0.48.0",
] ]
[[package]] [[package]]
@@ -1901,84 +1851,42 @@ version = "0.42.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c9864e83243fdec7fc9c5444389dcbbfd258f745e7853198f365e3c4968a608" checksum = "8c9864e83243fdec7fc9c5444389dcbbfd258f745e7853198f365e3c4968a608"
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc"
[[package]] [[package]]
name = "windows_aarch64_msvc" name = "windows_aarch64_msvc"
version = "0.42.1" version = "0.42.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c8b1b673ffc16c47a9ff48570a9d85e25d265735c503681332589af6253c6c7" checksum = "4c8b1b673ffc16c47a9ff48570a9d85e25d265735c503681332589af6253c6c7"
[[package]]
name = "windows_aarch64_msvc"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3"
[[package]] [[package]]
name = "windows_i686_gnu" name = "windows_i686_gnu"
version = "0.42.1" version = "0.42.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de3887528ad530ba7bdbb1faa8275ec7a1155a45ffa57c37993960277145d640" checksum = "de3887528ad530ba7bdbb1faa8275ec7a1155a45ffa57c37993960277145d640"
[[package]]
name = "windows_i686_gnu"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241"
[[package]] [[package]]
name = "windows_i686_msvc" name = "windows_i686_msvc"
version = "0.42.1" version = "0.42.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf4d1122317eddd6ff351aa852118a2418ad4214e6613a50e0191f7004372605" checksum = "bf4d1122317eddd6ff351aa852118a2418ad4214e6613a50e0191f7004372605"
[[package]]
name = "windows_i686_msvc"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00"
[[package]] [[package]]
name = "windows_x86_64_gnu" name = "windows_x86_64_gnu"
version = "0.42.1" version = "0.42.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1040f221285e17ebccbc2591ffdc2d44ee1f9186324dd3e84e99ac68d699c45" checksum = "c1040f221285e17ebccbc2591ffdc2d44ee1f9186324dd3e84e99ac68d699c45"
[[package]]
name = "windows_x86_64_gnu"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1"
[[package]] [[package]]
name = "windows_x86_64_gnullvm" name = "windows_x86_64_gnullvm"
version = "0.42.1" version = "0.42.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "628bfdf232daa22b0d64fdb62b09fcc36bb01f05a3939e20ab73aaf9470d0463" checksum = "628bfdf232daa22b0d64fdb62b09fcc36bb01f05a3939e20ab73aaf9470d0463"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953"
[[package]] [[package]]
name = "windows_x86_64_msvc" name = "windows_x86_64_msvc"
version = "0.42.1" version = "0.42.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "447660ad36a13288b1db4d4248e857b510e8c3a225c822ba4fb748c0aafecffd" checksum = "447660ad36a13288b1db4d4248e857b510e8c3a225c822ba4fb748c0aafecffd"
[[package]]
name = "windows_x86_64_msvc"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a"
[[package]] [[package]]
name = "winnow" name = "winnow"
version = "0.3.3" version = "0.3.3"

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "pgcat" name = "pgcat"
version = "1.0.2-alpha3" version = "1.0.2-alpha1"
edition = "2021" edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -34,7 +34,7 @@ hyper = { version = "0.14", features = ["full"] }
phf = { version = "0.11.1", features = ["macros"] } phf = { version = "0.11.1", features = ["macros"] }
exitcode = "1.1.2" exitcode = "1.1.2"
futures = "0.3" futures = "0.3"
socket2 = { version = "0.5.3", features = ["all"] } socket2 = { version = "0.4.7", features = ["all"] }
nix = "0.26.2" nix = "0.26.2"
atomic_enum = "0.2.0" atomic_enum = "0.2.0"
postgres-protocol = "0.6.5" postgres-protocol = "0.6.5"
@@ -45,7 +45,6 @@ rustls = { version = "0.21", features = ["dangerous_configuration"] }
trust-dns-resolver = "0.22.0" trust-dns-resolver = "0.22.0"
tokio-test = "0.4.2" tokio-test = "0.4.2"
serde_json = "1" serde_json = "1"
itertools = "0.10"
[target.'cfg(not(target_env = "msvc"))'.dependencies] [target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0" jemallocator = "0.5.0"

View File

@@ -25,7 +25,7 @@ x-common-env-pg:
services: services:
main: main:
image: gcr.io/google_containers/pause:3.2 image: kubernetes/pause
ports: ports:
- 6432 - 6432
@@ -64,7 +64,7 @@ services:
<<: *common-env-pg <<: *common-env-pg
POSTGRES_INITDB_ARGS: --auth-local=md5 --auth-host=md5 --auth=md5 POSTGRES_INITDB_ARGS: --auth-local=md5 --auth-host=md5 --auth=md5
PGPORT: 10432 PGPORT: 10432
command: ["postgres", "-p", "10432", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-c", "pg_stat_statements.max=100000"] command: ["postgres", "-p", "5432", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-c", "pg_stat_statements.max=100000"]
toxiproxy: toxiproxy:
build: . build: .

View File

@@ -1,22 +0,0 @@
# This is an example of the most basic config
# that will mimic what PgBouncer does in transaction mode with one server.
[general]
host = "0.0.0.0"
port = 6433
admin_username = "pgcat"
admin_password = "pgcat"
[pools.pgml.users.0]
username = "postgres"
password = "postgres"
pool_size = 10
min_pool_size = 1
pool_mode = "transaction"
[pools.pgml.shards.0]
servers = [
["127.0.0.1", 28815, "primary"]
]
database = "postgres"

View File

@@ -77,58 +77,6 @@ admin_username = "admin_user"
# Password to access the virtual administrative database # Password to access the virtual administrative database
admin_password = "admin_pass" admin_password = "admin_pass"
# Default plugins that are configured on all pools.
[plugins]
# Prewarmer plugin that runs queries on server startup, before giving the connection
# to the client.
[plugins.prewarmer]
enabled = false
queries = [
"SELECT pg_prewarm('pgbench_accounts')",
]
# Log all queries to stdout.
[plugins.query_logger]
enabled = false
# Block access to tables that Postgres does not allow us to control.
[plugins.table_access]
enabled = false
tables = [
"pg_user",
"pg_roles",
"pg_database",
]
# Intercept user queries and give a fake reply.
[plugins.intercept]
enabled = true
[plugins.intercept.queries.0]
query = "select current_database() as a, current_schemas(false) as b"
schema = [
["a", "text"],
["b", "text"],
]
result = [
["${DATABASE}", "{public}"],
]
[plugins.intercept.queries.1]
query = "select current_database(), current_schema(), current_user"
schema = [
["current_database", "text"],
["current_schema", "text"],
["current_user", "text"],
]
result = [
["${DATABASE}", "public", "${USER}"],
]
# pool configs are structured as pool.<pool_name> # pool configs are structured as pool.<pool_name>
# the pool_name is what clients use as database name when connecting. # the pool_name is what clients use as database name when connecting.
# For a pool named `sharded_db`, clients access that pool using connection string like # For a pool named `sharded_db`, clients access that pool using connection string like
@@ -206,20 +154,12 @@ connect_timeout = 3000
# Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`). # Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
# dns_max_ttl = 30 # dns_max_ttl = 30
# Plugins can be configured on a pool-per-pool basis. This overrides the global plugins setting, [plugins]
# so all plugins have to be configured here again.
[pool.sharded_db.plugins]
[pools.sharded_db.plugins.prewarmer] [plugins.query_logger]
enabled = true
queries = [
"SELECT pg_prewarm('pgbench_accounts')",
]
[pools.sharded_db.plugins.query_logger]
enabled = false enabled = false
[pools.sharded_db.plugins.table_access] [plugins.table_access]
enabled = false enabled = false
tables = [ tables = [
"pg_user", "pg_user",
@@ -227,10 +167,10 @@ tables = [
"pg_database", "pg_database",
] ]
[pools.sharded_db.plugins.intercept] [plugins.intercept]
enabled = true enabled = true
[pools.sharded_db.plugins.intercept.queries.0] [plugins.intercept.queries.0]
query = "select current_database() as a, current_schemas(false) as b" query = "select current_database() as a, current_schemas(false) as b"
schema = [ schema = [
@@ -241,7 +181,7 @@ result = [
["${DATABASE}", "{public}"], ["${DATABASE}", "{public}"],
] ]
[pools.sharded_db.plugins.intercept.queries.1] [plugins.intercept.queries.1]
query = "select current_database(), current_schema(), current_user" query = "select current_database(), current_schema(), current_user"
schema = [ schema = [

View File

@@ -122,16 +122,6 @@ impl Default for Address {
} }
} }
impl std::fmt::Display for Address {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"[address: {}:{}][database: {}][user: {}]",
self.host, self.port, self.database, self.username
)
}
}
// We need to implement PartialEq by ourselves so we skip stats in the comparison // We need to implement PartialEq by ourselves so we skip stats in the comparison
impl PartialEq for Address { impl PartialEq for Address {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
@@ -245,8 +235,6 @@ pub struct General {
pub port: u16, pub port: u16,
pub enable_prometheus_exporter: Option<bool>, pub enable_prometheus_exporter: Option<bool>,
#[serde(default = "General::default_prometheus_exporter_port")]
pub prometheus_exporter_port: i16, pub prometheus_exporter_port: i16,
#[serde(default = "General::default_connect_timeout")] #[serde(default = "General::default_connect_timeout")]
@@ -310,9 +298,6 @@ pub struct General {
pub admin_username: String, pub admin_username: String,
pub admin_password: String, pub admin_password: String,
#[serde(default = "General::default_validate_config")]
pub validate_config: bool,
// Support for auth query // Support for auth query
pub auth_query: Option<String>, pub auth_query: Option<String>,
pub auth_query_user: Option<String>, pub auth_query_user: Option<String>,
@@ -382,14 +367,6 @@ impl General {
pub fn default_idle_client_in_transaction_timeout() -> u64 { pub fn default_idle_client_in_transaction_timeout() -> u64 {
0 0
} }
pub fn default_validate_config() -> bool {
true
}
pub fn default_prometheus_exporter_port() -> i16 {
9930
}
} }
impl Default for General { impl Default for General {
@@ -425,7 +402,6 @@ impl Default for General {
auth_query_user: None, auth_query_user: None,
auth_query_password: None, auth_query_password: None,
server_lifetime: 1000 * 3600 * 24, // 24 hours, server_lifetime: 1000 * 3600 * 24, // 24 hours,
validate_config: true,
} }
} }
} }
@@ -478,7 +454,6 @@ pub struct Pool {
#[serde(default = "Pool::default_load_balancing_mode")] #[serde(default = "Pool::default_load_balancing_mode")]
pub load_balancing_mode: LoadBalancingMode, pub load_balancing_mode: LoadBalancingMode,
#[serde(default = "Pool::default_default_role")]
pub default_role: String, pub default_role: String,
#[serde(default)] // False #[serde(default)] // False
@@ -493,7 +468,6 @@ pub struct Pool {
pub server_lifetime: Option<u64>, pub server_lifetime: Option<u64>,
#[serde(default = "Pool::default_sharding_function")]
pub sharding_function: ShardingFunction, pub sharding_function: ShardingFunction,
#[serde(default = "Pool::default_automatic_sharding_key")] #[serde(default = "Pool::default_automatic_sharding_key")]
@@ -507,7 +481,6 @@ pub struct Pool {
pub auth_query_user: Option<String>, pub auth_query_user: Option<String>,
pub auth_query_password: Option<String>, pub auth_query_password: Option<String>,
pub plugins: Option<Plugins>,
pub shards: BTreeMap<String, Shard>, pub shards: BTreeMap<String, Shard>,
pub users: BTreeMap<String, User>, pub users: BTreeMap<String, User>,
// Note, don't put simple fields below these configs. There's a compatibility issue with TOML that makes it // Note, don't put simple fields below these configs. There's a compatibility issue with TOML that makes it
@@ -540,14 +513,6 @@ impl Pool {
None None
} }
pub fn default_default_role() -> String {
"any".into()
}
pub fn default_sharding_function() -> ShardingFunction {
ShardingFunction::PgBigintHash
}
pub fn validate(&mut self) -> Result<(), Error> { pub fn validate(&mut self) -> Result<(), Error> {
match self.default_role.as_ref() { match self.default_role.as_ref() {
"any" => (), "any" => (),
@@ -636,7 +601,6 @@ impl Default for Pool {
auth_query_user: None, auth_query_user: None,
auth_query_password: None, auth_query_password: None,
server_lifetime: None, server_lifetime: None,
plugins: None,
} }
} }
} }
@@ -715,60 +679,39 @@ impl Default for Shard {
} }
} }
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Plugins { pub struct Plugins {
pub intercept: Option<Intercept>, pub intercept: Option<Intercept>,
pub table_access: Option<TableAccess>, pub table_access: Option<TableAccess>,
pub query_logger: Option<QueryLogger>, pub query_logger: Option<QueryLogger>,
pub prewarmer: Option<Prewarmer>,
} }
impl std::fmt::Display for Plugins { #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"interceptor: {}, table_access: {}, query_logger: {}, prewarmer: {}",
self.intercept.is_some(),
self.table_access.is_some(),
self.query_logger.is_some(),
self.prewarmer.is_some(),
)
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
pub struct Intercept { pub struct Intercept {
pub enabled: bool, pub enabled: bool,
pub queries: BTreeMap<String, Query>, pub queries: BTreeMap<String, Query>,
} }
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct TableAccess { pub struct TableAccess {
pub enabled: bool, pub enabled: bool,
pub tables: Vec<String>, pub tables: Vec<String>,
} }
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct QueryLogger { pub struct QueryLogger {
pub enabled: bool, pub enabled: bool,
} }
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
pub struct Prewarmer {
pub enabled: bool,
pub queries: Vec<String>,
}
impl Intercept { impl Intercept {
pub fn substitute(&mut self, db: &str, user: &str) { pub fn substitute(&mut self, db: &str, user: &str) {
for (_, query) in self.queries.iter_mut() { for (_, query) in self.queries.iter_mut() {
query.substitute(db, user); query.substitute(db, user);
query.query = query.query.to_ascii_lowercase();
} }
} }
} }
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Query { pub struct Query {
pub query: String, pub query: String,
pub schema: Vec<Vec<String>>, pub schema: Vec<Vec<String>>,
@@ -802,13 +745,8 @@ pub struct Config {
#[serde(default = "Config::default_path")] #[serde(default = "Config::default_path")]
pub path: String, pub path: String,
// General and global settings.
pub general: General, pub general: General,
// Plugins that should run in all pools.
pub plugins: Option<Plugins>, pub plugins: Option<Plugins>,
// Connection pools.
pub pools: HashMap<String, Pool>, pub pools: HashMap<String, Pool>,
} }
@@ -993,13 +931,6 @@ impl Config {
"Server TLS certificate verification: {}", "Server TLS certificate verification: {}",
self.general.verify_server_certificate self.general.verify_server_certificate
); );
info!(
"Plugins: {}",
match self.plugins {
Some(ref plugins) => plugins.to_string(),
None => "not configured".into(),
}
);
for (pool_name, pool_config) in &self.pools { for (pool_name, pool_config) in &self.pools {
// TODO: Make this output prettier (maybe a table?) // TODO: Make this output prettier (maybe a table?)
@@ -1066,14 +997,6 @@ impl Config {
None => "default".to_string(), None => "default".to_string(),
} }
); );
info!(
"[pool: {}] Plugins: {}",
pool_name,
match pool_config.plugins {
Some(ref plugins) => plugins.to_string(),
None => "not configured".into(),
}
);
for user in &pool_config.users { for user in &pool_config.users {
info!( info!(

View File

@@ -43,7 +43,6 @@ impl MirroredClient {
ClientServerMap::default(), ClientServerMap::default(),
Arc::new(PoolStats::new(identifier, cfg.clone())), Arc::new(PoolStats::new(identifier, cfg.clone())),
Arc::new(RwLock::new(None)), Arc::new(RwLock::new(None)),
None,
); );
Pool::builder() Pool::builder()

View File

@@ -2,21 +2,52 @@
//! //!
//! It intercepts queries and returns fake results. //! It intercepts queries and returns fake results.
use arc_swap::ArcSwap;
use async_trait::async_trait; use async_trait::async_trait;
use bytes::{BufMut, BytesMut}; use bytes::{BufMut, BytesMut};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use sqlparser::ast::Statement; use sqlparser::ast::Statement;
use std::collections::HashMap;
use log::debug; use log::{debug, info};
use std::sync::Arc;
use crate::{ use crate::{
config::Intercept as InterceptConfig, config::Intercept as InterceptConfig,
errors::Error, errors::Error,
messages::{command_complete, data_row_nullable, row_description, DataType}, messages::{command_complete, data_row_nullable, row_description, DataType},
plugins::{Plugin, PluginOutput}, plugins::{Plugin, PluginOutput},
pool::{PoolIdentifier, PoolMap},
query_router::QueryRouter, query_router::QueryRouter,
}; };
pub static CONFIG: Lazy<ArcSwap<HashMap<PoolIdentifier, InterceptConfig>>> =
Lazy::new(|| ArcSwap::from_pointee(HashMap::new()));
/// Check if the interceptor plugin has been enabled.
pub fn enabled() -> bool {
!CONFIG.load().is_empty()
}
pub fn setup(intercept_config: &InterceptConfig, pools: &PoolMap) {
let mut config = HashMap::new();
for (identifier, _) in pools.iter() {
let mut intercept_config = intercept_config.clone();
intercept_config.substitute(&identifier.db, &identifier.user);
config.insert(identifier.clone(), intercept_config);
}
CONFIG.store(Arc::new(config));
info!("Intercepting {} queries", intercept_config.queries.len());
}
pub fn disable() {
CONFIG.store(Arc::new(HashMap::new()));
}
// TODO: use these structs for deserialization // TODO: use these structs for deserialization
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct Rule { pub struct Rule {
@@ -32,35 +63,33 @@ pub struct Column {
} }
/// The intercept plugin. /// The intercept plugin.
pub struct Intercept<'a> { pub struct Intercept;
pub enabled: bool,
pub config: &'a InterceptConfig,
}
#[async_trait] #[async_trait]
impl<'a> Plugin for Intercept<'a> { impl Plugin for Intercept {
async fn run( async fn run(
&mut self, &mut self,
query_router: &QueryRouter, query_router: &QueryRouter,
ast: &Vec<Statement>, ast: &Vec<Statement>,
) -> Result<PluginOutput, Error> { ) -> Result<PluginOutput, Error> {
if !self.enabled || ast.is_empty() { if ast.is_empty() {
return Ok(PluginOutput::Allow); return Ok(PluginOutput::Allow);
} }
let mut config = self.config.clone(); let mut result = BytesMut::new();
config.substitute( let query_map = match CONFIG.load().get(&PoolIdentifier::new(
&query_router.pool_settings().db, &query_router.pool_settings().db,
&query_router.pool_settings().user.username, &query_router.pool_settings().user.username,
); )) {
Some(query_map) => query_map.clone(),
let mut result = BytesMut::new(); None => return Ok(PluginOutput::Allow),
};
for q in ast { for q in ast {
// Normalization // Normalization
let q = q.to_string().to_ascii_lowercase(); let q = q.to_string().to_ascii_lowercase();
for (_, target) in config.queries.iter() { for (_, target) in query_map.queries.iter() {
if target.query.as_str() == q { if target.query.as_str() == q {
debug!("Intercepting query: {}", q); debug!("Intercepting query: {}", q);
@@ -118,3 +147,142 @@ impl<'a> Plugin for Intercept<'a> {
} }
} }
} }
/// Make IntelliJ SQL plugin believe it's talking to an actual database
/// instead of PgCat.
#[allow(dead_code)]
fn fool_datagrip(database: &str, user: &str) -> Value {
json!([
{
"query": "select current_database() as a, current_schemas(false) as b",
"schema": [
{
"name": "a",
"data_type": "text",
},
{
"name": "b",
"data_type": "anyarray",
},
],
"result": [
[database, "{public}"],
],
},
{
"query": "select current_database(), current_schema(), current_user",
"schema": [
{
"name": "current_database",
"data_type": "text",
},
{
"name": "current_schema",
"data_type": "text",
},
{
"name": "current_user",
"data_type": "text",
}
],
"result": [
["sharded_db", "public", "sharding_user"],
],
},
{
"query": "select cast(n.oid as bigint) as id, datname as name, d.description, datistemplate as is_template, datallowconn as allow_connections, pg_catalog.pg_get_userbyid(n.datdba) as \"owner\" from pg_catalog.pg_database as n left join pg_catalog.pg_shdescription as d on n.oid = d.objoid order by case when datname = pg_catalog.current_database() then -cast(1 as bigint) else cast(n.oid as bigint) end",
"schema": [
{
"name": "id",
"data_type": "oid",
},
{
"name": "name",
"data_type": "text",
},
{
"name": "description",
"data_type": "text",
},
{
"name": "is_template",
"data_type": "bool",
},
{
"name": "allow_connections",
"data_type": "bool",
},
{
"name": "owner",
"data_type": "text",
}
],
"result": [
["16387", database, "", "f", "t", user],
]
},
{
"query": "select cast(r.oid as bigint) as role_id, rolname as role_name, rolsuper as is_super, rolinherit as is_inherit, rolcreaterole as can_createrole, rolcreatedb as can_createdb, rolcanlogin as can_login, rolreplication as is_replication, rolconnlimit as conn_limit, rolvaliduntil as valid_until, rolbypassrls as bypass_rls, rolconfig as config, d.description from pg_catalog.pg_roles as r left join pg_catalog.pg_shdescription as d on d.objoid = r.oid",
"schema": [
{
"name": "role_id",
"data_type": "oid",
},
{
"name": "role_name",
"data_type": "text",
},
{
"name": "is_super",
"data_type": "bool",
},
{
"name": "is_inherit",
"data_type": "bool",
},
{
"name": "can_createrole",
"data_type": "bool",
},
{
"name": "can_createdb",
"data_type": "bool",
},
{
"name": "can_login",
"data_type": "bool",
},
{
"name": "is_replication",
"data_type": "bool",
},
{
"name": "conn_limit",
"data_type": "int4",
},
{
"name": "valid_until",
"data_type": "text",
},
{
"name": "bypass_rls",
"data_type": "bool",
},
{
"name": "config",
"data_type": "text",
},
{
"name": "description",
"data_type": "text",
},
],
"result": [
["10", "postgres", "f", "t", "f", "f", "t", "f", "-1", "", "f", "", ""],
["16419", user, "f", "t", "f", "f", "t", "f", "-1", "", "f", "", ""],
]
}
])
}

View File

@@ -9,7 +9,6 @@
//! //!
pub mod intercept; pub mod intercept;
pub mod prewarmer;
pub mod query_logger; pub mod query_logger;
pub mod table_access; pub mod table_access;

View File

@@ -1,28 +0,0 @@
//! Prewarm new connections before giving them to the client.
use crate::{errors::Error, server::Server};
use log::info;
pub struct Prewarmer<'a> {
pub enabled: bool,
pub server: &'a mut Server,
pub queries: &'a Vec<String>,
}
impl<'a> Prewarmer<'a> {
pub async fn run(&mut self) -> Result<(), Error> {
if !self.enabled {
return Ok(());
}
for query in self.queries {
info!(
"{} Prewarning with query: `{}`",
self.server.address(),
query
);
self.server.query(&query).await?;
}
Ok(())
}
}

View File

@@ -5,33 +5,44 @@ use crate::{
plugins::{Plugin, PluginOutput}, plugins::{Plugin, PluginOutput},
query_router::QueryRouter, query_router::QueryRouter,
}; };
use arc_swap::ArcSwap;
use async_trait::async_trait; use async_trait::async_trait;
use log::info; use log::info;
use once_cell::sync::Lazy;
use sqlparser::ast::Statement; use sqlparser::ast::Statement;
use std::sync::Arc;
pub struct QueryLogger<'a> { static ENABLED: Lazy<ArcSwap<bool>> = Lazy::new(|| ArcSwap::from_pointee(false));
pub enabled: bool,
pub user: &'a str, pub struct QueryLogger;
pub db: &'a str,
pub fn setup() {
ENABLED.store(Arc::new(true));
info!("Logging queries to stdout");
}
pub fn disable() {
ENABLED.store(Arc::new(false));
}
pub fn enabled() -> bool {
**ENABLED.load()
} }
#[async_trait] #[async_trait]
impl<'a> Plugin for QueryLogger<'a> { impl Plugin for QueryLogger {
async fn run( async fn run(
&mut self, &mut self,
_query_router: &QueryRouter, _query_router: &QueryRouter,
ast: &Vec<Statement>, ast: &Vec<Statement>,
) -> Result<PluginOutput, Error> { ) -> Result<PluginOutput, Error> {
if !self.enabled {
return Ok(PluginOutput::Allow);
}
let query = ast let query = ast
.iter() .iter()
.map(|q| q.to_string()) .map(|q| q.to_string())
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join("; "); .join("; ");
info!("[pool: {}][user: {}] {}", self.user, self.db, query); info!("{}", query);
Ok(PluginOutput::Allow) Ok(PluginOutput::Allow)
} }

View File

@@ -5,39 +5,53 @@ use async_trait::async_trait;
use sqlparser::ast::{visit_relations, Statement}; use sqlparser::ast::{visit_relations, Statement};
use crate::{ use crate::{
config::TableAccess as TableAccessConfig,
errors::Error, errors::Error,
plugins::{Plugin, PluginOutput}, plugins::{Plugin, PluginOutput},
query_router::QueryRouter, query_router::QueryRouter,
}; };
use log::debug; use log::{debug, info};
use arc_swap::ArcSwap;
use core::ops::ControlFlow; use core::ops::ControlFlow;
use once_cell::sync::Lazy;
use std::sync::Arc;
pub struct TableAccess<'a> { static CONFIG: Lazy<ArcSwap<Vec<String>>> = Lazy::new(|| ArcSwap::from_pointee(vec![]));
pub enabled: bool,
pub tables: &'a Vec<String>, pub fn setup(config: &TableAccessConfig) {
CONFIG.store(Arc::new(config.tables.clone()));
info!("Blocking access to {} tables", config.tables.len());
} }
pub fn enabled() -> bool {
!CONFIG.load().is_empty()
}
pub fn disable() {
CONFIG.store(Arc::new(vec![]));
}
pub struct TableAccess;
#[async_trait] #[async_trait]
impl<'a> Plugin for TableAccess<'a> { impl Plugin for TableAccess {
async fn run( async fn run(
&mut self, &mut self,
_query_router: &QueryRouter, _query_router: &QueryRouter,
ast: &Vec<Statement>, ast: &Vec<Statement>,
) -> Result<PluginOutput, Error> { ) -> Result<PluginOutput, Error> {
if !self.enabled {
return Ok(PluginOutput::Allow);
}
let mut found = None; let mut found = None;
let forbidden_tables = CONFIG.load();
visit_relations(ast, |relation| { visit_relations(ast, |relation| {
let relation = relation.to_string(); let relation = relation.to_string();
let parts = relation.split(".").collect::<Vec<&str>>(); let parts = relation.split(".").collect::<Vec<&str>>();
let table_name = parts.last().unwrap(); let table_name = parts.last().unwrap();
if self.tables.contains(&table_name.to_string()) { if forbidden_tables.contains(&table_name.to_string()) {
found = Some(table_name.to_string()); found = Some(table_name.to_string());
ControlFlow::<()>::Break(()) ControlFlow::<()>::Break(())
} else { } else {

View File

@@ -17,13 +17,10 @@ use std::sync::{
use std::time::Instant; use std::time::Instant;
use tokio::sync::Notify; use tokio::sync::Notify;
use crate::config::{ use crate::config::{get_config, Address, General, LoadBalancingMode, PoolMode, Role, User};
get_config, Address, General, LoadBalancingMode, Plugins, PoolMode, Role, User,
};
use crate::errors::Error; use crate::errors::Error;
use crate::auth_passthrough::AuthPassthrough; use crate::auth_passthrough::AuthPassthrough;
use crate::plugins::prewarmer;
use crate::server::Server; use crate::server::Server;
use crate::sharding::ShardingFunction; use crate::sharding::ShardingFunction;
use crate::stats::{AddressStats, ClientStats, PoolStats, ServerStats}; use crate::stats::{AddressStats, ClientStats, PoolStats, ServerStats};
@@ -135,9 +132,6 @@ pub struct PoolSettings {
pub auth_query: Option<String>, pub auth_query: Option<String>,
pub auth_query_user: Option<String>, pub auth_query_user: Option<String>,
pub auth_query_password: Option<String>, pub auth_query_password: Option<String>,
/// Plugins
pub plugins: Option<Plugins>,
} }
impl Default for PoolSettings { impl Default for PoolSettings {
@@ -162,7 +156,6 @@ impl Default for PoolSettings {
auth_query: None, auth_query: None,
auth_query_user: None, auth_query_user: None,
auth_query_password: None, auth_query_password: None,
plugins: None,
} }
} }
} }
@@ -202,7 +195,6 @@ pub struct ConnectionPool {
paused: Arc<AtomicBool>, paused: Arc<AtomicBool>,
paused_waiter: Arc<Notify>, paused_waiter: Arc<Notify>,
/// Statistics.
pub stats: Arc<PoolStats>, pub stats: Arc<PoolStats>,
/// AuthInfo /// AuthInfo
@@ -360,10 +352,6 @@ impl ConnectionPool {
client_server_map.clone(), client_server_map.clone(),
pool_stats.clone(), pool_stats.clone(),
pool_auth_hash.clone(), pool_auth_hash.clone(),
match pool_config.plugins {
Some(ref plugins) => Some(plugins.clone()),
None => config.plugins.clone(),
},
); );
let connect_timeout = match pool_config.connect_timeout { let connect_timeout = match pool_config.connect_timeout {
@@ -389,10 +377,7 @@ impl ConnectionPool {
.min() .min()
.unwrap(); .unwrap();
debug!( debug!("Pool reaper rate: {}ms", reaper_rate);
"[pool: {}][user: {}] Pool reaper rate: {}ms",
pool_name, user.username, reaper_rate
);
let pool = Pool::builder() let pool = Pool::builder()
.max_size(user.pool_size) .max_size(user.pool_size)
@@ -401,13 +386,9 @@ impl ConnectionPool {
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout))) .idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
.max_lifetime(Some(std::time::Duration::from_millis(server_lifetime))) .max_lifetime(Some(std::time::Duration::from_millis(server_lifetime)))
.reaper_rate(std::time::Duration::from_millis(reaper_rate)) .reaper_rate(std::time::Duration::from_millis(reaper_rate))
.test_on_check_out(false); .test_on_check_out(false)
.build(manager)
let pool = if config.general.validate_config { .await?;
pool.build(manager).await?
} else {
pool.build_unchecked(manager)
};
pools.push(pool); pools.push(pool);
servers.push(address); servers.push(address);
@@ -469,10 +450,6 @@ impl ConnectionPool {
auth_query: pool_config.auth_query.clone(), auth_query: pool_config.auth_query.clone(),
auth_query_user: pool_config.auth_query_user.clone(), auth_query_user: pool_config.auth_query_user.clone(),
auth_query_password: pool_config.auth_query_password.clone(), auth_query_password: pool_config.auth_query_password.clone(),
plugins: match pool_config.plugins {
Some(ref plugins) => Some(plugins.clone()),
None => config.plugins.clone(),
},
}, },
validated: Arc::new(AtomicBool::new(false)), validated: Arc::new(AtomicBool::new(false)),
paused: Arc::new(AtomicBool::new(false)), paused: Arc::new(AtomicBool::new(false)),
@@ -482,18 +459,42 @@ impl ConnectionPool {
// Connect to the servers to make sure pool configuration is valid // Connect to the servers to make sure pool configuration is valid
// before setting it globally. // before setting it globally.
// Do this async and somewhere else, we don't have to wait here. // Do this async and somewhere else, we don't have to wait here.
if config.general.validate_config { let mut validate_pool = pool.clone();
let mut validate_pool = pool.clone(); tokio::task::spawn(async move {
tokio::task::spawn(async move { let _ = validate_pool.validate().await;
let _ = validate_pool.validate().await; });
});
}
// There is one pool per database/user pair. // There is one pool per database/user pair.
new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool); new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool);
} }
} }
if let Some(ref plugins) = config.plugins {
if let Some(ref intercept) = plugins.intercept {
if intercept.enabled {
crate::plugins::intercept::setup(intercept, &new_pools);
} else {
crate::plugins::intercept::disable();
}
}
if let Some(ref table_access) = plugins.table_access {
if table_access.enabled {
crate::plugins::table_access::setup(table_access);
} else {
crate::plugins::table_access::disable();
}
}
if let Some(ref query_logger) = plugins.query_logger {
if query_logger.enabled {
crate::plugins::query_logger::setup();
} else {
crate::plugins::query_logger::disable();
}
}
}
POOLS.store(Arc::new(new_pools.clone())); POOLS.store(Arc::new(new_pools.clone()));
Ok(()) Ok(())
} }
@@ -638,10 +639,7 @@ impl ConnectionPool {
{ {
Ok(conn) => conn, Ok(conn) => conn,
Err(err) => { Err(err) => {
error!( error!("Banning instance {:?}, error: {:?}", address, err);
"Connection checkout error for instance {:?}, error: {:?}",
address, err
);
self.ban(address, BanReason::FailedCheckout, Some(client_stats)); self.ban(address, BanReason::FailedCheckout, Some(client_stats));
address.stats.error(); address.stats.error();
client_stats.idle(); client_stats.idle();
@@ -717,7 +715,7 @@ impl ConnectionPool {
// Health check failed. // Health check failed.
Err(err) => { Err(err) => {
error!( error!(
"Failed health check on instance {:?}, error: {:?}", "Banning instance {:?} because of failed health check, {:?}",
address, err address, err
); );
} }
@@ -726,7 +724,7 @@ impl ConnectionPool {
// Health check timed out. // Health check timed out.
Err(err) => { Err(err) => {
error!( error!(
"Health check timeout on instance {:?}, error: {:?}", "Banning instance {:?} because of health check timeout, {:?}",
address, err address, err
); );
} }
@@ -748,16 +746,13 @@ impl ConnectionPool {
return; return;
} }
error!("Banning instance {:?}, reason: {:?}", address, reason);
let now = chrono::offset::Utc::now().naive_utc(); let now = chrono::offset::Utc::now().naive_utc();
let mut guard = self.banlist.write(); let mut guard = self.banlist.write();
error!("Banning {:?}", address);
if let Some(client_info) = client_info { if let Some(client_info) = client_info {
client_info.ban_error(); client_info.ban_error();
address.stats.error(); address.stats.error();
} }
guard[address.shard].insert(address.clone(), (reason, now)); guard[address.shard].insert(address.clone(), (reason, now));
} }
@@ -920,7 +915,6 @@ pub struct ServerPool {
client_server_map: ClientServerMap, client_server_map: ClientServerMap,
stats: Arc<PoolStats>, stats: Arc<PoolStats>,
auth_hash: Arc<RwLock<Option<String>>>, auth_hash: Arc<RwLock<Option<String>>>,
plugins: Option<Plugins>,
} }
impl ServerPool { impl ServerPool {
@@ -931,7 +925,6 @@ impl ServerPool {
client_server_map: ClientServerMap, client_server_map: ClientServerMap,
stats: Arc<PoolStats>, stats: Arc<PoolStats>,
auth_hash: Arc<RwLock<Option<String>>>, auth_hash: Arc<RwLock<Option<String>>>,
plugins: Option<Plugins>,
) -> ServerPool { ) -> ServerPool {
ServerPool { ServerPool {
address, address,
@@ -940,7 +933,6 @@ impl ServerPool {
client_server_map, client_server_map,
stats, stats,
auth_hash, auth_hash,
plugins,
} }
} }
} }
@@ -973,19 +965,7 @@ impl ManageConnection for ServerPool {
) )
.await .await
{ {
Ok(mut conn) => { Ok(conn) => {
if let Some(ref plugins) = self.plugins {
if let Some(ref prewarmer) = plugins.prewarmer {
let mut prewarmer = prewarmer::Prewarmer {
enabled: prewarmer.enabled,
server: &mut conn,
queries: &prewarmer.queries,
};
prewarmer.run().await?;
}
}
stats.idle(); stats.idle();
Ok(conn) Ok(conn)
} }

View File

@@ -15,7 +15,10 @@ use sqlparser::parser::Parser;
use crate::config::Role; use crate::config::Role;
use crate::errors::Error; use crate::errors::Error;
use crate::messages::BytesMutReader; use crate::messages::BytesMutReader;
use crate::plugins::{Intercept, Plugin, PluginOutput, QueryLogger, TableAccess}; use crate::plugins::{
intercept, query_logger, table_access, Intercept, Plugin, PluginOutput, QueryLogger,
TableAccess,
};
use crate::pool::PoolSettings; use crate::pool::PoolSettings;
use crate::sharding::Sharder; use crate::sharding::Sharder;
@@ -790,27 +793,13 @@ impl QueryRouter {
/// Add your plugins here and execute them. /// Add your plugins here and execute them.
pub async fn execute_plugins(&self, ast: &Vec<Statement>) -> Result<PluginOutput, Error> { pub async fn execute_plugins(&self, ast: &Vec<Statement>) -> Result<PluginOutput, Error> {
let plugins = match self.pool_settings.plugins { if query_logger::enabled() {
Some(ref plugins) => plugins, let mut query_logger = QueryLogger {};
None => return Ok(PluginOutput::Allow),
};
if let Some(ref query_logger) = plugins.query_logger {
let mut query_logger = QueryLogger {
enabled: query_logger.enabled,
user: &self.pool_settings.user.username,
db: &self.pool_settings.db,
};
let _ = query_logger.run(&self, ast).await; let _ = query_logger.run(&self, ast).await;
} }
if let Some(ref intercept) = plugins.intercept { if intercept::enabled() {
let mut intercept = Intercept { let mut intercept = Intercept {};
enabled: intercept.enabled,
config: &intercept,
};
let result = intercept.run(&self, ast).await; let result = intercept.run(&self, ast).await;
if let Ok(PluginOutput::Intercept(output)) = result { if let Ok(PluginOutput::Intercept(output)) = result {
@@ -818,12 +807,8 @@ impl QueryRouter {
} }
} }
if let Some(ref table_access) = plugins.table_access { if table_access::enabled() {
let mut table_access = TableAccess { let mut table_access = TableAccess {};
enabled: table_access.enabled,
tables: &table_access.tables,
};
let result = table_access.run(&self, ast).await; let result = table_access.run(&self, ast).await;
if let Ok(PluginOutput::Deny(error)) = result { if let Ok(PluginOutput::Deny(error)) = result {
@@ -1176,7 +1161,6 @@ mod test {
auth_query_password: None, auth_query_password: None,
auth_query_user: None, auth_query_user: None,
db: "test".to_string(), db: "test".to_string(),
plugins: None,
}; };
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
assert_eq!(qr.active_role, None); assert_eq!(qr.active_role, None);
@@ -1251,9 +1235,7 @@ mod test {
auth_query_password: None, auth_query_password: None,
auth_query_user: None, auth_query_user: None,
db: "test".to_string(), db: "test".to_string(),
plugins: None,
}; };
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
qr.update_pool_settings(pool_settings.clone()); qr.update_pool_settings(pool_settings.clone());
@@ -1397,25 +1379,17 @@ mod test {
#[tokio::test] #[tokio::test]
async fn test_table_access_plugin() { async fn test_table_access_plugin() {
use crate::config::{Plugins, TableAccess}; use crate::config::TableAccess;
let table_access = TableAccess { let ta = TableAccess {
enabled: true, enabled: true,
tables: vec![String::from("pg_database")], tables: vec![String::from("pg_database")],
}; };
let plugins = Plugins {
table_access: Some(table_access), crate::plugins::table_access::setup(&ta);
intercept: None,
query_logger: None,
prewarmer: None,
};
QueryRouter::setup(); QueryRouter::setup();
let mut pool_settings = PoolSettings::default();
pool_settings.query_parser_enabled = true;
pool_settings.plugins = Some(plugins);
let mut qr = QueryRouter::new(); let qr = QueryRouter::new();
qr.update_pool_settings(pool_settings);
let query = simple_query("SELECT * FROM pg_database"); let query = simple_query("SELECT * FROM pg_database");
let ast = QueryRouter::parse(&query).unwrap(); let ast = QueryRouter::parse(&query).unwrap();
@@ -1429,17 +1403,4 @@ mod test {
)) ))
); );
} }
#[tokio::test]
async fn test_plugins_disabled_by_defaault() {
QueryRouter::setup();
let qr = QueryRouter::new();
let query = simple_query("SELECT * FROM pg_database");
let ast = QueryRouter::parse(&query).unwrap();
let res = qr.execute_plugins(&ast).await;
assert_eq!(res, Ok(PluginOutput::Allow));
}
} }

View File

@@ -103,48 +103,6 @@ impl StreamInner {
} }
} }
#[derive(Copy, Clone)]
struct CleanupState {
/// If server connection requires DISCARD ALL before checkin because of set statement
needs_cleanup_set: bool,
/// If server connection requires DISCARD ALL before checkin because of prepare statement
needs_cleanup_prepare: bool,
}
impl CleanupState {
fn new() -> Self {
CleanupState {
needs_cleanup_set: false,
needs_cleanup_prepare: false,
}
}
fn needs_cleanup(&self) -> bool {
self.needs_cleanup_set || self.needs_cleanup_prepare
}
fn set_true(&mut self) {
self.needs_cleanup_set = true;
self.needs_cleanup_prepare = true;
}
fn reset(&mut self) {
self.needs_cleanup_set = false;
self.needs_cleanup_prepare = false;
}
}
impl std::fmt::Display for CleanupState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"SET: {}, PREPARE: {}",
self.needs_cleanup_set, self.needs_cleanup_prepare
)
}
}
/// Server state. /// Server state.
pub struct Server { pub struct Server {
/// Server host, e.g. localhost, /// Server host, e.g. localhost,
@@ -173,8 +131,8 @@ pub struct Server {
/// Is the server broken? We'll remote it from the pool if so. /// Is the server broken? We'll remote it from the pool if so.
bad: bool, bad: bool,
/// If server connection requires DISCARD ALL before checkin /// If server connection requires a DISCARD ALL before checkin
cleanup_state: CleanupState, needs_cleanup: bool,
/// Mapping of clients and servers used for query cancellation. /// Mapping of clients and servers used for query cancellation.
client_server_map: ClientServerMap, client_server_map: ClientServerMap,
@@ -672,7 +630,7 @@ impl Server {
in_transaction: false, in_transaction: false,
data_available: false, data_available: false,
bad: false, bad: false,
cleanup_state: CleanupState::new(), needs_cleanup: false,
client_server_map, client_server_map,
addr_set, addr_set,
connected_at: chrono::offset::Utc::now().naive_utc(), connected_at: chrono::offset::Utc::now().naive_utc(),
@@ -747,10 +705,7 @@ impl Server {
Ok(()) Ok(())
} }
Err(err) => { Err(err) => {
error!( error!("Terminating server because of: {:?}", err);
"Terminating server {:?} because of: {:?}",
self.address, err
);
self.bad = true; self.bad = true;
Err(err) Err(err)
} }
@@ -765,10 +720,7 @@ impl Server {
let mut message = match read_message(&mut self.stream).await { let mut message = match read_message(&mut self.stream).await {
Ok(message) => message, Ok(message) => message,
Err(err) => { Err(err) => {
error!( error!("Terminating server because of: {:?}", err);
"Terminating server {:?} because of: {:?}",
self.address, err
);
self.bad = true; self.bad = true;
return Err(err); return Err(err);
} }
@@ -835,12 +787,12 @@ impl Server {
// This will reduce amount of discard statements sent // This will reduce amount of discard statements sent
if !self.in_transaction { if !self.in_transaction {
debug!("Server connection marked for clean up"); debug!("Server connection marked for clean up");
self.cleanup_state.needs_cleanup_set = true; self.needs_cleanup = true;
} }
} }
"PREPARE\0" => { "PREPARE\0" => {
debug!("Server connection marked for clean up"); debug!("Server connection marked for clean up");
self.cleanup_state.needs_cleanup_prepare = true; self.needs_cleanup = true;
} }
_ => (), _ => (),
} }
@@ -970,8 +922,6 @@ impl Server {
/// It will use the simple query protocol. /// It will use the simple query protocol.
/// Result will not be returned, so this is useful for things like `SET` or `ROLLBACK`. /// Result will not be returned, so this is useful for things like `SET` or `ROLLBACK`.
pub async fn query(&mut self, query: &str) -> Result<(), Error> { pub async fn query(&mut self, query: &str) -> Result<(), Error> {
debug!("Running `{}` on server {:?}", query, self.address);
let query = simple_query(query); let query = simple_query(query);
self.send(&query).await?; self.send(&query).await?;
@@ -1004,11 +954,10 @@ impl Server {
// to avoid leaking state between clients. For performance reasons we only // to avoid leaking state between clients. For performance reasons we only
// send `DISCARD ALL` if we think the session is altered instead of just sending // send `DISCARD ALL` if we think the session is altered instead of just sending
// it before each checkin. // it before each checkin.
if self.cleanup_state.needs_cleanup() { if self.needs_cleanup {
warn!("Server returned with session state altered, discarding state ({}) for application {}", self.cleanup_state, self.application_name); warn!("Server returned with session state altered, discarding state");
self.query("DISCARD ALL").await?; self.query("DISCARD ALL").await?;
self.query("RESET ROLE").await?; self.needs_cleanup = false;
self.cleanup_state.reset();
} }
Ok(()) Ok(())
@@ -1020,12 +969,12 @@ impl Server {
self.application_name = name.to_string(); self.application_name = name.to_string();
// We don't want `SET application_name` to mark the server connection // We don't want `SET application_name` to mark the server connection
// as needing cleanup // as needing cleanup
let needs_cleanup_before = self.cleanup_state; let needs_cleanup_before = self.needs_cleanup;
let result = Ok(self let result = Ok(self
.query(&format!("SET application_name = '{}'", name)) .query(&format!("SET application_name = '{}'", name))
.await?); .await?);
self.cleanup_state = needs_cleanup_before; self.needs_cleanup = needs_cleanup_before;
result result
} else { } else {
Ok(()) Ok(())
@@ -1050,7 +999,7 @@ impl Server {
// Marks a connection as needing DISCARD ALL at checkin // Marks a connection as needing DISCARD ALL at checkin
pub fn mark_dirty(&mut self) { pub fn mark_dirty(&mut self) {
self.cleanup_state.set_true(); self.needs_cleanup = true;
} }
pub fn mirror_send(&mut self, bytes: &BytesMut) { pub fn mirror_send(&mut self, bytes: &BytesMut) {
@@ -1186,18 +1135,14 @@ impl Drop for Server {
_ => debug!("Dirty shutdown"), _ => debug!("Dirty shutdown"),
}; };
// Should not matter.
self.bad = true;
let now = chrono::offset::Utc::now().naive_utc(); let now = chrono::offset::Utc::now().naive_utc();
let duration = now - self.connected_at; let duration = now - self.connected_at;
let message = if self.bad {
"Server connection terminated"
} else {
"Server connection closed"
};
info!( info!(
"{} {:?}, session duration: {}", "Server connection closed {:?}, session duration: {}",
message,
self.address, self.address,
crate::format_duration(&duration) crate::format_duration(&duration)
); );

View File

@@ -107,19 +107,8 @@ impl Collector {
loop { loop {
interval.tick().await; interval.tick().await;
// Hold read lock for duration of update to retain all server stats for stats in SERVER_STATS.read().values() {
let server_stats = SERVER_STATS.read(); stats.address_stats().update_averages();
for stats in server_stats.values() {
if !stats.check_address_stat_average_is_updated_status() {
stats.address_stats().update_averages();
stats.set_address_stat_average_is_updated_status(true);
}
}
// Reset to false for next update
for stats in server_stats.values() {
stats.set_address_stat_average_is_updated_status(false);
} }
} }
}); });

View File

@@ -1,3 +1,4 @@
use log::warn;
use std::sync::atomic::*; use std::sync::atomic::*;
use std::sync::Arc; use std::sync::Arc;
@@ -12,16 +13,6 @@ pub struct AddressStats {
pub total_query_time: Arc<AtomicU64>, pub total_query_time: Arc<AtomicU64>,
pub total_wait_time: Arc<AtomicU64>, pub total_wait_time: Arc<AtomicU64>,
pub total_errors: Arc<AtomicU64>, pub total_errors: Arc<AtomicU64>,
pub old_total_xact_count: Arc<AtomicU64>,
pub old_total_query_count: Arc<AtomicU64>,
pub old_total_received: Arc<AtomicU64>,
pub old_total_sent: Arc<AtomicU64>,
pub old_total_xact_time: Arc<AtomicU64>,
pub old_total_query_time: Arc<AtomicU64>,
pub old_total_wait_time: Arc<AtomicU64>,
pub old_total_errors: Arc<AtomicU64>,
pub avg_query_count: Arc<AtomicU64>, pub avg_query_count: Arc<AtomicU64>,
pub avg_query_time: Arc<AtomicU64>, pub avg_query_time: Arc<AtomicU64>,
pub avg_recv: Arc<AtomicU64>, pub avg_recv: Arc<AtomicU64>,
@@ -30,9 +21,6 @@ pub struct AddressStats {
pub avg_xact_time: Arc<AtomicU64>, pub avg_xact_time: Arc<AtomicU64>,
pub avg_xact_count: Arc<AtomicU64>, pub avg_xact_count: Arc<AtomicU64>,
pub avg_wait_time: Arc<AtomicU64>, pub avg_wait_time: Arc<AtomicU64>,
// Determines if the averages have been updated since the last time they were reported
pub averages_updated: Arc<AtomicBool>,
} }
impl IntoIterator for AddressStats { impl IntoIterator for AddressStats {
@@ -116,15 +104,16 @@ impl AddressStats {
} }
pub fn update_averages(&self) { pub fn update_averages(&self) {
let (totals, averages, old_totals) = self.fields_iterators(); let (totals, averages) = self.fields_iterators();
for (total, average, old_total) in itertools::izip!(totals, averages, old_totals) { for data in totals.iter().zip(averages.iter()) {
let total_value = total.load(Ordering::Relaxed); let (total, average) = data;
let old_total_value = old_total.load(Ordering::Relaxed); if let Err(err) = average.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |avg| {
average.store( let total = total.load(Ordering::Relaxed);
(total_value - old_total_value) / (crate::stats::STAT_PERIOD / 1_000), let avg = (total - avg) / (crate::stats::STAT_PERIOD / 1_000); // Avg / second
Ordering::Relaxed, Some(avg)
); // Avg / second }) {
old_total.store(total_value, Ordering::Relaxed); warn!("Could not update averages for addresses stats, {:?}", err);
}
} }
} }
@@ -134,42 +123,27 @@ impl AddressStats {
} }
} }
fn fields_iterators( fn fields_iterators(&self) -> (Vec<Arc<AtomicU64>>, Vec<Arc<AtomicU64>>) {
&self,
) -> (
Vec<Arc<AtomicU64>>,
Vec<Arc<AtomicU64>>,
Vec<Arc<AtomicU64>>,
) {
let mut totals: Vec<Arc<AtomicU64>> = Vec::new(); let mut totals: Vec<Arc<AtomicU64>> = Vec::new();
let mut averages: Vec<Arc<AtomicU64>> = Vec::new(); let mut averages: Vec<Arc<AtomicU64>> = Vec::new();
let mut old_totals: Vec<Arc<AtomicU64>> = Vec::new();
totals.push(self.total_xact_count.clone()); totals.push(self.total_xact_count.clone());
old_totals.push(self.old_total_xact_count.clone());
averages.push(self.avg_xact_count.clone()); averages.push(self.avg_xact_count.clone());
totals.push(self.total_query_count.clone()); totals.push(self.total_query_count.clone());
old_totals.push(self.old_total_query_count.clone());
averages.push(self.avg_query_count.clone()); averages.push(self.avg_query_count.clone());
totals.push(self.total_received.clone()); totals.push(self.total_received.clone());
old_totals.push(self.old_total_received.clone());
averages.push(self.avg_recv.clone()); averages.push(self.avg_recv.clone());
totals.push(self.total_sent.clone()); totals.push(self.total_sent.clone());
old_totals.push(self.old_total_sent.clone());
averages.push(self.avg_sent.clone()); averages.push(self.avg_sent.clone());
totals.push(self.total_xact_time.clone()); totals.push(self.total_xact_time.clone());
old_totals.push(self.old_total_xact_time.clone());
averages.push(self.avg_xact_time.clone()); averages.push(self.avg_xact_time.clone());
totals.push(self.total_query_time.clone()); totals.push(self.total_query_time.clone());
old_totals.push(self.old_total_query_time.clone());
averages.push(self.avg_query_time.clone()); averages.push(self.avg_query_time.clone());
totals.push(self.total_wait_time.clone()); totals.push(self.total_wait_time.clone());
old_totals.push(self.old_total_wait_time.clone());
averages.push(self.avg_wait_time.clone()); averages.push(self.avg_wait_time.clone());
totals.push(self.total_errors.clone()); totals.push(self.total_errors.clone());
old_totals.push(self.old_total_errors.clone());
averages.push(self.avg_errors.clone()); averages.push(self.avg_errors.clone());
(totals, averages, old_totals) (totals, averages)
} }
} }

View File

@@ -139,17 +139,6 @@ impl ServerStats {
self.address.stats.clone() self.address.stats.clone()
} }
pub fn check_address_stat_average_is_updated_status(&self) -> bool {
self.address.stats.averages_updated.load(Ordering::Relaxed)
}
pub fn set_address_stat_average_is_updated_status(&self, is_checked: bool) {
self.address
.stats
.averages_updated
.store(is_checked, Ordering::Relaxed);
}
// Helper methods for show_servers // Helper methods for show_servers
pub fn pool_name(&self) -> String { pub fn pool_name(&self) -> String {
self.pool_stats.database() self.pool_stats.database()

View File

@@ -14,12 +14,11 @@ describe "Admin" do
describe "SHOW STATS" do describe "SHOW STATS" do
context "clients connect and make one query" do context "clients connect and make one query" do
it "updates *_query_time and *_wait_time" do it "updates *_query_time and *_wait_time" do
connections = Array.new(3) { PG::connect("#{pgcat_conn_str}?application_name=one_query") } connection = PG::connect("#{pgcat_conn_str}?application_name=one_query")
connections.each do |c| connection.async_exec("SELECT pg_sleep(0.25)")
Thread.new { c.async_exec("SELECT pg_sleep(0.25)") } connection.async_exec("SELECT pg_sleep(0.25)")
end connection.async_exec("SELECT pg_sleep(0.25)")
sleep(1) connection.close
connections.map(&:close)
# wait for averages to be calculated, we shouldn't do this too often # wait for averages to be calculated, we shouldn't do this too often
sleep(15.5) sleep(15.5)
@@ -27,7 +26,7 @@ describe "Admin" do
results = admin_conn.async_exec("SHOW STATS")[0] results = admin_conn.async_exec("SHOW STATS")[0]
admin_conn.close admin_conn.close
expect(results["total_query_time"].to_i).to be_within(200).of(750) expect(results["total_query_time"].to_i).to be_within(200).of(750)
expect(results["avg_query_time"].to_i).to be_within(20).of(50) expect(results["avg_query_time"].to_i).to_not eq(0)
expect(results["total_wait_time"].to_i).to_not eq(0) expect(results["total_wait_time"].to_i).to_not eq(0)
expect(results["avg_wait_time"].to_i).to_not eq(0) expect(results["avg_wait_time"].to_i).to_not eq(0)

View File

@@ -41,24 +41,7 @@ module Helpers
"1" => { "database" => "shard1", "servers" => [["localhost", primary1.port.to_s, "primary"]] }, "1" => { "database" => "shard1", "servers" => [["localhost", primary1.port.to_s, "primary"]] },
"2" => { "database" => "shard2", "servers" => [["localhost", primary2.port.to_s, "primary"]] }, "2" => { "database" => "shard2", "servers" => [["localhost", primary2.port.to_s, "primary"]] },
}, },
"users" => { "0" => user }, "users" => { "0" => user }
"plugins" => {
"intercept" => {
"enabled" => true,
"queries" => {
"0" => {
"query" => "select current_database() as a, current_schemas(false) as b",
"schema" => [
["a", "text"],
["b", "text"],
],
"result" => [
["${DATABASE}", "{public}"],
]
}
}
}
}
} }
} }
pgcat.update_config(pgcat_cfg) pgcat.update_config(pgcat_cfg)

View File

@@ -241,18 +241,6 @@ describe "Miscellaneous" do
expect(processes.primary.count_query("DISCARD ALL")).to eq(10) expect(processes.primary.count_query("DISCARD ALL")).to eq(10)
end end
it "Resets server roles correctly" do
10.times do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.async_exec("SET SERVER ROLE to 'primary'")
conn.async_exec("SELECT 1")
conn.async_exec("SET statement_timeout to 5000")
conn.close
end
expect(processes.primary.count_query("RESET ROLE")).to eq(10)
end
end end
context "transaction mode" do context "transaction mode" do