From 880f1a649f1fced244abaeff4c3191102ec5b3a0 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Thu, 3 Feb 2022 13:54:07 -0800 Subject: [PATCH] working q --- src/client.rs | 30 ++++++++++++++++-- src/main.rs | 7 +++-- src/messages.rs | 82 +++++++++++++++++++------------------------------ 3 files changed, 62 insertions(+), 57 deletions(-) diff --git a/src/client.rs b/src/client.rs index 0142dd3..894a196 100644 --- a/src/client.rs +++ b/src/client.rs @@ -2,7 +2,8 @@ /// We are pretending to be the backend. use tokio::net::TcpStream; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; +use tokio::io::{AsyncReadExt, BufReader}; use bytes::{BytesMut, Buf, BufMut}; @@ -10,7 +11,8 @@ use crate::errors::Error; use crate::messages::*; pub struct Client { - stream: TcpStream, + read: BufReader, + write: OwnedWriteHalf, } impl Client { @@ -50,8 +52,11 @@ impl Client { auth_ok(&mut stream).await?; ready_for_query(&mut stream).await?; + let (read, write) = stream.into_split(); + return Ok(Client { - stream: stream, + read: BufReader::new(read), + write: write, }); }, @@ -61,4 +66,23 @@ impl Client { }; } } + + pub async fn handle(&mut self) -> Result<(), Error> { + loop { + let mut message = read_message(&mut self.read).await?; + let original = message.clone(); // To be forwarded to the server + let code = message.get_u8() as char; + let len = message.get_i32() as usize; + + match code { + 'Q' => { + println!(">>> Query: {:?}", message); + }, + + _ => { + println!(">>> Unexpected code: {}", code); + }, + } + } + } } \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index cd6e69b..2121a13 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,8 +20,8 @@ async fn main() { }; loop { - let (mut socket, addr) = match listener.accept().await { - Ok((mut socket, addr)) => (socket, addr), + let (socket, addr) = match listener.accept().await { + Ok((socket, addr)) => (socket, addr), Err(err) => { println!("> Listener: {:?}", err); continue; @@ -33,8 +33,9 @@ async fn main() { println!(">> Client {:?} connected", addr); match client::Client::startup(socket).await { - Ok(client) => { + Ok(mut client) => { println!(">> Client {:?} connected successfully!", addr); + client.handle().await; }, Err(err) => { diff --git a/src/messages.rs b/src/messages.rs index c43da68..754226b 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -1,59 +1,11 @@ use tokio::net::TcpStream; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use bytes::{Buf, BufMut, BytesMut}; +use tokio::net::tcp::OwnedReadHalf; +use tokio::io::{AsyncWriteExt, BufReader, AsyncReadExt}; +use bytes::{BufMut, BytesMut}; use crate::errors::Error; -/// Handle the startup phase for the client. -/// This one is special because Startup and SSLRequest -/// packages don't start with a u8 letter code. -pub async fn handle_client_startup(stream: &mut TcpStream) -> Result<(), Error> { - loop { - // Could be StartupMessage or SSLRequest - // which makes this variable length. - let len = match stream.read_i32().await { - Ok(len) => len, - Err(_) => return Err(Error::ClientBadStartup), - }; - - // Read whatever is left. - let mut startup = vec![0u8; len as usize - 4]; - - match stream.read_exact(&mut startup).await { - Ok(_) => (), - Err(_) => return Err(Error::ClientBadStartup), - }; - - let mut bytes = BytesMut::from(&startup[..]); - let code = bytes.get_i32(); - - match code { - // Client wants SSL. We don't support it at the moment. - 80877103 => { - let mut no = BytesMut::with_capacity(1); - no.put_u8(b'N'); - - write_all(stream, no).await?; - }, - - // Regular startup message. - 196608 => { - // TODO: perform actual auth. - // TODO: record startup parameters client sends over. - auth_ok(stream).await?; - ready_for_query(stream).await?; - return Ok(()); - }, - - _ => { - return Err(Error::ProtocolSyncError); - } - }; - } -} - - pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> { let mut auth_ok = BytesMut::with_capacity(9); @@ -79,4 +31,32 @@ pub async fn write_all(stream: &mut TcpStream, buf: BytesMut) -> Result<(), Erro Ok(_) => Ok(()), Err(_) => return Err(Error::SocketError), } +} + +/// Read a complete message from the socket. +pub async fn read_message(stream: &mut BufReader) -> Result { + let code = match stream.read_u8().await { + Ok(code) => code, + Err(_) => return Err(Error::SocketError), + }; + + let len = match stream.read_i32().await { + Ok(len) => len, + Err(_) => return Err(Error::SocketError), + }; + + let mut buf = vec![0u8; len as usize - 4]; + + match stream.read_exact(&mut buf).await { + Ok(_) => (), + Err(_) => return Err(Error::SocketError), + }; + + let mut bytes = BytesMut::with_capacity(len as usize + 1); + + bytes.put_u8(code); + bytes.put_i32(len); + bytes.put_slice(&buf); + + Ok(bytes) } \ No newline at end of file