diff --git a/.circleci/run_tests.sh b/.circleci/run_tests.sh index c27bb3b..00d4927 100644 --- a/.circleci/run_tests.sh +++ b/.circleci/run_tests.sh @@ -56,6 +56,17 @@ psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'COPY (SELECT * FROM pgbench_accou sleep 1 killall psql -s SIGINT +# Pause/resume test. +# Running benches before, during, and after pause/resume. +pgbench -U sharding_user -t 500 -c 2 -h 127.0.0.1 -p 6432 --protocol extended & +BENCH_ONE=$! +PGPASSWORD=admin_pass psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'PAUSE sharded_db,sharding_user' +pgbench -U sharding_user -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol extended & +BENCH_TWO=$! +PGPASSWORD=admin_pass psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RESUME sharded_db,sharding_user' +wait ${BENCH_ONE} +wait ${BENCH_TWO} + # Reload pool (closing unused server connections) PGPASSWORD=admin_pass psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD' diff --git a/Cargo.lock b/Cargo.lock index 276436e..2b8ce70 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -258,6 +258,21 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2022715d62ab30faffd124d40b76f4134a550a87792276512b18d63272333394" +[[package]] +name = "futures" +version = "0.3.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38390104763dc37a5145a53c29c63c1290b5d316d6086ec32c293f6736051bb0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.25" @@ -265,6 +280,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "52ba265a92256105f45b719605a571ffe2d1f0fea3807304b522c1d778f79eed" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -273,6 +289,34 @@ version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04909a7a7e4633ae6c4a9ab280aeb86da1236243a77b694a49eacd659a4bd3ac" +[[package]] +name = "futures-executor" +version = "0.3.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7acc85df6714c176ab5edf386123fafe217be88c0840ec11f199441134a074e2" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00f5fb52a06bdcadeb54e8d3671f8888a39697dcb0b81b23b55174030427f4eb" + +[[package]] +name = "futures-macro" +version = "0.3.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdfb8ce053d86b91919aad980c220b1fb8401a9394410e1c289ed7e66b61835d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.25" @@ -293,7 +337,11 @@ checksum = "197676987abd2f9cadff84926f410af1c183608d36641465df73ae8211dc65d6" dependencies = [ "futures-channel", "futures-core", + "futures-io", + "futures-macro", + "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", "slab", @@ -665,6 +713,7 @@ dependencies = [ "chrono", "env_logger", "exitcode", + "futures", "hmac", "hyper", "jemallocator", diff --git a/Cargo.toml b/Cargo.toml index 5dde943..344b1c0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ rustls-pemfile = "1" hyper = { version = "0.14", features = ["full"] } phf = { version = "0.11.1", features = ["macros"] } exitcode = "1.1.2" +futures = "0.3" [target.'cfg(not(target_env = "msvc"))'.dependencies] jemallocator = "0.5.0" diff --git a/src/admin.rs b/src/admin.rs index 9d4526e..71d3e48 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -7,7 +7,7 @@ use tokio::time::Instant; use crate::config::{get_config, reload_config, VERSION}; use crate::errors::Error; use crate::messages::*; -use crate::pool::get_all_pools; +use crate::pool::{get_all_pools, get_pool}; use crate::stats::{ get_address_stats, get_client_stats, get_pool_stats, get_server_stats, ClientState, ServerState, }; @@ -44,15 +44,13 @@ where } let len = query.get_i32() as usize; - let query = String::from_utf8_lossy(&query[..len - 5]) - .to_string() - .to_ascii_uppercase(); + let query = String::from_utf8_lossy(&query[..len - 5]).to_string(); trace!("Admin query: {}", query); let query_parts: Vec<&str> = query.trim_end_matches(';').split_whitespace().collect(); - match query_parts[0] { + match query_parts[0].to_ascii_uppercase().as_str() { "RELOAD" => { trace!("RELOAD"); reload(stream, client_server_map).await @@ -61,7 +59,15 @@ where trace!("SET"); ignore_set(stream).await } - "SHOW" => match query_parts[1] { + "PAUSE" => { + trace!("PAUSE"); + pause(stream, query_parts[1]).await + } + "RESUME" => { + trace!("RESUME"); + resume(stream, query_parts[1]).await + } + "SHOW" => match query_parts[1].to_ascii_uppercase().as_str() { "CONFIG" => { trace!("SHOW CONFIG"); show_config(stream).await @@ -287,6 +293,7 @@ where let address = pool.address(shard, server); let pool_state = pool.pool_state(shard, server); let banned = pool.is_banned(address); + let paused = pool.paused(); res.put(data_row(&vec![ address.name(), // name @@ -300,7 +307,11 @@ where pool_config.pool_mode.to_string(), // pool_mode pool_config.user.pool_size.to_string(), // max_connections pool_state.connections.to_string(), // current_connections - "0".to_string(), // paused + match paused { + // paused + true => "1".to_string(), + false => "0".to_string(), + }, match banned { // disabled true => "1".to_string(), @@ -561,3 +572,97 @@ where write_all_half(stream, &res).await } + +/// Pause a pool. It won't pass any more queries to the backends. +async fn pause(stream: &mut T, query: &str) -> Result<(), Error> +where + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ + let parts: Vec<&str> = query.split(",").map(|part| part.trim()).collect(); + + if parts.len() != 2 { + error_response( + stream, + "PAUSE requires a database and a user, e.g. PAUSE my_db,my_user", + ) + .await + } else { + let database = parts[0]; + let user = parts[1]; + + match get_pool(database, user) { + Some(pool) => { + pool.pause(); + + let mut res = BytesMut::new(); + + res.put(command_complete(&format!("PAUSE {},{}", database, user))); + + // ReadyForQuery + res.put_u8(b'Z'); + res.put_i32(5); + res.put_u8(b'I'); + + write_all_half(stream, &res).await + } + + None => { + error_response( + stream, + &format!( + "No pool configured for database: {}, user: {}", + database, user + ), + ) + .await + } + } + } +} + +/// Resume a pool. Queries are allowed again. +async fn resume(stream: &mut T, query: &str) -> Result<(), Error> +where + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ + let parts: Vec<&str> = query.split(",").map(|part| part.trim()).collect(); + + if parts.len() != 2 { + error_response( + stream, + "RESUME requires a database and a user, e.g. RESUME my_db,my_user", + ) + .await + } else { + let database = parts[0]; + let user = parts[1]; + + match get_pool(database, user) { + Some(pool) => { + pool.resume(); + + let mut res = BytesMut::new(); + + res.put(command_complete(&format!("RESUME {},{}", database, user))); + + // ReadyForQuery + res.put_u8(b'Z'); + res.put_i32(5); + res.put_u8(b'I'); + + write_all_half(stream, &res).await + } + + None => { + error_response( + stream, + &format!( + "No pool configured for database: {}, user: {}", + database, user + ), + ) + .await + } + } + } +} diff --git a/src/client.rs b/src/client.rs index 15fe21d..ed2044d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -476,7 +476,7 @@ where } // Authenticate normal user. else { - let pool = match get_pool(pool_name, username) { + let mut pool = match get_pool(pool_name, username) { Some(pool) => pool, None => { error_response( @@ -504,6 +504,25 @@ where let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction; + // If the pool hasn't been validated yet, + // connect to the servers and figure out what's what. + if !pool.validated() { + match pool.validate().await { + Ok(_) => (), + Err(err) => { + error_response( + &mut write, + &format!( + "Pool down for database: {:?}, user: {:?}", + pool_name, username + ), + ) + .await?; + return Err(Error::ClientError(format!("Pool down: {:?}", err))); + } + } + } + (transaction_mode, pool.server_info()) }; @@ -674,22 +693,16 @@ where // Get a pool instance referenced by the most up-to-date // pointer. This ensures we always read the latest config // when starting a query. - let pool = match get_pool(&self.pool_name, &self.username) { - Some(pool) => pool, - None => { - error_response( - &mut self.write, - &format!( - "No pool configured for database: {:?}, user: {:?}", - self.pool_name, self.username - ), - ) - .await?; + let mut pool = self.get_pool().await?; + + // Check if the pool is paused and wait until it's resumed. + if pool.wait_paused().await { + // Refresh pool information, something might have changed. + pool = self.get_pool().await?; + } - return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", self.pool_name, self.username, self.application_name))); - } - }; query_router.update_pool_settings(pool.settings.clone()); + let current_shard = query_router.shard(); // Handle all custom protocol commands, if any. @@ -1012,6 +1025,29 @@ where } } + /// Retrieve connection pool, if it exists. + /// Return an error to the client otherwise. + async fn get_pool(&mut self) -> Result { + match get_pool(&self.pool_name, &self.username) { + Some(pool) => Ok(pool), + None => { + error_response( + &mut self.write, + &format!( + "No pool configured for database: {}, user: {}", + self.pool_name, self.username + ), + ) + .await?; + + Err(Error::ClientError(format!( + "Invalid pool name {{ username: {}, pool_name: {}, application_name: {} }}", + self.pool_name, self.username, self.application_name + ))) + } + } + } + /// Release the server from the client: it can't cancel its queries anymore. pub fn release(&self) { let mut guard = self.client_server_map.lock(); diff --git a/src/pool.rs b/src/pool.rs index 0f92158..702b617 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -1,7 +1,7 @@ use arc_swap::ArcSwap; use async_trait::async_trait; use bb8::{ManageConnection, Pool, PooledConnection}; -use bytes::BytesMut; +use bytes::{BufMut, BytesMut}; use chrono::naive::NaiveDateTime; use log::{debug, error, info, warn}; use once_cell::sync::Lazy; @@ -9,8 +9,12 @@ use parking_lot::{Mutex, RwLock}; use rand::seq::SliceRandom; use rand::thread_rng; use std::collections::{HashMap, HashSet}; -use std::sync::Arc; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; use std::time::Instant; +use tokio::sync::Notify; use crate::config::{get_config, Address, General, LoadBalancingMode, PoolMode, Role, User}; use crate::errors::Error; @@ -56,6 +60,12 @@ impl PoolIdentifier { } } +impl From<&Address> for PoolIdentifier { + fn from(address: &Address) -> PoolIdentifier { + PoolIdentifier::new(&address.database, &address.username) + } +} + /// Pool settings. #[derive(Clone, Debug)] pub struct PoolSettings { @@ -136,10 +146,18 @@ pub struct ConnectionPool { /// The server information (K messages) have to be passed to the /// clients on startup. We pre-connect to all shards and replicas /// on pool creation and save the K messages here. - server_info: BytesMut, + server_info: Arc>, /// Pool configuration. pub settings: PoolSettings, + + /// If not validated, we need to double check the pool is available before allowing a client + /// to use it. + validated: Arc, + + /// If the pool has been paused or not. + paused: Arc, + paused_waiter: Arc, } impl ConnectionPool { @@ -257,12 +275,12 @@ impl ConnectionPool { assert_eq!(shards.len(), addresses.len()); - let mut pool = ConnectionPool { + let pool = ConnectionPool { databases: shards, addresses, banlist: Arc::new(RwLock::new(banlist)), stats: get_reporter(), - server_info: BytesMut::new(), + server_info: Arc::new(RwLock::new(BytesMut::new())), settings: PoolSettings { pool_mode: pool_config.pool_mode, load_balancing_mode: pool_config.load_balancing_mode, @@ -283,17 +301,18 @@ impl ConnectionPool { healthcheck_timeout: config.general.healthcheck_timeout, ban_time: config.general.ban_time, }, + validated: Arc::new(AtomicBool::new(false)), + paused: Arc::new(AtomicBool::new(false)), + paused_waiter: Arc::new(Notify::new()), }; // Connect to the servers to make sure pool configuration is valid // before setting it globally. - match pool.validate().await { - Ok(_) => (), - Err(err) => { - error!("Could not validate connection pool: {:?}", err); - return Err(err); - } - }; + // Do this async and somewhere else, we don't have to wait here. + let mut validate_pool = pool.clone(); + tokio::task::spawn(async move { + let _ = validate_pool.validate().await; + }); // There is one pool per database/user pair. new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool); @@ -311,49 +330,87 @@ impl ConnectionPool { /// when they connect. /// This also warms up the pool for clients that connect when /// the pooler starts up. - async fn validate(&mut self) -> Result<(), Error> { - let mut server_infos = Vec::new(); + pub async fn validate(&mut self) -> Result<(), Error> { + let mut futures = Vec::new(); + let validated = Arc::clone(&self.validated); + for shard in 0..self.shards() { for server in 0..self.servers(shard) { - let connection = match self.databases[shard][server].get().await { - Ok(conn) => conn, - Err(err) => { - error!("Shard {} down or misconfigured: {:?}", shard, err); - continue; - } - }; + let databases = self.databases.clone(); + let validated = Arc::clone(&validated); + let pool_server_info = Arc::clone(&self.server_info); - let proxy = connection; - let server = &*proxy; - let server_info = server.server_info(); + let task = tokio::task::spawn(async move { + let connection = match databases[shard][server].get().await { + Ok(conn) => conn, + Err(err) => { + error!("Shard {} down or misconfigured: {:?}", shard, err); + return; + } + }; - if !server_infos.is_empty() { - // Compare against the last server checked. - if server_info != server_infos[server_infos.len() - 1] { - warn!( - "{:?} has different server configuration than the last server", - proxy.address() - ); - } - } + let proxy = connection; + let server = &*proxy; + let server_info = server.server_info(); - server_infos.push(server_info); + let mut guard = pool_server_info.write(); + guard.clear(); + guard.put(server_info.clone()); + validated.store(true, Ordering::Relaxed); + }); + + futures.push(task); } } + futures::future::join_all(futures).await; + // TODO: compare server information to make sure // all shards are running identical configurations. - if server_infos.is_empty() { + if self.server_info.read().is_empty() { + error!("Could not validate connection pool"); return Err(Error::AllServersDown); } - // We're assuming all servers are identical. - // TODO: not true. - self.server_info = server_infos[0].clone(); - Ok(()) } + /// The pool can be used by clients. + /// + /// If not, we need to validate it first by connecting to servers. + /// Call `validate()` to do so. + pub fn validated(&self) -> bool { + self.validated.load(Ordering::Relaxed) + } + + /// Pause the pool, allowing no more queries and make clients wait. + pub fn pause(&self) { + self.paused.store(true, Ordering::Relaxed); + } + + /// Resume the pool, allowing queries and resuming any pending queries. + pub fn resume(&self) { + self.paused.store(false, Ordering::Relaxed); + self.paused_waiter.notify_waiters(); + } + + /// Check if the pool is paused. + pub fn paused(&self) -> bool { + self.paused.load(Ordering::Relaxed) + } + + /// Check if the pool is paused and wait until it's resumed. + pub async fn wait_paused(&self) -> bool { + let waiter = self.paused_waiter.notified(); + let paused = self.paused.load(Ordering::Relaxed); + + if paused { + waiter.await; + } + + paused + } + /// Get a connection from the pool. pub async fn get( &self, @@ -624,7 +681,7 @@ impl ConnectionPool { } pub fn server_info(&self) -> BytesMut { - self.server_info.clone() + self.server_info.read().clone() } fn busy_connection_count(&self, address: &Address) -> u32 { diff --git a/src/server.rs b/src/server.rs index f2a6d38..9600485 100644 --- a/src/server.rs +++ b/src/server.rs @@ -546,6 +546,7 @@ impl Server { /// If the server is still inside a transaction. /// If the client disconnects while the server is in a transaction, we will clean it up. pub fn in_transaction(&self) -> bool { + debug!("Server in transaction: {}", self.in_transaction); self.in_transaction }