mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-26 10:26:30 +00:00
Live reloading entire config and bug fixes (#84)
* Support reloading the entire config (including sharding logic) without restart. * Fix bug incorrectly handing error reporting when the shard is set incorrectly via SET SHARD TO command. selected wrong shard and the connection keep reporting fatal #80. * Fix total_received and avg_recv admin database statistics. * Enabling the query parser by default. * More tests.
This commit is contained in:
102
src/pool.rs
102
src/pool.rs
@@ -1,9 +1,10 @@
|
||||
/// Pooling, failover and banlist.
|
||||
use arc_swap::ArcSwap;
|
||||
use async_trait::async_trait;
|
||||
use bb8::{ManageConnection, Pool, PooledConnection};
|
||||
use bytes::BytesMut;
|
||||
use chrono::naive::NaiveDateTime;
|
||||
use log::{debug, error, info, warn};
|
||||
use once_cell::sync::Lazy;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
@@ -12,28 +13,47 @@ use std::time::Instant;
|
||||
use crate::config::{get_config, Address, Role, User};
|
||||
use crate::errors::Error;
|
||||
use crate::server::Server;
|
||||
use crate::stats::Reporter;
|
||||
use crate::stats::{get_reporter, Reporter};
|
||||
|
||||
pub type BanList = Arc<RwLock<Vec<HashMap<Address, NaiveDateTime>>>>;
|
||||
pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, String)>>>;
|
||||
|
||||
/// The connection pool, globally available.
|
||||
/// This is atomic and safe and read-optimized.
|
||||
/// The pool is recreated dynamically when the config is reloaded.
|
||||
pub static POOL: Lazy<ArcSwap<ConnectionPool>> =
|
||||
Lazy::new(|| ArcSwap::from_pointee(ConnectionPool::default()));
|
||||
|
||||
/// The globally accessible connection pool.
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ConnectionPool {
|
||||
/// The pools handled internally by bb8.
|
||||
databases: Vec<Vec<Pool<ServerPool>>>,
|
||||
|
||||
/// The addresses (host, port, role) to handle
|
||||
/// failover and load balancing deterministically.
|
||||
addresses: Vec<Vec<Address>>,
|
||||
round_robin: usize,
|
||||
|
||||
/// List of banned addresses (see above)
|
||||
/// that should not be queried.
|
||||
banlist: BanList,
|
||||
|
||||
/// The statistics aggregator runs in a separate task
|
||||
/// and receives stats from clients, servers, and the pool.
|
||||
stats: Reporter,
|
||||
|
||||
/// The server information (K messages) have to be passed to the
|
||||
/// clients on startup. We pre-connect to all shards and replicas
|
||||
/// on pool creation and save the K messages here.
|
||||
server_info: BytesMut,
|
||||
}
|
||||
|
||||
impl ConnectionPool {
|
||||
/// Construct the connection pool from the configuration.
|
||||
pub async fn from_config(
|
||||
client_server_map: ClientServerMap,
|
||||
stats: Reporter,
|
||||
) -> ConnectionPool {
|
||||
pub async fn from_config(client_server_map: ClientServerMap) -> Result<(), Error> {
|
||||
let reporter = get_reporter();
|
||||
let config = get_config();
|
||||
|
||||
let mut shards = Vec::new();
|
||||
let mut addresses = Vec::new();
|
||||
let mut banlist = Vec::new();
|
||||
@@ -44,6 +64,8 @@ impl ConnectionPool {
|
||||
.into_keys()
|
||||
.map(|x| x.to_string())
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
// Sort by shard number to ensure consistency.
|
||||
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
|
||||
|
||||
for shard_idx in shard_ids {
|
||||
@@ -82,7 +104,7 @@ impl ConnectionPool {
|
||||
config.user.clone(),
|
||||
&shard.database,
|
||||
client_server_map.clone(),
|
||||
stats.clone(),
|
||||
reporter.clone(),
|
||||
);
|
||||
|
||||
let pool = Pool::builder()
|
||||
@@ -105,15 +127,28 @@ impl ConnectionPool {
|
||||
}
|
||||
|
||||
assert_eq!(shards.len(), addresses.len());
|
||||
let address_len = addresses.len();
|
||||
|
||||
ConnectionPool {
|
||||
let mut pool = ConnectionPool {
|
||||
databases: shards,
|
||||
addresses: addresses,
|
||||
round_robin: rand::random::<usize>() % address_len, // Start at a random replica
|
||||
banlist: Arc::new(RwLock::new(banlist)),
|
||||
stats: stats,
|
||||
}
|
||||
stats: reporter,
|
||||
server_info: BytesMut::new(),
|
||||
};
|
||||
|
||||
// Connect to the servers to make sure pool configuration is valid
|
||||
// before setting it globally.
|
||||
match pool.validate().await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
error!("Could not validate connection pool: {:?}", err);
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
POOL.store(Arc::new(pool.clone()));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Connect to all shards and grab server information.
|
||||
@@ -121,16 +156,18 @@ impl ConnectionPool {
|
||||
/// when they connect.
|
||||
/// This also warms up the pool for clients that connect when
|
||||
/// the pooler starts up.
|
||||
pub async fn validate(&mut self) -> Result<BytesMut, Error> {
|
||||
async fn validate(&mut self) -> Result<(), Error> {
|
||||
let mut server_infos = Vec::new();
|
||||
|
||||
let stats = self.stats.clone();
|
||||
|
||||
for shard in 0..self.shards() {
|
||||
let mut round_robin = 0;
|
||||
|
||||
for _ in 0..self.servers(shard) {
|
||||
// To keep stats consistent.
|
||||
let fake_process_id = 0;
|
||||
|
||||
let connection = match self.get(shard, None, fake_process_id).await {
|
||||
let connection = match self.get(shard, None, fake_process_id, round_robin).await {
|
||||
Ok(conn) => conn,
|
||||
Err(err) => {
|
||||
error!("Shard {} down or misconfigured: {:?}", shard, err);
|
||||
@@ -138,10 +175,9 @@ impl ConnectionPool {
|
||||
}
|
||||
};
|
||||
|
||||
let mut proxy = connection.0;
|
||||
let proxy = connection.0;
|
||||
let address = connection.1;
|
||||
let server = &mut *proxy;
|
||||
|
||||
let server = &*proxy;
|
||||
let server_info = server.server_info();
|
||||
|
||||
stats.client_disconnecting(fake_process_id, address.id);
|
||||
@@ -157,6 +193,7 @@ impl ConnectionPool {
|
||||
}
|
||||
|
||||
server_infos.push(server_info);
|
||||
round_robin += 1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,15 +203,18 @@ impl ConnectionPool {
|
||||
return Err(Error::AllServersDown);
|
||||
}
|
||||
|
||||
Ok(server_infos[0].clone())
|
||||
self.server_info = server_infos[0].clone();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get a connection from the pool.
|
||||
pub async fn get(
|
||||
&mut self,
|
||||
shard: usize,
|
||||
role: Option<Role>,
|
||||
process_id: i32,
|
||||
shard: usize, // shard number
|
||||
role: Option<Role>, // primary or replica
|
||||
process_id: i32, // client id
|
||||
mut round_robin: usize, // round robin offset
|
||||
) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
|
||||
let now = Instant::now();
|
||||
let addresses = &self.addresses[shard];
|
||||
@@ -204,9 +244,9 @@ impl ConnectionPool {
|
||||
|
||||
while allowed_attempts > 0 {
|
||||
// Round-robin replicas.
|
||||
self.round_robin += 1;
|
||||
round_robin += 1;
|
||||
|
||||
let index = self.round_robin % addresses.len();
|
||||
let index = round_robin % addresses.len();
|
||||
let address = &addresses[index];
|
||||
|
||||
// Make sure you're getting a primary or a replica
|
||||
@@ -218,6 +258,7 @@ impl ConnectionPool {
|
||||
|
||||
allowed_attempts -= 1;
|
||||
|
||||
// Don't attempt to connect to banned servers.
|
||||
if self.is_banned(address, shard, role) {
|
||||
continue;
|
||||
}
|
||||
@@ -390,6 +431,10 @@ impl ConnectionPool {
|
||||
pub fn address(&self, shard: usize, server: usize) -> &Address {
|
||||
&self.addresses[shard][server]
|
||||
}
|
||||
|
||||
pub fn server_info(&self) -> BytesMut {
|
||||
self.server_info.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper for the bb8 connection pool.
|
||||
@@ -470,3 +515,8 @@ impl ManageConnection for ServerPool {
|
||||
conn.is_bad()
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the connection pool
|
||||
pub fn get_pool() -> ConnectionPool {
|
||||
(*(*POOL.load())).clone()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user