From 9e51b8110fcadb53f24af342dff403061bab8562 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Fri, 28 Apr 2023 11:20:49 -0700 Subject: [PATCH] Server TLS --- Cargo.lock | 21 ++++++++ Cargo.toml | 1 + src/client.rs | 13 ++++- src/messages.rs | 15 ++++++ src/server.rs | 130 +++++++++++++++++++++++++++++++++++++++++++----- 5 files changed, 166 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bf961a6..610a13d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -762,6 +762,7 @@ dependencies = [ "once_cell", "parking_lot", "phf", + "pin-project", "postgres-protocol", "rand", "regex", @@ -820,6 +821,26 @@ dependencies = [ "siphasher", ] +[[package]] +name = "pin-project" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "pin-project-lite" version = "0.2.9" diff --git a/Cargo.toml b/Cargo.toml index a557351..c1f8f34 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ nix = "0.26.2" atomic_enum = "0.2.0" postgres-protocol = "0.6.5" fallible-iterator = "0.2" +pin-project = "*" [target.'cfg(not(target_env = "msvc"))'.dependencies] jemallocator = "0.5.0" diff --git a/src/client.rs b/src/client.rs index 5098ec6..efde755 100644 --- a/src/client.rs +++ b/src/client.rs @@ -539,6 +539,7 @@ where Some(md5_hash_password(username, password, &salt)) } else { if !get_config().is_auth_query_configured() { + wrong_password(&mut write, username).await?; return Err(Error::ClientAuthImpossible(username.into())); } @@ -565,6 +566,8 @@ where } Err(err) => { + wrong_password(&mut write, username).await?; + return Err(Error::ClientAuthPassthroughError( err.to_string(), client_identifier, @@ -587,7 +590,15 @@ where client_identifier ); - let fetched_hash = refetch_auth_hash(&pool).await?; + let fetched_hash = match refetch_auth_hash(&pool).await { + Ok(fetched_hash) => fetched_hash, + Err(err) => { + wrong_password(&mut write, username).await?; + + return Err(err); + } + }; + let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt); // Ok password changed in server an auth is possible. diff --git a/src/messages.rs b/src/messages.rs index ba4818c..58d9b26 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -150,6 +150,21 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu } } +pub async fn ssl_request(stream: &mut TcpStream) -> Result<(), Error> { + let mut bytes = BytesMut::with_capacity(12); + + bytes.put_i32(8); + bytes.put_i32(80877103); + + match stream.write_all(&bytes).await { + Ok(_) => Ok(()), + Err(err) => Err(Error::SocketError(format!( + "Error writing SSLRequest to server socket - Error: {:?}", + err + ))), + } +} + /// Parse the params the server sends as a key/value format. pub fn parse_params(mut bytes: BytesMut) -> Result, Error> { let mut result = HashMap::new(); diff --git a/src/server.rs b/src/server.rs index 84bed6c..722108b 100644 --- a/src/server.rs +++ b/src/server.rs @@ -9,11 +9,13 @@ use std::collections::HashMap; use std::io::Read; use std::sync::Arc; use std::time::SystemTime; -use tokio::io::{AsyncReadExt, BufReader}; +use tokio::io::{split, AsyncRead, AsyncReadExt, AsyncWrite, BufReader, ReadHalf, WriteHalf}; use tokio::net::{ tcp::{OwnedReadHalf, OwnedWriteHalf}, TcpStream, }; +use tokio_rustls::rustls::ClientConfig; +use tokio_rustls::{TlsConnector, TlsStream}; use crate::config::{Address, User}; use crate::constants::*; @@ -24,6 +26,82 @@ use crate::pool::ClientServerMap; use crate::scram::ScramSha256; use crate::stats::ServerStats; +use pin_project::pin_project; + +#[pin_project(project = ReadInnerProj)] +pub enum ReadInner { + Plain { + #[pin] + stream: ReadHalf, + }, + Tls { + #[pin] + stream: ReadHalf>, + }, +} + +#[pin_project(project = WriteInnerProj)] +pub enum WriteInner { + Plain { + #[pin] + stream: WriteHalf, + }, + Tls { + #[pin] + stream: WriteHalf>, + }, +} + +impl AsyncWrite for WriteInner { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + let this = self.project(); + match this { + WriteInnerProj::Tls { stream } => stream.poll_write(cx, buf), + WriteInnerProj::Plain { stream } => stream.poll_write(cx, buf), + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.project(); + match this { + WriteInnerProj::Tls { stream } => stream.poll_flush(cx), + WriteInnerProj::Plain { stream } => stream.poll_flush(cx), + } + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.project(); + match this { + WriteInnerProj::Tls { stream } => stream.poll_shutdown(cx), + WriteInnerProj::Plain { stream } => stream.poll_shutdown(cx), + } + } +} + +impl AsyncRead for ReadInner { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + let this = self.project(); + match this { + ReadInnerProj::Tls { stream } => stream.poll_read(cx, buf), + ReadInnerProj::Plain { stream } => stream.poll_read(cx, buf), + } + } +} + /// Server state. pub struct Server { /// Server host, e.g. localhost, @@ -31,10 +109,10 @@ pub struct Server { address: Address, /// Buffered read socket. - read: BufReader, + read: BufReader, /// Unbuffered write socket (our client code buffers). - write: OwnedWriteHalf, + write: WriteInner, /// Our server response buffer. We buffer data before we give it to the client. buffer: BytesMut, @@ -100,6 +178,32 @@ impl Server { }; configure_socket(&stream); + // ssl_request(&mut stream).await?; + // let response = match stream.read_u8().await { + // Ok(response) => response as char, + // Err(err) => return Err(Error::SocketError(format!("Server socket error: {:?}", err))), + // }; + + // match response { + // 'S' => { + // let connector = TlsConnector::from(ClientConfig::builder() + // .with_safe_default_cipher_suites() + // .with_safe_default_kx_groups() + // .with_safe_default_protocol_versions() + // .unwrap() + // .with_no_client_auth()); + // connector.connect("test".into(), stream).await.unwrap(); + // }, + + // 'N' => { + + // }, + + // _ => { + // return Err(Error::SocketError("error".into())); + // } + // }; + trace!("Sending StartupMessage"); // StartupMessage @@ -443,12 +547,12 @@ impl Server { } }; - let (read, write) = stream.into_split(); + let (read, write) = split(stream); let mut server = Server { address: address.clone(), - read: BufReader::new(read), - write, + read: BufReader::new(ReadInner::Plain { stream: read }), + write: WriteInner::Plain { stream: write }, buffer: BytesMut::with_capacity(8196), server_info, process_id, @@ -935,14 +1039,14 @@ impl Drop for Server { // Update statistics self.stats.disconnect(); - let mut bytes = BytesMut::with_capacity(4); - bytes.put_u8(b'X'); - bytes.put_i32(4); + // let mut bytes = BytesMut::with_capacity(4); + // bytes.put_u8(b'X'); + // bytes.put_i32(4); - match self.write.try_write(&bytes) { - Ok(_) => (), - Err(_) => debug!("Dirty shutdown"), - }; + // match self.write.try_write(&bytes) { + // Ok(_) => (), + // Err(_) => debug!("Dirty shutdown"), + // }; // Should not matter. self.bad = true;