diff --git a/README.md b/README.md index dfecc0a..1e77baf 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,8 @@ psql -h 127.0.0.1 -p 6432 -c 'SELECT 1' | **`user`** | | | | `name` | The user name. | `sharding_user` | | `password` | The user password in plaintext. | `hunter2` | +| `statement_timeout` | Timeout in milliseconds for how long a query takes to execute | `0 (disabled)` | + | | | | | **`shards`** | Shards are numerically numbered starting from 0; the order in the config is preserved by the pooler to route queries accordingly. | `[shards.0]` | | `servers` | List of servers to connect to and their roles. A server is: `[host, port, role]`, where `role` is either `primary` or `replica`. | `["127.0.0.1", 5432, "primary"]` | diff --git a/src/client.rs b/src/client.rs index 3b0b0ea..4e8556b 100644 --- a/src/client.rs +++ b/src/client.rs @@ -9,11 +9,11 @@ use tokio::sync::broadcast::Receiver; use tokio::sync::mpsc::Sender; use crate::admin::{generate_server_info_for_admin, handle_admin}; -use crate::config::{get_config, Address}; +use crate::config::{get_config, Address, PoolMode}; use crate::constants::*; use crate::errors::Error; use crate::messages::*; -use crate::pool::{get_pool, ClientServerMap, ConnectionPool, PoolMode}; +use crate::pool::{get_pool, ClientServerMap, ConnectionPool}; use crate::query_router::{Command, QueryRouter}; use crate::server::Server; use crate::stats::{get_reporter, Reporter}; diff --git a/src/config.rs b/src/config.rs index e4c25a2..893f5b7 100644 --- a/src/config.rs +++ b/src/config.rs @@ -127,6 +127,7 @@ pub struct User { pub username: String, pub password: String, pub pool_size: u32, + #[serde(default)] // 0 pub statement_timeout: u64, } @@ -144,34 +145,81 @@ impl Default for User { /// General configuration. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct General { + #[serde(default = "General::default_host")] pub host: String, + + #[serde(default = "General::default_port")] pub port: i16, + pub enable_prometheus_exporter: Option, pub prometheus_exporter_port: i16, + + #[serde(default = "General::default_connect_timeout")] pub connect_timeout: u64, - pub healthcheck_timeout: u64, + + #[serde(default = "General::default_shutdown_timeout")] pub shutdown_timeout: u64, + + #[serde(default = "General::default_healthcheck_timeout")] + pub healthcheck_timeout: u64, + + #[serde(default = "General::default_healthcheck_delay")] pub healthcheck_delay: u64, + + #[serde(default = "General::default_ban_time")] pub ban_time: i64, + + #[serde(default)] // False pub autoreload: bool, + pub tls_certificate: Option, pub tls_private_key: Option, pub admin_username: String, pub admin_password: String, } +impl General { + fn default_host() -> String { + "0.0.0.0".into() + } + + fn default_port() -> i16 { + 5432 + } + + fn default_connect_timeout() -> u64 { + 1000 + } + + fn default_shutdown_timeout() -> u64 { + 60000 + } + + fn default_healthcheck_timeout() -> u64 { + 1000 + } + + fn default_healthcheck_delay() -> u64 { + 30000 + } + + fn default_ban_time() -> i64 { + 60 + } +} + impl Default for General { fn default() -> General { General { - host: String::from("localhost"), - port: 5432, + host: General::default_host(), + port: General::default_port(), enable_prometheus_exporter: Some(false), prometheus_exporter_port: 9930, - connect_timeout: 5000, - healthcheck_timeout: 1000, - shutdown_timeout: 60000, - healthcheck_delay: 30000, - ban_time: 60, + connect_timeout: General::default_connect_timeout(), + shutdown_timeout: General::default_shutdown_timeout(), + healthcheck_timeout: General::default_healthcheck_timeout(), + healthcheck_delay: General::default_healthcheck_delay(), + ban_time: General::default_ban_time(), autoreload: false, tls_certificate: None, tls_private_key: None, @@ -180,25 +228,61 @@ impl Default for General { } } } + +/// Pool mode: +/// - transaction: server serves one transaction, +/// - session: server is attached to the client. +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Copy)] +pub enum PoolMode { + #[serde(alias = "transaction", alias = "Transaction")] + Transaction, + + #[serde(alias = "session", alias = "Session")] + Session, +} + +impl ToString for PoolMode { + fn to_string(&self) -> String { + match *self { + PoolMode::Transaction => "transaction".to_string(), + PoolMode::Session => "session".to_string(), + } + } +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct Pool { - pub pool_mode: String, + #[serde(default = "Pool::default_pool_mode")] + pub pool_mode: PoolMode, + pub default_role: String, + + #[serde(default)] // False pub query_parser_enabled: bool, + + #[serde(default)] // False pub primary_reads_enabled: bool, + pub sharding_function: String, pub shards: HashMap, pub users: HashMap, } + +impl Pool { + fn default_pool_mode() -> PoolMode { + PoolMode::Transaction + } +} + impl Default for Pool { fn default() -> Pool { Pool { - pool_mode: String::from("transaction"), + pool_mode: Pool::default_pool_mode(), shards: HashMap::from([(String::from("1"), Shard::default())]), users: HashMap::default(), default_role: String::from("any"), query_parser_enabled: false, - primary_reads_enabled: true, + primary_reads_enabled: false, sharding_function: "pg_bigint_hash".to_string(), } } @@ -231,10 +315,6 @@ impl Default for Shard { } } -fn default_path() -> String { - String::from("pgcat.toml") -} - /// Configuration wrapper. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct Config { @@ -249,17 +329,23 @@ pub struct Config { // [main.subconf] // field1_under_subconf = 1 // field3_under_main = 3 # This field will be interpreted as being under subconf and not under main - #[serde(default = "default_path")] + #[serde(default = "Config::default_path")] pub path: String, pub general: General, pub pools: HashMap, } +impl Config { + fn default_path() -> String { + String::from("pgcat.toml") + } +} + impl Default for Config { fn default() -> Config { Config { - path: String::from("pgcat.toml"), + path: Config::default_path(), general: General::default(), pools: HashMap::default(), } @@ -275,7 +361,7 @@ impl From<&Config> for std::collections::HashMap { [ ( format!("pools.{}.pool_mode", pool_name), - pool.pool_mode.clone(), + pool.pool_mode.to_string(), ), ( format!("pools.{}.primary_reads_enabled", pool_name), @@ -383,7 +469,10 @@ impl Config { .sum::() .to_string() ); - info!("[pool: {}] Pool mode: {}", pool_name, pool_config.pool_mode); + info!( + "[pool: {}] Pool mode: {:?}", + pool_name, pool_config.pool_mode + ); info!( "[pool: {}] Sharding function: {}", pool_name, pool_config.sharding_function @@ -513,18 +602,6 @@ pub async fn parse(path: &str) -> Result<(), Error> { } }; - match pool.pool_mode.as_ref() { - "transaction" => (), - "session" => (), - other => { - error!( - "pool_mode can be 'session' or 'transaction', got: '{}'", - other - ); - return Err(Error::BadConfig); - } - }; - for shard in &pool.shards { // We use addresses as unique identifiers, // let's make sure they are unique in the config as well. diff --git a/src/pool.rs b/src/pool.rs index f81a1e0..af64fd3 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -12,7 +12,7 @@ use std::collections::HashMap; use std::sync::Arc; use std::time::Instant; -use crate::config::{get_config, Address, Role, User}; +use crate::config::{get_config, Address, PoolMode, Role, User}; use crate::errors::Error; use crate::server::Server; @@ -27,24 +27,6 @@ pub type PoolMap = HashMap<(String, String), ConnectionPool>; /// The pool is recreated dynamically when the config is reloaded. pub static POOLS: Lazy> = Lazy::new(|| ArcSwap::from_pointee(HashMap::default())); -/// Pool mode: -/// - transaction: server serves one transaction, -/// - session: server is attached to the client. -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum PoolMode { - Session, - Transaction, -} - -impl std::fmt::Display for PoolMode { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match *self { - PoolMode::Session => write!(f, "session"), - PoolMode::Transaction => write!(f, "transaction"), - } - } -} - /// Pool settings. #[derive(Clone, Debug)] pub struct PoolSettings { @@ -199,11 +181,7 @@ impl ConnectionPool { stats: get_reporter(), server_info: BytesMut::new(), settings: PoolSettings { - pool_mode: match pool_config.pool_mode.as_str() { - "transaction" => PoolMode::Transaction, - "session" => PoolMode::Session, - _ => unreachable!(), - }, + pool_mode: pool_config.pool_mode, // shards: pool_config.shards.clone(), shards: shard_ids.len(), user: user.clone(), diff --git a/src/query_router.rs b/src/query_router.rs index 85d4f8b..34745ed 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -359,8 +359,8 @@ impl QueryRouter { #[cfg(test)] mod test { use super::*; + use crate::config::PoolMode; use crate::messages::simple_query; - use crate::pool::PoolMode; use crate::sharding::ShardingFunction; use bytes::BufMut;