diff --git a/.circleci/run_tests.sh b/.circleci/run_tests.sh index 3a31240..b6bc422 100644 --- a/.circleci/run_tests.sh +++ b/.circleci/run_tests.sh @@ -119,6 +119,13 @@ popd start_pgcat "info" +# +# Rust tests +# +cd tests/rust +cargo run +cd ../../ + # Admin tests export PGPASSWORD=admin_pass psql -U admin_user -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null diff --git a/Cargo.lock b/Cargo.lock index e504397..28f20a9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1020,7 +1020,7 @@ checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" [[package]] name = "pgcat" -version = "1.1.2-dev2" +version = "1.1.2-dev4" dependencies = [ "arc-swap", "async-trait", diff --git a/Cargo.toml b/Cargo.toml index 6485622..f451ffc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgcat" -version = "1.1.2-dev2" +version = "1.1.2-dev4" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/src/client.rs b/src/client.rs index dd89697..bbeb526 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1149,7 +1149,7 @@ where // This reads the first byte without advancing the internal pointer and mutating the bytes let code = *message.first().unwrap() as char; - trace!("Message: {}", code); + trace!("Client message: {}", code); match code { // Query @@ -1188,6 +1188,7 @@ where }; } } + debug!("Sending query to server"); self.send_and_receive_loop( @@ -1320,6 +1321,7 @@ where { match protocol_data { ExtendedProtocolData::Parse { data, metadata } => { + debug!("Have parse in extended buffer"); let (parse, hash) = match metadata { Some(metadata) => metadata, None => { @@ -1656,11 +1658,25 @@ where ) -> Result<(), Error> { match self.prepared_statements.get(&client_name) { Some((parse, hash)) => { - debug!("Prepared statement `{}` found in cache", parse.name); + debug!("Prepared statement `{}` found in cache", client_name); // In this case we want to send the parse message to the server // since pgcat is initiating the prepared statement on this specific server - self.register_parse_to_server_cache(true, hash, parse, pool, server, address) - .await?; + match self + .register_parse_to_server_cache(true, hash, parse, pool, server, address) + .await + { + Ok(_) => (), + Err(err) => match err { + Error::PreparedStatementError => { + debug!("Removed {} from client cache", client_name); + self.prepared_statements.remove(&client_name); + } + + _ => { + return Err(err); + } + }, + } } None => { @@ -1689,11 +1705,20 @@ where // We want to promote this in the pool's LRU pool.promote_prepared_statement_hash(hash); + debug!("Checking for prepared statement {}", parse.name); + if let Err(err) = server .register_prepared_statement(parse, should_send_parse_to_server) .await { - pool.ban(address, BanReason::MessageSendFailed, Some(&self.stats)); + match err { + // Don't ban for this. + Error::PreparedStatementError => (), + _ => { + pool.ban(address, BanReason::MessageSendFailed, Some(&self.stats)); + } + }; + return Err(err); } diff --git a/src/errors.rs b/src/errors.rs index a6aebc5..13047b4 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -29,6 +29,7 @@ pub enum Error { QueryRouterParserError(String), QueryRouterError(String), InvalidShardId(usize), + PreparedStatementError, } #[derive(Clone, PartialEq, Debug)] diff --git a/src/server.rs b/src/server.rs index dff6a76..9089b56 100644 --- a/src/server.rs +++ b/src/server.rs @@ -7,7 +7,7 @@ use lru::LruCache; use once_cell::sync::Lazy; use parking_lot::{Mutex, RwLock}; use postgres_protocol::message; -use std::collections::{HashMap, HashSet}; +use std::collections::{HashMap, HashSet, VecDeque}; use std::mem; use std::net::IpAddr; use std::num::NonZeroUsize; @@ -325,6 +325,9 @@ pub struct Server { /// Prepared statements prepared_statement_cache: Option>, + + /// Prepared statement being currently registered on the server. + registering_prepared_statement: VecDeque, } impl Server { @@ -827,6 +830,7 @@ impl Server { NonZeroUsize::new(prepared_statement_cache_size).unwrap(), )), }, + registering_prepared_statement: VecDeque::new(), }; return Ok(server); @@ -956,7 +960,6 @@ impl Server { // There is no more data available from the server. self.data_available = false; - break; } @@ -966,6 +969,23 @@ impl Server { self.in_copy_mode = false; } + // Remove the prepared statement from the cache, it has a syntax error or something else bad happened. + if let Some(prepared_stmt_name) = + self.registering_prepared_statement.pop_front() + { + if let Some(ref mut cache) = self.prepared_statement_cache { + if let Some(_removed) = cache.pop(&prepared_stmt_name) { + debug!( + "Removed {} from prepared statement cache", + prepared_stmt_name + ); + } else { + // Shouldn't happen. + debug!("Prepared statement {} was not cached", prepared_stmt_name); + } + } + } + if self.prepared_statement_cache.is_some() { let error_message = PgErrorMsg::parse(&message)?; if error_message.message == "cached plan must not change result type" { @@ -1068,6 +1088,11 @@ impl Server { // Buffer until ReadyForQuery shows up, so don't exit the loop yet. 'c' => (), + // Parse complete successfully + '1' => { + self.registering_prepared_statement.pop_front(); + } + // Anything else, e.g. errors, notices, etc. // Keep buffering until ReadyForQuery shows up. _ => (), @@ -1107,7 +1132,7 @@ impl Server { has_it } - pub fn add_prepared_statement_to_cache(&mut self, name: &str) -> Option { + fn add_prepared_statement_to_cache(&mut self, name: &str) -> Option { let cache = match &mut self.prepared_statement_cache { Some(cache) => cache, None => return None, @@ -1129,7 +1154,7 @@ impl Server { None } - pub fn remove_prepared_statement_from_cache(&mut self, name: &str) { + fn remove_prepared_statement_from_cache(&mut self, name: &str) { let cache = match &mut self.prepared_statement_cache { Some(cache) => cache, None => return, @@ -1145,6 +1170,9 @@ impl Server { should_send_parse_to_server: bool, ) -> Result<(), Error> { if !self.has_prepared_statement(&parse.name) { + self.registering_prepared_statement + .push_back(parse.name.clone()); + let mut bytes = BytesMut::new(); if should_send_parse_to_server { @@ -1176,7 +1204,13 @@ impl Server { } }; - Ok(()) + // If it's not there, something went bad, I'm guessing bad syntax or permissions error + // on the server. + if !self.has_prepared_statement(&parse.name) { + Err(Error::PreparedStatementError) + } else { + Ok(()) + } } /// If the server is still inside a transaction. @@ -1186,6 +1220,7 @@ impl Server { self.in_transaction } + /// Currently copying data from client to server or vice-versa. pub fn in_copy_mode(&self) -> bool { self.in_copy_mode } diff --git a/tests/rust/src/main.rs b/tests/rust/src/main.rs index 79667bc..c61d48c 100644 --- a/tests/rust/src/main.rs +++ b/tests/rust/src/main.rs @@ -16,7 +16,14 @@ async fn test_prepared_statements() { let pool = pool.clone(); let handle = tokio::task::spawn(async move { for _ in 0..1000 { - sqlx::query("SELECT 1").fetch_all(&pool).await.unwrap(); + match sqlx::query("SELECT one").fetch_all(&pool).await { + Ok(_) => (), + Err(err) => { + if err.to_string().contains("prepared statement") { + panic!("prepared statement error: {}", err); + } + } + } } });