This commit is contained in:
Lev Kokotov
2023-04-29 08:38:27 -07:00
parent a514dbc187
commit bba5f10be1
3 changed files with 88 additions and 81 deletions

View File

@@ -66,7 +66,7 @@ tcp_keepalives_interval = 5
# tls_private_key = ".circleci/server.key" # tls_private_key = ".circleci/server.key"
# Enable/disable server TLS # Enable/disable server TLS
server_tls = true server_tls = false
# Verify server certificate is completely authentic. # Verify server certificate is completely authentic.
verify_server_certificate = false verify_server_certificate = false

View File

@@ -523,6 +523,29 @@ where
} }
} }
pub async fn write_all_flush<S>(stream: &mut S, buf: &[u8]) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
match stream.write_all(buf).await {
Ok(_) => match stream.flush().await {
Ok(_) => Ok(()),
Err(err) => {
return Err(Error::SocketError(format!(
"Error flushing socket - Error: {:?}",
err
)))
}
},
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
)))
}
}
}
/// Read a complete message from the socket. /// Read a complete message from the socket.
pub async fn read_message<S>(stream: &mut S) -> Result<BytesMut, Error> pub async fn read_message<S>(stream: &mut S) -> Result<BytesMut, Error>
where where

View File

@@ -9,7 +9,7 @@ use std::collections::HashMap;
use std::io::Read; use std::io::Read;
use std::sync::Arc; use std::sync::Arc;
use std::time::SystemTime; use std::time::SystemTime;
use tokio::io::{split, AsyncRead, AsyncReadExt, AsyncWrite, BufReader, ReadHalf, WriteHalf}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore}; use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore};
use tokio_rustls::{client::TlsStream, TlsConnector}; use tokio_rustls::{client::TlsStream, TlsConnector};
@@ -22,34 +22,23 @@ use crate::mirrors::MirroringManager;
use crate::pool::ClientServerMap; use crate::pool::ClientServerMap;
use crate::scram::ScramSha256; use crate::scram::ScramSha256;
use crate::stats::ServerStats; use crate::stats::ServerStats;
use std::io::Write;
use pin_project::pin_project; use pin_project::pin_project;
#[pin_project(project = ReadInnerProj)] #[pin_project(project = SteamInnerProj)]
pub enum ReadInner { pub enum StreamInner {
Plain { Plain {
#[pin] #[pin]
stream: ReadHalf<TcpStream>, stream: TcpStream,
}, },
Tls { Tls {
#[pin] #[pin]
stream: ReadHalf<TlsStream<TcpStream>>, stream: TlsStream<TcpStream>,
}, },
} }
#[pin_project(project = WriteInnerProj)] impl AsyncWrite for StreamInner {
pub enum WriteInner {
Plain {
#[pin]
stream: WriteHalf<TcpStream>,
},
Tls {
#[pin]
stream: WriteHalf<TlsStream<TcpStream>>,
},
}
impl AsyncWrite for WriteInner {
fn poll_write( fn poll_write(
self: std::pin::Pin<&mut Self>, self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
@@ -57,8 +46,8 @@ impl AsyncWrite for WriteInner {
) -> std::task::Poll<Result<usize, std::io::Error>> { ) -> std::task::Poll<Result<usize, std::io::Error>> {
let this = self.project(); let this = self.project();
match this { match this {
WriteInnerProj::Tls { stream } => stream.poll_write(cx, buf), SteamInnerProj::Tls { stream } => stream.poll_write(cx, buf),
WriteInnerProj::Plain { stream } => stream.poll_write(cx, buf), SteamInnerProj::Plain { stream } => stream.poll_write(cx, buf),
} }
} }
@@ -68,8 +57,8 @@ impl AsyncWrite for WriteInner {
) -> std::task::Poll<Result<(), std::io::Error>> { ) -> std::task::Poll<Result<(), std::io::Error>> {
let this = self.project(); let this = self.project();
match this { match this {
WriteInnerProj::Tls { stream } => stream.poll_flush(cx), SteamInnerProj::Tls { stream } => stream.poll_flush(cx),
WriteInnerProj::Plain { stream } => stream.poll_flush(cx), SteamInnerProj::Plain { stream } => stream.poll_flush(cx),
} }
} }
@@ -79,13 +68,13 @@ impl AsyncWrite for WriteInner {
) -> std::task::Poll<Result<(), std::io::Error>> { ) -> std::task::Poll<Result<(), std::io::Error>> {
let this = self.project(); let this = self.project();
match this { match this {
WriteInnerProj::Tls { stream } => stream.poll_shutdown(cx), SteamInnerProj::Tls { stream } => stream.poll_shutdown(cx),
WriteInnerProj::Plain { stream } => stream.poll_shutdown(cx), SteamInnerProj::Plain { stream } => stream.poll_shutdown(cx),
} }
} }
} }
impl AsyncRead for ReadInner { impl AsyncRead for StreamInner {
fn poll_read( fn poll_read(
self: std::pin::Pin<&mut Self>, self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
@@ -93,8 +82,21 @@ impl AsyncRead for ReadInner {
) -> std::task::Poll<std::io::Result<()>> { ) -> std::task::Poll<std::io::Result<()>> {
let this = self.project(); let this = self.project();
match this { match this {
ReadInnerProj::Tls { stream } => stream.poll_read(cx, buf), SteamInnerProj::Tls { stream } => stream.poll_read(cx, buf),
ReadInnerProj::Plain { stream } => stream.poll_read(cx, buf), SteamInnerProj::Plain { stream } => stream.poll_read(cx, buf),
}
}
}
impl StreamInner {
pub fn try_write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self {
StreamInner::Tls { stream } => {
let r = stream.get_mut();
let mut w = r.1.writer();
w.write(buf)
}
StreamInner::Plain { stream } => stream.try_write(buf),
} }
} }
} }
@@ -105,11 +107,8 @@ pub struct Server {
/// port, e.g. 5432, and role, e.g. primary or replica. /// port, e.g. 5432, and role, e.g. primary or replica.
address: Address, address: Address,
/// Buffered read socket. /// Server TCP connection.
read: BufReader<ReadInner>, stream: BufStream<StreamInner>,
/// Unbuffered write socket (our client code buffers).
write: WriteInner,
/// Our server response buffer. We buffer data before we give it to the client. /// Our server response buffer. We buffer data before we give it to the client.
buffer: BytesMut, buffer: BytesMut,
@@ -177,7 +176,7 @@ impl Server {
// TCP timeouts. // TCP timeouts.
configure_socket(&stream); configure_socket(&stream);
let (mut read, mut write) = if get_config().general.server_tls { let mut stream = if get_config().general.server_tls {
// Request a TLS connection // Request a TLS connection
ssl_request(&mut stream).await?; ssl_request(&mut stream).await?;
@@ -232,21 +231,11 @@ impl Server {
} }
}; };
let (read, write) = split(stream); StreamInner::Tls { stream }
(
ReadInner::Tls { stream: read },
WriteInner::Tls { stream: write },
)
} }
// Server does not support TLS // Server does not support TLS
'N' => { 'N' => StreamInner::Plain { stream },
let (read, write) = split(stream);
(
ReadInner::Plain { stream: read },
WriteInner::Plain { stream: write },
)
}
// Something else? // Something else?
m => { m => {
@@ -257,11 +246,7 @@ impl Server {
} }
} }
} else { } else {
let (read, write) = split(stream); StreamInner::Plain { stream }
(
ReadInner::Plain { stream: read },
WriteInner::Plain { stream: write },
)
}; };
// let (read, write) = split(stream); // let (read, write) = split(stream);
@@ -283,7 +268,7 @@ impl Server {
}, },
}; };
startup(&mut write, username, database).await?; startup(&mut stream, username, database).await?;
let mut server_info = BytesMut::new(); let mut server_info = BytesMut::new();
let mut process_id: i32 = 0; let mut process_id: i32 = 0;
@@ -298,7 +283,7 @@ impl Server {
}; };
loop { loop {
let code = match read.read_u8().await { let code = match stream.read_u8().await {
Ok(code) => code as char, Ok(code) => code as char,
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -308,7 +293,7 @@ impl Server {
} }
}; };
let len = match read.read_i32().await { let len = match stream.read_i32().await {
Ok(len) => len, Ok(len) => len,
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -324,7 +309,7 @@ impl Server {
// Authentication // Authentication
'R' => { 'R' => {
// Determine which kind of authentication is required, if any. // Determine which kind of authentication is required, if any.
let auth_code = match read.read_i32().await { let auth_code = match stream.read_i32().await {
Ok(auth_code) => auth_code, Ok(auth_code) => auth_code,
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -342,7 +327,7 @@ impl Server {
// See: https://www.postgresql.org/docs/12/protocol-message-formats.html // See: https://www.postgresql.org/docs/12/protocol-message-formats.html
let mut salt = vec![0u8; 4]; let mut salt = vec![0u8; 4];
match read.read_exact(&mut salt).await { match stream.read_exact(&mut salt).await {
Ok(_) => (), Ok(_) => (),
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -355,7 +340,7 @@ impl Server {
match password { match password {
// Using plaintext password // Using plaintext password
Some(password) => { Some(password) => {
md5_password(&mut write, username, password, &salt[..]).await? md5_password(&mut stream, username, password, &salt[..]).await?
} }
// Using auth passthrough, in this case we should already have a // Using auth passthrough, in this case we should already have a
@@ -366,7 +351,7 @@ impl Server {
match option_hash { match option_hash {
Some(hash) => Some(hash) =>
md5_password_with_hash( md5_password_with_hash(
&mut write, &mut stream,
&hash, &hash,
&salt[..], &salt[..],
) )
@@ -400,7 +385,7 @@ impl Server {
let sasl_len = (len - 8) as usize; let sasl_len = (len - 8) as usize;
let mut sasl_auth = vec![0u8; sasl_len]; let mut sasl_auth = vec![0u8; sasl_len];
match read.read_exact(&mut sasl_auth).await { match stream.read_exact(&mut sasl_auth).await {
Ok(_) => (), Ok(_) => (),
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -435,7 +420,7 @@ impl Server {
res.put_i32(sasl_response.len() as i32); res.put_i32(sasl_response.len() as i32);
res.put(sasl_response); res.put(sasl_response);
write_all(&mut write, res).await?; write_all_flush(&mut stream, &res).await?;
} else { } else {
error!("Unsupported SCRAM version: {}", sasl_type); error!("Unsupported SCRAM version: {}", sasl_type);
return Err(Error::ServerError); return Err(Error::ServerError);
@@ -447,7 +432,7 @@ impl Server {
let mut sasl_data = vec![0u8; (len - 8) as usize]; let mut sasl_data = vec![0u8; (len - 8) as usize];
match read.read_exact(&mut sasl_data).await { match stream.read_exact(&mut sasl_data).await {
Ok(_) => (), Ok(_) => (),
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -466,14 +451,14 @@ impl Server {
res.put_i32(4 + sasl_response.len() as i32); res.put_i32(4 + sasl_response.len() as i32);
res.put(sasl_response); res.put(sasl_response);
write_all(&mut write, res).await?; write_all_flush(&mut stream, &res).await?;
} }
SASL_FINAL => { SASL_FINAL => {
trace!("Final SASL"); trace!("Final SASL");
let mut sasl_final = vec![0u8; len as usize - 8]; let mut sasl_final = vec![0u8; len as usize - 8];
match read.read_exact(&mut sasl_final).await { match stream.read_exact(&mut sasl_final).await {
Ok(_) => (), Ok(_) => (),
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -508,7 +493,7 @@ impl Server {
// ErrorResponse // ErrorResponse
'E' => { 'E' => {
let error_code = match read.read_u8().await { let error_code = match stream.read_u8().await {
Ok(error_code) => error_code, Ok(error_code) => error_code,
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -529,7 +514,7 @@ impl Server {
// Read the error message without the terminating null character. // Read the error message without the terminating null character.
let mut error = vec![0u8; len as usize - 4 - 1]; let mut error = vec![0u8; len as usize - 4 - 1];
match read.read_exact(&mut error).await { match stream.read_exact(&mut error).await {
Ok(_) => (), Ok(_) => (),
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -553,7 +538,7 @@ impl Server {
'S' => { 'S' => {
let mut param = vec![0u8; len as usize - 4]; let mut param = vec![0u8; len as usize - 4];
match read.read_exact(&mut param).await { match stream.read_exact(&mut param).await {
Ok(_) => (), Ok(_) => (),
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -575,7 +560,7 @@ impl Server {
'K' => { 'K' => {
// The frontend must save these values if it wishes to be able to issue CancelRequest messages later. // The frontend must save these values if it wishes to be able to issue CancelRequest messages later.
// See: <https://www.postgresql.org/docs/12/protocol-message-formats.html>. // See: <https://www.postgresql.org/docs/12/protocol-message-formats.html>.
process_id = match read.read_i32().await { process_id = match stream.read_i32().await {
Ok(id) => id, Ok(id) => id,
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -585,7 +570,7 @@ impl Server {
} }
}; };
secret_key = match read.read_i32().await { secret_key = match stream.read_i32().await {
Ok(id) => id, Ok(id) => id,
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -600,7 +585,7 @@ impl Server {
'Z' => { 'Z' => {
let mut idle = vec![0u8; len as usize - 4]; let mut idle = vec![0u8; len as usize - 4];
match read.read_exact(&mut idle).await { match stream.read_exact(&mut idle).await {
Ok(_) => (), Ok(_) => (),
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -612,8 +597,7 @@ impl Server {
let mut server = Server { let mut server = Server {
address: address.clone(), address: address.clone(),
read: BufReader::new(read), stream: BufStream::new(stream),
write,
buffer: BytesMut::with_capacity(8196), buffer: BytesMut::with_capacity(8196),
server_info, server_info,
process_id, process_id,
@@ -680,7 +664,7 @@ impl Server {
bytes.put_i32(process_id); bytes.put_i32(process_id);
bytes.put_i32(secret_key); bytes.put_i32(secret_key);
write_all(&mut stream, bytes).await write_all_flush(&mut stream, &bytes).await
} }
/// Send messages to the server from the client. /// Send messages to the server from the client.
@@ -688,7 +672,7 @@ impl Server {
self.mirror_send(messages); self.mirror_send(messages);
self.stats().data_sent(messages.len()); self.stats().data_sent(messages.len());
match write_all_half(&mut self.write, messages).await { match write_all_flush(&mut self.stream, &messages).await {
Ok(_) => { Ok(_) => {
// Successfully sent to server // Successfully sent to server
self.last_activity = SystemTime::now(); self.last_activity = SystemTime::now();
@@ -707,7 +691,7 @@ impl Server {
/// in order to receive all data the server has to offer. /// in order to receive all data the server has to offer.
pub async fn recv(&mut self) -> Result<BytesMut, Error> { pub async fn recv(&mut self) -> Result<BytesMut, Error> {
loop { loop {
let mut message = match read_message(&mut self.read).await { let mut message = match read_message(&mut self.stream).await {
Ok(message) => message, Ok(message) => message,
Err(err) => { Err(err) => {
error!("Terminating server because of: {:?}", err); error!("Terminating server because of: {:?}", err);
@@ -1100,14 +1084,14 @@ impl Drop for Server {
// Update statistics // Update statistics
self.stats.disconnect(); self.stats.disconnect();
// let mut bytes = BytesMut::with_capacity(4); let mut bytes = BytesMut::with_capacity(4);
// bytes.put_u8(b'X'); bytes.put_u8(b'X');
// bytes.put_i32(4); bytes.put_i32(4);
// match self.write.try_write(&bytes) { match self.stream.get_mut().try_write(&bytes) {
// Ok(_) => (), Ok(_) => (),
// Err(_) => debug!("Dirty shutdown"), Err(_) => debug!("Dirty shutdown"),
// }; };
// Should not matter. // Should not matter.
self.bad = true; self.bad = true;