Live reloading entire config and bug fixes (#84)

* Support reloading the entire config (including sharding logic) without restart.

* Fix bug incorrectly handing error reporting when the shard is set incorrectly via SET SHARD TO command.
selected wrong shard and the connection keep reporting fatal #80.

* Fix total_received and avg_recv admin database statistics.

* Enabling the query parser by default.

* More tests.
This commit is contained in:
Lev Kokotov
2022-06-24 14:52:38 -07:00
committed by GitHub
parent d865d9f9d8
commit b93303eb83
14 changed files with 393 additions and 188 deletions

View File

@@ -87,7 +87,7 @@ default_role = "any"
# every incoming query to determine if it's a read or a write.
# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,
# we'll direct it to the primary.
query_parser_enabled = false
query_parser_enabled = true
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
# load balancing of read queries. Otherwise, the primary will only be used for write

View File

@@ -42,7 +42,15 @@ pgbench -U sharding_user -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol extended
psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'COPY (SELECT * FROM pgbench_accounts LIMIT 15) TO STDOUT;' > /dev/null
# Query cancellation test
(psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'SELECT pg_sleep(5)' || true) &
(psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'SELECT pg_sleep(50)' || true) &
sleep 1
killall psql -s SIGINT
# Reload pool (closing unused server connections)
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD'
(psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'SELECT pg_sleep(50)' || true) &
sleep 1
killall psql -s SIGINT
# Sharding insert
@@ -94,7 +102,7 @@ toxiproxy-cli toxic remove --toxicName latency_downstream postgres_replica
start_pgcat "info"
# Test session mode (and config reload)
sed -i 's/pool_mode = "transaction"/pool_mode = "session"/' pgcat.toml
sed -i 's/pool_mode = "transaction"/pool_mode = "session"/' .circleci/pgcat.toml
# Reload config test
kill -SIGHUP $(pgrep pgcat)

2
Cargo.lock generated
View File

@@ -368,7 +368,7 @@ dependencies = [
[[package]]
name = "pgcat"
version = "0.2.0-beta1"
version = "0.4.0-beta1"
dependencies = [
"arc-swap",
"async-trait",

View File

@@ -1,6 +1,6 @@
[package]
name = "pgcat"
version = "0.2.1-beta1"
version = "0.4.0-beta1"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

View File

@@ -87,7 +87,7 @@ default_role = "any"
# every incoming query to determine if it's a read or a write.
# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,
# we'll direct it to the primary.
query_parser_enabled = false
query_parser_enabled = true
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
# load balancing of read queries. Otherwise, the primary will only be used for write

View File

@@ -4,17 +4,19 @@ use log::{info, trace};
use std::collections::HashMap;
use tokio::net::tcp::OwnedWriteHalf;
use crate::config::{get_config, parse};
use crate::config::{get_config, reload_config};
use crate::errors::Error;
use crate::messages::*;
use crate::pool::ConnectionPool;
use crate::stats::get_stats;
use crate::ClientServerMap;
/// Handle admin client.
pub async fn handle_admin(
stream: &mut OwnedWriteHalf,
mut query: BytesMut,
pool: ConnectionPool,
client_server_map: ClientServerMap,
) -> Result<(), Error> {
let code = query.get_u8() as char;
@@ -34,7 +36,7 @@ pub async fn handle_admin(
show_stats(stream, &pool).await
} else if query.starts_with("RELOAD") {
trace!("RELOAD");
reload(stream).await
reload(stream, client_server_map).await
} else if query.starts_with("SHOW CONFIG") {
trace!("SHOW CONFIG");
show_config(stream).await
@@ -143,10 +145,7 @@ async fn show_version(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
/// Show utilization of connection pools for each shard and replicas.
async fn show_pools(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> {
let stats = get_stats();
let config = {
let guard = get_config();
&*guard.clone()
};
let config = get_config();
let columns = vec![
("database", DataType::Text),
@@ -199,9 +198,7 @@ async fn show_pools(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Resul
/// Show shards and replicas.
async fn show_databases(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> {
let guard = get_config();
let config = &*guard.clone();
drop(guard);
let config = get_config();
// Columns
let columns = vec![
@@ -266,17 +263,15 @@ async fn ignore_set(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
}
/// Reload the configuration file without restarting the process.
async fn reload(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
async fn reload(
stream: &mut OwnedWriteHalf,
client_server_map: ClientServerMap,
) -> Result<(), Error> {
info!("Reloading config");
let config = get_config();
let path = config.path.clone().unwrap();
reload_config(client_server_map).await?;
parse(&path).await?;
let config = get_config();
config.show();
get_config().show();
let mut res = BytesMut::new();
@@ -292,10 +287,8 @@ async fn reload(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
/// Shows current configuration.
async fn show_config(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
let guard = get_config();
let config = &*guard.clone();
let config = &get_config();
let config: HashMap<String, String> = config.into();
drop(guard);
// Configs that cannot be changed without restarting.
let immutables = ["host", "port", "connect_timeout"];

View File

@@ -13,10 +13,10 @@ use crate::config::get_config;
use crate::constants::*;
use crate::errors::Error;
use crate::messages::*;
use crate::pool::{ClientServerMap, ConnectionPool};
use crate::pool::{get_pool, ClientServerMap};
use crate::query_router::{Command, QueryRouter};
use crate::server::Server;
use crate::stats::Reporter;
use crate::stats::{get_reporter, Reporter};
/// The client state. One of these is created per client.
pub struct Client {
@@ -69,12 +69,11 @@ impl Client {
pub async fn startup(
mut stream: TcpStream,
client_server_map: ClientServerMap,
server_info: BytesMut,
stats: Reporter,
) -> Result<Client, Error> {
let config = get_config().clone();
let transaction_mode = config.general.pool_mode.starts_with("t");
// drop(config);
let config = get_config();
let transaction_mode = config.general.pool_mode == "transaction";
let stats = get_reporter();
loop {
trace!("Waiting for StartupMessage");
@@ -154,9 +153,10 @@ impl Client {
debug!("Password authentication successful");
auth_ok(&mut stream).await?;
write_all(&mut stream, server_info).await?;
write_all(&mut stream, get_pool().server_info()).await?;
backend_key_data(&mut stream, process_id, secret_key).await?;
ready_for_query(&mut stream).await?;
trace!("Startup OK");
let database = parameters
@@ -221,7 +221,7 @@ impl Client {
}
/// Handle a connected and authenticated client.
pub async fn handle(&mut self, mut pool: ConnectionPool) -> Result<(), Error> {
pub async fn handle(&mut self) -> Result<(), Error> {
// The client wants to cancel a query it has issued previously.
if self.cancel_mode {
trace!("Sending CancelRequest");
@@ -252,13 +252,19 @@ impl Client {
return Ok(Server::cancel(&address, &port, process_id, secret_key).await?);
}
// The query router determines where the query is going to go,
// e.g. primary, replica, which shard.
let mut query_router = QueryRouter::new();
let mut round_robin = 0;
// Our custom protocol loop.
// We expect the client to either start a transaction with regular queries
// or issue commands for our sharding and server selection protocol.
loop {
trace!("Client idle, waiting for message");
trace!(
"Client idle, waiting for message, transaction mode: {}",
self.transaction_mode
);
// Read a complete message from the client, which normally would be
// either a `Q` (query) or `P` (prepare, extended protocol).
@@ -267,32 +273,63 @@ impl Client {
// SET SHARDING KEY TO 'bigint';
let mut message = read_message(&mut self.read).await?;
// Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config
// when starting a query.
let mut pool = get_pool();
// Avoid taking a server if the client just wants to disconnect.
if message[0] as char == 'X' {
trace!("Client disconnecting");
debug!("Client disconnecting");
return Ok(());
}
// Handle admin database queries.
if self.admin {
trace!("Handling admin command");
handle_admin(&mut self.write, message, pool.clone()).await?;
debug!("Handling admin command");
handle_admin(
&mut self.write,
message,
pool.clone(),
self.client_server_map.clone(),
)
.await?;
continue;
}
let current_shard = query_router.shard();
// Handle all custom protocol commands, if any.
match query_router.try_execute_command(message.clone()) {
// Normal query, not a custom command.
None => {
// Attempt to infer which server we want to query, i.e. primary or replica.
if query_router.query_parser_enabled() && query_router.role() == None {
query_router.infer_role(message.clone());
}
}
None => (),
// SET SHARD TO
Some((Command::SetShard, _)) => {
// Selected shard is not configured.
if query_router.shard() >= pool.shards() {
// Set the shard back to what it was.
query_router.set_shard(current_shard);
error_response(
&mut self.write,
&format!(
"shard {} is more than configured {}, staying on shard {}",
query_router.shard(),
pool.shards(),
current_shard,
),
)
.await?;
} else {
custom_protocol_response_ok(&mut self.write, "SET SHARD").await?;
}
continue;
}
// SET PRIMARY READS TO
Some((Command::SetPrimaryReads, _)) => {
custom_protocol_response_ok(&mut self.write, "SET PRIMARY READS").await?;
continue;
}
@@ -319,27 +356,24 @@ impl Client {
show_response(&mut self.write, "shard", &value).await?;
continue;
}
};
// Make sure we selected a valid shard.
if query_router.shard() >= pool.shards() {
error_response(
&mut self.write,
&format!(
"shard {} is more than configured {}",
query_router.shard(),
pool.shards()
),
)
.await?;
// SHOW PRIMARY READS
Some((Command::ShowPrimaryReads, value)) => {
show_response(&mut self.write, "primary reads", &value).await?;
continue;
}
};
debug!("Waiting for connection from pool");
// Grab a server from the pool.
let connection = match pool
.get(query_router.shard(), query_router.role(), self.process_id)
.get(
query_router.shard(),
query_router.role(),
self.process_id,
round_robin,
)
.await
{
Ok(conn) => {
@@ -358,6 +392,8 @@ impl Client {
let address = connection.1;
let server = &mut *reference;
round_robin += 1;
// Server is assigned to the client in case the client wants to
// cancel a query later.
server.claim(self.process_id, self.secret_key);

View File

@@ -1,5 +1,5 @@
/// Parse the configuration file.
use arc_swap::{ArcSwap, Guard};
use arc_swap::ArcSwap;
use log::{error, info};
use once_cell::sync::Lazy;
use serde_derive::Deserialize;
@@ -10,6 +10,7 @@ use tokio::io::AsyncReadExt;
use toml;
use crate::errors::Error;
use crate::{ClientServerMap, ConnectionPool};
/// Globally available configuration.
static CONFIG: Lazy<ArcSwap<Config>> = Lazy::new(|| ArcSwap::from_pointee(Config::default()));
@@ -126,7 +127,7 @@ impl Default for General {
}
/// Shard configuration.
#[derive(Deserialize, Debug, Clone)]
#[derive(Deserialize, Debug, Clone, PartialEq)]
pub struct Shard {
pub servers: Vec<(String, u16, String)>,
pub database: String,
@@ -161,10 +162,16 @@ impl Default for QueryRouter {
}
}
fn default_path() -> String {
String::from("pgcat.toml")
}
/// Configuration wrapper.
#[derive(Deserialize, Debug, Clone)]
pub struct Config {
pub path: Option<String>,
#[serde(default = "default_path")]
pub path: String,
pub general: General,
pub user: User,
pub shards: HashMap<String, Shard>,
@@ -174,7 +181,7 @@ pub struct Config {
impl Default for Config {
fn default() -> Config {
Config {
path: Some(String::from("pgcat.toml")),
path: String::from("pgcat.toml"),
general: General::default(),
user: User::default(),
shards: HashMap::from([(String::from("1"), Shard::default())]),
@@ -237,6 +244,8 @@ impl Config {
);
info!("Connection timeout: {}ms", self.general.connect_timeout);
info!("Sharding function: {}", self.query_router.sharding_function);
info!("Primary reads: {}", self.query_router.primary_reads_enabled);
info!("Query router: {}", self.query_router.query_parser_enabled);
info!("Number of shards: {}", self.shards.len());
}
}
@@ -244,8 +253,8 @@ impl Config {
/// Get a read-only instance of the configuration
/// from anywhere in the app.
/// ArcSwap makes this cheap and quick.
pub fn get_config() -> Guard<Arc<Config>> {
CONFIG.load()
pub fn get_config() -> Config {
(*(*CONFIG.load())).clone()
}
/// Parse the configuration file located at the path.
@@ -357,7 +366,7 @@ pub async fn parse(path: &str) -> Result<(), Error> {
}
};
config.path = Some(path.to_string());
config.path = path.to_string();
// Update the configuration globally.
CONFIG.store(Arc::new(config.clone()));
@@ -365,6 +374,27 @@ pub async fn parse(path: &str) -> Result<(), Error> {
Ok(())
}
pub async fn reload_config(client_server_map: ClientServerMap) -> Result<(), Error> {
let old_config = get_config();
match parse(&old_config.path).await {
Ok(()) => (),
Err(err) => {
error!("Config reload error: {:?}", err);
return Err(Error::BadConfig);
}
};
let new_config = get_config();
if old_config.shards != new_config.shards {
info!("Sharding configuration changed, re-creating server pools");
ConnectionPool::from_config(client_server_map).await
} else {
Ok(())
}
}
#[cfg(test)]
mod test {
use super::*;
@@ -377,6 +407,6 @@ mod test {
assert_eq!(get_config().shards["1"].servers[0].0, "127.0.0.1");
assert_eq!(get_config().shards["0"].servers[0].2, "primary");
assert_eq!(get_config().query_router.default_role, "any");
assert_eq!(get_config().path, Some("pgcat.toml".to_string()));
assert_eq!(get_config().path, "pgcat.toml".to_string());
}
}

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2022 Lev Kokotov <lev@levthe.dev>
// Copyright (c) 2022 Lev Kokotov <hi@levthe.dev>
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
@@ -34,7 +34,7 @@ extern crate sqlparser;
extern crate tokio;
extern crate toml;
use log::{error, info};
use log::{debug, error, info};
use parking_lot::Mutex;
use tokio::net::TcpListener;
use tokio::{
@@ -59,9 +59,9 @@ mod server;
mod sharding;
mod stats;
use config::get_config;
use pool::{ClientServerMap, ConnectionPool};
use stats::{Collector, Reporter};
use config::{get_config, reload_config};
use pool::{get_pool, ClientServerMap, ConnectionPool};
use stats::{Collector, Reporter, REPORTER};
#[tokio::main(worker_threads = 4)]
async fn main() {
@@ -109,37 +109,39 @@ async fn main() {
// Statistics reporting.
let (tx, rx) = mpsc::channel(100);
REPORTER.store(Arc::new(Reporter::new(tx.clone())));
// Connection pool that allows to query all shards and replicas.
let mut pool =
ConnectionPool::from_config(client_server_map.clone(), Reporter::new(tx.clone())).await;
match ConnectionPool::from_config(client_server_map.clone()).await {
Ok(_) => (),
Err(err) => {
error!("Pool error: {:?}", err);
return;
}
};
let pool = get_pool();
// Statistics collector task.
let collector_tx = tx.clone();
// Save these for reloading
let reload_client_server_map = client_server_map.clone();
let addresses = pool.databases();
tokio::task::spawn(async move {
let mut stats_collector = Collector::new(rx, collector_tx);
stats_collector.collect(addresses).await;
});
// Connect to all servers and validate their versions.
let server_info = match pool.validate().await {
Ok(info) => info,
Err(err) => {
error!("Could not validate connection pool: {:?}", err);
return;
}
};
info!("Waiting for clients");
drop(pool);
// Client connection loop.
tokio::task::spawn(async move {
loop {
let pool = pool.clone();
let client_server_map = client_server_map.clone();
let server_info = server_info.clone();
let reporter = Reporter::new(tx.clone());
let (socket, addr) = match listener.accept().await {
Ok((socket, addr)) => (socket, addr),
@@ -152,12 +154,11 @@ async fn main() {
// Handle client.
tokio::task::spawn(async move {
let start = chrono::offset::Utc::now().naive_utc();
match client::Client::startup(socket, client_server_map, server_info, reporter)
.await
{
match client::Client::startup(socket, client_server_map).await {
Ok(mut client) => {
info!("Client {:?} connected", addr);
match client.handle(pool).await {
match client.handle().await {
Ok(()) => {
let duration = chrono::offset::Utc::now().naive_utc() - start;
@@ -176,7 +177,7 @@ async fn main() {
}
Err(err) => {
error!("Client failed to login: {:?}", err);
debug!("Client failed to login: {:?}", err);
}
};
});
@@ -190,16 +191,15 @@ async fn main() {
loop {
stream.recv().await;
info!("Reloading config");
match config::parse("pgcat.toml").await {
Ok(_) => {
get_config().show();
}
Err(err) => {
error!("{:?}", err);
return;
}
match reload_config(reload_client_server_map.clone()).await {
Ok(_) => (),
Err(_) => continue,
};
get_config().show();
}
});

View File

@@ -1,9 +1,10 @@
/// Pooling, failover and banlist.
use arc_swap::ArcSwap;
use async_trait::async_trait;
use bb8::{ManageConnection, Pool, PooledConnection};
use bytes::BytesMut;
use chrono::naive::NaiveDateTime;
use log::{debug, error, info, warn};
use once_cell::sync::Lazy;
use parking_lot::{Mutex, RwLock};
use std::collections::HashMap;
use std::sync::Arc;
@@ -12,28 +13,47 @@ use std::time::Instant;
use crate::config::{get_config, Address, Role, User};
use crate::errors::Error;
use crate::server::Server;
use crate::stats::Reporter;
use crate::stats::{get_reporter, Reporter};
pub type BanList = Arc<RwLock<Vec<HashMap<Address, NaiveDateTime>>>>;
pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, String)>>>;
/// The connection pool, globally available.
/// This is atomic and safe and read-optimized.
/// The pool is recreated dynamically when the config is reloaded.
pub static POOL: Lazy<ArcSwap<ConnectionPool>> =
Lazy::new(|| ArcSwap::from_pointee(ConnectionPool::default()));
/// The globally accessible connection pool.
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Default)]
pub struct ConnectionPool {
/// The pools handled internally by bb8.
databases: Vec<Vec<Pool<ServerPool>>>,
/// The addresses (host, port, role) to handle
/// failover and load balancing deterministically.
addresses: Vec<Vec<Address>>,
round_robin: usize,
/// List of banned addresses (see above)
/// that should not be queried.
banlist: BanList,
/// The statistics aggregator runs in a separate task
/// and receives stats from clients, servers, and the pool.
stats: Reporter,
/// The server information (K messages) have to be passed to the
/// clients on startup. We pre-connect to all shards and replicas
/// on pool creation and save the K messages here.
server_info: BytesMut,
}
impl ConnectionPool {
/// Construct the connection pool from the configuration.
pub async fn from_config(
client_server_map: ClientServerMap,
stats: Reporter,
) -> ConnectionPool {
pub async fn from_config(client_server_map: ClientServerMap) -> Result<(), Error> {
let reporter = get_reporter();
let config = get_config();
let mut shards = Vec::new();
let mut addresses = Vec::new();
let mut banlist = Vec::new();
@@ -44,6 +64,8 @@ impl ConnectionPool {
.into_keys()
.map(|x| x.to_string())
.collect::<Vec<String>>();
// Sort by shard number to ensure consistency.
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
for shard_idx in shard_ids {
@@ -82,7 +104,7 @@ impl ConnectionPool {
config.user.clone(),
&shard.database,
client_server_map.clone(),
stats.clone(),
reporter.clone(),
);
let pool = Pool::builder()
@@ -105,15 +127,28 @@ impl ConnectionPool {
}
assert_eq!(shards.len(), addresses.len());
let address_len = addresses.len();
ConnectionPool {
let mut pool = ConnectionPool {
databases: shards,
addresses: addresses,
round_robin: rand::random::<usize>() % address_len, // Start at a random replica
banlist: Arc::new(RwLock::new(banlist)),
stats: stats,
stats: reporter,
server_info: BytesMut::new(),
};
// Connect to the servers to make sure pool configuration is valid
// before setting it globally.
match pool.validate().await {
Ok(_) => (),
Err(err) => {
error!("Could not validate connection pool: {:?}", err);
return Err(err);
}
};
POOL.store(Arc::new(pool.clone()));
Ok(())
}
/// Connect to all shards and grab server information.
@@ -121,16 +156,18 @@ impl ConnectionPool {
/// when they connect.
/// This also warms up the pool for clients that connect when
/// the pooler starts up.
pub async fn validate(&mut self) -> Result<BytesMut, Error> {
async fn validate(&mut self) -> Result<(), Error> {
let mut server_infos = Vec::new();
let stats = self.stats.clone();
for shard in 0..self.shards() {
let mut round_robin = 0;
for _ in 0..self.servers(shard) {
// To keep stats consistent.
let fake_process_id = 0;
let connection = match self.get(shard, None, fake_process_id).await {
let connection = match self.get(shard, None, fake_process_id, round_robin).await {
Ok(conn) => conn,
Err(err) => {
error!("Shard {} down or misconfigured: {:?}", shard, err);
@@ -138,10 +175,9 @@ impl ConnectionPool {
}
};
let mut proxy = connection.0;
let proxy = connection.0;
let address = connection.1;
let server = &mut *proxy;
let server = &*proxy;
let server_info = server.server_info();
stats.client_disconnecting(fake_process_id, address.id);
@@ -157,6 +193,7 @@ impl ConnectionPool {
}
server_infos.push(server_info);
round_robin += 1;
}
}
@@ -166,15 +203,18 @@ impl ConnectionPool {
return Err(Error::AllServersDown);
}
Ok(server_infos[0].clone())
self.server_info = server_infos[0].clone();
Ok(())
}
/// Get a connection from the pool.
pub async fn get(
&mut self,
shard: usize,
role: Option<Role>,
process_id: i32,
shard: usize, // shard number
role: Option<Role>, // primary or replica
process_id: i32, // client id
mut round_robin: usize, // round robin offset
) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
let now = Instant::now();
let addresses = &self.addresses[shard];
@@ -204,9 +244,9 @@ impl ConnectionPool {
while allowed_attempts > 0 {
// Round-robin replicas.
self.round_robin += 1;
round_robin += 1;
let index = self.round_robin % addresses.len();
let index = round_robin % addresses.len();
let address = &addresses[index];
// Make sure you're getting a primary or a replica
@@ -218,6 +258,7 @@ impl ConnectionPool {
allowed_attempts -= 1;
// Don't attempt to connect to banned servers.
if self.is_banned(address, shard, role) {
continue;
}
@@ -390,6 +431,10 @@ impl ConnectionPool {
pub fn address(&self, shard: usize, server: usize) -> &Address {
&self.addresses[shard][server]
}
pub fn server_info(&self) -> BytesMut {
self.server_info.clone()
}
}
/// Wrapper for the bb8 connection pool.
@@ -470,3 +515,8 @@ impl ManageConnection for ServerPool {
conn.is_bad()
}
}
/// Get the connection pool
pub fn get_pool() -> ConnectionPool {
(*(*POOL.load())).clone()
}

View File

@@ -12,12 +12,14 @@ use crate::config::{get_config, Role};
use crate::sharding::{Sharder, ShardingFunction};
/// Regexes used to parse custom commands.
const CUSTOM_SQL_REGEXES: [&str; 5] = [
const CUSTOM_SQL_REGEXES: [&str; 7] = [
r"(?i)^ *SET SHARDING KEY TO '?([0-9]+)'? *;? *$",
r"(?i)^ *SET SHARD TO '?([0-9]+|ANY)'? *;? *$",
r"(?i)^ *SHOW SHARD *;? *$",
r"(?i)^ *SET SERVER ROLE TO '(PRIMARY|REPLICA|ANY|AUTO|DEFAULT)' *;? *$",
r"(?i)^ *SHOW SERVER ROLE *;? *$",
r"(?i)^ *SET PRIMARY READS TO '?(on|off|default)'? *;? *$",
r"(?i)^ *SHOW PRIMARY READS *;? *$",
];
/// Custom commands.
@@ -28,6 +30,8 @@ pub enum Command {
ShowShard,
SetServerRole,
ShowServerRole,
SetPrimaryReads,
ShowPrimaryReads,
}
/// Quickly test for match when a query is received.
@@ -38,27 +42,17 @@ static CUSTOM_SQL_REGEX_LIST: OnceCell<Vec<Regex>> = OnceCell::new();
/// The query router.
pub struct QueryRouter {
/// By default, queries go here, unless we have better information
/// about what the client wants.
default_server_role: Option<Role>,
/// Number of shards in the cluster.
shards: usize,
/// Which shard we should be talking to right now.
active_shard: Option<usize>,
/// Which server should we be talking to.
active_role: Option<Role>,
/// Include the primary into the replica pool for reads.
primary_reads_enabled: bool,
/// Should we try to parse queries to route them to replicas or primary automatically.
/// Should we try to parse queries to route them to replicas or primary automatically
query_parser_enabled: bool,
/// Which sharding function we're using.
sharding_function: ShardingFunction,
/// Include the primary into the replica pool for reads.
primary_reads_enabled: bool,
}
impl QueryRouter {
@@ -97,28 +91,11 @@ impl QueryRouter {
pub fn new() -> QueryRouter {
let config = get_config();
let default_server_role = match config.query_router.default_role.as_ref() {
"any" => None,
"primary" => Some(Role::Primary),
"replica" => Some(Role::Replica),
_ => unreachable!(),
};
let sharding_function = match config.query_router.sharding_function.as_ref() {
"pg_bigint_hash" => ShardingFunction::PgBigintHash,
"sha1" => ShardingFunction::Sha1,
_ => unreachable!(),
};
QueryRouter {
default_server_role: default_server_role,
shards: config.shards.len(),
active_role: default_server_role,
active_shard: None,
primary_reads_enabled: config.query_router.primary_reads_enabled,
active_role: None,
query_parser_enabled: config.query_router.query_parser_enabled,
sharding_function,
primary_reads_enabled: config.query_router.primary_reads_enabled,
}
}
@@ -146,21 +123,48 @@ impl QueryRouter {
let matches: Vec<_> = regex_set.matches(&query).into_iter().collect();
// This is not a custom query, try to infer which
// server it'll go to if the query parser is enabled.
if matches.len() != 1 {
debug!("Regular query");
if self.query_parser_enabled && self.role() == None {
debug!("Inferring role");
self.infer_role(buf.clone());
}
return None;
}
let config = get_config();
let sharding_function = match config.query_router.sharding_function.as_ref() {
"pg_bigint_hash" => ShardingFunction::PgBigintHash,
"sha1" => ShardingFunction::Sha1,
_ => unreachable!(),
};
let default_server_role = match config.query_router.default_role.as_ref() {
"any" => None,
"primary" => Some(Role::Primary),
"replica" => Some(Role::Replica),
_ => unreachable!(),
};
let command = match matches[0] {
0 => Command::SetShardingKey,
1 => Command::SetShard,
2 => Command::ShowShard,
3 => Command::SetServerRole,
4 => Command::ShowServerRole,
5 => Command::SetPrimaryReads,
6 => Command::ShowPrimaryReads,
_ => unreachable!(),
};
let mut value = match command {
Command::SetShardingKey | Command::SetShard | Command::SetServerRole => {
Command::SetShardingKey
| Command::SetShard
| Command::SetServerRole
| Command::SetPrimaryReads => {
// Capture value. I know this re-runs the regex engine, but I haven't
// figured out a better way just yet. I think I can write a single Regex
// that matches all 5 custom SQL patterns, but maybe that's not very legible?
@@ -187,11 +191,16 @@ impl QueryRouter {
}
}
},
Command::ShowPrimaryReads => match self.primary_reads_enabled {
true => String::from("on"),
false => String::from("off"),
},
};
match command {
Command::SetShardingKey => {
let sharder = Sharder::new(self.shards, self.sharding_function);
let sharder = Sharder::new(config.shards.len(), sharding_function);
let shard = sharder.shard(value.parse::<i64>().unwrap());
self.active_shard = Some(shard);
value = shard.to_string();
@@ -199,7 +208,7 @@ impl QueryRouter {
Command::SetShard => {
self.active_shard = match value.to_ascii_uppercase().as_ref() {
"ANY" => Some(rand::random::<usize>() % self.shards),
"ANY" => Some(rand::random::<usize>() % config.shards.len()),
_ => Some(value.parse::<usize>().unwrap()),
};
}
@@ -227,8 +236,8 @@ impl QueryRouter {
}
"default" => {
self.active_role = self.default_server_role;
self.query_parser_enabled = get_config().query_router.query_parser_enabled;
self.active_role = default_server_role;
self.query_parser_enabled = config.query_router.query_parser_enabled;
self.active_role
}
@@ -236,6 +245,19 @@ impl QueryRouter {
};
}
Command::SetPrimaryReads => {
if value == "on" {
debug!("Setting primary reads to on");
self.primary_reads_enabled = true;
} else if value == "off" {
debug!("Setting primary reads to off");
self.primary_reads_enabled = false;
} else if value == "default" {
debug!("Setting primary reads to default");
self.primary_reads_enabled = config.query_router.primary_reads_enabled;
}
}
_ => (),
}
@@ -330,23 +352,15 @@ impl QueryRouter {
}
}
/// Reset the router back to defaults.
/// This must be called at the end of every transaction in transaction mode.
pub fn _reset(&mut self) {
self.active_role = self.default_server_role;
self.active_shard = None;
pub fn set_shard(&mut self, shard: usize) {
self.active_shard = Some(shard);
}
/// Should we attempt to parse queries?
#[allow(dead_code)]
pub fn query_parser_enabled(&self) -> bool {
self.query_parser_enabled
}
/// Allows to toggle primary reads in tests.
#[allow(dead_code)]
pub fn toggle_primary_reads(&mut self, value: bool) {
self.primary_reads_enabled = value;
}
}
#[cfg(test)]
@@ -369,7 +383,8 @@ mod test {
let mut qr = QueryRouter::new();
assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None);
assert_eq!(qr.query_parser_enabled(), true);
qr.toggle_primary_reads(false);
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None);
let queries = vec![
simple_query("SELECT * FROM items WHERE id = 5"),
@@ -410,7 +425,7 @@ mod test {
QueryRouter::setup();
let mut qr = QueryRouter::new();
let query = simple_query("SELECT * FROM items WHERE id = 5");
qr.toggle_primary_reads(true);
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO on")) != None);
assert!(qr.infer_role(query));
assert_eq!(qr.role(), None);
@@ -421,7 +436,7 @@ mod test {
QueryRouter::setup();
let mut qr = QueryRouter::new();
qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'"));
qr.toggle_primary_reads(false);
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None);
let prepared_stmt = BytesMut::from(
&b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..],
@@ -450,6 +465,10 @@ mod test {
"SET SERVER ROLE TO 'any'",
"SET SERVER ROLE TO 'auto'",
"SHOW SERVER ROLE",
"SET PRIMARY READS TO 'on'",
"SET PRIMARY READS TO 'off'",
"SET PRIMARY READS TO 'default'",
"SHOW PRIMARY READS",
// Lower case
"set sharding key to '1'",
"set shard to '1'",
@@ -459,9 +478,13 @@ mod test {
"set server role to 'any'",
"set server role to 'auto'",
"show server role",
"set primary reads to 'on'",
"set primary reads to 'OFF'",
"set primary reads to 'deFaUlt'",
// No quotes
"SET SHARDING KEY TO 11235",
"SET SHARD TO 15",
"SET PRIMARY READS TO off",
// Spaces and semicolon
" SET SHARDING KEY TO 11235 ; ",
" SET SHARD TO 15; ",
@@ -469,18 +492,23 @@ mod test {
" SET SERVER ROLE TO 'primary'; ",
" SET SERVER ROLE TO 'primary' ; ",
" SET SERVER ROLE TO 'primary' ;",
" SET PRIMARY READS TO 'off' ;",
];
// Which regexes it'll match to in the list
let matches = [
0, 1, 2, 3, 3, 3, 3, 4, 0, 1, 2, 3, 3, 3, 3, 4, 0, 1, 0, 1, 0, 3, 3, 3,
0, 1, 2, 3, 3, 3, 3, 4, 5, 5, 5, 6, 0, 1, 2, 3, 3, 3, 3, 4, 5, 5, 5, 0, 1, 5, 0, 1, 0,
3, 3, 3, 5,
];
let list = CUSTOM_SQL_REGEX_LIST.get().unwrap();
let set = CUSTOM_SQL_REGEX_SET.get().unwrap();
for (i, test) in tests.iter().enumerate() {
assert!(list[matches[i]].is_match(test));
if !list[matches[i]].is_match(test) {
println!("{} does not match {}", test, list[matches[i]]);
assert!(false);
}
assert_eq!(set.matches(test).into_iter().collect::<Vec<_>>().len(), 1);
}
@@ -549,6 +577,26 @@ mod test {
Some((Command::ShowServerRole, String::from(*role)))
);
}
let primary_reads = ["on", "off", "default"];
let primary_reads_enabled = ["on", "off", "on"];
for (idx, primary_reads) in primary_reads.iter().enumerate() {
assert_eq!(
qr.try_execute_command(simple_query(&format!(
"SET PRIMARY READS TO {}",
primary_reads
))),
Some((Command::SetPrimaryReads, String::from(*primary_reads)))
);
assert_eq!(
qr.try_execute_command(simple_query("SHOW PRIMARY READS")),
Some((
Command::ShowPrimaryReads,
String::from(primary_reads_enabled[idx])
))
);
}
}
#[test]
@@ -556,7 +604,7 @@ mod test {
QueryRouter::setup();
let mut qr = QueryRouter::new();
let query = simple_query("SET SERVER ROLE TO 'auto'");
qr.toggle_primary_reads(false);
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None);
assert!(qr.try_execute_command(query) != None);
assert!(qr.query_parser_enabled());
@@ -573,6 +621,6 @@ mod test {
assert!(qr.query_parser_enabled());
let query = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(query) != None);
assert!(!qr.query_parser_enabled());
assert!(qr.query_parser_enabled());
}
}

View File

@@ -1,9 +1,13 @@
use arc_swap::ArcSwap;
/// Statistics and reporting.
use log::info;
use once_cell::sync::Lazy;
use parking_lot::Mutex;
use std::collections::HashMap;
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::mpsc::{channel, Receiver, Sender};
pub static REPORTER: Lazy<ArcSwap<Reporter>> =
Lazy::new(|| ArcSwap::from_pointee(Reporter::default()));
/// Latest stats updated every second; used in SHOW STATS and other admin commands.
static LATEST_STATS: Lazy<Mutex<HashMap<usize, HashMap<String, i64>>>> =
@@ -60,6 +64,13 @@ pub struct Reporter {
tx: Sender<Event>,
}
impl Default for Reporter {
fn default() -> Reporter {
let (tx, _rx) = channel(5);
Reporter { tx }
}
}
impl Reporter {
/// Create a new Reporter instance.
pub fn new(tx: Sender<Event>) -> Reporter {
@@ -289,7 +300,7 @@ impl Collector {
("avg_query_time", 0),
("avg_xact_count", 0),
("avg_sent", 0),
("avg_received", 0),
("avg_recv", 0),
("avg_wait_time", 0),
("maxwait_us", 0),
("maxwait", 0),
@@ -493,10 +504,14 @@ impl Collector {
"avg_query_count",
"avgxact_count",
"avg_sent",
"avg_received",
"avg_recv",
"avg_wait_time",
] {
let total_name = stat.replace("avg_", "total_");
let total_name = match stat {
&"avg_recv" => "total_received".to_string(), // Because PgBouncer is saving bytes
_ => stat.replace("avg_", "total_"),
};
let old_value = old_stats.entry(total_name.clone()).or_insert(0);
let new_value = stats.get(total_name.as_str()).unwrap_or(&0).to_owned();
let avg = (new_value - *old_value) / (STAT_PERIOD as i64 / 1_000); // Avg / second
@@ -515,3 +530,8 @@ impl Collector {
pub fn get_stats() -> HashMap<usize, HashMap<String, i64>> {
LATEST_STATS.lock().clone()
}
/// Get the statistics reporter used to update stats across the pools/clients.
pub fn get_reporter() -> Reporter {
(*(*REPORTER.load())).clone()
}

View File

@@ -12,6 +12,8 @@
SET SHARD TO :shard;
SET SERVER ROLE TO 'auto';
BEGIN;
UPDATE pgbench_accounts SET abalance = abalance + :delta WHERE aid = :aid;
@@ -26,3 +28,12 @@ INSERT INTO pgbench_history (tid, bid, aid, delta, mtime) VALUES (:tid, :bid, :a
END;
SET SHARDING KEY TO :aid;
-- Read load balancing
SELECT abalance FROM pgbench_accounts WHERE aid = :aid;
SET SERVER ROLE TO 'replica';
-- Read load balancing
SELECT abalance FROM pgbench_accounts WHERE aid = :aid;

View File

@@ -151,3 +151,12 @@ SELECT 1;
set server role to 'replica';
SeT SeRver Role TO 'PrImARY';
select 1;
SET PRIMARY READS TO 'on';
SELECT 1;
SET PRIMARY READS TO 'off';
SELECT 1;
SET PRIMARY READS TO 'default';
SELECT 1;