From 52b1b438505198f84e1d56fa3f4767dddb77f67d Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Fri, 12 May 2023 09:50:52 -0700 Subject: [PATCH] Prewarmer (#435) * Prewarmer * hmm * Tests * default * fix test * Correct configuration * Added minimal config example * remove connect_timeout --- Cargo.lock | 2 +- Cargo.toml | 2 +- pgcat.minimal.toml | 22 ++++ pgcat.toml | 72 ++++++++++- src/config.rs | 78 +++++++++++- src/mirrors.rs | 1 + src/plugins/intercept.rs | 194 ++--------------------------- src/plugins/mod.rs | 1 + src/plugins/prewarmer.rs | 28 +++++ src/plugins/query_logger.rs | 31 ++--- src/plugins/table_access.rs | 34 ++--- src/pool.rs | 66 +++++----- src/query_router.rs | 69 +++++++--- src/server.rs | 2 + tests/ruby/helpers/pgcat_helper.rb | 19 ++- 15 files changed, 337 insertions(+), 284 deletions(-) create mode 100644 pgcat.minimal.toml create mode 100644 src/plugins/prewarmer.rs diff --git a/Cargo.lock b/Cargo.lock index 9c553d4..4703a0b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -897,7 +897,7 @@ checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" [[package]] name = "pgcat" -version = "1.0.2-alpha1" +version = "1.0.2-alpha2" dependencies = [ "arc-swap", "async-trait", diff --git a/Cargo.toml b/Cargo.toml index 5428ab3..af1b9ae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgcat" -version = "1.0.2-alpha1" +version = "1.0.2-alpha2" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/pgcat.minimal.toml b/pgcat.minimal.toml new file mode 100644 index 0000000..4b17a45 --- /dev/null +++ b/pgcat.minimal.toml @@ -0,0 +1,22 @@ +# This is an example of the most basic config +# that will mimic what PgBouncer does in transaction mode with one server. + +[general] + +host = "0.0.0.0" +port = 6433 +admin_username = "pgcat" +admin_password = "pgcat" + +[pools.pgml.users.0] +username = "postgres" +password = "postgres" +pool_size = 10 +min_pool_size = 1 +pool_mode = "transaction" + +[pools.pgml.shards.0] +servers = [ + ["127.0.0.1", 28815, "primary"] +] +database = "postgres" diff --git a/pgcat.toml b/pgcat.toml index ce36632..e6b54b2 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -77,6 +77,58 @@ admin_username = "admin_user" # Password to access the virtual administrative database admin_password = "admin_pass" +# Default plugins that are configured on all pools. +[plugins] + +# Prewarmer plugin that runs queries on server startup, before giving the connection +# to the client. +[plugins.prewarmer] +enabled = false +queries = [ + "SELECT pg_prewarm('pgbench_accounts')", +] + +# Log all queries to stdout. +[plugins.query_logger] +enabled = false + +# Block access to tables that Postgres does not allow us to control. +[plugins.table_access] +enabled = false +tables = [ + "pg_user", + "pg_roles", + "pg_database", +] + +# Intercept user queries and give a fake reply. +[plugins.intercept] +enabled = true + +[plugins.intercept.queries.0] + +query = "select current_database() as a, current_schemas(false) as b" +schema = [ + ["a", "text"], + ["b", "text"], +] +result = [ + ["${DATABASE}", "{public}"], +] + +[plugins.intercept.queries.1] + +query = "select current_database(), current_schema(), current_user" +schema = [ + ["current_database", "text"], + ["current_schema", "text"], + ["current_user", "text"], +] +result = [ + ["${DATABASE}", "public", "${USER}"], +] + + # pool configs are structured as pool. # the pool_name is what clients use as database name when connecting. # For a pool named `sharded_db`, clients access that pool using connection string like @@ -154,12 +206,20 @@ connect_timeout = 3000 # Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`). # dns_max_ttl = 30 -[plugins] +# Plugins can be configured on a pool-per-pool basis. This overrides the global plugins setting, +# so all plugins have to be configured here again. +[pool.sharded_db.plugins] -[plugins.query_logger] +[pools.sharded_db.plugins.prewarmer] +enabled = true +queries = [ + "SELECT pg_prewarm('pgbench_accounts')", +] + +[pools.sharded_db.plugins.query_logger] enabled = false -[plugins.table_access] +[pools.sharded_db.plugins.table_access] enabled = false tables = [ "pg_user", @@ -167,10 +227,10 @@ tables = [ "pg_database", ] -[plugins.intercept] +[pools.sharded_db.plugins.intercept] enabled = true -[plugins.intercept.queries.0] +[pools.sharded_db.plugins.intercept.queries.0] query = "select current_database() as a, current_schemas(false) as b" schema = [ @@ -181,7 +241,7 @@ result = [ ["${DATABASE}", "{public}"], ] -[plugins.intercept.queries.1] +[pools.sharded_db.plugins.intercept.queries.1] query = "select current_database(), current_schema(), current_user" schema = [ diff --git a/src/config.rs b/src/config.rs index c7fce63..f417773 100644 --- a/src/config.rs +++ b/src/config.rs @@ -122,6 +122,16 @@ impl Default for Address { } } +impl std::fmt::Display for Address { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "[address: {}:{}][database: {}][user: {}]", + self.host, self.port, self.database, self.username + ) + } +} + // We need to implement PartialEq by ourselves so we skip stats in the comparison impl PartialEq for Address { fn eq(&self, other: &Self) -> bool { @@ -235,6 +245,8 @@ pub struct General { pub port: u16, pub enable_prometheus_exporter: Option, + + #[serde(default = "General::default_prometheus_exporter_port")] pub prometheus_exporter_port: i16, #[serde(default = "General::default_connect_timeout")] @@ -374,6 +386,10 @@ impl General { pub fn default_validate_config() -> bool { true } + + pub fn default_prometheus_exporter_port() -> i16 { + 9930 + } } impl Default for General { @@ -462,6 +478,7 @@ pub struct Pool { #[serde(default = "Pool::default_load_balancing_mode")] pub load_balancing_mode: LoadBalancingMode, + #[serde(default = "Pool::default_default_role")] pub default_role: String, #[serde(default)] // False @@ -476,6 +493,7 @@ pub struct Pool { pub server_lifetime: Option, + #[serde(default = "Pool::default_sharding_function")] pub sharding_function: ShardingFunction, #[serde(default = "Pool::default_automatic_sharding_key")] @@ -489,6 +507,7 @@ pub struct Pool { pub auth_query_user: Option, pub auth_query_password: Option, + pub plugins: Option, pub shards: BTreeMap, pub users: BTreeMap, // Note, don't put simple fields below these configs. There's a compatibility issue with TOML that makes it @@ -521,6 +540,14 @@ impl Pool { None } + pub fn default_default_role() -> String { + "any".into() + } + + pub fn default_sharding_function() -> ShardingFunction { + ShardingFunction::PgBigintHash + } + pub fn validate(&mut self) -> Result<(), Error> { match self.default_role.as_ref() { "any" => (), @@ -609,6 +636,7 @@ impl Default for Pool { auth_query_user: None, auth_query_password: None, server_lifetime: None, + plugins: None, } } } @@ -687,30 +715,50 @@ impl Default for Shard { } } -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] pub struct Plugins { pub intercept: Option, pub table_access: Option, pub query_logger: Option, + pub prewarmer: Option, } -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)] +impl std::fmt::Display for Plugins { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "interceptor: {}, table_access: {}, query_logger: {}, prewarmer: {}", + self.intercept.is_some(), + self.table_access.is_some(), + self.query_logger.is_some(), + self.prewarmer.is_some(), + ) + } +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] pub struct Intercept { pub enabled: bool, pub queries: BTreeMap, } -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] pub struct TableAccess { pub enabled: bool, pub tables: Vec, } -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] pub struct QueryLogger { pub enabled: bool, } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] +pub struct Prewarmer { + pub enabled: bool, + pub queries: Vec, +} + impl Intercept { pub fn substitute(&mut self, db: &str, user: &str) { for (_, query) in self.queries.iter_mut() { @@ -720,7 +768,7 @@ impl Intercept { } } -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, Hash, Eq)] pub struct Query { pub query: String, pub schema: Vec>, @@ -754,8 +802,13 @@ pub struct Config { #[serde(default = "Config::default_path")] pub path: String, + // General and global settings. pub general: General, + + // Plugins that should run in all pools. pub plugins: Option, + + // Connection pools. pub pools: HashMap, } @@ -940,6 +993,13 @@ impl Config { "Server TLS certificate verification: {}", self.general.verify_server_certificate ); + info!( + "Plugins: {}", + match self.plugins { + Some(ref plugins) => plugins.to_string(), + None => "not configured".into(), + } + ); for (pool_name, pool_config) in &self.pools { // TODO: Make this output prettier (maybe a table?) @@ -1006,6 +1066,14 @@ impl Config { None => "default".to_string(), } ); + info!( + "[pool: {}] Plugins: {}", + pool_name, + match pool_config.plugins { + Some(ref plugins) => plugins.to_string(), + None => "not configured".into(), + } + ); for user in &pool_config.users { info!( diff --git a/src/mirrors.rs b/src/mirrors.rs index 17f91d4..d6d691f 100644 --- a/src/mirrors.rs +++ b/src/mirrors.rs @@ -43,6 +43,7 @@ impl MirroredClient { ClientServerMap::default(), Arc::new(PoolStats::new(identifier, cfg.clone())), Arc::new(RwLock::new(None)), + None, ); Pool::builder() diff --git a/src/plugins/intercept.rs b/src/plugins/intercept.rs index 88d24d0..166294b 100644 --- a/src/plugins/intercept.rs +++ b/src/plugins/intercept.rs @@ -2,52 +2,21 @@ //! //! It intercepts queries and returns fake results. -use arc_swap::ArcSwap; use async_trait::async_trait; use bytes::{BufMut, BytesMut}; -use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; -use serde_json::{json, Value}; use sqlparser::ast::Statement; -use std::collections::HashMap; -use log::{debug, info}; -use std::sync::Arc; +use log::debug; use crate::{ config::Intercept as InterceptConfig, errors::Error, messages::{command_complete, data_row_nullable, row_description, DataType}, plugins::{Plugin, PluginOutput}, - pool::{PoolIdentifier, PoolMap}, query_router::QueryRouter, }; -pub static CONFIG: Lazy>> = - Lazy::new(|| ArcSwap::from_pointee(HashMap::new())); - -/// Check if the interceptor plugin has been enabled. -pub fn enabled() -> bool { - !CONFIG.load().is_empty() -} - -pub fn setup(intercept_config: &InterceptConfig, pools: &PoolMap) { - let mut config = HashMap::new(); - for (identifier, _) in pools.iter() { - let mut intercept_config = intercept_config.clone(); - intercept_config.substitute(&identifier.db, &identifier.user); - config.insert(identifier.clone(), intercept_config); - } - - CONFIG.store(Arc::new(config)); - - info!("Intercepting {} queries", intercept_config.queries.len()); -} - -pub fn disable() { - CONFIG.store(Arc::new(HashMap::new())); -} - // TODO: use these structs for deserialization #[derive(Serialize, Deserialize)] pub struct Rule { @@ -63,33 +32,35 @@ pub struct Column { } /// The intercept plugin. -pub struct Intercept; +pub struct Intercept<'a> { + pub enabled: bool, + pub config: &'a InterceptConfig, +} #[async_trait] -impl Plugin for Intercept { +impl<'a> Plugin for Intercept<'a> { async fn run( &mut self, query_router: &QueryRouter, ast: &Vec, ) -> Result { - if ast.is_empty() { + if !self.enabled || ast.is_empty() { return Ok(PluginOutput::Allow); } - let mut result = BytesMut::new(); - let query_map = match CONFIG.load().get(&PoolIdentifier::new( + let mut config = self.config.clone(); + config.substitute( &query_router.pool_settings().db, &query_router.pool_settings().user.username, - )) { - Some(query_map) => query_map.clone(), - None => return Ok(PluginOutput::Allow), - }; + ); + + let mut result = BytesMut::new(); for q in ast { // Normalization let q = q.to_string().to_ascii_lowercase(); - for (_, target) in query_map.queries.iter() { + for (_, target) in config.queries.iter() { if target.query.as_str() == q { debug!("Intercepting query: {}", q); @@ -147,142 +118,3 @@ impl Plugin for Intercept { } } } - -/// Make IntelliJ SQL plugin believe it's talking to an actual database -/// instead of PgCat. -#[allow(dead_code)] -fn fool_datagrip(database: &str, user: &str) -> Value { - json!([ - { - "query": "select current_database() as a, current_schemas(false) as b", - "schema": [ - { - "name": "a", - "data_type": "text", - }, - { - "name": "b", - "data_type": "anyarray", - }, - ], - - "result": [ - [database, "{public}"], - ], - }, - { - "query": "select current_database(), current_schema(), current_user", - "schema": [ - { - "name": "current_database", - "data_type": "text", - }, - { - "name": "current_schema", - "data_type": "text", - }, - { - "name": "current_user", - "data_type": "text", - } - ], - - "result": [ - ["sharded_db", "public", "sharding_user"], - ], - }, - { - "query": "select cast(n.oid as bigint) as id, datname as name, d.description, datistemplate as is_template, datallowconn as allow_connections, pg_catalog.pg_get_userbyid(n.datdba) as \"owner\" from pg_catalog.pg_database as n left join pg_catalog.pg_shdescription as d on n.oid = d.objoid order by case when datname = pg_catalog.current_database() then -cast(1 as bigint) else cast(n.oid as bigint) end", - "schema": [ - { - "name": "id", - "data_type": "oid", - }, - { - "name": "name", - "data_type": "text", - }, - { - "name": "description", - "data_type": "text", - }, - { - "name": "is_template", - "data_type": "bool", - }, - { - "name": "allow_connections", - "data_type": "bool", - }, - { - "name": "owner", - "data_type": "text", - } - ], - "result": [ - ["16387", database, "", "f", "t", user], - ] - }, - { - "query": "select cast(r.oid as bigint) as role_id, rolname as role_name, rolsuper as is_super, rolinherit as is_inherit, rolcreaterole as can_createrole, rolcreatedb as can_createdb, rolcanlogin as can_login, rolreplication as is_replication, rolconnlimit as conn_limit, rolvaliduntil as valid_until, rolbypassrls as bypass_rls, rolconfig as config, d.description from pg_catalog.pg_roles as r left join pg_catalog.pg_shdescription as d on d.objoid = r.oid", - "schema": [ - { - "name": "role_id", - "data_type": "oid", - }, - { - "name": "role_name", - "data_type": "text", - }, - { - "name": "is_super", - "data_type": "bool", - }, - { - "name": "is_inherit", - "data_type": "bool", - }, - { - "name": "can_createrole", - "data_type": "bool", - }, - { - "name": "can_createdb", - "data_type": "bool", - }, - { - "name": "can_login", - "data_type": "bool", - }, - { - "name": "is_replication", - "data_type": "bool", - }, - { - "name": "conn_limit", - "data_type": "int4", - }, - { - "name": "valid_until", - "data_type": "text", - }, - { - "name": "bypass_rls", - "data_type": "bool", - }, - { - "name": "config", - "data_type": "text", - }, - { - "name": "description", - "data_type": "text", - }, - ], - "result": [ - ["10", "postgres", "f", "t", "f", "f", "t", "f", "-1", "", "f", "", ""], - ["16419", user, "f", "t", "f", "f", "t", "f", "-1", "", "f", "", ""], - ] - } - ]) -} diff --git a/src/plugins/mod.rs b/src/plugins/mod.rs index 6661ece..5ef6009 100644 --- a/src/plugins/mod.rs +++ b/src/plugins/mod.rs @@ -9,6 +9,7 @@ //! pub mod intercept; +pub mod prewarmer; pub mod query_logger; pub mod table_access; diff --git a/src/plugins/prewarmer.rs b/src/plugins/prewarmer.rs new file mode 100644 index 0000000..a09bbe9 --- /dev/null +++ b/src/plugins/prewarmer.rs @@ -0,0 +1,28 @@ +//! Prewarm new connections before giving them to the client. +use crate::{errors::Error, server::Server}; +use log::info; + +pub struct Prewarmer<'a> { + pub enabled: bool, + pub server: &'a mut Server, + pub queries: &'a Vec, +} + +impl<'a> Prewarmer<'a> { + pub async fn run(&mut self) -> Result<(), Error> { + if !self.enabled { + return Ok(()); + } + + for query in self.queries { + info!( + "{} Prewarning with query: `{}`", + self.server.address(), + query + ); + self.server.query(&query).await?; + } + + Ok(()) + } +} diff --git a/src/plugins/query_logger.rs b/src/plugins/query_logger.rs index 2dfda8b..debdf39 100644 --- a/src/plugins/query_logger.rs +++ b/src/plugins/query_logger.rs @@ -5,44 +5,33 @@ use crate::{ plugins::{Plugin, PluginOutput}, query_router::QueryRouter, }; -use arc_swap::ArcSwap; use async_trait::async_trait; use log::info; -use once_cell::sync::Lazy; use sqlparser::ast::Statement; -use std::sync::Arc; -static ENABLED: Lazy> = Lazy::new(|| ArcSwap::from_pointee(false)); - -pub struct QueryLogger; - -pub fn setup() { - ENABLED.store(Arc::new(true)); - - info!("Logging queries to stdout"); -} - -pub fn disable() { - ENABLED.store(Arc::new(false)); -} - -pub fn enabled() -> bool { - **ENABLED.load() +pub struct QueryLogger<'a> { + pub enabled: bool, + pub user: &'a str, + pub db: &'a str, } #[async_trait] -impl Plugin for QueryLogger { +impl<'a> Plugin for QueryLogger<'a> { async fn run( &mut self, _query_router: &QueryRouter, ast: &Vec, ) -> Result { + if !self.enabled { + return Ok(PluginOutput::Allow); + } + let query = ast .iter() .map(|q| q.to_string()) .collect::>() .join("; "); - info!("{}", query); + info!("[pool: {}][user: {}] {}", self.user, self.db, query); Ok(PluginOutput::Allow) } diff --git a/src/plugins/table_access.rs b/src/plugins/table_access.rs index 4613a4f..79c1260 100644 --- a/src/plugins/table_access.rs +++ b/src/plugins/table_access.rs @@ -5,53 +5,39 @@ use async_trait::async_trait; use sqlparser::ast::{visit_relations, Statement}; use crate::{ - config::TableAccess as TableAccessConfig, errors::Error, plugins::{Plugin, PluginOutput}, query_router::QueryRouter, }; -use log::{debug, info}; +use log::debug; -use arc_swap::ArcSwap; use core::ops::ControlFlow; -use once_cell::sync::Lazy; -use std::sync::Arc; -static CONFIG: Lazy>> = Lazy::new(|| ArcSwap::from_pointee(vec![])); - -pub fn setup(config: &TableAccessConfig) { - CONFIG.store(Arc::new(config.tables.clone())); - - info!("Blocking access to {} tables", config.tables.len()); +pub struct TableAccess<'a> { + pub enabled: bool, + pub tables: &'a Vec, } -pub fn enabled() -> bool { - !CONFIG.load().is_empty() -} - -pub fn disable() { - CONFIG.store(Arc::new(vec![])); -} - -pub struct TableAccess; - #[async_trait] -impl Plugin for TableAccess { +impl<'a> Plugin for TableAccess<'a> { async fn run( &mut self, _query_router: &QueryRouter, ast: &Vec, ) -> Result { + if !self.enabled { + return Ok(PluginOutput::Allow); + } + let mut found = None; - let forbidden_tables = CONFIG.load(); visit_relations(ast, |relation| { let relation = relation.to_string(); let parts = relation.split(".").collect::>(); let table_name = parts.last().unwrap(); - if forbidden_tables.contains(&table_name.to_string()) { + if self.tables.contains(&table_name.to_string()) { found = Some(table_name.to_string()); ControlFlow::<()>::Break(()) } else { diff --git a/src/pool.rs b/src/pool.rs index 4664193..8e03ae4 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -17,10 +17,13 @@ use std::sync::{ use std::time::Instant; use tokio::sync::Notify; -use crate::config::{get_config, Address, General, LoadBalancingMode, PoolMode, Role, User}; +use crate::config::{ + get_config, Address, General, LoadBalancingMode, Plugins, PoolMode, Role, User, +}; use crate::errors::Error; use crate::auth_passthrough::AuthPassthrough; +use crate::plugins::prewarmer; use crate::server::Server; use crate::sharding::ShardingFunction; use crate::stats::{AddressStats, ClientStats, PoolStats, ServerStats}; @@ -132,6 +135,9 @@ pub struct PoolSettings { pub auth_query: Option, pub auth_query_user: Option, pub auth_query_password: Option, + + /// Plugins + pub plugins: Option, } impl Default for PoolSettings { @@ -156,6 +162,7 @@ impl Default for PoolSettings { auth_query: None, auth_query_user: None, auth_query_password: None, + plugins: None, } } } @@ -195,6 +202,7 @@ pub struct ConnectionPool { paused: Arc, paused_waiter: Arc, + /// Statistics. pub stats: Arc, /// AuthInfo @@ -352,6 +360,10 @@ impl ConnectionPool { client_server_map.clone(), pool_stats.clone(), pool_auth_hash.clone(), + match pool_config.plugins { + Some(ref plugins) => Some(plugins.clone()), + None => config.plugins.clone(), + }, ); let connect_timeout = match pool_config.connect_timeout { @@ -377,7 +389,10 @@ impl ConnectionPool { .min() .unwrap(); - debug!("Pool reaper rate: {}ms", reaper_rate); + debug!( + "[pool: {}][user: {}] Pool reaper rate: {}ms", + pool_name, user.username, reaper_rate + ); let pool = Pool::builder() .max_size(user.pool_size) @@ -450,6 +465,10 @@ impl ConnectionPool { auth_query: pool_config.auth_query.clone(), auth_query_user: pool_config.auth_query_user.clone(), auth_query_password: pool_config.auth_query_password.clone(), + plugins: match pool_config.plugins { + Some(ref plugins) => Some(plugins.clone()), + None => config.plugins.clone(), + }, }, validated: Arc::new(AtomicBool::new(false)), paused: Arc::new(AtomicBool::new(false)), @@ -471,32 +490,6 @@ impl ConnectionPool { } } - if let Some(ref plugins) = config.plugins { - if let Some(ref intercept) = plugins.intercept { - if intercept.enabled { - crate::plugins::intercept::setup(intercept, &new_pools); - } else { - crate::plugins::intercept::disable(); - } - } - - if let Some(ref table_access) = plugins.table_access { - if table_access.enabled { - crate::plugins::table_access::setup(table_access); - } else { - crate::plugins::table_access::disable(); - } - } - - if let Some(ref query_logger) = plugins.query_logger { - if query_logger.enabled { - crate::plugins::query_logger::setup(); - } else { - crate::plugins::query_logger::disable(); - } - } - } - POOLS.store(Arc::new(new_pools.clone())); Ok(()) } @@ -923,6 +916,7 @@ pub struct ServerPool { client_server_map: ClientServerMap, stats: Arc, auth_hash: Arc>>, + plugins: Option, } impl ServerPool { @@ -933,6 +927,7 @@ impl ServerPool { client_server_map: ClientServerMap, stats: Arc, auth_hash: Arc>>, + plugins: Option, ) -> ServerPool { ServerPool { address, @@ -941,6 +936,7 @@ impl ServerPool { client_server_map, stats, auth_hash, + plugins, } } } @@ -973,7 +969,19 @@ impl ManageConnection for ServerPool { ) .await { - Ok(conn) => { + Ok(mut conn) => { + if let Some(ref plugins) = self.plugins { + if let Some(ref prewarmer) = plugins.prewarmer { + let mut prewarmer = prewarmer::Prewarmer { + enabled: prewarmer.enabled, + server: &mut conn, + queries: &prewarmer.queries, + }; + + prewarmer.run().await?; + } + } + stats.idle(); Ok(conn) } diff --git a/src/query_router.rs b/src/query_router.rs index d995b80..3e3a23a 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -15,10 +15,7 @@ use sqlparser::parser::Parser; use crate::config::Role; use crate::errors::Error; use crate::messages::BytesMutReader; -use crate::plugins::{ - intercept, query_logger, table_access, Intercept, Plugin, PluginOutput, QueryLogger, - TableAccess, -}; +use crate::plugins::{Intercept, Plugin, PluginOutput, QueryLogger, TableAccess}; use crate::pool::PoolSettings; use crate::sharding::Sharder; @@ -793,13 +790,27 @@ impl QueryRouter { /// Add your plugins here and execute them. pub async fn execute_plugins(&self, ast: &Vec) -> Result { - if query_logger::enabled() { - let mut query_logger = QueryLogger {}; + let plugins = match self.pool_settings.plugins { + Some(ref plugins) => plugins, + None => return Ok(PluginOutput::Allow), + }; + + if let Some(ref query_logger) = plugins.query_logger { + let mut query_logger = QueryLogger { + enabled: query_logger.enabled, + user: &self.pool_settings.user.username, + db: &self.pool_settings.db, + }; + let _ = query_logger.run(&self, ast).await; } - if intercept::enabled() { - let mut intercept = Intercept {}; + if let Some(ref intercept) = plugins.intercept { + let mut intercept = Intercept { + enabled: intercept.enabled, + config: &intercept, + }; + let result = intercept.run(&self, ast).await; if let Ok(PluginOutput::Intercept(output)) = result { @@ -807,8 +818,12 @@ impl QueryRouter { } } - if table_access::enabled() { - let mut table_access = TableAccess {}; + if let Some(ref table_access) = plugins.table_access { + let mut table_access = TableAccess { + enabled: table_access.enabled, + tables: &table_access.tables, + }; + let result = table_access.run(&self, ast).await; if let Ok(PluginOutput::Deny(error)) = result { @@ -1161,6 +1176,7 @@ mod test { auth_query_password: None, auth_query_user: None, db: "test".to_string(), + plugins: None, }; let mut qr = QueryRouter::new(); assert_eq!(qr.active_role, None); @@ -1235,7 +1251,9 @@ mod test { auth_query_password: None, auth_query_user: None, db: "test".to_string(), + plugins: None, }; + let mut qr = QueryRouter::new(); qr.update_pool_settings(pool_settings.clone()); @@ -1379,17 +1397,25 @@ mod test { #[tokio::test] async fn test_table_access_plugin() { - use crate::config::TableAccess; - let ta = TableAccess { + use crate::config::{Plugins, TableAccess}; + let table_access = TableAccess { enabled: true, tables: vec![String::from("pg_database")], }; - - crate::plugins::table_access::setup(&ta); + let plugins = Plugins { + table_access: Some(table_access), + intercept: None, + query_logger: None, + prewarmer: None, + }; QueryRouter::setup(); + let mut pool_settings = PoolSettings::default(); + pool_settings.query_parser_enabled = true; + pool_settings.plugins = Some(plugins); - let qr = QueryRouter::new(); + let mut qr = QueryRouter::new(); + qr.update_pool_settings(pool_settings); let query = simple_query("SELECT * FROM pg_database"); let ast = QueryRouter::parse(&query).unwrap(); @@ -1403,4 +1429,17 @@ mod test { )) ); } + + #[tokio::test] + async fn test_plugins_disabled_by_defaault() { + QueryRouter::setup(); + let qr = QueryRouter::new(); + + let query = simple_query("SELECT * FROM pg_database"); + let ast = QueryRouter::parse(&query).unwrap(); + + let res = qr.execute_plugins(&ast).await; + + assert_eq!(res, Ok(PluginOutput::Allow)); + } } diff --git a/src/server.rs b/src/server.rs index dceab49..244c06e 100644 --- a/src/server.rs +++ b/src/server.rs @@ -970,6 +970,8 @@ impl Server { /// It will use the simple query protocol. /// Result will not be returned, so this is useful for things like `SET` or `ROLLBACK`. pub async fn query(&mut self, query: &str) -> Result<(), Error> { + debug!("Running `{}` on server {:?}", query, self.address); + let query = simple_query(query); self.send(&query).await?; diff --git a/tests/ruby/helpers/pgcat_helper.rb b/tests/ruby/helpers/pgcat_helper.rb index ad4c32a..e36801e 100644 --- a/tests/ruby/helpers/pgcat_helper.rb +++ b/tests/ruby/helpers/pgcat_helper.rb @@ -41,7 +41,24 @@ module Helpers "1" => { "database" => "shard1", "servers" => [["localhost", primary1.port.to_s, "primary"]] }, "2" => { "database" => "shard2", "servers" => [["localhost", primary2.port.to_s, "primary"]] }, }, - "users" => { "0" => user } + "users" => { "0" => user }, + "plugins" => { + "intercept" => { + "enabled" => true, + "queries" => { + "0" => { + "query" => "select current_database() as a, current_schemas(false) as b", + "schema" => [ + ["a", "text"], + ["b", "text"], + ], + "result" => [ + ["${DATABASE}", "{public}"], + ] + } + } + } + } } } pgcat.update_config(pgcat_cfg)