Zero-downtime password rotation

This commit is contained in:
Lev Kokotov
2023-03-30 11:55:27 -07:00
parent 6f768a84ce
commit 5c673b4333
10 changed files with 672 additions and 407 deletions

View File

@@ -58,9 +58,9 @@ tcp_keepalives_count = 5
tcp_keepalives_interval = 5
# Path to TLS Certficate file to use for TLS connections
# tls_certificate = "server.cert"
tls_certificate = ".circleci/server.cert"
# Path to TLS private key file to use for TLS connections
# tls_private_key = "server.key"
tls_private_key = ".circleci/server.key"
# User name to access the virtual administrative database (pgbouncer or pgcat)
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
@@ -129,6 +129,10 @@ connect_timeout = 3000
username = "sharding_user"
# Postgresql password
password = "sharding_user"
# Passwords the client can use to connect. Useful for password rotations.
secrets = [ "secret_one", "secret_two" ]
# Maximum number of server connections that can be established for this user
# The maximum number of connection from a single Pgcat process to any database in the cluster
# is the sum of pool_size across all users.

View File

@@ -780,7 +780,7 @@ where
let database = parts[0];
let user = parts[1];
match get_pool(database, user) {
match get_pool(database, user, None) {
Some(pool) => {
pool.pause();
@@ -827,7 +827,7 @@ where
let database = parts[0];
let user = parts[1];
match get_pool(database, user) {
match get_pool(database, user, None) {
Some(pool) => {
pool.resume();

382
src/auth.rs Normal file
View File

@@ -0,0 +1,382 @@
//! Module implementing various client authentication mechanisms.
//!
//! Currently supported: plain (via TLS), md5 (via TLS and plain text connection).
use crate::errors::Error;
use crate::tokio::io::AsyncReadExt;
use crate::{
config::get_config,
messages::{error_response, md5_hash_password, write_all, wrong_password, md5_hash_second_pass},
pool::{get_pool, ConnectionPool},
auth_passthrough::AuthPassthrough,
};
use bytes::{BufMut, BytesMut};
use log::debug;
async fn refetch_auth_hash<S>(pool: &ConnectionPool, stream: &mut S, username: &str, pool_name: &str) -> Result<String, Error>
where S: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send
{
let address = pool.address(0, 0);
if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) {
let hash = apt.fetch_hash(address).await?;
return Ok(hash);
}
error_response(
stream,
&format!(
"No password set and auth passthrough failed for database: {}, user: {}",
pool_name, username
),
).await?;
Err(Error::ClientError(format!(
"Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.",
address.username, address.database
)))
}
/// Read 'p' message from client.
async fn response<R>(stream: &mut R) -> Result<Vec<u8>, Error>
where
R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send,
{
let code = match stream.read_u8().await {
Ok(code) => code,
Err(_) => {
return Err(Error::SocketError(
"Error reading password code from client".to_string(),
))
}
};
if code as char != 'p' {
return Err(Error::SocketError(format!("Expected p, got {}", code)));
}
let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => {
return Err(Error::SocketError(
"Error reading password length from client".to_string(),
))
}
};
let mut response = vec![0; (len - 4) as usize];
match stream.read_exact(&mut response).await {
Ok(_) => (),
Err(_) => {
return Err(Error::SocketError(
"Error reading password from client".to_string(),
))
}
};
Ok(response.to_vec())
}
/// Make sure the pool we authenticated to has at least one server connection
/// that can serve our request.
async fn validate_pool<W>(
stream: &mut W,
mut pool: ConnectionPool,
username: &str,
pool_name: &str,
) -> Result<(), Error>
where
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
if !pool.validated() {
match pool.validate().await {
Ok(_) => Ok(()),
Err(err) => {
error_response(
stream,
&format!(
"Pool down for database: {:?}, user: {:?}",
pool_name, username,
),
)
.await?;
Err(Error::ClientError(format!("Pool down: {:?}", err)))
}
}
} else {
Ok(())
}
}
/// Clear text authentication.
///
/// The client will send the password in plain text over the wire.
/// To protect against obvious security issues, this is only used over TLS.
///
/// Clear text authentication is used to support zero-downtime password rotation.
/// It allows the client to use multiple passwords when talking to the PgCat
/// while the password is being rotated across multiple app instances.
pub struct ClearText {
username: String,
pool_name: String,
application_name: String,
}
impl ClearText {
/// Create a new ClearText authentication mechanism.
pub fn new(username: &str, pool_name: &str, application_name: &str) -> ClearText {
ClearText {
username: username.to_string(),
pool_name: pool_name.to_string(),
application_name: application_name.to_string(),
}
}
/// Issue 'R' clear text challenge to client.
pub async fn challenge<W>(&self, stream: &mut W) -> Result<(), Error>
where
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
debug!("Sending plain challenge");
let mut msg = BytesMut::new();
msg.put_u8(b'R');
msg.put_i32(8);
msg.put_i32(3); // Clear text
write_all(stream, msg).await
}
/// Authenticate client with server password or secret.
pub async fn authenticate<R, W>(
&self,
read: &mut R,
write: &mut W,
) -> Result<Option<String>, Error>
where
R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send,
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
let response = response(read).await?;
let secret = String::from_utf8_lossy(&response[0..response.len() - 1]).to_string();
match get_pool(&self.pool_name, &self.username, Some(secret.clone())) {
None => match get_pool(&self.pool_name, &self.username, None) {
Some(pool) => {
match pool.settings.user.password {
Some(ref password) => {
if password != &secret {
wrong_password(write, &self.username).await?;
Err(Error::ClientError(format!(
"Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
self.username, self.pool_name, self.application_name
)))
}
else {
validate_pool(write, pool, &self.username, &self.pool_name).await?;
Ok(None)
}
}
None => {
// Server is storing hashes, we can't query it for the plain text password.
error_response(
write,
&format!(
"No server password configured for database: {:?}, user: {:?}",
self.pool_name, self.username
),
)
.await?;
Err(Error::ClientError(format!(
"No server password configured for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
self.username, self.pool_name, self.application_name
)))
}
}
}
None => {
error_response(
write,
&format!(
"No pool configured for database: {:?}, user: {:?}",
self.pool_name, self.username
),
)
.await?;
Err(Error::ClientError(format!(
"Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
self.username, self.pool_name, self.application_name
)))
}
},
Some(pool) => {
validate_pool(write, pool, &self.username, &self.pool_name).await?;
Ok(Some(secret))
}
}
}
}
/// MD5 hash authentication.
///
/// Deprecated, but widely used everywhere, and currently required for poolers
/// to authencticate clients without involving Postgres.
///
/// Admin clients are required to use MD5.
pub struct Md5 {
username: String,
pool_name: String,
application_name: String,
salt: [u8; 4],
admin: bool,
}
impl Md5 {
pub fn new(username: &str, pool_name: &str, application_name: &str, admin: bool) -> Md5 {
let salt: [u8; 4] = [
rand::random(),
rand::random(),
rand::random(),
rand::random(),
];
Md5 {
username: username.to_string(),
pool_name: pool_name.to_string(),
application_name: application_name.to_string(),
salt,
admin,
}
}
/// Issue a 'R' MD5 challenge to the client.
pub async fn challenge<W>(&self, stream: &mut W) -> Result<(), Error>
where
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
let mut res = BytesMut::new();
res.put_u8(b'R');
res.put_i32(12);
res.put_i32(5); // MD5
res.put_slice(&self.salt[..]);
write_all(stream, res).await
}
/// Authenticate client with MD5. This is used for both admin and normal users.
pub async fn authenticate<R, W>(&self, read: &mut R, write: &mut W) -> Result<(), Error>
where
R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send,
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
let password_hash = response(read).await?;
if self.admin {
let config = get_config();
// Compare server and client hashes.
let our_hash = md5_hash_password(
&config.general.admin_username,
&config.general.admin_password,
&self.salt,
);
if our_hash != password_hash {
wrong_password(write, &self.username).await?;
Err(Error::ClientError(format!(
"Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
self.username, self.pool_name, self.application_name
)))
} else {
Ok(())
}
} else {
match get_pool(&self.pool_name, &self.username, None) {
Some(pool) => {
match &pool.settings.user.password {
Some(ref password) => {
let our_hash = md5_hash_password(&self.username, password, &self.salt);
if our_hash != password_hash {
wrong_password(write, &self.username).await?;
Err(Error::ClientError(format!(
"Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
self.username, self.pool_name, self.application_name
)))
} else {
validate_pool(write, pool, &self.username, &self.pool_name).await?;
Ok(())
}
}
None => {
// Fetch hash from server
let hash = (*pool.auth_hash.read()).clone();
let hash = match hash {
Some(hash) => hash.to_string(),
None => refetch_auth_hash(&pool, write, &self.username, &self.pool_name).await?,
};
let our_hash = md5_hash_second_pass(&hash, &self.salt);
// Compare hashes
if our_hash != password_hash {
// Server hash maybe changed
let hash = refetch_auth_hash(&pool, write, &self.username, &self.pool_name).await?;
let our_hash = md5_hash_second_pass(&hash, &self.salt);
if our_hash != password_hash {
wrong_password(write, &self.username).await?;
Err(Error::ClientError(format!(
"Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
self.username, self.pool_name, self.application_name
)))
} else {
(*pool.auth_hash.write()) = Some(hash);
validate_pool(write, pool.clone(), &self.username, &self.pool_name).await?;
Ok(())
}
} else {
wrong_password(write, &self.username).await?;
Err(Error::ClientError(format!(
"Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
self.username, self.pool_name, self.application_name
)))
}
}
}
}
None => {
error_response(
write,
&format!(
"No pool configured for database: {:?}, user: {:?}",
self.pool_name, self.username
),
)
.await?;
return Err(Error::ClientError(format!(
"Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
self.username, self.pool_name, self.application_name
)));
}
}
}
}
}

