From b36746a47bc8b48c71c58aeffb01dd3516b77f5c Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Fri, 28 Apr 2023 18:02:48 -0700 Subject: [PATCH] Finish up TLS --- Cargo.lock | 11 ++++ Cargo.toml | 4 +- pgcat.toml | 10 ++- src/config.rs | 9 +++ src/errors.rs | 1 + src/messages.rs | 5 +- src/pool.rs | 3 +- src/server.rs | 161 +++++++++++++++++++++++++++++++++--------------- src/tls.rs | 29 ++++++++- 9 files changed, 175 insertions(+), 58 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 610a13d..7991667 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -766,6 +766,7 @@ dependencies = [ "postgres-protocol", "rand", "regex", + "rustls", "rustls-pemfile", "serde", "serde_derive", @@ -777,6 +778,7 @@ dependencies = [ "tokio", "tokio-rustls", "toml", + "webpki-roots", ] [[package]] @@ -1467,6 +1469,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-roots" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa54963694b65584e170cf5dc46aeb4dcaa5584e652ff5f3952e56d66aff0125" +dependencies = [ + "rustls-webpki", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/Cargo.toml b/Cargo.toml index c1f8f34..28e94a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,9 @@ nix = "0.26.2" atomic_enum = "0.2.0" postgres-protocol = "0.6.5" fallible-iterator = "0.2" -pin-project = "*" +pin-project = "1" +webpki-roots = "0.23" +rustls = { version = "0.21", features = ["dangerous_configuration"] } [target.'cfg(not(target_env = "msvc"))'.dependencies] jemallocator = "0.5.0" diff --git a/pgcat.toml b/pgcat.toml index 9203cb6..52cd100 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -61,9 +61,15 @@ tcp_keepalives_count = 5 tcp_keepalives_interval = 5 # Path to TLS Certificate file to use for TLS connections -# tls_certificate = "server.cert" +tls_certificate = ".circleci/server.cert" # Path to TLS private key file to use for TLS connections -# tls_private_key = "server.key" +tls_private_key = ".circleci/server.key" + +# Enable/disable server TLS +server_tls = true + +# Verify server certificate is completely authentic. +verify_server_certificate = false # User name to access the virtual administrative database (pgbouncer or pgcat) # Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc.. diff --git a/src/config.rs b/src/config.rs index d822486..061a758 100644 --- a/src/config.rs +++ b/src/config.rs @@ -281,6 +281,13 @@ pub struct General { pub tls_certificate: Option, pub tls_private_key: Option, + + #[serde(default)] // false + pub server_tls: bool, + + #[serde(default)] // false + pub verify_server_certificate: bool, + pub admin_username: String, pub admin_password: String, @@ -373,6 +380,8 @@ impl Default for General { autoreload: None, tls_certificate: None, tls_private_key: None, + server_tls: false, + verify_server_certificate: false, admin_username: String::from("admin"), admin_password: String::from("admin"), auth_query: None, diff --git a/src/errors.rs b/src/errors.rs index 0930ab8..4868f9e 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -23,6 +23,7 @@ pub enum Error { ParseBytesError(String), AuthError(String), AuthPassthroughError(String), + TlsCertificateReadError(String), } #[derive(Clone, PartialEq, Debug)] diff --git a/src/messages.rs b/src/messages.rs index 58d9b26..58b785b 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -116,7 +116,10 @@ where /// Send the startup packet the server. We're pretending we're a Pg client. /// This tells the server which user we are and what database we want. -pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> { +pub async fn startup(stream: &mut S, user: &str, database: &str) -> Result<(), Error> +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ let mut bytes = BytesMut::with_capacity(25); bytes.put_i32(196608); // Protocol number diff --git a/src/pool.rs b/src/pool.rs index 8ec8860..ee8de44 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -376,8 +376,7 @@ impl ConnectionPool { .max_lifetime(Some(std::time::Duration::from_millis(server_lifetime))) .test_on_check_out(false) .build(manager) - .await - .unwrap(); + .await?; pools.push(pool); servers.push(address); diff --git a/src/server.rs b/src/server.rs index 722108b..a04dc5f 100644 --- a/src/server.rs +++ b/src/server.rs @@ -10,14 +10,11 @@ use std::io::Read; use std::sync::Arc; use std::time::SystemTime; use tokio::io::{split, AsyncRead, AsyncReadExt, AsyncWrite, BufReader, ReadHalf, WriteHalf}; -use tokio::net::{ - tcp::{OwnedReadHalf, OwnedWriteHalf}, - TcpStream, -}; -use tokio_rustls::rustls::ClientConfig; -use tokio_rustls::{TlsConnector, TlsStream}; +use tokio::net::TcpStream; +use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore}; +use tokio_rustls::{client::TlsStream, TlsConnector}; -use crate::config::{Address, User}; +use crate::config::{get_config, Address, User}; use crate::constants::*; use crate::errors::{Error, ServerIdentifier}; use crate::messages::*; @@ -176,33 +173,97 @@ impl Server { ))); } }; + + // TCP timeouts. configure_socket(&stream); - // ssl_request(&mut stream).await?; - // let response = match stream.read_u8().await { - // Ok(response) => response as char, - // Err(err) => return Err(Error::SocketError(format!("Server socket error: {:?}", err))), - // }; + let (mut read, mut write) = if get_config().general.server_tls { + // Request a TLS connection + ssl_request(&mut stream).await?; - // match response { - // 'S' => { - // let connector = TlsConnector::from(ClientConfig::builder() - // .with_safe_default_cipher_suites() - // .with_safe_default_kx_groups() - // .with_safe_default_protocol_versions() - // .unwrap() - // .with_no_client_auth()); - // connector.connect("test".into(), stream).await.unwrap(); - // }, + let response = match stream.read_u8().await { + Ok(response) => response as char, + Err(err) => { + return Err(Error::SocketError(format!( + "Server socket error: {:?}", + err + ))) + } + }; - // 'N' => { + match response { + // Server supports TLS + 'S' => { + let mut root_store = RootCertStore::empty(); + root_store.add_server_trust_anchors( + webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + }), + ); - // }, + let mut config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(); - // _ => { - // return Err(Error::SocketError("error".into())); - // } - // }; + // Equivalent to sslmode=prefer which is fine most places. + // If you want verify-full, change `verify_server_certificate` to true. + if !get_config().general.verify_server_certificate { + let mut dangerous = config.dangerous(); + dangerous.set_certificate_verifier(Arc::new( + crate::tls::NoCertificateVerification {}, + )); + } + + let connector = TlsConnector::from(Arc::new(config)); + let stream = match connector + .connect(address.host.as_str().try_into().unwrap(), stream) + .await + { + Ok(stream) => stream, + Err(err) => { + return Err(Error::SocketError(format!("Server TLS error: {:?}", err))) + } + }; + + let (read, write) = split(stream); + ( + ReadInner::Tls { stream: read }, + WriteInner::Tls { stream: write }, + ) + } + + // Server does not support TLS + 'N' => { + let (read, write) = split(stream); + ( + ReadInner::Plain { stream: read }, + WriteInner::Plain { stream: write }, + ) + } + + // Something else? + m => { + return Err(Error::SocketError(format!( + "Unknown message: {}", + m as char + ))); + } + } + } else { + let (read, write) = split(stream); + ( + ReadInner::Plain { stream: read }, + WriteInner::Plain { stream: write }, + ) + }; + + // let (read, write) = split(stream); + // let (mut read, mut write) = (ReadInner::Plain { stream: read }, WriteInner::Plain { stream: write }); trace!("Sending StartupMessage"); @@ -220,7 +281,7 @@ impl Server { }, }; - startup(&mut stream, username, database).await?; + startup(&mut write, username, database).await?; let mut server_info = BytesMut::new(); let mut process_id: i32 = 0; @@ -235,7 +296,7 @@ impl Server { }; loop { - let code = match stream.read_u8().await { + let code = match read.read_u8().await { Ok(code) => code as char, Err(_) => { return Err(Error::ServerStartupError( @@ -245,7 +306,7 @@ impl Server { } }; - let len = match stream.read_i32().await { + let len = match read.read_i32().await { Ok(len) => len, Err(_) => { return Err(Error::ServerStartupError( @@ -261,7 +322,7 @@ impl Server { // Authentication 'R' => { // Determine which kind of authentication is required, if any. - let auth_code = match stream.read_i32().await { + let auth_code = match read.read_i32().await { Ok(auth_code) => auth_code, Err(_) => { return Err(Error::ServerStartupError( @@ -279,7 +340,7 @@ impl Server { // See: https://www.postgresql.org/docs/12/protocol-message-formats.html let mut salt = vec![0u8; 4]; - match stream.read_exact(&mut salt).await { + match read.read_exact(&mut salt).await { Ok(_) => (), Err(_) => { return Err(Error::ServerStartupError( @@ -292,7 +353,7 @@ impl Server { match password { // Using plaintext password Some(password) => { - md5_password(&mut stream, username, password, &salt[..]).await? + md5_password(&mut write, username, password, &salt[..]).await? } // Using auth passthrough, in this case we should already have a @@ -303,7 +364,7 @@ impl Server { match option_hash { Some(hash) => md5_password_with_hash( - &mut stream, + &mut write, &hash, &salt[..], ) @@ -337,7 +398,7 @@ impl Server { let sasl_len = (len - 8) as usize; let mut sasl_auth = vec![0u8; sasl_len]; - match stream.read_exact(&mut sasl_auth).await { + match read.read_exact(&mut sasl_auth).await { Ok(_) => (), Err(_) => { return Err(Error::ServerStartupError( @@ -349,7 +410,7 @@ impl Server { let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]); - if sasl_type == SCRAM_SHA_256 { + if sasl_type.contains(SCRAM_SHA_256) { debug!("Using {}", SCRAM_SHA_256); // Generate client message. @@ -372,7 +433,7 @@ impl Server { res.put_i32(sasl_response.len() as i32); res.put(sasl_response); - write_all(&mut stream, res).await?; + write_all(&mut write, res).await?; } else { error!("Unsupported SCRAM version: {}", sasl_type); return Err(Error::ServerError); @@ -384,7 +445,7 @@ impl Server { let mut sasl_data = vec![0u8; (len - 8) as usize]; - match stream.read_exact(&mut sasl_data).await { + match read.read_exact(&mut sasl_data).await { Ok(_) => (), Err(_) => { return Err(Error::ServerStartupError( @@ -403,14 +464,14 @@ impl Server { res.put_i32(4 + sasl_response.len() as i32); res.put(sasl_response); - write_all(&mut stream, res).await?; + write_all(&mut write, res).await?; } SASL_FINAL => { trace!("Final SASL"); let mut sasl_final = vec![0u8; len as usize - 8]; - match stream.read_exact(&mut sasl_final).await { + match read.read_exact(&mut sasl_final).await { Ok(_) => (), Err(_) => { return Err(Error::ServerStartupError( @@ -445,7 +506,7 @@ impl Server { // ErrorResponse 'E' => { - let error_code = match stream.read_u8().await { + let error_code = match read.read_u8().await { Ok(error_code) => error_code, Err(_) => { return Err(Error::ServerStartupError( @@ -466,7 +527,7 @@ impl Server { // Read the error message without the terminating null character. let mut error = vec![0u8; len as usize - 4 - 1]; - match stream.read_exact(&mut error).await { + match read.read_exact(&mut error).await { Ok(_) => (), Err(_) => { return Err(Error::ServerStartupError( @@ -490,7 +551,7 @@ impl Server { 'S' => { let mut param = vec![0u8; len as usize - 4]; - match stream.read_exact(&mut param).await { + match read.read_exact(&mut param).await { Ok(_) => (), Err(_) => { return Err(Error::ServerStartupError( @@ -512,7 +573,7 @@ impl Server { 'K' => { // The frontend must save these values if it wishes to be able to issue CancelRequest messages later. // See: . - process_id = match stream.read_i32().await { + process_id = match read.read_i32().await { Ok(id) => id, Err(_) => { return Err(Error::ServerStartupError( @@ -522,7 +583,7 @@ impl Server { } }; - secret_key = match stream.read_i32().await { + secret_key = match read.read_i32().await { Ok(id) => id, Err(_) => { return Err(Error::ServerStartupError( @@ -537,7 +598,7 @@ impl Server { 'Z' => { let mut idle = vec![0u8; len as usize - 4]; - match stream.read_exact(&mut idle).await { + match read.read_exact(&mut idle).await { Ok(_) => (), Err(_) => { return Err(Error::ServerStartupError( @@ -547,12 +608,10 @@ impl Server { } }; - let (read, write) = split(stream); - let mut server = Server { address: address.clone(), - read: BufReader::new(ReadInner::Plain { stream: read }), - write: WriteInner::Plain { stream: write }, + read: BufReader::new(read), + write, buffer: BytesMut::with_capacity(8196), server_info, process_id, diff --git a/src/tls.rs b/src/tls.rs index fbfbae7..019b145 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -4,12 +4,23 @@ use rustls_pemfile::{certs, read_one, Item}; use std::iter; use std::path::Path; use std::sync::Arc; -use tokio_rustls::rustls::{self, Certificate, PrivateKey}; +use std::time::SystemTime; +use tokio_rustls::rustls::{ + self, + client::{ServerCertVerified, ServerCertVerifier}, + Certificate, PrivateKey, ServerName, +}; use tokio_rustls::TlsAcceptor; use crate::config::get_config; use crate::errors::Error; +impl From for Error { + fn from(err: std::io::Error) -> Error { + Error::TlsCertificateReadError(err.to_string()) + } +} + // TLS pub fn load_certs(path: &Path) -> std::io::Result> { certs(&mut std::io::BufReader::new(std::fs::File::open(path)?)) @@ -64,3 +75,19 @@ impl Tls { }) } } + +pub struct NoCertificateVerification; + +impl ServerCertVerifier for NoCertificateVerification { + fn verify_server_cert( + &self, + _end_entity: &Certificate, + _intermediates: &[Certificate], + _server_name: &ServerName, + _scts: &mut dyn Iterator, + _ocsp_response: &[u8], + _now: SystemTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) + } +}