mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 09:26:30 +00:00
hmm
This commit is contained in:
@@ -33,8 +33,8 @@ ban_time = 60 # Seconds
|
||||
autoreload = false
|
||||
|
||||
# TLS
|
||||
tls_certificate = "server.cert"
|
||||
tls_private_key = "server.key"
|
||||
# tls_certificate = "server.cert"
|
||||
# tls_private_key = "server.key"
|
||||
|
||||
#
|
||||
# User to use for authentication against the server.
|
||||
|
||||
334
src/client.rs
334
src/client.rs
@@ -1,12 +1,9 @@
|
||||
/// Handle clients by pretending to be a PostgreSQL server.
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use log::{debug, error, trace};
|
||||
use log::{debug, error, trace, info};
|
||||
use std::collections::HashMap;
|
||||
use tokio::io::{split, AsyncReadExt, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
|
||||
use tokio::net::{
|
||||
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
||||
TcpStream,
|
||||
};
|
||||
use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use crate::admin::handle_admin;
|
||||
use crate::config::get_config;
|
||||
@@ -17,7 +14,7 @@ use crate::pool::{get_pool, ClientServerMap};
|
||||
use crate::query_router::{Command, QueryRouter};
|
||||
use crate::server::Server;
|
||||
use crate::stats::{get_reporter, Reporter};
|
||||
use crate::stream::Tls;
|
||||
use crate::tls::Tls;
|
||||
|
||||
use tokio_rustls::server::TlsStream;
|
||||
|
||||
@@ -41,6 +38,9 @@ pub struct Client<S, T> {
|
||||
/// them to the backend.
|
||||
buffer: BytesMut,
|
||||
|
||||
/// Address
|
||||
addr: std::net::SocketAddr,
|
||||
|
||||
/// The client was started with the sole reason to cancel another running query.
|
||||
cancel_mode: bool,
|
||||
|
||||
@@ -73,35 +73,107 @@ pub struct Client<S, T> {
|
||||
last_server_id: Option<i32>,
|
||||
}
|
||||
|
||||
/// Main client loop.
|
||||
pub async fn client_loop(
|
||||
/// Client entrypoint.
|
||||
pub async fn client_entrypoint(
|
||||
mut stream: TcpStream,
|
||||
client_server_map: ClientServerMap,
|
||||
) -> Result<(), Error> {
|
||||
// Figure out if the client wants TLS or not.
|
||||
let addr = stream.peer_addr().unwrap();
|
||||
|
||||
match get_startup::<TcpStream>(&mut stream).await {
|
||||
Ok((ClientConnectionType::Tls, bytes)) => {
|
||||
match startup_tls(stream, client_server_map).await {
|
||||
Ok(mut client) => client.handle().await,
|
||||
Err(err) => Err(err),
|
||||
|
||||
// Client requested a TLS connection.
|
||||
Ok((ClientConnectionType::Tls, _)) => {
|
||||
let config = get_config();
|
||||
|
||||
// TLS settings are configured, will setup TLS now.
|
||||
if config.general.tls_certificate != None {
|
||||
debug!("Accepting TLS request");
|
||||
|
||||
let mut yes = BytesMut::new();
|
||||
yes.put_u8(b'S');
|
||||
write_all(&mut stream, yes).await?;
|
||||
|
||||
// Negotiate TLS.
|
||||
match startup_tls(stream, client_server_map).await {
|
||||
Ok(mut client) => {
|
||||
info!("Client {:?} connected (TLS)", addr);
|
||||
|
||||
client.handle().await
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
// TLS is not configured, we cannot offer it.
|
||||
else {
|
||||
// Rejecting client request for TLS.
|
||||
let mut no = BytesMut::new();
|
||||
no.put_u8(b'N');
|
||||
write_all(&mut stream, no).await?;
|
||||
|
||||
// Attempting regular startup. Client can disconnect now
|
||||
// if they choose.
|
||||
match get_startup::<TcpStream>(&mut stream).await {
|
||||
// Client accepted unencrypted connection.
|
||||
Ok((ClientConnectionType::Startup, bytes)) => {
|
||||
let (read, write) = split(stream);
|
||||
|
||||
// Continue with regular startup.
|
||||
match Client::startup(read, write, addr, bytes, client_server_map).await {
|
||||
Ok(mut client) => {
|
||||
info!("Client {:?} connected (plain)", addr);
|
||||
|
||||
client.handle().await
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
// Client probably disconnected rejecting our plain text connection.
|
||||
_ => Err(Error::ProtocolSyncError),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Client wants to use plain connection without encryption.
|
||||
Ok((ClientConnectionType::Startup, bytes)) => {
|
||||
let (read, write) = split(stream);
|
||||
match Client::handle_startup(read, write, bytes, client_server_map).await {
|
||||
Ok(mut client) => client.handle().await,
|
||||
|
||||
// Continue with regular startup.
|
||||
match Client::startup(read, write, addr, bytes, client_server_map).await {
|
||||
Ok(mut client) => {
|
||||
info!("Client {:?} connected (plain)", addr);
|
||||
|
||||
client.handle().await
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
// Client wants to cancel a query.
|
||||
Ok((ClientConnectionType::CancelQuery, bytes)) => {
|
||||
return Err(Error::ProtocolSyncError);
|
||||
let (read, write) = split(stream);
|
||||
|
||||
// Continue with cancel query request.
|
||||
match Client::cancel(read, write, addr, bytes, client_server_map).await {
|
||||
Ok(mut client) => {
|
||||
info!("Client {:?} issued a cancel query request", addr);
|
||||
|
||||
client.handle().await
|
||||
}
|
||||
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
// Something failed, probably the socket.
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle the first message the client sends.
|
||||
async fn get_startup<S>(stream: &mut S) -> Result<(ClientConnectionType, BytesMut), Error>
|
||||
where
|
||||
S: tokio::io::AsyncRead + std::marker::Unpin + tokio::io::AsyncWrite,
|
||||
@@ -131,32 +203,45 @@ where
|
||||
|
||||
// Client is requesting to cancel a running query (plain text connection).
|
||||
CANCEL_REQUEST_CODE => Ok((ClientConnectionType::CancelQuery, bytes)),
|
||||
|
||||
// Something else, probably something is wrong and it's not our fault,
|
||||
// e.g. badly implemented Postgres client.
|
||||
_ => Err(Error::ProtocolSyncError),
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle TLS connection negotation.
|
||||
pub async fn startup_tls(
|
||||
mut stream: TcpStream,
|
||||
stream: TcpStream,
|
||||
client_server_map: ClientServerMap,
|
||||
) -> Result<Client<ReadHalf<TlsStream<TcpStream>>, WriteHalf<TlsStream<TcpStream>>>, Error> {
|
||||
// Accept SSL request if SSL is configured.
|
||||
let mut yes = BytesMut::new();
|
||||
yes.put_u8(b'S');
|
||||
write_all(&mut stream, yes).await?;
|
||||
|
||||
// Negotiate TLS.
|
||||
let mut tls = Tls::new().unwrap();
|
||||
let tls = Tls::new()?;
|
||||
let addr = stream.peer_addr().unwrap();
|
||||
|
||||
let mut stream = match tls.acceptor.accept(stream).await {
|
||||
Ok(stream) => stream,
|
||||
Err(_) => return Err(Error::TlsError),
|
||||
|
||||
// TLS negotitation failed.
|
||||
Err(err) => {
|
||||
error!("TLS negotiation failed: {:?}", err);
|
||||
return Err(Error::TlsError)
|
||||
}
|
||||
};
|
||||
|
||||
// TLS negotitation successful.
|
||||
// Continue with regular startup using encrypted connection.
|
||||
match get_startup::<TlsStream<TcpStream>>(&mut stream).await {
|
||||
|
||||
// Got good startup message, proceeding like normal except we
|
||||
// are encrypted now.
|
||||
Ok((ClientConnectionType::Startup, bytes)) => {
|
||||
let (read, write) = split(stream);
|
||||
Client::handle_startup(read, write, bytes, client_server_map).await
|
||||
|
||||
Client::startup(read, write, addr, bytes, client_server_map).await
|
||||
}
|
||||
|
||||
// Bad Postgres client.
|
||||
_ => Err(Error::ProtocolSyncError),
|
||||
}
|
||||
}
|
||||
@@ -166,10 +251,12 @@ where
|
||||
S: tokio::io::AsyncRead + std::marker::Unpin,
|
||||
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
// Perform client startup sequence in TLS.
|
||||
pub async fn handle_startup(
|
||||
/// Handle Postgres client startup after TLS negotiation is complete
|
||||
/// or over plain text.
|
||||
pub async fn startup(
|
||||
mut read: S,
|
||||
mut write: T,
|
||||
addr: std::net::SocketAddr,
|
||||
bytes: BytesMut, // The rest of the startup message.
|
||||
client_server_map: ClientServerMap,
|
||||
) -> Result<Client<S, T>, Error> {
|
||||
@@ -244,6 +331,7 @@ where
|
||||
return Ok(Client {
|
||||
read: BufReader::new(read),
|
||||
write: write,
|
||||
addr,
|
||||
buffer: BytesMut::with_capacity(8196),
|
||||
cancel_mode: false,
|
||||
transaction_mode: transaction_mode,
|
||||
@@ -258,161 +346,38 @@ where
|
||||
});
|
||||
}
|
||||
|
||||
/// Perform client startup sequence.
|
||||
/// See docs: <https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3>
|
||||
// pub async fn startup(
|
||||
// mut stream: TcpStream,
|
||||
// client_server_map: ClientServerMap,
|
||||
// ) -> Result<Client<ReadHalf<TcpStream>, WriteHalf<TcpStream>>, Error> {
|
||||
// let config = get_config();
|
||||
// let transaction_mode = config.general.pool_mode == "transaction";
|
||||
// let stats = get_reporter();
|
||||
/// Handle cancel request.
|
||||
pub async fn cancel(
|
||||
read: S,
|
||||
write: T,
|
||||
addr: std::net::SocketAddr,
|
||||
mut bytes: BytesMut, // The rest of the startup message.
|
||||
client_server_map: ClientServerMap,
|
||||
) -> Result<Client<S, T>, Error> {
|
||||
let process_id = bytes.get_i32();
|
||||
let secret_key = bytes.get_i32();
|
||||
|
||||
// loop {
|
||||
// trace!("Waiting for StartupMessage");
|
||||
let config = get_config();
|
||||
let transaction_mode = config.general.pool_mode == "transaction";
|
||||
let stats = get_reporter();
|
||||
|
||||
// // Could be StartupMessage, SSLRequest or CancelRequest.
|
||||
// let len = match stream.read_i32().await {
|
||||
// Ok(len) => len,
|
||||
// Err(_) => return Err(Error::ClientBadStartup),
|
||||
// };
|
||||
|
||||
// let mut startup = vec![0u8; len as usize - 4];
|
||||
|
||||
// match stream.read_exact(&mut startup).await {
|
||||
// Ok(_) => (),
|
||||
// Err(_) => return Err(Error::ClientBadStartup),
|
||||
// };
|
||||
|
||||
// let mut bytes = BytesMut::from(&startup[..]);
|
||||
// let code = bytes.get_i32();
|
||||
|
||||
// match code {
|
||||
// // Client wants SSL. We don't support it at the moment.
|
||||
// SSL_REQUEST_CODE => {
|
||||
// trace!("Rejecting SSLRequest");
|
||||
|
||||
// let mut no = BytesMut::with_capacity(1);
|
||||
// no.put_u8(b'N');
|
||||
|
||||
// write_all(&mut stream, no).await?;
|
||||
// }
|
||||
|
||||
// // Regular startup message.
|
||||
// PROTOCOL_VERSION_NUMBER => {
|
||||
// trace!("Got StartupMessage");
|
||||
// let parameters = parse_startup(bytes.clone())?;
|
||||
|
||||
// // Generate random backend ID and secret key
|
||||
// let process_id: i32 = rand::random();
|
||||
// let secret_key: i32 = rand::random();
|
||||
|
||||
// // Perform MD5 authentication.
|
||||
// // TODO: Add SASL support.
|
||||
// let salt = md5_challenge(&mut stream).await?;
|
||||
|
||||
// let code = match stream.read_u8().await {
|
||||
// Ok(p) => p,
|
||||
// Err(_) => return Err(Error::SocketError),
|
||||
// };
|
||||
|
||||
// // PasswordMessage
|
||||
// if code as char != 'p' {
|
||||
// debug!("Expected p, got {}", code as char);
|
||||
// return Err(Error::ProtocolSyncError);
|
||||
// }
|
||||
|
||||
// let len = match stream.read_i32().await {
|
||||
// Ok(len) => len,
|
||||
// Err(_) => return Err(Error::SocketError),
|
||||
// };
|
||||
|
||||
// let mut password_response = vec![0u8; (len - 4) as usize];
|
||||
|
||||
// match stream.read_exact(&mut password_response).await {
|
||||
// Ok(_) => (),
|
||||
// Err(_) => return Err(Error::SocketError),
|
||||
// };
|
||||
|
||||
// // Compare server and client hashes.
|
||||
// let password_hash =
|
||||
// md5_hash_password(&config.user.name, &config.user.password, &salt);
|
||||
|
||||
// if password_hash != password_response {
|
||||
// debug!("Password authentication failed");
|
||||
// wrong_password(&mut stream, &config.user.name).await?;
|
||||
// return Err(Error::ClientError);
|
||||
// }
|
||||
|
||||
// debug!("Password authentication successful");
|
||||
|
||||
// auth_ok(&mut stream).await?;
|
||||
// write_all(&mut stream, get_pool().server_info()).await?;
|
||||
// backend_key_data(&mut stream, process_id, secret_key).await?;
|
||||
// ready_for_query(&mut stream).await?;
|
||||
|
||||
// trace!("Startup OK");
|
||||
|
||||
// let database = parameters
|
||||
// .get("database")
|
||||
// .unwrap_or(parameters.get("user").unwrap());
|
||||
// let admin = ["pgcat", "pgbouncer"]
|
||||
// .iter()
|
||||
// .filter(|db| *db == &database)
|
||||
// .count()
|
||||
// == 1;
|
||||
|
||||
// // Split the read and write streams
|
||||
// // so we can control buffering.
|
||||
// let (read, write) = split(stream);
|
||||
|
||||
// return Ok(Client {
|
||||
// read: BufReader::new(read),
|
||||
// write: write,
|
||||
// buffer: BytesMut::with_capacity(8196),
|
||||
// cancel_mode: false,
|
||||
// transaction_mode: transaction_mode,
|
||||
// process_id: process_id,
|
||||
// secret_key: secret_key,
|
||||
// client_server_map: client_server_map,
|
||||
// parameters: parameters,
|
||||
// stats: stats,
|
||||
// admin: admin,
|
||||
// last_address_id: None,
|
||||
// last_server_id: None,
|
||||
// });
|
||||
// }
|
||||
|
||||
// // Query cancel request.
|
||||
// CANCEL_REQUEST_CODE => {
|
||||
// let (read, write) = split(stream);
|
||||
|
||||
// let process_id = bytes.get_i32();
|
||||
// let secret_key = bytes.get_i32();
|
||||
|
||||
// return Ok(Client {
|
||||
// read: BufReader::new(read),
|
||||
// write: write,
|
||||
// buffer: BytesMut::with_capacity(8196),
|
||||
// cancel_mode: true,
|
||||
// transaction_mode: transaction_mode,
|
||||
// process_id: process_id,
|
||||
// secret_key: secret_key,
|
||||
// client_server_map: client_server_map,
|
||||
// parameters: HashMap::new(),
|
||||
// stats: stats,
|
||||
// admin: false,
|
||||
// last_address_id: None,
|
||||
// last_server_id: None,
|
||||
// });
|
||||
// }
|
||||
|
||||
// _ => {
|
||||
// return Err(Error::ProtocolSyncError);
|
||||
// }
|
||||
// };
|
||||
// }
|
||||
// }
|
||||
return Ok(Client {
|
||||
read: BufReader::new(read),
|
||||
write: write,
|
||||
addr,
|
||||
buffer: BytesMut::with_capacity(8196),
|
||||
cancel_mode: true,
|
||||
transaction_mode: transaction_mode,
|
||||
process_id: process_id,
|
||||
secret_key: secret_key,
|
||||
client_server_map: client_server_map,
|
||||
parameters: HashMap::new(),
|
||||
stats: stats,
|
||||
admin: false,
|
||||
last_address_id: None,
|
||||
last_server_id: None,
|
||||
});
|
||||
}
|
||||
|
||||
/// Handle a connected and authenticated client.
|
||||
pub async fn handle(&mut self) -> Result<(), Error> {
|
||||
@@ -608,8 +573,8 @@ where
|
||||
self.last_server_id = Some(server.process_id());
|
||||
|
||||
debug!(
|
||||
"Client stuff talking to server {:?}",
|
||||
// self.write.peer_addr().unwrap(),
|
||||
"Client {:?} talking to server {:?}",
|
||||
self.addr,
|
||||
server.address()
|
||||
);
|
||||
|
||||
@@ -846,6 +811,9 @@ where
|
||||
|
||||
impl<S, T> Drop for Client<S, T> {
|
||||
fn drop(&mut self) {
|
||||
let mut guard = self.client_server_map.lock();
|
||||
guard.remove(&(self.process_id, self.secret_key));
|
||||
|
||||
// Update statistics.
|
||||
if let Some(address_id) = self.last_address_id {
|
||||
self.stats.client_disconnecting(self.process_id, address_id);
|
||||
@@ -854,5 +822,7 @@ impl<S, T> Drop for Client<S, T> {
|
||||
self.stats.server_idle(process_id, address_id);
|
||||
}
|
||||
}
|
||||
|
||||
// self.release();
|
||||
}
|
||||
}
|
||||
|
||||
55
src/main.rs
55
src/main.rs
@@ -45,11 +45,6 @@ use tokio::{
|
||||
sync::mpsc,
|
||||
};
|
||||
|
||||
use tokio::net::{
|
||||
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
||||
TcpStream,
|
||||
};
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -65,9 +60,8 @@ mod scram;
|
||||
mod server;
|
||||
mod sharding;
|
||||
mod stats;
|
||||
mod stream;
|
||||
mod tls;
|
||||
|
||||
use crate::constants::*;
|
||||
use config::{get_config, reload_config};
|
||||
use pool::{ClientServerMap, ConnectionPool};
|
||||
use stats::{Collector, Reporter, REPORTER};
|
||||
@@ -159,45 +153,22 @@ async fn main() {
|
||||
// Handle client.
|
||||
tokio::task::spawn(async move {
|
||||
let start = chrono::offset::Utc::now().naive_utc();
|
||||
// match client::get_startup(&mut socket) {
|
||||
// Ok((code, bytes)) => match code {
|
||||
// SSL_REQUEST_CODE => client::Client::tls_startup<
|
||||
// }
|
||||
// }
|
||||
|
||||
match client::client_loop(socket, client_server_map).await {
|
||||
Ok(_) => (),
|
||||
match client::client_entrypoint(socket, client_server_map).await {
|
||||
Ok(_) => {
|
||||
let duration = chrono::offset::Utc::now().naive_utc() - start;
|
||||
|
||||
info!(
|
||||
"Client {:?} disconnected, session duration: {}",
|
||||
addr,
|
||||
format_duration(&duration)
|
||||
);
|
||||
},
|
||||
|
||||
Err(err) => {
|
||||
debug!("Client failed to login: {:?}", err);
|
||||
debug!("Client disconnected with error {:?}", err);
|
||||
}
|
||||
};
|
||||
|
||||
// match client::Client::<OwnedReadHalf, OwnedWriteHalf>::startup(socket, client_server_map).await {
|
||||
// Ok(mut client) => {
|
||||
// info!("Client {:?} connected", addr);
|
||||
|
||||
// match client.handle().await {
|
||||
// Ok(()) => {
|
||||
// let duration = chrono::offset::Utc::now().naive_utc() - start;
|
||||
|
||||
// info!(
|
||||
// "Client {:?} disconnected, session duration: {}",
|
||||
// addr,
|
||||
// format_duration(&duration)
|
||||
// );
|
||||
// }
|
||||
|
||||
// Err(err) => {
|
||||
// error!("Client disconnected with error: {:?}", err);
|
||||
// client.release();
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// Err(err) => {
|
||||
// debug!("Client failed to login: {:?}", err);
|
||||
// }
|
||||
// };
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
@@ -2,11 +2,8 @@
|
||||
/// and handle TcpStream (TCP socket).
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use md5::{Digest, Md5};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::net::{
|
||||
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
||||
TcpStream,
|
||||
};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use crate::errors::Error;
|
||||
use std::collections::HashMap;
|
||||
|
||||
119
src/stream.rs
119
src/stream.rs
@@ -1,119 +0,0 @@
|
||||
// Stream wrapper.
|
||||
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use rustls_pemfile::{certs, rsa_private_keys};
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{split, AsyncReadExt, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
|
||||
use tokio::net::{
|
||||
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
||||
TcpStream,
|
||||
};
|
||||
use tokio_rustls::rustls::{self, Certificate, PrivateKey};
|
||||
use tokio_rustls::server::TlsStream;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
|
||||
use crate::config::get_config;
|
||||
use crate::errors::Error;
|
||||
|
||||
// TLS
|
||||
fn load_certs(path: &std::path::Path) -> std::io::Result<Vec<Certificate>> {
|
||||
certs(&mut std::io::BufReader::new(std::fs::File::open(path)?))
|
||||
.map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid cert"))
|
||||
.map(|mut certs| certs.drain(..).map(Certificate).collect())
|
||||
}
|
||||
|
||||
fn load_keys(path: &std::path::Path) -> std::io::Result<Vec<PrivateKey>> {
|
||||
rsa_private_keys(&mut std::io::BufReader::new(std::fs::File::open(path)?))
|
||||
.map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid key"))
|
||||
.map(|mut keys| keys.drain(..).map(PrivateKey).collect())
|
||||
}
|
||||
|
||||
pub struct Tls {
|
||||
pub acceptor: TlsAcceptor,
|
||||
}
|
||||
|
||||
impl Tls {
|
||||
pub fn new() -> Result<Self, Error> {
|
||||
let config = get_config();
|
||||
|
||||
let certs = match load_certs(&Path::new(&config.general.tls_certificate.unwrap())) {
|
||||
Ok(certs) => certs,
|
||||
Err(_) => return Err(Error::TlsError),
|
||||
};
|
||||
|
||||
let mut keys = match load_keys(&Path::new(&config.general.tls_private_key.unwrap())) {
|
||||
Ok(keys) => keys,
|
||||
Err(_) => return Err(Error::TlsError),
|
||||
};
|
||||
|
||||
let config = match rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, keys.remove(0))
|
||||
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))
|
||||
{
|
||||
Ok(c) => c,
|
||||
Err(_) => return Err(Error::TlsError),
|
||||
};
|
||||
|
||||
Ok(Tls {
|
||||
acceptor: TlsAcceptor::from(Arc::new(config)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct Stream {
|
||||
read: Option<BufReader<OwnedReadHalf>>,
|
||||
write: Option<OwnedWriteHalf>,
|
||||
tls_read: Option<BufReader<ReadHalf<TlsStream<TcpStream>>>>,
|
||||
tls_write: Option<WriteHalf<TlsStream<TcpStream>>>,
|
||||
}
|
||||
|
||||
impl Stream {
|
||||
pub async fn new(stream: TcpStream, tls: Option<Tls>) -> Result<Stream, Error> {
|
||||
let config = get_config();
|
||||
|
||||
match tls {
|
||||
None => {
|
||||
let (read, write) = stream.into_split();
|
||||
let read = BufReader::new(read);
|
||||
Ok(Self {
|
||||
read: Some(read),
|
||||
write: Some(write),
|
||||
tls_read: None,
|
||||
tls_write: None,
|
||||
})
|
||||
}
|
||||
|
||||
Some(tls) => {
|
||||
let mut tls_stream = match tls.acceptor.accept(stream).await {
|
||||
Ok(stream) => stream,
|
||||
Err(_) => return Err(Error::TlsError),
|
||||
};
|
||||
|
||||
let (read, write) = split(tls_stream);
|
||||
|
||||
Ok(Self {
|
||||
read: None,
|
||||
write: None,
|
||||
tls_read: Some(BufReader::new(read)),
|
||||
tls_write: Some(write),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// impl tokio::io::AsyncRead for Stream {
|
||||
// fn poll_read(
|
||||
// mut self: core::pin::Pin<&mut Self>,
|
||||
// cx: &mut core::task::Context<'_>,
|
||||
// buf: &mut tokio::io::ReadBuf<'_>
|
||||
// ) -> core::task::Poll<std::io::Result<()>> {
|
||||
// match &mut self.get_mut().tls_read {
|
||||
// None => core::pin::Pin::new(self.read.as_mut().unwrap()).poll_read(cx, buf),
|
||||
// Some(mut tls) => core::pin::Pin::new(&mut tls).poll_read(cx, buf),
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
57
src/tls.rs
Normal file
57
src/tls.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
// Stream wrapper.
|
||||
|
||||
use rustls_pemfile::{certs, rsa_private_keys};
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tokio_rustls::rustls::{self, Certificate, PrivateKey};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
|
||||
use crate::config::get_config;
|
||||
use crate::errors::Error;
|
||||
|
||||
// TLS
|
||||
fn load_certs(path: &Path) -> std::io::Result<Vec<Certificate>> {
|
||||
certs(&mut std::io::BufReader::new(std::fs::File::open(path)?))
|
||||
.map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid cert"))
|
||||
.map(|mut certs| certs.drain(..).map(Certificate).collect())
|
||||
}
|
||||
|
||||
fn load_keys(path: &Path) -> std::io::Result<Vec<PrivateKey>> {
|
||||
rsa_private_keys(&mut std::io::BufReader::new(std::fs::File::open(path)?))
|
||||
.map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid key"))
|
||||
.map(|mut keys| keys.drain(..).map(PrivateKey).collect())
|
||||
}
|
||||
|
||||
pub struct Tls {
|
||||
pub acceptor: TlsAcceptor,
|
||||
}
|
||||
|
||||
impl Tls {
|
||||
pub fn new() -> Result<Self, Error> {
|
||||
let config = get_config();
|
||||
|
||||
let certs = match load_certs(&Path::new(&config.general.tls_certificate.unwrap())) {
|
||||
Ok(certs) => certs,
|
||||
Err(_) => return Err(Error::TlsError),
|
||||
};
|
||||
|
||||
let mut keys = match load_keys(&Path::new(&config.general.tls_private_key.unwrap())) {
|
||||
Ok(keys) => keys,
|
||||
Err(_) => return Err(Error::TlsError),
|
||||
};
|
||||
|
||||
let config = match rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, keys.remove(0))
|
||||
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))
|
||||
{
|
||||
Ok(c) => c,
|
||||
Err(_) => return Err(Error::TlsError),
|
||||
};
|
||||
|
||||
Ok(Tls {
|
||||
acceptor: TlsAcceptor::from(Arc::new(config)),
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user