mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-27 18:56:30 +00:00
at least it compiles
This commit is contained in:
54
src/admin.rs
54
src/admin.rs
@@ -2,7 +2,7 @@
|
|||||||
use bytes::{Buf, BufMut, BytesMut};
|
use bytes::{Buf, BufMut, BytesMut};
|
||||||
use log::{info, trace};
|
use log::{info, trace};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use tokio::net::tcp::OwnedWriteHalf;
|
// use tokio::net::tcp::T;
|
||||||
|
|
||||||
use crate::config::{get_config, reload_config};
|
use crate::config::{get_config, reload_config};
|
||||||
use crate::errors::Error;
|
use crate::errors::Error;
|
||||||
@@ -12,12 +12,15 @@ use crate::stats::get_stats;
|
|||||||
use crate::ClientServerMap;
|
use crate::ClientServerMap;
|
||||||
|
|
||||||
/// Handle admin client.
|
/// Handle admin client.
|
||||||
pub async fn handle_admin(
|
pub async fn handle_admin<T>(
|
||||||
stream: &mut OwnedWriteHalf,
|
stream: &mut T,
|
||||||
mut query: BytesMut,
|
mut query: BytesMut,
|
||||||
pool: ConnectionPool,
|
pool: ConnectionPool,
|
||||||
client_server_map: ClientServerMap,
|
client_server_map: ClientServerMap,
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error>
|
||||||
|
where
|
||||||
|
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||||
|
{
|
||||||
let code = query.get_u8() as char;
|
let code = query.get_u8() as char;
|
||||||
|
|
||||||
if code != 'Q' {
|
if code != 'Q' {
|
||||||
@@ -61,7 +64,10 @@ pub async fn handle_admin(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Column-oriented statistics.
|
/// Column-oriented statistics.
|
||||||
async fn show_lists(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> {
|
async fn show_lists<T>(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error>
|
||||||
|
where
|
||||||
|
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||||
|
{
|
||||||
let stats = get_stats();
|
let stats = get_stats();
|
||||||
|
|
||||||
let columns = vec![("list", DataType::Text), ("items", DataType::Int4)];
|
let columns = vec![("list", DataType::Text), ("items", DataType::Int4)];
|
||||||
@@ -128,7 +134,10 @@ async fn show_lists(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Resul
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Show PgCat version.
|
/// Show PgCat version.
|
||||||
async fn show_version(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
|
async fn show_version<T>(stream: &mut T) -> Result<(), Error>
|
||||||
|
where
|
||||||
|
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||||
|
{
|
||||||
let mut res = BytesMut::new();
|
let mut res = BytesMut::new();
|
||||||
|
|
||||||
res.put(row_description(&vec![("version", DataType::Text)]));
|
res.put(row_description(&vec![("version", DataType::Text)]));
|
||||||
@@ -143,7 +152,10 @@ async fn show_version(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Show utilization of connection pools for each shard and replicas.
|
/// Show utilization of connection pools for each shard and replicas.
|
||||||
async fn show_pools(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> {
|
async fn show_pools<T>(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error>
|
||||||
|
where
|
||||||
|
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||||
|
{
|
||||||
let stats = get_stats();
|
let stats = get_stats();
|
||||||
let config = get_config();
|
let config = get_config();
|
||||||
|
|
||||||
@@ -197,7 +209,10 @@ async fn show_pools(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Resul
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Show shards and replicas.
|
/// Show shards and replicas.
|
||||||
async fn show_databases(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> {
|
async fn show_databases<T>(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error>
|
||||||
|
where
|
||||||
|
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||||
|
{
|
||||||
let config = get_config();
|
let config = get_config();
|
||||||
|
|
||||||
// Columns
|
// Columns
|
||||||
@@ -258,15 +273,18 @@ async fn show_databases(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> R
|
|||||||
|
|
||||||
/// Ignore any SET commands the client sends.
|
/// Ignore any SET commands the client sends.
|
||||||
/// This is common initialization done by ORMs.
|
/// This is common initialization done by ORMs.
|
||||||
async fn ignore_set(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
|
async fn ignore_set<T>(stream: &mut T) -> Result<(), Error>
|
||||||
|
where
|
||||||
|
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||||
|
{
|
||||||
custom_protocol_response_ok(stream, "SET").await
|
custom_protocol_response_ok(stream, "SET").await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Reload the configuration file without restarting the process.
|
/// Reload the configuration file without restarting the process.
|
||||||
async fn reload(
|
async fn reload<T>(stream: &mut T, client_server_map: ClientServerMap) -> Result<(), Error>
|
||||||
stream: &mut OwnedWriteHalf,
|
where
|
||||||
client_server_map: ClientServerMap,
|
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||||
) -> Result<(), Error> {
|
{
|
||||||
info!("Reloading config");
|
info!("Reloading config");
|
||||||
|
|
||||||
reload_config(client_server_map).await?;
|
reload_config(client_server_map).await?;
|
||||||
@@ -286,7 +304,10 @@ async fn reload(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Shows current configuration.
|
/// Shows current configuration.
|
||||||
async fn show_config(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
|
async fn show_config<T>(stream: &mut T) -> Result<(), Error>
|
||||||
|
where
|
||||||
|
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||||
|
{
|
||||||
let config = &get_config();
|
let config = &get_config();
|
||||||
let config: HashMap<String, String> = config.into();
|
let config: HashMap<String, String> = config.into();
|
||||||
|
|
||||||
@@ -329,7 +350,10 @@ async fn show_config(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Show shard and replicas statistics.
|
/// Show shard and replicas statistics.
|
||||||
async fn show_stats(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> {
|
async fn show_stats<T>(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error>
|
||||||
|
where
|
||||||
|
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||||
|
{
|
||||||
let columns = vec![
|
let columns = vec![
|
||||||
("database", DataType::Text),
|
("database", DataType::Text),
|
||||||
("total_xact_count", DataType::Numeric),
|
("total_xact_count", DataType::Numeric),
|
||||||
|
|||||||
498
src/client.rs
498
src/client.rs
@@ -2,7 +2,7 @@
|
|||||||
use bytes::{Buf, BufMut, BytesMut};
|
use bytes::{Buf, BufMut, BytesMut};
|
||||||
use log::{debug, error, trace};
|
use log::{debug, error, trace};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use tokio::io::{AsyncReadExt, BufReader};
|
use tokio::io::{split, AsyncReadExt, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
|
||||||
use tokio::net::{
|
use tokio::net::{
|
||||||
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
||||||
TcpStream,
|
TcpStream,
|
||||||
@@ -17,15 +17,25 @@ use crate::pool::{get_pool, ClientServerMap};
|
|||||||
use crate::query_router::{Command, QueryRouter};
|
use crate::query_router::{Command, QueryRouter};
|
||||||
use crate::server::Server;
|
use crate::server::Server;
|
||||||
use crate::stats::{get_reporter, Reporter};
|
use crate::stats::{get_reporter, Reporter};
|
||||||
|
use crate::stream::Tls;
|
||||||
|
|
||||||
|
use tokio_rustls::server::TlsStream;
|
||||||
|
|
||||||
|
/// Type of connection received from client.
|
||||||
|
enum ClientConnectionType {
|
||||||
|
Startup,
|
||||||
|
Tls,
|
||||||
|
CancelQuery,
|
||||||
|
}
|
||||||
|
|
||||||
/// The client state. One of these is created per client.
|
/// The client state. One of these is created per client.
|
||||||
pub struct Client <T, S> {
|
pub struct Client<S, T> {
|
||||||
/// The reads are buffered (8K by default).
|
/// The reads are buffered (8K by default).
|
||||||
read: BufReader<T>,
|
read: BufReader<S>,
|
||||||
|
|
||||||
/// We buffer the writes ourselves because we know the protocol
|
/// We buffer the writes ourselves because we know the protocol
|
||||||
/// better than a stock buffer.
|
/// better than a stock buffer.
|
||||||
write: S,
|
write: T,
|
||||||
|
|
||||||
/// Internal buffer, where we place messages until we have to flush
|
/// Internal buffer, where we place messages until we have to flush
|
||||||
/// them to the backend.
|
/// them to the backend.
|
||||||
@@ -63,163 +73,347 @@ pub struct Client <T, S> {
|
|||||||
last_server_id: Option<i32>,
|
last_server_id: Option<i32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: tokio::io::AsyncRead + std::marker::Unpin, S: tokio::io::AsyncWrite + std::marker::Unpin> Client <T, S> {
|
/// Main client loop.
|
||||||
/// Perform client startup sequence.
|
pub async fn client_loop(
|
||||||
/// See docs: <https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3>
|
mut stream: TcpStream,
|
||||||
pub async fn startup(
|
client_server_map: ClientServerMap,
|
||||||
mut stream: TcpStream,
|
) -> Result<(), Error> {
|
||||||
|
match get_startup::<TcpStream>(&mut stream).await {
|
||||||
|
Ok((ClientConnectionType::Tls, bytes)) => {
|
||||||
|
match startup_tls(stream, client_server_map).await {
|
||||||
|
Ok(mut client) => client.handle().await,
|
||||||
|
Err(err) => Err(err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok((ClientConnectionType::Startup, bytes)) => {
|
||||||
|
let (read, write) = split(stream);
|
||||||
|
match Client::handle_startup(read, write, bytes, client_server_map).await {
|
||||||
|
Ok(mut client) => client.handle().await,
|
||||||
|
Err(err) => Err(err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok((ClientConnectionType::CancelQuery, bytes)) => {
|
||||||
|
return Err(Error::ProtocolSyncError);
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(err) => Err(err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_startup<S>(stream: &mut S) -> Result<(ClientConnectionType, BytesMut), Error>
|
||||||
|
where
|
||||||
|
S: tokio::io::AsyncRead + std::marker::Unpin + tokio::io::AsyncWrite,
|
||||||
|
{
|
||||||
|
// Get startup message length.
|
||||||
|
let len = match stream.read_i32().await {
|
||||||
|
Ok(len) => len,
|
||||||
|
Err(_) => return Err(Error::ClientBadStartup),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get the rest of the message.
|
||||||
|
let mut startup = vec![0u8; len as usize - 4];
|
||||||
|
match stream.read_exact(&mut startup).await {
|
||||||
|
Ok(_) => (),
|
||||||
|
Err(_) => return Err(Error::ClientBadStartup),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut bytes = BytesMut::from(&startup[..]);
|
||||||
|
let code = bytes.get_i32();
|
||||||
|
|
||||||
|
match code {
|
||||||
|
// Client is requesting SSL (TLS).
|
||||||
|
SSL_REQUEST_CODE => Ok((ClientConnectionType::Tls, bytes)),
|
||||||
|
|
||||||
|
// Client wants to use plain text, requesting regular startup.
|
||||||
|
PROTOCOL_VERSION_NUMBER => Ok((ClientConnectionType::Startup, bytes)),
|
||||||
|
|
||||||
|
// Client is requesting to cancel a running query (plain text connection).
|
||||||
|
CANCEL_REQUEST_CODE => Ok((ClientConnectionType::CancelQuery, bytes)),
|
||||||
|
_ => Err(Error::ProtocolSyncError),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle TLS connection negotation.
|
||||||
|
pub async fn startup_tls(
|
||||||
|
mut stream: TcpStream,
|
||||||
|
client_server_map: ClientServerMap,
|
||||||
|
) -> Result<Client<ReadHalf<TlsStream<TcpStream>>, WriteHalf<TlsStream<TcpStream>>>, Error> {
|
||||||
|
// Accept SSL request if SSL is configured.
|
||||||
|
let mut yes = BytesMut::new();
|
||||||
|
yes.put_u8(b'S');
|
||||||
|
write_all(&mut stream, yes).await?;
|
||||||
|
|
||||||
|
// Negotiate TLS.
|
||||||
|
let mut tls = Tls::new().unwrap();
|
||||||
|
let mut stream = match tls.acceptor.accept(stream).await {
|
||||||
|
Ok(stream) => stream,
|
||||||
|
Err(_) => return Err(Error::TlsError),
|
||||||
|
};
|
||||||
|
|
||||||
|
match get_startup::<TlsStream<TcpStream>>(&mut stream).await {
|
||||||
|
Ok((ClientConnectionType::Startup, bytes)) => {
|
||||||
|
let (read, write) = split(stream);
|
||||||
|
Client::handle_startup(read, write, bytes, client_server_map).await
|
||||||
|
}
|
||||||
|
_ => Err(Error::ProtocolSyncError),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, T> Client<S, T>
|
||||||
|
where
|
||||||
|
S: tokio::io::AsyncRead + std::marker::Unpin,
|
||||||
|
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||||
|
{
|
||||||
|
// Perform client startup sequence in TLS.
|
||||||
|
pub async fn handle_startup(
|
||||||
|
mut read: S,
|
||||||
|
mut write: T,
|
||||||
|
bytes: BytesMut, // The rest of the startup message.
|
||||||
client_server_map: ClientServerMap,
|
client_server_map: ClientServerMap,
|
||||||
) -> Result<Client<T, S>, Error> {
|
) -> Result<Client<S, T>, Error> {
|
||||||
let config = get_config();
|
let config = get_config();
|
||||||
let transaction_mode = config.general.pool_mode == "transaction";
|
let transaction_mode = config.general.pool_mode == "transaction";
|
||||||
let stats = get_reporter();
|
let stats = get_reporter();
|
||||||
|
|
||||||
loop {
|
trace!("Got StartupMessage");
|
||||||
trace!("Waiting for StartupMessage");
|
let parameters = parse_startup(bytes.clone())?;
|
||||||
|
|
||||||
// Could be StartupMessage, SSLRequest or CancelRequest.
|
// Generate random backend ID and secret key
|
||||||
let len = match stream.read_i32().await {
|
let process_id: i32 = rand::random();
|
||||||
Ok(len) => len,
|
let secret_key: i32 = rand::random();
|
||||||
Err(_) => return Err(Error::ClientBadStartup),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut startup = vec![0u8; len as usize - 4];
|
// Perform MD5 authentication.
|
||||||
|
// TODO: Add SASL support.
|
||||||
|
let salt = md5_challenge(&mut write).await?;
|
||||||
|
|
||||||
match stream.read_exact(&mut startup).await {
|
let code = match read.read_u8().await {
|
||||||
Ok(_) => (),
|
Ok(p) => p,
|
||||||
Err(_) => return Err(Error::ClientBadStartup),
|
Err(_) => return Err(Error::SocketError),
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut bytes = BytesMut::from(&startup[..]);
|
// PasswordMessage
|
||||||
let code = bytes.get_i32();
|
if code as char != 'p' {
|
||||||
|
debug!("Expected p, got {}", code as char);
|
||||||
match code {
|
return Err(Error::ProtocolSyncError);
|
||||||
// Client wants SSL. We don't support it at the moment.
|
|
||||||
SSL_REQUEST_CODE => {
|
|
||||||
trace!("Rejecting SSLRequest");
|
|
||||||
|
|
||||||
let mut no = BytesMut::with_capacity(1);
|
|
||||||
no.put_u8(b'N');
|
|
||||||
|
|
||||||
write_all(&mut stream, no).await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Regular startup message.
|
|
||||||
PROTOCOL_VERSION_NUMBER => {
|
|
||||||
trace!("Got StartupMessage");
|
|
||||||
let parameters = parse_startup(bytes.clone())?;
|
|
||||||
|
|
||||||
// Generate random backend ID and secret key
|
|
||||||
let process_id: i32 = rand::random();
|
|
||||||
let secret_key: i32 = rand::random();
|
|
||||||
|
|
||||||
// Perform MD5 authentication.
|
|
||||||
// TODO: Add SASL support.
|
|
||||||
let salt = md5_challenge(&mut stream).await?;
|
|
||||||
|
|
||||||
let code = match stream.read_u8().await {
|
|
||||||
Ok(p) => p,
|
|
||||||
Err(_) => return Err(Error::SocketError),
|
|
||||||
};
|
|
||||||
|
|
||||||
// PasswordMessage
|
|
||||||
if code as char != 'p' {
|
|
||||||
debug!("Expected p, got {}", code as char);
|
|
||||||
return Err(Error::ProtocolSyncError);
|
|
||||||
}
|
|
||||||
|
|
||||||
let len = match stream.read_i32().await {
|
|
||||||
Ok(len) => len,
|
|
||||||
Err(_) => return Err(Error::SocketError),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut password_response = vec![0u8; (len - 4) as usize];
|
|
||||||
|
|
||||||
match stream.read_exact(&mut password_response).await {
|
|
||||||
Ok(_) => (),
|
|
||||||
Err(_) => return Err(Error::SocketError),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Compare server and client hashes.
|
|
||||||
let password_hash =
|
|
||||||
md5_hash_password(&config.user.name, &config.user.password, &salt);
|
|
||||||
|
|
||||||
if password_hash != password_response {
|
|
||||||
debug!("Password authentication failed");
|
|
||||||
wrong_password(&mut stream, &config.user.name).await?;
|
|
||||||
return Err(Error::ClientError);
|
|
||||||
}
|
|
||||||
|
|
||||||
debug!("Password authentication successful");
|
|
||||||
|
|
||||||
auth_ok(&mut stream).await?;
|
|
||||||
write_all(&mut stream, get_pool().server_info()).await?;
|
|
||||||
backend_key_data(&mut stream, process_id, secret_key).await?;
|
|
||||||
ready_for_query(&mut stream).await?;
|
|
||||||
|
|
||||||
trace!("Startup OK");
|
|
||||||
|
|
||||||
let database = parameters
|
|
||||||
.get("database")
|
|
||||||
.unwrap_or(parameters.get("user").unwrap());
|
|
||||||
let admin = ["pgcat", "pgbouncer"]
|
|
||||||
.iter()
|
|
||||||
.filter(|db| *db == &database)
|
|
||||||
.count()
|
|
||||||
== 1;
|
|
||||||
|
|
||||||
// Split the read and write streams
|
|
||||||
// so we can control buffering.
|
|
||||||
let (read, write) = stream.into_split();
|
|
||||||
|
|
||||||
return Ok(Client {
|
|
||||||
read: BufReader::new(read),
|
|
||||||
write: write,
|
|
||||||
buffer: BytesMut::with_capacity(8196),
|
|
||||||
cancel_mode: false,
|
|
||||||
transaction_mode: transaction_mode,
|
|
||||||
process_id: process_id,
|
|
||||||
secret_key: secret_key,
|
|
||||||
client_server_map: client_server_map,
|
|
||||||
parameters: parameters,
|
|
||||||
stats: stats,
|
|
||||||
admin: admin,
|
|
||||||
last_address_id: None,
|
|
||||||
last_server_id: None,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query cancel request.
|
|
||||||
CANCEL_REQUEST_CODE => {
|
|
||||||
let (read, write) = stream.into_split();
|
|
||||||
|
|
||||||
let process_id = bytes.get_i32();
|
|
||||||
let secret_key = bytes.get_i32();
|
|
||||||
|
|
||||||
return Ok(Client {
|
|
||||||
read: BufReader::new(read),
|
|
||||||
write: write,
|
|
||||||
buffer: BytesMut::with_capacity(8196),
|
|
||||||
cancel_mode: true,
|
|
||||||
transaction_mode: transaction_mode,
|
|
||||||
process_id: process_id,
|
|
||||||
secret_key: secret_key,
|
|
||||||
client_server_map: client_server_map,
|
|
||||||
parameters: HashMap::new(),
|
|
||||||
stats: stats,
|
|
||||||
admin: false,
|
|
||||||
last_address_id: None,
|
|
||||||
last_server_id: None,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
_ => {
|
|
||||||
return Err(Error::ProtocolSyncError);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let len = match read.read_i32().await {
|
||||||
|
Ok(len) => len,
|
||||||
|
Err(_) => return Err(Error::SocketError),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut password_response = vec![0u8; (len - 4) as usize];
|
||||||
|
|
||||||
|
match read.read_exact(&mut password_response).await {
|
||||||
|
Ok(_) => (),
|
||||||
|
Err(_) => return Err(Error::SocketError),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Compare server and client hashes.
|
||||||
|
let password_hash = md5_hash_password(&config.user.name, &config.user.password, &salt);
|
||||||
|
|
||||||
|
if password_hash != password_response {
|
||||||
|
debug!("Password authentication failed");
|
||||||
|
wrong_password(&mut write, &config.user.name).await?;
|
||||||
|
return Err(Error::ClientError);
|
||||||
|
}
|
||||||
|
|
||||||
|
debug!("Password authentication successful");
|
||||||
|
|
||||||
|
auth_ok(&mut write).await?;
|
||||||
|
write_all(&mut write, get_pool().server_info()).await?;
|
||||||
|
backend_key_data(&mut write, process_id, secret_key).await?;
|
||||||
|
ready_for_query(&mut write).await?;
|
||||||
|
|
||||||
|
trace!("Startup OK");
|
||||||
|
|
||||||
|
let database = parameters
|
||||||
|
.get("database")
|
||||||
|
.unwrap_or(parameters.get("user").unwrap());
|
||||||
|
let admin = ["pgcat", "pgbouncer"]
|
||||||
|
.iter()
|
||||||
|
.filter(|db| *db == &database)
|
||||||
|
.count()
|
||||||
|
== 1;
|
||||||
|
|
||||||
|
// Split the read and write streams
|
||||||
|
// so we can control buffering.
|
||||||
|
|
||||||
|
return Ok(Client {
|
||||||
|
read: BufReader::new(read),
|
||||||
|
write: write,
|
||||||
|
buffer: BytesMut::with_capacity(8196),
|
||||||
|
cancel_mode: false,
|
||||||
|
transaction_mode: transaction_mode,
|
||||||
|
process_id: process_id,
|
||||||
|
secret_key: secret_key,
|
||||||
|
client_server_map: client_server_map,
|
||||||
|
parameters: parameters,
|
||||||
|
stats: stats,
|
||||||
|
admin: admin,
|
||||||
|
last_address_id: None,
|
||||||
|
last_server_id: None,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Perform client startup sequence.
|
||||||
|
/// See docs: <https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3>
|
||||||
|
// pub async fn startup(
|
||||||
|
// mut stream: TcpStream,
|
||||||
|
// client_server_map: ClientServerMap,
|
||||||
|
// ) -> Result<Client<ReadHalf<TcpStream>, WriteHalf<TcpStream>>, Error> {
|
||||||
|
// let config = get_config();
|
||||||
|
// let transaction_mode = config.general.pool_mode == "transaction";
|
||||||
|
// let stats = get_reporter();
|
||||||
|
|
||||||
|
// loop {
|
||||||
|
// trace!("Waiting for StartupMessage");
|
||||||
|
|
||||||
|
// // Could be StartupMessage, SSLRequest or CancelRequest.
|
||||||
|
// let len = match stream.read_i32().await {
|
||||||
|
// Ok(len) => len,
|
||||||
|
// Err(_) => return Err(Error::ClientBadStartup),
|
||||||
|
// };
|
||||||
|
|
||||||
|
// let mut startup = vec![0u8; len as usize - 4];
|
||||||
|
|
||||||
|
// match stream.read_exact(&mut startup).await {
|
||||||
|
// Ok(_) => (),
|
||||||
|
// Err(_) => return Err(Error::ClientBadStartup),
|
||||||
|
// };
|
||||||
|
|
||||||
|
// let mut bytes = BytesMut::from(&startup[..]);
|
||||||
|
// let code = bytes.get_i32();
|
||||||
|
|
||||||
|
// match code {
|
||||||
|
// // Client wants SSL. We don't support it at the moment.
|
||||||
|
// SSL_REQUEST_CODE => {
|
||||||
|
// trace!("Rejecting SSLRequest");
|
||||||
|
|
||||||
|
// let mut no = BytesMut::with_capacity(1);
|
||||||
|
// no.put_u8(b'N');
|
||||||
|
|
||||||
|
// write_all(&mut stream, no).await?;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // Regular startup message.
|
||||||
|
// PROTOCOL_VERSION_NUMBER => {
|
||||||
|
// trace!("Got StartupMessage");
|
||||||
|
// let parameters = parse_startup(bytes.clone())?;
|
||||||
|
|
||||||
|
// // Generate random backend ID and secret key
|
||||||
|
// let process_id: i32 = rand::random();
|
||||||
|
// let secret_key: i32 = rand::random();
|
||||||
|
|
||||||
|
// // Perform MD5 authentication.
|
||||||
|
// // TODO: Add SASL support.
|
||||||
|
// let salt = md5_challenge(&mut stream).await?;
|
||||||
|
|
||||||
|
// let code = match stream.read_u8().await {
|
||||||
|
// Ok(p) => p,
|
||||||
|
// Err(_) => return Err(Error::SocketError),
|
||||||
|
// };
|
||||||
|
|
||||||
|
// // PasswordMessage
|
||||||
|
// if code as char != 'p' {
|
||||||
|
// debug!("Expected p, got {}", code as char);
|
||||||
|
// return Err(Error::ProtocolSyncError);
|
||||||
|
// }
|
||||||
|
|
||||||
|
// let len = match stream.read_i32().await {
|
||||||
|
// Ok(len) => len,
|
||||||
|
// Err(_) => return Err(Error::SocketError),
|
||||||
|
// };
|
||||||
|
|
||||||
|
// let mut password_response = vec![0u8; (len - 4) as usize];
|
||||||
|
|
||||||
|
// match stream.read_exact(&mut password_response).await {
|
||||||
|
// Ok(_) => (),
|
||||||
|
// Err(_) => return Err(Error::SocketError),
|
||||||
|
// };
|
||||||
|
|
||||||
|
// // Compare server and client hashes.
|
||||||
|
// let password_hash =
|
||||||
|
// md5_hash_password(&config.user.name, &config.user.password, &salt);
|
||||||
|
|
||||||
|
// if password_hash != password_response {
|
||||||
|
// debug!("Password authentication failed");
|
||||||
|
// wrong_password(&mut stream, &config.user.name).await?;
|
||||||
|
// return Err(Error::ClientError);
|
||||||
|
// }
|
||||||
|
|
||||||
|
// debug!("Password authentication successful");
|
||||||
|
|
||||||
|
// auth_ok(&mut stream).await?;
|
||||||
|
// write_all(&mut stream, get_pool().server_info()).await?;
|
||||||
|
// backend_key_data(&mut stream, process_id, secret_key).await?;
|
||||||
|
// ready_for_query(&mut stream).await?;
|
||||||
|
|
||||||
|
// trace!("Startup OK");
|
||||||
|
|
||||||
|
// let database = parameters
|
||||||
|
// .get("database")
|
||||||
|
// .unwrap_or(parameters.get("user").unwrap());
|
||||||
|
// let admin = ["pgcat", "pgbouncer"]
|
||||||
|
// .iter()
|
||||||
|
// .filter(|db| *db == &database)
|
||||||
|
// .count()
|
||||||
|
// == 1;
|
||||||
|
|
||||||
|
// // Split the read and write streams
|
||||||
|
// // so we can control buffering.
|
||||||
|
// let (read, write) = split(stream);
|
||||||
|
|
||||||
|
// return Ok(Client {
|
||||||
|
// read: BufReader::new(read),
|
||||||
|
// write: write,
|
||||||
|
// buffer: BytesMut::with_capacity(8196),
|
||||||
|
// cancel_mode: false,
|
||||||
|
// transaction_mode: transaction_mode,
|
||||||
|
// process_id: process_id,
|
||||||
|
// secret_key: secret_key,
|
||||||
|
// client_server_map: client_server_map,
|
||||||
|
// parameters: parameters,
|
||||||
|
// stats: stats,
|
||||||
|
// admin: admin,
|
||||||
|
// last_address_id: None,
|
||||||
|
// last_server_id: None,
|
||||||
|
// });
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // Query cancel request.
|
||||||
|
// CANCEL_REQUEST_CODE => {
|
||||||
|
// let (read, write) = split(stream);
|
||||||
|
|
||||||
|
// let process_id = bytes.get_i32();
|
||||||
|
// let secret_key = bytes.get_i32();
|
||||||
|
|
||||||
|
// return Ok(Client {
|
||||||
|
// read: BufReader::new(read),
|
||||||
|
// write: write,
|
||||||
|
// buffer: BytesMut::with_capacity(8196),
|
||||||
|
// cancel_mode: true,
|
||||||
|
// transaction_mode: transaction_mode,
|
||||||
|
// process_id: process_id,
|
||||||
|
// secret_key: secret_key,
|
||||||
|
// client_server_map: client_server_map,
|
||||||
|
// parameters: HashMap::new(),
|
||||||
|
// stats: stats,
|
||||||
|
// admin: false,
|
||||||
|
// last_address_id: None,
|
||||||
|
// last_server_id: None,
|
||||||
|
// });
|
||||||
|
// }
|
||||||
|
|
||||||
|
// _ => {
|
||||||
|
// return Err(Error::ProtocolSyncError);
|
||||||
|
// }
|
||||||
|
// };
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
/// Handle a connected and authenticated client.
|
/// Handle a connected and authenticated client.
|
||||||
pub async fn handle(&mut self) -> Result<(), Error> {
|
pub async fn handle(&mut self) -> Result<(), Error> {
|
||||||
// The client wants to cancel a query it has issued previously.
|
// The client wants to cancel a query it has issued previously.
|
||||||
@@ -414,8 +608,8 @@ impl<T: tokio::io::AsyncRead + std::marker::Unpin, S: tokio::io::AsyncWrite + st
|
|||||||
self.last_server_id = Some(server.process_id());
|
self.last_server_id = Some(server.process_id());
|
||||||
|
|
||||||
debug!(
|
debug!(
|
||||||
"Client {:?} talking to server {:?}",
|
"Client stuff talking to server {:?}",
|
||||||
self.write.peer_addr().unwrap(),
|
// self.write.peer_addr().unwrap(),
|
||||||
server.address()
|
server.address()
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -650,7 +844,7 @@ impl<T: tokio::io::AsyncRead + std::marker::Unpin, S: tokio::io::AsyncWrite + st
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T, S> Drop for Client <T, S> {
|
impl<S, T> Drop for Client<S, T> {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
// Update statistics.
|
// Update statistics.
|
||||||
if let Some(address_id) = self.last_address_id {
|
if let Some(address_id) = self.last_address_id {
|
||||||
|
|||||||
65
src/main.rs
65
src/main.rs
@@ -28,13 +28,13 @@ extern crate log;
|
|||||||
extern crate md5;
|
extern crate md5;
|
||||||
extern crate num_cpus;
|
extern crate num_cpus;
|
||||||
extern crate once_cell;
|
extern crate once_cell;
|
||||||
|
extern crate rustls_pemfile;
|
||||||
extern crate serde;
|
extern crate serde;
|
||||||
extern crate serde_derive;
|
extern crate serde_derive;
|
||||||
extern crate sqlparser;
|
extern crate sqlparser;
|
||||||
extern crate tokio;
|
extern crate tokio;
|
||||||
extern crate toml;
|
|
||||||
extern crate tokio_rustls;
|
extern crate tokio_rustls;
|
||||||
extern crate rustls_pemfile;
|
extern crate toml;
|
||||||
|
|
||||||
use log::{debug, error, info};
|
use log::{debug, error, info};
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
@@ -45,6 +45,11 @@ use tokio::{
|
|||||||
sync::mpsc,
|
sync::mpsc,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use tokio::net::{
|
||||||
|
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
||||||
|
TcpStream,
|
||||||
|
};
|
||||||
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
@@ -62,6 +67,7 @@ mod sharding;
|
|||||||
mod stats;
|
mod stats;
|
||||||
mod stream;
|
mod stream;
|
||||||
|
|
||||||
|
use crate::constants::*;
|
||||||
use config::{get_config, reload_config};
|
use config::{get_config, reload_config};
|
||||||
use pool::{ClientServerMap, ConnectionPool};
|
use pool::{ClientServerMap, ConnectionPool};
|
||||||
use stats::{Collector, Reporter, REPORTER};
|
use stats::{Collector, Reporter, REPORTER};
|
||||||
@@ -153,32 +159,45 @@ async fn main() {
|
|||||||
// Handle client.
|
// Handle client.
|
||||||
tokio::task::spawn(async move {
|
tokio::task::spawn(async move {
|
||||||
let start = chrono::offset::Utc::now().naive_utc();
|
let start = chrono::offset::Utc::now().naive_utc();
|
||||||
match client::Client::startup(socket, client_server_map).await {
|
// match client::get_startup(&mut socket) {
|
||||||
Ok(mut client) => {
|
// Ok((code, bytes)) => match code {
|
||||||
info!("Client {:?} connected", addr);
|
// SSL_REQUEST_CODE => client::Client::tls_startup<
|
||||||
|
// }
|
||||||
match client.handle().await {
|
// }
|
||||||
Ok(()) => {
|
|
||||||
let duration = chrono::offset::Utc::now().naive_utc() - start;
|
|
||||||
|
|
||||||
info!(
|
|
||||||
"Client {:?} disconnected, session duration: {}",
|
|
||||||
addr,
|
|
||||||
format_duration(&duration)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
Err(err) => {
|
|
||||||
error!("Client disconnected with error: {:?}", err);
|
|
||||||
client.release();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
match client::client_loop(socket, client_server_map).await {
|
||||||
|
Ok(_) => (),
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
debug!("Client failed to login: {:?}", err);
|
debug!("Client failed to login: {:?}", err);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// match client::Client::<OwnedReadHalf, OwnedWriteHalf>::startup(socket, client_server_map).await {
|
||||||
|
// Ok(mut client) => {
|
||||||
|
// info!("Client {:?} connected", addr);
|
||||||
|
|
||||||
|
// match client.handle().await {
|
||||||
|
// Ok(()) => {
|
||||||
|
// let duration = chrono::offset::Utc::now().naive_utc() - start;
|
||||||
|
|
||||||
|
// info!(
|
||||||
|
// "Client {:?} disconnected, session duration: {}",
|
||||||
|
// addr,
|
||||||
|
// format_duration(&duration)
|
||||||
|
// );
|
||||||
|
// }
|
||||||
|
|
||||||
|
// Err(err) => {
|
||||||
|
// error!("Client disconnected with error: {:?}", err);
|
||||||
|
// client.release();
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// Err(err) => {
|
||||||
|
// debug!("Client failed to login: {:?}", err);
|
||||||
|
// }
|
||||||
|
// };
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -31,7 +31,9 @@ impl From<&DataType> for i32 {
|
|||||||
|
|
||||||
/// Tell the client that authentication handshake completed successfully.
|
/// Tell the client that authentication handshake completed successfully.
|
||||||
pub async fn auth_ok<S>(stream: &mut S) -> Result<(), Error>
|
pub async fn auth_ok<S>(stream: &mut S) -> Result<(), Error>
|
||||||
where S: tokio::io::AsyncWrite + std::marker::Unpin {
|
where
|
||||||
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||||
|
{
|
||||||
let mut auth_ok = BytesMut::with_capacity(9);
|
let mut auth_ok = BytesMut::with_capacity(9);
|
||||||
|
|
||||||
auth_ok.put_u8(b'R');
|
auth_ok.put_u8(b'R');
|
||||||
@@ -43,7 +45,9 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin {
|
|||||||
|
|
||||||
/// Generate md5 password challenge.
|
/// Generate md5 password challenge.
|
||||||
pub async fn md5_challenge<S>(stream: &mut S) -> Result<[u8; 4], Error>
|
pub async fn md5_challenge<S>(stream: &mut S) -> Result<[u8; 4], Error>
|
||||||
where S: tokio::io::AsyncWrite + std::marker::Unpin {
|
where
|
||||||
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||||
|
{
|
||||||
// let mut rng = rand::thread_rng();
|
// let mut rng = rand::thread_rng();
|
||||||
let salt: [u8; 4] = [
|
let salt: [u8; 4] = [
|
||||||
rand::random(),
|
rand::random(),
|
||||||
@@ -69,7 +73,9 @@ pub async fn backend_key_data<S>(
|
|||||||
backend_id: i32,
|
backend_id: i32,
|
||||||
secret_key: i32,
|
secret_key: i32,
|
||||||
) -> Result<(), Error>
|
) -> Result<(), Error>
|
||||||
where S: tokio::io::AsyncWrite + std::marker::Unpin {
|
where
|
||||||
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||||
|
{
|
||||||
let mut key_data = BytesMut::from(&b"K"[..]);
|
let mut key_data = BytesMut::from(&b"K"[..]);
|
||||||
key_data.put_i32(12);
|
key_data.put_i32(12);
|
||||||
key_data.put_i32(backend_id);
|
key_data.put_i32(backend_id);
|
||||||
@@ -91,7 +97,9 @@ pub fn simple_query(query: &str) -> BytesMut {
|
|||||||
|
|
||||||
/// Tell the client we're ready for another query.
|
/// Tell the client we're ready for another query.
|
||||||
pub async fn ready_for_query<S>(stream: &mut S) -> Result<(), Error>
|
pub async fn ready_for_query<S>(stream: &mut S) -> Result<(), Error>
|
||||||
where S: tokio::io::AsyncWrite + std::marker::Unpin {
|
where
|
||||||
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||||
|
{
|
||||||
let mut bytes = BytesMut::with_capacity(5);
|
let mut bytes = BytesMut::with_capacity(5);
|
||||||
|
|
||||||
bytes.put_u8(b'Z');
|
bytes.put_u8(b'Z');
|
||||||
@@ -215,7 +223,9 @@ pub async fn md5_password<S>(
|
|||||||
password: &str,
|
password: &str,
|
||||||
salt: &[u8],
|
salt: &[u8],
|
||||||
) -> Result<(), Error>
|
) -> Result<(), Error>
|
||||||
where S: tokio::io::AsyncWrite + std::marker::Unpin {
|
where
|
||||||
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||||
|
{
|
||||||
let password = md5_hash_password(user, password, salt);
|
let password = md5_hash_password(user, password, salt);
|
||||||
|
|
||||||
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
|
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
|
||||||
@@ -230,11 +240,10 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin {
|
|||||||
/// Implements a response to our custom `SET SHARDING KEY`
|
/// Implements a response to our custom `SET SHARDING KEY`
|
||||||
/// and `SET SERVER ROLE` commands.
|
/// and `SET SERVER ROLE` commands.
|
||||||
/// This tells the client we're ready for the next query.
|
/// This tells the client we're ready for the next query.
|
||||||
pub async fn custom_protocol_response_ok<S>(
|
pub async fn custom_protocol_response_ok<S>(stream: &mut S, message: &str) -> Result<(), Error>
|
||||||
stream: &mut S,
|
where
|
||||||
message: &str,
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||||
) -> Result<(), Error>
|
{
|
||||||
where S: tokio::io::AsyncWrite + std::marker::Unpin {
|
|
||||||
let mut res = BytesMut::with_capacity(25);
|
let mut res = BytesMut::with_capacity(25);
|
||||||
|
|
||||||
let set_complete = BytesMut::from(&format!("{}\0", message)[..]);
|
let set_complete = BytesMut::from(&format!("{}\0", message)[..]);
|
||||||
@@ -257,7 +266,9 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin {
|
|||||||
/// Tell the client we are ready for the next query and no rollback is necessary.
|
/// Tell the client we are ready for the next query and no rollback is necessary.
|
||||||
/// Docs on error codes: <https://www.postgresql.org/docs/12/errcodes-appendix.html>.
|
/// Docs on error codes: <https://www.postgresql.org/docs/12/errcodes-appendix.html>.
|
||||||
pub async fn error_response<S>(stream: &mut S, message: &str) -> Result<(), Error>
|
pub async fn error_response<S>(stream: &mut S, message: &str) -> Result<(), Error>
|
||||||
where S: tokio::io::AsyncWrite + std::marker::Unpin {
|
where
|
||||||
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||||
|
{
|
||||||
let mut error = BytesMut::new();
|
let mut error = BytesMut::new();
|
||||||
|
|
||||||
// Error level
|
// Error level
|
||||||
@@ -299,7 +310,9 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn wrong_password<S>(stream: &mut S, user: &str) -> Result<(), Error>
|
pub async fn wrong_password<S>(stream: &mut S, user: &str) -> Result<(), Error>
|
||||||
where S: tokio::io::AsyncWrite + std::marker::Unpin {
|
where
|
||||||
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||||
|
{
|
||||||
let mut error = BytesMut::new();
|
let mut error = BytesMut::new();
|
||||||
|
|
||||||
// Error level
|
// Error level
|
||||||
@@ -333,11 +346,10 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Respond to a SHOW SHARD command.
|
/// Respond to a SHOW SHARD command.
|
||||||
pub async fn show_response(
|
pub async fn show_response<S>(stream: &mut S, name: &str, value: &str) -> Result<(), Error>
|
||||||
stream: &mut OwnedWriteHalf,
|
where
|
||||||
name: &str,
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||||
value: &str,
|
{
|
||||||
) -> Result<(), Error> {
|
|
||||||
// A SELECT response consists of:
|
// A SELECT response consists of:
|
||||||
// 1. RowDescription
|
// 1. RowDescription
|
||||||
// 2. One or more DataRow
|
// 2. One or more DataRow
|
||||||
@@ -439,7 +451,9 @@ pub fn command_complete(command: &str) -> BytesMut {
|
|||||||
|
|
||||||
/// Write all data in the buffer to the TcpStream.
|
/// Write all data in the buffer to the TcpStream.
|
||||||
pub async fn write_all<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error>
|
pub async fn write_all<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error>
|
||||||
where S: tokio::io::AsyncWrite + std::marker::Unpin {
|
where
|
||||||
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||||
|
{
|
||||||
match stream.write_all(&buf).await {
|
match stream.write_all(&buf).await {
|
||||||
Ok(_) => Ok(()),
|
Ok(_) => Ok(()),
|
||||||
Err(_) => return Err(Error::SocketError),
|
Err(_) => return Err(Error::SocketError),
|
||||||
@@ -448,7 +462,9 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin {
|
|||||||
|
|
||||||
/// Write all the data in the buffer to the TcpStream, write owned half (see mpsc).
|
/// Write all the data in the buffer to the TcpStream, write owned half (see mpsc).
|
||||||
pub async fn write_all_half<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error>
|
pub async fn write_all_half<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error>
|
||||||
where S: tokio::io::AsyncWrite + std::marker::Unpin {
|
where
|
||||||
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||||
|
{
|
||||||
match stream.write_all(&buf).await {
|
match stream.write_all(&buf).await {
|
||||||
Ok(_) => Ok(()),
|
Ok(_) => Ok(()),
|
||||||
Err(_) => return Err(Error::SocketError),
|
Err(_) => return Err(Error::SocketError),
|
||||||
@@ -456,7 +472,10 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Read a complete message from the socket.
|
/// Read a complete message from the socket.
|
||||||
pub async fn read_message(stream: &mut BufReader<OwnedReadHalf>) -> Result<BytesMut, Error> {
|
pub async fn read_message<S>(stream: &mut S) -> Result<BytesMut, Error>
|
||||||
|
where
|
||||||
|
S: tokio::io::AsyncRead + std::marker::Unpin,
|
||||||
|
{
|
||||||
let code = match stream.read_u8().await {
|
let code = match stream.read_u8().await {
|
||||||
Ok(code) => code,
|
Ok(code) => code,
|
||||||
Err(_) => return Err(Error::SocketError),
|
Err(_) => return Err(Error::SocketError),
|
||||||
|
|||||||
185
src/stream.rs
185
src/stream.rs
@@ -1,17 +1,17 @@
|
|||||||
// Stream wrapper.
|
// Stream wrapper.
|
||||||
|
|
||||||
use bytes::{Buf, BufMut, BytesMut};
|
use bytes::{Buf, BufMut, BytesMut};
|
||||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, split, ReadHalf, WriteHalf};
|
use rustls_pemfile::{certs, rsa_private_keys};
|
||||||
|
use std::path::Path;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::io::{split, AsyncReadExt, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
|
||||||
use tokio::net::{
|
use tokio::net::{
|
||||||
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
||||||
TcpStream,
|
TcpStream,
|
||||||
};
|
};
|
||||||
use tokio_rustls::server::TlsStream;
|
|
||||||
use rustls_pemfile::{certs, rsa_private_keys};
|
|
||||||
use tokio_rustls::rustls::{self, Certificate, PrivateKey};
|
use tokio_rustls::rustls::{self, Certificate, PrivateKey};
|
||||||
|
use tokio_rustls::server::TlsStream;
|
||||||
use tokio_rustls::TlsAcceptor;
|
use tokio_rustls::TlsAcceptor;
|
||||||
use std::sync::Arc;
|
|
||||||
use std::path::Path;
|
|
||||||
|
|
||||||
use crate::config::get_config;
|
use crate::config::get_config;
|
||||||
use crate::errors::Error;
|
use crate::errors::Error;
|
||||||
@@ -29,132 +29,91 @@ fn load_keys(path: &std::path::Path) -> std::io::Result<Vec<PrivateKey>> {
|
|||||||
.map(|mut keys| keys.drain(..).map(PrivateKey).collect())
|
.map(|mut keys| keys.drain(..).map(PrivateKey).collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Tls {
|
pub struct Tls {
|
||||||
acceptor: TlsAcceptor,
|
pub acceptor: TlsAcceptor,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Tls {
|
impl Tls {
|
||||||
pub fn new() -> Result<Self, Error> {
|
pub fn new() -> Result<Self, Error> {
|
||||||
let config = get_config();
|
let config = get_config();
|
||||||
|
|
||||||
let certs = match load_certs(&Path::new(&config.general.tls_certificate.unwrap())) {
|
let certs = match load_certs(&Path::new(&config.general.tls_certificate.unwrap())) {
|
||||||
Ok(certs) => certs,
|
Ok(certs) => certs,
|
||||||
Err(_) => return Err(Error::TlsError),
|
Err(_) => return Err(Error::TlsError),
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut keys = match load_keys(&Path::new(&config.general.tls_private_key.unwrap())) {
|
let mut keys = match load_keys(&Path::new(&config.general.tls_private_key.unwrap())) {
|
||||||
Ok(keys) => keys,
|
Ok(keys) => keys,
|
||||||
Err(_) => return Err(Error::TlsError),
|
Err(_) => return Err(Error::TlsError),
|
||||||
};
|
};
|
||||||
|
|
||||||
let config = match rustls::ServerConfig::builder()
|
let config = match rustls::ServerConfig::builder()
|
||||||
.with_safe_defaults()
|
.with_safe_defaults()
|
||||||
.with_no_client_auth()
|
.with_no_client_auth()
|
||||||
.with_single_cert(certs, keys.remove(0))
|
.with_single_cert(certs, keys.remove(0))
|
||||||
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err)) {
|
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))
|
||||||
Ok(c) => c,
|
{
|
||||||
Err(_) => return Err(Error::TlsError)
|
Ok(c) => c,
|
||||||
|
Err(_) => return Err(Error::TlsError),
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Tls {
|
Ok(Tls {
|
||||||
acceptor: TlsAcceptor::from(Arc::new(config)),
|
acceptor: TlsAcceptor::from(Arc::new(config)),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Stream {
|
struct Stream {
|
||||||
read: Option<BufReader<OwnedReadHalf>>,
|
read: Option<BufReader<OwnedReadHalf>>,
|
||||||
write: Option<OwnedWriteHalf>,
|
write: Option<OwnedWriteHalf>,
|
||||||
tls_read: Option<BufReader<ReadHalf<TlsStream<TcpStream>>>>,
|
tls_read: Option<BufReader<ReadHalf<TlsStream<TcpStream>>>>,
|
||||||
tls_write: Option<WriteHalf<TlsStream<TcpStream>>>,
|
tls_write: Option<WriteHalf<TlsStream<TcpStream>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl Stream {
|
impl Stream {
|
||||||
pub async fn new(stream: TcpStream, tls: Option<Tls>) -> Result<Stream, Error> {
|
pub async fn new(stream: TcpStream, tls: Option<Tls>) -> Result<Stream, Error> {
|
||||||
|
let config = get_config();
|
||||||
|
|
||||||
let config = get_config();
|
match tls {
|
||||||
|
None => {
|
||||||
|
let (read, write) = stream.into_split();
|
||||||
|
let read = BufReader::new(read);
|
||||||
|
Ok(Self {
|
||||||
|
read: Some(read),
|
||||||
|
write: Some(write),
|
||||||
|
tls_read: None,
|
||||||
|
tls_write: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
match tls {
|
Some(tls) => {
|
||||||
None => {
|
let mut tls_stream = match tls.acceptor.accept(stream).await {
|
||||||
let (read, write) = stream.into_split();
|
Ok(stream) => stream,
|
||||||
let read = BufReader::new(read);
|
Err(_) => return Err(Error::TlsError),
|
||||||
Ok(
|
};
|
||||||
Self {
|
|
||||||
read: Some(read),
|
|
||||||
write: Some(write),
|
|
||||||
tls_read: None,
|
|
||||||
tls_write: None,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
Some(tls) => {
|
let (read, write) = split(tls_stream);
|
||||||
let mut tls_stream = match tls.acceptor.accept(stream).await {
|
|
||||||
Ok(stream) => stream,
|
|
||||||
Err(_) => return Err(Error::TlsError),
|
|
||||||
};
|
|
||||||
|
|
||||||
let (read, write) = split(tls_stream);
|
Ok(Self {
|
||||||
|
read: None,
|
||||||
Ok(Self{
|
write: None,
|
||||||
read: None,
|
tls_read: Some(BufReader::new(read)),
|
||||||
write: None,
|
tls_write: Some(write),
|
||||||
tls_read: Some(BufReader::new(read)),
|
})
|
||||||
tls_write: Some(write),
|
}
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn read<S>(stream: &mut S) -> Result<BytesMut, Error>
|
|
||||||
where S: tokio::io::AsyncRead + std::marker::Unpin {
|
|
||||||
|
|
||||||
let code = match stream.read_u8().await {
|
|
||||||
Ok(code) => code,
|
|
||||||
Err(_) => return Err(Error::SocketError),
|
|
||||||
};
|
|
||||||
|
|
||||||
let len = match stream.read_i32().await {
|
|
||||||
Ok(len) => len,
|
|
||||||
Err(_) => return Err(Error::SocketError),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut buf = vec![0u8; len as usize - 4];
|
|
||||||
|
|
||||||
match stream.read_exact(&mut buf).await {
|
|
||||||
Ok(_) => (),
|
|
||||||
Err(_) => return Err(Error::SocketError),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut bytes = BytesMut::with_capacity(len as usize + 1);
|
|
||||||
|
|
||||||
bytes.put_u8(code);
|
|
||||||
bytes.put_i32(len);
|
|
||||||
bytes.put_slice(&buf);
|
|
||||||
|
|
||||||
Ok(bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn write<S>(stream: &mut S, buf: &BytesMut) -> Result<(), Error>
|
|
||||||
where S: tokio::io::AsyncWrite + std::marker::Unpin {
|
|
||||||
match stream.write_all(buf).await {
|
|
||||||
Ok(_) => Ok(()),
|
|
||||||
Err(_) => return Err(Error::SocketError),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn read_message(&mut self) -> Result<BytesMut, Error> {
|
|
||||||
match &self.read {
|
|
||||||
Some(read) => Self::read(self.read.as_mut().unwrap()).await,
|
|
||||||
None => Self::read(self.tls_read.as_mut().unwrap()).await,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn write_all(&mut self, buf: &BytesMut) -> Result<(), Error> {
|
|
||||||
match &self.write {
|
|
||||||
Some(write) => Self::write(self.write.as_mut().unwrap(), buf).await,
|
|
||||||
None => Self::write(self.tls_write.as_mut().unwrap(), buf).await,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// impl tokio::io::AsyncRead for Stream {
|
||||||
|
// fn poll_read(
|
||||||
|
// mut self: core::pin::Pin<&mut Self>,
|
||||||
|
// cx: &mut core::task::Context<'_>,
|
||||||
|
// buf: &mut tokio::io::ReadBuf<'_>
|
||||||
|
// ) -> core::task::Poll<std::io::Result<()>> {
|
||||||
|
// match &mut self.get_mut().tls_read {
|
||||||
|
// None => core::pin::Pin::new(self.read.as_mut().unwrap()).poll_read(cx, buf),
|
||||||
|
// Some(mut tls) => core::pin::Pin::new(&mut tls).poll_read(cx, buf),
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|||||||
Reference in New Issue
Block a user