From 6deb7b1162d5302e2e6c8d871954bd7a8aabd98c Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Thu, 3 Feb 2022 17:06:19 -0800 Subject: [PATCH] name servers & dont leave open servers with bad state --- Cargo.lock | 64 +++++++++++++++++++++++++++++++++++++++++++++++++ Cargo.toml | 3 ++- src/client.rs | 46 ++++++++++++++++++++++++++++++----- src/main.rs | 10 ++++---- src/pool.rs | 20 ++++++++++------ src/server.rs | 66 ++++++++++++++++++++++++++++++++++++++++++++------- 6 files changed, 182 insertions(+), 27 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 321690f..9f87edc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -118,6 +118,17 @@ dependencies = [ "version_check", ] +[[package]] +name = "getrandom" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "418d37c8b1d42553c93648be529cb70f920d3baf8ef469b74b9638df426e0b4c" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "hermit-abi" version = "0.1.19" @@ -259,6 +270,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "ppv-lite86" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" + [[package]] name = "proc-macro2" version = "1.0.36" @@ -285,9 +302,50 @@ dependencies = [ "bb8", "bytes", "md-5", + "rand", "tokio", ] +[[package]] +name = "rand" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e7573632e6454cf6b99d7aac4ccca54be06da05aca2ef7423d22d27d4d4bcd8" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", + "rand_hc", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d34f1408f55294453790c48b2f1ebbb1c5b4b7563eb1f418bcfcfdbb06ebb4e7" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_hc" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d51e9f596de227fda2ea6c84607f5558e196eeaf43c986b724ba4fb8fdf497e7" +dependencies = [ + "rand_core", +] + [[package]] name = "redox_syscall" version = "0.2.10" @@ -383,6 +441,12 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "wasi" +version = "0.10.2+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6" + [[package]] name = "winapi" version = "0.3.9" diff --git a/Cargo.toml b/Cargo.toml index 06f6f00..28a079d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,4 +10,5 @@ tokio = { version = "1", features = ["full"] } bytes = "1" md-5 = "*" bb8 = "*" -async-trait = "*" \ No newline at end of file +async-trait = "*" +rand = "*" \ No newline at end of file diff --git a/src/client.rs b/src/client.rs index 0c83fb6..c2e5a13 100644 --- a/src/client.rs +++ b/src/client.rs @@ -9,13 +9,15 @@ use bytes::{Buf, BufMut, BytesMut}; use crate::errors::Error; use crate::messages::*; -use bb8::Pool; use crate::pool::ServerPool; +use bb8::Pool; +use rand::{distributions::Alphanumeric, Rng}; pub struct Client { read: BufReader, write: OwnedWriteHalf, buffer: BytesMut, + name: String, } impl Client { @@ -57,10 +59,17 @@ impl Client { let (read, write) = stream.into_split(); + let name: String = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(7) + .map(char::from) + .collect(); + return Ok(Client { read: BufReader::new(read), write: write, buffer: BytesMut::with_capacity(8196), + name: name, }); } @@ -72,7 +81,6 @@ impl Client { } pub async fn handle(&mut self, pool: Pool) -> Result<(), Error> { - loop { // Only grab a connection once we have some traffic on the socket // TODO: this is not the most optimal way to share servers. @@ -84,8 +92,20 @@ impl Client { let mut proxy = pool.get().await.unwrap(); let server = &mut *proxy; + server.set_name(&self.name).await?; + loop { - let mut message = read_message(&mut self.read).await?; + let mut message = match read_message(&mut self.read).await { + Ok(message) => message, + Err(err) => { + if server.in_transaction() { + server.mark_bad(); + } + + return Err(err); + } + }; + let original = message.clone(); // To be forwarded to the server let code = message.get_u8() as char; let _len = message.get_i32() as usize; @@ -94,10 +114,17 @@ impl Client { 'Q' => { server.send(original).await?; let response = server.recv().await?; - write_all_half(&mut self.write, response).await?; + match write_all_half(&mut self.write, response).await { + Ok(_) => (), + Err(err) => { + server.mark_bad(); + return Err(err); + } + }; - // Release server + // Release server if !server.in_transaction() { + drop(server); break; } } @@ -131,10 +158,17 @@ impl Client { self.buffer.clear(); let response = server.recv().await?; - write_all_half(&mut self.write, response).await?; + match write_all_half(&mut self.write, response).await { + Ok(_) => (), + Err(err) => { + server.mark_bad(); + return Err(err); + } + }; // Release server if !server.in_transaction() { + drop(server); break; } } diff --git a/src/main.rs b/src/main.rs index d7965fd..1a0ad67 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,17 +1,17 @@ +extern crate async_trait; +extern crate bb8; extern crate bytes; extern crate md5; extern crate tokio; -extern crate async_trait; -extern crate bb8; -use tokio::net::TcpListener; use bb8::Pool; +use tokio::net::TcpListener; mod client; mod errors; mod messages; -mod server; mod pool; +mod server; #[tokio::main] async fn main() { @@ -42,7 +42,7 @@ async fn main() { // Client goes to another thread, bye. tokio::task::spawn(async move { println!(">> Client {:?} connected.", addr); - + let pool = pool.clone(); match client::Client::startup(socket).await { diff --git a/src/pool.rs b/src/pool.rs index cd1cc89..11ca9a8 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -1,9 +1,8 @@ - use async_trait::async_trait; use bb8::{ManageConnection, PooledConnection}; -use crate::server::Server; use crate::errors::Error; +use crate::server::Server; pub struct ServerPool { host: String, @@ -30,10 +29,17 @@ impl ManageConnection for ServerPool { type Connection = Server; type Error = Error; - /// Attempts to create a new connection. + /// Attempts to create a new connection. async fn connect(&self) -> Result { println!(">> Getting connetion from pool"); - Ok(Server::startup(&self.host, &self.port, &self.user, &self.password, &self.database).await?) + Ok(Server::startup( + &self.host, + &self.port, + &self.user, + &self.password, + &self.database, + ) + .await?) } /// Determines if the connection is still connected to the database. @@ -42,7 +48,7 @@ impl ManageConnection for ServerPool { } /// Synchronously determine if the connection is no longer usable, if possible. - fn has_broken(&self, _conn: &mut Self::Connection) -> bool { - false + fn has_broken(&self, conn: &mut Self::Connection) -> bool { + conn.is_bad() } -} \ No newline at end of file +} diff --git a/src/server.rs b/src/server.rs index ddcbacf..7c4eced 100644 --- a/src/server.rs +++ b/src/server.rs @@ -11,6 +11,7 @@ pub struct Server { write: OwnedWriteHalf, buffer: BytesMut, in_transaction: bool, + bad: bool, } impl Server { @@ -114,6 +115,7 @@ impl Server { write: write, buffer: BytesMut::with_capacity(8196), in_transaction: false, + bad: false, }); } @@ -126,12 +128,26 @@ impl Server { } pub async fn send(&mut self, messages: BytesMut) -> Result<(), Error> { - Ok(write_all_half(&mut self.write, messages).await?) + match write_all_half(&mut self.write, messages).await { + Ok(_) => Ok(()), + Err(err) => { + println!(">> Terminating server because of: {:?}", err); + self.bad = true; + Err(err) + } + } } pub async fn recv(&mut self) -> Result { loop { - let mut message = read_message(&mut self.read).await?; + let mut message = match read_message(&mut self.read).await { + Ok(message) => message, + Err(err) => { + println!(">> Terminating server because of: {:?}", err); + self.bad = true; + return Err(err); + } + }; // Buffer the message we'll forward to the client in a bit. self.buffer.put(&message[..]); @@ -143,23 +159,30 @@ impl Server { 'Z' => { // Ready for query, time to forward buffer to client. let transaction_state = message.get_u8() as char; - + match transaction_state { 'T' => { self.in_transaction = true; - }, + } 'I' => { self.in_transaction = false; - }, + } + + // Error client didn't clean up! + // We shuold drop this server + 'E' => { + self.bad = true; + } _ => { - self.in_transaction = false; - }, + self.bad = true; + return Err(Error::ProtocolSyncError); + } }; break; - }, + } _ => { // Keep buffering, @@ -176,4 +199,31 @@ impl Server { pub fn in_transaction(&self) -> bool { self.in_transaction } + + pub fn is_bad(&self) -> bool { + self.bad + } + + pub fn mark_bad(&mut self) { + println!(">> Server marked bad"); + self.bad = true; + } + + pub async fn set_name(&mut self, name: &str) -> Result<(), Error> { + let mut query = BytesMut::from(&format!("SET application_name = {}", name).as_bytes()[..]); + query.put_u8(0); + + let len = query.len() as i32 + 4; + + let mut msg = BytesMut::with_capacity(len as usize + 1); + + msg.put_u8(b'Q'); + msg.put_i32(len); + msg.put_slice(&query[..]); + + self.send(msg).await?; + let _ = self.recv().await?; + + Ok(()) + } }