From 5c673b4333238037e133493535cfd30511721e19 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Thu, 30 Mar 2023 11:55:27 -0700 Subject: [PATCH] Zero-downtime password rotation --- pgcat.toml | 8 +- src/admin.rs | 4 +- src/auth.rs | 382 ++++++++++++++++++++++++++++++++++ src/auth_passthrough.rs | 1 + src/client.rs | 197 +++--------------- src/config.rs | 17 ++ src/main.rs | 1 + src/messages.rs | 23 --- src/mirrors.rs | 2 +- src/pool.rs | 444 +++++++++++++++++++++------------------- 10 files changed, 672 insertions(+), 407 deletions(-) create mode 100644 src/auth.rs diff --git a/pgcat.toml b/pgcat.toml index 0d883a3..b0d2fed 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -58,9 +58,9 @@ tcp_keepalives_count = 5 tcp_keepalives_interval = 5 # Path to TLS Certficate file to use for TLS connections -# tls_certificate = "server.cert" +tls_certificate = ".circleci/server.cert" # Path to TLS private key file to use for TLS connections -# tls_private_key = "server.key" +tls_private_key = ".circleci/server.key" # User name to access the virtual administrative database (pgbouncer or pgcat) # Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc.. @@ -129,6 +129,10 @@ connect_timeout = 3000 username = "sharding_user" # Postgresql password password = "sharding_user" + +# Passwords the client can use to connect. Useful for password rotations. +secrets = [ "secret_one", "secret_two" ] + # Maximum number of server connections that can be established for this user # The maximum number of connection from a single Pgcat process to any database in the cluster # is the sum of pool_size across all users. diff --git a/src/admin.rs b/src/admin.rs index 03af755..4df2206 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -780,7 +780,7 @@ where let database = parts[0]; let user = parts[1]; - match get_pool(database, user) { + match get_pool(database, user, None) { Some(pool) => { pool.pause(); @@ -827,7 +827,7 @@ where let database = parts[0]; let user = parts[1]; - match get_pool(database, user) { + match get_pool(database, user, None) { Some(pool) => { pool.resume(); diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..7865e80 --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,382 @@ +//! Module implementing various client authentication mechanisms. +//! +//! Currently supported: plain (via TLS), md5 (via TLS and plain text connection). + +use crate::errors::Error; +use crate::tokio::io::AsyncReadExt; +use crate::{ + config::get_config, + messages::{error_response, md5_hash_password, write_all, wrong_password, md5_hash_second_pass}, + pool::{get_pool, ConnectionPool}, + auth_passthrough::AuthPassthrough, +}; +use bytes::{BufMut, BytesMut}; +use log::debug; + +async fn refetch_auth_hash(pool: &ConnectionPool, stream: &mut S, username: &str, pool_name: &str) -> Result +where S: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send +{ + let address = pool.address(0, 0); + if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) { + let hash = apt.fetch_hash(address).await?; + + return Ok(hash); + } + + error_response( + stream, + &format!( + "No password set and auth passthrough failed for database: {}, user: {}", + pool_name, username + ), + ).await?; + + Err(Error::ClientError(format!( + "Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.", + address.username, address.database + ))) +} + +/// Read 'p' message from client. +async fn response(stream: &mut R) -> Result, Error> +where + R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send, +{ + let code = match stream.read_u8().await { + Ok(code) => code, + Err(_) => { + return Err(Error::SocketError( + "Error reading password code from client".to_string(), + )) + } + }; + + if code as char != 'p' { + return Err(Error::SocketError(format!("Expected p, got {}", code))); + } + + let len = match stream.read_i32().await { + Ok(len) => len, + Err(_) => { + return Err(Error::SocketError( + "Error reading password length from client".to_string(), + )) + } + }; + + let mut response = vec![0; (len - 4) as usize]; + + match stream.read_exact(&mut response).await { + Ok(_) => (), + Err(_) => { + return Err(Error::SocketError( + "Error reading password from client".to_string(), + )) + } + }; + + Ok(response.to_vec()) +} + +/// Make sure the pool we authenticated to has at least one server connection +/// that can serve our request. +async fn validate_pool( + stream: &mut W, + mut pool: ConnectionPool, + username: &str, + pool_name: &str, +) -> Result<(), Error> +where + W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send, +{ + if !pool.validated() { + match pool.validate().await { + Ok(_) => Ok(()), + Err(err) => { + error_response( + stream, + &format!( + "Pool down for database: {:?}, user: {:?}", + pool_name, username, + ), + ) + .await?; + + Err(Error::ClientError(format!("Pool down: {:?}", err))) + } + } + } else { + Ok(()) + } +} + +/// Clear text authentication. +/// +/// The client will send the password in plain text over the wire. +/// To protect against obvious security issues, this is only used over TLS. +/// +/// Clear text authentication is used to support zero-downtime password rotation. +/// It allows the client to use multiple passwords when talking to the PgCat +/// while the password is being rotated across multiple app instances. +pub struct ClearText { + username: String, + pool_name: String, + application_name: String, +} + +impl ClearText { + /// Create a new ClearText authentication mechanism. + pub fn new(username: &str, pool_name: &str, application_name: &str) -> ClearText { + ClearText { + username: username.to_string(), + pool_name: pool_name.to_string(), + application_name: application_name.to_string(), + } + } + + /// Issue 'R' clear text challenge to client. + pub async fn challenge(&self, stream: &mut W) -> Result<(), Error> + where + W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send, + { + debug!("Sending plain challenge"); + + let mut msg = BytesMut::new(); + msg.put_u8(b'R'); + msg.put_i32(8); + msg.put_i32(3); // Clear text + + write_all(stream, msg).await + } + + /// Authenticate client with server password or secret. + pub async fn authenticate( + &self, + read: &mut R, + write: &mut W, + ) -> Result, Error> + where + R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send, + W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send, + { + let response = response(read).await?; + + let secret = String::from_utf8_lossy(&response[0..response.len() - 1]).to_string(); + + match get_pool(&self.pool_name, &self.username, Some(secret.clone())) { + None => match get_pool(&self.pool_name, &self.username, None) { + Some(pool) => { + match pool.settings.user.password { + Some(ref password) => { + if password != &secret { + wrong_password(write, &self.username).await?; + Err(Error::ClientError(format!( + "Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", + self.username, self.pool_name, self.application_name + ))) + } + else { + validate_pool(write, pool, &self.username, &self.pool_name).await?; + + Ok(None) + } + } + + None => { + // Server is storing hashes, we can't query it for the plain text password. + error_response( + write, + &format!( + "No server password configured for database: {:?}, user: {:?}", + self.pool_name, self.username + ), + ) + .await?; + + Err(Error::ClientError(format!( + "No server password configured for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", + self.username, self.pool_name, self.application_name + ))) + } + } + } + + None => { + error_response( + write, + &format!( + "No pool configured for database: {:?}, user: {:?}", + self.pool_name, self.username + ), + ) + .await?; + + Err(Error::ClientError(format!( + "Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", + self.username, self.pool_name, self.application_name + ))) + } + }, + Some(pool) => { + validate_pool(write, pool, &self.username, &self.pool_name).await?; + Ok(Some(secret)) + } + } + } +} + +/// MD5 hash authentication. +/// +/// Deprecated, but widely used everywhere, and currently required for poolers +/// to authencticate clients without involving Postgres. +/// +/// Admin clients are required to use MD5. +pub struct Md5 { + username: String, + pool_name: String, + application_name: String, + salt: [u8; 4], + admin: bool, +} + +impl Md5 { + pub fn new(username: &str, pool_name: &str, application_name: &str, admin: bool) -> Md5 { + let salt: [u8; 4] = [ + rand::random(), + rand::random(), + rand::random(), + rand::random(), + ]; + + Md5 { + username: username.to_string(), + pool_name: pool_name.to_string(), + application_name: application_name.to_string(), + salt, + admin, + } + } + + /// Issue a 'R' MD5 challenge to the client. + pub async fn challenge(&self, stream: &mut W) -> Result<(), Error> + where + W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send, + { + let mut res = BytesMut::new(); + res.put_u8(b'R'); + res.put_i32(12); + res.put_i32(5); // MD5 + res.put_slice(&self.salt[..]); + + write_all(stream, res).await + } + + /// Authenticate client with MD5. This is used for both admin and normal users. + pub async fn authenticate(&self, read: &mut R, write: &mut W) -> Result<(), Error> + where + R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send, + W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send, + { + let password_hash = response(read).await?; + + if self.admin { + let config = get_config(); + + // Compare server and client hashes. + let our_hash = md5_hash_password( + &config.general.admin_username, + &config.general.admin_password, + &self.salt, + ); + + if our_hash != password_hash { + wrong_password(write, &self.username).await?; + Err(Error::ClientError(format!( + "Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", + self.username, self.pool_name, self.application_name + ))) + } else { + Ok(()) + } + } else { + match get_pool(&self.pool_name, &self.username, None) { + Some(pool) => { + match &pool.settings.user.password { + Some(ref password) => { + let our_hash = md5_hash_password(&self.username, password, &self.salt); + + if our_hash != password_hash { + wrong_password(write, &self.username).await?; + + Err(Error::ClientError(format!( + "Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", + self.username, self.pool_name, self.application_name + ))) + } else { + validate_pool(write, pool, &self.username, &self.pool_name).await?; + Ok(()) + } + } + + None => { + // Fetch hash from server + let hash = (*pool.auth_hash.read()).clone(); + + let hash = match hash { + Some(hash) => hash.to_string(), + None => refetch_auth_hash(&pool, write, &self.username, &self.pool_name).await?, + }; + + let our_hash = md5_hash_second_pass(&hash, &self.salt); + + // Compare hashes + if our_hash != password_hash { + // Server hash maybe changed + let hash = refetch_auth_hash(&pool, write, &self.username, &self.pool_name).await?; + let our_hash = md5_hash_second_pass(&hash, &self.salt); + + if our_hash != password_hash { + wrong_password(write, &self.username).await?; + + Err(Error::ClientError(format!( + "Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", + self.username, self.pool_name, self.application_name + ))) + } else { + (*pool.auth_hash.write()) = Some(hash); + + validate_pool(write, pool.clone(), &self.username, &self.pool_name).await?; + + Ok(()) + } + } else { + wrong_password(write, &self.username).await?; + + Err(Error::ClientError(format!( + "Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", + self.username, self.pool_name, self.application_name + ))) + } + } + } + } + + None => { + error_response( + write, + &format!( + "No pool configured for database: {:?}, user: {:?}", + self.pool_name, self.username + ), + ) + .await?; + + return Err(Error::ClientError(format!( + "Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", + self.username, self.pool_name, self.application_name + ))); + } + } + } + } +} diff --git a/src/auth_passthrough.rs b/src/auth_passthrough.rs index b9f0e97..33a9fb1 100644 --- a/src/auth_passthrough.rs +++ b/src/auth_passthrough.rs @@ -73,6 +73,7 @@ impl AuthPassthrough { password: Some(self.password.clone()), pool_size: 1, statement_timeout: 0, + secrets: None, }; let user = &address.username; diff --git a/src/client.rs b/src/client.rs index d75c069..5ef5fe4 100644 --- a/src/client.rs +++ b/src/client.rs @@ -2,7 +2,7 @@ use crate::errors::Error; use crate::pool::BanReason; /// Handle clients by pretending to be a PostgreSQL server. use bytes::{Buf, BufMut, BytesMut}; -use log::{debug, error, info, trace, warn}; +use log::{debug, error, info, trace}; use std::collections::HashMap; use std::sync::Arc; use std::time::Instant; @@ -90,6 +90,9 @@ pub struct Client { /// Application name for this client (defaults to pgcat) application_name: String, + /// Which secret the user is using to connect, if any. + secret: Option, + /// Used to notify clients about an impending shutdown shutdown: Receiver<()>, } @@ -290,7 +293,7 @@ pub async fn client_entrypoint( /// Handle the first message the client sends. async fn get_startup(stream: &mut S) -> Result<(ClientConnectionType, BytesMut), Error> where - S: tokio::io::AsyncRead + std::marker::Unpin + tokio::io::AsyncWrite, + S: tokio::io::AsyncRead + std::marker::Unpin + tokio::io::AsyncWrite + std::marker::Send, { // Get startup message length. let len = match stream.read_i32().await { @@ -377,24 +380,10 @@ pub async fn startup_tls( } } -async fn refetch_auth_hash(pool: &ConnectionPool) -> Result { - let address = pool.address(0, 0); - if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) { - let hash = apt.fetch_hash(address).await?; - - return Ok(hash); - } - - Err(Error::ClientError(format!( - "Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.", - address.username, address.database - ))) -} - impl Client where - S: tokio::io::AsyncRead + std::marker::Unpin, - T: tokio::io::AsyncWrite + std::marker::Unpin, + S: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send, + T: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send, { pub fn is_admin(&self) -> bool { self.admin @@ -457,161 +446,39 @@ where let process_id: i32 = rand::random(); let secret_key: i32 = rand::random(); - // Perform MD5 authentication. - // TODO: Add SASL support. - let salt = md5_challenge(&mut write).await?; + let config = get_config(); - let code = match read.read_u8().await { - Ok(p) => p, - Err(_) => return Err(Error::SocketError(format!("Error reading password code from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))), + let secret = if admin { + debug!("Using md5 auth for admin"); + let auth = crate::auth::Md5::new(&username, &pool_name, &application_name, true); + auth.challenge(&mut write).await?; + auth.authenticate(&mut read, &mut write).await?; + None + } else if !config.tls_enabled() { + debug!("Using md5 auth"); + let auth = crate::auth::Md5::new(&username, &pool_name, &application_name, false); + auth.challenge(&mut write).await?; + auth.authenticate(&mut read, &mut write).await?; + None + } else { + debug!("Using plain auth"); + let auth = crate::auth::ClearText::new(&username, &pool_name, &application_name); + auth.challenge(&mut write).await?; + auth.authenticate(&mut read, &mut write).await? }; - // PasswordMessage - if code as char != 'p' { - return Err(Error::ProtocolSyncError(format!( - "Expected p, got {}", - code as char - ))); - } - - let len = match read.read_i32().await { - Ok(len) => len, - Err(_) => return Err(Error::SocketError(format!("Error reading password message length from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))), - }; - - let mut password_response = vec![0u8; (len - 4) as usize]; - - match read.read_exact(&mut password_response).await { - Ok(_) => (), - Err(_) => return Err(Error::SocketError(format!("Error reading password message from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))), - }; - - // Authenticate admin user. + // Authenticated admin user. let (transaction_mode, server_info) = if admin { - let config = get_config(); - // Compare server and client hashes. - let password_hash = md5_hash_password( - &config.general.admin_username, - &config.general.admin_password, - &salt, - ); - - if password_hash != password_response { - warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name); - wrong_password(&mut write, username).await?; - - return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))); - } - (false, generate_server_info_for_admin()) } - // Authenticate normal user. + // Authenticated normal user. else { - let mut pool = match get_pool(pool_name, username) { - Some(pool) => pool, - None => { - error_response( - &mut write, - &format!( - "No pool configured for database: {:?}, user: {:?}", - pool_name, username - ), - ) - .await?; - - return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))); - } - }; - - // Obtain the hash to compare, we give preference to that written in cleartext in config - // if there is nothing set in cleartext and auth passthrough (auth_query) is configured, we use the hash obtained - // when the pool was created. If there is no hash there, we try to fetch it one more time. - let password_hash = if let Some(password) = &pool.settings.user.password { - Some(md5_hash_password(username, password, &salt)) - } else { - if !get_config().is_auth_query_configured() { - return Err(Error::ClientError(format!("Client auth not possible, no cleartext password set for username: {:?} in config and auth passthrough (query_auth) is not set up.", username))); - } - - let mut hash = (*pool.auth_hash.read()).clone(); - - if hash.is_none() { - warn!("Query auth configured but no hash password found for pool {}. Will try to refetch it.", pool_name); - match refetch_auth_hash(&pool).await { - Ok(fetched_hash) => { - warn!("Password for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, obtained. Updating.", username, pool_name, application_name); - { - let mut pool_auth_hash = pool.auth_hash.write(); - *pool_auth_hash = Some(fetched_hash.clone()); - } - - hash = Some(fetched_hash); - } - Err(err) => { - return Err( - Error::ClientError( - format!("No cleartext password set, and no auth passthrough could not obtain the hash from server for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, the error was: {:?}", - username, - pool_name, - application_name, - err) - ) - ); - } - } - }; - - Some(md5_hash_second_pass(&hash.unwrap(), &salt)) - }; - - // Once we have the resulting hash, we compare with what the client gave us. - // If they do not match and auth query is set up, we try to refetch the hash one more time - // to see if the password has changed since the pool was created. - // - // @TODO: we could end up fetching again the same password twice (see above). - if password_hash.unwrap() != password_response { - warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, will try to refetch it.", username, pool_name, application_name); - let fetched_hash = refetch_auth_hash(&pool).await?; - let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt); - - // Ok password changed in server an auth is possible. - if new_password_hash == password_response { - warn!("Password for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, changed in server. Updating.", username, pool_name, application_name); - { - let mut pool_auth_hash = pool.auth_hash.write(); - *pool_auth_hash = Some(fetched_hash); - } - } else { - wrong_password(&mut write, username).await?; - return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))); - } - } - + let pool = get_pool(&pool_name, &username, secret.clone()).unwrap(); let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction; - - // If the pool hasn't been validated yet, - // connect to the servers and figure out what's what. - if !pool.validated() { - match pool.validate().await { - Ok(_) => (), - Err(err) => { - error_response( - &mut write, - &format!( - "Pool down for database: {:?}, user: {:?}", - pool_name, username - ), - ) - .await?; - return Err(Error::ClientError(format!("Pool down: {:?}", err))); - } - } - } - (transaction_mode, pool.server_info()) }; - debug!("Password authentication successful"); + debug!("Authentication successful"); auth_ok(&mut write).await?; write_all(&mut write, server_info).await?; @@ -619,7 +486,7 @@ where ready_for_query(&mut write).await?; trace!("Startup OK"); - let pool_stats = match get_pool(pool_name, username) { + let pool_stats = match get_pool(pool_name, username, secret.clone()) { Some(pool) => { if !admin { pool.stats @@ -659,6 +526,7 @@ where application_name: application_name.to_string(), shutdown, connected_to_server: false, + secret, }) } @@ -693,6 +561,7 @@ where application_name: String::from("undefined"), shutdown, connected_to_server: false, + secret: None, }) } @@ -1200,7 +1069,7 @@ where /// Retrieve connection pool, if it exists. /// Return an error to the client otherwise. async fn get_pool(&mut self) -> Result { - match get_pool(&self.pool_name, &self.username) { + match get_pool(&self.pool_name, &self.username, self.secret.clone()) { Some(pool) => Ok(pool), None => { error_response( diff --git a/src/config.rs b/src/config.rs index 6545457..c124713 100644 --- a/src/config.rs +++ b/src/config.rs @@ -181,6 +181,13 @@ pub struct User { pub pool_size: u32, #[serde(default)] // 0 pub statement_timeout: u64, + pub secrets: Option>, +} + +impl User { + fn validate(&self) -> Result<(), Error> { + Ok(()) + } } impl Default for User { @@ -190,6 +197,7 @@ impl Default for User { password: None, pool_size: 15, statement_timeout: 0, + secrets: None, } } } @@ -508,6 +516,10 @@ impl Pool { None => None, }; + for user in self.users.iter() { + user.1.validate()?; + } + Ok(()) } } @@ -657,6 +669,11 @@ impl Config { } } } + + /// Checks that we configured TLS. + pub fn tls_enabled(&self) -> bool { + self.general.tls_certificate.is_some() && self.general.tls_private_key.is_some() + } } impl Default for Config { diff --git a/src/main.rs b/src/main.rs index 4c8987f..53179e2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -62,6 +62,7 @@ use tokio::sync::broadcast; mod admin; mod auth_passthrough; +mod auth; mod client; mod config; mod constants; diff --git a/src/messages.rs b/src/messages.rs index 61c36c6..c9a2e4a 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -46,29 +46,6 @@ where write_all(stream, auth_ok).await } -/// Generate md5 password challenge. -pub async fn md5_challenge(stream: &mut S) -> Result<[u8; 4], Error> -where - S: tokio::io::AsyncWrite + std::marker::Unpin, -{ - // let mut rng = rand::thread_rng(); - let salt: [u8; 4] = [ - rand::random(), - rand::random(), - rand::random(), - rand::random(), - ]; - - let mut res = BytesMut::new(); - res.put_u8(b'R'); - res.put_i32(12); - res.put_i32(5); // MD5 - res.put_slice(&salt[..]); - - write_all(stream, res).await?; - Ok(salt) -} - /// Give the client the process_id and secret we generated /// used in query cancellation. pub async fn backend_key_data( diff --git a/src/mirrors.rs b/src/mirrors.rs index 17f91d4..e8918f5 100644 --- a/src/mirrors.rs +++ b/src/mirrors.rs @@ -34,7 +34,7 @@ impl MirroredClient { None => (default, default, crate::config::Pool::default()), }; - let identifier = PoolIdentifier::new(&self.database, &self.user.username); + let identifier = PoolIdentifier::new(&self.database, &self.user.username, None); let manager = ServerPool::new( self.address.clone(), diff --git a/src/pool.rs b/src/pool.rs index e1ab7cb..c5431ef 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -59,24 +59,22 @@ pub struct PoolIdentifier { /// The username the client connects with. Each user gets its own pool. pub user: String, + + /// The client secret (password). + pub secret: Option, } impl PoolIdentifier { /// Create a new user/pool identifier. - pub fn new(db: &str, user: &str) -> PoolIdentifier { + pub fn new(db: &str, user: &str, secret: Option) -> PoolIdentifier { PoolIdentifier { db: db.to_string(), user: user.to_string(), + secret, } } } -impl From<&Address> for PoolIdentifier { - fn from(address: &Address) -> PoolIdentifier { - PoolIdentifier::new(&address.database, &address.username) - } -} - /// Pool settings. #[derive(Clone, Debug)] pub struct PoolSettings { @@ -210,224 +208,240 @@ impl ConnectionPool { // There is one pool per database/user pair. for user in pool_config.users.values() { - let old_pool_ref = get_pool(pool_name, &user.username); - let identifier = PoolIdentifier::new(pool_name, &user.username); + let mut secrets = match &user.secrets { + Some(_) => user + .secrets + .as_ref() + .unwrap() + .iter() + .map(|secret| Some(secret.to_string())) + .collect::>>(), + None => vec![], + }; + + secrets.push(None); - match old_pool_ref { - Some(pool) => { - // If the pool hasn't changed, get existing reference and insert it into the new_pools. - // We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8). - if pool.config_hash == new_pool_hash_value { - info!( - "[pool: {}][user: {}] has not changed", - pool_name, user.username - ); - new_pools.insert(identifier.clone(), pool.clone()); - continue; - } - } - None => (), - } + for secret in secrets { - info!( - "[pool: {}][user: {}] creating new pool", - pool_name, user.username - ); + let old_pool_ref = get_pool(pool_name, &user.username, secret.clone()); + let identifier = PoolIdentifier::new(pool_name, &user.username, secret.clone()); - let mut shards = Vec::new(); - let mut addresses = Vec::new(); - let mut banlist = Vec::new(); - let mut shard_ids = pool_config - .shards - .clone() - .into_keys() - .collect::>(); - let pool_stats = Arc::new(PoolStats::new(identifier, pool_config.clone())); - - // Allow the pool to be seen in statistics - pool_stats.register(pool_stats.clone()); - - // Sort by shard number to ensure consistency. - shard_ids.sort_by_key(|k| k.parse::().unwrap()); - let pool_auth_hash: Arc>> = Arc::new(RwLock::new(None)); - - for shard_idx in &shard_ids { - let shard = &pool_config.shards[shard_idx]; - let mut pools = Vec::new(); - let mut servers = Vec::new(); - let mut replica_number = 0; - - // Load Mirror settings - for (address_index, server) in shard.servers.iter().enumerate() { - let mut mirror_addresses = vec![]; - if let Some(mirror_settings_vec) = &shard.mirrors { - for (mirror_idx, mirror_settings) in - mirror_settings_vec.iter().enumerate() - { - if mirror_settings.mirroring_target_index != address_index { - continue; - } - mirror_addresses.push(Address { - id: address_id, - database: shard.database.clone(), - host: mirror_settings.host.clone(), - port: mirror_settings.port, - role: server.role, - address_index: mirror_idx, - replica_number, - shard: shard_idx.parse::().unwrap(), - username: user.username.clone(), - pool_name: pool_name.clone(), - mirrors: vec![], - stats: Arc::new(AddressStats::default()), - }); - address_id += 1; + match old_pool_ref { + Some(pool) => { + // If the pool hasn't changed, get existing reference and insert it into the new_pools. + // We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8). + if pool.config_hash == new_pool_hash_value { + info!( + "[pool: {}][user: {}] has not changed", + pool_name, user.username + ); + new_pools.insert(identifier.clone(), pool.clone()); + continue; } } - - let address = Address { - id: address_id, - database: shard.database.clone(), - host: server.host.clone(), - port: server.port, - role: server.role, - address_index, - replica_number, - shard: shard_idx.parse::().unwrap(), - username: user.username.clone(), - pool_name: pool_name.clone(), - mirrors: mirror_addresses, - stats: Arc::new(AddressStats::default()), - }; - - address_id += 1; - - if server.role == Role::Replica { - replica_number += 1; - } - - // We assume every server in the pool share user/passwords - let auth_passthrough = AuthPassthrough::from_pool_config(pool_config); - - if let Some(apt) = &auth_passthrough { - match apt.fetch_hash(&address).await { - Ok(ok) => { - if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) { - if ok != *pool_auth_hash_value { - warn!("Hash is not the same across shards of the same pool, client auth will \ - be done using last obtained hash. Server: {}:{}, Database: {}", server.host, server.port, shard.database); - } - } - debug!("Hash obtained for {:?}", address); - { - let mut pool_auth_hash = pool_auth_hash.write(); - *pool_auth_hash = Some(ok.clone()); - } - }, - Err(err) => warn!("Could not obtain password hashes using auth_query config, ignoring. Error: {:?}", err), - } - } - - let manager = ServerPool::new( - address.clone(), - user.clone(), - &shard.database, - client_server_map.clone(), - pool_stats.clone(), - pool_auth_hash.clone(), - ); - - let connect_timeout = match pool_config.connect_timeout { - Some(connect_timeout) => connect_timeout, - None => config.general.connect_timeout, - }; - - let idle_timeout = match pool_config.idle_timeout { - Some(idle_timeout) => idle_timeout, - None => config.general.idle_timeout, - }; - - let pool = Pool::builder() - .max_size(user.pool_size) - .connection_timeout(std::time::Duration::from_millis(connect_timeout)) - .idle_timeout(Some(std::time::Duration::from_millis(idle_timeout))) - .test_on_check_out(false) - .build(manager) - .await - .unwrap(); - - pools.push(pool); - servers.push(address); + None => (), } - shards.push(pools); - addresses.push(servers); - banlist.push(HashMap::new()); - } - - assert_eq!(shards.len(), addresses.len()); - if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) { info!( - "Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}", + "[pool: {}][user: {}] creating new pool", pool_name, user.username ); - } - let pool = ConnectionPool { - databases: shards, - stats: pool_stats, - addresses, - banlist: Arc::new(RwLock::new(banlist)), - config_hash: new_pool_hash_value, - server_info: Arc::new(RwLock::new(BytesMut::new())), - auth_hash: pool_auth_hash, - settings: PoolSettings { - pool_mode: pool_config.pool_mode, - load_balancing_mode: pool_config.load_balancing_mode, - // shards: pool_config.shards.clone(), - shards: shard_ids.len(), - user: user.clone(), - default_role: match pool_config.default_role.as_str() { - "any" => None, - "replica" => Some(Role::Replica), - "primary" => Some(Role::Primary), - _ => unreachable!(), + let mut shards = Vec::new(); + let mut addresses = Vec::new(); + let mut banlist = Vec::new(); + let mut shard_ids = pool_config + .shards + .clone() + .into_keys() + .collect::>(); + let pool_stats = Arc::new(PoolStats::new(identifier, pool_config.clone())); + + // Allow the pool to be seen in statistics + pool_stats.register(pool_stats.clone()); + + // Sort by shard number to ensure consistency. + shard_ids.sort_by_key(|k| k.parse::().unwrap()); + let pool_auth_hash: Arc>> = Arc::new(RwLock::new(None)); + + for shard_idx in &shard_ids { + let shard = &pool_config.shards[shard_idx]; + let mut pools = Vec::new(); + let mut servers = Vec::new(); + let mut replica_number = 0; + + // Load Mirror settings + for (address_index, server) in shard.servers.iter().enumerate() { + let mut mirror_addresses = vec![]; + if let Some(mirror_settings_vec) = &shard.mirrors { + for (mirror_idx, mirror_settings) in + mirror_settings_vec.iter().enumerate() + { + if mirror_settings.mirroring_target_index != address_index { + continue; + } + mirror_addresses.push(Address { + id: address_id, + database: shard.database.clone(), + host: mirror_settings.host.clone(), + port: mirror_settings.port, + role: server.role, + address_index: mirror_idx, + replica_number, + shard: shard_idx.parse::().unwrap(), + username: user.username.clone(), + pool_name: pool_name.clone(), + mirrors: vec![], + stats: Arc::new(AddressStats::default()), + }); + address_id += 1; + } + } + + let address = Address { + id: address_id, + database: shard.database.clone(), + host: server.host.clone(), + port: server.port, + role: server.role, + address_index, + replica_number, + shard: shard_idx.parse::().unwrap(), + username: user.username.clone(), + pool_name: pool_name.clone(), + mirrors: mirror_addresses, + stats: Arc::new(AddressStats::default()), + }; + + address_id += 1; + + if server.role == Role::Replica { + replica_number += 1; + } + + // We assume every server in the pool share user/passwords + let auth_passthrough = AuthPassthrough::from_pool_config(pool_config); + + if let Some(apt) = &auth_passthrough { + match apt.fetch_hash(&address).await { + Ok(ok) => { + if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) { + if ok != *pool_auth_hash_value { + warn!("Hash is not the same across shards of the same pool, client auth will \ + be done using last obtained hash. Server: {}:{}, Database: {}", server.host, server.port, shard.database); + } + } + debug!("Hash obtained for {:?}", address); + { + let mut pool_auth_hash = pool_auth_hash.write(); + *pool_auth_hash = Some(ok.clone()); + } + }, + Err(err) => warn!("Could not obtain password hashes using auth_query config, ignoring. Error: {:?}", err), + } + } + + let manager = ServerPool::new( + address.clone(), + user.clone(), + &shard.database, + client_server_map.clone(), + pool_stats.clone(), + pool_auth_hash.clone(), + ); + + let connect_timeout = match pool_config.connect_timeout { + Some(connect_timeout) => connect_timeout, + None => config.general.connect_timeout, + }; + + let idle_timeout = match pool_config.idle_timeout { + Some(idle_timeout) => idle_timeout, + None => config.general.idle_timeout, + }; + + let pool = Pool::builder() + .max_size(user.pool_size) + .connection_timeout(std::time::Duration::from_millis(connect_timeout)) + .idle_timeout(Some(std::time::Duration::from_millis(idle_timeout))) + .test_on_check_out(false) + .build(manager) + .await + .unwrap(); + + pools.push(pool); + servers.push(address); + } + + shards.push(pools); + addresses.push(servers); + banlist.push(HashMap::new()); + } + + assert_eq!(shards.len(), addresses.len()); + if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) { + info!( + "Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}", + pool_name, user.username + ); + } + + let pool = ConnectionPool { + databases: shards, + stats: pool_stats, + addresses, + banlist: Arc::new(RwLock::new(banlist)), + config_hash: new_pool_hash_value, + server_info: Arc::new(RwLock::new(BytesMut::new())), + auth_hash: pool_auth_hash, + settings: PoolSettings { + pool_mode: pool_config.pool_mode, + load_balancing_mode: pool_config.load_balancing_mode, + // shards: pool_config.shards.clone(), + shards: shard_ids.len(), + user: user.clone(), + default_role: match pool_config.default_role.as_str() { + "any" => None, + "replica" => Some(Role::Replica), + "primary" => Some(Role::Primary), + _ => unreachable!(), + }, + query_parser_enabled: pool_config.query_parser_enabled, + primary_reads_enabled: pool_config.primary_reads_enabled, + sharding_function: pool_config.sharding_function, + automatic_sharding_key: pool_config.automatic_sharding_key.clone(), + healthcheck_delay: config.general.healthcheck_delay, + healthcheck_timeout: config.general.healthcheck_timeout, + ban_time: config.general.ban_time, + sharding_key_regex: pool_config + .sharding_key_regex + .clone() + .map(|regex| Regex::new(regex.as_str()).unwrap()), + shard_id_regex: pool_config + .shard_id_regex + .clone() + .map(|regex| Regex::new(regex.as_str()).unwrap()), + regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000), + auth_query: pool_config.auth_query.clone(), + auth_query_user: pool_config.auth_query_user.clone(), + auth_query_password: pool_config.auth_query_password.clone(), }, - query_parser_enabled: pool_config.query_parser_enabled, - primary_reads_enabled: pool_config.primary_reads_enabled, - sharding_function: pool_config.sharding_function, - automatic_sharding_key: pool_config.automatic_sharding_key.clone(), - healthcheck_delay: config.general.healthcheck_delay, - healthcheck_timeout: config.general.healthcheck_timeout, - ban_time: config.general.ban_time, - sharding_key_regex: pool_config - .sharding_key_regex - .clone() - .map(|regex| Regex::new(regex.as_str()).unwrap()), - shard_id_regex: pool_config - .shard_id_regex - .clone() - .map(|regex| Regex::new(regex.as_str()).unwrap()), - regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000), - auth_query: pool_config.auth_query.clone(), - auth_query_user: pool_config.auth_query_user.clone(), - auth_query_password: pool_config.auth_query_password.clone(), - }, - validated: Arc::new(AtomicBool::new(false)), - paused: Arc::new(AtomicBool::new(false)), - paused_waiter: Arc::new(Notify::new()), - }; + validated: Arc::new(AtomicBool::new(false)), + paused: Arc::new(AtomicBool::new(false)), + paused_waiter: Arc::new(Notify::new()), + }; - // Connect to the servers to make sure pool configuration is valid - // before setting it globally. - // Do this async and somewhere else, we don't have to wait here. - let mut validate_pool = pool.clone(); - tokio::task::spawn(async move { - let _ = validate_pool.validate().await; - }); + // Connect to the servers to make sure pool configuration is valid + // before setting it globally. + // Do this async and somewhere else, we don't have to wait here. + let mut validate_pool = pool.clone(); + tokio::task::spawn(async move { + let _ = validate_pool.validate().await; + }); - // There is one pool per database/user pair. - new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool); + // There is one pool per database/user pair. + new_pools.insert(PoolIdentifier::new(pool_name, &user.username, secret), pool); + } } } @@ -924,10 +938,10 @@ impl ManageConnection for ServerPool { } /// Get the connection pool -pub fn get_pool(db: &str, user: &str) -> Option { - (*(*POOLS.load())) - .get(&PoolIdentifier::new(db, user)) - .cloned() +pub fn get_pool(db: &str, user: &str, secret: Option) -> Option { + let identifier = PoolIdentifier::new(db, user, secret); + + (*(*POOLS.load())).get(&identifier).cloned() } /// Get a pointer to all configured pools.