Files
pgcat/src/server.rs

230 lines
6.8 KiB
Rust
Raw Normal View History

2022-02-03 15:17:04 -08:00
use bytes::{Buf, BufMut, BytesMut};
use tokio::io::{AsyncReadExt, BufReader};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::TcpStream;
use crate::errors::Error;
use crate::messages::*;
pub struct Server {
read: BufReader<OwnedReadHalf>,
write: OwnedWriteHalf,
buffer: BytesMut,
2022-02-03 16:25:05 -08:00
in_transaction: bool,
bad: bool,
2022-02-03 15:17:04 -08:00
}
impl Server {
pub async fn startup(
host: &str,
port: &str,
user: &str,
password: &str,
database: &str,
) -> Result<Server, Error> {
let mut stream = match TcpStream::connect(&format!("{}:{}", host, port)).await {
Ok(stream) => stream,
Err(err) => {
println!(">> Could not connect to server: {}", err);
return Err(Error::SocketError);
}
};
startup(&mut stream, user, database).await?;
loop {
let code = match stream.read_u8().await {
Ok(code) => code as char,
Err(_) => return Err(Error::SocketError),
};
let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::SocketError),
};
match code {
'R' => {
// Auth can proceed
let code = match stream.read_i32().await {
Ok(code) => code,
Err(_) => return Err(Error::SocketError),
};
match code {
// MD5
5 => {
let mut salt = vec![0u8; 4];
match stream.read_exact(&mut salt).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
};
md5_password(&mut stream, user, password, &salt[..]).await?;
}
// We're in!
0 => {
println!(">> Server authentication successful!");
}
_ => {
println!(">> Unsupported authentication mechanism: {}", code);
return Err(Error::ServerError);
}
}
}
'E' => {
println!(">> Database error");
return Err(Error::ServerError);
}
'S' => {
// Parameter
let mut param = vec![0u8; len as usize - 4];
match stream.read_exact(&mut param).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
};
}
'K' => {
// TODO: save cancellation secret
let mut cancel_secret = vec![0u8; len as usize - 4];
match stream.read_exact(&mut cancel_secret).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
};
}
'Z' => {
let mut idle = vec![0u8; len as usize - 4];
match stream.read_exact(&mut idle).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
};
// Startup finished
let (read, write) = stream.into_split();
return Ok(Server {
read: BufReader::new(read),
write: write,
buffer: BytesMut::with_capacity(8196),
2022-02-03 16:25:05 -08:00
in_transaction: false,
bad: false,
2022-02-03 15:17:04 -08:00
});
}
_ => {
println!(">> Unknown code: {}", code);
return Err(Error::ProtocolSyncError);
}
};
}
}
pub async fn send(&mut self, messages: BytesMut) -> Result<(), Error> {
match write_all_half(&mut self.write, messages).await {
Ok(_) => Ok(()),
Err(err) => {
println!(">> Terminating server because of: {:?}", err);
self.bad = true;
Err(err)
}
}
2022-02-03 15:17:04 -08:00
}
pub async fn recv(&mut self) -> Result<BytesMut, Error> {
loop {
let mut message = match read_message(&mut self.read).await {
Ok(message) => message,
Err(err) => {
println!(">> Terminating server because of: {:?}", err);
self.bad = true;
return Err(err);
}
};
2022-02-03 15:17:04 -08:00
// Buffer the message we'll forward to the client in a bit.
self.buffer.put(&message[..]);
let code = message.get_u8() as char;
2022-02-03 16:25:05 -08:00
let _len = message.get_i32();
2022-02-03 15:33:26 -08:00
2022-02-03 15:17:04 -08:00
match code {
'Z' => {
// Ready for query, time to forward buffer to client.
2022-02-03 16:25:05 -08:00
let transaction_state = message.get_u8() as char;
2022-02-03 16:25:05 -08:00
match transaction_state {
'T' => {
self.in_transaction = true;
}
2022-02-03 16:25:05 -08:00
'I' => {
self.in_transaction = false;
}
// Error client didn't clean up!
// We shuold drop this server
'E' => {
self.bad = true;
}
2022-02-03 16:25:05 -08:00
_ => {
self.bad = true;
return Err(Error::ProtocolSyncError);
}
2022-02-03 16:25:05 -08:00
};
2022-02-03 15:17:04 -08:00
break;
}
2022-02-03 15:17:04 -08:00
_ => {
// Keep buffering,
}
};
}
let bytes = self.buffer.clone();
self.buffer.clear();
Ok(bytes)
}
2022-02-03 16:25:05 -08:00
pub fn in_transaction(&self) -> bool {
self.in_transaction
}
pub fn is_bad(&self) -> bool {
self.bad
}
pub fn mark_bad(&mut self) {
println!(">> Server marked bad");
self.bad = true;
}
pub async fn set_name(&mut self, name: &str) -> Result<(), Error> {
let mut query = BytesMut::from(&format!("SET application_name = {}", name).as_bytes()[..]);
query.put_u8(0);
let len = query.len() as i32 + 4;
let mut msg = BytesMut::with_capacity(len as usize + 1);
msg.put_u8(b'Q');
msg.put_i32(len);
msg.put_slice(&query[..]);
self.send(msg).await?;
let _ = self.recv().await?;
Ok(())
}
2022-02-03 15:17:04 -08:00
}