View File

@@ -73,6 +73,7 @@ impl AuthPassthrough {
password: Some(self.password.clone()),
pool_size: 1,
statement_timeout: 0,
secrets: None,
};
let user = &address.username;

View File

@@ -2,7 +2,7 @@ use crate::errors::Error;
use crate::pool::BanReason;
/// Handle clients by pretending to be a PostgreSQL server.
use bytes::{Buf, BufMut, BytesMut};
use log::{debug, error, info, trace, warn};
use log::{debug, error, info, trace};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
@@ -90,6 +90,9 @@ pub struct Client<S, T> {
/// Application name for this client (defaults to pgcat)
application_name: String,
/// Which secret the user is using to connect, if any.
secret: Option<String>,
/// Used to notify clients about an impending shutdown
shutdown: Receiver<()>,
}
@@ -290,7 +293,7 @@ pub async fn client_entrypoint(
/// Handle the first message the client sends.
async fn get_startup<S>(stream: &mut S) -> Result<(ClientConnectionType, BytesMut), Error>
where
S: tokio::io::AsyncRead + std::marker::Unpin + tokio::io::AsyncWrite,
S: tokio::io::AsyncRead + std::marker::Unpin + tokio::io::AsyncWrite + std::marker::Send,
{
// Get startup message length.
let len = match stream.read_i32().await {
@@ -377,24 +380,10 @@ pub async fn startup_tls(
}
}
async fn refetch_auth_hash(pool: &ConnectionPool) -> Result<String, Error> {
let address = pool.address(0, 0);
if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) {
let hash = apt.fetch_hash(address).await?;
return Ok(hash);
}
Err(Error::ClientError(format!(
"Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.",
address.username, address.database
)))
}
impl<S, T> Client<S, T>
where
S: tokio::io::AsyncRead + std::marker::Unpin,
T: tokio::io::AsyncWrite + std::marker::Unpin,
S: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send,
T: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
pub fn is_admin(&self) -> bool {
self.admin
@@ -457,161 +446,39 @@ where
let process_id: i32 = rand::random();
let secret_key: i32 = rand::random();
// Perform MD5 authentication.
// TODO: Add SASL support.
let salt = md5_challenge(&mut write).await?;
let config = get_config();
let code = match read.read_u8().await {
Ok(p) => p,
Err(_) => return Err(Error::SocketError(format!("Error reading password code from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))),
let secret = if admin {
debug!("Using md5 auth for admin");
let auth = crate::auth::Md5::new(&username, &pool_name, &application_name, true);
auth.challenge(&mut write).await?;
auth.authenticate(&mut read, &mut write).await?;
None
} else if !config.tls_enabled() {
debug!("Using md5 auth");
let auth = crate::auth::Md5::new(&username, &pool_name, &application_name, false);
auth.challenge(&mut write).await?;
auth.authenticate(&mut read, &mut write).await?;
None
} else {
debug!("Using plain auth");
let auth = crate::auth::ClearText::new(&username, &pool_name, &application_name);
auth.challenge(&mut write).await?;
auth.authenticate(&mut read, &mut write).await?
};
// PasswordMessage
if code as char != 'p' {
return Err(Error::ProtocolSyncError(format!(
"Expected p, got {}",
code as char
)));
}
let len = match read.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::SocketError(format!("Error reading password message length from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))),
};
let mut password_response = vec![0u8; (len - 4) as usize];
match read.read_exact(&mut password_response).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading password message from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))),
};
// Authenticate admin user.
// Authenticated admin user.
let (transaction_mode, server_info) = if admin {
let config = get_config();
// Compare server and client hashes.
let password_hash = md5_hash_password(
&config.general.admin_username,
&config.general.admin_password,
&salt,
);
if password_hash != password_response {
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name);
wrong_password(&mut write, username).await?;
return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name)));
}
(false, generate_server_info_for_admin())
}
// Authenticate normal user.
// Authenticated normal user.
else {
let mut pool = match get_pool(pool_name, username) {
Some(pool) => pool,
None => {
error_response(
&mut write,
&format!(
"No pool configured for database: {:?}, user: {:?}",
pool_name, username
),
)
.await?;
return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name)));
}
};
// Obtain the hash to compare, we give preference to that written in cleartext in config
// if there is nothing set in cleartext and auth passthrough (auth_query) is configured, we use the hash obtained
// when the pool was created. If there is no hash there, we try to fetch it one more time.
let password_hash = if let Some(password) = &pool.settings.user.password {
Some(md5_hash_password(username, password, &salt))
} else {
if !get_config().is_auth_query_configured() {
return Err(Error::ClientError(format!("Client auth not possible, no cleartext password set for username: {:?} in config and auth passthrough (query_auth) is not set up.", username)));
}
let mut hash = (*pool.auth_hash.read()).clone();
if hash.is_none() {
warn!("Query auth configured but no hash password found for pool {}. Will try to refetch it.", pool_name);
match refetch_auth_hash(&pool).await {
Ok(fetched_hash) => {
warn!("Password for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, obtained. Updating.", username, pool_name, application_name);
{
let mut pool_auth_hash = pool.auth_hash.write();
*pool_auth_hash = Some(fetched_hash.clone());
}
hash = Some(fetched_hash);
}
Err(err) => {
return Err(
Error::ClientError(
format!("No cleartext password set, and no auth passthrough could not obtain the hash from server for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, the error was: {:?}",
username,
pool_name,
application_name,
err)
)
);
}
}
};
Some(md5_hash_second_pass(&hash.unwrap(), &salt))
};
// Once we have the resulting hash, we compare with what the client gave us.
// If they do not match and auth query is set up, we try to refetch the hash one more time
// to see if the password has changed since the pool was created.
//
// @TODO: we could end up fetching again the same password twice (see above).
if password_hash.unwrap() != password_response {
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, will try to refetch it.", username, pool_name, application_name);
let fetched_hash = refetch_auth_hash(&pool).await?;
let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt);
// Ok password changed in server an auth is possible.
if new_password_hash == password_response {
warn!("Password for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, changed in server. Updating.", username, pool_name, application_name);
{
let mut pool_auth_hash = pool.auth_hash.write();
*pool_auth_hash = Some(fetched_hash);
}
} else {
wrong_password(&mut write, username).await?;
return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name)));
}
}
let pool = get_pool(&pool_name, &username, secret.clone()).unwrap();
let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction;
// If the pool hasn't been validated yet,
// connect to the servers and figure out what's what.
if !pool.validated() {
match pool.validate().await {
Ok(_) => (),
Err(err) => {
error_response(
&mut write,
&format!(
"Pool down for database: {:?}, user: {:?}",
pool_name, username
),
)
.await?;
return Err(Error::ClientError(format!("Pool down: {:?}", err)));
}
}
}
(transaction_mode, pool.server_info())
};
debug!("Password authentication successful");
debug!("Authentication successful");
auth_ok(&mut write).await?;
write_all(&mut write, server_info).await?;
@@ -619,7 +486,7 @@ where
ready_for_query(&mut write).await?;
trace!("Startup OK");
let pool_stats = match get_pool(pool_name, username) {
let pool_stats = match get_pool(pool_name, username, secret.clone()) {
Some(pool) => {
if !admin {
pool.stats
@@ -659,6 +526,7 @@ where
application_name: application_name.to_string(),
shutdown,
connected_to_server: false,
secret,
})
}
@@ -693,6 +561,7 @@ where
application_name: String::from("undefined"),
shutdown,
connected_to_server: false,
secret: None,
})
}
@@ -1200,7 +1069,7 @@ where
/// Retrieve connection pool, if it exists.
/// Return an error to the client otherwise.
async fn get_pool(&mut self) -> Result<ConnectionPool, Error> {
match get_pool(&self.pool_name, &self.username) {
match get_pool(&self.pool_name, &self.username, self.secret.clone()) {
Some(pool) => Ok(pool),
None => {
error_response(

View File

@@ -181,6 +181,13 @@ pub struct User {
pub pool_size: u32,
#[serde(default)] // 0
pub statement_timeout: u64,
pub secrets: Option<Vec<String>>,
}
impl User {
fn validate(&self) -> Result<(), Error> {
Ok(())
}
}
impl Default for User {
@@ -190,6 +197,7 @@ impl Default for User {
password: None,
pool_size: 15,
statement_timeout: 0,
secrets: None,
}
}
}
@@ -508,6 +516,10 @@ impl Pool {
None => None,
};
for user in self.users.iter() {
user.1.validate()?;
}
Ok(())
}
}
@@ -657,6 +669,11 @@ impl Config {
}
}
}
/// Checks that we configured TLS.
pub fn tls_enabled(&self) -> bool {
self.general.tls_certificate.is_some() && self.general.tls_private_key.is_some()
}
}
impl Default for Config {

View File

@@ -62,6 +62,7 @@ use tokio::sync::broadcast;
mod admin;
mod auth_passthrough;
mod auth;
mod client;
mod config;
mod constants;

View File

@@ -46,29 +46,6 @@ where
write_all(stream, auth_ok).await
}
/// Generate md5 password challenge.
pub async fn md5_challenge<S>(stream: &mut S) -> Result<[u8; 4], Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
// let mut rng = rand::thread_rng();
let salt: [u8; 4] = [
rand::random(),
rand::random(),
rand::random(),
rand::random(),
];
let mut res = BytesMut::new();
res.put_u8(b'R');
res.put_i32(12);
res.put_i32(5); // MD5
res.put_slice(&salt[..]);
write_all(stream, res).await?;
Ok(salt)
}
/// Give the client the process_id and secret we generated
/// used in query cancellation.
pub async fn backend_key_data<S>(

View File

@@ -34,7 +34,7 @@ impl MirroredClient {
None => (default, default, crate::config::Pool::default()),
};
let identifier = PoolIdentifier::new(&self.database, &self.user.username);
let identifier = PoolIdentifier::new(&self.database, &self.user.username, None);
let manager = ServerPool::new(
self.address.clone(),

View File

@@ -59,24 +59,22 @@ pub struct PoolIdentifier {
/// The username the client connects with. Each user gets its own pool.
pub user: String,
/// The client secret (password).
pub secret: Option<String>,
}
impl PoolIdentifier {
/// Create a new user/pool identifier.
pub fn new(db: &str, user: &str) -> PoolIdentifier {
pub fn new(db: &str, user: &str, secret: Option<String>) -> PoolIdentifier {
PoolIdentifier {
db: db.to_string(),
user: user.to_string(),
secret,
}
}
}
impl From<&Address> for PoolIdentifier {
fn from(address: &Address) -> PoolIdentifier {
PoolIdentifier::new(&address.database, &address.username)
}
}
/// Pool settings.
#[derive(Clone, Debug)]
pub struct PoolSettings {
@@ -210,224 +208,240 @@ impl ConnectionPool {
// There is one pool per database/user pair.
for user in pool_config.users.values() {
let old_pool_ref = get_pool(pool_name, &user.username);
let identifier = PoolIdentifier::new(pool_name, &user.username);
let mut secrets = match &user.secrets {
Some(_) => user
.secrets
.as_ref()
.unwrap()
.iter()
.map(|secret| Some(secret.to_string()))
.collect::<Vec<Option<String>>>(),
None => vec![],
};
secrets.push(None);
match old_pool_ref {
Some(pool) => {
// If the pool hasn't changed, get existing reference and insert it into the new_pools.
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
if pool.config_hash == new_pool_hash_value {
info!(
"[pool: {}][user: {}] has not changed",
pool_name, user.username
);
new_pools.insert(identifier.clone(), pool.clone());
continue;
}
}
None => (),
}
for secret in secrets {
info!(
"[pool: {}][user: {}] creating new pool",
pool_name, user.username
);
let old_pool_ref = get_pool(pool_name, &user.username, secret.clone());
let identifier = PoolIdentifier::new(pool_name, &user.username, secret.clone());
let mut shards = Vec::new();
let mut addresses = Vec::new();
let mut banlist = Vec::new();
let mut shard_ids = pool_config
.shards
.clone()
.into_keys()
.collect::<Vec<String>>();
let pool_stats = Arc::new(PoolStats::new(identifier, pool_config.clone()));
// Allow the pool to be seen in statistics
pool_stats.register(pool_stats.clone());
// Sort by shard number to ensure consistency.
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
let pool_auth_hash: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
for shard_idx in &shard_ids {
let shard = &pool_config.shards[shard_idx];
let mut pools = Vec::new();
let mut servers = Vec::new();
let mut replica_number = 0;
// Load Mirror settings
for (address_index, server) in shard.servers.iter().enumerate() {
let mut mirror_addresses = vec![];
if let Some(mirror_settings_vec) = &shard.mirrors {
for (mirror_idx, mirror_settings) in
mirror_settings_vec.iter().enumerate()
{
if mirror_settings.mirroring_target_index != address_index {
continue;
}
mirror_addresses.push(Address {
id: address_id,
database: shard.database.clone(),
host: mirror_settings.host.clone(),
port: mirror_settings.port,
role: server.role,
address_index: mirror_idx,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
username: user.username.clone(),
pool_name: pool_name.clone(),
mirrors: vec![],
stats: Arc::new(AddressStats::default()),
});
address_id += 1;
match old_pool_ref {
Some(pool) => {
// If the pool hasn't changed, get existing reference and insert it into the new_pools.
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
if pool.config_hash == new_pool_hash_value {
info!(
"[pool: {}][user: {}] has not changed",
pool_name, user.username
);
new_pools.insert(identifier.clone(), pool.clone());
continue;
}
}
let address = Address {
id: address_id,
database: shard.database.clone(),
host: server.host.clone(),
port: server.port,
role: server.role,
address_index,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
username: user.username.clone(),
pool_name: pool_name.clone(),
mirrors: mirror_addresses,
stats: Arc::new(AddressStats::default()),
};
address_id += 1;
if server.role == Role::Replica {
replica_number += 1;
}
// We assume every server in the pool share user/passwords
let auth_passthrough = AuthPassthrough::from_pool_config(pool_config);
if let Some(apt) = &auth_passthrough {
match apt.fetch_hash(&address).await {
Ok(ok) => {
if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) {
if ok != *pool_auth_hash_value {
warn!("Hash is not the same across shards of the same pool, client auth will \
be done using last obtained hash. Server: {}:{}, Database: {}", server.host, server.port, shard.database);
}
}
debug!("Hash obtained for {:?}", address);
{
let mut pool_auth_hash = pool_auth_hash.write();
*pool_auth_hash = Some(ok.clone());
}
},
Err(err) => warn!("Could not obtain password hashes using auth_query config, ignoring. Error: {:?}", err),
}
}
let manager = ServerPool::new(
address.clone(),
user.clone(),
&shard.database,
client_server_map.clone(),
pool_stats.clone(),
pool_auth_hash.clone(),
);
let connect_timeout = match pool_config.connect_timeout {
Some(connect_timeout) => connect_timeout,
None => config.general.connect_timeout,
};
let idle_timeout = match pool_config.idle_timeout {
Some(idle_timeout) => idle_timeout,
None => config.general.idle_timeout,
};
let pool = Pool::builder()
.max_size(user.pool_size)
.connection_timeout(std::time::Duration::from_millis(connect_timeout))
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
.test_on_check_out(false)
.build(manager)
.await
.unwrap();
pools.push(pool);
servers.push(address);
None => (),
}
shards.push(pools);
addresses.push(servers);
banlist.push(HashMap::new());
}
assert_eq!(shards.len(), addresses.len());
if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) {
info!(
"Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}",
"[pool: {}][user: {}] creating new pool",
pool_name, user.username
);
}
let pool = ConnectionPool {
databases: shards,
stats: pool_stats,
addresses,
banlist: Arc::new(RwLock::new(banlist)),
config_hash: new_pool_hash_value,
server_info: Arc::new(RwLock::new(BytesMut::new())),
auth_hash: pool_auth_hash,
settings: PoolSettings {
pool_mode: pool_config.pool_mode,
load_balancing_mode: pool_config.load_balancing_mode,
// shards: pool_config.shards.clone(),
shards: shard_ids.len(),
user: user.clone(),
default_role: match pool_config.default_role.as_str() {
"any" => None,
"replica" => Some(Role::Replica),
"primary" => Some(Role::Primary),
_ => unreachable!(),
let mut shards = Vec::new();
let mut addresses = Vec::new();
let mut banlist = Vec::new();
let mut shard_ids = pool_config
.shards
.clone()
.into_keys()
.collect::<Vec<String>>();
let pool_stats = Arc::new(PoolStats::new(identifier, pool_config.clone()));
// Allow the pool to be seen in statistics
pool_stats.register(pool_stats.clone());
// Sort by shard number to ensure consistency.
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
let pool_auth_hash: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
for shard_idx in &shard_ids {
let shard = &pool_config.shards[shard_idx];
let mut pools = Vec::new();
let mut servers = Vec::new();
let mut replica_number = 0;
// Load Mirror settings
for (address_index, server) in shard.servers.iter().enumerate() {
let mut mirror_addresses = vec![];
if let Some(mirror_settings_vec) = &shard.mirrors {
for (mirror_idx, mirror_settings) in
mirror_settings_vec.iter().enumerate()
{
if mirror_settings.mirroring_target_index != address_index {
continue;
}
mirror_addresses.push(Address {
id: address_id,
database: shard.database.clone(),
host: mirror_settings.host.clone(),
port: mirror_settings.port,
role: server.role,
address_index: mirror_idx,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
username: user.username.clone(),
pool_name: pool_name.clone(),
mirrors: vec![],
stats: Arc::new(AddressStats::default()),
});
address_id += 1;
}
}
let address = Address {
id: address_id,
database: shard.database.clone(),
host: server.host.clone(),
port: server.port,
role: server.role,
address_index,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
username: user.username.clone(),
pool_name: pool_name.clone(),
mirrors: mirror_addresses,
stats: Arc::new(AddressStats::default()),
};
address_id += 1;
if server.role == Role::Replica {
replica_number += 1;
}
// We assume every server in the pool share user/passwords
let auth_passthrough = AuthPassthrough::from_pool_config(pool_config);
if let Some(apt) = &auth_passthrough {
match apt.fetch_hash(&address).await {
Ok(ok) => {
if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) {
if ok != *pool_auth_hash_value {
warn!("Hash is not the same across shards of the same pool, client auth will \
be done using last obtained hash. Server: {}:{}, Database: {}", server.host, server.port, shard.database);
}
}
debug!("Hash obtained for {:?}", address);
{
let mut pool_auth_hash = pool_auth_hash.write();
*pool_auth_hash = Some(ok.clone());
}
},
Err(err) => warn!("Could not obtain password hashes using auth_query config, ignoring. Error: {:?}", err),
}
}
let manager = ServerPool::new(
address.clone(),
user.clone(),
&shard.database,
client_server_map.clone(),
pool_stats.clone(),
pool_auth_hash.clone(),
);
let connect_timeout = match pool_config.connect_timeout {
Some(connect_timeout) => connect_timeout,
None => config.general.connect_timeout,
};
let idle_timeout = match pool_config.idle_timeout {
Some(idle_timeout) => idle_timeout,
None => config.general.idle_timeout,
};
let pool = Pool::builder()
.max_size(user.pool_size)
.connection_timeout(std::time::Duration::from_millis(connect_timeout))
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
.test_on_check_out(false)
.build(manager)
.await
.unwrap();
pools.push(pool);
servers.push(address);
}
shards.push(pools);
addresses.push(servers);
banlist.push(HashMap::new());
}
assert_eq!(shards.len(), addresses.len());
if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) {
info!(
"Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}",
pool_name, user.username
);
}
let pool = ConnectionPool {
databases: shards,
stats: pool_stats,
addresses,
banlist: Arc::new(RwLock::new(banlist)),
config_hash: new_pool_hash_value,
server_info: Arc::new(RwLock::new(BytesMut::new())),
auth_hash: pool_auth_hash,
settings: PoolSettings {
pool_mode: pool_config.pool_mode,
load_balancing_mode: pool_config.load_balancing_mode,
// shards: pool_config.shards.clone(),
shards: shard_ids.len(),
user: user.clone(),
default_role: match pool_config.default_role.as_str() {
"any" => None,
"replica" => Some(Role::Replica),
"primary" => Some(Role::Primary),
_ => unreachable!(),
},
query_parser_enabled: pool_config.query_parser_enabled,
primary_reads_enabled: pool_config.primary_reads_enabled,
sharding_function: pool_config.sharding_function,
automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
healthcheck_delay: config.general.healthcheck_delay,
healthcheck_timeout: config.general.healthcheck_timeout,
ban_time: config.general.ban_time,
sharding_key_regex: pool_config
.sharding_key_regex
.clone()
.map(|regex| Regex::new(regex.as_str()).unwrap()),
shard_id_regex: pool_config
.shard_id_regex
.clone()
.map(|regex| Regex::new(regex.as_str()).unwrap()),
regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
auth_query: pool_config.auth_query.clone(),
auth_query_user: pool_config.auth_query_user.clone(),
auth_query_password: pool_config.auth_query_password.clone(),
},
query_parser_enabled: pool_config.query_parser_enabled,
primary_reads_enabled: pool_config.primary_reads_enabled,
sharding_function: pool_config.sharding_function,
automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
healthcheck_delay: config.general.healthcheck_delay,
healthcheck_timeout: config.general.healthcheck_timeout,
ban_time: config.general.ban_time,
sharding_key_regex: pool_config
.sharding_key_regex
.clone()
.map(|regex| Regex::new(regex.as_str()).unwrap()),
shard_id_regex: pool_config
.shard_id_regex
.clone()
.map(|regex| Regex::new(regex.as_str()).unwrap()),
regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
auth_query: pool_config.auth_query.clone(),
auth_query_user: pool_config.auth_query_user.clone(),
auth_query_password: pool_config.auth_query_password.clone(),
},
validated: Arc::new(AtomicBool::new(false)),
paused: Arc::new(AtomicBool::new(false)),
paused_waiter: Arc::new(Notify::new()),
};
validated: Arc::new(AtomicBool::new(false)),
paused: Arc::new(AtomicBool::new(false)),
paused_waiter: Arc::new(Notify::new()),
};
// Connect to the servers to make sure pool configuration is valid
// before setting it globally.
// Do this async and somewhere else, we don't have to wait here.
let mut validate_pool = pool.clone();
tokio::task::spawn(async move {
let _ = validate_pool.validate().await;
});
// Connect to the servers to make sure pool configuration is valid
// before setting it globally.
// Do this async and somewhere else, we don't have to wait here.
let mut validate_pool = pool.clone();
tokio::task::spawn(async move {
let _ = validate_pool.validate().await;
});
// There is one pool per database/user pair.
new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool);
// There is one pool per database/user pair.
new_pools.insert(PoolIdentifier::new(pool_name, &user.username, secret), pool);
}
}
}
@@ -924,10 +938,10 @@ impl ManageConnection for ServerPool {
}
/// Get the connection pool
pub fn get_pool(db: &str, user: &str) -> Option<ConnectionPool> {
(*(*POOLS.load()))
.get(&PoolIdentifier::new(db, user))
.cloned()
pub fn get_pool(db: &str, user: &str, secret: Option<String>) -> Option<ConnectionPool> {
let identifier = PoolIdentifier::new(db, user, secret);
(*(*POOLS.load())).get(&identifier).cloned()
}
/// Get a pointer to all configured pools.