mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 01:16:30 +00:00
Handle and track startup parameters (#478)
* User server parameters struct instead of server info bytesmut * Refactor to use hashmap for all params and add server parameters to client * Sync parameters on client server checkout * minor refactor * update client side parameters when changed * Move the SET statement logic from the C packet to the S packet. * trigger build * revert validation changes * remove comment * Try fix * Reset cleanup state after sync * fix server version test * Track application name through client life for stats * Add tests * minor refactoring * fmt * fix * fmt
This commit is contained in:
17
src/admin.rs
17
src/admin.rs
@@ -1,4 +1,5 @@
|
||||
use crate::pool::BanReason;
|
||||
use crate::server::ServerParameters;
|
||||
use crate::stats::pool::PoolStats;
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use log::{error, info, trace};
|
||||
@@ -17,16 +18,16 @@ use crate::pool::ClientServerMap;
|
||||
use crate::pool::{get_all_pools, get_pool};
|
||||
use crate::stats::{get_client_stats, get_server_stats, ClientState, ServerState};
|
||||
|
||||
pub fn generate_server_info_for_admin() -> BytesMut {
|
||||
let mut server_info = BytesMut::new();
|
||||
pub fn generate_server_parameters_for_admin() -> ServerParameters {
|
||||
let mut server_parameters = ServerParameters::new();
|
||||
|
||||
server_info.put(server_parameter_message("application_name", ""));
|
||||
server_info.put(server_parameter_message("client_encoding", "UTF8"));
|
||||
server_info.put(server_parameter_message("server_encoding", "UTF8"));
|
||||
server_info.put(server_parameter_message("server_version", VERSION));
|
||||
server_info.put(server_parameter_message("DateStyle", "ISO, MDY"));
|
||||
server_parameters.set_param("application_name".to_string(), "".to_string(), true);
|
||||
server_parameters.set_param("client_encoding".to_string(), "UTF8".to_string(), true);
|
||||
server_parameters.set_param("server_encoding".to_string(), "UTF8".to_string(), true);
|
||||
server_parameters.set_param("server_version".to_string(), VERSION.to_string(), true);
|
||||
server_parameters.set_param("DateStyle".to_string(), "ISO, MDY".to_string(), true);
|
||||
|
||||
server_info
|
||||
server_parameters
|
||||
}
|
||||
|
||||
/// Handle admin client.
|
||||
|
||||
107
src/client.rs
107
src/client.rs
@@ -12,7 +12,7 @@ use tokio::net::TcpStream;
|
||||
use tokio::sync::broadcast::Receiver;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
|
||||
use crate::admin::{generate_server_info_for_admin, handle_admin};
|
||||
use crate::admin::{generate_server_parameters_for_admin, handle_admin};
|
||||
use crate::auth_passthrough::refetch_auth_hash;
|
||||
use crate::config::{
|
||||
get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address, PoolMode,
|
||||
@@ -22,7 +22,7 @@ use crate::messages::*;
|
||||
use crate::plugins::PluginOutput;
|
||||
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
|
||||
use crate::query_router::{Command, QueryRouter};
|
||||
use crate::server::Server;
|
||||
use crate::server::{Server, ServerParameters};
|
||||
use crate::stats::{ClientStats, ServerStats};
|
||||
use crate::tls::Tls;
|
||||
|
||||
@@ -96,8 +96,8 @@ pub struct Client<S, T> {
|
||||
/// Postgres user for this client (This comes from the user in the connection string)
|
||||
username: String,
|
||||
|
||||
/// Application name for this client (defaults to pgcat)
|
||||
application_name: String,
|
||||
/// Server startup and session parameters that we're going to track
|
||||
server_parameters: ServerParameters,
|
||||
|
||||
/// Used to notify clients about an impending shutdown
|
||||
shutdown: Receiver<()>,
|
||||
@@ -502,7 +502,7 @@ where
|
||||
};
|
||||
|
||||
// Authenticate admin user.
|
||||
let (transaction_mode, server_info) = if admin {
|
||||
let (transaction_mode, mut server_parameters) = if admin {
|
||||
let config = get_config();
|
||||
|
||||
// Compare server and client hashes.
|
||||
@@ -521,7 +521,7 @@ where
|
||||
return Err(error);
|
||||
}
|
||||
|
||||
(false, generate_server_info_for_admin())
|
||||
(false, generate_server_parameters_for_admin())
|
||||
}
|
||||
// Authenticate normal user.
|
||||
else {
|
||||
@@ -654,13 +654,16 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
(transaction_mode, pool.server_info())
|
||||
(transaction_mode, pool.server_parameters())
|
||||
};
|
||||
|
||||
// Update the parameters to merge what the application sent and what's originally on the server
|
||||
server_parameters.set_from_hashmap(¶meters, false);
|
||||
|
||||
debug!("Password authentication successful");
|
||||
|
||||
auth_ok(&mut write).await?;
|
||||
write_all(&mut write, server_info).await?;
|
||||
write_all(&mut write, (&server_parameters).into()).await?;
|
||||
backend_key_data(&mut write, process_id, secret_key).await?;
|
||||
ready_for_query(&mut write).await?;
|
||||
|
||||
@@ -690,7 +693,7 @@ where
|
||||
last_server_stats: None,
|
||||
pool_name: pool_name.clone(),
|
||||
username: username.clone(),
|
||||
application_name: application_name.to_string(),
|
||||
server_parameters,
|
||||
shutdown,
|
||||
connected_to_server: false,
|
||||
prepared_statements: HashMap::new(),
|
||||
@@ -725,7 +728,7 @@ where
|
||||
last_server_stats: None,
|
||||
pool_name: String::from("undefined"),
|
||||
username: String::from("undefined"),
|
||||
application_name: String::from("undefined"),
|
||||
server_parameters: ServerParameters::new(),
|
||||
shutdown,
|
||||
connected_to_server: false,
|
||||
prepared_statements: HashMap::new(),
|
||||
@@ -774,8 +777,11 @@ where
|
||||
let mut prepared_statement = None;
|
||||
let mut will_prepare = false;
|
||||
|
||||
let client_identifier =
|
||||
ClientIdentifier::new(&self.application_name, &self.username, &self.pool_name);
|
||||
let client_identifier = ClientIdentifier::new(
|
||||
&self.server_parameters.get_application_name(),
|
||||
&self.username,
|
||||
&self.pool_name,
|
||||
);
|
||||
|
||||
// Our custom protocol loop.
|
||||
// We expect the client to either start a transaction with regular queries
|
||||
@@ -1115,10 +1121,7 @@ where
|
||||
server.address()
|
||||
);
|
||||
|
||||
// TODO: investigate other parameters and set them too.
|
||||
|
||||
// Set application_name.
|
||||
server.set_name(&self.application_name).await?;
|
||||
server.sync_parameters(&self.server_parameters).await?;
|
||||
|
||||
let mut initial_message = Some(message);
|
||||
|
||||
@@ -1296,7 +1299,9 @@ where
|
||||
if !server.in_transaction() {
|
||||
// Report transaction executed statistics.
|
||||
self.stats.transaction();
|
||||
server.stats().transaction(&self.application_name);
|
||||
server
|
||||
.stats()
|
||||
.transaction(&self.server_parameters.get_application_name());
|
||||
|
||||
// Release server back to the pool if we are in transaction mode.
|
||||
// If we are in session mode, we keep the server until the client disconnects.
|
||||
@@ -1446,7 +1451,9 @@ where
|
||||
|
||||
if !server.in_transaction() {
|
||||
self.stats.transaction();
|
||||
server.stats().transaction(&self.application_name);
|
||||
server
|
||||
.stats()
|
||||
.transaction(&self.server_parameters.get_application_name());
|
||||
|
||||
// Release server back to the pool if we are in transaction mode.
|
||||
// If we are in session mode, we keep the server until the client disconnects.
|
||||
@@ -1495,7 +1502,9 @@ where
|
||||
|
||||
if !server.in_transaction() {
|
||||
self.stats.transaction();
|
||||
server.stats().transaction(&self.application_name);
|
||||
server
|
||||
.stats()
|
||||
.transaction(self.server_parameters.get_application_name());
|
||||
|
||||
// Release server back to the pool if we are in transaction mode.
|
||||
// If we are in session mode, we keep the server until the client disconnects.
|
||||
@@ -1547,7 +1556,9 @@ where
|
||||
|
||||
Err(Error::ClientError(format!(
|
||||
"Invalid pool name {{ username: {}, pool_name: {}, application_name: {} }}",
|
||||
self.pool_name, self.username, self.application_name
|
||||
self.pool_name,
|
||||
self.username,
|
||||
self.server_parameters.get_application_name()
|
||||
)))
|
||||
}
|
||||
}
|
||||
@@ -1704,7 +1715,7 @@ where
|
||||
client_stats.query();
|
||||
server.stats().query(
|
||||
Instant::now().duration_since(query_start).as_millis() as u64,
|
||||
&self.application_name,
|
||||
&self.server_parameters.get_application_name(),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
@@ -1733,38 +1744,18 @@ where
|
||||
pool: &ConnectionPool,
|
||||
client_stats: &ClientStats,
|
||||
) -> Result<BytesMut, Error> {
|
||||
if pool.settings.user.statement_timeout > 0 {
|
||||
match tokio::time::timeout(
|
||||
tokio::time::Duration::from_millis(pool.settings.user.statement_timeout),
|
||||
server.recv(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => match result {
|
||||
Ok(message) => Ok(message),
|
||||
Err(err) => {
|
||||
pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats));
|
||||
error_response_terminal(
|
||||
&mut self.write,
|
||||
&format!("error receiving data from server: {:?}", err),
|
||||
)
|
||||
.await?;
|
||||
Err(err)
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
error!(
|
||||
"Statement timeout while talking to {:?} with user {}",
|
||||
address, pool.settings.user.username
|
||||
);
|
||||
server.mark_bad();
|
||||
pool.ban(address, BanReason::StatementTimeout, Some(client_stats));
|
||||
error_response_terminal(&mut self.write, "pool statement timeout").await?;
|
||||
Err(Error::StatementTimeout)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
match server.recv().await {
|
||||
let statement_timeout_duration = match pool.settings.user.statement_timeout {
|
||||
0 => tokio::time::Duration::MAX,
|
||||
timeout => tokio::time::Duration::from_millis(timeout),
|
||||
};
|
||||
|
||||
match tokio::time::timeout(
|
||||
statement_timeout_duration,
|
||||
server.recv(Some(&mut self.server_parameters)),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => match result {
|
||||
Ok(message) => Ok(message),
|
||||
Err(err) => {
|
||||
pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats));
|
||||
@@ -1775,6 +1766,16 @@ where
|
||||
.await?;
|
||||
Err(err)
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
error!(
|
||||
"Statement timeout while talking to {:?} with user {}",
|
||||
address, pool.settings.user.username
|
||||
);
|
||||
server.mark_bad();
|
||||
pool.ban(address, BanReason::StatementTimeout, Some(client_stats));
|
||||
error_response_terminal(&mut self.write, "pool statement timeout").await?;
|
||||
Err(Error::StatementTimeout)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,6 +144,10 @@ where
|
||||
bytes.put_slice(user.as_bytes());
|
||||
bytes.put_u8(0);
|
||||
|
||||
// Application name
|
||||
bytes.put(&b"application_name\0"[..]);
|
||||
bytes.put_slice(&b"pgcat\0"[..]);
|
||||
|
||||
// Database
|
||||
bytes.put(&b"database\0"[..]);
|
||||
bytes.put_slice(database.as_bytes());
|
||||
@@ -731,6 +735,21 @@ impl BytesMutReader for Cursor<&BytesMut> {
|
||||
}
|
||||
}
|
||||
|
||||
impl BytesMutReader for BytesMut {
|
||||
/// Should only be used when reading strings from the message protocol.
|
||||
/// Can be used to read multiple strings from the same message which are separated by the null byte
|
||||
fn read_string(&mut self) -> Result<String, Error> {
|
||||
let null_index = self.iter().position(|&byte| byte == b'\0');
|
||||
|
||||
match null_index {
|
||||
Some(index) => {
|
||||
let string_bytes = self.split_to(index + 1);
|
||||
Ok(String::from_utf8_lossy(&string_bytes[..string_bytes.len() - 1]).to_string())
|
||||
}
|
||||
None => return Err(Error::ParseBytesError("Could not read string".to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Parse (F) message.
|
||||
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
|
||||
#[derive(Clone, Debug)]
|
||||
|
||||
@@ -78,7 +78,7 @@ impl MirroredClient {
|
||||
}
|
||||
|
||||
// Incoming data from server (we read to clear the socket buffer and discard the data)
|
||||
recv_result = server.recv() => {
|
||||
recv_result = server.recv(None) => {
|
||||
match recv_result {
|
||||
Ok(message) => trace!("Received from mirror: {} {:?}", String::from_utf8_lossy(&message[..]), address.clone()),
|
||||
Err(err) => {
|
||||
|
||||
26
src/pool.rs
26
src/pool.rs
@@ -1,7 +1,6 @@
|
||||
use arc_swap::ArcSwap;
|
||||
use async_trait::async_trait;
|
||||
use bb8::{ManageConnection, Pool, PooledConnection, QueueStrategy};
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use chrono::naive::NaiveDateTime;
|
||||
use log::{debug, error, info, warn};
|
||||
use once_cell::sync::Lazy;
|
||||
@@ -25,7 +24,7 @@ use crate::errors::Error;
|
||||
|
||||
use crate::auth_passthrough::AuthPassthrough;
|
||||
use crate::plugins::prewarmer;
|
||||
use crate::server::Server;
|
||||
use crate::server::{Server, ServerParameters};
|
||||
use crate::sharding::ShardingFunction;
|
||||
use crate::stats::{AddressStats, ClientStats, ServerStats};
|
||||
|
||||
@@ -196,10 +195,10 @@ pub struct ConnectionPool {
|
||||
/// that should not be queried.
|
||||
banlist: BanList,
|
||||
|
||||
/// The server information (K messages) have to be passed to the
|
||||
/// The server information has to be passed to the
|
||||
/// clients on startup. We pre-connect to all shards and replicas
|
||||
/// on pool creation and save the K messages here.
|
||||
server_info: Arc<RwLock<BytesMut>>,
|
||||
/// on pool creation and save the startup parameters here.
|
||||
original_server_parameters: Arc<RwLock<ServerParameters>>,
|
||||
|
||||
/// Pool configuration.
|
||||
pub settings: PoolSettings,
|
||||
@@ -445,7 +444,7 @@ impl ConnectionPool {
|
||||
addresses,
|
||||
banlist: Arc::new(RwLock::new(banlist)),
|
||||
config_hash: new_pool_hash_value,
|
||||
server_info: Arc::new(RwLock::new(BytesMut::new())),
|
||||
original_server_parameters: Arc::new(RwLock::new(ServerParameters::new())),
|
||||
auth_hash: pool_auth_hash,
|
||||
settings: PoolSettings {
|
||||
pool_mode: match user.pool_mode {
|
||||
@@ -528,7 +527,7 @@ impl ConnectionPool {
|
||||
for server in 0..self.servers(shard) {
|
||||
let databases = self.databases.clone();
|
||||
let validated = Arc::clone(&validated);
|
||||
let pool_server_info = Arc::clone(&self.server_info);
|
||||
let pool_server_parameters = Arc::clone(&self.original_server_parameters);
|
||||
|
||||
let task = tokio::task::spawn(async move {
|
||||
let connection = match databases[shard][server].get().await {
|
||||
@@ -541,11 +540,10 @@ impl ConnectionPool {
|
||||
|
||||
let proxy = connection;
|
||||
let server = &*proxy;
|
||||
let server_info = server.server_info();
|
||||
let server_parameters: ServerParameters = server.server_parameters();
|
||||
|
||||
let mut guard = pool_server_info.write();
|
||||
guard.clear();
|
||||
guard.put(server_info.clone());
|
||||
let mut guard = pool_server_parameters.write();
|
||||
*guard = server_parameters;
|
||||
validated.store(true, Ordering::Relaxed);
|
||||
});
|
||||
|
||||
@@ -557,7 +555,7 @@ impl ConnectionPool {
|
||||
|
||||
// TODO: compare server information to make sure
|
||||
// all shards are running identical configurations.
|
||||
if self.server_info.read().is_empty() {
|
||||
if !self.validated() {
|
||||
error!("Could not validate connection pool");
|
||||
return Err(Error::AllServersDown);
|
||||
}
|
||||
@@ -917,8 +915,8 @@ impl ConnectionPool {
|
||||
&self.addresses[shard][server]
|
||||
}
|
||||
|
||||
pub fn server_info(&self) -> BytesMut {
|
||||
self.server_info.read().clone()
|
||||
pub fn server_parameters(&self) -> ServerParameters {
|
||||
self.original_server_parameters.read().clone()
|
||||
}
|
||||
|
||||
fn busy_connection_count(&self, address: &Address) -> u32 {
|
||||
|
||||
226
src/server.rs
226
src/server.rs
@@ -3,12 +3,13 @@
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use log::{debug, error, info, trace, warn};
|
||||
use once_cell::sync::Lazy;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use postgres_protocol::message;
|
||||
use std::collections::{BTreeSet, HashMap};
|
||||
use std::io::Read;
|
||||
use std::collections::{BTreeSet, HashMap, HashSet};
|
||||
use std::mem;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, Once};
|
||||
use std::time::SystemTime;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream};
|
||||
use tokio::net::TcpStream;
|
||||
@@ -19,6 +20,7 @@ use crate::config::{get_config, get_prepared_statements_cache_size, Address, Use
|
||||
use crate::constants::*;
|
||||
use crate::dns_cache::{AddrSet, CACHED_RESOLVER};
|
||||
use crate::errors::{Error, ServerIdentifier};
|
||||
use crate::messages::BytesMutReader;
|
||||
use crate::messages::*;
|
||||
use crate::mirrors::MirroringManager;
|
||||
use crate::pool::ClientServerMap;
|
||||
@@ -145,6 +147,124 @@ impl std::fmt::Display for CleanupState {
|
||||
}
|
||||
}
|
||||
|
||||
static TRACKED_PARAMETERS: Lazy<HashSet<String>> = Lazy::new(|| {
|
||||
let mut set = HashSet::new();
|
||||
set.insert("client_encoding".to_string());
|
||||
set.insert("DateStyle".to_string());
|
||||
set.insert("TimeZone".to_string());
|
||||
set.insert("standard_conforming_strings".to_string());
|
||||
set.insert("application_name".to_string());
|
||||
set
|
||||
});
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ServerParameters {
|
||||
parameters: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl Default for ServerParameters {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerParameters {
|
||||
pub fn new() -> Self {
|
||||
let mut server_parameters = ServerParameters {
|
||||
parameters: HashMap::new(),
|
||||
};
|
||||
|
||||
server_parameters.set_param("client_encoding".to_string(), "UTF8".to_string(), false);
|
||||
server_parameters.set_param("DateStyle".to_string(), "ISO, MDY".to_string(), false);
|
||||
server_parameters.set_param("TimeZone".to_string(), "Etc/UTC".to_string(), false);
|
||||
server_parameters.set_param(
|
||||
"standard_conforming_strings".to_string(),
|
||||
"on".to_string(),
|
||||
false,
|
||||
);
|
||||
server_parameters.set_param("application_name".to_string(), "pgcat".to_string(), false);
|
||||
|
||||
server_parameters
|
||||
}
|
||||
|
||||
/// returns true if a tracked parameter was set, false if it was a non-tracked parameter
|
||||
/// if startup is false, then then only tracked parameters will be set
|
||||
pub fn set_param(&mut self, mut key: String, value: String, startup: bool) {
|
||||
// The startup parameter will send uncapitalized keys but parameter status packets will send capitalized keys
|
||||
if key == "timezone" {
|
||||
key = "TimeZone".to_string();
|
||||
} else if key == "datestyle" {
|
||||
key = "DateStyle".to_string();
|
||||
};
|
||||
|
||||
if TRACKED_PARAMETERS.contains(&key) {
|
||||
self.parameters.insert(key, value);
|
||||
} else {
|
||||
if startup {
|
||||
self.parameters.insert(key, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_from_hashmap(&mut self, parameters: &HashMap<String, String>, startup: bool) {
|
||||
// iterate through each and call set_param
|
||||
for (key, value) in parameters {
|
||||
self.set_param(key.to_string(), value.to_string(), startup);
|
||||
}
|
||||
}
|
||||
|
||||
// Gets the diff of the parameters
|
||||
fn compare_params(&self, incoming_parameters: &ServerParameters) -> HashMap<String, String> {
|
||||
let mut diff = HashMap::new();
|
||||
|
||||
// iterate through tracked parameters
|
||||
for key in TRACKED_PARAMETERS.iter() {
|
||||
if let Some(incoming_value) = incoming_parameters.parameters.get(key) {
|
||||
if let Some(value) = self.parameters.get(key) {
|
||||
if value != incoming_value {
|
||||
diff.insert(key.to_string(), incoming_value.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
diff
|
||||
}
|
||||
|
||||
pub fn get_application_name(&self) -> &String {
|
||||
// Can unwrap because we set it in the constructor
|
||||
self.parameters.get("application_name").unwrap()
|
||||
}
|
||||
|
||||
fn add_parameter_message(key: &str, value: &str, buffer: &mut BytesMut) {
|
||||
buffer.put_u8(b'S');
|
||||
|
||||
// 4 is len of i32, the plus for the null terminator
|
||||
let len = 4 + key.len() + 1 + value.len() + 1;
|
||||
|
||||
buffer.put_i32(len as i32);
|
||||
|
||||
buffer.put_slice(key.as_bytes());
|
||||
buffer.put_u8(0);
|
||||
buffer.put_slice(value.as_bytes());
|
||||
buffer.put_u8(0);
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&ServerParameters> for BytesMut {
|
||||
fn from(server_parameters: &ServerParameters) -> Self {
|
||||
let mut bytes = BytesMut::new();
|
||||
|
||||
for (key, value) in &server_parameters.parameters {
|
||||
ServerParameters::add_parameter_message(key, value, &mut bytes);
|
||||
}
|
||||
|
||||
bytes
|
||||
}
|
||||
}
|
||||
|
||||
// pub fn compare
|
||||
|
||||
/// Server state.
|
||||
pub struct Server {
|
||||
/// Server host, e.g. localhost,
|
||||
@@ -158,7 +278,7 @@ pub struct Server {
|
||||
buffer: BytesMut,
|
||||
|
||||
/// Server information the server sent us over on startup.
|
||||
server_info: BytesMut,
|
||||
server_parameters: ServerParameters,
|
||||
|
||||
/// Backend id and secret key used for query cancellation.
|
||||
process_id: i32,
|
||||
@@ -347,7 +467,6 @@ impl Server {
|
||||
|
||||
startup(&mut stream, username, database).await?;
|
||||
|
||||
let mut server_info = BytesMut::new();
|
||||
let mut process_id: i32 = 0;
|
||||
let mut secret_key: i32 = 0;
|
||||
let server_identifier = ServerIdentifier::new(username, &database);
|
||||
@@ -359,6 +478,8 @@ impl Server {
|
||||
None => None,
|
||||
};
|
||||
|
||||
let mut server_parameters = ServerParameters::new();
|
||||
|
||||
loop {
|
||||
let code = match stream.read_u8().await {
|
||||
Ok(code) => code as char,
|
||||
@@ -616,9 +737,10 @@ impl Server {
|
||||
|
||||
// ParameterStatus
|
||||
'S' => {
|
||||
let mut param = vec![0u8; len as usize - 4];
|
||||
let mut bytes = BytesMut::with_capacity(len as usize - 4);
|
||||
bytes.resize(len as usize - mem::size_of::<i32>(), b'0');
|
||||
|
||||
match stream.read_exact(&mut param).await {
|
||||
match stream.read_exact(&mut bytes[..]).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => {
|
||||
return Err(Error::ServerStartupError(
|
||||
@@ -628,12 +750,13 @@ impl Server {
|
||||
}
|
||||
};
|
||||
|
||||
let key = bytes.read_string().unwrap();
|
||||
let value = bytes.read_string().unwrap();
|
||||
|
||||
// Save the parameter so we can pass it to the client later.
|
||||
// These can be server_encoding, client_encoding, server timezone, Postgres version,
|
||||
// and many more interesting things we should know about the Postgres server we are talking to.
|
||||
server_info.put_u8(b'S');
|
||||
server_info.put_i32(len);
|
||||
server_info.put_slice(¶m[..]);
|
||||
server_parameters.set_param(key, value, true);
|
||||
}
|
||||
|
||||
// BackendKeyData
|
||||
@@ -675,11 +798,11 @@ impl Server {
|
||||
}
|
||||
};
|
||||
|
||||
let mut server = Server {
|
||||
let server = Server {
|
||||
address: address.clone(),
|
||||
stream: BufStream::new(stream),
|
||||
buffer: BytesMut::with_capacity(8196),
|
||||
server_info,
|
||||
server_parameters,
|
||||
process_id,
|
||||
secret_key,
|
||||
in_transaction: false,
|
||||
@@ -691,7 +814,7 @@ impl Server {
|
||||
addr_set,
|
||||
connected_at: chrono::offset::Utc::now().naive_utc(),
|
||||
stats,
|
||||
application_name: String::new(),
|
||||
application_name: "pgcat".to_string(),
|
||||
last_activity: SystemTime::now(),
|
||||
mirror_manager: match address.mirrors.len() {
|
||||
0 => None,
|
||||
@@ -705,8 +828,6 @@ impl Server {
|
||||
prepared_statements: BTreeSet::new(),
|
||||
};
|
||||
|
||||
server.set_name("pgcat").await?;
|
||||
|
||||
return Ok(server);
|
||||
}
|
||||
|
||||
@@ -776,7 +897,10 @@ impl Server {
|
||||
/// Receive data from the server in response to a client request.
|
||||
/// This method must be called multiple times while `self.is_data_available()` is true
|
||||
/// in order to receive all data the server has to offer.
|
||||
pub async fn recv(&mut self) -> Result<BytesMut, Error> {
|
||||
pub async fn recv(
|
||||
&mut self,
|
||||
mut client_server_parameters: Option<&mut ServerParameters>,
|
||||
) -> Result<BytesMut, Error> {
|
||||
loop {
|
||||
let mut message = match read_message(&mut self.stream).await {
|
||||
Ok(message) => message,
|
||||
@@ -848,14 +972,13 @@ impl Server {
|
||||
self.in_copy_mode = false;
|
||||
}
|
||||
|
||||
let mut command_tag = String::new();
|
||||
match message.reader().read_to_string(&mut command_tag) {
|
||||
Ok(_) => {
|
||||
match message.read_string() {
|
||||
Ok(command) => {
|
||||
// Non-exhaustive list of commands that are likely to change session variables/resources
|
||||
// which can leak between clients. This is a best effort to block bad clients
|
||||
// from poisoning a transaction-mode pool by setting inappropriate session variables
|
||||
match command_tag.as_str() {
|
||||
"SET\0" => {
|
||||
match command.as_str() {
|
||||
"SET" => {
|
||||
// We don't detect set statements in transactions
|
||||
// No great way to differentiate between set and set local
|
||||
// As a result, we will miss cases when set statements are used in transactions
|
||||
@@ -865,7 +988,8 @@ impl Server {
|
||||
self.cleanup_state.needs_cleanup_set = true;
|
||||
}
|
||||
}
|
||||
"PREPARE\0" => {
|
||||
|
||||
"PREPARE" => {
|
||||
debug!("Server connection marked for clean up");
|
||||
self.cleanup_state.needs_cleanup_prepare = true;
|
||||
}
|
||||
@@ -879,6 +1003,17 @@ impl Server {
|
||||
}
|
||||
}
|
||||
|
||||
'S' => {
|
||||
let key = message.read_string().unwrap();
|
||||
let value = message.read_string().unwrap();
|
||||
|
||||
if let Some(client_server_parameters) = client_server_parameters.as_mut() {
|
||||
client_server_parameters.set_param(key.clone(), value.clone(), false);
|
||||
}
|
||||
|
||||
self.server_parameters.set_param(key, value, false);
|
||||
}
|
||||
|
||||
// DataRow
|
||||
'D' => {
|
||||
// More data is available after this message, this is not the end of the reply.
|
||||
@@ -1089,9 +1224,28 @@ impl Server {
|
||||
}
|
||||
|
||||
/// Get server startup information to forward it to the client.
|
||||
/// Not used at the moment.
|
||||
pub fn server_info(&self) -> BytesMut {
|
||||
self.server_info.clone()
|
||||
pub fn server_parameters(&self) -> ServerParameters {
|
||||
self.server_parameters.clone()
|
||||
}
|
||||
|
||||
pub async fn sync_parameters(&mut self, parameters: &ServerParameters) -> Result<(), Error> {
|
||||
let parameter_diff = self.server_parameters.compare_params(parameters);
|
||||
|
||||
if parameter_diff.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut query = String::from("");
|
||||
|
||||
for (key, value) in parameter_diff {
|
||||
query.push_str(&format!("SET {} TO '{}';", key, value));
|
||||
}
|
||||
|
||||
let res = self.query(&query).await;
|
||||
|
||||
self.cleanup_state.reset();
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
/// Indicate that this server connection cannot be re-used and must be discarded.
|
||||
@@ -1125,7 +1279,7 @@ impl Server {
|
||||
self.send(&query).await?;
|
||||
|
||||
loop {
|
||||
let _ = self.recv().await?;
|
||||
let _ = self.recv(None).await?;
|
||||
|
||||
if !self.data_available {
|
||||
break;
|
||||
@@ -1166,24 +1320,6 @@ impl Server {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// A shorthand for `SET application_name = $1`.
|
||||
pub async fn set_name(&mut self, name: &str) -> Result<(), Error> {
|
||||
if self.application_name != name {
|
||||
self.application_name = name.to_string();
|
||||
// We don't want `SET application_name` to mark the server connection
|
||||
// as needing cleanup
|
||||
let needs_cleanup_before = self.cleanup_state;
|
||||
|
||||
let result = Ok(self
|
||||
.query(&format!("SET application_name = '{}'", name))
|
||||
.await?);
|
||||
self.cleanup_state = needs_cleanup_before;
|
||||
result
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// get Server stats
|
||||
pub fn stats(&self) -> Arc<ServerStats> {
|
||||
self.stats.clone()
|
||||
@@ -1241,7 +1377,7 @@ impl Server {
|
||||
.await?;
|
||||
debug!("Connected!, sending query.");
|
||||
server.send(&simple_query(query)).await?;
|
||||
let mut message = server.recv().await?;
|
||||
let mut message = server.recv(None).await?;
|
||||
|
||||
Ok(parse_query_message(&mut message).await?)
|
||||
}
|
||||
|
||||
@@ -112,10 +112,16 @@ class PgcatProcess
|
||||
"postgresql://#{username}:#{password}@0.0.0.0:#{@port}/pgcat"
|
||||
end
|
||||
|
||||
def connection_string(pool_name, username, password = nil)
|
||||
def connection_string(pool_name, username, password = nil, parameters: {})
|
||||
cfg = current_config
|
||||
user_idx, user_obj = cfg["pools"][pool_name]["users"].detect { |k, user| user["username"] == username }
|
||||
"postgresql://#{username}:#{password || user_obj["password"]}@0.0.0.0:#{@port}/#{pool_name}"
|
||||
connection_string = "postgresql://#{username}:#{password || user_obj["password"]}@0.0.0.0:#{@port}/#{pool_name}"
|
||||
|
||||
# Add the additional parameters to the connection string
|
||||
parameter_string = parameters.map { |key, value| "#{key}=#{value}" }.join("&")
|
||||
connection_string += "?#{parameter_string}" unless parameter_string.empty?
|
||||
|
||||
connection_string
|
||||
end
|
||||
|
||||
def example_connection_string
|
||||
|
||||
@@ -294,6 +294,30 @@ describe "Miscellaneous" do
|
||||
|
||||
expect(processes.primary.count_query("DISCARD ALL")).to eq(10)
|
||||
end
|
||||
|
||||
it "Respects tracked parameters on startup" do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user", parameters: { "application_name" => "my_pgcat_test" }))
|
||||
|
||||
expect(conn.async_exec("SHOW application_name")[0]["application_name"]).to eq("my_pgcat_test")
|
||||
conn.close
|
||||
end
|
||||
|
||||
it "Respect tracked parameter on set statemet" do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
|
||||
conn.async_exec("SET application_name to 'my_pgcat_test'")
|
||||
expect(conn.async_exec("SHOW application_name")[0]["application_name"]).to eq("my_pgcat_test")
|
||||
end
|
||||
|
||||
|
||||
it "Ignore untracked parameter on set statemet" do
|
||||
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
orignal_statement_timeout = conn.async_exec("SHOW statement_timeout")[0]["statement_timeout"]
|
||||
|
||||
conn.async_exec("SET statement_timeout to 1500")
|
||||
expect(conn.async_exec("SHOW statement_timeout")[0]["statement_timeout"]).to eq(orignal_statement_timeout)
|
||||
end
|
||||
|
||||
end
|
||||
|
||||
context "transaction mode with transactions" do
|
||||
|
||||
Reference in New Issue
Block a user