diff --git a/.circleci/pgcat.toml b/.circleci/pgcat.toml index ba3b875..d7249f1 100644 --- a/.circleci/pgcat.toml +++ b/.circleci/pgcat.toml @@ -29,6 +29,9 @@ healthcheck_timeout = 100 # For how long to ban a server if it fails a health check (seconds). ban_time = 60 # Seconds +# +autoreload = true + # # User to use for authentication against the server. [user] diff --git a/Cargo.lock b/Cargo.lock index 87d11be..b8b4bfd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -238,9 +238,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.117" +version = "0.2.126" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e74d72e0f9b65b5b4ca49a346af3976df0f9c61d550727f349ecd559f251a26c" +checksum = "349d5a591cd28b49e1d1037471617a32ddcda5731b99419008085f72d5a53836" [[package]] name = "lock_api" diff --git a/pgcat.toml b/pgcat.toml index 435dda9..70b2fae 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -29,6 +29,9 @@ healthcheck_timeout = 1000 # For how long to ban a server if it fails a health check (seconds). ban_time = 60 # Seconds +# Reload config automatically if it changes. +autoreload = false + # # User to use for authentication against the server. [user] diff --git a/src/config.rs b/src/config.rs index 96d5a77..f6fa129 100644 --- a/src/config.rs +++ b/src/config.rs @@ -101,7 +101,7 @@ impl Default for User { } /// General configuration. -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone, PartialEq)] pub struct General { pub host: String, pub port: i16, @@ -110,6 +110,7 @@ pub struct General { pub connect_timeout: u64, pub healthcheck_timeout: u64, pub ban_time: i64, + pub autoreload: bool, } impl Default for General { @@ -122,6 +123,7 @@ impl Default for General { connect_timeout: 5000, healthcheck_timeout: 1000, ban_time: 60, + autoreload: false, } } } @@ -143,7 +145,7 @@ impl Default for Shard { } /// Query Router configuration. -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone, PartialEq)] pub struct QueryRouter { pub default_role: String, pub query_parser_enabled: bool, @@ -167,7 +169,7 @@ fn default_path() -> String { } /// Configuration wrapper. -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone, PartialEq)] pub struct Config { #[serde(default = "default_path")] pub path: String, @@ -374,7 +376,7 @@ pub async fn parse(path: &str) -> Result<(), Error> { Ok(()) } -pub async fn reload_config(client_server_map: ClientServerMap) -> Result<(), Error> { +pub async fn reload_config(client_server_map: ClientServerMap) -> Result { let old_config = get_config(); match parse(&old_config.path).await { @@ -387,11 +389,14 @@ pub async fn reload_config(client_server_map: ClientServerMap) -> Result<(), Err let new_config = get_config(); - if old_config.shards != new_config.shards { + if old_config.shards != new_config.shards || old_config.user != new_config.user { info!("Sharding configuration changed, re-creating server pools"); - ConnectionPool::from_config(client_server_map).await + ConnectionPool::from_config(client_server_map).await?; + Ok(true) + } else if old_config != new_config { + Ok(true) } else { - Ok(()) + Ok(false) } } diff --git a/src/main.rs b/src/main.rs index 70094d8..93715f4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -127,6 +127,7 @@ async fn main() { // Save these for reloading let reload_client_server_map = client_server_map.clone(); + let autoreload_client_server_map = client_server_map.clone(); let addresses = pool.databases(); tokio::task::spawn(async move { @@ -203,6 +204,26 @@ async fn main() { } }); + if config.general.autoreload { + let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(15_000)); + + tokio::task::spawn(async move { + info!("Config autoreloader started"); + + loop { + interval.tick().await; + match reload_config(autoreload_client_server_map.clone()).await { + Ok(changed) => { + if changed { + get_config().show() + } + } + Err(_) => (), + }; + } + }); + } + // Exit on Ctrl-C (SIGINT) and SIGTERM. let mut term_signal = unix_signal(SignalKind::terminate()).unwrap();