diff --git a/.circleci/pgcat.toml b/.circleci/pgcat.toml index d1dd19d..ba3b875 100644 --- a/.circleci/pgcat.toml +++ b/.circleci/pgcat.toml @@ -87,7 +87,7 @@ default_role = "any" # every incoming query to determine if it's a read or a write. # If it's a read query, we'll direct it to a replica. Otherwise, if it's a write, # we'll direct it to the primary. -query_parser_enabled = false +query_parser_enabled = true # If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for # load balancing of read queries. Otherwise, the primary will only be used for write diff --git a/.circleci/run_tests.sh b/.circleci/run_tests.sh index 66fcb79..a0e23f0 100644 --- a/.circleci/run_tests.sh +++ b/.circleci/run_tests.sh @@ -42,7 +42,15 @@ pgbench -U sharding_user -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol extended psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'COPY (SELECT * FROM pgbench_accounts LIMIT 15) TO STDOUT;' > /dev/null # Query cancellation test -(psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'SELECT pg_sleep(5)' || true) & +(psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'SELECT pg_sleep(50)' || true) & +sleep 1 +killall psql -s SIGINT + +# Reload pool (closing unused server connections) +psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD' + +(psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'SELECT pg_sleep(50)' || true) & +sleep 1 killall psql -s SIGINT # Sharding insert @@ -94,7 +102,7 @@ toxiproxy-cli toxic remove --toxicName latency_downstream postgres_replica start_pgcat "info" # Test session mode (and config reload) -sed -i 's/pool_mode = "transaction"/pool_mode = "session"/' pgcat.toml +sed -i 's/pool_mode = "transaction"/pool_mode = "session"/' .circleci/pgcat.toml # Reload config test kill -SIGHUP $(pgrep pgcat) diff --git a/Cargo.lock b/Cargo.lock index 668f421..87d11be 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -368,7 +368,7 @@ dependencies = [ [[package]] name = "pgcat" -version = "0.2.0-beta1" +version = "0.4.0-beta1" dependencies = [ "arc-swap", "async-trait", diff --git a/Cargo.toml b/Cargo.toml index 924b9cb..fa63c0e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgcat" -version = "0.2.1-beta1" +version = "0.4.0-beta1" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -28,4 +28,4 @@ parking_lot = "0.11" hmac = "0.12" sha2 = "0.10" base64 = "0.13" -stringprep = "0.1" \ No newline at end of file +stringprep = "0.1" diff --git a/pgcat.toml b/pgcat.toml index 5311bd7..435dda9 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -87,7 +87,7 @@ default_role = "any" # every incoming query to determine if it's a read or a write. # If it's a read query, we'll direct it to a replica. Otherwise, if it's a write, # we'll direct it to the primary. -query_parser_enabled = false +query_parser_enabled = true # If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for # load balancing of read queries. Otherwise, the primary will only be used for write diff --git a/src/admin.rs b/src/admin.rs index c467570..b7a5b6f 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -4,17 +4,19 @@ use log::{info, trace}; use std::collections::HashMap; use tokio::net::tcp::OwnedWriteHalf; -use crate::config::{get_config, parse}; +use crate::config::{get_config, reload_config}; use crate::errors::Error; use crate::messages::*; use crate::pool::ConnectionPool; use crate::stats::get_stats; +use crate::ClientServerMap; /// Handle admin client. pub async fn handle_admin( stream: &mut OwnedWriteHalf, mut query: BytesMut, pool: ConnectionPool, + client_server_map: ClientServerMap, ) -> Result<(), Error> { let code = query.get_u8() as char; @@ -34,7 +36,7 @@ pub async fn handle_admin( show_stats(stream, &pool).await } else if query.starts_with("RELOAD") { trace!("RELOAD"); - reload(stream).await + reload(stream, client_server_map).await } else if query.starts_with("SHOW CONFIG") { trace!("SHOW CONFIG"); show_config(stream).await @@ -143,10 +145,7 @@ async fn show_version(stream: &mut OwnedWriteHalf) -> Result<(), Error> { /// Show utilization of connection pools for each shard and replicas. async fn show_pools(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> { let stats = get_stats(); - let config = { - let guard = get_config(); - &*guard.clone() - }; + let config = get_config(); let columns = vec![ ("database", DataType::Text), @@ -199,9 +198,7 @@ async fn show_pools(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Resul /// Show shards and replicas. async fn show_databases(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> { - let guard = get_config(); - let config = &*guard.clone(); - drop(guard); + let config = get_config(); // Columns let columns = vec![ @@ -266,17 +263,15 @@ async fn ignore_set(stream: &mut OwnedWriteHalf) -> Result<(), Error> { } /// Reload the configuration file without restarting the process. -async fn reload(stream: &mut OwnedWriteHalf) -> Result<(), Error> { +async fn reload( + stream: &mut OwnedWriteHalf, + client_server_map: ClientServerMap, +) -> Result<(), Error> { info!("Reloading config"); - let config = get_config(); - let path = config.path.clone().unwrap(); + reload_config(client_server_map).await?; - parse(&path).await?; - - let config = get_config(); - - config.show(); + get_config().show(); let mut res = BytesMut::new(); @@ -292,10 +287,8 @@ async fn reload(stream: &mut OwnedWriteHalf) -> Result<(), Error> { /// Shows current configuration. async fn show_config(stream: &mut OwnedWriteHalf) -> Result<(), Error> { - let guard = get_config(); - let config = &*guard.clone(); + let config = &get_config(); let config: HashMap = config.into(); - drop(guard); // Configs that cannot be changed without restarting. let immutables = ["host", "port", "connect_timeout"]; diff --git a/src/client.rs b/src/client.rs index 08b7049..b53bd33 100644 --- a/src/client.rs +++ b/src/client.rs @@ -13,10 +13,10 @@ use crate::config::get_config; use crate::constants::*; use crate::errors::Error; use crate::messages::*; -use crate::pool::{ClientServerMap, ConnectionPool}; +use crate::pool::{get_pool, ClientServerMap}; use crate::query_router::{Command, QueryRouter}; use crate::server::Server; -use crate::stats::Reporter; +use crate::stats::{get_reporter, Reporter}; /// The client state. One of these is created per client. pub struct Client { @@ -69,12 +69,11 @@ impl Client { pub async fn startup( mut stream: TcpStream, client_server_map: ClientServerMap, - server_info: BytesMut, - stats: Reporter, ) -> Result { - let config = get_config().clone(); - let transaction_mode = config.general.pool_mode.starts_with("t"); - // drop(config); + let config = get_config(); + let transaction_mode = config.general.pool_mode == "transaction"; + let stats = get_reporter(); + loop { trace!("Waiting for StartupMessage"); @@ -154,9 +153,10 @@ impl Client { debug!("Password authentication successful"); auth_ok(&mut stream).await?; - write_all(&mut stream, server_info).await?; + write_all(&mut stream, get_pool().server_info()).await?; backend_key_data(&mut stream, process_id, secret_key).await?; ready_for_query(&mut stream).await?; + trace!("Startup OK"); let database = parameters @@ -221,7 +221,7 @@ impl Client { } /// Handle a connected and authenticated client. - pub async fn handle(&mut self, mut pool: ConnectionPool) -> Result<(), Error> { + pub async fn handle(&mut self) -> Result<(), Error> { // The client wants to cancel a query it has issued previously. if self.cancel_mode { trace!("Sending CancelRequest"); @@ -252,13 +252,19 @@ impl Client { return Ok(Server::cancel(&address, &port, process_id, secret_key).await?); } + // The query router determines where the query is going to go, + // e.g. primary, replica, which shard. let mut query_router = QueryRouter::new(); + let mut round_robin = 0; // Our custom protocol loop. // We expect the client to either start a transaction with regular queries // or issue commands for our sharding and server selection protocol. loop { - trace!("Client idle, waiting for message"); + trace!( + "Client idle, waiting for message, transaction mode: {}", + self.transaction_mode + ); // Read a complete message from the client, which normally would be // either a `Q` (query) or `P` (prepare, extended protocol). @@ -267,32 +273,63 @@ impl Client { // SET SHARDING KEY TO 'bigint'; let mut message = read_message(&mut self.read).await?; + // Get a pool instance referenced by the most up-to-date + // pointer. This ensures we always read the latest config + // when starting a query. + let mut pool = get_pool(); + // Avoid taking a server if the client just wants to disconnect. if message[0] as char == 'X' { - trace!("Client disconnecting"); + debug!("Client disconnecting"); return Ok(()); } // Handle admin database queries. if self.admin { - trace!("Handling admin command"); - handle_admin(&mut self.write, message, pool.clone()).await?; + debug!("Handling admin command"); + handle_admin( + &mut self.write, + message, + pool.clone(), + self.client_server_map.clone(), + ) + .await?; continue; } + let current_shard = query_router.shard(); + // Handle all custom protocol commands, if any. match query_router.try_execute_command(message.clone()) { // Normal query, not a custom command. - None => { - // Attempt to infer which server we want to query, i.e. primary or replica. - if query_router.query_parser_enabled() && query_router.role() == None { - query_router.infer_role(message.clone()); - } - } + None => (), // SET SHARD TO Some((Command::SetShard, _)) => { - custom_protocol_response_ok(&mut self.write, "SET SHARD").await?; + // Selected shard is not configured. + if query_router.shard() >= pool.shards() { + // Set the shard back to what it was. + query_router.set_shard(current_shard); + + error_response( + &mut self.write, + &format!( + "shard {} is more than configured {}, staying on shard {}", + query_router.shard(), + pool.shards(), + current_shard, + ), + ) + .await?; + } else { + custom_protocol_response_ok(&mut self.write, "SET SHARD").await?; + } + continue; + } + + // SET PRIMARY READS TO + Some((Command::SetPrimaryReads, _)) => { + custom_protocol_response_ok(&mut self.write, "SET PRIMARY READS").await?; continue; } @@ -319,27 +356,24 @@ impl Client { show_response(&mut self.write, "shard", &value).await?; continue; } - }; - // Make sure we selected a valid shard. - if query_router.shard() >= pool.shards() { - error_response( - &mut self.write, - &format!( - "shard {} is more than configured {}", - query_router.shard(), - pool.shards() - ), - ) - .await?; - continue; - } + // SHOW PRIMARY READS + Some((Command::ShowPrimaryReads, value)) => { + show_response(&mut self.write, "primary reads", &value).await?; + continue; + } + }; debug!("Waiting for connection from pool"); // Grab a server from the pool. let connection = match pool - .get(query_router.shard(), query_router.role(), self.process_id) + .get( + query_router.shard(), + query_router.role(), + self.process_id, + round_robin, + ) .await { Ok(conn) => { @@ -358,6 +392,8 @@ impl Client { let address = connection.1; let server = &mut *reference; + round_robin += 1; + // Server is assigned to the client in case the client wants to // cancel a query later. server.claim(self.process_id, self.secret_key); diff --git a/src/config.rs b/src/config.rs index f11aa31..96d5a77 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,5 +1,5 @@ /// Parse the configuration file. -use arc_swap::{ArcSwap, Guard}; +use arc_swap::ArcSwap; use log::{error, info}; use once_cell::sync::Lazy; use serde_derive::Deserialize; @@ -10,6 +10,7 @@ use tokio::io::AsyncReadExt; use toml; use crate::errors::Error; +use crate::{ClientServerMap, ConnectionPool}; /// Globally available configuration. static CONFIG: Lazy> = Lazy::new(|| ArcSwap::from_pointee(Config::default())); @@ -126,7 +127,7 @@ impl Default for General { } /// Shard configuration. -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone, PartialEq)] pub struct Shard { pub servers: Vec<(String, u16, String)>, pub database: String, @@ -161,10 +162,16 @@ impl Default for QueryRouter { } } +fn default_path() -> String { + String::from("pgcat.toml") +} + /// Configuration wrapper. #[derive(Deserialize, Debug, Clone)] pub struct Config { - pub path: Option, + #[serde(default = "default_path")] + pub path: String, + pub general: General, pub user: User, pub shards: HashMap, @@ -174,7 +181,7 @@ pub struct Config { impl Default for Config { fn default() -> Config { Config { - path: Some(String::from("pgcat.toml")), + path: String::from("pgcat.toml"), general: General::default(), user: User::default(), shards: HashMap::from([(String::from("1"), Shard::default())]), @@ -237,6 +244,8 @@ impl Config { ); info!("Connection timeout: {}ms", self.general.connect_timeout); info!("Sharding function: {}", self.query_router.sharding_function); + info!("Primary reads: {}", self.query_router.primary_reads_enabled); + info!("Query router: {}", self.query_router.query_parser_enabled); info!("Number of shards: {}", self.shards.len()); } } @@ -244,8 +253,8 @@ impl Config { /// Get a read-only instance of the configuration /// from anywhere in the app. /// ArcSwap makes this cheap and quick. -pub fn get_config() -> Guard> { - CONFIG.load() +pub fn get_config() -> Config { + (*(*CONFIG.load())).clone() } /// Parse the configuration file located at the path. @@ -357,7 +366,7 @@ pub async fn parse(path: &str) -> Result<(), Error> { } }; - config.path = Some(path.to_string()); + config.path = path.to_string(); // Update the configuration globally. CONFIG.store(Arc::new(config.clone())); @@ -365,6 +374,27 @@ pub async fn parse(path: &str) -> Result<(), Error> { Ok(()) } +pub async fn reload_config(client_server_map: ClientServerMap) -> Result<(), Error> { + let old_config = get_config(); + + match parse(&old_config.path).await { + Ok(()) => (), + Err(err) => { + error!("Config reload error: {:?}", err); + return Err(Error::BadConfig); + } + }; + + let new_config = get_config(); + + if old_config.shards != new_config.shards { + info!("Sharding configuration changed, re-creating server pools"); + ConnectionPool::from_config(client_server_map).await + } else { + Ok(()) + } +} + #[cfg(test)] mod test { use super::*; @@ -377,6 +407,6 @@ mod test { assert_eq!(get_config().shards["1"].servers[0].0, "127.0.0.1"); assert_eq!(get_config().shards["0"].servers[0].2, "primary"); assert_eq!(get_config().query_router.default_role, "any"); - assert_eq!(get_config().path, Some("pgcat.toml".to_string())); + assert_eq!(get_config().path, "pgcat.toml".to_string()); } } diff --git a/src/main.rs b/src/main.rs index c22391e..70094d8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Lev Kokotov +// Copyright (c) 2022 Lev Kokotov // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -34,7 +34,7 @@ extern crate sqlparser; extern crate tokio; extern crate toml; -use log::{error, info}; +use log::{debug, error, info}; use parking_lot::Mutex; use tokio::net::TcpListener; use tokio::{ @@ -59,9 +59,9 @@ mod server; mod sharding; mod stats; -use config::get_config; -use pool::{ClientServerMap, ConnectionPool}; -use stats::{Collector, Reporter}; +use config::{get_config, reload_config}; +use pool::{get_pool, ClientServerMap, ConnectionPool}; +use stats::{Collector, Reporter, REPORTER}; #[tokio::main(worker_threads = 4)] async fn main() { @@ -109,37 +109,39 @@ async fn main() { // Statistics reporting. let (tx, rx) = mpsc::channel(100); + REPORTER.store(Arc::new(Reporter::new(tx.clone()))); // Connection pool that allows to query all shards and replicas. - let mut pool = - ConnectionPool::from_config(client_server_map.clone(), Reporter::new(tx.clone())).await; + match ConnectionPool::from_config(client_server_map.clone()).await { + Ok(_) => (), + Err(err) => { + error!("Pool error: {:?}", err); + return; + } + }; + + let pool = get_pool(); // Statistics collector task. let collector_tx = tx.clone(); + + // Save these for reloading + let reload_client_server_map = client_server_map.clone(); + let addresses = pool.databases(); tokio::task::spawn(async move { let mut stats_collector = Collector::new(rx, collector_tx); stats_collector.collect(addresses).await; }); - // Connect to all servers and validate their versions. - let server_info = match pool.validate().await { - Ok(info) => info, - Err(err) => { - error!("Could not validate connection pool: {:?}", err); - return; - } - }; - info!("Waiting for clients"); + drop(pool); + // Client connection loop. tokio::task::spawn(async move { loop { - let pool = pool.clone(); let client_server_map = client_server_map.clone(); - let server_info = server_info.clone(); - let reporter = Reporter::new(tx.clone()); let (socket, addr) = match listener.accept().await { Ok((socket, addr)) => (socket, addr), @@ -152,12 +154,11 @@ async fn main() { // Handle client. tokio::task::spawn(async move { let start = chrono::offset::Utc::now().naive_utc(); - match client::Client::startup(socket, client_server_map, server_info, reporter) - .await - { + match client::Client::startup(socket, client_server_map).await { Ok(mut client) => { info!("Client {:?} connected", addr); - match client.handle(pool).await { + + match client.handle().await { Ok(()) => { let duration = chrono::offset::Utc::now().naive_utc() - start; @@ -176,7 +177,7 @@ async fn main() { } Err(err) => { - error!("Client failed to login: {:?}", err); + debug!("Client failed to login: {:?}", err); } }; }); @@ -190,16 +191,15 @@ async fn main() { loop { stream.recv().await; + info!("Reloading config"); - match config::parse("pgcat.toml").await { - Ok(_) => { - get_config().show(); - } - Err(err) => { - error!("{:?}", err); - return; - } + + match reload_config(reload_client_server_map.clone()).await { + Ok(_) => (), + Err(_) => continue, }; + + get_config().show(); } }); diff --git a/src/pool.rs b/src/pool.rs index 53803e7..3f2f364 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -1,9 +1,10 @@ -/// Pooling, failover and banlist. +use arc_swap::ArcSwap; use async_trait::async_trait; use bb8::{ManageConnection, Pool, PooledConnection}; use bytes::BytesMut; use chrono::naive::NaiveDateTime; use log::{debug, error, info, warn}; +use once_cell::sync::Lazy; use parking_lot::{Mutex, RwLock}; use std::collections::HashMap; use std::sync::Arc; @@ -12,28 +13,47 @@ use std::time::Instant; use crate::config::{get_config, Address, Role, User}; use crate::errors::Error; use crate::server::Server; -use crate::stats::Reporter; +use crate::stats::{get_reporter, Reporter}; pub type BanList = Arc>>>; pub type ClientServerMap = Arc>>; +/// The connection pool, globally available. +/// This is atomic and safe and read-optimized. +/// The pool is recreated dynamically when the config is reloaded. +pub static POOL: Lazy> = + Lazy::new(|| ArcSwap::from_pointee(ConnectionPool::default())); + /// The globally accessible connection pool. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default)] pub struct ConnectionPool { + /// The pools handled internally by bb8. databases: Vec>>, + + /// The addresses (host, port, role) to handle + /// failover and load balancing deterministically. addresses: Vec>, - round_robin: usize, + + /// List of banned addresses (see above) + /// that should not be queried. banlist: BanList, + + /// The statistics aggregator runs in a separate task + /// and receives stats from clients, servers, and the pool. stats: Reporter, + + /// The server information (K messages) have to be passed to the + /// clients on startup. We pre-connect to all shards and replicas + /// on pool creation and save the K messages here. + server_info: BytesMut, } impl ConnectionPool { /// Construct the connection pool from the configuration. - pub async fn from_config( - client_server_map: ClientServerMap, - stats: Reporter, - ) -> ConnectionPool { + pub async fn from_config(client_server_map: ClientServerMap) -> Result<(), Error> { + let reporter = get_reporter(); let config = get_config(); + let mut shards = Vec::new(); let mut addresses = Vec::new(); let mut banlist = Vec::new(); @@ -44,6 +64,8 @@ impl ConnectionPool { .into_keys() .map(|x| x.to_string()) .collect::>(); + + // Sort by shard number to ensure consistency. shard_ids.sort_by_key(|k| k.parse::().unwrap()); for shard_idx in shard_ids { @@ -82,7 +104,7 @@ impl ConnectionPool { config.user.clone(), &shard.database, client_server_map.clone(), - stats.clone(), + reporter.clone(), ); let pool = Pool::builder() @@ -105,15 +127,28 @@ impl ConnectionPool { } assert_eq!(shards.len(), addresses.len()); - let address_len = addresses.len(); - ConnectionPool { + let mut pool = ConnectionPool { databases: shards, addresses: addresses, - round_robin: rand::random::() % address_len, // Start at a random replica banlist: Arc::new(RwLock::new(banlist)), - stats: stats, - } + stats: reporter, + server_info: BytesMut::new(), + }; + + // Connect to the servers to make sure pool configuration is valid + // before setting it globally. + match pool.validate().await { + Ok(_) => (), + Err(err) => { + error!("Could not validate connection pool: {:?}", err); + return Err(err); + } + }; + + POOL.store(Arc::new(pool.clone())); + + Ok(()) } /// Connect to all shards and grab server information. @@ -121,16 +156,18 @@ impl ConnectionPool { /// when they connect. /// This also warms up the pool for clients that connect when /// the pooler starts up. - pub async fn validate(&mut self) -> Result { + async fn validate(&mut self) -> Result<(), Error> { let mut server_infos = Vec::new(); - let stats = self.stats.clone(); + for shard in 0..self.shards() { + let mut round_robin = 0; + for _ in 0..self.servers(shard) { // To keep stats consistent. let fake_process_id = 0; - let connection = match self.get(shard, None, fake_process_id).await { + let connection = match self.get(shard, None, fake_process_id, round_robin).await { Ok(conn) => conn, Err(err) => { error!("Shard {} down or misconfigured: {:?}", shard, err); @@ -138,10 +175,9 @@ impl ConnectionPool { } }; - let mut proxy = connection.0; + let proxy = connection.0; let address = connection.1; - let server = &mut *proxy; - + let server = &*proxy; let server_info = server.server_info(); stats.client_disconnecting(fake_process_id, address.id); @@ -157,6 +193,7 @@ impl ConnectionPool { } server_infos.push(server_info); + round_robin += 1; } } @@ -166,15 +203,18 @@ impl ConnectionPool { return Err(Error::AllServersDown); } - Ok(server_infos[0].clone()) + self.server_info = server_infos[0].clone(); + + Ok(()) } /// Get a connection from the pool. pub async fn get( &mut self, - shard: usize, - role: Option, - process_id: i32, + shard: usize, // shard number + role: Option, // primary or replica + process_id: i32, // client id + mut round_robin: usize, // round robin offset ) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> { let now = Instant::now(); let addresses = &self.addresses[shard]; @@ -204,9 +244,9 @@ impl ConnectionPool { while allowed_attempts > 0 { // Round-robin replicas. - self.round_robin += 1; + round_robin += 1; - let index = self.round_robin % addresses.len(); + let index = round_robin % addresses.len(); let address = &addresses[index]; // Make sure you're getting a primary or a replica @@ -218,6 +258,7 @@ impl ConnectionPool { allowed_attempts -= 1; + // Don't attempt to connect to banned servers. if self.is_banned(address, shard, role) { continue; } @@ -390,6 +431,10 @@ impl ConnectionPool { pub fn address(&self, shard: usize, server: usize) -> &Address { &self.addresses[shard][server] } + + pub fn server_info(&self) -> BytesMut { + self.server_info.clone() + } } /// Wrapper for the bb8 connection pool. @@ -470,3 +515,8 @@ impl ManageConnection for ServerPool { conn.is_bad() } } + +/// Get the connection pool +pub fn get_pool() -> ConnectionPool { + (*(*POOL.load())).clone() +} diff --git a/src/query_router.rs b/src/query_router.rs index 1f20079..995a25c 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -12,12 +12,14 @@ use crate::config::{get_config, Role}; use crate::sharding::{Sharder, ShardingFunction}; /// Regexes used to parse custom commands. -const CUSTOM_SQL_REGEXES: [&str; 5] = [ +const CUSTOM_SQL_REGEXES: [&str; 7] = [ r"(?i)^ *SET SHARDING KEY TO '?([0-9]+)'? *;? *$", r"(?i)^ *SET SHARD TO '?([0-9]+|ANY)'? *;? *$", r"(?i)^ *SHOW SHARD *;? *$", r"(?i)^ *SET SERVER ROLE TO '(PRIMARY|REPLICA|ANY|AUTO|DEFAULT)' *;? *$", r"(?i)^ *SHOW SERVER ROLE *;? *$", + r"(?i)^ *SET PRIMARY READS TO '?(on|off|default)'? *;? *$", + r"(?i)^ *SHOW PRIMARY READS *;? *$", ]; /// Custom commands. @@ -28,6 +30,8 @@ pub enum Command { ShowShard, SetServerRole, ShowServerRole, + SetPrimaryReads, + ShowPrimaryReads, } /// Quickly test for match when a query is received. @@ -38,27 +42,17 @@ static CUSTOM_SQL_REGEX_LIST: OnceCell> = OnceCell::new(); /// The query router. pub struct QueryRouter { - /// By default, queries go here, unless we have better information - /// about what the client wants. - default_server_role: Option, - - /// Number of shards in the cluster. - shards: usize, - /// Which shard we should be talking to right now. active_shard: Option, /// Which server should we be talking to. active_role: Option, - /// Include the primary into the replica pool for reads. - primary_reads_enabled: bool, - - /// Should we try to parse queries to route them to replicas or primary automatically. + /// Should we try to parse queries to route them to replicas or primary automatically query_parser_enabled: bool, - /// Which sharding function we're using. - sharding_function: ShardingFunction, + /// Include the primary into the replica pool for reads. + primary_reads_enabled: bool, } impl QueryRouter { @@ -97,28 +91,11 @@ impl QueryRouter { pub fn new() -> QueryRouter { let config = get_config(); - let default_server_role = match config.query_router.default_role.as_ref() { - "any" => None, - "primary" => Some(Role::Primary), - "replica" => Some(Role::Replica), - _ => unreachable!(), - }; - - let sharding_function = match config.query_router.sharding_function.as_ref() { - "pg_bigint_hash" => ShardingFunction::PgBigintHash, - "sha1" => ShardingFunction::Sha1, - _ => unreachable!(), - }; - QueryRouter { - default_server_role: default_server_role, - shards: config.shards.len(), - - active_role: default_server_role, active_shard: None, - primary_reads_enabled: config.query_router.primary_reads_enabled, + active_role: None, query_parser_enabled: config.query_router.query_parser_enabled, - sharding_function, + primary_reads_enabled: config.query_router.primary_reads_enabled, } } @@ -146,21 +123,48 @@ impl QueryRouter { let matches: Vec<_> = regex_set.matches(&query).into_iter().collect(); + // This is not a custom query, try to infer which + // server it'll go to if the query parser is enabled. if matches.len() != 1 { + debug!("Regular query"); + if self.query_parser_enabled && self.role() == None { + debug!("Inferring role"); + self.infer_role(buf.clone()); + } return None; } + let config = get_config(); + + let sharding_function = match config.query_router.sharding_function.as_ref() { + "pg_bigint_hash" => ShardingFunction::PgBigintHash, + "sha1" => ShardingFunction::Sha1, + _ => unreachable!(), + }; + + let default_server_role = match config.query_router.default_role.as_ref() { + "any" => None, + "primary" => Some(Role::Primary), + "replica" => Some(Role::Replica), + _ => unreachable!(), + }; + let command = match matches[0] { 0 => Command::SetShardingKey, 1 => Command::SetShard, 2 => Command::ShowShard, 3 => Command::SetServerRole, 4 => Command::ShowServerRole, + 5 => Command::SetPrimaryReads, + 6 => Command::ShowPrimaryReads, _ => unreachable!(), }; let mut value = match command { - Command::SetShardingKey | Command::SetShard | Command::SetServerRole => { + Command::SetShardingKey + | Command::SetShard + | Command::SetServerRole + | Command::SetPrimaryReads => { // Capture value. I know this re-runs the regex engine, but I haven't // figured out a better way just yet. I think I can write a single Regex // that matches all 5 custom SQL patterns, but maybe that's not very legible? @@ -187,11 +191,16 @@ impl QueryRouter { } } }, + + Command::ShowPrimaryReads => match self.primary_reads_enabled { + true => String::from("on"), + false => String::from("off"), + }, }; match command { Command::SetShardingKey => { - let sharder = Sharder::new(self.shards, self.sharding_function); + let sharder = Sharder::new(config.shards.len(), sharding_function); let shard = sharder.shard(value.parse::().unwrap()); self.active_shard = Some(shard); value = shard.to_string(); @@ -199,7 +208,7 @@ impl QueryRouter { Command::SetShard => { self.active_shard = match value.to_ascii_uppercase().as_ref() { - "ANY" => Some(rand::random::() % self.shards), + "ANY" => Some(rand::random::() % config.shards.len()), _ => Some(value.parse::().unwrap()), }; } @@ -227,8 +236,8 @@ impl QueryRouter { } "default" => { - self.active_role = self.default_server_role; - self.query_parser_enabled = get_config().query_router.query_parser_enabled; + self.active_role = default_server_role; + self.query_parser_enabled = config.query_router.query_parser_enabled; self.active_role } @@ -236,6 +245,19 @@ impl QueryRouter { }; } + Command::SetPrimaryReads => { + if value == "on" { + debug!("Setting primary reads to on"); + self.primary_reads_enabled = true; + } else if value == "off" { + debug!("Setting primary reads to off"); + self.primary_reads_enabled = false; + } else if value == "default" { + debug!("Setting primary reads to default"); + self.primary_reads_enabled = config.query_router.primary_reads_enabled; + } + } + _ => (), } @@ -330,23 +352,15 @@ impl QueryRouter { } } - /// Reset the router back to defaults. - /// This must be called at the end of every transaction in transaction mode. - pub fn _reset(&mut self) { - self.active_role = self.default_server_role; - self.active_shard = None; + pub fn set_shard(&mut self, shard: usize) { + self.active_shard = Some(shard); } /// Should we attempt to parse queries? + #[allow(dead_code)] pub fn query_parser_enabled(&self) -> bool { self.query_parser_enabled } - - /// Allows to toggle primary reads in tests. - #[allow(dead_code)] - pub fn toggle_primary_reads(&mut self, value: bool) { - self.primary_reads_enabled = value; - } } #[cfg(test)] @@ -369,7 +383,8 @@ mod test { let mut qr = QueryRouter::new(); assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None); assert_eq!(qr.query_parser_enabled(), true); - qr.toggle_primary_reads(false); + + assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); let queries = vec![ simple_query("SELECT * FROM items WHERE id = 5"), @@ -410,7 +425,7 @@ mod test { QueryRouter::setup(); let mut qr = QueryRouter::new(); let query = simple_query("SELECT * FROM items WHERE id = 5"); - qr.toggle_primary_reads(true); + assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO on")) != None); assert!(qr.infer_role(query)); assert_eq!(qr.role(), None); @@ -421,7 +436,7 @@ mod test { QueryRouter::setup(); let mut qr = QueryRouter::new(); qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")); - qr.toggle_primary_reads(false); + assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); let prepared_stmt = BytesMut::from( &b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..], @@ -450,6 +465,10 @@ mod test { "SET SERVER ROLE TO 'any'", "SET SERVER ROLE TO 'auto'", "SHOW SERVER ROLE", + "SET PRIMARY READS TO 'on'", + "SET PRIMARY READS TO 'off'", + "SET PRIMARY READS TO 'default'", + "SHOW PRIMARY READS", // Lower case "set sharding key to '1'", "set shard to '1'", @@ -459,9 +478,13 @@ mod test { "set server role to 'any'", "set server role to 'auto'", "show server role", + "set primary reads to 'on'", + "set primary reads to 'OFF'", + "set primary reads to 'deFaUlt'", // No quotes "SET SHARDING KEY TO 11235", "SET SHARD TO 15", + "SET PRIMARY READS TO off", // Spaces and semicolon " SET SHARDING KEY TO 11235 ; ", " SET SHARD TO 15; ", @@ -469,18 +492,23 @@ mod test { " SET SERVER ROLE TO 'primary'; ", " SET SERVER ROLE TO 'primary' ; ", " SET SERVER ROLE TO 'primary' ;", + " SET PRIMARY READS TO 'off' ;", ]; // Which regexes it'll match to in the list let matches = [ - 0, 1, 2, 3, 3, 3, 3, 4, 0, 1, 2, 3, 3, 3, 3, 4, 0, 1, 0, 1, 0, 3, 3, 3, + 0, 1, 2, 3, 3, 3, 3, 4, 5, 5, 5, 6, 0, 1, 2, 3, 3, 3, 3, 4, 5, 5, 5, 0, 1, 5, 0, 1, 0, + 3, 3, 3, 5, ]; let list = CUSTOM_SQL_REGEX_LIST.get().unwrap(); let set = CUSTOM_SQL_REGEX_SET.get().unwrap(); for (i, test) in tests.iter().enumerate() { - assert!(list[matches[i]].is_match(test)); + if !list[matches[i]].is_match(test) { + println!("{} does not match {}", test, list[matches[i]]); + assert!(false); + } assert_eq!(set.matches(test).into_iter().collect::>().len(), 1); } @@ -549,6 +577,26 @@ mod test { Some((Command::ShowServerRole, String::from(*role))) ); } + + let primary_reads = ["on", "off", "default"]; + let primary_reads_enabled = ["on", "off", "on"]; + + for (idx, primary_reads) in primary_reads.iter().enumerate() { + assert_eq!( + qr.try_execute_command(simple_query(&format!( + "SET PRIMARY READS TO {}", + primary_reads + ))), + Some((Command::SetPrimaryReads, String::from(*primary_reads))) + ); + assert_eq!( + qr.try_execute_command(simple_query("SHOW PRIMARY READS")), + Some(( + Command::ShowPrimaryReads, + String::from(primary_reads_enabled[idx]) + )) + ); + } } #[test] @@ -556,7 +604,7 @@ mod test { QueryRouter::setup(); let mut qr = QueryRouter::new(); let query = simple_query("SET SERVER ROLE TO 'auto'"); - qr.toggle_primary_reads(false); + assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); assert!(qr.try_execute_command(query) != None); assert!(qr.query_parser_enabled()); @@ -573,6 +621,6 @@ mod test { assert!(qr.query_parser_enabled()); let query = simple_query("SET SERVER ROLE TO 'default'"); assert!(qr.try_execute_command(query) != None); - assert!(!qr.query_parser_enabled()); + assert!(qr.query_parser_enabled()); } } diff --git a/src/stats.rs b/src/stats.rs index e44578d..7454c58 100644 --- a/src/stats.rs +++ b/src/stats.rs @@ -1,9 +1,13 @@ +use arc_swap::ArcSwap; /// Statistics and reporting. use log::info; use once_cell::sync::Lazy; use parking_lot::Mutex; use std::collections::HashMap; -use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::sync::mpsc::{channel, Receiver, Sender}; + +pub static REPORTER: Lazy> = + Lazy::new(|| ArcSwap::from_pointee(Reporter::default())); /// Latest stats updated every second; used in SHOW STATS and other admin commands. static LATEST_STATS: Lazy>>> = @@ -60,6 +64,13 @@ pub struct Reporter { tx: Sender, } +impl Default for Reporter { + fn default() -> Reporter { + let (tx, _rx) = channel(5); + Reporter { tx } + } +} + impl Reporter { /// Create a new Reporter instance. pub fn new(tx: Sender) -> Reporter { @@ -289,7 +300,7 @@ impl Collector { ("avg_query_time", 0), ("avg_xact_count", 0), ("avg_sent", 0), - ("avg_received", 0), + ("avg_recv", 0), ("avg_wait_time", 0), ("maxwait_us", 0), ("maxwait", 0), @@ -493,10 +504,14 @@ impl Collector { "avg_query_count", "avgxact_count", "avg_sent", - "avg_received", + "avg_recv", "avg_wait_time", ] { - let total_name = stat.replace("avg_", "total_"); + let total_name = match stat { + &"avg_recv" => "total_received".to_string(), // Because PgBouncer is saving bytes + _ => stat.replace("avg_", "total_"), + }; + let old_value = old_stats.entry(total_name.clone()).or_insert(0); let new_value = stats.get(total_name.as_str()).unwrap_or(&0).to_owned(); let avg = (new_value - *old_value) / (STAT_PERIOD as i64 / 1_000); // Avg / second @@ -515,3 +530,8 @@ impl Collector { pub fn get_stats() -> HashMap> { LATEST_STATS.lock().clone() } + +/// Get the statistics reporter used to update stats across the pools/clients. +pub fn get_reporter() -> Reporter { + (*(*REPORTER.load())).clone() +} diff --git a/tests/pgbench/simple.sql b/tests/pgbench/simple.sql index 0a283ba..ad5e613 100644 --- a/tests/pgbench/simple.sql +++ b/tests/pgbench/simple.sql @@ -12,6 +12,8 @@ SET SHARD TO :shard; +SET SERVER ROLE TO 'auto'; + BEGIN; UPDATE pgbench_accounts SET abalance = abalance + :delta WHERE aid = :aid; @@ -26,3 +28,12 @@ INSERT INTO pgbench_history (tid, bid, aid, delta, mtime) VALUES (:tid, :bid, :a END; +SET SHARDING KEY TO :aid; + +-- Read load balancing +SELECT abalance FROM pgbench_accounts WHERE aid = :aid; + +SET SERVER ROLE TO 'replica'; + +-- Read load balancing +SELECT abalance FROM pgbench_accounts WHERE aid = :aid; diff --git a/tests/sharding/query_routing_test_primary_replica.sql b/tests/sharding/query_routing_test_primary_replica.sql index db05ba7..5fe3cbe 100644 --- a/tests/sharding/query_routing_test_primary_replica.sql +++ b/tests/sharding/query_routing_test_primary_replica.sql @@ -151,3 +151,12 @@ SELECT 1; set server role to 'replica'; SeT SeRver Role TO 'PrImARY'; select 1; + +SET PRIMARY READS TO 'on'; +SELECT 1; + +SET PRIMARY READS TO 'off'; +SELECT 1; + +SET PRIMARY READS TO 'default'; +SELECT 1;