diff --git a/Cargo.lock b/Cargo.lock index 4cc882b..75fd0c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,15 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "block-buffer" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1d36a02058e76b040de25a4464ba1c80935655595b661505c8b39b664828b95" +dependencies = [ + "generic-array", +] + [[package]] name = "bytes" version = "1.1.0" @@ -20,6 +29,36 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "crypto-common" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d6b536309245c849479fba3da410962a43ed8e51c26b729208ec0ac2798d0" +dependencies = [ + "generic-array", +] + +[[package]] +name = "digest" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b697d66081d42af4fba142d56918a3cb21dc8eb63372c6b85d14f44fb9c5979b" +dependencies = [ + "block-buffer", + "crypto-common", + "generic-array", +] + +[[package]] +name = "generic-array" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd48d33ec7f05fbfa152300fdad764757cbded343c1aa1cff2fbaf4134851803" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "hermit-abi" version = "0.1.19" @@ -62,6 +101,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "md-5" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6a38fc55c8bbc10058782919516f88826e70320db6d206aebc49611d24216ae" +dependencies = [ + "digest", +] + [[package]] name = "memchr" version = "2.4.1" @@ -169,6 +217,7 @@ name = "rabbit" version = "0.1.0" dependencies = [ "bytes", + "md-5", "tokio", ] @@ -243,12 +292,24 @@ dependencies = [ "syn", ] +[[package]] +name = "typenum" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" + [[package]] name = "unicode-xid" version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3" +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + [[package]] name = "winapi" version = "0.3.9" diff --git a/Cargo.toml b/Cargo.toml index d3772c2..4d81899 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,4 +7,5 @@ edition = "2021" [dependencies] tokio = { version = "1", features = ["full"] } -bytes = "1" \ No newline at end of file +bytes = "1" +md-5 = "*" \ No newline at end of file diff --git a/src/client.rs b/src/client.rs index 894a196..6cebeea 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,14 +1,14 @@ +use tokio::io::{AsyncReadExt, BufReader}; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; /// PostgreSQL client (frontend). /// We are pretending to be the backend. - use tokio::net::TcpStream; -use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; -use tokio::io::{AsyncReadExt, BufReader}; -use bytes::{BytesMut, Buf, BufMut}; +use bytes::{Buf, BufMut, BytesMut}; use crate::errors::Error; use crate::messages::*; +use crate::server::Server; pub struct Client { read: BufReader, @@ -43,7 +43,7 @@ impl Client { no.put_u8(b'N'); write_all(&mut stream, no).await?; - }, + } // Regular startup message. 196608 => { @@ -58,7 +58,7 @@ impl Client { read: BufReader::new(read), write: write, }); - }, + } _ => { return Err(Error::ProtocolSyncError); @@ -67,22 +67,29 @@ impl Client { } } - pub async fn handle(&mut self) -> Result<(), Error> { + pub async fn handle(&mut self, mut server: Server) -> Result<(), Error> { loop { let mut message = read_message(&mut self.read).await?; let original = message.clone(); // To be forwarded to the server let code = message.get_u8() as char; - let len = message.get_i32() as usize; + let _len = message.get_i32() as usize; match code { 'Q' => { - println!(">>> Query: {:?}", message); - }, + server.send(original).await?; + let response = server.recv().await?; + write_all_half(&mut self.write, response).await?; + } + + 'X' => { + // Client closing + return Ok(()); + } _ => { println!(">>> Unexpected code: {}", code); - }, + } } } } -} \ No newline at end of file +} diff --git a/src/errors.rs b/src/errors.rs index 4e3e026..ec076ee 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,7 +1,8 @@ #[derive(Debug, PartialEq)] pub enum Error { SocketError, - ClientDisconneted, + // ClientDisconnected, ClientBadStartup, ProtocolSyncError, -} \ No newline at end of file + ServerError, +} diff --git a/src/main.rs b/src/main.rs index 2121a13..c15f66e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,13 @@ extern crate bytes; +extern crate md5; extern crate tokio; use tokio::net::TcpListener; +mod client; mod errors; mod messages; -mod client; +mod server; #[tokio::main] async fn main() { @@ -30,13 +32,29 @@ async fn main() { // Client goes to another thread, bye. tokio::task::spawn(async move { - println!(">> Client {:?} connected", addr); + println!(">> Client {:?} connected.", addr); match client::Client::startup(socket).await { Ok(mut client) => { - println!(">> Client {:?} connected successfully!", addr); - client.handle().await; - }, + println!(">> Client {:?} authenticated successfully!", addr); + let server = + match server::Server::startup("127.0.0.1", "5432", "lev", "lev", "lev") + .await + { + Ok(server) => server, + Err(_) => return, + }; + + match client.handle(server).await { + Ok(()) => { + println!(">> Client {:?} disconnected.", addr); + } + + Err(err) => { + println!(">> Client disconnected with error: {:?}", err); + } + } + } Err(err) => { println!(">> Error: {:?}", err); diff --git a/src/messages.rs b/src/messages.rs index 754226b..86ad681 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -1,8 +1,8 @@ - -use tokio::net::TcpStream; -use tokio::net::tcp::OwnedReadHalf; -use tokio::io::{AsyncWriteExt, BufReader, AsyncReadExt}; use bytes::{BufMut, BytesMut}; +use md5::{Digest, Md5}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader}; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; +use tokio::net::TcpStream; use crate::errors::Error; @@ -26,6 +26,67 @@ pub async fn ready_for_query(stream: &mut TcpStream) -> Result<(), Error> { Ok(write_all(stream, bytes).await?) } +pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> { + let mut bytes = BytesMut::with_capacity(25); + + bytes.put_i32(196608); // Protocol number + + // User + bytes.put(&b"user\0"[..]); + bytes.put_slice(&user.as_bytes()); + bytes.put_u8(0); + + // Database + bytes.put(&b"database\0"[..]); + bytes.put_slice(&database.as_bytes()); + bytes.put_u8(0); + bytes.put_u8(0); // Null terminator + + let len = bytes.len() as i32 + 4i32; + + let mut startup = BytesMut::with_capacity(len as usize); + + startup.put_i32(len); + startup.put(bytes); + + match stream.write_all(&startup).await { + Ok(_) => Ok(()), + Err(_) => return Err(Error::SocketError), + } +} + +pub async fn md5_password( + stream: &mut TcpStream, + user: &str, + password: &str, + salt: &[u8], +) -> Result<(), Error> { + let mut md5 = Md5::new(); + + // First pass + md5.update(&password.as_bytes()); + md5.update(&user.as_bytes()); + + let output = md5.finalize_reset(); + + // Second pass + md5.update(format!("{:x}", output)); + md5.update(salt); + + let mut password = format!("md5{:x}", md5.finalize()) + .chars() + .map(|x| x as u8) + .collect::>(); + password.push(0); + + let mut message = BytesMut::with_capacity(password.len() as usize + 5); + message.put_u8(b'p'); + message.put_i32(password.len() as i32 + 4); + message.put_slice(&password[..]); + + Ok(write_all(stream, message).await?) +} + pub async fn write_all(stream: &mut TcpStream, buf: BytesMut) -> Result<(), Error> { match stream.write_all(&buf).await { Ok(_) => Ok(()), @@ -33,6 +94,13 @@ pub async fn write_all(stream: &mut TcpStream, buf: BytesMut) -> Result<(), Erro } } +pub async fn write_all_half(stream: &mut OwnedWriteHalf, buf: BytesMut) -> Result<(), Error> { + match stream.write_all(&buf).await { + Ok(_) => Ok(()), + Err(_) => return Err(Error::SocketError), + } +} + /// Read a complete message from the socket. pub async fn read_message(stream: &mut BufReader) -> Result { let code = match stream.read_u8().await { @@ -59,4 +127,4 @@ pub async fn read_message(stream: &mut BufReader) -> Result, + write: OwnedWriteHalf, + buffer: BytesMut, +} + +impl Server { + pub async fn startup( + host: &str, + port: &str, + user: &str, + password: &str, + database: &str, + ) -> Result { + let mut stream = match TcpStream::connect(&format!("{}:{}", host, port)).await { + Ok(stream) => stream, + Err(err) => { + println!(">> Could not connect to server: {}", err); + return Err(Error::SocketError); + } + }; + + startup(&mut stream, user, database).await?; + + loop { + let code = match stream.read_u8().await { + Ok(code) => code as char, + Err(_) => return Err(Error::SocketError), + }; + + let len = match stream.read_i32().await { + Ok(len) => len, + Err(_) => return Err(Error::SocketError), + }; + + match code { + 'R' => { + // Auth can proceed + let code = match stream.read_i32().await { + Ok(code) => code, + Err(_) => return Err(Error::SocketError), + }; + + match code { + // MD5 + 5 => { + let mut salt = vec![0u8; 4]; + + match stream.read_exact(&mut salt).await { + Ok(_) => (), + Err(_) => return Err(Error::SocketError), + }; + + md5_password(&mut stream, user, password, &salt[..]).await?; + } + + // We're in! + 0 => { + println!(">> Server authentication successful!"); + } + + _ => { + println!(">> Unsupported authentication mechanism: {}", code); + return Err(Error::ServerError); + } + } + } + + 'E' => { + println!(">> Database error"); + return Err(Error::ServerError); + } + + 'S' => { + // Parameter + let mut param = vec![0u8; len as usize - 4]; + match stream.read_exact(&mut param).await { + Ok(_) => (), + Err(_) => return Err(Error::SocketError), + }; + } + + 'K' => { + // TODO: save cancellation secret + let mut cancel_secret = vec![0u8; len as usize - 4]; + match stream.read_exact(&mut cancel_secret).await { + Ok(_) => (), + Err(_) => return Err(Error::SocketError), + }; + } + + 'Z' => { + let mut idle = vec![0u8; len as usize - 4]; + + match stream.read_exact(&mut idle).await { + Ok(_) => (), + Err(_) => return Err(Error::SocketError), + }; + + // Startup finished + let (read, write) = stream.into_split(); + + return Ok(Server { + read: BufReader::new(read), + write: write, + buffer: BytesMut::with_capacity(8196), + }); + } + + _ => { + println!(">> Unknown code: {}", code); + return Err(Error::ProtocolSyncError); + } + }; + } + } + + pub async fn send(&mut self, messages: BytesMut) -> Result<(), Error> { + Ok(write_all_half(&mut self.write, messages).await?) + } + + pub async fn recv(&mut self) -> Result { + loop { + let mut message = read_message(&mut self.read).await?; + + // Buffer the message we'll forward to the client in a bit. + self.buffer.put(&message[..]); + + let code = message.get_u8() as char; + match code { + 'Z' => { + // Ready for query, time to forward buffer to client. + break; + } + + _ => { + // Keep buffering, + } + }; + } + + let bytes = self.buffer.clone(); + self.buffer.clear(); + + Ok(bytes) + } + + // pub async fn handle(&mut self) -> Result<(), Error> { + // loop { + // let message = read_message(&mut self.read).await?; + // let original = message.clone(); + + // let code = message.get_u8() as char; + // let len = message.get_i32(); + + // match code { + // ' + // } + // } + // } +}