Zero-downtime password rotation

This commit is contained in:
Lev Kokotov
2023-03-30 11:55:27 -07:00
parent 6f768a84ce
commit 5c673b4333
10 changed files with 672 additions and 407 deletions

View File

@@ -58,9 +58,9 @@ tcp_keepalives_count = 5
tcp_keepalives_interval = 5 tcp_keepalives_interval = 5
# Path to TLS Certficate file to use for TLS connections # Path to TLS Certficate file to use for TLS connections
# tls_certificate = "server.cert" tls_certificate = ".circleci/server.cert"
# Path to TLS private key file to use for TLS connections # Path to TLS private key file to use for TLS connections
# tls_private_key = "server.key" tls_private_key = ".circleci/server.key"
# User name to access the virtual administrative database (pgbouncer or pgcat) # User name to access the virtual administrative database (pgbouncer or pgcat)
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc.. # Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
@@ -129,6 +129,10 @@ connect_timeout = 3000
username = "sharding_user" username = "sharding_user"
# Postgresql password # Postgresql password
password = "sharding_user" password = "sharding_user"
# Passwords the client can use to connect. Useful for password rotations.
secrets = [ "secret_one", "secret_two" ]
# Maximum number of server connections that can be established for this user # Maximum number of server connections that can be established for this user
# The maximum number of connection from a single Pgcat process to any database in the cluster # The maximum number of connection from a single Pgcat process to any database in the cluster
# is the sum of pool_size across all users. # is the sum of pool_size across all users.

View File

