Handle and track startup parameters (#478)

* User server parameters struct instead of server info bytesmut

* Refactor to use hashmap for all params and add server parameters to client

* Sync parameters on client server checkout

* minor refactor

* update client side parameters when changed

* Move the SET statement logic from the C packet to the S packet.

* trigger build

* revert validation changes

* remove comment

* Try fix

* Reset cleanup state after sync

* fix server version test

* Track application name through client life for stats

* Add tests

* minor refactoring

* fmt

* fix

* fmt
This commit is contained in:
Zain Kabani
2023-08-10 11:18:46 -04:00
committed by GitHub
parent 9ab128579d
commit f94ce97ebc
8 changed files with 308 additions and 123 deletions

View File

@@ -12,7 +12,7 @@ use tokio::net::TcpStream;
use tokio::sync::broadcast::Receiver;
use tokio::sync::mpsc::Sender;
use crate::admin::{generate_server_info_for_admin, handle_admin};
use crate::admin::{generate_server_parameters_for_admin, handle_admin};
use crate::auth_passthrough::refetch_auth_hash;
use crate::config::{
get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address, PoolMode,
@@ -22,7 +22,7 @@ use crate::messages::*;
use crate::plugins::PluginOutput;
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
use crate::query_router::{Command, QueryRouter};
use crate::server::Server;
use crate::server::{Server, ServerParameters};
use crate::stats::{ClientStats, ServerStats};
use crate::tls::Tls;
@@ -96,8 +96,8 @@ pub struct Client<S, T> {
/// Postgres user for this client (This comes from the user in the connection string)
username: String,
/// Application name for this client (defaults to pgcat)
application_name: String,
/// Server startup and session parameters that we're going to track
server_parameters: ServerParameters,
/// Used to notify clients about an impending shutdown
shutdown: Receiver<()>,
@@ -502,7 +502,7 @@ where
};
// Authenticate admin user.
let (transaction_mode, server_info) = if admin {
let (transaction_mode, mut server_parameters) = if admin {
let config = get_config();
// Compare server and client hashes.
@@ -521,7 +521,7 @@ where
return Err(error);
}
(false, generate_server_info_for_admin())
(false, generate_server_parameters_for_admin())
}
// Authenticate normal user.
else {
@@ -654,13 +654,16 @@ where
}
}
(transaction_mode, pool.server_info())
(transaction_mode, pool.server_parameters())
};
// Update the parameters to merge what the application sent and what's originally on the server
server_parameters.set_from_hashmap(&parameters, false);
debug!("Password authentication successful");
auth_ok(&mut write).await?;
write_all(&mut write, server_info).await?;
write_all(&mut write, (&server_parameters).into()).await?;
backend_key_data(&mut write, process_id, secret_key).await?;
ready_for_query(&mut write).await?;
@@ -690,7 +693,7 @@ where
last_server_stats: None,
pool_name: pool_name.clone(),
username: username.clone(),
application_name: application_name.to_string(),
server_parameters,
shutdown,
connected_to_server: false,
prepared_statements: HashMap::new(),
@@ -725,7 +728,7 @@ where
last_server_stats: None,
pool_name: String::from("undefined"),
username: String::from("undefined"),
application_name: String::from("undefined"),
server_parameters: ServerParameters::new(),
shutdown,
connected_to_server: false,
prepared_statements: HashMap::new(),
@@ -774,8 +777,11 @@ where
let mut prepared_statement = None;
let mut will_prepare = false;
let client_identifier =
ClientIdentifier::new(&self.application_name, &self.username, &self.pool_name);
let client_identifier = ClientIdentifier::new(
&self.server_parameters.get_application_name(),
&self.username,
&self.pool_name,
);
// Our custom protocol loop.
// We expect the client to either start a transaction with regular queries
@@ -1115,10 +1121,7 @@ where
server.address()
);
// TODO: investigate other parameters and set them too.
// Set application_name.
server.set_name(&self.application_name).await?;
server.sync_parameters(&self.server_parameters).await?;
let mut initial_message = Some(message);
@@ -1296,7 +1299,9 @@ where
if !server.in_transaction() {
// Report transaction executed statistics.
self.stats.transaction();
server.stats().transaction(&self.application_name);
server
.stats()
.transaction(&self.server_parameters.get_application_name());
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
@@ -1446,7 +1451,9 @@ where
if !server.in_transaction() {
self.stats.transaction();
server.stats().transaction(&self.application_name);
server
.stats()
.transaction(&self.server_parameters.get_application_name());
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
@@ -1495,7 +1502,9 @@ where
if !server.in_transaction() {
self.stats.transaction();
server.stats().transaction(&self.application_name);
server
.stats()
.transaction(self.server_parameters.get_application_name());
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
@@ -1547,7 +1556,9 @@ where
Err(Error::ClientError(format!(
"Invalid pool name {{ username: {}, pool_name: {}, application_name: {} }}",
self.pool_name, self.username, self.application_name
self.pool_name,
self.username,
self.server_parameters.get_application_name()
)))
}
}
@@ -1704,7 +1715,7 @@ where
client_stats.query();
server.stats().query(
Instant::now().duration_since(query_start).as_millis() as u64,
&self.application_name,
&self.server_parameters.get_application_name(),
);
Ok(())
@@ -1733,38 +1744,18 @@ where
pool: &ConnectionPool,
client_stats: &ClientStats,
) -> Result<BytesMut, Error> {
if pool.settings.user.statement_timeout > 0 {
match tokio::time::timeout(
tokio::time::Duration::from_millis(pool.settings.user.statement_timeout),
server.recv(),
)
.await
{
Ok(result) => match result {
Ok(message) => Ok(message),
Err(err) => {
pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats));
error_response_terminal(
&mut self.write,
&format!("error receiving data from server: {:?}", err),
)
.await?;
Err(err)
}
},
Err(_) => {
error!(
"Statement timeout while talking to {:?} with user {}",
address, pool.settings.user.username
);
server.mark_bad();
pool.ban(address, BanReason::StatementTimeout, Some(client_stats));
error_response_terminal(&mut self.write, "pool statement timeout").await?;
Err(Error::StatementTimeout)
}
}
} else {
match server.recv().await {
let statement_timeout_duration = match pool.settings.user.statement_timeout {
0 => tokio::time::Duration::MAX,
timeout => tokio::time::Duration::from_millis(timeout),
};
match tokio::time::timeout(
statement_timeout_duration,
server.recv(Some(&mut self.server_parameters)),
)
.await
{
Ok(result) => match result {
Ok(message) => Ok(message),
Err(err) => {
pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats));
@@ -1775,6 +1766,16 @@ where
.await?;
Err(err)
}
},
Err(_) => {
error!(
"Statement timeout while talking to {:?} with user {}",
address, pool.settings.user.username
);
server.mark_bad();
pool.ban(address, BanReason::StatementTimeout, Some(client_stats));
error_response_terminal(&mut self.write, "pool statement timeout").await?;
Err(Error::StatementTimeout)
}
}
}