From 5972b6fa525261d19ab2bcbf86aa1cf08daaa5e6 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Thu, 24 Feb 2022 08:44:41 -0800 Subject: [PATCH] Switch to parking_lot RwLock & Mutex. Use trace! for protocol instead of debug! (#42) * RwLock & parking_lot::Mutex * upgrade to trace --- Cargo.lock | 1 + Cargo.toml | 1 + src/client.rs | 22 +++++++++++----------- src/main.rs | 3 ++- src/messages.rs | 2 +- src/pool.rs | 33 +++++++++++++++++++++++---------- src/server.rs | 27 +++++++++------------------ 7 files changed, 48 insertions(+), 41 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1259b4e..d6e42cb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -364,6 +364,7 @@ dependencies = [ "md-5", "num_cpus", "once_cell", + "parking_lot", "rand", "regex", "serde", diff --git a/Cargo.toml b/Cargo.toml index 860955e..d070c61 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,4 @@ sqlparser = "0.14" log = "0.4" arc-swap = "1" env_logger = "0.9" +parking_lot = "0.11" diff --git a/src/client.rs b/src/client.rs index 0961253..3c1eeea 100644 --- a/src/client.rs +++ b/src/client.rs @@ -2,7 +2,7 @@ /// We are pretending to the server in this scenario, /// and this module implements that. use bytes::{Buf, BufMut, BytesMut}; -use log::{debug, error}; +use log::{debug, error, trace}; use tokio::io::{AsyncReadExt, BufReader}; use tokio::net::{ tcp::{OwnedReadHalf, OwnedWriteHalf}, @@ -70,7 +70,7 @@ impl Client { let transaction_mode = config.general.pool_mode.starts_with("t"); drop(config); loop { - debug!("Waiting for StartupMessage"); + trace!("Waiting for StartupMessage"); // Could be StartupMessage or SSLRequest // which makes this variable length. @@ -93,7 +93,7 @@ impl Client { match code { // Client wants SSL. We don't support it at the moment. SSL_REQUEST_CODE => { - debug!("Rejecting SSLRequest"); + trace!("Rejecting SSLRequest"); let mut no = BytesMut::with_capacity(1); no.put_u8(b'N'); @@ -103,7 +103,7 @@ impl Client { // Regular startup message. PROTOCOL_VERSION_NUMBER => { - debug!("Got StartupMessage"); + trace!("Got StartupMessage"); // TODO: perform actual auth. let parameters = parse_startup(bytes.clone())?; @@ -116,7 +116,7 @@ impl Client { write_all(&mut stream, server_info).await?; backend_key_data(&mut stream, process_id, secret_key).await?; ready_for_query(&mut stream).await?; - debug!("Startup OK"); + trace!("Startup OK"); // Split the read and write streams // so we can control buffering. @@ -168,10 +168,10 @@ impl Client { pub async fn handle(&mut self, mut pool: ConnectionPool) -> Result<(), Error> { // The client wants to cancel a query it has issued previously. if self.cancel_mode { - debug!("Sending CancelRequest"); + trace!("Sending CancelRequest"); let (process_id, secret_key, address, port) = { - let guard = self.client_server_map.lock().unwrap(); + let guard = self.client_server_map.lock(); match guard.get(&(self.process_id, self.secret_key)) { // Drop the mutex as soon as possible. @@ -202,7 +202,7 @@ impl Client { // We expect the client to either start a transaction with regular queries // or issue commands for our sharding and server selection protocols. loop { - debug!("Client idle, waiting for message"); + trace!("Client idle, waiting for message"); // Client idle, waiting for messages. self.stats.client_idle(self.process_id); @@ -216,7 +216,7 @@ impl Client { // Avoid taking a server if the client just wants to disconnect. if message[0] as char == 'X' { - debug!("Client disconnecting"); + trace!("Client disconnecting"); return Ok(()); } @@ -333,7 +333,7 @@ impl Client { let code = message.get_u8() as char; let _len = message.get_i32() as usize; - debug!("Message: {}", code); + trace!("Message: {}", code); match code { // ReadyForQuery @@ -514,7 +514,7 @@ impl Client { /// Release the server from being mine. I can't cancel its queries anymore. pub fn release(&self) { - let mut guard = self.client_server_map.lock().unwrap(); + let mut guard = self.client_server_map.lock(); guard.remove(&(self.process_id, self.secret_key)); } } diff --git a/src/main.rs b/src/main.rs index 21da9af..2cf6773 100644 --- a/src/main.rs +++ b/src/main.rs @@ -36,6 +36,7 @@ extern crate tokio; extern crate toml; use log::{error, info}; +use parking_lot::Mutex; use tokio::net::TcpListener; use tokio::{ signal, @@ -44,7 +45,7 @@ use tokio::{ }; use std::collections::HashMap; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; mod client; mod config; diff --git a/src/messages.rs b/src/messages.rs index 16b2f84..e48af53 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -38,7 +38,7 @@ pub async fn backend_key_data( Ok(write_all(stream, key_data).await?) } -#[allow(dead_code)] +/// Construct a `Q`: Query message. pub fn simple_query(query: &str) -> BytesMut { let mut res = BytesMut::from(&b"Q"[..]); let query = format!("{}\0", query); diff --git a/src/pool.rs b/src/pool.rs index fb33fd8..b1dda4f 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -3,7 +3,8 @@ use async_trait::async_trait; use bb8::{ManageConnection, Pool, PooledConnection}; use bytes::BytesMut; use chrono::naive::NaiveDateTime; -use log::{error, info, warn}; +use log::{debug, error, info, warn}; +use parking_lot::{Mutex, RwLock}; use crate::config::{get_config, Address, Role, User}; use crate::errors::Error; @@ -11,11 +12,11 @@ use crate::server::Server; use crate::stats::Reporter; use std::collections::HashMap; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::time::Instant; // Banlist: bad servers go in here. -pub type BanList = Arc>>>; +pub type BanList = Arc>>>; pub type ClientServerMap = Arc>>; #[derive(Clone, Debug)] @@ -101,7 +102,7 @@ impl ConnectionPool { databases: shards, addresses: addresses, round_robin: rand::random::() % address_len, // Start at a random replica - banlist: Arc::new(Mutex::new(banlist)), + banlist: Arc::new(RwLock::new(banlist)), stats: stats, } } @@ -161,6 +162,8 @@ impl ConnectionPool { _ => addresses.len(), }; + debug!("Allowed attempts for {:?}: {}", role, allowed_attempts); + let exists = match role { Some(role) => addresses.iter().filter(|addr| addr.role == role).count() > 0, None => true, @@ -251,14 +254,14 @@ impl ConnectionPool { pub fn ban(&self, address: &Address, shard: usize) { error!("Banning {:?}", address); let now = chrono::offset::Utc::now().naive_utc(); - let mut guard = self.banlist.lock().unwrap(); + let mut guard = self.banlist.write(); guard[shard].insert(address.clone(), now); } /// Clear the replica to receive traffic again. Takes effect immediately /// for all new transactions. pub fn _unban(&self, address: &Address, shard: usize) { - let mut guard = self.banlist.lock().unwrap(); + let mut guard = self.banlist.write(); guard[shard].remove(address); } @@ -274,12 +277,14 @@ impl ConnectionPool { Some(Role::Primary) => return false, // Primary cannot be banned. }; - // If you're not asking for the primary, - // all databases are treated as replicas. - let mut guard = self.banlist.lock().unwrap(); + debug!("Available targets for {:?}: {}", role, replicas_available); + + let guard = self.banlist.read(); // Everything is banned = nothing is banned. if guard[shard].len() == replicas_available { + drop(guard); + let mut guard = self.banlist.write(); guard[shard].clear(); drop(guard); warn!("Unbanning all replicas."); @@ -291,16 +296,24 @@ impl ConnectionPool { Some(timestamp) => { let now = chrono::offset::Utc::now().naive_utc(); let config = get_config(); + // Ban expired. if now.timestamp() - timestamp.timestamp() > config.general.ban_time { + drop(guard); + warn!("Unbanning {:?}", address); + let mut guard = self.banlist.write(); guard[shard].remove(address); false } else { + debug!("{:?} is banned", address); true } } - None => false, + None => { + debug!("{:?} is ok", address); + false + } } } diff --git a/src/server.rs b/src/server.rs index f186c0c..934fc95 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,7 +1,7 @@ use bytes::{Buf, BufMut, BytesMut}; ///! Implementation of the PostgreSQL server (database) protocol. ///! Here we are pretending to the a Postgres client. -use log::{debug, error, info}; +use log::{debug, error, info, trace}; use tokio::io::{AsyncReadExt, BufReader}; use tokio::net::{ tcp::{OwnedReadHalf, OwnedWriteHalf}, @@ -75,7 +75,7 @@ impl Server { } }; - debug!("Sending StartupMessage"); + trace!("Sending StartupMessage"); // Send the startup packet telling the server we're a normal Postgres client. startup(&mut stream, &user.name, database).await?; @@ -97,7 +97,7 @@ impl Server { Err(_) => return Err(Error::SocketError), }; - debug!("Message: {}", code); + trace!("Message: {}", code); match code { // Authentication @@ -108,7 +108,7 @@ impl Server { Err(_) => return Err(Error::SocketError), }; - debug!("Auth: {}", auth_code); + trace!("Auth: {}", auth_code); match auth_code { MD5_ENCRYPTED_PASSWORD => { @@ -141,7 +141,7 @@ impl Server { Err(_) => return Err(Error::SocketError), }; - debug!("Error: {}", error_code); + trace!("Error: {}", error_code); match error_code { // No error message is present in the message. @@ -300,7 +300,7 @@ impl Server { let code = message.get_u8() as char; let _len = message.get_i32(); - debug!("Message: {}", code); + trace!("Message: {}", code); match code { // ReadyForQuery @@ -415,7 +415,7 @@ impl Server { /// Claim this server as mine for the purposes of query cancellation. pub fn claim(&mut self, process_id: i32, secret_key: i32) { - let mut guard = self.client_server_map.lock().unwrap(); + let mut guard = self.client_server_map.lock(); guard.insert( (process_id, secret_key), ( @@ -431,18 +431,9 @@ impl Server { /// It will use the simple query protocol. /// Result will not be returned, so this is useful for things like `SET` or `ROLLBACK`. pub async fn query(&mut self, query: &str) -> Result<(), Error> { - let mut query = BytesMut::from(&query.as_bytes()[..]); - query.put_u8(0); // C-string terminator (NULL character). + let query = simple_query(query); - let len = query.len() as i32 + 4; - - let mut msg = BytesMut::with_capacity(len as usize + 1); - - msg.put_u8(b'Q'); - msg.put_i32(len); - msg.put_slice(&query[..]); - - self.send(msg).await?; + self.send(query).await?; loop { let _ = self.recv().await?;