Handle and track startup parameters (#478)

* User server parameters struct instead of server info bytesmut

* Refactor to use hashmap for all params and add server parameters to client

* Sync parameters on client server checkout

* minor refactor

* update client side parameters when changed

* Move the SET statement logic from the C packet to the S packet.

* trigger build

* revert validation changes

* remove comment

* Try fix

* Reset cleanup state after sync

* fix server version test

* Track application name through client life for stats

* Add tests

* minor refactoring

* fmt

* fix

* fmt
This commit is contained in:
Zain Kabani
2023-08-10 11:18:46 -04:00
committed by GitHub
parent 9ab128579d
commit f94ce97ebc
8 changed files with 308 additions and 123 deletions

View File

@@ -1,4 +1,5 @@
use crate::pool::BanReason; use crate::pool::BanReason;
use crate::server::ServerParameters;
use crate::stats::pool::PoolStats; use crate::stats::pool::PoolStats;
use bytes::{Buf, BufMut, BytesMut}; use bytes::{Buf, BufMut, BytesMut};
use log::{error, info, trace}; use log::{error, info, trace};
@@ -17,16 +18,16 @@ use crate::pool::ClientServerMap;
use crate::pool::{get_all_pools, get_pool}; use crate::pool::{get_all_pools, get_pool};
use crate::stats::{get_client_stats, get_server_stats, ClientState, ServerState}; use crate::stats::{get_client_stats, get_server_stats, ClientState, ServerState};
pub fn generate_server_info_for_admin() -> BytesMut { pub fn generate_server_parameters_for_admin() -> ServerParameters {
let mut server_info = BytesMut::new(); let mut server_parameters = ServerParameters::new();
server_info.put(server_parameter_message("application_name", "")); server_parameters.set_param("application_name".to_string(), "".to_string(), true);
server_info.put(server_parameter_message("client_encoding", "UTF8")); server_parameters.set_param("client_encoding".to_string(), "UTF8".to_string(), true);
server_info.put(server_parameter_message("server_encoding", "UTF8")); server_parameters.set_param("server_encoding".to_string(), "UTF8".to_string(), true);
server_info.put(server_parameter_message("server_version", VERSION)); server_parameters.set_param("server_version".to_string(), VERSION.to_string(), true);
server_info.put(server_parameter_message("DateStyle", "ISO, MDY")); server_parameters.set_param("DateStyle".to_string(), "ISO, MDY".to_string(), true);
server_info server_parameters
} }
/// Handle admin client. /// Handle admin client.

View File

@@ -12,7 +12,7 @@ use tokio::net::TcpStream;
use tokio::sync::broadcast::Receiver; use tokio::sync::broadcast::Receiver;
use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::Sender;
use crate::admin::{generate_server_info_for_admin, handle_admin}; use crate::admin::{generate_server_parameters_for_admin, handle_admin};
use crate::auth_passthrough::refetch_auth_hash; use crate::auth_passthrough::refetch_auth_hash;
use crate::config::{ use crate::config::{
get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address, PoolMode, get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address, PoolMode,
@@ -22,7 +22,7 @@ use crate::messages::*;
use crate::plugins::PluginOutput; use crate::plugins::PluginOutput;
use crate::pool::{get_pool, ClientServerMap, ConnectionPool}; use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
use crate::query_router::{Command, QueryRouter}; use crate::query_router::{Command, QueryRouter};
use crate::server::Server; use crate::server::{Server, ServerParameters};
use crate::stats::{ClientStats, ServerStats}; use crate::stats::{ClientStats, ServerStats};
use crate::tls::Tls; use crate::tls::Tls;
@@ -96,8 +96,8 @@ pub struct Client<S, T> {
/// Postgres user for this client (This comes from the user in the connection string) /// Postgres user for this client (This comes from the user in the connection string)
username: String, username: String,
/// Application name for this client (defaults to pgcat) /// Server startup and session parameters that we're going to track
application_name: String, server_parameters: ServerParameters,
/// Used to notify clients about an impending shutdown /// Used to notify clients about an impending shutdown
shutdown: Receiver<()>, shutdown: Receiver<()>,
@@ -502,7 +502,7 @@ where
}; };
// Authenticate admin user. // Authenticate admin user.
let (transaction_mode, server_info) = if admin { let (transaction_mode, mut server_parameters) = if admin {
let config = get_config(); let config = get_config();
// Compare server and client hashes. // Compare server and client hashes.
@@ -521,7 +521,7 @@ where
return Err(error); return Err(error);
} }
(false, generate_server_info_for_admin()) (false, generate_server_parameters_for_admin())
} }
// Authenticate normal user. // Authenticate normal user.
else { else {
@@ -654,13 +654,16 @@ where
} }
} }
(transaction_mode, pool.server_info()) (transaction_mode, pool.server_parameters())
}; };
// Update the parameters to merge what the application sent and what's originally on the server
server_parameters.set_from_hashmap(&parameters, false);
debug!("Password authentication successful"); debug!("Password authentication successful");
auth_ok(&mut write).await?; auth_ok(&mut write).await?;
write_all(&mut write, server_info).await?; write_all(&mut write, (&server_parameters).into()).await?;
backend_key_data(&mut write, process_id, secret_key).await?; backend_key_data(&mut write, process_id, secret_key).await?;
ready_for_query(&mut write).await?; ready_for_query(&mut write).await?;
@@ -690,7 +693,7 @@ where
last_server_stats: None, last_server_stats: None,
pool_name: pool_name.clone(), pool_name: pool_name.clone(),
username: username.clone(), username: username.clone(),
application_name: application_name.to_string(), server_parameters,
shutdown, shutdown,
connected_to_server: false, connected_to_server: false,
prepared_statements: HashMap::new(), prepared_statements: HashMap::new(),
@@ -725,7 +728,7 @@ where
last_server_stats: None, last_server_stats: None,
pool_name: String::from("undefined"), pool_name: String::from("undefined"),
username: String::from("undefined"), username: String::from("undefined"),
application_name: String::from("undefined"), server_parameters: ServerParameters::new(),
shutdown, shutdown,
connected_to_server: false, connected_to_server: false,
prepared_statements: HashMap::new(), prepared_statements: HashMap::new(),
@@ -774,8 +777,11 @@ where
let mut prepared_statement = None; let mut prepared_statement = None;
let mut will_prepare = false; let mut will_prepare = false;
let client_identifier = let client_identifier = ClientIdentifier::new(
ClientIdentifier::new(&self.application_name, &self.username, &self.pool_name); &self.server_parameters.get_application_name(),
&self.username,
&self.pool_name,
);
// Our custom protocol loop. // Our custom protocol loop.
// We expect the client to either start a transaction with regular queries // We expect the client to either start a transaction with regular queries
@@ -1115,10 +1121,7 @@ where
server.address() server.address()
); );
// TODO: investigate other parameters and set them too. server.sync_parameters(&self.server_parameters).await?;
// Set application_name.
server.set_name(&self.application_name).await?;
let mut initial_message = Some(message); let mut initial_message = Some(message);
@@ -1296,7 +1299,9 @@ where
if !server.in_transaction() { if !server.in_transaction() {
// Report transaction executed statistics. // Report transaction executed statistics.
self.stats.transaction(); self.stats.transaction();
server.stats().transaction(&self.application_name); server
.stats()
.transaction(&self.server_parameters.get_application_name());
// Release server back to the pool if we are in transaction mode. // Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects. // If we are in session mode, we keep the server until the client disconnects.
@@ -1446,7 +1451,9 @@ where
if !server.in_transaction() { if !server.in_transaction() {
self.stats.transaction(); self.stats.transaction();
server.stats().transaction(&self.application_name); server
.stats()
.transaction(&self.server_parameters.get_application_name());
// Release server back to the pool if we are in transaction mode. // Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects. // If we are in session mode, we keep the server until the client disconnects.
@@ -1495,7 +1502,9 @@ where
if !server.in_transaction() { if !server.in_transaction() {
self.stats.transaction(); self.stats.transaction();
server.stats().transaction(&self.application_name); server
.stats()
.transaction(self.server_parameters.get_application_name());
// Release server back to the pool if we are in transaction mode. // Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects. // If we are in session mode, we keep the server until the client disconnects.
@@ -1547,7 +1556,9 @@ where
Err(Error::ClientError(format!( Err(Error::ClientError(format!(
"Invalid pool name {{ username: {}, pool_name: {}, application_name: {} }}", "Invalid pool name {{ username: {}, pool_name: {}, application_name: {} }}",
self.pool_name, self.username, self.application_name self.pool_name,
self.username,
self.server_parameters.get_application_name()
))) )))
} }
} }
@@ -1704,7 +1715,7 @@ where
client_stats.query(); client_stats.query();
server.stats().query( server.stats().query(
Instant::now().duration_since(query_start).as_millis() as u64, Instant::now().duration_since(query_start).as_millis() as u64,
&self.application_name, &self.server_parameters.get_application_name(),
); );
Ok(()) Ok(())
@@ -1733,38 +1744,18 @@ where
pool: &ConnectionPool, pool: &ConnectionPool,
client_stats: &ClientStats, client_stats: &ClientStats,
) -> Result<BytesMut, Error> { ) -> Result<BytesMut, Error> {
if pool.settings.user.statement_timeout > 0 { let statement_timeout_duration = match pool.settings.user.statement_timeout {
match tokio::time::timeout( 0 => tokio::time::Duration::MAX,
tokio::time::Duration::from_millis(pool.settings.user.statement_timeout), timeout => tokio::time::Duration::from_millis(timeout),
server.recv(), };
)
.await match tokio::time::timeout(
{ statement_timeout_duration,
Ok(result) => match result { server.recv(Some(&mut self.server_parameters)),
Ok(message) => Ok(message), )
Err(err) => { .await
pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats)); {
error_response_terminal( Ok(result) => match result {
&mut self.write,
&format!("error receiving data from server: {:?}", err),
)
.await?;
Err(err)
}
},
Err(_) => {
error!(
"Statement timeout while talking to {:?} with user {}",
address, pool.settings.user.username
);
server.mark_bad();
pool.ban(address, BanReason::StatementTimeout, Some(client_stats));
error_response_terminal(&mut self.write, "pool statement timeout").await?;
Err(Error::StatementTimeout)
}
}
} else {
match server.recv().await {
Ok(message) => Ok(message), Ok(message) => Ok(message),
Err(err) => { Err(err) => {
pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats)); pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats));
@@ -1775,6 +1766,16 @@ where
.await?; .await?;
Err(err) Err(err)
} }
},
Err(_) => {
error!(
"Statement timeout while talking to {:?} with user {}",
address, pool.settings.user.username
);
server.mark_bad();
pool.ban(address, BanReason::StatementTimeout, Some(client_stats));
error_response_terminal(&mut self.write, "pool statement timeout").await?;
Err(Error::StatementTimeout)
} }
} }
} }

