Server TLS

This commit is contained in:
Lev Kokotov
2023-04-28 11:20:49 -07:00
parent 4a87b4807d
commit 9e51b8110f
5 changed files with 166 additions and 14 deletions

View File

@@ -9,11 +9,13 @@ use std::collections::HashMap;
use std::io::Read;
use std::sync::Arc;
use std::time::SystemTime;
use tokio::io::{AsyncReadExt, BufReader};
use tokio::io::{split, AsyncRead, AsyncReadExt, AsyncWrite, BufReader, ReadHalf, WriteHalf};
use tokio::net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
};
use tokio_rustls::rustls::ClientConfig;
use tokio_rustls::{TlsConnector, TlsStream};
use crate::config::{Address, User};
use crate::constants::*;
@@ -24,6 +26,82 @@ use crate::pool::ClientServerMap;
use crate::scram::ScramSha256;
use crate::stats::ServerStats;
use pin_project::pin_project;
#[pin_project(project = ReadInnerProj)]
pub enum ReadInner {
Plain {
#[pin]
stream: ReadHalf<TcpStream>,
},
Tls {
#[pin]
stream: ReadHalf<TlsStream<TcpStream>>,
},
}
#[pin_project(project = WriteInnerProj)]
pub enum WriteInner {
Plain {
#[pin]
stream: WriteHalf<TcpStream>,
},
Tls {
#[pin]
stream: WriteHalf<TlsStream<TcpStream>>,
},
}
impl AsyncWrite for WriteInner {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let this = self.project();
match this {
WriteInnerProj::Tls { stream } => stream.poll_write(cx, buf),
WriteInnerProj::Plain { stream } => stream.poll_write(cx, buf),
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let this = self.project();
match this {
WriteInnerProj::Tls { stream } => stream.poll_flush(cx),
WriteInnerProj::Plain { stream } => stream.poll_flush(cx),
}
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let this = self.project();
match this {
WriteInnerProj::Tls { stream } => stream.poll_shutdown(cx),
WriteInnerProj::Plain { stream } => stream.poll_shutdown(cx),
}
}
}
impl AsyncRead for ReadInner {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let this = self.project();
match this {
ReadInnerProj::Tls { stream } => stream.poll_read(cx, buf),
ReadInnerProj::Plain { stream } => stream.poll_read(cx, buf),
}
}
}
/// Server state.
pub struct Server {
/// Server host, e.g. localhost,
@@ -31,10 +109,10 @@ pub struct Server {
address: Address,
/// Buffered read socket.
read: BufReader<OwnedReadHalf>,
read: BufReader<ReadInner>,
/// Unbuffered write socket (our client code buffers).
write: OwnedWriteHalf,
write: WriteInner,
/// Our server response buffer. We buffer data before we give it to the client.
buffer: BytesMut,
@@ -100,6 +178,32 @@ impl Server {
};
configure_socket(&stream);
// ssl_request(&mut stream).await?;
// let response = match stream.read_u8().await {
// Ok(response) => response as char,
// Err(err) => return Err(Error::SocketError(format!("Server socket error: {:?}", err))),
// };
// match response {
// 'S' => {
// let connector = TlsConnector::from(ClientConfig::builder()
// .with_safe_default_cipher_suites()
// .with_safe_default_kx_groups()
// .with_safe_default_protocol_versions()
// .unwrap()
// .with_no_client_auth());
// connector.connect("test".into(), stream).await.unwrap();
// },
// 'N' => {
// },
// _ => {
// return Err(Error::SocketError("error".into()));
// }
// };
trace!("Sending StartupMessage");
// StartupMessage
@@ -443,12 +547,12 @@ impl Server {
}
};
let (read, write) = stream.into_split();
let (read, write) = split(stream);
let mut server = Server {
address: address.clone(),
read: BufReader::new(read),
write,
read: BufReader::new(ReadInner::Plain { stream: read }),
write: WriteInner::Plain { stream: write },
buffer: BytesMut::with_capacity(8196),
server_info,
process_id,
@@ -935,14 +1039,14 @@ impl Drop for Server {
// Update statistics
self.stats.disconnect();
let mut bytes = BytesMut::with_capacity(4);
bytes.put_u8(b'X');
bytes.put_i32(4);
// let mut bytes = BytesMut::with_capacity(4);
// bytes.put_u8(b'X');
// bytes.put_i32(4);
match self.write.try_write(&bytes) {
Ok(_) => (),
Err(_) => debug!("Dirty shutdown"),
};
// match self.write.try_write(&bytes) {
// Ok(_) => (),
// Err(_) => debug!("Dirty shutdown"),
// };
// Should not matter.
self.bad = true;