Handle and track startup parameters (#478)

* User server parameters struct instead of server info bytesmut

* Refactor to use hashmap for all params and add server parameters to client

* Sync parameters on client server checkout

* minor refactor

* update client side parameters when changed

* Move the SET statement logic from the C packet to the S packet.

* trigger build

* revert validation changes

* remove comment

* Try fix

* Reset cleanup state after sync

* fix server version test

* Track application name through client life for stats

* Add tests

* minor refactoring

* fmt

* fix

* fmt
This commit is contained in:
Zain Kabani
2023-08-10 11:18:46 -04:00
committed by GitHub
parent 9ab128579d
commit f94ce97ebc
8 changed files with 308 additions and 123 deletions

View File

@@ -3,12 +3,13 @@
use bytes::{Buf, BufMut, BytesMut};
use fallible_iterator::FallibleIterator;
use log::{debug, error, info, trace, warn};
use once_cell::sync::Lazy;
use parking_lot::{Mutex, RwLock};
use postgres_protocol::message;
use std::collections::{BTreeSet, HashMap};
use std::io::Read;
use std::collections::{BTreeSet, HashMap, HashSet};
use std::mem;
use std::net::IpAddr;
use std::sync::Arc;
use std::sync::{Arc, Once};
use std::time::SystemTime;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream};
use tokio::net::TcpStream;
@@ -19,6 +20,7 @@ use crate::config::{get_config, get_prepared_statements_cache_size, Address, Use
use crate::constants::*;
use crate::dns_cache::{AddrSet, CACHED_RESOLVER};
use crate::errors::{Error, ServerIdentifier};
use crate::messages::BytesMutReader;
use crate::messages::*;
use crate::mirrors::MirroringManager;
use crate::pool::ClientServerMap;
@@ -145,6 +147,124 @@ impl std::fmt::Display for CleanupState {
}
}
static TRACKED_PARAMETERS: Lazy<HashSet<String>> = Lazy::new(|| {
let mut set = HashSet::new();
set.insert("client_encoding".to_string());
set.insert("DateStyle".to_string());
set.insert("TimeZone".to_string());
set.insert("standard_conforming_strings".to_string());
set.insert("application_name".to_string());
set
});
#[derive(Debug, Clone)]
pub struct ServerParameters {
parameters: HashMap<String, String>,
}
impl Default for ServerParameters {
fn default() -> Self {
Self::new()
}
}
impl ServerParameters {
pub fn new() -> Self {
let mut server_parameters = ServerParameters {
parameters: HashMap::new(),
};
server_parameters.set_param("client_encoding".to_string(), "UTF8".to_string(), false);
server_parameters.set_param("DateStyle".to_string(), "ISO, MDY".to_string(), false);
server_parameters.set_param("TimeZone".to_string(), "Etc/UTC".to_string(), false);
server_parameters.set_param(
"standard_conforming_strings".to_string(),
"on".to_string(),
false,
);
server_parameters.set_param("application_name".to_string(), "pgcat".to_string(), false);
server_parameters
}
/// returns true if a tracked parameter was set, false if it was a non-tracked parameter
/// if startup is false, then then only tracked parameters will be set
pub fn set_param(&mut self, mut key: String, value: String, startup: bool) {
// The startup parameter will send uncapitalized keys but parameter status packets will send capitalized keys
if key == "timezone" {
key = "TimeZone".to_string();
} else if key == "datestyle" {
key = "DateStyle".to_string();
};
if TRACKED_PARAMETERS.contains(&key) {
self.parameters.insert(key, value);
} else {
if startup {
self.parameters.insert(key, value);
}
}
}
pub fn set_from_hashmap(&mut self, parameters: &HashMap<String, String>, startup: bool) {
// iterate through each and call set_param
for (key, value) in parameters {
self.set_param(key.to_string(), value.to_string(), startup);
}
}
// Gets the diff of the parameters
fn compare_params(&self, incoming_parameters: &ServerParameters) -> HashMap<String, String> {
let mut diff = HashMap::new();
// iterate through tracked parameters
for key in TRACKED_PARAMETERS.iter() {
if let Some(incoming_value) = incoming_parameters.parameters.get(key) {
if let Some(value) = self.parameters.get(key) {
if value != incoming_value {
diff.insert(key.to_string(), incoming_value.to_string());
}
}
}
}
diff
}
pub fn get_application_name(&self) -> &String {
// Can unwrap because we set it in the constructor
self.parameters.get("application_name").unwrap()
}
fn add_parameter_message(key: &str, value: &str, buffer: &mut BytesMut) {
buffer.put_u8(b'S');
// 4 is len of i32, the plus for the null terminator
let len = 4 + key.len() + 1 + value.len() + 1;
buffer.put_i32(len as i32);
buffer.put_slice(key.as_bytes());
buffer.put_u8(0);
buffer.put_slice(value.as_bytes());
buffer.put_u8(0);
}
}
impl From<&ServerParameters> for BytesMut {
fn from(server_parameters: &ServerParameters) -> Self {
let mut bytes = BytesMut::new();
for (key, value) in &server_parameters.parameters {
ServerParameters::add_parameter_message(key, value, &mut bytes);
}
bytes
}
}
// pub fn compare
/// Server state.
pub struct Server {
/// Server host, e.g. localhost,
@@ -158,7 +278,7 @@ pub struct Server {
buffer: BytesMut,
/// Server information the server sent us over on startup.
server_info: BytesMut,
server_parameters: ServerParameters,
/// Backend id and secret key used for query cancellation.
process_id: i32,
@@ -347,7 +467,6 @@ impl Server {
startup(&mut stream, username, database).await?;
let mut server_info = BytesMut::new();
let mut process_id: i32 = 0;
let mut secret_key: i32 = 0;
let server_identifier = ServerIdentifier::new(username, &database);
@@ -359,6 +478,8 @@ impl Server {
None => None,
};
let mut server_parameters = ServerParameters::new();
loop {
let code = match stream.read_u8().await {
Ok(code) => code as char,
@@ -616,9 +737,10 @@ impl Server {
// ParameterStatus
'S' => {
let mut param = vec![0u8; len as usize - 4];
let mut bytes = BytesMut::with_capacity(len as usize - 4);
bytes.resize(len as usize - mem::size_of::<i32>(), b'0');
match stream.read_exact(&mut param).await {
match stream.read_exact(&mut bytes[..]).await {
Ok(_) => (),
Err(_) => {
return Err(Error::ServerStartupError(
@@ -628,12 +750,13 @@ impl Server {
}
};
let key = bytes.read_string().unwrap();
let value = bytes.read_string().unwrap();
// Save the parameter so we can pass it to the client later.
// These can be server_encoding, client_encoding, server timezone, Postgres version,
// and many more interesting things we should know about the Postgres server we are talking to.
server_info.put_u8(b'S');
server_info.put_i32(len);
server_info.put_slice(&param[..]);
server_parameters.set_param(key, value, true);
}
// BackendKeyData
@@ -675,11 +798,11 @@ impl Server {
}
};
let mut server = Server {
let server = Server {
address: address.clone(),
stream: BufStream::new(stream),
buffer: BytesMut::with_capacity(8196),
server_info,
server_parameters,
process_id,
secret_key,
in_transaction: false,
@@ -691,7 +814,7 @@ impl Server {
addr_set,
connected_at: chrono::offset::Utc::now().naive_utc(),
stats,
application_name: String::new(),
application_name: "pgcat".to_string(),
last_activity: SystemTime::now(),
mirror_manager: match address.mirrors.len() {
0 => None,
@@ -705,8 +828,6 @@ impl Server {
prepared_statements: BTreeSet::new(),
};
server.set_name("pgcat").await?;
return Ok(server);
}
@@ -776,7 +897,10 @@ impl Server {
/// Receive data from the server in response to a client request.
/// This method must be called multiple times while `self.is_data_available()` is true
/// in order to receive all data the server has to offer.
pub async fn recv(&mut self) -> Result<BytesMut, Error> {
pub async fn recv(
&mut self,
mut client_server_parameters: Option<&mut ServerParameters>,
) -> Result<BytesMut, Error> {
loop {
let mut message = match read_message(&mut self.stream).await {
Ok(message) => message,
@@ -848,14 +972,13 @@ impl Server {
self.in_copy_mode = false;
}
let mut command_tag = String::new();
match message.reader().read_to_string(&mut command_tag) {
Ok(_) => {
match message.read_string() {
Ok(command) => {
// Non-exhaustive list of commands that are likely to change session variables/resources
// which can leak between clients. This is a best effort to block bad clients
// from poisoning a transaction-mode pool by setting inappropriate session variables
match command_tag.as_str() {
"SET\0" => {
match command.as_str() {
"SET" => {
// We don't detect set statements in transactions
// No great way to differentiate between set and set local
// As a result, we will miss cases when set statements are used in transactions
@@ -865,7 +988,8 @@ impl Server {
self.cleanup_state.needs_cleanup_set = true;
}
}
"PREPARE\0" => {
"PREPARE" => {
debug!("Server connection marked for clean up");
self.cleanup_state.needs_cleanup_prepare = true;
}
@@ -879,6 +1003,17 @@ impl Server {
}
}
'S' => {
let key = message.read_string().unwrap();
let value = message.read_string().unwrap();
if let Some(client_server_parameters) = client_server_parameters.as_mut() {
client_server_parameters.set_param(key.clone(), value.clone(), false);
}
self.server_parameters.set_param(key, value, false);
}
// DataRow
'D' => {
// More data is available after this message, this is not the end of the reply.
@@ -1089,9 +1224,28 @@ impl Server {
}
/// Get server startup information to forward it to the client.
/// Not used at the moment.
pub fn server_info(&self) -> BytesMut {
self.server_info.clone()
pub fn server_parameters(&self) -> ServerParameters {
self.server_parameters.clone()
}
pub async fn sync_parameters(&mut self, parameters: &ServerParameters) -> Result<(), Error> {
let parameter_diff = self.server_parameters.compare_params(parameters);
if parameter_diff.is_empty() {
return Ok(());
}
let mut query = String::from("");
for (key, value) in parameter_diff {
query.push_str(&format!("SET {} TO '{}';", key, value));
}
let res = self.query(&query).await;
self.cleanup_state.reset();
res
}
/// Indicate that this server connection cannot be re-used and must be discarded.
@@ -1125,7 +1279,7 @@ impl Server {
self.send(&query).await?;
loop {
let _ = self.recv().await?;
let _ = self.recv(None).await?;
if !self.data_available {
break;
@@ -1166,24 +1320,6 @@ impl Server {
Ok(())
}
/// A shorthand for `SET application_name = $1`.
pub async fn set_name(&mut self, name: &str) -> Result<(), Error> {
if self.application_name != name {
self.application_name = name.to_string();
// We don't want `SET application_name` to mark the server connection
// as needing cleanup
let needs_cleanup_before = self.cleanup_state;
let result = Ok(self
.query(&format!("SET application_name = '{}'", name))
.await?);
self.cleanup_state = needs_cleanup_before;
result
} else {
Ok(())
}
}
/// get Server stats
pub fn stats(&self) -> Arc<ServerStats> {
self.stats.clone()
@@ -1241,7 +1377,7 @@ impl Server {
.await?;
debug!("Connected!, sending query.");
server.send(&simple_query(query)).await?;
let mut message = server.recv().await?;
let mut message = server.recv(None).await?;
Ok(parse_query_message(&mut message).await?)
}