View File

@@ -144,6 +144,10 @@ where
bytes.put_slice(user.as_bytes()); bytes.put_slice(user.as_bytes());
bytes.put_u8(0); bytes.put_u8(0);
// Application name
bytes.put(&b"application_name\0"[..]);
bytes.put_slice(&b"pgcat\0"[..]);
// Database // Database
bytes.put(&b"database\0"[..]); bytes.put(&b"database\0"[..]);
bytes.put_slice(database.as_bytes()); bytes.put_slice(database.as_bytes());
@@ -731,6 +735,21 @@ impl BytesMutReader for Cursor<&BytesMut> {
} }
} }
impl BytesMutReader for BytesMut {
/// Should only be used when reading strings from the message protocol.
/// Can be used to read multiple strings from the same message which are separated by the null byte
fn read_string(&mut self) -> Result<String, Error> {
let null_index = self.iter().position(|&byte| byte == b'\0');
match null_index {
Some(index) => {
let string_bytes = self.split_to(index + 1);
Ok(String::from_utf8_lossy(&string_bytes[..string_bytes.len() - 1]).to_string())
}
None => return Err(Error::ParseBytesError("Could not read string".to_string())),
}
}
}
/// Parse (F) message. /// Parse (F) message.
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html> /// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
#[derive(Clone, Debug)] #[derive(Clone, Debug)]

