Server TLS

This commit is contained in:
Lev Kokotov
2023-04-28 11:20:49 -07:00
parent 4a87b4807d
commit 9e51b8110f
5 changed files with 166 additions and 14 deletions

21
Cargo.lock generated
View File

@@ -762,6 +762,7 @@ dependencies = [
"once_cell",
"parking_lot",
"phf",
"pin-project",
"postgres-protocol",
"rand",
"regex",
@@ -820,6 +821,26 @@ dependencies = [
"siphasher",
]
[[package]]
name = "pin-project"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "pin-project-lite"
version = "0.2.9"

View File

@@ -39,6 +39,7 @@ nix = "0.26.2"
atomic_enum = "0.2.0"
postgres-protocol = "0.6.5"
fallible-iterator = "0.2"
pin-project = "*"
[target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0"

View File

@@ -539,6 +539,7 @@ where
Some(md5_hash_password(username, password, &salt))
} else {
if !get_config().is_auth_query_configured() {
wrong_password(&mut write, username).await?;
return Err(Error::ClientAuthImpossible(username.into()));
}
@@ -565,6 +566,8 @@ where
}
Err(err) => {
wrong_password(&mut write, username).await?;
return Err(Error::ClientAuthPassthroughError(
err.to_string(),
client_identifier,
@@ -587,7 +590,15 @@ where
client_identifier
);
let fetched_hash = refetch_auth_hash(&pool).await?;
let fetched_hash = match refetch_auth_hash(&pool).await {
Ok(fetched_hash) => fetched_hash,
Err(err) => {
wrong_password(&mut write, username).await?;
return Err(err);
}
};
let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt);
// Ok password changed in server an auth is possible.

View File

@@ -150,6 +150,21 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
}
}
pub async fn ssl_request(stream: &mut TcpStream) -> Result<(), Error> {
let mut bytes = BytesMut::with_capacity(12);
bytes.put_i32(8);
bytes.put_i32(80877103);
match stream.write_all(&bytes).await {
Ok(_) => Ok(()),
Err(err) => Err(Error::SocketError(format!(
"Error writing SSLRequest to server socket - Error: {:?}",
err
))),
}
}
/// Parse the params the server sends as a key/value format.
pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
let mut result = HashMap::new();

View File

@@ -9,11 +9,13 @@ use std::collections::HashMap;
use std::io::Read;
use std::sync::Arc;
use std::time::SystemTime;
use tokio::io::{AsyncReadExt, BufReader};
use tokio::io::{split, AsyncRead, AsyncReadExt, AsyncWrite, BufReader, ReadHalf, WriteHalf};
use tokio::net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
};
use tokio_rustls::rustls::ClientConfig;
use tokio_rustls::{TlsConnector, TlsStream};
use crate::config::{Address, User};
use crate::constants::*;
@@ -24,6 +26,82 @@ use crate::pool::ClientServerMap;
use crate::scram::ScramSha256;
use crate::stats::ServerStats;
use pin_project::pin_project;
#[pin_project(project = ReadInnerProj)]
pub enum ReadInner {
Plain {
#[pin]
stream: ReadHalf<TcpStream>,
},
Tls {
#[pin]
stream: ReadHalf<TlsStream<TcpStream>>,
},
}
#[pin_project(project = WriteInnerProj)]
pub enum WriteInner {
Plain {
#[pin]
stream: WriteHalf<TcpStream>,
},
Tls {
#[pin]
stream: WriteHalf<TlsStream<TcpStream>>,
},
}
impl AsyncWrite for WriteInner {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let this = self.project();
match this {
WriteInnerProj::Tls { stream } => stream.poll_write(cx, buf),
WriteInnerProj::Plain { stream } => stream.poll_write(cx, buf),
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let this = self.project();
match this {
WriteInnerProj::Tls { stream } => stream.poll_flush(cx),
WriteInnerProj::Plain { stream } => stream.poll_flush(cx),
}
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let this = self.project();
match this {
WriteInnerProj::Tls { stream } => stream.poll_shutdown(cx),
WriteInnerProj::Plain { stream } => stream.poll_shutdown(cx),
}
}
}
impl AsyncRead for ReadInner {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let this = self.project();
match this {
ReadInnerProj::Tls { stream } => stream.poll_read(cx, buf),
ReadInnerProj::Plain { stream } => stream.poll_read(cx, buf),
}
}
}
/// Server state.
pub struct Server {
/// Server host, e.g. localhost,
@@ -31,10 +109,10 @@ pub struct Server {
address: Address,
/// Buffered read socket.
read: BufReader<OwnedReadHalf>,
read: BufReader<ReadInner>,
/// Unbuffered write socket (our client code buffers).
write: OwnedWriteHalf,
write: WriteInner,
/// Our server response buffer. We buffer data before we give it to the client.
buffer: BytesMut,
@@ -100,6 +178,32 @@ impl Server {
};
configure_socket(&stream);
// ssl_request(&mut stream).await?;
// let response = match stream.read_u8().await {
// Ok(response) => response as char,
// Err(err) => return Err(Error::SocketError(format!("Server socket error: {:?}", err))),
// };
// match response {
// 'S' => {
// let connector = TlsConnector::from(ClientConfig::builder()
// .with_safe_default_cipher_suites()
// .with_safe_default_kx_groups()
// .with_safe_default_protocol_versions()
// .unwrap()
// .with_no_client_auth());
// connector.connect("test".into(), stream).await.unwrap();
// },
// 'N' => {
// },
// _ => {
// return Err(Error::SocketError("error".into()));
// }
// };
trace!("Sending StartupMessage");
// StartupMessage
@@ -443,12 +547,12 @@ impl Server {
}
};
let (read, write) = stream.into_split();
let (read, write) = split(stream);
let mut server = Server {
address: address.clone(),
read: BufReader::new(read),
write,
read: BufReader::new(ReadInner::Plain { stream: read }),
write: WriteInner::Plain { stream: write },
buffer: BytesMut::with_capacity(8196),
server_info,
process_id,
@@ -935,14 +1039,14 @@ impl Drop for Server {
// Update statistics
self.stats.disconnect();
let mut bytes = BytesMut::with_capacity(4);
bytes.put_u8(b'X');
bytes.put_i32(4);
// let mut bytes = BytesMut::with_capacity(4);
// bytes.put_u8(b'X');
// bytes.put_i32(4);
match self.write.try_write(&bytes) {
Ok(_) => (),
Err(_) => debug!("Dirty shutdown"),
};
// match self.write.try_write(&bytes) {
// Ok(_) => (),
// Err(_) => debug!("Dirty shutdown"),
// };
// Should not matter.
self.bad = true;