@@ -780,7 +780,7 @@ where
let database = parts[0]; let database = parts[0];
let user = parts[1]; let user = parts[1];
match get_pool(database, user) { match get_pool(database, user, None) {
Some(pool) => { Some(pool) => {
pool.pause(); pool.pause();
@@ -827,7 +827,7 @@ where
let database = parts[0]; let database = parts[0];
let user = parts[1]; let user = parts[1];
match get_pool(database, user) { match get_pool(database, user, None) {
Some(pool) => { Some(pool) => {
pool.resume(); pool.resume();

382
src/auth.rs Normal file
View File

@@ -0,0 +1,382 @@
//! Module implementing various client authentication mechanisms.
//!
//! Currently supported: plain (via TLS), md5 (via TLS and plain text connection).
use crate::errors::Error;
use crate::tokio::io::AsyncReadExt;
use crate::{
config::get_config,
messages::{error_response, md5_hash_password, write_all, wrong_password, md5_hash_second_pass},
pool::{get_pool, ConnectionPool},
auth_passthrough::AuthPassthrough,
};
use bytes::{BufMut, BytesMut};
use log::debug;
async fn refetch_auth_hash<S>(pool: &ConnectionPool, stream: &mut S, username: &str, pool_name: &str) -> Result<String, Error>
where S: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send
{
let address = pool.address(0, 0);
if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) {
let hash = apt.fetch_hash(address).await?;
return Ok(hash);
}
error_response(
stream,
&format!(
"No password set and auth passthrough failed for database: {}, user: {}",
pool_name, username
),
).await?;
Err(Error::ClientError(format!(
"Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.",
address.username, address.database
)))
}
/// Read 'p' message from client.
async fn response<R>(stream: &mut R) -> Result<Vec<u8>, Error>
where
R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send,
{
let code = match stream.read_u8().await {
Ok(code) => code,
Err(_) => {
return Err(Error::SocketError(
"Error reading password code from client".to_string(),
))
}
};
if code as char != 'p' {
return Err(Error::SocketError(format!("Expected p, got {}", code)));
}
let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => {
return Err(Error::SocketError(
"Error reading password length from client".to_string(),
))
}
};
let mut response = vec![0; (len - 4) as usize];
match stream.read_exact(&mut response).await {
Ok(_) => (),
Err(_) => {
return Err(Error::SocketError(
"Error reading password from client".to_string(),
))
}
};
Ok(response.to_vec())
}
/// Make sure the pool we authenticated to has at least one server connection
/// that can serve our request.
async fn validate_pool<W>(
stream: &mut W,
mut pool: ConnectionPool,
username: &str,
pool_name: &str,
) -> Result<(), Error>
where
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
if !pool.validated() {
match pool.validate().await {
Ok(_) => Ok(()),
Err(err) => {
error_response(
stream,
&format!(
"Pool down for database: {:?}, user: {:?}",
pool_name, username,
),
)
.await?;
Err(Error::ClientError(format!("Pool down: {:?}", err)))
}
}
} else {
Ok(())
}
}
/// Clear text authentication.
///
/// The client will send the password in plain text over the wire.
/// To protect against obvious security issues, this is only used over TLS.
///
/// Clear text authentication is used to support zero-downtime password rotation.
/// It allows the client to use multiple passwords when talking to the PgCat
/// while the password is being rotated across multiple app instances.
pub struct ClearText {
username: String,
pool_name: String,
application_name: String,
}
impl ClearText {
/// Create a new ClearText authentication mechanism.
pub fn new(username: &str, pool_name: &str, application_name: &str) -> ClearText {
ClearText {
username: username.to_string(),
pool_name: pool_name.to_string(),
application_name: application_name.to_string(),
}
}
/// Issue 'R' clear text challenge to client.
pub async fn challenge<W>(&self, stream: &mut W) -> Result<(), Error>
where
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
debug!("Sending plain challenge");
let mut msg = BytesMut::new();
msg.put_u8(b'R');
msg.put_i32(8);
msg.put_i32(3); // Clear text
write_all(stream, msg).await
}
/// Authenticate client with server password or secret.
pub async fn authenticate<R, W>(
&self,
read: &mut R,
write: &mut W,
) -> Result<Option<String>, Error>
where
R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send,
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
let response = response(read).await?;
let secret = String::from_utf8_lossy(&response[0..response.len() - 1]).to_string();
match get_pool(&self.pool_name, &self.username, Some(secret.clone())) {
None => match get_pool(&self.pool_name, &self.username, None) {
Some(pool) => {
match pool.settings.user.password {
Some(ref password) => {
if password != &secret {
wrong_password(write, &self.username).await?;
Err(Error::ClientError(format!(
"Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
self.username, self.pool_name, self.application_name
)))
}
else {
validate_pool(write, pool, &self.username, &self.pool_name).await?;
Ok(None)
}
}
None => {
// Server is storing hashes, we can't query it for the plain text password.
error_response(
write,
&format!(
"No server password configured for database: {:?}, user: {:?}",
self.pool_name, self.username
),
)
.await?;
Err(Error::ClientError(format!(
"No server password configured for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
self.username, self.pool_name, self.application_name
)))
}
}
}
None => {
error_response(
write,
&format!(
"No pool configured for database: {:?}, user: {:?}",
self.pool_name, self.username
),
)
.await?;
Err(Error::ClientError(format!(
"Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
self.username, self.pool_name, self.application_name
)))
}
},
Some(pool) => {
validate_pool(write, pool, &self.username, &self.pool_name).await?;
Ok(Some(secret))
}
}
}
}
/// MD5 hash authentication.
///
/// Deprecated, but widely used everywhere, and currently required for poolers
/// to authencticate clients without involving Postgres.
///
/// Admin clients are required to use MD5.
pub struct Md5 {
username: String,
pool_name: String,
application_name: String,
salt: [u8; 4],
admin: bool,
}
impl Md5 {
pub fn new(username: &str, pool_name: &str, application_name: &str, admin: bool) -> Md5 {
let salt: [u8; 4] = [
rand::random(),
rand::random(),
rand::random(),
rand::random(),
];
Md5 {
username: username.to_string(),
pool_name: pool_name.to_string(),
application_name: application_name.to_string(),
salt,
admin,
}
}
/// Issue a 'R' MD5 challenge to the client.
pub async fn challenge<W>(&self, stream: &mut W) -> Result<(), Error>
where
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
let mut res = BytesMut::new();
res.put_u8(b'R');
res.put_i32(12);
res.put_i32(5); // MD5
res.put_slice(&self.salt[..]);
write_all(stream, res).await
}
/// Authenticate client with MD5. This is used for both admin and normal users.
pub async fn authenticate<R, W>(&self, read: &mut R, write: &mut W) -> Result<(), Error>
where
R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send,
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
let password_hash = response(read).await?;
if self.admin {
let config = get_config();
// Compare server and client hashes.
let our_hash = md5_hash_password(
&config.general.admin_username,
&config.general.admin_password,
&self.salt,
);
if our_hash != password_hash {
wrong_password(write, &self.username).await?;
Err(Error::ClientError(format!(
"Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
self.username, self.pool_name, self.application_name
)))
} else {
Ok(())
}
} else {
match get_pool(&self.pool_name, &self.username, None) {
Some(pool) => {
match &pool.settings.user.password {
Some(ref password) => {
let our_hash = md5_hash_password(&self.username, password, &self.salt);
if our_hash != password_hash {
wrong_password(write, &self.username).await?;
Err(Error::ClientError(format!(
"Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
self.username, self.pool_name, self.application_name
)))
} else {
validate_pool(write, pool, &self.username, &self.pool_name).await?;
Ok(())
}
}
None => {
// Fetch hash from server
let hash = (*pool.auth_hash.read()).clone();
let hash = match hash {
Some(hash) => hash.to_string(),
None => refetch_auth_hash(&pool, write, &self.username, &self.pool_name).await?,
};
let our_hash = md5_hash_second_pass(&hash, &self.salt);
// Compare hashes
if our_hash != password_hash {
// Server hash maybe changed
let hash = refetch_auth_hash(&pool, write, &self.username, &self.pool_name).await?;
let our_hash = md5_hash_second_pass(&hash, &self.salt);
if our_hash != password_hash {
wrong_password(write, &self.username).await?;
Err(Error::ClientError(format!(
"Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
self.username, self.pool_name, self.application_name
)))
} else {
(*pool.auth_hash.write()) = Some(hash);
validate_pool(write, pool.clone(), &self.username, &self.pool_name).await?;
Ok(())
}
} else {
wrong_password(write, &self.username).await?;
Err(Error::ClientError(format!(
"Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
self.username, self.pool_name, self.application_name
)))
}
}
}
}
None => {
error_response(
write,
&format!(
"No pool configured for database: {:?}, user: {:?}",
self.pool_name, self.username
),
)
.await?;
return Err(Error::ClientError(format!(
"Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
self.username, self.pool_name, self.application_name
)));
}
}
}
}
}

View File

@@ -73,6 +73,7 @@ impl AuthPassthrough {
password: Some(self.password.clone()), password: Some(self.password.clone()),
pool_size: 1, pool_size: 1,
statement_timeout: 0, statement_timeout: 0,
secrets: None,
}; };
let user = &address.username; let user = &address.username;

