mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 01:16:30 +00:00
Finish up TLS
This commit is contained in:
11
Cargo.lock
generated
11
Cargo.lock
generated
@@ -766,6 +766,7 @@ dependencies = [
|
||||
"postgres-protocol",
|
||||
"rand",
|
||||
"regex",
|
||||
"rustls",
|
||||
"rustls-pemfile",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
@@ -777,6 +778,7 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"toml",
|
||||
"webpki-roots",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1467,6 +1469,15 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "0.23.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "aa54963694b65584e170cf5dc46aeb4dcaa5584e652ff5f3952e56d66aff0125"
|
||||
dependencies = [
|
||||
"rustls-webpki",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
|
||||
@@ -39,7 +39,9 @@ nix = "0.26.2"
|
||||
atomic_enum = "0.2.0"
|
||||
postgres-protocol = "0.6.5"
|
||||
fallible-iterator = "0.2"
|
||||
pin-project = "*"
|
||||
pin-project = "1"
|
||||
webpki-roots = "0.23"
|
||||
rustls = { version = "0.21", features = ["dangerous_configuration"] }
|
||||
|
||||
[target.'cfg(not(target_env = "msvc"))'.dependencies]
|
||||
jemallocator = "0.5.0"
|
||||
|
||||
10
pgcat.toml
10
pgcat.toml
@@ -61,9 +61,15 @@ tcp_keepalives_count = 5
|
||||
tcp_keepalives_interval = 5
|
||||
|
||||
# Path to TLS Certificate file to use for TLS connections
|
||||
# tls_certificate = "server.cert"
|
||||
tls_certificate = ".circleci/server.cert"
|
||||
# Path to TLS private key file to use for TLS connections
|
||||
# tls_private_key = "server.key"
|
||||
tls_private_key = ".circleci/server.key"
|
||||
|
||||
# Enable/disable server TLS
|
||||
server_tls = true
|
||||
|
||||
# Verify server certificate is completely authentic.
|
||||
verify_server_certificate = false
|
||||
|
||||
# User name to access the virtual administrative database (pgbouncer or pgcat)
|
||||
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
|
||||
|
||||
@@ -281,6 +281,13 @@ pub struct General {
|
||||
|
||||
pub tls_certificate: Option<String>,
|
||||
pub tls_private_key: Option<String>,
|
||||
|
||||
#[serde(default)] // false
|
||||
pub server_tls: bool,
|
||||
|
||||
#[serde(default)] // false
|
||||
pub verify_server_certificate: bool,
|
||||
|
||||
pub admin_username: String,
|
||||
pub admin_password: String,
|
||||
|
||||
@@ -373,6 +380,8 @@ impl Default for General {
|
||||
autoreload: None,
|
||||
tls_certificate: None,
|
||||
tls_private_key: None,
|
||||
server_tls: false,
|
||||
verify_server_certificate: false,
|
||||
admin_username: String::from("admin"),
|
||||
admin_password: String::from("admin"),
|
||||
auth_query: None,
|
||||
|
||||
@@ -23,6 +23,7 @@ pub enum Error {
|
||||
ParseBytesError(String),
|
||||
AuthError(String),
|
||||
AuthPassthroughError(String),
|
||||
TlsCertificateReadError(String),
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
|
||||
@@ -116,7 +116,10 @@ where
|
||||
|
||||
/// Send the startup packet the server. We're pretending we're a Pg client.
|
||||
/// This tells the server which user we are and what database we want.
|
||||
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> {
|
||||
pub async fn startup<S>(stream: &mut S, user: &str, database: &str) -> Result<(), Error>
|
||||
where
|
||||
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
let mut bytes = BytesMut::with_capacity(25);
|
||||
|
||||
bytes.put_i32(196608); // Protocol number
|
||||
|
||||
@@ -376,8 +376,7 @@ impl ConnectionPool {
|
||||
.max_lifetime(Some(std::time::Duration::from_millis(server_lifetime)))
|
||||
.test_on_check_out(false)
|
||||
.build(manager)
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
|
||||
pools.push(pool);
|
||||
servers.push(address);
|
||||
|
||||
161
src/server.rs
161
src/server.rs
@@ -10,14 +10,11 @@ use std::io::Read;
|
||||
use std::sync::Arc;
|
||||
use std::time::SystemTime;
|
||||
use tokio::io::{split, AsyncRead, AsyncReadExt, AsyncWrite, BufReader, ReadHalf, WriteHalf};
|
||||
use tokio::net::{
|
||||
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
||||
TcpStream,
|
||||
};
|
||||
use tokio_rustls::rustls::ClientConfig;
|
||||
use tokio_rustls::{TlsConnector, TlsStream};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore};
|
||||
use tokio_rustls::{client::TlsStream, TlsConnector};
|
||||
|
||||
use crate::config::{Address, User};
|
||||
use crate::config::{get_config, Address, User};
|
||||
use crate::constants::*;
|
||||
use crate::errors::{Error, ServerIdentifier};
|
||||
use crate::messages::*;
|
||||
@@ -176,33 +173,97 @@ impl Server {
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
// TCP timeouts.
|
||||
configure_socket(&stream);
|
||||
|
||||
// ssl_request(&mut stream).await?;
|
||||
// let response = match stream.read_u8().await {
|
||||
// Ok(response) => response as char,
|
||||
// Err(err) => return Err(Error::SocketError(format!("Server socket error: {:?}", err))),
|
||||
// };
|
||||
let (mut read, mut write) = if get_config().general.server_tls {
|
||||
// Request a TLS connection
|
||||
ssl_request(&mut stream).await?;
|
||||
|
||||
// match response {
|
||||
// 'S' => {
|
||||
// let connector = TlsConnector::from(ClientConfig::builder()
|
||||
// .with_safe_default_cipher_suites()
|
||||
// .with_safe_default_kx_groups()
|
||||
// .with_safe_default_protocol_versions()
|
||||
// .unwrap()
|
||||
// .with_no_client_auth());
|
||||
// connector.connect("test".into(), stream).await.unwrap();
|
||||
// },
|
||||
let response = match stream.read_u8().await {
|
||||
Ok(response) => response as char,
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Server socket error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
// 'N' => {
|
||||
match response {
|
||||
// Server supports TLS
|
||||
'S' => {
|
||||
let mut root_store = RootCertStore::empty();
|
||||
root_store.add_server_trust_anchors(
|
||||
webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
|
||||
OwnedTrustAnchor::from_subject_spki_name_constraints(
|
||||
ta.subject,
|
||||
ta.spki,
|
||||
ta.name_constraints,
|
||||
)
|
||||
}),
|
||||
);
|
||||
|
||||
// },
|
||||
let mut config = rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
|
||||
// _ => {
|
||||
// return Err(Error::SocketError("error".into()));
|
||||
// }
|
||||
// };
|
||||
// Equivalent to sslmode=prefer which is fine most places.
|
||||
// If you want verify-full, change `verify_server_certificate` to true.
|
||||
if !get_config().general.verify_server_certificate {
|
||||
let mut dangerous = config.dangerous();
|
||||
dangerous.set_certificate_verifier(Arc::new(
|
||||
crate::tls::NoCertificateVerification {},
|
||||
));
|
||||
}
|
||||
|
||||
let connector = TlsConnector::from(Arc::new(config));
|
||||
let stream = match connector
|
||||
.connect(address.host.as_str().try_into().unwrap(), stream)
|
||||
.await
|
||||
{
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!("Server TLS error: {:?}", err)))
|
||||
}
|
||||
};
|
||||
|
||||
let (read, write) = split(stream);
|
||||
(
|
||||
ReadInner::Tls { stream: read },
|
||||
WriteInner::Tls { stream: write },
|
||||
)
|
||||
}
|
||||
|
||||
// Server does not support TLS
|
||||
'N' => {
|
||||
let (read, write) = split(stream);
|
||||
(
|
||||
ReadInner::Plain { stream: read },
|
||||
WriteInner::Plain { stream: write },
|
||||
)
|
||||
}
|
||||
|
||||
// Something else?
|
||||
m => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Unknown message: {}",
|
||||
m as char
|
||||
)));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let (read, write) = split(stream);
|
||||
(
|
||||
ReadInner::Plain { stream: read },
|
||||
WriteInner::Plain { stream: write },
|
||||
)
|
||||
};
|
||||
|
||||
// let (read, write) = split(stream);
|
||||
// let (mut read, mut write) = (ReadInner::Plain { stream: read }, WriteInner::Plain { stream: write });
|
||||
|
||||
trace!("Sending StartupMessage");
|
||||
|
||||
@@ -220,7 +281,7 @@ impl Server {
|
||||
},
|
||||
};
|
||||
|
||||
startup(&mut stream, username, database).await?;
|
||||
startup(&mut write, username, database).await?;
|
||||
|
||||
let mut server_info = BytesMut::new();
|
||||
let mut process_id: i32 = 0;
|
||||
@@ -235,7 +296,7 @@ impl Server {
|
||||
};
|
||||
|
||||
loop {
|
||||
let code = match stream.read_u8().await {
|
||||
let code = match read.read_u8().await {
|
||||
Ok(code) => code as char,
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
@@ -245,7 +306,7 @@ impl Server {
|
||||
}
|
||||
};
|
||||
|
||||
let len = match stream.read_i32().await {
|
||||
let len = match read.read_i32().await {
|
||||
Ok(len) => len,
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
@@ -261,7 +322,7 @@ impl Server {
|
||||
// Authentication
|
||||
'R' => {
|
||||
// Determine which kind of authentication is required, if any.
|
||||
let auth_code = match stream.read_i32().await {
|
||||
let auth_code = match read.read_i32().await {
|
||||
Ok(auth_code) => auth_code,
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
@@ -279,7 +340,7 @@ impl Server {
|
||||
// See: https://www.postgresql.org/docs/12/protocol-message-formats.html
|
||||
let mut salt = vec![0u8; 4];
|
||||
|
||||
match stream.read_exact(&mut salt).await {
|
||||
match read.read_exact(&mut salt).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
@@ -292,7 +353,7 @@ impl Server {
|
||||
match password {
|
||||
// Using plaintext password
|
||||
Some(password) => {
|
||||
md5_password(&mut stream, username, password, &salt[..]).await?
|
||||
md5_password(&mut write, username, password, &salt[..]).await?
|
||||
}
|
||||
|
||||
// Using auth passthrough, in this case we should already have a
|
||||
@@ -303,7 +364,7 @@ impl Server {
|
||||
match option_hash {
|
||||
Some(hash) =>
|
||||
md5_password_with_hash(
|
||||
&mut stream,
|
||||
&mut write,
|
||||
&hash,
|
||||
&salt[..],
|
||||
)
|
||||
@@ -337,7 +398,7 @@ impl Server {
|
||||
let sasl_len = (len - 8) as usize;
|
||||
let mut sasl_auth = vec![0u8; sasl_len];
|
||||
|
||||
match stream.read_exact(&mut sasl_auth).await {
|
||||
match read.read_exact(&mut sasl_auth).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
@@ -349,7 +410,7 @@ impl Server {
|
||||
|
||||
let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]);
|
||||
|
||||
if sasl_type == SCRAM_SHA_256 {
|
||||
if sasl_type.contains(SCRAM_SHA_256) {
|
||||
debug!("Using {}", SCRAM_SHA_256);
|
||||
|
||||
// Generate client message.
|
||||
@@ -372,7 +433,7 @@ impl Server {
|
||||
res.put_i32(sasl_response.len() as i32);
|
||||
res.put(sasl_response);
|
||||
|
||||
write_all(&mut stream, res).await?;
|
||||
write_all(&mut write, res).await?;
|
||||
} else {
|
||||
error!("Unsupported SCRAM version: {}", sasl_type);
|
||||
return Err(Error::ServerError);
|
||||
@@ -384,7 +445,7 @@ impl Server {
|
||||
|
||||
let mut sasl_data = vec![0u8; (len - 8) as usize];
|
||||
|
||||
match stream.read_exact(&mut sasl_data).await {
|
||||
match read.read_exact(&mut sasl_data).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
@@ -403,14 +464,14 @@ impl Server {
|
||||
res.put_i32(4 + sasl_response.len() as i32);
|
||||
res.put(sasl_response);
|
||||
|
||||
write_all(&mut stream, res).await?;
|
||||
write_all(&mut write, res).await?;
|
||||
}
|
||||
|
||||
SASL_FINAL => {
|
||||
trace!("Final SASL");
|
||||
|
||||
let mut sasl_final = vec![0u8; len as usize - 8];
|
||||
match stream.read_exact(&mut sasl_final).await {
|
||||
match read.read_exact(&mut sasl_final).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
@@ -445,7 +506,7 @@ impl Server {
|
||||
|
||||
// ErrorResponse
|
||||
'E' => {
|
||||
let error_code = match stream.read_u8().await {
|
||||
let error_code = match read.read_u8().await {
|
||||
Ok(error_code) => error_code,
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
@@ -466,7 +527,7 @@ impl Server {
|
||||
// Read the error message without the terminating null character.
|
||||
let mut error = vec![0u8; len as usize - 4 - 1];
|
||||
|
||||
match stream.read_exact(&mut error).await {
|
||||
match read.read_exact(&mut error).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
@@ -490,7 +551,7 @@ impl Server {
|
||||
'S' => {
|
||||
let mut param = vec![0u8; len as usize - 4];
|
||||
|
||||
match stream.read_exact(&mut param).await {
|
||||
match read.read_exact(&mut param).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
@@ -512,7 +573,7 @@ impl Server {
|
||||
'K' => {
|
||||
// The frontend must save these values if it wishes to be able to issue CancelRequest messages later.
|
||||
// See: <https://www.postgresql.org/docs/12/protocol-message-formats.html>.
|
||||
process_id = match stream.read_i32().await {
|
||||
process_id = match read.read_i32().await {
|
||||
Ok(id) => id,
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
@@ -522,7 +583,7 @@ impl Server {
|
||||
}
|
||||
};
|
||||
|
||||
secret_key = match stream.read_i32().await {
|
||||
secret_key = match read.read_i32().await {
|
||||
Ok(id) => id,
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
@@ -537,7 +598,7 @@ impl Server {
|
||||
'Z' => {
|
||||
let mut idle = vec![0u8; len as usize - 4];
|
||||
|
||||
match stream.read_exact(&mut idle).await {
|
||||
match read.read_exact(&mut idle).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
@@ -547,12 +608,10 @@ impl Server {
|
||||
}
|
||||
};
|
||||
|
||||
let (read, write) = split(stream);
|
||||
|
||||
let mut server = Server {
|
||||
address: address.clone(),
|
||||
read: BufReader::new(ReadInner::Plain { stream: read }),
|
||||
write: WriteInner::Plain { stream: write },
|
||||
read: BufReader::new(read),
|
||||
write,
|
||||
buffer: BytesMut::with_capacity(8196),
|
||||
server_info,
|
||||
process_id,
|
||||
|
||||
29
src/tls.rs
29
src/tls.rs
@@ -4,12 +4,23 @@ use rustls_pemfile::{certs, read_one, Item};
|
||||
use std::iter;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tokio_rustls::rustls::{self, Certificate, PrivateKey};
|
||||
use std::time::SystemTime;
|
||||
use tokio_rustls::rustls::{
|
||||
self,
|
||||
client::{ServerCertVerified, ServerCertVerifier},
|
||||
Certificate, PrivateKey, ServerName,
|
||||
};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
|
||||
use crate::config::get_config;
|
||||
use crate::errors::Error;
|
||||
|
||||
impl From<std::io::Error> for Error {
|
||||
fn from(err: std::io::Error) -> Error {
|
||||
Error::TlsCertificateReadError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
// TLS
|
||||
pub fn load_certs(path: &Path) -> std::io::Result<Vec<Certificate>> {
|
||||
certs(&mut std::io::BufReader::new(std::fs::File::open(path)?))
|
||||
@@ -64,3 +75,19 @@ impl Tls {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NoCertificateVerification;
|
||||
|
||||
impl ServerCertVerifier for NoCertificateVerification {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &Certificate,
|
||||
_intermediates: &[Certificate],
|
||||
_server_name: &ServerName,
|
||||
_scts: &mut dyn Iterator<Item = &[u8]>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: SystemTime,
|
||||
) -> Result<ServerCertVerified, rustls::Error> {
|
||||
Ok(ServerCertVerified::assertion())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user