Switch to parking_lot RwLock & Mutex. Use trace! for protocol instead of debug! (#42)

* RwLock & parking_lot::Mutex

* upgrade to trace
This commit is contained in:
Lev Kokotov
2022-02-24 08:44:41 -08:00
committed by GitHub
parent b3c8ca4b8a
commit 5972b6fa52
7 changed files with 48 additions and 41 deletions

1
Cargo.lock generated
View File

@@ -364,6 +364,7 @@ dependencies = [
"md-5", "md-5",
"num_cpus", "num_cpus",
"once_cell", "once_cell",
"parking_lot",
"rand", "rand",
"regex", "regex",
"serde", "serde",

View File

@@ -25,3 +25,4 @@ sqlparser = "0.14"
log = "0.4" log = "0.4"
arc-swap = "1" arc-swap = "1"
env_logger = "0.9" env_logger = "0.9"
parking_lot = "0.11"

View File

@@ -2,7 +2,7 @@
/// We are pretending to the server in this scenario, /// We are pretending to the server in this scenario,
/// and this module implements that. /// and this module implements that.
use bytes::{Buf, BufMut, BytesMut}; use bytes::{Buf, BufMut, BytesMut};
use log::{debug, error}; use log::{debug, error, trace};
use tokio::io::{AsyncReadExt, BufReader}; use tokio::io::{AsyncReadExt, BufReader};
use tokio::net::{ use tokio::net::{
tcp::{OwnedReadHalf, OwnedWriteHalf}, tcp::{OwnedReadHalf, OwnedWriteHalf},
@@ -70,7 +70,7 @@ impl Client {
let transaction_mode = config.general.pool_mode.starts_with("t"); let transaction_mode = config.general.pool_mode.starts_with("t");
drop(config); drop(config);
loop { loop {
debug!("Waiting for StartupMessage"); trace!("Waiting for StartupMessage");
// Could be StartupMessage or SSLRequest // Could be StartupMessage or SSLRequest
// which makes this variable length. // which makes this variable length.
@@ -93,7 +93,7 @@ impl Client {
match code { match code {
// Client wants SSL. We don't support it at the moment. // Client wants SSL. We don't support it at the moment.
SSL_REQUEST_CODE => { SSL_REQUEST_CODE => {
debug!("Rejecting SSLRequest"); trace!("Rejecting SSLRequest");
let mut no = BytesMut::with_capacity(1); let mut no = BytesMut::with_capacity(1);
no.put_u8(b'N'); no.put_u8(b'N');
@@ -103,7 +103,7 @@ impl Client {
// Regular startup message. // Regular startup message.
PROTOCOL_VERSION_NUMBER => { PROTOCOL_VERSION_NUMBER => {
debug!("Got StartupMessage"); trace!("Got StartupMessage");
// TODO: perform actual auth. // TODO: perform actual auth.
let parameters = parse_startup(bytes.clone())?; let parameters = parse_startup(bytes.clone())?;
@@ -116,7 +116,7 @@ impl Client {
write_all(&mut stream, server_info).await?; write_all(&mut stream, server_info).await?;
backend_key_data(&mut stream, process_id, secret_key).await?; backend_key_data(&mut stream, process_id, secret_key).await?;
ready_for_query(&mut stream).await?; ready_for_query(&mut stream).await?;
debug!("Startup OK"); trace!("Startup OK");
// Split the read and write streams // Split the read and write streams
// so we can control buffering. // so we can control buffering.
@@ -168,10 +168,10 @@ impl Client {
pub async fn handle(&mut self, mut pool: ConnectionPool) -> Result<(), Error> { pub async fn handle(&mut self, mut pool: ConnectionPool) -> Result<(), Error> {
// The client wants to cancel a query it has issued previously. // The client wants to cancel a query it has issued previously.
if self.cancel_mode { if self.cancel_mode {
debug!("Sending CancelRequest"); trace!("Sending CancelRequest");
let (process_id, secret_key, address, port) = { let (process_id, secret_key, address, port) = {
let guard = self.client_server_map.lock().unwrap(); let guard = self.client_server_map.lock();
match guard.get(&(self.process_id, self.secret_key)) { match guard.get(&(self.process_id, self.secret_key)) {
// Drop the mutex as soon as possible. // Drop the mutex as soon as possible.
@@ -202,7 +202,7 @@ impl Client {
// We expect the client to either start a transaction with regular queries // We expect the client to either start a transaction with regular queries
// or issue commands for our sharding and server selection protocols. // or issue commands for our sharding and server selection protocols.
loop { loop {
debug!("Client idle, waiting for message"); trace!("Client idle, waiting for message");
// Client idle, waiting for messages. // Client idle, waiting for messages.
self.stats.client_idle(self.process_id); self.stats.client_idle(self.process_id);
@@ -216,7 +216,7 @@ impl Client {
// Avoid taking a server if the client just wants to disconnect. // Avoid taking a server if the client just wants to disconnect.
if message[0] as char == 'X' { if message[0] as char == 'X' {
debug!("Client disconnecting"); trace!("Client disconnecting");
return Ok(()); return Ok(());
} }
@@ -333,7 +333,7 @@ impl Client {
let code = message.get_u8() as char; let code = message.get_u8() as char;
let _len = message.get_i32() as usize; let _len = message.get_i32() as usize;
debug!("Message: {}", code); trace!("Message: {}", code);
match code { match code {
// ReadyForQuery // ReadyForQuery
@@ -514,7 +514,7 @@ impl Client {
/// Release the server from being mine. I can't cancel its queries anymore. /// Release the server from being mine. I can't cancel its queries anymore.
pub fn release(&self) { pub fn release(&self) {
let mut guard = self.client_server_map.lock().unwrap(); let mut guard = self.client_server_map.lock();
guard.remove(&(self.process_id, self.secret_key)); guard.remove(&(self.process_id, self.secret_key));
} }
} }

View File

@@ -36,6 +36,7 @@ extern crate tokio;
extern crate toml; extern crate toml;
use log::{error, info}; use log::{error, info};
use parking_lot::Mutex;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::{ use tokio::{
signal, signal,
@@ -44,7 +45,7 @@ use tokio::{
}; };
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex}; use std::sync::Arc;
mod client; mod client;
mod config; mod config;

View File

@@ -38,7 +38,7 @@ pub async fn backend_key_data(
Ok(write_all(stream, key_data).await?) Ok(write_all(stream, key_data).await?)
} }
#[allow(dead_code)] /// Construct a `Q`: Query message.
pub fn simple_query(query: &str) -> BytesMut { pub fn simple_query(query: &str) -> BytesMut {
let mut res = BytesMut::from(&b"Q"[..]); let mut res = BytesMut::from(&b"Q"[..]);
let query = format!("{}\0", query); let query = format!("{}\0", query);

View File

@@ -3,7 +3,8 @@ use async_trait::async_trait;
use bb8::{ManageConnection, Pool, PooledConnection}; use bb8::{ManageConnection, Pool, PooledConnection};
use bytes::BytesMut; use bytes::BytesMut;
use chrono::naive::NaiveDateTime; use chrono::naive::NaiveDateTime;
use log::{error, info, warn}; use log::{debug, error, info, warn};
use parking_lot::{Mutex, RwLock};
use crate::config::{get_config, Address, Role, User}; use crate::config::{get_config, Address, Role, User};
use crate::errors::Error; use crate::errors::Error;
@@ -11,11 +12,11 @@ use crate::server::Server;
use crate::stats::Reporter; use crate::stats::Reporter;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex}; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
// Banlist: bad servers go in here. // Banlist: bad servers go in here.
pub type BanList = Arc<Mutex<Vec<HashMap<Address, NaiveDateTime>>>>; pub type BanList = Arc<RwLock<Vec<HashMap<Address, NaiveDateTime>>>>;
pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, String)>>>; pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, String)>>>;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@@ -101,7 +102,7 @@ impl ConnectionPool {
databases: shards, databases: shards,
addresses: addresses, addresses: addresses,
round_robin: rand::random::<usize>() % address_len, // Start at a random replica round_robin: rand::random::<usize>() % address_len, // Start at a random replica
banlist: Arc::new(Mutex::new(banlist)), banlist: Arc::new(RwLock::new(banlist)),
stats: stats, stats: stats,
} }
} }
@@ -161,6 +162,8 @@ impl ConnectionPool {
_ => addresses.len(), _ => addresses.len(),
}; };
debug!("Allowed attempts for {:?}: {}", role, allowed_attempts);
let exists = match role { let exists = match role {
Some(role) => addresses.iter().filter(|addr| addr.role == role).count() > 0, Some(role) => addresses.iter().filter(|addr| addr.role == role).count() > 0,
None => true, None => true,
@@ -251,14 +254,14 @@ impl ConnectionPool {
pub fn ban(&self, address: &Address, shard: usize) { pub fn ban(&self, address: &Address, shard: usize) {
error!("Banning {:?}", address); error!("Banning {:?}", address);
let now = chrono::offset::Utc::now().naive_utc(); let now = chrono::offset::Utc::now().naive_utc();
let mut guard = self.banlist.lock().unwrap(); let mut guard = self.banlist.write();
guard[shard].insert(address.clone(), now); guard[shard].insert(address.clone(), now);
} }
/// Clear the replica to receive traffic again. Takes effect immediately /// Clear the replica to receive traffic again. Takes effect immediately
/// for all new transactions. /// for all new transactions.
pub fn _unban(&self, address: &Address, shard: usize) { pub fn _unban(&self, address: &Address, shard: usize) {
let mut guard = self.banlist.lock().unwrap(); let mut guard = self.banlist.write();
guard[shard].remove(address); guard[shard].remove(address);
} }
@@ -274,12 +277,14 @@ impl ConnectionPool {
Some(Role::Primary) => return false, // Primary cannot be banned. Some(Role::Primary) => return false, // Primary cannot be banned.
}; };
// If you're not asking for the primary, debug!("Available targets for {:?}: {}", role, replicas_available);
// all databases are treated as replicas.
let mut guard = self.banlist.lock().unwrap(); let guard = self.banlist.read();
// Everything is banned = nothing is banned. // Everything is banned = nothing is banned.
if guard[shard].len() == replicas_available { if guard[shard].len() == replicas_available {
drop(guard);
let mut guard = self.banlist.write();
guard[shard].clear(); guard[shard].clear();
drop(guard); drop(guard);
warn!("Unbanning all replicas."); warn!("Unbanning all replicas.");
@@ -291,16 +296,24 @@ impl ConnectionPool {
Some(timestamp) => { Some(timestamp) => {
let now = chrono::offset::Utc::now().naive_utc(); let now = chrono::offset::Utc::now().naive_utc();
let config = get_config(); let config = get_config();
// Ban expired. // Ban expired.
if now.timestamp() - timestamp.timestamp() > config.general.ban_time { if now.timestamp() - timestamp.timestamp() > config.general.ban_time {
drop(guard);
warn!("Unbanning {:?}", address);
let mut guard = self.banlist.write();
guard[shard].remove(address); guard[shard].remove(address);
false false
} else { } else {
debug!("{:?} is banned", address);
true true
} }
} }
None => false, None => {
debug!("{:?} is ok", address);
false
}
} }
} }

View File

@@ -1,7 +1,7 @@
use bytes::{Buf, BufMut, BytesMut}; use bytes::{Buf, BufMut, BytesMut};
///! Implementation of the PostgreSQL server (database) protocol. ///! Implementation of the PostgreSQL server (database) protocol.
///! Here we are pretending to the a Postgres client. ///! Here we are pretending to the a Postgres client.
use log::{debug, error, info}; use log::{debug, error, info, trace};
use tokio::io::{AsyncReadExt, BufReader}; use tokio::io::{AsyncReadExt, BufReader};
use tokio::net::{ use tokio::net::{
tcp::{OwnedReadHalf, OwnedWriteHalf}, tcp::{OwnedReadHalf, OwnedWriteHalf},
@@ -75,7 +75,7 @@ impl Server {
} }
}; };
debug!("Sending StartupMessage"); trace!("Sending StartupMessage");
// Send the startup packet telling the server we're a normal Postgres client. // Send the startup packet telling the server we're a normal Postgres client.
startup(&mut stream, &user.name, database).await?; startup(&mut stream, &user.name, database).await?;
@@ -97,7 +97,7 @@ impl Server {
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError),
}; };
debug!("Message: {}", code); trace!("Message: {}", code);
match code { match code {
// Authentication // Authentication
@@ -108,7 +108,7 @@ impl Server {
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError),
}; };
debug!("Auth: {}", auth_code); trace!("Auth: {}", auth_code);
match auth_code { match auth_code {
MD5_ENCRYPTED_PASSWORD => { MD5_ENCRYPTED_PASSWORD => {
@@ -141,7 +141,7 @@ impl Server {
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError),
}; };
debug!("Error: {}", error_code); trace!("Error: {}", error_code);
match error_code { match error_code {
// No error message is present in the message. // No error message is present in the message.
@@ -300,7 +300,7 @@ impl Server {
let code = message.get_u8() as char; let code = message.get_u8() as char;
let _len = message.get_i32(); let _len = message.get_i32();
debug!("Message: {}", code); trace!("Message: {}", code);
match code { match code {
// ReadyForQuery // ReadyForQuery
@@ -415,7 +415,7 @@ impl Server {
/// Claim this server as mine for the purposes of query cancellation. /// Claim this server as mine for the purposes of query cancellation.
pub fn claim(&mut self, process_id: i32, secret_key: i32) { pub fn claim(&mut self, process_id: i32, secret_key: i32) {
let mut guard = self.client_server_map.lock().unwrap(); let mut guard = self.client_server_map.lock();
guard.insert( guard.insert(
(process_id, secret_key), (process_id, secret_key),
( (
@@ -431,18 +431,9 @@ impl Server {
/// It will use the simple query protocol. /// It will use the simple query protocol.
/// Result will not be returned, so this is useful for things like `SET` or `ROLLBACK`. /// Result will not be returned, so this is useful for things like `SET` or `ROLLBACK`.
pub async fn query(&mut self, query: &str) -> Result<(), Error> { pub async fn query(&mut self, query: &str) -> Result<(), Error> {
let mut query = BytesMut::from(&query.as_bytes()[..]); let query = simple_query(query);
query.put_u8(0); // C-string terminator (NULL character).
let len = query.len() as i32 + 4; self.send(query).await?;
let mut msg = BytesMut::with_capacity(len as usize + 1);
msg.put_u8(b'Q');
msg.put_i32(len);
msg.put_slice(&query[..]);
self.send(msg).await?;
loop { loop {
let _ = self.recv().await?; let _ = self.recv().await?;