mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 01:16:30 +00:00
Allow setting the number of runtime workers to be used. (#258)
This change adds a new configuration parameter called `worker_threads` that allows setting the number of workers the Tokio Runtime will use. It defaults to 4 to maintain backward compatibility. Given that the config file parse is done asynchronously, first, a transient runtime is created for reading config, and once it has been parsed, the actual runtime that will be used for PgCat execution is created.
This commit is contained in:
@@ -44,6 +44,9 @@ log_client_disconnections = false
|
||||
# Reload config automatically if it changes.
|
||||
autoreload = false
|
||||
|
||||
# Number of worker threads the Runtime will use (4 by default).
|
||||
worker_threads = 5
|
||||
|
||||
# TLS
|
||||
# tls_certificate = "server.cert"
|
||||
# tls_private_key = "server.key"
|
||||
|
||||
@@ -178,6 +178,9 @@ pub struct General {
|
||||
#[serde(default = "General::default_ban_time")]
|
||||
pub ban_time: i64,
|
||||
|
||||
#[serde(default = "General::default_worker_threads")]
|
||||
pub worker_threads: usize,
|
||||
|
||||
#[serde(default)] // False
|
||||
pub autoreload: bool,
|
||||
|
||||
@@ -219,6 +222,10 @@ impl General {
|
||||
pub fn default_ban_time() -> i64 {
|
||||
60
|
||||
}
|
||||
|
||||
pub fn default_worker_threads() -> usize {
|
||||
4
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for General {
|
||||
@@ -234,6 +241,7 @@ impl Default for General {
|
||||
healthcheck_timeout: Self::default_healthcheck_timeout(),
|
||||
healthcheck_delay: Self::default_healthcheck_delay(),
|
||||
ban_time: Self::default_ban_time(),
|
||||
worker_threads: Self::default_worker_threads(),
|
||||
log_client_connections: false,
|
||||
log_client_disconnections: false,
|
||||
autoreload: false,
|
||||
|
||||
395
src/main.rs
395
src/main.rs
@@ -49,6 +49,7 @@ use parking_lot::Mutex;
|
||||
use pgcat::format_duration;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::{
|
||||
runtime::Builder,
|
||||
signal::unix::{signal as unix_signal, SignalKind},
|
||||
sync::mpsc,
|
||||
};
|
||||
@@ -79,8 +80,7 @@ use crate::pool::{ClientServerMap, ConnectionPool};
|
||||
use crate::prometheus::start_metric_server;
|
||||
use crate::stats::{Collector, Reporter, REPORTER};
|
||||
|
||||
#[tokio::main(worker_threads = 4)]
|
||||
async fn main() {
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
env_logger::builder().format_timestamp_micros().init();
|
||||
|
||||
info!("Welcome to PgCat! Meow. (Version {})", VERSION);
|
||||
@@ -98,215 +98,232 @@ async fn main() {
|
||||
String::from("pgcat.toml")
|
||||
};
|
||||
|
||||
match config::parse(&config_file).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
error!("Config parse error: {:?}", err);
|
||||
std::process::exit(exitcode::CONFIG);
|
||||
}
|
||||
};
|
||||
// Create a transient runtime for loading the config for the first time.
|
||||
{
|
||||
let runtime = Builder::new_multi_thread().worker_threads(1).build()?;
|
||||
|
||||
let config = get_config();
|
||||
|
||||
if let Some(true) = config.general.enable_prometheus_exporter {
|
||||
let http_addr_str = format!(
|
||||
"{}:{}",
|
||||
config.general.host, config.general.prometheus_exporter_port
|
||||
);
|
||||
let http_addr = match SocketAddr::from_str(&http_addr_str) {
|
||||
Ok(addr) => addr,
|
||||
Err(err) => {
|
||||
error!("Invalid http address: {}", err);
|
||||
std::process::exit(exitcode::CONFIG);
|
||||
}
|
||||
};
|
||||
tokio::task::spawn(async move {
|
||||
start_metric_server(http_addr).await;
|
||||
runtime.block_on(async {
|
||||
match config::parse(&config_file).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
error!("Config parse error: {:?}", err);
|
||||
std::process::exit(exitcode::CONFIG);
|
||||
}
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
let addr = format!("{}:{}", config.general.host, config.general.port);
|
||||
let config = get_config();
|
||||
|
||||
let listener = match TcpListener::bind(&addr).await {
|
||||
Ok(sock) => sock,
|
||||
Err(err) => {
|
||||
error!("Listener socket error: {:?}", err);
|
||||
std::process::exit(exitcode::CONFIG);
|
||||
}
|
||||
};
|
||||
// Create the runtime now we know required worker_threads.
|
||||
let runtime = Builder::new_multi_thread()
|
||||
.worker_threads(config.general.worker_threads)
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
info!("Running on {}", addr);
|
||||
runtime.block_on(async move {
|
||||
|
||||
config.show();
|
||||
if let Some(true) = config.general.enable_prometheus_exporter {
|
||||
let http_addr_str = format!(
|
||||
"{}:{}",
|
||||
config.general.host, config.general.prometheus_exporter_port
|
||||
);
|
||||
let http_addr = match SocketAddr::from_str(&http_addr_str) {
|
||||
Ok(addr) => addr,
|
||||
Err(err) => {
|
||||
error!("Invalid http address: {}", err);
|
||||
std::process::exit(exitcode::CONFIG);
|
||||
}
|
||||
};
|
||||
tokio::task::spawn(async move {
|
||||
start_metric_server(http_addr).await;
|
||||
});
|
||||
}
|
||||
|
||||
// Tracks which client is connected to which server for query cancellation.
|
||||
let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new()));
|
||||
let addr = format!("{}:{}", config.general.host, config.general.port);
|
||||
|
||||
// Statistics reporting.
|
||||
let (stats_tx, stats_rx) = mpsc::channel(100_000);
|
||||
REPORTER.store(Arc::new(Reporter::new(stats_tx.clone())));
|
||||
|
||||
// Connection pool that allows to query all shards and replicas.
|
||||
match ConnectionPool::from_config(client_server_map.clone()).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
error!("Pool error: {:?}", err);
|
||||
std::process::exit(exitcode::CONFIG);
|
||||
}
|
||||
};
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
let mut stats_collector = Collector::new(stats_rx, stats_tx.clone());
|
||||
stats_collector.collect().await;
|
||||
});
|
||||
|
||||
info!("Config autoreloader: {}", config.general.autoreload);
|
||||
|
||||
let mut autoreload_interval = tokio::time::interval(tokio::time::Duration::from_millis(15_000));
|
||||
let autoreload_client_server_map = client_server_map.clone();
|
||||
tokio::task::spawn(async move {
|
||||
loop {
|
||||
autoreload_interval.tick().await;
|
||||
if config.general.autoreload {
|
||||
info!("Automatically reloading config");
|
||||
|
||||
if let Ok(changed) = reload_config(autoreload_client_server_map.clone()).await {
|
||||
if changed {
|
||||
get_config().show()
|
||||
}
|
||||
};
|
||||
let listener = match TcpListener::bind(&addr).await {
|
||||
Ok(sock) => sock,
|
||||
Err(err) => {
|
||||
error!("Listener socket error: {:?}", err);
|
||||
std::process::exit(exitcode::CONFIG);
|
||||
}
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
let mut term_signal = unix_signal(SignalKind::terminate()).unwrap();
|
||||
let mut interrupt_signal = unix_signal(SignalKind::interrupt()).unwrap();
|
||||
let mut sighup_signal = unix_signal(SignalKind::hangup()).unwrap();
|
||||
let (shutdown_tx, _) = broadcast::channel::<()>(1);
|
||||
let (drain_tx, mut drain_rx) = mpsc::channel::<i32>(2048);
|
||||
let (exit_tx, mut exit_rx) = mpsc::channel::<()>(1);
|
||||
info!("Running on {}", addr);
|
||||
|
||||
info!("Waiting for clients");
|
||||
config.show();
|
||||
|
||||
let mut admin_only = false;
|
||||
let mut total_clients = 0;
|
||||
// Tracks which client is connected to which server for query cancellation.
|
||||
let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new()));
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Reload config:
|
||||
// kill -SIGHUP $(pgrep pgcat)
|
||||
_ = sighup_signal.recv() => {
|
||||
info!("Reloading config");
|
||||
// Statistics reporting.
|
||||
let (stats_tx, stats_rx) = mpsc::channel(100_000);
|
||||
REPORTER.store(Arc::new(Reporter::new(stats_tx.clone())));
|
||||
|
||||
_ = reload_config(client_server_map.clone()).await;
|
||||
// Connection pool that allows to query all shards and replicas.
|
||||
match ConnectionPool::from_config(client_server_map.clone()).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
error!("Pool error: {:?}", err);
|
||||
std::process::exit(exitcode::CONFIG);
|
||||
}
|
||||
};
|
||||
|
||||
get_config().show();
|
||||
},
|
||||
tokio::task::spawn(async move {
|
||||
let mut stats_collector = Collector::new(stats_rx, stats_tx.clone());
|
||||
stats_collector.collect().await;
|
||||
});
|
||||
|
||||
// Initiate graceful shutdown sequence on sig int
|
||||
_ = interrupt_signal.recv() => {
|
||||
info!("Got SIGINT, waiting for client connection drain now");
|
||||
admin_only = true;
|
||||
info!("Config autoreloader: {}", config.general.autoreload);
|
||||
|
||||
// Broadcast that client tasks need to finish
|
||||
let _ = shutdown_tx.send(());
|
||||
let exit_tx = exit_tx.clone();
|
||||
let _ = drain_tx.send(0).await;
|
||||
let mut autoreload_interval = tokio::time::interval(tokio::time::Duration::from_millis(15_000));
|
||||
let autoreload_client_server_map = client_server_map.clone();
|
||||
tokio::task::spawn(async move {
|
||||
loop {
|
||||
autoreload_interval.tick().await;
|
||||
if config.general.autoreload {
|
||||
info!("Automatically reloading config");
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(config.general.shutdown_timeout));
|
||||
|
||||
// First tick fires immediately.
|
||||
interval.tick().await;
|
||||
|
||||
// Second one in the interval time.
|
||||
interval.tick().await;
|
||||
|
||||
// We're done waiting.
|
||||
error!("Graceful shutdown timed out. {} active clients being closed", total_clients);
|
||||
|
||||
let _ = exit_tx.send(()).await;
|
||||
});
|
||||
},
|
||||
|
||||
_ = term_signal.recv() => {
|
||||
info!("Got SIGTERM, closing with {} clients active", total_clients);
|
||||
break;
|
||||
},
|
||||
|
||||
new_client = listener.accept() => {
|
||||
let (socket, addr) = match new_client {
|
||||
Ok((socket, addr)) => (socket, addr),
|
||||
Err(err) => {
|
||||
error!("{:?}", err);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let shutdown_rx = shutdown_tx.subscribe();
|
||||
let drain_tx = drain_tx.clone();
|
||||
let client_server_map = client_server_map.clone();
|
||||
|
||||
let tls_certificate = config.general.tls_certificate.clone();
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
let start = chrono::offset::Utc::now().naive_utc();
|
||||
|
||||
match client::client_entrypoint(
|
||||
socket,
|
||||
client_server_map,
|
||||
shutdown_rx,
|
||||
drain_tx,
|
||||
admin_only,
|
||||
tls_certificate.clone(),
|
||||
config.general.log_client_connections,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
|
||||
let duration = chrono::offset::Utc::now().naive_utc() - start;
|
||||
|
||||
if config.general.log_client_disconnections {
|
||||
info!(
|
||||
"Client {:?} disconnected, session duration: {}",
|
||||
addr,
|
||||
format_duration(&duration)
|
||||
);
|
||||
} else {
|
||||
debug!(
|
||||
"Client {:?} disconnected, session duration: {}",
|
||||
addr,
|
||||
format_duration(&duration)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Err(err) => {
|
||||
match err {
|
||||
errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err),
|
||||
_ => warn!("Client disconnected with error {:?}", err),
|
||||
}
|
||||
|
||||
}
|
||||
if let Ok(changed) = reload_config(autoreload_client_server_map.clone()).await {
|
||||
if changed {
|
||||
get_config().show()
|
||||
}
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
_ = exit_rx.recv() => {
|
||||
break;
|
||||
let mut term_signal = unix_signal(SignalKind::terminate()).unwrap();
|
||||
let mut interrupt_signal = unix_signal(SignalKind::interrupt()).unwrap();
|
||||
let mut sighup_signal = unix_signal(SignalKind::hangup()).unwrap();
|
||||
let (shutdown_tx, _) = broadcast::channel::<()>(1);
|
||||
let (drain_tx, mut drain_rx) = mpsc::channel::<i32>(2048);
|
||||
let (exit_tx, mut exit_rx) = mpsc::channel::<()>(1);
|
||||
|
||||
info!("Waiting for clients");
|
||||
|
||||
let mut admin_only = false;
|
||||
let mut total_clients = 0;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Reload config:
|
||||
// kill -SIGHUP $(pgrep pgcat)
|
||||
_ = sighup_signal.recv() => {
|
||||
info!("Reloading config");
|
||||
|
||||
_ = reload_config(client_server_map.clone()).await;
|
||||
|
||||
get_config().show();
|
||||
},
|
||||
|
||||
// Initiate graceful shutdown sequence on sig int
|
||||
_ = interrupt_signal.recv() => {
|
||||
info!("Got SIGINT, waiting for client connection drain now");
|
||||
admin_only = true;
|
||||
|
||||
// Broadcast that client tasks need to finish
|
||||
let _ = shutdown_tx.send(());
|
||||
let exit_tx = exit_tx.clone();
|
||||
let _ = drain_tx.send(0).await;
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(config.general.shutdown_timeout));
|
||||
|
||||
// First tick fires immediately.
|
||||
interval.tick().await;
|
||||
|
||||
// Second one in the interval time.
|
||||
interval.tick().await;
|
||||
|
||||
// We're done waiting.
|
||||
error!("Graceful shutdown timed out. {} active clients being closed", total_clients);
|
||||
|
||||
let _ = exit_tx.send(()).await;
|
||||
});
|
||||
},
|
||||
|
||||
_ = term_signal.recv() => {
|
||||
info!("Got SIGTERM, closing with {} clients active", total_clients);
|
||||
break;
|
||||
},
|
||||
|
||||
new_client = listener.accept() => {
|
||||
let (socket, addr) = match new_client {
|
||||
Ok((socket, addr)) => (socket, addr),
|
||||
Err(err) => {
|
||||
error!("{:?}", err);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let shutdown_rx = shutdown_tx.subscribe();
|
||||
let drain_tx = drain_tx.clone();
|
||||
let client_server_map = client_server_map.clone();
|
||||
|
||||
let tls_certificate = config.general.tls_certificate.clone();
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
let start = chrono::offset::Utc::now().naive_utc();
|
||||
|
||||
match client::client_entrypoint(
|
||||
socket,
|
||||
client_server_map,
|
||||
shutdown_rx,
|
||||
drain_tx,
|
||||
admin_only,
|
||||
tls_certificate.clone(),
|
||||
config.general.log_client_connections,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
|
||||
let duration = chrono::offset::Utc::now().naive_utc() - start;
|
||||
|
||||
if config.general.log_client_disconnections {
|
||||
info!(
|
||||
"Client {:?} disconnected, session duration: {}",
|
||||
addr,
|
||||
format_duration(&duration)
|
||||
);
|
||||
} else {
|
||||
debug!(
|
||||
"Client {:?} disconnected, session duration: {}",
|
||||
addr,
|
||||
format_duration(&duration)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Err(err) => {
|
||||
match err {
|
||||
errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err),
|
||||
_ => warn!("Client disconnected with error {:?}", err),
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
_ = exit_rx.recv() => {
|
||||
break;
|
||||
}
|
||||
|
||||
client_ping = drain_rx.recv() => {
|
||||
let client_ping = client_ping.unwrap();
|
||||
total_clients += client_ping;
|
||||
|
||||
if total_clients == 0 && admin_only {
|
||||
let _ = exit_tx.send(()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client_ping = drain_rx.recv() => {
|
||||
let client_ping = client_ping.unwrap();
|
||||
total_clients += client_ping;
|
||||
|
||||
if total_clients == 0 && admin_only {
|
||||
let _ = exit_tx.send(()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("Shutting down...");
|
||||
info!("Shutting down...");
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user