/// Helper functions to send one-off protocol messages /// and handle TcpStream (TCP socket). use bytes::{Buf, BufMut, BytesMut}; use log::error; use md5::{Digest, Md5}; use socket2::{SockRef, TcpKeepalive}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; use crate::client::PREPARED_STATEMENT_COUNTER; use crate::config::get_config; use crate::errors::Error; use std::collections::HashMap; use std::ffi::CString; use std::io::{BufRead, Cursor}; use std::mem; use std::sync::atomic::Ordering; use std::time::Duration; /// Postgres data type mappings /// used in RowDescription ('T') message. pub enum DataType { Text, Int4, Numeric, Bool, Oid, AnyArray, Any, } impl From<&DataType> for i32 { fn from(data_type: &DataType) -> i32 { match data_type { DataType::Text => 25, DataType::Int4 => 23, DataType::Numeric => 1700, DataType::Bool => 16, DataType::Oid => 26, DataType::AnyArray => 2277, DataType::Any => 2276, } } } /// Tell the client that authentication handshake completed successfully. pub async fn auth_ok(stream: &mut S) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { let mut auth_ok = BytesMut::with_capacity(9); auth_ok.put_u8(b'R'); auth_ok.put_i32(8); auth_ok.put_i32(0); write_all(stream, auth_ok).await } /// Generate md5 password challenge. pub async fn md5_challenge(stream: &mut S) -> Result<[u8; 4], Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { // let mut rng = rand::thread_rng(); let salt: [u8; 4] = [ rand::random(), rand::random(), rand::random(), rand::random(), ]; let mut res = BytesMut::new(); res.put_u8(b'R'); res.put_i32(12); res.put_i32(5); // MD5 res.put_slice(&salt[..]); write_all(stream, res).await?; Ok(salt) } /// Give the client the process_id and secret we generated /// used in query cancellation. pub async fn backend_key_data( stream: &mut S, backend_id: i32, secret_key: i32, ) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { let mut key_data = BytesMut::from(&b"K"[..]); key_data.put_i32(12); key_data.put_i32(backend_id); key_data.put_i32(secret_key); write_all(stream, key_data).await } /// Construct a `Q`: Query message. pub fn simple_query(query: &str) -> BytesMut { let mut res = BytesMut::from(&b"Q"[..]); let query = format!("{}\0", query); res.put_i32(query.len() as i32 + 4); res.put_slice(query.as_bytes()); res } /// Tell the client we're ready for another query. pub async fn ready_for_query(stream: &mut S) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { let mut bytes = BytesMut::with_capacity( mem::size_of::() + mem::size_of::() + mem::size_of::(), ); bytes.put_u8(b'Z'); bytes.put_i32(5); bytes.put_u8(b'I'); // Idle write_all(stream, bytes).await } /// Send the startup packet the server. We're pretending we're a Pg client. /// This tells the server which user we are and what database we want. pub async fn startup(stream: &mut S, user: &str, database: &str) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { let mut bytes = BytesMut::with_capacity(25); bytes.put_i32(196608); // Protocol number // User bytes.put(&b"user\0"[..]); bytes.put_slice(user.as_bytes()); bytes.put_u8(0); // Database bytes.put(&b"database\0"[..]); bytes.put_slice(database.as_bytes()); bytes.put_u8(0); bytes.put_u8(0); // Null terminator let len = bytes.len() as i32 + 4i32; let mut startup = BytesMut::with_capacity(len as usize); startup.put_i32(len); startup.put(bytes); match stream.write_all(&startup).await { Ok(_) => Ok(()), Err(err) => { return Err(Error::SocketError(format!( "Error writing startup to server socket - Error: {:?}", err ))) } } } pub async fn ssl_request(stream: &mut TcpStream) -> Result<(), Error> { let mut bytes = BytesMut::with_capacity(12); bytes.put_i32(8); bytes.put_i32(80877103); match stream.write_all(&bytes).await { Ok(_) => Ok(()), Err(err) => Err(Error::SocketError(format!( "Error writing SSLRequest to server socket - Error: {:?}", err ))), } } /// Parse the params the server sends as a key/value format. pub fn parse_params(mut bytes: BytesMut) -> Result, Error> { let mut result = HashMap::new(); let mut buf = Vec::new(); let mut tmp = String::new(); while bytes.has_remaining() { let mut c = bytes.get_u8(); // Null-terminated C-strings. while c != 0 { tmp.push(c as char); c = bytes.get_u8(); } if !tmp.is_empty() { buf.push(tmp.clone()); tmp.clear(); } } // Expect pairs of name and value // and at least one pair to be present. if buf.len() % 2 != 0 || buf.len() < 2 { return Err(Error::ClientBadStartup); } let mut i = 0; while i < buf.len() { let name = buf[i].clone(); let value = buf[i + 1].clone(); let _ = result.insert(name, value); i += 2; } Ok(result) } /// Parse StartupMessage parameters. /// e.g. user, database, application_name, etc. pub fn parse_startup(bytes: BytesMut) -> Result, Error> { let result = parse_params(bytes)?; // Minimum required parameters // I want to have the user at the very minimum, according to the protocol spec. if !result.contains_key("user") { return Err(Error::ClientBadStartup); } Ok(result) } /// Create md5 password hash given a salt. pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec { let mut md5 = Md5::new(); // First pass md5.update(&password.as_bytes()); md5.update(&user.as_bytes()); let output = md5.finalize_reset(); // Second pass md5_hash_second_pass(&(format!("{:x}", output)), salt) } pub fn md5_hash_second_pass(hash: &str, salt: &[u8]) -> Vec { let mut md5 = Md5::new(); // Second pass md5.update(hash); md5.update(salt); let mut password = format!("md5{:x}", md5.finalize()) .chars() .map(|x| x as u8) .collect::>(); password.push(0); password } /// Send password challenge response to the server. /// This is the MD5 challenge. pub async fn md5_password( stream: &mut S, user: &str, password: &str, salt: &[u8], ) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { let password = md5_hash_password(user, password, salt); let mut message = BytesMut::with_capacity(password.len() as usize + 5); message.put_u8(b'p'); message.put_i32(password.len() as i32 + 4); message.put_slice(&password[..]); write_all(stream, message).await } pub async fn md5_password_with_hash(stream: &mut S, hash: &str, salt: &[u8]) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { let password = md5_hash_second_pass(hash, salt); let mut message = BytesMut::with_capacity(password.len() as usize + 5); message.put_u8(b'p'); message.put_i32(password.len() as i32 + 4); message.put_slice(&password[..]); write_all(stream, message).await } /// Implements a response to our custom `SET SHARDING KEY` /// and `SET SERVER ROLE` commands. /// This tells the client we're ready for the next query. pub async fn custom_protocol_response_ok(stream: &mut S, message: &str) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { let mut res = BytesMut::with_capacity(25); let set_complete = BytesMut::from(&format!("{}\0", message)[..]); let len = (set_complete.len() + 4) as i32; // CommandComplete res.put_u8(b'C'); res.put_i32(len); res.put_slice(&set_complete[..]); write_all_half(stream, &res).await?; ready_for_query(stream).await } /// Send a custom error message to the client. /// Tell the client we are ready for the next query and no rollback is necessary. /// Docs on error codes: . pub async fn error_response(stream: &mut S, message: &str) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { error_response_terminal(stream, message).await?; ready_for_query(stream).await } /// Send a custom error message to the client. /// Tell the client we are ready for the next query and no rollback is necessary. /// Docs on error codes: . pub async fn error_response_terminal(stream: &mut S, message: &str) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { let mut error = BytesMut::new(); // Error level error.put_u8(b'S'); error.put_slice(&b"FATAL\0"[..]); // Error level (non-translatable) error.put_u8(b'V'); error.put_slice(&b"FATAL\0"[..]); // Error code: not sure how much this matters. error.put_u8(b'C'); error.put_slice(&b"58000\0"[..]); // system_error, see Appendix A. // The short error message. error.put_u8(b'M'); error.put_slice(format!("{}\0", message).as_bytes()); // No more fields follow. error.put_u8(0); // Compose the two message reply. let mut res = BytesMut::with_capacity(error.len() + 5); res.put_u8(b'E'); res.put_i32(error.len() as i32 + 4); res.put(error); write_all_half(stream, &res).await } pub async fn wrong_password(stream: &mut S, user: &str) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { let mut error = BytesMut::new(); // Error level error.put_u8(b'S'); error.put_slice(&b"FATAL\0"[..]); // Error level (non-translatable) error.put_u8(b'V'); error.put_slice(&b"FATAL\0"[..]); // Error code: not sure how much this matters. error.put_u8(b'C'); error.put_slice(&b"28P01\0"[..]); // system_error, see Appendix A. // The short error message. error.put_u8(b'M'); error.put_slice(format!("password authentication failed for user \"{}\"\0", user).as_bytes()); // No more fields follow. error.put_u8(0); // Compose the two message reply. let mut res = BytesMut::new(); res.put_u8(b'E'); res.put_i32(error.len() as i32 + 4); res.put(error); write_all(stream, res).await } /// Respond to a SHOW SHARD command. pub async fn show_response(stream: &mut S, name: &str, value: &str) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { // A SELECT response consists of: // 1. RowDescription // 2. One or more DataRow // 3. CommandComplete // 4. ReadyForQuery // The final messages sent to the client let mut res = BytesMut::new(); // RowDescription res.put(row_description(&vec![(name, DataType::Text)])); // DataRow res.put(data_row(&vec![value.to_string()])); // CommandComplete res.put(command_complete("SELECT 1")); write_all_half(stream, &res).await?; ready_for_query(stream).await } pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut { let mut res = BytesMut::new(); let mut row_desc = BytesMut::new(); // how many columns we are storing row_desc.put_i16(columns.len() as i16); for (name, data_type) in columns { // Column name row_desc.put_slice(format!("{}\0", name).as_bytes()); // Doesn't belong to any table row_desc.put_i32(0); // Doesn't belong to any table row_desc.put_i16(0); // Text row_desc.put_i32(data_type.into()); // Text size = variable (-1) let type_size = match data_type { DataType::Text => -1, DataType::Int4 => 4, DataType::Numeric => -1, DataType::Bool => 1, DataType::Oid => 4, DataType::AnyArray => -1, DataType::Any => -1, }; row_desc.put_i16(type_size); // Type modifier: none that I know row_desc.put_i32(-1); // Format being used: text (0), binary (1) row_desc.put_i16(0); } res.put_u8(b'T'); res.put_i32(row_desc.len() as i32 + 4); res.put(row_desc); res } /// Create a DataRow message. pub fn data_row(row: &Vec) -> BytesMut { let mut res = BytesMut::new(); let mut data_row = BytesMut::new(); data_row.put_i16(row.len() as i16); for column in row { let column = column.as_bytes(); data_row.put_i32(column.len() as i32); data_row.put_slice(column); } res.put_u8(b'D'); res.put_i32(data_row.len() as i32 + 4); res.put(data_row); res } pub fn data_row_nullable(row: &Vec>) -> BytesMut { let mut res = BytesMut::new(); let mut data_row = BytesMut::new(); data_row.put_i16(row.len() as i16); for column in row { if let Some(column) = column { let column = column.as_bytes(); data_row.put_i32(column.len() as i32); data_row.put_slice(column); } else { data_row.put_i32(-1 as i32); } } res.put_u8(b'D'); res.put_i32(data_row.len() as i32 + 4); res.put(data_row); res } /// Create a CommandComplete message. pub fn command_complete(command: &str) -> BytesMut { let cmd = BytesMut::from(format!("{}\0", command).as_bytes()); let mut res = BytesMut::new(); res.put_u8(b'C'); res.put_i32(cmd.len() as i32 + 4); res.put(cmd); res } pub fn flush() -> BytesMut { let mut bytes = BytesMut::new(); bytes.put_u8(b'H'); bytes.put_i32(4); bytes } /// Write all data in the buffer to the TcpStream. pub async fn write_all(stream: &mut S, buf: BytesMut) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { match stream.write_all(&buf).await { Ok(_) => Ok(()), Err(err) => { return Err(Error::SocketError(format!( "Error writing to socket - Error: {:?}", err ))) } } } /// Write all the data in the buffer to the TcpStream, write owned half (see mpsc). pub async fn write_all_half(stream: &mut S, buf: &BytesMut) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { match stream.write_all(buf).await { Ok(_) => Ok(()), Err(err) => { return Err(Error::SocketError(format!( "Error writing to socket - Error: {:?}", err ))) } } } pub async fn write_all_flush(stream: &mut S, buf: &[u8]) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { match stream.write_all(buf).await { Ok(_) => match stream.flush().await { Ok(_) => Ok(()), Err(err) => { return Err(Error::SocketError(format!( "Error flushing socket - Error: {:?}", err ))) } }, Err(err) => { return Err(Error::SocketError(format!( "Error writing to socket - Error: {:?}", err ))) } } } /// Read a complete message from the socket. pub async fn read_message(stream: &mut S) -> Result where S: tokio::io::AsyncRead + std::marker::Unpin, { let code = match stream.read_u8().await { Ok(code) => code, Err(err) => { return Err(Error::SocketError(format!( "Error reading message code from socket - Error {:?}", err ))) } }; let len = match stream.read_i32().await { Ok(len) => len, Err(err) => { return Err(Error::SocketError(format!( "Error reading message len from socket - Code: {:?}, Error: {:?}", code, err ))) } }; let mut bytes = BytesMut::with_capacity(len as usize + 1); bytes.put_u8(code); bytes.put_i32(len); bytes.resize(bytes.len() + len as usize - mem::size_of::(), b'0'); let slice_start = mem::size_of::() + mem::size_of::(); let slice_end = slice_start + len as usize - mem::size_of::(); // Avoids a panic if slice_end < slice_start { return Err(Error::SocketError(format!( "Error reading message from socket - Code: {:?} - Length {:?}, Error: {:?}", code, len, "Unexpected length value for message" ))); } match stream.read_exact(&mut bytes[slice_start..slice_end]).await { Ok(_) => (), Err(err) => { return Err(Error::SocketError(format!( "Error reading message from socket - Code: {:?}, Error: {:?}", code, err ))) } }; Ok(bytes) } pub fn server_parameter_message(key: &str, value: &str) -> BytesMut { let mut server_info = BytesMut::new(); let null_byte_size = 1; let len: usize = mem::size_of::() + key.len() + null_byte_size + value.len() + null_byte_size; server_info.put_slice("S".as_bytes()); server_info.put_i32(len.try_into().unwrap()); server_info.put_slice(key.as_bytes()); server_info.put_bytes(0, 1); server_info.put_slice(value.as_bytes()); server_info.put_bytes(0, 1); server_info } pub fn configure_socket(stream: &TcpStream) { let sock_ref = SockRef::from(stream); let conf = get_config(); match sock_ref.set_keepalive(true) { Ok(_) => { match sock_ref.set_tcp_keepalive( &TcpKeepalive::new() .with_interval(Duration::from_secs(conf.general.tcp_keepalives_interval)) .with_retries(conf.general.tcp_keepalives_count) .with_time(Duration::from_secs(conf.general.tcp_keepalives_idle)), ) { Ok(_) => (), Err(err) => error!("Could not configure socket: {}", err), } } Err(err) => error!("Could not configure socket: {}", err), } } pub trait BytesMutReader { fn read_string(&mut self) -> Result; } impl BytesMutReader for Cursor<&BytesMut> { /// Should only be used when reading strings from the message protocol. /// Can be used to read multiple strings from the same message which are separated by the null byte fn read_string(&mut self) -> Result { let mut buf = vec![]; match self.read_until(b'\0', &mut buf) { Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()), Err(err) => return Err(Error::ParseBytesError(err.to_string())), } } } /// Parse (F) message. /// See: #[derive(Clone, Debug)] pub struct Parse { code: char, #[allow(dead_code)] len: i32, pub name: String, pub generated_name: String, query: String, num_params: i16, param_types: Vec, } impl TryFrom<&BytesMut> for Parse { type Error = Error; fn try_from(buf: &BytesMut) -> Result { let mut cursor = Cursor::new(buf); let code = cursor.get_u8() as char; let len = cursor.get_i32(); let name = cursor.read_string()?; let query = cursor.read_string()?; let num_params = cursor.get_i16(); let mut param_types = Vec::new(); for _ in 0..num_params { param_types.push(cursor.get_i32()); } Ok(Parse { code, len, name, generated_name: prepared_statement_name(), query, num_params, param_types, }) } } impl TryFrom for BytesMut { type Error = Error; fn try_from(parse: Parse) -> Result { let mut bytes = BytesMut::new(); let name_binding = CString::new(parse.name)?; let name = name_binding.as_bytes_with_nul(); let query_binding = CString::new(parse.query)?; let query = query_binding.as_bytes_with_nul(); // Recompute length of the message. let len = 4 // self + name.len() + query.len() + 2 + 4 * parse.num_params as usize; bytes.put_u8(parse.code as u8); bytes.put_i32(len as i32); bytes.put_slice(name); bytes.put_slice(query); bytes.put_i16(parse.num_params); for param in parse.param_types { bytes.put_i32(param); } Ok(bytes) } } impl TryFrom<&Parse> for BytesMut { type Error = Error; fn try_from(parse: &Parse) -> Result { parse.clone().try_into() } } impl Parse { pub fn rename(mut self) -> Self { self.name = self.generated_name.to_string(); self } pub fn anonymous(&self) -> bool { self.name.is_empty() } } /// Bind (B) message. /// See: #[derive(Clone, Debug)] pub struct Bind { code: char, #[allow(dead_code)] len: i64, portal: String, pub prepared_statement: String, num_param_format_codes: i16, param_format_codes: Vec, num_param_values: i16, param_values: Vec<(i32, BytesMut)>, num_result_column_format_codes: i16, result_columns_format_codes: Vec, } impl TryFrom<&BytesMut> for Bind { type Error = Error; fn try_from(buf: &BytesMut) -> Result { let mut cursor = Cursor::new(buf); let code = cursor.get_u8() as char; let len = cursor.get_i32(); let portal = cursor.read_string()?; let prepared_statement = cursor.read_string()?; let num_param_format_codes = cursor.get_i16(); let mut param_format_codes = Vec::new(); for _ in 0..num_param_format_codes { param_format_codes.push(cursor.get_i16()); } let num_param_values = cursor.get_i16(); let mut param_values = Vec::new(); for _ in 0..num_param_values { let param_len = cursor.get_i32(); let mut param = BytesMut::with_capacity(param_len as usize); param.resize(param_len as usize, b'0'); cursor.copy_to_slice(&mut param); param_values.push((param_len, param)); } let num_result_column_format_codes = cursor.get_i16(); let mut result_columns_format_codes = Vec::new(); for _ in 0..num_result_column_format_codes { result_columns_format_codes.push(cursor.get_i16()); } Ok(Bind { code, len: len as i64, portal, prepared_statement, num_param_format_codes, param_format_codes, num_param_values, param_values, num_result_column_format_codes, result_columns_format_codes, }) } } impl TryFrom for BytesMut { type Error = Error; fn try_from(bind: Bind) -> Result { let mut bytes = BytesMut::new(); let portal_binding = CString::new(bind.portal)?; let portal = portal_binding.as_bytes_with_nul(); let prepared_statement_binding = CString::new(bind.prepared_statement)?; let prepared_statement = prepared_statement_binding.as_bytes_with_nul(); let mut len = 4 // self + portal.len() + prepared_statement.len() + 2 // num_param_format_codes + 2 * bind.num_param_format_codes as usize // num_param_format_codes + 2; // num_param_values for (param_len, _) in &bind.param_values { len += 4 + *param_len as usize; } len += 2; // num_result_column_format_codes len += 2 * bind.num_result_column_format_codes as usize; bytes.put_u8(bind.code as u8); bytes.put_i32(len as i32); bytes.put_slice(portal); bytes.put_slice(prepared_statement); bytes.put_i16(bind.num_param_format_codes); for param_format_code in bind.param_format_codes { bytes.put_i16(param_format_code); } bytes.put_i16(bind.num_param_values); for (param_len, param) in bind.param_values { bytes.put_i32(param_len); bytes.put_slice(¶m); } bytes.put_i16(bind.num_result_column_format_codes); for result_column_format_code in bind.result_columns_format_codes { bytes.put_i16(result_column_format_code); } Ok(bytes) } } impl Bind { pub fn reassign(mut self, parse: &Parse) -> Self { self.prepared_statement = parse.name.clone(); self } pub fn anonymous(&self) -> bool { self.prepared_statement.is_empty() } } #[derive(Debug, Clone)] pub struct Describe { code: char, #[allow(dead_code)] len: i32, target: char, pub statement_name: String, } impl TryFrom<&BytesMut> for Describe { type Error = Error; fn try_from(bytes: &BytesMut) -> Result { let mut cursor = Cursor::new(bytes); let code = cursor.get_u8() as char; let len = cursor.get_i32(); let target = cursor.get_u8() as char; let statement_name = cursor.read_string()?; Ok(Describe { code, len, target, statement_name, }) } } impl TryFrom for BytesMut { type Error = Error; fn try_from(describe: Describe) -> Result { let mut bytes = BytesMut::new(); let statement_name_binding = CString::new(describe.statement_name)?; let statement_name = statement_name_binding.as_bytes_with_nul(); let len = 4 + 1 + statement_name.len(); bytes.put_u8(describe.code as u8); bytes.put_i32(len as i32); bytes.put_u8(describe.target as u8); bytes.put_slice(statement_name); Ok(bytes) } } impl Describe { pub fn rename(mut self, name: &str) -> Self { self.statement_name = name.to_string(); self } pub fn anonymous(&self) -> bool { self.statement_name.is_empty() } } pub fn prepared_statement_name() -> String { format!( "P_{}", PREPARED_STATEMENT_COUNTER.fetch_add(1, Ordering::SeqCst) ) }