Actually plugins (#421)

* more plugins

* clean up

* fix tests

* fix flakey test
This commit is contained in:
Lev Kokotov
2023-05-03 16:13:45 -07:00
committed by GitHub
parent d5e329fec5
commit 811885f464
11 changed files with 265 additions and 71 deletions

View File

@@ -48,3 +48,4 @@ serde_json = "1"
[target.'cfg(not(target_env = "msvc"))'.dependencies] [target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0" jemallocator = "0.5.0"

View File

@@ -77,9 +77,6 @@ admin_username = "admin_user"
# Password to access the virtual administrative database # Password to access the virtual administrative database
admin_password = "admin_pass" admin_password = "admin_pass"
# Plugins!!
# query_router_plugins = ["pg_table_access", "intercept"]
# pool configs are structured as pool.<pool_name> # pool configs are structured as pool.<pool_name>
# the pool_name is what clients use as database name when connecting. # the pool_name is what clients use as database name when connecting.
# For a pool named `sharded_db`, clients access that pool using connection string like # For a pool named `sharded_db`, clients access that pool using connection string like
@@ -157,6 +154,45 @@ connect_timeout = 3000
# Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`). # Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
# dns_max_ttl = 30 # dns_max_ttl = 30
[plugins]
[plugins.query_logger]
enabled = false
[plugins.table_access]
enabled = false
tables = [
"pg_user",
"pg_roles",
"pg_database",
]
[plugins.intercept]
enabled = true
[plugins.intercept.queries.0]
query = "select current_database() as a, current_schemas(false) as b"
schema = [
["a", "text"],
["b", "text"],
]
result = [
["${DATABASE}", "{public}"],
]
[plugins.intercept.queries.1]
query = "select current_database(), current_schema(), current_user"
schema = [
["current_database", "text"],
["current_schema", "text"],
["current_user", "text"],
]
result = [
["${DATABASE}", "public", "${USER}"],
]
# User configs are structured as pool.<pool_name>.users.<user_index> # User configs are structured as pool.<pool_name>.users.<user_index>
# This section holds the credentials for users that may connect to this cluster # This section holds the credentials for users that may connect to this cluster
[pools.sharded_db.users.0] [pools.sharded_db.users.0]

View File

@@ -302,8 +302,6 @@ pub struct General {
pub auth_query: Option<String>, pub auth_query: Option<String>,
pub auth_query_user: Option<String>, pub auth_query_user: Option<String>,
pub auth_query_password: Option<String>, pub auth_query_password: Option<String>,
pub query_router_plugins: Option<Vec<String>>,
} }
impl General { impl General {
@@ -404,7 +402,6 @@ impl Default for General {
auth_query_user: None, auth_query_user: None,
auth_query_password: None, auth_query_password: None,
server_lifetime: 1000 * 3600 * 24, // 24 hours, server_lifetime: 1000 * 3600 * 24, // 24 hours,
query_router_plugins: None,
} }
} }
} }
@@ -682,6 +679,55 @@ impl Default for Shard {
} }
} }
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Plugins {
pub intercept: Option<Intercept>,
pub table_access: Option<TableAccess>,
pub query_logger: Option<QueryLogger>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Intercept {
pub enabled: bool,
pub queries: BTreeMap<String, Query>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct TableAccess {
pub enabled: bool,
pub tables: Vec<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct QueryLogger {
pub enabled: bool,
}
impl Intercept {
pub fn substitute(&mut self, db: &str, user: &str) {
for (_, query) in self.queries.iter_mut() {
query.substitute(db, user);
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Query {
pub query: String,
pub schema: Vec<Vec<String>>,
pub result: Vec<Vec<String>>,
}
impl Query {
pub fn substitute(&mut self, db: &str, user: &str) {
for col in self.result.iter_mut() {
for i in 0..col.len() {
col[i] = col[i].replace("${USER}", user).replace("${DATABASE}", db);
}
}
}
}
/// Configuration wrapper. /// Configuration wrapper.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Config { pub struct Config {
@@ -700,6 +746,7 @@ pub struct Config {
pub path: String, pub path: String,
pub general: General, pub general: General,
pub plugins: Option<Plugins>,
pub pools: HashMap<String, Pool>, pub pools: HashMap<String, Pool>,
} }
@@ -737,6 +784,7 @@ impl Default for Config {
path: Self::default_path(), path: Self::default_path(),
general: General::default(), general: General::default(),
pools: HashMap::default(), pools: HashMap::default(),
plugins: None,
} }
} }
} }
@@ -1128,6 +1176,7 @@ pub async fn parse(path: &str) -> Result<(), Error> {
pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, Error> { pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, Error> {
let old_config = get_config(); let old_config = get_config();
match parse(&old_config.path).await { match parse(&old_config.path).await {
Ok(()) => (), Ok(()) => (),
Err(err) => { Err(err) => {
@@ -1135,18 +1184,18 @@ pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, E
return Err(Error::BadConfig); return Err(Error::BadConfig);
} }
}; };
let new_config = get_config(); let new_config = get_config();
match CachedResolver::from_config().await { match CachedResolver::from_config().await {
Ok(_) => (), Ok(_) => (),
Err(err) => error!("DNS cache reinitialization error: {:?}", err), Err(err) => error!("DNS cache reinitialization error: {:?}", err),
}; };
if old_config.pools != new_config.pools { if old_config != new_config {
info!("Pool configuration changed"); info!("Config changed, reloading");
ConnectionPool::from_config(client_server_map).await?; ConnectionPool::from_config(client_server_map).await?;
Ok(true) Ok(true)
} else if old_config != new_config {
Ok(true)
} else { } else {
Ok(false) Ok(false)
} }

View File

@@ -11,10 +11,11 @@ use serde_json::{json, Value};
use sqlparser::ast::Statement; use sqlparser::ast::Statement;
use std::collections::HashMap; use std::collections::HashMap;
use log::debug; use log::{debug, info};
use std::sync::Arc; use std::sync::Arc;
use crate::{ use crate::{
config::Intercept as InterceptConfig,
errors::Error, errors::Error,
messages::{command_complete, data_row_nullable, row_description, DataType}, messages::{command_complete, data_row_nullable, row_description, DataType},
plugins::{Plugin, PluginOutput}, plugins::{Plugin, PluginOutput},
@@ -22,19 +23,29 @@ use crate::{
query_router::QueryRouter, query_router::QueryRouter,
}; };
pub static CONFIG: Lazy<ArcSwap<HashMap<PoolIdentifier, Value>>> = pub static CONFIG: Lazy<ArcSwap<HashMap<PoolIdentifier, InterceptConfig>>> =
Lazy::new(|| ArcSwap::from_pointee(HashMap::new())); Lazy::new(|| ArcSwap::from_pointee(HashMap::new()));
/// Configure the intercept plugin. /// Check if the interceptor plugin has been enabled.
pub fn configure(pools: &PoolMap) { pub fn enabled() -> bool {
!CONFIG.load().is_empty()
}
pub fn setup(intercept_config: &InterceptConfig, pools: &PoolMap) {
let mut config = HashMap::new(); let mut config = HashMap::new();
for (identifier, _) in pools.iter() { for (identifier, _) in pools.iter() {
// TODO: make this configurable from a text config. let mut intercept_config = intercept_config.clone();
let value = fool_datagrip(&identifier.db, &identifier.user); intercept_config.substitute(&identifier.db, &identifier.user);
config.insert(identifier.clone(), value); config.insert(identifier.clone(), intercept_config);
} }
CONFIG.store(Arc::new(config)); CONFIG.store(Arc::new(config));
info!("Intercepting {} queries", intercept_config.queries.len());
}
pub fn disable() {
CONFIG.store(Arc::new(HashMap::new()));
} }
// TODO: use these structs for deserialization // TODO: use these structs for deserialization
@@ -78,19 +89,19 @@ impl Plugin for Intercept {
// Normalization // Normalization
let q = q.to_string().to_ascii_lowercase(); let q = q.to_string().to_ascii_lowercase();
for target in query_map.as_array().unwrap().iter() { for (_, target) in query_map.queries.iter() {
if target["query"].as_str().unwrap() == q { if target.query.as_str() == q {
debug!("Query matched: {}", q); debug!("Intercepting query: {}", q);
let rd = target["schema"] let rd = target
.as_array() .schema
.unwrap()
.iter() .iter()
.map(|row| { .map(|row| {
let row = row.as_object().unwrap(); let name = &row[0];
let data_type = &row[1];
( (
row["name"].as_str().unwrap(), name.as_str(),
match row["data_type"].as_str().unwrap() { match data_type.as_str() {
"text" => DataType::Text, "text" => DataType::Text,
"anyarray" => DataType::AnyArray, "anyarray" => DataType::AnyArray,
"oid" => DataType::Oid, "oid" => DataType::Oid,
@@ -104,13 +115,11 @@ impl Plugin for Intercept {
result.put(row_description(&rd)); result.put(row_description(&rd));
target["result"].as_array().unwrap().iter().for_each(|row| { target.result.iter().for_each(|row| {
let row = row let row = row
.as_array()
.unwrap()
.iter() .iter()
.map(|s| { .map(|s| {
let s = s.as_str().unwrap().to_string(); let s = s.as_str().to_string();
if s == "" { if s == "" {
None None
@@ -141,6 +150,7 @@ impl Plugin for Intercept {
/// Make IntelliJ SQL plugin believe it's talking to an actual database /// Make IntelliJ SQL plugin believe it's talking to an actual database
/// instead of PgCat. /// instead of PgCat.
#[allow(dead_code)]
fn fool_datagrip(database: &str, user: &str) -> Value { fn fool_datagrip(database: &str, user: &str) -> Value {
json!([ json!([
{ {

View File

@@ -9,6 +9,7 @@
//! //!
pub mod intercept; pub mod intercept;
pub mod query_logger;
pub mod table_access; pub mod table_access;
use crate::{errors::Error, query_router::QueryRouter}; use crate::{errors::Error, query_router::QueryRouter};
@@ -17,6 +18,7 @@ use bytes::BytesMut;
use sqlparser::ast::Statement; use sqlparser::ast::Statement;
pub use intercept::Intercept; pub use intercept::Intercept;
pub use query_logger::QueryLogger;
pub use table_access::TableAccess; pub use table_access::TableAccess;
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
@@ -29,12 +31,13 @@ pub enum PluginOutput {
#[async_trait] #[async_trait]
pub trait Plugin { pub trait Plugin {
// Custom output is allowed because we want to extend this system // Run before the query is sent to the server.
// to rewriting queries some day. So an output of a plugin could be
// a rewritten AST.
async fn run( async fn run(
&mut self, &mut self,
query_router: &QueryRouter, query_router: &QueryRouter,
ast: &Vec<Statement>, ast: &Vec<Statement>,
) -> Result<PluginOutput, Error>; ) -> Result<PluginOutput, Error>;
// TODO: run after the result is returned
// async fn callback(&mut self, query_router: &QueryRouter);
} }

View File

@@ -0,0 +1,49 @@
//! Log all queries to stdout (or somewhere else, why not).
use crate::{
errors::Error,
plugins::{Plugin, PluginOutput},
query_router::QueryRouter,
};
use arc_swap::ArcSwap;
use async_trait::async_trait;
use log::info;
use once_cell::sync::Lazy;
use sqlparser::ast::Statement;
use std::sync::Arc;
static ENABLED: Lazy<ArcSwap<bool>> = Lazy::new(|| ArcSwap::from_pointee(false));
pub struct QueryLogger;
pub fn setup() {
ENABLED.store(Arc::new(true));
info!("Logging queries to stdout");
}
pub fn disable() {
ENABLED.store(Arc::new(false));
}
pub fn enabled() -> bool {
**ENABLED.load()
}
#[async_trait]
impl Plugin for QueryLogger {
async fn run(
&mut self,
_query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error> {
let query = ast
.iter()
.map(|q| q.to_string())
.collect::<Vec<String>>()
.join("; ");
info!("{}", query);
Ok(PluginOutput::Allow)
}
}

View File

@@ -5,17 +5,37 @@ use async_trait::async_trait;
use sqlparser::ast::{visit_relations, Statement}; use sqlparser::ast::{visit_relations, Statement};
use crate::{ use crate::{
config::TableAccess as TableAccessConfig,
errors::Error, errors::Error,
plugins::{Plugin, PluginOutput}, plugins::{Plugin, PluginOutput},
query_router::QueryRouter, query_router::QueryRouter,
}; };
use core::ops::ControlFlow; use log::{debug, info};
pub struct TableAccess { use arc_swap::ArcSwap;
pub forbidden_tables: Vec<String>, use core::ops::ControlFlow;
use once_cell::sync::Lazy;
use std::sync::Arc;
static CONFIG: Lazy<ArcSwap<Vec<String>>> = Lazy::new(|| ArcSwap::from_pointee(vec![]));
pub fn setup(config: &TableAccessConfig) {
CONFIG.store(Arc::new(config.tables.clone()));
info!("Blocking access to {} tables", config.tables.len());
} }
pub fn enabled() -> bool {
!CONFIG.load().is_empty()
}
pub fn disable() {
CONFIG.store(Arc::new(vec![]));
}
pub struct TableAccess;
#[async_trait] #[async_trait]
impl Plugin for TableAccess { impl Plugin for TableAccess {
async fn run( async fn run(
@@ -24,13 +44,14 @@ impl Plugin for TableAccess {
ast: &Vec<Statement>, ast: &Vec<Statement>,
) -> Result<PluginOutput, Error> { ) -> Result<PluginOutput, Error> {
let mut found = None; let mut found = None;
let forbidden_tables = CONFIG.load();
visit_relations(ast, |relation| { visit_relations(ast, |relation| {
let relation = relation.to_string(); let relation = relation.to_string();
let parts = relation.split(".").collect::<Vec<&str>>(); let parts = relation.split(".").collect::<Vec<&str>>();
let table_name = parts.last().unwrap(); let table_name = parts.last().unwrap();
if self.forbidden_tables.contains(&table_name.to_string()) { if forbidden_tables.contains(&table_name.to_string()) {
found = Some(table_name.to_string()); found = Some(table_name.to_string());
ControlFlow::<()>::Break(()) ControlFlow::<()>::Break(())
} else { } else {
@@ -39,6 +60,8 @@ impl Plugin for TableAccess {
}); });
if let Some(found) = found { if let Some(found) = found {
debug!("Blocking access to table \"{}\"", found);
Ok(PluginOutput::Deny(format!( Ok(PluginOutput::Deny(format!(
"permission for table \"{}\" denied", "permission for table \"{}\" denied",
found found

View File

@@ -132,8 +132,6 @@ pub struct PoolSettings {
pub auth_query: Option<String>, pub auth_query: Option<String>,
pub auth_query_user: Option<String>, pub auth_query_user: Option<String>,
pub auth_query_password: Option<String>, pub auth_query_password: Option<String>,
pub plugins: Option<Vec<String>>,
} }
impl Default for PoolSettings { impl Default for PoolSettings {
@@ -158,7 +156,6 @@ impl Default for PoolSettings {
auth_query: None, auth_query: None,
auth_query_user: None, auth_query_user: None,
auth_query_password: None, auth_query_password: None,
plugins: None,
} }
} }
} }
@@ -453,7 +450,6 @@ impl ConnectionPool {
auth_query: pool_config.auth_query.clone(), auth_query: pool_config.auth_query.clone(),
auth_query_user: pool_config.auth_query_user.clone(), auth_query_user: pool_config.auth_query_user.clone(),
auth_query_password: pool_config.auth_query_password.clone(), auth_query_password: pool_config.auth_query_password.clone(),
plugins: config.general.query_router_plugins.clone(),
}, },
validated: Arc::new(AtomicBool::new(false)), validated: Arc::new(AtomicBool::new(false)),
paused: Arc::new(AtomicBool::new(false)), paused: Arc::new(AtomicBool::new(false)),
@@ -473,10 +469,29 @@ impl ConnectionPool {
} }
} }
// Initialize plugins here if required. if let Some(ref plugins) = config.plugins {
if let Some(plugins) = config.general.query_router_plugins { if let Some(ref intercept) = plugins.intercept {
if plugins.contains(&String::from("intercept")) { if intercept.enabled {
crate::plugins::intercept::configure(&new_pools); crate::plugins::intercept::setup(intercept, &new_pools);
} else {
crate::plugins::intercept::disable();
}
}
if let Some(ref table_access) = plugins.table_access {
if table_access.enabled {
crate::plugins::table_access::setup(table_access);
} else {
crate::plugins::table_access::disable();
}
}
if let Some(ref query_logger) = plugins.query_logger {
if query_logger.enabled {
crate::plugins::query_logger::setup();
} else {
crate::plugins::query_logger::disable();
}
} }
} }

View File

@@ -15,7 +15,10 @@ use sqlparser::parser::Parser;
use crate::config::Role; use crate::config::Role;
use crate::errors::Error; use crate::errors::Error;
use crate::messages::BytesMutReader; use crate::messages::BytesMutReader;
use crate::plugins::{Intercept, Plugin, PluginOutput, TableAccess}; use crate::plugins::{
intercept, query_logger, table_access, Intercept, Plugin, PluginOutput, QueryLogger,
TableAccess,
};
use crate::pool::PoolSettings; use crate::pool::PoolSettings;
use crate::sharding::Sharder; use crate::sharding::Sharder;
@@ -790,24 +793,26 @@ impl QueryRouter {
/// Add your plugins here and execute them. /// Add your plugins here and execute them.
pub async fn execute_plugins(&self, ast: &Vec<Statement>) -> Result<PluginOutput, Error> { pub async fn execute_plugins(&self, ast: &Vec<Statement>) -> Result<PluginOutput, Error> {
if let Some(plugins) = &self.pool_settings.plugins { if query_logger::enabled() {
if plugins.contains(&String::from("intercept")) { let mut query_logger = QueryLogger {};
let mut intercept = Intercept {}; let _ = query_logger.run(&self, ast).await;
let result = intercept.run(&self, ast).await; }
if let Ok(PluginOutput::Intercept(output)) = result { if intercept::enabled() {
return Ok(PluginOutput::Intercept(output)); let mut intercept = Intercept {};
} let result = intercept.run(&self, ast).await;
if let Ok(PluginOutput::Intercept(output)) = result {
return Ok(PluginOutput::Intercept(output));
} }
}
if plugins.contains(&String::from("pg_table_access")) { if table_access::enabled() {
let mut table_access = TableAccess { let mut table_access = TableAccess {};
forbidden_tables: vec![String::from("pg_database"), String::from("pg_roles")], let result = table_access.run(&self, ast).await;
};
if let Ok(PluginOutput::Deny(error)) = table_access.run(&self, ast).await { if let Ok(PluginOutput::Deny(error)) = result {
return Ok(PluginOutput::Deny(error)); return Ok(PluginOutput::Deny(error));
}
} }
} }
@@ -1156,7 +1161,6 @@ mod test {
auth_query_password: None, auth_query_password: None,
auth_query_user: None, auth_query_user: None,
db: "test".to_string(), db: "test".to_string(),
plugins: None,
}; };
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
assert_eq!(qr.active_role, None); assert_eq!(qr.active_role, None);
@@ -1231,7 +1235,6 @@ mod test {
auth_query_password: None, auth_query_password: None,
auth_query_user: None, auth_query_user: None,
db: "test".to_string(), db: "test".to_string(),
plugins: None,
}; };
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
qr.update_pool_settings(pool_settings.clone()); qr.update_pool_settings(pool_settings.clone());
@@ -1376,13 +1379,17 @@ mod test {
#[tokio::test] #[tokio::test]
async fn test_table_access_plugin() { async fn test_table_access_plugin() {
use crate::config::TableAccess;
let ta = TableAccess {
enabled: true,
tables: vec![String::from("pg_database")],
};
crate::plugins::table_access::setup(&ta);
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let qr = QueryRouter::new();
let mut pool_settings = PoolSettings::default();
pool_settings.plugins = Some(vec![String::from("pg_table_access")]);
qr.update_pool_settings(pool_settings);
let query = simple_query("SELECT * FROM pg_database"); let query = simple_query("SELECT * FROM pg_database");
let ast = QueryRouter::parse(&query).unwrap(); let ast = QueryRouter::parse(&query).unwrap();

View File

@@ -71,15 +71,17 @@ describe "Admin" do
context "client connects but issues no queries" do context "client connects but issues no queries" do
it "only affects cl_idle stats" do it "only affects cl_idle stats" do
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
before_test = admin_conn.async_exec("SHOW POOLS")[0]["sv_idle"]
connections = Array.new(20) { PG::connect(pgcat_conn_str) } connections = Array.new(20) { PG::connect(pgcat_conn_str) }
sleep(1) sleep(1)
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
results = admin_conn.async_exec("SHOW POOLS")[0] results = admin_conn.async_exec("SHOW POOLS")[0]
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s| %w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0" raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end end
expect(results["cl_idle"]).to eq("20") expect(results["cl_idle"]).to eq("20")
expect(results["sv_idle"]).to eq("1") expect(results["sv_idle"]).to eq(before_test)
connections.map(&:close) connections.map(&:close)
sleep(1.1) sleep(1.1)
@@ -87,7 +89,7 @@ describe "Admin" do
%w[cl_active cl_idle cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s| %w[cl_active cl_idle cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0" raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
end end
expect(results["sv_idle"]).to eq("1") expect(results["sv_idle"]).to eq(before_test)
end end
end end

View File

@@ -27,7 +27,6 @@ module Helpers
primary2 = PgInstance.new(8432, user["username"], user["password"], "shard2") primary2 = PgInstance.new(8432, user["username"], user["password"], "shard2")
pgcat_cfg = pgcat.current_config pgcat_cfg = pgcat.current_config
pgcat_cfg["general"]["query_router_plugins"] = ["intercept"]
pgcat_cfg["pools"] = { pgcat_cfg["pools"] = {
"#{pool_name}" => { "#{pool_name}" => {
"default_role" => "any", "default_role" => "any",