at least it compiles

This commit is contained in:
Lev
2022-06-27 15:52:01 -07:00
parent b974aacd71
commit eb58920870
5 changed files with 538 additions and 323 deletions

View File

@@ -2,7 +2,7 @@
use bytes::{Buf, BufMut, BytesMut};
use log::{debug, error, trace};
use std::collections::HashMap;
use tokio::io::{AsyncReadExt, BufReader};
use tokio::io::{split, AsyncReadExt, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
use tokio::net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
@@ -17,15 +17,25 @@ use crate::pool::{get_pool, ClientServerMap};
use crate::query_router::{Command, QueryRouter};
use crate::server::Server;
use crate::stats::{get_reporter, Reporter};
use crate::stream::Tls;
use tokio_rustls::server::TlsStream;
/// Type of connection received from client.
enum ClientConnectionType {
Startup,
Tls,
CancelQuery,
}
/// The client state. One of these is created per client.
pub struct Client <T, S> {
pub struct Client<S, T> {
/// The reads are buffered (8K by default).
read: BufReader<T>,
read: BufReader<S>,
/// We buffer the writes ourselves because we know the protocol
/// better than a stock buffer.
write: S,
write: T,
/// Internal buffer, where we place messages until we have to flush
/// them to the backend.
@@ -63,163 +73,347 @@ pub struct Client <T, S> {
last_server_id: Option<i32>,
}
impl<T: tokio::io::AsyncRead + std::marker::Unpin, S: tokio::io::AsyncWrite + std::marker::Unpin> Client <T, S> {
/// Perform client startup sequence.
/// See docs: <https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3>
pub async fn startup(
mut stream: TcpStream,
/// Main client loop.
pub async fn client_loop(
mut stream: TcpStream,
client_server_map: ClientServerMap,
) -> Result<(), Error> {
match get_startup::<TcpStream>(&mut stream).await {
Ok((ClientConnectionType::Tls, bytes)) => {
match startup_tls(stream, client_server_map).await {
Ok(mut client) => client.handle().await,
Err(err) => Err(err),
}
}
Ok((ClientConnectionType::Startup, bytes)) => {
let (read, write) = split(stream);
match Client::handle_startup(read, write, bytes, client_server_map).await {
Ok(mut client) => client.handle().await,
Err(err) => Err(err),
}
}
Ok((ClientConnectionType::CancelQuery, bytes)) => {
return Err(Error::ProtocolSyncError);
}
Err(err) => Err(err),
}
}
async fn get_startup<S>(stream: &mut S) -> Result<(ClientConnectionType, BytesMut), Error>
where
S: tokio::io::AsyncRead + std::marker::Unpin + tokio::io::AsyncWrite,
{
// Get startup message length.
let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::ClientBadStartup),
};
// Get the rest of the message.
let mut startup = vec![0u8; len as usize - 4];
match stream.read_exact(&mut startup).await {
Ok(_) => (),
Err(_) => return Err(Error::ClientBadStartup),
};
let mut bytes = BytesMut::from(&startup[..]);
let code = bytes.get_i32();
match code {
// Client is requesting SSL (TLS).
SSL_REQUEST_CODE => Ok((ClientConnectionType::Tls, bytes)),
// Client wants to use plain text, requesting regular startup.
PROTOCOL_VERSION_NUMBER => Ok((ClientConnectionType::Startup, bytes)),
// Client is requesting to cancel a running query (plain text connection).
CANCEL_REQUEST_CODE => Ok((ClientConnectionType::CancelQuery, bytes)),
_ => Err(Error::ProtocolSyncError),
}
}
/// Handle TLS connection negotation.
pub async fn startup_tls(
mut stream: TcpStream,
client_server_map: ClientServerMap,
) -> Result<Client<ReadHalf<TlsStream<TcpStream>>, WriteHalf<TlsStream<TcpStream>>>, Error> {
// Accept SSL request if SSL is configured.
let mut yes = BytesMut::new();
yes.put_u8(b'S');
write_all(&mut stream, yes).await?;
// Negotiate TLS.
let mut tls = Tls::new().unwrap();
let mut stream = match tls.acceptor.accept(stream).await {
Ok(stream) => stream,
Err(_) => return Err(Error::TlsError),
};
match get_startup::<TlsStream<TcpStream>>(&mut stream).await {
Ok((ClientConnectionType::Startup, bytes)) => {
let (read, write) = split(stream);
Client::handle_startup(read, write, bytes, client_server_map).await
}
_ => Err(Error::ProtocolSyncError),
}
}
impl<S, T> Client<S, T>
where
S: tokio::io::AsyncRead + std::marker::Unpin,
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
// Perform client startup sequence in TLS.
pub async fn handle_startup(
mut read: S,
mut write: T,
bytes: BytesMut, // The rest of the startup message.
client_server_map: ClientServerMap,
) -> Result<Client<T, S>, Error> {
) -> Result<Client<S, T>, Error> {
let config = get_config();
let transaction_mode = config.general.pool_mode == "transaction";
let stats = get_reporter();
loop {
trace!("Waiting for StartupMessage");
trace!("Got StartupMessage");
let parameters = parse_startup(bytes.clone())?;
// Could be StartupMessage, SSLRequest or CancelRequest.
let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::ClientBadStartup),
};
// Generate random backend ID and secret key
let process_id: i32 = rand::random();
let secret_key: i32 = rand::random();
let mut startup = vec![0u8; len as usize - 4];
// Perform MD5 authentication.
// TODO: Add SASL support.
let salt = md5_challenge(&mut write).await?;
match stream.read_exact(&mut startup).await {
Ok(_) => (),
Err(_) => return Err(Error::ClientBadStartup),
};
let code = match read.read_u8().await {
Ok(p) => p,
Err(_) => return Err(Error::SocketError),
};
let mut bytes = BytesMut::from(&startup[..]);
let code = bytes.get_i32();
match code {
// Client wants SSL. We don't support it at the moment.
SSL_REQUEST_CODE => {
trace!("Rejecting SSLRequest");
let mut no = BytesMut::with_capacity(1);
no.put_u8(b'N');
write_all(&mut stream, no).await?;
}
// Regular startup message.
PROTOCOL_VERSION_NUMBER => {
trace!("Got StartupMessage");
let parameters = parse_startup(bytes.clone())?;
// Generate random backend ID and secret key
let process_id: i32 = rand::random();
let secret_key: i32 = rand::random();
// Perform MD5 authentication.
// TODO: Add SASL support.
let salt = md5_challenge(&mut stream).await?;
let code = match stream.read_u8().await {
Ok(p) => p,
Err(_) => return Err(Error::SocketError),
};
// PasswordMessage
if code as char != 'p' {
debug!("Expected p, got {}", code as char);
return Err(Error::ProtocolSyncError);
}
let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::SocketError),
};
let mut password_response = vec![0u8; (len - 4) as usize];
match stream.read_exact(&mut password_response).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
};
// Compare server and client hashes.
let password_hash =
md5_hash_password(&config.user.name, &config.user.password, &salt);
if password_hash != password_response {
debug!("Password authentication failed");
wrong_password(&mut stream, &config.user.name).await?;
return Err(Error::ClientError);
}
debug!("Password authentication successful");
auth_ok(&mut stream).await?;
write_all(&mut stream, get_pool().server_info()).await?;
backend_key_data(&mut stream, process_id, secret_key).await?;
ready_for_query(&mut stream).await?;
trace!("Startup OK");
let database = parameters
.get("database")
.unwrap_or(parameters.get("user").unwrap());
let admin = ["pgcat", "pgbouncer"]
.iter()
.filter(|db| *db == &database)
.count()
== 1;
// Split the read and write streams
// so we can control buffering.
let (read, write) = stream.into_split();
return Ok(Client {
read: BufReader::new(read),
write: write,
buffer: BytesMut::with_capacity(8196),
cancel_mode: false,
transaction_mode: transaction_mode,
process_id: process_id,
secret_key: secret_key,
client_server_map: client_server_map,
parameters: parameters,
stats: stats,
admin: admin,
last_address_id: None,
last_server_id: None,
});
}
// Query cancel request.
CANCEL_REQUEST_CODE => {
let (read, write) = stream.into_split();
let process_id = bytes.get_i32();
let secret_key = bytes.get_i32();
return Ok(Client {
read: BufReader::new(read),
write: write,
buffer: BytesMut::with_capacity(8196),
cancel_mode: true,
transaction_mode: transaction_mode,
process_id: process_id,
secret_key: secret_key,
client_server_map: client_server_map,
parameters: HashMap::new(),
stats: stats,
admin: false,
last_address_id: None,
last_server_id: None,
});
}
_ => {
return Err(Error::ProtocolSyncError);
}
};
// PasswordMessage
if code as char != 'p' {
debug!("Expected p, got {}", code as char);
return Err(Error::ProtocolSyncError);
}
let len = match read.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::SocketError),
};
let mut password_response = vec![0u8; (len - 4) as usize];
match read.read_exact(&mut password_response).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
};
// Compare server and client hashes.
let password_hash = md5_hash_password(&config.user.name, &config.user.password, &salt);
if password_hash != password_response {
debug!("Password authentication failed");
wrong_password(&mut write, &config.user.name).await?;
return Err(Error::ClientError);
}
debug!("Password authentication successful");
auth_ok(&mut write).await?;
write_all(&mut write, get_pool().server_info()).await?;
backend_key_data(&mut write, process_id, secret_key).await?;
ready_for_query(&mut write).await?;
trace!("Startup OK");
let database = parameters
.get("database")
.unwrap_or(parameters.get("user").unwrap());
let admin = ["pgcat", "pgbouncer"]
.iter()
.filter(|db| *db == &database)
.count()
== 1;
// Split the read and write streams
// so we can control buffering.
return Ok(Client {
read: BufReader::new(read),
write: write,
buffer: BytesMut::with_capacity(8196),
cancel_mode: false,
transaction_mode: transaction_mode,
process_id: process_id,
secret_key: secret_key,
client_server_map: client_server_map,
parameters: parameters,
stats: stats,
admin: admin,
last_address_id: None,
last_server_id: None,
});
}
/// Perform client startup sequence.
/// See docs: <https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3>
// pub async fn startup(
// mut stream: TcpStream,
// client_server_map: ClientServerMap,
// ) -> Result<Client<ReadHalf<TcpStream>, WriteHalf<TcpStream>>, Error> {
// let config = get_config();
// let transaction_mode = config.general.pool_mode == "transaction";
// let stats = get_reporter();
// loop {
// trace!("Waiting for StartupMessage");
// // Could be StartupMessage, SSLRequest or CancelRequest.
// let len = match stream.read_i32().await {
// Ok(len) => len,
// Err(_) => return Err(Error::ClientBadStartup),
// };
// let mut startup = vec![0u8; len as usize - 4];
// match stream.read_exact(&mut startup).await {
// Ok(_) => (),
// Err(_) => return Err(Error::ClientBadStartup),
// };
// let mut bytes = BytesMut::from(&startup[..]);
// let code = bytes.get_i32();
// match code {
// // Client wants SSL. We don't support it at the moment.
// SSL_REQUEST_CODE => {
// trace!("Rejecting SSLRequest");
// let mut no = BytesMut::with_capacity(1);
// no.put_u8(b'N');
// write_all(&mut stream, no).await?;
// }
// // Regular startup message.
// PROTOCOL_VERSION_NUMBER => {
// trace!("Got StartupMessage");
// let parameters = parse_startup(bytes.clone())?;
// // Generate random backend ID and secret key
// let process_id: i32 = rand::random();
// let secret_key: i32 = rand::random();
// // Perform MD5 authentication.
// // TODO: Add SASL support.
// let salt = md5_challenge(&mut stream).await?;
// let code = match stream.read_u8().await {
// Ok(p) => p,
// Err(_) => return Err(Error::SocketError),
// };
// // PasswordMessage
// if code as char != 'p' {
// debug!("Expected p, got {}", code as char);
// return Err(Error::ProtocolSyncError);
// }
// let len = match stream.read_i32().await {
// Ok(len) => len,
// Err(_) => return Err(Error::SocketError),
// };
// let mut password_response = vec![0u8; (len - 4) as usize];
// match stream.read_exact(&mut password_response).await {
// Ok(_) => (),
// Err(_) => return Err(Error::SocketError),
// };
// // Compare server and client hashes.
// let password_hash =
// md5_hash_password(&config.user.name, &config.user.password, &salt);
// if password_hash != password_response {
// debug!("Password authentication failed");
// wrong_password(&mut stream, &config.user.name).await?;
// return Err(Error::ClientError);
// }
// debug!("Password authentication successful");
// auth_ok(&mut stream).await?;
// write_all(&mut stream, get_pool().server_info()).await?;
// backend_key_data(&mut stream, process_id, secret_key).await?;
// ready_for_query(&mut stream).await?;
// trace!("Startup OK");
// let database = parameters
// .get("database")
// .unwrap_or(parameters.get("user").unwrap());
// let admin = ["pgcat", "pgbouncer"]
// .iter()
// .filter(|db| *db == &database)
// .count()
// == 1;
// // Split the read and write streams
// // so we can control buffering.
// let (read, write) = split(stream);
// return Ok(Client {
// read: BufReader::new(read),
// write: write,
// buffer: BytesMut::with_capacity(8196),
// cancel_mode: false,
// transaction_mode: transaction_mode,
// process_id: process_id,
// secret_key: secret_key,
// client_server_map: client_server_map,
// parameters: parameters,
// stats: stats,
// admin: admin,
// last_address_id: None,
// last_server_id: None,
// });
// }
// // Query cancel request.
// CANCEL_REQUEST_CODE => {
// let (read, write) = split(stream);
// let process_id = bytes.get_i32();
// let secret_key = bytes.get_i32();
// return Ok(Client {
// read: BufReader::new(read),
// write: write,
// buffer: BytesMut::with_capacity(8196),
// cancel_mode: true,
// transaction_mode: transaction_mode,
// process_id: process_id,
// secret_key: secret_key,
// client_server_map: client_server_map,
// parameters: HashMap::new(),
// stats: stats,
// admin: false,
// last_address_id: None,
// last_server_id: None,
// });
// }
// _ => {
// return Err(Error::ProtocolSyncError);
// }
// };
// }
// }
/// Handle a connected and authenticated client.
pub async fn handle(&mut self) -> Result<(), Error> {
// The client wants to cancel a query it has issued previously.
@@ -414,8 +608,8 @@ impl<T: tokio::io::AsyncRead + std::marker::Unpin, S: tokio::io::AsyncWrite + st
self.last_server_id = Some(server.process_id());
debug!(
"Client {:?} talking to server {:?}",
self.write.peer_addr().unwrap(),
"Client stuff talking to server {:?}",
// self.write.peer_addr().unwrap(),
server.address()
);
@@ -650,7 +844,7 @@ impl<T: tokio::io::AsyncRead + std::marker::Unpin, S: tokio::io::AsyncWrite + st
}
}
impl<T, S> Drop for Client <T, S> {
impl<S, T> Drop for Client<S, T> {
fn drop(&mut self) {
// Update statistics.
if let Some(address_id) = self.last_address_id {