mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-24 17:56:29 +00:00
131 lines
3.3 KiB
Rust
131 lines
3.3 KiB
Rust
use bytes::{BufMut, BytesMut};
|
|
use md5::{Digest, Md5};
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
|
|
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
|
use tokio::net::TcpStream;
|
|
|
|
use crate::errors::Error;
|
|
|
|
pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> {
|
|
let mut auth_ok = BytesMut::with_capacity(9);
|
|
|
|
auth_ok.put_u8(b'R');
|
|
auth_ok.put_i32(8);
|
|
auth_ok.put_i32(0);
|
|
|
|
Ok(write_all(stream, auth_ok).await?)
|
|
}
|
|
|
|
pub async fn ready_for_query(stream: &mut TcpStream) -> Result<(), Error> {
|
|
let mut bytes = BytesMut::with_capacity(5);
|
|
|
|
bytes.put_u8(b'Z');
|
|
bytes.put_i32(5);
|
|
bytes.put_u8(b'I'); // Idle
|
|
|
|
Ok(write_all(stream, bytes).await?)
|
|
}
|
|
|
|
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> {
|
|
let mut bytes = BytesMut::with_capacity(25);
|
|
|
|
bytes.put_i32(196608); // Protocol number
|
|
|
|
// User
|
|
bytes.put(&b"user\0"[..]);
|
|
bytes.put_slice(&user.as_bytes());
|
|
bytes.put_u8(0);
|
|
|
|
// Database
|
|
bytes.put(&b"database\0"[..]);
|
|
bytes.put_slice(&database.as_bytes());
|
|
bytes.put_u8(0);
|
|
bytes.put_u8(0); // Null terminator
|
|
|
|
let len = bytes.len() as i32 + 4i32;
|
|
|
|
let mut startup = BytesMut::with_capacity(len as usize);
|
|
|
|
startup.put_i32(len);
|
|
startup.put(bytes);
|
|
|
|
match stream.write_all(&startup).await {
|
|
Ok(_) => Ok(()),
|
|
Err(_) => return Err(Error::SocketError),
|
|
}
|
|
}
|
|
|
|
pub async fn md5_password(
|
|
stream: &mut TcpStream,
|
|
user: &str,
|
|
password: &str,
|
|
salt: &[u8],
|
|
) -> Result<(), Error> {
|
|
let mut md5 = Md5::new();
|
|
|
|
// First pass
|
|
md5.update(&password.as_bytes());
|
|
md5.update(&user.as_bytes());
|
|
|
|
let output = md5.finalize_reset();
|
|
|
|
// Second pass
|
|
md5.update(format!("{:x}", output));
|
|
md5.update(salt);
|
|
|
|
let mut password = format!("md5{:x}", md5.finalize())
|
|
.chars()
|
|
.map(|x| x as u8)
|
|
.collect::<Vec<u8>>();
|
|
password.push(0);
|
|
|
|
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
|
|
message.put_u8(b'p');
|
|
message.put_i32(password.len() as i32 + 4);
|
|
message.put_slice(&password[..]);
|
|
|
|
Ok(write_all(stream, message).await?)
|
|
}
|
|
|
|
pub async fn write_all(stream: &mut TcpStream, buf: BytesMut) -> Result<(), Error> {
|
|
match stream.write_all(&buf).await {
|
|
Ok(_) => Ok(()),
|
|
Err(_) => return Err(Error::SocketError),
|
|
}
|
|
}
|
|
|
|
pub async fn write_all_half(stream: &mut OwnedWriteHalf, buf: BytesMut) -> Result<(), Error> {
|
|
match stream.write_all(&buf).await {
|
|
Ok(_) => Ok(()),
|
|
Err(_) => return Err(Error::SocketError),
|
|
}
|
|
}
|
|
|
|
/// Read a complete message from the socket.
|
|
pub async fn read_message(stream: &mut BufReader<OwnedReadHalf>) -> Result<BytesMut, Error> {
|
|
let code = match stream.read_u8().await {
|
|
Ok(code) => code,
|
|
Err(_) => return Err(Error::SocketError),
|
|
};
|
|
|
|
let len = match stream.read_i32().await {
|
|
Ok(len) => len,
|
|
Err(_) => return Err(Error::SocketError),
|
|
};
|
|
|
|
let mut buf = vec![0u8; len as usize - 4];
|
|
|
|
match stream.read_exact(&mut buf).await {
|
|
Ok(_) => (),
|
|
Err(_) => return Err(Error::SocketError),
|
|
};
|
|
|
|
let mut bytes = BytesMut::with_capacity(len as usize + 1);
|
|
|
|
bytes.put_u8(code);
|
|
bytes.put_i32(len);
|
|
bytes.put_slice(&buf);
|
|
|
|
Ok(bytes)
|
|
}
|