at least it compiles

This commit is contained in:
Lev
2022-06-27 15:52:01 -07:00
parent b974aacd71
commit eb58920870
5 changed files with 538 additions and 323 deletions

View File

@@ -2,7 +2,7 @@
use bytes::{Buf, BufMut, BytesMut}; use bytes::{Buf, BufMut, BytesMut};
use log::{info, trace}; use log::{info, trace};
use std::collections::HashMap; use std::collections::HashMap;
use tokio::net::tcp::OwnedWriteHalf; // use tokio::net::tcp::T;
use crate::config::{get_config, reload_config}; use crate::config::{get_config, reload_config};
use crate::errors::Error; use crate::errors::Error;
@@ -12,12 +12,15 @@ use crate::stats::get_stats;
use crate::ClientServerMap; use crate::ClientServerMap;
/// Handle admin client. /// Handle admin client.
pub async fn handle_admin( pub async fn handle_admin<T>(
stream: &mut OwnedWriteHalf, stream: &mut T,
mut query: BytesMut, mut query: BytesMut,
pool: ConnectionPool, pool: ConnectionPool,
client_server_map: ClientServerMap, client_server_map: ClientServerMap,
) -> Result<(), Error> { ) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let code = query.get_u8() as char; let code = query.get_u8() as char;
if code != 'Q' { if code != 'Q' {
@@ -61,7 +64,10 @@ pub async fn handle_admin(
} }
/// Column-oriented statistics. /// Column-oriented statistics.
async fn show_lists(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> { async fn show_lists<T>(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let stats = get_stats(); let stats = get_stats();
let columns = vec![("list", DataType::Text), ("items", DataType::Int4)]; let columns = vec![("list", DataType::Text), ("items", DataType::Int4)];
@@ -128,7 +134,10 @@ async fn show_lists(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Resul
} }
/// Show PgCat version. /// Show PgCat version.
async fn show_version(stream: &mut OwnedWriteHalf) -> Result<(), Error> { async fn show_version<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut res = BytesMut::new(); let mut res = BytesMut::new();
res.put(row_description(&vec![("version", DataType::Text)])); res.put(row_description(&vec![("version", DataType::Text)]));
@@ -143,7 +152,10 @@ async fn show_version(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
} }
/// Show utilization of connection pools for each shard and replicas. /// Show utilization of connection pools for each shard and replicas.
async fn show_pools(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> { async fn show_pools<T>(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let stats = get_stats(); let stats = get_stats();
let config = get_config(); let config = get_config();
@@ -197,7 +209,10 @@ async fn show_pools(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Resul
} }
/// Show shards and replicas. /// Show shards and replicas.
async fn show_databases(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> { async fn show_databases<T>(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let config = get_config(); let config = get_config();
// Columns // Columns
@@ -258,15 +273,18 @@ async fn show_databases(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> R
/// Ignore any SET commands the client sends. /// Ignore any SET commands the client sends.
/// This is common initialization done by ORMs. /// This is common initialization done by ORMs.
async fn ignore_set(stream: &mut OwnedWriteHalf) -> Result<(), Error> { async fn ignore_set<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
custom_protocol_response_ok(stream, "SET").await custom_protocol_response_ok(stream, "SET").await
} }
/// Reload the configuration file without restarting the process. /// Reload the configuration file without restarting the process.
async fn reload( async fn reload<T>(stream: &mut T, client_server_map: ClientServerMap) -> Result<(), Error>
stream: &mut OwnedWriteHalf, where
client_server_map: ClientServerMap, T: tokio::io::AsyncWrite + std::marker::Unpin,
) -> Result<(), Error> { {
info!("Reloading config"); info!("Reloading config");
reload_config(client_server_map).await?; reload_config(client_server_map).await?;
@@ -286,7 +304,10 @@ async fn reload(
} }
/// Shows current configuration. /// Shows current configuration.
async fn show_config(stream: &mut OwnedWriteHalf) -> Result<(), Error> { async fn show_config<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let config = &get_config(); let config = &get_config();
let config: HashMap<String, String> = config.into(); let config: HashMap<String, String> = config.into();
@@ -329,7 +350,10 @@ async fn show_config(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
} }
/// Show shard and replicas statistics. /// Show shard and replicas statistics.
async fn show_stats(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> { async fn show_stats<T>(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let columns = vec![ let columns = vec![
("database", DataType::Text), ("database", DataType::Text),
("total_xact_count", DataType::Numeric), ("total_xact_count", DataType::Numeric),

View File

@@ -2,7 +2,7 @@
use bytes::{Buf, BufMut, BytesMut}; use bytes::{Buf, BufMut, BytesMut};
use log::{debug, error, trace}; use log::{debug, error, trace};
use std::collections::HashMap; use std::collections::HashMap;
use tokio::io::{AsyncReadExt, BufReader}; use tokio::io::{split, AsyncReadExt, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
use tokio::net::{ use tokio::net::{
tcp::{OwnedReadHalf, OwnedWriteHalf}, tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream, TcpStream,
@@ -17,15 +17,25 @@ use crate::pool::{get_pool, ClientServerMap};
use crate::query_router::{Command, QueryRouter}; use crate::query_router::{Command, QueryRouter};
use crate::server::Server; use crate::server::Server;
use crate::stats::{get_reporter, Reporter}; use crate::stats::{get_reporter, Reporter};
use crate::stream::Tls;
use tokio_rustls::server::TlsStream;
/// Type of connection received from client.
enum ClientConnectionType {
Startup,
Tls,
CancelQuery,
}
/// The client state. One of these is created per client. /// The client state. One of these is created per client.
pub struct Client <T, S> { pub struct Client<S, T> {
/// The reads are buffered (8K by default). /// The reads are buffered (8K by default).
read: BufReader<T>, read: BufReader<S>,
/// We buffer the writes ourselves because we know the protocol /// We buffer the writes ourselves because we know the protocol
/// better than a stock buffer. /// better than a stock buffer.
write: S, write: T,
/// Internal buffer, where we place messages until we have to flush /// Internal buffer, where we place messages until we have to flush
/// them to the backend. /// them to the backend.
@@ -63,163 +73,347 @@ pub struct Client <T, S> {
last_server_id: Option<i32>, last_server_id: Option<i32>,
} }
impl<T: tokio::io::AsyncRead + std::marker::Unpin, S: tokio::io::AsyncWrite + std::marker::Unpin> Client <T, S> { /// Main client loop.
/// Perform client startup sequence. pub async fn client_loop(
/// See docs: <https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3> mut stream: TcpStream,
pub async fn startup( client_server_map: ClientServerMap,
mut stream: TcpStream, ) -> Result<(), Error> {
match get_startup::<TcpStream>(&mut stream).await {
Ok((ClientConnectionType::Tls, bytes)) => {
match startup_tls(stream, client_server_map).await {
Ok(mut client) => client.handle().await,
Err(err) => Err(err),
}
}
Ok((ClientConnectionType::Startup, bytes)) => {
let (read, write) = split(stream);
match Client::handle_startup(read, write, bytes, client_server_map).await {
Ok(mut client) => client.handle().await,
Err(err) => Err(err),
}
}
Ok((ClientConnectionType::CancelQuery, bytes)) => {
return Err(Error::ProtocolSyncError);
}
Err(err) => Err(err),
}
}
async fn get_startup<S>(stream: &mut S) -> Result<(ClientConnectionType, BytesMut), Error>
where
S: tokio::io::AsyncRead + std::marker::Unpin + tokio::io::AsyncWrite,
{
// Get startup message length.
let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::ClientBadStartup),
};
// Get the rest of the message.
let mut startup = vec![0u8; len as usize - 4];
match stream.read_exact(&mut startup).await {
Ok(_) => (),
Err(_) => return Err(Error::ClientBadStartup),
};
let mut bytes = BytesMut::from(&startup[..]);
let code = bytes.get_i32();
match code {
// Client is requesting SSL (TLS).
SSL_REQUEST_CODE => Ok((ClientConnectionType::Tls, bytes)),
// Client wants to use plain text, requesting regular startup.
PROTOCOL_VERSION_NUMBER => Ok((ClientConnectionType::Startup, bytes)),
// Client is requesting to cancel a running query (plain text connection).
CANCEL_REQUEST_CODE => Ok((ClientConnectionType::CancelQuery, bytes)),
_ => Err(Error::ProtocolSyncError),
}
}
/// Handle TLS connection negotation.
pub async fn startup_tls(
mut stream: TcpStream,
client_server_map: ClientServerMap,
) -> Result<Client<ReadHalf<TlsStream<TcpStream>>, WriteHalf<TlsStream<TcpStream>>>, Error> {
// Accept SSL request if SSL is configured.
let mut yes = BytesMut::new();
yes.put_u8(b'S');
write_all(&mut stream, yes).await?;
// Negotiate TLS.
let mut tls = Tls::new().unwrap();
let mut stream = match tls.acceptor.accept(stream).await {
Ok(stream) => stream,
Err(_) => return Err(Error::TlsError),
};
match get_startup::<TlsStream<TcpStream>>(&mut stream).await {
Ok((ClientConnectionType::Startup, bytes)) => {
let (read, write) = split(stream);
Client::handle_startup(read, write, bytes, client_server_map).await
}
_ => Err(Error::ProtocolSyncError),
}
}
impl<S, T> Client<S, T>
where
S: tokio::io::AsyncRead + std::marker::Unpin,
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
// Perform client startup sequence in TLS.
pub async fn handle_startup(
mut read: S,
mut write: T,
bytes: BytesMut, // The rest of the startup message.
client_server_map: ClientServerMap, client_server_map: ClientServerMap,
) -> Result<Client<T, S>, Error> { ) -> Result<Client<S, T>, Error> {
let config = get_config(); let config = get_config();
let transaction_mode = config.general.pool_mode == "transaction"; let transaction_mode = config.general.pool_mode == "transaction";
let stats = get_reporter(); let stats = get_reporter();
loop { trace!("Got StartupMessage");
trace!("Waiting for StartupMessage"); let parameters = parse_startup(bytes.clone())?;
// Could be StartupMessage, SSLRequest or CancelRequest. // Generate random backend ID and secret key
let len = match stream.read_i32().await { let process_id: i32 = rand::random();
Ok(len) => len, let secret_key: i32 = rand::random();
Err(_) => return Err(Error::ClientBadStartup),
};
let mut startup = vec![0u8; len as usize - 4]; // Perform MD5 authentication.
// TODO: Add SASL support.
let salt = md5_challenge(&mut write).await?;
match stream.read_exact(&mut startup).await { let code = match read.read_u8().await {
Ok(_) => (), Ok(p) => p,
Err(_) => return Err(Error::ClientBadStartup), Err(_) => return Err(Error::SocketError),
}; };
let mut bytes = BytesMut::from(&startup[..]); // PasswordMessage
let code = bytes.get_i32(); if code as char != 'p' {
debug!("Expected p, got {}", code as char);
match code { return Err(Error::ProtocolSyncError);
// Client wants SSL. We don't support it at the moment.
SSL_REQUEST_CODE => {
trace!("Rejecting SSLRequest");
let mut no = BytesMut::with_capacity(1);
no.put_u8(b'N');
write_all(&mut stream, no).await?;
}
// Regular startup message.
PROTOCOL_VERSION_NUMBER => {
trace!("Got StartupMessage");
let parameters = parse_startup(bytes.clone())?;
// Generate random backend ID and secret key
let process_id: i32 = rand::random();
let secret_key: i32 = rand::random();
// Perform MD5 authentication.
// TODO: Add SASL support.
let salt = md5_challenge(&mut stream).await?;
let code = match stream.read_u8().await {
Ok(p) => p,
Err(_) => return Err(Error::SocketError),
};
// PasswordMessage
if code as char != 'p' {
debug!("Expected p, got {}", code as char);
return Err(Error::ProtocolSyncError);
}
let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::SocketError),
};
let mut password_response = vec![0u8; (len - 4) as usize];
match stream.read_exact(&mut password_response).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
};
// Compare server and client hashes.
let password_hash =
md5_hash_password(&config.user.name, &config.user.password, &salt);
if password_hash != password_response {
debug!("Password authentication failed");
wrong_password(&mut stream, &config.user.name).await?;
return Err(Error::ClientError);
}
debug!("Password authentication successful");
auth_ok(&mut stream).await?;
write_all(&mut stream, get_pool().server_info()).await?;
backend_key_data(&mut stream, process_id, secret_key).await?;
ready_for_query(&mut stream).await?;
trace!("Startup OK");
let database = parameters
.get("database")
.unwrap_or(parameters.get("user").unwrap());
let admin = ["pgcat", "pgbouncer"]
.iter()
.filter(|db| *db == &database)
.count()
== 1;
// Split the read and write streams
// so we can control buffering.
let (read, write) = stream.into_split();
return Ok(Client {
read: BufReader::new(read),
write: write,
buffer: BytesMut::with_capacity(8196),
cancel_mode: false,
transaction_mode: transaction_mode,
process_id: process_id,
secret_key: secret_key,
client_server_map: client_server_map,
parameters: parameters,
stats: stats,
admin: admin,
last_address_id: None,
last_server_id: None,
});
}
// Query cancel request.
CANCEL_REQUEST_CODE => {
let (read, write) = stream.into_split();
let process_id = bytes.get_i32();
let secret_key = bytes.get_i32();
return Ok(Client {
read: BufReader::new(read),
write: write,
buffer: BytesMut::with_capacity(8196),
cancel_mode: true,
transaction_mode: transaction_mode,
process_id: process_id,
secret_key: secret_key,
client_server_map: client_server_map,
parameters: HashMap::new(),
stats: stats,
admin: false,
last_address_id: None,
last_server_id: None,
});
}
_ => {
return Err(Error::ProtocolSyncError);
}
};
} }
let len = match read.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::SocketError),
};
let mut password_response = vec![0u8; (len - 4) as usize];
match read.read_exact(&mut password_response).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
};
// Compare server and client hashes.
let password_hash = md5_hash_password(&config.user.name, &config.user.password, &salt);
if password_hash != password_response {
debug!("Password authentication failed");
wrong_password(&mut write, &config.user.name).await?;
return Err(Error::ClientError);
}
debug!("Password authentication successful");
auth_ok(&mut write).await?;
write_all(&mut write, get_pool().server_info()).await?;
backend_key_data(&mut write, process_id, secret_key).await?;
ready_for_query(&mut write).await?;
trace!("Startup OK");
let database = parameters
.get("database")
.unwrap_or(parameters.get("user").unwrap());
let admin = ["pgcat", "pgbouncer"]
.iter()
.filter(|db| *db == &database)
.count()
== 1;
// Split the read and write streams
// so we can control buffering.
return Ok(Client {
read: BufReader::new(read),
write: write,
buffer: BytesMut::with_capacity(8196),
cancel_mode: false,
transaction_mode: transaction_mode,
process_id: process_id,
secret_key: secret_key,
client_server_map: client_server_map,
parameters: parameters,
stats: stats,
admin: admin,
last_address_id: None,
last_server_id: None,
});
} }
/// Perform client startup sequence.
/// See docs: <https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3>
// pub async fn startup(
// mut stream: TcpStream,
// client_server_map: ClientServerMap,
// ) -> Result<Client<ReadHalf<TcpStream>, WriteHalf<TcpStream>>, Error> {
// let config = get_config();
// let transaction_mode = config.general.pool_mode == "transaction";
// let stats = get_reporter();
// loop {
// trace!("Waiting for StartupMessage");
// // Could be StartupMessage, SSLRequest or CancelRequest.
// let len = match stream.read_i32().await {
// Ok(len) => len,
// Err(_) => return Err(Error::ClientBadStartup),
// };
// let mut startup = vec![0u8; len as usize - 4];
// match stream.read_exact(&mut startup).await {
// Ok(_) => (),
// Err(_) => return Err(Error::ClientBadStartup),
// };
// let mut bytes = BytesMut::from(&startup[..]);
// let code = bytes.get_i32();
// match code {
// // Client wants SSL. We don't support it at the moment.
// SSL_REQUEST_CODE => {
// trace!("Rejecting SSLRequest");
// let mut no = BytesMut::with_capacity(1);
// no.put_u8(b'N');
// write_all(&mut stream, no).await?;
// }
// // Regular startup message.
// PROTOCOL_VERSION_NUMBER => {
// trace!("Got StartupMessage");
// let parameters = parse_startup(bytes.clone())?;
// // Generate random backend ID and secret key
// let process_id: i32 = rand::random();
// let secret_key: i32 = rand::random();
// // Perform MD5 authentication.
// // TODO: Add SASL support.
// let salt = md5_challenge(&mut stream).await?;
// let code = match stream.read_u8().await {
// Ok(p) => p,
// Err(_) => return Err(Error::SocketError),
// };
// // PasswordMessage
// if code as char != 'p' {
// debug!("Expected p, got {}", code as char);
// return Err(Error::ProtocolSyncError);
// }
// let len = match stream.read_i32().await {
// Ok(len) => len,
// Err(_) => return Err(Error::SocketError),
// };
// let mut password_response = vec![0u8; (len - 4) as usize];
// match stream.read_exact(&mut password_response).await {
// Ok(_) => (),
// Err(_) => return Err(Error::SocketError),
// };
// // Compare server and client hashes.
// let password_hash =
// md5_hash_password(&config.user.name, &config.user.password, &salt);
// if password_hash != password_response {
// debug!("Password authentication failed");
// wrong_password(&mut stream, &config.user.name).await?;
// return Err(Error::ClientError);
// }
// debug!("Password authentication successful");
// auth_ok(&mut stream).await?;
// write_all(&mut stream, get_pool().server_info()).await?;
// backend_key_data(&mut stream, process_id, secret_key).await?;
// ready_for_query(&mut stream).await?;
// trace!("Startup OK");
// let database = parameters
// .get("database")
// .unwrap_or(parameters.get("user").unwrap());
// let admin = ["pgcat", "pgbouncer"]
// .iter()
// .filter(|db| *db == &database)
// .count()
// == 1;
// // Split the read and write streams
// // so we can control buffering.
// let (read, write) = split(stream);
// return Ok(Client {
// read: BufReader::new(read),
// write: write,
// buffer: BytesMut::with_capacity(8196),
// cancel_mode: false,
// transaction_mode: transaction_mode,
// process_id: process_id,
// secret_key: secret_key,
// client_server_map: client_server_map,
// parameters: parameters,
// stats: stats,
// admin: admin,
// last_address_id: None,
// last_server_id: None,
// });
// }
// // Query cancel request.
// CANCEL_REQUEST_CODE => {
// let (read, write) = split(stream);
// let process_id = bytes.get_i32();
// let secret_key = bytes.get_i32();
// return Ok(Client {
// read: BufReader::new(read),
// write: write,
// buffer: BytesMut::with_capacity(8196),
// cancel_mode: true,
// transaction_mode: transaction_mode,
// process_id: process_id,
// secret_key: secret_key,
// client_server_map: client_server_map,
// parameters: HashMap::new(),
// stats: stats,
// admin: false,
// last_address_id: None,
// last_server_id: None,
// });
// }
// _ => {
// return Err(Error::ProtocolSyncError);
// }
// };
// }
// }
/// Handle a connected and authenticated client. /// Handle a connected and authenticated client.
pub async fn handle(&mut self) -> Result<(), Error> { pub async fn handle(&mut self) -> Result<(), Error> {
// The client wants to cancel a query it has issued previously. // The client wants to cancel a query it has issued previously.
@@ -414,8 +608,8 @@ impl<T: tokio::io::AsyncRead + std::marker::Unpin, S: tokio::io::AsyncWrite + st
self.last_server_id = Some(server.process_id()); self.last_server_id = Some(server.process_id());
debug!( debug!(
"Client {:?} talking to server {:?}", "Client stuff talking to server {:?}",
self.write.peer_addr().unwrap(), // self.write.peer_addr().unwrap(),
server.address() server.address()
); );
@@ -650,7 +844,7 @@ impl<T: tokio::io::AsyncRead + std::marker::Unpin, S: tokio::io::AsyncWrite + st
} }
} }
impl<T, S> Drop for Client <T, S> { impl<S, T> Drop for Client<S, T> {
fn drop(&mut self) { fn drop(&mut self) {
// Update statistics. // Update statistics.
if let Some(address_id) = self.last_address_id { if let Some(address_id) = self.last_address_id {

View File

@@ -28,13 +28,13 @@ extern crate log;
extern crate md5; extern crate md5;
extern crate num_cpus; extern crate num_cpus;
extern crate once_cell; extern crate once_cell;
extern crate rustls_pemfile;
extern crate serde; extern crate serde;
extern crate serde_derive; extern crate serde_derive;
extern crate sqlparser; extern crate sqlparser;
extern crate tokio; extern crate tokio;
extern crate toml;
extern crate tokio_rustls; extern crate tokio_rustls;
extern crate rustls_pemfile; extern crate toml;
use log::{debug, error, info}; use log::{debug, error, info};
use parking_lot::Mutex; use parking_lot::Mutex;
@@ -45,6 +45,11 @@ use tokio::{
sync::mpsc, sync::mpsc,
}; };
use tokio::net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
@@ -62,6 +67,7 @@ mod sharding;
mod stats; mod stats;
mod stream; mod stream;
use crate::constants::*;
use config::{get_config, reload_config}; use config::{get_config, reload_config};
use pool::{ClientServerMap, ConnectionPool}; use pool::{ClientServerMap, ConnectionPool};
use stats::{Collector, Reporter, REPORTER}; use stats::{Collector, Reporter, REPORTER};
@@ -153,32 +159,45 @@ async fn main() {
// Handle client. // Handle client.
tokio::task::spawn(async move { tokio::task::spawn(async move {
let start = chrono::offset::Utc::now().naive_utc(); let start = chrono::offset::Utc::now().naive_utc();
match client::Client::startup(socket, client_server_map).await { // match client::get_startup(&mut socket) {
Ok(mut client) => { // Ok((code, bytes)) => match code {
info!("Client {:?} connected", addr); // SSL_REQUEST_CODE => client::Client::tls_startup<
// }
match client.handle().await { // }
Ok(()) => {
let duration = chrono::offset::Utc::now().naive_utc() - start;
info!(
"Client {:?} disconnected, session duration: {}",
addr,
format_duration(&duration)
);
}
Err(err) => {
error!("Client disconnected with error: {:?}", err);
client.release();
}
}
}
match client::client_loop(socket, client_server_map).await {
Ok(_) => (),
Err(err) => { Err(err) => {
debug!("Client failed to login: {:?}", err); debug!("Client failed to login: {:?}", err);
} }
}; };
// match client::Client::<OwnedReadHalf, OwnedWriteHalf>::startup(socket, client_server_map).await {
// Ok(mut client) => {
// info!("Client {:?} connected", addr);
// match client.handle().await {
// Ok(()) => {
// let duration = chrono::offset::Utc::now().naive_utc() - start;
// info!(
// "Client {:?} disconnected, session duration: {}",
// addr,
// format_duration(&duration)
// );
// }
// Err(err) => {
// error!("Client disconnected with error: {:?}", err);
// client.release();
// }
// }
// }
// Err(err) => {
// debug!("Client failed to login: {:?}", err);
// }
// };
}); });
} }
}); });

View File

@@ -31,7 +31,9 @@ impl From<&DataType> for i32 {
/// Tell the client that authentication handshake completed successfully. /// Tell the client that authentication handshake completed successfully.
pub async fn auth_ok<S>(stream: &mut S) -> Result<(), Error> pub async fn auth_ok<S>(stream: &mut S) -> Result<(), Error>
where S: tokio::io::AsyncWrite + std::marker::Unpin { where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut auth_ok = BytesMut::with_capacity(9); let mut auth_ok = BytesMut::with_capacity(9);
auth_ok.put_u8(b'R'); auth_ok.put_u8(b'R');
@@ -43,7 +45,9 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin {
/// Generate md5 password challenge. /// Generate md5 password challenge.
pub async fn md5_challenge<S>(stream: &mut S) -> Result<[u8; 4], Error> pub async fn md5_challenge<S>(stream: &mut S) -> Result<[u8; 4], Error>
where S: tokio::io::AsyncWrite + std::marker::Unpin { where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
// let mut rng = rand::thread_rng(); // let mut rng = rand::thread_rng();
let salt: [u8; 4] = [ let salt: [u8; 4] = [
rand::random(), rand::random(),
@@ -69,7 +73,9 @@ pub async fn backend_key_data<S>(
backend_id: i32, backend_id: i32,
secret_key: i32, secret_key: i32,
) -> Result<(), Error> ) -> Result<(), Error>
where S: tokio::io::AsyncWrite + std::marker::Unpin { where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut key_data = BytesMut::from(&b"K"[..]); let mut key_data = BytesMut::from(&b"K"[..]);
key_data.put_i32(12); key_data.put_i32(12);
key_data.put_i32(backend_id); key_data.put_i32(backend_id);
@@ -91,7 +97,9 @@ pub fn simple_query(query: &str) -> BytesMut {
/// Tell the client we're ready for another query. /// Tell the client we're ready for another query.
pub async fn ready_for_query<S>(stream: &mut S) -> Result<(), Error> pub async fn ready_for_query<S>(stream: &mut S) -> Result<(), Error>
where S: tokio::io::AsyncWrite + std::marker::Unpin { where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut bytes = BytesMut::with_capacity(5); let mut bytes = BytesMut::with_capacity(5);
bytes.put_u8(b'Z'); bytes.put_u8(b'Z');
@@ -215,7 +223,9 @@ pub async fn md5_password<S>(
password: &str, password: &str,
salt: &[u8], salt: &[u8],
) -> Result<(), Error> ) -> Result<(), Error>
where S: tokio::io::AsyncWrite + std::marker::Unpin { where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let password = md5_hash_password(user, password, salt); let password = md5_hash_password(user, password, salt);
let mut message = BytesMut::with_capacity(password.len() as usize + 5); let mut message = BytesMut::with_capacity(password.len() as usize + 5);
@@ -230,11 +240,10 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin {
/// Implements a response to our custom `SET SHARDING KEY` /// Implements a response to our custom `SET SHARDING KEY`
/// and `SET SERVER ROLE` commands. /// and `SET SERVER ROLE` commands.
/// This tells the client we're ready for the next query. /// This tells the client we're ready for the next query.
pub async fn custom_protocol_response_ok<S>( pub async fn custom_protocol_response_ok<S>(stream: &mut S, message: &str) -> Result<(), Error>
stream: &mut S, where
message: &str, S: tokio::io::AsyncWrite + std::marker::Unpin,
) -> Result<(), Error> {
where S: tokio::io::AsyncWrite + std::marker::Unpin {
let mut res = BytesMut::with_capacity(25); let mut res = BytesMut::with_capacity(25);
let set_complete = BytesMut::from(&format!("{}\0", message)[..]); let set_complete = BytesMut::from(&format!("{}\0", message)[..]);
@@ -257,7 +266,9 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin {
/// Tell the client we are ready for the next query and no rollback is necessary. /// Tell the client we are ready for the next query and no rollback is necessary.
/// Docs on error codes: <https://www.postgresql.org/docs/12/errcodes-appendix.html>. /// Docs on error codes: <https://www.postgresql.org/docs/12/errcodes-appendix.html>.
pub async fn error_response<S>(stream: &mut S, message: &str) -> Result<(), Error> pub async fn error_response<S>(stream: &mut S, message: &str) -> Result<(), Error>
where S: tokio::io::AsyncWrite + std::marker::Unpin { where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut error = BytesMut::new(); let mut error = BytesMut::new();
// Error level // Error level
@@ -299,7 +310,9 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin {
} }
pub async fn wrong_password<S>(stream: &mut S, user: &str) -> Result<(), Error> pub async fn wrong_password<S>(stream: &mut S, user: &str) -> Result<(), Error>
where S: tokio::io::AsyncWrite + std::marker::Unpin { where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut error = BytesMut::new(); let mut error = BytesMut::new();
// Error level // Error level
@@ -333,11 +346,10 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin {
} }
/// Respond to a SHOW SHARD command. /// Respond to a SHOW SHARD command.
pub async fn show_response( pub async fn show_response<S>(stream: &mut S, name: &str, value: &str) -> Result<(), Error>
stream: &mut OwnedWriteHalf, where
name: &str, S: tokio::io::AsyncWrite + std::marker::Unpin,
value: &str, {
) -> Result<(), Error> {
// A SELECT response consists of: // A SELECT response consists of:
// 1. RowDescription // 1. RowDescription
// 2. One or more DataRow // 2. One or more DataRow
@@ -439,7 +451,9 @@ pub fn command_complete(command: &str) -> BytesMut {
/// Write all data in the buffer to the TcpStream. /// Write all data in the buffer to the TcpStream.
pub async fn write_all<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error> pub async fn write_all<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error>
where S: tokio::io::AsyncWrite + std::marker::Unpin { where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
match stream.write_all(&buf).await { match stream.write_all(&buf).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError),
@@ -448,7 +462,9 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin {
/// Write all the data in the buffer to the TcpStream, write owned half (see mpsc). /// Write all the data in the buffer to the TcpStream, write owned half (see mpsc).
pub async fn write_all_half<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error> pub async fn write_all_half<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error>
where S: tokio::io::AsyncWrite + std::marker::Unpin { where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
match stream.write_all(&buf).await { match stream.write_all(&buf).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError),
@@ -456,7 +472,10 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin {
} }
/// Read a complete message from the socket. /// Read a complete message from the socket.
pub async fn read_message(stream: &mut BufReader<OwnedReadHalf>) -> Result<BytesMut, Error> { pub async fn read_message<S>(stream: &mut S) -> Result<BytesMut, Error>
where
S: tokio::io::AsyncRead + std::marker::Unpin,
{
let code = match stream.read_u8().await { let code = match stream.read_u8().await {
Ok(code) => code, Ok(code) => code,
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError),

View File

@@ -1,17 +1,17 @@
// Stream wrapper. // Stream wrapper.
use bytes::{Buf, BufMut, BytesMut}; use bytes::{Buf, BufMut, BytesMut};
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, split, ReadHalf, WriteHalf}; use rustls_pemfile::{certs, rsa_private_keys};
use std::path::Path;
use std::sync::Arc;
use tokio::io::{split, AsyncReadExt, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
use tokio::net::{ use tokio::net::{
tcp::{OwnedReadHalf, OwnedWriteHalf}, tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream, TcpStream,
}; };
use tokio_rustls::server::TlsStream;
use rustls_pemfile::{certs, rsa_private_keys};
use tokio_rustls::rustls::{self, Certificate, PrivateKey}; use tokio_rustls::rustls::{self, Certificate, PrivateKey};
use tokio_rustls::server::TlsStream;
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
use std::sync::Arc;
use std::path::Path;
use crate::config::get_config; use crate::config::get_config;
use crate::errors::Error; use crate::errors::Error;
@@ -29,132 +29,91 @@ fn load_keys(path: &std::path::Path) -> std::io::Result<Vec<PrivateKey>> {
.map(|mut keys| keys.drain(..).map(PrivateKey).collect()) .map(|mut keys| keys.drain(..).map(PrivateKey).collect())
} }
struct Tls { pub struct Tls {
acceptor: TlsAcceptor, pub acceptor: TlsAcceptor,
} }
impl Tls { impl Tls {
pub fn new() -> Result<Self, Error> { pub fn new() -> Result<Self, Error> {
let config = get_config(); let config = get_config();
let certs = match load_certs(&Path::new(&config.general.tls_certificate.unwrap())) { let certs = match load_certs(&Path::new(&config.general.tls_certificate.unwrap())) {
Ok(certs) => certs, Ok(certs) => certs,
Err(_) => return Err(Error::TlsError), Err(_) => return Err(Error::TlsError),
}; };
let mut keys = match load_keys(&Path::new(&config.general.tls_private_key.unwrap())) { let mut keys = match load_keys(&Path::new(&config.general.tls_private_key.unwrap())) {
Ok(keys) => keys, Ok(keys) => keys,
Err(_) => return Err(Error::TlsError), Err(_) => return Err(Error::TlsError),
}; };
let config = match rustls::ServerConfig::builder() let config = match rustls::ServerConfig::builder()
.with_safe_defaults() .with_safe_defaults()
.with_no_client_auth() .with_no_client_auth()
.with_single_cert(certs, keys.remove(0)) .with_single_cert(certs, keys.remove(0))
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err)) { .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))
Ok(c) => c, {
Err(_) => return Err(Error::TlsError) Ok(c) => c,
Err(_) => return Err(Error::TlsError),
}; };
Ok(Tls { Ok(Tls {
acceptor: TlsAcceptor::from(Arc::new(config)), acceptor: TlsAcceptor::from(Arc::new(config)),
}) })
} }
} }
struct Stream { struct Stream {
read: Option<BufReader<OwnedReadHalf>>, read: Option<BufReader<OwnedReadHalf>>,
write: Option<OwnedWriteHalf>, write: Option<OwnedWriteHalf>,
tls_read: Option<BufReader<ReadHalf<TlsStream<TcpStream>>>>, tls_read: Option<BufReader<ReadHalf<TlsStream<TcpStream>>>>,
tls_write: Option<WriteHalf<TlsStream<TcpStream>>>, tls_write: Option<WriteHalf<TlsStream<TcpStream>>>,
} }
impl Stream { impl Stream {
pub async fn new(stream: TcpStream, tls: Option<Tls>) -> Result<Stream, Error> { pub async fn new(stream: TcpStream, tls: Option<Tls>) -> Result<Stream, Error> {
let config = get_config();
let config = get_config(); match tls {
None => {
let (read, write) = stream.into_split();
let read = BufReader::new(read);
Ok(Self {
read: Some(read),
write: Some(write),
tls_read: None,
tls_write: None,
})
}
match tls { Some(tls) => {
None => { let mut tls_stream = match tls.acceptor.accept(stream).await {
let (read, write) = stream.into_split(); Ok(stream) => stream,
let read = BufReader::new(read); Err(_) => return Err(Error::TlsError),
Ok( };
Self {
read: Some(read),
write: Some(write),
tls_read: None,
tls_write: None,
}
)
}
Some(tls) => { let (read, write) = split(tls_stream);
let mut tls_stream = match tls.acceptor.accept(stream).await {
Ok(stream) => stream,
Err(_) => return Err(Error::TlsError),
};
let (read, write) = split(tls_stream); Ok(Self {
read: None,
Ok(Self{ write: None,
read: None, tls_read: Some(BufReader::new(read)),
write: None, tls_write: Some(write),
tls_read: Some(BufReader::new(read)), })
tls_write: Some(write), }
}) }
} }
}
}
async fn read<S>(stream: &mut S) -> Result<BytesMut, Error>
where S: tokio::io::AsyncRead + std::marker::Unpin {
let code = match stream.read_u8().await {
Ok(code) => code,
Err(_) => return Err(Error::SocketError),
};
let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::SocketError),
};
let mut buf = vec![0u8; len as usize - 4];
match stream.read_exact(&mut buf).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
};
let mut bytes = BytesMut::with_capacity(len as usize + 1);
bytes.put_u8(code);
bytes.put_i32(len);
bytes.put_slice(&buf);
Ok(bytes)
}
async fn write<S>(stream: &mut S, buf: &BytesMut) -> Result<(), Error>
where S: tokio::io::AsyncWrite + std::marker::Unpin {
match stream.write_all(buf).await {
Ok(_) => Ok(()),
Err(_) => return Err(Error::SocketError),
}
}
pub async fn read_message(&mut self) -> Result<BytesMut, Error> {
match &self.read {
Some(read) => Self::read(self.read.as_mut().unwrap()).await,
None => Self::read(self.tls_read.as_mut().unwrap()).await,
}
}
pub async fn write_all(&mut self, buf: &BytesMut) -> Result<(), Error> {
match &self.write {
Some(write) => Self::write(self.write.as_mut().unwrap(), buf).await,
None => Self::write(self.tls_write.as_mut().unwrap(), buf).await,
}
}
} }
// impl tokio::io::AsyncRead for Stream {
// fn poll_read(
// mut self: core::pin::Pin<&mut Self>,
// cx: &mut core::task::Context<'_>,
// buf: &mut tokio::io::ReadBuf<'_>
// ) -> core::task::Poll<std::io::Result<()>> {
// match &mut self.get_mut().tls_read {
// None => core::pin::Pin::new(self.read.as_mut().unwrap()).poll_read(cx, buf),
// Some(mut tls) => core::pin::Pin::new(&mut tls).poll_read(cx, buf),
// }
// }
// }