mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 01:16:30 +00:00
Server TLS
This commit is contained in:
21
Cargo.lock
generated
21
Cargo.lock
generated
@@ -762,6 +762,7 @@ dependencies = [
|
||||
"once_cell",
|
||||
"parking_lot",
|
||||
"phf",
|
||||
"pin-project",
|
||||
"postgres-protocol",
|
||||
"rand",
|
||||
"regex",
|
||||
@@ -820,6 +821,26 @@ dependencies = [
|
||||
"siphasher",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project"
|
||||
version = "1.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc"
|
||||
dependencies = [
|
||||
"pin-project-internal",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-internal"
|
||||
version = "1.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-lite"
|
||||
version = "0.2.9"
|
||||
|
||||
@@ -39,6 +39,7 @@ nix = "0.26.2"
|
||||
atomic_enum = "0.2.0"
|
||||
postgres-protocol = "0.6.5"
|
||||
fallible-iterator = "0.2"
|
||||
pin-project = "*"
|
||||
|
||||
[target.'cfg(not(target_env = "msvc"))'.dependencies]
|
||||
jemallocator = "0.5.0"
|
||||
|
||||
@@ -539,6 +539,7 @@ where
|
||||
Some(md5_hash_password(username, password, &salt))
|
||||
} else {
|
||||
if !get_config().is_auth_query_configured() {
|
||||
wrong_password(&mut write, username).await?;
|
||||
return Err(Error::ClientAuthImpossible(username.into()));
|
||||
}
|
||||
|
||||
@@ -565,6 +566,8 @@ where
|
||||
}
|
||||
|
||||
Err(err) => {
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(Error::ClientAuthPassthroughError(
|
||||
err.to_string(),
|
||||
client_identifier,
|
||||
@@ -587,7 +590,15 @@ where
|
||||
client_identifier
|
||||
);
|
||||
|
||||
let fetched_hash = refetch_auth_hash(&pool).await?;
|
||||
let fetched_hash = match refetch_auth_hash(&pool).await {
|
||||
Ok(fetched_hash) => fetched_hash,
|
||||
Err(err) => {
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt);
|
||||
|
||||
// Ok password changed in server an auth is possible.
|
||||
|
||||
@@ -150,6 +150,21 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn ssl_request(stream: &mut TcpStream) -> Result<(), Error> {
|
||||
let mut bytes = BytesMut::with_capacity(12);
|
||||
|
||||
bytes.put_i32(8);
|
||||
bytes.put_i32(80877103);
|
||||
|
||||
match stream.write_all(&bytes).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(Error::SocketError(format!(
|
||||
"Error writing SSLRequest to server socket - Error: {:?}",
|
||||
err
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse the params the server sends as a key/value format.
|
||||
pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
|
||||
let mut result = HashMap::new();
|
||||
|
||||
130
src/server.rs
130
src/server.rs
@@ -9,11 +9,13 @@ use std::collections::HashMap;
|
||||
use std::io::Read;
|
||||
use std::sync::Arc;
|
||||
use std::time::SystemTime;
|
||||
use tokio::io::{AsyncReadExt, BufReader};
|
||||
use tokio::io::{split, AsyncRead, AsyncReadExt, AsyncWrite, BufReader, ReadHalf, WriteHalf};
|
||||
use tokio::net::{
|
||||
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
||||
TcpStream,
|
||||
};
|
||||
use tokio_rustls::rustls::ClientConfig;
|
||||
use tokio_rustls::{TlsConnector, TlsStream};
|
||||
|
||||
use crate::config::{Address, User};
|
||||
use crate::constants::*;
|
||||
@@ -24,6 +26,82 @@ use crate::pool::ClientServerMap;
|
||||
use crate::scram::ScramSha256;
|
||||
use crate::stats::ServerStats;
|
||||
|
||||
use pin_project::pin_project;
|
||||
|
||||
#[pin_project(project = ReadInnerProj)]
|
||||
pub enum ReadInner {
|
||||
Plain {
|
||||
#[pin]
|
||||
stream: ReadHalf<TcpStream>,
|
||||
},
|
||||
Tls {
|
||||
#[pin]
|
||||
stream: ReadHalf<TlsStream<TcpStream>>,
|
||||
},
|
||||
}
|
||||
|
||||
#[pin_project(project = WriteInnerProj)]
|
||||
pub enum WriteInner {
|
||||
Plain {
|
||||
#[pin]
|
||||
stream: WriteHalf<TcpStream>,
|
||||
},
|
||||
Tls {
|
||||
#[pin]
|
||||
stream: WriteHalf<TlsStream<TcpStream>>,
|
||||
},
|
||||
}
|
||||
|
||||
impl AsyncWrite for WriteInner {
|
||||
fn poll_write(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> std::task::Poll<Result<usize, std::io::Error>> {
|
||||
let this = self.project();
|
||||
match this {
|
||||
WriteInnerProj::Tls { stream } => stream.poll_write(cx, buf),
|
||||
WriteInnerProj::Plain { stream } => stream.poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), std::io::Error>> {
|
||||
let this = self.project();
|
||||
match this {
|
||||
WriteInnerProj::Tls { stream } => stream.poll_flush(cx),
|
||||
WriteInnerProj::Plain { stream } => stream.poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), std::io::Error>> {
|
||||
let this = self.project();
|
||||
match this {
|
||||
WriteInnerProj::Tls { stream } => stream.poll_shutdown(cx),
|
||||
WriteInnerProj::Plain { stream } => stream.poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for ReadInner {
|
||||
fn poll_read(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
let this = self.project();
|
||||
match this {
|
||||
ReadInnerProj::Tls { stream } => stream.poll_read(cx, buf),
|
||||
ReadInnerProj::Plain { stream } => stream.poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Server state.
|
||||
pub struct Server {
|
||||
/// Server host, e.g. localhost,
|
||||
@@ -31,10 +109,10 @@ pub struct Server {
|
||||
address: Address,
|
||||
|
||||
/// Buffered read socket.
|
||||
read: BufReader<OwnedReadHalf>,
|
||||
read: BufReader<ReadInner>,
|
||||
|
||||
/// Unbuffered write socket (our client code buffers).
|
||||
write: OwnedWriteHalf,
|
||||
write: WriteInner,
|
||||
|
||||
/// Our server response buffer. We buffer data before we give it to the client.
|
||||
buffer: BytesMut,
|
||||
@@ -100,6 +178,32 @@ impl Server {
|
||||
};
|
||||
configure_socket(&stream);
|
||||
|
||||
// ssl_request(&mut stream).await?;
|
||||
// let response = match stream.read_u8().await {
|
||||
// Ok(response) => response as char,
|
||||
// Err(err) => return Err(Error::SocketError(format!("Server socket error: {:?}", err))),
|
||||
// };
|
||||
|
||||
// match response {
|
||||
// 'S' => {
|
||||
// let connector = TlsConnector::from(ClientConfig::builder()
|
||||
// .with_safe_default_cipher_suites()
|
||||
// .with_safe_default_kx_groups()
|
||||
// .with_safe_default_protocol_versions()
|
||||
// .unwrap()
|
||||
// .with_no_client_auth());
|
||||
// connector.connect("test".into(), stream).await.unwrap();
|
||||
// },
|
||||
|
||||
// 'N' => {
|
||||
|
||||
// },
|
||||
|
||||
// _ => {
|
||||
// return Err(Error::SocketError("error".into()));
|
||||
// }
|
||||
// };
|
||||
|
||||
trace!("Sending StartupMessage");
|
||||
|
||||
// StartupMessage
|
||||
@@ -443,12 +547,12 @@ impl Server {
|
||||
}
|
||||
};
|
||||
|
||||
let (read, write) = stream.into_split();
|
||||
let (read, write) = split(stream);
|
||||
|
||||
let mut server = Server {
|
||||
address: address.clone(),
|
||||
read: BufReader::new(read),
|
||||
write,
|
||||
read: BufReader::new(ReadInner::Plain { stream: read }),
|
||||
write: WriteInner::Plain { stream: write },
|
||||
buffer: BytesMut::with_capacity(8196),
|
||||
server_info,
|
||||
process_id,
|
||||
@@ -935,14 +1039,14 @@ impl Drop for Server {
|
||||
// Update statistics
|
||||
self.stats.disconnect();
|
||||
|
||||
let mut bytes = BytesMut::with_capacity(4);
|
||||
bytes.put_u8(b'X');
|
||||
bytes.put_i32(4);
|
||||
// let mut bytes = BytesMut::with_capacity(4);
|
||||
// bytes.put_u8(b'X');
|
||||
// bytes.put_i32(4);
|
||||
|
||||
match self.write.try_write(&bytes) {
|
||||
Ok(_) => (),
|
||||
Err(_) => debug!("Dirty shutdown"),
|
||||
};
|
||||
// match self.write.try_write(&bytes) {
|
||||
// Ok(_) => (),
|
||||
// Err(_) => debug!("Dirty shutdown"),
|
||||
// };
|
||||
|
||||
// Should not matter.
|
||||
self.bad = true;
|
||||
|
||||
Reference in New Issue
Block a user