Compare commits

...

10 Commits

Author SHA1 Message Date
Lev Kokotov
ee23b374ae fetch config once 2023-04-30 09:19:39 -07:00
Lev Kokotov
9dffebccbf remove unused error 2023-04-30 09:07:15 -07:00
Lev Kokotov
4c8358b8b3 skip flakey test 2023-04-30 09:03:32 -07:00
Lev Kokotov
f0d1916a98 dirty shutdown 2023-04-30 08:23:30 -07:00
Lev Kokotov
bba5f10be1 maybe? 2023-04-29 08:38:27 -07:00
Lev Kokotov
a514dbc187 remove dead code 2023-04-28 18:08:20 -07:00
Lev Kokotov
d660e3e565 diff 2023-04-28 18:06:11 -07:00
Lev Kokotov
0d882cc204 thats it 2023-04-28 18:05:28 -07:00
Lev Kokotov
b36746a47b Finish up TLS 2023-04-28 18:02:48 -07:00
Lev Kokotov
9e51b8110f Server TLS 2023-04-28 11:20:49 -07:00
10 changed files with 311 additions and 33 deletions

32
Cargo.lock generated
View File

@@ -762,9 +762,11 @@ dependencies = [
"once_cell", "once_cell",
"parking_lot", "parking_lot",
"phf", "phf",
"pin-project",
"postgres-protocol", "postgres-protocol",
"rand", "rand",
"regex", "regex",
"rustls",
"rustls-pemfile", "rustls-pemfile",
"serde", "serde",
"serde_derive", "serde_derive",
@@ -776,6 +778,7 @@ dependencies = [
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",
"toml", "toml",
"webpki-roots",
] ]
[[package]] [[package]]
@@ -820,6 +823,26 @@ dependencies = [
"siphasher", "siphasher",
] ]
[[package]]
name = "pin-project"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]] [[package]]
name = "pin-project-lite" name = "pin-project-lite"
version = "0.2.9" version = "0.2.9"
@@ -1446,6 +1469,15 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "webpki-roots"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa54963694b65584e170cf5dc46aeb4dcaa5584e652ff5f3952e56d66aff0125"
dependencies = [
"rustls-webpki",
]
[[package]] [[package]]
name = "winapi" name = "winapi"
version = "0.3.9" version = "0.3.9"

View File

