mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-25 18:06:29 +00:00
Graceful shutdown and refactor (#144)
* Graceful shutdown and refactor * ok * _Graceful_ shutdown * Remove hardcoded setting * clean up * end * timeout * hmm * hmm! * bash * bash * hmm * maybe maybe * Adds tests and move non-admin connection rejection to startup (#145) * Move error response * Adds tests and removes unused variable * Adds debug log Co-authored-by: zainkabani <77307340+zainkabani@users.noreply.github.com>
This commit is contained in:
270
src/client.rs
270
src/client.rs
@@ -5,13 +5,14 @@ use std::collections::HashMap;
|
||||
use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::broadcast::Receiver;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
|
||||
use crate::admin::{generate_server_info_for_admin, handle_admin};
|
||||
use crate::config::{get_config, Address};
|
||||
use crate::constants::*;
|
||||
use crate::errors::Error;
|
||||
use crate::messages::*;
|
||||
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
|
||||
use crate::pool::{get_pool, ClientServerMap, ConnectionPool, PoolMode};
|
||||
use crate::query_router::{Command, QueryRouter};
|
||||
use crate::server::Server;
|
||||
use crate::stats::{get_reporter, Reporter};
|
||||
@@ -58,7 +59,6 @@ pub struct Client<S, T> {
|
||||
client_server_map: ClientServerMap,
|
||||
|
||||
/// Client parameters, e.g. user, client_encoding, etc.
|
||||
#[allow(dead_code)]
|
||||
parameters: HashMap<String, String>,
|
||||
|
||||
/// Statistics
|
||||
@@ -77,20 +77,22 @@ pub struct Client<S, T> {
|
||||
connected_to_server: bool,
|
||||
|
||||
/// Name of the server pool for this client (This comes from the database name in the connection string)
|
||||
target_pool_name: String,
|
||||
pool_name: String,
|
||||
|
||||
/// Postgres user for this client (This comes from the user in the connection string)
|
||||
target_user_name: String,
|
||||
username: String,
|
||||
|
||||
/// Used to notify clients about an impending shutdown
|
||||
shutdown_event_receiver: Receiver<()>,
|
||||
shutdown: Receiver<()>,
|
||||
}
|
||||
|
||||
/// Client entrypoint.
|
||||
pub async fn client_entrypoint(
|
||||
mut stream: TcpStream,
|
||||
client_server_map: ClientServerMap,
|
||||
shutdown_event_receiver: Receiver<()>,
|
||||
shutdown: Receiver<()>,
|
||||
drain: Sender<i8>,
|
||||
admin_only: bool,
|
||||
) -> Result<(), Error> {
|
||||
// Figure out if the client wants TLS or not.
|
||||
let addr = stream.peer_addr().unwrap();
|
||||
@@ -109,11 +111,21 @@ pub async fn client_entrypoint(
|
||||
write_all(&mut stream, yes).await?;
|
||||
|
||||
// Negotiate TLS.
|
||||
match startup_tls(stream, client_server_map, shutdown_event_receiver).await {
|
||||
match startup_tls(stream, client_server_map, shutdown, admin_only).await {
|
||||
Ok(mut client) => {
|
||||
info!("Client {:?} connected (TLS)", addr);
|
||||
|
||||
client.handle().await
|
||||
if !client.is_admin() {
|
||||
let _ = drain.send(1).await;
|
||||
}
|
||||
|
||||
let result = client.handle().await;
|
||||
|
||||
if !client.is_admin() {
|
||||
let _ = drain.send(-1).await;
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
@@ -139,14 +151,25 @@ pub async fn client_entrypoint(
|
||||
addr,
|
||||
bytes,
|
||||
client_server_map,
|
||||
shutdown_event_receiver,
|
||||
shutdown,
|
||||
admin_only,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(mut client) => {
|
||||
info!("Client {:?} connected (plain)", addr);
|
||||
|
||||
client.handle().await
|
||||
if !client.is_admin() {
|
||||
let _ = drain.send(1).await;
|
||||
}
|
||||
|
||||
let result = client.handle().await;
|
||||
|
||||
if !client.is_admin() {
|
||||
let _ = drain.send(-1).await;
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
@@ -169,14 +192,25 @@ pub async fn client_entrypoint(
|
||||
addr,
|
||||
bytes,
|
||||
client_server_map,
|
||||
shutdown_event_receiver,
|
||||
shutdown,
|
||||
admin_only,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(mut client) => {
|
||||
info!("Client {:?} connected (plain)", addr);
|
||||
|
||||
client.handle().await
|
||||
if client.is_admin() {
|
||||
let _ = drain.send(1).await;
|
||||
}
|
||||
|
||||
let result = client.handle().await;
|
||||
|
||||
if !client.is_admin() {
|
||||
let _ = drain.send(-1).await;
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
@@ -187,20 +221,21 @@ pub async fn client_entrypoint(
|
||||
let (read, write) = split(stream);
|
||||
|
||||
// Continue with cancel query request.
|
||||
match Client::cancel(
|
||||
read,
|
||||
write,
|
||||
addr,
|
||||
bytes,
|
||||
client_server_map,
|
||||
shutdown_event_receiver,
|
||||
)
|
||||
.await
|
||||
{
|
||||
match Client::cancel(read, write, addr, bytes, client_server_map, shutdown).await {
|
||||
Ok(mut client) => {
|
||||
info!("Client {:?} issued a cancel query request", addr);
|
||||
|
||||
client.handle().await
|
||||
if client.is_admin() {
|
||||
let _ = drain.send(1).await;
|
||||
}
|
||||
|
||||
let result = client.handle().await;
|
||||
|
||||
if !client.is_admin() {
|
||||
let _ = drain.send(-1).await;
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
Err(err) => Err(err),
|
||||
@@ -253,7 +288,8 @@ where
|
||||
pub async fn startup_tls(
|
||||
stream: TcpStream,
|
||||
client_server_map: ClientServerMap,
|
||||
shutdown_event_receiver: Receiver<()>,
|
||||
shutdown: Receiver<()>,
|
||||
admin_only: bool,
|
||||
) -> Result<Client<ReadHalf<TlsStream<TcpStream>>, WriteHalf<TlsStream<TcpStream>>>, Error> {
|
||||
// Negotiate TLS.
|
||||
let tls = Tls::new()?;
|
||||
@@ -283,7 +319,8 @@ pub async fn startup_tls(
|
||||
addr,
|
||||
bytes,
|
||||
client_server_map,
|
||||
shutdown_event_receiver,
|
||||
shutdown,
|
||||
admin_only,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -298,6 +335,10 @@ where
|
||||
S: tokio::io::AsyncRead + std::marker::Unpin,
|
||||
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
pub fn is_admin(&self) -> bool {
|
||||
self.admin
|
||||
}
|
||||
|
||||
/// Handle Postgres client startup after TLS negotiation is complete
|
||||
/// or over plain text.
|
||||
pub async fn startup(
|
||||
@@ -306,29 +347,44 @@ where
|
||||
addr: std::net::SocketAddr,
|
||||
bytes: BytesMut, // The rest of the startup message.
|
||||
client_server_map: ClientServerMap,
|
||||
shutdown_event_receiver: Receiver<()>,
|
||||
shutdown: Receiver<()>,
|
||||
admin_only: bool,
|
||||
) -> Result<Client<S, T>, Error> {
|
||||
let config = get_config();
|
||||
let stats = get_reporter();
|
||||
|
||||
trace!("Got StartupMessage");
|
||||
let parameters = parse_startup(bytes.clone())?;
|
||||
let target_pool_name = match parameters.get("database") {
|
||||
|
||||
// These two parameters are mandatory by the protocol.
|
||||
let pool_name = match parameters.get("database") {
|
||||
Some(db) => db,
|
||||
None => return Err(Error::ClientError),
|
||||
};
|
||||
|
||||
let target_user_name = match parameters.get("user") {
|
||||
let username = match parameters.get("user") {
|
||||
Some(user) => user,
|
||||
None => return Err(Error::ClientError),
|
||||
};
|
||||
|
||||
let admin = ["pgcat", "pgbouncer"]
|
||||
.iter()
|
||||
.filter(|db| *db == &target_pool_name)
|
||||
.filter(|db| *db == &pool_name)
|
||||
.count()
|
||||
== 1;
|
||||
|
||||
// Kick any client that's not admin while we're in admin-only mode.
|
||||
if !admin && admin_only {
|
||||
debug!(
|
||||
"Rejecting non-admin connection to {} when in admin only mode",
|
||||
pool_name
|
||||
);
|
||||
error_response_terminal(
|
||||
&mut write,
|
||||
&format!("terminating connection due to administrator command"),
|
||||
)
|
||||
.await?;
|
||||
return Err(Error::ShuttingDown);
|
||||
}
|
||||
|
||||
// Generate random backend ID and secret key
|
||||
let process_id: i32 = rand::random();
|
||||
let secret_key: i32 = rand::random();
|
||||
@@ -360,46 +416,55 @@ where
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
};
|
||||
|
||||
// Authenticate admin user.
|
||||
let (transaction_mode, server_info) = if admin {
|
||||
let correct_user = config.general.admin_username.as_str();
|
||||
let correct_password = config.general.admin_password.as_str();
|
||||
|
||||
// Compare server and client hashes.
|
||||
let password_hash = md5_hash_password(correct_user, correct_password, &salt);
|
||||
let password_hash = md5_hash_password(
|
||||
&config.general.admin_username,
|
||||
&config.general.admin_password,
|
||||
&salt,
|
||||
);
|
||||
|
||||
if password_hash != password_response {
|
||||
debug!("Password authentication failed");
|
||||
wrong_password(&mut write, target_user_name).await?;
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(Error::ClientError);
|
||||
}
|
||||
|
||||
(false, generate_server_info_for_admin())
|
||||
} else {
|
||||
let target_pool = match get_pool(target_pool_name.clone(), target_user_name.clone()) {
|
||||
}
|
||||
// Authenticate normal user.
|
||||
else {
|
||||
let pool = match get_pool(pool_name.clone(), username.clone()) {
|
||||
Some(pool) => pool,
|
||||
None => {
|
||||
error_response(
|
||||
&mut write,
|
||||
&format!(
|
||||
"No pool configured for database: {:?}, user: {:?}",
|
||||
target_pool_name, target_user_name
|
||||
pool_name, username
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
|
||||
return Err(Error::ClientError);
|
||||
}
|
||||
};
|
||||
let transaction_mode = target_pool.settings.pool_mode == "transaction";
|
||||
let server_info = target_pool.server_info();
|
||||
|
||||
// Compare server and client hashes.
|
||||
let correct_password = target_pool.settings.user.password.as_str();
|
||||
let password_hash = md5_hash_password(&target_user_name, correct_password, &salt);
|
||||
let password_hash = md5_hash_password(&username, &pool.settings.user.password, &salt);
|
||||
|
||||
if password_hash != password_response {
|
||||
debug!("Password authentication failed");
|
||||
wrong_password(&mut write, &target_user_name).await?;
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(Error::ClientError);
|
||||
}
|
||||
(transaction_mode, server_info)
|
||||
|
||||
let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction;
|
||||
|
||||
(transaction_mode, pool.server_info())
|
||||
};
|
||||
|
||||
debug!("Password authentication successful");
|
||||
@@ -411,27 +476,24 @@ where
|
||||
|
||||
trace!("Startup OK");
|
||||
|
||||
// Split the read and write streams
|
||||
// so we can control buffering.
|
||||
|
||||
return Ok(Client {
|
||||
read: BufReader::new(read),
|
||||
write: write,
|
||||
addr,
|
||||
buffer: BytesMut::with_capacity(8196),
|
||||
cancel_mode: false,
|
||||
transaction_mode: transaction_mode,
|
||||
process_id: process_id,
|
||||
secret_key: secret_key,
|
||||
client_server_map: client_server_map,
|
||||
transaction_mode,
|
||||
process_id,
|
||||
secret_key,
|
||||
client_server_map,
|
||||
parameters: parameters.clone(),
|
||||
stats: stats,
|
||||
admin: admin,
|
||||
last_address_id: None,
|
||||
last_server_id: None,
|
||||
target_pool_name: target_pool_name.clone(),
|
||||
target_user_name: target_user_name.clone(),
|
||||
shutdown_event_receiver: shutdown_event_receiver,
|
||||
pool_name: pool_name.clone(),
|
||||
username: username.clone(),
|
||||
shutdown,
|
||||
connected_to_server: false,
|
||||
});
|
||||
}
|
||||
@@ -443,7 +505,7 @@ where
|
||||
addr: std::net::SocketAddr,
|
||||
mut bytes: BytesMut, // The rest of the startup message.
|
||||
client_server_map: ClientServerMap,
|
||||
shutdown_event_receiver: Receiver<()>,
|
||||
shutdown: Receiver<()>,
|
||||
) -> Result<Client<S, T>, Error> {
|
||||
let process_id = bytes.get_i32();
|
||||
let secret_key = bytes.get_i32();
|
||||
@@ -454,17 +516,17 @@ where
|
||||
buffer: BytesMut::with_capacity(8196),
|
||||
cancel_mode: true,
|
||||
transaction_mode: false,
|
||||
process_id: process_id,
|
||||
secret_key: secret_key,
|
||||
client_server_map: client_server_map,
|
||||
process_id,
|
||||
secret_key,
|
||||
client_server_map,
|
||||
parameters: HashMap::new(),
|
||||
stats: get_reporter(),
|
||||
admin: false,
|
||||
last_address_id: None,
|
||||
last_server_id: None,
|
||||
target_pool_name: String::from("undefined"),
|
||||
target_user_name: String::from("undefined"),
|
||||
shutdown_event_receiver: shutdown_event_receiver,
|
||||
pool_name: String::from("undefined"),
|
||||
username: String::from("undefined"),
|
||||
shutdown,
|
||||
connected_to_server: false,
|
||||
});
|
||||
}
|
||||
@@ -486,7 +548,7 @@ where
|
||||
process_id.clone(),
|
||||
secret_key.clone(),
|
||||
address.clone(),
|
||||
port.clone(),
|
||||
*port,
|
||||
),
|
||||
|
||||
// The client doesn't know / got the wrong server,
|
||||
@@ -498,7 +560,7 @@ where
|
||||
// Opens a new separate connection to the server, sends the backend_id
|
||||
// and secret_key and then closes it for security reasons. No other interactions
|
||||
// take place.
|
||||
return Ok(Server::cancel(&address, &port, process_id, secret_key).await?);
|
||||
return Ok(Server::cancel(&address, port, process_id, secret_key).await?);
|
||||
}
|
||||
|
||||
// The query router determines where the query is going to go,
|
||||
@@ -521,9 +583,19 @@ where
|
||||
// SET SHARDING KEY TO 'bigint';
|
||||
|
||||
let mut message = tokio::select! {
|
||||
_ = self.shutdown_event_receiver.recv() => {
|
||||
error_response_terminal(&mut self.write, &format!("terminating connection due to administrator command")).await?;
|
||||
return Ok(())
|
||||
_ = self.shutdown.recv() => {
|
||||
if !self.admin {
|
||||
error_response_terminal(
|
||||
&mut self.write,
|
||||
&format!("terminating connection due to administrator command")
|
||||
).await?;
|
||||
return Ok(())
|
||||
}
|
||||
|
||||
// Admin clients ignore shutdown.
|
||||
else {
|
||||
read_message(&mut self.read).await?
|
||||
}
|
||||
},
|
||||
message_result = read_message(&mut self.read) => message_result?
|
||||
};
|
||||
@@ -544,15 +616,14 @@ where
|
||||
// Get a pool instance referenced by the most up-to-date
|
||||
// pointer. This ensures we always read the latest config
|
||||
// when starting a query.
|
||||
let pool = match get_pool(self.target_pool_name.clone(), self.target_user_name.clone())
|
||||
{
|
||||
let pool = match get_pool(self.pool_name.clone(), self.username.clone()) {
|
||||
Some(pool) => pool,
|
||||
None => {
|
||||
error_response(
|
||||
&mut self.write,
|
||||
&format!(
|
||||
"No pool configured for database: {:?}, user: {:?}",
|
||||
self.target_pool_name, self.target_user_name
|
||||
self.pool_name, self.username
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
@@ -649,14 +720,16 @@ where
|
||||
match message[0] as char {
|
||||
'P' | 'B' | 'E' | 'D' => (),
|
||||
_ => {
|
||||
error!("Could not get connection from pool: {:?}", err);
|
||||
error_response(
|
||||
&mut self.write,
|
||||
"could not get connection from the pool",
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
error!("Could not get connection from pool: {:?}", err);
|
||||
|
||||
continue;
|
||||
}
|
||||
};
|
||||
@@ -739,15 +812,8 @@ where
|
||||
'Q' => {
|
||||
debug!("Sending query to server");
|
||||
|
||||
self.send_and_receive_loop(
|
||||
code,
|
||||
original,
|
||||
server,
|
||||
&address,
|
||||
query_router.shard(),
|
||||
&pool,
|
||||
)
|
||||
.await?;
|
||||
self.send_and_receive_loop(code, original, server, &address, &pool)
|
||||
.await?;
|
||||
|
||||
if !server.in_transaction() {
|
||||
// Report transaction executed statistics.
|
||||
@@ -814,7 +880,6 @@ where
|
||||
self.buffer.clone(),
|
||||
server,
|
||||
&address,
|
||||
query_router.shard(),
|
||||
&pool,
|
||||
)
|
||||
.await?;
|
||||
@@ -836,32 +901,18 @@ where
|
||||
'd' => {
|
||||
// Forward the data to the server,
|
||||
// don't buffer it since it can be rather large.
|
||||
self.send_server_message(
|
||||
server,
|
||||
original,
|
||||
&address,
|
||||
query_router.shard(),
|
||||
&pool,
|
||||
)
|
||||
.await?;
|
||||
self.send_server_message(server, original, &address, &pool)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// CopyDone or CopyFail
|
||||
// Copy is done, successfully or not.
|
||||
'c' | 'f' => {
|
||||
self.send_server_message(
|
||||
server,
|
||||
original,
|
||||
&address,
|
||||
query_router.shard(),
|
||||
&pool,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let response = self
|
||||
.receive_server_message(server, &address, query_router.shard(), &pool)
|
||||
self.send_server_message(server, original, &address, &pool)
|
||||
.await?;
|
||||
|
||||
let response = self.receive_server_message(server, &address, &pool).await?;
|
||||
|
||||
match write_all_half(&mut self.write, response).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
@@ -910,20 +961,17 @@ where
|
||||
message: BytesMut,
|
||||
server: &mut Server,
|
||||
address: &Address,
|
||||
shard: usize,
|
||||
pool: &ConnectionPool,
|
||||
) -> Result<(), Error> {
|
||||
debug!("Sending {} to server", code);
|
||||
|
||||
self.send_server_message(server, message, &address, shard, &pool)
|
||||
self.send_server_message(server, message, &address, &pool)
|
||||
.await?;
|
||||
|
||||
// Read all data the server has to offer, which can be multiple messages
|
||||
// buffered in 8196 bytes chunks.
|
||||
loop {
|
||||
let response = self
|
||||
.receive_server_message(server, &address, shard, &pool)
|
||||
.await?;
|
||||
let response = self.receive_server_message(server, &address, &pool).await?;
|
||||
|
||||
match write_all_half(&mut self.write, response).await {
|
||||
Ok(_) => (),
|
||||
@@ -949,13 +997,12 @@ where
|
||||
server: &mut Server,
|
||||
message: BytesMut,
|
||||
address: &Address,
|
||||
shard: usize,
|
||||
pool: &ConnectionPool,
|
||||
) -> Result<(), Error> {
|
||||
match server.send(message).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => {
|
||||
pool.ban(address, shard, self.process_id);
|
||||
pool.ban(address, self.process_id);
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
@@ -965,7 +1012,6 @@ where
|
||||
&mut self,
|
||||
server: &mut Server,
|
||||
address: &Address,
|
||||
shard: usize,
|
||||
pool: &ConnectionPool,
|
||||
) -> Result<BytesMut, Error> {
|
||||
if pool.settings.user.statement_timeout > 0 {
|
||||
@@ -978,7 +1024,7 @@ where
|
||||
Ok(result) => match result {
|
||||
Ok(message) => Ok(message),
|
||||
Err(err) => {
|
||||
pool.ban(address, shard, self.process_id);
|
||||
pool.ban(address, self.process_id);
|
||||
error_response_terminal(
|
||||
&mut self.write,
|
||||
&format!("error receiving data from server: {:?}", err),
|
||||
@@ -993,7 +1039,7 @@ where
|
||||
address, pool.settings.user.username
|
||||
);
|
||||
server.mark_bad();
|
||||
pool.ban(address, shard, self.process_id);
|
||||
pool.ban(address, self.process_id);
|
||||
error_response_terminal(&mut self.write, "pool statement timeout").await?;
|
||||
Err(Error::StatementTimeout)
|
||||
}
|
||||
@@ -1002,7 +1048,7 @@ where
|
||||
match server.recv().await {
|
||||
Ok(message) => Ok(message),
|
||||
Err(err) => {
|
||||
pool.ban(address, shard, self.process_id);
|
||||
pool.ban(address, self.process_id);
|
||||
error_response_terminal(
|
||||
&mut self.write,
|
||||
&format!("error receiving data from server: {:?}", err),
|
||||
|
||||
Reference in New Issue
Block a user