it works!

This commit is contained in:
Lev Kokotov
2022-02-03 15:17:04 -08:00
parent 880f1a649f
commit f4d647ce2f
7 changed files with 350 additions and 25 deletions

61
Cargo.lock generated
View File

@@ -8,6 +8,15 @@ version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "block-buffer"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1d36a02058e76b040de25a4464ba1c80935655595b661505c8b39b664828b95"
dependencies = [
"generic-array",
]
[[package]]
name = "bytes"
version = "1.1.0"
@@ -20,6 +29,36 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "crypto-common"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "683d6b536309245c849479fba3da410962a43ed8e51c26b729208ec0ac2798d0"
dependencies = [
"generic-array",
]
[[package]]
name = "digest"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b697d66081d42af4fba142d56918a3cb21dc8eb63372c6b85d14f44fb9c5979b"
dependencies = [
"block-buffer",
"crypto-common",
"generic-array",
]
[[package]]
name = "generic-array"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd48d33ec7f05fbfa152300fdad764757cbded343c1aa1cff2fbaf4134851803"
dependencies = [
"typenum",
"version_check",
]
[[package]]
name = "hermit-abi"
version = "0.1.19"
@@ -62,6 +101,15 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "md-5"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6a38fc55c8bbc10058782919516f88826e70320db6d206aebc49611d24216ae"
dependencies = [
"digest",
]
[[package]]
name = "memchr"
version = "2.4.1"
@@ -169,6 +217,7 @@ name = "rabbit"
version = "0.1.0"
dependencies = [
"bytes",
"md-5",
"tokio",
]
@@ -243,12 +292,24 @@ dependencies = [
"syn",
]
[[package]]
name = "typenum"
version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
[[package]]
name = "unicode-xid"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3"
[[package]]
name = "version_check"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "winapi"
version = "0.3.9"

View File

@@ -7,4 +7,5 @@ edition = "2021"
[dependencies]
tokio = { version = "1", features = ["full"] }
bytes = "1"
bytes = "1"
md-5 = "*"

View File

@@ -1,14 +1,14 @@
use tokio::io::{AsyncReadExt, BufReader};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
/// PostgreSQL client (frontend).
/// We are pretending to be the backend.
use tokio::net::TcpStream;
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::io::{AsyncReadExt, BufReader};
use bytes::{BytesMut, Buf, BufMut};
use bytes::{Buf, BufMut, BytesMut};
use crate::errors::Error;
use crate::messages::*;
use crate::server::Server;
pub struct Client {
read: BufReader<OwnedReadHalf>,
@@ -43,7 +43,7 @@ impl Client {
no.put_u8(b'N');
write_all(&mut stream, no).await?;
},
}
// Regular startup message.
196608 => {
@@ -58,7 +58,7 @@ impl Client {
read: BufReader::new(read),
write: write,
});
},
}
_ => {
return Err(Error::ProtocolSyncError);
@@ -67,22 +67,29 @@ impl Client {
}
}
pub async fn handle(&mut self) -> Result<(), Error> {
pub async fn handle(&mut self, mut server: Server) -> Result<(), Error> {
loop {
let mut message = read_message(&mut self.read).await?;
let original = message.clone(); // To be forwarded to the server
let code = message.get_u8() as char;
let len = message.get_i32() as usize;
let _len = message.get_i32() as usize;
match code {
'Q' => {
println!(">>> Query: {:?}", message);
},
server.send(original).await?;
let response = server.recv().await?;
write_all_half(&mut self.write, response).await?;
}
'X' => {
// Client closing
return Ok(());
}
_ => {
println!(">>> Unexpected code: {}", code);
},
}
}
}
}
}
}

View File

@@ -1,7 +1,8 @@
#[derive(Debug, PartialEq)]
pub enum Error {
SocketError,
ClientDisconneted,
// ClientDisconnected,
ClientBadStartup,
ProtocolSyncError,
}
ServerError,
}

View File

