mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 01:16:30 +00:00
504 lines
13 KiB
Rust
504 lines
13 KiB
Rust
/// Helper functions to send one-off protocol messages
|
|
/// and handle TcpStream (TCP socket).
|
|
use bytes::{Buf, BufMut, BytesMut};
|
|
use md5::{Digest, Md5};
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
|
|
use tokio::net::{
|
|
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
|
TcpStream,
|
|
};
|
|
|
|
use crate::errors::Error;
|
|
use std::collections::HashMap;
|
|
|
|
/// Postgres data type mappings
|
|
/// used in RowDescription ('T') message.
|
|
pub enum DataType {
|
|
Text,
|
|
Int4,
|
|
Numeric,
|
|
}
|
|
|
|
impl From<&DataType> for i32 {
|
|
fn from(data_type: &DataType) -> i32 {
|
|
match data_type {
|
|
DataType::Text => 25,
|
|
DataType::Int4 => 23,
|
|
DataType::Numeric => 1700,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Tell the client that authentication handshake completed successfully.
|
|
pub async fn auth_ok<S>(stream: &mut S) -> Result<(), Error>
|
|
where
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
{
|
|
let mut auth_ok = BytesMut::with_capacity(9);
|
|
|
|
auth_ok.put_u8(b'R');
|
|
auth_ok.put_i32(8);
|
|
auth_ok.put_i32(0);
|
|
|
|
Ok(write_all(stream, auth_ok).await?)
|
|
}
|
|
|
|
/// Generate md5 password challenge.
|
|
pub async fn md5_challenge<S>(stream: &mut S) -> Result<[u8; 4], Error>
|
|
where
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
{
|
|
// let mut rng = rand::thread_rng();
|
|
let salt: [u8; 4] = [
|
|
rand::random(),
|
|
rand::random(),
|
|
rand::random(),
|
|
rand::random(),
|
|
];
|
|
|
|
let mut res = BytesMut::new();
|
|
res.put_u8(b'R');
|
|
res.put_i32(12);
|
|
res.put_i32(5); // MD5
|
|
res.put_slice(&salt[..]);
|
|
|
|
write_all(stream, res).await?;
|
|
Ok(salt)
|
|
}
|
|
|
|
/// Give the client the process_id and secret we generated
|
|
/// used in query cancellation.
|
|
pub async fn backend_key_data<S>(
|
|
stream: &mut S,
|
|
backend_id: i32,
|
|
secret_key: i32,
|
|
) -> Result<(), Error>
|
|
where
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
{
|
|
let mut key_data = BytesMut::from(&b"K"[..]);
|
|
key_data.put_i32(12);
|
|
key_data.put_i32(backend_id);
|
|
key_data.put_i32(secret_key);
|
|
|
|
Ok(write_all(stream, key_data).await?)
|
|
}
|
|
|
|
/// Construct a `Q`: Query message.
|
|
pub fn simple_query(query: &str) -> BytesMut {
|
|
let mut res = BytesMut::from(&b"Q"[..]);
|
|
let query = format!("{}\0", query);
|
|
|
|
res.put_i32(query.len() as i32 + 4);
|
|
res.put_slice(&query.as_bytes());
|
|
|
|
res
|
|
}
|
|
|
|
/// Tell the client we're ready for another query.
|
|
pub async fn ready_for_query<S>(stream: &mut S) -> Result<(), Error>
|
|
where
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
{
|
|
let mut bytes = BytesMut::with_capacity(5);
|
|
|
|
bytes.put_u8(b'Z');
|
|
bytes.put_i32(5);
|
|
bytes.put_u8(b'I'); // Idle
|
|
|
|
Ok(write_all(stream, bytes).await?)
|
|
}
|
|
|
|
/// Send the startup packet the server. We're pretending we're a Pg client.
|
|
/// This tells the server which user we are and what database we want.
|
|
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> {
|
|
let mut bytes = BytesMut::with_capacity(25);
|
|
|
|
bytes.put_i32(196608); // Protocol number
|
|
|
|
// User
|
|
bytes.put(&b"user\0"[..]);
|
|
bytes.put_slice(&user.as_bytes());
|
|
bytes.put_u8(0);
|
|
|
|
// Database
|
|
bytes.put(&b"database\0"[..]);
|
|
bytes.put_slice(&database.as_bytes());
|
|
bytes.put_u8(0);
|
|
bytes.put_u8(0); // Null terminator
|
|
|
|
let len = bytes.len() as i32 + 4i32;
|
|
|
|
let mut startup = BytesMut::with_capacity(len as usize);
|
|
|
|
startup.put_i32(len);
|
|
startup.put(bytes);
|
|
|
|
match stream.write_all(&startup).await {
|
|
Ok(_) => Ok(()),
|
|
Err(_) => return Err(Error::SocketError),
|
|
}
|
|
}
|
|
|
|
/// Parse the params the server sends as a key/value format.
|
|
pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
|
|
let mut result = HashMap::new();
|
|
let mut buf = Vec::new();
|
|
let mut tmp = String::new();
|
|
|
|
while bytes.has_remaining() {
|
|
let mut c = bytes.get_u8();
|
|
|
|
// Null-terminated C-strings.
|
|
while c != 0 {
|
|
tmp.push(c as char);
|
|
c = bytes.get_u8();
|
|
}
|
|
|
|
if tmp.len() > 0 {
|
|
buf.push(tmp.clone());
|
|
tmp.clear();
|
|
}
|
|
}
|
|
|
|
// Expect pairs of name and value
|
|
// and at least one pair to be present.
|
|
if buf.len() % 2 != 0 || buf.len() < 2 {
|
|
return Err(Error::ClientBadStartup);
|
|
}
|
|
|
|
let mut i = 0;
|
|
while i < buf.len() {
|
|
let name = buf[i].clone();
|
|
let value = buf[i + 1].clone();
|
|
let _ = result.insert(name, value);
|
|
i += 2;
|
|
}
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
/// Parse StartupMessage parameters.
|
|
/// e.g. user, database, application_name, etc.
|
|
pub fn parse_startup(bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
|
|
let result = parse_params(bytes)?;
|
|
|
|
// Minimum required parameters
|
|
// I want to have the user at the very minimum, according to the protocol spec.
|
|
if !result.contains_key("user") {
|
|
return Err(Error::ClientBadStartup);
|
|
}
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
/// Create md5 password hash given a salt.
|
|
pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec<u8> {
|
|
let mut md5 = Md5::new();
|
|
|
|
// First pass
|
|
md5.update(&password.as_bytes());
|
|
md5.update(&user.as_bytes());
|
|
|
|
let output = md5.finalize_reset();
|
|
|
|
// Second pass
|
|
md5.update(format!("{:x}", output));
|
|
md5.update(salt);
|
|
|
|
let mut password = format!("md5{:x}", md5.finalize())
|
|
.chars()
|
|
.map(|x| x as u8)
|
|
.collect::<Vec<u8>>();
|
|
password.push(0);
|
|
|
|
password
|
|
}
|
|
|
|
/// Send password challenge response to the server.
|
|
/// This is the MD5 challenge.
|
|
pub async fn md5_password<S>(
|
|
stream: &mut S,
|
|
user: &str,
|
|
password: &str,
|
|
salt: &[u8],
|
|
) -> Result<(), Error>
|
|
where
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
{
|
|
let password = md5_hash_password(user, password, salt);
|
|
|
|
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
|
|
|
|
message.put_u8(b'p');
|
|
message.put_i32(password.len() as i32 + 4);
|
|
message.put_slice(&password[..]);
|
|
|
|
Ok(write_all(stream, message).await?)
|
|
}
|
|
|
|
/// Implements a response to our custom `SET SHARDING KEY`
|
|
/// and `SET SERVER ROLE` commands.
|
|
/// This tells the client we're ready for the next query.
|
|
pub async fn custom_protocol_response_ok<S>(stream: &mut S, message: &str) -> Result<(), Error>
|
|
where
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
{
|
|
let mut res = BytesMut::with_capacity(25);
|
|
|
|
let set_complete = BytesMut::from(&format!("{}\0", message)[..]);
|
|
let len = (set_complete.len() + 4) as i32;
|
|
|
|
// CommandComplete
|
|
res.put_u8(b'C');
|
|
res.put_i32(len);
|
|
res.put_slice(&set_complete[..]);
|
|
|
|
// ReadyForQuery (idle)
|
|
res.put_u8(b'Z');
|
|
res.put_i32(5);
|
|
res.put_u8(b'I');
|
|
|
|
write_all_half(stream, res).await
|
|
}
|
|
|
|
/// Send a custom error message to the client.
|
|
/// Tell the client we are ready for the next query and no rollback is necessary.
|
|
/// Docs on error codes: <https://www.postgresql.org/docs/12/errcodes-appendix.html>.
|
|
pub async fn error_response<S>(stream: &mut S, message: &str) -> Result<(), Error>
|
|
where
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
{
|
|
let mut error = BytesMut::new();
|
|
|
|
// Error level
|
|
error.put_u8(b'S');
|
|
error.put_slice(&b"FATAL\0"[..]);
|
|
|
|
// Error level (non-translatable)
|
|
error.put_u8(b'V');
|
|
error.put_slice(&b"FATAL\0"[..]);
|
|
|
|
// Error code: not sure how much this matters.
|
|
error.put_u8(b'C');
|
|
error.put_slice(&b"58000\0"[..]); // system_error, see Appendix A.
|
|
|
|
// The short error message.
|
|
error.put_u8(b'M');
|
|
error.put_slice(&format!("{}\0", message).as_bytes());
|
|
|
|
// No more fields follow.
|
|
error.put_u8(0);
|
|
|
|
// Ready for query, no rollback needed (I = idle).
|
|
let mut ready_for_query = BytesMut::new();
|
|
|
|
ready_for_query.put_u8(b'Z');
|
|
ready_for_query.put_i32(5);
|
|
ready_for_query.put_u8(b'I');
|
|
|
|
// Compose the two message reply.
|
|
let mut res = BytesMut::with_capacity(error.len() + ready_for_query.len() + 5);
|
|
|
|
res.put_u8(b'E');
|
|
res.put_i32(error.len() as i32 + 4);
|
|
|
|
res.put(error);
|
|
res.put(ready_for_query);
|
|
|
|
Ok(write_all_half(stream, res).await?)
|
|
}
|
|
|
|
pub async fn wrong_password<S>(stream: &mut S, user: &str) -> Result<(), Error>
|
|
where
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
{
|
|
let mut error = BytesMut::new();
|
|
|
|
// Error level
|
|
error.put_u8(b'S');
|
|
error.put_slice(&b"FATAL\0"[..]);
|
|
|
|
// Error level (non-translatable)
|
|
error.put_u8(b'V');
|
|
error.put_slice(&b"FATAL\0"[..]);
|
|
|
|
// Error code: not sure how much this matters.
|
|
error.put_u8(b'C');
|
|
error.put_slice(&b"28P01\0"[..]); // system_error, see Appendix A.
|
|
|
|
// The short error message.
|
|
error.put_u8(b'M');
|
|
error.put_slice(&format!("password authentication failed for user \"{}\"\0", user).as_bytes());
|
|
|
|
// No more fields follow.
|
|
error.put_u8(0);
|
|
|
|
// Compose the two message reply.
|
|
let mut res = BytesMut::new();
|
|
|
|
res.put_u8(b'E');
|
|
res.put_i32(error.len() as i32 + 4);
|
|
|
|
res.put(error);
|
|
|
|
write_all(stream, res).await
|
|
}
|
|
|
|
/// Respond to a SHOW SHARD command.
|
|
pub async fn show_response<S>(stream: &mut S, name: &str, value: &str) -> Result<(), Error>
|
|
where
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
{
|
|
// A SELECT response consists of:
|
|
// 1. RowDescription
|
|
// 2. One or more DataRow
|
|
// 3. CommandComplete
|
|
// 4. ReadyForQuery
|
|
|
|
// The final messages sent to the client
|
|
let mut res = BytesMut::new();
|
|
|
|
// RowDescription
|
|
res.put(row_description(&vec![(name, DataType::Text)]));
|
|
|
|
// DataRow
|
|
res.put(data_row(&vec![value.to_string()]));
|
|
|
|
// CommandComplete
|
|
res.put(command_complete("SELECT 1"));
|
|
|
|
// ReadyForQuery
|
|
res.put_u8(b'Z');
|
|
res.put_i32(5);
|
|
res.put_u8(b'I');
|
|
|
|
write_all_half(stream, res).await
|
|
}
|
|
|
|
pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
|
|
let mut res = BytesMut::new();
|
|
let mut row_desc = BytesMut::new();
|
|
|
|
// how many colums we are storing
|
|
row_desc.put_i16(columns.len() as i16);
|
|
|
|
for (name, data_type) in columns {
|
|
// Column name
|
|
row_desc.put_slice(&format!("{}\0", name).as_bytes());
|
|
|
|
// Doesn't belong to any table
|
|
row_desc.put_i32(0);
|
|
|
|
// Doesn't belong to any table
|
|
row_desc.put_i16(0);
|
|
|
|
// Text
|
|
row_desc.put_i32(data_type.into());
|
|
|
|
// Text size = variable (-1)
|
|
let type_size = match data_type {
|
|
DataType::Text => -1,
|
|
DataType::Int4 => 4,
|
|
DataType::Numeric => -1,
|
|
};
|
|
|
|
row_desc.put_i16(type_size);
|
|
|
|
// Type modifier: none that I know
|
|
row_desc.put_i32(-1);
|
|
|
|
// Format being used: text (0), binary (1)
|
|
row_desc.put_i16(0);
|
|
}
|
|
|
|
res.put_u8(b'T');
|
|
res.put_i32(row_desc.len() as i32 + 4);
|
|
res.put(row_desc);
|
|
|
|
res
|
|
}
|
|
|
|
/// Create a DataRow message.
|
|
pub fn data_row(row: &Vec<String>) -> BytesMut {
|
|
let mut res = BytesMut::new();
|
|
let mut data_row = BytesMut::new();
|
|
|
|
data_row.put_i16(row.len() as i16);
|
|
|
|
for column in row {
|
|
let column = column.as_bytes();
|
|
data_row.put_i32(column.len() as i32);
|
|
data_row.put_slice(&column);
|
|
}
|
|
|
|
res.put_u8(b'D');
|
|
res.put_i32(data_row.len() as i32 + 4);
|
|
res.put(data_row);
|
|
|
|
res
|
|
}
|
|
|
|
/// Create a CommandComplete message.
|
|
pub fn command_complete(command: &str) -> BytesMut {
|
|
let cmd = BytesMut::from(format!("{}\0", command).as_bytes());
|
|
let mut res = BytesMut::new();
|
|
res.put_u8(b'C');
|
|
res.put_i32(cmd.len() as i32 + 4);
|
|
res.put(cmd);
|
|
res
|
|
}
|
|
|
|
/// Write all data in the buffer to the TcpStream.
|
|
pub async fn write_all<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error>
|
|
where
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
{
|
|
match stream.write_all(&buf).await {
|
|
Ok(_) => Ok(()),
|
|
Err(_) => return Err(Error::SocketError),
|
|
}
|
|
}
|
|
|
|
/// Write all the data in the buffer to the TcpStream, write owned half (see mpsc).
|
|
pub async fn write_all_half<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error>
|
|
where
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
{
|
|
match stream.write_all(&buf).await {
|
|
Ok(_) => Ok(()),
|
|
Err(_) => return Err(Error::SocketError),
|
|
}
|
|
}
|
|
|
|
/// Read a complete message from the socket.
|
|
pub async fn read_message<S>(stream: &mut S) -> Result<BytesMut, Error>
|
|
where
|
|
S: tokio::io::AsyncRead + std::marker::Unpin,
|
|
{
|
|
let code = match stream.read_u8().await {
|
|
Ok(code) => code,
|
|
Err(_) => return Err(Error::SocketError),
|
|
};
|
|
|
|
let len = match stream.read_i32().await {
|
|
Ok(len) => len,
|
|
Err(_) => return Err(Error::SocketError),
|
|
};
|
|
|
|
let mut buf = vec![0u8; len as usize - 4];
|
|
|
|
match stream.read_exact(&mut buf).await {
|
|
Ok(_) => (),
|
|
Err(_) => return Err(Error::SocketError),
|
|
};
|
|
|
|
let mut bytes = BytesMut::with_capacity(len as usize + 1);
|
|
|
|
bytes.put_u8(code);
|
|
bytes.put_i32(len);
|
|
bytes.put_slice(&buf);
|
|
|
|
Ok(bytes)
|
|
}
|