Prevent clients from sticking to old pools after config update (#113)

* Re-acquire pool at the beginning of Protocol loop

* Fix query router + add tests for recycling behavior
This commit is contained in:
Mostafa Abdelraouf
2022-08-09 14:18:27 -05:00
committed by GitHub
parent 3719c22322
commit 7592339092
6 changed files with 168 additions and 38 deletions

View File

@@ -11,7 +11,7 @@ use crate::config::get_config;
use crate::constants::*;
use crate::errors::Error;
use crate::messages::*;
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
use crate::pool::{get_pool, ClientServerMap};
use crate::query_router::{Command, QueryRouter};
use crate::server::Server;
use crate::stats::{get_reporter, Reporter};
@@ -73,8 +73,13 @@ pub struct Client<S, T> {
/// Last server process id we talked to.
last_server_id: Option<i32>,
target_pool: ConnectionPool,
/// Name of the server pool for this client (This comes from the database name in the connection string)
target_pool_name: String,
/// Postgres user for this client (This comes from the user in the connection string)
target_user_name: String,
/// Used to notify clients about an impending shutdown
shutdown_event_receiver: Receiver<()>,
}
@@ -305,19 +310,19 @@ where
trace!("Got StartupMessage");
let parameters = parse_startup(bytes.clone())?;
let database = match parameters.get("database") {
let target_pool_name = match parameters.get("database") {
Some(db) => db,
None => return Err(Error::ClientError),
};
let user = match parameters.get("user") {
let target_user_name = match parameters.get("user") {
Some(user) => user,
None => return Err(Error::ClientError),
};
let admin = ["pgcat", "pgbouncer"]
.iter()
.filter(|db| *db == &database)
.filter(|db| *db == &target_pool_name)
.count()
== 1;
@@ -352,7 +357,7 @@ where
Err(_) => return Err(Error::SocketError),
};
let (target_pool, transaction_mode, server_info) = if admin {
let (transaction_mode, server_info) = if admin {
let correct_user = config.general.admin_username.as_str();
let correct_password = config.general.admin_password.as_str();
@@ -360,23 +365,20 @@ where
let password_hash = md5_hash_password(correct_user, correct_password, &salt);
if password_hash != password_response {
debug!("Password authentication failed");
wrong_password(&mut write, user).await?;
wrong_password(&mut write, target_user_name).await?;
return Err(Error::ClientError);
}
(
ConnectionPool::default(),
false,
generate_server_info_for_admin(),
)
(false, generate_server_info_for_admin())
} else {
let target_pool = match get_pool(database.clone(), user.clone()) {
let target_pool = match get_pool(target_pool_name.clone(), target_user_name.clone()) {
Some(pool) => pool,
None => {
error_response(
&mut write,
&format!(
"No pool configured for database: {:?}, user: {:?}",
database, user
target_pool_name, target_user_name
),
)
.await?;
@@ -387,14 +389,14 @@ where
let server_info = target_pool.server_info();
// Compare server and client hashes.
let correct_password = target_pool.settings.user.password.as_str();
let password_hash = md5_hash_password(user, correct_password, &salt);
let password_hash = md5_hash_password(&target_user_name, correct_password, &salt);
if password_hash != password_response {
debug!("Password authentication failed");
wrong_password(&mut write, user).await?;
wrong_password(&mut write, &target_user_name).await?;
return Err(Error::ClientError);
}
(target_pool, transaction_mode, server_info)
(transaction_mode, server_info)
};
debug!("Password authentication successful");
@@ -424,7 +426,8 @@ where
admin: admin,
last_address_id: None,
last_server_id: None,
target_pool: target_pool,
target_pool_name: target_pool_name.clone(),
target_user_name: target_user_name.clone(),
shutdown_event_receiver: shutdown_event_receiver,
});
}
@@ -455,7 +458,8 @@ where
admin: false,
last_address_id: None,
last_server_id: None,
target_pool: ConnectionPool::default(),
target_pool_name: String::from("undefined"),
target_user_name: String::from("undefined"),
shutdown_event_receiver: shutdown_event_receiver,
});
}
@@ -494,7 +498,7 @@ where
// The query router determines where the query is going to go,
// e.g. primary, replica, which shard.
let mut query_router = QueryRouter::new(self.target_pool.clone());
let mut query_router = QueryRouter::new();
let mut round_robin = 0;
// Our custom protocol loop.
@@ -520,11 +524,6 @@ where
message_result = read_message(&mut self.read) => message_result?
};
// Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config
// when starting a query.
let mut pool = self.target_pool.clone();
// Avoid taking a server if the client just wants to disconnect.
if message[0] as char == 'X' {
debug!("Client disconnecting");
@@ -538,6 +537,25 @@ where
continue;
}
// Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config
// when starting a query.
let mut pool =
match get_pool(self.target_pool_name.clone(), self.target_user_name.clone()) {
Some(pool) => pool,
None => {
error_response(
&mut self.write,
&format!(
"No pool configured for database: {:?}, user: {:?}",
self.target_pool_name, self.target_user_name
),
)
.await?;
return Err(Error::ClientError);
}
};
query_router.update_pool_settings(pool.settings.clone());
let current_shard = query_router.shard();
// Handle all custom protocol commands, if any.