mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 01:16:30 +00:00
Compare commits
39 Commits
dependabot
...
kczimm-mea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc649aaee3 | ||
|
|
b4ba3b378c | ||
|
|
81536a0bad | ||
|
|
6eb01e51a0 | ||
|
|
ae3241b634 | ||
|
|
33724ea670 | ||
|
|
1c26aa3547 | ||
|
|
64eb417125 | ||
|
|
22d9d3c90a | ||
|
|
3162d550fd | ||
|
|
12522562ce | ||
|
|
4cf54a6122 | ||
|
|
2a8f3653a6 | ||
|
|
19cb8a3022 | ||
|
|
f85e5bd9e8 | ||
|
|
7bdb4e5cd9 | ||
|
|
5d87e3781e | ||
|
|
3e08c6bd8d | ||
|
|
15b6db8e4e | ||
|
|
b2e6dfd9bb | ||
|
|
3c9565d351 | ||
|
|
67579c9af4 | ||
|
|
cf7f6f35ab | ||
|
|
7205537b49 | ||
|
|
1ed6e925ed | ||
|
|
4b78af9676 | ||
|
|
73500c0c96 | ||
|
|
b167de5aa3 | ||
|
|
473bb3d17d | ||
|
|
c7d6273037 | ||
|
|
94c781881f | ||
|
|
a8c81e5df6 | ||
|
|
1d3746ec9e | ||
|
|
b5489dc1e6 | ||
|
|
557b425fb1 | ||
|
|
aca9738821 | ||
|
|
0bc453a771 | ||
|
|
b67c33b6d0 | ||
|
|
a8a30ad43b |
7
.github/workflows/build-and-push.yaml
vendored
7
.github/workflows/build-and-push.yaml
vendored
@@ -1,6 +1,11 @@
|
||||
name: Build and Push
|
||||
|
||||
on: push
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
tags:
|
||||
- v*
|
||||
|
||||
env:
|
||||
registry: ghcr.io
|
||||
|
||||
14
CONFIG.md
14
CONFIG.md
@@ -1,4 +1,4 @@
|
||||
# PgCat Configurations
|
||||
# PgCat Configurations
|
||||
## `general` Section
|
||||
|
||||
### host
|
||||
@@ -116,10 +116,10 @@ If we should log client disconnections
|
||||
### autoreload
|
||||
```
|
||||
path: general.autoreload
|
||||
default: 15000
|
||||
default: 15000 # milliseconds
|
||||
```
|
||||
|
||||
When set to true, PgCat reloads configs if it detects a change in the config file.
|
||||
When set, PgCat automatically reloads its configurations at the specified interval (in milliseconds) if it detects changes in the configuration file. The default interval is 15000 milliseconds or 15 seconds.
|
||||
|
||||
### worker_threads
|
||||
```
|
||||
@@ -151,7 +151,13 @@ path: general.tcp_keepalives_interval
|
||||
default: 5
|
||||
```
|
||||
|
||||
Number of seconds between keepalive packets.
|
||||
### tcp_user_timeout
|
||||
```
|
||||
path: general.tcp_user_timeout
|
||||
default: 10000
|
||||
```
|
||||
A linux-only parameters that defines the amount of time in milliseconds that transmitted data may remain unacknowledged or buffered data may remain untransmitted (due to zero window size) before TCP will forcibly disconnect
|
||||
|
||||
|
||||
### tls_certificate
|
||||
```
|
||||
|
||||
916
Cargo.lock
generated
916
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "pgcat"
|
||||
version = "1.0.2-alpha3"
|
||||
version = "1.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
@@ -8,7 +8,7 @@ edition = "2021"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
bytes = "1"
|
||||
md-5 = "0.10"
|
||||
bb8 = "0.8.0"
|
||||
bb8 = "0.8.1"
|
||||
async-trait = "0.1"
|
||||
rand = "0.8"
|
||||
chrono = "0.4"
|
||||
@@ -19,7 +19,7 @@ serde_derive = "1"
|
||||
regex = "1"
|
||||
num_cpus = "1"
|
||||
once_cell = "1"
|
||||
sqlparser = {version = "0.33", features = ["visitor"] }
|
||||
sqlparser = {version = "0.34", features = ["visitor"] }
|
||||
log = "0.4"
|
||||
arc-swap = "1"
|
||||
env_logger = "0.10"
|
||||
@@ -46,6 +46,9 @@ trust-dns-resolver = "0.22.0"
|
||||
tokio-test = "0.4.2"
|
||||
serde_json = "1"
|
||||
itertools = "0.10"
|
||||
clap = { version = "4.3.1", features = ["derive", "env"] }
|
||||
tracing = "0.1.37"
|
||||
tracing-subscriber = { version = "0.3.17", features = ["json"]}
|
||||
|
||||
[target.'cfg(not(target_env = "msvc"))'.dependencies]
|
||||
jemallocator = "0.5.0"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM rust:bullseye
|
||||
FROM rust:1.70-bullseye
|
||||
|
||||
# Dependencies
|
||||
RUN apt-get update -y \
|
||||
|
||||
@@ -60,6 +60,12 @@ tcp_keepalives_count = 5
|
||||
# Number of seconds between keepalive packets.
|
||||
tcp_keepalives_interval = 5
|
||||
|
||||
# Handle prepared statements.
|
||||
prepared_statements = true
|
||||
|
||||
# Prepared statements server cache size.
|
||||
prepared_statements_cache_size = 500
|
||||
|
||||
# Path to TLS Certificate file to use for TLS connections
|
||||
# tls_certificate = ".circleci/server.cert"
|
||||
# Path to TLS private key file to use for TLS connections
|
||||
|
||||
120
src/admin.rs
120
src/admin.rs
@@ -1,4 +1,5 @@
|
||||
use crate::pool::BanReason;
|
||||
use crate::stats::pool::PoolStats;
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use log::{error, info, trace};
|
||||
use nix::sys::signal::{self, Signal};
|
||||
@@ -14,7 +15,7 @@ use crate::errors::Error;
|
||||
use crate::messages::*;
|
||||
use crate::pool::ClientServerMap;
|
||||
use crate::pool::{get_all_pools, get_pool};
|
||||
use crate::stats::{get_client_stats, get_pool_stats, get_server_stats, ClientState, ServerState};
|
||||
use crate::stats::{get_client_stats, get_server_stats, ClientState, ServerState};
|
||||
|
||||
pub fn generate_server_info_for_admin() -> BytesMut {
|
||||
let mut server_info = BytesMut::new();
|
||||
@@ -83,6 +84,10 @@ where
|
||||
shutdown(stream).await
|
||||
}
|
||||
"SHOW" => match query_parts[1].to_ascii_uppercase().as_str() {
|
||||
"HELP" => {
|
||||
trace!("SHOW HELP");
|
||||
show_help(stream).await
|
||||
}
|
||||
"BANS" => {
|
||||
trace!("SHOW BANS");
|
||||
show_bans(stream).await
|
||||
@@ -254,39 +259,51 @@ async fn show_pools<T>(stream: &mut T) -> Result<(), Error>
|
||||
where
|
||||
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
let all_pool_stats = get_pool_stats();
|
||||
let pool_lookup = PoolStats::construct_pool_lookup();
|
||||
let mut res = BytesMut::new();
|
||||
res.put(row_description(&PoolStats::generate_header()));
|
||||
pool_lookup.iter().for_each(|(_identifier, pool_stats)| {
|
||||
res.put(data_row(&pool_stats.generate_row()));
|
||||
});
|
||||
res.put(command_complete("SHOW"));
|
||||
|
||||
let columns = vec![
|
||||
("database", DataType::Text),
|
||||
("user", DataType::Text),
|
||||
("pool_mode", DataType::Text),
|
||||
("cl_idle", DataType::Numeric),
|
||||
("cl_active", DataType::Numeric),
|
||||
("cl_waiting", DataType::Numeric),
|
||||
("cl_cancel_req", DataType::Numeric),
|
||||
("sv_active", DataType::Numeric),
|
||||
("sv_idle", DataType::Numeric),
|
||||
("sv_used", DataType::Numeric),
|
||||
("sv_tested", DataType::Numeric),
|
||||
("sv_login", DataType::Numeric),
|
||||
("maxwait", DataType::Numeric),
|
||||
("maxwait_us", DataType::Numeric),
|
||||
// ReadyForQuery
|
||||
res.put_u8(b'Z');
|
||||
res.put_i32(5);
|
||||
res.put_u8(b'I');
|
||||
|
||||
write_all_half(stream, &res).await
|
||||
}
|
||||
|
||||
/// Show all available options.
|
||||
async fn show_help<T>(stream: &mut T) -> Result<(), Error>
|
||||
where
|
||||
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
let mut res = BytesMut::new();
|
||||
|
||||
let detail_msg = vec![
|
||||
"",
|
||||
"SHOW HELP|CONFIG|DATABASES|POOLS|CLIENTS|SERVERS|USERS|VERSION",
|
||||
// "SHOW PEERS|PEER_POOLS", // missing PEERS|PEER_POOLS
|
||||
// "SHOW FDS|SOCKETS|ACTIVE_SOCKETS|LISTS|MEM|STATE", // missing FDS|SOCKETS|ACTIVE_SOCKETS|MEM|STATE
|
||||
"SHOW LISTS",
|
||||
// "SHOW DNS_HOSTS|DNS_ZONES", // missing DNS_HOSTS|DNS_ZONES
|
||||
"SHOW STATS", // missing STATS_TOTALS|STATS_AVERAGES|TOTALS
|
||||
"SET key = arg",
|
||||
"RELOAD",
|
||||
"PAUSE [<db>, <user>]",
|
||||
"RESUME [<db>, <user>]",
|
||||
// "DISABLE <db>", // missing
|
||||
// "ENABLE <db>", // missing
|
||||
// "RECONNECT [<db>]", missing
|
||||
// "KILL <db>",
|
||||
// "SUSPEND",
|
||||
"SHUTDOWN",
|
||||
// "WAIT_CLOSE [<db>]", // missing
|
||||
];
|
||||
|
||||
let mut res = BytesMut::new();
|
||||
res.put(row_description(&columns));
|
||||
|
||||
for ((_user_pool, _pool), pool_stats) in all_pool_stats {
|
||||
let mut row = vec![
|
||||
pool_stats.database(),
|
||||
pool_stats.user(),
|
||||
pool_stats.pool_mode().to_string(),
|
||||
];
|
||||
pool_stats.populate_row(&mut row);
|
||||
pool_stats.clear_maxwait();
|
||||
res.put(data_row(&row));
|
||||
}
|
||||
|
||||
res.put(notify("Console usage", detail_msg.join("\n\t")));
|
||||
res.put(command_complete("SHOW"));
|
||||
|
||||
// ReadyForQuery
|
||||
@@ -334,17 +351,17 @@ where
|
||||
let paused = pool.paused();
|
||||
|
||||
res.put(data_row(&vec![
|
||||
address.name(), // name
|
||||
address.host.to_string(), // host
|
||||
address.port.to_string(), // port
|
||||
database_name.to_string(), // database
|
||||
pool_config.user.username.to_string(), // force_user
|
||||
pool_config.user.pool_size.to_string(), // pool_size
|
||||
"0".to_string(), // min_pool_size
|
||||
"0".to_string(), // reserve_pool
|
||||
pool_config.pool_mode.to_string(), // pool_mode
|
||||
pool_config.user.pool_size.to_string(), // max_connections
|
||||
pool_state.connections.to_string(), // current_connections
|
||||
address.name(), // name
|
||||
address.host.to_string(), // host
|
||||
address.port.to_string(), // port
|
||||
database_name.to_string(), // database
|
||||
pool_config.user.username.to_string(), // force_user
|
||||
pool_config.user.pool_size.to_string(), // pool_size
|
||||
pool_config.user.min_pool_size.unwrap_or(0).to_string(), // min_pool_size
|
||||
"0".to_string(), // reserve_pool
|
||||
pool_config.pool_mode.to_string(), // pool_mode
|
||||
pool_config.user.pool_size.to_string(), // max_connections
|
||||
pool_state.connections.to_string(), // current_connections
|
||||
match paused {
|
||||
// paused
|
||||
true => "1".to_string(),
|
||||
@@ -725,6 +742,9 @@ where
|
||||
("bytes_sent", DataType::Numeric),
|
||||
("bytes_received", DataType::Numeric),
|
||||
("age_seconds", DataType::Numeric),
|
||||
("prepare_cache_hit", DataType::Numeric),
|
||||
("prepare_cache_miss", DataType::Numeric),
|
||||
("prepare_cache_size", DataType::Numeric),
|
||||
];
|
||||
|
||||
let new_map = get_server_stats();
|
||||
@@ -748,6 +768,18 @@ where
|
||||
.duration_since(server.connect_time())
|
||||
.as_secs()
|
||||
.to_string(),
|
||||
server
|
||||
.prepared_hit_count
|
||||
.load(Ordering::Relaxed)
|
||||
.to_string(),
|
||||
server
|
||||
.prepared_miss_count
|
||||
.load(Ordering::Relaxed)
|
||||
.to_string(),
|
||||
server
|
||||
.prepared_cache_size
|
||||
.load(Ordering::Relaxed)
|
||||
.to_string(),
|
||||
];
|
||||
|
||||
res.put(data_row(&row));
|
||||
@@ -768,7 +800,7 @@ async fn pause<T>(stream: &mut T, query: &str) -> Result<(), Error>
|
||||
where
|
||||
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
let parts: Vec<&str> = query.split(",").map(|part| part.trim()).collect();
|
||||
let parts: Vec<&str> = query.split(',').map(|part| part.trim()).collect();
|
||||
|
||||
if parts.len() != 2 {
|
||||
error_response(
|
||||
@@ -815,7 +847,7 @@ async fn resume<T>(stream: &mut T, query: &str) -> Result<(), Error>
|
||||
where
|
||||
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
let parts: Vec<&str> = query.split(",").map(|part| part.trim()).collect();
|
||||
let parts: Vec<&str> = query.split(',').map(|part| part.trim()).collect();
|
||||
|
||||
if parts.len() != 2 {
|
||||
error_response(
|
||||
|
||||
@@ -12,7 +12,7 @@ pub struct AuthPassthrough {
|
||||
|
||||
impl AuthPassthrough {
|
||||
/// Initializes an AuthPassthrough.
|
||||
pub fn new(query: &str, user: &str, password: &str) -> Self {
|
||||
pub fn new<S: ToString>(query: S, user: S, password: S) -> Self {
|
||||
AuthPassthrough {
|
||||
password: password.to_string(),
|
||||
query: query.to_string(),
|
||||
|
||||
302
src/client.rs
302
src/client.rs
@@ -3,8 +3,9 @@ use crate::pool::BanReason;
|
||||
/// Handle clients by pretending to be a PostgreSQL server.
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use log::{debug, error, info, trace, warn};
|
||||
use once_cell::sync::Lazy;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{atomic::AtomicUsize, Arc};
|
||||
use std::time::Instant;
|
||||
use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf};
|
||||
use tokio::net::TcpStream;
|
||||
@@ -13,18 +14,25 @@ use tokio::sync::mpsc::Sender;
|
||||
|
||||
use crate::admin::{generate_server_info_for_admin, handle_admin};
|
||||
use crate::auth_passthrough::refetch_auth_hash;
|
||||
use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode};
|
||||
use crate::config::{
|
||||
get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address, PoolMode,
|
||||
};
|
||||
use crate::constants::*;
|
||||
use crate::messages::*;
|
||||
use crate::plugins::PluginOutput;
|
||||
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
|
||||
use crate::query_router::{Command, QueryRouter};
|
||||
use crate::server::Server;
|
||||
use crate::stats::{ClientStats, PoolStats, ServerStats};
|
||||
use crate::stats::{ClientStats, ServerStats};
|
||||
use crate::tls::Tls;
|
||||
|
||||
use tokio_rustls::server::TlsStream;
|
||||
|
||||
/// Incrementally count prepared statements
|
||||
/// to avoid random conflicts in places where the random number generator is weak.
|
||||
pub static PREPARED_STATEMENT_COUNTER: Lazy<Arc<AtomicUsize>> =
|
||||
Lazy::new(|| Arc::new(AtomicUsize::new(0)));
|
||||
|
||||
/// Type of connection received from client.
|
||||
enum ClientConnectionType {
|
||||
Startup,
|
||||
@@ -93,6 +101,9 @@ pub struct Client<S, T> {
|
||||
|
||||
/// Used to notify clients about an impending shutdown
|
||||
shutdown: Receiver<()>,
|
||||
|
||||
/// Prepared statements
|
||||
prepared_statements: HashMap<String, Parse>,
|
||||
}
|
||||
|
||||
/// Client entrypoint.
|
||||
@@ -112,7 +123,7 @@ pub async fn client_entrypoint(
|
||||
// Client requested a TLS connection.
|
||||
Ok((ClientConnectionType::Tls, _)) => {
|
||||
// TLS settings are configured, will setup TLS now.
|
||||
if tls_certificate != None {
|
||||
if tls_certificate.is_some() {
|
||||
debug!("Accepting TLS request");
|
||||
|
||||
let mut yes = BytesMut::new();
|
||||
@@ -420,7 +431,7 @@ where
|
||||
None => "pgcat",
|
||||
};
|
||||
|
||||
let client_identifier = ClientIdentifier::new(&application_name, &username, &pool_name);
|
||||
let client_identifier = ClientIdentifier::new(application_name, username, pool_name);
|
||||
|
||||
let admin = ["pgcat", "pgbouncer"]
|
||||
.iter()
|
||||
@@ -654,24 +665,12 @@ where
|
||||
ready_for_query(&mut write).await?;
|
||||
|
||||
trace!("Startup OK");
|
||||
let pool_stats = match get_pool(pool_name, username) {
|
||||
Some(pool) => {
|
||||
if !admin {
|
||||
pool.stats
|
||||
} else {
|
||||
Arc::new(PoolStats::default())
|
||||
}
|
||||
}
|
||||
None => Arc::new(PoolStats::default()),
|
||||
};
|
||||
|
||||
let stats = Arc::new(ClientStats::new(
|
||||
process_id,
|
||||
application_name,
|
||||
username,
|
||||
pool_name,
|
||||
tokio::time::Instant::now(),
|
||||
pool_stats,
|
||||
));
|
||||
|
||||
Ok(Client {
|
||||
@@ -694,6 +693,7 @@ where
|
||||
application_name: application_name.to_string(),
|
||||
shutdown,
|
||||
connected_to_server: false,
|
||||
prepared_statements: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -728,6 +728,7 @@ where
|
||||
application_name: String::from("undefined"),
|
||||
shutdown,
|
||||
connected_to_server: false,
|
||||
prepared_statements: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -769,6 +770,10 @@ where
|
||||
// Result returned by one of the plugins.
|
||||
let mut plugin_output = None;
|
||||
|
||||
// Prepared statement being executed
|
||||
let mut prepared_statement = None;
|
||||
let mut will_prepare = false;
|
||||
|
||||
// Our custom protocol loop.
|
||||
// We expect the client to either start a transaction with regular queries
|
||||
// or issue commands for our sharding and server selection protocol.
|
||||
@@ -778,13 +783,16 @@ where
|
||||
self.transaction_mode
|
||||
);
|
||||
|
||||
// Should we rewrite prepared statements and bind messages?
|
||||
let mut prepared_statements_enabled = get_prepared_statements();
|
||||
|
||||
// Read a complete message from the client, which normally would be
|
||||
// either a `Q` (query) or `P` (prepare, extended protocol).
|
||||
// We can parse it here before grabbing a server from the pool,
|
||||
// in case the client is sending some custom protocol messages, e.g.
|
||||
// SET SHARDING KEY TO 'bigint';
|
||||
|
||||
let message = tokio::select! {
|
||||
let mut message = tokio::select! {
|
||||
_ = self.shutdown.recv() => {
|
||||
if !self.admin {
|
||||
error_response_terminal(
|
||||
@@ -812,7 +820,21 @@ where
|
||||
// allocate a connection, we wouldn't be able to send back an error message
|
||||
// to the client so we buffer them and defer the decision to error out or not
|
||||
// to when we get the S message
|
||||
'D' | 'E' => {
|
||||
'D' => {
|
||||
if prepared_statements_enabled {
|
||||
let name;
|
||||
(name, message) = self.rewrite_describe(message).await?;
|
||||
|
||||
if let Some(name) = name {
|
||||
prepared_statement = Some(name);
|
||||
}
|
||||
}
|
||||
|
||||
self.buffer.put(&message[..]);
|
||||
continue;
|
||||
}
|
||||
|
||||
'E' => {
|
||||
self.buffer.put(&message[..]);
|
||||
continue;
|
||||
}
|
||||
@@ -842,6 +864,11 @@ where
|
||||
}
|
||||
|
||||
'P' => {
|
||||
if prepared_statements_enabled {
|
||||
(prepared_statement, message) = self.rewrite_parse(message)?;
|
||||
will_prepare = true;
|
||||
}
|
||||
|
||||
self.buffer.put(&message[..]);
|
||||
|
||||
if query_router.query_parser_enabled() {
|
||||
@@ -858,6 +885,10 @@ where
|
||||
}
|
||||
|
||||
'B' => {
|
||||
if prepared_statements_enabled {
|
||||
(prepared_statement, message) = self.rewrite_bind(message).await?;
|
||||
}
|
||||
|
||||
self.buffer.put(&message[..]);
|
||||
|
||||
if query_router.query_parser_enabled() {
|
||||
@@ -875,6 +906,19 @@ where
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Close (F)
|
||||
'C' => {
|
||||
if prepared_statements_enabled {
|
||||
let close: Close = (&message).try_into()?;
|
||||
|
||||
if close.is_prepared_statement() && !close.anonymous() {
|
||||
self.prepared_statements.remove(&close.name);
|
||||
write_all_flush(&mut self.write, &close_complete()).await?;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_ => (),
|
||||
}
|
||||
|
||||
@@ -886,16 +930,12 @@ where
|
||||
}
|
||||
|
||||
// Check on plugin results.
|
||||
match plugin_output {
|
||||
Some(PluginOutput::Deny(error)) => {
|
||||
self.buffer.clear();
|
||||
error_response(&mut self.write, &error).await?;
|
||||
plugin_output = None;
|
||||
continue;
|
||||
}
|
||||
|
||||
_ => (),
|
||||
};
|
||||
if let Some(PluginOutput::Deny(error)) = plugin_output {
|
||||
self.buffer.clear();
|
||||
error_response(&mut self.write, &error).await?;
|
||||
plugin_output = None;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get a pool instance referenced by the most up-to-date
|
||||
// pointer. This ensures we always read the latest config
|
||||
@@ -1066,7 +1106,58 @@ where
|
||||
// If the client is in session mode, no more custom protocol
|
||||
// commands will be accepted.
|
||||
loop {
|
||||
let message = match initial_message {
|
||||
// Only check if we should rewrite prepared statements
|
||||
// in session mode. In transaction mode, we check at the beginning of
|
||||
// each transaction.
|
||||
if !self.transaction_mode {
|
||||
prepared_statements_enabled = get_prepared_statements();
|
||||
}
|
||||
|
||||
debug!("Prepared statement active: {:?}", prepared_statement);
|
||||
|
||||
// We are processing a prepared statement.
|
||||
if let Some(ref name) = prepared_statement {
|
||||
debug!("Checking prepared statement is on server");
|
||||
// Get the prepared statement the server expects to see.
|
||||
let statement = match self.prepared_statements.get(name) {
|
||||
Some(statement) => {
|
||||
debug!("Prepared statement `{}` found in cache", name);
|
||||
statement
|
||||
}
|
||||
None => {
|
||||
return Err(Error::ClientError(format!(
|
||||
"prepared statement `{}` not found",
|
||||
name
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
// Since it's already in the buffer, we don't need to prepare it on this server.
|
||||
if will_prepare {
|
||||
server.will_prepare(&statement.name);
|
||||
will_prepare = false;
|
||||
} else {
|
||||
// The statement is not prepared on the server, so we need to prepare it.
|
||||
if server.should_prepare(&statement.name) {
|
||||
match server.prepare(statement).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
pool.ban(
|
||||
&address,
|
||||
BanReason::MessageSendFailed,
|
||||
Some(&self.stats),
|
||||
);
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Done processing the prepared statement.
|
||||
prepared_statement = None;
|
||||
}
|
||||
|
||||
let mut message = match initial_message {
|
||||
None => {
|
||||
trace!("Waiting for message inside transaction or in session mode");
|
||||
|
||||
@@ -1118,7 +1209,7 @@ where
|
||||
|
||||
// Safe to unwrap because we know this message has a certain length and has the code
|
||||
// This reads the first byte without advancing the internal pointer and mutating the bytes
|
||||
let code = *message.get(0).unwrap() as char;
|
||||
let code = *message.first().unwrap() as char;
|
||||
|
||||
trace!("Message: {}", code);
|
||||
|
||||
@@ -1165,7 +1256,7 @@ where
|
||||
|
||||
// Release server back to the pool if we are in transaction mode.
|
||||
// If we are in session mode, we keep the server until the client disconnects.
|
||||
if self.transaction_mode {
|
||||
if self.transaction_mode && !server.in_copy_mode() {
|
||||
self.stats.idle();
|
||||
|
||||
break;
|
||||
@@ -1179,12 +1270,21 @@ where
|
||||
self.stats.disconnect();
|
||||
self.release();
|
||||
|
||||
if prepared_statements_enabled {
|
||||
server.maintain_cache().await?;
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Parse
|
||||
// The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
|
||||
'P' => {
|
||||
if prepared_statements_enabled {
|
||||
(prepared_statement, message) = self.rewrite_parse(message)?;
|
||||
will_prepare = true;
|
||||
}
|
||||
|
||||
if query_router.query_parser_enabled() {
|
||||
if let Ok(ast) = QueryRouter::parse(&message) {
|
||||
if let Ok(output) = query_router.execute_plugins(&ast).await {
|
||||
@@ -1199,17 +1299,42 @@ where
|
||||
// Bind
|
||||
// The placeholder's replacements are here, e.g. 'user@email.com' and 'true'
|
||||
'B' => {
|
||||
if prepared_statements_enabled {
|
||||
(prepared_statement, message) = self.rewrite_bind(message).await?;
|
||||
}
|
||||
|
||||
self.buffer.put(&message[..]);
|
||||
}
|
||||
|
||||
// Describe
|
||||
// Command a client can issue to describe a previously prepared named statement.
|
||||
'D' => {
|
||||
if prepared_statements_enabled {
|
||||
let name;
|
||||
(name, message) = self.rewrite_describe(message).await?;
|
||||
|
||||
if let Some(name) = name {
|
||||
prepared_statement = Some(name);
|
||||
}
|
||||
}
|
||||
|
||||
self.buffer.put(&message[..]);
|
||||
}
|
||||
|
||||
// Close the prepared statement.
|
||||
'C' => {
|
||||
if prepared_statements_enabled {
|
||||
let close: Close = (&message).try_into()?;
|
||||
|
||||
if close.is_prepared_statement() && !close.anonymous() {
|
||||
if let Some(parse) = self.prepared_statements.get(&close.name) {
|
||||
server.will_close(&parse.generated_name);
|
||||
} else {
|
||||
// A prepared statement slipped through? Not impossible, since we don't support PREPARE yet.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.buffer.put(&message[..]);
|
||||
}
|
||||
|
||||
@@ -1244,10 +1369,10 @@ where
|
||||
|
||||
self.buffer.put(&message[..]);
|
||||
|
||||
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;
|
||||
let first_message_code = (*self.buffer.first().unwrap_or(&0)) as char;
|
||||
|
||||
// Almost certainly true
|
||||
if first_message_code == 'P' {
|
||||
if first_message_code == 'P' && !prepared_statements_enabled {
|
||||
// Message layout
|
||||
// P followed by 32 int followed by null-terminated statement name
|
||||
// So message code should be in offset 0 of the buffer, first character
|
||||
@@ -1278,7 +1403,7 @@ where
|
||||
|
||||
// Release server back to the pool if we are in transaction mode.
|
||||
// If we are in session mode, we keep the server until the client disconnects.
|
||||
if self.transaction_mode {
|
||||
if self.transaction_mode && !server.in_copy_mode() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -1343,7 +1468,13 @@ where
|
||||
|
||||
// The server is no longer bound to us, we can't cancel it's queries anymore.
|
||||
debug!("Releasing server back into the pool");
|
||||
|
||||
server.checkin_cleanup().await?;
|
||||
|
||||
if prepared_statements_enabled {
|
||||
server.maintain_cache().await?;
|
||||
}
|
||||
|
||||
server.stats().idle();
|
||||
self.connected_to_server = false;
|
||||
|
||||
@@ -1375,6 +1506,107 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// Rewrite Parse (F) message to set the prepared statement name to one we control.
|
||||
/// Save it into the client cache.
|
||||
fn rewrite_parse(&mut self, message: BytesMut) -> Result<(Option<String>, BytesMut), Error> {
|
||||
let parse: Parse = (&message).try_into()?;
|
||||
|
||||
let name = parse.name.clone();
|
||||
|
||||
// Don't rewrite anonymous prepared statements
|
||||
if parse.anonymous() {
|
||||
debug!("Anonymous prepared statement");
|
||||
return Ok((None, message));
|
||||
}
|
||||
|
||||
let parse = parse.rename();
|
||||
|
||||
debug!(
|
||||
"Renamed prepared statement `{}` to `{}` and saved to cache",
|
||||
name, parse.name
|
||||
);
|
||||
|
||||
self.prepared_statements.insert(name.clone(), parse.clone());
|
||||
|
||||
Ok((Some(name), parse.try_into()?))
|
||||
}
|
||||
|
||||
/// Rewrite the Bind (F) message to use the prepared statement name
|
||||
/// saved in the client cache.
|
||||
async fn rewrite_bind(
|
||||
&mut self,
|
||||
message: BytesMut,
|
||||
) -> Result<(Option<String>, BytesMut), Error> {
|
||||
let bind: Bind = (&message).try_into()?;
|
||||
let name = bind.prepared_statement.clone();
|
||||
|
||||
if bind.anonymous() {
|
||||
debug!("Anonymous bind message");
|
||||
return Ok((None, message));
|
||||
}
|
||||
|
||||
match self.prepared_statements.get(&name) {
|
||||
Some(prepared_stmt) => {
|
||||
let bind = bind.reassign(prepared_stmt);
|
||||
|
||||
debug!("Rewrote bind `{}` to `{}`", name, bind.prepared_statement);
|
||||
|
||||
Ok((Some(name), bind.try_into()?))
|
||||
}
|
||||
None => {
|
||||
debug!("Got bind for unknown prepared statement {:?}", bind);
|
||||
|
||||
error_response(
|
||||
&mut self.write,
|
||||
&format!(
|
||||
"prepared statement \"{}\" does not exist",
|
||||
bind.prepared_statement
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Err(Error::ClientError(format!(
|
||||
"Prepared statement `{}` doesn't exist",
|
||||
name
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Rewrite the Describe (F) message to use the prepared statement name
|
||||
/// saved in the client cache.
|
||||
async fn rewrite_describe(
|
||||
&mut self,
|
||||
message: BytesMut,
|
||||
) -> Result<(Option<String>, BytesMut), Error> {
|
||||
let describe: Describe = (&message).try_into()?;
|
||||
let name = describe.statement_name.clone();
|
||||
|
||||
if describe.anonymous() {
|
||||
debug!("Anonymous describe");
|
||||
return Ok((None, message));
|
||||
}
|
||||
|
||||
match self.prepared_statements.get(&name) {
|
||||
Some(prepared_stmt) => {
|
||||
let describe = describe.rename(&prepared_stmt.name);
|
||||
|
||||
debug!(
|
||||
"Rewrote describe `{}` to `{}`",
|
||||
name, describe.statement_name
|
||||
);
|
||||
|
||||
Ok((Some(name), describe.try_into()?))
|
||||
}
|
||||
|
||||
None => {
|
||||
debug!("Got describe for unknown prepared statement {:?}", describe);
|
||||
|
||||
Ok((None, message))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Release the server from the client: it can't cancel its queries anymore.
|
||||
pub fn release(&self) {
|
||||
let mut guard = self.client_server_map.lock();
|
||||
|
||||
36
src/cmd_args.rs
Normal file
36
src/cmd_args.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
use clap::{Parser, ValueEnum};
|
||||
use tracing::Level;
|
||||
|
||||
/// PgCat: Nextgen PostgreSQL Pooler
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
#[arg(default_value_t = String::from("pgcat.toml"), env)]
|
||||
pub config_file: String,
|
||||
|
||||
#[arg(short, long, default_value_t = tracing::Level::INFO, env)]
|
||||
pub log_level: Level,
|
||||
|
||||
#[clap(short='F', long, value_enum, default_value_t=LogFormat::Text, env)]
|
||||
pub log_format: LogFormat,
|
||||
|
||||
#[arg(
|
||||
short,
|
||||
long,
|
||||
default_value_t = false,
|
||||
env,
|
||||
help = "disable colors in the log output"
|
||||
)]
|
||||
pub no_color: bool,
|
||||
}
|
||||
|
||||
pub fn parse() -> Args {
|
||||
Args::parse()
|
||||
}
|
||||
|
||||
#[derive(ValueEnum, Clone, Debug)]
|
||||
pub enum LogFormat {
|
||||
Text,
|
||||
Structured,
|
||||
Debug,
|
||||
}
|
||||
146
src/config.rs
146
src/config.rs
@@ -217,19 +217,15 @@ impl Default for User {
|
||||
|
||||
impl User {
|
||||
fn validate(&self) -> Result<(), Error> {
|
||||
match self.min_pool_size {
|
||||
Some(min_pool_size) => {
|
||||
if min_pool_size > self.pool_size {
|
||||
error!(
|
||||
"min_pool_size of {} cannot be larger than pool_size of {}",
|
||||
min_pool_size, self.pool_size
|
||||
);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
if let Some(min_pool_size) = self.min_pool_size {
|
||||
if min_pool_size > self.pool_size {
|
||||
error!(
|
||||
"min_pool_size of {} cannot be larger than pool_size of {}",
|
||||
min_pool_size, self.pool_size
|
||||
);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
|
||||
None => (),
|
||||
};
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -261,6 +257,8 @@ pub struct General {
|
||||
pub tcp_keepalives_count: u32,
|
||||
#[serde(default = "General::default_tcp_keepalives_interval")]
|
||||
pub tcp_keepalives_interval: u64,
|
||||
#[serde(default = "General::default_tcp_user_timeout")]
|
||||
pub tcp_user_timeout: u64,
|
||||
|
||||
#[serde(default)] // False
|
||||
pub log_client_connections: bool,
|
||||
@@ -292,6 +290,9 @@ pub struct General {
|
||||
#[serde(default = "General::default_server_lifetime")]
|
||||
pub server_lifetime: u64,
|
||||
|
||||
#[serde(default = "General::default_server_round_robin")] // False
|
||||
pub server_round_robin: bool,
|
||||
|
||||
#[serde(default = "General::default_worker_threads")]
|
||||
pub worker_threads: usize,
|
||||
|
||||
@@ -317,6 +318,12 @@ pub struct General {
|
||||
pub auth_query: Option<String>,
|
||||
pub auth_query_user: Option<String>,
|
||||
pub auth_query_password: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub prepared_statements: bool,
|
||||
|
||||
#[serde(default = "General::default_prepared_statements_cache_size")]
|
||||
pub prepared_statements_cache_size: usize,
|
||||
}
|
||||
|
||||
impl General {
|
||||
@@ -329,7 +336,7 @@ impl General {
|
||||
}
|
||||
|
||||
pub fn default_server_lifetime() -> u64 {
|
||||
1000 * 60 * 60 * 24 // 24 hours
|
||||
1000 * 60 * 60 // 1 hour
|
||||
}
|
||||
|
||||
pub fn default_connect_timeout() -> u64 {
|
||||
@@ -351,8 +358,12 @@ impl General {
|
||||
5 // 5 seconds
|
||||
}
|
||||
|
||||
pub fn default_tcp_user_timeout() -> u64 {
|
||||
10000 // 10000 milliseconds
|
||||
}
|
||||
|
||||
pub fn default_idle_timeout() -> u64 {
|
||||
60000 // 10 minutes
|
||||
600000 // 10 minutes
|
||||
}
|
||||
|
||||
pub fn default_shutdown_timeout() -> u64 {
|
||||
@@ -390,6 +401,14 @@ impl General {
|
||||
pub fn default_prometheus_exporter_port() -> i16 {
|
||||
9930
|
||||
}
|
||||
|
||||
pub fn default_server_round_robin() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
pub fn default_prepared_statements_cache_size() -> usize {
|
||||
500
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for General {
|
||||
@@ -410,6 +429,7 @@ impl Default for General {
|
||||
tcp_keepalives_idle: Self::default_tcp_keepalives_idle(),
|
||||
tcp_keepalives_count: Self::default_tcp_keepalives_count(),
|
||||
tcp_keepalives_interval: Self::default_tcp_keepalives_interval(),
|
||||
tcp_user_timeout: Self::default_tcp_user_timeout(),
|
||||
log_client_connections: false,
|
||||
log_client_disconnections: false,
|
||||
autoreload: None,
|
||||
@@ -424,8 +444,11 @@ impl Default for General {
|
||||
auth_query: None,
|
||||
auth_query_user: None,
|
||||
auth_query_password: None,
|
||||
server_lifetime: 1000 * 3600 * 24, // 24 hours,
|
||||
server_lifetime: Self::default_server_lifetime(),
|
||||
server_round_robin: Self::default_server_round_robin(),
|
||||
validate_config: true,
|
||||
prepared_statements: false,
|
||||
prepared_statements_cache_size: 500,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -604,9 +627,9 @@ impl Pool {
|
||||
Some(key) => {
|
||||
// No quotes in the key so we don't have to compare quoted
|
||||
// to unquoted idents.
|
||||
let key = key.replace("\"", "");
|
||||
let key = key.replace('\"', "");
|
||||
|
||||
if key.split(".").count() != 2 {
|
||||
if key.split('.').count() != 2 {
|
||||
error!(
|
||||
"automatic_sharding_key '{}' must be fully qualified, e.g. t.{}`",
|
||||
key, key
|
||||
@@ -619,7 +642,7 @@ impl Pool {
|
||||
None => None,
|
||||
};
|
||||
|
||||
for (_, user) in &self.users {
|
||||
for user in self.users.values() {
|
||||
user.validate()?;
|
||||
}
|
||||
|
||||
@@ -791,8 +814,8 @@ pub struct Query {
|
||||
impl Query {
|
||||
pub fn substitute(&mut self, db: &str, user: &str) {
|
||||
for col in self.result.iter_mut() {
|
||||
for i in 0..col.len() {
|
||||
col[i] = col[i].replace("${USER}", user).replace("${DATABASE}", db);
|
||||
for c in col {
|
||||
*c = c.replace("${USER}", user).replace("${DATABASE}", db);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -902,8 +925,8 @@ impl From<&Config> for std::collections::HashMap<String, String> {
|
||||
(
|
||||
format!("pools.{:?}.users", pool_name),
|
||||
pool.users
|
||||
.iter()
|
||||
.map(|(_username, user)| &user.username)
|
||||
.values()
|
||||
.map(|user| &user.username)
|
||||
.cloned()
|
||||
.collect::<Vec<String>>()
|
||||
.join(", "),
|
||||
@@ -983,17 +1006,14 @@ impl Config {
|
||||
"Default max server lifetime: {}ms",
|
||||
self.general.server_lifetime
|
||||
);
|
||||
info!("Sever round robin: {}", self.general.server_round_robin);
|
||||
match self.general.tls_certificate.clone() {
|
||||
Some(tls_certificate) => {
|
||||
info!("TLS certificate: {}", tls_certificate);
|
||||
|
||||
match self.general.tls_private_key.clone() {
|
||||
Some(tls_private_key) => {
|
||||
info!("TLS private key: {}", tls_private_key);
|
||||
info!("TLS support is enabled");
|
||||
}
|
||||
|
||||
None => (),
|
||||
if let Some(tls_private_key) = self.general.tls_private_key.clone() {
|
||||
info!("TLS private key: {}", tls_private_key);
|
||||
info!("TLS support is enabled");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1006,6 +1026,13 @@ impl Config {
|
||||
"Server TLS certificate verification: {}",
|
||||
self.general.verify_server_certificate
|
||||
);
|
||||
info!("Prepared statements: {}", self.general.prepared_statements);
|
||||
if self.general.prepared_statements {
|
||||
info!(
|
||||
"Prepared statements server cache size: {}",
|
||||
self.general.prepared_statements_cache_size
|
||||
);
|
||||
}
|
||||
info!(
|
||||
"Plugins: {}",
|
||||
match self.plugins {
|
||||
@@ -1021,8 +1048,8 @@ impl Config {
|
||||
pool_name,
|
||||
pool_config
|
||||
.users
|
||||
.iter()
|
||||
.map(|(_, user_cfg)| user_cfg.pool_size)
|
||||
.values()
|
||||
.map(|user_cfg| user_cfg.pool_size)
|
||||
.sum::<u32>()
|
||||
.to_string()
|
||||
);
|
||||
@@ -1179,35 +1206,32 @@ impl Config {
|
||||
}
|
||||
|
||||
// Validate TLS!
|
||||
match self.general.tls_certificate.clone() {
|
||||
Some(tls_certificate) => {
|
||||
match load_certs(Path::new(&tls_certificate)) {
|
||||
Ok(_) => {
|
||||
// Cert is okay, but what about the private key?
|
||||
match self.general.tls_private_key.clone() {
|
||||
Some(tls_private_key) => match load_keys(Path::new(&tls_private_key)) {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
error!("tls_private_key is incorrectly configured: {:?}", err);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
},
|
||||
|
||||
None => {
|
||||
error!("tls_certificate is set, but the tls_private_key is not");
|
||||
if let Some(tls_certificate) = self.general.tls_certificate.clone() {
|
||||
match load_certs(Path::new(&tls_certificate)) {
|
||||
Ok(_) => {
|
||||
// Cert is okay, but what about the private key?
|
||||
match self.general.tls_private_key.clone() {
|
||||
Some(tls_private_key) => match load_keys(Path::new(&tls_private_key)) {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
error!("tls_private_key is incorrectly configured: {:?}", err);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
};
|
||||
}
|
||||
},
|
||||
|
||||
Err(err) => {
|
||||
error!("tls_certificate is incorrectly configured: {:?}", err);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
None => {
|
||||
error!("tls_certificate is set, but the tls_private_key is not");
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
Err(err) => {
|
||||
error!("tls_certificate is incorrectly configured: {:?}", err);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
}
|
||||
None => (),
|
||||
};
|
||||
}
|
||||
|
||||
for pool in self.pools.values_mut() {
|
||||
pool.validate()?;
|
||||
@@ -1225,9 +1249,15 @@ pub fn get_config() -> Config {
|
||||
}
|
||||
|
||||
pub fn get_idle_client_in_transaction_timeout() -> u64 {
|
||||
(*(*CONFIG.load()))
|
||||
.general
|
||||
.idle_client_in_transaction_timeout
|
||||
CONFIG.load().general.idle_client_in_transaction_timeout
|
||||
}
|
||||
|
||||
pub fn get_prepared_statements() -> bool {
|
||||
CONFIG.load().general.prepared_statements
|
||||
}
|
||||
|
||||
pub fn get_prepared_statements_cache_size() -> usize {
|
||||
CONFIG.load().general.prepared_statements_cache_size
|
||||
}
|
||||
|
||||
/// Parse the configuration file located at the path.
|
||||
|
||||
@@ -26,6 +26,7 @@ pub enum Error {
|
||||
AuthPassthroughError(String),
|
||||
UnsupportedStatement,
|
||||
QueryRouterParserError(String),
|
||||
QueryRouterError(String),
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
@@ -36,11 +37,11 @@ pub struct ClientIdentifier {
|
||||
}
|
||||
|
||||
impl ClientIdentifier {
|
||||
pub fn new(application_name: &str, username: &str, pool_name: &str) -> ClientIdentifier {
|
||||
pub fn new<S: ToString>(application_name: S, username: S, pool_name: S) -> ClientIdentifier {
|
||||
ClientIdentifier {
|
||||
application_name: application_name.into(),
|
||||
username: username.into(),
|
||||
pool_name: pool_name.into(),
|
||||
application_name: application_name.to_string(),
|
||||
username: username.to_string(),
|
||||
pool_name: pool_name.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -62,10 +63,10 @@ pub struct ServerIdentifier {
|
||||
}
|
||||
|
||||
impl ServerIdentifier {
|
||||
pub fn new(username: &str, database: &str) -> ServerIdentifier {
|
||||
pub fn new<S: ToString>(username: S, database: S) -> ServerIdentifier {
|
||||
ServerIdentifier {
|
||||
username: username.into(),
|
||||
database: database.into(),
|
||||
username: username.to_string(),
|
||||
database: database.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -83,41 +84,42 @@ impl std::fmt::Display for ServerIdentifier {
|
||||
impl std::fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match &self {
|
||||
&Error::ClientSocketError(error, client_identifier) => write!(
|
||||
f,
|
||||
"Error reading {} from client {}",
|
||||
error, client_identifier
|
||||
),
|
||||
&Error::ClientGeneralError(error, client_identifier) => {
|
||||
write!(f, "{} {}", error, client_identifier)
|
||||
Error::ClientSocketError(error, client_identifier) => {
|
||||
write!(f, "Error reading {error} from client {client_identifier}",)
|
||||
}
|
||||
&Error::ClientAuthImpossible(username) => write!(
|
||||
Error::ClientGeneralError(error, client_identifier) => {
|
||||
write!(f, "{error} {client_identifier}")
|
||||
}
|
||||
Error::ClientAuthImpossible(username) => write!(
|
||||
f,
|
||||
"Client auth not possible, \
|
||||
no cleartext password set for username: {} \
|
||||
no cleartext password set for username: {username} \
|
||||
in config and auth passthrough (query_auth) \
|
||||
is not set up.",
|
||||
username
|
||||
is not set up."
|
||||
),
|
||||
&Error::ClientAuthPassthroughError(error, client_identifier) => write!(
|
||||
Error::ClientAuthPassthroughError(error, client_identifier) => write!(
|
||||
f,
|
||||
"No cleartext password set, \
|
||||
and no auth passthrough could not \
|
||||
obtain the hash from server for {}, \
|
||||
the error was: {}",
|
||||
client_identifier, error
|
||||
obtain the hash from server for {client_identifier}, \
|
||||
the error was: {error}",
|
||||
),
|
||||
&Error::ServerStartupError(error, server_identifier) => write!(
|
||||
Error::ServerStartupError(error, server_identifier) => write!(
|
||||
f,
|
||||
"Error reading {} on server startup {}",
|
||||
error, server_identifier,
|
||||
"Error reading {error} on server startup {server_identifier}",
|
||||
),
|
||||
&Error::ServerAuthError(error, server_identifier) => {
|
||||
write!(f, "{} for {}", error, server_identifier,)
|
||||
Error::ServerAuthError(error, server_identifier) => {
|
||||
write!(f, "{error} for {server_identifier}")
|
||||
}
|
||||
|
||||
// The rest can use Debug.
|
||||
err => write!(f, "{:?}", err),
|
||||
err => write!(f, "{err:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::ffi::NulError> for Error {
|
||||
fn from(err: std::ffi::NulError) -> Self {
|
||||
Error::QueryRouterError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
22
src/lib.rs
22
src/lib.rs
@@ -1,13 +1,14 @@
|
||||
pub mod admin;
|
||||
pub mod auth_passthrough;
|
||||
pub mod client;
|
||||
pub mod cmd_args;
|
||||
pub mod config;
|
||||
pub mod constants;
|
||||
pub mod dns_cache;
|
||||
pub mod errors;
|
||||
pub mod logger;
|
||||
pub mod messages;
|
||||
pub mod mirrors;
|
||||
pub mod multi_logger;
|
||||
pub mod plugins;
|
||||
pub mod pool;
|
||||
pub mod prometheus;
|
||||
@@ -24,18 +25,11 @@ pub mod tls;
|
||||
///
|
||||
/// * `duration` - A duration of time
|
||||
pub fn format_duration(duration: &chrono::Duration) -> String {
|
||||
let milliseconds = format!("{:0>3}", duration.num_milliseconds() % 1000);
|
||||
let milliseconds = duration.num_milliseconds() % 1000;
|
||||
let seconds = duration.num_seconds() % 60;
|
||||
let minutes = duration.num_minutes() % 60;
|
||||
let hours = duration.num_hours() % 24;
|
||||
let days = duration.num_days();
|
||||
|
||||
let seconds = format!("{:0>2}", duration.num_seconds() % 60);
|
||||
|
||||
let minutes = format!("{:0>2}", duration.num_minutes() % 60);
|
||||
|
||||
let hours = format!("{:0>2}", duration.num_hours() % 24);
|
||||
|
||||
let days = duration.num_days().to_string();
|
||||
|
||||
format!(
|
||||
"{}d {}:{}:{}.{}",
|
||||
days, hours, minutes, seconds, milliseconds
|
||||
)
|
||||
format!("{days}d {hours:0>2}:{minutes:0>2}:{seconds:0>2}.{milliseconds:0>3}")
|
||||
}
|
||||
|
||||
14
src/logger.rs
Normal file
14
src/logger.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
use crate::cmd_args::{Args, LogFormat};
|
||||
use tracing_subscriber;
|
||||
|
||||
pub fn init(args: &Args) {
|
||||
let trace_sub = tracing_subscriber::fmt()
|
||||
.with_max_level(args.log_level)
|
||||
.with_ansi(!args.no_color);
|
||||
|
||||
match args.log_format {
|
||||
LogFormat::Structured => trace_sub.json().init(),
|
||||
LogFormat::Debug => trace_sub.pretty().init(),
|
||||
_ => trace_sub.init(),
|
||||
};
|
||||
}
|
||||
20
src/main.rs
20
src/main.rs
@@ -61,15 +61,18 @@ use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
use pgcat::cmd_args;
|
||||
use pgcat::config::{get_config, reload_config, VERSION};
|
||||
use pgcat::dns_cache;
|
||||
use pgcat::logger;
|
||||
use pgcat::messages::configure_socket;
|
||||
use pgcat::pool::{ClientServerMap, ConnectionPool};
|
||||
use pgcat::prometheus::start_metric_server;
|
||||
use pgcat::stats::{Collector, Reporter, REPORTER};
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
pgcat::multi_logger::MultiLogger::init().unwrap();
|
||||
let args = cmd_args::parse();
|
||||
logger::init(&args);
|
||||
|
||||
info!("Welcome to PgCat! Meow. (Version {})", VERSION);
|
||||
|
||||
@@ -78,20 +81,12 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
std::process::exit(exitcode::CONFIG);
|
||||
}
|
||||
|
||||
let args = std::env::args().collect::<Vec<String>>();
|
||||
|
||||
let config_file = if args.len() == 2 {
|
||||
args[1].to_string()
|
||||
} else {
|
||||
String::from("pgcat.toml")
|
||||
};
|
||||
|
||||
// Create a transient runtime for loading the config for the first time.
|
||||
{
|
||||
let runtime = Builder::new_multi_thread().worker_threads(1).build()?;
|
||||
|
||||
runtime.block_on(async {
|
||||
match pgcat::config::parse(&config_file).await {
|
||||
match pgcat::config::parse(args.config_file.as_str()).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
error!("Config parse error: {:?}", err);
|
||||
@@ -165,10 +160,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
}
|
||||
};
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
let mut stats_collector = Collector::default();
|
||||
stats_collector.collect().await;
|
||||
});
|
||||
Collector::collect();
|
||||
|
||||
info!("Config autoreloader: {}", match config.general.autoreload {
|
||||
Some(interval) => format!("{} ms", interval),
|
||||
|
||||
475
src/messages.rs
475
src/messages.rs
@@ -1,17 +1,21 @@
|
||||
/// Helper functions to send one-off protocol messages
|
||||
/// and handle TcpStream (TCP socket).
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use log::error;
|
||||
use log::{debug, error};
|
||||
use md5::{Digest, Md5};
|
||||
use socket2::{SockRef, TcpKeepalive};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use crate::client::PREPARED_STATEMENT_COUNTER;
|
||||
use crate::config::get_config;
|
||||
use crate::errors::Error;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::CString;
|
||||
use std::io::{BufRead, Cursor};
|
||||
use std::mem;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Postgres data type mappings
|
||||
@@ -152,12 +156,10 @@ where
|
||||
|
||||
match stream.write_all(&startup).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error writing startup to server socket - Error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
Err(err) => Err(Error::SocketError(format!(
|
||||
"Error writing startup to server socket - Error: {:?}",
|
||||
err
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -233,8 +235,8 @@ pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec<u8> {
|
||||
let mut md5 = Md5::new();
|
||||
|
||||
// First pass
|
||||
md5.update(&password.as_bytes());
|
||||
md5.update(&user.as_bytes());
|
||||
md5.update(password.as_bytes());
|
||||
md5.update(user.as_bytes());
|
||||
|
||||
let output = md5.finalize_reset();
|
||||
|
||||
@@ -270,7 +272,7 @@ where
|
||||
{
|
||||
let password = md5_hash_password(user, password, salt);
|
||||
|
||||
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
|
||||
let mut message = BytesMut::with_capacity(password.len() + 5);
|
||||
|
||||
message.put_u8(b'p');
|
||||
message.put_i32(password.len() as i32 + 4);
|
||||
@@ -284,7 +286,7 @@ where
|
||||
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
let password = md5_hash_second_pass(hash, salt);
|
||||
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
|
||||
let mut message = BytesMut::with_capacity(password.len() + 5);
|
||||
|
||||
message.put_u8(b'p');
|
||||
message.put_i32(password.len() as i32 + 4);
|
||||
@@ -505,7 +507,7 @@ pub fn data_row_nullable(row: &Vec<Option<String>>) -> BytesMut {
|
||||
data_row.put_i32(column.len() as i32);
|
||||
data_row.put_slice(column);
|
||||
} else {
|
||||
data_row.put_i32(-1 as i32);
|
||||
data_row.put_i32(-1_i32);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -526,6 +528,33 @@ pub fn command_complete(command: &str) -> BytesMut {
|
||||
res
|
||||
}
|
||||
|
||||
/// Create a notify message.
|
||||
pub fn notify(message: &str, details: String) -> BytesMut {
|
||||
let mut notify_cmd = BytesMut::new();
|
||||
|
||||
notify_cmd.put_slice("SNOTICE\0".as_bytes());
|
||||
notify_cmd.put_slice("C00000\0".as_bytes());
|
||||
notify_cmd.put_slice(format!("M{}\0", message).as_bytes());
|
||||
notify_cmd.put_slice(format!("D{}\0", details).as_bytes());
|
||||
|
||||
// this extra byte says that is the end of the package
|
||||
notify_cmd.put_u8(0);
|
||||
|
||||
let mut res = BytesMut::new();
|
||||
res.put_u8(b'N');
|
||||
res.put_i32(notify_cmd.len() as i32 + 4);
|
||||
res.put(notify_cmd);
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
pub fn flush() -> BytesMut {
|
||||
let mut bytes = BytesMut::new();
|
||||
bytes.put_u8(b'H');
|
||||
bytes.put_i32(4);
|
||||
bytes
|
||||
}
|
||||
|
||||
/// Write all data in the buffer to the TcpStream.
|
||||
pub async fn write_all<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error>
|
||||
where
|
||||
@@ -533,12 +562,10 @@ where
|
||||
{
|
||||
match stream.write_all(&buf).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error writing to socket - Error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
Err(err) => Err(Error::SocketError(format!(
|
||||
"Error writing to socket - Error: {:?}",
|
||||
err
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -549,12 +576,10 @@ where
|
||||
{
|
||||
match stream.write_all(buf).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error writing to socket - Error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
Err(err) => Err(Error::SocketError(format!(
|
||||
"Error writing to socket - Error: {:?}",
|
||||
err
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -565,19 +590,15 @@ where
|
||||
match stream.write_all(buf).await {
|
||||
Ok(_) => match stream.flush().await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error flushing socket - Error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error writing to socket - Error: {:?}",
|
||||
Err(err) => Err(Error::SocketError(format!(
|
||||
"Error flushing socket - Error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
))),
|
||||
},
|
||||
Err(err) => Err(Error::SocketError(format!(
|
||||
"Error writing to socket - Error: {:?}",
|
||||
err
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -658,6 +679,13 @@ pub fn configure_socket(stream: &TcpStream) {
|
||||
let sock_ref = SockRef::from(stream);
|
||||
let conf = get_config();
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
match sock_ref.set_tcp_user_timeout(Some(Duration::from_millis(conf.general.tcp_user_timeout)))
|
||||
{
|
||||
Ok(_) => (),
|
||||
Err(err) => error!("Could not configure tcp_user_timeout for socket: {}", err),
|
||||
}
|
||||
|
||||
match sock_ref.set_keepalive(true) {
|
||||
Ok(_) => {
|
||||
match sock_ref.set_tcp_keepalive(
|
||||
@@ -667,7 +695,7 @@ pub fn configure_socket(stream: &TcpStream) {
|
||||
.with_time(Duration::from_secs(conf.general.tcp_keepalives_idle)),
|
||||
) {
|
||||
Ok(_) => (),
|
||||
Err(err) => error!("Could not configure socket: {}", err),
|
||||
Err(err) => error!("Could not configure tcp_keepalive for socket: {}", err),
|
||||
}
|
||||
}
|
||||
Err(err) => error!("Could not configure socket: {}", err),
|
||||
@@ -685,7 +713,378 @@ impl BytesMutReader for Cursor<&BytesMut> {
|
||||
let mut buf = vec![];
|
||||
match self.read_until(b'\0', &mut buf) {
|
||||
Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()),
|
||||
Err(err) => return Err(Error::ParseBytesError(err.to_string())),
|
||||
Err(err) => Err(Error::ParseBytesError(err.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse (F) message.
|
||||
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Parse {
|
||||
code: char,
|
||||
#[allow(dead_code)]
|
||||
len: i32,
|
||||
pub name: String,
|
||||
pub generated_name: String,
|
||||
query: String,
|
||||
num_params: i16,
|
||||
param_types: Vec<i32>,
|
||||
}
|
||||
|
||||
impl TryFrom<&BytesMut> for Parse {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(buf: &BytesMut) -> Result<Parse, Error> {
|
||||
let mut cursor = Cursor::new(buf);
|
||||
let code = cursor.get_u8() as char;
|
||||
let len = cursor.get_i32();
|
||||
let name = cursor.read_string()?;
|
||||
let query = cursor.read_string()?;
|
||||
let num_params = cursor.get_i16();
|
||||
let mut param_types = Vec::new();
|
||||
|
||||
for _ in 0..num_params {
|
||||
param_types.push(cursor.get_i32());
|
||||
}
|
||||
|
||||
Ok(Parse {
|
||||
code,
|
||||
len,
|
||||
name,
|
||||
generated_name: prepared_statement_name(),
|
||||
query,
|
||||
num_params,
|
||||
param_types,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Parse> for BytesMut {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(parse: Parse) -> Result<BytesMut, Error> {
|
||||
let mut bytes = BytesMut::new();
|
||||
|
||||
let name_binding = CString::new(parse.name)?;
|
||||
let name = name_binding.as_bytes_with_nul();
|
||||
|
||||
let query_binding = CString::new(parse.query)?;
|
||||
let query = query_binding.as_bytes_with_nul();
|
||||
|
||||
// Recompute length of the message.
|
||||
let len = 4 // self
|
||||
+ name.len()
|
||||
+ query.len()
|
||||
+ 2
|
||||
+ 4 * parse.num_params as usize;
|
||||
|
||||
bytes.put_u8(parse.code as u8);
|
||||
bytes.put_i32(len as i32);
|
||||
bytes.put_slice(name);
|
||||
bytes.put_slice(query);
|
||||
bytes.put_i16(parse.num_params);
|
||||
for param in parse.param_types {
|
||||
bytes.put_i32(param);
|
||||
}
|
||||
|
||||
Ok(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&Parse> for BytesMut {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(parse: &Parse) -> Result<BytesMut, Error> {
|
||||
parse.clone().try_into()
|
||||
}
|
||||
}
|
||||
|
||||
impl Parse {
|
||||
pub fn rename(mut self) -> Self {
|
||||
self.name = self.generated_name.to_string();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn anonymous(&self) -> bool {
|
||||
self.name.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// Bind (B) message.
|
||||
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Bind {
|
||||
code: char,
|
||||
#[allow(dead_code)]
|
||||
len: i64,
|
||||
portal: String,
|
||||
pub prepared_statement: String,
|
||||
num_param_format_codes: i16,
|
||||
param_format_codes: Vec<i16>,
|
||||
num_param_values: i16,
|
||||
param_values: Vec<(i32, BytesMut)>,
|
||||
num_result_column_format_codes: i16,
|
||||
result_columns_format_codes: Vec<i16>,
|
||||
}
|
||||
|
||||
impl TryFrom<&BytesMut> for Bind {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(buf: &BytesMut) -> Result<Bind, Error> {
|
||||
let mut cursor = Cursor::new(buf);
|
||||
let code = cursor.get_u8() as char;
|
||||
let len = cursor.get_i32();
|
||||
let portal = cursor.read_string()?;
|
||||
let prepared_statement = cursor.read_string()?;
|
||||
let num_param_format_codes = cursor.get_i16();
|
||||
let mut param_format_codes = Vec::new();
|
||||
|
||||
for _ in 0..num_param_format_codes {
|
||||
param_format_codes.push(cursor.get_i16());
|
||||
}
|
||||
|
||||
let num_param_values = cursor.get_i16();
|
||||
let mut param_values = Vec::new();
|
||||
|
||||
for _ in 0..num_param_values {
|
||||
let param_len = cursor.get_i32();
|
||||
// There is special occasion when the parameter is NULL
|
||||
// In that case, param length is defined as -1
|
||||
// So if the passed parameter len is over 0
|
||||
if param_len > 0 {
|
||||
let mut param = BytesMut::with_capacity(param_len as usize);
|
||||
param.resize(param_len as usize, b'0');
|
||||
cursor.copy_to_slice(&mut param);
|
||||
// we push and the length and the parameter into vector
|
||||
param_values.push((param_len, param));
|
||||
} else {
|
||||
// otherwise we push a tuple with -1 and 0-len BytesMut
|
||||
// which means that after encountering -1 postgres proceeds
|
||||
// to processing another parameter
|
||||
param_values.push((param_len, BytesMut::new()));
|
||||
}
|
||||
}
|
||||
|
||||
let num_result_column_format_codes = cursor.get_i16();
|
||||
let mut result_columns_format_codes = Vec::new();
|
||||
|
||||
for _ in 0..num_result_column_format_codes {
|
||||
result_columns_format_codes.push(cursor.get_i16());
|
||||
}
|
||||
|
||||
Ok(Bind {
|
||||
code,
|
||||
len: len as i64,
|
||||
portal,
|
||||
prepared_statement,
|
||||
num_param_format_codes,
|
||||
param_format_codes,
|
||||
num_param_values,
|
||||
param_values,
|
||||
num_result_column_format_codes,
|
||||
result_columns_format_codes,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Bind> for BytesMut {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(bind: Bind) -> Result<BytesMut, Error> {
|
||||
let mut bytes = BytesMut::new();
|
||||
|
||||
let portal_binding = CString::new(bind.portal)?;
|
||||
let portal = portal_binding.as_bytes_with_nul();
|
||||
|
||||
let prepared_statement_binding = CString::new(bind.prepared_statement)?;
|
||||
let prepared_statement = prepared_statement_binding.as_bytes_with_nul();
|
||||
|
||||
let mut len = 4 // self
|
||||
+ portal.len()
|
||||
+ prepared_statement.len()
|
||||
+ 2 // num_param_format_codes
|
||||
+ 2 * bind.num_param_format_codes as usize // num_param_format_codes
|
||||
+ 2; // num_param_values
|
||||
|
||||
for (param_len, _) in &bind.param_values {
|
||||
len += 4 + *param_len as usize;
|
||||
}
|
||||
len += 2; // num_result_column_format_codes
|
||||
len += 2 * bind.num_result_column_format_codes as usize;
|
||||
|
||||
bytes.put_u8(bind.code as u8);
|
||||
bytes.put_i32(len as i32);
|
||||
bytes.put_slice(portal);
|
||||
bytes.put_slice(prepared_statement);
|
||||
bytes.put_i16(bind.num_param_format_codes);
|
||||
for param_format_code in bind.param_format_codes {
|
||||
bytes.put_i16(param_format_code);
|
||||
}
|
||||
bytes.put_i16(bind.num_param_values);
|
||||
for (param_len, param) in bind.param_values {
|
||||
bytes.put_i32(param_len);
|
||||
bytes.put_slice(¶m);
|
||||
}
|
||||
bytes.put_i16(bind.num_result_column_format_codes);
|
||||
for result_column_format_code in bind.result_columns_format_codes {
|
||||
bytes.put_i16(result_column_format_code);
|
||||
}
|
||||
|
||||
Ok(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl Bind {
|
||||
pub fn reassign(mut self, parse: &Parse) -> Self {
|
||||
self.prepared_statement = parse.name.clone();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn anonymous(&self) -> bool {
|
||||
self.prepared_statement.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Describe {
|
||||
code: char,
|
||||
|
||||
#[allow(dead_code)]
|
||||
len: i32,
|
||||
target: char,
|
||||
pub statement_name: String,
|
||||
}
|
||||
|
||||
impl TryFrom<&BytesMut> for Describe {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(bytes: &BytesMut) -> Result<Describe, Error> {
|
||||
let mut cursor = Cursor::new(bytes);
|
||||
let code = cursor.get_u8() as char;
|
||||
let len = cursor.get_i32();
|
||||
let target = cursor.get_u8() as char;
|
||||
let statement_name = cursor.read_string()?;
|
||||
|
||||
Ok(Describe {
|
||||
code,
|
||||
len,
|
||||
target,
|
||||
statement_name,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Describe> for BytesMut {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(describe: Describe) -> Result<BytesMut, Error> {
|
||||
let mut bytes = BytesMut::new();
|
||||
let statement_name_binding = CString::new(describe.statement_name)?;
|
||||
let statement_name = statement_name_binding.as_bytes_with_nul();
|
||||
let len = 4 + 1 + statement_name.len();
|
||||
|
||||
bytes.put_u8(describe.code as u8);
|
||||
bytes.put_i32(len as i32);
|
||||
bytes.put_u8(describe.target as u8);
|
||||
bytes.put_slice(statement_name);
|
||||
|
||||
Ok(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl Describe {
|
||||
pub fn rename(mut self, name: &str) -> Self {
|
||||
self.statement_name = name.to_string();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn anonymous(&self) -> bool {
|
||||
self.statement_name.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// Close (F) message.
|
||||
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Close {
|
||||
code: char,
|
||||
#[allow(dead_code)]
|
||||
len: i32,
|
||||
close_type: char,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
impl TryFrom<&BytesMut> for Close {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(bytes: &BytesMut) -> Result<Close, Error> {
|
||||
let mut cursor = Cursor::new(bytes);
|
||||
let code = cursor.get_u8() as char;
|
||||
let len = cursor.get_i32();
|
||||
let close_type = cursor.get_u8() as char;
|
||||
let name = cursor.read_string()?;
|
||||
|
||||
Ok(Close {
|
||||
code,
|
||||
len,
|
||||
close_type,
|
||||
name,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Close> for BytesMut {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(close: Close) -> Result<BytesMut, Error> {
|
||||
debug!("Close: {:?}", close);
|
||||
|
||||
let mut bytes = BytesMut::new();
|
||||
let name_binding = CString::new(close.name)?;
|
||||
let name = name_binding.as_bytes_with_nul();
|
||||
let len = 4 + 1 + name.len();
|
||||
|
||||
bytes.put_u8(close.code as u8);
|
||||
bytes.put_i32(len as i32);
|
||||
bytes.put_u8(close.close_type as u8);
|
||||
bytes.put_slice(name);
|
||||
|
||||
Ok(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl Close {
|
||||
pub fn new(name: &str) -> Close {
|
||||
let name = name.to_string();
|
||||
|
||||
Close {
|
||||
code: 'C',
|
||||
len: 4 + 1 + name.len() as i32 + 1, // will be recalculated
|
||||
close_type: 'S',
|
||||
name,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_prepared_statement(&self) -> bool {
|
||||
self.close_type == 'S'
|
||||
}
|
||||
|
||||
pub fn anonymous(&self) -> bool {
|
||||
self.name.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn close_complete() -> BytesMut {
|
||||
let mut bytes = BytesMut::new();
|
||||
bytes.put_u8(b'3');
|
||||
bytes.put_i32(4);
|
||||
bytes
|
||||
}
|
||||
|
||||
pub fn prepared_statement_name() -> String {
|
||||
format!(
|
||||
"P_{}",
|
||||
PREPARED_STATEMENT_COUNTER.fetch_add(1, Ordering::SeqCst)
|
||||
)
|
||||
}
|
||||
|
||||
@@ -7,8 +7,7 @@ use bytes::{Bytes, BytesMut};
|
||||
use parking_lot::RwLock;
|
||||
|
||||
use crate::config::{get_config, Address, Role, User};
|
||||
use crate::pool::{ClientServerMap, PoolIdentifier, ServerPool};
|
||||
use crate::stats::PoolStats;
|
||||
use crate::pool::{ClientServerMap, ServerPool};
|
||||
use log::{error, info, trace, warn};
|
||||
use tokio::sync::mpsc::{channel, Receiver, Sender};
|
||||
|
||||
@@ -24,7 +23,7 @@ impl MirroredClient {
|
||||
async fn create_pool(&self) -> Pool<ServerPool> {
|
||||
let config = get_config();
|
||||
let default = std::time::Duration::from_millis(10_000).as_millis() as u64;
|
||||
let (connection_timeout, idle_timeout, cfg) =
|
||||
let (connection_timeout, idle_timeout, _cfg) =
|
||||
match config.pools.get(&self.address.pool_name) {
|
||||
Some(cfg) => (
|
||||
cfg.connect_timeout.unwrap_or(default),
|
||||
@@ -34,14 +33,11 @@ impl MirroredClient {
|
||||
None => (default, default, crate::config::Pool::default()),
|
||||
};
|
||||
|
||||
let identifier = PoolIdentifier::new(&self.database, &self.user.username);
|
||||
|
||||
let manager = ServerPool::new(
|
||||
self.address.clone(),
|
||||
self.user.clone(),
|
||||
self.database.as_str(),
|
||||
ClientServerMap::default(),
|
||||
Arc::new(PoolStats::new(identifier, cfg.clone())),
|
||||
Arc::new(RwLock::new(None)),
|
||||
None,
|
||||
true,
|
||||
@@ -146,12 +142,12 @@ impl MirroringManager {
|
||||
});
|
||||
|
||||
Self {
|
||||
byte_senders: byte_senders,
|
||||
byte_senders,
|
||||
disconnect_senders: exit_senders,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send(self: &mut Self, bytes: &BytesMut) {
|
||||
pub fn send(&mut self, bytes: &BytesMut) {
|
||||
// We want to avoid performing an allocation if we won't be able to send the message
|
||||
// There is a possibility of a race here where we check the capacity and then the channel is
|
||||
// closed or the capacity is reduced to 0, but mirroring is best effort anyway
|
||||
@@ -173,7 +169,7 @@ impl MirroringManager {
|
||||
});
|
||||
}
|
||||
|
||||
pub fn disconnect(self: &mut Self) {
|
||||
pub fn disconnect(&mut self) {
|
||||
self.disconnect_senders
|
||||
.iter_mut()
|
||||
.for_each(|sender| match sender.try_send(()) {
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
use log::{Level, Log, Metadata, Record, SetLoggerError};
|
||||
|
||||
// This is a special kind of logger that allows sending logs to different
|
||||
// targets depending on the log level.
|
||||
//
|
||||
// By default, if nothing is set, it acts as a regular env_log logger,
|
||||
// it sends everything to standard error.
|
||||
//
|
||||
// If the Env variable `STDOUT_LOG` is defined, it will be used for
|
||||
// configuring the standard out logger.
|
||||
//
|
||||
// The behavior is:
|
||||
// - If it is an error, the message is written to standard error.
|
||||
// - If it is not, and it matches the log level of the standard output logger (`STDOUT_LOG` env var), it will be send to standard output.
|
||||
// - If the above is not true, it is sent to the stderr logger that will log it or not depending on the value
|
||||
// of the RUST_LOG env var.
|
||||
//
|
||||
// So to summarize, if no `STDOUT_LOG` env var is present, the logger is the default logger. If `STDOUT_LOG` is set, everything
|
||||
// but errors, that matches the log level set in the `STDOUT_LOG` env var is sent to stdout. You can have also some esoteric configuration
|
||||
// where you set `RUST_LOG=debug` and `STDOUT_LOG=info`, in here, errors will go to stderr, warns and infos to stdout and debugs to stderr.
|
||||
//
|
||||
pub struct MultiLogger {
|
||||
stderr_logger: env_logger::Logger,
|
||||
stdout_logger: env_logger::Logger,
|
||||
}
|
||||
|
||||
impl MultiLogger {
|
||||
fn new() -> Self {
|
||||
let stderr_logger = env_logger::builder().format_timestamp_micros().build();
|
||||
let stdout_logger = env_logger::Builder::from_env("STDOUT_LOG")
|
||||
.format_timestamp_micros()
|
||||
.target(env_logger::Target::Stdout)
|
||||
.build();
|
||||
|
||||
Self {
|
||||
stderr_logger,
|
||||
stdout_logger,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init() -> Result<(), SetLoggerError> {
|
||||
let logger = Self::new();
|
||||
|
||||
log::set_max_level(logger.stderr_logger.filter());
|
||||
log::set_boxed_logger(Box::new(logger))
|
||||
}
|
||||
}
|
||||
|
||||
impl Log for MultiLogger {
|
||||
fn enabled(&self, metadata: &Metadata) -> bool {
|
||||
self.stderr_logger.enabled(metadata) && self.stdout_logger.enabled(metadata)
|
||||
}
|
||||
|
||||
fn log(&self, record: &Record) {
|
||||
if record.level() == Level::Error {
|
||||
self.stderr_logger.log(record);
|
||||
} else {
|
||||
if self.stdout_logger.matches(record) {
|
||||
self.stdout_logger.log(record);
|
||||
} else {
|
||||
self.stderr_logger.log(record);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&self) {
|
||||
self.stderr_logger.flush();
|
||||
self.stdout_logger.flush();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_init() {
|
||||
MultiLogger::init().unwrap();
|
||||
}
|
||||
}
|
||||
@@ -92,7 +92,7 @@ impl<'a> Plugin for Intercept<'a> {
|
||||
.map(|s| {
|
||||
let s = s.as_str().to_string();
|
||||
|
||||
if s == "" {
|
||||
if s.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(s)
|
||||
|
||||
@@ -30,6 +30,7 @@ pub enum PluginOutput {
|
||||
Intercept(BytesMut),
|
||||
}
|
||||
|
||||
#[allow(clippy::ptr_arg)]
|
||||
#[async_trait]
|
||||
pub trait Plugin {
|
||||
// Run before the query is sent to the server.
|
||||
|
||||
@@ -20,7 +20,7 @@ impl<'a> Prewarmer<'a> {
|
||||
self.server.address(),
|
||||
query
|
||||
);
|
||||
self.server.query(&query).await?;
|
||||
self.server.query(query).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -31,7 +31,7 @@ impl<'a> Plugin for QueryLogger<'a> {
|
||||
.map(|q| q.to_string())
|
||||
.collect::<Vec<String>>()
|
||||
.join("; ");
|
||||
info!("[pool: {}][user: {}] {}", self.user, self.db, query);
|
||||
info!("[pool: {}][user: {}] {}", self.db, self.user, query);
|
||||
|
||||
Ok(PluginOutput::Allow)
|
||||
}
|
||||
|
||||
@@ -30,27 +30,22 @@ impl<'a> Plugin for TableAccess<'a> {
|
||||
return Ok(PluginOutput::Allow);
|
||||
}
|
||||
|
||||
let mut found = None;
|
||||
|
||||
visit_relations(ast, |relation| {
|
||||
let control_flow = visit_relations(ast, |relation| {
|
||||
let relation = relation.to_string();
|
||||
let parts = relation.split(".").collect::<Vec<&str>>();
|
||||
let table_name = parts.last().unwrap();
|
||||
let table_name = relation.split('.').last().unwrap().to_string();
|
||||
|
||||
if self.tables.contains(&table_name.to_string()) {
|
||||
found = Some(table_name.to_string());
|
||||
ControlFlow::<()>::Break(())
|
||||
if self.tables.contains(&table_name) {
|
||||
ControlFlow::Break(table_name)
|
||||
} else {
|
||||
ControlFlow::<()>::Continue(())
|
||||
ControlFlow::Continue(())
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(found) = found {
|
||||
debug!("Blocking access to table \"{}\"", found);
|
||||
if let ControlFlow::Break(found) = control_flow {
|
||||
debug!("Blocking access to table \"{found}\"");
|
||||
|
||||
Ok(PluginOutput::Deny(format!(
|
||||
"permission for table \"{}\" denied",
|
||||
found
|
||||
"permission for table \"{found}\" denied",
|
||||
)))
|
||||
} else {
|
||||
Ok(PluginOutput::Allow)
|
||||
|
||||
85
src/pool.rs
85
src/pool.rs
@@ -1,6 +1,6 @@
|
||||
use arc_swap::ArcSwap;
|
||||
use async_trait::async_trait;
|
||||
use bb8::{ManageConnection, Pool, PooledConnection};
|
||||
use bb8::{ManageConnection, Pool, PooledConnection, QueueStrategy};
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use chrono::naive::NaiveDateTime;
|
||||
use log::{debug, error, info, warn};
|
||||
@@ -10,6 +10,7 @@ use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
use regex::Regex;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
@@ -26,7 +27,7 @@ use crate::auth_passthrough::AuthPassthrough;
|
||||
use crate::plugins::prewarmer;
|
||||
use crate::server::Server;
|
||||
use crate::sharding::ShardingFunction;
|
||||
use crate::stats::{AddressStats, ClientStats, PoolStats, ServerStats};
|
||||
use crate::stats::{AddressStats, ClientStats, ServerStats};
|
||||
|
||||
pub type ProcessId = i32;
|
||||
pub type SecretKey = i32;
|
||||
@@ -76,6 +77,12 @@ impl PoolIdentifier {
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for PoolIdentifier {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}@{}", self.user, self.db)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Address> for PoolIdentifier {
|
||||
fn from(address: &Address) -> PoolIdentifier {
|
||||
PoolIdentifier::new(&address.database, &address.username)
|
||||
@@ -202,9 +209,6 @@ pub struct ConnectionPool {
|
||||
paused: Arc<AtomicBool>,
|
||||
paused_waiter: Arc<Notify>,
|
||||
|
||||
/// Statistics.
|
||||
pub stats: Arc<PoolStats>,
|
||||
|
||||
/// AuthInfo
|
||||
pub auth_hash: Arc<RwLock<Option<String>>>,
|
||||
}
|
||||
@@ -225,20 +229,17 @@ impl ConnectionPool {
|
||||
let old_pool_ref = get_pool(pool_name, &user.username);
|
||||
let identifier = PoolIdentifier::new(pool_name, &user.username);
|
||||
|
||||
match old_pool_ref {
|
||||
Some(pool) => {
|
||||
// If the pool hasn't changed, get existing reference and insert it into the new_pools.
|
||||
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
|
||||
if pool.config_hash == new_pool_hash_value {
|
||||
info!(
|
||||
"[pool: {}][user: {}] has not changed",
|
||||
pool_name, user.username
|
||||
);
|
||||
new_pools.insert(identifier.clone(), pool.clone());
|
||||
continue;
|
||||
}
|
||||
if let Some(pool) = old_pool_ref {
|
||||
// If the pool hasn't changed, get existing reference and insert it into the new_pools.
|
||||
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
|
||||
if pool.config_hash == new_pool_hash_value {
|
||||
info!(
|
||||
"[pool: {}][user: {}] has not changed",
|
||||
pool_name, user.username
|
||||
);
|
||||
new_pools.insert(identifier.clone(), pool.clone());
|
||||
continue;
|
||||
}
|
||||
None => (),
|
||||
}
|
||||
|
||||
info!(
|
||||
@@ -254,10 +255,6 @@ impl ConnectionPool {
|
||||
.clone()
|
||||
.into_keys()
|
||||
.collect::<Vec<String>>();
|
||||
let pool_stats = Arc::new(PoolStats::new(identifier, pool_config.clone()));
|
||||
|
||||
// Allow the pool to be seen in statistics
|
||||
pool_stats.register(pool_stats.clone());
|
||||
|
||||
// Sort by shard number to ensure consistency.
|
||||
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
|
||||
@@ -358,7 +355,6 @@ impl ConnectionPool {
|
||||
user.clone(),
|
||||
&shard.database,
|
||||
client_server_map.clone(),
|
||||
pool_stats.clone(),
|
||||
pool_auth_hash.clone(),
|
||||
match pool_config.plugins {
|
||||
Some(ref plugins) => Some(plugins.clone()),
|
||||
@@ -390,6 +386,11 @@ impl ConnectionPool {
|
||||
.min()
|
||||
.unwrap();
|
||||
|
||||
let queue_strategy = match config.general.server_round_robin {
|
||||
true => QueueStrategy::Fifo,
|
||||
false => QueueStrategy::Lifo,
|
||||
};
|
||||
|
||||
debug!(
|
||||
"[pool: {}][user: {}] Pool reaper rate: {}ms",
|
||||
pool_name, user.username, reaper_rate
|
||||
@@ -402,6 +403,7 @@ impl ConnectionPool {
|
||||
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
|
||||
.max_lifetime(Some(std::time::Duration::from_millis(server_lifetime)))
|
||||
.reaper_rate(std::time::Duration::from_millis(reaper_rate))
|
||||
.queue_strategy(queue_strategy)
|
||||
.test_on_check_out(false);
|
||||
|
||||
let pool = if config.general.validate_config {
|
||||
@@ -429,7 +431,6 @@ impl ConnectionPool {
|
||||
|
||||
let pool = ConnectionPool {
|
||||
databases: shards,
|
||||
stats: pool_stats,
|
||||
addresses,
|
||||
banlist: Arc::new(RwLock::new(banlist)),
|
||||
config_hash: new_pool_hash_value,
|
||||
@@ -610,6 +611,10 @@ impl ConnectionPool {
|
||||
});
|
||||
}
|
||||
|
||||
// Indicate we're waiting on a server connection from a pool.
|
||||
let now = Instant::now();
|
||||
client_stats.waiting();
|
||||
|
||||
while !candidates.is_empty() {
|
||||
// Get the next candidate
|
||||
let address = match candidates.pop() {
|
||||
@@ -620,7 +625,7 @@ impl ConnectionPool {
|
||||
let mut force_healthcheck = false;
|
||||
|
||||
if self.is_banned(address) {
|
||||
if self.try_unban(&address).await {
|
||||
if self.try_unban(address).await {
|
||||
force_healthcheck = true;
|
||||
} else {
|
||||
debug!("Address {:?} is banned", address);
|
||||
@@ -628,10 +633,6 @@ impl ConnectionPool {
|
||||
}
|
||||
}
|
||||
|
||||
// Indicate we're waiting on a server connection from a pool.
|
||||
let now = Instant::now();
|
||||
client_stats.waiting();
|
||||
|
||||
// Check if we can connect
|
||||
let mut conn = match self.databases[address.shard][address.address_index]
|
||||
.get()
|
||||
@@ -669,7 +670,7 @@ impl ConnectionPool {
|
||||
.stats()
|
||||
.checkout_time(checkout_time, client_stats.application_name());
|
||||
server.stats().active(client_stats.application_name());
|
||||
|
||||
client_stats.active();
|
||||
return Ok((conn, address.clone()));
|
||||
}
|
||||
|
||||
@@ -677,11 +678,19 @@ impl ConnectionPool {
|
||||
.run_health_check(address, server, now, client_stats)
|
||||
.await
|
||||
{
|
||||
let checkout_time: u64 = now.elapsed().as_micros() as u64;
|
||||
client_stats.checkout_time(checkout_time);
|
||||
server
|
||||
.stats()
|
||||
.checkout_time(checkout_time, client_stats.application_name());
|
||||
server.stats().active(client_stats.application_name());
|
||||
client_stats.active();
|
||||
return Ok((conn, address.clone()));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
client_stats.idle();
|
||||
Err(Error::AllServersDown)
|
||||
}
|
||||
|
||||
@@ -736,8 +745,8 @@ impl ConnectionPool {
|
||||
// Don't leave a bad connection in the pool.
|
||||
server.mark_bad();
|
||||
|
||||
self.ban(&address, BanReason::FailedHealthCheck, Some(client_info));
|
||||
return false;
|
||||
self.ban(address, BanReason::FailedHealthCheck, Some(client_info));
|
||||
false
|
||||
}
|
||||
|
||||
/// Ban an address (i.e. replica). It no longer will serve
|
||||
@@ -849,10 +858,10 @@ impl ConnectionPool {
|
||||
let guard = self.banlist.read();
|
||||
for banlist in guard.iter() {
|
||||
for (address, (reason, timestamp)) in banlist.iter() {
|
||||
bans.push((address.clone(), (reason.clone(), timestamp.clone())));
|
||||
bans.push((address.clone(), (reason.clone(), *timestamp)));
|
||||
}
|
||||
}
|
||||
return bans;
|
||||
bans
|
||||
}
|
||||
|
||||
/// Get the address from the host url
|
||||
@@ -909,7 +918,7 @@ impl ConnectionPool {
|
||||
}
|
||||
let busy = provisioned - idle;
|
||||
debug!("{:?} has {:?} busy connections", address, busy);
|
||||
return busy;
|
||||
busy
|
||||
}
|
||||
}
|
||||
|
||||
@@ -927,9 +936,6 @@ pub struct ServerPool {
|
||||
/// Client/server mapping.
|
||||
client_server_map: ClientServerMap,
|
||||
|
||||
/// Server statistics.
|
||||
stats: Arc<PoolStats>,
|
||||
|
||||
/// Server auth hash (for auth passthrough).
|
||||
auth_hash: Arc<RwLock<Option<String>>>,
|
||||
|
||||
@@ -946,7 +952,6 @@ impl ServerPool {
|
||||
user: User,
|
||||
database: &str,
|
||||
client_server_map: ClientServerMap,
|
||||
stats: Arc<PoolStats>,
|
||||
auth_hash: Arc<RwLock<Option<String>>>,
|
||||
plugins: Option<Plugins>,
|
||||
cleanup_connections: bool,
|
||||
@@ -956,7 +961,6 @@ impl ServerPool {
|
||||
user: user.clone(),
|
||||
database: database.to_string(),
|
||||
client_server_map,
|
||||
stats,
|
||||
auth_hash,
|
||||
plugins,
|
||||
cleanup_connections,
|
||||
@@ -975,7 +979,6 @@ impl ManageConnection for ServerPool {
|
||||
|
||||
let stats = Arc::new(ServerStats::new(
|
||||
self.address.clone(),
|
||||
self.stats.clone(),
|
||||
tokio::time::Instant::now(),
|
||||
));
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use hyper::service::{make_service_fn, service_fn};
|
||||
use hyper::{Body, Method, Request, Response, Server, StatusCode};
|
||||
use log::{error, info, warn};
|
||||
use log::{debug, error, info};
|
||||
use phf::phf_map;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
@@ -9,8 +9,9 @@ use std::sync::atomic::Ordering;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::Address;
|
||||
use crate::pool::get_all_pools;
|
||||
use crate::stats::{get_pool_stats, get_server_stats, ServerStats};
|
||||
use crate::pool::{get_all_pools, PoolIdentifier};
|
||||
use crate::stats::pool::PoolStats;
|
||||
use crate::stats::{get_server_stats, ServerStats};
|
||||
|
||||
struct MetricHelpType {
|
||||
help: &'static str,
|
||||
@@ -233,10 +234,10 @@ impl<Value: fmt::Display> PrometheusMetric<Value> {
|
||||
Self::from_name(&format!("stats_{}", name), value, labels)
|
||||
}
|
||||
|
||||
fn from_pool(pool: &(String, String), name: &str, value: u64) -> Option<PrometheusMetric<u64>> {
|
||||
fn from_pool(pool_id: PoolIdentifier, name: &str, value: u64) -> Option<PrometheusMetric<u64>> {
|
||||
let mut labels = HashMap::new();
|
||||
labels.insert("pool", pool.0.clone());
|
||||
labels.insert("user", pool.1.clone());
|
||||
labels.insert("pool", pool_id.db);
|
||||
labels.insert("user", pool_id.user);
|
||||
|
||||
Self::from_name(&format!("pools_{}", name), value, labels)
|
||||
}
|
||||
@@ -274,7 +275,7 @@ fn push_address_stats(lines: &mut Vec<String>) {
|
||||
{
|
||||
lines.push(prometheus_metric.to_string());
|
||||
} else {
|
||||
warn!("Metric {} not implemented for {}", key, address.name());
|
||||
debug!("Metric {} not implemented for {}", key, address.name());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -284,18 +285,15 @@ fn push_address_stats(lines: &mut Vec<String>) {
|
||||
|
||||
// Adds relevant metrics shown in a SHOW POOLS admin command.
|
||||
fn push_pool_stats(lines: &mut Vec<String>) {
|
||||
let pool_stats = get_pool_stats();
|
||||
for (pool, stats) in pool_stats.iter() {
|
||||
let stats = &**stats;
|
||||
let pool_stats = PoolStats::construct_pool_lookup();
|
||||
for (pool_id, stats) in pool_stats.iter() {
|
||||
for (name, value) in stats.clone() {
|
||||
if let Some(prometheus_metric) = PrometheusMetric::<u64>::from_pool(pool, &name, value)
|
||||
if let Some(prometheus_metric) =
|
||||
PrometheusMetric::<u64>::from_pool(pool_id.clone(), &name, value)
|
||||
{
|
||||
lines.push(prometheus_metric.to_string());
|
||||
} else {
|
||||
warn!(
|
||||
"Metric {} not implemented for ({},{})",
|
||||
name, pool.0, pool.1
|
||||
);
|
||||
debug!("Metric {} not implemented for ({})", name, *pool_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -320,7 +318,7 @@ fn push_database_stats(lines: &mut Vec<String>) {
|
||||
{
|
||||
lines.push(prometheus_metric.to_string());
|
||||
} else {
|
||||
warn!("Metric {} not implemented for {}", key, address.name());
|
||||
debug!("Metric {} not implemented for {}", key, address.name());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -366,7 +364,7 @@ fn push_server_stats(lines: &mut Vec<String>) {
|
||||
{
|
||||
lines.push(prometheus_metric.to_string());
|
||||
} else {
|
||||
warn!("Metric {} not implemented for {}", key, address.name());
|
||||
debug!("Metric {} not implemented for {}", key, address.name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,6 +67,7 @@ static CUSTOM_SQL_REGEX_SET: OnceCell<RegexSet> = OnceCell::new();
|
||||
static CUSTOM_SQL_REGEX_LIST: OnceCell<Vec<Regex>> = OnceCell::new();
|
||||
|
||||
/// The query router.
|
||||
#[derive(Default)]
|
||||
pub struct QueryRouter {
|
||||
/// Which shard we should be talking to right now.
|
||||
active_shard: Option<usize>,
|
||||
@@ -91,7 +92,7 @@ impl QueryRouter {
|
||||
/// One-time initialization of regexes
|
||||
/// that parse our custom SQL protocol.
|
||||
pub fn setup() -> bool {
|
||||
let set = match RegexSet::new(&CUSTOM_SQL_REGEXES) {
|
||||
let set = match RegexSet::new(CUSTOM_SQL_REGEXES) {
|
||||
Ok(rgx) => rgx,
|
||||
Err(err) => {
|
||||
error!("QueryRouter::setup Could not compile regex set: {:?}", err);
|
||||
@@ -116,15 +117,8 @@ impl QueryRouter {
|
||||
|
||||
/// Create a new instance of the query router.
|
||||
/// Each client gets its own.
|
||||
pub fn new() -> QueryRouter {
|
||||
QueryRouter {
|
||||
active_shard: None,
|
||||
active_role: None,
|
||||
query_parser_enabled: None,
|
||||
primary_reads_enabled: None,
|
||||
pool_settings: PoolSettings::default(),
|
||||
placeholders: Vec::new(),
|
||||
}
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Pool settings can change because of a config reload.
|
||||
@@ -132,7 +126,7 @@ impl QueryRouter {
|
||||
self.pool_settings = pool_settings;
|
||||
}
|
||||
|
||||
pub fn pool_settings<'a>(&'a self) -> &'a PoolSettings {
|
||||
pub fn pool_settings(&self) -> &PoolSettings {
|
||||
&self.pool_settings
|
||||
}
|
||||
|
||||
@@ -143,7 +137,7 @@ impl QueryRouter {
|
||||
let code = message_cursor.get_u8() as char;
|
||||
|
||||
// Check for any sharding regex matches in any queries
|
||||
match code as char {
|
||||
match code {
|
||||
// For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement
|
||||
'P' | 'Q' => {
|
||||
if self.pool_settings.shard_id_regex.is_some()
|
||||
@@ -331,7 +325,7 @@ impl QueryRouter {
|
||||
Some((command, value))
|
||||
}
|
||||
|
||||
pub fn parse(message: &BytesMut) -> Result<Vec<sqlparser::ast::Statement>, Error> {
|
||||
pub fn parse(message: &BytesMut) -> Result<Vec<Statement>, Error> {
|
||||
let mut message_cursor = Cursor::new(message);
|
||||
|
||||
let code = message_cursor.get_u8() as char;
|
||||
@@ -348,12 +342,13 @@ impl QueryRouter {
|
||||
// Parse (prepared statement)
|
||||
'P' => {
|
||||
// Reads statement name
|
||||
message_cursor.read_string().unwrap();
|
||||
let _name = message_cursor.read_string().unwrap();
|
||||
|
||||
// Reads query string
|
||||
let query = message_cursor.read_string().unwrap();
|
||||
|
||||
debug!("Prepared statement: '{}'", query);
|
||||
|
||||
query
|
||||
}
|
||||
|
||||
@@ -396,14 +391,10 @@ impl QueryRouter {
|
||||
// or discard shard selection. If they point to the same shard though,
|
||||
// we can let them through as-is.
|
||||
// This is basically building a database now :)
|
||||
match self.infer_shard(query) {
|
||||
Some(shard) => {
|
||||
self.active_shard = Some(shard);
|
||||
debug!("Automatically using shard: {:?}", self.active_shard);
|
||||
}
|
||||
|
||||
None => (),
|
||||
};
|
||||
if let Some(shard) = self.infer_shard(query) {
|
||||
self.active_shard = Some(shard);
|
||||
debug!("Automatically using shard: {:?}", self.active_shard);
|
||||
}
|
||||
}
|
||||
|
||||
None => (),
|
||||
@@ -575,8 +566,8 @@ impl QueryRouter {
|
||||
.automatic_sharding_key
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.split(".")
|
||||
.map(|ident| Ident::new(ident))
|
||||
.split('.')
|
||||
.map(Ident::new)
|
||||
.collect::<Vec<Ident>>();
|
||||
|
||||
// Sharding key must be always fully qualified
|
||||
@@ -592,7 +583,7 @@ impl QueryRouter {
|
||||
Expr::Identifier(ident) => {
|
||||
// Only if we're dealing with only one table
|
||||
// and there is no ambiguity
|
||||
if &ident.value == &sharding_key[1].value {
|
||||
if ident.value == sharding_key[1].value {
|
||||
// Sharding key is unique enough, don't worry about
|
||||
// table names.
|
||||
if &sharding_key[0].value == "*" {
|
||||
@@ -605,13 +596,13 @@ impl QueryRouter {
|
||||
// SELECT * FROM t WHERE sharding_key = 5
|
||||
// Make sure the table name from the sharding key matches
|
||||
// the table name from the query.
|
||||
found = &sharding_key[0].value == &table[0].value;
|
||||
found = sharding_key[0].value == table[0].value;
|
||||
} else if table.len() == 2 {
|
||||
// Table name is fully qualified with the schema: e.g.
|
||||
// SELECT * FROM public.t WHERE sharding_key = 5
|
||||
// Ignore the schema (TODO: at some point, we want schema support)
|
||||
// and use the table name only.
|
||||
found = &sharding_key[0].value == &table[1].value;
|
||||
found = sharding_key[0].value == table[1].value;
|
||||
} else {
|
||||
debug!("Got table name with more than two idents, which is not possible");
|
||||
}
|
||||
@@ -623,8 +614,8 @@ impl QueryRouter {
|
||||
// The key is fully qualified in the query,
|
||||
// it will exist or Postgres will throw an error.
|
||||
if idents.len() == 2 {
|
||||
found = &sharding_key[0].value == &idents[0].value
|
||||
&& &sharding_key[1].value == &idents[1].value;
|
||||
found = sharding_key[0].value == idents[0].value
|
||||
&& sharding_key[1].value == idents[1].value;
|
||||
}
|
||||
// TODO: key can have schema as well, e.g. public.data.id (len == 3)
|
||||
}
|
||||
@@ -656,7 +647,7 @@ impl QueryRouter {
|
||||
}
|
||||
|
||||
Expr::Value(Value::Placeholder(placeholder)) => {
|
||||
match placeholder.replace("$", "").parse::<i16>() {
|
||||
match placeholder.replace('$', "").parse::<i16>() {
|
||||
Ok(placeholder) => result.push(ShardingKey::Placeholder(placeholder)),
|
||||
Err(_) => {
|
||||
debug!(
|
||||
@@ -682,12 +673,9 @@ impl QueryRouter {
|
||||
|
||||
match &*query.body {
|
||||
SetExpr::Query(query) => {
|
||||
match self.infer_shard(&*query) {
|
||||
Some(shard) => {
|
||||
shards.insert(shard);
|
||||
}
|
||||
None => (),
|
||||
};
|
||||
if let Some(shard) = self.infer_shard(query) {
|
||||
shards.insert(shard);
|
||||
}
|
||||
}
|
||||
|
||||
// SELECT * FROM ...
|
||||
@@ -697,38 +685,22 @@ impl QueryRouter {
|
||||
let mut table_names = Vec::new();
|
||||
|
||||
for table in select.from.iter() {
|
||||
match &table.relation {
|
||||
TableFactor::Table { name, .. } => {
|
||||
table_names.push(name.0.clone());
|
||||
}
|
||||
|
||||
_ => (),
|
||||
};
|
||||
if let TableFactor::Table { name, .. } = &table.relation {
|
||||
table_names.push(name.0.clone());
|
||||
}
|
||||
|
||||
// Get table names from all the joins.
|
||||
for join in table.joins.iter() {
|
||||
match &join.relation {
|
||||
TableFactor::Table { name, .. } => {
|
||||
table_names.push(name.0.clone());
|
||||
}
|
||||
|
||||
_ => (),
|
||||
};
|
||||
if let TableFactor::Table { name, .. } = &join.relation {
|
||||
table_names.push(name.0.clone());
|
||||
}
|
||||
|
||||
// We can filter results based on join conditions, e.g.
|
||||
// SELECT * FROM t INNER JOIN B ON B.sharding_key = 5;
|
||||
match &join.join_operator {
|
||||
JoinOperator::Inner(inner_join) => match &inner_join {
|
||||
JoinConstraint::On(expr) => {
|
||||
// Parse the selection criteria later.
|
||||
exprs.push(expr.clone());
|
||||
}
|
||||
|
||||
_ => (),
|
||||
},
|
||||
|
||||
_ => (),
|
||||
};
|
||||
if let JoinOperator::Inner(JoinConstraint::On(expr)) = &join.join_operator {
|
||||
// Parse the selection criteria later.
|
||||
exprs.push(expr.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -802,16 +774,16 @@ impl QueryRouter {
|
||||
db: &self.pool_settings.db,
|
||||
};
|
||||
|
||||
let _ = query_logger.run(&self, ast).await;
|
||||
let _ = query_logger.run(self, ast).await;
|
||||
}
|
||||
|
||||
if let Some(ref intercept) = plugins.intercept {
|
||||
let mut intercept = Intercept {
|
||||
enabled: intercept.enabled,
|
||||
config: &intercept,
|
||||
config: intercept,
|
||||
};
|
||||
|
||||
let result = intercept.run(&self, ast).await;
|
||||
let result = intercept.run(self, ast).await;
|
||||
|
||||
if let Ok(PluginOutput::Intercept(output)) = result {
|
||||
return Ok(PluginOutput::Intercept(output));
|
||||
@@ -824,7 +796,7 @@ impl QueryRouter {
|
||||
tables: &table_access.tables,
|
||||
};
|
||||
|
||||
let result = table_access.run(&self, ast).await;
|
||||
let result = table_access.run(self, ast).await;
|
||||
|
||||
if let Ok(PluginOutput::Deny(error)) = result {
|
||||
return Ok(PluginOutput::Deny(error));
|
||||
@@ -860,7 +832,7 @@ impl QueryRouter {
|
||||
|
||||
/// Should we attempt to parse queries?
|
||||
pub fn query_parser_enabled(&self) -> bool {
|
||||
let enabled = match self.query_parser_enabled {
|
||||
match self.query_parser_enabled {
|
||||
None => {
|
||||
debug!(
|
||||
"Using pool settings, query_parser_enabled: {}",
|
||||
@@ -876,9 +848,7 @@ impl QueryRouter {
|
||||
);
|
||||
value
|
||||
}
|
||||
};
|
||||
|
||||
enabled
|
||||
}
|
||||
}
|
||||
|
||||
pub fn primary_reads_enabled(&self) -> bool {
|
||||
@@ -909,10 +879,14 @@ mod test {
|
||||
fn test_infer_replica() {
|
||||
QueryRouter::setup();
|
||||
let mut qr = QueryRouter::new();
|
||||
assert!(qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")) != None);
|
||||
assert!(qr
|
||||
.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"))
|
||||
.is_some());
|
||||
assert!(qr.query_parser_enabled());
|
||||
|
||||
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
|
||||
assert!(qr
|
||||
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
|
||||
.is_some());
|
||||
|
||||
let queries = vec![
|
||||
simple_query("SELECT * FROM items WHERE id = 5"),
|
||||
@@ -953,7 +927,9 @@ mod test {
|
||||
QueryRouter::setup();
|
||||
let mut qr = QueryRouter::new();
|
||||
let query = simple_query("SELECT * FROM items WHERE id = 5");
|
||||
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None);
|
||||
assert!(qr
|
||||
.try_execute_command(&simple_query("SET PRIMARY READS TO on"))
|
||||
.is_some());
|
||||
|
||||
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
|
||||
assert_eq!(qr.role(), None);
|
||||
@@ -964,7 +940,9 @@ mod test {
|
||||
QueryRouter::setup();
|
||||
let mut qr = QueryRouter::new();
|
||||
qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"));
|
||||
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
|
||||
assert!(qr
|
||||
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
|
||||
.is_some());
|
||||
|
||||
let prepared_stmt = BytesMut::from(
|
||||
&b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..],
|
||||
@@ -1132,9 +1110,11 @@ mod test {
|
||||
QueryRouter::setup();
|
||||
let mut qr = QueryRouter::new();
|
||||
let query = simple_query("SET SERVER ROLE TO 'auto'");
|
||||
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
|
||||
assert!(qr
|
||||
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
|
||||
.is_some());
|
||||
|
||||
assert!(qr.try_execute_command(&query) != None);
|
||||
assert!(qr.try_execute_command(&query).is_some());
|
||||
assert!(qr.query_parser_enabled());
|
||||
assert_eq!(qr.role(), None);
|
||||
|
||||
@@ -1148,7 +1128,7 @@ mod test {
|
||||
|
||||
assert!(qr.query_parser_enabled());
|
||||
let query = simple_query("SET SERVER ROLE TO 'default'");
|
||||
assert!(qr.try_execute_command(&query) != None);
|
||||
assert!(qr.try_execute_command(&query).is_some());
|
||||
assert!(!qr.query_parser_enabled());
|
||||
}
|
||||
|
||||
@@ -1193,11 +1173,11 @@ mod test {
|
||||
assert!(!qr.primary_reads_enabled());
|
||||
|
||||
let q1 = simple_query("SET SERVER ROLE TO 'primary'");
|
||||
assert!(qr.try_execute_command(&q1) != None);
|
||||
assert!(qr.try_execute_command(&q1).is_some());
|
||||
assert_eq!(qr.active_role.unwrap(), Role::Primary);
|
||||
|
||||
let q2 = simple_query("SET SERVER ROLE TO 'default'");
|
||||
assert!(qr.try_execute_command(&q2) != None);
|
||||
assert!(qr.try_execute_command(&q2).is_some());
|
||||
assert_eq!(qr.active_role.unwrap(), pool_settings.default_role);
|
||||
}
|
||||
|
||||
@@ -1262,17 +1242,17 @@ mod test {
|
||||
|
||||
// Make sure setting it works
|
||||
let q1 = simple_query("/* shard_id: 1 */ select 1 from foo;");
|
||||
assert!(qr.try_execute_command(&q1) == None);
|
||||
assert!(qr.try_execute_command(&q1).is_none());
|
||||
assert_eq!(qr.active_shard, Some(1));
|
||||
|
||||
// And make sure changing it works
|
||||
let q2 = simple_query("/* shard_id: 0 */ select 1 from foo;");
|
||||
assert!(qr.try_execute_command(&q2) == None);
|
||||
assert!(qr.try_execute_command(&q2).is_none());
|
||||
assert_eq!(qr.active_shard, Some(0));
|
||||
|
||||
// Validate setting by shard with expected shard copied from sharding.rs tests
|
||||
let q2 = simple_query("/* sharding_key: 6 */ select 1 from foo;");
|
||||
assert!(qr.try_execute_command(&q2) == None);
|
||||
assert!(qr.try_execute_command(&q2).is_none());
|
||||
assert_eq!(qr.active_shard, Some(2));
|
||||
}
|
||||
|
||||
@@ -1410,9 +1390,11 @@ mod test {
|
||||
};
|
||||
|
||||
QueryRouter::setup();
|
||||
let mut pool_settings = PoolSettings::default();
|
||||
pool_settings.query_parser_enabled = true;
|
||||
pool_settings.plugins = Some(plugins);
|
||||
let pool_settings = PoolSettings {
|
||||
query_parser_enabled: true,
|
||||
plugins: Some(plugins),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut qr = QueryRouter::new();
|
||||
qr.update_pool_settings(pool_settings);
|
||||
|
||||
14
src/scram.rs
14
src/scram.rs
@@ -79,12 +79,12 @@ impl ScramSha256 {
|
||||
let server_message = Message::parse(message)?;
|
||||
|
||||
if !server_message.nonce.starts_with(&self.nonce) {
|
||||
return Err(Error::ProtocolSyncError(format!("SCRAM")));
|
||||
return Err(Error::ProtocolSyncError("SCRAM".to_string()));
|
||||
}
|
||||
|
||||
let salt = match general_purpose::STANDARD.decode(&server_message.salt) {
|
||||
Ok(salt) => salt,
|
||||
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
|
||||
Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
|
||||
};
|
||||
|
||||
let salted_password = Self::hi(
|
||||
@@ -166,9 +166,9 @@ impl ScramSha256 {
|
||||
pub fn finish(&mut self, message: &BytesMut) -> Result<(), Error> {
|
||||
let final_message = FinalMessage::parse(message)?;
|
||||
|
||||
let verifier = match general_purpose::STANDARD.decode(&final_message.value) {
|
||||
let verifier = match general_purpose::STANDARD.decode(final_message.value) {
|
||||
Ok(verifier) => verifier,
|
||||
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
|
||||
Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
|
||||
};
|
||||
|
||||
let mut hmac = match Hmac::<Sha256>::new_from_slice(&self.salted_password) {
|
||||
@@ -230,14 +230,14 @@ impl Message {
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
if parts.len() != 3 {
|
||||
return Err(Error::ProtocolSyncError(format!("SCRAM")));
|
||||
return Err(Error::ProtocolSyncError("SCRAM".to_string()));
|
||||
}
|
||||
|
||||
let nonce = str::replace(&parts[0], "r=", "");
|
||||
let salt = str::replace(&parts[1], "s=", "");
|
||||
let iterations = match str::replace(&parts[2], "i=", "").parse::<u32>() {
|
||||
Ok(iterations) => iterations,
|
||||
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
|
||||
Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
|
||||
};
|
||||
|
||||
Ok(Message {
|
||||
@@ -257,7 +257,7 @@ impl FinalMessage {
|
||||
/// Parse the server final validation message.
|
||||
pub fn parse(message: &BytesMut) -> Result<FinalMessage, Error> {
|
||||
if !message.starts_with(b"v=") || message.len() < 4 {
|
||||
return Err(Error::ProtocolSyncError(format!("SCRAM")));
|
||||
return Err(Error::ProtocolSyncError("SCRAM".to_string()));
|
||||
}
|
||||
|
||||
Ok(FinalMessage {
|
||||
|
||||
180
src/server.rs
180
src/server.rs
@@ -5,7 +5,7 @@ use fallible_iterator::FallibleIterator;
|
||||
use log::{debug, error, info, trace, warn};
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use postgres_protocol::message;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{BTreeSet, HashMap};
|
||||
use std::io::Read;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
@@ -15,7 +15,7 @@ use tokio::net::TcpStream;
|
||||
use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore};
|
||||
use tokio_rustls::{client::TlsStream, TlsConnector};
|
||||
|
||||
use crate::config::{get_config, Address, User};
|
||||
use crate::config::{get_config, get_prepared_statements_cache_size, Address, User};
|
||||
use crate::constants::*;
|
||||
use crate::dns_cache::{AddrSet, CACHED_RESOLVER};
|
||||
use crate::errors::{Error, ServerIdentifier};
|
||||
@@ -170,6 +170,9 @@ pub struct Server {
|
||||
/// Is there more data for the client to read.
|
||||
data_available: bool,
|
||||
|
||||
/// Is the server in copy-in or copy-out modes
|
||||
in_copy_mode: bool,
|
||||
|
||||
/// Is the server broken? We'll remote it from the pool if so.
|
||||
bad: bool,
|
||||
|
||||
@@ -198,6 +201,9 @@ pub struct Server {
|
||||
|
||||
/// Should clean up dirty connections?
|
||||
cleanup_connections: bool,
|
||||
|
||||
/// Prepared statements
|
||||
prepared_statements: BTreeSet<String>,
|
||||
}
|
||||
|
||||
impl Server {
|
||||
@@ -310,10 +316,7 @@ impl Server {
|
||||
|
||||
// Something else?
|
||||
m => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Unknown message: {}",
|
||||
m as char
|
||||
)));
|
||||
return Err(Error::SocketError(format!("Unknown message: {}", { m })));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -331,27 +334,18 @@ impl Server {
|
||||
None => &user.username,
|
||||
};
|
||||
|
||||
let password = match user.server_password {
|
||||
Some(ref server_password) => Some(server_password),
|
||||
None => match user.password {
|
||||
Some(ref password) => Some(password),
|
||||
None => None,
|
||||
},
|
||||
};
|
||||
let password = user.server_password.as_ref();
|
||||
|
||||
startup(&mut stream, username, database).await?;
|
||||
|
||||
let mut server_info = BytesMut::new();
|
||||
let mut process_id: i32 = 0;
|
||||
let mut secret_key: i32 = 0;
|
||||
let server_identifier = ServerIdentifier::new(username, &database);
|
||||
let server_identifier = ServerIdentifier::new(username, database);
|
||||
|
||||
// We'll be handling multiple packets, but they will all be structured the same.
|
||||
// We'll loop here until this exchange is complete.
|
||||
let mut scram: Option<ScramSha256> = match password {
|
||||
Some(password) => Some(ScramSha256::new(password)),
|
||||
None => None,
|
||||
};
|
||||
let mut scram: Option<ScramSha256> = password.map(|password| ScramSha256::new(password));
|
||||
|
||||
loop {
|
||||
let code = match stream.read_u8().await {
|
||||
@@ -674,6 +668,7 @@ impl Server {
|
||||
process_id,
|
||||
secret_key,
|
||||
in_transaction: false,
|
||||
in_copy_mode: false,
|
||||
data_available: false,
|
||||
bad: false,
|
||||
cleanup_state: CleanupState::new(),
|
||||
@@ -692,6 +687,7 @@ impl Server {
|
||||
)),
|
||||
},
|
||||
cleanup_connections,
|
||||
prepared_statements: BTreeSet::new(),
|
||||
};
|
||||
|
||||
server.set_name("pgcat").await?;
|
||||
@@ -745,7 +741,7 @@ impl Server {
|
||||
self.mirror_send(messages);
|
||||
self.stats().data_sent(messages.len());
|
||||
|
||||
match write_all_flush(&mut self.stream, &messages).await {
|
||||
match write_all_flush(&mut self.stream, messages).await {
|
||||
Ok(_) => {
|
||||
// Successfully sent to server
|
||||
self.last_activity = SystemTime::now();
|
||||
@@ -824,8 +820,19 @@ impl Server {
|
||||
break;
|
||||
}
|
||||
|
||||
// ErrorResponse
|
||||
'E' => {
|
||||
if self.in_copy_mode {
|
||||
self.in_copy_mode = false;
|
||||
}
|
||||
}
|
||||
|
||||
// CommandComplete
|
||||
'C' => {
|
||||
if self.in_copy_mode {
|
||||
self.in_copy_mode = false;
|
||||
}
|
||||
|
||||
let mut command_tag = String::new();
|
||||
match message.reader().read_to_string(&mut command_tag) {
|
||||
Ok(_) => {
|
||||
@@ -869,10 +876,14 @@ impl Server {
|
||||
}
|
||||
|
||||
// CopyInResponse: copy is starting from client to server.
|
||||
'G' => break,
|
||||
'G' => {
|
||||
self.in_copy_mode = true;
|
||||
break;
|
||||
}
|
||||
|
||||
// CopyOutResponse: copy is starting from the server to the client.
|
||||
'H' => {
|
||||
self.in_copy_mode = true;
|
||||
self.data_available = true;
|
||||
break;
|
||||
}
|
||||
@@ -910,6 +921,115 @@ impl Server {
|
||||
Ok(bytes)
|
||||
}
|
||||
|
||||
/// Add the prepared statement to being tracked by this server.
|
||||
/// The client is processing data that will create a prepared statement on this server.
|
||||
pub fn will_prepare(&mut self, name: &str) {
|
||||
debug!("Will prepare `{}`", name);
|
||||
|
||||
self.prepared_statements.insert(name.to_string());
|
||||
self.stats.prepared_cache_add();
|
||||
}
|
||||
|
||||
/// Check if we should prepare a statement on the server.
|
||||
pub fn should_prepare(&self, name: &str) -> bool {
|
||||
let should_prepare = !self.prepared_statements.contains(name);
|
||||
|
||||
debug!("Should prepare `{}`: {}", name, should_prepare);
|
||||
|
||||
if should_prepare {
|
||||
self.stats.prepared_cache_miss();
|
||||
} else {
|
||||
self.stats.prepared_cache_hit();
|
||||
}
|
||||
|
||||
should_prepare
|
||||
}
|
||||
|
||||
/// Create a prepared statement on the server.
|
||||
pub async fn prepare(&mut self, parse: &Parse) -> Result<(), Error> {
|
||||
debug!("Preparing `{}`", parse.name);
|
||||
|
||||
let bytes: BytesMut = parse.try_into()?;
|
||||
self.send(&bytes).await?;
|
||||
self.send(&flush()).await?;
|
||||
|
||||
// Read and discard ParseComplete (B)
|
||||
match read_message(&mut self.stream).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
self.bad = true;
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
|
||||
self.prepared_statements.insert(parse.name.to_string());
|
||||
self.stats.prepared_cache_add();
|
||||
|
||||
debug!("Prepared `{}`", parse.name);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Maintain adequate cache size on the server.
|
||||
pub async fn maintain_cache(&mut self) -> Result<(), Error> {
|
||||
debug!("Cache maintenance run");
|
||||
|
||||
let max_cache_size = get_prepared_statements_cache_size();
|
||||
let mut names = Vec::new();
|
||||
|
||||
while self.prepared_statements.len() >= max_cache_size {
|
||||
// The prepared statmeents are alphanumerically sorted by the BTree.
|
||||
// FIFO.
|
||||
if let Some(name) = self.prepared_statements.pop_last() {
|
||||
names.push(name);
|
||||
}
|
||||
}
|
||||
|
||||
self.deallocate(names).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove the prepared statement from being tracked by this server.
|
||||
/// The client is processing data that will cause the server to close the prepared statement.
|
||||
pub fn will_close(&mut self, name: &str) {
|
||||
debug!("Will close `{}`", name);
|
||||
|
||||
self.prepared_statements.remove(name);
|
||||
}
|
||||
|
||||
/// Close a prepared statement on the server.
|
||||
pub async fn deallocate(&mut self, names: Vec<String>) -> Result<(), Error> {
|
||||
for name in &names {
|
||||
debug!("Deallocating prepared statement `{}`", name);
|
||||
|
||||
let close = Close::new(name);
|
||||
let bytes: BytesMut = close.try_into()?;
|
||||
|
||||
self.send(&bytes).await?;
|
||||
}
|
||||
|
||||
self.send(&flush()).await?;
|
||||
|
||||
// Read and discard CloseComplete (3)
|
||||
for name in &names {
|
||||
match read_message(&mut self.stream).await {
|
||||
Ok(_) => {
|
||||
self.prepared_statements.remove(name);
|
||||
self.stats.prepared_cache_remove();
|
||||
debug!("Closed `{}`", name);
|
||||
}
|
||||
|
||||
Err(err) => {
|
||||
self.bad = true;
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// If the server is still inside a transaction.
|
||||
/// If the client disconnects while the server is in a transaction, we will clean it up.
|
||||
pub fn in_transaction(&self) -> bool {
|
||||
@@ -917,6 +1037,10 @@ impl Server {
|
||||
self.in_transaction
|
||||
}
|
||||
|
||||
pub fn in_copy_mode(&self) -> bool {
|
||||
self.in_copy_mode
|
||||
}
|
||||
|
||||
/// We don't buffer all of server responses, e.g. COPY OUT produces too much data.
|
||||
/// The client is responsible to call `self.recv()` while this method returns true.
|
||||
pub fn is_data_available(&self) -> bool {
|
||||
@@ -1016,6 +1140,10 @@ impl Server {
|
||||
self.cleanup_state.reset();
|
||||
}
|
||||
|
||||
if self.in_copy_mode() {
|
||||
warn!("Server returned while still in copy-mode");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1059,16 +1187,14 @@ impl Server {
|
||||
}
|
||||
|
||||
pub fn mirror_send(&mut self, bytes: &BytesMut) {
|
||||
match self.mirror_manager.as_mut() {
|
||||
Some(manager) => manager.send(bytes),
|
||||
None => (),
|
||||
if let Some(manager) = self.mirror_manager.as_mut() {
|
||||
manager.send(bytes);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mirror_disconnect(&mut self) {
|
||||
match self.mirror_manager.as_mut() {
|
||||
Some(manager) => manager.disconnect(),
|
||||
None => (),
|
||||
if let Some(manager) = self.mirror_manager.as_mut() {
|
||||
manager.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1096,7 +1222,7 @@ impl Server {
|
||||
server.send(&simple_query(query)).await?;
|
||||
let mut message = server.recv().await?;
|
||||
|
||||
Ok(parse_query_message(&mut message).await?)
|
||||
parse_query_message(&mut message).await
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ impl Sharder {
|
||||
fn sha1(&self, key: i64) -> usize {
|
||||
let mut hasher = Sha1::new();
|
||||
|
||||
hasher.update(&key.to_string().as_bytes());
|
||||
hasher.update(key.to_string().as_bytes());
|
||||
|
||||
let result = hasher.finalize();
|
||||
|
||||
|
||||
26
src/stats.rs
26
src/stats.rs
@@ -1,4 +1,3 @@
|
||||
use crate::pool::PoolIdentifier;
|
||||
/// Statistics and reporting.
|
||||
use arc_swap::ArcSwap;
|
||||
|
||||
@@ -16,13 +15,11 @@ pub mod pool;
|
||||
pub mod server;
|
||||
pub use address::AddressStats;
|
||||
pub use client::{ClientState, ClientStats};
|
||||
pub use pool::PoolStats;
|
||||
pub use server::{ServerState, ServerStats};
|
||||
|
||||
/// Convenience types for various stats
|
||||
type ClientStatesLookup = HashMap<i32, Arc<ClientStats>>;
|
||||
type ServerStatesLookup = HashMap<i32, Arc<ServerStats>>;
|
||||
type PoolStatsLookup = HashMap<(String, String), Arc<PoolStats>>;
|
||||
|
||||
/// Stats for individual client connections
|
||||
/// Used in SHOW CLIENTS.
|
||||
@@ -34,11 +31,6 @@ static CLIENT_STATS: Lazy<Arc<RwLock<ClientStatesLookup>>> =
|
||||
static SERVER_STATS: Lazy<Arc<RwLock<ServerStatesLookup>>> =
|
||||
Lazy::new(|| Arc::new(RwLock::new(ServerStatesLookup::default())));
|
||||
|
||||
/// Aggregate stats for each pool (a pool is identified by database name and username)
|
||||
/// Used in SHOW POOLS.
|
||||
static POOL_STATS: Lazy<Arc<RwLock<PoolStatsLookup>>> =
|
||||
Lazy::new(|| Arc::new(RwLock::new(PoolStatsLookup::default())));
|
||||
|
||||
/// The statistics reporter. An instance is given to each possible source of statistics,
|
||||
/// e.g. client stats, server stats, connection pool stats.
|
||||
pub static REPORTER: Lazy<ArcSwap<Reporter>> =
|
||||
@@ -80,25 +72,17 @@ impl Reporter {
|
||||
fn server_disconnecting(&self, server_id: i32) {
|
||||
SERVER_STATS.write().remove(&server_id);
|
||||
}
|
||||
|
||||
/// Register a pool with the stats system.
|
||||
fn pool_register(&self, identifier: PoolIdentifier, stats: Arc<PoolStats>) {
|
||||
POOL_STATS
|
||||
.write()
|
||||
.insert((identifier.db, identifier.user), stats);
|
||||
}
|
||||
}
|
||||
|
||||
/// The statistics collector which used for calculating averages
|
||||
/// There is only one collector (kind of like a singleton)
|
||||
/// it updates averages every 15 seconds.
|
||||
#[derive(Default)]
|
||||
pub struct Collector {}
|
||||
pub struct Collector;
|
||||
|
||||
impl Collector {
|
||||
/// The statistics collection handler. It will collect statistics
|
||||
/// for `address_id`s starting at 0 up to `addresses`.
|
||||
pub async fn collect(&mut self) {
|
||||
pub fn collect() {
|
||||
info!("Events reporter started");
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
@@ -139,12 +123,6 @@ pub fn get_server_stats() -> ServerStatesLookup {
|
||||
SERVER_STATS.read().clone()
|
||||
}
|
||||
|
||||
/// Get a snapshot of pool statistics.
|
||||
/// by the `Collector`.
|
||||
pub fn get_pool_stats() -> PoolStatsLookup {
|
||||
POOL_STATS.read().clone()
|
||||
}
|
||||
|
||||
/// Get the statistics reporter used to update stats across the pools/clients.
|
||||
pub fn get_reporter() -> Reporter {
|
||||
(*(*REPORTER.load())).clone()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use super::PoolStats;
|
||||
use super::{get_reporter, Reporter};
|
||||
use atomic_enum::atomic_enum;
|
||||
use std::sync::atomic::*;
|
||||
@@ -34,12 +33,14 @@ pub struct ClientStats {
|
||||
pool_name: String,
|
||||
connect_time: Instant,
|
||||
|
||||
pool_stats: Arc<PoolStats>,
|
||||
reporter: Reporter,
|
||||
|
||||
/// Total time spent waiting for a connection from pool, measures in microseconds
|
||||
pub total_wait_time: Arc<AtomicU64>,
|
||||
|
||||
/// Maximum time spent waiting for a connection from pool, measures in microseconds
|
||||
pub max_wait_time: Arc<AtomicU64>,
|
||||
|
||||
/// Current state of the client
|
||||
pub state: Arc<AtomicClientState>,
|
||||
|
||||
@@ -61,8 +62,8 @@ impl Default for ClientStats {
|
||||
application_name: String::new(),
|
||||
username: String::new(),
|
||||
pool_name: String::new(),
|
||||
pool_stats: Arc::new(PoolStats::default()),
|
||||
total_wait_time: Arc::new(AtomicU64::new(0)),
|
||||
max_wait_time: Arc::new(AtomicU64::new(0)),
|
||||
state: Arc::new(AtomicClientState::new(ClientState::Idle)),
|
||||
transaction_count: Arc::new(AtomicU64::new(0)),
|
||||
query_count: Arc::new(AtomicU64::new(0)),
|
||||
@@ -79,11 +80,9 @@ impl ClientStats {
|
||||
username: &str,
|
||||
pool_name: &str,
|
||||
connect_time: Instant,
|
||||
pool_stats: Arc<PoolStats>,
|
||||
) -> Self {
|
||||
Self {
|
||||
client_id,
|
||||
pool_stats,
|
||||
connect_time,
|
||||
application_name: application_name.to_string(),
|
||||
username: username.to_string(),
|
||||
@@ -96,8 +95,6 @@ impl ClientStats {
|
||||
/// update metrics on the corresponding pool.
|
||||
pub fn disconnect(&self) {
|
||||
self.reporter.client_disconnecting(self.client_id);
|
||||
self.pool_stats
|
||||
.client_disconnect(self.state.load(Ordering::Relaxed))
|
||||
}
|
||||
|
||||
/// Register a client with the stats system. The stats system uses client_id
|
||||
@@ -105,27 +102,20 @@ impl ClientStats {
|
||||
pub fn register(&self, stats: Arc<ClientStats>) {
|
||||
self.reporter.client_register(self.client_id, stats);
|
||||
self.state.store(ClientState::Idle, Ordering::Relaxed);
|
||||
self.pool_stats.cl_idle.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Reports a client is done querying the server and is no longer assigned a server connection
|
||||
pub fn idle(&self) {
|
||||
self.pool_stats
|
||||
.client_idle(self.state.load(Ordering::Relaxed));
|
||||
self.state.store(ClientState::Idle, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Reports a client is waiting for a connection
|
||||
pub fn waiting(&self) {
|
||||
self.pool_stats
|
||||
.client_waiting(self.state.load(Ordering::Relaxed));
|
||||
self.state.store(ClientState::Waiting, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Reports a client is done waiting for a connection and is about to query the server.
|
||||
pub fn active(&self) {
|
||||
self.pool_stats
|
||||
.client_active(self.state.load(Ordering::Relaxed));
|
||||
self.state.store(ClientState::Active, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
@@ -144,6 +134,8 @@ impl ClientStats {
|
||||
pub fn checkout_time(&self, microseconds: u64) {
|
||||
self.total_wait_time
|
||||
.fetch_add(microseconds, Ordering::Relaxed);
|
||||
self.max_wait_time
|
||||
.fetch_max(microseconds, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Report a query executed by a client against a server
|
||||
|
||||
@@ -1,36 +1,131 @@
|
||||
use crate::config::Pool;
|
||||
use crate::config::PoolMode;
|
||||
use crate::pool::PoolIdentifier;
|
||||
use std::sync::atomic::*;
|
||||
use std::sync::Arc;
|
||||
use log::debug;
|
||||
|
||||
use super::get_reporter;
|
||||
use super::Reporter;
|
||||
use super::{ClientState, ServerState};
|
||||
use crate::{config::PoolMode, messages::DataType, pool::PoolIdentifier};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::*;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
use crate::pool::get_all_pools;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// A struct that holds information about a Pool .
|
||||
pub struct PoolStats {
|
||||
// Pool identifier, cannot be changed after creating the instance
|
||||
identifier: PoolIdentifier,
|
||||
pub identifier: PoolIdentifier,
|
||||
pub mode: PoolMode,
|
||||
pub cl_idle: u64,
|
||||
pub cl_active: u64,
|
||||
pub cl_waiting: u64,
|
||||
pub cl_cancel_req: u64,
|
||||
pub sv_active: u64,
|
||||
pub sv_idle: u64,
|
||||
pub sv_used: u64,
|
||||
pub sv_tested: u64,
|
||||
pub sv_login: u64,
|
||||
pub maxwait: u64,
|
||||
}
|
||||
impl PoolStats {
|
||||
pub fn new(identifier: PoolIdentifier, mode: PoolMode) -> Self {
|
||||
PoolStats {
|
||||
identifier,
|
||||
mode,
|
||||
cl_idle: 0,
|
||||
cl_active: 0,
|
||||
cl_waiting: 0,
|
||||
cl_cancel_req: 0,
|
||||
sv_active: 0,
|
||||
sv_idle: 0,
|
||||
sv_used: 0,
|
||||
sv_tested: 0,
|
||||
sv_login: 0,
|
||||
maxwait: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Pool Config, cannot be changed after creating the instance
|
||||
config: Pool,
|
||||
pub fn construct_pool_lookup() -> HashMap<PoolIdentifier, PoolStats> {
|
||||
let mut map: HashMap<PoolIdentifier, PoolStats> = HashMap::new();
|
||||
let client_map = super::get_client_stats();
|
||||
let server_map = super::get_server_stats();
|
||||
|
||||
// A reference to the global reporter.
|
||||
reporter: Reporter,
|
||||
for (identifier, pool) in get_all_pools() {
|
||||
map.insert(
|
||||
identifier.clone(),
|
||||
PoolStats::new(identifier, pool.settings.pool_mode),
|
||||
);
|
||||
}
|
||||
|
||||
/// Counters (atomics)
|
||||
pub cl_idle: Arc<AtomicU64>,
|
||||
pub cl_active: Arc<AtomicU64>,
|
||||
pub cl_waiting: Arc<AtomicU64>,
|
||||
pub cl_cancel_req: Arc<AtomicU64>,
|
||||
pub sv_active: Arc<AtomicU64>,
|
||||
pub sv_idle: Arc<AtomicU64>,
|
||||
pub sv_used: Arc<AtomicU64>,
|
||||
pub sv_tested: Arc<AtomicU64>,
|
||||
pub sv_login: Arc<AtomicU64>,
|
||||
pub maxwait: Arc<AtomicU64>,
|
||||
for client in client_map.values() {
|
||||
match map.get_mut(&PoolIdentifier {
|
||||
db: client.pool_name(),
|
||||
user: client.username(),
|
||||
}) {
|
||||
Some(pool_stats) => {
|
||||
match client.state.load(Ordering::Relaxed) {
|
||||
ClientState::Active => pool_stats.cl_active += 1,
|
||||
ClientState::Idle => pool_stats.cl_idle += 1,
|
||||
ClientState::Waiting => pool_stats.cl_waiting += 1,
|
||||
}
|
||||
let max_wait = client.max_wait_time.load(Ordering::Relaxed);
|
||||
pool_stats.maxwait = std::cmp::max(pool_stats.maxwait, max_wait);
|
||||
}
|
||||
None => debug!("Client from an obselete pool"),
|
||||
}
|
||||
}
|
||||
|
||||
for server in server_map.values() {
|
||||
match map.get_mut(&PoolIdentifier {
|
||||
db: server.pool_name(),
|
||||
user: server.username(),
|
||||
}) {
|
||||
Some(pool_stats) => match server.state.load(Ordering::Relaxed) {
|
||||
ServerState::Active => pool_stats.sv_active += 1,
|
||||
ServerState::Idle => pool_stats.sv_idle += 1,
|
||||
ServerState::Login => pool_stats.sv_login += 1,
|
||||
ServerState::Tested => pool_stats.sv_tested += 1,
|
||||
},
|
||||
None => debug!("Server from an obselete pool"),
|
||||
}
|
||||
}
|
||||
|
||||
map
|
||||
}
|
||||
|
||||
pub fn generate_header() -> Vec<(&'static str, DataType)> {
|
||||
vec![
|
||||
("database", DataType::Text),
|
||||
("user", DataType::Text),
|
||||
("pool_mode", DataType::Text),
|
||||
("cl_idle", DataType::Numeric),
|
||||
("cl_active", DataType::Numeric),
|
||||
("cl_waiting", DataType::Numeric),
|
||||
("cl_cancel_req", DataType::Numeric),
|
||||
("sv_active", DataType::Numeric),
|
||||
("sv_idle", DataType::Numeric),
|
||||
("sv_used", DataType::Numeric),
|
||||
("sv_tested", DataType::Numeric),
|
||||
("sv_login", DataType::Numeric),
|
||||
("maxwait", DataType::Numeric),
|
||||
("maxwait_us", DataType::Numeric),
|
||||
]
|
||||
}
|
||||
|
||||
pub fn generate_row(&self) -> Vec<String> {
|
||||
vec![
|
||||
self.identifier.db.clone(),
|
||||
self.identifier.user.clone(),
|
||||
self.mode.to_string(),
|
||||
self.cl_idle.to_string(),
|
||||
self.cl_active.to_string(),
|
||||
self.cl_waiting.to_string(),
|
||||
self.cl_cancel_req.to_string(),
|
||||
self.sv_active.to_string(),
|
||||
self.sv_idle.to_string(),
|
||||
self.sv_used.to_string(),
|
||||
self.sv_tested.to_string(),
|
||||
self.sv_login.to_string(),
|
||||
(self.maxwait / 1_000_000).to_string(),
|
||||
(self.maxwait % 1_000_000).to_string(),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoIterator for PoolStats {
|
||||
@@ -39,236 +134,18 @@ impl IntoIterator for PoolStats {
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
vec![
|
||||
("cl_idle".to_string(), self.cl_idle.load(Ordering::Relaxed)),
|
||||
(
|
||||
"cl_active".to_string(),
|
||||
self.cl_active.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"cl_waiting".to_string(),
|
||||
self.cl_waiting.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"cl_cancel_req".to_string(),
|
||||
self.cl_cancel_req.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"sv_active".to_string(),
|
||||
self.sv_active.load(Ordering::Relaxed),
|
||||
),
|
||||
("sv_idle".to_string(), self.sv_idle.load(Ordering::Relaxed)),
|
||||
("sv_used".to_string(), self.sv_used.load(Ordering::Relaxed)),
|
||||
(
|
||||
"sv_tested".to_string(),
|
||||
self.sv_tested.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"sv_login".to_string(),
|
||||
self.sv_login.load(Ordering::Relaxed),
|
||||
),
|
||||
(
|
||||
"maxwait".to_string(),
|
||||
self.maxwait.load(Ordering::Relaxed) / 1_000_000,
|
||||
),
|
||||
(
|
||||
"maxwait_us".to_string(),
|
||||
self.maxwait.load(Ordering::Relaxed) % 1_000_000,
|
||||
),
|
||||
("cl_idle".to_string(), self.cl_idle),
|
||||
("cl_active".to_string(), self.cl_active),
|
||||
("cl_waiting".to_string(), self.cl_waiting),
|
||||
("cl_cancel_req".to_string(), self.cl_cancel_req),
|
||||
("sv_active".to_string(), self.sv_active),
|
||||
("sv_idle".to_string(), self.sv_idle),
|
||||
("sv_used".to_string(), self.sv_used),
|
||||
("sv_tested".to_string(), self.sv_tested),
|
||||
("sv_login".to_string(), self.sv_login),
|
||||
("maxwait".to_string(), self.maxwait / 1_000_000),
|
||||
("maxwait_us".to_string(), self.maxwait % 1_000_000),
|
||||
]
|
||||
.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl PoolStats {
|
||||
pub fn new(identifier: PoolIdentifier, config: Pool) -> Self {
|
||||
Self {
|
||||
identifier,
|
||||
config,
|
||||
reporter: get_reporter(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
// Getters
|
||||
pub fn register(&self, stats: Arc<PoolStats>) {
|
||||
self.reporter.pool_register(self.identifier.clone(), stats);
|
||||
}
|
||||
|
||||
pub fn database(&self) -> String {
|
||||
self.identifier.db.clone()
|
||||
}
|
||||
|
||||
pub fn user(&self) -> String {
|
||||
self.identifier.user.clone()
|
||||
}
|
||||
|
||||
pub fn pool_mode(&self) -> PoolMode {
|
||||
self.config.pool_mode
|
||||
}
|
||||
|
||||
/// Populates an array of strings with counters (used by admin in show pools)
|
||||
pub fn populate_row(&self, row: &mut Vec<String>) {
|
||||
for (_key, value) in self.clone() {
|
||||
row.push(value.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
/// Deletes the maxwait counter, this is done everytime we obtain metrics
|
||||
pub fn clear_maxwait(&self) {
|
||||
self.maxwait.store(0, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Notified when a server of the pool enters login state.
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the server that notifies.
|
||||
pub fn server_login(&self, from: ServerState) {
|
||||
self.sv_login.fetch_add(1, Ordering::Relaxed);
|
||||
if from != ServerState::Login {
|
||||
self.decrease_from_server_state(from);
|
||||
}
|
||||
}
|
||||
|
||||
/// Notified when a server of the pool become 'active'
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the server that notifies.
|
||||
pub fn server_active(&self, from: ServerState) {
|
||||
self.sv_active.fetch_add(1, Ordering::Relaxed);
|
||||
if from != ServerState::Active {
|
||||
self.decrease_from_server_state(from);
|
||||
}
|
||||
}
|
||||
|
||||
/// Notified when a server of the pool become 'tested'
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the server that notifies.
|
||||
pub fn server_tested(&self, from: ServerState) {
|
||||
self.sv_tested.fetch_add(1, Ordering::Relaxed);
|
||||
if from != ServerState::Tested {
|
||||
self.decrease_from_server_state(from);
|
||||
}
|
||||
}
|
||||
|
||||
/// Notified when a server of the pool become 'idle'
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the server that notifies.
|
||||
pub fn server_idle(&self, from: ServerState) {
|
||||
self.sv_idle.fetch_add(1, Ordering::Relaxed);
|
||||
if from != ServerState::Idle {
|
||||
self.decrease_from_server_state(from);
|
||||
}
|
||||
}
|
||||
|
||||
/// Notified when a client of the pool become 'waiting'
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the client that notifies.
|
||||
pub fn client_waiting(&self, from: ClientState) {
|
||||
if from != ClientState::Waiting {
|
||||
self.cl_waiting.fetch_add(1, Ordering::Relaxed);
|
||||
self.decrease_from_client_state(from);
|
||||
}
|
||||
}
|
||||
|
||||
/// Notified when a client of the pool become 'active'
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the client that notifies.
|
||||
pub fn client_active(&self, from: ClientState) {
|
||||
if from != ClientState::Active {
|
||||
self.cl_active.fetch_add(1, Ordering::Relaxed);
|
||||
self.decrease_from_client_state(from);
|
||||
}
|
||||
}
|
||||
|
||||
/// Notified when a client of the pool become 'idle'
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the client that notifies.
|
||||
pub fn client_idle(&self, from: ClientState) {
|
||||
if from != ClientState::Idle {
|
||||
self.cl_idle.fetch_add(1, Ordering::Relaxed);
|
||||
self.decrease_from_client_state(from);
|
||||
}
|
||||
}
|
||||
|
||||
/// Notified when a client disconnects.
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the client that notifies.
|
||||
pub fn client_disconnect(&self, from: ClientState) {
|
||||
let counter = match from {
|
||||
ClientState::Idle => &self.cl_idle,
|
||||
ClientState::Waiting => &self.cl_waiting,
|
||||
ClientState::Active => &self.cl_active,
|
||||
};
|
||||
|
||||
Self::decrease_counter(counter.clone());
|
||||
}
|
||||
|
||||
/// Notified when a server disconnects.
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// `from`: The state of the client that notifies.
|
||||
pub fn server_disconnect(&self, from: ServerState) {
|
||||
let counter = match from {
|
||||
ServerState::Active => &self.sv_active,
|
||||
ServerState::Idle => &self.sv_idle,
|
||||
ServerState::Login => &self.sv_login,
|
||||
ServerState::Tested => &self.sv_tested,
|
||||
};
|
||||
Self::decrease_counter(counter.clone());
|
||||
}
|
||||
|
||||
// helpers for counter decrease
|
||||
fn decrease_from_server_state(&self, from: ServerState) {
|
||||
let counter = match from {
|
||||
ServerState::Tested => &self.sv_tested,
|
||||
ServerState::Active => &self.sv_active,
|
||||
ServerState::Idle => &self.sv_idle,
|
||||
ServerState::Login => &self.sv_login,
|
||||
};
|
||||
Self::decrease_counter(counter.clone());
|
||||
}
|
||||
|
||||
fn decrease_from_client_state(&self, from: ClientState) {
|
||||
let counter = match from {
|
||||
ClientState::Active => &self.cl_active,
|
||||
ClientState::Idle => &self.cl_idle,
|
||||
ClientState::Waiting => &self.cl_waiting,
|
||||
};
|
||||
Self::decrease_counter(counter.clone());
|
||||
}
|
||||
|
||||
fn decrease_counter(value: Arc<AtomicU64>) {
|
||||
if value.load(Ordering::Relaxed) > 0 {
|
||||
value.fetch_sub(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_decrease() {
|
||||
let stat: PoolStats = PoolStats::default();
|
||||
stat.server_login(ServerState::Login);
|
||||
stat.server_idle(ServerState::Login);
|
||||
assert_eq!(stat.sv_login.load(Ordering::Relaxed), 0);
|
||||
assert_eq!(stat.sv_idle.load(Ordering::Relaxed), 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use super::AddressStats;
|
||||
use super::PoolStats;
|
||||
use super::{get_reporter, Reporter};
|
||||
use crate::config::Address;
|
||||
use atomic_enum::atomic_enum;
|
||||
@@ -38,7 +37,6 @@ pub struct ServerStats {
|
||||
address: Address,
|
||||
connect_time: Instant,
|
||||
|
||||
pool_stats: Arc<PoolStats>,
|
||||
reporter: Reporter,
|
||||
|
||||
/// Data
|
||||
@@ -49,6 +47,9 @@ pub struct ServerStats {
|
||||
pub transaction_count: Arc<AtomicU64>,
|
||||
pub query_count: Arc<AtomicU64>,
|
||||
pub error_count: Arc<AtomicU64>,
|
||||
pub prepared_hit_count: Arc<AtomicU64>,
|
||||
pub prepared_miss_count: Arc<AtomicU64>,
|
||||
pub prepared_cache_size: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
impl Default for ServerStats {
|
||||
@@ -57,7 +58,6 @@ impl Default for ServerStats {
|
||||
server_id: 0,
|
||||
application_name: Arc::new(RwLock::new(String::new())),
|
||||
address: Address::default(),
|
||||
pool_stats: Arc::new(PoolStats::default()),
|
||||
connect_time: Instant::now(),
|
||||
state: Arc::new(AtomicServerState::new(ServerState::Login)),
|
||||
bytes_sent: Arc::new(AtomicU64::new(0)),
|
||||
@@ -66,15 +66,17 @@ impl Default for ServerStats {
|
||||
query_count: Arc::new(AtomicU64::new(0)),
|
||||
error_count: Arc::new(AtomicU64::new(0)),
|
||||
reporter: get_reporter(),
|
||||
prepared_hit_count: Arc::new(AtomicU64::new(0)),
|
||||
prepared_miss_count: Arc::new(AtomicU64::new(0)),
|
||||
prepared_cache_size: Arc::new(AtomicU64::new(0)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerStats {
|
||||
pub fn new(address: Address, pool_stats: Arc<PoolStats>, connect_time: Instant) -> Self {
|
||||
pub fn new(address: Address, connect_time: Instant) -> Self {
|
||||
Self {
|
||||
address,
|
||||
pool_stats,
|
||||
connect_time,
|
||||
server_id: rand::random::<i32>(),
|
||||
..Default::default()
|
||||
@@ -96,9 +98,6 @@ impl ServerStats {
|
||||
/// Reports a server connection is no longer assigned to a client
|
||||
/// and is available for the next client to pick it up
|
||||
pub fn idle(&self) {
|
||||
self.pool_stats
|
||||
.server_idle(self.state.load(Ordering::Relaxed));
|
||||
|
||||
self.state.store(ServerState::Idle, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
@@ -106,22 +105,16 @@ impl ServerStats {
|
||||
/// Also updates metrics on the pool regarding server usage.
|
||||
pub fn disconnect(&self) {
|
||||
self.reporter.server_disconnecting(self.server_id);
|
||||
self.pool_stats
|
||||
.server_disconnect(self.state.load(Ordering::Relaxed))
|
||||
}
|
||||
|
||||
/// Reports a server connection is being tested before being given to a client.
|
||||
pub fn tested(&self) {
|
||||
self.set_undefined_application();
|
||||
self.pool_stats
|
||||
.server_tested(self.state.load(Ordering::Relaxed));
|
||||
self.state.store(ServerState::Tested, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Reports a server connection is attempting to login.
|
||||
pub fn login(&self) {
|
||||
self.pool_stats
|
||||
.server_login(self.state.load(Ordering::Relaxed));
|
||||
self.state.store(ServerState::Login, Ordering::Relaxed);
|
||||
self.set_undefined_application();
|
||||
}
|
||||
@@ -129,8 +122,6 @@ impl ServerStats {
|
||||
/// Reports a server connection has been assigned to a client that
|
||||
/// is about to query the server
|
||||
pub fn active(&self, application_name: String) {
|
||||
self.pool_stats
|
||||
.server_active(self.state.load(Ordering::Relaxed));
|
||||
self.state.store(ServerState::Active, Ordering::Relaxed);
|
||||
self.set_application(application_name);
|
||||
}
|
||||
@@ -152,11 +143,11 @@ impl ServerStats {
|
||||
|
||||
// Helper methods for show_servers
|
||||
pub fn pool_name(&self) -> String {
|
||||
self.pool_stats.database()
|
||||
self.address.pool_name.clone()
|
||||
}
|
||||
|
||||
pub fn username(&self) -> String {
|
||||
self.pool_stats.user()
|
||||
self.address.username.clone()
|
||||
}
|
||||
|
||||
pub fn address_name(&self) -> String {
|
||||
@@ -180,9 +171,6 @@ impl ServerStats {
|
||||
// Update server stats and address aggregation stats
|
||||
self.set_application(application_name);
|
||||
self.address.stats.wait_time_add(microseconds);
|
||||
self.pool_stats
|
||||
.maxwait
|
||||
.fetch_max(microseconds, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Report a query executed by a client against a server
|
||||
@@ -190,6 +178,7 @@ impl ServerStats {
|
||||
self.set_application(application_name.to_string());
|
||||
self.address.stats.query_count_add();
|
||||
self.address.stats.query_time_add(milliseconds);
|
||||
self.query_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Report a transaction executed by a client a server
|
||||
@@ -216,4 +205,22 @@ impl ServerStats {
|
||||
.fetch_add(amount_bytes as u64, Ordering::Relaxed);
|
||||
self.address.stats.bytes_received_add(amount_bytes as u64);
|
||||
}
|
||||
|
||||
/// Report a prepared statement that already exists on the server.
|
||||
pub fn prepared_cache_hit(&self) {
|
||||
self.prepared_hit_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Report a prepared statement that does not exist on the server yet.
|
||||
pub fn prepared_cache_miss(&self) {
|
||||
self.prepared_miss_count.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn prepared_cache_add(&self) {
|
||||
self.prepared_cache_size.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn prepared_cache_remove(&self) {
|
||||
self.prepared_cache_size.fetch_sub(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
24
src/tls.rs
24
src/tls.rs
@@ -44,25 +44,17 @@ impl Tls {
|
||||
pub fn new() -> Result<Self, Error> {
|
||||
let config = get_config();
|
||||
|
||||
let certs = match load_certs(Path::new(&config.general.tls_certificate.unwrap())) {
|
||||
Ok(certs) => certs,
|
||||
Err(_) => return Err(Error::TlsError),
|
||||
};
|
||||
let certs = load_certs(Path::new(&config.general.tls_certificate.unwrap()))
|
||||
.map_err(|_| Error::TlsError)?;
|
||||
let key_der = load_keys(Path::new(&config.general.tls_private_key.unwrap()))
|
||||
.map_err(|_| Error::TlsError)?
|
||||
.remove(0);
|
||||
|
||||
let mut keys = match load_keys(Path::new(&config.general.tls_private_key.unwrap())) {
|
||||
Ok(keys) => keys,
|
||||
Err(_) => return Err(Error::TlsError),
|
||||
};
|
||||
|
||||
let config = match rustls::ServerConfig::builder()
|
||||
let config = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, keys.remove(0))
|
||||
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))
|
||||
{
|
||||
Ok(c) => c,
|
||||
Err(_) => return Err(Error::TlsError),
|
||||
};
|
||||
.with_single_cert(certs, key_der)
|
||||
.map_err(|_| Error::TlsError)?;
|
||||
|
||||
Ok(Tls {
|
||||
acceptor: TlsAcceptor::from(Arc::new(config)),
|
||||
|
||||
@@ -63,6 +63,7 @@ def cleanup_conn(conn: psycopg2.extensions.connection, cur: psycopg2.extensions.
|
||||
|
||||
|
||||
def test_normal_db_access():
|
||||
pgcat_start()
|
||||
conn, cur = connect_db(autocommit=False)
|
||||
cur.execute("SELECT 1")
|
||||
res = cur.fetchall()
|
||||
|
||||
@@ -11,326 +11,6 @@ describe "Admin" do
|
||||
processes.pgcat.shutdown
|
||||
end
|
||||
|
||||
describe "SHOW STATS" do
|
||||
context "clients connect and make one query" do
|
||||
it "updates *_query_time and *_wait_time" do
|
||||
connections = Array.new(3) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
Thread.new { c.async_exec("SELECT pg_sleep(0.25)") }
|
||||
end
|
||||
sleep(1)
|
||||
connections.map(&:close)
|
||||
|
||||
# wait for averages to be calculated, we shouldn't do this too often
|
||||
sleep(15.5)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW STATS")[0]
|
||||
admin_conn.close
|
||||
expect(results["total_query_time"].to_i).to be_within(200).of(750)
|
||||
expect(results["avg_query_time"].to_i).to be_within(50).of(250)
|
||||
|
||||
expect(results["total_wait_time"].to_i).to_not eq(0)
|
||||
expect(results["avg_wait_time"].to_i).to_not eq(0)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "SHOW POOLS" do
|
||||
context "bad credentials" do
|
||||
it "does not change any stats" do
|
||||
bad_password_url = URI(pgcat_conn_str)
|
||||
bad_password_url.password = "wrong"
|
||||
expect { PG::connect("#{bad_password_url.to_s}?application_name=bad_password") }.to raise_error(PG::ConnectionBad)
|
||||
|
||||
sleep(1)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
|
||||
expect(results["sv_idle"]).to eq("1")
|
||||
end
|
||||
end
|
||||
|
||||
context "bad database name" do
|
||||
it "does not change any stats" do
|
||||
bad_db_url = URI(pgcat_conn_str)
|
||||
bad_db_url.path = "/wrong_db"
|
||||
expect { PG::connect("#{bad_db_url.to_s}?application_name=bad_db") }.to raise_error(PG::ConnectionBad)
|
||||
|
||||
sleep(1)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
|
||||
expect(results["sv_idle"]).to eq("1")
|
||||
end
|
||||
end
|
||||
|
||||
context "client connects but issues no queries" do
|
||||
it "only affects cl_idle stats" do
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
|
||||
before_test = admin_conn.async_exec("SHOW POOLS")[0]["sv_idle"]
|
||||
connections = Array.new(20) { PG::connect(pgcat_conn_str) }
|
||||
sleep(1)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["cl_idle"]).to eq("20")
|
||||
expect(results["sv_idle"]).to eq(before_test)
|
||||
|
||||
connections.map(&:close)
|
||||
sleep(1.1)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_active cl_idle cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["sv_idle"]).to eq(before_test)
|
||||
end
|
||||
end
|
||||
|
||||
context "clients connect and make one query" do
|
||||
it "only affects cl_idle, sv_idle stats" do
|
||||
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
Thread.new { c.async_exec("SELECT pg_sleep(2.5)") }
|
||||
end
|
||||
|
||||
sleep(1.1)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_waiting cl_cancel_req sv_idle sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["cl_active"]).to eq("5")
|
||||
expect(results["sv_active"]).to eq("5")
|
||||
|
||||
sleep(3)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["cl_idle"]).to eq("5")
|
||||
expect(results["sv_idle"]).to eq("5")
|
||||
|
||||
connections.map(&:close)
|
||||
sleep(1)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["sv_idle"]).to eq("5")
|
||||
end
|
||||
end
|
||||
|
||||
context "client connects and opens a transaction and closes connection uncleanly" do
|
||||
it "produces correct statistics" do
|
||||
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
Thread.new do
|
||||
c.async_exec("BEGIN")
|
||||
c.async_exec("SELECT pg_sleep(0.01)")
|
||||
c.close
|
||||
end
|
||||
end
|
||||
|
||||
sleep(1.1)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["sv_idle"]).to eq("5")
|
||||
end
|
||||
end
|
||||
|
||||
context "client fail to checkout connection from the pool" do
|
||||
it "counts clients as idle" do
|
||||
new_configs = processes.pgcat.current_config
|
||||
new_configs["general"]["connect_timeout"] = 500
|
||||
new_configs["general"]["ban_time"] = 1
|
||||
new_configs["general"]["shutdown_timeout"] = 1
|
||||
new_configs["pools"]["sharded_db"]["users"]["0"]["pool_size"] = 1
|
||||
processes.pgcat.update_config(new_configs)
|
||||
processes.pgcat.reload_config
|
||||
|
||||
threads = []
|
||||
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
threads << Thread.new { c.async_exec("SELECT pg_sleep(1)") rescue PG::SystemError }
|
||||
end
|
||||
|
||||
sleep(2)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["cl_idle"]).to eq("5")
|
||||
expect(results["sv_idle"]).to eq("1")
|
||||
|
||||
threads.map(&:join)
|
||||
connections.map(&:close)
|
||||
end
|
||||
end
|
||||
|
||||
context "clients connects and disconnect normally" do
|
||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 2) }
|
||||
|
||||
it 'shows the same number of clients before and after' do
|
||||
clients_before = clients_connected_to_pool(processes: processes)
|
||||
threads = []
|
||||
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
threads << Thread.new { c.async_exec("SELECT 1") }
|
||||
end
|
||||
clients_between = clients_connected_to_pool(processes: processes)
|
||||
expect(clients_before).not_to eq(clients_between)
|
||||
connections.each(&:close)
|
||||
clients_after = clients_connected_to_pool(processes: processes)
|
||||
expect(clients_before).to eq(clients_after)
|
||||
end
|
||||
end
|
||||
|
||||
context "clients connects and disconnect abruptly" do
|
||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 10) }
|
||||
|
||||
it 'shows the same number of clients before and after' do
|
||||
threads = []
|
||||
connections = Array.new(2) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
threads << Thread.new { c.async_exec("SELECT 1") }
|
||||
end
|
||||
clients_before = clients_connected_to_pool(processes: processes)
|
||||
random_string = (0...8).map { (65 + rand(26)).chr }.join
|
||||
connection_string = "#{pgcat_conn_str}?application_name=#{random_string}"
|
||||
faulty_client = Process.spawn("psql -Atx #{connection_string} >/dev/null")
|
||||
sleep(1)
|
||||
# psql starts two processes, we only know the pid of the parent, this
|
||||
# ensure both are killed
|
||||
`pkill -9 -f '#{random_string}'`
|
||||
Process.wait(faulty_client)
|
||||
clients_after = clients_connected_to_pool(processes: processes)
|
||||
expect(clients_before).to eq(clients_after)
|
||||
end
|
||||
end
|
||||
|
||||
context "clients overwhelm server pools" do
|
||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 2) }
|
||||
|
||||
it "cl_waiting is updated to show it" do
|
||||
threads = []
|
||||
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
threads << Thread.new { c.async_exec("SELECT pg_sleep(1.5)") }
|
||||
end
|
||||
|
||||
sleep(1.1) # Allow time for stats to update
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_cancel_req sv_idle sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
|
||||
expect(results["cl_waiting"]).to eq("2")
|
||||
expect(results["cl_active"]).to eq("2")
|
||||
expect(results["sv_active"]).to eq("2")
|
||||
|
||||
sleep(2.5) # Allow time for stats to update
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["cl_idle"]).to eq("4")
|
||||
expect(results["sv_idle"]).to eq("2")
|
||||
|
||||
threads.map(&:join)
|
||||
connections.map(&:close)
|
||||
end
|
||||
|
||||
it "show correct max_wait" do
|
||||
threads = []
|
||||
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
threads << Thread.new { c.async_exec("SELECT pg_sleep(1.5)") }
|
||||
end
|
||||
|
||||
sleep(2.5) # Allow time for stats to update
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
|
||||
expect(results["maxwait"]).to eq("1")
|
||||
expect(results["maxwait_us"].to_i).to be_within(200_000).of(500_000)
|
||||
|
||||
sleep(4.5) # Allow time for stats to update
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
expect(results["maxwait"]).to eq("0")
|
||||
|
||||
threads.map(&:join)
|
||||
connections.map(&:close)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "SHOW CLIENTS" do
|
||||
it "reports correct number and application names" do
|
||||
conn_str = processes.pgcat.connection_string("sharded_db", "sharding_user")
|
||||
connections = Array.new(20) { |i| PG::connect("#{conn_str}?application_name=app#{i % 5}") }
|
||||
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
sleep(1) # Wait for stats to be updated
|
||||
|
||||
results = admin_conn.async_exec("SHOW CLIENTS")
|
||||
expect(results.count).to eq(21) # count admin clients
|
||||
expect(results.select { |c| c["application_name"] == "app3" || c["application_name"] == "app4" }.count).to eq(8)
|
||||
expect(results.select { |c| c["database"] == "pgcat" }.count).to eq(1)
|
||||
|
||||
connections[0..5].map(&:close)
|
||||
sleep(1) # Wait for stats to be updated
|
||||
results = admin_conn.async_exec("SHOW CLIENTS")
|
||||
expect(results.count).to eq(15)
|
||||
|
||||
connections[6..].map(&:close)
|
||||
sleep(1) # Wait for stats to be updated
|
||||
expect(admin_conn.async_exec("SHOW CLIENTS").count).to eq(1)
|
||||
admin_conn.close
|
||||
end
|
||||
|
||||
it "reports correct number of queries and transactions" do
|
||||
conn_str = processes.pgcat.connection_string("sharded_db", "sharding_user")
|
||||
|
||||
connections = Array.new(2) { |i| PG::connect("#{conn_str}?application_name=app#{i}") }
|
||||
connections.each do |c|
|
||||
c.async_exec("SELECT 1")
|
||||
c.async_exec("SELECT 2")
|
||||
c.async_exec("SELECT 3")
|
||||
c.async_exec("BEGIN")
|
||||
c.async_exec("SELECT 4")
|
||||
c.async_exec("SELECT 5")
|
||||
c.async_exec("COMMIT")
|
||||
end
|
||||
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
sleep(1) # Wait for stats to be updated
|
||||
|
||||
results = admin_conn.async_exec("SHOW CLIENTS")
|
||||
expect(results.count).to eq(3)
|
||||
normal_client_results = results.reject { |r| r["database"] == "pgcat" }
|
||||
expect(normal_client_results[0]["transaction_count"]).to eq("4")
|
||||
expect(normal_client_results[1]["transaction_count"]).to eq("4")
|
||||
expect(normal_client_results[0]["query_count"]).to eq("7")
|
||||
expect(normal_client_results[1]["query_count"]).to eq("7")
|
||||
|
||||
admin_conn.close
|
||||
connections.map(&:close)
|
||||
end
|
||||
end
|
||||
|
||||
describe "Manual Banning" do
|
||||
let(:processes) { Helpers::Pgcat.single_shard_setup("sharded_db", 10) }
|
||||
before do
|
||||
@@ -401,7 +81,7 @@ describe "Admin" do
|
||||
end
|
||||
end
|
||||
|
||||
describe "SHOW users" do
|
||||
describe "SHOW USERS" do
|
||||
it "returns the right users" do
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW USERS")[0]
|
||||
|
||||
102
tests/ruby/copy_spec.rb
Normal file
102
tests/ruby/copy_spec.rb
Normal file
@@ -0,0 +1,102 @@
|
||||
# frozen_string_literal: true
|
||||
require_relative 'spec_helper'
|
||||
|
||||
|
||||
describe "COPY Handling" do
|
||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 5) }
|
||||
before do
|
||||
new_configs = processes.pgcat.current_config
|
||||
|
||||
# Allow connections in the pool to expire faster
|
||||
new_configs["general"]["idle_timeout"] = 5
|
||||
processes.pgcat.update_config(new_configs)
|
||||
# We need to kill the old process that was using the default configs
|
||||
processes.pgcat.stop
|
||||
processes.pgcat.start
|
||||
processes.pgcat.wait_until_ready
|
||||
end
|
||||
|
||||
before do
|
||||
processes.all_databases.first.with_connection do |conn|
|
||||
conn.async_exec "CREATE TABLE copy_test_table (a TEXT,b TEXT,c TEXT,d TEXT)"
|
||||
end
|
||||
end
|
||||
|
||||
after do
|
||||
processes.all_databases.first.with_connection do |conn|
|
||||
conn.async_exec "DROP TABLE copy_test_table;"
|
||||
end
|
||||
end
|
||||
|
||||
after do
|
||||
processes.all_databases.map(&:reset)
|
||||
processes.pgcat.shutdown
|
||||
end
|
||||
|
||||
describe "COPY FROM" do
|
||||
context "within transaction" do
|
||||
it "finishes within alloted time" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
Timeout.timeout(3) do
|
||||
conn.async_exec("BEGIN")
|
||||
conn.copy_data "COPY copy_test_table FROM STDIN CSV" do
|
||||
sleep 0.5
|
||||
conn.put_copy_data "some,data,to,copy\n"
|
||||
conn.put_copy_data "more,data,to,copy\n"
|
||||
end
|
||||
conn.async_exec("COMMIT")
|
||||
end
|
||||
|
||||
res = conn.async_exec("SELECT * FROM copy_test_table").to_a
|
||||
expect(res).to eq([
|
||||
{"a"=>"some", "b"=>"data", "c"=>"to", "d"=>"copy"},
|
||||
{"a"=>"more", "b"=>"data", "c"=>"to", "d"=>"copy"}
|
||||
])
|
||||
end
|
||||
end
|
||||
|
||||
context "outside transaction" do
|
||||
it "finishes within alloted time" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
Timeout.timeout(3) do
|
||||
conn.copy_data "COPY copy_test_table FROM STDIN CSV" do
|
||||
sleep 0.5
|
||||
conn.put_copy_data "some,data,to,copy\n"
|
||||
conn.put_copy_data "more,data,to,copy\n"
|
||||
end
|
||||
end
|
||||
|
||||
res = conn.async_exec("SELECT * FROM copy_test_table").to_a
|
||||
expect(res).to eq([
|
||||
{"a"=>"some", "b"=>"data", "c"=>"to", "d"=>"copy"},
|
||||
{"a"=>"more", "b"=>"data", "c"=>"to", "d"=>"copy"}
|
||||
])
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "COPY TO" do
|
||||
before do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
conn.async_exec("BEGIN")
|
||||
conn.copy_data "COPY copy_test_table FROM STDIN CSV" do
|
||||
conn.put_copy_data "some,data,to,copy\n"
|
||||
conn.put_copy_data "more,data,to,copy\n"
|
||||
end
|
||||
conn.async_exec("COMMIT")
|
||||
conn.close
|
||||
end
|
||||
|
||||
it "works" do
|
||||
res = []
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
conn.copy_data "COPY copy_test_table TO STDOUT CSV" do
|
||||
while row=conn.get_copy_data
|
||||
res << row
|
||||
end
|
||||
end
|
||||
expect(res).to eq(["some,data,to,copy\n", "more,data,to,copy\n"])
|
||||
end
|
||||
end
|
||||
|
||||
end
|
||||
29
tests/ruby/prepared_spec.rb
Normal file
29
tests/ruby/prepared_spec.rb
Normal file
@@ -0,0 +1,29 @@
|
||||
require_relative 'spec_helper'
|
||||
|
||||
describe 'Prepared statements' do
|
||||
let(:processes) { Helpers::Pgcat.three_shard_setup('sharded_db', 5) }
|
||||
|
||||
context 'enabled' do
|
||||
it 'will work over the same connection' do
|
||||
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
|
||||
|
||||
10.times do |i|
|
||||
statement_name = "statement_#{i}"
|
||||
conn.prepare(statement_name, 'SELECT $1::int')
|
||||
conn.exec_prepared(statement_name, [1])
|
||||
conn.describe_prepared(statement_name)
|
||||
end
|
||||
end
|
||||
|
||||
it 'will work with new connections' do
|
||||
10.times do
|
||||
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
|
||||
|
||||
statement_name = 'statement1'
|
||||
conn.prepare('statement1', 'SELECT $1::int')
|
||||
conn.exec_prepared('statement1', [1])
|
||||
conn.describe_prepared('statement1')
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
369
tests/ruby/stats_spec.rb
Normal file
369
tests/ruby/stats_spec.rb
Normal file
@@ -0,0 +1,369 @@
|
||||
# frozen_string_literal: true
|
||||
require 'open3'
|
||||
require_relative 'spec_helper'
|
||||
|
||||
describe "Stats" do
|
||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 10) }
|
||||
let(:pgcat_conn_str) { processes.pgcat.connection_string("sharded_db", "sharding_user") }
|
||||
|
||||
after do
|
||||
processes.all_databases.map(&:reset)
|
||||
processes.pgcat.shutdown
|
||||
end
|
||||
|
||||
describe "SHOW STATS" do
|
||||
context "clients connect and make one query" do
|
||||
it "updates *_query_time and *_wait_time" do
|
||||
connections = Array.new(3) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
Thread.new { c.async_exec("SELECT pg_sleep(0.25)") }
|
||||
end
|
||||
sleep(1)
|
||||
connections.map(&:close)
|
||||
|
||||
# wait for averages to be calculated, we shouldn't do this too often
|
||||
sleep(15.5)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW STATS")[0]
|
||||
admin_conn.close
|
||||
expect(results["total_query_time"].to_i).to be_within(200).of(750)
|
||||
expect(results["avg_query_time"].to_i).to be_within(50).of(250)
|
||||
|
||||
expect(results["total_wait_time"].to_i).to_not eq(0)
|
||||
expect(results["avg_wait_time"].to_i).to_not eq(0)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "SHOW POOLS" do
|
||||
context "bad credentials" do
|
||||
it "does not change any stats" do
|
||||
bad_password_url = URI(pgcat_conn_str)
|
||||
bad_password_url.password = "wrong"
|
||||
expect { PG::connect("#{bad_password_url.to_s}?application_name=bad_password") }.to raise_error(PG::ConnectionBad)
|
||||
|
||||
sleep(1)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
|
||||
expect(results["sv_idle"]).to eq("1")
|
||||
end
|
||||
end
|
||||
|
||||
context "bad database name" do
|
||||
it "does not change any stats" do
|
||||
bad_db_url = URI(pgcat_conn_str)
|
||||
bad_db_url.path = "/wrong_db"
|
||||
expect { PG::connect("#{bad_db_url.to_s}?application_name=bad_db") }.to raise_error(PG::ConnectionBad)
|
||||
|
||||
sleep(1)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
|
||||
expect(results["sv_idle"]).to eq("1")
|
||||
end
|
||||
end
|
||||
|
||||
context "client connects but issues no queries" do
|
||||
it "only affects cl_idle stats" do
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
|
||||
before_test = admin_conn.async_exec("SHOW POOLS")[0]["sv_idle"]
|
||||
connections = Array.new(20) { PG::connect(pgcat_conn_str) }
|
||||
sleep(1)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["cl_idle"]).to eq("20")
|
||||
expect(results["sv_idle"]).to eq(before_test)
|
||||
|
||||
connections.map(&:close)
|
||||
sleep(1.1)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_active cl_idle cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["sv_idle"]).to eq(before_test)
|
||||
end
|
||||
end
|
||||
|
||||
context "clients connect and make one query" do
|
||||
it "only affects cl_idle, sv_idle stats" do
|
||||
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
Thread.new { c.async_exec("SELECT pg_sleep(2.5)") }
|
||||
end
|
||||
|
||||
sleep(1.1)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_waiting cl_cancel_req sv_idle sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["cl_active"]).to eq("5")
|
||||
expect(results["sv_active"]).to eq("5")
|
||||
|
||||
sleep(3)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["cl_idle"]).to eq("5")
|
||||
expect(results["sv_idle"]).to eq("5")
|
||||
|
||||
connections.map(&:close)
|
||||
sleep(1)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["sv_idle"]).to eq("5")
|
||||
end
|
||||
end
|
||||
|
||||
context "client connects and opens a transaction and closes connection uncleanly" do
|
||||
it "produces correct statistics" do
|
||||
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
Thread.new do
|
||||
c.async_exec("BEGIN")
|
||||
c.async_exec("SELECT pg_sleep(0.01)")
|
||||
c.close
|
||||
end
|
||||
end
|
||||
|
||||
sleep(1.1)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["sv_idle"]).to eq("5")
|
||||
end
|
||||
end
|
||||
|
||||
context "client fail to checkout connection from the pool" do
|
||||
it "counts clients as idle" do
|
||||
new_configs = processes.pgcat.current_config
|
||||
new_configs["general"]["connect_timeout"] = 500
|
||||
new_configs["general"]["ban_time"] = 1
|
||||
new_configs["general"]["shutdown_timeout"] = 1
|
||||
new_configs["pools"]["sharded_db"]["users"]["0"]["pool_size"] = 1
|
||||
processes.pgcat.update_config(new_configs)
|
||||
processes.pgcat.reload_config
|
||||
|
||||
threads = []
|
||||
connections = Array.new(5) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
threads << Thread.new { c.async_exec("SELECT pg_sleep(1)") rescue PG::SystemError }
|
||||
end
|
||||
|
||||
sleep(2)
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["cl_idle"]).to eq("5")
|
||||
expect(results["sv_idle"]).to eq("1")
|
||||
|
||||
threads.map(&:join)
|
||||
connections.map(&:close)
|
||||
end
|
||||
end
|
||||
|
||||
context "clients connects and disconnect normally" do
|
||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 2) }
|
||||
|
||||
it 'shows the same number of clients before and after' do
|
||||
clients_before = clients_connected_to_pool(processes: processes)
|
||||
threads = []
|
||||
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
threads << Thread.new { c.async_exec("SELECT 1") rescue nil }
|
||||
end
|
||||
clients_between = clients_connected_to_pool(processes: processes)
|
||||
expect(clients_before).not_to eq(clients_between)
|
||||
connections.each(&:close)
|
||||
clients_after = clients_connected_to_pool(processes: processes)
|
||||
expect(clients_before).to eq(clients_after)
|
||||
end
|
||||
end
|
||||
|
||||
context "clients connects and disconnect abruptly" do
|
||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 10) }
|
||||
|
||||
it 'shows the same number of clients before and after' do
|
||||
threads = []
|
||||
connections = Array.new(2) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
threads << Thread.new { c.async_exec("SELECT 1") }
|
||||
end
|
||||
clients_before = clients_connected_to_pool(processes: processes)
|
||||
random_string = (0...8).map { (65 + rand(26)).chr }.join
|
||||
connection_string = "#{pgcat_conn_str}?application_name=#{random_string}"
|
||||
faulty_client = Process.spawn("psql -Atx #{connection_string} >/dev/null")
|
||||
sleep(1)
|
||||
# psql starts two processes, we only know the pid of the parent, this
|
||||
# ensure both are killed
|
||||
`pkill -9 -f '#{random_string}'`
|
||||
Process.wait(faulty_client)
|
||||
clients_after = clients_connected_to_pool(processes: processes)
|
||||
expect(clients_before).to eq(clients_after)
|
||||
end
|
||||
end
|
||||
|
||||
context "clients overwhelm server pools" do
|
||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 2) }
|
||||
|
||||
it "cl_waiting is updated to show it" do
|
||||
threads = []
|
||||
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
threads << Thread.new { c.async_exec("SELECT pg_sleep(1.5)") }
|
||||
end
|
||||
|
||||
sleep(1.1) # Allow time for stats to update
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_cancel_req sv_idle sv_used sv_tested sv_login maxwait].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
|
||||
expect(results["cl_waiting"]).to eq("2")
|
||||
expect(results["cl_active"]).to eq("2")
|
||||
expect(results["sv_active"]).to eq("2")
|
||||
|
||||
sleep(2.5) # Allow time for stats to update
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_active cl_waiting cl_cancel_req sv_active sv_used sv_tested sv_login].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
expect(results["cl_idle"]).to eq("4")
|
||||
expect(results["sv_idle"]).to eq("2")
|
||||
|
||||
threads.map(&:join)
|
||||
connections.map(&:close)
|
||||
end
|
||||
|
||||
it "show correct max_wait" do
|
||||
threads = []
|
||||
connections = Array.new(4) { PG::connect("#{pgcat_conn_str}?application_name=one_query") }
|
||||
connections.each do |c|
|
||||
threads << Thread.new { c.async_exec("SELECT pg_sleep(1.5)") rescue nil }
|
||||
end
|
||||
|
||||
sleep(2.5) # Allow time for stats to update
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
|
||||
expect(results["maxwait"]).to eq("1")
|
||||
expect(results["maxwait_us"].to_i).to be_within(200_000).of(500_000)
|
||||
connections.map(&:close)
|
||||
|
||||
sleep(4.5) # Allow time for stats to update
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
expect(results["maxwait"]).to eq("0")
|
||||
|
||||
threads.map(&:join)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "SHOW CLIENTS" do
|
||||
it "reports correct number and application names" do
|
||||
conn_str = processes.pgcat.connection_string("sharded_db", "sharding_user")
|
||||
connections = Array.new(20) { |i| PG::connect("#{conn_str}?application_name=app#{i % 5}") }
|
||||
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
sleep(1) # Wait for stats to be updated
|
||||
|
||||
results = admin_conn.async_exec("SHOW CLIENTS")
|
||||
expect(results.count).to eq(21) # count admin clients
|
||||
expect(results.select { |c| c["application_name"] == "app3" || c["application_name"] == "app4" }.count).to eq(8)
|
||||
expect(results.select { |c| c["database"] == "pgcat" }.count).to eq(1)
|
||||
|
||||
connections[0..5].map(&:close)
|
||||
sleep(1) # Wait for stats to be updated
|
||||
results = admin_conn.async_exec("SHOW CLIENTS")
|
||||
expect(results.count).to eq(15)
|
||||
|
||||
connections[6..].map(&:close)
|
||||
sleep(1) # Wait for stats to be updated
|
||||
expect(admin_conn.async_exec("SHOW CLIENTS").count).to eq(1)
|
||||
admin_conn.close
|
||||
end
|
||||
|
||||
it "reports correct number of queries and transactions" do
|
||||
conn_str = processes.pgcat.connection_string("sharded_db", "sharding_user")
|
||||
|
||||
connections = Array.new(2) { |i| PG::connect("#{conn_str}?application_name=app#{i}") }
|
||||
connections.each do |c|
|
||||
c.async_exec("SELECT 1")
|
||||
c.async_exec("SELECT 2")
|
||||
c.async_exec("SELECT 3")
|
||||
c.async_exec("BEGIN")
|
||||
c.async_exec("SELECT 4")
|
||||
c.async_exec("SELECT 5")
|
||||
c.async_exec("COMMIT")
|
||||
end
|
||||
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
sleep(1) # Wait for stats to be updated
|
||||
|
||||
results = admin_conn.async_exec("SHOW CLIENTS")
|
||||
expect(results.count).to eq(3)
|
||||
normal_client_results = results.reject { |r| r["database"] == "pgcat" }
|
||||
expect(normal_client_results[0]["transaction_count"]).to eq("4")
|
||||
expect(normal_client_results[1]["transaction_count"]).to eq("4")
|
||||
expect(normal_client_results[0]["query_count"]).to eq("7")
|
||||
expect(normal_client_results[1]["query_count"]).to eq("7")
|
||||
|
||||
admin_conn.close
|
||||
connections.map(&:close)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
describe "Query Storm" do
|
||||
context "when the proxy receives overwhelmingly large number of short quick queries" do
|
||||
it "should not have lingering clients or active servers" do
|
||||
new_configs = processes.pgcat.current_config
|
||||
|
||||
new_configs["general"]["connect_timeout"] = 500
|
||||
new_configs["general"]["ban_time"] = 1
|
||||
new_configs["general"]["shutdown_timeout"] = 1
|
||||
new_configs["pools"]["sharded_db"]["users"]["0"]["pool_size"] = 1
|
||||
processes.pgcat.update_config(new_configs)
|
||||
processes.pgcat.reload_config
|
||||
|
||||
Array.new(40) do
|
||||
Thread.new do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
conn.async_exec("SELECT pg_sleep(0.1)")
|
||||
rescue PG::SystemError
|
||||
ensure
|
||||
conn.close
|
||||
end
|
||||
end.each(&:join)
|
||||
|
||||
sleep 1
|
||||
|
||||
admin_conn = PG::connect(processes.pgcat.admin_connection_string)
|
||||
results = admin_conn.async_exec("SHOW POOLS")[0]
|
||||
%w[cl_idle cl_waiting cl_cancel_req sv_used sv_tested sv_login].each do |s|
|
||||
raise StandardError, "Field #{s} was expected to be 0 but found to be #{results[s]}" if results[s] != "0"
|
||||
end
|
||||
|
||||
admin_conn.close
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
1
tests/rust/.gitignore
vendored
Normal file
1
tests/rust/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
target/
|
||||
1322
tests/rust/Cargo.lock
generated
Normal file
1322
tests/rust/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
10
tests/rust/Cargo.toml
Normal file
10
tests/rust/Cargo.toml
Normal file
@@ -0,0 +1,10 @@
|
||||
[package]
|
||||
name = "rust"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
sqlx = { version = "0.6.2", features = [ "runtime-tokio-rustls", "postgres", "json", "tls", "migrate", "time", "uuid", "ipnetwork"] }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
29
tests/rust/src/main.rs
Normal file
29
tests/rust/src/main.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
test_prepared_statements().await;
|
||||
}
|
||||
|
||||
async fn test_prepared_statements() {
|
||||
let pool = sqlx::postgres::PgPoolOptions::new()
|
||||
.max_connections(5)
|
||||
.connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/sharded_db")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut handles = Vec::new();
|
||||
|
||||
for _ in 0..5 {
|
||||
let pool = pool.clone();
|
||||
let handle = tokio::task::spawn(async move {
|
||||
for _ in 0..1000 {
|
||||
sqlx::query("SELECT 1").fetch_all(&pool).await.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.await.unwrap();
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user