Files
pgcat/src/messages.rs
Mostafa Abdelraouf f1265a5570 Introduce tcp_keepalives to PgCat (#315)
We have encountered a case where PgCat pools were stuck following a database incident. Our best understanding at this point is that the PgCat -> Postgres connections died silently and because Tokio defaults to disabling keepalives, connections in the pool were marked as busy forever. Only when we deployed PgCat did we see recovery.

This PR introduces tcp_keepalives to PgCat. This sets the defaults to be

keepalives_idle: 5        # seconds
keepalives_interval: 5 # seconds
keepalives_count: 5    # a count
These settings can detect the death of an idle connection within 30 seconds of its death. Please note that the connection can remain idle forever (from an application perspective) as long as the keepalive packets are flowing so disconnection will only occur if the other end is not acknowledging keepalive packets (keepalive packet acks are handled by the OS, the application does not need to do anything). I plan to add tcp_user_timeout in a follow-up PR.
2023-02-08 11:35:38 -06:00

592 lines
16 KiB
Rust

/// Helper functions to send one-off protocol messages
/// and handle TcpStream (TCP socket).
use bytes::{Buf, BufMut, BytesMut};
use log::error;
use md5::{Digest, Md5};
use socket2::{SockRef, TcpKeepalive};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use crate::config::get_config;
use crate::errors::Error;
use std::collections::HashMap;
use std::io::{BufRead, Cursor};
use std::mem;
use std::time::Duration;
/// Postgres data type mappings
/// used in RowDescription ('T') message.
pub enum DataType {
Text,
Int4,
Numeric,
}
impl From<&DataType> for i32 {
fn from(data_type: &DataType) -> i32 {
match data_type {
DataType::Text => 25,
DataType::Int4 => 23,
DataType::Numeric => 1700,
}
}
}
/// Tell the client that authentication handshake completed successfully.
pub async fn auth_ok<S>(stream: &mut S) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut auth_ok = BytesMut::with_capacity(9);
auth_ok.put_u8(b'R');
auth_ok.put_i32(8);
auth_ok.put_i32(0);
write_all(stream, auth_ok).await
}
/// Generate md5 password challenge.
pub async fn md5_challenge<S>(stream: &mut S) -> Result<[u8; 4], Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
// let mut rng = rand::thread_rng();
let salt: [u8; 4] = [
rand::random(),
rand::random(),
rand::random(),
rand::random(),
];
let mut res = BytesMut::new();
res.put_u8(b'R');
res.put_i32(12);
res.put_i32(5); // MD5
res.put_slice(&salt[..]);
write_all(stream, res).await?;
Ok(salt)
}
/// Give the client the process_id and secret we generated
/// used in query cancellation.
pub async fn backend_key_data<S>(
stream: &mut S,
backend_id: i32,
secret_key: i32,
) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut key_data = BytesMut::from(&b"K"[..]);
key_data.put_i32(12);
key_data.put_i32(backend_id);
key_data.put_i32(secret_key);
write_all(stream, key_data).await
}
/// Construct a `Q`: Query message.
pub fn simple_query(query: &str) -> BytesMut {
let mut res = BytesMut::from(&b"Q"[..]);
let query = format!("{}\0", query);
res.put_i32(query.len() as i32 + 4);
res.put_slice(query.as_bytes());
res
}
/// Tell the client we're ready for another query.
pub async fn ready_for_query<S>(stream: &mut S) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut bytes = BytesMut::with_capacity(
mem::size_of::<u8>() + mem::size_of::<i32>() + mem::size_of::<u8>(),
);
bytes.put_u8(b'Z');
bytes.put_i32(5);
bytes.put_u8(b'I'); // Idle
write_all(stream, bytes).await
}
/// Send the startup packet the server. We're pretending we're a Pg client.
/// This tells the server which user we are and what database we want.
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> {
let mut bytes = BytesMut::with_capacity(25);
bytes.put_i32(196608); // Protocol number
// User
bytes.put(&b"user\0"[..]);
bytes.put_slice(user.as_bytes());
bytes.put_u8(0);
// Database
bytes.put(&b"database\0"[..]);
bytes.put_slice(database.as_bytes());
bytes.put_u8(0);
bytes.put_u8(0); // Null terminator
let len = bytes.len() as i32 + 4i32;
let mut startup = BytesMut::with_capacity(len as usize);
startup.put_i32(len);
startup.put(bytes);
match stream.write_all(&startup).await {
Ok(_) => Ok(()),
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing startup to server socket - Error: {:?}",
err
)))
}
}
}
/// Parse the params the server sends as a key/value format.
pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
let mut result = HashMap::new();
let mut buf = Vec::new();
let mut tmp = String::new();
while bytes.has_remaining() {
let mut c = bytes.get_u8();
// Null-terminated C-strings.
while c != 0 {
tmp.push(c as char);
c = bytes.get_u8();
}
if !tmp.is_empty() {
buf.push(tmp.clone());
tmp.clear();
}
}
// Expect pairs of name and value
// and at least one pair to be present.
if buf.len() % 2 != 0 || buf.len() < 2 {
return Err(Error::ClientBadStartup);
}
let mut i = 0;
while i < buf.len() {
let name = buf[i].clone();
let value = buf[i + 1].clone();
let _ = result.insert(name, value);
i += 2;
}
Ok(result)
}
/// Parse StartupMessage parameters.
/// e.g. user, database, application_name, etc.
pub fn parse_startup(bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
let result = parse_params(bytes)?;
// Minimum required parameters
// I want to have the user at the very minimum, according to the protocol spec.
if !result.contains_key("user") {
return Err(Error::ClientBadStartup);
}
Ok(result)
}
/// Create md5 password hash given a salt.
pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec<u8> {
let mut md5 = Md5::new();
// First pass
md5.update(&password.as_bytes());
md5.update(&user.as_bytes());
let output = md5.finalize_reset();
// Second pass
md5.update(format!("{:x}", output));
md5.update(salt);
let mut password = format!("md5{:x}", md5.finalize())
.chars()
.map(|x| x as u8)
.collect::<Vec<u8>>();
password.push(0);
password
}
/// Send password challenge response to the server.
/// This is the MD5 challenge.
pub async fn md5_password<S>(
stream: &mut S,
user: &str,
password: &str,
salt: &[u8],
) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let password = md5_hash_password(user, password, salt);
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
message.put_u8(b'p');
message.put_i32(password.len() as i32 + 4);
message.put_slice(&password[..]);
write_all(stream, message).await
}
/// Implements a response to our custom `SET SHARDING KEY`
/// and `SET SERVER ROLE` commands.
/// This tells the client we're ready for the next query.
pub async fn custom_protocol_response_ok<S>(stream: &mut S, message: &str) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut res = BytesMut::with_capacity(25);
let set_complete = BytesMut::from(&format!("{}\0", message)[..]);
let len = (set_complete.len() + 4) as i32;
// CommandComplete
res.put_u8(b'C');
res.put_i32(len);
res.put_slice(&set_complete[..]);
write_all_half(stream, &res).await?;
ready_for_query(stream).await
}
/// Send a custom error message to the client.
/// Tell the client we are ready for the next query and no rollback is necessary.
/// Docs on error codes: <https://www.postgresql.org/docs/12/errcodes-appendix.html>.
pub async fn error_response<S>(stream: &mut S, message: &str) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
error_response_terminal(stream, message).await?;
ready_for_query(stream).await
}
/// Send a custom error message to the client.
/// Tell the client we are ready for the next query and no rollback is necessary.
/// Docs on error codes: <https://www.postgresql.org/docs/12/errcodes-appendix.html>.
pub async fn error_response_terminal<S>(stream: &mut S, message: &str) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut error = BytesMut::new();
// Error level
error.put_u8(b'S');
error.put_slice(&b"FATAL\0"[..]);
// Error level (non-translatable)
error.put_u8(b'V');
error.put_slice(&b"FATAL\0"[..]);
// Error code: not sure how much this matters.
error.put_u8(b'C');
error.put_slice(&b"58000\0"[..]); // system_error, see Appendix A.
// The short error message.
error.put_u8(b'M');
error.put_slice(format!("{}\0", message).as_bytes());
// No more fields follow.
error.put_u8(0);
// Compose the two message reply.
let mut res = BytesMut::with_capacity(error.len() + 5);
res.put_u8(b'E');
res.put_i32(error.len() as i32 + 4);
res.put(error);
write_all_half(stream, &res).await
}
pub async fn wrong_password<S>(stream: &mut S, user: &str) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut error = BytesMut::new();
// Error level
error.put_u8(b'S');
error.put_slice(&b"FATAL\0"[..]);
// Error level (non-translatable)
error.put_u8(b'V');
error.put_slice(&b"FATAL\0"[..]);
// Error code: not sure how much this matters.
error.put_u8(b'C');
error.put_slice(&b"28P01\0"[..]); // system_error, see Appendix A.
// The short error message.
error.put_u8(b'M');
error.put_slice(format!("password authentication failed for user \"{}\"\0", user).as_bytes());
// No more fields follow.
error.put_u8(0);
// Compose the two message reply.
let mut res = BytesMut::new();
res.put_u8(b'E');
res.put_i32(error.len() as i32 + 4);
res.put(error);
write_all(stream, res).await
}
/// Respond to a SHOW SHARD command.
pub async fn show_response<S>(stream: &mut S, name: &str, value: &str) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
// A SELECT response consists of:
// 1. RowDescription
// 2. One or more DataRow
// 3. CommandComplete
// 4. ReadyForQuery
// The final messages sent to the client
let mut res = BytesMut::new();
// RowDescription
res.put(row_description(&vec![(name, DataType::Text)]));
// DataRow
res.put(data_row(&vec![value.to_string()]));
// CommandComplete
res.put(command_complete("SELECT 1"));
write_all_half(stream, &res).await?;
ready_for_query(stream).await
}
pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
let mut res = BytesMut::new();
let mut row_desc = BytesMut::new();
// how many colums we are storing
row_desc.put_i16(columns.len() as i16);
for (name, data_type) in columns {
// Column name
row_desc.put_slice(format!("{}\0", name).as_bytes());
// Doesn't belong to any table
row_desc.put_i32(0);
// Doesn't belong to any table
row_desc.put_i16(0);
// Text
row_desc.put_i32(data_type.into());
// Text size = variable (-1)
let type_size = match data_type {
DataType::Text => -1,
DataType::Int4 => 4,
DataType::Numeric => -1,
};
row_desc.put_i16(type_size);
// Type modifier: none that I know
row_desc.put_i32(-1);
// Format being used: text (0), binary (1)
row_desc.put_i16(0);
}
res.put_u8(b'T');
res.put_i32(row_desc.len() as i32 + 4);
res.put(row_desc);
res
}
/// Create a DataRow message.
pub fn data_row(row: &Vec<String>) -> BytesMut {
let mut res = BytesMut::new();
let mut data_row = BytesMut::new();
data_row.put_i16(row.len() as i16);
for column in row {
let column = column.as_bytes();
data_row.put_i32(column.len() as i32);
data_row.put_slice(column);
}
res.put_u8(b'D');
res.put_i32(data_row.len() as i32 + 4);
res.put(data_row);
res
}
/// Create a CommandComplete message.
pub fn command_complete(command: &str) -> BytesMut {
let cmd = BytesMut::from(format!("{}\0", command).as_bytes());
let mut res = BytesMut::new();
res.put_u8(b'C');
res.put_i32(cmd.len() as i32 + 4);
res.put(cmd);
res
}
/// Write all data in the buffer to the TcpStream.
pub async fn write_all<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
match stream.write_all(&buf).await {
Ok(_) => Ok(()),
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
)))
}
}
}
/// Write all the data in the buffer to the TcpStream, write owned half (see mpsc).
pub async fn write_all_half<S>(stream: &mut S, buf: &BytesMut) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
match stream.write_all(buf).await {
Ok(_) => Ok(()),
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
)))
}
}
}
/// Read a complete message from the socket.
pub async fn read_message<S>(stream: &mut S) -> Result<BytesMut, Error>
where
S: tokio::io::AsyncRead + std::marker::Unpin,
{
let code = match stream.read_u8().await {
Ok(code) => code,
Err(err) => {
return Err(Error::SocketError(format!(
"Error reading message code from socket - Error {:?}",
err
)))
}
};
let len = match stream.read_i32().await {
Ok(len) => len,
Err(err) => {
return Err(Error::SocketError(format!(
"Error reading message len from socket - Code: {:?}, Error: {:?}",
code, err
)))
}
};
let mut bytes = BytesMut::with_capacity(len as usize + 1);
bytes.put_u8(code);
bytes.put_i32(len);
bytes.resize(bytes.len() + len as usize - mem::size_of::<i32>(), b'0');
match stream
.read_exact(
&mut bytes[mem::size_of::<u8>() + mem::size_of::<i32>()
..mem::size_of::<u8>() + mem::size_of::<i32>() + len as usize
- mem::size_of::<i32>()],
)
.await
{
Ok(_) => (),
Err(err) => {
return Err(Error::SocketError(format!(
"Error reading message from socket - Code: {:?}, Error: {:?}",
code, err
)))
}
};
Ok(bytes)
}
pub fn server_parameter_message(key: &str, value: &str) -> BytesMut {
let mut server_info = BytesMut::new();
let null_byte_size = 1;
let len: usize =
mem::size_of::<i32>() + key.len() + null_byte_size + value.len() + null_byte_size;
server_info.put_slice("S".as_bytes());
server_info.put_i32(len.try_into().unwrap());
server_info.put_slice(key.as_bytes());
server_info.put_bytes(0, 1);
server_info.put_slice(value.as_bytes());
server_info.put_bytes(0, 1);
server_info
}
pub fn configure_socket(stream: &TcpStream) {
let sock_ref = SockRef::from(stream);
let conf = get_config();
match sock_ref.set_keepalive(true) {
Ok(_) => {
match sock_ref.set_tcp_keepalive(
&TcpKeepalive::new()
.with_interval(Duration::from_secs(conf.general.tcp_keepalives_interval))
.with_retries(conf.general.tcp_keepalives_count)
.with_time(Duration::from_secs(conf.general.tcp_keepalives_idle)),
) {
Ok(_) => (),
Err(err) => error!("Could not configure socket: {}", err),
}
}
Err(err) => error!("Could not configure socket: {}", err),
}
}
pub trait BytesMutReader {
fn read_string(&mut self) -> Result<String, Error>;
}
impl BytesMutReader for Cursor<&BytesMut> {
/// Should only be used when reading strings from the message protocol.
/// Can be used to read multiple strings from the same message which are separated by the null byte
fn read_string(&mut self) -> Result<String, Error> {
let mut buf = vec![];
match self.read_until(b'\0', &mut buf) {
Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()),
Err(err) => return Err(Error::ParseBytesError(err.to_string())),
}
}
}