Finish up TLS

This commit is contained in:
Lev Kokotov
2023-04-28 18:02:48 -07:00
parent 9e51b8110f
commit b36746a47b
9 changed files with 175 additions and 58 deletions

11
Cargo.lock generated
View File

@@ -766,6 +766,7 @@ dependencies = [
"postgres-protocol", "postgres-protocol",
"rand", "rand",
"regex", "regex",
"rustls",
"rustls-pemfile", "rustls-pemfile",
"serde", "serde",
"serde_derive", "serde_derive",
@@ -777,6 +778,7 @@ dependencies = [
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",
"toml", "toml",
"webpki-roots",
] ]
[[package]] [[package]]
@@ -1467,6 +1469,15 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "webpki-roots"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa54963694b65584e170cf5dc46aeb4dcaa5584e652ff5f3952e56d66aff0125"
dependencies = [
"rustls-webpki",
]
[[package]] [[package]]
name = "winapi" name = "winapi"
version = "0.3.9" version = "0.3.9"

View File

@@ -39,7 +39,9 @@ nix = "0.26.2"
atomic_enum = "0.2.0" atomic_enum = "0.2.0"
postgres-protocol = "0.6.5" postgres-protocol = "0.6.5"
fallible-iterator = "0.2" fallible-iterator = "0.2"
pin-project = "*" pin-project = "1"
webpki-roots = "0.23"
rustls = { version = "0.21", features = ["dangerous_configuration"] }
[target.'cfg(not(target_env = "msvc"))'.dependencies] [target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0" jemallocator = "0.5.0"

View File

@@ -61,9 +61,15 @@ tcp_keepalives_count = 5
tcp_keepalives_interval = 5 tcp_keepalives_interval = 5
# Path to TLS Certificate file to use for TLS connections # Path to TLS Certificate file to use for TLS connections
# tls_certificate = "server.cert" tls_certificate = ".circleci/server.cert"
# Path to TLS private key file to use for TLS connections # Path to TLS private key file to use for TLS connections
# tls_private_key = "server.key" tls_private_key = ".circleci/server.key"
# Enable/disable server TLS
server_tls = true
# Verify server certificate is completely authentic.
verify_server_certificate = false
# User name to access the virtual administrative database (pgbouncer or pgcat) # User name to access the virtual administrative database (pgbouncer or pgcat)
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc.. # Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..

View File

@@ -281,6 +281,13 @@ pub struct General {
pub tls_certificate: Option<String>, pub tls_certificate: Option<String>,
pub tls_private_key: Option<String>, pub tls_private_key: Option<String>,
#[serde(default)] // false
pub server_tls: bool,
#[serde(default)] // false
pub verify_server_certificate: bool,
pub admin_username: String, pub admin_username: String,
pub admin_password: String, pub admin_password: String,
@@ -373,6 +380,8 @@ impl Default for General {
autoreload: None, autoreload: None,
tls_certificate: None, tls_certificate: None,
tls_private_key: None, tls_private_key: None,
server_tls: false,
verify_server_certificate: false,
admin_username: String::from("admin"), admin_username: String::from("admin"),
admin_password: String::from("admin"), admin_password: String::from("admin"),
auth_query: None, auth_query: None,

View File

@@ -23,6 +23,7 @@ pub enum Error {
ParseBytesError(String), ParseBytesError(String),
AuthError(String), AuthError(String),
AuthPassthroughError(String), AuthPassthroughError(String),
TlsCertificateReadError(String),
} }
#[derive(Clone, PartialEq, Debug)] #[derive(Clone, PartialEq, Debug)]

View File

@@ -116,7 +116,10 @@ where
/// Send the startup packet the server. We're pretending we're a Pg client. /// Send the startup packet the server. We're pretending we're a Pg client.
/// This tells the server which user we are and what database we want. /// This tells the server which user we are and what database we want.
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> { pub async fn startup<S>(stream: &mut S, user: &str, database: &str) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut bytes = BytesMut::with_capacity(25); let mut bytes = BytesMut::with_capacity(25);
bytes.put_i32(196608); // Protocol number bytes.put_i32(196608); // Protocol number

View File

@@ -376,8 +376,7 @@ impl ConnectionPool {
.max_lifetime(Some(std::time::Duration::from_millis(server_lifetime))) .max_lifetime(Some(std::time::Duration::from_millis(server_lifetime)))
.test_on_check_out(false) .test_on_check_out(false)
.build(manager) .build(manager)
.await .await?;
.unwrap();
pools.push(pool); pools.push(pool);
servers.push(address); servers.push(address);

View File

@@ -10,14 +10,11 @@ use std::io::Read;
use std::sync::Arc; use std::sync::Arc;
use std::time::SystemTime; use std::time::SystemTime;
use tokio::io::{split, AsyncRead, AsyncReadExt, AsyncWrite, BufReader, ReadHalf, WriteHalf}; use tokio::io::{split, AsyncRead, AsyncReadExt, AsyncWrite, BufReader, ReadHalf, WriteHalf};
use tokio::net::{ use tokio::net::TcpStream;
tcp::{OwnedReadHalf, OwnedWriteHalf}, use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore};
TcpStream, use tokio_rustls::{client::TlsStream, TlsConnector};
};
use tokio_rustls::rustls::ClientConfig;
use tokio_rustls::{TlsConnector, TlsStream};
use crate::config::{Address, User}; use crate::config::{get_config, Address, User};
use crate::constants::*; use crate::constants::*;
use crate::errors::{Error, ServerIdentifier}; use crate::errors::{Error, ServerIdentifier};
use crate::messages::*; use crate::messages::*;
@@ -176,33 +173,97 @@ impl Server {
))); )));
} }
}; };
// TCP timeouts.
configure_socket(&stream); configure_socket(&stream);
// ssl_request(&mut stream).await?; let (mut read, mut write) = if get_config().general.server_tls {
// let response = match stream.read_u8().await { // Request a TLS connection
// Ok(response) => response as char, ssl_request(&mut stream).await?;
// Err(err) => return Err(Error::SocketError(format!("Server socket error: {:?}", err))),
// };
// match response { let response = match stream.read_u8().await {
// 'S' => { Ok(response) => response as char,
// let connector = TlsConnector::from(ClientConfig::builder() Err(err) => {
// .with_safe_default_cipher_suites() return Err(Error::SocketError(format!(
// .with_safe_default_kx_groups() "Server socket error: {:?}",
// .with_safe_default_protocol_versions() err
// .unwrap() )))
// .with_no_client_auth()); }
// connector.connect("test".into(), stream).await.unwrap(); };
// },
// 'N' => { match response {
// Server supports TLS
'S' => {
let mut root_store = RootCertStore::empty();
root_store.add_server_trust_anchors(
webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}),
);
// }, let mut config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
// _ => { // Equivalent to sslmode=prefer which is fine most places.
// return Err(Error::SocketError("error".into())); // If you want verify-full, change `verify_server_certificate` to true.
// } if !get_config().general.verify_server_certificate {
// }; let mut dangerous = config.dangerous();
dangerous.set_certificate_verifier(Arc::new(
crate::tls::NoCertificateVerification {},
));
}
let connector = TlsConnector::from(Arc::new(config));
let stream = match connector
.connect(address.host.as_str().try_into().unwrap(), stream)
.await
{
Ok(stream) => stream,
Err(err) => {
return Err(Error::SocketError(format!("Server TLS error: {:?}", err)))
}
};
let (read, write) = split(stream);
(
ReadInner::Tls { stream: read },
WriteInner::Tls { stream: write },
)
}
// Server does not support TLS
'N' => {
let (read, write) = split(stream);
(
ReadInner::Plain { stream: read },
WriteInner::Plain { stream: write },
)
}
// Something else?
m => {
return Err(Error::SocketError(format!(
"Unknown message: {}",
m as char
)));
}
}
} else {
let (read, write) = split(stream);
(
ReadInner::Plain { stream: read },
WriteInner::Plain { stream: write },
)
};
// let (read, write) = split(stream);
// let (mut read, mut write) = (ReadInner::Plain { stream: read }, WriteInner::Plain { stream: write });
trace!("Sending StartupMessage"); trace!("Sending StartupMessage");
@@ -220,7 +281,7 @@ impl Server {
}, },
}; };
startup(&mut stream, username, database).await?; startup(&mut write, username, database).await?;
let mut server_info = BytesMut::new(); let mut server_info = BytesMut::new();
let mut process_id: i32 = 0; let mut process_id: i32 = 0;
@@ -235,7 +296,7 @@ impl Server {
}; };
loop { loop {
let code = match stream.read_u8().await { let code = match read.read_u8().await {
Ok(code) => code as char, Ok(code) => code as char,
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -245,7 +306,7 @@ impl Server {
} }
}; };
let len = match stream.read_i32().await { let len = match read.read_i32().await {
Ok(len) => len, Ok(len) => len,
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -261,7 +322,7 @@ impl Server {
// Authentication // Authentication
'R' => { 'R' => {
// Determine which kind of authentication is required, if any. // Determine which kind of authentication is required, if any.
let auth_code = match stream.read_i32().await { let auth_code = match read.read_i32().await {
Ok(auth_code) => auth_code, Ok(auth_code) => auth_code,
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -279,7 +340,7 @@ impl Server {
// See: https://www.postgresql.org/docs/12/protocol-message-formats.html // See: https://www.postgresql.org/docs/12/protocol-message-formats.html
let mut salt = vec![0u8; 4]; let mut salt = vec![0u8; 4];
match stream.read_exact(&mut salt).await { match read.read_exact(&mut salt).await {
Ok(_) => (), Ok(_) => (),
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -292,7 +353,7 @@ impl Server {
match password { match password {
// Using plaintext password // Using plaintext password
Some(password) => { Some(password) => {
md5_password(&mut stream, username, password, &salt[..]).await? md5_password(&mut write, username, password, &salt[..]).await?
} }
// Using auth passthrough, in this case we should already have a // Using auth passthrough, in this case we should already have a
@@ -303,7 +364,7 @@ impl Server {
match option_hash { match option_hash {
Some(hash) => Some(hash) =>
md5_password_with_hash( md5_password_with_hash(
&mut stream, &mut write,
&hash, &hash,
&salt[..], &salt[..],
) )
@@ -337,7 +398,7 @@ impl Server {
let sasl_len = (len - 8) as usize; let sasl_len = (len - 8) as usize;
let mut sasl_auth = vec![0u8; sasl_len]; let mut sasl_auth = vec![0u8; sasl_len];
match stream.read_exact(&mut sasl_auth).await { match read.read_exact(&mut sasl_auth).await {
Ok(_) => (), Ok(_) => (),
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -349,7 +410,7 @@ impl Server {
let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]); let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]);
if sasl_type == SCRAM_SHA_256 { if sasl_type.contains(SCRAM_SHA_256) {
debug!("Using {}", SCRAM_SHA_256); debug!("Using {}", SCRAM_SHA_256);
// Generate client message. // Generate client message.
@@ -372,7 +433,7 @@ impl Server {
res.put_i32(sasl_response.len() as i32); res.put_i32(sasl_response.len() as i32);
res.put(sasl_response); res.put(sasl_response);
write_all(&mut stream, res).await?; write_all(&mut write, res).await?;
} else { } else {
error!("Unsupported SCRAM version: {}", sasl_type); error!("Unsupported SCRAM version: {}", sasl_type);
return Err(Error::ServerError); return Err(Error::ServerError);
@@ -384,7 +445,7 @@ impl Server {
let mut sasl_data = vec![0u8; (len - 8) as usize]; let mut sasl_data = vec![0u8; (len - 8) as usize];
match stream.read_exact(&mut sasl_data).await { match read.read_exact(&mut sasl_data).await {
Ok(_) => (), Ok(_) => (),
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -403,14 +464,14 @@ impl Server {
res.put_i32(4 + sasl_response.len() as i32); res.put_i32(4 + sasl_response.len() as i32);
res.put(sasl_response); res.put(sasl_response);
write_all(&mut stream, res).await?; write_all(&mut write, res).await?;
} }
SASL_FINAL => { SASL_FINAL => {
trace!("Final SASL"); trace!("Final SASL");
let mut sasl_final = vec![0u8; len as usize - 8]; let mut sasl_final = vec![0u8; len as usize - 8];
match stream.read_exact(&mut sasl_final).await { match read.read_exact(&mut sasl_final).await {
Ok(_) => (), Ok(_) => (),
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -445,7 +506,7 @@ impl Server {
// ErrorResponse // ErrorResponse
'E' => { 'E' => {
let error_code = match stream.read_u8().await { let error_code = match read.read_u8().await {
Ok(error_code) => error_code, Ok(error_code) => error_code,
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -466,7 +527,7 @@ impl Server {
// Read the error message without the terminating null character. // Read the error message without the terminating null character.
let mut error = vec![0u8; len as usize - 4 - 1]; let mut error = vec![0u8; len as usize - 4 - 1];
match stream.read_exact(&mut error).await { match read.read_exact(&mut error).await {
Ok(_) => (), Ok(_) => (),
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -490,7 +551,7 @@ impl Server {
'S' => { 'S' => {
let mut param = vec![0u8; len as usize - 4]; let mut param = vec![0u8; len as usize - 4];
match stream.read_exact(&mut param).await { match read.read_exact(&mut param).await {
Ok(_) => (), Ok(_) => (),
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -512,7 +573,7 @@ impl Server {
'K' => { 'K' => {
// The frontend must save these values if it wishes to be able to issue CancelRequest messages later. // The frontend must save these values if it wishes to be able to issue CancelRequest messages later.
// See: <https://www.postgresql.org/docs/12/protocol-message-formats.html>. // See: <https://www.postgresql.org/docs/12/protocol-message-formats.html>.
process_id = match stream.read_i32().await { process_id = match read.read_i32().await {
Ok(id) => id, Ok(id) => id,
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -522,7 +583,7 @@ impl Server {
} }
}; };
secret_key = match stream.read_i32().await { secret_key = match read.read_i32().await {
Ok(id) => id, Ok(id) => id,
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -537,7 +598,7 @@ impl Server {
'Z' => { 'Z' => {
let mut idle = vec![0u8; len as usize - 4]; let mut idle = vec![0u8; len as usize - 4];
match stream.read_exact(&mut idle).await { match read.read_exact(&mut idle).await {
Ok(_) => (), Ok(_) => (),
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -547,12 +608,10 @@ impl Server {
} }
}; };
let (read, write) = split(stream);
let mut server = Server { let mut server = Server {
address: address.clone(), address: address.clone(),
read: BufReader::new(ReadInner::Plain { stream: read }), read: BufReader::new(read),
write: WriteInner::Plain { stream: write }, write,
buffer: BytesMut::with_capacity(8196), buffer: BytesMut::with_capacity(8196),
server_info, server_info,
process_id, process_id,

View File

@@ -4,12 +4,23 @@ use rustls_pemfile::{certs, read_one, Item};
use std::iter; use std::iter;
use std::path::Path; use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use tokio_rustls::rustls::{self, Certificate, PrivateKey}; use std::time::SystemTime;
use tokio_rustls::rustls::{
self,
client::{ServerCertVerified, ServerCertVerifier},
Certificate, PrivateKey, ServerName,
};
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
use crate::config::get_config; use crate::config::get_config;
use crate::errors::Error; use crate::errors::Error;
impl From<std::io::Error> for Error {
fn from(err: std::io::Error) -> Error {
Error::TlsCertificateReadError(err.to_string())
}
}
// TLS // TLS
pub fn load_certs(path: &Path) -> std::io::Result<Vec<Certificate>> { pub fn load_certs(path: &Path) -> std::io::Result<Vec<Certificate>> {
certs(&mut std::io::BufReader::new(std::fs::File::open(path)?)) certs(&mut std::io::BufReader::new(std::fs::File::open(path)?))
@@ -64,3 +75,19 @@ impl Tls {
}) })
} }
} }
pub struct NoCertificateVerification;
impl ServerCertVerifier for NoCertificateVerification {
fn verify_server_cert(
&self,
_end_entity: &Certificate,
_intermediates: &[Certificate],
_server_name: &ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8],
_now: SystemTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
}