diff --git a/pgcat.toml b/pgcat.toml index 73afc4b..0187c16 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -44,6 +44,9 @@ log_client_disconnections = false # Reload config automatically if it changes. autoreload = false +# Number of worker threads the Runtime will use (4 by default). +worker_threads = 5 + # TLS # tls_certificate = "server.cert" # tls_private_key = "server.key" diff --git a/src/config.rs b/src/config.rs index 48bd0bb..e8be947 100644 --- a/src/config.rs +++ b/src/config.rs @@ -178,6 +178,9 @@ pub struct General { #[serde(default = "General::default_ban_time")] pub ban_time: i64, + #[serde(default = "General::default_worker_threads")] + pub worker_threads: usize, + #[serde(default)] // False pub autoreload: bool, @@ -219,6 +222,10 @@ impl General { pub fn default_ban_time() -> i64 { 60 } + + pub fn default_worker_threads() -> usize { + 4 + } } impl Default for General { @@ -234,6 +241,7 @@ impl Default for General { healthcheck_timeout: Self::default_healthcheck_timeout(), healthcheck_delay: Self::default_healthcheck_delay(), ban_time: Self::default_ban_time(), + worker_threads: Self::default_worker_threads(), log_client_connections: false, log_client_disconnections: false, autoreload: false, diff --git a/src/main.rs b/src/main.rs index 0b5f732..aac51d6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -49,6 +49,7 @@ use parking_lot::Mutex; use pgcat::format_duration; use tokio::net::TcpListener; use tokio::{ + runtime::Builder, signal::unix::{signal as unix_signal, SignalKind}, sync::mpsc, }; @@ -79,8 +80,7 @@ use crate::pool::{ClientServerMap, ConnectionPool}; use crate::prometheus::start_metric_server; use crate::stats::{Collector, Reporter, REPORTER}; -#[tokio::main(worker_threads = 4)] -async fn main() { +fn main() -> Result<(), Box> { env_logger::builder().format_timestamp_micros().init(); info!("Welcome to PgCat! Meow. (Version {})", VERSION); @@ -98,215 +98,232 @@ async fn main() { String::from("pgcat.toml") }; - match config::parse(&config_file).await { - Ok(_) => (), - Err(err) => { - error!("Config parse error: {:?}", err); - std::process::exit(exitcode::CONFIG); - } - }; + // Create a transient runtime for loading the config for the first time. + { + let runtime = Builder::new_multi_thread().worker_threads(1).build()?; - let config = get_config(); - - if let Some(true) = config.general.enable_prometheus_exporter { - let http_addr_str = format!( - "{}:{}", - config.general.host, config.general.prometheus_exporter_port - ); - let http_addr = match SocketAddr::from_str(&http_addr_str) { - Ok(addr) => addr, - Err(err) => { - error!("Invalid http address: {}", err); - std::process::exit(exitcode::CONFIG); - } - }; - tokio::task::spawn(async move { - start_metric_server(http_addr).await; + runtime.block_on(async { + match config::parse(&config_file).await { + Ok(_) => (), + Err(err) => { + error!("Config parse error: {:?}", err); + std::process::exit(exitcode::CONFIG); + } + }; }); } - let addr = format!("{}:{}", config.general.host, config.general.port); + let config = get_config(); - let listener = match TcpListener::bind(&addr).await { - Ok(sock) => sock, - Err(err) => { - error!("Listener socket error: {:?}", err); - std::process::exit(exitcode::CONFIG); - } - }; + // Create the runtime now we know required worker_threads. + let runtime = Builder::new_multi_thread() + .worker_threads(config.general.worker_threads) + .enable_all() + .build()?; - info!("Running on {}", addr); + runtime.block_on(async move { - config.show(); + if let Some(true) = config.general.enable_prometheus_exporter { + let http_addr_str = format!( + "{}:{}", + config.general.host, config.general.prometheus_exporter_port + ); + let http_addr = match SocketAddr::from_str(&http_addr_str) { + Ok(addr) => addr, + Err(err) => { + error!("Invalid http address: {}", err); + std::process::exit(exitcode::CONFIG); + } + }; + tokio::task::spawn(async move { + start_metric_server(http_addr).await; + }); + } - // Tracks which client is connected to which server for query cancellation. - let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new())); + let addr = format!("{}:{}", config.general.host, config.general.port); - // Statistics reporting. - let (stats_tx, stats_rx) = mpsc::channel(100_000); - REPORTER.store(Arc::new(Reporter::new(stats_tx.clone()))); - - // Connection pool that allows to query all shards and replicas. - match ConnectionPool::from_config(client_server_map.clone()).await { - Ok(_) => (), - Err(err) => { - error!("Pool error: {:?}", err); - std::process::exit(exitcode::CONFIG); - } - }; - - tokio::task::spawn(async move { - let mut stats_collector = Collector::new(stats_rx, stats_tx.clone()); - stats_collector.collect().await; - }); - - info!("Config autoreloader: {}", config.general.autoreload); - - let mut autoreload_interval = tokio::time::interval(tokio::time::Duration::from_millis(15_000)); - let autoreload_client_server_map = client_server_map.clone(); - tokio::task::spawn(async move { - loop { - autoreload_interval.tick().await; - if config.general.autoreload { - info!("Automatically reloading config"); - - if let Ok(changed) = reload_config(autoreload_client_server_map.clone()).await { - if changed { - get_config().show() - } - }; + let listener = match TcpListener::bind(&addr).await { + Ok(sock) => sock, + Err(err) => { + error!("Listener socket error: {:?}", err); + std::process::exit(exitcode::CONFIG); } - } - }); + }; - let mut term_signal = unix_signal(SignalKind::terminate()).unwrap(); - let mut interrupt_signal = unix_signal(SignalKind::interrupt()).unwrap(); - let mut sighup_signal = unix_signal(SignalKind::hangup()).unwrap(); - let (shutdown_tx, _) = broadcast::channel::<()>(1); - let (drain_tx, mut drain_rx) = mpsc::channel::(2048); - let (exit_tx, mut exit_rx) = mpsc::channel::<()>(1); + info!("Running on {}", addr); - info!("Waiting for clients"); + config.show(); - let mut admin_only = false; - let mut total_clients = 0; + // Tracks which client is connected to which server for query cancellation. + let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new())); - loop { - tokio::select! { - // Reload config: - // kill -SIGHUP $(pgrep pgcat) - _ = sighup_signal.recv() => { - info!("Reloading config"); + // Statistics reporting. + let (stats_tx, stats_rx) = mpsc::channel(100_000); + REPORTER.store(Arc::new(Reporter::new(stats_tx.clone()))); - _ = reload_config(client_server_map.clone()).await; + // Connection pool that allows to query all shards and replicas. + match ConnectionPool::from_config(client_server_map.clone()).await { + Ok(_) => (), + Err(err) => { + error!("Pool error: {:?}", err); + std::process::exit(exitcode::CONFIG); + } + }; - get_config().show(); - }, + tokio::task::spawn(async move { + let mut stats_collector = Collector::new(stats_rx, stats_tx.clone()); + stats_collector.collect().await; + }); - // Initiate graceful shutdown sequence on sig int - _ = interrupt_signal.recv() => { - info!("Got SIGINT, waiting for client connection drain now"); - admin_only = true; + info!("Config autoreloader: {}", config.general.autoreload); - // Broadcast that client tasks need to finish - let _ = shutdown_tx.send(()); - let exit_tx = exit_tx.clone(); - let _ = drain_tx.send(0).await; + let mut autoreload_interval = tokio::time::interval(tokio::time::Duration::from_millis(15_000)); + let autoreload_client_server_map = client_server_map.clone(); + tokio::task::spawn(async move { + loop { + autoreload_interval.tick().await; + if config.general.autoreload { + info!("Automatically reloading config"); - tokio::task::spawn(async move { - let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(config.general.shutdown_timeout)); - - // First tick fires immediately. - interval.tick().await; - - // Second one in the interval time. - interval.tick().await; - - // We're done waiting. - error!("Graceful shutdown timed out. {} active clients being closed", total_clients); - - let _ = exit_tx.send(()).await; - }); - }, - - _ = term_signal.recv() => { - info!("Got SIGTERM, closing with {} clients active", total_clients); - break; - }, - - new_client = listener.accept() => { - let (socket, addr) = match new_client { - Ok((socket, addr)) => (socket, addr), - Err(err) => { - error!("{:?}", err); - continue; - } - }; - - let shutdown_rx = shutdown_tx.subscribe(); - let drain_tx = drain_tx.clone(); - let client_server_map = client_server_map.clone(); - - let tls_certificate = config.general.tls_certificate.clone(); - - tokio::task::spawn(async move { - let start = chrono::offset::Utc::now().naive_utc(); - - match client::client_entrypoint( - socket, - client_server_map, - shutdown_rx, - drain_tx, - admin_only, - tls_certificate.clone(), - config.general.log_client_connections, - ) - .await - { - Ok(()) => { - - let duration = chrono::offset::Utc::now().naive_utc() - start; - - if config.general.log_client_disconnections { - info!( - "Client {:?} disconnected, session duration: {}", - addr, - format_duration(&duration) - ); - } else { - debug!( - "Client {:?} disconnected, session duration: {}", - addr, - format_duration(&duration) - ); - } - } - - Err(err) => { - match err { - errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err), - _ => warn!("Client disconnected with error {:?}", err), - } - - } + if let Ok(changed) = reload_config(autoreload_client_server_map.clone()).await { + if changed { + get_config().show() + } }; - }); + } } + }); - _ = exit_rx.recv() => { - break; + let mut term_signal = unix_signal(SignalKind::terminate()).unwrap(); + let mut interrupt_signal = unix_signal(SignalKind::interrupt()).unwrap(); + let mut sighup_signal = unix_signal(SignalKind::hangup()).unwrap(); + let (shutdown_tx, _) = broadcast::channel::<()>(1); + let (drain_tx, mut drain_rx) = mpsc::channel::(2048); + let (exit_tx, mut exit_rx) = mpsc::channel::<()>(1); + + info!("Waiting for clients"); + + let mut admin_only = false; + let mut total_clients = 0; + + loop { + tokio::select! { + // Reload config: + // kill -SIGHUP $(pgrep pgcat) + _ = sighup_signal.recv() => { + info!("Reloading config"); + + _ = reload_config(client_server_map.clone()).await; + + get_config().show(); + }, + + // Initiate graceful shutdown sequence on sig int + _ = interrupt_signal.recv() => { + info!("Got SIGINT, waiting for client connection drain now"); + admin_only = true; + + // Broadcast that client tasks need to finish + let _ = shutdown_tx.send(()); + let exit_tx = exit_tx.clone(); + let _ = drain_tx.send(0).await; + + tokio::task::spawn(async move { + let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(config.general.shutdown_timeout)); + + // First tick fires immediately. + interval.tick().await; + + // Second one in the interval time. + interval.tick().await; + + // We're done waiting. + error!("Graceful shutdown timed out. {} active clients being closed", total_clients); + + let _ = exit_tx.send(()).await; + }); + }, + + _ = term_signal.recv() => { + info!("Got SIGTERM, closing with {} clients active", total_clients); + break; + }, + + new_client = listener.accept() => { + let (socket, addr) = match new_client { + Ok((socket, addr)) => (socket, addr), + Err(err) => { + error!("{:?}", err); + continue; + } + }; + + let shutdown_rx = shutdown_tx.subscribe(); + let drain_tx = drain_tx.clone(); + let client_server_map = client_server_map.clone(); + + let tls_certificate = config.general.tls_certificate.clone(); + + tokio::task::spawn(async move { + let start = chrono::offset::Utc::now().naive_utc(); + + match client::client_entrypoint( + socket, + client_server_map, + shutdown_rx, + drain_tx, + admin_only, + tls_certificate.clone(), + config.general.log_client_connections, + ) + .await + { + Ok(()) => { + + let duration = chrono::offset::Utc::now().naive_utc() - start; + + if config.general.log_client_disconnections { + info!( + "Client {:?} disconnected, session duration: {}", + addr, + format_duration(&duration) + ); + } else { + debug!( + "Client {:?} disconnected, session duration: {}", + addr, + format_duration(&duration) + ); + } + } + + Err(err) => { + match err { + errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err), + _ => warn!("Client disconnected with error {:?}", err), + } + + } + }; + }); + } + + _ = exit_rx.recv() => { + break; + } + + client_ping = drain_rx.recv() => { + let client_ping = client_ping.unwrap(); + total_clients += client_ping; + + if total_clients == 0 && admin_only { + let _ = exit_tx.send(()).await; + } + } } + } - client_ping = drain_rx.recv() => { - let client_ping = client_ping.unwrap(); - total_clients += client_ping; - - if total_clients == 0 && admin_only { - let _ = exit_tx.send(()).await; - } - } - } - } - - info!("Shutting down..."); + info!("Shutting down..."); + }); + Ok(()) }