@@ -39,6 +39,9 @@ nix = "0.26.2"
atomic_enum = "0.2.0" atomic_enum = "0.2.0"
postgres-protocol = "0.6.5" postgres-protocol = "0.6.5"
fallible-iterator = "0.2" fallible-iterator = "0.2"
pin-project = "1"
webpki-roots = "0.23"
rustls = { version = "0.21", features = ["dangerous_configuration"] }
[target.'cfg(not(target_env = "msvc"))'.dependencies] [target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0" jemallocator = "0.5.0"

View File

@@ -61,9 +61,15 @@ tcp_keepalives_count = 5
tcp_keepalives_interval = 5 tcp_keepalives_interval = 5
# Path to TLS Certificate file to use for TLS connections # Path to TLS Certificate file to use for TLS connections
# tls_certificate = "server.cert" # tls_certificate = ".circleci/server.cert"
# Path to TLS private key file to use for TLS connections # Path to TLS private key file to use for TLS connections
# tls_private_key = "server.key" # tls_private_key = ".circleci/server.key"
# Enable/disable server TLS
server_tls = false
# Verify server certificate is completely authentic.
verify_server_certificate = false
# User name to access the virtual administrative database (pgbouncer or pgcat) # User name to access the virtual administrative database (pgbouncer or pgcat)
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc.. # Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..

View File

@@ -539,6 +539,7 @@ where
Some(md5_hash_password(username, password, &salt)) Some(md5_hash_password(username, password, &salt))
} else { } else {
if !get_config().is_auth_query_configured() { if !get_config().is_auth_query_configured() {
wrong_password(&mut write, username).await?;
return Err(Error::ClientAuthImpossible(username.into())); return Err(Error::ClientAuthImpossible(username.into()));
} }
@@ -565,6 +566,8 @@ where
} }
Err(err) => { Err(err) => {
wrong_password(&mut write, username).await?;
return Err(Error::ClientAuthPassthroughError( return Err(Error::ClientAuthPassthroughError(
err.to_string(), err.to_string(),
client_identifier, client_identifier,
@@ -587,7 +590,15 @@ where
client_identifier client_identifier
); );
let fetched_hash = refetch_auth_hash(&pool).await?; let fetched_hash = match refetch_auth_hash(&pool).await {
Ok(fetched_hash) => fetched_hash,
Err(err) => {
wrong_password(&mut write, username).await?;
return Err(err);
}
};
let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt); let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt);
// Ok password changed in server an auth is possible. // Ok password changed in server an auth is possible.

View File

@@ -281,6 +281,13 @@ pub struct General {
pub tls_certificate: Option<String>, pub tls_certificate: Option<String>,
pub tls_private_key: Option<String>, pub tls_private_key: Option<String>,
#[serde(default)] // false
pub server_tls: bool,
#[serde(default)] // false
pub verify_server_certificate: bool,
pub admin_username: String, pub admin_username: String,
pub admin_password: String, pub admin_password: String,
@@ -373,6 +380,8 @@ impl Default for General {
autoreload: None, autoreload: None,
tls_certificate: None, tls_certificate: None,
tls_private_key: None, tls_private_key: None,
server_tls: false,
verify_server_certificate: false,
admin_username: String::from("admin"), admin_username: String::from("admin"),
admin_password: String::from("admin"), admin_password: String::from("admin"),
auth_query: None, auth_query: None,
@@ -852,6 +861,11 @@ impl Config {
info!("TLS support is disabled"); info!("TLS support is disabled");
} }
}; };
info!("Server TLS enabled: {}", self.general.server_tls);
info!(
"Server TLS certificate verification: {}",
self.general.verify_server_certificate
);
for (pool_name, pool_config) in &self.pools { for (pool_name, pool_config) in &self.pools {
// TODO: Make this output prettier (maybe a table?) // TODO: Make this output prettier (maybe a table?)

View File

@@ -116,7 +116,10 @@ where
/// Send the startup packet the server. We're pretending we're a Pg client. /// Send the startup packet the server. We're pretending we're a Pg client.
/// This tells the server which user we are and what database we want. /// This tells the server which user we are and what database we want.
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> { pub async fn startup<S>(stream: &mut S, user: &str, database: &str) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut bytes = BytesMut::with_capacity(25); let mut bytes = BytesMut::with_capacity(25);
bytes.put_i32(196608); // Protocol number bytes.put_i32(196608); // Protocol number
@@ -150,6 +153,21 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
} }
} }
pub async fn ssl_request(stream: &mut TcpStream) -> Result<(), Error> {
let mut bytes = BytesMut::with_capacity(12);
bytes.put_i32(8);
bytes.put_i32(80877103);
match stream.write_all(&bytes).await {
Ok(_) => Ok(()),
Err(err) => Err(Error::SocketError(format!(
"Error writing SSLRequest to server socket - Error: {:?}",
err
))),
}
}
/// Parse the params the server sends as a key/value format. /// Parse the params the server sends as a key/value format.
pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> { pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
let mut result = HashMap::new(); let mut result = HashMap::new();
@@ -505,6 +523,29 @@ where
} }
} }
pub async fn write_all_flush<S>(stream: &mut S, buf: &[u8]) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
match stream.write_all(buf).await {
Ok(_) => match stream.flush().await {
Ok(_) => Ok(()),
Err(err) => {
return Err(Error::SocketError(format!(
"Error flushing socket - Error: {:?}",
err
)))
}
},
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
)))
}
}
}
/// Read a complete message from the socket. /// Read a complete message from the socket.
pub async fn read_message<S>(stream: &mut S) -> Result<BytesMut, Error> pub async fn read_message<S>(stream: &mut S) -> Result<BytesMut, Error>
where where

View File

