2022-02-03 15:17:04 -08:00
|
|
|
use bytes::{Buf, BufMut, BytesMut};
|
|
|
|
|
use tokio::io::{AsyncReadExt, BufReader};
|
|
|
|
|
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
|
|
|
|
use tokio::net::TcpStream;
|
|
|
|
|
|
|
|
|
|
use crate::errors::Error;
|
|
|
|
|
use crate::messages::*;
|
|
|
|
|
|
|
|
|
|
pub struct Server {
|
|
|
|
|
read: BufReader<OwnedReadHalf>,
|
|
|
|
|
write: OwnedWriteHalf,
|
|
|
|
|
buffer: BytesMut,
|
2022-02-03 16:25:05 -08:00
|
|
|
in_transaction: bool,
|
2022-02-03 17:06:19 -08:00
|
|
|
bad: bool,
|
2022-02-03 15:17:04 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl Server {
|
|
|
|
|
pub async fn startup(
|
|
|
|
|
host: &str,
|
|
|
|
|
port: &str,
|
|
|
|
|
user: &str,
|
|
|
|
|
password: &str,
|
|
|
|
|
database: &str,
|
|
|
|
|
) -> Result<Server, Error> {
|
|
|
|
|
let mut stream = match TcpStream::connect(&format!("{}:{}", host, port)).await {
|
|
|
|
|
Ok(stream) => stream,
|
|
|
|
|
Err(err) => {
|
|
|
|
|
println!(">> Could not connect to server: {}", err);
|
|
|
|
|
return Err(Error::SocketError);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
startup(&mut stream, user, database).await?;
|
|
|
|
|
|
|
|
|
|
loop {
|
|
|
|
|
let code = match stream.read_u8().await {
|
|
|
|
|
Ok(code) => code as char,
|
|
|
|
|
Err(_) => return Err(Error::SocketError),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let len = match stream.read_i32().await {
|
|
|
|
|
Ok(len) => len,
|
|
|
|
|
Err(_) => return Err(Error::SocketError),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
match code {
|
|
|
|
|
'R' => {
|
|
|
|
|
// Auth can proceed
|
|
|
|
|
let code = match stream.read_i32().await {
|
|
|
|
|
Ok(code) => code,
|
|
|
|
|
Err(_) => return Err(Error::SocketError),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
match code {
|
|
|
|
|
// MD5
|
|
|
|
|
5 => {
|
|
|
|
|
let mut salt = vec![0u8; 4];
|
|
|
|
|
|
|
|
|
|
match stream.read_exact(&mut salt).await {
|
|
|
|
|
Ok(_) => (),
|
|
|
|
|
Err(_) => return Err(Error::SocketError),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
md5_password(&mut stream, user, password, &salt[..]).await?;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// We're in!
|
|
|
|
|
0 => {
|
|
|
|
|
println!(">> Server authentication successful!");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_ => {
|
|
|
|
|
println!(">> Unsupported authentication mechanism: {}", code);
|
|
|
|
|
return Err(Error::ServerError);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
'E' => {
|
|
|
|
|
println!(">> Database error");
|
|
|
|
|
return Err(Error::ServerError);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
'S' => {
|
|
|
|
|
// Parameter
|
|
|
|
|
let mut param = vec![0u8; len as usize - 4];
|
|
|
|
|
match stream.read_exact(&mut param).await {
|
|
|
|
|
Ok(_) => (),
|
|
|
|
|
Err(_) => return Err(Error::SocketError),
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
'K' => {
|
|
|
|
|
// TODO: save cancellation secret
|
|
|
|
|
let mut cancel_secret = vec![0u8; len as usize - 4];
|
|
|
|
|
match stream.read_exact(&mut cancel_secret).await {
|
|
|
|
|
Ok(_) => (),
|
|
|
|
|
Err(_) => return Err(Error::SocketError),
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
'Z' => {
|
|
|
|
|
let mut idle = vec![0u8; len as usize - 4];
|
|
|
|
|
|
|
|
|
|
match stream.read_exact(&mut idle).await {
|
|
|
|
|
Ok(_) => (),
|
|
|
|
|
Err(_) => return Err(Error::SocketError),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Startup finished
|
|
|
|
|
let (read, write) = stream.into_split();
|
|
|
|
|
|
|
|
|
|
return Ok(Server {
|
|
|
|
|
read: BufReader::new(read),
|
|
|
|
|
write: write,
|
|
|
|
|
buffer: BytesMut::with_capacity(8196),
|
2022-02-03 16:25:05 -08:00
|
|
|
in_transaction: false,
|
2022-02-03 17:06:19 -08:00
|
|
|
bad: false,
|
2022-02-03 15:17:04 -08:00
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_ => {
|
|
|
|
|
println!(">> Unknown code: {}", code);
|
|
|
|
|
return Err(Error::ProtocolSyncError);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub async fn send(&mut self, messages: BytesMut) -> Result<(), Error> {
|
2022-02-03 17:06:19 -08:00
|
|
|
match write_all_half(&mut self.write, messages).await {
|
|
|
|
|
Ok(_) => Ok(()),
|
|
|
|
|
Err(err) => {
|
|
|
|
|
println!(">> Terminating server because of: {:?}", err);
|
|
|
|
|
self.bad = true;
|
|
|
|
|
Err(err)
|
|
|
|
|
}
|
|
|
|
|
}
|
2022-02-03 15:17:04 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub async fn recv(&mut self) -> Result<BytesMut, Error> {
|
|
|
|
|
loop {
|
2022-02-03 17:06:19 -08:00
|
|
|
let mut message = match read_message(&mut self.read).await {
|
|
|
|
|
Ok(message) => message,
|
|
|
|
|
Err(err) => {
|
|
|
|
|
println!(">> Terminating server because of: {:?}", err);
|
|
|
|
|
self.bad = true;
|
|
|
|
|
return Err(err);
|
|
|
|
|
}
|
|
|
|
|
};
|
2022-02-03 15:17:04 -08:00
|
|
|
|
|
|
|
|
// Buffer the message we'll forward to the client in a bit.
|
|
|
|
|
self.buffer.put(&message[..]);
|
|
|
|
|
|
|
|
|
|
let code = message.get_u8() as char;
|
2022-02-03 16:25:05 -08:00
|
|
|
let _len = message.get_i32();
|
2022-02-03 15:33:26 -08:00
|
|
|
|
2022-02-03 15:17:04 -08:00
|
|
|
match code {
|
|
|
|
|
'Z' => {
|
|
|
|
|
// Ready for query, time to forward buffer to client.
|
2022-02-03 16:25:05 -08:00
|
|
|
let transaction_state = message.get_u8() as char;
|
2022-02-03 17:06:19 -08:00
|
|
|
|
2022-02-03 16:25:05 -08:00
|
|
|
match transaction_state {
|
|
|
|
|
'T' => {
|
|
|
|
|
self.in_transaction = true;
|
2022-02-03 17:06:19 -08:00
|
|
|
}
|
2022-02-03 16:25:05 -08:00
|
|
|
|
|
|
|
|
'I' => {
|
|
|
|
|
self.in_transaction = false;
|
2022-02-03 17:06:19 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Error client didn't clean up!
|
|
|
|
|
// We shuold drop this server
|
|
|
|
|
'E' => {
|
2022-02-03 17:36:35 -08:00
|
|
|
self.in_transaction = false;
|
2022-02-03 17:06:19 -08:00
|
|
|
}
|
2022-02-03 16:25:05 -08:00
|
|
|
|
|
|
|
|
_ => {
|
2022-02-03 17:06:19 -08:00
|
|
|
self.bad = true;
|
|
|
|
|
return Err(Error::ProtocolSyncError);
|
|
|
|
|
}
|
2022-02-03 16:25:05 -08:00
|
|
|
};
|
|
|
|
|
|
2022-02-03 15:17:04 -08:00
|
|
|
break;
|
2022-02-03 17:06:19 -08:00
|
|
|
}
|
2022-02-03 15:17:04 -08:00
|
|
|
|
|
|
|
|
_ => {
|
|
|
|
|
// Keep buffering,
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let bytes = self.buffer.clone();
|
|
|
|
|
self.buffer.clear();
|
|
|
|
|
|
|
|
|
|
Ok(bytes)
|
|
|
|
|
}
|
2022-02-03 16:25:05 -08:00
|
|
|
|
|
|
|
|
pub fn in_transaction(&self) -> bool {
|
|
|
|
|
self.in_transaction
|
|
|
|
|
}
|
2022-02-03 17:06:19 -08:00
|
|
|
|
|
|
|
|
pub fn is_bad(&self) -> bool {
|
|
|
|
|
self.bad
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn mark_bad(&mut self) {
|
|
|
|
|
println!(">> Server marked bad");
|
|
|
|
|
self.bad = true;
|
|
|
|
|
}
|
|
|
|
|
|
2022-02-03 17:32:04 -08:00
|
|
|
pub async fn query(&mut self, query: &str) -> Result<(), Error> {
|
|
|
|
|
let mut query = BytesMut::from(&query.as_bytes()[..]);
|
2022-02-03 17:06:19 -08:00
|
|
|
query.put_u8(0);
|
|
|
|
|
|
|
|
|
|
let len = query.len() as i32 + 4;
|
|
|
|
|
|
|
|
|
|
let mut msg = BytesMut::with_capacity(len as usize + 1);
|
|
|
|
|
|
|
|
|
|
msg.put_u8(b'Q');
|
|
|
|
|
msg.put_i32(len);
|
|
|
|
|
msg.put_slice(&query[..]);
|
|
|
|
|
|
|
|
|
|
self.send(msg).await?;
|
|
|
|
|
let _ = self.recv().await?;
|
|
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
2022-02-03 17:32:04 -08:00
|
|
|
|
|
|
|
|
pub async fn set_name(&mut self, name: &str) -> Result<(), Error> {
|
|
|
|
|
Ok(self
|
|
|
|
|
.query(&format!("SET application_name = '{}'", name))
|
|
|
|
|
.await?)
|
|
|
|
|
}
|
2022-02-03 15:17:04 -08:00
|
|
|
}
|