Compare commits

..

4 Commits

Author SHA1 Message Date
Lev Kokotov
e7265cbf91 fix flakey test 2023-05-03 16:01:48 -07:00
Lev Kokotov
d738ba28b6 fix tests 2023-05-03 15:42:16 -07:00
Lev Kokotov
ff80bb75cc clean up 2023-05-03 15:38:03 -07:00
Lev Kokotov
374a6b138b more plugins 2023-05-03 15:29:16 -07:00
22 changed files with 452 additions and 771 deletions

48
Cargo.lock generated
View File

@@ -250,12 +250,6 @@ dependencies = [
"subtle",
]
[[package]]
name = "either"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91"
[[package]]
name = "enum-as-inner"
version = "0.5.1"
@@ -314,12 +308,6 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
[[package]]
name = "fallible-iterator"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649"
[[package]]
name = "fnv"
version = "1.0.7"
@@ -670,15 +658,6 @@ dependencies = [
"windows-sys",
]
[[package]]
name = "itertools"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "1.0.5"
@@ -903,7 +882,7 @@ checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e"
[[package]]
name = "pgcat"
version = "1.0.2-alpha3"
version = "1.0.2-alpha1"
dependencies = [
"arc-swap",
"async-trait",
@@ -914,11 +893,10 @@ dependencies = [
"chrono",
"env_logger",
"exitcode",
"fallible-iterator 0.3.0",
"fallible-iterator",
"futures",
"hmac",
"hyper",
"itertools",
"jemallocator",
"log",
"md-5",
@@ -1032,7 +1010,7 @@ dependencies = [
"base64",
"byteorder",
"bytes",
"fallible-iterator 0.2.0",
"fallible-iterator",
"hmac",
"md-5",
"memchr",
@@ -1258,9 +1236,9 @@ dependencies = [
[[package]]
name = "serde_spanned"
version = "0.6.2"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93107647184f6027e3b7dcb2e11034cf95ffa1e3a682c67951963ac69c1c007d"
checksum = "0efd8caf556a6cebd3b285caf480045fcc1ac04f6bd786b09a6f11af30c4fcf4"
dependencies = [
"serde",
]
@@ -1534,9 +1512,9 @@ dependencies = [
[[package]]
name = "toml"
version = "0.7.4"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6135d499e69981f9ff0ef2167955a5333c35e36f6937d382974566b3d5b94ec"
checksum = "b403acf6f2bb0859c93c7f0d967cb4a75a7ac552100f9322faf64dc047669b21"
dependencies = [
"serde",
"serde_spanned",
@@ -1546,18 +1524,18 @@ dependencies = [
[[package]]
name = "toml_datetime"
version = "0.6.2"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a76a9312f5ba4c2dec6b9161fdf25d87ad8a09256ccea5a556fef03c706a10f"
checksum = "3ab8ed2edee10b50132aed5f331333428b011c99402b5a534154ed15746f9622"
dependencies = [
"serde",
]
[[package]]
name = "toml_edit"
version = "0.19.9"
version = "0.19.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92d964908cec0d030b812013af25a0e57fddfadb1e066ecc6681d86253129d4f"
checksum = "08de71aa0d6e348f070457f85af8bd566e2bc452156a423ddf22861b3a953fae"
dependencies = [
"indexmap",
"serde",
@@ -1911,9 +1889,9 @@ checksum = "447660ad36a13288b1db4d4248e857b510e8c3a225c822ba4fb748c0aafecffd"
[[package]]
name = "winnow"
version = "0.4.6"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61de7bac303dc551fe038e2b3cef0f571087a47571ea6e79a87692ac99b99699"
checksum = "faf09497b8f8b5ac5d3bb4d05c0a99be20f26fd3d5f2db7b0716e946d5103658"
dependencies = [
"memchr",
]

View File

@@ -1,6 +1,6 @@
[package]
name = "pgcat"
version = "1.0.2-alpha3"
version = "1.0.2-alpha1"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -38,14 +38,13 @@ socket2 = { version = "0.4.7", features = ["all"] }
nix = "0.26.2"
atomic_enum = "0.2.0"
postgres-protocol = "0.6.5"
fallible-iterator = "0.3"
fallible-iterator = "0.2"
pin-project = "1"
webpki-roots = "0.23"
rustls = { version = "0.21", features = ["dangerous_configuration"] }
trust-dns-resolver = "0.22.0"
tokio-test = "0.4.2"
serde_json = "1"
itertools = "0.10"
[target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0"

View File

@@ -25,7 +25,7 @@ x-common-env-pg:
services:
main:
image: gcr.io/google_containers/pause:3.2
image: kubernetes/pause
ports:
- 6432
@@ -64,7 +64,7 @@ services:
<<: *common-env-pg
POSTGRES_INITDB_ARGS: --auth-local=md5 --auth-host=md5 --auth=md5
PGPORT: 10432
command: ["postgres", "-p", "10432", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-c", "pg_stat_statements.max=100000"]
command: ["postgres", "-p", "5432", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-c", "pg_stat_statements.max=100000"]
toxiproxy:
build: .

View File

@@ -1,22 +0,0 @@
# This is an example of the most basic config
# that will mimic what PgBouncer does in transaction mode with one server.
[general]
host = "0.0.0.0"
port = 6433
admin_username = "pgcat"
admin_password = "pgcat"
[pools.pgml.users.0]
username = "postgres"
password = "postgres"
pool_size = 10
min_pool_size = 1
pool_mode = "transaction"
[pools.pgml.shards.0]
servers = [
["127.0.0.1", 28815, "primary"]
]
database = "postgres"

View File

@@ -77,58 +77,6 @@ admin_username = "admin_user"
# Password to access the virtual administrative database
admin_password = "admin_pass"
# Default plugins that are configured on all pools.
[plugins]
# Prewarmer plugin that runs queries on server startup, before giving the connection
# to the client.
[plugins.prewarmer]
enabled = false
queries = [
"SELECT pg_prewarm('pgbench_accounts')",
]
# Log all queries to stdout.
[plugins.query_logger]
enabled = false
# Block access to tables that Postgres does not allow us to control.
[plugins.table_access]
enabled = false
tables = [
"pg_user",
"pg_roles",
"pg_database",
]
# Intercept user queries and give a fake reply.
[plugins.intercept]
enabled = true
[plugins.intercept.queries.0]
query = "select current_database() as a, current_schemas(false) as b"
schema = [
["a", "text"],
["b", "text"],
]
result = [
["${DATABASE}", "{public}"],
]
[plugins.intercept.queries.1]
query = "select current_database(), current_schema(), current_user"
schema = [
["current_database", "text"],
["current_schema", "text"],
["current_user", "text"],
]
result = [
["${DATABASE}", "public", "${USER}"],
]
# pool configs are structured as pool.<pool_name>
# the pool_name is what clients use as database name when connecting.
# For a pool named `sharded_db`, clients access that pool using connection string like
@@ -206,20 +154,12 @@ connect_timeout = 3000
# Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
# dns_max_ttl = 30
# Plugins can be configured on a pool-per-pool basis. This overrides the global plugins setting,
# so all plugins have to be configured here again.
[pool.sharded_db.plugins]
[plugins]
[pools.sharded_db.plugins.prewarmer]
enabled = true
queries = [
"SELECT pg_prewarm('pgbench_accounts')",
]
[pools.sharded_db.plugins.query_logger]
[plugins.query_logger]
enabled = false
[pools.sharded_db.plugins.table_access]
[plugins.table_access]
enabled = false
tables = [
"pg_user",
@@ -227,10 +167,10 @@ tables = [
"pg_database",
]
[pools.sharded_db.plugins.intercept]
[plugins.intercept]
enabled = true
[pools.sharded_db.plugins.intercept.queries.0]
[plugins.intercept.queries.0]
query = "select current_database() as a, current_schemas(false) as b"
schema = [
@@ -241,7 +181,7 @@ result = [
["${DATABASE}", "{public}"],
]
[pools.sharded_db.plugins.intercept.queries.1]
[plugins.intercept.queries.1]
query = "select current_database(), current_schema(), current_user"
schema = [

View File

@@ -1313,7 +1313,7 @@ where
.receive_server_message(server, &address, &pool, &self.stats.clone())
.await?;
match write_all_flush(&mut self.write, &response).await {
match write_all_half(&mut self.write, &response).await {
Ok(_) => (),
Err(err) => {
server.mark_bad();
@@ -1408,7 +1408,7 @@ where
.receive_server_message(server, address, pool, client_stats)
.await?;
match write_all_flush(&mut self.write, &response).await {
match write_all_half(&mut self.write, &response).await {
Ok(_) => (),
Err(err) => {
server.mark_bad();

View File

@@ -122,16 +122,6 @@ impl Default for Address {
}
}
impl std::fmt::Display for Address {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"[address: {}:{}][database: {}][user: {}]",
self.host, self.port, self.database, self.username
)
}
}
// We need to implement PartialEq by ourselves so we skip stats in the comparison
impl PartialEq for Address {
fn eq(&self, other: &Self) -> bool {
@@ -245,8 +235,6 @@ pub struct General {
pub port: u16,
pub enable_prometheus_exporter: Option<bool>,
#[serde(default = "General::default_prometheus_exporter_port")]
pub prometheus_exporter_port: i16,
#[serde(default = "General::default_connect_timeout")]
@@ -310,9 +298,6 @@ pub struct General {
pub admin_username: String,
pub admin_password: String,
#[serde(default = "General::default_validate_config")]
pub validate_config: bool,
// Support for auth query
pub auth_query: Option<String>,
pub auth_query_user: Option<String>,
@@ -382,14 +367,6 @@ impl General {
pub fn default_idle_client_in_transaction_timeout() -> u64 {
0
}
pub fn default_validate_config() -> bool {
true
}
pub fn default_prometheus_exporter_port() -> i16 {
9930
}
}
impl Default for General {
@@ -425,7 +402,6 @@ impl Default for General {
auth_query_user: None,
auth_query_password: None,
server_lifetime: 1000 * 3600 * 24, // 24 hours,
validate_config: true,
}
}
}
@@ -478,7 +454,6 @@ pub struct Pool {
#[serde(default = "Pool::default_load_balancing_mode")]
pub load_balancing_mode: LoadBalancingMode,
#[serde(default = "Pool::default_default_role")]
pub default_role: String,
#[serde(default)] // False
@@ -487,18 +462,12 @@ pub struct Pool {
#[serde(default)] // False
pub primary_reads_enabled: bool,
/// Maximum time to allow for establishing a new server connection.
pub connect_timeout: Option<u64>,
/// Close idle connections that have been opened for longer than this.
pub idle_timeout: Option<u64>,
/// Close server connections that have been opened for longer than this.
/// Only applied to idle connections. If the connection is actively used for
/// longer than this period, the pool will not interrupt it.
pub server_lifetime: Option<u64>,
#[serde(default = "Pool::default_sharding_function")]
pub sharding_function: ShardingFunction,
#[serde(default = "Pool::default_automatic_sharding_key")]
@@ -512,10 +481,6 @@ pub struct Pool {
pub auth_query_user: Option<String>,
pub auth_query_password: Option<String>,
#[serde(default = "Pool::default_cleanup_server_connections")]
pub cleanup_server_connections: bool,
pub plugins: Option<Plugins>,
pub shards: BTreeMap<String, Shard>,
pub users: BTreeMap<String, User>,
// Note, don't put simple fields below these configs. There's a compatibility issue with TOML that makes it
@@ -548,18 +513,6 @@ impl Pool {
None
}
pub fn default_default_role() -> String {
"any".into()
}
pub fn default_sharding_function() -> ShardingFunction {
ShardingFunction::PgBigintHash
}
pub fn default_cleanup_server_connections() -> bool {
true
}
pub fn validate(&mut self) -> Result<(), Error> {
match self.default_role.as_ref() {
"any" => (),
@@ -648,8 +601,6 @@ impl Default for Pool {
auth_query_user: None,
auth_query_password: None,
server_lifetime: None,
plugins: None,
cleanup_server_connections: true,
}
}
}
@@ -728,60 +679,39 @@ impl Default for Shard {
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Plugins {
pub intercept: Option<Intercept>,
pub table_access: Option<TableAccess>,
pub query_logger: Option<QueryLogger>,
pub prewarmer: Option<Prewarmer>,
}
impl std::fmt::Display for Plugins {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"interceptor: {}, table_access: {}, query_logger: {}, prewarmer: {}",
self.intercept.is_some(),
self.table_access.is_some(),
self.query_logger.is_some(),
self.prewarmer.is_some(),
)
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Intercept {
pub enabled: bool,
pub queries: BTreeMap<String, Query>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct TableAccess {
pub enabled: bool,
pub tables: Vec<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct QueryLogger {
pub enabled: bool,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
pub struct Prewarmer {
pub enabled: bool,
pub queries: Vec<String>,
}
impl Intercept {
pub fn substitute(&mut self, db: &str, user: &str) {
for (_, query) in self.queries.iter_mut() {
query.substitute(db, user);
query.query = query.query.to_ascii_lowercase();
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Query {
pub query: String,
pub schema: Vec<Vec<String>>,
@@ -815,13 +745,8 @@ pub struct Config {
#[serde(default = "Config::default_path")]
pub path: String,
// General and global settings.
pub general: General,
// Plugins that should run in all pools.
pub plugins: Option<Plugins>,
// Connection pools.
pub pools: HashMap<String, Pool>,
}
@@ -1006,13 +931,6 @@ impl Config {
"Server TLS certificate verification: {}",
self.general.verify_server_certificate
);
info!(
"Plugins: {}",
match self.plugins {
Some(ref plugins) => plugins.to_string(),
None => "not configured".into(),
}
);
for (pool_name, pool_config) in &self.pools {
// TODO: Make this output prettier (maybe a table?)
@@ -1079,18 +997,6 @@ impl Config {
None => "default".to_string(),
}
);
info!(
"[pool: {}] Cleanup server connections: {}",
pool_name, pool_config.cleanup_server_connections
);
info!(
"[pool: {}] Plugins: {}",
pool_name,
match pool_config.plugins {
Some(ref plugins) => plugins.to_string(),
None => "not configured".into(),
}
);
for user in &pool_config.users {
info!(

View File

@@ -43,8 +43,6 @@ impl MirroredClient {
ClientServerMap::default(),
Arc::new(PoolStats::new(identifier, cfg.clone())),
Arc::new(RwLock::new(None)),
None,
true,
);
Pool::builder()

View File

@@ -2,21 +2,52 @@
//!
//! It intercepts queries and returns fake results.
use arc_swap::ArcSwap;
use async_trait::async_trait;
use bytes::{BufMut, BytesMut};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use sqlparser::ast::Statement;
use std::collections::HashMap;
use log::debug;
use log::{debug, info};
use std::sync::Arc;
use crate::{
config::Intercept as InterceptConfig,
errors::Error,
messages::{command_complete, data_row_nullable, row_description, DataType},
plugins::{Plugin, PluginOutput},
pool::{PoolIdentifier, PoolMap},
query_router::QueryRouter,
};
pub static CONFIG: Lazy<ArcSwap<HashMap<PoolIdentifier, InterceptConfig>>> =
Lazy::new(|| ArcSwap::from_pointee(HashMap::new()));
/// Check if the interceptor plugin has been enabled.
pub fn enabled() -> bool {
!CONFIG.load().is_empty()
}
pub fn setup(intercept_config: &InterceptConfig, pools: &PoolMap) {
let mut config = HashMap::new();
for (identifier, _) in pools.iter() {
let mut intercept_config = intercept_config.clone();
intercept_config.substitute(&identifier.db, &identifier.user);
config.insert(identifier.clone(), intercept_config);
}
CONFIG.store(Arc::new(config));
info!("Intercepting {} queries", intercept_config.queries.len());
}
pub fn disable() {
CONFIG.store(Arc::new(HashMap::new()));
}
// TODO: use these structs for deserialization
#[derive(Serialize, Deserialize)]
pub struct Rule {
@@ -32,35 +63,33 @@ pub struct Column {
}
/// The intercept plugin.
pub struct Intercept<'a> {
pub enabled: bool,
pub config: &'a InterceptConfig,
}
pub struct Intercept;
#[async_trait]
impl<'a> Plugin for Intercept<'a> {
impl Plugin for Intercept {
async fn run(
&mut self,
query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error> {
if !self.enabled || ast.is_empty() {
if ast.is_empty() {
return Ok(PluginOutput::Allow);
}
let mut config = self.config.clone();
config.substitute(
let mut result = BytesMut::new();
let query_map = match CONFIG.load().get(&PoolIdentifier::new(
&query_router.pool_settings().db,
&query_router.pool_settings().user.username,
);
let mut result = BytesMut::new();
)) {
Some(query_map) => query_map.clone(),
None => return Ok(PluginOutput::Allow),
};
for q in ast {
// Normalization
let q = q.to_string().to_ascii_lowercase();
for (_, target) in config.queries.iter() {
for (_, target) in query_map.queries.iter() {
if target.query.as_str() == q {
debug!("Intercepting query: {}", q);
@@ -118,3 +147,142 @@ impl<'a> Plugin for Intercept<'a> {
}
}
}
/// Make IntelliJ SQL plugin believe it's talking to an actual database
/// instead of PgCat.
#[allow(dead_code)]
fn fool_datagrip(database: &str, user: &str) -> Value {
json!([
{
"query": "select current_database() as a, current_schemas(false) as b",
"schema": [
{
"name": "a",
"data_type": "text",
},
{
"name": "b",
"data_type": "anyarray",
},
],
"result": [
[database, "{public}"],
],
},
{
"query": "select current_database(), current_schema(), current_user",
"schema": [
{
"name": "current_database",
"data_type": "text",
},
{
"name": "current_schema",
"data_type": "text",
},
{
"name": "current_user",
"data_type": "text",
}
],
"result": [
["sharded_db", "public", "sharding_user"],
],
},
{
"query": "select cast(n.oid as bigint) as id, datname as name, d.description, datistemplate as is_template, datallowconn as allow_connections, pg_catalog.pg_get_userbyid(n.datdba) as \"owner\" from pg_catalog.pg_database as n left join pg_catalog.pg_shdescription as d on n.oid = d.objoid order by case when datname = pg_catalog.current_database() then -cast(1 as bigint) else cast(n.oid as bigint) end",
"schema": [
{
"name": "id",
"data_type": "oid",
},
{
"name": "name",
"data_type": "text",
},
{
"name": "description",
"data_type": "text",
},
{
"name": "is_template",
"data_type": "bool",
},
{
"name": "allow_connections",
"data_type": "bool",
},
{
"name": "owner",
"data_type": "text",
}
],
"result": [
["16387", database, "", "f", "t", user],
]
},
{
"query": "select cast(r.oid as bigint) as role_id, rolname as role_name, rolsuper as is_super, rolinherit as is_inherit, rolcreaterole as can_createrole, rolcreatedb as can_createdb, rolcanlogin as can_login, rolreplication as is_replication, rolconnlimit as conn_limit, rolvaliduntil as valid_until, rolbypassrls as bypass_rls, rolconfig as config, d.description from pg_catalog.pg_roles as r left join pg_catalog.pg_shdescription as d on d.objoid = r.oid",
"schema": [
{
"name": "role_id",
"data_type": "oid",
},
{
"name": "role_name",
"data_type": "text",
},
{
"name": "is_super",
"data_type": "bool",
},
{
"name": "is_inherit",
"data_type": "bool",
},
{
"name": "can_createrole",
"data_type": "bool",
},
{
"name": "can_createdb",
"data_type": "bool",
},
{
"name": "can_login",
"data_type": "bool",
},
{
"name": "is_replication",
"data_type": "bool",
},
{
"name": "conn_limit",
"data_type": "int4",
},
{
"name": "valid_until",
"data_type": "text",
},
{
"name": "bypass_rls",
"data_type": "bool",
},
{
"name": "config",
"data_type": "text",
},
{
"name": "description",
"data_type": "text",
},
],
"result": [
["10", "postgres", "f", "t", "f", "f", "t", "f", "-1", "", "f", "", ""],
["16419", user, "f", "t", "f", "f", "t", "f", "-1", "", "f", "", ""],
]
}
])
}

View File

@@ -9,7 +9,6 @@
//!
pub mod intercept;
pub mod prewarmer;
pub mod query_logger;
pub mod table_access;

View File

@@ -1,28 +0,0 @@
//! Prewarm new connections before giving them to the client.
use crate::{errors::Error, server::Server};
use log::info;
pub struct Prewarmer<'a> {
pub enabled: bool,
pub server: &'a mut Server,
pub queries: &'a Vec<String>,
}
impl<'a> Prewarmer<'a> {
pub async fn run(&mut self) -> Result<(), Error> {
if !self.enabled {
return Ok(());
}
for query in self.queries {
info!(
"{} Prewarning with query: `{}`",
self.server.address(),
query
);
self.server.query(&query).await?;
}
Ok(())
}
}

View File

@@ -5,33 +5,44 @@ use crate::{
plugins::{Plugin, PluginOutput},
query_router::QueryRouter,
};
use arc_swap::ArcSwap;
use async_trait::async_trait;
use log::info;
use once_cell::sync::Lazy;
use sqlparser::ast::Statement;
use std::sync::Arc;
pub struct QueryLogger<'a> {
pub enabled: bool,
pub user: &'a str,
pub db: &'a str,
static ENABLED: Lazy<ArcSwap<bool>> = Lazy::new(|| ArcSwap::from_pointee(false));
pub struct QueryLogger;
pub fn setup() {
ENABLED.store(Arc::new(true));
info!("Logging queries to stdout");
}
pub fn disable() {
ENABLED.store(Arc::new(false));
}
pub fn enabled() -> bool {
**ENABLED.load()
}
#[async_trait]
impl<'a> Plugin for QueryLogger<'a> {
impl Plugin for QueryLogger {
async fn run(
&mut self,
_query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error> {
if !self.enabled {
return Ok(PluginOutput::Allow);
}
let query = ast
.iter()
.map(|q| q.to_string())
.collect::<Vec<String>>()
.join("; ");
info!("[pool: {}][user: {}] {}", self.user, self.db, query);
info!("{}", query);
Ok(PluginOutput::Allow)
}

View File

@@ -5,39 +5,53 @@ use async_trait::async_trait;
use sqlparser::ast::{visit_relations, Statement};
use crate::{
config::TableAccess as TableAccessConfig,
errors::Error,
plugins::{Plugin, PluginOutput},
query_router::QueryRouter,
};
use log::debug;
use log::{debug, info};
use arc_swap::ArcSwap;
use core::ops::ControlFlow;
use once_cell::sync::Lazy;
use std::sync::Arc;
pub struct TableAccess<'a> {
pub enabled: bool,
pub tables: &'a Vec<String>,
static CONFIG: Lazy<ArcSwap<Vec<String>>> = Lazy::new(|| ArcSwap::from_pointee(vec![]));
pub fn setup(config: &TableAccessConfig) {
CONFIG.store(Arc::new(config.tables.clone()));
info!("Blocking access to {} tables", config.tables.len());
}
pub fn enabled() -> bool {
!CONFIG.load().is_empty()
}
pub fn disable() {
CONFIG.store(Arc::new(vec![]));
}
pub struct TableAccess;
#[async_trait]
impl<'a> Plugin for TableAccess<'a> {
impl Plugin for TableAccess {
async fn run(
&mut self,
_query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error> {
if !self.enabled {
return Ok(PluginOutput::Allow);
}
let mut found = None;
let forbidden_tables = CONFIG.load();
visit_relations(ast, |relation| {
let relation = relation.to_string();
let parts = relation.split(".").collect::<Vec<&str>>();
let table_name = parts.last().unwrap();
if self.tables.contains(&table_name.to_string()) {
if forbidden_tables.contains(&table_name.to_string()) {
found = Some(table_name.to_string());
ControlFlow::<()>::Break(())
} else {

View File

@@ -17,13 +17,10 @@ use std::sync::{
use std::time::Instant;
use tokio::sync::Notify;
use crate::config::{
get_config, Address, General, LoadBalancingMode, Plugins, PoolMode, Role, User,
};
use crate::config::{get_config, Address, General, LoadBalancingMode, PoolMode, Role, User};
use crate::errors::Error;
use crate::auth_passthrough::AuthPassthrough;
use crate::plugins::prewarmer;
use crate::server::Server;
use crate::sharding::ShardingFunction;
use crate::stats::{AddressStats, ClientStats, PoolStats, ServerStats};
@@ -135,9 +132,6 @@ pub struct PoolSettings {
pub auth_query: Option<String>,
pub auth_query_user: Option<String>,
pub auth_query_password: Option<String>,
/// Plugins
pub plugins: Option<Plugins>,
}
impl Default for PoolSettings {
@@ -162,7 +156,6 @@ impl Default for PoolSettings {
auth_query: None,
auth_query_user: None,
auth_query_password: None,
plugins: None,
}
}
}
@@ -202,7 +195,6 @@ pub struct ConnectionPool {
paused: Arc<AtomicBool>,
paused_waiter: Arc<Notify>,
/// Statistics.
pub stats: Arc<PoolStats>,
/// AuthInfo
@@ -360,11 +352,6 @@ impl ConnectionPool {
client_server_map.clone(),
pool_stats.clone(),
pool_auth_hash.clone(),
match pool_config.plugins {
Some(ref plugins) => Some(plugins.clone()),
None => config.plugins.clone(),
},
pool_config.cleanup_server_connections,
);
let connect_timeout = match pool_config.connect_timeout {
@@ -390,10 +377,7 @@ impl ConnectionPool {
.min()
.unwrap();
debug!(
"[pool: {}][user: {}] Pool reaper rate: {}ms",
pool_name, user.username, reaper_rate
);
debug!("Pool reaper rate: {}ms", reaper_rate);
let pool = Pool::builder()
.max_size(user.pool_size)
@@ -402,13 +386,9 @@ impl ConnectionPool {
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
.max_lifetime(Some(std::time::Duration::from_millis(server_lifetime)))
.reaper_rate(std::time::Duration::from_millis(reaper_rate))
.test_on_check_out(false);
let pool = if config.general.validate_config {
pool.build(manager).await?
} else {
pool.build_unchecked(manager)
};
.test_on_check_out(false)
.build(manager)
.await?;
pools.push(pool);
servers.push(address);
@@ -470,10 +450,6 @@ impl ConnectionPool {
auth_query: pool_config.auth_query.clone(),
auth_query_user: pool_config.auth_query_user.clone(),
auth_query_password: pool_config.auth_query_password.clone(),
plugins: match pool_config.plugins {
Some(ref plugins) => Some(plugins.clone()),
None => config.plugins.clone(),
},
},
validated: Arc::new(AtomicBool::new(false)),
paused: Arc::new(AtomicBool::new(false)),
@@ -483,18 +459,42 @@ impl ConnectionPool {
// Connect to the servers to make sure pool configuration is valid
// before setting it globally.
// Do this async and somewhere else, we don't have to wait here.
if config.general.validate_config {
let mut validate_pool = pool.clone();
tokio::task::spawn(async move {
let _ = validate_pool.validate().await;
});
}
let mut validate_pool = pool.clone();
tokio::task::spawn(async move {
let _ = validate_pool.validate().await;
});
// There is one pool per database/user pair.
new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool);
}
}
if let Some(ref plugins) = config.plugins {
if let Some(ref intercept) = plugins.intercept {
if intercept.enabled {
crate::plugins::intercept::setup(intercept, &new_pools);
} else {
crate::plugins::intercept::disable();
}
}
if let Some(ref table_access) = plugins.table_access {
if table_access.enabled {
crate::plugins::table_access::setup(table_access);
} else {
crate::plugins::table_access::disable();
}
}
if let Some(ref query_logger) = plugins.query_logger {
if query_logger.enabled {
crate::plugins::query_logger::setup();
} else {
crate::plugins::query_logger::disable();
}
}
}
POOLS.store(Arc::new(new_pools.clone()));
Ok(())
}
@@ -639,10 +639,7 @@ impl ConnectionPool {
{
Ok(conn) => conn,
Err(err) => {
error!(
"Connection checkout error for instance {:?}, error: {:?}",
address, err
);
error!("Banning instance {:?}, error: {:?}", address, err);
self.ban(address, BanReason::FailedCheckout, Some(client_stats));
address.stats.error();
client_stats.idle();
@@ -718,7 +715,7 @@ impl ConnectionPool {
// Health check failed.
Err(err) => {
error!(
"Failed health check on instance {:?}, error: {:?}",
"Banning instance {:?} because of failed health check, {:?}",
address, err
);
}
@@ -727,7 +724,7 @@ impl ConnectionPool {
// Health check timed out.
Err(err) => {
error!(
"Health check timeout on instance {:?}, error: {:?}",
"Banning instance {:?} because of health check timeout, {:?}",
address, err
);
}
@@ -749,16 +746,13 @@ impl ConnectionPool {
return;
}
error!("Banning instance {:?}, reason: {:?}", address, reason);
let now = chrono::offset::Utc::now().naive_utc();
let mut guard = self.banlist.write();
error!("Banning {:?}", address);
if let Some(client_info) = client_info {
client_info.ban_error();
address.stats.error();
}
guard[address.shard].insert(address.clone(), (reason, now));
}
@@ -915,29 +909,12 @@ impl ConnectionPool {
/// Wrapper for the bb8 connection pool.
pub struct ServerPool {
/// Server address.
address: Address,
/// Server Postgres user.
user: User,
/// Server database.
database: String,
/// Client/server mapping.
client_server_map: ClientServerMap,
/// Server statistics.
stats: Arc<PoolStats>,
/// Server auth hash (for auth passthrough).
auth_hash: Arc<RwLock<Option<String>>>,
/// Server plugins.
plugins: Option<Plugins>,
/// Should we clean up dirty connections before putting them into the pool?
cleanup_connections: bool,
}
impl ServerPool {
@@ -948,8 +925,6 @@ impl ServerPool {
client_server_map: ClientServerMap,
stats: Arc<PoolStats>,
auth_hash: Arc<RwLock<Option<String>>>,
plugins: Option<Plugins>,
cleanup_connections: bool,
) -> ServerPool {
ServerPool {
address,
@@ -958,8 +933,6 @@ impl ServerPool {
client_server_map,
stats,
auth_hash,
plugins,
cleanup_connections,
}
}
}
@@ -989,23 +962,10 @@ impl ManageConnection for ServerPool {
self.client_server_map.clone(),
stats.clone(),
self.auth_hash.clone(),
self.cleanup_connections,
)
.await
{
Ok(mut conn) => {
if let Some(ref plugins) = self.plugins {
if let Some(ref prewarmer) = plugins.prewarmer {
let mut prewarmer = prewarmer::Prewarmer {
enabled: prewarmer.enabled,
server: &mut conn,
queries: &prewarmer.queries,
};
prewarmer.run().await?;
}
}
Ok(conn) => {
stats.idle();
Ok(conn)
}

View File

@@ -15,7 +15,10 @@ use sqlparser::parser::Parser;
use crate::config::Role;
use crate::errors::Error;
use crate::messages::BytesMutReader;
use crate::plugins::{Intercept, Plugin, PluginOutput, QueryLogger, TableAccess};
use crate::plugins::{
intercept, query_logger, table_access, Intercept, Plugin, PluginOutput, QueryLogger,
TableAccess,
};
use crate::pool::PoolSettings;
use crate::sharding::Sharder;
@@ -790,27 +793,13 @@ impl QueryRouter {
/// Add your plugins here and execute them.
pub async fn execute_plugins(&self, ast: &Vec<Statement>) -> Result<PluginOutput, Error> {
let plugins = match self.pool_settings.plugins {
Some(ref plugins) => plugins,
None => return Ok(PluginOutput::Allow),
};
if let Some(ref query_logger) = plugins.query_logger {
let mut query_logger = QueryLogger {
enabled: query_logger.enabled,
user: &self.pool_settings.user.username,
db: &self.pool_settings.db,
};
if query_logger::enabled() {
let mut query_logger = QueryLogger {};
let _ = query_logger.run(&self, ast).await;
}
if let Some(ref intercept) = plugins.intercept {
let mut intercept = Intercept {
enabled: intercept.enabled,
config: &intercept,
};
if intercept::enabled() {
let mut intercept = Intercept {};
let result = intercept.run(&self, ast).await;
if let Ok(PluginOutput::Intercept(output)) = result {
@@ -818,12 +807,8 @@ impl QueryRouter {
}
}
if let Some(ref table_access) = plugins.table_access {
let mut table_access = TableAccess {
enabled: table_access.enabled,
tables: &table_access.tables,
};
if table_access::enabled() {
let mut table_access = TableAccess {};
let result = table_access.run(&self, ast).await;
if let Ok(PluginOutput::Deny(error)) = result {
@@ -1176,7 +1161,6 @@ mod test {
auth_query_password: None,
auth_query_user: None,
db: "test".to_string(),
plugins: None,
};
let mut qr = QueryRouter::new();
assert_eq!(qr.active_role, None);
@@ -1251,9 +1235,7 @@ mod test {
auth_query_password: None,
auth_query_user: None,
db: "test".to_string(),
plugins: None,
};
let mut qr = QueryRouter::new();
qr.update_pool_settings(pool_settings.clone());
@@ -1397,25 +1379,17 @@ mod test {
#[tokio::test]
async fn test_table_access_plugin() {
use crate::config::{Plugins, TableAccess};
let table_access = TableAccess {
use crate::config::TableAccess;
let ta = TableAccess {
enabled: true,
tables: vec![String::from("pg_database")],
};
let plugins = Plugins {
table_access: Some(table_access),
intercept: None,
query_logger: None,
prewarmer: None,
};
crate::plugins::table_access::setup(&ta);
QueryRouter::setup();
let mut pool_settings = PoolSettings::default();
pool_settings.query_parser_enabled = true;
pool_settings.plugins = Some(plugins);
let mut qr = QueryRouter::new();
qr.update_pool_settings(pool_settings);
let qr = QueryRouter::new();
let query = simple_query("SELECT * FROM pg_database");
let ast = QueryRouter::parse(&query).unwrap();
@@ -1429,17 +1403,4 @@ mod test {
))
);
}
#[tokio::test]
async fn test_plugins_disabled_by_defaault() {
QueryRouter::setup();
let qr = QueryRouter::new();
let query = simple_query("SELECT * FROM pg_database");
let ast = QueryRouter::parse(&query).unwrap();
let res = qr.execute_plugins(&ast).await;
assert_eq!(res, Ok(PluginOutput::Allow));
}
}

View File

@@ -103,48 +103,6 @@ impl StreamInner {
}
}
#[derive(Copy, Clone)]
struct CleanupState {
/// If server connection requires DISCARD ALL before checkin because of set statement
needs_cleanup_set: bool,
/// If server connection requires DISCARD ALL before checkin because of prepare statement
needs_cleanup_prepare: bool,
}
impl CleanupState {
fn new() -> Self {
CleanupState {
needs_cleanup_set: false,
needs_cleanup_prepare: false,
}
}
fn needs_cleanup(&self) -> bool {
self.needs_cleanup_set || self.needs_cleanup_prepare
}
fn set_true(&mut self) {
self.needs_cleanup_set = true;
self.needs_cleanup_prepare = true;
}
fn reset(&mut self) {
self.needs_cleanup_set = false;
self.needs_cleanup_prepare = false;
}
}
impl std::fmt::Display for CleanupState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"SET: {}, PREPARE: {}",
self.needs_cleanup_set, self.needs_cleanup_prepare
)
}
}
/// Server state.
pub struct Server {
/// Server host, e.g. localhost,
@@ -173,8 +131,8 @@ pub struct Server {
/// Is the server broken? We'll remote it from the pool if so.
bad: bool,
/// If server connection requires DISCARD ALL before checkin
cleanup_state: CleanupState,
/// If server connection requires a DISCARD ALL before checkin
needs_cleanup: bool,
/// Mapping of clients and servers used for query cancellation.
client_server_map: ClientServerMap,
@@ -188,16 +146,13 @@ pub struct Server {
/// Application name using the server at the moment.
application_name: String,
/// Last time that a successful server send or response happened
// Last time that a successful server send or response happened
last_activity: SystemTime,
mirror_manager: Option<MirroringManager>,
/// Associated addresses used
// Associated addresses used
addr_set: Option<AddrSet>,
/// Should clean up dirty connections?
cleanup_connections: bool,
}
impl Server {
@@ -210,7 +165,6 @@ impl Server {
client_server_map: ClientServerMap,
stats: Arc<ServerStats>,
auth_hash: Arc<RwLock<Option<String>>>,
cleanup_connections: bool,
) -> Result<Server, Error> {
let cached_resolver = CACHED_RESOLVER.load();
let mut addr_set: Option<AddrSet> = None;
@@ -676,7 +630,7 @@ impl Server {
in_transaction: false,
data_available: false,
bad: false,
cleanup_state: CleanupState::new(),
needs_cleanup: false,
client_server_map,
addr_set,
connected_at: chrono::offset::Utc::now().naive_utc(),
@@ -691,7 +645,6 @@ impl Server {
address.mirrors.clone(),
)),
},
cleanup_connections,
};
server.set_name("pgcat").await?;
@@ -752,10 +705,7 @@ impl Server {
Ok(())
}
Err(err) => {
error!(
"Terminating server {:?} because of: {:?}",
self.address, err
);
error!("Terminating server because of: {:?}", err);
self.bad = true;
Err(err)
}
@@ -770,10 +720,7 @@ impl Server {
let mut message = match read_message(&mut self.stream).await {
Ok(message) => message,
Err(err) => {
error!(
"Terminating server {:?} because of: {:?}",
self.address, err
);
error!("Terminating server because of: {:?}", err);
self.bad = true;
return Err(err);
}
@@ -840,12 +787,12 @@ impl Server {
// This will reduce amount of discard statements sent
if !self.in_transaction {
debug!("Server connection marked for clean up");
self.cleanup_state.needs_cleanup_set = true;
self.needs_cleanup = true;
}
}
"PREPARE\0" => {
debug!("Server connection marked for clean up");
self.cleanup_state.needs_cleanup_prepare = true;
self.needs_cleanup = true;
}
_ => (),
}
@@ -975,8 +922,6 @@ impl Server {
/// It will use the simple query protocol.
/// Result will not be returned, so this is useful for things like `SET` or `ROLLBACK`.
pub async fn query(&mut self, query: &str) -> Result<(), Error> {
debug!("Running `{}` on server {:?}", query, self.address);
let query = simple_query(query);
self.send(&query).await?;
@@ -1009,11 +954,10 @@ impl Server {
// to avoid leaking state between clients. For performance reasons we only
// send `DISCARD ALL` if we think the session is altered instead of just sending
// it before each checkin.
if self.cleanup_state.needs_cleanup() && self.cleanup_connections {
warn!("Server returned with session state altered, discarding state ({}) for application {}", self.cleanup_state, self.application_name);
if self.needs_cleanup {
warn!("Server returned with session state altered, discarding state");
self.query("DISCARD ALL").await?;
self.query("RESET ROLE").await?;
self.cleanup_state.reset();
self.needs_cleanup = false;
}
Ok(())
@@ -1025,12 +969,12 @@ impl Server {
self.application_name = name.to_string();
// We don't want `SET application_name` to mark the server connection
// as needing cleanup
let needs_cleanup_before = self.cleanup_state;
let needs_cleanup_before = self.needs_cleanup;
let result = Ok(self
.query(&format!("SET application_name = '{}'", name))
.await?);
self.cleanup_state = needs_cleanup_before;
self.needs_cleanup = needs_cleanup_before;
result
} else {
Ok(())
@@ -1055,7 +999,7 @@ impl Server {
// Marks a connection as needing DISCARD ALL at checkin
pub fn mark_dirty(&mut self) {
self.cleanup_state.set_true();
self.needs_cleanup = true;
}
pub fn mirror_send(&mut self, bytes: &BytesMut) {
@@ -1089,7 +1033,6 @@ impl Server {
client_server_map,
Arc::new(ServerStats::default()),
Arc::new(RwLock::new(None)),
true,
)
.await?;
debug!("Connected!, sending query.");
@@ -1192,18 +1135,14 @@ impl Drop for Server {
_ => debug!("Dirty shutdown"),
};
// Should not matter.
self.bad = true;
let now = chrono::offset::Utc::now().naive_utc();
let duration = now - self.connected_at;
let message = if self.bad {
"Server connection terminated"
} else {
"Server connection closed"
};
info!(
"{} {:?}, session duration: {}",
message,
"Server connection closed {:?}, session duration: {}",
self.address,
crate::format_duration(&duration)
);

View File

@@ -107,20 +107,8 @@ impl Collector {
loop {
interval.tick().await;
// Hold read lock for duration of update to retain all server stats
let server_stats = SERVER_STATS.read();
for stats in server_stats.values() {
if !stats.check_address_stat_average_is_updated_status() {
stats.address_stats().update_averages();
stats.address_stats().reset_current_counts();
stats.set_address_stat_average_is_updated_status(true);
}
}
// Reset to false for next update
for stats in server_stats.values() {
stats.set_address_stat_average_is_updated_status(false);
for stats in SERVER_STATS.read().values() {
stats.address_stats().update_averages();
}
}
});

View File

@@ -1,29 +1,26 @@
use log::warn;
use std::sync::atomic::*;
use std::sync::Arc;
#[derive(Debug, Clone, Default)]
struct AddressStatFields {
xact_count: Arc<AtomicU64>,
query_count: Arc<AtomicU64>,
bytes_received: Arc<AtomicU64>,
bytes_sent: Arc<AtomicU64>,
xact_time: Arc<AtomicU64>,
query_time: Arc<AtomicU64>,
wait_time: Arc<AtomicU64>,
errors: Arc<AtomicU64>,
}
/// Internal address stats
#[derive(Debug, Clone, Default)]
pub struct AddressStats {
total: AddressStatFields,
current: AddressStatFields,
averages: AddressStatFields,
// Determines if the averages have been updated since the last time they were reported
pub averages_updated: Arc<AtomicBool>,
pub total_xact_count: Arc<AtomicU64>,
pub total_query_count: Arc<AtomicU64>,
pub total_received: Arc<AtomicU64>,
pub total_sent: Arc<AtomicU64>,
pub total_xact_time: Arc<AtomicU64>,
pub total_query_time: Arc<AtomicU64>,
pub total_wait_time: Arc<AtomicU64>,
pub total_errors: Arc<AtomicU64>,
pub avg_query_count: Arc<AtomicU64>,
pub avg_query_time: Arc<AtomicU64>,
pub avg_recv: Arc<AtomicU64>,
pub avg_sent: Arc<AtomicU64>,
pub avg_errors: Arc<AtomicU64>,
pub avg_xact_time: Arc<AtomicU64>,
pub avg_xact_count: Arc<AtomicU64>,
pub avg_wait_time: Arc<AtomicU64>,
}
impl IntoIterator for AddressStats {
@@ -34,67 +31,67 @@ impl IntoIterator for AddressStats {
vec![
(
"total_xact_count".to_string(),
self.total.xact_count.load(Ordering::Relaxed),
self.total_xact_count.load(Ordering::Relaxed),
),
(
"total_query_count".to_string(),
self.total.query_count.load(Ordering::Relaxed),
self.total_query_count.load(Ordering::Relaxed),
),
(
"total_received".to_string(),
self.total.bytes_received.load(Ordering::Relaxed),
self.total_received.load(Ordering::Relaxed),
),
(
"total_sent".to_string(),
self.total.bytes_sent.load(Ordering::Relaxed),
self.total_sent.load(Ordering::Relaxed),
),
(
"total_xact_time".to_string(),
self.total.xact_time.load(Ordering::Relaxed),
self.total_xact_time.load(Ordering::Relaxed),
),
(
"total_query_time".to_string(),
self.total.query_time.load(Ordering::Relaxed),
self.total_query_time.load(Ordering::Relaxed),
),
(
"total_wait_time".to_string(),
self.total.wait_time.load(Ordering::Relaxed),
self.total_wait_time.load(Ordering::Relaxed),
),
(
"total_errors".to_string(),
self.total.errors.load(Ordering::Relaxed),
self.total_errors.load(Ordering::Relaxed),
),
(
"avg_xact_count".to_string(),
self.averages.xact_count.load(Ordering::Relaxed),
self.avg_xact_count.load(Ordering::Relaxed),
),
(
"avg_query_count".to_string(),
self.averages.query_count.load(Ordering::Relaxed),
self.avg_query_count.load(Ordering::Relaxed),
),
(
"avg_recv".to_string(),
self.averages.bytes_received.load(Ordering::Relaxed),
self.avg_recv.load(Ordering::Relaxed),
),
(
"avg_sent".to_string(),
self.averages.bytes_sent.load(Ordering::Relaxed),
self.avg_sent.load(Ordering::Relaxed),
),
(
"avg_errors".to_string(),
self.averages.errors.load(Ordering::Relaxed),
self.avg_errors.load(Ordering::Relaxed),
),
(
"avg_xact_time".to_string(),
self.averages.xact_time.load(Ordering::Relaxed),
self.avg_xact_time.load(Ordering::Relaxed),
),
(
"avg_query_time".to_string(),
self.averages.query_time.load(Ordering::Relaxed),
self.avg_query_time.load(Ordering::Relaxed),
),
(
"avg_wait_time".to_string(),
self.averages.wait_time.load(Ordering::Relaxed),
self.avg_wait_time.load(Ordering::Relaxed),
),
]
.into_iter()
@@ -102,120 +99,22 @@ impl IntoIterator for AddressStats {
}
impl AddressStats {
pub fn xact_count_add(&self) {
self.total.xact_count.fetch_add(1, Ordering::Relaxed);
self.current.xact_count.fetch_add(1, Ordering::Relaxed);
}
pub fn query_count_add(&self) {
self.total.query_count.fetch_add(1, Ordering::Relaxed);
self.current.query_count.fetch_add(1, Ordering::Relaxed);
}
pub fn bytes_received_add(&self, bytes: u64) {
self.total
.bytes_received
.fetch_add(bytes, Ordering::Relaxed);
self.current
.bytes_received
.fetch_add(bytes, Ordering::Relaxed);
}
pub fn bytes_sent_add(&self, bytes: u64) {
self.total.bytes_sent.fetch_add(bytes, Ordering::Relaxed);
self.current.bytes_sent.fetch_add(bytes, Ordering::Relaxed);
}
pub fn xact_time_add(&self, time: u64) {
self.total.xact_time.fetch_add(time, Ordering::Relaxed);
self.current.xact_time.fetch_add(time, Ordering::Relaxed);
}
pub fn query_time_add(&self, time: u64) {
self.total.query_time.fetch_add(time, Ordering::Relaxed);
self.current.query_time.fetch_add(time, Ordering::Relaxed);
}
pub fn wait_time_add(&self, time: u64) {
self.total.wait_time.fetch_add(time, Ordering::Relaxed);
self.current.wait_time.fetch_add(time, Ordering::Relaxed);
}
pub fn error(&self) {
self.total.errors.fetch_add(1, Ordering::Relaxed);
self.current.errors.fetch_add(1, Ordering::Relaxed);
self.total_errors.fetch_add(1, Ordering::Relaxed);
}
pub fn update_averages(&self) {
let stat_period_per_second = crate::stats::STAT_PERIOD / 1_000;
// xact_count
let current_xact_count = self.current.xact_count.load(Ordering::Relaxed);
let current_xact_time = self.current.xact_time.load(Ordering::Relaxed);
self.averages.xact_count.store(
current_xact_count / stat_period_per_second,
Ordering::Relaxed,
);
if current_xact_count == 0 {
self.averages.xact_time.store(0, Ordering::Relaxed);
} else {
self.averages
.xact_time
.store(current_xact_time / current_xact_count, Ordering::Relaxed);
let (totals, averages) = self.fields_iterators();
for data in totals.iter().zip(averages.iter()) {
let (total, average) = data;
if let Err(err) = average.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |avg| {
let total = total.load(Ordering::Relaxed);
let avg = (total - avg) / (crate::stats::STAT_PERIOD / 1_000); // Avg / second
Some(avg)
}) {
warn!("Could not update averages for addresses stats, {:?}", err);
}
}
// query_count
let current_query_count = self.current.query_count.load(Ordering::Relaxed);
let current_query_time = self.current.query_time.load(Ordering::Relaxed);
self.averages.query_count.store(
current_query_count / stat_period_per_second,
Ordering::Relaxed,
);
if current_query_count == 0 {
self.averages.query_time.store(0, Ordering::Relaxed);
} else {
self.averages
.query_time
.store(current_query_time / current_query_count, Ordering::Relaxed);
}
// bytes_received
let current_bytes_received = self.current.bytes_received.load(Ordering::Relaxed);
self.averages.bytes_received.store(
current_bytes_received / stat_period_per_second,
Ordering::Relaxed,
);
// bytes_sent
let current_bytes_sent = self.current.bytes_sent.load(Ordering::Relaxed);
self.averages.bytes_sent.store(
current_bytes_sent / stat_period_per_second,
Ordering::Relaxed,
);
// wait_time
let current_wait_time = self.current.wait_time.load(Ordering::Relaxed);
self.averages.wait_time.store(
current_wait_time / stat_period_per_second,
Ordering::Relaxed,
);
// errors
let current_errors = self.current.errors.load(Ordering::Relaxed);
self.averages
.errors
.store(current_errors / stat_period_per_second, Ordering::Relaxed);
}
pub fn reset_current_counts(&self) {
self.current.xact_count.store(0, Ordering::Relaxed);
self.current.xact_time.store(0, Ordering::Relaxed);
self.current.query_count.store(0, Ordering::Relaxed);
self.current.query_time.store(0, Ordering::Relaxed);
self.current.bytes_received.store(0, Ordering::Relaxed);
self.current.bytes_sent.store(0, Ordering::Relaxed);
self.current.wait_time.store(0, Ordering::Relaxed);
self.current.errors.store(0, Ordering::Relaxed);
}
pub fn populate_row(&self, row: &mut Vec<String>) {
@@ -223,4 +122,28 @@ impl AddressStats {
row.push(value.to_string());
}
}
fn fields_iterators(&self) -> (Vec<Arc<AtomicU64>>, Vec<Arc<AtomicU64>>) {
let mut totals: Vec<Arc<AtomicU64>> = Vec::new();
let mut averages: Vec<Arc<AtomicU64>> = Vec::new();
totals.push(self.total_xact_count.clone());
averages.push(self.avg_xact_count.clone());
totals.push(self.total_query_count.clone());
averages.push(self.avg_query_count.clone());
totals.push(self.total_received.clone());
averages.push(self.avg_recv.clone());
totals.push(self.total_sent.clone());
averages.push(self.avg_sent.clone());
totals.push(self.total_xact_time.clone());
averages.push(self.avg_xact_time.clone());
totals.push(self.total_query_time.clone());
averages.push(self.avg_query_time.clone());
totals.push(self.total_wait_time.clone());
averages.push(self.avg_wait_time.clone());
totals.push(self.total_errors.clone());
averages.push(self.avg_errors.clone());
(totals, averages)
}
}

View File

@@ -139,17 +139,6 @@ impl ServerStats {
self.address.stats.clone()
}
pub fn check_address_stat_average_is_updated_status(&self) -> bool {
self.address.stats.averages_updated.load(Ordering::Relaxed)
}
pub fn set_address_stat_average_is_updated_status(&self, is_checked: bool) {
self.address
.stats
.averages_updated
.store(is_checked, Ordering::Relaxed);
}
// Helper methods for show_servers
pub fn pool_name(&self) -> String {
self.pool_stats.database()
@@ -177,9 +166,12 @@ impl ServerStats {
}
pub fn checkout_time(&self, microseconds: u64, application_name: String) {
// Update server stats and address aggregation stats
// Update server stats and address aggergation stats
self.set_application(application_name);
self.address.stats.wait_time_add(microseconds);
self.address
.stats
.total_wait_time
.fetch_add(microseconds, Ordering::Relaxed);
self.pool_stats
.maxwait
.fetch_max(microseconds, Ordering::Relaxed);
@@ -188,8 +180,13 @@ impl ServerStats {
/// Report a query executed by a client against a server
pub fn query(&self, milliseconds: u64, application_name: &str) {
self.set_application(application_name.to_string());
self.address.stats.query_count_add();
self.address.stats.query_time_add(milliseconds);
let address_stats = self.address_stats();
address_stats
.total_query_count
.fetch_add(1, Ordering::Relaxed);
address_stats
.total_query_time
.fetch_add(milliseconds, Ordering::Relaxed);
}
/// Report a transaction executed by a client a server
@@ -200,20 +197,29 @@ impl ServerStats {
self.set_application(application_name.to_string());
self.transaction_count.fetch_add(1, Ordering::Relaxed);
self.address.stats.xact_count_add();
self.address
.stats
.total_xact_count
.fetch_add(1, Ordering::Relaxed);
}
/// Report data sent to a server
pub fn data_sent(&self, amount_bytes: usize) {
self.bytes_sent
.fetch_add(amount_bytes as u64, Ordering::Relaxed);
self.address.stats.bytes_sent_add(amount_bytes as u64);
self.address
.stats
.total_sent
.fetch_add(amount_bytes as u64, Ordering::Relaxed);
}
/// Report data received from a server
pub fn data_received(&self, amount_bytes: usize) {
self.bytes_received
.fetch_add(amount_bytes as u64, Ordering::Relaxed);
self.address.stats.bytes_received_add(amount_bytes as u64);
self.address
.stats
.total_received
.fetch_add(amount_bytes as u64, Ordering::Relaxed);
}
}

View File

@@ -14,12 +14,11 @@ describe "Admin" do
describe "SHOW STATS" do
context "clients connect and make one query" do
it "updates *_query_time and *_wait_time" do
connections = Array.new(3) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
connections.each do |c|
Thread.new { c.async_exec("SELECT pg_sleep(0.25)") }
end
sleep(1)
connections.map(&:close)
connection = PG::connect("#{pgcat_conn_str}?application_name=one_query")
connection.async_exec("SELECT pg_sleep(0.25)")
connection.async_exec("SELECT pg_sleep(0.25)")
connection.async_exec("SELECT pg_sleep(0.25)")
connection.close
# wait for averages to be calculated, we shouldn't do this too often
sleep(15.5)
@@ -27,7 +26,7 @@ describe "Admin" do
results = admin_conn.async_exec("SHOW STATS")[0]
admin_conn.close
expect(results["total_query_time"].to_i).to be_within(200).of(750)
expect(results["avg_query_time"].to_i).to be_within(50).of(250)
expect(results["avg_query_time"].to_i).to_not eq(0)
expect(results["total_wait_time"].to_i).to_not eq(0)
expect(results["avg_wait_time"].to_i).to_not eq(0)

View File

@@ -41,24 +41,7 @@ module Helpers
"1" => { "database" => "shard1", "servers" => [["localhost", primary1.port.to_s, "primary"]] },
"2" => { "database" => "shard2", "servers" => [["localhost", primary2.port.to_s, "primary"]] },
},
"users" => { "0" => user },
"plugins" => {
"intercept" => {
"enabled" => true,
"queries" => {
"0" => {
"query" => "select current_database() as a, current_schemas(false) as b",
"schema" => [
["a", "text"],
["b", "text"],
],
"result" => [
["${DATABASE}", "{public}"],
]
}
}
}
}
"users" => { "0" => user }
}
}
pgcat.update_config(pgcat_cfg)
@@ -118,7 +101,7 @@ module Helpers
end
end
def self.single_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mode="random", log_level="info", pool_settings={})
def self.single_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mode="random", log_level="info")
user = {
"password" => "sharding_user",
"pool_size" => pool_size,
@@ -134,32 +117,28 @@ module Helpers
replica1 = PgInstance.new(8432, user["username"], user["password"], "shard0")
replica2 = PgInstance.new(9432, user["username"], user["password"], "shard0")
pool_config = {
"default_role" => "any",
"pool_mode" => pool_mode,
"load_balancing_mode" => lb_mode,
"primary_reads_enabled" => false,
"query_parser_enabled" => false,
"sharding_function" => "pg_bigint_hash",
"shards" => {
"0" => {
"database" => "shard0",
"servers" => [
["localhost", primary.port.to_s, "primary"],
["localhost", replica0.port.to_s, "replica"],
["localhost", replica1.port.to_s, "replica"],
["localhost", replica2.port.to_s, "replica"]
]
},
},
"users" => { "0" => user }
}
pool_config = pool_config.merge(pool_settings)
# Main proxy configs
pgcat_cfg["pools"] = {
"#{pool_name}" => pool_config,
"#{pool_name}" => {
"default_role" => "any",
"pool_mode" => pool_mode,
"load_balancing_mode" => lb_mode,
"primary_reads_enabled" => false,
"query_parser_enabled" => false,
"sharding_function" => "pg_bigint_hash",
"shards" => {
"0" => {
"database" => "shard0",
"servers" => [
["localhost", primary.port.to_s, "primary"],
["localhost", replica0.port.to_s, "replica"],
["localhost", replica1.port.to_s, "replica"],
["localhost", replica2.port.to_s, "replica"]
]
},
},
"users" => { "0" => user }
}
}
pgcat_cfg["general"]["port"] = pgcat.port
pgcat.update_config(pgcat_cfg)

View File

@@ -241,18 +241,6 @@ describe "Miscellaneous" do
expect(processes.primary.count_query("DISCARD ALL")).to eq(10)
end
it "Resets server roles correctly" do
10.times do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.async_exec("SET SERVER ROLE to 'primary'")
conn.async_exec("SELECT 1")
conn.async_exec("SET statement_timeout to 5000")
conn.close
end
expect(processes.primary.count_query("RESET ROLE")).to eq(10)
end
end
context "transaction mode" do
@@ -320,31 +308,6 @@ describe "Miscellaneous" do
expect(processes.primary.count_query("DISCARD ALL")).to eq(0)
end
end
context "server cleanup disabled" do
let(:processes) { Helpers::Pgcat.single_shard_setup("sharded_db", 1, "transaction", "random", "info", { "cleanup_server_connections" => false }) }
it "will not clean up connection state" do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
processes.primary.reset_stats
conn.async_exec("SET statement_timeout TO 1000")
conn.close
puts processes.pgcat.logs
expect(processes.primary.count_query("DISCARD ALL")).to eq(0)
end
it "will not clean up prepared statements" do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
processes.primary.reset_stats
conn.async_exec("PREPARE prepared_q (int) AS SELECT $1")
conn.close
puts processes.pgcat.logs
expect(processes.primary.count_query("DISCARD ALL")).to eq(0)
end
end
end
describe "Idle client timeout" do