Files
pgcat/src/messages.rs

275 lines
7.4 KiB
Rust
Raw Normal View History

/// Helper functions to send one-off protocol messages
/// and handle TcpStream (TCP socket).
2022-02-14 05:11:53 -08:00
use bytes::{Buf, BufMut, BytesMut};
2022-02-03 15:17:04 -08:00
use md5::{Digest, Md5};
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
};
2022-02-03 13:35:40 -08:00
2022-02-14 05:11:53 -08:00
use std::collections::HashMap;
2022-02-03 13:35:40 -08:00
use crate::errors::Error;
2022-02-08 09:33:20 -08:00
/// Tell the client that authentication handshake completed successfully.
2022-02-03 13:35:40 -08:00
pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> {
let mut auth_ok = BytesMut::with_capacity(9);
auth_ok.put_u8(b'R');
auth_ok.put_i32(8);
auth_ok.put_i32(0);
Ok(write_all(stream, auth_ok).await?)
}
2022-02-08 09:33:20 -08:00
/// Give the client the process_id and secret we generated
/// used in query cancellation.
2022-02-04 09:28:52 -08:00
pub async fn backend_key_data(
stream: &mut TcpStream,
backend_id: i32,
secret_key: i32,
) -> Result<(), Error> {
let mut key_data = BytesMut::from(&b"K"[..]);
key_data.put_i32(12);
key_data.put_i32(backend_id);
key_data.put_i32(secret_key);
Ok(write_all(stream, key_data).await?)
}
2022-02-08 09:33:20 -08:00
/// Tell the client we're ready for another query.
2022-02-03 13:35:40 -08:00
pub async fn ready_for_query(stream: &mut TcpStream) -> Result<(), Error> {
let mut bytes = BytesMut::with_capacity(5);
bytes.put_u8(b'Z');
bytes.put_i32(5);
bytes.put_u8(b'I'); // Idle
Ok(write_all(stream, bytes).await?)
}
2022-02-08 09:33:20 -08:00
/// Send the startup packet the server. We're pretending we're a Pg client.
/// This tells the server which user we are and what database we want.
2022-02-03 15:17:04 -08:00
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> {
let mut bytes = BytesMut::with_capacity(25);
bytes.put_i32(196608); // Protocol number
// User
bytes.put(&b"user\0"[..]);
bytes.put_slice(&user.as_bytes());
bytes.put_u8(0);
// Database
bytes.put(&b"database\0"[..]);
bytes.put_slice(&database.as_bytes());
bytes.put_u8(0);
bytes.put_u8(0); // Null terminator
let len = bytes.len() as i32 + 4i32;
let mut startup = BytesMut::with_capacity(len as usize);
startup.put_i32(len);
startup.put(bytes);
match stream.write_all(&startup).await {
Ok(_) => Ok(()),
Err(_) => return Err(Error::SocketError),
}
}
2022-02-14 05:11:53 -08:00
/// Parse StartupMessage parameters.
/// e.g. user, database, application_name, etc.
pub fn parse_startup(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
let mut result = HashMap::new();
let mut buf = Vec::new();
let mut tmp = String::new();
while bytes.has_remaining() {
let mut c = bytes.get_u8();
// Null-terminated C-strings.
while c != 0 {
tmp.push(c as char);
c = bytes.get_u8();
}
if tmp.len() > 0 {
buf.push(tmp.clone());
tmp.clear();
}
}
// Expect pairs of name and value
// and at least one pair to be present.
if buf.len() % 2 != 0 && buf.len() >= 2 {
return Err(Error::ClientBadStartup);
}
let mut i = 0;
while i < buf.len() {
let name = buf[i].clone();
let value = buf[i + 1].clone();
let _ = result.insert(name, value);
i += 2;
}
// Minimum required parameters
// I want to have the user at the very minimum, according to the protocol spec.
if !result.contains_key("user") {
return Err(Error::ClientBadStartup);
}
Ok(result)
}
2022-02-08 09:33:20 -08:00
/// Send password challenge response to the server.
/// This is the MD5 challenge.
2022-02-03 15:17:04 -08:00
pub async fn md5_password(
stream: &mut TcpStream,
user: &str,
password: &str,
salt: &[u8],
) -> Result<(), Error> {
let mut md5 = Md5::new();
// First pass
md5.update(&password.as_bytes());
md5.update(&user.as_bytes());
let output = md5.finalize_reset();
// Second pass
md5.update(format!("{:x}", output));
md5.update(salt);
let mut password = format!("md5{:x}", md5.finalize())
.chars()
.map(|x| x as u8)
.collect::<Vec<u8>>();
password.push(0);
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
2022-02-03 15:17:04 -08:00
message.put_u8(b'p');
message.put_i32(password.len() as i32 + 4);
message.put_slice(&password[..]);
Ok(write_all(stream, message).await?)
}
2022-02-09 20:02:20 -08:00
/// Implements a response to our custom `SET SHARDING KEY`
/// and `SET SERVER ROLE` commands.
2022-02-09 06:51:31 -08:00
/// This tells the client we're ready for the next query.
2022-02-09 20:02:20 -08:00
pub async fn custom_protocol_response_ok(
stream: &mut OwnedWriteHalf,
message: &str,
) -> Result<(), Error> {
2022-02-08 13:11:50 -08:00
let mut res = BytesMut::with_capacity(25);
2022-02-09 20:02:20 -08:00
let set_complete = BytesMut::from(&format!("{}\0", message)[..]);
2022-02-08 13:11:50 -08:00
let len = (set_complete.len() + 4) as i32;
2022-02-09 06:51:31 -08:00
// CommandComplete
2022-02-08 13:11:50 -08:00
res.put_u8(b'C');
res.put_i32(len);
res.put_slice(&set_complete[..]);
2022-02-09 06:51:31 -08:00
// ReadyForQuery (idle)
2022-02-08 13:11:50 -08:00
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, res).await
}
/// Send a custom error message to the client.
/// Tell the client we are ready for the next query and no rollback is necessary.
/// Docs on error codes: https://www.postgresql.org/docs/12/errcodes-appendix.html
pub async fn error_response(stream: &mut OwnedWriteHalf, message: &str) -> Result<(), Error> {
let mut error = BytesMut::new();
// Error level
error.put_u8(b'S');
error.put_slice(&b"FATAL\0"[..]);
// Error level (non-translatable)
error.put_u8(b'V');
error.put_slice(&b"FATAL\0"[..]);
// Error code: not sure how much this matters.
error.put_u8(b'C');
error.put_slice(&b"58000\0"[..]); // system_error, see Appendix A.
// The short error message.
error.put_u8(b'M');
error.put_slice(&format!("{}\0", message).as_bytes());
// No more fields follow.
error.put_u8(0);
// Ready for query, no rollback needed (I = idle).
let mut ready_for_query = BytesMut::new();
ready_for_query.put_u8(b'Z');
ready_for_query.put_i32(5);
ready_for_query.put_u8(b'I');
// Compose the two message reply.
let mut res = BytesMut::with_capacity(error.len() + ready_for_query.len() + 5);
res.put_u8(b'E');
res.put_i32(error.len() as i32 + 4);
res.put(error);
res.put(ready_for_query);
Ok(write_all_half(stream, res).await?)
}
2022-02-08 09:33:20 -08:00
/// Write all data in the buffer to the TcpStream.
2022-02-03 13:35:40 -08:00
pub async fn write_all(stream: &mut TcpStream, buf: BytesMut) -> Result<(), Error> {
match stream.write_all(&buf).await {
Ok(_) => Ok(()),
Err(_) => return Err(Error::SocketError),
}
2022-02-03 13:54:07 -08:00
}
2022-02-08 09:33:20 -08:00
/// Write all the data in the buffer to the TcpStream, write owned half (see mpsc).
2022-02-03 15:17:04 -08:00
pub async fn write_all_half(stream: &mut OwnedWriteHalf, buf: BytesMut) -> Result<(), Error> {
match stream.write_all(&buf).await {
Ok(_) => Ok(()),
Err(_) => return Err(Error::SocketError),
}
}
2022-02-03 13:54:07 -08:00
/// Read a complete message from the socket.
pub async fn read_message(stream: &mut BufReader<OwnedReadHalf>) -> Result<BytesMut, Error> {
let code = match stream.read_u8().await {
Ok(code) => code,
Err(_) => return Err(Error::SocketError),
};
let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::SocketError),
};
let mut buf = vec![0u8; len as usize - 4];
match stream.read_exact(&mut buf).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
};
let mut bytes = BytesMut::with_capacity(len as usize + 1);
bytes.put_u8(code);
bytes.put_i32(len);
bytes.put_slice(&buf);
Ok(bytes)
2022-02-03 15:17:04 -08:00
}