diff --git a/pgcat.toml b/pgcat.toml index 283ed74..df2ba71 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -66,7 +66,7 @@ tcp_keepalives_interval = 5 # tls_private_key = ".circleci/server.key" # Enable/disable server TLS -server_tls = true +server_tls = false # Verify server certificate is completely authentic. verify_server_certificate = false diff --git a/src/messages.rs b/src/messages.rs index 58b785b..0e980fe 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -523,6 +523,29 @@ where } } +pub async fn write_all_flush(stream: &mut S, buf: &[u8]) -> Result<(), Error> +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ + match stream.write_all(buf).await { + Ok(_) => match stream.flush().await { + Ok(_) => Ok(()), + Err(err) => { + return Err(Error::SocketError(format!( + "Error flushing socket - Error: {:?}", + err + ))) + } + }, + Err(err) => { + return Err(Error::SocketError(format!( + "Error writing to socket - Error: {:?}", + err + ))) + } + } +} + /// Read a complete message from the socket. pub async fn read_message(stream: &mut S) -> Result where diff --git a/src/server.rs b/src/server.rs index be74acc..4c8d90b 100644 --- a/src/server.rs +++ b/src/server.rs @@ -9,7 +9,7 @@ use std::collections::HashMap; use std::io::Read; use std::sync::Arc; use std::time::SystemTime; -use tokio::io::{split, AsyncRead, AsyncReadExt, AsyncWrite, BufReader, ReadHalf, WriteHalf}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream}; use tokio::net::TcpStream; use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore}; use tokio_rustls::{client::TlsStream, TlsConnector}; @@ -22,34 +22,23 @@ use crate::mirrors::MirroringManager; use crate::pool::ClientServerMap; use crate::scram::ScramSha256; use crate::stats::ServerStats; +use std::io::Write; use pin_project::pin_project; -#[pin_project(project = ReadInnerProj)] -pub enum ReadInner { +#[pin_project(project = SteamInnerProj)] +pub enum StreamInner { Plain { #[pin] - stream: ReadHalf, + stream: TcpStream, }, Tls { #[pin] - stream: ReadHalf>, + stream: TlsStream, }, } -#[pin_project(project = WriteInnerProj)] -pub enum WriteInner { - Plain { - #[pin] - stream: WriteHalf, - }, - Tls { - #[pin] - stream: WriteHalf>, - }, -} - -impl AsyncWrite for WriteInner { +impl AsyncWrite for StreamInner { fn poll_write( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -57,8 +46,8 @@ impl AsyncWrite for WriteInner { ) -> std::task::Poll> { let this = self.project(); match this { - WriteInnerProj::Tls { stream } => stream.poll_write(cx, buf), - WriteInnerProj::Plain { stream } => stream.poll_write(cx, buf), + SteamInnerProj::Tls { stream } => stream.poll_write(cx, buf), + SteamInnerProj::Plain { stream } => stream.poll_write(cx, buf), } } @@ -68,8 +57,8 @@ impl AsyncWrite for WriteInner { ) -> std::task::Poll> { let this = self.project(); match this { - WriteInnerProj::Tls { stream } => stream.poll_flush(cx), - WriteInnerProj::Plain { stream } => stream.poll_flush(cx), + SteamInnerProj::Tls { stream } => stream.poll_flush(cx), + SteamInnerProj::Plain { stream } => stream.poll_flush(cx), } } @@ -79,13 +68,13 @@ impl AsyncWrite for WriteInner { ) -> std::task::Poll> { let this = self.project(); match this { - WriteInnerProj::Tls { stream } => stream.poll_shutdown(cx), - WriteInnerProj::Plain { stream } => stream.poll_shutdown(cx), + SteamInnerProj::Tls { stream } => stream.poll_shutdown(cx), + SteamInnerProj::Plain { stream } => stream.poll_shutdown(cx), } } } -impl AsyncRead for ReadInner { +impl AsyncRead for StreamInner { fn poll_read( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -93,8 +82,21 @@ impl AsyncRead for ReadInner { ) -> std::task::Poll> { let this = self.project(); match this { - ReadInnerProj::Tls { stream } => stream.poll_read(cx, buf), - ReadInnerProj::Plain { stream } => stream.poll_read(cx, buf), + SteamInnerProj::Tls { stream } => stream.poll_read(cx, buf), + SteamInnerProj::Plain { stream } => stream.poll_read(cx, buf), + } + } +} + +impl StreamInner { + pub fn try_write(&mut self, buf: &[u8]) -> std::io::Result { + match self { + StreamInner::Tls { stream } => { + let r = stream.get_mut(); + let mut w = r.1.writer(); + w.write(buf) + } + StreamInner::Plain { stream } => stream.try_write(buf), } } } @@ -105,11 +107,8 @@ pub struct Server { /// port, e.g. 5432, and role, e.g. primary or replica. address: Address, - /// Buffered read socket. - read: BufReader, - - /// Unbuffered write socket (our client code buffers). - write: WriteInner, + /// Server TCP connection. + stream: BufStream, /// Our server response buffer. We buffer data before we give it to the client. buffer: BytesMut, @@ -177,7 +176,7 @@ impl Server { // TCP timeouts. configure_socket(&stream); - let (mut read, mut write) = if get_config().general.server_tls { + let mut stream = if get_config().general.server_tls { // Request a TLS connection ssl_request(&mut stream).await?; @@ -232,21 +231,11 @@ impl Server { } }; - let (read, write) = split(stream); - ( - ReadInner::Tls { stream: read }, - WriteInner::Tls { stream: write }, - ) + StreamInner::Tls { stream } } // Server does not support TLS - 'N' => { - let (read, write) = split(stream); - ( - ReadInner::Plain { stream: read }, - WriteInner::Plain { stream: write }, - ) - } + 'N' => StreamInner::Plain { stream }, // Something else? m => { @@ -257,11 +246,7 @@ impl Server { } } } else { - let (read, write) = split(stream); - ( - ReadInner::Plain { stream: read }, - WriteInner::Plain { stream: write }, - ) + StreamInner::Plain { stream } }; // let (read, write) = split(stream); @@ -283,7 +268,7 @@ impl Server { }, }; - startup(&mut write, username, database).await?; + startup(&mut stream, username, database).await?; let mut server_info = BytesMut::new(); let mut process_id: i32 = 0; @@ -298,7 +283,7 @@ impl Server { }; loop { - let code = match read.read_u8().await { + let code = match stream.read_u8().await { Ok(code) => code as char, Err(_) => { return Err(Error::ServerStartupError( @@ -308,7 +293,7 @@ impl Server { } }; - let len = match read.read_i32().await { + let len = match stream.read_i32().await { Ok(len) => len, Err(_) => { return Err(Error::ServerStartupError( @@ -324,7 +309,7 @@ impl Server { // Authentication 'R' => { // Determine which kind of authentication is required, if any. - let auth_code = match read.read_i32().await { + let auth_code = match stream.read_i32().await { Ok(auth_code) => auth_code, Err(_) => { return Err(Error::ServerStartupError( @@ -342,7 +327,7 @@ impl Server { // See: https://www.postgresql.org/docs/12/protocol-message-formats.html let mut salt = vec![0u8; 4]; - match read.read_exact(&mut salt).await { + match stream.read_exact(&mut salt).await { Ok(_) => (), Err(_) => { return Err(Error::ServerStartupError( @@ -355,7 +340,7 @@ impl Server { match password { // Using plaintext password Some(password) => { - md5_password(&mut write, username, password, &salt[..]).await? + md5_password(&mut stream, username, password, &salt[..]).await? } // Using auth passthrough, in this case we should already have a @@ -366,7 +351,7 @@ impl Server { match option_hash { Some(hash) => md5_password_with_hash( - &mut write, + &mut stream, &hash, &salt[..], ) @@ -400,7 +385,7 @@ impl Server { let sasl_len = (len - 8) as usize; let mut sasl_auth = vec![0u8; sasl_len]; - match read.read_exact(&mut sasl_auth).await { + match stream.read_exact(&mut sasl_auth).await { Ok(_) => (), Err(_) => { return Err(Error::ServerStartupError( @@ -435,7 +420,7 @@ impl Server { res.put_i32(sasl_response.len() as i32); res.put(sasl_response); - write_all(&mut write, res).await?; + write_all_flush(&mut stream, &res).await?; } else { error!("Unsupported SCRAM version: {}", sasl_type); return Err(Error::ServerError); @@ -447,7 +432,7 @@ impl Server { let mut sasl_data = vec![0u8; (len - 8) as usize]; - match read.read_exact(&mut sasl_data).await { + match stream.read_exact(&mut sasl_data).await { Ok(_) => (), Err(_) => { return Err(Error::ServerStartupError( @@ -466,14 +451,14 @@ impl Server { res.put_i32(4 + sasl_response.len() as i32); res.put(sasl_response); - write_all(&mut write, res).await?; + write_all_flush(&mut stream, &res).await?; } SASL_FINAL => { trace!("Final SASL"); let mut sasl_final = vec![0u8; len as usize - 8]; - match read.read_exact(&mut sasl_final).await { + match stream.read_exact(&mut sasl_final).await { Ok(_) => (), Err(_) => { return Err(Error::ServerStartupError( @@ -508,7 +493,7 @@ impl Server { // ErrorResponse 'E' => { - let error_code = match read.read_u8().await { + let error_code = match stream.read_u8().await { Ok(error_code) => error_code, Err(_) => { return Err(Error::ServerStartupError( @@ -529,7 +514,7 @@ impl Server { // Read the error message without the terminating null character. let mut error = vec![0u8; len as usize - 4 - 1]; - match read.read_exact(&mut error).await { + match stream.read_exact(&mut error).await { Ok(_) => (), Err(_) => { return Err(Error::ServerStartupError( @@ -553,7 +538,7 @@ impl Server { 'S' => { let mut param = vec![0u8; len as usize - 4]; - match read.read_exact(&mut param).await { + match stream.read_exact(&mut param).await { Ok(_) => (), Err(_) => { return Err(Error::ServerStartupError( @@ -575,7 +560,7 @@ impl Server { 'K' => { // The frontend must save these values if it wishes to be able to issue CancelRequest messages later. // See: . - process_id = match read.read_i32().await { + process_id = match stream.read_i32().await { Ok(id) => id, Err(_) => { return Err(Error::ServerStartupError( @@ -585,7 +570,7 @@ impl Server { } }; - secret_key = match read.read_i32().await { + secret_key = match stream.read_i32().await { Ok(id) => id, Err(_) => { return Err(Error::ServerStartupError( @@ -600,7 +585,7 @@ impl Server { 'Z' => { let mut idle = vec![0u8; len as usize - 4]; - match read.read_exact(&mut idle).await { + match stream.read_exact(&mut idle).await { Ok(_) => (), Err(_) => { return Err(Error::ServerStartupError( @@ -612,8 +597,7 @@ impl Server { let mut server = Server { address: address.clone(), - read: BufReader::new(read), - write, + stream: BufStream::new(stream), buffer: BytesMut::with_capacity(8196), server_info, process_id, @@ -680,7 +664,7 @@ impl Server { bytes.put_i32(process_id); bytes.put_i32(secret_key); - write_all(&mut stream, bytes).await + write_all_flush(&mut stream, &bytes).await } /// Send messages to the server from the client. @@ -688,7 +672,7 @@ impl Server { self.mirror_send(messages); self.stats().data_sent(messages.len()); - match write_all_half(&mut self.write, messages).await { + match write_all_flush(&mut self.stream, &messages).await { Ok(_) => { // Successfully sent to server self.last_activity = SystemTime::now(); @@ -707,7 +691,7 @@ impl Server { /// in order to receive all data the server has to offer. pub async fn recv(&mut self) -> Result { loop { - let mut message = match read_message(&mut self.read).await { + let mut message = match read_message(&mut self.stream).await { Ok(message) => message, Err(err) => { error!("Terminating server because of: {:?}", err); @@ -1100,14 +1084,14 @@ impl Drop for Server { // Update statistics self.stats.disconnect(); - // let mut bytes = BytesMut::with_capacity(4); - // bytes.put_u8(b'X'); - // bytes.put_i32(4); + let mut bytes = BytesMut::with_capacity(4); + bytes.put_u8(b'X'); + bytes.put_i32(4); - // match self.write.try_write(&bytes) { - // Ok(_) => (), - // Err(_) => debug!("Dirty shutdown"), - // }; + match self.stream.get_mut().try_write(&bytes) { + Ok(_) => (), + Err(_) => debug!("Dirty shutdown"), + }; // Should not matter. self.bad = true;