Startup improvements & PAUSE/RESUME (#300)

* Dont require servers to be online to start pooler

* PAUSE/RESUME

* fix

* Refresh pool

* Fixes

* lint
This commit is contained in:
Lev Kokotov
2023-01-28 15:36:35 -08:00
committed by GitHub
parent 2e3eb2663e
commit 24e79dcf05
7 changed files with 322 additions and 62 deletions

View File

@@ -1,7 +1,7 @@
use arc_swap::ArcSwap;
use async_trait::async_trait;
use bb8::{ManageConnection, Pool, PooledConnection};
use bytes::BytesMut;
use bytes::{BufMut, BytesMut};
use chrono::naive::NaiveDateTime;
use log::{debug, error, info, warn};
use once_cell::sync::Lazy;
@@ -9,8 +9,12 @@ use parking_lot::{Mutex, RwLock};
use rand::seq::SliceRandom;
use rand::thread_rng;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use std::time::Instant;
use tokio::sync::Notify;
use crate::config::{get_config, Address, General, LoadBalancingMode, PoolMode, Role, User};
use crate::errors::Error;
@@ -56,6 +60,12 @@ impl PoolIdentifier {
}
}
impl From<&Address> for PoolIdentifier {
fn from(address: &Address) -> PoolIdentifier {
PoolIdentifier::new(&address.database, &address.username)
}
}
/// Pool settings.
#[derive(Clone, Debug)]
pub struct PoolSettings {
@@ -136,10 +146,18 @@ pub struct ConnectionPool {
/// The server information (K messages) have to be passed to the
/// clients on startup. We pre-connect to all shards and replicas
/// on pool creation and save the K messages here.
server_info: BytesMut,
server_info: Arc<RwLock<BytesMut>>,
/// Pool configuration.
pub settings: PoolSettings,
/// If not validated, we need to double check the pool is available before allowing a client
/// to use it.
validated: Arc<AtomicBool>,
/// If the pool has been paused or not.
paused: Arc<AtomicBool>,
paused_waiter: Arc<Notify>,
}
impl ConnectionPool {
@@ -257,12 +275,12 @@ impl ConnectionPool {
assert_eq!(shards.len(), addresses.len());
let mut pool = ConnectionPool {
let pool = ConnectionPool {
databases: shards,
addresses,
banlist: Arc::new(RwLock::new(banlist)),
stats: get_reporter(),
server_info: BytesMut::new(),
server_info: Arc::new(RwLock::new(BytesMut::new())),
settings: PoolSettings {
pool_mode: pool_config.pool_mode,
load_balancing_mode: pool_config.load_balancing_mode,
@@ -283,17 +301,18 @@ impl ConnectionPool {
healthcheck_timeout: config.general.healthcheck_timeout,
ban_time: config.general.ban_time,
},
validated: Arc::new(AtomicBool::new(false)),
paused: Arc::new(AtomicBool::new(false)),
paused_waiter: Arc::new(Notify::new()),
};
// Connect to the servers to make sure pool configuration is valid
// before setting it globally.
match pool.validate().await {
Ok(_) => (),
Err(err) => {
error!("Could not validate connection pool: {:?}", err);
return Err(err);
}
};
// Do this async and somewhere else, we don't have to wait here.
let mut validate_pool = pool.clone();
tokio::task::spawn(async move {
let _ = validate_pool.validate().await;
});
// There is one pool per database/user pair.
new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool);
@@ -311,49 +330,87 @@ impl ConnectionPool {
/// when they connect.
/// This also warms up the pool for clients that connect when
/// the pooler starts up.
async fn validate(&mut self) -> Result<(), Error> {
let mut server_infos = Vec::new();
pub async fn validate(&mut self) -> Result<(), Error> {
let mut futures = Vec::new();
let validated = Arc::clone(&self.validated);
for shard in 0..self.shards() {
for server in 0..self.servers(shard) {
let connection = match self.databases[shard][server].get().await {
Ok(conn) => conn,
Err(err) => {
error!("Shard {} down or misconfigured: {:?}", shard, err);
continue;
}
};
let databases = self.databases.clone();
let validated = Arc::clone(&validated);
let pool_server_info = Arc::clone(&self.server_info);
let proxy = connection;
let server = &*proxy;
let server_info = server.server_info();
let task = tokio::task::spawn(async move {
let connection = match databases[shard][server].get().await {
Ok(conn) => conn,
Err(err) => {
error!("Shard {} down or misconfigured: {:?}", shard, err);
return;
}
};
if !server_infos.is_empty() {
// Compare against the last server checked.
if server_info != server_infos[server_infos.len() - 1] {
warn!(
"{:?} has different server configuration than the last server",
proxy.address()
);
}
}
let proxy = connection;
let server = &*proxy;
let server_info = server.server_info();
server_infos.push(server_info);
let mut guard = pool_server_info.write();
guard.clear();
guard.put(server_info.clone());
validated.store(true, Ordering::Relaxed);
});
futures.push(task);
}
}
futures::future::join_all(futures).await;
// TODO: compare server information to make sure
// all shards are running identical configurations.
if server_infos.is_empty() {
if self.server_info.read().is_empty() {
error!("Could not validate connection pool");
return Err(Error::AllServersDown);
}
// We're assuming all servers are identical.
// TODO: not true.
self.server_info = server_infos[0].clone();
Ok(())
}
/// The pool can be used by clients.
///
/// If not, we need to validate it first by connecting to servers.
/// Call `validate()` to do so.
pub fn validated(&self) -> bool {
self.validated.load(Ordering::Relaxed)
}
/// Pause the pool, allowing no more queries and make clients wait.
pub fn pause(&self) {
self.paused.store(true, Ordering::Relaxed);
}
/// Resume the pool, allowing queries and resuming any pending queries.
pub fn resume(&self) {
self.paused.store(false, Ordering::Relaxed);
self.paused_waiter.notify_waiters();
}
/// Check if the pool is paused.
pub fn paused(&self) -> bool {
self.paused.load(Ordering::Relaxed)
}
/// Check if the pool is paused and wait until it's resumed.
pub async fn wait_paused(&self) -> bool {
let waiter = self.paused_waiter.notified();
let paused = self.paused.load(Ordering::Relaxed);
if paused {
waiter.await;
}
paused
}
/// Get a connection from the pool.
pub async fn get(
&self,
@@ -624,7 +681,7 @@ impl ConnectionPool {
}
pub fn server_info(&self) -> BytesMut {
self.server_info.clone()
self.server_info.read().clone()
}
fn busy_connection_count(&self, address: &Address) -> u32 {