diff --git a/.circleci/pgcat.toml b/.circleci/pgcat.toml index 9c3babf..56aa1dd 100644 --- a/.circleci/pgcat.toml +++ b/.circleci/pgcat.toml @@ -91,11 +91,13 @@ password = "sharding_user" # The maximum number of connection from a single Pgcat process to any database in the cluster # is the sum of pool_size across all users. pool_size = 9 +statement_timeout = 0 [pools.sharded_db.users.1] username = "other_user" password = "other_user" pool_size = 21 +statement_timeout = 30000 # Shard 0 [pools.sharded_db.shards.0] @@ -133,6 +135,7 @@ sharding_function = "pg_bigint_hash" username = "simple_user" password = "simple_user" pool_size = 5 +statement_timeout = 30000 [pools.simple_db.shards.0] servers = [ diff --git a/.circleci/run_tests.sh b/.circleci/run_tests.sh index 835bd10..645ff94 100644 --- a/.circleci/run_tests.sh +++ b/.circleci/run_tests.sh @@ -66,6 +66,18 @@ psql -U sharding_user -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_te # Replica/primary selection & more sharding tests psql -U sharding_user -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_primary_replica.sql > /dev/null +# Statement timeout tests +sed -i 's/statement_timeout = 0/statement_timeout = 100/' .circleci/pgcat.toml +kill -SIGHUP $(pgrep pgcat) # Reload config +sleep 0.2 + +# This should timeout +(! psql -U sharding_user -e -h 127.0.0.1 -p 6432 -c 'select pg_sleep(0.5)') + +# Disable statement timeout +sed -i 's/statement_timeout = 100/statement_timeout = 0/' .circleci/pgcat.toml +kill -SIGHUP $(pgrep pgcat) # Reload config again + # # ActiveRecord tests # diff --git a/pgcat.toml b/pgcat.toml index bc246f4..2976118 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -92,10 +92,14 @@ password = "sharding_user" # is the sum of pool_size across all users. pool_size = 9 +# Maximum query duration. Dangerous, but protetcts against DBs that died and a non-obvious way. +statement_timeout = 0 + [pools.sharded_db.users.1] username = "other_user" password = "other_user" pool_size = 21 +statement_timeout = 15000 # Shard 0 [pools.sharded_db.shards.0] @@ -133,6 +137,7 @@ sharding_function = "pg_bigint_hash" username = "simple_user" password = "simple_user" pool_size = 5 +statement_timeout = 0 [pools.simple_db.shards.0] servers = [ diff --git a/src/client.rs b/src/client.rs index 1dd1bcc..b36eae0 100644 --- a/src/client.rs +++ b/src/client.rs @@ -499,7 +499,7 @@ where // The query router determines where the query is going to go, // e.g. primary, replica, which shard. let mut query_router = QueryRouter::new(); - let mut round_robin = 0; + let mut round_robin = rand::random(); // Our custom protocol loop. // We expect the client to either start a transaction with regular queries @@ -970,17 +970,54 @@ where } async fn receive_server_message( - &self, + &mut self, server: &mut Server, address: &Address, shard: usize, pool: &ConnectionPool, ) -> Result { - match server.recv().await { - Ok(message) => Ok(message), - Err(err) => { - pool.ban(address, shard, self.process_id); - Err(err) + if pool.settings.user.statement_timeout > 0 { + match tokio::time::timeout( + tokio::time::Duration::from_millis(pool.settings.user.statement_timeout), + server.recv(), + ) + .await + { + Ok(result) => match result { + Ok(message) => Ok(message), + Err(err) => { + pool.ban(address, shard, self.process_id); + error_response_terminal( + &mut self.write, + &format!("error receiving data from server: {:?}", err), + ) + .await?; + Err(err) + } + }, + Err(_) => { + error!( + "Statement timeout while talking to {:?} with user {}", + address, pool.settings.user.username + ); + server.mark_bad(); + pool.ban(address, shard, self.process_id); + error_response_terminal(&mut self.write, "pool statement timeout").await?; + Err(Error::StatementTimeout) + } + } + } else { + match server.recv().await { + Ok(message) => Ok(message), + Err(err) => { + pool.ban(address, shard, self.process_id); + error_response_terminal( + &mut self.write, + &format!("error receiving data from server: {:?}", err), + ) + .await?; + Err(err) + } } } } diff --git a/src/config.rs b/src/config.rs index 57b52ae..ae006b3 100644 --- a/src/config.rs +++ b/src/config.rs @@ -100,6 +100,7 @@ pub struct User { pub username: String, pub password: String, pub pool_size: u32, + pub statement_timeout: u64, } impl Default for User { @@ -108,6 +109,7 @@ impl Default for User { username: String::from("postgres"), password: String::new(), pool_size: 15, + statement_timeout: 0, } } } @@ -332,6 +334,7 @@ impl Config { }; for (pool_name, pool_config) in &self.pools { + // TODO: Make this output prettier (maybe a table?) info!("--- Settings for pool {} ---", pool_name); info!( "Pool size from all users: {}", @@ -346,8 +349,17 @@ impl Config { info!("Sharding function: {}", pool_config.sharding_function); info!("Primary reads: {}", pool_config.primary_reads_enabled); info!("Query router: {}", pool_config.query_parser_enabled); + + // TODO: Make this prettier. info!("Number of shards: {}", pool_config.shards.len()); info!("Number of users: {}", pool_config.users.len()); + + for user in &pool_config.users { + info!( + "{} pool size: {}, statement timeout: {}", + user.1.username, user.1.pool_size, user.1.statement_timeout + ); + } } } } diff --git a/src/errors.rs b/src/errors.rs index cc8f65d..06371fd 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -11,4 +11,5 @@ pub enum Error { AllServersDown, ClientError, TlsError, + StatementTimeout, }