Live reloading entire config and bug fixes (#84)

* Support reloading the entire config (including sharding logic) without restart.

* Fix bug incorrectly handing error reporting when the shard is set incorrectly via SET SHARD TO command.
selected wrong shard and the connection keep reporting fatal #80.

* Fix total_received and avg_recv admin database statistics.

* Enabling the query parser by default.

* More tests.
This commit is contained in:
Lev Kokotov
2022-06-24 14:52:38 -07:00
committed by GitHub
parent d865d9f9d8
commit b93303eb83
14 changed files with 393 additions and 188 deletions

View File

@@ -87,7 +87,7 @@ default_role = "any"
# every incoming query to determine if it's a read or a write. # every incoming query to determine if it's a read or a write.
# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write, # If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,
# we'll direct it to the primary. # we'll direct it to the primary.
query_parser_enabled = false query_parser_enabled = true
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for # If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
# load balancing of read queries. Otherwise, the primary will only be used for write # load balancing of read queries. Otherwise, the primary will only be used for write

View File

@@ -42,7 +42,15 @@ pgbench -U sharding_user -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol extended
psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'COPY (SELECT * FROM pgbench_accounts LIMIT 15) TO STDOUT;' > /dev/null psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'COPY (SELECT * FROM pgbench_accounts LIMIT 15) TO STDOUT;' > /dev/null
# Query cancellation test # Query cancellation test
(psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'SELECT pg_sleep(5)' || true) & (psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'SELECT pg_sleep(50)' || true) &
sleep 1
killall psql -s SIGINT
# Reload pool (closing unused server connections)
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD'
(psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'SELECT pg_sleep(50)' || true) &
sleep 1
killall psql -s SIGINT killall psql -s SIGINT
# Sharding insert # Sharding insert
@@ -94,7 +102,7 @@ toxiproxy-cli toxic remove --toxicName latency_downstream postgres_replica
start_pgcat "info" start_pgcat "info"
# Test session mode (and config reload) # Test session mode (and config reload)
sed -i 's/pool_mode = "transaction"/pool_mode = "session"/' pgcat.toml sed -i 's/pool_mode = "transaction"/pool_mode = "session"/' .circleci/pgcat.toml
# Reload config test # Reload config test
kill -SIGHUP $(pgrep pgcat) kill -SIGHUP $(pgrep pgcat)

2
Cargo.lock generated
View File

@@ -368,7 +368,7 @@ dependencies = [
[[package]] [[package]]
name = "pgcat" name = "pgcat"
version = "0.2.0-beta1" version = "0.4.0-beta1"
dependencies = [ dependencies = [
"arc-swap", "arc-swap",
"async-trait", "async-trait",

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "pgcat" name = "pgcat"
version = "0.2.1-beta1" version = "0.4.0-beta1"
edition = "2021" edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -28,4 +28,4 @@ parking_lot = "0.11"
hmac = "0.12" hmac = "0.12"
sha2 = "0.10" sha2 = "0.10"
base64 = "0.13" base64 = "0.13"
stringprep = "0.1" stringprep = "0.1"

View File

@@ -87,7 +87,7 @@ default_role = "any"
# every incoming query to determine if it's a read or a write. # every incoming query to determine if it's a read or a write.
# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write, # If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,
# we'll direct it to the primary. # we'll direct it to the primary.
query_parser_enabled = false query_parser_enabled = true
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for # If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
# load balancing of read queries. Otherwise, the primary will only be used for write # load balancing of read queries. Otherwise, the primary will only be used for write

View File

@@ -4,17 +4,19 @@ use log::{info, trace};
use std::collections::HashMap; use std::collections::HashMap;
use tokio::net::tcp::OwnedWriteHalf; use tokio::net::tcp::OwnedWriteHalf;
use crate::config::{get_config, parse}; use crate::config::{get_config, reload_config};
use crate::errors::Error; use crate::errors::Error;
use crate::messages::*; use crate::messages::*;
use crate::pool::ConnectionPool; use crate::pool::ConnectionPool;
use crate::stats::get_stats; use crate::stats::get_stats;
use crate::ClientServerMap;
/// Handle admin client. /// Handle admin client.
pub async fn handle_admin( pub async fn handle_admin(
stream: &mut OwnedWriteHalf, stream: &mut OwnedWriteHalf,
mut query: BytesMut, mut query: BytesMut,
pool: ConnectionPool, pool: ConnectionPool,
client_server_map: ClientServerMap,
) -> Result<(), Error> { ) -> Result<(), Error> {
let code = query.get_u8() as char; let code = query.get_u8() as char;
@@ -34,7 +36,7 @@ pub async fn handle_admin(
show_stats(stream, &pool).await show_stats(stream, &pool).await
} else if query.starts_with("RELOAD") { } else if query.starts_with("RELOAD") {
trace!("RELOAD"); trace!("RELOAD");
reload(stream).await reload(stream, client_server_map).await
} else if query.starts_with("SHOW CONFIG") { } else if query.starts_with("SHOW CONFIG") {
trace!("SHOW CONFIG"); trace!("SHOW CONFIG");
show_config(stream).await show_config(stream).await
@@ -143,10 +145,7 @@ async fn show_version(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
/// Show utilization of connection pools for each shard and replicas. /// Show utilization of connection pools for each shard and replicas.
async fn show_pools(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> { async fn show_pools(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> {
let stats = get_stats(); let stats = get_stats();
let config = { let config = get_config();
let guard = get_config();
&*guard.clone()
};
let columns = vec![ let columns = vec![
("database", DataType::Text), ("database", DataType::Text),
@@ -199,9 +198,7 @@ async fn show_pools(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Resul
/// Show shards and replicas. /// Show shards and replicas.
async fn show_databases(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> { async fn show_databases(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> {
let guard = get_config(); let config = get_config();
let config = &*guard.clone();
drop(guard);
// Columns // Columns
let columns = vec![ let columns = vec![
@@ -266,17 +263,15 @@ async fn ignore_set(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
} }
/// Reload the configuration file without restarting the process. /// Reload the configuration file without restarting the process.
async fn reload(stream: &mut OwnedWriteHalf) -> Result<(), Error> { async fn reload(
stream: &mut OwnedWriteHalf,
client_server_map: ClientServerMap,
) -> Result<(), Error> {
info!("Reloading config"); info!("Reloading config");
let config = get_config(); reload_config(client_server_map).await?;
let path = config.path.clone().unwrap();
parse(&path).await?; get_config().show();
let config = get_config();
config.show();
let mut res = BytesMut::new(); let mut res = BytesMut::new();
@@ -292,10 +287,8 @@ async fn reload(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
/// Shows current configuration. /// Shows current configuration.
async fn show_config(stream: &mut OwnedWriteHalf) -> Result<(), Error> { async fn show_config(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
let guard = get_config(); let config = &get_config();
let config = &*guard.clone();
let config: HashMap<String, String> = config.into(); let config: HashMap<String, String> = config.into();
drop(guard);
// Configs that cannot be changed without restarting. // Configs that cannot be changed without restarting.
let immutables = ["host", "port", "connect_timeout"]; let immutables = ["host", "port", "connect_timeout"];

View File

@@ -13,10 +13,10 @@ use crate::config::get_config;
use crate::constants::*; use crate::constants::*;
use crate::errors::Error; use crate::errors::Error;
use crate::messages::*; use crate::messages::*;
use crate::pool::{ClientServerMap, ConnectionPool}; use crate::pool::{get_pool, ClientServerMap};
use crate::query_router::{Command, QueryRouter}; use crate::query_router::{Command, QueryRouter};
use crate::server::Server; use crate::server::Server;
use crate::stats::Reporter; use crate::stats::{get_reporter, Reporter};
/// The client state. One of these is created per client. /// The client state. One of these is created per client.
pub struct Client { pub struct Client {
@@ -69,12 +69,11 @@ impl Client {
pub async fn startup( pub async fn startup(
mut stream: TcpStream, mut stream: TcpStream,
client_server_map: ClientServerMap, client_server_map: ClientServerMap,
server_info: BytesMut,
stats: Reporter,
) -> Result<Client, Error> { ) -> Result<Client, Error> {
let config = get_config().clone(); let config = get_config();
let transaction_mode = config.general.pool_mode.starts_with("t"); let transaction_mode = config.general.pool_mode == "transaction";
// drop(config); let stats = get_reporter();
loop { loop {
trace!("Waiting for StartupMessage"); trace!("Waiting for StartupMessage");
@@ -154,9 +153,10 @@ impl Client {
debug!("Password authentication successful"); debug!("Password authentication successful");
auth_ok(&mut stream).await?; auth_ok(&mut stream).await?;
write_all(&mut stream, server_info).await?; write_all(&mut stream, get_pool().server_info()).await?;
backend_key_data(&mut stream, process_id, secret_key).await?; backend_key_data(&mut stream, process_id, secret_key).await?;
ready_for_query(&mut stream).await?; ready_for_query(&mut stream).await?;
trace!("Startup OK"); trace!("Startup OK");
let database = parameters let database = parameters
@@ -221,7 +221,7 @@ impl Client {
} }
/// Handle a connected and authenticated client. /// Handle a connected and authenticated client.
pub async fn handle(&mut self, mut pool: ConnectionPool) -> Result<(), Error> { pub async fn handle(&mut self) -> Result<(), Error> {
// The client wants to cancel a query it has issued previously. // The client wants to cancel a query it has issued previously.
if self.cancel_mode { if self.cancel_mode {
trace!("Sending CancelRequest"); trace!("Sending CancelRequest");
@@ -252,13 +252,19 @@ impl Client {
return Ok(Server::cancel(&address, &port, process_id, secret_key).await?); return Ok(Server::cancel(&address, &port, process_id, secret_key).await?);
} }
// The query router determines where the query is going to go,
// e.g. primary, replica, which shard.
let mut query_router = QueryRouter::new(); let mut query_router = QueryRouter::new();
let mut round_robin = 0;
// Our custom protocol loop. // Our custom protocol loop.
// We expect the client to either start a transaction with regular queries // We expect the client to either start a transaction with regular queries
// or issue commands for our sharding and server selection protocol. // or issue commands for our sharding and server selection protocol.
loop { loop {
trace!("Client idle, waiting for message"); trace!(
"Client idle, waiting for message, transaction mode: {}",
self.transaction_mode
);
// Read a complete message from the client, which normally would be // Read a complete message from the client, which normally would be
// either a `Q` (query) or `P` (prepare, extended protocol). // either a `Q` (query) or `P` (prepare, extended protocol).
@@ -267,32 +273,63 @@ impl Client {
// SET SHARDING KEY TO 'bigint'; // SET SHARDING KEY TO 'bigint';
let mut message = read_message(&mut self.read).await?; let mut message = read_message(&mut self.read).await?;
// Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config
// when starting a query.
let mut pool = get_pool();
// Avoid taking a server if the client just wants to disconnect. // Avoid taking a server if the client just wants to disconnect.
if message[0] as char == 'X' { if message[0] as char == 'X' {
trace!("Client disconnecting"); debug!("Client disconnecting");
return Ok(()); return Ok(());
} }
// Handle admin database queries. // Handle admin database queries.
if self.admin { if self.admin {
trace!("Handling admin command"); debug!("Handling admin command");
handle_admin(&mut self.write, message, pool.clone()).await?; handle_admin(
&mut self.write,
message,
pool.clone(),
self.client_server_map.clone(),
)
.await?;
continue; continue;
} }
let current_shard = query_router.shard();
// Handle all custom protocol commands, if any. // Handle all custom protocol commands, if any.
match query_router.try_execute_command(message.clone()) { match query_router.try_execute_command(message.clone()) {
// Normal query, not a custom command. // Normal query, not a custom command.
None => { None => (),
// Attempt to infer which server we want to query, i.e. primary or replica.
if query_router.query_parser_enabled() && query_router.role() == None {
query_router.infer_role(message.clone());
}
}
// SET SHARD TO // SET SHARD TO
Some((Command::SetShard, _)) => { Some((Command::SetShard, _)) => {
custom_protocol_response_ok(&mut self.write, "SET SHARD").await?; // Selected shard is not configured.
if query_router.shard() >= pool.shards() {
// Set the shard back to what it was.
query_router.set_shard(current_shard);
error_response(
&mut self.write,
&format!(
"shard {} is more than configured {}, staying on shard {}",
query_router.shard(),
pool.shards(),
current_shard,
),
)
.await?;
} else {
custom_protocol_response_ok(&mut self.write, "SET SHARD").await?;
}
continue;
}
// SET PRIMARY READS TO
Some((Command::SetPrimaryReads, _)) => {
custom_protocol_response_ok(&mut self.write, "SET PRIMARY READS").await?;
continue; continue;
} }
@@ -319,27 +356,24 @@ impl Client {
show_response(&mut self.write, "shard", &value).await?; show_response(&mut self.write, "shard", &value).await?;
continue; continue;
} }
};
// Make sure we selected a valid shard. // SHOW PRIMARY READS
if query_router.shard() >= pool.shards() { Some((Command::ShowPrimaryReads, value)) => {
error_response( show_response(&mut self.write, "primary reads", &value).await?;
&mut self.write, continue;
&format!( }
"shard {} is more than configured {}", };
query_router.shard(),
pool.shards()
),
)
.await?;
continue;
}
debug!("Waiting for connection from pool"); debug!("Waiting for connection from pool");
// Grab a server from the pool. // Grab a server from the pool.
let connection = match pool let connection = match pool
.get(query_router.shard(), query_router.role(), self.process_id) .get(
query_router.shard(),
query_router.role(),
self.process_id,
round_robin,
)
.await .await
{ {
Ok(conn) => { Ok(conn) => {
@@ -358,6 +392,8 @@ impl Client {
let address = connection.1; let address = connection.1;
let server = &mut *reference; let server = &mut *reference;
round_robin += 1;
// Server is assigned to the client in case the client wants to // Server is assigned to the client in case the client wants to
// cancel a query later. // cancel a query later.
server.claim(self.process_id, self.secret_key); server.claim(self.process_id, self.secret_key);

View File

@@ -1,5 +1,5 @@
/// Parse the configuration file. /// Parse the configuration file.
use arc_swap::{ArcSwap, Guard}; use arc_swap::ArcSwap;
use log::{error, info}; use log::{error, info};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use serde_derive::Deserialize; use serde_derive::Deserialize;
@@ -10,6 +10,7 @@ use tokio::io::AsyncReadExt;
use toml; use toml;
use crate::errors::Error; use crate::errors::Error;
use crate::{ClientServerMap, ConnectionPool};
/// Globally available configuration. /// Globally available configuration.
static CONFIG: Lazy<ArcSwap<Config>> = Lazy::new(|| ArcSwap::from_pointee(Config::default())); static CONFIG: Lazy<ArcSwap<Config>> = Lazy::new(|| ArcSwap::from_pointee(Config::default()));
@@ -126,7 +127,7 @@ impl Default for General {
} }
/// Shard configuration. /// Shard configuration.
#[derive(Deserialize, Debug, Clone)] #[derive(Deserialize, Debug, Clone, PartialEq)]
pub struct Shard { pub struct Shard {
pub servers: Vec<(String, u16, String)>, pub servers: Vec<(String, u16, String)>,
pub database: String, pub database: String,
@@ -161,10 +162,16 @@ impl Default for QueryRouter {
} }
} }
fn default_path() -> String {
String::from("pgcat.toml")
}
/// Configuration wrapper. /// Configuration wrapper.
#[derive(Deserialize, Debug, Clone)] #[derive(Deserialize, Debug, Clone)]
pub struct Config { pub struct Config {
pub path: Option<String>, #[serde(default = "default_path")]
pub path: String,
pub general: General, pub general: General,
pub user: User, pub user: User,
pub shards: HashMap<String, Shard>, pub shards: HashMap<String, Shard>,
@@ -174,7 +181,7 @@ pub struct Config {
impl Default for Config { impl Default for Config {
fn default() -> Config { fn default() -> Config {
Config { Config {
path: Some(String::from("pgcat.toml")), path: String::from("pgcat.toml"),
general: General::default(), general: General::default(),
user: User::default(), user: User::default(),
shards: HashMap::from([(String::from("1"), Shard::default())]), shards: HashMap::from([(String::from("1"), Shard::default())]),
@@ -237,6 +244,8 @@ impl Config {
); );
info!("Connection timeout: {}ms", self.general.connect_timeout); info!("Connection timeout: {}ms", self.general.connect_timeout);
info!("Sharding function: {}", self.query_router.sharding_function); info!("Sharding function: {}", self.query_router.sharding_function);
info!("Primary reads: {}", self.query_router.primary_reads_enabled);
info!("Query router: {}", self.query_router.query_parser_enabled);
info!("Number of shards: {}", self.shards.len()); info!("Number of shards: {}", self.shards.len());
} }
} }
@@ -244,8 +253,8 @@ impl Config {
/// Get a read-only instance of the configuration /// Get a read-only instance of the configuration
/// from anywhere in the app. /// from anywhere in the app.
/// ArcSwap makes this cheap and quick. /// ArcSwap makes this cheap and quick.
pub fn get_config() -> Guard<Arc<Config>> { pub fn get_config() -> Config {
CONFIG.load() (*(*CONFIG.load())).clone()
} }
/// Parse the configuration file located at the path. /// Parse the configuration file located at the path.
@@ -357,7 +366,7 @@ pub async fn parse(path: &str) -> Result<(), Error> {
} }
}; };
config.path = Some(path.to_string()); config.path = path.to_string();
// Update the configuration globally. // Update the configuration globally.
CONFIG.store(Arc::new(config.clone())); CONFIG.store(Arc::new(config.clone()));
@@ -365,6 +374,27 @@ pub async fn parse(path: &str) -> Result<(), Error> {
Ok(()) Ok(())
} }
pub async fn reload_config(client_server_map: ClientServerMap) -> Result<(), Error> {
let old_config = get_config();
match parse(&old_config.path).await {
Ok(()) => (),
Err(err) => {
error!("Config reload error: {:?}", err);
return Err(Error::BadConfig);
}
};
let new_config = get_config();
if old_config.shards != new_config.shards {
info!("Sharding configuration changed, re-creating server pools");
ConnectionPool::from_config(client_server_map).await
} else {
Ok(())
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::*; use super::*;
@@ -377,6 +407,6 @@ mod test {
assert_eq!(get_config().shards["1"].servers[0].0, "127.0.0.1"); assert_eq!(get_config().shards["1"].servers[0].0, "127.0.0.1");
assert_eq!(get_config().shards["0"].servers[0].2, "primary"); assert_eq!(get_config().shards["0"].servers[0].2, "primary");
assert_eq!(get_config().query_router.default_role, "any"); assert_eq!(get_config().query_router.default_role, "any");
assert_eq!(get_config().path, Some("pgcat.toml".to_string())); assert_eq!(get_config().path, "pgcat.toml".to_string());
} }
} }

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2022 Lev Kokotov <lev@levthe.dev> // Copyright (c) 2022 Lev Kokotov <hi@levthe.dev>
// Permission is hereby granted, free of charge, to any person obtaining // Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the // a copy of this software and associated documentation files (the
@@ -34,7 +34,7 @@ extern crate sqlparser;
extern crate tokio; extern crate tokio;
extern crate toml; extern crate toml;
use log::{error, info}; use log::{debug, error, info};
use parking_lot::Mutex; use parking_lot::Mutex;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::{ use tokio::{
@@ -59,9 +59,9 @@ mod server;
mod sharding; mod sharding;
mod stats; mod stats;
use config::get_config; use config::{get_config, reload_config};
use pool::{ClientServerMap, ConnectionPool}; use pool::{get_pool, ClientServerMap, ConnectionPool};
use stats::{Collector, Reporter}; use stats::{Collector, Reporter, REPORTER};
#[tokio::main(worker_threads = 4)] #[tokio::main(worker_threads = 4)]
async fn main() { async fn main() {
@@ -109,37 +109,39 @@ async fn main() {
// Statistics reporting. // Statistics reporting.
let (tx, rx) = mpsc::channel(100); let (tx, rx) = mpsc::channel(100);
REPORTER.store(Arc::new(Reporter::new(tx.clone())));
// Connection pool that allows to query all shards and replicas. // Connection pool that allows to query all shards and replicas.
let mut pool = match ConnectionPool::from_config(client_server_map.clone()).await {
ConnectionPool::from_config(client_server_map.clone(), Reporter::new(tx.clone())).await; Ok(_) => (),
Err(err) => {
error!("Pool error: {:?}", err);
return;
}
};
let pool = get_pool();
// Statistics collector task. // Statistics collector task.
let collector_tx = tx.clone(); let collector_tx = tx.clone();
// Save these for reloading
let reload_client_server_map = client_server_map.clone();
let addresses = pool.databases(); let addresses = pool.databases();
tokio::task::spawn(async move { tokio::task::spawn(async move {
let mut stats_collector = Collector::new(rx, collector_tx); let mut stats_collector = Collector::new(rx, collector_tx);
stats_collector.collect(addresses).await; stats_collector.collect(addresses).await;
}); });
// Connect to all servers and validate their versions.
let server_info = match pool.validate().await {
Ok(info) => info,
Err(err) => {
error!("Could not validate connection pool: {:?}", err);
return;
}
};
info!("Waiting for clients"); info!("Waiting for clients");
drop(pool);
// Client connection loop. // Client connection loop.
tokio::task::spawn(async move { tokio::task::spawn(async move {
loop { loop {
let pool = pool.clone();
let client_server_map = client_server_map.clone(); let client_server_map = client_server_map.clone();
let server_info = server_info.clone();
let reporter = Reporter::new(tx.clone());
let (socket, addr) = match listener.accept().await { let (socket, addr) = match listener.accept().await {
Ok((socket, addr)) => (socket, addr), Ok((socket, addr)) => (socket, addr),
@@ -152,12 +154,11 @@ async fn main() {
// Handle client. // Handle client.
tokio::task::spawn(async move { tokio::task::spawn(async move {
let start = chrono::offset::Utc::now().naive_utc(); let start = chrono::offset::Utc::now().naive_utc();
match client::Client::startup(socket, client_server_map, server_info, reporter) match client::Client::startup(socket, client_server_map).await {
.await
{
Ok(mut client) => { Ok(mut client) => {
info!("Client {:?} connected", addr); info!("Client {:?} connected", addr);
match client.handle(pool).await {
match client.handle().await {
Ok(()) => { Ok(()) => {
let duration = chrono::offset::Utc::now().naive_utc() - start; let duration = chrono::offset::Utc::now().naive_utc() - start;
@@ -176,7 +177,7 @@ async fn main() {
} }
Err(err) => { Err(err) => {
error!("Client failed to login: {:?}", err); debug!("Client failed to login: {:?}", err);
} }
}; };
}); });
@@ -190,16 +191,15 @@ async fn main() {
loop { loop {
stream.recv().await; stream.recv().await;
info!("Reloading config"); info!("Reloading config");
match config::parse("pgcat.toml").await {
Ok(_) => { match reload_config(reload_client_server_map.clone()).await {
get_config().show(); Ok(_) => (),
} Err(_) => continue,
Err(err) => {
error!("{:?}", err);
return;
}
}; };
get_config().show();
} }
}); });

View File

@@ -1,9 +1,10 @@
/// Pooling, failover and banlist. use arc_swap::ArcSwap;
use async_trait::async_trait; use async_trait::async_trait;
use bb8::{ManageConnection, Pool, PooledConnection}; use bb8::{ManageConnection, Pool, PooledConnection};
use bytes::BytesMut; use bytes::BytesMut;
use chrono::naive::NaiveDateTime; use chrono::naive::NaiveDateTime;
use log::{debug, error, info, warn}; use log::{debug, error, info, warn};
use once_cell::sync::Lazy;
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
@@ -12,28 +13,47 @@ use std::time::Instant;
use crate::config::{get_config, Address, Role, User}; use crate::config::{get_config, Address, Role, User};
use crate::errors::Error; use crate::errors::Error;
use crate::server::Server; use crate::server::Server;
use crate::stats::Reporter; use crate::stats::{get_reporter, Reporter};
pub type BanList = Arc<RwLock<Vec<HashMap<Address, NaiveDateTime>>>>; pub type BanList = Arc<RwLock<Vec<HashMap<Address, NaiveDateTime>>>>;
pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, String)>>>; pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, String)>>>;
/// The connection pool, globally available.
/// This is atomic and safe and read-optimized.
/// The pool is recreated dynamically when the config is reloaded.
pub static POOL: Lazy<ArcSwap<ConnectionPool>> =
Lazy::new(|| ArcSwap::from_pointee(ConnectionPool::default()));
/// The globally accessible connection pool. /// The globally accessible connection pool.
#[derive(Clone, Debug)] #[derive(Clone, Debug, Default)]
pub struct ConnectionPool { pub struct ConnectionPool {
/// The pools handled internally by bb8.
databases: Vec<Vec<Pool<ServerPool>>>, databases: Vec<Vec<Pool<ServerPool>>>,
/// The addresses (host, port, role) to handle
/// failover and load balancing deterministically.
addresses: Vec<Vec<Address>>, addresses: Vec<Vec<Address>>,
round_robin: usize,
/// List of banned addresses (see above)
/// that should not be queried.
banlist: BanList, banlist: BanList,
/// The statistics aggregator runs in a separate task
/// and receives stats from clients, servers, and the pool.
stats: Reporter, stats: Reporter,
/// The server information (K messages) have to be passed to the
/// clients on startup. We pre-connect to all shards and replicas
/// on pool creation and save the K messages here.
server_info: BytesMut,
} }
impl ConnectionPool { impl ConnectionPool {
/// Construct the connection pool from the configuration. /// Construct the connection pool from the configuration.
pub async fn from_config( pub async fn from_config(client_server_map: ClientServerMap) -> Result<(), Error> {
client_server_map: ClientServerMap, let reporter = get_reporter();
stats: Reporter,
) -> ConnectionPool {
let config = get_config(); let config = get_config();
let mut shards = Vec::new(); let mut shards = Vec::new();
let mut addresses = Vec::new(); let mut addresses = Vec::new();
let mut banlist = Vec::new(); let mut banlist = Vec::new();
@@ -44,6 +64,8 @@ impl ConnectionPool {
.into_keys() .into_keys()
.map(|x| x.to_string()) .map(|x| x.to_string())
.collect::<Vec<String>>(); .collect::<Vec<String>>();
// Sort by shard number to ensure consistency.
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap()); shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
for shard_idx in shard_ids { for shard_idx in shard_ids {
@@ -82,7 +104,7 @@ impl ConnectionPool {
config.user.clone(), config.user.clone(),
&shard.database, &shard.database,
client_server_map.clone(), client_server_map.clone(),
stats.clone(), reporter.clone(),
); );
let pool = Pool::builder() let pool = Pool::builder()
@@ -105,15 +127,28 @@ impl ConnectionPool {
} }
assert_eq!(shards.len(), addresses.len()); assert_eq!(shards.len(), addresses.len());
let address_len = addresses.len();
ConnectionPool { let mut pool = ConnectionPool {
databases: shards, databases: shards,
addresses: addresses, addresses: addresses,
round_robin: rand::random::<usize>() % address_len, // Start at a random replica
banlist: Arc::new(RwLock::new(banlist)), banlist: Arc::new(RwLock::new(banlist)),
stats: stats, stats: reporter,
} server_info: BytesMut::new(),
};
// Connect to the servers to make sure pool configuration is valid
// before setting it globally.
match pool.validate().await {
Ok(_) => (),
Err(err) => {
error!("Could not validate connection pool: {:?}", err);
return Err(err);
}
};
POOL.store(Arc::new(pool.clone()));
Ok(())
} }
/// Connect to all shards and grab server information. /// Connect to all shards and grab server information.
@@ -121,16 +156,18 @@ impl ConnectionPool {
/// when they connect. /// when they connect.
/// This also warms up the pool for clients that connect when /// This also warms up the pool for clients that connect when
/// the pooler starts up. /// the pooler starts up.
pub async fn validate(&mut self) -> Result<BytesMut, Error> { async fn validate(&mut self) -> Result<(), Error> {
let mut server_infos = Vec::new(); let mut server_infos = Vec::new();
let stats = self.stats.clone(); let stats = self.stats.clone();
for shard in 0..self.shards() { for shard in 0..self.shards() {
let mut round_robin = 0;
for _ in 0..self.servers(shard) { for _ in 0..self.servers(shard) {
// To keep stats consistent. // To keep stats consistent.
let fake_process_id = 0; let fake_process_id = 0;
let connection = match self.get(shard, None, fake_process_id).await { let connection = match self.get(shard, None, fake_process_id, round_robin).await {
Ok(conn) => conn, Ok(conn) => conn,
Err(err) => { Err(err) => {
error!("Shard {} down or misconfigured: {:?}", shard, err); error!("Shard {} down or misconfigured: {:?}", shard, err);
@@ -138,10 +175,9 @@ impl ConnectionPool {
} }
}; };
let mut proxy = connection.0; let proxy = connection.0;
let address = connection.1; let address = connection.1;
let server = &mut *proxy; let server = &*proxy;
let server_info = server.server_info(); let server_info = server.server_info();
stats.client_disconnecting(fake_process_id, address.id); stats.client_disconnecting(fake_process_id, address.id);
@@ -157,6 +193,7 @@ impl ConnectionPool {
} }
server_infos.push(server_info); server_infos.push(server_info);
round_robin += 1;
} }
} }
@@ -166,15 +203,18 @@ impl ConnectionPool {
return Err(Error::AllServersDown); return Err(Error::AllServersDown);
} }
Ok(server_infos[0].clone()) self.server_info = server_infos[0].clone();
Ok(())
} }
/// Get a connection from the pool. /// Get a connection from the pool.
pub async fn get( pub async fn get(
&mut self, &mut self,
shard: usize, shard: usize, // shard number
role: Option<Role>, role: Option<Role>, // primary or replica
process_id: i32, process_id: i32, // client id
mut round_robin: usize, // round robin offset
) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> { ) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
let now = Instant::now(); let now = Instant::now();
let addresses = &self.addresses[shard]; let addresses = &self.addresses[shard];
@@ -204,9 +244,9 @@ impl ConnectionPool {
while allowed_attempts > 0 { while allowed_attempts > 0 {
// Round-robin replicas. // Round-robin replicas.
self.round_robin += 1; round_robin += 1;
let index = self.round_robin % addresses.len(); let index = round_robin % addresses.len();
let address = &addresses[index]; let address = &addresses[index];
// Make sure you're getting a primary or a replica // Make sure you're getting a primary or a replica
@@ -218,6 +258,7 @@ impl ConnectionPool {
allowed_attempts -= 1; allowed_attempts -= 1;
// Don't attempt to connect to banned servers.
if self.is_banned(address, shard, role) { if self.is_banned(address, shard, role) {
continue; continue;
} }
@@ -390,6 +431,10 @@ impl ConnectionPool {
pub fn address(&self, shard: usize, server: usize) -> &Address { pub fn address(&self, shard: usize, server: usize) -> &Address {
&self.addresses[shard][server] &self.addresses[shard][server]
} }
pub fn server_info(&self) -> BytesMut {
self.server_info.clone()
}
} }
/// Wrapper for the bb8 connection pool. /// Wrapper for the bb8 connection pool.
@@ -470,3 +515,8 @@ impl ManageConnection for ServerPool {
conn.is_bad() conn.is_bad()
} }
} }
/// Get the connection pool
pub fn get_pool() -> ConnectionPool {
(*(*POOL.load())).clone()
}

View File

@@ -12,12 +12,14 @@ use crate::config::{get_config, Role};
use crate::sharding::{Sharder, ShardingFunction}; use crate::sharding::{Sharder, ShardingFunction};
/// Regexes used to parse custom commands. /// Regexes used to parse custom commands.
const CUSTOM_SQL_REGEXES: [&str; 5] = [ const CUSTOM_SQL_REGEXES: [&str; 7] = [
r"(?i)^ *SET SHARDING KEY TO '?([0-9]+)'? *;? *$", r"(?i)^ *SET SHARDING KEY TO '?([0-9]+)'? *;? *$",
r"(?i)^ *SET SHARD TO '?([0-9]+|ANY)'? *;? *$", r"(?i)^ *SET SHARD TO '?([0-9]+|ANY)'? *;? *$",
r"(?i)^ *SHOW SHARD *;? *$", r"(?i)^ *SHOW SHARD *;? *$",
r"(?i)^ *SET SERVER ROLE TO '(PRIMARY|REPLICA|ANY|AUTO|DEFAULT)' *;? *$", r"(?i)^ *SET SERVER ROLE TO '(PRIMARY|REPLICA|ANY|AUTO|DEFAULT)' *;? *$",
r"(?i)^ *SHOW SERVER ROLE *;? *$", r"(?i)^ *SHOW SERVER ROLE *;? *$",
r"(?i)^ *SET PRIMARY READS TO '?(on|off|default)'? *;? *$",
r"(?i)^ *SHOW PRIMARY READS *;? *$",
]; ];
/// Custom commands. /// Custom commands.
@@ -28,6 +30,8 @@ pub enum Command {
ShowShard, ShowShard,
SetServerRole, SetServerRole,
ShowServerRole, ShowServerRole,
SetPrimaryReads,
ShowPrimaryReads,
} }
/// Quickly test for match when a query is received. /// Quickly test for match when a query is received.
@@ -38,27 +42,17 @@ static CUSTOM_SQL_REGEX_LIST: OnceCell<Vec<Regex>> = OnceCell::new();
/// The query router. /// The query router.
pub struct QueryRouter { pub struct QueryRouter {
/// By default, queries go here, unless we have better information
/// about what the client wants.
default_server_role: Option<Role>,
/// Number of shards in the cluster.
shards: usize,
/// Which shard we should be talking to right now. /// Which shard we should be talking to right now.
active_shard: Option<usize>, active_shard: Option<usize>,
/// Which server should we be talking to. /// Which server should we be talking to.
active_role: Option<Role>, active_role: Option<Role>,
/// Include the primary into the replica pool for reads. /// Should we try to parse queries to route them to replicas or primary automatically
primary_reads_enabled: bool,
/// Should we try to parse queries to route them to replicas or primary automatically.
query_parser_enabled: bool, query_parser_enabled: bool,
/// Which sharding function we're using. /// Include the primary into the replica pool for reads.
sharding_function: ShardingFunction, primary_reads_enabled: bool,
} }
impl QueryRouter { impl QueryRouter {
@@ -97,28 +91,11 @@ impl QueryRouter {
pub fn new() -> QueryRouter { pub fn new() -> QueryRouter {
let config = get_config(); let config = get_config();
let default_server_role = match config.query_router.default_role.as_ref() {
"any" => None,
"primary" => Some(Role::Primary),
"replica" => Some(Role::Replica),
_ => unreachable!(),
};
let sharding_function = match config.query_router.sharding_function.as_ref() {
"pg_bigint_hash" => ShardingFunction::PgBigintHash,
"sha1" => ShardingFunction::Sha1,
_ => unreachable!(),
};
QueryRouter { QueryRouter {
default_server_role: default_server_role,
shards: config.shards.len(),
active_role: default_server_role,
active_shard: None, active_shard: None,
primary_reads_enabled: config.query_router.primary_reads_enabled, active_role: None,
query_parser_enabled: config.query_router.query_parser_enabled, query_parser_enabled: config.query_router.query_parser_enabled,
sharding_function, primary_reads_enabled: config.query_router.primary_reads_enabled,
} }
} }
@@ -146,21 +123,48 @@ impl QueryRouter {
let matches: Vec<_> = regex_set.matches(&query).into_iter().collect(); let matches: Vec<_> = regex_set.matches(&query).into_iter().collect();
// This is not a custom query, try to infer which
// server it'll go to if the query parser is enabled.
if matches.len() != 1 { if matches.len() != 1 {
debug!("Regular query");
if self.query_parser_enabled && self.role() == None {
debug!("Inferring role");
self.infer_role(buf.clone());
}
return None; return None;
} }
let config = get_config();
let sharding_function = match config.query_router.sharding_function.as_ref() {
"pg_bigint_hash" => ShardingFunction::PgBigintHash,
"sha1" => ShardingFunction::Sha1,
_ => unreachable!(),
};
let default_server_role = match config.query_router.default_role.as_ref() {
"any" => None,
"primary" => Some(Role::Primary),
"replica" => Some(Role::Replica),
_ => unreachable!(),
};
let command = match matches[0] { let command = match matches[0] {
0 => Command::SetShardingKey, 0 => Command::SetShardingKey,
1 => Command::SetShard, 1 => Command::SetShard,
2 => Command::ShowShard, 2 => Command::ShowShard,
3 => Command::SetServerRole, 3 => Command::SetServerRole,
4 => Command::ShowServerRole, 4 => Command::ShowServerRole,
5 => Command::SetPrimaryReads,
6 => Command::ShowPrimaryReads,
_ => unreachable!(), _ => unreachable!(),
}; };
let mut value = match command { let mut value = match command {
Command::SetShardingKey | Command::SetShard | Command::SetServerRole => { Command::SetShardingKey
| Command::SetShard
| Command::SetServerRole
| Command::SetPrimaryReads => {
// Capture value. I know this re-runs the regex engine, but I haven't // Capture value. I know this re-runs the regex engine, but I haven't
// figured out a better way just yet. I think I can write a single Regex // figured out a better way just yet. I think I can write a single Regex
// that matches all 5 custom SQL patterns, but maybe that's not very legible? // that matches all 5 custom SQL patterns, but maybe that's not very legible?
@@ -187,11 +191,16 @@ impl QueryRouter {
} }
} }
}, },
Command::ShowPrimaryReads => match self.primary_reads_enabled {
true => String::from("on"),
false => String::from("off"),
},
}; };
match command { match command {
Command::SetShardingKey => { Command::SetShardingKey => {
let sharder = Sharder::new(self.shards, self.sharding_function); let sharder = Sharder::new(config.shards.len(), sharding_function);
let shard = sharder.shard(value.parse::<i64>().unwrap()); let shard = sharder.shard(value.parse::<i64>().unwrap());
self.active_shard = Some(shard); self.active_shard = Some(shard);
value = shard.to_string(); value = shard.to_string();
@@ -199,7 +208,7 @@ impl QueryRouter {
Command::SetShard => { Command::SetShard => {
self.active_shard = match value.to_ascii_uppercase().as_ref() { self.active_shard = match value.to_ascii_uppercase().as_ref() {
"ANY" => Some(rand::random::<usize>() % self.shards), "ANY" => Some(rand::random::<usize>() % config.shards.len()),
_ => Some(value.parse::<usize>().unwrap()), _ => Some(value.parse::<usize>().unwrap()),
}; };
} }
@@ -227,8 +236,8 @@ impl QueryRouter {
} }
"default" => { "default" => {
self.active_role = self.default_server_role; self.active_role = default_server_role;
self.query_parser_enabled = get_config().query_router.query_parser_enabled; self.query_parser_enabled = config.query_router.query_parser_enabled;
self.active_role self.active_role
} }
@@ -236,6 +245,19 @@ impl QueryRouter {
}; };
} }
Command::SetPrimaryReads => {
if value == "on" {
debug!("Setting primary reads to on");
self.primary_reads_enabled = true;
} else if value == "off" {
debug!("Setting primary reads to off");
self.primary_reads_enabled = false;
} else if value == "default" {
debug!("Setting primary reads to default");
self.primary_reads_enabled = config.query_router.primary_reads_enabled;
}
}
_ => (), _ => (),
} }
@@ -330,23 +352,15 @@ impl QueryRouter {
} }
} }
/// Reset the router back to defaults. pub fn set_shard(&mut self, shard: usize) {
/// This must be called at the end of every transaction in transaction mode. self.active_shard = Some(shard);
pub fn _reset(&mut self) {
self.active_role = self.default_server_role;
self.active_shard = None;
} }
/// Should we attempt to parse queries? /// Should we attempt to parse queries?
#[allow(dead_code)]
pub fn query_parser_enabled(&self) -> bool { pub fn query_parser_enabled(&self) -> bool {
self.query_parser_enabled self.query_parser_enabled
} }
/// Allows to toggle primary reads in tests.
#[allow(dead_code)]
pub fn toggle_primary_reads(&mut self, value: bool) {
self.primary_reads_enabled = value;
}
} }
#[cfg(test)] #[cfg(test)]
@@ -369,7 +383,8 @@ mod test {
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None); assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None);
assert_eq!(qr.query_parser_enabled(), true); assert_eq!(qr.query_parser_enabled(), true);
qr.toggle_primary_reads(false);
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None);
let queries = vec![ let queries = vec![
simple_query("SELECT * FROM items WHERE id = 5"), simple_query("SELECT * FROM items WHERE id = 5"),
@@ -410,7 +425,7 @@ mod test {
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
let query = simple_query("SELECT * FROM items WHERE id = 5"); let query = simple_query("SELECT * FROM items WHERE id = 5");
qr.toggle_primary_reads(true); assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO on")) != None);
assert!(qr.infer_role(query)); assert!(qr.infer_role(query));
assert_eq!(qr.role(), None); assert_eq!(qr.role(), None);
@@ -421,7 +436,7 @@ mod test {
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")); qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'"));
qr.toggle_primary_reads(false); assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None);
let prepared_stmt = BytesMut::from( let prepared_stmt = BytesMut::from(
&b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..], &b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..],
@@ -450,6 +465,10 @@ mod test {
"SET SERVER ROLE TO 'any'", "SET SERVER ROLE TO 'any'",
"SET SERVER ROLE TO 'auto'", "SET SERVER ROLE TO 'auto'",
"SHOW SERVER ROLE", "SHOW SERVER ROLE",
"SET PRIMARY READS TO 'on'",
"SET PRIMARY READS TO 'off'",
"SET PRIMARY READS TO 'default'",
"SHOW PRIMARY READS",
// Lower case // Lower case
"set sharding key to '1'", "set sharding key to '1'",
"set shard to '1'", "set shard to '1'",
@@ -459,9 +478,13 @@ mod test {
"set server role to 'any'", "set server role to 'any'",
"set server role to 'auto'", "set server role to 'auto'",
"show server role", "show server role",
"set primary reads to 'on'",
"set primary reads to 'OFF'",
"set primary reads to 'deFaUlt'",
// No quotes // No quotes
"SET SHARDING KEY TO 11235", "SET SHARDING KEY TO 11235",
"SET SHARD TO 15", "SET SHARD TO 15",
"SET PRIMARY READS TO off",
// Spaces and semicolon // Spaces and semicolon
" SET SHARDING KEY TO 11235 ; ", " SET SHARDING KEY TO 11235 ; ",
" SET SHARD TO 15; ", " SET SHARD TO 15; ",
@@ -469,18 +492,23 @@ mod test {
" SET SERVER ROLE TO 'primary'; ", " SET SERVER ROLE TO 'primary'; ",
" SET SERVER ROLE TO 'primary' ; ", " SET SERVER ROLE TO 'primary' ; ",
" SET SERVER ROLE TO 'primary' ;", " SET SERVER ROLE TO 'primary' ;",
" SET PRIMARY READS TO 'off' ;",
]; ];
// Which regexes it'll match to in the list // Which regexes it'll match to in the list
let matches = [ let matches = [
0, 1, 2, 3, 3, 3, 3, 4, 0, 1, 2, 3, 3, 3, 3, 4, 0, 1, 0, 1, 0, 3, 3, 3, 0, 1, 2, 3, 3, 3, 3, 4, 5, 5, 5, 6, 0, 1, 2, 3, 3, 3, 3, 4, 5, 5, 5, 0, 1, 5, 0, 1, 0,
3, 3, 3, 5,
]; ];
let list = CUSTOM_SQL_REGEX_LIST.get().unwrap(); let list = CUSTOM_SQL_REGEX_LIST.get().unwrap();
let set = CUSTOM_SQL_REGEX_SET.get().unwrap(); let set = CUSTOM_SQL_REGEX_SET.get().unwrap();
for (i, test) in tests.iter().enumerate() { for (i, test) in tests.iter().enumerate() {
assert!(list[matches[i]].is_match(test)); if !list[matches[i]].is_match(test) {
println!("{} does not match {}", test, list[matches[i]]);
assert!(false);
}
assert_eq!(set.matches(test).into_iter().collect::<Vec<_>>().len(), 1); assert_eq!(set.matches(test).into_iter().collect::<Vec<_>>().len(), 1);
} }
@@ -549,6 +577,26 @@ mod test {
Some((Command::ShowServerRole, String::from(*role))) Some((Command::ShowServerRole, String::from(*role)))
); );
} }
let primary_reads = ["on", "off", "default"];
let primary_reads_enabled = ["on", "off", "on"];
for (idx, primary_reads) in primary_reads.iter().enumerate() {
assert_eq!(
qr.try_execute_command(simple_query(&format!(
"SET PRIMARY READS TO {}",
primary_reads
))),
Some((Command::SetPrimaryReads, String::from(*primary_reads)))
);
assert_eq!(
qr.try_execute_command(simple_query("SHOW PRIMARY READS")),
Some((
Command::ShowPrimaryReads,
String::from(primary_reads_enabled[idx])
))
);
}
} }
#[test] #[test]
@@ -556,7 +604,7 @@ mod test {
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
let query = simple_query("SET SERVER ROLE TO 'auto'"); let query = simple_query("SET SERVER ROLE TO 'auto'");
qr.toggle_primary_reads(false); assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None);
assert!(qr.try_execute_command(query) != None); assert!(qr.try_execute_command(query) != None);
assert!(qr.query_parser_enabled()); assert!(qr.query_parser_enabled());
@@ -573,6 +621,6 @@ mod test {
assert!(qr.query_parser_enabled()); assert!(qr.query_parser_enabled());
let query = simple_query("SET SERVER ROLE TO 'default'"); let query = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(query) != None); assert!(qr.try_execute_command(query) != None);
assert!(!qr.query_parser_enabled()); assert!(qr.query_parser_enabled());
} }
} }

View File

@@ -1,9 +1,13 @@
use arc_swap::ArcSwap;
/// Statistics and reporting. /// Statistics and reporting.
use log::info; use log::info;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use parking_lot::Mutex; use parking_lot::Mutex;
use std::collections::HashMap; use std::collections::HashMap;
use tokio::sync::mpsc::{Receiver, Sender}; use tokio::sync::mpsc::{channel, Receiver, Sender};
pub static REPORTER: Lazy<ArcSwap<Reporter>> =
Lazy::new(|| ArcSwap::from_pointee(Reporter::default()));
/// Latest stats updated every second; used in SHOW STATS and other admin commands. /// Latest stats updated every second; used in SHOW STATS and other admin commands.
static LATEST_STATS: Lazy<Mutex<HashMap<usize, HashMap<String, i64>>>> = static LATEST_STATS: Lazy<Mutex<HashMap<usize, HashMap<String, i64>>>> =
@@ -60,6 +64,13 @@ pub struct Reporter {
tx: Sender<Event>, tx: Sender<Event>,
} }
impl Default for Reporter {
fn default() -> Reporter {
let (tx, _rx) = channel(5);
Reporter { tx }
}
}
impl Reporter { impl Reporter {
/// Create a new Reporter instance. /// Create a new Reporter instance.
pub fn new(tx: Sender<Event>) -> Reporter { pub fn new(tx: Sender<Event>) -> Reporter {
@@ -289,7 +300,7 @@ impl Collector {
("avg_query_time", 0), ("avg_query_time", 0),
("avg_xact_count", 0), ("avg_xact_count", 0),
("avg_sent", 0), ("avg_sent", 0),
("avg_received", 0), ("avg_recv", 0),
("avg_wait_time", 0), ("avg_wait_time", 0),
("maxwait_us", 0), ("maxwait_us", 0),
("maxwait", 0), ("maxwait", 0),
@@ -493,10 +504,14 @@ impl Collector {
"avg_query_count", "avg_query_count",
"avgxact_count", "avgxact_count",
"avg_sent", "avg_sent",
"avg_received", "avg_recv",
"avg_wait_time", "avg_wait_time",
] { ] {
let total_name = stat.replace("avg_", "total_"); let total_name = match stat {
&"avg_recv" => "total_received".to_string(), // Because PgBouncer is saving bytes
_ => stat.replace("avg_", "total_"),
};
let old_value = old_stats.entry(total_name.clone()).or_insert(0); let old_value = old_stats.entry(total_name.clone()).or_insert(0);
let new_value = stats.get(total_name.as_str()).unwrap_or(&0).to_owned(); let new_value = stats.get(total_name.as_str()).unwrap_or(&0).to_owned();
let avg = (new_value - *old_value) / (STAT_PERIOD as i64 / 1_000); // Avg / second let avg = (new_value - *old_value) / (STAT_PERIOD as i64 / 1_000); // Avg / second
@@ -515,3 +530,8 @@ impl Collector {
pub fn get_stats() -> HashMap<usize, HashMap<String, i64>> { pub fn get_stats() -> HashMap<usize, HashMap<String, i64>> {
LATEST_STATS.lock().clone() LATEST_STATS.lock().clone()
} }
/// Get the statistics reporter used to update stats across the pools/clients.
pub fn get_reporter() -> Reporter {
(*(*REPORTER.load())).clone()
}

View File

@@ -12,6 +12,8 @@
SET SHARD TO :shard; SET SHARD TO :shard;
SET SERVER ROLE TO 'auto';
BEGIN; BEGIN;
UPDATE pgbench_accounts SET abalance = abalance + :delta WHERE aid = :aid; UPDATE pgbench_accounts SET abalance = abalance + :delta WHERE aid = :aid;
@@ -26,3 +28,12 @@ INSERT INTO pgbench_history (tid, bid, aid, delta, mtime) VALUES (:tid, :bid, :a
END; END;
SET SHARDING KEY TO :aid;
-- Read load balancing
SELECT abalance FROM pgbench_accounts WHERE aid = :aid;
SET SERVER ROLE TO 'replica';
-- Read load balancing
SELECT abalance FROM pgbench_accounts WHERE aid = :aid;

View File

@@ -151,3 +151,12 @@ SELECT 1;
set server role to 'replica'; set server role to 'replica';
SeT SeRver Role TO 'PrImARY'; SeT SeRver Role TO 'PrImARY';
select 1; select 1;
SET PRIMARY READS TO 'on';
SELECT 1;
SET PRIMARY READS TO 'off';
SELECT 1;
SET PRIMARY READS TO 'default';
SELECT 1;