@@ -1,11 +1,13 @@
extern crate bytes;
extern crate md5;
extern crate tokio;
use tokio::net::TcpListener;
mod client;
mod errors;
mod messages;
mod client;
mod server;
#[tokio::main]
async fn main() {
@@ -30,13 +32,29 @@ async fn main() {
// Client goes to another thread, bye.
tokio::task::spawn(async move {
println!(">> Client {:?} connected", addr);
println!(">> Client {:?} connected.", addr);
match client::Client::startup(socket).await {
Ok(mut client) => {
println!(">> Client {:?} connected successfully!", addr);
client.handle().await;
},
println!(">> Client {:?} authenticated successfully!", addr);
let server =
match server::Server::startup("127.0.0.1", "5432", "lev", "lev", "lev")
.await
{
Ok(server) => server,
Err(_) => return,
};
match client.handle(server).await {
Ok(()) => {
println!(">> Client {:?} disconnected.", addr);
}
Err(err) => {
println!(">> Client disconnected with error: {:?}", err);
}
}
}
Err(err) => {
println!(">> Error: {:?}", err);

View File

@@ -1,8 +1,8 @@
use tokio::net::TcpStream;
use tokio::net::tcp::OwnedReadHalf;
use tokio::io::{AsyncWriteExt, BufReader, AsyncReadExt};
use bytes::{BufMut, BytesMut};
use md5::{Digest, Md5};
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::TcpStream;
use crate::errors::Error;
@@ -26,6 +26,67 @@ pub async fn ready_for_query(stream: &mut TcpStream) -> Result<(), Error> {
Ok(write_all(stream, bytes).await?)
}
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> {
let mut bytes = BytesMut::with_capacity(25);
bytes.put_i32(196608); // Protocol number
// User
bytes.put(&b"user\0"[..]);
bytes.put_slice(&user.as_bytes());
bytes.put_u8(0);
// Database
bytes.put(&b"database\0"[..]);
bytes.put_slice(&database.as_bytes());
bytes.put_u8(0);
bytes.put_u8(0); // Null terminator
let len = bytes.len() as i32 + 4i32;
let mut startup = BytesMut::with_capacity(len as usize);
startup.put_i32(len);
startup.put(bytes);
match stream.write_all(&startup).await {
Ok(_) => Ok(()),
Err(_) => return Err(Error::SocketError),
}
}
pub async fn md5_password(
stream: &mut TcpStream,
user: &str,
password: &str,
salt: &[u8],
) -> Result<(), Error> {
let mut md5 = Md5::new();
// First pass
md5.update(&password.as_bytes());
md5.update(&user.as_bytes());
let output = md5.finalize_reset();
// Second pass
md5.update(format!("{:x}", output));
md5.update(salt);
let mut password = format!("md5{:x}", md5.finalize())
.chars()
.map(|x| x as u8)
.collect::<Vec<u8>>();
password.push(0);
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
message.put_u8(b'p');
message.put_i32(password.len() as i32 + 4);
message.put_slice(&password[..]);
Ok(write_all(stream, message).await?)
}
pub async fn write_all(stream: &mut TcpStream, buf: BytesMut) -> Result<(), Error> {
match stream.write_all(&buf).await {
Ok(_) => Ok(()),
@@ -33,6 +94,13 @@ pub async fn write_all(stream: &mut TcpStream, buf: BytesMut) -> Result<(), Erro
}
}
pub async fn write_all_half(stream: &mut OwnedWriteHalf, buf: BytesMut) -> Result<(), Error> {
match stream.write_all(&buf).await {
Ok(_) => Ok(()),
Err(_) => return Err(Error::SocketError),
}
}
/// Read a complete message from the socket.
pub async fn read_message(stream: &mut BufReader<OwnedReadHalf>) -> Result<BytesMut, Error> {
let code = match stream.read_u8().await {
@@ -59,4 +127,4 @@ pub async fn read_message(stream: &mut BufReader<OwnedReadHalf>) -> Result<Bytes
bytes.put_slice(&buf);
Ok(bytes)
}
}

169
src/server.rs Normal file
View File

@@ -0,0 +1,169 @@
use bytes::{Buf, BufMut, BytesMut};
use tokio::io::{AsyncReadExt, BufReader};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::TcpStream;
use crate::errors::Error;
use crate::messages::*;
pub struct Server {
read: BufReader<OwnedReadHalf>,
write: OwnedWriteHalf,
buffer: BytesMut,
}
impl Server {
pub async fn startup(
host: &str,
port: &str,
user: &str,
password: &str,
database: &str,
) -> Result<Server, Error> {
let mut stream = match TcpStream::connect(&format!("{}:{}", host, port)).await {
Ok(stream) => stream,
Err(err) => {
println!(">> Could not connect to server: {}", err);
return Err(Error::SocketError);
}
};
startup(&mut stream, user, database).await?;
loop {
let code = match stream.read_u8().await {
Ok(code) => code as char,
Err(_) => return Err(Error::SocketError),
};
let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::SocketError),
};
match code {
'R' => {
// Auth can proceed
let code = match stream.read_i32().await {
Ok(code) => code,
Err(_) => return Err(Error::SocketError),
};
match code {
// MD5
5 => {
let mut salt = vec![0u8; 4];
match stream.read_exact(&mut salt).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
};
md5_password(&mut stream, user, password, &salt[..]).await?;
}
// We're in!
0 => {
println!(">> Server authentication successful!");
}
_ => {
println!(">> Unsupported authentication mechanism: {}", code);
return Err(Error::ServerError);
}
}
}
'E' => {
println!(">> Database error");
return Err(Error::ServerError);
}
'S' => {
// Parameter
let mut param = vec![0u8; len as usize - 4];
match stream.read_exact(&mut param).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
};
}
'K' => {
// TODO: save cancellation secret
let mut cancel_secret = vec![0u8; len as usize - 4];
match stream.read_exact(&mut cancel_secret).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
};
}
'Z' => {
let mut idle = vec![0u8; len as usize - 4];
match stream.read_exact(&mut idle).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
};
// Startup finished
let (read, write) = stream.into_split();
return Ok(Server {
read: BufReader::new(read),
write: write,
buffer: BytesMut::with_capacity(8196),
});
}
_ => {
println!(">> Unknown code: {}", code);
return Err(Error::ProtocolSyncError);
}
};
}
}
pub async fn send(&mut self, messages: BytesMut) -> Result<(), Error> {
Ok(write_all_half(&mut self.write, messages).await?)
}
pub async fn recv(&mut self) -> Result<BytesMut, Error> {
loop {
let mut message = read_message(&mut self.read).await?;
// Buffer the message we'll forward to the client in a bit.
self.buffer.put(&message[..]);
let code = message.get_u8() as char;
match code {
'Z' => {
// Ready for query, time to forward buffer to client.
break;
}
_ => {
// Keep buffering,
}
};
}
let bytes = self.buffer.clone();
self.buffer.clear();
Ok(bytes)
}
// pub async fn handle(&mut self) -> Result<(), Error> {
// loop {
// let message = read_message(&mut self.read).await?;
// let original = message.clone();
// let code = message.get_u8() as char;
// let len = message.get_i32();
// match code {
// '
// }
// }
// }
}