diff --git a/Cargo.lock b/Cargo.lock index 75fd0c0..321690f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,30 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "async-trait" +version = "0.1.52" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "061a7acccaa286c011ddc30970520b98fa40e00c9d644633fb26b5fc63a265e3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "bb8" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e9f4fa9768efd269499d8fba693260cfc670891cf6de3adc935588447a77cc8" +dependencies = [ + "async-trait", + "futures-channel", + "futures-util", + "parking_lot", + "tokio", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -49,6 +73,41 @@ dependencies = [ "generic-array", ] +[[package]] +name = "futures-channel" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3dda0b6588335f360afc675d0564c17a77a2bda81ca178a4b6081bd86c7f0b" +dependencies = [ + "futures-core", +] + +[[package]] +name = "futures-core" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0c8ff0461b82559810cdccfde3215c3f373807f5e5232b71479bff7bb2583d7" + +[[package]] +name = "futures-task" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ee7c6485c30167ce4dfb83ac568a849fe53274c831081476ee13e0dce1aad72" + +[[package]] +name = "futures-util" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b5cf40b47a271f77a8b1bec03ca09044d99d2372c0de244e66430761127164" +dependencies = [ + "futures-channel", + "futures-core", + "futures-task", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "generic-array" version = "0.14.5" @@ -194,6 +253,12 @@ version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e280fbe77cc62c91527259e9442153f4688736748d24660126286329742b4c6c" +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "proc-macro2" version = "1.0.36" @@ -216,6 +281,8 @@ dependencies = [ name = "rabbit" version = "0.1.0" dependencies = [ + "async-trait", + "bb8", "bytes", "md-5", "tokio", @@ -245,6 +312,12 @@ dependencies = [ "libc", ] +[[package]] +name = "slab" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9def91fd1e018fe007022791f865d0ccc9b3a0d5001e01aabb8b40e46000afb5" + [[package]] name = "smallvec" version = "1.8.0" diff --git a/Cargo.toml b/Cargo.toml index 4d81899..06f6f00 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,4 +8,6 @@ edition = "2021" [dependencies] tokio = { version = "1", features = ["full"] } bytes = "1" -md-5 = "*" \ No newline at end of file +md-5 = "*" +bb8 = "*" +async-trait = "*" \ No newline at end of file diff --git a/src/client.rs b/src/client.rs index 03a523c..0c83fb6 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,4 +1,4 @@ -use tokio::io::{AsyncReadExt, BufReader}; +use tokio::io::{AsyncReadExt, BufReader, Interest}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; /// PostgreSQL client (frontend). /// We are pretending to be the backend. @@ -8,7 +8,9 @@ use bytes::{Buf, BufMut, BytesMut}; use crate::errors::Error; use crate::messages::*; -use crate::server::Server; + +use bb8::Pool; +use crate::pool::ServerPool; pub struct Client { read: BufReader, @@ -69,54 +71,77 @@ impl Client { } } - pub async fn handle(&mut self, mut server: Server) -> Result<(), Error> { + pub async fn handle(&mut self, pool: Pool) -> Result<(), Error> { + loop { - let mut message = read_message(&mut self.read).await?; - let original = message.clone(); // To be forwarded to the server - let code = message.get_u8() as char; - let _len = message.get_i32() as usize; + // Only grab a connection once we have some traffic on the socket + // TODO: this is not the most optimal way to share servers. + match self.read.get_ref().ready(Interest::READABLE).await { + Ok(_) => (), + Err(_) => return Err(Error::ClientDisconnected), + }; - match code { - 'Q' => { - server.send(original).await?; - let response = server.recv().await?; - write_all_half(&mut self.write, response).await?; - } + let mut proxy = pool.get().await.unwrap(); + let server = &mut *proxy; - 'X' => { - // Client closing - return Ok(()); - } + loop { + let mut message = read_message(&mut self.read).await?; + let original = message.clone(); // To be forwarded to the server + let code = message.get_u8() as char; + let _len = message.get_i32() as usize; - 'P' => { - // Extended protocol, let's buffer most of it - self.buffer.put(&original[..]); - } + match code { + 'Q' => { + server.send(original).await?; + let response = server.recv().await?; + write_all_half(&mut self.write, response).await?; - 'B' => { - self.buffer.put(&original[..]); - } + // Release server + if !server.in_transaction() { + break; + } + } - 'D' => { - self.buffer.put(&original[..]); - } + 'X' => { + // Client closing + return Ok(()); + } - 'E' => { - self.buffer.put(&original[..]); - } + 'P' => { + // Extended protocol, let's buffer most of it + self.buffer.put(&original[..]); + } - 'S' => { - // Extended protocol, client requests sync - self.buffer.put(&original[..]); - server.send(self.buffer.clone()).await?; - self.buffer.clear(); + 'B' => { + self.buffer.put(&original[..]); + } - let response = server.recv().await?; - write_all_half(&mut self.write, response).await?; - } + 'D' => { + self.buffer.put(&original[..]); + } - _ => { - println!(">>> Unexpected code: {}", code); + 'E' => { + self.buffer.put(&original[..]); + } + + 'S' => { + // Extended protocol, client requests sync + self.buffer.put(&original[..]); + server.send(self.buffer.clone()).await?; + self.buffer.clear(); + + let response = server.recv().await?; + write_all_half(&mut self.write, response).await?; + + // Release server + if !server.in_transaction() { + break; + } + } + + _ => { + println!(">>> Unexpected code: {}", code); + } } } } diff --git a/src/errors.rs b/src/errors.rs index ec076ee..34f7da2 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,7 +1,7 @@ #[derive(Debug, PartialEq)] pub enum Error { SocketError, - // ClientDisconnected, + ClientDisconnected, ClientBadStartup, ProtocolSyncError, ServerError, diff --git a/src/main.rs b/src/main.rs index c15f66e..d7965fd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,13 +1,17 @@ extern crate bytes; extern crate md5; extern crate tokio; +extern crate async_trait; +extern crate bb8; use tokio::net::TcpListener; +use bb8::Pool; mod client; mod errors; mod messages; mod server; +mod pool; #[tokio::main] async fn main() { @@ -21,7 +25,12 @@ async fn main() { } }; + let manager = pool::ServerPool::new("127.0.0.1", "5432", "lev", "lev", "lev"); + let pool = Pool::builder().max_size(15).build(manager).await.unwrap(); + loop { + let pool = pool.clone(); + let (socket, addr) = match listener.accept().await { Ok((socket, addr)) => (socket, addr), Err(err) => { @@ -33,19 +42,14 @@ async fn main() { // Client goes to another thread, bye. tokio::task::spawn(async move { println!(">> Client {:?} connected.", addr); + + let pool = pool.clone(); match client::Client::startup(socket).await { Ok(mut client) => { println!(">> Client {:?} authenticated successfully!", addr); - let server = - match server::Server::startup("127.0.0.1", "5432", "lev", "lev", "lev") - .await - { - Ok(server) => server, - Err(_) => return, - }; - match client.handle(server).await { + match client.handle(pool).await { Ok(()) => { println!(">> Client {:?} disconnected.", addr); } diff --git a/src/pool.rs b/src/pool.rs new file mode 100644 index 0000000..cd1cc89 --- /dev/null +++ b/src/pool.rs @@ -0,0 +1,48 @@ + +use async_trait::async_trait; +use bb8::{ManageConnection, PooledConnection}; + +use crate::server::Server; +use crate::errors::Error; + +pub struct ServerPool { + host: String, + port: String, + user: String, + password: String, + database: String, +} + +impl ServerPool { + pub fn new(host: &str, port: &str, user: &str, password: &str, database: &str) -> ServerPool { + ServerPool { + host: host.to_string(), + port: port.to_string(), + user: user.to_string(), + password: password.to_string(), + database: database.to_string(), + } + } +} + +#[async_trait] +impl ManageConnection for ServerPool { + type Connection = Server; + type Error = Error; + + /// Attempts to create a new connection. + async fn connect(&self) -> Result { + println!(">> Getting connetion from pool"); + Ok(Server::startup(&self.host, &self.port, &self.user, &self.password, &self.database).await?) + } + + /// Determines if the connection is still connected to the database. + async fn is_valid(&self, _conn: &mut PooledConnection<'_, Self>) -> Result<(), Self::Error> { + Ok(()) + } + + /// Synchronously determine if the connection is no longer usable, if possible. + fn has_broken(&self, _conn: &mut Self::Connection) -> bool { + false + } +} \ No newline at end of file diff --git a/src/server.rs b/src/server.rs index 1e7059d..ddcbacf 100644 --- a/src/server.rs +++ b/src/server.rs @@ -10,6 +10,7 @@ pub struct Server { read: BufReader, write: OwnedWriteHalf, buffer: BytesMut, + in_transaction: bool, } impl Server { @@ -112,6 +113,7 @@ impl Server { read: BufReader::new(read), write: write, buffer: BytesMut::with_capacity(8196), + in_transaction: false, }); } @@ -135,12 +137,29 @@ impl Server { self.buffer.put(&message[..]); let code = message.get_u8() as char; + let _len = message.get_i32(); match code { 'Z' => { // Ready for query, time to forward buffer to client. + let transaction_state = message.get_u8() as char; + + match transaction_state { + 'T' => { + self.in_transaction = true; + }, + + 'I' => { + self.in_transaction = false; + }, + + _ => { + self.in_transaction = false; + }, + }; + break; - } + }, _ => { // Keep buffering, @@ -153,4 +172,8 @@ impl Server { Ok(bytes) } + + pub fn in_transaction(&self) -> bool { + self.in_transaction + } }