@@ -376,8 +376,7 @@ impl ConnectionPool {
.max_lifetime(Some(std::time::Duration::from_millis(server_lifetime))) .max_lifetime(Some(std::time::Duration::from_millis(server_lifetime)))
.test_on_check_out(false) .test_on_check_out(false)
.build(manager) .build(manager)
.await .await?;
.unwrap();
pools.push(pool); pools.push(pool);
servers.push(address); servers.push(address);

View File

@@ -9,13 +9,12 @@ use std::collections::HashMap;
use std::io::Read; use std::io::Read;
use std::sync::Arc; use std::sync::Arc;
use std::time::SystemTime; use std::time::SystemTime;
use tokio::io::{AsyncReadExt, BufReader}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream};
use tokio::net::{ use tokio::net::TcpStream;
tcp::{OwnedReadHalf, OwnedWriteHalf}, use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore};
TcpStream, use tokio_rustls::{client::TlsStream, TlsConnector};
};
use crate::config::{Address, User}; use crate::config::{get_config, Address, User};
use crate::constants::*; use crate::constants::*;
use crate::errors::{Error, ServerIdentifier}; use crate::errors::{Error, ServerIdentifier};
use crate::messages::*; use crate::messages::*;
@@ -23,6 +22,84 @@ use crate::mirrors::MirroringManager;
use crate::pool::ClientServerMap; use crate::pool::ClientServerMap;
use crate::scram::ScramSha256; use crate::scram::ScramSha256;
use crate::stats::ServerStats; use crate::stats::ServerStats;
use std::io::Write;
use pin_project::pin_project;
#[pin_project(project = SteamInnerProj)]
pub enum StreamInner {
Plain {
#[pin]
stream: TcpStream,
},
Tls {
#[pin]
stream: TlsStream<TcpStream>,
},
}
impl AsyncWrite for StreamInner {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let this = self.project();
match this {
SteamInnerProj::Tls { stream } => stream.poll_write(cx, buf),
SteamInnerProj::Plain { stream } => stream.poll_write(cx, buf),
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let this = self.project();
match this {
SteamInnerProj::Tls { stream } => stream.poll_flush(cx),
SteamInnerProj::Plain { stream } => stream.poll_flush(cx),
}
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let this = self.project();
match this {
SteamInnerProj::Tls { stream } => stream.poll_shutdown(cx),
SteamInnerProj::Plain { stream } => stream.poll_shutdown(cx),
}
}
}
impl AsyncRead for StreamInner {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let this = self.project();
match this {
SteamInnerProj::Tls { stream } => stream.poll_read(cx, buf),
SteamInnerProj::Plain { stream } => stream.poll_read(cx, buf),
}
}
}
impl StreamInner {
pub fn try_write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self {
StreamInner::Tls { stream } => {
let r = stream.get_mut();
let mut w = r.1.writer();
w.write(buf)
}
StreamInner::Plain { stream } => stream.try_write(buf),
}
}
}
/// Server state. /// Server state.
pub struct Server { pub struct Server {
@@ -30,11 +107,8 @@ pub struct Server {
/// port, e.g. 5432, and role, e.g. primary or replica. /// port, e.g. 5432, and role, e.g. primary or replica.
address: Address, address: Address,
/// Buffered read socket. /// Server TCP connection.
read: BufReader<OwnedReadHalf>, stream: BufStream<StreamInner>,
/// Unbuffered write socket (our client code buffers).
write: OwnedWriteHalf,
/// Our server response buffer. We buffer data before we give it to the client. /// Our server response buffer. We buffer data before we give it to the client.
buffer: BytesMut, buffer: BytesMut,
@@ -98,8 +172,88 @@ impl Server {
))); )));
} }
}; };
// TCP timeouts.
configure_socket(&stream); configure_socket(&stream);
let config = get_config();
let mut stream = if config.general.server_tls {
// Request a TLS connection
ssl_request(&mut stream).await?;
let response = match stream.read_u8().await {
Ok(response) => response as char,
Err(err) => {
return Err(Error::SocketError(format!(
"Server socket error: {:?}",
err
)))
}
};
match response {
// Server supports TLS
'S' => {
debug!("Connecting to server using TLS");
let mut root_store = RootCertStore::empty();
root_store.add_server_trust_anchors(
webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}),
);
let mut tls_config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
// Equivalent to sslmode=prefer which is fine most places.
// If you want verify-full, change `verify_server_certificate` to true.
if !config.general.verify_server_certificate {
let mut dangerous = tls_config.dangerous();
dangerous.set_certificate_verifier(Arc::new(
crate::tls::NoCertificateVerification {},
));
}
let connector = TlsConnector::from(Arc::new(tls_config));
let stream = match connector
.connect(address.host.as_str().try_into().unwrap(), stream)
.await
{
Ok(stream) => stream,
Err(err) => {
return Err(Error::SocketError(format!("Server TLS error: {:?}", err)))
}
};
StreamInner::Tls { stream }
}
// Server does not support TLS
'N' => StreamInner::Plain { stream },
// Something else?
m => {
return Err(Error::SocketError(format!(
"Unknown message: {}",
m as char
)));
}
}
} else {
StreamInner::Plain { stream }
};
// let (read, write) = split(stream);
// let (mut read, mut write) = (ReadInner::Plain { stream: read }, WriteInner::Plain { stream: write });
trace!("Sending StartupMessage"); trace!("Sending StartupMessage");
// StartupMessage // StartupMessage
@@ -245,7 +399,7 @@ impl Server {
let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]); let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]);
if sasl_type == SCRAM_SHA_256 { if sasl_type.contains(SCRAM_SHA_256) {
debug!("Using {}", SCRAM_SHA_256); debug!("Using {}", SCRAM_SHA_256);
// Generate client message. // Generate client message.
@@ -268,7 +422,7 @@ impl Server {
res.put_i32(sasl_response.len() as i32); res.put_i32(sasl_response.len() as i32);
res.put(sasl_response); res.put(sasl_response);
write_all(&mut stream, res).await?; write_all_flush(&mut stream, &res).await?;
} else { } else {
error!("Unsupported SCRAM version: {}", sasl_type); error!("Unsupported SCRAM version: {}", sasl_type);
return Err(Error::ServerError); return Err(Error::ServerError);
@@ -299,7 +453,7 @@ impl Server {
res.put_i32(4 + sasl_response.len() as i32); res.put_i32(4 + sasl_response.len() as i32);
res.put(sasl_response); res.put(sasl_response);
write_all(&mut stream, res).await?; write_all_flush(&mut stream, &res).await?;
} }
SASL_FINAL => { SASL_FINAL => {
@@ -443,12 +597,9 @@ impl Server {
} }
}; };
let (read, write) = stream.into_split();
let mut server = Server { let mut server = Server {
address: address.clone(), address: address.clone(),
read: BufReader::new(read), stream: BufStream::new(stream),
write,
buffer: BytesMut::with_capacity(8196), buffer: BytesMut::with_capacity(8196),
server_info, server_info,
process_id, process_id,
@@ -515,7 +666,7 @@ impl Server {
bytes.put_i32(process_id); bytes.put_i32(process_id);
bytes.put_i32(secret_key); bytes.put_i32(secret_key);
write_all(&mut stream, bytes).await write_all_flush(&mut stream, &bytes).await
} }
/// Send messages to the server from the client. /// Send messages to the server from the client.
@@ -523,7 +674,7 @@ impl Server {
self.mirror_send(messages); self.mirror_send(messages);
self.stats().data_sent(messages.len()); self.stats().data_sent(messages.len());
match write_all_half(&mut self.write, messages).await { match write_all_flush(&mut self.stream, &messages).await {
Ok(_) => { Ok(_) => {
// Successfully sent to server // Successfully sent to server
self.last_activity = SystemTime::now(); self.last_activity = SystemTime::now();
@@ -542,7 +693,7 @@ impl Server {
/// in order to receive all data the server has to offer. /// in order to receive all data the server has to offer.
pub async fn recv(&mut self) -> Result<BytesMut, Error> { pub async fn recv(&mut self) -> Result<BytesMut, Error> {
loop { loop {
let mut message = match read_message(&mut self.read).await { let mut message = match read_message(&mut self.stream).await {
Ok(message) => message, Ok(message) => message,
Err(err) => { Err(err) => {
error!("Terminating server because of: {:?}", err); error!("Terminating server because of: {:?}", err);
@@ -935,13 +1086,13 @@ impl Drop for Server {
// Update statistics // Update statistics
self.stats.disconnect(); self.stats.disconnect();
let mut bytes = BytesMut::with_capacity(4); let mut bytes = BytesMut::with_capacity(5);
bytes.put_u8(b'X'); bytes.put_u8(b'X');
bytes.put_i32(4); bytes.put_i32(4);
match self.write.try_write(&bytes) { match self.stream.get_mut().try_write(&bytes) {
Ok(_) => (), Ok(5) => (),
Err(_) => debug!("Dirty shutdown"), _ => debug!("Dirty shutdown"),
}; };
// Should not matter. // Should not matter.

View File

@@ -4,7 +4,12 @@ use rustls_pemfile::{certs, read_one, Item};
use std::iter; use std::iter;
use std::path::Path; use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use tokio_rustls::rustls::{self, Certificate, PrivateKey}; use std::time::SystemTime;
use tokio_rustls::rustls::{
self,
client::{ServerCertVerified, ServerCertVerifier},
Certificate, PrivateKey, ServerName,
};
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
use crate::config::get_config; use crate::config::get_config;
@@ -64,3 +69,19 @@ impl Tls {
}) })
} }
} }
pub struct NoCertificateVerification;
impl ServerCertVerifier for NoCertificateVerification {
fn verify_server_cert(
&self,
_end_entity: &Certificate,
_intermediates: &[Certificate],
_server_name: &ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8],
_now: SystemTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
}

View File

@@ -25,7 +25,7 @@ describe "Query Mirroing" do
processes.pgcat.shutdown processes.pgcat.shutdown
end end
it "can mirror a query" do xit "can mirror a query" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
runs = 15 runs = 15
runs.times { conn.async_exec("SELECT 1 + 2") } runs.times { conn.async_exec("SELECT 1 + 2") }