Graceful shutdown and refactor (#144)

* Graceful shutdown and refactor

* ok

* _Graceful_ shutdown

* Remove hardcoded setting

* clean up

* end

* timeout

* hmm

* hmm!

* bash

* bash

* hmm

* maybe maybe

* Adds tests and move non-admin connection rejection to startup (#145)

* Move error response

* Adds tests and removes unused variable

* Adds debug log

Co-authored-by: zainkabani <77307340+zainkabani@users.noreply.github.com>
This commit is contained in:
Lev Kokotov
2022-08-25 06:40:56 -07:00
committed by GitHub
parent c054ff068d
commit 9d84d6f131
9 changed files with 602 additions and 382 deletions

View File

@@ -5,13 +5,14 @@ use std::collections::HashMap;
use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio::sync::broadcast::Receiver;
use tokio::sync::mpsc::Sender;
use crate::admin::{generate_server_info_for_admin, handle_admin};
use crate::config::{get_config, Address};
use crate::constants::*;
use crate::errors::Error;
use crate::messages::*;
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
use crate::pool::{get_pool, ClientServerMap, ConnectionPool, PoolMode};
use crate::query_router::{Command, QueryRouter};
use crate::server::Server;
use crate::stats::{get_reporter, Reporter};
@@ -58,7 +59,6 @@ pub struct Client<S, T> {
client_server_map: ClientServerMap,
/// Client parameters, e.g. user, client_encoding, etc.
#[allow(dead_code)]
parameters: HashMap<String, String>,
/// Statistics
@@ -77,20 +77,22 @@ pub struct Client<S, T> {
connected_to_server: bool,
/// Name of the server pool for this client (This comes from the database name in the connection string)
target_pool_name: String,
pool_name: String,
/// Postgres user for this client (This comes from the user in the connection string)
target_user_name: String,
username: String,
/// Used to notify clients about an impending shutdown
shutdown_event_receiver: Receiver<()>,
shutdown: Receiver<()>,
}
/// Client entrypoint.
pub async fn client_entrypoint(
mut stream: TcpStream,
client_server_map: ClientServerMap,
shutdown_event_receiver: Receiver<()>,
shutdown: Receiver<()>,
drain: Sender<i8>,
admin_only: bool,
) -> Result<(), Error> {
// Figure out if the client wants TLS or not.
let addr = stream.peer_addr().unwrap();
@@ -109,11 +111,21 @@ pub async fn client_entrypoint(
write_all(&mut stream, yes).await?;
// Negotiate TLS.
match startup_tls(stream, client_server_map, shutdown_event_receiver).await {
match startup_tls(stream, client_server_map, shutdown, admin_only).await {
Ok(mut client) => {
info!("Client {:?} connected (TLS)", addr);
client.handle().await
if !client.is_admin() {
let _ = drain.send(1).await;
}
let result = client.handle().await;
if !client.is_admin() {
let _ = drain.send(-1).await;
}
result
}
Err(err) => Err(err),
}
@@ -139,14 +151,25 @@ pub async fn client_entrypoint(
addr,
bytes,
client_server_map,
shutdown_event_receiver,
shutdown,
admin_only,
)
.await
{
Ok(mut client) => {
info!("Client {:?} connected (plain)", addr);
client.handle().await
if !client.is_admin() {
let _ = drain.send(1).await;
}
let result = client.handle().await;
if !client.is_admin() {
let _ = drain.send(-1).await;
}
result
}
Err(err) => Err(err),
}
@@ -169,14 +192,25 @@ pub async fn client_entrypoint(
addr,
bytes,
client_server_map,
shutdown_event_receiver,
shutdown,
admin_only,
)
.await
{
Ok(mut client) => {
info!("Client {:?} connected (plain)", addr);
client.handle().await
if client.is_admin() {
let _ = drain.send(1).await;
}
let result = client.handle().await;
if !client.is_admin() {
let _ = drain.send(-1).await;
}
result
}
Err(err) => Err(err),
}
@@ -187,20 +221,21 @@ pub async fn client_entrypoint(
let (read, write) = split(stream);
// Continue with cancel query request.
match Client::cancel(
read,
write,
addr,
bytes,
client_server_map,
shutdown_event_receiver,
)
.await
{
match Client::cancel(read, write, addr, bytes, client_server_map, shutdown).await {
Ok(mut client) => {
info!("Client {:?} issued a cancel query request", addr);
client.handle().await
if client.is_admin() {
let _ = drain.send(1).await;
}
let result = client.handle().await;
if !client.is_admin() {
let _ = drain.send(-1).await;
}
result
}
Err(err) => Err(err),
@@ -253,7 +288,8 @@ where
pub async fn startup_tls(
stream: TcpStream,
client_server_map: ClientServerMap,
shutdown_event_receiver: Receiver<()>,
shutdown: Receiver<()>,
admin_only: bool,
) -> Result<Client<ReadHalf<TlsStream<TcpStream>>, WriteHalf<TlsStream<TcpStream>>>, Error> {
// Negotiate TLS.
let tls = Tls::new()?;
@@ -283,7 +319,8 @@ pub async fn startup_tls(
addr,
bytes,
client_server_map,
shutdown_event_receiver,
shutdown,
admin_only,
)
.await
}
@@ -298,6 +335,10 @@ where
S: tokio::io::AsyncRead + std::marker::Unpin,
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
pub fn is_admin(&self) -> bool {
self.admin
}
/// Handle Postgres client startup after TLS negotiation is complete
/// or over plain text.
pub async fn startup(
@@ -306,29 +347,44 @@ where
addr: std::net::SocketAddr,
bytes: BytesMut, // The rest of the startup message.
client_server_map: ClientServerMap,
shutdown_event_receiver: Receiver<()>,
shutdown: Receiver<()>,
admin_only: bool,
) -> Result<Client<S, T>, Error> {
let config = get_config();
let stats = get_reporter();
trace!("Got StartupMessage");
let parameters = parse_startup(bytes.clone())?;
let target_pool_name = match parameters.get("database") {
// These two parameters are mandatory by the protocol.
let pool_name = match parameters.get("database") {
Some(db) => db,
None => return Err(Error::ClientError),
};
let target_user_name = match parameters.get("user") {
let username = match parameters.get("user") {
Some(user) => user,
None => return Err(Error::ClientError),
};
let admin = ["pgcat", "pgbouncer"]
.iter()
.filter(|db| *db == &target_pool_name)
.filter(|db| *db == &pool_name)
.count()
== 1;
// Kick any client that's not admin while we're in admin-only mode.
if !admin && admin_only {
debug!(
"Rejecting non-admin connection to {} when in admin only mode",
pool_name
);
error_response_terminal(
&mut write,
&format!("terminating connection due to administrator command"),
)
.await?;
return Err(Error::ShuttingDown);
}
// Generate random backend ID and secret key
let process_id: i32 = rand::random();
let secret_key: i32 = rand::random();
@@ -360,46 +416,55 @@ where
Err(_) => return Err(Error::SocketError),
};
// Authenticate admin user.
let (transaction_mode, server_info) = if admin {
let correct_user = config.general.admin_username.as_str();
let correct_password = config.general.admin_password.as_str();
// Compare server and client hashes.
let password_hash = md5_hash_password(correct_user, correct_password, &salt);
let password_hash = md5_hash_password(
&config.general.admin_username,
&config.general.admin_password,
&salt,
);
if password_hash != password_response {
debug!("Password authentication failed");
wrong_password(&mut write, target_user_name).await?;
wrong_password(&mut write, username).await?;
return Err(Error::ClientError);
}
(false, generate_server_info_for_admin())
} else {
let target_pool = match get_pool(target_pool_name.clone(), target_user_name.clone()) {
}
// Authenticate normal user.
else {
let pool = match get_pool(pool_name.clone(), username.clone()) {
Some(pool) => pool,
None => {
error_response(
&mut write,
&format!(
"No pool configured for database: {:?}, user: {:?}",
target_pool_name, target_user_name
pool_name, username
),
)
.await?;
return Err(Error::ClientError);
}
};
let transaction_mode = target_pool.settings.pool_mode == "transaction";
let server_info = target_pool.server_info();
// Compare server and client hashes.
let correct_password = target_pool.settings.user.password.as_str();
let password_hash = md5_hash_password(&target_user_name, correct_password, &salt);
let password_hash = md5_hash_password(&username, &pool.settings.user.password, &salt);
if password_hash != password_response {
debug!("Password authentication failed");
wrong_password(&mut write, &target_user_name).await?;
wrong_password(&mut write, username).await?;
return Err(Error::ClientError);
}
(transaction_mode, server_info)
let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction;
(transaction_mode, pool.server_info())
};
debug!("Password authentication successful");
@@ -411,27 +476,24 @@ where
trace!("Startup OK");
// Split the read and write streams
// so we can control buffering.
return Ok(Client {
read: BufReader::new(read),
write: write,
addr,
buffer: BytesMut::with_capacity(8196),
cancel_mode: false,
transaction_mode: transaction_mode,
process_id: process_id,
secret_key: secret_key,
client_server_map: client_server_map,
transaction_mode,
process_id,
secret_key,
client_server_map,
parameters: parameters.clone(),
stats: stats,
admin: admin,
last_address_id: None,
last_server_id: None,
target_pool_name: target_pool_name.clone(),
target_user_name: target_user_name.clone(),
shutdown_event_receiver: shutdown_event_receiver,
pool_name: pool_name.clone(),
username: username.clone(),
shutdown,
connected_to_server: false,
});
}
@@ -443,7 +505,7 @@ where
addr: std::net::SocketAddr,
mut bytes: BytesMut, // The rest of the startup message.
client_server_map: ClientServerMap,
shutdown_event_receiver: Receiver<()>,
shutdown: Receiver<()>,
) -> Result<Client<S, T>, Error> {
let process_id = bytes.get_i32();
let secret_key = bytes.get_i32();
@@ -454,17 +516,17 @@ where
buffer: BytesMut::with_capacity(8196),
cancel_mode: true,
transaction_mode: false,
process_id: process_id,
secret_key: secret_key,
client_server_map: client_server_map,
process_id,
secret_key,
client_server_map,
parameters: HashMap::new(),
stats: get_reporter(),
admin: false,
last_address_id: None,
last_server_id: None,
target_pool_name: String::from("undefined"),
target_user_name: String::from("undefined"),
shutdown_event_receiver: shutdown_event_receiver,
pool_name: String::from("undefined"),
username: String::from("undefined"),
shutdown,
connected_to_server: false,
});
}
@@ -486,7 +548,7 @@ where
process_id.clone(),
secret_key.clone(),
address.clone(),
port.clone(),
*port,
),
// The client doesn't know / got the wrong server,
@@ -498,7 +560,7 @@ where
// Opens a new separate connection to the server, sends the backend_id
// and secret_key and then closes it for security reasons. No other interactions
// take place.
return Ok(Server::cancel(&address, &port, process_id, secret_key).await?);
return Ok(Server::cancel(&address, port, process_id, secret_key).await?);
}
// The query router determines where the query is going to go,
@@ -521,9 +583,19 @@ where
// SET SHARDING KEY TO 'bigint';
let mut message = tokio::select! {
_ = self.shutdown_event_receiver.recv() => {
error_response_terminal(&mut self.write, &format!("terminating connection due to administrator command")).await?;
return Ok(())
_ = self.shutdown.recv() => {
if !self.admin {
error_response_terminal(
&mut self.write,
&format!("terminating connection due to administrator command")
).await?;
return Ok(())
}
// Admin clients ignore shutdown.
else {
read_message(&mut self.read).await?
}
},
message_result = read_message(&mut self.read) => message_result?
};
@@ -544,15 +616,14 @@ where
// Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config
// when starting a query.
let pool = match get_pool(self.target_pool_name.clone(), self.target_user_name.clone())
{
let pool = match get_pool(self.pool_name.clone(), self.username.clone()) {
Some(pool) => pool,
None => {
error_response(
&mut self.write,
&format!(
"No pool configured for database: {:?}, user: {:?}",
self.target_pool_name, self.target_user_name
self.pool_name, self.username
),
)
.await?;
@@ -649,14 +720,16 @@ where
match message[0] as char {
'P' | 'B' | 'E' | 'D' => (),
_ => {
error!("Could not get connection from pool: {:?}", err);
error_response(
&mut self.write,
"could not get connection from the pool",
)
.await?;
}
}
};
error!("Could not get connection from pool: {:?}", err);
continue;
}
};
@@ -739,15 +812,8 @@ where
'Q' => {
debug!("Sending query to server");
self.send_and_receive_loop(
code,
original,
server,
&address,
query_router.shard(),
&pool,
)
.await?;
self.send_and_receive_loop(code, original, server, &address, &pool)
.await?;
if !server.in_transaction() {
// Report transaction executed statistics.
@@ -814,7 +880,6 @@ where
self.buffer.clone(),
server,
&address,
query_router.shard(),
&pool,
)
.await?;
@@ -836,32 +901,18 @@ where
'd' => {
// Forward the data to the server,
// don't buffer it since it can be rather large.
self.send_server_message(
server,
original,
&address,
query_router.shard(),
&pool,
)
.await?;
self.send_server_message(server, original, &address, &pool)
.await?;
}
// CopyDone or CopyFail
// Copy is done, successfully or not.
'c' | 'f' => {
self.send_server_message(
server,
original,
&address,
query_router.shard(),
&pool,
)
.await?;
let response = self
.receive_server_message(server, &address, query_router.shard(), &pool)
self.send_server_message(server, original, &address, &pool)
.await?;
let response = self.receive_server_message(server, &address, &pool).await?;
match write_all_half(&mut self.write, response).await {
Ok(_) => (),
Err(err) => {
@@ -910,20 +961,17 @@ where
message: BytesMut,
server: &mut Server,
address: &Address,
shard: usize,
pool: &ConnectionPool,
) -> Result<(), Error> {
debug!("Sending {} to server", code);
self.send_server_message(server, message, &address, shard, &pool)
self.send_server_message(server, message, &address, &pool)
.await?;
// Read all data the server has to offer, which can be multiple messages
// buffered in 8196 bytes chunks.
loop {
let response = self
.receive_server_message(server, &address, shard, &pool)
.await?;
let response = self.receive_server_message(server, &address, &pool).await?;
match write_all_half(&mut self.write, response).await {
Ok(_) => (),
@@ -949,13 +997,12 @@ where
server: &mut Server,
message: BytesMut,
address: &Address,
shard: usize,
pool: &ConnectionPool,
) -> Result<(), Error> {
match server.send(message).await {
Ok(_) => Ok(()),
Err(err) => {
pool.ban(address, shard, self.process_id);
pool.ban(address, self.process_id);
Err(err)
}
}
@@ -965,7 +1012,6 @@ where
&mut self,
server: &mut Server,
address: &Address,
shard: usize,
pool: &ConnectionPool,
) -> Result<BytesMut, Error> {
if pool.settings.user.statement_timeout > 0 {
@@ -978,7 +1024,7 @@ where
Ok(result) => match result {
Ok(message) => Ok(message),
Err(err) => {
pool.ban(address, shard, self.process_id);
pool.ban(address, self.process_id);
error_response_terminal(
&mut self.write,
&format!("error receiving data from server: {:?}", err),
@@ -993,7 +1039,7 @@ where
address, pool.settings.user.username
);
server.mark_bad();
pool.ban(address, shard, self.process_id);
pool.ban(address, self.process_id);
error_response_terminal(&mut self.write, "pool statement timeout").await?;
Err(Error::StatementTimeout)
}
@@ -1002,7 +1048,7 @@ where
match server.recv().await {
Ok(message) => Ok(message),
Err(err) => {
pool.ban(address, shard, self.process_id);
pool.ban(address, self.process_id);
error_response_terminal(
&mut self.write,
&format!("error receiving data from server: {:?}", err),