From b974aacd719b076036ae8b1a60f682db2d578e63 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Mon, 27 Jun 2022 09:46:33 -0700 Subject: [PATCH] check --- Cargo.lock | 172 ++++++++++++++++++++++++++++++++++++++++++++++++ Cargo.toml | 2 + pgcat.toml | 4 ++ src/client.rs | 12 ++-- src/config.rs | 4 ++ src/errors.rs | 1 + src/main.rs | 3 + src/messages.rs | 42 +++++++----- src/stream.rs | 160 ++++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 378 insertions(+), 22 deletions(-) create mode 100644 src/stream.rs diff --git a/Cargo.lock b/Cargo.lock index b8b4bfd..cfaa0fb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -79,12 +79,24 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bumpalo" +version = "3.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37ccbd214614c6783386c1af30caf03192f17891059cecc394b4fb119e363de3" + [[package]] name = "bytes" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8" +[[package]] +name = "cc" +version = "1.0.73" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11" + [[package]] name = "cfg-if" version = "1.0.0" @@ -236,6 +248,21 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "js-sys" +version = "0.3.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3fac17f7123a73ca62df411b1bf727ccc805daa070338fda671c86dac1bdc27" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + [[package]] name = "libc" version = "0.2.126" @@ -385,6 +412,7 @@ dependencies = [ "parking_lot", "rand", "regex", + "rustls-pemfile", "serde", "serde_derive", "sha-1", @@ -392,6 +420,7 @@ dependencies = [ "sqlparser", "stringprep", "tokio", + "tokio-rustls", "toml", ] @@ -497,12 +526,58 @@ version = "0.6.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" +[[package]] +name = "ring" +version = "0.16.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" +dependencies = [ + "cc", + "libc", + "once_cell", + "spin", + "untrusted", + "web-sys", + "winapi", +] + +[[package]] +name = "rustls" +version = "0.20.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aab8ee6c7097ed6057f43c187a62418d0c05a4bd5f18b3571db50ee0f9ce033" +dependencies = [ + "log", + "ring", + "sct", + "webpki", +] + +[[package]] +name = "rustls-pemfile" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7522c9de787ff061458fe9a829dc790a3f5b22dc571694fc5883f448b94d9a9" +dependencies = [ + "base64", +] + [[package]] name = "scopeguard" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "sct" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "serde" version = "1.0.136" @@ -563,6 +638,12 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2dd574626839106c320a323308629dcb1acfc96e32a8cba364ddc61ac23ee83" +[[package]] +name = "spin" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" + [[package]] name = "sqlparser" version = "0.14.0" @@ -664,6 +745,17 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-rustls" +version = "0.23.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59" +dependencies = [ + "rustls", + "tokio", + "webpki", +] + [[package]] name = "toml" version = "0.5.8" @@ -700,6 +792,12 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3" +[[package]] +name = "untrusted" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" + [[package]] name = "version_check" version = "0.9.4" @@ -712,6 +810,80 @@ version = "0.10.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f" +[[package]] +name = "wasm-bindgen" +version = "0.2.81" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c53b543413a17a202f4be280a7e5c62a1c69345f5de525ee64f8cfdbc954994" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.81" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5491a68ab4500fa6b4d726bd67408630c3dbe9c4fe7bda16d5c82a1fd8c7340a" +dependencies = [ + "bumpalo", + "lazy_static", + "log", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.81" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c441e177922bc58f1e12c022624b6216378e5febc2f0533e41ba443d505b80aa" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.81" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d94ac45fcf608c1f45ef53e748d35660f168490c10b23704c7779ab8f5c3048" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.81" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a89911bd99e5f3659ec4acf9c4d93b0a90fe4a2a11f15328472058edc5261be" + +[[package]] +name = "web-sys" +version = "0.3.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fed94beee57daf8dd7d51f2b15dc2bcde92d7a72304cdf662a4371008b71b90" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "webpki" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/Cargo.toml b/Cargo.toml index fa63c0e..b193e77 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,3 +29,5 @@ hmac = "0.12" sha2 = "0.10" base64 = "0.13" stringprep = "0.1" +tokio-rustls = "*" +rustls-pemfile = "*" diff --git a/pgcat.toml b/pgcat.toml index 70b2fae..b2a6b76 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -32,6 +32,10 @@ ban_time = 60 # Seconds # Reload config automatically if it changes. autoreload = false +# TLS +tls_certificate = "server.cert" +tls_private_key = "server.key" + # # User to use for authentication against the server. [user] diff --git a/src/client.rs b/src/client.rs index ea1bf01..c02fcf8 100644 --- a/src/client.rs +++ b/src/client.rs @@ -19,13 +19,13 @@ use crate::server::Server; use crate::stats::{get_reporter, Reporter}; /// The client state. One of these is created per client. -pub struct Client { +pub struct Client { /// The reads are buffered (8K by default). - read: BufReader, + read: BufReader, /// We buffer the writes ourselves because we know the protocol /// better than a stock buffer. - write: OwnedWriteHalf, + write: S, /// Internal buffer, where we place messages until we have to flush /// them to the backend. @@ -63,13 +63,13 @@ pub struct Client { last_server_id: Option, } -impl Client { +impl Client { /// Perform client startup sequence. /// See docs: pub async fn startup( mut stream: TcpStream, client_server_map: ClientServerMap, - ) -> Result { + ) -> Result, Error> { let config = get_config(); let transaction_mode = config.general.pool_mode == "transaction"; let stats = get_reporter(); @@ -650,7 +650,7 @@ impl Client { } } -impl Drop for Client { +impl Drop for Client { fn drop(&mut self) { // Update statistics. if let Some(address_id) = self.last_address_id { diff --git a/src/config.rs b/src/config.rs index f6fa129..1700541 100644 --- a/src/config.rs +++ b/src/config.rs @@ -111,6 +111,8 @@ pub struct General { pub healthcheck_timeout: u64, pub ban_time: i64, pub autoreload: bool, + pub tls_certificate: Option, + pub tls_private_key: Option, } impl Default for General { @@ -124,6 +126,8 @@ impl Default for General { healthcheck_timeout: 1000, ban_time: 60, autoreload: false, + tls_certificate: None, + tls_private_key: None, } } } diff --git a/src/errors.rs b/src/errors.rs index b07d508..cc8f65d 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -10,4 +10,5 @@ pub enum Error { BadConfig, AllServersDown, ClientError, + TlsError, } diff --git a/src/main.rs b/src/main.rs index 7b78e5b..05bfb23 100644 --- a/src/main.rs +++ b/src/main.rs @@ -33,6 +33,8 @@ extern crate serde_derive; extern crate sqlparser; extern crate tokio; extern crate toml; +extern crate tokio_rustls; +extern crate rustls_pemfile; use log::{debug, error, info}; use parking_lot::Mutex; @@ -58,6 +60,7 @@ mod scram; mod server; mod sharding; mod stats; +mod stream; use config::{get_config, reload_config}; use pool::{ClientServerMap, ConnectionPool}; diff --git a/src/messages.rs b/src/messages.rs index 993545b..960a3b6 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -30,7 +30,8 @@ impl From<&DataType> for i32 { } /// Tell the client that authentication handshake completed successfully. -pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> { +pub async fn auth_ok(stream: &mut S) -> Result<(), Error> +where S: tokio::io::AsyncWrite + std::marker::Unpin { let mut auth_ok = BytesMut::with_capacity(9); auth_ok.put_u8(b'R'); @@ -41,7 +42,8 @@ pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> { } /// Generate md5 password challenge. -pub async fn md5_challenge(stream: &mut TcpStream) -> Result<[u8; 4], Error> { +pub async fn md5_challenge(stream: &mut S) -> Result<[u8; 4], Error> +where S: tokio::io::AsyncWrite + std::marker::Unpin { // let mut rng = rand::thread_rng(); let salt: [u8; 4] = [ rand::random(), @@ -62,11 +64,12 @@ pub async fn md5_challenge(stream: &mut TcpStream) -> Result<[u8; 4], Error> { /// Give the client the process_id and secret we generated /// used in query cancellation. -pub async fn backend_key_data( - stream: &mut TcpStream, +pub async fn backend_key_data( + stream: &mut S, backend_id: i32, secret_key: i32, -) -> Result<(), Error> { +) -> Result<(), Error> +where S: tokio::io::AsyncWrite + std::marker::Unpin { let mut key_data = BytesMut::from(&b"K"[..]); key_data.put_i32(12); key_data.put_i32(backend_id); @@ -87,7 +90,8 @@ pub fn simple_query(query: &str) -> BytesMut { } /// Tell the client we're ready for another query. -pub async fn ready_for_query(stream: &mut TcpStream) -> Result<(), Error> { +pub async fn ready_for_query(stream: &mut S) -> Result<(), Error> +where S: tokio::io::AsyncWrite + std::marker::Unpin { let mut bytes = BytesMut::with_capacity(5); bytes.put_u8(b'Z'); @@ -205,12 +209,13 @@ pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec { /// Send password challenge response to the server. /// This is the MD5 challenge. -pub async fn md5_password( - stream: &mut TcpStream, +pub async fn md5_password( + stream: &mut S, user: &str, password: &str, salt: &[u8], -) -> Result<(), Error> { +) -> Result<(), Error> +where S: tokio::io::AsyncWrite + std::marker::Unpin { let password = md5_hash_password(user, password, salt); let mut message = BytesMut::with_capacity(password.len() as usize + 5); @@ -225,10 +230,11 @@ pub async fn md5_password( /// Implements a response to our custom `SET SHARDING KEY` /// and `SET SERVER ROLE` commands. /// This tells the client we're ready for the next query. -pub async fn custom_protocol_response_ok( - stream: &mut OwnedWriteHalf, +pub async fn custom_protocol_response_ok( + stream: &mut S, message: &str, -) -> Result<(), Error> { +) -> Result<(), Error> +where S: tokio::io::AsyncWrite + std::marker::Unpin { let mut res = BytesMut::with_capacity(25); let set_complete = BytesMut::from(&format!("{}\0", message)[..]); @@ -250,7 +256,8 @@ pub async fn custom_protocol_response_ok( /// Send a custom error message to the client. /// Tell the client we are ready for the next query and no rollback is necessary. /// Docs on error codes: . -pub async fn error_response(stream: &mut OwnedWriteHalf, message: &str) -> Result<(), Error> { +pub async fn error_response(stream: &mut S, message: &str) -> Result<(), Error> +where S: tokio::io::AsyncWrite + std::marker::Unpin { let mut error = BytesMut::new(); // Error level @@ -291,7 +298,8 @@ pub async fn error_response(stream: &mut OwnedWriteHalf, message: &str) -> Resul Ok(write_all_half(stream, res).await?) } -pub async fn wrong_password(stream: &mut TcpStream, user: &str) -> Result<(), Error> { +pub async fn wrong_password(stream: &mut S, user: &str) -> Result<(), Error> +where S: tokio::io::AsyncWrite + std::marker::Unpin { let mut error = BytesMut::new(); // Error level @@ -430,7 +438,8 @@ pub fn command_complete(command: &str) -> BytesMut { } /// Write all data in the buffer to the TcpStream. -pub async fn write_all(stream: &mut TcpStream, buf: BytesMut) -> Result<(), Error> { +pub async fn write_all(stream: &mut S, buf: BytesMut) -> Result<(), Error> +where S: tokio::io::AsyncWrite + std::marker::Unpin { match stream.write_all(&buf).await { Ok(_) => Ok(()), Err(_) => return Err(Error::SocketError), @@ -438,7 +447,8 @@ pub async fn write_all(stream: &mut TcpStream, buf: BytesMut) -> Result<(), Erro } /// Write all the data in the buffer to the TcpStream, write owned half (see mpsc). -pub async fn write_all_half(stream: &mut OwnedWriteHalf, buf: BytesMut) -> Result<(), Error> { +pub async fn write_all_half(stream: &mut S, buf: BytesMut) -> Result<(), Error> +where S: tokio::io::AsyncWrite + std::marker::Unpin { match stream.write_all(&buf).await { Ok(_) => Ok(()), Err(_) => return Err(Error::SocketError), diff --git a/src/stream.rs b/src/stream.rs new file mode 100644 index 0000000..41d3cd1 --- /dev/null +++ b/src/stream.rs @@ -0,0 +1,160 @@ +// Stream wrapper. + +use bytes::{Buf, BufMut, BytesMut}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, split, ReadHalf, WriteHalf}; +use tokio::net::{ + tcp::{OwnedReadHalf, OwnedWriteHalf}, + TcpStream, +}; +use tokio_rustls::server::TlsStream; +use rustls_pemfile::{certs, rsa_private_keys}; +use tokio_rustls::rustls::{self, Certificate, PrivateKey}; +use tokio_rustls::TlsAcceptor; +use std::sync::Arc; +use std::path::Path; + +use crate::config::get_config; +use crate::errors::Error; + +// TLS +fn load_certs(path: &std::path::Path) -> std::io::Result> { + certs(&mut std::io::BufReader::new(std::fs::File::open(path)?)) + .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid cert")) + .map(|mut certs| certs.drain(..).map(Certificate).collect()) +} + +fn load_keys(path: &std::path::Path) -> std::io::Result> { + rsa_private_keys(&mut std::io::BufReader::new(std::fs::File::open(path)?)) + .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid key")) + .map(|mut keys| keys.drain(..).map(PrivateKey).collect()) +} + +struct Tls { + acceptor: TlsAcceptor, +} + +impl Tls { + pub fn new() -> Result { + let config = get_config(); + + let certs = match load_certs(&Path::new(&config.general.tls_certificate.unwrap())) { + Ok(certs) => certs, + Err(_) => return Err(Error::TlsError), + }; + + let mut keys = match load_keys(&Path::new(&config.general.tls_private_key.unwrap())) { + Ok(keys) => keys, + Err(_) => return Err(Error::TlsError), + }; + + let config = match rustls::ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(certs, keys.remove(0)) + .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err)) { + Ok(c) => c, + Err(_) => return Err(Error::TlsError) + }; + + Ok(Tls { + acceptor: TlsAcceptor::from(Arc::new(config)), + }) + } +} + +struct Stream { + read: Option>, + write: Option, + tls_read: Option>>>, + tls_write: Option>>, +} + + +impl Stream { + pub async fn new(stream: TcpStream, tls: Option) -> Result { + + let config = get_config(); + + match tls { + None => { + let (read, write) = stream.into_split(); + let read = BufReader::new(read); + Ok( + Self { + read: Some(read), + write: Some(write), + tls_read: None, + tls_write: None, + } + ) + } + + Some(tls) => { + let mut tls_stream = match tls.acceptor.accept(stream).await { + Ok(stream) => stream, + Err(_) => return Err(Error::TlsError), + }; + + let (read, write) = split(tls_stream); + + Ok(Self{ + read: None, + write: None, + tls_read: Some(BufReader::new(read)), + tls_write: Some(write), + }) + } + } + } + + async fn read(stream: &mut S) -> Result + where S: tokio::io::AsyncRead + std::marker::Unpin { + + let code = match stream.read_u8().await { + Ok(code) => code, + Err(_) => return Err(Error::SocketError), + }; + + let len = match stream.read_i32().await { + Ok(len) => len, + Err(_) => return Err(Error::SocketError), + }; + + let mut buf = vec![0u8; len as usize - 4]; + + match stream.read_exact(&mut buf).await { + Ok(_) => (), + Err(_) => return Err(Error::SocketError), + }; + + let mut bytes = BytesMut::with_capacity(len as usize + 1); + + bytes.put_u8(code); + bytes.put_i32(len); + bytes.put_slice(&buf); + + Ok(bytes) + } + + async fn write(stream: &mut S, buf: &BytesMut) -> Result<(), Error> + where S: tokio::io::AsyncWrite + std::marker::Unpin { + match stream.write_all(buf).await { + Ok(_) => Ok(()), + Err(_) => return Err(Error::SocketError), + } + } + + pub async fn read_message(&mut self) -> Result { + match &self.read { + Some(read) => Self::read(self.read.as_mut().unwrap()).await, + None => Self::read(self.tls_read.as_mut().unwrap()).await, + } + } + + pub async fn write_all(&mut self, buf: &BytesMut) -> Result<(), Error> { + match &self.write { + Some(write) => Self::write(self.write.as_mut().unwrap(), buf).await, + None => Self::write(self.tls_write.as_mut().unwrap(), buf).await, + } + } +} \ No newline at end of file