diff --git a/src/admin.rs b/src/admin.rs index b7a5b6f..622c14a 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -2,7 +2,7 @@ use bytes::{Buf, BufMut, BytesMut}; use log::{info, trace}; use std::collections::HashMap; -use tokio::net::tcp::OwnedWriteHalf; +// use tokio::net::tcp::T; use crate::config::{get_config, reload_config}; use crate::errors::Error; @@ -12,12 +12,15 @@ use crate::stats::get_stats; use crate::ClientServerMap; /// Handle admin client. -pub async fn handle_admin( - stream: &mut OwnedWriteHalf, +pub async fn handle_admin( + stream: &mut T, mut query: BytesMut, pool: ConnectionPool, client_server_map: ClientServerMap, -) -> Result<(), Error> { +) -> Result<(), Error> +where + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ let code = query.get_u8() as char; if code != 'Q' { @@ -61,7 +64,10 @@ pub async fn handle_admin( } /// Column-oriented statistics. -async fn show_lists(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> { +async fn show_lists(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error> +where + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ let stats = get_stats(); let columns = vec![("list", DataType::Text), ("items", DataType::Int4)]; @@ -128,7 +134,10 @@ async fn show_lists(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Resul } /// Show PgCat version. -async fn show_version(stream: &mut OwnedWriteHalf) -> Result<(), Error> { +async fn show_version(stream: &mut T) -> Result<(), Error> +where + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ let mut res = BytesMut::new(); res.put(row_description(&vec![("version", DataType::Text)])); @@ -143,7 +152,10 @@ async fn show_version(stream: &mut OwnedWriteHalf) -> Result<(), Error> { } /// Show utilization of connection pools for each shard and replicas. -async fn show_pools(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> { +async fn show_pools(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error> +where + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ let stats = get_stats(); let config = get_config(); @@ -197,7 +209,10 @@ async fn show_pools(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Resul } /// Show shards and replicas. -async fn show_databases(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> { +async fn show_databases(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error> +where + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ let config = get_config(); // Columns @@ -258,15 +273,18 @@ async fn show_databases(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> R /// Ignore any SET commands the client sends. /// This is common initialization done by ORMs. -async fn ignore_set(stream: &mut OwnedWriteHalf) -> Result<(), Error> { +async fn ignore_set(stream: &mut T) -> Result<(), Error> +where + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ custom_protocol_response_ok(stream, "SET").await } /// Reload the configuration file without restarting the process. -async fn reload( - stream: &mut OwnedWriteHalf, - client_server_map: ClientServerMap, -) -> Result<(), Error> { +async fn reload(stream: &mut T, client_server_map: ClientServerMap) -> Result<(), Error> +where + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ info!("Reloading config"); reload_config(client_server_map).await?; @@ -286,7 +304,10 @@ async fn reload( } /// Shows current configuration. -async fn show_config(stream: &mut OwnedWriteHalf) -> Result<(), Error> { +async fn show_config(stream: &mut T) -> Result<(), Error> +where + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ let config = &get_config(); let config: HashMap = config.into(); @@ -329,7 +350,10 @@ async fn show_config(stream: &mut OwnedWriteHalf) -> Result<(), Error> { } /// Show shard and replicas statistics. -async fn show_stats(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> { +async fn show_stats(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error> +where + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ let columns = vec![ ("database", DataType::Text), ("total_xact_count", DataType::Numeric), diff --git a/src/client.rs b/src/client.rs index c02fcf8..4e8ba73 100644 --- a/src/client.rs +++ b/src/client.rs @@ -2,7 +2,7 @@ use bytes::{Buf, BufMut, BytesMut}; use log::{debug, error, trace}; use std::collections::HashMap; -use tokio::io::{AsyncReadExt, BufReader}; +use tokio::io::{split, AsyncReadExt, AsyncWriteExt, BufReader, ReadHalf, WriteHalf}; use tokio::net::{ tcp::{OwnedReadHalf, OwnedWriteHalf}, TcpStream, @@ -17,15 +17,25 @@ use crate::pool::{get_pool, ClientServerMap}; use crate::query_router::{Command, QueryRouter}; use crate::server::Server; use crate::stats::{get_reporter, Reporter}; +use crate::stream::Tls; + +use tokio_rustls::server::TlsStream; + +/// Type of connection received from client. +enum ClientConnectionType { + Startup, + Tls, + CancelQuery, +} /// The client state. One of these is created per client. -pub struct Client { +pub struct Client { /// The reads are buffered (8K by default). - read: BufReader, + read: BufReader, /// We buffer the writes ourselves because we know the protocol /// better than a stock buffer. - write: S, + write: T, /// Internal buffer, where we place messages until we have to flush /// them to the backend. @@ -63,163 +73,347 @@ pub struct Client { last_server_id: Option, } -impl Client { - /// Perform client startup sequence. - /// See docs: - pub async fn startup( - mut stream: TcpStream, +/// Main client loop. +pub async fn client_loop( + mut stream: TcpStream, + client_server_map: ClientServerMap, +) -> Result<(), Error> { + match get_startup::(&mut stream).await { + Ok((ClientConnectionType::Tls, bytes)) => { + match startup_tls(stream, client_server_map).await { + Ok(mut client) => client.handle().await, + Err(err) => Err(err), + } + } + + Ok((ClientConnectionType::Startup, bytes)) => { + let (read, write) = split(stream); + match Client::handle_startup(read, write, bytes, client_server_map).await { + Ok(mut client) => client.handle().await, + Err(err) => Err(err), + } + } + + Ok((ClientConnectionType::CancelQuery, bytes)) => { + return Err(Error::ProtocolSyncError); + } + + Err(err) => Err(err), + } +} + +async fn get_startup(stream: &mut S) -> Result<(ClientConnectionType, BytesMut), Error> +where + S: tokio::io::AsyncRead + std::marker::Unpin + tokio::io::AsyncWrite, +{ + // Get startup message length. + let len = match stream.read_i32().await { + Ok(len) => len, + Err(_) => return Err(Error::ClientBadStartup), + }; + + // Get the rest of the message. + let mut startup = vec![0u8; len as usize - 4]; + match stream.read_exact(&mut startup).await { + Ok(_) => (), + Err(_) => return Err(Error::ClientBadStartup), + }; + + let mut bytes = BytesMut::from(&startup[..]); + let code = bytes.get_i32(); + + match code { + // Client is requesting SSL (TLS). + SSL_REQUEST_CODE => Ok((ClientConnectionType::Tls, bytes)), + + // Client wants to use plain text, requesting regular startup. + PROTOCOL_VERSION_NUMBER => Ok((ClientConnectionType::Startup, bytes)), + + // Client is requesting to cancel a running query (plain text connection). + CANCEL_REQUEST_CODE => Ok((ClientConnectionType::CancelQuery, bytes)), + _ => Err(Error::ProtocolSyncError), + } +} + +/// Handle TLS connection negotation. +pub async fn startup_tls( + mut stream: TcpStream, + client_server_map: ClientServerMap, +) -> Result>, WriteHalf>>, Error> { + // Accept SSL request if SSL is configured. + let mut yes = BytesMut::new(); + yes.put_u8(b'S'); + write_all(&mut stream, yes).await?; + + // Negotiate TLS. + let mut tls = Tls::new().unwrap(); + let mut stream = match tls.acceptor.accept(stream).await { + Ok(stream) => stream, + Err(_) => return Err(Error::TlsError), + }; + + match get_startup::>(&mut stream).await { + Ok((ClientConnectionType::Startup, bytes)) => { + let (read, write) = split(stream); + Client::handle_startup(read, write, bytes, client_server_map).await + } + _ => Err(Error::ProtocolSyncError), + } +} + +impl Client +where + S: tokio::io::AsyncRead + std::marker::Unpin, + T: tokio::io::AsyncWrite + std::marker::Unpin, +{ + // Perform client startup sequence in TLS. + pub async fn handle_startup( + mut read: S, + mut write: T, + bytes: BytesMut, // The rest of the startup message. client_server_map: ClientServerMap, - ) -> Result, Error> { + ) -> Result, Error> { let config = get_config(); let transaction_mode = config.general.pool_mode == "transaction"; let stats = get_reporter(); - loop { - trace!("Waiting for StartupMessage"); + trace!("Got StartupMessage"); + let parameters = parse_startup(bytes.clone())?; - // Could be StartupMessage, SSLRequest or CancelRequest. - let len = match stream.read_i32().await { - Ok(len) => len, - Err(_) => return Err(Error::ClientBadStartup), - }; + // Generate random backend ID and secret key + let process_id: i32 = rand::random(); + let secret_key: i32 = rand::random(); - let mut startup = vec![0u8; len as usize - 4]; + // Perform MD5 authentication. + // TODO: Add SASL support. + let salt = md5_challenge(&mut write).await?; - match stream.read_exact(&mut startup).await { - Ok(_) => (), - Err(_) => return Err(Error::ClientBadStartup), - }; + let code = match read.read_u8().await { + Ok(p) => p, + Err(_) => return Err(Error::SocketError), + }; - let mut bytes = BytesMut::from(&startup[..]); - let code = bytes.get_i32(); - - match code { - // Client wants SSL. We don't support it at the moment. - SSL_REQUEST_CODE => { - trace!("Rejecting SSLRequest"); - - let mut no = BytesMut::with_capacity(1); - no.put_u8(b'N'); - - write_all(&mut stream, no).await?; - } - - // Regular startup message. - PROTOCOL_VERSION_NUMBER => { - trace!("Got StartupMessage"); - let parameters = parse_startup(bytes.clone())?; - - // Generate random backend ID and secret key - let process_id: i32 = rand::random(); - let secret_key: i32 = rand::random(); - - // Perform MD5 authentication. - // TODO: Add SASL support. - let salt = md5_challenge(&mut stream).await?; - - let code = match stream.read_u8().await { - Ok(p) => p, - Err(_) => return Err(Error::SocketError), - }; - - // PasswordMessage - if code as char != 'p' { - debug!("Expected p, got {}", code as char); - return Err(Error::ProtocolSyncError); - } - - let len = match stream.read_i32().await { - Ok(len) => len, - Err(_) => return Err(Error::SocketError), - }; - - let mut password_response = vec![0u8; (len - 4) as usize]; - - match stream.read_exact(&mut password_response).await { - Ok(_) => (), - Err(_) => return Err(Error::SocketError), - }; - - // Compare server and client hashes. - let password_hash = - md5_hash_password(&config.user.name, &config.user.password, &salt); - - if password_hash != password_response { - debug!("Password authentication failed"); - wrong_password(&mut stream, &config.user.name).await?; - return Err(Error::ClientError); - } - - debug!("Password authentication successful"); - - auth_ok(&mut stream).await?; - write_all(&mut stream, get_pool().server_info()).await?; - backend_key_data(&mut stream, process_id, secret_key).await?; - ready_for_query(&mut stream).await?; - - trace!("Startup OK"); - - let database = parameters - .get("database") - .unwrap_or(parameters.get("user").unwrap()); - let admin = ["pgcat", "pgbouncer"] - .iter() - .filter(|db| *db == &database) - .count() - == 1; - - // Split the read and write streams - // so we can control buffering. - let (read, write) = stream.into_split(); - - return Ok(Client { - read: BufReader::new(read), - write: write, - buffer: BytesMut::with_capacity(8196), - cancel_mode: false, - transaction_mode: transaction_mode, - process_id: process_id, - secret_key: secret_key, - client_server_map: client_server_map, - parameters: parameters, - stats: stats, - admin: admin, - last_address_id: None, - last_server_id: None, - }); - } - - // Query cancel request. - CANCEL_REQUEST_CODE => { - let (read, write) = stream.into_split(); - - let process_id = bytes.get_i32(); - let secret_key = bytes.get_i32(); - - return Ok(Client { - read: BufReader::new(read), - write: write, - buffer: BytesMut::with_capacity(8196), - cancel_mode: true, - transaction_mode: transaction_mode, - process_id: process_id, - secret_key: secret_key, - client_server_map: client_server_map, - parameters: HashMap::new(), - stats: stats, - admin: false, - last_address_id: None, - last_server_id: None, - }); - } - - _ => { - return Err(Error::ProtocolSyncError); - } - }; + // PasswordMessage + if code as char != 'p' { + debug!("Expected p, got {}", code as char); + return Err(Error::ProtocolSyncError); } + + let len = match read.read_i32().await { + Ok(len) => len, + Err(_) => return Err(Error::SocketError), + }; + + let mut password_response = vec![0u8; (len - 4) as usize]; + + match read.read_exact(&mut password_response).await { + Ok(_) => (), + Err(_) => return Err(Error::SocketError), + }; + + // Compare server and client hashes. + let password_hash = md5_hash_password(&config.user.name, &config.user.password, &salt); + + if password_hash != password_response { + debug!("Password authentication failed"); + wrong_password(&mut write, &config.user.name).await?; + return Err(Error::ClientError); + } + + debug!("Password authentication successful"); + + auth_ok(&mut write).await?; + write_all(&mut write, get_pool().server_info()).await?; + backend_key_data(&mut write, process_id, secret_key).await?; + ready_for_query(&mut write).await?; + + trace!("Startup OK"); + + let database = parameters + .get("database") + .unwrap_or(parameters.get("user").unwrap()); + let admin = ["pgcat", "pgbouncer"] + .iter() + .filter(|db| *db == &database) + .count() + == 1; + + // Split the read and write streams + // so we can control buffering. + + return Ok(Client { + read: BufReader::new(read), + write: write, + buffer: BytesMut::with_capacity(8196), + cancel_mode: false, + transaction_mode: transaction_mode, + process_id: process_id, + secret_key: secret_key, + client_server_map: client_server_map, + parameters: parameters, + stats: stats, + admin: admin, + last_address_id: None, + last_server_id: None, + }); } + /// Perform client startup sequence. + /// See docs: + // pub async fn startup( + // mut stream: TcpStream, + // client_server_map: ClientServerMap, + // ) -> Result, WriteHalf>, Error> { + // let config = get_config(); + // let transaction_mode = config.general.pool_mode == "transaction"; + // let stats = get_reporter(); + + // loop { + // trace!("Waiting for StartupMessage"); + + // // Could be StartupMessage, SSLRequest or CancelRequest. + // let len = match stream.read_i32().await { + // Ok(len) => len, + // Err(_) => return Err(Error::ClientBadStartup), + // }; + + // let mut startup = vec![0u8; len as usize - 4]; + + // match stream.read_exact(&mut startup).await { + // Ok(_) => (), + // Err(_) => return Err(Error::ClientBadStartup), + // }; + + // let mut bytes = BytesMut::from(&startup[..]); + // let code = bytes.get_i32(); + + // match code { + // // Client wants SSL. We don't support it at the moment. + // SSL_REQUEST_CODE => { + // trace!("Rejecting SSLRequest"); + + // let mut no = BytesMut::with_capacity(1); + // no.put_u8(b'N'); + + // write_all(&mut stream, no).await?; + // } + + // // Regular startup message. + // PROTOCOL_VERSION_NUMBER => { + // trace!("Got StartupMessage"); + // let parameters = parse_startup(bytes.clone())?; + + // // Generate random backend ID and secret key + // let process_id: i32 = rand::random(); + // let secret_key: i32 = rand::random(); + + // // Perform MD5 authentication. + // // TODO: Add SASL support. + // let salt = md5_challenge(&mut stream).await?; + + // let code = match stream.read_u8().await { + // Ok(p) => p, + // Err(_) => return Err(Error::SocketError), + // }; + + // // PasswordMessage + // if code as char != 'p' { + // debug!("Expected p, got {}", code as char); + // return Err(Error::ProtocolSyncError); + // } + + // let len = match stream.read_i32().await { + // Ok(len) => len, + // Err(_) => return Err(Error::SocketError), + // }; + + // let mut password_response = vec![0u8; (len - 4) as usize]; + + // match stream.read_exact(&mut password_response).await { + // Ok(_) => (), + // Err(_) => return Err(Error::SocketError), + // }; + + // // Compare server and client hashes. + // let password_hash = + // md5_hash_password(&config.user.name, &config.user.password, &salt); + + // if password_hash != password_response { + // debug!("Password authentication failed"); + // wrong_password(&mut stream, &config.user.name).await?; + // return Err(Error::ClientError); + // } + + // debug!("Password authentication successful"); + + // auth_ok(&mut stream).await?; + // write_all(&mut stream, get_pool().server_info()).await?; + // backend_key_data(&mut stream, process_id, secret_key).await?; + // ready_for_query(&mut stream).await?; + + // trace!("Startup OK"); + + // let database = parameters + // .get("database") + // .unwrap_or(parameters.get("user").unwrap()); + // let admin = ["pgcat", "pgbouncer"] + // .iter() + // .filter(|db| *db == &database) + // .count() + // == 1; + + // // Split the read and write streams + // // so we can control buffering. + // let (read, write) = split(stream); + + // return Ok(Client { + // read: BufReader::new(read), + // write: write, + // buffer: BytesMut::with_capacity(8196), + // cancel_mode: false, + // transaction_mode: transaction_mode, + // process_id: process_id, + // secret_key: secret_key, + // client_server_map: client_server_map, + // parameters: parameters, + // stats: stats, + // admin: admin, + // last_address_id: None, + // last_server_id: None, + // }); + // } + + // // Query cancel request. + // CANCEL_REQUEST_CODE => { + // let (read, write) = split(stream); + + // let process_id = bytes.get_i32(); + // let secret_key = bytes.get_i32(); + + // return Ok(Client { + // read: BufReader::new(read), + // write: write, + // buffer: BytesMut::with_capacity(8196), + // cancel_mode: true, + // transaction_mode: transaction_mode, + // process_id: process_id, + // secret_key: secret_key, + // client_server_map: client_server_map, + // parameters: HashMap::new(), + // stats: stats, + // admin: false, + // last_address_id: None, + // last_server_id: None, + // }); + // } + + // _ => { + // return Err(Error::ProtocolSyncError); + // } + // }; + // } + // } + /// Handle a connected and authenticated client. pub async fn handle(&mut self) -> Result<(), Error> { // The client wants to cancel a query it has issued previously. @@ -414,8 +608,8 @@ impl Drop for Client { +impl Drop for Client { fn drop(&mut self) { // Update statistics. if let Some(address_id) = self.last_address_id { diff --git a/src/main.rs b/src/main.rs index 05bfb23..439c01a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -28,13 +28,13 @@ extern crate log; extern crate md5; extern crate num_cpus; extern crate once_cell; +extern crate rustls_pemfile; extern crate serde; extern crate serde_derive; extern crate sqlparser; extern crate tokio; -extern crate toml; extern crate tokio_rustls; -extern crate rustls_pemfile; +extern crate toml; use log::{debug, error, info}; use parking_lot::Mutex; @@ -45,6 +45,11 @@ use tokio::{ sync::mpsc, }; +use tokio::net::{ + tcp::{OwnedReadHalf, OwnedWriteHalf}, + TcpStream, +}; + use std::collections::HashMap; use std::sync::Arc; @@ -62,6 +67,7 @@ mod sharding; mod stats; mod stream; +use crate::constants::*; use config::{get_config, reload_config}; use pool::{ClientServerMap, ConnectionPool}; use stats::{Collector, Reporter, REPORTER}; @@ -153,32 +159,45 @@ async fn main() { // Handle client. tokio::task::spawn(async move { let start = chrono::offset::Utc::now().naive_utc(); - match client::Client::startup(socket, client_server_map).await { - Ok(mut client) => { - info!("Client {:?} connected", addr); - - match client.handle().await { - Ok(()) => { - let duration = chrono::offset::Utc::now().naive_utc() - start; - - info!( - "Client {:?} disconnected, session duration: {}", - addr, - format_duration(&duration) - ); - } - - Err(err) => { - error!("Client disconnected with error: {:?}", err); - client.release(); - } - } - } + // match client::get_startup(&mut socket) { + // Ok((code, bytes)) => match code { + // SSL_REQUEST_CODE => client::Client::tls_startup< + // } + // } + match client::client_loop(socket, client_server_map).await { + Ok(_) => (), Err(err) => { debug!("Client failed to login: {:?}", err); } }; + + // match client::Client::::startup(socket, client_server_map).await { + // Ok(mut client) => { + // info!("Client {:?} connected", addr); + + // match client.handle().await { + // Ok(()) => { + // let duration = chrono::offset::Utc::now().naive_utc() - start; + + // info!( + // "Client {:?} disconnected, session duration: {}", + // addr, + // format_duration(&duration) + // ); + // } + + // Err(err) => { + // error!("Client disconnected with error: {:?}", err); + // client.release(); + // } + // } + // } + + // Err(err) => { + // debug!("Client failed to login: {:?}", err); + // } + // }; }); } }); diff --git a/src/messages.rs b/src/messages.rs index 960a3b6..7b04792 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -31,7 +31,9 @@ impl From<&DataType> for i32 { /// Tell the client that authentication handshake completed successfully. pub async fn auth_ok(stream: &mut S) -> Result<(), Error> -where S: tokio::io::AsyncWrite + std::marker::Unpin { +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ let mut auth_ok = BytesMut::with_capacity(9); auth_ok.put_u8(b'R'); @@ -43,7 +45,9 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin { /// Generate md5 password challenge. pub async fn md5_challenge(stream: &mut S) -> Result<[u8; 4], Error> -where S: tokio::io::AsyncWrite + std::marker::Unpin { +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ // let mut rng = rand::thread_rng(); let salt: [u8; 4] = [ rand::random(), @@ -69,7 +73,9 @@ pub async fn backend_key_data( backend_id: i32, secret_key: i32, ) -> Result<(), Error> -where S: tokio::io::AsyncWrite + std::marker::Unpin { +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ let mut key_data = BytesMut::from(&b"K"[..]); key_data.put_i32(12); key_data.put_i32(backend_id); @@ -91,7 +97,9 @@ pub fn simple_query(query: &str) -> BytesMut { /// Tell the client we're ready for another query. pub async fn ready_for_query(stream: &mut S) -> Result<(), Error> -where S: tokio::io::AsyncWrite + std::marker::Unpin { +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ let mut bytes = BytesMut::with_capacity(5); bytes.put_u8(b'Z'); @@ -215,7 +223,9 @@ pub async fn md5_password( password: &str, salt: &[u8], ) -> Result<(), Error> -where S: tokio::io::AsyncWrite + std::marker::Unpin { +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ let password = md5_hash_password(user, password, salt); let mut message = BytesMut::with_capacity(password.len() as usize + 5); @@ -230,11 +240,10 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin { /// Implements a response to our custom `SET SHARDING KEY` /// and `SET SERVER ROLE` commands. /// This tells the client we're ready for the next query. -pub async fn custom_protocol_response_ok( - stream: &mut S, - message: &str, -) -> Result<(), Error> -where S: tokio::io::AsyncWrite + std::marker::Unpin { +pub async fn custom_protocol_response_ok(stream: &mut S, message: &str) -> Result<(), Error> +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ let mut res = BytesMut::with_capacity(25); let set_complete = BytesMut::from(&format!("{}\0", message)[..]); @@ -257,7 +266,9 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin { /// Tell the client we are ready for the next query and no rollback is necessary. /// Docs on error codes: . pub async fn error_response(stream: &mut S, message: &str) -> Result<(), Error> -where S: tokio::io::AsyncWrite + std::marker::Unpin { +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ let mut error = BytesMut::new(); // Error level @@ -299,7 +310,9 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin { } pub async fn wrong_password(stream: &mut S, user: &str) -> Result<(), Error> -where S: tokio::io::AsyncWrite + std::marker::Unpin { +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ let mut error = BytesMut::new(); // Error level @@ -333,11 +346,10 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin { } /// Respond to a SHOW SHARD command. -pub async fn show_response( - stream: &mut OwnedWriteHalf, - name: &str, - value: &str, -) -> Result<(), Error> { +pub async fn show_response(stream: &mut S, name: &str, value: &str) -> Result<(), Error> +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ // A SELECT response consists of: // 1. RowDescription // 2. One or more DataRow @@ -439,7 +451,9 @@ pub fn command_complete(command: &str) -> BytesMut { /// Write all data in the buffer to the TcpStream. pub async fn write_all(stream: &mut S, buf: BytesMut) -> Result<(), Error> -where S: tokio::io::AsyncWrite + std::marker::Unpin { +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ match stream.write_all(&buf).await { Ok(_) => Ok(()), Err(_) => return Err(Error::SocketError), @@ -448,7 +462,9 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin { /// Write all the data in the buffer to the TcpStream, write owned half (see mpsc). pub async fn write_all_half(stream: &mut S, buf: BytesMut) -> Result<(), Error> -where S: tokio::io::AsyncWrite + std::marker::Unpin { +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ match stream.write_all(&buf).await { Ok(_) => Ok(()), Err(_) => return Err(Error::SocketError), @@ -456,7 +472,10 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin { } /// Read a complete message from the socket. -pub async fn read_message(stream: &mut BufReader) -> Result { +pub async fn read_message(stream: &mut S) -> Result +where + S: tokio::io::AsyncRead + std::marker::Unpin, +{ let code = match stream.read_u8().await { Ok(code) => code, Err(_) => return Err(Error::SocketError), diff --git a/src/stream.rs b/src/stream.rs index 41d3cd1..9c19b89 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,17 +1,17 @@ // Stream wrapper. use bytes::{Buf, BufMut, BytesMut}; -use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, split, ReadHalf, WriteHalf}; +use rustls_pemfile::{certs, rsa_private_keys}; +use std::path::Path; +use std::sync::Arc; +use tokio::io::{split, AsyncReadExt, AsyncWriteExt, BufReader, ReadHalf, WriteHalf}; use tokio::net::{ tcp::{OwnedReadHalf, OwnedWriteHalf}, TcpStream, }; -use tokio_rustls::server::TlsStream; -use rustls_pemfile::{certs, rsa_private_keys}; use tokio_rustls::rustls::{self, Certificate, PrivateKey}; +use tokio_rustls::server::TlsStream; use tokio_rustls::TlsAcceptor; -use std::sync::Arc; -use std::path::Path; use crate::config::get_config; use crate::errors::Error; @@ -29,132 +29,91 @@ fn load_keys(path: &std::path::Path) -> std::io::Result> { .map(|mut keys| keys.drain(..).map(PrivateKey).collect()) } -struct Tls { - acceptor: TlsAcceptor, +pub struct Tls { + pub acceptor: TlsAcceptor, } impl Tls { - pub fn new() -> Result { - let config = get_config(); + pub fn new() -> Result { + let config = get_config(); - let certs = match load_certs(&Path::new(&config.general.tls_certificate.unwrap())) { - Ok(certs) => certs, - Err(_) => return Err(Error::TlsError), - }; + let certs = match load_certs(&Path::new(&config.general.tls_certificate.unwrap())) { + Ok(certs) => certs, + Err(_) => return Err(Error::TlsError), + }; - let mut keys = match load_keys(&Path::new(&config.general.tls_private_key.unwrap())) { - Ok(keys) => keys, - Err(_) => return Err(Error::TlsError), - }; + let mut keys = match load_keys(&Path::new(&config.general.tls_private_key.unwrap())) { + Ok(keys) => keys, + Err(_) => return Err(Error::TlsError), + }; - let config = match rustls::ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(certs, keys.remove(0)) - .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err)) { - Ok(c) => c, - Err(_) => return Err(Error::TlsError) + let config = match rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(certs, keys.remove(0)) + .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err)) + { + Ok(c) => c, + Err(_) => return Err(Error::TlsError), }; Ok(Tls { - acceptor: TlsAcceptor::from(Arc::new(config)), + acceptor: TlsAcceptor::from(Arc::new(config)), }) - } + } } struct Stream { - read: Option>, - write: Option, - tls_read: Option>>>, - tls_write: Option>>, + read: Option>, + write: Option, + tls_read: Option>>>, + tls_write: Option>>, } - impl Stream { - pub async fn new(stream: TcpStream, tls: Option) -> Result { + pub async fn new(stream: TcpStream, tls: Option) -> Result { + let config = get_config(); - let config = get_config(); + match tls { + None => { + let (read, write) = stream.into_split(); + let read = BufReader::new(read); + Ok(Self { + read: Some(read), + write: Some(write), + tls_read: None, + tls_write: None, + }) + } - match tls { - None => { - let (read, write) = stream.into_split(); - let read = BufReader::new(read); - Ok( - Self { - read: Some(read), - write: Some(write), - tls_read: None, - tls_write: None, - } - ) - } + Some(tls) => { + let mut tls_stream = match tls.acceptor.accept(stream).await { + Ok(stream) => stream, + Err(_) => return Err(Error::TlsError), + }; - Some(tls) => { - let mut tls_stream = match tls.acceptor.accept(stream).await { - Ok(stream) => stream, - Err(_) => return Err(Error::TlsError), - }; + let (read, write) = split(tls_stream); - let (read, write) = split(tls_stream); + Ok(Self { + read: None, + write: None, + tls_read: Some(BufReader::new(read)), + tls_write: Some(write), + }) + } + } + } +} - Ok(Self{ - read: None, - write: None, - tls_read: Some(BufReader::new(read)), - tls_write: Some(write), - }) - } - } - } - - async fn read(stream: &mut S) -> Result - where S: tokio::io::AsyncRead + std::marker::Unpin { - - let code = match stream.read_u8().await { - Ok(code) => code, - Err(_) => return Err(Error::SocketError), - }; - - let len = match stream.read_i32().await { - Ok(len) => len, - Err(_) => return Err(Error::SocketError), - }; - - let mut buf = vec![0u8; len as usize - 4]; - - match stream.read_exact(&mut buf).await { - Ok(_) => (), - Err(_) => return Err(Error::SocketError), - }; - - let mut bytes = BytesMut::with_capacity(len as usize + 1); - - bytes.put_u8(code); - bytes.put_i32(len); - bytes.put_slice(&buf); - - Ok(bytes) - } - - async fn write(stream: &mut S, buf: &BytesMut) -> Result<(), Error> - where S: tokio::io::AsyncWrite + std::marker::Unpin { - match stream.write_all(buf).await { - Ok(_) => Ok(()), - Err(_) => return Err(Error::SocketError), - } - } - - pub async fn read_message(&mut self) -> Result { - match &self.read { - Some(read) => Self::read(self.read.as_mut().unwrap()).await, - None => Self::read(self.tls_read.as_mut().unwrap()).await, - } - } - - pub async fn write_all(&mut self, buf: &BytesMut) -> Result<(), Error> { - match &self.write { - Some(write) => Self::write(self.write.as_mut().unwrap(), buf).await, - None => Self::write(self.tls_write.as_mut().unwrap(), buf).await, - } - } -} \ No newline at end of file +// impl tokio::io::AsyncRead for Stream { +// fn poll_read( +// mut self: core::pin::Pin<&mut Self>, +// cx: &mut core::task::Context<'_>, +// buf: &mut tokio::io::ReadBuf<'_> +// ) -> core::task::Poll> { +// match &mut self.get_mut().tls_read { +// None => core::pin::Pin::new(self.read.as_mut().unwrap()).poll_read(cx, buf), +// Some(mut tls) => core::pin::Pin::new(&mut tls).poll_read(cx, buf), +// } +// } +// }