mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-26 18:36:28 +00:00
Zero-downtime password rotation
This commit is contained in:
444
src/pool.rs
444
src/pool.rs
@@ -59,24 +59,22 @@ pub struct PoolIdentifier {
|
||||
|
||||
/// The username the client connects with. Each user gets its own pool.
|
||||
pub user: String,
|
||||
|
||||
/// The client secret (password).
|
||||
pub secret: Option<String>,
|
||||
}
|
||||
|
||||
impl PoolIdentifier {
|
||||
/// Create a new user/pool identifier.
|
||||
pub fn new(db: &str, user: &str) -> PoolIdentifier {
|
||||
pub fn new(db: &str, user: &str, secret: Option<String>) -> PoolIdentifier {
|
||||
PoolIdentifier {
|
||||
db: db.to_string(),
|
||||
user: user.to_string(),
|
||||
secret,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Address> for PoolIdentifier {
|
||||
fn from(address: &Address) -> PoolIdentifier {
|
||||
PoolIdentifier::new(&address.database, &address.username)
|
||||
}
|
||||
}
|
||||
|
||||
/// Pool settings.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PoolSettings {
|
||||
@@ -210,224 +208,240 @@ impl ConnectionPool {
|
||||
|
||||
// There is one pool per database/user pair.
|
||||
for user in pool_config.users.values() {
|
||||
let old_pool_ref = get_pool(pool_name, &user.username);
|
||||
let identifier = PoolIdentifier::new(pool_name, &user.username);
|
||||
let mut secrets = match &user.secrets {
|
||||
Some(_) => user
|
||||
.secrets
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|secret| Some(secret.to_string()))
|
||||
.collect::<Vec<Option<String>>>(),
|
||||
None => vec![],
|
||||
};
|
||||
|
||||
secrets.push(None);
|
||||
|
||||
match old_pool_ref {
|
||||
Some(pool) => {
|
||||
// If the pool hasn't changed, get existing reference and insert it into the new_pools.
|
||||
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
|
||||
if pool.config_hash == new_pool_hash_value {
|
||||
info!(
|
||||
"[pool: {}][user: {}] has not changed",
|
||||
pool_name, user.username
|
||||
);
|
||||
new_pools.insert(identifier.clone(), pool.clone());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
None => (),
|
||||
}
|
||||
for secret in secrets {
|
||||
|
||||
info!(
|
||||
"[pool: {}][user: {}] creating new pool",
|
||||
pool_name, user.username
|
||||
);
|
||||
let old_pool_ref = get_pool(pool_name, &user.username, secret.clone());
|
||||
let identifier = PoolIdentifier::new(pool_name, &user.username, secret.clone());
|
||||
|
||||
let mut shards = Vec::new();
|
||||
let mut addresses = Vec::new();
|
||||
let mut banlist = Vec::new();
|
||||
let mut shard_ids = pool_config
|
||||
.shards
|
||||
.clone()
|
||||
.into_keys()
|
||||
.collect::<Vec<String>>();
|
||||
let pool_stats = Arc::new(PoolStats::new(identifier, pool_config.clone()));
|
||||
|
||||
// Allow the pool to be seen in statistics
|
||||
pool_stats.register(pool_stats.clone());
|
||||
|
||||
// Sort by shard number to ensure consistency.
|
||||
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
|
||||
let pool_auth_hash: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
|
||||
|
||||
for shard_idx in &shard_ids {
|
||||
let shard = &pool_config.shards[shard_idx];
|
||||
let mut pools = Vec::new();
|
||||
let mut servers = Vec::new();
|
||||
let mut replica_number = 0;
|
||||
|
||||
// Load Mirror settings
|
||||
for (address_index, server) in shard.servers.iter().enumerate() {
|
||||
let mut mirror_addresses = vec![];
|
||||
if let Some(mirror_settings_vec) = &shard.mirrors {
|
||||
for (mirror_idx, mirror_settings) in
|
||||
mirror_settings_vec.iter().enumerate()
|
||||
{
|
||||
if mirror_settings.mirroring_target_index != address_index {
|
||||
continue;
|
||||
}
|
||||
mirror_addresses.push(Address {
|
||||
id: address_id,
|
||||
database: shard.database.clone(),
|
||||
host: mirror_settings.host.clone(),
|
||||
port: mirror_settings.port,
|
||||
role: server.role,
|
||||
address_index: mirror_idx,
|
||||
replica_number,
|
||||
shard: shard_idx.parse::<usize>().unwrap(),
|
||||
username: user.username.clone(),
|
||||
pool_name: pool_name.clone(),
|
||||
mirrors: vec![],
|
||||
stats: Arc::new(AddressStats::default()),
|
||||
});
|
||||
address_id += 1;
|
||||
match old_pool_ref {
|
||||
Some(pool) => {
|
||||
// If the pool hasn't changed, get existing reference and insert it into the new_pools.
|
||||
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
|
||||
if pool.config_hash == new_pool_hash_value {
|
||||
info!(
|
||||
"[pool: {}][user: {}] has not changed",
|
||||
pool_name, user.username
|
||||
);
|
||||
new_pools.insert(identifier.clone(), pool.clone());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let address = Address {
|
||||
id: address_id,
|
||||
database: shard.database.clone(),
|
||||
host: server.host.clone(),
|
||||
port: server.port,
|
||||
role: server.role,
|
||||
address_index,
|
||||
replica_number,
|
||||
shard: shard_idx.parse::<usize>().unwrap(),
|
||||
username: user.username.clone(),
|
||||
pool_name: pool_name.clone(),
|
||||
mirrors: mirror_addresses,
|
||||
stats: Arc::new(AddressStats::default()),
|
||||
};
|
||||
|
||||
address_id += 1;
|
||||
|
||||
if server.role == Role::Replica {
|
||||
replica_number += 1;
|
||||
}
|
||||
|
||||
// We assume every server in the pool share user/passwords
|
||||
let auth_passthrough = AuthPassthrough::from_pool_config(pool_config);
|
||||
|
||||
if let Some(apt) = &auth_passthrough {
|
||||
match apt.fetch_hash(&address).await {
|
||||
Ok(ok) => {
|
||||
if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) {
|
||||
if ok != *pool_auth_hash_value {
|
||||
warn!("Hash is not the same across shards of the same pool, client auth will \
|
||||
be done using last obtained hash. Server: {}:{}, Database: {}", server.host, server.port, shard.database);
|
||||
}
|
||||
}
|
||||
debug!("Hash obtained for {:?}", address);
|
||||
{
|
||||
let mut pool_auth_hash = pool_auth_hash.write();
|
||||
*pool_auth_hash = Some(ok.clone());
|
||||
}
|
||||
},
|
||||
Err(err) => warn!("Could not obtain password hashes using auth_query config, ignoring. Error: {:?}", err),
|
||||
}
|
||||
}
|
||||
|
||||
let manager = ServerPool::new(
|
||||
address.clone(),
|
||||
user.clone(),
|
||||
&shard.database,
|
||||
client_server_map.clone(),
|
||||
pool_stats.clone(),
|
||||
pool_auth_hash.clone(),
|
||||
);
|
||||
|
||||
let connect_timeout = match pool_config.connect_timeout {
|
||||
Some(connect_timeout) => connect_timeout,
|
||||
None => config.general.connect_timeout,
|
||||
};
|
||||
|
||||
let idle_timeout = match pool_config.idle_timeout {
|
||||
Some(idle_timeout) => idle_timeout,
|
||||
None => config.general.idle_timeout,
|
||||
};
|
||||
|
||||
let pool = Pool::builder()
|
||||
.max_size(user.pool_size)
|
||||
.connection_timeout(std::time::Duration::from_millis(connect_timeout))
|
||||
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
|
||||
.test_on_check_out(false)
|
||||
.build(manager)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
pools.push(pool);
|
||||
servers.push(address);
|
||||
None => (),
|
||||
}
|
||||
|
||||
shards.push(pools);
|
||||
addresses.push(servers);
|
||||
banlist.push(HashMap::new());
|
||||
}
|
||||
|
||||
assert_eq!(shards.len(), addresses.len());
|
||||
if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) {
|
||||
info!(
|
||||
"Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}",
|
||||
"[pool: {}][user: {}] creating new pool",
|
||||
pool_name, user.username
|
||||
);
|
||||
}
|
||||
|
||||
let pool = ConnectionPool {
|
||||
databases: shards,
|
||||
stats: pool_stats,
|
||||
addresses,
|
||||
banlist: Arc::new(RwLock::new(banlist)),
|
||||
config_hash: new_pool_hash_value,
|
||||
server_info: Arc::new(RwLock::new(BytesMut::new())),
|
||||
auth_hash: pool_auth_hash,
|
||||
settings: PoolSettings {
|
||||
pool_mode: pool_config.pool_mode,
|
||||
load_balancing_mode: pool_config.load_balancing_mode,
|
||||
// shards: pool_config.shards.clone(),
|
||||
shards: shard_ids.len(),
|
||||
user: user.clone(),
|
||||
default_role: match pool_config.default_role.as_str() {
|
||||
"any" => None,
|
||||
"replica" => Some(Role::Replica),
|
||||
"primary" => Some(Role::Primary),
|
||||
_ => unreachable!(),
|
||||
let mut shards = Vec::new();
|
||||
let mut addresses = Vec::new();
|
||||
let mut banlist = Vec::new();
|
||||
let mut shard_ids = pool_config
|
||||
.shards
|
||||
.clone()
|
||||
.into_keys()
|
||||
.collect::<Vec<String>>();
|
||||
let pool_stats = Arc::new(PoolStats::new(identifier, pool_config.clone()));
|
||||
|
||||
// Allow the pool to be seen in statistics
|
||||
pool_stats.register(pool_stats.clone());
|
||||
|
||||
// Sort by shard number to ensure consistency.
|
||||
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
|
||||
let pool_auth_hash: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
|
||||
|
||||
for shard_idx in &shard_ids {
|
||||
let shard = &pool_config.shards[shard_idx];
|
||||
let mut pools = Vec::new();
|
||||
let mut servers = Vec::new();
|
||||
let mut replica_number = 0;
|
||||
|
||||
// Load Mirror settings
|
||||
for (address_index, server) in shard.servers.iter().enumerate() {
|
||||
let mut mirror_addresses = vec![];
|
||||
if let Some(mirror_settings_vec) = &shard.mirrors {
|
||||
for (mirror_idx, mirror_settings) in
|
||||
mirror_settings_vec.iter().enumerate()
|
||||
{
|
||||
if mirror_settings.mirroring_target_index != address_index {
|
||||
continue;
|
||||
}
|
||||
mirror_addresses.push(Address {
|
||||
id: address_id,
|
||||
database: shard.database.clone(),
|
||||
host: mirror_settings.host.clone(),
|
||||
port: mirror_settings.port,
|
||||
role: server.role,
|
||||
address_index: mirror_idx,
|
||||
replica_number,
|
||||
shard: shard_idx.parse::<usize>().unwrap(),
|
||||
username: user.username.clone(),
|
||||
pool_name: pool_name.clone(),
|
||||
mirrors: vec![],
|
||||
stats: Arc::new(AddressStats::default()),
|
||||
});
|
||||
address_id += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let address = Address {
|
||||
id: address_id,
|
||||
database: shard.database.clone(),
|
||||
host: server.host.clone(),
|
||||
port: server.port,
|
||||
role: server.role,
|
||||
address_index,
|
||||
replica_number,
|
||||
shard: shard_idx.parse::<usize>().unwrap(),
|
||||
username: user.username.clone(),
|
||||
pool_name: pool_name.clone(),
|
||||
mirrors: mirror_addresses,
|
||||
stats: Arc::new(AddressStats::default()),
|
||||
};
|
||||
|
||||
address_id += 1;
|
||||
|
||||
if server.role == Role::Replica {
|
||||
replica_number += 1;
|
||||
}
|
||||
|
||||
// We assume every server in the pool share user/passwords
|
||||
let auth_passthrough = AuthPassthrough::from_pool_config(pool_config);
|
||||
|
||||
if let Some(apt) = &auth_passthrough {
|
||||
match apt.fetch_hash(&address).await {
|
||||
Ok(ok) => {
|
||||
if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) {
|
||||
if ok != *pool_auth_hash_value {
|
||||
warn!("Hash is not the same across shards of the same pool, client auth will \
|
||||
be done using last obtained hash. Server: {}:{}, Database: {}", server.host, server.port, shard.database);
|
||||
}
|
||||
}
|
||||
debug!("Hash obtained for {:?}", address);
|
||||
{
|
||||
let mut pool_auth_hash = pool_auth_hash.write();
|
||||
*pool_auth_hash = Some(ok.clone());
|
||||
}
|
||||
},
|
||||
Err(err) => warn!("Could not obtain password hashes using auth_query config, ignoring. Error: {:?}", err),
|
||||
}
|
||||
}
|
||||
|
||||
let manager = ServerPool::new(
|
||||
address.clone(),
|
||||
user.clone(),
|
||||
&shard.database,
|
||||
client_server_map.clone(),
|
||||
pool_stats.clone(),
|
||||
pool_auth_hash.clone(),
|
||||
);
|
||||
|
||||
let connect_timeout = match pool_config.connect_timeout {
|
||||
Some(connect_timeout) => connect_timeout,
|
||||
None => config.general.connect_timeout,
|
||||
};
|
||||
|
||||
let idle_timeout = match pool_config.idle_timeout {
|
||||
Some(idle_timeout) => idle_timeout,
|
||||
None => config.general.idle_timeout,
|
||||
};
|
||||
|
||||
let pool = Pool::builder()
|
||||
.max_size(user.pool_size)
|
||||
.connection_timeout(std::time::Duration::from_millis(connect_timeout))
|
||||
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
|
||||
.test_on_check_out(false)
|
||||
.build(manager)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
pools.push(pool);
|
||||
servers.push(address);
|
||||
}
|
||||
|
||||
shards.push(pools);
|
||||
addresses.push(servers);
|
||||
banlist.push(HashMap::new());
|
||||
}
|
||||
|
||||
assert_eq!(shards.len(), addresses.len());
|
||||
if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) {
|
||||
info!(
|
||||
"Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}",
|
||||
pool_name, user.username
|
||||
);
|
||||
}
|
||||
|
||||
let pool = ConnectionPool {
|
||||
databases: shards,
|
||||
stats: pool_stats,
|
||||
addresses,
|
||||
banlist: Arc::new(RwLock::new(banlist)),
|
||||
config_hash: new_pool_hash_value,
|
||||
server_info: Arc::new(RwLock::new(BytesMut::new())),
|
||||
auth_hash: pool_auth_hash,
|
||||
settings: PoolSettings {
|
||||
pool_mode: pool_config.pool_mode,
|
||||
load_balancing_mode: pool_config.load_balancing_mode,
|
||||
// shards: pool_config.shards.clone(),
|
||||
shards: shard_ids.len(),
|
||||
user: user.clone(),
|
||||
default_role: match pool_config.default_role.as_str() {
|
||||
"any" => None,
|
||||
"replica" => Some(Role::Replica),
|
||||
"primary" => Some(Role::Primary),
|
||||
_ => unreachable!(),
|
||||
},
|
||||
query_parser_enabled: pool_config.query_parser_enabled,
|
||||
primary_reads_enabled: pool_config.primary_reads_enabled,
|
||||
sharding_function: pool_config.sharding_function,
|
||||
automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
|
||||
healthcheck_delay: config.general.healthcheck_delay,
|
||||
healthcheck_timeout: config.general.healthcheck_timeout,
|
||||
ban_time: config.general.ban_time,
|
||||
sharding_key_regex: pool_config
|
||||
.sharding_key_regex
|
||||
.clone()
|
||||
.map(|regex| Regex::new(regex.as_str()).unwrap()),
|
||||
shard_id_regex: pool_config
|
||||
.shard_id_regex
|
||||
.clone()
|
||||
.map(|regex| Regex::new(regex.as_str()).unwrap()),
|
||||
regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
|
||||
auth_query: pool_config.auth_query.clone(),
|
||||
auth_query_user: pool_config.auth_query_user.clone(),
|
||||
auth_query_password: pool_config.auth_query_password.clone(),
|
||||
},
|
||||
query_parser_enabled: pool_config.query_parser_enabled,
|
||||
primary_reads_enabled: pool_config.primary_reads_enabled,
|
||||
sharding_function: pool_config.sharding_function,
|
||||
automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
|
||||
healthcheck_delay: config.general.healthcheck_delay,
|
||||
healthcheck_timeout: config.general.healthcheck_timeout,
|
||||
ban_time: config.general.ban_time,
|
||||
sharding_key_regex: pool_config
|
||||
.sharding_key_regex
|
||||
.clone()
|
||||
.map(|regex| Regex::new(regex.as_str()).unwrap()),
|
||||
shard_id_regex: pool_config
|
||||
.shard_id_regex
|
||||
.clone()
|
||||
.map(|regex| Regex::new(regex.as_str()).unwrap()),
|
||||
regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
|
||||
auth_query: pool_config.auth_query.clone(),
|
||||
auth_query_user: pool_config.auth_query_user.clone(),
|
||||
auth_query_password: pool_config.auth_query_password.clone(),
|
||||
},
|
||||
validated: Arc::new(AtomicBool::new(false)),
|
||||
paused: Arc::new(AtomicBool::new(false)),
|
||||
paused_waiter: Arc::new(Notify::new()),
|
||||
};
|
||||
validated: Arc::new(AtomicBool::new(false)),
|
||||
paused: Arc::new(AtomicBool::new(false)),
|
||||
paused_waiter: Arc::new(Notify::new()),
|
||||
};
|
||||
|
||||
// Connect to the servers to make sure pool configuration is valid
|
||||
// before setting it globally.
|
||||
// Do this async and somewhere else, we don't have to wait here.
|
||||
let mut validate_pool = pool.clone();
|
||||
tokio::task::spawn(async move {
|
||||
let _ = validate_pool.validate().await;
|
||||
});
|
||||
// Connect to the servers to make sure pool configuration is valid
|
||||
// before setting it globally.
|
||||
// Do this async and somewhere else, we don't have to wait here.
|
||||
let mut validate_pool = pool.clone();
|
||||
tokio::task::spawn(async move {
|
||||
let _ = validate_pool.validate().await;
|
||||
});
|
||||
|
||||
// There is one pool per database/user pair.
|
||||
new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool);
|
||||
// There is one pool per database/user pair.
|
||||
new_pools.insert(PoolIdentifier::new(pool_name, &user.username, secret), pool);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -924,10 +938,10 @@ impl ManageConnection for ServerPool {
|
||||
}
|
||||
|
||||
/// Get the connection pool
|
||||
pub fn get_pool(db: &str, user: &str) -> Option<ConnectionPool> {
|
||||
(*(*POOLS.load()))
|
||||
.get(&PoolIdentifier::new(db, user))
|
||||
.cloned()
|
||||
pub fn get_pool(db: &str, user: &str, secret: Option<String>) -> Option<ConnectionPool> {
|
||||
let identifier = PoolIdentifier::new(db, user, secret);
|
||||
|
||||
(*(*POOLS.load())).get(&identifier).cloned()
|
||||
}
|
||||
|
||||
/// Get a pointer to all configured pools.
|
||||
|
||||
Reference in New Issue
Block a user