diff --git a/src/admin.rs b/src/admin.rs index 5108346..a19febd 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -2,8 +2,6 @@ use bytes::{Buf, BufMut, BytesMut}; use log::trace; use tokio::net::tcp::OwnedWriteHalf; -use std::collections::HashMap; - use crate::constants::{OID_NUMERIC, OID_TEXT}; use crate::errors::Error; use crate::messages::write_all_half; @@ -50,7 +48,7 @@ pub async fn show_stats(stream: &mut OwnedWriteHalf) -> Result<(), Error> { "avg_wait_time", ]; - let stats = get_stats().unwrap_or(HashMap::new()); + let stats = get_stats(); let mut res = BytesMut::new(); let mut row_desc = BytesMut::new(); let mut data_row = BytesMut::new(); diff --git a/src/messages.rs b/src/messages.rs index e24ad9d..8f7080f 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -91,9 +91,8 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu } } -/// Parse StartupMessage parameters. -/// e.g. user, database, application_name, etc. -pub fn parse_startup(mut bytes: BytesMut) -> Result, Error> { +/// Parse the params the server sends as a key/value format. +pub fn parse_params(mut bytes: BytesMut) -> Result, Error> { let mut result = HashMap::new(); let mut buf = Vec::new(); let mut tmp = String::new(); @@ -115,7 +114,7 @@ pub fn parse_startup(mut bytes: BytesMut) -> Result, Err // Expect pairs of name and value // and at least one pair to be present. - if buf.len() % 2 != 0 && buf.len() >= 2 { + if buf.len() % 2 != 0 || buf.len() < 2 { return Err(Error::ClientBadStartup); } @@ -127,6 +126,14 @@ pub fn parse_startup(mut bytes: BytesMut) -> Result, Err i += 2; } + Ok(result) +} + +/// Parse StartupMessage parameters. +/// e.g. user, database, application_name, etc. +pub fn parse_startup(bytes: BytesMut) -> Result, Error> { + let result = parse_params(bytes)?; + // Minimum required parameters // I want to have the user at the very minimum, according to the protocol spec. if !result.contains_key("user") { diff --git a/src/pool.rs b/src/pool.rs index b1dda4f..26d172b 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -126,10 +126,22 @@ impl ConnectionPool { }; let mut proxy = connection.0; - let _address = connection.1; + let address = connection.1; let server = &mut *proxy; - server_infos.push(server.server_info()); + let server_info = server.server_info(); + + if server_infos.len() > 0 { + // Compare against the last server checked. + if server_info != server_infos[server_infos.len() - 1] { + warn!( + "{:?} has different server configuration than the last server", + address + ); + } + } + + server_infos.push(server_info); } } diff --git a/src/stats.rs b/src/stats.rs index 399bee6..99f709d 100644 --- a/src/stats.rs +++ b/src/stats.rs @@ -1,15 +1,15 @@ -use log::{debug, error, info}; -use once_cell::sync::OnceCell; +use log::{debug, info}; +use once_cell::sync::Lazy; +use parking_lot::Mutex; use statsd::Client; -/// Events collector and publisher. use tokio::sync::mpsc::{Receiver, Sender}; use std::collections::HashMap; -use std::sync::{Arc, Mutex}; use crate::config::get_config; -static LATEST_STATS: OnceCell>>> = OnceCell::new(); +// Stats used in SHOW STATS +static LATEST_STATS: Lazy>> = Lazy::new(|| Mutex::new(HashMap::new())); static STAT_PERIOD: u64 = 15000; //15 seconds #[derive(Debug, Clone, Copy)] @@ -187,16 +187,6 @@ impl Reporter { let _ = self.tx.try_send(event); } - - // pub fn flush_to_statsd(&self) { - // let event = Event { - // name: EventName::FlushStatsToStatsD, - // value: 0, - // process_id: None, - // }; - - // let _ = self.tx.try_send(event); - // } } pub struct Collector { @@ -217,13 +207,6 @@ impl Collector { pub async fn collect(&mut self) { info!("Events reporter started"); - match LATEST_STATS.set(Arc::new(Mutex::new(HashMap::new()))) { - Ok(_) => (), - Err(_) => { - error!("Latest stats will not be available"); - } - }; - let mut stats = HashMap::from([ ("total_query_count", 0), ("total_xact_count", 0), @@ -400,16 +383,10 @@ impl Collector { debug!("{:?}", stats); // Update latest stats used in SHOW STATS - match LATEST_STATS.get() { - Some(arc) => { - let mut guard = arc.lock().unwrap(); - for (key, value) in &stats { - guard.insert(key.to_string(), value.clone()); - } - } - - None => (), - }; + let mut guard = LATEST_STATS.lock(); + for (key, value) in &stats { + guard.insert(key.to_string(), value.clone()); + } let mut pipeline = self.client.pipeline(); @@ -440,13 +417,6 @@ impl Collector { } } -pub fn get_stats() -> Option> { - match LATEST_STATS.get() { - Some(arc) => { - let guard = arc.lock().unwrap(); - Some(guard.clone()) - } - - None => None, - } +pub fn get_stats() -> HashMap { + LATEST_STATS.lock().clone() }