mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 01:16:30 +00:00
it works!
This commit is contained in:
61
Cargo.lock
generated
61
Cargo.lock
generated
@@ -8,6 +8,15 @@ version = "1.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
|
||||
|
||||
[[package]]
|
||||
name = "block-buffer"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f1d36a02058e76b040de25a4464ba1c80935655595b661505c8b39b664828b95"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bytes"
|
||||
version = "1.1.0"
|
||||
@@ -20,6 +29,36 @@ version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "crypto-common"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "683d6b536309245c849479fba3da410962a43ed8e51c26b729208ec0ac2798d0"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b697d66081d42af4fba142d56918a3cb21dc8eb63372c6b85d14f44fb9c5979b"
|
||||
dependencies = [
|
||||
"block-buffer",
|
||||
"crypto-common",
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "generic-array"
|
||||
version = "0.14.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fd48d33ec7f05fbfa152300fdad764757cbded343c1aa1cff2fbaf4134851803"
|
||||
dependencies = [
|
||||
"typenum",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.1.19"
|
||||
@@ -62,6 +101,15 @@ dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "md-5"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6a38fc55c8bbc10058782919516f88826e70320db6d206aebc49611d24216ae"
|
||||
dependencies = [
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.4.1"
|
||||
@@ -169,6 +217,7 @@ name = "rabbit"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"md-5",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
@@ -243,12 +292,24 @@ dependencies = [
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-xid"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3"
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
|
||||
@@ -8,3 +8,4 @@ edition = "2021"
|
||||
[dependencies]
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
bytes = "1"
|
||||
md-5 = "*"
|
||||
@@ -1,14 +1,14 @@
|
||||
use tokio::io::{AsyncReadExt, BufReader};
|
||||
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
||||
/// PostgreSQL client (frontend).
|
||||
/// We are pretending to be the backend.
|
||||
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
||||
use tokio::io::{AsyncReadExt, BufReader};
|
||||
|
||||
use bytes::{BytesMut, Buf, BufMut};
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
|
||||
use crate::errors::Error;
|
||||
use crate::messages::*;
|
||||
use crate::server::Server;
|
||||
|
||||
pub struct Client {
|
||||
read: BufReader<OwnedReadHalf>,
|
||||
@@ -43,7 +43,7 @@ impl Client {
|
||||
no.put_u8(b'N');
|
||||
|
||||
write_all(&mut stream, no).await?;
|
||||
},
|
||||
}
|
||||
|
||||
// Regular startup message.
|
||||
196608 => {
|
||||
@@ -58,7 +58,7 @@ impl Client {
|
||||
read: BufReader::new(read),
|
||||
write: write,
|
||||
});
|
||||
},
|
||||
}
|
||||
|
||||
_ => {
|
||||
return Err(Error::ProtocolSyncError);
|
||||
@@ -67,21 +67,28 @@ impl Client {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle(&mut self) -> Result<(), Error> {
|
||||
pub async fn handle(&mut self, mut server: Server) -> Result<(), Error> {
|
||||
loop {
|
||||
let mut message = read_message(&mut self.read).await?;
|
||||
let original = message.clone(); // To be forwarded to the server
|
||||
let code = message.get_u8() as char;
|
||||
let len = message.get_i32() as usize;
|
||||
let _len = message.get_i32() as usize;
|
||||
|
||||
match code {
|
||||
'Q' => {
|
||||
println!(">>> Query: {:?}", message);
|
||||
},
|
||||
server.send(original).await?;
|
||||
let response = server.recv().await?;
|
||||
write_all_half(&mut self.write, response).await?;
|
||||
}
|
||||
|
||||
'X' => {
|
||||
// Client closing
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
_ => {
|
||||
println!(">>> Unexpected code: {}", code);
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum Error {
|
||||
SocketError,
|
||||
ClientDisconneted,
|
||||
// ClientDisconnected,
|
||||
ClientBadStartup,
|
||||
ProtocolSyncError,
|
||||
ServerError,
|
||||
}
|
||||
28
src/main.rs
28
src/main.rs
@@ -1,11 +1,13 @@
|
||||
extern crate bytes;
|
||||
extern crate md5;
|
||||
extern crate tokio;
|
||||
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
mod client;
|
||||
mod errors;
|
||||
mod messages;
|
||||
mod client;
|
||||
mod server;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
@@ -30,13 +32,29 @@ async fn main() {
|
||||
|
||||
// Client goes to another thread, bye.
|
||||
tokio::task::spawn(async move {
|
||||
println!(">> Client {:?} connected", addr);
|
||||
println!(">> Client {:?} connected.", addr);
|
||||
|
||||
match client::Client::startup(socket).await {
|
||||
Ok(mut client) => {
|
||||
println!(">> Client {:?} connected successfully!", addr);
|
||||
client.handle().await;
|
||||
},
|
||||
println!(">> Client {:?} authenticated successfully!", addr);
|
||||
let server =
|
||||
match server::Server::startup("127.0.0.1", "5432", "lev", "lev", "lev")
|
||||
.await
|
||||
{
|
||||
Ok(server) => server,
|
||||
Err(_) => return,
|
||||
};
|
||||
|
||||
match client.handle(server).await {
|
||||
Ok(()) => {
|
||||
println!(">> Client {:?} disconnected.", addr);
|
||||
}
|
||||
|
||||
Err(err) => {
|
||||
println!(">> Client disconnected with error: {:?}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(err) => {
|
||||
println!(">> Error: {:?}", err);
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::net::tcp::OwnedReadHalf;
|
||||
use tokio::io::{AsyncWriteExt, BufReader, AsyncReadExt};
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use md5::{Digest, Md5};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use crate::errors::Error;
|
||||
|
||||
@@ -26,6 +26,67 @@ pub async fn ready_for_query(stream: &mut TcpStream) -> Result<(), Error> {
|
||||
Ok(write_all(stream, bytes).await?)
|
||||
}
|
||||
|
||||
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> {
|
||||
let mut bytes = BytesMut::with_capacity(25);
|
||||
|
||||
bytes.put_i32(196608); // Protocol number
|
||||
|
||||
// User
|
||||
bytes.put(&b"user\0"[..]);
|
||||
bytes.put_slice(&user.as_bytes());
|
||||
bytes.put_u8(0);
|
||||
|
||||
// Database
|
||||
bytes.put(&b"database\0"[..]);
|
||||
bytes.put_slice(&database.as_bytes());
|
||||
bytes.put_u8(0);
|
||||
bytes.put_u8(0); // Null terminator
|
||||
|
||||
let len = bytes.len() as i32 + 4i32;
|
||||
|
||||
let mut startup = BytesMut::with_capacity(len as usize);
|
||||
|
||||
startup.put_i32(len);
|
||||
startup.put(bytes);
|
||||
|
||||
match stream.write_all(&startup).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn md5_password(
|
||||
stream: &mut TcpStream,
|
||||
user: &str,
|
||||
password: &str,
|
||||
salt: &[u8],
|
||||
) -> Result<(), Error> {
|
||||
let mut md5 = Md5::new();
|
||||
|
||||
// First pass
|
||||
md5.update(&password.as_bytes());
|
||||
md5.update(&user.as_bytes());
|
||||
|
||||
let output = md5.finalize_reset();
|
||||
|
||||
// Second pass
|
||||
md5.update(format!("{:x}", output));
|
||||
md5.update(salt);
|
||||
|
||||
let mut password = format!("md5{:x}", md5.finalize())
|
||||
.chars()
|
||||
.map(|x| x as u8)
|
||||
.collect::<Vec<u8>>();
|
||||
password.push(0);
|
||||
|
||||
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
|
||||
message.put_u8(b'p');
|
||||
message.put_i32(password.len() as i32 + 4);
|
||||
message.put_slice(&password[..]);
|
||||
|
||||
Ok(write_all(stream, message).await?)
|
||||
}
|
||||
|
||||
pub async fn write_all(stream: &mut TcpStream, buf: BytesMut) -> Result<(), Error> {
|
||||
match stream.write_all(&buf).await {
|
||||
Ok(_) => Ok(()),
|
||||
@@ -33,6 +94,13 @@ pub async fn write_all(stream: &mut TcpStream, buf: BytesMut) -> Result<(), Erro
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn write_all_half(stream: &mut OwnedWriteHalf, buf: BytesMut) -> Result<(), Error> {
|
||||
match stream.write_all(&buf).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
}
|
||||
}
|
||||
|
||||
/// Read a complete message from the socket.
|
||||
pub async fn read_message(stream: &mut BufReader<OwnedReadHalf>) -> Result<BytesMut, Error> {
|
||||
let code = match stream.read_u8().await {
|
||||
|
||||
169
src/server.rs
Normal file
169
src/server.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use tokio::io::{AsyncReadExt, BufReader};
|
||||
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use crate::errors::Error;
|
||||
use crate::messages::*;
|
||||
|
||||
pub struct Server {
|
||||
read: BufReader<OwnedReadHalf>,
|
||||
write: OwnedWriteHalf,
|
||||
buffer: BytesMut,
|
||||
}
|
||||
|
||||
impl Server {
|
||||
pub async fn startup(
|
||||
host: &str,
|
||||
port: &str,
|
||||
user: &str,
|
||||
password: &str,
|
||||
database: &str,
|
||||
) -> Result<Server, Error> {
|
||||
let mut stream = match TcpStream::connect(&format!("{}:{}", host, port)).await {
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
println!(">> Could not connect to server: {}", err);
|
||||
return Err(Error::SocketError);
|
||||
}
|
||||
};
|
||||
|
||||
startup(&mut stream, user, database).await?;
|
||||
|
||||
loop {
|
||||
let code = match stream.read_u8().await {
|
||||
Ok(code) => code as char,
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
};
|
||||
|
||||
let len = match stream.read_i32().await {
|
||||
Ok(len) => len,
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
};
|
||||
|
||||
match code {
|
||||
'R' => {
|
||||
// Auth can proceed
|
||||
let code = match stream.read_i32().await {
|
||||
Ok(code) => code,
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
};
|
||||
|
||||
match code {
|
||||
// MD5
|
||||
5 => {
|
||||
let mut salt = vec![0u8; 4];
|
||||
|
||||
match stream.read_exact(&mut salt).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
};
|
||||
|
||||
md5_password(&mut stream, user, password, &salt[..]).await?;
|
||||
}
|
||||
|
||||
// We're in!
|
||||
0 => {
|
||||
println!(">> Server authentication successful!");
|
||||
}
|
||||
|
||||
_ => {
|
||||
println!(">> Unsupported authentication mechanism: {}", code);
|
||||
return Err(Error::ServerError);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
'E' => {
|
||||
println!(">> Database error");
|
||||
return Err(Error::ServerError);
|
||||
}
|
||||
|
||||
'S' => {
|
||||
// Parameter
|
||||
let mut param = vec![0u8; len as usize - 4];
|
||||
match stream.read_exact(&mut param).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
};
|
||||
}
|
||||
|
||||
'K' => {
|
||||
// TODO: save cancellation secret
|
||||
let mut cancel_secret = vec![0u8; len as usize - 4];
|
||||
match stream.read_exact(&mut cancel_secret).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
};
|
||||
}
|
||||
|
||||
'Z' => {
|
||||
let mut idle = vec![0u8; len as usize - 4];
|
||||
|
||||
match stream.read_exact(&mut idle).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
};
|
||||
|
||||
// Startup finished
|
||||
let (read, write) = stream.into_split();
|
||||
|
||||
return Ok(Server {
|
||||
read: BufReader::new(read),
|
||||
write: write,
|
||||
buffer: BytesMut::with_capacity(8196),
|
||||
});
|
||||
}
|
||||
|
||||
_ => {
|
||||
println!(">> Unknown code: {}", code);
|
||||
return Err(Error::ProtocolSyncError);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send(&mut self, messages: BytesMut) -> Result<(), Error> {
|
||||
Ok(write_all_half(&mut self.write, messages).await?)
|
||||
}
|
||||
|
||||
pub async fn recv(&mut self) -> Result<BytesMut, Error> {
|
||||
loop {
|
||||
let mut message = read_message(&mut self.read).await?;
|
||||
|
||||
// Buffer the message we'll forward to the client in a bit.
|
||||
self.buffer.put(&message[..]);
|
||||
|
||||
let code = message.get_u8() as char;
|
||||
match code {
|
||||
'Z' => {
|
||||
// Ready for query, time to forward buffer to client.
|
||||
break;
|
||||
}
|
||||
|
||||
_ => {
|
||||
// Keep buffering,
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
let bytes = self.buffer.clone();
|
||||
self.buffer.clear();
|
||||
|
||||
Ok(bytes)
|
||||
}
|
||||
|
||||
// pub async fn handle(&mut self) -> Result<(), Error> {
|
||||
// loop {
|
||||
// let message = read_message(&mut self.read).await?;
|
||||
// let original = message.clone();
|
||||
|
||||
// let code = message.get_u8() as char;
|
||||
// let len = message.get_i32();
|
||||
|
||||
// match code {
|
||||
// '
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
Reference in New Issue
Block a user