View File

@@ -78,7 +78,7 @@ impl MirroredClient {
} }
// Incoming data from server (we read to clear the socket buffer and discard the data) // Incoming data from server (we read to clear the socket buffer and discard the data)
recv_result = server.recv() => { recv_result = server.recv(None) => {
match recv_result { match recv_result {
Ok(message) => trace!("Received from mirror: {} {:?}", String::from_utf8_lossy(&message[..]), address.clone()), Ok(message) => trace!("Received from mirror: {} {:?}", String::from_utf8_lossy(&message[..]), address.clone()),
Err(err) => { Err(err) => {

View File

@@ -1,7 +1,6 @@
use arc_swap::ArcSwap; use arc_swap::ArcSwap;
use async_trait::async_trait; use async_trait::async_trait;
use bb8::{ManageConnection, Pool, PooledConnection, QueueStrategy}; use bb8::{ManageConnection, Pool, PooledConnection, QueueStrategy};
use bytes::{BufMut, BytesMut};
use chrono::naive::NaiveDateTime; use chrono::naive::NaiveDateTime;
use log::{debug, error, info, warn}; use log::{debug, error, info, warn};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
@@ -25,7 +24,7 @@ use crate::errors::Error;
use crate::auth_passthrough::AuthPassthrough; use crate::auth_passthrough::AuthPassthrough;
use crate::plugins::prewarmer; use crate::plugins::prewarmer;
use crate::server::Server; use crate::server::{Server, ServerParameters};
use crate::sharding::ShardingFunction; use crate::sharding::ShardingFunction;
use crate::stats::{AddressStats, ClientStats, ServerStats}; use crate::stats::{AddressStats, ClientStats, ServerStats};
@@ -196,10 +195,10 @@ pub struct ConnectionPool {
/// that should not be queried. /// that should not be queried.
banlist: BanList, banlist: BanList,
/// The server information (K messages) have to be passed to the /// The server information has to be passed to the
/// clients on startup. We pre-connect to all shards and replicas /// clients on startup. We pre-connect to all shards and replicas
/// on pool creation and save the K messages here. /// on pool creation and save the startup parameters here.
server_info: Arc<RwLock<BytesMut>>, original_server_parameters: Arc<RwLock<ServerParameters>>,
/// Pool configuration. /// Pool configuration.
pub settings: PoolSettings, pub settings: PoolSettings,
@@ -445,7 +444,7 @@ impl ConnectionPool {
addresses, addresses,
banlist: Arc::new(RwLock::new(banlist)), banlist: Arc::new(RwLock::new(banlist)),
config_hash: new_pool_hash_value, config_hash: new_pool_hash_value,
server_info: Arc::new(RwLock::new(BytesMut::new())), original_server_parameters: Arc::new(RwLock::new(ServerParameters::new())),
auth_hash: pool_auth_hash, auth_hash: pool_auth_hash,
settings: PoolSettings { settings: PoolSettings {
pool_mode: match user.pool_mode { pool_mode: match user.pool_mode {
@@ -528,7 +527,7 @@ impl ConnectionPool {
for server in 0..self.servers(shard) { for server in 0..self.servers(shard) {
let databases = self.databases.clone(); let databases = self.databases.clone();
let validated = Arc::clone(&validated); let validated = Arc::clone(&validated);
let pool_server_info = Arc::clone(&self.server_info); let pool_server_parameters = Arc::clone(&self.original_server_parameters);
let task = tokio::task::spawn(async move { let task = tokio::task::spawn(async move {
let connection = match databases[shard][server].get().await { let connection = match databases[shard][server].get().await {
@@ -541,11 +540,10 @@ impl ConnectionPool {
let proxy = connection; let proxy = connection;
let server = &*proxy; let server = &*proxy;
let server_info = server.server_info(); let server_parameters: ServerParameters = server.server_parameters();
let mut guard = pool_server_info.write(); let mut guard = pool_server_parameters.write();
guard.clear(); *guard = server_parameters;
guard.put(server_info.clone());
validated.store(true, Ordering::Relaxed); validated.store(true, Ordering::Relaxed);
}); });
@@ -557,7 +555,7 @@ impl ConnectionPool {
// TODO: compare server information to make sure // TODO: compare server information to make sure
// all shards are running identical configurations. // all shards are running identical configurations.
if self.server_info.read().is_empty() { if !self.validated() {
error!("Could not validate connection pool"); error!("Could not validate connection pool");
return Err(Error::AllServersDown); return Err(Error::AllServersDown);
} }
@@ -917,8 +915,8 @@ impl ConnectionPool {
&self.addresses[shard][server] &self.addresses[shard][server]
} }
pub fn server_info(&self) -> BytesMut { pub fn server_parameters(&self) -> ServerParameters {
self.server_info.read().clone() self.original_server_parameters.read().clone()
} }
fn busy_connection_count(&self, address: &Address) -> u32 { fn busy_connection_count(&self, address: &Address) -> u32 {

View File

@@ -3,12 +3,13 @@
use bytes::{Buf, BufMut, BytesMut}; use bytes::{Buf, BufMut, BytesMut};
use fallible_iterator::FallibleIterator; use fallible_iterator::FallibleIterator;
use log::{debug, error, info, trace, warn}; use log::{debug, error, info, trace, warn};
use once_cell::sync::Lazy;
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use postgres_protocol::message; use postgres_protocol::message;
use std::collections::{BTreeSet, HashMap}; use std::collections::{BTreeSet, HashMap, HashSet};
use std::io::Read; use std::mem;
use std::net::IpAddr; use std::net::IpAddr;
use std::sync::Arc; use std::sync::{Arc, Once};
use std::time::SystemTime; use std::time::SystemTime;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream};
use tokio::net::TcpStream; use tokio::net::TcpStream;
@@ -19,6 +20,7 @@ use crate::config::{get_config, get_prepared_statements_cache_size, Address, Use
use crate::constants::*; use crate::constants::*;
use crate::dns_cache::{AddrSet, CACHED_RESOLVER}; use crate::dns_cache::{AddrSet, CACHED_RESOLVER};
use crate::errors::{Error, ServerIdentifier}; use crate::errors::{Error, ServerIdentifier};
use crate::messages::BytesMutReader;
use crate::messages::*; use crate::messages::*;
use crate::mirrors::MirroringManager; use crate::mirrors::MirroringManager;
use crate::pool::ClientServerMap; use crate::pool::ClientServerMap;
@@ -145,6 +147,124 @@ impl std::fmt::Display for CleanupState {
} }
} }
static TRACKED_PARAMETERS: Lazy<HashSet<String>> = Lazy::new(|| {
let mut set = HashSet::new();
set.insert("client_encoding".to_string());
set.insert("DateStyle".to_string());
set.insert("TimeZone".to_string());
set.insert("standard_conforming_strings".to_string());
set.insert("application_name".to_string());
set
});
#[derive(Debug, Clone)]
pub struct ServerParameters {
parameters: HashMap<String, String>,
}
impl Default for ServerParameters {
fn default() -> Self {
Self::new()
}
}
impl ServerParameters {
pub fn new() -> Self {
let mut server_parameters = ServerParameters {
parameters: HashMap::new(),
};
server_parameters.set_param("client_encoding".to_string(), "UTF8".to_string(), false);
server_parameters.set_param("DateStyle".to_string(), "ISO, MDY".to_string(), false);
server_parameters.set_param("TimeZone".to_string(), "Etc/UTC".to_string(), false);
server_parameters.set_param(
"standard_conforming_strings".to_string(),
"on".to_string(),
false,
);
server_parameters.set_param("application_name".to_string(), "pgcat".to_string(), false);
server_parameters
}
/// returns true if a tracked parameter was set, false if it was a non-tracked parameter
/// if startup is false, then then only tracked parameters will be set
pub fn set_param(&mut self, mut key: String, value: String, startup: bool) {
// The startup parameter will send uncapitalized keys but parameter status packets will send capitalized keys
if key == "timezone" {
key = "TimeZone".to_string();
} else if key == "datestyle" {
key = "DateStyle".to_string();
};
if TRACKED_PARAMETERS.contains(&key) {
self.parameters.insert(key, value);
} else {
if startup {
self.parameters.insert(key, value);
}
}
}
pub fn set_from_hashmap(&mut self, parameters: &HashMap<String, String>, startup: bool) {
// iterate through each and call set_param
for (key, value) in parameters {
self.set_param(key.to_string(), value.to_string(), startup);
}
}
// Gets the diff of the parameters
fn compare_params(&self, incoming_parameters: &ServerParameters) -> HashMap<String, String> {
let mut diff = HashMap::new();
// iterate through tracked parameters
for key in TRACKED_PARAMETERS.iter() {
if let Some(incoming_value) = incoming_parameters.parameters.get(key) {
if let Some(value) = self.parameters.get(key) {
if value != incoming_value {
diff.insert(key.to_string(), incoming_value.to_string());
}
}
}
}
diff
}
pub fn get_application_name(&self) -> &String {
// Can unwrap because we set it in the constructor
self.parameters.get("application_name").unwrap()
}
fn add_parameter_message(key: &str, value: &str, buffer: &mut BytesMut) {
buffer.put_u8(b'S');
// 4 is len of i32, the plus for the null terminator
let len = 4 + key.len() + 1 + value.len() + 1;
buffer.put_i32(len as i32);
buffer.put_slice(key.as_bytes());
buffer.put_u8(0);
buffer.put_slice(value.as_bytes());
buffer.put_u8(0);
}
}
impl From<&ServerParameters> for BytesMut {
fn from(server_parameters: &ServerParameters) -> Self {
let mut bytes = BytesMut::new();
for (key, value) in &server_parameters.parameters {
ServerParameters::add_parameter_message(key, value, &mut bytes);
}
bytes
}
}
// pub fn compare
/// Server state. /// Server state.
pub struct Server { pub struct Server {
/// Server host, e.g. localhost, /// Server host, e.g. localhost,
@@ -158,7 +278,7 @@ pub struct Server {
buffer: BytesMut, buffer: BytesMut,
/// Server information the server sent us over on startup. /// Server information the server sent us over on startup.
server_info: BytesMut, server_parameters: ServerParameters,
/// Backend id and secret key used for query cancellation. /// Backend id and secret key used for query cancellation.
process_id: i32, process_id: i32,
@@ -347,7 +467,6 @@ impl Server {
startup(&mut stream, username, database).await?; startup(&mut stream, username, database).await?;
let mut server_info = BytesMut::new();
let mut process_id: i32 = 0; let mut process_id: i32 = 0;
let mut secret_key: i32 = 0; let mut secret_key: i32 = 0;
let server_identifier = ServerIdentifier::new(username, &database); let server_identifier = ServerIdentifier::new(username, &database);
@@ -359,6 +478,8 @@ impl Server {
None => None, None => None,
}; };
let mut server_parameters = ServerParameters::new();
loop { loop {
let code = match stream.read_u8().await { let code = match stream.read_u8().await {
Ok(code) => code as char, Ok(code) => code as char,
@@ -616,9 +737,10 @@ impl Server {
// ParameterStatus // ParameterStatus
'S' => { 'S' => {
let mut param = vec![0u8; len as usize - 4]; let mut bytes = BytesMut::with_capacity(len as usize - 4);
bytes.resize(len as usize - mem::size_of::<i32>(), b'0');
match stream.read_exact(&mut param).await { match stream.read_exact(&mut bytes[..]).await {
Ok(_) => (), Ok(_) => (),
Err(_) => { Err(_) => {
return Err(Error::ServerStartupError( return Err(Error::ServerStartupError(
@@ -628,12 +750,13 @@ impl Server {
} }
}; };
let key = bytes.read_string().unwrap();
let value = bytes.read_string().unwrap();
// Save the parameter so we can pass it to the client later. // Save the parameter so we can pass it to the client later.
// These can be server_encoding, client_encoding, server timezone, Postgres version, // These can be server_encoding, client_encoding, server timezone, Postgres version,
// and many more interesting things we should know about the Postgres server we are talking to. // and many more interesting things we should know about the Postgres server we are talking to.
server_info.put_u8(b'S'); server_parameters.set_param(key, value, true);
server_info.put_i32(len);
server_info.put_slice(&param[..]);
} }
// BackendKeyData // BackendKeyData
@@ -675,11 +798,11 @@ impl Server {
} }
}; };
let mut server = Server { let server = Server {
address: address.clone(), address: address.clone(),
stream: BufStream::new(stream), stream: BufStream::new(stream),
buffer: BytesMut::with_capacity(8196), buffer: BytesMut::with_capacity(8196),
server_info, server_parameters,
process_id, process_id,
secret_key, secret_key,
in_transaction: false, in_transaction: false,
@@ -691,7 +814,7 @@ impl Server {
addr_set, addr_set,
connected_at: chrono::offset::Utc::now().naive_utc(), connected_at: chrono::offset::Utc::now().naive_utc(),
stats, stats,
application_name: String::new(), application_name: "pgcat".to_string(),
last_activity: SystemTime::now(), last_activity: SystemTime::now(),
mirror_manager: match address.mirrors.len() { mirror_manager: match address.mirrors.len() {
0 => None, 0 => None,
@@ -705,8 +828,6 @@ impl Server {
prepared_statements: BTreeSet::new(), prepared_statements: BTreeSet::new(),
}; };
server.set_name("pgcat").await?;
return Ok(server); return Ok(server);
} }
@@ -776,7 +897,10 @@ impl Server {
/// Receive data from the server in response to a client request. /// Receive data from the server in response to a client request.
/// This method must be called multiple times while `self.is_data_available()` is true /// This method must be called multiple times while `self.is_data_available()` is true
/// in order to receive all data the server has to offer. /// in order to receive all data the server has to offer.
pub async fn recv(&mut self) -> Result<BytesMut, Error> { pub async fn recv(
&mut self,
mut client_server_parameters: Option<&mut ServerParameters>,
) -> Result<BytesMut, Error> {
loop { loop {
let mut message = match read_message(&mut self.stream).await { let mut message = match read_message(&mut self.stream).await {
Ok(message) => message, Ok(message) => message,
@@ -848,14 +972,13 @@ impl Server {
self.in_copy_mode = false; self.in_copy_mode = false;
} }
let mut command_tag = String::new(); match message.read_string() {
match message.reader().read_to_string(&mut command_tag) { Ok(command) => {
Ok(_) => {
// Non-exhaustive list of commands that are likely to change session variables/resources // Non-exhaustive list of commands that are likely to change session variables/resources
// which can leak between clients. This is a best effort to block bad clients // which can leak between clients. This is a best effort to block bad clients
// from poisoning a transaction-mode pool by setting inappropriate session variables // from poisoning a transaction-mode pool by setting inappropriate session variables
match command_tag.as_str() { match command.as_str() {
"SET\0" => { "SET" => {
// We don't detect set statements in transactions // We don't detect set statements in transactions
// No great way to differentiate between set and set local // No great way to differentiate between set and set local
// As a result, we will miss cases when set statements are used in transactions // As a result, we will miss cases when set statements are used in transactions
@@ -865,7 +988,8 @@ impl Server {
self.cleanup_state.needs_cleanup_set = true; self.cleanup_state.needs_cleanup_set = true;
} }
} }
"PREPARE\0" => {
"PREPARE" => {
debug!("Server connection marked for clean up"); debug!("Server connection marked for clean up");
self.cleanup_state.needs_cleanup_prepare = true; self.cleanup_state.needs_cleanup_prepare = true;
} }
@@ -879,6 +1003,17 @@ impl Server {
} }
} }
'S' => {
let key = message.read_string().unwrap();
let value = message.read_string().unwrap();
if let Some(client_server_parameters) = client_server_parameters.as_mut() {
client_server_parameters.set_param(key.clone(), value.clone(), false);
}
self.server_parameters.set_param(key, value, false);
}
// DataRow // DataRow
'D' => { 'D' => {
// More data is available after this message, this is not the end of the reply. // More data is available after this message, this is not the end of the reply.
@@ -1089,9 +1224,28 @@ impl Server {
} }
/// Get server startup information to forward it to the client. /// Get server startup information to forward it to the client.
/// Not used at the moment. pub fn server_parameters(&self) -> ServerParameters {
pub fn server_info(&self) -> BytesMut { self.server_parameters.clone()
self.server_info.clone() }
pub async fn sync_parameters(&mut self, parameters: &ServerParameters) -> Result<(), Error> {
let parameter_diff = self.server_parameters.compare_params(parameters);
if parameter_diff.is_empty() {
return Ok(());
}
let mut query = String::from("");
for (key, value) in parameter_diff {
query.push_str(&format!("SET {} TO '{}';", key, value));
}
let res = self.query(&query).await;
self.cleanup_state.reset();
res
} }
/// Indicate that this server connection cannot be re-used and must be discarded. /// Indicate that this server connection cannot be re-used and must be discarded.
@@ -1125,7 +1279,7 @@ impl Server {
self.send(&query).await?; self.send(&query).await?;
loop { loop {
let _ = self.recv().await?; let _ = self.recv(None).await?;
if !self.data_available { if !self.data_available {
break; break;
@@ -1166,24 +1320,6 @@ impl Server {
Ok(()) Ok(())
} }
/// A shorthand for `SET application_name = $1`.
pub async fn set_name(&mut self, name: &str) -> Result<(), Error> {
if self.application_name != name {
self.application_name = name.to_string();
// We don't want `SET application_name` to mark the server connection
// as needing cleanup
let needs_cleanup_before = self.cleanup_state;
let result = Ok(self
.query(&format!("SET application_name = '{}'", name))
.await?);
self.cleanup_state = needs_cleanup_before;
result
} else {
Ok(())
}
}
/// get Server stats /// get Server stats
pub fn stats(&self) -> Arc<ServerStats> { pub fn stats(&self) -> Arc<ServerStats> {
self.stats.clone() self.stats.clone()
@@ -1241,7 +1377,7 @@ impl Server {
.await?; .await?;
debug!("Connected!, sending query."); debug!("Connected!, sending query.");
server.send(&simple_query(query)).await?; server.send(&simple_query(query)).await?;
let mut message = server.recv().await?; let mut message = server.recv(None).await?;
Ok(parse_query_message(&mut message).await?) Ok(parse_query_message(&mut message).await?)
} }

View File

@@ -112,10 +112,16 @@ class PgcatProcess
"postgresql://#{username}:#{password}@0.0.0.0:#{@port}/pgcat" "postgresql://#{username}:#{password}@0.0.0.0:#{@port}/pgcat"
end end
def connection_string(pool_name, username, password = nil) def connection_string(pool_name, username, password = nil, parameters: {})
cfg = current_config cfg = current_config
user_idx, user_obj = cfg["pools"][pool_name]["users"].detect { |k, user| user["username"] == username } user_idx, user_obj = cfg["pools"][pool_name]["users"].detect { |k, user| user["username"] == username }
"postgresql://#{username}:#{password || user_obj["password"]}@0.0.0.0:#{@port}/#{pool_name}" connection_string = "postgresql://#{username}:#{password || user_obj["password"]}@0.0.0.0:#{@port}/#{pool_name}"
# Add the additional parameters to the connection string
parameter_string = parameters.map { |key, value| "#{key}=#{value}" }.join("&")
connection_string += "?#{parameter_string}" unless parameter_string.empty?
connection_string
end end
def example_connection_string def example_connection_string

View File

@@ -294,6 +294,30 @@ describe "Miscellaneous" do
expect(processes.primary.count_query("DISCARD ALL")).to eq(10) expect(processes.primary.count_query("DISCARD ALL")).to eq(10)
end end
it "Respects tracked parameters on startup" do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user", parameters: { "application_name" => "my_pgcat_test" }))
expect(conn.async_exec("SHOW application_name")[0]["application_name"]).to eq("my_pgcat_test")
conn.close
end
it "Respect tracked parameter on set statemet" do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.async_exec("SET application_name to 'my_pgcat_test'")
expect(conn.async_exec("SHOW application_name")[0]["application_name"]).to eq("my_pgcat_test")
end
it "Ignore untracked parameter on set statemet" do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
orignal_statement_timeout = conn.async_exec("SHOW statement_timeout")[0]["statement_timeout"]
conn.async_exec("SET statement_timeout to 1500")
expect(conn.async_exec("SHOW statement_timeout")[0]["statement_timeout"]).to eq(orignal_statement_timeout)
end
end end
context "transaction mode with transactions" do context "transaction mode with transactions" do