View File

@@ -2,7 +2,7 @@ use crate::errors::Error;
use crate::pool::BanReason; use crate::pool::BanReason;
/// Handle clients by pretending to be a PostgreSQL server. /// Handle clients by pretending to be a PostgreSQL server.
use bytes::{Buf, BufMut, BytesMut}; use bytes::{Buf, BufMut, BytesMut};
use log::{debug, error, info, trace, warn}; use log::{debug, error, info, trace};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
@@ -90,6 +90,9 @@ pub struct Client<S, T> {
/// Application name for this client (defaults to pgcat) /// Application name for this client (defaults to pgcat)
application_name: String, application_name: String,
/// Which secret the user is using to connect, if any.
secret: Option<String>,
/// Used to notify clients about an impending shutdown /// Used to notify clients about an impending shutdown
shutdown: Receiver<()>, shutdown: Receiver<()>,
} }
@@ -290,7 +293,7 @@ pub async fn client_entrypoint(
/// Handle the first message the client sends. /// Handle the first message the client sends.
async fn get_startup<S>(stream: &mut S) -> Result<(ClientConnectionType, BytesMut), Error> async fn get_startup<S>(stream: &mut S) -> Result<(ClientConnectionType, BytesMut), Error>
where where
S: tokio::io::AsyncRead + std::marker::Unpin + tokio::io::AsyncWrite, S: tokio::io::AsyncRead + std::marker::Unpin + tokio::io::AsyncWrite + std::marker::Send,
{ {
// Get startup message length. // Get startup message length.
let len = match stream.read_i32().await { let len = match stream.read_i32().await {
@@ -377,24 +380,10 @@ pub async fn startup_tls(
} }
} }
async fn refetch_auth_hash(pool: &ConnectionPool) -> Result<String, Error> {
let address = pool.address(0, 0);
if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) {
let hash = apt.fetch_hash(address).await?;
return Ok(hash);
}
Err(Error::ClientError(format!(
"Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.",
address.username, address.database
)))
}
impl<S, T> Client<S, T> impl<S, T> Client<S, T>
where where
S: tokio::io::AsyncRead + std::marker::Unpin, S: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send,
T: tokio::io::AsyncWrite + std::marker::Unpin, T: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{ {
pub fn is_admin(&self) -> bool { pub fn is_admin(&self) -> bool {
self.admin self.admin
@@ -457,161 +446,39 @@ where
let process_id: i32 = rand::random(); let process_id: i32 = rand::random();
let secret_key: i32 = rand::random(); let secret_key: i32 = rand::random();
// Perform MD5 authentication.
// TODO: Add SASL support.
let salt = md5_challenge(&mut write).await?;
let code = match read.read_u8().await {
Ok(p) => p,
Err(_) => return Err(Error::SocketError(format!("Error reading password code from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))),
};
// PasswordMessage
if code as char != 'p' {
return Err(Error::ProtocolSyncError(format!(
"Expected p, got {}",
code as char
)));
}
let len = match read.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::SocketError(format!("Error reading password message length from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))),
};
let mut password_response = vec![0u8; (len - 4) as usize];
match read.read_exact(&mut password_response).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading password message from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))),
};
// Authenticate admin user.
let (transaction_mode, server_info) = if admin {
let config = get_config(); let config = get_config();
// Compare server and client hashes.
let password_hash = md5_hash_password(
&config.general.admin_username,
&config.general.admin_password,
&salt,
);
if password_hash != password_response { let secret = if admin {
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name); debug!("Using md5 auth for admin");
wrong_password(&mut write, username).await?; let auth = crate::auth::Md5::new(&username, &pool_name, &application_name, true);
auth.challenge(&mut write).await?;
return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))); auth.authenticate(&mut read, &mut write).await?;
} None
} else if !config.tls_enabled() {
debug!("Using md5 auth");
let auth = crate::auth::Md5::new(&username, &pool_name, &application_name, false);
auth.challenge(&mut write).await?;
auth.authenticate(&mut read, &mut write).await?;
None
} else {
debug!("Using plain auth");
let auth = crate::auth::ClearText::new(&username, &pool_name, &application_name);
auth.challenge(&mut write).await?;
auth.authenticate(&mut read, &mut write).await?
};
// Authenticated admin user.
let (transaction_mode, server_info) = if admin {
(false, generate_server_info_for_admin()) (false, generate_server_info_for_admin())
} }
// Authenticate normal user. // Authenticated normal user.
else { else {
let mut pool = match get_pool(pool_name, username) { let pool = get_pool(&pool_name, &username, secret.clone()).unwrap();
Some(pool) => pool,
None => {
error_response(
&mut write,
&format!(
"No pool configured for database: {:?}, user: {:?}",
pool_name, username
),
)
.await?;
return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name)));
}
};
// Obtain the hash to compare, we give preference to that written in cleartext in config
// if there is nothing set in cleartext and auth passthrough (auth_query) is configured, we use the hash obtained
// when the pool was created. If there is no hash there, we try to fetch it one more time.
let password_hash = if let Some(password) = &pool.settings.user.password {
Some(md5_hash_password(username, password, &salt))
} else {
if !get_config().is_auth_query_configured() {
return Err(Error::ClientError(format!("Client auth not possible, no cleartext password set for username: {:?} in config and auth passthrough (query_auth) is not set up.", username)));
}
let mut hash = (*pool.auth_hash.read()).clone();
if hash.is_none() {
warn!("Query auth configured but no hash password found for pool {}. Will try to refetch it.", pool_name);
match refetch_auth_hash(&pool).await {
Ok(fetched_hash) => {
warn!("Password for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, obtained. Updating.", username, pool_name, application_name);
{
let mut pool_auth_hash = pool.auth_hash.write();
*pool_auth_hash = Some(fetched_hash.clone());
}
hash = Some(fetched_hash);
}
Err(err) => {
return Err(
Error::ClientError(
format!("No cleartext password set, and no auth passthrough could not obtain the hash from server for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, the error was: {:?}",
username,
pool_name,
application_name,
err)
)
);
}
}
};
Some(md5_hash_second_pass(&hash.unwrap(), &salt))
};
// Once we have the resulting hash, we compare with what the client gave us.
// If they do not match and auth query is set up, we try to refetch the hash one more time
// to see if the password has changed since the pool was created.
//
// @TODO: we could end up fetching again the same password twice (see above).
if password_hash.unwrap() != password_response {
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, will try to refetch it.", username, pool_name, application_name);
let fetched_hash = refetch_auth_hash(&pool).await?;
let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt);
// Ok password changed in server an auth is possible.
if new_password_hash == password_response {
warn!("Password for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, changed in server. Updating.", username, pool_name, application_name);
{
let mut pool_auth_hash = pool.auth_hash.write();
*pool_auth_hash = Some(fetched_hash);
}
} else {
wrong_password(&mut write, username).await?;
return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name)));
}
}
let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction; let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction;
// If the pool hasn't been validated yet,
// connect to the servers and figure out what's what.
if !pool.validated() {
match pool.validate().await {
Ok(_) => (),
Err(err) => {
error_response(
&mut write,
&format!(
"Pool down for database: {:?}, user: {:?}",
pool_name, username
),
)
.await?;
return Err(Error::ClientError(format!("Pool down: {:?}", err)));
}
}
}
(transaction_mode, pool.server_info()) (transaction_mode, pool.server_info())
}; };
debug!("Password authentication successful"); debug!("Authentication successful");
auth_ok(&mut write).await?; auth_ok(&mut write).await?;
write_all(&mut write, server_info).await?; write_all(&mut write, server_info).await?;
@@ -619,7 +486,7 @@ where
ready_for_query(&mut write).await?; ready_for_query(&mut write).await?;
trace!("Startup OK"); trace!("Startup OK");
let pool_stats = match get_pool(pool_name, username) { let pool_stats = match get_pool(pool_name, username, secret.clone()) {
Some(pool) => { Some(pool) => {
if !admin { if !admin {
pool.stats pool.stats
@@ -659,6 +526,7 @@ where
application_name: application_name.to_string(), application_name: application_name.to_string(),
shutdown, shutdown,
connected_to_server: false, connected_to_server: false,
secret,
}) })
} }
@@ -693,6 +561,7 @@ where
application_name: String::from("undefined"), application_name: String::from("undefined"),
shutdown, shutdown,
connected_to_server: false, connected_to_server: false,
secret: None,
}) })
} }
@@ -1200,7 +1069,7 @@ where
/// Retrieve connection pool, if it exists. /// Retrieve connection pool, if it exists.
/// Return an error to the client otherwise. /// Return an error to the client otherwise.
async fn get_pool(&mut self) -> Result<ConnectionPool, Error> { async fn get_pool(&mut self) -> Result<ConnectionPool, Error> {
match get_pool(&self.pool_name, &self.username) { match get_pool(&self.pool_name, &self.username, self.secret.clone()) {
Some(pool) => Ok(pool), Some(pool) => Ok(pool),
None => { None => {
error_response( error_response(

View File

@@ -181,6 +181,13 @@ pub struct User {
pub pool_size: u32, pub pool_size: u32,
#[serde(default)] // 0 #[serde(default)] // 0
pub statement_timeout: u64, pub statement_timeout: u64,
pub secrets: Option<Vec<String>>,
}
impl User {
fn validate(&self) -> Result<(), Error> {
Ok(())
}
} }
impl Default for User { impl Default for User {
@@ -190,6 +197,7 @@ impl Default for User {
password: None, password: None,
pool_size: 15, pool_size: 15,
statement_timeout: 0, statement_timeout: 0,
secrets: None,
} }
} }
} }
@@ -508,6 +516,10 @@ impl Pool {
None => None, None => None,
}; };
for user in self.users.iter() {
user.1.validate()?;
}
Ok(()) Ok(())
} }
} }
@@ -657,6 +669,11 @@ impl Config {
} }
} }
} }
/// Checks that we configured TLS.
pub fn tls_enabled(&self) -> bool {
self.general.tls_certificate.is_some() && self.general.tls_private_key.is_some()
}
} }
impl Default for Config { impl Default for Config {

View File

@@ -62,6 +62,7 @@ use tokio::sync::broadcast;
mod admin; mod admin;
mod auth_passthrough; mod auth_passthrough;
mod auth;
mod client; mod client;
mod config; mod config;
mod constants; mod constants;

View File

@@ -46,29 +46,6 @@ where
write_all(stream, auth_ok).await write_all(stream, auth_ok).await
} }
/// Generate md5 password challenge.
pub async fn md5_challenge<S>(stream: &mut S) -> Result<[u8; 4], Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
// let mut rng = rand::thread_rng();
let salt: [u8; 4] = [
rand::random(),
rand::random(),
rand::random(),
rand::random(),
];
let mut res = BytesMut::new();
res.put_u8(b'R');
res.put_i32(12);
res.put_i32(5); // MD5
res.put_slice(&salt[..]);
write_all(stream, res).await?;
Ok(salt)
}
/// Give the client the process_id and secret we generated /// Give the client the process_id and secret we generated
/// used in query cancellation. /// used in query cancellation.
pub async fn backend_key_data<S>( pub async fn backend_key_data<S>(

View File

@@ -34,7 +34,7 @@ impl MirroredClient {
None => (default, default, crate::config::Pool::default()), None => (default, default, crate::config::Pool::default()),
}; };
let identifier = PoolIdentifier::new(&self.database, &self.user.username); let identifier = PoolIdentifier::new(&self.database, &self.user.username, None);
let manager = ServerPool::new( let manager = ServerPool::new(
self.address.clone(), self.address.clone(),

View File

@@ -59,24 +59,22 @@ pub struct PoolIdentifier {
/// The username the client connects with. Each user gets its own pool. /// The username the client connects with. Each user gets its own pool.
pub user: String, pub user: String,
/// The client secret (password).
pub secret: Option<String>,
} }
impl PoolIdentifier { impl PoolIdentifier {
/// Create a new user/pool identifier. /// Create a new user/pool identifier.
pub fn new(db: &str, user: &str) -> PoolIdentifier { pub fn new(db: &str, user: &str, secret: Option<String>) -> PoolIdentifier {
PoolIdentifier { PoolIdentifier {
db: db.to_string(), db: db.to_string(),
user: user.to_string(), user: user.to_string(),
secret,
} }
} }
} }
impl From<&Address> for PoolIdentifier {
fn from(address: &Address) -> PoolIdentifier {
PoolIdentifier::new(&address.database, &address.username)
}
}
/// Pool settings. /// Pool settings.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct PoolSettings { pub struct PoolSettings {
@@ -210,8 +208,23 @@ impl ConnectionPool {
// There is one pool per database/user pair. // There is one pool per database/user pair.
for user in pool_config.users.values() { for user in pool_config.users.values() {
let old_pool_ref = get_pool(pool_name, &user.username); let mut secrets = match &user.secrets {
let identifier = PoolIdentifier::new(pool_name, &user.username); Some(_) => user
.secrets
.as_ref()
.unwrap()
.iter()
.map(|secret| Some(secret.to_string()))
.collect::<Vec<Option<String>>>(),
None => vec![],
};
secrets.push(None);
for secret in secrets {
let old_pool_ref = get_pool(pool_name, &user.username, secret.clone());
let identifier = PoolIdentifier::new(pool_name, &user.username, secret.clone());
match old_pool_ref { match old_pool_ref {
Some(pool) => { Some(pool) => {
@@ -427,7 +440,8 @@ impl ConnectionPool {
}); });
// There is one pool per database/user pair. // There is one pool per database/user pair.
new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool); new_pools.insert(PoolIdentifier::new(pool_name, &user.username, secret), pool);
}
} }
} }
@@ -924,10 +938,10 @@ impl ManageConnection for ServerPool {
} }
/// Get the connection pool /// Get the connection pool
pub fn get_pool(db: &str, user: &str) -> Option<ConnectionPool> { pub fn get_pool(db: &str, user: &str, secret: Option<String>) -> Option<ConnectionPool> {
(*(*POOLS.load())) let identifier = PoolIdentifier::new(db, user, secret);
.get(&PoolIdentifier::new(db, user))
.cloned() (*(*POOLS.load())).get(&identifier).cloned()
} }
/// Get a pointer to all configured pools. /// Get a pointer to all configured pools.